summaryrefslogtreecommitdiffstats
path: root/src/arrow/cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/arrow/cpp')
-rw-r--r--src/arrow/cpp/.gitignore43
-rw-r--r--src/arrow/cpp/Brewfile41
-rw-r--r--src/arrow/cpp/CHANGELOG_PARQUET.md501
-rw-r--r--src/arrow/cpp/CMakeLists.txt958
-rw-r--r--src/arrow/cpp/CMakeSettings.json21
-rw-r--r--src/arrow/cpp/README.md34
-rw-r--r--src/arrow/cpp/apidoc/.gitignore1
-rw-r--r--src/arrow/cpp/apidoc/Doxyfile2551
-rw-r--r--src/arrow/cpp/apidoc/HDFS.md83
-rw-r--r--src/arrow/cpp/apidoc/footer.html31
-rw-r--r--src/arrow/cpp/apidoc/tutorials/plasma.md450
-rw-r--r--src/arrow/cpp/apidoc/tutorials/tensor_to_py.md127
-rwxr-xr-xsrc/arrow/cpp/build-support/asan_symbolize.py368
-rwxr-xr-xsrc/arrow/cpp/build-support/build-lz4-lib.sh32
-rwxr-xr-xsrc/arrow/cpp/build-support/build-zstd-lib.sh25
-rwxr-xr-xsrc/arrow/cpp/build-support/cpplint.py6477
-rwxr-xr-xsrc/arrow/cpp/build-support/fuzzing/generate_corpuses.sh59
-rwxr-xr-xsrc/arrow/cpp/build-support/fuzzing/pack_corpus.py54
-rwxr-xr-xsrc/arrow/cpp/build-support/get-upstream-commit.sh25
-rw-r--r--src/arrow/cpp/build-support/iwyu/iwyu-filter.awk96
-rwxr-xr-xsrc/arrow/cpp/build-support/iwyu/iwyu.sh90
-rwxr-xr-xsrc/arrow/cpp/build-support/iwyu/iwyu_tool.py280
-rw-r--r--src/arrow/cpp/build-support/iwyu/mappings/arrow-misc.imp61
-rw-r--r--src/arrow/cpp/build-support/iwyu/mappings/boost-all-private.imp4166
-rw-r--r--src/arrow/cpp/build-support/iwyu/mappings/boost-all.imp5679
-rw-r--r--src/arrow/cpp/build-support/iwyu/mappings/boost-extra.imp23
-rw-r--r--src/arrow/cpp/build-support/iwyu/mappings/gflags.imp20
-rw-r--r--src/arrow/cpp/build-support/iwyu/mappings/glog.imp27
-rw-r--r--src/arrow/cpp/build-support/iwyu/mappings/gmock.imp23
-rw-r--r--src/arrow/cpp/build-support/iwyu/mappings/gtest.imp26
-rwxr-xr-xsrc/arrow/cpp/build-support/lint_cpp_cli.py128
-rw-r--r--src/arrow/cpp/build-support/lint_exclusions.txt15
-rw-r--r--src/arrow/cpp/build-support/lintutils.py109
-rw-r--r--src/arrow/cpp/build-support/lsan-suppressions.txt21
-rwxr-xr-xsrc/arrow/cpp/build-support/run-infer.sh48
-rwxr-xr-xsrc/arrow/cpp/build-support/run-test.sh237
-rwxr-xr-xsrc/arrow/cpp/build-support/run_clang_format.py137
-rwxr-xr-xsrc/arrow/cpp/build-support/run_clang_tidy.py124
-rwxr-xr-xsrc/arrow/cpp/build-support/run_cpplint.py132
-rw-r--r--src/arrow/cpp/build-support/sanitizer-disallowed-entries.txt25
-rwxr-xr-xsrc/arrow/cpp/build-support/stacktrace_addr2line.pl92
-rwxr-xr-xsrc/arrow/cpp/build-support/trim-boost.sh72
-rw-r--r--src/arrow/cpp/build-support/tsan-suppressions.txt19
-rw-r--r--src/arrow/cpp/build-support/ubsan-suppressions.txt16
-rwxr-xr-xsrc/arrow/cpp/build-support/update-flatbuffers.sh50
-rwxr-xr-xsrc/arrow/cpp/build-support/update-thrift.sh23
-rwxr-xr-xsrc/arrow/cpp/build-support/vendor-flatbuffers.sh31
-rw-r--r--src/arrow/cpp/cmake_modules/BuildUtils.cmake936
-rw-r--r--src/arrow/cpp/cmake_modules/DefineOptions.cmake589
-rw-r--r--src/arrow/cpp/cmake_modules/FindArrow.cmake466
-rw-r--r--src/arrow/cpp/cmake_modules/FindArrowCUDA.cmake88
-rw-r--r--src/arrow/cpp/cmake_modules/FindArrowDataset.cmake89
-rw-r--r--src/arrow/cpp/cmake_modules/FindArrowFlight.cmake89
-rw-r--r--src/arrow/cpp/cmake_modules/FindArrowFlightTesting.cmake98
-rw-r--r--src/arrow/cpp/cmake_modules/FindArrowPython.cmake87
-rw-r--r--src/arrow/cpp/cmake_modules/FindArrowPythonFlight.cmake94
-rw-r--r--src/arrow/cpp/cmake_modules/FindArrowTesting.cmake89
-rw-r--r--src/arrow/cpp/cmake_modules/FindBoostAlt.cmake63
-rw-r--r--src/arrow/cpp/cmake_modules/FindBrotli.cmake130
-rw-r--r--src/arrow/cpp/cmake_modules/FindClangTools.cmake106
-rw-r--r--src/arrow/cpp/cmake_modules/FindGLOG.cmake56
-rw-r--r--src/arrow/cpp/cmake_modules/FindGandiva.cmake94
-rw-r--r--src/arrow/cpp/cmake_modules/FindInferTools.cmake47
-rw-r--r--src/arrow/cpp/cmake_modules/FindLLVMAlt.cmake76
-rw-r--r--src/arrow/cpp/cmake_modules/FindLz4.cmake84
-rw-r--r--src/arrow/cpp/cmake_modules/FindNumPy.cmake96
-rw-r--r--src/arrow/cpp/cmake_modules/FindORC.cmake55
-rw-r--r--src/arrow/cpp/cmake_modules/FindOpenSSLAlt.cmake54
-rw-r--r--src/arrow/cpp/cmake_modules/FindParquet.cmake126
-rw-r--r--src/arrow/cpp/cmake_modules/FindPlasma.cmake102
-rw-r--r--src/arrow/cpp/cmake_modules/FindPython3Alt.cmake96
-rw-r--r--src/arrow/cpp/cmake_modules/FindPythonLibsNew.cmake267
-rw-r--r--src/arrow/cpp/cmake_modules/FindRapidJSONAlt.cmake72
-rw-r--r--src/arrow/cpp/cmake_modules/FindSnappy.cmake62
-rw-r--r--src/arrow/cpp/cmake_modules/FindThrift.cmake144
-rw-r--r--src/arrow/cpp/cmake_modules/Findc-aresAlt.cmake73
-rw-r--r--src/arrow/cpp/cmake_modules/FindgRPCAlt.cmake76
-rw-r--r--src/arrow/cpp/cmake_modules/FindgflagsAlt.cmake59
-rw-r--r--src/arrow/cpp/cmake_modules/Findjemalloc.cmake94
-rw-r--r--src/arrow/cpp/cmake_modules/Findre2Alt.cmake87
-rw-r--r--src/arrow/cpp/cmake_modules/Findutf8proc.cmake101
-rw-r--r--src/arrow/cpp/cmake_modules/Findzstd.cmake89
-rw-r--r--src/arrow/cpp/cmake_modules/SetupCxxFlags.cmake648
-rw-r--r--src/arrow/cpp/cmake_modules/ThirdpartyToolchain.cmake4063
-rw-r--r--src/arrow/cpp/cmake_modules/UseCython.cmake187
-rw-r--r--src/arrow/cpp/cmake_modules/Usevcpkg.cmake249
-rw-r--r--src/arrow/cpp/cmake_modules/san-config.cmake122
-rw-r--r--src/arrow/cpp/examples/arrow/CMakeLists.txt44
-rw-r--r--src/arrow/cpp/examples/arrow/compute_and_write_csv_example.cc113
-rw-r--r--src/arrow/cpp/examples/arrow/compute_register_example.cc168
-rw-r--r--src/arrow/cpp/examples/arrow/dataset_documentation_example.cc374
-rw-r--r--src/arrow/cpp/examples/arrow/dataset_parquet_scan_example.cc190
-rw-r--r--src/arrow/cpp/examples/arrow/row_wise_conversion_example.cc207
-rw-r--r--src/arrow/cpp/examples/minimal_build/.gitignore18
-rw-r--r--src/arrow/cpp/examples/minimal_build/CMakeLists.txt40
-rw-r--r--src/arrow/cpp/examples/minimal_build/README.md88
-rwxr-xr-xsrc/arrow/cpp/examples/minimal_build/build_arrow.sh35
-rwxr-xr-xsrc/arrow/cpp/examples/minimal_build/build_example.sh27
-rw-r--r--src/arrow/cpp/examples/minimal_build/docker-compose.yml51
-rw-r--r--src/arrow/cpp/examples/minimal_build/example.cc69
-rw-r--r--src/arrow/cpp/examples/minimal_build/minimal.dockerfile27
-rwxr-xr-xsrc/arrow/cpp/examples/minimal_build/run.sh48
-rw-r--r--src/arrow/cpp/examples/minimal_build/run_static.bat88
-rwxr-xr-xsrc/arrow/cpp/examples/minimal_build/run_static.sh121
-rw-r--r--src/arrow/cpp/examples/minimal_build/system_dependency.dockerfile44
-rw-r--r--src/arrow/cpp/examples/minimal_build/test.csv3
-rw-r--r--src/arrow/cpp/examples/parquet/CMakeLists.txt78
-rw-r--r--src/arrow/cpp/examples/parquet/low_level_api/encryption_reader_writer.cc451
-rw-r--r--src/arrow/cpp/examples/parquet/low_level_api/encryption_reader_writer_all_crypto_options.cc656
-rw-r--r--src/arrow/cpp/examples/parquet/low_level_api/reader_writer.cc413
-rw-r--r--src/arrow/cpp/examples/parquet/low_level_api/reader_writer.h70
-rw-r--r--src/arrow/cpp/examples/parquet/low_level_api/reader_writer2.cc434
-rw-r--r--src/arrow/cpp/examples/parquet/parquet_arrow/CMakeLists.txt42
-rw-r--r--src/arrow/cpp/examples/parquet/parquet_arrow/README.md20
-rw-r--r--src/arrow/cpp/examples/parquet/parquet_arrow/reader_writer.cc140
-rw-r--r--src/arrow/cpp/examples/parquet/parquet_stream_api/stream_reader_writer.cc324
-rw-r--r--src/arrow/cpp/src/arrow/ArrowConfig.cmake.in92
-rw-r--r--src/arrow/cpp/src/arrow/ArrowTestingConfig.cmake.in36
-rw-r--r--src/arrow/cpp/src/arrow/CMakeLists.txt751
-rw-r--r--src/arrow/cpp/src/arrow/adapters/orc/CMakeLists.txt57
-rw-r--r--src/arrow/cpp/src/arrow/adapters/orc/adapter.cc699
-rw-r--r--src/arrow/cpp/src/arrow/adapters/orc/adapter.h291
-rw-r--r--src/arrow/cpp/src/arrow/adapters/orc/adapter_test.cc686
-rw-r--r--src/arrow/cpp/src/arrow/adapters/orc/adapter_util.cc1069
-rw-r--r--src/arrow/cpp/src/arrow/adapters/orc/adapter_util.h57
-rw-r--r--src/arrow/cpp/src/arrow/adapters/orc/arrow-orc.pc.in24
-rw-r--r--src/arrow/cpp/src/arrow/adapters/tensorflow/CMakeLists.txt21
-rw-r--r--src/arrow/cpp/src/arrow/adapters/tensorflow/arrow-tensorflow.pc.in24
-rw-r--r--src/arrow/cpp/src/arrow/adapters/tensorflow/convert.h128
-rw-r--r--src/arrow/cpp/src/arrow/api.h44
-rw-r--r--src/arrow/cpp/src/arrow/array.h44
-rw-r--r--src/arrow/cpp/src/arrow/array/CMakeLists.txt26
-rw-r--r--src/arrow/cpp/src/arrow/array/README.md20
-rw-r--r--src/arrow/cpp/src/arrow/array/array_base.cc313
-rw-r--r--src/arrow/cpp/src/arrow/array/array_base.h260
-rw-r--r--src/arrow/cpp/src/arrow/array/array_binary.cc108
-rw-r--r--src/arrow/cpp/src/arrow/array/array_binary.h261
-rw-r--r--src/arrow/cpp/src/arrow/array/array_binary_test.cc900
-rw-r--r--src/arrow/cpp/src/arrow/array/array_decimal.cc63
-rw-r--r--src/arrow/cpp/src/arrow/array/array_decimal.h72
-rw-r--r--src/arrow/cpp/src/arrow/array/array_dict.cc442
-rw-r--r--src/arrow/cpp/src/arrow/array/array_dict.h180
-rw-r--r--src/arrow/cpp/src/arrow/array/array_dict_test.cc1678
-rw-r--r--src/arrow/cpp/src/arrow/array/array_list_test.cc1182
-rw-r--r--src/arrow/cpp/src/arrow/array/array_nested.cc763
-rw-r--r--src/arrow/cpp/src/arrow/array/array_nested.h533
-rw-r--r--src/arrow/cpp/src/arrow/array/array_primitive.cc133
-rw-r--r--src/arrow/cpp/src/arrow/array/array_primitive.h178
-rw-r--r--src/arrow/cpp/src/arrow/array/array_struct_test.cc699
-rw-r--r--src/arrow/cpp/src/arrow/array/array_test.cc3291
-rw-r--r--src/arrow/cpp/src/arrow/array/array_union_test.cc582
-rw-r--r--src/arrow/cpp/src/arrow/array/array_view_test.cc441
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_adaptive.cc380
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_adaptive.h203
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_base.cc336
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_base.h307
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_binary.cc207
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_binary.h697
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_decimal.cc105
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_decimal.h94
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_dict.cc213
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_dict.h712
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_nested.cc294
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_nested.h544
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_primitive.cc145
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_primitive.h519
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_time.h56
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_union.cc151
-rw-r--r--src/arrow/cpp/src/arrow/array/builder_union.h242
-rw-r--r--src/arrow/cpp/src/arrow/array/concatenate.cc510
-rw-r--r--src/arrow/cpp/src/arrow/array/concatenate.h37
-rw-r--r--src/arrow/cpp/src/arrow/array/concatenate_test.cc398
-rw-r--r--src/arrow/cpp/src/arrow/array/data.cc331
-rw-r--r--src/arrow/cpp/src/arrow/array/data.h258
-rw-r--r--src/arrow/cpp/src/arrow/array/dict_internal.h193
-rw-r--r--src/arrow/cpp/src/arrow/array/diff.cc794
-rw-r--r--src/arrow/cpp/src/arrow/array/diff.h76
-rw-r--r--src/arrow/cpp/src/arrow/array/diff_test.cc696
-rw-r--r--src/arrow/cpp/src/arrow/array/util.cc860
-rw-r--r--src/arrow/cpp/src/arrow/array/util.h78
-rw-r--r--src/arrow/cpp/src/arrow/array/validate.cc679
-rw-r--r--src/arrow/cpp/src/arrow/array/validate.h55
-rw-r--r--src/arrow/cpp/src/arrow/arrow-config.cmake26
-rw-r--r--src/arrow/cpp/src/arrow/arrow-testing.pc.in27
-rw-r--r--src/arrow/cpp/src/arrow/arrow.pc.in31
-rw-r--r--src/arrow/cpp/src/arrow/buffer.cc207
-rw-r--r--src/arrow/cpp/src/arrow/buffer.h499
-rw-r--r--src/arrow/cpp/src/arrow/buffer_builder.h459
-rw-r--r--src/arrow/cpp/src/arrow/buffer_test.cc926
-rw-r--r--src/arrow/cpp/src/arrow/builder.cc312
-rw-r--r--src/arrow/cpp/src/arrow/builder.h32
-rw-r--r--src/arrow/cpp/src/arrow/builder_benchmark.cc453
-rw-r--r--src/arrow/cpp/src/arrow/c/CMakeLists.txt22
-rw-r--r--src/arrow/cpp/src/arrow/c/abi.h103
-rw-r--r--src/arrow/cpp/src/arrow/c/bridge.cc1818
-rw-r--r--src/arrow/cpp/src/arrow/c/bridge.h197
-rw-r--r--src/arrow/cpp/src/arrow/c/bridge_benchmark.cc159
-rw-r--r--src/arrow/cpp/src/arrow/c/bridge_test.cc3226
-rw-r--r--src/arrow/cpp/src/arrow/c/helpers.h117
-rw-r--r--src/arrow/cpp/src/arrow/c/util_internal.h85
-rw-r--r--src/arrow/cpp/src/arrow/chunked_array.cc304
-rw-r--r--src/arrow/cpp/src/arrow/chunked_array.h255
-rw-r--r--src/arrow/cpp/src/arrow/chunked_array_test.cc266
-rw-r--r--src/arrow/cpp/src/arrow/compare.cc1300
-rw-r--r--src/arrow/cpp/src/arrow/compare.h133
-rw-r--r--src/arrow/cpp/src/arrow/compare_benchmark.cc164
-rw-r--r--src/arrow/cpp/src/arrow/compute/CMakeLists.txt72
-rw-r--r--src/arrow/cpp/src/arrow/compute/README.md58
-rw-r--r--src/arrow/cpp/src/arrow/compute/api.h35
-rw-r--r--src/arrow/cpp/src/arrow/compute/api_aggregate.cc250
-rw-r--r--src/arrow/cpp/src/arrow/compute/api_aggregate.h494
-rw-r--r--src/arrow/cpp/src/arrow/compute/api_scalar.cc676
-rw-r--r--src/arrow/cpp/src/arrow/compute/api_scalar.h1219
-rw-r--r--src/arrow/cpp/src/arrow/compute/api_vector.cc328
-rw-r--r--src/arrow/cpp/src/arrow/compute/api_vector.h506
-rw-r--r--src/arrow/cpp/src/arrow/compute/arrow-compute.pc.in21
-rw-r--r--src/arrow/cpp/src/arrow/compute/cast.cc273
-rw-r--r--src/arrow/cpp/src/arrow/compute/cast.h167
-rw-r--r--src/arrow/cpp/src/arrow/compute/cast_internal.h43
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec.cc1061
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec.h268
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/CMakeLists.txt33
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/aggregate_node.cc644
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_1.jpgbin0 -> 53790 bytes
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_10.jpgbin0 -> 69625 bytes
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_11.jpgbin0 -> 60687 bytes
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_2.jpgbin0 -> 43971 bytes
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_3.jpgbin0 -> 59985 bytes
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_4.jpgbin0 -> 56289 bytes
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_5.jpgbin0 -> 61950 bytes
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_6.jpgbin0 -> 43687 bytes
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_7.jpgbin0 -> 43687 bytes
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_8.jpgbin0 -> 48054 bytes
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_9.jpgbin0 -> 52894 bytes
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/doc/key_map.md223
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/exec_plan.cc523
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/exec_plan.h422
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/expression.cc1192
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/expression.h269
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/expression_benchmark.cc88
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/expression_internal.h336
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/expression_test.cc1414
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/filter_node.cc116
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/forest_internal.h125
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/hash_join.cc795
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/hash_join.h95
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/hash_join_dict.cc665
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/hash_join_dict.h315
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/hash_join_node.cc469
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/hash_join_node_test.cc1693
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/key_compare.cc424
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/key_compare.h137
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/key_compare_avx2.cc633
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/key_encode.cc1341
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/key_encode.h567
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/key_encode_avx2.cc241
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/key_hash.cc319
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/key_hash.h106
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/key_hash_avx2.cc268
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/key_map.cc862
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/key_map.h206
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/key_map_avx2.cc414
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/options.h265
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/order_by_impl.cc104
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/order_by_impl.h53
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/plan_test.cc1226
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/project_node.cc127
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/schema_util.h209
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/sink_node.cc341
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/source_node.cc182
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/subtree_internal.h178
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/subtree_test.cc377
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/task_util.cc409
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/task_util.h100
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/test_util.cc239
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/test_util.h107
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/union_node.cc154
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/union_node_test.cc150
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/util.cc336
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/util.h277
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/util_avx2.cc221
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec/util_test.cc131
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec_internal.h145
-rw-r--r--src/arrow/cpp/src/arrow/compute/exec_test.cc891
-rw-r--r--src/arrow/cpp/src/arrow/compute/function.cc339
-rw-r--r--src/arrow/cpp/src/arrow/compute/function.h395
-rw-r--r--src/arrow/cpp/src/arrow/compute/function_benchmark.cc218
-rw-r--r--src/arrow/cpp/src/arrow/compute/function_internal.cc113
-rw-r--r--src/arrow/cpp/src/arrow/compute/function_internal.h648
-rw-r--r--src/arrow/cpp/src/arrow/compute/function_test.cc351
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernel.cc507
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernel.h752
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernel_test.cc516
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/CMakeLists.txt78
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic.cc1011
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc88
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc90
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h626
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc752
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/aggregate_internal.h223
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/aggregate_mode.cc419
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/aggregate_quantile.cc513
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc235
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/aggregate_test.cc3670
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/aggregate_var_std.cc298
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/aggregate_var_std_internal.h68
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/chunked_internal.h167
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/codegen_internal.cc420
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h1353
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/codegen_internal_test.cc163
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/common.h54
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/hash_aggregate.cc2659
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc2612
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/row_encoder.cc360
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/row_encoder.h267
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc2609
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic_benchmark.cc159
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc3174
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_boolean.cc563
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_boolean_benchmark.cc59
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_boolean_test.cc159
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_benchmark.cc117
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc70
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc126
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc299
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.h88
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc188
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc784
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_string.cc374
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc598
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_test.cc2334
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_compare.cc540
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_compare_benchmark.cc81
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_compare_test.cc1388
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_if_else.cc2912
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc457
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc2922
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_nested.cc317
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_nested_test.cc249
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc532
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup_benchmark.cc143
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc992
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_string.cc4490
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc240
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_string_test.cc1739
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_temporal_binary.cc542
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc1330
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc1158
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_validity.cc286
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/scalar_validity_test.cc206
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/select_k_test.cc716
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/temporal_internal.h269
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/test_util.cc362
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/test_util.h241
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/util_internal.cc82
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/util_internal.h166
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_array_sort.cc561
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_hash.cc807
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_hash_benchmark.cc250
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_hash_test.cc756
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_nested.cc194
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_nested_test.cc132
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_partition_benchmark.cc59
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_replace.cc541
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc89
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_replace_test.cc677
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_selection.cc2442
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc354
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_selection_test.cc2332
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_sort.cc1902
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc305
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_sort_internal.h457
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_sort_test.cc1925
-rw-r--r--src/arrow/cpp/src/arrow/compute/kernels/vector_topk_benchmark.cc59
-rw-r--r--src/arrow/cpp/src/arrow/compute/registry.cc200
-rw-r--r--src/arrow/cpp/src/arrow/compute/registry.h93
-rw-r--r--src/arrow/cpp/src/arrow/compute/registry_internal.h64
-rw-r--r--src/arrow/cpp/src/arrow/compute/registry_test.cc87
-rw-r--r--src/arrow/cpp/src/arrow/compute/type_fwd.h50
-rw-r--r--src/arrow/cpp/src/arrow/config.cc78
-rw-r--r--src/arrow/cpp/src/arrow/config.h72
-rw-r--r--src/arrow/cpp/src/arrow/csv/CMakeLists.txt39
-rw-r--r--src/arrow/cpp/src/arrow/csv/api.h26
-rw-r--r--src/arrow/cpp/src/arrow/csv/arrow-csv.pc.in24
-rw-r--r--src/arrow/cpp/src/arrow/csv/chunker.cc311
-rw-r--r--src/arrow/cpp/src/arrow/csv/chunker.h36
-rw-r--r--src/arrow/cpp/src/arrow/csv/chunker_test.cc372
-rw-r--r--src/arrow/cpp/src/arrow/csv/column_builder.cc367
-rw-r--r--src/arrow/cpp/src/arrow/csv/column_builder.h78
-rw-r--r--src/arrow/cpp/src/arrow/csv/column_builder_test.cc608
-rw-r--r--src/arrow/cpp/src/arrow/csv/column_decoder.cc250
-rw-r--r--src/arrow/cpp/src/arrow/csv/column_decoder.h64
-rw-r--r--src/arrow/cpp/src/arrow/csv/column_decoder_test.cc385
-rw-r--r--src/arrow/cpp/src/arrow/csv/converter.cc780
-rw-r--r--src/arrow/cpp/src/arrow/csv/converter.h82
-rw-r--r--src/arrow/cpp/src/arrow/csv/converter_benchmark.cc152
-rw-r--r--src/arrow/cpp/src/arrow/csv/converter_test.cc818
-rw-r--r--src/arrow/cpp/src/arrow/csv/inference_internal.h155
-rw-r--r--src/arrow/cpp/src/arrow/csv/invalid_row.h56
-rw-r--r--src/arrow/cpp/src/arrow/csv/options.cc83
-rw-r--r--src/arrow/cpp/src/arrow/csv/options.h194
-rw-r--r--src/arrow/cpp/src/arrow/csv/parser.cc608
-rw-r--r--src/arrow/cpp/src/arrow/csv/parser.h227
-rw-r--r--src/arrow/cpp/src/arrow/csv/parser_benchmark.cc205
-rw-r--r--src/arrow/cpp/src/arrow/csv/parser_test.cc805
-rw-r--r--src/arrow/cpp/src/arrow/csv/reader.cc1303
-rw-r--r--src/arrow/cpp/src/arrow/csv/reader.h125
-rw-r--r--src/arrow/cpp/src/arrow/csv/reader_test.cc490
-rw-r--r--src/arrow/cpp/src/arrow/csv/test_common.cc121
-rw-r--r--src/arrow/cpp/src/arrow/csv/test_common.h55
-rw-r--r--src/arrow/cpp/src/arrow/csv/type_fwd.h28
-rw-r--r--src/arrow/cpp/src/arrow/csv/writer.cc460
-rw-r--r--src/arrow/cpp/src/arrow/csv/writer.h73
-rw-r--r--src/arrow/cpp/src/arrow/csv/writer_test.cc159
-rw-r--r--src/arrow/cpp/src/arrow/dataset/ArrowDatasetConfig.cmake.in37
-rw-r--r--src/arrow/cpp/src/arrow/dataset/CMakeLists.txt145
-rw-r--r--src/arrow/cpp/src/arrow/dataset/README.md32
-rw-r--r--src/arrow/cpp/src/arrow/dataset/api.h30
-rw-r--r--src/arrow/cpp/src/arrow/dataset/arrow-dataset.pc.in25
-rw-r--r--src/arrow/cpp/src/arrow/dataset/dataset.cc269
-rw-r--r--src/arrow/cpp/src/arrow/dataset/dataset.h264
-rw-r--r--src/arrow/cpp/src/arrow/dataset/dataset_internal.h160
-rw-r--r--src/arrow/cpp/src/arrow/dataset/dataset_test.cc734
-rw-r--r--src/arrow/cpp/src/arrow/dataset/dataset_writer.cc529
-rw-r--r--src/arrow/cpp/src/arrow/dataset/dataset_writer.h97
-rw-r--r--src/arrow/cpp/src/arrow/dataset/dataset_writer_test.cc349
-rw-r--r--src/arrow/cpp/src/arrow/dataset/discovery.cc282
-rw-r--r--src/arrow/cpp/src/arrow/dataset/discovery.h271
-rw-r--r--src/arrow/cpp/src/arrow/dataset/discovery_test.cc479
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_base.cc466
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_base.h421
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_benchmark.cc90
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_csv.cc335
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_csv.h123
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_csv_test.cc404
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_ipc.cc310
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_ipc.h125
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_ipc_test.cc173
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_orc.cc193
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_orc.h79
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_orc_test.cc85
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_parquet.cc974
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_parquet.h385
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_parquet_test.cc612
-rw-r--r--src/arrow/cpp/src/arrow/dataset/file_test.cc346
-rw-r--r--src/arrow/cpp/src/arrow/dataset/partition.cc732
-rw-r--r--src/arrow/cpp/src/arrow/dataset/partition.h372
-rw-r--r--src/arrow/cpp/src/arrow/dataset/partition_test.cc836
-rw-r--r--src/arrow/cpp/src/arrow/dataset/pch.h27
-rw-r--r--src/arrow/cpp/src/arrow/dataset/plan.cc39
-rw-r--r--src/arrow/cpp/src/arrow/dataset/plan.h33
-rw-r--r--src/arrow/cpp/src/arrow/dataset/projector.cc63
-rw-r--r--src/arrow/cpp/src/arrow/dataset/projector.h32
-rw-r--r--src/arrow/cpp/src/arrow/dataset/scanner.cc1347
-rw-r--r--src/arrow/cpp/src/arrow/dataset/scanner.h458
-rw-r--r--src/arrow/cpp/src/arrow/dataset/scanner_benchmark.cc210
-rw-r--r--src/arrow/cpp/src/arrow/dataset/scanner_internal.h264
-rw-r--r--src/arrow/cpp/src/arrow/dataset/scanner_test.cc1814
-rw-r--r--src/arrow/cpp/src/arrow/dataset/test_util.h1300
-rw-r--r--src/arrow/cpp/src/arrow/dataset/type_fwd.h104
-rw-r--r--src/arrow/cpp/src/arrow/dataset/visibility.h50
-rw-r--r--src/arrow/cpp/src/arrow/datum.cc292
-rw-r--r--src/arrow/cpp/src/arrow/datum.h281
-rw-r--r--src/arrow/cpp/src/arrow/datum_test.cc172
-rw-r--r--src/arrow/cpp/src/arrow/dbi/README.md24
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/CMakeLists.txt116
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/api.h27
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/columnar_row_set.cc100
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/columnar_row_set.h155
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/hiveserver2_test.cc458
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/operation.cc150
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/operation.h127
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/public_api_test.cc26
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/sample_usage.cc137
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/service.cc110
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/service.h140
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/session.cc103
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/session.h84
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/.gitignore1
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/CMakeLists.txt120
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/ExecStats.thrift103
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/ImpalaService.thrift300
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/Status.thrift23
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/TCLIService.thrift1180
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/Types.thrift218
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/beeswax.thrift174
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/fb303.thrift112
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/generate_error_codes.py293
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/hive_metastore.thrift1214
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift_internal.cc301
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift_internal.h91
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/types.cc45
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/types.h131
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/util.cc250
-rw-r--r--src/arrow/cpp/src/arrow/dbi/hiveserver2/util.h36
-rw-r--r--src/arrow/cpp/src/arrow/device.cc209
-rw-r--r--src/arrow/cpp/src/arrow/device.h226
-rw-r--r--src/arrow/cpp/src/arrow/extension_type.cc169
-rw-r--r--src/arrow/cpp/src/arrow/extension_type.h161
-rw-r--r--src/arrow/cpp/src/arrow/extension_type_test.cc336
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/CMakeLists.txt79
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/api.h28
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/arrow-filesystem.pc.in24
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/filesystem.cc767
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/filesystem.h535
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/filesystem_test.cc825
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/gcsfs.cc269
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/gcsfs.h118
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/gcsfs_internal.cc67
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/gcsfs_internal.h36
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/gcsfs_test.cc264
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/hdfs.cc518
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/hdfs.h113
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/hdfs_test.cc356
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/localfs.cc448
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/localfs.h113
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/localfs_test.cc396
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/mockfs.cc778
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/mockfs.h132
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/path_util.cc271
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/path_util.h130
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/s3_internal.h215
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/s3_test_util.h153
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/s3fs.cc2453
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/s3fs.h315
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/s3fs_benchmark.cc432
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/s3fs_narrative_test.cc245
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/s3fs_test.cc1084
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/test_util.cc1135
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/test_util.h246
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/type_fwd.h49
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/util_internal.cc73
-rw-r--r--src/arrow/cpp/src/arrow/filesystem/util_internal.h56
-rw-r--r--src/arrow/cpp/src/arrow/flight/ArrowFlightConfig.cmake.in36
-rw-r--r--src/arrow/cpp/src/arrow/flight/ArrowFlightTestingConfig.cmake.in37
-rw-r--r--src/arrow/cpp/src/arrow/flight/CMakeLists.txt267
-rw-r--r--src/arrow/cpp/src/arrow/flight/README.md36
-rw-r--r--src/arrow/cpp/src/arrow/flight/api.h27
-rw-r--r--src/arrow/cpp/src/arrow/flight/arrow-flight-testing.pc.in25
-rw-r--r--src/arrow/cpp/src/arrow/flight/arrow-flight.pc.in25
-rw-r--r--src/arrow/cpp/src/arrow/flight/client.cc1355
-rw-r--r--src/arrow/cpp/src/arrow/flight/client.h330
-rw-r--r--src/arrow/cpp/src/arrow/flight/client_auth.h62
-rw-r--r--src/arrow/cpp/src/arrow/flight/client_cookie_middleware.cc65
-rw-r--r--src/arrow/cpp/src/arrow/flight/client_cookie_middleware.h33
-rw-r--r--src/arrow/cpp/src/arrow/flight/client_header_internal.cc337
-rw-r--r--src/arrow/cpp/src/arrow/flight/client_header_internal.h151
-rw-r--r--src/arrow/cpp/src/arrow/flight/client_middleware.h73
-rw-r--r--src/arrow/cpp/src/arrow/flight/customize_protobuf.h108
-rw-r--r--src/arrow/cpp/src/arrow/flight/flight_benchmark.cc493
-rw-r--r--src/arrow/cpp/src/arrow/flight/flight_test.cc2872
-rw-r--r--src/arrow/cpp/src/arrow/flight/internal.cc514
-rw-r--r--src/arrow/cpp/src/arrow/flight/internal.h128
-rw-r--r--src/arrow/cpp/src/arrow/flight/middleware.h73
-rw-r--r--src/arrow/cpp/src/arrow/flight/middleware_internal.h46
-rw-r--r--src/arrow/cpp/src/arrow/flight/pch.h26
-rw-r--r--src/arrow/cpp/src/arrow/flight/perf.proto44
-rw-r--r--src/arrow/cpp/src/arrow/flight/perf_server.cc285
-rw-r--r--src/arrow/cpp/src/arrow/flight/platform.h32
-rw-r--r--src/arrow/cpp/src/arrow/flight/protocol_internal.cc26
-rw-r--r--src/arrow/cpp/src/arrow/flight/protocol_internal.h28
-rw-r--r--src/arrow/cpp/src/arrow/flight/serialization_internal.cc474
-rw-r--r--src/arrow/cpp/src/arrow/flight/serialization_internal.h152
-rw-r--r--src/arrow/cpp/src/arrow/flight/server.cc1165
-rw-r--r--src/arrow/cpp/src/arrow/flight/server.h285
-rw-r--r--src/arrow/cpp/src/arrow/flight/server_auth.cc37
-rw-r--r--src/arrow/cpp/src/arrow/flight/server_auth.h78
-rw-r--r--src/arrow/cpp/src/arrow/flight/server_middleware.h83
-rw-r--r--src/arrow/cpp/src/arrow/flight/test_integration.cc270
-rw-r--r--src/arrow/cpp/src/arrow/flight/test_integration.h49
-rw-r--r--src/arrow/cpp/src/arrow/flight/test_integration_client.cc244
-rw-r--r--src/arrow/cpp/src/arrow/flight/test_integration_server.cc207
-rw-r--r--src/arrow/cpp/src/arrow/flight/test_server.cc62
-rw-r--r--src/arrow/cpp/src/arrow/flight/test_util.cc822
-rw-r--r--src/arrow/cpp/src/arrow/flight/test_util.h242
-rw-r--r--src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_127.cc36
-rw-r--r--src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_132.cc36
-rw-r--r--src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_134.cc44
-rw-r--r--src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_136.cc38
-rw-r--r--src/arrow/cpp/src/arrow/flight/types.cc378
-rw-r--r--src/arrow/cpp/src/arrow/flight/types.h529
-rw-r--r--src/arrow/cpp/src/arrow/flight/visibility.h48
-rw-r--r--src/arrow/cpp/src/arrow/gpu/.gitignore18
-rw-r--r--src/arrow/cpp/src/arrow/gpu/ArrowCUDAConfig.cmake.in36
-rw-r--r--src/arrow/cpp/src/arrow/gpu/CMakeLists.txt88
-rw-r--r--src/arrow/cpp/src/arrow/gpu/arrow-cuda.pc.in26
-rw-r--r--src/arrow/cpp/src/arrow/gpu/cuda_api.h23
-rw-r--r--src/arrow/cpp/src/arrow/gpu/cuda_arrow_ipc.cc69
-rw-r--r--src/arrow/cpp/src/arrow/gpu/cuda_arrow_ipc.h72
-rw-r--r--src/arrow/cpp/src/arrow/gpu/cuda_benchmark.cc94
-rw-r--r--src/arrow/cpp/src/arrow/gpu/cuda_context.cc646
-rw-r--r--src/arrow/cpp/src/arrow/gpu/cuda_context.h310
-rw-r--r--src/arrow/cpp/src/arrow/gpu/cuda_internal.cc66
-rw-r--r--src/arrow/cpp/src/arrow/gpu/cuda_internal.h60
-rw-r--r--src/arrow/cpp/src/arrow/gpu/cuda_memory.cc484
-rw-r--r--src/arrow/cpp/src/arrow/gpu/cuda_memory.h260
-rw-r--r--src/arrow/cpp/src/arrow/gpu/cuda_test.cc626
-rw-r--r--src/arrow/cpp/src/arrow/gpu/cuda_version.h.in25
-rw-r--r--src/arrow/cpp/src/arrow/io/CMakeLists.txt39
-rw-r--r--src/arrow/cpp/src/arrow/io/api.h25
-rw-r--r--src/arrow/cpp/src/arrow/io/buffered.cc489
-rw-r--r--src/arrow/cpp/src/arrow/io/buffered.h167
-rw-r--r--src/arrow/cpp/src/arrow/io/buffered_test.cc667
-rw-r--r--src/arrow/cpp/src/arrow/io/caching.cc318
-rw-r--r--src/arrow/cpp/src/arrow/io/caching.h138
-rw-r--r--src/arrow/cpp/src/arrow/io/compressed.cc450
-rw-r--r--src/arrow/cpp/src/arrow/io/compressed.h118
-rw-r--r--src/arrow/cpp/src/arrow/io/compressed_test.cc311
-rw-r--r--src/arrow/cpp/src/arrow/io/concurrency.h263
-rw-r--r--src/arrow/cpp/src/arrow/io/file.cc789
-rw-r--r--src/arrow/cpp/src/arrow/io/file.h221
-rw-r--r--src/arrow/cpp/src/arrow/io/file_benchmark.cc301
-rw-r--r--src/arrow/cpp/src/arrow/io/file_test.cc1064
-rw-r--r--src/arrow/cpp/src/arrow/io/hdfs.cc738
-rw-r--r--src/arrow/cpp/src/arrow/io/hdfs.h284
-rw-r--r--src/arrow/cpp/src/arrow/io/hdfs_internal.cc556
-rw-r--r--src/arrow/cpp/src/arrow/io/hdfs_internal.h222
-rw-r--r--src/arrow/cpp/src/arrow/io/hdfs_test.cc464
-rw-r--r--src/arrow/cpp/src/arrow/io/interfaces.cc469
-rw-r--r--src/arrow/cpp/src/arrow/io/interfaces.h340
-rw-r--r--src/arrow/cpp/src/arrow/io/memory.cc388
-rw-r--r--src/arrow/cpp/src/arrow/io/memory.h197
-rw-r--r--src/arrow/cpp/src/arrow/io/memory_benchmark.cc359
-rw-r--r--src/arrow/cpp/src/arrow/io/memory_test.cc883
-rw-r--r--src/arrow/cpp/src/arrow/io/mman.h169
-rw-r--r--src/arrow/cpp/src/arrow/io/slow.cc148
-rw-r--r--src/arrow/cpp/src/arrow/io/slow.h118
-rw-r--r--src/arrow/cpp/src/arrow/io/stdio.cc95
-rw-r--r--src/arrow/cpp/src/arrow/io/stdio.h82
-rw-r--r--src/arrow/cpp/src/arrow/io/test_common.cc121
-rw-r--r--src/arrow/cpp/src/arrow/io/test_common.h58
-rw-r--r--src/arrow/cpp/src/arrow/io/transform.cc162
-rw-r--r--src/arrow/cpp/src/arrow/io/transform.h60
-rw-r--r--src/arrow/cpp/src/arrow/io/type_fwd.h79
-rw-r--r--src/arrow/cpp/src/arrow/io/util_internal.h66
-rw-r--r--src/arrow/cpp/src/arrow/ipc/CMakeLists.txt87
-rw-r--r--src/arrow/cpp/src/arrow/ipc/api.h25
-rw-r--r--src/arrow/cpp/src/arrow/ipc/dictionary.cc412
-rw-r--r--src/arrow/cpp/src/arrow/ipc/dictionary.h177
-rw-r--r--src/arrow/cpp/src/arrow/ipc/feather.cc819
-rw-r--r--src/arrow/cpp/src/arrow/ipc/feather.fbs156
-rw-r--r--src/arrow/cpp/src/arrow/ipc/feather.h140
-rw-r--r--src/arrow/cpp/src/arrow/ipc/feather_test.cc373
-rw-r--r--src/arrow/cpp/src/arrow/ipc/file_fuzz.cc28
-rw-r--r--src/arrow/cpp/src/arrow/ipc/file_to_stream.cc64
-rw-r--r--src/arrow/cpp/src/arrow/ipc/generate_fuzz_corpus.cc161
-rw-r--r--src/arrow/cpp/src/arrow/ipc/generate_tensor_fuzz_corpus.cc134
-rw-r--r--src/arrow/cpp/src/arrow/ipc/json_simple.cc994
-rw-r--r--src/arrow/cpp/src/arrow/ipc/json_simple.h66
-rw-r--r--src/arrow/cpp/src/arrow/ipc/json_simple_test.cc1415
-rw-r--r--src/arrow/cpp/src/arrow/ipc/message.cc931
-rw-r--r--src/arrow/cpp/src/arrow/ipc/message.h536
-rw-r--r--src/arrow/cpp/src/arrow/ipc/metadata_internal.cc1497
-rw-r--r--src/arrow/cpp/src/arrow/ipc/metadata_internal.h228
-rw-r--r--src/arrow/cpp/src/arrow/ipc/options.cc41
-rw-r--r--src/arrow/cpp/src/arrow/ipc/options.h160
-rw-r--r--src/arrow/cpp/src/arrow/ipc/read_write_benchmark.cc262
-rw-r--r--src/arrow/cpp/src/arrow/ipc/read_write_test.cc2415
-rw-r--r--src/arrow/cpp/src/arrow/ipc/reader.cc2095
-rw-r--r--src/arrow/cpp/src/arrow/ipc/reader.h536
-rw-r--r--src/arrow/cpp/src/arrow/ipc/stream_fuzz.cc28
-rw-r--r--src/arrow/cpp/src/arrow/ipc/stream_to_file.cc60
-rw-r--r--src/arrow/cpp/src/arrow/ipc/tensor_stream_fuzz.cc29
-rw-r--r--src/arrow/cpp/src/arrow/ipc/tensor_test.cc506
-rw-r--r--src/arrow/cpp/src/arrow/ipc/test_common.cc1125
-rw-r--r--src/arrow/cpp/src/arrow/ipc/test_common.h175
-rw-r--r--src/arrow/cpp/src/arrow/ipc/type_fwd.h65
-rw-r--r--src/arrow/cpp/src/arrow/ipc/util.h41
-rw-r--r--src/arrow/cpp/src/arrow/ipc/writer.cc1429
-rw-r--r--src/arrow/cpp/src/arrow/ipc/writer.h459
-rw-r--r--src/arrow/cpp/src/arrow/json/CMakeLists.txt32
-rw-r--r--src/arrow/cpp/src/arrow/json/api.h21
-rw-r--r--src/arrow/cpp/src/arrow/json/arrow-json.pc.in24
-rw-r--r--src/arrow/cpp/src/arrow/json/chunked_builder.cc470
-rw-r--r--src/arrow/cpp/src/arrow/json/chunked_builder.h68
-rw-r--r--src/arrow/cpp/src/arrow/json/chunked_builder_test.cc450
-rw-r--r--src/arrow/cpp/src/arrow/json/chunker.cc186
-rw-r--r--src/arrow/cpp/src/arrow/json/chunker.h35
-rw-r--r--src/arrow/cpp/src/arrow/json/chunker_test.cc276
-rw-r--r--src/arrow/cpp/src/arrow/json/converter.cc362
-rw-r--r--src/arrow/cpp/src/arrow/json/converter.h94
-rw-r--r--src/arrow/cpp/src/arrow/json/converter_test.cc214
-rw-r--r--src/arrow/cpp/src/arrow/json/object_parser.cc83
-rw-r--r--src/arrow/cpp/src/arrow/json/object_parser.h49
-rw-r--r--src/arrow/cpp/src/arrow/json/object_writer.cc82
-rw-r--r--src/arrow/cpp/src/arrow/json/object_writer.h48
-rw-r--r--src/arrow/cpp/src/arrow/json/options.cc28
-rw-r--r--src/arrow/cpp/src/arrow/json/options.h74
-rw-r--r--src/arrow/cpp/src/arrow/json/parser.cc1107
-rw-r--r--src/arrow/cpp/src/arrow/json/parser.h101
-rw-r--r--src/arrow/cpp/src/arrow/json/parser_benchmark.cc164
-rw-r--r--src/arrow/cpp/src/arrow/json/parser_test.cc265
-rw-r--r--src/arrow/cpp/src/arrow/json/rapidjson_defs.h43
-rw-r--r--src/arrow/cpp/src/arrow/json/reader.cc218
-rw-r--r--src/arrow/cpp/src/arrow/json/reader.h64
-rw-r--r--src/arrow/cpp/src/arrow/json/reader_test.cc278
-rw-r--r--src/arrow/cpp/src/arrow/json/test_common.h260
-rw-r--r--src/arrow/cpp/src/arrow/json/type_fwd.h26
-rw-r--r--src/arrow/cpp/src/arrow/memory_pool.cc797
-rw-r--r--src/arrow/cpp/src/arrow/memory_pool.h185
-rw-r--r--src/arrow/cpp/src/arrow/memory_pool_benchmark.cc129
-rw-r--r--src/arrow/cpp/src/arrow/memory_pool_test.cc174
-rw-r--r--src/arrow/cpp/src/arrow/memory_pool_test.h92
-rw-r--r--src/arrow/cpp/src/arrow/pch.h30
-rw-r--r--src/arrow/cpp/src/arrow/pretty_print.cc646
-rw-r--r--src/arrow/cpp/src/arrow/pretty_print.h125
-rw-r--r--src/arrow/cpp/src/arrow/pretty_print_test.cc1081
-rw-r--r--src/arrow/cpp/src/arrow/public_api_test.cc93
-rw-r--r--src/arrow/cpp/src/arrow/python/ArrowPythonConfig.cmake.in36
-rw-r--r--src/arrow/cpp/src/arrow/python/ArrowPythonFlightConfig.cmake.in37
-rw-r--r--src/arrow/cpp/src/arrow/python/CMakeLists.txt184
-rw-r--r--src/arrow/cpp/src/arrow/python/api.h30
-rw-r--r--src/arrow/cpp/src/arrow/python/arrow-python-flight.pc.in25
-rw-r--r--src/arrow/cpp/src/arrow/python/arrow-python.pc.in26
-rw-r--r--src/arrow/cpp/src/arrow/python/arrow_to_pandas.cc2322
-rw-r--r--src/arrow/cpp/src/arrow/python/arrow_to_pandas.h124
-rw-r--r--src/arrow/cpp/src/arrow/python/arrow_to_python_internal.h49
-rw-r--r--src/arrow/cpp/src/arrow/python/benchmark.cc38
-rw-r--r--src/arrow/cpp/src/arrow/python/benchmark.h36
-rw-r--r--src/arrow/cpp/src/arrow/python/common.cc203
-rw-r--r--src/arrow/cpp/src/arrow/python/common.h360
-rw-r--r--src/arrow/cpp/src/arrow/python/datetime.cc566
-rw-r--r--src/arrow/cpp/src/arrow/python/datetime.h211
-rw-r--r--src/arrow/cpp/src/arrow/python/decimal.cc246
-rw-r--r--src/arrow/cpp/src/arrow/python/decimal.h128
-rw-r--r--src/arrow/cpp/src/arrow/python/deserialize.cc495
-rw-r--r--src/arrow/cpp/src/arrow/python/deserialize.h106
-rw-r--r--src/arrow/cpp/src/arrow/python/extension_type.cc217
-rw-r--r--src/arrow/cpp/src/arrow/python/extension_type.h85
-rw-r--r--src/arrow/cpp/src/arrow/python/filesystem.cc206
-rw-r--r--src/arrow/cpp/src/arrow/python/filesystem.h126
-rw-r--r--src/arrow/cpp/src/arrow/python/flight.cc408
-rw-r--r--src/arrow/cpp/src/arrow/python/flight.h357
-rw-r--r--src/arrow/cpp/src/arrow/python/helpers.cc470
-rw-r--r--src/arrow/cpp/src/arrow/python/helpers.h159
-rw-r--r--src/arrow/cpp/src/arrow/python/inference.cc723
-rw-r--r--src/arrow/cpp/src/arrow/python/inference.h64
-rw-r--r--src/arrow/cpp/src/arrow/python/init.cc24
-rw-r--r--src/arrow/cpp/src/arrow/python/init.h26
-rw-r--r--src/arrow/cpp/src/arrow/python/io.cc374
-rw-r--r--src/arrow/cpp/src/arrow/python/io.h116
-rw-r--r--src/arrow/cpp/src/arrow/python/ipc.cc67
-rw-r--r--src/arrow/cpp/src/arrow/python/ipc.h52
-rw-r--r--src/arrow/cpp/src/arrow/python/iterators.h194
-rw-r--r--src/arrow/cpp/src/arrow/python/numpy_convert.cc562
-rw-r--r--src/arrow/cpp/src/arrow/python/numpy_convert.h120
-rw-r--r--src/arrow/cpp/src/arrow/python/numpy_internal.h182
-rw-r--r--src/arrow/cpp/src/arrow/python/numpy_interop.h96
-rw-r--r--src/arrow/cpp/src/arrow/python/numpy_to_arrow.cc865
-rw-r--r--src/arrow/cpp/src/arrow/python/numpy_to_arrow.h72
-rw-r--r--src/arrow/cpp/src/arrow/python/pch.h24
-rw-r--r--src/arrow/cpp/src/arrow/python/platform.h36
-rw-r--r--src/arrow/cpp/src/arrow/python/pyarrow.cc90
-rw-r--r--src/arrow/cpp/src/arrow/python/pyarrow.h84
-rw-r--r--src/arrow/cpp/src/arrow/python/pyarrow_api.h239
-rw-r--r--src/arrow/cpp/src/arrow/python/pyarrow_lib.h82
-rw-r--r--src/arrow/cpp/src/arrow/python/python_test.cc599
-rw-r--r--src/arrow/cpp/src/arrow/python/python_to_arrow.cc1179
-rw-r--r--src/arrow/cpp/src/arrow/python/python_to_arrow.h80
-rw-r--r--src/arrow/cpp/src/arrow/python/serialize.cc798
-rw-r--r--src/arrow/cpp/src/arrow/python/serialize.h145
-rw-r--r--src/arrow/cpp/src/arrow/python/type_traits.h350
-rw-r--r--src/arrow/cpp/src/arrow/python/util/CMakeLists.txt32
-rw-r--r--src/arrow/cpp/src/arrow/python/util/test_main.cc41
-rw-r--r--src/arrow/cpp/src/arrow/python/visibility.h39
-rw-r--r--src/arrow/cpp/src/arrow/record_batch.cc367
-rw-r--r--src/arrow/cpp/src/arrow/record_batch.h241
-rw-r--r--src/arrow/cpp/src/arrow/record_batch_test.cc320
-rw-r--r--src/arrow/cpp/src/arrow/result.cc36
-rw-r--r--src/arrow/cpp/src/arrow/result.h512
-rw-r--r--src/arrow/cpp/src/arrow/result_internal.h22
-rw-r--r--src/arrow/cpp/src/arrow/result_test.cc799
-rw-r--r--src/arrow/cpp/src/arrow/scalar.cc1008
-rw-r--r--src/arrow/cpp/src/arrow/scalar.h636
-rw-r--r--src/arrow/cpp/src/arrow/scalar_test.cc1629
-rw-r--r--src/arrow/cpp/src/arrow/sparse_tensor.cc478
-rw-r--r--src/arrow/cpp/src/arrow/sparse_tensor.h617
-rw-r--r--src/arrow/cpp/src/arrow/sparse_tensor_test.cc1678
-rw-r--r--src/arrow/cpp/src/arrow/status.cc143
-rw-r--r--src/arrow/cpp/src/arrow/status.h451
-rw-r--r--src/arrow/cpp/src/arrow/status_test.cc212
-rw-r--r--src/arrow/cpp/src/arrow/stl.h466
-rw-r--r--src/arrow/cpp/src/arrow/stl_allocator.h153
-rw-r--r--src/arrow/cpp/src/arrow/stl_iterator.h146
-rw-r--r--src/arrow/cpp/src/arrow/stl_iterator_test.cc252
-rw-r--r--src/arrow/cpp/src/arrow/stl_test.cc558
-rw-r--r--src/arrow/cpp/src/arrow/symbols.map38
-rw-r--r--src/arrow/cpp/src/arrow/table.cc641
-rw-r--r--src/arrow/cpp/src/arrow/table.h295
-rw-r--r--src/arrow/cpp/src/arrow/table_builder.cc113
-rw-r--r--src/arrow/cpp/src/arrow/table_builder.h110
-rw-r--r--src/arrow/cpp/src/arrow/table_builder_test.cc182
-rw-r--r--src/arrow/cpp/src/arrow/table_test.cc753
-rw-r--r--src/arrow/cpp/src/arrow/tensor.cc342
-rw-r--r--src/arrow/cpp/src/arrow/tensor.h246
-rw-r--r--src/arrow/cpp/src/arrow/tensor/CMakeLists.txt25
-rw-r--r--src/arrow/cpp/src/arrow/tensor/converter.h67
-rw-r--r--src/arrow/cpp/src/arrow/tensor/converter_internal.h88
-rw-r--r--src/arrow/cpp/src/arrow/tensor/coo_converter.cc333
-rw-r--r--src/arrow/cpp/src/arrow/tensor/csf_converter.cc289
-rw-r--r--src/arrow/cpp/src/arrow/tensor/csx_converter.cc241
-rw-r--r--src/arrow/cpp/src/arrow/tensor/tensor_conversion_benchmark.cc230
-rw-r--r--src/arrow/cpp/src/arrow/tensor_test.cc749
-rw-r--r--src/arrow/cpp/src/arrow/testing/CMakeLists.txt37
-rw-r--r--src/arrow/cpp/src/arrow/testing/async_test_util.h54
-rw-r--r--src/arrow/cpp/src/arrow/testing/executor_util.h55
-rw-r--r--src/arrow/cpp/src/arrow/testing/extension_type.h158
-rw-r--r--src/arrow/cpp/src/arrow/testing/future_util.h142
-rw-r--r--src/arrow/cpp/src/arrow/testing/generator.cc110
-rw-r--r--src/arrow/cpp/src/arrow/testing/generator.h261
-rw-r--r--src/arrow/cpp/src/arrow/testing/gtest_common.h128
-rw-r--r--src/arrow/cpp/src/arrow/testing/gtest_compat.h33
-rw-r--r--src/arrow/cpp/src/arrow/testing/gtest_util.cc1006
-rw-r--r--src/arrow/cpp/src/arrow/testing/gtest_util.h691
-rw-r--r--src/arrow/cpp/src/arrow/testing/json_integration.cc219
-rw-r--r--src/arrow/cpp/src/arrow/testing/json_integration.h129
-rw-r--r--src/arrow/cpp/src/arrow/testing/json_integration_test.cc1188
-rw-r--r--src/arrow/cpp/src/arrow/testing/json_internal.cc1804
-rw-r--r--src/arrow/cpp/src/arrow/testing/json_internal.h126
-rw-r--r--src/arrow/cpp/src/arrow/testing/matchers.h237
-rw-r--r--src/arrow/cpp/src/arrow/testing/pch.h26
-rw-r--r--src/arrow/cpp/src/arrow/testing/random.cc949
-rw-r--r--src/arrow/cpp/src/arrow/testing/random.h489
-rw-r--r--src/arrow/cpp/src/arrow/testing/random_test.cc513
-rw-r--r--src/arrow/cpp/src/arrow/testing/uniform_real.h84
-rw-r--r--src/arrow/cpp/src/arrow/testing/util.cc188
-rw-r--r--src/arrow/cpp/src/arrow/testing/util.h190
-rw-r--r--src/arrow/cpp/src/arrow/testing/visibility.h48
-rw-r--r--src/arrow/cpp/src/arrow/type.cc2428
-rw-r--r--src/arrow/cpp/src/arrow/type.h2041
-rw-r--r--src/arrow/cpp/src/arrow/type_benchmark.cc439
-rw-r--r--src/arrow/cpp/src/arrow/type_fwd.h631
-rw-r--r--src/arrow/cpp/src/arrow/type_test.cc1792
-rw-r--r--src/arrow/cpp/src/arrow/type_traits.h1059
-rw-r--r--src/arrow/cpp/src/arrow/util/CMakeLists.txt100
-rw-r--r--src/arrow/cpp/src/arrow/util/algorithm.h33
-rw-r--r--src/arrow/cpp/src/arrow/util/align_util.h68
-rw-r--r--src/arrow/cpp/src/arrow/util/align_util_test.cc150
-rw-r--r--src/arrow/cpp/src/arrow/util/aligned_storage.h127
-rw-r--r--src/arrow/cpp/src/arrow/util/async_generator.h1804
-rw-r--r--src/arrow/cpp/src/arrow/util/async_generator_test.cc1842
-rw-r--r--src/arrow/cpp/src/arrow/util/async_util.cc206
-rw-r--r--src/arrow/cpp/src/arrow/util/async_util.h258
-rw-r--r--src/arrow/cpp/src/arrow/util/async_util_test.cc239
-rw-r--r--src/arrow/cpp/src/arrow/util/atomic_shared_ptr.h111
-rw-r--r--src/arrow/cpp/src/arrow/util/base64.h35
-rw-r--r--src/arrow/cpp/src/arrow/util/basic_decimal.cc1381
-rw-r--r--src/arrow/cpp/src/arrow/util/basic_decimal.h494
-rw-r--r--src/arrow/cpp/src/arrow/util/benchmark_main.cc24
-rw-r--r--src/arrow/cpp/src/arrow/util/benchmark_util.h138
-rw-r--r--src/arrow/cpp/src/arrow/util/bit_block_counter.cc74
-rw-r--r--src/arrow/cpp/src/arrow/util/bit_block_counter.h542
-rw-r--r--src/arrow/cpp/src/arrow/util/bit_block_counter_benchmark.cc266
-rw-r--r--src/arrow/cpp/src/arrow/util/bit_block_counter_test.cc417
-rw-r--r--src/arrow/cpp/src/arrow/util/bit_run_reader.cc54
-rw-r--r--src/arrow/cpp/src/arrow/util/bit_run_reader.h515
-rw-r--r--src/arrow/cpp/src/arrow/util/bit_stream_utils.h513
-rw-r--r--src/arrow/cpp/src/arrow/util/bit_util.cc129
-rw-r--r--src/arrow/cpp/src/arrow/util/bit_util.h354
-rw-r--r--src/arrow/cpp/src/arrow/util/bit_util_benchmark.cc560
-rw-r--r--src/arrow/cpp/src/arrow/util/bit_util_test.cc2330
-rw-r--r--src/arrow/cpp/src/arrow/util/bitmap.cc75
-rw-r--r--src/arrow/cpp/src/arrow/util/bitmap.h461
-rw-r--r--src/arrow/cpp/src/arrow/util/bitmap_builders.cc72
-rw-r--r--src/arrow/cpp/src/arrow/util/bitmap_builders.h43
-rw-r--r--src/arrow/cpp/src/arrow/util/bitmap_generate.h111
-rw-r--r--src/arrow/cpp/src/arrow/util/bitmap_ops.cc387
-rw-r--r--src/arrow/cpp/src/arrow/util/bitmap_ops.h206
-rw-r--r--src/arrow/cpp/src/arrow/util/bitmap_reader.h271
-rw-r--r--src/arrow/cpp/src/arrow/util/bitmap_reader_benchmark.cc113
-rw-r--r--src/arrow/cpp/src/arrow/util/bitmap_visit.h88
-rw-r--r--src/arrow/cpp/src/arrow/util/bitmap_writer.h285
-rw-r--r--src/arrow/cpp/src/arrow/util/bitset_stack.h89
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking.cc396
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking.h34
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking64_codegen.py131
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking64_default.h5642
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking_avx2.cc31
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking_avx2.h28
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking_avx512.cc31
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking_avx512.h28
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking_default.h4251
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking_neon.cc31
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking_neon.h28
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking_simd128_generated.h2144
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking_simd256_generated.h1271
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking_simd512_generated.h837
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking_simd_codegen.py223
-rw-r--r--src/arrow/cpp/src/arrow/util/bpacking_simd_internal.h138
-rw-r--r--src/arrow/cpp/src/arrow/util/byte_stream_split.h626
-rw-r--r--src/arrow/cpp/src/arrow/util/cache_benchmark.cc146
-rw-r--r--src/arrow/cpp/src/arrow/util/cache_internal.h210
-rw-r--r--src/arrow/cpp/src/arrow/util/cache_test.cc290
-rw-r--r--src/arrow/cpp/src/arrow/util/cancel.cc226
-rw-r--r--src/arrow/cpp/src/arrow/util/cancel.h102
-rw-r--r--src/arrow/cpp/src/arrow/util/cancel_test.cc308
-rw-r--r--src/arrow/cpp/src/arrow/util/checked_cast.h61
-rw-r--r--src/arrow/cpp/src/arrow/util/checked_cast_test.cc74
-rw-r--r--src/arrow/cpp/src/arrow/util/compare.h62
-rw-r--r--src/arrow/cpp/src/arrow/util/compression.cc261
-rw-r--r--src/arrow/cpp/src/arrow/util/compression.h202
-rw-r--r--src/arrow/cpp/src/arrow/util/compression_benchmark.cc201
-rw-r--r--src/arrow/cpp/src/arrow/util/compression_brotli.cc245
-rw-r--r--src/arrow/cpp/src/arrow/util/compression_bz2.cc287
-rw-r--r--src/arrow/cpp/src/arrow/util/compression_internal.h80
-rw-r--r--src/arrow/cpp/src/arrow/util/compression_lz4.cc495
-rw-r--r--src/arrow/cpp/src/arrow/util/compression_snappy.cc102
-rw-r--r--src/arrow/cpp/src/arrow/util/compression_test.cc635
-rw-r--r--src/arrow/cpp/src/arrow/util/compression_zlib.cc507
-rw-r--r--src/arrow/cpp/src/arrow/util/compression_zstd.cc249
-rw-r--r--src/arrow/cpp/src/arrow/util/concurrent_map.h68
-rw-r--r--src/arrow/cpp/src/arrow/util/config.h.cmake48
-rw-r--r--src/arrow/cpp/src/arrow/util/converter.h411
-rw-r--r--src/arrow/cpp/src/arrow/util/counting_semaphore.cc126
-rw-r--r--src/arrow/cpp/src/arrow/util/counting_semaphore.h60
-rw-r--r--src/arrow/cpp/src/arrow/util/counting_semaphore_test.cc98
-rw-r--r--src/arrow/cpp/src/arrow/util/cpu_info.cc563
-rw-r--r--src/arrow/cpp/src/arrow/util/cpu_info.h143
-rw-r--r--src/arrow/cpp/src/arrow/util/decimal.cc908
-rw-r--r--src/arrow/cpp/src/arrow/util/decimal.h314
-rw-r--r--src/arrow/cpp/src/arrow/util/decimal_benchmark.cc282
-rw-r--r--src/arrow/cpp/src/arrow/util/decimal_test.cc1939
-rw-r--r--src/arrow/cpp/src/arrow/util/delimiting.cc193
-rw-r--r--src/arrow/cpp/src/arrow/util/delimiting.h181
-rw-r--r--src/arrow/cpp/src/arrow/util/dispatch.h115
-rw-r--r--src/arrow/cpp/src/arrow/util/double_conversion.h32
-rw-r--r--src/arrow/cpp/src/arrow/util/endian.h245
-rw-r--r--src/arrow/cpp/src/arrow/util/formatting.cc91
-rw-r--r--src/arrow/cpp/src/arrow/util/formatting.h602
-rw-r--r--src/arrow/cpp/src/arrow/util/formatting_util_test.cc468
-rw-r--r--src/arrow/cpp/src/arrow/util/functional.h160
-rw-r--r--src/arrow/cpp/src/arrow/util/future.cc437
-rw-r--r--src/arrow/cpp/src/arrow/util/future.h978
-rw-r--r--src/arrow/cpp/src/arrow/util/future_iterator.h75
-rw-r--r--src/arrow/cpp/src/arrow/util/future_test.cc1803
-rw-r--r--src/arrow/cpp/src/arrow/util/hash_util.h66
-rw-r--r--src/arrow/cpp/src/arrow/util/hashing.h886
-rw-r--r--src/arrow/cpp/src/arrow/util/hashing_benchmark.cc123
-rw-r--r--src/arrow/cpp/src/arrow/util/hashing_test.cc490
-rw-r--r--src/arrow/cpp/src/arrow/util/int128_internal.h45
-rw-r--r--src/arrow/cpp/src/arrow/util/int_util.cc952
-rw-r--r--src/arrow/cpp/src/arrow/util/int_util.h117
-rw-r--r--src/arrow/cpp/src/arrow/util/int_util_benchmark.cc143
-rw-r--r--src/arrow/cpp/src/arrow/util/int_util_internal.h153
-rw-r--r--src/arrow/cpp/src/arrow/util/int_util_test.cc597
-rw-r--r--src/arrow/cpp/src/arrow/util/io_util.cc1685
-rw-r--r--src/arrow/cpp/src/arrow/util/io_util.h349
-rw-r--r--src/arrow/cpp/src/arrow/util/io_util_test.cc713
-rw-r--r--src/arrow/cpp/src/arrow/util/io_util_test.manifest39
-rw-r--r--src/arrow/cpp/src/arrow/util/io_util_test.rc44
-rw-r--r--src/arrow/cpp/src/arrow/util/iterator.h568
-rw-r--r--src/arrow/cpp/src/arrow/util/iterator_test.cc465
-rw-r--r--src/arrow/cpp/src/arrow/util/key_value_metadata.cc274
-rw-r--r--src/arrow/cpp/src/arrow/util/key_value_metadata.h98
-rw-r--r--src/arrow/cpp/src/arrow/util/key_value_metadata_test.cc211
-rw-r--r--src/arrow/cpp/src/arrow/util/launder.h35
-rw-r--r--src/arrow/cpp/src/arrow/util/logging.cc256
-rw-r--r--src/arrow/cpp/src/arrow/util/logging.h259
-rw-r--r--src/arrow/cpp/src/arrow/util/logging_test.cc103
-rw-r--r--src/arrow/cpp/src/arrow/util/machine_benchmark.cc74
-rw-r--r--src/arrow/cpp/src/arrow/util/macros.h225
-rw-r--r--src/arrow/cpp/src/arrow/util/make_unique.h42
-rw-r--r--src/arrow/cpp/src/arrow/util/map.h63
-rw-r--r--src/arrow/cpp/src/arrow/util/math_constants.h32
-rw-r--r--src/arrow/cpp/src/arrow/util/memory.cc74
-rw-r--r--src/arrow/cpp/src/arrow/util/memory.h43
-rw-r--r--src/arrow/cpp/src/arrow/util/mutex.cc54
-rw-r--r--src/arrow/cpp/src/arrow/util/mutex.h64
-rw-r--r--src/arrow/cpp/src/arrow/util/optional.h35
-rw-r--r--src/arrow/cpp/src/arrow/util/parallel.h102
-rw-r--r--src/arrow/cpp/src/arrow/util/pcg_random.h31
-rw-r--r--src/arrow/cpp/src/arrow/util/print.h51
-rw-r--r--src/arrow/cpp/src/arrow/util/queue.h29
-rw-r--r--src/arrow/cpp/src/arrow/util/queue_benchmark.cc85
-rw-r--r--src/arrow/cpp/src/arrow/util/queue_test.cc55
-rw-r--r--src/arrow/cpp/src/arrow/util/range.h155
-rw-r--r--src/arrow/cpp/src/arrow/util/range_benchmark.cc128
-rw-r--r--src/arrow/cpp/src/arrow/util/range_test.cc69
-rw-r--r--src/arrow/cpp/src/arrow/util/reflection_internal.h133
-rw-r--r--src/arrow/cpp/src/arrow/util/reflection_test.cc224
-rw-r--r--src/arrow/cpp/src/arrow/util/rle_encoding.h826
-rw-r--r--src/arrow/cpp/src/arrow/util/rle_encoding_test.cc573
-rw-r--r--src/arrow/cpp/src/arrow/util/simd.h50
-rw-r--r--src/arrow/cpp/src/arrow/util/small_vector.h519
-rw-r--r--src/arrow/cpp/src/arrow/util/small_vector_benchmark.cc344
-rw-r--r--src/arrow/cpp/src/arrow/util/small_vector_test.cc786
-rw-r--r--src/arrow/cpp/src/arrow/util/sort.h78
-rw-r--r--src/arrow/cpp/src/arrow/util/spaced.h98
-rw-r--r--src/arrow/cpp/src/arrow/util/stl_util_test.cc172
-rw-r--r--src/arrow/cpp/src/arrow/util/stopwatch.h48
-rw-r--r--src/arrow/cpp/src/arrow/util/string.cc191
-rw-r--r--src/arrow/cpp/src/arrow/util/string.h79
-rw-r--r--src/arrow/cpp/src/arrow/util/string_builder.cc40
-rw-r--r--src/arrow/cpp/src/arrow/util/string_builder.h84
-rw-r--r--src/arrow/cpp/src/arrow/util/string_test.cc144
-rw-r--r--src/arrow/cpp/src/arrow/util/string_view.h38
-rw-r--r--src/arrow/cpp/src/arrow/util/task_group.cc224
-rw-r--r--src/arrow/cpp/src/arrow/util/task_group.h106
-rw-r--r--src/arrow/cpp/src/arrow/util/task_group_test.cc444
-rw-r--r--src/arrow/cpp/src/arrow/util/tdigest.cc420
-rw-r--r--src/arrow/cpp/src/arrow/util/tdigest.h104
-rw-r--r--src/arrow/cpp/src/arrow/util/tdigest_benchmark.cc48
-rw-r--r--src/arrow/cpp/src/arrow/util/tdigest_test.cc290
-rw-r--r--src/arrow/cpp/src/arrow/util/test_common.cc68
-rw-r--r--src/arrow/cpp/src/arrow/util/test_common.h90
-rw-r--r--src/arrow/cpp/src/arrow/util/thread_pool.cc450
-rw-r--r--src/arrow/cpp/src/arrow/util/thread_pool.h403
-rw-r--r--src/arrow/cpp/src/arrow/util/thread_pool_benchmark.cc248
-rw-r--r--src/arrow/cpp/src/arrow/util/thread_pool_test.cc718
-rw-r--r--src/arrow/cpp/src/arrow/util/time.cc68
-rw-r--r--src/arrow/cpp/src/arrow/util/time.h83
-rw-r--r--src/arrow/cpp/src/arrow/util/time_test.cc63
-rw-r--r--src/arrow/cpp/src/arrow/util/trie.cc211
-rw-r--r--src/arrow/cpp/src/arrow/util/trie.h245
-rw-r--r--src/arrow/cpp/src/arrow/util/trie_benchmark.cc222
-rw-r--r--src/arrow/cpp/src/arrow/util/trie_test.cc305
-rw-r--r--src/arrow/cpp/src/arrow/util/type_fwd.h62
-rw-r--r--src/arrow/cpp/src/arrow/util/type_traits.h86
-rw-r--r--src/arrow/cpp/src/arrow/util/ubsan.h88
-rw-r--r--src/arrow/cpp/src/arrow/util/unreachable.cc29
-rw-r--r--src/arrow/cpp/src/arrow/util/unreachable.h24
-rw-r--r--src/arrow/cpp/src/arrow/util/uri.cc292
-rw-r--r--src/arrow/cpp/src/arrow/util/uri.h104
-rw-r--r--src/arrow/cpp/src/arrow/util/uri_test.cc312
-rw-r--r--src/arrow/cpp/src/arrow/util/utf8.cc160
-rw-r--r--src/arrow/cpp/src/arrow/util/utf8.h566
-rw-r--r--src/arrow/cpp/src/arrow/util/utf8_util_benchmark.cc150
-rw-r--r--src/arrow/cpp/src/arrow/util/utf8_util_test.cc513
-rw-r--r--src/arrow/cpp/src/arrow/util/value_parsing.cc87
-rw-r--r--src/arrow/cpp/src/arrow/util/value_parsing.h853
-rw-r--r--src/arrow/cpp/src/arrow/util/value_parsing_benchmark.cc303
-rw-r--r--src/arrow/cpp/src/arrow/util/value_parsing_test.cc643
-rw-r--r--src/arrow/cpp/src/arrow/util/variant.h443
-rw-r--r--src/arrow/cpp/src/arrow/util/variant_benchmark.cc248
-rw-r--r--src/arrow/cpp/src/arrow/util/variant_test.cc345
-rw-r--r--src/arrow/cpp/src/arrow/util/vector.h172
-rw-r--r--src/arrow/cpp/src/arrow/util/visibility.h45
-rw-r--r--src/arrow/cpp/src/arrow/util/windows_compatibility.h42
-rw-r--r--src/arrow/cpp/src/arrow/util/windows_fixup.h52
-rw-r--r--src/arrow/cpp/src/arrow/vendored/CMakeLists.txt21
-rw-r--r--src/arrow/cpp/src/arrow/vendored/ProducerConsumerQueue.h217
-rw-r--r--src/arrow/cpp/src/arrow/vendored/base64.cpp134
-rw-r--r--src/arrow/cpp/src/arrow/vendored/datetime.h26
-rw-r--r--src/arrow/cpp/src/arrow/vendored/datetime/CMakeLists.txt18
-rw-r--r--src/arrow/cpp/src/arrow/vendored/datetime/README.md28
-rw-r--r--src/arrow/cpp/src/arrow/vendored/datetime/date.h8237
-rw-r--r--src/arrow/cpp/src/arrow/vendored/datetime/ios.h53
-rw-r--r--src/arrow/cpp/src/arrow/vendored/datetime/ios.mm340
-rw-r--r--src/arrow/cpp/src/arrow/vendored/datetime/tz.cpp3951
-rw-r--r--src/arrow/cpp/src/arrow/vendored/datetime/tz.h2801
-rw-r--r--src/arrow/cpp/src/arrow/vendored/datetime/tz_private.h319
-rw-r--r--src/arrow/cpp/src/arrow/vendored/datetime/visibility.h26
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/.gitignore1
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/CMakeLists.txt18
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/README.md20
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/bignum-dtoa.cc641
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/bignum-dtoa.h84
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/bignum.cc767
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/bignum.h144
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/cached-powers.cc175
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/cached-powers.h64
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/diy-fp.cc57
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/diy-fp.h118
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/double-conversion.cc1171
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/double-conversion.h587
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/fast-dtoa.cc665
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/fast-dtoa.h88
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/fixed-dtoa.cc405
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/fixed-dtoa.h56
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/ieee.h402
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/strtod.cc580
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/strtod.h45
-rw-r--r--src/arrow/cpp/src/arrow/vendored/double-conversion/utils.h367
-rw-r--r--src/arrow/cpp/src/arrow/vendored/fast_float/README.md7
-rw-r--r--src/arrow/cpp/src/arrow/vendored/fast_float/ascii_number.h301
-rw-r--r--src/arrow/cpp/src/arrow/vendored/fast_float/decimal_to_binary.h176
-rw-r--r--src/arrow/cpp/src/arrow/vendored/fast_float/fast_float.h48
-rw-r--r--src/arrow/cpp/src/arrow/vendored/fast_float/fast_table.h691
-rw-r--r--src/arrow/cpp/src/arrow/vendored/fast_float/float_common.h345
-rw-r--r--src/arrow/cpp/src/arrow/vendored/fast_float/parse_number.h133
-rw-r--r--src/arrow/cpp/src/arrow/vendored/fast_float/simple_decimal_conversion.h362
-rw-r--r--src/arrow/cpp/src/arrow/vendored/musl/README.md25
-rw-r--r--src/arrow/cpp/src/arrow/vendored/musl/strptime.c237
-rw-r--r--src/arrow/cpp/src/arrow/vendored/optional.hpp1553
-rw-r--r--src/arrow/cpp/src/arrow/vendored/pcg/README.md26
-rw-r--r--src/arrow/cpp/src/arrow/vendored/pcg/pcg_extras.hpp670
-rw-r--r--src/arrow/cpp/src/arrow/vendored/pcg/pcg_random.hpp1954
-rw-r--r--src/arrow/cpp/src/arrow/vendored/pcg/pcg_uint128.hpp1008
-rw-r--r--src/arrow/cpp/src/arrow/vendored/portable-snippets/README.md10
-rw-r--r--src/arrow/cpp/src/arrow/vendored/portable-snippets/safe-math.h1072
-rw-r--r--src/arrow/cpp/src/arrow/vendored/string_view.hpp1531
-rw-r--r--src/arrow/cpp/src/arrow/vendored/strptime.h35
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/README.md25
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/Uri.h1090
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriBase.h377
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriCommon.c572
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriCommon.h109
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriCompare.c168
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriDefsAnsi.h82
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriDefsConfig.h102
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriDefsUnicode.h82
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriEscape.c453
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriFile.c242
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4.c329
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4.h110
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4Base.c96
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4Base.h59
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriMemory.c468
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriMemory.h78
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriNormalize.c771
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriNormalizeBase.c119
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriNormalizeBase.h53
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriParse.c2410
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriParseBase.c90
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriParseBase.h55
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriQuery.c501
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriRecompose.c577
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriResolve.c329
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/UriShorten.c324
-rw-r--r--src/arrow/cpp/src/arrow/vendored/uriparser/config.h47
-rw-r--r--src/arrow/cpp/src/arrow/vendored/utfcpp/README.md28
-rw-r--r--src/arrow/cpp/src/arrow/vendored/utfcpp/checked.h333
-rw-r--r--src/arrow/cpp/src/arrow/vendored/utfcpp/core.h338
-rw-r--r--src/arrow/cpp/src/arrow/vendored/utfcpp/cpp11.h103
-rw-r--r--src/arrow/cpp/src/arrow/vendored/xxhash.h18
-rw-r--r--src/arrow/cpp/src/arrow/vendored/xxhash/README.md22
-rw-r--r--src/arrow/cpp/src/arrow/vendored/xxhash/xxhash.c43
-rw-r--r--src/arrow/cpp/src/arrow/vendored/xxhash/xxhash.h4769
-rw-r--r--src/arrow/cpp/src/arrow/visitor.cc172
-rw-r--r--src/arrow/cpp/src/arrow/visitor.h155
-rw-r--r--src/arrow/cpp/src/arrow/visitor_inline.h450
-rw-r--r--src/arrow/cpp/src/gandiva/CMakeLists.txt253
-rw-r--r--src/arrow/cpp/src/gandiva/GandivaConfig.cmake.in36
-rw-r--r--src/arrow/cpp/src/gandiva/annotator.cc118
-rw-r--r--src/arrow/cpp/src/gandiva/annotator.h81
-rw-r--r--src/arrow/cpp/src/gandiva/annotator_test.cc102
-rw-r--r--src/arrow/cpp/src/gandiva/arrow.h57
-rw-r--r--src/arrow/cpp/src/gandiva/basic_decimal_scalar.h65
-rw-r--r--src/arrow/cpp/src/gandiva/bitmap_accumulator.cc75
-rw-r--r--src/arrow/cpp/src/gandiva/bitmap_accumulator.h79
-rw-r--r--src/arrow/cpp/src/gandiva/bitmap_accumulator_test.cc112
-rw-r--r--src/arrow/cpp/src/gandiva/cache.cc45
-rw-r--r--src/arrow/cpp/src/gandiva/cache.h60
-rw-r--r--src/arrow/cpp/src/gandiva/cast_time.cc85
-rw-r--r--src/arrow/cpp/src/gandiva/compiled_expr.h71
-rw-r--r--src/arrow/cpp/src/gandiva/condition.h37
-rw-r--r--src/arrow/cpp/src/gandiva/configuration.cc43
-rw-r--r--src/arrow/cpp/src/gandiva/configuration.h84
-rw-r--r--src/arrow/cpp/src/gandiva/context_helper.cc76
-rw-r--r--src/arrow/cpp/src/gandiva/date_utils.cc232
-rw-r--r--src/arrow/cpp/src/gandiva/date_utils.h52
-rw-r--r--src/arrow/cpp/src/gandiva/decimal_ir.cc559
-rw-r--r--src/arrow/cpp/src/gandiva/decimal_ir.h188
-rw-r--r--src/arrow/cpp/src/gandiva/decimal_scalar.h76
-rw-r--r--src/arrow/cpp/src/gandiva/decimal_type_util.cc75
-rw-r--r--src/arrow/cpp/src/gandiva/decimal_type_util.h83
-rw-r--r--src/arrow/cpp/src/gandiva/decimal_type_util_test.cc58
-rw-r--r--src/arrow/cpp/src/gandiva/decimal_xlarge.cc284
-rw-r--r--src/arrow/cpp/src/gandiva/decimal_xlarge.h41
-rw-r--r--src/arrow/cpp/src/gandiva/dex.h396
-rw-r--r--src/arrow/cpp/src/gandiva/dex_visitor.h97
-rw-r--r--src/arrow/cpp/src/gandiva/engine.cc338
-rw-r--r--src/arrow/cpp/src/gandiva/engine.h104
-rw-r--r--src/arrow/cpp/src/gandiva/engine_llvm_test.cc131
-rw-r--r--src/arrow/cpp/src/gandiva/eval_batch.h107
-rw-r--r--src/arrow/cpp/src/gandiva/execution_context.h54
-rw-r--r--src/arrow/cpp/src/gandiva/exported_funcs.h59
-rw-r--r--src/arrow/cpp/src/gandiva/exported_funcs_registry.cc30
-rw-r--r--src/arrow/cpp/src/gandiva/exported_funcs_registry.h54
-rw-r--r--src/arrow/cpp/src/gandiva/expr_decomposer.cc310
-rw-r--r--src/arrow/cpp/src/gandiva/expr_decomposer.h128
-rw-r--r--src/arrow/cpp/src/gandiva/expr_decomposer_test.cc409
-rw-r--r--src/arrow/cpp/src/gandiva/expr_validator.cc193
-rw-r--r--src/arrow/cpp/src/gandiva/expr_validator.h80
-rw-r--r--src/arrow/cpp/src/gandiva/expression.cc25
-rw-r--r--src/arrow/cpp/src/gandiva/expression.h46
-rw-r--r--src/arrow/cpp/src/gandiva/expression_registry.cc187
-rw-r--r--src/arrow/cpp/src/gandiva/expression_registry.h71
-rw-r--r--src/arrow/cpp/src/gandiva/expression_registry_test.cc68
-rw-r--r--src/arrow/cpp/src/gandiva/field_descriptor.h69
-rw-r--r--src/arrow/cpp/src/gandiva/filter.cc171
-rw-r--r--src/arrow/cpp/src/gandiva/filter.h112
-rw-r--r--src/arrow/cpp/src/gandiva/formatting_utils.h69
-rw-r--r--src/arrow/cpp/src/gandiva/func_descriptor.h50
-rw-r--r--src/arrow/cpp/src/gandiva/function_holder.h34
-rw-r--r--src/arrow/cpp/src/gandiva/function_holder_registry.h76
-rw-r--r--src/arrow/cpp/src/gandiva/function_ir_builder.cc81
-rw-r--r--src/arrow/cpp/src/gandiva/function_ir_builder.h61
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry.cc83
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry.h47
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry_arithmetic.cc125
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry_arithmetic.h27
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry_common.h268
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry_datetime.cc132
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry_datetime.h27
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry_hash.cc63
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry_hash.h27
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry_math_ops.cc106
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry_math_ops.h27
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry_string.cc422
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry_string.h27
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry_test.cc96
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry_timestamp_arithmetic.cc89
-rw-r--r--src/arrow/cpp/src/gandiva/function_registry_timestamp_arithmetic.h27
-rw-r--r--src/arrow/cpp/src/gandiva/function_signature.cc113
-rw-r--r--src/arrow/cpp/src/gandiva/function_signature.h55
-rw-r--r--src/arrow/cpp/src/gandiva/function_signature_test.cc113
-rw-r--r--src/arrow/cpp/src/gandiva/gandiva.pc.in27
-rw-r--r--src/arrow/cpp/src/gandiva/gandiva_aliases.h62
-rw-r--r--src/arrow/cpp/src/gandiva/gdv_function_stubs.cc1603
-rw-r--r--src/arrow/cpp/src/gandiva/gdv_function_stubs.h173
-rw-r--r--src/arrow/cpp/src/gandiva/gdv_function_stubs_test.cc769
-rw-r--r--src/arrow/cpp/src/gandiva/greedy_dual_size_cache.h154
-rw-r--r--src/arrow/cpp/src/gandiva/greedy_dual_size_cache_test.cc88
-rw-r--r--src/arrow/cpp/src/gandiva/hash_utils.cc134
-rw-r--r--src/arrow/cpp/src/gandiva/hash_utils.h44
-rw-r--r--src/arrow/cpp/src/gandiva/hash_utils_test.cc164
-rw-r--r--src/arrow/cpp/src/gandiva/in_holder.h91
-rw-r--r--src/arrow/cpp/src/gandiva/jni/CMakeLists.txt107
-rw-r--r--src/arrow/cpp/src/gandiva/jni/config_builder.cc53
-rw-r--r--src/arrow/cpp/src/gandiva/jni/config_holder.cc30
-rw-r--r--src/arrow/cpp/src/gandiva/jni/config_holder.h68
-rw-r--r--src/arrow/cpp/src/gandiva/jni/env_helper.h23
-rw-r--r--src/arrow/cpp/src/gandiva/jni/expression_registry_helper.cc190
-rw-r--r--src/arrow/cpp/src/gandiva/jni/id_to_module_map.h66
-rw-r--r--src/arrow/cpp/src/gandiva/jni/jni_common.cc1055
-rw-r--r--src/arrow/cpp/src/gandiva/jni/module_holder.h59
-rw-r--r--src/arrow/cpp/src/gandiva/jni/symbols.map20
-rw-r--r--src/arrow/cpp/src/gandiva/like_holder.cc156
-rw-r--r--src/arrow/cpp/src/gandiva/like_holder.h68
-rw-r--r--src/arrow/cpp/src/gandiva/like_holder_test.cc281
-rw-r--r--src/arrow/cpp/src/gandiva/literal_holder.cc45
-rw-r--r--src/arrow/cpp/src/gandiva/literal_holder.h36
-rw-r--r--src/arrow/cpp/src/gandiva/llvm_generator.cc1400
-rw-r--r--src/arrow/cpp/src/gandiva/llvm_generator.h253
-rw-r--r--src/arrow/cpp/src/gandiva/llvm_generator_test.cc116
-rw-r--r--src/arrow/cpp/src/gandiva/llvm_includes.h56
-rw-r--r--src/arrow/cpp/src/gandiva/llvm_types.cc48
-rw-r--r--src/arrow/cpp/src/gandiva/llvm_types.h130
-rw-r--r--src/arrow/cpp/src/gandiva/llvm_types_test.cc61
-rw-r--r--src/arrow/cpp/src/gandiva/local_bitmaps_holder.h85
-rw-r--r--src/arrow/cpp/src/gandiva/lvalue.h77
-rw-r--r--src/arrow/cpp/src/gandiva/make_precompiled_bitcode.py49
-rw-r--r--src/arrow/cpp/src/gandiva/native_function.h81
-rw-r--r--src/arrow/cpp/src/gandiva/node.h299
-rw-r--r--src/arrow/cpp/src/gandiva/node_visitor.h56
-rw-r--r--src/arrow/cpp/src/gandiva/pch.h24
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/CMakeLists.txt142
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/arithmetic_ops.cc274
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc180
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/bitmap.cc60
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/bitmap_test.cc62
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/decimal_ops.cc723
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/decimal_ops.h90
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/decimal_ops_test.cc1095
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/decimal_wrapper.cc433
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/epoch_time_point.h118
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/epoch_time_point_test.cc103
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/extended_math_ops.cc410
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/extended_math_ops_test.cc349
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/hash.cc407
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/hash_test.cc122
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/print.cc28
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/string_ops.cc2198
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/string_ops_test.cc1758
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/testing.h43
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/time.cc894
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/time_constants.h30
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/time_fields.h35
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/time_test.cc953
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/timestamp_arithmetic.cc283
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled/types.h592
-rw-r--r--src/arrow/cpp/src/gandiva/precompiled_bitcode.cc.in26
-rw-r--r--src/arrow/cpp/src/gandiva/projector.cc369
-rw-r--r--src/arrow/cpp/src/gandiva/projector.h143
-rw-r--r--src/arrow/cpp/src/gandiva/proto/Types.proto255
-rw-r--r--src/arrow/cpp/src/gandiva/random_generator_holder.cc45
-rw-r--r--src/arrow/cpp/src/gandiva/random_generator_holder.h57
-rw-r--r--src/arrow/cpp/src/gandiva/random_generator_holder_test.cc103
-rw-r--r--src/arrow/cpp/src/gandiva/regex_util.cc63
-rw-r--r--src/arrow/cpp/src/gandiva/regex_util.h45
-rw-r--r--src/arrow/cpp/src/gandiva/replace_holder.cc65
-rw-r--r--src/arrow/cpp/src/gandiva/replace_holder.h97
-rw-r--r--src/arrow/cpp/src/gandiva/replace_holder_test.cc129
-rw-r--r--src/arrow/cpp/src/gandiva/selection_vector.cc179
-rw-r--r--src/arrow/cpp/src/gandiva/selection_vector.h151
-rw-r--r--src/arrow/cpp/src/gandiva/selection_vector_impl.h108
-rw-r--r--src/arrow/cpp/src/gandiva/selection_vector_test.cc270
-rw-r--r--src/arrow/cpp/src/gandiva/simple_arena.h160
-rw-r--r--src/arrow/cpp/src/gandiva/simple_arena_test.cc102
-rw-r--r--src/arrow/cpp/src/gandiva/symbols.map35
-rw-r--r--src/arrow/cpp/src/gandiva/tests/CMakeLists.txt42
-rw-r--r--src/arrow/cpp/src/gandiva/tests/binary_test.cc136
-rw-r--r--src/arrow/cpp/src/gandiva/tests/boolean_expr_test.cc388
-rw-r--r--src/arrow/cpp/src/gandiva/tests/date_time_test.cc602
-rw-r--r--src/arrow/cpp/src/gandiva/tests/decimal_single_test.cc305
-rw-r--r--src/arrow/cpp/src/gandiva/tests/decimal_test.cc1194
-rw-r--r--src/arrow/cpp/src/gandiva/tests/filter_project_test.cc276
-rw-r--r--src/arrow/cpp/src/gandiva/tests/filter_test.cc340
-rw-r--r--src/arrow/cpp/src/gandiva/tests/generate_data.h152
-rw-r--r--src/arrow/cpp/src/gandiva/tests/hash_test.cc615
-rw-r--r--src/arrow/cpp/src/gandiva/tests/huge_table_test.cc157
-rw-r--r--src/arrow/cpp/src/gandiva/tests/if_expr_test.cc378
-rw-r--r--src/arrow/cpp/src/gandiva/tests/in_expr_test.cc278
-rw-r--r--src/arrow/cpp/src/gandiva/tests/literal_test.cc232
-rw-r--r--src/arrow/cpp/src/gandiva/tests/micro_benchmarks.cc456
-rw-r--r--src/arrow/cpp/src/gandiva/tests/null_validity_test.cc175
-rw-r--r--src/arrow/cpp/src/gandiva/tests/projector_build_validation_test.cc287
-rw-r--r--src/arrow/cpp/src/gandiva/tests/projector_test.cc1609
-rw-r--r--src/arrow/cpp/src/gandiva/tests/test_util.h103
-rw-r--r--src/arrow/cpp/src/gandiva/tests/timed_evaluate.h136
-rw-r--r--src/arrow/cpp/src/gandiva/tests/to_string_test.cc88
-rw-r--r--src/arrow/cpp/src/gandiva/tests/utf8_test.cc751
-rw-r--r--src/arrow/cpp/src/gandiva/to_date_holder.cc116
-rw-r--r--src/arrow/cpp/src/gandiva/to_date_holder.h58
-rw-r--r--src/arrow/cpp/src/gandiva/to_date_holder_test.cc152
-rw-r--r--src/arrow/cpp/src/gandiva/tree_expr_builder.cc223
-rw-r--r--src/arrow/cpp/src/gandiva/tree_expr_builder.h139
-rw-r--r--src/arrow/cpp/src/gandiva/tree_expr_test.cc159
-rw-r--r--src/arrow/cpp/src/gandiva/value_validity_pair.h48
-rw-r--r--src/arrow/cpp/src/gandiva/visibility.h48
-rw-r--r--src/arrow/cpp/src/generated/Expression_generated.h1870
-rw-r--r--src/arrow/cpp/src/generated/File_generated.h200
-rw-r--r--src/arrow/cpp/src/generated/Literal_generated.h2037
-rw-r--r--src/arrow/cpp/src/generated/Message_generated.h659
-rw-r--r--src/arrow/cpp/src/generated/Plan_generated.h115
-rw-r--r--src/arrow/cpp/src/generated/Relation_generated.h1647
-rw-r--r--src/arrow/cpp/src/generated/Schema_generated.h2367
-rw-r--r--src/arrow/cpp/src/generated/SparseTensor_generated.h921
-rw-r--r--src/arrow/cpp/src/generated/Tensor_generated.h387
-rw-r--r--src/arrow/cpp/src/generated/feather_generated.h863
-rw-r--r--src/arrow/cpp/src/generated/parquet_constants.cpp17
-rw-r--r--src/arrow/cpp/src/generated/parquet_constants.h24
-rw-r--r--src/arrow/cpp/src/generated/parquet_types.cpp7413
-rw-r--r--src/arrow/cpp/src/generated/parquet_types.h2917
-rw-r--r--src/arrow/cpp/src/jni/CMakeLists.txt27
-rw-r--r--src/arrow/cpp/src/jni/dataset/CMakeLists.txt65
-rw-r--r--src/arrow/cpp/src/jni/dataset/jni_util.cc242
-rw-r--r--src/arrow/cpp/src/jni/dataset/jni_util.h135
-rw-r--r--src/arrow/cpp/src/jni/dataset/jni_util_test.cc134
-rw-r--r--src/arrow/cpp/src/jni/dataset/jni_wrapper.cc545
-rw-r--r--src/arrow/cpp/src/jni/orc/CMakeLists.txt53
-rw-r--r--src/arrow/cpp/src/jni/orc/concurrent_map.h77
-rw-r--r--src/arrow/cpp/src/jni/orc/jni_wrapper.cpp306
-rw-r--r--src/arrow/cpp/src/parquet/CMakeLists.txt414
-rw-r--r--src/arrow/cpp/src/parquet/ParquetConfig.cmake.in43
-rw-r--r--src/arrow/cpp/src/parquet/README10
-rw-r--r--src/arrow/cpp/src/parquet/api/CMakeLists.txt19
-rw-r--r--src/arrow/cpp/src/parquet/api/io.h20
-rw-r--r--src/arrow/cpp/src/parquet/api/reader.h35
-rw-r--r--src/arrow/cpp/src/parquet/api/schema.h21
-rw-r--r--src/arrow/cpp/src/parquet/api/writer.h25
-rw-r--r--src/arrow/cpp/src/parquet/arrow/CMakeLists.txt31
-rw-r--r--src/arrow/cpp/src/parquet/arrow/arrow_reader_writer_test.cc4343
-rw-r--r--src/arrow/cpp/src/parquet/arrow/arrow_schema_test.cc1701
-rw-r--r--src/arrow/cpp/src/parquet/arrow/arrow_statistics_test.cc161
-rw-r--r--src/arrow/cpp/src/parquet/arrow/fuzz.cc25
-rw-r--r--src/arrow/cpp/src/parquet/arrow/generate_fuzz_corpus.cc198
-rw-r--r--src/arrow/cpp/src/parquet/arrow/path_internal.cc901
-rw-r--r--src/arrow/cpp/src/parquet/arrow/path_internal.h155
-rw-r--r--src/arrow/cpp/src/parquet/arrow/path_internal_test.cc648
-rw-r--r--src/arrow/cpp/src/parquet/arrow/reader.cc1305
-rw-r--r--src/arrow/cpp/src/parquet/arrow/reader.h344
-rw-r--r--src/arrow/cpp/src/parquet/arrow/reader_internal.cc791
-rw-r--r--src/arrow/cpp/src/parquet/arrow/reader_internal.h122
-rw-r--r--src/arrow/cpp/src/parquet/arrow/reader_writer_benchmark.cc585
-rw-r--r--src/arrow/cpp/src/parquet/arrow/reconstruct_internal_test.cc1639
-rw-r--r--src/arrow/cpp/src/parquet/arrow/schema.cc1093
-rw-r--r--src/arrow/cpp/src/parquet/arrow/schema.h184
-rw-r--r--src/arrow/cpp/src/parquet/arrow/schema_internal.cc222
-rw-r--r--src/arrow/cpp/src/parquet/arrow/schema_internal.h51
-rw-r--r--src/arrow/cpp/src/parquet/arrow/test_util.h512
-rw-r--r--src/arrow/cpp/src/parquet/arrow/writer.cc480
-rw-r--r--src/arrow/cpp/src/parquet/arrow/writer.h109
-rw-r--r--src/arrow/cpp/src/parquet/bloom_filter.cc162
-rw-r--r--src/arrow/cpp/src/parquet/bloom_filter.h247
-rw-r--r--src/arrow/cpp/src/parquet/bloom_filter_test.cc247
-rw-r--r--src/arrow/cpp/src/parquet/column_io_benchmark.cc261
-rw-r--r--src/arrow/cpp/src/parquet/column_page.h160
-rw-r--r--src/arrow/cpp/src/parquet/column_reader.cc1808
-rw-r--r--src/arrow/cpp/src/parquet/column_reader.h376
-rw-r--r--src/arrow/cpp/src/parquet/column_reader_test.cc476
-rw-r--r--src/arrow/cpp/src/parquet/column_scanner.cc91
-rw-r--r--src/arrow/cpp/src/parquet/column_scanner.h262
-rw-r--r--src/arrow/cpp/src/parquet/column_scanner_test.cc229
-rw-r--r--src/arrow/cpp/src/parquet/column_writer.cc2103
-rw-r--r--src/arrow/cpp/src/parquet/column_writer.h270
-rw-r--r--src/arrow/cpp/src/parquet/column_writer_test.cc1019
-rw-r--r--src/arrow/cpp/src/parquet/encoding.cc2597
-rw-r--r--src/arrow/cpp/src/parquet/encoding.h460
-rw-r--r--src/arrow/cpp/src/parquet/encoding_benchmark.cc802
-rw-r--r--src/arrow/cpp/src/parquet/encoding_test.cc1247
-rw-r--r--src/arrow/cpp/src/parquet/encryption/CMakeLists.txt19
-rw-r--r--src/arrow/cpp/src/parquet/encryption/crypto_factory.cc175
-rw-r--r--src/arrow/cpp/src/parquet/encryption/crypto_factory.h135
-rw-r--r--src/arrow/cpp/src/parquet/encryption/encryption.cc412
-rw-r--r--src/arrow/cpp/src/parquet/encryption/encryption.h510
-rw-r--r--src/arrow/cpp/src/parquet/encryption/encryption_internal.cc613
-rw-r--r--src/arrow/cpp/src/parquet/encryption/encryption_internal.h116
-rw-r--r--src/arrow/cpp/src/parquet/encryption/encryption_internal_nossl.cc110
-rw-r--r--src/arrow/cpp/src/parquet/encryption/file_key_material_store.h31
-rw-r--r--src/arrow/cpp/src/parquet/encryption/file_key_unwrapper.cc114
-rw-r--r--src/arrow/cpp/src/parquet/encryption/file_key_unwrapper.h66
-rw-r--r--src/arrow/cpp/src/parquet/encryption/file_key_wrapper.cc109
-rw-r--r--src/arrow/cpp/src/parquet/encryption/file_key_wrapper.h82
-rw-r--r--src/arrow/cpp/src/parquet/encryption/internal_file_decryptor.cc240
-rw-r--r--src/arrow/cpp/src/parquet/encryption/internal_file_decryptor.h121
-rw-r--r--src/arrow/cpp/src/parquet/encryption/internal_file_encryptor.cc170
-rw-r--r--src/arrow/cpp/src/parquet/encryption/internal_file_encryptor.h109
-rw-r--r--src/arrow/cpp/src/parquet/encryption/key_encryption_key.h59
-rw-r--r--src/arrow/cpp/src/parquet/encryption/key_management_test.cc225
-rw-r--r--src/arrow/cpp/src/parquet/encryption/key_material.cc159
-rw-r--r--src/arrow/cpp/src/parquet/encryption/key_material.h131
-rw-r--r--src/arrow/cpp/src/parquet/encryption/key_metadata.cc89
-rw-r--r--src/arrow/cpp/src/parquet/encryption/key_metadata.h94
-rw-r--r--src/arrow/cpp/src/parquet/encryption/key_metadata_test.cc77
-rw-r--r--src/arrow/cpp/src/parquet/encryption/key_toolkit.cc52
-rw-r--r--src/arrow/cpp/src/parquet/encryption/key_toolkit.h76
-rw-r--r--src/arrow/cpp/src/parquet/encryption/key_toolkit_internal.cc80
-rw-r--r--src/arrow/cpp/src/parquet/encryption/key_toolkit_internal.h58
-rw-r--r--src/arrow/cpp/src/parquet/encryption/key_wrapping_test.cc103
-rw-r--r--src/arrow/cpp/src/parquet/encryption/kms_client.cc44
-rw-r--r--src/arrow/cpp/src/parquet/encryption/kms_client.h95
-rw-r--r--src/arrow/cpp/src/parquet/encryption/kms_client_factory.h40
-rw-r--r--src/arrow/cpp/src/parquet/encryption/local_wrap_kms_client.cc116
-rw-r--r--src/arrow/cpp/src/parquet/encryption/local_wrap_kms_client.h96
-rw-r--r--src/arrow/cpp/src/parquet/encryption/properties_test.cc276
-rw-r--r--src/arrow/cpp/src/parquet/encryption/read_configurations_test.cc272
-rw-r--r--src/arrow/cpp/src/parquet/encryption/test_encryption_util.cc502
-rw-r--r--src/arrow/cpp/src/parquet/encryption/test_encryption_util.h118
-rw-r--r--src/arrow/cpp/src/parquet/encryption/test_in_memory_kms.cc81
-rw-r--r--src/arrow/cpp/src/parquet/encryption/test_in_memory_kms.h89
-rw-r--r--src/arrow/cpp/src/parquet/encryption/two_level_cache_with_expiration.h159
-rw-r--r--src/arrow/cpp/src/parquet/encryption/two_level_cache_with_expiration_test.cc177
-rw-r--r--src/arrow/cpp/src/parquet/encryption/write_configurations_test.cc234
-rw-r--r--src/arrow/cpp/src/parquet/exception.cc27
-rw-r--r--src/arrow/cpp/src/parquet/exception.h158
-rw-r--r--src/arrow/cpp/src/parquet/file_deserialize_test.cc372
-rw-r--r--src/arrow/cpp/src/parquet/file_reader.cc868
-rw-r--r--src/arrow/cpp/src/parquet/file_reader.h188
-rw-r--r--src/arrow/cpp/src/parquet/file_serialize_test.cc470
-rw-r--r--src/arrow/cpp/src/parquet/file_writer.cc547
-rw-r--r--src/arrow/cpp/src/parquet/file_writer.h234
-rw-r--r--src/arrow/cpp/src/parquet/hasher.h72
-rw-r--r--src/arrow/cpp/src/parquet/level_comparison.cc82
-rw-r--r--src/arrow/cpp/src/parquet/level_comparison.h40
-rw-r--r--src/arrow/cpp/src/parquet/level_comparison_avx2.cc34
-rw-r--r--src/arrow/cpp/src/parquet/level_comparison_inc.h65
-rw-r--r--src/arrow/cpp/src/parquet/level_conversion.cc183
-rw-r--r--src/arrow/cpp/src/parquet/level_conversion.h199
-rw-r--r--src/arrow/cpp/src/parquet/level_conversion_benchmark.cc80
-rw-r--r--src/arrow/cpp/src/parquet/level_conversion_bmi2.cc33
-rw-r--r--src/arrow/cpp/src/parquet/level_conversion_inc.h357
-rw-r--r--src/arrow/cpp/src/parquet/level_conversion_test.cc361
-rw-r--r--src/arrow/cpp/src/parquet/metadata.cc1797
-rw-r--r--src/arrow/cpp/src/parquet/metadata.h489
-rw-r--r--src/arrow/cpp/src/parquet/metadata_test.cc571
-rw-r--r--src/arrow/cpp/src/parquet/murmur3.cc222
-rw-r--r--src/arrow/cpp/src/parquet/murmur3.h54
-rw-r--r--src/arrow/cpp/src/parquet/parquet.pc.in31
-rw-r--r--src/arrow/cpp/src/parquet/parquet.thrift1063
-rw-r--r--src/arrow/cpp/src/parquet/parquet_version.h.in31
-rw-r--r--src/arrow/cpp/src/parquet/pch.h28
-rw-r--r--src/arrow/cpp/src/parquet/platform.cc41
-rw-r--r--src/arrow/cpp/src/parquet/platform.h111
-rw-r--r--src/arrow/cpp/src/parquet/printer.cc297
-rw-r--r--src/arrow/cpp/src/parquet/printer.h46
-rw-r--r--src/arrow/cpp/src/parquet/properties.cc64
-rw-r--r--src/arrow/cpp/src/parquet/properties.h801
-rw-r--r--src/arrow/cpp/src/parquet/properties_test.cc90
-rw-r--r--src/arrow/cpp/src/parquet/public_api_test.cc49
-rw-r--r--src/arrow/cpp/src/parquet/reader_test.cc810
-rw-r--r--src/arrow/cpp/src/parquet/schema.cc945
-rw-r--r--src/arrow/cpp/src/parquet/schema.h491
-rw-r--r--src/arrow/cpp/src/parquet/schema_internal.h54
-rw-r--r--src/arrow/cpp/src/parquet/schema_test.cc2226
-rw-r--r--src/arrow/cpp/src/parquet/statistics.cc887
-rw-r--r--src/arrow/cpp/src/parquet/statistics.h367
-rw-r--r--src/arrow/cpp/src/parquet/statistics_test.cc1178
-rw-r--r--src/arrow/cpp/src/parquet/stream_reader.cc521
-rw-r--r--src/arrow/cpp/src/parquet/stream_reader.h299
-rw-r--r--src/arrow/cpp/src/parquet/stream_reader_test.cc916
-rw-r--r--src/arrow/cpp/src/parquet/stream_writer.cc324
-rw-r--r--src/arrow/cpp/src/parquet/stream_writer.h243
-rw-r--r--src/arrow/cpp/src/parquet/stream_writer_test.cc419
-rw-r--r--src/arrow/cpp/src/parquet/symbols.map40
-rw-r--r--src/arrow/cpp/src/parquet/test_util.cc136
-rw-r--r--src/arrow/cpp/src/parquet/test_util.h715
-rw-r--r--src/arrow/cpp/src/parquet/thrift_internal.h509
-rw-r--r--src/arrow/cpp/src/parquet/type_fwd.h88
-rw-r--r--src/arrow/cpp/src/parquet/types.cc1567
-rw-r--r--src/arrow/cpp/src/parquet/types.h766
-rw-r--r--src/arrow/cpp/src/parquet/types_test.cc172
-rw-r--r--src/arrow/cpp/src/parquet/windows_compatibility.h30
-rw-r--r--src/arrow/cpp/src/plasma/.gitignore18
-rw-r--r--src/arrow/cpp/src/plasma/CMakeLists.txt235
-rw-r--r--src/arrow/cpp/src/plasma/PlasmaConfig.cmake.in39
-rw-r--r--src/arrow/cpp/src/plasma/client.cc1224
-rw-r--r--src/arrow/cpp/src/plasma/client.h309
-rw-r--r--src/arrow/cpp/src/plasma/common.cc195
-rw-r--r--src/arrow/cpp/src/plasma/common.fbs39
-rw-r--r--src/arrow/cpp/src/plasma/common.h155
-rw-r--r--src/arrow/cpp/src/plasma/common_generated.h230
-rw-r--r--src/arrow/cpp/src/plasma/compat.h32
-rw-r--r--src/arrow/cpp/src/plasma/dlmalloc.cc166
-rw-r--r--src/arrow/cpp/src/plasma/events.cc107
-rw-r--r--src/arrow/cpp/src/plasma/events.h108
-rw-r--r--src/arrow/cpp/src/plasma/eviction_policy.cc175
-rw-r--r--src/arrow/cpp/src/plasma/eviction_policy.h209
-rw-r--r--src/arrow/cpp/src/plasma/external_store.cc63
-rw-r--r--src/arrow/cpp/src/plasma/external_store.h120
-rw-r--r--src/arrow/cpp/src/plasma/fling.cc129
-rw-r--r--src/arrow/cpp/src/plasma/fling.h52
-rw-r--r--src/arrow/cpp/src/plasma/hash_table_store.cc58
-rw-r--r--src/arrow/cpp/src/plasma/hash_table_store.h50
-rw-r--r--src/arrow/cpp/src/plasma/io.cc250
-rw-r--r--src/arrow/cpp/src/plasma/io.h67
-rw-r--r--src/arrow/cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.cc263
-rw-r--r--src/arrow/cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.h141
-rw-r--r--src/arrow/cpp/src/plasma/malloc.cc70
-rw-r--r--src/arrow/cpp/src/plasma/malloc.h51
-rw-r--r--src/arrow/cpp/src/plasma/plasma.cc99
-rw-r--r--src/arrow/cpp/src/plasma/plasma.fbs357
-rw-r--r--src/arrow/cpp/src/plasma/plasma.h175
-rw-r--r--src/arrow/cpp/src/plasma/plasma.pc.in33
-rw-r--r--src/arrow/cpp/src/plasma/plasma_allocator.cc56
-rw-r--r--src/arrow/cpp/src/plasma/plasma_allocator.h61
-rw-r--r--src/arrow/cpp/src/plasma/plasma_generated.h3984
-rw-r--r--src/arrow/cpp/src/plasma/protocol.cc829
-rw-r--r--src/arrow/cpp/src/plasma/protocol.h251
-rw-r--r--src/arrow/cpp/src/plasma/quota_aware_policy.cc177
-rw-r--r--src/arrow/cpp/src/plasma/quota_aware_policy.h88
-rw-r--r--src/arrow/cpp/src/plasma/store.cc1353
-rw-r--r--src/arrow/cpp/src/plasma/store.h245
-rw-r--r--src/arrow/cpp/src/plasma/symbols.map34
-rw-r--r--src/arrow/cpp/src/plasma/test/client_tests.cc1084
-rw-r--r--src/arrow/cpp/src/plasma/test/external_store_tests.cc143
-rw-r--r--src/arrow/cpp/src/plasma/test/serialization_tests.cc333
-rw-r--r--src/arrow/cpp/src/plasma/test_util.h46
-rw-r--r--src/arrow/cpp/src/plasma/thirdparty/ae/ae.c465
-rw-r--r--src/arrow/cpp/src/plasma/thirdparty/ae/ae.h121
-rw-r--r--src/arrow/cpp/src/plasma/thirdparty/ae/ae_epoll.c137
-rw-r--r--src/arrow/cpp/src/plasma/thirdparty/ae/ae_evport.c320
-rw-r--r--src/arrow/cpp/src/plasma/thirdparty/ae/ae_kqueue.c138
-rw-r--r--src/arrow/cpp/src/plasma/thirdparty/ae/ae_select.c106
-rw-r--r--src/arrow/cpp/src/plasma/thirdparty/ae/config.h52
-rw-r--r--src/arrow/cpp/src/plasma/thirdparty/ae/zmalloc.h43
-rw-r--r--src/arrow/cpp/src/plasma/thirdparty/dlmalloc.c6296
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/LICENSE.txt202
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/README.md19
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/bad_data/PARQUET-1481.parquetbin0 -> 451 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/bad_data/README.md24
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/README.md58
-rwxr-xr-xsrc/arrow/cpp/submodules/parquet-testing/data/alltypes_dictionary.parquetbin0 -> 1698 bytes
-rwxr-xr-xsrc/arrow/cpp/submodules/parquet-testing/data/alltypes_plain.parquetbin0 -> 1851 bytes
-rwxr-xr-xsrc/arrow/cpp/submodules/parquet-testing/data/alltypes_plain.snappy.parquetbin0 -> 1736 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/binary.parquetbin0 -> 478 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/bloom_filter.binbin0 -> 1036 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/byte_array_decimal.parquetbin0 -> 324 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/datapage_v2.snappy.parquetbin0 -> 1165 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/delta_binary_packed.md440
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/delta_binary_packed.parquetbin0 -> 72971 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/delta_binary_packed_expect.csv201
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/dict-page-offset-zero.parquetbin0 -> 635 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer.parquet.encryptedbin0 -> 4930 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer_aad.parquet.encryptedbin0 -> 4938 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer_ctr.parquet.encryptedbin0 -> 4864 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer_disable_aad_storage.parquet.encryptedbin0 -> 4930 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_plaintext_footer.parquet.encryptedbin0 -> 5083 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/fixed_length_decimal.parquetbin0 -> 677 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/fixed_length_decimal_legacy.parquetbin0 -> 537 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/hadoop_lz4_compressed.parquetbin0 -> 702 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/hadoop_lz4_compressed_larger.parquetbin0 -> 358859 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/int32_decimal.parquetbin0 -> 478 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/int64_decimal.parquetbin0 -> 591 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/list_columns.parquetbin0 -> 2526 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/lz4_raw_compressed.parquetbin0 -> 797 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/lz4_raw_compressed_larger.parquetbin0 -> 380836 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/nation.dict-malformed.parquetbin0 -> 2850 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/nested_lists.snappy.parquetbin0 -> 881 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/nested_maps.snappy.parquetbin0 -> 1324 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/nested_structs.rust.parquetbin0 -> 53040 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/non_hadoop_lz4_compressed.parquetbin0 -> 1228 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/nonnullable.impala.parquetbin0 -> 3186 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/nullable.impala.parquetbin0 -> 3896 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/nulls.snappy.parquetbin0 -> 461 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/repeated_no_annotation.parquetbin0 -> 662 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/single_nan.parquetbin0 -> 660 bytes
-rw-r--r--src/arrow/cpp/submodules/parquet-testing/data/uniform_encryption.parquet.encryptedbin0 -> 5483 bytes
-rw-r--r--src/arrow/cpp/thirdparty/README.md25
-rwxr-xr-xsrc/arrow/cpp/thirdparty/download_dependencies.sh63
-rw-r--r--src/arrow/cpp/thirdparty/flatbuffers/include/flatbuffers/base.h398
-rw-r--r--src/arrow/cpp/thirdparty/flatbuffers/include/flatbuffers/flatbuffers.h2783
-rw-r--r--src/arrow/cpp/thirdparty/flatbuffers/include/flatbuffers/stl_emulation.h307
-rw-r--r--src/arrow/cpp/thirdparty/hadoop/include/hdfs.h1024
-rw-r--r--src/arrow/cpp/thirdparty/versions.txt130
-rw-r--r--src/arrow/cpp/tools/parquet/CMakeLists.txt36
-rw-r--r--src/arrow/cpp/tools/parquet/parquet_dump_schema.cc52
-rw-r--r--src/arrow/cpp/tools/parquet/parquet_reader.cc82
-rw-r--r--src/arrow/cpp/tools/parquet/parquet_scan.cc78
-rw-r--r--src/arrow/cpp/valgrind.supp53
-rw-r--r--src/arrow/cpp/vcpkg.json46
1602 files changed, 588654 insertions, 0 deletions
diff --git a/src/arrow/cpp/.gitignore b/src/arrow/cpp/.gitignore
new file mode 100644
index 000000000..03c03a401
--- /dev/null
+++ b/src/arrow/cpp/.gitignore
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+thirdparty/*.tar*
+CMakeFiles/
+CMakeCache.txt
+CTestTestfile.cmake
+Makefile
+cmake_install.cmake
+build/
+*-build/
+Testing/
+build-support/boost_*
+
+# Build directories created by Clion
+cmake-build-*/
+
+#########################################
+# Editor temporary/working/backup files #
+.#*
+*\#*\#
+[#]*#
+*~
+*$
+*.bak
+*flymake*
+*.kdev4
+*.log
+*.swp
diff --git a/src/arrow/cpp/Brewfile b/src/arrow/cpp/Brewfile
new file mode 100644
index 000000000..78ee5e64c
--- /dev/null
+++ b/src/arrow/cpp/Brewfile
@@ -0,0 +1,41 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+brew "automake"
+brew "boost"
+brew "brotli"
+brew "c-ares"
+brew "ccache"
+brew "cmake"
+brew "flatbuffers"
+brew "git"
+brew "glog"
+brew "grpc"
+brew "llvm"
+brew "llvm@8"
+brew "lz4"
+brew "minio"
+brew "ninja"
+brew "numpy"
+brew "openssl@1.1"
+brew "protobuf"
+brew "python"
+brew "rapidjson"
+brew "snappy"
+brew "thrift"
+brew "wget"
+brew "zstd"
diff --git a/src/arrow/cpp/CHANGELOG_PARQUET.md b/src/arrow/cpp/CHANGELOG_PARQUET.md
new file mode 100644
index 000000000..06a09c20f
--- /dev/null
+++ b/src/arrow/cpp/CHANGELOG_PARQUET.md
@@ -0,0 +1,501 @@
+Parquet C++ 1.5.0
+--------------------------------------------------------------------------------
+## Bug
+ * [PARQUET-979] - [C++] Limit size of min, max or disable stats for long binary types
+ * [PARQUET-1071] - [C++] parquet::arrow::FileWriter::Close is not idempotent
+ * [PARQUET-1349] - [C++] PARQUET_RPATH_ORIGIN is not picked by the build
+ * [PARQUET-1334] - [C++] memory_map parameter seems missleading in parquet file opener
+ * [PARQUET-1333] - [C++] Reading of files with dictionary size 0 fails on Windows with bad_alloc
+ * [PARQUET-1283] - [C++] FormatStatValue appends trailing space to string and int96
+ * [PARQUET-1270] - [C++] Executable tools do not get installed
+ * [PARQUET-1272] - [C++] ScanFileContents reports wrong row count for nested columns
+ * [PARQUET-1268] - [C++] Conversion of Arrow null list columns fails
+ * [PARQUET-1255] - [C++] Exceptions thrown in some tests
+ * [PARQUET-1358] - [C++] index_page_offset should be unset as it is not supported.
+ * [PARQUET-1357] - [C++] FormatStatValue truncates binary statistics on zero character
+ * [PARQUET-1319] - [C++] Pass BISON_EXECUTABLE to Thrift EP for MacOS
+ * [PARQUET-1313] - [C++] Compilation failure with VS2017
+ * [PARQUET-1315] - [C++] ColumnChunkMetaData.has_dictionary_page() should return bool, not int64_t
+ * [PARQUET-1307] - [C++] memory-test fails with latest Arrow
+ * [PARQUET-1274] - [Python] SegFault in pyarrow.parquet.write_table with specific options
+ * [PARQUET-1209] - locally defined symbol ... imported in function ..
+ * [PARQUET-1245] - [C++] Segfault when writing Arrow table with duplicate columns
+ * [PARQUET-1273] - [Python] Error writing to partitioned Parquet dataset
+ * [PARQUET-1384] - [C++] Clang compiler warnings in bloom_filter-test.cc
+
+## Improvement
+ * [PARQUET-1348] - [C++] Allow Arrow FileWriter To Write FileMetaData
+ * [PARQUET-1346] - [C++] Protect against null values data in empty Arrow array
+ * [PARQUET-1340] - [C++] Fix Travis Ci valgrind errors related to std::random_device
+ * [PARQUET-1323] - [C++] Fix compiler warnings with clang-6.0
+ * [PARQUET-1279] - Use ASSERT_NO_FATAIL_FAILURE in C++ unit tests
+ * [PARQUET-1262] - [C++] Use the same BOOST_ROOT and Boost_NAMESPACE for Thrift
+ * [PARQUET-1267] - replace "unsafe" std::equal by std::memcmp
+ * [PARQUET-1360] - [C++] Minor API + style changes follow up to PARQUET-1348
+ * [PARQUET-1166] - [API Proposal] Add GetRecordBatchReader in parquet/arrow/reader.h
+ * [PARQUET-1378] - [c++] Allow RowGroups with zero rows to be written
+ * [PARQUET-1256] - [C++] Add --print-key-value-metadata option to parquet_reader tool
+ * [PARQUET-1276] - [C++] Reduce the amount of memory used for writing null decimal values
+
+## New Feature
+ * [PARQUET-1392] - [C++] Supply row group indices to parquet::arrow::FileReader::ReadTable
+
+## Sub-task
+ * [PARQUET-1227] - Thrift crypto metadata structures
+ * [PARQUET-1332] - [C++] Add bloom filter utility class
+
+## Task
+ * [PARQUET-1350] - [C++] Use abstract ResizableBuffer instead of concrete PoolBuffer
+ * [PARQUET-1366] - [C++] Streamline use of Arrow bit-util.h
+ * [PARQUET-1308] - [C++] parquet::arrow should use thread pool, not ParallelFor
+ * [PARQUET-1382] - [C++] Prepare for arrow::test namespace removal
+ * [PARQUET-1372] - [C++] Add an API to allow writing RowGroups based on their size rather than num_rows
+
+
+Parquet C++ 1.4.0
+--------------------------------------------------------------------------------
+## Bug
+ * [PARQUET-1193] - [CPP] Implement ColumnOrder to support min_value and max_value
+ * [PARQUET-1180] - C++: Fix behaviour of num_children element of primitive nodes
+ * [PARQUET-1146] - C++: Add macOS-compatible sha512sum call to release verify script
+ * [PARQUET-1167] - [C++] FieldToNode function should return a status when throwing an exception
+ * [PARQUET-1175] - [C++] Fix usage of deprecated Arrow API
+ * [PARQUET-1113] - [C++] Incorporate fix from ARROW-1601 on bitmap read path
+ * [PARQUET-1111] - dev/release/verify-release-candidate has stale help
+ * [PARQUET-1109] - C++: Update release verification script to SHA512
+ * [PARQUET-1179] - [C++] Support Apache Thrift 0.11
+ * [PARQUET-1226] - [C++] Fix new build warnings with clang 5.0
+ * [PARQUET-1233] - [CPP ]Enable option to switch between stl classes and boost classes for thrift header
+ * [PARQUET-1205] - Fix msvc static build
+ * [PARQUET-1210] - [C++] Boost 1.66 compilation fails on Windows on linkage stage
+
+## Improvement
+ * [PARQUET-1092] - [C++] Write Arrow tables with chunked columns
+ * [PARQUET-1086] - [C++] Remove usage of arrow/util/compiler-util.h after 1.3.0 release
+ * [PARQUET-1097] - [C++] Account for Arrow API deprecation in ARROW-1511
+ * [PARQUET-1150] - C++: Hide statically linked boost symbols
+ * [PARQUET-1151] - [C++] Add build options / configuration to use static runtime libraries with MSVC
+ * [PARQUET-1147] - [C++] Account for API deprecation / change in ARROW-1671
+ * [PARQUET-1162] - C++: Update dev/README after migration to Gitbox
+ * [PARQUET-1165] - [C++] Pin clang-format version to 4.0
+ * [PARQUET-1164] - [C++] Follow API changes in ARROW-1808
+ * [PARQUET-1177] - [C++] Add more extensive compiler warnings when using Clang
+ * [PARQUET-1110] - [C++] Release verification script for Windows
+ * [PARQUET-859] - [C++] Flatten parquet/file directory
+ * [PARQUET-1220] - [C++] Don't build Thrift examples and tutorials in the ExternalProject
+ * [PARQUET-1219] - [C++] Update release-candidate script links to gitbox
+ * [PARQUET-1196] - [C++] Provide a parquet_arrow example project incl. CMake setup
+ * [PARQUET-1200] - [C++] Support reading a single Arrow column from a Parquet file
+
+## New Feature
+ * [PARQUET-1095] - [C++] Read and write Arrow decimal values
+ * [PARQUET-970] - Add Add Lz4 and Zstd compression codecs
+
+## Task
+ * [PARQUET-1221] - [C++] Extend release README
+ * [PARQUET-1225] - NaN values may lead to incorrect filtering under certain circumstances
+
+
+Parquet C++ 1.3.1
+--------------------------------------------------------------------------------
+## Bug
+ * [PARQUET-1105] - [CPP] Remove libboost_system dependency
+ * [PARQUET-1138] - [C++] Fix compilation with Arrow 0.7.1
+ * [PARQUET-1123] - [C++] Update parquet-cpp to use Arrow's AssertArraysEqual
+ * [PARQUET-1121] - C++: DictionaryArrays of NullType cannot be written
+ * [PARQUET-1139] - Add license to cmake_modules/parquet-cppConfig.cmake.in
+
+## Improvement
+ * [PARQUET-1140] - [C++] Fail on RAT errors in CI
+ * [PARQUET-1070] - Add CPack support to the build
+
+
+Parquet C++ 1.3.0
+--------------------------------------------------------------------------------
+## Bug
+ * [PARQUET-1098] - [C++] Install new header in parquet/util
+ * [PARQUET-1085] - [C++] Backwards compatibility from macro cleanup in transitive dependencies in ARROW-1452
+ * [PARQUET-1074] - [C++] Switch to long key ids in KEYs file
+ * [PARQUET-1075] - C++: Coverage upload is broken
+ * [PARQUET-1088] - [CPP] remove parquet_version.h from version control since it gets auto generated
+ * [PARQUET-1002] - [C++] Compute statistics based on Logical Types
+ * [PARQUET-1100] - [C++] Reading repeated types should decode number of records rather than number of values
+ * [PARQUET-1090] - [C++] Fix int32 overflow in Arrow table writer, add max row group size property
+ * [PARQUET-1108] - [C++] Fix Int96 comparators
+
+## Improvement
+ * [PARQUET-1104] - [C++] Upgrade to Apache Arrow 0.7.0 RC0
+ * [PARQUET-1072] - [C++] Add ARROW_NO_DEPRECATED_API to CI to check for deprecated API use
+ * [PARQUET-1096] - C++: Update sha{1, 256, 512} checksums per latest ASF release policy
+ * [PARQUET-1079] - [C++] Account for Arrow API change in ARROW-1335
+ * [PARQUET-1087] - [C++] Add wrapper for ScanFileContents in parquet::arrow that catches exceptions
+ * [PARQUET-1093] - C++: Improve Arrow level generation error message
+ * [PARQUET-1094] - C++: Add benchmark for boolean Arrow column I/O
+ * [PARQUET-1083] - [C++] Refactor core logic in parquet-scan.cc so that it can be used as a library function for benchmarking
+ * [PARQUET-1037] - Allow final RowGroup to be unfilled
+
+## New Feature
+ * [PARQUET-1078] - [C++] Add Arrow writer option to coerce timestamps to milliseconds or microseconds
+ * [PARQUET-929] - [C++] Handle arrow::DictionaryArray when writing Arrow data
+
+
+Parquet C++ 1.2.0
+--------------------------------------------------------------------------------
+## Bug
+ * [PARQUET-1029] - [C++] TypedColumnReader/TypeColumnWriter symbols are no longer being exported
+ * [PARQUET-997] - Fix override compiler warnings
+ * [PARQUET-1033] - Mismatched Read and Write
+ * [PARQUET-1007] - [C++ ] Update parquet.thrift from https://github.com/apache/parquet-format
+ * [PARQUET-1039] - PARQUET-911 Breaks Arrow
+ * [PARQUET-1038] - Key value metadata should be nullptr if not set
+ * [PARQUET-1018] - [C++] parquet.dll has runtime dependencies on one or more libraries in the build toolchain
+ * [PARQUET-1003] - [C++] Modify DEFAULT_CREATED_BY value for every new release version
+ * [PARQUET-1004] - CPP Building fails on windows
+ * [PARQUET-1040] - Missing writer method implementations
+ * [PARQUET-1054] - [C++] Account for Arrow API changes in ARROW-1199
+ * [PARQUET-1042] - C++: Compilation breaks on GCC 4.8
+ * [PARQUET-1048] - [C++] Static linking of libarrow is no longer supported
+ * [PARQUET-1013] - Fix ZLIB_INCLUDE_DIR
+ * [PARQUET-998] - C++: Release script is not usable
+ * [PARQUET-1023] - [C++] Brotli libraries are not being statically linked on Windows
+ * [PARQUET-1000] - [C++] Do not build thirdparty Arrow with /WX on MSVC
+ * [PARQUET-1052] - [C++] add_compiler_export_flags() throws warning with CMake >= 3.3
+ * [PARQUET-1069] - C++: ./dev/release/verify-release-candidate is broken due to missing Arrow dependencies
+
+## Improvement
+ * [PARQUET-996] - Improve MSVC build - ThirdpartyToolchain - Arrow
+ * [PARQUET-911] - C++: Support nested structs in parquet_arrow
+ * [PARQUET-986] - Improve MSVC build - ThirdpartyToolchain - Thrift
+ * [PARQUET-864] - [C++] Consolidate non-Parquet-specific bit utility code into Apache Arrow
+ * [PARQUET-1043] - [C++] Raise minimum supported CMake version to 3.2
+ * [PARQUET-1016] - Upgrade thirdparty Arrow to 0.4.0
+ * [PARQUET-858] - [C++] Flatten parquet/column directory, consolidate related code
+ * [PARQUET-978] - [C++] Minimizing footer reads for small(ish) metadata
+ * [PARQUET-991] - [C++] Fix compiler warnings on MSVC and build with /WX in Appveyor
+ * [PARQUET-863] - [C++] Move SIMD, CPU info, hashing, and other generic utilities into Apache Arrow
+ * [PARQUET-1053] - Fix unused result warnings due to unchecked Statuses
+ * [PARQUET-1067] - C++: Update arrow hash to 0.5.0
+ * [PARQUET-1041] - C++: Support Arrow's NullArray
+ * [PARQUET-1008] - Update TypedColumnReader::ReadBatch method to accept batch_size as int64_t
+ * [PARQUET-1044] - [C++] Use compression libraries from Apache Arrow
+ * [PARQUET-999] - Improve MSVC build - Enable PARQUET_BUILD_BENCHMARKS
+ * [PARQUET-967] - [C++] Combine libparquet/libparquet_arrow libraries
+ * [PARQUET-1045] - [C++] Refactor to account for computational utility code migration in ARROW-1154
+
+## New Feature
+ * [PARQUET-1035] - Write Int96 from Arrow Timestamp(ns)
+
+## Task
+ * [PARQUET-994] - C++: release-candidate script should not push to master
+ * [PARQUET-902] - [C++] Move compressor interfaces into Apache Arrow
+
+## Test
+ * [PARQUET-706] - [C++] Create test case that uses libparquet as a 3rd party library
+
+
+Parquet C++ 1.1.0
+--------------------------------------------------------------------------------
+## Bug
+ * [PARQUET-898] - [C++] Change Travis CI OS X image to Xcode 6.4 and fix our thirdparty build
+ * [PARQUET-976] - [C++] Pass unit test suite with MSVC, build in Appveyor
+ * [PARQUET-963] - [C++] Disallow reading struct types in Arrow reader for now
+ * [PARQUET-959] - [C++] Arrow thirdparty build fails on multiarch systems
+ * [PARQUET-962] - [C++] GTEST_MAIN_STATIC_LIB is not defined in FindGTest.cmake
+ * [PARQUET-958] - [C++] Print Parquet metadata in JSON format
+ * [PARQUET-956] - C++: BUILD_BYPRODUCTS not specified anymore for gtest
+ * [PARQUET-948] - [C++] Account for API changes in ARROW-782
+ * [PARQUET-947] - [C++] Refactor to account for ARROW-795 Arrow core library consolidation
+ * [PARQUET-965] - [C++] FIXED_LEN_BYTE_ARRAY types are unhandled in the Arrow reader
+ * [PARQUET-949] - [C++] Arrow version pinning seems to not be working properly
+ * [PARQUET-955] - [C++] pkg_check_modules will override $ARROW_HOME if it is set in the environment
+ * [PARQUET-945] - [C++] Thrift static libraries are not used with recent patch
+ * [PARQUET-943] - [C++] Overflow build error on x86
+ * [PARQUET-938] - [C++] There is a typo in cmake_modules/FindSnappy.cmake comment
+ * [PARQUET-936] - [C++] parquet::arrow::WriteTable can enter infinite loop if chunk_size is 0
+ * [PARQUET-981] - Repair usage of *_HOME 3rd party dependencies environment variables during Windows build
+ * [PARQUET-992] - [C++] parquet/compression.h leaks zlib.h
+ * [PARQUET-987] - [C++] Fix regressions caused by PARQUET-981
+ * [PARQUET-933] - [C++] Account for Arrow Table API changes coming in ARROW-728
+ * [PARQUET-915] - Support Arrow Time Types in Schema
+ * [PARQUET-914] - [C++] Throw more informative exception when user writes too many values to a column in a row group
+ * [PARQUET-923] - [C++] Account for Time metadata changes in ARROW-686
+ * [PARQUET-918] - FromParquetSchema API crashes on nested schemas
+ * [PARQUET-925] - [C++] FindArrow.cmake sets the wrong library path after ARROW-648
+ * [PARQUET-932] - [c++] Add option to build parquet library with minimal dependency
+ * [PARQUET-919] - [C++] Account for API changes in ARROW-683
+ * [PARQUET-995] - [C++] Int96 reader in parquet_arrow uses size of Int96Type instead of Int96
+
+## Improvement
+ * [PARQUET-508] - Add ParquetFilePrinter
+ * [PARQUET-595] - Add API for key-value metadata
+ * [PARQUET-897] - [C++] Only use designated public headers from libarrow
+ * [PARQUET-679] - [C++] Build and unit tests support for MSVC on Windows
+ * [PARQUET-977] - Improve MSVC build
+ * [PARQUET-957] - [C++] Add optional $PARQUET_BUILD_TOOLCHAIN environment variable option for configuring build environment
+ * [PARQUET-961] - [C++] Strip debug symbols from libparquet libraries in release builds by default
+ * [PARQUET-954] - C++: Use Brolti 0.6 release
+ * [PARQUET-953] - [C++] Change arrow::FileWriter API to be initialized from a Schema, and provide for writing multiple tables
+ * [PARQUET-941] - [C++] Stop needless Boost static library detection for CentOS 7 support
+ * [PARQUET-942] - [C++] Fix wrong variabe use in FindSnappy
+ * [PARQUET-939] - [C++] Support Thrift_HOME CMake variable like FindSnappy does as Snappy_HOME
+ * [PARQUET-940] - [C++] Fix Arrow library path detection
+ * [PARQUET-937] - [C++] Support CMake < 3.4 again for Arrow detection
+ * [PARQUET-935] - [C++] Set shared library version for .deb packages
+ * [PARQUET-934] - [C++] Support multiarch on Debian
+ * [PARQUET-984] - C++: Add abi and so version to pkg-config
+ * [PARQUET-983] - C++: Update Thirdparty hash to Arrow 0.3.0
+ * [PARQUET-989] - [C++] Link dynamically to libarrow in toolchain build, set LD_LIBRARY_PATH
+ * [PARQUET-988] - [C++] Add Linux toolchain-based build to Travis CI
+ * [PARQUET-928] - [C++] Support pkg-config
+ * [PARQUET-927] - [C++] Specify shared library version of Apache Arrow
+ * [PARQUET-931] - [C++] Add option to pin thirdparty Arrow version used in ExternalProject
+ * [PARQUET-926] - [C++] Use pkg-config to find Apache Arrow
+ * [PARQUET-917] - C++: Build parquet_arrow by default
+ * [PARQUET-910] - C++: Support TIME logical type in parquet_arrow
+ * [PARQUET-909] - [CPP]: Reduce buffer allocations (mallocs) on critical path
+
+## New Feature
+ * [PARQUET-853] - [C++] Add option to link with shared boost libraries when building Arrow in the thirdparty toolchain
+ * [PARQUET-946] - [C++] Refactoring in parquet::arrow::FileReader to be able to read a single row group
+ * [PARQUET-930] - [C++] Account for all Arrow date/time types
+
+
+Parquet C++ 1.0.0
+--------------------------------------------------------------------------------
+## Bug
+ * [PARQUET-455] - Fix compiler warnings on OS X / Clang
+ * [PARQUET-558] - Support ZSH in build scripts
+ * [PARQUET-720] - Parquet-cpp fails to link when included in multiple TUs
+ * [PARQUET-718] - Reading boolean pages written by parquet-cpp fails
+ * [PARQUET-640] - [C++] Force the use of gcc 4.9 in conda builds
+ * [PARQUET-643] - Add const modifier to schema pointer reference in ParquetFileWriter
+ * [PARQUET-672] - [C++] Build testing conda artifacts in debug mode
+ * [PARQUET-661] - [C++] Do not assume that perl is found in /usr/bin
+ * [PARQUET-659] - [C++] Instantiated template visibility is broken on clang / OS X
+ * [PARQUET-657] - [C++] Don't define DISALLOW_COPY_AND_ASSIGN if already defined
+ * [PARQUET-656] - [C++] Revert PARQUET-653
+ * [PARQUET-676] - MAX_VALUES_PER_LITERAL_RUN causes RLE encoding failure
+ * [PARQUET-614] - C++: Remove unneeded LZ4-related code
+ * [PARQUET-604] - Install writer.h headers
+ * [PARQUET-621] - C++: Uninitialised DecimalMetadata is read
+ * [PARQUET-620] - C++: Duplicate calls to ParquetFileWriter::Close cause duplicate metdata writes
+ * [PARQUET-599] - ColumnWriter::RleEncodeLevels' size estimation might be wrong
+ * [PARQUET-617] - C++: Enable conda build to work on systems with non-default C++ toolchains
+ * [PARQUET-627] - Ensure that thrift headers are generated before source compilation
+ * [PARQUET-745] - TypedRowGroupStatistics fails to PlainDecode min and max in ByteArrayType
+ * [PARQUET-738] - Update arrow version that also supports newer Xcode
+ * [PARQUET-747] - [C++] TypedRowGroupStatistics are not being exported in libparquet.so
+ * [PARQUET-711] - Use metadata builders in parquet writer
+ * [PARQUET-732] - Building a subset of dependencies does not work
+ * [PARQUET-760] - On switching from dictionary to the fallback encoding, an incorrect encoding is set
+ * [PARQUET-691] - [C++] Write ColumnChunk metadata after each column chunk in the file
+ * [PARQUET-797] - [C++] Update for API changes in ARROW-418
+ * [PARQUET-837] - [C++] SerializedFile::ParseMetaData uses Seek, followed by Read, and could have race conditions
+ * [PARQUET-827] - [C++] Incorporate addition of arrow::MemoryPool::Reallocate
+ * [PARQUET-502] - Scanner segfaults when its batch size is smaller than the number of rows
+ * [PARQUET-469] - Roll back Thrift bindings to 0.9.0
+ * [PARQUET-889] - Fix compilation when PARQUET_USE_SSE is on
+ * [PARQUET-888] - C++ Memory leak in RowGroupSerializer
+ * [PARQUET-819] - C++: Trying to install non-existing parquet/arrow/utils.h
+ * [PARQUET-736] - XCode 8.0 breaks builds
+ * [PARQUET-505] - Column reader: automatically handle large data pages
+ * [PARQUET-615] - C++: Building static or shared libparquet should not be mutually exclusive
+ * [PARQUET-658] - ColumnReader has no virtual destructor
+ * [PARQUET-799] - concurrent usage of the file reader API
+ * [PARQUET-513] - Valgrind errors are not failing the Travis CI build
+ * [PARQUET-841] - [C++] Writing wrong format version when using ParquetVersion::PARQUET_1_0
+ * [PARQUET-742] - Add missing license headers
+ * [PARQUET-741] - compression_buffer_ is reused although it shouldn't
+ * [PARQUET-700] - C++: Disable dictionary encoding for boolean columns
+ * [PARQUET-662] - [C++] ParquetException must be explicitly exported in dynamic libraries
+ * [PARQUET-704] - [C++] scan-all.h is not being installed
+ * [PARQUET-865] - C++: Pass all CXXFLAGS to Thrift ExternalProject
+ * [PARQUET-875] - [C++] Fix coveralls build given changes to thirdparty build procedure
+ * [PARQUET-709] - [C++] Fix conda dev binary builds
+ * [PARQUET-638] - [C++] Revert static linking of libstdc++ in conda builds until symbol visibility addressed
+ * [PARQUET-606] - Travis coverage is broken
+ * [PARQUET-880] - [CPP] Prevent destructors from throwing
+ * [PARQUET-886] - [C++] Revise build documentation and requirements in README.md
+ * [PARQUET-900] - C++: Fix NOTICE / LICENSE issues
+ * [PARQUET-885] - [C++] Do not search for Thrift in default system paths
+ * [PARQUET-879] - C++: ExternalProject compilation for Thrift fails on older CMake versions
+ * [PARQUET-635] - [C++] Statically link libstdc++ on Linux in conda recipe
+ * [PARQUET-710] - Remove unneeded private member variables from RowGroupReader ABI
+ * [PARQUET-766] - C++: Expose ParquetFileReader through Arrow reader as const
+ * [PARQUET-876] - C++: Correct snapshot version
+ * [PARQUET-821] - [C++] zlib download link is broken
+ * [PARQUET-818] - [C++] Refactor library to share IO, Buffer, and memory management abstractions with Apache Arrow
+ * [PARQUET-537] - LocalFileSource leaks resources
+ * [PARQUET-764] - [CPP] Parquet Writer does not write Boolean values correctly
+ * [PARQUET-812] - [C++] Failure reading BYTE_ARRAY data from file in parquet-compatibility project
+ * [PARQUET-759] - Cannot store columns consisting of empty strings
+ * [PARQUET-846] - [CPP] CpuInfo::Init() is not thread safe
+ * [PARQUET-694] - C++: Revert default data page size back to 1M
+ * [PARQUET-842] - [C++] Impala rejects DOUBLE columns if decimal metadata is set
+ * [PARQUET-708] - [C++] RleEncoder does not account for "worst case scenario" in MaxBufferSize for bit_width > 1
+ * [PARQUET-639] - Do not export DCHECK in public headers
+ * [PARQUET-828] - [C++] "version" field set improperly in file metadata
+ * [PARQUET-891] - [C++] Do not search for Snappy in default system paths
+ * [PARQUET-626] - Fix builds due to unavailable llvm.org apt mirror
+ * [PARQUET-629] - RowGroupSerializer should only close itself once
+ * [PARQUET-472] - Clean up InputStream ownership semantics in ColumnReader
+ * [PARQUET-739] - Rle-decoding uses static buffer that is shared accross threads
+ * [PARQUET-561] - ParquetFileReader::Contents PIMPL missing a virtual destructor
+ * [PARQUET-892] - [C++] Clean up link library targets in CMake files
+ * [PARQUET-454] - Address inconsistencies in boolean decoding
+ * [PARQUET-816] - [C++] Failure decoding sample dict-encoded file from parquet-compatibility project
+ * [PARQUET-565] - Use PATH instead of DIRECTORY in get_filename_component to support CMake<2.8.12
+ * [PARQUET-446] - Hide thrift dependency in parquet-cpp
+ * [PARQUET-843] - [C++] Impala unable to read files created by parquet-cpp
+ * [PARQUET-555] - Dictionary page metadata handling inconsistencies
+ * [PARQUET-908] - Fix for PARQUET-890 introduces undefined symbol in libparquet_arrow.so
+ * [PARQUET-793] - [CPP] Do not return incorrect statistics
+ * [PARQUET-887] - C++: Fix issues in release scripts arise in RC1
+
+## Improvement
+ * [PARQUET-277] - Remove boost dependency
+ * [PARQUET-500] - Enable coveralls.io for apache/parquet-cpp
+ * [PARQUET-497] - Decouple Parquet physical file structure from FileReader class
+ * [PARQUET-597] - Add data rates to benchmark output
+ * [PARQUET-522] - #include cleanup with include-what-you-use
+ * [PARQUET-515] - Add "Reset" to LevelEncoder and LevelDecoder
+ * [PARQUET-514] - Automate coveralls.io updates in Travis CI
+ * [PARQUET-551] - Handle compiler warnings due to disabled DCHECKs in release builds
+ * [PARQUET-559] - Enable InputStream as a source to the ParquetFileReader
+ * [PARQUET-562] - Simplified ZSH support in build scripts
+ * [PARQUET-538] - Improve ColumnReader Tests
+ * [PARQUET-541] - Portable build scripts
+ * [PARQUET-724] - Test more advanced properties setting
+ * [PARQUET-641] - Instantiate stringstream only if needed in SerializedPageReader::NextPage
+ * [PARQUET-636] - Expose selection for different encodings
+ * [PARQUET-603] - Implement missing information in schema descriptor
+ * [PARQUET-610] - Print ColumnMetaData for each RowGroup
+ * [PARQUET-600] - Add benchmarks for RLE-Level encoding
+ * [PARQUET-592] - Support compressed writes
+ * [PARQUET-593] - Add API for writing Page statistics
+ * [PARQUET-589] - Implement Chunked InMemoryInputStream for better memory usage
+ * [PARQUET-587] - Implement BufferReader::Read(int64_t,uint8_t*)
+ * [PARQUET-616] - C++: WriteBatch should accept const arrays
+ * [PARQUET-630] - C++: Support link flags for older CMake versions
+ * [PARQUET-634] - Consistent private linking of dependencies
+ * [PARQUET-633] - Add version to WriterProperties
+ * [PARQUET-625] - Improve RLE read performance
+ * [PARQUET-737] - Use absolute namespace in macros
+ * [PARQUET-762] - C++: Use optimistic allocation instead of Arrow Builders
+ * [PARQUET-773] - C++: Check licenses with RAT in CI
+ * [PARQUET-687] - C++: Switch to PLAIN encoding if dictionary grows too large
+ * [PARQUET-784] - C++: Reference Spark, Kudu and FrameOfReference in LICENSE
+ * [PARQUET-809] - [C++] Add API to determine if two files' schemas are compatible
+ * [PARQUET-778] - Standardize the schema output to match the parquet-mr format
+ * [PARQUET-463] - Add DCHECK* macros for assertions in debug builds
+ * [PARQUET-471] - Use the same environment setup script for Travis CI as local sandbox development
+ * [PARQUET-449] - Update to latest parquet.thrift
+ * [PARQUET-496] - Fix cpplint configuration to be more restrictive
+ * [PARQUET-468] - Add a cmake option to generate the Parquet thrift headers with the thriftc in the environment
+ * [PARQUET-482] - Organize src code file structure to have a very clear folder with public headers.
+ * [PARQUET-591] - Page size estimation during writes
+ * [PARQUET-518] - Review usages of size_t and unsigned integers generally per Google style guide
+ * [PARQUET-533] - Simplify RandomAccessSource API to combine Seek/Read
+ * [PARQUET-767] - Add release scripts for parquet-cpp
+ * [PARQUET-699] - Update parquet.thrift from https://github.com/apache/parquet-format
+ * [PARQUET-653] - [C++] Re-enable -static-libstdc++ in dev artifact builds
+ * [PARQUET-763] - C++: Expose ParquetFileReader through Arrow reader
+ * [PARQUET-857] - [C++] Flatten parquet/encodings directory
+ * [PARQUET-862] - Provide defaut cache size values if CPU info probing is not available
+ * [PARQUET-689] - C++: Compress DataPages eagerly
+ * [PARQUET-874] - [C++] Use default memory allocator from Arrow
+ * [PARQUET-267] - Detach thirdparty code from build configuration.
+ * [PARQUET-418] - Add a utility to print contents of a Parquet file to stdout
+ * [PARQUET-519] - Disable compiler warning supressions and fix all DEBUG build warnings
+ * [PARQUET-447] - Add Debug and Release build types and associated compiler flags
+ * [PARQUET-868] - C++: Build snappy with optimizations
+ * [PARQUET-894] - Fix compilation warning
+ * [PARQUET-883] - C++: Support non-standard gcc version strings
+ * [PARQUET-607] - Public Writer header
+ * [PARQUET-731] - [CPP] Add API to return metadata size and Skip reading values
+ * [PARQUET-628] - Link thrift privately
+ * [PARQUET-877] - C++: Update Arrow Hash, update Version in metadata.
+ * [PARQUET-547] - Refactor most templates to use DataType structs rather than the Type::type enum
+ * [PARQUET-882] - [CPP] Improve Application Version parsing
+ * [PARQUET-448] - Add cmake option to skip building the unit tests
+ * [PARQUET-721] - Performance benchmarks for reading into Arrow structures
+ * [PARQUET-820] - C++: Decoders should directly emit arrays with spacing for null entries
+ * [PARQUET-813] - C++: Build dependencies using CMake External project
+ * [PARQUET-488] - Add SSE-related cmake options to manage compiler flags
+ * [PARQUET-564] - Add option to run unit tests with valgrind --tool=memcheck
+ * [PARQUET-572] - Rename parquet_cpp namespace to parquet
+ * [PARQUET-829] - C++: Make use of ARROW-469
+ * [PARQUET-501] - Add an OutputStream abstraction (capable of memory allocation) for Encoder public API
+ * [PARQUET-744] - Clarifications on build instructions
+ * [PARQUET-520] - Add version of LocalFileSource that uses memory-mapping for zero-copy reads
+ * [PARQUET-556] - Extend RowGroupStatistics to include "min" "max" statistics
+ * [PARQUET-671] - Improve performance of RLE/bit-packed decoding in parquet-cpp
+ * [PARQUET-681] - Add tool to scan a parquet file
+
+## New Feature
+ * [PARQUET-499] - Complete PlainEncoder implementation for all primitive types and test end to end
+ * [PARQUET-439] - Conform all copyright headers to ASF requirements
+ * [PARQUET-436] - Implement ParquetFileWriter class entry point for generating new Parquet files
+ * [PARQUET-435] - Provide vectorized ColumnReader interface
+ * [PARQUET-438] - Update RLE encoder/decoder modules from Impala upstream changes and adapt unit tests
+ * [PARQUET-512] - Add optional google/benchmark 3rd-party dependency for performance testing
+ * [PARQUET-566] - Add method to retrieve the full column path
+ * [PARQUET-613] - C++: Add conda packaging recipe
+ * [PARQUET-605] - Expose schema node in ColumnDescriptor
+ * [PARQUET-619] - C++: Add OutputStream for local files
+ * [PARQUET-583] - Implement Parquet to Thrift schema conversion
+ * [PARQUET-582] - Conversion functions for Parquet enums to Thrift enums
+ * [PARQUET-728] - [C++] Bring parquet::arrow up to date with API changes in arrow::io
+ * [PARQUET-752] - [C++] Conform parquet_arrow to upstream API changes
+ * [PARQUET-788] - [C++] Reference Impala / Apache Impala (incubating) in LICENSE
+ * [PARQUET-808] - [C++] Add API to read file given externally-provided FileMetadata
+ * [PARQUET-807] - [C++] Add API to read file metadata only from a file handle
+ * [PARQUET-805] - C++: Read Int96 into Arrow Timestamp(ns)
+ * [PARQUET-836] - [C++] Add column selection to parquet::arrow::FileReader
+ * [PARQUET-835] - [C++] Add option to parquet::arrow to read columns in parallel using a thread pool
+ * [PARQUET-830] - [C++] Add additional configuration options to parquet::arrow::OpenFIle
+ * [PARQUET-769] - C++: Add support for Brotli Compression
+ * [PARQUET-489] - Add visibility macros to be used for public and internal APIs of libparquet
+ * [PARQUET-542] - Support memory allocation from external memory
+ * [PARQUET-844] - [C++] Consolidate encodings, schema, and compression subdirectories into fewer files
+ * [PARQUET-848] - [C++] Consolidate libparquet_thrift subcomponent
+ * [PARQUET-646] - [C++] Enable easier 3rd-party toolchain clang builds on Linux
+ * [PARQUET-598] - [C++] Test writing all primitive data types
+ * [PARQUET-442] - Convert flat SchemaElement vector to implied nested schema data structure
+ * [PARQUET-867] - [C++] Support writing sliced Arrow arrays
+ * [PARQUET-456] - Add zlib codec support
+ * [PARQUET-834] - C++: Support r/w of arrow::ListArray
+ * [PARQUET-485] - Decouple data page delimiting from column reader / scanner classes, create test fixtures
+ * [PARQUET-434] - Add a ParquetFileReader class to encapsulate some low-level details of interacting with Parquet files
+ * [PARQUET-666] - PLAIN_DICTIONARY write support
+ * [PARQUET-437] - Incorporate googletest thirdparty dependency and add cmake tools (ADD_PARQUET_TEST) to simplify adding new unit tests
+ * [PARQUET-866] - [C++] Account for API changes in ARROW-33
+ * [PARQUET-545] - Improve API to support Decimal type
+ * [PARQUET-579] - Add API for writing Column statistics
+ * [PARQUET-494] - Implement PLAIN_DICTIONARY encoding and decoding
+ * [PARQUET-618] - C++: Automatically upload conda build artifacts on commits to master
+ * [PARQUET-833] - C++: Provide API to write spaced arrays (e.g. Arrow)
+ * [PARQUET-903] - C++: Add option to set RPATH to ORIGIN
+ * [PARQUET-451] - Add a RowGroup reader interface class
+ * [PARQUET-785] - C++: List conversion for Arrow Schemas
+ * [PARQUET-712] - C++: Read into Arrow memory
+ * [PARQUET-890] - C++: Support I/O of DATE columns in parquet_arrow
+ * [PARQUET-782] - C++: Support writing to Arrow sinks
+ * [PARQUET-849] - [C++] Upgrade default Thrift in thirdparty toolchain to 0.9.3 or 0.10
+ * [PARQUET-573] - C++: Create a public API for reading and writing file metadata
+
+## Task
+ * [PARQUET-814] - C++: Remove Conda recipes
+ * [PARQUET-503] - Re-enable parquet 2.0 encodings
+ * [PARQUET-169] - Parquet-cpp: Implement support for bulk reading and writing repetition/definition levels.
+ * [PARQUET-878] - C++: Remove setup_build_env from rc-verification script
+ * [PARQUET-881] - C++: Update Arrow hash to 0.2.0-rc2
+ * [PARQUET-771] - C++: Sync KEYS file
+ * [PARQUET-901] - C++: Publish RCs in apache-parquet-VERSION in SVN
+
+## Test
+ * [PARQUET-525] - Test coverage for malformed file failure modes on the read path
+ * [PARQUET-703] - [C++] Validate num_values metadata for columns with nulls
+ * [PARQUET-507] - Improve runtime of rle-test.cc
+ * [PARQUET-549] - Add scanner and column reader tests for dictionary data pages
+ * [PARQUET-457] - Add compressed data page unit tests
diff --git a/src/arrow/cpp/CMakeLists.txt b/src/arrow/cpp/CMakeLists.txt
new file mode 100644
index 000000000..65b9d96b2
--- /dev/null
+++ b/src/arrow/cpp/CMakeLists.txt
@@ -0,0 +1,958 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+cmake_minimum_required(VERSION 3.5)
+message(STATUS "Building using CMake version: ${CMAKE_VERSION}")
+
+# Compiler id for Apple Clang is now AppleClang.
+# https://www.cmake.org/cmake/help/latest/policy/CMP0025.html
+cmake_policy(SET CMP0025 NEW)
+
+# Only interpret if() arguments as variables or keywords when unquoted.
+# https://www.cmake.org/cmake/help/latest/policy/CMP0054.html
+cmake_policy(SET CMP0054 NEW)
+
+# Support new if() IN_LIST operator.
+# https://www.cmake.org/cmake/help/latest/policy/CMP0057.html
+cmake_policy(SET CMP0057 NEW)
+
+# Adapted from Apache Kudu: https://github.com/apache/kudu/commit/bd549e13743a51013585
+# Honor visibility properties for all target types.
+# https://www.cmake.org/cmake/help/latest/policy/CMP0063.html
+cmake_policy(SET CMP0063 NEW)
+
+# RPATH settings on macOS do not affect install_name.
+# https://cmake.org/cmake/help/latest/policy/CMP0068.html
+if(POLICY CMP0068)
+ cmake_policy(SET CMP0068 NEW)
+endif()
+
+# find_package() uses <PackageName>_ROOT variables.
+# https://cmake.org/cmake/help/latest/policy/CMP0074.html
+if(POLICY CMP0074)
+ cmake_policy(SET CMP0074 NEW)
+endif()
+
+set(ARROW_VERSION "6.0.1")
+
+string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" ARROW_BASE_VERSION "${ARROW_VERSION}")
+
+# if no build build type is specified, default to release builds
+if(NOT DEFINED CMAKE_BUILD_TYPE)
+ set(CMAKE_BUILD_TYPE
+ Release
+ CACHE STRING "Choose the type of build.")
+endif()
+string(TOLOWER ${CMAKE_BUILD_TYPE} LOWERCASE_BUILD_TYPE)
+string(TOUPPER ${CMAKE_BUILD_TYPE} UPPERCASE_BUILD_TYPE)
+
+list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules")
+
+# this must be included before the project() command, because of the way
+# vcpkg (ab)uses CMAKE_TOOLCHAIN_FILE to inject its logic into CMake
+if(ARROW_DEPENDENCY_SOURCE STREQUAL "VCPKG")
+ include(Usevcpkg)
+endif()
+
+project(arrow VERSION "${ARROW_BASE_VERSION}")
+
+set(ARROW_VERSION_MAJOR "${arrow_VERSION_MAJOR}")
+set(ARROW_VERSION_MINOR "${arrow_VERSION_MINOR}")
+set(ARROW_VERSION_PATCH "${arrow_VERSION_PATCH}")
+if(ARROW_VERSION_MAJOR STREQUAL ""
+ OR ARROW_VERSION_MINOR STREQUAL ""
+ OR ARROW_VERSION_PATCH STREQUAL "")
+ message(FATAL_ERROR "Failed to determine Arrow version from '${ARROW_VERSION}'")
+endif()
+
+# The SO version is also the ABI version
+if(ARROW_VERSION_MAJOR STREQUAL "0")
+ # Arrow 0.x.y => SO version is "x", full SO version is "x.y.0"
+ set(ARROW_SO_VERSION "${ARROW_VERSION_MINOR}")
+ set(ARROW_FULL_SO_VERSION "${ARROW_SO_VERSION}.${ARROW_VERSION_PATCH}.0")
+else()
+ # Arrow 1.x.y => SO version is "10x", full SO version is "10x.y.0"
+ math(EXPR ARROW_SO_VERSION "${ARROW_VERSION_MAJOR} * 100 + ${ARROW_VERSION_MINOR}")
+ set(ARROW_FULL_SO_VERSION "${ARROW_SO_VERSION}.${ARROW_VERSION_PATCH}.0")
+endif()
+
+message(STATUS "Arrow version: "
+ "${ARROW_VERSION_MAJOR}.${ARROW_VERSION_MINOR}.${ARROW_VERSION_PATCH} "
+ "(full: '${ARROW_VERSION}')")
+message(STATUS "Arrow SO version: ${ARROW_SO_VERSION} (full: ${ARROW_FULL_SO_VERSION})")
+
+set(ARROW_SOURCE_DIR ${PROJECT_SOURCE_DIR})
+set(ARROW_BINARY_DIR ${PROJECT_BINARY_DIR})
+
+include(CMakePackageConfigHelpers)
+include(CMakeParseArguments)
+include(ExternalProject)
+include(FindPackageHandleStandardArgs)
+
+include(GNUInstallDirs)
+
+set(BUILD_SUPPORT_DIR "${CMAKE_SOURCE_DIR}/build-support")
+
+set(ARROW_CMAKE_INSTALL_DIR "${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}")
+set(ARROW_DOC_DIR "share/doc/${PROJECT_NAME}")
+
+set(ARROW_LLVM_VERSIONS
+ "13.0"
+ "12.0"
+ "11.1"
+ "11.0"
+ "10"
+ "9"
+ "8"
+ "7")
+list(GET ARROW_LLVM_VERSIONS 0 ARROW_LLVM_VERSION_PRIMARY)
+string(REGEX REPLACE "^([0-9]+)(\\..+)?" "\\1" ARROW_LLVM_VERSION_PRIMARY_MAJOR
+ "${ARROW_LLVM_VERSION_PRIMARY}")
+
+file(READ ${CMAKE_CURRENT_SOURCE_DIR}/../.env ARROW_ENV)
+string(REGEX MATCH "CLANG_TOOLS=[^\n]+" ARROW_ENV_CLANG_TOOLS_VERSION "${ARROW_ENV}")
+string(REGEX REPLACE "^CLANG_TOOLS=" "" ARROW_CLANG_TOOLS_VERSION
+ "${ARROW_ENV_CLANG_TOOLS_VERSION}")
+string(REGEX REPLACE "^([0-9]+)(\\..+)?" "\\1" ARROW_CLANG_TOOLS_VERSION_MAJOR
+ "${ARROW_CLANG_TOOLS_VERSION}")
+
+if(APPLE)
+ find_program(BREW_BIN brew)
+ if(BREW_BIN)
+ execute_process(COMMAND ${BREW_BIN} --prefix
+ "llvm@${ARROW_LLVM_VERSION_PRIMARY_MAJOR}"
+ OUTPUT_VARIABLE LLVM_BREW_PREFIX
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+ if(NOT LLVM_BREW_PREFIX)
+ execute_process(COMMAND ${BREW_BIN} --prefix llvm
+ OUTPUT_VARIABLE LLVM_BREW_PREFIX
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+ endif()
+
+ execute_process(COMMAND ${BREW_BIN} --prefix "llvm@${ARROW_CLANG_TOOLS_VERSION_MAJOR}"
+ OUTPUT_VARIABLE CLANG_TOOLS_BREW_PREFIX
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+ if(NOT CLANG_TOOLS_BREW_PREFIX)
+ execute_process(COMMAND ${BREW_BIN} --prefix llvm
+ OUTPUT_VARIABLE CLANG_TOOLS_BREW_PREFIX
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+ endif()
+ endif()
+endif()
+
+if(WIN32 AND NOT MINGW)
+ # This is used to handle builds using e.g. clang in an MSVC setting.
+ set(MSVC_TOOLCHAIN TRUE)
+else()
+ set(MSVC_TOOLCHAIN FALSE)
+endif()
+
+find_package(ClangTools)
+find_package(InferTools)
+if("$ENV{CMAKE_EXPORT_COMPILE_COMMANDS}" STREQUAL "1"
+ OR CLANG_TIDY_FOUND
+ OR INFER_FOUND)
+ # Generate a Clang compile_commands.json "compilation database" file for use
+ # with various development tools, such as Vim's YouCompleteMe plugin.
+ # See http://clang.llvm.org/docs/JSONCompilationDatabase.html
+ set(CMAKE_EXPORT_COMPILE_COMMANDS 1)
+endif()
+
+# ----------------------------------------------------------------------
+# cmake options
+include(DefineOptions)
+
+# Needed for linting targets, etc.
+if(${CMAKE_VERSION} VERSION_LESS "3.12.0")
+ find_package(PythonInterp)
+else()
+ # Use the first Python installation on PATH, not the newest one
+ set(Python3_FIND_STRATEGY "LOCATION")
+ # On Windows, use registry last, not first
+ set(Python3_FIND_REGISTRY "LAST")
+ # On macOS, use framework last, not first
+ set(Python3_FIND_FRAMEWORK "LAST")
+
+ find_package(Python3)
+ set(PYTHON_EXECUTABLE ${Python3_EXECUTABLE})
+endif()
+
+if(ARROW_USE_CCACHE)
+ find_program(CCACHE_FOUND ccache)
+ if(CCACHE_FOUND)
+ message(STATUS "Using ccache: ${CCACHE_FOUND}")
+ set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ${CCACHE_FOUND})
+ set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ${CCACHE_FOUND})
+ # ARROW-3985: let ccache preserve C++ comments, because some of them may be
+ # meaningful to the compiler
+ set(ENV{CCACHE_COMMENTS} "1")
+ endif(CCACHE_FOUND)
+endif()
+
+if(ARROW_USE_PRECOMPILED_HEADERS AND ${CMAKE_VERSION} VERSION_LESS "3.16.0")
+ message(WARNING "Precompiled headers need CMake 3.16.0 or later, disabling")
+ set(ARROW_USE_PRECOMPILED_HEADERS OFF)
+endif()
+
+if(ARROW_OPTIONAL_INSTALL)
+ # Don't make the "install" target depend on the "all" target
+ set(CMAKE_SKIP_INSTALL_ALL_DEPENDENCY true)
+
+ set(INSTALL_IS_OPTIONAL OPTIONAL)
+endif()
+
+#
+# "make lint" target
+#
+if(NOT ARROW_VERBOSE_LINT)
+ set(ARROW_LINT_QUIET "--quiet")
+endif()
+
+if(NOT LINT_EXCLUSIONS_FILE)
+ # source files matching a glob from a line in this file
+ # will be excluded from linting (cpplint, clang-tidy, clang-format)
+ set(LINT_EXCLUSIONS_FILE ${BUILD_SUPPORT_DIR}/lint_exclusions.txt)
+endif()
+
+find_program(CPPLINT_BIN
+ NAMES cpplint cpplint.py
+ HINTS ${BUILD_SUPPORT_DIR})
+message(STATUS "Found cpplint executable at ${CPPLINT_BIN}")
+
+add_custom_target(lint
+ ${PYTHON_EXECUTABLE}
+ ${BUILD_SUPPORT_DIR}/run_cpplint.py
+ --cpplint_binary
+ ${CPPLINT_BIN}
+ --exclude_globs
+ ${LINT_EXCLUSIONS_FILE}
+ --source_dir
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ ${ARROW_LINT_QUIET})
+
+#
+# "make format" and "make check-format" targets
+#
+if(${CLANG_FORMAT_FOUND})
+ # runs clang format and updates files in place.
+ add_custom_target(format
+ ${PYTHON_EXECUTABLE}
+ ${BUILD_SUPPORT_DIR}/run_clang_format.py
+ --clang_format_binary
+ ${CLANG_FORMAT_BIN}
+ --exclude_globs
+ ${LINT_EXCLUSIONS_FILE}
+ --source_dir
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ --fix
+ ${ARROW_LINT_QUIET})
+
+ # runs clang format and exits with a non-zero exit code if any files need to be reformatted
+ add_custom_target(check-format
+ ${PYTHON_EXECUTABLE}
+ ${BUILD_SUPPORT_DIR}/run_clang_format.py
+ --clang_format_binary
+ ${CLANG_FORMAT_BIN}
+ --exclude_globs
+ ${LINT_EXCLUSIONS_FILE}
+ --source_dir
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ ${ARROW_LINT_QUIET})
+endif()
+
+add_custom_target(lint_cpp_cli ${PYTHON_EXECUTABLE} ${BUILD_SUPPORT_DIR}/lint_cpp_cli.py
+ ${CMAKE_CURRENT_SOURCE_DIR}/src)
+
+if(ARROW_LINT_ONLY)
+ message("ARROW_LINT_ONLY was specified, this is only a partial build directory")
+ return()
+endif()
+
+#
+# "make clang-tidy" and "make check-clang-tidy" targets
+#
+if(${CLANG_TIDY_FOUND})
+ # TODO check to make sure .clang-tidy is being respected
+
+ # runs clang-tidy and attempts to fix any warning automatically
+ add_custom_target(clang-tidy
+ ${PYTHON_EXECUTABLE}
+ ${BUILD_SUPPORT_DIR}/run_clang_tidy.py
+ --clang_tidy_binary
+ ${CLANG_TIDY_BIN}
+ --exclude_globs
+ ${LINT_EXCLUSIONS_FILE}
+ --compile_commands
+ ${CMAKE_BINARY_DIR}/compile_commands.json
+ --source_dir
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ --fix
+ ${ARROW_LINT_QUIET})
+
+ # runs clang-tidy and exits with a non-zero exit code if any errors are found.
+ add_custom_target(check-clang-tidy
+ ${PYTHON_EXECUTABLE}
+ ${BUILD_SUPPORT_DIR}/run_clang_tidy.py
+ --clang_tidy_binary
+ ${CLANG_TIDY_BIN}
+ --exclude_globs
+ ${LINT_EXCLUSIONS_FILE}
+ --compile_commands
+ ${CMAKE_BINARY_DIR}/compile_commands.json
+ --source_dir
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ ${ARROW_LINT_QUIET})
+endif()
+
+if(UNIX)
+ add_custom_target(iwyu ${BUILD_SUPPORT_DIR}/iwyu/iwyu.sh)
+ add_custom_target(iwyu-all ${BUILD_SUPPORT_DIR}/iwyu/iwyu.sh all)
+endif(UNIX)
+
+#
+# Set up various options
+#
+
+if(ARROW_BUILD_BENCHMARKS
+ OR ARROW_BUILD_TESTS
+ OR ARROW_BUILD_INTEGRATION
+ OR ARROW_FUZZING)
+ set(ARROW_JSON ON)
+ set(ARROW_TESTING ON)
+endif()
+
+if(ARROW_GANDIVA)
+ set(ARROW_WITH_RE2 ON)
+endif()
+
+if(ARROW_CUDA
+ OR ARROW_FLIGHT
+ OR ARROW_PARQUET
+ OR ARROW_BUILD_TESTS
+ OR ARROW_BUILD_BENCHMARKS)
+ set(ARROW_IPC ON)
+endif()
+
+if(ARROW_ENGINE)
+ set(ARROW_COMPUTE ON)
+endif()
+
+if(ARROW_DATASET)
+ set(ARROW_COMPUTE ON)
+ set(ARROW_FILESYSTEM ON)
+endif()
+
+if(ARROW_PARQUET)
+ set(ARROW_COMPUTE ON)
+endif()
+
+if(ARROW_PYTHON)
+ set(ARROW_COMPUTE ON)
+ set(ARROW_CSV ON)
+ set(ARROW_DATASET ON)
+ set(ARROW_FILESYSTEM ON)
+ set(ARROW_HDFS ON)
+ set(ARROW_JSON ON)
+endif()
+
+if(MSVC_TOOLCHAIN)
+ # ORC doesn't build on windows
+ set(ARROW_ORC OFF)
+ # Plasma using glog is not fully tested on windows.
+ set(ARROW_USE_GLOG OFF)
+endif()
+
+if(ARROW_JNI)
+ set(ARROW_BUILD_STATIC ON)
+endif()
+
+if(ARROW_ORC)
+ set(ARROW_WITH_LZ4 ON)
+ set(ARROW_WITH_SNAPPY ON)
+ set(ARROW_WITH_ZLIB ON)
+ set(ARROW_WITH_ZSTD ON)
+endif()
+
+# datetime code used by iOS requires zlib support
+if(IOS)
+ set(ARROW_WITH_ZLIB ON)
+endif()
+
+if(NOT ARROW_BUILD_TESTS)
+ set(NO_TESTS 1)
+else()
+ add_custom_target(all-tests)
+ add_custom_target(unittest
+ ctest
+ -j4
+ -L
+ unittest
+ --output-on-failure)
+ add_dependencies(unittest all-tests)
+endif()
+
+if(ARROW_ENABLE_TIMING_TESTS)
+ add_definitions(-DARROW_WITH_TIMING_TESTS)
+endif()
+
+if(NOT ARROW_BUILD_BENCHMARKS)
+ set(NO_BENCHMARKS 1)
+else()
+ add_custom_target(all-benchmarks)
+ add_custom_target(benchmark ctest -L benchmark)
+ add_dependencies(benchmark all-benchmarks)
+ if(ARROW_BUILD_BENCHMARKS_REFERENCE)
+ add_definitions(-DARROW_WITH_BENCHMARKS_REFERENCE)
+ endif()
+endif()
+
+if(NOT ARROW_BUILD_EXAMPLES)
+ set(NO_EXAMPLES 1)
+endif()
+
+if(NOT ARROW_FUZZING)
+ set(NO_FUZZING 1)
+endif()
+
+if(ARROW_LARGE_MEMORY_TESTS)
+ add_definitions(-DARROW_LARGE_MEMORY_TESTS)
+endif()
+
+if(ARROW_TEST_MEMCHECK)
+ add_definitions(-DARROW_VALGRIND)
+endif()
+
+if(ARROW_USE_UBSAN)
+ add_definitions(-DARROW_UBSAN)
+endif()
+
+#
+# Compiler flags
+#
+
+if(ARROW_NO_DEPRECATED_API)
+ add_definitions(-DARROW_NO_DEPRECATED_API)
+endif()
+
+if(ARROW_EXTRA_ERROR_CONTEXT)
+ add_definitions(-DARROW_EXTRA_ERROR_CONTEXT)
+endif()
+
+include(SetupCxxFlags)
+
+#
+# Linker flags
+#
+
+# Localize thirdparty symbols using a linker version script. This hides them
+# from the client application. The OS X linker does not support the
+# version-script option.
+if(CMAKE_VERSION VERSION_LESS 3.18)
+ if(APPLE OR WIN32)
+ set(CXX_LINKER_SUPPORTS_VERSION_SCRIPT FALSE)
+ else()
+ set(CXX_LINKER_SUPPORTS_VERSION_SCRIPT TRUE)
+ endif()
+else()
+ include(CheckLinkerFlag)
+ check_linker_flag(CXX
+ "-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/src/arrow/symbols.map"
+ CXX_LINKER_SUPPORTS_VERSION_SCRIPT)
+endif()
+
+#
+# Build output directory
+#
+
+# set compile output directory
+string(TOLOWER ${CMAKE_BUILD_TYPE} BUILD_SUBDIR_NAME)
+
+# If build in-source, create the latest symlink. If build out-of-source, which is
+# preferred, simply output the binaries in the build folder
+if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_CURRENT_BINARY_DIR})
+ set(BUILD_OUTPUT_ROOT_DIRECTORY
+ "${CMAKE_CURRENT_BINARY_DIR}/build/${BUILD_SUBDIR_NAME}/")
+ # Link build/latest to the current build directory, to avoid developers
+ # accidentally running the latest debug build when in fact they're building
+ # release builds.
+ file(MAKE_DIRECTORY ${BUILD_OUTPUT_ROOT_DIRECTORY})
+ if(NOT APPLE)
+ set(MORE_ARGS "-T")
+ endif()
+ execute_process(COMMAND ln ${MORE_ARGS} -sf ${BUILD_OUTPUT_ROOT_DIRECTORY}
+ ${CMAKE_CURRENT_BINARY_DIR}/build/latest)
+else()
+ set(BUILD_OUTPUT_ROOT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${BUILD_SUBDIR_NAME}/")
+endif()
+
+# where to put generated archives (.a files)
+set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${BUILD_OUTPUT_ROOT_DIRECTORY}")
+set(ARCHIVE_OUTPUT_DIRECTORY "${BUILD_OUTPUT_ROOT_DIRECTORY}")
+
+# where to put generated libraries (.so files)
+set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${BUILD_OUTPUT_ROOT_DIRECTORY}")
+set(LIBRARY_OUTPUT_DIRECTORY "${BUILD_OUTPUT_ROOT_DIRECTORY}")
+
+# where to put generated binaries
+set(EXECUTABLE_OUTPUT_PATH "${BUILD_OUTPUT_ROOT_DIRECTORY}")
+
+if(CMAKE_GENERATOR STREQUAL Xcode)
+ # Xcode projects support multi-configuration builds. This forces a single output directory
+ # when building with Xcode that is consistent with single-configuration Makefile driven build.
+ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_${UPPERCASE_BUILD_TYPE}
+ "${BUILD_OUTPUT_ROOT_DIRECTORY}")
+ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_${UPPERCASE_BUILD_TYPE}
+ "${BUILD_OUTPUT_ROOT_DIRECTORY}")
+ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_${UPPERCASE_BUILD_TYPE}
+ "${BUILD_OUTPUT_ROOT_DIRECTORY}")
+endif()
+
+#
+# Dependencies
+#
+
+include(BuildUtils)
+enable_testing()
+
+# For arrow.pc. Requires.private and Libs.private are used when
+# "pkg-config --libs --static arrow" is used.
+set(ARROW_PC_REQUIRES_PRIVATE)
+set(ARROW_PC_LIBS_PRIVATE)
+
+include(ThirdpartyToolchain)
+
+# Add common flags
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_COMMON_FLAGS}")
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARROW_CXXFLAGS}")
+
+# For any C code, use the same flags. These flags don't contain
+# C++ specific flags.
+set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${CXX_COMMON_FLAGS} ${ARROW_CXXFLAGS}")
+
+# Remove --std=c++11 to avoid errors from C compilers
+string(REPLACE "-std=c++11" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
+
+# Add C++-only flags, like -std=c++11
+set(CMAKE_CXX_FLAGS "${CXX_ONLY_FLAGS} ${CMAKE_CXX_FLAGS}")
+
+# ASAN / TSAN / UBSAN
+if(ARROW_FUZZING)
+ set(ARROW_USE_COVERAGE ON)
+endif()
+include(san-config)
+
+# Code coverage
+if("${ARROW_GENERATE_COVERAGE}")
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} --coverage -DCOVERAGE_BUILD")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --coverage -DCOVERAGE_BUILD")
+endif()
+
+# CMAKE_CXX_FLAGS now fully assembled
+message(STATUS "CMAKE_C_FLAGS: ${CMAKE_C_FLAGS}")
+message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
+
+include_directories(${CMAKE_CURRENT_BINARY_DIR}/src)
+include_directories(src)
+
+# Compiled flatbuffers files
+include_directories(src/generated)
+
+#
+# Visibility
+#
+if(PARQUET_BUILD_SHARED)
+ set_target_properties(arrow_shared
+ PROPERTIES C_VISIBILITY_PRESET hidden
+ CXX_VISIBILITY_PRESET hidden
+ VISIBILITY_INLINES_HIDDEN 1)
+endif()
+
+#
+# "make ctags" target
+#
+if(UNIX)
+ add_custom_target(ctags ctags -R --languages=c++,c)
+endif(UNIX)
+
+#
+# "make etags" target
+#
+if(UNIX)
+ add_custom_target(tags
+ etags
+ --members
+ --declarations
+ `find
+ ${CMAKE_CURRENT_SOURCE_DIR}/src
+ -name
+ \\*.cc
+ -or
+ -name
+ \\*.hh
+ -or
+ -name
+ \\*.cpp
+ -or
+ -name
+ \\*.h
+ -or
+ -name
+ \\*.c
+ -or
+ -name
+ \\*.f`)
+ add_custom_target(etags DEPENDS tags)
+endif(UNIX)
+
+#
+# "make cscope" target
+#
+if(UNIX)
+ add_custom_target(cscope
+ find
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ (-name
+ \\*.cc
+ -or
+ -name
+ \\*.hh
+ -or
+ -name
+ \\*.cpp
+ -or
+ -name
+ \\*.h
+ -or
+ -name
+ \\*.c
+ -or
+ -name
+ \\*.f)
+ -exec
+ echo
+ \"{}\"
+ \;
+ >
+ cscope.files
+ &&
+ cscope
+ -q
+ -b
+ VERBATIM)
+endif(UNIX)
+
+#
+# "make infer" target
+#
+
+if(${INFER_FOUND})
+ # runs infer capture
+ add_custom_target(infer ${BUILD_SUPPORT_DIR}/run-infer.sh ${INFER_BIN}
+ ${CMAKE_BINARY_DIR}/compile_commands.json 1)
+ # runs infer analyze
+ add_custom_target(infer-analyze ${BUILD_SUPPORT_DIR}/run-infer.sh ${INFER_BIN}
+ ${CMAKE_BINARY_DIR}/compile_commands.json 2)
+ # runs infer report
+ add_custom_target(infer-report ${BUILD_SUPPORT_DIR}/run-infer.sh ${INFER_BIN}
+ ${CMAKE_BINARY_DIR}/compile_commands.json 3)
+endif()
+
+#
+# Linker and Dependencies
+#
+
+# Libraries to link statically with libarrow.so
+set(ARROW_LINK_LIBS)
+set(ARROW_STATIC_LINK_LIBS)
+set(ARROW_STATIC_INSTALL_INTERFACE_LIBS)
+
+if(ARROW_USE_OPENSSL)
+ set(ARROW_OPENSSL_LIBS OpenSSL::Crypto OpenSSL::SSL)
+ list(APPEND ARROW_LINK_LIBS ${ARROW_OPENSSL_LIBS})
+ list(APPEND ARROW_STATIC_LINK_LIBS ${ARROW_OPENSSL_LIBS})
+ list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS ${ARROW_OPENSSL_LIBS})
+endif()
+
+if(ARROW_WITH_BROTLI)
+ # Order is important for static linking
+ set(ARROW_BROTLI_LIBS Brotli::brotlienc Brotli::brotlidec Brotli::brotlicommon)
+ list(APPEND ARROW_LINK_LIBS ${ARROW_BROTLI_LIBS})
+ list(APPEND ARROW_STATIC_LINK_LIBS ${ARROW_BROTLI_LIBS})
+ if(Brotli_SOURCE STREQUAL "SYSTEM")
+ list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS ${ARROW_BROTLI_LIBS})
+ endif()
+endif()
+
+if(ARROW_WITH_BZ2)
+ list(APPEND ARROW_STATIC_LINK_LIBS BZip2::BZip2)
+ if(BZip2_SOURCE STREQUAL "SYSTEM")
+ list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS BZip2::BZip2)
+ endif()
+endif()
+
+if(ARROW_WITH_LZ4)
+ list(APPEND ARROW_STATIC_LINK_LIBS LZ4::lz4)
+ if(Lz4_SOURCE STREQUAL "SYSTEM")
+ list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS LZ4::lz4)
+ endif()
+endif()
+
+if(ARROW_WITH_SNAPPY)
+ list(APPEND ARROW_STATIC_LINK_LIBS Snappy::snappy)
+ if(Snappy_SOURCE STREQUAL "SYSTEM")
+ list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS Snappy::snappy)
+ endif()
+endif()
+
+if(ARROW_WITH_ZLIB)
+ list(APPEND ARROW_STATIC_LINK_LIBS ZLIB::ZLIB)
+ if(ZLIB_SOURCE STREQUAL "SYSTEM")
+ list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS ZLIB::ZLIB)
+ endif()
+endif()
+
+if(ARROW_WITH_ZSTD)
+ list(APPEND ARROW_STATIC_LINK_LIBS ${ARROW_ZSTD_LIBZSTD})
+ if(zstd_SOURCE STREQUAL "SYSTEM")
+ list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS ${ARROW_ZSTD_LIBZSTD})
+ endif()
+endif()
+
+if(ARROW_ORC)
+ list(APPEND ARROW_LINK_LIBS orc::liborc ${ARROW_PROTOBUF_LIBPROTOBUF})
+ list(APPEND ARROW_STATIC_LINK_LIBS orc::liborc ${ARROW_PROTOBUF_LIBPROTOBUF})
+ if(ORC_SOURCE STREQUAL "SYSTEM")
+ list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS orc::liborc
+ ${ARROW_PROTOBUF_LIBPROTOBUF})
+ endif()
+endif()
+
+if(ARROW_GCS)
+ list(APPEND ARROW_LINK_LIBS google-cloud-cpp::storage)
+ list(APPEND ARROW_STATIC_LINK_LIBS google-cloud-cpp::storage)
+ if(google_cloud_cpp_storage_SOURCE STREQUAL "SYSTEM")
+ list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS google-cloud-cpp::storage)
+ endif()
+endif()
+
+if(ARROW_USE_GLOG)
+ list(APPEND ARROW_LINK_LIBS glog::glog)
+ list(APPEND ARROW_STATIC_LINK_LIBS glog::glog)
+ if(GLOG_SOURCE STREQUAL "SYSTEM")
+ list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS glog::glog)
+ endif()
+ add_definitions("-DARROW_USE_GLOG")
+endif()
+
+if(ARROW_S3)
+ list(APPEND ARROW_LINK_LIBS ${AWSSDK_LINK_LIBRARIES})
+ list(APPEND ARROW_STATIC_LINK_LIBS ${AWSSDK_LINK_LIBRARIES})
+endif()
+
+if(ARROW_WITH_UTF8PROC)
+ list(APPEND ARROW_LINK_LIBS utf8proc::utf8proc)
+ list(APPEND ARROW_STATIC_LINK_LIBS utf8proc::utf8proc)
+ if(utf8proc_SOURCE STREQUAL "SYSTEM")
+ list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS utf8proc::utf8proc)
+ endif()
+endif()
+
+if(ARROW_WITH_RE2)
+ list(APPEND ARROW_LINK_LIBS re2::re2)
+ list(APPEND ARROW_STATIC_LINK_LIBS re2::re2)
+ if(re2_SOURCE STREQUAL "SYSTEM")
+ list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS re2::re2)
+ endif()
+endif()
+
+add_custom_target(arrow_dependencies)
+add_custom_target(arrow_benchmark_dependencies)
+add_custom_target(arrow_test_dependencies)
+
+# ARROW-4581: CMake can be finicky about invoking the ExternalProject builds
+# for some of the library dependencies, so we "nuke it from orbit" by making
+# the toolchain dependency explicit using these "dependencies" targets
+add_dependencies(arrow_dependencies toolchain)
+add_dependencies(arrow_test_dependencies toolchain-tests)
+
+if(ARROW_STATIC_LINK_LIBS)
+ add_dependencies(arrow_dependencies ${ARROW_STATIC_LINK_LIBS})
+ if(ARROW_ORC)
+ if(NOT MSVC_TOOLCHAIN)
+ list(APPEND ARROW_STATIC_LINK_LIBS ${CMAKE_DL_LIBS})
+ list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS ${CMAKE_DL_LIBS})
+ endif()
+ endif()
+endif()
+
+set(ARROW_SHARED_PRIVATE_LINK_LIBS ${ARROW_STATIC_LINK_LIBS})
+
+# boost::filesystem is needed for S3 and Flight tests as a boost::process dependency.
+if(((ARROW_FLIGHT
+ OR ARROW_S3
+ OR ARROW_GCS)
+ AND (ARROW_BUILD_TESTS OR ARROW_BUILD_INTEGRATION)
+ ))
+ list(APPEND ARROW_TEST_LINK_LIBS ${BOOST_FILESYSTEM_LIBRARY} ${BOOST_SYSTEM_LIBRARY})
+endif()
+
+if(NOT MSVC_TOOLCHAIN)
+ list(APPEND ARROW_LINK_LIBS ${CMAKE_DL_LIBS})
+ list(APPEND ARROW_SHARED_INSTALL_INTERFACE_LIBS ${CMAKE_DL_LIBS})
+endif()
+
+set(ARROW_TEST_LINK_TOOLCHAIN
+ GTest::gtest_main
+ GTest::gtest
+ GTest::gmock
+ ${BOOST_FILESYSTEM_LIBRARY}
+ ${BOOST_SYSTEM_LIBRARY})
+
+if(ARROW_BUILD_TESTS)
+ add_dependencies(arrow_test_dependencies ${ARROW_TEST_LINK_TOOLCHAIN})
+endif()
+
+if(ARROW_BUILD_BENCHMARKS)
+ # Some benchmarks use gtest
+ add_dependencies(arrow_benchmark_dependencies arrow_test_dependencies
+ toolchain-benchmarks)
+endif()
+
+set(ARROW_TEST_STATIC_LINK_LIBS arrow_testing_static arrow_static ${ARROW_LINK_LIBS}
+ ${ARROW_TEST_LINK_TOOLCHAIN})
+
+set(ARROW_TEST_SHARED_LINK_LIBS arrow_testing_shared arrow_shared ${ARROW_LINK_LIBS}
+ ${ARROW_TEST_LINK_TOOLCHAIN})
+
+if(NOT MSVC)
+ set(ARROW_TEST_SHARED_LINK_LIBS ${ARROW_TEST_SHARED_LINK_LIBS} ${CMAKE_DL_LIBS})
+endif()
+
+if("${ARROW_TEST_LINKAGE}" STREQUAL "shared")
+ if(ARROW_BUILD_TESTS AND NOT ARROW_BUILD_SHARED)
+ message(FATAL_ERROR "If using shared linkage for unit tests, must also \
+pass ARROW_BUILD_SHARED=on")
+ endif()
+ # Use shared linking for unit tests if it's available
+ set(ARROW_TEST_LINK_LIBS ${ARROW_TEST_SHARED_LINK_LIBS})
+ set(ARROW_EXAMPLE_LINK_LIBS arrow_shared)
+else()
+ if(ARROW_BUILD_TESTS AND NOT ARROW_BUILD_STATIC)
+ message(FATAL_ERROR "If using static linkage for unit tests, must also \
+pass ARROW_BUILD_STATIC=on")
+ endif()
+ set(ARROW_TEST_LINK_LIBS ${ARROW_TEST_STATIC_LINK_LIBS})
+ set(ARROW_EXAMPLE_LINK_LIBS arrow_static)
+endif()
+
+if(ARROW_BUILD_BENCHMARKS)
+ # In the case that benchmark::benchmark_main is not available,
+ # we need to provide our own version. This only happens for older versions
+ # of benchmark.
+ if(NOT TARGET benchmark::benchmark_main)
+ add_library(arrow_benchmark_main STATIC src/arrow/util/benchmark_main.cc)
+ add_library(benchmark::benchmark_main ALIAS arrow_benchmark_main)
+ endif()
+
+ set(ARROW_BENCHMARK_LINK_LIBS benchmark::benchmark_main benchmark::benchmark
+ ${ARROW_TEST_LINK_LIBS})
+ if(WIN32)
+ set(ARROW_BENCHMARK_LINK_LIBS Shlwapi.dll ${ARROW_BENCHMARK_LINK_LIBS})
+ endif()
+endif()
+
+if(ARROW_JEMALLOC)
+ add_definitions(-DARROW_JEMALLOC)
+ add_definitions(-DARROW_JEMALLOC_INCLUDE_DIR=${JEMALLOC_INCLUDE_DIR})
+ list(APPEND ARROW_LINK_LIBS jemalloc::jemalloc)
+ list(APPEND ARROW_STATIC_LINK_LIBS jemalloc::jemalloc)
+endif()
+
+if(ARROW_MIMALLOC)
+ add_definitions(-DARROW_MIMALLOC)
+ list(APPEND ARROW_LINK_LIBS mimalloc::mimalloc)
+ list(APPEND ARROW_STATIC_LINK_LIBS mimalloc::mimalloc)
+endif()
+
+# ----------------------------------------------------------------------
+# Handle platform-related libraries like -pthread
+
+set(ARROW_SYSTEM_LINK_LIBS)
+
+list(APPEND ARROW_SYSTEM_LINK_LIBS Threads::Threads)
+if(CMAKE_THREAD_LIBS_INIT)
+ string(APPEND ARROW_PC_LIBS_PRIVATE " ${CMAKE_THREAD_LIBS_INIT}")
+endif()
+
+if(WIN32)
+ # Winsock
+ list(APPEND ARROW_SYSTEM_LINK_LIBS "ws2_32.dll")
+endif()
+
+if(NOT WIN32 AND NOT APPLE)
+ # Pass -lrt on Linux only
+ list(APPEND ARROW_SYSTEM_LINK_LIBS rt)
+endif()
+
+list(APPEND ARROW_LINK_LIBS ${ARROW_SYSTEM_LINK_LIBS})
+list(APPEND ARROW_STATIC_LINK_LIBS ${ARROW_SYSTEM_LINK_LIBS})
+list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS ${ARROW_SYSTEM_LINK_LIBS})
+
+#
+# Subdirectories
+#
+
+if(NOT WIN32 AND ARROW_PLASMA)
+ add_subdirectory(src/plasma)
+endif()
+
+add_subdirectory(src/arrow)
+
+if(ARROW_PARQUET)
+ add_subdirectory(src/parquet)
+ add_subdirectory(tools/parquet)
+ if(PARQUET_BUILD_EXAMPLES)
+ add_subdirectory(examples/parquet)
+ endif()
+endif()
+
+if(ARROW_JNI)
+ add_subdirectory(src/jni)
+endif()
+
+if(ARROW_GANDIVA)
+ add_subdirectory(src/gandiva)
+endif()
+
+if(ARROW_BUILD_EXAMPLES)
+ add_custom_target(runexample ctest -L example)
+ add_subdirectory(examples/arrow)
+endif()
+
+install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/../LICENSE.txt
+ ${CMAKE_CURRENT_SOURCE_DIR}/../NOTICE.txt
+ ${CMAKE_CURRENT_SOURCE_DIR}/README.md DESTINATION "${ARROW_DOC_DIR}")
+
+#
+# Validate and print out Arrow configuration options
+#
+
+validate_config()
+config_summary_message()
+if(${ARROW_BUILD_CONFIG_SUMMARY_JSON})
+ config_summary_json()
+endif()
diff --git a/src/arrow/cpp/CMakeSettings.json b/src/arrow/cpp/CMakeSettings.json
new file mode 100644
index 000000000..90d3abbca
--- /dev/null
+++ b/src/arrow/cpp/CMakeSettings.json
@@ -0,0 +1,21 @@
+{
+ "configurations": [
+ {
+ "name": "x64-Debug (default)",
+ "generator": "Ninja",
+ "configurationType": "Debug",
+ "inheritEnvironments": [ "msvc_x64_x64" ],
+ "buildRoot": "${projectDir}\\out\\build\\${name}",
+ "installRoot": "${projectDir}\\out\\install\\${name}",
+ "cmakeCommandArgs": "",
+ "buildCommandArgs": "",
+ "ctestCommandArgs": "",
+ "variables": [
+ {
+ "name":"VCPKG_MANIFEST_MODE",
+ "value":"OFF"
+ }
+ ]
+ }
+ ]
+}
diff --git a/src/arrow/cpp/README.md b/src/arrow/cpp/README.md
new file mode 100644
index 000000000..b083f3fe7
--- /dev/null
+++ b/src/arrow/cpp/README.md
@@ -0,0 +1,34 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Apache Arrow C++
+
+This directory contains the code and build system for the Arrow C++ libraries,
+as well as for the C++ libraries for Apache Parquet.
+
+## Installation
+
+See https://arrow.apache.org/install/ for the latest instructions how
+to install pre-compiled binary versions of the library.
+
+## Source Builds and Development
+
+Please refer to our latest [C++ Development Documentation][1].
+
+[1]: https://github.com/apache/arrow/blob/master/docs/source/developers/cpp
diff --git a/src/arrow/cpp/apidoc/.gitignore b/src/arrow/cpp/apidoc/.gitignore
new file mode 100644
index 000000000..5ccff1a6b
--- /dev/null
+++ b/src/arrow/cpp/apidoc/.gitignore
@@ -0,0 +1 @@
+html/
diff --git a/src/arrow/cpp/apidoc/Doxyfile b/src/arrow/cpp/apidoc/Doxyfile
new file mode 100644
index 000000000..8978dba53
--- /dev/null
+++ b/src/arrow/cpp/apidoc/Doxyfile
@@ -0,0 +1,2551 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Doxyfile 1.8.18
+
+# This file describes the settings to be used by the documentation system
+# doxygen (www.doxygen.org) for a project.
+#
+# All text after a double hash (##) is considered a comment and is placed in
+# front of the TAG it is preceding.
+#
+# All text after a single hash (#) is considered a comment and will be ignored.
+# The format is:
+# TAG = value [value, ...]
+# For lists, items can also be appended using:
+# TAG += value [value, ...]
+# Values that contain spaces should be placed between quotes (\" \").
+
+#---------------------------------------------------------------------------
+# Project related configuration options
+#---------------------------------------------------------------------------
+
+# This tag specifies the encoding used for all characters in the configuration
+# file that follow. The default is UTF-8 which is also the encoding used for all
+# text before the first occurrence of this tag. Doxygen uses libiconv (or the
+# iconv built into libc) for the transcoding. See
+# https://www.gnu.org/software/libiconv/ for the list of possible encodings.
+# The default value is: UTF-8.
+
+DOXYFILE_ENCODING = UTF-8
+
+# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by
+# double-quotes, unless you are using Doxywizard) that should identify the
+# project for which the documentation is generated. This name is used in the
+# title of most generated pages and in a few other places.
+# The default value is: My Project.
+
+PROJECT_NAME = "Apache Arrow (C++)"
+
+# The PROJECT_NUMBER tag can be used to enter a project or revision number. This
+# could be handy for archiving the generated documentation or if some version
+# control system is used.
+
+PROJECT_NUMBER =
+
+# Using the PROJECT_BRIEF tag one can provide an optional one line description
+# for a project that appears at the top of each page and should give viewer a
+# quick idea about the purpose of the project. Keep the description short.
+
+PROJECT_BRIEF = "A columnar in-memory analytics layer designed to accelerate big data."
+
+# With the PROJECT_LOGO tag one can specify a logo or an icon that is included
+# in the documentation. The maximum height of the logo should not exceed 55
+# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy
+# the logo to the output directory.
+
+PROJECT_LOGO =
+
+# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path
+# into which the generated documentation will be written. If a relative path is
+# entered, it will be relative to the location where doxygen was started. If
+# left blank the current directory will be used.
+
+OUTPUT_DIRECTORY = $(OUTPUT_DIRECTORY)
+
+# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub-
+# directories (in 2 levels) under the output directory of each output format and
+# will distribute the generated files over these directories. Enabling this
+# option can be useful when feeding doxygen a huge amount of source files, where
+# putting all generated files in the same directory would otherwise causes
+# performance problems for the file system.
+# The default value is: NO.
+
+CREATE_SUBDIRS = NO
+
+# If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII
+# characters to appear in the names of generated files. If set to NO, non-ASCII
+# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode
+# U+3044.
+# The default value is: NO.
+
+ALLOW_UNICODE_NAMES = NO
+
+# The OUTPUT_LANGUAGE tag is used to specify the language in which all
+# documentation generated by doxygen is written. Doxygen will use this
+# information to generate all constant output in the proper language.
+# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Catalan, Chinese,
+# Chinese-Traditional, Croatian, Czech, Danish, Dutch, English (United States),
+# Esperanto, Farsi (Persian), Finnish, French, German, Greek, Hungarian,
+# Indonesian, Italian, Japanese, Japanese-en (Japanese with English messages),
+# Korean, Korean-en (Korean with English messages), Latvian, Lithuanian,
+# Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, Romanian, Russian,
+# Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, Swedish, Turkish,
+# Ukrainian and Vietnamese.
+# The default value is: English.
+
+OUTPUT_LANGUAGE = English
+
+# The OUTPUT_TEXT_DIRECTION tag is used to specify the direction in which all
+# documentation generated by doxygen is written. Doxygen will use this
+# information to generate all generated output in the proper direction.
+# Possible values are: None, LTR, RTL and Context.
+# The default value is: None.
+
+OUTPUT_TEXT_DIRECTION = None
+
+# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member
+# descriptions after the members that are listed in the file and class
+# documentation (similar to Javadoc). Set to NO to disable this.
+# The default value is: YES.
+
+BRIEF_MEMBER_DESC = YES
+
+# If the REPEAT_BRIEF tag is set to YES, doxygen will prepend the brief
+# description of a member or function before the detailed description
+#
+# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the
+# brief descriptions will be completely suppressed.
+# The default value is: YES.
+
+REPEAT_BRIEF = YES
+
+# This tag implements a quasi-intelligent brief description abbreviator that is
+# used to form the text in various listings. Each string in this list, if found
+# as the leading text of the brief description, will be stripped from the text
+# and the result, after processing the whole list, is used as the annotated
+# text. Otherwise, the brief description is used as-is. If left blank, the
+# following values are used ($name is automatically replaced with the name of
+# the entity):The $name class, The $name widget, The $name file, is, provides,
+# specifies, contains, represents, a, an and the.
+
+ABBREVIATE_BRIEF = "The $name class" \
+ "The $name widget" \
+ "The $name file" \
+ is \
+ provides \
+ specifies \
+ contains \
+ represents \
+ a \
+ an \
+ the
+
+# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then
+# doxygen will generate a detailed section even if there is only a brief
+# description.
+# The default value is: NO.
+
+ALWAYS_DETAILED_SEC = NO
+
+# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all
+# inherited members of a class in the documentation of that class as if those
+# members were ordinary class members. Constructors, destructors and assignment
+# operators of the base classes will not be shown.
+# The default value is: NO.
+
+INLINE_INHERITED_MEMB = NO
+
+# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path
+# before files name in the file list and in the header files. If set to NO the
+# shortest path that makes the file name unique will be used
+# The default value is: YES.
+
+FULL_PATH_NAMES = YES
+
+# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path.
+# Stripping is only done if one of the specified strings matches the left-hand
+# part of the path. The tag can be used to show relative paths in the file list.
+# If left blank the directory from which doxygen is run is used as the path to
+# strip.
+#
+# Note that you can specify absolute paths here, but also relative paths, which
+# will be relative from the directory where doxygen is started.
+# This tag requires that the tag FULL_PATH_NAMES is set to YES.
+
+STRIP_FROM_PATH =
+
+# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the
+# path mentioned in the documentation of a class, which tells the reader which
+# header file to include in order to use a class. If left blank only the name of
+# the header file containing the class definition is used. Otherwise one should
+# specify the list of include paths that are normally passed to the compiler
+# using the -I flag.
+
+STRIP_FROM_INC_PATH = ../src
+
+# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but
+# less readable) file names. This can be useful is your file systems doesn't
+# support long names like on DOS, Mac, or CD-ROM.
+# The default value is: NO.
+
+SHORT_NAMES = NO
+
+# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the
+# first line (until the first dot) of a Javadoc-style comment as the brief
+# description. If set to NO, the Javadoc-style will behave just like regular Qt-
+# style comments (thus requiring an explicit @brief command for a brief
+# description.)
+# The default value is: NO.
+
+JAVADOC_AUTOBRIEF = YES
+
+# If the JAVADOC_BANNER tag is set to YES then doxygen will interpret a line
+# such as
+# /***************
+# as being the beginning of a Javadoc-style comment "banner". If set to NO, the
+# Javadoc-style will behave just like regular comments and it will not be
+# interpreted by doxygen.
+# The default value is: NO.
+
+JAVADOC_BANNER = NO
+
+# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first
+# line (until the first dot) of a Qt-style comment as the brief description. If
+# set to NO, the Qt-style will behave just like regular Qt-style comments (thus
+# requiring an explicit \brief command for a brief description.)
+# The default value is: NO.
+
+QT_AUTOBRIEF = NO
+
+# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a
+# multi-line C++ special comment block (i.e. a block of //! or /// comments) as
+# a brief description. This used to be the default behavior. The new default is
+# to treat a multi-line C++ comment block as a detailed description. Set this
+# tag to YES if you prefer the old behavior instead.
+#
+# Note that setting this tag to YES also means that rational rose comments are
+# not recognized any more.
+# The default value is: NO.
+
+MULTILINE_CPP_IS_BRIEF = NO
+
+# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the
+# documentation from any documented member that it re-implements.
+# The default value is: YES.
+
+INHERIT_DOCS = YES
+
+# If the SEPARATE_MEMBER_PAGES tag is set to YES then doxygen will produce a new
+# page for each member. If set to NO, the documentation of a member will be part
+# of the file/class/namespace that contains it.
+# The default value is: NO.
+
+SEPARATE_MEMBER_PAGES = NO
+
+# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen
+# uses this value to replace tabs by spaces in code fragments.
+# Minimum value: 1, maximum value: 16, default value: 4.
+
+TAB_SIZE = 4
+
+# This tag can be used to specify a number of aliases that act as commands in
+# the documentation. An alias has the form:
+# name=value
+# For example adding
+# "sideeffect=@par Side Effects:\n"
+# will allow you to put the command \sideeffect (or @sideeffect) in the
+# documentation, which will result in a user-defined paragraph with heading
+# "Side Effects:". You can put \n's in the value part of an alias to insert
+# newlines (in the resulting output). You can put ^^ in the value part of an
+# alias to insert a newline as if a physical newline was in the original file.
+# When you need a literal { or } or , in the value part of an alias you have to
+# escape them by means of a backslash (\), this can lead to conflicts with the
+# commands \{ and \} for these it is advised to use the version @{ and @} or use
+# a double escape (\\{ and \\})
+
+ALIASES =
+
+# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources
+# only. Doxygen will then generate output that is more tailored for C. For
+# instance, some of the names that are used will be different. The list of all
+# members will be omitted, etc.
+# The default value is: NO.
+
+OPTIMIZE_OUTPUT_FOR_C = NO
+
+# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or
+# Python sources only. Doxygen will then generate output that is more tailored
+# for that language. For instance, namespaces will be presented as packages,
+# qualified scopes will look different, etc.
+# The default value is: NO.
+
+OPTIMIZE_OUTPUT_JAVA = NO
+
+# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran
+# sources. Doxygen will then generate output that is tailored for Fortran.
+# The default value is: NO.
+
+OPTIMIZE_FOR_FORTRAN = NO
+
+# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL
+# sources. Doxygen will then generate output that is tailored for VHDL.
+# The default value is: NO.
+
+OPTIMIZE_OUTPUT_VHDL = NO
+
+# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice
+# sources only. Doxygen will then generate output that is more tailored for that
+# language. For instance, namespaces will be presented as modules, types will be
+# separated into more groups, etc.
+# The default value is: NO.
+
+OPTIMIZE_OUTPUT_SLICE = NO
+
+# Doxygen selects the parser to use depending on the extension of the files it
+# parses. With this tag you can assign which parser to use for a given
+# extension. Doxygen has a built-in mapping, but you can override or extend it
+# using this tag. The format is ext=language, where ext is a file extension, and
+# language is one of the parsers supported by doxygen: IDL, Java, JavaScript,
+# Csharp (C#), C, C++, D, PHP, md (Markdown), Objective-C, Python, Slice, VHDL,
+# Fortran (fixed format Fortran: FortranFixed, free formatted Fortran:
+# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser
+# tries to guess whether the code is fixed or free formatted code, this is the
+# default for Fortran type files). For instance to make doxygen treat .inc files
+# as Fortran files (default is PHP), and .f files as C (default is Fortran),
+# use: inc=Fortran f=C.
+#
+# Note: For files without extension you can use no_extension as a placeholder.
+#
+# Note that for custom extensions you also need to set FILE_PATTERNS otherwise
+# the files are not read by doxygen.
+
+EXTENSION_MAPPING =
+
+# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments
+# according to the Markdown format, which allows for more readable
+# documentation. See https://daringfireball.net/projects/markdown/ for details.
+# The output of markdown processing is further processed by doxygen, so you can
+# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in
+# case of backward compatibilities issues.
+# The default value is: YES.
+
+MARKDOWN_SUPPORT = YES
+
+# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up
+# to that level are automatically included in the table of contents, even if
+# they do not have an id attribute.
+# Note: This feature currently applies only to Markdown headings.
+# Minimum value: 0, maximum value: 99, default value: 5.
+# This tag requires that the tag MARKDOWN_SUPPORT is set to YES.
+
+TOC_INCLUDE_HEADINGS = 0
+
+# When enabled doxygen tries to link words that correspond to documented
+# classes, or namespaces to their corresponding documentation. Such a link can
+# be prevented in individual cases by putting a % sign in front of the word or
+# globally by setting AUTOLINK_SUPPORT to NO.
+# The default value is: YES.
+
+AUTOLINK_SUPPORT = YES
+
+# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want
+# to include (a tag file for) the STL sources as input, then you should set this
+# tag to YES in order to let doxygen match functions declarations and
+# definitions whose arguments contain STL classes (e.g. func(std::string);
+# versus func(std::string) {}). This also make the inheritance and collaboration
+# diagrams that involve STL classes more complete and accurate.
+# The default value is: NO.
+
+BUILTIN_STL_SUPPORT = NO
+
+# If you use Microsoft's C++/CLI language, you should set this option to YES to
+# enable parsing support.
+# The default value is: NO.
+
+CPP_CLI_SUPPORT = NO
+
+# Set the SIP_SUPPORT tag to YES if your project consists of sip (see:
+# https://www.riverbankcomputing.com/software/sip/intro) sources only. Doxygen
+# will parse them like normal C++ but will assume all classes use public instead
+# of private inheritance when no explicit protection keyword is present.
+# The default value is: NO.
+
+SIP_SUPPORT = NO
+
+# For Microsoft's IDL there are propget and propput attributes to indicate
+# getter and setter methods for a property. Setting this option to YES will make
+# doxygen to replace the get and set methods by a property in the documentation.
+# This will only work if the methods are indeed getting or setting a simple
+# type. If this is not the case, or you want to show the methods anyway, you
+# should set this option to NO.
+# The default value is: YES.
+
+IDL_PROPERTY_SUPPORT = YES
+
+# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC
+# tag is set to YES then doxygen will reuse the documentation of the first
+# member in the group (if any) for the other members of the group. By default
+# all members of a group must be documented explicitly.
+# The default value is: NO.
+
+DISTRIBUTE_GROUP_DOC = NO
+
+# If one adds a struct or class to a group and this option is enabled, then also
+# any nested class or struct is added to the same group. By default this option
+# is disabled and one has to add nested compounds explicitly via \ingroup.
+# The default value is: NO.
+
+GROUP_NESTED_COMPOUNDS = NO
+
+# Set the SUBGROUPING tag to YES to allow class member groups of the same type
+# (for instance a group of public functions) to be put as a subgroup of that
+# type (e.g. under the Public Functions section). Set it to NO to prevent
+# subgrouping. Alternatively, this can be done per class using the
+# \nosubgrouping command.
+# The default value is: YES.
+
+SUBGROUPING = YES
+
+# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions
+# are shown inside the group in which they are included (e.g. using \ingroup)
+# instead of on a separate page (for HTML and Man pages) or section (for LaTeX
+# and RTF).
+#
+# Note that this feature does not work in combination with
+# SEPARATE_MEMBER_PAGES.
+# The default value is: NO.
+
+INLINE_GROUPED_CLASSES = NO
+
+# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions
+# with only public data fields or simple typedef fields will be shown inline in
+# the documentation of the scope in which they are defined (i.e. file,
+# namespace, or group documentation), provided this scope is documented. If set
+# to NO, structs, classes, and unions are shown on a separate page (for HTML and
+# Man pages) or section (for LaTeX and RTF).
+# The default value is: NO.
+
+INLINE_SIMPLE_STRUCTS = NO
+
+# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or
+# enum is documented as struct, union, or enum with the name of the typedef. So
+# typedef struct TypeS {} TypeT, will appear in the documentation as a struct
+# with name TypeT. When disabled the typedef will appear as a member of a file,
+# namespace, or class. And the struct will be named TypeS. This can typically be
+# useful for C code in case the coding convention dictates that all compound
+# types are typedef'ed and only the typedef is referenced, never the tag name.
+# The default value is: NO.
+
+TYPEDEF_HIDES_STRUCT = NO
+
+# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This
+# cache is used to resolve symbols given their name and scope. Since this can be
+# an expensive process and often the same symbol appears multiple times in the
+# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small
+# doxygen will become slower. If the cache is too large, memory is wasted. The
+# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range
+# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536
+# symbols. At the end of a run doxygen will report the cache usage and suggest
+# the optimal cache size from a speed point of view.
+# Minimum value: 0, maximum value: 9, default value: 0.
+
+LOOKUP_CACHE_SIZE = 0
+
+#---------------------------------------------------------------------------
+# Build related configuration options
+#---------------------------------------------------------------------------
+
+# If the EXTRACT_ALL tag is set to YES, doxygen will assume all entities in
+# documentation are documented, even if no documentation was available. Private
+# class members and static file members will be hidden unless the
+# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES.
+# Note: This will also disable the warnings about undocumented members that are
+# normally produced when WARNINGS is set to YES.
+# The default value is: NO.
+
+EXTRACT_ALL = YES
+
+# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will
+# be included in the documentation.
+# The default value is: NO.
+
+EXTRACT_PRIVATE = NO
+
+# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual
+# methods of a class will be included in the documentation.
+# The default value is: NO.
+
+EXTRACT_PRIV_VIRTUAL = NO
+
+# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal
+# scope will be included in the documentation.
+# The default value is: NO.
+
+EXTRACT_PACKAGE = NO
+
+# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be
+# included in the documentation.
+# The default value is: NO.
+
+EXTRACT_STATIC = NO
+
+# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined
+# locally in source files will be included in the documentation. If set to NO,
+# only classes defined in header files are included. Does not have any effect
+# for Java sources.
+# The default value is: YES.
+
+EXTRACT_LOCAL_CLASSES = YES
+
+# This flag is only useful for Objective-C code. If set to YES, local methods,
+# which are defined in the implementation section but not in the interface are
+# included in the documentation. If set to NO, only methods in the interface are
+# included.
+# The default value is: NO.
+
+EXTRACT_LOCAL_METHODS = NO
+
+# If this flag is set to YES, the members of anonymous namespaces will be
+# extracted and appear in the documentation as a namespace called
+# 'anonymous_namespace{file}', where file will be replaced with the base name of
+# the file that contains the anonymous namespace. By default anonymous namespace
+# are hidden.
+# The default value is: NO.
+
+EXTRACT_ANON_NSPACES = NO
+
+# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all
+# undocumented members inside documented classes or files. If set to NO these
+# members will be included in the various overviews, but no documentation
+# section is generated. This option has no effect if EXTRACT_ALL is enabled.
+# The default value is: NO.
+
+HIDE_UNDOC_MEMBERS = NO
+
+# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all
+# undocumented classes that are normally visible in the class hierarchy. If set
+# to NO, these classes will be included in the various overviews. This option
+# has no effect if EXTRACT_ALL is enabled.
+# The default value is: NO.
+
+HIDE_UNDOC_CLASSES = NO
+
+# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend
+# declarations. If set to NO, these declarations will be included in the
+# documentation.
+# The default value is: NO.
+
+HIDE_FRIEND_COMPOUNDS = YES
+
+# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any
+# documentation blocks found inside the body of a function. If set to NO, these
+# blocks will be appended to the function's detailed documentation block.
+# The default value is: NO.
+
+HIDE_IN_BODY_DOCS = NO
+
+# The INTERNAL_DOCS tag determines if documentation that is typed after a
+# \internal command is included. If the tag is set to NO then the documentation
+# will be excluded. Set it to YES to include the internal documentation.
+# The default value is: NO.
+
+INTERNAL_DOCS = NO
+
+# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file
+# names in lower-case letters. If set to YES, upper-case letters are also
+# allowed. This is useful if you have classes or files whose names only differ
+# in case and if your file system supports case sensitive file names. Windows
+# (including Cygwin) ands Mac users are advised to set this option to NO.
+# The default value is: system dependent.
+
+CASE_SENSE_NAMES = NO
+
+# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with
+# their full class and namespace scopes in the documentation. If set to YES, the
+# scope will be hidden.
+# The default value is: NO.
+
+HIDE_SCOPE_NAMES = NO
+
+# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then doxygen will
+# append additional text to a page's title, such as Class Reference. If set to
+# YES the compound reference will be hidden.
+# The default value is: NO.
+
+HIDE_COMPOUND_REFERENCE= NO
+
+# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of
+# the files that are included by a file in the documentation of that file.
+# The default value is: YES.
+
+SHOW_INCLUDE_FILES = YES
+
+# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each
+# grouped member an include statement to the documentation, telling the reader
+# which file to include in order to use the member.
+# The default value is: NO.
+
+SHOW_GROUPED_MEMB_INC = NO
+
+# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include
+# files with double quotes in the documentation rather than with sharp brackets.
+# The default value is: NO.
+
+FORCE_LOCAL_INCLUDES = NO
+
+# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the
+# documentation for inline members.
+# The default value is: YES.
+
+INLINE_INFO = YES
+
+# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the
+# (detailed) documentation of file and class members alphabetically by member
+# name. If set to NO, the members will appear in declaration order.
+# The default value is: YES.
+
+SORT_MEMBER_DOCS = YES
+
+# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief
+# descriptions of file, namespace and class members alphabetically by member
+# name. If set to NO, the members will appear in declaration order. Note that
+# this will also influence the order of the classes in the class list.
+# The default value is: NO.
+
+SORT_BRIEF_DOCS = NO
+
+# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the
+# (brief and detailed) documentation of class members so that constructors and
+# destructors are listed first. If set to NO the constructors will appear in the
+# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS.
+# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief
+# member documentation.
+# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting
+# detailed member documentation.
+# The default value is: NO.
+
+SORT_MEMBERS_CTORS_1ST = NO
+
+# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy
+# of group names into alphabetical order. If set to NO the group names will
+# appear in their defined order.
+# The default value is: NO.
+
+SORT_GROUP_NAMES = NO
+
+# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by
+# fully-qualified names, including namespaces. If set to NO, the class list will
+# be sorted only by class name, not including the namespace part.
+# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES.
+# Note: This option applies only to the class list, not to the alphabetical
+# list.
+# The default value is: NO.
+
+SORT_BY_SCOPE_NAME = NO
+
+# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper
+# type resolution of all parameters of a function it will reject a match between
+# the prototype and the implementation of a member function even if there is
+# only one candidate or it is obvious which candidate to choose by doing a
+# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still
+# accept a match between prototype and implementation in such cases.
+# The default value is: NO.
+
+STRICT_PROTO_MATCHING = NO
+
+# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo
+# list. This list is created by putting \todo commands in the documentation.
+# The default value is: YES.
+
+GENERATE_TODOLIST = YES
+
+# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test
+# list. This list is created by putting \test commands in the documentation.
+# The default value is: YES.
+
+GENERATE_TESTLIST = YES
+
+# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug
+# list. This list is created by putting \bug commands in the documentation.
+# The default value is: YES.
+
+GENERATE_BUGLIST = YES
+
+# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO)
+# the deprecated list. This list is created by putting \deprecated commands in
+# the documentation.
+# The default value is: YES.
+
+GENERATE_DEPRECATEDLIST= YES
+
+# The ENABLED_SECTIONS tag can be used to enable conditional documentation
+# sections, marked by \if <section_label> ... \endif and \cond <section_label>
+# ... \endcond blocks.
+
+ENABLED_SECTIONS =
+
+# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the
+# initial value of a variable or macro / define can have for it to appear in the
+# documentation. If the initializer consists of more lines than specified here
+# it will be hidden. Use a value of 0 to hide initializers completely. The
+# appearance of the value of individual variables and macros / defines can be
+# controlled using \showinitializer or \hideinitializer command in the
+# documentation regardless of this setting.
+# Minimum value: 0, maximum value: 10000, default value: 30.
+
+MAX_INITIALIZER_LINES = 30
+
+# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at
+# the bottom of the documentation of classes and structs. If set to YES, the
+# list will mention the files that were used to generate the documentation.
+# The default value is: YES.
+
+SHOW_USED_FILES = YES
+
+# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This
+# will remove the Files entry from the Quick Index and from the Folder Tree View
+# (if specified).
+# The default value is: YES.
+
+SHOW_FILES = YES
+
+# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces
+# page. This will remove the Namespaces entry from the Quick Index and from the
+# Folder Tree View (if specified).
+# The default value is: YES.
+
+SHOW_NAMESPACES = YES
+
+# The FILE_VERSION_FILTER tag can be used to specify a program or script that
+# doxygen should invoke to get the current version for each file (typically from
+# the version control system). Doxygen will invoke the program by executing (via
+# popen()) the command command input-file, where command is the value of the
+# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided
+# by doxygen. Whatever the program writes to standard output is used as the file
+# version. For an example see the documentation.
+
+FILE_VERSION_FILTER =
+
+# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed
+# by doxygen. The layout file controls the global structure of the generated
+# output files in an output format independent way. To create the layout file
+# that represents doxygen's defaults, run doxygen with the -l option. You can
+# optionally specify a file name after the option, if omitted DoxygenLayout.xml
+# will be used as the name of the layout file.
+#
+# Note that if you run doxygen from a directory containing a file called
+# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE
+# tag is left empty.
+
+LAYOUT_FILE =
+
+# The CITE_BIB_FILES tag can be used to specify one or more bib files containing
+# the reference definitions. This must be a list of .bib files. The .bib
+# extension is automatically appended if omitted. This requires the bibtex tool
+# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info.
+# For LaTeX the style of the bibliography can be controlled using
+# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the
+# search path. See also \cite for info how to create references.
+
+CITE_BIB_FILES =
+
+#---------------------------------------------------------------------------
+# Configuration options related to warning and progress messages
+#---------------------------------------------------------------------------
+
+# The QUIET tag can be used to turn on/off the messages that are generated to
+# standard output by doxygen. If QUIET is set to YES this implies that the
+# messages are off.
+# The default value is: NO.
+
+QUIET = YES
+
+# The WARNINGS tag can be used to turn on/off the warning messages that are
+# generated to standard error (stderr) by doxygen. If WARNINGS is set to YES
+# this implies that the warnings are on.
+#
+# Tip: Turn warnings on while writing the documentation.
+# The default value is: YES.
+
+WARNINGS = YES
+
+# If the WARN_IF_UNDOCUMENTED tag is set to YES then doxygen will generate
+# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag
+# will automatically be disabled.
+# The default value is: YES.
+
+WARN_IF_UNDOCUMENTED = YES
+
+# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for
+# potential errors in the documentation, such as not documenting some parameters
+# in a documented function, or documenting parameters that don't exist or using
+# markup commands wrongly.
+# The default value is: YES.
+
+WARN_IF_DOC_ERROR = YES
+
+# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that
+# are documented, but have no documentation for their parameters or return
+# value. If set to NO, doxygen will only warn about wrong or incomplete
+# parameter documentation, but not about the absence of documentation. If
+# EXTRACT_ALL is set to YES then this flag will automatically be disabled.
+# The default value is: NO.
+
+WARN_NO_PARAMDOC = NO
+
+# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when
+# a warning is encountered.
+# The default value is: NO.
+
+WARN_AS_ERROR = YES
+
+# The WARN_FORMAT tag determines the format of the warning messages that doxygen
+# can produce. The string should contain the $file, $line, and $text tags, which
+# will be replaced by the file and line number from which the warning originated
+# and the warning text. Optionally the format may contain $version, which will
+# be replaced by the version of the file (if it could be obtained via
+# FILE_VERSION_FILTER)
+# The default value is: $file:$line: $text.
+
+WARN_FORMAT = "$file:$line: $text"
+
+# The WARN_LOGFILE tag can be used to specify a file to which warning and error
+# messages should be written. If left blank the output is written to standard
+# error (stderr).
+
+WARN_LOGFILE =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the input files
+#---------------------------------------------------------------------------
+
+# The INPUT tag is used to specify the files and/or directories that contain
+# documented source files. You may enter file names like myfile.cpp or
+# directories like /usr/src/myproject. Separate the files or directories with
+# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING
+# Note: If this tag is empty the current directory is searched.
+
+INPUT = ../src \
+ .
+
+# This tag can be used to specify the character encoding of the source files
+# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
+# libiconv (or the iconv built into libc) for the transcoding. See the libiconv
+# documentation (see: https://www.gnu.org/software/libiconv/) for the list of
+# possible encodings.
+# The default value is: UTF-8.
+
+INPUT_ENCODING = UTF-8
+
+# If the value of the INPUT tag contains directories, you can use the
+# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and
+# *.h) to filter out the source-files in the directories.
+#
+# Note that for custom extensions or not directly supported extensions you also
+# need to set EXTENSION_MAPPING for the extension otherwise the files are not
+# read by doxygen.
+#
+# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp,
+# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h,
+# *.hh, *.hxx, *.hpp, *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc,
+# *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C comment),
+# *.doc (to be provided as doxygen C comment), *.txt (to be provided as doxygen
+# C comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f18, *.f, *.for, *.vhd,
+# *.vhdl, *.ucf, *.qsf and *.ice.
+
+FILE_PATTERNS = *.h \
+ *.hh \
+ *.hxx \
+ *.hpp \
+ *.inc \
+ *.m \
+ *.markdown \
+ *.md \
+ *.mm \
+ *.dox \
+ *.py
+
+# The RECURSIVE tag can be used to specify whether or not subdirectories should
+# be searched for input files as well.
+# The default value is: NO.
+
+RECURSIVE = YES
+
+# The EXCLUDE tag can be used to specify files and/or directories that should be
+# excluded from the INPUT source files. This way you can easily exclude a
+# subdirectory from a directory tree whose root is specified with the INPUT tag.
+#
+# Note that relative paths are relative to the directory from which doxygen is
+# run.
+
+EXCLUDE = ../src/arrow/vendored \
+ ../src/generated
+
+# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or
+# directories that are symbolic links (a Unix file system feature) are excluded
+# from the input.
+# The default value is: NO.
+
+EXCLUDE_SYMLINKS = NO
+
+# If the value of the INPUT tag contains directories, you can use the
+# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude
+# certain files from those directories.
+#
+# Note that the wildcards are matched against the file with absolute path, so to
+# exclude all test directories for example use the pattern */test/*
+
+EXCLUDE_PATTERNS = *-test.cc \
+ *test* \
+ *_generated.h \
+ *-benchmark.cc \
+ *internal*
+
+# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names
+# (namespaces, classes, functions, etc.) that should be excluded from the
+# output. The symbol name can be a fully qualified name, a word, or if the
+# wildcard * is used, a substring. Examples: ANamespace, AClass,
+# AClass::ANamespace, ANamespace::*Test
+#
+# Note that the wildcards are matched against the file with absolute path, so to
+# exclude all test directories use the pattern */test/*
+
+EXCLUDE_SYMBOLS = *::detail \
+ *::internal \
+ _* \
+ BitUtil \
+ SSEUtil
+
+# The EXAMPLE_PATH tag can be used to specify one or more files or directories
+# that contain example code fragments that are included (see the \include
+# command).
+
+EXAMPLE_PATH =
+
+# If the value of the EXAMPLE_PATH tag contains directories, you can use the
+# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and
+# *.h) to filter out the source-files in the directories. If left blank all
+# files are included.
+
+EXAMPLE_PATTERNS = *
+
+# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be
+# searched for input files to be used with the \include or \dontinclude commands
+# irrespective of the value of the RECURSIVE tag.
+# The default value is: NO.
+
+EXAMPLE_RECURSIVE = NO
+
+# The IMAGE_PATH tag can be used to specify one or more files or directories
+# that contain images that are to be included in the documentation (see the
+# \image command).
+
+IMAGE_PATH =
+
+# The INPUT_FILTER tag can be used to specify a program that doxygen should
+# invoke to filter for each input file. Doxygen will invoke the filter program
+# by executing (via popen()) the command:
+#
+# <filter> <input-file>
+#
+# where <filter> is the value of the INPUT_FILTER tag, and <input-file> is the
+# name of an input file. Doxygen will then use the output that the filter
+# program writes to standard output. If FILTER_PATTERNS is specified, this tag
+# will be ignored.
+#
+# Note that the filter must not add or remove lines; it is applied before the
+# code is scanned, but not when the output code is generated. If lines are added
+# or removed, the anchors will not be placed correctly.
+#
+# Note that for custom extensions or not directly supported extensions you also
+# need to set EXTENSION_MAPPING for the extension otherwise the files are not
+# properly processed by doxygen.
+
+INPUT_FILTER =
+
+# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern
+# basis. Doxygen will compare the file name with each pattern and apply the
+# filter if there is a match. The filters are a list of the form: pattern=filter
+# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how
+# filters are used. If the FILTER_PATTERNS tag is empty or if none of the
+# patterns match the file name, INPUT_FILTER is applied.
+#
+# Note that for custom extensions or not directly supported extensions you also
+# need to set EXTENSION_MAPPING for the extension otherwise the files are not
+# properly processed by doxygen.
+
+FILTER_PATTERNS =
+
+# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using
+# INPUT_FILTER) will also be used to filter the input files that are used for
+# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES).
+# The default value is: NO.
+
+FILTER_SOURCE_FILES = NO
+
+# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file
+# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and
+# it is also possible to disable source filtering for a specific pattern using
+# *.ext= (so without naming a filter).
+# This tag requires that the tag FILTER_SOURCE_FILES is set to YES.
+
+FILTER_SOURCE_PATTERNS =
+
+# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that
+# is part of the input, its contents will be placed on the main page
+# (index.html). This can be useful if you have a project on for instance GitHub
+# and want to reuse the introduction page also for the doxygen output.
+
+USE_MDFILE_AS_MAINPAGE =
+
+#---------------------------------------------------------------------------
+# Configuration options related to source browsing
+#---------------------------------------------------------------------------
+
+# If the SOURCE_BROWSER tag is set to YES then a list of source files will be
+# generated. Documented entities will be cross-referenced with these sources.
+#
+# Note: To get rid of all source code in the generated output, make sure that
+# also VERBATIM_HEADERS is set to NO.
+# The default value is: NO.
+
+SOURCE_BROWSER = NO
+
+# Setting the INLINE_SOURCES tag to YES will include the body of functions,
+# classes and enums directly into the documentation.
+# The default value is: NO.
+
+INLINE_SOURCES = NO
+
+# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any
+# special comment blocks from generated source code fragments. Normal C, C++ and
+# Fortran comments will always remain visible.
+# The default value is: YES.
+
+STRIP_CODE_COMMENTS = YES
+
+# If the REFERENCED_BY_RELATION tag is set to YES then for each documented
+# entity all documented functions referencing it will be listed.
+# The default value is: NO.
+
+REFERENCED_BY_RELATION = NO
+
+# If the REFERENCES_RELATION tag is set to YES then for each documented function
+# all documented entities called/used by that function will be listed.
+# The default value is: NO.
+
+REFERENCES_RELATION = NO
+
+# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set
+# to YES then the hyperlinks from functions in REFERENCES_RELATION and
+# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will
+# link to the documentation.
+# The default value is: YES.
+
+REFERENCES_LINK_SOURCE = YES
+
+# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the
+# source code will show a tooltip with additional information such as prototype,
+# brief description and links to the definition and documentation. Since this
+# will make the HTML file larger and loading of large files a bit slower, you
+# can opt to disable this feature.
+# The default value is: YES.
+# This tag requires that the tag SOURCE_BROWSER is set to YES.
+
+SOURCE_TOOLTIPS = YES
+
+# If the USE_HTAGS tag is set to YES then the references to source code will
+# point to the HTML generated by the htags(1) tool instead of doxygen built-in
+# source browser. The htags tool is part of GNU's global source tagging system
+# (see https://www.gnu.org/software/global/global.html). You will need version
+# 4.8.6 or higher.
+#
+# To use it do the following:
+# - Install the latest version of global
+# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file
+# - Make sure the INPUT points to the root of the source tree
+# - Run doxygen as normal
+#
+# Doxygen will invoke htags (and that will in turn invoke gtags), so these
+# tools must be available from the command line (i.e. in the search path).
+#
+# The result: instead of the source browser generated by doxygen, the links to
+# source code will now point to the output of htags.
+# The default value is: NO.
+# This tag requires that the tag SOURCE_BROWSER is set to YES.
+
+USE_HTAGS = NO
+
+# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a
+# verbatim copy of the header file for each class for which an include is
+# specified. Set to NO to disable this.
+# See also: Section \class.
+# The default value is: YES.
+
+VERBATIM_HEADERS = YES
+
+#---------------------------------------------------------------------------
+# Configuration options related to the alphabetical class index
+#---------------------------------------------------------------------------
+
+# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all
+# compounds will be generated. Enable this if the project contains a lot of
+# classes, structs, unions or interfaces.
+# The default value is: YES.
+
+ALPHABETICAL_INDEX = YES
+
+# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in
+# which the alphabetical index list will be split.
+# Minimum value: 1, maximum value: 20, default value: 5.
+# This tag requires that the tag ALPHABETICAL_INDEX is set to YES.
+
+COLS_IN_ALPHA_INDEX = 5
+
+# In case all classes in a project start with a common prefix, all classes will
+# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag
+# can be used to specify a prefix (or a list of prefixes) that should be ignored
+# while generating the index headers.
+# This tag requires that the tag ALPHABETICAL_INDEX is set to YES.
+
+IGNORE_PREFIX =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the HTML output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output
+# The default value is: YES.
+
+GENERATE_HTML = YES
+
+# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it.
+# The default directory is: html.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_OUTPUT = html
+
+# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each
+# generated HTML page (for example: .htm, .php, .asp).
+# The default value is: .html.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_FILE_EXTENSION = .html
+
+# The HTML_HEADER tag can be used to specify a user-defined HTML header file for
+# each generated HTML page. If the tag is left blank doxygen will generate a
+# standard header.
+#
+# To get valid HTML the header file that includes any scripts and style sheets
+# that doxygen needs, which is dependent on the configuration options used (e.g.
+# the setting GENERATE_TREEVIEW). It is highly recommended to start with a
+# default header using
+# doxygen -w html new_header.html new_footer.html new_stylesheet.css
+# YourConfigFile
+# and then modify the file new_header.html. See also section "Doxygen usage"
+# for information on how to generate the default header that doxygen normally
+# uses.
+# Note: The header is subject to change so you typically have to regenerate the
+# default header when upgrading to a newer version of doxygen. For a description
+# of the possible markers and block names see the documentation.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_HEADER =
+
+# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each
+# generated HTML page. If the tag is left blank doxygen will generate a standard
+# footer. See HTML_HEADER for more information on how to generate a default
+# footer and what special commands can be used inside the footer. See also
+# section "Doxygen usage" for information on how to generate the default footer
+# that doxygen normally uses.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_FOOTER = footer.html
+
+# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style
+# sheet that is used by each HTML page. It can be used to fine-tune the look of
+# the HTML output. If left blank doxygen will generate a default style sheet.
+# See also section "Doxygen usage" for information on how to generate the style
+# sheet that doxygen normally uses.
+# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as
+# it is more robust and this tag (HTML_STYLESHEET) will in the future become
+# obsolete.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_STYLESHEET =
+
+# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined
+# cascading style sheets that are included after the standard style sheets
+# created by doxygen. Using this option one can overrule certain style aspects.
+# This is preferred over using HTML_STYLESHEET since it does not replace the
+# standard style sheet and is therefore more robust against future updates.
+# Doxygen will copy the style sheet files to the output directory.
+# Note: The order of the extra style sheet files is of importance (e.g. the last
+# style sheet in the list overrules the setting of the previous ones in the
+# list). For an example see the documentation.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_EXTRA_STYLESHEET =
+
+# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or
+# other source files which should be copied to the HTML output directory. Note
+# that these files will be copied to the base HTML output directory. Use the
+# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these
+# files. In the HTML_STYLESHEET file, use the file name only. Also note that the
+# files will be copied as-is; there are no commands or markers available.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_EXTRA_FILES =
+
+# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen
+# will adjust the colors in the style sheet and background images according to
+# this color. Hue is specified as an angle on a colorwheel, see
+# https://en.wikipedia.org/wiki/Hue for more information. For instance the value
+# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300
+# purple, and 360 is red again.
+# Minimum value: 0, maximum value: 359, default value: 220.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_COLORSTYLE_HUE = 220
+
+# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors
+# in the HTML output. For a value of 0 the output will use grayscales only. A
+# value of 255 will produce the most vivid colors.
+# Minimum value: 0, maximum value: 255, default value: 100.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_COLORSTYLE_SAT = 100
+
+# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the
+# luminance component of the colors in the HTML output. Values below 100
+# gradually make the output lighter, whereas values above 100 make the output
+# darker. The value divided by 100 is the actual gamma applied, so 80 represents
+# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not
+# change the gamma.
+# Minimum value: 40, maximum value: 240, default value: 80.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_COLORSTYLE_GAMMA = 80
+
+# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML
+# page will contain the date and time when the page was generated. Setting this
+# to YES can help to show when doxygen was last run and thus if the
+# documentation is up to date.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_TIMESTAMP = NO
+
+# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML
+# documentation will contain a main index with vertical navigation menus that
+# are dynamically created via JavaScript. If disabled, the navigation index will
+# consists of multiple levels of tabs that are statically embedded in every HTML
+# page. Disable this option to support browsers that do not have JavaScript,
+# like the Qt help browser.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_DYNAMIC_MENUS = YES
+
+# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML
+# documentation will contain sections that can be hidden and shown after the
+# page has loaded.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_DYNAMIC_SECTIONS = NO
+
+# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries
+# shown in the various tree structured indices initially; the user can expand
+# and collapse entries dynamically later on. Doxygen will expand the tree to
+# such a level that at most the specified number of entries are visible (unless
+# a fully collapsed tree already exceeds this amount). So setting the number of
+# entries 1 will produce a full collapsed tree by default. 0 is a special value
+# representing an infinite number of entries and will result in a full expanded
+# tree by default.
+# Minimum value: 0, maximum value: 9999, default value: 100.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_INDEX_NUM_ENTRIES = 100
+
+# If the GENERATE_DOCSET tag is set to YES, additional index files will be
+# generated that can be used as input for Apple's Xcode 3 integrated development
+# environment (see: https://developer.apple.com/xcode/), introduced with OSX
+# 10.5 (Leopard). To create a documentation set, doxygen will generate a
+# Makefile in the HTML output directory. Running make will produce the docset in
+# that directory and running make install will install the docset in
+# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at
+# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy
+# genXcode/_index.html for more information.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_DOCSET = NO
+
+# This tag determines the name of the docset feed. A documentation feed provides
+# an umbrella under which multiple documentation sets from a single provider
+# (such as a company or product suite) can be grouped.
+# The default value is: Doxygen generated docs.
+# This tag requires that the tag GENERATE_DOCSET is set to YES.
+
+DOCSET_FEEDNAME = "Doxygen generated docs"
+
+# This tag specifies a string that should uniquely identify the documentation
+# set bundle. This should be a reverse domain-name style string, e.g.
+# com.mycompany.MyDocSet. Doxygen will append .docset to the name.
+# The default value is: org.doxygen.Project.
+# This tag requires that the tag GENERATE_DOCSET is set to YES.
+
+DOCSET_BUNDLE_ID = org.doxygen.Project
+
+# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify
+# the documentation publisher. This should be a reverse domain-name style
+# string, e.g. com.mycompany.MyDocSet.documentation.
+# The default value is: org.doxygen.Publisher.
+# This tag requires that the tag GENERATE_DOCSET is set to YES.
+
+DOCSET_PUBLISHER_ID = org.doxygen.Publisher
+
+# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher.
+# The default value is: Publisher.
+# This tag requires that the tag GENERATE_DOCSET is set to YES.
+
+DOCSET_PUBLISHER_NAME = Publisher
+
+# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three
+# additional HTML index files: index.hhp, index.hhc, and index.hhk. The
+# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop
+# (see: https://www.microsoft.com/en-us/download/details.aspx?id=21138) on
+# Windows.
+#
+# The HTML Help Workshop contains a compiler that can convert all HTML output
+# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML
+# files are now used as the Windows 98 help format, and will replace the old
+# Windows help format (.hlp) on all Windows platforms in the future. Compressed
+# HTML files also contain an index, a table of contents, and you can search for
+# words in the documentation. The HTML workshop also contains a viewer for
+# compressed HTML files.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_HTMLHELP = NO
+
+# The CHM_FILE tag can be used to specify the file name of the resulting .chm
+# file. You can add a path in front of the file if the result should not be
+# written to the html output directory.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+CHM_FILE =
+
+# The HHC_LOCATION tag can be used to specify the location (absolute path
+# including file name) of the HTML help compiler (hhc.exe). If non-empty,
+# doxygen will try to run the HTML help compiler on the generated index.hhp.
+# The file has to be specified with full path.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+HHC_LOCATION =
+
+# The GENERATE_CHI flag controls if a separate .chi index file is generated
+# (YES) or that it should be included in the master .chm file (NO).
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+GENERATE_CHI = NO
+
+# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc)
+# and project file content.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+CHM_INDEX_ENCODING =
+
+# The BINARY_TOC flag controls whether a binary table of contents is generated
+# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it
+# enables the Previous and Next buttons.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+BINARY_TOC = NO
+
+# The TOC_EXPAND flag can be set to YES to add extra items for group members to
+# the table of contents of the HTML help documentation and to the tree view.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+TOC_EXPAND = NO
+
+# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and
+# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that
+# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help
+# (.qch) of the generated HTML documentation.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_QHP = NO
+
+# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify
+# the file name of the resulting .qch file. The path specified is relative to
+# the HTML output folder.
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QCH_FILE =
+
+# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help
+# Project output. For more information please see Qt Help Project / Namespace
+# (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace).
+# The default value is: org.doxygen.Project.
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_NAMESPACE = org.doxygen.Project
+
+# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt
+# Help Project output. For more information please see Qt Help Project / Virtual
+# Folders (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual-
+# folders).
+# The default value is: doc.
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_VIRTUAL_FOLDER = doc
+
+# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom
+# filter to add. For more information please see Qt Help Project / Custom
+# Filters (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-
+# filters).
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_CUST_FILTER_NAME =
+
+# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the
+# custom filter to add. For more information please see Qt Help Project / Custom
+# Filters (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-
+# filters).
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_CUST_FILTER_ATTRS =
+
+# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this
+# project's filter section matches. Qt Help Project / Filter Attributes (see:
+# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes).
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_SECT_FILTER_ATTRS =
+
+# The QHG_LOCATION tag can be used to specify the location of Qt's
+# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the
+# generated .qhp file.
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHG_LOCATION =
+
+# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be
+# generated, together with the HTML files, they form an Eclipse help plugin. To
+# install this plugin and make it available under the help contents menu in
+# Eclipse, the contents of the directory containing the HTML and XML files needs
+# to be copied into the plugins directory of eclipse. The name of the directory
+# within the plugins directory should be the same as the ECLIPSE_DOC_ID value.
+# After copying Eclipse needs to be restarted before the help appears.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_ECLIPSEHELP = NO
+
+# A unique identifier for the Eclipse help plugin. When installing the plugin
+# the directory name containing the HTML and XML files should also have this
+# name. Each documentation set should have its own identifier.
+# The default value is: org.doxygen.Project.
+# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES.
+
+ECLIPSE_DOC_ID = org.doxygen.Project
+
+# If you want full control over the layout of the generated HTML pages it might
+# be necessary to disable the index and replace it with your own. The
+# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top
+# of each HTML page. A value of NO enables the index and the value YES disables
+# it. Since the tabs in the index contain the same information as the navigation
+# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+DISABLE_INDEX = NO
+
+# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index
+# structure should be generated to display hierarchical information. If the tag
+# value is set to YES, a side panel will be generated containing a tree-like
+# index structure (just like the one that is generated for HTML Help). For this
+# to work a browser that supports JavaScript, DHTML, CSS and frames is required
+# (i.e. any modern browser). Windows users are probably better off using the
+# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can
+# further fine-tune the look of the index. As an example, the default style
+# sheet generated by doxygen has an example that shows how to put an image at
+# the root of the tree instead of the PROJECT_NAME. Since the tree basically has
+# the same information as the tab index, you could consider setting
+# DISABLE_INDEX to YES when enabling this option.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_TREEVIEW = NO
+
+# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that
+# doxygen will group on one line in the generated HTML documentation.
+#
+# Note that a value of 0 will completely suppress the enum values from appearing
+# in the overview section.
+# Minimum value: 0, maximum value: 20, default value: 4.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+ENUM_VALUES_PER_LINE = 4
+
+# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used
+# to set the initial width (in pixels) of the frame in which the tree is shown.
+# Minimum value: 0, maximum value: 1500, default value: 250.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+TREEVIEW_WIDTH = 250
+
+# If the EXT_LINKS_IN_WINDOW option is set to YES, doxygen will open links to
+# external symbols imported via tag files in a separate window.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+EXT_LINKS_IN_WINDOW = NO
+
+# If the HTML_FORMULA_FORMAT option is set to svg, doxygen will use the pdf2svg
+# tool (see https://github.com/dawbarton/pdf2svg) or inkscape (see
+# https://inkscape.org) to generate formulas as SVG images instead of PNGs for
+# the HTML output. These images will generally look nicer at scaled resolutions.
+# Possible values are: png The default and svg Looks nicer but requires the
+# pdf2svg tool.
+# The default value is: png.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_FORMULA_FORMAT = png
+
+# Use this tag to change the font size of LaTeX formulas included as images in
+# the HTML documentation. When you change the font size after a successful
+# doxygen run you need to manually remove any form_*.png images from the HTML
+# output directory to force them to be regenerated.
+# Minimum value: 8, maximum value: 50, default value: 10.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+FORMULA_FONTSIZE = 10
+
+# Use the FORMULA_TRANSPARENT tag to determine whether or not the images
+# generated for formulas are transparent PNGs. Transparent PNGs are not
+# supported properly for IE 6.0, but are supported on all modern browsers.
+#
+# Note that when changing this option you need to delete any form_*.png files in
+# the HTML output directory before the changes have effect.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+FORMULA_TRANSPARENT = YES
+
+# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands
+# to create new LaTeX commands to be used in formulas as building blocks. See
+# the section "Including formulas" for details.
+
+FORMULA_MACROFILE =
+
+# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see
+# https://www.mathjax.org) which uses client side JavaScript for the rendering
+# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX
+# installed or if you want to formulas look prettier in the HTML output. When
+# enabled you may also need to install MathJax separately and configure the path
+# to it using the MATHJAX_RELPATH option.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+USE_MATHJAX = NO
+
+# When MathJax is enabled you can set the default output format to be used for
+# the MathJax output. See the MathJax site (see:
+# http://docs.mathjax.org/en/latest/output.html) for more details.
+# Possible values are: HTML-CSS (which is slower, but has the best
+# compatibility), NativeMML (i.e. MathML) and SVG.
+# The default value is: HTML-CSS.
+# This tag requires that the tag USE_MATHJAX is set to YES.
+
+MATHJAX_FORMAT = HTML-CSS
+
+# When MathJax is enabled you need to specify the location relative to the HTML
+# output directory using the MATHJAX_RELPATH option. The destination directory
+# should contain the MathJax.js script. For instance, if the mathjax directory
+# is located at the same level as the HTML output directory, then
+# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax
+# Content Delivery Network so you can quickly see the result without installing
+# MathJax. However, it is strongly recommended to install a local copy of
+# MathJax from https://www.mathjax.org before deployment.
+# The default value is: https://cdn.jsdelivr.net/npm/mathjax@2.
+# This tag requires that the tag USE_MATHJAX is set to YES.
+
+MATHJAX_RELPATH = http://cdn.mathjax.org/mathjax/latest
+
+# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax
+# extension names that should be enabled during MathJax rendering. For example
+# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols
+# This tag requires that the tag USE_MATHJAX is set to YES.
+
+MATHJAX_EXTENSIONS =
+
+# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces
+# of code that will be used on startup of the MathJax code. See the MathJax site
+# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an
+# example see the documentation.
+# This tag requires that the tag USE_MATHJAX is set to YES.
+
+MATHJAX_CODEFILE =
+
+# When the SEARCHENGINE tag is enabled doxygen will generate a search box for
+# the HTML output. The underlying search engine uses javascript and DHTML and
+# should work on any modern browser. Note that when using HTML help
+# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET)
+# there is already a search function so this one should typically be disabled.
+# For large projects the javascript based search engine can be slow, then
+# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to
+# search using the keyboard; to jump to the search box use <access key> + S
+# (what the <access key> is depends on the OS and browser, but it is typically
+# <CTRL>, <ALT>/<option>, or both). Inside the search box use the <cursor down
+# key> to jump into the search results window, the results can be navigated
+# using the <cursor keys>. Press <Enter> to select an item or <escape> to cancel
+# the search. The filter options can be selected when the cursor is inside the
+# search box by pressing <Shift>+<cursor down>. Also here use the <cursor keys>
+# to select a filter and <Enter> or <escape> to activate or cancel the filter
+# option.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+SEARCHENGINE = YES
+
+# When the SERVER_BASED_SEARCH tag is enabled the search engine will be
+# implemented using a web server instead of a web client using JavaScript. There
+# are two flavors of web server based searching depending on the EXTERNAL_SEARCH
+# setting. When disabled, doxygen will generate a PHP script for searching and
+# an index file used by the script. When EXTERNAL_SEARCH is enabled the indexing
+# and searching needs to be provided by external tools. See the section
+# "External Indexing and Searching" for details.
+# The default value is: NO.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+SERVER_BASED_SEARCH = NO
+
+# When EXTERNAL_SEARCH tag is enabled doxygen will no longer generate the PHP
+# script for searching. Instead the search results are written to an XML file
+# which needs to be processed by an external indexer. Doxygen will invoke an
+# external search engine pointed to by the SEARCHENGINE_URL option to obtain the
+# search results.
+#
+# Doxygen ships with an example indexer (doxyindexer) and search engine
+# (doxysearch.cgi) which are based on the open source search engine library
+# Xapian (see: https://xapian.org/).
+#
+# See the section "External Indexing and Searching" for details.
+# The default value is: NO.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+EXTERNAL_SEARCH = NO
+
+# The SEARCHENGINE_URL should point to a search engine hosted by a web server
+# which will return the search results when EXTERNAL_SEARCH is enabled.
+#
+# Doxygen ships with an example indexer (doxyindexer) and search engine
+# (doxysearch.cgi) which are based on the open source search engine library
+# Xapian (see: https://xapian.org/). See the section "External Indexing and
+# Searching" for details.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+SEARCHENGINE_URL =
+
+# When SERVER_BASED_SEARCH and EXTERNAL_SEARCH are both enabled the unindexed
+# search data is written to a file for indexing by an external tool. With the
+# SEARCHDATA_FILE tag the name of this file can be specified.
+# The default file is: searchdata.xml.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+SEARCHDATA_FILE = searchdata.xml
+
+# When SERVER_BASED_SEARCH and EXTERNAL_SEARCH are both enabled the
+# EXTERNAL_SEARCH_ID tag can be used as an identifier for the project. This is
+# useful in combination with EXTRA_SEARCH_MAPPINGS to search through multiple
+# projects and redirect the results back to the right project.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+EXTERNAL_SEARCH_ID =
+
+# The EXTRA_SEARCH_MAPPINGS tag can be used to enable searching through doxygen
+# projects other than the one defined by this configuration file, but that are
+# all added to the same external search index. Each project needs to have a
+# unique id set via EXTERNAL_SEARCH_ID. The search mapping then maps the id of
+# to a relative location where the documentation can be found. The format is:
+# EXTRA_SEARCH_MAPPINGS = tagname1=loc1 tagname2=loc2 ...
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+EXTRA_SEARCH_MAPPINGS =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the LaTeX output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_LATEX tag is set to YES, doxygen will generate LaTeX output.
+# The default value is: YES.
+
+GENERATE_LATEX = NO
+
+# The LATEX_OUTPUT tag is used to specify where the LaTeX docs will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it.
+# The default directory is: latex.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_OUTPUT = latex
+
+# The LATEX_CMD_NAME tag can be used to specify the LaTeX command name to be
+# invoked.
+#
+# Note that when not enabling USE_PDFLATEX the default is latex when enabling
+# USE_PDFLATEX the default is pdflatex and when in the later case latex is
+# chosen this is overwritten by pdflatex. For specific output languages the
+# default can have been set differently, this depends on the implementation of
+# the output language.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_CMD_NAME = latex
+
+# The MAKEINDEX_CMD_NAME tag can be used to specify the command name to generate
+# index for LaTeX.
+# Note: This tag is used in the Makefile / make.bat.
+# See also: LATEX_MAKEINDEX_CMD for the part in the generated output file
+# (.tex).
+# The default file is: makeindex.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+MAKEINDEX_CMD_NAME = makeindex
+
+# The LATEX_MAKEINDEX_CMD tag can be used to specify the command name to
+# generate index for LaTeX. In case there is no backslash (\) as first character
+# it will be automatically added in the LaTeX code.
+# Note: This tag is used in the generated output file (.tex).
+# See also: MAKEINDEX_CMD_NAME for the part in the Makefile / make.bat.
+# The default value is: makeindex.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_MAKEINDEX_CMD = makeindex
+
+# If the COMPACT_LATEX tag is set to YES, doxygen generates more compact LaTeX
+# documents. This may be useful for small projects and may help to save some
+# trees in general.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+COMPACT_LATEX = NO
+
+# The PAPER_TYPE tag can be used to set the paper type that is used by the
+# printer.
+# Possible values are: a4 (210 x 297 mm), letter (8.5 x 11 inches), legal (8.5 x
+# 14 inches) and executive (7.25 x 10.5 inches).
+# The default value is: a4.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+PAPER_TYPE = a4
+
+# The EXTRA_PACKAGES tag can be used to specify one or more LaTeX package names
+# that should be included in the LaTeX output. The package can be specified just
+# by its name or with the correct syntax as to be used with the LaTeX
+# \usepackage command. To get the times font for instance you can specify :
+# EXTRA_PACKAGES=times or EXTRA_PACKAGES={times}
+# To use the option intlimits with the amsmath package you can specify:
+# EXTRA_PACKAGES=[intlimits]{amsmath}
+# If left blank no extra packages will be included.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+EXTRA_PACKAGES =
+
+# The LATEX_HEADER tag can be used to specify a personal LaTeX header for the
+# generated LaTeX document. The header should contain everything until the first
+# chapter. If it is left blank doxygen will generate a standard header. See
+# section "Doxygen usage" for information on how to let doxygen write the
+# default header to a separate file.
+#
+# Note: Only use a user-defined header if you know what you are doing! The
+# following commands have a special meaning inside the header: $title,
+# $datetime, $date, $doxygenversion, $projectname, $projectnumber,
+# $projectbrief, $projectlogo. Doxygen will replace $title with the empty
+# string, for the replacement values of the other commands the user is referred
+# to HTML_HEADER.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_HEADER =
+
+# The LATEX_FOOTER tag can be used to specify a personal LaTeX footer for the
+# generated LaTeX document. The footer should contain everything after the last
+# chapter. If it is left blank doxygen will generate a standard footer. See
+# LATEX_HEADER for more information on how to generate a default footer and what
+# special commands can be used inside the footer.
+#
+# Note: Only use a user-defined footer if you know what you are doing!
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_FOOTER =
+
+# The LATEX_EXTRA_STYLESHEET tag can be used to specify additional user-defined
+# LaTeX style sheets that are included after the standard style sheets created
+# by doxygen. Using this option one can overrule certain style aspects. Doxygen
+# will copy the style sheet files to the output directory.
+# Note: The order of the extra style sheet files is of importance (e.g. the last
+# style sheet in the list overrules the setting of the previous ones in the
+# list).
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_EXTRA_STYLESHEET =
+
+# The LATEX_EXTRA_FILES tag can be used to specify one or more extra images or
+# other source files which should be copied to the LATEX_OUTPUT output
+# directory. Note that the files will be copied as-is; there are no commands or
+# markers available.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_EXTRA_FILES =
+
+# If the PDF_HYPERLINKS tag is set to YES, the LaTeX that is generated is
+# prepared for conversion to PDF (using ps2pdf or pdflatex). The PDF file will
+# contain links (just like the HTML output) instead of page references. This
+# makes the output suitable for online browsing using a PDF viewer.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+PDF_HYPERLINKS = YES
+
+# If the USE_PDFLATEX tag is set to YES, doxygen will use pdflatex to generate
+# the PDF file directly from the LaTeX files. Set this option to YES, to get a
+# higher quality PDF documentation.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+USE_PDFLATEX = YES
+
+# If the LATEX_BATCHMODE tag is set to YES, doxygen will add the \batchmode
+# command to the generated LaTeX files. This will instruct LaTeX to keep running
+# if errors occur, instead of asking the user for help. This option is also used
+# when generating formulas in HTML.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_BATCHMODE = NO
+
+# If the LATEX_HIDE_INDICES tag is set to YES then doxygen will not include the
+# index chapters (such as File Index, Compound Index, etc.) in the output.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_HIDE_INDICES = NO
+
+# If the LATEX_SOURCE_CODE tag is set to YES then doxygen will include source
+# code with syntax highlighting in the LaTeX output.
+#
+# Note that which sources are shown also depends on other settings such as
+# SOURCE_BROWSER.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_SOURCE_CODE = NO
+
+# The LATEX_BIB_STYLE tag can be used to specify the style to use for the
+# bibliography, e.g. plainnat, or ieeetr. See
+# https://en.wikipedia.org/wiki/BibTeX and \cite for more info.
+# The default value is: plain.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_BIB_STYLE = plain
+
+# If the LATEX_TIMESTAMP tag is set to YES then the footer of each generated
+# page will contain the date and time when the page was generated. Setting this
+# to NO can help when comparing the output of multiple runs.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_TIMESTAMP = NO
+
+# The LATEX_EMOJI_DIRECTORY tag is used to specify the (relative or absolute)
+# path from which the emoji images will be read. If a relative path is entered,
+# it will be relative to the LATEX_OUTPUT directory. If left blank the
+# LATEX_OUTPUT directory will be used.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_EMOJI_DIRECTORY =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the RTF output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_RTF tag is set to YES, doxygen will generate RTF output. The
+# RTF output is optimized for Word 97 and may not look too pretty with other RTF
+# readers/editors.
+# The default value is: NO.
+
+GENERATE_RTF = NO
+
+# The RTF_OUTPUT tag is used to specify where the RTF docs will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it.
+# The default directory is: rtf.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_OUTPUT = rtf
+
+# If the COMPACT_RTF tag is set to YES, doxygen generates more compact RTF
+# documents. This may be useful for small projects and may help to save some
+# trees in general.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+COMPACT_RTF = NO
+
+# If the RTF_HYPERLINKS tag is set to YES, the RTF that is generated will
+# contain hyperlink fields. The RTF file will contain links (just like the HTML
+# output) instead of page references. This makes the output suitable for online
+# browsing using Word or some other Word compatible readers that support those
+# fields.
+#
+# Note: WordPad (write) and others do not support links.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_HYPERLINKS = NO
+
+# Load stylesheet definitions from file. Syntax is similar to doxygen's
+# configuration file, i.e. a series of assignments. You only have to provide
+# replacements, missing definitions are set to their default value.
+#
+# See also section "Doxygen usage" for information on how to generate the
+# default style sheet that doxygen normally uses.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_STYLESHEET_FILE =
+
+# Set optional variables used in the generation of an RTF document. Syntax is
+# similar to doxygen's configuration file. A template extensions file can be
+# generated using doxygen -e rtf extensionFile.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_EXTENSIONS_FILE =
+
+# If the RTF_SOURCE_CODE tag is set to YES then doxygen will include source code
+# with syntax highlighting in the RTF output.
+#
+# Note that which sources are shown also depends on other settings such as
+# SOURCE_BROWSER.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_SOURCE_CODE = NO
+
+#---------------------------------------------------------------------------
+# Configuration options related to the man page output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_MAN tag is set to YES, doxygen will generate man pages for
+# classes and files.
+# The default value is: NO.
+
+GENERATE_MAN = NO
+
+# The MAN_OUTPUT tag is used to specify where the man pages will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it. A directory man3 will be created inside the directory specified by
+# MAN_OUTPUT.
+# The default directory is: man.
+# This tag requires that the tag GENERATE_MAN is set to YES.
+
+MAN_OUTPUT = man
+
+# The MAN_EXTENSION tag determines the extension that is added to the generated
+# man pages. In case the manual section does not start with a number, the number
+# 3 is prepended. The dot (.) at the beginning of the MAN_EXTENSION tag is
+# optional.
+# The default value is: .3.
+# This tag requires that the tag GENERATE_MAN is set to YES.
+
+MAN_EXTENSION = .3
+
+# The MAN_SUBDIR tag determines the name of the directory created within
+# MAN_OUTPUT in which the man pages are placed. If defaults to man followed by
+# MAN_EXTENSION with the initial . removed.
+# This tag requires that the tag GENERATE_MAN is set to YES.
+
+MAN_SUBDIR =
+
+# If the MAN_LINKS tag is set to YES and doxygen generates man output, then it
+# will generate one additional man file for each entity documented in the real
+# man page(s). These additional files only source the real man page, but without
+# them the man command would be unable to find the correct page.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_MAN is set to YES.
+
+MAN_LINKS = NO
+
+#---------------------------------------------------------------------------
+# Configuration options related to the XML output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_XML tag is set to YES, doxygen will generate an XML file that
+# captures the structure of the code including all documentation.
+# The default value is: NO.
+
+GENERATE_XML = YES
+
+# The XML_OUTPUT tag is used to specify where the XML pages will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it.
+# The default directory is: xml.
+# This tag requires that the tag GENERATE_XML is set to YES.
+
+XML_OUTPUT = xml
+
+# If the XML_PROGRAMLISTING tag is set to YES, doxygen will dump the program
+# listings (including syntax highlighting and cross-referencing information) to
+# the XML output. Note that enabling this will significantly increase the size
+# of the XML output.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_XML is set to YES.
+
+XML_PROGRAMLISTING = YES
+
+# If the XML_NS_MEMB_FILE_SCOPE tag is set to YES, doxygen will include
+# namespace members in file scope as well, matching the HTML output.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_XML is set to YES.
+
+XML_NS_MEMB_FILE_SCOPE = NO
+
+#---------------------------------------------------------------------------
+# Configuration options related to the DOCBOOK output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_DOCBOOK tag is set to YES, doxygen will generate Docbook files
+# that can be used to generate PDF.
+# The default value is: NO.
+
+GENERATE_DOCBOOK = NO
+
+# The DOCBOOK_OUTPUT tag is used to specify where the Docbook pages will be put.
+# If a relative path is entered the value of OUTPUT_DIRECTORY will be put in
+# front of it.
+# The default directory is: docbook.
+# This tag requires that the tag GENERATE_DOCBOOK is set to YES.
+
+DOCBOOK_OUTPUT = docbook
+
+# If the DOCBOOK_PROGRAMLISTING tag is set to YES, doxygen will include the
+# program listings (including syntax highlighting and cross-referencing
+# information) to the DOCBOOK output. Note that enabling this will significantly
+# increase the size of the DOCBOOK output.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_DOCBOOK is set to YES.
+
+DOCBOOK_PROGRAMLISTING = NO
+
+#---------------------------------------------------------------------------
+# Configuration options for the AutoGen Definitions output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_AUTOGEN_DEF tag is set to YES, doxygen will generate an
+# AutoGen Definitions (see http://autogen.sourceforge.net/) file that captures
+# the structure of the code including all documentation. Note that this feature
+# is still experimental and incomplete at the moment.
+# The default value is: NO.
+
+GENERATE_AUTOGEN_DEF = NO
+
+#---------------------------------------------------------------------------
+# Configuration options related to the Perl module output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_PERLMOD tag is set to YES, doxygen will generate a Perl module
+# file that captures the structure of the code including all documentation.
+#
+# Note that this feature is still experimental and incomplete at the moment.
+# The default value is: NO.
+
+GENERATE_PERLMOD = NO
+
+# If the PERLMOD_LATEX tag is set to YES, doxygen will generate the necessary
+# Makefile rules, Perl scripts and LaTeX code to be able to generate PDF and DVI
+# output from the Perl module output.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_PERLMOD is set to YES.
+
+PERLMOD_LATEX = NO
+
+# If the PERLMOD_PRETTY tag is set to YES, the Perl module output will be nicely
+# formatted so it can be parsed by a human reader. This is useful if you want to
+# understand what is going on. On the other hand, if this tag is set to NO, the
+# size of the Perl module output will be much smaller and Perl will parse it
+# just the same.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_PERLMOD is set to YES.
+
+PERLMOD_PRETTY = YES
+
+# The names of the make variables in the generated doxyrules.make file are
+# prefixed with the string contained in PERLMOD_MAKEVAR_PREFIX. This is useful
+# so different doxyrules.make files included by the same Makefile don't
+# overwrite each other's variables.
+# This tag requires that the tag GENERATE_PERLMOD is set to YES.
+
+PERLMOD_MAKEVAR_PREFIX =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the preprocessor
+#---------------------------------------------------------------------------
+
+# If the ENABLE_PREPROCESSING tag is set to YES, doxygen will evaluate all
+# C-preprocessor directives found in the sources and include files.
+# The default value is: YES.
+
+ENABLE_PREPROCESSING = YES
+
+# If the MACRO_EXPANSION tag is set to YES, doxygen will expand all macro names
+# in the source code. If set to NO, only conditional compilation will be
+# performed. Macro expansion can be done in a controlled way by setting
+# EXPAND_ONLY_PREDEF to YES.
+# The default value is: NO.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+MACRO_EXPANSION = YES
+
+# If the EXPAND_ONLY_PREDEF and MACRO_EXPANSION tags are both set to YES then
+# the macro expansion is limited to the macros specified with the PREDEFINED and
+# EXPAND_AS_DEFINED tags.
+# The default value is: NO.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+EXPAND_ONLY_PREDEF = YES
+
+# If the SEARCH_INCLUDES tag is set to YES, the include files in the
+# INCLUDE_PATH will be searched if a #include is found.
+# The default value is: YES.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+SEARCH_INCLUDES = YES
+
+# The INCLUDE_PATH tag can be used to specify one or more directories that
+# contain include files that are not input files but should be processed by the
+# preprocessor.
+# This tag requires that the tag SEARCH_INCLUDES is set to YES.
+
+INCLUDE_PATH =
+
+# You can use the INCLUDE_FILE_PATTERNS tag to specify one or more wildcard
+# patterns (like *.h and *.hpp) to filter out the header-files in the
+# directories. If left blank, the patterns specified with FILE_PATTERNS will be
+# used.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+INCLUDE_FILE_PATTERNS =
+
+# The PREDEFINED tag can be used to specify one or more macro names that are
+# defined before the preprocessor is started (similar to the -D option of e.g.
+# gcc). The argument of the tag is a list of macros of the form: name or
+# name=definition (no spaces). If the definition and the "=" are omitted, "=1"
+# is assumed. To prevent a macro definition from being undefined via #undef or
+# recursively expanded use the := operator instead of the = operator.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+PREDEFINED = __attribute__(x)= \
+ __declspec(x)= \
+ PARQUET_EXPORT= \
+ ARROW_EXPORT= \
+ ARROW_DS_EXPORT= \
+ ARROW_FLIGHT_EXPORT= \
+ ARROW_EXTERN_TEMPLATE= \
+ ARROW_DEPRECATED(x)=
+
+# If the MACRO_EXPANSION and EXPAND_ONLY_PREDEF tags are set to YES then this
+# tag can be used to specify a list of macro names that should be expanded. The
+# macro definition that is found in the sources will be used. Use the PREDEFINED
+# tag if you want to use a different macro definition that overrules the
+# definition found in the source code.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+EXPAND_AS_DEFINED =
+
+# If the SKIP_FUNCTION_MACROS tag is set to YES then doxygen's preprocessor will
+# remove all references to function-like macros that are alone on a line, have
+# an all uppercase name, and do not end with a semicolon. Such function macros
+# are typically used for boiler-plate code, and will confuse the parser if not
+# removed.
+# The default value is: YES.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+SKIP_FUNCTION_MACROS = YES
+
+#---------------------------------------------------------------------------
+# Configuration options related to external references
+#---------------------------------------------------------------------------
+
+# The TAGFILES tag can be used to specify one or more tag files. For each tag
+# file the location of the external documentation should be added. The format of
+# a tag file without this location is as follows:
+# TAGFILES = file1 file2 ...
+# Adding location for the tag files is done as follows:
+# TAGFILES = file1=loc1 "file2 = loc2" ...
+# where loc1 and loc2 can be relative or absolute paths or URLs. See the
+# section "Linking to external documentation" for more information about the use
+# of tag files.
+# Note: Each tag file must have a unique name (where the name does NOT include
+# the path). If a tag file is not located in the directory in which doxygen is
+# run, you must also specify the path to the tagfile here.
+
+TAGFILES =
+
+# When a file name is specified after GENERATE_TAGFILE, doxygen will create a
+# tag file that is based on the input files it reads. See section "Linking to
+# external documentation" for more information about the usage of tag files.
+
+GENERATE_TAGFILE =
+
+# If the ALLEXTERNALS tag is set to YES, all external class will be listed in
+# the class index. If set to NO, only the inherited external classes will be
+# listed.
+# The default value is: NO.
+
+ALLEXTERNALS = NO
+
+# If the EXTERNAL_GROUPS tag is set to YES, all external groups will be listed
+# in the modules index. If set to NO, only the current project's groups will be
+# listed.
+# The default value is: YES.
+
+EXTERNAL_GROUPS = YES
+
+# If the EXTERNAL_PAGES tag is set to YES, all external pages will be listed in
+# the related pages index. If set to NO, only the current project's pages will
+# be listed.
+# The default value is: YES.
+
+EXTERNAL_PAGES = YES
+
+#---------------------------------------------------------------------------
+# Configuration options related to the dot tool
+#---------------------------------------------------------------------------
+
+# If the CLASS_DIAGRAMS tag is set to YES, doxygen will generate a class diagram
+# (in HTML and LaTeX) for classes with base or super classes. Setting the tag to
+# NO turns the diagrams off. Note that this option also works with HAVE_DOT
+# disabled, but it is recommended to install and use dot, since it yields more
+# powerful graphs.
+# The default value is: YES.
+
+CLASS_DIAGRAMS = YES
+
+# You can include diagrams made with dia in doxygen documentation. Doxygen will
+# then run dia to produce the diagram and insert it in the documentation. The
+# DIA_PATH tag allows you to specify the directory where the dia binary resides.
+# If left empty dia is assumed to be found in the default search path.
+
+DIA_PATH =
+
+# If set to YES the inheritance and collaboration graphs will hide inheritance
+# and usage relations if the target is undocumented or is not a class.
+# The default value is: YES.
+
+HIDE_UNDOC_RELATIONS = YES
+
+# If you set the HAVE_DOT tag to YES then doxygen will assume the dot tool is
+# available from the path. This tool is part of Graphviz (see:
+# http://www.graphviz.org/), a graph visualization toolkit from AT&T and Lucent
+# Bell Labs. The other options in this section have no effect if this option is
+# set to NO
+# The default value is: NO.
+
+HAVE_DOT = NO
+
+# The DOT_NUM_THREADS specifies the number of dot invocations doxygen is allowed
+# to run in parallel. When set to 0 doxygen will base this on the number of
+# processors available in the system. You can set it explicitly to a value
+# larger than 0 to get control over the balance between CPU load and processing
+# speed.
+# Minimum value: 0, maximum value: 32, default value: 0.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_NUM_THREADS = 0
+
+# When you want a differently looking font in the dot files that doxygen
+# generates you can specify the font name using DOT_FONTNAME. You need to make
+# sure dot is able to find the font, which can be done by putting it in a
+# standard location or by setting the DOTFONTPATH environment variable or by
+# setting DOT_FONTPATH to the directory containing the font.
+# The default value is: Helvetica.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_FONTNAME = Helvetica
+
+# The DOT_FONTSIZE tag can be used to set the size (in points) of the font of
+# dot graphs.
+# Minimum value: 4, maximum value: 24, default value: 10.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_FONTSIZE = 10
+
+# By default doxygen will tell dot to use the default font as specified with
+# DOT_FONTNAME. If you specify a different font using DOT_FONTNAME you can set
+# the path where dot can find it using this tag.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_FONTPATH =
+
+# If the CLASS_GRAPH tag is set to YES then doxygen will generate a graph for
+# each documented class showing the direct and indirect inheritance relations.
+# Setting this tag to YES will force the CLASS_DIAGRAMS tag to NO.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+CLASS_GRAPH = YES
+
+# If the COLLABORATION_GRAPH tag is set to YES then doxygen will generate a
+# graph for each documented class showing the direct and indirect implementation
+# dependencies (inheritance, containment, and class references variables) of the
+# class with other documented classes.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+COLLABORATION_GRAPH = YES
+
+# If the GROUP_GRAPHS tag is set to YES then doxygen will generate a graph for
+# groups, showing the direct groups dependencies.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+GROUP_GRAPHS = YES
+
+# If the UML_LOOK tag is set to YES, doxygen will generate inheritance and
+# collaboration diagrams in a style similar to the OMG's Unified Modeling
+# Language.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+UML_LOOK = NO
+
+# If the UML_LOOK tag is enabled, the fields and methods are shown inside the
+# class node. If there are many fields or methods and many nodes the graph may
+# become too big to be useful. The UML_LIMIT_NUM_FIELDS threshold limits the
+# number of items for each type to make the size more manageable. Set this to 0
+# for no limit. Note that the threshold may be exceeded by 50% before the limit
+# is enforced. So when you set the threshold to 10, up to 15 fields may appear,
+# but if the number exceeds 15, the total amount of fields shown is limited to
+# 10.
+# Minimum value: 0, maximum value: 100, default value: 10.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+UML_LIMIT_NUM_FIELDS = 10
+
+# If the TEMPLATE_RELATIONS tag is set to YES then the inheritance and
+# collaboration graphs will show the relations between templates and their
+# instances.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+TEMPLATE_RELATIONS = NO
+
+# If the INCLUDE_GRAPH, ENABLE_PREPROCESSING and SEARCH_INCLUDES tags are set to
+# YES then doxygen will generate a graph for each documented file showing the
+# direct and indirect include dependencies of the file with other documented
+# files.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+INCLUDE_GRAPH = YES
+
+# If the INCLUDED_BY_GRAPH, ENABLE_PREPROCESSING and SEARCH_INCLUDES tags are
+# set to YES then doxygen will generate a graph for each documented file showing
+# the direct and indirect include dependencies of the file with other documented
+# files.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+INCLUDED_BY_GRAPH = YES
+
+# If the CALL_GRAPH tag is set to YES then doxygen will generate a call
+# dependency graph for every global function or class method.
+#
+# Note that enabling this option will significantly increase the time of a run.
+# So in most cases it will be better to enable call graphs for selected
+# functions only using the \callgraph command. Disabling a call graph can be
+# accomplished by means of the command \hidecallgraph.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+CALL_GRAPH = NO
+
+# If the CALLER_GRAPH tag is set to YES then doxygen will generate a caller
+# dependency graph for every global function or class method.
+#
+# Note that enabling this option will significantly increase the time of a run.
+# So in most cases it will be better to enable caller graphs for selected
+# functions only using the \callergraph command. Disabling a caller graph can be
+# accomplished by means of the command \hidecallergraph.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+CALLER_GRAPH = NO
+
+# If the GRAPHICAL_HIERARCHY tag is set to YES then doxygen will graphical
+# hierarchy of all classes instead of a textual one.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+GRAPHICAL_HIERARCHY = YES
+
+# If the DIRECTORY_GRAPH tag is set to YES then doxygen will show the
+# dependencies a directory has on other directories in a graphical way. The
+# dependency relations are determined by the #include relations between the
+# files in the directories.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DIRECTORY_GRAPH = YES
+
+# The DOT_IMAGE_FORMAT tag can be used to set the image format of the images
+# generated by dot. For an explanation of the image formats see the section
+# output formats in the documentation of the dot tool (Graphviz (see:
+# http://www.graphviz.org/)).
+# Note: If you choose svg you need to set HTML_FILE_EXTENSION to xhtml in order
+# to make the SVG files visible in IE 9+ (other browsers do not have this
+# requirement).
+# Possible values are: png, jpg, gif, svg, png:gd, png:gd:gd, png:cairo,
+# png:cairo:gd, png:cairo:cairo, png:cairo:gdiplus, png:gdiplus and
+# png:gdiplus:gdiplus.
+# The default value is: png.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_IMAGE_FORMAT = png
+
+# If DOT_IMAGE_FORMAT is set to svg, then this option can be set to YES to
+# enable generation of interactive SVG images that allow zooming and panning.
+#
+# Note that this requires a modern browser other than Internet Explorer. Tested
+# and working are Firefox, Chrome, Safari, and Opera.
+# Note: For IE 9+ you need to set HTML_FILE_EXTENSION to xhtml in order to make
+# the SVG files visible. Older versions of IE do not have SVG support.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+INTERACTIVE_SVG = NO
+
+# The DOT_PATH tag can be used to specify the path where the dot tool can be
+# found. If left blank, it is assumed the dot tool can be found in the path.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_PATH =
+
+# The DOTFILE_DIRS tag can be used to specify one or more directories that
+# contain dot files that are included in the documentation (see the \dotfile
+# command).
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOTFILE_DIRS =
+
+# The MSCFILE_DIRS tag can be used to specify one or more directories that
+# contain msc files that are included in the documentation (see the \mscfile
+# command).
+
+MSCFILE_DIRS =
+
+# The DIAFILE_DIRS tag can be used to specify one or more directories that
+# contain dia files that are included in the documentation (see the \diafile
+# command).
+
+DIAFILE_DIRS =
+
+# When using plantuml, the PLANTUML_JAR_PATH tag should be used to specify the
+# path where java can find the plantuml.jar file. If left blank, it is assumed
+# PlantUML is not used or called during a preprocessing step. Doxygen will
+# generate a warning when it encounters a \startuml command in this case and
+# will not generate output for the diagram.
+
+PLANTUML_JAR_PATH =
+
+# When using plantuml, the PLANTUML_CFG_FILE tag can be used to specify a
+# configuration file for plantuml.
+
+PLANTUML_CFG_FILE =
+
+# When using plantuml, the specified paths are searched for files specified by
+# the !include statement in a plantuml block.
+
+PLANTUML_INCLUDE_PATH =
+
+# The DOT_GRAPH_MAX_NODES tag can be used to set the maximum number of nodes
+# that will be shown in the graph. If the number of nodes in a graph becomes
+# larger than this value, doxygen will truncate the graph, which is visualized
+# by representing a node as a red box. Note that doxygen if the number of direct
+# children of the root node in a graph is already larger than
+# DOT_GRAPH_MAX_NODES then the graph will not be shown at all. Also note that
+# the size of a graph can be further restricted by MAX_DOT_GRAPH_DEPTH.
+# Minimum value: 0, maximum value: 10000, default value: 50.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_GRAPH_MAX_NODES = 50
+
+# The MAX_DOT_GRAPH_DEPTH tag can be used to set the maximum depth of the graphs
+# generated by dot. A depth value of 3 means that only nodes reachable from the
+# root by following a path via at most 3 edges will be shown. Nodes that lay
+# further from the root node will be omitted. Note that setting this option to 1
+# or 2 may greatly reduce the computation time needed for large code bases. Also
+# note that the size of a graph can be further restricted by
+# DOT_GRAPH_MAX_NODES. Using a depth of 0 means no depth restriction.
+# Minimum value: 0, maximum value: 1000, default value: 0.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+MAX_DOT_GRAPH_DEPTH = 0
+
+# Set the DOT_TRANSPARENT tag to YES to generate images with a transparent
+# background. This is disabled by default, because dot on Windows does not seem
+# to support this out of the box.
+#
+# Warning: Depending on the platform used, enabling this option may lead to
+# badly anti-aliased labels on the edges of a graph (i.e. they become hard to
+# read).
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_TRANSPARENT = NO
+
+# Set the DOT_MULTI_TARGETS tag to YES to allow dot to generate multiple output
+# files in one run (i.e. multiple -o and -T options on the command line). This
+# makes dot run faster, but since only newer versions of dot (>1.8.10) support
+# this, this feature is disabled by default.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_MULTI_TARGETS = NO
+
+# If the GENERATE_LEGEND tag is set to YES doxygen will generate a legend page
+# explaining the meaning of the various boxes and arrows in the dot generated
+# graphs.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+GENERATE_LEGEND = YES
+
+# If the DOT_CLEANUP tag is set to YES, doxygen will remove the intermediate dot
+# files that are used to generate the various graphs.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_CLEANUP = YES
diff --git a/src/arrow/cpp/apidoc/HDFS.md b/src/arrow/cpp/apidoc/HDFS.md
new file mode 100644
index 000000000..d3671fb76
--- /dev/null
+++ b/src/arrow/cpp/apidoc/HDFS.md
@@ -0,0 +1,83 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+## Using Arrow's HDFS (Apache Hadoop Distributed File System) interface
+
+### Build requirements
+
+To build the integration, pass the following option to CMake
+
+```shell
+-DARROW_HDFS=on
+```
+
+For convenience, we have bundled `hdfs.h` for libhdfs from Apache Hadoop in
+Arrow's thirdparty. If you wish to build against the `hdfs.h` in your installed
+Hadoop distribution, set the `$HADOOP_HOME` environment variable.
+
+### Runtime requirements
+
+By default, the HDFS client C++ class in `libarrow_io` uses the libhdfs JNI
+interface to the Java Hadoop client. This library is loaded **at runtime**
+(rather than at link / library load time, since the library may not be in your
+LD_LIBRARY_PATH), and relies on some environment variables.
+
+* `HADOOP_HOME`: the root of your installed Hadoop distribution. Often has
+`lib/native/libhdfs.so`.
+* `JAVA_HOME`: the location of your Java SDK installation.
+* `CLASSPATH`: must contain the Hadoop jars. You can set these using:
+
+```shell
+export CLASSPATH=`$HADOOP_HOME/bin/hadoop classpath --glob`
+```
+
+* `ARROW_LIBHDFS_DIR` (optional): explicit location of `libhdfs.so` if it is
+installed somewhere other than `$HADOOP_HOME/lib/native`.
+
+To accommodate distribution-specific nuances, the `JAVA_HOME` variable may be
+set to the root path for the Java SDK, the JRE path itself, or to the directory
+containing the `libjvm` library.
+
+### Mac Specifics
+
+The installed location of Java on OS X can vary, however the following snippet
+will set it automatically for you:
+
+```shell
+export JAVA_HOME=$(/usr/libexec/java_home)
+```
+
+Homebrew's Hadoop does not have native libs. Apache doesn't build these, so
+users must build Hadoop to get the native libs. See this Stack Overflow
+answer for details:
+
+http://stackoverflow.com/a/40051353/478288
+
+Be sure to include the path to the native libs in `JAVA_LIBRARY_PATH`:
+
+```shell
+export JAVA_LIBRARY_PATH=$HADOOP_HOME/lib/native:$JAVA_LIBRARY_PATH
+```
+
+If you get an error about needing to install Java 6, then add *BundledApp* and
+*JNI* to the `JVMCapabilities` in `$JAVA_HOME/../Info.plist`. See
+
+https://oliverdowling.com.au/2015/10/09/oracles-jre-8-on-mac-os-x-el-capitan/
+
+https://derflounder.wordpress.com/2015/08/08/modifying-oracles-java-sdk-to-run-java-applications-on-os-x/
diff --git a/src/arrow/cpp/apidoc/footer.html b/src/arrow/cpp/apidoc/footer.html
new file mode 100644
index 000000000..01f4ad2d5
--- /dev/null
+++ b/src/arrow/cpp/apidoc/footer.html
@@ -0,0 +1,31 @@
+<!-- HTML footer for doxygen 1.8.14-->
+<!-- start footer part -->
+<!--BEGIN GENERATE_TREEVIEW-->
+<div id="nav-path" class="navpath"><!-- id is needed for treeview function! -->
+ <ul>
+ $navpath
+ <li class="footer">$generatedby
+ <a href="http://www.doxygen.org/index.html">
+ <img class="footer" src="$relpath^doxygen.png" alt="doxygen"/></a> $doxygenversion </li>
+ </ul>
+</div>
+<!--END GENERATE_TREEVIEW-->
+<!--BEGIN !GENERATE_TREEVIEW-->
+<hr class="footer"/><address class="footer"><small>
+$generatedby &#160;<a href="http://www.doxygen.org/index.html">
+<img class="footer" src="$relpath^doxygen.png" alt="doxygen"/>
+</a> $doxygenversion
+</small></address>
+<!--END !GENERATE_TREEVIEW-->
+
+<script async src="https://www.googletagmanager.com/gtag/js?id=UA-107500873-1"></script>
+<script>
+ window.dataLayer = window.dataLayer || [];
+ function gtag(){dataLayer.push(arguments);}
+ gtag('js', new Date());
+
+ gtag('config', 'UA-107500873-1');
+</script>
+
+</body>
+</html>
diff --git a/src/arrow/cpp/apidoc/tutorials/plasma.md b/src/arrow/cpp/apidoc/tutorials/plasma.md
new file mode 100644
index 000000000..fef452220
--- /dev/null
+++ b/src/arrow/cpp/apidoc/tutorials/plasma.md
@@ -0,0 +1,450 @@
+<!---
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. See accompanying LICENSE file.
+-->
+
+Using the Plasma In-Memory Object Store from C++
+================================================
+
+Apache Arrow offers the ability to share your data structures among multiple
+processes simultaneously through Plasma, an in-memory object store.
+
+Note that **the Plasma API is not stable**.
+
+Plasma clients are processes that run on the same machine as the object store.
+They communicate with the object store over Unix domain sockets, and they read
+and write data in the object store through shared memory.
+
+Plasma objects are immutable once they have been created.
+
+The following goes over the basics so you can begin using Plasma in your big
+data applications.
+
+Starting the Plasma store
+-------------------------
+
+To start running the Plasma object store so that clients may
+connect and access the data, run the following command:
+
+```
+plasma_store_server -m 1000000000 -s /tmp/plasma
+```
+
+The `-m` flag specifies the size of the object store in bytes. The `-s` flag
+specifies the path of the Unix domain socket that the store will listen at.
+
+Therefore, the above command initializes a Plasma store up to 1 GB of memory
+and sets the socket to `/tmp/plasma.`
+
+The Plasma store will remain available as long as the `plasma_store_server` process is
+running in a terminal window. Messages, such as alerts for disconnecting
+clients, may occasionally be output. To stop running the Plasma store, you
+can press `Ctrl-C` in the terminal window.
+
+Alternatively, you can run the Plasma store in the background and ignore all
+message output with the following terminal command:
+
+```
+plasma_store_server -m 1000000000 -s /tmp/plasma 1> /dev/null 2> /dev/null &
+```
+
+The Plasma store will instead run silently in the background. To stop running
+the Plasma store in this case, issue the command below:
+
+```
+killall plasma_store_server
+```
+
+Creating a Plasma client
+------------------------
+
+Now that the Plasma object store is up and running, it is time to make a client
+process connect to it. To use the Plasma object store as a client, your
+application should initialize a `plasma::PlasmaClient` object and tell it to
+connect to the socket specified when starting up the Plasma object store.
+
+```cpp
+#include <plasma/client.h>
+
+using namespace plasma;
+
+int main(int argc, char** argv) {
+ // Start up and connect a Plasma client.
+ PlasmaClient client;
+ ARROW_CHECK_OK(client.Connect("/tmp/plasma"));
+ // Disconnect the Plasma client.
+ ARROW_CHECK_OK(client.Disconnect());
+}
+```
+
+Save this program in a file `test.cc` and compile it with
+
+```
+g++ test.cc `pkg-config --cflags --libs plasma` --std=c++11
+```
+
+Note that multiple clients can be created within the same process.
+
+If the Plasma store is still running, you can now execute the `a.out` executable
+and the store will print something like
+
+```
+Disconnecting client on fd 5
+```
+
+which shows that the client was successfully disconnected.
+
+Object IDs
+----------
+
+The Plasma object store uses twenty-byte identifiers for accessing objects
+stored in shared memory. Each object in the Plasma store should be associated
+with a unique ID. The Object ID is then a key that can be used by **any** client
+to fetch that object from the Plasma store.
+
+Random generation of Object IDs is often good enough to ensure unique IDs.
+For test purposes, you can use the function `random_object_id` from the header
+`plasma/test-util.h` to generate random Object IDs, which uses a global random
+number generator. In your own applications, we recommend to generate a string of
+`ObjectID::size()` many random bytes using your own random number generator
+and pass them to `ObjectID::from_bytes` to generate the ObjectID.
+
+```cpp
+#include <plasma/test-util.h>
+
+// Randomly generate an Object ID.
+ObjectID object_id = random_object_id();
+```
+
+Now, any connected client that knows the object's Object ID can access the
+same object from the Plasma object store. For easy transportation of Object IDs,
+you can convert/serialize an Object ID into a binary string and back as
+follows:
+
+```cpp
+// From ObjectID to binary string
+std:string id_string = object_id.binary();
+
+// From binary string to ObjectID
+ObjectID id_object = ObjectID::from_binary(&id_string);
+```
+
+You can also get a human readable representation of ObjectIDs in the same
+format that git uses for commit hashes by running `ObjectID::hex`.
+
+Here is a test program you can run:
+
+```cpp
+#include <iostream>
+#include <string>
+#include <plasma/client.h>
+#include <plasma/test-util.h>
+
+using namespace plasma;
+
+int main(int argc, char** argv) {
+ ObjectID object_id1 = random_object_id();
+ std::cout << "object_id1 is " << object_id1.hex() << std::endl;
+
+ std::string id_string = object_id1.binary();
+ ObjectID object_id2 = ObjectID::from_binary(id_string);
+ std::cout << "object_id2 is " << object_id2.hex() << std::endl;
+}
+```
+
+Creating an Object
+------------------
+
+Now that you learned about Object IDs that are used to refer to objects,
+let's look at how objects can be stored in Plasma.
+
+Storing objects is a two-stage process. First a buffer is allocated with a call
+to `Create`. Then it can be constructed in place by the client. Then it is made
+immutable and shared with other clients via a call to `Seal`.
+
+The `Create` call blocks while the Plasma store allocates a buffer of the
+appropriate size. The client will then map the buffer into its own address
+space. At this point the object can be constructed in place using a pointer that
+was written by the `Create` command.
+
+```cpp
+int64_t data_size = 100;
+// The address of the buffer allocated by the Plasma store will be written at
+// this address.
+std::shared_ptr<Buffer> data;
+// Create a Plasma object by specifying its ID and size.
+ARROW_CHECK_OK(client.Create(object_id, data_size, NULL, 0, &data));
+```
+
+You can also specify metadata for the object; the third argument is the
+metadata (as raw bytes) and the fourth argument is the size of the metadata.
+
+```cpp
+// Create a Plasma object with metadata.
+int64_t data_size = 100;
+std::string metadata = "{'author': 'john'}";
+std::shared_ptr<Buffer> data;
+client.Create(object_id, data_size, (uint8_t*) metadata.data(), metadata.size(), &data);
+```
+
+Now that we've obtained a pointer to our object's data, we can
+write our data to it:
+
+```cpp
+// Write some data for the Plasma object.
+for (int64_t i = 0; i < data_size; i++) {
+ data[i] = static_cast<uint8_t>(i % 4);
+}
+```
+
+When the client is done, the client **seals** the buffer, making the object
+immutable, and making it available to other Plasma clients:
+
+```cpp
+// Seal the object. This makes it available for all clients.
+client.Seal(object_id);
+```
+
+Here is an example that combines all these features:
+
+```cpp
+#include <plasma/client.h>
+
+using namespace plasma;
+
+int main(int argc, char** argv) {
+ // Start up and connect a Plasma client.
+ PlasmaClient client;
+ ARROW_CHECK_OK(client.Connect("/tmp/plasma"));
+ // Create an object with a fixed ObjectID.
+ ObjectID object_id = ObjectID::from_binary("00000000000000000000");
+ int64_t data_size = 1000;
+ std::shared_ptr<Buffer> data;
+ std::string metadata = "{'author': 'john'}";
+ ARROW_CHECK_OK(client.Create(object_id, data_size, (uint8_t*) metadata.data(), metadata.size(), &data));
+ // Write some data into the object.
+ auto d = data->mutable_data();
+ for (int64_t i = 0; i < data_size; i++) {
+ d[i] = static_cast<uint8_t>(i % 4);
+ }
+ // Seal the object.
+ ARROW_CHECK_OK(client.Seal(object_id));
+ // Disconnect the client.
+ ARROW_CHECK_OK(client.Disconnect());
+}
+```
+
+This example can be compiled with
+
+```
+g++ create.cc `pkg-config --cflags --libs plasma` --std=c++11 -o create
+```
+
+To verify that an object exists in the Plasma object store, you can
+call `PlasmaClient::Contains()` to check if an object has
+been created and sealed for a given Object ID. Note that this function
+will still return False if the object has been created, but not yet
+sealed:
+
+```cpp
+// Check if an object has been created and sealed.
+bool has_object;
+client.Contains(object_id, &has_object);
+if (has_object) {
+ // Object has been created and sealed, proceed
+}
+```
+
+Getting an Object
+-----------------
+
+After an object has been sealed, any client who knows the Object ID can get
+the object. To store the retrieved object contents, you should create an
+`ObjectBuffer`, then call `PlasmaClient::Get()` as follows:
+
+```cpp
+// Get from the Plasma store by Object ID.
+ObjectBuffer object_buffer;
+client.Get(&object_id, 1, -1, &object_buffer);
+```
+
+`PlasmaClient::Get()` isn't limited to fetching a single object
+from the Plasma store at once. You can specify an array of Object IDs and
+`ObjectBuffers` to fetch at once, so long as you also specify the
+number of objects being fetched:
+
+```cpp
+// Get two objects at once from the Plasma store. This function
+// call will block until both objects have been fetched.
+ObjectBuffer multiple_buffers[2];
+ObjectID multiple_ids[2] = {object_id1, object_id2};
+client.Get(multiple_ids, 2, -1, multiple_buffers);
+```
+
+Since `PlasmaClient::Get()` is a blocking function call, it may be
+necessary to limit the amount of time the function is allowed to take
+when trying to fetch from the Plasma store. You can pass in a timeout
+in milliseconds when calling `PlasmaClient::Get().` To use `PlasmaClient::Get()`
+without a timeout, just pass in -1 like in the previous example calls:
+
+```cpp
+// Make the function call give up fetching the object if it takes
+// more than 100 milliseconds.
+int64_t timeout = 100;
+client.Get(&object_id, 1, timeout, &object_buffer);
+```
+
+Finally, to access the object, you can access the `data` and
+`metadata` attributes of the `ObjectBuffer`. The `data` can be indexed
+like any array:
+
+```cpp
+// Access object data.
+uint8_t* data = object_buffer.data;
+int64_t data_size = object_buffer.data_size;
+
+// Access object metadata.
+uint8_t* metadata = object_buffer.metadata;
+uint8_t metadata_size = object_buffer.metadata_size;
+
+// Index into data array.
+uint8_t first_data_byte = data[0];
+```
+
+Here is a longer example that shows these capabilities:
+
+```cpp
+#include <plasma/client.h>
+
+using namespace plasma;
+
+int main(int argc, char** argv) {
+ // Start up and connect a Plasma client.
+ PlasmaClient client;
+ ARROW_CHECK_OK(client.Connect("/tmp/plasma"));
+ ObjectID object_id = ObjectID::from_binary("00000000000000000000");
+ ObjectBuffer object_buffer;
+ ARROW_CHECK_OK(client.Get(&object_id, 1, -1, &object_buffer));
+
+ // Retrieve object data.
+ auto buffer = object_buffer.data;
+ const uint8_t* data = buffer->data();
+ int64_t data_size = buffer->size();
+
+ // Check that the data agrees with what was written in the other process.
+ for (int64_t i = 0; i < data_size; i++) {
+ ARROW_CHECK(data[i] == static_cast<uint8_t>(i % 4));
+ }
+
+ // Disconnect the client.
+ ARROW_CHECK_OK(client.Disconnect());
+}
+```
+
+If you compile it with
+
+```
+g++ get.cc `pkg-config --cflags --libs plasma` --std=c++11 -o get
+```
+
+and run it with `./get`, all the assertions will pass if you run the `create`
+example from above on the same Plasma store.
+
+
+Object Lifetime Management
+--------------------------
+
+The Plasma store internally does reference counting to make sure objects that
+are mapped into the address space of one of the clients with `PlasmaClient::Get`
+are accessible. To unmap objects from a client, call `PlasmaClient::Release`.
+All objects that are mapped into a clients address space will automatically
+be released when the client is disconnected from the store (this happens even
+if the client process crashes or otherwise fails to call `Disconnect`).
+
+If a new object is created and there is not enough space in the Plasma store,
+the store will evict the least recently used object (an object is in use if at
+least one client has gotten it but not released it).
+
+Object notifications
+--------------------
+
+Additionally, you can arrange to have Plasma notify you when objects are
+sealed in the object store. This may especially be handy when your
+program is collaborating with other Plasma clients, and needs to know
+when they make objects available.
+
+First, you can subscribe your current Plasma client to such notifications
+by getting a file descriptor:
+
+```cpp
+// Start receiving notifications into file_descriptor.
+int fd;
+ARROW_CHECK_OK(client.Subscribe(&fd));
+```
+
+Once you have the file descriptor, you can have your current Plasma client
+wait to receive the next object notification. Object notifications
+include information such as Object ID, data size, and metadata size of
+the next newly available object:
+
+```cpp
+// Receive notification of the next newly available object.
+// Notification information is stored in object_id, data_size, and metadata_size
+ObjectID object_id;
+int64_t data_size;
+int64_t metadata_size;
+ARROW_CHECK_OK(client.GetNotification(fd, &object_id, &data_size, &metadata_size));
+
+// Get the newly available object.
+ObjectBuffer object_buffer;
+ARROW_CHECK_OK(client.Get(&object_id, 1, -1, &object_buffer));
+```
+
+Here is a full program that shows this capability:
+
+```cpp
+#include <plasma/client.h>
+
+using namespace plasma;
+
+int main(int argc, char** argv) {
+ // Start up and connect a Plasma client.
+ PlasmaClient client;
+ ARROW_CHECK_OK(client.Connect("/tmp/plasma"));
+
+ int fd;
+ ARROW_CHECK_OK(client.Subscribe(&fd));
+
+ ObjectID object_id;
+ int64_t data_size;
+ int64_t metadata_size;
+ while (true) {
+ ARROW_CHECK_OK(client.GetNotification(fd, &object_id, &data_size, &metadata_size));
+
+ std::cout << "Received object notification for object_id = "
+ << object_id.hex() << ", with data_size = " << data_size
+ << ", and metadata_size = " << metadata_size << std::endl;
+ }
+
+ // Disconnect the client.
+ ARROW_CHECK_OK(client.Disconnect());
+}
+```
+
+If you compile it with
+
+```
+g++ subscribe.cc `pkg-config --cflags --libs plasma` --std=c++11 -o subscribe
+```
+
+and invoke `./create` and `./subscribe` while the Plasma store is running,
+you can observe the new object arriving.
diff --git a/src/arrow/cpp/apidoc/tutorials/tensor_to_py.md b/src/arrow/cpp/apidoc/tutorials/tensor_to_py.md
new file mode 100644
index 000000000..cd191fea0
--- /dev/null
+++ b/src/arrow/cpp/apidoc/tutorials/tensor_to_py.md
@@ -0,0 +1,127 @@
+<!---
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. See accompanying LICENSE file.
+-->
+
+Use Plasma to Access Tensors from C++ in Python
+==============================================
+
+This short tutorial shows how to use Arrow and the Plasma Store to send data
+from C++ to Python.
+
+In detail, we will show how to:
+1. Serialize a floating-point array in C++ into an Arrow tensor
+2. Save the Arrow tensor to Plasma
+3. Access the Tensor in a Python process
+
+This approach has the advantage that multiple python processes can all read
+the tensor with zero-copy. Therefore, only one copy is necessary when we send
+a tensor from one C++ process to many python processes.
+
+
+Step 0: Set up
+------
+We will include the following header files and construct a Plasma client.
+
+```cpp
+#include <plasma/client.h>
+#include <arrow/tensor.h>
+#include <arrow/array.h>
+#include <arrow/buffer.h>
+#include <arrow/io/memory.h>
+#include <arrow/ipc/writer.h>
+
+PlasmaClient client_;
+ARROW_CHECK_OK(client_.Connect("/tmp/plasma", "", 0));
+```
+
+
+Step 1: Serialize a floating point array in C++ into an Arrow Tensor
+--------------------------------------------------------------------
+In this step, we will construct a floating-point array in C++.
+
+```cpp
+// Generate an Object ID for Plasma
+ObjectID object_id = ObjectID::from_binary("11111111111111111111");
+
+// Generate Float Array
+int64_t input_length = 1000;
+std::vector<float> input(input_length);
+for (int64_t i = 0; i < input_length; ++i) {
+ input[i] = 2.0;
+}
+
+// Create Arrow Tensor Object, no copy made!
+// {input_length} is the shape of the tensor
+auto value_buffer = Buffer::Wrap<float>(input);
+Tensor t(float32(), value_buffer, {input_length});
+```
+
+Step 2: Save the Arrow Tensor to Plasma In-Memory Object Store
+--------------------------------------------------------------
+Continuing from Step 1, this step will save the tensor to Plasma Store. We
+use `arrow::ipc::WriteTensor` to write the data.
+
+The variable `meta_len` will contain the length of the tensor metadata
+after the call to `arrow::ipc::WriteTensor`.
+
+```cpp
+// Get the size of the tensor to be stored in Plasma
+int64_t datasize;
+ARROW_CHECK_OK(ipc::GetTensorSize(t, &datasize));
+int32_t meta_len = 0;
+
+// Create the Plasma Object
+// Plasma is responsible for initializing and resizing the buffer
+// This buffer will contain the _serialized_ tensor
+std::shared_ptr<Buffer> buffer;
+ARROW_CHECK_OK(
+ client_.Create(object_id, datasize, NULL, 0, &buffer));
+
+// Writing Process, this will copy the tensor into Plasma
+io::FixedSizeBufferWriter stream(buffer);
+ARROW_CHECK_OK(arrow::ipc::WriteTensor(t, &stream, &meta_len, &datasize));
+
+// Seal Plasma Object
+// This computes a hash of the object data by default
+ARROW_CHECK_OK(client_.Seal(object_id));
+```
+
+Step 3: Access the Tensor in a Python Process
+---------------------------------------------
+In Python, we will construct a Plasma client and point it to the store's socket.
+The `inputs` variable will be a list of Object IDs in their raw byte string form.
+
+```python
+import pyarrow as pa
+import pyarrow.plasma as plasma
+
+plasma_client = plasma.connect('/tmp/plasma')
+
+# inputs: a list of object ids
+inputs = [20 * b'1']
+
+# Construct Object ID and perform a batch get
+object_ids = [plasma.ObjectID(inp) for inp in inputs]
+buffers = plasma_client.get_buffers(object_ids)
+
+# Read the tensor and convert to numpy array for each object
+arrs = []
+for buffer in buffers:
+ reader = pa.BufferReader(buffer)
+ t = pa.read_tensor(reader)
+ arr = t.to_numpy()
+ arrs.append(arr)
+
+# arrs is now a list of numpy arrays
+assert np.all(arrs[0] == 2.0 * np.ones(1000, dtype="float32"))
+```
diff --git a/src/arrow/cpp/build-support/asan_symbolize.py b/src/arrow/cpp/build-support/asan_symbolize.py
new file mode 100755
index 000000000..854090ae5
--- /dev/null
+++ b/src/arrow/cpp/build-support/asan_symbolize.py
@@ -0,0 +1,368 @@
+#!/usr/bin/env python3
+#===- lib/asan/scripts/asan_symbolize.py -----------------------------------===#
+#
+# The LLVM Compiler Infrastructure
+#
+# This file is distributed under the University of Illinois Open Source
+# License. See LICENSE.TXT for details.
+#
+#===------------------------------------------------------------------------===#
+import bisect
+import os
+import re
+import subprocess
+import sys
+
+llvm_symbolizer = None
+symbolizers = {}
+filetypes = {}
+vmaddrs = {}
+DEBUG = False
+
+
+# FIXME: merge the code that calls fix_filename().
+def fix_filename(file_name):
+ for path_to_cut in sys.argv[1:]:
+ file_name = re.sub('.*' + path_to_cut, '', file_name)
+ file_name = re.sub('.*asan_[a-z_]*.cc:[0-9]*', '_asan_rtl_', file_name)
+ file_name = re.sub('.*crtstuff.c:0', '???:0', file_name)
+ return file_name
+
+
+class Symbolizer(object):
+ def __init__(self):
+ pass
+
+ def symbolize(self, addr, binary, offset):
+ """Symbolize the given address (pair of binary and offset).
+
+ Overridden in subclasses.
+ Args:
+ addr: virtual address of an instruction.
+ binary: path to executable/shared object containing this instruction.
+ offset: instruction offset in the @binary.
+ Returns:
+ list of strings (one string for each inlined frame) describing
+ the code locations for this instruction (that is, function name, file
+ name, line and column numbers).
+ """
+ return None
+
+
+class LLVMSymbolizer(Symbolizer):
+ def __init__(self, symbolizer_path):
+ super(LLVMSymbolizer, self).__init__()
+ self.symbolizer_path = symbolizer_path
+ self.pipe = self.open_llvm_symbolizer()
+
+ def open_llvm_symbolizer(self):
+ if not os.path.exists(self.symbolizer_path):
+ return None
+ cmd = [self.symbolizer_path,
+ '--use-symbol-table=true',
+ '--demangle=false',
+ '--functions=true',
+ '--inlining=true']
+ if DEBUG:
+ print(' '.join(cmd))
+ return subprocess.Popen(cmd, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE)
+
+ def symbolize(self, addr, binary, offset):
+ """Overrides Symbolizer.symbolize."""
+ if not self.pipe:
+ return None
+ result = []
+ try:
+ symbolizer_input = '%s %s' % (binary, offset)
+ if DEBUG:
+ print(symbolizer_input)
+ self.pipe.stdin.write(symbolizer_input)
+ self.pipe.stdin.write('\n')
+ while True:
+ function_name = self.pipe.stdout.readline().rstrip()
+ if not function_name:
+ break
+ file_name = self.pipe.stdout.readline().rstrip()
+ file_name = fix_filename(file_name)
+ if (not function_name.startswith('??') and
+ not file_name.startswith('??')):
+ # Append only valid frames.
+ result.append('%s in %s %s' % (addr, function_name,
+ file_name))
+ except Exception:
+ result = []
+ if not result:
+ result = None
+ return result
+
+
+def LLVMSymbolizerFactory(system):
+ symbolizer_path = os.getenv('LLVM_SYMBOLIZER_PATH')
+ if not symbolizer_path:
+ # Assume llvm-symbolizer is in PATH.
+ symbolizer_path = 'llvm-symbolizer'
+ return LLVMSymbolizer(symbolizer_path)
+
+
+class Addr2LineSymbolizer(Symbolizer):
+ def __init__(self, binary):
+ super(Addr2LineSymbolizer, self).__init__()
+ self.binary = binary
+ self.pipe = self.open_addr2line()
+
+ def open_addr2line(self):
+ cmd = ['addr2line', '-f', '-e', self.binary]
+ if DEBUG:
+ print(' '.join(cmd))
+ return subprocess.Popen(cmd,
+ stdin=subprocess.PIPE, stdout=subprocess.PIPE)
+
+ def symbolize(self, addr, binary, offset):
+ """Overrides Symbolizer.symbolize."""
+ if self.binary != binary:
+ return None
+ try:
+ self.pipe.stdin.write(offset)
+ self.pipe.stdin.write('\n')
+ function_name = self.pipe.stdout.readline().rstrip()
+ file_name = self.pipe.stdout.readline().rstrip()
+ except Exception:
+ function_name = ''
+ file_name = ''
+ file_name = fix_filename(file_name)
+ return ['%s in %s %s' % (addr, function_name, file_name)]
+
+
+class DarwinSymbolizer(Symbolizer):
+ def __init__(self, addr, binary):
+ super(DarwinSymbolizer, self).__init__()
+ self.binary = binary
+ # Guess which arch we're running. 10 = len('0x') + 8 hex digits.
+ if len(addr) > 10:
+ self.arch = 'x86_64'
+ else:
+ self.arch = 'i386'
+ self.vmaddr = None
+ self.pipe = None
+
+ def write_addr_to_pipe(self, offset):
+ self.pipe.stdin.write('0x%x' % int(offset, 16))
+ self.pipe.stdin.write('\n')
+
+ def open_atos(self):
+ if DEBUG:
+ print('atos -o %s -arch %s' % (self.binary, self.arch))
+ cmdline = ['atos', '-o', self.binary, '-arch', self.arch]
+ self.pipe = subprocess.Popen(cmdline,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+
+ def symbolize(self, addr, binary, offset):
+ """Overrides Symbolizer.symbolize."""
+ if self.binary != binary:
+ return None
+ self.open_atos()
+ self.write_addr_to_pipe(offset)
+ self.pipe.stdin.close()
+ atos_line = self.pipe.stdout.readline().rstrip()
+ # A well-formed atos response looks like this:
+ # foo(type1, type2) (in object.name) (filename.cc:80)
+ match = re.match('^(.*) \(in (.*)\) \((.*:\d*)\)$', atos_line)
+ if DEBUG:
+ print('atos_line: {0}'.format(atos_line))
+ if match:
+ function_name = match.group(1)
+ function_name = re.sub('\(.*?\)', '', function_name)
+ file_name = fix_filename(match.group(3))
+ return ['%s in %s %s' % (addr, function_name, file_name)]
+ else:
+ return ['%s in %s' % (addr, atos_line)]
+
+
+# Chain several symbolizers so that if one symbolizer fails, we fall back
+# to the next symbolizer in chain.
+class ChainSymbolizer(Symbolizer):
+ def __init__(self, symbolizer_list):
+ super(ChainSymbolizer, self).__init__()
+ self.symbolizer_list = symbolizer_list
+
+ def symbolize(self, addr, binary, offset):
+ """Overrides Symbolizer.symbolize."""
+ for symbolizer in self.symbolizer_list:
+ if symbolizer:
+ result = symbolizer.symbolize(addr, binary, offset)
+ if result:
+ return result
+ return None
+
+ def append_symbolizer(self, symbolizer):
+ self.symbolizer_list.append(symbolizer)
+
+
+def BreakpadSymbolizerFactory(binary):
+ suffix = os.getenv('BREAKPAD_SUFFIX')
+ if suffix:
+ filename = binary + suffix
+ if os.access(filename, os.F_OK):
+ return BreakpadSymbolizer(filename)
+ return None
+
+
+def SystemSymbolizerFactory(system, addr, binary):
+ if system == 'Darwin':
+ return DarwinSymbolizer(addr, binary)
+ elif system == 'Linux':
+ return Addr2LineSymbolizer(binary)
+
+
+class BreakpadSymbolizer(Symbolizer):
+ def __init__(self, filename):
+ super(BreakpadSymbolizer, self).__init__()
+ self.filename = filename
+ lines = file(filename).readlines()
+ self.files = []
+ self.symbols = {}
+ self.address_list = []
+ self.addresses = {}
+ # MODULE mac x86_64 A7001116478B33F18FF9BEDE9F615F190 t
+ fragments = lines[0].rstrip().split()
+ self.arch = fragments[2]
+ self.debug_id = fragments[3]
+ self.binary = ' '.join(fragments[4:])
+ self.parse_lines(lines[1:])
+
+ def parse_lines(self, lines):
+ cur_function_addr = ''
+ for line in lines:
+ fragments = line.split()
+ if fragments[0] == 'FILE':
+ assert int(fragments[1]) == len(self.files)
+ self.files.append(' '.join(fragments[2:]))
+ elif fragments[0] == 'PUBLIC':
+ self.symbols[int(fragments[1], 16)] = ' '.join(fragments[3:])
+ elif fragments[0] in ['CFI', 'STACK']:
+ pass
+ elif fragments[0] == 'FUNC':
+ cur_function_addr = int(fragments[1], 16)
+ if not cur_function_addr in self.symbols.keys():
+ self.symbols[cur_function_addr] = ' '.join(fragments[4:])
+ else:
+ # Line starting with an address.
+ addr = int(fragments[0], 16)
+ self.address_list.append(addr)
+ # Tuple of symbol address, size, line, file number.
+ self.addresses[addr] = (cur_function_addr,
+ int(fragments[1], 16),
+ int(fragments[2]),
+ int(fragments[3]))
+ self.address_list.sort()
+
+ def get_sym_file_line(self, addr):
+ key = None
+ if addr in self.addresses.keys():
+ key = addr
+ else:
+ index = bisect.bisect_left(self.address_list, addr)
+ if index == 0:
+ return None
+ else:
+ key = self.address_list[index - 1]
+ sym_id, size, line_no, file_no = self.addresses[key]
+ symbol = self.symbols[sym_id]
+ filename = self.files[file_no]
+ if addr < key + size:
+ return symbol, filename, line_no
+ else:
+ return None
+
+ def symbolize(self, addr, binary, offset):
+ if self.binary != binary:
+ return None
+ res = self.get_sym_file_line(int(offset, 16))
+ if res:
+ function_name, file_name, line_no = res
+ result = ['%s in %s %s:%d' % (
+ addr, function_name, file_name, line_no)]
+ print(result)
+ return result
+ else:
+ return None
+
+
+class SymbolizationLoop(object):
+ def __init__(self, binary_name_filter=None):
+ # Used by clients who may want to supply a different binary name.
+ # E.g. in Chrome several binaries may share a single .dSYM.
+ self.binary_name_filter = binary_name_filter
+ self.system = os.uname()[0]
+ if self.system in ['Linux', 'Darwin']:
+ self.llvm_symbolizer = LLVMSymbolizerFactory(self.system)
+ else:
+ raise Exception('Unknown system')
+
+ def symbolize_address(self, addr, binary, offset):
+ # Use the chain of symbolizers:
+ # Breakpad symbolizer -> LLVM symbolizer -> addr2line/atos
+ # (fall back to next symbolizer if the previous one fails).
+ if not binary in symbolizers:
+ symbolizers[binary] = ChainSymbolizer(
+ [BreakpadSymbolizerFactory(binary), self.llvm_symbolizer])
+ result = symbolizers[binary].symbolize(addr, binary, offset)
+ if result is None:
+ # Initialize system symbolizer only if other symbolizers failed.
+ symbolizers[binary].append_symbolizer(
+ SystemSymbolizerFactory(self.system, addr, binary))
+ result = symbolizers[binary].symbolize(addr, binary, offset)
+ # The system symbolizer must produce some result.
+ assert result
+ return result
+
+ def print_symbolized_lines(self, symbolized_lines):
+ if not symbolized_lines:
+ print(self.current_line)
+ else:
+ for symbolized_frame in symbolized_lines:
+ print(' #' + str(self.frame_no) + ' ' + symbolized_frame.rstrip())
+ self.frame_no += 1
+
+ def process_stdin(self):
+ self.frame_no = 0
+
+ if sys.version_info[0] == 2:
+ sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 0)
+ else:
+ # Unbuffered output is not supported in Python 3
+ sys.stdout = os.fdopen(sys.stdout.fileno(), 'w')
+
+ while True:
+ line = sys.stdin.readline()
+ if not line: break
+ self.current_line = line.rstrip()
+ #0 0x7f6e35cf2e45 (/blah/foo.so+0x11fe45)
+ stack_trace_line_format = (
+ '^( *#([0-9]+) *)(0x[0-9a-f]+) *\((.*)\+(0x[0-9a-f]+)\)')
+ match = re.match(stack_trace_line_format, line)
+ if not match:
+ print(self.current_line)
+ continue
+ if DEBUG:
+ print(line)
+ _, frameno_str, addr, binary, offset = match.groups()
+ if frameno_str == '0':
+ # Assume that frame #0 is the first frame of new stack trace.
+ self.frame_no = 0
+ original_binary = binary
+ if self.binary_name_filter:
+ binary = self.binary_name_filter(binary)
+ symbolized_line = self.symbolize_address(addr, binary, offset)
+ if not symbolized_line:
+ if original_binary != binary:
+ symbolized_line = self.symbolize_address(addr, binary, offset)
+ self.print_symbolized_lines(symbolized_line)
+
+
+if __name__ == '__main__':
+ loop = SymbolizationLoop()
+ loop.process_stdin()
diff --git a/src/arrow/cpp/build-support/build-lz4-lib.sh b/src/arrow/cpp/build-support/build-lz4-lib.sh
new file mode 100755
index 000000000..37c564848
--- /dev/null
+++ b/src/arrow/cpp/build-support/build-lz4-lib.sh
@@ -0,0 +1,32 @@
+#!/usr/bin/env sh
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+# mingw-w64-x86_64-make installs mingw32-make, not make
+MAKE="make"
+if [ -n "$MSYSTEM" ]; then
+ MAKE="mingw32-make"
+fi
+
+export CFLAGS="${CFLAGS} -O3 -fPIC"
+if [ -z "$MAKELEVEL" ]; then
+ "$MAKE" -j4 CFLAGS="$CFLAGS" "$@"
+else
+ "$MAKE" CFLAGS="$CFLAGS" "$@"
+fi
diff --git a/src/arrow/cpp/build-support/build-zstd-lib.sh b/src/arrow/cpp/build-support/build-zstd-lib.sh
new file mode 100755
index 000000000..444e62599
--- /dev/null
+++ b/src/arrow/cpp/build-support/build-zstd-lib.sh
@@ -0,0 +1,25 @@
+#!/usr/bin/env sh
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+export CFLAGS="${CFLAGS} -O3 -fPIC"
+if [ -z "$MAKELEVEL" ]; then
+ make -j4
+else
+ make
+fi
diff --git a/src/arrow/cpp/build-support/cpplint.py b/src/arrow/cpp/build-support/cpplint.py
new file mode 100755
index 000000000..a40c538e7
--- /dev/null
+++ b/src/arrow/cpp/build-support/cpplint.py
@@ -0,0 +1,6477 @@
+#!/usr/bin/env python3
+#
+# Copyright (c) 2009 Google Inc. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Does google-lint on c++ files.
+
+The goal of this script is to identify places in the code that *may*
+be in non-compliance with google style. It does not attempt to fix
+up these problems -- the point is to educate. It does also not
+attempt to find all problems, or to ensure that everything it does
+find is legitimately a problem.
+
+In particular, we can get very confused by /* and // inside strings!
+We do a small hack, which is to ignore //'s with "'s after them on the
+same line, but it is far from perfect (in either direction).
+"""
+
+import codecs
+import copy
+import getopt
+import glob
+import itertools
+import math # for log
+import os
+import re
+import sre_compile
+import string
+import sys
+import unicodedata
+import xml.etree.ElementTree
+
+# if empty, use defaults
+_header_extensions = set([])
+
+# if empty, use defaults
+_valid_extensions = set([])
+
+
+# Files with any of these extensions are considered to be
+# header files (and will undergo different style checks).
+# This set can be extended by using the --headers
+# option (also supported in CPPLINT.cfg)
+def GetHeaderExtensions():
+ if not _header_extensions:
+ return set(['h', 'hpp', 'hxx', 'h++', 'cuh'])
+ return _header_extensions
+
+# The allowed extensions for file names
+# This is set by --extensions flag
+def GetAllExtensions():
+ if not _valid_extensions:
+ return GetHeaderExtensions().union(set(['c', 'cc', 'cpp', 'cxx', 'c++', 'cu']))
+ return _valid_extensions
+
+def GetNonHeaderExtensions():
+ return GetAllExtensions().difference(GetHeaderExtensions())
+
+
+_USAGE = """
+Syntax: cpplint.py [--verbose=#] [--output=emacs|eclipse|vs7|junit]
+ [--filter=-x,+y,...]
+ [--counting=total|toplevel|detailed] [--repository=path]
+ [--root=subdir] [--linelength=digits] [--recursive]
+ [--exclude=path]
+ [--headers=ext1,ext2]
+ [--extensions=hpp,cpp,...]
+ <file> [file] ...
+
+ The style guidelines this tries to follow are those in
+ https://google.github.io/styleguide/cppguide.html
+
+ Every problem is given a confidence score from 1-5, with 5 meaning we are
+ certain of the problem, and 1 meaning it could be a legitimate construct.
+ This will miss some errors, and is not a substitute for a code review.
+
+ To suppress false-positive errors of a certain category, add a
+ 'NOLINT(category)' comment to the line. NOLINT or NOLINT(*)
+ suppresses errors of all categories on that line.
+
+ The files passed in will be linted; at least one file must be provided.
+ Default linted extensions are %s.
+ Other file types will be ignored.
+ Change the extensions with the --extensions flag.
+
+ Flags:
+
+ output=emacs|eclipse|vs7|junit
+ By default, the output is formatted to ease emacs parsing. Output
+ compatible with eclipse (eclipse), Visual Studio (vs7), and JUnit
+ XML parsers such as those used in Jenkins and Bamboo may also be
+ used. Other formats are unsupported.
+
+ verbose=#
+ Specify a number 0-5 to restrict errors to certain verbosity levels.
+ Errors with lower verbosity levels have lower confidence and are more
+ likely to be false positives.
+
+ quiet
+ Suppress output other than linting errors, such as information about
+ which files have been processed and excluded.
+
+ filter=-x,+y,...
+ Specify a comma-separated list of category-filters to apply: only
+ error messages whose category names pass the filters will be printed.
+ (Category names are printed with the message and look like
+ "[whitespace/indent]".) Filters are evaluated left to right.
+ "-FOO" and "FOO" means "do not print categories that start with FOO".
+ "+FOO" means "do print categories that start with FOO".
+
+ Examples: --filter=-whitespace,+whitespace/braces
+ --filter=whitespace,runtime/printf,+runtime/printf_format
+ --filter=-,+build/include_what_you_use
+
+ To see a list of all the categories used in cpplint, pass no arg:
+ --filter=
+
+ counting=total|toplevel|detailed
+ The total number of errors found is always printed. If
+ 'toplevel' is provided, then the count of errors in each of
+ the top-level categories like 'build' and 'whitespace' will
+ also be printed. If 'detailed' is provided, then a count
+ is provided for each category like 'build/class'.
+
+ repository=path
+ The top level directory of the repository, used to derive the header
+ guard CPP variable. By default, this is determined by searching for a
+ path that contains .git, .hg, or .svn. When this flag is specified, the
+ given path is used instead. This option allows the header guard CPP
+ variable to remain consistent even if members of a team have different
+ repository root directories (such as when checking out a subdirectory
+ with SVN). In addition, users of non-mainstream version control systems
+ can use this flag to ensure readable header guard CPP variables.
+
+ Examples:
+ Assuming that Alice checks out ProjectName and Bob checks out
+ ProjectName/trunk and trunk contains src/chrome/ui/browser.h, then
+ with no --repository flag, the header guard CPP variable will be:
+
+ Alice => TRUNK_SRC_CHROME_BROWSER_UI_BROWSER_H_
+ Bob => SRC_CHROME_BROWSER_UI_BROWSER_H_
+
+ If Alice uses the --repository=trunk flag and Bob omits the flag or
+ uses --repository=. then the header guard CPP variable will be:
+
+ Alice => SRC_CHROME_BROWSER_UI_BROWSER_H_
+ Bob => SRC_CHROME_BROWSER_UI_BROWSER_H_
+
+ root=subdir
+ The root directory used for deriving header guard CPP variables. This
+ directory is relative to the top level directory of the repository which
+ by default is determined by searching for a directory that contains .git,
+ .hg, or .svn but can also be controlled with the --repository flag. If
+ the specified directory does not exist, this flag is ignored.
+
+ Examples:
+ Assuming that src is the top level directory of the repository, the
+ header guard CPP variables for src/chrome/browser/ui/browser.h are:
+
+ No flag => CHROME_BROWSER_UI_BROWSER_H_
+ --root=chrome => BROWSER_UI_BROWSER_H_
+ --root=chrome/browser => UI_BROWSER_H_
+
+ linelength=digits
+ This is the allowed line length for the project. The default value is
+ 80 characters.
+
+ Examples:
+ --linelength=120
+
+ recursive
+ Search for files to lint recursively. Each directory given in the list
+ of files to be linted is replaced by all files that descend from that
+ directory. Files with extensions not in the valid extensions list are
+ excluded.
+
+ exclude=path
+ Exclude the given path from the list of files to be linted. Relative
+ paths are evaluated relative to the current directory and shell globbing
+ is performed. This flag can be provided multiple times to exclude
+ multiple files.
+
+ Examples:
+ --exclude=one.cc
+ --exclude=src/*.cc
+ --exclude=src/*.cc --exclude=test/*.cc
+
+ extensions=extension,extension,...
+ The allowed file extensions that cpplint will check
+
+ Examples:
+ --extensions=%s
+
+ headers=extension,extension,...
+ The allowed header extensions that cpplint will consider to be header files
+ (by default, only files with extensions %s
+ will be assumed to be headers)
+
+ Examples:
+ --headers=%s
+
+ cpplint.py supports per-directory configurations specified in CPPLINT.cfg
+ files. CPPLINT.cfg file can contain a number of key=value pairs.
+ Currently the following options are supported:
+
+ set noparent
+ filter=+filter1,-filter2,...
+ exclude_files=regex
+ linelength=80
+ root=subdir
+
+ "set noparent" option prevents cpplint from traversing directory tree
+ upwards looking for more .cfg files in parent directories. This option
+ is usually placed in the top-level project directory.
+
+ The "filter" option is similar in function to --filter flag. It specifies
+ message filters in addition to the |_DEFAULT_FILTERS| and those specified
+ through --filter command-line flag.
+
+ "exclude_files" allows to specify a regular expression to be matched against
+ a file name. If the expression matches, the file is skipped and not run
+ through the linter.
+
+ "linelength" specifies the allowed line length for the project.
+
+ The "root" option is similar in function to the --root flag (see example
+ above).
+
+ CPPLINT.cfg has an effect on files in the same directory and all
+ subdirectories, unless overridden by a nested configuration file.
+
+ Example file:
+ filter=-build/include_order,+build/include_alpha
+ exclude_files=.*\\.cc
+
+ The above example disables build/include_order warning and enables
+ build/include_alpha as well as excludes all .cc from being
+ processed by linter, in the current directory (where the .cfg
+ file is located) and all subdirectories.
+""" % (list(GetAllExtensions()),
+ ','.join(list(GetAllExtensions())),
+ GetHeaderExtensions(),
+ ','.join(GetHeaderExtensions()))
+
+# We categorize each error message we print. Here are the categories.
+# We want an explicit list so we can list them all in cpplint --filter=.
+# If you add a new error message with a new category, add it to the list
+# here! cpplint_unittest.py should tell you if you forget to do this.
+_ERROR_CATEGORIES = [
+ 'build/class',
+ 'build/c++11',
+ 'build/c++14',
+ 'build/c++tr1',
+ 'build/deprecated',
+ 'build/endif_comment',
+ 'build/explicit_make_pair',
+ 'build/forward_decl',
+ 'build/header_guard',
+ 'build/include',
+ 'build/include_subdir',
+ 'build/include_alpha',
+ 'build/include_order',
+ 'build/include_what_you_use',
+ 'build/namespaces_literals',
+ 'build/namespaces',
+ 'build/printf_format',
+ 'build/storage_class',
+ 'legal/copyright',
+ 'readability/alt_tokens',
+ 'readability/braces',
+ 'readability/casting',
+ 'readability/check',
+ 'readability/constructors',
+ 'readability/fn_size',
+ 'readability/inheritance',
+ 'readability/multiline_comment',
+ 'readability/multiline_string',
+ 'readability/namespace',
+ 'readability/nolint',
+ 'readability/nul',
+ 'readability/strings',
+ 'readability/todo',
+ 'readability/utf8',
+ 'runtime/arrays',
+ 'runtime/casting',
+ 'runtime/explicit',
+ 'runtime/int',
+ 'runtime/init',
+ 'runtime/invalid_increment',
+ 'runtime/member_string_references',
+ 'runtime/memset',
+ 'runtime/indentation_namespace',
+ 'runtime/operator',
+ 'runtime/printf',
+ 'runtime/printf_format',
+ 'runtime/references',
+ 'runtime/string',
+ 'runtime/threadsafe_fn',
+ 'runtime/vlog',
+ 'whitespace/blank_line',
+ 'whitespace/braces',
+ 'whitespace/comma',
+ 'whitespace/comments',
+ 'whitespace/empty_conditional_body',
+ 'whitespace/empty_if_body',
+ 'whitespace/empty_loop_body',
+ 'whitespace/end_of_line',
+ 'whitespace/ending_newline',
+ 'whitespace/forcolon',
+ 'whitespace/indent',
+ 'whitespace/line_length',
+ 'whitespace/newline',
+ 'whitespace/operators',
+ 'whitespace/parens',
+ 'whitespace/semicolon',
+ 'whitespace/tab',
+ 'whitespace/todo',
+ ]
+
+# These error categories are no longer enforced by cpplint, but for backwards-
+# compatibility they may still appear in NOLINT comments.
+_LEGACY_ERROR_CATEGORIES = [
+ 'readability/streams',
+ 'readability/function',
+ ]
+
+# The default state of the category filter. This is overridden by the --filter=
+# flag. By default all errors are on, so only add here categories that should be
+# off by default (i.e., categories that must be enabled by the --filter= flags).
+# All entries here should start with a '-' or '+', as in the --filter= flag.
+_DEFAULT_FILTERS = ['-build/include_alpha']
+
+# The default list of categories suppressed for C (not C++) files.
+_DEFAULT_C_SUPPRESSED_CATEGORIES = [
+ 'readability/casting',
+ ]
+
+# The default list of categories suppressed for Linux Kernel files.
+_DEFAULT_KERNEL_SUPPRESSED_CATEGORIES = [
+ 'whitespace/tab',
+ ]
+
+# We used to check for high-bit characters, but after much discussion we
+# decided those were OK, as long as they were in UTF-8 and didn't represent
+# hard-coded international strings, which belong in a separate i18n file.
+
+# C++ headers
+_CPP_HEADERS = frozenset([
+ # Legacy
+ 'algobase.h',
+ 'algo.h',
+ 'alloc.h',
+ 'builtinbuf.h',
+ 'bvector.h',
+ 'complex.h',
+ 'defalloc.h',
+ 'deque.h',
+ 'editbuf.h',
+ 'fstream.h',
+ 'function.h',
+ 'hash_map',
+ 'hash_map.h',
+ 'hash_set',
+ 'hash_set.h',
+ 'hashtable.h',
+ 'heap.h',
+ 'indstream.h',
+ 'iomanip.h',
+ 'iostream.h',
+ 'istream.h',
+ 'iterator.h',
+ 'list.h',
+ 'map.h',
+ 'multimap.h',
+ 'multiset.h',
+ 'ostream.h',
+ 'pair.h',
+ 'parsestream.h',
+ 'pfstream.h',
+ 'procbuf.h',
+ 'pthread_alloc',
+ 'pthread_alloc.h',
+ 'rope',
+ 'rope.h',
+ 'ropeimpl.h',
+ 'set.h',
+ 'slist',
+ 'slist.h',
+ 'stack.h',
+ 'stdiostream.h',
+ 'stl_alloc.h',
+ 'stl_relops.h',
+ 'streambuf.h',
+ 'stream.h',
+ 'strfile.h',
+ 'strstream.h',
+ 'tempbuf.h',
+ 'tree.h',
+ 'type_traits.h',
+ 'vector.h',
+ # 17.6.1.2 C++ library headers
+ 'algorithm',
+ 'array',
+ 'atomic',
+ 'bitset',
+ 'chrono',
+ 'codecvt',
+ 'complex',
+ 'condition_variable',
+ 'deque',
+ 'exception',
+ 'forward_list',
+ 'fstream',
+ 'functional',
+ 'future',
+ 'initializer_list',
+ 'iomanip',
+ 'ios',
+ 'iosfwd',
+ 'iostream',
+ 'istream',
+ 'iterator',
+ 'limits',
+ 'list',
+ 'locale',
+ 'map',
+ 'memory',
+ 'mutex',
+ 'new',
+ 'numeric',
+ 'ostream',
+ 'queue',
+ 'random',
+ 'ratio',
+ 'regex',
+ 'scoped_allocator',
+ 'set',
+ 'sstream',
+ 'stack',
+ 'stdexcept',
+ 'streambuf',
+ 'string',
+ 'strstream',
+ 'system_error',
+ 'thread',
+ 'tuple',
+ 'typeindex',
+ 'typeinfo',
+ 'type_traits',
+ 'unordered_map',
+ 'unordered_set',
+ 'utility',
+ 'valarray',
+ 'vector',
+ # 17.6.1.2 C++ headers for C library facilities
+ 'cassert',
+ 'ccomplex',
+ 'cctype',
+ 'cerrno',
+ 'cfenv',
+ 'cfloat',
+ 'cinttypes',
+ 'ciso646',
+ 'climits',
+ 'clocale',
+ 'cmath',
+ 'csetjmp',
+ 'csignal',
+ 'cstdalign',
+ 'cstdarg',
+ 'cstdbool',
+ 'cstddef',
+ 'cstdint',
+ 'cstdio',
+ 'cstdlib',
+ 'cstring',
+ 'ctgmath',
+ 'ctime',
+ 'cuchar',
+ 'cwchar',
+ 'cwctype',
+ ])
+
+# Type names
+_TYPES = re.compile(
+ r'^(?:'
+ # [dcl.type.simple]
+ r'(char(16_t|32_t)?)|wchar_t|'
+ r'bool|short|int|long|signed|unsigned|float|double|'
+ # [support.types]
+ r'(ptrdiff_t|size_t|max_align_t|nullptr_t)|'
+ # [cstdint.syn]
+ r'(u?int(_fast|_least)?(8|16|32|64)_t)|'
+ r'(u?int(max|ptr)_t)|'
+ r')$')
+
+
+# These headers are excluded from [build/include] and [build/include_order]
+# checks:
+# - Anything not following google file name conventions (containing an
+# uppercase character, such as Python.h or nsStringAPI.h, for example).
+# - Lua headers.
+_THIRD_PARTY_HEADERS_PATTERN = re.compile(
+ r'^(?:[^/]*[A-Z][^/]*\.h|lua\.h|lauxlib\.h|lualib\.h)$')
+
+# Pattern for matching FileInfo.BaseName() against test file name
+_test_suffixes = ['_test', '_regtest', '_unittest']
+_TEST_FILE_SUFFIX = '(' + '|'.join(_test_suffixes) + r')$'
+
+# Pattern that matches only complete whitespace, possibly across multiple lines.
+_EMPTY_CONDITIONAL_BODY_PATTERN = re.compile(r'^\s*$', re.DOTALL)
+
+# Assertion macros. These are defined in base/logging.h and
+# testing/base/public/gunit.h.
+_CHECK_MACROS = [
+ 'DCHECK', 'CHECK',
+ 'EXPECT_TRUE', 'ASSERT_TRUE',
+ 'EXPECT_FALSE', 'ASSERT_FALSE',
+ ]
+
+# Replacement macros for CHECK/DCHECK/EXPECT_TRUE/EXPECT_FALSE
+_CHECK_REPLACEMENT = dict([(macro_var, {}) for macro_var in _CHECK_MACROS])
+
+for op, replacement in [('==', 'EQ'), ('!=', 'NE'),
+ ('>=', 'GE'), ('>', 'GT'),
+ ('<=', 'LE'), ('<', 'LT')]:
+ _CHECK_REPLACEMENT['DCHECK'][op] = 'DCHECK_%s' % replacement
+ _CHECK_REPLACEMENT['CHECK'][op] = 'CHECK_%s' % replacement
+ _CHECK_REPLACEMENT['EXPECT_TRUE'][op] = 'EXPECT_%s' % replacement
+ _CHECK_REPLACEMENT['ASSERT_TRUE'][op] = 'ASSERT_%s' % replacement
+
+for op, inv_replacement in [('==', 'NE'), ('!=', 'EQ'),
+ ('>=', 'LT'), ('>', 'LE'),
+ ('<=', 'GT'), ('<', 'GE')]:
+ _CHECK_REPLACEMENT['EXPECT_FALSE'][op] = 'EXPECT_%s' % inv_replacement
+ _CHECK_REPLACEMENT['ASSERT_FALSE'][op] = 'ASSERT_%s' % inv_replacement
+
+# Alternative tokens and their replacements. For full list, see section 2.5
+# Alternative tokens [lex.digraph] in the C++ standard.
+#
+# Digraphs (such as '%:') are not included here since it's a mess to
+# match those on a word boundary.
+_ALT_TOKEN_REPLACEMENT = {
+ 'and': '&&',
+ 'bitor': '|',
+ 'or': '||',
+ 'xor': '^',
+ 'compl': '~',
+ 'bitand': '&',
+ 'and_eq': '&=',
+ 'or_eq': '|=',
+ 'xor_eq': '^=',
+ 'not': '!',
+ 'not_eq': '!='
+ }
+
+# Compile regular expression that matches all the above keywords. The "[ =()]"
+# bit is meant to avoid matching these keywords outside of boolean expressions.
+#
+# False positives include C-style multi-line comments and multi-line strings
+# but those have always been troublesome for cpplint.
+_ALT_TOKEN_REPLACEMENT_PATTERN = re.compile(
+ r'[ =()](' + ('|'.join(_ALT_TOKEN_REPLACEMENT.keys())) + r')(?=[ (]|$)')
+
+
+# These constants define types of headers for use with
+# _IncludeState.CheckNextIncludeOrder().
+_C_SYS_HEADER = 1
+_CPP_SYS_HEADER = 2
+_LIKELY_MY_HEADER = 3
+_POSSIBLE_MY_HEADER = 4
+_OTHER_HEADER = 5
+
+# These constants define the current inline assembly state
+_NO_ASM = 0 # Outside of inline assembly block
+_INSIDE_ASM = 1 # Inside inline assembly block
+_END_ASM = 2 # Last line of inline assembly block
+_BLOCK_ASM = 3 # The whole block is an inline assembly block
+
+# Match start of assembly blocks
+_MATCH_ASM = re.compile(r'^\s*(?:asm|_asm|__asm|__asm__)'
+ r'(?:\s+(volatile|__volatile__))?'
+ r'\s*[{(]')
+
+# Match strings that indicate we're working on a C (not C++) file.
+_SEARCH_C_FILE = re.compile(r'\b(?:LINT_C_FILE|'
+ r'vim?:\s*.*(\s*|:)filetype=c(\s*|:|$))')
+
+# Match string that indicates we're working on a Linux Kernel file.
+_SEARCH_KERNEL_FILE = re.compile(r'\b(?:LINT_KERNEL_FILE)')
+
+_regexp_compile_cache = {}
+
+# {str, set(int)}: a map from error categories to sets of linenumbers
+# on which those errors are expected and should be suppressed.
+_error_suppressions = {}
+
+# The root directory used for deriving header guard CPP variable.
+# This is set by --root flag.
+_root = None
+
+# The top level repository directory. If set, _root is calculated relative to
+# this directory instead of the directory containing version control artifacts.
+# This is set by the --repository flag.
+_repository = None
+
+# Files to exclude from linting. This is set by the --exclude flag.
+_excludes = None
+
+# Whether to suppress PrintInfo messages
+_quiet = False
+
+# The allowed line length of files.
+# This is set by --linelength flag.
+_line_length = 80
+
+try:
+ xrange(1, 0)
+except NameError:
+ # -- pylint: disable=redefined-builtin
+ xrange = range
+
+try:
+ unicode
+except NameError:
+ # -- pylint: disable=redefined-builtin
+ basestring = unicode = str
+
+try:
+ long(2)
+except NameError:
+ # -- pylint: disable=redefined-builtin
+ long = int
+
+if sys.version_info < (3,):
+ # -- pylint: disable=no-member
+ # BINARY_TYPE = str
+ itervalues = dict.itervalues
+ iteritems = dict.iteritems
+else:
+ # BINARY_TYPE = bytes
+ itervalues = dict.values
+ iteritems = dict.items
+
+def unicode_escape_decode(x):
+ if sys.version_info < (3,):
+ return codecs.unicode_escape_decode(x)[0]
+ else:
+ return x
+
+# {str, bool}: a map from error categories to booleans which indicate if the
+# category should be suppressed for every line.
+_global_error_suppressions = {}
+
+
+
+
+def ParseNolintSuppressions(filename, raw_line, linenum, error):
+ """Updates the global list of line error-suppressions.
+
+ Parses any NOLINT comments on the current line, updating the global
+ error_suppressions store. Reports an error if the NOLINT comment
+ was malformed.
+
+ Args:
+ filename: str, the name of the input file.
+ raw_line: str, the line of input text, with comments.
+ linenum: int, the number of the current line.
+ error: function, an error handler.
+ """
+ matched = Search(r'\bNOLINT(NEXTLINE)?\b(\([^)]+\))?', raw_line)
+ if matched:
+ if matched.group(1):
+ suppressed_line = linenum + 1
+ else:
+ suppressed_line = linenum
+ category = matched.group(2)
+ if category in (None, '(*)'): # => "suppress all"
+ _error_suppressions.setdefault(None, set()).add(suppressed_line)
+ else:
+ if category.startswith('(') and category.endswith(')'):
+ category = category[1:-1]
+ if category in _ERROR_CATEGORIES:
+ _error_suppressions.setdefault(category, set()).add(suppressed_line)
+ elif category not in _LEGACY_ERROR_CATEGORIES:
+ error(filename, linenum, 'readability/nolint', 5,
+ 'Unknown NOLINT error category: %s' % category)
+
+
+def ProcessGlobalSuppresions(lines):
+ """Updates the list of global error suppressions.
+
+ Parses any lint directives in the file that have global effect.
+
+ Args:
+ lines: An array of strings, each representing a line of the file, with the
+ last element being empty if the file is terminated with a newline.
+ """
+ for line in lines:
+ if _SEARCH_C_FILE.search(line):
+ for category in _DEFAULT_C_SUPPRESSED_CATEGORIES:
+ _global_error_suppressions[category] = True
+ if _SEARCH_KERNEL_FILE.search(line):
+ for category in _DEFAULT_KERNEL_SUPPRESSED_CATEGORIES:
+ _global_error_suppressions[category] = True
+
+
+def ResetNolintSuppressions():
+ """Resets the set of NOLINT suppressions to empty."""
+ _error_suppressions.clear()
+ _global_error_suppressions.clear()
+
+
+def IsErrorSuppressedByNolint(category, linenum):
+ """Returns true if the specified error category is suppressed on this line.
+
+ Consults the global error_suppressions map populated by
+ ParseNolintSuppressions/ProcessGlobalSuppresions/ResetNolintSuppressions.
+
+ Args:
+ category: str, the category of the error.
+ linenum: int, the current line number.
+ Returns:
+ bool, True iff the error should be suppressed due to a NOLINT comment or
+ global suppression.
+ """
+ return (_global_error_suppressions.get(category, False) or
+ linenum in _error_suppressions.get(category, set()) or
+ linenum in _error_suppressions.get(None, set()))
+
+
+def Match(pattern, s):
+ """Matches the string with the pattern, caching the compiled regexp."""
+ # The regexp compilation caching is inlined in both Match and Search for
+ # performance reasons; factoring it out into a separate function turns out
+ # to be noticeably expensive.
+ if pattern not in _regexp_compile_cache:
+ _regexp_compile_cache[pattern] = sre_compile.compile(pattern)
+ return _regexp_compile_cache[pattern].match(s)
+
+
+def ReplaceAll(pattern, rep, s):
+ """Replaces instances of pattern in a string with a replacement.
+
+ The compiled regex is kept in a cache shared by Match and Search.
+
+ Args:
+ pattern: regex pattern
+ rep: replacement text
+ s: search string
+
+ Returns:
+ string with replacements made (or original string if no replacements)
+ """
+ if pattern not in _regexp_compile_cache:
+ _regexp_compile_cache[pattern] = sre_compile.compile(pattern)
+ return _regexp_compile_cache[pattern].sub(rep, s)
+
+
+def Search(pattern, s):
+ """Searches the string for the pattern, caching the compiled regexp."""
+ if pattern not in _regexp_compile_cache:
+ _regexp_compile_cache[pattern] = sre_compile.compile(pattern)
+ return _regexp_compile_cache[pattern].search(s)
+
+
+def _IsSourceExtension(s):
+ """File extension (excluding dot) matches a source file extension."""
+ return s in GetNonHeaderExtensions()
+
+
+class _IncludeState(object):
+ """Tracks line numbers for includes, and the order in which includes appear.
+
+ include_list contains list of lists of (header, line number) pairs.
+ It's a lists of lists rather than just one flat list to make it
+ easier to update across preprocessor boundaries.
+
+ Call CheckNextIncludeOrder() once for each header in the file, passing
+ in the type constants defined above. Calls in an illegal order will
+ raise an _IncludeError with an appropriate error message.
+
+ """
+ # self._section will move monotonically through this set. If it ever
+ # needs to move backwards, CheckNextIncludeOrder will raise an error.
+ _INITIAL_SECTION = 0
+ _MY_H_SECTION = 1
+ _C_SECTION = 2
+ _CPP_SECTION = 3
+ _OTHER_H_SECTION = 4
+
+ _TYPE_NAMES = {
+ _C_SYS_HEADER: 'C system header',
+ _CPP_SYS_HEADER: 'C++ system header',
+ _LIKELY_MY_HEADER: 'header this file implements',
+ _POSSIBLE_MY_HEADER: 'header this file may implement',
+ _OTHER_HEADER: 'other header',
+ }
+ _SECTION_NAMES = {
+ _INITIAL_SECTION: "... nothing. (This can't be an error.)",
+ _MY_H_SECTION: 'a header this file implements',
+ _C_SECTION: 'C system header',
+ _CPP_SECTION: 'C++ system header',
+ _OTHER_H_SECTION: 'other header',
+ }
+
+ def __init__(self):
+ self.include_list = [[]]
+ self._section = None
+ self._last_header = None
+ self.ResetSection('')
+
+ def FindHeader(self, header):
+ """Check if a header has already been included.
+
+ Args:
+ header: header to check.
+ Returns:
+ Line number of previous occurrence, or -1 if the header has not
+ been seen before.
+ """
+ for section_list in self.include_list:
+ for f in section_list:
+ if f[0] == header:
+ return f[1]
+ return -1
+
+ def ResetSection(self, directive):
+ """Reset section checking for preprocessor directive.
+
+ Args:
+ directive: preprocessor directive (e.g. "if", "else").
+ """
+ # The name of the current section.
+ self._section = self._INITIAL_SECTION
+ # The path of last found header.
+ self._last_header = ''
+
+ # Update list of includes. Note that we never pop from the
+ # include list.
+ if directive in ('if', 'ifdef', 'ifndef'):
+ self.include_list.append([])
+ elif directive in ('else', 'elif'):
+ self.include_list[-1] = []
+
+ def SetLastHeader(self, header_path):
+ self._last_header = header_path
+
+ def CanonicalizeAlphabeticalOrder(self, header_path):
+ """Returns a path canonicalized for alphabetical comparison.
+
+ - replaces "-" with "_" so they both cmp the same.
+ - removes '-inl' since we don't require them to be after the main header.
+ - lowercase everything, just in case.
+
+ Args:
+ header_path: Path to be canonicalized.
+
+ Returns:
+ Canonicalized path.
+ """
+ return header_path.replace('-inl.h', '.h').replace('-', '_').lower()
+
+ def IsInAlphabeticalOrder(self, clean_lines, linenum, header_path):
+ """Check if a header is in alphabetical order with the previous header.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ header_path: Canonicalized header to be checked.
+
+ Returns:
+ Returns true if the header is in alphabetical order.
+ """
+ # If previous section is different from current section, _last_header will
+ # be reset to empty string, so it's always less than current header.
+ #
+ # If previous line was a blank line, assume that the headers are
+ # intentionally sorted the way they are.
+ if (self._last_header > header_path and
+ Match(r'^\s*#\s*include\b', clean_lines.elided[linenum - 1])):
+ return False
+ return True
+
+ def CheckNextIncludeOrder(self, header_type):
+ """Returns a non-empty error message if the next header is out of order.
+
+ This function also updates the internal state to be ready to check
+ the next include.
+
+ Args:
+ header_type: One of the _XXX_HEADER constants defined above.
+
+ Returns:
+ The empty string if the header is in the right order, or an
+ error message describing what's wrong.
+
+ """
+ error_message = ('Found %s after %s' %
+ (self._TYPE_NAMES[header_type],
+ self._SECTION_NAMES[self._section]))
+
+ last_section = self._section
+
+ if header_type == _C_SYS_HEADER:
+ if self._section <= self._C_SECTION:
+ self._section = self._C_SECTION
+ else:
+ self._last_header = ''
+ return error_message
+ elif header_type == _CPP_SYS_HEADER:
+ if self._section <= self._CPP_SECTION:
+ self._section = self._CPP_SECTION
+ else:
+ self._last_header = ''
+ return error_message
+ elif header_type == _LIKELY_MY_HEADER:
+ if self._section <= self._MY_H_SECTION:
+ self._section = self._MY_H_SECTION
+ else:
+ self._section = self._OTHER_H_SECTION
+ elif header_type == _POSSIBLE_MY_HEADER:
+ if self._section <= self._MY_H_SECTION:
+ self._section = self._MY_H_SECTION
+ else:
+ # This will always be the fallback because we're not sure
+ # enough that the header is associated with this file.
+ self._section = self._OTHER_H_SECTION
+ else:
+ assert header_type == _OTHER_HEADER
+ self._section = self._OTHER_H_SECTION
+
+ if last_section != self._section:
+ self._last_header = ''
+
+ return ''
+
+
+class _CppLintState(object):
+ """Maintains module-wide state.."""
+
+ def __init__(self):
+ self.verbose_level = 1 # global setting.
+ self.error_count = 0 # global count of reported errors
+ # filters to apply when emitting error messages
+ self.filters = _DEFAULT_FILTERS[:]
+ # backup of filter list. Used to restore the state after each file.
+ self._filters_backup = self.filters[:]
+ self.counting = 'total' # In what way are we counting errors?
+ self.errors_by_category = {} # string to int dict storing error counts
+
+ # output format:
+ # "emacs" - format that emacs can parse (default)
+ # "eclipse" - format that eclipse can parse
+ # "vs7" - format that Microsoft Visual Studio 7 can parse
+ # "junit" - format that Jenkins, Bamboo, etc can parse
+ self.output_format = 'emacs'
+
+ # For JUnit output, save errors and failures until the end so that they
+ # can be written into the XML
+ self._junit_errors = []
+ self._junit_failures = []
+
+ def SetOutputFormat(self, output_format):
+ """Sets the output format for errors."""
+ self.output_format = output_format
+
+ def SetVerboseLevel(self, level):
+ """Sets the module's verbosity, and returns the previous setting."""
+ last_verbose_level = self.verbose_level
+ self.verbose_level = level
+ return last_verbose_level
+
+ def SetCountingStyle(self, counting_style):
+ """Sets the module's counting options."""
+ self.counting = counting_style
+
+ def SetFilters(self, filters):
+ """Sets the error-message filters.
+
+ These filters are applied when deciding whether to emit a given
+ error message.
+
+ Args:
+ filters: A string of comma-separated filters (eg "+whitespace/indent").
+ Each filter should start with + or -; else we die.
+
+ Raises:
+ ValueError: The comma-separated filters did not all start with '+' or '-'.
+ E.g. "-,+whitespace,-whitespace/indent,whitespace/badfilter"
+ """
+ # Default filters always have less priority than the flag ones.
+ self.filters = _DEFAULT_FILTERS[:]
+ self.AddFilters(filters)
+
+ def AddFilters(self, filters):
+ """ Adds more filters to the existing list of error-message filters. """
+ for filt in filters.split(','):
+ clean_filt = filt.strip()
+ if clean_filt:
+ self.filters.append(clean_filt)
+ for filt in self.filters:
+ if not (filt.startswith('+') or filt.startswith('-')):
+ raise ValueError('Every filter in --filters must start with + or -'
+ ' (%s does not)' % filt)
+
+ def BackupFilters(self):
+ """ Saves the current filter list to backup storage."""
+ self._filters_backup = self.filters[:]
+
+ def RestoreFilters(self):
+ """ Restores filters previously backed up."""
+ self.filters = self._filters_backup[:]
+
+ def ResetErrorCounts(self):
+ """Sets the module's error statistic back to zero."""
+ self.error_count = 0
+ self.errors_by_category = {}
+
+ def IncrementErrorCount(self, category):
+ """Bumps the module's error statistic."""
+ self.error_count += 1
+ if self.counting in ('toplevel', 'detailed'):
+ if self.counting != 'detailed':
+ category = category.split('/')[0]
+ if category not in self.errors_by_category:
+ self.errors_by_category[category] = 0
+ self.errors_by_category[category] += 1
+
+ def PrintErrorCounts(self):
+ """Print a summary of errors by category, and the total."""
+ for category, count in sorted(iteritems(self.errors_by_category)):
+ self.PrintInfo('Category \'%s\' errors found: %d\n' %
+ (category, count))
+ if self.error_count > 0:
+ self.PrintInfo('Total errors found: %d\n' % self.error_count)
+
+ def PrintInfo(self, message):
+ if not _quiet and self.output_format != 'junit':
+ sys.stderr.write(message)
+
+ def PrintError(self, message):
+ if self.output_format == 'junit':
+ self._junit_errors.append(message)
+ else:
+ sys.stderr.write(message)
+
+ def AddJUnitFailure(self, filename, linenum, message, category, confidence):
+ self._junit_failures.append((filename, linenum, message, category,
+ confidence))
+
+ def FormatJUnitXML(self):
+ num_errors = len(self._junit_errors)
+ num_failures = len(self._junit_failures)
+
+ testsuite = xml.etree.ElementTree.Element('testsuite')
+ testsuite.attrib['name'] = 'cpplint'
+ testsuite.attrib['errors'] = str(num_errors)
+ testsuite.attrib['failures'] = str(num_failures)
+
+ if num_errors == 0 and num_failures == 0:
+ testsuite.attrib['tests'] = str(1)
+ xml.etree.ElementTree.SubElement(testsuite, 'testcase', name='passed')
+
+ else:
+ testsuite.attrib['tests'] = str(num_errors + num_failures)
+ if num_errors > 0:
+ testcase = xml.etree.ElementTree.SubElement(testsuite, 'testcase')
+ testcase.attrib['name'] = 'errors'
+ error = xml.etree.ElementTree.SubElement(testcase, 'error')
+ error.text = '\n'.join(self._junit_errors)
+ if num_failures > 0:
+ # Group failures by file
+ failed_file_order = []
+ failures_by_file = {}
+ for failure in self._junit_failures:
+ failed_file = failure[0]
+ if failed_file not in failed_file_order:
+ failed_file_order.append(failed_file)
+ failures_by_file[failed_file] = []
+ failures_by_file[failed_file].append(failure)
+ # Create a testcase for each file
+ for failed_file in failed_file_order:
+ failures = failures_by_file[failed_file]
+ testcase = xml.etree.ElementTree.SubElement(testsuite, 'testcase')
+ testcase.attrib['name'] = failed_file
+ failure = xml.etree.ElementTree.SubElement(testcase, 'failure')
+ template = '{0}: {1} [{2}] [{3}]'
+ texts = [template.format(f[1], f[2], f[3], f[4]) for f in failures]
+ failure.text = '\n'.join(texts)
+
+ xml_decl = '<?xml version="1.0" encoding="UTF-8" ?>\n'
+ return xml_decl + xml.etree.ElementTree.tostring(testsuite, 'utf-8').decode('utf-8')
+
+
+_cpplint_state = _CppLintState()
+
+
+def _OutputFormat():
+ """Gets the module's output format."""
+ return _cpplint_state.output_format
+
+
+def _SetOutputFormat(output_format):
+ """Sets the module's output format."""
+ _cpplint_state.SetOutputFormat(output_format)
+
+
+def _VerboseLevel():
+ """Returns the module's verbosity setting."""
+ return _cpplint_state.verbose_level
+
+
+def _SetVerboseLevel(level):
+ """Sets the module's verbosity, and returns the previous setting."""
+ return _cpplint_state.SetVerboseLevel(level)
+
+
+def _SetCountingStyle(level):
+ """Sets the module's counting options."""
+ _cpplint_state.SetCountingStyle(level)
+
+
+def _Filters():
+ """Returns the module's list of output filters, as a list."""
+ return _cpplint_state.filters
+
+
+def _SetFilters(filters):
+ """Sets the module's error-message filters.
+
+ These filters are applied when deciding whether to emit a given
+ error message.
+
+ Args:
+ filters: A string of comma-separated filters (eg "whitespace/indent").
+ Each filter should start with + or -; else we die.
+ """
+ _cpplint_state.SetFilters(filters)
+
+def _AddFilters(filters):
+ """Adds more filter overrides.
+
+ Unlike _SetFilters, this function does not reset the current list of filters
+ available.
+
+ Args:
+ filters: A string of comma-separated filters (eg "whitespace/indent").
+ Each filter should start with + or -; else we die.
+ """
+ _cpplint_state.AddFilters(filters)
+
+def _BackupFilters():
+ """ Saves the current filter list to backup storage."""
+ _cpplint_state.BackupFilters()
+
+def _RestoreFilters():
+ """ Restores filters previously backed up."""
+ _cpplint_state.RestoreFilters()
+
+class _FunctionState(object):
+ """Tracks current function name and the number of lines in its body."""
+
+ _NORMAL_TRIGGER = 250 # for --v=0, 500 for --v=1, etc.
+ _TEST_TRIGGER = 400 # about 50% more than _NORMAL_TRIGGER.
+
+ def __init__(self):
+ self.in_a_function = False
+ self.lines_in_function = 0
+ self.current_function = ''
+
+ def Begin(self, function_name):
+ """Start analyzing function body.
+
+ Args:
+ function_name: The name of the function being tracked.
+ """
+ self.in_a_function = True
+ self.lines_in_function = 0
+ self.current_function = function_name
+
+ def Count(self):
+ """Count line in current function body."""
+ if self.in_a_function:
+ self.lines_in_function += 1
+
+ def Check(self, error, filename, linenum):
+ """Report if too many lines in function body.
+
+ Args:
+ error: The function to call with any errors found.
+ filename: The name of the current file.
+ linenum: The number of the line to check.
+ """
+ if not self.in_a_function:
+ return
+
+ if Match(r'T(EST|est)', self.current_function):
+ base_trigger = self._TEST_TRIGGER
+ else:
+ base_trigger = self._NORMAL_TRIGGER
+ trigger = base_trigger * 2**_VerboseLevel()
+
+ if self.lines_in_function > trigger:
+ error_level = int(math.log(self.lines_in_function / base_trigger, 2))
+ # 50 => 0, 100 => 1, 200 => 2, 400 => 3, 800 => 4, 1600 => 5, ...
+ if error_level > 5:
+ error_level = 5
+ error(filename, linenum, 'readability/fn_size', error_level,
+ 'Small and focused functions are preferred:'
+ ' %s has %d non-comment lines'
+ ' (error triggered by exceeding %d lines).' % (
+ self.current_function, self.lines_in_function, trigger))
+
+ def End(self):
+ """Stop analyzing function body."""
+ self.in_a_function = False
+
+
+class _IncludeError(Exception):
+ """Indicates a problem with the include order in a file."""
+ pass
+
+
+class FileInfo(object):
+ """Provides utility functions for filenames.
+
+ FileInfo provides easy access to the components of a file's path
+ relative to the project root.
+ """
+
+ def __init__(self, filename):
+ self._filename = filename
+
+ def FullName(self):
+ """Make Windows paths like Unix."""
+ return os.path.abspath(self._filename).replace('\\', '/')
+
+ def RepositoryName(self):
+ r"""FullName after removing the local path to the repository.
+
+ If we have a real absolute path name here we can try to do something smart:
+ detecting the root of the checkout and truncating /path/to/checkout from
+ the name so that we get header guards that don't include things like
+ "C:\Documents and Settings\..." or "/home/username/..." in them and thus
+ people on different computers who have checked the source out to different
+ locations won't see bogus errors.
+ """
+ fullname = self.FullName()
+
+ if os.path.exists(fullname):
+ project_dir = os.path.dirname(fullname)
+
+ # If the user specified a repository path, it exists, and the file is
+ # contained in it, use the specified repository path
+ if _repository:
+ repo = FileInfo(_repository).FullName()
+ root_dir = project_dir
+ while os.path.exists(root_dir):
+ # allow case insensitive compare on Windows
+ if os.path.normcase(root_dir) == os.path.normcase(repo):
+ return os.path.relpath(fullname, root_dir).replace('\\', '/')
+ one_up_dir = os.path.dirname(root_dir)
+ if one_up_dir == root_dir:
+ break
+ root_dir = one_up_dir
+
+ if os.path.exists(os.path.join(project_dir, ".svn")):
+ # If there's a .svn file in the current directory, we recursively look
+ # up the directory tree for the top of the SVN checkout
+ root_dir = project_dir
+ one_up_dir = os.path.dirname(root_dir)
+ while os.path.exists(os.path.join(one_up_dir, ".svn")):
+ root_dir = os.path.dirname(root_dir)
+ one_up_dir = os.path.dirname(one_up_dir)
+
+ prefix = os.path.commonprefix([root_dir, project_dir])
+ return fullname[len(prefix) + 1:]
+
+ # Not SVN <= 1.6? Try to find a git, hg, or svn top level directory by
+ # searching up from the current path.
+ root_dir = current_dir = os.path.dirname(fullname)
+ while current_dir != os.path.dirname(current_dir):
+ if (os.path.exists(os.path.join(current_dir, ".git")) or
+ os.path.exists(os.path.join(current_dir, ".hg")) or
+ os.path.exists(os.path.join(current_dir, ".svn"))):
+ root_dir = current_dir
+ current_dir = os.path.dirname(current_dir)
+
+ if (os.path.exists(os.path.join(root_dir, ".git")) or
+ os.path.exists(os.path.join(root_dir, ".hg")) or
+ os.path.exists(os.path.join(root_dir, ".svn"))):
+ prefix = os.path.commonprefix([root_dir, project_dir])
+ return fullname[len(prefix) + 1:]
+
+ # Don't know what to do; header guard warnings may be wrong...
+ return fullname
+
+ def Split(self):
+ """Splits the file into the directory, basename, and extension.
+
+ For 'chrome/browser/browser.cc', Split() would
+ return ('chrome/browser', 'browser', '.cc')
+
+ Returns:
+ A tuple of (directory, basename, extension).
+ """
+
+ googlename = self.RepositoryName()
+ project, rest = os.path.split(googlename)
+ return (project,) + os.path.splitext(rest)
+
+ def BaseName(self):
+ """File base name - text after the final slash, before the final period."""
+ return self.Split()[1]
+
+ def Extension(self):
+ """File extension - text following the final period, includes that period."""
+ return self.Split()[2]
+
+ def NoExtension(self):
+ """File has no source file extension."""
+ return '/'.join(self.Split()[0:2])
+
+ def IsSource(self):
+ """File has a source file extension."""
+ return _IsSourceExtension(self.Extension()[1:])
+
+
+def _ShouldPrintError(category, confidence, linenum):
+ """If confidence >= verbose, category passes filter and is not suppressed."""
+
+ # There are three ways we might decide not to print an error message:
+ # a "NOLINT(category)" comment appears in the source,
+ # the verbosity level isn't high enough, or the filters filter it out.
+ if IsErrorSuppressedByNolint(category, linenum):
+ return False
+
+ if confidence < _cpplint_state.verbose_level:
+ return False
+
+ is_filtered = False
+ for one_filter in _Filters():
+ if one_filter.startswith('-'):
+ if category.startswith(one_filter[1:]):
+ is_filtered = True
+ elif one_filter.startswith('+'):
+ if category.startswith(one_filter[1:]):
+ is_filtered = False
+ else:
+ assert False # should have been checked for in SetFilter.
+ if is_filtered:
+ return False
+
+ return True
+
+
+def Error(filename, linenum, category, confidence, message):
+ """Logs the fact we've found a lint error.
+
+ We log where the error was found, and also our confidence in the error,
+ that is, how certain we are this is a legitimate style regression, and
+ not a misidentification or a use that's sometimes justified.
+
+ False positives can be suppressed by the use of
+ "cpplint(category)" comments on the offending line. These are
+ parsed into _error_suppressions.
+
+ Args:
+ filename: The name of the file containing the error.
+ linenum: The number of the line containing the error.
+ category: A string used to describe the "category" this bug
+ falls under: "whitespace", say, or "runtime". Categories
+ may have a hierarchy separated by slashes: "whitespace/indent".
+ confidence: A number from 1-5 representing a confidence score for
+ the error, with 5 meaning that we are certain of the problem,
+ and 1 meaning that it could be a legitimate construct.
+ message: The error message.
+ """
+ if _ShouldPrintError(category, confidence, linenum):
+ _cpplint_state.IncrementErrorCount(category)
+ if _cpplint_state.output_format == 'vs7':
+ _cpplint_state.PrintError('%s(%s): warning: %s [%s] [%d]\n' % (
+ filename, linenum, message, category, confidence))
+ elif _cpplint_state.output_format == 'eclipse':
+ sys.stderr.write('%s:%s: warning: %s [%s] [%d]\n' % (
+ filename, linenum, message, category, confidence))
+ elif _cpplint_state.output_format == 'junit':
+ _cpplint_state.AddJUnitFailure(filename, linenum, message, category,
+ confidence)
+ else:
+ final_message = '%s:%s: %s [%s] [%d]\n' % (
+ filename, linenum, message, category, confidence)
+ sys.stderr.write(final_message)
+
+# Matches standard C++ escape sequences per 2.13.2.3 of the C++ standard.
+_RE_PATTERN_CLEANSE_LINE_ESCAPES = re.compile(
+ r'\\([abfnrtv?"\\\']|\d+|x[0-9a-fA-F]+)')
+# Match a single C style comment on the same line.
+_RE_PATTERN_C_COMMENTS = r'/\*(?:[^*]|\*(?!/))*\*/'
+# Matches multi-line C style comments.
+# This RE is a little bit more complicated than one might expect, because we
+# have to take care of space removals tools so we can handle comments inside
+# statements better.
+# The current rule is: We only clear spaces from both sides when we're at the
+# end of the line. Otherwise, we try to remove spaces from the right side,
+# if this doesn't work we try on left side but only if there's a non-character
+# on the right.
+_RE_PATTERN_CLEANSE_LINE_C_COMMENTS = re.compile(
+ r'(\s*' + _RE_PATTERN_C_COMMENTS + r'\s*$|' +
+ _RE_PATTERN_C_COMMENTS + r'\s+|' +
+ r'\s+' + _RE_PATTERN_C_COMMENTS + r'(?=\W)|' +
+ _RE_PATTERN_C_COMMENTS + r')')
+
+
+def IsCppString(line):
+ """Does line terminate so, that the next symbol is in string constant.
+
+ This function does not consider single-line nor multi-line comments.
+
+ Args:
+ line: is a partial line of code starting from the 0..n.
+
+ Returns:
+ True, if next character appended to 'line' is inside a
+ string constant.
+ """
+
+ line = line.replace(r'\\', 'XX') # after this, \\" does not match to \"
+ return ((line.count('"') - line.count(r'\"') - line.count("'\"'")) & 1) == 1
+
+
+def CleanseRawStrings(raw_lines):
+ """Removes C++11 raw strings from lines.
+
+ Before:
+ static const char kData[] = R"(
+ multi-line string
+ )";
+
+ After:
+ static const char kData[] = ""
+ (replaced by blank line)
+ "";
+
+ Args:
+ raw_lines: list of raw lines.
+
+ Returns:
+ list of lines with C++11 raw strings replaced by empty strings.
+ """
+
+ delimiter = None
+ lines_without_raw_strings = []
+ for line in raw_lines:
+ if delimiter:
+ # Inside a raw string, look for the end
+ end = line.find(delimiter)
+ if end >= 0:
+ # Found the end of the string, match leading space for this
+ # line and resume copying the original lines, and also insert
+ # a "" on the last line.
+ leading_space = Match(r'^(\s*)\S', line)
+ line = leading_space.group(1) + '""' + line[end + len(delimiter):]
+ delimiter = None
+ else:
+ # Haven't found the end yet, append a blank line.
+ line = '""'
+
+ # Look for beginning of a raw string, and replace them with
+ # empty strings. This is done in a loop to handle multiple raw
+ # strings on the same line.
+ while delimiter is None:
+ # Look for beginning of a raw string.
+ # See 2.14.15 [lex.string] for syntax.
+ #
+ # Once we have matched a raw string, we check the prefix of the
+ # line to make sure that the line is not part of a single line
+ # comment. It's done this way because we remove raw strings
+ # before removing comments as opposed to removing comments
+ # before removing raw strings. This is because there are some
+ # cpplint checks that requires the comments to be preserved, but
+ # we don't want to check comments that are inside raw strings.
+ matched = Match(r'^(.*?)\b(?:R|u8R|uR|UR|LR)"([^\s\\()]*)\((.*)$', line)
+ if (matched and
+ not Match(r'^([^\'"]|\'(\\.|[^\'])*\'|"(\\.|[^"])*")*//',
+ matched.group(1))):
+ delimiter = ')' + matched.group(2) + '"'
+
+ end = matched.group(3).find(delimiter)
+ if end >= 0:
+ # Raw string ended on same line
+ line = (matched.group(1) + '""' +
+ matched.group(3)[end + len(delimiter):])
+ delimiter = None
+ else:
+ # Start of a multi-line raw string
+ line = matched.group(1) + '""'
+ else:
+ break
+
+ lines_without_raw_strings.append(line)
+
+ # TODO(unknown): if delimiter is not None here, we might want to
+ # emit a warning for unterminated string.
+ return lines_without_raw_strings
+
+
+def FindNextMultiLineCommentStart(lines, lineix):
+ """Find the beginning marker for a multiline comment."""
+ while lineix < len(lines):
+ if lines[lineix].strip().startswith('/*'):
+ # Only return this marker if the comment goes beyond this line
+ if lines[lineix].strip().find('*/', 2) < 0:
+ return lineix
+ lineix += 1
+ return len(lines)
+
+
+def FindNextMultiLineCommentEnd(lines, lineix):
+ """We are inside a comment, find the end marker."""
+ while lineix < len(lines):
+ if lines[lineix].strip().endswith('*/'):
+ return lineix
+ lineix += 1
+ return len(lines)
+
+
+def RemoveMultiLineCommentsFromRange(lines, begin, end):
+ """Clears a range of lines for multi-line comments."""
+ # Having // dummy comments makes the lines non-empty, so we will not get
+ # unnecessary blank line warnings later in the code.
+ for i in range(begin, end):
+ lines[i] = '/**/'
+
+
+def RemoveMultiLineComments(filename, lines, error):
+ """Removes multiline (c-style) comments from lines."""
+ lineix = 0
+ while lineix < len(lines):
+ lineix_begin = FindNextMultiLineCommentStart(lines, lineix)
+ if lineix_begin >= len(lines):
+ return
+ lineix_end = FindNextMultiLineCommentEnd(lines, lineix_begin)
+ if lineix_end >= len(lines):
+ error(filename, lineix_begin + 1, 'readability/multiline_comment', 5,
+ 'Could not find end of multi-line comment')
+ return
+ RemoveMultiLineCommentsFromRange(lines, lineix_begin, lineix_end + 1)
+ lineix = lineix_end + 1
+
+
+def CleanseComments(line):
+ """Removes //-comments and single-line C-style /* */ comments.
+
+ Args:
+ line: A line of C++ source.
+
+ Returns:
+ The line with single-line comments removed.
+ """
+ commentpos = line.find('//')
+ if commentpos != -1 and not IsCppString(line[:commentpos]):
+ line = line[:commentpos].rstrip()
+ # get rid of /* ... */
+ return _RE_PATTERN_CLEANSE_LINE_C_COMMENTS.sub('', line)
+
+
+class CleansedLines(object):
+ """Holds 4 copies of all lines with different preprocessing applied to them.
+
+ 1) elided member contains lines without strings and comments.
+ 2) lines member contains lines without comments.
+ 3) raw_lines member contains all the lines without processing.
+ 4) lines_without_raw_strings member is same as raw_lines, but with C++11 raw
+ strings removed.
+ All these members are of <type 'list'>, and of the same length.
+ """
+
+ def __init__(self, lines):
+ self.elided = []
+ self.lines = []
+ self.raw_lines = lines
+ self.num_lines = len(lines)
+ self.lines_without_raw_strings = CleanseRawStrings(lines)
+ for linenum in range(len(self.lines_without_raw_strings)):
+ self.lines.append(CleanseComments(
+ self.lines_without_raw_strings[linenum]))
+ elided = self._CollapseStrings(self.lines_without_raw_strings[linenum])
+ self.elided.append(CleanseComments(elided))
+
+ def NumLines(self):
+ """Returns the number of lines represented."""
+ return self.num_lines
+
+ @staticmethod
+ def _CollapseStrings(elided):
+ """Collapses strings and chars on a line to simple "" or '' blocks.
+
+ We nix strings first so we're not fooled by text like '"http://"'
+
+ Args:
+ elided: The line being processed.
+
+ Returns:
+ The line with collapsed strings.
+ """
+ if _RE_PATTERN_INCLUDE.match(elided):
+ return elided
+
+ # Remove escaped characters first to make quote/single quote collapsing
+ # basic. Things that look like escaped characters shouldn't occur
+ # outside of strings and chars.
+ elided = _RE_PATTERN_CLEANSE_LINE_ESCAPES.sub('', elided)
+
+ # Replace quoted strings and digit separators. Both single quotes
+ # and double quotes are processed in the same loop, otherwise
+ # nested quotes wouldn't work.
+ collapsed = ''
+ while True:
+ # Find the first quote character
+ match = Match(r'^([^\'"]*)([\'"])(.*)$', elided)
+ if not match:
+ collapsed += elided
+ break
+ head, quote, tail = match.groups()
+
+ if quote == '"':
+ # Collapse double quoted strings
+ second_quote = tail.find('"')
+ if second_quote >= 0:
+ collapsed += head + '""'
+ elided = tail[second_quote + 1:]
+ else:
+ # Unmatched double quote, don't bother processing the rest
+ # of the line since this is probably a multiline string.
+ collapsed += elided
+ break
+ else:
+ # Found single quote, check nearby text to eliminate digit separators.
+ #
+ # There is no special handling for floating point here, because
+ # the integer/fractional/exponent parts would all be parsed
+ # correctly as long as there are digits on both sides of the
+ # separator. So we are fine as long as we don't see something
+ # like "0.'3" (gcc 4.9.0 will not allow this literal).
+ if Search(r'\b(?:0[bBxX]?|[1-9])[0-9a-fA-F]*$', head):
+ match_literal = Match(r'^((?:\'?[0-9a-zA-Z_])*)(.*)$', "'" + tail)
+ collapsed += head + match_literal.group(1).replace("'", '')
+ elided = match_literal.group(2)
+ else:
+ second_quote = tail.find('\'')
+ if second_quote >= 0:
+ collapsed += head + "''"
+ elided = tail[second_quote + 1:]
+ else:
+ # Unmatched single quote
+ collapsed += elided
+ break
+
+ return collapsed
+
+
+def FindEndOfExpressionInLine(line, startpos, stack):
+ """Find the position just after the end of current parenthesized expression.
+
+ Args:
+ line: a CleansedLines line.
+ startpos: start searching at this position.
+ stack: nesting stack at startpos.
+
+ Returns:
+ On finding matching end: (index just after matching end, None)
+ On finding an unclosed expression: (-1, None)
+ Otherwise: (-1, new stack at end of this line)
+ """
+ for i in xrange(startpos, len(line)):
+ char = line[i]
+ if char in '([{':
+ # Found start of parenthesized expression, push to expression stack
+ stack.append(char)
+ elif char == '<':
+ # Found potential start of template argument list
+ if i > 0 and line[i - 1] == '<':
+ # Left shift operator
+ if stack and stack[-1] == '<':
+ stack.pop()
+ if not stack:
+ return (-1, None)
+ elif i > 0 and Search(r'\boperator\s*$', line[0:i]):
+ # operator<, don't add to stack
+ continue
+ else:
+ # Tentative start of template argument list
+ stack.append('<')
+ elif char in ')]}':
+ # Found end of parenthesized expression.
+ #
+ # If we are currently expecting a matching '>', the pending '<'
+ # must have been an operator. Remove them from expression stack.
+ while stack and stack[-1] == '<':
+ stack.pop()
+ if not stack:
+ return (-1, None)
+ if ((stack[-1] == '(' and char == ')') or
+ (stack[-1] == '[' and char == ']') or
+ (stack[-1] == '{' and char == '}')):
+ stack.pop()
+ if not stack:
+ return (i + 1, None)
+ else:
+ # Mismatched parentheses
+ return (-1, None)
+ elif char == '>':
+ # Found potential end of template argument list.
+
+ # Ignore "->" and operator functions
+ if (i > 0 and
+ (line[i - 1] == '-' or Search(r'\boperator\s*$', line[0:i - 1]))):
+ continue
+
+ # Pop the stack if there is a matching '<'. Otherwise, ignore
+ # this '>' since it must be an operator.
+ if stack:
+ if stack[-1] == '<':
+ stack.pop()
+ if not stack:
+ return (i + 1, None)
+ elif char == ';':
+ # Found something that look like end of statements. If we are currently
+ # expecting a '>', the matching '<' must have been an operator, since
+ # template argument list should not contain statements.
+ while stack and stack[-1] == '<':
+ stack.pop()
+ if not stack:
+ return (-1, None)
+
+ # Did not find end of expression or unbalanced parentheses on this line
+ return (-1, stack)
+
+
+def CloseExpression(clean_lines, linenum, pos):
+ """If input points to ( or { or [ or <, finds the position that closes it.
+
+ If lines[linenum][pos] points to a '(' or '{' or '[' or '<', finds the
+ linenum/pos that correspond to the closing of the expression.
+
+ TODO(unknown): cpplint spends a fair bit of time matching parentheses.
+ Ideally we would want to index all opening and closing parentheses once
+ and have CloseExpression be just a simple lookup, but due to preprocessor
+ tricks, this is not so easy.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ pos: A position on the line.
+
+ Returns:
+ A tuple (line, linenum, pos) pointer *past* the closing brace, or
+ (line, len(lines), -1) if we never find a close. Note we ignore
+ strings and comments when matching; and the line we return is the
+ 'cleansed' line at linenum.
+ """
+
+ line = clean_lines.elided[linenum]
+ if (line[pos] not in '({[<') or Match(r'<[<=]', line[pos:]):
+ return (line, clean_lines.NumLines(), -1)
+
+ # Check first line
+ (end_pos, stack) = FindEndOfExpressionInLine(line, pos, [])
+ if end_pos > -1:
+ return (line, linenum, end_pos)
+
+ # Continue scanning forward
+ while stack and linenum < clean_lines.NumLines() - 1:
+ linenum += 1
+ line = clean_lines.elided[linenum]
+ (end_pos, stack) = FindEndOfExpressionInLine(line, 0, stack)
+ if end_pos > -1:
+ return (line, linenum, end_pos)
+
+ # Did not find end of expression before end of file, give up
+ return (line, clean_lines.NumLines(), -1)
+
+
+def FindStartOfExpressionInLine(line, endpos, stack):
+ """Find position at the matching start of current expression.
+
+ This is almost the reverse of FindEndOfExpressionInLine, but note
+ that the input position and returned position differs by 1.
+
+ Args:
+ line: a CleansedLines line.
+ endpos: start searching at this position.
+ stack: nesting stack at endpos.
+
+ Returns:
+ On finding matching start: (index at matching start, None)
+ On finding an unclosed expression: (-1, None)
+ Otherwise: (-1, new stack at beginning of this line)
+ """
+ i = endpos
+ while i >= 0:
+ char = line[i]
+ if char in ')]}':
+ # Found end of expression, push to expression stack
+ stack.append(char)
+ elif char == '>':
+ # Found potential end of template argument list.
+ #
+ # Ignore it if it's a "->" or ">=" or "operator>"
+ if (i > 0 and
+ (line[i - 1] == '-' or
+ Match(r'\s>=\s', line[i - 1:]) or
+ Search(r'\boperator\s*$', line[0:i]))):
+ i -= 1
+ else:
+ stack.append('>')
+ elif char == '<':
+ # Found potential start of template argument list
+ if i > 0 and line[i - 1] == '<':
+ # Left shift operator
+ i -= 1
+ else:
+ # If there is a matching '>', we can pop the expression stack.
+ # Otherwise, ignore this '<' since it must be an operator.
+ if stack and stack[-1] == '>':
+ stack.pop()
+ if not stack:
+ return (i, None)
+ elif char in '([{':
+ # Found start of expression.
+ #
+ # If there are any unmatched '>' on the stack, they must be
+ # operators. Remove those.
+ while stack and stack[-1] == '>':
+ stack.pop()
+ if not stack:
+ return (-1, None)
+ if ((char == '(' and stack[-1] == ')') or
+ (char == '[' and stack[-1] == ']') or
+ (char == '{' and stack[-1] == '}')):
+ stack.pop()
+ if not stack:
+ return (i, None)
+ else:
+ # Mismatched parentheses
+ return (-1, None)
+ elif char == ';':
+ # Found something that look like end of statements. If we are currently
+ # expecting a '<', the matching '>' must have been an operator, since
+ # template argument list should not contain statements.
+ while stack and stack[-1] == '>':
+ stack.pop()
+ if not stack:
+ return (-1, None)
+
+ i -= 1
+
+ return (-1, stack)
+
+
+def ReverseCloseExpression(clean_lines, linenum, pos):
+ """If input points to ) or } or ] or >, finds the position that opens it.
+
+ If lines[linenum][pos] points to a ')' or '}' or ']' or '>', finds the
+ linenum/pos that correspond to the opening of the expression.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ pos: A position on the line.
+
+ Returns:
+ A tuple (line, linenum, pos) pointer *at* the opening brace, or
+ (line, 0, -1) if we never find the matching opening brace. Note
+ we ignore strings and comments when matching; and the line we
+ return is the 'cleansed' line at linenum.
+ """
+ line = clean_lines.elided[linenum]
+ if line[pos] not in ')}]>':
+ return (line, 0, -1)
+
+ # Check last line
+ (start_pos, stack) = FindStartOfExpressionInLine(line, pos, [])
+ if start_pos > -1:
+ return (line, linenum, start_pos)
+
+ # Continue scanning backward
+ while stack and linenum > 0:
+ linenum -= 1
+ line = clean_lines.elided[linenum]
+ (start_pos, stack) = FindStartOfExpressionInLine(line, len(line) - 1, stack)
+ if start_pos > -1:
+ return (line, linenum, start_pos)
+
+ # Did not find start of expression before beginning of file, give up
+ return (line, 0, -1)
+
+
+def CheckForCopyright(filename, lines, error):
+ """Logs an error if no Copyright message appears at the top of the file."""
+
+ # We'll say it should occur by line 10. Don't forget there's a
+ # dummy line at the front.
+ for line in range(1, min(len(lines), 11)):
+ if re.search(r'Copyright', lines[line], re.I): break
+ else: # means no copyright line was found
+ error(filename, 0, 'legal/copyright', 5,
+ 'No copyright message found. '
+ 'You should have a line: "Copyright [year] <Copyright Owner>"')
+
+
+def GetIndentLevel(line):
+ """Return the number of leading spaces in line.
+
+ Args:
+ line: A string to check.
+
+ Returns:
+ An integer count of leading spaces, possibly zero.
+ """
+ indent = Match(r'^( *)\S', line)
+ if indent:
+ return len(indent.group(1))
+ else:
+ return 0
+
+
+def GetHeaderGuardCPPVariable(filename):
+ """Returns the CPP variable that should be used as a header guard.
+
+ Args:
+ filename: The name of a C++ header file.
+
+ Returns:
+ The CPP variable that should be used as a header guard in the
+ named file.
+
+ """
+
+ # Restores original filename in case that cpplint is invoked from Emacs's
+ # flymake.
+ filename = re.sub(r'_flymake\.h$', '.h', filename)
+ filename = re.sub(r'/\.flymake/([^/]*)$', r'/\1', filename)
+ # Replace 'c++' with 'cpp'.
+ filename = filename.replace('C++', 'cpp').replace('c++', 'cpp')
+
+ fileinfo = FileInfo(filename)
+ file_path_from_root = fileinfo.RepositoryName()
+ if _root:
+ suffix = os.sep
+ # On Windows using directory separator will leave us with
+ # "bogus escape error" unless we properly escape regex.
+ if suffix == '\\':
+ suffix += '\\'
+ file_path_from_root = re.sub('^' + _root + suffix, '', file_path_from_root)
+ return re.sub(r'[^a-zA-Z0-9]', '_', file_path_from_root).upper() + '_'
+
+
+def CheckForHeaderGuard(filename, clean_lines, error):
+ """Checks that the file contains a header guard.
+
+ Logs an error if no #ifndef header guard is present. For other
+ headers, checks that the full pathname is used.
+
+ Args:
+ filename: The name of the C++ header file.
+ clean_lines: A CleansedLines instance containing the file.
+ error: The function to call with any errors found.
+ """
+
+ # Don't check for header guards if there are error suppression
+ # comments somewhere in this file.
+ #
+ # Because this is silencing a warning for a nonexistent line, we
+ # only support the very specific NOLINT(build/header_guard) syntax,
+ # and not the general NOLINT or NOLINT(*) syntax.
+ raw_lines = clean_lines.lines_without_raw_strings
+ for i in raw_lines:
+ if Search(r'//\s*NOLINT\(build/header_guard\)', i):
+ return
+
+ # Allow pragma once instead of header guards
+ for i in raw_lines:
+ if Search(r'^\s*#pragma\s+once', i):
+ return
+
+ cppvar = GetHeaderGuardCPPVariable(filename)
+
+ ifndef = ''
+ ifndef_linenum = 0
+ define = ''
+ endif = ''
+ endif_linenum = 0
+ for linenum, line in enumerate(raw_lines):
+ linesplit = line.split()
+ if len(linesplit) >= 2:
+ # find the first occurrence of #ifndef and #define, save arg
+ if not ifndef and linesplit[0] == '#ifndef':
+ # set ifndef to the header guard presented on the #ifndef line.
+ ifndef = linesplit[1]
+ ifndef_linenum = linenum
+ if not define and linesplit[0] == '#define':
+ define = linesplit[1]
+ # find the last occurrence of #endif, save entire line
+ if line.startswith('#endif'):
+ endif = line
+ endif_linenum = linenum
+
+ if not ifndef or not define or ifndef != define:
+ error(filename, 0, 'build/header_guard', 5,
+ 'No #ifndef header guard found, suggested CPP variable is: %s' %
+ cppvar)
+ return
+
+ # The guard should be PATH_FILE_H_, but we also allow PATH_FILE_H__
+ # for backward compatibility.
+ if ifndef != cppvar:
+ error_level = 0
+ if ifndef != cppvar + '_':
+ error_level = 5
+
+ ParseNolintSuppressions(filename, raw_lines[ifndef_linenum], ifndef_linenum,
+ error)
+ error(filename, ifndef_linenum, 'build/header_guard', error_level,
+ '#ifndef header guard has wrong style, please use: %s' % cppvar)
+
+ # Check for "//" comments on endif line.
+ ParseNolintSuppressions(filename, raw_lines[endif_linenum], endif_linenum,
+ error)
+ match = Match(r'#endif\s*//\s*' + cppvar + r'(_)?\b', endif)
+ if match:
+ if match.group(1) == '_':
+ # Issue low severity warning for deprecated double trailing underscore
+ error(filename, endif_linenum, 'build/header_guard', 0,
+ '#endif line should be "#endif // %s"' % cppvar)
+ return
+
+ # Didn't find the corresponding "//" comment. If this file does not
+ # contain any "//" comments at all, it could be that the compiler
+ # only wants "/**/" comments, look for those instead.
+ no_single_line_comments = True
+ for i in xrange(1, len(raw_lines) - 1):
+ line = raw_lines[i]
+ if Match(r'^(?:(?:\'(?:\.|[^\'])*\')|(?:"(?:\.|[^"])*")|[^\'"])*//', line):
+ no_single_line_comments = False
+ break
+
+ if no_single_line_comments:
+ match = Match(r'#endif\s*/\*\s*' + cppvar + r'(_)?\s*\*/', endif)
+ if match:
+ if match.group(1) == '_':
+ # Low severity warning for double trailing underscore
+ error(filename, endif_linenum, 'build/header_guard', 0,
+ '#endif line should be "#endif /* %s */"' % cppvar)
+ return
+
+ # Didn't find anything
+ error(filename, endif_linenum, 'build/header_guard', 5,
+ '#endif line should be "#endif // %s"' % cppvar)
+
+
+def CheckHeaderFileIncluded(filename, include_state, error):
+ """Logs an error if a source file does not include its header."""
+
+ # Do not check test files
+ fileinfo = FileInfo(filename)
+ if Search(_TEST_FILE_SUFFIX, fileinfo.BaseName()):
+ return
+
+ for ext in GetHeaderExtensions():
+ basefilename = filename[0:len(filename) - len(fileinfo.Extension())]
+ headerfile = basefilename + '.' + ext
+ if not os.path.exists(headerfile):
+ continue
+ headername = FileInfo(headerfile).RepositoryName()
+ first_include = None
+ for section_list in include_state.include_list:
+ for f in section_list:
+ if headername in f[0] or f[0] in headername:
+ return
+ if not first_include:
+ first_include = f[1]
+
+ error(filename, first_include, 'build/include', 5,
+ '%s should include its header file %s' % (fileinfo.RepositoryName(),
+ headername))
+
+
+def CheckForBadCharacters(filename, lines, error):
+ """Logs an error for each line containing bad characters.
+
+ Two kinds of bad characters:
+
+ 1. Unicode replacement characters: These indicate that either the file
+ contained invalid UTF-8 (likely) or Unicode replacement characters (which
+ it shouldn't). Note that it's possible for this to throw off line
+ numbering if the invalid UTF-8 occurred adjacent to a newline.
+
+ 2. NUL bytes. These are problematic for some tools.
+
+ Args:
+ filename: The name of the current file.
+ lines: An array of strings, each representing a line of the file.
+ error: The function to call with any errors found.
+ """
+ for linenum, line in enumerate(lines):
+ if unicode_escape_decode('\ufffd') in line:
+ error(filename, linenum, 'readability/utf8', 5,
+ 'Line contains invalid UTF-8 (or Unicode replacement character).')
+ if '\0' in line:
+ error(filename, linenum, 'readability/nul', 5, 'Line contains NUL byte.')
+
+
+def CheckForNewlineAtEOF(filename, lines, error):
+ """Logs an error if there is no newline char at the end of the file.
+
+ Args:
+ filename: The name of the current file.
+ lines: An array of strings, each representing a line of the file.
+ error: The function to call with any errors found.
+ """
+
+ # The array lines() was created by adding two newlines to the
+ # original file (go figure), then splitting on \n.
+ # To verify that the file ends in \n, we just have to make sure the
+ # last-but-two element of lines() exists and is empty.
+ if len(lines) < 3 or lines[-2]:
+ error(filename, len(lines) - 2, 'whitespace/ending_newline', 5,
+ 'Could not find a newline character at the end of the file.')
+
+
+def CheckForMultilineCommentsAndStrings(filename, clean_lines, linenum, error):
+ """Logs an error if we see /* ... */ or "..." that extend past one line.
+
+ /* ... */ comments are legit inside macros, for one line.
+ Otherwise, we prefer // comments, so it's ok to warn about the
+ other. Likewise, it's ok for strings to extend across multiple
+ lines, as long as a line continuation character (backslash)
+ terminates each line. Although not currently prohibited by the C++
+ style guide, it's ugly and unnecessary. We don't do well with either
+ in this lint program, so we warn about both.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Remove all \\ (escaped backslashes) from the line. They are OK, and the
+ # second (escaped) slash may trigger later \" detection erroneously.
+ line = line.replace('\\\\', '')
+
+ if line.count('/*') > line.count('*/'):
+ error(filename, linenum, 'readability/multiline_comment', 5,
+ 'Complex multi-line /*...*/-style comment found. '
+ 'Lint may give bogus warnings. '
+ 'Consider replacing these with //-style comments, '
+ 'with #if 0...#endif, '
+ 'or with more clearly structured multi-line comments.')
+
+ if (line.count('"') - line.count('\\"')) % 2:
+ error(filename, linenum, 'readability/multiline_string', 5,
+ 'Multi-line string ("...") found. This lint script doesn\'t '
+ 'do well with such strings, and may give bogus warnings. '
+ 'Use C++11 raw strings or concatenation instead.')
+
+
+# (non-threadsafe name, thread-safe alternative, validation pattern)
+#
+# The validation pattern is used to eliminate false positives such as:
+# _rand(); // false positive due to substring match.
+# ->rand(); // some member function rand().
+# ACMRandom rand(seed); // some variable named rand.
+# ISAACRandom rand(); // another variable named rand.
+#
+# Basically we require the return value of these functions to be used
+# in some expression context on the same line by matching on some
+# operator before the function name. This eliminates constructors and
+# member function calls.
+_UNSAFE_FUNC_PREFIX = r'(?:[-+*/=%^&|(<]\s*|>\s+)'
+_THREADING_LIST = (
+ ('asctime(', 'asctime_r(', _UNSAFE_FUNC_PREFIX + r'asctime\([^)]+\)'),
+ ('ctime(', 'ctime_r(', _UNSAFE_FUNC_PREFIX + r'ctime\([^)]+\)'),
+ ('getgrgid(', 'getgrgid_r(', _UNSAFE_FUNC_PREFIX + r'getgrgid\([^)]+\)'),
+ ('getgrnam(', 'getgrnam_r(', _UNSAFE_FUNC_PREFIX + r'getgrnam\([^)]+\)'),
+ ('getlogin(', 'getlogin_r(', _UNSAFE_FUNC_PREFIX + r'getlogin\(\)'),
+ ('getpwnam(', 'getpwnam_r(', _UNSAFE_FUNC_PREFIX + r'getpwnam\([^)]+\)'),
+ ('getpwuid(', 'getpwuid_r(', _UNSAFE_FUNC_PREFIX + r'getpwuid\([^)]+\)'),
+ ('gmtime(', 'gmtime_r(', _UNSAFE_FUNC_PREFIX + r'gmtime\([^)]+\)'),
+ ('localtime(', 'localtime_r(', _UNSAFE_FUNC_PREFIX + r'localtime\([^)]+\)'),
+ ('rand(', 'rand_r(', _UNSAFE_FUNC_PREFIX + r'rand\(\)'),
+ ('strtok(', 'strtok_r(',
+ _UNSAFE_FUNC_PREFIX + r'strtok\([^)]+\)'),
+ ('ttyname(', 'ttyname_r(', _UNSAFE_FUNC_PREFIX + r'ttyname\([^)]+\)'),
+ )
+
+
+def CheckPosixThreading(filename, clean_lines, linenum, error):
+ """Checks for calls to thread-unsafe functions.
+
+ Much code has been originally written without consideration of
+ multi-threading. Also, engineers are relying on their old experience;
+ they have learned posix before threading extensions were added. These
+ tests guide the engineers to use thread-safe functions (when using
+ posix directly).
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+ for single_thread_func, multithread_safe_func, pattern in _THREADING_LIST:
+ # Additional pattern matching check to confirm that this is the
+ # function we are looking for
+ if Search(pattern, line):
+ error(filename, linenum, 'runtime/threadsafe_fn', 2,
+ 'Consider using ' + multithread_safe_func +
+ '...) instead of ' + single_thread_func +
+ '...) for improved thread safety.')
+
+
+def CheckVlogArguments(filename, clean_lines, linenum, error):
+ """Checks that VLOG() is only used for defining a logging level.
+
+ For example, VLOG(2) is correct. VLOG(INFO), VLOG(WARNING), VLOG(ERROR), and
+ VLOG(FATAL) are not.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+ if Search(r'\bVLOG\((INFO|ERROR|WARNING|DFATAL|FATAL)\)', line):
+ error(filename, linenum, 'runtime/vlog', 5,
+ 'VLOG() should be used with numeric verbosity level. '
+ 'Use LOG() if you want symbolic severity levels.')
+
+# Matches invalid increment: *count++, which moves pointer instead of
+# incrementing a value.
+_RE_PATTERN_INVALID_INCREMENT = re.compile(
+ r'^\s*\*\w+(\+\+|--);')
+
+
+def CheckInvalidIncrement(filename, clean_lines, linenum, error):
+ """Checks for invalid increment *count++.
+
+ For example following function:
+ void increment_counter(int* count) {
+ *count++;
+ }
+ is invalid, because it effectively does count++, moving pointer, and should
+ be replaced with ++*count, (*count)++ or *count += 1.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+ if _RE_PATTERN_INVALID_INCREMENT.match(line):
+ error(filename, linenum, 'runtime/invalid_increment', 5,
+ 'Changing pointer instead of value (or unused value of operator*).')
+
+
+def IsMacroDefinition(clean_lines, linenum):
+ if Search(r'^#define', clean_lines[linenum]):
+ return True
+
+ if linenum > 0 and Search(r'\\$', clean_lines[linenum - 1]):
+ return True
+
+ return False
+
+
+def IsForwardClassDeclaration(clean_lines, linenum):
+ return Match(r'^\s*(\btemplate\b)*.*class\s+\w+;\s*$', clean_lines[linenum])
+
+
+class _BlockInfo(object):
+ """Stores information about a generic block of code."""
+
+ def __init__(self, linenum, seen_open_brace):
+ self.starting_linenum = linenum
+ self.seen_open_brace = seen_open_brace
+ self.open_parentheses = 0
+ self.inline_asm = _NO_ASM
+ self.check_namespace_indentation = False
+
+ def CheckBegin(self, filename, clean_lines, linenum, error):
+ """Run checks that applies to text up to the opening brace.
+
+ This is mostly for checking the text after the class identifier
+ and the "{", usually where the base class is specified. For other
+ blocks, there isn't much to check, so we always pass.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ pass
+
+ def CheckEnd(self, filename, clean_lines, linenum, error):
+ """Run checks that applies to text after the closing brace.
+
+ This is mostly used for checking end of namespace comments.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ pass
+
+ def IsBlockInfo(self):
+ """Returns true if this block is a _BlockInfo.
+
+ This is convenient for verifying that an object is an instance of
+ a _BlockInfo, but not an instance of any of the derived classes.
+
+ Returns:
+ True for this class, False for derived classes.
+ """
+ return self.__class__ == _BlockInfo
+
+
+class _ExternCInfo(_BlockInfo):
+ """Stores information about an 'extern "C"' block."""
+
+ def __init__(self, linenum):
+ _BlockInfo.__init__(self, linenum, True)
+
+
+class _ClassInfo(_BlockInfo):
+ """Stores information about a class."""
+
+ def __init__(self, name, class_or_struct, clean_lines, linenum):
+ _BlockInfo.__init__(self, linenum, False)
+ self.name = name
+ self.is_derived = False
+ self.check_namespace_indentation = True
+ if class_or_struct == 'struct':
+ self.access = 'public'
+ self.is_struct = True
+ else:
+ self.access = 'private'
+ self.is_struct = False
+
+ # Remember initial indentation level for this class. Using raw_lines here
+ # instead of elided to account for leading comments.
+ self.class_indent = GetIndentLevel(clean_lines.raw_lines[linenum])
+
+ # Try to find the end of the class. This will be confused by things like:
+ # class A {
+ # } *x = { ...
+ #
+ # But it's still good enough for CheckSectionSpacing.
+ self.last_line = 0
+ depth = 0
+ for i in range(linenum, clean_lines.NumLines()):
+ line = clean_lines.elided[i]
+ depth += line.count('{') - line.count('}')
+ if not depth:
+ self.last_line = i
+ break
+
+ def CheckBegin(self, filename, clean_lines, linenum, error):
+ # Look for a bare ':'
+ if Search('(^|[^:]):($|[^:])', clean_lines.elided[linenum]):
+ self.is_derived = True
+
+ def CheckEnd(self, filename, clean_lines, linenum, error):
+ # If there is a DISALLOW macro, it should appear near the end of
+ # the class.
+ seen_last_thing_in_class = False
+ for i in xrange(linenum - 1, self.starting_linenum, -1):
+ match = Search(
+ r'\b(DISALLOW_COPY_AND_ASSIGN|DISALLOW_IMPLICIT_CONSTRUCTORS)\(' +
+ self.name + r'\)',
+ clean_lines.elided[i])
+ if match:
+ if seen_last_thing_in_class:
+ error(filename, i, 'readability/constructors', 3,
+ match.group(1) + ' should be the last thing in the class')
+ break
+
+ if not Match(r'^\s*$', clean_lines.elided[i]):
+ seen_last_thing_in_class = True
+
+ # Check that closing brace is aligned with beginning of the class.
+ # Only do this if the closing brace is indented by only whitespaces.
+ # This means we will not check single-line class definitions.
+ indent = Match(r'^( *)\}', clean_lines.elided[linenum])
+ if indent and len(indent.group(1)) != self.class_indent:
+ if self.is_struct:
+ parent = 'struct ' + self.name
+ else:
+ parent = 'class ' + self.name
+ error(filename, linenum, 'whitespace/indent', 3,
+ 'Closing brace should be aligned with beginning of %s' % parent)
+
+
+class _NamespaceInfo(_BlockInfo):
+ """Stores information about a namespace."""
+
+ def __init__(self, name, linenum):
+ _BlockInfo.__init__(self, linenum, False)
+ self.name = name or ''
+ self.check_namespace_indentation = True
+
+ def CheckEnd(self, filename, clean_lines, linenum, error):
+ """Check end of namespace comments."""
+ line = clean_lines.raw_lines[linenum]
+
+ # Check how many lines is enclosed in this namespace. Don't issue
+ # warning for missing namespace comments if there aren't enough
+ # lines. However, do apply checks if there is already an end of
+ # namespace comment and it's incorrect.
+ #
+ # TODO(unknown): We always want to check end of namespace comments
+ # if a namespace is large, but sometimes we also want to apply the
+ # check if a short namespace contained nontrivial things (something
+ # other than forward declarations). There is currently no logic on
+ # deciding what these nontrivial things are, so this check is
+ # triggered by namespace size only, which works most of the time.
+ if (linenum - self.starting_linenum < 10
+ and not Match(r'^\s*};*\s*(//|/\*).*\bnamespace\b', line)):
+ return
+
+ # Look for matching comment at end of namespace.
+ #
+ # Note that we accept C style "/* */" comments for terminating
+ # namespaces, so that code that terminate namespaces inside
+ # preprocessor macros can be cpplint clean.
+ #
+ # We also accept stuff like "// end of namespace <name>." with the
+ # period at the end.
+ #
+ # Besides these, we don't accept anything else, otherwise we might
+ # get false negatives when existing comment is a substring of the
+ # expected namespace.
+ if self.name:
+ # Named namespace
+ if not Match((r'^\s*};*\s*(//|/\*).*\bnamespace\s+' +
+ re.escape(self.name) + r'[\*/\.\\\s]*$'),
+ line):
+ error(filename, linenum, 'readability/namespace', 5,
+ 'Namespace should be terminated with "// namespace %s"' %
+ self.name)
+ else:
+ # Anonymous namespace
+ if not Match(r'^\s*};*\s*(//|/\*).*\bnamespace[\*/\.\\\s]*$', line):
+ # If "// namespace anonymous" or "// anonymous namespace (more text)",
+ # mention "// anonymous namespace" as an acceptable form
+ if Match(r'^\s*}.*\b(namespace anonymous|anonymous namespace)\b', line):
+ error(filename, linenum, 'readability/namespace', 5,
+ 'Anonymous namespace should be terminated with "// namespace"'
+ ' or "// anonymous namespace"')
+ else:
+ error(filename, linenum, 'readability/namespace', 5,
+ 'Anonymous namespace should be terminated with "// namespace"')
+
+
+class _PreprocessorInfo(object):
+ """Stores checkpoints of nesting stacks when #if/#else is seen."""
+
+ def __init__(self, stack_before_if):
+ # The entire nesting stack before #if
+ self.stack_before_if = stack_before_if
+
+ # The entire nesting stack up to #else
+ self.stack_before_else = []
+
+ # Whether we have already seen #else or #elif
+ self.seen_else = False
+
+
+class NestingState(object):
+ """Holds states related to parsing braces."""
+
+ def __init__(self):
+ # Stack for tracking all braces. An object is pushed whenever we
+ # see a "{", and popped when we see a "}". Only 3 types of
+ # objects are possible:
+ # - _ClassInfo: a class or struct.
+ # - _NamespaceInfo: a namespace.
+ # - _BlockInfo: some other type of block.
+ self.stack = []
+
+ # Top of the previous stack before each Update().
+ #
+ # Because the nesting_stack is updated at the end of each line, we
+ # had to do some convoluted checks to find out what is the current
+ # scope at the beginning of the line. This check is simplified by
+ # saving the previous top of nesting stack.
+ #
+ # We could save the full stack, but we only need the top. Copying
+ # the full nesting stack would slow down cpplint by ~10%.
+ self.previous_stack_top = []
+
+ # Stack of _PreprocessorInfo objects.
+ self.pp_stack = []
+
+ def SeenOpenBrace(self):
+ """Check if we have seen the opening brace for the innermost block.
+
+ Returns:
+ True if we have seen the opening brace, False if the innermost
+ block is still expecting an opening brace.
+ """
+ return (not self.stack) or self.stack[-1].seen_open_brace
+
+ def InNamespaceBody(self):
+ """Check if we are currently one level inside a namespace body.
+
+ Returns:
+ True if top of the stack is a namespace block, False otherwise.
+ """
+ return self.stack and isinstance(self.stack[-1], _NamespaceInfo)
+
+ def InExternC(self):
+ """Check if we are currently one level inside an 'extern "C"' block.
+
+ Returns:
+ True if top of the stack is an extern block, False otherwise.
+ """
+ return self.stack and isinstance(self.stack[-1], _ExternCInfo)
+
+ def InClassDeclaration(self):
+ """Check if we are currently one level inside a class or struct declaration.
+
+ Returns:
+ True if top of the stack is a class/struct, False otherwise.
+ """
+ return self.stack and isinstance(self.stack[-1], _ClassInfo)
+
+ def InAsmBlock(self):
+ """Check if we are currently one level inside an inline ASM block.
+
+ Returns:
+ True if the top of the stack is a block containing inline ASM.
+ """
+ return self.stack and self.stack[-1].inline_asm != _NO_ASM
+
+ def InTemplateArgumentList(self, clean_lines, linenum, pos):
+ """Check if current position is inside template argument list.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ pos: position just after the suspected template argument.
+ Returns:
+ True if (linenum, pos) is inside template arguments.
+ """
+ while linenum < clean_lines.NumLines():
+ # Find the earliest character that might indicate a template argument
+ line = clean_lines.elided[linenum]
+ match = Match(r'^[^{};=\[\]\.<>]*(.)', line[pos:])
+ if not match:
+ linenum += 1
+ pos = 0
+ continue
+ token = match.group(1)
+ pos += len(match.group(0))
+
+ # These things do not look like template argument list:
+ # class Suspect {
+ # class Suspect x; }
+ if token in ('{', '}', ';'): return False
+
+ # These things look like template argument list:
+ # template <class Suspect>
+ # template <class Suspect = default_value>
+ # template <class Suspect[]>
+ # template <class Suspect...>
+ if token in ('>', '=', '[', ']', '.'): return True
+
+ # Check if token is an unmatched '<'.
+ # If not, move on to the next character.
+ if token != '<':
+ pos += 1
+ if pos >= len(line):
+ linenum += 1
+ pos = 0
+ continue
+
+ # We can't be sure if we just find a single '<', and need to
+ # find the matching '>'.
+ (_, end_line, end_pos) = CloseExpression(clean_lines, linenum, pos - 1)
+ if end_pos < 0:
+ # Not sure if template argument list or syntax error in file
+ return False
+ linenum = end_line
+ pos = end_pos
+ return False
+
+ def UpdatePreprocessor(self, line):
+ """Update preprocessor stack.
+
+ We need to handle preprocessors due to classes like this:
+ #ifdef SWIG
+ struct ResultDetailsPageElementExtensionPoint {
+ #else
+ struct ResultDetailsPageElementExtensionPoint : public Extension {
+ #endif
+
+ We make the following assumptions (good enough for most files):
+ - Preprocessor condition evaluates to true from #if up to first
+ #else/#elif/#endif.
+
+ - Preprocessor condition evaluates to false from #else/#elif up
+ to #endif. We still perform lint checks on these lines, but
+ these do not affect nesting stack.
+
+ Args:
+ line: current line to check.
+ """
+ if Match(r'^\s*#\s*(if|ifdef|ifndef)\b', line):
+ # Beginning of #if block, save the nesting stack here. The saved
+ # stack will allow us to restore the parsing state in the #else case.
+ self.pp_stack.append(_PreprocessorInfo(copy.deepcopy(self.stack)))
+ elif Match(r'^\s*#\s*(else|elif)\b', line):
+ # Beginning of #else block
+ if self.pp_stack:
+ if not self.pp_stack[-1].seen_else:
+ # This is the first #else or #elif block. Remember the
+ # whole nesting stack up to this point. This is what we
+ # keep after the #endif.
+ self.pp_stack[-1].seen_else = True
+ self.pp_stack[-1].stack_before_else = copy.deepcopy(self.stack)
+
+ # Restore the stack to how it was before the #if
+ self.stack = copy.deepcopy(self.pp_stack[-1].stack_before_if)
+ else:
+ # TODO(unknown): unexpected #else, issue warning?
+ pass
+ elif Match(r'^\s*#\s*endif\b', line):
+ # End of #if or #else blocks.
+ if self.pp_stack:
+ # If we saw an #else, we will need to restore the nesting
+ # stack to its former state before the #else, otherwise we
+ # will just continue from where we left off.
+ if self.pp_stack[-1].seen_else:
+ # Here we can just use a shallow copy since we are the last
+ # reference to it.
+ self.stack = self.pp_stack[-1].stack_before_else
+ # Drop the corresponding #if
+ self.pp_stack.pop()
+ else:
+ # TODO(unknown): unexpected #endif, issue warning?
+ pass
+
+ # TODO(unknown): Update() is too long, but we will refactor later.
+ def Update(self, filename, clean_lines, linenum, error):
+ """Update nesting state with current line.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Remember top of the previous nesting stack.
+ #
+ # The stack is always pushed/popped and not modified in place, so
+ # we can just do a shallow copy instead of copy.deepcopy. Using
+ # deepcopy would slow down cpplint by ~28%.
+ if self.stack:
+ self.previous_stack_top = self.stack[-1]
+ else:
+ self.previous_stack_top = None
+
+ # Update pp_stack
+ self.UpdatePreprocessor(line)
+
+ # Count parentheses. This is to avoid adding struct arguments to
+ # the nesting stack.
+ if self.stack:
+ inner_block = self.stack[-1]
+ depth_change = line.count('(') - line.count(')')
+ inner_block.open_parentheses += depth_change
+
+ # Also check if we are starting or ending an inline assembly block.
+ if inner_block.inline_asm in (_NO_ASM, _END_ASM):
+ if (depth_change != 0 and
+ inner_block.open_parentheses == 1 and
+ _MATCH_ASM.match(line)):
+ # Enter assembly block
+ inner_block.inline_asm = _INSIDE_ASM
+ else:
+ # Not entering assembly block. If previous line was _END_ASM,
+ # we will now shift to _NO_ASM state.
+ inner_block.inline_asm = _NO_ASM
+ elif (inner_block.inline_asm == _INSIDE_ASM and
+ inner_block.open_parentheses == 0):
+ # Exit assembly block
+ inner_block.inline_asm = _END_ASM
+
+ # Consume namespace declaration at the beginning of the line. Do
+ # this in a loop so that we catch same line declarations like this:
+ # namespace proto2 { namespace bridge { class MessageSet; } }
+ while True:
+ # Match start of namespace. The "\b\s*" below catches namespace
+ # declarations even if it weren't followed by a whitespace, this
+ # is so that we don't confuse our namespace checker. The
+ # missing spaces will be flagged by CheckSpacing.
+ namespace_decl_match = Match(r'^\s*namespace\b\s*([:\w]+)?(.*)$', line)
+ if not namespace_decl_match:
+ break
+
+ new_namespace = _NamespaceInfo(namespace_decl_match.group(1), linenum)
+ self.stack.append(new_namespace)
+
+ line = namespace_decl_match.group(2)
+ if line.find('{') != -1:
+ new_namespace.seen_open_brace = True
+ line = line[line.find('{') + 1:]
+
+ # Look for a class declaration in whatever is left of the line
+ # after parsing namespaces. The regexp accounts for decorated classes
+ # such as in:
+ # class LOCKABLE API Object {
+ # };
+ class_decl_match = Match(
+ r'^(\s*(?:template\s*<[\w\s<>,:=]*>\s*)?'
+ r'(class|struct)\s+(?:[A-Z_]+\s+)*(\w+(?:::\w+)*))'
+ r'(.*)$', line)
+ if (class_decl_match and
+ (not self.stack or self.stack[-1].open_parentheses == 0)):
+ # We do not want to accept classes that are actually template arguments:
+ # template <class Ignore1,
+ # class Ignore2 = Default<Args>,
+ # template <Args> class Ignore3>
+ # void Function() {};
+ #
+ # To avoid template argument cases, we scan forward and look for
+ # an unmatched '>'. If we see one, assume we are inside a
+ # template argument list.
+ end_declaration = len(class_decl_match.group(1))
+ if not self.InTemplateArgumentList(clean_lines, linenum, end_declaration):
+ self.stack.append(_ClassInfo(
+ class_decl_match.group(3), class_decl_match.group(2),
+ clean_lines, linenum))
+ line = class_decl_match.group(4)
+
+ # If we have not yet seen the opening brace for the innermost block,
+ # run checks here.
+ if not self.SeenOpenBrace():
+ self.stack[-1].CheckBegin(filename, clean_lines, linenum, error)
+
+ # Update access control if we are inside a class/struct
+ if self.stack and isinstance(self.stack[-1], _ClassInfo):
+ classinfo = self.stack[-1]
+ access_match = Match(
+ r'^(.*)\b(public|private|protected|signals)(\s+(?:slots\s*)?)?'
+ r':(?:[^:]|$)',
+ line)
+ if access_match:
+ classinfo.access = access_match.group(2)
+
+ # Check that access keywords are indented +1 space. Skip this
+ # check if the keywords are not preceded by whitespaces.
+ indent = access_match.group(1)
+ if (len(indent) != classinfo.class_indent + 1 and
+ Match(r'^\s*$', indent)):
+ if classinfo.is_struct:
+ parent = 'struct ' + classinfo.name
+ else:
+ parent = 'class ' + classinfo.name
+ slots = ''
+ if access_match.group(3):
+ slots = access_match.group(3)
+ error(filename, linenum, 'whitespace/indent', 3,
+ '%s%s: should be indented +1 space inside %s' % (
+ access_match.group(2), slots, parent))
+
+ # Consume braces or semicolons from what's left of the line
+ while True:
+ # Match first brace, semicolon, or closed parenthesis.
+ matched = Match(r'^[^{;)}]*([{;)}])(.*)$', line)
+ if not matched:
+ break
+
+ token = matched.group(1)
+ if token == '{':
+ # If namespace or class hasn't seen a opening brace yet, mark
+ # namespace/class head as complete. Push a new block onto the
+ # stack otherwise.
+ if not self.SeenOpenBrace():
+ self.stack[-1].seen_open_brace = True
+ elif Match(r'^extern\s*"[^"]*"\s*\{', line):
+ self.stack.append(_ExternCInfo(linenum))
+ else:
+ self.stack.append(_BlockInfo(linenum, True))
+ if _MATCH_ASM.match(line):
+ self.stack[-1].inline_asm = _BLOCK_ASM
+
+ elif token == ';' or token == ')':
+ # If we haven't seen an opening brace yet, but we already saw
+ # a semicolon, this is probably a forward declaration. Pop
+ # the stack for these.
+ #
+ # Similarly, if we haven't seen an opening brace yet, but we
+ # already saw a closing parenthesis, then these are probably
+ # function arguments with extra "class" or "struct" keywords.
+ # Also pop these stack for these.
+ if not self.SeenOpenBrace():
+ self.stack.pop()
+ else: # token == '}'
+ # Perform end of block checks and pop the stack.
+ if self.stack:
+ self.stack[-1].CheckEnd(filename, clean_lines, linenum, error)
+ self.stack.pop()
+ line = matched.group(2)
+
+ def InnermostClass(self):
+ """Get class info on the top of the stack.
+
+ Returns:
+ A _ClassInfo object if we are inside a class, or None otherwise.
+ """
+ for i in range(len(self.stack), 0, -1):
+ classinfo = self.stack[i - 1]
+ if isinstance(classinfo, _ClassInfo):
+ return classinfo
+ return None
+
+ def CheckCompletedBlocks(self, filename, error):
+ """Checks that all classes and namespaces have been completely parsed.
+
+ Call this when all lines in a file have been processed.
+ Args:
+ filename: The name of the current file.
+ error: The function to call with any errors found.
+ """
+ # Note: This test can result in false positives if #ifdef constructs
+ # get in the way of brace matching. See the testBuildClass test in
+ # cpplint_unittest.py for an example of this.
+ for obj in self.stack:
+ if isinstance(obj, _ClassInfo):
+ error(filename, obj.starting_linenum, 'build/class', 5,
+ 'Failed to find complete declaration of class %s' %
+ obj.name)
+ elif isinstance(obj, _NamespaceInfo):
+ error(filename, obj.starting_linenum, 'build/namespaces', 5,
+ 'Failed to find complete declaration of namespace %s' %
+ obj.name)
+
+
+def CheckForNonStandardConstructs(filename, clean_lines, linenum,
+ nesting_state, error):
+ r"""Logs an error if we see certain non-ANSI constructs ignored by gcc-2.
+
+ Complain about several constructs which gcc-2 accepts, but which are
+ not standard C++. Warning about these in lint is one way to ease the
+ transition to new compilers.
+ - put storage class first (e.g. "static const" instead of "const static").
+ - "%lld" instead of %qd" in printf-type functions.
+ - "%1$d" is non-standard in printf-type functions.
+ - "\%" is an undefined character escape sequence.
+ - text after #endif is not allowed.
+ - invalid inner-style forward declaration.
+ - >? and <? operators, and their >?= and <?= cousins.
+
+ Additionally, check for constructor/destructor style violations and reference
+ members, as it is very convenient to do so while checking for
+ gcc-2 compliance.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: A callable to which errors are reported, which takes 4 arguments:
+ filename, line number, error level, and message
+ """
+
+ # Remove comments from the line, but leave in strings for now.
+ line = clean_lines.lines[linenum]
+
+ if Search(r'printf\s*\(.*".*%[-+ ]?\d*q', line):
+ error(filename, linenum, 'runtime/printf_format', 3,
+ '%q in format strings is deprecated. Use %ll instead.')
+
+ if Search(r'printf\s*\(.*".*%\d+\$', line):
+ error(filename, linenum, 'runtime/printf_format', 2,
+ '%N$ formats are unconventional. Try rewriting to avoid them.')
+
+ # Remove escaped backslashes before looking for undefined escapes.
+ line = line.replace('\\\\', '')
+
+ if Search(r'("|\').*\\(%|\[|\(|{)', line):
+ error(filename, linenum, 'build/printf_format', 3,
+ '%, [, (, and { are undefined character escapes. Unescape them.')
+
+ # For the rest, work with both comments and strings removed.
+ line = clean_lines.elided[linenum]
+
+ if Search(r'\b(const|volatile|void|char|short|int|long'
+ r'|float|double|signed|unsigned'
+ r'|schar|u?int8|u?int16|u?int32|u?int64)'
+ r'\s+(register|static|extern|typedef)\b',
+ line):
+ error(filename, linenum, 'build/storage_class', 5,
+ 'Storage-class specifier (static, extern, typedef, etc) should be '
+ 'at the beginning of the declaration.')
+
+ if Match(r'\s*#\s*endif\s*[^/\s]+', line):
+ error(filename, linenum, 'build/endif_comment', 5,
+ 'Uncommented text after #endif is non-standard. Use a comment.')
+
+ if Match(r'\s*class\s+(\w+\s*::\s*)+\w+\s*;', line):
+ error(filename, linenum, 'build/forward_decl', 5,
+ 'Inner-style forward declarations are invalid. Remove this line.')
+
+ if Search(r'(\w+|[+-]?\d+(\.\d*)?)\s*(<|>)\?=?\s*(\w+|[+-]?\d+)(\.\d*)?',
+ line):
+ error(filename, linenum, 'build/deprecated', 3,
+ '>? and <? (max and min) operators are non-standard and deprecated.')
+
+ if Search(r'^\s*const\s*string\s*&\s*\w+\s*;', line):
+ # TODO(unknown): Could it be expanded safely to arbitrary references,
+ # without triggering too many false positives? The first
+ # attempt triggered 5 warnings for mostly benign code in the regtest, hence
+ # the restriction.
+ # Here's the original regexp, for the reference:
+ # type_name = r'\w+((\s*::\s*\w+)|(\s*<\s*\w+?\s*>))?'
+ # r'\s*const\s*' + type_name + '\s*&\s*\w+\s*;'
+ error(filename, linenum, 'runtime/member_string_references', 2,
+ 'const string& members are dangerous. It is much better to use '
+ 'alternatives, such as pointers or simple constants.')
+
+ # Everything else in this function operates on class declarations.
+ # Return early if the top of the nesting stack is not a class, or if
+ # the class head is not completed yet.
+ classinfo = nesting_state.InnermostClass()
+ if not classinfo or not classinfo.seen_open_brace:
+ return
+
+ # The class may have been declared with namespace or classname qualifiers.
+ # The constructor and destructor will not have those qualifiers.
+ base_classname = classinfo.name.split('::')[-1]
+
+ # Look for single-argument constructors that aren't marked explicit.
+ # Technically a valid construct, but against style.
+ explicit_constructor_match = Match(
+ r'\s+(?:inline\s+)?(explicit\s+)?(?:inline\s+)?%s\s*'
+ r'\(((?:[^()]|\([^()]*\))*)\)'
+ % re.escape(base_classname),
+ line)
+
+ if explicit_constructor_match:
+ is_marked_explicit = explicit_constructor_match.group(1)
+
+ if not explicit_constructor_match.group(2):
+ constructor_args = []
+ else:
+ constructor_args = explicit_constructor_match.group(2).split(',')
+
+ # collapse arguments so that commas in template parameter lists and function
+ # argument parameter lists don't split arguments in two
+ i = 0
+ while i < len(constructor_args):
+ constructor_arg = constructor_args[i]
+ while (constructor_arg.count('<') > constructor_arg.count('>') or
+ constructor_arg.count('(') > constructor_arg.count(')')):
+ constructor_arg += ',' + constructor_args[i + 1]
+ del constructor_args[i + 1]
+ constructor_args[i] = constructor_arg
+ i += 1
+
+ variadic_args = [arg for arg in constructor_args if '&&...' in arg]
+ defaulted_args = [arg for arg in constructor_args if '=' in arg]
+ noarg_constructor = (not constructor_args or # empty arg list
+ # 'void' arg specifier
+ (len(constructor_args) == 1 and
+ constructor_args[0].strip() == 'void'))
+ onearg_constructor = ((len(constructor_args) == 1 and # exactly one arg
+ not noarg_constructor) or
+ # all but at most one arg defaulted
+ (len(constructor_args) >= 1 and
+ not noarg_constructor and
+ len(defaulted_args) >= len(constructor_args) - 1) or
+ # variadic arguments with zero or one argument
+ (len(constructor_args) <= 2 and
+ len(variadic_args) >= 1))
+ initializer_list_constructor = bool(
+ onearg_constructor and
+ Search(r'\bstd\s*::\s*initializer_list\b', constructor_args[0]))
+ copy_constructor = bool(
+ onearg_constructor and
+ Match(r'(const\s+)?%s(\s*<[^>]*>)?(\s+const)?\s*(?:<\w+>\s*)?&'
+ % re.escape(base_classname), constructor_args[0].strip()))
+
+ if (not is_marked_explicit and
+ onearg_constructor and
+ not initializer_list_constructor and
+ not copy_constructor):
+ if defaulted_args or variadic_args:
+ error(filename, linenum, 'runtime/explicit', 5,
+ 'Constructors callable with one argument '
+ 'should be marked explicit.')
+ else:
+ error(filename, linenum, 'runtime/explicit', 5,
+ 'Single-parameter constructors should be marked explicit.')
+ elif is_marked_explicit and not onearg_constructor:
+ if noarg_constructor:
+ error(filename, linenum, 'runtime/explicit', 5,
+ 'Zero-parameter constructors should not be marked explicit.')
+
+
+def CheckSpacingForFunctionCall(filename, clean_lines, linenum, error):
+ """Checks for the correctness of various spacing around function calls.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Since function calls often occur inside if/for/while/switch
+ # expressions - which have their own, more liberal conventions - we
+ # first see if we should be looking inside such an expression for a
+ # function call, to which we can apply more strict standards.
+ fncall = line # if there's no control flow construct, look at whole line
+ for pattern in (r'\bif\s*\((.*)\)\s*{',
+ r'\bfor\s*\((.*)\)\s*{',
+ r'\bwhile\s*\((.*)\)\s*[{;]',
+ r'\bswitch\s*\((.*)\)\s*{'):
+ match = Search(pattern, line)
+ if match:
+ fncall = match.group(1) # look inside the parens for function calls
+ break
+
+ # Except in if/for/while/switch, there should never be space
+ # immediately inside parens (eg "f( 3, 4 )"). We make an exception
+ # for nested parens ( (a+b) + c ). Likewise, there should never be
+ # a space before a ( when it's a function argument. I assume it's a
+ # function argument when the char before the whitespace is legal in
+ # a function name (alnum + _) and we're not starting a macro. Also ignore
+ # pointers and references to arrays and functions coz they're too tricky:
+ # we use a very simple way to recognize these:
+ # " (something)(maybe-something)" or
+ # " (something)(maybe-something," or
+ # " (something)[something]"
+ # Note that we assume the contents of [] to be short enough that
+ # they'll never need to wrap.
+ if ( # Ignore control structures.
+ not Search(r'\b(if|for|while|switch|return|new|delete|catch|sizeof)\b',
+ fncall) and
+ # Ignore pointers/references to functions.
+ not Search(r' \([^)]+\)\([^)]*(\)|,$)', fncall) and
+ # Ignore pointers/references to arrays.
+ not Search(r' \([^)]+\)\[[^\]]+\]', fncall)):
+ if Search(r'\w\s*\(\s(?!\s*\\$)', fncall): # a ( used for a fn call
+ error(filename, linenum, 'whitespace/parens', 4,
+ 'Extra space after ( in function call')
+ elif Search(r'\(\s+(?!(\s*\\)|\()', fncall):
+ error(filename, linenum, 'whitespace/parens', 2,
+ 'Extra space after (')
+ if (Search(r'\w\s+\(', fncall) and
+ not Search(r'_{0,2}asm_{0,2}\s+_{0,2}volatile_{0,2}\s+\(', fncall) and
+ not Search(r'#\s*define|typedef|using\s+\w+\s*=', fncall) and
+ not Search(r'\w\s+\((\w+::)*\*\w+\)\(', fncall) and
+ not Search(r'\b(' + '|'.join(_ALT_TOKEN_REPLACEMENT.keys()) + r')\b\s+\(',
+ fncall) and
+ not Search(r'\bcase\s+\(', fncall)):
+ # TODO(unknown): Space after an operator function seem to be a common
+ # error, silence those for now by restricting them to highest verbosity.
+ if Search(r'\boperator_*\b', line):
+ error(filename, linenum, 'whitespace/parens', 0,
+ 'Extra space before ( in function call')
+ else:
+ error(filename, linenum, 'whitespace/parens', 4,
+ 'Extra space before ( in function call')
+ # If the ) is followed only by a newline or a { + newline, assume it's
+ # part of a control statement (if/while/etc), and don't complain
+ if Search(r'[^)]\s+\)\s*[^{\s]', fncall):
+ # If the closing parenthesis is preceded by only whitespaces,
+ # try to give a more descriptive error message.
+ if Search(r'^\s+\)', fncall):
+ error(filename, linenum, 'whitespace/parens', 2,
+ 'Closing ) should be moved to the previous line')
+ else:
+ error(filename, linenum, 'whitespace/parens', 2,
+ 'Extra space before )')
+
+
+def IsBlankLine(line):
+ """Returns true if the given line is blank.
+
+ We consider a line to be blank if the line is empty or consists of
+ only white spaces.
+
+ Args:
+ line: A line of a string.
+
+ Returns:
+ True, if the given line is blank.
+ """
+ return not line or line.isspace()
+
+
+def CheckForNamespaceIndentation(filename, nesting_state, clean_lines, line,
+ error):
+ is_namespace_indent_item = (
+ len(nesting_state.stack) > 1 and
+ nesting_state.stack[-1].check_namespace_indentation and
+ isinstance(nesting_state.previous_stack_top, _NamespaceInfo) and
+ nesting_state.previous_stack_top == nesting_state.stack[-2])
+
+ if ShouldCheckNamespaceIndentation(nesting_state, is_namespace_indent_item,
+ clean_lines.elided, line):
+ CheckItemIndentationInNamespace(filename, clean_lines.elided,
+ line, error)
+
+
+def CheckForFunctionLengths(filename, clean_lines, linenum,
+ function_state, error):
+ """Reports for long function bodies.
+
+ For an overview why this is done, see:
+ https://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Write_Short_Functions
+
+ Uses a simplistic algorithm assuming other style guidelines
+ (especially spacing) are followed.
+ Only checks unindented functions, so class members are unchecked.
+ Trivial bodies are unchecked, so constructors with huge initializer lists
+ may be missed.
+ Blank/comment lines are not counted so as to avoid encouraging the removal
+ of vertical space and comments just to get through a lint check.
+ NOLINT *on the last line of a function* disables this check.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ function_state: Current function name and lines in body so far.
+ error: The function to call with any errors found.
+ """
+ lines = clean_lines.lines
+ line = lines[linenum]
+ joined_line = ''
+
+ starting_func = False
+ regexp = r'(\w(\w|::|\*|\&|\s)*)\(' # decls * & space::name( ...
+ match_result = Match(regexp, line)
+ if match_result:
+ # If the name is all caps and underscores, figure it's a macro and
+ # ignore it, unless it's TEST or TEST_F.
+ function_name = match_result.group(1).split()[-1]
+ if function_name == 'TEST' or function_name == 'TEST_F' or (
+ not Match(r'[A-Z_]+$', function_name)):
+ starting_func = True
+
+ if starting_func:
+ body_found = False
+ for start_linenum in range(linenum, clean_lines.NumLines()):
+ start_line = lines[start_linenum]
+ joined_line += ' ' + start_line.lstrip()
+ if Search(r'(;|})', start_line): # Declarations and trivial functions
+ body_found = True
+ break # ... ignore
+ elif Search(r'{', start_line):
+ body_found = True
+ function = Search(r'((\w|:)*)\(', line).group(1)
+ if Match(r'TEST', function): # Handle TEST... macros
+ parameter_regexp = Search(r'(\(.*\))', joined_line)
+ if parameter_regexp: # Ignore bad syntax
+ function += parameter_regexp.group(1)
+ else:
+ function += '()'
+ function_state.Begin(function)
+ break
+ if not body_found:
+ # No body for the function (or evidence of a non-function) was found.
+ error(filename, linenum, 'readability/fn_size', 5,
+ 'Lint failed to find start of function body.')
+ elif Match(r'^\}\s*$', line): # function end
+ function_state.Check(error, filename, linenum)
+ function_state.End()
+ elif not Match(r'^\s*$', line):
+ function_state.Count() # Count non-blank/non-comment lines.
+
+
+_RE_PATTERN_TODO = re.compile(r'^//(\s*)TODO(\(.+?\))?:?(\s|$)?')
+
+
+def CheckComment(line, filename, linenum, next_line_start, error):
+ """Checks for common mistakes in comments.
+
+ Args:
+ line: The line in question.
+ filename: The name of the current file.
+ linenum: The number of the line to check.
+ next_line_start: The first non-whitespace column of the next line.
+ error: The function to call with any errors found.
+ """
+ commentpos = line.find('//')
+ if commentpos != -1:
+ # Check if the // may be in quotes. If so, ignore it
+ if re.sub(r'\\.', '', line[0:commentpos]).count('"') % 2 == 0:
+ # Allow one space for new scopes, two spaces otherwise:
+ if (not (Match(r'^.*{ *//', line) and next_line_start == commentpos) and
+ ((commentpos >= 1 and
+ line[commentpos-1] not in string.whitespace) or
+ (commentpos >= 2 and
+ line[commentpos-2] not in string.whitespace))):
+ error(filename, linenum, 'whitespace/comments', 2,
+ 'At least two spaces is best between code and comments')
+
+ # Checks for common mistakes in TODO comments.
+ comment = line[commentpos:]
+ match = _RE_PATTERN_TODO.match(comment)
+ if match:
+ # One whitespace is correct; zero whitespace is handled elsewhere.
+ leading_whitespace = match.group(1)
+ if len(leading_whitespace) > 1:
+ error(filename, linenum, 'whitespace/todo', 2,
+ 'Too many spaces before TODO')
+
+ username = match.group(2)
+ if not username:
+ error(filename, linenum, 'readability/todo', 2,
+ 'Missing username in TODO; it should look like '
+ '"// TODO(my_username): Stuff."')
+
+ middle_whitespace = match.group(3)
+ # Comparisons made explicit for correctness -- pylint: disable=g-explicit-bool-comparison
+ if middle_whitespace != ' ' and middle_whitespace != '':
+ error(filename, linenum, 'whitespace/todo', 2,
+ 'TODO(my_username) should be followed by a space')
+
+ # If the comment contains an alphanumeric character, there
+ # should be a space somewhere between it and the // unless
+ # it's a /// or //! Doxygen comment.
+ if (Match(r'//[^ ]*\w', comment) and
+ not Match(r'(///|//\!)(\s+|$)', comment)):
+ error(filename, linenum, 'whitespace/comments', 4,
+ 'Should have a space between // and comment')
+
+
+def CheckAccess(filename, clean_lines, linenum, nesting_state, error):
+ """Checks for improper use of DISALLOW* macros.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum] # get rid of comments and strings
+
+ matched = Match((r'\s*(DISALLOW_COPY_AND_ASSIGN|'
+ r'DISALLOW_IMPLICIT_CONSTRUCTORS)'), line)
+ if not matched:
+ return
+ if nesting_state.stack and isinstance(nesting_state.stack[-1], _ClassInfo):
+ if nesting_state.stack[-1].access != 'private':
+ error(filename, linenum, 'readability/constructors', 3,
+ '%s must be in the private: section' % matched.group(1))
+
+ else:
+ # Found DISALLOW* macro outside a class declaration, or perhaps it
+ # was used inside a function when it should have been part of the
+ # class declaration. We could issue a warning here, but it
+ # probably resulted in a compiler error already.
+ pass
+
+
+def CheckSpacing(filename, clean_lines, linenum, nesting_state, error):
+ """Checks for the correctness of various spacing issues in the code.
+
+ Things we check for: spaces around operators, spaces after
+ if/for/while/switch, no spaces around parens in function calls, two
+ spaces between code and comment, don't start a block with a blank
+ line, don't end a function with a blank line, don't add a blank line
+ after public/protected/private, don't have too many blank lines in a row.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: The function to call with any errors found.
+ """
+
+ # Don't use "elided" lines here, otherwise we can't check commented lines.
+ # Don't want to use "raw" either, because we don't want to check inside C++11
+ # raw strings,
+ raw = clean_lines.lines_without_raw_strings
+ line = raw[linenum]
+
+ # Before nixing comments, check if the line is blank for no good
+ # reason. This includes the first line after a block is opened, and
+ # blank lines at the end of a function (ie, right before a line like '}'
+ #
+ # Skip all the blank line checks if we are immediately inside a
+ # namespace body. In other words, don't issue blank line warnings
+ # for this block:
+ # namespace {
+ #
+ # }
+ #
+ # A warning about missing end of namespace comments will be issued instead.
+ #
+ # Also skip blank line checks for 'extern "C"' blocks, which are formatted
+ # like namespaces.
+ if (IsBlankLine(line) and
+ not nesting_state.InNamespaceBody() and
+ not nesting_state.InExternC()):
+ elided = clean_lines.elided
+ prev_line = elided[linenum - 1]
+ prevbrace = prev_line.rfind('{')
+ # TODO(unknown): Don't complain if line before blank line, and line after,
+ # both start with alnums and are indented the same amount.
+ # This ignores whitespace at the start of a namespace block
+ # because those are not usually indented.
+ if prevbrace != -1 and prev_line[prevbrace:].find('}') == -1:
+ # OK, we have a blank line at the start of a code block. Before we
+ # complain, we check if it is an exception to the rule: The previous
+ # non-empty line has the parameters of a function header that are indented
+ # 4 spaces (because they did not fit in a 80 column line when placed on
+ # the same line as the function name). We also check for the case where
+ # the previous line is indented 6 spaces, which may happen when the
+ # initializers of a constructor do not fit into a 80 column line.
+ exception = False
+ if Match(r' {6}\w', prev_line): # Initializer list?
+ # We are looking for the opening column of initializer list, which
+ # should be indented 4 spaces to cause 6 space indentation afterwards.
+ search_position = linenum-2
+ while (search_position >= 0
+ and Match(r' {6}\w', elided[search_position])):
+ search_position -= 1
+ exception = (search_position >= 0
+ and elided[search_position][:5] == ' :')
+ else:
+ # Search for the function arguments or an initializer list. We use a
+ # simple heuristic here: If the line is indented 4 spaces; and we have a
+ # closing paren, without the opening paren, followed by an opening brace
+ # or colon (for initializer lists) we assume that it is the last line of
+ # a function header. If we have a colon indented 4 spaces, it is an
+ # initializer list.
+ exception = (Match(r' {4}\w[^\(]*\)\s*(const\s*)?(\{\s*$|:)',
+ prev_line)
+ or Match(r' {4}:', prev_line))
+
+ if not exception:
+ error(filename, linenum, 'whitespace/blank_line', 2,
+ 'Redundant blank line at the start of a code block '
+ 'should be deleted.')
+ # Ignore blank lines at the end of a block in a long if-else
+ # chain, like this:
+ # if (condition1) {
+ # // Something followed by a blank line
+ #
+ # } else if (condition2) {
+ # // Something else
+ # }
+ if linenum + 1 < clean_lines.NumLines():
+ next_line = raw[linenum + 1]
+ if (next_line
+ and Match(r'\s*}', next_line)
+ and next_line.find('} else ') == -1):
+ error(filename, linenum, 'whitespace/blank_line', 3,
+ 'Redundant blank line at the end of a code block '
+ 'should be deleted.')
+
+ matched = Match(r'\s*(public|protected|private):', prev_line)
+ if matched:
+ error(filename, linenum, 'whitespace/blank_line', 3,
+ 'Do not leave a blank line after "%s:"' % matched.group(1))
+
+ # Next, check comments
+ next_line_start = 0
+ if linenum + 1 < clean_lines.NumLines():
+ next_line = raw[linenum + 1]
+ next_line_start = len(next_line) - len(next_line.lstrip())
+ CheckComment(line, filename, linenum, next_line_start, error)
+
+ # get rid of comments and strings
+ line = clean_lines.elided[linenum]
+
+ # You shouldn't have spaces before your brackets, except maybe after
+ # 'delete []' or 'return []() {};'
+ if Search(r'\w\s+\[', line) and not Search(r'(?:delete|return)\s+\[', line):
+ error(filename, linenum, 'whitespace/braces', 5,
+ 'Extra space before [')
+
+ # In range-based for, we wanted spaces before and after the colon, but
+ # not around "::" tokens that might appear.
+ if (Search(r'for *\(.*[^:]:[^: ]', line) or
+ Search(r'for *\(.*[^: ]:[^:]', line)):
+ error(filename, linenum, 'whitespace/forcolon', 2,
+ 'Missing space around colon in range-based for loop')
+
+
+def CheckOperatorSpacing(filename, clean_lines, linenum, error):
+ """Checks for horizontal spacing around operators.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Don't try to do spacing checks for operator methods. Do this by
+ # replacing the troublesome characters with something else,
+ # preserving column position for all other characters.
+ #
+ # The replacement is done repeatedly to avoid false positives from
+ # operators that call operators.
+ while True:
+ match = Match(r'^(.*\boperator\b)(\S+)(\s*\(.*)$', line)
+ if match:
+ line = match.group(1) + ('_' * len(match.group(2))) + match.group(3)
+ else:
+ break
+
+ # We allow no-spaces around = within an if: "if ( (a=Foo()) == 0 )".
+ # Otherwise not. Note we only check for non-spaces on *both* sides;
+ # sometimes people put non-spaces on one side when aligning ='s among
+ # many lines (not that this is behavior that I approve of...)
+ if ((Search(r'[\w.]=', line) or
+ Search(r'=[\w.]', line))
+ and not Search(r'\b(if|while|for) ', line)
+ # Operators taken from [lex.operators] in C++11 standard.
+ and not Search(r'(>=|<=|==|!=|&=|\^=|\|=|\+=|\*=|\/=|\%=)', line)
+ and not Search(r'operator=', line)):
+ error(filename, linenum, 'whitespace/operators', 4,
+ 'Missing spaces around =')
+
+ # It's ok not to have spaces around binary operators like + - * /, but if
+ # there's too little whitespace, we get concerned. It's hard to tell,
+ # though, so we punt on this one for now. TODO.
+
+ # You should always have whitespace around binary operators.
+ #
+ # Check <= and >= first to avoid false positives with < and >, then
+ # check non-include lines for spacing around < and >.
+ #
+ # If the operator is followed by a comma, assume it's be used in a
+ # macro context and don't do any checks. This avoids false
+ # positives.
+ #
+ # Note that && is not included here. This is because there are too
+ # many false positives due to RValue references.
+ match = Search(r'[^<>=!\s](==|!=|<=|>=|\|\|)[^<>=!\s,;\)]', line)
+ if match:
+ error(filename, linenum, 'whitespace/operators', 3,
+ 'Missing spaces around %s' % match.group(1))
+ elif not Match(r'#.*include', line):
+ # Look for < that is not surrounded by spaces. This is only
+ # triggered if both sides are missing spaces, even though
+ # technically should should flag if at least one side is missing a
+ # space. This is done to avoid some false positives with shifts.
+ match = Match(r'^(.*[^\s<])<[^\s=<,]', line)
+ if match:
+ (_, _, end_pos) = CloseExpression(
+ clean_lines, linenum, len(match.group(1)))
+ if end_pos <= -1:
+ error(filename, linenum, 'whitespace/operators', 3,
+ 'Missing spaces around <')
+
+ # Look for > that is not surrounded by spaces. Similar to the
+ # above, we only trigger if both sides are missing spaces to avoid
+ # false positives with shifts.
+ match = Match(r'^(.*[^-\s>])>[^\s=>,]', line)
+ if match:
+ (_, _, start_pos) = ReverseCloseExpression(
+ clean_lines, linenum, len(match.group(1)))
+ if start_pos <= -1:
+ error(filename, linenum, 'whitespace/operators', 3,
+ 'Missing spaces around >')
+
+ # We allow no-spaces around << when used like this: 10<<20, but
+ # not otherwise (particularly, not when used as streams)
+ #
+ # We also allow operators following an opening parenthesis, since
+ # those tend to be macros that deal with operators.
+ match = Search(r'(operator|[^\s(<])(?:L|UL|LL|ULL|l|ul|ll|ull)?<<([^\s,=<])', line)
+ if (match and not (match.group(1).isdigit() and match.group(2).isdigit()) and
+ not (match.group(1) == 'operator' and match.group(2) == ';')):
+ error(filename, linenum, 'whitespace/operators', 3,
+ 'Missing spaces around <<')
+
+ # We allow no-spaces around >> for almost anything. This is because
+ # C++11 allows ">>" to close nested templates, which accounts for
+ # most cases when ">>" is not followed by a space.
+ #
+ # We still warn on ">>" followed by alpha character, because that is
+ # likely due to ">>" being used for right shifts, e.g.:
+ # value >> alpha
+ #
+ # When ">>" is used to close templates, the alphanumeric letter that
+ # follows would be part of an identifier, and there should still be
+ # a space separating the template type and the identifier.
+ # type<type<type>> alpha
+ match = Search(r'>>[a-zA-Z_]', line)
+ if match:
+ error(filename, linenum, 'whitespace/operators', 3,
+ 'Missing spaces around >>')
+
+ # There shouldn't be space around unary operators
+ match = Search(r'(!\s|~\s|[\s]--[\s;]|[\s]\+\+[\s;])', line)
+ if match:
+ error(filename, linenum, 'whitespace/operators', 4,
+ 'Extra space for operator %s' % match.group(1))
+
+
+def CheckParenthesisSpacing(filename, clean_lines, linenum, error):
+ """Checks for horizontal spacing around parentheses.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # No spaces after an if, while, switch, or for
+ match = Search(r' (if\(|for\(|while\(|switch\()', line)
+ if match:
+ error(filename, linenum, 'whitespace/parens', 5,
+ 'Missing space before ( in %s' % match.group(1))
+
+ # For if/for/while/switch, the left and right parens should be
+ # consistent about how many spaces are inside the parens, and
+ # there should either be zero or one spaces inside the parens.
+ # We don't want: "if ( foo)" or "if ( foo )".
+ # Exception: "for ( ; foo; bar)" and "for (foo; bar; )" are allowed.
+ match = Search(r'\b(if|for|while|switch)\s*'
+ r'\(([ ]*)(.).*[^ ]+([ ]*)\)\s*{\s*$',
+ line)
+ if match:
+ if len(match.group(2)) != len(match.group(4)):
+ if not (match.group(3) == ';' and
+ len(match.group(2)) == 1 + len(match.group(4)) or
+ not match.group(2) and Search(r'\bfor\s*\(.*; \)', line)):
+ error(filename, linenum, 'whitespace/parens', 5,
+ 'Mismatching spaces inside () in %s' % match.group(1))
+ if len(match.group(2)) not in [0, 1]:
+ error(filename, linenum, 'whitespace/parens', 5,
+ 'Should have zero or one spaces inside ( and ) in %s' %
+ match.group(1))
+
+
+def CheckCommaSpacing(filename, clean_lines, linenum, error):
+ """Checks for horizontal spacing near commas and semicolons.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ raw = clean_lines.lines_without_raw_strings
+ line = clean_lines.elided[linenum]
+
+ # You should always have a space after a comma (either as fn arg or operator)
+ #
+ # This does not apply when the non-space character following the
+ # comma is another comma, since the only time when that happens is
+ # for empty macro arguments.
+ #
+ # We run this check in two passes: first pass on elided lines to
+ # verify that lines contain missing whitespaces, second pass on raw
+ # lines to confirm that those missing whitespaces are not due to
+ # elided comments.
+ if (Search(r',[^,\s]', ReplaceAll(r'\boperator\s*,\s*\(', 'F(', line)) and
+ Search(r',[^,\s]', raw[linenum])):
+ error(filename, linenum, 'whitespace/comma', 3,
+ 'Missing space after ,')
+
+ # You should always have a space after a semicolon
+ # except for few corner cases
+ # TODO(unknown): clarify if 'if (1) { return 1;}' is requires one more
+ # space after ;
+ if Search(r';[^\s};\\)/]', line):
+ error(filename, linenum, 'whitespace/semicolon', 3,
+ 'Missing space after ;')
+
+
+def _IsType(clean_lines, nesting_state, expr):
+ """Check if expression looks like a type name, returns true if so.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ expr: The expression to check.
+ Returns:
+ True, if token looks like a type.
+ """
+ # Keep only the last token in the expression
+ last_word = Match(r'^.*(\b\S+)$', expr)
+ if last_word:
+ token = last_word.group(1)
+ else:
+ token = expr
+
+ # Match native types and stdint types
+ if _TYPES.match(token):
+ return True
+
+ # Try a bit harder to match templated types. Walk up the nesting
+ # stack until we find something that resembles a typename
+ # declaration for what we are looking for.
+ typename_pattern = (r'\b(?:typename|class|struct)\s+' + re.escape(token) +
+ r'\b')
+ block_index = len(nesting_state.stack) - 1
+ while block_index >= 0:
+ if isinstance(nesting_state.stack[block_index], _NamespaceInfo):
+ return False
+
+ # Found where the opening brace is. We want to scan from this
+ # line up to the beginning of the function, minus a few lines.
+ # template <typename Type1, // stop scanning here
+ # ...>
+ # class C
+ # : public ... { // start scanning here
+ last_line = nesting_state.stack[block_index].starting_linenum
+
+ next_block_start = 0
+ if block_index > 0:
+ next_block_start = nesting_state.stack[block_index - 1].starting_linenum
+ first_line = last_line
+ while first_line >= next_block_start:
+ if clean_lines.elided[first_line].find('template') >= 0:
+ break
+ first_line -= 1
+ if first_line < next_block_start:
+ # Didn't find any "template" keyword before reaching the next block,
+ # there are probably no template things to check for this block
+ block_index -= 1
+ continue
+
+ # Look for typename in the specified range
+ for i in xrange(first_line, last_line + 1, 1):
+ if Search(typename_pattern, clean_lines.elided[i]):
+ return True
+ block_index -= 1
+
+ return False
+
+
+def CheckBracesSpacing(filename, clean_lines, linenum, nesting_state, error):
+ """Checks for horizontal spacing near commas.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Except after an opening paren, or after another opening brace (in case of
+ # an initializer list, for instance), you should have spaces before your
+ # braces when they are delimiting blocks, classes, namespaces etc.
+ # And since you should never have braces at the beginning of a line,
+ # this is an easy test. Except that braces used for initialization don't
+ # follow the same rule; we often don't want spaces before those.
+ match = Match(r'^(.*[^ ({>]){', line)
+
+ if match:
+ # Try a bit harder to check for brace initialization. This
+ # happens in one of the following forms:
+ # Constructor() : initializer_list_{} { ... }
+ # Constructor{}.MemberFunction()
+ # Type variable{};
+ # FunctionCall(type{}, ...);
+ # LastArgument(..., type{});
+ # LOG(INFO) << type{} << " ...";
+ # map_of_type[{...}] = ...;
+ # ternary = expr ? new type{} : nullptr;
+ # OuterTemplate<InnerTemplateConstructor<Type>{}>
+ #
+ # We check for the character following the closing brace, and
+ # silence the warning if it's one of those listed above, i.e.
+ # "{.;,)<>]:".
+ #
+ # To account for nested initializer list, we allow any number of
+ # closing braces up to "{;,)<". We can't simply silence the
+ # warning on first sight of closing brace, because that would
+ # cause false negatives for things that are not initializer lists.
+ # Silence this: But not this:
+ # Outer{ if (...) {
+ # Inner{...} if (...){ // Missing space before {
+ # }; }
+ #
+ # There is a false negative with this approach if people inserted
+ # spurious semicolons, e.g. "if (cond){};", but we will catch the
+ # spurious semicolon with a separate check.
+ leading_text = match.group(1)
+ (endline, endlinenum, endpos) = CloseExpression(
+ clean_lines, linenum, len(match.group(1)))
+ trailing_text = ''
+ if endpos > -1:
+ trailing_text = endline[endpos:]
+ for offset in xrange(endlinenum + 1,
+ min(endlinenum + 3, clean_lines.NumLines() - 1)):
+ trailing_text += clean_lines.elided[offset]
+ # We also suppress warnings for `uint64_t{expression}` etc., as the style
+ # guide recommends brace initialization for integral types to avoid
+ # overflow/truncation.
+ if (not Match(r'^[\s}]*[{.;,)<>\]:]', trailing_text)
+ and not _IsType(clean_lines, nesting_state, leading_text)):
+ error(filename, linenum, 'whitespace/braces', 5,
+ 'Missing space before {')
+
+ # Make sure '} else {' has spaces.
+ if Search(r'}else', line):
+ error(filename, linenum, 'whitespace/braces', 5,
+ 'Missing space before else')
+
+ # You shouldn't have a space before a semicolon at the end of the line.
+ # There's a special case for "for" since the style guide allows space before
+ # the semicolon there.
+ if Search(r':\s*;\s*$', line):
+ error(filename, linenum, 'whitespace/semicolon', 5,
+ 'Semicolon defining empty statement. Use {} instead.')
+ elif Search(r'^\s*;\s*$', line):
+ error(filename, linenum, 'whitespace/semicolon', 5,
+ 'Line contains only semicolon. If this should be an empty statement, '
+ 'use {} instead.')
+ elif (Search(r'\s+;\s*$', line) and
+ not Search(r'\bfor\b', line)):
+ error(filename, linenum, 'whitespace/semicolon', 5,
+ 'Extra space before last semicolon. If this should be an empty '
+ 'statement, use {} instead.')
+
+
+def IsDecltype(clean_lines, linenum, column):
+ """Check if the token ending on (linenum, column) is decltype().
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: the number of the line to check.
+ column: end column of the token to check.
+ Returns:
+ True if this token is decltype() expression, False otherwise.
+ """
+ (text, _, start_col) = ReverseCloseExpression(clean_lines, linenum, column)
+ if start_col < 0:
+ return False
+ if Search(r'\bdecltype\s*$', text[0:start_col]):
+ return True
+ return False
+
+def CheckSectionSpacing(filename, clean_lines, class_info, linenum, error):
+ """Checks for additional blank line issues related to sections.
+
+ Currently the only thing checked here is blank line before protected/private.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ class_info: A _ClassInfo objects.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ # Skip checks if the class is small, where small means 25 lines or less.
+ # 25 lines seems like a good cutoff since that's the usual height of
+ # terminals, and any class that can't fit in one screen can't really
+ # be considered "small".
+ #
+ # Also skip checks if we are on the first line. This accounts for
+ # classes that look like
+ # class Foo { public: ... };
+ #
+ # If we didn't find the end of the class, last_line would be zero,
+ # and the check will be skipped by the first condition.
+ if (class_info.last_line - class_info.starting_linenum <= 24 or
+ linenum <= class_info.starting_linenum):
+ return
+
+ matched = Match(r'\s*(public|protected|private):', clean_lines.lines[linenum])
+ if matched:
+ # Issue warning if the line before public/protected/private was
+ # not a blank line, but don't do this if the previous line contains
+ # "class" or "struct". This can happen two ways:
+ # - We are at the beginning of the class.
+ # - We are forward-declaring an inner class that is semantically
+ # private, but needed to be public for implementation reasons.
+ # Also ignores cases where the previous line ends with a backslash as can be
+ # common when defining classes in C macros.
+ prev_line = clean_lines.lines[linenum - 1]
+ if (not IsBlankLine(prev_line) and
+ not Search(r'\b(class|struct)\b', prev_line) and
+ not Search(r'\\$', prev_line)):
+ # Try a bit harder to find the beginning of the class. This is to
+ # account for multi-line base-specifier lists, e.g.:
+ # class Derived
+ # : public Base {
+ end_class_head = class_info.starting_linenum
+ for i in range(class_info.starting_linenum, linenum):
+ if Search(r'\{\s*$', clean_lines.lines[i]):
+ end_class_head = i
+ break
+ if end_class_head < linenum - 1:
+ error(filename, linenum, 'whitespace/blank_line', 3,
+ '"%s:" should be preceded by a blank line' % matched.group(1))
+
+
+def GetPreviousNonBlankLine(clean_lines, linenum):
+ """Return the most recent non-blank line and its line number.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file contents.
+ linenum: The number of the line to check.
+
+ Returns:
+ A tuple with two elements. The first element is the contents of the last
+ non-blank line before the current line, or the empty string if this is the
+ first non-blank line. The second is the line number of that line, or -1
+ if this is the first non-blank line.
+ """
+
+ prevlinenum = linenum - 1
+ while prevlinenum >= 0:
+ prevline = clean_lines.elided[prevlinenum]
+ if not IsBlankLine(prevline): # if not a blank line...
+ return (prevline, prevlinenum)
+ prevlinenum -= 1
+ return ('', -1)
+
+
+def CheckBraces(filename, clean_lines, linenum, error):
+ """Looks for misplaced braces (e.g. at the end of line).
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+
+ line = clean_lines.elided[linenum] # get rid of comments and strings
+
+ if Match(r'\s*{\s*$', line):
+ # We allow an open brace to start a line in the case where someone is using
+ # braces in a block to explicitly create a new scope, which is commonly used
+ # to control the lifetime of stack-allocated variables. Braces are also
+ # used for brace initializers inside function calls. We don't detect this
+ # perfectly: we just don't complain if the last non-whitespace character on
+ # the previous non-blank line is ',', ';', ':', '(', '{', or '}', or if the
+ # previous line starts a preprocessor block. We also allow a brace on the
+ # following line if it is part of an array initialization and would not fit
+ # within the 80 character limit of the preceding line.
+ prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0]
+ if (not Search(r'[,;:}{(]\s*$', prevline) and
+ not Match(r'\s*#', prevline) and
+ not (GetLineWidth(prevline) > _line_length - 2 and '[]' in prevline)):
+ error(filename, linenum, 'whitespace/braces', 4,
+ '{ should almost always be at the end of the previous line')
+
+ # An else clause should be on the same line as the preceding closing brace.
+ if Match(r'\s*else\b\s*(?:if\b|\{|$)', line):
+ prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0]
+ if Match(r'\s*}\s*$', prevline):
+ error(filename, linenum, 'whitespace/newline', 4,
+ 'An else should appear on the same line as the preceding }')
+
+ # If braces come on one side of an else, they should be on both.
+ # However, we have to worry about "else if" that spans multiple lines!
+ if Search(r'else if\s*\(', line): # could be multi-line if
+ brace_on_left = bool(Search(r'}\s*else if\s*\(', line))
+ # find the ( after the if
+ pos = line.find('else if')
+ pos = line.find('(', pos)
+ if pos > 0:
+ (endline, _, endpos) = CloseExpression(clean_lines, linenum, pos)
+ brace_on_right = endline[endpos:].find('{') != -1
+ if brace_on_left != brace_on_right: # must be brace after if
+ error(filename, linenum, 'readability/braces', 5,
+ 'If an else has a brace on one side, it should have it on both')
+ elif Search(r'}\s*else[^{]*$', line) or Match(r'[^}]*else\s*{', line):
+ error(filename, linenum, 'readability/braces', 5,
+ 'If an else has a brace on one side, it should have it on both')
+
+ # Likewise, an else should never have the else clause on the same line
+ if Search(r'\belse [^\s{]', line) and not Search(r'\belse if\b', line):
+ error(filename, linenum, 'whitespace/newline', 4,
+ 'Else clause should never be on same line as else (use 2 lines)')
+
+ # In the same way, a do/while should never be on one line
+ if Match(r'\s*do [^\s{]', line):
+ error(filename, linenum, 'whitespace/newline', 4,
+ 'do/while clauses should not be on a single line')
+
+ # Check single-line if/else bodies. The style guide says 'curly braces are not
+ # required for single-line statements'. We additionally allow multi-line,
+ # single statements, but we reject anything with more than one semicolon in
+ # it. This means that the first semicolon after the if should be at the end of
+ # its line, and the line after that should have an indent level equal to or
+ # lower than the if. We also check for ambiguous if/else nesting without
+ # braces.
+ if_else_match = Search(r'\b(if\s*\(|else\b)', line)
+ if if_else_match and not Match(r'\s*#', line):
+ if_indent = GetIndentLevel(line)
+ endline, endlinenum, endpos = line, linenum, if_else_match.end()
+ if_match = Search(r'\bif\s*\(', line)
+ if if_match:
+ # This could be a multiline if condition, so find the end first.
+ pos = if_match.end() - 1
+ (endline, endlinenum, endpos) = CloseExpression(clean_lines, linenum, pos)
+ # Check for an opening brace, either directly after the if or on the next
+ # line. If found, this isn't a single-statement conditional.
+ if (not Match(r'\s*{', endline[endpos:])
+ and not (Match(r'\s*$', endline[endpos:])
+ and endlinenum < (len(clean_lines.elided) - 1)
+ and Match(r'\s*{', clean_lines.elided[endlinenum + 1]))):
+ while (endlinenum < len(clean_lines.elided)
+ and ';' not in clean_lines.elided[endlinenum][endpos:]):
+ endlinenum += 1
+ endpos = 0
+ if endlinenum < len(clean_lines.elided):
+ endline = clean_lines.elided[endlinenum]
+ # We allow a mix of whitespace and closing braces (e.g. for one-liner
+ # methods) and a single \ after the semicolon (for macros)
+ endpos = endline.find(';')
+ if not Match(r';[\s}]*(\\?)$', endline[endpos:]):
+ # Semicolon isn't the last character, there's something trailing.
+ # Output a warning if the semicolon is not contained inside
+ # a lambda expression.
+ if not Match(r'^[^{};]*\[[^\[\]]*\][^{}]*\{[^{}]*\}\s*\)*[;,]\s*$',
+ endline):
+ error(filename, linenum, 'readability/braces', 4,
+ 'If/else bodies with multiple statements require braces')
+ elif endlinenum < len(clean_lines.elided) - 1:
+ # Make sure the next line is dedented
+ next_line = clean_lines.elided[endlinenum + 1]
+ next_indent = GetIndentLevel(next_line)
+ # With ambiguous nested if statements, this will error out on the
+ # if that *doesn't* match the else, regardless of whether it's the
+ # inner one or outer one.
+ if (if_match and Match(r'\s*else\b', next_line)
+ and next_indent != if_indent):
+ error(filename, linenum, 'readability/braces', 4,
+ 'Else clause should be indented at the same level as if. '
+ 'Ambiguous nested if/else chains require braces.')
+ elif next_indent > if_indent:
+ error(filename, linenum, 'readability/braces', 4,
+ 'If/else bodies with multiple statements require braces')
+
+
+def CheckTrailingSemicolon(filename, clean_lines, linenum, error):
+ """Looks for redundant trailing semicolon.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+
+ line = clean_lines.elided[linenum]
+
+ # Block bodies should not be followed by a semicolon. Due to C++11
+ # brace initialization, there are more places where semicolons are
+ # required than not, so we use explicitly list the allowed rules
+ # rather than listing the disallowed ones. These are the places
+ # where "};" should be replaced by just "}":
+ # 1. Some flavor of block following closing parenthesis:
+ # for (;;) {};
+ # while (...) {};
+ # switch (...) {};
+ # Function(...) {};
+ # if (...) {};
+ # if (...) else if (...) {};
+ #
+ # 2. else block:
+ # if (...) else {};
+ #
+ # 3. const member function:
+ # Function(...) const {};
+ #
+ # 4. Block following some statement:
+ # x = 42;
+ # {};
+ #
+ # 5. Block at the beginning of a function:
+ # Function(...) {
+ # {};
+ # }
+ #
+ # Note that naively checking for the preceding "{" will also match
+ # braces inside multi-dimensional arrays, but this is fine since
+ # that expression will not contain semicolons.
+ #
+ # 6. Block following another block:
+ # while (true) {}
+ # {};
+ #
+ # 7. End of namespaces:
+ # namespace {};
+ #
+ # These semicolons seems far more common than other kinds of
+ # redundant semicolons, possibly due to people converting classes
+ # to namespaces. For now we do not warn for this case.
+ #
+ # Try matching case 1 first.
+ match = Match(r'^(.*\)\s*)\{', line)
+ if match:
+ # Matched closing parenthesis (case 1). Check the token before the
+ # matching opening parenthesis, and don't warn if it looks like a
+ # macro. This avoids these false positives:
+ # - macro that defines a base class
+ # - multi-line macro that defines a base class
+ # - macro that defines the whole class-head
+ #
+ # But we still issue warnings for macros that we know are safe to
+ # warn, specifically:
+ # - TEST, TEST_F, TEST_P, MATCHER, MATCHER_P
+ # - TYPED_TEST
+ # - INTERFACE_DEF
+ # - EXCLUSIVE_LOCKS_REQUIRED, SHARED_LOCKS_REQUIRED, LOCKS_EXCLUDED:
+ #
+ # We implement a list of safe macros instead of a list of
+ # unsafe macros, even though the latter appears less frequently in
+ # google code and would have been easier to implement. This is because
+ # the downside for getting the allowed checks wrong means some extra
+ # semicolons, while the downside for getting disallowed checks wrong
+ # would result in compile errors.
+ #
+ # In addition to macros, we also don't want to warn on
+ # - Compound literals
+ # - Lambdas
+ # - alignas specifier with anonymous structs
+ # - decltype
+ closing_brace_pos = match.group(1).rfind(')')
+ opening_parenthesis = ReverseCloseExpression(
+ clean_lines, linenum, closing_brace_pos)
+ if opening_parenthesis[2] > -1:
+ line_prefix = opening_parenthesis[0][0:opening_parenthesis[2]]
+ macro = Search(r'\b([A-Z_][A-Z0-9_]*)\s*$', line_prefix)
+ func = Match(r'^(.*\])\s*$', line_prefix)
+ if ((macro and
+ macro.group(1) not in (
+ 'TEST', 'TEST_F', 'MATCHER', 'MATCHER_P', 'TYPED_TEST',
+ 'EXCLUSIVE_LOCKS_REQUIRED', 'SHARED_LOCKS_REQUIRED',
+ 'LOCKS_EXCLUDED', 'INTERFACE_DEF')) or
+ (func and not Search(r'\boperator\s*\[\s*\]', func.group(1))) or
+ Search(r'\b(?:struct|union)\s+alignas\s*$', line_prefix) or
+ Search(r'\bdecltype$', line_prefix) or
+ Search(r'\s+=\s*$', line_prefix)):
+ match = None
+ if (match and
+ opening_parenthesis[1] > 1 and
+ Search(r'\]\s*$', clean_lines.elided[opening_parenthesis[1] - 1])):
+ # Multi-line lambda-expression
+ match = None
+
+ else:
+ # Try matching cases 2-3.
+ match = Match(r'^(.*(?:else|\)\s*const)\s*)\{', line)
+ if not match:
+ # Try matching cases 4-6. These are always matched on separate lines.
+ #
+ # Note that we can't simply concatenate the previous line to the
+ # current line and do a single match, otherwise we may output
+ # duplicate warnings for the blank line case:
+ # if (cond) {
+ # // blank line
+ # }
+ prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0]
+ if prevline and Search(r'[;{}]\s*$', prevline):
+ match = Match(r'^(\s*)\{', line)
+
+ # Check matching closing brace
+ if match:
+ (endline, endlinenum, endpos) = CloseExpression(
+ clean_lines, linenum, len(match.group(1)))
+ if endpos > -1 and Match(r'^\s*;', endline[endpos:]):
+ # Current {} pair is eligible for semicolon check, and we have found
+ # the redundant semicolon, output warning here.
+ #
+ # Note: because we are scanning forward for opening braces, and
+ # outputting warnings for the matching closing brace, if there are
+ # nested blocks with trailing semicolons, we will get the error
+ # messages in reversed order.
+
+ # We need to check the line forward for NOLINT
+ raw_lines = clean_lines.raw_lines
+ ParseNolintSuppressions(filename, raw_lines[endlinenum-1], endlinenum-1,
+ error)
+ ParseNolintSuppressions(filename, raw_lines[endlinenum], endlinenum,
+ error)
+
+ error(filename, endlinenum, 'readability/braces', 4,
+ "You don't need a ; after a }")
+
+
+def CheckEmptyBlockBody(filename, clean_lines, linenum, error):
+ """Look for empty loop/conditional body with only a single semicolon.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+
+ # Search for loop keywords at the beginning of the line. Because only
+ # whitespaces are allowed before the keywords, this will also ignore most
+ # do-while-loops, since those lines should start with closing brace.
+ #
+ # We also check "if" blocks here, since an empty conditional block
+ # is likely an error.
+ line = clean_lines.elided[linenum]
+ matched = Match(r'\s*(for|while|if)\s*\(', line)
+ if matched:
+ # Find the end of the conditional expression.
+ (end_line, end_linenum, end_pos) = CloseExpression(
+ clean_lines, linenum, line.find('('))
+
+ # Output warning if what follows the condition expression is a semicolon.
+ # No warning for all other cases, including whitespace or newline, since we
+ # have a separate check for semicolons preceded by whitespace.
+ if end_pos >= 0 and Match(r';', end_line[end_pos:]):
+ if matched.group(1) == 'if':
+ error(filename, end_linenum, 'whitespace/empty_conditional_body', 5,
+ 'Empty conditional bodies should use {}')
+ else:
+ error(filename, end_linenum, 'whitespace/empty_loop_body', 5,
+ 'Empty loop bodies should use {} or continue')
+
+ # Check for if statements that have completely empty bodies (no comments)
+ # and no else clauses.
+ if end_pos >= 0 and matched.group(1) == 'if':
+ # Find the position of the opening { for the if statement.
+ # Return without logging an error if it has no brackets.
+ opening_linenum = end_linenum
+ opening_line_fragment = end_line[end_pos:]
+ # Loop until EOF or find anything that's not whitespace or opening {.
+ while not Search(r'^\s*\{', opening_line_fragment):
+ if Search(r'^(?!\s*$)', opening_line_fragment):
+ # Conditional has no brackets.
+ return
+ opening_linenum += 1
+ if opening_linenum == len(clean_lines.elided):
+ # Couldn't find conditional's opening { or any code before EOF.
+ return
+ opening_line_fragment = clean_lines.elided[opening_linenum]
+ # Set opening_line (opening_line_fragment may not be entire opening line).
+ opening_line = clean_lines.elided[opening_linenum]
+
+ # Find the position of the closing }.
+ opening_pos = opening_line_fragment.find('{')
+ if opening_linenum == end_linenum:
+ # We need to make opening_pos relative to the start of the entire line.
+ opening_pos += end_pos
+ (closing_line, closing_linenum, closing_pos) = CloseExpression(
+ clean_lines, opening_linenum, opening_pos)
+ if closing_pos < 0:
+ return
+
+ # Now construct the body of the conditional. This consists of the portion
+ # of the opening line after the {, all lines until the closing line,
+ # and the portion of the closing line before the }.
+ if (clean_lines.raw_lines[opening_linenum] !=
+ CleanseComments(clean_lines.raw_lines[opening_linenum])):
+ # Opening line ends with a comment, so conditional isn't empty.
+ return
+ if closing_linenum > opening_linenum:
+ # Opening line after the {. Ignore comments here since we checked above.
+ bodylist = list(opening_line[opening_pos+1:])
+ # All lines until closing line, excluding closing line, with comments.
+ bodylist.extend(clean_lines.raw_lines[opening_linenum+1:closing_linenum])
+ # Closing line before the }. Won't (and can't) have comments.
+ bodylist.append(clean_lines.elided[closing_linenum][:closing_pos-1])
+ body = '\n'.join(bodylist)
+ else:
+ # If statement has brackets and fits on a single line.
+ body = opening_line[opening_pos+1:closing_pos-1]
+
+ # Check if the body is empty
+ if not _EMPTY_CONDITIONAL_BODY_PATTERN.search(body):
+ return
+ # The body is empty. Now make sure there's not an else clause.
+ current_linenum = closing_linenum
+ current_line_fragment = closing_line[closing_pos:]
+ # Loop until EOF or find anything that's not whitespace or else clause.
+ while Search(r'^\s*$|^(?=\s*else)', current_line_fragment):
+ if Search(r'^(?=\s*else)', current_line_fragment):
+ # Found an else clause, so don't log an error.
+ return
+ current_linenum += 1
+ if current_linenum == len(clean_lines.elided):
+ break
+ current_line_fragment = clean_lines.elided[current_linenum]
+
+ # The body is empty and there's no else clause until EOF or other code.
+ error(filename, end_linenum, 'whitespace/empty_if_body', 4,
+ ('If statement had no body and no else clause'))
+
+
+def FindCheckMacro(line):
+ """Find a replaceable CHECK-like macro.
+
+ Args:
+ line: line to search on.
+ Returns:
+ (macro name, start position), or (None, -1) if no replaceable
+ macro is found.
+ """
+ for macro in _CHECK_MACROS:
+ i = line.find(macro)
+ if i >= 0:
+ # Find opening parenthesis. Do a regular expression match here
+ # to make sure that we are matching the expected CHECK macro, as
+ # opposed to some other macro that happens to contain the CHECK
+ # substring.
+ matched = Match(r'^(.*\b' + macro + r'\s*)\(', line)
+ if not matched:
+ continue
+ return (macro, len(matched.group(1)))
+ return (None, -1)
+
+
+def CheckCheck(filename, clean_lines, linenum, error):
+ """Checks the use of CHECK and EXPECT macros.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+
+ # Decide the set of replacement macros that should be suggested
+ lines = clean_lines.elided
+ (check_macro, start_pos) = FindCheckMacro(lines[linenum])
+ if not check_macro:
+ return
+
+ # Find end of the boolean expression by matching parentheses
+ (last_line, end_line, end_pos) = CloseExpression(
+ clean_lines, linenum, start_pos)
+ if end_pos < 0:
+ return
+
+ # If the check macro is followed by something other than a
+ # semicolon, assume users will log their own custom error messages
+ # and don't suggest any replacements.
+ if not Match(r'\s*;', last_line[end_pos:]):
+ return
+
+ if linenum == end_line:
+ expression = lines[linenum][start_pos + 1:end_pos - 1]
+ else:
+ expression = lines[linenum][start_pos + 1:]
+ for i in xrange(linenum + 1, end_line):
+ expression += lines[i]
+ expression += last_line[0:end_pos - 1]
+
+ # Parse expression so that we can take parentheses into account.
+ # This avoids false positives for inputs like "CHECK((a < 4) == b)",
+ # which is not replaceable by CHECK_LE.
+ lhs = ''
+ rhs = ''
+ operator = None
+ while expression:
+ matched = Match(r'^\s*(<<|<<=|>>|>>=|->\*|->|&&|\|\||'
+ r'==|!=|>=|>|<=|<|\()(.*)$', expression)
+ if matched:
+ token = matched.group(1)
+ if token == '(':
+ # Parenthesized operand
+ expression = matched.group(2)
+ (end, _) = FindEndOfExpressionInLine(expression, 0, ['('])
+ if end < 0:
+ return # Unmatched parenthesis
+ lhs += '(' + expression[0:end]
+ expression = expression[end:]
+ elif token in ('&&', '||'):
+ # Logical and/or operators. This means the expression
+ # contains more than one term, for example:
+ # CHECK(42 < a && a < b);
+ #
+ # These are not replaceable with CHECK_LE, so bail out early.
+ return
+ elif token in ('<<', '<<=', '>>', '>>=', '->*', '->'):
+ # Non-relational operator
+ lhs += token
+ expression = matched.group(2)
+ else:
+ # Relational operator
+ operator = token
+ rhs = matched.group(2)
+ break
+ else:
+ # Unparenthesized operand. Instead of appending to lhs one character
+ # at a time, we do another regular expression match to consume several
+ # characters at once if possible. Trivial benchmark shows that this
+ # is more efficient when the operands are longer than a single
+ # character, which is generally the case.
+ matched = Match(r'^([^-=!<>()&|]+)(.*)$', expression)
+ if not matched:
+ matched = Match(r'^(\s*\S)(.*)$', expression)
+ if not matched:
+ break
+ lhs += matched.group(1)
+ expression = matched.group(2)
+
+ # Only apply checks if we got all parts of the boolean expression
+ if not (lhs and operator and rhs):
+ return
+
+ # Check that rhs do not contain logical operators. We already know
+ # that lhs is fine since the loop above parses out && and ||.
+ if rhs.find('&&') > -1 or rhs.find('||') > -1:
+ return
+
+ # At least one of the operands must be a constant literal. This is
+ # to avoid suggesting replacements for unprintable things like
+ # CHECK(variable != iterator)
+ #
+ # The following pattern matches decimal, hex integers, strings, and
+ # characters (in that order).
+ lhs = lhs.strip()
+ rhs = rhs.strip()
+ match_constant = r'^([-+]?(\d+|0[xX][0-9a-fA-F]+)[lLuU]{0,3}|".*"|\'.*\')$'
+ if Match(match_constant, lhs) or Match(match_constant, rhs):
+ # Note: since we know both lhs and rhs, we can provide a more
+ # descriptive error message like:
+ # Consider using CHECK_EQ(x, 42) instead of CHECK(x == 42)
+ # Instead of:
+ # Consider using CHECK_EQ instead of CHECK(a == b)
+ #
+ # We are still keeping the less descriptive message because if lhs
+ # or rhs gets long, the error message might become unreadable.
+ error(filename, linenum, 'readability/check', 2,
+ 'Consider using %s instead of %s(a %s b)' % (
+ _CHECK_REPLACEMENT[check_macro][operator],
+ check_macro, operator))
+
+
+def CheckAltTokens(filename, clean_lines, linenum, error):
+ """Check alternative keywords being used in boolean expressions.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Avoid preprocessor lines
+ if Match(r'^\s*#', line):
+ return
+
+ # Last ditch effort to avoid multi-line comments. This will not help
+ # if the comment started before the current line or ended after the
+ # current line, but it catches most of the false positives. At least,
+ # it provides a way to workaround this warning for people who use
+ # multi-line comments in preprocessor macros.
+ #
+ # TODO(unknown): remove this once cpplint has better support for
+ # multi-line comments.
+ if line.find('/*') >= 0 or line.find('*/') >= 0:
+ return
+
+ for match in _ALT_TOKEN_REPLACEMENT_PATTERN.finditer(line):
+ error(filename, linenum, 'readability/alt_tokens', 2,
+ 'Use operator %s instead of %s' % (
+ _ALT_TOKEN_REPLACEMENT[match.group(1)], match.group(1)))
+
+
+def GetLineWidth(line):
+ """Determines the width of the line in column positions.
+
+ Args:
+ line: A string, which may be a Unicode string.
+
+ Returns:
+ The width of the line in column positions, accounting for Unicode
+ combining characters and wide characters.
+ """
+ if isinstance(line, unicode):
+ width = 0
+ for uc in unicodedata.normalize('NFC', line):
+ if unicodedata.east_asian_width(uc) in ('W', 'F'):
+ width += 2
+ elif not unicodedata.combining(uc):
+ width += 1
+ return width
+ else:
+ return len(line)
+
+
+def CheckStyle(filename, clean_lines, linenum, file_extension, nesting_state,
+ error):
+ """Checks rules from the 'C++ style rules' section of cppguide.html.
+
+ Most of these rules are hard to test (naming, comment style), but we
+ do what we can. In particular we check for 2-space indents, line lengths,
+ tab usage, spaces inside code, etc.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ file_extension: The extension (without the dot) of the filename.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: The function to call with any errors found.
+ """
+
+ # Don't use "elided" lines here, otherwise we can't check commented lines.
+ # Don't want to use "raw" either, because we don't want to check inside C++11
+ # raw strings,
+ raw_lines = clean_lines.lines_without_raw_strings
+ line = raw_lines[linenum]
+ prev = raw_lines[linenum - 1] if linenum > 0 else ''
+
+ if line.find('\t') != -1:
+ error(filename, linenum, 'whitespace/tab', 1,
+ 'Tab found; better to use spaces')
+
+ # One or three blank spaces at the beginning of the line is weird; it's
+ # hard to reconcile that with 2-space indents.
+ # NOTE: here are the conditions rob pike used for his tests. Mine aren't
+ # as sophisticated, but it may be worth becoming so: RLENGTH==initial_spaces
+ # if(RLENGTH > 20) complain = 0;
+ # if(match($0, " +(error|private|public|protected):")) complain = 0;
+ # if(match(prev, "&& *$")) complain = 0;
+ # if(match(prev, "\\|\\| *$")) complain = 0;
+ # if(match(prev, "[\",=><] *$")) complain = 0;
+ # if(match($0, " <<")) complain = 0;
+ # if(match(prev, " +for \\(")) complain = 0;
+ # if(prevodd && match(prevprev, " +for \\(")) complain = 0;
+ scope_or_label_pattern = r'\s*\w+\s*:\s*\\?$'
+ classinfo = nesting_state.InnermostClass()
+ initial_spaces = 0
+ cleansed_line = clean_lines.elided[linenum]
+ while initial_spaces < len(line) and line[initial_spaces] == ' ':
+ initial_spaces += 1
+ # There are certain situations we allow one space, notably for
+ # section labels, and also lines containing multi-line raw strings.
+ # We also don't check for lines that look like continuation lines
+ # (of lines ending in double quotes, commas, equals, or angle brackets)
+ # because the rules for how to indent those are non-trivial.
+ if (not Search(r'[",=><] *$', prev) and
+ (initial_spaces == 1 or initial_spaces == 3) and
+ not Match(scope_or_label_pattern, cleansed_line) and
+ not (clean_lines.raw_lines[linenum] != line and
+ Match(r'^\s*""', line))):
+ error(filename, linenum, 'whitespace/indent', 3,
+ 'Weird number of spaces at line-start. '
+ 'Are you using a 2-space indent?')
+
+ if line and line[-1].isspace():
+ error(filename, linenum, 'whitespace/end_of_line', 4,
+ 'Line ends in whitespace. Consider deleting these extra spaces.')
+
+ # Check if the line is a header guard.
+ is_header_guard = False
+ if file_extension in GetHeaderExtensions():
+ cppvar = GetHeaderGuardCPPVariable(filename)
+ if (line.startswith('#ifndef %s' % cppvar) or
+ line.startswith('#define %s' % cppvar) or
+ line.startswith('#endif // %s' % cppvar)):
+ is_header_guard = True
+ # #include lines and header guards can be long, since there's no clean way to
+ # split them.
+ #
+ # URLs can be long too. It's possible to split these, but it makes them
+ # harder to cut&paste.
+ #
+ # The "$Id:...$" comment may also get very long without it being the
+ # developers fault.
+ #
+ # Doxygen documentation copying can get pretty long when using an overloaded
+ # function declaration
+ if (not line.startswith('#include') and not is_header_guard and
+ not Match(r'^\s*//.*http(s?)://\S*$', line) and
+ not Match(r'^\s*//\s*[^\s]*$', line) and
+ not Match(r'^// \$Id:.*#[0-9]+ \$$', line) and
+ not Match(r'^\s*/// [@\\](copydoc|copydetails|copybrief) .*$', line)):
+ line_width = GetLineWidth(line)
+ if line_width > _line_length:
+ error(filename, linenum, 'whitespace/line_length', 2,
+ 'Lines should be <= %i characters long' % _line_length)
+
+ if (cleansed_line.count(';') > 1 and
+ # allow simple single line lambdas
+ not Match(r'^[^{};]*\[[^\[\]]*\][^{}]*\{[^{}\n\r]*\}',
+ line) and
+ # for loops are allowed two ;'s (and may run over two lines).
+ cleansed_line.find('for') == -1 and
+ (GetPreviousNonBlankLine(clean_lines, linenum)[0].find('for') == -1 or
+ GetPreviousNonBlankLine(clean_lines, linenum)[0].find(';') != -1) and
+ # It's ok to have many commands in a switch case that fits in 1 line
+ not ((cleansed_line.find('case ') != -1 or
+ cleansed_line.find('default:') != -1) and
+ cleansed_line.find('break;') != -1)):
+ error(filename, linenum, 'whitespace/newline', 0,
+ 'More than one command on the same line')
+
+ # Some more style checks
+ CheckBraces(filename, clean_lines, linenum, error)
+ CheckTrailingSemicolon(filename, clean_lines, linenum, error)
+ CheckEmptyBlockBody(filename, clean_lines, linenum, error)
+ CheckAccess(filename, clean_lines, linenum, nesting_state, error)
+ CheckSpacing(filename, clean_lines, linenum, nesting_state, error)
+ CheckOperatorSpacing(filename, clean_lines, linenum, error)
+ CheckParenthesisSpacing(filename, clean_lines, linenum, error)
+ CheckCommaSpacing(filename, clean_lines, linenum, error)
+ CheckBracesSpacing(filename, clean_lines, linenum, nesting_state, error)
+ CheckSpacingForFunctionCall(filename, clean_lines, linenum, error)
+ CheckCheck(filename, clean_lines, linenum, error)
+ CheckAltTokens(filename, clean_lines, linenum, error)
+ classinfo = nesting_state.InnermostClass()
+ if classinfo:
+ CheckSectionSpacing(filename, clean_lines, classinfo, linenum, error)
+
+
+_RE_PATTERN_INCLUDE = re.compile(r'^\s*#\s*include\s*([<"])([^>"]*)[>"].*$')
+# Matches the first component of a filename delimited by -s and _s. That is:
+# _RE_FIRST_COMPONENT.match('foo').group(0) == 'foo'
+# _RE_FIRST_COMPONENT.match('foo.cc').group(0) == 'foo'
+# _RE_FIRST_COMPONENT.match('foo-bar_baz.cc').group(0) == 'foo'
+# _RE_FIRST_COMPONENT.match('foo_bar-baz.cc').group(0) == 'foo'
+_RE_FIRST_COMPONENT = re.compile(r'^[^-_.]+')
+
+
+def _DropCommonSuffixes(filename):
+ """Drops common suffixes like _test.cc or -inl.h from filename.
+
+ For example:
+ >>> _DropCommonSuffixes('foo/foo-inl.h')
+ 'foo/foo'
+ >>> _DropCommonSuffixes('foo/bar/foo.cc')
+ 'foo/bar/foo'
+ >>> _DropCommonSuffixes('foo/foo_internal.h')
+ 'foo/foo'
+ >>> _DropCommonSuffixes('foo/foo_unusualinternal.h')
+ 'foo/foo_unusualinternal'
+
+ Args:
+ filename: The input filename.
+
+ Returns:
+ The filename with the common suffix removed.
+ """
+ for suffix in itertools.chain(
+ ('%s.%s' % (test_suffix.lstrip('_'), ext)
+ for test_suffix, ext in itertools.product(_test_suffixes, GetNonHeaderExtensions())),
+ ('%s.%s' % (suffix, ext)
+ for suffix, ext in itertools.product(['inl', 'imp', 'internal'], GetHeaderExtensions()))):
+ if (filename.endswith(suffix) and len(filename) > len(suffix) and
+ filename[-len(suffix) - 1] in ('-', '_')):
+ return filename[:-len(suffix) - 1]
+ return os.path.splitext(filename)[0]
+
+
+def _ClassifyInclude(fileinfo, include, is_system):
+ """Figures out what kind of header 'include' is.
+
+ Args:
+ fileinfo: The current file cpplint is running over. A FileInfo instance.
+ include: The path to a #included file.
+ is_system: True if the #include used <> rather than "".
+
+ Returns:
+ One of the _XXX_HEADER constants.
+
+ For example:
+ >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'stdio.h', True)
+ _C_SYS_HEADER
+ >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'string', True)
+ _CPP_SYS_HEADER
+ >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'foo/foo.h', False)
+ _LIKELY_MY_HEADER
+ >>> _ClassifyInclude(FileInfo('foo/foo_unknown_extension.cc'),
+ ... 'bar/foo_other_ext.h', False)
+ _POSSIBLE_MY_HEADER
+ >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'foo/bar.h', False)
+ _OTHER_HEADER
+ """
+ # This is a list of all standard c++ header files, except
+ # those already checked for above.
+ is_cpp_h = include in _CPP_HEADERS
+
+ # Headers with C++ extensions shouldn't be considered C system headers
+ if is_system and os.path.splitext(include)[1] in ['.hpp', '.hxx', '.h++']:
+ is_system = False
+
+ if is_system:
+ if is_cpp_h:
+ return _CPP_SYS_HEADER
+ else:
+ return _C_SYS_HEADER
+
+ # If the target file and the include we're checking share a
+ # basename when we drop common extensions, and the include
+ # lives in . , then it's likely to be owned by the target file.
+ target_dir, target_base = (
+ os.path.split(_DropCommonSuffixes(fileinfo.RepositoryName())))
+ include_dir, include_base = os.path.split(_DropCommonSuffixes(include))
+ target_dir_pub = os.path.normpath(target_dir + '/../public')
+ target_dir_pub = target_dir_pub.replace('\\', '/')
+ if target_base == include_base and (
+ include_dir == target_dir or
+ include_dir == target_dir_pub):
+ return _LIKELY_MY_HEADER
+
+ # If the target and include share some initial basename
+ # component, it's possible the target is implementing the
+ # include, so it's allowed to be first, but we'll never
+ # complain if it's not there.
+ target_first_component = _RE_FIRST_COMPONENT.match(target_base)
+ include_first_component = _RE_FIRST_COMPONENT.match(include_base)
+ if (target_first_component and include_first_component and
+ target_first_component.group(0) ==
+ include_first_component.group(0)):
+ return _POSSIBLE_MY_HEADER
+
+ return _OTHER_HEADER
+
+
+
+def CheckIncludeLine(filename, clean_lines, linenum, include_state, error):
+ """Check rules that are applicable to #include lines.
+
+ Strings on #include lines are NOT removed from elided line, to make
+ certain tasks easier. However, to prevent false positives, checks
+ applicable to #include lines in CheckLanguage must be put here.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ include_state: An _IncludeState instance in which the headers are inserted.
+ error: The function to call with any errors found.
+ """
+ fileinfo = FileInfo(filename)
+ line = clean_lines.lines[linenum]
+
+ # "include" should use the new style "foo/bar.h" instead of just "bar.h"
+ # Only do this check if the included header follows google naming
+ # conventions. If not, assume that it's a 3rd party API that
+ # requires special include conventions.
+ #
+ # We also make an exception for Lua headers, which follow google
+ # naming convention but not the include convention.
+ match = Match(r'#include\s*"([^/]+\.h)"', line)
+ if match and not _THIRD_PARTY_HEADERS_PATTERN.match(match.group(1)):
+ error(filename, linenum, 'build/include_subdir', 4,
+ 'Include the directory when naming .h files')
+
+ # we shouldn't include a file more than once. actually, there are a
+ # handful of instances where doing so is okay, but in general it's
+ # not.
+ match = _RE_PATTERN_INCLUDE.search(line)
+ if match:
+ include = match.group(2)
+ is_system = (match.group(1) == '<')
+ duplicate_line = include_state.FindHeader(include)
+ if duplicate_line >= 0:
+ error(filename, linenum, 'build/include', 4,
+ '"%s" already included at %s:%s' %
+ (include, filename, duplicate_line))
+ return
+
+ for extension in GetNonHeaderExtensions():
+ if (include.endswith('.' + extension) and
+ os.path.dirname(fileinfo.RepositoryName()) != os.path.dirname(include)):
+ error(filename, linenum, 'build/include', 4,
+ 'Do not include .' + extension + ' files from other packages')
+ return
+
+ if not _THIRD_PARTY_HEADERS_PATTERN.match(include):
+ include_state.include_list[-1].append((include, linenum))
+
+ # We want to ensure that headers appear in the right order:
+ # 1) for foo.cc, foo.h (preferred location)
+ # 2) c system files
+ # 3) cpp system files
+ # 4) for foo.cc, foo.h (deprecated location)
+ # 5) other google headers
+ #
+ # We classify each include statement as one of those 5 types
+ # using a number of techniques. The include_state object keeps
+ # track of the highest type seen, and complains if we see a
+ # lower type after that.
+ error_message = include_state.CheckNextIncludeOrder(
+ _ClassifyInclude(fileinfo, include, is_system))
+ if error_message:
+ error(filename, linenum, 'build/include_order', 4,
+ '%s. Should be: %s.h, c system, c++ system, other.' %
+ (error_message, fileinfo.BaseName()))
+ canonical_include = include_state.CanonicalizeAlphabeticalOrder(include)
+ if not include_state.IsInAlphabeticalOrder(
+ clean_lines, linenum, canonical_include):
+ error(filename, linenum, 'build/include_alpha', 4,
+ 'Include "%s" not in alphabetical order' % include)
+ include_state.SetLastHeader(canonical_include)
+
+
+
+def _GetTextInside(text, start_pattern):
+ r"""Retrieves all the text between matching open and close parentheses.
+
+ Given a string of lines and a regular expression string, retrieve all the text
+ following the expression and between opening punctuation symbols like
+ (, [, or {, and the matching close-punctuation symbol. This properly nested
+ occurrences of the punctuations, so for the text like
+ printf(a(), b(c()));
+ a call to _GetTextInside(text, r'printf\(') will return 'a(), b(c())'.
+ start_pattern must match string having an open punctuation symbol at the end.
+
+ Args:
+ text: The lines to extract text. Its comments and strings must be elided.
+ It can be single line and can span multiple lines.
+ start_pattern: The regexp string indicating where to start extracting
+ the text.
+ Returns:
+ The extracted text.
+ None if either the opening string or ending punctuation could not be found.
+ """
+ # TODO(unknown): Audit cpplint.py to see what places could be profitably
+ # rewritten to use _GetTextInside (and use inferior regexp matching today).
+
+ # Give opening punctuations to get the matching close-punctuations.
+ matching_punctuation = {'(': ')', '{': '}', '[': ']'}
+ closing_punctuation = set(itervalues(matching_punctuation))
+
+ # Find the position to start extracting text.
+ match = re.search(start_pattern, text, re.M)
+ if not match: # start_pattern not found in text.
+ return None
+ start_position = match.end(0)
+
+ assert start_position > 0, (
+ 'start_pattern must ends with an opening punctuation.')
+ assert text[start_position - 1] in matching_punctuation, (
+ 'start_pattern must ends with an opening punctuation.')
+ # Stack of closing punctuations we expect to have in text after position.
+ punctuation_stack = [matching_punctuation[text[start_position - 1]]]
+ position = start_position
+ while punctuation_stack and position < len(text):
+ if text[position] == punctuation_stack[-1]:
+ punctuation_stack.pop()
+ elif text[position] in closing_punctuation:
+ # A closing punctuation without matching opening punctuations.
+ return None
+ elif text[position] in matching_punctuation:
+ punctuation_stack.append(matching_punctuation[text[position]])
+ position += 1
+ if punctuation_stack:
+ # Opening punctuations left without matching close-punctuations.
+ return None
+ # punctuations match.
+ return text[start_position:position - 1]
+
+
+# Patterns for matching call-by-reference parameters.
+#
+# Supports nested templates up to 2 levels deep using this messy pattern:
+# < (?: < (?: < [^<>]*
+# >
+# | [^<>] )*
+# >
+# | [^<>] )*
+# >
+_RE_PATTERN_IDENT = r'[_a-zA-Z]\w*' # =~ [[:alpha:]][[:alnum:]]*
+_RE_PATTERN_TYPE = (
+ r'(?:const\s+)?(?:typename\s+|class\s+|struct\s+|union\s+|enum\s+)?'
+ r'(?:\w|'
+ r'\s*<(?:<(?:<[^<>]*>|[^<>])*>|[^<>])*>|'
+ r'::)+')
+# A call-by-reference parameter ends with '& identifier'.
+_RE_PATTERN_REF_PARAM = re.compile(
+ r'(' + _RE_PATTERN_TYPE + r'(?:\s*(?:\bconst\b|[*]))*\s*'
+ r'&\s*' + _RE_PATTERN_IDENT + r')\s*(?:=[^,()]+)?[,)]')
+# A call-by-const-reference parameter either ends with 'const& identifier'
+# or looks like 'const type& identifier' when 'type' is atomic.
+_RE_PATTERN_CONST_REF_PARAM = (
+ r'(?:.*\s*\bconst\s*&\s*' + _RE_PATTERN_IDENT +
+ r'|const\s+' + _RE_PATTERN_TYPE + r'\s*&\s*' + _RE_PATTERN_IDENT + r')')
+# Stream types.
+_RE_PATTERN_REF_STREAM_PARAM = (
+ r'(?:.*stream\s*&\s*' + _RE_PATTERN_IDENT + r')')
+
+
+def CheckLanguage(filename, clean_lines, linenum, file_extension,
+ include_state, nesting_state, error):
+ """Checks rules from the 'C++ language rules' section of cppguide.html.
+
+ Some of these rules are hard to test (function overloading, using
+ uint32 inappropriately), but we do the best we can.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ file_extension: The extension (without the dot) of the filename.
+ include_state: An _IncludeState instance in which the headers are inserted.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: The function to call with any errors found.
+ """
+ # If the line is empty or consists of entirely a comment, no need to
+ # check it.
+ line = clean_lines.elided[linenum]
+ if not line:
+ return
+
+ match = _RE_PATTERN_INCLUDE.search(line)
+ if match:
+ CheckIncludeLine(filename, clean_lines, linenum, include_state, error)
+ return
+
+ # Reset include state across preprocessor directives. This is meant
+ # to silence warnings for conditional includes.
+ match = Match(r'^\s*#\s*(if|ifdef|ifndef|elif|else|endif)\b', line)
+ if match:
+ include_state.ResetSection(match.group(1))
+
+
+ # Perform other checks now that we are sure that this is not an include line
+ CheckCasts(filename, clean_lines, linenum, error)
+ CheckGlobalStatic(filename, clean_lines, linenum, error)
+ CheckPrintf(filename, clean_lines, linenum, error)
+
+ if file_extension in GetHeaderExtensions():
+ # TODO(unknown): check that 1-arg constructors are explicit.
+ # How to tell it's a constructor?
+ # (handled in CheckForNonStandardConstructs for now)
+ # TODO(unknown): check that classes declare or disable copy/assign
+ # (level 1 error)
+ pass
+
+ # Check if people are using the verboten C basic types. The only exception
+ # we regularly allow is "unsigned short port" for port.
+ if Search(r'\bshort port\b', line):
+ if not Search(r'\bunsigned short port\b', line):
+ error(filename, linenum, 'runtime/int', 4,
+ 'Use "unsigned short" for ports, not "short"')
+ else:
+ match = Search(r'\b(short|long(?! +double)|long long)\b', line)
+ if match:
+ error(filename, linenum, 'runtime/int', 4,
+ 'Use int16/int64/etc, rather than the C type %s' % match.group(1))
+
+ # Check if some verboten operator overloading is going on
+ # TODO(unknown): catch out-of-line unary operator&:
+ # class X {};
+ # int operator&(const X& x) { return 42; } // unary operator&
+ # The trick is it's hard to tell apart from binary operator&:
+ # class Y { int operator&(const Y& x) { return 23; } }; // binary operator&
+ if Search(r'\boperator\s*&\s*\(\s*\)', line):
+ error(filename, linenum, 'runtime/operator', 4,
+ 'Unary operator& is dangerous. Do not use it.')
+
+ # Check for suspicious usage of "if" like
+ # } if (a == b) {
+ if Search(r'\}\s*if\s*\(', line):
+ error(filename, linenum, 'readability/braces', 4,
+ 'Did you mean "else if"? If not, start a new line for "if".')
+
+ # Check for potential format string bugs like printf(foo).
+ # We constrain the pattern not to pick things like DocidForPrintf(foo).
+ # Not perfect but it can catch printf(foo.c_str()) and printf(foo->c_str())
+ # TODO(unknown): Catch the following case. Need to change the calling
+ # convention of the whole function to process multiple line to handle it.
+ # printf(
+ # boy_this_is_a_really_long_variable_that_cannot_fit_on_the_prev_line);
+ printf_args = _GetTextInside(line, r'(?i)\b(string)?printf\s*\(')
+ if printf_args:
+ match = Match(r'([\w.\->()]+)$', printf_args)
+ if match and match.group(1) != '__VA_ARGS__':
+ function_name = re.search(r'\b((?:string)?printf)\s*\(',
+ line, re.I).group(1)
+ error(filename, linenum, 'runtime/printf', 4,
+ 'Potential format string bug. Do %s("%%s", %s) instead.'
+ % (function_name, match.group(1)))
+
+ # Check for potential memset bugs like memset(buf, sizeof(buf), 0).
+ match = Search(r'memset\s*\(([^,]*),\s*([^,]*),\s*0\s*\)', line)
+ if match and not Match(r"^''|-?[0-9]+|0x[0-9A-Fa-f]$", match.group(2)):
+ error(filename, linenum, 'runtime/memset', 4,
+ 'Did you mean "memset(%s, 0, %s)"?'
+ % (match.group(1), match.group(2)))
+
+ if Search(r'\busing namespace\b', line):
+ if Search(r'\bliterals\b', line):
+ error(filename, linenum, 'build/namespaces_literals', 5,
+ 'Do not use namespace using-directives. '
+ 'Use using-declarations instead.')
+ else:
+ error(filename, linenum, 'build/namespaces', 5,
+ 'Do not use namespace using-directives. '
+ 'Use using-declarations instead.')
+
+ # Detect variable-length arrays.
+ match = Match(r'\s*(.+::)?(\w+) [a-z]\w*\[(.+)];', line)
+ if (match and match.group(2) != 'return' and match.group(2) != 'delete' and
+ match.group(3).find(']') == -1):
+ # Split the size using space and arithmetic operators as delimiters.
+ # If any of the resulting tokens are not compile time constants then
+ # report the error.
+ tokens = re.split(r'\s|\+|\-|\*|\/|<<|>>]', match.group(3))
+ is_const = True
+ skip_next = False
+ for tok in tokens:
+ if skip_next:
+ skip_next = False
+ continue
+
+ if Search(r'sizeof\(.+\)', tok): continue
+ if Search(r'arraysize\(\w+\)', tok): continue
+
+ tok = tok.lstrip('(')
+ tok = tok.rstrip(')')
+ if not tok: continue
+ if Match(r'\d+', tok): continue
+ if Match(r'0[xX][0-9a-fA-F]+', tok): continue
+ if Match(r'k[A-Z0-9]\w*', tok): continue
+ if Match(r'(.+::)?k[A-Z0-9]\w*', tok): continue
+ if Match(r'(.+::)?[A-Z][A-Z0-9_]*', tok): continue
+ # A catch all for tricky sizeof cases, including 'sizeof expression',
+ # 'sizeof(*type)', 'sizeof(const type)', 'sizeof(struct StructName)'
+ # requires skipping the next token because we split on ' ' and '*'.
+ if tok.startswith('sizeof'):
+ skip_next = True
+ continue
+ is_const = False
+ break
+ if not is_const:
+ error(filename, linenum, 'runtime/arrays', 1,
+ 'Do not use variable-length arrays. Use an appropriately named '
+ "('k' followed by CamelCase) compile-time constant for the size.")
+
+ # Check for use of unnamed namespaces in header files. Registration
+ # macros are typically OK, so we allow use of "namespace {" on lines
+ # that end with backslashes.
+ if (file_extension in GetHeaderExtensions()
+ and Search(r'\bnamespace\s*{', line)
+ and line[-1] != '\\'):
+ error(filename, linenum, 'build/namespaces', 4,
+ 'Do not use unnamed namespaces in header files. See '
+ 'https://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Namespaces'
+ ' for more information.')
+
+
+def CheckGlobalStatic(filename, clean_lines, linenum, error):
+ """Check for unsafe global or static objects.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Match two lines at a time to support multiline declarations
+ if linenum + 1 < clean_lines.NumLines() and not Search(r'[;({]', line):
+ line += clean_lines.elided[linenum + 1].strip()
+
+ # Check for people declaring static/global STL strings at the top level.
+ # This is dangerous because the C++ language does not guarantee that
+ # globals with constructors are initialized before the first access, and
+ # also because globals can be destroyed when some threads are still running.
+ # TODO(unknown): Generalize this to also find static unique_ptr instances.
+ # TODO(unknown): File bugs for clang-tidy to find these.
+ match = Match(
+ r'((?:|static +)(?:|const +))(?::*std::)?string( +const)? +'
+ r'([a-zA-Z0-9_:]+)\b(.*)',
+ line)
+
+ # Remove false positives:
+ # - String pointers (as opposed to values).
+ # string *pointer
+ # const string *pointer
+ # string const *pointer
+ # string *const pointer
+ #
+ # - Functions and template specializations.
+ # string Function<Type>(...
+ # string Class<Type>::Method(...
+ #
+ # - Operators. These are matched separately because operator names
+ # cross non-word boundaries, and trying to match both operators
+ # and functions at the same time would decrease accuracy of
+ # matching identifiers.
+ # string Class::operator*()
+ if (match and
+ not Search(r'\bstring\b(\s+const)?\s*[\*\&]\s*(const\s+)?\w', line) and
+ not Search(r'\boperator\W', line) and
+ not Match(r'\s*(<.*>)?(::[a-zA-Z0-9_]+)*\s*\(([^"]|$)', match.group(4))):
+ if Search(r'\bconst\b', line):
+ error(filename, linenum, 'runtime/string', 4,
+ 'For a static/global string constant, use a C style string '
+ 'instead: "%schar%s %s[]".' %
+ (match.group(1), match.group(2) or '', match.group(3)))
+ else:
+ error(filename, linenum, 'runtime/string', 4,
+ 'Static/global string variables are not permitted.')
+
+ if (Search(r'\b([A-Za-z0-9_]*_)\(\1\)', line) or
+ Search(r'\b([A-Za-z0-9_]*_)\(CHECK_NOTNULL\(\1\)\)', line)):
+ error(filename, linenum, 'runtime/init', 4,
+ 'You seem to be initializing a member variable with itself.')
+
+
+def CheckPrintf(filename, clean_lines, linenum, error):
+ """Check for printf related issues.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # When snprintf is used, the second argument shouldn't be a literal.
+ match = Search(r'snprintf\s*\(([^,]*),\s*([0-9]*)\s*,', line)
+ if match and match.group(2) != '0':
+ # If 2nd arg is zero, snprintf is used to calculate size.
+ error(filename, linenum, 'runtime/printf', 3,
+ 'If you can, use sizeof(%s) instead of %s as the 2nd arg '
+ 'to snprintf.' % (match.group(1), match.group(2)))
+
+ # Check if some verboten C functions are being used.
+ if Search(r'\bsprintf\s*\(', line):
+ error(filename, linenum, 'runtime/printf', 5,
+ 'Never use sprintf. Use snprintf instead.')
+ match = Search(r'\b(strcpy|strcat)\s*\(', line)
+ if match:
+ error(filename, linenum, 'runtime/printf', 4,
+ 'Almost always, snprintf is better than %s' % match.group(1))
+
+
+def IsDerivedFunction(clean_lines, linenum):
+ """Check if current line contains an inherited function.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ Returns:
+ True if current line contains a function with "override"
+ virt-specifier.
+ """
+ # Scan back a few lines for start of current function
+ for i in xrange(linenum, max(-1, linenum - 10), -1):
+ match = Match(r'^([^()]*\w+)\(', clean_lines.elided[i])
+ if match:
+ # Look for "override" after the matching closing parenthesis
+ line, _, closing_paren = CloseExpression(
+ clean_lines, i, len(match.group(1)))
+ return (closing_paren >= 0 and
+ Search(r'\boverride\b', line[closing_paren:]))
+ return False
+
+
+def IsOutOfLineMethodDefinition(clean_lines, linenum):
+ """Check if current line contains an out-of-line method definition.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ Returns:
+ True if current line contains an out-of-line method definition.
+ """
+ # Scan back a few lines for start of current function
+ for i in xrange(linenum, max(-1, linenum - 10), -1):
+ if Match(r'^([^()]*\w+)\(', clean_lines.elided[i]):
+ return Match(r'^[^()]*\w+::\w+\(', clean_lines.elided[i]) is not None
+ return False
+
+
+def IsInitializerList(clean_lines, linenum):
+ """Check if current line is inside constructor initializer list.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ Returns:
+ True if current line appears to be inside constructor initializer
+ list, False otherwise.
+ """
+ for i in xrange(linenum, 1, -1):
+ line = clean_lines.elided[i]
+ if i == linenum:
+ remove_function_body = Match(r'^(.*)\{\s*$', line)
+ if remove_function_body:
+ line = remove_function_body.group(1)
+
+ if Search(r'\s:\s*\w+[({]', line):
+ # A lone colon tend to indicate the start of a constructor
+ # initializer list. It could also be a ternary operator, which
+ # also tend to appear in constructor initializer lists as
+ # opposed to parameter lists.
+ return True
+ if Search(r'\}\s*,\s*$', line):
+ # A closing brace followed by a comma is probably the end of a
+ # brace-initialized member in constructor initializer list.
+ return True
+ if Search(r'[{};]\s*$', line):
+ # Found one of the following:
+ # - A closing brace or semicolon, probably the end of the previous
+ # function.
+ # - An opening brace, probably the start of current class or namespace.
+ #
+ # Current line is probably not inside an initializer list since
+ # we saw one of those things without seeing the starting colon.
+ return False
+
+ # Got to the beginning of the file without seeing the start of
+ # constructor initializer list.
+ return False
+
+
+def CheckForNonConstReference(filename, clean_lines, linenum,
+ nesting_state, error):
+ """Check for non-const references.
+
+ Separate from CheckLanguage since it scans backwards from current
+ line, instead of scanning forward.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: The function to call with any errors found.
+ """
+ # Do nothing if there is no '&' on current line.
+ line = clean_lines.elided[linenum]
+ if '&' not in line:
+ return
+
+ # If a function is inherited, current function doesn't have much of
+ # a choice, so any non-const references should not be blamed on
+ # derived function.
+ if IsDerivedFunction(clean_lines, linenum):
+ return
+
+ # Don't warn on out-of-line method definitions, as we would warn on the
+ # in-line declaration, if it isn't marked with 'override'.
+ if IsOutOfLineMethodDefinition(clean_lines, linenum):
+ return
+
+ # Long type names may be broken across multiple lines, usually in one
+ # of these forms:
+ # LongType
+ # ::LongTypeContinued &identifier
+ # LongType::
+ # LongTypeContinued &identifier
+ # LongType<
+ # ...>::LongTypeContinued &identifier
+ #
+ # If we detected a type split across two lines, join the previous
+ # line to current line so that we can match const references
+ # accordingly.
+ #
+ # Note that this only scans back one line, since scanning back
+ # arbitrary number of lines would be expensive. If you have a type
+ # that spans more than 2 lines, please use a typedef.
+ if linenum > 1:
+ previous = None
+ if Match(r'\s*::(?:[\w<>]|::)+\s*&\s*\S', line):
+ # previous_line\n + ::current_line
+ previous = Search(r'\b((?:const\s*)?(?:[\w<>]|::)+[\w<>])\s*$',
+ clean_lines.elided[linenum - 1])
+ elif Match(r'\s*[a-zA-Z_]([\w<>]|::)+\s*&\s*\S', line):
+ # previous_line::\n + current_line
+ previous = Search(r'\b((?:const\s*)?(?:[\w<>]|::)+::)\s*$',
+ clean_lines.elided[linenum - 1])
+ if previous:
+ line = previous.group(1) + line.lstrip()
+ else:
+ # Check for templated parameter that is split across multiple lines
+ endpos = line.rfind('>')
+ if endpos > -1:
+ (_, startline, startpos) = ReverseCloseExpression(
+ clean_lines, linenum, endpos)
+ if startpos > -1 and startline < linenum:
+ # Found the matching < on an earlier line, collect all
+ # pieces up to current line.
+ line = ''
+ for i in xrange(startline, linenum + 1):
+ line += clean_lines.elided[i].strip()
+
+ # Check for non-const references in function parameters. A single '&' may
+ # found in the following places:
+ # inside expression: binary & for bitwise AND
+ # inside expression: unary & for taking the address of something
+ # inside declarators: reference parameter
+ # We will exclude the first two cases by checking that we are not inside a
+ # function body, including one that was just introduced by a trailing '{'.
+ # TODO(unknown): Doesn't account for 'catch(Exception& e)' [rare].
+ if (nesting_state.previous_stack_top and
+ not (isinstance(nesting_state.previous_stack_top, _ClassInfo) or
+ isinstance(nesting_state.previous_stack_top, _NamespaceInfo))):
+ # Not at toplevel, not within a class, and not within a namespace
+ return
+
+ # Avoid initializer lists. We only need to scan back from the
+ # current line for something that starts with ':'.
+ #
+ # We don't need to check the current line, since the '&' would
+ # appear inside the second set of parentheses on the current line as
+ # opposed to the first set.
+ if linenum > 0:
+ for i in xrange(linenum - 1, max(0, linenum - 10), -1):
+ previous_line = clean_lines.elided[i]
+ if not Search(r'[),]\s*$', previous_line):
+ break
+ if Match(r'^\s*:\s+\S', previous_line):
+ return
+
+ # Avoid preprocessors
+ if Search(r'\\\s*$', line):
+ return
+
+ # Avoid constructor initializer lists
+ if IsInitializerList(clean_lines, linenum):
+ return
+
+ # We allow non-const references in a few standard places, like functions
+ # called "swap()" or iostream operators like "<<" or ">>". Do not check
+ # those function parameters.
+ #
+ # We also accept & in static_assert, which looks like a function but
+ # it's actually a declaration expression.
+ allowed_functions = (r'(?:[sS]wap(?:<\w:+>)?|'
+ r'operator\s*[<>][<>]|'
+ r'static_assert|COMPILE_ASSERT'
+ r')\s*\(')
+ if Search(allowed_functions, line):
+ return
+ elif not Search(r'\S+\([^)]*$', line):
+ # Don't see an allowed function on this line. Actually we
+ # didn't see any function name on this line, so this is likely a
+ # multi-line parameter list. Try a bit harder to catch this case.
+ for i in xrange(2):
+ if (linenum > i and
+ Search(allowed_functions, clean_lines.elided[linenum - i - 1])):
+ return
+
+ decls = ReplaceAll(r'{[^}]*}', ' ', line) # exclude function body
+ for parameter in re.findall(_RE_PATTERN_REF_PARAM, decls):
+ if (not Match(_RE_PATTERN_CONST_REF_PARAM, parameter) and
+ not Match(_RE_PATTERN_REF_STREAM_PARAM, parameter)):
+ error(filename, linenum, 'runtime/references', 2,
+ 'Is this a non-const reference? '
+ 'If so, make const or use a pointer: ' +
+ ReplaceAll(' *<', '<', parameter))
+
+
+def CheckCasts(filename, clean_lines, linenum, error):
+ """Various cast related checks.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ # Check to see if they're using an conversion function cast.
+ # I just try to capture the most common basic types, though there are more.
+ # Parameterless conversion functions, such as bool(), are allowed as they are
+ # probably a member operator declaration or default constructor.
+ match = Search(
+ r'(\bnew\s+(?:const\s+)?|\S<\s*(?:const\s+)?)?\b'
+ r'(int|float|double|bool|char|int32|uint32|int64|uint64)'
+ r'(\([^)].*)', line)
+ expecting_function = ExpectingFunctionArgs(clean_lines, linenum)
+ if match and not expecting_function:
+ matched_type = match.group(2)
+
+ # matched_new_or_template is used to silence two false positives:
+ # - New operators
+ # - Template arguments with function types
+ #
+ # For template arguments, we match on types immediately following
+ # an opening bracket without any spaces. This is a fast way to
+ # silence the common case where the function type is the first
+ # template argument. False negative with less-than comparison is
+ # avoided because those operators are usually followed by a space.
+ #
+ # function<double(double)> // bracket + no space = false positive
+ # value < double(42) // bracket + space = true positive
+ matched_new_or_template = match.group(1)
+
+ # Avoid arrays by looking for brackets that come after the closing
+ # parenthesis.
+ if Match(r'\([^()]+\)\s*\[', match.group(3)):
+ return
+
+ # Other things to ignore:
+ # - Function pointers
+ # - Casts to pointer types
+ # - Placement new
+ # - Alias declarations
+ matched_funcptr = match.group(3)
+ if (matched_new_or_template is None and
+ not (matched_funcptr and
+ (Match(r'\((?:[^() ]+::\s*\*\s*)?[^() ]+\)\s*\(',
+ matched_funcptr) or
+ matched_funcptr.startswith('(*)'))) and
+ not Match(r'\s*using\s+\S+\s*=\s*' + matched_type, line) and
+ not Search(r'new\(\S+\)\s*' + matched_type, line)):
+ error(filename, linenum, 'readability/casting', 4,
+ 'Using deprecated casting style. '
+ 'Use static_cast<%s>(...) instead' %
+ matched_type)
+
+ if not expecting_function:
+ CheckCStyleCast(filename, clean_lines, linenum, 'static_cast',
+ r'\((int|float|double|bool|char|u?int(16|32|64))\)', error)
+
+ # This doesn't catch all cases. Consider (const char * const)"hello".
+ #
+ # (char *) "foo" should always be a const_cast (reinterpret_cast won't
+ # compile).
+ if CheckCStyleCast(filename, clean_lines, linenum, 'const_cast',
+ r'\((char\s?\*+\s?)\)\s*"', error):
+ pass
+ else:
+ # Check pointer casts for other than string constants
+ CheckCStyleCast(filename, clean_lines, linenum, 'reinterpret_cast',
+ r'\((\w+\s?\*+\s?)\)', error)
+
+ # In addition, we look for people taking the address of a cast. This
+ # is dangerous -- casts can assign to temporaries, so the pointer doesn't
+ # point where you think.
+ #
+ # Some non-identifier character is required before the '&' for the
+ # expression to be recognized as a cast. These are casts:
+ # expression = &static_cast<int*>(temporary());
+ # function(&(int*)(temporary()));
+ #
+ # This is not a cast:
+ # reference_type&(int* function_param);
+ match = Search(
+ r'(?:[^\w]&\(([^)*][^)]*)\)[\w(])|'
+ r'(?:[^\w]&(static|dynamic|down|reinterpret)_cast\b)', line)
+ if match:
+ # Try a better error message when the & is bound to something
+ # dereferenced by the casted pointer, as opposed to the casted
+ # pointer itself.
+ parenthesis_error = False
+ match = Match(r'^(.*&(?:static|dynamic|down|reinterpret)_cast\b)<', line)
+ if match:
+ _, y1, x1 = CloseExpression(clean_lines, linenum, len(match.group(1)))
+ if x1 >= 0 and clean_lines.elided[y1][x1] == '(':
+ _, y2, x2 = CloseExpression(clean_lines, y1, x1)
+ if x2 >= 0:
+ extended_line = clean_lines.elided[y2][x2:]
+ if y2 < clean_lines.NumLines() - 1:
+ extended_line += clean_lines.elided[y2 + 1]
+ if Match(r'\s*(?:->|\[)', extended_line):
+ parenthesis_error = True
+
+ if parenthesis_error:
+ error(filename, linenum, 'readability/casting', 4,
+ ('Are you taking an address of something dereferenced '
+ 'from a cast? Wrapping the dereferenced expression in '
+ 'parentheses will make the binding more obvious'))
+ else:
+ error(filename, linenum, 'runtime/casting', 4,
+ ('Are you taking an address of a cast? '
+ 'This is dangerous: could be a temp var. '
+ 'Take the address before doing the cast, rather than after'))
+
+
+def CheckCStyleCast(filename, clean_lines, linenum, cast_type, pattern, error):
+ """Checks for a C-style cast by looking for the pattern.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ cast_type: The string for the C++ cast to recommend. This is either
+ reinterpret_cast, static_cast, or const_cast, depending.
+ pattern: The regular expression used to find C-style casts.
+ error: The function to call with any errors found.
+
+ Returns:
+ True if an error was emitted.
+ False otherwise.
+ """
+ line = clean_lines.elided[linenum]
+ match = Search(pattern, line)
+ if not match:
+ return False
+
+ # Exclude lines with keywords that tend to look like casts
+ context = line[0:match.start(1) - 1]
+ if Match(r'.*\b(?:sizeof|alignof|alignas|[_A-Z][_A-Z0-9]*)\s*$', context):
+ return False
+
+ # Try expanding current context to see if we one level of
+ # parentheses inside a macro.
+ if linenum > 0:
+ for i in xrange(linenum - 1, max(0, linenum - 5), -1):
+ context = clean_lines.elided[i] + context
+ if Match(r'.*\b[_A-Z][_A-Z0-9]*\s*\((?:\([^()]*\)|[^()])*$', context):
+ return False
+
+ # operator++(int) and operator--(int)
+ if context.endswith(' operator++') or context.endswith(' operator--'):
+ return False
+
+ # A single unnamed argument for a function tends to look like old style cast.
+ # If we see those, don't issue warnings for deprecated casts.
+ remainder = line[match.end(0):]
+ if Match(r'^\s*(?:;|const\b|throw\b|final\b|override\b|[=>{),]|->)',
+ remainder):
+ return False
+
+ # At this point, all that should be left is actual casts.
+ error(filename, linenum, 'readability/casting', 4,
+ 'Using C-style cast. Use %s<%s>(...) instead' %
+ (cast_type, match.group(1)))
+
+ return True
+
+
+def ExpectingFunctionArgs(clean_lines, linenum):
+ """Checks whether where function type arguments are expected.
+
+ Args:
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+
+ Returns:
+ True if the line at 'linenum' is inside something that expects arguments
+ of function types.
+ """
+ line = clean_lines.elided[linenum]
+ return (Match(r'^\s*MOCK_(CONST_)?METHOD\d+(_T)?\(', line) or
+ (linenum >= 2 and
+ (Match(r'^\s*MOCK_(?:CONST_)?METHOD\d+(?:_T)?\((?:\S+,)?\s*$',
+ clean_lines.elided[linenum - 1]) or
+ Match(r'^\s*MOCK_(?:CONST_)?METHOD\d+(?:_T)?\(\s*$',
+ clean_lines.elided[linenum - 2]) or
+ Search(r'\bstd::m?function\s*\<\s*$',
+ clean_lines.elided[linenum - 1]))))
+
+
+_HEADERS_CONTAINING_TEMPLATES = (
+ ('<deque>', ('deque',)),
+ ('<functional>', ('unary_function', 'binary_function',
+ 'plus', 'minus', 'multiplies', 'divides', 'modulus',
+ 'negate',
+ 'equal_to', 'not_equal_to', 'greater', 'less',
+ 'greater_equal', 'less_equal',
+ 'logical_and', 'logical_or', 'logical_not',
+ 'unary_negate', 'not1', 'binary_negate', 'not2',
+ 'bind1st', 'bind2nd',
+ 'pointer_to_unary_function',
+ 'pointer_to_binary_function',
+ 'ptr_fun',
+ 'mem_fun_t', 'mem_fun', 'mem_fun1_t', 'mem_fun1_ref_t',
+ 'mem_fun_ref_t',
+ 'const_mem_fun_t', 'const_mem_fun1_t',
+ 'const_mem_fun_ref_t', 'const_mem_fun1_ref_t',
+ 'mem_fun_ref',
+ )),
+ ('<limits>', ('numeric_limits',)),
+ ('<list>', ('list',)),
+ ('<map>', ('map', 'multimap',)),
+ ('<memory>', ('allocator', 'make_shared', 'make_unique', 'shared_ptr',
+ 'unique_ptr', 'weak_ptr')),
+ ('<queue>', ('queue', 'priority_queue',)),
+ ('<set>', ('set', 'multiset',)),
+ ('<stack>', ('stack',)),
+ ('<string>', ('char_traits', 'basic_string',)),
+ ('<tuple>', ('tuple',)),
+ ('<unordered_map>', ('unordered_map', 'unordered_multimap')),
+ ('<unordered_set>', ('unordered_set', 'unordered_multiset')),
+ ('<utility>', ('pair',)),
+ ('<vector>', ('vector',)),
+
+ # gcc extensions.
+ # Note: std::hash is their hash, ::hash is our hash
+ ('<hash_map>', ('hash_map', 'hash_multimap',)),
+ ('<hash_set>', ('hash_set', 'hash_multiset',)),
+ ('<slist>', ('slist',)),
+ )
+
+_HEADERS_MAYBE_TEMPLATES = (
+ ('<algorithm>', ('copy', 'max', 'min', 'min_element', 'sort',
+ 'transform',
+ )),
+ ('<utility>', ('forward', 'make_pair', 'move', 'swap')),
+ )
+
+_RE_PATTERN_STRING = re.compile(r'\bstring\b')
+
+_re_pattern_headers_maybe_templates = []
+for _header, _templates in _HEADERS_MAYBE_TEMPLATES:
+ for _template in _templates:
+ # Match max<type>(..., ...), max(..., ...), but not foo->max, foo.max or
+ # type::max().
+ _re_pattern_headers_maybe_templates.append(
+ (re.compile(r'[^>.]\b' + _template + r'(<.*?>)?\([^\)]'),
+ _template,
+ _header))
+
+# Other scripts may reach in and modify this pattern.
+_re_pattern_templates = []
+for _header, _templates in _HEADERS_CONTAINING_TEMPLATES:
+ for _template in _templates:
+ _re_pattern_templates.append(
+ (re.compile(r'(\<|\b)' + _template + r'\s*\<'),
+ _template + '<>',
+ _header))
+
+
+def FilesBelongToSameModule(filename_cc, filename_h):
+ """Check if these two filenames belong to the same module.
+
+ The concept of a 'module' here is a as follows:
+ foo.h, foo-inl.h, foo.cc, foo_test.cc and foo_unittest.cc belong to the
+ same 'module' if they are in the same directory.
+ some/path/public/xyzzy and some/path/internal/xyzzy are also considered
+ to belong to the same module here.
+
+ If the filename_cc contains a longer path than the filename_h, for example,
+ '/absolute/path/to/base/sysinfo.cc', and this file would include
+ 'base/sysinfo.h', this function also produces the prefix needed to open the
+ header. This is used by the caller of this function to more robustly open the
+ header file. We don't have access to the real include paths in this context,
+ so we need this guesswork here.
+
+ Known bugs: tools/base/bar.cc and base/bar.h belong to the same module
+ according to this implementation. Because of this, this function gives
+ some false positives. This should be sufficiently rare in practice.
+
+ Args:
+ filename_cc: is the path for the source (e.g. .cc) file
+ filename_h: is the path for the header path
+
+ Returns:
+ Tuple with a bool and a string:
+ bool: True if filename_cc and filename_h belong to the same module.
+ string: the additional prefix needed to open the header file.
+ """
+ fileinfo_cc = FileInfo(filename_cc)
+ if not fileinfo_cc.Extension().lstrip('.') in GetNonHeaderExtensions():
+ return (False, '')
+
+ fileinfo_h = FileInfo(filename_h)
+ if not fileinfo_h.Extension().lstrip('.') in GetHeaderExtensions():
+ return (False, '')
+
+ filename_cc = filename_cc[:-(len(fileinfo_cc.Extension()))]
+ matched_test_suffix = Search(_TEST_FILE_SUFFIX, fileinfo_cc.BaseName())
+ if matched_test_suffix:
+ filename_cc = filename_cc[:-len(matched_test_suffix.group(1))]
+
+ filename_cc = filename_cc.replace('/public/', '/')
+ filename_cc = filename_cc.replace('/internal/', '/')
+
+ filename_h = filename_h[:-(len(fileinfo_h.Extension()))]
+ if filename_h.endswith('-inl'):
+ filename_h = filename_h[:-len('-inl')]
+ filename_h = filename_h.replace('/public/', '/')
+ filename_h = filename_h.replace('/internal/', '/')
+
+ files_belong_to_same_module = filename_cc.endswith(filename_h)
+ common_path = ''
+ if files_belong_to_same_module:
+ common_path = filename_cc[:-len(filename_h)]
+ return files_belong_to_same_module, common_path
+
+
+def UpdateIncludeState(filename, include_dict, io=codecs):
+ """Fill up the include_dict with new includes found from the file.
+
+ Args:
+ filename: the name of the header to read.
+ include_dict: a dictionary in which the headers are inserted.
+ io: The io factory to use to read the file. Provided for testability.
+
+ Returns:
+ True if a header was successfully added. False otherwise.
+ """
+ headerfile = None
+ try:
+ headerfile = io.open(filename, 'r', 'utf8', 'replace')
+ except IOError:
+ return False
+ linenum = 0
+ for line in headerfile:
+ linenum += 1
+ clean_line = CleanseComments(line)
+ match = _RE_PATTERN_INCLUDE.search(clean_line)
+ if match:
+ include = match.group(2)
+ include_dict.setdefault(include, linenum)
+ return True
+
+
+def CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error,
+ io=codecs):
+ """Reports for missing stl includes.
+
+ This function will output warnings to make sure you are including the headers
+ necessary for the stl containers and functions that you use. We only give one
+ reason to include a header. For example, if you use both equal_to<> and
+ less<> in a .h file, only one (the latter in the file) of these will be
+ reported as a reason to include the <functional>.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ include_state: An _IncludeState instance.
+ error: The function to call with any errors found.
+ io: The IO factory to use to read the header file. Provided for unittest
+ injection.
+ """
+ required = {} # A map of header name to linenumber and the template entity.
+ # Example of required: { '<functional>': (1219, 'less<>') }
+
+ for linenum in range(clean_lines.NumLines()):
+ line = clean_lines.elided[linenum]
+ if not line or line[0] == '#':
+ continue
+
+ # String is special -- it is a non-templatized type in STL.
+ matched = _RE_PATTERN_STRING.search(line)
+ if matched:
+ # Don't warn about strings in non-STL namespaces:
+ # (We check only the first match per line; good enough.)
+ prefix = line[:matched.start()]
+ if prefix.endswith('std::') or not prefix.endswith('::'):
+ required['<string>'] = (linenum, 'string')
+
+ for pattern, template, header in _re_pattern_headers_maybe_templates:
+ if pattern.search(line):
+ required[header] = (linenum, template)
+
+ # The following function is just a speed up, no semantics are changed.
+ if not '<' in line: # Reduces the cpu time usage by skipping lines.
+ continue
+
+ for pattern, template, header in _re_pattern_templates:
+ matched = pattern.search(line)
+ if matched:
+ # Don't warn about IWYU in non-STL namespaces:
+ # (We check only the first match per line; good enough.)
+ prefix = line[:matched.start()]
+ if prefix.endswith('std::') or not prefix.endswith('::'):
+ required[header] = (linenum, template)
+
+ # The policy is that if you #include something in foo.h you don't need to
+ # include it again in foo.cc. Here, we will look at possible includes.
+ # Let's flatten the include_state include_list and copy it into a dictionary.
+ include_dict = dict([item for sublist in include_state.include_list
+ for item in sublist])
+
+ # Did we find the header for this file (if any) and successfully load it?
+ header_found = False
+
+ # Use the absolute path so that matching works properly.
+ abs_filename = FileInfo(filename).FullName()
+
+ # For Emacs's flymake.
+ # If cpplint is invoked from Emacs's flymake, a temporary file is generated
+ # by flymake and that file name might end with '_flymake.cc'. In that case,
+ # restore original file name here so that the corresponding header file can be
+ # found.
+ # e.g. If the file name is 'foo_flymake.cc', we should search for 'foo.h'
+ # instead of 'foo_flymake.h'
+ abs_filename = re.sub(r'_flymake\.cc$', '.cc', abs_filename)
+
+ # include_dict is modified during iteration, so we iterate over a copy of
+ # the keys.
+ header_keys = list(include_dict.keys())
+ for header in header_keys:
+ (same_module, common_path) = FilesBelongToSameModule(abs_filename, header)
+ fullpath = common_path + header
+ if same_module and UpdateIncludeState(fullpath, include_dict, io):
+ header_found = True
+
+ # If we can't find the header file for a .cc, assume it's because we don't
+ # know where to look. In that case we'll give up as we're not sure they
+ # didn't include it in the .h file.
+ # TODO(unknown): Do a better job of finding .h files so we are confident that
+ # not having the .h file means there isn't one.
+ if not header_found:
+ for extension in GetNonHeaderExtensions():
+ if filename.endswith('.' + extension):
+ return
+
+ # All the lines have been processed, report the errors found.
+ for required_header_unstripped in sorted(required, key=required.__getitem__):
+ template = required[required_header_unstripped][1]
+ if required_header_unstripped.strip('<>"') not in include_dict:
+ error(filename, required[required_header_unstripped][0],
+ 'build/include_what_you_use', 4,
+ 'Add #include ' + required_header_unstripped + ' for ' + template)
+
+
+_RE_PATTERN_EXPLICIT_MAKEPAIR = re.compile(r'\bmake_pair\s*<')
+
+
+def CheckMakePairUsesDeduction(filename, clean_lines, linenum, error):
+ """Check that make_pair's template arguments are deduced.
+
+ G++ 4.6 in C++11 mode fails badly if make_pair's template arguments are
+ specified explicitly, and such use isn't intended in any case.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+ match = _RE_PATTERN_EXPLICIT_MAKEPAIR.search(line)
+ if match:
+ error(filename, linenum, 'build/explicit_make_pair',
+ 4, # 4 = high confidence
+ 'For C++11-compatibility, omit template arguments from make_pair'
+ ' OR use pair directly OR if appropriate, construct a pair directly')
+
+
+def CheckRedundantVirtual(filename, clean_lines, linenum, error):
+ """Check if line contains a redundant "virtual" function-specifier.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ # Look for "virtual" on current line.
+ line = clean_lines.elided[linenum]
+ virtual = Match(r'^(.*)(\bvirtual\b)(.*)$', line)
+ if not virtual: return
+
+ # Ignore "virtual" keywords that are near access-specifiers. These
+ # are only used in class base-specifier and do not apply to member
+ # functions.
+ if (Search(r'\b(public|protected|private)\s+$', virtual.group(1)) or
+ Match(r'^\s+(public|protected|private)\b', virtual.group(3))):
+ return
+
+ # Ignore the "virtual" keyword from virtual base classes. Usually
+ # there is a column on the same line in these cases (virtual base
+ # classes are rare in google3 because multiple inheritance is rare).
+ if Match(r'^.*[^:]:[^:].*$', line): return
+
+ # Look for the next opening parenthesis. This is the start of the
+ # parameter list (possibly on the next line shortly after virtual).
+ # TODO(unknown): doesn't work if there are virtual functions with
+ # decltype() or other things that use parentheses, but csearch suggests
+ # that this is rare.
+ end_col = -1
+ end_line = -1
+ start_col = len(virtual.group(2))
+ for start_line in xrange(linenum, min(linenum + 3, clean_lines.NumLines())):
+ line = clean_lines.elided[start_line][start_col:]
+ parameter_list = Match(r'^([^(]*)\(', line)
+ if parameter_list:
+ # Match parentheses to find the end of the parameter list
+ (_, end_line, end_col) = CloseExpression(
+ clean_lines, start_line, start_col + len(parameter_list.group(1)))
+ break
+ start_col = 0
+
+ if end_col < 0:
+ return # Couldn't find end of parameter list, give up
+
+ # Look for "override" or "final" after the parameter list
+ # (possibly on the next few lines).
+ for i in xrange(end_line, min(end_line + 3, clean_lines.NumLines())):
+ line = clean_lines.elided[i][end_col:]
+ match = Search(r'\b(override|final)\b', line)
+ if match:
+ error(filename, linenum, 'readability/inheritance', 4,
+ ('"virtual" is redundant since function is '
+ 'already declared as "%s"' % match.group(1)))
+
+ # Set end_col to check whole lines after we are done with the
+ # first line.
+ end_col = 0
+ if Search(r'[^\w]\s*$', line):
+ break
+
+
+def CheckRedundantOverrideOrFinal(filename, clean_lines, linenum, error):
+ """Check if line contains a redundant "override" or "final" virt-specifier.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ # Look for closing parenthesis nearby. We need one to confirm where
+ # the declarator ends and where the virt-specifier starts to avoid
+ # false positives.
+ line = clean_lines.elided[linenum]
+ declarator_end = line.rfind(')')
+ if declarator_end >= 0:
+ fragment = line[declarator_end:]
+ else:
+ if linenum > 1 and clean_lines.elided[linenum - 1].rfind(')') >= 0:
+ fragment = line
+ else:
+ return
+
+ # Check that at most one of "override" or "final" is present, not both
+ if Search(r'\boverride\b', fragment) and Search(r'\bfinal\b', fragment):
+ error(filename, linenum, 'readability/inheritance', 4,
+ ('"override" is redundant since function is '
+ 'already declared as "final"'))
+
+
+
+
+# Returns true if we are at a new block, and it is directly
+# inside of a namespace.
+def IsBlockInNameSpace(nesting_state, is_forward_declaration):
+ """Checks that the new block is directly in a namespace.
+
+ Args:
+ nesting_state: The _NestingState object that contains info about our state.
+ is_forward_declaration: If the class is a forward declared class.
+ Returns:
+ Whether or not the new block is directly in a namespace.
+ """
+ if is_forward_declaration:
+ return len(nesting_state.stack) >= 1 and (
+ isinstance(nesting_state.stack[-1], _NamespaceInfo))
+
+
+ return (len(nesting_state.stack) > 1 and
+ nesting_state.stack[-1].check_namespace_indentation and
+ isinstance(nesting_state.stack[-2], _NamespaceInfo))
+
+
+def ShouldCheckNamespaceIndentation(nesting_state, is_namespace_indent_item,
+ raw_lines_no_comments, linenum):
+ """This method determines if we should apply our namespace indentation check.
+
+ Args:
+ nesting_state: The current nesting state.
+ is_namespace_indent_item: If we just put a new class on the stack, True.
+ If the top of the stack is not a class, or we did not recently
+ add the class, False.
+ raw_lines_no_comments: The lines without the comments.
+ linenum: The current line number we are processing.
+
+ Returns:
+ True if we should apply our namespace indentation check. Currently, it
+ only works for classes and namespaces inside of a namespace.
+ """
+
+ is_forward_declaration = IsForwardClassDeclaration(raw_lines_no_comments,
+ linenum)
+
+ if not (is_namespace_indent_item or is_forward_declaration):
+ return False
+
+ # If we are in a macro, we do not want to check the namespace indentation.
+ if IsMacroDefinition(raw_lines_no_comments, linenum):
+ return False
+
+ return IsBlockInNameSpace(nesting_state, is_forward_declaration)
+
+
+# Call this method if the line is directly inside of a namespace.
+# If the line above is blank (excluding comments) or the start of
+# an inner namespace, it cannot be indented.
+def CheckItemIndentationInNamespace(filename, raw_lines_no_comments, linenum,
+ error):
+ line = raw_lines_no_comments[linenum]
+ if Match(r'^\s+', line):
+ error(filename, linenum, 'runtime/indentation_namespace', 4,
+ 'Do not indent within a namespace')
+
+
+def ProcessLine(filename, file_extension, clean_lines, line,
+ include_state, function_state, nesting_state, error,
+ extra_check_functions=None):
+ """Processes a single line in the file.
+
+ Args:
+ filename: Filename of the file that is being processed.
+ file_extension: The extension (dot not included) of the file.
+ clean_lines: An array of strings, each representing a line of the file,
+ with comments stripped.
+ line: Number of line being processed.
+ include_state: An _IncludeState instance in which the headers are inserted.
+ function_state: A _FunctionState instance which counts function lines, etc.
+ nesting_state: A NestingState instance which maintains information about
+ the current stack of nested blocks being parsed.
+ error: A callable to which errors are reported, which takes 4 arguments:
+ filename, line number, error level, and message
+ extra_check_functions: An array of additional check functions that will be
+ run on each source line. Each function takes 4
+ arguments: filename, clean_lines, line, error
+ """
+ raw_lines = clean_lines.raw_lines
+ ParseNolintSuppressions(filename, raw_lines[line], line, error)
+ nesting_state.Update(filename, clean_lines, line, error)
+ CheckForNamespaceIndentation(filename, nesting_state, clean_lines, line,
+ error)
+ if nesting_state.InAsmBlock(): return
+ CheckForFunctionLengths(filename, clean_lines, line, function_state, error)
+ CheckForMultilineCommentsAndStrings(filename, clean_lines, line, error)
+ CheckStyle(filename, clean_lines, line, file_extension, nesting_state, error)
+ CheckLanguage(filename, clean_lines, line, file_extension, include_state,
+ nesting_state, error)
+ CheckForNonConstReference(filename, clean_lines, line, nesting_state, error)
+ CheckForNonStandardConstructs(filename, clean_lines, line,
+ nesting_state, error)
+ CheckVlogArguments(filename, clean_lines, line, error)
+ CheckPosixThreading(filename, clean_lines, line, error)
+ CheckInvalidIncrement(filename, clean_lines, line, error)
+ CheckMakePairUsesDeduction(filename, clean_lines, line, error)
+ CheckRedundantVirtual(filename, clean_lines, line, error)
+ CheckRedundantOverrideOrFinal(filename, clean_lines, line, error)
+ if extra_check_functions:
+ for check_fn in extra_check_functions:
+ check_fn(filename, clean_lines, line, error)
+
+def FlagCxx11Features(filename, clean_lines, linenum, error):
+ """Flag those c++11 features that we only allow in certain places.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ include = Match(r'\s*#\s*include\s+[<"]([^<"]+)[">]', line)
+
+ # Flag unapproved C++ TR1 headers.
+ if include and include.group(1).startswith('tr1/'):
+ error(filename, linenum, 'build/c++tr1', 5,
+ ('C++ TR1 headers such as <%s> are unapproved.') % include.group(1))
+
+ # Flag unapproved C++11 headers.
+ if include and include.group(1) in ('cfenv',
+ 'condition_variable',
+ 'fenv.h',
+ 'future',
+ 'mutex',
+ 'thread',
+ 'chrono',
+ 'ratio',
+ 'regex',
+ 'system_error',
+ ):
+ error(filename, linenum, 'build/c++11', 5,
+ ('<%s> is an unapproved C++11 header.') % include.group(1))
+
+ # The only place where we need to worry about C++11 keywords and library
+ # features in preprocessor directives is in macro definitions.
+ if Match(r'\s*#', line) and not Match(r'\s*#\s*define\b', line): return
+
+ # These are classes and free functions. The classes are always
+ # mentioned as std::*, but we only catch the free functions if
+ # they're not found by ADL. They're alphabetical by header.
+ for top_name in (
+ # type_traits
+ 'alignment_of',
+ 'aligned_union',
+ ):
+ if Search(r'\bstd::%s\b' % top_name, line):
+ error(filename, linenum, 'build/c++11', 5,
+ ('std::%s is an unapproved C++11 class or function. Send c-style '
+ 'an example of where it would make your code more readable, and '
+ 'they may let you use it.') % top_name)
+
+
+def FlagCxx14Features(filename, clean_lines, linenum, error):
+ """Flag those C++14 features that we restrict.
+
+ Args:
+ filename: The name of the current file.
+ clean_lines: A CleansedLines instance containing the file.
+ linenum: The number of the line to check.
+ error: The function to call with any errors found.
+ """
+ line = clean_lines.elided[linenum]
+
+ include = Match(r'\s*#\s*include\s+[<"]([^<"]+)[">]', line)
+
+ # Flag unapproved C++14 headers.
+ if include and include.group(1) in ('scoped_allocator', 'shared_mutex'):
+ error(filename, linenum, 'build/c++14', 5,
+ ('<%s> is an unapproved C++14 header.') % include.group(1))
+
+
+def ProcessFileData(filename, file_extension, lines, error,
+ extra_check_functions=None):
+ """Performs lint checks and reports any errors to the given error function.
+
+ Args:
+ filename: Filename of the file that is being processed.
+ file_extension: The extension (dot not included) of the file.
+ lines: An array of strings, each representing a line of the file, with the
+ last element being empty if the file is terminated with a newline.
+ error: A callable to which errors are reported, which takes 4 arguments:
+ filename, line number, error level, and message
+ extra_check_functions: An array of additional check functions that will be
+ run on each source line. Each function takes 4
+ arguments: filename, clean_lines, line, error
+ """
+ lines = (['// marker so line numbers and indices both start at 1'] + lines +
+ ['// marker so line numbers end in a known way'])
+
+ include_state = _IncludeState()
+ function_state = _FunctionState()
+ nesting_state = NestingState()
+
+ ResetNolintSuppressions()
+
+ CheckForCopyright(filename, lines, error)
+ ProcessGlobalSuppresions(lines)
+ RemoveMultiLineComments(filename, lines, error)
+ clean_lines = CleansedLines(lines)
+
+ if file_extension in GetHeaderExtensions():
+ CheckForHeaderGuard(filename, clean_lines, error)
+
+ for line in range(clean_lines.NumLines()):
+ ProcessLine(filename, file_extension, clean_lines, line,
+ include_state, function_state, nesting_state, error,
+ extra_check_functions)
+ FlagCxx11Features(filename, clean_lines, line, error)
+ nesting_state.CheckCompletedBlocks(filename, error)
+
+ CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error)
+
+ # Check that the .cc file has included its header if it exists.
+ if _IsSourceExtension(file_extension):
+ CheckHeaderFileIncluded(filename, include_state, error)
+
+ # We check here rather than inside ProcessLine so that we see raw
+ # lines rather than "cleaned" lines.
+ CheckForBadCharacters(filename, lines, error)
+
+ CheckForNewlineAtEOF(filename, lines, error)
+
+def ProcessConfigOverrides(filename):
+ """ Loads the configuration files and processes the config overrides.
+
+ Args:
+ filename: The name of the file being processed by the linter.
+
+ Returns:
+ False if the current |filename| should not be processed further.
+ """
+
+ abs_filename = os.path.abspath(filename)
+ cfg_filters = []
+ keep_looking = True
+ while keep_looking:
+ abs_path, base_name = os.path.split(abs_filename)
+ if not base_name:
+ break # Reached the root directory.
+
+ cfg_file = os.path.join(abs_path, "CPPLINT.cfg")
+ abs_filename = abs_path
+ if not os.path.isfile(cfg_file):
+ continue
+
+ try:
+ with open(cfg_file) as file_handle:
+ for line in file_handle:
+ line, _, _ = line.partition('#') # Remove comments.
+ if not line.strip():
+ continue
+
+ name, _, val = line.partition('=')
+ name = name.strip()
+ val = val.strip()
+ if name == 'set noparent':
+ keep_looking = False
+ elif name == 'filter':
+ cfg_filters.append(val)
+ elif name == 'exclude_files':
+ # When matching exclude_files pattern, use the base_name of
+ # the current file name or the directory name we are processing.
+ # For example, if we are checking for lint errors in /foo/bar/baz.cc
+ # and we found the .cfg file at /foo/CPPLINT.cfg, then the config
+ # file's "exclude_files" filter is meant to be checked against "bar"
+ # and not "baz" nor "bar/baz.cc".
+ if base_name:
+ pattern = re.compile(val)
+ if pattern.match(base_name):
+ _cpplint_state.PrintInfo('Ignoring "%s": file excluded by '
+ '"%s". File path component "%s" matches pattern "%s"\n' %
+ (filename, cfg_file, base_name, val))
+ return False
+ elif name == 'linelength':
+ global _line_length
+ try:
+ _line_length = int(val)
+ except ValueError:
+ _cpplint_state.PrintError('Line length must be numeric.')
+ elif name == 'extensions':
+ global _valid_extensions
+ try:
+ extensions = [ext.strip() for ext in val.split(',')]
+ _valid_extensions = set(extensions)
+ except ValueError:
+ sys.stderr.write('Extensions should be a comma-separated list of values;'
+ 'for example: extensions=hpp,cpp\n'
+ 'This could not be parsed: "%s"' % (val,))
+ elif name == 'headers':
+ global _header_extensions
+ try:
+ extensions = [ext.strip() for ext in val.split(',')]
+ _header_extensions = set(extensions)
+ except ValueError:
+ sys.stderr.write('Extensions should be a comma-separated list of values;'
+ 'for example: extensions=hpp,cpp\n'
+ 'This could not be parsed: "%s"' % (val,))
+ elif name == 'root':
+ global _root
+ _root = val
+ else:
+ _cpplint_state.PrintError(
+ 'Invalid configuration option (%s) in file %s\n' %
+ (name, cfg_file))
+
+ except IOError:
+ _cpplint_state.PrintError(
+ "Skipping config file '%s': Can't open for reading\n" % cfg_file)
+ keep_looking = False
+
+ # Apply all the accumulated filters in reverse order (top-level directory
+ # config options having the least priority).
+ for cfg_filter in reversed(cfg_filters):
+ _AddFilters(cfg_filter)
+
+ return True
+
+
+def ProcessFile(filename, vlevel, extra_check_functions=None):
+ """Does google-lint on a single file.
+
+ Args:
+ filename: The name of the file to parse.
+
+ vlevel: The level of errors to report. Every error of confidence
+ >= verbose_level will be reported. 0 is a good default.
+
+ extra_check_functions: An array of additional check functions that will be
+ run on each source line. Each function takes 4
+ arguments: filename, clean_lines, line, error
+ """
+
+ _SetVerboseLevel(vlevel)
+ _BackupFilters()
+
+ if not ProcessConfigOverrides(filename):
+ _RestoreFilters()
+ return
+
+ lf_lines = []
+ crlf_lines = []
+ try:
+ # Support the UNIX convention of using "-" for stdin. Note that
+ # we are not opening the file with universal newline support
+ # (which codecs doesn't support anyway), so the resulting lines do
+ # contain trailing '\r' characters if we are reading a file that
+ # has CRLF endings.
+ # If after the split a trailing '\r' is present, it is removed
+ # below.
+ if filename == '-':
+ lines = codecs.StreamReaderWriter(sys.stdin,
+ codecs.getreader('utf8'),
+ codecs.getwriter('utf8'),
+ 'replace').read().split('\n')
+ else:
+ lines = codecs.open(filename, 'r', 'utf8', 'replace').read().split('\n')
+
+ # Remove trailing '\r'.
+ # The -1 accounts for the extra trailing blank line we get from split()
+ for linenum in range(len(lines) - 1):
+ if lines[linenum].endswith('\r'):
+ lines[linenum] = lines[linenum].rstrip('\r')
+ crlf_lines.append(linenum + 1)
+ else:
+ lf_lines.append(linenum + 1)
+
+ except IOError:
+ _cpplint_state.PrintError(
+ "Skipping input '%s': Can't open for reading\n" % filename)
+ _RestoreFilters()
+ return
+
+ # Note, if no dot is found, this will give the entire filename as the ext.
+ file_extension = filename[filename.rfind('.') + 1:]
+
+ # When reading from stdin, the extension is unknown, so no cpplint tests
+ # should rely on the extension.
+ if filename != '-' and file_extension not in GetAllExtensions():
+ _cpplint_state.PrintError('Ignoring %s; not a valid file name '
+ '(%s)\n' % (filename, ', '.join(GetAllExtensions())))
+ else:
+ ProcessFileData(filename, file_extension, lines, Error,
+ extra_check_functions)
+
+ # If end-of-line sequences are a mix of LF and CR-LF, issue
+ # warnings on the lines with CR.
+ #
+ # Don't issue any warnings if all lines are uniformly LF or CR-LF,
+ # since critique can handle these just fine, and the style guide
+ # doesn't dictate a particular end of line sequence.
+ #
+ # We can't depend on os.linesep to determine what the desired
+ # end-of-line sequence should be, since that will return the
+ # server-side end-of-line sequence.
+ if lf_lines and crlf_lines:
+ # Warn on every line with CR. An alternative approach might be to
+ # check whether the file is mostly CRLF or just LF, and warn on the
+ # minority, we bias toward LF here since most tools prefer LF.
+ for linenum in crlf_lines:
+ Error(filename, linenum, 'whitespace/newline', 1,
+ 'Unexpected \\r (^M) found; better to use only \\n')
+
+ _cpplint_state.PrintInfo('Done processing %s\n' % filename)
+ _RestoreFilters()
+
+
+def PrintUsage(message):
+ """Prints a brief usage string and exits, optionally with an error message.
+
+ Args:
+ message: The optional error message.
+ """
+ sys.stderr.write(_USAGE)
+
+ if message:
+ sys.exit('\nFATAL ERROR: ' + message)
+ else:
+ sys.exit(0)
+
+
+def PrintCategories():
+ """Prints a list of all the error-categories used by error messages.
+
+ These are the categories used to filter messages via --filter.
+ """
+ sys.stderr.write(''.join(' %s\n' % cat for cat in _ERROR_CATEGORIES))
+ sys.exit(0)
+
+
+def ParseArguments(args):
+ """Parses the command line arguments.
+
+ This may set the output format and verbosity level as side-effects.
+
+ Args:
+ args: The command line arguments:
+
+ Returns:
+ The list of filenames to lint.
+ """
+ try:
+ (opts, filenames) = getopt.getopt(args, '', ['help', 'output=', 'verbose=',
+ 'counting=',
+ 'filter=',
+ 'root=',
+ 'repository=',
+ 'linelength=',
+ 'extensions=',
+ 'exclude=',
+ 'headers=',
+ 'quiet',
+ 'recursive'])
+ except getopt.GetoptError:
+ PrintUsage('Invalid arguments.')
+
+ verbosity = _VerboseLevel()
+ output_format = _OutputFormat()
+ filters = ''
+ counting_style = ''
+ recursive = False
+
+ for (opt, val) in opts:
+ if opt == '--help':
+ PrintUsage(None)
+ elif opt == '--output':
+ if val not in ('emacs', 'vs7', 'eclipse', 'junit'):
+ PrintUsage('The only allowed output formats are emacs, vs7, eclipse '
+ 'and junit.')
+ output_format = val
+ elif opt == '--verbose':
+ verbosity = int(val)
+ elif opt == '--filter':
+ filters = val
+ if not filters:
+ PrintCategories()
+ elif opt == '--counting':
+ if val not in ('total', 'toplevel', 'detailed'):
+ PrintUsage('Valid counting options are total, toplevel, and detailed')
+ counting_style = val
+ elif opt == '--root':
+ global _root
+ _root = val
+ elif opt == '--repository':
+ global _repository
+ _repository = val
+ elif opt == '--linelength':
+ global _line_length
+ try:
+ _line_length = int(val)
+ except ValueError:
+ PrintUsage('Line length must be digits.')
+ elif opt == '--exclude':
+ global _excludes
+ if not _excludes:
+ _excludes = set()
+ _excludes.update(glob.glob(val))
+ elif opt == '--extensions':
+ global _valid_extensions
+ try:
+ _valid_extensions = set(val.split(','))
+ except ValueError:
+ PrintUsage('Extensions must be comma separated list.')
+ elif opt == '--headers':
+ global _header_extensions
+ try:
+ _header_extensions = set(val.split(','))
+ except ValueError:
+ PrintUsage('Extensions must be comma separated list.')
+ elif opt == '--recursive':
+ recursive = True
+ elif opt == '--quiet':
+ global _quiet
+ _quiet = True
+
+ if not filenames:
+ PrintUsage('No files were specified.')
+
+ if recursive:
+ filenames = _ExpandDirectories(filenames)
+
+ if _excludes:
+ filenames = _FilterExcludedFiles(filenames)
+
+ _SetOutputFormat(output_format)
+ _SetVerboseLevel(verbosity)
+ _SetFilters(filters)
+ _SetCountingStyle(counting_style)
+
+ return filenames
+
+def _ExpandDirectories(filenames):
+ """Searches a list of filenames and replaces directories in the list with
+ all files descending from those directories. Files with extensions not in
+ the valid extensions list are excluded.
+
+ Args:
+ filenames: A list of files or directories
+
+ Returns:
+ A list of all files that are members of filenames or descended from a
+ directory in filenames
+ """
+ expanded = set()
+ for filename in filenames:
+ if not os.path.isdir(filename):
+ expanded.add(filename)
+ continue
+
+ for root, _, files in os.walk(filename):
+ for loopfile in files:
+ fullname = os.path.join(root, loopfile)
+ if fullname.startswith('.' + os.path.sep):
+ fullname = fullname[len('.' + os.path.sep):]
+ expanded.add(fullname)
+
+ filtered = []
+ for filename in expanded:
+ if os.path.splitext(filename)[1][1:] in GetAllExtensions():
+ filtered.append(filename)
+
+ return filtered
+
+def _FilterExcludedFiles(filenames):
+ """Filters out files listed in the --exclude command line switch. File paths
+ in the switch are evaluated relative to the current working directory
+ """
+ exclude_paths = [os.path.abspath(f) for f in _excludes]
+ return [f for f in filenames if os.path.abspath(f) not in exclude_paths]
+
+def main():
+ filenames = ParseArguments(sys.argv[1:])
+ backup_err = sys.stderr
+ try:
+ # Change stderr to write with replacement characters so we don't die
+ # if we try to print something containing non-ASCII characters.
+ sys.stderr = codecs.StreamReader(sys.stderr, 'replace')
+
+ _cpplint_state.ResetErrorCounts()
+ for filename in filenames:
+ ProcessFile(filename, _cpplint_state.verbose_level)
+ _cpplint_state.PrintErrorCounts()
+
+ if _cpplint_state.output_format == 'junit':
+ sys.stderr.write(_cpplint_state.FormatJUnitXML())
+
+ finally:
+ sys.stderr = backup_err
+
+ sys.exit(_cpplint_state.error_count > 0)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/src/arrow/cpp/build-support/fuzzing/generate_corpuses.sh b/src/arrow/cpp/build-support/fuzzing/generate_corpuses.sh
new file mode 100755
index 000000000..e3f00e647
--- /dev/null
+++ b/src/arrow/cpp/build-support/fuzzing/generate_corpuses.sh
@@ -0,0 +1,59 @@
+#!/usr/bin/env bash
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Generate and pack seed corpus files, for OSS-Fuzz
+
+if [ $# -ne 1 ]; then
+ echo "Usage: $0 <build output dir>"
+ exit 1
+fi
+
+set -ex
+
+CORPUS_DIR=/tmp/corpus
+ARROW_ROOT=$(cd $(dirname $BASH_SOURCE)/../../..; pwd)
+ARROW_CPP=$ARROW_ROOT/cpp
+OUT=$1
+
+# NOTE: name of seed corpus output file should be "<FUZZ TARGET>-seed_corpus.zip"
+# where "<FUZZ TARGET>" is the exact name of the fuzz target executable the
+# seed corpus is generated for.
+
+IPC_INTEGRATION_FILES=$(find ${ARROW_ROOT}/testing/data/arrow-ipc-stream/integration -name "*.stream")
+
+rm -rf ${CORPUS_DIR}
+${OUT}/arrow-ipc-generate-fuzz-corpus -stream ${CORPUS_DIR}
+# Several IPC integration files can have the same name, make sure
+# they all appear in the corpus by numbering the duplicates.
+cp --backup=numbered ${IPC_INTEGRATION_FILES} ${CORPUS_DIR}
+${ARROW_CPP}/build-support/fuzzing/pack_corpus.py ${CORPUS_DIR} ${OUT}/arrow-ipc-stream-fuzz_seed_corpus.zip
+
+rm -rf ${CORPUS_DIR}
+${OUT}/arrow-ipc-generate-fuzz-corpus -file ${CORPUS_DIR}
+${ARROW_CPP}/build-support/fuzzing/pack_corpus.py ${CORPUS_DIR} ${OUT}/arrow-ipc-file-fuzz_seed_corpus.zip
+
+rm -rf ${CORPUS_DIR}
+${OUT}/arrow-ipc-generate-tensor-fuzz-corpus -stream ${CORPUS_DIR}
+${ARROW_CPP}/build-support/fuzzing/pack_corpus.py ${CORPUS_DIR} ${OUT}/arrow-ipc-tensor-stream-fuzz_seed_corpus.zip
+
+rm -rf ${CORPUS_DIR}
+${OUT}/parquet-arrow-generate-fuzz-corpus ${CORPUS_DIR}
+# Add Parquet testing examples
+cp ${ARROW_CPP}/submodules/parquet-testing/data/*.parquet ${CORPUS_DIR}
+${ARROW_CPP}/build-support/fuzzing/pack_corpus.py ${CORPUS_DIR} ${OUT}/parquet-arrow-fuzz_seed_corpus.zip
diff --git a/src/arrow/cpp/build-support/fuzzing/pack_corpus.py b/src/arrow/cpp/build-support/fuzzing/pack_corpus.py
new file mode 100755
index 000000000..2064fed60
--- /dev/null
+++ b/src/arrow/cpp/build-support/fuzzing/pack_corpus.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python3
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Rename a bunch of corpus files to their SHA1 hashes, and
+# pack them into a ZIP archive.
+
+import hashlib
+from pathlib import Path
+import sys
+import zipfile
+
+
+def process_dir(corpus_dir, zip_output):
+ seen = set()
+
+ for child in corpus_dir.iterdir():
+ if not child.is_file():
+ raise IOError("Not a file: {0}".format(child))
+ with child.open('rb') as f:
+ data = f.read()
+ arcname = hashlib.sha1(data).hexdigest()
+ if arcname in seen:
+ raise ValueError("Duplicate hash: {0} (in file {1})"
+ .format(arcname, child))
+ zip_output.writestr(str(arcname), data)
+ seen.add(arcname)
+
+
+def main(corpus_dir, zip_output_name):
+ with zipfile.ZipFile(zip_output_name, 'w') as zip_output:
+ process_dir(Path(corpus_dir), zip_output)
+
+
+if __name__ == "__main__":
+ if len(sys.argv) != 3:
+ print("Usage: {0} <corpus dir> <output zip file>".format(sys.argv[0]))
+ sys.exit(1)
+ main(sys.argv[1], sys.argv[2])
diff --git a/src/arrow/cpp/build-support/get-upstream-commit.sh b/src/arrow/cpp/build-support/get-upstream-commit.sh
new file mode 100755
index 000000000..779cb012b
--- /dev/null
+++ b/src/arrow/cpp/build-support/get-upstream-commit.sh
@@ -0,0 +1,25 @@
+#!/usr/bin/env bash
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Script which tries to determine the most recent git hash in the current
+# branch which originates from master by checking for the
+# 'ARROW-1234: Description` commit message
+set -e
+
+git log --grep='^ARROW-[0-9]*:.*' -n1 --pretty=format:%H
diff --git a/src/arrow/cpp/build-support/iwyu/iwyu-filter.awk b/src/arrow/cpp/build-support/iwyu/iwyu-filter.awk
new file mode 100644
index 000000000..943ab115c
--- /dev/null
+++ b/src/arrow/cpp/build-support/iwyu/iwyu-filter.awk
@@ -0,0 +1,96 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#
+# This is an awk script to process output from the include-what-you-use (IWYU)
+# tool. As of now, IWYU is of alpha quality and it gives many incorrect
+# recommendations -- obviously invalid or leading to compilation breakage.
+# Most of those can be silenced using appropriate IWYU pragmas, but it's not
+# the case for the auto-generated files.
+#
+# Also, it's possible to address invalid recommendation using mappings:
+# https://github.com/include-what-you-use/include-what-you-use/blob/master/docs/IWYUMappings.md
+#
+# Usage:
+# 1. Run the CMake with -DCMAKE_CXX_INCLUDE_WHAT_YOU_USE=<iwyu_cmd_line>
+#
+# The path to the IWYU binary should be absolute. The path to the binary
+# and the command-line options should be separated by semicolon
+# (that's for feeding it into CMake list variables).
+#
+# E.g., from the build directory (line breaks are just for readability):
+#
+# CC=../../thirdparty/clang-toolchain/bin/clang
+# CXX=../../thirdparty/clang-toolchain/bin/clang++
+# IWYU="`pwd`../../thirdparty/clang-toolchain/bin/include-what-you-use;\
+# -Xiwyu;--mapping_file=`pwd`../../build-support/iwyu/mappings/map.imp"
+#
+# ../../build-support/enable_devtoolset.sh \
+# env CC=$CC CXX=$CXX \
+# ../../thirdparty/installed/common/bin/cmake \
+# -DCMAKE_CXX_INCLUDE_WHAT_YOU_USE=\"$IWYU\" \
+# ../..
+#
+# NOTE:
+# Since the arrow code has some 'ifdef NDEBUG' directives, it's possible
+# that IWYU would produce different results if run against release, not
+# debug build. However, we plan to use the tool only with debug builds.
+#
+# 2. Run make, separating the output from the IWYU tool into a separate file
+# (it's possible to use piping the output from the tool to the script
+# but having a file is good for future reference, if necessary):
+#
+# make -j$(nproc) 2>/tmp/iwyu.log
+#
+# 3. Process the output from the IWYU tool using the script:
+#
+# awk -f ../../build-support/iwyu/iwyu-filter.awk /tmp/iwyu.log
+#
+
+BEGIN {
+ # This is the list of the files for which the suggestions from IWYU are
+ # ignored. Eventually, this list should become empty as soon as all the valid
+ # suggestions are addressed and invalid ones are taken care either by proper
+ # IWYU pragmas or adding special mappings (e.g. like boost mappings).
+ # muted["relative/path/to/file"]
+ muted["arrow/util/bit-util-test.cc"]
+ muted["arrow/util/rle-encoding-test.cc"]
+ muted["arrow/vendored"]
+ muted["include/hdfs.h"]
+ muted["arrow/visitor.h"]
+}
+
+# mute all suggestions for the auto-generated files
+/.*\.(pb|proxy|service)\.(cc|h) should (add|remove) these lines:/, /^$/ {
+ next
+}
+
+# mute suggestions for the explicitly specified files
+/.* should (add|remove) these lines:/ {
+ do_print = 1
+ for (path in muted) {
+ if (index($0, path)) {
+ do_print = 0
+ break
+ }
+ }
+}
+/^$/ {
+ if (do_print) print
+ do_print = 0
+}
+{ if (do_print) print }
diff --git a/src/arrow/cpp/build-support/iwyu/iwyu.sh b/src/arrow/cpp/build-support/iwyu/iwyu.sh
new file mode 100755
index 000000000..55e39d772
--- /dev/null
+++ b/src/arrow/cpp/build-support/iwyu/iwyu.sh
@@ -0,0 +1,90 @@
+#!/usr/bin/env bash
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+set -uo pipefail
+
+ROOT=$(cd $(dirname $BASH_SOURCE)/../../..; pwd)
+
+IWYU_LOG=$(mktemp -t arrow-cpp-iwyu.XXXXXX)
+trap "rm -f $IWYU_LOG" EXIT
+
+IWYU_MAPPINGS_PATH="$ROOT/cpp/build-support/iwyu/mappings"
+IWYU_ARGS="\
+ --mapping_file=$IWYU_MAPPINGS_PATH/boost-all.imp \
+ --mapping_file=$IWYU_MAPPINGS_PATH/boost-all-private.imp \
+ --mapping_file=$IWYU_MAPPINGS_PATH/boost-extra.imp \
+ --mapping_file=$IWYU_MAPPINGS_PATH/gflags.imp \
+ --mapping_file=$IWYU_MAPPINGS_PATH/glog.imp \
+ --mapping_file=$IWYU_MAPPINGS_PATH/gmock.imp \
+ --mapping_file=$IWYU_MAPPINGS_PATH/gtest.imp \
+ --mapping_file=$IWYU_MAPPINGS_PATH/arrow-misc.imp"
+
+set -e
+
+affected_files() {
+ pushd $ROOT > /dev/null
+ local commit=$($ROOT/cpp/build-support/get-upstream-commit.sh)
+ git diff --name-only $commit | awk '/\.(c|cc|h)$/'
+ popd > /dev/null
+}
+
+# Show the IWYU version. Also causes the script to fail if iwyu is not in your
+# PATH
+include-what-you-use --version
+
+if [[ "${1:-}" == "all" ]]; then
+ python $ROOT/cpp/build-support/iwyu/iwyu_tool.py -p ${IWYU_COMPILATION_DATABASE_PATH:-.} \
+ -- $IWYU_ARGS | awk -f $ROOT/cpp/build-support/iwyu/iwyu-filter.awk
+elif [[ "${1:-}" == "match" ]]; then
+ ALL_FILES=
+ IWYU_FILE_LIST=
+ for path in $(find $ROOT/cpp/src -type f | awk '/\.(c|cc|h)$/'); do
+ if [[ $path =~ $2 ]]; then
+ IWYU_FILE_LIST="$IWYU_FILE_LIST $path"
+ fi
+ done
+
+ echo "Running IWYU on $IWYU_FILE_LIST"
+ python $ROOT/cpp/build-support/iwyu/iwyu_tool.py \
+ -p ${IWYU_COMPILATION_DATABASE_PATH:-.} $IWYU_FILE_LIST -- \
+ $IWYU_ARGS | awk -f $ROOT/cpp/build-support/iwyu/iwyu-filter.awk
+else
+ # Build the list of updated files which are of IWYU interest.
+ file_list_tmp=$(affected_files)
+ if [ -z "$file_list_tmp" ]; then
+ exit 0
+ fi
+
+ # Adjust the path for every element in the list. The iwyu_tool.py normalizes
+ # paths (via realpath) to match the records from the compilation database.
+ IWYU_FILE_LIST=
+ for p in $file_list_tmp; do
+ IWYU_FILE_LIST="$IWYU_FILE_LIST $ROOT/$p"
+ done
+
+ python $ROOT/cpp/build-support/iwyu/iwyu_tool.py \
+ -p ${IWYU_COMPILATION_DATABASE_PATH:-.} $IWYU_FILE_LIST -- \
+ $IWYU_ARGS | awk -f $ROOT/cpp/build-support/iwyu/iwyu-filter.awk > $IWYU_LOG
+fi
+
+if [ -s "$IWYU_LOG" ]; then
+ # The output is not empty: the changelist needs correction.
+ cat $IWYU_LOG 1>&2
+ exit 1
+fi
diff --git a/src/arrow/cpp/build-support/iwyu/iwyu_tool.py b/src/arrow/cpp/build-support/iwyu/iwyu_tool.py
new file mode 100755
index 000000000..1429e0c0e
--- /dev/null
+++ b/src/arrow/cpp/build-support/iwyu/iwyu_tool.py
@@ -0,0 +1,280 @@
+#!/usr/bin/env python
+
+# This file has been imported into the apache source tree from
+# the IWYU source tree as of version 0.8
+# https://github.com/include-what-you-use/include-what-you-use/blob/master/iwyu_tool.py
+# and corresponding license has been added:
+# https://github.com/include-what-you-use/include-what-you-use/blob/master/LICENSE.TXT
+#
+# ==============================================================================
+# LLVM Release License
+# ==============================================================================
+# University of Illinois/NCSA
+# Open Source License
+#
+# Copyright (c) 2003-2010 University of Illinois at Urbana-Champaign.
+# All rights reserved.
+#
+# Developed by:
+#
+# LLVM Team
+#
+# University of Illinois at Urbana-Champaign
+#
+# http://llvm.org
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal with
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is furnished to do
+# so, subject to the following conditions:
+#
+# * Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimers.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimers in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the names of the LLVM Team, University of Illinois at
+# Urbana-Champaign, nor the names of its contributors may be used to
+# endorse or promote products derived from this Software without specific
+# prior written permission.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE
+# SOFTWARE.
+
+""" Driver to consume a Clang compilation database and invoke IWYU.
+
+Example usage with CMake:
+
+ # Unix systems
+ $ mkdir build && cd build
+ $ CC="clang" CXX="clang++" cmake -DCMAKE_EXPORT_COMPILE_COMMANDS=ON ...
+ $ iwyu_tool.py -p .
+
+ # Windows systems
+ $ mkdir build && cd build
+ $ cmake -DCMAKE_CXX_COMPILER="%VCINSTALLDIR%/bin/cl.exe" \
+ -DCMAKE_C_COMPILER="%VCINSTALLDIR%/VC/bin/cl.exe" \
+ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
+ -G Ninja ...
+ $ python iwyu_tool.py -p .
+
+See iwyu_tool.py -h for more details on command-line arguments.
+"""
+
+import os
+import sys
+import json
+import argparse
+import subprocess
+import re
+
+import logging
+
+logging.basicConfig(filename='iwyu.log')
+LOGGER = logging.getLogger("iwyu")
+
+
+def iwyu_formatter(output):
+ """ Process iwyu's output, basically a no-op. """
+ print('\n'.join(output))
+
+
+CORRECT_RE = re.compile(r'^\((.*?) has correct #includes/fwd-decls\)$')
+SHOULD_ADD_RE = re.compile(r'^(.*?) should add these lines:$')
+SHOULD_REMOVE_RE = re.compile(r'^(.*?) should remove these lines:$')
+FULL_LIST_RE = re.compile(r'The full include-list for (.*?):$')
+END_RE = re.compile(r'^---$')
+LINES_RE = re.compile(r'^- (.*?) // lines ([0-9]+)-[0-9]+$')
+
+
+GENERAL, ADD, REMOVE, LIST = range(4)
+
+
+def clang_formatter(output):
+ """ Process iwyu's output into something clang-like. """
+ state = (GENERAL, None)
+ for line in output:
+ match = CORRECT_RE.match(line)
+ if match:
+ print('%s:1:1: note: #includes/fwd-decls are correct', match.groups(1))
+ continue
+ match = SHOULD_ADD_RE.match(line)
+ if match:
+ state = (ADD, match.group(1))
+ continue
+ match = SHOULD_REMOVE_RE.match(line)
+ if match:
+ state = (REMOVE, match.group(1))
+ continue
+ match = FULL_LIST_RE.match(line)
+ if match:
+ state = (LIST, match.group(1))
+ elif END_RE.match(line):
+ state = (GENERAL, None)
+ elif not line.strip():
+ continue
+ elif state[0] == GENERAL:
+ print(line)
+ elif state[0] == ADD:
+ print('%s:1:1: error: add the following line', state[1])
+ print(line)
+ elif state[0] == REMOVE:
+ match = LINES_RE.match(line)
+ line_no = match.group(2) if match else '1'
+ print('%s:%s:1: error: remove the following line', state[1], line_no)
+ print(match.group(1))
+
+
+DEFAULT_FORMAT = 'iwyu'
+FORMATTERS = {
+ 'iwyu': iwyu_formatter,
+ 'clang': clang_formatter
+}
+
+
+def get_output(cwd, command):
+ """ Run the given command and return its output as a string. """
+ process = subprocess.Popen(command,
+ cwd=cwd,
+ shell=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT)
+ return process.communicate()[0].decode("utf-8").splitlines()
+
+
+def run_iwyu(cwd, compile_command, iwyu_args, verbose, formatter):
+ """ Rewrite compile_command to an IWYU command, and run it. """
+ compiler, _, args = compile_command.partition(' ')
+ if compiler.endswith('cl.exe'):
+ # If the compiler name is cl.exe, let IWYU be cl-compatible
+ clang_args = ['--driver-mode=cl']
+ else:
+ clang_args = []
+
+ iwyu_args = ['-Xiwyu ' + a for a in iwyu_args]
+ command = ['include-what-you-use'] + clang_args + iwyu_args
+ command = '%s %s' % (' '.join(command), args.strip())
+
+ if verbose:
+ print('%s:', command)
+
+ formatter(get_output(cwd, command))
+
+
+def main(compilation_db_path, source_files, verbose, formatter, iwyu_args):
+ """ Entry point. """
+ # Canonicalize compilation database path
+ if os.path.isdir(compilation_db_path):
+ compilation_db_path = os.path.join(compilation_db_path,
+ 'compile_commands.json')
+
+ compilation_db_path = os.path.realpath(compilation_db_path)
+ if not os.path.isfile(compilation_db_path):
+ print('ERROR: No such file or directory: \'%s\'', compilation_db_path)
+ return 1
+
+ # Read compilation db from disk
+ with open(compilation_db_path, 'r') as fileobj:
+ compilation_db = json.load(fileobj)
+
+ # expand symlinks
+ for entry in compilation_db:
+ entry['file'] = os.path.realpath(entry['file'])
+
+ # Cross-reference source files with compilation database
+ source_files = [os.path.realpath(s) for s in source_files]
+ if not source_files:
+ # No source files specified, analyze entire compilation database
+ entries = compilation_db
+ else:
+ # Source files specified, analyze the ones appearing in compilation db,
+ # warn for the rest.
+ entries = []
+ for source in source_files:
+ matches = [e for e in compilation_db if e['file'] == source]
+ if matches:
+ entries.extend(matches)
+ else:
+ print("{} not in compilation database".format(source))
+ # TODO: As long as there is no complete compilation database available this check cannot be performed
+ pass
+ #print('WARNING: \'%s\' not found in compilation database.', source)
+
+ # Run analysis
+ try:
+ for entry in entries:
+ cwd, compile_command = entry['directory'], entry['command']
+ run_iwyu(cwd, compile_command, iwyu_args, verbose, formatter)
+ except OSError as why:
+ print('ERROR: Failed to launch include-what-you-use: %s', why)
+ return 1
+
+ return 0
+
+
+def _bootstrap():
+ """ Parse arguments and dispatch to main(). """
+ # This hackery is necessary to add the forwarded IWYU args to the
+ # usage and help strings.
+ def customize_usage(parser):
+ """ Rewrite the parser's format_usage. """
+ original_format_usage = parser.format_usage
+ parser.format_usage = lambda: original_format_usage().rstrip() + \
+ ' -- [<IWYU args>]' + os.linesep
+
+ def customize_help(parser):
+ """ Rewrite the parser's format_help. """
+ original_format_help = parser.format_help
+
+ def custom_help():
+ """ Customized help string, calls the adjusted format_usage. """
+ helpmsg = original_format_help()
+ helplines = helpmsg.splitlines()
+ helplines[0] = parser.format_usage().rstrip()
+ return os.linesep.join(helplines) + os.linesep
+
+ parser.format_help = custom_help
+
+ # Parse arguments
+ parser = argparse.ArgumentParser(
+ description='Include-what-you-use compilation database driver.',
+ epilog='Assumes include-what-you-use is available on the PATH.')
+ customize_usage(parser)
+ customize_help(parser)
+
+ parser.add_argument('-v', '--verbose', action='store_true',
+ help='Print IWYU commands')
+ parser.add_argument('-o', '--output-format', type=str,
+ choices=FORMATTERS.keys(), default=DEFAULT_FORMAT,
+ help='Output format (default: %s)' % DEFAULT_FORMAT)
+ parser.add_argument('-p', metavar='<build-path>', required=True,
+ help='Compilation database path', dest='dbpath')
+ parser.add_argument('source', nargs='*',
+ help='Zero or more source files to run IWYU on. '
+ 'Defaults to all in compilation database.')
+
+ def partition_args(argv):
+ """ Split around '--' into driver args and IWYU args. """
+ try:
+ double_dash = argv.index('--')
+ return argv[:double_dash], argv[double_dash+1:]
+ except ValueError:
+ return argv, []
+ argv, iwyu_args = partition_args(sys.argv[1:])
+ args = parser.parse_args(argv)
+
+ sys.exit(main(args.dbpath, args.source, args.verbose,
+ FORMATTERS[args.output_format], iwyu_args))
+
+
+if __name__ == '__main__':
+ _bootstrap()
diff --git a/src/arrow/cpp/build-support/iwyu/mappings/arrow-misc.imp b/src/arrow/cpp/build-support/iwyu/mappings/arrow-misc.imp
new file mode 100644
index 000000000..6f144f1f3
--- /dev/null
+++ b/src/arrow/cpp/build-support/iwyu/mappings/arrow-misc.imp
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+[
+ { include: ["<ext/new_allocator.h>", private, "<cstddef>", public ] },
+ { include: ["<ext/alloc_traits.h>", private, "<memory>", public ] },
+ { include: ["<ext/alloc_traits.h>", private, "<condition_variable>", public ] },
+ { include: ["<ext/alloc_traits.h>", private, "<deque>", public ] },
+ { include: ["<ext/alloc_traits.h>", private, "<forward_list>", public ] },
+ { include: ["<ext/alloc_traits.h>", private, "<future>", public ] },
+ { include: ["<ext/alloc_traits.h>", private, "<map>", public ] },
+ { include: ["<ext/alloc_traits.h>", private, "<set>", public ] },
+ { include: ["<ext/alloc_traits.h>", private, "<string>", public ] },
+ { include: ["<ext/alloc_traits.h>", private, "<unordered_map>", public ] },
+ { include: ["<ext/alloc_traits.h>", private, "<unordered_set>", public ] },
+ { include: ["<ext/alloc_traits.h>", private, "<vector>", public ] },
+ { include: ["<bits/exception.h>", private, "<exception>", public ] },
+ { include: ["<bits/stdint-intn.h>", private, "<cstdint>", public ] },
+ { include: ["<bits/stdint-uintn.h>", private, "<cstdint>", public ] },
+ { include: ["<bits/shared_ptr.h>", private, "<memory>", public ] },
+ { include: ["<initializer_list>", public, "<vector>", public ] },
+ { include: ["<stdint.h>", public, "<cstdint>", public ] },
+ { include: ["<string.h>", public, "<cstring>", public ] },
+ { symbol: ["bool", private, "<cstdint>", public ] },
+ { symbol: ["false", private, "<cstdint>", public ] },
+ { symbol: ["true", private, "<cstdint>", public ] },
+ { symbol: ["int8_t", private, "<cstdint>", public ] },
+ { symbol: ["int16_t", private, "<cstdint>", public ] },
+ { symbol: ["int32_t", private, "<cstdint>", public ] },
+ { symbol: ["int64_t", private, "<cstdint>", public ] },
+ { symbol: ["uint8_t", private, "<cstdint>", public ] },
+ { symbol: ["uint16_t", private, "<cstdint>", public ] },
+ { symbol: ["uint32_t", private, "<cstdint>", public ] },
+ { symbol: ["uint64_t", private, "<cstdint>", public ] },
+ { symbol: ["size_t", private, "<cstddef>", public ] },
+ { symbol: ["variant", private, "\"arrow/compute/kernel.h\"", public ] },
+ { symbol: ["default_memory_pool", private, "\"arrow/type_fwd.h\"", public ] },
+ { symbol: ["make_shared", private, "<memory>", public ] },
+ { symbol: ["shared_ptr", private, "<memory>", public ] },
+ { symbol: ["_Node_const_iterator", private, "<flatbuffers/flatbuffers.h>", public ] },
+ { symbol: ["unordered_map<>::mapped_type", private, "<flatbuffers/flatbuffers.h>", public ] },
+ { symbol: ["std::copy", private, "<algorithm>", public ] },
+ { symbol: ["std::move", private, "<utility>", public ] },
+ { symbol: ["std::transform", private, "<algorithm>", public ] },
+ { symbol: ["pair", private, "<utility>", public ] },
+ { symbol: ["errno", private, "<cerrno>", public ] },
+ { symbol: ["posix_memalign", private, "<cstdlib>", public ] }
+]
diff --git a/src/arrow/cpp/build-support/iwyu/mappings/boost-all-private.imp b/src/arrow/cpp/build-support/iwyu/mappings/boost-all-private.imp
new file mode 100644
index 000000000..133eef113
--- /dev/null
+++ b/src/arrow/cpp/build-support/iwyu/mappings/boost-all-private.imp
@@ -0,0 +1,4166 @@
+# This file has been imported into the arrow source tree from
+# the IWYU source tree as of version 0.8
+# https://github.com/include-what-you-use/include-what-you-use/blob/master/boost-all-private.imp
+# and corresponding license has been added:
+# https://github.com/include-what-you-use/include-what-you-use/blob/master/LICENSE.TXT
+#
+# ==============================================================================
+# LLVM Release License
+# ==============================================================================
+# University of Illinois/NCSA
+# Open Source License
+#
+# Copyright (c) 2003-2010 University of Illinois at Urbana-Champaign.
+# All rights reserved.
+#
+# Developed by:
+#
+# LLVM Team
+#
+# University of Illinois at Urbana-Champaign
+#
+# http://llvm.org
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal with
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is furnished to do
+# so, subject to the following conditions:
+#
+# * Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimers.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimers in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the names of the LLVM Team, University of Illinois at
+# Urbana-Champaign, nor the names of its contributors may be used to
+# endorse or promote products derived from this Software without specific
+# prior written permission.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE
+# SOFTWARE.
+
+[
+#grep -r '^ *# *include' boost/ | grep -e "boost/[^:]*/detail/.*hp*:" -e "boost/[^:]*/impl/.*hp*:" | grep -e "\:.*/detail/" -e "\:.*/impl/" | perl -nle 'm/^([^:]+).*["<]([^>]+)[">]/ && print qq@ { include: ["<$2>", private, "<$1>", private ] },@' | grep -e \\[\"\<boost/ | sort -u
+#remove circular dependencies
+# boost/fusion/container/set/detail/value_of_data_impl.hpp with itself...
+#
+# { include: ["<boost/numeric/odeint/integrate/detail/integrate_adaptive.hpp>", private, "<boost/numeric/odeint/integrate/detail/integrate_n_steps.hpp>", private ] },
+# { include: ["<boost/numeric/odeint/integrate/detail/integrate_n_steps.hpp>", private, "<boost/numeric/odeint/integrate/detail/integrate_adaptive.hpp>", private ] },
+#
+# { include: ["<boost/python/detail/type_list.hpp>", private, "<boost/python/detail/type_list_impl.hpp>", private ] },
+# { include: ["<boost/python/detail/type_list.hpp>", private, "<boost/python/detail/type_list_impl_no_pts.hpp>", private ] },
+# { include: ["<boost/python/detail/type_list_impl.hpp>", private, "<boost/python/detail/type_list.hpp>", private ] },
+# { include: ["<boost/python/detail/type_list_impl_no_pts.hpp>", private, "<boost/python/detail/type_list.hpp>", private ] },
+
+ { include: ["<boost/accumulators/numeric/detail/function_n.hpp>", private, "<boost/accumulators/numeric/detail/function2.hpp>", private ] },
+ { include: ["<boost/accumulators/numeric/detail/function_n.hpp>", private, "<boost/accumulators/numeric/detail/function3.hpp>", private ] },
+ { include: ["<boost/accumulators/numeric/detail/function_n.hpp>", private, "<boost/accumulators/numeric/detail/function4.hpp>", private ] },
+ { include: ["<boost/algorithm/searching/detail/debugging.hpp>", private, "<boost/algorithm/searching/detail/bm_traits.hpp>", private ] },
+ { include: ["<boost/algorithm/string/detail/finder_regex.hpp>", private, "<boost/algorithm/string/detail/formatter_regex.hpp>", private ] },
+ { include: ["<boost/algorithm/string/detail/find_format_store.hpp>", private, "<boost/algorithm/string/detail/find_format_all.hpp>", private ] },
+ { include: ["<boost/algorithm/string/detail/find_format_store.hpp>", private, "<boost/algorithm/string/detail/find_format.hpp>", private ] },
+ { include: ["<boost/algorithm/string/detail/replace_storage.hpp>", private, "<boost/algorithm/string/detail/find_format_all.hpp>", private ] },
+ { include: ["<boost/algorithm/string/detail/replace_storage.hpp>", private, "<boost/algorithm/string/detail/find_format.hpp>", private ] },
+ { include: ["<boost/algorithm/string/detail/sequence.hpp>", private, "<boost/algorithm/string/detail/replace_storage.hpp>", private ] },
+ { include: ["<boost/algorithm/string/detail/util.hpp>", private, "<boost/algorithm/string/detail/formatter.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/detail/archive_serializer_map.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/detail/basic_archive_impl.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/detail/basic_iarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/detail/basic_iserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/detail/basic_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/detail/basic_oserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/detail/basic_pointer_iserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/detail/basic_pointer_oserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/detail/basic_serializer_map.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/detail/interface_iarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/detail/interface_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/detail/polymorphic_iarchive_route.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/detail/polymorphic_oarchive_route.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/detail/archive_serializer_map.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/detail/basic_archive_impl.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/detail/basic_iarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/detail/basic_iserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/detail/basic_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/detail/basic_oserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/detail/basic_pointer_iserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/detail/basic_pointer_oserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/detail/basic_serializer_map.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/detail/interface_iarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/detail/interface_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/detail/polymorphic_iarchive_route.hpp>", private ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/detail/polymorphic_oarchive_route.hpp>", private ] },
+ { include: ["<boost/archive/detail/archive_serializer_map.hpp>", private, "<boost/archive/detail/iserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/archive_serializer_map.hpp>", private, "<boost/archive/detail/oserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/detail/archive_serializer_map.hpp>", private ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/detail/basic_iserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/detail/basic_oserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/detail/basic_pointer_iserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/detail/basic_pointer_oserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/detail/basic_serializer_map.hpp>", private ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/detail/interface_iarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/detail/interface_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/mpi/detail/content_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/mpi/detail/forward_skeleton_iarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/mpi/detail/forward_skeleton_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/mpi/detail/ignore_skeleton_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/mpi/detail/mpi_datatype_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/mpi/detail/text_skeleton_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/basic_iarchive.hpp>", private, "<boost/archive/detail/common_iarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/basic_iarchive.hpp>", private, "<boost/archive/detail/iserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/basic_iserializer.hpp>", private, "<boost/archive/detail/iserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/basic_oarchive.hpp>", private, "<boost/archive/detail/common_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/basic_oarchive.hpp>", private, "<boost/archive/detail/oserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/basic_oserializer.hpp>", private, "<boost/archive/detail/oserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/basic_pointer_iserializer.hpp>", private, "<boost/archive/detail/common_iarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/basic_pointer_iserializer.hpp>", private, "<boost/archive/detail/iserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/basic_pointer_oserializer.hpp>", private, "<boost/archive/detail/oserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/basic_serializer.hpp>", private, "<boost/archive/detail/basic_iserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/basic_serializer.hpp>", private, "<boost/archive/detail/basic_oserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/basic_serializer.hpp>", private, "<boost/archive/detail/basic_pointer_iserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/basic_serializer.hpp>", private, "<boost/archive/detail/basic_pointer_oserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/check.hpp>", private, "<boost/archive/detail/iserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/check.hpp>", private, "<boost/archive/detail/oserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/common_iarchive.hpp>", private, "<boost/mpi/detail/forward_skeleton_iarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/common_oarchive.hpp>", private, "<boost/mpi/detail/forward_skeleton_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/common_oarchive.hpp>", private, "<boost/mpi/detail/ignore_skeleton_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/decl.hpp>", private, "<boost/archive/detail/auto_link_archive.hpp>", private ] },
+ { include: ["<boost/archive/detail/decl.hpp>", private, "<boost/archive/detail/auto_link_warchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/decl.hpp>", private, "<boost/archive/detail/basic_iarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/decl.hpp>", private, "<boost/archive/detail/basic_iserializer.hpp>", private ] },
+ { include: ["<boost/archive/detail/interface_iarchive.hpp>", private, "<boost/archive/detail/common_iarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/interface_iarchive.hpp>", private, "<boost/mpi/detail/forward_skeleton_iarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/interface_oarchive.hpp>", private, "<boost/archive/detail/common_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/interface_oarchive.hpp>", private, "<boost/mpi/detail/forward_skeleton_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/iserializer.hpp>", private, "<boost/archive/detail/interface_iarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/iserializer.hpp>", private, "<boost/mpi/detail/forward_skeleton_iarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/oserializer.hpp>", private, "<boost/archive/detail/interface_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/oserializer.hpp>", private, "<boost/mpi/detail/forward_skeleton_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/oserializer.hpp>", private, "<boost/mpi/detail/ignore_skeleton_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/oserializer.hpp>", private, "<boost/mpi/detail/mpi_datatype_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/mpi/detail/content_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/mpi/detail/mpi_datatype_oarchive.hpp>", private ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/mpi/detail/text_skeleton_oarchive.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/completion_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/deadline_timer_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/descriptor_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/descriptor_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/handler_alloc_helpers.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/handler_cont_helpers.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/handler_invoke_helpers.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/impl/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/impl/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/impl/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/reactive_descriptor_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/reactive_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/reactive_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/reactive_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/reactive_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/reactive_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/reactive_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/reactive_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/reactive_socket_sendto_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/resolve_endpoint_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/signal_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/signal_set_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/wait_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/win_iocp_handle_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/win_iocp_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/win_iocp_handle_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/win_iocp_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/win_iocp_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/win_iocp_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/win_iocp_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/win_object_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/winrt_resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/winrt_resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/winrt_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/winrt_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/winrt_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/winrt_ssocket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/winrt_ssocket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/addressof.hpp>", private, "<boost/asio/detail/winrt_utils.hpp>", private ] },
+ { include: ["<boost/asio/detail/array_fwd.hpp>", private, "<boost/asio/detail/buffer_sequence_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/array_fwd.hpp>", private, "<boost/asio/impl/read_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/array_fwd.hpp>", private, "<boost/asio/impl/read.hpp>", private ] },
+ { include: ["<boost/asio/detail/array_fwd.hpp>", private, "<boost/asio/impl/write_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/array_fwd.hpp>", private, "<boost/asio/impl/write.hpp>", private ] },
+ { include: ["<boost/asio/detail/assert.hpp>", private, "<boost/asio/detail/buffered_stream_storage.hpp>", private ] },
+ { include: ["<boost/asio/detail/assert.hpp>", private, "<boost/asio/detail/hash_map.hpp>", private ] },
+ { include: ["<boost/asio/detail/assert.hpp>", private, "<boost/asio/detail/posix_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/assert.hpp>", private, "<boost/asio/detail/std_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/assert.hpp>", private, "<boost/asio/detail/win_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/assert.hpp>", private, "<boost/asio/ssl/old/detail/openssl_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/atomic_count.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/atomic_count.hpp>", private, "<boost/asio/detail/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/atomic_count.hpp>", private, "<boost/asio/detail/winrt_async_manager.hpp>", private ] },
+ { include: ["<boost/asio/detail/base_from_completion_cond.hpp>", private, "<boost/asio/impl/read_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/base_from_completion_cond.hpp>", private, "<boost/asio/impl/read.hpp>", private ] },
+ { include: ["<boost/asio/detail/base_from_completion_cond.hpp>", private, "<boost/asio/impl/write_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/base_from_completion_cond.hpp>", private, "<boost/asio/impl/write.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/deadline_timer_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/descriptor_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/descriptor_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/null_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/reactive_descriptor_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/reactive_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/reactive_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/reactive_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/reactive_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/reactive_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/reactive_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/reactive_socket_sendto_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/resolve_endpoint_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/win_iocp_handle_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/win_iocp_handle_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/win_iocp_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/win_iocp_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/win_iocp_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/win_iocp_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/winrt_resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/winrt_resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/winrt_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/winrt_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/winrt_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/detail/wrapped_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/impl/connect.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/impl/read_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/impl/read.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/impl/read_until.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/impl/write_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/impl/write.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/descriptor_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/descriptor_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/reactive_descriptor_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/reactive_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/reactive_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/reactive_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/reactive_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/reactive_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/reactive_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/reactive_socket_sendto_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/win_iocp_handle_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/win_iocp_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/win_iocp_handle_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/win_iocp_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/win_iocp_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/win_iocp_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/win_iocp_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/winrt_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/winrt_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/winrt_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/detail/winrt_ssocket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/ssl/detail/read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/ssl/detail/write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/ssl/old/detail/openssl_stream_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/call_stack.hpp>", private, "<boost/asio/detail/impl/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/call_stack.hpp>", private, "<boost/asio/detail/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/call_stack.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/completion_handler.hpp>", private, "<boost/asio/detail/impl/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/completion_handler.hpp>", private, "<boost/asio/detail/impl/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/completion_handler.hpp>", private, "<boost/asio/detail/impl/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/addressof.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/array_fwd.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/array.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/assert.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/atomic_count.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/base_from_completion_cond.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/bind_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/buffered_stream_storage.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/buffer_resize_guard.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/buffer_sequence_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/call_stack.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/completion_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/consuming_buffers.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/cstdint.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/date_time_fwd.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/deadline_timer_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/dependent_type.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/descriptor_ops.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/descriptor_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/descriptor_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/eventfd_select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/event.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/fd_set_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/function.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/gcc_arm_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/gcc_hppa_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/gcc_sync_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/gcc_x86_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/handler_alloc_helpers.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/handler_cont_helpers.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/handler_invoke_helpers.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/handler_tracking.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/handler_type_requirements.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/hash_map.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/impl/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/impl/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/impl/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/impl/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/impl/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/io_control.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/keyword_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/limits.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/local_free_on_block_exit.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/macos_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/noncopyable.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/null_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/null_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/null_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/null_signal_blocker.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/null_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/null_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/null_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/null_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/old_win_sdk_compat.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/pipe_select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/posix_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/posix_fd_set_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/posix_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/posix_signal_blocker.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/posix_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/posix_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/posix_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactive_descriptor_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactive_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactive_serial_port_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactive_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactive_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactive_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactive_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactive_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactive_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactive_socket_sendto_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactor_fwd.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactor_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/reactor_op_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/resolve_endpoint_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/resolver_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/scoped_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/service_registry.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/shared_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/signal_blocker.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/signal_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/signal_init.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/signal_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/signal_set_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/socket_holder.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/socket_ops.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/socket_option.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/socket_select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/socket_types.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/solaris_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/std_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/std_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/std_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/std_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/throw_error.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/throw_exception.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/timer_queue_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/timer_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/timer_queue_set.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/timer_scheduler_fwd.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/type_traits.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/variadic_templates.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/wait_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/wait_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/weak_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/wince_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_fd_set_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_handle_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_handle_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_serial_port_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_object_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/winrt_async_manager.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/winrt_async_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/winrt_resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/winrt_resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/winrt_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/winrt_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/winrt_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/winrt_ssocket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/winrt_ssocket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/winrt_utils.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/winsock_init.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/detail/win_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/generic/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/impl/spawn.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/impl/use_future.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/detail/socket_option.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/local/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/detail/buffered_handshake_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/detail/engine.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/detail/handshake_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/detail/io.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/detail/openssl_init.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/detail/openssl_types.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/detail/password_callback.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/detail/read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/detail/shutdown_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/detail/stream_core.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/detail/verify_callback.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/detail/write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/impl/context.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/old/detail/openssl_context_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/old/detail/openssl_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/old/detail/openssl_stream_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/consuming_buffers.hpp>", private, "<boost/asio/impl/connect.hpp>", private ] },
+ { include: ["<boost/asio/detail/consuming_buffers.hpp>", private, "<boost/asio/impl/read_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/consuming_buffers.hpp>", private, "<boost/asio/impl/read.hpp>", private ] },
+ { include: ["<boost/asio/detail/consuming_buffers.hpp>", private, "<boost/asio/impl/write_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/consuming_buffers.hpp>", private, "<boost/asio/impl/write.hpp>", private ] },
+ { include: ["<boost/asio/detail/cstdint.hpp>", private, "<boost/asio/detail/chrono_time_traits.hpp>", private ] },
+ { include: ["<boost/asio/detail/cstdint.hpp>", private, "<boost/asio/detail/handler_tracking.hpp>", private ] },
+ { include: ["<boost/asio/detail/cstdint.hpp>", private, "<boost/asio/detail/timer_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/cstdint.hpp>", private, "<boost/asio/detail/win_iocp_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/date_time_fwd.hpp>", private, "<boost/asio/detail/timer_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/dependent_type.hpp>", private, "<boost/asio/impl/read_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/dependent_type.hpp>", private, "<boost/asio/impl/read.hpp>", private ] },
+ { include: ["<boost/asio/detail/dependent_type.hpp>", private, "<boost/asio/impl/write_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/dependent_type.hpp>", private, "<boost/asio/impl/write.hpp>", private ] },
+ { include: ["<boost/asio/detail/descriptor_ops.hpp>", private, "<boost/asio/detail/descriptor_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/descriptor_ops.hpp>", private, "<boost/asio/detail/descriptor_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/descriptor_ops.hpp>", private, "<boost/asio/detail/reactive_descriptor_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/descriptor_ops.hpp>", private, "<boost/asio/detail/reactive_serial_port_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/descriptor_read_op.hpp>", private, "<boost/asio/detail/reactive_descriptor_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/descriptor_write_op.hpp>", private, "<boost/asio/detail/reactive_descriptor_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/dev_poll_reactor.hpp>", private, "<boost/asio/detail/reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/dev_poll_reactor.hpp>", private, "<boost/asio/detail/timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/epoll_reactor.hpp>", private, "<boost/asio/detail/reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/epoll_reactor.hpp>", private, "<boost/asio/detail/timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/eventfd_select_interrupter.hpp>", private, "<boost/asio/detail/select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/event.hpp>", private, "<boost/asio/detail/task_io_service_thread_info.hpp>", private ] },
+ { include: ["<boost/asio/detail/event.hpp>", private, "<boost/asio/detail/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/fd_set_adapter.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/completion_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/deadline_timer_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/descriptor_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/descriptor_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/impl/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/impl/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/impl/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/reactive_descriptor_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/reactive_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/reactive_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/reactive_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/reactive_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/reactive_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/reactive_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/reactive_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/reactive_socket_sendto_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/resolve_endpoint_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/signal_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/wait_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/win_iocp_handle_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/win_iocp_handle_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/win_iocp_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/win_iocp_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/win_iocp_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/win_iocp_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/winrt_resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/winrt_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/winrt_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/fenced_block.hpp>", private, "<boost/asio/detail/winrt_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/gcc_arm_fenced_block.hpp>", private, "<boost/asio/detail/fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/gcc_hppa_fenced_block.hpp>", private, "<boost/asio/detail/fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/gcc_sync_fenced_block.hpp>", private, "<boost/asio/detail/fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/gcc_x86_fenced_block.hpp>", private, "<boost/asio/detail/fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/bind_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/completion_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/impl/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/impl/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/impl/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/reactive_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/resolve_endpoint_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/signal_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/signal_set_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/wait_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/win_iocp_handle_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/win_iocp_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/win_iocp_handle_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/win_iocp_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/win_iocp_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/win_iocp_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/win_iocp_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/win_object_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/winrt_resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/winrt_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/winrt_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/winrt_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/detail/wrapped_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/impl/buffered_read_stream.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/impl/buffered_write_stream.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/impl/connect.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/impl/read_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/impl/read.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/impl/read_until.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/impl/spawn.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/impl/write_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_alloc_helpers.hpp>", private, "<boost/asio/impl/write.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_cont_helpers.hpp>", private, "<boost/asio/detail/bind_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_cont_helpers.hpp>", private, "<boost/asio/detail/impl/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_cont_helpers.hpp>", private, "<boost/asio/detail/wrapped_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_cont_helpers.hpp>", private, "<boost/asio/impl/buffered_read_stream.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_cont_helpers.hpp>", private, "<boost/asio/impl/buffered_write_stream.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_cont_helpers.hpp>", private, "<boost/asio/impl/connect.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_cont_helpers.hpp>", private, "<boost/asio/impl/read_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_cont_helpers.hpp>", private, "<boost/asio/impl/read.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_cont_helpers.hpp>", private, "<boost/asio/impl/read_until.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_cont_helpers.hpp>", private, "<boost/asio/impl/spawn.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_cont_helpers.hpp>", private, "<boost/asio/impl/write_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_cont_helpers.hpp>", private, "<boost/asio/impl/write.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/bind_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/completion_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/impl/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/impl/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/impl/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/reactive_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/resolve_endpoint_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/signal_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/wait_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/win_iocp_handle_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/win_iocp_handle_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/win_iocp_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/win_iocp_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/win_iocp_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/win_iocp_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/winrt_resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/winrt_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/winrt_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/winrt_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/detail/wrapped_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/impl/buffered_read_stream.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/impl/buffered_write_stream.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/impl/connect.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/impl/read_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/impl/read.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/impl/read_until.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/impl/spawn.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/impl/write_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_invoke_helpers.hpp>", private, "<boost/asio/impl/write.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_tracking.hpp>", private, "<boost/asio/detail/task_io_service_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_tracking.hpp>", private, "<boost/asio/detail/win_iocp_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/impl/buffered_read_stream.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/impl/buffered_write_stream.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/impl/connect.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/impl/io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/impl/read_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/impl/read.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/impl/read_until.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/impl/write_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/impl/write.hpp>", private ] },
+ { include: ["<boost/asio/detail/hash_map.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/hash_map.hpp>", private, "<boost/asio/detail/reactor_op_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/buffer_sequence_adapter.ipp>", private, "<boost/asio/detail/buffer_sequence_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/buffer_sequence_adapter.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/descriptor_ops.ipp>", private, "<boost/asio/detail/descriptor_ops.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/descriptor_ops.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/dev_poll_reactor.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/dev_poll_reactor.ipp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/dev_poll_reactor.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/epoll_reactor.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/epoll_reactor.ipp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/epoll_reactor.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/eventfd_select_interrupter.ipp>", private, "<boost/asio/detail/eventfd_select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/eventfd_select_interrupter.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/handler_tracking.ipp>", private, "<boost/asio/detail/handler_tracking.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/handler_tracking.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/kqueue_reactor.hpp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/kqueue_reactor.ipp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/kqueue_reactor.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/pipe_select_interrupter.ipp>", private, "<boost/asio/detail/pipe_select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/pipe_select_interrupter.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/posix_event.ipp>", private, "<boost/asio/detail/posix_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/posix_event.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/posix_mutex.ipp>", private, "<boost/asio/detail/posix_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/posix_mutex.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/posix_thread.ipp>", private, "<boost/asio/detail/posix_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/posix_thread.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/posix_tss_ptr.ipp>", private, "<boost/asio/detail/posix_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/posix_tss_ptr.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/reactive_descriptor_service.ipp>", private, "<boost/asio/detail/reactive_descriptor_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/reactive_descriptor_service.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/reactive_serial_port_service.ipp>", private, "<boost/asio/detail/reactive_serial_port_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/reactive_serial_port_service.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/reactive_socket_service_base.ipp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/reactive_socket_service_base.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/resolver_service_base.ipp>", private, "<boost/asio/detail/resolver_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/resolver_service_base.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/select_reactor.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/select_reactor.ipp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/select_reactor.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/service_registry.hpp>", private, "<boost/asio/detail/service_registry.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/service_registry.ipp>", private, "<boost/asio/detail/service_registry.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/service_registry.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/signal_set_service.ipp>", private, "<boost/asio/detail/signal_set_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/signal_set_service.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/socket_ops.ipp>", private, "<boost/asio/detail/socket_ops.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/socket_ops.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/socket_select_interrupter.ipp>", private, "<boost/asio/detail/socket_select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/socket_select_interrupter.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/strand_service.hpp>", private, "<boost/asio/detail/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/strand_service.ipp>", private, "<boost/asio/detail/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/strand_service.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/task_io_service.hpp>", private, "<boost/asio/detail/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/task_io_service.ipp>", private, "<boost/asio/detail/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/task_io_service.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/throw_error.ipp>", private, "<boost/asio/detail/throw_error.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/throw_error.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/timer_queue_ptime.ipp>", private, "<boost/asio/detail/timer_queue_ptime.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/timer_queue_ptime.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/timer_queue_set.ipp>", private, "<boost/asio/detail/timer_queue_set.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/timer_queue_set.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_event.ipp>", private, "<boost/asio/detail/win_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_event.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_iocp_handle_service.ipp>", private, "<boost/asio/detail/win_iocp_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_iocp_handle_service.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_iocp_io_service.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_iocp_io_service.ipp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_iocp_io_service.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_iocp_serial_port_service.ipp>", private, "<boost/asio/detail/win_iocp_serial_port_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_iocp_serial_port_service.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_iocp_socket_service_base.ipp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_iocp_socket_service_base.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_mutex.ipp>", private, "<boost/asio/detail/win_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_mutex.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_object_handle_service.ipp>", private, "<boost/asio/detail/win_object_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_object_handle_service.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/winrt_ssocket_service_base.ipp>", private, "<boost/asio/detail/winrt_ssocket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/winrt_ssocket_service_base.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/winrt_timer_scheduler.hpp>", private, "<boost/asio/detail/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/winrt_timer_scheduler.ipp>", private, "<boost/asio/detail/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/winrt_timer_scheduler.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/winsock_init.ipp>", private, "<boost/asio/detail/winsock_init.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/winsock_init.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_static_mutex.ipp>", private, "<boost/asio/detail/win_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_static_mutex.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_thread.ipp>", private, "<boost/asio/detail/win_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_thread.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_tss_ptr.ipp>", private, "<boost/asio/detail/win_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/impl/win_tss_ptr.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/detail/keyword_tss_ptr.hpp>", private, "<boost/asio/detail/tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/kqueue_reactor.hpp>", private, "<boost/asio/detail/reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/kqueue_reactor.hpp>", private, "<boost/asio/detail/timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/limits.hpp>", private, "<boost/asio/detail/buffer_resize_guard.hpp>", private ] },
+ { include: ["<boost/asio/detail/limits.hpp>", private, "<boost/asio/detail/consuming_buffers.hpp>", private ] },
+ { include: ["<boost/asio/detail/limits.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/limits.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/limits.hpp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/limits.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/limits.hpp>", private, "<boost/asio/detail/timer_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/limits.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/limits.hpp>", private, "<boost/asio/detail/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/limits.hpp>", private, "<boost/asio/impl/read_until.hpp>", private ] },
+ { include: ["<boost/asio/detail/macos_fenced_block.hpp>", private, "<boost/asio/detail/fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/mutex.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/mutex.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/mutex.hpp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/mutex.hpp>", private, "<boost/asio/detail/resolver_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/mutex.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/mutex.hpp>", private, "<boost/asio/detail/service_registry.hpp>", private ] },
+ { include: ["<boost/asio/detail/mutex.hpp>", private, "<boost/asio/detail/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/mutex.hpp>", private, "<boost/asio/detail/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/mutex.hpp>", private, "<boost/asio/detail/win_iocp_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/mutex.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/mutex.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/mutex.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/mutex.hpp>", private, "<boost/asio/detail/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/call_stack.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/deadline_timer_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/handler_alloc_helpers.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/hash_map.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/keyword_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/local_free_on_block_exit.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/null_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/null_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/null_signal_blocker.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/null_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/null_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/object_pool.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/op_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/posix_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/posix_fd_set_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/posix_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/posix_signal_blocker.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/posix_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/posix_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/reactive_descriptor_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/reactor_op_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/resolver_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/scoped_lock.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/service_registry.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/socket_holder.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/std_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/std_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/std_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/std_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/thread_info_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/timer_queue_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/wince_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/win_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/win_fd_set_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/win_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/win_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/detail/win_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/impl/spawn.hpp>", private ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/ssl/detail/openssl_init.hpp>", private ] },
+ { include: ["<boost/asio/detail/null_event.hpp>", private, "<boost/asio/detail/event.hpp>", private ] },
+ { include: ["<boost/asio/detail/null_fenced_block.hpp>", private, "<boost/asio/detail/fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/null_mutex.hpp>", private, "<boost/asio/detail/mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/null_reactor.hpp>", private, "<boost/asio/detail/reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/null_signal_blocker.hpp>", private, "<boost/asio/detail/signal_blocker.hpp>", private ] },
+ { include: ["<boost/asio/detail/null_static_mutex.hpp>", private, "<boost/asio/detail/static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/null_thread.hpp>", private, "<boost/asio/detail/thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/null_tss_ptr.hpp>", private, "<boost/asio/detail/tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/object_pool.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/object_pool.hpp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/old_win_sdk_compat.hpp>", private, "<boost/asio/detail/socket_types.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/completion_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/reactor_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/resolve_endpoint_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/resolver_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/signal_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/timer_queue_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/wait_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/win_iocp_handle_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/win_iocp_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/win_iocp_handle_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/win_iocp_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/win_iocp_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/win_iocp_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/operation.hpp>", private, "<boost/asio/detail/winrt_async_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/reactor_op_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/signal_set_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/task_io_service_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/task_io_service_thread_info.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/timer_queue_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/timer_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/win_iocp_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/op_queue.hpp>", private, "<boost/asio/detail/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/pipe_select_interrupter.hpp>", private, "<boost/asio/detail/select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/base_from_completion_cond.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/bind_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/buffered_stream_storage.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/buffer_resize_guard.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/buffer_sequence_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/call_stack.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/chrono_time_traits.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/completion_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/consuming_buffers.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/deadline_timer_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/dependent_type.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/descriptor_ops.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/descriptor_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/descriptor_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/eventfd_select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/gcc_arm_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/gcc_hppa_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/gcc_sync_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/gcc_x86_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/handler_alloc_helpers.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/handler_cont_helpers.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/handler_invoke_helpers.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/handler_tracking.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/hash_map.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/impl/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/impl/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/impl/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/impl/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/impl/service_registry.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/impl/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/impl/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/impl/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/impl/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/io_control.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/keyword_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/local_free_on_block_exit.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/macos_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/noncopyable.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/null_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/null_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/null_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/null_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/null_signal_blocker.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/null_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/null_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/null_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/null_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/object_pool.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/old_win_sdk_compat.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/op_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/pipe_select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/posix_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/posix_fd_set_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/posix_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/posix_signal_blocker.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/posix_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/posix_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/posix_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/reactive_descriptor_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/reactive_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/reactive_serial_port_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/reactive_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/reactive_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/reactive_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/reactive_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/reactive_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/reactive_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/reactive_socket_sendto_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/reactor_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/reactor_op_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/resolve_endpoint_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/resolver_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/scoped_lock.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/scoped_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/service_registry.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/signal_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/signal_init.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/signal_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/signal_set_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/socket_holder.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/socket_ops.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/socket_option.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/socket_select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/socket_types.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/solaris_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/std_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/std_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/std_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/std_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/task_io_service_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/task_io_service_thread_info.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/thread_info_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/throw_error.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/timer_queue_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/timer_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/timer_queue_ptime.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/timer_queue_set.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/wait_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/wait_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/wince_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_fd_set_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_handle_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_handle_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_serial_port_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_iocp_thread_info.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_object_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/winrt_async_manager.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/winrt_async_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/winrt_resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/winrt_resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/winrt_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/winrt_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/winrt_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/winrt_ssocket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/winrt_ssocket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/winrt_utils.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/winsock_init.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/win_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/detail/wrapped_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/generic/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/impl/buffered_read_stream.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/impl/buffered_write_stream.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/impl/connect.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/impl/io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/impl/read_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/impl/read.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/impl/read_until.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/impl/serial_port_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/impl/spawn.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/impl/use_future.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/impl/write_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/impl/write.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/detail/socket_option.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/impl/address.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/impl/address_v4.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/impl/address_v6.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/impl/basic_endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/local/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/detail/buffered_handshake_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/detail/engine.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/detail/handshake_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/detail/io.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/detail/openssl_init.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/detail/password_callback.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/detail/read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/detail/shutdown_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/detail/stream_core.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/detail/verify_callback.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/detail/write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/impl/context.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/old/detail/openssl_context_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/old/detail/openssl_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/old/detail/openssl_stream_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/posix_event.hpp>", private, "<boost/asio/detail/event.hpp>", private ] },
+ { include: ["<boost/asio/detail/posix_fd_set_adapter.hpp>", private, "<boost/asio/detail/fd_set_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/posix_mutex.hpp>", private, "<boost/asio/detail/mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/posix_signal_blocker.hpp>", private, "<boost/asio/detail/signal_blocker.hpp>", private ] },
+ { include: ["<boost/asio/detail/posix_static_mutex.hpp>", private, "<boost/asio/detail/static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/posix_thread.hpp>", private, "<boost/asio/detail/thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/posix_tss_ptr.hpp>", private, "<boost/asio/detail/tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/base_from_completion_cond.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/bind_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/buffered_stream_storage.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/buffer_resize_guard.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/buffer_sequence_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/call_stack.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/chrono_time_traits.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/completion_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/consuming_buffers.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/deadline_timer_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/dependent_type.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/descriptor_ops.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/descriptor_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/descriptor_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/eventfd_select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/gcc_arm_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/gcc_hppa_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/gcc_sync_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/gcc_x86_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/handler_alloc_helpers.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/handler_cont_helpers.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/handler_invoke_helpers.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/handler_tracking.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/hash_map.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/impl/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/impl/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/impl/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/impl/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/impl/service_registry.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/impl/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/impl/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/impl/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/impl/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/io_control.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/keyword_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/local_free_on_block_exit.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/macos_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/noncopyable.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/null_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/null_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/null_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/null_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/null_signal_blocker.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/null_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/null_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/null_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/null_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/object_pool.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/old_win_sdk_compat.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/op_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/pipe_select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/posix_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/posix_fd_set_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/posix_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/posix_signal_blocker.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/posix_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/posix_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/posix_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/reactive_descriptor_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/reactive_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/reactive_serial_port_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/reactive_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/reactive_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/reactive_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/reactive_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/reactive_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/reactive_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/reactive_socket_sendto_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/reactor_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/reactor_op_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/resolve_endpoint_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/resolver_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/scoped_lock.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/scoped_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/service_registry.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/signal_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/signal_init.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/signal_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/signal_set_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/socket_holder.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/socket_ops.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/socket_option.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/socket_select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/socket_types.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/solaris_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/std_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/std_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/std_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/std_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/task_io_service_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/task_io_service_thread_info.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/thread_info_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/throw_error.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/timer_queue_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/timer_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/timer_queue_ptime.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/timer_queue_set.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/wait_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/wait_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/wince_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_fd_set_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_handle_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_handle_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_serial_port_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_iocp_thread_info.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_object_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/winrt_async_manager.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/winrt_async_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/winrt_resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/winrt_resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/winrt_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/winrt_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/winrt_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/winrt_ssocket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/winrt_ssocket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/winrt_utils.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/winsock_init.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/win_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/detail/wrapped_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/generic/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/impl/buffered_read_stream.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/impl/buffered_write_stream.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/impl/connect.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/impl/io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/impl/read_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/impl/read.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/impl/read_until.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/impl/serial_port_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/impl/spawn.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/impl/use_future.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/impl/write_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/impl/write.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/detail/socket_option.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/impl/address.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/impl/address_v4.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/impl/address_v6.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/impl/basic_endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/local/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/detail/buffered_handshake_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/detail/engine.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/detail/handshake_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/detail/io.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/detail/openssl_init.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/detail/password_callback.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/detail/read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/detail/shutdown_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/detail/stream_core.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/detail/verify_callback.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/detail/write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/impl/context.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/old/detail/openssl_context_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/old/detail/openssl_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/old/detail/openssl_stream_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactive_descriptor_service.hpp>", private, "<boost/asio/detail/reactive_serial_port_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactive_null_buffers_op.hpp>", private, "<boost/asio/detail/reactive_descriptor_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactive_null_buffers_op.hpp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactive_null_buffers_op.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactive_socket_accept_op.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactive_socket_connect_op.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactive_socket_connect_op.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactive_socket_recvfrom_op.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactive_socket_recvmsg_op.hpp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactive_socket_recv_op.hpp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactive_socket_send_op.hpp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactive_socket_sendto_op.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactive_socket_service_base.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_fwd.hpp>", private, "<boost/asio/detail/reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_fwd.hpp>", private, "<boost/asio/detail/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor.hpp>", private, "<boost/asio/detail/reactive_descriptor_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor.hpp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor.hpp>", private, "<boost/asio/detail/signal_set_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/descriptor_read_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/descriptor_write_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/reactive_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/reactive_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/reactive_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/reactive_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/reactive_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/reactive_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/reactive_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/reactive_socket_sendto_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/reactor_op_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/win_iocp_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op_queue.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/reactor_op_queue.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/resolve_endpoint_op.hpp>", private, "<boost/asio/detail/resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/resolve_op.hpp>", private, "<boost/asio/detail/resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/resolver_service_base.hpp>", private, "<boost/asio/detail/resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/scoped_lock.hpp>", private, "<boost/asio/detail/null_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/scoped_lock.hpp>", private, "<boost/asio/detail/null_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/scoped_lock.hpp>", private, "<boost/asio/detail/posix_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/scoped_lock.hpp>", private, "<boost/asio/detail/posix_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/scoped_lock.hpp>", private, "<boost/asio/detail/std_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/scoped_lock.hpp>", private, "<boost/asio/detail/std_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/scoped_lock.hpp>", private, "<boost/asio/detail/win_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/scoped_lock.hpp>", private, "<boost/asio/detail/win_static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/scoped_ptr.hpp>", private, "<boost/asio/detail/resolver_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/scoped_ptr.hpp>", private, "<boost/asio/detail/strand_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/scoped_ptr.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/select_interrupter.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/select_interrupter.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/select_interrupter.hpp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/select_interrupter.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/select_reactor.hpp>", private, "<boost/asio/detail/reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/select_reactor.hpp>", private, "<boost/asio/detail/timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/service_registry.hpp>", private, "<boost/asio/impl/io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/shared_ptr.hpp>", private, "<boost/asio/detail/socket_ops.hpp>", private ] },
+ { include: ["<boost/asio/detail/shared_ptr.hpp>", private, "<boost/asio/impl/spawn.hpp>", private ] },
+ { include: ["<boost/asio/detail/shared_ptr.hpp>", private, "<boost/asio/ssl/detail/openssl_init.hpp>", private ] },
+ { include: ["<boost/asio/detail/signal_handler.hpp>", private, "<boost/asio/detail/signal_set_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/signal_op.hpp>", private, "<boost/asio/detail/signal_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/signal_op.hpp>", private, "<boost/asio/detail/signal_set_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_holder.hpp>", private, "<boost/asio/detail/reactive_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_holder.hpp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_holder.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_holder.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_holder.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/deadline_timer_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/reactive_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/reactive_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/reactive_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/reactive_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/reactive_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/reactive_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/reactive_socket_sendto_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/resolve_endpoint_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/resolver_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/socket_holder.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/win_iocp_null_buffers_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/win_iocp_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvfrom_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/win_iocp_socket_recvmsg_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/win_iocp_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/win_iocp_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/winrt_resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/detail/winrt_utils.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/ip/detail/socket_option.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/ssl/old/detail/openssl_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_select_interrupter.hpp>", private, "<boost/asio/detail/select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/buffer_sequence_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/deadline_timer_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/descriptor_ops.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/hash_map.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/io_control.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/local_free_on_block_exit.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/posix_fd_set_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/reactive_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/reactive_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/resolver_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/signal_set_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/socket_ops.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/socket_option.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/socket_select_interrupter.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/wince_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/win_event.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/win_fd_set_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/win_fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/win_iocp_operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/win_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/winrt_ssocket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/win_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/detail/win_tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/generic/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/ip/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/ip/detail/socket_option.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/local/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/ssl/detail/openssl_types.hpp>", private ] },
+ { include: ["<boost/asio/detail/solaris_fenced_block.hpp>", private, "<boost/asio/detail/fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/static_mutex.hpp>", private, "<boost/asio/detail/handler_tracking.hpp>", private ] },
+ { include: ["<boost/asio/detail/static_mutex.hpp>", private, "<boost/asio/ssl/detail/engine.hpp>", private ] },
+ { include: ["<boost/asio/detail/std_event.hpp>", private, "<boost/asio/detail/event.hpp>", private ] },
+ { include: ["<boost/asio/detail/std_mutex.hpp>", private, "<boost/asio/detail/mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/std_static_mutex.hpp>", private, "<boost/asio/detail/static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/std_thread.hpp>", private, "<boost/asio/detail/thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/task_io_service.hpp>", private, "<boost/asio/impl/io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/task_io_service_operation.hpp>", private, "<boost/asio/detail/operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/task_io_service_operation.hpp>", private, "<boost/asio/detail/task_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/thread.hpp>", private, "<boost/asio/detail/resolver_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/thread.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/thread.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/thread.hpp>", private, "<boost/asio/detail/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/thread_info_base.hpp>", private, "<boost/asio/detail/task_io_service_thread_info.hpp>", private ] },
+ { include: ["<boost/asio/detail/thread_info_base.hpp>", private, "<boost/asio/detail/win_iocp_thread_info.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/detail/null_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/detail/wince_thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/impl/connect.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/impl/read_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/impl/read.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/impl/read_until.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/impl/write_at.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/impl/write.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/ip/impl/address.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/ip/impl/address_v4.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/ip/impl/address_v6.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/ip/impl/basic_endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/ssl/impl/context.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/ssl/old/detail/openssl_context_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_exception.hpp>", private, "<boost/asio/detail/socket_option.hpp>", private ] },
+ { include: ["<boost/asio/detail/throw_exception.hpp>", private, "<boost/asio/ip/detail/socket_option.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue_base.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue_base.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue_base.hpp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue_base.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue_base.hpp>", private, "<boost/asio/detail/timer_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue_base.hpp>", private, "<boost/asio/detail/timer_queue_set.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue_base.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue_base.hpp>", private, "<boost/asio/detail/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue.hpp>", private, "<boost/asio/detail/deadline_timer_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue.hpp>", private, "<boost/asio/detail/timer_queue_ptime.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue_set.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue_set.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue_set.hpp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue_set.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue_set.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_queue_set.hpp>", private, "<boost/asio/detail/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_scheduler_fwd.hpp>", private, "<boost/asio/detail/timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/timer_scheduler.hpp>", private, "<boost/asio/detail/deadline_timer_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/tss_ptr.hpp>", private, "<boost/asio/detail/call_stack.hpp>", private ] },
+ { include: ["<boost/asio/detail/tss_ptr.hpp>", private, "<boost/asio/detail/handler_tracking.hpp>", private ] },
+ { include: ["<boost/asio/detail/wait_handler.hpp>", private, "<boost/asio/detail/deadline_timer_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/wait_handler.hpp>", private, "<boost/asio/detail/win_object_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/wait_op.hpp>", private, "<boost/asio/detail/deadline_timer_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/wait_op.hpp>", private, "<boost/asio/detail/dev_poll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/wait_op.hpp>", private, "<boost/asio/detail/epoll_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/wait_op.hpp>", private, "<boost/asio/detail/kqueue_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/wait_op.hpp>", private, "<boost/asio/detail/select_reactor.hpp>", private ] },
+ { include: ["<boost/asio/detail/wait_op.hpp>", private, "<boost/asio/detail/timer_queue.hpp>", private ] },
+ { include: ["<boost/asio/detail/wait_op.hpp>", private, "<boost/asio/detail/wait_handler.hpp>", private ] },
+ { include: ["<boost/asio/detail/wait_op.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/wait_op.hpp>", private, "<boost/asio/detail/winrt_timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/weak_ptr.hpp>", private, "<boost/asio/detail/socket_ops.hpp>", private ] },
+ { include: ["<boost/asio/detail/wince_thread.hpp>", private, "<boost/asio/detail/thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_event.hpp>", private, "<boost/asio/detail/event.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_fd_set_adapter.hpp>", private, "<boost/asio/detail/fd_set_adapter.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_fenced_block.hpp>", private, "<boost/asio/detail/fenced_block.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_handle_read_op.hpp>", private, "<boost/asio/detail/win_iocp_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_handle_service.hpp>", private, "<boost/asio/detail/win_iocp_serial_port_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_handle_write_op.hpp>", private, "<boost/asio/detail/win_iocp_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_io_service.hpp>", private, "<boost/asio/detail/timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_io_service.hpp>", private, "<boost/asio/detail/win_iocp_handle_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_io_service.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_io_service.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_io_service.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_io_service.hpp>", private, "<boost/asio/impl/io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_null_buffers_op.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_null_buffers_op.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_operation.hpp>", private, "<boost/asio/detail/operation.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_operation.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_overlapped_op.hpp>", private, "<boost/asio/detail/win_iocp_overlapped_ptr.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_socket_accept_op.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_socket_recvfrom_op.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_socket_recvmsg_op.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_socket_recv_op.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_socket_send_op.hpp>", private, "<boost/asio/detail/win_iocp_socket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_socket_send_op.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_socket_service_base.hpp>", private, "<boost/asio/detail/win_iocp_socket_accept_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_socket_service_base.hpp>", private, "<boost/asio/detail/win_iocp_socket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_iocp_thread_info.hpp>", private, "<boost/asio/detail/win_iocp_io_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_mutex.hpp>", private, "<boost/asio/detail/mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_async_manager.hpp>", private, "<boost/asio/detail/winrt_resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_async_manager.hpp>", private, "<boost/asio/detail/winrt_ssocket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_async_op.hpp>", private, "<boost/asio/detail/winrt_async_manager.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_async_op.hpp>", private, "<boost/asio/detail/winrt_resolve_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_async_op.hpp>", private, "<boost/asio/detail/winrt_socket_connect_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_async_op.hpp>", private, "<boost/asio/detail/winrt_socket_recv_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_async_op.hpp>", private, "<boost/asio/detail/winrt_socket_send_op.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_resolve_op.hpp>", private, "<boost/asio/detail/winrt_resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_socket_connect_op.hpp>", private, "<boost/asio/detail/winrt_ssocket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_socket_recv_op.hpp>", private, "<boost/asio/detail/winrt_ssocket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_socket_send_op.hpp>", private, "<boost/asio/detail/winrt_ssocket_service_base.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_ssocket_service_base.hpp>", private, "<boost/asio/detail/winrt_ssocket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_timer_scheduler.hpp>", private, "<boost/asio/detail/timer_scheduler.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_utils.hpp>", private, "<boost/asio/detail/winrt_resolver_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/winrt_utils.hpp>", private, "<boost/asio/detail/winrt_ssocket_service.hpp>", private ] },
+ { include: ["<boost/asio/detail/winsock_init.hpp>", private, "<boost/asio/ip/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_static_mutex.hpp>", private, "<boost/asio/detail/static_mutex.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_thread.hpp>", private, "<boost/asio/detail/thread.hpp>", private ] },
+ { include: ["<boost/asio/detail/win_tss_ptr.hpp>", private, "<boost/asio/detail/tss_ptr.hpp>", private ] },
+ { include: ["<boost/asio/generic/detail/impl/endpoint.ipp>", private, "<boost/asio/generic/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/generic/detail/impl/endpoint.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/impl/error.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/impl/handler_alloc_hook.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/impl/io_service.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/impl/serial_port_base.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/ip/detail/impl/endpoint.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/ip/detail/impl/endpoint.ipp>", private, "<boost/asio/ip/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/ip/impl/address.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/ip/impl/address_v4.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/ip/impl/address_v6.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/ip/impl/host_name.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/local/detail/impl/endpoint.ipp>", private, "<boost/asio/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/local/detail/impl/endpoint.ipp>", private, "<boost/asio/local/detail/endpoint.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/engine.hpp>", private, "<boost/asio/ssl/detail/buffered_handshake_op.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/engine.hpp>", private, "<boost/asio/ssl/detail/handshake_op.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/engine.hpp>", private, "<boost/asio/ssl/detail/io.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/engine.hpp>", private, "<boost/asio/ssl/detail/read_op.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/engine.hpp>", private, "<boost/asio/ssl/detail/shutdown_op.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/engine.hpp>", private, "<boost/asio/ssl/detail/stream_core.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/engine.hpp>", private, "<boost/asio/ssl/detail/write_op.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/impl/engine.ipp>", private, "<boost/asio/ssl/detail/engine.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/impl/engine.ipp>", private, "<boost/asio/ssl/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/impl/openssl_init.ipp>", private, "<boost/asio/ssl/detail/openssl_init.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/impl/openssl_init.ipp>", private, "<boost/asio/ssl/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/openssl_init.hpp>", private, "<boost/asio/ssl/old/detail/openssl_context_service.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/openssl_types.hpp>", private, "<boost/asio/ssl/detail/engine.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/openssl_types.hpp>", private, "<boost/asio/ssl/detail/openssl_init.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/openssl_types.hpp>", private, "<boost/asio/ssl/old/detail/openssl_context_service.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/openssl_types.hpp>", private, "<boost/asio/ssl/old/detail/openssl_operation.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/openssl_types.hpp>", private, "<boost/asio/ssl/old/detail/openssl_stream_service.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/stream_core.hpp>", private, "<boost/asio/ssl/detail/io.hpp>", private ] },
+ { include: ["<boost/asio/ssl/detail/verify_callback.hpp>", private, "<boost/asio/ssl/detail/engine.hpp>", private ] },
+ { include: ["<boost/asio/ssl/impl/context.ipp>", private, "<boost/asio/ssl/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/ssl/impl/error.ipp>", private, "<boost/asio/ssl/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/ssl/impl/rfc2818_verification.ipp>", private, "<boost/asio/ssl/impl/src.hpp>", private ] },
+ { include: ["<boost/asio/ssl/old/detail/openssl_operation.hpp>", private, "<boost/asio/ssl/old/detail/openssl_stream_service.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/cas128strong.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/cas32strong.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/cas32weak.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/cas64strong.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/cas64strong-ptr.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/gcc-alpha.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/gcc-armv6plus.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/gcc-atomic.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/gcc-cas.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/gcc-ppc.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/gcc-sparcv9.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/gcc-x86.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/generic-cas.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/linux-arm.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/platform.hpp>", private ] },
+ { include: ["<boost/atomic/detail/base.hpp>", private, "<boost/atomic/detail/windows.hpp>", private ] },
+ { include: ["<boost/atomic/detail/builder.hpp>", private, "<boost/atomic/detail/gcc-alpha.hpp>", private ] },
+ { include: ["<boost/atomic/detail/builder.hpp>", private, "<boost/atomic/detail/generic-cas.hpp>", private ] },
+ { include: ["<boost/atomic/detail/cas128strong.hpp>", private, "<boost/atomic/detail/gcc-x86.hpp>", private ] },
+ { include: ["<boost/atomic/detail/cas32strong.hpp>", private, "<boost/atomic/detail/gcc-cas.hpp>", private ] },
+ { include: ["<boost/atomic/detail/cas32weak.hpp>", private, "<boost/atomic/detail/gcc-armv6plus.hpp>", private ] },
+ { include: ["<boost/atomic/detail/cas32weak.hpp>", private, "<boost/atomic/detail/linux-arm.hpp>", private ] },
+ { include: ["<boost/atomic/detail/cas64strong.hpp>", private, "<boost/atomic/detail/gcc-x86.hpp>", private ] },
+ { include: ["<boost/atomic/detail/cas64strong.hpp>", private, "<boost/atomic/detail/windows.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/base.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/cas128strong.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/cas32strong.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/cas32weak.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/cas64strong.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/cas64strong-ptr.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/gcc-alpha.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/gcc-armv6plus.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/gcc-atomic.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/gcc-cas.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/gcc-ppc.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/gcc-sparcv9.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/gcc-x86.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/generic-cas.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/interlocked.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/link.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/linux-arm.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/lockpool.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/platform.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/type-classification.hpp>", private ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/detail/windows.hpp>", private ] },
+ { include: ["<boost/atomic/detail/gcc-alpha.hpp>", private, "<boost/atomic/detail/platform.hpp>", private ] },
+ { include: ["<boost/atomic/detail/gcc-armv6plus.hpp>", private, "<boost/atomic/detail/platform.hpp>", private ] },
+ { include: ["<boost/atomic/detail/gcc-atomic.hpp>", private, "<boost/atomic/detail/platform.hpp>", private ] },
+ { include: ["<boost/atomic/detail/gcc-cas.hpp>", private, "<boost/atomic/detail/platform.hpp>", private ] },
+ { include: ["<boost/atomic/detail/gcc-ppc.hpp>", private, "<boost/atomic/detail/platform.hpp>", private ] },
+ { include: ["<boost/atomic/detail/gcc-sparcv9.hpp>", private, "<boost/atomic/detail/platform.hpp>", private ] },
+ { include: ["<boost/atomic/detail/gcc-x86.hpp>", private, "<boost/atomic/detail/platform.hpp>", private ] },
+ { include: ["<boost/atomic/detail/interlocked.hpp>", private, "<boost/atomic/detail/windows.hpp>", private ] },
+ { include: ["<boost/atomic/detail/link.hpp>", private, "<boost/atomic/detail/lockpool.hpp>", private ] },
+ { include: ["<boost/atomic/detail/linux-arm.hpp>", private, "<boost/atomic/detail/platform.hpp>", private ] },
+ { include: ["<boost/atomic/detail/lockpool.hpp>", private, "<boost/atomic/detail/base.hpp>", private ] },
+ { include: ["<boost/atomic/detail/windows.hpp>", private, "<boost/atomic/detail/platform.hpp>", private ] },
+ { include: ["<boost/bimap/detail/concept_tags.hpp>", private, "<boost/bimap/detail/is_set_type_of.hpp>", private ] },
+ { include: ["<boost/bimap/detail/debug/static_error.hpp>", private, "<boost/bimap/detail/map_view_base.hpp>", private ] },
+ { include: ["<boost/bimap/detail/debug/static_error.hpp>", private, "<boost/bimap/relation/detail/metadata_access_builder.hpp>", private ] },
+ { include: ["<boost/bimap/detail/debug/static_error.hpp>", private, "<boost/bimap/relation/detail/mutant.hpp>", private ] },
+ { include: ["<boost/bimap/detail/debug/static_error.hpp>", private, "<boost/bimap/relation/detail/static_access_builder.hpp>", private ] },
+ { include: ["<boost/bimap/detail/is_set_type_of.hpp>", private, "<boost/bimap/detail/manage_additional_parameters.hpp>", private ] },
+ { include: ["<boost/bimap/detail/is_set_type_of.hpp>", private, "<boost/bimap/detail/manage_bimap_key.hpp>", private ] },
+ { include: ["<boost/bimap/detail/manage_additional_parameters.hpp>", private, "<boost/bimap/detail/bimap_core.hpp>", private ] },
+ { include: ["<boost/bimap/detail/manage_bimap_key.hpp>", private, "<boost/bimap/detail/bimap_core.hpp>", private ] },
+ { include: ["<boost/bimap/detail/map_view_iterator.hpp>", private, "<boost/bimap/detail/bimap_core.hpp>", private ] },
+ { include: ["<boost/bimap/detail/map_view_iterator.hpp>", private, "<boost/bimap/detail/map_view_base.hpp>", private ] },
+ { include: ["<boost/bimap/detail/modifier_adaptor.hpp>", private, "<boost/bimap/detail/map_view_base.hpp>", private ] },
+ { include: ["<boost/bimap/detail/modifier_adaptor.hpp>", private, "<boost/bimap/detail/set_view_base.hpp>", private ] },
+ { include: ["<boost/bimap/detail/set_view_iterator.hpp>", private, "<boost/bimap/detail/bimap_core.hpp>", private ] },
+ { include: ["<boost/bimap/detail/set_view_iterator.hpp>", private, "<boost/bimap/detail/set_view_base.hpp>", private ] },
+ { include: ["<boost/bimap/relation/detail/metadata_access_builder.hpp>", private, "<boost/bimap/detail/map_view_iterator.hpp>", private ] },
+ { include: ["<boost/bimap/relation/detail/mutant.hpp>", private, "<boost/bimap/relation/detail/to_mutable_relation_functor.hpp>", private ] },
+ { include: ["<boost/bimap/relation/detail/static_access_builder.hpp>", private, "<boost/bimap/detail/map_view_iterator.hpp>", private ] },
+ { include: ["<boost/bimap/relation/detail/to_mutable_relation_functor.hpp>", private, "<boost/bimap/detail/map_view_base.hpp>", private ] },
+ { include: ["<boost/bimap/relation/detail/to_mutable_relation_functor.hpp>", private, "<boost/bimap/detail/set_view_base.hpp>", private ] },
+ { include: ["<boost/chrono/detail/inlined/mac/chrono.hpp>", private, "<boost/chrono/detail/inlined/chrono.hpp>", private ] },
+ { include: ["<boost/chrono/detail/inlined/mac/process_cpu_clocks.hpp>", private, "<boost/chrono/detail/inlined/process_cpu_clocks.hpp>", private ] },
+ { include: ["<boost/chrono/detail/inlined/mac/thread_clock.hpp>", private, "<boost/chrono/detail/inlined/thread_clock.hpp>", private ] },
+ { include: ["<boost/chrono/detail/inlined/posix/chrono.hpp>", private, "<boost/chrono/detail/inlined/chrono.hpp>", private ] },
+ { include: ["<boost/chrono/detail/inlined/posix/process_cpu_clocks.hpp>", private, "<boost/chrono/detail/inlined/process_cpu_clocks.hpp>", private ] },
+ { include: ["<boost/chrono/detail/inlined/posix/thread_clock.hpp>", private, "<boost/chrono/detail/inlined/mac/thread_clock.hpp>", private ] },
+ { include: ["<boost/chrono/detail/inlined/posix/thread_clock.hpp>", private, "<boost/chrono/detail/inlined/thread_clock.hpp>", private ] },
+ { include: ["<boost/chrono/detail/inlined/win/chrono.hpp>", private, "<boost/chrono/detail/inlined/chrono.hpp>", private ] },
+ { include: ["<boost/chrono/detail/inlined/win/process_cpu_clocks.hpp>", private, "<boost/chrono/detail/inlined/process_cpu_clocks.hpp>", private ] },
+ { include: ["<boost/chrono/detail/inlined/win/thread_clock.hpp>", private, "<boost/chrono/detail/inlined/thread_clock.hpp>", private ] },
+ { include: ["<boost/chrono/detail/system.hpp>", private, "<boost/chrono/detail/inlined/chrono.hpp>", private ] },
+ { include: ["<boost/chrono/detail/system.hpp>", private, "<boost/chrono/detail/inlined/thread_clock.hpp>", private ] },
+ { include: ["<boost/concept/detail/backward_compatibility.hpp>", private, "<boost/concept/detail/borland.hpp>", private ] },
+ { include: ["<boost/concept/detail/backward_compatibility.hpp>", private, "<boost/concept/detail/general.hpp>", private ] },
+ { include: ["<boost/concept/detail/backward_compatibility.hpp>", private, "<boost/concept/detail/has_constraints.hpp>", private ] },
+ { include: ["<boost/concept/detail/backward_compatibility.hpp>", private, "<boost/concept/detail/msvc.hpp>", private ] },
+ { include: ["<boost/concept/detail/concept_def.hpp>", private, "<boost/icl/detail/concept_check.hpp>", private ] },
+ { include: ["<boost/concept/detail/has_constraints.hpp>", private, "<boost/concept/detail/general.hpp>", private ] },
+ { include: ["<boost/concept/detail/has_constraints.hpp>", private, "<boost/concept/detail/msvc.hpp>", private ] },
+ { include: ["<boost/container/detail/adaptive_node_pool_impl.hpp>", private, "<boost/interprocess/allocators/detail/adaptive_node_pool.hpp>", private ] },
+ { include: ["<boost/container/detail/algorithms.hpp>", private, "<boost/container/detail/node_alloc_holder.hpp>", private ] },
+ { include: ["<boost/container/detail/algorithms.hpp>", private, "<boost/container/detail/tree.hpp>", private ] },
+ { include: ["<boost/container/detail/allocation_type.hpp>", private, "<boost/container/detail/allocator_version_traits.hpp>", private ] },
+ { include: ["<boost/container/detail/allocator_version_traits.hpp>", private, "<boost/container/detail/node_alloc_holder.hpp>", private ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/detail/allocator_version_traits.hpp>", private ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/detail/function_detector.hpp>", private ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/detail/memory_util.hpp>", private ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/detail/preprocessor.hpp>", private ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/detail/workaround.hpp>", private ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/geometry/index/detail/varray.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/adaptive_node_pool_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/advanced_insert_int.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/algorithms.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/allocation_type.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/allocator_version_traits.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/destroyers.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/flat_tree.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/function_detector.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/iterators.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/math_functions.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/memory_util.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/multiallocation_chain.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/node_alloc_holder.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/node_pool_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/pair.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/pool_common.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/preprocessor.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/transform_iterator.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/tree.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/type_traits.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/utilities.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/value_init.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/variadic_templates_tools.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/detail/workaround.hpp>", private ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/geometry/index/detail/varray.hpp>", private ] },
+ { include: ["<boost/container/detail/destroyers.hpp>", private, "<boost/container/detail/advanced_insert_int.hpp>", private ] },
+ { include: ["<boost/container/detail/destroyers.hpp>", private, "<boost/container/detail/flat_tree.hpp>", private ] },
+ { include: ["<boost/container/detail/destroyers.hpp>", private, "<boost/container/detail/node_alloc_holder.hpp>", private ] },
+ { include: ["<boost/container/detail/destroyers.hpp>", private, "<boost/container/detail/tree.hpp>", private ] },
+ { include: ["<boost/container/detail/iterators.hpp>", private, "<boost/container/detail/algorithms.hpp>", private ] },
+ { include: ["<boost/container/detail/iterators.hpp>", private, "<boost/container/detail/tree.hpp>", private ] },
+ { include: ["<boost/container/detail/math_functions.hpp>", private, "<boost/container/detail/adaptive_node_pool_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/math_functions.hpp>", private, "<boost/container/detail/node_pool_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/memory_util.hpp>", private, "<boost/container/detail/utilities.hpp>", private ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/detail/adaptive_node_pool_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/detail/algorithms.hpp>", private ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/detail/allocator_version_traits.hpp>", private ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/detail/node_alloc_holder.hpp>", private ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/detail/node_pool_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/detail/pair.hpp>", private ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/detail/utilities.hpp>", private ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/detail/version_type.hpp>", private ] },
+ { include: ["<boost/container/detail/multiallocation_chain.hpp>", private, "<boost/container/detail/allocator_version_traits.hpp>", private ] },
+ { include: ["<boost/container/detail/multiallocation_chain.hpp>", private, "<boost/interprocess/allocators/detail/allocator_common.hpp>", private ] },
+ { include: ["<boost/container/detail/multiallocation_chain.hpp>", private, "<boost/interprocess/mem_algo/detail/mem_algo_common.hpp>", private ] },
+ { include: ["<boost/container/detail/multiallocation_chain.hpp>", private, "<boost/interprocess/mem_algo/detail/simple_seq_fit_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/node_alloc_holder.hpp>", private, "<boost/container/detail/tree.hpp>", private ] },
+ { include: ["<boost/container/detail/node_pool_impl.hpp>", private, "<boost/interprocess/allocators/detail/node_pool.hpp>", private ] },
+ { include: ["<boost/container/detail/pair.hpp>", private, "<boost/container/detail/flat_tree.hpp>", private ] },
+ { include: ["<boost/container/detail/pair.hpp>", private, "<boost/container/detail/tree.hpp>", private ] },
+ { include: ["<boost/container/detail/pool_common.hpp>", private, "<boost/container/detail/adaptive_node_pool_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/pool_common.hpp>", private, "<boost/container/detail/node_pool_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/preprocessor.hpp>", private, "<boost/container/detail/advanced_insert_int.hpp>", private ] },
+ { include: ["<boost/container/detail/preprocessor.hpp>", private, "<boost/container/detail/iterators.hpp>", private ] },
+ { include: ["<boost/container/detail/preprocessor.hpp>", private, "<boost/container/detail/memory_util.hpp>", private ] },
+ { include: ["<boost/container/detail/preprocessor.hpp>", private, "<boost/container/detail/node_alloc_holder.hpp>", private ] },
+ { include: ["<boost/container/detail/preprocessor.hpp>", private, "<boost/container/detail/pair.hpp>", private ] },
+ { include: ["<boost/container/detail/preprocessor.hpp>", private, "<boost/container/detail/tree.hpp>", private ] },
+ { include: ["<boost/container/detail/preprocessor.hpp>", private, "<boost/geometry/index/detail/varray.hpp>", private ] },
+ { include: ["<boost/container/detail/transform_iterator.hpp>", private, "<boost/container/detail/multiallocation_chain.hpp>", private ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/detail/adaptive_node_pool_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/detail/algorithms.hpp>", private ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/detail/iterators.hpp>", private ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/detail/multiallocation_chain.hpp>", private ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/detail/node_alloc_holder.hpp>", private ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/detail/node_pool_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/detail/pair.hpp>", private ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/detail/transform_iterator.hpp>", private ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/detail/tree.hpp>", private ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/detail/utilities.hpp>", private ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/detail/variadic_templates_tools.hpp>", private ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/detail/version_type.hpp>", private ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/detail/adaptive_node_pool_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/detail/destroyers.hpp>", private ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/detail/flat_tree.hpp>", private ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/detail/multiallocation_chain.hpp>", private ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/detail/node_alloc_holder.hpp>", private ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/detail/node_pool_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/detail/tree.hpp>", private ] },
+ { include: ["<boost/container/detail/value_init.hpp>", private, "<boost/container/detail/advanced_insert_int.hpp>", private ] },
+ { include: ["<boost/container/detail/value_init.hpp>", private, "<boost/container/detail/flat_tree.hpp>", private ] },
+ { include: ["<boost/container/detail/variadic_templates_tools.hpp>", private, "<boost/container/detail/advanced_insert_int.hpp>", private ] },
+ { include: ["<boost/container/detail/variadic_templates_tools.hpp>", private, "<boost/container/detail/iterators.hpp>", private ] },
+ { include: ["<boost/container/detail/version_type.hpp>", private, "<boost/container/detail/allocator_version_traits.hpp>", private ] },
+ { include: ["<boost/container/detail/version_type.hpp>", private, "<boost/container/detail/destroyers.hpp>", private ] },
+ { include: ["<boost/container/detail/version_type.hpp>", private, "<boost/container/detail/node_alloc_holder.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/adaptive_node_pool_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/advanced_insert_int.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/algorithms.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/allocation_type.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/allocator_version_traits.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/destroyers.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/flat_tree.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/iterators.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/memory_util.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/node_alloc_holder.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/node_pool_impl.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/pair.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/preprocessor.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/transform_iterator.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/tree.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/value_init.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/container/detail/variadic_templates_tools.hpp>", private ] },
+ { include: ["<boost/container/detail/workaround.hpp>", private, "<boost/geometry/index/detail/varray.hpp>", private ] },
+ { include: ["<boost/context/detail/config.hpp>", private, "<boost/context/detail/fcontext_arm.hpp>", private ] },
+ { include: ["<boost/context/detail/config.hpp>", private, "<boost/context/detail/fcontext_i386.hpp>", private ] },
+ { include: ["<boost/context/detail/config.hpp>", private, "<boost/context/detail/fcontext_i386_win.hpp>", private ] },
+ { include: ["<boost/context/detail/config.hpp>", private, "<boost/context/detail/fcontext_mips.hpp>", private ] },
+ { include: ["<boost/context/detail/config.hpp>", private, "<boost/context/detail/fcontext_ppc.hpp>", private ] },
+ { include: ["<boost/context/detail/config.hpp>", private, "<boost/context/detail/fcontext_sparc.hpp>", private ] },
+ { include: ["<boost/context/detail/config.hpp>", private, "<boost/context/detail/fcontext_x86_64.hpp>", private ] },
+ { include: ["<boost/context/detail/config.hpp>", private, "<boost/context/detail/fcontext_x86_64_win.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/detail/coroutine_context.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/detail/segmented_stack_allocator.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/detail/stack_tuple.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/detail/standard_stack_allocator.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v1/detail/arg.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v1/detail/coroutine_base.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v1/detail/coroutine_base_resume.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v1/detail/coroutine_caller.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v1/detail/coroutine_get.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v1/detail/coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v1/detail/coroutine_op.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v2/detail/pull_coroutine_base.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v2/detail/pull_coroutine_caller.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v2/detail/pull_coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v2/detail/push_coroutine_base.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v2/detail/push_coroutine_caller.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v2/detail/push_coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/coroutine_context.hpp>", private, "<boost/coroutine/detail/holder.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/coroutine_context.hpp>", private, "<boost/coroutine/v1/detail/coroutine_base.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/coroutine_context.hpp>", private, "<boost/coroutine/v1/detail/coroutine_base_resume.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/coroutine_context.hpp>", private, "<boost/coroutine/v2/detail/pull_coroutine_base.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/coroutine_context.hpp>", private, "<boost/coroutine/v2/detail/push_coroutine_base.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/flags.hpp>", private, "<boost/coroutine/v1/detail/coroutine_base.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/flags.hpp>", private, "<boost/coroutine/v1/detail/coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/flags.hpp>", private, "<boost/coroutine/v2/detail/pull_coroutine_base.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/flags.hpp>", private, "<boost/coroutine/v2/detail/pull_coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/flags.hpp>", private, "<boost/coroutine/v2/detail/push_coroutine_base.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/flags.hpp>", private, "<boost/coroutine/v2/detail/push_coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/holder.hpp>", private, "<boost/coroutine/v1/detail/coroutine_base_resume.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/holder.hpp>", private, "<boost/coroutine/v1/detail/coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/holder.hpp>", private, "<boost/coroutine/v2/detail/pull_coroutine_base.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/holder.hpp>", private, "<boost/coroutine/v2/detail/pull_coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/holder.hpp>", private, "<boost/coroutine/v2/detail/push_coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/param.hpp>", private, "<boost/coroutine/v1/detail/coroutine_get.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/param.hpp>", private, "<boost/coroutine/v1/detail/coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/param.hpp>", private, "<boost/coroutine/v2/detail/pull_coroutine_base.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/param.hpp>", private, "<boost/coroutine/v2/detail/pull_coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/param.hpp>", private, "<boost/coroutine/v2/detail/push_coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/stack_tuple.hpp>", private, "<boost/coroutine/v1/detail/coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/stack_tuple.hpp>", private, "<boost/coroutine/v2/detail/pull_coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/stack_tuple.hpp>", private, "<boost/coroutine/v2/detail/push_coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/trampoline.hpp>", private, "<boost/coroutine/v1/detail/coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/trampoline.hpp>", private, "<boost/coroutine/v2/detail/pull_coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/detail/trampoline.hpp>", private, "<boost/coroutine/v2/detail/push_coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/v1/detail/arg.hpp>", private, "<boost/coroutine/v1/detail/coroutine_base_resume.hpp>", private ] },
+ { include: ["<boost/coroutine/v1/detail/arg.hpp>", private, "<boost/coroutine/v1/detail/coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/v1/detail/arg.hpp>", private, "<boost/coroutine/v1/detail/coroutine_op.hpp>", private ] },
+ { include: ["<boost/coroutine/v1/detail/coroutine_base.hpp>", private, "<boost/coroutine/v1/detail/coroutine_caller.hpp>", private ] },
+ { include: ["<boost/coroutine/v1/detail/coroutine_base.hpp>", private, "<boost/coroutine/v1/detail/coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/v1/detail/coroutine_base_resume.hpp>", private, "<boost/coroutine/v1/detail/coroutine_base.hpp>", private ] },
+ { include: ["<boost/coroutine/v1/detail/coroutine_object_result_0.ipp>", private, "<boost/coroutine/v1/detail/coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/v1/detail/coroutine_object_result_1.ipp>", private, "<boost/coroutine/v1/detail/coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/v1/detail/coroutine_object_result_arity.ipp>", private, "<boost/coroutine/v1/detail/coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/v1/detail/coroutine_object_void_0.ipp>", private, "<boost/coroutine/v1/detail/coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/v1/detail/coroutine_object_void_1.ipp>", private, "<boost/coroutine/v1/detail/coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/v1/detail/coroutine_object_void_arity.ipp>", private, "<boost/coroutine/v1/detail/coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/v2/detail/pull_coroutine_base.hpp>", private, "<boost/coroutine/v2/detail/pull_coroutine_caller.hpp>", private ] },
+ { include: ["<boost/coroutine/v2/detail/pull_coroutine_base.hpp>", private, "<boost/coroutine/v2/detail/pull_coroutine_object.hpp>", private ] },
+ { include: ["<boost/coroutine/v2/detail/push_coroutine_base.hpp>", private, "<boost/coroutine/v2/detail/push_coroutine_caller.hpp>", private ] },
+ { include: ["<boost/coroutine/v2/detail/push_coroutine_base.hpp>", private, "<boost/coroutine/v2/detail/push_coroutine_object.hpp>", private ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/multi_index/detail/auto_space.hpp>", private ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/multi_index/detail/hash_index_node.hpp>", private ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/multi_index/detail/index_base.hpp>", private ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/multi_index/detail/ord_index_node.hpp>", private ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/multi_index/detail/rnd_index_loader.hpp>", private ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/multi_index/detail/rnd_index_node.hpp>", private ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/multi_index/detail/rnd_index_ptr_array.hpp>", private ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/multi_index/detail/seq_index_node.hpp>", private ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/statechart/detail/memory.hpp>", private ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/statechart/detail/state_base.hpp>", private ] },
+ { include: ["<boost/detail/atomic_count.hpp>", private, "<boost/asio/detail/atomic_count.hpp>", private ] },
+ { include: ["<boost/detail/atomic_count.hpp>", private, "<boost/serialization/detail/shared_ptr_nmt_132.hpp>", private ] },
+ { include: ["<boost/detail/atomic_count.hpp>", private, "<boost/spirit/home/support/iterators/detail/ref_counted_policy.hpp>", private ] },
+ { include: ["<boost/detail/atomic_count.hpp>", private, "<boost/statechart/detail/counted_base.hpp>", private ] },
+ { include: ["<boost/detail/atomic_count.hpp>", private, "<boost/xpressive/detail/utility/counted_base.hpp>", private ] },
+ { include: ["<boost/detail/atomic_count.hpp>", private, "<boost/xpressive/detail/utility/tracking_ptr.hpp>", private ] },
+ { include: ["<boost/detail/binary_search.hpp>", private, "<boost/python/suite/indexing/detail/indexing_suite_detail.hpp>", private ] },
+ { include: ["<boost/detail/container_fwd.hpp>", private, "<boost/lambda/detail/operator_return_type_traits.hpp>", private ] },
+ { include: ["<boost/detail/container_fwd.hpp>", private, "<boost/phoenix/stl/algorithm/detail/is_std_list.hpp>", private ] },
+ { include: ["<boost/detail/container_fwd.hpp>", private, "<boost/phoenix/stl/algorithm/detail/is_std_map.hpp>", private ] },
+ { include: ["<boost/detail/container_fwd.hpp>", private, "<boost/phoenix/stl/algorithm/detail/is_std_set.hpp>", private ] },
+ { include: ["<boost/detail/container_fwd.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/detail/is_std_list.hpp>", private ] },
+ { include: ["<boost/detail/container_fwd.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/detail/is_std_map.hpp>", private ] },
+ { include: ["<boost/detail/container_fwd.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/detail/is_std_set.hpp>", private ] },
+ { include: ["<boost/detail/endian.hpp>", private, "<boost/math/special_functions/detail/fp_traits.hpp>", private ] },
+ { include: ["<boost/detail/endian.hpp>", private, "<boost/spirit/home/support/detail/endian/endian.hpp>", private ] },
+ { include: ["<boost/detail/endian.hpp>", private, "<boost/spirit/home/support/detail/math/detail/fp_traits.hpp>", private ] },
+ { include: ["<boost/detail/fenv.hpp>", private, "<boost/numeric/interval/detail/c99sub_rounding_control.hpp>", private ] },
+ { include: ["<boost/detail/indirect_traits.hpp>", private, "<boost/iterator/detail/facade_iterator_category.hpp>", private ] },
+ { include: ["<boost/detail/indirect_traits.hpp>", private, "<boost/python/detail/caller.hpp>", private ] },
+ { include: ["<boost/detail/indirect_traits.hpp>", private, "<boost/python/detail/indirect_traits.hpp>", private ] },
+ { include: ["<boost/detail/interlocked.hpp>", private, "<boost/atomic/detail/interlocked.hpp>", private ] },
+ { include: ["<boost/detail/interlocked.hpp>", private, "<boost/interprocess/detail/win32_api.hpp>", private ] },
+ { include: ["<boost/detail/interlocked.hpp>", private, "<boost/log/detail/spin_mutex.hpp>", private ] },
+ { include: ["<boost/detail/interlocked.hpp>", private, "<boost/smart_ptr/detail/atomic_count_win32.hpp>", private ] },
+ { include: ["<boost/detail/interlocked.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_w32.hpp>", private ] },
+ { include: ["<boost/detail/interlocked.hpp>", private, "<boost/smart_ptr/detail/spinlock_w32.hpp>", private ] },
+ { include: ["<boost/detail/is_incrementable.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/detail/is_xxx.hpp>", private, "<boost/python/detail/is_xxx.hpp>", private ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/algorithm/string/detail/finder.hpp>", private ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/algorithm/string/detail/trim.hpp>", private ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/iostreams/detail/adapter/range_adapter.hpp>", private ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/multi_index/detail/safe_mode.hpp>", private ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/range/detail/collection_traits_detail.hpp>", private ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/qi/numeric/detail/numeric_utils.hpp>", private ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/qi/stream/detail/iterator_source.hpp>", private ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/qi/string/detail/tst.hpp>", private ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/support/detail/lexer/generate_cpp.hpp>", private ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/support/detail/lexer/input.hpp>", private ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/support/iterators/detail/buffering_input_iterator_policy.hpp>", private ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/support/iterators/detail/input_iterator_policy.hpp>", private ] },
+ { include: ["<boost/detail/lightweight_mutex.hpp>", private, "<boost/flyweight/detail/recursive_lw_mutex.hpp>", private ] },
+ { include: ["<boost/detail/lightweight_mutex.hpp>", private, "<boost/multi_index/detail/safe_mode.hpp>", private ] },
+ { include: ["<boost/detail/lightweight_mutex.hpp>", private, "<boost/serialization/detail/shared_count_132.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/archive/detail/iserializer.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/detail/adaptive_node_pool_impl.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/detail/advanced_insert_int.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/detail/algorithms.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/detail/allocator_version_traits.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/detail/node_alloc_holder.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/detail/node_pool_impl.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/detail/tree.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/detail/utilities.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/flyweight/detail/flyweight_core.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/geometry/index/detail/varray_detail.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/interprocess/detail/intersegment_ptr.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/interprocess/detail/managed_memory_impl.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/interprocess/detail/managed_multi_shared_memory.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/interprocess/detail/segment_manager_helper.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/interprocess/smart_ptr/detail/shared_count.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/intrusive/detail/utilities.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/multi_index/detail/archive_constructed.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/multi_index/detail/copy_map.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/multi_index/detail/scope_guard.hpp>", private ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/multi_index/detail/seq_index_ops.hpp>", private ] },
+ { include: ["<boost/detail/quick_allocator.hpp>", private, "<boost/serialization/detail/shared_count_132.hpp>", private ] },
+ { include: ["<boost/detail/reference_content.hpp>", private, "<boost/variant/detail/initializer.hpp>", private ] },
+ { include: ["<boost/detail/scoped_enum_emulation.hpp>", private, "<boost/spirit/home/support/detail/scoped_enum_emulation.hpp>", private ] },
+ { include: ["<boost/detail/select_type.hpp>", private, "<boost/unordered/detail/allocate.hpp>", private ] },
+ { include: ["<boost/detail/select_type.hpp>", private, "<boost/unordered/detail/util.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/exception/detail/type_info.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_acc_ia64.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_aix.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_cw_ppc.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_cw_x86.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_gcc_ia64.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_gcc_mips.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_gcc_ppc.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_gcc_sparc.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_gcc_x86.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_nt.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_pt.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_snc_ps3.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_solaris.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_spin.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_sync.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_vacpp_ppc.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_w32.hpp>", private ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/spirit/home/support/detail/hold_any.hpp>", private ] },
+ { include: ["<boost/detail/templated_streams.hpp>", private, "<boost/variant/detail/variant_io.hpp>", private ] },
+ { include: ["<boost/detail/utf8_codecvt_facet.hpp>", private, "<boost/archive/detail/utf8_codecvt_facet.hpp>", private ] },
+ { include: ["<boost/detail/utf8_codecvt_facet.hpp>", private, "<boost/filesystem/detail/utf8_codecvt_facet.hpp>", private ] },
+ { include: ["<boost/detail/utf8_codecvt_facet.hpp>", private, "<boost/program_options/detail/utf8_codecvt_facet.hpp>", private ] },
+ { include: ["<boost/detail/winapi/GetCurrentProcess.hpp>", private, "<boost/chrono/detail/inlined/win/process_cpu_clocks.hpp>", private ] },
+ { include: ["<boost/detail/winapi/GetCurrentThread.hpp>", private, "<boost/chrono/detail/inlined/win/thread_clock.hpp>", private ] },
+ { include: ["<boost/detail/winapi/GetLastError.hpp>", private, "<boost/chrono/detail/inlined/win/chrono.hpp>", private ] },
+ { include: ["<boost/detail/winapi/GetLastError.hpp>", private, "<boost/chrono/detail/inlined/win/process_cpu_clocks.hpp>", private ] },
+ { include: ["<boost/detail/winapi/GetLastError.hpp>", private, "<boost/chrono/detail/inlined/win/thread_clock.hpp>", private ] },
+ { include: ["<boost/detail/winapi/GetProcessTimes.hpp>", private, "<boost/chrono/detail/inlined/win/process_cpu_clocks.hpp>", private ] },
+ { include: ["<boost/detail/winapi/GetThreadTimes.hpp>", private, "<boost/chrono/detail/inlined/win/thread_clock.hpp>", private ] },
+ { include: ["<boost/detail/winapi/time.hpp>", private, "<boost/chrono/detail/inlined/win/chrono.hpp>", private ] },
+ { include: ["<boost/detail/winapi/timers.hpp>", private, "<boost/chrono/detail/inlined/win/chrono.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/archive/detail/iserializer.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/archive/detail/oserializer.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/archive/impl/basic_xml_grammar.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/concept/detail/has_constraints.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/context/detail/config.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/coroutine/detail/config.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/flyweight/detail/flyweight_core.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/flyweight/detail/not_placeholder_expr.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/format/detail/config_macros.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/function_types/detail/cv_traits.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/graph/detail/adjacency_list.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/graph/detail/adj_list_edge_iterator.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/graph/detail/read_graphviz_new.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/graph/detail/read_graphviz_spirit.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/interprocess/detail/xsi_shared_memory_device.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/interprocess/detail/xsi_shared_memory_file_wrapper.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/adapter/direct_adapter.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/adapter/mode_adapter.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/add_facet.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/codecvt_helper.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/config/codecvt.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/config/disable_warnings.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/config/dyn_link.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/config/overload_resolution.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/config/wide_streams.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/default_arg.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/double_object.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/execute.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/forward.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/ios.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/is_dereferenceable.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/is_iterator_range.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/push.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/streambuf/chainbuf.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/streambuf/indirect_streambuf.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iostreams/detail/wrap_unwrap.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iterator/detail/config_def.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/iterator/detail/enable_if.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/lambda/detail/lambda_functors.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/multi_index/detail/access_specifier.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/multi_index/detail/base_type.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/multi_index/detail/index_base.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/multi_index/detail/msvc_index_specifier.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/multi_index/detail/node_type.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/multi_index/detail/prevent_eti.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/multi_index/detail/safe_ctr_proxy.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/multi_index/detail/unbounded.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/program_options/detail/cmdline.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/program_options/detail/config_file.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/program_options/detail/convert.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/proto/detail/as_expr.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/proto/detail/decltype.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/python/detail/config.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/python/detail/destroy.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/python/detail/enable_if.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/python/detail/string_literal.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/random/detail/operators.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/range/detail/begin.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/range/detail/end.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/range/detail/size.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/serialization/detail/shared_ptr_132.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/signals2/detail/auto_buffer.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/smart_ptr/detail/shared_count.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_w32.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/spirit/home/karma/detail/alternative_function.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/spirit/home/karma/numeric/detail/bool_utils.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/spirit/home/karma/numeric/detail/real_utils.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/spirit/home/support/detail/lexer/parser/tokeniser/num_token.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/spirit/home/support/detail/what_function.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/statechart/detail/rtti_policy.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/statechart/detail/state_base.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/test/detail/config.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/thread/detail/config.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/tti/detail/dmem_data.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/tuple/detail/tuple_basic.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/type_traits/detail/cv_traits_impl.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/variant/detail/apply_visitor_binary.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/variant/detail/apply_visitor_unary.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/variant/detail/config.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/variant/detail/move.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/xpressive/detail/core/results_cache.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/xpressive/detail/utility/literals.hpp>", private ] },
+ { include: ["<boost/detail/workaround.hpp>", private, "<boost/xpressive/detail/utility/tracking_ptr.hpp>", private ] },
+ { include: ["<boost/exception/detail/clone_current_exception.hpp>", private, "<boost/exception/detail/exception_ptr.hpp>", private ] },
+ { include: ["<boost/exception/detail/type_info.hpp>", private, "<boost/exception/detail/exception_ptr.hpp>", private ] },
+ { include: ["<boost/exception/detail/type_info.hpp>", private, "<boost/exception/detail/object_hex_dump.hpp>", private ] },
+ { include: ["<boost/flyweight/detail/dyn_perfect_fwd.hpp>", private, "<boost/flyweight/detail/perfect_fwd.hpp>", private ] },
+ { include: ["<boost/flyweight/detail/is_placeholder_expr.hpp>", private, "<boost/flyweight/detail/nested_xxx_if_not_ph.hpp>", private ] },
+ { include: ["<boost/flyweight/detail/perfect_fwd.hpp>", private, "<boost/flyweight/detail/default_value_policy.hpp>", private ] },
+ { include: ["<boost/flyweight/detail/perfect_fwd.hpp>", private, "<boost/flyweight/detail/flyweight_core.hpp>", private ] },
+ { include: ["<boost/flyweight/detail/pp_perfect_fwd.hpp>", private, "<boost/flyweight/detail/perfect_fwd.hpp>", private ] },
+ { include: ["<boost/flyweight/detail/value_tag.hpp>", private, "<boost/flyweight/detail/default_value_policy.hpp>", private ] },
+ { include: ["<boost/format/detail/config_macros.hpp>", private, "<boost/format/detail/compat_workarounds.hpp>", private ] },
+ { include: ["<boost/format/detail/workarounds_gcc-2_95.hpp>", private, "<boost/format/detail/config_macros.hpp>", private ] },
+ { include: ["<boost/format/detail/workarounds_stlport.hpp>", private, "<boost/format/detail/config_macros.hpp>", private ] },
+ { include: ["<boost/functional/hash/detail/float_functions.hpp>", private, "<boost/functional/hash/detail/hash_float.hpp>", private ] },
+ { include: ["<boost/functional/hash/detail/limits.hpp>", private, "<boost/functional/hash/detail/hash_float.hpp>", private ] },
+ { include: ["<boost/function/detail/maybe_include.hpp>", private, "<boost/function/detail/function_iterate.hpp>", private ] },
+ { include: ["<boost/function_types/detail/classifier_impl/arity10_0.hpp>", private, "<boost/function_types/detail/classifier_impl/arity20_0.hpp>", private ] },
+ { include: ["<boost/function_types/detail/classifier_impl/arity10_1.hpp>", private, "<boost/function_types/detail/classifier_impl/arity20_1.hpp>", private ] },
+ { include: ["<boost/function_types/detail/classifier_impl/arity20_0.hpp>", private, "<boost/function_types/detail/classifier_impl/arity30_0.hpp>", private ] },
+ { include: ["<boost/function_types/detail/classifier_impl/arity20_1.hpp>", private, "<boost/function_types/detail/classifier_impl/arity30_1.hpp>", private ] },
+ { include: ["<boost/function_types/detail/classifier_impl/arity30_0.hpp>", private, "<boost/function_types/detail/classifier_impl/arity40_0.hpp>", private ] },
+ { include: ["<boost/function_types/detail/classifier_impl/arity30_1.hpp>", private, "<boost/function_types/detail/classifier_impl/arity40_1.hpp>", private ] },
+ { include: ["<boost/function_types/detail/classifier_impl/arity40_0.hpp>", private, "<boost/function_types/detail/classifier_impl/arity50_0.hpp>", private ] },
+ { include: ["<boost/function_types/detail/classifier_impl/arity40_1.hpp>", private, "<boost/function_types/detail/classifier_impl/arity50_1.hpp>", private ] },
+ { include: ["<boost/function_types/detail/components_impl/arity10_0.hpp>", private, "<boost/function_types/detail/components_impl/arity20_0.hpp>", private ] },
+ { include: ["<boost/function_types/detail/components_impl/arity10_1.hpp>", private, "<boost/function_types/detail/components_impl/arity20_1.hpp>", private ] },
+ { include: ["<boost/function_types/detail/components_impl/arity20_0.hpp>", private, "<boost/function_types/detail/components_impl/arity30_0.hpp>", private ] },
+ { include: ["<boost/function_types/detail/components_impl/arity20_1.hpp>", private, "<boost/function_types/detail/components_impl/arity30_1.hpp>", private ] },
+ { include: ["<boost/function_types/detail/components_impl/arity30_0.hpp>", private, "<boost/function_types/detail/components_impl/arity40_0.hpp>", private ] },
+ { include: ["<boost/function_types/detail/components_impl/arity30_1.hpp>", private, "<boost/function_types/detail/components_impl/arity40_1.hpp>", private ] },
+ { include: ["<boost/function_types/detail/components_impl/arity40_0.hpp>", private, "<boost/function_types/detail/components_impl/arity50_0.hpp>", private ] },
+ { include: ["<boost/function_types/detail/components_impl/arity40_1.hpp>", private, "<boost/function_types/detail/components_impl/arity50_1.hpp>", private ] },
+ { include: ["<boost/function_types/detail/cv_traits.hpp>", private, "<boost/function_types/detail/synthesize.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/aliases_def.hpp>", private, "<boost/function_types/detail/pp_cc_loop/master.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/aliases_def.hpp>", private, "<boost/function_types/detail/pp_loop.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/aliases_def.hpp>", private, "<boost/function_types/detail/pp_retag_default_cc/master.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/aliases_undef.hpp>", private, "<boost/function_types/detail/pp_cc_loop/master.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/aliases_undef.hpp>", private, "<boost/function_types/detail/pp_loop.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/aliases_undef.hpp>", private, "<boost/function_types/detail/pp_retag_default_cc/master.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/def.hpp>", private, "<boost/function_types/detail/pp_cc_loop/master.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/def.hpp>", private, "<boost/function_types/detail/pp_loop.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/def.hpp>", private, "<boost/function_types/detail/pp_retag_default_cc/master.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/def.hpp>", private, "<boost/function_types/detail/pp_tags/master.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/def.hpp>", private, "<boost/function_types/detail/pp_variate_loop/master.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/undef.hpp>", private, "<boost/function_types/detail/pp_cc_loop/master.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/undef.hpp>", private, "<boost/function_types/detail/pp_loop.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/undef.hpp>", private, "<boost/function_types/detail/pp_retag_default_cc/master.hpp>", private ] },
+ { include: ["<boost/function_types/detail/encoding/undef.hpp>", private, "<boost/function_types/detail/pp_variate_loop/master.hpp>", private ] },
+ { include: ["<boost/function_types/detail/pp_loop.hpp>", private, "<boost/function_types/detail/classifier.hpp>", private ] },
+ { include: ["<boost/function_types/detail/pp_loop.hpp>", private, "<boost/function_types/detail/synthesize.hpp>", private ] },
+ { include: ["<boost/function_types/detail/pp_retag_default_cc/master.hpp>", private, "<boost/function_types/detail/retag_default_cc.hpp>", private ] },
+ { include: ["<boost/function_types/detail/pp_retag_default_cc/preprocessed.hpp>", private, "<boost/function_types/detail/retag_default_cc.hpp>", private ] },
+ { include: ["<boost/function_types/detail/retag_default_cc.hpp>", private, "<boost/function_types/detail/synthesize.hpp>", private ] },
+ { include: ["<boost/function_types/detail/synthesize_impl/arity10_0.hpp>", private, "<boost/function_types/detail/synthesize_impl/arity20_0.hpp>", private ] },
+ { include: ["<boost/function_types/detail/synthesize_impl/arity10_1.hpp>", private, "<boost/function_types/detail/synthesize_impl/arity20_1.hpp>", private ] },
+ { include: ["<boost/function_types/detail/synthesize_impl/arity20_0.hpp>", private, "<boost/function_types/detail/synthesize_impl/arity30_0.hpp>", private ] },
+ { include: ["<boost/function_types/detail/synthesize_impl/arity20_1.hpp>", private, "<boost/function_types/detail/synthesize_impl/arity30_1.hpp>", private ] },
+ { include: ["<boost/function_types/detail/synthesize_impl/arity30_0.hpp>", private, "<boost/function_types/detail/synthesize_impl/arity40_0.hpp>", private ] },
+ { include: ["<boost/function_types/detail/synthesize_impl/arity30_1.hpp>", private, "<boost/function_types/detail/synthesize_impl/arity40_1.hpp>", private ] },
+ { include: ["<boost/function_types/detail/synthesize_impl/arity40_0.hpp>", private, "<boost/function_types/detail/synthesize_impl/arity50_0.hpp>", private ] },
+ { include: ["<boost/function_types/detail/synthesize_impl/arity40_1.hpp>", private, "<boost/function_types/detail/synthesize_impl/arity50_1.hpp>", private ] },
+ { include: ["<boost/fusion/adapted/struct/detail/adapt_base.hpp>", private, "<boost/fusion/adapted/struct/detail/define_struct.hpp>", private ] },
+ { include: ["<boost/fusion/adapted/struct/detail/define_struct.hpp>", private, "<boost/fusion/adapted/struct/detail/define_struct_inline.hpp>", private ] },
+ { include: ["<boost/fusion/adapted/struct/detail/extension.hpp>", private, "<boost/fusion/adapted/adt/detail/extension.hpp>", private ] },
+ { include: ["<boost/fusion/adapted/struct/detail/namespace.hpp>", private, "<boost/fusion/adapted/struct/detail/define_struct.hpp>", private ] },
+ { include: ["<boost/fusion/adapted/struct/detail/namespace.hpp>", private, "<boost/fusion/adapted/struct/detail/proxy_type.hpp>", private ] },
+ { include: ["<boost/fusion/algorithm/query/detail/find_if.hpp>", private, "<boost/fusion/view/filter_view/detail/next_impl.hpp>", private ] },
+ { include: ["<boost/fusion/algorithm/transformation/detail/preprocessed/zip10.hpp>", private, "<boost/fusion/algorithm/transformation/detail/preprocessed/zip.hpp>", private ] },
+ { include: ["<boost/fusion/algorithm/transformation/detail/preprocessed/zip20.hpp>", private, "<boost/fusion/algorithm/transformation/detail/preprocessed/zip.hpp>", private ] },
+ { include: ["<boost/fusion/algorithm/transformation/detail/preprocessed/zip30.hpp>", private, "<boost/fusion/algorithm/transformation/detail/preprocessed/zip.hpp>", private ] },
+ { include: ["<boost/fusion/algorithm/transformation/detail/preprocessed/zip40.hpp>", private, "<boost/fusion/algorithm/transformation/detail/preprocessed/zip.hpp>", private ] },
+ { include: ["<boost/fusion/algorithm/transformation/detail/preprocessed/zip50.hpp>", private, "<boost/fusion/algorithm/transformation/detail/preprocessed/zip.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/at_impl.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/begin_impl.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/as_deque.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/build_deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/deque_forward_ctor.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/deque_initial_size.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/deque_keyed_values_call.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque_keyed_values.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/deque_keyed_values.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/limits.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/limits.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/limits.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque_keyed_values.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/as_deque10.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/as_deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/as_deque20.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/as_deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/as_deque30.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/as_deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/as_deque40.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/as_deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/as_deque50.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/as_deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/as_deque.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/as_deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque10_fwd.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque10.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque20_fwd.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque20.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque30_fwd.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque30.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque40_fwd.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque40.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque50_fwd.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque50.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_fwd.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_initial_size10.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_initial_size.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_initial_size20.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_initial_size.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_initial_size30.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_initial_size.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_initial_size40.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_initial_size.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_initial_size50.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_initial_size.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_initial_size.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque_initial_size.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_keyed_values10.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_keyed_values.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_keyed_values20.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_keyed_values.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_keyed_values30.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_keyed_values.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_keyed_values40.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_keyed_values.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_keyed_values50.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_keyed_values.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/preprocessed/deque_keyed_values.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque_keyed_values.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/end_impl.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/is_sequence_impl.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/keyed_element.hpp>", private, "<boost/fusion/container/deque/detail/at_impl.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/keyed_element.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/keyed_element.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque_keyed_values.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/keyed_element.hpp>", private, "<boost/fusion/container/deque/detail/deque_keyed_values.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/keyed_element.hpp>", private, "<boost/fusion/container/deque/detail/value_at_impl.hpp>", private ] },
+ { include: ["<boost/fusion/container/deque/detail/value_at_impl.hpp>", private, "<boost/fusion/container/deque/detail/cpp03/deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/deque_tie10.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/deque_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/deque_tie20.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/deque_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/deque_tie30.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/deque_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/deque_tie40.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/deque_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/deque_tie50.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/deque_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/deque_tie.hpp>", private, "<boost/fusion/container/generation/detail/pp_deque_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/list_tie10.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/list_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/list_tie20.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/list_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/list_tie30.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/list_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/list_tie40.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/list_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/list_tie50.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/list_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_deque10.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_deque20.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_deque30.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_deque40.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_deque50.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_deque.hpp>", private, "<boost/fusion/container/generation/detail/pp_make_deque.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_list10.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_list.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_list20.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_list.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_list30.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_list.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_list40.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_list.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_list50.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_list.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_map10.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_map.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_map20.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_map.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_map30.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_map.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_map40.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_map.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_map50.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_map.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_map.hpp>", private, "<boost/fusion/container/generation/detail/pp_make_map.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_set10.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_set.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_set20.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_set.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_set30.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_set.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_set40.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_set.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_set50.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_set.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_vector10.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_vector20.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_vector30.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_vector40.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_vector50.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/make_vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/map_tie10.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/map_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/map_tie20.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/map_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/map_tie30.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/map_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/map_tie40.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/map_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/map_tie50.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/map_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/map_tie.hpp>", private, "<boost/fusion/container/generation/detail/pp_map_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/vector_tie10.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/vector_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/vector_tie20.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/vector_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/vector_tie30.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/vector_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/vector_tie40.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/vector_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/vector_tie50.hpp>", private, "<boost/fusion/container/generation/detail/preprocessed/vector_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/build_cons.hpp>", private, "<boost/fusion/container/list/detail/convert_impl.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/list_to_cons_call.hpp>", private, "<boost/fusion/container/list/detail/list_to_cons.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list10_fwd.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list10.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list20_fwd.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list20.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list30_fwd.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list30.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list40_fwd.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list40.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list50_fwd.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list50.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list_to_cons10.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list_to_cons.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list_to_cons20.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list_to_cons.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list_to_cons30.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list_to_cons.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list_to_cons40.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list_to_cons.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list_to_cons50.hpp>", private, "<boost/fusion/container/list/detail/preprocessed/list_to_cons.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list_to_cons.hpp>", private, "<boost/fusion/container/list/detail/list_to_cons.hpp>", private ] },
+ { include: ["<boost/fusion/container/list/detail/reverse_cons.hpp>", private, "<boost/fusion/view/iterator_range/detail/segmented_iterator_range.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/as_map.hpp>", private, "<boost/fusion/container/map/detail/cpp03/convert.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/as_map.hpp>", private, "<boost/fusion/container/map/detail/cpp03/convert_impl.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/at_impl.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/begin_impl.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/convert_impl.hpp>", private, "<boost/fusion/container/map/detail/cpp03/convert.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/deref_data_impl.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/deref_impl.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/end_impl.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/key_of_impl.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/limits.hpp>", private, "<boost/fusion/container/generation/detail/pp_map_tie.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/limits.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/map_forward_ctor.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/map_fwd.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/map.hpp>", private, "<boost/fusion/container/map/detail/cpp03/as_map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/map.hpp>", private, "<boost/fusion/container/map/detail/cpp03/convert.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/map.hpp>", private, "<boost/fusion/container/map/detail/cpp03/convert_impl.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/as_map10.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/as_map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/as_map20.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/as_map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/as_map30.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/as_map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/as_map40.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/as_map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/as_map50.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/as_map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/as_map.hpp>", private, "<boost/fusion/container/map/detail/cpp03/as_map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/map10_fwd.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/map_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/map10.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/map20_fwd.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/map_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/map20.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/map30_fwd.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/map_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/map30.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/map40_fwd.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/map_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/map40.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/map50_fwd.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/map_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/map50.hpp>", private, "<boost/fusion/container/map/detail/cpp03/preprocessed/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/map_fwd.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/preprocessed/map.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/value_at_impl.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/value_of_data_impl.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/value_of_impl.hpp>", private, "<boost/fusion/container/map/detail/cpp03/key_of_impl.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/value_of_impl.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map.hpp>", private ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/value_of_impl.hpp>", private, "<boost/fusion/container/map/detail/cpp03/value_of_data_impl.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/as_set.hpp>", private, "<boost/fusion/container/set/detail/convert_impl.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/deref_impl.hpp>", private, "<boost/fusion/container/set/detail/deref_data_impl.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/as_set10.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/as_set.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/as_set20.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/as_set.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/as_set30.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/as_set.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/as_set40.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/as_set.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/as_set50.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/as_set.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/as_set.hpp>", private, "<boost/fusion/container/set/detail/as_set.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/set10_fwd.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/set_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/set10.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/set.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/set20_fwd.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/set_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/set20.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/set.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/set30_fwd.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/set_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/set30.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/set.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/set40_fwd.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/set_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/set40.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/set.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/set50_fwd.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/set_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/set50.hpp>", private, "<boost/fusion/container/set/detail/preprocessed/set.hpp>", private ] },
+ { include: ["<boost/fusion/container/set/detail/value_of_data_impl.hpp>", private, "<boost/fusion/container/set/detail/key_of_impl.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/as_vector.hpp>", private, "<boost/fusion/container/vector/detail/convert_impl.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/as_vector10.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/as_vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/as_vector20.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/as_vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/as_vector30.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/as_vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/as_vector40.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/as_vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/as_vector50.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/as_vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/as_vector.hpp>", private, "<boost/fusion/container/vector/detail/as_vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector_chooser10.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector_chooser.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector_chooser20.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector_chooser.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector_chooser30.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector_chooser.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector_chooser40.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector_chooser.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector_chooser50.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector_chooser.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector_chooser.hpp>", private, "<boost/fusion/container/vector/detail/vector_n_chooser.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vvector10_fwd.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vvector10.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vvector20_fwd.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vvector20.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vvector30_fwd.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vvector30.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vvector40_fwd.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vvector40.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vvector50_fwd.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vvector50.hpp>", private, "<boost/fusion/container/vector/detail/preprocessed/vector.hpp>", private ] },
+ { include: ["<boost/fusion/iterator/detail/adapt_deref_traits.hpp>", private, "<boost/fusion/view/filter_view/detail/deref_impl.hpp>", private ] },
+ { include: ["<boost/fusion/iterator/detail/adapt_deref_traits.hpp>", private, "<boost/fusion/view/joint_view/detail/deref_impl.hpp>", private ] },
+ { include: ["<boost/fusion/iterator/detail/adapt_value_traits.hpp>", private, "<boost/fusion/view/filter_view/detail/value_of_impl.hpp>", private ] },
+ { include: ["<boost/fusion/iterator/detail/adapt_value_traits.hpp>", private, "<boost/fusion/view/joint_view/detail/value_of_impl.hpp>", private ] },
+ { include: ["<boost/fusion/iterator/detail/segmented_equal_to.hpp>", private, "<boost/fusion/iterator/detail/segmented_iterator.hpp>", private ] },
+ { include: ["<boost/fusion/iterator/detail/segment_sequence.hpp>", private, "<boost/fusion/view/iterator_range/detail/segmented_iterator_range.hpp>", private ] },
+ { include: ["<boost/fusion/sequence/intrinsic/detail/segmented_begin_impl.hpp>", private, "<boost/fusion/sequence/intrinsic/detail/segmented_begin.hpp>", private ] },
+ { include: ["<boost/fusion/sequence/intrinsic/detail/segmented_end_impl.hpp>", private, "<boost/fusion/sequence/intrinsic/detail/segmented_begin_impl.hpp>", private ] },
+ { include: ["<boost/fusion/sequence/intrinsic/detail/segmented_end_impl.hpp>", private, "<boost/fusion/sequence/intrinsic/detail/segmented_end.hpp>", private ] },
+ { include: ["<boost/fusion/sequence/io/detail/manip.hpp>", private, "<boost/fusion/sequence/io/detail/in.hpp>", private ] },
+ { include: ["<boost/fusion/sequence/io/detail/manip.hpp>", private, "<boost/fusion/sequence/io/detail/out.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/adapted/std_tuple/detail/at_impl.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/algorithm/query/detail/count.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/deque/detail/keyed_element.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/list/detail/at_impl.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/list/detail/value_at_impl.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/map/detail/at_impl.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/map/detail/at_key_impl.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/map/detail/cpp03/at_impl.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/map/detail/cpp03/deref_data_impl.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/map/detail/cpp03/map.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/map/detail/map_impl.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/map/detail/value_at_key_impl.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/vector/detail/at_impl.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/vector/detail/deref_impl.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/container/generation/detail/pp_make_deque.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/container/generation/detail/pp_make_map.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/functional/generation/detail/gen_make_adapter.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/spirit/home/support/detail/make_vector.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/mpl_iterator_category.hpp>", private, "<boost/fusion/adapted/mpl/detail/category_of_impl.hpp>", private ] },
+ { include: ["<boost/fusion/support/detail/segmented_fold_until_impl.hpp>", private, "<boost/fusion/sequence/intrinsic/detail/segmented_begin_impl.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/make_tuple10.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/make_tuple.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/make_tuple20.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/make_tuple.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/make_tuple30.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/make_tuple.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/make_tuple40.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/make_tuple.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/make_tuple50.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/make_tuple.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple10_fwd.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple10.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple20_fwd.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple20.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple30_fwd.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple30.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple40_fwd.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple40.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple50_fwd.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple_fwd.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple50.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple_tie10.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple_tie.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple_tie20.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple_tie.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple_tie30.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple_tie.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple_tie40.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple_tie.hpp>", private ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple_tie50.hpp>", private, "<boost/fusion/tuple/detail/preprocessed/tuple_tie.hpp>", private ] },
+ { include: ["<boost/fusion/view/iterator_range/detail/segmented_iterator_range.hpp>", private, "<boost/fusion/view/iterator_range/detail/segments_impl.hpp>", private ] },
+ { include: ["<boost/fusion/view/transform_view/detail/apply_transform_result.hpp>", private, "<boost/fusion/view/transform_view/detail/at_impl.hpp>", private ] },
+ { include: ["<boost/fusion/view/transform_view/detail/apply_transform_result.hpp>", private, "<boost/fusion/view/transform_view/detail/deref_impl.hpp>", private ] },
+ { include: ["<boost/fusion/view/transform_view/detail/apply_transform_result.hpp>", private, "<boost/fusion/view/transform_view/detail/value_at_impl.hpp>", private ] },
+ { include: ["<boost/fusion/view/transform_view/detail/apply_transform_result.hpp>", private, "<boost/fusion/view/transform_view/detail/value_of_impl.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/assign_values.hpp>", private, "<boost/geometry/algorithms/detail/assign_box_corners.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/assign_values.hpp>", private, "<boost/geometry/algorithms/detail/assign_indexed_point.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/convert_point_to_point.hpp>", private, "<boost/geometry/algorithms/detail/point_is_spike_or_equal.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/convert_point_to_point.hpp>", private, "<boost/geometry/algorithms/detail/point_on_border.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/disjoint.hpp>", private, "<boost/geometry/algorithms/detail/overlay/append_no_duplicates.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/disjoint.hpp>", private, "<boost/geometry/algorithms/detail/overlay/append_no_dups_or_spikes.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/disjoint.hpp>", private, "<boost/geometry/algorithms/detail/overlay/get_turns.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/disjoint.hpp>", private, "<boost/geometry/algorithms/detail/overlay/self_turn_points.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/disjoint.hpp>", private, "<boost/geometry/algorithms/detail/point_on_border.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/for_each_range.hpp>", private, "<boost/geometry/multi/algorithms/detail/for_each_range.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/get_left_turns.hpp>", private, "<boost/geometry/algorithms/detail/occupation_info.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/has_self_intersections.hpp>", private, "<boost/geometry/algorithms/detail/overlay/backtrack_check_si.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/add_rings.hpp>", private, "<boost/geometry/algorithms/detail/overlay/overlay.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/append_no_duplicates.hpp>", private, "<boost/geometry/algorithms/detail/overlay/clip_linestring.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/append_no_duplicates.hpp>", private, "<boost/geometry/algorithms/detail/overlay/follow.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/append_no_dups_or_spikes.hpp>", private, "<boost/geometry/algorithms/detail/overlay/copy_segments.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/append_no_dups_or_spikes.hpp>", private, "<boost/geometry/algorithms/detail/overlay/traverse.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/assign_parents.hpp>", private, "<boost/geometry/algorithms/detail/overlay/overlay.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/backtrack_check_si.hpp>", private, "<boost/geometry/algorithms/detail/overlay/traverse.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/calculate_distance_policy.hpp>", private, "<boost/geometry/algorithms/detail/overlay/overlay.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/check_enrich.hpp>", private, "<boost/geometry/algorithms/detail/overlay/enrich_intersection_points.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/clip_linestring.hpp>", private, "<boost/geometry/algorithms/detail/overlay/intersection_insert.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/convert_ring.hpp>", private, "<boost/geometry/algorithms/detail/overlay/add_rings.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/copy_segment_point.hpp>", private, "<boost/geometry/algorithms/detail/overlay/enrich_intersection_points.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/copy_segment_point.hpp>", private, "<boost/geometry/algorithms/detail/overlay/handle_tangencies.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/copy_segment_point.hpp>", private, "<boost/geometry/multi/algorithms/detail/overlay/copy_segment_point.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/copy_segments.hpp>", private, "<boost/geometry/algorithms/detail/overlay/follow.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/copy_segments.hpp>", private, "<boost/geometry/algorithms/detail/overlay/traverse.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/copy_segments.hpp>", private, "<boost/geometry/multi/algorithms/detail/overlay/copy_segments.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/debug_turn_info.hpp>", private, "<boost/geometry/algorithms/detail/has_self_intersections.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/debug_turn_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/backtrack_check_si.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/debug_turn_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/enrich_intersection_points.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/debug_turn_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/traverse.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/enrich_intersection_points.hpp>", private, "<boost/geometry/algorithms/detail/overlay/overlay.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/enrichment_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/overlay.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/enrichment_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/traversal_info.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/follow.hpp>", private, "<boost/geometry/algorithms/detail/overlay/intersection_insert.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/get_intersection_points.hpp>", private, "<boost/geometry/algorithms/detail/overlay/intersection_insert.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/get_relative_order.hpp>", private, "<boost/geometry/algorithms/detail/overlay/enrich_intersection_points.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/get_ring.hpp>", private, "<boost/geometry/algorithms/detail/overlay/add_rings.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/get_ring.hpp>", private, "<boost/geometry/algorithms/detail/overlay/assign_parents.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/get_ring.hpp>", private, "<boost/geometry/multi/algorithms/detail/overlay/get_ring.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/get_turn_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/get_turns.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/get_turns.hpp>", private, "<boost/geometry/algorithms/detail/has_self_intersections.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/get_turns.hpp>", private, "<boost/geometry/algorithms/detail/overlay/get_intersection_points.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/get_turns.hpp>", private, "<boost/geometry/algorithms/detail/overlay/overlay.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/get_turns.hpp>", private, "<boost/geometry/algorithms/detail/overlay/self_turn_points.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/get_turns.hpp>", private, "<boost/geometry/multi/algorithms/detail/overlay/get_turns.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/handle_tangencies.hpp>", private, "<boost/geometry/algorithms/detail/overlay/enrich_intersection_points.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/msm_state.hpp>", private, "<boost/geometry/algorithms/detail/overlay/visit_info.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/overlay.hpp>", private, "<boost/geometry/algorithms/detail/overlay/intersection_insert.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/overlay_type.hpp>", private, "<boost/geometry/algorithms/detail/overlay/intersection_insert.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/overlay_type.hpp>", private, "<boost/geometry/algorithms/detail/overlay/overlay.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/overlay_type.hpp>", private, "<boost/geometry/algorithms/detail/overlay/select_rings.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/ring_properties.hpp>", private, "<boost/geometry/algorithms/detail/overlay/overlay.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/ring_properties.hpp>", private, "<boost/geometry/algorithms/detail/overlay/select_rings.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/segment_identifier.hpp>", private, "<boost/geometry/algorithms/detail/overlay/get_turns.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/segment_identifier.hpp>", private, "<boost/geometry/algorithms/detail/overlay/traversal_info.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/segment_identifier.hpp>", private, "<boost/geometry/algorithms/detail/overlay/turn_info.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/select_rings.hpp>", private, "<boost/geometry/algorithms/detail/overlay/overlay.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/select_rings.hpp>", private, "<boost/geometry/multi/algorithms/detail/overlay/select_rings.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/self_turn_points.hpp>", private, "<boost/geometry/algorithms/detail/has_self_intersections.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/self_turn_points.hpp>", private, "<boost/geometry/multi/algorithms/detail/overlay/self_turn_points.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/traversal_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/overlay.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/traverse.hpp>", private, "<boost/geometry/algorithms/detail/overlay/overlay.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/turn_info.hpp>", private, "<boost/geometry/algorithms/detail/has_self_intersections.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/turn_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/backtrack_check_si.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/turn_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/debug_turn_info.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/turn_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/follow.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/turn_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/get_turn_info.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/turn_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/handle_tangencies.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/turn_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/overlay.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/turn_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/traversal_info.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/turn_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/traverse.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/visit_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/debug_turn_info.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/visit_info.hpp>", private, "<boost/geometry/algorithms/detail/overlay/traversal_info.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/partition.hpp>", private, "<boost/geometry/algorithms/detail/overlay/assign_parents.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/partition.hpp>", private, "<boost/geometry/algorithms/detail/overlay/get_turns.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/partition.hpp>", private, "<boost/geometry/algorithms/detail/overlay/self_turn_points.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/point_is_spike_or_equal.hpp>", private, "<boost/geometry/algorithms/detail/overlay/append_no_dups_or_spikes.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/point_on_border.hpp>", private, "<boost/geometry/algorithms/detail/overlay/follow.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/point_on_border.hpp>", private, "<boost/geometry/algorithms/detail/overlay/intersection_insert.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/point_on_border.hpp>", private, "<boost/geometry/algorithms/detail/overlay/ring_properties.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/point_on_border.hpp>", private, "<boost/geometry/algorithms/detail/overlay/select_rings.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/point_on_border.hpp>", private, "<boost/geometry/multi/algorithms/detail/point_on_border.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/ring_identifier.hpp>", private, "<boost/geometry/algorithms/detail/overlay/check_enrich.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/ring_identifier.hpp>", private, "<boost/geometry/algorithms/detail/overlay/convert_ring.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/ring_identifier.hpp>", private, "<boost/geometry/algorithms/detail/overlay/enrich_intersection_points.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/ring_identifier.hpp>", private, "<boost/geometry/algorithms/detail/overlay/get_ring.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/ring_identifier.hpp>", private, "<boost/geometry/algorithms/detail/overlay/handle_tangencies.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/ring_identifier.hpp>", private, "<boost/geometry/algorithms/detail/overlay/select_rings.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/ring_identifier.hpp>", private, "<boost/geometry/algorithms/detail/sections/sectionalize.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/sections/range_by_section.hpp>", private, "<boost/geometry/algorithms/detail/overlay/get_turns.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/sections/range_by_section.hpp>", private, "<boost/geometry/multi/algorithms/detail/sections/range_by_section.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/sections/sectionalize.hpp>", private, "<boost/geometry/algorithms/detail/overlay/get_turns.hpp>", private ] },
+ { include: ["<boost/geometry/algorithms/detail/sections/sectionalize.hpp>", private, "<boost/geometry/multi/algorithms/detail/sections/sectionalize.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/bounds.hpp>", private, "<boost/geometry/index/detail/rtree/node/node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/comparable_distance_centroid.hpp>", private, "<boost/geometry/index/detail/distance_predicates.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/comparable_distance_far.hpp>", private, "<boost/geometry/index/detail/distance_predicates.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/comparable_distance_near.hpp>", private, "<boost/geometry/index/detail/distance_predicates.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/content.hpp>", private, "<boost/geometry/index/detail/algorithms/intersection_content.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/content.hpp>", private, "<boost/geometry/index/detail/algorithms/union_content.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/content.hpp>", private, "<boost/geometry/index/detail/rtree/linear/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/content.hpp>", private, "<boost/geometry/index/detail/rtree/quadratic/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/content.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/choose_next_node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/content.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/insert.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/content.hpp>", private, "<boost/geometry/index/detail/rtree/visitors/insert.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/diff_abs.hpp>", private, "<boost/geometry/index/detail/algorithms/comparable_distance_centroid.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/diff_abs.hpp>", private, "<boost/geometry/index/detail/algorithms/comparable_distance_far.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/diff_abs.hpp>", private, "<boost/geometry/index/detail/algorithms/minmaxdist.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/intersection_content.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/choose_next_node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/intersection_content.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/margin.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/path_intersection.hpp>", private, "<boost/geometry/index/detail/distance_predicates.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/segment_intersection.hpp>", private, "<boost/geometry/index/detail/algorithms/path_intersection.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/smallest_for_indexable.hpp>", private, "<boost/geometry/index/detail/algorithms/minmaxdist.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/sum_for_indexable.hpp>", private, "<boost/geometry/index/detail/algorithms/comparable_distance_centroid.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/sum_for_indexable.hpp>", private, "<boost/geometry/index/detail/algorithms/comparable_distance_far.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/sum_for_indexable.hpp>", private, "<boost/geometry/index/detail/algorithms/comparable_distance_near.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/sum_for_indexable.hpp>", private, "<boost/geometry/index/detail/algorithms/minmaxdist.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/union_content.hpp>", private, "<boost/geometry/index/detail/rtree/quadratic/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/union_content.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/choose_next_node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/algorithms/union_content.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/assert.hpp>", private, "<boost/geometry/index/detail/pushable_array.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/assert.hpp>", private, "<boost/geometry/index/detail/varray.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/linear/redistribute_elements.hpp>", private, "<boost/geometry/index/detail/rtree/linear/linear.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/auto_deallocator.hpp>", private, "<boost/geometry/index/detail/rtree/node/node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/concept.hpp>", private, "<boost/geometry/index/detail/rtree/node/node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/dynamic_visitor.hpp>", private, "<boost/geometry/index/detail/rtree/node/node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/node_auto_ptr.hpp>", private, "<boost/geometry/index/detail/rtree/node/node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/node_d_mem_dynamic.hpp>", private, "<boost/geometry/index/detail/rtree/node/node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/node_d_mem_static.hpp>", private, "<boost/geometry/index/detail/rtree/node/node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/node.hpp>", private, "<boost/geometry/index/detail/rtree/linear/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/node.hpp>", private, "<boost/geometry/index/detail/rtree/quadratic/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/node.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/choose_next_node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/node.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/node.hpp>", private, "<boost/geometry/index/detail/rtree/utilities/are_boxes_ok.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/node.hpp>", private, "<boost/geometry/index/detail/rtree/utilities/are_levels_ok.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/node_s_mem_dynamic.hpp>", private, "<boost/geometry/index/detail/rtree/node/node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/node_s_mem_static.hpp>", private, "<boost/geometry/index/detail/rtree/node/node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/pairs.hpp>", private, "<boost/geometry/index/detail/rtree/node/node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/static_visitor.hpp>", private, "<boost/geometry/index/detail/rtree/node/node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/quadratic/redistribute_elements.hpp>", private, "<boost/geometry/index/detail/rtree/quadratic/quadratic.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/rstar/choose_next_node.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/rstar.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/rstar/insert.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/rstar.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/rstar/redistribute_elements.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/rstar.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/destroy.hpp>", private, "<boost/geometry/index/detail/rtree/node/node_auto_ptr.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/insert.hpp>", private, "<boost/geometry/index/detail/rtree/linear/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/insert.hpp>", private, "<boost/geometry/index/detail/rtree/quadratic/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/insert.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/is_leaf.hpp>", private, "<boost/geometry/index/detail/rtree/linear/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/is_leaf.hpp>", private, "<boost/geometry/index/detail/rtree/node/node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/is_leaf.hpp>", private, "<boost/geometry/index/detail/rtree/quadratic/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/is_leaf.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/choose_next_node.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/is_leaf.hpp>", private, "<boost/geometry/index/detail/rtree/rstar/redistribute_elements.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/is_leaf.hpp>", private, "<boost/geometry/index/detail/rtree/visitors/remove.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/tags.hpp>", private, "<boost/geometry/index/detail/distance_predicates.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/tags.hpp>", private, "<boost/geometry/index/detail/predicates.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/varray_detail.hpp>", private, "<boost/geometry/index/detail/varray.hpp>", private ] },
+ { include: ["<boost/geometry/index/detail/varray.hpp>", private, "<boost/geometry/index/detail/rtree/node/node.hpp>", private ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/overlay/self_turn_points.hpp>", private, "<boost/geometry/algorithms/detail/has_self_intersections.hpp>", private ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/sections/range_by_section.hpp>", private, "<boost/geometry/multi/algorithms/detail/overlay/get_turns.hpp>", private ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/sections/sectionalize.hpp>", private, "<boost/geometry/multi/algorithms/detail/overlay/get_turns.hpp>", private ] },
+ { include: ["<boost/geometry/multi/views/detail/range_type.hpp>", private, "<boost/geometry/multi/algorithms/detail/overlay/get_turns.hpp>", private ] },
+ { include: ["<boost/geometry/views/detail/range_type.hpp>", private, "<boost/geometry/algorithms/detail/overlay/get_turns.hpp>", private ] },
+ { include: ["<boost/geometry/views/detail/range_type.hpp>", private, "<boost/geometry/multi/views/detail/range_type.hpp>", private ] },
+ { include: ["<boost/graph/detail/adj_list_edge_iterator.hpp>", private, "<boost/graph/detail/adjacency_list.hpp>", private ] },
+ { include: ["<boost/graph/detail/histogram_sort.hpp>", private, "<boost/graph/detail/compressed_sparse_row_struct.hpp>", private ] },
+ { include: ["<boost/graph/detail/indexed_properties.hpp>", private, "<boost/graph/detail/compressed_sparse_row_struct.hpp>", private ] },
+ { include: ["<boost/graph/detail/shadow_iterator.hpp>", private, "<boost/graph/detail/permutation.hpp>", private ] },
+ { include: ["<boost/graph/parallel/detail/untracked_pair.hpp>", private, "<boost/graph/parallel/detail/property_holders.hpp>", private ] },
+ { include: ["<boost/heap/detail/ordered_adaptor_iterator.hpp>", private, "<boost/heap/detail/mutable_heap.hpp>", private ] },
+ { include: ["<boost/heap/detail/tree_iterator.hpp>", private, "<boost/heap/detail/ordered_adaptor_iterator.hpp>", private ] },
+ { include: ["<boost/icl/detail/associated_value.hpp>", private, "<boost/icl/detail/interval_set_algo.hpp>", private ] },
+ { include: ["<boost/icl/detail/design_config.hpp>", private, "<boost/icl/detail/associated_value.hpp>", private ] },
+ { include: ["<boost/icl/detail/element_comparer.hpp>", private, "<boost/icl/detail/interval_map_algo.hpp>", private ] },
+ { include: ["<boost/icl/detail/element_comparer.hpp>", private, "<boost/icl/detail/interval_set_algo.hpp>", private ] },
+ { include: ["<boost/icl/detail/interval_subset_comparer.hpp>", private, "<boost/icl/detail/interval_map_algo.hpp>", private ] },
+ { include: ["<boost/icl/detail/interval_subset_comparer.hpp>", private, "<boost/icl/detail/interval_set_algo.hpp>", private ] },
+ { include: ["<boost/icl/detail/mapped_reference.hpp>", private, "<boost/icl/detail/element_iterator.hpp>", private ] },
+ { include: ["<boost/icl/detail/notate.hpp>", private, "<boost/icl/detail/element_comparer.hpp>", private ] },
+ { include: ["<boost/icl/detail/notate.hpp>", private, "<boost/icl/detail/interval_map_algo.hpp>", private ] },
+ { include: ["<boost/icl/detail/notate.hpp>", private, "<boost/icl/detail/interval_morphism.hpp>", private ] },
+ { include: ["<boost/icl/detail/notate.hpp>", private, "<boost/icl/detail/interval_set_algo.hpp>", private ] },
+ { include: ["<boost/icl/detail/notate.hpp>", private, "<boost/icl/detail/interval_subset_comparer.hpp>", private ] },
+ { include: ["<boost/icl/detail/notate.hpp>", private, "<boost/icl/detail/map_algo.hpp>", private ] },
+ { include: ["<boost/icl/detail/notate.hpp>", private, "<boost/icl/detail/set_algo.hpp>", private ] },
+ { include: ["<boost/icl/detail/notate.hpp>", private, "<boost/icl/detail/subset_comparer.hpp>", private ] },
+ { include: ["<boost/icl/detail/relation_state.hpp>", private, "<boost/icl/detail/interval_map_algo.hpp>", private ] },
+ { include: ["<boost/icl/detail/relation_state.hpp>", private, "<boost/icl/detail/interval_set_algo.hpp>", private ] },
+ { include: ["<boost/icl/detail/relation_state.hpp>", private, "<boost/icl/detail/interval_subset_comparer.hpp>", private ] },
+ { include: ["<boost/icl/detail/relation_state.hpp>", private, "<boost/icl/detail/subset_comparer.hpp>", private ] },
+ { include: ["<boost/icl/detail/set_algo.hpp>", private, "<boost/icl/detail/map_algo.hpp>", private ] },
+ { include: ["<boost/interprocess/allocators/detail/allocator_common.hpp>", private, "<boost/interprocess/allocators/detail/adaptive_node_pool.hpp>", private ] },
+ { include: ["<boost/interprocess/allocators/detail/allocator_common.hpp>", private, "<boost/interprocess/allocators/detail/node_pool.hpp>", private ] },
+ { include: ["<boost/interprocess/allocators/detail/node_tools.hpp>", private, "<boost/interprocess/allocators/detail/adaptive_node_pool.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/atomic.hpp>", private, "<boost/interprocess/detail/intermodule_singleton_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/atomic.hpp>", private, "<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/atomic.hpp>", private, "<boost/interprocess/detail/portable_intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/atomic.hpp>", private, "<boost/interprocess/detail/robust_emulation.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/atomic.hpp>", private, "<boost/interprocess/smart_ptr/detail/sp_counted_base_atomic.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/cast_tags.hpp>", private, "<boost/interprocess/detail/intersegment_ptr.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/allocators/detail/adaptive_node_pool.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/allocators/detail/allocator_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/allocators/detail/node_pool.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/allocators/detail/node_tools.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/atomic.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/cast_tags.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/file_locking_helpers.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/file_wrapper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/in_place_interface.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/intermodule_singleton_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/intersegment_ptr.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/managed_global_memory.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/managed_memory_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/managed_multi_shared_memory.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/min_max.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/multi_segment_services.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/named_proxy.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/os_file_functions.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/os_thread_functions.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/pointer_type.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/portable_intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/preprocessor.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/robust_emulation.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/segment_manager_helper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/tmp_dir_helpers.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/transform_iterator.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/type_traits.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/utilities.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/variadic_templates_tools.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/win32_api.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/windows_intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/workaround.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/xsi_shared_memory_device.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/detail/xsi_shared_memory_file_wrapper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/mem_algo/detail/mem_algo_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/mem_algo/detail/multi_simple_seq_fit.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/mem_algo/detail/multi_simple_seq_fit_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/mem_algo/detail/simple_seq_fit_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/smart_ptr/detail/bad_weak_ptr.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/smart_ptr/detail/shared_count.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/smart_ptr/detail/sp_counted_base_atomic.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/smart_ptr/detail/sp_counted_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/detail/condition_algorithm_8a.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/detail/condition_any_algorithm.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/detail/locks.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/allocators/detail/adaptive_node_pool.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/allocators/detail/allocator_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/allocators/detail/node_pool.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/allocators/detail/node_tools.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/atomic.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/cast_tags.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/file_locking_helpers.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/file_wrapper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/in_place_interface.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/intermodule_singleton_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/intersegment_ptr.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/managed_global_memory.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/managed_memory_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/managed_multi_shared_memory.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/min_max.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/multi_segment_services.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/named_proxy.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/os_file_functions.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/os_thread_functions.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/pointer_type.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/portable_intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/preprocessor.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/robust_emulation.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/segment_manager_helper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/tmp_dir_helpers.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/transform_iterator.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/type_traits.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/utilities.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/variadic_templates_tools.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/win32_api.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/windows_intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/workaround.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/xsi_shared_memory_device.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/detail/xsi_shared_memory_file_wrapper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/mem_algo/detail/mem_algo_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/mem_algo/detail/multi_simple_seq_fit.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/mem_algo/detail/multi_simple_seq_fit_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/mem_algo/detail/simple_seq_fit_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/smart_ptr/detail/bad_weak_ptr.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/smart_ptr/detail/shared_count.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/smart_ptr/detail/sp_counted_base_atomic.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/smart_ptr/detail/sp_counted_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/detail/condition_algorithm_8a.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/detail/condition_any_algorithm.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/detail/locks.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/file_locking_helpers.hpp>", private, "<boost/interprocess/detail/portable_intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/in_place_interface.hpp>", private, "<boost/interprocess/detail/named_proxy.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/in_place_interface.hpp>", private, "<boost/interprocess/detail/segment_manager_helper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/intermodule_singleton_common.hpp>", private, "<boost/interprocess/detail/portable_intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/intermodule_singleton_common.hpp>", private, "<boost/interprocess/detail/windows_intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/intermodule_singleton.hpp>", private, "<boost/interprocess/detail/robust_emulation.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/interprocess_tester.hpp>", private, "<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/managed_global_memory.hpp>", private, "<boost/interprocess/detail/portable_intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/managed_memory_impl.hpp>", private, "<boost/interprocess/detail/managed_global_memory.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/managed_memory_impl.hpp>", private, "<boost/interprocess/detail/managed_multi_shared_memory.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/detail/managed_global_memory.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/detail/managed_multi_shared_memory.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/math_functions.hpp>", private, "<boost/interprocess/allocators/detail/adaptive_node_pool.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/math_functions.hpp>", private, "<boost/interprocess/detail/intersegment_ptr.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/math_functions.hpp>", private, "<boost/interprocess/mem_algo/detail/mem_algo_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/min_max.hpp>", private, "<boost/interprocess/detail/utilities.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/min_max.hpp>", private, "<boost/interprocess/mem_algo/detail/mem_algo_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/min_max.hpp>", private, "<boost/interprocess/mem_algo/detail/multi_simple_seq_fit_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/min_max.hpp>", private, "<boost/interprocess/mem_algo/detail/simple_seq_fit_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/mpl.hpp>", private, "<boost/interprocess/detail/intermodule_singleton_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/mpl.hpp>", private, "<boost/interprocess/detail/intersegment_ptr.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/mpl.hpp>", private, "<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/mpl.hpp>", private, "<boost/interprocess/detail/named_proxy.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/mpl.hpp>", private, "<boost/interprocess/detail/utilities.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/multi_segment_services.hpp>", private, "<boost/interprocess/detail/intersegment_ptr.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/multi_segment_services.hpp>", private, "<boost/interprocess/detail/managed_multi_shared_memory.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/multi_segment_services.hpp>", private, "<boost/interprocess/mem_algo/detail/multi_simple_seq_fit_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/detail/file_locking_helpers.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/detail/file_wrapper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/detail/managed_memory_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/detail/portable_intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/detail/robust_emulation.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/detail/tmp_dir_helpers.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/detail/xsi_shared_memory_device.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/detail/xsi_shared_memory_file_wrapper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/os_thread_functions.hpp>", private, "<boost/interprocess/detail/intermodule_singleton_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/os_thread_functions.hpp>", private, "<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/os_thread_functions.hpp>", private, "<boost/interprocess/detail/portable_intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/portable_intermodule_singleton.hpp>", private, "<boost/interprocess/detail/intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/detail/os_thread_functions.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/preprocessor.hpp>", private, "<boost/interprocess/detail/named_proxy.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/segment_manager_helper.hpp>", private, "<boost/interprocess/allocators/detail/allocator_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/tmp_dir_helpers.hpp>", private, "<boost/interprocess/detail/file_locking_helpers.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/tmp_dir_helpers.hpp>", private, "<boost/interprocess/detail/portable_intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/tmp_dir_helpers.hpp>", private, "<boost/interprocess/detail/robust_emulation.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/tmp_dir_helpers.hpp>", private, "<boost/interprocess/detail/xsi_shared_memory_device.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/tmp_dir_helpers.hpp>", private, "<boost/interprocess/detail/xsi_shared_memory_file_wrapper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/transform_iterator.hpp>", private, "<boost/interprocess/detail/utilities.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/allocators/detail/adaptive_node_pool.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/allocators/detail/allocator_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/detail/in_place_interface.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/detail/pointer_type.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/detail/segment_manager_helper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/detail/transform_iterator.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/detail/utilities.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/detail/variadic_templates_tools.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/mem_algo/detail/mem_algo_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/mem_algo/detail/simple_seq_fit_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/allocators/detail/adaptive_node_pool.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/allocators/detail/allocator_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/allocators/detail/node_pool.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/detail/intersegment_ptr.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/detail/managed_memory_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/detail/managed_multi_shared_memory.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/detail/segment_manager_helper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/detail/xsi_shared_memory_device.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/detail/xsi_shared_memory_file_wrapper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/mem_algo/detail/mem_algo_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/mem_algo/detail/multi_simple_seq_fit_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/mem_algo/detail/simple_seq_fit_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/smart_ptr/detail/shared_count.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/smart_ptr/detail/sp_counted_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/variadic_templates_tools.hpp>", private, "<boost/interprocess/detail/named_proxy.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/win32_api.hpp>", private, "<boost/interprocess/detail/atomic.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/win32_api.hpp>", private, "<boost/interprocess/detail/os_file_functions.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/win32_api.hpp>", private, "<boost/interprocess/detail/os_thread_functions.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/windows_intermodule_singleton.hpp>", private, "<boost/interprocess/detail/intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/windows_intermodule_singleton.hpp>", private, "<boost/interprocess/detail/tmp_dir_helpers.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/allocators/detail/adaptive_node_pool.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/allocators/detail/allocator_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/allocators/detail/node_pool.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/allocators/detail/node_tools.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/atomic.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/cast_tags.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/file_locking_helpers.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/file_wrapper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/in_place_interface.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/intermodule_singleton_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/intersegment_ptr.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/managed_global_memory.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/managed_memory_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/managed_multi_shared_memory.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/min_max.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/multi_segment_services.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/named_proxy.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/os_file_functions.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/os_thread_functions.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/pointer_type.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/portable_intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/robust_emulation.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/segment_manager_helper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/tmp_dir_helpers.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/transform_iterator.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/utilities.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/variadic_templates_tools.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/win32_api.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/windows_intermodule_singleton.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/xsi_shared_memory_device.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/detail/xsi_shared_memory_file_wrapper.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/mem_algo/detail/mem_algo_common.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/mem_algo/detail/multi_simple_seq_fit.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/mem_algo/detail/multi_simple_seq_fit_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/mem_algo/detail/simple_seq_fit_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/smart_ptr/detail/bad_weak_ptr.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/smart_ptr/detail/shared_count.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/smart_ptr/detail/sp_counted_base_atomic.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/smart_ptr/detail/sp_counted_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/sync/detail/condition_algorithm_8a.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/sync/detail/condition_any_algorithm.hpp>", private ] },
+ { include: ["<boost/interprocess/detail/workaround.hpp>", private, "<boost/interprocess/sync/detail/locks.hpp>", private ] },
+ { include: ["<boost/interprocess/mem_algo/detail/mem_algo_common.hpp>", private, "<boost/interprocess/allocators/detail/adaptive_node_pool.hpp>", private ] },
+ { include: ["<boost/interprocess/mem_algo/detail/mem_algo_common.hpp>", private, "<boost/interprocess/allocators/detail/allocator_common.hpp>", private ] },
+ { include: ["<boost/interprocess/mem_algo/detail/mem_algo_common.hpp>", private, "<boost/interprocess/mem_algo/detail/simple_seq_fit_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/mem_algo/detail/simple_seq_fit_impl.hpp>", private, "<boost/interprocess/mem_algo/detail/multi_simple_seq_fit.hpp>", private ] },
+ { include: ["<boost/interprocess/smart_ptr/detail/bad_weak_ptr.hpp>", private, "<boost/interprocess/smart_ptr/detail/shared_count.hpp>", private ] },
+ { include: ["<boost/interprocess/smart_ptr/detail/sp_counted_base_atomic.hpp>", private, "<boost/interprocess/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/interprocess/smart_ptr/detail/sp_counted_base.hpp>", private, "<boost/interprocess/smart_ptr/detail/sp_counted_impl.hpp>", private ] },
+ { include: ["<boost/interprocess/smart_ptr/detail/sp_counted_impl.hpp>", private, "<boost/interprocess/smart_ptr/detail/shared_count.hpp>", private ] },
+ { include: ["<boost/interprocess/sync/detail/locks.hpp>", private, "<boost/interprocess/sync/detail/condition_algorithm_8a.hpp>", private ] },
+ { include: ["<boost/interprocess/sync/detail/locks.hpp>", private, "<boost/interprocess/sync/detail/condition_any_algorithm.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/detail/any_node_and_algorithms.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/detail/common_slist_algorithms.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/detail/hashtable_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/detail/list_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/detail/slist_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/detail/utilities.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/any_node_and_algorithms.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/avltree_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/clear_on_destructor_base.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/common_slist_algorithms.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/ebo_functor_holder.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/function_detector.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/generic_hook.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/hashtable_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/has_member_function_callable_with.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/is_stateful_value_traits.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/list_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/memory_util.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/mpl.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/parent_from_member.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/preprocessor.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/rbtree_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/slist_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/transform_iterator.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/tree_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/utilities.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/detail/workaround.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/any_node_and_algorithms.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/avltree_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/clear_on_destructor_base.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/common_slist_algorithms.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/ebo_functor_holder.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/function_detector.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/generic_hook.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/has_member_function_callable_with.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/is_stateful_value_traits.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/list_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/memory_util.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/mpl.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/parent_from_member.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/preprocessor.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/rbtree_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/slist_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/transform_iterator.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/tree_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/utilities.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/detail/workaround.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/ebo_functor_holder.hpp>", private, "<boost/intrusive/detail/utilities.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/function_detector.hpp>", private, "<boost/intrusive/detail/is_stateful_value_traits.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/has_member_function_callable_with.hpp>", private, "<boost/container/detail/memory_util.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/has_member_function_callable_with.hpp>", private, "<boost/intrusive/detail/memory_util.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/is_stateful_value_traits.hpp>", private, "<boost/intrusive/detail/utilities.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/detail/any_node_and_algorithms.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/detail/avltree_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/detail/ebo_functor_holder.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/detail/generic_hook.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/detail/hashtable_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/detail/has_member_function_callable_with.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/detail/is_stateful_value_traits.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/detail/memory_util.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/detail/rbtree_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/detail/transform_iterator.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/detail/tree_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/detail/utilities.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/parent_from_member.hpp>", private, "<boost/intrusive/detail/utilities.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/preprocessor.hpp>", private, "<boost/intrusive/detail/has_member_function_callable_with.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/preprocessor.hpp>", private, "<boost/intrusive/detail/memory_util.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/tree_node.hpp>", private, "<boost/intrusive/detail/rbtree_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/detail/generic_hook.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/detail/hashtable_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/detail/rbtree_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/detail/slist_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/detail/tree_node.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/workaround.hpp>", private, "<boost/intrusive/detail/has_member_function_callable_with.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/workaround.hpp>", private, "<boost/intrusive/detail/memory_util.hpp>", private ] },
+ { include: ["<boost/intrusive/detail/workaround.hpp>", private, "<boost/intrusive/detail/preprocessor.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/access_control.hpp>", private, "<boost/iostreams/detail/streambuf/chainbuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/adapter/concept_adapter.hpp>", private, "<boost/iostreams/detail/streambuf/indirect_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/adapter/device_adapter.hpp>", private, "<boost/iostreams/detail/restrict_impl.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/adapter/filter_adapter.hpp>", private, "<boost/iostreams/detail/restrict_impl.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/adapter/mode_adapter.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/adapter/non_blocking_adapter.hpp>", private, "<boost/iostreams/detail/adapter/concept_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/adapter/output_iterator_adapter.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/adapter/range_adapter.hpp>", private, "<boost/iostreams/detail/push.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/adapter/range_adapter.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/bool_trait_def.hpp>", private, "<boost/iostreams/detail/is_iterator_range.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/broken_overload_resolution/forward.hpp>", private, "<boost/iostreams/detail/broken_overload_resolution/stream_buffer.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/broken_overload_resolution/forward.hpp>", private, "<boost/iostreams/detail/broken_overload_resolution/stream.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/buffer.hpp>", private, "<boost/iostreams/detail/current_directory.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/buffer.hpp>", private, "<boost/iostreams/detail/streambuf/indirect_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/call_traits.hpp>", private, "<boost/iostreams/detail/adapter/concept_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/call_traits.hpp>", private, "<boost/iostreams/detail/adapter/device_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/call_traits.hpp>", private, "<boost/iostreams/detail/adapter/filter_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/call_traits.hpp>", private, "<boost/iostreams/detail/restrict_impl.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/detail/adapter/concept_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/detail/counted_array.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/detail/streambuf/direct_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/detail/streambuf/linked_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/codecvt.hpp>", private, "<boost/iostreams/detail/codecvt_helper.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/codecvt.hpp>", private, "<boost/iostreams/detail/codecvt_holder.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/detail/adapter/concept_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/detail/adapter/direct_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/detail/adapter/range_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/detail/current_directory.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/detail/is_iterator_range.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/detail/restrict_impl.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/detail/streambuf/direct_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/detail/streambuf/indirect_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/detail/streambuf/linked_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/detail/adapter/concept_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/detail/adapter/direct_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/detail/adapter/range_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/detail/current_directory.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/detail/is_iterator_range.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/detail/streambuf/direct_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/detail/streambuf/indirect_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/detail/streambuf/linked_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/gcc.hpp>", private, "<boost/iostreams/detail/config/overload_resolution.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/gcc.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/limits.hpp>", private, "<boost/iostreams/detail/adapter/direct_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/limits.hpp>", private, "<boost/iostreams/detail/execute.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/limits.hpp>", private, "<boost/iostreams/detail/forward.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/overload_resolution.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/unreachable_return.hpp>", private, "<boost/iostreams/detail/adapter/concept_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/detail/adapter/direct_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/detail/char_traits.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/detail/config/codecvt.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/detail/fstream.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/detail/ios.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/detail/iostream.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/detail/path.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/detail/push.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/detail/streambuf/chainbuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/detail/streambuf/direct_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/detail/streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/detail/streambuf/indirect_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/windows_posix.hpp>", private, "<boost/iostreams/detail/absolute_path.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/windows_posix.hpp>", private, "<boost/iostreams/detail/config/rtl.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/windows_posix.hpp>", private, "<boost/iostreams/detail/current_directory.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/windows_posix.hpp>", private, "<boost/iostreams/detail/file_handle.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/config/windows_posix.hpp>", private, "<boost/iostreams/detail/system_failure.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/current_directory.hpp>", private, "<boost/iostreams/detail/absolute_path.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/dispatch.hpp>", private, "<boost/iostreams/detail/adapter/concept_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/double_object.hpp>", private, "<boost/iostreams/detail/adapter/direct_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/double_object.hpp>", private, "<boost/iostreams/detail/streambuf/indirect_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/enable_if_stream.hpp>", private, "<boost/iostreams/detail/push.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/enable_if_stream.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/enable_if_stream.hpp>", private, "<boost/iostreams/detail/restrict_impl.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/enable_if_stream.hpp>", private, "<boost/iostreams/detail/wrap_unwrap.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/error.hpp>", private, "<boost/iostreams/detail/adapter/concept_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/error.hpp>", private, "<boost/iostreams/detail/adapter/direct_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/error.hpp>", private, "<boost/iostreams/detail/adapter/range_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/error.hpp>", private, "<boost/iostreams/detail/restrict_impl.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/error.hpp>", private, "<boost/iostreams/detail/streambuf/direct_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/execute.hpp>", private, "<boost/iostreams/detail/streambuf/direct_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/execute.hpp>", private, "<boost/iostreams/detail/streambuf/indirect_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/functional.hpp>", private, "<boost/iostreams/detail/streambuf/direct_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/functional.hpp>", private, "<boost/iostreams/detail/streambuf/indirect_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/detail/adapter/device_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/detail/adapter/direct_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/detail/adapter/filter_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/detail/adapter/mode_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/detail/adapter/non_blocking_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/detail/buffer.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/detail/counted_array.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/detail/error.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/detail/functional.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/detail/restrict_impl.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/detail/streambuf/direct_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/detail/streambuf/indirect_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/detail/streambuf/linked_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/detail/system_failure.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/is_dereferenceable.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/is_iterator_range.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/optional.hpp>", private, "<boost/iostreams/detail/streambuf/direct_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/optional.hpp>", private, "<boost/iostreams/detail/streambuf/indirect_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/push.hpp>", private, "<boost/iostreams/detail/streambuf/indirect_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/push_params.hpp>", private, "<boost/iostreams/detail/forward.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/push_params.hpp>", private, "<boost/iostreams/detail/push.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/resolve.hpp>", private, "<boost/iostreams/detail/push.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/select.hpp>", private, "<boost/iostreams/detail/access_control.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/select.hpp>", private, "<boost/iostreams/detail/dispatch.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/select.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/select.hpp>", private, "<boost/iostreams/detail/restrict_impl.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/streambuf.hpp>", private, "<boost/iostreams/detail/adapter/concept_adapter.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/streambuf.hpp>", private, "<boost/iostreams/detail/streambuf/chainbuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/streambuf.hpp>", private, "<boost/iostreams/detail/streambuf/direct_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/streambuf.hpp>", private, "<boost/iostreams/detail/streambuf/linked_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/streambuf/linked_streambuf.hpp>", private, "<boost/iostreams/detail/streambuf/chainbuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/streambuf/linked_streambuf.hpp>", private, "<boost/iostreams/detail/streambuf/direct_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/streambuf/linked_streambuf.hpp>", private, "<boost/iostreams/detail/streambuf/indirect_streambuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/system_failure.hpp>", private, "<boost/iostreams/detail/current_directory.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/template_params.hpp>", private, "<boost/iostreams/detail/bool_trait_def.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/translate_int_type.hpp>", private, "<boost/iostreams/detail/streambuf/chainbuf.hpp>", private ] },
+ { include: ["<boost/iostreams/detail/wrap_unwrap.hpp>", private, "<boost/iostreams/detail/resolve.hpp>", private ] },
+ { include: ["<boost/iterator/detail/config_def.hpp>", private, "<boost/iterator/detail/enable_if.hpp>", private ] },
+ { include: ["<boost/iterator/detail/config_def.hpp>", private, "<boost/iterator/detail/facade_iterator_category.hpp>", private ] },
+ { include: ["<boost/iterator/detail/config_undef.hpp>", private, "<boost/iterator/detail/enable_if.hpp>", private ] },
+ { include: ["<boost/iterator/detail/config_undef.hpp>", private, "<boost/iterator/detail/facade_iterator_category.hpp>", private ] },
+ { include: ["<boost/iterator/detail/enable_if.hpp>", private, "<boost/bimap/detail/map_view_iterator.hpp>", private ] },
+ { include: ["<boost/iterator/detail/enable_if.hpp>", private, "<boost/bimap/detail/set_view_iterator.hpp>", private ] },
+ { include: ["<boost/lambda/detail/is_instance_of.hpp>", private, "<boost/lambda/detail/operator_return_type_traits.hpp>", private ] },
+ { include: ["<boost/lambda/detail/is_instance_of.hpp>", private, "<boost/lambda/detail/operators.hpp>", private ] },
+ { include: ["<boost/lambda/detail/lambda_fwd.hpp>", private, "<boost/lambda/detail/lambda_functor_base.hpp>", private ] },
+ { include: ["<boost/lambda/detail/lambda_traits.hpp>", private, "<boost/lambda/detail/lambda_functor_base.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/bind.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/const_bind.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/const.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/const_bind.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/add.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/auto.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/add.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/bind.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/add.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/const_bind.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/add.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/const.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/add.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/default.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/add.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/inline.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/add.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/recursive.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/add.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/register.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/add.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/return.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/add.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/this.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/add.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/thisunderscore.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/add.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/void.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/is.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/auto.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/is.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/bind.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/is.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/const.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/is.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/default.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/is.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/inline.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/is.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/recursive.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/is.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/register.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/is.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/return.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/is.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/this.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/is.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/thisunderscore.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/is.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/void.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/remove.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/auto.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/remove.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/bind.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/remove.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/const.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/remove.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/default.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/remove.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/inline.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/remove.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/recursive.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/remove.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/register.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/remove.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/return.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/remove.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/this.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/remove.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/thisunderscore.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/facility/remove.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/void.hpp>", private ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/void.hpp>", private, "<boost/local_function/detail/preprocessor/void_list.hpp>", private ] },
+ { include: ["<boost/lockfree/detail/atomic.hpp>", private, "<boost/lockfree/detail/freelist.hpp>", private ] },
+ { include: ["<boost/lockfree/detail/branch_hints.hpp>", private, "<boost/lockfree/detail/tagged_ptr_dcas.hpp>", private ] },
+ { include: ["<boost/lockfree/detail/branch_hints.hpp>", private, "<boost/lockfree/detail/tagged_ptr_ptrcompression.hpp>", private ] },
+ { include: ["<boost/lockfree/detail/parameter.hpp>", private, "<boost/lockfree/detail/freelist.hpp>", private ] },
+ { include: ["<boost/lockfree/detail/prefix.hpp>", private, "<boost/lockfree/detail/tagged_ptr.hpp>", private ] },
+ { include: ["<boost/lockfree/detail/tagged_ptr_dcas.hpp>", private, "<boost/lockfree/detail/tagged_ptr.hpp>", private ] },
+ { include: ["<boost/lockfree/detail/tagged_ptr.hpp>", private, "<boost/lockfree/detail/freelist.hpp>", private ] },
+ { include: ["<boost/lockfree/detail/tagged_ptr_ptrcompression.hpp>", private, "<boost/lockfree/detail/tagged_ptr.hpp>", private ] },
+ { include: ["<boost/log/detail/attr_output_terminal.hpp>", private, "<boost/log/detail/attr_output_impl.hpp>", private ] },
+ { include: ["<boost/log/detail/cleanup_scope_guard.hpp>", private, "<boost/log/detail/format.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/asio_fwd.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/attachable_sstream_buf.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/attribute_get_value_impl.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/attribute_predicate.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/attr_output_impl.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/attr_output_terminal.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/cleanup_scope_guard.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/code_conversion.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/custom_terminal_spec.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/date_time_fmt_gen_traits_fwd.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/date_time_format_parser.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/decomposed_time.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/deduce_char_type.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/default_attribute_names.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/embedded_string_type.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/event.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/fake_mutex.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/format.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/function_traits.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/id.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/light_function.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/light_rw_mutex.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/locking_ptr.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/locks.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/malloc_aligned.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/native_typeof.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/parameter_tools.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/pp_identity.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/process_id.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/setup_config.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/singleton.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/sink_init_helpers.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/snprintf.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/spin_mutex.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/tagged_integer.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/thread_id.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/threadsafe_queue.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/thread_specific.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/timestamp.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/trivial_keyword.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/unary_function_terminal.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/unhandled_exception_count.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/value_ref_visitation.hpp>", private ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/detail/visible_type.hpp>", private ] },
+ { include: ["<boost/log/detail/custom_terminal_spec.hpp>", private, "<boost/log/detail/attr_output_terminal.hpp>", private ] },
+ { include: ["<boost/log/detail/custom_terminal_spec.hpp>", private, "<boost/log/detail/unary_function_terminal.hpp>", private ] },
+ { include: ["<boost/log/detail/date_time_format_parser.hpp>", private, "<boost/log/detail/decomposed_time.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/attachable_sstream_buf.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/attribute_get_value_impl.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/attribute_predicate.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/attr_output_impl.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/attr_output_terminal.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/cleanup_scope_guard.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/code_conversion.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/custom_terminal_spec.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/date_time_format_parser.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/decomposed_time.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/deduce_char_type.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/default_attribute_names.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/embedded_string_type.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/event.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/fake_mutex.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/format.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/function_traits.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/id.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/light_function.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/light_rw_mutex.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/locking_ptr.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/locks.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/malloc_aligned.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/parameter_tools.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/process_id.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/singleton.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/sink_init_helpers.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/snprintf.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/spin_mutex.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/tagged_integer.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/thread_id.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/threadsafe_queue.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/thread_specific.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/timestamp.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/trivial_keyword.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/unary_function_terminal.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/value_ref_visitation.hpp>", private ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/detail/visible_type.hpp>", private ] },
+ { include: ["<boost/log/detail/generate_overloads.hpp>", private, "<boost/log/detail/attr_output_impl.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/attachable_sstream_buf.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/attribute_get_value_impl.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/attribute_predicate.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/attr_output_impl.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/attr_output_terminal.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/cleanup_scope_guard.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/code_conversion.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/custom_terminal_spec.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/date_time_format_parser.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/decomposed_time.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/deduce_char_type.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/default_attribute_names.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/embedded_string_type.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/event.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/fake_mutex.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/format.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/function_traits.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/id.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/light_function.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/light_rw_mutex.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/locking_ptr.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/locks.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/malloc_aligned.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/parameter_tools.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/process_id.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/singleton.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/sink_init_helpers.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/snprintf.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/spin_mutex.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/tagged_integer.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/thread_id.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/threadsafe_queue.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/thread_specific.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/timestamp.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/trivial_keyword.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/unary_function_terminal.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/value_ref_visitation.hpp>", private ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/detail/visible_type.hpp>", private ] },
+ { include: ["<boost/log/detail/id.hpp>", private, "<boost/log/detail/process_id.hpp>", private ] },
+ { include: ["<boost/log/detail/id.hpp>", private, "<boost/log/detail/thread_id.hpp>", private ] },
+ { include: ["<boost/log/detail/unhandled_exception_count.hpp>", private, "<boost/log/detail/format.hpp>", private ] },
+ { include: ["<boost/math/bindings/detail/big_lanczos.hpp>", private, "<boost/multiprecision/detail/big_lanczos.hpp>", private ] },
+ { include: ["<boost/math/distributions/detail/hypergeometric_pdf.hpp>", private, "<boost/math/distributions/detail/hypergeometric_cdf.hpp>", private ] },
+ { include: ["<boost/math/distributions/detail/hypergeometric_pdf.hpp>", private, "<boost/math/distributions/detail/hypergeometric_quantile.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/airy_ai_bi_zero.hpp>", private, "<boost/math/special_functions/detail/bessel_jy_zero.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/bessel_j0.hpp>", private, "<boost/math/special_functions/detail/bessel_jn.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/bessel_j0.hpp>", private, "<boost/math/special_functions/detail/bessel_y0.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/bessel_j1.hpp>", private, "<boost/math/special_functions/detail/bessel_jn.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/bessel_j1.hpp>", private, "<boost/math/special_functions/detail/bessel_y1.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/bessel_jy_asym.hpp>", private, "<boost/math/special_functions/detail/bessel_jn.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/bessel_jy_asym.hpp>", private, "<boost/math/special_functions/detail/bessel_jy.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/bessel_jy.hpp>", private, "<boost/math/special_functions/detail/bessel_jn.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/bessel_jy_series.hpp>", private, "<boost/math/special_functions/detail/bessel_jn.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/bessel_jy_series.hpp>", private, "<boost/math/special_functions/detail/bessel_jy.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/bessel_jy_series.hpp>", private, "<boost/math/special_functions/detail/bessel_yn.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/bessel_k0.hpp>", private, "<boost/math/special_functions/detail/bessel_kn.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/bessel_k1.hpp>", private, "<boost/math/special_functions/detail/bessel_kn.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/bessel_y0.hpp>", private, "<boost/math/special_functions/detail/bessel_yn.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/bessel_y1.hpp>", private, "<boost/math/special_functions/detail/bessel_yn.hpp>", private ] },
+ { include: ["<boost/math/special_functions/detail/t_distribution_inv.hpp>", private, "<boost/math/special_functions/detail/ibeta_inverse.hpp>", private ] },
+ { include: ["<boost/move/detail/config_begin.hpp>", private, "<boost/move/detail/meta_utils.hpp>", private ] },
+ { include: ["<boost/move/detail/config_end.hpp>", private, "<boost/move/detail/meta_utils.hpp>", private ] },
+ { include: ["<boost/mpi/detail/forward_skeleton_oarchive.hpp>", private, "<boost/mpi/detail/text_skeleton_oarchive.hpp>", private ] },
+ { include: ["<boost/mpi/detail/ignore_oprimitive.hpp>", private, "<boost/mpi/detail/text_skeleton_oarchive.hpp>", private ] },
+ { include: ["<boost/mpi/detail/ignore_skeleton_oarchive.hpp>", private, "<boost/mpi/detail/content_oarchive.hpp>", private ] },
+ { include: ["<boost/mpi/detail/ignore_skeleton_oarchive.hpp>", private, "<boost/mpi/detail/mpi_datatype_oarchive.hpp>", private ] },
+ { include: ["<boost/mpi/detail/mpi_datatype_oarchive.hpp>", private, "<boost/mpi/detail/mpi_datatype_cache.hpp>", private ] },
+ { include: ["<boost/mpi/detail/mpi_datatype_primitive.hpp>", private, "<boost/mpi/detail/content_oarchive.hpp>", private ] },
+ { include: ["<boost/mpi/detail/mpi_datatype_primitive.hpp>", private, "<boost/mpi/detail/mpi_datatype_oarchive.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/access_specifier.hpp>", private, "<boost/multi_index/detail/safe_mode.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/adl_swap.hpp>", private, "<boost/multi_index/detail/auto_space.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/auto_space.hpp>", private, "<boost/multi_index/detail/bucket_array.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/auto_space.hpp>", private, "<boost/multi_index/detail/copy_map.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/auto_space.hpp>", private, "<boost/multi_index/detail/index_loader.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/auto_space.hpp>", private, "<boost/multi_index/detail/index_matcher.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/auto_space.hpp>", private, "<boost/multi_index/detail/rnd_index_loader.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/auto_space.hpp>", private, "<boost/multi_index/detail/rnd_index_ptr_array.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/copy_map.hpp>", private, "<boost/multi_index/detail/index_base.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/do_not_copy_elements_tag.hpp>", private, "<boost/multi_index/detail/index_base.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/hash_index_node.hpp>", private, "<boost/multi_index/detail/bucket_array.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/header_holder.hpp>", private, "<boost/multi_index/detail/node_type.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/index_base.hpp>", private, "<boost/multi_index/detail/base_type.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/index_loader.hpp>", private, "<boost/multi_index/detail/index_base.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/index_matcher.hpp>", private, "<boost/multi_index/detail/index_saver.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/index_node_base.hpp>", private, "<boost/multi_index/detail/node_type.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/index_saver.hpp>", private, "<boost/multi_index/detail/index_base.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/is_index_list.hpp>", private, "<boost/multi_index/detail/base_type.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/is_index_list.hpp>", private, "<boost/multi_index/detail/node_type.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/iter_adaptor.hpp>", private, "<boost/multi_index/detail/safe_mode.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/msvc_index_specifier.hpp>", private, "<boost/multi_index/detail/base_type.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/msvc_index_specifier.hpp>", private, "<boost/multi_index/detail/node_type.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/node_type.hpp>", private, "<boost/multi_index/detail/index_base.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/prevent_eti.hpp>", private, "<boost/multi_index/detail/auto_space.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/prevent_eti.hpp>", private, "<boost/multi_index/detail/bucket_array.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/prevent_eti.hpp>", private, "<boost/multi_index/detail/copy_map.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/prevent_eti.hpp>", private, "<boost/multi_index/detail/hash_index_node.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/prevent_eti.hpp>", private, "<boost/multi_index/detail/iter_adaptor.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/prevent_eti.hpp>", private, "<boost/multi_index/detail/ord_index_node.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/prevent_eti.hpp>", private, "<boost/multi_index/detail/rnd_index_loader.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/prevent_eti.hpp>", private, "<boost/multi_index/detail/rnd_index_node.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/prevent_eti.hpp>", private, "<boost/multi_index/detail/rnd_index_ptr_array.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/prevent_eti.hpp>", private, "<boost/multi_index/detail/seq_index_node.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/rnd_index_node.hpp>", private, "<boost/multi_index/detail/rnd_index_ptr_array.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/rnd_index_ptr_array.hpp>", private, "<boost/multi_index/detail/rnd_index_loader.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/rnd_index_ptr_array.hpp>", private, "<boost/multi_index/detail/rnd_index_ops.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/safe_mode.hpp>", private, "<boost/multi_index/detail/safe_ctr_proxy.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/scope_guard.hpp>", private, "<boost/signals2/detail/auto_buffer.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/seq_index_node.hpp>", private, "<boost/multi_index/detail/seq_index_ops.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/uintptr_type.hpp>", private, "<boost/multi_index/detail/ord_index_node.hpp>", private ] },
+ { include: ["<boost/multi_index/detail/vartempl_support.hpp>", private, "<boost/multi_index/detail/index_base.hpp>", private ] },
+ { include: ["<boost/multiprecision/detail/default_ops.hpp>", private, "<boost/multiprecision/detail/generic_interconvert.hpp>", private ] },
+ { include: ["<boost/multiprecision/detail/et_ops.hpp>", private, "<boost/multiprecision/detail/default_ops.hpp>", private ] },
+ { include: ["<boost/multiprecision/detail/functions/constants.hpp>", private, "<boost/multiprecision/detail/default_ops.hpp>", private ] },
+ { include: ["<boost/multiprecision/detail/functions/pow.hpp>", private, "<boost/multiprecision/detail/default_ops.hpp>", private ] },
+ { include: ["<boost/multiprecision/detail/functions/trig.hpp>", private, "<boost/multiprecision/detail/default_ops.hpp>", private ] },
+ { include: ["<boost/multiprecision/detail/no_et_ops.hpp>", private, "<boost/multiprecision/detail/default_ops.hpp>", private ] },
+ { include: ["<boost/multiprecision/detail/number_base.hpp>", private, "<boost/multiprecision/detail/default_ops.hpp>", private ] },
+ { include: ["<boost/multiprecision/detail/rebind.hpp>", private, "<boost/multiprecision/detail/dynamic_array.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/conversion_traits.hpp>", private, "<boost/numeric/conversion/detail/converter.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/int_float_mixture.hpp>", private, "<boost/numeric/conversion/detail/conversion_traits.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/int_float_mixture.hpp>", private, "<boost/numeric/conversion/detail/is_subranged.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/is_subranged.hpp>", private, "<boost/numeric/conversion/detail/conversion_traits.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/meta.hpp>", private, "<boost/numeric/conversion/detail/conversion_traits.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/meta.hpp>", private, "<boost/numeric/conversion/detail/converter.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/meta.hpp>", private, "<boost/numeric/conversion/detail/int_float_mixture.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/meta.hpp>", private, "<boost/numeric/conversion/detail/is_subranged.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/meta.hpp>", private, "<boost/numeric/conversion/detail/sign_mixture.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/meta.hpp>", private, "<boost/numeric/conversion/detail/udt_builtin_mixture.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/preprocessed/numeric_cast_traits_common.hpp>", private, "<boost/numeric/conversion/detail/numeric_cast_traits.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/preprocessed/numeric_cast_traits_long_long.hpp>", private, "<boost/numeric/conversion/detail/numeric_cast_traits.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/sign_mixture.hpp>", private, "<boost/numeric/conversion/detail/conversion_traits.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/sign_mixture.hpp>", private, "<boost/numeric/conversion/detail/is_subranged.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/udt_builtin_mixture.hpp>", private, "<boost/numeric/conversion/detail/conversion_traits.hpp>", private ] },
+ { include: ["<boost/numeric/conversion/detail/udt_builtin_mixture.hpp>", private, "<boost/numeric/conversion/detail/is_subranged.hpp>", private ] },
+ { include: ["<boost/numeric/interval/detail/bcc_rounding_control.hpp>", private, "<boost/numeric/interval/detail/x86_rounding_control.hpp>", private ] },
+ { include: ["<boost/numeric/interval/detail/bugs.hpp>", private, "<boost/numeric/interval/detail/division.hpp>", private ] },
+ { include: ["<boost/numeric/interval/detail/c99sub_rounding_control.hpp>", private, "<boost/numeric/interval/detail/c99_rounding_control.hpp>", private ] },
+ { include: ["<boost/numeric/interval/detail/c99sub_rounding_control.hpp>", private, "<boost/numeric/interval/detail/x86_rounding_control.hpp>", private ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/detail/division.hpp>", private ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/detail/test_input.hpp>", private ] },
+ { include: ["<boost/numeric/interval/detail/msvc_rounding_control.hpp>", private, "<boost/numeric/interval/detail/x86_rounding_control.hpp>", private ] },
+ { include: ["<boost/numeric/interval/detail/test_input.hpp>", private, "<boost/numeric/interval/detail/division.hpp>", private ] },
+ { include: ["<boost/numeric/interval/detail/x86gcc_rounding_control.hpp>", private, "<boost/numeric/interval/detail/x86_rounding_control.hpp>", private ] },
+ { include: ["<boost/numeric/odeint/integrate/detail/integrate_adaptive.hpp>", private, "<boost/numeric/odeint/integrate/detail/integrate_const.hpp>", private ] },
+ { include: ["<boost/numeric/odeint/stepper/detail/generic_rk_call_algebra.hpp>", private, "<boost/numeric/odeint/stepper/detail/generic_rk_algorithm.hpp>", private ] },
+ { include: ["<boost/numeric/odeint/stepper/detail/generic_rk_operations.hpp>", private, "<boost/numeric/odeint/stepper/detail/generic_rk_algorithm.hpp>", private ] },
+ { include: ["<boost/numeric/odeint/util/detail/less_with_sign.hpp>", private, "<boost/numeric/odeint/integrate/detail/integrate_adaptive.hpp>", private ] },
+ { include: ["<boost/numeric/odeint/util/detail/less_with_sign.hpp>", private, "<boost/numeric/odeint/integrate/detail/integrate_const.hpp>", private ] },
+ { include: ["<boost/numeric/odeint/util/detail/less_with_sign.hpp>", private, "<boost/numeric/odeint/integrate/detail/integrate_n_steps.hpp>", private ] },
+ { include: ["<boost/numeric/odeint/util/detail/less_with_sign.hpp>", private, "<boost/numeric/odeint/integrate/detail/integrate_times.hpp>", private ] },
+ { include: ["<boost/numeric/ublas/detail/definitions.hpp>", private, "<boost/numeric/ublas/detail/config.hpp>", private ] },
+ { include: ["<boost/phoenix/bind/detail/preprocessed/function_ptr_10.hpp>", private, "<boost/phoenix/bind/detail/preprocessed/function_ptr.hpp>", private ] },
+ { include: ["<boost/phoenix/bind/detail/preprocessed/function_ptr_20.hpp>", private, "<boost/phoenix/bind/detail/preprocessed/function_ptr.hpp>", private ] },
+ { include: ["<boost/phoenix/bind/detail/preprocessed/function_ptr_30.hpp>", private, "<boost/phoenix/bind/detail/preprocessed/function_ptr.hpp>", private ] },
+ { include: ["<boost/phoenix/bind/detail/preprocessed/function_ptr_40.hpp>", private, "<boost/phoenix/bind/detail/preprocessed/function_ptr.hpp>", private ] },
+ { include: ["<boost/phoenix/bind/detail/preprocessed/function_ptr_50.hpp>", private, "<boost/phoenix/bind/detail/preprocessed/function_ptr.hpp>", private ] },
+ { include: ["<boost/phoenix/bind/detail/preprocessed/function_ptr.hpp>", private, "<boost/phoenix/bind/detail/function_ptr.hpp>", private ] },
+ { include: ["<boost/phoenix/bind/detail/preprocessed/member_function_ptr_10.hpp>", private, "<boost/phoenix/bind/detail/preprocessed/member_function_ptr.hpp>", private ] },
+ { include: ["<boost/phoenix/bind/detail/preprocessed/member_function_ptr_20.hpp>", private, "<boost/phoenix/bind/detail/preprocessed/member_function_ptr.hpp>", private ] },
+ { include: ["<boost/phoenix/bind/detail/preprocessed/member_function_ptr_30.hpp>", private, "<boost/phoenix/bind/detail/preprocessed/member_function_ptr.hpp>", private ] },
+ { include: ["<boost/phoenix/bind/detail/preprocessed/member_function_ptr_40.hpp>", private, "<boost/phoenix/bind/detail/preprocessed/member_function_ptr.hpp>", private ] },
+ { include: ["<boost/phoenix/bind/detail/preprocessed/member_function_ptr_50.hpp>", private, "<boost/phoenix/bind/detail/preprocessed/member_function_ptr.hpp>", private ] },
+ { include: ["<boost/phoenix/bind/detail/preprocessed/member_function_ptr.hpp>", private, "<boost/phoenix/bind/detail/member_function_ptr.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/phx2_result.hpp>", private, "<boost/phoenix/core/detail/function_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/actor_operator_10.hpp>", private, "<boost/phoenix/core/detail/preprocessed/actor_operator.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/actor_operator_20.hpp>", private, "<boost/phoenix/core/detail/preprocessed/actor_operator.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/actor_operator_30.hpp>", private, "<boost/phoenix/core/detail/preprocessed/actor_operator.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/actor_operator_40.hpp>", private, "<boost/phoenix/core/detail/preprocessed/actor_operator.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/actor_operator_50.hpp>", private, "<boost/phoenix/core/detail/preprocessed/actor_operator.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/actor_operator.hpp>", private, "<boost/phoenix/core/detail/actor_operator.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/actor_result_of_10.hpp>", private, "<boost/phoenix/core/detail/preprocessed/actor_result_of.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/actor_result_of_20.hpp>", private, "<boost/phoenix/core/detail/preprocessed/actor_result_of.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/actor_result_of_30.hpp>", private, "<boost/phoenix/core/detail/preprocessed/actor_result_of.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/actor_result_of_40.hpp>", private, "<boost/phoenix/core/detail/preprocessed/actor_result_of.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/actor_result_of_50.hpp>", private, "<boost/phoenix/core/detail/preprocessed/actor_result_of.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/actor_result_of.hpp>", private, "<boost/phoenix/core/detail/actor_result_of.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/call_10.hpp>", private, "<boost/phoenix/core/detail/preprocessed/call.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/call_20.hpp>", private, "<boost/phoenix/core/detail/preprocessed/call.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/call_30.hpp>", private, "<boost/phoenix/core/detail/preprocessed/call.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/call_40.hpp>", private, "<boost/phoenix/core/detail/preprocessed/call.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/call_50.hpp>", private, "<boost/phoenix/core/detail/preprocessed/call.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/call.hpp>", private, "<boost/phoenix/core/detail/call.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/function_eval_10.hpp>", private, "<boost/phoenix/core/detail/preprocessed/function_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/function_eval_20.hpp>", private, "<boost/phoenix/core/detail/preprocessed/function_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/function_eval_30.hpp>", private, "<boost/phoenix/core/detail/preprocessed/function_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/function_eval_40.hpp>", private, "<boost/phoenix/core/detail/preprocessed/function_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/function_eval_50.hpp>", private, "<boost/phoenix/core/detail/preprocessed/function_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/function_eval.hpp>", private, "<boost/phoenix/core/detail/function_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/phx2_result_10.hpp>", private, "<boost/phoenix/core/detail/preprocessed/phx2_result.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/phx2_result_20.hpp>", private, "<boost/phoenix/core/detail/preprocessed/phx2_result.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/phx2_result_30.hpp>", private, "<boost/phoenix/core/detail/preprocessed/phx2_result.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/phx2_result_40.hpp>", private, "<boost/phoenix/core/detail/preprocessed/phx2_result.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/phx2_result_50.hpp>", private, "<boost/phoenix/core/detail/preprocessed/phx2_result.hpp>", private ] },
+ { include: ["<boost/phoenix/core/detail/preprocessed/phx2_result.hpp>", private, "<boost/phoenix/core/detail/phx2_result.hpp>", private ] },
+ { include: ["<boost/phoenix/function/detail/preprocessed/function_operator_10.hpp>", private, "<boost/phoenix/function/detail/preprocessed/function_operator.hpp>", private ] },
+ { include: ["<boost/phoenix/function/detail/preprocessed/function_operator_20.hpp>", private, "<boost/phoenix/function/detail/preprocessed/function_operator.hpp>", private ] },
+ { include: ["<boost/phoenix/function/detail/preprocessed/function_operator_30.hpp>", private, "<boost/phoenix/function/detail/preprocessed/function_operator.hpp>", private ] },
+ { include: ["<boost/phoenix/function/detail/preprocessed/function_operator_40.hpp>", private, "<boost/phoenix/function/detail/preprocessed/function_operator.hpp>", private ] },
+ { include: ["<boost/phoenix/function/detail/preprocessed/function_operator_50.hpp>", private, "<boost/phoenix/function/detail/preprocessed/function_operator.hpp>", private ] },
+ { include: ["<boost/phoenix/function/detail/preprocessed/function_operator.hpp>", private, "<boost/phoenix/function/detail/function_operator.hpp>", private ] },
+ { include: ["<boost/phoenix/function/detail/preprocessed/function_result_of_10.hpp>", private, "<boost/phoenix/function/detail/preprocessed/function_result_of.hpp>", private ] },
+ { include: ["<boost/phoenix/function/detail/preprocessed/function_result_of_20.hpp>", private, "<boost/phoenix/function/detail/preprocessed/function_result_of.hpp>", private ] },
+ { include: ["<boost/phoenix/function/detail/preprocessed/function_result_of_30.hpp>", private, "<boost/phoenix/function/detail/preprocessed/function_result_of.hpp>", private ] },
+ { include: ["<boost/phoenix/function/detail/preprocessed/function_result_of_40.hpp>", private, "<boost/phoenix/function/detail/preprocessed/function_result_of.hpp>", private ] },
+ { include: ["<boost/phoenix/function/detail/preprocessed/function_result_of_50.hpp>", private, "<boost/phoenix/function/detail/preprocessed/function_result_of.hpp>", private ] },
+ { include: ["<boost/phoenix/function/detail/preprocessed/function_result_of.hpp>", private, "<boost/phoenix/function/detail/function_result_of.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/construct_10.hpp>", private, "<boost/phoenix/object/detail/preprocessed/construct.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/construct_20.hpp>", private, "<boost/phoenix/object/detail/preprocessed/construct.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/construct_30.hpp>", private, "<boost/phoenix/object/detail/preprocessed/construct.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/construct_40.hpp>", private, "<boost/phoenix/object/detail/preprocessed/construct.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/construct_50.hpp>", private, "<boost/phoenix/object/detail/preprocessed/construct.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/construct_eval_10.hpp>", private, "<boost/phoenix/object/detail/preprocessed/construct_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/construct_eval_20.hpp>", private, "<boost/phoenix/object/detail/preprocessed/construct_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/construct_eval_30.hpp>", private, "<boost/phoenix/object/detail/preprocessed/construct_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/construct_eval_40.hpp>", private, "<boost/phoenix/object/detail/preprocessed/construct_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/construct_eval_50.hpp>", private, "<boost/phoenix/object/detail/preprocessed/construct_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/construct_eval.hpp>", private, "<boost/phoenix/object/detail/construct_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/construct.hpp>", private, "<boost/phoenix/object/detail/construct.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/new_10.hpp>", private, "<boost/phoenix/object/detail/preprocessed/new.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/new_20.hpp>", private, "<boost/phoenix/object/detail/preprocessed/new.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/new_30.hpp>", private, "<boost/phoenix/object/detail/preprocessed/new.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/new_40.hpp>", private, "<boost/phoenix/object/detail/preprocessed/new.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/new_50.hpp>", private, "<boost/phoenix/object/detail/preprocessed/new.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/new_eval_10.hpp>", private, "<boost/phoenix/object/detail/preprocessed/new_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/new_eval_20.hpp>", private, "<boost/phoenix/object/detail/preprocessed/new_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/new_eval_30.hpp>", private, "<boost/phoenix/object/detail/preprocessed/new_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/new_eval_40.hpp>", private, "<boost/phoenix/object/detail/preprocessed/new_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/new_eval_50.hpp>", private, "<boost/phoenix/object/detail/preprocessed/new_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/new_eval.hpp>", private, "<boost/phoenix/object/detail/new_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/object/detail/preprocessed/new.hpp>", private, "<boost/phoenix/object/detail/new.hpp>", private ] },
+ { include: ["<boost/phoenix/operator/detail/preprocessed/mem_fun_ptr_gen_10.hpp>", private, "<boost/phoenix/operator/detail/preprocessed/mem_fun_ptr_gen.hpp>", private ] },
+ { include: ["<boost/phoenix/operator/detail/preprocessed/mem_fun_ptr_gen_20.hpp>", private, "<boost/phoenix/operator/detail/preprocessed/mem_fun_ptr_gen.hpp>", private ] },
+ { include: ["<boost/phoenix/operator/detail/preprocessed/mem_fun_ptr_gen_30.hpp>", private, "<boost/phoenix/operator/detail/preprocessed/mem_fun_ptr_gen.hpp>", private ] },
+ { include: ["<boost/phoenix/operator/detail/preprocessed/mem_fun_ptr_gen_40.hpp>", private, "<boost/phoenix/operator/detail/preprocessed/mem_fun_ptr_gen.hpp>", private ] },
+ { include: ["<boost/phoenix/operator/detail/preprocessed/mem_fun_ptr_gen_50.hpp>", private, "<boost/phoenix/operator/detail/preprocessed/mem_fun_ptr_gen.hpp>", private ] },
+ { include: ["<boost/phoenix/operator/detail/preprocessed/mem_fun_ptr_gen.hpp>", private, "<boost/phoenix/operator/detail/mem_fun_ptr_gen.hpp>", private ] },
+ { include: ["<boost/phoenix/scope/detail/preprocessed/dynamic_10.hpp>", private, "<boost/phoenix/scope/detail/preprocessed/dynamic.hpp>", private ] },
+ { include: ["<boost/phoenix/scope/detail/preprocessed/dynamic_20.hpp>", private, "<boost/phoenix/scope/detail/preprocessed/dynamic.hpp>", private ] },
+ { include: ["<boost/phoenix/scope/detail/preprocessed/dynamic_30.hpp>", private, "<boost/phoenix/scope/detail/preprocessed/dynamic.hpp>", private ] },
+ { include: ["<boost/phoenix/scope/detail/preprocessed/dynamic_40.hpp>", private, "<boost/phoenix/scope/detail/preprocessed/dynamic.hpp>", private ] },
+ { include: ["<boost/phoenix/scope/detail/preprocessed/dynamic_50.hpp>", private, "<boost/phoenix/scope/detail/preprocessed/dynamic.hpp>", private ] },
+ { include: ["<boost/phoenix/scope/detail/preprocessed/dynamic.hpp>", private, "<boost/phoenix/scope/detail/dynamic.hpp>", private ] },
+ { include: ["<boost/phoenix/scope/detail/preprocessed/make_locals_10.hpp>", private, "<boost/phoenix/scope/detail/preprocessed/make_locals.hpp>", private ] },
+ { include: ["<boost/phoenix/scope/detail/preprocessed/make_locals_20.hpp>", private, "<boost/phoenix/scope/detail/preprocessed/make_locals.hpp>", private ] },
+ { include: ["<boost/phoenix/scope/detail/preprocessed/make_locals_30.hpp>", private, "<boost/phoenix/scope/detail/preprocessed/make_locals.hpp>", private ] },
+ { include: ["<boost/phoenix/scope/detail/preprocessed/make_locals_40.hpp>", private, "<boost/phoenix/scope/detail/preprocessed/make_locals.hpp>", private ] },
+ { include: ["<boost/phoenix/scope/detail/preprocessed/make_locals_50.hpp>", private, "<boost/phoenix/scope/detail/preprocessed/make_locals.hpp>", private ] },
+ { include: ["<boost/phoenix/scope/detail/preprocessed/make_locals.hpp>", private, "<boost/phoenix/scope/detail/make_locals.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/catch_push_back_10.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/catch_push_back.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/catch_push_back_20.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/catch_push_back.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/catch_push_back_30.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/catch_push_back.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/catch_push_back_40.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/catch_push_back.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/catch_push_back_50.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/catch_push_back.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/catch_push_back.hpp>", private, "<boost/phoenix/statement/detail/catch_push_back.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/switch_10.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/switch.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/switch_20.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/switch.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/switch_30.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/switch.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/switch_40.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/switch.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/switch_50.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/switch.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/switch.hpp>", private, "<boost/phoenix/statement/detail/switch.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/try_catch_eval_10.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/try_catch_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/try_catch_eval_20.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/try_catch_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/try_catch_eval_30.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/try_catch_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/try_catch_eval_40.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/try_catch_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/try_catch_eval_50.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/try_catch_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/try_catch_eval.hpp>", private, "<boost/phoenix/statement/detail/try_catch_eval.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/try_catch_expression_10.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/try_catch_expression.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/try_catch_expression_20.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/try_catch_expression.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/try_catch_expression_30.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/try_catch_expression.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/try_catch_expression_40.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/try_catch_expression.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/try_catch_expression_50.hpp>", private, "<boost/phoenix/statement/detail/preprocessed/try_catch_expression.hpp>", private ] },
+ { include: ["<boost/phoenix/statement/detail/preprocessed/try_catch_expression.hpp>", private, "<boost/phoenix/statement/detail/try_catch_expression.hpp>", private ] },
+ { include: ["<boost/phoenix/support/detail/iterate_define.hpp>", private, "<boost/phoenix/support/detail/iterate.hpp>", private ] },
+ { include: ["<boost/phoenix/support/detail/iterate_undef.hpp>", private, "<boost/phoenix/support/detail/iterate.hpp>", private ] },
+ { include: ["<boost/preprocessor/detail/check.hpp>", private, "<boost/preprocessor/detail/is_binary.hpp>", private ] },
+ { include: ["<boost/preprocessor/detail/check.hpp>", private, "<boost/preprocessor/detail/is_nullary.hpp>", private ] },
+ { include: ["<boost/preprocessor/detail/check.hpp>", private, "<boost/preprocessor/detail/is_unary.hpp>", private ] },
+ { include: ["<boost/preprocessor/detail/dmc/auto_rec.hpp>", private, "<boost/preprocessor/detail/auto_rec.hpp>", private ] },
+ { include: ["<boost/preprocessor/detail/is_binary.hpp>", private, "<boost/tti/detail/dvm_template_params.hpp>", private ] },
+ { include: ["<boost/preprocessor/detail/is_unary.hpp>", private, "<boost/local_function/detail/preprocessor/keyword/facility/is.hpp>", private ] },
+ { include: ["<boost/preprocessor/detail/is_unary.hpp>", private, "<boost/proto/detail/remove_typename.hpp>", private ] },
+ { include: ["<boost/preprocessor/detail/is_unary.hpp>", private, "<boost/range/detail/microsoft.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/bounds/lower1.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward1.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/bounds/lower2.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward2.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/bounds/lower3.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward3.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/bounds/lower4.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward4.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/bounds/lower5.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward5.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/bounds/upper1.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward1.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/bounds/upper2.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward2.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/bounds/upper3.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward3.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/bounds/upper4.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward4.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/bounds/upper5.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward5.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/finish.hpp>", private, "<boost/preprocessor/iteration/detail/local.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/iter/reverse1.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward1.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/iter/reverse2.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward2.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/iter/reverse3.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward3.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/iter/reverse4.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward4.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/iter/reverse5.hpp>", private, "<boost/preprocessor/iteration/detail/iter/forward5.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/rlocal.hpp>", private, "<boost/preprocessor/iteration/detail/local.hpp>", private ] },
+ { include: ["<boost/preprocessor/iteration/detail/start.hpp>", private, "<boost/preprocessor/iteration/detail/local.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/iteration/detail/bounds/lower1.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/iteration/detail/bounds/lower2.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/iteration/detail/bounds/lower3.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/iteration/detail/bounds/lower4.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/iteration/detail/bounds/lower5.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/iteration/detail/bounds/upper1.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/iteration/detail/bounds/upper2.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/iteration/detail/bounds/upper3.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/iteration/detail/bounds/upper4.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/iteration/detail/bounds/upper5.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/iteration/detail/finish.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/iteration/detail/start.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/slot/detail/counter.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/slot/detail/slot1.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/slot/detail/slot2.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/slot/detail/slot3.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/slot/detail/slot4.hpp>", private ] },
+ { include: ["<boost/preprocessor/slot/detail/shared.hpp>", private, "<boost/preprocessor/slot/detail/slot5.hpp>", private ] },
+ { include: ["<boost/program_options/detail/convert.hpp>", private, "<boost/program_options/detail/config_file.hpp>", private ] },
+ { include: ["<boost/program_options/detail/convert.hpp>", private, "<boost/program_options/detail/parsers.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/file_parser_error.hpp>", private, "<boost/property_tree/detail/info_parser_error.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/file_parser_error.hpp>", private, "<boost/property_tree/detail/json_parser_error.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/file_parser_error.hpp>", private, "<boost/property_tree/detail/xml_parser_error.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/info_parser_error.hpp>", private, "<boost/property_tree/detail/info_parser_read.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/info_parser_utils.hpp>", private, "<boost/property_tree/detail/info_parser_read.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/info_parser_utils.hpp>", private, "<boost/property_tree/detail/info_parser_write.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/json_parser_error.hpp>", private, "<boost/property_tree/detail/json_parser_read.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/ptree_utils.hpp>", private, "<boost/property_tree/detail/json_parser_read.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/ptree_utils.hpp>", private, "<boost/property_tree/detail/xml_parser_utils.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/ptree_utils.hpp>", private, "<boost/property_tree/detail/xml_parser_writer_settings.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/rapidxml.hpp>", private, "<boost/property_tree/detail/xml_parser_read_rapidxml.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/xml_parser_error.hpp>", private, "<boost/property_tree/detail/xml_parser_read_rapidxml.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/xml_parser_error.hpp>", private, "<boost/property_tree/detail/xml_parser_utils.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/xml_parser_flags.hpp>", private, "<boost/property_tree/detail/xml_parser_read_rapidxml.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/xml_parser_utils.hpp>", private, "<boost/property_tree/detail/xml_parser_read_rapidxml.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/xml_parser_utils.hpp>", private, "<boost/property_tree/detail/xml_parser_write.hpp>", private ] },
+ { include: ["<boost/property_tree/detail/xml_parser_writer_settings.hpp>", private, "<boost/property_tree/detail/xml_parser_utils.hpp>", private ] },
+ { include: ["<boost/proto/context/detail/preprocessed/callable_eval.hpp>", private, "<boost/proto/context/detail/callable_eval.hpp>", private ] },
+ { include: ["<boost/proto/context/detail/preprocessed/default_eval.hpp>", private, "<boost/proto/context/detail/default_eval.hpp>", private ] },
+ { include: ["<boost/proto/context/detail/preprocessed/null_eval.hpp>", private, "<boost/proto/context/detail/null_eval.hpp>", private ] },
+ { include: ["<boost/proto/detail/any.hpp>", private, "<boost/proto/detail/decltype.hpp>", private ] },
+ { include: ["<boost/proto/detail/class_member_traits.hpp>", private, "<boost/proto/detail/decltype.hpp>", private ] },
+ { include: ["<boost/proto/detail/decltype.hpp>", private, "<boost/phoenix/bind/detail/member_variable.hpp>", private ] },
+ { include: ["<boost/proto/detail/deduce_domain_n.hpp>", private, "<boost/proto/detail/deduce_domain.hpp>", private ] },
+ { include: ["<boost/proto/detail/ignore_unused.hpp>", private, "<boost/xpressive/detail/utility/ignore_unused.hpp>", private ] },
+ { include: ["<boost/proto/detail/is_noncopyable.hpp>", private, "<boost/proto/detail/poly_function.hpp>", private ] },
+ { include: ["<boost/proto/detail/memfun_funop.hpp>", private, "<boost/proto/detail/decltype.hpp>", private ] },
+ { include: ["<boost/proto/detail/poly_function_funop.hpp>", private, "<boost/proto/detail/poly_function.hpp>", private ] },
+ { include: ["<boost/proto/detail/poly_function_traits.hpp>", private, "<boost/proto/detail/poly_function.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/and_n.hpp>", private, "<boost/proto/detail/and_n.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/args.hpp>", private, "<boost/proto/detail/args.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/basic_expr.hpp>", private, "<boost/proto/detail/basic_expr.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/class_member_traits.hpp>", private, "<boost/proto/detail/class_member_traits.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/deduce_domain_n.hpp>", private, "<boost/proto/detail/deduce_domain_n.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/deep_copy.hpp>", private, "<boost/proto/detail/deep_copy.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/expr.hpp>", private, "<boost/proto/detail/expr.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/expr_variadic.hpp>", private, "<boost/proto/detail/expr.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/extends_funop_const.hpp>", private, "<boost/proto/detail/extends_funop_const.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/extends_funop.hpp>", private, "<boost/proto/detail/extends_funop.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/funop.hpp>", private, "<boost/proto/detail/funop.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/generate_by_value.hpp>", private, "<boost/proto/detail/generate_by_value.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/lambda_matches.hpp>", private, "<boost/proto/detail/lambda_matches.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/make_expr_funop.hpp>", private, "<boost/proto/detail/make_expr_funop.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/make_expr_.hpp>", private, "<boost/proto/detail/make_expr_.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/make_expr.hpp>", private, "<boost/proto/detail/make_expr.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/matches_.hpp>", private, "<boost/proto/detail/matches_.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/memfun_funop.hpp>", private, "<boost/proto/detail/memfun_funop.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/or_n.hpp>", private, "<boost/proto/detail/or_n.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/poly_function_funop.hpp>", private, "<boost/proto/detail/poly_function_funop.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/poly_function_traits.hpp>", private, "<boost/proto/detail/poly_function_traits.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/template_arity_helper.hpp>", private, "<boost/proto/detail/template_arity_helper.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/traits.hpp>", private, "<boost/proto/detail/traits.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/unpack_expr_.hpp>", private, "<boost/proto/detail/unpack_expr_.hpp>", private ] },
+ { include: ["<boost/proto/detail/preprocessed/vararg_matches_impl.hpp>", private, "<boost/proto/detail/vararg_matches_impl.hpp>", private ] },
+ { include: ["<boost/proto/detail/template_arity_helper.hpp>", private, "<boost/proto/detail/template_arity.hpp>", private ] },
+ { include: ["<boost/proto/transform/detail/expand_pack.hpp>", private, "<boost/proto/transform/detail/pack.hpp>", private ] },
+ { include: ["<boost/proto/transform/detail/pack_impl.hpp>", private, "<boost/proto/transform/detail/pack.hpp>", private ] },
+ { include: ["<boost/proto/transform/detail/preprocessed/call.hpp>", private, "<boost/proto/transform/detail/call.hpp>", private ] },
+ { include: ["<boost/proto/transform/detail/preprocessed/construct_funop.hpp>", private, "<boost/proto/transform/detail/construct_funop.hpp>", private ] },
+ { include: ["<boost/proto/transform/detail/preprocessed/construct_pod_funop.hpp>", private, "<boost/proto/transform/detail/construct_pod_funop.hpp>", private ] },
+ { include: ["<boost/proto/transform/detail/preprocessed/default_function_impl.hpp>", private, "<boost/proto/transform/detail/default_function_impl.hpp>", private ] },
+ { include: ["<boost/proto/transform/detail/preprocessed/expand_pack.hpp>", private, "<boost/proto/transform/detail/expand_pack.hpp>", private ] },
+ { include: ["<boost/proto/transform/detail/preprocessed/fold_impl.hpp>", private, "<boost/proto/transform/detail/fold_impl.hpp>", private ] },
+ { include: ["<boost/proto/transform/detail/preprocessed/lazy.hpp>", private, "<boost/proto/transform/detail/lazy.hpp>", private ] },
+ { include: ["<boost/proto/transform/detail/preprocessed/make_gcc_workaround.hpp>", private, "<boost/proto/transform/detail/make_gcc_workaround.hpp>", private ] },
+ { include: ["<boost/proto/transform/detail/preprocessed/make.hpp>", private, "<boost/proto/transform/detail/make.hpp>", private ] },
+ { include: ["<boost/proto/transform/detail/preprocessed/pack_impl.hpp>", private, "<boost/proto/transform/detail/pack_impl.hpp>", private ] },
+ { include: ["<boost/proto/transform/detail/preprocessed/pass_through_impl.hpp>", private, "<boost/proto/transform/detail/pass_through_impl.hpp>", private ] },
+ { include: ["<boost/proto/transform/detail/preprocessed/when.hpp>", private, "<boost/proto/transform/detail/when.hpp>", private ] },
+ { include: ["<boost/ptr_container/detail/default_deleter.hpp>", private, "<boost/ptr_container/detail/static_move_ptr.hpp>", private ] },
+ { include: ["<boost/ptr_container/detail/is_convertible.hpp>", private, "<boost/ptr_container/detail/static_move_ptr.hpp>", private ] },
+ { include: ["<boost/ptr_container/detail/move.hpp>", private, "<boost/ptr_container/detail/static_move_ptr.hpp>", private ] },
+ { include: ["<boost/ptr_container/detail/reversible_ptr_container.hpp>", private, "<boost/ptr_container/detail/associative_ptr_container.hpp>", private ] },
+ { include: ["<boost/ptr_container/detail/reversible_ptr_container.hpp>", private, "<boost/ptr_container/detail/serialize_reversible_cont.hpp>", private ] },
+ { include: ["<boost/ptr_container/detail/scoped_deleter.hpp>", private, "<boost/ptr_container/detail/reversible_ptr_container.hpp>", private ] },
+ { include: ["<boost/ptr_container/detail/serialize_xml_names.hpp>", private, "<boost/ptr_container/detail/serialize_ptr_map_adapter.hpp>", private ] },
+ { include: ["<boost/ptr_container/detail/serialize_xml_names.hpp>", private, "<boost/ptr_container/detail/serialize_reversible_cont.hpp>", private ] },
+ { include: ["<boost/ptr_container/detail/static_move_ptr.hpp>", private, "<boost/ptr_container/detail/reversible_ptr_container.hpp>", private ] },
+ { include: ["<boost/ptr_container/detail/throw_exception.hpp>", private, "<boost/ptr_container/detail/reversible_ptr_container.hpp>", private ] },
+ { include: ["<boost/python/detail/config.hpp>", private, "<boost/python/detail/exception_handler.hpp>", private ] },
+ { include: ["<boost/python/detail/config.hpp>", private, "<boost/python/detail/prefix.hpp>", private ] },
+ { include: ["<boost/python/detail/config.hpp>", private, "<boost/python/detail/scope.hpp>", private ] },
+ { include: ["<boost/python/detail/copy_ctor_mutates_rhs.hpp>", private, "<boost/python/detail/value_arg.hpp>", private ] },
+ { include: ["<boost/python/detail/cv_category.hpp>", private, "<boost/python/detail/unwind_type.hpp>", private ] },
+ { include: ["<boost/python/detail/defaults_gen.hpp>", private, "<boost/python/detail/defaults_def.hpp>", private ] },
+ { include: ["<boost/python/detail/def_helper_fwd.hpp>", private, "<boost/python/detail/def_helper.hpp>", private ] },
+ { include: ["<boost/python/detail/exception_handler.hpp>", private, "<boost/python/detail/translate_exception.hpp>", private ] },
+ { include: ["<boost/python/detail/indirect_traits.hpp>", private, "<boost/python/detail/decorated_type_id.hpp>", private ] },
+ { include: ["<boost/python/detail/indirect_traits.hpp>", private, "<boost/python/detail/def_helper.hpp>", private ] },
+ { include: ["<boost/python/detail/indirect_traits.hpp>", private, "<boost/python/detail/signature.hpp>", private ] },
+ { include: ["<boost/python/detail/indirect_traits.hpp>", private, "<boost/python/detail/unwind_type.hpp>", private ] },
+ { include: ["<boost/python/detail/invoke.hpp>", private, "<boost/python/detail/caller.hpp>", private ] },
+ { include: ["<boost/python/detail/is_auto_ptr.hpp>", private, "<boost/python/detail/copy_ctor_mutates_rhs.hpp>", private ] },
+ { include: ["<boost/python/detail/is_wrapper.hpp>", private, "<boost/python/detail/unwrap_wrapper.hpp>", private ] },
+ { include: ["<boost/python/detail/is_xxx.hpp>", private, "<boost/python/detail/is_auto_ptr.hpp>", private ] },
+ { include: ["<boost/python/detail/is_xxx.hpp>", private, "<boost/python/detail/is_shared_ptr.hpp>", private ] },
+ { include: ["<boost/python/detail/is_xxx.hpp>", private, "<boost/python/detail/value_is_xxx.hpp>", private ] },
+ { include: ["<boost/python/detail/make_keyword_range_fn.hpp>", private, "<boost/python/detail/defaults_def.hpp>", private ] },
+ { include: ["<boost/python/detail/none.hpp>", private, "<boost/python/detail/invoke.hpp>", private ] },
+ { include: ["<boost/python/detail/not_specified.hpp>", private, "<boost/python/detail/def_helper_fwd.hpp>", private ] },
+ { include: ["<boost/python/detail/not_specified.hpp>", private, "<boost/python/detail/def_helper.hpp>", private ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/detail/aix_init_module.hpp>", private ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/detail/invoke.hpp>", private ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/detail/is_wrapper.hpp>", private ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/detail/none.hpp>", private ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/detail/nullary_function_adaptor.hpp>", private ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/detail/sfinae.hpp>", private ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/detail/unwrap_wrapper.hpp>", private ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/detail/wrapper_base.hpp>", private ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/detail/caller.hpp>", private ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/detail/defaults_gen.hpp>", private ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/detail/invoke.hpp>", private ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/detail/result.hpp>", private ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/detail/signature.hpp>", private ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/detail/target.hpp>", private ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/detail/type_list.hpp>", private ] },
+ { include: ["<boost/python/detail/python22_fixed.h>", private, "<boost/python/detail/wrap_python.hpp>", private ] },
+ { include: ["<boost/python/detail/scope.hpp>", private, "<boost/python/detail/defaults_def.hpp>", private ] },
+ { include: ["<boost/python/detail/sfinae.hpp>", private, "<boost/python/detail/enable_if.hpp>", private ] },
+ { include: ["<boost/python/detail/signature.hpp>", private, "<boost/python/detail/caller.hpp>", private ] },
+ { include: ["<boost/python/detail/value_is_xxx.hpp>", private, "<boost/python/detail/value_is_shared_ptr.hpp>", private ] },
+ { include: ["<boost/python/detail/wrap_python.hpp>", private, "<boost/python/detail/prefix.hpp>", private ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/detail/operators.hpp>", private ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/detail/uniform_int_float.hpp>", private ] },
+ { include: ["<boost/random/detail/const_mod.hpp>", private, "<boost/random/detail/seed_impl.hpp>", private ] },
+ { include: ["<boost/random/detail/disable_warnings.hpp>", private, "<boost/random/detail/const_mod.hpp>", private ] },
+ { include: ["<boost/random/detail/disable_warnings.hpp>", private, "<boost/random/detail/large_arithmetic.hpp>", private ] },
+ { include: ["<boost/random/detail/disable_warnings.hpp>", private, "<boost/random/detail/seed_impl.hpp>", private ] },
+ { include: ["<boost/random/detail/disable_warnings.hpp>", private, "<boost/random/detail/uniform_int_float.hpp>", private ] },
+ { include: ["<boost/random/detail/enable_warnings.hpp>", private, "<boost/random/detail/const_mod.hpp>", private ] },
+ { include: ["<boost/random/detail/enable_warnings.hpp>", private, "<boost/random/detail/large_arithmetic.hpp>", private ] },
+ { include: ["<boost/random/detail/enable_warnings.hpp>", private, "<boost/random/detail/seed_impl.hpp>", private ] },
+ { include: ["<boost/random/detail/enable_warnings.hpp>", private, "<boost/random/detail/uniform_int_float.hpp>", private ] },
+ { include: ["<boost/random/detail/generator_bits.hpp>", private, "<boost/random/detail/seed_impl.hpp>", private ] },
+ { include: ["<boost/random/detail/generator_bits.hpp>", private, "<boost/random/detail/uniform_int_float.hpp>", private ] },
+ { include: ["<boost/random/detail/integer_log2.hpp>", private, "<boost/random/detail/large_arithmetic.hpp>", private ] },
+ { include: ["<boost/random/detail/integer_log2.hpp>", private, "<boost/random/detail/seed_impl.hpp>", private ] },
+ { include: ["<boost/random/detail/large_arithmetic.hpp>", private, "<boost/random/detail/const_mod.hpp>", private ] },
+ { include: ["<boost/random/detail/signed_unsigned_tools.hpp>", private, "<boost/random/detail/seed_impl.hpp>", private ] },
+ { include: ["<boost/range/detail/any_iterator_buffer.hpp>", private, "<boost/range/detail/any_iterator.hpp>", private ] },
+ { include: ["<boost/range/detail/any_iterator_buffer.hpp>", private, "<boost/range/detail/any_iterator_interface.hpp>", private ] },
+ { include: ["<boost/range/detail/any_iterator_interface.hpp>", private, "<boost/range/detail/any_iterator.hpp>", private ] },
+ { include: ["<boost/range/detail/any_iterator_interface.hpp>", private, "<boost/range/detail/any_iterator_wrapper.hpp>", private ] },
+ { include: ["<boost/range/detail/any_iterator_wrapper.hpp>", private, "<boost/range/detail/any_iterator.hpp>", private ] },
+ { include: ["<boost/range/detail/begin.hpp>", private, "<boost/range/detail/detail_str.hpp>", private ] },
+ { include: ["<boost/range/detail/collection_traits_detail.hpp>", private, "<boost/range/detail/collection_traits.hpp>", private ] },
+ { include: ["<boost/range/detail/common.hpp>", private, "<boost/range/detail/begin.hpp>", private ] },
+ { include: ["<boost/range/detail/common.hpp>", private, "<boost/range/detail/const_iterator.hpp>", private ] },
+ { include: ["<boost/range/detail/common.hpp>", private, "<boost/range/detail/detail_str.hpp>", private ] },
+ { include: ["<boost/range/detail/common.hpp>", private, "<boost/range/detail/difference_type.hpp>", private ] },
+ { include: ["<boost/range/detail/common.hpp>", private, "<boost/range/detail/empty.hpp>", private ] },
+ { include: ["<boost/range/detail/common.hpp>", private, "<boost/range/detail/end.hpp>", private ] },
+ { include: ["<boost/range/detail/common.hpp>", private, "<boost/range/detail/implementation_help.hpp>", private ] },
+ { include: ["<boost/range/detail/common.hpp>", private, "<boost/range/detail/iterator.hpp>", private ] },
+ { include: ["<boost/range/detail/common.hpp>", private, "<boost/range/detail/size.hpp>", private ] },
+ { include: ["<boost/range/detail/common.hpp>", private, "<boost/range/detail/size_type.hpp>", private ] },
+ { include: ["<boost/range/detail/common.hpp>", private, "<boost/range/detail/value_type.hpp>", private ] },
+ { include: ["<boost/range/detail/common.hpp>", private, "<boost/range/detail/vc6/end.hpp>", private ] },
+ { include: ["<boost/range/detail/common.hpp>", private, "<boost/range/detail/vc6/size.hpp>", private ] },
+ { include: ["<boost/range/detail/demote_iterator_traversal_tag.hpp>", private, "<boost/range/detail/join_iterator.hpp>", private ] },
+ { include: ["<boost/range/detail/detail_str.hpp>", private, "<boost/range/detail/as_literal.hpp>", private ] },
+ { include: ["<boost/range/detail/end.hpp>", private, "<boost/range/detail/detail_str.hpp>", private ] },
+ { include: ["<boost/range/detail/implementation_help.hpp>", private, "<boost/range/detail/end.hpp>", private ] },
+ { include: ["<boost/range/detail/implementation_help.hpp>", private, "<boost/range/detail/size.hpp>", private ] },
+ { include: ["<boost/range/detail/implementation_help.hpp>", private, "<boost/range/detail/vc6/end.hpp>", private ] },
+ { include: ["<boost/range/detail/implementation_help.hpp>", private, "<boost/range/detail/vc6/size.hpp>", private ] },
+ { include: ["<boost/range/detail/remove_extent.hpp>", private, "<boost/range/detail/const_iterator.hpp>", private ] },
+ { include: ["<boost/range/detail/remove_extent.hpp>", private, "<boost/range/detail/end.hpp>", private ] },
+ { include: ["<boost/range/detail/remove_extent.hpp>", private, "<boost/range/detail/iterator.hpp>", private ] },
+ { include: ["<boost/range/detail/remove_extent.hpp>", private, "<boost/range/detail/size.hpp>", private ] },
+ { include: ["<boost/range/detail/remove_extent.hpp>", private, "<boost/range/detail/value_type.hpp>", private ] },
+ { include: ["<boost/range/detail/remove_extent.hpp>", private, "<boost/range/detail/vc6/end.hpp>", private ] },
+ { include: ["<boost/range/detail/remove_extent.hpp>", private, "<boost/range/detail/vc6/size.hpp>", private ] },
+ { include: ["<boost/range/detail/sfinae.hpp>", private, "<boost/range/detail/common.hpp>", private ] },
+ { include: ["<boost/range/detail/size_type.hpp>", private, "<boost/range/detail/detail_str.hpp>", private ] },
+ { include: ["<boost/range/detail/size_type.hpp>", private, "<boost/range/detail/size.hpp>", private ] },
+ { include: ["<boost/range/detail/size_type.hpp>", private, "<boost/range/detail/vc6/size.hpp>", private ] },
+ { include: ["<boost/range/detail/value_type.hpp>", private, "<boost/range/detail/detail_str.hpp>", private ] },
+ { include: ["<boost/range/detail/vc6/end.hpp>", private, "<boost/range/detail/end.hpp>", private ] },
+ { include: ["<boost/range/detail/vc6/size.hpp>", private, "<boost/range/detail/size.hpp>", private ] },
+ { include: ["<boost/ratio/detail/mpl/abs.hpp>", private, "<boost/ratio/detail/mpl/gcd.hpp>", private ] },
+ { include: ["<boost/ratio/detail/mpl/abs.hpp>", private, "<boost/ratio/detail/mpl/lcm.hpp>", private ] },
+ { include: ["<boost/ratio/detail/mpl/abs.hpp>", private, "<boost/ratio/detail/overflow_helpers.hpp>", private ] },
+ { include: ["<boost/ratio/detail/mpl/sign.hpp>", private, "<boost/ratio/detail/overflow_helpers.hpp>", private ] },
+ { include: ["<boost/ratio/detail/overflow_helpers.hpp>", private, "<boost/chrono/detail/is_evenly_divisible_by.hpp>", private ] },
+ { include: ["<boost/serialization/detail/get_data.hpp>", private, "<boost/mpi/detail/mpi_datatype_primitive.hpp>", private ] },
+ { include: ["<boost/serialization/detail/get_data.hpp>", private, "<boost/mpi/detail/packed_iprimitive.hpp>", private ] },
+ { include: ["<boost/serialization/detail/get_data.hpp>", private, "<boost/mpi/detail/packed_oprimitive.hpp>", private ] },
+ { include: ["<boost/serialization/detail/shared_count_132.hpp>", private, "<boost/serialization/detail/shared_ptr_132.hpp>", private ] },
+ { include: ["<boost/serialization/detail/shared_ptr_nmt_132.hpp>", private, "<boost/serialization/detail/shared_ptr_132.hpp>", private ] },
+ { include: ["<boost/signals2/detail/auto_buffer.hpp>", private, "<boost/signals2/detail/slot_call_iterator.hpp>", private ] },
+ { include: ["<boost/signals2/detail/signals_common.hpp>", private, "<boost/signals2/detail/tracked_objects_visitor.hpp>", private ] },
+ { include: ["<boost/signals2/detail/signals_common_macros.hpp>", private, "<boost/signals2/detail/preprocessed_arg_type.hpp>", private ] },
+ { include: ["<boost/signals2/detail/unique_lock.hpp>", private, "<boost/signals2/detail/slot_call_iterator.hpp>", private ] },
+ { include: ["<boost/signals2/detail/variadic_arg_type.hpp>", private, "<boost/signals2/detail/variadic_slot_invoker.hpp>", private ] },
+ { include: ["<boost/signals/detail/config.hpp>", private, "<boost/signals/detail/named_slot_map.hpp>", private ] },
+ { include: ["<boost/signals/detail/config.hpp>", private, "<boost/signals/detail/signal_base.hpp>", private ] },
+ { include: ["<boost/signals/detail/config.hpp>", private, "<boost/signals/detail/signals_common.hpp>", private ] },
+ { include: ["<boost/signals/detail/config.hpp>", private, "<boost/signals/detail/slot_call_iterator.hpp>", private ] },
+ { include: ["<boost/signals/detail/named_slot_map.hpp>", private, "<boost/signals/detail/signal_base.hpp>", private ] },
+ { include: ["<boost/signals/detail/signals_common.hpp>", private, "<boost/signals/detail/named_slot_map.hpp>", private ] },
+ { include: ["<boost/signals/detail/signals_common.hpp>", private, "<boost/signals/detail/signal_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/array_utility.hpp>", private, "<boost/smart_ptr/detail/array_deleter.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/atomic_count_gcc.hpp>", private, "<boost/smart_ptr/detail/atomic_count.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/atomic_count_gcc_x86.hpp>", private, "<boost/smart_ptr/detail/atomic_count.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/atomic_count.hpp>", private, "<boost/smart_ptr/detail/shared_array_nmt.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/atomic_count.hpp>", private, "<boost/smart_ptr/detail/shared_ptr_nmt.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/atomic_count_pthreads.hpp>", private, "<boost/smart_ptr/detail/atomic_count.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/atomic_count_sync.hpp>", private, "<boost/smart_ptr/detail/atomic_count.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/atomic_count_win32.hpp>", private, "<boost/smart_ptr/detail/atomic_count.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/lightweight_mutex.hpp>", private, "<boost/atomic/detail/lockpool.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/lightweight_mutex.hpp>", private, "<boost/smart_ptr/detail/quick_allocator.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/lwm_nop.hpp>", private, "<boost/smart_ptr/detail/lightweight_mutex.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/lwm_pthreads.hpp>", private, "<boost/smart_ptr/detail/lightweight_mutex.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/lwm_win32_cs.hpp>", private, "<boost/smart_ptr/detail/lightweight_mutex.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/quick_allocator.hpp>", private, "<boost/smart_ptr/detail/sp_counted_impl.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_acc_ia64.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_aix.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_cw_ppc.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_gcc_ia64.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_gcc_mips.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_gcc_ppc.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_gcc_sparc.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_gcc_x86.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base.hpp>", private, "<boost/smart_ptr/detail/shared_count.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base.hpp>", private, "<boost/smart_ptr/detail/sp_counted_impl.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_nt.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_pt.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_snc_ps3.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_spin.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_sync.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_vacpp_ppc.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_base_w32.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_counted_impl.hpp>", private, "<boost/smart_ptr/detail/shared_count.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_forward.hpp>", private, "<boost/smart_ptr/detail/array_deleter.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_has_sync.hpp>", private, "<boost/smart_ptr/detail/atomic_count.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_has_sync.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/sp_has_sync.hpp>", private, "<boost/smart_ptr/detail/spinlock.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/spinlock_gcc_arm.hpp>", private, "<boost/smart_ptr/detail/spinlock.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/spinlock.hpp>", private, "<boost/smart_ptr/detail/spinlock_pool.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/spinlock_nt.hpp>", private, "<boost/smart_ptr/detail/spinlock.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/spinlock_pool.hpp>", private, "<boost/smart_ptr/detail/sp_counted_base_spin.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/spinlock_pt.hpp>", private, "<boost/smart_ptr/detail/spinlock.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/spinlock_sync.hpp>", private, "<boost/smart_ptr/detail/spinlock.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/spinlock_w32.hpp>", private, "<boost/smart_ptr/detail/spinlock.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/yield_k.hpp>", private, "<boost/smart_ptr/detail/spinlock_gcc_arm.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/yield_k.hpp>", private, "<boost/smart_ptr/detail/spinlock_sync.hpp>", private ] },
+ { include: ["<boost/smart_ptr/detail/yield_k.hpp>", private, "<boost/smart_ptr/detail/spinlock_w32.hpp>", private ] },
+ { include: ["<boost/spirit/fusion/detail/access.hpp>", private, "<boost/xpressive/detail/utility/cons.hpp>", private ] },
+ { include: ["<boost/spirit/fusion/detail/config.hpp>", private, "<boost/xpressive/detail/utility/cons.hpp>", private ] },
+ { include: ["<boost/spirit/fusion/iterator/detail/iterator_base.hpp>", private, "<boost/xpressive/detail/utility/cons.hpp>", private ] },
+ { include: ["<boost/spirit/fusion/sequence/detail/sequence_base.hpp>", private, "<boost/xpressive/detail/utility/cons.hpp>", private ] },
+ { include: ["<boost/spirit/home/classic/utility/impl/chset/basic_chset.ipp>", private, "<boost/spirit/home/classic/utility/impl/chset/basic_chset.hpp>", private ] },
+ { include: ["<boost/spirit/home/classic/utility/impl/chset/range_run.hpp>", private, "<boost/spirit/home/classic/utility/impl/chset/basic_chset.hpp>", private ] },
+ { include: ["<boost/spirit/home/classic/utility/impl/chset/range_run.ipp>", private, "<boost/spirit/home/classic/utility/impl/chset/range_run.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/detail/alternative_function.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/detail/extract_from.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/detail/pass_container.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/generate_to.hpp>", private, "<boost/spirit/home/karma/detail/string_compare.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/generate_to.hpp>", private, "<boost/spirit/home/karma/detail/string_generate.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/generate_to.hpp>", private, "<boost/spirit/home/karma/numeric/detail/bool_utils.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/generate_to.hpp>", private, "<boost/spirit/home/karma/numeric/detail/numeric_utils.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/generate_to.hpp>", private, "<boost/spirit/home/karma/numeric/detail/real_utils.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/generate_to.hpp>", private, "<boost/spirit/home/karma/stream/detail/iterator_sink.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/detail/alternative_function.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/detail/generate.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/detail/generate_to.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/string_generate.hpp>", private, "<boost/spirit/home/karma/numeric/detail/bool_utils.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/string_generate.hpp>", private, "<boost/spirit/home/karma/numeric/detail/numeric_utils.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/detail/string_generate.hpp>", private, "<boost/spirit/home/karma/numeric/detail/real_utils.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/numeric/detail/numeric_utils.hpp>", private, "<boost/spirit/home/karma/numeric/detail/bool_utils.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/numeric/detail/numeric_utils.hpp>", private, "<boost/spirit/home/karma/numeric/detail/real_utils.hpp>", private ] },
+ { include: ["<boost/spirit/home/karma/stream/detail/format_manip.hpp>", private, "<boost/spirit/home/karma/stream/detail/format_manip_auto.hpp>", private ] },
+ { include: ["<boost/spirit/home/phoenix/detail/local_reference.hpp>", private, "<boost/spirit/home/phoenix/detail/type_deduction.hpp>", private ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/mem_fun_ptr_eval.hpp>", private, "<boost/spirit/home/phoenix/operator/detail/mem_fun_ptr_gen.hpp>", private ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/mem_fun_ptr_return.hpp>", private, "<boost/spirit/home/phoenix/operator/detail/mem_fun_ptr_eval.hpp>", private ] },
+ { include: ["<boost/spirit/home/phoenix/statement/detail/switch_eval.ipp>", private, "<boost/spirit/home/phoenix/statement/detail/switch_eval.hpp>", private ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/qi/detail/alternative_function.hpp>", private ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/qi/detail/string_parse.hpp>", private ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/detail/alternative_function.hpp>", private ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/detail/assign_to.hpp>", private ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/detail/pass_container.hpp>", private ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/numeric/detail/numeric_utils.hpp>", private ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/numeric/detail/real_impl.hpp>", private ] },
+ { include: ["<boost/spirit/home/qi/detail/construct.hpp>", private, "<boost/spirit/home/qi/detail/assign_to.hpp>", private ] },
+ { include: ["<boost/spirit/home/qi/stream/detail/match_manip.hpp>", private, "<boost/spirit/home/qi/stream/detail/match_manip_auto.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/endian/cover_operators.hpp>", private, "<boost/spirit/home/support/detail/endian/endian.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/endian/endian.hpp>", private, "<boost/spirit/home/support/detail/endian.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/hold_any.hpp>", private, "<boost/spirit/home/karma/detail/alternative_function.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/hold_any.hpp>", private, "<boost/spirit/home/karma/detail/pass_container.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/math/detail/fp_traits.hpp>", private, "<boost/spirit/home/support/detail/math/fpclassify.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/math/detail/fp_traits.hpp>", private, "<boost/spirit/home/support/detail/math/signbit.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/math/fpclassify.hpp>", private, "<boost/spirit/home/support/detail/sign.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/math/signbit.hpp>", private, "<boost/spirit/home/support/detail/sign.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/pow10.hpp>", private, "<boost/spirit/home/karma/numeric/detail/numeric_utils.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/pow10.hpp>", private, "<boost/spirit/home/karma/numeric/detail/real_utils.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/pow10.hpp>", private, "<boost/spirit/home/qi/numeric/detail/real_impl.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/scoped_enum_emulation.hpp>", private, "<boost/spirit/home/support/detail/endian/endian.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/sign.hpp>", private, "<boost/spirit/home/karma/numeric/detail/numeric_utils.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/sign.hpp>", private, "<boost/spirit/home/karma/numeric/detail/real_utils.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/detail/sign.hpp>", private, "<boost/spirit/home/qi/numeric/detail/real_impl.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/fixed_size_queue.hpp>", private, "<boost/spirit/home/support/iterators/detail/fixed_size_queue_policy.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/input_iterator_policy.hpp>", private, "<boost/spirit/home/support/iterators/detail/buffering_input_iterator_policy.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/multi_pass.hpp>", private, "<boost/spirit/home/support/iterators/detail/buffering_input_iterator_policy.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/multi_pass.hpp>", private, "<boost/spirit/home/support/iterators/detail/buf_id_check_policy.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/multi_pass.hpp>", private, "<boost/spirit/home/support/iterators/detail/first_owner_policy.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/multi_pass.hpp>", private, "<boost/spirit/home/support/iterators/detail/fixed_size_queue_policy.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/multi_pass.hpp>", private, "<boost/spirit/home/support/iterators/detail/functor_input_policy.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/multi_pass.hpp>", private, "<boost/spirit/home/support/iterators/detail/input_iterator_policy.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/multi_pass.hpp>", private, "<boost/spirit/home/support/iterators/detail/istream_policy.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/multi_pass.hpp>", private, "<boost/spirit/home/support/iterators/detail/lex_input_policy.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/multi_pass.hpp>", private, "<boost/spirit/home/support/iterators/detail/no_check_policy.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/multi_pass.hpp>", private, "<boost/spirit/home/support/iterators/detail/ref_counted_policy.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/multi_pass.hpp>", private, "<boost/spirit/home/support/iterators/detail/split_functor_input_policy.hpp>", private ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/multi_pass.hpp>", private, "<boost/spirit/home/support/iterators/detail/split_std_deque_policy.hpp>", private ] },
+ { include: ["<boost/statechart/detail/avoid_unused_warning.hpp>", private, "<boost/statechart/detail/memory.hpp>", private ] },
+ { include: ["<boost/statechart/detail/counted_base.hpp>", private, "<boost/statechart/detail/state_base.hpp>", private ] },
+ { include: ["<boost/statechart/detail/state_base.hpp>", private, "<boost/statechart/detail/leaf_state.hpp>", private ] },
+ { include: ["<boost/statechart/detail/state_base.hpp>", private, "<boost/statechart/detail/node_state.hpp>", private ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/detail/global_typedef.hpp>", private ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/detail/unit_test_parameters.hpp>", private ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/detail/workaround.hpp>", private ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/detail/unit_test_parameters.hpp>", private ] },
+ { include: ["<boost/test/detail/log_level.hpp>", private, "<boost/test/detail/unit_test_parameters.hpp>", private ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/detail/global_typedef.hpp>", private ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/detail/unit_test_parameters.hpp>", private ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/detail/workaround.hpp>", private ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/log/detail/config.hpp>", private ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/detail/counter.hpp>", private ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/detail/force_cast.hpp>", private ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/detail/lockable_wrapper.hpp>", private ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/detail/log.hpp>", private ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/detail/move.hpp>", private ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/detail/singleton.hpp>", private ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/detail/thread.hpp>", private ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/detail/thread_interruption.hpp>", private ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/detail/tss_hooks.hpp>", private ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/detail/counter.hpp>", private ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/detail/move.hpp>", private ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/detail/thread_interruption.hpp>", private ] },
+ { include: ["<boost/thread/detail/invoke.hpp>", private, "<boost/thread/detail/async_func.hpp>", private ] },
+ { include: ["<boost/thread/detail/invoke.hpp>", private, "<boost/thread/detail/thread.hpp>", private ] },
+ { include: ["<boost/thread/detail/is_convertible.hpp>", private, "<boost/thread/detail/thread.hpp>", private ] },
+ { include: ["<boost/thread/detail/make_tuple_indices.hpp>", private, "<boost/thread/detail/async_func.hpp>", private ] },
+ { include: ["<boost/thread/detail/make_tuple_indices.hpp>", private, "<boost/thread/detail/thread.hpp>", private ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/detail/async_func.hpp>", private ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/detail/invoke.hpp>", private ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/detail/is_convertible.hpp>", private ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/detail/thread.hpp>", private ] },
+ { include: ["<boost/thread/detail/platform.hpp>", private, "<boost/thread/detail/config.hpp>", private ] },
+ { include: ["<boost/thread/detail/platform.hpp>", private, "<boost/thread/detail/thread_heap_alloc.hpp>", private ] },
+ { include: ["<boost/thread/detail/thread_heap_alloc.hpp>", private, "<boost/thread/detail/thread.hpp>", private ] },
+ { include: ["<boost/tr1/detail/config.hpp>", private, "<boost/tr1/detail/config_all.hpp>", private ] },
+ { include: ["<boost/tti/detail/dcomp_mem_fun.hpp>", private, "<boost/tti/detail/dmem_fun.hpp>", private ] },
+ { include: ["<boost/tti/detail/ddeftype.hpp>", private, "<boost/tti/detail/dmem_data.hpp>", private ] },
+ { include: ["<boost/tti/detail/ddeftype.hpp>", private, "<boost/tti/detail/dmem_fun.hpp>", private ] },
+ { include: ["<boost/tti/detail/ddeftype.hpp>", private, "<boost/tti/detail/dtype.hpp>", private ] },
+ { include: ["<boost/tti/detail/dftclass.hpp>", private, "<boost/tti/detail/dcomp_mem_fun.hpp>", private ] },
+ { include: ["<boost/tti/detail/dftclass.hpp>", private, "<boost/tti/detail/dmem_data.hpp>", private ] },
+ { include: ["<boost/tti/detail/dlambda.hpp>", private, "<boost/tti/detail/dtype.hpp>", private ] },
+ { include: ["<boost/tti/detail/dmem_data.hpp>", private, "<boost/tti/detail/ddata.hpp>", private ] },
+ { include: ["<boost/tti/detail/dmem_fun.hpp>", private, "<boost/tti/detail/dfunction.hpp>", private ] },
+ { include: ["<boost/tti/detail/dmem_fun.hpp>", private, "<boost/tti/detail/dmem_data.hpp>", private ] },
+ { include: ["<boost/tti/detail/dmetafunc.hpp>", private, "<boost/tti/detail/dlambda.hpp>", private ] },
+ { include: ["<boost/tti/detail/dnullptr.hpp>", private, "<boost/tti/detail/dcomp_mem_fun.hpp>", private ] },
+ { include: ["<boost/tti/detail/dnullptr.hpp>", private, "<boost/tti/detail/dcomp_static_mem_fun.hpp>", private ] },
+ { include: ["<boost/tti/detail/dnullptr.hpp>", private, "<boost/tti/detail/dmem_fun.hpp>", private ] },
+ { include: ["<boost/tti/detail/dnullptr.hpp>", private, "<boost/tti/detail/dstatic_mem_data.hpp>", private ] },
+ { include: ["<boost/tti/detail/dnullptr.hpp>", private, "<boost/tti/detail/dstatic_mem_fun.hpp>", private ] },
+ { include: ["<boost/tti/detail/dplaceholder.hpp>", private, "<boost/tti/detail/dlambda.hpp>", private ] },
+ { include: ["<boost/tti/detail/dptmf.hpp>", private, "<boost/tti/detail/dmem_fun.hpp>", private ] },
+ { include: ["<boost/tti/detail/dstatic_mem_data.hpp>", private, "<boost/tti/detail/ddata.hpp>", private ] },
+ { include: ["<boost/tti/detail/dstatic_mem_fun.hpp>", private, "<boost/tti/detail/dfunction.hpp>", private ] },
+ { include: ["<boost/tti/detail/dtemplate.hpp>", private, "<boost/tti/detail/dvm_template_params.hpp>", private ] },
+ { include: ["<boost/tti/detail/dtemplate_params.hpp>", private, "<boost/tti/detail/dvm_template_params.hpp>", private ] },
+ { include: ["<boost/tti/detail/dtfunction.hpp>", private, "<boost/tti/detail/dstatic_mem_fun.hpp>", private ] },
+ { include: ["<boost/type_erasure/detail/any_base.hpp>", private, "<boost/type_erasure/detail/access.hpp>", private ] },
+ { include: ["<boost/type_erasure/detail/get_placeholders.hpp>", private, "<boost/type_erasure/detail/check_map.hpp>", private ] },
+ { include: ["<boost/type_erasure/detail/get_placeholders.hpp>", private, "<boost/type_erasure/detail/normalize.hpp>", private ] },
+ { include: ["<boost/type_erasure/detail/get_signature.hpp>", private, "<boost/type_erasure/detail/adapt_to_vtable.hpp>", private ] },
+ { include: ["<boost/type_erasure/detail/normalize_deduced.hpp>", private, "<boost/type_erasure/detail/normalize.hpp>", private ] },
+ { include: ["<boost/type_erasure/detail/normalize.hpp>", private, "<boost/type_erasure/detail/instantiate.hpp>", private ] },
+ { include: ["<boost/type_erasure/detail/rebind_placeholders.hpp>", private, "<boost/type_erasure/detail/instantiate.hpp>", private ] },
+ { include: ["<boost/type_erasure/detail/rebind_placeholders.hpp>", private, "<boost/type_erasure/detail/normalize.hpp>", private ] },
+ { include: ["<boost/type_erasure/detail/rebind_placeholders.hpp>", private, "<boost/type_erasure/detail/vtable.hpp>", private ] },
+ { include: ["<boost/type_erasure/detail/storage.hpp>", private, "<boost/type_erasure/detail/access.hpp>", private ] },
+ { include: ["<boost/type_erasure/detail/storage.hpp>", private, "<boost/type_erasure/detail/adapt_to_vtable.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/iostreams/detail/is_dereferenceable.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/detail/has_binary_operator.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/detail/has_postfix_operator.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/detail/has_prefix_operator.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/variant/detail/bool_trait_def.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/detail/has_binary_operator.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/detail/has_postfix_operator.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/detail/has_prefix_operator.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/variant/detail/bool_trait_undef.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/range/detail/common.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/iostreams/detail/is_dereferenceable.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/type_traits/detail/bool_trait_def.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/type_traits/detail/size_t_trait_def.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/type_traits/detail/type_trait_def.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/iostreams/detail/bool_trait_def.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/range/detail/sfinae.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/tti/detail/dcomp_mem_fun.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/tti/detail/dcomp_static_mem_fun.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/tti/detail/dmem_data.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/tti/detail/dmem_fun.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/tti/detail/dstatic_mem_data.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/tti/detail/dstatic_mem_fun.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/detail/cv_traits_impl.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/detail/is_function_ptr_tester.hpp>", private ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/detail/is_mem_fun_pointer_tester.hpp>", private ] },
+ { include: ["<boost/units/detail/dimension_list.hpp>", private, "<boost/units/detail/conversion_impl.hpp>", private ] },
+ { include: ["<boost/units/detail/dimension_list.hpp>", private, "<boost/units/detail/dimension_impl.hpp>", private ] },
+ { include: ["<boost/units/detail/dimension_list.hpp>", private, "<boost/units/detail/linear_algebra.hpp>", private ] },
+ { include: ["<boost/units/detail/dimension_list.hpp>", private, "<boost/units/detail/sort.hpp>", private ] },
+ { include: ["<boost/units/detail/heterogeneous_conversion.hpp>", private, "<boost/units/detail/conversion_impl.hpp>", private ] },
+ { include: ["<boost/units/detail/linear_algebra.hpp>", private, "<boost/units/detail/heterogeneous_conversion.hpp>", private ] },
+ { include: ["<boost/units/detail/one.hpp>", private, "<boost/units/detail/conversion_impl.hpp>", private ] },
+ { include: ["<boost/units/detail/one.hpp>", private, "<boost/units/detail/static_rational_power.hpp>", private ] },
+ { include: ["<boost/units/detail/one.hpp>", private, "<boost/units/detail/unscale.hpp>", private ] },
+ { include: ["<boost/units/detail/one.hpp>", private, "<boost/units/systems/detail/constants.hpp>", private ] },
+ { include: ["<boost/units/detail/push_front_if.hpp>", private, "<boost/units/detail/dimension_impl.hpp>", private ] },
+ { include: ["<boost/units/detail/push_front_if.hpp>", private, "<boost/units/detail/push_front_or_add.hpp>", private ] },
+ { include: ["<boost/units/detail/push_front_or_add.hpp>", private, "<boost/units/detail/dimension_impl.hpp>", private ] },
+ { include: ["<boost/units/detail/sort.hpp>", private, "<boost/units/detail/linear_algebra.hpp>", private ] },
+ { include: ["<boost/units/detail/static_rational_power.hpp>", private, "<boost/units/detail/conversion_impl.hpp>", private ] },
+ { include: ["<boost/units/detail/unscale.hpp>", private, "<boost/units/detail/conversion_impl.hpp>", private ] },
+ { include: ["<boost/unordered/detail/allocate.hpp>", private, "<boost/unordered/detail/buckets.hpp>", private ] },
+ { include: ["<boost/unordered/detail/buckets.hpp>", private, "<boost/unordered/detail/table.hpp>", private ] },
+ { include: ["<boost/unordered/detail/extract_key.hpp>", private, "<boost/unordered/detail/equivalent.hpp>", private ] },
+ { include: ["<boost/unordered/detail/extract_key.hpp>", private, "<boost/unordered/detail/unique.hpp>", private ] },
+ { include: ["<boost/unordered/detail/fwd.hpp>", private, "<boost/unordered/detail/allocate.hpp>", private ] },
+ { include: ["<boost/unordered/detail/table.hpp>", private, "<boost/unordered/detail/equivalent.hpp>", private ] },
+ { include: ["<boost/unordered/detail/table.hpp>", private, "<boost/unordered/detail/extract_key.hpp>", private ] },
+ { include: ["<boost/unordered/detail/table.hpp>", private, "<boost/unordered/detail/unique.hpp>", private ] },
+ { include: ["<boost/unordered/detail/util.hpp>", private, "<boost/unordered/detail/buckets.hpp>", private ] },
+ { include: ["<boost/unordered/detail/util.hpp>", private, "<boost/unordered/detail/table.hpp>", private ] },
+ { include: ["<boost/variant/detail/apply_visitor_binary.hpp>", private, "<boost/variant/detail/apply_visitor_delayed.hpp>", private ] },
+ { include: ["<boost/variant/detail/apply_visitor_unary.hpp>", private, "<boost/variant/detail/apply_visitor_binary.hpp>", private ] },
+ { include: ["<boost/variant/detail/apply_visitor_unary.hpp>", private, "<boost/variant/detail/apply_visitor_delayed.hpp>", private ] },
+ { include: ["<boost/variant/detail/backup_holder.hpp>", private, "<boost/variant/detail/visitation_impl.hpp>", private ] },
+ { include: ["<boost/variant/detail/cast_storage.hpp>", private, "<boost/variant/detail/visitation_impl.hpp>", private ] },
+ { include: ["<boost/variant/detail/enable_recursive_fwd.hpp>", private, "<boost/variant/detail/enable_recursive.hpp>", private ] },
+ { include: ["<boost/variant/detail/forced_return.hpp>", private, "<boost/variant/detail/visitation_impl.hpp>", private ] },
+ { include: ["<boost/variant/detail/generic_result_type.hpp>", private, "<boost/variant/detail/apply_visitor_binary.hpp>", private ] },
+ { include: ["<boost/variant/detail/generic_result_type.hpp>", private, "<boost/variant/detail/apply_visitor_delayed.hpp>", private ] },
+ { include: ["<boost/variant/detail/generic_result_type.hpp>", private, "<boost/variant/detail/apply_visitor_unary.hpp>", private ] },
+ { include: ["<boost/variant/detail/generic_result_type.hpp>", private, "<boost/variant/detail/forced_return.hpp>", private ] },
+ { include: ["<boost/variant/detail/generic_result_type.hpp>", private, "<boost/variant/detail/visitation_impl.hpp>", private ] },
+ { include: ["<boost/variant/detail/move.hpp>", private, "<boost/variant/detail/initializer.hpp>", private ] },
+ { include: ["<boost/variant/detail/substitute_fwd.hpp>", private, "<boost/variant/detail/substitute.hpp>", private ] },
+ { include: ["<boost/variant/detail/substitute.hpp>", private, "<boost/variant/detail/enable_recursive.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/access.hpp>", private, "<boost/xpressive/detail/core/results_cache.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/access.hpp>", private, "<boost/xpressive/detail/core/state.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/action.hpp>", private, "<boost/xpressive/detail/core/matcher/action_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/action.hpp>", private, "<boost/xpressive/detail/core/state.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/adaptor.hpp>", private, "<boost/xpressive/detail/core/matcher/regex_byref_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/adaptor.hpp>", private, "<boost/xpressive/detail/core/matcher/regex_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/adaptor.hpp>", private, "<boost/xpressive/detail/static/compile.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/finder.hpp>", private, "<boost/xpressive/detail/core/optimize.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/flow_control.hpp>", private, "<boost/xpressive/detail/core/matcher/end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/icase.hpp>", private, "<boost/xpressive/detail/dynamic/dynamic.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/linker.hpp>", private, "<boost/xpressive/detail/core/icase.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/linker.hpp>", private, "<boost/xpressive/detail/core/optimize.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/linker.hpp>", private, "<boost/xpressive/detail/static/compile.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/linker.hpp>", private, "<boost/xpressive/detail/static/static.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/list.hpp>", private, "<boost/xpressive/detail/core/results_cache.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/action_matcher.hpp>", private, "<boost/xpressive/detail/core/matcher/predicate_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/alternate_end_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/alternate_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/alternate_matcher.hpp>", private, "<boost/xpressive/detail/static/transforms/as_alternate.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/any_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/assert_bol_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/assert_bos_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/assert_eol_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/assert_eos_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/assert_line_base.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_bol_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/assert_line_base.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_eol_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/assert_word_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/attr_end_matcher.hpp>", private, "<boost/xpressive/detail/static/transforms/as_action.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/attr_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/charset_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/end_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/end_matcher.hpp>", private, "<boost/xpressive/detail/static/compile.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/epsilon_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/keeper_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/literal_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/logical_newline_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/lookahead_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/lookbehind_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/mark_begin_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/mark_begin_matcher.hpp>", private, "<boost/xpressive/detail/static/visitor.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/mark_end_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/mark_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/optional_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/posix_charset_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/range_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/regex_byref_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/regex_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/repeat_begin_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/repeat_end_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/set_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matchers.hpp>", private, "<boost/xpressive/detail/core/linker.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matchers.hpp>", private, "<boost/xpressive/detail/core/peeker.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matchers.hpp>", private, "<boost/xpressive/detail/dynamic/parser.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matchers.hpp>", private, "<boost/xpressive/detail/static/transmogrify.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/simple_repeat_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/string_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/matcher/true_matcher.hpp>", private, "<boost/xpressive/detail/core/matchers.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/optimize.hpp>", private, "<boost/xpressive/detail/static/compile.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/peeker.hpp>", private, "<boost/xpressive/detail/core/linker.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/peeker.hpp>", private, "<boost/xpressive/detail/core/optimize.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/peeker.hpp>", private, "<boost/xpressive/detail/static/static.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/action_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/alternate_end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/alternate_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/any_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_bol_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_bos_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_eol_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_eos_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_line_base.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_word_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/attr_begin_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/attr_end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/attr_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/charset_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/epsilon_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/keeper_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/literal_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/logical_newline_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/lookahead_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/lookbehind_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/mark_begin_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/mark_end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/mark_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/optional_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/posix_charset_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/predicate_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/range_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/regex_byref_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/regex_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/repeat_begin_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/repeat_end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/set_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/simple_repeat_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/string_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/core/matcher/true_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/dynamic/dynamic.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/dynamic/matchable.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/quant_style.hpp>", private, "<boost/xpressive/detail/static/placeholders.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/regex_impl.hpp>", private, "<boost/xpressive/detail/core/finder.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/regex_impl.hpp>", private, "<boost/xpressive/detail/core/flow_control.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/regex_impl.hpp>", private, "<boost/xpressive/detail/core/matcher/regex_byref_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/regex_impl.hpp>", private, "<boost/xpressive/detail/core/matcher/regex_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/regex_impl.hpp>", private, "<boost/xpressive/detail/core/optimize.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/regex_impl.hpp>", private, "<boost/xpressive/detail/core/state.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/regex_impl.hpp>", private, "<boost/xpressive/detail/static/compile.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/regex_impl.hpp>", private, "<boost/xpressive/detail/static/placeholders.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/regex_impl.hpp>", private, "<boost/xpressive/detail/static/visitor.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/flow_control.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/action_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/alternate_end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/alternate_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/any_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_bol_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_bos_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_eol_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_eos_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_line_base.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_word_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/attr_begin_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/attr_end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/attr_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/charset_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/epsilon_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/keeper_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/literal_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/logical_newline_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/lookahead_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/lookbehind_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/mark_begin_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/mark_end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/mark_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/optional_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/posix_charset_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/predicate_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/range_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/regex_byref_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/regex_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/repeat_begin_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/repeat_end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/set_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/simple_repeat_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/string_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/core/matcher/true_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/detail/static/static.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/sub_match_impl.hpp>", private, "<boost/xpressive/detail/core/matcher/end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/sub_match_impl.hpp>", private, "<boost/xpressive/detail/core/sub_match_vector.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/core/sub_match_vector.hpp>", private, "<boost/xpressive/detail/core/state.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/access.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/action.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/adaptor.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/finder.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/flow_control.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/icase.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/linker.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/action_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/alternate_end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/alternate_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/any_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_bol_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_bos_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_eol_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_eos_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_line_base.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_word_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/attr_begin_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/attr_end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/attr_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/charset_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/epsilon_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/keeper_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/literal_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/logical_newline_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/lookahead_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/lookbehind_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/mark_begin_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/mark_end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/mark_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/optional_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/posix_charset_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/predicate_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/range_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/regex_byref_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/regex_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/repeat_begin_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/repeat_end_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/set_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/simple_repeat_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/string_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/matcher/true_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/peeker.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/quant_style.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/regex_impl.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/results_cache.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/state.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/core/sub_match_vector.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/dynamic/dynamic.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/dynamic/matchable.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/dynamic/parse_charset.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/dynamic/parser.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/dynamic/parser_traits.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/dynamic/sequence.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/grammar.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/is_pure.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/modifier.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/static.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/transforms/as_action.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/transforms/as_alternate.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/transforms/as_independent.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/transforms/as_inverse.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/transforms/as_marker.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/transforms/as_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/transforms/as_modifier.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/transforms/as_quantifier.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/transforms/as_sequence.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/transforms/as_set.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/transmogrify.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/type_traits.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/visitor.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/static/width_of.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/utility/boyer_moore.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/detail/utility/chset/chset.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/traits/detail/c_ctype.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/dynamic/dynamic.hpp>", private, "<boost/xpressive/detail/dynamic/parser.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/dynamic/matchable.hpp>", private, "<boost/xpressive/detail/core/access.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/dynamic/matchable.hpp>", private, "<boost/xpressive/detail/core/adaptor.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/dynamic/matchable.hpp>", private, "<boost/xpressive/detail/core/linker.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/dynamic/matchable.hpp>", private, "<boost/xpressive/detail/core/matcher/alternate_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/dynamic/matchable.hpp>", private, "<boost/xpressive/detail/core/regex_impl.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/dynamic/matchable.hpp>", private, "<boost/xpressive/detail/dynamic/dynamic.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/dynamic/matchable.hpp>", private, "<boost/xpressive/detail/dynamic/parser_traits.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/dynamic/parser_enum.hpp>", private, "<boost/xpressive/detail/dynamic/parse_charset.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/dynamic/parser_enum.hpp>", private, "<boost/xpressive/detail/dynamic/parser_traits.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/dynamic/sequence.hpp>", private, "<boost/xpressive/detail/dynamic/dynamic.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/dynamic/sequence.hpp>", private, "<boost/xpressive/detail/dynamic/matchable.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/grammar.hpp>", private, "<boost/xpressive/detail/static/compile.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/is_pure.hpp>", private, "<boost/xpressive/detail/static/grammar.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/modifier.hpp>", private, "<boost/xpressive/detail/core/icase.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/placeholders.hpp>", private, "<boost/xpressive/detail/static/static.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/placeholders.hpp>", private, "<boost/xpressive/detail/static/transmogrify.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/static.hpp>", private, "<boost/xpressive/detail/static/compile.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/static.hpp>", private, "<boost/xpressive/detail/static/transforms/as_action.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/static.hpp>", private, "<boost/xpressive/detail/static/transforms/as_alternate.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/static.hpp>", private, "<boost/xpressive/detail/static/transforms/as_independent.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/static.hpp>", private, "<boost/xpressive/detail/static/transforms/as_inverse.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/static.hpp>", private, "<boost/xpressive/detail/static/transforms/as_marker.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/static.hpp>", private, "<boost/xpressive/detail/static/transforms/as_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/static.hpp>", private, "<boost/xpressive/detail/static/transforms/as_modifier.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/static.hpp>", private, "<boost/xpressive/detail/static/transforms/as_quantifier.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/static.hpp>", private, "<boost/xpressive/detail/static/transforms/as_sequence.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/static.hpp>", private, "<boost/xpressive/detail/static/transforms/as_set.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/transforms/as_action.hpp>", private, "<boost/xpressive/detail/core/matcher/action_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/transforms/as_action.hpp>", private, "<boost/xpressive/detail/static/grammar.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/transforms/as_alternate.hpp>", private, "<boost/xpressive/detail/static/grammar.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/transforms/as_independent.hpp>", private, "<boost/xpressive/detail/static/grammar.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/transforms/as_inverse.hpp>", private, "<boost/xpressive/detail/static/grammar.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/transforms/as_marker.hpp>", private, "<boost/xpressive/detail/static/grammar.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/transforms/as_matcher.hpp>", private, "<boost/xpressive/detail/static/grammar.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/transforms/as_modifier.hpp>", private, "<boost/xpressive/detail/static/grammar.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/transforms/as_quantifier.hpp>", private, "<boost/xpressive/detail/static/grammar.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/transforms/as_quantifier.hpp>", private, "<boost/xpressive/detail/static/transforms/as_action.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/transforms/as_sequence.hpp>", private, "<boost/xpressive/detail/static/grammar.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/transforms/as_set.hpp>", private, "<boost/xpressive/detail/static/grammar.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/transmogrify.hpp>", private, "<boost/xpressive/detail/static/visitor.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/type_traits.hpp>", private, "<boost/xpressive/detail/core/matcher/simple_repeat_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/type_traits.hpp>", private, "<boost/xpressive/detail/static/width_of.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/visitor.hpp>", private, "<boost/xpressive/detail/static/compile.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/static/width_of.hpp>", private, "<boost/xpressive/detail/static/is_pure.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/algorithm.hpp>", private, "<boost/xpressive/detail/core/matcher/alternate_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/algorithm.hpp>", private, "<boost/xpressive/detail/core/matcher/lookbehind_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/algorithm.hpp>", private, "<boost/xpressive/detail/core/matcher/string_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/algorithm.hpp>", private, "<boost/xpressive/detail/core/peeker.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/algorithm.hpp>", private, "<boost/xpressive/detail/dynamic/parser_traits.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/algorithm.hpp>", private, "<boost/xpressive/detail/utility/chset/chset.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/algorithm.hpp>", private, "<boost/xpressive/detail/utility/traits_utils.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/any.hpp>", private, "<boost/xpressive/detail/core/matcher/alternate_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/boyer_moore.hpp>", private, "<boost/xpressive/detail/core/finder.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/chset/basic_chset.ipp>", private, "<boost/xpressive/detail/utility/chset/chset.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/chset/basic_chset.ipp>", private, "<boost/xpressive/detail/utility/hash_peek_bitset.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/chset/chset.hpp>", private, "<boost/xpressive/detail/dynamic/parse_charset.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/chset/chset.hpp>", private, "<boost/xpressive/detail/static/transforms/as_set.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/chset/range_run.ipp>", private, "<boost/xpressive/detail/utility/chset/basic_chset.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/cons.hpp>", private, "<boost/xpressive/detail/static/transforms/as_alternate.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/counted_base.hpp>", private, "<boost/xpressive/detail/core/regex_impl.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/counted_base.hpp>", private, "<boost/xpressive/detail/dynamic/matchable.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/dont_care.hpp>", private, "<boost/xpressive/detail/static/transmogrify.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/hash_peek_bitset.hpp>", private, "<boost/xpressive/detail/core/finder.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/hash_peek_bitset.hpp>", private, "<boost/xpressive/detail/core/matcher/alternate_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/hash_peek_bitset.hpp>", private, "<boost/xpressive/detail/core/peeker.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/ignore_unused.hpp>", private, "<boost/xpressive/detail/core/flow_control.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/ignore_unused.hpp>", private, "<boost/xpressive/detail/core/icase.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/ignore_unused.hpp>", private, "<boost/xpressive/detail/core/matcher/assert_word_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/ignore_unused.hpp>", private, "<boost/xpressive/detail/core/matcher/lookahead_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/ignore_unused.hpp>", private, "<boost/xpressive/detail/core/matcher/lookbehind_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/ignore_unused.hpp>", private, "<boost/xpressive/detail/dynamic/parser.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/ignore_unused.hpp>", private, "<boost/xpressive/detail/utility/algorithm.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/literals.hpp>", private, "<boost/xpressive/detail/dynamic/parse_charset.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/literals.hpp>", private, "<boost/xpressive/detail/dynamic/parser_traits.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/never_true.hpp>", private, "<boost/xpressive/detail/core/linker.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/never_true.hpp>", private, "<boost/xpressive/detail/core/peeker.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/save_restore.hpp>", private, "<boost/xpressive/detail/core/matcher/lookahead_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/save_restore.hpp>", private, "<boost/xpressive/detail/core/matcher/lookbehind_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/sequence_stack.hpp>", private, "<boost/xpressive/detail/core/state.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/symbols.hpp>", private, "<boost/xpressive/detail/core/matcher/attr_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/tracking_ptr.hpp>", private, "<boost/xpressive/detail/core/regex_impl.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/traits_utils.hpp>", private, "<boost/xpressive/detail/core/matcher/literal_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/traits_utils.hpp>", private, "<boost/xpressive/detail/core/matcher/mark_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/traits_utils.hpp>", private, "<boost/xpressive/detail/core/matcher/posix_charset_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/traits_utils.hpp>", private, "<boost/xpressive/detail/core/matcher/string_matcher.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/traits_utils.hpp>", private, "<boost/xpressive/detail/static/transforms/as_set.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/traits_utils.hpp>", private, "<boost/xpressive/detail/static/transmogrify.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/width.hpp>", private, "<boost/xpressive/detail/core/quant_style.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/width.hpp>", private, "<boost/xpressive/detail/dynamic/sequence.hpp>", private ] },
+ { include: ["<boost/xpressive/detail/utility/width.hpp>", private, "<boost/xpressive/detail/static/static.hpp>", private ] }
+]
diff --git a/src/arrow/cpp/build-support/iwyu/mappings/boost-all.imp b/src/arrow/cpp/build-support/iwyu/mappings/boost-all.imp
new file mode 100644
index 000000000..5427ae2ac
--- /dev/null
+++ b/src/arrow/cpp/build-support/iwyu/mappings/boost-all.imp
@@ -0,0 +1,5679 @@
+# This file has been imported into the apache source tree from
+# the IWYU source tree as of version 0.8
+# https://github.com/include-what-you-use/include-what-you-use/blob/master/boost-all.imp
+# and corresponding license has been added:
+# https://github.com/include-what-you-use/include-what-you-use/blob/master/LICENSE.TXT
+#
+# ==============================================================================
+# LLVM Release License
+# ==============================================================================
+# University of Illinois/NCSA
+# Open Source License
+#
+# Copyright (c) 2003-2010 University of Illinois at Urbana-Champaign.
+# All rights reserved.
+#
+# Developed by:
+#
+# LLVM Team
+#
+# University of Illinois at Urbana-Champaign
+#
+# http://llvm.org
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal with
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is furnished to do
+# so, subject to the following conditions:
+#
+# * Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimers.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimers in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the names of the LLVM Team, University of Illinois at
+# Urbana-Champaign, nor the names of its contributors may be used to
+# endorse or promote products derived from this Software without specific
+# prior written permission.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE
+# SOFTWARE.
+
+[
+# cd /usr/include && grep -r --exclude-dir={detail,impl} '^ *# *include' boost/ | perl -nle 'm/^([^:]+).*["<]([^>]+)[">]/ && print qq@ { include: ["<$2>", private, "<$1>", public ] },@' | grep -e \/detail\/ -e \/impl\/ | grep -e \\[\"\<boost/ | sort -u
+#manually include:
+{ include: ["@<boost/bind/.*>", private, "<boost/bind.hpp>", public ] },
+{ include: ["@<boost/format/.*>", private, "<boost/format.hpp>", public ] },
+{ include: ["@<boost/filesystem/.*>", private, "<boost/filesystem.hpp>", public ] },
+{ include: ["@<boost/function/.*>", private, "<boost/function.hpp>", public ] },
+#manually delete $ sed '/workarounds*\.hpp/d' -i boost-all.imp
+#also good idea to remove all lines referring to folders above (e.g., sed '/\/format\//d' -i boost-all.imp)
+#programatically include:
+ { include: ["<boost/accumulators/numeric/detail/function1.hpp>", private, "<boost/accumulators/numeric/functional.hpp>", public ] },
+ { include: ["<boost/accumulators/numeric/detail/function2.hpp>", private, "<boost/accumulators/numeric/functional.hpp>", public ] },
+ { include: ["<boost/accumulators/numeric/detail/pod_singleton.hpp>", private, "<boost/accumulators/numeric/functional.hpp>", public ] },
+ { include: ["<boost/algorithm/searching/detail/bm_traits.hpp>", private, "<boost/algorithm/searching/boyer_moore_horspool.hpp>", public ] },
+ { include: ["<boost/algorithm/searching/detail/bm_traits.hpp>", private, "<boost/algorithm/searching/boyer_moore.hpp>", public ] },
+ { include: ["<boost/algorithm/searching/detail/debugging.hpp>", private, "<boost/algorithm/searching/boyer_moore_horspool.hpp>", public ] },
+ { include: ["<boost/algorithm/searching/detail/debugging.hpp>", private, "<boost/algorithm/searching/boyer_moore.hpp>", public ] },
+ { include: ["<boost/algorithm/searching/detail/debugging.hpp>", private, "<boost/algorithm/searching/knuth_morris_pratt.hpp>", public ] },
+ { include: ["<boost/algorithm/string/detail/case_conv.hpp>", private, "<boost/algorithm/string/case_conv.hpp>", public ] },
+ { include: ["<boost/algorithm/string/detail/classification.hpp>", private, "<boost/algorithm/string/classification.hpp>", public ] },
+ { include: ["<boost/algorithm/string/detail/finder.hpp>", private, "<boost/algorithm/string/finder.hpp>", public ] },
+ { include: ["<boost/algorithm/string/detail/finder_regex.hpp>", private, "<boost/algorithm/string/regex_find_format.hpp>", public ] },
+ { include: ["<boost/algorithm/string/detail/find_format_all.hpp>", private, "<boost/algorithm/string/find_format.hpp>", public ] },
+ { include: ["<boost/algorithm/string/detail/find_format.hpp>", private, "<boost/algorithm/string/find_format.hpp>", public ] },
+ { include: ["<boost/algorithm/string/detail/find_iterator.hpp>", private, "<boost/algorithm/string/find_iterator.hpp>", public ] },
+ { include: ["<boost/algorithm/string/detail/formatter.hpp>", private, "<boost/algorithm/string/formatter.hpp>", public ] },
+ { include: ["<boost/algorithm/string/detail/formatter_regex.hpp>", private, "<boost/algorithm/string/regex_find_format.hpp>", public ] },
+ { include: ["<boost/algorithm/string/detail/predicate.hpp>", private, "<boost/algorithm/string/predicate.hpp>", public ] },
+ { include: ["<boost/algorithm/string/detail/sequence.hpp>", private, "<boost/algorithm/string/join.hpp>", public ] },
+ { include: ["<boost/algorithm/string/detail/trim.hpp>", private, "<boost/algorithm/string/trim.hpp>", public ] },
+ { include: ["<boost/algorithm/string/detail/util.hpp>", private, "<boost/algorithm/string/iter_find.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/archive_exception.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/basic_archive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/basic_binary_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/basic_binary_iprimitive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/basic_binary_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/basic_binary_oprimitive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/basic_text_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/basic_text_iprimitive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/basic_text_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/basic_text_oprimitive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/basic_xml_archive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/basic_xml_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/basic_xml_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/codecvt_null.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/polymorphic_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/polymorphic_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/shared_ptr_helper.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/text_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/text_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/text_wiarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/text_woarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/xml_archive_exception.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/xml_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/xml_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/xml_wiarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_prefix.hpp>", private, "<boost/archive/xml_woarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/archive_exception.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/basic_archive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/basic_binary_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/basic_binary_iprimitive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/basic_binary_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/basic_binary_oprimitive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/basic_text_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/basic_text_iprimitive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/basic_text_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/basic_text_oprimitive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/basic_xml_archive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/basic_xml_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/basic_xml_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/codecvt_null.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/polymorphic_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/polymorphic_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/shared_ptr_helper.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/text_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/text_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/text_wiarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/text_woarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/xml_archive_exception.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/xml_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/xml_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/xml_wiarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/abi_suffix.hpp>", private, "<boost/archive/xml_woarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/basic_archive.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/basic_binary_iprimitive.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/basic_binary_oprimitive.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/basic_xml_archive.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/codecvt_null.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/text_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/text_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/xml_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/archive/xml_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/mpi/packed_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/mpi/packed_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_archive.hpp>", private, "<boost/mpi/skeleton_and_content.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_warchive.hpp>", private, "<boost/archive/text_wiarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_warchive.hpp>", private, "<boost/archive/text_woarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_warchive.hpp>", private, "<boost/archive/xml_wiarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/auto_link_warchive.hpp>", private, "<boost/archive/xml_woarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/basic_iarchive.hpp>", private, "<boost/serialization/collections_load_imp.hpp>", public ] },
+ { include: ["<boost/archive/detail/basic_iarchive.hpp>", private, "<boost/serialization/hash_collections_load_imp.hpp>", public ] },
+ { include: ["<boost/archive/detail/basic_iarchive.hpp>", private, "<boost/serialization/optional.hpp>", public ] },
+ { include: ["<boost/archive/detail/common_iarchive.hpp>", private, "<boost/archive/basic_binary_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/common_iarchive.hpp>", private, "<boost/archive/basic_text_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/common_iarchive.hpp>", private, "<boost/archive/basic_xml_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/common_iarchive.hpp>", private, "<boost/mpi/packed_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/common_oarchive.hpp>", private, "<boost/archive/basic_binary_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/common_oarchive.hpp>", private, "<boost/archive/basic_text_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/common_oarchive.hpp>", private, "<boost/archive/basic_xml_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/common_oarchive.hpp>", private, "<boost/mpi/packed_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/decl.hpp>", private, "<boost/archive/archive_exception.hpp>", public ] },
+ { include: ["<boost/archive/detail/decl.hpp>", private, "<boost/archive/polymorphic_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/decl.hpp>", private, "<boost/archive/polymorphic_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/decl.hpp>", private, "<boost/archive/shared_ptr_helper.hpp>", public ] },
+ { include: ["<boost/archive/detail/decl.hpp>", private, "<boost/archive/xml_archive_exception.hpp>", public ] },
+ { include: ["<boost/archive/detail/interface_iarchive.hpp>", private, "<boost/archive/polymorphic_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/interface_oarchive.hpp>", private, "<boost/archive/polymorphic_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/iserializer.hpp>", private, "<boost/archive/polymorphic_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/oserializer.hpp>", private, "<boost/archive/polymorphic_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/polymorphic_iarchive_route.hpp>", private, "<boost/archive/polymorphic_binary_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/polymorphic_iarchive_route.hpp>", private, "<boost/archive/polymorphic_text_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/polymorphic_iarchive_route.hpp>", private, "<boost/archive/polymorphic_text_wiarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/polymorphic_iarchive_route.hpp>", private, "<boost/archive/polymorphic_xml_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/polymorphic_iarchive_route.hpp>", private, "<boost/archive/polymorphic_xml_wiarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/polymorphic_oarchive_route.hpp>", private, "<boost/archive/polymorphic_binary_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/polymorphic_oarchive_route.hpp>", private, "<boost/archive/polymorphic_text_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/polymorphic_oarchive_route.hpp>", private, "<boost/archive/polymorphic_text_woarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/polymorphic_oarchive_route.hpp>", private, "<boost/archive/polymorphic_xml_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/polymorphic_oarchive_route.hpp>", private, "<boost/archive/polymorphic_xml_woarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/archive/binary_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/archive/binary_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/archive/binary_wiarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/archive/binary_woarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/archive/polymorphic_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/archive/polymorphic_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/archive/text_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/archive/text_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/archive/text_wiarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/archive/text_woarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/archive/xml_iarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/archive/xml_oarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/archive/xml_wiarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/archive/xml_woarchive.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/mpi/skeleton_and_content.hpp>", public ] },
+ { include: ["<boost/archive/detail/register_archive.hpp>", private, "<boost/serialization/export.hpp>", public ] },
+ { include: ["<boost/asio/detail/array_fwd.hpp>", private, "<boost/asio/buffer.hpp>", public ] },
+ { include: ["<boost/asio/detail/array.hpp>", private, "<boost/asio/basic_socket_streambuf.hpp>", public ] },
+ { include: ["<boost/asio/detail/array.hpp>", private, "<boost/asio/ip/address_v4.hpp>", public ] },
+ { include: ["<boost/asio/detail/array.hpp>", private, "<boost/asio/ip/address_v6.hpp>", public ] },
+ { include: ["<boost/asio/detail/assert.hpp>", private, "<boost/asio/buffers_iterator.hpp>", public ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/buffered_read_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/bind_handler.hpp>", private, "<boost/asio/buffered_write_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/buffered_stream_storage.hpp>", private, "<boost/asio/buffered_read_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/buffered_stream_storage.hpp>", private, "<boost/asio/buffered_write_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/buffer_resize_guard.hpp>", private, "<boost/asio/buffered_read_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/buffer_sequence_adapter.hpp>", private, "<boost/asio/ssl/stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/chrono_time_traits.hpp>", private, "<boost/asio/waitable_timer_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/async_result.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_datagram_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_deadline_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_io_object.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_raw_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_seq_packet_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_serial_port.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_signal_set.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_socket_acceptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_socket_iostream.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_socket_streambuf.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_streambuf_fwd.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_streambuf.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_stream_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/basic_waitable_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/buffered_read_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/buffered_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/buffered_write_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/buffer.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/buffers_iterator.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/completion_condition.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/connect.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/datagram_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/deadline_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/deadline_timer_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/error.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/generic/basic_endpoint.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/generic/datagram_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/generic/raw_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/generic/seq_packet_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/generic/stream_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/handler_alloc_hook.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/handler_continuation_hook.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/handler_invoke_hook.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/handler_type.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/high_resolution_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/io_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/address.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/address_v4.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/address_v6.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/basic_endpoint.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/basic_resolver_entry.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/basic_resolver.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/basic_resolver_iterator.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/basic_resolver_query.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/host_name.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/icmp.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/multicast.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/resolver_query_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/resolver_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/tcp.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/udp.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/unicast.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ip/v6_only.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/is_read_buffered.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/is_write_buffered.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/local/basic_endpoint.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/local/connect_pair.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/local/datagram_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/local/stream_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/placeholders.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/posix/basic_descriptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/posix/basic_stream_descriptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/posix/descriptor_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/posix/stream_descriptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/posix/stream_descriptor_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/raw_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/read_at.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/read.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/read_until.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/seq_packet_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/serial_port_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/serial_port.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/serial_port_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/signal_set.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/signal_set_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/socket_acceptor_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/socket_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/spawn.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/basic_context.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/context_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/context.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/context_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/error.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/old/basic_context.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/old/context_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/old/stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/old/stream_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/rfc2818_verification.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/stream_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/stream_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/verify_context.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/ssl/verify_mode.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/steady_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/strand.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/streambuf.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/stream_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/system_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/use_future.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/waitable_timer_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/windows/basic_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/windows/basic_object_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/windows/basic_random_access_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/windows/basic_stream_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/windows/object_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/windows/object_handle_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/windows/overlapped_ptr.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/windows/random_access_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/windows/random_access_handle_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/windows/stream_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/windows/stream_handle_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/write_at.hpp>", public ] },
+ { include: ["<boost/asio/detail/config.hpp>", private, "<boost/asio/write.hpp>", public ] },
+ { include: ["<boost/asio/detail/cstdint.hpp>", private, "<boost/asio/read_at.hpp>", public ] },
+ { include: ["<boost/asio/detail/cstdint.hpp>", private, "<boost/asio/windows/random_access_handle_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/cstdint.hpp>", private, "<boost/asio/write_at.hpp>", public ] },
+ { include: ["<boost/asio/detail/deadline_timer_service.hpp>", private, "<boost/asio/deadline_timer_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/deadline_timer_service.hpp>", private, "<boost/asio/waitable_timer_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/function.hpp>", private, "<boost/asio/buffer.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/basic_datagram_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/basic_deadline_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/basic_raw_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/basic_seq_packet_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/basic_serial_port.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/basic_signal_set.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/basic_socket_acceptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/basic_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/basic_stream_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/basic_waitable_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/ip/basic_resolver.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/posix/basic_stream_descriptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/ssl/stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/strand.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/windows/basic_random_access_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/handler_type_requirements.hpp>", private, "<boost/asio/windows/basic_stream_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/io_control.hpp>", private, "<boost/asio/posix/descriptor_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/io_control.hpp>", private, "<boost/asio/socket_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/limits.hpp>", private, "<boost/asio/basic_streambuf.hpp>", public ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/basic_streambuf.hpp>", public ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/buffered_read_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/buffered_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/buffered_write_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/io_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/ssl/stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/ssl/verify_context.hpp>", public ] },
+ { include: ["<boost/asio/detail/noncopyable.hpp>", private, "<boost/asio/windows/overlapped_ptr.hpp>", public ] },
+ { include: ["<boost/asio/detail/null_socket_service.hpp>", private, "<boost/asio/datagram_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/null_socket_service.hpp>", private, "<boost/asio/raw_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/null_socket_service.hpp>", private, "<boost/asio/seq_packet_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/null_socket_service.hpp>", private, "<boost/asio/socket_acceptor_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/async_result.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/basic_datagram_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/basic_deadline_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/basic_io_object.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/basic_raw_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/basic_seq_packet_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/basic_serial_port.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/basic_signal_set.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/basic_socket_acceptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/basic_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/basic_socket_iostream.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/basic_socket_streambuf.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/basic_streambuf.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/basic_stream_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/basic_waitable_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/buffered_read_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/buffered_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/buffered_write_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/buffer.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/buffers_iterator.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/completion_condition.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/connect.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/datagram_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/deadline_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/deadline_timer_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/error.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/generic/basic_endpoint.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/generic/datagram_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/generic/raw_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/generic/seq_packet_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/generic/stream_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/handler_alloc_hook.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/handler_continuation_hook.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/handler_invoke_hook.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/handler_type.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/io_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/address.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/address_v4.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/address_v6.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/basic_endpoint.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/basic_resolver_entry.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/basic_resolver.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/basic_resolver_iterator.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/basic_resolver_query.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/host_name.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/icmp.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/multicast.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/resolver_query_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/resolver_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/tcp.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/udp.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/unicast.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ip/v6_only.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/is_read_buffered.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/is_write_buffered.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/local/basic_endpoint.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/local/connect_pair.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/local/datagram_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/local/stream_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/placeholders.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/posix/basic_descriptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/posix/basic_stream_descriptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/posix/descriptor_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/posix/stream_descriptor_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/raw_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/read_at.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/read.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/read_until.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/seq_packet_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/serial_port_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/serial_port_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/signal_set_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/socket_acceptor_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/socket_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/spawn.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/basic_context.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/context_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/context.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/context_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/error.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/old/basic_context.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/old/context_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/old/stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/old/stream_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/rfc2818_verification.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/stream_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/stream_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/verify_context.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/ssl/verify_mode.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/strand.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/stream_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/time_traits.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/use_future.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/waitable_timer_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/wait_traits.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/windows/basic_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/windows/basic_object_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/windows/basic_random_access_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/windows/basic_stream_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/windows/object_handle_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/windows/overlapped_ptr.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/windows/random_access_handle_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/windows/stream_handle_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/write_at.hpp>", public ] },
+ { include: ["<boost/asio/detail/pop_options.hpp>", private, "<boost/asio/write.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/async_result.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/basic_datagram_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/basic_deadline_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/basic_io_object.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/basic_raw_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/basic_seq_packet_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/basic_serial_port.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/basic_signal_set.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/basic_socket_acceptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/basic_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/basic_socket_iostream.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/basic_socket_streambuf.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/basic_streambuf.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/basic_stream_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/basic_waitable_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/buffered_read_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/buffered_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/buffered_write_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/buffer.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/buffers_iterator.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/completion_condition.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/connect.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/datagram_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/deadline_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/deadline_timer_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/error.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/generic/basic_endpoint.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/generic/datagram_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/generic/raw_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/generic/seq_packet_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/generic/stream_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/handler_alloc_hook.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/handler_continuation_hook.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/handler_invoke_hook.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/handler_type.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/io_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/address.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/address_v4.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/address_v6.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/basic_endpoint.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/basic_resolver_entry.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/basic_resolver.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/basic_resolver_iterator.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/basic_resolver_query.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/host_name.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/icmp.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/multicast.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/resolver_query_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/resolver_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/tcp.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/udp.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/unicast.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ip/v6_only.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/is_read_buffered.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/is_write_buffered.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/local/basic_endpoint.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/local/connect_pair.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/local/datagram_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/local/stream_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/placeholders.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/posix/basic_descriptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/posix/basic_stream_descriptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/posix/descriptor_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/posix/stream_descriptor_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/raw_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/read_at.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/read.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/read_until.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/seq_packet_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/serial_port_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/serial_port_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/signal_set_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/socket_acceptor_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/socket_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/spawn.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/basic_context.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/context_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/context.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/context_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/error.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/old/basic_context.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/old/context_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/old/stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/old/stream_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/rfc2818_verification.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/stream_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/stream_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/verify_context.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/ssl/verify_mode.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/strand.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/stream_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/time_traits.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/use_future.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/waitable_timer_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/wait_traits.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/windows/basic_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/windows/basic_object_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/windows/basic_random_access_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/windows/basic_stream_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/windows/object_handle_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/windows/overlapped_ptr.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/windows/random_access_handle_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/windows/stream_handle_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/write_at.hpp>", public ] },
+ { include: ["<boost/asio/detail/push_options.hpp>", private, "<boost/asio/write.hpp>", public ] },
+ { include: ["<boost/asio/detail/reactive_descriptor_service.hpp>", private, "<boost/asio/posix/stream_descriptor_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/reactive_serial_port_service.hpp>", private, "<boost/asio/serial_port_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/reactive_socket_service.hpp>", private, "<boost/asio/datagram_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/reactive_socket_service.hpp>", private, "<boost/asio/raw_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/reactive_socket_service.hpp>", private, "<boost/asio/seq_packet_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/reactive_socket_service.hpp>", private, "<boost/asio/socket_acceptor_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/reactive_socket_service.hpp>", private, "<boost/asio/stream_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/regex_fwd.hpp>", private, "<boost/asio/read_until.hpp>", public ] },
+ { include: ["<boost/asio/detail/resolver_service.hpp>", private, "<boost/asio/ip/resolver_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/shared_ptr.hpp>", private, "<boost/asio/ip/basic_resolver_iterator.hpp>", public ] },
+ { include: ["<boost/asio/detail/signal_init.hpp>", private, "<boost/asio/io_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/signal_set_service.hpp>", private, "<boost/asio/signal_set_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/ip/basic_resolver_iterator.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/ip/basic_resolver_query.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_ops.hpp>", private, "<boost/asio/local/connect_pair.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_option.hpp>", private, "<boost/asio/ip/tcp.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_option.hpp>", private, "<boost/asio/ip/v6_only.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_option.hpp>", private, "<boost/asio/posix/descriptor_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_option.hpp>", private, "<boost/asio/socket_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/deadline_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/generic/datagram_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/generic/raw_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/generic/seq_packet_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/generic/stream_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/ip/address_v4.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/ip/address_v6.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/ip/basic_resolver_iterator.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/ip/icmp.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/ip/resolver_query_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/ip/tcp.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/ip/udp.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/local/datagram_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/local/stream_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/serial_port_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/socket_base.hpp>", public ] },
+ { include: ["<boost/asio/detail/socket_types.hpp>", private, "<boost/asio/time_traits.hpp>", public ] },
+ { include: ["<boost/asio/detail/strand_service.hpp>", private, "<boost/asio/strand.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/basic_datagram_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/basic_deadline_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/basic_raw_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/basic_seq_packet_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/basic_serial_port.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/basic_signal_set.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/basic_socket_acceptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/basic_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/basic_socket_streambuf.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/basic_stream_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/basic_waitable_timer.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/ip/basic_resolver.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/local/connect_pair.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/posix/basic_descriptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/posix/basic_stream_descriptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/ssl/old/basic_context.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/ssl/old/stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/windows/basic_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/windows/basic_object_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/windows/basic_random_access_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_error.hpp>", private, "<boost/asio/windows/basic_stream_handle.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_exception.hpp>", private, "<boost/asio/basic_streambuf.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_exception.hpp>", private, "<boost/asio/generic/datagram_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_exception.hpp>", private, "<boost/asio/generic/raw_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_exception.hpp>", private, "<boost/asio/generic/seq_packet_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/throw_exception.hpp>", private, "<boost/asio/generic/stream_protocol.hpp>", public ] },
+ { include: ["<boost/asio/detail/timer_queue_ptime.hpp>", private, "<boost/asio/deadline_timer_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/basic_datagram_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/basic_raw_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/basic_socket_acceptor.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/basic_socket.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/buffered_read_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/buffered_write_stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/buffer.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/buffers_iterator.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/datagram_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/raw_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/read_until.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/seq_packet_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/socket_acceptor_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/ssl/old/stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/ssl/stream.hpp>", public ] },
+ { include: ["<boost/asio/detail/type_traits.hpp>", private, "<boost/asio/stream_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/variadic_templates.hpp>", private, "<boost/asio/basic_socket_iostream.hpp>", public ] },
+ { include: ["<boost/asio/detail/variadic_templates.hpp>", private, "<boost/asio/basic_socket_streambuf.hpp>", public ] },
+ { include: ["<boost/asio/detail/weak_ptr.hpp>", private, "<boost/asio/spawn.hpp>", public ] },
+ { include: ["<boost/asio/detail/win_iocp_handle_service.hpp>", private, "<boost/asio/windows/random_access_handle_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/win_iocp_handle_service.hpp>", private, "<boost/asio/windows/stream_handle_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/win_iocp_overlapped_ptr.hpp>", private, "<boost/asio/windows/overlapped_ptr.hpp>", public ] },
+ { include: ["<boost/asio/detail/win_iocp_serial_port_service.hpp>", private, "<boost/asio/serial_port_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/win_iocp_socket_service.hpp>", private, "<boost/asio/datagram_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/win_iocp_socket_service.hpp>", private, "<boost/asio/raw_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/win_iocp_socket_service.hpp>", private, "<boost/asio/seq_packet_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/win_iocp_socket_service.hpp>", private, "<boost/asio/socket_acceptor_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/win_iocp_socket_service.hpp>", private, "<boost/asio/stream_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/win_object_handle_service.hpp>", private, "<boost/asio/windows/object_handle_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/winrt_resolver_service.hpp>", private, "<boost/asio/ip/resolver_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/winrt_ssocket_service.hpp>", private, "<boost/asio/stream_socket_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/winrt_utils.hpp>", private, "<boost/asio/ip/basic_resolver_iterator.hpp>", public ] },
+ { include: ["<boost/asio/detail/winsock_init.hpp>", private, "<boost/asio/io_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/winsock_init.hpp>", private, "<boost/asio/ip/address_v4.hpp>", public ] },
+ { include: ["<boost/asio/detail/winsock_init.hpp>", private, "<boost/asio/ip/address_v6.hpp>", public ] },
+ { include: ["<boost/asio/detail/wrapped_handler.hpp>", private, "<boost/asio/io_service.hpp>", public ] },
+ { include: ["<boost/asio/detail/wrapped_handler.hpp>", private, "<boost/asio/spawn.hpp>", public ] },
+ { include: ["<boost/asio/detail/wrapped_handler.hpp>", private, "<boost/asio/strand.hpp>", public ] },
+ { include: ["<boost/asio/generic/detail/endpoint.hpp>", private, "<boost/asio/generic/basic_endpoint.hpp>", public ] },
+ { include: ["<boost/asio/impl/buffered_read_stream.hpp>", private, "<boost/asio/buffered_read_stream.hpp>", public ] },
+ { include: ["<boost/asio/impl/buffered_write_stream.hpp>", private, "<boost/asio/buffered_write_stream.hpp>", public ] },
+ { include: ["<boost/asio/impl/connect.hpp>", private, "<boost/asio/connect.hpp>", public ] },
+ { include: ["<boost/asio/impl/error.ipp>", private, "<boost/asio/error.hpp>", public ] },
+ { include: ["<boost/asio/impl/handler_alloc_hook.ipp>", private, "<boost/asio/handler_alloc_hook.hpp>", public ] },
+ { include: ["<boost/asio/impl/io_service.hpp>", private, "<boost/asio/io_service.hpp>", public ] },
+ { include: ["<boost/asio/impl/io_service.ipp>", private, "<boost/asio/io_service.hpp>", public ] },
+ { include: ["<boost/asio/impl/read_at.hpp>", private, "<boost/asio/read_at.hpp>", public ] },
+ { include: ["<boost/asio/impl/read.hpp>", private, "<boost/asio/read.hpp>", public ] },
+ { include: ["<boost/asio/impl/read_until.hpp>", private, "<boost/asio/read_until.hpp>", public ] },
+ { include: ["<boost/asio/impl/serial_port_base.hpp>", private, "<boost/asio/serial_port_base.hpp>", public ] },
+ { include: ["<boost/asio/impl/serial_port_base.ipp>", private, "<boost/asio/serial_port_base.hpp>", public ] },
+ { include: ["<boost/asio/impl/spawn.hpp>", private, "<boost/asio/spawn.hpp>", public ] },
+ { include: ["<boost/asio/impl/use_future.hpp>", private, "<boost/asio/use_future.hpp>", public ] },
+ { include: ["<boost/asio/impl/write_at.hpp>", private, "<boost/asio/write_at.hpp>", public ] },
+ { include: ["<boost/asio/impl/write.hpp>", private, "<boost/asio/write.hpp>", public ] },
+ { include: ["<boost/asio/ip/detail/endpoint.hpp>", private, "<boost/asio/ip/basic_endpoint.hpp>", public ] },
+ { include: ["<boost/asio/ip/detail/socket_option.hpp>", private, "<boost/asio/ip/multicast.hpp>", public ] },
+ { include: ["<boost/asio/ip/detail/socket_option.hpp>", private, "<boost/asio/ip/unicast.hpp>", public ] },
+ { include: ["<boost/asio/ip/impl/address.hpp>", private, "<boost/asio/ip/address.hpp>", public ] },
+ { include: ["<boost/asio/ip/impl/address.ipp>", private, "<boost/asio/ip/address.hpp>", public ] },
+ { include: ["<boost/asio/ip/impl/address_v4.hpp>", private, "<boost/asio/ip/address_v4.hpp>", public ] },
+ { include: ["<boost/asio/ip/impl/address_v4.ipp>", private, "<boost/asio/ip/address_v4.hpp>", public ] },
+ { include: ["<boost/asio/ip/impl/address_v6.hpp>", private, "<boost/asio/ip/address_v6.hpp>", public ] },
+ { include: ["<boost/asio/ip/impl/address_v6.ipp>", private, "<boost/asio/ip/address_v6.hpp>", public ] },
+ { include: ["<boost/asio/ip/impl/basic_endpoint.hpp>", private, "<boost/asio/ip/basic_endpoint.hpp>", public ] },
+ { include: ["<boost/asio/ip/impl/host_name.ipp>", private, "<boost/asio/ip/host_name.hpp>", public ] },
+ { include: ["<boost/asio/local/detail/endpoint.hpp>", private, "<boost/asio/local/basic_endpoint.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/buffered_handshake_op.hpp>", private, "<boost/asio/ssl/stream.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/handshake_op.hpp>", private, "<boost/asio/ssl/stream.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/io.hpp>", private, "<boost/asio/ssl/stream.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/openssl_init.hpp>", private, "<boost/asio/ssl/context.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/openssl_types.hpp>", private, "<boost/asio/ssl/context_base.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/openssl_types.hpp>", private, "<boost/asio/ssl/context.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/openssl_types.hpp>", private, "<boost/asio/ssl/rfc2818_verification.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/openssl_types.hpp>", private, "<boost/asio/ssl/verify_context.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/openssl_types.hpp>", private, "<boost/asio/ssl/verify_mode.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/password_callback.hpp>", private, "<boost/asio/ssl/context.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/read_op.hpp>", private, "<boost/asio/ssl/stream.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/shutdown_op.hpp>", private, "<boost/asio/ssl/stream.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/stream_core.hpp>", private, "<boost/asio/ssl/stream.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/verify_callback.hpp>", private, "<boost/asio/ssl/context.hpp>", public ] },
+ { include: ["<boost/asio/ssl/detail/write_op.hpp>", private, "<boost/asio/ssl/stream.hpp>", public ] },
+ { include: ["<boost/asio/ssl/impl/context.hpp>", private, "<boost/asio/ssl/context.hpp>", public ] },
+ { include: ["<boost/asio/ssl/impl/context.ipp>", private, "<boost/asio/ssl/context.hpp>", public ] },
+ { include: ["<boost/asio/ssl/impl/error.ipp>", private, "<boost/asio/ssl/error.hpp>", public ] },
+ { include: ["<boost/asio/ssl/impl/rfc2818_verification.ipp>", private, "<boost/asio/ssl/rfc2818_verification.hpp>", public ] },
+ { include: ["<boost/asio/ssl/old/detail/openssl_context_service.hpp>", private, "<boost/asio/ssl/old/context_service.hpp>", public ] },
+ { include: ["<boost/asio/ssl/old/detail/openssl_stream_service.hpp>", private, "<boost/asio/ssl/old/stream_service.hpp>", public ] },
+ { include: ["<boost/atomic/detail/config.hpp>", private, "<boost/atomic/atomic.hpp>", public ] },
+ { include: ["<boost/atomic/detail/platform.hpp>", private, "<boost/atomic/atomic.hpp>", public ] },
+ { include: ["<boost/atomic/detail/type-classification.hpp>", private, "<boost/atomic/atomic.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/comparison_adaptor.hpp>", private, "<boost/bimap/container_adaptor/list_adaptor.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/comparison_adaptor.hpp>", private, "<boost/bimap/container_adaptor/list_map_adaptor.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/comparison_adaptor.hpp>", private, "<boost/bimap/container_adaptor/ordered_associative_container_adaptor.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/comparison_adaptor.hpp>", private, "<boost/bimap/views/multiset_view.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/comparison_adaptor.hpp>", private, "<boost/bimap/views/vector_map_view.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/comparison_adaptor.hpp>", private, "<boost/bimap/views/vector_set_view.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/functor_bag.hpp>", private, "<boost/bimap/container_adaptor/container_adaptor.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/identity_converters.hpp>", private, "<boost/bimap/container_adaptor/associative_container_adaptor.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/identity_converters.hpp>", private, "<boost/bimap/container_adaptor/container_adaptor.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/identity_converters.hpp>", private, "<boost/bimap/container_adaptor/list_map_adaptor.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/identity_converters.hpp>", private, "<boost/bimap/container_adaptor/sequence_container_adaptor.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/identity_converters.hpp>", private, "<boost/bimap/container_adaptor/vector_map_adaptor.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/key_extractor.hpp>", private, "<boost/bimap/container_adaptor/list_map_adaptor.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/non_unique_container_helper.hpp>", private, "<boost/bimap/container_adaptor/multimap_adaptor.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/non_unique_container_helper.hpp>", private, "<boost/bimap/container_adaptor/multiset_adaptor.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/non_unique_container_helper.hpp>", private, "<boost/bimap/container_adaptor/unordered_multimap_adaptor.hpp>", public ] },
+ { include: ["<boost/bimap/container_adaptor/detail/non_unique_container_helper.hpp>", private, "<boost/bimap/container_adaptor/unordered_multiset_adaptor.hpp>", public ] },
+ { include: ["<boost/bimap/detail/bimap_core.hpp>", private, "<boost/bimap/bimap.hpp>", public ] },
+ { include: ["<boost/bimap/detail/concept_tags.hpp>", private, "<boost/bimap/list_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/concept_tags.hpp>", private, "<boost/bimap/multiset_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/concept_tags.hpp>", private, "<boost/bimap/set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/concept_tags.hpp>", private, "<boost/bimap/unconstrained_set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/concept_tags.hpp>", private, "<boost/bimap/unordered_multiset_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/concept_tags.hpp>", private, "<boost/bimap/unordered_set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/concept_tags.hpp>", private, "<boost/bimap/vector_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/debug/static_error.hpp>", private, "<boost/bimap/relation/structured_pair.hpp>", public ] },
+ { include: ["<boost/bimap/detail/debug/static_error.hpp>", private, "<boost/bimap/relation/support/member_with_tag.hpp>", public ] },
+ { include: ["<boost/bimap/detail/debug/static_error.hpp>", private, "<boost/bimap/tags/support/tag_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_index_binder.hpp>", private, "<boost/bimap/list_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_index_binder.hpp>", private, "<boost/bimap/multiset_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_index_binder.hpp>", private, "<boost/bimap/set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_index_binder.hpp>", private, "<boost/bimap/unconstrained_set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_index_binder.hpp>", private, "<boost/bimap/unordered_multiset_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_index_binder.hpp>", private, "<boost/bimap/unordered_set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_index_binder.hpp>", private, "<boost/bimap/vector_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_relation_binder.hpp>", private, "<boost/bimap/list_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_relation_binder.hpp>", private, "<boost/bimap/multiset_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_relation_binder.hpp>", private, "<boost/bimap/set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_relation_binder.hpp>", private, "<boost/bimap/unconstrained_set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_relation_binder.hpp>", private, "<boost/bimap/unordered_multiset_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_relation_binder.hpp>", private, "<boost/bimap/unordered_set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_relation_binder.hpp>", private, "<boost/bimap/vector_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_view_binder.hpp>", private, "<boost/bimap/list_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_view_binder.hpp>", private, "<boost/bimap/multiset_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_view_binder.hpp>", private, "<boost/bimap/set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_view_binder.hpp>", private, "<boost/bimap/unconstrained_set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_view_binder.hpp>", private, "<boost/bimap/unordered_multiset_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_view_binder.hpp>", private, "<boost/bimap/unordered_set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/generate_view_binder.hpp>", private, "<boost/bimap/vector_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/map_view_base.hpp>", private, "<boost/bimap/bimap.hpp>", public ] },
+ { include: ["<boost/bimap/detail/map_view_base.hpp>", private, "<boost/bimap/views/list_map_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/map_view_base.hpp>", private, "<boost/bimap/views/list_set_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/map_view_base.hpp>", private, "<boost/bimap/views/map_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/map_view_base.hpp>", private, "<boost/bimap/views/multimap_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/map_view_base.hpp>", private, "<boost/bimap/views/unordered_map_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/map_view_base.hpp>", private, "<boost/bimap/views/unordered_multimap_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/map_view_base.hpp>", private, "<boost/bimap/views/vector_map_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/map_view_base.hpp>", private, "<boost/bimap/views/vector_set_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/modifier_adaptor.hpp>", private, "<boost/bimap/bimap.hpp>", public ] },
+ { include: ["<boost/bimap/detail/non_unique_views_helper.hpp>", private, "<boost/bimap/views/multimap_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/non_unique_views_helper.hpp>", private, "<boost/bimap/views/multiset_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/non_unique_views_helper.hpp>", private, "<boost/bimap/views/unordered_multimap_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/non_unique_views_helper.hpp>", private, "<boost/bimap/views/unordered_multiset_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/set_view_base.hpp>", private, "<boost/bimap/views/list_set_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/set_view_base.hpp>", private, "<boost/bimap/views/multiset_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/set_view_base.hpp>", private, "<boost/bimap/views/set_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/set_view_base.hpp>", private, "<boost/bimap/views/unordered_multiset_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/set_view_base.hpp>", private, "<boost/bimap/views/unordered_set_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/set_view_base.hpp>", private, "<boost/bimap/views/vector_set_view.hpp>", public ] },
+ { include: ["<boost/bimap/detail/user_interface_config.hpp>", private, "<boost/bimap/bimap.hpp>", public ] },
+ { include: ["<boost/bimap/detail/user_interface_config.hpp>", private, "<boost/bimap/list_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/user_interface_config.hpp>", private, "<boost/bimap/multiset_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/user_interface_config.hpp>", private, "<boost/bimap/set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/user_interface_config.hpp>", private, "<boost/bimap/unconstrained_set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/user_interface_config.hpp>", private, "<boost/bimap/unordered_multiset_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/user_interface_config.hpp>", private, "<boost/bimap/unordered_set_of.hpp>", public ] },
+ { include: ["<boost/bimap/detail/user_interface_config.hpp>", private, "<boost/bimap/vector_of.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/access_builder.hpp>", private, "<boost/bimap/relation/support/get.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/access_builder.hpp>", private, "<boost/bimap/relation/support/pair_by.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/access_builder.hpp>", private, "<boost/bimap/support/map_by.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/metadata_access_builder.hpp>", private, "<boost/bimap/relation/support/data_extractor.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/metadata_access_builder.hpp>", private, "<boost/bimap/relation/support/opposite_tag.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/metadata_access_builder.hpp>", private, "<boost/bimap/relation/support/pair_type_by.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/metadata_access_builder.hpp>", private, "<boost/bimap/relation/support/value_type_of.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/metadata_access_builder.hpp>", private, "<boost/bimap/support/data_type_by.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/metadata_access_builder.hpp>", private, "<boost/bimap/support/iterator_type_by.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/metadata_access_builder.hpp>", private, "<boost/bimap/support/key_type_by.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/metadata_access_builder.hpp>", private, "<boost/bimap/support/map_type_by.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/metadata_access_builder.hpp>", private, "<boost/bimap/support/value_type_by.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/mutant.hpp>", private, "<boost/bimap/relation/mutant_relation.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/static_access_builder.hpp>", private, "<boost/bimap/support/iterator_type_by.hpp>", public ] },
+ { include: ["<boost/bimap/relation/detail/to_mutable_relation_functor.hpp>", private, "<boost/bimap/views/list_map_view.hpp>", public ] },
+ { include: ["<boost/chrono/detail/inlined/chrono.hpp>", private, "<boost/chrono/system_clocks.hpp>", public ] },
+ { include: ["<boost/chrono/detail/inlined/process_cpu_clocks.hpp>", private, "<boost/chrono/process_cpu_clocks.hpp>", public ] },
+ { include: ["<boost/chrono/detail/inlined/thread_clock.hpp>", private, "<boost/chrono/thread_clock.hpp>", public ] },
+ { include: ["<boost/chrono/detail/is_evenly_divisible_by.hpp>", private, "<boost/chrono/duration.hpp>", public ] },
+ { include: ["<boost/chrono/detail/no_warning/signed_unsigned_cmp.hpp>", private, "<boost/chrono/io/duration_get.hpp>", public ] },
+ { include: ["<boost/chrono/detail/no_warning/signed_unsigned_cmp.hpp>", private, "<boost/chrono/io_v1/chrono_io.hpp>", public ] },
+ { include: ["<boost/chrono/detail/scan_keyword.hpp>", private, "<boost/chrono/io/duration_get.hpp>", public ] },
+ { include: ["<boost/chrono/detail/scan_keyword.hpp>", private, "<boost/chrono/io/time_point_get.hpp>", public ] },
+ { include: ["<boost/chrono/detail/scan_keyword.hpp>", private, "<boost/chrono/io/time_point_io.hpp>", public ] },
+ { include: ["<boost/chrono/detail/scan_keyword.hpp>", private, "<boost/chrono/io_v1/chrono_io.hpp>", public ] },
+ { include: ["<boost/chrono/detail/static_assert.hpp>", private, "<boost/chrono/duration.hpp>", public ] },
+ { include: ["<boost/chrono/detail/system.hpp>", private, "<boost/chrono/process_cpu_clocks.hpp>", public ] },
+ { include: ["<boost/chrono/detail/system.hpp>", private, "<boost/chrono/system_clocks.hpp>", public ] },
+ { include: ["<boost/chrono/detail/system.hpp>", private, "<boost/chrono/thread_clock.hpp>", public ] },
+ { include: ["<boost/concept/detail/backward_compatibility.hpp>", private, "<boost/concept/usage.hpp>", public ] },
+ { include: ["<boost/concept/detail/borland.hpp>", private, "<boost/concept/assert.hpp>", public ] },
+ { include: ["<boost/concept/detail/concept_def.hpp>", private, "<boost/concept_check.hpp>", public ] },
+ { include: ["<boost/concept/detail/concept_def.hpp>", private, "<boost/graph/bron_kerbosch_all_cliques.hpp>", public ] },
+ { include: ["<boost/concept/detail/concept_def.hpp>", private, "<boost/graph/buffer_concepts.hpp>", public ] },
+ { include: ["<boost/concept/detail/concept_def.hpp>", private, "<boost/graph/distributed/concepts.hpp>", public ] },
+ { include: ["<boost/concept/detail/concept_def.hpp>", private, "<boost/graph/graph_concepts.hpp>", public ] },
+ { include: ["<boost/concept/detail/concept_def.hpp>", private, "<boost/graph/tiernan_all_cycles.hpp>", public ] },
+ { include: ["<boost/concept/detail/concept_def.hpp>", private, "<boost/iterator/iterator_concepts.hpp>", public ] },
+ { include: ["<boost/concept/detail/concept_undef.hpp>", private, "<boost/concept_check.hpp>", public ] },
+ { include: ["<boost/concept/detail/concept_undef.hpp>", private, "<boost/graph/bron_kerbosch_all_cliques.hpp>", public ] },
+ { include: ["<boost/concept/detail/concept_undef.hpp>", private, "<boost/graph/distributed/concepts.hpp>", public ] },
+ { include: ["<boost/concept/detail/concept_undef.hpp>", private, "<boost/graph/graph_concepts.hpp>", public ] },
+ { include: ["<boost/concept/detail/concept_undef.hpp>", private, "<boost/graph/tiernan_all_cycles.hpp>", public ] },
+ { include: ["<boost/concept/detail/concept_undef.hpp>", private, "<boost/iterator/iterator_concepts.hpp>", public ] },
+ { include: ["<boost/concept/detail/general.hpp>", private, "<boost/concept/assert.hpp>", public ] },
+ { include: ["<boost/concept/detail/msvc.hpp>", private, "<boost/concept/assert.hpp>", public ] },
+ { include: ["<boost/container/detail/advanced_insert_int.hpp>", private, "<boost/container/deque.hpp>", public ] },
+ { include: ["<boost/container/detail/advanced_insert_int.hpp>", private, "<boost/container/vector.hpp>", public ] },
+ { include: ["<boost/container/detail/algorithms.hpp>", private, "<boost/container/deque.hpp>", public ] },
+ { include: ["<boost/container/detail/algorithms.hpp>", private, "<boost/container/list.hpp>", public ] },
+ { include: ["<boost/container/detail/algorithms.hpp>", private, "<boost/container/stable_vector.hpp>", public ] },
+ { include: ["<boost/container/detail/algorithms.hpp>", private, "<boost/container/string.hpp>", public ] },
+ { include: ["<boost/container/detail/algorithms.hpp>", private, "<boost/container/vector.hpp>", public ] },
+ { include: ["<boost/container/detail/allocation_type.hpp>", private, "<boost/container/string.hpp>", public ] },
+ { include: ["<boost/container/detail/allocation_type.hpp>", private, "<boost/container/vector.hpp>", public ] },
+ { include: ["<boost/container/detail/allocation_type.hpp>", private, "<boost/interprocess/containers/allocation_type.hpp>", public ] },
+ { include: ["<boost/container/detail/allocator_version_traits.hpp>", private, "<boost/container/stable_vector.hpp>", public ] },
+ { include: ["<boost/container/detail/allocator_version_traits.hpp>", private, "<boost/container/string.hpp>", public ] },
+ { include: ["<boost/container/detail/allocator_version_traits.hpp>", private, "<boost/container/vector.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/allocator_traits.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/deque.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/flat_map.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/flat_set.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/list.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/map.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/scoped_allocator_fwd.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/scoped_allocator.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/set.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/slist.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/stable_vector.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/static_vector.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/string.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/throw_exception.hpp>", public ] },
+ { include: ["<boost/container/detail/config_begin.hpp>", private, "<boost/container/vector.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/allocator_traits.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/deque.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/flat_map.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/flat_set.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/list.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/map.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/scoped_allocator_fwd.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/scoped_allocator.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/set.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/slist.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/stable_vector.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/static_vector.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/string.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/throw_exception.hpp>", public ] },
+ { include: ["<boost/container/detail/config_end.hpp>", private, "<boost/container/vector.hpp>", public ] },
+ { include: ["<boost/container/detail/destroyers.hpp>", private, "<boost/container/vector.hpp>", public ] },
+ { include: ["<boost/container/detail/flat_tree.hpp>", private, "<boost/container/flat_map.hpp>", public ] },
+ { include: ["<boost/container/detail/flat_tree.hpp>", private, "<boost/container/flat_set.hpp>", public ] },
+ { include: ["<boost/container/detail/iterators.hpp>", private, "<boost/container/deque.hpp>", public ] },
+ { include: ["<boost/container/detail/iterators.hpp>", private, "<boost/container/list.hpp>", public ] },
+ { include: ["<boost/container/detail/iterators.hpp>", private, "<boost/container/slist.hpp>", public ] },
+ { include: ["<boost/container/detail/iterators.hpp>", private, "<boost/container/stable_vector.hpp>", public ] },
+ { include: ["<boost/container/detail/iterators.hpp>", private, "<boost/container/string.hpp>", public ] },
+ { include: ["<boost/container/detail/iterators.hpp>", private, "<boost/container/vector.hpp>", public ] },
+ { include: ["<boost/container/detail/memory_util.hpp>", private, "<boost/container/allocator_traits.hpp>", public ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/allocator_traits.hpp>", public ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/deque.hpp>", public ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/flat_map.hpp>", public ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/flat_set.hpp>", public ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/list.hpp>", public ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/map.hpp>", public ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/set.hpp>", public ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/slist.hpp>", public ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/string.hpp>", public ] },
+ { include: ["<boost/container/detail/mpl.hpp>", private, "<boost/container/vector.hpp>", public ] },
+ { include: ["<boost/container/detail/multiallocation_chain.hpp>", private, "<boost/interprocess/allocators/adaptive_pool.hpp>", public ] },
+ { include: ["<boost/container/detail/multiallocation_chain.hpp>", private, "<boost/interprocess/allocators/allocator.hpp>", public ] },
+ { include: ["<boost/container/detail/multiallocation_chain.hpp>", private, "<boost/interprocess/allocators/node_allocator.hpp>", public ] },
+ { include: ["<boost/container/detail/multiallocation_chain.hpp>", private, "<boost/interprocess/allocators/private_adaptive_pool.hpp>", public ] },
+ { include: ["<boost/container/detail/multiallocation_chain.hpp>", private, "<boost/interprocess/allocators/private_node_allocator.hpp>", public ] },
+ { include: ["<boost/container/detail/multiallocation_chain.hpp>", private, "<boost/interprocess/mem_algo/rbtree_best_fit.hpp>", public ] },
+ { include: ["<boost/container/detail/node_alloc_holder.hpp>", private, "<boost/container/list.hpp>", public ] },
+ { include: ["<boost/container/detail/node_alloc_holder.hpp>", private, "<boost/container/slist.hpp>", public ] },
+ { include: ["<boost/container/detail/pair.hpp>", private, "<boost/container/map.hpp>", public ] },
+ { include: ["<boost/container/detail/pair.hpp>", private, "<boost/container/scoped_allocator.hpp>", public ] },
+ { include: ["<boost/container/detail/pair.hpp>", private, "<boost/interprocess/containers/pair.hpp>", public ] },
+ { include: ["<boost/container/detail/preprocessor.hpp>", private, "<boost/container/allocator_traits.hpp>", public ] },
+ { include: ["<boost/container/detail/preprocessor.hpp>", private, "<boost/container/list.hpp>", public ] },
+ { include: ["<boost/container/detail/preprocessor.hpp>", private, "<boost/container/scoped_allocator_fwd.hpp>", public ] },
+ { include: ["<boost/container/detail/preprocessor.hpp>", private, "<boost/container/set.hpp>", public ] },
+ { include: ["<boost/container/detail/preprocessor.hpp>", private, "<boost/container/slist.hpp>", public ] },
+ { include: ["<boost/container/detail/tree.hpp>", private, "<boost/container/map.hpp>", public ] },
+ { include: ["<boost/container/detail/tree.hpp>", private, "<boost/container/set.hpp>", public ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/map.hpp>", public ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/scoped_allocator_fwd.hpp>", public ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/scoped_allocator.hpp>", public ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/slist.hpp>", public ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/string.hpp>", public ] },
+ { include: ["<boost/container/detail/type_traits.hpp>", private, "<boost/container/vector.hpp>", public ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/deque.hpp>", public ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/list.hpp>", public ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/map.hpp>", public ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/scoped_allocator.hpp>", public ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/slist.hpp>", public ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/stable_vector.hpp>", public ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/string.hpp>", public ] },
+ { include: ["<boost/container/detail/utilities.hpp>", private, "<boost/container/vector.hpp>", public ] },
+ { include: ["<boost/container/detail/value_init.hpp>", private, "<boost/container/map.hpp>", public ] },
+ { include: ["<boost/container/detail/version_type.hpp>", private, "<boost/container/list.hpp>", public ] },
+ { include: ["<boost/container/detail/version_type.hpp>", private, "<boost/container/string.hpp>", public ] },
+ { include: ["<boost/container/detail/version_type.hpp>", private, "<boost/container/vector.hpp>", public ] },
+ { include: ["<boost/container/detail/version_type.hpp>", private, "<boost/interprocess/containers/version_type.hpp>", public ] },
+ { include: ["<boost/context/detail/config.hpp>", private, "<boost/context/fcontext.hpp>", public ] },
+ { include: ["<boost/context/detail/config.hpp>", private, "<boost/coroutine/stack_allocator.hpp>", public ] },
+ { include: ["<boost/context/detail/fcontext_arm.hpp>", private, "<boost/context/fcontext.hpp>", public ] },
+ { include: ["<boost/context/detail/fcontext_i386.hpp>", private, "<boost/context/fcontext.hpp>", public ] },
+ { include: ["<boost/context/detail/fcontext_i386_win.hpp>", private, "<boost/context/fcontext.hpp>", public ] },
+ { include: ["<boost/context/detail/fcontext_mips.hpp>", private, "<boost/context/fcontext.hpp>", public ] },
+ { include: ["<boost/context/detail/fcontext_ppc.hpp>", private, "<boost/context/fcontext.hpp>", public ] },
+ { include: ["<boost/context/detail/fcontext_sparc.hpp>", private, "<boost/context/fcontext.hpp>", public ] },
+ { include: ["<boost/context/detail/fcontext_x86_64.hpp>", private, "<boost/context/fcontext.hpp>", public ] },
+ { include: ["<boost/context/detail/fcontext_x86_64_win.hpp>", private, "<boost/context/fcontext.hpp>", public ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/exceptions.hpp>", public ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/stack_context.hpp>", public ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v1/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/detail/config.hpp>", private, "<boost/coroutine/v2/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/detail/coroutine_context.hpp>", private, "<boost/coroutine/v1/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/detail/coroutine_context.hpp>", private, "<boost/coroutine/v2/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/detail/param.hpp>", private, "<boost/coroutine/v2/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/detail/segmented_stack_allocator.hpp>", private, "<boost/coroutine/stack_allocator.hpp>", public ] },
+ { include: ["<boost/coroutine/detail/standard_stack_allocator.hpp>", private, "<boost/coroutine/stack_allocator.hpp>", public ] },
+ { include: ["<boost/coroutine/v1/detail/arg.hpp>", private, "<boost/coroutine/v1/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/v1/detail/coroutine_base.hpp>", private, "<boost/coroutine/v1/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/v1/detail/coroutine_caller.hpp>", private, "<boost/coroutine/v1/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/v1/detail/coroutine_get.hpp>", private, "<boost/coroutine/v1/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/v1/detail/coroutine_object.hpp>", private, "<boost/coroutine/v1/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/v1/detail/coroutine_op.hpp>", private, "<boost/coroutine/v1/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/v2/detail/pull_coroutine_base.hpp>", private, "<boost/coroutine/v2/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/v2/detail/pull_coroutine_caller.hpp>", private, "<boost/coroutine/v2/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/v2/detail/pull_coroutine_object.hpp>", private, "<boost/coroutine/v2/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/v2/detail/push_coroutine_base.hpp>", private, "<boost/coroutine/v2/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/v2/detail/push_coroutine_caller.hpp>", private, "<boost/coroutine/v2/coroutine.hpp>", public ] },
+ { include: ["<boost/coroutine/v2/detail/push_coroutine_object.hpp>", private, "<boost/coroutine/v2/coroutine.hpp>", public ] },
+ { include: ["<boost/detail/algorithm.hpp>", private, "<boost/graph/graph_utility.hpp>", public ] },
+ { include: ["<boost/detail/algorithm.hpp>", private, "<boost/graph/isomorphism.hpp>", public ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/flyweight/set_factory.hpp>", public ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/multi_index/hashed_index.hpp>", public ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/multi_index/sequenced_index.hpp>", public ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/statechart/fifo_worker.hpp>", public ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/statechart/processor_container.hpp>", public ] },
+ { include: ["<boost/detail/allocator_utilities.hpp>", private, "<boost/statechart/state_machine.hpp>", public ] },
+ { include: ["<boost/detail/atomic_count.hpp>", private, "<boost/flyweight/refcounted.hpp>", public ] },
+ { include: ["<boost/detail/atomic_count.hpp>", private, "<boost/log/attributes/counter.hpp>", public ] },
+ { include: ["<boost/detail/atomic_count.hpp>", private, "<boost/log/core/record_view.hpp>", public ] },
+ { include: ["<boost/detail/atomic_count.hpp>", private, "<boost/wave/cpplexer/cpp_lex_token.hpp>", public ] },
+ { include: ["<boost/detail/atomic_count.hpp>", private, "<boost/wave/util/macro_definition.hpp>", public ] },
+ { include: ["<boost/detail/atomic_redef_macros.hpp>", private, "<boost/thread/future.hpp>", public ] },
+ { include: ["<boost/detail/atomic_undef_macros.hpp>", private, "<boost/thread/future.hpp>", public ] },
+ { include: ["<boost/detail/binary_search.hpp>", private, "<boost/test/utils/fixed_mapping.hpp>", public ] },
+ { include: ["<boost/detail/call_traits.hpp>", private, "<boost/call_traits.hpp>", public ] },
+ { include: ["<boost/detail/compressed_pair.hpp>", private, "<boost/compressed_pair.hpp>", public ] },
+ { include: ["<boost/detail/container_fwd.hpp>", private, "<boost/functional/hash/extensions.hpp>", public ] },
+ { include: ["<boost/detail/dynamic_bitset.hpp>", private, "<boost/dynamic_bitset/dynamic_bitset.hpp>", public ] },
+ { include: ["<boost/detail/endian.hpp>", private, "<boost/mpl/string.hpp>", public ] },
+ { include: ["<boost/detail/endian.hpp>", private, "<boost/multiprecision/cpp_int.hpp>", public ] },
+ { include: ["<boost/detail/fenv.hpp>", private, "<boost/math/tools/config.hpp>", public ] },
+ { include: ["<boost/detail/indirect_traits.hpp>", private, "<boost/iterator/indirect_iterator.hpp>", public ] },
+ { include: ["<boost/detail/interlocked.hpp>", private, "<boost/thread/win32/basic_timed_mutex.hpp>", public ] },
+ { include: ["<boost/detail/interlocked.hpp>", private, "<boost/thread/win32/interlocked_read.hpp>", public ] },
+ { include: ["<boost/detail/interlocked.hpp>", private, "<boost/thread/win32/once.hpp>", public ] },
+ { include: ["<boost/detail/interlocked.hpp>", private, "<boost/thread/win32/shared_mutex.hpp>", public ] },
+ { include: ["<boost/detail/interlocked.hpp>", private, "<boost/thread/win32/thread_primitives.hpp>", public ] },
+ { include: ["<boost/detail/is_incrementable.hpp>", private, "<boost/icl/type_traits/is_discrete.hpp>", public ] },
+ { include: ["<boost/detail/is_incrementable.hpp>", private, "<boost/indirect_reference.hpp>", public ] },
+ { include: ["<boost/detail/is_incrementable.hpp>", private, "<boost/iterator/new_iterator_tests.hpp>", public ] },
+ { include: ["<boost/detail/is_incrementable.hpp>", private, "<boost/pointee.hpp>", public ] },
+ { include: ["<boost/detail/is_sorted.hpp>", private, "<boost/graph/distributed/connected_components.hpp>", public ] },
+ { include: ["<boost/detail/is_sorted.hpp>", private, "<boost/range/algorithm_ext/is_sorted.hpp>", public ] },
+ { include: ["<boost/detail/is_xxx.hpp>", private, "<boost/parameter/parameters.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/algorithm/string/find_format.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/algorithm/string/formatter.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/array.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/circular_buffer.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/concept_check.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/dynamic_bitset/dynamic_bitset.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/graph/incremental_components.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/iterator/indirect_iterator.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/iterator/is_lvalue_iterator.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/iterator/is_readable_iterator.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/iterator/iterator_adaptor.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/iterator/iterator_categories.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/iterator/iterator_concepts.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/iterator/iterator_traits.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/iterator/new_iterator_tests.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/iterator/zip_iterator.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/property_map/property_map.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/python/object/iterator.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/classic/core/scanner/scanner.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/classic/iterator/multi_pass.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/classic/iterator/position_iterator_fwd.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/classic/tree/common.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/classic/utility/regex.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/lex/lexer/lexer.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/functor.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/iterator_tokenizer.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/position_token.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/token.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/lex/lexer/token_def.hpp>", public ] },
+ { include: ["<boost/detail/iterator.hpp>", private, "<boost/spirit/home/support/container.hpp>", public ] },
+ { include: ["<boost/detail/lcast_precision.hpp>", private, "<boost/lexical_cast.hpp>", public ] },
+ { include: ["<boost/detail/lightweight_test.hpp>", private, "<boost/iterator/new_iterator_tests.hpp>", public ] },
+ { include: ["<boost/detail/lightweight_test.hpp>", private, "<boost/mpl/aux_/test.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/chrono/io/duration_io.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/chrono/io/time_point_io.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/circular_buffer/details.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/deque.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/flat_map.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/map.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/scoped_allocator.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/slist.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/stable_vector.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/string.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/container/vector.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/interprocess/ipc/message_queue.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/interprocess/managed_heap_memory.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/interprocess/segment_manager.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/interprocess/smart_ptr/weak_ptr.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/move/algorithm.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/msm/back/state_machine.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/multi_index/hashed_index.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/multi_index/random_access_index.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/multi_index/sequenced_index.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/serialization/state_saver.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/thread/pthread/once_atomic.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/thread/pthread/once.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/thread/win32/once.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/thread/win32/thread_heap_alloc.hpp>", public ] },
+ { include: ["<boost/detail/no_exceptions_support.hpp>", private, "<boost/variant/variant.hpp>", public ] },
+ { include: ["<boost/detail/numeric_traits.hpp>", private, "<boost/graph/bandwidth.hpp>", public ] },
+ { include: ["<boost/detail/numeric_traits.hpp>", private, "<boost/graph/minimum_degree_ordering.hpp>", public ] },
+ { include: ["<boost/detail/numeric_traits.hpp>", private, "<boost/graph/profile.hpp>", public ] },
+ { include: ["<boost/detail/numeric_traits.hpp>", private, "<boost/graph/wavefront.hpp>", public ] },
+ { include: ["<boost/detail/numeric_traits.hpp>", private, "<boost/iterator/counting_iterator.hpp>", public ] },
+ { include: ["<boost/detail/ob_call_traits.hpp>", private, "<boost/call_traits.hpp>", public ] },
+ { include: ["<boost/detail/ob_compressed_pair.hpp>", private, "<boost/compressed_pair.hpp>", public ] },
+ { include: ["<boost/detail/reference_content.hpp>", private, "<boost/optional/optional.hpp>", public ] },
+ { include: ["<boost/detail/reference_content.hpp>", private, "<boost/variant/variant.hpp>", public ] },
+ { include: ["<boost/detail/scoped_enum_emulation.hpp>", private, "<boost/chrono/io/duration_style.hpp>", public ] },
+ { include: ["<boost/detail/scoped_enum_emulation.hpp>", private, "<boost/chrono/io/timezone.hpp>", public ] },
+ { include: ["<boost/detail/scoped_enum_emulation.hpp>", private, "<boost/coroutine/exceptions.hpp>", public ] },
+ { include: ["<boost/detail/scoped_enum_emulation.hpp>", private, "<boost/thread/cv_status.hpp>", public ] },
+ { include: ["<boost/detail/scoped_enum_emulation.hpp>", private, "<boost/thread/future_error_code.hpp>", public ] },
+ { include: ["<boost/detail/scoped_enum_emulation.hpp>", private, "<boost/thread/future.hpp>", public ] },
+ { include: ["<boost/detail/select_type.hpp>", private, "<boost/cast.hpp>", public ] },
+ { include: ["<boost/detail/sp_typeinfo.hpp>", private, "<boost/proto/debug.hpp>", public ] },
+ { include: ["<boost/detail/templated_streams.hpp>", private, "<boost/blank.hpp>", public ] },
+ { include: ["<boost/detail/templated_streams.hpp>", private, "<boost/flyweight/flyweight_fwd.hpp>", public ] },
+ { include: ["<boost/exception/detail/attribute_noreturn.hpp>", private, "<boost/throw_exception.hpp>", public ] },
+ { include: ["<boost/exception/detail/error_info_impl.hpp>", private, "<boost/exception/get_error_info.hpp>", public ] },
+ { include: ["<boost/exception/detail/error_info_impl.hpp>", private, "<boost/exception/info.hpp>", public ] },
+ { include: ["<boost/exception/detail/exception_ptr.hpp>", private, "<boost/exception_ptr.hpp>", public ] },
+ { include: ["<boost/exception/detail/is_output_streamable.hpp>", private, "<boost/exception/to_string.hpp>", public ] },
+ { include: ["<boost/exception/detail/object_hex_dump.hpp>", private, "<boost/exception/to_string_stub.hpp>", public ] },
+ { include: ["<boost/exception/detail/type_info.hpp>", private, "<boost/exception/get_error_info.hpp>", public ] },
+ { include: ["<boost/flyweight/detail/default_value_policy.hpp>", private, "<boost/flyweight/flyweight.hpp>", public ] },
+ { include: ["<boost/flyweight/detail/flyweight_core.hpp>", private, "<boost/flyweight/flyweight.hpp>", public ] },
+ { include: ["<boost/flyweight/detail/is_placeholder_expr.hpp>", private, "<boost/flyweight/assoc_container_factory.hpp>", public ] },
+ { include: ["<boost/flyweight/detail/nested_xxx_if_not_ph.hpp>", private, "<boost/flyweight/assoc_container_factory.hpp>", public ] },
+ { include: ["<boost/flyweight/detail/not_placeholder_expr.hpp>", private, "<boost/flyweight/assoc_container_factory_fwd.hpp>", public ] },
+ { include: ["<boost/flyweight/detail/not_placeholder_expr.hpp>", private, "<boost/flyweight/hashed_factory_fwd.hpp>", public ] },
+ { include: ["<boost/flyweight/detail/not_placeholder_expr.hpp>", private, "<boost/flyweight/set_factory_fwd.hpp>", public ] },
+ { include: ["<boost/flyweight/detail/perfect_fwd.hpp>", private, "<boost/flyweight/flyweight.hpp>", public ] },
+ { include: ["<boost/flyweight/detail/perfect_fwd.hpp>", private, "<boost/flyweight/key_value.hpp>", public ] },
+ { include: ["<boost/flyweight/detail/recursive_lw_mutex.hpp>", private, "<boost/flyweight/simple_locking.hpp>", public ] },
+ { include: ["<boost/flyweight/detail/value_tag.hpp>", private, "<boost/flyweight/key_value.hpp>", public ] },
+ { include: ["<boost/functional/hash/detail/hash_float.hpp>", private, "<boost/functional/hash/hash.hpp>", public ] },
+ { include: ["<boost/functional/overloaded_function/detail/base.hpp>", private, "<boost/functional/overloaded_function.hpp>", public ] },
+ { include: ["<boost/functional/overloaded_function/detail/function_type.hpp>", private, "<boost/functional/overloaded_function.hpp>", public ] },
+ { include: ["<boost/function_types/detail/classifier.hpp>", private, "<boost/function_types/components.hpp>", public ] },
+ { include: ["<boost/function_types/detail/class_transform.hpp>", private, "<boost/function_types/components.hpp>", public ] },
+ { include: ["<boost/function_types/detail/components_as_mpl_sequence.hpp>", private, "<boost/function_types/components.hpp>", public ] },
+ { include: ["<boost/function_types/detail/pp_loop.hpp>", private, "<boost/function_types/components.hpp>", public ] },
+ { include: ["<boost/function_types/detail/pp_loop.hpp>", private, "<boost/function_types/property_tags.hpp>", public ] },
+ { include: ["<boost/function_types/detail/pp_tags/preprocessed.hpp>", private, "<boost/function_types/property_tags.hpp>", public ] },
+ { include: ["<boost/function_types/detail/retag_default_cc.hpp>", private, "<boost/function_types/components.hpp>", public ] },
+ { include: ["<boost/function_types/detail/synthesize.hpp>", private, "<boost/function_types/function_type.hpp>", public ] },
+ { include: ["<boost/function_types/detail/synthesize.hpp>", private, "<boost/function_types/member_function_pointer.hpp>", public ] },
+ { include: ["<boost/function_types/detail/synthesize.hpp>", private, "<boost/function_types/member_object_pointer.hpp>", public ] },
+ { include: ["<boost/function_types/detail/to_sequence.hpp>", private, "<boost/function_types/function_type.hpp>", public ] },
+ { include: ["<boost/function_types/detail/to_sequence.hpp>", private, "<boost/function_types/member_function_pointer.hpp>", public ] },
+ { include: ["<boost/function_types/detail/to_sequence.hpp>", private, "<boost/function_types/member_object_pointer.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/adt/detail/adapt_base.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/adt/detail/adapt_base.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/adt/detail/extension.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/adt/detail/extension.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_array/detail/at_impl.hpp>", private, "<boost/fusion/adapted/boost_array.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_array/detail/begin_impl.hpp>", private, "<boost/fusion/adapted/boost_array.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_array/detail/category_of_impl.hpp>", private, "<boost/fusion/adapted/boost_array.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_array/detail/end_impl.hpp>", private, "<boost/fusion/adapted/boost_array.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_array/detail/is_sequence_impl.hpp>", private, "<boost/fusion/adapted/boost_array.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_array/detail/is_view_impl.hpp>", private, "<boost/fusion/adapted/boost_array.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_array/detail/size_impl.hpp>", private, "<boost/fusion/adapted/boost_array.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_array/detail/value_at_impl.hpp>", private, "<boost/fusion/adapted/boost_array.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_tuple/detail/at_impl.hpp>", private, "<boost/fusion/adapted/boost_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_tuple/detail/begin_impl.hpp>", private, "<boost/fusion/adapted/boost_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_tuple/detail/category_of_impl.hpp>", private, "<boost/fusion/adapted/boost_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_tuple/detail/end_impl.hpp>", private, "<boost/fusion/adapted/boost_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_tuple/detail/is_sequence_impl.hpp>", private, "<boost/fusion/adapted/boost_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_tuple/detail/is_view_impl.hpp>", private, "<boost/fusion/adapted/boost_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_tuple/detail/size_impl.hpp>", private, "<boost/fusion/adapted/boost_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/boost_tuple/detail/value_at_impl.hpp>", private, "<boost/fusion/adapted/boost_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/mpl/detail/at_impl.hpp>", private, "<boost/fusion/adapted/mpl.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/mpl/detail/begin_impl.hpp>", private, "<boost/fusion/adapted/mpl.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/mpl/detail/begin_impl.hpp>", private, "<boost/fusion/mpl/begin.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/mpl/detail/category_of_impl.hpp>", private, "<boost/fusion/adapted/mpl.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/mpl/detail/empty_impl.hpp>", private, "<boost/fusion/adapted/mpl.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/mpl/detail/end_impl.hpp>", private, "<boost/fusion/adapted/mpl.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/mpl/detail/end_impl.hpp>", private, "<boost/fusion/mpl/end.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/mpl/detail/has_key_impl.hpp>", private, "<boost/fusion/adapted/mpl.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/mpl/detail/is_sequence_impl.hpp>", private, "<boost/fusion/adapted/mpl.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/mpl/detail/is_view_impl.hpp>", private, "<boost/fusion/adapted/mpl.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/mpl/detail/size_impl.hpp>", private, "<boost/fusion/adapted/mpl.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/mpl/detail/value_at_impl.hpp>", private, "<boost/fusion/adapted/mpl.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/std_tuple/detail/at_impl.hpp>", private, "<boost/fusion/adapted/std_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/std_tuple/detail/begin_impl.hpp>", private, "<boost/fusion/adapted/std_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/std_tuple/detail/category_of_impl.hpp>", private, "<boost/fusion/adapted/std_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/std_tuple/detail/end_impl.hpp>", private, "<boost/fusion/adapted/std_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/std_tuple/detail/is_sequence_impl.hpp>", private, "<boost/fusion/adapted/std_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/std_tuple/detail/is_view_impl.hpp>", private, "<boost/fusion/adapted/std_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/std_tuple/detail/size_impl.hpp>", private, "<boost/fusion/adapted/std_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/std_tuple/detail/value_at_impl.hpp>", private, "<boost/fusion/adapted/std_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/adapt_base.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/adapt_base.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/adapt_base.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/adapt_base.hpp>", private, "<boost/fusion/adapted/struct/adapt_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/at_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/at_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/at_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/at_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/begin_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/begin_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/begin_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/begin_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/category_of_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/category_of_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/category_of_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/category_of_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/define_struct.hpp>", private, "<boost/fusion/adapted/struct/define_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/define_struct.hpp>", private, "<boost/fusion/adapted/struct/define_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/define_struct_inline.hpp>", private, "<boost/fusion/adapted/struct/define_struct_inline.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/deref_data_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/deref_data_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/deref_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/deref_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/deref_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/deref_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/end_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/end_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/end_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/end_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/extension.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/extension.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/extension.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/extension.hpp>", private, "<boost/fusion/adapted/struct/adapt_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/is_sequence_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/is_sequence_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/is_sequence_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/is_sequence_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/is_view_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/is_view_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/is_view_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/is_view_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/key_of_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/key_of_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/proxy_type.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt_named.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/proxy_type.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt_named.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/proxy_type.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct_named.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/proxy_type.hpp>", private, "<boost/fusion/adapted/struct/adapt_struct_named.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/proxy_type.hpp>", private, "<boost/fusion/include/proxy_type.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/size_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/size_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/size_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/size_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/value_at_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/value_at_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/value_at_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/value_at_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/value_of_data_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/value_of_data_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/value_of_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/value_of_impl.hpp>", private, "<boost/fusion/adapted/adt/adapt_assoc_adt.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/value_of_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_assoc_struct.hpp>", public ] },
+ { include: ["<boost/fusion/adapted/struct/detail/value_of_impl.hpp>", private, "<boost/fusion/adapted/struct/adapt_struct.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/iteration/detail/fold.hpp>", private, "<boost/fusion/algorithm/iteration/fold.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/iteration/detail/fold.hpp>", private, "<boost/fusion/algorithm/iteration/iter_fold.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/iteration/detail/fold.hpp>", private, "<boost/fusion/algorithm/iteration/reverse_fold.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/iteration/detail/fold.hpp>", private, "<boost/fusion/algorithm/iteration/reverse_iter_fold.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/iteration/detail/for_each.hpp>", private, "<boost/fusion/algorithm/iteration/for_each.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/iteration/detail/preprocessed/fold.hpp>", private, "<boost/fusion/algorithm/iteration/fold.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/iteration/detail/preprocessed/iter_fold.hpp>", private, "<boost/fusion/algorithm/iteration/iter_fold.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/iteration/detail/preprocessed/reverse_fold.hpp>", private, "<boost/fusion/algorithm/iteration/reverse_fold.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/iteration/detail/preprocessed/reverse_iter_fold.hpp>", private, "<boost/fusion/algorithm/iteration/reverse_iter_fold.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/iteration/detail/segmented_fold.hpp>", private, "<boost/fusion/algorithm/iteration/fold.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/iteration/detail/segmented_for_each.hpp>", private, "<boost/fusion/algorithm/iteration/for_each.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/query/detail/all.hpp>", private, "<boost/fusion/algorithm/query/all.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/query/detail/any.hpp>", private, "<boost/fusion/algorithm/query/any.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/query/detail/count.hpp>", private, "<boost/fusion/algorithm/query/count.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/query/detail/count_if.hpp>", private, "<boost/fusion/algorithm/query/count_if.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/query/detail/find_if.hpp>", private, "<boost/fusion/algorithm/query/find.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/query/detail/find_if.hpp>", private, "<boost/fusion/algorithm/query/find_if.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/query/detail/find_if.hpp>", private, "<boost/fusion/view/filter_view/filter_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/query/detail/segmented_find.hpp>", private, "<boost/fusion/algorithm/query/find.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/query/detail/segmented_find_if.hpp>", private, "<boost/fusion/algorithm/query/find_if.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/transformation/detail/preprocessed/zip.hpp>", private, "<boost/fusion/algorithm/transformation/zip.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/transformation/detail/replace.hpp>", private, "<boost/fusion/algorithm/transformation/replace.hpp>", public ] },
+ { include: ["<boost/fusion/algorithm/transformation/detail/replace_if.hpp>", private, "<boost/fusion/algorithm/transformation/replace_if.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/at_impl.hpp>", private, "<boost/fusion/container/deque/deque.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/begin_impl.hpp>", private, "<boost/fusion/container/deque/deque.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/build_deque.hpp>", private, "<boost/fusion/container/deque/convert.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/convert_impl.hpp>", private, "<boost/fusion/container/deque/convert.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/build_deque.hpp>", private, "<boost/fusion/container/deque/convert.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/deque_fwd.hpp>", private, "<boost/fusion/container/deque/deque_fwd.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/cpp03/deque.hpp>", private, "<boost/fusion/container/deque/deque.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/deque_keyed_values.hpp>", private, "<boost/fusion/container/deque/deque.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/end_impl.hpp>", private, "<boost/fusion/container/deque/deque.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/is_sequence_impl.hpp>", private, "<boost/fusion/container/deque/deque.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/keyed_element.hpp>", private, "<boost/fusion/container/deque/back_extended_deque.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/keyed_element.hpp>", private, "<boost/fusion/container/deque/deque.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/keyed_element.hpp>", private, "<boost/fusion/container/deque/deque_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/keyed_element.hpp>", private, "<boost/fusion/container/deque/front_extended_deque.hpp>", public ] },
+ { include: ["<boost/fusion/container/deque/detail/value_at_impl.hpp>", private, "<boost/fusion/container/deque/deque.hpp>", public ] },
+ { include: ["<boost/fusion/container/generation/detail/pp_deque_tie.hpp>", private, "<boost/fusion/container/generation/deque_tie.hpp>", public ] },
+ { include: ["<boost/fusion/container/generation/detail/pp_make_deque.hpp>", private, "<boost/fusion/container/generation/make_deque.hpp>", public ] },
+ { include: ["<boost/fusion/container/generation/detail/pp_make_map.hpp>", private, "<boost/fusion/container/generation/make_map.hpp>", public ] },
+ { include: ["<boost/fusion/container/generation/detail/pp_map_tie.hpp>", private, "<boost/fusion/container/generation/map_tie.hpp>", public ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/list_tie.hpp>", private, "<boost/fusion/container/generation/list_tie.hpp>", public ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_list.hpp>", private, "<boost/fusion/container/generation/make_list.hpp>", public ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_set.hpp>", private, "<boost/fusion/container/generation/make_set.hpp>", public ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/make_vector.hpp>", private, "<boost/fusion/container/generation/make_vector.hpp>", public ] },
+ { include: ["<boost/fusion/container/generation/detail/preprocessed/vector_tie.hpp>", private, "<boost/fusion/container/generation/vector_tie.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/at_impl.hpp>", private, "<boost/fusion/container/list/cons.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/begin_impl.hpp>", private, "<boost/fusion/container/list/cons.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/build_cons.hpp>", private, "<boost/fusion/container/list/convert.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/convert_impl.hpp>", private, "<boost/fusion/container/list/convert.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/deref_impl.hpp>", private, "<boost/fusion/container/list/cons_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/empty_impl.hpp>", private, "<boost/fusion/container/list/cons.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/end_impl.hpp>", private, "<boost/fusion/container/list/cons.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/equal_to_impl.hpp>", private, "<boost/fusion/container/list/cons_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/list_forward_ctor.hpp>", private, "<boost/fusion/container/list/list.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/list_to_cons.hpp>", private, "<boost/fusion/container/list/list.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/next_impl.hpp>", private, "<boost/fusion/container/list/cons_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list_fwd.hpp>", private, "<boost/fusion/container/list/list_fwd.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/preprocessed/list.hpp>", private, "<boost/fusion/container/list/list.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/value_at_impl.hpp>", private, "<boost/fusion/container/list/cons.hpp>", public ] },
+ { include: ["<boost/fusion/container/list/detail/value_of_impl.hpp>", private, "<boost/fusion/container/list/cons_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/container/map/detail/at_impl.hpp>", private, "<boost/fusion/container/map/map.hpp>", public ] },
+ { include: ["<boost/fusion/container/map/detail/at_key_impl.hpp>", private, "<boost/fusion/container/map/map.hpp>", public ] },
+ { include: ["<boost/fusion/container/map/detail/begin_impl.hpp>", private, "<boost/fusion/container/map/map.hpp>", public ] },
+ { include: ["<boost/fusion/container/map/detail/build_map.hpp>", private, "<boost/fusion/container/map/convert.hpp>", public ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/convert.hpp>", private, "<boost/fusion/container/map/convert.hpp>", public ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/map_fwd.hpp>", private, "<boost/fusion/container/map/map_fwd.hpp>", public ] },
+ { include: ["<boost/fusion/container/map/detail/cpp03/map.hpp>", private, "<boost/fusion/container/map/map.hpp>", public ] },
+ { include: ["<boost/fusion/container/map/detail/end_impl.hpp>", private, "<boost/fusion/container/map/map.hpp>", public ] },
+ { include: ["<boost/fusion/container/map/detail/map_impl.hpp>", private, "<boost/fusion/container/map/map_fwd.hpp>", public ] },
+ { include: ["<boost/fusion/container/map/detail/map_impl.hpp>", private, "<boost/fusion/container/map/map.hpp>", public ] },
+ { include: ["<boost/fusion/container/map/detail/value_at_impl.hpp>", private, "<boost/fusion/container/map/map.hpp>", public ] },
+ { include: ["<boost/fusion/container/map/detail/value_at_key_impl.hpp>", private, "<boost/fusion/container/map/map.hpp>", public ] },
+ { include: ["<boost/fusion/container/set/detail/as_set.hpp>", private, "<boost/fusion/container/set/convert.hpp>", public ] },
+ { include: ["<boost/fusion/container/set/detail/begin_impl.hpp>", private, "<boost/fusion/container/set/set.hpp>", public ] },
+ { include: ["<boost/fusion/container/set/detail/convert_impl.hpp>", private, "<boost/fusion/container/set/convert.hpp>", public ] },
+ { include: ["<boost/fusion/container/set/detail/deref_data_impl.hpp>", private, "<boost/fusion/container/set/set.hpp>", public ] },
+ { include: ["<boost/fusion/container/set/detail/deref_impl.hpp>", private, "<boost/fusion/container/set/set.hpp>", public ] },
+ { include: ["<boost/fusion/container/set/detail/end_impl.hpp>", private, "<boost/fusion/container/set/set.hpp>", public ] },
+ { include: ["<boost/fusion/container/set/detail/key_of_impl.hpp>", private, "<boost/fusion/container/set/set.hpp>", public ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/set_fwd.hpp>", private, "<boost/fusion/container/set/set_fwd.hpp>", public ] },
+ { include: ["<boost/fusion/container/set/detail/preprocessed/set.hpp>", private, "<boost/fusion/container/set/set.hpp>", public ] },
+ { include: ["<boost/fusion/container/set/detail/set_forward_ctor.hpp>", private, "<boost/fusion/container/set/set.hpp>", public ] },
+ { include: ["<boost/fusion/container/set/detail/value_of_data_impl.hpp>", private, "<boost/fusion/container/set/set.hpp>", public ] },
+ { include: ["<boost/fusion/container/set/detail/value_of_impl.hpp>", private, "<boost/fusion/container/set/set.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/advance_impl.hpp>", private, "<boost/fusion/container/vector/vector_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/as_vector.hpp>", private, "<boost/fusion/container/vector/convert.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/at_impl.hpp>", private, "<boost/fusion/container/vector/vector10.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/at_impl.hpp>", private, "<boost/fusion/container/vector/vector20.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/at_impl.hpp>", private, "<boost/fusion/container/vector/vector30.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/at_impl.hpp>", private, "<boost/fusion/container/vector/vector40.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/at_impl.hpp>", private, "<boost/fusion/container/vector/vector50.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/begin_impl.hpp>", private, "<boost/fusion/container/vector/vector10.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/begin_impl.hpp>", private, "<boost/fusion/container/vector/vector20.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/begin_impl.hpp>", private, "<boost/fusion/container/vector/vector30.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/begin_impl.hpp>", private, "<boost/fusion/container/vector/vector40.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/begin_impl.hpp>", private, "<boost/fusion/container/vector/vector50.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/convert_impl.hpp>", private, "<boost/fusion/container/vector/convert.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/deref_impl.hpp>", private, "<boost/fusion/container/vector/vector_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/distance_impl.hpp>", private, "<boost/fusion/container/vector/vector_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/end_impl.hpp>", private, "<boost/fusion/container/vector/vector10.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/end_impl.hpp>", private, "<boost/fusion/container/vector/vector20.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/end_impl.hpp>", private, "<boost/fusion/container/vector/vector30.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/end_impl.hpp>", private, "<boost/fusion/container/vector/vector40.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/end_impl.hpp>", private, "<boost/fusion/container/vector/vector50.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/equal_to_impl.hpp>", private, "<boost/fusion/container/vector/vector_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/next_impl.hpp>", private, "<boost/fusion/container/vector/vector_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector10_fwd.hpp>", private, "<boost/fusion/container/vector/vector10_fwd.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector10.hpp>", private, "<boost/fusion/container/vector/vector10.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector20_fwd.hpp>", private, "<boost/fusion/container/vector/vector20_fwd.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector20.hpp>", private, "<boost/fusion/container/vector/vector20.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector30_fwd.hpp>", private, "<boost/fusion/container/vector/vector30_fwd.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector30.hpp>", private, "<boost/fusion/container/vector/vector30.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector40_fwd.hpp>", private, "<boost/fusion/container/vector/vector40_fwd.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector40.hpp>", private, "<boost/fusion/container/vector/vector40.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector50_fwd.hpp>", private, "<boost/fusion/container/vector/vector50_fwd.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector50.hpp>", private, "<boost/fusion/container/vector/vector50.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector_fwd.hpp>", private, "<boost/fusion/container/vector/vector_fwd.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/preprocessed/vector.hpp>", private, "<boost/fusion/container/vector/vector.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/prior_impl.hpp>", private, "<boost/fusion/container/vector/vector_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/value_at_impl.hpp>", private, "<boost/fusion/container/vector/vector10.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/value_at_impl.hpp>", private, "<boost/fusion/container/vector/vector20.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/value_at_impl.hpp>", private, "<boost/fusion/container/vector/vector30.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/value_at_impl.hpp>", private, "<boost/fusion/container/vector/vector40.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/value_at_impl.hpp>", private, "<boost/fusion/container/vector/vector50.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/value_of_impl.hpp>", private, "<boost/fusion/container/vector/vector_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/vector_forward_ctor.hpp>", private, "<boost/fusion/container/vector/vector.hpp>", public ] },
+ { include: ["<boost/fusion/container/vector/detail/vector_n_chooser.hpp>", private, "<boost/fusion/container/vector/vector.hpp>", public ] },
+ { include: ["<boost/fusion/functional/adapter/detail/access.hpp>", private, "<boost/fusion/functional/adapter/fused_function_object.hpp>", public ] },
+ { include: ["<boost/fusion/functional/adapter/detail/access.hpp>", private, "<boost/fusion/functional/adapter/fused.hpp>", public ] },
+ { include: ["<boost/fusion/functional/adapter/detail/access.hpp>", private, "<boost/fusion/functional/adapter/fused_procedure.hpp>", public ] },
+ { include: ["<boost/fusion/functional/adapter/detail/access.hpp>", private, "<boost/fusion/functional/adapter/unfused.hpp>", public ] },
+ { include: ["<boost/fusion/functional/adapter/detail/access.hpp>", private, "<boost/fusion/functional/adapter/unfused_typed.hpp>", public ] },
+ { include: ["<boost/fusion/functional/generation/detail/gen_make_adapter.hpp>", private, "<boost/fusion/functional/generation/make_fused_function_object.hpp>", public ] },
+ { include: ["<boost/fusion/functional/generation/detail/gen_make_adapter.hpp>", private, "<boost/fusion/functional/generation/make_fused.hpp>", public ] },
+ { include: ["<boost/fusion/functional/generation/detail/gen_make_adapter.hpp>", private, "<boost/fusion/functional/generation/make_fused_procedure.hpp>", public ] },
+ { include: ["<boost/fusion/functional/generation/detail/gen_make_adapter.hpp>", private, "<boost/fusion/functional/generation/make_unfused.hpp>", public ] },
+ { include: ["<boost/fusion/functional/invocation/detail/that_ptr.hpp>", private, "<boost/fusion/functional/invocation/invoke.hpp>", public ] },
+ { include: ["<boost/fusion/functional/invocation/detail/that_ptr.hpp>", private, "<boost/fusion/functional/invocation/invoke_procedure.hpp>", public ] },
+ { include: ["<boost/fusion/iterator/detail/advance.hpp>", private, "<boost/fusion/iterator/advance.hpp>", public ] },
+ { include: ["<boost/fusion/iterator/detail/advance.hpp>", private, "<boost/fusion/iterator/iterator_adapter.hpp>", public ] },
+ { include: ["<boost/fusion/iterator/detail/advance.hpp>", private, "<boost/fusion/iterator/iterator_facade.hpp>", public ] },
+ { include: ["<boost/fusion/iterator/detail/distance.hpp>", private, "<boost/fusion/iterator/distance.hpp>", public ] },
+ { include: ["<boost/fusion/iterator/detail/distance.hpp>", private, "<boost/fusion/iterator/iterator_facade.hpp>", public ] },
+ { include: ["<boost/fusion/iterator/detail/segmented_iterator.hpp>", private, "<boost/fusion/iterator/segmented_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/iterator/detail/segmented_next_impl.hpp>", private, "<boost/fusion/iterator/segmented_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/mpl/detail/clear.hpp>", private, "<boost/fusion/mpl/clear.hpp>", public ] },
+ { include: ["<boost/fusion/sequence/comparison/detail/equal_to.hpp>", private, "<boost/fusion/algorithm/auxiliary/copy.hpp>", public ] },
+ { include: ["<boost/fusion/sequence/comparison/detail/equal_to.hpp>", private, "<boost/fusion/algorithm/auxiliary/move.hpp>", public ] },
+ { include: ["<boost/fusion/sequence/comparison/detail/equal_to.hpp>", private, "<boost/fusion/sequence/comparison/equal_to.hpp>", public ] },
+ { include: ["<boost/fusion/sequence/comparison/detail/greater_equal.hpp>", private, "<boost/fusion/sequence/comparison/greater_equal.hpp>", public ] },
+ { include: ["<boost/fusion/sequence/comparison/detail/greater.hpp>", private, "<boost/fusion/sequence/comparison/greater.hpp>", public ] },
+ { include: ["<boost/fusion/sequence/comparison/detail/less_equal.hpp>", private, "<boost/fusion/sequence/comparison/less_equal.hpp>", public ] },
+ { include: ["<boost/fusion/sequence/comparison/detail/less.hpp>", private, "<boost/fusion/sequence/comparison/less.hpp>", public ] },
+ { include: ["<boost/fusion/sequence/comparison/detail/not_equal_to.hpp>", private, "<boost/fusion/sequence/comparison/not_equal_to.hpp>", public ] },
+ { include: ["<boost/fusion/sequence/intrinsic/detail/segmented_begin.hpp>", private, "<boost/fusion/sequence/intrinsic/begin.hpp>", public ] },
+ { include: ["<boost/fusion/sequence/intrinsic/detail/segmented_end.hpp>", private, "<boost/fusion/sequence/intrinsic/end.hpp>", public ] },
+ { include: ["<boost/fusion/sequence/intrinsic/detail/segmented_size.hpp>", private, "<boost/fusion/sequence/intrinsic/size.hpp>", public ] },
+ { include: ["<boost/fusion/sequence/io/detail/in.hpp>", private, "<boost/fusion/sequence/io/in.hpp>", public ] },
+ { include: ["<boost/fusion/sequence/io/detail/out.hpp>", private, "<boost/fusion/sequence/io/out.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/adapted/std_tuple/std_tuple_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/deque/deque.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/list/cons.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/set/set.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/vector/vector10.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/vector/vector20.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/vector/vector30.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/vector/vector40.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/container/vector/vector50.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/functional/adapter/unfused_typed.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/sequence/intrinsic/at.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/sequence/intrinsic/at_key.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/support/pair.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/view/filter_view/filter_view.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/view/iterator_range/iterator_range.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/view/joint_view/joint_view.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/view/reverse_view/reverse_view.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/view/single_view/single_view.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/view/single_view/single_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/access.hpp>", private, "<boost/fusion/view/transform_view/transform_view.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/algorithm/transformation/erase.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/algorithm/transformation/insert.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/algorithm/transformation/insert_range.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/algorithm/transformation/push_back.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/algorithm/transformation/push_front.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/container/generation/deque_tie.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/container/generation/make_cons.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/container/generation/make_deque.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/container/generation/make_list.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/container/generation/make_map.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/container/generation/make_set.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/container/generation/make_vector.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/support/pair.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/tuple/make_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/as_fusion_element.hpp>", private, "<boost/fusion/view/single_view/single_view.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/category_of.hpp>", private, "<boost/fusion/support/category_of.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/is_mpl_sequence.hpp>", private, "<boost/fusion/support/tag_of.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/is_view.hpp>", private, "<boost/fusion/support/is_view.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/mpl_iterator_category.hpp>", private, "<boost/fusion/adapted/mpl/mpl_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/pp_round.hpp>", private, "<boost/fusion/algorithm/transformation/zip.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/pp_round.hpp>", private, "<boost/fusion/container/vector/limits.hpp>", public ] },
+ { include: ["<boost/fusion/support/detail/segmented_fold_until_impl.hpp>", private, "<boost/fusion/support/segmented_fold_until.hpp>", public ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/make_tuple.hpp>", private, "<boost/fusion/tuple/make_tuple.hpp>", public ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple_fwd.hpp>", private, "<boost/fusion/tuple/tuple_fwd.hpp>", public ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple.hpp>", private, "<boost/fusion/tuple/tuple.hpp>", public ] },
+ { include: ["<boost/fusion/tuple/detail/preprocessed/tuple_tie.hpp>", private, "<boost/fusion/tuple/tuple_tie.hpp>", public ] },
+ { include: ["<boost/fusion/tuple/detail/tuple_expand.hpp>", private, "<boost/fusion/tuple/tuple.hpp>", public ] },
+ { include: ["<boost/fusion/view/detail/strictest_traversal.hpp>", private, "<boost/fusion/view/transform_view/transform_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/detail/strictest_traversal.hpp>", private, "<boost/fusion/view/zip_view/zip_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/detail/strictest_traversal.hpp>", private, "<boost/fusion/view/zip_view/zip_view_iterator_fwd.hpp>", public ] },
+ { include: ["<boost/fusion/view/filter_view/detail/begin_impl.hpp>", private, "<boost/fusion/view/filter_view/filter_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/filter_view/detail/deref_data_impl.hpp>", private, "<boost/fusion/view/filter_view/filter_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/filter_view/detail/deref_impl.hpp>", private, "<boost/fusion/view/filter_view/filter_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/filter_view/detail/end_impl.hpp>", private, "<boost/fusion/view/filter_view/filter_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/filter_view/detail/equal_to_impl.hpp>", private, "<boost/fusion/view/filter_view/filter_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/filter_view/detail/key_of_impl.hpp>", private, "<boost/fusion/view/filter_view/filter_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/filter_view/detail/next_impl.hpp>", private, "<boost/fusion/view/filter_view/filter_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/filter_view/detail/size_impl.hpp>", private, "<boost/fusion/view/filter_view/filter_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/filter_view/detail/value_of_data_impl.hpp>", private, "<boost/fusion/view/filter_view/filter_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/filter_view/detail/value_of_impl.hpp>", private, "<boost/fusion/view/filter_view/filter_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/iterator_range/detail/at_impl.hpp>", private, "<boost/fusion/view/iterator_range/iterator_range.hpp>", public ] },
+ { include: ["<boost/fusion/view/iterator_range/detail/begin_impl.hpp>", private, "<boost/fusion/view/iterator_range/iterator_range.hpp>", public ] },
+ { include: ["<boost/fusion/view/iterator_range/detail/end_impl.hpp>", private, "<boost/fusion/view/iterator_range/iterator_range.hpp>", public ] },
+ { include: ["<boost/fusion/view/iterator_range/detail/is_segmented_impl.hpp>", private, "<boost/fusion/view/iterator_range/iterator_range.hpp>", public ] },
+ { include: ["<boost/fusion/view/iterator_range/detail/segments_impl.hpp>", private, "<boost/fusion/view/iterator_range/iterator_range.hpp>", public ] },
+ { include: ["<boost/fusion/view/iterator_range/detail/size_impl.hpp>", private, "<boost/fusion/view/iterator_range/iterator_range.hpp>", public ] },
+ { include: ["<boost/fusion/view/iterator_range/detail/value_at_impl.hpp>", private, "<boost/fusion/view/iterator_range/iterator_range.hpp>", public ] },
+ { include: ["<boost/fusion/view/joint_view/detail/begin_impl.hpp>", private, "<boost/fusion/view/joint_view/joint_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/joint_view/detail/deref_data_impl.hpp>", private, "<boost/fusion/view/joint_view/joint_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/joint_view/detail/deref_impl.hpp>", private, "<boost/fusion/view/joint_view/joint_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/joint_view/detail/end_impl.hpp>", private, "<boost/fusion/view/joint_view/joint_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/joint_view/detail/key_of_impl.hpp>", private, "<boost/fusion/view/joint_view/joint_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/joint_view/detail/next_impl.hpp>", private, "<boost/fusion/view/joint_view/joint_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/joint_view/detail/value_of_data_impl.hpp>", private, "<boost/fusion/view/joint_view/joint_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/joint_view/detail/value_of_impl.hpp>", private, "<boost/fusion/view/joint_view/joint_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/nview/detail/advance_impl.hpp>", private, "<boost/fusion/view/nview/nview_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/nview/detail/at_impl.hpp>", private, "<boost/fusion/view/nview/nview_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/nview/detail/begin_impl.hpp>", private, "<boost/fusion/view/nview/nview_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/nview/detail/deref_impl.hpp>", private, "<boost/fusion/view/nview/nview_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/nview/detail/distance_impl.hpp>", private, "<boost/fusion/view/nview/nview_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/nview/detail/end_impl.hpp>", private, "<boost/fusion/view/nview/nview_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/nview/detail/equal_to_impl.hpp>", private, "<boost/fusion/view/nview/nview_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/nview/detail/next_impl.hpp>", private, "<boost/fusion/view/nview/nview_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/nview/detail/nview_impl.hpp>", private, "<boost/fusion/view/nview/nview.hpp>", public ] },
+ { include: ["<boost/fusion/view/nview/detail/prior_impl.hpp>", private, "<boost/fusion/view/nview/nview_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/nview/detail/size_impl.hpp>", private, "<boost/fusion/view/nview/nview_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/nview/detail/value_at_impl.hpp>", private, "<boost/fusion/view/nview/nview_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/nview/detail/value_of_impl.hpp>", private, "<boost/fusion/view/nview/nview_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/repetitive_view/detail/begin_impl.hpp>", private, "<boost/fusion/view/repetitive_view/repetitive_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/repetitive_view/detail/deref_impl.hpp>", private, "<boost/fusion/view/repetitive_view/repetitive_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/repetitive_view/detail/end_impl.hpp>", private, "<boost/fusion/view/repetitive_view/repetitive_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/repetitive_view/detail/next_impl.hpp>", private, "<boost/fusion/view/repetitive_view/repetitive_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/repetitive_view/detail/value_of_impl.hpp>", private, "<boost/fusion/view/repetitive_view/repetitive_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/reverse_view/detail/advance_impl.hpp>", private, "<boost/fusion/view/reverse_view/reverse_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/reverse_view/detail/at_impl.hpp>", private, "<boost/fusion/view/reverse_view/reverse_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/reverse_view/detail/begin_impl.hpp>", private, "<boost/fusion/view/reverse_view/reverse_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/reverse_view/detail/deref_data_impl.hpp>", private, "<boost/fusion/view/reverse_view/reverse_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/reverse_view/detail/deref_impl.hpp>", private, "<boost/fusion/view/reverse_view/reverse_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/reverse_view/detail/distance_impl.hpp>", private, "<boost/fusion/view/reverse_view/reverse_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/reverse_view/detail/end_impl.hpp>", private, "<boost/fusion/view/reverse_view/reverse_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/reverse_view/detail/key_of_impl.hpp>", private, "<boost/fusion/view/reverse_view/reverse_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/reverse_view/detail/next_impl.hpp>", private, "<boost/fusion/view/reverse_view/reverse_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/reverse_view/detail/prior_impl.hpp>", private, "<boost/fusion/view/reverse_view/reverse_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/reverse_view/detail/value_at_impl.hpp>", private, "<boost/fusion/view/reverse_view/reverse_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/reverse_view/detail/value_of_data_impl.hpp>", private, "<boost/fusion/view/reverse_view/reverse_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/reverse_view/detail/value_of_impl.hpp>", private, "<boost/fusion/view/reverse_view/reverse_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/single_view/detail/advance_impl.hpp>", private, "<boost/fusion/view/single_view/single_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/single_view/detail/at_impl.hpp>", private, "<boost/fusion/view/single_view/single_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/single_view/detail/begin_impl.hpp>", private, "<boost/fusion/view/single_view/single_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/single_view/detail/deref_impl.hpp>", private, "<boost/fusion/view/single_view/single_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/single_view/detail/distance_impl.hpp>", private, "<boost/fusion/view/single_view/single_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/single_view/detail/end_impl.hpp>", private, "<boost/fusion/view/single_view/single_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/single_view/detail/equal_to_impl.hpp>", private, "<boost/fusion/view/single_view/single_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/single_view/detail/next_impl.hpp>", private, "<boost/fusion/view/single_view/single_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/single_view/detail/prior_impl.hpp>", private, "<boost/fusion/view/single_view/single_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/single_view/detail/size_impl.hpp>", private, "<boost/fusion/view/single_view/single_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/single_view/detail/value_at_impl.hpp>", private, "<boost/fusion/view/single_view/single_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/single_view/detail/value_of_impl.hpp>", private, "<boost/fusion/view/single_view/single_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/transform_view/detail/advance_impl.hpp>", private, "<boost/fusion/view/transform_view/transform_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/transform_view/detail/at_impl.hpp>", private, "<boost/fusion/view/transform_view/transform_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/transform_view/detail/begin_impl.hpp>", private, "<boost/fusion/view/transform_view/transform_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/transform_view/detail/deref_impl.hpp>", private, "<boost/fusion/view/transform_view/transform_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/transform_view/detail/distance_impl.hpp>", private, "<boost/fusion/view/transform_view/transform_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/transform_view/detail/end_impl.hpp>", private, "<boost/fusion/view/transform_view/transform_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/transform_view/detail/equal_to_impl.hpp>", private, "<boost/fusion/view/transform_view/transform_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/transform_view/detail/next_impl.hpp>", private, "<boost/fusion/view/transform_view/transform_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/transform_view/detail/prior_impl.hpp>", private, "<boost/fusion/view/transform_view/transform_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/transform_view/detail/value_at_impl.hpp>", private, "<boost/fusion/view/transform_view/transform_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/transform_view/detail/value_of_impl.hpp>", private, "<boost/fusion/view/transform_view/transform_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/zip_view/detail/advance_impl.hpp>", private, "<boost/fusion/view/zip_view/zip_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/zip_view/detail/at_impl.hpp>", private, "<boost/fusion/view/zip_view/zip_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/zip_view/detail/begin_impl.hpp>", private, "<boost/fusion/view/zip_view/zip_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/zip_view/detail/deref_impl.hpp>", private, "<boost/fusion/view/zip_view/zip_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/zip_view/detail/distance_impl.hpp>", private, "<boost/fusion/view/zip_view/zip_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/zip_view/detail/end_impl.hpp>", private, "<boost/fusion/view/zip_view/zip_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/zip_view/detail/equal_to_impl.hpp>", private, "<boost/fusion/view/zip_view/zip_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/zip_view/detail/next_impl.hpp>", private, "<boost/fusion/view/zip_view/zip_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/zip_view/detail/prior_impl.hpp>", private, "<boost/fusion/view/zip_view/zip_view_iterator.hpp>", public ] },
+ { include: ["<boost/fusion/view/zip_view/detail/size_impl.hpp>", private, "<boost/fusion/view/zip_view/zip_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/zip_view/detail/value_at_impl.hpp>", private, "<boost/fusion/view/zip_view/zip_view.hpp>", public ] },
+ { include: ["<boost/fusion/view/zip_view/detail/value_of_impl.hpp>", private, "<boost/fusion/view/zip_view/zip_view_iterator.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/as_range.hpp>", private, "<boost/geometry/algorithms/convex_hull.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/assign_box_corners.hpp>", private, "<boost/geometry/algorithms/assign.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/assign_box_corners.hpp>", private, "<boost/geometry/algorithms/convert.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/assign_box_corners.hpp>", private, "<boost/geometry/algorithms/convex_hull.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/assign_indexed_point.hpp>", private, "<boost/geometry/algorithms/assign.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/assign_indexed_point.hpp>", private, "<boost/geometry/algorithms/convert.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/assign_values.hpp>", private, "<boost/geometry/algorithms/assign.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/assign_values.hpp>", private, "<boost/geometry/algorithms/convert.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/assign_values.hpp>", private, "<boost/geometry/strategies/cartesian/cart_intersect.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/calculate_null.hpp>", private, "<boost/geometry/algorithms/area.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/calculate_null.hpp>", private, "<boost/geometry/algorithms/length.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/calculate_null.hpp>", private, "<boost/geometry/algorithms/perimeter.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/calculate_sum.hpp>", private, "<boost/geometry/algorithms/area.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/calculate_sum.hpp>", private, "<boost/geometry/algorithms/perimeter.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/convert_indexed_to_indexed.hpp>", private, "<boost/geometry/algorithms/convert.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/convert_point_to_point.hpp>", private, "<boost/geometry/algorithms/append.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/convert_point_to_point.hpp>", private, "<boost/geometry/algorithms/convert.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/disjoint.hpp>", private, "<boost/geometry/algorithms/buffer.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/disjoint.hpp>", private, "<boost/geometry/algorithms/disjoint.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/disjoint.hpp>", private, "<boost/geometry/algorithms/equals.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/equals/collect_vectors.hpp>", private, "<boost/geometry/algorithms/equals.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/for_each_range.hpp>", private, "<boost/geometry/algorithms/disjoint.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/for_each_range.hpp>", private, "<boost/geometry/strategies/agnostic/hull_graham_andrew.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/not.hpp>", private, "<boost/geometry/algorithms/equals.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/get_turns.hpp>", private, "<boost/geometry/algorithms/disjoint.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/get_turns.hpp>", private, "<boost/geometry/algorithms/touches.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/intersection_insert.hpp>", private, "<boost/geometry/algorithms/difference.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/intersection_insert.hpp>", private, "<boost/geometry/algorithms/intersection.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/overlay.hpp>", private, "<boost/geometry/algorithms/union.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/self_turn_points.hpp>", private, "<boost/geometry/algorithms/intersects.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/overlay/self_turn_points.hpp>", private, "<boost/geometry/algorithms/touches.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/point_on_border.hpp>", private, "<boost/geometry/algorithms/disjoint.hpp>", public ] },
+ { include: ["<boost/geometry/algorithms/detail/throw_on_empty_input.hpp>", private, "<boost/geometry/algorithms/distance.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/algorithms/is_valid.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/assert.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/config_begin.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/config_end.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/distance_predicates.hpp>", private, "<boost/geometry/index/distance_predicates.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/exception.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/meta.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/predicates.hpp>", private, "<boost/geometry/index/predicates.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/adaptors.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/linear/linear.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/node/node.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/options.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/pack_create.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/quadratic/quadratic.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/query_iterators.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/rstar/rstar.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/utilities/view.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/children_box.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/copy.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/count.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/destroy.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/distance_query.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/insert.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/remove.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/rtree/visitors/spatial_query.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/serialization.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/translator.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/tuples.hpp>", private, "<boost/geometry/index/predicates.hpp>", public ] },
+ { include: ["<boost/geometry/index/detail/utilities.hpp>", private, "<boost/geometry/index/rtree.hpp>", public ] },
+ { include: ["<boost/geometry/io/wkt/detail/prefix.hpp>", private, "<boost/geometry/io/wkt/read.hpp>", public ] },
+ { include: ["<boost/geometry/io/wkt/detail/prefix.hpp>", private, "<boost/geometry/io/wkt/write.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/for_each_range.hpp>", private, "<boost/geometry/multi/multi.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/modify.hpp>", private, "<boost/geometry/multi/algorithms/correct.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/modify.hpp>", private, "<boost/geometry/multi/algorithms/reverse.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/modify_with_predicate.hpp>", private, "<boost/geometry/multi/multi.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/multi_sum.hpp>", private, "<boost/geometry/multi/algorithms/area.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/multi_sum.hpp>", private, "<boost/geometry/multi/algorithms/length.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/multi_sum.hpp>", private, "<boost/geometry/multi/algorithms/perimeter.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/multi_sum.hpp>", private, "<boost/geometry/multi/multi.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/overlay/copy_segment_point.hpp>", private, "<boost/geometry/multi/algorithms/intersection.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/overlay/copy_segment_point.hpp>", private, "<boost/geometry/multi/multi.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/overlay/copy_segments.hpp>", private, "<boost/geometry/multi/algorithms/intersection.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/overlay/copy_segments.hpp>", private, "<boost/geometry/multi/multi.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/overlay/get_ring.hpp>", private, "<boost/geometry/multi/algorithms/intersection.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/overlay/get_ring.hpp>", private, "<boost/geometry/multi/multi.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/overlay/get_turns.hpp>", private, "<boost/geometry/multi/algorithms/intersection.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/overlay/get_turns.hpp>", private, "<boost/geometry/multi/multi.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/overlay/select_rings.hpp>", private, "<boost/geometry/multi/algorithms/intersection.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/overlay/select_rings.hpp>", private, "<boost/geometry/multi/multi.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/overlay/self_turn_points.hpp>", private, "<boost/geometry/multi/multi.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/sections/range_by_section.hpp>", private, "<boost/geometry/multi/algorithms/intersection.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/sections/range_by_section.hpp>", private, "<boost/geometry/multi/multi.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/sections/sectionalize.hpp>", private, "<boost/geometry/multi/algorithms/intersection.hpp>", public ] },
+ { include: ["<boost/geometry/multi/algorithms/detail/sections/sectionalize.hpp>", private, "<boost/geometry/multi/multi.hpp>", public ] },
+ { include: ["<boost/geometry/multi/io/wkt/detail/prefix.hpp>", private, "<boost/geometry/multi/io/wkt/read.hpp>", public ] },
+ { include: ["<boost/geometry/multi/io/wkt/detail/prefix.hpp>", private, "<boost/geometry/multi/io/wkt/write.hpp>", public ] },
+ { include: ["<boost/geometry/multi/views/detail/range_type.hpp>", private, "<boost/geometry/multi/multi.hpp>", public ] },
+ { include: ["<boost/geometry/views/detail/points_view.hpp>", private, "<boost/geometry/views/box_view.hpp>", public ] },
+ { include: ["<boost/geometry/views/detail/points_view.hpp>", private, "<boost/geometry/views/segment_view.hpp>", public ] },
+ { include: ["<boost/geometry/views/detail/range_type.hpp>", private, "<boost/geometry/algorithms/convex_hull.hpp>", public ] },
+ { include: ["<boost/geometry/views/detail/range_type.hpp>", private, "<boost/geometry/strategies/agnostic/hull_graham_andrew.hpp>", public ] },
+ { include: ["<boost/graph/detail/adjacency_list.hpp>", private, "<boost/graph/adjacency_list.hpp>", public ] },
+ { include: ["<boost/graph/detail/array_binary_tree.hpp>", private, "<boost/pending/mutable_queue.hpp>", public ] },
+ { include: ["<boost/graph/detail/augment.hpp>", private, "<boost/graph/cycle_canceling.hpp>", public ] },
+ { include: ["<boost/graph/detail/augment.hpp>", private, "<boost/graph/successive_shortest_path_nonnegative_weights.hpp>", public ] },
+ { include: ["<boost/graph/detail/compressed_sparse_row_struct.hpp>", private, "<boost/graph/compressed_sparse_row_graph.hpp>", public ] },
+ { include: ["<boost/graph/detail/d_ary_heap.hpp>", private, "<boost/graph/astar_search.hpp>", public ] },
+ { include: ["<boost/graph/detail/d_ary_heap.hpp>", private, "<boost/graph/core_numbers.hpp>", public ] },
+ { include: ["<boost/graph/detail/d_ary_heap.hpp>", private, "<boost/graph/dijkstra_shortest_paths.hpp>", public ] },
+ { include: ["<boost/graph/detail/d_ary_heap.hpp>", private, "<boost/graph/dijkstra_shortest_paths_no_color_map.hpp>", public ] },
+ { include: ["<boost/graph/detail/d_ary_heap.hpp>", private, "<boost/graph/named_function_params.hpp>", public ] },
+ { include: ["<boost/graph/detail/d_ary_heap.hpp>", private, "<boost/graph/stoer_wagner_min_cut.hpp>", public ] },
+ { include: ["<boost/graph/detail/edge.hpp>", private, "<boost/graph/adjacency_list.hpp>", public ] },
+ { include: ["<boost/graph/detail/edge.hpp>", private, "<boost/graph/adjacency_matrix.hpp>", public ] },
+ { include: ["<boost/graph/detail/edge.hpp>", private, "<boost/property_map/parallel/distributed_property_map.hpp>", public ] },
+ { include: ["<boost/graph/detail/geodesic.hpp>", private, "<boost/graph/closeness_centrality.hpp>", public ] },
+ { include: ["<boost/graph/detail/geodesic.hpp>", private, "<boost/graph/eccentricity.hpp>", public ] },
+ { include: ["<boost/graph/detail/geodesic.hpp>", private, "<boost/graph/geodesic_distance.hpp>", public ] },
+ { include: ["<boost/graph/detail/incremental_components.hpp>", private, "<boost/graph/incremental_components.hpp>", public ] },
+ { include: ["<boost/graph/detail/indexed_properties.hpp>", private, "<boost/graph/compressed_sparse_row_graph.hpp>", public ] },
+ { include: ["<boost/graph/detail/index.hpp>", private, "<boost/graph/property_maps/container_property_map.hpp>", public ] },
+ { include: ["<boost/graph/detail/is_distributed_selector.hpp>", private, "<boost/graph/compressed_sparse_row_graph.hpp>", public ] },
+ { include: ["<boost/graph/detail/is_distributed_selector.hpp>", private, "<boost/graph/distributed/selector.hpp>", public ] },
+ { include: ["<boost/graph/detail/labeled_graph_traits.hpp>", private, "<boost/graph/labeled_graph.hpp>", public ] },
+ { include: ["<boost/graph/detail/read_graphviz_new.hpp>", private, "<boost/graph/graphviz.hpp>", public ] },
+ { include: ["<boost/graph/detail/read_graphviz_spirit.hpp>", private, "<boost/graph/graphviz.hpp>", public ] },
+ { include: ["<boost/graph/detail/set_adaptor.hpp>", private, "<boost/graph/filtered_graph.hpp>", public ] },
+ { include: ["<boost/graph/detail/sparse_ordering.hpp>", private, "<boost/graph/cuthill_mckee_ordering.hpp>", public ] },
+ { include: ["<boost/graph/detail/sparse_ordering.hpp>", private, "<boost/graph/king_ordering.hpp>", public ] },
+ { include: ["<boost/graph/distributed/detail/dijkstra_shortest_paths.hpp>", private, "<boost/graph/distributed/betweenness_centrality.hpp>", public ] },
+ { include: ["<boost/graph/distributed/detail/dijkstra_shortest_paths.hpp>", private, "<boost/graph/distributed/crauser_et_al_shortest_paths.hpp>", public ] },
+ { include: ["<boost/graph/distributed/detail/dijkstra_shortest_paths.hpp>", private, "<boost/graph/distributed/delta_stepping_shortest_paths.hpp>", public ] },
+ { include: ["<boost/graph/distributed/detail/dijkstra_shortest_paths.hpp>", private, "<boost/graph/distributed/eager_dijkstra_shortest_paths.hpp>", public ] },
+ { include: ["<boost/graph/distributed/detail/filtered_queue.hpp>", private, "<boost/graph/distributed/breadth_first_search.hpp>", public ] },
+ { include: ["<boost/graph/distributed/detail/filtered_queue.hpp>", private, "<boost/graph/distributed/strong_components.hpp>", public ] },
+ { include: ["<boost/graph/distributed/detail/mpi_process_group.ipp>", private, "<boost/graph/distributed/mpi_process_group.hpp>", public ] },
+ { include: ["<boost/graph/distributed/detail/queue.ipp>", private, "<boost/graph/distributed/queue.hpp>", public ] },
+ { include: ["<boost/graph/distributed/detail/remote_update_set.hpp>", private, "<boost/graph/distributed/crauser_et_al_shortest_paths.hpp>", public ] },
+ { include: ["<boost/graph/distributed/detail/remote_update_set.hpp>", private, "<boost/graph/distributed/eager_dijkstra_shortest_paths.hpp>", public ] },
+ { include: ["<boost/graph/parallel/detail/inplace_all_to_all.hpp>", private, "<boost/graph/parallel/algorithm.hpp>", public ] },
+ { include: ["<boost/graph/parallel/detail/property_holders.hpp>", private, "<boost/graph/distributed/adjacency_list.hpp>", public ] },
+ { include: ["<boost/graph/parallel/detail/property_holders.hpp>", private, "<boost/graph/distributed/named_graph.hpp>", public ] },
+ { include: ["<boost/graph/parallel/detail/untracked_pair.hpp>", private, "<boost/graph/distributed/adjlist/handlers.hpp>", public ] },
+ { include: ["<boost/graph/parallel/detail/untracked_pair.hpp>", private, "<boost/graph/distributed/dehne_gotz_min_spanning_tree.hpp>", public ] },
+ { include: ["<boost/graph/parallel/detail/untracked_pair.hpp>", private, "<boost/property_map/parallel/distributed_property_map.hpp>", public ] },
+ { include: ["<boost/heap/detail/heap_comparison.hpp>", private, "<boost/heap/binomial_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/heap_comparison.hpp>", private, "<boost/heap/d_ary_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/heap_comparison.hpp>", private, "<boost/heap/fibonacci_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/heap_comparison.hpp>", private, "<boost/heap/pairing_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/heap_comparison.hpp>", private, "<boost/heap/priority_queue.hpp>", public ] },
+ { include: ["<boost/heap/detail/heap_comparison.hpp>", private, "<boost/heap/skew_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/heap_node.hpp>", private, "<boost/heap/binomial_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/heap_node.hpp>", private, "<boost/heap/fibonacci_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/heap_node.hpp>", private, "<boost/heap/pairing_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/heap_node.hpp>", private, "<boost/heap/skew_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/mutable_heap.hpp>", private, "<boost/heap/d_ary_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/ordered_adaptor_iterator.hpp>", private, "<boost/heap/d_ary_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/stable_heap.hpp>", private, "<boost/heap/binomial_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/stable_heap.hpp>", private, "<boost/heap/d_ary_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/stable_heap.hpp>", private, "<boost/heap/fibonacci_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/stable_heap.hpp>", private, "<boost/heap/pairing_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/stable_heap.hpp>", private, "<boost/heap/priority_queue.hpp>", public ] },
+ { include: ["<boost/heap/detail/stable_heap.hpp>", private, "<boost/heap/skew_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/tree_iterator.hpp>", private, "<boost/heap/binomial_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/tree_iterator.hpp>", private, "<boost/heap/fibonacci_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/tree_iterator.hpp>", private, "<boost/heap/pairing_heap.hpp>", public ] },
+ { include: ["<boost/heap/detail/tree_iterator.hpp>", private, "<boost/heap/skew_heap.hpp>", public ] },
+ { include: ["<boost/icl/detail/boost_config.hpp>", private, "<boost/icl/gregorian.hpp>", public ] },
+ { include: ["<boost/icl/detail/boost_config.hpp>", private, "<boost/icl/impl_config.hpp>", public ] },
+ { include: ["<boost/icl/detail/boost_config.hpp>", private, "<boost/icl/ptime.hpp>", public ] },
+ { include: ["<boost/icl/detail/concept_check.hpp>", private, "<boost/icl/continuous_interval.hpp>", public ] },
+ { include: ["<boost/icl/detail/concept_check.hpp>", private, "<boost/icl/discrete_interval.hpp>", public ] },
+ { include: ["<boost/icl/detail/concept_check.hpp>", private, "<boost/icl/map.hpp>", public ] },
+ { include: ["<boost/icl/detail/design_config.hpp>", private, "<boost/icl/concept/interval.hpp>", public ] },
+ { include: ["<boost/icl/detail/design_config.hpp>", private, "<boost/icl/interval_base_map.hpp>", public ] },
+ { include: ["<boost/icl/detail/design_config.hpp>", private, "<boost/icl/interval_bounds.hpp>", public ] },
+ { include: ["<boost/icl/detail/design_config.hpp>", private, "<boost/icl/map.hpp>", public ] },
+ { include: ["<boost/icl/detail/design_config.hpp>", private, "<boost/icl/type_traits/interval_type_default.hpp>", public ] },
+ { include: ["<boost/icl/detail/element_iterator.hpp>", private, "<boost/icl/interval_base_set.hpp>", public ] },
+ { include: ["<boost/icl/detail/exclusive_less_than.hpp>", private, "<boost/icl/interval_base_set.hpp>", public ] },
+ { include: ["<boost/icl/detail/interval_map_algo.hpp>", private, "<boost/icl/concept/interval_associator.hpp>", public ] },
+ { include: ["<boost/icl/detail/interval_map_algo.hpp>", private, "<boost/icl/concept/interval_map.hpp>", public ] },
+ { include: ["<boost/icl/detail/interval_map_algo.hpp>", private, "<boost/icl/interval_base_map.hpp>", public ] },
+ { include: ["<boost/icl/detail/interval_set_algo.hpp>", private, "<boost/icl/concept/interval_associator.hpp>", public ] },
+ { include: ["<boost/icl/detail/interval_set_algo.hpp>", private, "<boost/icl/concept/interval_set.hpp>", public ] },
+ { include: ["<boost/icl/detail/interval_set_algo.hpp>", private, "<boost/icl/interval_base_set.hpp>", public ] },
+ { include: ["<boost/icl/detail/map_algo.hpp>", private, "<boost/icl/associative_element_container.hpp>", public ] },
+ { include: ["<boost/icl/detail/map_algo.hpp>", private, "<boost/icl/concept/element_map.hpp>", public ] },
+ { include: ["<boost/icl/detail/map_algo.hpp>", private, "<boost/icl/concept/interval_associator.hpp>", public ] },
+ { include: ["<boost/icl/detail/notate.hpp>", private, "<boost/icl/interval_base_map.hpp>", public ] },
+ { include: ["<boost/icl/detail/notate.hpp>", private, "<boost/icl/interval_base_set.hpp>", public ] },
+ { include: ["<boost/icl/detail/notate.hpp>", private, "<boost/icl/map.hpp>", public ] },
+ { include: ["<boost/icl/detail/on_absorbtion.hpp>", private, "<boost/icl/concept/element_map.hpp>", public ] },
+ { include: ["<boost/icl/detail/on_absorbtion.hpp>", private, "<boost/icl/interval_base_map.hpp>", public ] },
+ { include: ["<boost/icl/detail/on_absorbtion.hpp>", private, "<boost/icl/map.hpp>", public ] },
+ { include: ["<boost/icl/detail/set_algo.hpp>", private, "<boost/icl/concept/element_set.hpp>", public ] },
+ { include: ["<boost/icl/detail/set_algo.hpp>", private, "<boost/icl/concept/interval_associator.hpp>", public ] },
+ { include: ["<boost/icl/detail/set_algo.hpp>", private, "<boost/icl/concept/interval_map.hpp>", public ] },
+ { include: ["<boost/icl/detail/set_algo.hpp>", private, "<boost/icl/concept/interval_set.hpp>", public ] },
+ { include: ["<boost/icl/detail/std_set.hpp>", private, "<boost/icl/concept/element_set.hpp>", public ] },
+ { include: ["<boost/icl/detail/subset_comparer.hpp>", private, "<boost/icl/concept/element_associator.hpp>", public ] },
+ { include: ["<boost/interprocess/allocators/detail/adaptive_node_pool.hpp>", private, "<boost/interprocess/allocators/adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/allocators/detail/adaptive_node_pool.hpp>", private, "<boost/interprocess/allocators/cached_adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/allocators/detail/adaptive_node_pool.hpp>", private, "<boost/interprocess/allocators/private_adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/allocators/detail/allocator_common.hpp>", private, "<boost/interprocess/allocators/adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/allocators/detail/allocator_common.hpp>", private, "<boost/interprocess/allocators/allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/allocators/detail/allocator_common.hpp>", private, "<boost/interprocess/allocators/cached_adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/allocators/detail/allocator_common.hpp>", private, "<boost/interprocess/allocators/cached_node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/allocators/detail/allocator_common.hpp>", private, "<boost/interprocess/allocators/node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/allocators/detail/node_pool.hpp>", private, "<boost/interprocess/allocators/cached_node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/allocators/detail/node_pool.hpp>", private, "<boost/interprocess/allocators/node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/allocators/detail/node_pool.hpp>", private, "<boost/interprocess/allocators/private_node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/allocators/detail/node_tools.hpp>", private, "<boost/interprocess/allocators/cached_adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/allocators/detail/node_tools.hpp>", private, "<boost/interprocess/allocators/cached_node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/atomic.hpp>", private, "<boost/interprocess/sync/spin/condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/atomic.hpp>", private, "<boost/interprocess/sync/spin/mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/atomic.hpp>", private, "<boost/interprocess/sync/spin/recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/atomic.hpp>", private, "<boost/interprocess/sync/spin/semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/cast_tags.hpp>", private, "<boost/interprocess/offset_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/cast_tags.hpp>", private, "<boost/interprocess/smart_ptr/shared_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/allocators/adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/allocators/allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/allocators/cached_adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/allocators/cached_node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/allocators/node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/allocators/private_adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/allocators/private_node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/anonymous_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/containers/allocation_type.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/containers/containers_fwd.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/containers/deque.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/containers/flat_map.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/containers/flat_set.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/containers/list.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/containers/map.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/containers/pair.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/containers/set.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/containers/slist.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/containers/stable_vector.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/containers/string.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/containers/vector.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/containers/version_type.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/creation_tags.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/errors.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/exceptions.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/file_mapping.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/indexes/flat_map_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/indexes/iset_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/indexes/iunordered_set_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/indexes/map_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/indexes/null_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/indexes/unordered_map_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/interprocess_fwd.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/ipc/message_queue.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/managed_external_buffer.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/managed_heap_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/managed_mapped_file.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/managed_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/managed_windows_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/managed_xsi_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/mapped_region.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/mem_algo/rbtree_best_fit.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/mem_algo/simple_seq_fit.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/offset_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/permissions.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/segment_manager.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/shared_memory_object.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/smart_ptr/deleter.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/smart_ptr/enable_shared_from_this.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/smart_ptr/intrusive_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/smart_ptr/scoped_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/smart_ptr/shared_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/smart_ptr/unique_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/smart_ptr/weak_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/streams/bufferstream.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/streams/vectorstream.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/file_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/interprocess_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/interprocess_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/interprocess_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/interprocess_recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/interprocess_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/interprocess_sharable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/interprocess_upgradable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/lock_options.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/mutex_family.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/named_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/named_recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/named_sharable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/named_upgradable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/null_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/posix/condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/posix/mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/posix/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/posix/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/posix/pthread_helpers.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/posix/recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/posix/semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/scoped_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/sharable_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/shm/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/shm/named_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/shm/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/shm/named_recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/shm/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/shm/named_upgradable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/spin/condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/spin/mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/spin/recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/spin/semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/spin/wait.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/upgradable_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/windows/condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/windows/mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/windows/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/windows/named_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/windows/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/windows/named_recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/windows/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/windows/named_sync.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/windows/recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/windows/semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/windows/sync_utils.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/windows/winapi_mutex_wrapper.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/windows/winapi_semaphore_wrapper.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/sync/xsi/xsi_named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/windows_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/xsi_key.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_begin.hpp>", private, "<boost/interprocess/xsi_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/allocators/adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/allocators/allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/allocators/cached_adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/allocators/cached_node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/allocators/node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/allocators/private_adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/allocators/private_node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/anonymous_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/containers/allocation_type.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/containers/containers_fwd.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/containers/deque.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/containers/flat_map.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/containers/flat_set.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/containers/list.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/containers/map.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/containers/pair.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/containers/set.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/containers/slist.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/containers/stable_vector.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/containers/string.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/containers/vector.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/containers/version_type.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/creation_tags.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/errors.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/exceptions.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/file_mapping.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/indexes/flat_map_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/indexes/iset_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/indexes/iunordered_set_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/indexes/map_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/indexes/null_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/indexes/unordered_map_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/interprocess_fwd.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/ipc/message_queue.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/managed_external_buffer.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/managed_heap_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/managed_mapped_file.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/managed_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/managed_windows_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/managed_xsi_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/mapped_region.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/mem_algo/rbtree_best_fit.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/mem_algo/simple_seq_fit.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/offset_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/permissions.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/segment_manager.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/shared_memory_object.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/smart_ptr/deleter.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/smart_ptr/enable_shared_from_this.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/smart_ptr/intrusive_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/smart_ptr/scoped_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/smart_ptr/shared_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/smart_ptr/unique_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/smart_ptr/weak_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/streams/bufferstream.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/streams/vectorstream.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/file_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/interprocess_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/interprocess_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/interprocess_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/interprocess_recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/interprocess_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/interprocess_sharable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/interprocess_upgradable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/lock_options.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/mutex_family.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/named_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/named_recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/named_sharable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/named_upgradable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/null_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/posix/condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/posix/mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/posix/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/posix/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/posix/pthread_helpers.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/posix/recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/posix/semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/scoped_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/sharable_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/shm/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/shm/named_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/shm/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/shm/named_recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/shm/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/shm/named_upgradable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/spin/condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/spin/mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/spin/recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/spin/semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/spin/wait.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/upgradable_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/windows/condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/windows/mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/windows/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/windows/named_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/windows/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/windows/named_recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/windows/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/windows/named_sync.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/windows/recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/windows/semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/windows/sync_utils.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/windows/winapi_mutex_wrapper.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/windows/winapi_semaphore_wrapper.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/sync/xsi/xsi_named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/windows_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/xsi_key.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_end.hpp>", private, "<boost/interprocess/xsi_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_external_begin.hpp>", private, "<boost/interprocess/sync/windows/sync_utils.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/config_external_end.hpp>", private, "<boost/interprocess/sync/windows/sync_utils.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/file_wrapper.hpp>", private, "<boost/interprocess/managed_mapped_file.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/intermodule_singleton.hpp>", private, "<boost/flyweight/intermodule_holder.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/interprocess_tester.hpp>", private, "<boost/interprocess/sync/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/interprocess_tester.hpp>", private, "<boost/interprocess/sync/named_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/interprocess_tester.hpp>", private, "<boost/interprocess/sync/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/interprocess_tester.hpp>", private, "<boost/interprocess/sync/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/interprocess_tester.hpp>", private, "<boost/interprocess/sync/posix/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/interprocess_tester.hpp>", private, "<boost/interprocess/sync/shm/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/interprocess_tester.hpp>", private, "<boost/interprocess/sync/shm/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/interprocess_tester.hpp>", private, "<boost/interprocess/sync/windows/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/interprocess_tester.hpp>", private, "<boost/interprocess/sync/windows/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/interprocess_tester.hpp>", private, "<boost/interprocess/sync/windows/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_memory_impl.hpp>", private, "<boost/interprocess/managed_external_buffer.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_memory_impl.hpp>", private, "<boost/interprocess/managed_heap_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_memory_impl.hpp>", private, "<boost/interprocess/managed_mapped_file.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_memory_impl.hpp>", private, "<boost/interprocess/managed_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_memory_impl.hpp>", private, "<boost/interprocess/managed_windows_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_memory_impl.hpp>", private, "<boost/interprocess/managed_xsi_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/ipc/message_queue.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/managed_mapped_file.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/managed_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/managed_windows_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/managed_xsi_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/sync/named_sharable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/sync/named_upgradable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/sync/shm/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/sync/shm/named_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/sync/shm/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/sync/shm/named_recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/sync/shm/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/managed_open_or_create_impl.hpp>", private, "<boost/interprocess/sync/shm/named_upgradable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/math_functions.hpp>", private, "<boost/interprocess/mem_algo/rbtree_best_fit.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/min_max.hpp>", private, "<boost/interprocess/mem_algo/rbtree_best_fit.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/mpl.hpp>", private, "<boost/interprocess/allocators/adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/mpl.hpp>", private, "<boost/interprocess/offset_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/mpl.hpp>", private, "<boost/interprocess/segment_manager.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/mpl.hpp>", private, "<boost/interprocess/smart_ptr/shared_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/mpl.hpp>", private, "<boost/interprocess/smart_ptr/unique_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/mpl.hpp>", private, "<boost/interprocess/sync/scoped_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/mpl.hpp>", private, "<boost/interprocess/sync/sharable_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/mpl.hpp>", private, "<boost/interprocess/sync/shm/named_creation_functor.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/mpl.hpp>", private, "<boost/interprocess/sync/upgradable_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/named_proxy.hpp>", private, "<boost/interprocess/segment_manager.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/file_mapping.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/mapped_region.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/shared_memory_object.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/sync/file_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/sync/posix/semaphore_wrapper.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/sync/xsi/xsi_named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/windows_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/xsi_key.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_file_functions.hpp>", private, "<boost/interprocess/xsi_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_thread_functions.hpp>", private, "<boost/interprocess/sync/file_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_thread_functions.hpp>", private, "<boost/interprocess/sync/posix/mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_thread_functions.hpp>", private, "<boost/interprocess/sync/posix/recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_thread_functions.hpp>", private, "<boost/interprocess/sync/posix/semaphore_wrapper.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_thread_functions.hpp>", private, "<boost/interprocess/sync/spin/condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_thread_functions.hpp>", private, "<boost/interprocess/sync/spin/mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_thread_functions.hpp>", private, "<boost/interprocess/sync/spin/recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_thread_functions.hpp>", private, "<boost/interprocess/sync/spin/semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/os_thread_functions.hpp>", private, "<boost/interprocess/sync/spin/wait.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/pointer_type.hpp>", private, "<boost/interprocess/smart_ptr/scoped_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/pointer_type.hpp>", private, "<boost/interprocess/smart_ptr/unique_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/file_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/interprocess_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/interprocess_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/interprocess_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/interprocess_recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/interprocess_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/interprocess_sharable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/interprocess_upgradable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/named_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/named_recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/named_sharable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/named_upgradable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/posix/condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/posix/mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/posix/ptime_to_timespec.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/posix/recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/posix/semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/posix/semaphore_wrapper.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/scoped_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/sharable_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/shm/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/shm/named_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/shm/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/shm/named_recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/shm/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/shm/named_upgradable_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/spin/condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/spin/interprocess_barrier.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/spin/mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/spin/recursive_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/spin/semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/upgradable_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/windows/condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/windows/mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/windows/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/windows/named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/windows/named_semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/windows/semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/windows/winapi_mutex_wrapper.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/posix_time_types_wrk.hpp>", private, "<boost/interprocess/sync/windows/winapi_semaphore_wrapper.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/segment_manager_helper.hpp>", private, "<boost/interprocess/segment_manager.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/tmp_dir_helpers.hpp>", private, "<boost/interprocess/shared_memory_object.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/tmp_dir_helpers.hpp>", private, "<boost/interprocess/sync/posix/semaphore_wrapper.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/tmp_dir_helpers.hpp>", private, "<boost/interprocess/sync/windows/named_sync.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/transform_iterator.hpp>", private, "<boost/interprocess/segment_manager.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/allocators/adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/allocators/allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/allocators/node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/ipc/message_queue.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/mem_algo/rbtree_best_fit.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/segment_manager.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/smart_ptr/shared_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/smart_ptr/unique_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/sync/scoped_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/sync/sharable_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/sync/shm/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/sync/shm/named_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/sync/shm/named_creation_functor.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/type_traits.hpp>", private, "<boost/interprocess/sync/upgradable_lock.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/allocators/adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/allocators/allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/allocators/cached_adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/allocators/cached_node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/allocators/node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/allocators/private_adaptive_pool.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/allocators/private_node_allocator.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/file_mapping.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/indexes/iset_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/indexes/iunordered_set_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/indexes/unordered_map_index.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/ipc/message_queue.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/mapped_region.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/mem_algo/rbtree_best_fit.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/offset_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/segment_manager.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/smart_ptr/deleter.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/smart_ptr/intrusive_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/smart_ptr/scoped_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/smart_ptr/shared_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/smart_ptr/unique_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/sync/xsi/xsi_named_mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/windows_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/xsi_key.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/utilities.hpp>", private, "<boost/interprocess/xsi_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/win32_api.hpp>", private, "<boost/interprocess/errors.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/win32_api.hpp>", private, "<boost/interprocess/mapped_region.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/win32_api.hpp>", private, "<boost/interprocess/permissions.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/win32_api.hpp>", private, "<boost/interprocess/sync/windows/mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/win32_api.hpp>", private, "<boost/interprocess/sync/windows/semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/win32_api.hpp>", private, "<boost/interprocess/sync/windows/sync_utils.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/win32_api.hpp>", private, "<boost/interprocess/sync/windows/winapi_mutex_wrapper.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/win32_api.hpp>", private, "<boost/interprocess/sync/windows/winapi_semaphore_wrapper.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/win32_api.hpp>", private, "<boost/interprocess/windows_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/windows_intermodule_singleton.hpp>", private, "<boost/interprocess/mapped_region.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/windows_intermodule_singleton.hpp>", private, "<boost/interprocess/sync/windows/mutex.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/windows_intermodule_singleton.hpp>", private, "<boost/interprocess/sync/windows/semaphore.hpp>", public ] },
+ { include: ["<boost/interprocess/detail/xsi_shared_memory_file_wrapper.hpp>", private, "<boost/interprocess/managed_xsi_shared_memory.hpp>", public ] },
+ { include: ["<boost/interprocess/mem_algo/detail/mem_algo_common.hpp>", private, "<boost/interprocess/mem_algo/rbtree_best_fit.hpp>", public ] },
+ { include: ["<boost/interprocess/mem_algo/detail/simple_seq_fit_impl.hpp>", private, "<boost/interprocess/mem_algo/simple_seq_fit.hpp>", public ] },
+ { include: ["<boost/interprocess/smart_ptr/detail/shared_count.hpp>", private, "<boost/interprocess/smart_ptr/shared_ptr.hpp>", public ] },
+ { include: ["<boost/interprocess/sync/detail/condition_algorithm_8a.hpp>", private, "<boost/interprocess/sync/windows/condition.hpp>", public ] },
+ { include: ["<boost/interprocess/sync/detail/condition_algorithm_8a.hpp>", private, "<boost/interprocess/sync/windows/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/sync/detail/condition_any_algorithm.hpp>", private, "<boost/interprocess/sync/interprocess_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/sync/detail/condition_any_algorithm.hpp>", private, "<boost/interprocess/sync/shm/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/sync/detail/condition_any_algorithm.hpp>", private, "<boost/interprocess/sync/shm/named_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/sync/detail/locks.hpp>", private, "<boost/interprocess/sync/interprocess_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/sync/detail/locks.hpp>", private, "<boost/interprocess/sync/named_condition_any.hpp>", public ] },
+ { include: ["<boost/interprocess/sync/detail/locks.hpp>", private, "<boost/interprocess/sync/named_condition.hpp>", public ] },
+ { include: ["<boost/interprocess/sync/detail/locks.hpp>", private, "<boost/interprocess/sync/shm/named_condition.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/any_node_and_algorithms.hpp>", private, "<boost/intrusive/any_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/avltree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/avltree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/bstree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/bstree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/circular_slist_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/hashtable.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/list.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/pointer_plus_bits.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/rbtree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/rbtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/sgtree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/sgtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/slist.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/splaytree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/splaytree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/treap_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/assert.hpp>", private, "<boost/intrusive/treap.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/avltree_node.hpp>", private, "<boost/intrusive/avl_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/avltree_node.hpp>", private, "<boost/intrusive/avltree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/clear_on_destructor_base.hpp>", private, "<boost/intrusive/bstree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/clear_on_destructor_base.hpp>", private, "<boost/intrusive/hashtable.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/clear_on_destructor_base.hpp>", private, "<boost/intrusive/list.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/clear_on_destructor_base.hpp>", private, "<boost/intrusive/slist.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/common_slist_algorithms.hpp>", private, "<boost/intrusive/circular_slist_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/common_slist_algorithms.hpp>", private, "<boost/intrusive/linear_slist_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/any_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/avl_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/avl_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/avltree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/avltree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/bs_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/bs_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/bstree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/bstree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/circular_list_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/circular_slist_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/hashtable.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/linear_slist_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/list_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/list.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/options.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/parent_from_member.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/pointer_traits.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/priority_compare.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/rbtree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/rbtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/sg_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/sgtree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/sgtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/slist_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/slist.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/splay_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/splay_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/splaytree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/splaytree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/treap_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/treap.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/treap_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/unordered_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_begin.hpp>", private, "<boost/intrusive/unordered_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/any_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/avl_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/avl_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/avltree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/avltree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/bs_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/bs_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/bstree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/bstree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/circular_list_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/circular_slist_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/hashtable.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/linear_slist_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/list_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/list.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/options.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/parent_from_member.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/pointer_traits.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/priority_compare.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/rbtree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/rbtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/sg_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/sgtree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/sgtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/slist_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/slist.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/splay_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/splay_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/splaytree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/splaytree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/treap_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/treap.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/treap_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/unordered_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/config_end.hpp>", private, "<boost/intrusive/unordered_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/ebo_functor_holder.hpp>", private, "<boost/intrusive/avltree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/ebo_functor_holder.hpp>", private, "<boost/intrusive/bstree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/ebo_functor_holder.hpp>", private, "<boost/intrusive/hashtable.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/ebo_functor_holder.hpp>", private, "<boost/intrusive/rbtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/ebo_functor_holder.hpp>", private, "<boost/intrusive/sgtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/ebo_functor_holder.hpp>", private, "<boost/intrusive/splaytree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/ebo_functor_holder.hpp>", private, "<boost/intrusive/treap.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/function_detector.hpp>", private, "<boost/intrusive/bstree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/function_detector.hpp>", private, "<boost/intrusive/rbtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/function_detector.hpp>", private, "<boost/intrusive/splaytree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/generic_hook.hpp>", private, "<boost/intrusive/any_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/generic_hook.hpp>", private, "<boost/intrusive/avl_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/generic_hook.hpp>", private, "<boost/intrusive/bs_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/generic_hook.hpp>", private, "<boost/intrusive/list_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/generic_hook.hpp>", private, "<boost/intrusive/set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/generic_hook.hpp>", private, "<boost/intrusive/slist_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/generic_hook.hpp>", private, "<boost/intrusive/unordered_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/hashtable_node.hpp>", private, "<boost/intrusive/hashtable.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/list_node.hpp>", private, "<boost/intrusive/list_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/memory_util.hpp>", private, "<boost/container/allocator_traits.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/memory_util.hpp>", private, "<boost/intrusive/pointer_traits.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/avl_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/avltree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/bs_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/bstree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/hashtable.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/list.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/options.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/pointer_plus_bits.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/rbtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/sg_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/sgtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/splay_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/splaytree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/treap.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/mpl.hpp>", private, "<boost/intrusive/treap_set.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/parent_from_member.hpp>", private, "<boost/intrusive/member_value_traits.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/parent_from_member.hpp>", private, "<boost/intrusive/parent_from_member.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/rbtree_node.hpp>", private, "<boost/intrusive/rbtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/rbtree_node.hpp>", private, "<boost/intrusive/set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/slist_node.hpp>", private, "<boost/intrusive/slist_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/transform_iterator.hpp>", private, "<boost/intrusive/hashtable.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/tree_node.hpp>", private, "<boost/intrusive/avltree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/tree_node.hpp>", private, "<boost/intrusive/bs_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/tree_node.hpp>", private, "<boost/intrusive/bstree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/tree_node.hpp>", private, "<boost/intrusive/rbtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/tree_node.hpp>", private, "<boost/intrusive/sgtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/tree_node.hpp>", private, "<boost/intrusive/splaytree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/tree_node.hpp>", private, "<boost/intrusive/treap.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/any_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/avl_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/avltree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/avltree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/bs_set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/bstree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/bstree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/circular_list_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/circular_slist_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/hashtable.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/linear_slist_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/list_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/list.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/options.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/rbtree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/rbtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/set_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/sgtree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/sgtree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/slist_hook.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/slist.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/splaytree_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/splaytree.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/treap_algorithms.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/treap.hpp>", public ] },
+ { include: ["<boost/intrusive/detail/utilities.hpp>", private, "<boost/intrusive/unordered_set_hook.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/access_control.hpp>", private, "<boost/iostreams/chain.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/access_control.hpp>", private, "<boost/iostreams/filtering_streambuf.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/access_control.hpp>", private, "<boost/iostreams/filtering_stream.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/adapter/concept_adapter.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/adapter/device_adapter.hpp>", private, "<boost/iostreams/tee.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/adapter/direct_adapter.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/adapter/direct_adapter.hpp>", private, "<boost/iostreams/compose.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/adapter/filter_adapter.hpp>", private, "<boost/iostreams/tee.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/adapter/non_blocking_adapter.hpp>", private, "<boost/iostreams/close.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/adapter/non_blocking_adapter.hpp>", private, "<boost/iostreams/copy.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/adapter/non_blocking_adapter.hpp>", private, "<boost/iostreams/filter/gzip.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/adapter/range_adapter.hpp>", private, "<boost/iostreams/filter/gzip.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/bool_trait_def.hpp>", private, "<boost/iostreams/filter/test.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/bool_trait_def.hpp>", private, "<boost/iostreams/traits.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/broken_overload_resolution/stream_buffer.hpp>", private, "<boost/iostreams/stream_buffer.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/broken_overload_resolution/stream.hpp>", private, "<boost/iostreams/stream.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/buffer.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/buffer.hpp>", private, "<boost/iostreams/copy.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/buffer.hpp>", private, "<boost/iostreams/filter/symmetric.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/buffer.hpp>", private, "<boost/iostreams/invert.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/call_traits.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/call_traits.hpp>", private, "<boost/iostreams/compose.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/call_traits.hpp>", private, "<boost/iostreams/tee.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/chain.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/char_traits.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/filter/aggregate.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/filter/gzip.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/filtering_streambuf.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/filtering_stream.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/filter/newline.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/filter/stdio.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/filter/symmetric.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/read.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/stream_buffer.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/stream.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/char_traits.hpp>", private, "<boost/iostreams/write.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/codecvt_helper.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/codecvt_holder.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/auto_link.hpp>", private, "<boost/iostreams/device/file_descriptor.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/auto_link.hpp>", private, "<boost/iostreams/device/mapped_file.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/auto_link.hpp>", private, "<boost/iostreams/filter/bzip2.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/auto_link.hpp>", private, "<boost/iostreams/filter/zlib.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/bzip2.hpp>", private, "<boost/iostreams/filter/bzip2.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/codecvt.hpp>", private, "<boost/iostreams/positioning.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/checked_operations.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/close.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/combine.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/compose.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/device/file.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/filter/aggregate.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/filter/counter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/filtering_stream.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/filter/line.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/filter/newline.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/filter/symmetric.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/flush.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/imbue.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/input_sequence.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/invert.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/optimal_buffer_size.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/output_sequence.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/positioning.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/read.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/seek.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/stream_buffer.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/traits.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/disable_warnings.hpp>", private, "<boost/iostreams/write.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/dyn_link.hpp>", private, "<boost/iostreams/device/file_descriptor.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/dyn_link.hpp>", private, "<boost/iostreams/device/mapped_file.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/dyn_link.hpp>", private, "<boost/iostreams/filter/bzip2.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/dyn_link.hpp>", private, "<boost/iostreams/filter/zlib.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/checked_operations.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/close.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/combine.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/compose.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/device/file.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/filter/aggregate.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/filter/counter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/filtering_stream.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/filter/line.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/filter/newline.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/filter/symmetric.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/flush.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/imbue.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/input_sequence.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/invert.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/optimal_buffer_size.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/output_sequence.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/positioning.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/read.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/seek.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/stream_buffer.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/traits.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/enable_warnings.hpp>", private, "<boost/iostreams/write.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/fpos.hpp>", private, "<boost/iostreams/positioning.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/limits.hpp>", private, "<boost/iostreams/filter/symmetric.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/overload_resolution.hpp>", private, "<boost/iostreams/stream_buffer.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/overload_resolution.hpp>", private, "<boost/iostreams/stream.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/unreachable_return.hpp>", private, "<boost/iostreams/checked_operations.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/char_traits.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/device/file.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/device/mapped_file.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/filter/bzip2.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/filter/stdio.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/filter/zlib.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/wide_streams.hpp>", private, "<boost/iostreams/traits.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/windows_posix.hpp>", private, "<boost/iostreams/device/file_descriptor.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/config/zlib.hpp>", private, "<boost/iostreams/filter/zlib.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/counted_array.hpp>", private, "<boost/iostreams/invert.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/default_arg.hpp>", private, "<boost/iostreams/concepts.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/dispatch.hpp>", private, "<boost/iostreams/checked_operations.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/dispatch.hpp>", private, "<boost/iostreams/flush.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/dispatch.hpp>", private, "<boost/iostreams/imbue.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/dispatch.hpp>", private, "<boost/iostreams/optimal_buffer_size.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/dispatch.hpp>", private, "<boost/iostreams/read.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/dispatch.hpp>", private, "<boost/iostreams/seek.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/dispatch.hpp>", private, "<boost/iostreams/write.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/double_object.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/enable_if_stream.hpp>", private, "<boost/iostreams/compose.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/enable_if_stream.hpp>", private, "<boost/iostreams/copy.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/error.hpp>", private, "<boost/iostreams/checked_operations.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/error.hpp>", private, "<boost/iostreams/filter/gzip.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/execute.hpp>", private, "<boost/iostreams/chain.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/execute.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/execute.hpp>", private, "<boost/iostreams/compose.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/execute.hpp>", private, "<boost/iostreams/copy.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/execute.hpp>", private, "<boost/iostreams/invert.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/execute.hpp>", private, "<boost/iostreams/tee.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/file_handle.hpp>", private, "<boost/iostreams/device/file_descriptor.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/forward.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/forward.hpp>", private, "<boost/iostreams/stream_buffer.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/forward.hpp>", private, "<boost/iostreams/stream.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/fstream.hpp>", private, "<boost/iostreams/device/file.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/functional.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/functional.hpp>", private, "<boost/iostreams/compose.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/functional.hpp>", private, "<boost/iostreams/copy.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/functional.hpp>", private, "<boost/iostreams/invert.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/functional.hpp>", private, "<boost/iostreams/tee.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/close.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/combine.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/concepts.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/constants.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/copy.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/device/back_inserter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/device/file_descriptor.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/device/file.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/device/mapped_file.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/device/null.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/filter/aggregate.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/filter/bzip2.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/filter/gzip.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/filter/line.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/filter/newline.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/filter/stdio.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/filter/test.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/filter/zlib.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/positioning.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/read.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/seek.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/skip.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/stream_buffer.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/ios.hpp>", private, "<boost/iostreams/write.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/iostream.hpp>", private, "<boost/iostreams/filtering_stream.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/iostream.hpp>", private, "<boost/iostreams/stream.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/is_iterator_range.hpp>", private, "<boost/iostreams/traits.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/optional.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/path.hpp>", private, "<boost/iostreams/device/file_descriptor.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/path.hpp>", private, "<boost/iostreams/device/mapped_file.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/push.hpp>", private, "<boost/iostreams/chain.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/push.hpp>", private, "<boost/iostreams/filtering_streambuf.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/push.hpp>", private, "<boost/iostreams/filtering_stream.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/resolve.hpp>", private, "<boost/iostreams/copy.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/restrict_impl.hpp>", private, "<boost/iostreams/restrict.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/restrict_impl.hpp>", private, "<boost/iostreams/slice.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/select_by_size.hpp>", private, "<boost/iostreams/traits.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/select.hpp>", private, "<boost/iostreams/close.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/select.hpp>", private, "<boost/iostreams/code_converter.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/select.hpp>", private, "<boost/iostreams/filtering_stream.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/select.hpp>", private, "<boost/iostreams/stream.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/select.hpp>", private, "<boost/iostreams/traits.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/streambuf/chainbuf.hpp>", private, "<boost/iostreams/filtering_streambuf.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/streambuf/direct_streambuf.hpp>", private, "<boost/iostreams/stream_buffer.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/streambuf.hpp>", private, "<boost/iostreams/chain.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/streambuf.hpp>", private, "<boost/iostreams/filtering_streambuf.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/streambuf.hpp>", private, "<boost/iostreams/filtering_stream.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/streambuf.hpp>", private, "<boost/iostreams/flush.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/streambuf.hpp>", private, "<boost/iostreams/imbue.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/streambuf.hpp>", private, "<boost/iostreams/read.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/streambuf.hpp>", private, "<boost/iostreams/seek.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/streambuf.hpp>", private, "<boost/iostreams/write.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/streambuf/indirect_streambuf.hpp>", private, "<boost/iostreams/stream_buffer.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/template_params.hpp>", private, "<boost/iostreams/filter/symmetric.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/template_params.hpp>", private, "<boost/iostreams/pipeline.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/vc6/close.hpp>", private, "<boost/iostreams/close.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/vc6/read.hpp>", private, "<boost/iostreams/read.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/vc6/write.hpp>", private, "<boost/iostreams/write.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/wrap_unwrap.hpp>", private, "<boost/iostreams/chain.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/wrap_unwrap.hpp>", private, "<boost/iostreams/close.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/wrap_unwrap.hpp>", private, "<boost/iostreams/combine.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/wrap_unwrap.hpp>", private, "<boost/iostreams/copy.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/wrap_unwrap.hpp>", private, "<boost/iostreams/flush.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/wrap_unwrap.hpp>", private, "<boost/iostreams/imbue.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/wrap_unwrap.hpp>", private, "<boost/iostreams/input_sequence.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/wrap_unwrap.hpp>", private, "<boost/iostreams/optimal_buffer_size.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/wrap_unwrap.hpp>", private, "<boost/iostreams/output_sequence.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/wrap_unwrap.hpp>", private, "<boost/iostreams/read.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/wrap_unwrap.hpp>", private, "<boost/iostreams/seek.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/wrap_unwrap.hpp>", private, "<boost/iostreams/traits.hpp>", public ] },
+ { include: ["<boost/iostreams/detail/wrap_unwrap.hpp>", private, "<boost/iostreams/write.hpp>", public ] },
+ { include: ["<boost/iterator/detail/any_conversion_eater.hpp>", private, "<boost/iterator/is_lvalue_iterator.hpp>", public ] },
+ { include: ["<boost/iterator/detail/any_conversion_eater.hpp>", private, "<boost/iterator/is_readable_iterator.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_def.hpp>", private, "<boost/iterator/indirect_iterator.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_def.hpp>", private, "<boost/iterator/interoperable.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_def.hpp>", private, "<boost/iterator/is_lvalue_iterator.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_def.hpp>", private, "<boost/iterator/is_readable_iterator.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_def.hpp>", private, "<boost/iterator/iterator_adaptor.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_def.hpp>", private, "<boost/iterator/iterator_categories.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_def.hpp>", private, "<boost/iterator/iterator_facade.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_def.hpp>", private, "<boost/iterator/new_iterator_tests.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_def.hpp>", private, "<boost/iterator/transform_iterator.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_def.hpp>", private, "<boost/python/object_operators.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_undef.hpp>", private, "<boost/iterator/indirect_iterator.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_undef.hpp>", private, "<boost/iterator/interoperable.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_undef.hpp>", private, "<boost/iterator/is_lvalue_iterator.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_undef.hpp>", private, "<boost/iterator/is_readable_iterator.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_undef.hpp>", private, "<boost/iterator/iterator_adaptor.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_undef.hpp>", private, "<boost/iterator/iterator_categories.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_undef.hpp>", private, "<boost/iterator/iterator_facade.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_undef.hpp>", private, "<boost/iterator/new_iterator_tests.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_undef.hpp>", private, "<boost/iterator/transform_iterator.hpp>", public ] },
+ { include: ["<boost/iterator/detail/config_undef.hpp>", private, "<boost/python/object_operators.hpp>", public ] },
+ { include: ["<boost/iterator/detail/enable_if.hpp>", private, "<boost/iterator/iterator_adaptor.hpp>", public ] },
+ { include: ["<boost/iterator/detail/enable_if.hpp>", private, "<boost/iterator/iterator_facade.hpp>", public ] },
+ { include: ["<boost/iterator/detail/enable_if.hpp>", private, "<boost/iterator/transform_iterator.hpp>", public ] },
+ { include: ["<boost/iterator/detail/enable_if.hpp>", private, "<boost/python/object_operators.hpp>", public ] },
+ { include: ["<boost/iterator/detail/facade_iterator_category.hpp>", private, "<boost/iterator/iterator_archetypes.hpp>", public ] },
+ { include: ["<boost/iterator/detail/facade_iterator_category.hpp>", private, "<boost/iterator/iterator_facade.hpp>", public ] },
+ { include: ["<boost/iterator/detail/minimum_category.hpp>", private, "<boost/iterator/zip_iterator.hpp>", public ] },
+ { include: ["<boost/iterator/detail/minimum_category.hpp>", private, "<boost/token_iterator.hpp>", public ] },
+ { include: ["<boost/lambda/detail/actions.hpp>", private, "<boost/lambda/core.hpp>", public ] },
+ { include: ["<boost/lambda/detail/arity_code.hpp>", private, "<boost/lambda/core.hpp>", public ] },
+ { include: ["<boost/lambda/detail/bind_functions.hpp>", private, "<boost/lambda/bind.hpp>", public ] },
+ { include: ["<boost/lambda/detail/control_constructs_common.hpp>", private, "<boost/lambda/exceptions.hpp>", public ] },
+ { include: ["<boost/lambda/detail/control_constructs_common.hpp>", private, "<boost/lambda/switch.hpp>", public ] },
+ { include: ["<boost/lambda/detail/function_adaptors.hpp>", private, "<boost/lambda/core.hpp>", public ] },
+ { include: ["<boost/lambda/detail/lambda_config.hpp>", private, "<boost/lambda/core.hpp>", public ] },
+ { include: ["<boost/lambda/detail/lambda_functor_base.hpp>", private, "<boost/lambda/core.hpp>", public ] },
+ { include: ["<boost/lambda/detail/lambda_functors.hpp>", private, "<boost/lambda/core.hpp>", public ] },
+ { include: ["<boost/lambda/detail/lambda_fwd.hpp>", private, "<boost/lambda/core.hpp>", public ] },
+ { include: ["<boost/lambda/detail/lambda_traits.hpp>", private, "<boost/lambda/core.hpp>", public ] },
+ { include: ["<boost/lambda/detail/member_ptr.hpp>", private, "<boost/lambda/lambda.hpp>", public ] },
+ { include: ["<boost/lambda/detail/operator_actions.hpp>", private, "<boost/lambda/control_structures.hpp>", public ] },
+ { include: ["<boost/lambda/detail/operator_actions.hpp>", private, "<boost/lambda/if.hpp>", public ] },
+ { include: ["<boost/lambda/detail/operator_actions.hpp>", private, "<boost/lambda/lambda.hpp>", public ] },
+ { include: ["<boost/lambda/detail/operator_lambda_func_base.hpp>", private, "<boost/lambda/lambda.hpp>", public ] },
+ { include: ["<boost/lambda/detail/operator_return_type_traits.hpp>", private, "<boost/lambda/control_structures.hpp>", public ] },
+ { include: ["<boost/lambda/detail/operator_return_type_traits.hpp>", private, "<boost/lambda/if.hpp>", public ] },
+ { include: ["<boost/lambda/detail/operator_return_type_traits.hpp>", private, "<boost/lambda/lambda.hpp>", public ] },
+ { include: ["<boost/lambda/detail/operators.hpp>", private, "<boost/lambda/lambda.hpp>", public ] },
+ { include: ["<boost/lambda/detail/ret.hpp>", private, "<boost/lambda/core.hpp>", public ] },
+ { include: ["<boost/lambda/detail/return_type_traits.hpp>", private, "<boost/lambda/core.hpp>", public ] },
+ { include: ["<boost/lambda/detail/select_functions.hpp>", private, "<boost/lambda/core.hpp>", public ] },
+ { include: ["<boost/lambda/detail/suppress_unused.hpp>", private, "<boost/lambda/casts.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/auto.hpp>", private, "<boost/local_function/aux_/macro/code_/functor.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/bind.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/sign.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/bind.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/validate_/defaults.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/bind.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/validate_/this.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/const_bind.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/any_bind_type.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/const_bind.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/sign.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/const_bind.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/validate_/defaults.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/const_bind.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/validate_/this.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/const.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/validate_/this.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/default.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_params.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/default.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/sign.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/default.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/validate_/defaults.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/inline.hpp>", private, "<boost/local_function/aux_/macro/name.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/recursive.hpp>", private, "<boost/local_function/aux_/macro/name.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/register.hpp>", private, "<boost/local_function/aux_/macro/code_/functor.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/return.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_/append.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/return.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/sign.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/this.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/any_bind_type.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/this.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/validate_/this.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/thisunderscore.hpp>", private, "<boost/local_function/aux_/macro/code_/functor.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/thisunderscore.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/sign.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/keyword/thisunderscore.hpp>", private, "<boost/scope_exit.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/line_counter.hpp>", private, "<boost/local_function.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/line_counter.hpp>", private, "<boost/scope_exit.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/void_list.hpp>", private, "<boost/local_function.hpp>", public ] },
+ { include: ["<boost/local_function/detail/preprocessor/void_list.hpp>", private, "<boost/scope_exit.hpp>", public ] },
+ { include: ["<boost/lockfree/detail/atomic.hpp>", private, "<boost/lockfree/queue.hpp>", public ] },
+ { include: ["<boost/lockfree/detail/atomic.hpp>", private, "<boost/lockfree/spsc_queue.hpp>", public ] },
+ { include: ["<boost/lockfree/detail/atomic.hpp>", private, "<boost/lockfree/stack.hpp>", public ] },
+ { include: ["<boost/lockfree/detail/branch_hints.hpp>", private, "<boost/lockfree/spsc_queue.hpp>", public ] },
+ { include: ["<boost/lockfree/detail/copy_payload.hpp>", private, "<boost/lockfree/queue.hpp>", public ] },
+ { include: ["<boost/lockfree/detail/copy_payload.hpp>", private, "<boost/lockfree/stack.hpp>", public ] },
+ { include: ["<boost/lockfree/detail/freelist.hpp>", private, "<boost/lockfree/queue.hpp>", public ] },
+ { include: ["<boost/lockfree/detail/freelist.hpp>", private, "<boost/lockfree/stack.hpp>", public ] },
+ { include: ["<boost/lockfree/detail/parameter.hpp>", private, "<boost/lockfree/queue.hpp>", public ] },
+ { include: ["<boost/lockfree/detail/parameter.hpp>", private, "<boost/lockfree/spsc_queue.hpp>", public ] },
+ { include: ["<boost/lockfree/detail/parameter.hpp>", private, "<boost/lockfree/stack.hpp>", public ] },
+ { include: ["<boost/lockfree/detail/prefix.hpp>", private, "<boost/lockfree/spsc_queue.hpp>", public ] },
+ { include: ["<boost/lockfree/detail/tagged_ptr.hpp>", private, "<boost/lockfree/queue.hpp>", public ] },
+ { include: ["<boost/lockfree/detail/tagged_ptr.hpp>", private, "<boost/lockfree/stack.hpp>", public ] },
+ { include: ["<boost/log/detail/asio_fwd.hpp>", private, "<boost/log/sinks/syslog_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/attachable_sstream_buf.hpp>", private, "<boost/log/sinks/basic_sink_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/attachable_sstream_buf.hpp>", private, "<boost/log/utility/formatting_ostream.hpp>", public ] },
+ { include: ["<boost/log/detail/attribute_get_value_impl.hpp>", private, "<boost/log/attributes/attribute.hpp>", public ] },
+ { include: ["<boost/log/detail/attribute_get_value_impl.hpp>", private, "<boost/log/attributes/attribute_value.hpp>", public ] },
+ { include: ["<boost/log/detail/attribute_predicate.hpp>", private, "<boost/log/expressions/predicates/begins_with.hpp>", public ] },
+ { include: ["<boost/log/detail/attribute_predicate.hpp>", private, "<boost/log/expressions/predicates/contains.hpp>", public ] },
+ { include: ["<boost/log/detail/attribute_predicate.hpp>", private, "<boost/log/expressions/predicates/ends_with.hpp>", public ] },
+ { include: ["<boost/log/detail/attribute_predicate.hpp>", private, "<boost/log/expressions/predicates/is_in_range.hpp>", public ] },
+ { include: ["<boost/log/detail/attribute_predicate.hpp>", private, "<boost/log/expressions/predicates/matches.hpp>", public ] },
+ { include: ["<boost/log/detail/attr_output_impl.hpp>", private, "<boost/log/expressions/attr.hpp>", public ] },
+ { include: ["<boost/log/detail/attr_output_impl.hpp>", private, "<boost/log/expressions/formatters/stream.hpp>", public ] },
+ { include: ["<boost/log/detail/attr_output_terminal.hpp>", private, "<boost/log/expressions/formatters/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/attr_output_terminal.hpp>", private, "<boost/log/expressions/formatters/named_scope.hpp>", public ] },
+ { include: ["<boost/log/detail/cleanup_scope_guard.hpp>", private, "<boost/log/sinks/basic_sink_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/cleanup_scope_guard.hpp>", private, "<boost/log/sinks/text_multifile_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/code_conversion.hpp>", private, "<boost/log/sinks/basic_sink_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/code_conversion.hpp>", private, "<boost/log/utility/formatting_ostream.hpp>", public ] },
+ { include: ["<boost/log/detail/code_conversion.hpp>", private, "<boost/log/utility/setup/filter_parser.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/attribute_cast.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/attribute.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/attribute_name.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/attribute_set.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/attribute_value.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/attribute_value_impl.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/attribute_value_set.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/clock.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/constant.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/counter.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/current_process_id.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/current_process_name.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/current_thread_id.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/fallback_policy_fwd.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/fallback_policy.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/function.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/mutable_constant.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/named_scope.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/scoped_attribute.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/timer.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/time_traits.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/value_extraction_fwd.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/value_extraction.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/value_visitation_fwd.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/attributes/value_visitation.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/common.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/core/core.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/core.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/core/record.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/core/record_view.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/exceptions.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/attr_fwd.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/attr.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/filter.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/formatter.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/formatters/c_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/formatters/char_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/formatters/csv_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/formatters/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/formatters/format.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/formatters.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/formatters/if.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/formatters/named_scope.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/formatters/stream.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/formatters/wrap_formatter.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/formatters/xml_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/is_keyword_descriptor.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/keyword_fwd.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/keyword.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/message.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/predicates/begins_with.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/predicates/channel_severity_filter.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/predicates/contains.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/predicates/ends_with.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/predicates/has_attr.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/predicates.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/predicates/is_debugger_present.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/predicates/is_in_range.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/predicates/matches.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/expressions/record.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/auto_flush.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/channel.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/delimiter.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/depth.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/facility.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/file_name.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/filter.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/format.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/ident.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/ip_version.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/iteration.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/log_name.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/log_source.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/max_size.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/message_file.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/min_free_space.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/open_mode.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/order.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/ordering_window.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/registration.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/rotation_size.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/scan_method.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/severity.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/start_thread.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/target.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/time_based_rotation.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/keywords/use_impl.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/async_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/attribute_mapping.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/basic_sink_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/basic_sink_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/block_on_overflow.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/bounded_fifo_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/bounded_ordering_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/debug_output_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/drop_on_overflow.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/event_log_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/event_log_constants.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/frontend_requirements.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/sink.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/sync_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/syslog_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/syslog_constants.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/text_file_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/text_multifile_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/text_ostream_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/unbounded_fifo_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/unbounded_ordering_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sinks/unlocked_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sources/basic_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sources/channel_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sources/channel_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sources/exception_handler_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sources/features.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sources/global_logger_storage.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sources/logger.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sources/record_ostream.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sources/severity_channel_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sources/severity_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sources/severity_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/sources/threading_models.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/support/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/support/exception.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/support/regex.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/support/spirit_classic.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/support/spirit_qi.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/support/xpressive.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/trivial.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/empty_deleter.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/exception_handler.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/explicit_operator_bool.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/formatting_ostream_fwd.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/formatting_ostream.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional/as_action.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional/begins_with.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional/bind_assign.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional/bind.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional/bind_output.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional/bind_to_log.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional/contains.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional/ends_with.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional/fun_ref.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional/in_range.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional/logical.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional/matches.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional/nop.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/functional/save_result.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/intrusive_ref_counter.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/manipulators/add_value.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/manipulators/dump.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/manipulators.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/manipulators/to_log.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/once_block.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/record_ordering.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/setup/common_attributes.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/setup/console.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/setup/file.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/strictest_lock.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/string_literal_fwd.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/string_literal.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/type_dispatch/date_time_types.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/type_dispatch/dynamic_type_dispatcher.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/type_dispatch/standard_types.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/type_dispatch/static_type_dispatcher.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/type_dispatch/type_dispatcher.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/type_info_wrapper.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/unique_identifier_name.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/unused_variable.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/value_ref_fwd.hpp>", public ] },
+ { include: ["<boost/log/detail/config.hpp>", private, "<boost/log/utility/value_ref.hpp>", public ] },
+ { include: ["<boost/log/detail/custom_terminal_spec.hpp>", private, "<boost/log/expressions/attr.hpp>", public ] },
+ { include: ["<boost/log/detail/custom_terminal_spec.hpp>", private, "<boost/log/expressions/formatters/char_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/custom_terminal_spec.hpp>", private, "<boost/log/expressions/formatters/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/custom_terminal_spec.hpp>", private, "<boost/log/expressions/formatters/format.hpp>", public ] },
+ { include: ["<boost/log/detail/custom_terminal_spec.hpp>", private, "<boost/log/expressions/formatters/if.hpp>", public ] },
+ { include: ["<boost/log/detail/custom_terminal_spec.hpp>", private, "<boost/log/expressions/formatters/named_scope.hpp>", public ] },
+ { include: ["<boost/log/detail/custom_terminal_spec.hpp>", private, "<boost/log/expressions/formatters/wrap_formatter.hpp>", public ] },
+ { include: ["<boost/log/detail/custom_terminal_spec.hpp>", private, "<boost/log/expressions/keyword.hpp>", public ] },
+ { include: ["<boost/log/detail/custom_terminal_spec.hpp>", private, "<boost/log/expressions/predicates/channel_severity_filter.hpp>", public ] },
+ { include: ["<boost/log/detail/date_time_fmt_gen_traits_fwd.hpp>", private, "<boost/log/expressions/formatters/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/date_time_fmt_gen_traits_fwd.hpp>", private, "<boost/log/support/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/date_time_format_parser.hpp>", private, "<boost/log/support/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/decomposed_time.hpp>", private, "<boost/log/support/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/deduce_char_type.hpp>", private, "<boost/log/expressions/formatters/char_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/deduce_char_type.hpp>", private, "<boost/log/expressions/formatters/named_scope.hpp>", public ] },
+ { include: ["<boost/log/detail/default_attribute_names.hpp>", private, "<boost/log/expressions/message.hpp>", public ] },
+ { include: ["<boost/log/detail/default_attribute_names.hpp>", private, "<boost/log/sources/channel_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/default_attribute_names.hpp>", private, "<boost/log/sources/severity_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/default_attribute_names.hpp>", private, "<boost/log/utility/setup/common_attributes.hpp>", public ] },
+ { include: ["<boost/log/detail/embedded_string_type.hpp>", private, "<boost/log/attributes/constant.hpp>", public ] },
+ { include: ["<boost/log/detail/embedded_string_type.hpp>", private, "<boost/log/expressions/predicates/begins_with.hpp>", public ] },
+ { include: ["<boost/log/detail/embedded_string_type.hpp>", private, "<boost/log/expressions/predicates/contains.hpp>", public ] },
+ { include: ["<boost/log/detail/embedded_string_type.hpp>", private, "<boost/log/expressions/predicates/ends_with.hpp>", public ] },
+ { include: ["<boost/log/detail/embedded_string_type.hpp>", private, "<boost/log/expressions/predicates/is_in_range.hpp>", public ] },
+ { include: ["<boost/log/detail/embedded_string_type.hpp>", private, "<boost/log/utility/manipulators/add_value.hpp>", public ] },
+ { include: ["<boost/log/detail/event.hpp>", private, "<boost/log/sinks/unbounded_fifo_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/fake_mutex.hpp>", private, "<boost/log/sinks/basic_sink_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/fake_mutex.hpp>", private, "<boost/log/sinks/unlocked_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/attribute_cast.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/attribute.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/attribute_name.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/attribute_set.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/attribute_value.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/attribute_value_impl.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/attribute_value_set.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/clock.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/constant.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/counter.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/current_process_id.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/current_process_name.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/current_thread_id.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/fallback_policy.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/function.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/mutable_constant.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/named_scope.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/scoped_attribute.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/timer.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/time_traits.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/value_extraction.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/attributes/value_visitation.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/core/core.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/core/record.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/core/record_view.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/exceptions.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/attr.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/filter.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/formatter.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/formatters/c_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/formatters/char_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/formatters/csv_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/formatters/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/formatters/format.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/formatters/if.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/formatters/named_scope.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/formatters/stream.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/formatters/wrap_formatter.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/formatters/xml_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/is_keyword_descriptor.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/keyword.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/message.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/predicates/begins_with.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/predicates/channel_severity_filter.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/predicates/contains.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/predicates/ends_with.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/predicates/has_attr.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/predicates/is_debugger_present.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/predicates/is_in_range.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/predicates/matches.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/expressions/record.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/async_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/attribute_mapping.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/basic_sink_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/basic_sink_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/block_on_overflow.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/bounded_fifo_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/bounded_ordering_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/debug_output_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/drop_on_overflow.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/event_log_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/event_log_constants.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/frontend_requirements.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/sink.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/sync_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/syslog_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/syslog_constants.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/text_file_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/text_multifile_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/text_ostream_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/unbounded_fifo_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/unbounded_ordering_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sinks/unlocked_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sources/basic_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sources/channel_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sources/channel_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sources/exception_handler_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sources/features.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sources/global_logger_storage.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sources/logger.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sources/record_ostream.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sources/severity_channel_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sources/severity_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sources/severity_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/sources/threading_models.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/support/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/support/exception.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/support/regex.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/support/spirit_classic.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/support/spirit_qi.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/support/xpressive.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/trivial.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/exception_handler.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/formatting_ostream.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/functional/as_action.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/functional/begins_with.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/functional/bind_assign.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/functional/bind.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/functional/bind_output.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/functional/bind_to_log.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/functional/contains.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/functional/ends_with.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/functional/fun_ref.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/functional/in_range.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/functional/logical.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/functional/matches.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/functional/nop.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/functional/save_result.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/intrusive_ref_counter.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/manipulators/add_value.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/manipulators/dump.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/manipulators/to_log.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/once_block.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/record_ordering.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/setup/common_attributes.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/setup/console.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/setup/file.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/setup/filter_parser.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/setup/formatter_parser.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/setup/from_settings.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/setup/from_stream.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/setup/settings.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/setup/settings_parser.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/strictest_lock.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/string_literal.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/type_dispatch/date_time_types.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/type_dispatch/dynamic_type_dispatcher.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/type_dispatch/standard_types.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/type_dispatch/static_type_dispatcher.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/type_dispatch/type_dispatcher.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/type_info_wrapper.hpp>", public ] },
+ { include: ["<boost/log/detail/footer.hpp>", private, "<boost/log/utility/value_ref.hpp>", public ] },
+ { include: ["<boost/log/detail/format.hpp>", private, "<boost/log/expressions/formatters/format.hpp>", public ] },
+ { include: ["<boost/log/detail/function_traits.hpp>", private, "<boost/log/expressions/formatters/wrap_formatter.hpp>", public ] },
+ { include: ["<boost/log/detail/function_traits.hpp>", private, "<boost/log/utility/record_ordering.hpp>", public ] },
+ { include: ["<boost/log/detail/generate_overloads.hpp>", private, "<boost/log/expressions/formatters/char_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/generate_overloads.hpp>", private, "<boost/log/expressions/formatters/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/generate_overloads.hpp>", private, "<boost/log/expressions/formatters/if.hpp>", public ] },
+ { include: ["<boost/log/detail/generate_overloads.hpp>", private, "<boost/log/expressions/formatters/named_scope.hpp>", public ] },
+ { include: ["<boost/log/detail/generate_overloads.hpp>", private, "<boost/log/expressions/formatters/wrap_formatter.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/attribute_cast.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/attribute.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/attribute_name.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/attribute_set.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/attribute_value.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/attribute_value_impl.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/attribute_value_set.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/clock.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/constant.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/counter.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/current_process_id.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/current_process_name.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/current_thread_id.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/fallback_policy.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/function.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/mutable_constant.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/named_scope.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/scoped_attribute.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/timer.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/time_traits.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/value_extraction.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/attributes/value_visitation.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/core/core.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/core/record.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/core/record_view.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/exceptions.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/attr.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/filter.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/formatter.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/formatters/c_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/formatters/char_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/formatters/csv_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/formatters/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/formatters/format.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/formatters/if.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/formatters/named_scope.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/formatters/stream.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/formatters/wrap_formatter.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/formatters/xml_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/is_keyword_descriptor.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/keyword.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/message.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/predicates/begins_with.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/predicates/channel_severity_filter.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/predicates/contains.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/predicates/ends_with.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/predicates/has_attr.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/predicates/is_debugger_present.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/predicates/is_in_range.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/predicates/matches.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/expressions/record.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/async_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/attribute_mapping.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/basic_sink_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/basic_sink_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/block_on_overflow.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/bounded_fifo_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/bounded_ordering_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/debug_output_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/drop_on_overflow.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/event_log_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/event_log_constants.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/frontend_requirements.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/sink.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/sync_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/syslog_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/syslog_constants.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/text_file_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/text_multifile_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/text_ostream_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/unbounded_fifo_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/unbounded_ordering_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sinks/unlocked_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sources/basic_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sources/channel_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sources/channel_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sources/exception_handler_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sources/features.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sources/global_logger_storage.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sources/logger.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sources/record_ostream.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sources/severity_channel_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sources/severity_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sources/severity_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/sources/threading_models.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/support/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/support/exception.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/support/regex.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/support/spirit_classic.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/support/spirit_qi.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/support/xpressive.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/trivial.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/exception_handler.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/formatting_ostream.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/functional/as_action.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/functional/begins_with.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/functional/bind_assign.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/functional/bind.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/functional/bind_output.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/functional/bind_to_log.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/functional/contains.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/functional/ends_with.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/functional/fun_ref.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/functional/in_range.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/functional/logical.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/functional/matches.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/functional/nop.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/functional/save_result.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/intrusive_ref_counter.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/manipulators/add_value.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/manipulators/dump.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/manipulators/to_log.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/once_block.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/record_ordering.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/setup/common_attributes.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/setup/console.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/setup/file.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/setup/filter_parser.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/setup/formatter_parser.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/setup/from_settings.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/setup/from_stream.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/setup/settings.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/setup/settings_parser.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/strictest_lock.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/string_literal.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/type_dispatch/date_time_types.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/type_dispatch/dynamic_type_dispatcher.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/type_dispatch/standard_types.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/type_dispatch/static_type_dispatcher.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/type_dispatch/type_dispatcher.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/type_info_wrapper.hpp>", public ] },
+ { include: ["<boost/log/detail/header.hpp>", private, "<boost/log/utility/value_ref.hpp>", public ] },
+ { include: ["<boost/log/detail/light_function.hpp>", private, "<boost/log/core/core.hpp>", public ] },
+ { include: ["<boost/log/detail/light_function.hpp>", private, "<boost/log/expressions/filter.hpp>", public ] },
+ { include: ["<boost/log/detail/light_function.hpp>", private, "<boost/log/expressions/formatter.hpp>", public ] },
+ { include: ["<boost/log/detail/light_function.hpp>", private, "<boost/log/expressions/formatters/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/light_function.hpp>", private, "<boost/log/expressions/formatters/named_scope.hpp>", public ] },
+ { include: ["<boost/log/detail/light_function.hpp>", private, "<boost/log/sinks/event_log_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/light_function.hpp>", private, "<boost/log/sinks/sink.hpp>", public ] },
+ { include: ["<boost/log/detail/light_function.hpp>", private, "<boost/log/sinks/syslog_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/light_function.hpp>", private, "<boost/log/sinks/text_file_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/light_function.hpp>", private, "<boost/log/sinks/text_multifile_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/light_function.hpp>", private, "<boost/log/sources/exception_handler_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/light_function.hpp>", private, "<boost/log/support/date_time.hpp>", public ] },
+ { include: ["<boost/log/detail/light_rw_mutex.hpp>", private, "<boost/log/sinks/basic_sink_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/light_rw_mutex.hpp>", private, "<boost/log/sources/channel_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/light_rw_mutex.hpp>", private, "<boost/log/sources/logger.hpp>", public ] },
+ { include: ["<boost/log/detail/light_rw_mutex.hpp>", private, "<boost/log/sources/severity_channel_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/light_rw_mutex.hpp>", private, "<boost/log/sources/severity_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/locking_ptr.hpp>", private, "<boost/log/sinks/async_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/locking_ptr.hpp>", private, "<boost/log/sinks/sync_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/locks.hpp>", private, "<boost/log/attributes/mutable_constant.hpp>", public ] },
+ { include: ["<boost/log/detail/locks.hpp>", private, "<boost/log/sinks/basic_sink_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/locks.hpp>", private, "<boost/log/sources/channel_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/locks.hpp>", private, "<boost/log/sources/exception_handler_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/locks.hpp>", private, "<boost/log/sources/severity_feature.hpp>", public ] },
+ { include: ["<boost/log/detail/locks.hpp>", private, "<boost/log/sources/threading_models.hpp>", public ] },
+ { include: ["<boost/log/detail/locks.hpp>", private, "<boost/log/utility/strictest_lock.hpp>", public ] },
+ { include: ["<boost/log/detail/native_typeof.hpp>", private, "<boost/log/sources/record_ostream.hpp>", public ] },
+ { include: ["<boost/log/detail/native_typeof.hpp>", private, "<boost/log/utility/setup/settings.hpp>", public ] },
+ { include: ["<boost/log/detail/parameter_tools.hpp>", private, "<boost/log/expressions/formatters/named_scope.hpp>", public ] },
+ { include: ["<boost/log/detail/parameter_tools.hpp>", private, "<boost/log/sinks/async_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/parameter_tools.hpp>", private, "<boost/log/sinks/event_log_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/parameter_tools.hpp>", private, "<boost/log/sinks/sync_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/parameter_tools.hpp>", private, "<boost/log/sinks/syslog_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/parameter_tools.hpp>", private, "<boost/log/sinks/text_file_backend.hpp>", public ] },
+ { include: ["<boost/log/detail/parameter_tools.hpp>", private, "<boost/log/sinks/unlocked_frontend.hpp>", public ] },
+ { include: ["<boost/log/detail/parameter_tools.hpp>", private, "<boost/log/sources/basic_logger.hpp>", public ] },
+ { include: ["<boost/log/detail/parameter_tools.hpp>", private, "<boost/log/utility/setup/file.hpp>", public ] },
+ { include: ["<boost/log/detail/parameter_tools.hpp>", private, "<boost/log/utility/value_ref.hpp>", public ] },
+ { include: ["<boost/log/detail/pp_identity.hpp>", private, "<boost/log/utility/strictest_lock.hpp>", public ] },
+ { include: ["<boost/log/detail/process_id.hpp>", private, "<boost/log/attributes/current_process_id.hpp>", public ] },
+ { include: ["<boost/log/detail/setup_config.hpp>", private, "<boost/log/utility/setup/filter_parser.hpp>", public ] },
+ { include: ["<boost/log/detail/setup_config.hpp>", private, "<boost/log/utility/setup/formatter_parser.hpp>", public ] },
+ { include: ["<boost/log/detail/setup_config.hpp>", private, "<boost/log/utility/setup/from_settings.hpp>", public ] },
+ { include: ["<boost/log/detail/setup_config.hpp>", private, "<boost/log/utility/setup/from_stream.hpp>", public ] },
+ { include: ["<boost/log/detail/setup_config.hpp>", private, "<boost/log/utility/setup.hpp>", public ] },
+ { include: ["<boost/log/detail/setup_config.hpp>", private, "<boost/log/utility/setup/settings.hpp>", public ] },
+ { include: ["<boost/log/detail/setup_config.hpp>", private, "<boost/log/utility/setup/settings_parser.hpp>", public ] },
+ { include: ["<boost/log/detail/singleton.hpp>", private, "<boost/log/sources/global_logger_storage.hpp>", public ] },
+ { include: ["<boost/log/detail/sink_init_helpers.hpp>", private, "<boost/log/utility/setup/console.hpp>", public ] },
+ { include: ["<boost/log/detail/sink_init_helpers.hpp>", private, "<boost/log/utility/setup/file.hpp>", public ] },
+ { include: ["<boost/log/detail/snprintf.hpp>", private, "<boost/log/expressions/formatters/c_decorator.hpp>", public ] },
+ { include: ["<boost/log/detail/tagged_integer.hpp>", private, "<boost/log/sinks/attribute_mapping.hpp>", public ] },
+ { include: ["<boost/log/detail/tagged_integer.hpp>", private, "<boost/log/sinks/event_log_constants.hpp>", public ] },
+ { include: ["<boost/log/detail/thread_id.hpp>", private, "<boost/log/attributes/current_thread_id.hpp>", public ] },
+ { include: ["<boost/log/detail/threadsafe_queue.hpp>", private, "<boost/log/sinks/unbounded_fifo_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/timestamp.hpp>", private, "<boost/log/sinks/bounded_ordering_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/timestamp.hpp>", private, "<boost/log/sinks/unbounded_ordering_queue.hpp>", public ] },
+ { include: ["<boost/log/detail/trivial_keyword.hpp>", private, "<boost/log/expressions/keyword.hpp>", public ] },
+ { include: ["<boost/log/detail/trivial_keyword.hpp>", private, "<boost/log/trivial.hpp>", public ] },
+ { include: ["<boost/log/detail/unary_function_terminal.hpp>", private, "<boost/log/expressions/predicates/begins_with.hpp>", public ] },
+ { include: ["<boost/log/detail/unary_function_terminal.hpp>", private, "<boost/log/expressions/predicates/contains.hpp>", public ] },
+ { include: ["<boost/log/detail/unary_function_terminal.hpp>", private, "<boost/log/expressions/predicates/ends_with.hpp>", public ] },
+ { include: ["<boost/log/detail/unary_function_terminal.hpp>", private, "<boost/log/expressions/predicates/has_attr.hpp>", public ] },
+ { include: ["<boost/log/detail/unary_function_terminal.hpp>", private, "<boost/log/expressions/predicates/is_in_range.hpp>", public ] },
+ { include: ["<boost/log/detail/unary_function_terminal.hpp>", private, "<boost/log/expressions/predicates/matches.hpp>", public ] },
+ { include: ["<boost/log/detail/unhandled_exception_count.hpp>", private, "<boost/log/sources/record_ostream.hpp>", public ] },
+ { include: ["<boost/log/detail/value_ref_visitation.hpp>", private, "<boost/log/utility/value_ref.hpp>", public ] },
+ { include: ["<boost/log/detail/visible_type.hpp>", private, "<boost/log/sources/global_logger_storage.hpp>", public ] },
+ { include: ["<boost/log/detail/visible_type.hpp>", private, "<boost/log/utility/type_dispatch/dynamic_type_dispatcher.hpp>", public ] },
+ { include: ["<boost/log/detail/visible_type.hpp>", private, "<boost/log/utility/type_dispatch/static_type_dispatcher.hpp>", public ] },
+ { include: ["<boost/log/detail/visible_type.hpp>", private, "<boost/log/utility/type_dispatch/type_dispatcher.hpp>", public ] },
+ { include: ["<boost/math/bindings/detail/big_digamma.hpp>", private, "<boost/math/bindings/e_float.hpp>", public ] },
+ { include: ["<boost/math/bindings/detail/big_digamma.hpp>", private, "<boost/math/bindings/mpfr.hpp>", public ] },
+ { include: ["<boost/math/bindings/detail/big_digamma.hpp>", private, "<boost/math/bindings/mpreal.hpp>", public ] },
+ { include: ["<boost/math/bindings/detail/big_digamma.hpp>", private, "<boost/math/bindings/rr.hpp>", public ] },
+ { include: ["<boost/math/bindings/detail/big_lanczos.hpp>", private, "<boost/math/bindings/e_float.hpp>", public ] },
+ { include: ["<boost/math/bindings/detail/big_lanczos.hpp>", private, "<boost/math/bindings/mpfr.hpp>", public ] },
+ { include: ["<boost/math/bindings/detail/big_lanczos.hpp>", private, "<boost/math/bindings/mpreal.hpp>", public ] },
+ { include: ["<boost/math/bindings/detail/big_lanczos.hpp>", private, "<boost/math/bindings/rr.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/bernoulli.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/beta.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/binomial.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/cauchy.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/chi_squared.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/exponential.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/extreme_value.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/fisher_f.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/gamma.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/geometric.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/hypergeometric.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/inverse_chi_squared.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/inverse_gamma.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/inverse_gaussian.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/laplace.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/logistic.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/lognormal.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/negative_binomial.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/non_central_beta.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/non_central_chi_squared.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/normal.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/pareto.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/poisson.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/rayleigh.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/skew_normal.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/students_t.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/triangular.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/uniform.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/common_error_handling.hpp>", private, "<boost/math/distributions/weibull.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/bernoulli.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/beta.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/binomial.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/cauchy.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/chi_squared.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/exponential.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/extreme_value.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/fisher_f.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/gamma.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/geometric.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/hypergeometric.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/inverse_chi_squared.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/inverse_gamma.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/inverse_gaussian.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/laplace.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/logistic.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/lognormal.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/negative_binomial.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/non_central_beta.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/non_central_chi_squared.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/non_central_f.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/non_central_t.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/normal.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/pareto.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/poisson.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/rayleigh.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/skew_normal.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/students_t.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/triangular.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/uniform.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/derived_accessors.hpp>", private, "<boost/math/distributions/weibull.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/generic_mode.hpp>", private, "<boost/math/distributions/non_central_beta.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/generic_mode.hpp>", private, "<boost/math/distributions/non_central_chi_squared.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/generic_mode.hpp>", private, "<boost/math/distributions/non_central_f.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/generic_mode.hpp>", private, "<boost/math/distributions/skew_normal.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/generic_quantile.hpp>", private, "<boost/math/distributions/non_central_chi_squared.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/generic_quantile.hpp>", private, "<boost/math/distributions/non_central_t.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/hypergeometric_cdf.hpp>", private, "<boost/math/distributions/hypergeometric.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/hypergeometric_pdf.hpp>", private, "<boost/math/distributions/hypergeometric.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/hypergeometric_quantile.hpp>", private, "<boost/math/distributions/hypergeometric.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/inv_discrete_quantile.hpp>", private, "<boost/math/distributions/binomial.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/inv_discrete_quantile.hpp>", private, "<boost/math/distributions/geometric.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/inv_discrete_quantile.hpp>", private, "<boost/math/distributions/negative_binomial.hpp>", public ] },
+ { include: ["<boost/math/distributions/detail/inv_discrete_quantile.hpp>", private, "<boost/math/distributions/poisson.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/airy_ai_bi_zero.hpp>", private, "<boost/math/special_functions/airy.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/bessel_i0.hpp>", private, "<boost/math/special_functions/bessel.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/bessel_i1.hpp>", private, "<boost/math/special_functions/bessel.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/bessel_ik.hpp>", private, "<boost/math/special_functions/bessel.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/bessel_jn.hpp>", private, "<boost/math/special_functions/bessel.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/bessel_jy.hpp>", private, "<boost/math/special_functions/bessel.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/bessel_jy_zero.hpp>", private, "<boost/math/special_functions/bessel.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/bessel_kn.hpp>", private, "<boost/math/special_functions/bessel.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/bessel_yn.hpp>", private, "<boost/math/special_functions/bessel.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/erf_inv.hpp>", private, "<boost/math/special_functions/erf.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/fp_traits.hpp>", private, "<boost/math/special_functions/fpclassify.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/fp_traits.hpp>", private, "<boost/math/special_functions/sign.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/gamma_inva.hpp>", private, "<boost/math/special_functions/gamma.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/ibeta_inv_ab.hpp>", private, "<boost/math/special_functions/beta.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/ibeta_inverse.hpp>", private, "<boost/math/special_functions/beta.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/iconv.hpp>", private, "<boost/math/special_functions/bessel.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/igamma_inverse.hpp>", private, "<boost/math/special_functions/gamma.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/igamma_large.hpp>", private, "<boost/math/special_functions/gamma.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/lanczos_sse2.hpp>", private, "<boost/math/special_functions/lanczos.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/lgamma_small.hpp>", private, "<boost/math/special_functions/gamma.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/round_fwd.hpp>", private, "<boost/math/special_functions/math_fwd.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/unchecked_factorial.hpp>", private, "<boost/math/special_functions/factorials.hpp>", public ] },
+ { include: ["<boost/math/special_functions/detail/unchecked_factorial.hpp>", private, "<boost/math/special_functions/gamma.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_10.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_11.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_12.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_13.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_14.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_15.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_16.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_17.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_18.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_19.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_20.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_2.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_3.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_4.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_5.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_6.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_7.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_8.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner1_9.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_10.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_11.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_12.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_13.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_14.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_15.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_16.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_17.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_18.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_19.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_20.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_2.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_3.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_4.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_5.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_6.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_7.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_8.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner2_9.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_10.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_11.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_12.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_13.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_14.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_15.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_16.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_17.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_18.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_19.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_20.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_2.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_3.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_4.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_5.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_6.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_7.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_8.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/polynomial_horner3_9.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_10.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_11.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_12.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_13.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_14.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_15.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_16.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_17.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_18.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_19.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_20.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_2.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_3.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_4.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_5.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_6.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_7.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_8.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner1_9.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_10.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_11.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_12.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_13.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_14.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_15.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_16.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_17.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_18.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_19.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_20.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_2.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_3.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_4.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_5.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_6.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_7.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_8.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner2_9.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_10.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_11.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_12.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_13.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_14.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_15.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_16.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_17.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_18.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_19.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_20.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_2.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_3.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_4.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_5.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_6.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_7.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_8.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/math/tools/detail/rational_horner3_9.hpp>", private, "<boost/math/tools/rational.hpp>", public ] },
+ { include: ["<boost/move/detail/config_begin.hpp>", private, "<boost/move/algorithm.hpp>", public ] },
+ { include: ["<boost/move/detail/config_begin.hpp>", private, "<boost/move/core.hpp>", public ] },
+ { include: ["<boost/move/detail/config_begin.hpp>", private, "<boost/move/iterator.hpp>", public ] },
+ { include: ["<boost/move/detail/config_begin.hpp>", private, "<boost/move/move.hpp>", public ] },
+ { include: ["<boost/move/detail/config_begin.hpp>", private, "<boost/move/traits.hpp>", public ] },
+ { include: ["<boost/move/detail/config_begin.hpp>", private, "<boost/move/utility.hpp>", public ] },
+ { include: ["<boost/move/detail/config_end.hpp>", private, "<boost/move/algorithm.hpp>", public ] },
+ { include: ["<boost/move/detail/config_end.hpp>", private, "<boost/move/core.hpp>", public ] },
+ { include: ["<boost/move/detail/config_end.hpp>", private, "<boost/move/iterator.hpp>", public ] },
+ { include: ["<boost/move/detail/config_end.hpp>", private, "<boost/move/move.hpp>", public ] },
+ { include: ["<boost/move/detail/config_end.hpp>", private, "<boost/move/traits.hpp>", public ] },
+ { include: ["<boost/move/detail/config_end.hpp>", private, "<boost/move/utility.hpp>", public ] },
+ { include: ["<boost/move/detail/meta_utils.hpp>", private, "<boost/move/core.hpp>", public ] },
+ { include: ["<boost/move/detail/meta_utils.hpp>", private, "<boost/move/traits.hpp>", public ] },
+ { include: ["<boost/move/detail/meta_utils.hpp>", private, "<boost/move/utility.hpp>", public ] },
+ { include: ["<boost/move/detail/move_helpers.hpp>", private, "<boost/container/deque.hpp>", public ] },
+ { include: ["<boost/move/detail/move_helpers.hpp>", private, "<boost/container/flat_map.hpp>", public ] },
+ { include: ["<boost/move/detail/move_helpers.hpp>", private, "<boost/container/flat_set.hpp>", public ] },
+ { include: ["<boost/move/detail/move_helpers.hpp>", private, "<boost/container/list.hpp>", public ] },
+ { include: ["<boost/move/detail/move_helpers.hpp>", private, "<boost/container/map.hpp>", public ] },
+ { include: ["<boost/move/detail/move_helpers.hpp>", private, "<boost/container/set.hpp>", public ] },
+ { include: ["<boost/move/detail/move_helpers.hpp>", private, "<boost/container/slist.hpp>", public ] },
+ { include: ["<boost/move/detail/move_helpers.hpp>", private, "<boost/container/stable_vector.hpp>", public ] },
+ { include: ["<boost/move/detail/move_helpers.hpp>", private, "<boost/container/vector.hpp>", public ] },
+ { include: ["<boost/mpi/detail/binary_buffer_iprimitive.hpp>", private, "<boost/mpi/packed_iarchive.hpp>", public ] },
+ { include: ["<boost/mpi/detail/binary_buffer_oprimitive.hpp>", private, "<boost/mpi/packed_oarchive.hpp>", public ] },
+ { include: ["<boost/mpi/detail/broadcast_sc.hpp>", private, "<boost/mpi/collectives/broadcast.hpp>", public ] },
+ { include: ["<boost/mpi/detail/broadcast_sc.hpp>", private, "<boost/mpi/skeleton_and_content.hpp>", public ] },
+ { include: ["<boost/mpi/detail/communicator_sc.hpp>", private, "<boost/mpi/communicator.hpp>", public ] },
+ { include: ["<boost/mpi/detail/communicator_sc.hpp>", private, "<boost/mpi/skeleton_and_content.hpp>", public ] },
+ { include: ["<boost/mpi/detail/computation_tree.hpp>", private, "<boost/mpi/collectives/reduce.hpp>", public ] },
+ { include: ["<boost/mpi/detail/computation_tree.hpp>", private, "<boost/mpi/collectives/scan.hpp>", public ] },
+ { include: ["<boost/mpi/detail/content_oarchive.hpp>", private, "<boost/mpi/skeleton_and_content.hpp>", public ] },
+ { include: ["<boost/mpi/detail/forward_skeleton_iarchive.hpp>", private, "<boost/mpi/skeleton_and_content.hpp>", public ] },
+ { include: ["<boost/mpi/detail/forward_skeleton_oarchive.hpp>", private, "<boost/mpi/skeleton_and_content.hpp>", public ] },
+ { include: ["<boost/mpi/detail/ignore_iprimitive.hpp>", private, "<boost/mpi/skeleton_and_content.hpp>", public ] },
+ { include: ["<boost/mpi/detail/ignore_oprimitive.hpp>", private, "<boost/mpi/skeleton_and_content.hpp>", public ] },
+ { include: ["<boost/mpi/detail/mpi_datatype_cache.hpp>", private, "<boost/mpi/datatype.hpp>", public ] },
+ { include: ["<boost/mpi/detail/packed_iprimitive.hpp>", private, "<boost/mpi/packed_iarchive.hpp>", public ] },
+ { include: ["<boost/mpi/detail/packed_oprimitive.hpp>", private, "<boost/mpi/packed_oarchive.hpp>", public ] },
+ { include: ["<boost/mpi/detail/point_to_point.hpp>", private, "<boost/mpi/collectives/gather.hpp>", public ] },
+ { include: ["<boost/mpi/detail/point_to_point.hpp>", private, "<boost/mpi/collectives/reduce.hpp>", public ] },
+ { include: ["<boost/mpi/detail/point_to_point.hpp>", private, "<boost/mpi/collectives/scan.hpp>", public ] },
+ { include: ["<boost/mpi/detail/point_to_point.hpp>", private, "<boost/mpi/collectives/scatter.hpp>", public ] },
+ { include: ["<boost/mpi/detail/point_to_point.hpp>", private, "<boost/mpi/communicator.hpp>", public ] },
+ { include: ["<boost/msm/front/detail/common_states.hpp>", private, "<boost/msm/front/common_states.hpp>", public ] },
+ { include: ["<boost/msm/front/detail/row2_helper.hpp>", private, "<boost/msm/front/internal_row.hpp>", public ] },
+ { include: ["<boost/msm/front/detail/row2_helper.hpp>", private, "<boost/msm/front/row2.hpp>", public ] },
+ { include: ["<boost/msm/mpl_graph/detail/adjacency_list_graph.ipp>", private, "<boost/msm/mpl_graph/adjacency_list_graph.hpp>", public ] },
+ { include: ["<boost/msm/mpl_graph/detail/graph_implementation_interface.ipp>", private, "<boost/msm/mpl_graph/mpl_graph.hpp>", public ] },
+ { include: ["<boost/msm/mpl_graph/detail/incidence_list_graph.ipp>", private, "<boost/msm/mpl_graph/incidence_list_graph.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/access_specifier.hpp>", private, "<boost/multi_index/composite_key.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/access_specifier.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/access_specifier.hpp>", private, "<boost/multi_index/hashed_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/access_specifier.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/access_specifier.hpp>", private, "<boost/multi_index/random_access_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/access_specifier.hpp>", private, "<boost/multi_index/sequenced_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/adl_swap.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/archive_constructed.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/auto_space.hpp>", private, "<boost/multi_index/hashed_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/base_type.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/bidir_node_iterator.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/bidir_node_iterator.hpp>", private, "<boost/multi_index/sequenced_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/bucket_array.hpp>", private, "<boost/multi_index/hashed_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/converter.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/do_not_copy_elements_tag.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/do_not_copy_elements_tag.hpp>", private, "<boost/multi_index/hashed_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/do_not_copy_elements_tag.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/do_not_copy_elements_tag.hpp>", private, "<boost/multi_index/random_access_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/do_not_copy_elements_tag.hpp>", private, "<boost/multi_index/sequenced_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/duplicates_iterator.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/hash_index_args.hpp>", private, "<boost/multi_index/hashed_index_fwd.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/hash_index_iterator.hpp>", private, "<boost/multi_index/hashed_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/has_tag.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/header_holder.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/index_node_base.hpp>", private, "<boost/multi_index/hashed_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/index_node_base.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/index_node_base.hpp>", private, "<boost/multi_index/random_access_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/index_node_base.hpp>", private, "<boost/multi_index/sequenced_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/invariant_assert.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/modify_key_adaptor.hpp>", private, "<boost/multi_index/hashed_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/modify_key_adaptor.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/no_duplicate_tags.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/no_duplicate_tags.hpp>", private, "<boost/multi_index/tag.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/ord_index_args.hpp>", private, "<boost/multi_index/ordered_index_fwd.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/ord_index_node.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/ord_index_ops.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/prevent_eti.hpp>", private, "<boost/multi_index/composite_key.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/prevent_eti.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/rnd_index_loader.hpp>", private, "<boost/multi_index/random_access_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/rnd_index_node.hpp>", private, "<boost/multi_index/random_access_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/rnd_index_ops.hpp>", private, "<boost/multi_index/random_access_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/rnd_index_ptr_array.hpp>", private, "<boost/multi_index/random_access_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/rnd_node_iterator.hpp>", private, "<boost/multi_index/random_access_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/safe_ctr_proxy.hpp>", private, "<boost/multi_index/hashed_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/safe_ctr_proxy.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/safe_ctr_proxy.hpp>", private, "<boost/multi_index/random_access_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/safe_ctr_proxy.hpp>", private, "<boost/multi_index/sequenced_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/safe_mode.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/safe_mode.hpp>", private, "<boost/multi_index/hashed_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/safe_mode.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/safe_mode.hpp>", private, "<boost/multi_index/random_access_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/safe_mode.hpp>", private, "<boost/multi_index/sequenced_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/scope_guard.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/scope_guard.hpp>", private, "<boost/multi_index/hashed_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/scope_guard.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/scope_guard.hpp>", private, "<boost/multi_index/random_access_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/scope_guard.hpp>", private, "<boost/multi_index/sequenced_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/seq_index_node.hpp>", private, "<boost/multi_index/sequenced_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/seq_index_ops.hpp>", private, "<boost/multi_index/sequenced_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/serialization_version.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/unbounded.hpp>", private, "<boost/bimap/bimap.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/unbounded.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/value_compare.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/vartempl_support.hpp>", private, "<boost/multi_index_container.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/vartempl_support.hpp>", private, "<boost/multi_index/hashed_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/vartempl_support.hpp>", private, "<boost/multi_index/ordered_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/vartempl_support.hpp>", private, "<boost/multi_index/random_access_index.hpp>", public ] },
+ { include: ["<boost/multi_index/detail/vartempl_support.hpp>", private, "<boost/multi_index/sequenced_index.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/big_lanczos.hpp>", private, "<boost/multiprecision/cpp_dec_float.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/big_lanczos.hpp>", private, "<boost/multiprecision/gmp.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/big_lanczos.hpp>", private, "<boost/multiprecision/mpfi.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/big_lanczos.hpp>", private, "<boost/multiprecision/mpfr.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/bitscan.hpp>", private, "<boost/multiprecision/cpp_int/misc.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/bitscan.hpp>", private, "<boost/multiprecision/integer.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/digits.hpp>", private, "<boost/multiprecision/gmp.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/digits.hpp>", private, "<boost/multiprecision/mpfi.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/digits.hpp>", private, "<boost/multiprecision/mpfr.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/dynamic_array.hpp>", private, "<boost/multiprecision/cpp_dec_float.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/float_string_cvt.hpp>", private, "<boost/multiprecision/float128.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/generic_interconvert.hpp>", private, "<boost/multiprecision/number.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/integer_ops.hpp>", private, "<boost/multiprecision/cpp_int.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/integer_ops.hpp>", private, "<boost/multiprecision/debug_adaptor.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/integer_ops.hpp>", private, "<boost/multiprecision/gmp.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/integer_ops.hpp>", private, "<boost/multiprecision/logged_adaptor.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/integer_ops.hpp>", private, "<boost/multiprecision/tommath.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/number_base.hpp>", private, "<boost/multiprecision/traits/is_restricted_conversion.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/number_compare.hpp>", private, "<boost/multiprecision/number.hpp>", public ] },
+ { include: ["<boost/multiprecision/detail/ublas_interop.hpp>", private, "<boost/multiprecision/number.hpp>", public ] },
+ { include: ["<boost/numeric/conversion/detail/bounds.hpp>", private, "<boost/numeric/conversion/bounds.hpp>", public ] },
+ { include: ["<boost/numeric/conversion/detail/conversion_traits.hpp>", private, "<boost/numeric/conversion/conversion_traits.hpp>", public ] },
+ { include: ["<boost/numeric/conversion/detail/converter.hpp>", private, "<boost/numeric/conversion/converter.hpp>", public ] },
+ { include: ["<boost/numeric/conversion/detail/int_float_mixture.hpp>", private, "<boost/numeric/conversion/int_float_mixture.hpp>", public ] },
+ { include: ["<boost/numeric/conversion/detail/is_subranged.hpp>", private, "<boost/numeric/conversion/is_subranged.hpp>", public ] },
+ { include: ["<boost/numeric/conversion/detail/numeric_cast_traits.hpp>", private, "<boost/numeric/conversion/numeric_cast_traits.hpp>", public ] },
+ { include: ["<boost/numeric/conversion/detail/old_numeric_cast.hpp>", private, "<boost/numeric/conversion/cast.hpp>", public ] },
+ { include: ["<boost/numeric/conversion/detail/sign_mixture.hpp>", private, "<boost/numeric/conversion/sign_mixture.hpp>", public ] },
+ { include: ["<boost/numeric/conversion/detail/udt_builtin_mixture.hpp>", private, "<boost/numeric/conversion/udt_builtin_mixture.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/alpha_rounding_control.hpp>", private, "<boost/numeric/interval/hw_rounding.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/bugs.hpp>", private, "<boost/numeric/interval/arith2.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/bugs.hpp>", private, "<boost/numeric/interval/arith.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/bugs.hpp>", private, "<boost/numeric/interval/rounded_arith.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/bugs.hpp>", private, "<boost/numeric/interval/rounded_transc.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/bugs.hpp>", private, "<boost/numeric/interval/transc.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/bugs.hpp>", private, "<boost/numeric/interval/utility.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/c99_rounding_control.hpp>", private, "<boost/numeric/interval/hw_rounding.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/division.hpp>", private, "<boost/numeric/interval/arith2.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/division.hpp>", private, "<boost/numeric/interval/arith.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/ia64_rounding_control.hpp>", private, "<boost/numeric/interval/hw_rounding.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/arith2.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/arith3.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/compare/certain.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/compare/explicit.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/compare/lexicographic.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/compare/possible.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/compare/set.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/compare/tribool.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/ext/integer.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/interval.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/limits.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/transc.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/interval_prototype.hpp>", private, "<boost/numeric/interval/utility.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/ppc_rounding_control.hpp>", private, "<boost/numeric/interval/hw_rounding.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/sparc_rounding_control.hpp>", private, "<boost/numeric/interval/hw_rounding.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/test_input.hpp>", private, "<boost/numeric/interval/arith2.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/test_input.hpp>", private, "<boost/numeric/interval/arith3.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/test_input.hpp>", private, "<boost/numeric/interval/arith.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/test_input.hpp>", private, "<boost/numeric/interval/compare/certain.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/test_input.hpp>", private, "<boost/numeric/interval/compare/lexicographic.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/test_input.hpp>", private, "<boost/numeric/interval/compare/possible.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/test_input.hpp>", private, "<boost/numeric/interval/compare/set.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/test_input.hpp>", private, "<boost/numeric/interval/compare/tribool.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/test_input.hpp>", private, "<boost/numeric/interval/ext/integer.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/test_input.hpp>", private, "<boost/numeric/interval/transc.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/test_input.hpp>", private, "<boost/numeric/interval/utility.hpp>", public ] },
+ { include: ["<boost/numeric/interval/detail/x86_rounding_control.hpp>", private, "<boost/numeric/interval/hw_rounding.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/algebra/detail/for_each.hpp>", private, "<boost/numeric/odeint/algebra/range_algebra.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/algebra/detail/macros.hpp>", private, "<boost/numeric/odeint/algebra/range_algebra.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/algebra/detail/reduce.hpp>", private, "<boost/numeric/odeint/algebra/range_algebra.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/integrate/detail/integrate_adaptive.hpp>", private, "<boost/numeric/odeint/integrate/integrate_adaptive.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/integrate/detail/integrate_adaptive.hpp>", private, "<boost/numeric/odeint/integrate/integrate_const.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/integrate/detail/integrate_const.hpp>", private, "<boost/numeric/odeint/integrate/integrate_const.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/integrate/detail/integrate_n_steps.hpp>", private, "<boost/numeric/odeint/integrate/integrate_n_steps.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/integrate/detail/integrate_times.hpp>", private, "<boost/numeric/odeint/integrate/integrate_times.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/stepper/detail/adams_bashforth_call_algebra.hpp>", private, "<boost/numeric/odeint/stepper/adams_bashforth.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/stepper/detail/adams_bashforth_coefficients.hpp>", private, "<boost/numeric/odeint/stepper/adams_bashforth.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/stepper/detail/adams_moulton_call_algebra.hpp>", private, "<boost/numeric/odeint/stepper/adams_moulton.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/stepper/detail/adams_moulton_coefficients.hpp>", private, "<boost/numeric/odeint/stepper/adams_moulton.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/stepper/detail/generic_rk_algorithm.hpp>", private, "<boost/numeric/odeint/stepper/explicit_error_generic_rk.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/stepper/detail/generic_rk_algorithm.hpp>", private, "<boost/numeric/odeint/stepper/explicit_generic_rk.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/stepper/detail/generic_rk_call_algebra.hpp>", private, "<boost/numeric/odeint/stepper/explicit_error_generic_rk.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/stepper/detail/generic_rk_operations.hpp>", private, "<boost/numeric/odeint/stepper/explicit_error_generic_rk.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/stepper/detail/rotating_buffer.hpp>", private, "<boost/numeric/odeint/stepper/adams_bashforth.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/stepper/detail/rotating_buffer.hpp>", private, "<boost/numeric/odeint/stepper/adams_moulton.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/util/detail/is_range.hpp>", private, "<boost/numeric/odeint/util/copy.hpp>", public ] },
+ { include: ["<boost/numeric/odeint/util/detail/less_with_sign.hpp>", private, "<boost/numeric/odeint/stepper/bulirsch_stoer.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/config.hpp>", private, "<boost/numeric/ublas/exception.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/config.hpp>", private, "<boost/numeric/ublas/operation/num_columns.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/config.hpp>", private, "<boost/numeric/ublas/operation/num_rows.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/config.hpp>", private, "<boost/numeric/ublas/operation/size.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/config.hpp>", private, "<boost/numeric/ublas/traits.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/definitions.hpp>", private, "<boost/numeric/ublas/functional.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/duff.hpp>", private, "<boost/numeric/ublas/functional.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/iterator.hpp>", private, "<boost/numeric/ublas/storage.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/iterator.hpp>", private, "<boost/numeric/ublas/traits.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/matrix_assign.hpp>", private, "<boost/numeric/ublas/experimental/sparse_view.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/matrix_assign.hpp>", private, "<boost/numeric/ublas/matrix.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/matrix_assign.hpp>", private, "<boost/numeric/ublas/matrix_proxy.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/matrix_assign.hpp>", private, "<boost/numeric/ublas/matrix_sparse.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/matrix_assign.hpp>", private, "<boost/numeric/ublas/operation_blocked.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/raw.hpp>", private, "<boost/numeric/ublas/functional.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/returntype_deduction.hpp>", private, "<boost/numeric/ublas/traits.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/temporary.hpp>", private, "<boost/numeric/ublas/banded.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/temporary.hpp>", private, "<boost/numeric/ublas/hermitian.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/temporary.hpp>", private, "<boost/numeric/ublas/matrix_proxy.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/temporary.hpp>", private, "<boost/numeric/ublas/symmetric.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/temporary.hpp>", private, "<boost/numeric/ublas/triangular.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/temporary.hpp>", private, "<boost/numeric/ublas/vector_proxy.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/vector_assign.hpp>", private, "<boost/numeric/ublas/matrix_proxy.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/vector_assign.hpp>", private, "<boost/numeric/ublas/operation_blocked.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/vector_assign.hpp>", private, "<boost/numeric/ublas/vector.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/vector_assign.hpp>", private, "<boost/numeric/ublas/vector_proxy.hpp>", public ] },
+ { include: ["<boost/numeric/ublas/detail/vector_assign.hpp>", private, "<boost/numeric/ublas/vector_sparse.hpp>", public ] },
+ { include: ["<boost/pending/detail/disjoint_sets.hpp>", private, "<boost/pending/disjoint_sets.hpp>", public ] },
+ { include: ["<boost/pending/detail/int_iterator.hpp>", private, "<boost/graph/matrix_as_graph.hpp>", public ] },
+ { include: ["<boost/pending/detail/property.hpp>", private, "<boost/pending/property.hpp>", public ] },
+ { include: ["<boost/phoenix/core/detail/actor_operator.hpp>", private, "<boost/phoenix/core/actor.hpp>", public ] },
+ { include: ["<boost/phoenix/core/detail/actor_result_of.hpp>", private, "<boost/phoenix/core/actor.hpp>", public ] },
+ { include: ["<boost/phoenix/core/detail/argument.hpp>", private, "<boost/phoenix/core/argument.hpp>", public ] },
+ { include: ["<boost/phoenix/core/detail/call.hpp>", private, "<boost/phoenix/core/call.hpp>", public ] },
+ { include: ["<boost/phoenix/core/detail/expression.hpp>", private, "<boost/phoenix/core/expression.hpp>", public ] },
+ { include: ["<boost/phoenix/object/detail/construct_eval.hpp>", private, "<boost/phoenix/object/construct.hpp>", public ] },
+ { include: ["<boost/phoenix/object/detail/construct.hpp>", private, "<boost/phoenix/object/construct.hpp>", public ] },
+ { include: ["<boost/phoenix/object/detail/new_eval.hpp>", private, "<boost/phoenix/object/new.hpp>", public ] },
+ { include: ["<boost/phoenix/object/detail/new.hpp>", private, "<boost/phoenix/object/new.hpp>", public ] },
+ { include: ["<boost/phoenix/object/detail/target.hpp>", private, "<boost/phoenix/object/const_cast.hpp>", public ] },
+ { include: ["<boost/phoenix/object/detail/target.hpp>", private, "<boost/phoenix/object/construct.hpp>", public ] },
+ { include: ["<boost/phoenix/object/detail/target.hpp>", private, "<boost/phoenix/object/dynamic_cast.hpp>", public ] },
+ { include: ["<boost/phoenix/object/detail/target.hpp>", private, "<boost/phoenix/object/new.hpp>", public ] },
+ { include: ["<boost/phoenix/object/detail/target.hpp>", private, "<boost/phoenix/object/reinterpret_cast.hpp>", public ] },
+ { include: ["<boost/phoenix/object/detail/target.hpp>", private, "<boost/phoenix/object/static_cast.hpp>", public ] },
+ { include: ["<boost/phoenix/operator/detail/define_operator.hpp>", private, "<boost/phoenix/operator/arithmetic.hpp>", public ] },
+ { include: ["<boost/phoenix/operator/detail/define_operator.hpp>", private, "<boost/phoenix/operator/bitwise.hpp>", public ] },
+ { include: ["<boost/phoenix/operator/detail/define_operator.hpp>", private, "<boost/phoenix/operator/comparison.hpp>", public ] },
+ { include: ["<boost/phoenix/operator/detail/define_operator.hpp>", private, "<boost/phoenix/operator/logical.hpp>", public ] },
+ { include: ["<boost/phoenix/operator/detail/define_operator.hpp>", private, "<boost/phoenix/operator/member.hpp>", public ] },
+ { include: ["<boost/phoenix/operator/detail/define_operator.hpp>", private, "<boost/phoenix/operator/self.hpp>", public ] },
+ { include: ["<boost/phoenix/operator/detail/mem_fun_ptr_eval_result_of.hpp>", private, "<boost/phoenix/operator/member.hpp>", public ] },
+ { include: ["<boost/phoenix/operator/detail/mem_fun_ptr_gen.hpp>", private, "<boost/phoenix/operator/member.hpp>", public ] },
+ { include: ["<boost/phoenix/operator/detail/undef_operator.hpp>", private, "<boost/phoenix/operator/arithmetic.hpp>", public ] },
+ { include: ["<boost/phoenix/operator/detail/undef_operator.hpp>", private, "<boost/phoenix/operator/bitwise.hpp>", public ] },
+ { include: ["<boost/phoenix/operator/detail/undef_operator.hpp>", private, "<boost/phoenix/operator/comparison.hpp>", public ] },
+ { include: ["<boost/phoenix/operator/detail/undef_operator.hpp>", private, "<boost/phoenix/operator/logical.hpp>", public ] },
+ { include: ["<boost/phoenix/operator/detail/undef_operator.hpp>", private, "<boost/phoenix/operator/member.hpp>", public ] },
+ { include: ["<boost/phoenix/operator/detail/undef_operator.hpp>", private, "<boost/phoenix/operator/self.hpp>", public ] },
+ { include: ["<boost/phoenix/scope/detail/dynamic.hpp>", private, "<boost/phoenix/scope/dynamic.hpp>", public ] },
+ { include: ["<boost/phoenix/scope/detail/local_gen.hpp>", private, "<boost/phoenix/scope/lambda.hpp>", public ] },
+ { include: ["<boost/phoenix/scope/detail/local_gen.hpp>", private, "<boost/phoenix/scope/let.hpp>", public ] },
+ { include: ["<boost/phoenix/scope/detail/local_variable.hpp>", private, "<boost/phoenix/scope/local_variable.hpp>", public ] },
+ { include: ["<boost/phoenix/statement/detail/catch_push_back.hpp>", private, "<boost/phoenix/statement/try_catch.hpp>", public ] },
+ { include: ["<boost/phoenix/statement/detail/switch.hpp>", private, "<boost/phoenix/statement/switch.hpp>", public ] },
+ { include: ["<boost/phoenix/statement/detail/try_catch_eval.hpp>", private, "<boost/phoenix/statement/try_catch.hpp>", public ] },
+ { include: ["<boost/phoenix/statement/detail/try_catch_expression.hpp>", private, "<boost/phoenix/statement/try_catch.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/begin.hpp>", private, "<boost/phoenix/stl/algorithm/iteration.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/begin.hpp>", private, "<boost/phoenix/stl/algorithm/querying.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/begin.hpp>", private, "<boost/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/decay_array.hpp>", private, "<boost/phoenix/stl/algorithm/querying.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/decay_array.hpp>", private, "<boost/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/end.hpp>", private, "<boost/phoenix/stl/algorithm/iteration.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/end.hpp>", private, "<boost/phoenix/stl/algorithm/querying.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/end.hpp>", private, "<boost/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/has_equal_range.hpp>", private, "<boost/phoenix/stl/algorithm/querying.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/has_find.hpp>", private, "<boost/phoenix/stl/algorithm/querying.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/has_lower_bound.hpp>", private, "<boost/phoenix/stl/algorithm/querying.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/has_remove.hpp>", private, "<boost/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/has_remove_if.hpp>", private, "<boost/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/has_reverse.hpp>", private, "<boost/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/has_sort.hpp>", private, "<boost/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/has_unique.hpp>", private, "<boost/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/algorithm/detail/has_upper_bound.hpp>", private, "<boost/phoenix/stl/algorithm/querying.hpp>", public ] },
+ { include: ["<boost/phoenix/stl/container/detail/container.hpp>", private, "<boost/phoenix/stl/container/container.hpp>", public ] },
+ { include: ["<boost/phoenix/support/detail/iterate_define.hpp>", private, "<boost/phoenix/support/iterate.hpp>", public ] },
+ { include: ["<boost/pool/detail/guard.hpp>", private, "<boost/pool/singleton_pool.hpp>", public ] },
+ { include: ["<boost/pool/detail/mutex.hpp>", private, "<boost/pool/poolfwd.hpp>", public ] },
+ { include: ["<boost/pool/detail/pool_construct.ipp>", private, "<boost/pool/object_pool.hpp>", public ] },
+ { include: ["<boost/pool/detail/pool_construct_simple.ipp>", private, "<boost/pool/object_pool.hpp>", public ] },
+ { include: ["<boost/predef/detail/_cassert.h>", private, "<boost/predef/library/c/_prefix.h>", public ] },
+ { include: ["<boost/predef/detail/_exception.h>", private, "<boost/predef/library/std/_prefix.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/aix.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/amigaos.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/android.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/beos.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/bsd/bsdi.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/bsd/dragonfly.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/bsd/free.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/bsd.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/bsd/net.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/bsd/open.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/cygwin.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/hpux.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/irix.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/linux.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/macos.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/os400.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/qnxnto.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/solaris.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/unix.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/vms.h>", public ] },
+ { include: ["<boost/predef/detail/os_detected.h>", private, "<boost/predef/os/windows.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/alpha.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/arm.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/blackfin.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/convex.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/ia64.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/m68k.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/mips.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/parisc.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/ppc.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/pyramid.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/rs6k.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/sparc.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/superh.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/sys370.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/sys390.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/x86/32.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/x86/64.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/x86.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/architecture/z.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/borland.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/clang.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/comeau.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/compaq.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/diab.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/digitalmars.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/dignus.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/edg.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/ekopath.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/gcc.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/gcc_xml.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/greenhills.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/hp_acc.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/iar.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/ibm.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/intel.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/kai.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/llvm.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/metaware.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/metrowerks.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/microtec.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/mpw.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/palm.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/pgi.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/sgi_mipspro.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/sunpro.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/tendra.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/visualc.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/compiler/watcom.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/language/objc.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/language/stdc.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/language/stdcpp.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/library/c/gnu.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/library/c/uc.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/library/c/vms.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/library/c/zos.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/library/std/cxx.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/library/std/dinkumware.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/library/std/libcomo.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/library/std/modena.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/library/std/msl.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/library/std/roguewave.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/library/std/sgi.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/library/std/stdcpp3.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/library/std/stlport.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/library/std/vacpp.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/make.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/aix.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/amigaos.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/android.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/beos.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/bsd/bsdi.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/bsd/dragonfly.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/bsd/free.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/bsd.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/bsd/net.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/bsd/open.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/cygwin.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/hpux.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/irix.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/linux.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/macos.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/os400.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/qnxnto.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/solaris.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/unix.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/vms.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/os/windows.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/other/endian.h>", public ] },
+ { include: ["<boost/predef/detail/test.h>", private, "<boost/predef/platform/mingw.h>", public ] },
+ { include: ["<boost/preprocessor/arithmetic/detail/div_base.hpp>", private, "<boost/preprocessor/arithmetic/div.hpp>", public ] },
+ { include: ["<boost/preprocessor/arithmetic/detail/div_base.hpp>", private, "<boost/preprocessor/arithmetic/mod.hpp>", public ] },
+ { include: ["<boost/preprocessor/control/detail/dmc/while.hpp>", private, "<boost/preprocessor/control/while.hpp>", public ] },
+ { include: ["<boost/preprocessor/control/detail/edg/while.hpp>", private, "<boost/preprocessor/control/while.hpp>", public ] },
+ { include: ["<boost/preprocessor/control/detail/msvc/while.hpp>", private, "<boost/preprocessor/control/while.hpp>", public ] },
+ { include: ["<boost/preprocessor/control/detail/while.hpp>", private, "<boost/preprocessor/control/while.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/auto_rec.hpp>", private, "<boost/preprocessor/control/deduce_d.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/auto_rec.hpp>", private, "<boost/preprocessor/control/while.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/auto_rec.hpp>", private, "<boost/preprocessor/list/fold_left.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/auto_rec.hpp>", private, "<boost/preprocessor/list/fold_right.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/auto_rec.hpp>", private, "<boost/preprocessor/repetition/deduce_r.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/auto_rec.hpp>", private, "<boost/preprocessor/repetition/deduce_z.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/auto_rec.hpp>", private, "<boost/preprocessor/repetition/enum.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/auto_rec.hpp>", private, "<boost/preprocessor/repetition/enum_shifted.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/auto_rec.hpp>", private, "<boost/preprocessor/repetition/enum_trailing.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/auto_rec.hpp>", private, "<boost/preprocessor/repetition/for.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/auto_rec.hpp>", private, "<boost/preprocessor/repetition/repeat_from_to.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/auto_rec.hpp>", private, "<boost/preprocessor/repetition/repeat.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/auto_rec.hpp>", private, "<boost/preprocessor/seq/fold_left.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/auto_rec.hpp>", private, "<boost/preprocessor/seq/fold_right.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/is_binary.hpp>", private, "<boost/parameter/name.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/is_binary.hpp>", private, "<boost/preprocessor/list/adt.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/is_binary.hpp>", private, "<boost/spirit/home/classic/utility/rule_parser.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/is_binary.hpp>", private, "<boost/tti/has_template.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/is_nullary.hpp>", private, "<boost/parameter/preprocessor.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/is_unary.hpp>", private, "<boost/local_function/aux_/preprocessor/traits/decl_sign_/any_bind_type.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/is_unary.hpp>", private, "<boost/preprocessor/facilities/apply.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/is_unary.hpp>", private, "<boost/spirit/home/classic/utility/rule_parser.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/is_unary.hpp>", private, "<boost/typeof/template_encoding.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/split.hpp>", private, "<boost/parameter/aux_/preprocessor/for_each.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/split.hpp>", private, "<boost/parameter/name.hpp>", public ] },
+ { include: ["<boost/preprocessor/detail/split.hpp>", private, "<boost/preprocessor/facilities/is_empty.hpp>", public ] },
+ { include: ["<boost/preprocessor/list/detail/dmc/fold_left.hpp>", private, "<boost/preprocessor/list/fold_left.hpp>", public ] },
+ { include: ["<boost/preprocessor/list/detail/edg/fold_left.hpp>", private, "<boost/preprocessor/list/fold_left.hpp>", public ] },
+ { include: ["<boost/preprocessor/list/detail/edg/fold_right.hpp>", private, "<boost/preprocessor/list/fold_right.hpp>", public ] },
+ { include: ["<boost/preprocessor/list/detail/fold_left.hpp>", private, "<boost/preprocessor/list/fold_left.hpp>", public ] },
+ { include: ["<boost/preprocessor/list/detail/fold_right.hpp>", private, "<boost/preprocessor/list/fold_right.hpp>", public ] },
+ { include: ["<boost/preprocessor/repetition/detail/dmc/for.hpp>", private, "<boost/preprocessor/repetition/for.hpp>", public ] },
+ { include: ["<boost/preprocessor/repetition/detail/edg/for.hpp>", private, "<boost/preprocessor/repetition/for.hpp>", public ] },
+ { include: ["<boost/preprocessor/repetition/detail/for.hpp>", private, "<boost/preprocessor/repetition/for.hpp>", public ] },
+ { include: ["<boost/preprocessor/repetition/detail/msvc/for.hpp>", private, "<boost/preprocessor/repetition/for.hpp>", public ] },
+ { include: ["<boost/preprocessor/seq/detail/binary_transform.hpp>", private, "<boost/preprocessor/seq/to_list.hpp>", public ] },
+ { include: ["<boost/preprocessor/seq/detail/split.hpp>", private, "<boost/preprocessor/seq/first_n.hpp>", public ] },
+ { include: ["<boost/preprocessor/seq/detail/split.hpp>", private, "<boost/preprocessor/seq/rest_n.hpp>", public ] },
+ { include: ["<boost/preprocessor/slot/detail/def.hpp>", private, "<boost/preprocessor/slot/counter.hpp>", public ] },
+ { include: ["<boost/preprocessor/slot/detail/def.hpp>", private, "<boost/preprocessor/slot/slot.hpp>", public ] },
+ { include: ["<boost/program_options/detail/cmdline.hpp>", private, "<boost/program_options/parsers.hpp>", public ] },
+ { include: ["<boost/program_options/detail/parsers.hpp>", private, "<boost/program_options/parsers.hpp>", public ] },
+ { include: ["<boost/program_options/detail/value_semantic.hpp>", private, "<boost/program_options/value_semantic.hpp>", public ] },
+ { include: ["<boost/property_map/parallel/impl/distributed_property_map.ipp>", private, "<boost/property_map/parallel/distributed_property_map.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/exception_implementation.hpp>", private, "<boost/property_tree/exceptions.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/file_parser_error.hpp>", private, "<boost/property_tree/ini_parser.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/info_parser_error.hpp>", private, "<boost/property_tree/info_parser.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/info_parser_read.hpp>", private, "<boost/property_tree/info_parser.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/info_parser_write.hpp>", private, "<boost/property_tree/info_parser.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/info_parser_writer_settings.hpp>", private, "<boost/property_tree/info_parser.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/json_parser_error.hpp>", private, "<boost/property_tree/json_parser.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/json_parser_read.hpp>", private, "<boost/property_tree/json_parser.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/json_parser_write.hpp>", private, "<boost/property_tree/json_parser.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/ptree_implementation.hpp>", private, "<boost/property_tree/ptree.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/ptree_utils.hpp>", private, "<boost/property_tree/ini_parser.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/ptree_utils.hpp>", private, "<boost/property_tree/ptree.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/ptree_utils.hpp>", private, "<boost/property_tree/string_path.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/xml_parser_error.hpp>", private, "<boost/property_tree/xml_parser.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/xml_parser_flags.hpp>", private, "<boost/property_tree/xml_parser.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/xml_parser_read_rapidxml.hpp>", private, "<boost/property_tree/xml_parser.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/xml_parser_utils.hpp>", private, "<boost/graph/graphml.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/xml_parser_write.hpp>", private, "<boost/property_tree/xml_parser.hpp>", public ] },
+ { include: ["<boost/property_tree/detail/xml_parser_writer_settings.hpp>", private, "<boost/property_tree/xml_parser.hpp>", public ] },
+ { include: ["<boost/proto/context/detail/callable_eval.hpp>", private, "<boost/proto/context/callable.hpp>", public ] },
+ { include: ["<boost/proto/context/detail/default_eval.hpp>", private, "<boost/proto/context/default.hpp>", public ] },
+ { include: ["<boost/proto/context/detail/null_eval.hpp>", private, "<boost/proto/context/null.hpp>", public ] },
+ { include: ["<boost/proto/detail/and_n.hpp>", private, "<boost/proto/matches.hpp>", public ] },
+ { include: ["<boost/proto/detail/any.hpp>", private, "<boost/proto/transform/impl.hpp>", public ] },
+ { include: ["<boost/proto/detail/args.hpp>", private, "<boost/proto/args.hpp>", public ] },
+ { include: ["<boost/proto/detail/as_expr.hpp>", private, "<boost/proto/domain.hpp>", public ] },
+ { include: ["<boost/proto/detail/as_lvalue.hpp>", private, "<boost/proto/transform/call.hpp>", public ] },
+ { include: ["<boost/proto/detail/as_lvalue.hpp>", private, "<boost/proto/transform/make.hpp>", public ] },
+ { include: ["<boost/proto/detail/basic_expr.hpp>", private, "<boost/proto/expr.hpp>", public ] },
+ { include: ["<boost/proto/detail/decltype.hpp>", private, "<boost/proto/context/default.hpp>", public ] },
+ { include: ["<boost/proto/detail/decltype.hpp>", private, "<boost/proto/transform/default.hpp>", public ] },
+ { include: ["<boost/proto/detail/deduce_domain.hpp>", private, "<boost/proto/domain.hpp>", public ] },
+ { include: ["<boost/proto/detail/deep_copy.hpp>", private, "<boost/proto/deep_copy.hpp>", public ] },
+ { include: ["<boost/proto/detail/deprecated.hpp>", private, "<boost/proto/make_expr.hpp>", public ] },
+ { include: ["<boost/proto/detail/expr.hpp>", private, "<boost/proto/expr.hpp>", public ] },
+ { include: ["<boost/proto/detail/extends_funop_const.hpp>", private, "<boost/proto/extends.hpp>", public ] },
+ { include: ["<boost/proto/detail/extends_funop.hpp>", private, "<boost/proto/extends.hpp>", public ] },
+ { include: ["<boost/proto/detail/funop.hpp>", private, "<boost/proto/expr.hpp>", public ] },
+ { include: ["<boost/proto/detail/generate_by_value.hpp>", private, "<boost/proto/generate.hpp>", public ] },
+ { include: ["<boost/proto/detail/ignore_unused.hpp>", private, "<boost/proto/transform/make.hpp>", public ] },
+ { include: ["<boost/proto/detail/ignore_unused.hpp>", private, "<boost/proto/transform/pass_through.hpp>", public ] },
+ { include: ["<boost/proto/detail/is_noncopyable.hpp>", private, "<boost/proto/args.hpp>", public ] },
+ { include: ["<boost/proto/detail/is_noncopyable.hpp>", private, "<boost/proto/transform/env.hpp>", public ] },
+ { include: ["<boost/proto/detail/lambda_matches.hpp>", private, "<boost/proto/matches.hpp>", public ] },
+ { include: ["<boost/proto/detail/make_expr_funop.hpp>", private, "<boost/proto/make_expr.hpp>", public ] },
+ { include: ["<boost/proto/detail/make_expr_.hpp>", private, "<boost/proto/make_expr.hpp>", public ] },
+ { include: ["<boost/proto/detail/make_expr.hpp>", private, "<boost/proto/make_expr.hpp>", public ] },
+ { include: ["<boost/proto/detail/matches_.hpp>", private, "<boost/proto/matches.hpp>", public ] },
+ { include: ["<boost/proto/detail/or_n.hpp>", private, "<boost/proto/matches.hpp>", public ] },
+ { include: ["<boost/proto/detail/poly_function.hpp>", private, "<boost/proto/make_expr.hpp>", public ] },
+ { include: ["<boost/proto/detail/poly_function.hpp>", private, "<boost/proto/transform/call.hpp>", public ] },
+ { include: ["<boost/proto/detail/poly_function.hpp>", private, "<boost/proto/transform/env.hpp>", public ] },
+ { include: ["<boost/proto/detail/remove_typename.hpp>", private, "<boost/proto/extends.hpp>", public ] },
+ { include: ["<boost/proto/detail/static_const.hpp>", private, "<boost/proto/transform/impl.hpp>", public ] },
+ { include: ["<boost/proto/detail/template_arity.hpp>", private, "<boost/proto/matches.hpp>", public ] },
+ { include: ["<boost/proto/detail/template_arity.hpp>", private, "<boost/proto/traits.hpp>", public ] },
+ { include: ["<boost/proto/detail/template_arity.hpp>", private, "<boost/proto/transform/make.hpp>", public ] },
+ { include: ["<boost/proto/detail/traits.hpp>", private, "<boost/proto/traits.hpp>", public ] },
+ { include: ["<boost/proto/detail/unpack_expr_.hpp>", private, "<boost/proto/make_expr.hpp>", public ] },
+ { include: ["<boost/proto/detail/vararg_matches_impl.hpp>", private, "<boost/proto/matches.hpp>", public ] },
+ { include: ["<boost/proto/transform/detail/call.hpp>", private, "<boost/proto/transform/call.hpp>", public ] },
+ { include: ["<boost/proto/transform/detail/construct_funop.hpp>", private, "<boost/proto/transform/make.hpp>", public ] },
+ { include: ["<boost/proto/transform/detail/construct_pod_funop.hpp>", private, "<boost/proto/transform/make.hpp>", public ] },
+ { include: ["<boost/proto/transform/detail/default_function_impl.hpp>", private, "<boost/proto/transform/default.hpp>", public ] },
+ { include: ["<boost/proto/transform/detail/fold_impl.hpp>", private, "<boost/proto/transform/fold.hpp>", public ] },
+ { include: ["<boost/proto/transform/detail/lazy.hpp>", private, "<boost/proto/transform/lazy.hpp>", public ] },
+ { include: ["<boost/proto/transform/detail/make.hpp>", private, "<boost/proto/transform/make.hpp>", public ] },
+ { include: ["<boost/proto/transform/detail/pack.hpp>", private, "<boost/proto/transform/call.hpp>", public ] },
+ { include: ["<boost/proto/transform/detail/pack.hpp>", private, "<boost/proto/transform/lazy.hpp>", public ] },
+ { include: ["<boost/proto/transform/detail/pack.hpp>", private, "<boost/proto/transform/make.hpp>", public ] },
+ { include: ["<boost/proto/transform/detail/pass_through_impl.hpp>", private, "<boost/proto/transform/pass_through.hpp>", public ] },
+ { include: ["<boost/proto/transform/detail/when.hpp>", private, "<boost/proto/transform/when.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/associative_ptr_container.hpp>", private, "<boost/ptr_container/ptr_map_adapter.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/associative_ptr_container.hpp>", private, "<boost/ptr_container/ptr_set_adapter.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/map_iterator.hpp>", private, "<boost/ptr_container/ptr_map_adapter.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/meta_functions.hpp>", private, "<boost/ptr_container/ptr_map_adapter.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/meta_functions.hpp>", private, "<boost/ptr_container/ptr_set_adapter.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/reversible_ptr_container.hpp>", private, "<boost/ptr_container/ptr_sequence_adapter.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/serialize_ptr_map_adapter.hpp>", private, "<boost/ptr_container/serialize_ptr_map.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/serialize_ptr_map_adapter.hpp>", private, "<boost/ptr_container/serialize_ptr_unordered_map.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/serialize_reversible_cont.hpp>", private, "<boost/ptr_container/serialize_ptr_array.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/serialize_reversible_cont.hpp>", private, "<boost/ptr_container/serialize_ptr_circular_buffer.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/serialize_reversible_cont.hpp>", private, "<boost/ptr_container/serialize_ptr_deque.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/serialize_reversible_cont.hpp>", private, "<boost/ptr_container/serialize_ptr_list.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/serialize_reversible_cont.hpp>", private, "<boost/ptr_container/serialize_ptr_set.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/serialize_reversible_cont.hpp>", private, "<boost/ptr_container/serialize_ptr_unordered_set.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/serialize_reversible_cont.hpp>", private, "<boost/ptr_container/serialize_ptr_vector.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/void_ptr_iterator.hpp>", private, "<boost/ptr_container/ptr_sequence_adapter.hpp>", public ] },
+ { include: ["<boost/ptr_container/detail/void_ptr_iterator.hpp>", private, "<boost/ptr_container/ptr_set_adapter.hpp>", public ] },
+ { include: ["<boost/python/detail/borrowed_ptr.hpp>", private, "<boost/python/borrowed.hpp>", public ] },
+ { include: ["<boost/python/detail/caller.hpp>", private, "<boost/python/make_constructor.hpp>", public ] },
+ { include: ["<boost/python/detail/caller.hpp>", private, "<boost/python/make_function.hpp>", public ] },
+ { include: ["<boost/python/detail/caller.hpp>", private, "<boost/python/object/function_handle.hpp>", public ] },
+ { include: ["<boost/python/detail/construct.hpp>", private, "<boost/python/converter/obj_mgr_arg_from_python.hpp>", public ] },
+ { include: ["<boost/python/detail/convertible.hpp>", private, "<boost/python/cast.hpp>", public ] },
+ { include: ["<boost/python/detail/convertible.hpp>", private, "<boost/python/converter/arg_to_python.hpp>", public ] },
+ { include: ["<boost/python/detail/copy_ctor_mutates_rhs.hpp>", private, "<boost/python/extract.hpp>", public ] },
+ { include: ["<boost/python/detail/copy_ctor_mutates_rhs.hpp>", private, "<boost/python/object/forward.hpp>", public ] },
+ { include: ["<boost/python/detail/dealloc.hpp>", private, "<boost/python/opaque_pointer_converter.hpp>", public ] },
+ { include: ["<boost/python/detail/decref_guard.hpp>", private, "<boost/python/object/make_instance.hpp>", public ] },
+ { include: ["<boost/python/detail/defaults_def.hpp>", private, "<boost/python/overloads.hpp>", public ] },
+ { include: ["<boost/python/detail/def_helper_fwd.hpp>", private, "<boost/python/object_core.hpp>", public ] },
+ { include: ["<boost/python/detail/def_helper.hpp>", private, "<boost/python/class.hpp>", public ] },
+ { include: ["<boost/python/detail/def_helper.hpp>", private, "<boost/python/def.hpp>", public ] },
+ { include: ["<boost/python/detail/dependent.hpp>", private, "<boost/python/back_reference.hpp>", public ] },
+ { include: ["<boost/python/detail/dependent.hpp>", private, "<boost/python/object_core.hpp>", public ] },
+ { include: ["<boost/python/detail/destroy.hpp>", private, "<boost/python/converter/obj_mgr_arg_from_python.hpp>", public ] },
+ { include: ["<boost/python/detail/destroy.hpp>", private, "<boost/python/converter/rvalue_from_python_data.hpp>", public ] },
+ { include: ["<boost/python/detail/exception_handler.hpp>", private, "<boost/python/exception_translator.hpp>", public ] },
+ { include: ["<boost/python/detail/force_instantiate.hpp>", private, "<boost/python/class.hpp>", public ] },
+ { include: ["<boost/python/detail/force_instantiate.hpp>", private, "<boost/python/object/class_metadata.hpp>", public ] },
+ { include: ["<boost/python/detail/force_instantiate.hpp>", private, "<boost/python/object/pointer_holder.hpp>", public ] },
+ { include: ["<boost/python/detail/force_instantiate.hpp>", private, "<boost/python/object/value_holder.hpp>", public ] },
+ { include: ["<boost/python/detail/force_instantiate.hpp>", private, "<boost/python/return_opaque_pointer.hpp>", public ] },
+ { include: ["<boost/python/detail/indirect_traits.hpp>", private, "<boost/python/converter/arg_from_python.hpp>", public ] },
+ { include: ["<boost/python/detail/indirect_traits.hpp>", private, "<boost/python/converter/arg_to_python.hpp>", public ] },
+ { include: ["<boost/python/detail/indirect_traits.hpp>", private, "<boost/python/converter/object_manager.hpp>", public ] },
+ { include: ["<boost/python/detail/indirect_traits.hpp>", private, "<boost/python/copy_const_reference.hpp>", public ] },
+ { include: ["<boost/python/detail/indirect_traits.hpp>", private, "<boost/python/copy_non_const_reference.hpp>", public ] },
+ { include: ["<boost/python/detail/indirect_traits.hpp>", private, "<boost/python/data_members.hpp>", public ] },
+ { include: ["<boost/python/detail/indirect_traits.hpp>", private, "<boost/python/manage_new_object.hpp>", public ] },
+ { include: ["<boost/python/detail/indirect_traits.hpp>", private, "<boost/python/reference_existing_object.hpp>", public ] },
+ { include: ["<boost/python/detail/is_xxx.hpp>", private, "<boost/python/object_core.hpp>", public ] },
+ { include: ["<boost/python/detail/make_keyword_range_fn.hpp>", private, "<boost/python/init.hpp>", public ] },
+ { include: ["<boost/python/detail/mpl_lambda.hpp>", private, "<boost/python/args.hpp>", public ] },
+ { include: ["<boost/python/detail/msvc_typeinfo.hpp>", private, "<boost/python/type_id.hpp>", public ] },
+ { include: ["<boost/python/detail/none.hpp>", private, "<boost/python/converter/builtin_converters.hpp>", public ] },
+ { include: ["<boost/python/detail/none.hpp>", private, "<boost/python/make_constructor.hpp>", public ] },
+ { include: ["<boost/python/detail/none.hpp>", private, "<boost/python/object/make_instance.hpp>", public ] },
+ { include: ["<boost/python/detail/none.hpp>", private, "<boost/python/opaque_pointer_converter.hpp>", public ] },
+ { include: ["<boost/python/detail/none.hpp>", private, "<boost/python/return_arg.hpp>", public ] },
+ { include: ["<boost/python/detail/none.hpp>", private, "<boost/python/to_python_indirect.hpp>", public ] },
+ { include: ["<boost/python/detail/not_specified.hpp>", private, "<boost/python/class_fwd.hpp>", public ] },
+ { include: ["<boost/python/detail/not_specified.hpp>", private, "<boost/python/data_members.hpp>", public ] },
+ { include: ["<boost/python/detail/not_specified.hpp>", private, "<boost/python/object/class_metadata.hpp>", public ] },
+ { include: ["<boost/python/detail/not_specified.hpp>", private, "<boost/python/operators.hpp>", public ] },
+ { include: ["<boost/python/detail/nullary_function_adaptor.hpp>", private, "<boost/python/pure_virtual.hpp>", public ] },
+ { include: ["<boost/python/detail/operator_id.hpp>", private, "<boost/python/class.hpp>", public ] },
+ { include: ["<boost/python/detail/operator_id.hpp>", private, "<boost/python/operators.hpp>", public ] },
+ { include: ["<boost/python/detail/overloads_fwd.hpp>", private, "<boost/python/class.hpp>", public ] },
+ { include: ["<boost/python/detail/overloads_fwd.hpp>", private, "<boost/python/def.hpp>", public ] },
+ { include: ["<boost/python/detail/overloads_fwd.hpp>", private, "<boost/python/overloads.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/arg_from_python.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/args_fwd.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/args.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/back_reference.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/bases.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/base_type_traits.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/borrowed.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/call.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/call_method.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/cast.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/class_fwd.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/class.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/converter/arg_from_python.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/converter/builtin_converters.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/converter/from_python.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/converter/obj_mgr_arg_from_python.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/converter/pyobject_traits.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/converter/pytype_function.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/converter/pytype_object_mgr_traits.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/converter/registrations.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/converter/to_python_function_type.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/copy_const_reference.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/copy_non_const_reference.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/data_members.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/default_call_policies.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/def.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/def_visitor.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/dict.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/enum.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/errors.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/exception_translator.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/extract.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/handle_fwd.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/handle.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/has_back_reference.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/implicit.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/init.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/instance_holder.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/iterator.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/list.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/long.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/lvalue_from_pytype.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/make_constructor.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/make_function.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/manage_new_object.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/module.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/module_init.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/numeric.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object_attributes.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object/class.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object_core.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object/function.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object/function_object.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object_fwd.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object/instance.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object_items.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object/iterator.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object/life_support.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object/make_holder.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object/make_instance.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object_operators.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object/pickle_support.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object_protocol_core.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object_protocol.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/object_slices.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/opaque_pointer_converter.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/operators.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/other.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/overloads.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/override.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/pointee.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/proxy.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/ptr.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/raw_function.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/refcount.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/reference_existing_object.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/return_by_value.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/return_internal_reference.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/return_opaque_pointer.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/return_value_policy.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/scope.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/self.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/signature.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/slice.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/slice_nil.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/ssize_t.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/stl_iterator.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/str.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/tag.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/to_python_converter.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/to_python_indirect.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/to_python_value.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/tuple.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/type_id.hpp>", public ] },
+ { include: ["<boost/python/detail/prefix.hpp>", private, "<boost/python/with_custodian_and_ward.hpp>", public ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/args.hpp>", public ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/call.hpp>", public ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/call_method.hpp>", public ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/object_core.hpp>", public ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/object/make_holder.hpp>", public ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/object/pointer_holder.hpp>", public ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/object/value_holder.hpp>", public ] },
+ { include: ["<boost/python/detail/preprocessor.hpp>", private, "<boost/python/signature.hpp>", public ] },
+ { include: ["<boost/python/detail/python_type.hpp>", private, "<boost/python/object/make_holder.hpp>", public ] },
+ { include: ["<boost/python/detail/raw_pyobject.hpp>", private, "<boost/python/back_reference.hpp>", public ] },
+ { include: ["<boost/python/detail/raw_pyobject.hpp>", private, "<boost/python/converter/obj_mgr_arg_from_python.hpp>", public ] },
+ { include: ["<boost/python/detail/raw_pyobject.hpp>", private, "<boost/python/converter/pytype_object_mgr_traits.hpp>", public ] },
+ { include: ["<boost/python/detail/raw_pyobject.hpp>", private, "<boost/python/handle.hpp>", public ] },
+ { include: ["<boost/python/detail/raw_pyobject.hpp>", private, "<boost/python/object_core.hpp>", public ] },
+ { include: ["<boost/python/detail/raw_pyobject.hpp>", private, "<boost/python/object/iterator.hpp>", public ] },
+ { include: ["<boost/python/detail/referent_storage.hpp>", private, "<boost/python/converter/arg_from_python.hpp>", public ] },
+ { include: ["<boost/python/detail/referent_storage.hpp>", private, "<boost/python/converter/obj_mgr_arg_from_python.hpp>", public ] },
+ { include: ["<boost/python/detail/referent_storage.hpp>", private, "<boost/python/converter/rvalue_from_python_data.hpp>", public ] },
+ { include: ["<boost/python/detail/scope.hpp>", private, "<boost/python/def.hpp>", public ] },
+ { include: ["<boost/python/detail/sfinae.hpp>", private, "<boost/python/wrapper.hpp>", public ] },
+ { include: ["<boost/python/detail/signature.hpp>", private, "<boost/python/object/function_doc_signature.hpp>", public ] },
+ { include: ["<boost/python/detail/signature.hpp>", private, "<boost/python/object/py_function.hpp>", public ] },
+ { include: ["<boost/python/detail/string_literal.hpp>", private, "<boost/python/converter/arg_to_python.hpp>", public ] },
+ { include: ["<boost/python/detail/string_literal.hpp>", private, "<boost/python/object_core.hpp>", public ] },
+ { include: ["<boost/python/detail/target.hpp>", private, "<boost/python/iterator.hpp>", public ] },
+ { include: ["<boost/python/detail/translate_exception.hpp>", private, "<boost/python/exception_translator.hpp>", public ] },
+ { include: ["<boost/python/detail/type_list.hpp>", private, "<boost/python/args.hpp>", public ] },
+ { include: ["<boost/python/detail/type_list.hpp>", private, "<boost/python/bases.hpp>", public ] },
+ { include: ["<boost/python/detail/type_list.hpp>", private, "<boost/python/init.hpp>", public ] },
+ { include: ["<boost/python/detail/type_list.hpp>", private, "<boost/python/signature.hpp>", public ] },
+ { include: ["<boost/python/detail/unwind_type.hpp>", private, "<boost/python/converter/pytype_function.hpp>", public ] },
+ { include: ["<boost/python/detail/unwrap_type_id.hpp>", private, "<boost/python/class.hpp>", public ] },
+ { include: ["<boost/python/detail/unwrap_wrapper.hpp>", private, "<boost/python/class.hpp>", public ] },
+ { include: ["<boost/python/detail/unwrap_wrapper.hpp>", private, "<boost/python/operators.hpp>", public ] },
+ { include: ["<boost/python/detail/value_arg.hpp>", private, "<boost/python/data_members.hpp>", public ] },
+ { include: ["<boost/python/detail/value_arg.hpp>", private, "<boost/python/default_call_policies.hpp>", public ] },
+ { include: ["<boost/python/detail/value_arg.hpp>", private, "<boost/python/object/forward.hpp>", public ] },
+ { include: ["<boost/python/detail/value_arg.hpp>", private, "<boost/python/return_arg.hpp>", public ] },
+ { include: ["<boost/python/detail/value_arg.hpp>", private, "<boost/python/return_by_value.hpp>", public ] },
+ { include: ["<boost/python/detail/value_arg.hpp>", private, "<boost/python/return_opaque_pointer.hpp>", public ] },
+ { include: ["<boost/python/detail/value_arg.hpp>", private, "<boost/python/to_python_value.hpp>", public ] },
+ { include: ["<boost/python/detail/value_is_shared_ptr.hpp>", private, "<boost/python/converter/arg_to_python.hpp>", public ] },
+ { include: ["<boost/python/detail/value_is_shared_ptr.hpp>", private, "<boost/python/to_python_value.hpp>", public ] },
+ { include: ["<boost/python/detail/void_ptr.hpp>", private, "<boost/python/converter/arg_from_python.hpp>", public ] },
+ { include: ["<boost/python/detail/void_ptr.hpp>", private, "<boost/python/converter/return_from_python.hpp>", public ] },
+ { include: ["<boost/python/detail/void_ptr.hpp>", private, "<boost/python/extract.hpp>", public ] },
+ { include: ["<boost/python/detail/void_ptr.hpp>", private, "<boost/python/lvalue_from_pytype.hpp>", public ] },
+ { include: ["<boost/python/detail/void_return.hpp>", private, "<boost/python/call.hpp>", public ] },
+ { include: ["<boost/python/detail/void_return.hpp>", private, "<boost/python/call_method.hpp>", public ] },
+ { include: ["<boost/python/detail/void_return.hpp>", private, "<boost/python/converter/return_from_python.hpp>", public ] },
+ { include: ["<boost/python/detail/void_return.hpp>", private, "<boost/python/extract.hpp>", public ] },
+ { include: ["<boost/python/detail/wrapper_base.hpp>", private, "<boost/python/object/pointer_holder.hpp>", public ] },
+ { include: ["<boost/python/detail/wrapper_base.hpp>", private, "<boost/python/wrapper.hpp>", public ] },
+ { include: ["<boost/python/suite/indexing/detail/indexing_suite_detail.hpp>", private, "<boost/python/suite/indexing/indexing_suite.hpp>", public ] },
+ { include: ["<boost/random/detail/auto_link.hpp>", private, "<boost/random/random_device.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/additive_combine.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/bernoulli_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/binomial_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/cauchy_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/chi_squared_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/discard_block.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/discrete_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/exponential_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/gamma_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/geometric_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/independent_bits.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/inversive_congruential.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/lagged_fibonacci.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/linear_congruential.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/linear_feedback_shift.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/lognormal_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/mersenne_twister.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/negative_binomial_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/normal_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/piecewise_constant_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/piecewise_linear_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/poisson_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/subtract_with_carry.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/triangle_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/uniform_01.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/uniform_int_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/uniform_on_sphere.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/uniform_real_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/uniform_smallint.hpp>", public ] },
+ { include: ["<boost/random/detail/config.hpp>", private, "<boost/random/xor_combine.hpp>", public ] },
+ { include: ["<boost/random/detail/const_mod.hpp>", private, "<boost/random/inversive_congruential.hpp>", public ] },
+ { include: ["<boost/random/detail/const_mod.hpp>", private, "<boost/random/linear_congruential.hpp>", public ] },
+ { include: ["<boost/random/detail/disable_warnings.hpp>", private, "<boost/random/binomial_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/disable_warnings.hpp>", private, "<boost/random/discrete_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/disable_warnings.hpp>", private, "<boost/random/inversive_congruential.hpp>", public ] },
+ { include: ["<boost/random/detail/disable_warnings.hpp>", private, "<boost/random/linear_congruential.hpp>", public ] },
+ { include: ["<boost/random/detail/disable_warnings.hpp>", private, "<boost/random/poisson_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/disable_warnings.hpp>", private, "<boost/random/random_number_generator.hpp>", public ] },
+ { include: ["<boost/random/detail/disable_warnings.hpp>", private, "<boost/random/shuffle_order.hpp>", public ] },
+ { include: ["<boost/random/detail/disable_warnings.hpp>", private, "<boost/random/uniform_01.hpp>", public ] },
+ { include: ["<boost/random/detail/disable_warnings.hpp>", private, "<boost/random/variate_generator.hpp>", public ] },
+ { include: ["<boost/random/detail/enable_warnings.hpp>", private, "<boost/random/binomial_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/enable_warnings.hpp>", private, "<boost/random/discrete_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/enable_warnings.hpp>", private, "<boost/random/inversive_congruential.hpp>", public ] },
+ { include: ["<boost/random/detail/enable_warnings.hpp>", private, "<boost/random/linear_congruential.hpp>", public ] },
+ { include: ["<boost/random/detail/enable_warnings.hpp>", private, "<boost/random/poisson_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/enable_warnings.hpp>", private, "<boost/random/random_number_generator.hpp>", public ] },
+ { include: ["<boost/random/detail/enable_warnings.hpp>", private, "<boost/random/shuffle_order.hpp>", public ] },
+ { include: ["<boost/random/detail/enable_warnings.hpp>", private, "<boost/random/uniform_01.hpp>", public ] },
+ { include: ["<boost/random/detail/enable_warnings.hpp>", private, "<boost/random/variate_generator.hpp>", public ] },
+ { include: ["<boost/random/detail/generator_bits.hpp>", private, "<boost/random/generate_canonical.hpp>", public ] },
+ { include: ["<boost/random/detail/generator_seed_seq.hpp>", private, "<boost/random/lagged_fibonacci.hpp>", public ] },
+ { include: ["<boost/random/detail/generator_seed_seq.hpp>", private, "<boost/random/mersenne_twister.hpp>", public ] },
+ { include: ["<boost/random/detail/generator_seed_seq.hpp>", private, "<boost/random/subtract_with_carry.hpp>", public ] },
+ { include: ["<boost/random/detail/integer_log2.hpp>", private, "<boost/random/independent_bits.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/additive_combine.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/bernoulli_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/cauchy_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/discrete_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/exponential_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/extreme_value_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/fisher_f_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/geometric_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/independent_bits.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/inversive_congruential.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/lagged_fibonacci.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/linear_feedback_shift.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/lognormal_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/normal_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/piecewise_constant_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/piecewise_linear_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/shuffle_order.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/student_t_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/subtract_with_carry.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/triangle_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/uniform_int_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/uniform_on_sphere.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/uniform_real_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/uniform_smallint.hpp>", public ] },
+ { include: ["<boost/random/detail/operators.hpp>", private, "<boost/random/weibull_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/ptr_helper.hpp>", private, "<boost/random/mersenne_twister.hpp>", public ] },
+ { include: ["<boost/random/detail/ptr_helper.hpp>", private, "<boost/random/uniform_01.hpp>", public ] },
+ { include: ["<boost/random/detail/ptr_helper.hpp>", private, "<boost/random/variate_generator.hpp>", public ] },
+ { include: ["<boost/random/detail/seed.hpp>", private, "<boost/random/additive_combine.hpp>", public ] },
+ { include: ["<boost/random/detail/seed.hpp>", private, "<boost/random/discard_block.hpp>", public ] },
+ { include: ["<boost/random/detail/seed.hpp>", private, "<boost/random/independent_bits.hpp>", public ] },
+ { include: ["<boost/random/detail/seed.hpp>", private, "<boost/random/inversive_congruential.hpp>", public ] },
+ { include: ["<boost/random/detail/seed.hpp>", private, "<boost/random/lagged_fibonacci.hpp>", public ] },
+ { include: ["<boost/random/detail/seed.hpp>", private, "<boost/random/linear_congruential.hpp>", public ] },
+ { include: ["<boost/random/detail/seed.hpp>", private, "<boost/random/linear_feedback_shift.hpp>", public ] },
+ { include: ["<boost/random/detail/seed.hpp>", private, "<boost/random/mersenne_twister.hpp>", public ] },
+ { include: ["<boost/random/detail/seed.hpp>", private, "<boost/random/shuffle_order.hpp>", public ] },
+ { include: ["<boost/random/detail/seed.hpp>", private, "<boost/random/subtract_with_carry.hpp>", public ] },
+ { include: ["<boost/random/detail/seed.hpp>", private, "<boost/random/xor_combine.hpp>", public ] },
+ { include: ["<boost/random/detail/seed_impl.hpp>", private, "<boost/random/independent_bits.hpp>", public ] },
+ { include: ["<boost/random/detail/seed_impl.hpp>", private, "<boost/random/inversive_congruential.hpp>", public ] },
+ { include: ["<boost/random/detail/seed_impl.hpp>", private, "<boost/random/linear_congruential.hpp>", public ] },
+ { include: ["<boost/random/detail/seed_impl.hpp>", private, "<boost/random/linear_feedback_shift.hpp>", public ] },
+ { include: ["<boost/random/detail/seed_impl.hpp>", private, "<boost/random/mersenne_twister.hpp>", public ] },
+ { include: ["<boost/random/detail/seed_impl.hpp>", private, "<boost/random/subtract_with_carry.hpp>", public ] },
+ { include: ["<boost/random/detail/seed_impl.hpp>", private, "<boost/random/xor_combine.hpp>", public ] },
+ { include: ["<boost/random/detail/signed_unsigned_tools.hpp>", private, "<boost/random/generate_canonical.hpp>", public ] },
+ { include: ["<boost/random/detail/signed_unsigned_tools.hpp>", private, "<boost/random/independent_bits.hpp>", public ] },
+ { include: ["<boost/random/detail/signed_unsigned_tools.hpp>", private, "<boost/random/shuffle_order.hpp>", public ] },
+ { include: ["<boost/random/detail/signed_unsigned_tools.hpp>", private, "<boost/random/uniform_int_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/signed_unsigned_tools.hpp>", private, "<boost/random/uniform_real_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/signed_unsigned_tools.hpp>", private, "<boost/random/uniform_smallint.hpp>", public ] },
+ { include: ["<boost/random/detail/uniform_int_float.hpp>", private, "<boost/random/uniform_int_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/vector_io.hpp>", private, "<boost/random/discrete_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/vector_io.hpp>", private, "<boost/random/piecewise_constant_distribution.hpp>", public ] },
+ { include: ["<boost/random/detail/vector_io.hpp>", private, "<boost/random/piecewise_linear_distribution.hpp>", public ] },
+ { include: ["<boost/range/detail/any_iterator.hpp>", private, "<boost/range/any_range.hpp>", public ] },
+ { include: ["<boost/range/detail/as_literal.hpp>", private, "<boost/range/as_literal.hpp>", public ] },
+ { include: ["<boost/range/detail/begin.hpp>", private, "<boost/range/begin.hpp>", public ] },
+ { include: ["<boost/range/detail/collection_traits.hpp>", private, "<boost/range.hpp>", public ] },
+ { include: ["<boost/range/detail/const_iterator.hpp>", private, "<boost/range/const_iterator.hpp>", public ] },
+ { include: ["<boost/range/detail/end.hpp>", private, "<boost/range/end.hpp>", public ] },
+ { include: ["<boost/range/detail/extract_optional_type.hpp>", private, "<boost/range/const_iterator.hpp>", public ] },
+ { include: ["<boost/range/detail/extract_optional_type.hpp>", private, "<boost/range/mutable_iterator.hpp>", public ] },
+ { include: ["<boost/range/detail/implementation_help.hpp>", private, "<boost/range/end.hpp>", public ] },
+ { include: ["<boost/range/detail/iterator.hpp>", private, "<boost/range/mutable_iterator.hpp>", public ] },
+ { include: ["<boost/range/detail/join_iterator.hpp>", private, "<boost/range/join.hpp>", public ] },
+ { include: ["<boost/range/detail/microsoft.hpp>", private, "<boost/range/atl.hpp>", public ] },
+ { include: ["<boost/range/detail/microsoft.hpp>", private, "<boost/range/mfc.hpp>", public ] },
+ { include: ["<boost/range/detail/misc_concept.hpp>", private, "<boost/range/concepts.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/adjacent_find.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/find_end.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/find_first_of.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/find.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/find_if.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/lower_bound.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/max_element.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/min_element.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/partition.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/remove.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/remove_if.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/reverse.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/search.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/search_n.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/stable_partition.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/unique.hpp>", public ] },
+ { include: ["<boost/range/detail/range_return.hpp>", private, "<boost/range/algorithm/upper_bound.hpp>", public ] },
+ { include: ["<boost/range/detail/safe_bool.hpp>", private, "<boost/range/iterator_range_core.hpp>", public ] },
+ { include: ["<boost/range/detail/size_type.hpp>", private, "<boost/range/size_type.hpp>", public ] },
+ { include: ["<boost/range/detail/str_types.hpp>", private, "<boost/range/as_array.hpp>", public ] },
+ { include: ["<boost/range/detail/str_types.hpp>", private, "<boost/range/as_literal.hpp>", public ] },
+ { include: ["<boost/ratio/detail/mpl/abs.hpp>", private, "<boost/ratio/mpl/abs.hpp>", public ] },
+ { include: ["<boost/ratio/detail/mpl/abs.hpp>", private, "<boost/ratio/ratio.hpp>", public ] },
+ { include: ["<boost/ratio/detail/mpl/gcd.hpp>", private, "<boost/ratio/mpl/gcd.hpp>", public ] },
+ { include: ["<boost/ratio/detail/mpl/gcd.hpp>", private, "<boost/ratio/ratio.hpp>", public ] },
+ { include: ["<boost/ratio/detail/mpl/lcm.hpp>", private, "<boost/ratio/mpl/lcm.hpp>", public ] },
+ { include: ["<boost/ratio/detail/mpl/lcm.hpp>", private, "<boost/ratio/ratio.hpp>", public ] },
+ { include: ["<boost/ratio/detail/mpl/sign.hpp>", private, "<boost/ratio/mpl/sign.hpp>", public ] },
+ { include: ["<boost/ratio/detail/mpl/sign.hpp>", private, "<boost/ratio/ratio.hpp>", public ] },
+ { include: ["<boost/ratio/detail/overflow_helpers.hpp>", private, "<boost/ratio/ratio.hpp>", public ] },
+ { include: ["<boost/ratio/detail/ratio_io.hpp>", private, "<boost/ratio/ratio_io.hpp>", public ] },
+ { include: ["<boost/serialization/detail/get_data.hpp>", private, "<boost/serialization/valarray.hpp>", public ] },
+ { include: ["<boost/serialization/detail/get_data.hpp>", private, "<boost/serialization/vector.hpp>", public ] },
+ { include: ["<boost/serialization/detail/shared_ptr_132.hpp>", private, "<boost/serialization/shared_ptr_132.hpp>", public ] },
+ { include: ["<boost/serialization/detail/stack_constructor.hpp>", private, "<boost/serialization/collections_load_imp.hpp>", public ] },
+ { include: ["<boost/serialization/detail/stack_constructor.hpp>", private, "<boost/serialization/optional.hpp>", public ] },
+ { include: ["<boost/signals2/detail/foreign_ptr.hpp>", private, "<boost/signals2/slot_base.hpp>", public ] },
+ { include: ["<boost/signals2/detail/lwm_nop.hpp>", private, "<boost/signals2/mutex.hpp>", public ] },
+ { include: ["<boost/signals2/detail/lwm_pthreads.hpp>", private, "<boost/signals2/mutex.hpp>", public ] },
+ { include: ["<boost/signals2/detail/lwm_win32_cs.hpp>", private, "<boost/signals2/mutex.hpp>", public ] },
+ { include: ["<boost/signals2/detail/null_output_iterator.hpp>", private, "<boost/signals2/connection.hpp>", public ] },
+ { include: ["<boost/signals2/detail/preprocessed_arg_type.hpp>", private, "<boost/signals2/preprocessed_signal.hpp>", public ] },
+ { include: ["<boost/signals2/detail/preprocessed_arg_type.hpp>", private, "<boost/signals2/preprocessed_slot.hpp>", public ] },
+ { include: ["<boost/signals2/detail/replace_slot_function.hpp>", private, "<boost/signals2/signal.hpp>", public ] },
+ { include: ["<boost/signals2/detail/result_type_wrapper.hpp>", private, "<boost/signals2/signal.hpp>", public ] },
+ { include: ["<boost/signals2/detail/signals_common.hpp>", private, "<boost/signals2/signal.hpp>", public ] },
+ { include: ["<boost/signals2/detail/signals_common.hpp>", private, "<boost/signals2/slot.hpp>", public ] },
+ { include: ["<boost/signals2/detail/signals_common_macros.hpp>", private, "<boost/signals2/signal.hpp>", public ] },
+ { include: ["<boost/signals2/detail/signals_common_macros.hpp>", private, "<boost/signals2/slot.hpp>", public ] },
+ { include: ["<boost/signals2/detail/signal_template.hpp>", private, "<boost/signals2/variadic_signal.hpp>", public ] },
+ { include: ["<boost/signals2/detail/slot_call_iterator.hpp>", private, "<boost/signals2/signal.hpp>", public ] },
+ { include: ["<boost/signals2/detail/slot_groups.hpp>", private, "<boost/signals2/signal.hpp>", public ] },
+ { include: ["<boost/signals2/detail/slot_template.hpp>", private, "<boost/signals2/variadic_slot.hpp>", public ] },
+ { include: ["<boost/signals2/detail/tracked_objects_visitor.hpp>", private, "<boost/signals2/slot.hpp>", public ] },
+ { include: ["<boost/signals2/detail/unique_lock.hpp>", private, "<boost/signals2/connection.hpp>", public ] },
+ { include: ["<boost/signals2/detail/unique_lock.hpp>", private, "<boost/signals2/signal.hpp>", public ] },
+ { include: ["<boost/signals2/detail/variadic_arg_type.hpp>", private, "<boost/signals2/variadic_signal.hpp>", public ] },
+ { include: ["<boost/signals2/detail/variadic_arg_type.hpp>", private, "<boost/signals2/variadic_slot.hpp>", public ] },
+ { include: ["<boost/signals2/detail/variadic_slot_invoker.hpp>", private, "<boost/signals2/variadic_signal.hpp>", public ] },
+ { include: ["<boost/signals/detail/signal_base.hpp>", private, "<boost/signals/signal_template.hpp>", public ] },
+ { include: ["<boost/signals/detail/signals_common.hpp>", private, "<boost/signals/connection.hpp>", public ] },
+ { include: ["<boost/signals/detail/signals_common.hpp>", private, "<boost/signals/slot.hpp>", public ] },
+ { include: ["<boost/signals/detail/slot_call_iterator.hpp>", private, "<boost/signals/signal_template.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/allocate_array_helper.hpp>", private, "<boost/smart_ptr/allocate_shared_array.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/array_deleter.hpp>", private, "<boost/smart_ptr/allocate_shared_array.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/array_deleter.hpp>", private, "<boost/smart_ptr/make_shared_array.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/array_traits.hpp>", private, "<boost/smart_ptr/allocate_shared_array.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/array_traits.hpp>", private, "<boost/smart_ptr/make_shared_array.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/atomic_count.hpp>", private, "<boost/smart_ptr/intrusive_ref_counter.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/make_array_helper.hpp>", private, "<boost/smart_ptr/make_shared_array.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/operator_bool.hpp>", private, "<boost/smart_ptr/intrusive_ptr.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/operator_bool.hpp>", private, "<boost/smart_ptr/scoped_array.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/operator_bool.hpp>", private, "<boost/smart_ptr/scoped_ptr.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/operator_bool.hpp>", private, "<boost/smart_ptr/shared_array.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/operator_bool.hpp>", private, "<boost/smart_ptr/shared_ptr.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/shared_array_nmt.hpp>", private, "<boost/smart_ptr/shared_array.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/shared_count.hpp>", private, "<boost/smart_ptr/shared_array.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/shared_count.hpp>", private, "<boost/smart_ptr/shared_ptr.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/shared_count.hpp>", private, "<boost/smart_ptr/weak_ptr.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/shared_ptr_nmt.hpp>", private, "<boost/smart_ptr/shared_ptr.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/sp_convertible.hpp>", private, "<boost/smart_ptr/intrusive_ptr.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/sp_convertible.hpp>", private, "<boost/smart_ptr/shared_ptr.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/sp_forward.hpp>", private, "<boost/smart_ptr/make_shared_object.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/sp_if_array.hpp>", private, "<boost/smart_ptr/allocate_shared_array.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/sp_if_array.hpp>", private, "<boost/smart_ptr/make_shared_array.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/spinlock_pool.hpp>", private, "<boost/smart_ptr/shared_ptr.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/sp_nullptr_t.hpp>", private, "<boost/smart_ptr/intrusive_ptr.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/sp_nullptr_t.hpp>", private, "<boost/smart_ptr/scoped_array.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/sp_nullptr_t.hpp>", private, "<boost/smart_ptr/scoped_ptr.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/sp_nullptr_t.hpp>", private, "<boost/smart_ptr/shared_array.hpp>", public ] },
+ { include: ["<boost/smart_ptr/detail/sp_nullptr_t.hpp>", private, "<boost/smart_ptr/shared_ptr.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/composite/impl/alternative.ipp>", private, "<boost/spirit/home/classic/core/composite/alternative.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/composite/impl/difference.ipp>", private, "<boost/spirit/home/classic/core/composite/difference.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/composite/impl/directives.ipp>", private, "<boost/spirit/home/classic/core/composite/directives.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/composite/impl/directives.ipp>", private, "<boost/spirit/home/classic/core/primitives/primitives.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/composite/impl/exclusive_or.ipp>", private, "<boost/spirit/home/classic/core/composite/exclusive_or.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/composite/impl/intersection.ipp>", private, "<boost/spirit/home/classic/core/composite/intersection.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/composite/impl/kleene_star.ipp>", private, "<boost/spirit/home/classic/core/composite/kleene_star.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/composite/impl/list.ipp>", private, "<boost/spirit/home/classic/core/composite/list.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/composite/impl/optional.ipp>", private, "<boost/spirit/home/classic/core/composite/optional.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/composite/impl/positive.ipp>", private, "<boost/spirit/home/classic/core/composite/positive.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/composite/impl/sequence.ipp>", private, "<boost/spirit/home/classic/core/composite/sequence.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/composite/impl/sequential_and.ipp>", private, "<boost/spirit/home/classic/core/composite/sequential_and.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/composite/impl/sequential_or.ipp>", private, "<boost/spirit/home/classic/core/composite/sequential_or.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/impl/match_attr_traits.ipp>", private, "<boost/spirit/home/classic/core/match.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/impl/match.ipp>", private, "<boost/spirit/home/classic/core/match.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/impl/parser.ipp>", private, "<boost/spirit/home/classic/core/parser.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/non_terminal/impl/grammar.ipp>", private, "<boost/spirit/home/classic/core/non_terminal/grammar.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/non_terminal/impl/rule.ipp>", private, "<boost/spirit/home/classic/core/non_terminal/rule.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/non_terminal/impl/rule.ipp>", private, "<boost/spirit/home/classic/dynamic/stored_rule.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/non_terminal/impl/static.hpp>", private, "<boost/spirit/include/classic_static.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/non_terminal/impl/subrule.ipp>", private, "<boost/spirit/home/classic/core/non_terminal/subrule.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/primitives/impl/numerics.ipp>", private, "<boost/spirit/home/classic/core/primitives/numerics.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/primitives/impl/primitives.ipp>", private, "<boost/spirit/home/classic/core/primitives/primitives.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/primitives/impl/primitives.ipp>", private, "<boost/spirit/home/classic/core/scanner/skipper.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/core/scanner/impl/skipper.ipp>", private, "<boost/spirit/home/classic/core/scanner/skipper.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/debug/impl/parser_names.ipp>", private, "<boost/spirit/home/classic/debug/parser_names.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/dynamic/impl/conditions.ipp>", private, "<boost/spirit/home/classic/dynamic/for.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/dynamic/impl/conditions.ipp>", private, "<boost/spirit/home/classic/dynamic/if.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/dynamic/impl/conditions.ipp>", private, "<boost/spirit/home/classic/dynamic/while.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/dynamic/impl/select.ipp>", private, "<boost/spirit/home/classic/dynamic/select.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/dynamic/impl/switch.ipp>", private, "<boost/spirit/home/classic/dynamic/switch.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/error_handling/impl/exceptions.ipp>", private, "<boost/spirit/home/classic/error_handling/exceptions.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/iterator/impl/file_iterator.ipp>", private, "<boost/spirit/home/classic/iterator/file_iterator.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/iterator/impl/position_iterator.ipp>", private, "<boost/spirit/home/classic/iterator/position_iterator.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/meta/impl/fundamental.ipp>", private, "<boost/spirit/home/classic/meta/fundamental.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/meta/impl/parser_traits.ipp>", private, "<boost/spirit/home/classic/meta/parser_traits.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/meta/impl/refactoring.ipp>", private, "<boost/spirit/home/classic/meta/refactoring.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/meta/impl/traverse.ipp>", private, "<boost/spirit/home/classic/meta/traverse.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/symbols/impl/symbols.ipp>", private, "<boost/spirit/home/classic/symbols/symbols.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/tree/impl/parse_tree_utils.ipp>", private, "<boost/spirit/home/classic/tree/parse_tree_utils.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/tree/impl/tree_to_xml.ipp>", private, "<boost/spirit/home/classic/tree/tree_to_xml.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/utility/impl/chset/basic_chset.hpp>", private, "<boost/spirit/home/classic/utility/chset.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/utility/impl/chset/basic_chset.hpp>", private, "<boost/spirit/include/classic_basic_chset.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/utility/impl/chset.ipp>", private, "<boost/spirit/home/classic/utility/chset.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/utility/impl/chset_operators.ipp>", private, "<boost/spirit/home/classic/utility/chset_operators.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/utility/impl/chset/range_run.hpp>", private, "<boost/spirit/include/classic_range_run.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/utility/impl/confix.ipp>", private, "<boost/spirit/home/classic/utility/confix.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/utility/impl/escape_char.ipp>", private, "<boost/spirit/home/classic/utility/escape_char.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/utility/impl/lists.ipp>", private, "<boost/spirit/home/classic/utility/lists.hpp>", public ] },
+ { include: ["<boost/spirit/home/classic/utility/impl/regex.ipp>", private, "<boost/spirit/home/classic/utility/regex.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/alternative_function.hpp>", private, "<boost/spirit/home/karma/operator/alternative.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/as.hpp>", private, "<boost/spirit/home/karma/directive/as.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/action/action.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/auxiliary/attr_cast.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/auxiliary/eol.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/auxiliary/lazy.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/directive/as.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/directive/buffer.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/directive/center_alignment.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/directive/columns.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/directive/delimit.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/directive/duplicate.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/directive/left_alignment.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/directive/maxwidth.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/directive/no_delimit.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/directive/omit.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/directive/repeat.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/directive/right_alignment.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/directive/verbatim.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/nonterminal/rule.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/operator/and_predicate.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/operator/kleene.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/operator/list.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/operator/not_predicate.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/operator/optional.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/operator/plus.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/operator/sequence.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/phoenix_attributes.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/home/karma/string/symbols.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/repository/home/karma/directive/confix.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/attributes.hpp>", private, "<boost/spirit/repository/home/karma/nonterminal/subrule.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/default_width.hpp>", private, "<boost/spirit/home/karma/directive/center_alignment.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/default_width.hpp>", private, "<boost/spirit/home/karma/directive/columns.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/default_width.hpp>", private, "<boost/spirit/home/karma/directive/left_alignment.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/default_width.hpp>", private, "<boost/spirit/home/karma/directive/maxwidth.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/default_width.hpp>", private, "<boost/spirit/home/karma/directive/right_alignment.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/enable_lit.hpp>", private, "<boost/spirit/home/karma/char/char.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/enable_lit.hpp>", private, "<boost/spirit/home/karma/numeric/bool.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/enable_lit.hpp>", private, "<boost/spirit/home/karma/numeric/int.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/enable_lit.hpp>", private, "<boost/spirit/home/karma/numeric/real.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/enable_lit.hpp>", private, "<boost/spirit/home/karma/numeric/uint.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/enable_lit.hpp>", private, "<boost/spirit/home/karma/string/lit.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/extract_from.hpp>", private, "<boost/spirit/home/karma/binary/binary.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/extract_from.hpp>", private, "<boost/spirit/home/karma/char/char_generator.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/extract_from.hpp>", private, "<boost/spirit/home/karma/numeric/bool.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/extract_from.hpp>", private, "<boost/spirit/home/karma/numeric/int.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/extract_from.hpp>", private, "<boost/spirit/home/karma/numeric/real.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/extract_from.hpp>", private, "<boost/spirit/home/karma/numeric/uint.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/extract_from.hpp>", private, "<boost/spirit/home/karma/stream/stream.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/extract_from.hpp>", private, "<boost/spirit/home/karma/string/lit.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/extract_from.hpp>", private, "<boost/spirit/home/karma/string/symbols.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/fail_function.hpp>", private, "<boost/spirit/home/karma/operator/kleene.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/fail_function.hpp>", private, "<boost/spirit/home/karma/operator/list.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/fail_function.hpp>", private, "<boost/spirit/home/karma/operator/plus.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/fail_function.hpp>", private, "<boost/spirit/home/karma/operator/sequence.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/generate_auto.hpp>", private, "<boost/spirit/home/karma/auto.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/generate_auto.hpp>", private, "<boost/spirit/include/karma_generate_auto.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/generate.hpp>", private, "<boost/spirit/home/karma/generate.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/generate_to.hpp>", private, "<boost/spirit/home/karma/auxiliary/eol.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/generate_to.hpp>", private, "<boost/spirit/home/karma/binary/binary.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/generate_to.hpp>", private, "<boost/spirit/home/karma/binary/padding.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/generate_to.hpp>", private, "<boost/spirit/home/karma/char/char_class.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/generate_to.hpp>", private, "<boost/spirit/home/karma/char/char_generator.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/generate_to.hpp>", private, "<boost/spirit/home/karma/char/char.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_casetag.hpp>", private, "<boost/spirit/home/karma/char/char_class.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_casetag.hpp>", private, "<boost/spirit/home/karma/char/char.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_casetag.hpp>", private, "<boost/spirit/home/karma/numeric/bool.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_casetag.hpp>", private, "<boost/spirit/home/karma/numeric/int.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_casetag.hpp>", private, "<boost/spirit/home/karma/numeric/real.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_casetag.hpp>", private, "<boost/spirit/home/karma/numeric/uint.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_casetag.hpp>", private, "<boost/spirit/home/karma/stream/stream.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_casetag.hpp>", private, "<boost/spirit/home/karma/string/lit.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_casetag.hpp>", private, "<boost/spirit/home/karma/string/symbols.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_stricttag.hpp>", private, "<boost/spirit/home/karma/directive/repeat.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_stricttag.hpp>", private, "<boost/spirit/home/karma/operator/alternative.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_stricttag.hpp>", private, "<boost/spirit/home/karma/operator/kleene.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_stricttag.hpp>", private, "<boost/spirit/home/karma/operator/list.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_stricttag.hpp>", private, "<boost/spirit/home/karma/operator/plus.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/get_stricttag.hpp>", private, "<boost/spirit/home/karma/operator/sequence.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/indirect_iterator.hpp>", private, "<boost/spirit/home/karma/operator/kleene.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/indirect_iterator.hpp>", private, "<boost/spirit/home/karma/operator/list.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/indirect_iterator.hpp>", private, "<boost/spirit/home/karma/operator/plus.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/indirect_iterator.hpp>", private, "<boost/spirit/home/karma/operator/sequence.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/indirect_iterator.hpp>", private, "<boost/spirit/home/karma/phoenix_attributes.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/directive/as.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/directive/buffer.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/directive/center_alignment.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/directive/left_alignment.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/directive/maxwidth.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/directive/repeat.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/directive/right_alignment.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/nonterminal/rule.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/operator/and_predicate.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/operator/kleene.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/operator/list.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/operator/not_predicate.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/operator/plus.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/output_iterator.hpp>", private, "<boost/spirit/home/karma/stream/format_manip.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/pass_container.hpp>", private, "<boost/spirit/home/karma/operator/kleene.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/pass_container.hpp>", private, "<boost/spirit/home/karma/operator/list.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/pass_container.hpp>", private, "<boost/spirit/home/karma/operator/plus.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/pass_container.hpp>", private, "<boost/spirit/home/karma/operator/sequence.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/string_compare.hpp>", private, "<boost/spirit/home/karma/string/lit.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/string_generate.hpp>", private, "<boost/spirit/home/karma/string/lit.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/string_generate.hpp>", private, "<boost/spirit/home/karma/string/symbols.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/unused_delimiter.hpp>", private, "<boost/spirit/home/karma/delimit_out.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/unused_delimiter.hpp>", private, "<boost/spirit/home/karma/directive/delimit.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/unused_delimiter.hpp>", private, "<boost/spirit/home/karma/directive/no_delimit.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/detail/unused_delimiter.hpp>", private, "<boost/spirit/home/karma/directive/verbatim.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/nonterminal/detail/fcall.hpp>", private, "<boost/spirit/home/karma/nonterminal/grammar.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/nonterminal/detail/fcall.hpp>", private, "<boost/spirit/home/karma/nonterminal/rule.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/nonterminal/detail/fcall.hpp>", private, "<boost/spirit/repository/home/karma/nonterminal/subrule.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/nonterminal/detail/generator_binder.hpp>", private, "<boost/spirit/home/karma/nonterminal/rule.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/nonterminal/detail/generator_binder.hpp>", private, "<boost/spirit/repository/home/karma/nonterminal/subrule.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/nonterminal/detail/parameterized.hpp>", private, "<boost/spirit/home/karma/nonterminal/rule.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/nonterminal/detail/parameterized.hpp>", private, "<boost/spirit/repository/home/karma/nonterminal/subrule.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/numeric/detail/bool_utils.hpp>", private, "<boost/spirit/home/karma/numeric/bool.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/numeric/detail/numeric_utils.hpp>", private, "<boost/spirit/home/karma/numeric/bool_policies.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/numeric/detail/numeric_utils.hpp>", private, "<boost/spirit/home/karma/numeric/int.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/numeric/detail/numeric_utils.hpp>", private, "<boost/spirit/home/karma/numeric/uint.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/numeric/detail/real_utils.hpp>", private, "<boost/spirit/home/karma/numeric/real.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/numeric/detail/real_utils.hpp>", private, "<boost/spirit/home/karma/numeric/real_policies.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/stream/detail/format_manip_auto.hpp>", private, "<boost/spirit/home/karma/format_auto.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/stream/detail/format_manip.hpp>", private, "<boost/spirit/home/karma/stream/format_manip.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/stream/detail/format_manip.hpp>", private, "<boost/spirit/home/karma/stream/stream.hpp>", public ] },
+ { include: ["<boost/spirit/home/karma/stream/detail/iterator_sink.hpp>", private, "<boost/spirit/home/karma/stream/stream.hpp>", public ] },
+ { include: ["<boost/spirit/home/lex/detail/sequence_function.hpp>", private, "<boost/spirit/home/lex/lexer/sequence.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/core/detail/actor.hpp>", private, "<boost/spirit/home/phoenix/core/actor.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/core/detail/basic_environment.hpp>", private, "<boost/spirit/home/phoenix/core/basic_environment.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/core/detail/compose.hpp>", private, "<boost/spirit/home/phoenix/core/compose.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/core/detail/composite_eval.hpp>", private, "<boost/spirit/home/phoenix/core/composite.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/core/detail/composite.hpp>", private, "<boost/spirit/home/phoenix/core/composite.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/detail/local_reference.hpp>", private, "<boost/spirit/home/phoenix/scope/lambda.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/detail/local_reference.hpp>", private, "<boost/spirit/home/phoenix/scope/let.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/detail/local_reference.hpp>", private, "<boost/spirit/home/phoenix/scope/local_variable.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/detail/type_deduction.hpp>", private, "<boost/spirit/home/phoenix/operator/arithmetic.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/detail/type_deduction.hpp>", private, "<boost/spirit/home/phoenix/operator/bitwise.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/detail/type_deduction.hpp>", private, "<boost/spirit/home/phoenix/operator/comparison.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/detail/type_deduction.hpp>", private, "<boost/spirit/home/phoenix/operator/if_else.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/detail/type_deduction.hpp>", private, "<boost/spirit/home/phoenix/operator/logical.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/detail/type_deduction.hpp>", private, "<boost/spirit/home/phoenix/operator/self.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/object/detail/construct_eval.hpp>", private, "<boost/spirit/home/phoenix/object/construct.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/object/detail/construct.hpp>", private, "<boost/spirit/home/phoenix/object/construct.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/object/detail/new_eval.hpp>", private, "<boost/spirit/home/phoenix/object/new.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/object/detail/new.hpp>", private, "<boost/spirit/home/phoenix/object/new.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/binary_compose.hpp>", private, "<boost/spirit/home/phoenix/operator/arithmetic.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/binary_compose.hpp>", private, "<boost/spirit/home/phoenix/operator/bitwise.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/binary_compose.hpp>", private, "<boost/spirit/home/phoenix/operator/comparison.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/binary_compose.hpp>", private, "<boost/spirit/home/phoenix/operator/logical.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/binary_compose.hpp>", private, "<boost/spirit/home/phoenix/operator/self.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/binary_eval.hpp>", private, "<boost/spirit/home/phoenix/operator/arithmetic.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/binary_eval.hpp>", private, "<boost/spirit/home/phoenix/operator/bitwise.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/binary_eval.hpp>", private, "<boost/spirit/home/phoenix/operator/comparison.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/binary_eval.hpp>", private, "<boost/spirit/home/phoenix/operator/logical.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/binary_eval.hpp>", private, "<boost/spirit/home/phoenix/operator/self.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/io.hpp>", private, "<boost/spirit/home/phoenix/operator/io.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/mem_fun_ptr_gen.hpp>", private, "<boost/spirit/home/phoenix/operator/member.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/unary_compose.hpp>", private, "<boost/spirit/home/phoenix/operator/arithmetic.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/unary_compose.hpp>", private, "<boost/spirit/home/phoenix/operator/bitwise.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/unary_compose.hpp>", private, "<boost/spirit/home/phoenix/operator/comparison.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/unary_compose.hpp>", private, "<boost/spirit/home/phoenix/operator/logical.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/unary_compose.hpp>", private, "<boost/spirit/home/phoenix/operator/self.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/unary_eval.hpp>", private, "<boost/spirit/home/phoenix/operator/arithmetic.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/unary_eval.hpp>", private, "<boost/spirit/home/phoenix/operator/bitwise.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/unary_eval.hpp>", private, "<boost/spirit/home/phoenix/operator/comparison.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/unary_eval.hpp>", private, "<boost/spirit/home/phoenix/operator/logical.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/operator/detail/unary_eval.hpp>", private, "<boost/spirit/home/phoenix/operator/self.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/scope/detail/local_gen.hpp>", private, "<boost/spirit/home/phoenix/scope/lambda.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/scope/detail/local_gen.hpp>", private, "<boost/spirit/home/phoenix/scope/let.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/scope/detail/local_variable.hpp>", private, "<boost/spirit/home/phoenix/scope/lambda.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/scope/detail/local_variable.hpp>", private, "<boost/spirit/home/phoenix/scope/let.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/scope/detail/local_variable.hpp>", private, "<boost/spirit/home/phoenix/scope/local_variable.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/statement/detail/catch_all_eval.hpp>", private, "<boost/spirit/home/phoenix/statement/try_catch.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/statement/detail/catch_composite.hpp>", private, "<boost/spirit/home/phoenix/statement/try_catch.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/statement/detail/catch_eval.hpp>", private, "<boost/spirit/home/phoenix/statement/try_catch.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/statement/detail/switch_eval.hpp>", private, "<boost/spirit/home/phoenix/statement/switch.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/statement/detail/switch.hpp>", private, "<boost/spirit/home/phoenix/statement/switch.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/begin.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/iteration.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/begin.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/querying.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/begin.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/decay_array.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/querying.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/decay_array.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/end.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/iteration.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/end.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/querying.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/end.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/has_equal_range.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/querying.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/has_find.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/querying.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/has_lower_bound.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/querying.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/has_remove.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/has_remove_if.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/has_reverse.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/has_sort.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/has_unique.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/transformation.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/algorithm/detail/has_upper_bound.hpp>", private, "<boost/spirit/home/phoenix/stl/algorithm/querying.hpp>", public ] },
+ { include: ["<boost/spirit/home/phoenix/stl/container/detail/container.hpp>", private, "<boost/spirit/home/phoenix/stl/container/container.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/alternative_function.hpp>", private, "<boost/spirit/home/qi/operator/alternative.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/lex/lexer/lexer.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/functor_data.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/position_token.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/token.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/lex/lexer/token_def.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/lex/qi/plain_raw_token.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/lex/qi/plain_token.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/lex/qi/plain_tokenid.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/lex/qi/plain_tokenid_mask.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/qi/auxiliary/attr.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/qi/binary/binary.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/qi/char/char_parser.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/qi/directive/as.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/qi/directive/matches.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/qi/directive/raw.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/qi/numeric/bool_policies.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/qi/numeric/numeric_utils.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/qi/operator/optional.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/assign_to.hpp>", private, "<boost/spirit/home/qi/string/symbols.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/lex/qi/plain_raw_token.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/lex/qi/plain_token.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/lex/qi/plain_tokenid.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/lex/qi/plain_tokenid_mask.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/lex/qi/state_switcher.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/action/action.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/auxiliary/attr_cast.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/auxiliary/lazy.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/binary/binary.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/directive/lexeme.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/directive/repeat.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/directive/skip.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/nonterminal/rule.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/operator/alternative.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/operator/and_predicate.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/operator/difference.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/operator/kleene.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/operator/list.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/operator/not_predicate.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/operator/optional.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/operator/permutation.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/operator/plus.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/operator/sequence_base.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/home/qi/operator/sequential_or.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/repository/home/qi/directive/confix.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/repository/home/qi/directive/distinct.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/repository/home/qi/directive/kwd.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/repository/home/qi/directive/seek.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/repository/home/qi/nonterminal/subrule.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/attributes.hpp>", private, "<boost/spirit/repository/home/qi/operator/keywords.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/construct.hpp>", private, "<boost/spirit/home/lex/lexer/token_def.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/enable_lit.hpp>", private, "<boost/spirit/home/qi/char/char.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/enable_lit.hpp>", private, "<boost/spirit/home/qi/numeric/bool.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/enable_lit.hpp>", private, "<boost/spirit/home/qi/numeric/int.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/enable_lit.hpp>", private, "<boost/spirit/home/qi/numeric/real.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/enable_lit.hpp>", private, "<boost/spirit/home/qi/numeric/uint.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/enable_lit.hpp>", private, "<boost/spirit/home/qi/string/lit.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/expect_function.hpp>", private, "<boost/spirit/home/qi/operator/expect.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/fail_function.hpp>", private, "<boost/spirit/home/qi/directive/repeat.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/fail_function.hpp>", private, "<boost/spirit/home/qi/operator/kleene.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/fail_function.hpp>", private, "<boost/spirit/home/qi/operator/list.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/fail_function.hpp>", private, "<boost/spirit/home/qi/operator/plus.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/fail_function.hpp>", private, "<boost/spirit/home/qi/operator/sequence.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/fail_function.hpp>", private, "<boost/spirit/repository/home/qi/directive/kwd.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/parse_auto.hpp>", private, "<boost/spirit/home/qi/auto.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/parse_auto.hpp>", private, "<boost/spirit/include/qi_parse_auto.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/parse.hpp>", private, "<boost/spirit/home/qi/parse.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/pass_container.hpp>", private, "<boost/spirit/home/qi/directive/repeat.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/pass_container.hpp>", private, "<boost/spirit/home/qi/operator/kleene.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/pass_container.hpp>", private, "<boost/spirit/home/qi/operator/list.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/pass_container.hpp>", private, "<boost/spirit/home/qi/operator/plus.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/pass_container.hpp>", private, "<boost/spirit/home/qi/operator/sequence_base.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/pass_function.hpp>", private, "<boost/spirit/home/qi/operator/sequential_or.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/permute_function.hpp>", private, "<boost/spirit/home/qi/operator/permutation.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/permute_function.hpp>", private, "<boost/spirit/repository/home/qi/operator/keywords.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/string_parse.hpp>", private, "<boost/spirit/home/qi/numeric/bool_policies.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/string_parse.hpp>", private, "<boost/spirit/home/qi/numeric/real_policies.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/string_parse.hpp>", private, "<boost/spirit/home/qi/stream/stream.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/string_parse.hpp>", private, "<boost/spirit/home/qi/string/lit.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/unused_skipper.hpp>", private, "<boost/spirit/home/qi/directive/lexeme.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/unused_skipper.hpp>", private, "<boost/spirit/home/qi/directive/no_skip.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/unused_skipper.hpp>", private, "<boost/spirit/home/qi/directive/skip.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/unused_skipper.hpp>", private, "<boost/spirit/home/qi/skip_over.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/detail/unused_skipper.hpp>", private, "<boost/spirit/repository/home/qi/directive/distinct.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/nonterminal/detail/fcall.hpp>", private, "<boost/spirit/home/qi/nonterminal/grammar.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/nonterminal/detail/fcall.hpp>", private, "<boost/spirit/home/qi/nonterminal/rule.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/nonterminal/detail/fcall.hpp>", private, "<boost/spirit/repository/home/qi/nonterminal/subrule.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/nonterminal/detail/parameterized.hpp>", private, "<boost/spirit/home/qi/nonterminal/rule.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/nonterminal/detail/parameterized.hpp>", private, "<boost/spirit/repository/home/qi/nonterminal/subrule.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/nonterminal/detail/parser_binder.hpp>", private, "<boost/spirit/home/qi/nonterminal/rule.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/nonterminal/detail/parser_binder.hpp>", private, "<boost/spirit/repository/home/qi/nonterminal/subrule.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/numeric/detail/numeric_utils.hpp>", private, "<boost/spirit/home/qi/numeric/numeric_utils.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/numeric/detail/real_impl.hpp>", private, "<boost/spirit/home/qi/numeric/real.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/stream/detail/iterator_source.hpp>", private, "<boost/spirit/home/qi/stream/stream.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/stream/detail/match_manip_auto.hpp>", private, "<boost/spirit/home/qi/match_auto.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/stream/detail/match_manip.hpp>", private, "<boost/spirit/home/qi/stream/match_manip.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/stream/detail/match_manip.hpp>", private, "<boost/spirit/home/qi/stream/stream.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/string/detail/tst.hpp>", private, "<boost/spirit/home/qi/string/tst.hpp>", public ] },
+ { include: ["<boost/spirit/home/qi/string/detail/tst.hpp>", private, "<boost/spirit/home/qi/string/tst_map.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/as_variant.hpp>", private, "<boost/spirit/home/support/attributes.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/endian.hpp>", private, "<boost/spirit/home/karma/binary/binary.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/endian.hpp>", private, "<boost/spirit/home/qi/binary/binary.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/get_encoding.hpp>", private, "<boost/spirit/home/karma/char/char_class.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/get_encoding.hpp>", private, "<boost/spirit/home/karma/char/char.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/get_encoding.hpp>", private, "<boost/spirit/home/karma/numeric/int.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/get_encoding.hpp>", private, "<boost/spirit/home/karma/numeric/real.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/get_encoding.hpp>", private, "<boost/spirit/home/karma/numeric/uint.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/get_encoding.hpp>", private, "<boost/spirit/home/karma/stream/stream.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/get_encoding.hpp>", private, "<boost/spirit/home/karma/string/lit.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/get_encoding.hpp>", private, "<boost/spirit/home/karma/string/symbols.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/get_encoding.hpp>", private, "<boost/spirit/home/qi/char/char_class.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/get_encoding.hpp>", private, "<boost/spirit/home/qi/char/char.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/get_encoding.hpp>", private, "<boost/spirit/home/qi/string/lit.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/get_encoding.hpp>", private, "<boost/spirit/home/qi/string/symbols.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/hold_any.hpp>", private, "<boost/spirit/home/karma/auto/auto.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/hold_any.hpp>", private, "<boost/spirit/home/karma/stream/stream.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/hold_any.hpp>", private, "<boost/spirit/home/qi/auto/auto.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/hold_any.hpp>", private, "<boost/spirit/home/qi/stream/stream.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/hold_any.hpp>", private, "<boost/spirit/home/support/attributes.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/is_spirit_tag.hpp>", private, "<boost/spirit/home/karma/numeric/int.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/is_spirit_tag.hpp>", private, "<boost/spirit/home/karma/numeric/uint.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/is_spirit_tag.hpp>", private, "<boost/spirit/home/karma/stream/stream.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/is_spirit_tag.hpp>", private, "<boost/spirit/home/qi/numeric/uint.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/is_spirit_tag.hpp>", private, "<boost/spirit/home/support/char_class.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/is_spirit_tag.hpp>", private, "<boost/spirit/home/support/lazy.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/is_spirit_tag.hpp>", private, "<boost/spirit/home/support/terminal.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/char_traits.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/generate_static.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/char_traits.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/iterator_tokenizer.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/consts.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/generate_static.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/consts.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/iterator_tokenizer.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/consts.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/lexer.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/consts.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/position_token.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/consts.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/token.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/debug.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/generate_static.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/debug.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/lexer.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/debug.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/static_lexer.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/generator.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/functor_data.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/generator.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/lexer.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/generator.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/position_token.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/generator.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/static_functor_data.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/generator.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/token.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/rules.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/functor_data.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/rules.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/generate_static.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/rules.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/lexer.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/rules.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/position_token.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/rules.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/static_functor_data.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/rules.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/token.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/size_t.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/generate_static.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/size_t.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/iterator_tokenizer.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/state_machine.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/functor_data.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/state_machine.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/generate_static.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/state_machine.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/iterator_tokenizer.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/lexer/state_machine.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/static_functor_data.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/make_cons.hpp>", private, "<boost/spirit/home/support/make_component.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/make_cons.hpp>", private, "<boost/spirit/home/support/meta_compiler.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/make_vector.hpp>", private, "<boost/spirit/home/support/terminal.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/scoped_enum_emulation.hpp>", private, "<boost/spirit/home/karma/delimit_flag.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/scoped_enum_emulation.hpp>", private, "<boost/spirit/home/lex/lexer/pass_flags.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/scoped_enum_emulation.hpp>", private, "<boost/spirit/home/lex/lexer/support_functions.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/scoped_enum_emulation.hpp>", private, "<boost/spirit/home/qi/skip_flag.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/scoped_enum_emulation.hpp>", private, "<boost/spirit/home/support/multi_pass_wrapper.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/what_function.hpp>", private, "<boost/spirit/home/karma/operator/alternative.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/what_function.hpp>", private, "<boost/spirit/home/karma/operator/sequence.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/what_function.hpp>", private, "<boost/spirit/home/qi/operator/alternative.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/what_function.hpp>", private, "<boost/spirit/home/qi/operator/permutation.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/what_function.hpp>", private, "<boost/spirit/home/qi/operator/sequence_base.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/what_function.hpp>", private, "<boost/spirit/home/qi/operator/sequential_or.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/detail/what_function.hpp>", private, "<boost/spirit/repository/home/qi/operator/keywords.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/buffering_input_iterator_policy.hpp>", private, "<boost/spirit/home/support/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/buf_id_check_policy.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/iterator.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/buf_id_check_policy.hpp>", private, "<boost/spirit/home/support/iterators/istream_iterator.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/buf_id_check_policy.hpp>", private, "<boost/spirit/home/support/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/combine_policies.hpp>", private, "<boost/spirit/home/support/iterators/istream_iterator.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/combine_policies.hpp>", private, "<boost/spirit/home/support/iterators/look_ahead.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/combine_policies.hpp>", private, "<boost/spirit/home/support/iterators/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/combine_policies.hpp>", private, "<boost/spirit/home/support/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/first_owner_policy.hpp>", private, "<boost/spirit/home/support/iterators/look_ahead.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/first_owner_policy.hpp>", private, "<boost/spirit/home/support/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/fixed_size_queue_policy.hpp>", private, "<boost/spirit/home/support/iterators/look_ahead.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/fixed_size_queue_policy.hpp>", private, "<boost/spirit/home/support/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/functor_input_policy.hpp>", private, "<boost/spirit/home/support/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/input_iterator_policy.hpp>", private, "<boost/spirit/home/support/iterators/look_ahead.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/input_iterator_policy.hpp>", private, "<boost/spirit/home/support/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/istream_policy.hpp>", private, "<boost/spirit/home/support/iterators/istream_iterator.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/istream_policy.hpp>", private, "<boost/spirit/home/support/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/lex_input_policy.hpp>", private, "<boost/spirit/home/support/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/multi_pass.hpp>", private, "<boost/spirit/home/support/iterators/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/multi_pass.hpp>", private, "<boost/spirit/home/support/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/no_check_policy.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/iterator.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/no_check_policy.hpp>", private, "<boost/spirit/home/support/iterators/istream_iterator.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/no_check_policy.hpp>", private, "<boost/spirit/home/support/iterators/look_ahead.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/no_check_policy.hpp>", private, "<boost/spirit/home/support/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/ref_counted_policy.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/iterator.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/ref_counted_policy.hpp>", private, "<boost/spirit/home/support/iterators/istream_iterator.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/ref_counted_policy.hpp>", private, "<boost/spirit/home/support/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/split_functor_input_policy.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/iterator.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/split_functor_input_policy.hpp>", private, "<boost/spirit/home/support/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/split_std_deque_policy.hpp>", private, "<boost/spirit/home/lex/lexer/lexertl/iterator.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/split_std_deque_policy.hpp>", private, "<boost/spirit/home/support/iterators/istream_iterator.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/iterators/detail/split_std_deque_policy.hpp>", private, "<boost/spirit/home/support/multi_pass.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/utree/detail/utree_detail1.hpp>", private, "<boost/spirit/home/support/utree/utree.hpp>", public ] },
+ { include: ["<boost/spirit/home/support/utree/detail/utree_detail2.hpp>", private, "<boost/spirit/home/support/utree.hpp>", public ] },
+ { include: ["<boost/spirit/repository/home/qi/operator/detail/keywords.hpp>", private, "<boost/spirit/repository/home/qi/operator/keywords.hpp>", public ] },
+ { include: ["<boost/statechart/detail/avoid_unused_warning.hpp>", private, "<boost/statechart/state_machine.hpp>", public ] },
+ { include: ["<boost/statechart/detail/constructor.hpp>", private, "<boost/statechart/simple_state.hpp>", public ] },
+ { include: ["<boost/statechart/detail/constructor.hpp>", private, "<boost/statechart/state_machine.hpp>", public ] },
+ { include: ["<boost/statechart/detail/counted_base.hpp>", private, "<boost/statechart/event_base.hpp>", public ] },
+ { include: ["<boost/statechart/detail/leaf_state.hpp>", private, "<boost/statechart/simple_state.hpp>", public ] },
+ { include: ["<boost/statechart/detail/leaf_state.hpp>", private, "<boost/statechart/state_machine.hpp>", public ] },
+ { include: ["<boost/statechart/detail/memory.hpp>", private, "<boost/statechart/event.hpp>", public ] },
+ { include: ["<boost/statechart/detail/memory.hpp>", private, "<boost/statechart/simple_state.hpp>", public ] },
+ { include: ["<boost/statechart/detail/node_state.hpp>", private, "<boost/statechart/simple_state.hpp>", public ] },
+ { include: ["<boost/statechart/detail/node_state.hpp>", private, "<boost/statechart/state_machine.hpp>", public ] },
+ { include: ["<boost/statechart/detail/reaction_dispatcher.hpp>", private, "<boost/statechart/in_state_reaction.hpp>", public ] },
+ { include: ["<boost/statechart/detail/reaction_dispatcher.hpp>", private, "<boost/statechart/transition.hpp>", public ] },
+ { include: ["<boost/statechart/detail/rtti_policy.hpp>", private, "<boost/statechart/event_base.hpp>", public ] },
+ { include: ["<boost/statechart/detail/rtti_policy.hpp>", private, "<boost/statechart/event.hpp>", public ] },
+ { include: ["<boost/statechart/detail/rtti_policy.hpp>", private, "<boost/statechart/state_machine.hpp>", public ] },
+ { include: ["<boost/statechart/detail/state_base.hpp>", private, "<boost/statechart/state_machine.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/debug.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/exception_safety.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/interaction_based.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/logged_expectations.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/mock_object.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/prg_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/test_observer.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/test_tools.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/unit_test_suite_impl.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/utils/basic_cstring/bcs_char_traits.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/utils/class_properties.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/utils/foreach.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/utils/lazy_ostream.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/utils/runtime/config.hpp>", public ] },
+ { include: ["<boost/test/detail/config.hpp>", private, "<boost/test/utils/wrap_stringstream.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/debug.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/exception_safety.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/execution_monitor.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/floating_point_comparison.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/framework.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/interaction_based.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/logged_expectations.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/minimal.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/mock_object.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/output/compiler_log_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/output/plain_report_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/output_test_stream.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/output/xml_log_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/output/xml_report_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/parameterized_test.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/predicate_result.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/progress_monitor.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/results_collector.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/results_reporter.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/test_observer.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/test_tools.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/unit_test_log_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/unit_test_log.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/unit_test_monitor.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/unit_test_suite_impl.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/algorithm.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/basic_cstring/basic_cstring.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/basic_cstring/bcs_char_traits.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/basic_cstring/compare.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/basic_cstring/io.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/callback.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/class_properties.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/custom_manip.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/fixed_mapping.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/foreach.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/iterator/ifstream_line_iterator.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/iterator/input_iterator_facade.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/iterator/istream_line_iterator.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/iterator/token_iterator.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/lazy_ostream.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/named_params.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/nullstream.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/trivial_singleton.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/wrap_stringstream.hpp>", public ] },
+ { include: ["<boost/test/detail/enable_warnings.hpp>", private, "<boost/test/utils/xml_printer.hpp>", public ] },
+ { include: ["<boost/test/detail/fwd_decl.hpp>", private, "<boost/test/execution_monitor.hpp>", public ] },
+ { include: ["<boost/test/detail/fwd_decl.hpp>", private, "<boost/test/framework.hpp>", public ] },
+ { include: ["<boost/test/detail/fwd_decl.hpp>", private, "<boost/test/results_collector.hpp>", public ] },
+ { include: ["<boost/test/detail/fwd_decl.hpp>", private, "<boost/test/results_reporter.hpp>", public ] },
+ { include: ["<boost/test/detail/fwd_decl.hpp>", private, "<boost/test/test_observer.hpp>", public ] },
+ { include: ["<boost/test/detail/fwd_decl.hpp>", private, "<boost/test/unit_test_log_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/fwd_decl.hpp>", private, "<boost/test/unit_test_log.hpp>", public ] },
+ { include: ["<boost/test/detail/fwd_decl.hpp>", private, "<boost/test/unit_test_monitor.hpp>", public ] },
+ { include: ["<boost/test/detail/fwd_decl.hpp>", private, "<boost/test/unit_test_suite_impl.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/execution_monitor.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/floating_point_comparison.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/framework.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/interaction_based.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/minimal.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/output/compiler_log_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/output/plain_report_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/output_test_stream.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/output/xml_log_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/output/xml_report_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/results_collector.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/results_reporter.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/test_observer.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/test_tools.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/unit_test_log_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/unit_test_log.hpp>", public ] },
+ { include: ["<boost/test/detail/global_typedef.hpp>", private, "<boost/test/unit_test_suite_impl.hpp>", public ] },
+ { include: ["<boost/test/detail/log_level.hpp>", private, "<boost/test/unit_test_log_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/log_level.hpp>", private, "<boost/test/unit_test_log.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/debug.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/exception_safety.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/execution_monitor.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/floating_point_comparison.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/framework.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/interaction_based.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/logged_expectations.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/minimal.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/mock_object.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/output/compiler_log_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/output/plain_report_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/output_test_stream.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/output/xml_log_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/output/xml_report_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/parameterized_test.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/predicate_result.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/progress_monitor.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/results_collector.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/results_reporter.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/test_observer.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/test_tools.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/unit_test_log_formatter.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/unit_test_log.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/unit_test_monitor.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/unit_test_suite_impl.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/algorithm.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/basic_cstring/basic_cstring.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/basic_cstring/bcs_char_traits.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/basic_cstring/compare.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/basic_cstring/io.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/callback.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/class_properties.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/custom_manip.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/fixed_mapping.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/foreach.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/iterator/ifstream_line_iterator.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/iterator/input_iterator_facade.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/iterator/istream_line_iterator.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/iterator/token_iterator.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/lazy_ostream.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/named_params.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/nullstream.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/trivial_singleton.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/wrap_stringstream.hpp>", public ] },
+ { include: ["<boost/test/detail/suppress_warnings.hpp>", private, "<boost/test/utils/xml_printer.hpp>", public ] },
+ { include: ["<boost/test/detail/unit_test_parameters.hpp>", private, "<boost/test/logged_expectations.hpp>", public ] },
+ { include: ["<boost/test/impl/compiler_log_formatter.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/compiler_log_formatter.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/cpp_main.ipp>", private, "<boost/test/included/prg_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/debug.ipp>", private, "<boost/test/included/prg_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/debug.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/debug.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/debug.ipp>", private, "<boost/test/minimal.hpp>", public ] },
+ { include: ["<boost/test/impl/exception_safety.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/execution_monitor.ipp>", private, "<boost/test/included/prg_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/execution_monitor.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/execution_monitor.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/execution_monitor.ipp>", private, "<boost/test/minimal.hpp>", public ] },
+ { include: ["<boost/test/impl/framework.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/framework.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/interaction_based.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/logged_expectations.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/plain_report_formatter.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/plain_report_formatter.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/progress_monitor.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/progress_monitor.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/results_collector.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/results_collector.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/results_reporter.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/results_reporter.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/test_main.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/test_tools.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/test_tools.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/unit_test_log.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/unit_test_log.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/unit_test_main.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/unit_test_main.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/unit_test_monitor.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/unit_test_monitor.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/unit_test_parameters.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/unit_test_parameters.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/unit_test_suite.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/unit_test_suite.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/xml_log_formatter.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/xml_log_formatter.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/impl/xml_report_formatter.ipp>", private, "<boost/test/included/test_exec_monitor.hpp>", public ] },
+ { include: ["<boost/test/impl/xml_report_formatter.ipp>", private, "<boost/test/included/unit_test.hpp>", public ] },
+ { include: ["<boost/test/utils/runtime/cla/detail/argument_value_usage.hpp>", private, "<boost/test/utils/runtime/cla/argument_factory.hpp>", public ] },
+ { include: ["<boost/thread/detail/async_func.hpp>", private, "<boost/thread/future.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/barrier.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/completion_latch.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/condition.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/exceptions.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/externally_locked.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/externally_locked_stream.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/future_error_code.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/future.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/is_locked_by_this_thread.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/latch.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/lockable_traits.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/lock_algorithms.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/lock_guard.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/lock_traits.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/lock_types.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/null_mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/once.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/pthread/mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/pthread/once_atomic.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/pthread/once.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/pthread/thread_data.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/pthread/timespec.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/reverse_lock.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/scoped_thread.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/shared_lock_guard.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/shared_mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/strict_lock.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/sync_bounded_queue.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/synchronized_value.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/sync_queue.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/testable_mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/thread_functors.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/tss.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/v2/shared_mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/v2/thread.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/win32/interlocked_read.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/win32/thread_data.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/win32/thread_heap_alloc.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/win32/thread_primitives.hpp>", public ] },
+ { include: ["<boost/thread/detail/config.hpp>", private, "<boost/thread/xtime.hpp>", public ] },
+ { include: ["<boost/thread/detail/counter.hpp>", private, "<boost/thread/completion_latch.hpp>", public ] },
+ { include: ["<boost/thread/detail/counter.hpp>", private, "<boost/thread/latch.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/barrier.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/completion_latch.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/externally_locked_stream.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/latch.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/lockable_adapter.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/lock_guard.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/null_mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/poly_lockable.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/pthread/condition_variable_fwd.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/pthread/condition_variable.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/pthread/mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/pthread/once.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/pthread/recursive_mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/pthread/shared_mutex_assert.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/pthread/shared_mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/reverse_lock.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/scoped_thread.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/shared_lock_guard.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/strict_lock.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/thread_functors.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/thread_guard.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/win32/mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/win32/recursive_mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/delete.hpp>", private, "<boost/thread/win32/shared_mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/invoke.hpp>", private, "<boost/thread/pthread/once_atomic.hpp>", public ] },
+ { include: ["<boost/thread/detail/invoke.hpp>", private, "<boost/thread/pthread/once.hpp>", public ] },
+ { include: ["<boost/thread/detail/invoke.hpp>", private, "<boost/thread/win32/once.hpp>", public ] },
+ { include: ["<boost/thread/detail/is_convertible.hpp>", private, "<boost/thread/future.hpp>", public ] },
+ { include: ["<boost/thread/detail/lockable_wrapper.hpp>", private, "<boost/thread/lock_guard.hpp>", public ] },
+ { include: ["<boost/thread/detail/lockable_wrapper.hpp>", private, "<boost/thread/strict_lock.hpp>", public ] },
+ { include: ["<boost/thread/detail/memory.hpp>", private, "<boost/thread/future.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/externally_locked_stream.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/future.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/lock_concepts.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/lock_guard.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/lock_types.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/pthread/once_atomic.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/pthread/once.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/reverse_lock.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/scoped_thread.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/sync_bounded_queue.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/synchronized_value.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/sync_queue.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/thread_functors.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/thread_guard.hpp>", public ] },
+ { include: ["<boost/thread/detail/move.hpp>", private, "<boost/thread/win32/once.hpp>", public ] },
+ { include: ["<boost/thread/detail/platform.hpp>", private, "<boost/thread/condition_variable.hpp>", public ] },
+ { include: ["<boost/thread/detail/platform.hpp>", private, "<boost/thread/mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/platform.hpp>", private, "<boost/thread/once.hpp>", public ] },
+ { include: ["<boost/thread/detail/platform.hpp>", private, "<boost/thread/recursive_mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/platform.hpp>", private, "<boost/thread/thread_only.hpp>", public ] },
+ { include: ["<boost/thread/detail/thread_group.hpp>", private, "<boost/thread/thread.hpp>", public ] },
+ { include: ["<boost/thread/detail/thread_heap_alloc.hpp>", private, "<boost/thread/tss.hpp>", public ] },
+ { include: ["<boost/thread/detail/thread.hpp>", private, "<boost/thread/thread_only.hpp>", public ] },
+ { include: ["<boost/thread/detail/thread_interruption.hpp>", private, "<boost/thread/pthread/shared_mutex_assert.hpp>", public ] },
+ { include: ["<boost/thread/detail/thread_interruption.hpp>", private, "<boost/thread/pthread/shared_mutex.hpp>", public ] },
+ { include: ["<boost/thread/detail/thread_interruption.hpp>", private, "<boost/thread/thread_only.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/array.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/cmath.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/complex.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/functional.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/memory.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/random.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/regex.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/algorithm>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/array>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/bitset>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/cmath>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/complex>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/deque>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/exception>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/fstream>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/functional>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/iomanip>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/ios>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/iostream>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/istream>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/iterator>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/limits>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/list>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/locale>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/map>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/memory>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/new>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/numeric>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/ostream>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/queue>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/random>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/regex>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/set>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/sstream>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/stack>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/stdexcept>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/streambuf>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/string>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/strstream>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/tuple>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/typeinfo>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/type_traits>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/unordered_map>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/unordered_set>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/utility>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/valarray>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tr1/vector>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/tuple.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/type_traits.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/unordered_map.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/unordered_set.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config_all.hpp>", private, "<boost/tr1/utility.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config.hpp>", private, "<boost/math/tools/tuple.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config.hpp>", private, "<boost/tr1/array.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config.hpp>", private, "<boost/tr1/cmath.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config.hpp>", private, "<boost/tr1/complex.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config.hpp>", private, "<boost/tr1/functional.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config.hpp>", private, "<boost/tr1/memory.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config.hpp>", private, "<boost/tr1/random.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config.hpp>", private, "<boost/tr1/regex.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config.hpp>", private, "<boost/tr1/tuple.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config.hpp>", private, "<boost/tr1/type_traits.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config.hpp>", private, "<boost/tr1/unordered_map.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config.hpp>", private, "<boost/tr1/unordered_set.hpp>", public ] },
+ { include: ["<boost/tr1/detail/config.hpp>", private, "<boost/tr1/utility.hpp>", public ] },
+ { include: ["<boost/tr1/detail/functor2iterator.hpp>", private, "<boost/tr1/random.hpp>", public ] },
+ { include: ["<boost/tr1/detail/math_overloads.hpp>", private, "<boost/tr1/complex.hpp>", public ] },
+ { include: ["<boost/tti/detail/ddata.hpp>", private, "<boost/tti/has_data.hpp>", public ] },
+ { include: ["<boost/tti/detail/ddeftype.hpp>", private, "<boost/tti/has_member_data.hpp>", public ] },
+ { include: ["<boost/tti/detail/ddeftype.hpp>", private, "<boost/tti/has_member_function.hpp>", public ] },
+ { include: ["<boost/tti/detail/ddeftype.hpp>", private, "<boost/tti/has_type.hpp>", public ] },
+ { include: ["<boost/tti/detail/dfunction.hpp>", private, "<boost/tti/has_function.hpp>", public ] },
+ { include: ["<boost/tti/detail/dmem_data.hpp>", private, "<boost/tti/has_member_data.hpp>", public ] },
+ { include: ["<boost/tti/detail/dmem_fun.hpp>", private, "<boost/tti/has_member_function.hpp>", public ] },
+ { include: ["<boost/tti/detail/dmem_type.hpp>", private, "<boost/tti/member_type.hpp>", public ] },
+ { include: ["<boost/tti/detail/dnotype.hpp>", private, "<boost/tti/member_type.hpp>", public ] },
+ { include: ["<boost/tti/detail/dstatic_mem_data.hpp>", private, "<boost/tti/has_static_member_data.hpp>", public ] },
+ { include: ["<boost/tti/detail/dstatic_mem_fun.hpp>", private, "<boost/tti/has_static_member_function.hpp>", public ] },
+ { include: ["<boost/tti/detail/dtemplate.hpp>", private, "<boost/tti/has_template.hpp>", public ] },
+ { include: ["<boost/tti/detail/dtemplate_params.hpp>", private, "<boost/tti/has_template.hpp>", public ] },
+ { include: ["<boost/tti/detail/dtype.hpp>", private, "<boost/tti/has_type.hpp>", public ] },
+ { include: ["<boost/tti/detail/dvm_template_params.hpp>", private, "<boost/tti/has_template.hpp>", public ] },
+ { include: ["<boost/tuple/detail/tuple_basic.hpp>", private, "<boost/tuple/tuple.hpp>", public ] },
+ { include: ["<boost/tuple/detail/tuple_basic_no_partial_spec.hpp>", private, "<boost/tuple/tuple.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/access.hpp>", private, "<boost/type_erasure/any_cast.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/access.hpp>", private, "<boost/type_erasure/any.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/access.hpp>", private, "<boost/type_erasure/binding_of.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/access.hpp>", private, "<boost/type_erasure/call.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/access.hpp>", private, "<boost/type_erasure/check_match.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/access.hpp>", private, "<boost/type_erasure/is_empty.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/access.hpp>", private, "<boost/type_erasure/param.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/access.hpp>", private, "<boost/type_erasure/typeid_of.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/adapt_to_vtable.hpp>", private, "<boost/type_erasure/binding.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/adapt_to_vtable.hpp>", private, "<boost/type_erasure/call.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/any_base.hpp>", private, "<boost/type_erasure/any.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/check_call.hpp>", private, "<boost/type_erasure/call.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/check_map.hpp>", private, "<boost/type_erasure/binding.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/const.hpp>", private, "<boost/type_erasure/free.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/const.hpp>", private, "<boost/type_erasure/member.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/const.hpp>", private, "<boost/type_erasure/operators.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/construct.hpp>", private, "<boost/type_erasure/any.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/extract_concept.hpp>", private, "<boost/type_erasure/call.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/extract_concept.hpp>", private, "<boost/type_erasure/check_match.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/extract_concept.hpp>", private, "<boost/type_erasure/require_match.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/get_placeholders.hpp>", private, "<boost/type_erasure/deduced.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/get_signature.hpp>", private, "<boost/type_erasure/call.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/instantiate.hpp>", private, "<boost/type_erasure/any.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/instantiate.hpp>", private, "<boost/type_erasure/binding.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/macro.hpp>", private, "<boost/type_erasure/free.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/macro.hpp>", private, "<boost/type_erasure/member.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/normalize.hpp>", private, "<boost/type_erasure/any.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/normalize.hpp>", private, "<boost/type_erasure/binding.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/normalize.hpp>", private, "<boost/type_erasure/is_subconcept.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/null.hpp>", private, "<boost/type_erasure/binding.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/rebind_placeholders.hpp>", private, "<boost/type_erasure/binding.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/rebind_placeholders.hpp>", private, "<boost/type_erasure/is_subconcept.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/storage.hpp>", private, "<boost/type_erasure/any.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/storage.hpp>", private, "<boost/type_erasure/builtin.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/storage.hpp>", private, "<boost/type_erasure/constructible.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/storage.hpp>", private, "<boost/type_erasure/param.hpp>", public ] },
+ { include: ["<boost/type_erasure/detail/vtable.hpp>", private, "<boost/type_erasure/binding.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/aligned_storage.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/iterator/is_lvalue_iterator.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/iterator/is_readable_iterator.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/math/tools/traits.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/mpl/empty_base.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/has_new_operator.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/has_nothrow_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/has_nothrow_constructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/has_nothrow_copy.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/has_nothrow_destructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/has_trivial_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/has_trivial_constructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/has_trivial_copy.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/has_trivial_destructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/has_trivial_move_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/has_trivial_move_constructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/has_virtual_destructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_abstract.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_arithmetic.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_array.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_base_and_derived.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_base_of.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_base_of_tr1.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_class.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_complex.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_compound.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_const.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_convertible.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_copy_constructible.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_empty.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_enum.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_float.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_floating_point.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_function.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_fundamental.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_integral.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_lvalue_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_member_function_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_member_object_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_member_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_nothrow_move_assignable.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_nothrow_move_constructible.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_object.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_pod.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_polymorphic.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_rvalue_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_same.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_scalar.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_signed.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_stateless.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_union.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_unsigned.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_virtual_base_of.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_void.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/is_volatile.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/type_traits/type_with_alignment.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/variant/recursive_wrapper_fwd.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_def.hpp>", private, "<boost/variant/static_visitor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/aligned_storage.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/iterator/is_lvalue_iterator.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/mpl/empty_base.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/has_new_operator.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/has_nothrow_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/has_nothrow_constructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/has_nothrow_copy.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/has_nothrow_destructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/has_trivial_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/has_trivial_constructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/has_trivial_copy.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/has_trivial_destructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/has_trivial_move_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/has_trivial_move_constructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/has_virtual_destructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_abstract.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_arithmetic.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_array.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_base_and_derived.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_base_of.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_base_of_tr1.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_class.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_complex.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_compound.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_const.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_convertible.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_copy_constructible.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_empty.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_enum.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_float.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_floating_point.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_function.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_fundamental.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_integral.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_lvalue_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_member_function_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_member_object_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_member_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_nothrow_move_assignable.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_nothrow_move_constructible.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_object.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_pod.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_polymorphic.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_rvalue_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_same.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_scalar.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_signed.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_stateless.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_union.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_unsigned.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_virtual_base_of.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_void.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/is_volatile.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/type_traits/type_with_alignment.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/variant/recursive_wrapper_fwd.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/bool_trait_undef.hpp>", private, "<boost/variant/static_visitor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/common_type_imp.hpp>", private, "<boost/type_traits/common_type.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/cv_traits_impl.hpp>", private, "<boost/type_traits/is_const.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/cv_traits_impl.hpp>", private, "<boost/type_traits/is_volatile.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/cv_traits_impl.hpp>", private, "<boost/type_traits/remove_const.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/cv_traits_impl.hpp>", private, "<boost/type_traits/remove_cv.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/cv_traits_impl.hpp>", private, "<boost/type_traits/remove_volatile.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/false_result.hpp>", private, "<boost/type_traits/is_const.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/false_result.hpp>", private, "<boost/type_traits/is_function.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/false_result.hpp>", private, "<boost/type_traits/is_member_function_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/false_result.hpp>", private, "<boost/type_traits/is_member_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/false_result.hpp>", private, "<boost/type_traits/is_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/false_result.hpp>", private, "<boost/type_traits/is_volatile.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_bit_and_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_bit_and.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_bit_or_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_bit_or.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_bit_xor_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_bit_xor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_divides_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_divides.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_equal_to.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_greater_equal.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_greater.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_left_shift_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_left_shift.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_less_equal.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_less.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_logical_and.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_logical_or.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_minus_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_minus.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_modulus_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_modulus.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_multiplies_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_multiplies.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_not_equal_to.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_plus_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_plus.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_right_shift_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_binary_operator.hpp>", private, "<boost/type_traits/has_right_shift.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_postfix_operator.hpp>", private, "<boost/type_traits/has_post_decrement.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_postfix_operator.hpp>", private, "<boost/type_traits/has_post_increment.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_prefix_operator.hpp>", private, "<boost/type_traits/has_complement.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_prefix_operator.hpp>", private, "<boost/type_traits/has_dereference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_prefix_operator.hpp>", private, "<boost/type_traits/has_logical_not.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_prefix_operator.hpp>", private, "<boost/type_traits/has_negate.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_prefix_operator.hpp>", private, "<boost/type_traits/has_pre_decrement.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_prefix_operator.hpp>", private, "<boost/type_traits/has_pre_increment.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_prefix_operator.hpp>", private, "<boost/type_traits/has_unary_minus.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/has_prefix_operator.hpp>", private, "<boost/type_traits/has_unary_plus.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/has_trivial_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/has_trivial_copy.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/has_trivial_move_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/has_trivial_move_constructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/ice.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/is_abstract.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/is_base_and_derived.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/is_base_of.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/is_class.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/is_empty.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/is_member_object_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/is_nothrow_move_assignable.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/is_nothrow_move_constructible.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/is_object.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/is_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/is_same.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/is_stateless.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/make_signed.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/type_traits/make_unsigned.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_and.hpp>", private, "<boost/units/scaled_base_unit.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_eq.hpp>", private, "<boost/type_traits/ice.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_not.hpp>", private, "<boost/type_traits/has_trivial_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_not.hpp>", private, "<boost/type_traits/has_trivial_copy.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_not.hpp>", private, "<boost/type_traits/has_trivial_move_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_not.hpp>", private, "<boost/type_traits/has_trivial_move_constructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_not.hpp>", private, "<boost/type_traits/ice.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_not.hpp>", private, "<boost/type_traits/is_class.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_not.hpp>", private, "<boost/type_traits/is_compound.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_not.hpp>", private, "<boost/type_traits/is_empty.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_not.hpp>", private, "<boost/type_traits/is_member_object_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_not.hpp>", private, "<boost/type_traits/is_nothrow_move_assignable.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_not.hpp>", private, "<boost/type_traits/is_object.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_not.hpp>", private, "<boost/type_traits/is_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_not.hpp>", private, "<boost/type_traits/make_signed.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_not.hpp>", private, "<boost/type_traits/make_unsigned.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/has_new_operator.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/has_trivial_assign.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/has_trivial_constructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/has_trivial_copy.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/has_trivial_destructor.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/ice.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/is_arithmetic.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/is_base_of.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/is_base_of_tr1.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/is_empty.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/is_fundamental.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/is_member_function_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/is_member_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/is_nothrow_move_assignable.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/is_nothrow_move_constructible.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/is_pod.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/is_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/is_scalar.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/is_signed.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/is_unsigned.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/make_signed.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/type_traits/make_unsigned.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/ice_or.hpp>", private, "<boost/units/scaled_base_unit.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/is_function_ptr_helper.hpp>", private, "<boost/type_traits/is_function.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/is_function_ptr_tester.hpp>", private, "<boost/type_traits/is_function.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/is_function_ptr_tester.hpp>", private, "<boost/type_traits/is_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/is_mem_fun_pointer_impl.hpp>", private, "<boost/type_traits/is_member_function_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/is_mem_fun_pointer_tester.hpp>", private, "<boost/type_traits/is_member_function_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/is_mem_fun_pointer_tester.hpp>", private, "<boost/type_traits/is_member_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/size_t_trait_def.hpp>", private, "<boost/type_traits/alignment_of.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/size_t_trait_def.hpp>", private, "<boost/type_traits/extent.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/size_t_trait_def.hpp>", private, "<boost/type_traits/rank.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/size_t_trait_undef.hpp>", private, "<boost/type_traits/alignment_of.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/size_t_trait_undef.hpp>", private, "<boost/type_traits/extent.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/size_t_trait_undef.hpp>", private, "<boost/type_traits/rank.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/components.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/function_arity.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/function_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/function_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/is_callable_builtin.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/is_function.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/is_function_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/is_function_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/is_member_function_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/is_member_object_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/is_member_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/is_nonmember_callable_builtin.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/member_function_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/member_object_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/parameter_types.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/template_arity_spec.hpp>", private, "<boost/function_types/result_type.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/add_const.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/add_cv.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/add_lvalue_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/add_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/add_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/add_rvalue_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/add_volatile.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/floating_point_promotion.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/integral_promotion.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/make_signed.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/make_unsigned.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/promote.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/remove_all_extents.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/remove_bounds.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/remove_const.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/remove_cv.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/remove_extent.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/remove_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/remove_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_def.hpp>", private, "<boost/type_traits/remove_volatile.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/add_const.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/add_cv.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/add_lvalue_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/add_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/add_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/add_rvalue_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/add_volatile.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/broken_compiler_spec.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/floating_point_promotion.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/integral_promotion.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/make_signed.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/make_unsigned.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/promote.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/remove_all_extents.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/remove_bounds.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/remove_const.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/remove_cv.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/remove_extent.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/remove_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/remove_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/type_trait_undef.hpp>", private, "<boost/type_traits/remove_volatile.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/wrap.hpp>", private, "<boost/type_traits/is_array.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/wrap.hpp>", private, "<boost/type_traits/is_lvalue_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/assign/list_of.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/assign/ptr_list_of.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/ptr_container/nullable.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/has_new_operator.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/ice.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/is_abstract.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/is_array.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/is_class.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/is_const.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/is_convertible.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/is_copy_constructible.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/is_function.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/is_lvalue_reference.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/is_member_function_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/is_member_pointer.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/is_same.hpp>", public ] },
+ { include: ["<boost/type_traits/detail/yes_no_type.hpp>", private, "<boost/type_traits/is_volatile.hpp>", public ] },
+ { include: ["<boost/units/detail/absolute_impl.hpp>", private, "<boost/units/absolute.hpp>", public ] },
+ { include: ["<boost/units/detail/cmath_impl.hpp>", private, "<boost/units/cmath.hpp>", public ] },
+ { include: ["<boost/units/detail/conversion_impl.hpp>", private, "<boost/units/conversion.hpp>", public ] },
+ { include: ["<boost/units/detail/dimension_impl.hpp>", private, "<boost/units/dimension.hpp>", public ] },
+ { include: ["<boost/units/detail/dimensionless_unit.hpp>", private, "<boost/units/cmath.hpp>", public ] },
+ { include: ["<boost/units/detail/dimensionless_unit.hpp>", private, "<boost/units/lambda.hpp>", public ] },
+ { include: ["<boost/units/detail/dimensionless_unit.hpp>", private, "<boost/units/quantity.hpp>", public ] },
+ { include: ["<boost/units/detail/dimension_list.hpp>", private, "<boost/units/base_dimension.hpp>", public ] },
+ { include: ["<boost/units/detail/dimension_list.hpp>", private, "<boost/units/base_unit.hpp>", public ] },
+ { include: ["<boost/units/detail/dimension_list.hpp>", private, "<boost/units/derived_dimension.hpp>", public ] },
+ { include: ["<boost/units/detail/dimension_list.hpp>", private, "<boost/units/dimension.hpp>", public ] },
+ { include: ["<boost/units/detail/dimension_list.hpp>", private, "<boost/units/make_system.hpp>", public ] },
+ { include: ["<boost/units/detail/dim_impl.hpp>", private, "<boost/units/dim.hpp>", public ] },
+ { include: ["<boost/units/detail/linear_algebra.hpp>", private, "<boost/units/heterogeneous_system.hpp>", public ] },
+ { include: ["<boost/units/detail/linear_algebra.hpp>", private, "<boost/units/homogeneous_system.hpp>", public ] },
+ { include: ["<boost/units/detail/one.hpp>", private, "<boost/units/scale.hpp>", public ] },
+ { include: ["<boost/units/detail/ordinal.hpp>", private, "<boost/units/base_dimension.hpp>", public ] },
+ { include: ["<boost/units/detail/ordinal.hpp>", private, "<boost/units/base_unit.hpp>", public ] },
+ { include: ["<boost/units/detail/prevent_redefinition.hpp>", private, "<boost/units/base_dimension.hpp>", public ] },
+ { include: ["<boost/units/detail/prevent_redefinition.hpp>", private, "<boost/units/base_unit.hpp>", public ] },
+ { include: ["<boost/units/detail/push_front_if.hpp>", private, "<boost/units/heterogeneous_system.hpp>", public ] },
+ { include: ["<boost/units/detail/push_front_or_add.hpp>", private, "<boost/units/heterogeneous_system.hpp>", public ] },
+ { include: ["<boost/units/detail/sort.hpp>", private, "<boost/units/make_system.hpp>", public ] },
+ { include: ["<boost/units/detail/static_rational_power.hpp>", private, "<boost/units/pow.hpp>", public ] },
+ { include: ["<boost/units/detail/static_rational_power.hpp>", private, "<boost/units/scale.hpp>", public ] },
+ { include: ["<boost/units/detail/unscale.hpp>", private, "<boost/units/heterogeneous_system.hpp>", public ] },
+ { include: ["<boost/units/detail/utility.hpp>", private, "<boost/units/io.hpp>", public ] },
+ { include: ["<boost/units/systems/detail/constants.hpp>", private, "<boost/units/systems/si/codata/alpha_constants.hpp>", public ] },
+ { include: ["<boost/units/systems/detail/constants.hpp>", private, "<boost/units/systems/si/codata/deuteron_constants.hpp>", public ] },
+ { include: ["<boost/units/systems/detail/constants.hpp>", private, "<boost/units/systems/si/codata/electromagnetic_constants.hpp>", public ] },
+ { include: ["<boost/units/systems/detail/constants.hpp>", private, "<boost/units/systems/si/codata/electron_constants.hpp>", public ] },
+ { include: ["<boost/units/systems/detail/constants.hpp>", private, "<boost/units/systems/si/codata/helion_constants.hpp>", public ] },
+ { include: ["<boost/units/systems/detail/constants.hpp>", private, "<boost/units/systems/si/codata/muon_constants.hpp>", public ] },
+ { include: ["<boost/units/systems/detail/constants.hpp>", private, "<boost/units/systems/si/codata/neutron_constants.hpp>", public ] },
+ { include: ["<boost/units/systems/detail/constants.hpp>", private, "<boost/units/systems/si/codata/physico-chemical_constants.hpp>", public ] },
+ { include: ["<boost/units/systems/detail/constants.hpp>", private, "<boost/units/systems/si/codata/proton_constants.hpp>", public ] },
+ { include: ["<boost/units/systems/detail/constants.hpp>", private, "<boost/units/systems/si/codata/tau_constants.hpp>", public ] },
+ { include: ["<boost/units/systems/detail/constants.hpp>", private, "<boost/units/systems/si/codata/triton_constants.hpp>", public ] },
+ { include: ["<boost/units/systems/detail/constants.hpp>", private, "<boost/units/systems/si/codata/universal_constants.hpp>", public ] },
+ { include: ["<boost/unordered/detail/equivalent.hpp>", private, "<boost/unordered/unordered_map.hpp>", public ] },
+ { include: ["<boost/unordered/detail/equivalent.hpp>", private, "<boost/unordered/unordered_set.hpp>", public ] },
+ { include: ["<boost/unordered/detail/fwd.hpp>", private, "<boost/unordered/unordered_map_fwd.hpp>", public ] },
+ { include: ["<boost/unordered/detail/fwd.hpp>", private, "<boost/unordered/unordered_set_fwd.hpp>", public ] },
+ { include: ["<boost/unordered/detail/unique.hpp>", private, "<boost/unordered/unordered_map.hpp>", public ] },
+ { include: ["<boost/unordered/detail/unique.hpp>", private, "<boost/unordered/unordered_set.hpp>", public ] },
+ { include: ["<boost/unordered/detail/util.hpp>", private, "<boost/unordered/unordered_map.hpp>", public ] },
+ { include: ["<boost/unordered/detail/util.hpp>", private, "<boost/unordered/unordered_set.hpp>", public ] },
+ { include: ["<boost/utility/detail/in_place_factory_prefix.hpp>", private, "<boost/utility/in_place_factory.hpp>", public ] },
+ { include: ["<boost/utility/detail/in_place_factory_prefix.hpp>", private, "<boost/utility/typed_in_place_factory.hpp>", public ] },
+ { include: ["<boost/utility/detail/in_place_factory_suffix.hpp>", private, "<boost/utility/in_place_factory.hpp>", public ] },
+ { include: ["<boost/utility/detail/in_place_factory_suffix.hpp>", private, "<boost/utility/typed_in_place_factory.hpp>", public ] },
+ { include: ["<boost/variant/detail/apply_visitor_binary.hpp>", private, "<boost/variant/apply_visitor.hpp>", public ] },
+ { include: ["<boost/variant/detail/apply_visitor_delayed.hpp>", private, "<boost/variant/apply_visitor.hpp>", public ] },
+ { include: ["<boost/variant/detail/apply_visitor_unary.hpp>", private, "<boost/variant/apply_visitor.hpp>", public ] },
+ { include: ["<boost/variant/detail/backup_holder.hpp>", private, "<boost/variant/variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/config.hpp>", private, "<boost/variant/variant_fwd.hpp>", public ] },
+ { include: ["<boost/variant/detail/config.hpp>", private, "<boost/variant/variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/enable_recursive_fwd.hpp>", private, "<boost/variant/variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/enable_recursive.hpp>", private, "<boost/variant/recursive_variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/forced_return.hpp>", private, "<boost/variant/variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/generic_result_type.hpp>", private, "<boost/variant/variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/hash_variant.hpp>", private, "<boost/variant/variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/initializer.hpp>", private, "<boost/variant/variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/make_variant_list.hpp>", private, "<boost/variant/recursive_variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/make_variant_list.hpp>", private, "<boost/variant/variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/move.hpp>", private, "<boost/variant/recursive_wrapper.hpp>", public ] },
+ { include: ["<boost/variant/detail/move.hpp>", private, "<boost/variant/variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/over_sequence.hpp>", private, "<boost/variant/recursive_variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/over_sequence.hpp>", private, "<boost/variant/variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/substitute_fwd.hpp>", private, "<boost/variant/recursive_variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/substitute_fwd.hpp>", private, "<boost/variant/variant_fwd.hpp>", public ] },
+ { include: ["<boost/variant/detail/variant_io.hpp>", private, "<boost/variant/variant.hpp>", public ] },
+ { include: ["<boost/variant/detail/visitation_impl.hpp>", private, "<boost/variant/variant.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/access.hpp>", private, "<boost/xpressive/regex_iterator.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/icase.hpp>", private, "<boost/xpressive/regex_primitives.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/linker.hpp>", private, "<boost/xpressive/regex_compiler.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/matcher/action_matcher.hpp>", private, "<boost/xpressive/regex_actions.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/matcher/attr_begin_matcher.hpp>", private, "<boost/xpressive/regex_actions.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/matcher/attr_end_matcher.hpp>", private, "<boost/xpressive/regex_actions.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/matcher/attr_matcher.hpp>", private, "<boost/xpressive/regex_actions.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/matcher/predicate_matcher.hpp>", private, "<boost/xpressive/regex_actions.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/matchers.hpp>", private, "<boost/xpressive/regex_primitives.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/optimize.hpp>", private, "<boost/xpressive/regex_compiler.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/regex_domain.hpp>", private, "<boost/xpressive/basic_regex.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/regex_domain.hpp>", private, "<boost/xpressive/regex_primitives.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/regex_impl.hpp>", private, "<boost/xpressive/basic_regex.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/regex_impl.hpp>", private, "<boost/xpressive/match_results.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/results_cache.hpp>", private, "<boost/xpressive/match_results.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/regex_actions.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/state.hpp>", private, "<boost/xpressive/regex_algorithms.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/core/sub_match_vector.hpp>", private, "<boost/xpressive/match_results.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/basic_regex.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/match_results.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/regex_actions.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/regex_algorithms.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/regex_iterator.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/regex_primitives.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/regex_traits.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/sub_match.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/traits/cpp_regex_traits.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/traits/null_regex_traits.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/detail_fwd.hpp>", private, "<boost/xpressive/xpressive_typeof.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/dynamic/parse_charset.hpp>", private, "<boost/xpressive/regex_compiler.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/dynamic/parser_enum.hpp>", private, "<boost/xpressive/regex_compiler.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/dynamic/parser.hpp>", private, "<boost/xpressive/regex_compiler.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/dynamic/parser_traits.hpp>", private, "<boost/xpressive/regex_compiler.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/static/compile.hpp>", private, "<boost/xpressive/regex_primitives.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/static/grammar.hpp>", private, "<boost/xpressive/basic_regex.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/static/modifier.hpp>", private, "<boost/xpressive/regex_primitives.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/static/type_traits.hpp>", private, "<boost/xpressive/regex_actions.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/utility/algorithm.hpp>", private, "<boost/xpressive/match_results.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/utility/counted_base.hpp>", private, "<boost/xpressive/match_results.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/utility/counted_base.hpp>", private, "<boost/xpressive/regex_iterator.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/utility/ignore_unused.hpp>", private, "<boost/xpressive/regex_actions.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/utility/ignore_unused.hpp>", private, "<boost/xpressive/regex_primitives.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/utility/ignore_unused.hpp>", private, "<boost/xpressive/traits/null_regex_traits.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/utility/literals.hpp>", private, "<boost/xpressive/match_results.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/utility/literals.hpp>", private, "<boost/xpressive/traits/cpp_regex_traits.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/utility/never_true.hpp>", private, "<boost/xpressive/traits/null_regex_traits.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/utility/save_restore.hpp>", private, "<boost/xpressive/regex_algorithms.hpp>", public ] },
+ { include: ["<boost/xpressive/detail/utility/sequence_stack.hpp>", private, "<boost/xpressive/match_results.hpp>", public ] },
+ { include: ["<boost/xpressive/traits/detail/c_ctype.hpp>", private, "<boost/xpressive/traits/c_regex_traits.hpp>", public ] }
+]
diff --git a/src/arrow/cpp/build-support/iwyu/mappings/boost-extra.imp b/src/arrow/cpp/build-support/iwyu/mappings/boost-extra.imp
new file mode 100644
index 000000000..aba1e4191
--- /dev/null
+++ b/src/arrow/cpp/build-support/iwyu/mappings/boost-extra.imp
@@ -0,0 +1,23 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+[
+ { include: ["<boost/core/explicit_operator_bool.hpp>", private, "<boost/optional/optional.hpp>", public ] },
+ { include: ["<boost/cstdint.hpp>", private, "<cstdint>", public ] },
+ { include: ["<boost/none.hpp>", private, "<boost/optional/optional.hpp>", public ] },
+ { include: ["<boost/optional/detail/optional_relops.hpp>", private, "<boost/optional/optional.hpp>", public ] },
+ { include: ["<boost/optional/detail/optional_reference_spec.hpp>", private, "<boost/optional/optional.hpp>", public ] }
+]
diff --git a/src/arrow/cpp/build-support/iwyu/mappings/gflags.imp b/src/arrow/cpp/build-support/iwyu/mappings/gflags.imp
new file mode 100644
index 000000000..46ce63d1e
--- /dev/null
+++ b/src/arrow/cpp/build-support/iwyu/mappings/gflags.imp
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+[
+ # <gflags/gflags_declare.h> confuses the IWYU tool because of the 'using '
+ { symbol: [ "fLS::clstring", private, "<string>", public ] }
+]
diff --git a/src/arrow/cpp/build-support/iwyu/mappings/glog.imp b/src/arrow/cpp/build-support/iwyu/mappings/glog.imp
new file mode 100644
index 000000000..08c5e3529
--- /dev/null
+++ b/src/arrow/cpp/build-support/iwyu/mappings/glog.imp
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+[
+ { symbol: [ "LOG", private, "<glog/logging.h>", public ] },
+ { symbol: [ "VLOG", private, "<glog/logging.h>", public ] },
+ { symbol: [ "CHECK_EQ", private, "<glog/logging.h>", public ] },
+ { symbol: [ "CHECK_NE", private, "<glog/logging.h>", public ] },
+ { symbol: [ "CHECK_LT", private, "<glog/logging.h>", public ] },
+ { symbol: [ "CHECK_GE", private, "<glog/logging.h>", public ] },
+ { symbol: [ "CHECK_GT", private, "<glog/logging.h>", public ] },
+ { symbol: [ "ErrnoLogMessage", private, "<glog/logging.h>", public ] },
+ { symbol: [ "COMPACT_GOOGLE_LOG_0", private, "<glog/logging.h>", public ] }
+]
diff --git a/src/arrow/cpp/build-support/iwyu/mappings/gmock.imp b/src/arrow/cpp/build-support/iwyu/mappings/gmock.imp
new file mode 100644
index 000000000..76e7cafdd
--- /dev/null
+++ b/src/arrow/cpp/build-support/iwyu/mappings/gmock.imp
@@ -0,0 +1,23 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#include
+#include
+
+[
+ { include: [ "<gmock/gmock-generated-matchers.h>", private, "<gmock/gmock.h>", public ] },
+ { include: [ "<gmock/gmock-matchers.h>", private, "<gmock/gmock.h>", public ] }
+] \ No newline at end of file
diff --git a/src/arrow/cpp/build-support/iwyu/mappings/gtest.imp b/src/arrow/cpp/build-support/iwyu/mappings/gtest.imp
new file mode 100644
index 000000000..a54165027
--- /dev/null
+++ b/src/arrow/cpp/build-support/iwyu/mappings/gtest.imp
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+[
+ { include: [ "<gtest/internal/gtest-internal.h>", private, "<gtest/gtest.h>", public ] },
+ { include: [ "<gtest/internal/gtest-string.h>", private, "<gtest/gtest.h>", public ] },
+ { include: [ "<gtest/gtest-death-test.h>", private, "<gtest/gtest.h>", public ] },
+ { include: [ "<gtest/gtest-message.h>", private, "<gtest/gtest.h>", public ] },
+ { include: [ "<gtest/gtest-param-test.h>", private, "<gtest/gtest.h>", public ] },
+ { include: [ "<gtest/gtest-printers.h>", private, "<gtest/gtest.h>", public ] },
+ { include: [ "<gtest/gtest-test-part.h>", private, "<gtest/gtest.h>", public ] },
+ { include: [ "<gtest/gtest-typed-test.h>", private, "<gtest/gtest.h>", public ] }
+]
diff --git a/src/arrow/cpp/build-support/lint_cpp_cli.py b/src/arrow/cpp/build-support/lint_cpp_cli.py
new file mode 100755
index 000000000..a0eb8f0ef
--- /dev/null
+++ b/src/arrow/cpp/build-support/lint_cpp_cli.py
@@ -0,0 +1,128 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import argparse
+import re
+import os
+
+parser = argparse.ArgumentParser(
+ description="Check for illegal headers for C++/CLI applications")
+parser.add_argument("source_path",
+ help="Path to source code")
+arguments = parser.parse_args()
+
+
+_STRIP_COMMENT_REGEX = re.compile('(.+)?(?=//)')
+_NULLPTR_REGEX = re.compile(r'.*\bnullptr\b.*')
+_RETURN_NOT_OK_REGEX = re.compile(r'.*\sRETURN_NOT_OK.*')
+_ASSIGN_OR_RAISE_REGEX = re.compile(r'.*\sASSIGN_OR_RAISE.*')
+
+
+def _paths(paths):
+ return [p.strip().replace('/', os.path.sep) for p in paths.splitlines()]
+
+
+def _strip_comments(line):
+ m = _STRIP_COMMENT_REGEX.match(line)
+ if not m:
+ return line
+ else:
+ return m.group(0)
+
+
+def lint_file(path):
+ fail_rules = [
+ # rule, error message, rule-specific exclusions list
+ (lambda x: '<mutex>' in x, 'Uses <mutex>', []),
+ (lambda x: '<iostream>' in x, 'Uses <iostream>', []),
+ (lambda x: re.match(_NULLPTR_REGEX, x), 'Uses nullptr', []),
+ (lambda x: re.match(_RETURN_NOT_OK_REGEX, x),
+ 'Use ARROW_RETURN_NOT_OK in header files', _paths('''\
+ arrow/status.h
+ test
+ arrow/util/hash.h
+ arrow/python/util''')),
+ (lambda x: re.match(_ASSIGN_OR_RAISE_REGEX, x),
+ 'Use ARROW_ASSIGN_OR_RAISE in header files', _paths('''\
+ arrow/result_internal.h
+ test
+ '''))
+
+ ]
+
+ with open(path) as f:
+ for i, line in enumerate(f):
+ stripped_line = _strip_comments(line)
+ for rule, why, rule_exclusions in fail_rules:
+ if any([True for excl in rule_exclusions if excl in path]):
+ continue
+
+ if rule(stripped_line):
+ yield path, why, i, line
+
+
+EXCLUSIONS = _paths('''\
+ arrow/arrow-config.cmake
+ arrow/python/iterators.h
+ arrow/util/hashing.h
+ arrow/util/macros.h
+ arrow/util/parallel.h
+ arrow/vendored
+ arrow/visitor_inline.h
+ gandiva/cache.h
+ gandiva/jni
+ jni/
+ test
+ internal
+ _generated''')
+
+
+def lint_files():
+ for dirpath, _, filenames in os.walk(arguments.source_path):
+ for filename in filenames:
+ full_path = os.path.join(dirpath, filename)
+
+ exclude = False
+ for exclusion in EXCLUSIONS:
+ if exclusion in full_path:
+ exclude = True
+ break
+
+ if exclude:
+ continue
+
+ # Lint file name, except for pkg-config templates
+ if not filename.endswith('.pc.in'):
+ if '-' in filename:
+ why = ("Please use underscores, not hyphens, "
+ "in source file names")
+ yield full_path, why, 0, full_path
+
+ # Only run on header files
+ if filename.endswith('.h'):
+ for _ in lint_file(full_path):
+ yield _
+
+
+if __name__ == '__main__':
+ failures = list(lint_files())
+ for path, why, i, line in failures:
+ print('File {0} failed C++/CLI lint check: {1}\n'
+ 'Line {2}: {3}'.format(path, why, i + 1, line))
+ if failures:
+ exit(1)
diff --git a/src/arrow/cpp/build-support/lint_exclusions.txt b/src/arrow/cpp/build-support/lint_exclusions.txt
new file mode 100644
index 000000000..4feb8fbe1
--- /dev/null
+++ b/src/arrow/cpp/build-support/lint_exclusions.txt
@@ -0,0 +1,15 @@
+*_generated*
+*.grpc.fb.*
+*apidoc/*
+*arrowExports.cpp*
+*build_support/*
+*parquet_constants.*
+*parquet_types.*
+*pyarrow_api.h
+*pyarrow_lib.h
+*python/config.h
+*python/platform.h
+*RcppExports.cpp*
+*thirdparty/*
+*vendored/*
+*windows_compatibility.h
diff --git a/src/arrow/cpp/build-support/lintutils.py b/src/arrow/cpp/build-support/lintutils.py
new file mode 100644
index 000000000..2386eb2e6
--- /dev/null
+++ b/src/arrow/cpp/build-support/lintutils.py
@@ -0,0 +1,109 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import multiprocessing as mp
+import os
+from fnmatch import fnmatch
+from subprocess import Popen
+
+
+def chunk(seq, n):
+ """
+ divide a sequence into equal sized chunks
+ (the last chunk may be smaller, but won't be empty)
+ """
+ chunks = []
+ some = []
+ for element in seq:
+ if len(some) == n:
+ chunks.append(some)
+ some = []
+ some.append(element)
+ if len(some) > 0:
+ chunks.append(some)
+ return chunks
+
+
+def dechunk(chunks):
+ "flatten chunks into a single list"
+ seq = []
+ for chunk in chunks:
+ seq.extend(chunk)
+ return seq
+
+
+def run_parallel(cmds, **kwargs):
+ """
+ Run each of cmds (with shared **kwargs) using subprocess.Popen
+ then wait for all of them to complete.
+ Runs batches of multiprocessing.cpu_count() * 2 from cmds
+ returns a list of tuples containing each process'
+ returncode, stdout, stderr
+ """
+ complete = []
+ for cmds_batch in chunk(cmds, mp.cpu_count() * 2):
+ procs_batch = [Popen(cmd, **kwargs) for cmd in cmds_batch]
+ for proc in procs_batch:
+ stdout, stderr = proc.communicate()
+ complete.append((proc.returncode, stdout, stderr))
+ return complete
+
+
+_source_extensions = '''
+.h
+.cc
+.cpp
+'''.split()
+
+
+def get_sources(source_dir, exclude_globs=[]):
+ sources = []
+ for directory, subdirs, basenames in os.walk(source_dir):
+ for path in [os.path.join(directory, basename)
+ for basename in basenames]:
+ # filter out non-source files
+ if os.path.splitext(path)[1] not in _source_extensions:
+ continue
+
+ path = os.path.abspath(path)
+
+ # filter out files that match the globs in the globs file
+ if any([fnmatch(path, glob) for glob in exclude_globs]):
+ continue
+
+ sources.append(path)
+ return sources
+
+
+def stdout_pathcolonline(completed_process, filenames):
+ """
+ given a completed process which may have reported some files as problematic
+ by printing the path name followed by ':' then a line number, examine
+ stdout and return the set of actually reported file names
+ """
+ returncode, stdout, stderr = completed_process
+ bfilenames = set()
+ for filename in filenames:
+ bfilenames.add(filename.encode('utf-8') + b':')
+ problem_files = set()
+ for line in stdout.splitlines():
+ for filename in bfilenames:
+ if line.startswith(filename):
+ problem_files.add(filename.decode('utf-8'))
+ bfilenames.remove(filename)
+ break
+ return problem_files, stdout
diff --git a/src/arrow/cpp/build-support/lsan-suppressions.txt b/src/arrow/cpp/build-support/lsan-suppressions.txt
new file mode 100644
index 000000000..566857a9c
--- /dev/null
+++ b/src/arrow/cpp/build-support/lsan-suppressions.txt
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# False positive from atexit() registration in libc
+leak:*__new_exitfn*
+# Leak at shutdown in OpenSSL
+leak:CRYPTO_zalloc
diff --git a/src/arrow/cpp/build-support/run-infer.sh b/src/arrow/cpp/build-support/run-infer.sh
new file mode 100755
index 000000000..7d1853437
--- /dev/null
+++ b/src/arrow/cpp/build-support/run-infer.sh
@@ -0,0 +1,48 @@
+#!/usr/bin/env bash
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+#
+# Runs infer in the given directory
+# Arguments:
+# $1 - Path to the infer binary
+# $2 - Path to the compile_commands.json to use
+# $3 - Apply infer step (1=capture, 2=analyze, 3=report)
+#
+INFER=$1
+shift
+COMPILE_COMMANDS=$1
+shift
+APPLY_STEP=$1
+shift
+
+if [ "$APPLY_STEP" == "1" ]; then
+ $INFER capture --compilation-database $COMPILE_COMMANDS
+ echo ""
+ echo "Run 'make infer-analyze' next."
+elif [ "$APPLY_STEP" == "2" ]; then
+ # infer's analyze step can take a very long time to complete
+ $INFER analyze
+ echo ""
+ echo "Run 'make infer-report' next."
+ echo "See: http://fbinfer.com/docs/steps-for-ci.html"
+elif [ "$APPLY_STEP" == "3" ]; then
+ $INFER report --issues-csv ./infer-out/report.csv 1> /dev/null
+ $INFER report --issues-txt ./infer-out/report.txt 1> /dev/null
+ $INFER report --issues-json ./infer-out/report.json 1> /dev/null
+ echo ""
+ echo "Reports (report.txt, report.csv, report.json) can be found in the infer-out subdirectory."
+else
+ echo ""
+ echo "See: http://fbinfer.com/docs/steps-for-ci.html"
+fi
diff --git a/src/arrow/cpp/build-support/run-test.sh b/src/arrow/cpp/build-support/run-test.sh
new file mode 100755
index 000000000..d2d327cfd
--- /dev/null
+++ b/src/arrow/cpp/build-support/run-test.sh
@@ -0,0 +1,237 @@
+#!/usr/bin/env bash
+# Copyright 2014 Cloudera, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Script which wraps running a test and redirects its output to a
+# test log directory.
+#
+# Arguments:
+# $1 - Base path for logs/artifacts.
+# $2 - type of test (e.g. test or benchmark)
+# $3 - path to executable
+# $ARGN - arguments for executable
+#
+
+OUTPUT_ROOT=$1
+shift
+ROOT=$(cd $(dirname $BASH_SOURCE)/..; pwd)
+
+TEST_LOGDIR=$OUTPUT_ROOT/build/$1-logs
+mkdir -p $TEST_LOGDIR
+
+RUN_TYPE=$1
+shift
+TEST_DEBUGDIR=$OUTPUT_ROOT/build/$RUN_TYPE-debug
+mkdir -p $TEST_DEBUGDIR
+
+TEST_DIRNAME=$(cd $(dirname $1); pwd)
+TEST_FILENAME=$(basename $1)
+shift
+TEST_EXECUTABLE="$TEST_DIRNAME/$TEST_FILENAME"
+TEST_NAME=$(echo $TEST_FILENAME | perl -pe 's/\..+?$//') # Remove path and extension (if any).
+
+# We run each test in its own subdir to avoid core file related races.
+TEST_WORKDIR=$OUTPUT_ROOT/build/test-work/$TEST_NAME
+mkdir -p $TEST_WORKDIR
+pushd $TEST_WORKDIR >/dev/null || exit 1
+rm -f *
+
+set -o pipefail
+
+LOGFILE=$TEST_LOGDIR/$TEST_NAME.txt
+XMLFILE=$TEST_LOGDIR/$TEST_NAME.xml
+
+TEST_EXECUTION_ATTEMPTS=1
+
+# Remove both the uncompressed output, so the developer doesn't accidentally get confused
+# and read output from a prior test run.
+rm -f $LOGFILE $LOGFILE.gz
+
+pipe_cmd=cat
+
+function setup_sanitizers() {
+ # Sets environment variables for different sanitizers (it configures how) the run_tests. Function works.
+
+ # Configure TSAN (ignored if this isn't a TSAN build).
+ #
+ TSAN_OPTIONS="$TSAN_OPTIONS suppressions=$ROOT/build-support/tsan-suppressions.txt"
+ TSAN_OPTIONS="$TSAN_OPTIONS history_size=7"
+ # Some tests deliberately fail allocating memory
+ TSAN_OPTIONS="$TSAN_OPTIONS allocator_may_return_null=1"
+ export TSAN_OPTIONS
+
+ UBSAN_OPTIONS="$UBSAN_OPTIONS print_stacktrace=1"
+ UBSAN_OPTIONS="$UBSAN_OPTIONS suppressions=$ROOT/build-support/ubsan-suppressions.txt"
+ export UBSAN_OPTIONS
+
+ # Enable leak detection even under LLVM 3.4, where it was disabled by default.
+ # This flag only takes effect when running an ASAN build.
+ # ASAN_OPTIONS="$ASAN_OPTIONS detect_leaks=1"
+ # export ASAN_OPTIONS
+
+ # Set up suppressions for LeakSanitizer
+ LSAN_OPTIONS="$LSAN_OPTIONS suppressions=$ROOT/build-support/lsan-suppressions.txt"
+ export LSAN_OPTIONS
+}
+
+function run_test() {
+ # Run gtest style tests with sanitizers if they are setup appropriately.
+
+ # gtest won't overwrite old junit test files, resulting in a build failure
+ # even when retries are successful.
+ rm -f $XMLFILE
+
+ $TEST_EXECUTABLE "$@" > $LOGFILE.raw 2>&1
+ STATUS=$?
+ cat $LOGFILE.raw \
+ | ${PYTHON:-python} $ROOT/build-support/asan_symbolize.py \
+ | ${CXXFILT:-c++filt} \
+ | $ROOT/build-support/stacktrace_addr2line.pl $TEST_EXECUTABLE \
+ | $pipe_cmd 2>&1 | tee $LOGFILE
+ rm -f $LOGFILE.raw
+
+ # TSAN doesn't always exit with a non-zero exit code due to a bug:
+ # mutex errors don't get reported through the normal error reporting infrastructure.
+ # So we make sure to detect this and exit 1.
+ #
+ # Additionally, certain types of failures won't show up in the standard JUnit
+ # XML output from gtest. We assume that gtest knows better than us and our
+ # regexes in most cases, but for certain errors we delete the resulting xml
+ # file and let our own post-processing step regenerate it.
+ export GREP=$(which egrep)
+ if zgrep --silent "ThreadSanitizer|Leak check.*detected leaks" $LOGFILE ; then
+ echo ThreadSanitizer or leak check failures in $LOGFILE
+ STATUS=1
+ rm -f $XMLFILE
+ fi
+}
+
+function print_coredumps() {
+ # The script expects core files relative to the build directory with unique
+ # names per test executable because of the parallel running. So the corefile
+ # patterns must be set with prefix `core.{test-executable}*`:
+ #
+ # In case of macOS:
+ # sudo sysctl -w kern.corefile=core.%N.%P
+ # On Linux:
+ # sudo sysctl -w kernel.core_pattern=core.%e.%p
+ #
+ # and the ulimit must be increased:
+ # ulimit -c unlimited
+
+ # filename is truncated to the first 15 characters in case of linux, so limit
+ # the pattern for the first 15 characters
+ FILENAME=$(basename "${TEST_EXECUTABLE}")
+ FILENAME=$(echo ${FILENAME} | cut -c-15)
+ PATTERN="^core\.${FILENAME}"
+
+ COREFILES=$(ls | grep $PATTERN)
+ if [ -n "$COREFILES" ]; then
+ echo "Found core dump, printing backtrace:"
+
+ for COREFILE in $COREFILES; do
+ # Print backtrace
+ if [ "$(uname)" == "Darwin" ]; then
+ lldb -c "${COREFILE}" --batch --one-line "thread backtrace all -e true"
+ else
+ gdb -c "${COREFILE}" $TEST_EXECUTABLE -ex "thread apply all bt" -ex "set pagination 0" -batch
+ fi
+ # Remove the coredump, regenerate it via running the test case directly
+ rm "${COREFILE}"
+ done
+ fi
+}
+
+function post_process_tests() {
+ # If we have a LeakSanitizer report, and XML reporting is configured, add a new test
+ # case result to the XML file for the leak report. Otherwise Jenkins won't show
+ # us which tests had LSAN errors.
+ if zgrep --silent "ERROR: LeakSanitizer: detected memory leaks" $LOGFILE ; then
+ echo Test had memory leaks. Editing XML
+ perl -p -i -e '
+ if (m#</testsuite>#) {
+ print "<testcase name=\"LeakSanitizer\" status=\"run\" classname=\"LSAN\">\n";
+ print " <failure message=\"LeakSanitizer failed\" type=\"\">\n";
+ print " See txt log file for details\n";
+ print " </failure>\n";
+ print "</testcase>\n";
+ }' $XMLFILE
+ fi
+}
+
+function run_other() {
+ # Generic run function for test like executables that aren't actually gtest
+ $TEST_EXECUTABLE "$@" 2>&1 | $pipe_cmd > $LOGFILE
+ STATUS=$?
+}
+
+if [ $RUN_TYPE = "test" ]; then
+ setup_sanitizers
+fi
+
+# Run the actual test.
+for ATTEMPT_NUMBER in $(seq 1 $TEST_EXECUTION_ATTEMPTS) ; do
+ if [ $ATTEMPT_NUMBER -lt $TEST_EXECUTION_ATTEMPTS ]; then
+ # If the test fails, the test output may or may not be left behind,
+ # depending on whether the test cleaned up or exited immediately. Either
+ # way we need to clean it up. We do this by comparing the data directory
+ # contents before and after the test runs, and deleting anything new.
+ #
+ # The comm program requires that its two inputs be sorted.
+ TEST_TMPDIR_BEFORE=$(find $TEST_TMPDIR -maxdepth 1 -type d | sort)
+ fi
+
+ if [ $ATTEMPT_NUMBER -lt $TEST_EXECUTION_ATTEMPTS ]; then
+ # Now delete any new test output.
+ TEST_TMPDIR_AFTER=$(find $TEST_TMPDIR -maxdepth 1 -type d | sort)
+ DIFF=$(comm -13 <(echo "$TEST_TMPDIR_BEFORE") \
+ <(echo "$TEST_TMPDIR_AFTER"))
+ for DIR in $DIFF; do
+ # Multiple tests may be running concurrently. To avoid deleting the
+ # wrong directories, constrain to only directories beginning with the
+ # test name.
+ #
+ # This may delete old test directories belonging to this test, but
+ # that's not typically a concern when rerunning flaky tests.
+ if [[ $DIR =~ ^$TEST_TMPDIR/$TEST_NAME ]]; then
+ echo Deleting leftover flaky test directory "$DIR"
+ rm -Rf "$DIR"
+ fi
+ done
+ fi
+ echo "Running $TEST_NAME, redirecting output into $LOGFILE" \
+ "(attempt ${ATTEMPT_NUMBER}/$TEST_EXECUTION_ATTEMPTS)"
+ if [ $RUN_TYPE = "test" ]; then
+ run_test $*
+ else
+ run_other $*
+ fi
+ if [ "$STATUS" -eq "0" ]; then
+ break
+ elif [ "$ATTEMPT_NUMBER" -lt "$TEST_EXECUTION_ATTEMPTS" ]; then
+ echo Test failed attempt number $ATTEMPT_NUMBER
+ echo Will retry...
+ fi
+done
+
+if [ $RUN_TYPE = "test" ]; then
+ post_process_tests
+fi
+
+print_coredumps
+
+popd
+rm -Rf $TEST_WORKDIR
+
+exit $STATUS
diff --git a/src/arrow/cpp/build-support/run_clang_format.py b/src/arrow/cpp/build-support/run_clang_format.py
new file mode 100755
index 000000000..fd653a530
--- /dev/null
+++ b/src/arrow/cpp/build-support/run_clang_format.py
@@ -0,0 +1,137 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import print_function
+import lintutils
+from subprocess import PIPE
+import argparse
+import difflib
+import multiprocessing as mp
+import sys
+from functools import partial
+
+
+# examine the output of clang-format and if changes are
+# present assemble a (unified)patch of the difference
+def _check_one_file(filename, formatted):
+ with open(filename, "rb") as reader:
+ original = reader.read()
+
+ if formatted != original:
+ # Run the equivalent of diff -u
+ diff = list(difflib.unified_diff(
+ original.decode('utf8').splitlines(True),
+ formatted.decode('utf8').splitlines(True),
+ fromfile=filename,
+ tofile="{} (after clang format)".format(
+ filename)))
+ else:
+ diff = None
+
+ return filename, diff
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Runs clang-format on all of the source "
+ "files. If --fix is specified enforce format by "
+ "modifying in place, otherwise compare the output "
+ "with the existing file and output any necessary "
+ "changes as a patch in unified diff format")
+ parser.add_argument("--clang_format_binary",
+ required=True,
+ help="Path to the clang-format binary")
+ parser.add_argument("--exclude_globs",
+ help="Filename containing globs for files "
+ "that should be excluded from the checks")
+ parser.add_argument("--source_dir",
+ required=True,
+ help="Root directory of the source code")
+ parser.add_argument("--fix", default=False,
+ action="store_true",
+ help="If specified, will re-format the source "
+ "code instead of comparing the re-formatted "
+ "output, defaults to %(default)s")
+ parser.add_argument("--quiet", default=False,
+ action="store_true",
+ help="If specified, only print errors")
+ arguments = parser.parse_args()
+
+ exclude_globs = []
+ if arguments.exclude_globs:
+ with open(arguments.exclude_globs) as f:
+ exclude_globs.extend(line.strip() for line in f)
+
+ formatted_filenames = []
+ for path in lintutils.get_sources(arguments.source_dir, exclude_globs):
+ formatted_filenames.append(str(path))
+
+ if arguments.fix:
+ if not arguments.quiet:
+ print("\n".join(map(lambda x: "Formatting {}".format(x),
+ formatted_filenames)))
+
+ # Break clang-format invocations into chunks: each invocation formats
+ # 16 files. Wait for all processes to complete
+ results = lintutils.run_parallel([
+ [arguments.clang_format_binary, "-i"] + some
+ for some in lintutils.chunk(formatted_filenames, 16)
+ ])
+ for returncode, stdout, stderr in results:
+ # if any clang-format reported a parse error, bubble it
+ if returncode != 0:
+ sys.exit(returncode)
+
+ else:
+ # run an instance of clang-format for each source file in parallel,
+ # then wait for all processes to complete
+ results = lintutils.run_parallel([
+ [arguments.clang_format_binary, filename]
+ for filename in formatted_filenames
+ ], stdout=PIPE, stderr=PIPE)
+
+ checker_args = []
+ for filename, res in zip(formatted_filenames, results):
+ # if any clang-format reported a parse error, bubble it
+ returncode, stdout, stderr = res
+ if returncode != 0:
+ print(stderr)
+ sys.exit(returncode)
+ checker_args.append((filename, stdout))
+
+ error = False
+ pool = mp.Pool()
+ try:
+ # check the output from each invocation of clang-format in parallel
+ for filename, diff in pool.starmap(_check_one_file, checker_args):
+ if not arguments.quiet:
+ print("Checking {}".format(filename))
+ if diff:
+ print("{} had clang-format style issues".format(filename))
+ # Print out the diff to stderr
+ error = True
+ # pad with a newline
+ print(file=sys.stderr)
+ sys.stderr.writelines(diff)
+ except Exception:
+ error = True
+ raise
+ finally:
+ pool.terminate()
+ pool.join()
+ sys.exit(1 if error else 0)
diff --git a/src/arrow/cpp/build-support/run_clang_tidy.py b/src/arrow/cpp/build-support/run_clang_tidy.py
new file mode 100755
index 000000000..e5211be84
--- /dev/null
+++ b/src/arrow/cpp/build-support/run_clang_tidy.py
@@ -0,0 +1,124 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import print_function
+import argparse
+import multiprocessing as mp
+import lintutils
+from subprocess import PIPE
+import sys
+from functools import partial
+
+
+def _get_chunk_key(filenames):
+ # lists are not hashable so key on the first filename in a chunk
+ return filenames[0]
+
+
+# clang-tidy outputs complaints in '/path:line_number: complaint' format,
+# so we can scan its output to get a list of files to fix
+def _check_some_files(completed_processes, filenames):
+ result = completed_processes[_get_chunk_key(filenames)]
+ return lintutils.stdout_pathcolonline(result, filenames)
+
+
+def _check_all(cmd, filenames):
+ # each clang-tidy instance will process 16 files
+ chunks = lintutils.chunk(filenames, 16)
+ cmds = [cmd + some for some in chunks]
+ results = lintutils.run_parallel(cmds, stderr=PIPE, stdout=PIPE)
+ error = False
+ # record completed processes (keyed by the first filename in the input
+ # chunk) for lookup in _check_some_files
+ completed_processes = {
+ _get_chunk_key(some): result
+ for some, result in zip(chunks, results)
+ }
+ checker = partial(_check_some_files, completed_processes)
+ pool = mp.Pool()
+ try:
+ # check output of completed clang-tidy invocations in parallel
+ for problem_files, stdout in pool.imap(checker, chunks):
+ if problem_files:
+ msg = "clang-tidy suggested fixes for {}"
+ print("\n".join(map(msg.format, problem_files)))
+ error = True
+ except Exception:
+ error = True
+ raise
+ finally:
+ pool.terminate()
+ pool.join()
+
+ if error:
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Runs clang-tidy on all ")
+ parser.add_argument("--clang_tidy_binary",
+ required=True,
+ help="Path to the clang-tidy binary")
+ parser.add_argument("--exclude_globs",
+ help="Filename containing globs for files "
+ "that should be excluded from the checks")
+ parser.add_argument("--compile_commands",
+ required=True,
+ help="compile_commands.json to pass clang-tidy")
+ parser.add_argument("--source_dir",
+ required=True,
+ help="Root directory of the source code")
+ parser.add_argument("--fix", default=False,
+ action="store_true",
+ help="If specified, will attempt to fix the "
+ "source code instead of recommending fixes, "
+ "defaults to %(default)s")
+ parser.add_argument("--quiet", default=False,
+ action="store_true",
+ help="If specified, only print errors")
+ arguments = parser.parse_args()
+
+ exclude_globs = []
+ if arguments.exclude_globs:
+ for line in open(arguments.exclude_globs):
+ exclude_globs.append(line.strip())
+
+ linted_filenames = []
+ for path in lintutils.get_sources(arguments.source_dir, exclude_globs):
+ linted_filenames.append(path)
+
+ if not arguments.quiet:
+ msg = 'Tidying {}' if arguments.fix else 'Checking {}'
+ print("\n".join(map(msg.format, linted_filenames)))
+
+ cmd = [
+ arguments.clang_tidy_binary,
+ '-p',
+ arguments.compile_commands
+ ]
+ if arguments.fix:
+ cmd.append('-fix')
+ results = lintutils.run_parallel(
+ [cmd + some for some in lintutils.chunk(linted_filenames, 16)])
+ for returncode, stdout, stderr in results:
+ if returncode != 0:
+ sys.exit(returncode)
+
+ else:
+ _check_all(cmd, linted_filenames)
diff --git a/src/arrow/cpp/build-support/run_cpplint.py b/src/arrow/cpp/build-support/run_cpplint.py
new file mode 100755
index 000000000..cc98e094e
--- /dev/null
+++ b/src/arrow/cpp/build-support/run_cpplint.py
@@ -0,0 +1,132 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import print_function
+import lintutils
+from subprocess import PIPE, STDOUT
+import argparse
+import multiprocessing as mp
+import sys
+import platform
+from functools import partial
+
+
+# NOTE(wesm):
+#
+# * readability/casting is disabled as it aggressively warns about functions
+# with names like "int32", so "int32(x)", where int32 is a function name,
+# warns with
+_filters = '''
+-whitespace/comments
+-readability/casting
+-readability/todo
+-readability/alt_tokens
+-build/header_guard
+-build/c++11
+-build/include_what_you_use
+-runtime/references
+-build/include_order
+'''.split()
+
+
+def _get_chunk_key(filenames):
+ # lists are not hashable so key on the first filename in a chunk
+ return filenames[0]
+
+
+def _check_some_files(completed_processes, filenames):
+ # cpplint outputs complaints in '/path:line_number: complaint' format,
+ # so we can scan its output to get a list of files to fix
+ result = completed_processes[_get_chunk_key(filenames)]
+ return lintutils.stdout_pathcolonline(result, filenames)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Runs cpplint on all of the source files.")
+ parser.add_argument("--cpplint_binary",
+ required=True,
+ help="Path to the cpplint binary")
+ parser.add_argument("--exclude_globs",
+ help="Filename containing globs for files "
+ "that should be excluded from the checks")
+ parser.add_argument("--source_dir",
+ required=True,
+ help="Root directory of the source code")
+ parser.add_argument("--quiet", default=False,
+ action="store_true",
+ help="If specified, only print errors")
+ arguments = parser.parse_args()
+
+ exclude_globs = []
+ if arguments.exclude_globs:
+ with open(arguments.exclude_globs) as f:
+ exclude_globs.extend(line.strip() for line in f)
+
+ linted_filenames = []
+ for path in lintutils.get_sources(arguments.source_dir, exclude_globs):
+ linted_filenames.append(str(path))
+
+ cmd = [
+ arguments.cpplint_binary,
+ '--verbose=2',
+ '--linelength=90',
+ '--filter=' + ','.join(_filters)
+ ]
+ if (arguments.cpplint_binary.endswith('.py') and
+ platform.system() == 'Windows'):
+ # Windows doesn't support executable scripts; execute with
+ # sys.executable
+ cmd.insert(0, sys.executable)
+ if arguments.quiet:
+ cmd.append('--quiet')
+ else:
+ print("\n".join(map(lambda x: "Linting {}".format(x),
+ linted_filenames)))
+
+ # lint files in chunks: each invocation of cpplint will process 16 files
+ chunks = lintutils.chunk(linted_filenames, 16)
+ cmds = [cmd + some for some in chunks]
+ results = lintutils.run_parallel(cmds, stdout=PIPE, stderr=STDOUT)
+
+ error = False
+ # record completed processes (keyed by the first filename in the input
+ # chunk) for lookup in _check_some_files
+ completed_processes = {
+ _get_chunk_key(filenames): result
+ for filenames, result in zip(chunks, results)
+ }
+ checker = partial(_check_some_files, completed_processes)
+ pool = mp.Pool()
+ try:
+ # scan the outputs of various cpplint invocations in parallel to
+ # distill a list of problematic files
+ for problem_files, stdout in pool.imap(checker, chunks):
+ if problem_files:
+ if isinstance(stdout, bytes):
+ stdout = stdout.decode('utf8')
+ print(stdout, file=sys.stderr)
+ error = True
+ except Exception:
+ error = True
+ raise
+ finally:
+ pool.terminate()
+ pool.join()
+
+ sys.exit(1 if error else 0)
diff --git a/src/arrow/cpp/build-support/sanitizer-disallowed-entries.txt b/src/arrow/cpp/build-support/sanitizer-disallowed-entries.txt
new file mode 100644
index 000000000..636cfda23
--- /dev/null
+++ b/src/arrow/cpp/build-support/sanitizer-disallowed-entries.txt
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Workaround for a problem with gmock where a runtime error is caused by a call on a null pointer,
+# on a mocked object.
+# Seen error:
+# thirdparty/gmock-1.7.0/include/gmock/gmock-spec-builders.h:1529:12: runtime error: member call on null pointer of type 'testing::internal::ActionResultHolder<void>'
+fun:*testing*internal*InvokeWith*
+
+# Workaround for RapidJSON https://github.com/Tencent/rapidjson/issues/1724
+src:*/rapidjson/internal/*
diff --git a/src/arrow/cpp/build-support/stacktrace_addr2line.pl b/src/arrow/cpp/build-support/stacktrace_addr2line.pl
new file mode 100755
index 000000000..caedc5c07
--- /dev/null
+++ b/src/arrow/cpp/build-support/stacktrace_addr2line.pl
@@ -0,0 +1,92 @@
+#!/usr/bin/env perl
+# Copyright 2014 Cloudera, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#######################################################################
+# This script will convert a stack trace with addresses:
+# @ 0x5fb015 kudu::master::Master::Init()
+# @ 0x5c2d38 kudu::master::MiniMaster::StartOnPorts()
+# @ 0x5c31fa kudu::master::MiniMaster::Start()
+# @ 0x58270a kudu::MiniCluster::Start()
+# @ 0x57dc71 kudu::CreateTableStressTest::SetUp()
+# To one with line numbers:
+# @ 0x5fb015 kudu::master::Master::Init() at /home/mpercy/src/kudu/src/master/master.cc:54
+# @ 0x5c2d38 kudu::master::MiniMaster::StartOnPorts() at /home/mpercy/src/kudu/src/master/mini_master.cc:52
+# @ 0x5c31fa kudu::master::MiniMaster::Start() at /home/mpercy/src/kudu/src/master/mini_master.cc:33
+# @ 0x58270a kudu::MiniCluster::Start() at /home/mpercy/src/kudu/src/integration-tests/mini_cluster.cc:48
+# @ 0x57dc71 kudu::CreateTableStressTest::SetUp() at /home/mpercy/src/kudu/src/integration-tests/create-table-stress-test.cc:61
+#
+# If the script detects that the output is not symbolized, it will also attempt
+# to determine the function names, i.e. it will convert:
+# @ 0x5fb015
+# @ 0x5c2d38
+# @ 0x5c31fa
+# To:
+# @ 0x5fb015 kudu::master::Master::Init() at /home/mpercy/src/kudu/src/master/master.cc:54
+# @ 0x5c2d38 kudu::master::MiniMaster::StartOnPorts() at /home/mpercy/src/kudu/src/master/mini_master.cc:52
+# @ 0x5c31fa kudu::master::MiniMaster::Start() at /home/mpercy/src/kudu/src/master/mini_master.cc:33
+#######################################################################
+use strict;
+use warnings;
+
+if (!@ARGV) {
+ die <<EOF
+Usage: $0 executable [stack-trace-file]
+
+This script will read addresses from a file containing stack traces and
+will convert the addresses that conform to the pattern " @ 0x123456" to line
+numbers by calling addr2line on the provided executable.
+If no stack-trace-file is specified, it will take input from stdin.
+EOF
+}
+
+# el6 and other older systems don't support the -p flag,
+# so we do our own "pretty" parsing.
+sub parse_addr2line_output($$) {
+ defined(my $output = shift) or die;
+ defined(my $lookup_func_name = shift) or die;
+ my @lines = grep { $_ ne '' } split("\n", $output);
+ my $pretty_str = '';
+ if ($lookup_func_name) {
+ $pretty_str .= ' ' . $lines[0];
+ }
+ $pretty_str .= ' at ' . $lines[1];
+ return $pretty_str;
+}
+
+my $binary = shift @ARGV;
+if (! -x $binary || ! -r $binary) {
+ die "Error: Cannot access executable ($binary)";
+}
+
+# Cache lookups to speed processing of files with repeated trace addresses.
+my %addr2line_map = ();
+
+# Disable stdout buffering
+$| = 1;
+
+# Reading from <ARGV> is magical in Perl.
+while (defined(my $input = <ARGV>)) {
+ if ($input =~ /^\s+\@\s+(0x[[:xdigit:]]{6,})(?:\s+(\S+))?/) {
+ my $addr = $1;
+ my $lookup_func_name = (!defined $2);
+ if (!exists($addr2line_map{$addr})) {
+ $addr2line_map{$addr} = `addr2line -ifC -e $binary $addr`;
+ }
+ chomp $input;
+ $input .= parse_addr2line_output($addr2line_map{$addr}, $lookup_func_name) . "\n";
+ }
+ print $input;
+}
+
+exit 0;
diff --git a/src/arrow/cpp/build-support/trim-boost.sh b/src/arrow/cpp/build-support/trim-boost.sh
new file mode 100755
index 000000000..477a5e965
--- /dev/null
+++ b/src/arrow/cpp/build-support/trim-boost.sh
@@ -0,0 +1,72 @@
+#!/usr/bin/env bash
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+# This script is used to make the subset of boost that we actually use,
+# so that we don't have to download the whole big boost project when we build
+# boost from source.
+#
+# To test building Arrow locally with the boost bundle this creates, add:
+#
+# set(BOOST_SOURCE_URL /path/to/arrow/cpp/build-support/boost_1_75_0/boost_1_75_0.tar.gz)
+#
+# to the beginning of the build_boost() macro in ThirdpartyToolchain.cmake,
+#
+# or set the env var ARROW_BOOST_URL before calling cmake, like:
+#
+# ARROW_BOOST_URL=/path/to/arrow/cpp/build-support/boost_1_75_0/boost_1_75_0.tar.gz cmake ...
+#
+# After running this script, upload the bundle to
+# https://github.com/ursa-labs/thirdparty/releases/edit/latest
+# TODO(ARROW-6407) automate uploading to github
+
+set -eu
+
+# if version is not defined by the caller, set a default.
+: ${BOOST_VERSION:=1.75.0}
+: ${BOOST_FILE:=boost_${BOOST_VERSION//./_}}
+: ${BOOST_URL:=https://sourceforge.net/projects/boost/files/boost/${BOOST_VERSION}/${BOOST_FILE}.tar.gz}
+
+# Arrow tests require these
+BOOST_LIBS="system.hpp filesystem.hpp process.hpp"
+# Add these to be able to build those
+BOOST_LIBS="$BOOST_LIBS config build boost_install headers log predef"
+# Gandiva needs these (and some Arrow tests do too)
+BOOST_LIBS="$BOOST_LIBS multiprecision/cpp_int.hpp"
+# These are for Thrift when Thrift_SOURCE=BUNDLED
+BOOST_LIBS="$BOOST_LIBS locale.hpp scope_exit.hpp boost/typeof/incr_registration_group.hpp"
+
+if [ ! -d ${BOOST_FILE} ]; then
+ curl -L "${BOOST_URL}" > ${BOOST_FILE}.tar.gz
+ tar -xzf ${BOOST_FILE}.tar.gz
+fi
+
+pushd ${BOOST_FILE}
+
+if [ ! -f "dist/bin/bcp" ]; then
+ ./bootstrap.sh
+ ./b2 tools/bcp
+fi
+mkdir -p ${BOOST_FILE}
+./dist/bin/bcp ${BOOST_LIBS} ${BOOST_FILE}
+
+tar -czf ${BOOST_FILE}.tar.gz ${BOOST_FILE}/
+# Resulting tarball is in ${BOOST_FILE}/${BOOST_FILE}.tar.gz
+
+popd
diff --git a/src/arrow/cpp/build-support/tsan-suppressions.txt b/src/arrow/cpp/build-support/tsan-suppressions.txt
new file mode 100644
index 000000000..ce897c859
--- /dev/null
+++ b/src/arrow/cpp/build-support/tsan-suppressions.txt
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Thread leak in CUDA
+thread:libcuda.so
diff --git a/src/arrow/cpp/build-support/ubsan-suppressions.txt b/src/arrow/cpp/build-support/ubsan-suppressions.txt
new file mode 100644
index 000000000..13a83393a
--- /dev/null
+++ b/src/arrow/cpp/build-support/ubsan-suppressions.txt
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/src/arrow/cpp/build-support/update-flatbuffers.sh b/src/arrow/cpp/build-support/update-flatbuffers.sh
new file mode 100755
index 000000000..b1116a1cb
--- /dev/null
+++ b/src/arrow/cpp/build-support/update-flatbuffers.sh
@@ -0,0 +1,50 @@
+#!/usr/bin/env bash
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+# Run this from cpp/ directory. flatc is expected to be in your path
+
+set -euo pipefail
+
+CWD="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)"
+SOURCE_DIR="$CWD/../src"
+PYTHON_SOURCE_DIR="$CWD/../../python"
+FORMAT_DIR="$CWD/../../format"
+TOP="$FORMAT_DIR/.."
+FLATC="flatc"
+
+OUT_DIR="$SOURCE_DIR/generated"
+FILES=($(find $FORMAT_DIR -name '*.fbs'))
+FILES+=("$SOURCE_DIR/arrow/ipc/feather.fbs")
+
+# add compute ir files
+FILES+=($(find "$TOP/experimental/computeir" -name '*.fbs'))
+
+$FLATC --cpp --cpp-std c++11 \
+ --scoped-enums \
+ -o "$OUT_DIR" \
+ "${FILES[@]}"
+
+PLASMA_FBS=("$SOURCE_DIR"/plasma/{plasma,common}.fbs)
+
+$FLATC --cpp --cpp-std c++11 \
+ -o "$SOURCE_DIR/plasma" \
+ --gen-object-api \
+ --scoped-enums \
+ "${PLASMA_FBS[@]}"
diff --git a/src/arrow/cpp/build-support/update-thrift.sh b/src/arrow/cpp/build-support/update-thrift.sh
new file mode 100755
index 000000000..1213a628e
--- /dev/null
+++ b/src/arrow/cpp/build-support/update-thrift.sh
@@ -0,0 +1,23 @@
+#!/usr/bin/env bash
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+# Run this from cpp/ directory. thrift is expected to be in your path
+
+thrift --gen cpp -out src/generated src/parquet/parquet.thrift
diff --git a/src/arrow/cpp/build-support/vendor-flatbuffers.sh b/src/arrow/cpp/build-support/vendor-flatbuffers.sh
new file mode 100755
index 000000000..6cbf77b9c
--- /dev/null
+++ b/src/arrow/cpp/build-support/vendor-flatbuffers.sh
@@ -0,0 +1,31 @@
+#!/usr/bin/env bash
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+set -eu
+
+# Run this from cpp/ directory with $FLATBUFFERS_HOME set to location of your
+# Flatbuffers installation
+SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)"
+
+VENDOR_LOCATION=$SOURCE_DIR/../thirdparty/flatbuffers/include/flatbuffers
+mkdir -p $VENDOR_LOCATION
+cp -f $FLATBUFFERS_HOME/include/flatbuffers/base.h $VENDOR_LOCATION
+cp -f $FLATBUFFERS_HOME/include/flatbuffers/flatbuffers.h $VENDOR_LOCATION
+cp -f $FLATBUFFERS_HOME/include/flatbuffers/stl_emulation.h $VENDOR_LOCATION
diff --git a/src/arrow/cpp/cmake_modules/BuildUtils.cmake b/src/arrow/cpp/cmake_modules/BuildUtils.cmake
new file mode 100644
index 000000000..cd8290d1b
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/BuildUtils.cmake
@@ -0,0 +1,936 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Common path suffixes to be searched by find_library or find_path.
+# Windows artifacts may be found under "<root>/Library", so
+# search there as well.
+set(ARROW_LIBRARY_PATH_SUFFIXES
+ "${CMAKE_LIBRARY_ARCHITECTURE}"
+ "lib/${CMAKE_LIBRARY_ARCHITECTURE}"
+ "lib64"
+ "lib32"
+ "lib"
+ "bin"
+ "Library"
+ "Library/lib"
+ "Library/bin")
+set(ARROW_INCLUDE_PATH_SUFFIXES "include" "Library" "Library/include")
+
+set(ARROW_BOOST_PROCESS_COMPILE_DEFINITIONS)
+if(WIN32 AND CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+ # boost/process/detail/windows/handle_workaround.hpp doesn't work
+ # without BOOST_USE_WINDOWS_H with MinGW because MinGW doesn't
+ # provide __kernel_entry without winternl.h.
+ #
+ # See also:
+ # https://github.com/boostorg/process/blob/develop/include/boost/process/detail/windows/handle_workaround.hpp
+ #
+ # You can use this like the following:
+ #
+ # target_compile_definitions(target PRIVATE
+ # ${ARROW_BOOST_PROCESS_COMPILE_DEFINITIONS})
+ list(APPEND ARROW_BOOST_PROCESS_COMPILE_DEFINITIONS "BOOST_USE_WINDOWS_H=1")
+endif()
+
+function(ADD_THIRDPARTY_LIB LIB_NAME)
+ set(options)
+ set(one_value_args SHARED_LIB STATIC_LIB)
+ set(multi_value_args DEPS INCLUDE_DIRECTORIES)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+ if(ARG_UNPARSED_ARGUMENTS)
+ message(SEND_ERROR "Error: unrecognized arguments: ${ARG_UNPARSED_ARGUMENTS}")
+ endif()
+
+ if(ARG_STATIC_LIB AND ARG_SHARED_LIB)
+ set(AUG_LIB_NAME "${LIB_NAME}_static")
+ add_library(${AUG_LIB_NAME} STATIC IMPORTED)
+ set_target_properties(${AUG_LIB_NAME} PROPERTIES IMPORTED_LOCATION
+ "${ARG_STATIC_LIB}")
+ if(ARG_DEPS)
+ set_target_properties(${AUG_LIB_NAME} PROPERTIES INTERFACE_LINK_LIBRARIES
+ "${ARG_DEPS}")
+ endif()
+ message(STATUS "Added static library dependency ${AUG_LIB_NAME}: ${ARG_STATIC_LIB}")
+ if(ARG_INCLUDE_DIRECTORIES)
+ set_target_properties(${AUG_LIB_NAME} PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
+ "${ARG_INCLUDE_DIRECTORIES}")
+ endif()
+
+ set(AUG_LIB_NAME "${LIB_NAME}_shared")
+ add_library(${AUG_LIB_NAME} SHARED IMPORTED)
+
+ if(WIN32)
+ # Mark the ".lib" location as part of a Windows DLL
+ set_target_properties(${AUG_LIB_NAME} PROPERTIES IMPORTED_IMPLIB
+ "${ARG_SHARED_LIB}")
+ else()
+ set_target_properties(${AUG_LIB_NAME} PROPERTIES IMPORTED_LOCATION
+ "${ARG_SHARED_LIB}")
+ endif()
+ if(ARG_DEPS)
+ set_target_properties(${AUG_LIB_NAME} PROPERTIES INTERFACE_LINK_LIBRARIES
+ "${ARG_DEPS}")
+ endif()
+ message(STATUS "Added shared library dependency ${AUG_LIB_NAME}: ${ARG_SHARED_LIB}")
+ if(ARG_INCLUDE_DIRECTORIES)
+ set_target_properties(${AUG_LIB_NAME} PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
+ "${ARG_INCLUDE_DIRECTORIES}")
+ endif()
+ elseif(ARG_STATIC_LIB)
+ set(AUG_LIB_NAME "${LIB_NAME}_static")
+ add_library(${AUG_LIB_NAME} STATIC IMPORTED)
+ set_target_properties(${AUG_LIB_NAME} PROPERTIES IMPORTED_LOCATION
+ "${ARG_STATIC_LIB}")
+ if(ARG_DEPS)
+ set_target_properties(${AUG_LIB_NAME} PROPERTIES INTERFACE_LINK_LIBRARIES
+ "${ARG_DEPS}")
+ endif()
+ message(STATUS "Added static library dependency ${AUG_LIB_NAME}: ${ARG_STATIC_LIB}")
+ if(ARG_INCLUDE_DIRECTORIES)
+ set_target_properties(${AUG_LIB_NAME} PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
+ "${ARG_INCLUDE_DIRECTORIES}")
+ endif()
+ elseif(ARG_SHARED_LIB)
+ set(AUG_LIB_NAME "${LIB_NAME}_shared")
+ add_library(${AUG_LIB_NAME} SHARED IMPORTED)
+
+ if(WIN32)
+ # Mark the ".lib" location as part of a Windows DLL
+ set_target_properties(${AUG_LIB_NAME} PROPERTIES IMPORTED_IMPLIB
+ "${ARG_SHARED_LIB}")
+ else()
+ set_target_properties(${AUG_LIB_NAME} PROPERTIES IMPORTED_LOCATION
+ "${ARG_SHARED_LIB}")
+ endif()
+ message(STATUS "Added shared library dependency ${AUG_LIB_NAME}: ${ARG_SHARED_LIB}")
+ if(ARG_DEPS)
+ set_target_properties(${AUG_LIB_NAME} PROPERTIES INTERFACE_LINK_LIBRARIES
+ "${ARG_DEPS}")
+ endif()
+ if(ARG_INCLUDE_DIRECTORIES)
+ set_target_properties(${AUG_LIB_NAME} PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
+ "${ARG_INCLUDE_DIRECTORIES}")
+ endif()
+ else()
+ message(FATAL_ERROR "No static or shared library provided for ${LIB_NAME}")
+ endif()
+endfunction()
+
+function(REUSE_PRECOMPILED_HEADER_LIB TARGET_NAME LIB_NAME)
+ if(ARROW_USE_PRECOMPILED_HEADERS)
+ target_precompile_headers(${TARGET_NAME} REUSE_FROM ${LIB_NAME})
+ endif()
+endfunction()
+
+# Based on MIT-licensed
+# https://gist.github.com/cristianadam/ef920342939a89fae3e8a85ca9459b49
+function(create_merged_static_lib output_target)
+ set(options)
+ set(one_value_args NAME ROOT)
+ set(multi_value_args TO_MERGE)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+ if(ARG_UNPARSED_ARGUMENTS)
+ message(SEND_ERROR "Error: unrecognized arguments: ${ARG_UNPARSED_ARGUMENTS}")
+ endif()
+
+ set(output_lib_path
+ ${BUILD_OUTPUT_ROOT_DIRECTORY}${CMAKE_STATIC_LIBRARY_PREFIX}${ARG_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}
+ )
+
+ set(all_library_paths $<TARGET_FILE:${ARG_ROOT}>)
+ foreach(lib ${ARG_TO_MERGE})
+ list(APPEND all_library_paths $<TARGET_FILE:${lib}>)
+ endforeach()
+
+ if(APPLE)
+ set(BUNDLE_COMMAND "libtool" "-no_warning_for_no_symbols" "-static" "-o"
+ ${output_lib_path} ${all_library_paths})
+ elseif(CMAKE_CXX_COMPILER_ID MATCHES "^(Clang|GNU|Intel)$")
+ set(ar_script_path ${CMAKE_BINARY_DIR}/${ARG_NAME}.ar)
+
+ file(WRITE ${ar_script_path}.in "CREATE ${output_lib_path}\n")
+ file(APPEND ${ar_script_path}.in "ADDLIB $<TARGET_FILE:${ARG_ROOT}>\n")
+
+ foreach(lib ${ARG_TO_MERGE})
+ file(APPEND ${ar_script_path}.in "ADDLIB $<TARGET_FILE:${lib}>\n")
+ endforeach()
+
+ file(APPEND ${ar_script_path}.in "SAVE\nEND\n")
+ file(GENERATE
+ OUTPUT ${ar_script_path}
+ INPUT ${ar_script_path}.in)
+ set(ar_tool ${CMAKE_AR})
+
+ if(CMAKE_INTERPROCEDURAL_OPTIMIZATION)
+ set(ar_tool ${CMAKE_CXX_COMPILER_AR})
+ endif()
+
+ set(BUNDLE_COMMAND ${ar_tool} -M < ${ar_script_path})
+
+ elseif(MSVC)
+ if(NOT CMAKE_LIBTOOL)
+ find_program(lib_tool lib HINTS "${CMAKE_CXX_COMPILER}/..")
+ if("${lib_tool}" STREQUAL "lib_tool-NOTFOUND")
+ message(FATAL_ERROR "Cannot locate libtool to bundle libraries")
+ endif()
+ else()
+ set(${lib_tool} ${CMAKE_LIBTOOL})
+ endif()
+ set(BUNDLE_TOOL ${lib_tool})
+ set(BUNDLE_COMMAND ${BUNDLE_TOOL} /NOLOGO /OUT:${output_lib_path}
+ ${all_library_paths})
+ else()
+ message(FATAL_ERROR "Unknown bundle scenario!")
+ endif()
+
+ add_custom_command(COMMAND ${BUNDLE_COMMAND}
+ OUTPUT ${output_lib_path}
+ COMMENT "Bundling ${output_lib_path}"
+ VERBATIM)
+
+ message(STATUS "Creating bundled static library target ${output_target} at ${output_lib_path}"
+ )
+
+ add_custom_target(${output_target} ALL DEPENDS ${output_lib_path})
+ add_dependencies(${output_target} ${ARG_ROOT} ${ARG_TO_MERGE})
+ install(FILES ${output_lib_path} DESTINATION ${CMAKE_INSTALL_LIBDIR})
+endfunction()
+
+# \arg OUTPUTS list to append built targets to
+function(ADD_ARROW_LIB LIB_NAME)
+ set(options)
+ set(one_value_args
+ BUILD_SHARED
+ BUILD_STATIC
+ CMAKE_PACKAGE_NAME
+ PKG_CONFIG_NAME
+ SHARED_LINK_FLAGS
+ PRECOMPILED_HEADER_LIB)
+ set(multi_value_args
+ SOURCES
+ PRECOMPILED_HEADERS
+ OUTPUTS
+ STATIC_LINK_LIBS
+ SHARED_LINK_LIBS
+ SHARED_PRIVATE_LINK_LIBS
+ EXTRA_INCLUDES
+ PRIVATE_INCLUDES
+ DEPENDENCIES
+ SHARED_INSTALL_INTERFACE_LIBS
+ STATIC_INSTALL_INTERFACE_LIBS
+ OUTPUT_PATH)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+ if(ARG_UNPARSED_ARGUMENTS)
+ message(SEND_ERROR "Error: unrecognized arguments: ${ARG_UNPARSED_ARGUMENTS}")
+ endif()
+
+ if(ARG_OUTPUTS)
+ set(${ARG_OUTPUTS})
+ endif()
+
+ # Allow overriding ARROW_BUILD_SHARED and ARROW_BUILD_STATIC
+ if(DEFINED ARG_BUILD_SHARED)
+ set(BUILD_SHARED ${ARG_BUILD_SHARED})
+ else()
+ set(BUILD_SHARED ${ARROW_BUILD_SHARED})
+ endif()
+ if(DEFINED ARG_BUILD_STATIC)
+ set(BUILD_STATIC ${ARG_BUILD_STATIC})
+ else()
+ set(BUILD_STATIC ${ARROW_BUILD_STATIC})
+ endif()
+ if(ARG_OUTPUT_PATH)
+ set(OUTPUT_PATH ${ARG_OUTPUT_PATH})
+ else()
+ set(OUTPUT_PATH ${BUILD_OUTPUT_ROOT_DIRECTORY})
+ endif()
+
+ if(WIN32 OR (CMAKE_GENERATOR STREQUAL Xcode))
+ # We need to compile C++ separately for each library kind (shared and static)
+ # because of dllexport declarations on Windows.
+ # The Xcode generator doesn't reliably work with Xcode as target names are not
+ # guessed correctly.
+ set(USE_OBJLIB OFF)
+ else()
+ set(USE_OBJLIB ON)
+ endif()
+
+ if(USE_OBJLIB)
+ # Generate a single "objlib" from all C++ modules and link
+ # that "objlib" into each library kind, to avoid compiling twice
+ add_library(${LIB_NAME}_objlib OBJECT ${ARG_SOURCES})
+ # Necessary to make static linking into other shared libraries work properly
+ set_property(TARGET ${LIB_NAME}_objlib PROPERTY POSITION_INDEPENDENT_CODE 1)
+ if(ARG_DEPENDENCIES)
+ add_dependencies(${LIB_NAME}_objlib ${ARG_DEPENDENCIES})
+ endif()
+ if(ARG_PRECOMPILED_HEADER_LIB)
+ reuse_precompiled_header_lib(${LIB_NAME}_objlib ${ARG_PRECOMPILED_HEADER_LIB})
+ endif()
+ if(ARG_PRECOMPILED_HEADERS AND ARROW_USE_PRECOMPILED_HEADERS)
+ target_precompile_headers(${LIB_NAME}_objlib PRIVATE ${ARG_PRECOMPILED_HEADERS})
+ endif()
+ set(LIB_DEPS $<TARGET_OBJECTS:${LIB_NAME}_objlib>)
+ set(LIB_INCLUDES)
+ set(EXTRA_DEPS)
+
+ if(ARG_OUTPUTS)
+ list(APPEND ${ARG_OUTPUTS} ${LIB_NAME}_objlib)
+ endif()
+
+ if(ARG_EXTRA_INCLUDES)
+ target_include_directories(${LIB_NAME}_objlib SYSTEM PUBLIC ${ARG_EXTRA_INCLUDES})
+ endif()
+ if(ARG_PRIVATE_INCLUDES)
+ target_include_directories(${LIB_NAME}_objlib PRIVATE ${ARG_PRIVATE_INCLUDES})
+ endif()
+ else()
+ # Prepare arguments for separate compilation of static and shared libs below
+ # TODO: add PCH directives
+ set(LIB_DEPS ${ARG_SOURCES})
+ set(EXTRA_DEPS ${ARG_DEPENDENCIES})
+
+ if(ARG_EXTRA_INCLUDES)
+ set(LIB_INCLUDES ${ARG_EXTRA_INCLUDES})
+ endif()
+ endif()
+
+ set(RUNTIME_INSTALL_DIR bin)
+
+ if(BUILD_SHARED)
+ add_library(${LIB_NAME}_shared SHARED ${LIB_DEPS})
+ if(EXTRA_DEPS)
+ add_dependencies(${LIB_NAME}_shared ${EXTRA_DEPS})
+ endif()
+
+ if(ARG_PRECOMPILED_HEADER_LIB)
+ reuse_precompiled_header_lib(${LIB_NAME}_shared ${ARG_PRECOMPILED_HEADER_LIB})
+ endif()
+
+ if(ARG_OUTPUTS)
+ list(APPEND ${ARG_OUTPUTS} ${LIB_NAME}_shared)
+ endif()
+
+ if(LIB_INCLUDES)
+ target_include_directories(${LIB_NAME}_shared SYSTEM PUBLIC ${ARG_EXTRA_INCLUDES})
+ endif()
+
+ if(ARG_PRIVATE_INCLUDES)
+ target_include_directories(${LIB_NAME}_shared PRIVATE ${ARG_PRIVATE_INCLUDES})
+ endif()
+
+ # On iOS, specifying -undefined conflicts with enabling bitcode
+ if(APPLE
+ AND NOT IOS
+ AND NOT DEFINED ENV{EMSCRIPTEN})
+ # On OS X, you can avoid linking at library load time and instead
+ # expecting that the symbols have been loaded separately. This happens
+ # with libpython* where there can be conflicts between system Python and
+ # the Python from a thirdparty distribution
+ #
+ # When running with the Emscripten Compiler, we need not worry about
+ # python, and the Emscripten Compiler does not support this option.
+ set(ARG_SHARED_LINK_FLAGS "-undefined dynamic_lookup ${ARG_SHARED_LINK_FLAGS}")
+ endif()
+
+ set_target_properties(${LIB_NAME}_shared
+ PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${OUTPUT_PATH}"
+ RUNTIME_OUTPUT_DIRECTORY "${OUTPUT_PATH}"
+ PDB_OUTPUT_DIRECTORY "${OUTPUT_PATH}"
+ LINK_FLAGS "${ARG_SHARED_LINK_FLAGS}"
+ OUTPUT_NAME ${LIB_NAME}
+ VERSION "${ARROW_FULL_SO_VERSION}"
+ SOVERSION "${ARROW_SO_VERSION}")
+
+ target_link_libraries(${LIB_NAME}_shared
+ LINK_PUBLIC
+ "$<BUILD_INTERFACE:${ARG_SHARED_LINK_LIBS}>"
+ "$<INSTALL_INTERFACE:${ARG_SHARED_INSTALL_INTERFACE_LIBS}>"
+ LINK_PRIVATE
+ ${ARG_SHARED_PRIVATE_LINK_LIBS})
+
+ if(ARROW_RPATH_ORIGIN)
+ if(APPLE)
+ set(_lib_install_rpath "@loader_path")
+ else()
+ set(_lib_install_rpath "\$ORIGIN")
+ endif()
+ set_target_properties(${LIB_NAME}_shared PROPERTIES INSTALL_RPATH
+ ${_lib_install_rpath})
+ endif()
+
+ if(APPLE)
+ if(ARROW_INSTALL_NAME_RPATH)
+ set(_lib_install_name "@rpath")
+ else()
+ set(_lib_install_name "${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}")
+ endif()
+ set_target_properties(${LIB_NAME}_shared
+ PROPERTIES BUILD_WITH_INSTALL_RPATH ON INSTALL_NAME_DIR
+ "${_lib_install_name}")
+ endif()
+
+ install(TARGETS ${LIB_NAME}_shared ${INSTALL_IS_OPTIONAL}
+ EXPORT ${LIB_NAME}_targets
+ RUNTIME DESTINATION ${RUNTIME_INSTALL_DIR}
+ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
+ INCLUDES
+ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
+ endif()
+
+ if(BUILD_STATIC)
+ add_library(${LIB_NAME}_static STATIC ${LIB_DEPS})
+ if(EXTRA_DEPS)
+ add_dependencies(${LIB_NAME}_static ${EXTRA_DEPS})
+ endif()
+
+ if(ARG_PRECOMPILED_HEADER_LIB)
+ reuse_precompiled_header_lib(${LIB_NAME}_static ${ARG_PRECOMPILED_HEADER_LIB})
+ endif()
+
+ if(ARG_OUTPUTS)
+ list(APPEND ${ARG_OUTPUTS} ${LIB_NAME}_static)
+ endif()
+
+ if(LIB_INCLUDES)
+ target_include_directories(${LIB_NAME}_static SYSTEM PUBLIC ${ARG_EXTRA_INCLUDES})
+ endif()
+
+ if(ARG_PRIVATE_INCLUDES)
+ target_include_directories(${LIB_NAME}_static PRIVATE ${ARG_PRIVATE_INCLUDES})
+ endif()
+
+ if(MSVC_TOOLCHAIN)
+ set(LIB_NAME_STATIC ${LIB_NAME}_static)
+ else()
+ set(LIB_NAME_STATIC ${LIB_NAME})
+ endif()
+
+ if(ARROW_BUILD_STATIC AND WIN32)
+ target_compile_definitions(${LIB_NAME}_static PUBLIC ARROW_STATIC)
+ endif()
+
+ set_target_properties(${LIB_NAME}_static
+ PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${OUTPUT_PATH}"
+ OUTPUT_NAME ${LIB_NAME_STATIC})
+
+ if(ARG_STATIC_INSTALL_INTERFACE_LIBS)
+ target_link_libraries(${LIB_NAME}_static LINK_PUBLIC
+ "$<INSTALL_INTERFACE:${ARG_STATIC_INSTALL_INTERFACE_LIBS}>")
+ endif()
+
+ if(ARG_STATIC_LINK_LIBS)
+ target_link_libraries(${LIB_NAME}_static LINK_PRIVATE
+ "$<BUILD_INTERFACE:${ARG_STATIC_LINK_LIBS}>")
+ endif()
+
+ install(TARGETS ${LIB_NAME}_static ${INSTALL_IS_OPTIONAL}
+ EXPORT ${LIB_NAME}_targets
+ RUNTIME DESTINATION ${RUNTIME_INSTALL_DIR}
+ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
+ INCLUDES
+ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
+ endif()
+
+ if(ARG_CMAKE_PACKAGE_NAME)
+ arrow_install_cmake_find_module("${ARG_CMAKE_PACKAGE_NAME}")
+
+ set(TARGETS_CMAKE "${ARG_CMAKE_PACKAGE_NAME}Targets.cmake")
+ install(EXPORT ${LIB_NAME}_targets
+ FILE "${TARGETS_CMAKE}"
+ DESTINATION "${ARROW_CMAKE_INSTALL_DIR}")
+
+ set(CONFIG_CMAKE "${ARG_CMAKE_PACKAGE_NAME}Config.cmake")
+ set(BUILT_CONFIG_CMAKE "${CMAKE_CURRENT_BINARY_DIR}/${CONFIG_CMAKE}")
+ configure_package_config_file("${CONFIG_CMAKE}.in" "${BUILT_CONFIG_CMAKE}"
+ INSTALL_DESTINATION "${ARROW_CMAKE_INSTALL_DIR}")
+ install(FILES "${BUILT_CONFIG_CMAKE}" DESTINATION "${ARROW_CMAKE_INSTALL_DIR}")
+
+ set(CONFIG_VERSION_CMAKE "${ARG_CMAKE_PACKAGE_NAME}ConfigVersion.cmake")
+ set(BUILT_CONFIG_VERSION_CMAKE "${CMAKE_CURRENT_BINARY_DIR}/${CONFIG_VERSION_CMAKE}")
+ write_basic_package_version_file(
+ "${BUILT_CONFIG_VERSION_CMAKE}"
+ VERSION ${${PROJECT_NAME}_VERSION}
+ COMPATIBILITY AnyNewerVersion)
+ install(FILES "${BUILT_CONFIG_VERSION_CMAKE}"
+ DESTINATION "${ARROW_CMAKE_INSTALL_DIR}")
+ endif()
+
+ if(ARG_PKG_CONFIG_NAME)
+ arrow_add_pkg_config("${ARG_PKG_CONFIG_NAME}")
+ endif()
+
+ # Modify variable in calling scope
+ if(ARG_OUTPUTS)
+ set(${ARG_OUTPUTS}
+ ${${ARG_OUTPUTS}}
+ PARENT_SCOPE)
+ endif()
+endfunction()
+
+#
+# Benchmarking
+#
+# Add a new micro benchmark, with or without an executable that should be built.
+# If benchmarks are enabled then they will be run along side unit tests with ctest.
+# 'make benchmark' and 'make unittest' to build/run only benchmark or unittests,
+# respectively.
+#
+# REL_BENCHMARK_NAME is the name of the benchmark app. It may be a single component
+# (e.g. monotime-benchmark) or contain additional components (e.g.
+# net/net_util-benchmark). Either way, the last component must be a globally
+# unique name.
+
+# The benchmark will registered as unit test with ctest with a label
+# of 'benchmark'.
+#
+# Arguments after the test name will be passed to set_tests_properties().
+#
+# \arg PREFIX a string to append to the name of the benchmark executable. For
+# example, if you have src/arrow/foo/bar-benchmark.cc, then PREFIX "foo" will
+# create test executable foo-bar-benchmark
+# \arg LABELS the benchmark label or labels to assign the unit tests to. By
+# default, benchmarks will go in the "benchmark" group. Custom targets for the
+# group names must exist
+function(ADD_BENCHMARK REL_BENCHMARK_NAME)
+ set(options)
+ set(one_value_args)
+ set(multi_value_args
+ EXTRA_LINK_LIBS
+ STATIC_LINK_LIBS
+ DEPENDENCIES
+ PREFIX
+ LABELS)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+ if(ARG_UNPARSED_ARGUMENTS)
+ message(SEND_ERROR "Error: unrecognized arguments: ${ARG_UNPARSED_ARGUMENTS}")
+ endif()
+
+ if(NO_BENCHMARKS)
+ return()
+ endif()
+ get_filename_component(BENCHMARK_NAME ${REL_BENCHMARK_NAME} NAME_WE)
+
+ if(ARG_PREFIX)
+ set(BENCHMARK_NAME "${ARG_PREFIX}-${BENCHMARK_NAME}")
+ endif()
+
+ # Make sure the executable name contains only hyphens, not underscores
+ string(REPLACE "_" "-" BENCHMARK_NAME ${BENCHMARK_NAME})
+
+ if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${REL_BENCHMARK_NAME}.cc)
+ # This benchmark has a corresponding .cc file, set it up as an executable.
+ set(BENCHMARK_PATH "${EXECUTABLE_OUTPUT_PATH}/${BENCHMARK_NAME}")
+ add_executable(${BENCHMARK_NAME} "${REL_BENCHMARK_NAME}.cc")
+
+ if(ARG_STATIC_LINK_LIBS)
+ # Customize link libraries
+ target_link_libraries(${BENCHMARK_NAME} PRIVATE ${ARG_STATIC_LINK_LIBS})
+ else()
+ target_link_libraries(${BENCHMARK_NAME} PRIVATE ${ARROW_BENCHMARK_LINK_LIBS})
+ endif()
+ add_dependencies(benchmark ${BENCHMARK_NAME})
+ set(NO_COLOR "--color_print=false")
+
+ if(ARG_EXTRA_LINK_LIBS)
+ target_link_libraries(${BENCHMARK_NAME} PRIVATE ${ARG_EXTRA_LINK_LIBS})
+ endif()
+ else()
+ # No executable, just invoke the benchmark (probably a script) directly.
+ set(BENCHMARK_PATH ${CMAKE_CURRENT_SOURCE_DIR}/${REL_BENCHMARK_NAME})
+ set(NO_COLOR "")
+ endif()
+
+ # With OSX and conda, we need to set the correct RPATH so that dependencies
+ # are found. The installed libraries with conda have an RPATH that matches
+ # for executables and libraries lying in $ENV{CONDA_PREFIX}/bin or
+ # $ENV{CONDA_PREFIX}/lib but our test libraries and executables are not
+ # installed there.
+ if(NOT "$ENV{CONDA_PREFIX}" STREQUAL "" AND APPLE)
+ set_target_properties(${BENCHMARK_NAME}
+ PROPERTIES BUILD_WITH_INSTALL_RPATH TRUE
+ INSTALL_RPATH_USE_LINK_PATH TRUE
+ INSTALL_RPATH
+ "$ENV{CONDA_PREFIX}/lib;${EXECUTABLE_OUTPUT_PATH}")
+ endif()
+
+ # Add test as dependency of relevant label targets
+ add_dependencies(all-benchmarks ${BENCHMARK_NAME})
+ foreach(TARGET ${ARG_LABELS})
+ add_dependencies(${TARGET} ${BENCHMARK_NAME})
+ endforeach()
+
+ if(ARG_DEPENDENCIES)
+ add_dependencies(${BENCHMARK_NAME} ${ARG_DEPENDENCIES})
+ endif()
+
+ if(ARG_LABELS)
+ set(ARG_LABELS "benchmark;${ARG_LABELS}")
+ else()
+ set(ARG_LABELS benchmark)
+ endif()
+
+ add_test(${BENCHMARK_NAME}
+ ${BUILD_SUPPORT_DIR}/run-test.sh
+ ${CMAKE_BINARY_DIR}
+ benchmark
+ ${BENCHMARK_PATH}
+ ${NO_COLOR})
+ set_property(TEST ${BENCHMARK_NAME}
+ APPEND
+ PROPERTY LABELS ${ARG_LABELS})
+endfunction()
+
+#
+# Testing
+#
+# Add a new test case, with or without an executable that should be built.
+#
+# REL_TEST_NAME is the name of the test. It may be a single component
+# (e.g. monotime-test) or contain additional components (e.g.
+# net/net_util-test). Either way, the last component must be a globally
+# unique name.
+#
+# If given, SOURCES is the list of C++ source files to compile into the test
+# executable. Otherwise, "REL_TEST_NAME.cc" is used.
+#
+# The unit test is added with a label of "unittest" to support filtering with
+# ctest.
+#
+# Arguments after the test name will be passed to set_tests_properties().
+#
+# \arg ENABLED if passed, add this unit test even if ARROW_BUILD_TESTS is off
+# \arg PREFIX a string to append to the name of the test executable. For
+# example, if you have src/arrow/foo/bar-test.cc, then PREFIX "foo" will create
+# test executable foo-bar-test
+# \arg LABELS the unit test label or labels to assign the unit tests
+# to. By default, unit tests will go in the "unittest" group, but if we have
+# multiple unit tests in some subgroup, you can assign a test to multiple
+# groups use the syntax unittest;GROUP2;GROUP3. Custom targets for the group
+# names must exist
+function(ADD_TEST_CASE REL_TEST_NAME)
+ set(options NO_VALGRIND ENABLED)
+ set(one_value_args PRECOMPILED_HEADER_LIB)
+ set(multi_value_args
+ SOURCES
+ PRECOMPILED_HEADERS
+ STATIC_LINK_LIBS
+ EXTRA_LINK_LIBS
+ EXTRA_INCLUDES
+ EXTRA_DEPENDENCIES
+ LABELS
+ EXTRA_LABELS
+ PREFIX)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+ if(ARG_UNPARSED_ARGUMENTS)
+ message(SEND_ERROR "Error: unrecognized arguments: ${ARG_UNPARSED_ARGUMENTS}")
+ endif()
+
+ if(NO_TESTS AND NOT ARG_ENABLED)
+ return()
+ endif()
+ get_filename_component(TEST_NAME ${REL_TEST_NAME} NAME_WE)
+
+ if(ARG_PREFIX)
+ set(TEST_NAME "${ARG_PREFIX}-${TEST_NAME}")
+ endif()
+
+ if(ARG_SOURCES)
+ set(SOURCES ${ARG_SOURCES})
+ else()
+ set(SOURCES "${REL_TEST_NAME}.cc")
+ endif()
+
+ # Make sure the executable name contains only hyphens, not underscores
+ string(REPLACE "_" "-" TEST_NAME ${TEST_NAME})
+
+ set(TEST_PATH "${EXECUTABLE_OUTPUT_PATH}/${TEST_NAME}")
+ add_executable(${TEST_NAME} ${SOURCES})
+
+ # With OSX and conda, we need to set the correct RPATH so that dependencies
+ # are found. The installed libraries with conda have an RPATH that matches
+ # for executables and libraries lying in $ENV{CONDA_PREFIX}/bin or
+ # $ENV{CONDA_PREFIX}/lib but our test libraries and executables are not
+ # installed there.
+ if(NOT "$ENV{CONDA_PREFIX}" STREQUAL "" AND APPLE)
+ set_target_properties(${TEST_NAME}
+ PROPERTIES BUILD_WITH_INSTALL_RPATH TRUE
+ INSTALL_RPATH_USE_LINK_PATH TRUE
+ INSTALL_RPATH
+ "${EXECUTABLE_OUTPUT_PATH};$ENV{CONDA_PREFIX}/lib")
+ endif()
+
+ if(ARG_STATIC_LINK_LIBS)
+ # Customize link libraries
+ target_link_libraries(${TEST_NAME} PRIVATE ${ARG_STATIC_LINK_LIBS})
+ else()
+ target_link_libraries(${TEST_NAME} PRIVATE ${ARROW_TEST_LINK_LIBS})
+ endif()
+
+ if(ARG_PRECOMPILED_HEADER_LIB)
+ reuse_precompiled_header_lib(${TEST_NAME} ${ARG_PRECOMPILED_HEADER_LIB})
+ endif()
+
+ if(ARG_PRECOMPILED_HEADERS AND ARROW_USE_PRECOMPILED_HEADERS)
+ target_precompile_headers(${TEST_NAME} PRIVATE ${ARG_PRECOMPILED_HEADERS})
+ endif()
+
+ if(ARG_EXTRA_LINK_LIBS)
+ target_link_libraries(${TEST_NAME} PRIVATE ${ARG_EXTRA_LINK_LIBS})
+ endif()
+
+ if(ARG_EXTRA_INCLUDES)
+ target_include_directories(${TEST_NAME} SYSTEM PUBLIC ${ARG_EXTRA_INCLUDES})
+ endif()
+
+ if(ARG_EXTRA_DEPENDENCIES)
+ add_dependencies(${TEST_NAME} ${ARG_EXTRA_DEPENDENCIES})
+ endif()
+
+ if(ARROW_TEST_MEMCHECK AND NOT ARG_NO_VALGRIND)
+ add_test(${TEST_NAME}
+ bash
+ -c
+ "cd '${CMAKE_SOURCE_DIR}'; \
+ valgrind --suppressions=valgrind.supp --tool=memcheck --gen-suppressions=all \
+ --num-callers=500 --leak-check=full --leak-check-heuristics=stdstring \
+ --error-exitcode=1 ${TEST_PATH}")
+ elseif(WIN32)
+ add_test(${TEST_NAME} ${TEST_PATH})
+ else()
+ add_test(${TEST_NAME}
+ ${BUILD_SUPPORT_DIR}/run-test.sh
+ ${CMAKE_BINARY_DIR}
+ test
+ ${TEST_PATH})
+ endif()
+
+ # Add test as dependency of relevant targets
+ add_dependencies(all-tests ${TEST_NAME})
+ foreach(TARGET ${ARG_LABELS})
+ add_dependencies(${TARGET} ${TEST_NAME})
+ endforeach()
+
+ set(LABELS)
+ list(APPEND LABELS "unittest")
+ if(ARG_LABELS)
+ list(APPEND LABELS ${ARG_LABELS})
+ endif()
+ # EXTRA_LABELS don't create their own dependencies, they are only used
+ # to ease running certain test categories.
+ if(ARG_EXTRA_LABELS)
+ list(APPEND LABELS ${ARG_EXTRA_LABELS})
+ endif()
+
+ foreach(LABEL ${ARG_LABELS})
+ # ensure there is a cmake target which exercises tests with this LABEL
+ set(LABEL_TEST_NAME "test-${LABEL}")
+ if(NOT TARGET ${LABEL_TEST_NAME})
+ add_custom_target(${LABEL_TEST_NAME}
+ ctest -L "${LABEL}" --output-on-failure
+ USES_TERMINAL)
+ endif()
+ # ensure the test is (re)built before the LABEL test runs
+ add_dependencies(${LABEL_TEST_NAME} ${TEST_NAME})
+ endforeach()
+
+ set_property(TEST ${TEST_NAME}
+ APPEND
+ PROPERTY LABELS ${LABELS})
+endfunction()
+
+#
+# Examples
+#
+# Add a new example, with or without an executable that should be built.
+# If examples are enabled then they will be run along side unit tests with ctest.
+# 'make runexample' to build/run only examples.
+#
+# REL_EXAMPLE_NAME is the name of the example app. It may be a single component
+# (e.g. monotime-example) or contain additional components (e.g.
+# net/net_util-example). Either way, the last component must be a globally
+# unique name.
+
+# The example will registered as unit test with ctest with a label
+# of 'example'.
+#
+# Arguments after the test name will be passed to set_tests_properties().
+#
+# \arg PREFIX a string to append to the name of the example executable. For
+# example, if you have src/arrow/foo/bar-example.cc, then PREFIX "foo" will
+# create test executable foo-bar-example
+function(ADD_ARROW_EXAMPLE REL_EXAMPLE_NAME)
+ set(options)
+ set(one_value_args)
+ set(multi_value_args EXTRA_LINK_LIBS DEPENDENCIES PREFIX)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+ if(ARG_UNPARSED_ARGUMENTS)
+ message(SEND_ERROR "Error: unrecognized arguments: ${ARG_UNPARSED_ARGUMENTS}")
+ endif()
+
+ if(NO_EXAMPLES)
+ return()
+ endif()
+ get_filename_component(EXAMPLE_NAME ${REL_EXAMPLE_NAME} NAME_WE)
+
+ if(ARG_PREFIX)
+ set(EXAMPLE_NAME "${ARG_PREFIX}-${EXAMPLE_NAME}")
+ endif()
+
+ if(EXISTS ${CMAKE_SOURCE_DIR}/examples/arrow/${REL_EXAMPLE_NAME}.cc)
+ # This example has a corresponding .cc file, set it up as an executable.
+ set(EXAMPLE_PATH "${EXECUTABLE_OUTPUT_PATH}/${EXAMPLE_NAME}")
+ add_executable(${EXAMPLE_NAME} "${REL_EXAMPLE_NAME}.cc")
+ target_link_libraries(${EXAMPLE_NAME} ${ARROW_EXAMPLE_LINK_LIBS})
+ add_dependencies(runexample ${EXAMPLE_NAME})
+ set(NO_COLOR "--color_print=false")
+
+ if(ARG_EXTRA_LINK_LIBS)
+ target_link_libraries(${EXAMPLE_NAME} ${ARG_EXTRA_LINK_LIBS})
+ endif()
+ endif()
+
+ if(ARG_DEPENDENCIES)
+ add_dependencies(${EXAMPLE_NAME} ${ARG_DEPENDENCIES})
+ endif()
+
+ add_test(${EXAMPLE_NAME} ${EXAMPLE_PATH})
+ set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "example")
+endfunction()
+
+#
+# Fuzzing
+#
+# Add new fuzz target executable.
+#
+# The single source file must define a function:
+# extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size)
+#
+# No main function must be present within the source file!
+#
+function(ADD_FUZZ_TARGET REL_FUZZING_NAME)
+ set(options)
+ set(one_value_args PREFIX)
+ set(multi_value_args LINK_LIBS)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+ if(ARG_UNPARSED_ARGUMENTS)
+ message(SEND_ERROR "Error: unrecognized arguments: ${ARG_UNPARSED_ARGUMENTS}")
+ endif()
+
+ if(NO_FUZZING)
+ return()
+ endif()
+
+ get_filename_component(FUZZING_NAME ${REL_FUZZING_NAME} NAME_WE)
+
+ # Make sure the executable name contains only hyphens, not underscores
+ string(REPLACE "_" "-" FUZZING_NAME ${FUZZING_NAME})
+
+ if(ARG_PREFIX)
+ set(FUZZING_NAME "${ARG_PREFIX}-${FUZZING_NAME}")
+ endif()
+
+ # For OSS-Fuzz
+ # (https://google.github.io/oss-fuzz/advanced-topics/ideal-integration/)
+ if(DEFINED ENV{LIB_FUZZING_ENGINE})
+ set(FUZZ_LDFLAGS $ENV{LIB_FUZZING_ENGINE})
+ else()
+ set(FUZZ_LDFLAGS "-fsanitize=fuzzer")
+ endif()
+
+ add_executable(${FUZZING_NAME} "${REL_FUZZING_NAME}.cc")
+ target_link_libraries(${FUZZING_NAME} ${LINK_LIBS})
+ target_compile_options(${FUZZING_NAME} PRIVATE ${FUZZ_LDFLAGS})
+ set_target_properties(${FUZZING_NAME} PROPERTIES LINK_FLAGS ${FUZZ_LDFLAGS} LABELS
+ "fuzzing")
+endfunction()
+
+function(ARROW_INSTALL_ALL_HEADERS PATH)
+ set(options)
+ set(one_value_args)
+ set(multi_value_args PATTERN)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+ if(NOT ARG_PATTERN)
+ # The .hpp extension is used by some vendored libraries
+ set(ARG_PATTERN "*.h" "*.hpp")
+ endif()
+ file(GLOB CURRENT_DIRECTORY_HEADERS ${ARG_PATTERN})
+
+ set(PUBLIC_HEADERS)
+ foreach(HEADER ${CURRENT_DIRECTORY_HEADERS})
+ get_filename_component(HEADER_BASENAME ${HEADER} NAME)
+ if(HEADER_BASENAME MATCHES "internal")
+ continue()
+ endif()
+ list(APPEND PUBLIC_HEADERS ${HEADER})
+ endforeach()
+ install(FILES ${PUBLIC_HEADERS} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${PATH}")
+endfunction()
+
+function(ARROW_ADD_PKG_CONFIG MODULE)
+ configure_file(${MODULE}.pc.in "${CMAKE_CURRENT_BINARY_DIR}/${MODULE}.pc" @ONLY)
+ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/${MODULE}.pc"
+ DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig/")
+endfunction()
+
+function(ARROW_INSTALL_CMAKE_FIND_MODULE MODULE)
+ install(FILES "${ARROW_SOURCE_DIR}/cmake_modules/Find${MODULE}.cmake"
+ DESTINATION "${ARROW_CMAKE_INSTALL_DIR}")
+endfunction()
+
+# Implementations of lisp "car" and "cdr" functions
+macro(ARROW_CAR var)
+ set(${var} ${ARGV1})
+endmacro()
+
+macro(ARROW_CDR var rest)
+ set(${var} ${ARGN})
+endmacro()
diff --git a/src/arrow/cpp/cmake_modules/DefineOptions.cmake b/src/arrow/cpp/cmake_modules/DefineOptions.cmake
new file mode 100644
index 000000000..3568887fa
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/DefineOptions.cmake
@@ -0,0 +1,589 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+macro(set_option_category name)
+ set(ARROW_OPTION_CATEGORY ${name})
+ list(APPEND "ARROW_OPTION_CATEGORIES" ${name})
+endmacro()
+
+function(check_description_length name description)
+ foreach(description_line ${description})
+ string(LENGTH ${description_line} line_length)
+ if(${line_length} GREATER 80)
+ message(FATAL_ERROR "description for ${name} contained a\n\
+ line ${line_length} characters long!\n\
+ (max is 80). Split it into more lines with semicolons")
+ endif()
+ endforeach()
+endfunction()
+
+function(list_join lst glue out)
+ if("${${lst}}" STREQUAL "")
+ set(${out}
+ ""
+ PARENT_SCOPE)
+ return()
+ endif()
+
+ list(GET ${lst} 0 joined)
+ list(REMOVE_AT ${lst} 0)
+ foreach(item ${${lst}})
+ set(joined "${joined}${glue}${item}")
+ endforeach()
+ set(${out}
+ ${joined}
+ PARENT_SCOPE)
+endfunction()
+
+macro(define_option name description default)
+ check_description_length(${name} ${description})
+ list_join(description "\n" multiline_description)
+
+ option(${name} "${multiline_description}" ${default})
+
+ list(APPEND "ARROW_${ARROW_OPTION_CATEGORY}_OPTION_NAMES" ${name})
+ set("${name}_OPTION_DESCRIPTION" ${description})
+ set("${name}_OPTION_DEFAULT" ${default})
+ set("${name}_OPTION_TYPE" "bool")
+endmacro()
+
+macro(define_option_string name description default)
+ check_description_length(${name} ${description})
+ list_join(description "\n" multiline_description)
+
+ set(${name}
+ ${default}
+ CACHE STRING "${multiline_description}")
+
+ list(APPEND "ARROW_${ARROW_OPTION_CATEGORY}_OPTION_NAMES" ${name})
+ set("${name}_OPTION_DESCRIPTION" ${description})
+ set("${name}_OPTION_DEFAULT" "\"${default}\"")
+ set("${name}_OPTION_TYPE" "string")
+ set("${name}_OPTION_POSSIBLE_VALUES" ${ARGN})
+
+ list_join("${name}_OPTION_POSSIBLE_VALUES" "|" "${name}_OPTION_ENUM")
+ if(NOT ("${${name}_OPTION_ENUM}" STREQUAL ""))
+ set_property(CACHE ${name} PROPERTY STRINGS "${name}_OPTION_POSSIBLE_VALUES")
+ endif()
+endmacro()
+
+# Top level cmake dir
+if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}")
+ #----------------------------------------------------------------------
+ set_option_category("Compile and link")
+
+ define_option_string(ARROW_CXXFLAGS "Compiler flags to append when compiling Arrow" "")
+
+ define_option(ARROW_BUILD_STATIC "Build static libraries" ON)
+
+ define_option(ARROW_BUILD_SHARED "Build shared libraries" ON)
+
+ define_option_string(ARROW_PACKAGE_KIND
+ "Arbitrary string that identifies the kind of package;\
+(for informational purposes)" "")
+
+ define_option_string(ARROW_GIT_ID "The Arrow git commit id (if any)" "")
+
+ define_option_string(ARROW_GIT_DESCRIPTION "The Arrow git commit description (if any)"
+ "")
+
+ define_option(ARROW_NO_DEPRECATED_API "Exclude deprecated APIs from build" OFF)
+
+ define_option(ARROW_USE_CCACHE "Use ccache when compiling (if available)" ON)
+
+ define_option(ARROW_USE_LD_GOLD "Use ld.gold for linking on Linux (if available)" OFF)
+
+ define_option(ARROW_USE_PRECOMPILED_HEADERS "Use precompiled headers when compiling"
+ OFF)
+
+ define_option_string(ARROW_SIMD_LEVEL
+ "Compile-time SIMD optimization level"
+ "DEFAULT" # default to SSE4_2 on x86, NEON on Arm, NONE otherwise
+ "NONE"
+ "SSE4_2"
+ "AVX2"
+ "AVX512"
+ "NEON"
+ "DEFAULT")
+
+ define_option_string(ARROW_RUNTIME_SIMD_LEVEL
+ "Max runtime SIMD optimization level"
+ "MAX" # default to max supported by compiler
+ "NONE"
+ "SSE4_2"
+ "AVX2"
+ "AVX512"
+ "MAX")
+
+ # Arm64 architectures and extensions can lead to exploding combinations.
+ # So set it directly through cmake command line.
+ #
+ # If you change this, you need to change the definition in
+ # python/CMakeLists.txt too.
+ define_option_string(ARROW_ARMV8_ARCH
+ "Arm64 arch and extensions"
+ "armv8-a" # Default
+ "armv8-a"
+ "armv8-a+crc+crypto")
+
+ define_option(ARROW_ALTIVEC "Build with Altivec if compiler has support" ON)
+
+ define_option(ARROW_RPATH_ORIGIN "Build Arrow libraries with RATH set to \$ORIGIN" OFF)
+
+ define_option(ARROW_INSTALL_NAME_RPATH
+ "Build Arrow libraries with install_name set to @rpath" ON)
+
+ define_option(ARROW_GGDB_DEBUG "Pass -ggdb flag to debug builds" ON)
+
+ #----------------------------------------------------------------------
+ set_option_category("Test and benchmark")
+
+ define_option(ARROW_BUILD_EXAMPLES "Build the Arrow examples" OFF)
+
+ define_option(ARROW_BUILD_TESTS "Build the Arrow googletest unit tests" OFF)
+
+ define_option(ARROW_ENABLE_TIMING_TESTS "Enable timing-sensitive tests" ON)
+
+ define_option(ARROW_BUILD_INTEGRATION "Build the Arrow integration test executables"
+ OFF)
+
+ define_option(ARROW_BUILD_BENCHMARKS "Build the Arrow micro benchmarks" OFF)
+
+ # Reference benchmarks are used to compare to naive implementation, or
+ # discover various hardware limits.
+ define_option(ARROW_BUILD_BENCHMARKS_REFERENCE
+ "Build the Arrow micro reference benchmarks" OFF)
+
+ if(ARROW_BUILD_SHARED)
+ set(ARROW_TEST_LINKAGE_DEFAULT "shared")
+ else()
+ set(ARROW_TEST_LINKAGE_DEFAULT "static")
+ endif()
+
+ define_option_string(ARROW_TEST_LINKAGE
+ "Linkage of Arrow libraries with unit tests executables."
+ "${ARROW_TEST_LINKAGE_DEFAULT}"
+ "shared"
+ "static")
+
+ define_option(ARROW_FUZZING "Build Arrow Fuzzing executables" OFF)
+
+ define_option(ARROW_LARGE_MEMORY_TESTS "Enable unit tests which use large memory" OFF)
+
+ #----------------------------------------------------------------------
+ set_option_category("Lint")
+
+ define_option(ARROW_ONLY_LINT "Only define the lint and check-format targets" OFF)
+
+ define_option(ARROW_VERBOSE_LINT
+ "If off, 'quiet' flags will be passed to linting tools" OFF)
+
+ define_option(ARROW_GENERATE_COVERAGE "Build with C++ code coverage enabled" OFF)
+
+ #----------------------------------------------------------------------
+ set_option_category("Checks")
+
+ define_option(ARROW_TEST_MEMCHECK "Run the test suite using valgrind --tool=memcheck"
+ OFF)
+
+ define_option(ARROW_USE_ASAN "Enable Address Sanitizer checks" OFF)
+
+ define_option(ARROW_USE_TSAN "Enable Thread Sanitizer checks" OFF)
+
+ define_option(ARROW_USE_UBSAN "Enable Undefined Behavior sanitizer checks" OFF)
+
+ #----------------------------------------------------------------------
+ set_option_category("Project component")
+
+ define_option(ARROW_BUILD_UTILITIES "Build Arrow commandline utilities" OFF)
+
+ define_option(ARROW_COMPUTE "Build the Arrow Compute Modules" OFF)
+
+ define_option(ARROW_CSV "Build the Arrow CSV Parser Module" OFF)
+
+ define_option(ARROW_CUDA "Build the Arrow CUDA extensions (requires CUDA toolkit)" OFF)
+
+ define_option(ARROW_DATASET "Build the Arrow Dataset Modules" OFF)
+
+ define_option(ARROW_FILESYSTEM "Build the Arrow Filesystem Layer" OFF)
+
+ define_option(ARROW_FLIGHT
+ "Build the Arrow Flight RPC System (requires GRPC, Protocol Buffers)" OFF)
+
+ define_option(ARROW_GANDIVA "Build the Gandiva libraries" OFF)
+
+ define_option(ARROW_GCS
+ "Build Arrow with GCS support (requires the GCloud SDK for C++)" OFF)
+ mark_as_advanced(ARROW_GCS) # TODO(ARROW-1231) - remove once completed
+
+ define_option(ARROW_HDFS "Build the Arrow HDFS bridge" OFF)
+
+ define_option(ARROW_HIVESERVER2 "Build the HiveServer2 client and Arrow adapter" OFF)
+
+ define_option(ARROW_IPC "Build the Arrow IPC extensions" ON)
+
+ set(ARROW_JEMALLOC_DESCRIPTION "Build the Arrow jemalloc-based allocator")
+ if(WIN32 OR "${CMAKE_SYSTEM_NAME}" STREQUAL "FreeBSD")
+ # jemalloc is not supported on Windows.
+ #
+ # jemalloc is the default malloc implementation on FreeBSD and can't
+ # be built with --disable-libdl on FreeBSD. Because lazy-lock feature
+ # is required on FreeBSD. Lazy-lock feature requires libdl.
+ define_option(ARROW_JEMALLOC ${ARROW_JEMALLOC_DESCRIPTION} OFF)
+ else()
+ define_option(ARROW_JEMALLOC ${ARROW_JEMALLOC_DESCRIPTION} ON)
+ endif()
+
+ define_option(ARROW_JNI "Build the Arrow JNI lib" OFF)
+
+ define_option(ARROW_JSON "Build Arrow with JSON support (requires RapidJSON)" OFF)
+
+ define_option(ARROW_MIMALLOC "Build the Arrow mimalloc-based allocator" OFF)
+
+ define_option(ARROW_PARQUET "Build the Parquet libraries" OFF)
+
+ define_option(ARROW_ORC "Build the Arrow ORC adapter" OFF)
+
+ define_option(ARROW_PLASMA "Build the plasma object store along with Arrow" OFF)
+
+ define_option(ARROW_PLASMA_JAVA_CLIENT "Build the plasma object store java client" OFF)
+
+ define_option(ARROW_PYTHON "Build the Arrow CPython extensions" OFF)
+
+ define_option(ARROW_S3 "Build Arrow with S3 support (requires the AWS SDK for C++)" OFF)
+
+ define_option(ARROW_TENSORFLOW "Build Arrow with TensorFlow support enabled" OFF)
+
+ define_option(ARROW_TESTING "Build the Arrow testing libraries" OFF)
+
+ #----------------------------------------------------------------------
+ set_option_category("Thirdparty toolchain")
+
+ # Determine how we will look for dependencies
+ # * AUTO: Guess which packaging systems we're running in and pull the
+ # dependencies from there. Build any missing ones through the
+ # ExternalProject setup. This is the default unless the CONDA_PREFIX
+ # environment variable is set, in which case the CONDA method is used
+ # * BUNDLED: Build dependencies through CMake's ExternalProject facility. If
+ # you wish to build individual dependencies from source instead of using
+ # one of the other methods, pass -D$NAME_SOURCE=BUNDLED
+ # * SYSTEM: Use CMake's find_package and find_library without any custom
+ # paths. If individual packages are on non-default locations, you can pass
+ # $NAME_ROOT arguments to CMake, or set environment variables for the same
+ # with CMake 3.11 and higher. If your system packages are in a non-default
+ # location, or if you are using a non-standard toolchain, you can also pass
+ # ARROW_PACKAGE_PREFIX to set the *_ROOT variables to look in that
+ # directory
+ # * CONDA: Same as SYSTEM but set all *_ROOT variables to
+ # ENV{CONDA_PREFIX}. If this is run within an active conda environment,
+ # then ENV{CONDA_PREFIX} will be used for dependencies unless
+ # ARROW_DEPENDENCY_SOURCE is set explicitly to one of the other options
+ # * VCPKG: Searches for dependencies installed by vcpkg.
+ # * BREW: Use SYSTEM but search for select packages with brew.
+ if(NOT "$ENV{CONDA_PREFIX}" STREQUAL "")
+ set(ARROW_DEPENDENCY_SOURCE_DEFAULT "CONDA")
+ else()
+ set(ARROW_DEPENDENCY_SOURCE_DEFAULT "AUTO")
+ endif()
+ define_option_string(ARROW_DEPENDENCY_SOURCE
+ "Method to use for acquiring arrow's build dependencies"
+ "${ARROW_DEPENDENCY_SOURCE_DEFAULT}"
+ "AUTO"
+ "BUNDLED"
+ "SYSTEM"
+ "CONDA"
+ "VCPKG"
+ "BREW")
+
+ define_option(ARROW_VERBOSE_THIRDPARTY_BUILD
+ "Show output from ExternalProjects rather than just logging to files" OFF)
+
+ define_option(ARROW_DEPENDENCY_USE_SHARED "Link to shared libraries" ON)
+
+ define_option(ARROW_BOOST_USE_SHARED "Rely on boost shared libraries where relevant"
+ ${ARROW_DEPENDENCY_USE_SHARED})
+
+ define_option(ARROW_BROTLI_USE_SHARED "Rely on Brotli shared libraries where relevant"
+ ${ARROW_DEPENDENCY_USE_SHARED})
+
+ define_option(ARROW_BZ2_USE_SHARED "Rely on Bz2 shared libraries where relevant"
+ ${ARROW_DEPENDENCY_USE_SHARED})
+
+ define_option(ARROW_GFLAGS_USE_SHARED "Rely on GFlags shared libraries where relevant"
+ ${ARROW_DEPENDENCY_USE_SHARED})
+
+ define_option(ARROW_GRPC_USE_SHARED "Rely on gRPC shared libraries where relevant"
+ ${ARROW_DEPENDENCY_USE_SHARED})
+
+ define_option(ARROW_LZ4_USE_SHARED "Rely on lz4 shared libraries where relevant"
+ ${ARROW_DEPENDENCY_USE_SHARED})
+
+ define_option(ARROW_OPENSSL_USE_SHARED
+ "Rely on OpenSSL shared libraries where relevant"
+ ${ARROW_DEPENDENCY_USE_SHARED})
+
+ define_option(ARROW_PROTOBUF_USE_SHARED
+ "Rely on Protocol Buffers shared libraries where relevant"
+ ${ARROW_DEPENDENCY_USE_SHARED})
+
+ if(WIN32)
+ # It seems that Thrift doesn't support DLL well yet.
+ # MSYS2, conda-forge and vcpkg don't build shared library.
+ set(ARROW_THRIFT_USE_SHARED_DEFAULT OFF)
+ else()
+ set(ARROW_THRIFT_USE_SHARED_DEFAULT ${ARROW_DEPENDENCY_USE_SHARED})
+ endif()
+ define_option(ARROW_THRIFT_USE_SHARED "Rely on thrift shared libraries where relevant"
+ ${ARROW_THRIFT_USE_SHARED_DEFAULT})
+
+ define_option(ARROW_UTF8PROC_USE_SHARED
+ "Rely on utf8proc shared libraries where relevant"
+ ${ARROW_DEPENDENCY_USE_SHARED})
+
+ define_option(ARROW_SNAPPY_USE_SHARED "Rely on snappy shared libraries where relevant"
+ ${ARROW_DEPENDENCY_USE_SHARED})
+
+ define_option(ARROW_UTF8PROC_USE_SHARED
+ "Rely on utf8proc shared libraries where relevant"
+ ${ARROW_DEPENDENCY_USE_SHARED})
+
+ define_option(ARROW_ZSTD_USE_SHARED "Rely on zstd shared libraries where relevant"
+ ${ARROW_DEPENDENCY_USE_SHARED})
+
+ define_option(ARROW_USE_GLOG "Build libraries with glog support for pluggable logging"
+ OFF)
+
+ define_option(ARROW_WITH_BACKTRACE "Build with backtrace support" ON)
+
+ define_option(ARROW_WITH_BROTLI "Build with Brotli compression" OFF)
+ define_option(ARROW_WITH_BZ2 "Build with BZ2 compression" OFF)
+ define_option(ARROW_WITH_LZ4 "Build with lz4 compression" OFF)
+ define_option(ARROW_WITH_SNAPPY "Build with Snappy compression" OFF)
+ define_option(ARROW_WITH_ZLIB "Build with zlib compression" OFF)
+ define_option(ARROW_WITH_ZSTD "Build with zstd compression" OFF)
+
+ define_option(ARROW_WITH_UTF8PROC
+ "Build with support for Unicode properties using the utf8proc library;(only used if ARROW_COMPUTE is ON or ARROW_GANDIVA is ON)"
+ ON)
+ define_option(ARROW_WITH_RE2
+ "Build with support for regular expressions using the re2 library;(only used if ARROW_COMPUTE or ARROW_GANDIVA is ON)"
+ ON)
+
+ #----------------------------------------------------------------------
+ if(MSVC_TOOLCHAIN)
+ set_option_category("MSVC")
+
+ define_option(MSVC_LINK_VERBOSE
+ "Pass verbose linking options when linking libraries and executables"
+ OFF)
+
+ define_option_string(BROTLI_MSVC_STATIC_LIB_SUFFIX
+ "Brotli static lib suffix used on Windows with MSVC" "-static")
+
+ define_option_string(PROTOBUF_MSVC_STATIC_LIB_SUFFIX
+ "Protobuf static lib suffix used on Windows with MSVC" "")
+
+ define_option_string(RE2_MSVC_STATIC_LIB_SUFFIX
+ "re2 static lib suffix used on Windows with MSVC" "_static")
+
+ if(DEFINED ENV{CONDA_PREFIX})
+ # Conda package changes the output name.
+ # https://github.com/conda-forge/snappy-feedstock/blob/master/recipe/windows-static-lib-name.patch
+ set(SNAPPY_MSVC_STATIC_LIB_SUFFIX_DEFAULT "_static")
+ else()
+ set(SNAPPY_MSVC_STATIC_LIB_SUFFIX_DEFAULT "")
+ endif()
+ define_option_string(SNAPPY_MSVC_STATIC_LIB_SUFFIX
+ "Snappy static lib suffix used on Windows with MSVC"
+ "${SNAPPY_MSVC_STATIC_LIB_SUFFIX_DEFAULT}")
+
+ define_option_string(LZ4_MSVC_STATIC_LIB_SUFFIX
+ "Lz4 static lib suffix used on Windows with MSVC" "_static")
+
+ define_option_string(ZSTD_MSVC_STATIC_LIB_SUFFIX
+ "ZStd static lib suffix used on Windows with MSVC" "_static")
+
+ define_option(ARROW_USE_STATIC_CRT "Build Arrow with statically linked CRT" OFF)
+ endif()
+
+ #----------------------------------------------------------------------
+ set_option_category("Parquet")
+
+ define_option(PARQUET_MINIMAL_DEPENDENCY
+ "Depend only on Thirdparty headers to build libparquet.;\
+Always OFF if building binaries" OFF)
+
+ define_option(PARQUET_BUILD_EXECUTABLES
+ "Build the Parquet executable CLI tools. Requires static libraries to be built."
+ OFF)
+
+ define_option(PARQUET_BUILD_EXAMPLES
+ "Build the Parquet examples. Requires static libraries to be built." OFF)
+
+ define_option(PARQUET_REQUIRE_ENCRYPTION
+ "Build support for encryption. Fail if OpenSSL is not found" OFF)
+
+ #----------------------------------------------------------------------
+ set_option_category("Gandiva")
+
+ define_option(ARROW_GANDIVA_JAVA "Build the Gandiva JNI wrappers" OFF)
+
+ # ARROW-3860: Temporary workaround
+ define_option(ARROW_GANDIVA_STATIC_LIBSTDCPP
+ "Include -static-libstdc++ -static-libgcc when linking with;Gandiva static libraries"
+ OFF)
+
+ define_option_string(ARROW_GANDIVA_PC_CXX_FLAGS
+ "Compiler flags to append when pre-compiling Gandiva operations"
+ "")
+
+ #----------------------------------------------------------------------
+ set_option_category("Advanced developer")
+
+ define_option(ARROW_EXTRA_ERROR_CONTEXT
+ "Compile with extra error context (line numbers, code)" OFF)
+
+ define_option(ARROW_OPTIONAL_INSTALL
+ "If enabled install ONLY targets that have already been built. Please be;\
+advised that if this is enabled 'install' will fail silently on components;\
+that have not been built"
+ OFF)
+
+ option(ARROW_BUILD_CONFIG_SUMMARY_JSON "Summarize build configuration in a JSON file"
+ ON)
+endif()
+
+macro(validate_config)
+ foreach(category ${ARROW_OPTION_CATEGORIES})
+ set(option_names ${ARROW_${category}_OPTION_NAMES})
+
+ foreach(name ${option_names})
+ set(possible_values ${${name}_OPTION_POSSIBLE_VALUES})
+ set(value "${${name}}")
+ if(possible_values)
+ if(NOT "${value}" IN_LIST possible_values)
+ message(FATAL_ERROR "Configuration option ${name} got invalid value '${value}'. "
+ "Allowed values: ${${name}_OPTION_ENUM}.")
+ endif()
+ endif()
+ endforeach()
+
+ endforeach()
+endmacro()
+
+macro(config_summary_message)
+ message(STATUS "---------------------------------------------------------------------")
+ message(STATUS "Arrow version: ${ARROW_VERSION}")
+ message(STATUS)
+ message(STATUS "Build configuration summary:")
+
+ message(STATUS " Generator: ${CMAKE_GENERATOR}")
+ message(STATUS " Build type: ${CMAKE_BUILD_TYPE}")
+ message(STATUS " Source directory: ${CMAKE_CURRENT_SOURCE_DIR}")
+ message(STATUS " Install prefix: ${CMAKE_INSTALL_PREFIX}")
+ if(${CMAKE_EXPORT_COMPILE_COMMANDS})
+ message(STATUS " Compile commands: ${CMAKE_CURRENT_BINARY_DIR}/compile_commands.json"
+ )
+ endif()
+
+ foreach(category ${ARROW_OPTION_CATEGORIES})
+
+ message(STATUS)
+ message(STATUS "${category} options:")
+ message(STATUS)
+
+ set(option_names ${ARROW_${category}_OPTION_NAMES})
+
+ foreach(name ${option_names})
+ set(value "${${name}}")
+ if("${value}" STREQUAL "")
+ set(value "\"\"")
+ endif()
+
+ set(description ${${name}_OPTION_DESCRIPTION})
+
+ if(NOT ("${${name}_OPTION_ENUM}" STREQUAL ""))
+ set(summary "=${value} [default=${${name}_OPTION_ENUM}]")
+ else()
+ set(summary "=${value} [default=${${name}_OPTION_DEFAULT}]")
+ endif()
+
+ message(STATUS " ${name}${summary}")
+ foreach(description_line ${description})
+ message(STATUS " ${description_line}")
+ endforeach()
+ endforeach()
+
+ endforeach()
+
+endmacro()
+
+macro(config_summary_json)
+ set(summary "${CMAKE_CURRENT_BINARY_DIR}/cmake_summary.json")
+ message(STATUS " Outputting build configuration summary to ${summary}")
+ file(WRITE ${summary} "{\n")
+
+ foreach(category ${ARROW_OPTION_CATEGORIES})
+ foreach(name ${ARROW_${category}_OPTION_NAMES})
+ file(APPEND ${summary} "\"${name}\": \"${${name}}\",\n")
+ endforeach()
+ endforeach()
+
+ file(APPEND ${summary} "\"generator\": \"${CMAKE_GENERATOR}\",\n")
+ file(APPEND ${summary} "\"build_type\": \"${CMAKE_BUILD_TYPE}\",\n")
+ file(APPEND ${summary} "\"source_dir\": \"${CMAKE_CURRENT_SOURCE_DIR}\",\n")
+ if(${CMAKE_EXPORT_COMPILE_COMMANDS})
+ file(APPEND ${summary} "\"compile_commands\": "
+ "\"${CMAKE_CURRENT_BINARY_DIR}/compile_commands.json\",\n")
+ endif()
+ file(APPEND ${summary} "\"install_prefix\": \"${CMAKE_INSTALL_PREFIX}\",\n")
+ file(APPEND ${summary} "\"arrow_version\": \"${ARROW_VERSION}\"\n")
+ file(APPEND ${summary} "}\n")
+endmacro()
+
+macro(config_summary_cmake_setters path)
+ file(WRITE ${path} "# Options used to build arrow:")
+
+ foreach(category ${ARROW_OPTION_CATEGORIES})
+ file(APPEND ${path} "\n\n## ${category} options:")
+ foreach(name ${ARROW_${category}_OPTION_NAMES})
+ set(description ${${name}_OPTION_DESCRIPTION})
+ foreach(description_line ${description})
+ file(APPEND ${path} "\n### ${description_line}")
+ endforeach()
+ file(APPEND ${path} "\nset(${name} \"${${name}}\")")
+ endforeach()
+ endforeach()
+
+endmacro()
+
+#----------------------------------------------------------------------
+# Compute default values for omitted variables
+
+if(NOT ARROW_GIT_ID)
+ execute_process(COMMAND "git" "log" "-n1" "--format=%H"
+ WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
+ OUTPUT_VARIABLE ARROW_GIT_ID
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+endif()
+if(NOT ARROW_GIT_DESCRIPTION)
+ execute_process(COMMAND "git" "describe" "--tags" "--dirty"
+ WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
+ ERROR_QUIET
+ OUTPUT_VARIABLE ARROW_GIT_DESCRIPTION
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindArrow.cmake b/src/arrow/cpp/cmake_modules/FindArrow.cmake
new file mode 100644
index 000000000..68024cc27
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindArrow.cmake
@@ -0,0 +1,466 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# - Find Arrow (arrow/api.h, libarrow.a, libarrow.so)
+# This module defines
+# ARROW_FOUND, whether Arrow has been found
+# ARROW_FULL_SO_VERSION, full shared object version of found Arrow "100.0.0"
+# ARROW_IMPORT_LIB, path to libarrow's import library (Windows only)
+# ARROW_INCLUDE_DIR, directory containing headers
+# ARROW_LIBS, deprecated. Use ARROW_LIB_DIR instead
+# ARROW_LIB_DIR, directory containing Arrow libraries
+# ARROW_SHARED_IMP_LIB, deprecated. Use ARROW_IMPORT_LIB instead
+# ARROW_SHARED_LIB, path to libarrow's shared library
+# ARROW_SO_VERSION, shared object version of found Arrow such as "100"
+# ARROW_STATIC_LIB, path to libarrow.a
+# ARROW_VERSION, version of found Arrow
+# ARROW_VERSION_MAJOR, major version of found Arrow
+# ARROW_VERSION_MINOR, minor version of found Arrow
+# ARROW_VERSION_PATCH, patch version of found Arrow
+
+if(DEFINED ARROW_FOUND)
+ return()
+endif()
+
+include(FindPkgConfig)
+include(FindPackageHandleStandardArgs)
+
+if(WIN32 AND NOT MINGW)
+ # This is used to handle builds using e.g. clang in an MSVC setting.
+ set(MSVC_TOOLCHAIN TRUE)
+else()
+ set(MSVC_TOOLCHAIN FALSE)
+endif()
+
+set(ARROW_SEARCH_LIB_PATH_SUFFIXES)
+if(CMAKE_LIBRARY_ARCHITECTURE)
+ list(APPEND ARROW_SEARCH_LIB_PATH_SUFFIXES "lib/${CMAKE_LIBRARY_ARCHITECTURE}")
+endif()
+list(APPEND
+ ARROW_SEARCH_LIB_PATH_SUFFIXES
+ "lib64"
+ "lib32"
+ "lib"
+ "bin")
+set(ARROW_CONFIG_SUFFIXES
+ "_RELEASE"
+ "_RELWITHDEBINFO"
+ "_MINSIZEREL"
+ "_DEBUG"
+ "")
+if(CMAKE_BUILD_TYPE)
+ string(TOUPPER ${CMAKE_BUILD_TYPE} ARROW_CONFIG_SUFFIX_PREFERRED)
+ set(ARROW_CONFIG_SUFFIX_PREFERRED "_${ARROW_CONFIG_SUFFIX_PREFERRED}")
+ list(INSERT ARROW_CONFIG_SUFFIXES 0 "${ARROW_CONFIG_SUFFIX_PREFERRED}")
+endif()
+
+if(NOT DEFINED ARROW_MSVC_STATIC_LIB_SUFFIX)
+ if(MSVC_TOOLCHAIN)
+ set(ARROW_MSVC_STATIC_LIB_SUFFIX "_static")
+ else()
+ set(ARROW_MSVC_STATIC_LIB_SUFFIX "")
+ endif()
+endif()
+
+# Internal function.
+#
+# Set shared library name for ${base_name} to ${output_variable}.
+#
+# Example:
+# arrow_build_shared_library_name(ARROW_SHARED_LIBRARY_NAME arrow)
+# # -> ARROW_SHARED_LIBRARY_NAME=libarrow.so on Linux
+# # -> ARROW_SHARED_LIBRARY_NAME=libarrow.dylib on macOS
+# # -> ARROW_SHARED_LIBRARY_NAME=arrow.dll with MSVC on Windows
+# # -> ARROW_SHARED_LIBRARY_NAME=libarrow.dll with MinGW on Windows
+function(arrow_build_shared_library_name output_variable base_name)
+ set(${output_variable}
+ "${CMAKE_SHARED_LIBRARY_PREFIX}${base_name}${CMAKE_SHARED_LIBRARY_SUFFIX}"
+ PARENT_SCOPE)
+endfunction()
+
+# Internal function.
+#
+# Set import library name for ${base_name} to ${output_variable}.
+# This is useful only for MSVC build. Import library is used only
+# with MSVC build.
+#
+# Example:
+# arrow_build_import_library_name(ARROW_IMPORT_LIBRARY_NAME arrow)
+# # -> ARROW_IMPORT_LIBRARY_NAME=arrow on Linux (meaningless)
+# # -> ARROW_IMPORT_LIBRARY_NAME=arrow on macOS (meaningless)
+# # -> ARROW_IMPORT_LIBRARY_NAME=arrow.lib with MSVC on Windows
+# # -> ARROW_IMPORT_LIBRARY_NAME=libarrow.dll.a with MinGW on Windows
+function(arrow_build_import_library_name output_variable base_name)
+ set(${output_variable}
+ "${CMAKE_IMPORT_LIBRARY_PREFIX}${base_name}${CMAKE_IMPORT_LIBRARY_SUFFIX}"
+ PARENT_SCOPE)
+endfunction()
+
+# Internal function.
+#
+# Set static library name for ${base_name} to ${output_variable}.
+#
+# Example:
+# arrow_build_static_library_name(ARROW_STATIC_LIBRARY_NAME arrow)
+# # -> ARROW_STATIC_LIBRARY_NAME=libarrow.a on Linux
+# # -> ARROW_STATIC_LIBRARY_NAME=libarrow.a on macOS
+# # -> ARROW_STATIC_LIBRARY_NAME=arrow.lib with MSVC on Windows
+# # -> ARROW_STATIC_LIBRARY_NAME=libarrow.dll.a with MinGW on Windows
+function(arrow_build_static_library_name output_variable base_name)
+ set(${output_variable}
+ "${CMAKE_STATIC_LIBRARY_PREFIX}${base_name}${ARROW_MSVC_STATIC_LIB_SUFFIX}${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ PARENT_SCOPE)
+endfunction()
+
+# Internal function.
+#
+# Set macro value for ${macro_name} in ${header_content} to ${output_variable}.
+#
+# Example:
+# arrow_extract_macro_value(version_major
+# "ARROW_VERSION_MAJOR"
+# "#define ARROW_VERSION_MAJOR 1.0.0")
+# # -> version_major=1.0.0
+function(arrow_extract_macro_value output_variable macro_name header_content)
+ string(REGEX MATCH "#define +${macro_name} +[^\r\n]+" macro_definition
+ "${header_content}")
+ string(REGEX REPLACE "^#define +${macro_name} +(.+)$" "\\1" macro_value
+ "${macro_definition}")
+ set(${output_variable}
+ "${macro_value}"
+ PARENT_SCOPE)
+endfunction()
+
+# Internal macro only for arrow_find_package.
+#
+# Find package in HOME.
+macro(arrow_find_package_home)
+ find_path(${prefix}_include_dir "${header_path}"
+ PATHS "${home}"
+ PATH_SUFFIXES "include"
+ NO_DEFAULT_PATH)
+ set(include_dir "${${prefix}_include_dir}")
+ set(${prefix}_INCLUDE_DIR
+ "${include_dir}"
+ PARENT_SCOPE)
+
+ if(MSVC_TOOLCHAIN)
+ set(CMAKE_SHARED_LIBRARY_SUFFIXES_ORIGINAL ${CMAKE_FIND_LIBRARY_SUFFIXES})
+ # .dll isn't found by find_library with MSVC because .dll isn't included in
+ # CMAKE_FIND_LIBRARY_SUFFIXES.
+ list(APPEND CMAKE_FIND_LIBRARY_SUFFIXES "${CMAKE_SHARED_LIBRARY_SUFFIX}")
+ endif()
+ find_library(${prefix}_shared_lib
+ NAMES "${shared_lib_name}"
+ PATHS "${home}"
+ PATH_SUFFIXES ${ARROW_SEARCH_LIB_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ if(MSVC_TOOLCHAIN)
+ set(CMAKE_SHARED_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES_ORIGINAL})
+ endif()
+ set(shared_lib "${${prefix}_shared_lib}")
+ set(${prefix}_SHARED_LIB
+ "${shared_lib}"
+ PARENT_SCOPE)
+ if(shared_lib)
+ add_library(${target_shared} SHARED IMPORTED)
+ set_target_properties(${target_shared} PROPERTIES IMPORTED_LOCATION "${shared_lib}")
+ if(include_dir)
+ set_target_properties(${target_shared} PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
+ "${include_dir}")
+ endif()
+ find_library(${prefix}_import_lib
+ NAMES "${import_lib_name}"
+ PATHS "${home}"
+ PATH_SUFFIXES ${ARROW_SEARCH_LIB_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ set(import_lib "${${prefix}_import_lib}")
+ set(${prefix}_IMPORT_LIB
+ "${import_lib}"
+ PARENT_SCOPE)
+ if(import_lib)
+ set_target_properties(${target_shared} PROPERTIES IMPORTED_IMPLIB "${import_lib}")
+ endif()
+ endif()
+
+ find_library(${prefix}_static_lib
+ NAMES "${static_lib_name}"
+ PATHS "${home}"
+ PATH_SUFFIXES ${ARROW_SEARCH_LIB_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ set(static_lib "${${prefix}_static_lib}")
+ set(${prefix}_STATIC_LIB
+ "${static_lib}"
+ PARENT_SCOPE)
+ if(static_lib)
+ add_library(${target_static} STATIC IMPORTED)
+ set_target_properties(${target_static} PROPERTIES IMPORTED_LOCATION "${static_lib}")
+ if(include_dir)
+ set_target_properties(${target_static} PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
+ "${include_dir}")
+ endif()
+ endif()
+endmacro()
+
+# Internal macro only for arrow_find_package.
+#
+# Find package by CMake package configuration.
+macro(arrow_find_package_cmake_package_configuration)
+ find_package(${cmake_package_name} CONFIG)
+ if(${cmake_package_name}_FOUND)
+ set(${prefix}_USE_CMAKE_PACKAGE_CONFIG
+ TRUE
+ PARENT_SCOPE)
+ if(TARGET ${target_shared})
+ foreach(suffix ${ARROW_CONFIG_SUFFIXES})
+ get_target_property(shared_lib ${target_shared} IMPORTED_LOCATION${suffix})
+ if(shared_lib)
+ # Remove shared library version:
+ # libarrow.so.100.0.0 -> libarrow.so
+ # Because ARROW_HOME and pkg-config approaches don't add
+ # shared library version.
+ string(REGEX REPLACE "(${CMAKE_SHARED_LIBRARY_SUFFIX})[.0-9]+$" "\\1"
+ shared_lib "${shared_lib}")
+ set(${prefix}_SHARED_LIB
+ "${shared_lib}"
+ PARENT_SCOPE)
+ break()
+ endif()
+ endforeach()
+ endif()
+ if(TARGET ${target_static})
+ foreach(suffix ${ARROW_CONFIG_SUFFIXES})
+ get_target_property(static_lib ${target_static} IMPORTED_LOCATION${suffix})
+ if(static_lib)
+ set(${prefix}_STATIC_LIB
+ "${static_lib}"
+ PARENT_SCOPE)
+ break()
+ endif()
+ endforeach()
+ endif()
+ endif()
+endmacro()
+
+# Internal macro only for arrow_find_package.
+#
+# Find package by pkg-config.
+macro(arrow_find_package_pkg_config)
+ pkg_check_modules(${prefix}_PC ${pkg_config_name})
+ if(${prefix}_PC_FOUND)
+ set(${prefix}_USE_PKG_CONFIG
+ TRUE
+ PARENT_SCOPE)
+
+ set(include_dir "${${prefix}_PC_INCLUDEDIR}")
+ set(lib_dir "${${prefix}_PC_LIBDIR}")
+ set(shared_lib_paths "${${prefix}_PC_LINK_LIBRARIES}")
+ # Use the first shared library path as the IMPORTED_LOCATION
+ # for ${target_shared}. This assumes that the first shared library
+ # path is the shared library path for this module.
+ list(GET shared_lib_paths 0 first_shared_lib_path)
+ # Use the rest shared library paths as the INTERFACE_LINK_LIBRARIES
+ # for ${target_shared}. This assumes that the rest shared library
+ # paths are dependency library paths for this module.
+ list(LENGTH shared_lib_paths n_shared_lib_paths)
+ if(n_shared_lib_paths LESS_EQUAL 1)
+ set(rest_shared_lib_paths)
+ else()
+ list(SUBLIST
+ shared_lib_paths
+ 1
+ -1
+ rest_shared_lib_paths)
+ endif()
+
+ set(${prefix}_VERSION
+ "${${prefix}_PC_VERSION}"
+ PARENT_SCOPE)
+ set(${prefix}_INCLUDE_DIR
+ "${include_dir}"
+ PARENT_SCOPE)
+ set(${prefix}_SHARED_LIB
+ "${first_shared_lib_path}"
+ PARENT_SCOPE)
+
+ add_library(${target_shared} SHARED IMPORTED)
+ set_target_properties(${target_shared}
+ PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${include_dir}"
+ INTERFACE_LINK_LIBRARIES "${rest_shared_lib_paths}"
+ IMPORTED_LOCATION "${first_shared_lib_path}")
+ get_target_property(shared_lib ${target_shared} IMPORTED_LOCATION)
+
+ find_library(${prefix}_static_lib
+ NAMES "${static_lib_name}"
+ PATHS "${lib_dir}"
+ NO_DEFAULT_PATH)
+ set(static_lib "${${prefix}_static_lib}")
+ set(${prefix}_STATIC_LIB
+ "${static_lib}"
+ PARENT_SCOPE)
+ if(static_lib)
+ add_library(${target_static} STATIC IMPORTED)
+ set_target_properties(${target_static}
+ PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${include_dir}"
+ IMPORTED_LOCATION "${static_lib}")
+ endif()
+ endif()
+endmacro()
+
+function(arrow_find_package
+ prefix
+ home
+ base_name
+ header_path
+ cmake_package_name
+ pkg_config_name)
+ arrow_build_shared_library_name(shared_lib_name ${base_name})
+ arrow_build_import_library_name(import_lib_name ${base_name})
+ arrow_build_static_library_name(static_lib_name ${base_name})
+
+ set(target_shared ${base_name}_shared)
+ set(target_static ${base_name}_static)
+
+ if(home)
+ arrow_find_package_home()
+ set(${prefix}_FIND_APPROACH
+ "HOME: ${home}"
+ PARENT_SCOPE)
+ else()
+ arrow_find_package_cmake_package_configuration()
+ if(${cmake_package_name}_FOUND)
+ set(${prefix}_FIND_APPROACH
+ "CMake package configuration: ${cmake_package_name}"
+ PARENT_SCOPE)
+ else()
+ arrow_find_package_pkg_config()
+ set(${prefix}_FIND_APPROACH
+ "pkg-config: ${pkg_config_name}"
+ PARENT_SCOPE)
+ endif()
+ endif()
+
+ if(NOT include_dir)
+ if(TARGET ${target_shared})
+ get_target_property(include_dir ${target_shared} INTERFACE_INCLUDE_DIRECTORIES)
+ elseif(TARGET ${target_static})
+ get_target_property(include_dir ${target_static} INTERFACE_INCLUDE_DIRECTORIES)
+ endif()
+ endif()
+ if(include_dir)
+ set(${prefix}_INCLUDE_DIR
+ "${include_dir}"
+ PARENT_SCOPE)
+ endif()
+
+ if(shared_lib)
+ get_filename_component(lib_dir "${shared_lib}" DIRECTORY)
+ elseif(static_lib)
+ get_filename_component(lib_dir "${static_lib}" DIRECTORY)
+ else()
+ set(lib_dir NOTFOUND)
+ endif()
+ set(${prefix}_LIB_DIR
+ "${lib_dir}"
+ PARENT_SCOPE)
+ # For backward compatibility
+ set(${prefix}_LIBS
+ "${lib_dir}"
+ PARENT_SCOPE)
+endfunction()
+
+if(NOT "$ENV{ARROW_HOME}" STREQUAL "")
+ file(TO_CMAKE_PATH "$ENV{ARROW_HOME}" ARROW_HOME)
+endif()
+arrow_find_package(ARROW
+ "${ARROW_HOME}"
+ arrow
+ arrow/api.h
+ Arrow
+ arrow)
+
+if(ARROW_HOME)
+ if(ARROW_INCLUDE_DIR)
+ file(READ "${ARROW_INCLUDE_DIR}/arrow/util/config.h" ARROW_CONFIG_H_CONTENT)
+ arrow_extract_macro_value(ARROW_VERSION_MAJOR "ARROW_VERSION_MAJOR"
+ "${ARROW_CONFIG_H_CONTENT}")
+ arrow_extract_macro_value(ARROW_VERSION_MINOR "ARROW_VERSION_MINOR"
+ "${ARROW_CONFIG_H_CONTENT}")
+ arrow_extract_macro_value(ARROW_VERSION_PATCH "ARROW_VERSION_PATCH"
+ "${ARROW_CONFIG_H_CONTENT}")
+ if("${ARROW_VERSION_MAJOR}" STREQUAL ""
+ OR "${ARROW_VERSION_MINOR}" STREQUAL ""
+ OR "${ARROW_VERSION_PATCH}" STREQUAL "")
+ set(ARROW_VERSION "0.0.0")
+ else()
+ set(ARROW_VERSION
+ "${ARROW_VERSION_MAJOR}.${ARROW_VERSION_MINOR}.${ARROW_VERSION_PATCH}")
+ endif()
+
+ arrow_extract_macro_value(ARROW_SO_VERSION_QUOTED "ARROW_SO_VERSION"
+ "${ARROW_CONFIG_H_CONTENT}")
+ string(REGEX REPLACE "^\"(.+)\"$" "\\1" ARROW_SO_VERSION "${ARROW_SO_VERSION_QUOTED}")
+ arrow_extract_macro_value(ARROW_FULL_SO_VERSION_QUOTED "ARROW_FULL_SO_VERSION"
+ "${ARROW_CONFIG_H_CONTENT}")
+ string(REGEX REPLACE "^\"(.+)\"$" "\\1" ARROW_FULL_SO_VERSION
+ "${ARROW_FULL_SO_VERSION_QUOTED}")
+ endif()
+else()
+ if(ARROW_USE_CMAKE_PACKAGE_CONFIG)
+ find_package(Arrow CONFIG)
+ elseif(ARROW_USE_PKG_CONFIG)
+ pkg_get_variable(ARROW_SO_VERSION arrow so_version)
+ pkg_get_variable(ARROW_FULL_SO_VERSION arrow full_so_version)
+ endif()
+endif()
+
+set(ARROW_ABI_VERSION ${ARROW_SO_VERSION})
+
+mark_as_advanced(ARROW_ABI_VERSION
+ ARROW_CONFIG_SUFFIXES
+ ARROW_FULL_SO_VERSION
+ ARROW_IMPORT_LIB
+ ARROW_INCLUDE_DIR
+ ARROW_LIBS
+ ARROW_LIB_DIR
+ ARROW_SEARCH_LIB_PATH_SUFFIXES
+ ARROW_SHARED_IMP_LIB
+ ARROW_SHARED_LIB
+ ARROW_SO_VERSION
+ ARROW_STATIC_LIB
+ ARROW_VERSION
+ ARROW_VERSION_MAJOR
+ ARROW_VERSION_MINOR
+ ARROW_VERSION_PATCH)
+
+find_package_handle_standard_args(
+ Arrow
+ REQUIRED_VARS # The first required variable is shown
+ # in the found message. So this list is
+ # not sorted alphabetically.
+ ARROW_INCLUDE_DIR ARROW_LIB_DIR ARROW_FULL_SO_VERSION ARROW_SO_VERSION
+ VERSION_VAR ARROW_VERSION)
+set(ARROW_FOUND ${Arrow_FOUND})
+
+if(Arrow_FOUND AND NOT Arrow_FIND_QUIETLY)
+ message(STATUS "Arrow version: ${ARROW_VERSION} (${ARROW_FIND_APPROACH})")
+ message(STATUS "Arrow SO and ABI version: ${ARROW_SO_VERSION}")
+ message(STATUS "Arrow full SO version: ${ARROW_FULL_SO_VERSION}")
+ message(STATUS "Found the Arrow core shared library: ${ARROW_SHARED_LIB}")
+ message(STATUS "Found the Arrow core import library: ${ARROW_IMPORT_LIB}")
+ message(STATUS "Found the Arrow core static library: ${ARROW_STATIC_LIB}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindArrowCUDA.cmake b/src/arrow/cpp/cmake_modules/FindArrowCUDA.cmake
new file mode 100644
index 000000000..014386f30
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindArrowCUDA.cmake
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# - Find Arrow CUDA (arrow/gpu/cuda_api.h, libarrow_cuda.a, libarrow_cuda.so)
+#
+# This module requires Arrow from which it uses
+# arrow_find_package()
+#
+# This module defines
+# ARROW_CUDA_FOUND, whether Arrow CUDA has been found
+# ARROW_CUDA_IMPORT_LIB, path to libarrow_cuda's import library (Windows only)
+# ARROW_CUDA_INCLUDE_DIR, directory containing headers
+# ARROW_CUDA_LIBS, deprecated. Use ARROW_CUDA_LIB_DIR instead
+# ARROW_CUDA_LIB_DIR, directory containing Arrow CUDA libraries
+# ARROW_CUDA_SHARED_IMP_LIB, deprecated. Use ARROW_CUDA_IMPORT_LIB instead
+# ARROW_CUDA_SHARED_LIB, path to libarrow_cuda's shared library
+# ARROW_CUDA_STATIC_LIB, path to libarrow_cuda.a
+
+if(DEFINED ARROW_CUDA_FOUND)
+ return()
+endif()
+
+set(find_package_arguments)
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION)
+ list(APPEND find_package_arguments "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}")
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED)
+ list(APPEND find_package_arguments REQUIRED)
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY)
+ list(APPEND find_package_arguments QUIET)
+endif()
+find_package(Arrow ${find_package_arguments})
+
+if(ARROW_FOUND)
+ arrow_find_package(ARROW_CUDA
+ "${ARROW_HOME}"
+ arrow_cuda
+ arrow/gpu/cuda_api.h
+ ArrowCUDA
+ arrow-cuda)
+ if(NOT ARROW_CUDA_VERSION)
+ set(ARROW_CUDA_VERSION "${ARROW_VERSION}")
+ endif()
+endif()
+
+if("${ARROW_CUDA_VERSION}" VERSION_EQUAL "${ARROW_VERSION}")
+ set(ARROW_CUDA_VERSION_MATCH TRUE)
+else()
+ set(ARROW_CUDA_VERSION_MATCH FALSE)
+endif()
+
+mark_as_advanced(ARROW_CUDA_IMPORT_LIB
+ ARROW_CUDA_INCLUDE_DIR
+ ARROW_CUDA_LIBS
+ ARROW_CUDA_LIB_DIR
+ ARROW_CUDA_SHARED_IMP_LIB
+ ARROW_CUDA_SHARED_LIB
+ ARROW_CUDA_STATIC_LIB
+ ARROW_CUDA_VERSION
+ ARROW_CUDA_VERSION_MATCH)
+
+find_package_handle_standard_args(
+ ArrowCUDA
+ REQUIRED_VARS ARROW_CUDA_INCLUDE_DIR ARROW_CUDA_LIB_DIR ARROW_CUDA_VERSION_MATCH
+ VERSION_VAR ARROW_CUDA_VERSION)
+set(ARROW_CUDA_FOUND ${ArrowCUDA_FOUND})
+
+if(ArrowCUDA_FOUND AND NOT ArrowCUDA_FIND_QUIETLY)
+ message(STATUS "Found the Arrow CUDA by ${ARROW_CUDA_FIND_APPROACH}")
+ message(STATUS "Found the Arrow CUDA shared library: ${ARROW_CUDA_SHARED_LIB}")
+ message(STATUS "Found the Arrow CUDA import library: ${ARROW_CUDA_IMPORT_LIB}")
+ message(STATUS "Found the Arrow CUDA static library: ${ARROW_CUDA_STATIC_LIB}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindArrowDataset.cmake b/src/arrow/cpp/cmake_modules/FindArrowDataset.cmake
new file mode 100644
index 000000000..fe74f247f
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindArrowDataset.cmake
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# - Find Arrow Dataset (arrow/dataset/api.h, libarrow_dataset.a, libarrow_dataset.so)
+#
+# This module requires Arrow from which it uses
+# arrow_find_package()
+#
+# This module defines
+# ARROW_DATASET_FOUND, whether Arrow Dataset has been found
+# ARROW_DATASET_IMPORT_LIB,
+# path to libarrow_dataset's import library (Windows only)
+# ARROW_DATASET_INCLUDE_DIR, directory containing headers
+# ARROW_DATASET_LIB_DIR, directory containing Arrow Dataset libraries
+# ARROW_DATASET_SHARED_LIB, path to libarrow_dataset's shared library
+# ARROW_DATASET_STATIC_LIB, path to libarrow_dataset.a
+
+if(DEFINED ARROW_DATASET_FOUND)
+ return()
+endif()
+
+set(find_package_arguments)
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION)
+ list(APPEND find_package_arguments "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}")
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED)
+ list(APPEND find_package_arguments REQUIRED)
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY)
+ list(APPEND find_package_arguments QUIET)
+endif()
+find_package(Arrow ${find_package_arguments})
+find_package(Parquet ${find_package_arguments})
+
+if(ARROW_FOUND AND PARQUET_FOUND)
+ arrow_find_package(ARROW_DATASET
+ "${ARROW_HOME}"
+ arrow_dataset
+ arrow/dataset/api.h
+ ArrowDataset
+ arrow-dataset)
+ if(NOT ARROW_DATASET_VERSION)
+ set(ARROW_DATASET_VERSION "${ARROW_VERSION}")
+ endif()
+endif()
+
+if("${ARROW_DATASET_VERSION}" VERSION_EQUAL "${ARROW_VERSION}")
+ set(ARROW_DATASET_VERSION_MATCH TRUE)
+else()
+ set(ARROW_DATASET_VERSION_MATCH FALSE)
+endif()
+
+mark_as_advanced(ARROW_DATASET_IMPORT_LIB
+ ARROW_DATASET_INCLUDE_DIR
+ ARROW_DATASET_LIBS
+ ARROW_DATASET_LIB_DIR
+ ARROW_DATASET_SHARED_IMP_LIB
+ ARROW_DATASET_SHARED_LIB
+ ARROW_DATASET_STATIC_LIB
+ ARROW_DATASET_VERSION
+ ARROW_DATASET_VERSION_MATCH)
+
+find_package_handle_standard_args(
+ ArrowDataset
+ REQUIRED_VARS ARROW_DATASET_INCLUDE_DIR ARROW_DATASET_LIB_DIR
+ ARROW_DATASET_VERSION_MATCH
+ VERSION_VAR ARROW_DATASET_VERSION)
+set(ARROW_DATASET_FOUND ${ArrowDataset_FOUND})
+
+if(ArrowDataset_FOUND AND NOT ArrowDataset_FIND_QUIETLY)
+ message(STATUS "Found the Arrow Dataset by ${ARROW_DATASET_FIND_APPROACH}")
+ message(STATUS "Found the Arrow Dataset shared library: ${ARROW_DATASET_SHARED_LIB}")
+ message(STATUS "Found the Arrow Dataset import library: ${ARROW_DATASET_IMPORT_LIB}")
+ message(STATUS "Found the Arrow Dataset static library: ${ARROW_DATASET_STATIC_LIB}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindArrowFlight.cmake b/src/arrow/cpp/cmake_modules/FindArrowFlight.cmake
new file mode 100644
index 000000000..805a4ff38
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindArrowFlight.cmake
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# - Find Arrow Flight (arrow/flight/api.h, libarrow_flight.a, libarrow_flight.so)
+#
+# This module requires Arrow from which it uses
+# arrow_find_package()
+#
+# This module defines
+# ARROW_FLIGHT_FOUND, whether Flight has been found
+# ARROW_FLIGHT_IMPORT_LIB,
+# path to libarrow_flight's import library (Windows only)
+# ARROW_FLIGHT_INCLUDE_DIR, directory containing headers
+# ARROW_FLIGHT_LIBS, deprecated. Use ARROW_FLIGHT_LIB_DIR instead
+# ARROW_FLIGHT_LIB_DIR, directory containing Flight libraries
+# ARROW_FLIGHT_SHARED_IMP_LIB, deprecated. Use ARROW_FLIGHT_IMPORT_LIB instead
+# ARROW_FLIGHT_SHARED_LIB, path to libarrow_flight's shared library
+# ARROW_FLIGHT_STATIC_LIB, path to libarrow_flight.a
+
+if(DEFINED ARROW_FLIGHT_FOUND)
+ return()
+endif()
+
+set(find_package_arguments)
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION)
+ list(APPEND find_package_arguments "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}")
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED)
+ list(APPEND find_package_arguments REQUIRED)
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY)
+ list(APPEND find_package_arguments QUIET)
+endif()
+find_package(Arrow ${find_package_arguments})
+
+if(ARROW_FOUND)
+ arrow_find_package(ARROW_FLIGHT
+ "${ARROW_HOME}"
+ arrow_flight
+ arrow/flight/api.h
+ ArrowFlight
+ arrow-flight)
+ if(NOT ARROW_FLIGHT_VERSION)
+ set(ARROW_FLIGHT_VERSION "${ARROW_VERSION}")
+ endif()
+endif()
+
+if("${ARROW_FLIGHT_VERSION}" VERSION_EQUAL "${ARROW_VERSION}")
+ set(ARROW_FLIGHT_VERSION_MATCH TRUE)
+else()
+ set(ARROW_FLIGHT_VERSION_MATCH FALSE)
+endif()
+
+mark_as_advanced(ARROW_FLIGHT_IMPORT_LIB
+ ARROW_FLIGHT_INCLUDE_DIR
+ ARROW_FLIGHT_LIBS
+ ARROW_FLIGHT_LIB_DIR
+ ARROW_FLIGHT_SHARED_IMP_LIB
+ ARROW_FLIGHT_SHARED_LIB
+ ARROW_FLIGHT_STATIC_LIB
+ ARROW_FLIGHT_VERSION
+ ARROW_FLIGHT_VERSION_MATCH)
+
+find_package_handle_standard_args(
+ ArrowFlight
+ REQUIRED_VARS ARROW_FLIGHT_INCLUDE_DIR ARROW_FLIGHT_LIB_DIR ARROW_FLIGHT_VERSION_MATCH
+ VERSION_VAR ARROW_FLIGHT_VERSION)
+set(ARROW_FLIGHT_FOUND ${ArrowFlight_FOUND})
+
+if(ArrowFlight_FOUND AND NOT ArrowFlight_FIND_QUIETLY)
+ message(STATUS "Found the Arrow Flight by ${ARROW_FLIGHT_FIND_APPROACH}")
+ message(STATUS "Found the Arrow Flight shared library: ${ARROW_FLIGHT_SHARED_LIB}")
+ message(STATUS "Found the Arrow Flight import library: ${ARROW_FLIGHT_IMPORT_LIB}")
+ message(STATUS "Found the Arrow Flight static library: ${ARROW_FLIGHT_STATIC_LIB}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindArrowFlightTesting.cmake b/src/arrow/cpp/cmake_modules/FindArrowFlightTesting.cmake
new file mode 100644
index 000000000..c0756cf63
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindArrowFlightTesting.cmake
@@ -0,0 +1,98 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# - Find Arrow Flight testing library
+# (arrow/flight/test_util.h,
+# libarrow_flight_testing.a,
+# libarrow_flight_testing.so)
+#
+# This module requires Arrow from which it uses
+# arrow_find_package()
+#
+# This module defines
+# ARROW_FLIGHT_TESTING_FOUND,
+# whether Arrow Flight testing library has been found
+# ARROW_FLIGHT_TESTING_IMPORT_LIB,
+# path to libarrow_flight_testing's import library (Windows only)
+# ARROW_FLIGHT_TESTING_INCLUDE_DIR, directory containing headers
+# ARROW_FLIGHT_TESTING_LIB_DIR, directory containing Arrow testing libraries
+# ARROW_FLIGHT_TESTING_SHARED_LIB,
+# path to libarrow_flight_testing's shared library
+# ARROW_FLIGHT_TESTING_STATIC_LIB, path to libarrow_flight_testing.a
+
+if(DEFINED ARROW_FLIGHT_TESTING_FOUND)
+ return()
+endif()
+
+set(find_package_arguments)
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION)
+ list(APPEND find_package_arguments "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}")
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED)
+ list(APPEND find_package_arguments REQUIRED)
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY)
+ list(APPEND find_package_arguments QUIET)
+endif()
+find_package(ArrowFlight ${find_package_arguments})
+find_package(ArrowTesting ${find_package_arguments})
+
+if(ARROW_TESTING_FOUND AND ARROW_FLIGHT_FOUND)
+ arrow_find_package(ARROW_FLIGHT_TESTING
+ "${ARROW_HOME}"
+ arrow_flight_testing
+ arrow/flight/test_util.h
+ ArrowFlightTesting
+ arrow-flight-testing)
+ if(NOT ARROW_FLIGHT_TESTING_VERSION)
+ set(ARROW_FLIGHT_TESTING_VERSION "${ARROW_VERSION}")
+ endif()
+endif()
+
+if("${ARROW_FLIGHT_TESTING_VERSION}" VERSION_EQUAL "${ARROW_VERSION}")
+ set(ARROW_FLIGHT_TESTING_VERSION_MATCH TRUE)
+else()
+ set(ARROW_FLIGHT_TESTING_VERSION_MATCH FALSE)
+endif()
+
+mark_as_advanced(ARROW_FLIGHT_TESTING_IMPORT_LIB
+ ARROW_FLIGHT_TESTING_INCLUDE_DIR
+ ARROW_FLIGHT_TESTING_LIBS
+ ARROW_FLIGHT_TESTING_LIB_DIR
+ ARROW_FLIGHT_TESTING_SHARED_IMP_LIB
+ ARROW_FLIGHT_TESTING_SHARED_LIB
+ ARROW_FLIGHT_TESTING_STATIC_LIB
+ ARROW_FLIGHT_TESTING_VERSION
+ ARROW_FLIGHT_TESTING_VERSION_MATCH)
+
+find_package_handle_standard_args(
+ ArrowFlightTesting
+ REQUIRED_VARS ARROW_FLIGHT_TESTING_INCLUDE_DIR ARROW_FLIGHT_TESTING_LIB_DIR
+ ARROW_FLIGHT_TESTING_VERSION_MATCH
+ VERSION_VAR ARROW_FLIGHT_TESTING_VERSION)
+set(ARROW_FLIGHT_TESTING_FOUND ${ArrowFlightTesting_FOUND})
+
+if(ArrowFlightTesting_FOUND AND NOT ArrowFlightTesting_FIND_QUIETLY)
+ message(STATUS "Found the Arrow Flight testing by ${ARROW_FLIGHT_TESTING_FIND_APPROACH}"
+ )
+ message(STATUS "Found the Arrow Flight testing shared library: ${ARROW_FLIGHT_TESTING_SHARED_LIB}"
+ )
+ message(STATUS "Found the Arrow Flight testing import library: ${ARROW_FLIGHT_TESTING_IMPORT_LIB}"
+ )
+ message(STATUS "Found the Arrow Flight testing static library: ${ARROW_FLIGHT_TESTING_STATIC_LIB}"
+ )
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindArrowPython.cmake b/src/arrow/cpp/cmake_modules/FindArrowPython.cmake
new file mode 100644
index 000000000..b503e6a9e
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindArrowPython.cmake
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# - Find Arrow Python (arrow/python/api.h, libarrow_python.a, libarrow_python.so)
+#
+# This module requires Arrow from which it uses
+# arrow_find_package()
+#
+# This module defines
+# ARROW_PYTHON_FOUND, whether Arrow Python has been found
+# ARROW_PYTHON_IMPORT_LIB,
+# path to libarrow_python's import library (Windows only)
+# ARROW_PYTHON_INCLUDE_DIR, directory containing headers
+# ARROW_PYTHON_LIB_DIR, directory containing Arrow Python libraries
+# ARROW_PYTHON_SHARED_LIB, path to libarrow_python's shared library
+# ARROW_PYTHON_STATIC_LIB, path to libarrow_python.a
+
+if(DEFINED ARROW_PYTHON_FOUND)
+ return()
+endif()
+
+set(find_package_arguments)
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION)
+ list(APPEND find_package_arguments "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}")
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED)
+ list(APPEND find_package_arguments REQUIRED)
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY)
+ list(APPEND find_package_arguments QUIET)
+endif()
+find_package(Arrow ${find_package_arguments})
+
+if(ARROW_FOUND)
+ arrow_find_package(ARROW_PYTHON
+ "${ARROW_HOME}"
+ arrow_python
+ arrow/python/api.h
+ ArrowPython
+ arrow-python)
+ if(NOT ARROW_PYTHON_VERSION)
+ set(ARROW_PYTHON_VERSION "${ARROW_VERSION}")
+ endif()
+endif()
+
+if("${ARROW_PYTHON_VERSION}" VERSION_EQUAL "${ARROW_VERSION}")
+ set(ARROW_PYTHON_VERSION_MATCH TRUE)
+else()
+ set(ARROW_PYTHON_VERSION_MATCH FALSE)
+endif()
+
+mark_as_advanced(ARROW_PYTHON_IMPORT_LIB
+ ARROW_PYTHON_INCLUDE_DIR
+ ARROW_PYTHON_LIBS
+ ARROW_PYTHON_LIB_DIR
+ ARROW_PYTHON_SHARED_IMP_LIB
+ ARROW_PYTHON_SHARED_LIB
+ ARROW_PYTHON_STATIC_LIB
+ ARROW_PYTHON_VERSION
+ ARROW_PYTHON_VERSION_MATCH)
+
+find_package_handle_standard_args(
+ ArrowPython
+ REQUIRED_VARS ARROW_PYTHON_INCLUDE_DIR ARROW_PYTHON_LIB_DIR ARROW_PYTHON_VERSION_MATCH
+ VERSION_VAR ARROW_PYTHON_VERSION)
+set(ARROW_PYTHON_FOUND ${ArrowPython_FOUND})
+
+if(ArrowPython_FOUND AND NOT ArrowPython_FIND_QUIETLY)
+ message(STATUS "Found the Arrow Python by ${ARROW_PYTHON_FIND_APPROACH}")
+ message(STATUS "Found the Arrow Python shared library: ${ARROW_PYTHON_SHARED_LIB}")
+ message(STATUS "Found the Arrow Python import library: ${ARROW_PYTHON_IMPORT_LIB}")
+ message(STATUS "Found the Arrow Python static library: ${ARROW_PYTHON_STATIC_LIB}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindArrowPythonFlight.cmake b/src/arrow/cpp/cmake_modules/FindArrowPythonFlight.cmake
new file mode 100644
index 000000000..3a639928c
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindArrowPythonFlight.cmake
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# - Find Arrow Python Flight
+# (arrow/python/flight.h, libarrow_python_flight.a, libarrow_python_flight.so)
+#
+# This module requires Arrow from which it uses
+# arrow_find_package()
+#
+# This module defines
+# ARROW_PYTHON_FLIGHT_FOUND, whether Arrow Python Flight has been found
+# ARROW_PYTHON_FLIGHT_IMPORT_LIB,
+# path to libarrow_python_flight's import library (Windows only)
+# ARROW_PYTHON_FLIGHT_INCLUDE_DIR, directory containing headers
+# ARROW_PYTHON_FLIGHT_LIB_DIR,
+# directory containing Arrow Python Flight libraries
+# ARROW_PYTHON_FLIGHT_SHARED_LIB, path to libarrow_python_flight's shared library
+# ARROW_PYTHON_FLIGHT_STATIC_LIB, path to libarrow_python_flight.a
+
+if(DEFINED ARROW_PYTHON_FLIGHT_FOUND)
+ return()
+endif()
+
+set(find_package_arguments)
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION)
+ list(APPEND find_package_arguments "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}")
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED)
+ list(APPEND find_package_arguments REQUIRED)
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY)
+ list(APPEND find_package_arguments QUIET)
+endif()
+find_package(ArrowFlight ${find_package_arguments})
+find_package(ArrowPython ${find_package_arguments})
+
+if(ARROW_PYTHON_FOUND AND ARROW_FLIGHT_FOUND)
+ arrow_find_package(ARROW_PYTHON_FLIGHT
+ "${ARROW_HOME}"
+ arrow_python_flight
+ arrow/python/flight.h
+ ArrowPythonFlight
+ arrow-python-flight)
+ if(NOT ARROW_PYTHON_FLIGHT_VERSION)
+ set(ARROW_PYTHON_FLIGHT_VERSION "${ARROW_VERSION}")
+ endif()
+endif()
+
+if("${ARROW_PYTHON_FLIGHT_VERSION}" VERSION_EQUAL "${ARROW_VERSION}")
+ set(ARROW_PYTHON_FLIGHT_VERSION_MATCH TRUE)
+else()
+ set(ARROW_PYTHON_FLIGHT_VERSION_MATCH FALSE)
+endif()
+
+mark_as_advanced(ARROW_PYTHON_FLIGHT_IMPORT_LIB
+ ARROW_PYTHON_FLIGHT_INCLUDE_DIR
+ ARROW_PYTHON_FLIGHT_LIBS
+ ARROW_PYTHON_FLIGHT_LIB_DIR
+ ARROW_PYTHON_FLIGHT_SHARED_IMP_LIB
+ ARROW_PYTHON_FLIGHT_SHARED_LIB
+ ARROW_PYTHON_FLIGHT_STATIC_LIB
+ ARROW_PYTHON_FLIGHT_VERSION
+ ARROW_PYTHON_FLIGHT_VERSION_MATCH)
+
+find_package_handle_standard_args(
+ ArrowPythonFlight
+ REQUIRED_VARS ARROW_PYTHON_FLIGHT_INCLUDE_DIR ARROW_PYTHON_FLIGHT_LIB_DIR
+ ARROW_PYTHON_FLIGHT_VERSION_MATCH
+ VERSION_VAR ARROW_PYTHON_FLIGHT_VERSION)
+set(ARROW_PYTHON_FLIGHT_FOUND ${ArrowPythonFlight_FOUND})
+
+if(ArrowPythonFlight_FOUND AND NOT ArrowPythonFlight_FIND_QUIETLY)
+ message(STATUS "Found the Arrow Python Flight by ${ARROW_PYTHON_FLIGHT_FIND_APPROACH}")
+ message(STATUS "Found the Arrow Python Flight shared library: ${ARROW_PYTHON_FLIGHT_SHARED_LIB}"
+ )
+ message(STATUS "Found the Arrow Python Flight import library: ${ARROW_PYTHON_FLIGHT_IMPORT_LIB}"
+ )
+ message(STATUS "Found the Arrow Python Flight static library: ${ARROW_PYTHON_FLIGHT_STATIC_LIB}"
+ )
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindArrowTesting.cmake b/src/arrow/cpp/cmake_modules/FindArrowTesting.cmake
new file mode 100644
index 000000000..c405003ad
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindArrowTesting.cmake
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# - Find Arrow testing library
+# (arrow/testing/util.h, libarrow_testing.a, libarrow_testing.so)
+#
+# This module requires Arrow from which it uses
+# arrow_find_package()
+#
+# This module defines
+# ARROW_TESTING_FOUND, whether Arrow testing library has been found
+# ARROW_TESTING_IMPORT_LIB,
+# path to libarrow_testing's import library (Windows only)
+# ARROW_TESTING_INCLUDE_DIR, directory containing headers
+# ARROW_TESTING_LIB_DIR, directory containing Arrow testing libraries
+# ARROW_TESTING_SHARED_LIB, path to libarrow_testing's shared library
+# ARROW_TESTING_STATIC_LIB, path to libarrow_testing.a
+
+if(DEFINED ARROW_TESTING_FOUND)
+ return()
+endif()
+
+set(find_package_arguments)
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION)
+ list(APPEND find_package_arguments "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}")
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED)
+ list(APPEND find_package_arguments REQUIRED)
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY)
+ list(APPEND find_package_arguments QUIET)
+endif()
+find_package(Arrow ${find_package_arguments})
+
+if(ARROW_FOUND)
+ arrow_find_package(ARROW_TESTING
+ "${ARROW_HOME}"
+ arrow_testing
+ arrow/testing/util.h
+ ArrowTesting
+ arrow-testing)
+ if(NOT ARROW_TESTING_VERSION)
+ set(ARROW_TESTING_VERSION "${ARROW_VERSION}")
+ endif()
+endif()
+
+if("${ARROW_TESTING_VERSION}" VERSION_EQUAL "${ARROW_VERSION}")
+ set(ARROW_TESTING_VERSION_MATCH TRUE)
+else()
+ set(ARROW_TESTING_VERSION_MATCH FALSE)
+endif()
+
+mark_as_advanced(ARROW_TESTING_IMPORT_LIB
+ ARROW_TESTING_INCLUDE_DIR
+ ARROW_TESTING_LIBS
+ ARROW_TESTING_LIB_DIR
+ ARROW_TESTING_SHARED_IMP_LIB
+ ARROW_TESTING_SHARED_LIB
+ ARROW_TESTING_STATIC_LIB
+ ARROW_TESTING_VERSION
+ ARROW_TESTING_VERSION_MATCH)
+
+find_package_handle_standard_args(
+ ArrowTesting
+ REQUIRED_VARS ARROW_TESTING_INCLUDE_DIR ARROW_TESTING_LIB_DIR
+ ARROW_TESTING_VERSION_MATCH
+ VERSION_VAR ARROW_TESTING_VERSION)
+set(ARROW_TESTING_FOUND ${ArrowTesting_FOUND})
+
+if(ArrowTesting_FOUND AND NOT ArrowTesting_FIND_QUIETLY)
+ message(STATUS "Found the Arrow testing by ${ARROW_TESTING_FIND_APPROACH}")
+ message(STATUS "Found the Arrow testing shared library: ${ARROW_TESTING_SHARED_LIB}")
+ message(STATUS "Found the Arrow testing import library: ${ARROW_TESTING_IMPORT_LIB}")
+ message(STATUS "Found the Arrow testing static library: ${ARROW_TESTING_STATIC_LIB}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindBoostAlt.cmake b/src/arrow/cpp/cmake_modules/FindBoostAlt.cmake
new file mode 100644
index 000000000..177193712
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindBoostAlt.cmake
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(DEFINED ENV{BOOST_ROOT} OR DEFINED BOOST_ROOT)
+ # In older versions of CMake (such as 3.2), the system paths for Boost will
+ # be looked in first even if we set $BOOST_ROOT or pass -DBOOST_ROOT
+ set(Boost_NO_SYSTEM_PATHS ON)
+endif()
+
+set(BoostAlt_FIND_VERSION_OPTIONS)
+if(BoostAlt_FIND_VERSION)
+ list(APPEND BoostAlt_FIND_VERSION_OPTIONS ${BoostAlt_FIND_VERSION})
+endif()
+if(BoostAlt_FIND_REQUIRED)
+ list(APPEND BoostAlt_FIND_VERSION_OPTIONS REQUIRED)
+endif()
+if(BoostAlt_FIND_QUIETLY)
+ list(APPEND BoostAlt_FIND_VERSION_OPTIONS QUIET)
+endif()
+
+if(ARROW_BOOST_USE_SHARED)
+ # Find shared Boost libraries.
+ set(Boost_USE_STATIC_LIBS OFF)
+ set(BUILD_SHARED_LIBS_KEEP ${BUILD_SHARED_LIBS})
+ set(BUILD_SHARED_LIBS ON)
+
+ find_package(Boost ${BoostAlt_FIND_VERSION_OPTIONS} COMPONENTS system filesystem)
+ set(BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS_KEEP})
+ unset(BUILD_SHARED_LIBS_KEEP)
+else()
+ # Find static boost headers and libs
+ # TODO Differentiate here between release and debug builds
+ set(Boost_USE_STATIC_LIBS ON)
+ find_package(Boost ${BoostAlt_FIND_VERSION_OPTIONS} COMPONENTS system filesystem)
+endif()
+
+if(Boost_FOUND)
+ set(BoostAlt_FOUND ON)
+ if(MSVC_TOOLCHAIN)
+ # disable autolinking in boost
+ add_definitions(-DBOOST_ALL_NO_LIB)
+ if(ARROW_BOOST_USE_SHARED)
+ # force all boost libraries to dynamic link
+ add_definitions(-DBOOST_ALL_DYN_LINK)
+ endif()
+ endif()
+else()
+ set(BoostAlt_FOUND OFF)
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindBrotli.cmake b/src/arrow/cpp/cmake_modules/FindBrotli.cmake
new file mode 100644
index 000000000..e2670b51a
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindBrotli.cmake
@@ -0,0 +1,130 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Tries to find Brotli headers and libraries.
+#
+# Usage of this module as follows:
+#
+# find_package(Brotli)
+
+if(ARROW_BROTLI_USE_SHARED)
+ set(BROTLI_COMMON_LIB_NAMES
+ brotlicommon
+ ${CMAKE_SHARED_LIBRARY_PREFIX}brotlicommon${CMAKE_SHARED_LIBRARY_SUFFIX})
+
+ set(BROTLI_ENC_LIB_NAMES
+ brotlienc ${CMAKE_SHARED_LIBRARY_PREFIX}brotlienc${CMAKE_SHARED_LIBRARY_SUFFIX})
+
+ set(BROTLI_DEC_LIB_NAMES
+ brotlidec ${CMAKE_SHARED_LIBRARY_PREFIX}brotlidec${CMAKE_SHARED_LIBRARY_SUFFIX})
+else()
+ set(BROTLI_COMMON_LIB_NAMES
+ brotlicommon-static
+ ${CMAKE_STATIC_LIBRARY_PREFIX}brotlicommon-static${CMAKE_STATIC_LIBRARY_SUFFIX}
+ ${CMAKE_STATIC_LIBRARY_PREFIX}brotlicommon_static${CMAKE_STATIC_LIBRARY_SUFFIX}
+ ${CMAKE_STATIC_LIBRARY_PREFIX}brotlicommon${CMAKE_STATIC_LIBRARY_SUFFIX})
+
+ set(BROTLI_ENC_LIB_NAMES
+ brotlienc-static
+ ${CMAKE_STATIC_LIBRARY_PREFIX}brotlienc-static${CMAKE_STATIC_LIBRARY_SUFFIX}
+ ${CMAKE_STATIC_LIBRARY_PREFIX}brotlienc_static${CMAKE_STATIC_LIBRARY_SUFFIX}
+ ${CMAKE_STATIC_LIBRARY_PREFIX}brotlienc${CMAKE_STATIC_LIBRARY_SUFFIX})
+
+ set(BROTLI_DEC_LIB_NAMES
+ brotlidec-static
+ ${CMAKE_STATIC_LIBRARY_PREFIX}brotlidec-static${CMAKE_STATIC_LIBRARY_SUFFIX}
+ ${CMAKE_STATIC_LIBRARY_PREFIX}brotlidec_static${CMAKE_STATIC_LIBRARY_SUFFIX}
+ ${CMAKE_STATIC_LIBRARY_PREFIX}brotlidec${CMAKE_STATIC_LIBRARY_SUFFIX})
+endif()
+
+if(BROTLI_ROOT)
+ find_library(BROTLI_COMMON_LIBRARY
+ NAMES ${BROTLI_COMMON_LIB_NAMES}
+ PATHS ${BROTLI_ROOT}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ find_library(BROTLI_ENC_LIBRARY
+ NAMES ${BROTLI_ENC_LIB_NAMES}
+ PATHS ${BROTLI_ROOT}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ find_library(BROTLI_DEC_LIBRARY
+ NAMES ${BROTLI_DEC_LIB_NAMES}
+ PATHS ${BROTLI_ROOT}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ find_path(BROTLI_INCLUDE_DIR
+ NAMES brotli/decode.h
+ PATHS ${BROTLI_ROOT}
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+else()
+ find_package(PkgConfig QUIET)
+ pkg_check_modules(BROTLI_PC libbrotlicommon libbrotlienc libbrotlidec)
+ if(BROTLI_PC_FOUND)
+ set(BROTLI_INCLUDE_DIR "${BROTLI_PC_libbrotlicommon_INCLUDEDIR}")
+
+ # Some systems (e.g. Fedora) don't fill Brotli_LIBRARY_DIRS, so add the other dirs here.
+ list(APPEND BROTLI_PC_LIBRARY_DIRS "${BROTLI_PC_libbrotlicommon_LIBDIR}")
+ list(APPEND BROTLI_PC_LIBRARY_DIRS "${BROTLI_PC_libbrotlienc_LIBDIR}")
+ list(APPEND BROTLI_PC_LIBRARY_DIRS "${BROTLI_PC_libbrotlidec_LIBDIR}")
+
+ find_library(BROTLI_COMMON_LIBRARY
+ NAMES ${BROTLI_COMMON_LIB_NAMES}
+ PATHS ${BROTLI_PC_LIBRARY_DIRS}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ find_library(BROTLI_ENC_LIBRARY
+ NAMES ${BROTLI_ENC_LIB_NAMES}
+ PATHS ${BROTLI_PC_LIBRARY_DIRS}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ find_library(BROTLI_DEC_LIBRARY
+ NAMES ${BROTLI_DEC_LIB_NAMES}
+ PATHS ${BROTLI_PC_LIBRARY_DIRS}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ else()
+ find_library(BROTLI_COMMON_LIBRARY
+ NAMES ${BROTLI_COMMON_LIB_NAMES}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES})
+ find_library(BROTLI_ENC_LIBRARY
+ NAMES ${BROTLI_ENC_LIB_NAMES}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES})
+ find_library(BROTLI_DEC_LIBRARY
+ NAMES ${BROTLI_DEC_LIB_NAMES}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES})
+ find_path(BROTLI_INCLUDE_DIR
+ NAMES brotli/decode.h
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+ endif()
+endif()
+
+find_package_handle_standard_args(
+ Brotli REQUIRED_VARS BROTLI_COMMON_LIBRARY BROTLI_ENC_LIBRARY BROTLI_DEC_LIBRARY
+ BROTLI_INCLUDE_DIR)
+if(Brotli_FOUND OR BROTLI_FOUND)
+ set(Brotli_FOUND TRUE)
+ add_library(Brotli::brotlicommon UNKNOWN IMPORTED)
+ set_target_properties(Brotli::brotlicommon
+ PROPERTIES IMPORTED_LOCATION "${BROTLI_COMMON_LIBRARY}"
+ INTERFACE_INCLUDE_DIRECTORIES "${BROTLI_INCLUDE_DIR}")
+ add_library(Brotli::brotlienc UNKNOWN IMPORTED)
+ set_target_properties(Brotli::brotlienc
+ PROPERTIES IMPORTED_LOCATION "${BROTLI_ENC_LIBRARY}"
+ INTERFACE_INCLUDE_DIRECTORIES "${BROTLI_INCLUDE_DIR}")
+ add_library(Brotli::brotlidec UNKNOWN IMPORTED)
+ set_target_properties(Brotli::brotlidec
+ PROPERTIES IMPORTED_LOCATION "${BROTLI_DEC_LIBRARY}"
+ INTERFACE_INCLUDE_DIRECTORIES "${BROTLI_INCLUDE_DIR}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindClangTools.cmake b/src/arrow/cpp/cmake_modules/FindClangTools.cmake
new file mode 100644
index 000000000..52fc59895
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindClangTools.cmake
@@ -0,0 +1,106 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Tries to find the clang-tidy and clang-format modules
+#
+# Usage of this module as follows:
+#
+# find_package(ClangTools)
+#
+# Variables used by this module which can change the default behaviour and need
+# to be set before calling find_package:
+#
+# CLANG_FORMAT_VERSION -
+# The version of clang-format to find. If this is not specified, clang-format
+# will not be searched for.
+#
+# ClangTools_PATH -
+# When set, this path is inspected in addition to standard library binary locations
+# to find clang-tidy and clang-format
+#
+# This module defines
+# CLANG_TIDY_BIN, The path to the clang tidy binary
+# CLANG_TIDY_FOUND, Whether clang tidy was found
+# CLANG_FORMAT_BIN, The path to the clang format binary
+# CLANG_FORMAT_FOUND, Whether clang format was found
+
+set(CLANG_TOOLS_SEARCH_PATHS
+ ${ClangTools_PATH}
+ $ENV{CLANG_TOOLS_PATH}
+ /usr/local/bin
+ /usr/bin
+ "C:/Program Files/LLVM/bin" # Windows, non-conda
+ "$ENV{CONDA_PREFIX}/Library/bin") # Windows, conda
+if(CLANG_TOOLS_BREW_PREFIX)
+ list(APPEND CLANG_TOOLS_SEARCH_PATHS "${CLANG_TOOLS_BREW}/bin")
+endif()
+
+function(FIND_CLANG_TOOL NAME OUTPUT VERSION_CHECK_PATTERN)
+ unset(CLANG_TOOL_BIN CACHE)
+ find_program(CLANG_TOOL_BIN
+ NAMES ${NAME}-${ARROW_CLANG_TOOLS_VERSION}
+ ${NAME}-${ARROW_CLANG_TOOLS_VERSION_MAJOR}
+ PATHS ${CLANG_TOOLS_SEARCH_PATHS}
+ NO_DEFAULT_PATH)
+ if(NOT CLANG_TOOL_BIN)
+ # try searching for non-versioned tool and check the version
+ find_program(CLANG_TOOL_BIN
+ NAMES ${NAME}
+ PATHS ${CLANG_TOOLS_SEARCH_PATHS}
+ NO_DEFAULT_PATH)
+ if(CLANG_TOOL_BIN)
+ unset(CLANG_TOOL_VERSION_MESSAGE)
+ execute_process(COMMAND ${CLANG_TOOL_BIN} "-version"
+ OUTPUT_VARIABLE CLANG_TOOL_VERSION_MESSAGE
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+ if(NOT (${CLANG_TOOL_VERSION_MESSAGE} MATCHES ${VERSION_CHECK_PATTERN}))
+ set(CLANG_TOOL_BIN "CLANG_TOOL_BIN-NOTFOUND")
+ endif()
+ endif()
+ endif()
+ if(CLANG_TOOL_BIN)
+ set(${OUTPUT}
+ ${CLANG_TOOL_BIN}
+ PARENT_SCOPE)
+ else()
+ set(${OUTPUT}
+ "${OUTPUT}-NOTFOUND"
+ PARENT_SCOPE)
+ endif()
+endfunction()
+
+string(REGEX REPLACE "\\." "\\\\." ARROW_CLANG_TOOLS_VERSION_ESCAPED
+ "${ARROW_CLANG_TOOLS_VERSION}")
+
+find_clang_tool(clang-tidy CLANG_TIDY_BIN
+ "LLVM version ${ARROW_CLANG_TOOLS_VERSION_ESCAPED}")
+if(CLANG_TIDY_BIN)
+ set(CLANG_TIDY_FOUND 1)
+ message(STATUS "clang-tidy found at ${CLANG_TIDY_BIN}")
+else()
+ set(CLANG_TIDY_FOUND 0)
+ message(STATUS "clang-tidy not found")
+endif()
+
+find_clang_tool(clang-format CLANG_FORMAT_BIN
+ "^clang-format version ${ARROW_CLANG_TOOLS_VERSION_ESCAPED}")
+if(CLANG_FORMAT_BIN)
+ set(CLANG_FORMAT_FOUND 1)
+ message(STATUS "clang-format found at ${CLANG_FORMAT_BIN}")
+else()
+ set(CLANG_FORMAT_FOUND 0)
+ message(STATUS "clang-format not found")
+endif()
+
+find_package_handle_standard_args(ClangTools REQUIRED_VARS CLANG_FORMAT_BIN
+ CLANG_TIDY_BIN)
diff --git a/src/arrow/cpp/cmake_modules/FindGLOG.cmake b/src/arrow/cpp/cmake_modules/FindGLOG.cmake
new file mode 100644
index 000000000..d67eb0056
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindGLOG.cmake
@@ -0,0 +1,56 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Tries to find GLog headers and libraries.
+#
+# Usage of this module as follows:
+#
+# find_package(GLOG)
+
+find_package(PkgConfig QUIET)
+pkg_check_modules(GLOG_PC libglog)
+if(GLOG_PC_FOUND)
+ set(GLOG_INCLUDE_DIR "${GLOG_PC_INCLUDEDIR}")
+ list(APPEND GLOG_PC_LIBRARY_DIRS "${GLOG_PC_LIBDIR}")
+ find_library(GLOG_LIB glog
+ PATHS ${GLOG_PC_LIBRARY_DIRS}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+elseif(GLOG_ROOT)
+ find_library(GLOG_LIB
+ NAMES glog
+ PATHS ${GLOG_ROOT}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ find_path(GLOG_INCLUDE_DIR
+ NAMES glog/logging.h
+ PATHS ${GLOG_ROOT}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+else()
+ find_library(GLOG_LIB
+ NAMES glog
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES})
+ find_path(GLOG_INCLUDE_DIR
+ NAMES glog/logging.h
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+endif()
+
+find_package_handle_standard_args(GLOG REQUIRED_VARS GLOG_INCLUDE_DIR GLOG_LIB)
+
+if(GLOG_FOUND)
+ add_library(glog::glog UNKNOWN IMPORTED)
+ set_target_properties(glog::glog
+ PROPERTIES IMPORTED_LOCATION "${GLOG_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${GLOG_INCLUDE_DIR}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindGandiva.cmake b/src/arrow/cpp/cmake_modules/FindGandiva.cmake
new file mode 100644
index 000000000..c533abed7
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindGandiva.cmake
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# - Find Gandiva (gandiva/arrow.h, libgandiva.a, libgandiva.so)
+#
+# This module requires Arrow from which it uses
+# arrow_find_package()
+#
+# This module defines
+# GANDIVA_FOUND, whether Gandiva has been found
+# GANDIVA_IMPORT_LIB, path to libgandiva's import library (Windows only)
+# GANDIVA_INCLUDE_DIR, directory containing headers
+# GANDIVA_LIBS, deprecated. Use GANDIVA_LIB_DIR instead
+# GANDIVA_LIB_DIR, directory containing Gandiva libraries
+# GANDIVA_SHARED_IMP_LIB, deprecated. Use GANDIVA_IMPORT_LIB instead
+# GANDIVA_SHARED_LIB, path to libgandiva's shared library
+# GANDIVA_SO_VERSION, shared object version of found Gandiva such as "100"
+# GANDIVA_STATIC_LIB, path to libgandiva.a
+
+if(DEFINED GANDIVA_FOUND)
+ return()
+endif()
+
+set(find_package_arguments)
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION)
+ list(APPEND find_package_arguments "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}")
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED)
+ list(APPEND find_package_arguments REQUIRED)
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY)
+ list(APPEND find_package_arguments QUIET)
+endif()
+find_package(Arrow ${find_package_arguments})
+
+if(ARROW_FOUND)
+ arrow_find_package(GANDIVA
+ "${ARROW_HOME}"
+ gandiva
+ gandiva/arrow.h
+ Gandiva
+ gandiva)
+ if(NOT GANDIVA_VERSION)
+ set(GANDIVA_VERSION "${ARROW_VERSION}")
+ endif()
+ set(GANDIVA_ABI_VERSION "${ARROW_ABI_VERSION}")
+ set(GANDIVA_SO_VERSION "${ARROW_SO_VERSION}")
+endif()
+
+if("${GANDIVA_VERSION}" VERSION_EQUAL "${ARROW_VERSION}")
+ set(GANDIVA_VERSION_MATCH TRUE)
+else()
+ set(GANDIVA_VERSION_MATCH FALSE)
+endif()
+
+mark_as_advanced(GANDIVA_ABI_VERSION
+ GANDIVA_IMPORT_LIB
+ GANDIVA_INCLUDE_DIR
+ GANDIVA_LIBS
+ GANDIVA_LIB_DIR
+ GANDIVA_SHARED_IMP_LIB
+ GANDIVA_SHARED_LIB
+ GANDIVA_SO_VERSION
+ GANDIVA_STATIC_LIB
+ GANDIVA_VERSION
+ GANDIVA_VERSION_MATCH)
+
+find_package_handle_standard_args(
+ Gandiva
+ REQUIRED_VARS GANDIVA_INCLUDE_DIR GANDIVA_LIB_DIR GANDIVA_SO_VERSION
+ GANDIVA_VERSION_MATCH
+ VERSION_VAR GANDIVA_VERSION)
+set(GANDIVA_FOUND ${Gandiva_FOUND})
+
+if(Gandiva_FOUND AND NOT Gandiva_FIND_QUIETLY)
+ message(STATUS "Found the Gandiva by ${GANDIVA_FIND_APPROACH}")
+ message(STATUS "Found the Gandiva shared library: ${GANDIVA_SHARED_LIB}")
+ message(STATUS "Found the Gandiva import library: ${GANDIVA_IMPORT_LIB}")
+ message(STATUS "Found the Gandiva static library: ${GANDIVA_STATIC_LIB}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindInferTools.cmake b/src/arrow/cpp/cmake_modules/FindInferTools.cmake
new file mode 100644
index 000000000..c4b65653a
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindInferTools.cmake
@@ -0,0 +1,47 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Tries to find the infer module
+#
+# Usage of this module as follows:
+#
+# find_package(InferTools)
+#
+# Variables used by this module, they can change the default behaviour and need
+# to be set before calling find_package:
+#
+# InferTools_PATH -
+# When set, this path is inspected instead of standard library binary locations
+# to find infer
+#
+# This module defines
+# INFER_BIN, The path to the infer binary
+# INFER_FOUND, Whether infer was found
+
+find_program(INFER_BIN
+ NAMES infer
+ PATHS ${InferTools_PATH}
+ $ENV{INFER_TOOLS_PATH}
+ /usr/local/bin
+ /usr/bin
+ /usr/local/homebrew/bin
+ /opt/local/bin
+ NO_DEFAULT_PATH)
+
+if("${INFER_BIN}" STREQUAL "INFER_BIN-NOTFOUND")
+ set(INFER_FOUND 0)
+ message(STATUS "infer not found")
+else()
+ set(INFER_FOUND 1)
+ message(STATUS "infer found at ${INFER_BIN}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindLLVMAlt.cmake b/src/arrow/cpp/cmake_modules/FindLLVMAlt.cmake
new file mode 100644
index 000000000..380f2d47c
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindLLVMAlt.cmake
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Usage of this module as follows:
+#
+# find_package(LLVMAlt)
+
+set(LLVM_HINTS ${LLVM_ROOT} ${LLVM_DIR} /usr/lib /usr/share)
+if(LLVM_BREW_PREFIX)
+ list(APPEND LLVM_HINTS ${LLVM_BREW_PREFIX})
+endif()
+foreach(ARROW_LLVM_VERSION ${ARROW_LLVM_VERSIONS})
+ find_package(LLVM
+ ${ARROW_LLVM_VERSION}
+ CONFIG
+ HINTS
+ ${LLVM_HINTS})
+ if(LLVM_FOUND)
+ break()
+ endif()
+endforeach()
+
+if(LLVM_FOUND)
+ # Find the libraries that correspond to the LLVM components
+ llvm_map_components_to_libnames(LLVM_LIBS
+ core
+ mcjit
+ native
+ ipo
+ bitreader
+ target
+ linker
+ analysis
+ debuginfodwarf)
+
+ find_program(LLVM_LINK_EXECUTABLE llvm-link HINTS ${LLVM_TOOLS_BINARY_DIR})
+
+ find_program(CLANG_EXECUTABLE
+ NAMES clang-${LLVM_PACKAGE_VERSION}
+ clang-${LLVM_VERSION_MAJOR}.${LLVM_VERSION_MINOR}
+ clang-${LLVM_VERSION_MAJOR} clang
+ HINTS ${LLVM_TOOLS_BINARY_DIR})
+
+ add_library(LLVM::LLVM_INTERFACE INTERFACE IMPORTED)
+
+ set_target_properties(LLVM::LLVM_INTERFACE
+ PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${LLVM_INCLUDE_DIRS}"
+ INTERFACE_COMPILE_FLAGS "${LLVM_DEFINITIONS}"
+ INTERFACE_LINK_LIBRARIES "${LLVM_LIBS}")
+endif()
+
+mark_as_advanced(CLANG_EXECUTABLE LLVM_LINK_EXECUTABLE)
+
+find_package_handle_standard_args(
+ LLVMAlt
+ REQUIRED_VARS # The first variable is used for display.
+ LLVM_PACKAGE_VERSION CLANG_EXECUTABLE LLVM_FOUND LLVM_LINK_EXECUTABLE)
+if(LLVMAlt_FOUND)
+ message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
+ message(STATUS "Found llvm-link ${LLVM_LINK_EXECUTABLE}")
+ message(STATUS "Found clang ${CLANG_EXECUTABLE}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindLz4.cmake b/src/arrow/cpp/cmake_modules/FindLz4.cmake
new file mode 100644
index 000000000..bc8051fe9
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindLz4.cmake
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(MSVC_TOOLCHAIN AND NOT DEFINED LZ4_MSVC_LIB_PREFIX)
+ set(LZ4_MSVC_LIB_PREFIX "lib")
+endif()
+set(LZ4_LIB_NAME_BASE "${LZ4_MSVC_LIB_PREFIX}lz4")
+
+if(ARROW_LZ4_USE_SHARED)
+ set(LZ4_LIB_NAMES)
+ if(CMAKE_IMPORT_LIBRARY_SUFFIX)
+ list(APPEND
+ LZ4_LIB_NAMES
+ "${CMAKE_IMPORT_LIBRARY_PREFIX}${LZ4_LIB_NAME_BASE}${CMAKE_IMPORT_LIBRARY_SUFFIX}"
+ )
+ endif()
+ list(APPEND LZ4_LIB_NAMES
+ "${CMAKE_SHARED_LIBRARY_PREFIX}${LZ4_LIB_NAME_BASE}${CMAKE_SHARED_LIBRARY_SUFFIX}")
+else()
+ if(MSVC AND NOT DEFINED LZ4_MSVC_STATIC_LIB_SUFFIX)
+ set(LZ4_MSVC_STATIC_LIB_SUFFIX "_static")
+ endif()
+ set(LZ4_STATIC_LIB_SUFFIX "${LZ4_MSVC_STATIC_LIB_SUFFIX}${CMAKE_STATIC_LIBRARY_SUFFIX}")
+ set(LZ4_LIB_NAMES
+ "${CMAKE_STATIC_LIBRARY_PREFIX}${LZ4_LIB_NAME_BASE}${LZ4_STATIC_LIB_SUFFIX}")
+endif()
+
+if(LZ4_ROOT)
+ find_library(LZ4_LIB
+ NAMES ${LZ4_LIB_NAMES}
+ PATHS ${LZ4_ROOT}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ find_path(LZ4_INCLUDE_DIR
+ NAMES lz4.h
+ PATHS ${LZ4_ROOT}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+
+else()
+ find_package(PkgConfig QUIET)
+ pkg_check_modules(LZ4_PC liblz4)
+ if(LZ4_PC_FOUND)
+ set(LZ4_INCLUDE_DIR "${LZ4_PC_INCLUDEDIR}")
+
+ list(APPEND LZ4_PC_LIBRARY_DIRS "${LZ4_PC_LIBDIR}")
+ find_library(LZ4_LIB
+ NAMES ${LZ4_LIB_NAMES}
+ PATHS ${LZ4_PC_LIBRARY_DIRS}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES})
+ else()
+ find_library(LZ4_LIB
+ NAMES ${LZ4_LIB_NAMES}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES})
+ find_path(LZ4_INCLUDE_DIR
+ NAMES lz4.h
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+ endif()
+endif()
+
+find_package_handle_standard_args(Lz4 REQUIRED_VARS LZ4_LIB LZ4_INCLUDE_DIR)
+
+if(Lz4_FOUND)
+ set(Lz4_FOUND TRUE)
+ add_library(LZ4::lz4 UNKNOWN IMPORTED)
+ set_target_properties(LZ4::lz4
+ PROPERTIES IMPORTED_LOCATION "${LZ4_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${LZ4_INCLUDE_DIR}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindNumPy.cmake b/src/arrow/cpp/cmake_modules/FindNumPy.cmake
new file mode 100644
index 000000000..c3daba149
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindNumPy.cmake
@@ -0,0 +1,96 @@
+# - Find the NumPy libraries
+# This module finds if NumPy is installed, and sets the following variables
+# indicating where it is.
+#
+# TODO: Update to provide the libraries and paths for linking npymath lib.
+#
+# NUMPY_FOUND - was NumPy found
+# NUMPY_VERSION - the version of NumPy found as a string
+# NUMPY_VERSION_MAJOR - the major version number of NumPy
+# NUMPY_VERSION_MINOR - the minor version number of NumPy
+# NUMPY_VERSION_PATCH - the patch version number of NumPy
+# NUMPY_VERSION_DECIMAL - e.g. version 1.6.1 is 10601
+# NUMPY_INCLUDE_DIRS - path to the NumPy include files
+
+#============================================================================
+# Copyright 2012 Continuum Analytics, Inc.
+#
+# MIT License
+#
+# Permission is hereby granted, free of charge, to any person obtaining
+# a copy of this software and associated documentation files
+# (the "Software"), to deal in the Software without restriction, including
+# without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and to permit
+# persons to whom the Software is furnished to do so, subject to
+# the following conditions:
+#
+# The above copyright notice and this permission notice shall be included
+# in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
+# OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
+# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+# OTHER DEALINGS IN THE SOFTWARE.
+#
+#============================================================================
+
+# Legacy code for CMake < 3.15.0. The primary point of entry should be
+# FindPython3Alt.cmake.
+
+if(NOT PYTHONINTERP_FOUND)
+ set(NUMPY_FOUND FALSE)
+ return()
+endif()
+
+execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c"
+ "import numpy as n; print(n.__version__); print(n.get_include());"
+ RESULT_VARIABLE _NUMPY_SEARCH_SUCCESS
+ OUTPUT_VARIABLE _NUMPY_VALUES_OUTPUT
+ ERROR_VARIABLE _NUMPY_ERROR_VALUE
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+
+if(NOT _NUMPY_SEARCH_SUCCESS MATCHES 0)
+ if(NumPy_FIND_REQUIRED)
+ message(FATAL_ERROR
+ "NumPy import failure:\n${_NUMPY_ERROR_VALUE}")
+ endif()
+ set(NUMPY_FOUND FALSE)
+ return()
+endif()
+
+# Convert the process output into a list
+string(REGEX REPLACE ";" "\\\\;" _NUMPY_VALUES ${_NUMPY_VALUES_OUTPUT})
+string(REGEX REPLACE "\n" ";" _NUMPY_VALUES ${_NUMPY_VALUES})
+list(GET _NUMPY_VALUES 0 NUMPY_VERSION)
+list(GET _NUMPY_VALUES 1 NUMPY_INCLUDE_DIRS)
+
+string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" _VER_CHECK "${NUMPY_VERSION}")
+if("${_VER_CHECK}" STREQUAL "")
+ # The output from Python was unexpected. Raise an error always
+ # here, because we found NumPy, but it appears to be corrupted somehow.
+ message(FATAL_ERROR
+ "Requested version and include path from NumPy, got instead:\n${_NUMPY_VALUES_OUTPUT}\n")
+ return()
+endif()
+
+# Make sure all directory separators are '/'
+string(REGEX REPLACE "\\\\" "/" NUMPY_INCLUDE_DIRS ${NUMPY_INCLUDE_DIRS})
+
+# Get the major and minor version numbers
+string(REGEX REPLACE "\\." ";" _NUMPY_VERSION_LIST ${NUMPY_VERSION})
+list(GET _NUMPY_VERSION_LIST 0 NUMPY_VERSION_MAJOR)
+list(GET _NUMPY_VERSION_LIST 1 NUMPY_VERSION_MINOR)
+list(GET _NUMPY_VERSION_LIST 2 NUMPY_VERSION_PATCH)
+string(REGEX MATCH "[0-9]*" NUMPY_VERSION_PATCH ${NUMPY_VERSION_PATCH})
+math(EXPR NUMPY_VERSION_DECIMAL
+ "(${NUMPY_VERSION_MAJOR} * 10000) + (${NUMPY_VERSION_MINOR} * 100) + ${NUMPY_VERSION_PATCH}")
+
+find_package_message(NUMPY
+ "Found NumPy: version \"${NUMPY_VERSION}\" ${NUMPY_INCLUDE_DIRS}"
+ "${NUMPY_INCLUDE_DIRS}${NUMPY_VERSION}")
+
+set(NUMPY_FOUND TRUE)
diff --git a/src/arrow/cpp/cmake_modules/FindORC.cmake b/src/arrow/cpp/cmake_modules/FindORC.cmake
new file mode 100644
index 000000000..d45b16078
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindORC.cmake
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# - Find Apache ORC C++ (orc/orc-config.h, liborc.a)
+# This module defines
+# ORC_INCLUDE_DIR, directory containing headers
+# ORC_STATIC_LIB, path to liborc.a
+# ORC_FOUND, whether orc has been found
+
+if(ORC_ROOT)
+ find_library(ORC_STATIC_LIB
+ NAMES orc
+ PATHS ${ORC_ROOT}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES})
+ find_path(ORC_INCLUDE_DIR
+ NAMES orc/orc-config.hh
+ PATHS ${ORC_ROOT}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+else()
+ find_library(ORC_STATIC_LIB
+ NAMES orc
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES})
+ find_path(ORC_INCLUDE_DIR
+ NAMES orc/orc-config.hh
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+endif()
+
+if(ORC_STATIC_LIB AND ORC_INCLUDE_DIR)
+ set(ORC_FOUND TRUE)
+ add_library(orc::liborc STATIC IMPORTED)
+ set_target_properties(orc::liborc
+ PROPERTIES IMPORTED_LOCATION "${ORC_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${ORC_INCLUDE_DIR}")
+else()
+ if(ORC_FIND_REQUIRED)
+ message(FATAL_ERROR "ORC library was required in toolchain and unable to locate")
+ endif()
+ set(ORC_FOUND FALSE)
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindOpenSSLAlt.cmake b/src/arrow/cpp/cmake_modules/FindOpenSSLAlt.cmake
new file mode 100644
index 000000000..603e7d066
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindOpenSSLAlt.cmake
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(ARROW_OPENSSL_USE_SHARED)
+ # Find shared OpenSSL libraries.
+ set(OpenSSL_USE_STATIC_LIBS OFF)
+ set(OPENSSL_USE_STATIC_LIBS OFF)
+ find_package(OpenSSL)
+else()
+ # Find static OpenSSL headers and libs
+ set(OpenSSL_USE_STATIC_LIBS ON)
+ set(OPENSSL_USE_STATIC_LIBS ON)
+ find_package(OpenSSL)
+endif()
+
+if(OPENSSL_FOUND)
+ message(STATUS "OpenSSL found with ${OPENSSL_VERSION} version")
+ if(OPENSSL_VERSION LESS "1.1.0")
+ message(SEND_ERROR "The OpenSSL must be greater than or equal to 1.1.0")
+ endif()
+else()
+ message(SEND_ERROR "Not found the OpenSSL library")
+endif()
+
+if(NOT GANDIVA_OPENSSL_LIBS)
+ if(WIN32)
+ if(CMAKE_VERSION VERSION_LESS 3.18)
+ set(GANDIVA_OPENSSL_LIBS OpenSSL::Crypto OpenSSL::SSL)
+ else()
+ set(GANDIVA_OPENSSL_LIBS OpenSSL::Crypto OpenSSL::SSL OpenSSL::applink)
+ endif()
+ else()
+ set(GANDIVA_OPENSSL_LIBS OpenSSL::Crypto OpenSSL::SSL)
+ endif()
+endif()
+
+if(NOT GANDIVA_OPENSSL_INCLUDE_DIR)
+ set(GANDIVA_OPENSSL_INCLUDE_DIR ${OPENSSL_INCLUDE_DIR})
+ message(STATUS "OpenSSL include dir: ${GANDIVA_OPENSSL_INCLUDE_DIR}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindParquet.cmake b/src/arrow/cpp/cmake_modules/FindParquet.cmake
new file mode 100644
index 000000000..e071fc822
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindParquet.cmake
@@ -0,0 +1,126 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# - Find Parquet (parquet/api/reader.h, libparquet.a, libparquet.so)
+#
+# This module requires Arrow from which it uses
+# arrow_find_package()
+#
+# This module defines
+# PARQUET_FOUND, whether Parquet has been found
+# PARQUET_IMPORT_LIB, path to libparquet's import library (Windows only)
+# PARQUET_INCLUDE_DIR, directory containing headers
+# PARQUET_LIBS, deprecated. Use PARQUET_LIB_DIR instead
+# PARQUET_LIB_DIR, directory containing Parquet libraries
+# PARQUET_SHARED_IMP_LIB, deprecated. Use PARQUET_IMPORT_LIB instead
+# PARQUET_SHARED_LIB, path to libparquet's shared library
+# PARQUET_SO_VERSION, shared object version of found Parquet such as "100"
+# PARQUET_STATIC_LIB, path to libparquet.a
+
+if(DEFINED PARQUET_FOUND)
+ return()
+endif()
+
+set(find_package_arguments)
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION)
+ list(APPEND find_package_arguments "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}")
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED)
+ list(APPEND find_package_arguments REQUIRED)
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY)
+ list(APPEND find_package_arguments QUIET)
+endif()
+find_package(Arrow ${find_package_arguments})
+
+if(NOT "$ENV{PARQUET_HOME}" STREQUAL "")
+ file(TO_CMAKE_PATH "$ENV{PARQUET_HOME}" PARQUET_HOME)
+endif()
+
+if((NOT PARQUET_HOME) AND ARROW_HOME)
+ set(PARQUET_HOME ${ARROW_HOME})
+endif()
+
+if(ARROW_FOUND)
+ arrow_find_package(PARQUET
+ "${PARQUET_HOME}"
+ parquet
+ parquet/api/reader.h
+ Parquet
+ parquet)
+ if(PARQUET_HOME)
+ if(PARQUET_INCLUDE_DIR)
+ file(READ "${PARQUET_INCLUDE_DIR}/parquet/parquet_version.h"
+ PARQUET_VERSION_H_CONTENT)
+ arrow_extract_macro_value(PARQUET_VERSION_MAJOR "PARQUET_VERSION_MAJOR"
+ "${PARQUET_VERSION_H_CONTENT}")
+ arrow_extract_macro_value(PARQUET_VERSION_MINOR "PARQUET_VERSION_MINOR"
+ "${PARQUET_VERSION_H_CONTENT}")
+ arrow_extract_macro_value(PARQUET_VERSION_PATCH "PARQUET_VERSION_PATCH"
+ "${PARQUET_VERSION_H_CONTENT}")
+ if("${PARQUET_VERSION_MAJOR}" STREQUAL ""
+ OR "${PARQUET_VERSION_MINOR}" STREQUAL ""
+ OR "${PARQUET_VERSION_PATCH}" STREQUAL "")
+ set(PARQUET_VERSION "0.0.0")
+ else()
+ set(PARQUET_VERSION
+ "${PARQUET_VERSION_MAJOR}.${PARQUET_VERSION_MINOR}.${PARQUET_VERSION_PATCH}")
+ endif()
+
+ arrow_extract_macro_value(PARQUET_SO_VERSION_QUOTED "PARQUET_SO_VERSION"
+ "${PARQUET_VERSION_H_CONTENT}")
+ string(REGEX REPLACE "^\"(.+)\"$" "\\1" PARQUET_SO_VERSION
+ "${PARQUET_SO_VERSION_QUOTED}")
+ arrow_extract_macro_value(PARQUET_FULL_SO_VERSION_QUOTED "PARQUET_FULL_SO_VERSION"
+ "${PARQUET_VERSION_H_CONTENT}")
+ string(REGEX REPLACE "^\"(.+)\"$" "\\1" PARQUET_FULL_SO_VERSION
+ "${PARQUET_FULL_SO_VERSION_QUOTED}")
+ endif()
+ else()
+ if(PARQUET_USE_CMAKE_PACKAGE_CONFIG)
+ find_package(Parquet CONFIG)
+ elseif(PARQUET_USE_PKG_CONFIG)
+ pkg_get_variable(PARQUET_SO_VERSION parquet so_version)
+ pkg_get_variable(PARQUET_FULL_SO_VERSION parquet full_so_version)
+ endif()
+ endif()
+ set(PARQUET_ABI_VERSION "${PARQUET_SO_VERSION}")
+endif()
+
+mark_as_advanced(PARQUET_ABI_VERSION
+ PARQUET_IMPORT_LIB
+ PARQUET_INCLUDE_DIR
+ PARQUET_LIBS
+ PARQUET_LIB_DIR
+ PARQUET_SHARED_IMP_LIB
+ PARQUET_SHARED_LIB
+ PARQUET_SO_VERSION
+ PARQUET_STATIC_LIB
+ PARQUET_VERSION)
+
+find_package_handle_standard_args(
+ Parquet
+ REQUIRED_VARS PARQUET_INCLUDE_DIR PARQUET_LIB_DIR PARQUET_SO_VERSION
+ VERSION_VAR PARQUET_VERSION)
+set(PARQUET_FOUND ${Parquet_FOUND})
+
+if(Parquet_FOUND AND NOT Parquet_FIND_QUIETLY)
+ message(STATUS "Parquet version: ${PARQUET_VERSION} (${PARQUET_FIND_APPROACH})")
+ message(STATUS "Found the Parquet shared library: ${PARQUET_SHARED_LIB}")
+ message(STATUS "Found the Parquet import library: ${PARQUET_IMPORT_LIB}")
+ message(STATUS "Found the Parquet static library: ${PARQUET_STATIC_LIB}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindPlasma.cmake b/src/arrow/cpp/cmake_modules/FindPlasma.cmake
new file mode 100644
index 000000000..2e634844c
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindPlasma.cmake
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# - Find Plasma (plasma/client.h, libplasma.a, libplasma.so)
+#
+# This module requires Arrow from which it uses
+# arrow_find_package()
+#
+# This module defines
+# PLASMA_EXECUTABLE, deprecated. Use PLASMA_STORE_SERVER instead
+# PLASMA_FOUND, whether Plasma has been found
+# PLASMA_IMPORT_LIB, path to libplasma's import library (Windows only)
+# PLASMA_INCLUDE_DIR, directory containing headers
+# PLASMA_LIBS, deprecated. Use PLASMA_LIB_DIR instead
+# PLASMA_LIB_DIR, directory containing Plasma libraries
+# PLASMA_SHARED_IMP_LIB, deprecated. Use PLASMA_IMPORT_LIB instead
+# PLASMA_SHARED_LIB, path to libplasma's shared library
+# PLASMA_SO_VERSION, shared object version of found Plasma such as "100"
+# PLASMA_STATIC_LIB, path to libplasma.a
+# PLASMA_STORE_SERVER, path to plasma-store-server
+
+if(DEFINED PLASMA_FOUND)
+ return()
+endif()
+
+set(find_package_arguments)
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION)
+ list(APPEND find_package_arguments "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}")
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED)
+ list(APPEND find_package_arguments REQUIRED)
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY)
+ list(APPEND find_package_arguments QUIET)
+endif()
+find_package(Arrow ${find_package_arguments})
+
+if(ARROW_FOUND)
+ arrow_find_package(PLASMA
+ "${ARROW_HOME}"
+ plasma
+ plasma/client.h
+ Plasma
+ plasma)
+ if(ARROW_HOME)
+ set(PLASMA_STORE_SERVER
+ ${ARROW_HOME}/bin/plasma-store-server${CMAKE_EXECUTABLE_SUFFIX})
+ else()
+ if(PLASMA_USE_CMAKE_PACKAGE_CONFIG)
+ find_package(Plasma CONFIG)
+ elseif(PLASMA_USE_PKG_CONFIG)
+ pkg_get_variable(PLASMA_STORE_SERVER plasma plasma_store_server)
+ endif()
+ endif()
+ set(PLASMA_VERSION "${ARROW_VERSION}")
+ set(PLASMA_SO_VERSION "${ARROW_SO_VERSION}")
+ set(PLASMA_ABI_VERSION "${PLASMA_SO_VERSION}")
+ # For backward compatibility
+ set(PLASMA_EXECUTABLE "${PLASMA_STORE_SERVER}")
+ set(PLASMA_LIBS "${PLASMA_LIB_DIR}")
+endif()
+
+mark_as_advanced(PLASMA_ABI_VERSION
+ PLASMA_EXECUTABLE
+ PLASMA_IMPORT_LIB
+ PLASMA_INCLUDE_DIR
+ PLASMA_LIBS
+ PLASMA_LIB_DIR
+ PLASMA_SHARED_IMP_LIB
+ PLASMA_SHARED_LIB
+ PLASMA_SO_VERSION
+ PLASMA_STATIC_LIB
+ PLASMA_STORE_SERVER
+ PLASMA_VERSION)
+
+find_package_handle_standard_args(
+ Plasma
+ REQUIRED_VARS PLASMA_INCLUDE_DIR PLASMA_LIB_DIR PLASMA_SO_VERSION PLASMA_STORE_SERVER
+ VERSION_VAR PLASMA_VERSION)
+set(PLASMA_FOUND ${Plasma_FOUND})
+
+if(Plasma_FOUND AND NOT Plasma_FIND_QUIETLY)
+ message(STATUS "Found the Plasma by ${PLASMA_FIND_APPROACH}")
+ message(STATUS "Found the plasma-store-server: ${PLASMA_STORE_SERVER}")
+ message(STATUS "Found the Plasma shared library: ${PLASMA_SHARED_LIB}")
+ message(STATUS "Found the Plasma import library: ${PLASMA_IMPORT_LIB}")
+ message(STATUS "Found the Plasma static library: ${PLASMA_STATIC_LIB}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindPython3Alt.cmake b/src/arrow/cpp/cmake_modules/FindPython3Alt.cmake
new file mode 100644
index 000000000..ab91c7be0
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindPython3Alt.cmake
@@ -0,0 +1,96 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This module finds the libraries corresponding to the Python 3 interpreter
+# and the NumPy package, and sets the following variables:
+# - PYTHON_EXECUTABLE
+# - PYTHON_INCLUDE_DIRS
+# - PYTHON_LIBRARIES
+# - PYTHON_OTHER_LIBS
+# - NUMPY_INCLUDE_DIRS
+
+# Need CMake 3.15 or later for Python3_FIND_STRATEGY
+if(${CMAKE_VERSION} VERSION_LESS "3.15.0")
+ # Use deprecated Python- and NumPy-finding code
+ if(Python3Alt_FIND_REQUIRED)
+ find_package(PythonLibsNew REQUIRED)
+ find_package(NumPy REQUIRED)
+ else()
+ find_package(PythonLibsNew)
+ find_package(NumPy)
+ endif()
+ find_package_handle_standard_args(
+ Python3Alt REQUIRED_VARS PYTHON_EXECUTABLE PYTHON_INCLUDE_DIRS NUMPY_INCLUDE_DIRS)
+ return()
+endif()
+
+if(${CMAKE_VERSION} VERSION_LESS "3.18.0" OR ARROW_BUILD_TESTS)
+ # When building arrow-python-test, we need libpython to be present, so ask for
+ # the full "Development" component. Also ask for it on CMake < 3.18,
+ # where "Development.Module" is not available.
+ if(Python3Alt_FIND_REQUIRED)
+ find_package(Python3
+ COMPONENTS Interpreter Development NumPy
+ REQUIRED)
+ else()
+ find_package(Python3 COMPONENTS Interpreter Development NumPy)
+ endif()
+else()
+ if(Python3Alt_FIND_REQUIRED)
+ find_package(Python3
+ COMPONENTS Interpreter Development.Module NumPy
+ REQUIRED)
+ else()
+ find_package(Python3 COMPONENTS Interpreter Development.Module NumPy)
+ endif()
+endif()
+
+if(NOT Python3_FOUND)
+ return()
+endif()
+
+set(PYTHON_EXECUTABLE ${Python3_EXECUTABLE})
+set(PYTHON_INCLUDE_DIRS ${Python3_INCLUDE_DIRS})
+set(PYTHON_LIBRARIES ${Python3_LIBRARIES})
+set(PYTHON_OTHER_LIBS)
+
+get_target_property(NUMPY_INCLUDE_DIRS Python3::NumPy INTERFACE_INCLUDE_DIRECTORIES)
+
+# CMake's python3_add_library() doesn't apply the required extension suffix,
+# detect it ourselves.
+# (https://gitlab.kitware.com/cmake/cmake/issues/20408)
+execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c"
+ "from distutils import sysconfig; print(sysconfig.get_config_var('EXT_SUFFIX'))"
+ RESULT_VARIABLE _PYTHON_RESULT
+ OUTPUT_VARIABLE _PYTHON_STDOUT
+ ERROR_VARIABLE _PYTHON_STDERR)
+
+if(NOT _PYTHON_RESULT MATCHES 0)
+ if(Python3Alt_FIND_REQUIRED)
+ message(FATAL_ERROR "Python 3 config failure:\n${_PYTHON_STDERR}")
+ endif()
+endif()
+
+string(STRIP ${_PYTHON_STDOUT} _EXT_SUFFIX)
+
+function(PYTHON_ADD_MODULE name)
+ python3_add_library(${name} MODULE ${ARGN})
+ set_target_properties(${name} PROPERTIES SUFFIX ${_EXT_SUFFIX})
+endfunction()
+
+find_package_handle_standard_args(
+ Python3Alt REQUIRED_VARS PYTHON_EXECUTABLE PYTHON_INCLUDE_DIRS NUMPY_INCLUDE_DIRS)
diff --git a/src/arrow/cpp/cmake_modules/FindPythonLibsNew.cmake b/src/arrow/cpp/cmake_modules/FindPythonLibsNew.cmake
new file mode 100644
index 000000000..581bba9d4
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindPythonLibsNew.cmake
@@ -0,0 +1,267 @@
+# - Find python libraries
+# This module finds the libraries corresponding to the Python interpreter
+# FindPythonInterp provides.
+# This code sets the following variables:
+#
+# PYTHONLIBS_FOUND - have the Python libs been found
+# PYTHON_PREFIX - path to the Python installation
+# PYTHON_LIBRARIES - path to the python library
+# PYTHON_INCLUDE_DIRS - path to where Python.h is found
+# PYTHON_SITE_PACKAGES - path to installation site-packages
+# PYTHON_IS_DEBUG - whether the Python interpreter is a debug build
+# PYTHON_OTHER_LIBS - third-party libraries (as link flags) needed
+# for linking with Python
+#
+# PYTHON_INCLUDE_PATH - path to where Python.h is found (deprecated)
+#
+# A function PYTHON_ADD_MODULE(<name> src1 src2 ... srcN) is defined
+# to build modules for python.
+#
+# Thanks to talljimbo for the patch adding the 'LDVERSION' config
+# variable usage.
+
+#=============================================================================
+# Copyright 2001-2009 Kitware, Inc.
+# Copyright 2012-2014 Continuum Analytics, Inc.
+#
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the names of Kitware, Inc., the Insight Software Consortium,
+# nor the names of their contributors may be used to endorse or promote
+# products derived from this software without specific prior written
+# permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#=============================================================================
+# (To distribute this file outside of CMake, substitute the full
+# License text for the above reference.)
+
+# Legacy code for CMake < 3.15.0. The primary point of entry should be
+# FindPython3Alt.cmake.
+
+# Use the Python interpreter to find the libs.
+if(PythonLibsNew_FIND_REQUIRED)
+ find_package(PythonInterp REQUIRED)
+else()
+ find_package(PythonInterp)
+endif()
+
+if(NOT PYTHONINTERP_FOUND)
+ set(PYTHONLIBS_FOUND FALSE)
+ return()
+endif()
+
+# According to http://stackoverflow.com/questions/646518/python-how-to-detect-debug-interpreter
+# testing whether sys has the gettotalrefcount function is a reliable,
+# cross-platform way to detect a CPython debug interpreter.
+#
+# The library suffix is from the config var LDVERSION sometimes, otherwise
+# VERSION. VERSION will typically be like "2.7" on unix, and "27" on windows.
+#
+# The config var LIBPL is for Linux, and helps on Debian Jessie where the
+# addition of multi-arch support shuffled things around.
+execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c"
+ "from distutils import sysconfig as s;import sys;import struct;
+print('.'.join(str(v) for v in sys.version_info));
+print(sys.prefix);
+print(s.get_python_inc(plat_specific=True));
+print(s.get_python_lib(plat_specific=True));
+print(s.get_config_var('SO'));
+print(hasattr(sys, 'gettotalrefcount')+0);
+print(struct.calcsize('@P'));
+print(s.get_config_var('LDVERSION') or s.get_config_var('VERSION'));
+print(s.get_config_var('LIBPL'));
+print(s.get_config_var('LIBS') or '');
+"
+ RESULT_VARIABLE _PYTHON_SUCCESS
+ OUTPUT_VARIABLE _PYTHON_VALUES
+ ERROR_VARIABLE _PYTHON_ERROR_VALUE)
+
+if(NOT _PYTHON_SUCCESS MATCHES 0)
+ if(PythonLibsNew_FIND_REQUIRED)
+ message(FATAL_ERROR
+ "Python config failure:\n${_PYTHON_ERROR_VALUE}")
+ endif()
+ set(PYTHONLIBS_FOUND FALSE)
+ return()
+endif()
+
+# Convert the process output into a list
+string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES})
+string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES})
+list(GET _PYTHON_VALUES 0 _PYTHON_VERSION_LIST)
+list(GET _PYTHON_VALUES 1 PYTHON_PREFIX)
+list(GET _PYTHON_VALUES 2 PYTHON_INCLUDE_DIR)
+list(GET _PYTHON_VALUES 3 PYTHON_SITE_PACKAGES)
+list(GET _PYTHON_VALUES 4 PYTHON_MODULE_EXTENSION)
+list(GET _PYTHON_VALUES 5 PYTHON_IS_DEBUG)
+list(GET _PYTHON_VALUES 6 PYTHON_SIZEOF_VOID_P)
+list(GET _PYTHON_VALUES 7 PYTHON_LIBRARY_SUFFIX)
+list(GET _PYTHON_VALUES 8 PYTHON_LIBRARY_PATH)
+list(GET _PYTHON_VALUES 9 PYTHON_OTHER_LIBS)
+
+# Make sure the Python has the same pointer-size as the chosen compiler
+# Skip the check on OS X, it doesn't consistently have CMAKE_SIZEOF_VOID_P defined
+if((NOT APPLE) AND (NOT "${PYTHON_SIZEOF_VOID_P}" STREQUAL "${CMAKE_SIZEOF_VOID_P}"))
+ if(PythonLibsNew_FIND_REQUIRED)
+ math(EXPR _PYTHON_BITS "${PYTHON_SIZEOF_VOID_P} * 8")
+ math(EXPR _CMAKE_BITS "${CMAKE_SIZEOF_VOID_P} * 8")
+ message(FATAL_ERROR
+ "Python config failure: Python is ${_PYTHON_BITS}-bit, "
+ "chosen compiler is ${_CMAKE_BITS}-bit")
+ endif()
+ set(PYTHONLIBS_FOUND FALSE)
+ return()
+endif()
+
+# The built-in FindPython didn't always give the version numbers
+string(REGEX REPLACE "\\." ";" _PYTHON_VERSION_LIST ${_PYTHON_VERSION_LIST})
+list(GET _PYTHON_VERSION_LIST 0 PYTHON_VERSION_MAJOR)
+list(GET _PYTHON_VERSION_LIST 1 PYTHON_VERSION_MINOR)
+list(GET _PYTHON_VERSION_LIST 2 PYTHON_VERSION_PATCH)
+
+# Make sure all directory separators are '/'
+string(REGEX REPLACE "\\\\" "/" PYTHON_PREFIX ${PYTHON_PREFIX})
+string(REGEX REPLACE "\\\\" "/" PYTHON_INCLUDE_DIR ${PYTHON_INCLUDE_DIR})
+string(REGEX REPLACE "\\\\" "/" PYTHON_SITE_PACKAGES ${PYTHON_SITE_PACKAGES})
+
+if(CMAKE_HOST_WIN32)
+ # Appease CMP0054
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
+ set(PYTHON_LIBRARY
+ "${PYTHON_PREFIX}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib")
+ else()
+ find_library(PYTHON_LIBRARY
+ NAMES "python${PYTHON_LIBRARY_SUFFIX}"
+ PATHS "${PYTHON_PREFIX}" NO_DEFAULT_PATH
+ PATH_SUFFIXES "lib" "libs")
+ endif()
+elseif(APPLE)
+
+ set(PYTHON_LIBRARY "${PYTHON_PREFIX}/lib/libpython${PYTHON_LIBRARY_SUFFIX}.dylib")
+
+ if (NOT EXISTS ${PYTHON_LIBRARY})
+ # In some cases libpythonX.X.dylib is not part of the PYTHON_PREFIX and we
+ # need to call `python-config --prefix` to determine the correct location.
+ find_program(PYTHON_CONFIG python-config
+ NO_CMAKE_SYSTEM_PATH)
+ if (PYTHON_CONFIG)
+ execute_process(
+ COMMAND "${PYTHON_CONFIG}" "--prefix"
+ OUTPUT_VARIABLE PYTHON_CONFIG_PREFIX
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+ set(PYTHON_LIBRARY "${PYTHON_CONFIG_PREFIX}/lib/libpython${PYTHON_LIBRARY_SUFFIX}.dylib")
+ endif()
+ endif()
+else()
+ if(${PYTHON_SIZEOF_VOID_P} MATCHES 8)
+ set(_PYTHON_LIBS_SEARCH "${PYTHON_PREFIX}/lib64" "${PYTHON_PREFIX}/lib" "${PYTHON_LIBRARY_PATH}")
+ else()
+ set(_PYTHON_LIBS_SEARCH "${PYTHON_PREFIX}/lib" "${PYTHON_LIBRARY_PATH}")
+ endif()
+ message(STATUS "Searching for Python libs in ${_PYTHON_LIBS_SEARCH}")
+ message(STATUS "Looking for python${PYTHON_LIBRARY_SUFFIX}")
+ # Probably this needs to be more involved. It would be nice if the config
+ # information the python interpreter itself gave us were more complete.
+ find_library(PYTHON_LIBRARY
+ NAMES "python${PYTHON_LIBRARY_SUFFIX}"
+ PATHS ${_PYTHON_LIBS_SEARCH}
+ NO_SYSTEM_ENVIRONMENT_PATH
+ NO_CMAKE_SYSTEM_PATH)
+ message(STATUS "Found Python lib ${PYTHON_LIBRARY}")
+endif()
+
+# For backward compatibility, set PYTHON_INCLUDE_PATH, but make it internal.
+SET(PYTHON_INCLUDE_PATH "${PYTHON_INCLUDE_DIR}" CACHE INTERNAL
+ "Path to where Python.h is found (deprecated)")
+
+MARK_AS_ADVANCED(
+ PYTHON_LIBRARY
+ PYTHON_INCLUDE_DIR
+)
+
+# We use PYTHON_INCLUDE_DIR, PYTHON_LIBRARY and PYTHON_DEBUG_LIBRARY for the
+# cache entries because they are meant to specify the location of a single
+# library. We now set the variables listed by the documentation for this
+# module.
+SET(PYTHON_INCLUDE_DIRS "${PYTHON_INCLUDE_DIR}")
+SET(PYTHON_LIBRARIES "${PYTHON_LIBRARY}")
+SET(PYTHON_DEBUG_LIBRARIES "${PYTHON_DEBUG_LIBRARY}")
+
+
+# Don't know how to get to this directory, just doing something simple :P
+#INCLUDE(${CMAKE_CURRENT_LIST_DIR}/FindPackageHandleStandardArgs.cmake)
+#FIND_PACKAGE_HANDLE_STANDARD_ARGS(PythonLibs DEFAULT_MSG PYTHON_LIBRARIES PYTHON_INCLUDE_DIRS)
+find_package_message(PYTHON
+ "Found PythonLibs: ${PYTHON_LIBRARY}"
+ "${PYTHON_EXECUTABLE}${PYTHON_VERSION}")
+
+
+# PYTHON_ADD_MODULE(<name> src1 src2 ... srcN) is used to build modules for python.
+FUNCTION(PYTHON_ADD_MODULE _NAME )
+ GET_PROPERTY(_TARGET_SUPPORTS_SHARED_LIBS
+ GLOBAL PROPERTY TARGET_SUPPORTS_SHARED_LIBS)
+ OPTION(PYTHON_ENABLE_MODULE_${_NAME} "Add module ${_NAME}" TRUE)
+ OPTION(PYTHON_MODULE_${_NAME}_BUILD_SHARED
+ "Add module ${_NAME} shared" ${_TARGET_SUPPORTS_SHARED_LIBS})
+
+ # Mark these options as advanced
+ MARK_AS_ADVANCED(PYTHON_ENABLE_MODULE_${_NAME}
+ PYTHON_MODULE_${_NAME}_BUILD_SHARED)
+
+ IF(PYTHON_ENABLE_MODULE_${_NAME})
+ IF(PYTHON_MODULE_${_NAME}_BUILD_SHARED)
+ SET(PY_MODULE_TYPE MODULE)
+ ELSE(PYTHON_MODULE_${_NAME}_BUILD_SHARED)
+ SET(PY_MODULE_TYPE STATIC)
+ SET_PROPERTY(GLOBAL APPEND PROPERTY PY_STATIC_MODULES_LIST ${_NAME})
+ ENDIF(PYTHON_MODULE_${_NAME}_BUILD_SHARED)
+
+ SET_PROPERTY(GLOBAL APPEND PROPERTY PY_MODULES_LIST ${_NAME})
+ ADD_LIBRARY(${_NAME} ${PY_MODULE_TYPE} ${ARGN})
+ IF(APPLE)
+ # On OS X, linking against the Python libraries causes
+ # segfaults, so do this dynamic lookup instead.
+ SET_TARGET_PROPERTIES(${_NAME} PROPERTIES LINK_FLAGS
+ "-undefined dynamic_lookup")
+ ELSEIF(MSVC)
+ target_link_libraries(${_NAME} ${PYTHON_LIBRARIES})
+ ELSE()
+ # In general, we should not link against libpython as we do not embed the
+ # Python interpreter. The python binary itself can then define where the
+ # symbols should loaded from. For being manylinux1 compliant, one is not
+ # allowed to link to libpython. Partly because not all systems ship it,
+ # also because the interpreter ABI/API was not stable between patch
+ # releases for Python < 3.5.
+ SET_TARGET_PROPERTIES(${_NAME} PROPERTIES LINK_FLAGS
+ "-Wl,-undefined,dynamic_lookup")
+ ENDIF()
+ IF(PYTHON_MODULE_${_NAME}_BUILD_SHARED)
+ SET_TARGET_PROPERTIES(${_NAME} PROPERTIES PREFIX "${PYTHON_MODULE_PREFIX}")
+ SET_TARGET_PROPERTIES(${_NAME} PROPERTIES SUFFIX "${PYTHON_MODULE_EXTENSION}")
+ ELSE()
+ ENDIF()
+
+ ENDIF(PYTHON_ENABLE_MODULE_${_NAME})
+ENDFUNCTION(PYTHON_ADD_MODULE)
diff --git a/src/arrow/cpp/cmake_modules/FindRapidJSONAlt.cmake b/src/arrow/cpp/cmake_modules/FindRapidJSONAlt.cmake
new file mode 100644
index 000000000..9a449a528
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindRapidJSONAlt.cmake
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set(find_package_args)
+if(RapidJSONAlt_FIND_VERSION)
+ list(APPEND find_package_args ${RapidJSONAlt_FIND_VERSION})
+endif()
+if(RapidJSONAlt_FIND_QUIETLY)
+ list(APPEND find_package_args QUIET)
+endif()
+find_package(RapidJSON ${find_package_args})
+if(RapidJSON_FOUND)
+ set(RapidJSONAlt_FOUND TRUE)
+ set(RAPIDJSON_INCLUDE_DIR ${RAPIDJSON_INCLUDE_DIRS})
+ return()
+endif()
+
+if(RapidJSON_ROOT)
+ find_path(RAPIDJSON_INCLUDE_DIR
+ NAMES rapidjson/rapidjson.h
+ PATHS ${RapidJSON_ROOT}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES "include")
+else()
+ find_path(RAPIDJSON_INCLUDE_DIR
+ NAMES rapidjson/rapidjson.h
+ PATH_SUFFIXES "include")
+endif()
+
+if(RAPIDJSON_INCLUDE_DIR)
+ file(READ "${RAPIDJSON_INCLUDE_DIR}/rapidjson/rapidjson.h" RAPIDJSON_H_CONTENT)
+ string(REGEX MATCH "#define RAPIDJSON_MAJOR_VERSION ([0-9]+)"
+ RAPIDJSON_MAJOR_VERSION_DEFINITION "${RAPIDJSON_H_CONTENT}")
+ string(REGEX REPLACE "^.+ ([0-9]+)$" "\\1" RAPIDJSON_MAJOR_VERSION
+ "${RAPIDJSON_MAJOR_VERSION_DEFINITION}")
+ string(REGEX MATCH "#define RAPIDJSON_MINOR_VERSION ([0-9]+)"
+ RAPIDJSON_MINOR_VERSION_DEFINITION "${RAPIDJSON_H_CONTENT}")
+ string(REGEX REPLACE "^.+ ([0-9]+)$" "\\1" RAPIDJSON_MINOR_VERSION
+ "${RAPIDJSON_MINOR_VERSION_DEFINITION}")
+ string(REGEX MATCH "#define RAPIDJSON_PATCH_VERSION ([0-9]+)"
+ RAPIDJSON_PATCH_VERSION_DEFINITION "${RAPIDJSON_H_CONTENT}")
+ string(REGEX REPLACE "^.+ ([0-9]+)$" "\\1" RAPIDJSON_PATCH_VERSION
+ "${RAPIDJSON_PATCH_VERSION_DEFINITION}")
+ if("${RAPIDJSON_MAJOR_VERSION}" STREQUAL ""
+ OR "${RAPIDJSON_MINOR_VERSION}" STREQUAL ""
+ OR "${RAPIDJSON_PATCH_VERSION}" STREQUAL "")
+ set(RAPIDJSON_VERSION "0.0.0")
+ else()
+ set(RAPIDJSON_VERSION
+ "${RAPIDJSON_MAJOR_VERSION}.${RAPIDJSON_MINOR_VERSION}.${RAPIDJSON_PATCH_VERSION}"
+ )
+ endif()
+endif()
+
+find_package_handle_standard_args(
+ RapidJSONAlt
+ REQUIRED_VARS RAPIDJSON_INCLUDE_DIR
+ VERSION_VAR RAPIDJSON_VERSION)
diff --git a/src/arrow/cpp/cmake_modules/FindSnappy.cmake b/src/arrow/cpp/cmake_modules/FindSnappy.cmake
new file mode 100644
index 000000000..747df3185
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindSnappy.cmake
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(ARROW_SNAPPY_USE_SHARED)
+ set(SNAPPY_LIB_NAMES)
+ if(CMAKE_IMPORT_LIBRARY_SUFFIX)
+ list(APPEND SNAPPY_LIB_NAMES
+ "${CMAKE_IMPORT_LIBRARY_PREFIX}snappy${CMAKE_IMPORT_LIBRARY_SUFFIX}")
+ endif()
+ list(APPEND SNAPPY_LIB_NAMES
+ "${CMAKE_SHARED_LIBRARY_PREFIX}snappy${CMAKE_SHARED_LIBRARY_SUFFIX}")
+else()
+ set(SNAPPY_STATIC_LIB_NAME_BASE "snappy")
+ if(MSVC)
+ set(SNAPPY_STATIC_LIB_NAME_BASE
+ "${SNAPPY_STATIC_LIB_NAME_BASE}${SNAPPY_MSVC_STATIC_LIB_SUFFIX}")
+ endif()
+ set(SNAPPY_LIB_NAMES
+ "${CMAKE_STATIC_LIBRARY_PREFIX}${SNAPPY_STATIC_LIB_NAME_BASE}${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+endif()
+
+if(Snappy_ROOT)
+ find_library(Snappy_LIB
+ NAMES ${SNAPPY_LIB_NAMES}
+ PATHS ${Snappy_ROOT}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ find_path(Snappy_INCLUDE_DIR
+ NAMES snappy.h
+ PATHS ${Snappy_ROOT}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+else()
+ find_library(Snappy_LIB NAMES ${SNAPPY_LIB_NAMES})
+ find_path(Snappy_INCLUDE_DIR
+ NAMES snappy.h
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+endif()
+
+find_package_handle_standard_args(Snappy REQUIRED_VARS Snappy_LIB Snappy_INCLUDE_DIR)
+
+if(Snappy_FOUND)
+ add_library(Snappy::snappy UNKNOWN IMPORTED)
+ set_target_properties(Snappy::snappy
+ PROPERTIES IMPORTED_LOCATION "${Snappy_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${Snappy_INCLUDE_DIR}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindThrift.cmake b/src/arrow/cpp/cmake_modules/FindThrift.cmake
new file mode 100644
index 000000000..750d8ce83
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindThrift.cmake
@@ -0,0 +1,144 @@
+# Copyright 2012 Cloudera Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# - Find Thrift (a cross platform RPC lib/tool)
+#
+# Variables used by this module, they can change the default behaviour and need
+# to be set before calling find_package:
+#
+# Thrift_ROOT - When set, this path is inspected instead of standard library
+# locations as the root of the Thrift installation.
+# The environment variable THRIFT_HOME overrides this variable.
+#
+# This module defines
+# THRIFT_VERSION, version string of ant if found
+# THRIFT_INCLUDE_DIR, where to find THRIFT headers
+# THRIFT_LIB, THRIFT library
+# THRIFT_FOUND, If false, do not try to use ant
+
+function(EXTRACT_THRIFT_VERSION)
+ if(THRIFT_INCLUDE_DIR)
+ file(READ "${THRIFT_INCLUDE_DIR}/thrift/config.h" THRIFT_CONFIG_H_CONTENT)
+ string(REGEX MATCH "#define PACKAGE_VERSION \"[0-9.]+\"" THRIFT_VERSION_DEFINITION
+ "${THRIFT_CONFIG_H_CONTENT}")
+ string(REGEX MATCH "[0-9.]+" THRIFT_VERSION "${THRIFT_VERSION_DEFINITION}")
+ set(THRIFT_VERSION
+ "${THRIFT_VERSION}"
+ PARENT_SCOPE)
+ else()
+ set(THRIFT_VERSION
+ ""
+ PARENT_SCOPE)
+ endif()
+endfunction(EXTRACT_THRIFT_VERSION)
+
+if(MSVC_TOOLCHAIN AND NOT DEFINED THRIFT_MSVC_LIB_SUFFIX)
+ if(NOT ARROW_THRIFT_USE_SHARED)
+ if(ARROW_USE_STATIC_CRT)
+ set(THRIFT_MSVC_LIB_SUFFIX "mt")
+ else()
+ set(THRIFT_MSVC_LIB_SUFFIX "md")
+ endif()
+ endif()
+endif()
+set(THRIFT_LIB_NAME_BASE "thrift${THRIFT_MSVC_LIB_SUFFIX}")
+
+if(ARROW_THRIFT_USE_SHARED)
+ set(THRIFT_LIB_NAMES thrift)
+ if(CMAKE_IMPORT_LIBRARY_SUFFIX)
+ list(APPEND
+ THRIFT_LIB_NAMES
+ "${CMAKE_IMPORT_LIBRARY_PREFIX}${THRIFT_LIB_NAME_BASE}${CMAKE_IMPORT_LIBRARY_SUFFIX}"
+ )
+ endif()
+ list(APPEND
+ THRIFT_LIB_NAMES
+ "${CMAKE_SHARED_LIBRARY_PREFIX}${THRIFT_LIB_NAME_BASE}${CMAKE_SHARED_LIBRARY_SUFFIX}"
+ )
+else()
+ set(THRIFT_LIB_NAMES
+ "${CMAKE_STATIC_LIBRARY_PREFIX}${THRIFT_LIB_NAME_BASE}${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+endif()
+
+if(Thrift_ROOT)
+ find_library(THRIFT_LIB
+ NAMES ${THRIFT_LIB_NAMES}
+ PATHS ${Thrift_ROOT}
+ PATH_SUFFIXES "lib/${CMAKE_LIBRARY_ARCHITECTURE}" "lib")
+ find_path(THRIFT_INCLUDE_DIR thrift/Thrift.h
+ PATHS ${Thrift_ROOT}
+ PATH_SUFFIXES "include")
+ find_program(THRIFT_COMPILER thrift
+ PATHS ${Thrift_ROOT}
+ PATH_SUFFIXES "bin")
+ extract_thrift_version()
+else()
+ # THRIFT-4760: The pkgconfig files are currently only installed when using autotools.
+ # Starting with 0.13, they are also installed for the CMake-based installations of Thrift.
+ find_package(PkgConfig QUIET)
+ pkg_check_modules(THRIFT_PC thrift)
+ if(THRIFT_PC_FOUND)
+ set(THRIFT_INCLUDE_DIR "${THRIFT_PC_INCLUDEDIR}")
+
+ list(APPEND THRIFT_PC_LIBRARY_DIRS "${THRIFT_PC_LIBDIR}")
+
+ find_library(THRIFT_LIB
+ NAMES ${THRIFT_LIB_NAMES}
+ PATHS ${THRIFT_PC_LIBRARY_DIRS}
+ NO_DEFAULT_PATH)
+ find_program(THRIFT_COMPILER thrift
+ HINTS ${THRIFT_PC_PREFIX}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES "bin")
+ set(THRIFT_VERSION ${THRIFT_PC_VERSION})
+ else()
+ find_library(THRIFT_LIB
+ NAMES ${THRIFT_LIB_NAMES}
+ PATH_SUFFIXES "lib/${CMAKE_LIBRARY_ARCHITECTURE}" "lib")
+ find_path(THRIFT_INCLUDE_DIR thrift/Thrift.h PATH_SUFFIXES "include")
+ find_program(THRIFT_COMPILER thrift PATH_SUFFIXES "bin")
+ extract_thrift_version()
+ endif()
+endif()
+
+if(THRIFT_COMPILER)
+ set(Thrift_COMPILER_FOUND TRUE)
+else()
+ set(Thrift_COMPILER_FOUND FALSE)
+endif()
+
+find_package_handle_standard_args(
+ Thrift
+ REQUIRED_VARS THRIFT_LIB THRIFT_INCLUDE_DIR
+ VERSION_VAR THRIFT_VERSION
+ HANDLE_COMPONENTS)
+
+if(Thrift_FOUND OR THRIFT_FOUND)
+ set(Thrift_FOUND TRUE)
+ if(ARROW_THRIFT_USE_SHARED)
+ add_library(thrift::thrift SHARED IMPORTED)
+ else()
+ add_library(thrift::thrift STATIC IMPORTED)
+ endif()
+ set_target_properties(thrift::thrift
+ PROPERTIES IMPORTED_LOCATION "${THRIFT_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${THRIFT_INCLUDE_DIR}")
+ if(WIN32 AND NOT MSVC_TOOLCHAIN)
+ # We don't need this for Visual C++ because Thrift uses
+ # "#pragma comment(lib, "Ws2_32.lib")" in
+ # thrift/windows/config.h for Visual C++.
+ set_target_properties(thrift::thrift PROPERTIES INTERFACE_LINK_LIBRARIES "ws2_32")
+ endif()
+endif()
diff --git a/src/arrow/cpp/cmake_modules/Findc-aresAlt.cmake b/src/arrow/cpp/cmake_modules/Findc-aresAlt.cmake
new file mode 100644
index 000000000..5213e8d12
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/Findc-aresAlt.cmake
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set(find_package_args)
+if(c-aresAlt_FIND_VERSION)
+ list(APPEND find_package_args ${c-aresAlt_FIND_VERSION})
+endif()
+if(c-aresAlt_FIND_QUIETLY)
+ list(APPEND find_package_args QUIET)
+endif()
+find_package(c-ares ${find_package_args})
+if(c-ares_FOUND)
+ set(c-aresAlt_FOUND TRUE)
+ return()
+endif()
+
+find_package(PkgConfig QUIET)
+pkg_check_modules(c-ares_PC libcares)
+if(c-ares_PC_FOUND)
+ set(c-ares_INCLUDE_DIR "${c-ares_PC_INCLUDEDIR}")
+
+ list(APPEND c-ares_PC_LIBRARY_DIRS "${c-ares_PC_LIBDIR}")
+ find_library(c-ares_LIB cares
+ PATHS ${c-ares_PC_LIBRARY_DIRS}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+elseif(c-ares_ROOT)
+ find_library(c-ares_LIB
+ NAMES cares
+ "${CMAKE_SHARED_LIBRARY_PREFIX}cares${CMAKE_SHARED_LIBRARY_SUFFIX}"
+ PATHS ${c-ares_ROOT}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ find_path(c-ares_INCLUDE_DIR
+ NAMES ares.h
+ PATHS ${c-ares_ROOT}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+else()
+ find_library(c-ares_LIB
+ NAMES cares
+ "${CMAKE_SHARED_LIBRARY_PREFIX}cares${CMAKE_SHARED_LIBRARY_SUFFIX}"
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES})
+ find_path(c-ares_INCLUDE_DIR
+ NAMES ares.h
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+endif()
+
+find_package_handle_standard_args(c-aresAlt REQUIRED_VARS c-ares_LIB c-ares_INCLUDE_DIR)
+
+if(c-aresAlt_FOUND)
+ if(NOT TARGET c-ares::cares)
+ add_library(c-ares::cares UNKNOWN IMPORTED)
+ set_target_properties(c-ares::cares
+ PROPERTIES IMPORTED_LOCATION "${c-ares_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${c-ares_INCLUDE_DIR}")
+ endif()
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindgRPCAlt.cmake b/src/arrow/cpp/cmake_modules/FindgRPCAlt.cmake
new file mode 100644
index 000000000..18b23f322
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindgRPCAlt.cmake
@@ -0,0 +1,76 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set(find_package_args)
+if(gRPCAlt_FIND_VERSION)
+ list(APPEND find_package_args ${gRPCAlt_FIND_VERSION})
+endif()
+if(gRPCAlt_FIND_QUIETLY)
+ list(APPEND find_package_args QUIET)
+endif()
+find_package(gRPC ${find_package_args})
+if(gRPC_FOUND)
+ set(gRPCAlt_FOUND TRUE)
+ return()
+endif()
+
+find_package(PkgConfig QUIET)
+pkg_check_modules(GRPCPP_PC grpc++)
+if(GRPCPP_PC_FOUND)
+ set(gRPCAlt_VERSION "${GRPCPP_PC_VERSION}")
+ set(GRPCPP_INCLUDE_DIRECTORIES ${GRPCPP_PC_INCLUDEDIR})
+ if(ARROW_GRPC_USE_SHARED)
+ set(GRPCPP_LINK_LIBRARIES ${GRPCPP_PC_LINK_LIBRARIES})
+ set(GRPCPP_LINK_OPTIONS ${GRPCPP_PC_LDFLAGS_OTHER})
+ set(GRPCPP_COMPILE_OPTIONS ${GRPCPP_PC_CFLAGS_OTHER})
+ else()
+ set(GRPCPP_LINK_LIBRARIES)
+ foreach(GRPCPP_LIBRARY_NAME ${GRPCPP_PC_STATIC_LIBRARIES})
+ find_library(GRPCPP_LIBRARY_${GRPCPP_LIBRARY_NAME}
+ NAMES "${CMAKE_STATIC_LIBRARY_PREFIX}${GRPCPP_LIBRARY_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ HINTS ${GRPCPP_PC_STATIC_LIBRARY_DIRS})
+ list(APPEND GRPCPP_LINK_LIBRARIES "${GRPCPP_LIBRARY_${GRPCPP_LIBRARY_NAME}}")
+ endforeach()
+ set(GRPCPP_LINK_OPTIONS ${GRPCPP_PC_STATIC_LDFLAGS_OTHER})
+ set(GRPCPP_COMPILE_OPTIONS ${GRPCPP_PC_STATIC_CFLAGS_OTHER})
+ endif()
+ list(GET GRPCPP_LINK_LIBRARIES 0 GRPCPP_IMPORTED_LOCATION)
+ list(REMOVE_AT GRPCPP_LINK_LIBRARIES 0)
+ find_program(GRPC_CPP_PLUGIN grpc_cpp_plugin
+ HINTS ${GRPCPP_PC_PREFIX}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES "bin")
+ set(gRPCAlt_FIND_PACKAGE_ARGS gRPCAlt REQUIRED_VARS GRPCPP_IMPORTED_LOCATION
+ GRPC_CPP_PLUGIN)
+ if(gRPCAlt_VERSION)
+ list(APPEND gRPCAlt_FIND_PACKAGE_ARGS VERSION_VAR gRPCAlt_VERSION)
+ endif()
+ find_package_handle_standard_args(${gRPCAlt_FIND_PACKAGE_ARGS})
+else()
+ set(gRPCAlt_FOUND FALSE)
+endif()
+
+if(gRPCAlt_FOUND)
+ add_library(gRPC::grpc++ UNKNOWN IMPORTED)
+ set_target_properties(gRPC::grpc++
+ PROPERTIES IMPORTED_LOCATION "${GRPCPP_IMPORTED_LOCATION}"
+ INTERFACE_COMPILE_OPTIONS "${GRPCPP_COMPILE_OPTIONS}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${GRPCPP_INCLUDE_DIRECTORIES}"
+ INTERFACE_LINK_LIBRARIES "${GRPCPP_LINK_LIBRARIES}"
+ INTERFACE_LINK_OPTIONS "${GRPCPP_LINK_OPTIONS}")
+
+ add_executable(gRPC::grpc_cpp_plugin IMPORTED)
+ set_target_properties(gRPC::grpc_cpp_plugin PROPERTIES IMPORTED_LOCATION
+ ${GRPC_CPP_PLUGIN})
+endif()
diff --git a/src/arrow/cpp/cmake_modules/FindgflagsAlt.cmake b/src/arrow/cpp/cmake_modules/FindgflagsAlt.cmake
new file mode 100644
index 000000000..e092ea3e9
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/FindgflagsAlt.cmake
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set(find_package_args)
+if(gflagsAlt_FIND_VERSION)
+ list(APPEND find_package_args ${gflagsAlt_FIND_VERSION})
+endif()
+if(gflagsAlt_FIND_QUIETLY)
+ list(APPEND find_package_args QUIET)
+endif()
+find_package(gflags ${find_package_args})
+if(gflags_FOUND)
+ set(gflagsAlt_FOUND TRUE)
+ return()
+endif()
+
+# TODO: Support version detection.
+
+if(gflags_ROOT)
+ find_library(gflags_LIB
+ NAMES gflags
+ PATHS ${gflags_ROOT}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ find_path(GFLAGS_INCLUDE_DIR
+ NAMES gflags/gflags.h
+ PATHS ${gflags_ROOT}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+else()
+ find_library(gflags_LIB NAMES gflags)
+ find_path(GFLAGS_INCLUDE_DIR
+ NAMES gflags/gflags.h
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+endif()
+
+find_package_handle_standard_args(gflagsAlt REQUIRED_VARS gflags_LIB GFLAGS_INCLUDE_DIR)
+
+if(gflagsAlt_FOUND)
+ add_library(gflags::gflags UNKNOWN IMPORTED)
+ set_target_properties(gflags::gflags
+ PROPERTIES IMPORTED_LOCATION "${gflags_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${GFLAGS_INCLUDE_DIR}")
+ set(GFLAGS_LIBRARIES gflags::gflags)
+endif()
diff --git a/src/arrow/cpp/cmake_modules/Findjemalloc.cmake b/src/arrow/cpp/cmake_modules/Findjemalloc.cmake
new file mode 100644
index 000000000..84bb81fcb
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/Findjemalloc.cmake
@@ -0,0 +1,94 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Tries to find jemalloc headers and libraries.
+#
+# Usage of this module as follows:
+#
+# find_package(jemalloc)
+#
+# Variables used by this module, they can change the default behaviour and need
+# to be set before calling find_package:
+#
+# JEMALLOC_HOME -
+# When set, this path is inspected instead of standard library locations as
+# the root of the jemalloc installation. The environment variable
+# JEMALLOC_HOME overrides this veriable.
+#
+# This module defines
+# JEMALLOC_INCLUDE_DIR, directory containing headers
+# JEMALLOC_SHARED_LIB, path to libjemalloc.so/dylib
+# JEMALLOC_FOUND, whether flatbuffers has been found
+
+if(NOT "${JEMALLOC_HOME}" STREQUAL "")
+ file(TO_CMAKE_PATH "${JEMALLOC_HOME}" _native_path)
+ list(APPEND _jemalloc_roots ${_native_path})
+elseif(JEMALLOC_HOME)
+ list(APPEND _jemalloc_roots ${JEMALLOC_HOME})
+endif()
+
+set(LIBJEMALLOC_NAMES jemalloc libjemalloc.so.1 libjemalloc.so.2 libjemalloc.dylib)
+
+# Try the parameterized roots, if they exist
+if(_jemalloc_roots)
+ find_path(JEMALLOC_INCLUDE_DIR
+ NAMES jemalloc/jemalloc.h
+ PATHS ${_jemalloc_roots}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES "include")
+ find_library(JEMALLOC_SHARED_LIB
+ NAMES ${LIBJEMALLOC_NAMES}
+ PATHS ${_jemalloc_roots}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES "lib")
+ find_library(JEMALLOC_STATIC_LIB
+ NAMES jemalloc_pic
+ PATHS ${_jemalloc_roots}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES "lib")
+else()
+ find_path(JEMALLOC_INCLUDE_DIR NAMES jemalloc/jemalloc.h)
+ message(STATUS ${JEMALLOC_INCLUDE_DIR})
+ find_library(JEMALLOC_SHARED_LIB NAMES ${LIBJEMALLOC_NAMES})
+ message(STATUS ${JEMALLOC_SHARED_LIB})
+ find_library(JEMALLOC_STATIC_LIB NAMES jemalloc_pic)
+ message(STATUS ${JEMALLOC_STATIC_LIB})
+endif()
+
+if(JEMALLOC_INCLUDE_DIR AND JEMALLOC_SHARED_LIB)
+ set(JEMALLOC_FOUND TRUE)
+else()
+ set(JEMALLOC_FOUND FALSE)
+endif()
+
+if(JEMALLOC_FOUND)
+ if(NOT jemalloc_FIND_QUIETLY)
+ message(STATUS "Found the jemalloc library: ${JEMALLOC_LIBRARIES}")
+ endif()
+else()
+ if(NOT jemalloc_FIND_QUIETLY)
+ set(JEMALLOC_ERR_MSG "Could not find the jemalloc library. Looked in ")
+ if(_flatbuffers_roots)
+ set(JEMALLOC_ERR_MSG "${JEMALLOC_ERR_MSG} in ${_jemalloc_roots}.")
+ else()
+ set(JEMALLOC_ERR_MSG "${JEMALLOC_ERR_MSG} system search paths.")
+ endif()
+ if(jemalloc_FIND_REQUIRED)
+ message(FATAL_ERROR "${JEMALLOC_ERR_MSG}")
+ else(jemalloc_FIND_REQUIRED)
+ message(STATUS "${JEMALLOC_ERR_MSG}")
+ endif(jemalloc_FIND_REQUIRED)
+ endif()
+endif()
+
+mark_as_advanced(JEMALLOC_INCLUDE_DIR JEMALLOC_SHARED_LIB)
diff --git a/src/arrow/cpp/cmake_modules/Findre2Alt.cmake b/src/arrow/cpp/cmake_modules/Findre2Alt.cmake
new file mode 100644
index 000000000..68abf1b75
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/Findre2Alt.cmake
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set(find_package_args)
+if(re2Alt_FIND_VERSION)
+ list(APPEND find_package_args ${re2Alt_FIND_VERSION})
+endif()
+if(re2Alt_FIND_QUIETLY)
+ list(APPEND find_package_args QUIET)
+endif()
+find_package(re2 ${find_package_args})
+if(re2_FOUND)
+ set(re2Alt_FOUND TRUE)
+ return()
+endif()
+
+find_package(PkgConfig QUIET)
+pkg_check_modules(RE2_PC re2)
+if(RE2_PC_FOUND)
+ set(RE2_INCLUDE_DIR "${RE2_PC_INCLUDEDIR}")
+
+ list(APPEND RE2_PC_LIBRARY_DIRS "${RE2_PC_LIBDIR}")
+ find_library(RE2_LIB re2
+ PATHS ${RE2_PC_LIBRARY_DIRS}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+
+ # On Fedora, the reported prefix is wrong. As users likely run into this,
+ # workaround.
+ # https://bugzilla.redhat.com/show_bug.cgi?id=1652589
+ if(UNIX
+ AND NOT APPLE
+ AND NOT RE2_LIB)
+ if(RE2_PC_PREFIX STREQUAL "/usr/local")
+ find_library(RE2_LIB re2)
+ endif()
+ endif()
+elseif(RE2_ROOT)
+ find_library(RE2_LIB
+ NAMES re2_static
+ re2
+ "${CMAKE_STATIC_LIBRARY_PREFIX}re2${RE2_MSVC_STATIC_LIB_SUFFIX}${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ "${CMAKE_SHARED_LIBRARY_PREFIX}re2${CMAKE_SHARED_LIBRARY_SUFFIX}"
+ PATHS ${RE2_ROOT}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ find_path(RE2_INCLUDE_DIR
+ NAMES re2/re2.h
+ PATHS ${RE2_ROOT}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+else()
+ find_library(RE2_LIB
+ NAMES re2_static
+ re2
+ "${CMAKE_STATIC_LIBRARY_PREFIX}re2${RE2_MSVC_STATIC_LIB_SUFFIX}${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ "${CMAKE_SHARED_LIBRARY_PREFIX}re2${CMAKE_SHARED_LIBRARY_SUFFIX}"
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES})
+ find_path(RE2_INCLUDE_DIR
+ NAMES re2/re2.h
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+endif()
+
+find_package_handle_standard_args(re2Alt REQUIRED_VARS RE2_LIB RE2_INCLUDE_DIR)
+
+if(re2Alt_FOUND)
+ if(NOT TARGET re2::re2)
+ add_library(re2::re2 UNKNOWN IMPORTED)
+ set_target_properties(re2::re2
+ PROPERTIES IMPORTED_LOCATION "${RE2_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${RE2_INCLUDE_DIR}")
+ endif()
+endif()
diff --git a/src/arrow/cpp/cmake_modules/Findutf8proc.cmake b/src/arrow/cpp/cmake_modules/Findutf8proc.cmake
new file mode 100644
index 000000000..4d732f186
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/Findutf8proc.cmake
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+function(extract_utf8proc_version)
+ if(utf8proc_INCLUDE_DIR)
+ file(READ "${utf8proc_INCLUDE_DIR}/utf8proc.h" UTF8PROC_H_CONTENT)
+
+ string(REGEX MATCH "#define UTF8PROC_VERSION_MAJOR [0-9]+"
+ UTF8PROC_MAJOR_VERSION_DEFINITION "${UTF8PROC_H_CONTENT}")
+ string(REGEX MATCH "#define UTF8PROC_VERSION_MINOR [0-9]+"
+ UTF8PROC_MINOR_VERSION_DEFINITION "${UTF8PROC_H_CONTENT}")
+ string(REGEX MATCH "#define UTF8PROC_VERSION_PATCH [0-9]+"
+ UTF8PROC_PATCH_VERSION_DEFINITION "${UTF8PROC_H_CONTENT}")
+
+ string(REGEX MATCH "[0-9]+$" UTF8PROC_MAJOR_VERSION
+ "${UTF8PROC_MAJOR_VERSION_DEFINITION}")
+ string(REGEX MATCH "[0-9]+$" UTF8PROC_MINOR_VERSION
+ "${UTF8PROC_MINOR_VERSION_DEFINITION}")
+ string(REGEX MATCH "[0-9]+$" UTF8PROC_PATCH_VERSION
+ "${UTF8PROC_PATCH_VERSION_DEFINITION}")
+ set(utf8proc_VERSION
+ "${UTF8PROC_MAJOR_VERSION}.${UTF8PROC_MINOR_VERSION}.${UTF8PROC_PATCH_VERSION}"
+ PARENT_SCOPE)
+ else()
+ set(utf8proc_VERSION
+ ""
+ PARENT_SCOPE)
+ endif()
+endfunction(extract_utf8proc_version)
+
+if(ARROW_UTF8PROC_USE_SHARED)
+ set(utf8proc_LIB_NAMES)
+ if(CMAKE_IMPORT_LIBRARY_SUFFIX)
+ list(APPEND utf8proc_LIB_NAMES
+ "${CMAKE_IMPORT_LIBRARY_PREFIX}utf8proc${CMAKE_IMPORT_LIBRARY_SUFFIX}")
+ endif()
+ list(APPEND utf8proc_LIB_NAMES
+ "${CMAKE_SHARED_LIBRARY_PREFIX}utf8proc${CMAKE_SHARED_LIBRARY_SUFFIX}")
+else()
+ if(MSVC AND NOT DEFINED utf8proc_MSVC_STATIC_LIB_SUFFIX)
+ set(utf8proc_MSVC_STATIC_LIB_SUFFIX "_static")
+ endif()
+ set(utf8proc_STATIC_LIB_SUFFIX
+ "${utf8proc_MSVC_STATIC_LIB_SUFFIX}${CMAKE_STATIC_LIBRARY_SUFFIX}")
+ set(utf8proc_LIB_NAMES
+ "${CMAKE_STATIC_LIBRARY_PREFIX}utf8proc${utf8proc_STATIC_LIB_SUFFIX}")
+endif()
+
+if(utf8proc_ROOT)
+ find_library(utf8proc_LIB
+ NAMES ${utf8proc_LIB_NAMES}
+ PATHS ${utf8proc_ROOT}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ find_path(utf8proc_INCLUDE_DIR
+ NAMES utf8proc.h
+ PATHS ${utf8proc_ROOT}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+ extract_utf8proc_version()
+else()
+ find_library(utf8proc_LIB
+ NAMES ${utf8proc_LIB_NAMES}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES})
+ find_path(utf8proc_INCLUDE_DIR
+ NAMES utf8proc.h
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+ extract_utf8proc_version()
+endif()
+
+find_package_handle_standard_args(
+ utf8proc
+ REQUIRED_VARS utf8proc_LIB utf8proc_INCLUDE_DIR
+ VERSION_VAR utf8proc_VERSION)
+
+if(utf8proc_FOUND)
+ set(utf8proc_FOUND TRUE)
+ add_library(utf8proc::utf8proc UNKNOWN IMPORTED)
+ set_target_properties(utf8proc::utf8proc
+ PROPERTIES IMPORTED_LOCATION "${utf8proc_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${utf8proc_INCLUDE_DIR}")
+ if(NOT ARROW_UTF8PROC_USE_SHARED)
+ set_target_properties(utf8proc::utf8proc PROPERTIES INTERFACE_COMPILER_DEFINITIONS
+ "UTF8PROC_STATIC")
+ endif()
+endif()
diff --git a/src/arrow/cpp/cmake_modules/Findzstd.cmake b/src/arrow/cpp/cmake_modules/Findzstd.cmake
new file mode 100644
index 000000000..3fc14ec0d
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/Findzstd.cmake
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(MSVC AND NOT DEFINED ZSTD_MSVC_LIB_PREFIX)
+ set(ZSTD_MSVC_LIB_PREFIX "lib")
+endif()
+set(ZSTD_LIB_NAME_BASE "${ZSTD_MSVC_LIB_PREFIX}zstd")
+
+if(ARROW_ZSTD_USE_SHARED)
+ set(ZSTD_LIB_NAMES)
+ if(CMAKE_IMPORT_LIBRARY_SUFFIX)
+ list(APPEND
+ ZSTD_LIB_NAMES
+ "${CMAKE_IMPORT_LIBRARY_PREFIX}${ZSTD_LIB_NAME_BASE}${CMAKE_IMPORT_LIBRARY_SUFFIX}"
+ )
+ endif()
+ list(APPEND ZSTD_LIB_NAMES
+ "${CMAKE_SHARED_LIBRARY_PREFIX}${ZSTD_LIB_NAME_BASE}${CMAKE_SHARED_LIBRARY_SUFFIX}"
+ )
+else()
+ if(MSVC AND NOT DEFINED ZSTD_MSVC_STATIC_LIB_SUFFIX)
+ set(ZSTD_MSVC_STATIC_LIB_SUFFIX "_static")
+ endif()
+ set(ZSTD_STATIC_LIB_SUFFIX
+ "${ZSTD_MSVC_STATIC_LIB_SUFFIX}${CMAKE_STATIC_LIBRARY_SUFFIX}")
+ set(ZSTD_LIB_NAMES
+ "${CMAKE_STATIC_LIBRARY_PREFIX}${ZSTD_LIB_NAME_BASE}${ZSTD_STATIC_LIB_SUFFIX}")
+endif()
+
+# First, find via if specified ZSTD_ROOT
+if(ZSTD_ROOT)
+ message(STATUS "Using ZSTD_ROOT: ${ZSTD_ROOT}")
+ find_library(ZSTD_LIB
+ NAMES ${ZSTD_LIB_NAMES}
+ PATHS ${ZSTD_ROOT}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES}
+ NO_DEFAULT_PATH)
+ find_path(ZSTD_INCLUDE_DIR
+ NAMES zstd.h
+ PATHS ${ZSTD_ROOT}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+
+else()
+ # Second, find via pkg_check_modules
+ find_package(PkgConfig QUIET)
+ pkg_check_modules(ZSTD_PC libzstd)
+ if(ZSTD_PC_FOUND)
+ set(ZSTD_INCLUDE_DIR "${ZSTD_PC_INCLUDEDIR}")
+
+ list(APPEND ZSTD_PC_LIBRARY_DIRS "${ZSTD_PC_LIBDIR}")
+ find_library(ZSTD_LIB
+ NAMES ${ZSTD_LIB_NAMES}
+ PATHS ${ZSTD_PC_LIBRARY_DIRS}
+ NO_DEFAULT_PATH
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES})
+ else()
+ # Third, check all other CMake paths
+ find_library(ZSTD_LIB
+ NAMES ${ZSTD_LIB_NAMES}
+ PATH_SUFFIXES ${ARROW_LIBRARY_PATH_SUFFIXES})
+ find_path(ZSTD_INCLUDE_DIR
+ NAMES zstd.h
+ PATH_SUFFIXES ${ARROW_INCLUDE_PATH_SUFFIXES})
+ endif()
+endif()
+
+find_package_handle_standard_args(zstd REQUIRED_VARS ZSTD_LIB ZSTD_INCLUDE_DIR)
+
+if(zstd_FOUND)
+ add_library(zstd::libzstd UNKNOWN IMPORTED)
+ set_target_properties(zstd::libzstd
+ PROPERTIES IMPORTED_LOCATION "${ZSTD_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${ZSTD_INCLUDE_DIR}")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/SetupCxxFlags.cmake b/src/arrow/cpp/cmake_modules/SetupCxxFlags.cmake
new file mode 100644
index 000000000..c1a1ba043
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/SetupCxxFlags.cmake
@@ -0,0 +1,648 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Check if the target architecture and compiler supports some special
+# instruction sets that would boost performance.
+include(CheckCXXCompilerFlag)
+include(CheckCXXSourceCompiles)
+# Get cpu architecture
+
+message(STATUS "System processor: ${CMAKE_SYSTEM_PROCESSOR}")
+
+if(NOT DEFINED ARROW_CPU_FLAG)
+ if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|ARM64|arm64")
+ set(ARROW_CPU_FLAG "armv8")
+ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "armv7")
+ set(ARROW_CPU_FLAG "armv7")
+ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "powerpc|ppc")
+ set(ARROW_CPU_FLAG "ppc")
+ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "s390x")
+ set(ARROW_CPU_FLAG "s390x")
+ else()
+ set(ARROW_CPU_FLAG "x86")
+ endif()
+endif()
+
+# Check architecture specific compiler flags
+if(ARROW_CPU_FLAG STREQUAL "x86")
+ # x86/amd64 compiler flags, msvc/gcc/clang
+ if(MSVC)
+ set(ARROW_SSE4_2_FLAG "")
+ set(ARROW_AVX2_FLAG "/arch:AVX2")
+ set(ARROW_AVX512_FLAG "/arch:AVX512")
+ set(CXX_SUPPORTS_SSE4_2 TRUE)
+ else()
+ set(ARROW_SSE4_2_FLAG "-msse4.2")
+ set(ARROW_AVX2_FLAG "-march=haswell")
+ # skylake-avx512 consists of AVX512F,AVX512BW,AVX512VL,AVX512CD,AVX512DQ
+ set(ARROW_AVX512_FLAG "-march=skylake-avx512 -mbmi2")
+ # Append the avx2/avx512 subset option also, fix issue ARROW-9877 for homebrew-cpp
+ set(ARROW_AVX2_FLAG "${ARROW_AVX2_FLAG} -mavx2")
+ set(ARROW_AVX512_FLAG
+ "${ARROW_AVX512_FLAG} -mavx512f -mavx512cd -mavx512vl -mavx512dq -mavx512bw")
+ check_cxx_compiler_flag(${ARROW_SSE4_2_FLAG} CXX_SUPPORTS_SSE4_2)
+ endif()
+ check_cxx_compiler_flag(${ARROW_AVX2_FLAG} CXX_SUPPORTS_AVX2)
+ if(MINGW)
+ # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65782
+ message(STATUS "Disable AVX512 support on MINGW for now")
+ else()
+ # Check for AVX512 support in the compiler.
+ set(OLD_CMAKE_REQURED_FLAGS ${CMAKE_REQUIRED_FLAGS})
+ set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} ${ARROW_AVX512_FLAG}")
+ check_cxx_source_compiles("
+ #ifdef _MSC_VER
+ #include <intrin.h>
+ #else
+ #include <immintrin.h>
+ #endif
+
+ int main() {
+ __m512i mask = _mm512_set1_epi32(0x1);
+ char out[32];
+ _mm512_storeu_si512(out, mask);
+ return 0;
+ }"
+ CXX_SUPPORTS_AVX512)
+ set(CMAKE_REQUIRED_FLAGS ${OLD_CMAKE_REQURED_FLAGS})
+ endif()
+ # Runtime SIMD level it can get from compiler and ARROW_RUNTIME_SIMD_LEVEL
+ if(CXX_SUPPORTS_SSE4_2 AND ARROW_RUNTIME_SIMD_LEVEL MATCHES
+ "^(SSE4_2|AVX2|AVX512|MAX)$")
+ set(ARROW_HAVE_RUNTIME_SSE4_2 ON)
+ add_definitions(-DARROW_HAVE_RUNTIME_SSE4_2)
+ endif()
+ if(CXX_SUPPORTS_AVX2 AND ARROW_RUNTIME_SIMD_LEVEL MATCHES "^(AVX2|AVX512|MAX)$")
+ set(ARROW_HAVE_RUNTIME_AVX2 ON)
+ add_definitions(-DARROW_HAVE_RUNTIME_AVX2 -DARROW_HAVE_RUNTIME_BMI2)
+ endif()
+ if(CXX_SUPPORTS_AVX512 AND ARROW_RUNTIME_SIMD_LEVEL MATCHES "^(AVX512|MAX)$")
+ set(ARROW_HAVE_RUNTIME_AVX512 ON)
+ add_definitions(-DARROW_HAVE_RUNTIME_AVX512 -DARROW_HAVE_RUNTIME_BMI2)
+ endif()
+ if(ARROW_SIMD_LEVEL STREQUAL "DEFAULT")
+ set(ARROW_SIMD_LEVEL "SSE4_2")
+ endif()
+elseif(ARROW_CPU_FLAG STREQUAL "ppc")
+ # power compiler flags, gcc/clang only
+ set(ARROW_ALTIVEC_FLAG "-maltivec")
+ check_cxx_compiler_flag(${ARROW_ALTIVEC_FLAG} CXX_SUPPORTS_ALTIVEC)
+ if(ARROW_SIMD_LEVEL STREQUAL "DEFAULT")
+ set(ARROW_SIMD_LEVEL "NONE")
+ endif()
+elseif(ARROW_CPU_FLAG STREQUAL "armv8")
+ # Arm64 compiler flags, gcc/clang only
+ set(ARROW_ARMV8_ARCH_FLAG "-march=${ARROW_ARMV8_ARCH}")
+ check_cxx_compiler_flag(${ARROW_ARMV8_ARCH_FLAG} CXX_SUPPORTS_ARMV8_ARCH)
+ if(ARROW_SIMD_LEVEL STREQUAL "DEFAULT")
+ set(ARROW_SIMD_LEVEL "NEON")
+ endif()
+endif()
+
+# Support C11
+if(NOT DEFINED CMAKE_C_STANDARD)
+ set(CMAKE_C_STANDARD 11)
+endif()
+
+# This ensures that things like c++11 get passed correctly
+if(NOT DEFINED CMAKE_CXX_STANDARD)
+ set(CMAKE_CXX_STANDARD 11)
+endif()
+
+# We require a C++11 compliant compiler
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+# ARROW-6848: Do not use GNU (or other CXX) extensions
+set(CMAKE_CXX_EXTENSIONS OFF)
+
+# Build with -fPIC so that can static link our libraries into other people's
+# shared libraries
+set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+
+string(TOUPPER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE)
+
+set(UNKNOWN_COMPILER_MESSAGE
+ "Unknown compiler: ${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}")
+
+# compiler flags that are common across debug/release builds
+if(WIN32)
+ # TODO(wesm): Change usages of C runtime functions that MSVC says are
+ # insecure, like std::getenv
+ add_definitions(-D_CRT_SECURE_NO_WARNINGS)
+
+ if(MSVC)
+ if(MSVC_VERSION VERSION_LESS 19)
+ message(FATAL_ERROR "Only MSVC 2015 (Version 19.0) and later are supported
+ by Arrow. Found version ${CMAKE_CXX_COMPILER_VERSION}.")
+ endif()
+
+ # ARROW-1931 See https://github.com/google/googletest/issues/1318
+ #
+ # This is added to CMAKE_CXX_FLAGS instead of CXX_COMMON_FLAGS since only the
+ # former is passed into the external projects
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /D_SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING")
+
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
+ # clang-cl
+ set(CXX_COMMON_FLAGS "-EHsc")
+ else()
+ # Fix annoying D9025 warning
+ string(REPLACE "/W3" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
+
+ # Set desired warning level (e.g. set /W4 for more warnings)
+ #
+ # ARROW-2986: Without /EHsc we get C4530 warning
+ set(CXX_COMMON_FLAGS "/W3 /EHsc")
+ endif()
+
+ # Disable C5105 (macro expansion producing 'defined' has undefined
+ # behavior) warning because there are codes that produce this
+ # warning in Windows Kits. e.g.:
+ #
+ # #define _CRT_INTERNAL_NONSTDC_NAMES \
+ # ( \
+ # ( defined _CRT_DECLARE_NONSTDC_NAMES && _CRT_DECLARE_NONSTDC_NAMES) || \
+ # (!defined _CRT_DECLARE_NONSTDC_NAMES && !__STDC__ ) \
+ # )
+ #
+ # See also:
+ # * C5105: https://docs.microsoft.com/en-US/cpp/error-messages/compiler-warnings/c5105
+ # * Related reports:
+ # * https://developercommunity.visualstudio.com/content/problem/387684/c5105-with-stdioh-and-experimentalpreprocessor.html
+ # * https://developercommunity.visualstudio.com/content/problem/1249671/stdc17-generates-warning-compiling-windowsh.html
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /wd5105")
+
+ if(ARROW_USE_STATIC_CRT)
+ foreach(c_flag
+ CMAKE_CXX_FLAGS
+ CMAKE_CXX_FLAGS_RELEASE
+ CMAKE_CXX_FLAGS_DEBUG
+ CMAKE_CXX_FLAGS_MINSIZEREL
+ CMAKE_CXX_FLAGS_RELWITHDEBINFO
+ CMAKE_C_FLAGS
+ CMAKE_C_FLAGS_RELEASE
+ CMAKE_C_FLAGS_DEBUG
+ CMAKE_C_FLAGS_MINSIZEREL
+ CMAKE_C_FLAGS_RELWITHDEBINFO)
+ string(REPLACE "/MD" "-MT" ${c_flag} "${${c_flag}}")
+ endforeach()
+ endif()
+
+ # Support large object code
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /bigobj")
+
+ # We may use UTF-8 in source code such as
+ # cpp/src/arrow/compute/kernels/scalar_string_test.cc
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /utf-8")
+ else()
+ # MinGW
+ check_cxx_compiler_flag(-Wa,-mbig-obj CXX_SUPPORTS_BIG_OBJ)
+ if(CXX_SUPPORTS_BIG_OBJ)
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wa,-mbig-obj")
+ endif()
+ endif(MSVC)
+else()
+ # Common flags set below with warning level
+ set(CXX_COMMON_FLAGS "")
+endif()
+
+# BUILD_WARNING_LEVEL add warning/error compiler flags. The possible values are
+# - PRODUCTION: Build with `-Wall` but do not add `-Werror`, so warnings do not
+# halt the build.
+# - CHECKIN: Build with `-Wall` and `-Wextra`. Also, add `-Werror` in debug mode
+# so that any important warnings fail the build.
+# - EVERYTHING: Like `CHECKIN`, but possible extra flags depending on the
+# compiler, including `-Wextra`, `-Weverything`, `-pedantic`.
+# This is the most aggressive warning level.
+
+# Defaults BUILD_WARNING_LEVEL to `CHECKIN`, unless CMAKE_BUILD_TYPE is
+# `RELEASE`, then it will default to `PRODUCTION`. The goal of defaulting to
+# `CHECKIN` is to avoid friction with long response time from CI.
+if(NOT BUILD_WARNING_LEVEL)
+ if("${CMAKE_BUILD_TYPE}" STREQUAL "RELEASE")
+ set(BUILD_WARNING_LEVEL PRODUCTION)
+ else()
+ set(BUILD_WARNING_LEVEL CHECKIN)
+ endif()
+endif(NOT BUILD_WARNING_LEVEL)
+string(TOUPPER ${BUILD_WARNING_LEVEL} BUILD_WARNING_LEVEL)
+
+message(STATUS "Arrow build warning level: ${BUILD_WARNING_LEVEL}")
+
+macro(arrow_add_werror_if_debug)
+ if("${CMAKE_BUILD_TYPE}" STREQUAL "DEBUG")
+ # Treat all compiler warnings as errors
+ if(MSVC)
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /WX")
+ else()
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Werror")
+ endif()
+ endif()
+endmacro()
+
+if("${BUILD_WARNING_LEVEL}" STREQUAL "CHECKIN")
+ # Pre-checkin builds
+ if(MSVC)
+ # https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warnings-by-compiler-version
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /W3")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /wd4365")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /wd4267")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /wd4838")
+ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL
+ "Clang")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wall")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wextra")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wdocumentation")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-missing-braces")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-unused-parameter")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-constant-logical-operand")
+ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wall")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-conversion")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-deprecated-declarations")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-sign-conversion")
+ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Intel")
+ if(WIN32)
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /Wall")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /Wno-deprecated")
+ else()
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wall")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-deprecated")
+ endif()
+ else()
+ message(FATAL_ERROR "${UNKNOWN_COMPILER_MESSAGE}")
+ endif()
+ arrow_add_werror_if_debug()
+
+elseif("${BUILD_WARNING_LEVEL}" STREQUAL "EVERYTHING")
+ # Pedantic builds for fixing warnings
+ if(MSVC)
+ string(REPLACE "/W3" "" CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS}")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /Wall")
+ # https://docs.microsoft.com/en-us/cpp/build/reference/compiler-option-warning-level
+ # /wdnnnn disables a warning where "nnnn" is a warning number
+ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL
+ "Clang")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Weverything")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-c++98-compat")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-c++98-compat-pedantic")
+ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wall")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wpedantic")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wextra")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-unused-parameter")
+ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Intel")
+ if(WIN32)
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /Wall")
+ else()
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wall")
+ endif()
+ else()
+ message(FATAL_ERROR "${UNKNOWN_COMPILER_MESSAGE}")
+ endif()
+ arrow_add_werror_if_debug()
+
+else()
+ # Production builds (warning are not treated as errors)
+ if(MSVC)
+ # https://docs.microsoft.com/en-us/cpp/build/reference/compiler-option-warning-level
+ # TODO: Enable /Wall and disable individual warnings until build compiles without errors
+ # /wdnnnn disables a warning where "nnnn" is a warning number
+ string(REPLACE "/W3" "" CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS}")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /W3")
+ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
+ OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
+ OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wall")
+ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Intel")
+ if(WIN32)
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /Wall")
+ else()
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wall")
+ endif()
+ else()
+ message(FATAL_ERROR "${UNKNOWN_COMPILER_MESSAGE}")
+ endif()
+
+endif()
+
+if(MSVC)
+ # Disable annoying "performance warning" about int-to-bool conversion
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /wd4800")
+
+ # Disable unchecked iterator warnings, equivalent to /D_SCL_SECURE_NO_WARNINGS
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /wd4996")
+
+ # Disable "switch statement contains 'default' but no 'case' labels" warning
+ # (required for protobuf, see https://github.com/protocolbuffers/protobuf/issues/6885)
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} /wd4065")
+
+elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+ if(CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL "7.0" OR CMAKE_CXX_COMPILER_VERSION
+ VERSION_GREATER "7.0")
+ # Without this, gcc >= 7 warns related to changes in C++17
+ set(CXX_ONLY_FLAGS "${CXX_ONLY_FLAGS} -Wno-noexcept-type")
+ endif()
+
+ if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "5.2")
+ # Disabling semantic interposition allows faster calling conventions
+ # when calling global functions internally, and can also help inlining.
+ # See https://stackoverflow.com/questions/35745543/new-option-in-gcc-5-3-fno-semantic-interposition
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -fno-semantic-interposition")
+ endif()
+
+ if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.9")
+ # Add colors when paired with ninja
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fdiagnostics-color=always")
+ endif()
+
+ if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "6.0")
+ # Work around https://gcc.gnu.org/bugzilla/show_bug.cgi?id=43407
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-attributes")
+ endif()
+
+ if(CMAKE_UNITY_BUILD)
+ # Work around issue similar to https://bugs.webkit.org/show_bug.cgi?id=176869
+ set(CXX_ONLY_FLAGS "${CXX_ONLY_FLAGS} -Wno-subobject-linkage")
+ endif()
+
+elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL
+ "Clang")
+ # Clang options for all builds
+
+ # Using Clang with ccache causes a bunch of spurious warnings that are
+ # purportedly fixed in the next version of ccache. See the following for details:
+ #
+ # http://petereisentraut.blogspot.com/2011/05/ccache-and-clang.html
+ # http://petereisentraut.blogspot.com/2011/09/ccache-and-clang-part-2.html
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Qunused-arguments")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Qunused-arguments")
+
+ # Avoid error when an unknown warning flag is passed
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-unknown-warning-option")
+ # Add colors when paired with ninja
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fcolor-diagnostics")
+
+ # Don't complain about optimization passes that were not possible
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-pass-failed")
+
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
+ # Depending on the default OSX_DEPLOYMENT_TARGET (< 10.9), libstdc++ may be
+ # the default standard library which does not support C++11. libc++ is the
+ # default from 10.9 onward.
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -stdlib=libc++")
+ endif()
+endif()
+
+# if build warning flags is set, add to CXX_COMMON_FLAGS
+if(BUILD_WARNING_FLAGS)
+ # Use BUILD_WARNING_FLAGS with BUILD_WARNING_LEVEL=everything to disable
+ # warnings (use with Clang's -Weverything flag to find potential errors)
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} ${BUILD_WARNING_FLAGS}")
+endif(BUILD_WARNING_FLAGS)
+
+# Only enable additional instruction sets if they are supported
+if(ARROW_CPU_FLAG STREQUAL "x86")
+ if(MINGW)
+ # Enable _xgetbv() intrinsic to query OS support for ZMM register saves
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -mxsave")
+ endif()
+ if(ARROW_SIMD_LEVEL STREQUAL "AVX512")
+ if(NOT CXX_SUPPORTS_AVX512)
+ message(FATAL_ERROR "AVX512 required but compiler doesn't support it.")
+ endif()
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} ${ARROW_AVX512_FLAG}")
+ add_definitions(-DARROW_HAVE_AVX512 -DARROW_HAVE_AVX2 -DARROW_HAVE_BMI2
+ -DARROW_HAVE_SSE4_2)
+ elseif(ARROW_SIMD_LEVEL STREQUAL "AVX2")
+ if(NOT CXX_SUPPORTS_AVX2)
+ message(FATAL_ERROR "AVX2 required but compiler doesn't support it.")
+ endif()
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} ${ARROW_AVX2_FLAG}")
+ add_definitions(-DARROW_HAVE_AVX2 -DARROW_HAVE_BMI2 -DARROW_HAVE_SSE4_2)
+ elseif(ARROW_SIMD_LEVEL STREQUAL "SSE4_2")
+ if(NOT CXX_SUPPORTS_SSE4_2)
+ message(FATAL_ERROR "SSE4.2 required but compiler doesn't support it.")
+ endif()
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} ${ARROW_SSE4_2_FLAG}")
+ add_definitions(-DARROW_HAVE_SSE4_2)
+ elseif(NOT ARROW_SIMD_LEVEL STREQUAL "NONE")
+ message(WARNING "ARROW_SIMD_LEVEL=${ARROW_SIMD_LEVEL} not supported by x86.")
+ endif()
+endif()
+
+if(ARROW_CPU_FLAG STREQUAL "ppc")
+ if(CXX_SUPPORTS_ALTIVEC AND ARROW_ALTIVEC)
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} ${ARROW_ALTIVEC_FLAG}")
+ endif()
+endif()
+
+if(ARROW_CPU_FLAG STREQUAL "armv8")
+ if(ARROW_SIMD_LEVEL STREQUAL "NEON")
+ set(ARROW_HAVE_NEON ON)
+
+ if(NOT CXX_SUPPORTS_ARMV8_ARCH)
+ message(FATAL_ERROR "Unsupported arch flag: ${ARROW_ARMV8_ARCH_FLAG}.")
+ endif()
+ if(ARROW_ARMV8_ARCH_FLAG MATCHES "native")
+ message(FATAL_ERROR "native arch not allowed, please specify arch explicitly.")
+ endif()
+ set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} ${ARROW_ARMV8_ARCH_FLAG}")
+
+ add_definitions(-DARROW_HAVE_NEON)
+
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS
+ "5.4")
+ message(WARNING "Disable Armv8 CRC and Crypto as compiler doesn't support them well."
+ )
+ else()
+ if(ARROW_ARMV8_ARCH_FLAG MATCHES "\\+crypto")
+ add_definitions(-DARROW_HAVE_ARMV8_CRYPTO)
+ endif()
+ # armv8.1+ implies crc support
+ if(ARROW_ARMV8_ARCH_FLAG MATCHES "armv8\\.[1-9]|\\+crc")
+ add_definitions(-DARROW_HAVE_ARMV8_CRC)
+ endif()
+ endif()
+ elseif(NOT ARROW_SIMD_LEVEL STREQUAL "NONE")
+ message(WARNING "ARROW_SIMD_LEVEL=${ARROW_SIMD_LEVEL} not supported by Arm.")
+ endif()
+endif()
+
+# ----------------------------------------------------------------------
+# Setup Gold linker, if available. Code originally from Apache Kudu
+
+# Interrogates the linker version via the C++ compiler to determine whether
+# we're using the gold linker, and if so, extracts its version.
+#
+# If the gold linker is being used, sets GOLD_VERSION in the parent scope with
+# the extracted version.
+#
+# Any additional arguments are passed verbatim into the C++ compiler invocation.
+function(GET_GOLD_VERSION)
+ # The gold linker is only for ELF binaries, which macOS doesn't use.
+ execute_process(COMMAND ${CMAKE_CXX_COMPILER} "-Wl,--version" ${ARGN}
+ ERROR_QUIET
+ OUTPUT_VARIABLE LINKER_OUTPUT)
+ # We're expecting LINKER_OUTPUT to look like one of these:
+ # GNU gold (version 2.24) 1.11
+ # GNU gold (GNU Binutils for Ubuntu 2.30) 1.15
+ if(LINKER_OUTPUT MATCHES "GNU gold")
+ string(REGEX MATCH "GNU gold \\([^\\)]*\\) (([0-9]+\\.?)+)" _ "${LINKER_OUTPUT}")
+ if(NOT CMAKE_MATCH_1)
+ message(SEND_ERROR "Could not extract GNU gold version. "
+ "Linker version output: ${LINKER_OUTPUT}")
+ endif()
+ set(GOLD_VERSION
+ "${CMAKE_MATCH_1}"
+ PARENT_SCOPE)
+ endif()
+endfunction()
+
+# Is the compiler hard-wired to use the gold linker?
+if(NOT WIN32 AND NOT APPLE)
+ get_gold_version()
+ if(GOLD_VERSION)
+ set(MUST_USE_GOLD 1)
+ elseif(ARROW_USE_LD_GOLD)
+ # Can the compiler optionally enable the gold linker?
+ get_gold_version("-fuse-ld=gold")
+
+ # We can't use the gold linker if it's inside devtoolset because the compiler
+ # won't find it when invoked directly from make/ninja (which is typically
+ # done outside devtoolset).
+ execute_process(COMMAND which ld.gold
+ OUTPUT_VARIABLE GOLD_LOCATION
+ OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET)
+ if("${GOLD_LOCATION}" MATCHES "^/opt/rh/devtoolset")
+ message("Skipping optional gold linker (version ${GOLD_VERSION}) because "
+ "it's in devtoolset")
+ set(GOLD_VERSION)
+ endif()
+ endif()
+
+ if(GOLD_VERSION)
+ # Older versions of the gold linker are vulnerable to a bug [1] which
+ # prevents weak symbols from being overridden properly. This leads to
+ # omitting of dependencies like tcmalloc (used in Kudu, where this
+ # workaround was written originally)
+ #
+ # How we handle this situation depends on other factors:
+ # - If gold is optional, we won't use it.
+ # - If gold is required, we'll either:
+ # - Raise an error in RELEASE builds (we shouldn't release such a product), or
+ # - Drop tcmalloc in all other builds.
+ #
+ # 1. https://sourceware.org/bugzilla/show_bug.cgi?id=16979.
+ if("${GOLD_VERSION}" VERSION_LESS "1.12")
+ set(ARROW_BUGGY_GOLD 1)
+ endif()
+ if(MUST_USE_GOLD)
+ message("Using hard-wired gold linker (version ${GOLD_VERSION})")
+ if(ARROW_BUGGY_GOLD)
+ if("${ARROW_LINK}" STREQUAL "d" AND "${CMAKE_BUILD_TYPE}" STREQUAL "RELEASE")
+ message(SEND_ERROR "Configured to use buggy gold with dynamic linking "
+ "in a RELEASE build")
+ endif()
+ endif()
+ elseif(NOT ARROW_BUGGY_GOLD)
+ # The Gold linker must be manually enabled.
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fuse-ld=gold")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fuse-ld=gold")
+ message("Using optional gold linker (version ${GOLD_VERSION})")
+ else()
+ message("Optional gold linker is buggy, using ld linker instead")
+ endif()
+ else()
+ message("Using ld linker")
+ endif()
+endif()
+
+# compiler flags for different build types (run 'cmake -DCMAKE_BUILD_TYPE=<type> .')
+# For all builds:
+# For CMAKE_BUILD_TYPE=Debug
+# -ggdb: Enable gdb debugging
+# For CMAKE_BUILD_TYPE=FastDebug
+# Same as DEBUG, except with some optimizations on.
+# For CMAKE_BUILD_TYPE=Release
+# -O3: Enable all compiler optimizations
+# Debug symbols are stripped for reduced binary size. Add
+# -DARROW_CXXFLAGS="-g" to add them
+if(NOT MSVC)
+ if(ARROW_GGDB_DEBUG)
+ set(ARROW_DEBUG_SYMBOL_TYPE "gdb")
+ set(C_FLAGS_DEBUG "-g${ARROW_DEBUG_SYMBOL_TYPE} -O0")
+ set(C_FLAGS_FASTDEBUG "-g${ARROW_DEBUG_SYMBOL_TYPE} -O1")
+ set(CXX_FLAGS_DEBUG "-g${ARROW_DEBUG_SYMBOL_TYPE} -O0")
+ set(CXX_FLAGS_FASTDEBUG "-g${ARROW_DEBUG_SYMBOL_TYPE} -O1")
+ else()
+ set(C_FLAGS_DEBUG "-g -O0")
+ set(C_FLAGS_FASTDEBUG "-g -O1")
+ set(CXX_FLAGS_DEBUG "-g -O0")
+ set(CXX_FLAGS_FASTDEBUG "-g -O1")
+ endif()
+
+ set(C_FLAGS_RELEASE "-O3 -DNDEBUG")
+ set(CXX_FLAGS_RELEASE "-O3 -DNDEBUG")
+endif()
+
+set(C_FLAGS_PROFILE_GEN "${CXX_FLAGS_RELEASE} -fprofile-generate")
+set(C_FLAGS_PROFILE_BUILD "${CXX_FLAGS_RELEASE} -fprofile-use")
+set(CXX_FLAGS_PROFILE_GEN "${CXX_FLAGS_RELEASE} -fprofile-generate")
+set(CXX_FLAGS_PROFILE_BUILD "${CXX_FLAGS_RELEASE} -fprofile-use")
+
+# Set compile flags based on the build type.
+message("Configured for ${CMAKE_BUILD_TYPE} build (set with cmake -DCMAKE_BUILD_TYPE={release,debug,...})"
+)
+if("${CMAKE_BUILD_TYPE}" STREQUAL "DEBUG")
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${C_FLAGS_DEBUG}")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_FLAGS_DEBUG}")
+elseif("${CMAKE_BUILD_TYPE}" STREQUAL "RELWITHDEBINFO")
+
+elseif("${CMAKE_BUILD_TYPE}" STREQUAL "FASTDEBUG")
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${C_FLAGS_FASTDEBUG}")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_FLAGS_FASTDEBUG}")
+elseif("${CMAKE_BUILD_TYPE}" STREQUAL "RELEASE")
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${C_FLAGS_RELEASE}")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_FLAGS_RELEASE}")
+elseif("${CMAKE_BUILD_TYPE}" STREQUAL "PROFILE_GEN")
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${C_FLAGS_PROFILE_GEN}")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_FLAGS_PROFILE_GEN}")
+elseif("${CMAKE_BUILD_TYPE}" STREQUAL "PROFILE_BUILD")
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${C_FLAGS_PROFILE_BUILD}")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_FLAGS_PROFILE_BUILD}")
+else()
+ message(FATAL_ERROR "Unknown build type: ${CMAKE_BUILD_TYPE}")
+endif()
+
+message(STATUS "Build Type: ${CMAKE_BUILD_TYPE}")
+
+# ----------------------------------------------------------------------
+# MSVC-specific linker options
+
+if(MSVC)
+ set(MSVC_LINKER_FLAGS)
+ if(MSVC_LINK_VERBOSE)
+ set(MSVC_LINKER_FLAGS "${MSVC_LINKER_FLAGS} /VERBOSE:LIB")
+ endif()
+ if(NOT ARROW_USE_STATIC_CRT)
+ set(MSVC_LINKER_FLAGS "${MSVC_LINKER_FLAGS} /NODEFAULTLIB:LIBCMT")
+ set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${MSVC_LINKER_FLAGS}")
+ set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} ${MSVC_LINKER_FLAGS}")
+ set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${MSVC_LINKER_FLAGS}")
+ endif()
+endif()
diff --git a/src/arrow/cpp/cmake_modules/ThirdpartyToolchain.cmake b/src/arrow/cpp/cmake_modules/ThirdpartyToolchain.cmake
new file mode 100644
index 000000000..a793f3046
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/ThirdpartyToolchain.cmake
@@ -0,0 +1,4063 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+include(ProcessorCount)
+processorcount(NPROC)
+
+add_custom_target(rapidjson)
+add_custom_target(toolchain)
+add_custom_target(toolchain-benchmarks)
+add_custom_target(toolchain-tests)
+
+# Accumulate all bundled targets and we will splice them together later as
+# libarrow_dependencies.a so that third party libraries have something usable
+# to create statically-linked builds with some BUNDLED dependencies, including
+# allocators like jemalloc and mimalloc
+set(ARROW_BUNDLED_STATIC_LIBS)
+
+# Accumulate all system dependencies to provide suitable static link
+# parameters to the third party libraries.
+set(ARROW_SYSTEM_DEPENDENCIES)
+
+# ----------------------------------------------------------------------
+# Toolchain linkage options
+
+set(ARROW_RE2_LINKAGE
+ "static"
+ CACHE STRING "How to link the re2 library. static|shared (default static)")
+
+if(ARROW_PROTOBUF_USE_SHARED)
+ set(Protobuf_USE_STATIC_LIBS OFF)
+else()
+ set(Protobuf_USE_STATIC_LIBS ON)
+endif()
+
+# ----------------------------------------------------------------------
+# Resolve the dependencies
+
+set(ARROW_THIRDPARTY_DEPENDENCIES
+ AWSSDK
+ benchmark
+ Boost
+ Brotli
+ BZip2
+ c-ares
+ gflags
+ GLOG
+ google_cloud_cpp_storage
+ gRPC
+ GTest
+ LLVM
+ Lz4
+ ORC
+ re2
+ Protobuf
+ RapidJSON
+ Snappy
+ Thrift
+ utf8proc
+ xsimd
+ ZLIB
+ zstd)
+
+# TODO(wesm): External GTest shared libraries are not currently
+# supported when building with MSVC because of the way that
+# conda-forge packages have 4 variants of the libraries packaged
+# together
+if(MSVC AND "${GTest_SOURCE}" STREQUAL "")
+ set(GTest_SOURCE "BUNDLED")
+endif()
+
+# For backward compatibility. We use "BOOST_SOURCE" if "Boost_SOURCE"
+# isn't specified and "BOOST_SOURCE" is specified.
+# We renamed "BOOST" dependency name to "Boost" in 3.0.0 because
+# upstreams (CMake and Boost) use "Boost" not "BOOST" as package name.
+if("${Boost_SOURCE}" STREQUAL "" AND NOT "${BOOST_SOURCE}" STREQUAL "")
+ set(Boost_SOURCE ${BOOST_SOURCE})
+endif()
+
+# For backward compatibility. We use "RE2_SOURCE" if "re2_SOURCE"
+# isn't specified and "RE2_SOURCE" is specified.
+# We renamed "RE2" dependency name to "re2" in 3.0.0 because
+# upstream uses "re2" not "RE2" as package name.
+if("${re2_SOURCE}" STREQUAL "" AND NOT "${RE2_SOURCE}" STREQUAL "")
+ set(re2_SOURCE ${RE2_SOURCE})
+endif()
+
+message(STATUS "Using ${ARROW_DEPENDENCY_SOURCE} approach to find dependencies")
+
+if(ARROW_DEPENDENCY_SOURCE STREQUAL "CONDA")
+ if(MSVC)
+ set(ARROW_PACKAGE_PREFIX "$ENV{CONDA_PREFIX}/Library")
+ else()
+ set(ARROW_PACKAGE_PREFIX $ENV{CONDA_PREFIX})
+ endif()
+ set(ARROW_ACTUAL_DEPENDENCY_SOURCE "SYSTEM")
+ message(STATUS "Using CONDA_PREFIX for ARROW_PACKAGE_PREFIX: ${ARROW_PACKAGE_PREFIX}")
+else()
+ set(ARROW_ACTUAL_DEPENDENCY_SOURCE "${ARROW_DEPENDENCY_SOURCE}")
+endif()
+
+if(ARROW_PACKAGE_PREFIX)
+ message(STATUS "Setting (unset) dependency *_ROOT variables: ${ARROW_PACKAGE_PREFIX}")
+ set(ENV{PKG_CONFIG_PATH} "${ARROW_PACKAGE_PREFIX}/lib/pkgconfig/")
+
+ if(NOT ENV{BOOST_ROOT})
+ set(ENV{BOOST_ROOT} ${ARROW_PACKAGE_PREFIX})
+ endif()
+ if(NOT ENV{Boost_ROOT})
+ set(ENV{Boost_ROOT} ${ARROW_PACKAGE_PREFIX})
+ endif()
+endif()
+
+# For each dependency, set dependency source to global default, if unset
+foreach(DEPENDENCY ${ARROW_THIRDPARTY_DEPENDENCIES})
+ if("${${DEPENDENCY}_SOURCE}" STREQUAL "")
+ set(${DEPENDENCY}_SOURCE ${ARROW_ACTUAL_DEPENDENCY_SOURCE})
+ # If no ROOT was supplied and we have a global prefix, use it
+ if(NOT ${DEPENDENCY}_ROOT AND ARROW_PACKAGE_PREFIX)
+ set(${DEPENDENCY}_ROOT ${ARROW_PACKAGE_PREFIX})
+ endif()
+ endif()
+endforeach()
+
+macro(build_dependency DEPENDENCY_NAME)
+ if("${DEPENDENCY_NAME}" STREQUAL "AWSSDK")
+ build_awssdk()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "benchmark")
+ build_benchmark()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "Boost")
+ build_boost()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "Brotli")
+ build_brotli()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "BZip2")
+ build_bzip2()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "c-ares")
+ build_cares()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "gflags")
+ build_gflags()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "GLOG")
+ build_glog()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "google_cloud_cpp_storage")
+ build_google_cloud_cpp_storage()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "gRPC")
+ build_grpc()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "GTest")
+ build_gtest()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "Lz4")
+ build_lz4()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "ORC")
+ build_orc()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "Protobuf")
+ build_protobuf()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "RapidJSON")
+ build_rapidjson()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "re2")
+ build_re2()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "Snappy")
+ build_snappy()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "Thrift")
+ build_thrift()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "utf8proc")
+ build_utf8proc()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "xsimd")
+ build_xsimd()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "ZLIB")
+ build_zlib()
+ elseif("${DEPENDENCY_NAME}" STREQUAL "zstd")
+ build_zstd()
+ else()
+ message(FATAL_ERROR "Unknown thirdparty dependency to build: ${DEPENDENCY_NAME}")
+ endif()
+endmacro()
+
+# Find modules are needed by the consumer in case of a static build, or if the
+# linkage is PUBLIC or INTERFACE.
+macro(provide_find_module PACKAGE_NAME)
+ set(module_ "${CMAKE_SOURCE_DIR}/cmake_modules/Find${PACKAGE_NAME}.cmake")
+ if(EXISTS "${module_}")
+ message(STATUS "Providing CMake module for ${PACKAGE_NAME}")
+ install(FILES "${module_}" DESTINATION "${ARROW_CMAKE_INSTALL_DIR}")
+ endif()
+ unset(module_)
+endmacro()
+
+macro(resolve_dependency DEPENDENCY_NAME)
+ set(options)
+ set(one_value_args HAVE_ALT IS_RUNTIME_DEPENDENCY REQUIRED_VERSION USE_CONFIG)
+ set(multi_value_args PC_PACKAGE_NAMES)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+ if(ARG_UNPARSED_ARGUMENTS)
+ message(SEND_ERROR "Error: unrecognized arguments: ${ARG_UNPARSED_ARGUMENTS}")
+ endif()
+ if("${ARG_IS_RUNTIME_DEPENDENCY}" STREQUAL "")
+ set(ARG_IS_RUNTIME_DEPENDENCY TRUE)
+ endif()
+
+ if(ARG_HAVE_ALT)
+ set(PACKAGE_NAME "${DEPENDENCY_NAME}Alt")
+ else()
+ set(PACKAGE_NAME ${DEPENDENCY_NAME})
+ endif()
+ set(FIND_PACKAGE_ARGUMENTS ${PACKAGE_NAME})
+ if(ARG_REQUIRED_VERSION)
+ list(APPEND FIND_PACKAGE_ARGUMENTS ${ARG_REQUIRED_VERSION})
+ endif()
+ if(ARG_USE_CONFIG)
+ list(APPEND FIND_PACKAGE_ARGUMENTS CONFIG)
+ endif()
+ if(${DEPENDENCY_NAME}_SOURCE STREQUAL "AUTO")
+ find_package(${FIND_PACKAGE_ARGUMENTS})
+ if(${${PACKAGE_NAME}_FOUND})
+ set(${DEPENDENCY_NAME}_SOURCE "SYSTEM")
+ else()
+ build_dependency(${DEPENDENCY_NAME})
+ set(${DEPENDENCY_NAME}_SOURCE "BUNDLED")
+ endif()
+ elseif(${DEPENDENCY_NAME}_SOURCE STREQUAL "BUNDLED")
+ build_dependency(${DEPENDENCY_NAME})
+ elseif(${DEPENDENCY_NAME}_SOURCE STREQUAL "SYSTEM")
+ find_package(${FIND_PACKAGE_ARGUMENTS} REQUIRED)
+ endif()
+ if(${DEPENDENCY_NAME}_SOURCE STREQUAL "SYSTEM" AND ARG_IS_RUNTIME_DEPENDENCY)
+ provide_find_module(${PACKAGE_NAME})
+ list(APPEND ARROW_SYSTEM_DEPENDENCIES ${PACKAGE_NAME})
+ find_package(PkgConfig QUIET)
+ foreach(ARG_PC_PACKAGE_NAME ${ARG_PC_PACKAGE_NAMES})
+ pkg_check_modules(${ARG_PC_PACKAGE_NAME}_PC
+ ${ARG_PC_PACKAGE_NAME}
+ NO_CMAKE_PATH
+ NO_CMAKE_ENVIRONMENT_PATH
+ QUIET)
+ if(${${ARG_PC_PACKAGE_NAME}_PC_FOUND})
+ string(APPEND ARROW_PC_REQUIRES_PRIVATE " ${ARG_PC_PACKAGE_NAME}")
+ endif()
+ endforeach()
+ endif()
+endmacro()
+
+# ----------------------------------------------------------------------
+# Thirdparty versions, environment variables, source URLs
+
+set(THIRDPARTY_DIR "${arrow_SOURCE_DIR}/thirdparty")
+
+# Include vendored Flatbuffers
+include_directories(SYSTEM "${THIRDPARTY_DIR}/flatbuffers/include")
+
+# ----------------------------------------------------------------------
+# Some EP's require other EP's
+
+if(PARQUET_REQUIRE_ENCRYPTION)
+ set(ARROW_JSON ON)
+endif()
+
+if(ARROW_THRIFT)
+ set(ARROW_WITH_ZLIB ON)
+endif()
+
+if(ARROW_HIVESERVER2 OR ARROW_PARQUET)
+ set(ARROW_WITH_THRIFT ON)
+ if(ARROW_HIVESERVER2)
+ set(ARROW_THRIFT_REQUIRED_COMPONENTS COMPILER)
+ else()
+ set(ARROW_THRIFT_REQUIRED_COMPONENTS)
+ endif()
+else()
+ set(ARROW_WITH_THRIFT OFF)
+endif()
+
+if(ARROW_FLIGHT)
+ set(ARROW_WITH_GRPC ON)
+ # gRPC requires zlib
+ set(ARROW_WITH_ZLIB ON)
+endif()
+
+if(ARROW_GCS)
+ set(ARROW_WITH_GOOGLE_CLOUD_CPP ON)
+endif()
+
+if(ARROW_JSON)
+ set(ARROW_WITH_RAPIDJSON ON)
+endif()
+
+if(ARROW_ORC
+ OR ARROW_FLIGHT
+ OR ARROW_GANDIVA)
+ set(ARROW_WITH_PROTOBUF ON)
+endif()
+
+if(ARROW_S3)
+ set(ARROW_WITH_ZLIB ON)
+endif()
+
+if((NOT ARROW_COMPUTE) AND (NOT ARROW_GANDIVA))
+ set(ARROW_WITH_UTF8PROC OFF)
+endif()
+
+if((NOT ARROW_COMPUTE)
+ AND (NOT ARROW_GANDIVA)
+ AND (NOT ARROW_WITH_GRPC))
+ set(ARROW_WITH_RE2 OFF)
+endif()
+
+# ----------------------------------------------------------------------
+# Versions and URLs for toolchain builds, which also can be used to configure
+# offline builds
+# Note: We should not use the Apache dist server for build dependencies
+
+macro(set_urls URLS)
+ set(${URLS} ${ARGN})
+ if(CMAKE_VERSION VERSION_LESS 3.7)
+ # ExternalProject doesn't support backup URLs;
+ # Feature only available starting in 3.7
+ list(GET ${URLS} 0 ${URLS})
+ endif()
+endmacro()
+
+# Read toolchain versions from cpp/thirdparty/versions.txt
+file(STRINGS "${THIRDPARTY_DIR}/versions.txt" TOOLCHAIN_VERSIONS_TXT)
+foreach(_VERSION_ENTRY ${TOOLCHAIN_VERSIONS_TXT})
+ # Exclude comments
+ if(NOT ((_VERSION_ENTRY MATCHES "^[^#][A-Za-z0-9-_]+_VERSION=")
+ OR (_VERSION_ENTRY MATCHES "^[^#][A-Za-z0-9-_]+_CHECKSUM=")))
+ continue()
+ endif()
+
+ string(REGEX MATCH "^[^=]*" _VARIABLE_NAME ${_VERSION_ENTRY})
+ string(REPLACE "${_VARIABLE_NAME}=" "" _VARIABLE_VALUE ${_VERSION_ENTRY})
+
+ # Skip blank or malformed lines
+ if(_VARIABLE_VALUE STREQUAL "")
+ continue()
+ endif()
+
+ # For debugging
+ message(STATUS "${_VARIABLE_NAME}: ${_VARIABLE_VALUE}")
+
+ set(${_VARIABLE_NAME} ${_VARIABLE_VALUE})
+endforeach()
+
+if(DEFINED ENV{ARROW_ABSL_URL})
+ set(ABSL_SOURCE_URL "$ENV{ARROW_ABSL_URL}")
+else()
+ set_urls(ABSL_SOURCE_URL
+ "https://github.com/abseil/abseil-cpp/archive/${ARROW_ABSL_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_AWS_C_COMMON_URL})
+ set(AWS_C_COMMON_SOURCE_URL "$ENV{ARROW_AWS_C_COMMON_URL}")
+else()
+ set_urls(AWS_C_COMMON_SOURCE_URL
+ "https://github.com/awslabs/aws-c-common/archive/${ARROW_AWS_C_COMMON_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_AWS_CHECKSUMS_URL})
+ set(AWS_CHECKSUMS_SOURCE_URL "$ENV{ARROW_AWS_CHECKSUMS_URL}")
+else()
+ set_urls(AWS_CHECKSUMS_SOURCE_URL
+ "https://github.com/awslabs/aws-checksums/archive/${ARROW_AWS_CHECKSUMS_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_AWS_C_EVENT_STREAM_URL})
+ set(AWS_C_EVENT_STREAM_SOURCE_URL "$ENV{ARROW_AWS_C_EVENT_STREAM_URL}")
+else()
+ set_urls(AWS_C_EVENT_STREAM_SOURCE_URL
+ "https://github.com/awslabs/aws-c-event-stream/archive/${ARROW_AWS_C_EVENT_STREAM_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_AWSSDK_URL})
+ set(AWSSDK_SOURCE_URL "$ENV{ARROW_AWSSDK_URL}")
+else()
+ set_urls(AWSSDK_SOURCE_URL
+ "https://github.com/aws/aws-sdk-cpp/archive/${ARROW_AWSSDK_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/aws-sdk-cpp-${ARROW_AWSSDK_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_BOOST_URL})
+ set(BOOST_SOURCE_URL "$ENV{ARROW_BOOST_URL}")
+else()
+ string(REPLACE "." "_" ARROW_BOOST_BUILD_VERSION_UNDERSCORES
+ ${ARROW_BOOST_BUILD_VERSION})
+ set_urls(BOOST_SOURCE_URL
+ # These are trimmed boost bundles we maintain.
+ # See cpp/build-support/trim-boost.sh
+ # FIXME(ARROW-6407) automate uploading this archive to ensure it reflects
+ # our currently used packages and doesn't fall out of sync with
+ # ${ARROW_BOOST_BUILD_VERSION_UNDERSCORES}
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/boost_${ARROW_BOOST_BUILD_VERSION_UNDERSCORES}.tar.gz"
+ "https://boostorg.jfrog.io/artifactory/main/release/${ARROW_BOOST_BUILD_VERSION}/source/boost_${ARROW_BOOST_BUILD_VERSION_UNDERSCORES}.tar.gz"
+ "https://sourceforge.net/projects/boost/files/boost/${ARROW_BOOST_BUILD_VERSION}/boost_${ARROW_BOOST_BUILD_VERSION_UNDERSCORES}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_BROTLI_URL})
+ set(BROTLI_SOURCE_URL "$ENV{ARROW_BROTLI_URL}")
+else()
+ set_urls(BROTLI_SOURCE_URL
+ "https://github.com/google/brotli/archive/${ARROW_BROTLI_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/brotli-${ARROW_BROTLI_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_BZIP2_URL})
+ set(ARROW_BZIP2_SOURCE_URL "$ENV{ARROW_BZIP2_URL}")
+else()
+ set_urls(ARROW_BZIP2_SOURCE_URL
+ "https://sourceware.org/pub/bzip2/bzip2-${ARROW_BZIP2_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/bzip2-${ARROW_BZIP2_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_CARES_URL})
+ set(CARES_SOURCE_URL "$ENV{ARROW_CARES_URL}")
+else()
+ set_urls(CARES_SOURCE_URL
+ "https://c-ares.haxx.se/download/c-ares-${ARROW_CARES_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/cares-${ARROW_CARES_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_CRC32C_URL})
+ set(CRC32C_URL "$ENV{ARROW_CRC32C_URL}")
+else()
+ set_urls(CRC32C_SOURCE_URL
+ "https://github.com/google/crc32c/archive/${ARROW_CRC32C_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_GBENCHMARK_URL})
+ set(GBENCHMARK_SOURCE_URL "$ENV{ARROW_GBENCHMARK_URL}")
+else()
+ set_urls(GBENCHMARK_SOURCE_URL
+ "https://github.com/google/benchmark/archive/${ARROW_GBENCHMARK_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/gbenchmark-${ARROW_GBENCHMARK_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_GFLAGS_URL})
+ set(GFLAGS_SOURCE_URL "$ENV{ARROW_GFLAGS_URL}")
+else()
+ set_urls(GFLAGS_SOURCE_URL
+ "https://github.com/gflags/gflags/archive/${ARROW_GFLAGS_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/gflags-${ARROW_GFLAGS_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_GLOG_URL})
+ set(GLOG_SOURCE_URL "$ENV{ARROW_GLOG_URL}")
+else()
+ set_urls(GLOG_SOURCE_URL
+ "https://github.com/google/glog/archive/${ARROW_GLOG_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/glog-${ARROW_GLOG_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_GOOGLE_CLOUD_CPP_URL})
+ set(google_cloud_cpp_storage_SOURCE_URL "$ENV{ARROW_GOOGLE_CLOUD_CPP_URL}")
+else()
+ set_urls(google_cloud_cpp_storage_SOURCE_URL
+ "https://github.com/googleapis/google-cloud-cpp/archive/${ARROW_GOOGLE_CLOUD_CPP_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_GRPC_URL})
+ set(GRPC_SOURCE_URL "$ENV{ARROW_GRPC_URL}")
+else()
+ set_urls(GRPC_SOURCE_URL
+ "https://github.com/grpc/grpc/archive/${ARROW_GRPC_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/grpc-${ARROW_GRPC_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_GTEST_URL})
+ set(GTEST_SOURCE_URL "$ENV{ARROW_GTEST_URL}")
+else()
+ set_urls(GTEST_SOURCE_URL
+ "https://github.com/google/googletest/archive/release-${ARROW_GTEST_BUILD_VERSION}.tar.gz"
+ "https://chromium.googlesource.com/external/github.com/google/googletest/+archive/release-${ARROW_GTEST_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/gtest-${ARROW_GTEST_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_JEMALLOC_URL})
+ set(JEMALLOC_SOURCE_URL "$ENV{ARROW_JEMALLOC_URL}")
+else()
+ set_urls(JEMALLOC_SOURCE_URL
+ "https://github.com/jemalloc/jemalloc/releases/download/${ARROW_JEMALLOC_BUILD_VERSION}/jemalloc-${ARROW_JEMALLOC_BUILD_VERSION}.tar.bz2"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/jemalloc-${ARROW_JEMALLOC_BUILD_VERSION}.tar.bz2"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_MIMALLOC_URL})
+ set(MIMALLOC_SOURCE_URL "$ENV{ARROW_MIMALLOC_URL}")
+else()
+ set_urls(MIMALLOC_SOURCE_URL
+ "https://github.com/microsoft/mimalloc/archive/${ARROW_MIMALLOC_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/mimalloc-${ARROW_MIMALLOC_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_NLOHMANN_JSON_URL})
+ set(NLOHMANN_JSON_SOURCE_URL "$ENV{ARROW_NLOHMANN_JSON_URL}")
+else()
+ set_urls(NLOHMANN_JSON_SOURCE_URL
+ "https://github.com/nlohmann/json/archive/${ARROW_NLOHMANN_JSON_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_LZ4_URL})
+ set(LZ4_SOURCE_URL "$ENV{ARROW_LZ4_URL}")
+else()
+ set_urls(LZ4_SOURCE_URL
+ "https://github.com/lz4/lz4/archive/${ARROW_LZ4_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/lz4-${ARROW_LZ4_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_ORC_URL})
+ set(ORC_SOURCE_URL "$ENV{ARROW_ORC_URL}")
+else()
+ set_urls(ORC_SOURCE_URL
+ "https://github.com/apache/orc/archive/rel/release-${ARROW_ORC_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/orc-${ARROW_ORC_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_PROTOBUF_URL})
+ set(PROTOBUF_SOURCE_URL "$ENV{ARROW_PROTOBUF_URL}")
+else()
+ string(SUBSTRING ${ARROW_PROTOBUF_BUILD_VERSION} 1 -1
+ ARROW_PROTOBUF_STRIPPED_BUILD_VERSION)
+ # strip the leading `v`
+ set_urls(PROTOBUF_SOURCE_URL
+ "https://github.com/protocolbuffers/protobuf/releases/download/${ARROW_PROTOBUF_BUILD_VERSION}/protobuf-all-${ARROW_PROTOBUF_STRIPPED_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/protobuf-${ARROW_PROTOBUF_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_RE2_URL})
+ set(RE2_SOURCE_URL "$ENV{ARROW_RE2_URL}")
+else()
+ set_urls(RE2_SOURCE_URL
+ "https://github.com/google/re2/archive/${ARROW_RE2_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/re2-${ARROW_RE2_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_RAPIDJSON_URL})
+ set(RAPIDJSON_SOURCE_URL "$ENV{ARROW_RAPIDJSON_URL}")
+else()
+ set_urls(RAPIDJSON_SOURCE_URL
+ "https://github.com/miloyip/rapidjson/archive/${ARROW_RAPIDJSON_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/rapidjson-${ARROW_RAPIDJSON_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_SNAPPY_URL})
+ set(SNAPPY_SOURCE_URL "$ENV{ARROW_SNAPPY_URL}")
+else()
+ set_urls(SNAPPY_SOURCE_URL
+ "https://github.com/google/snappy/archive/${ARROW_SNAPPY_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/snappy-${ARROW_SNAPPY_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_THRIFT_URL})
+ set(THRIFT_SOURCE_URL "$ENV{ARROW_THRIFT_URL}")
+else()
+ set_urls(THRIFT_SOURCE_URL
+ "https://www.apache.org/dyn/closer.cgi?action=download&filename=/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "https://downloads.apache.org/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "https://github.com/apache/thrift/archive/v${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "https://apache.claz.org/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "https://apache.cs.utah.edu/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "https://apache.mirrors.lucidnetworks.net/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "https://apache.osuosl.org/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "https://ftp.wayne.edu/apache/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "https://mirror.olnevhost.net/pub/apache/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "https://mirrors.gigenet.com/apache/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "https://mirrors.koehn.com/apache/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "https://mirrors.ocf.berkeley.edu/apache/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "https://mirrors.sonic.net/apache/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "https://us.mirrors.quenda.co/apache/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_UTF8PROC_URL})
+ set(ARROW_UTF8PROC_SOURCE_URL "$ENV{ARROW_UTF8PROC_URL}")
+else()
+ set_urls(ARROW_UTF8PROC_SOURCE_URL
+ "https://github.com/JuliaStrings/utf8proc/archive/${ARROW_UTF8PROC_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_XSIMD_URL})
+ set(XSIMD_SOURCE_URL "$ENV{ARROW_XSIMD_URL}")
+else()
+ set_urls(XSIMD_SOURCE_URL
+ "https://github.com/xtensor-stack/xsimd/archive/${ARROW_XSIMD_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_ZLIB_URL})
+ set(ZLIB_SOURCE_URL "$ENV{ARROW_ZLIB_URL}")
+else()
+ set_urls(ZLIB_SOURCE_URL
+ "https://zlib.net/fossils/zlib-${ARROW_ZLIB_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/zlib-${ARROW_ZLIB_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+if(DEFINED ENV{ARROW_ZSTD_URL})
+ set(ZSTD_SOURCE_URL "$ENV{ARROW_ZSTD_URL}")
+else()
+ set_urls(ZSTD_SOURCE_URL
+ "https://github.com/facebook/zstd/archive/${ARROW_ZSTD_BUILD_VERSION}.tar.gz"
+ "https://github.com/ursa-labs/thirdparty/releases/download/latest/zstd-${ARROW_ZSTD_BUILD_VERSION}.tar.gz"
+ )
+endif()
+
+# ----------------------------------------------------------------------
+# ExternalProject options
+
+set(EP_CXX_FLAGS
+ "${CMAKE_CXX_COMPILER_ARG1} ${CMAKE_CXX_FLAGS} ${CMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}}"
+)
+set(EP_C_FLAGS
+ "${CMAKE_C_COMPILER_ARG1} ${CMAKE_C_FLAGS} ${CMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}}")
+
+if(NOT MSVC_TOOLCHAIN)
+ # Set -fPIC on all external projects
+ set(EP_CXX_FLAGS "${EP_CXX_FLAGS} -fPIC")
+ set(EP_C_FLAGS "${EP_C_FLAGS} -fPIC")
+endif()
+
+# CC/CXX environment variables are captured on the first invocation of the
+# builder (e.g make or ninja) instead of when CMake is invoked into to build
+# directory. This leads to issues if the variables are exported in a subshell
+# and the invocation of make/ninja is in distinct subshell without the same
+# environment (CC/CXX).
+set(EP_COMMON_TOOLCHAIN -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
+ -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER})
+
+if(CMAKE_AR)
+ set(EP_COMMON_TOOLCHAIN ${EP_COMMON_TOOLCHAIN} -DCMAKE_AR=${CMAKE_AR})
+endif()
+
+if(CMAKE_RANLIB)
+ set(EP_COMMON_TOOLCHAIN ${EP_COMMON_TOOLCHAIN} -DCMAKE_RANLIB=${CMAKE_RANLIB})
+endif()
+
+# External projects are still able to override the following declarations.
+# cmake command line will favor the last defined variable when a duplicate is
+# encountered. This requires that `EP_COMMON_CMAKE_ARGS` is always the first
+# argument.
+set(EP_COMMON_CMAKE_ARGS
+ ${EP_COMMON_TOOLCHAIN}
+ ${EP_COMMON_CMAKE_ARGS}
+ -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
+ -DCMAKE_C_FLAGS=${EP_C_FLAGS}
+ -DCMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}=${EP_C_FLAGS}
+ -DCMAKE_CXX_FLAGS=${EP_CXX_FLAGS}
+ -DCMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}=${EP_CXX_FLAGS}
+ -DCMAKE_CXX_STANDARD=${CMAKE_CXX_STANDARD}
+ -DCMAKE_EXPORT_NO_PACKAGE_REGISTRY=${CMAKE_EXPORT_NO_PACKAGE_REGISTRY}
+ -DCMAKE_FIND_PACKAGE_NO_PACKAGE_REGISTRY=${CMAKE_FIND_PACKAGE_NO_PACKAGE_REGISTRY})
+
+if(NOT ARROW_VERBOSE_THIRDPARTY_BUILD)
+ set(EP_LOG_OPTIONS
+ LOG_CONFIGURE
+ 1
+ LOG_BUILD
+ 1
+ LOG_INSTALL
+ 1
+ LOG_DOWNLOAD
+ 1
+ LOG_OUTPUT_ON_FAILURE
+ 1)
+ set(Boost_DEBUG FALSE)
+else()
+ set(EP_LOG_OPTIONS)
+ set(Boost_DEBUG TRUE)
+endif()
+
+# Ensure that a default make is set
+if("${MAKE}" STREQUAL "")
+ if(NOT MSVC)
+ find_program(MAKE make)
+ endif()
+endif()
+
+# Using make -j in sub-make is fragile
+# see discussion https://github.com/apache/arrow/pull/2779
+if(${CMAKE_GENERATOR} MATCHES "Makefiles")
+ set(MAKE_BUILD_ARGS "")
+else()
+ # limit the maximum number of jobs for ninja
+ set(MAKE_BUILD_ARGS "-j${NPROC}")
+endif()
+
+# ----------------------------------------------------------------------
+# Find pthreads
+
+set(THREADS_PREFER_PTHREAD_FLAG ON)
+find_package(Threads REQUIRED)
+
+# ----------------------------------------------------------------------
+# Add Boost dependencies (code adapted from Apache Kudu)
+
+macro(build_boost)
+ set(BOOST_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/boost_ep-prefix/src/boost_ep")
+
+ # This is needed by the thrift_ep build
+ set(BOOST_ROOT ${BOOST_PREFIX})
+
+ if(ARROW_BOOST_REQUIRE_LIBRARY)
+ set(BOOST_LIB_DIR "${BOOST_PREFIX}/stage/lib")
+ set(BOOST_BUILD_LINK "static")
+ if("${CMAKE_BUILD_TYPE}" STREQUAL "DEBUG")
+ set(BOOST_BUILD_VARIANT "debug")
+ else()
+ set(BOOST_BUILD_VARIANT "release")
+ endif()
+ if(MSVC)
+ set(BOOST_CONFIGURE_COMMAND ".\\\\bootstrap.bat")
+ else()
+ set(BOOST_CONFIGURE_COMMAND "./bootstrap.sh")
+ endif()
+
+ set(BOOST_BUILD_WITH_LIBRARIES "filesystem" "system")
+ string(REPLACE ";" "," BOOST_CONFIGURE_LIBRARIES "${BOOST_BUILD_WITH_LIBRARIES}")
+ list(APPEND BOOST_CONFIGURE_COMMAND "--prefix=${BOOST_PREFIX}"
+ "--with-libraries=${BOOST_CONFIGURE_LIBRARIES}")
+ set(BOOST_BUILD_COMMAND "./b2" "-j${NPROC}" "link=${BOOST_BUILD_LINK}"
+ "variant=${BOOST_BUILD_VARIANT}")
+ if(MSVC)
+ string(REGEX REPLACE "([0-9])$" ".\\1" BOOST_TOOLSET_MSVC_VERSION
+ ${MSVC_TOOLSET_VERSION})
+ list(APPEND BOOST_BUILD_COMMAND "toolset=msvc-${BOOST_TOOLSET_MSVC_VERSION}")
+ set(BOOST_BUILD_WITH_LIBRARIES_MSVC)
+ foreach(_BOOST_LIB ${BOOST_BUILD_WITH_LIBRARIES})
+ list(APPEND BOOST_BUILD_WITH_LIBRARIES_MSVC "--with-${_BOOST_LIB}")
+ endforeach()
+ list(APPEND BOOST_BUILD_COMMAND ${BOOST_BUILD_WITH_LIBRARIES_MSVC})
+ else()
+ list(APPEND BOOST_BUILD_COMMAND "cxxflags=-fPIC")
+ endif()
+
+ if(MSVC)
+ string(REGEX
+ REPLACE "^([0-9]+)\\.([0-9]+)\\.[0-9]+$" "\\1_\\2"
+ ARROW_BOOST_BUILD_VERSION_NO_MICRO_UNDERSCORE
+ ${ARROW_BOOST_BUILD_VERSION})
+ set(BOOST_LIBRARY_SUFFIX "-vc${MSVC_TOOLSET_VERSION}-mt")
+ if(BOOST_BUILD_VARIANT STREQUAL "debug")
+ set(BOOST_LIBRARY_SUFFIX "${BOOST_LIBRARY_SUFFIX}-gd")
+ endif()
+ set(BOOST_LIBRARY_SUFFIX
+ "${BOOST_LIBRARY_SUFFIX}-x64-${ARROW_BOOST_BUILD_VERSION_NO_MICRO_UNDERSCORE}")
+ else()
+ set(BOOST_LIBRARY_SUFFIX "")
+ endif()
+ set(BOOST_STATIC_SYSTEM_LIBRARY
+ "${BOOST_LIB_DIR}/libboost_system${BOOST_LIBRARY_SUFFIX}${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ set(BOOST_STATIC_FILESYSTEM_LIBRARY
+ "${BOOST_LIB_DIR}/libboost_filesystem${BOOST_LIBRARY_SUFFIX}${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ set(BOOST_SYSTEM_LIBRARY boost_system_static)
+ set(BOOST_FILESYSTEM_LIBRARY boost_filesystem_static)
+ set(BOOST_BUILD_PRODUCTS ${BOOST_STATIC_SYSTEM_LIBRARY}
+ ${BOOST_STATIC_FILESYSTEM_LIBRARY})
+
+ add_thirdparty_lib(boost_system STATIC_LIB "${BOOST_STATIC_SYSTEM_LIBRARY}")
+
+ add_thirdparty_lib(boost_filesystem STATIC_LIB "${BOOST_STATIC_FILESYSTEM_LIBRARY}")
+
+ externalproject_add(boost_ep
+ URL ${BOOST_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_BOOST_BUILD_SHA256_CHECKSUM}"
+ BUILD_BYPRODUCTS ${BOOST_BUILD_PRODUCTS}
+ BUILD_IN_SOURCE 1
+ CONFIGURE_COMMAND ${BOOST_CONFIGURE_COMMAND}
+ BUILD_COMMAND ${BOOST_BUILD_COMMAND}
+ INSTALL_COMMAND "" ${EP_LOG_OPTIONS})
+ add_dependencies(boost_system_static boost_ep)
+ add_dependencies(boost_filesystem_static boost_ep)
+ else()
+ externalproject_add(boost_ep
+ ${EP_LOG_OPTIONS}
+ BUILD_COMMAND ""
+ CONFIGURE_COMMAND ""
+ INSTALL_COMMAND ""
+ URL ${BOOST_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_BOOST_BUILD_SHA256_CHECKSUM}")
+ endif()
+ set(Boost_INCLUDE_DIR "${BOOST_PREFIX}")
+ set(Boost_INCLUDE_DIRS "${Boost_INCLUDE_DIR}")
+ add_dependencies(toolchain boost_ep)
+ set(BOOST_VENDORED TRUE)
+endmacro()
+
+if(ARROW_FLIGHT AND ARROW_BUILD_TESTS)
+ set(ARROW_BOOST_REQUIRED_VERSION "1.64")
+else()
+ set(ARROW_BOOST_REQUIRED_VERSION "1.58")
+endif()
+
+set(Boost_USE_MULTITHREADED ON)
+if(MSVC AND ARROW_USE_STATIC_CRT)
+ set(Boost_USE_STATIC_RUNTIME ON)
+endif()
+set(Boost_ADDITIONAL_VERSIONS
+ "1.75.0"
+ "1.75"
+ "1.74.0"
+ "1.74"
+ "1.73.0"
+ "1.73"
+ "1.72.0"
+ "1.72"
+ "1.71.0"
+ "1.71"
+ "1.70.0"
+ "1.70"
+ "1.69.0"
+ "1.69"
+ "1.68.0"
+ "1.68"
+ "1.67.0"
+ "1.67"
+ "1.66.0"
+ "1.66"
+ "1.65.0"
+ "1.65"
+ "1.64.0"
+ "1.64"
+ "1.63.0"
+ "1.63"
+ "1.62.0"
+ "1.61"
+ "1.61.0"
+ "1.62"
+ "1.60.0"
+ "1.60")
+
+# Thrift needs Boost if we're building the bundled version with version < 0.13,
+# so we first need to determine whether we're building it
+if(ARROW_WITH_THRIFT AND Thrift_SOURCE STREQUAL "AUTO")
+ find_package(Thrift 0.11.0 MODULE COMPONENTS ${ARROW_THRIFT_REQUIRED_COMPONENTS})
+ if(Thrift_FOUND)
+ find_package(PkgConfig QUIET)
+ pkg_check_modules(THRIFT_PC
+ thrift
+ NO_CMAKE_PATH
+ NO_CMAKE_ENVIRONMENT_PATH
+ QUIET)
+ if(THRIFT_PC_FOUND)
+ string(APPEND ARROW_PC_REQUIRES_PRIVATE " thrift")
+ endif()
+ else()
+ set(Thrift_SOURCE "BUNDLED")
+ endif()
+endif()
+
+# Thrift < 0.13 has a compile-time header dependency on boost
+if(Thrift_SOURCE STREQUAL "BUNDLED" AND ARROW_THRIFT_BUILD_VERSION VERSION_LESS "0.13")
+ set(THRIFT_REQUIRES_BOOST TRUE)
+elseif(THRIFT_VERSION VERSION_LESS "0.13")
+ set(THRIFT_REQUIRES_BOOST TRUE)
+else()
+ set(THRIFT_REQUIRES_BOOST FALSE)
+endif()
+
+# Compilers that don't support int128_t have a compile-time
+# (header-only) dependency on Boost for int128_t.
+if(ARROW_USE_UBSAN)
+ # NOTE: Avoid native int128_t on clang with UBSan as it produces linker errors
+ # (such as "undefined reference to '__muloti4'")
+ set(ARROW_USE_NATIVE_INT128 FALSE)
+else()
+ include(CheckCXXSymbolExists)
+ check_cxx_symbol_exists("__SIZEOF_INT128__" "" ARROW_USE_NATIVE_INT128)
+endif()
+
+# - Gandiva has a compile-time (header-only) dependency on Boost, not runtime.
+# - Tests need Boost at runtime.
+# - S3FS and Flight benchmarks need Boost at runtime.
+if(ARROW_BUILD_INTEGRATION
+ OR ARROW_BUILD_TESTS
+ OR (ARROW_FLIGHT AND ARROW_BUILD_BENCHMARKS)
+ OR (ARROW_S3 AND ARROW_BUILD_BENCHMARKS))
+ set(ARROW_BOOST_REQUIRED TRUE)
+ set(ARROW_BOOST_REQUIRE_LIBRARY TRUE)
+elseif(ARROW_GANDIVA
+ OR (ARROW_WITH_THRIFT AND THRIFT_REQUIRES_BOOST)
+ OR (NOT ARROW_USE_NATIVE_INT128))
+ set(ARROW_BOOST_REQUIRED TRUE)
+ set(ARROW_BOOST_REQUIRE_LIBRARY FALSE)
+else()
+ set(ARROW_BOOST_REQUIRED FALSE)
+endif()
+
+if(ARROW_BOOST_REQUIRED)
+ resolve_dependency(Boost
+ HAVE_ALT
+ TRUE
+ REQUIRED_VERSION
+ ${ARROW_BOOST_REQUIRED_VERSION}
+ IS_RUNTIME_DEPENDENCY
+ # libarrow.so doesn't depend on libboost*.
+ FALSE)
+
+ if(TARGET Boost::system)
+ set(BOOST_SYSTEM_LIBRARY Boost::system)
+ set(BOOST_FILESYSTEM_LIBRARY Boost::filesystem)
+ elseif(BoostAlt_FOUND)
+ set(BOOST_SYSTEM_LIBRARY ${Boost_SYSTEM_LIBRARY})
+ set(BOOST_FILESYSTEM_LIBRARY ${Boost_FILESYSTEM_LIBRARY})
+ else()
+ set(BOOST_SYSTEM_LIBRARY boost_system_static)
+ set(BOOST_FILESYSTEM_LIBRARY boost_filesystem_static)
+ endif()
+ set(ARROW_BOOST_LIBS ${BOOST_SYSTEM_LIBRARY} ${BOOST_FILESYSTEM_LIBRARY})
+
+ message(STATUS "Boost include dir: ${Boost_INCLUDE_DIR}")
+ message(STATUS "Boost libraries: ${ARROW_BOOST_LIBS}")
+
+ include_directories(SYSTEM ${Boost_INCLUDE_DIR})
+endif()
+
+# ----------------------------------------------------------------------
+# Snappy
+
+macro(build_snappy)
+ message(STATUS "Building snappy from source")
+ set(SNAPPY_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/snappy_ep/src/snappy_ep-install")
+ set(SNAPPY_STATIC_LIB_NAME snappy)
+ set(SNAPPY_STATIC_LIB
+ "${SNAPPY_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}${SNAPPY_STATIC_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+
+ set(SNAPPY_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS} -DCMAKE_INSTALL_LIBDIR=lib -DSNAPPY_BUILD_TESTS=OFF
+ "-DCMAKE_INSTALL_PREFIX=${SNAPPY_PREFIX}")
+
+ externalproject_add(snappy_ep
+ ${EP_LOG_OPTIONS}
+ BUILD_IN_SOURCE 1
+ INSTALL_DIR ${SNAPPY_PREFIX}
+ URL ${SNAPPY_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_SNAPPY_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${SNAPPY_CMAKE_ARGS}
+ BUILD_BYPRODUCTS "${SNAPPY_STATIC_LIB}")
+
+ file(MAKE_DIRECTORY "${SNAPPY_PREFIX}/include")
+
+ add_library(Snappy::snappy STATIC IMPORTED)
+ set_target_properties(Snappy::snappy
+ PROPERTIES IMPORTED_LOCATION "${SNAPPY_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${SNAPPY_PREFIX}/include")
+ add_dependencies(toolchain snappy_ep)
+ add_dependencies(Snappy::snappy snappy_ep)
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS Snappy::snappy)
+endmacro()
+
+if(ARROW_WITH_SNAPPY)
+ resolve_dependency(Snappy PC_PACKAGE_NAMES snappy)
+ if(${Snappy_SOURCE} STREQUAL "SYSTEM" AND NOT snappy_PC_FOUND)
+ get_target_property(SNAPPY_LIB Snappy::snappy IMPORTED_LOCATION)
+ string(APPEND ARROW_PC_LIBS_PRIVATE " ${SNAPPY_LIB}")
+ endif()
+ # TODO: Don't use global includes but rather target_include_directories
+ get_target_property(SNAPPY_INCLUDE_DIRS Snappy::snappy INTERFACE_INCLUDE_DIRECTORIES)
+ include_directories(SYSTEM ${SNAPPY_INCLUDE_DIRS})
+endif()
+
+# ----------------------------------------------------------------------
+# Brotli
+
+macro(build_brotli)
+ message(STATUS "Building brotli from source")
+ set(BROTLI_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/brotli_ep/src/brotli_ep-install")
+ set(BROTLI_INCLUDE_DIR "${BROTLI_PREFIX}/include")
+ set(BROTLI_LIB_DIR lib)
+ set(BROTLI_STATIC_LIBRARY_ENC
+ "${BROTLI_PREFIX}/${BROTLI_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}brotlienc-static${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ set(BROTLI_STATIC_LIBRARY_DEC
+ "${BROTLI_PREFIX}/${BROTLI_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}brotlidec-static${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ set(BROTLI_STATIC_LIBRARY_COMMON
+ "${BROTLI_PREFIX}/${BROTLI_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}brotlicommon-static${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ set(BROTLI_CMAKE_ARGS ${EP_COMMON_CMAKE_ARGS} "-DCMAKE_INSTALL_PREFIX=${BROTLI_PREFIX}"
+ -DCMAKE_INSTALL_LIBDIR=${BROTLI_LIB_DIR})
+
+ externalproject_add(brotli_ep
+ URL ${BROTLI_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_BROTLI_BUILD_SHA256_CHECKSUM}"
+ BUILD_BYPRODUCTS "${BROTLI_STATIC_LIBRARY_ENC}"
+ "${BROTLI_STATIC_LIBRARY_DEC}"
+ "${BROTLI_STATIC_LIBRARY_COMMON}"
+ ${BROTLI_BUILD_BYPRODUCTS}
+ ${EP_LOG_OPTIONS}
+ CMAKE_ARGS ${BROTLI_CMAKE_ARGS}
+ STEP_TARGETS headers_copy)
+
+ add_dependencies(toolchain brotli_ep)
+ file(MAKE_DIRECTORY "${BROTLI_INCLUDE_DIR}")
+
+ add_library(Brotli::brotlicommon STATIC IMPORTED)
+ set_target_properties(Brotli::brotlicommon
+ PROPERTIES IMPORTED_LOCATION "${BROTLI_STATIC_LIBRARY_COMMON}"
+ INTERFACE_INCLUDE_DIRECTORIES "${BROTLI_INCLUDE_DIR}")
+ add_dependencies(Brotli::brotlicommon brotli_ep)
+
+ add_library(Brotli::brotlienc STATIC IMPORTED)
+ set_target_properties(Brotli::brotlienc
+ PROPERTIES IMPORTED_LOCATION "${BROTLI_STATIC_LIBRARY_ENC}"
+ INTERFACE_INCLUDE_DIRECTORIES "${BROTLI_INCLUDE_DIR}")
+ add_dependencies(Brotli::brotlienc brotli_ep)
+
+ add_library(Brotli::brotlidec STATIC IMPORTED)
+ set_target_properties(Brotli::brotlidec
+ PROPERTIES IMPORTED_LOCATION "${BROTLI_STATIC_LIBRARY_DEC}"
+ INTERFACE_INCLUDE_DIRECTORIES "${BROTLI_INCLUDE_DIR}")
+ add_dependencies(Brotli::brotlidec brotli_ep)
+
+ list(APPEND
+ ARROW_BUNDLED_STATIC_LIBS
+ Brotli::brotlicommon
+ Brotli::brotlienc
+ Brotli::brotlidec)
+endmacro()
+
+if(ARROW_WITH_BROTLI)
+ resolve_dependency(Brotli PC_PACKAGE_NAMES libbrotlidec libbrotlienc)
+ # TODO: Don't use global includes but rather target_include_directories
+ get_target_property(BROTLI_INCLUDE_DIR Brotli::brotlicommon
+ INTERFACE_INCLUDE_DIRECTORIES)
+ include_directories(SYSTEM ${BROTLI_INCLUDE_DIR})
+endif()
+
+if(PARQUET_REQUIRE_ENCRYPTION AND NOT ARROW_PARQUET)
+ set(PARQUET_REQUIRE_ENCRYPTION OFF)
+endif()
+set(ARROW_OPENSSL_REQUIRED_VERSION "1.0.2")
+if(BREW_BIN AND NOT OPENSSL_ROOT_DIR)
+ execute_process(COMMAND ${BREW_BIN} --prefix "openssl@1.1"
+ OUTPUT_VARIABLE OPENSSL11_BREW_PREFIX
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+ if(OPENSSL11_BREW_PREFIX)
+ set(OPENSSL_ROOT_DIR ${OPENSSL11_BREW_PREFIX})
+ else()
+ execute_process(COMMAND ${BREW_BIN} --prefix "openssl"
+ OUTPUT_VARIABLE OPENSSL_BREW_PREFIX
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+ if(OPENSSL_BREW_PREFIX)
+ set(OPENSSL_ROOT_DIR ${OPENSSL_BREW_PREFIX})
+ endif()
+ endif()
+endif()
+
+set(ARROW_USE_OPENSSL OFF)
+if(PARQUET_REQUIRE_ENCRYPTION
+ OR ARROW_FLIGHT
+ OR ARROW_S3)
+ # OpenSSL is required
+ if(ARROW_OPENSSL_USE_SHARED)
+ # Find shared OpenSSL libraries.
+ set(OpenSSL_USE_STATIC_LIBS OFF)
+ # Seems that different envs capitalize this differently?
+ set(OPENSSL_USE_STATIC_LIBS OFF)
+ set(BUILD_SHARED_LIBS_KEEP ${BUILD_SHARED_LIBS})
+ set(BUILD_SHARED_LIBS ON)
+
+ find_package(OpenSSL ${ARROW_OPENSSL_REQUIRED_VERSION} REQUIRED)
+ set(BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS_KEEP})
+ unset(BUILD_SHARED_LIBS_KEEP)
+ else()
+ # Find static OpenSSL headers and libs
+ set(OpenSSL_USE_STATIC_LIBS ON)
+ set(OPENSSL_USE_STATIC_LIBS ON)
+ find_package(OpenSSL ${ARROW_OPENSSL_REQUIRED_VERSION} REQUIRED)
+ endif()
+ set(ARROW_USE_OPENSSL ON)
+endif()
+
+if(ARROW_USE_OPENSSL)
+ message(STATUS "Found OpenSSL Crypto Library: ${OPENSSL_CRYPTO_LIBRARY}")
+ message(STATUS "Building with OpenSSL (Version: ${OPENSSL_VERSION}) support")
+
+ list(APPEND ARROW_SYSTEM_DEPENDENCIES "OpenSSL")
+
+ include_directories(SYSTEM ${OPENSSL_INCLUDE_DIR})
+else()
+ message(STATUS "Building without OpenSSL support. Minimum OpenSSL version ${ARROW_OPENSSL_REQUIRED_VERSION} required."
+ )
+endif()
+
+# ----------------------------------------------------------------------
+# GLOG
+
+macro(build_glog)
+ message(STATUS "Building glog from source")
+ set(GLOG_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/glog_ep-prefix/src/glog_ep")
+ set(GLOG_INCLUDE_DIR "${GLOG_BUILD_DIR}/include")
+ if(${UPPERCASE_BUILD_TYPE} STREQUAL "DEBUG")
+ set(GLOG_LIB_SUFFIX "d")
+ else()
+ set(GLOG_LIB_SUFFIX "")
+ endif()
+ set(GLOG_STATIC_LIB
+ "${GLOG_BUILD_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}glog${GLOG_LIB_SUFFIX}${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ set(GLOG_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
+ set(GLOG_CMAKE_C_FLAGS "${EP_C_FLAGS} -fPIC")
+ if(CMAKE_THREAD_LIBS_INIT)
+ set(GLOG_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CMAKE_THREAD_LIBS_INIT}")
+ set(GLOG_CMAKE_C_FLAGS "${EP_C_FLAGS} ${CMAKE_THREAD_LIBS_INIT}")
+ endif()
+
+ if(APPLE)
+ # If we don't set this flag, the binary built with 10.13 cannot be used in 10.12.
+ set(GLOG_CMAKE_CXX_FLAGS "${GLOG_CMAKE_CXX_FLAGS} -mmacosx-version-min=10.9")
+ endif()
+
+ set(GLOG_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS}
+ "-DCMAKE_INSTALL_PREFIX=${GLOG_BUILD_DIR}"
+ -DBUILD_SHARED_LIBS=OFF
+ -DBUILD_TESTING=OFF
+ -DWITH_GFLAGS=OFF
+ -DCMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}=${GLOG_CMAKE_CXX_FLAGS}
+ -DCMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}=${GLOG_CMAKE_C_FLAGS}
+ -DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS})
+ externalproject_add(glog_ep
+ URL ${GLOG_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_GLOG_BUILD_SHA256_CHECKSUM}"
+ BUILD_IN_SOURCE 1
+ BUILD_BYPRODUCTS "${GLOG_STATIC_LIB}"
+ CMAKE_ARGS ${GLOG_CMAKE_ARGS} ${EP_LOG_OPTIONS})
+
+ add_dependencies(toolchain glog_ep)
+ file(MAKE_DIRECTORY "${GLOG_INCLUDE_DIR}")
+
+ add_library(glog::glog STATIC IMPORTED)
+ set_target_properties(glog::glog
+ PROPERTIES IMPORTED_LOCATION "${GLOG_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${GLOG_INCLUDE_DIR}")
+ add_dependencies(glog::glog glog_ep)
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS glog::glog)
+endmacro()
+
+if(ARROW_USE_GLOG)
+ resolve_dependency(GLOG PC_PACKAGE_NAMES libglog)
+ # TODO: Don't use global includes but rather target_include_directories
+ get_target_property(GLOG_INCLUDE_DIR glog::glog INTERFACE_INCLUDE_DIRECTORIES)
+ include_directories(SYSTEM ${GLOG_INCLUDE_DIR})
+endif()
+
+# ----------------------------------------------------------------------
+# gflags
+
+if(ARROW_BUILD_TESTS
+ OR ARROW_BUILD_BENCHMARKS
+ OR ARROW_BUILD_INTEGRATION
+ OR ARROW_PLASMA
+ OR ARROW_USE_GLOG
+ OR ARROW_WITH_GRPC)
+ set(ARROW_NEED_GFLAGS 1)
+else()
+ set(ARROW_NEED_GFLAGS 0)
+endif()
+
+macro(build_gflags)
+ message(STATUS "Building gflags from source")
+
+ set(GFLAGS_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/gflags_ep-prefix/src/gflags_ep")
+ set(GFLAGS_INCLUDE_DIR "${GFLAGS_PREFIX}/include")
+ if(${UPPERCASE_BUILD_TYPE} STREQUAL "DEBUG")
+ set(GFLAGS_LIB_SUFFIX "_debug")
+ else()
+ set(GFLAGS_LIB_SUFFIX "")
+ endif()
+ if(MSVC)
+ set(GFLAGS_STATIC_LIB "${GFLAGS_PREFIX}/lib/gflags_static${GFLAGS_LIB_SUFFIX}.lib")
+ else()
+ set(GFLAGS_STATIC_LIB "${GFLAGS_PREFIX}/lib/libgflags${GFLAGS_LIB_SUFFIX}.a")
+ endif()
+ set(GFLAGS_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS}
+ "-DCMAKE_INSTALL_PREFIX=${GFLAGS_PREFIX}"
+ -DBUILD_SHARED_LIBS=OFF
+ -DBUILD_STATIC_LIBS=ON
+ -DBUILD_PACKAGING=OFF
+ -DBUILD_TESTING=OFF
+ -DBUILD_CONFIG_TESTS=OFF
+ -DINSTALL_HEADERS=ON)
+
+ file(MAKE_DIRECTORY "${GFLAGS_INCLUDE_DIR}")
+ externalproject_add(gflags_ep
+ URL ${GFLAGS_SOURCE_URL} ${EP_LOG_OPTIONS}
+ URL_HASH "SHA256=${ARROW_GFLAGS_BUILD_SHA256_CHECKSUM}"
+ BUILD_IN_SOURCE 1
+ BUILD_BYPRODUCTS "${GFLAGS_STATIC_LIB}"
+ CMAKE_ARGS ${GFLAGS_CMAKE_ARGS})
+
+ add_dependencies(toolchain gflags_ep)
+
+ add_thirdparty_lib(gflags STATIC_LIB ${GFLAGS_STATIC_LIB})
+ set(GFLAGS_LIBRARY gflags_static)
+ set_target_properties(${GFLAGS_LIBRARY}
+ PROPERTIES INTERFACE_COMPILE_DEFINITIONS "GFLAGS_IS_A_DLL=0"
+ INTERFACE_INCLUDE_DIRECTORIES "${GFLAGS_INCLUDE_DIR}")
+ if(MSVC)
+ set_target_properties(${GFLAGS_LIBRARY} PROPERTIES INTERFACE_LINK_LIBRARIES
+ "shlwapi.lib")
+ endif()
+ set(GFLAGS_LIBRARIES ${GFLAGS_LIBRARY})
+
+ set(GFLAGS_VENDORED TRUE)
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS gflags_static)
+endmacro()
+
+if(ARROW_NEED_GFLAGS)
+ set(ARROW_GFLAGS_REQUIRED_VERSION "2.1.0")
+ resolve_dependency(gflags
+ HAVE_ALT
+ TRUE
+ REQUIRED_VERSION
+ ${ARROW_GFLAGS_REQUIRED_VERSION}
+ IS_RUNTIME_DEPENDENCY
+ FALSE)
+ # TODO: Don't use global includes but rather target_include_directories
+ include_directories(SYSTEM ${GFLAGS_INCLUDE_DIR})
+
+ if(NOT TARGET ${GFLAGS_LIBRARIES})
+ if(TARGET gflags-shared)
+ set(GFLAGS_LIBRARIES gflags-shared)
+ elseif(TARGET gflags_shared)
+ set(GFLAGS_LIBRARIES gflags_shared)
+ endif()
+ endif()
+endif()
+
+# ----------------------------------------------------------------------
+# Thrift
+
+macro(build_thrift)
+ if(CMAKE_VERSION VERSION_LESS 3.10)
+ message(FATAL_ERROR "Building thrift using ExternalProject requires at least CMake 3.10"
+ )
+ endif()
+ message("Building Apache Thrift from source")
+ set(THRIFT_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/thrift_ep-install")
+ set(THRIFT_INCLUDE_DIR "${THRIFT_PREFIX}/include")
+ set(THRIFT_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS}
+ "-DCMAKE_INSTALL_PREFIX=${THRIFT_PREFIX}"
+ "-DCMAKE_INSTALL_RPATH=${THRIFT_PREFIX}/lib"
+ -DBUILD_COMPILER=OFF
+ -DBUILD_SHARED_LIBS=OFF
+ -DBUILD_TESTING=OFF
+ -DBUILD_EXAMPLES=OFF
+ -DBUILD_TUTORIALS=OFF
+ -DWITH_QT4=OFF
+ -DWITH_C_GLIB=OFF
+ -DWITH_JAVA=OFF
+ -DWITH_PYTHON=OFF
+ -DWITH_HASKELL=OFF
+ -DWITH_CPP=ON
+ -DWITH_STATIC_LIB=ON
+ -DWITH_LIBEVENT=OFF
+ # Work around https://gitlab.kitware.com/cmake/cmake/issues/18865
+ -DBoost_NO_BOOST_CMAKE=ON)
+
+ # Thrift also uses boost. Forward important boost settings if there were ones passed.
+ if(DEFINED BOOST_ROOT)
+ list(APPEND THRIFT_CMAKE_ARGS "-DBOOST_ROOT=${BOOST_ROOT}")
+ endif()
+ if(DEFINED Boost_NAMESPACE)
+ list(APPEND THRIFT_CMAKE_ARGS "-DBoost_NAMESPACE=${Boost_NAMESPACE}")
+ endif()
+
+ set(THRIFT_STATIC_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}thrift")
+ if(MSVC)
+ if(ARROW_USE_STATIC_CRT)
+ set(THRIFT_STATIC_LIB_NAME "${THRIFT_STATIC_LIB_NAME}mt")
+ list(APPEND THRIFT_CMAKE_ARGS "-DWITH_MT=ON")
+ else()
+ set(THRIFT_STATIC_LIB_NAME "${THRIFT_STATIC_LIB_NAME}md")
+ list(APPEND THRIFT_CMAKE_ARGS "-DWITH_MT=OFF")
+ endif()
+ endif()
+ if(${UPPERCASE_BUILD_TYPE} STREQUAL "DEBUG")
+ set(THRIFT_STATIC_LIB_NAME "${THRIFT_STATIC_LIB_NAME}d")
+ endif()
+ set(THRIFT_STATIC_LIB
+ "${THRIFT_PREFIX}/lib/${THRIFT_STATIC_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}")
+
+ if(BOOST_VENDORED)
+ set(THRIFT_DEPENDENCIES ${THRIFT_DEPENDENCIES} boost_ep)
+ endif()
+
+ externalproject_add(thrift_ep
+ URL ${THRIFT_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_THRIFT_BUILD_SHA256_CHECKSUM}"
+ BUILD_BYPRODUCTS "${THRIFT_STATIC_LIB}"
+ CMAKE_ARGS ${THRIFT_CMAKE_ARGS}
+ DEPENDS ${THRIFT_DEPENDENCIES} ${EP_LOG_OPTIONS})
+
+ add_library(thrift::thrift STATIC IMPORTED)
+ # The include directory must exist before it is referenced by a target.
+ file(MAKE_DIRECTORY "${THRIFT_INCLUDE_DIR}")
+ set_target_properties(thrift::thrift
+ PROPERTIES IMPORTED_LOCATION "${THRIFT_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${THRIFT_INCLUDE_DIR}")
+ add_dependencies(toolchain thrift_ep)
+ add_dependencies(thrift::thrift thrift_ep)
+ set(THRIFT_VERSION ${ARROW_THRIFT_BUILD_VERSION})
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS thrift::thrift)
+endmacro()
+
+if(ARROW_WITH_THRIFT)
+ # We already may have looked for Thrift earlier, when considering whether
+ # to build Boost, so don't look again if already found.
+ if(NOT Thrift_FOUND)
+ # Thrift c++ code generated by 0.13 requires 0.11 or greater
+ resolve_dependency(Thrift
+ REQUIRED_VERSION
+ 0.11.0
+ PC_PACKAGE_NAMES
+ thrift)
+ endif()
+ # TODO: Don't use global includes but rather target_include_directories
+ include_directories(SYSTEM ${THRIFT_INCLUDE_DIR})
+
+ string(REPLACE "." ";" VERSION_LIST ${THRIFT_VERSION})
+ list(GET VERSION_LIST 0 THRIFT_VERSION_MAJOR)
+ list(GET VERSION_LIST 1 THRIFT_VERSION_MINOR)
+ list(GET VERSION_LIST 2 THRIFT_VERSION_PATCH)
+endif()
+
+# ----------------------------------------------------------------------
+# Protocol Buffers (required for ORC and Flight and Gandiva libraries)
+
+macro(build_protobuf)
+ message("Building Protocol Buffers from source")
+ set(PROTOBUF_VENDORED TRUE)
+ set(PROTOBUF_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/protobuf_ep-install")
+ set(PROTOBUF_INCLUDE_DIR "${PROTOBUF_PREFIX}/include")
+ # Newer protobuf releases always have a lib prefix independent from CMAKE_STATIC_LIBRARY_PREFIX
+ set(PROTOBUF_STATIC_LIB
+ "${PROTOBUF_PREFIX}/lib/libprotobuf${CMAKE_STATIC_LIBRARY_SUFFIX}")
+ set(PROTOC_STATIC_LIB "${PROTOBUF_PREFIX}/lib/libprotoc${CMAKE_STATIC_LIBRARY_SUFFIX}")
+ set(Protobuf_PROTOC_LIBRARY "${PROTOC_STATIC_LIB}")
+ set(PROTOBUF_COMPILER "${PROTOBUF_PREFIX}/bin/protoc")
+
+ if(CMAKE_VERSION VERSION_LESS 3.7)
+ set(PROTOBUF_CONFIGURE_ARGS
+ "AR=${CMAKE_AR}"
+ "RANLIB=${CMAKE_RANLIB}"
+ "CC=${CMAKE_C_COMPILER}"
+ "CXX=${CMAKE_CXX_COMPILER}"
+ "--disable-shared"
+ "--prefix=${PROTOBUF_PREFIX}"
+ "CFLAGS=${EP_C_FLAGS}"
+ "CXXFLAGS=${EP_CXX_FLAGS}")
+ set(PROTOBUF_BUILD_COMMAND ${MAKE} ${MAKE_BUILD_ARGS})
+ if(CMAKE_OSX_SYSROOT)
+ list(APPEND PROTOBUF_CONFIGURE_ARGS "SDKROOT=${CMAKE_OSX_SYSROOT}")
+ list(APPEND PROTOBUF_BUILD_COMMAND "SDKROOT=${CMAKE_OSX_SYSROOT}")
+ endif()
+ set(PROTOBUF_EXTERNAL_PROJECT_ADD_ARGS
+ CONFIGURE_COMMAND
+ "./configure"
+ ${PROTOBUF_CONFIGURE_ARGS}
+ BUILD_COMMAND
+ ${PROTOBUF_BUILD_COMMAND})
+ else()
+ # Strip lto flags (which may be added by dh_auto_configure)
+ # See https://github.com/protocolbuffers/protobuf/issues/7092
+ set(PROTOBUF_C_FLAGS ${EP_C_FLAGS})
+ set(PROTOBUF_CXX_FLAGS ${EP_CXX_FLAGS})
+ string(REPLACE "-flto=auto" "" PROTOBUF_C_FLAGS "${PROTOBUF_C_FLAGS}")
+ string(REPLACE "-ffat-lto-objects" "" PROTOBUF_C_FLAGS "${PROTOBUF_C_FLAGS}")
+ string(REPLACE "-flto=auto" "" PROTOBUF_CXX_FLAGS "${PROTOBUF_CXX_FLAGS}")
+ string(REPLACE "-ffat-lto-objects" "" PROTOBUF_CXX_FLAGS "${PROTOBUF_CXX_FLAGS}")
+ set(PROTOBUF_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS}
+ -DBUILD_SHARED_LIBS=OFF
+ -DCMAKE_INSTALL_LIBDIR=lib
+ "-DCMAKE_INSTALL_PREFIX=${PROTOBUF_PREFIX}"
+ -Dprotobuf_BUILD_TESTS=OFF
+ -Dprotobuf_DEBUG_POSTFIX=
+ "-DCMAKE_C_FLAGS=${PROTOBUF_C_FLAGS}"
+ "-DCMAKE_CXX_FLAGS=${PROTOBUF_CXX_FLAGS}"
+ "-DCMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}=${PROTOBUF_C_FLAGS}"
+ "-DCMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}=${PROTOBUF_CXX_FLAGS}")
+ if(MSVC AND NOT ARROW_USE_STATIC_CRT)
+ list(APPEND PROTOBUF_CMAKE_ARGS "-Dprotobuf_MSVC_STATIC_RUNTIME=OFF")
+ endif()
+ if(ZLIB_ROOT)
+ list(APPEND PROTOBUF_CMAKE_ARGS "-DZLIB_ROOT=${ZLIB_ROOT}")
+ endif()
+ set(PROTOBUF_EXTERNAL_PROJECT_ADD_ARGS CMAKE_ARGS ${PROTOBUF_CMAKE_ARGS}
+ SOURCE_SUBDIR "cmake")
+ endif()
+
+ externalproject_add(protobuf_ep
+ ${PROTOBUF_EXTERNAL_PROJECT_ADD_ARGS}
+ BUILD_BYPRODUCTS "${PROTOBUF_STATIC_LIB}" "${PROTOBUF_COMPILER}"
+ ${EP_LOG_OPTIONS}
+ BUILD_IN_SOURCE 1
+ URL ${PROTOBUF_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_PROTOBUF_BUILD_SHA256_CHECKSUM}")
+
+ file(MAKE_DIRECTORY "${PROTOBUF_INCLUDE_DIR}")
+
+ add_library(arrow::protobuf::libprotobuf STATIC IMPORTED)
+ set_target_properties(arrow::protobuf::libprotobuf
+ PROPERTIES IMPORTED_LOCATION "${PROTOBUF_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${PROTOBUF_INCLUDE_DIR}")
+ add_library(arrow::protobuf::libprotoc STATIC IMPORTED)
+ set_target_properties(arrow::protobuf::libprotoc
+ PROPERTIES IMPORTED_LOCATION "${PROTOC_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${PROTOBUF_INCLUDE_DIR}")
+ add_executable(arrow::protobuf::protoc IMPORTED)
+ set_target_properties(arrow::protobuf::protoc PROPERTIES IMPORTED_LOCATION
+ "${PROTOBUF_COMPILER}")
+
+ add_dependencies(toolchain protobuf_ep)
+ add_dependencies(arrow::protobuf::libprotobuf protobuf_ep)
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS arrow::protobuf::libprotobuf)
+endmacro()
+
+if(ARROW_WITH_PROTOBUF)
+ if(ARROW_WITH_GRPC)
+ # FlightSQL uses proto3 optionals, which require 3.15 or later.
+ set(ARROW_PROTOBUF_REQUIRED_VERSION "3.15.0")
+ elseif(ARROW_GANDIVA_JAVA)
+ # google::protobuf::MessageLite::ByteSize() is deprecated since
+ # Protobuf 3.4.0.
+ set(ARROW_PROTOBUF_REQUIRED_VERSION "3.4.0")
+ else()
+ set(ARROW_PROTOBUF_REQUIRED_VERSION "2.6.1")
+ endif()
+ resolve_dependency(Protobuf
+ REQUIRED_VERSION
+ ${ARROW_PROTOBUF_REQUIRED_VERSION}
+ PC_PACKAGE_NAMES
+ protobuf)
+
+ if(ARROW_PROTOBUF_USE_SHARED AND MSVC_TOOLCHAIN)
+ add_definitions(-DPROTOBUF_USE_DLLS)
+ endif()
+
+ # TODO: Don't use global includes but rather target_include_directories
+ include_directories(SYSTEM ${PROTOBUF_INCLUDE_DIR})
+
+ if(TARGET arrow::protobuf::libprotobuf)
+ set(ARROW_PROTOBUF_LIBPROTOBUF arrow::protobuf::libprotobuf)
+ else()
+ # CMake 3.8 or older don't define the targets
+ if(NOT TARGET protobuf::libprotobuf)
+ add_library(protobuf::libprotobuf UNKNOWN IMPORTED)
+ set_target_properties(protobuf::libprotobuf
+ PROPERTIES IMPORTED_LOCATION "${PROTOBUF_LIBRARY}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${PROTOBUF_INCLUDE_DIR}")
+ endif()
+ set(ARROW_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf)
+ endif()
+ if(TARGET arrow::protobuf::libprotoc)
+ set(ARROW_PROTOBUF_LIBPROTOC arrow::protobuf::libprotoc)
+ else()
+ # CMake 3.8 or older don't define the targets
+ if(NOT TARGET protobuf::libprotoc)
+ if(PROTOBUF_PROTOC_LIBRARY AND NOT Protobuf_PROTOC_LIBRARY)
+ # Old CMake versions have a different casing.
+ set(Protobuf_PROTOC_LIBRARY ${PROTOBUF_PROTOC_LIBRARY})
+ endif()
+ if(NOT Protobuf_PROTOC_LIBRARY)
+ message(FATAL_ERROR "libprotoc was set to ${Protobuf_PROTOC_LIBRARY}")
+ endif()
+ add_library(protobuf::libprotoc UNKNOWN IMPORTED)
+ set_target_properties(protobuf::libprotoc
+ PROPERTIES IMPORTED_LOCATION "${Protobuf_PROTOC_LIBRARY}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${PROTOBUF_INCLUDE_DIR}")
+ endif()
+ set(ARROW_PROTOBUF_LIBPROTOC protobuf::libprotoc)
+ endif()
+ if(TARGET arrow::protobuf::protoc)
+ set(ARROW_PROTOBUF_PROTOC arrow::protobuf::protoc)
+ else()
+ if(NOT TARGET protobuf::protoc)
+ add_executable(protobuf::protoc IMPORTED)
+ set_target_properties(protobuf::protoc PROPERTIES IMPORTED_LOCATION
+ "${PROTOBUF_PROTOC_EXECUTABLE}")
+ endif()
+ set(ARROW_PROTOBUF_PROTOC protobuf::protoc)
+ endif()
+
+ # Log protobuf paths as we often see issues with mixed sources for
+ # the libraries and protoc.
+ get_target_property(PROTOBUF_PROTOC_EXECUTABLE ${ARROW_PROTOBUF_PROTOC}
+ IMPORTED_LOCATION)
+ message(STATUS "Found protoc: ${PROTOBUF_PROTOC_EXECUTABLE}")
+ # Protobuf_PROTOC_LIBRARY is set by all versions of FindProtobuf.cmake
+ message(STATUS "Found libprotoc: ${Protobuf_PROTOC_LIBRARY}")
+ get_target_property(PROTOBUF_LIBRARY ${ARROW_PROTOBUF_LIBPROTOBUF} IMPORTED_LOCATION)
+ message(STATUS "Found libprotobuf: ${PROTOBUF_LIBRARY}")
+ message(STATUS "Found protobuf headers: ${PROTOBUF_INCLUDE_DIR}")
+endif()
+
+# ----------------------------------------------------------------------
+# jemalloc - Unix-only high-performance allocator
+
+if(ARROW_JEMALLOC)
+ message(STATUS "Building (vendored) jemalloc from source")
+ # We only use a vendored jemalloc as we want to control its version.
+ # Also our build of jemalloc is specially prefixed so that it will not
+ # conflict with the default allocator as well as other jemalloc
+ # installations.
+ # find_package(jemalloc)
+
+ set(ARROW_JEMALLOC_USE_SHARED OFF)
+ set(JEMALLOC_PREFIX
+ "${CMAKE_CURRENT_BINARY_DIR}/jemalloc_ep-prefix/src/jemalloc_ep/dist/")
+ set(JEMALLOC_LIB_DIR "${JEMALLOC_PREFIX}/lib")
+ set(JEMALLOC_STATIC_LIB
+ "${JEMALLOC_LIB_DIR}/libjemalloc_pic${CMAKE_STATIC_LIBRARY_SUFFIX}")
+ set(JEMALLOC_CONFIGURE_COMMAND ./configure "AR=${CMAKE_AR}" "CC=${CMAKE_C_COMPILER}")
+ if(CMAKE_OSX_SYSROOT)
+ list(APPEND JEMALLOC_CONFIGURE_COMMAND "SDKROOT=${CMAKE_OSX_SYSROOT}")
+ endif()
+ if(DEFINED ARROW_JEMALLOC_LG_PAGE)
+ # Used for arm64 manylinux wheels in order to make the wheel work on both
+ # 4k and 64k page arm64 systems.
+ list(APPEND JEMALLOC_CONFIGURE_COMMAND "--with-lg-page=${ARROW_JEMALLOC_LG_PAGE}")
+ endif()
+ list(APPEND
+ JEMALLOC_CONFIGURE_COMMAND
+ "--prefix=${JEMALLOC_PREFIX}"
+ "--libdir=${JEMALLOC_LIB_DIR}"
+ "--with-jemalloc-prefix=je_arrow_"
+ "--with-private-namespace=je_arrow_private_"
+ "--without-export"
+ "--disable-shared"
+ # Don't override operator new()
+ "--disable-cxx"
+ "--disable-libdl"
+ # See https://github.com/jemalloc/jemalloc/issues/1237
+ "--disable-initial-exec-tls"
+ ${EP_LOG_OPTIONS})
+ set(JEMALLOC_BUILD_COMMAND ${MAKE} ${MAKE_BUILD_ARGS})
+ if(CMAKE_OSX_SYSROOT)
+ list(APPEND JEMALLOC_BUILD_COMMAND "SDKROOT=${CMAKE_OSX_SYSROOT}")
+ endif()
+ externalproject_add(jemalloc_ep
+ URL ${JEMALLOC_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_JEMALLOC_BUILD_SHA256_CHECKSUM}"
+ PATCH_COMMAND touch doc/jemalloc.3 doc/jemalloc.html
+ # The prefix "je_arrow_" must be kept in sync with the value in memory_pool.cc
+ CONFIGURE_COMMAND ${JEMALLOC_CONFIGURE_COMMAND}
+ BUILD_IN_SOURCE 1
+ BUILD_COMMAND ${JEMALLOC_BUILD_COMMAND}
+ BUILD_BYPRODUCTS "${JEMALLOC_STATIC_LIB}"
+ INSTALL_COMMAND ${MAKE} -j1 install)
+
+ # Don't use the include directory directly so that we can point to a path
+ # that is unique to our codebase.
+ include_directories(SYSTEM "${CMAKE_CURRENT_BINARY_DIR}/jemalloc_ep-prefix/src/")
+ # The include directory must exist before it is referenced by a target.
+ file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/jemalloc_ep-prefix/src/")
+ add_library(jemalloc::jemalloc STATIC IMPORTED)
+ set_target_properties(jemalloc::jemalloc
+ PROPERTIES INTERFACE_LINK_LIBRARIES Threads::Threads
+ IMPORTED_LOCATION "${JEMALLOC_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${CMAKE_CURRENT_BINARY_DIR}/jemalloc_ep-prefix/src")
+ add_dependencies(jemalloc::jemalloc jemalloc_ep)
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS jemalloc::jemalloc)
+endif()
+
+# ----------------------------------------------------------------------
+# mimalloc - Cross-platform high-performance allocator, from Microsoft
+
+if(ARROW_MIMALLOC)
+ message(STATUS "Building (vendored) mimalloc from source")
+ # We only use a vendored mimalloc as we want to control its build options.
+
+ set(MIMALLOC_LIB_BASE_NAME "mimalloc")
+ if(WIN32)
+ set(MIMALLOC_LIB_BASE_NAME "${MIMALLOC_LIB_BASE_NAME}-static")
+ endif()
+ if(${UPPERCASE_BUILD_TYPE} STREQUAL "DEBUG")
+ set(MIMALLOC_LIB_BASE_NAME "${MIMALLOC_LIB_BASE_NAME}-${LOWERCASE_BUILD_TYPE}")
+ endif()
+
+ set(MIMALLOC_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/mimalloc_ep/src/mimalloc_ep")
+ set(MIMALLOC_INCLUDE_DIR "${MIMALLOC_PREFIX}/include/mimalloc-1.7")
+ set(MIMALLOC_STATIC_LIB
+ "${MIMALLOC_PREFIX}/lib/mimalloc-1.7/${CMAKE_STATIC_LIBRARY_PREFIX}${MIMALLOC_LIB_BASE_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+
+ set(MIMALLOC_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS}
+ "-DCMAKE_INSTALL_PREFIX=${MIMALLOC_PREFIX}"
+ -DMI_OVERRIDE=OFF
+ -DMI_LOCAL_DYNAMIC_TLS=ON
+ -DMI_BUILD_OBJECT=OFF
+ -DMI_BUILD_SHARED=OFF
+ -DMI_BUILD_TESTS=OFF)
+
+ externalproject_add(mimalloc_ep
+ URL ${MIMALLOC_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_MIMALLOC_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${MIMALLOC_CMAKE_ARGS}
+ BUILD_BYPRODUCTS "${MIMALLOC_STATIC_LIB}")
+
+ include_directories(SYSTEM ${MIMALLOC_INCLUDE_DIR})
+ file(MAKE_DIRECTORY ${MIMALLOC_INCLUDE_DIR})
+
+ add_library(mimalloc::mimalloc STATIC IMPORTED)
+ set_target_properties(mimalloc::mimalloc
+ PROPERTIES INTERFACE_LINK_LIBRARIES Threads::Threads
+ IMPORTED_LOCATION "${MIMALLOC_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${MIMALLOC_INCLUDE_DIR}")
+ add_dependencies(mimalloc::mimalloc mimalloc_ep)
+ add_dependencies(toolchain mimalloc_ep)
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS mimalloc::mimalloc)
+endif()
+
+# ----------------------------------------------------------------------
+# Google gtest
+
+macro(build_gtest)
+ message(STATUS "Building gtest from source")
+ set(GTEST_VENDORED TRUE)
+ set(GTEST_CMAKE_CXX_FLAGS ${EP_CXX_FLAGS})
+
+ if(CMAKE_BUILD_TYPE MATCHES DEBUG)
+ set(CMAKE_GTEST_DEBUG_EXTENSION "d")
+ else()
+ set(CMAKE_GTEST_DEBUG_EXTENSION "")
+ endif()
+
+ if(APPLE)
+ set(GTEST_CMAKE_CXX_FLAGS ${GTEST_CMAKE_CXX_FLAGS} -DGTEST_USE_OWN_TR1_TUPLE=1
+ -Wno-unused-value -Wno-ignored-attributes)
+ endif()
+
+ if(MSVC)
+ set(GTEST_CMAKE_CXX_FLAGS "${GTEST_CMAKE_CXX_FLAGS} -DGTEST_CREATE_SHARED_LIBRARY=1")
+ endif()
+
+ set(GTEST_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/googletest_ep-prefix")
+ set(GTEST_INCLUDE_DIR "${GTEST_PREFIX}/include")
+
+ set(_GTEST_LIBRARY_DIR "${GTEST_PREFIX}/lib")
+
+ if(MSVC)
+ set(_GTEST_IMPORTED_TYPE IMPORTED_IMPLIB)
+ set(_GTEST_LIBRARY_SUFFIX
+ "${CMAKE_GTEST_DEBUG_EXTENSION}${CMAKE_IMPORT_LIBRARY_SUFFIX}")
+ else()
+ set(_GTEST_IMPORTED_TYPE IMPORTED_LOCATION)
+ set(_GTEST_LIBRARY_SUFFIX
+ "${CMAKE_GTEST_DEBUG_EXTENSION}${CMAKE_SHARED_LIBRARY_SUFFIX}")
+
+ endif()
+
+ set(GTEST_SHARED_LIB
+ "${_GTEST_LIBRARY_DIR}/${CMAKE_SHARED_LIBRARY_PREFIX}gtest${_GTEST_LIBRARY_SUFFIX}")
+ set(GMOCK_SHARED_LIB
+ "${_GTEST_LIBRARY_DIR}/${CMAKE_SHARED_LIBRARY_PREFIX}gmock${_GTEST_LIBRARY_SUFFIX}")
+ set(GTEST_MAIN_SHARED_LIB
+ "${_GTEST_LIBRARY_DIR}/${CMAKE_SHARED_LIBRARY_PREFIX}gtest_main${_GTEST_LIBRARY_SUFFIX}"
+ )
+ set(GTEST_INSTALL_NAME_DIR "$<INSTALL_PREFIX$<ANGLE-R>/lib")
+ # Fix syntax highlighting mess introduced by unclosed bracket above
+ set(dummy ">")
+
+ set(GTEST_CMAKE_ARGS
+ ${EP_COMMON_TOOLCHAIN}
+ -DBUILD_SHARED_LIBS=ON
+ -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
+ -DCMAKE_CXX_FLAGS=${GTEST_CMAKE_CXX_FLAGS}
+ -DCMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}=${GTEST_CMAKE_CXX_FLAGS}
+ -DCMAKE_INSTALL_LIBDIR=lib
+ -DCMAKE_INSTALL_NAME_DIR=${GTEST_INSTALL_NAME_DIR}
+ -DCMAKE_INSTALL_PREFIX=${GTEST_PREFIX}
+ -DCMAKE_MACOSX_RPATH=OFF)
+ set(GMOCK_INCLUDE_DIR "${GTEST_PREFIX}/include")
+
+ add_definitions(-DGTEST_LINKED_AS_SHARED_LIBRARY=1)
+
+ if(MSVC AND NOT ARROW_USE_STATIC_CRT)
+ set(GTEST_CMAKE_ARGS ${GTEST_CMAKE_ARGS} -Dgtest_force_shared_crt=ON)
+ endif()
+
+ externalproject_add(googletest_ep
+ URL ${GTEST_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_GTEST_BUILD_SHA256_CHECKSUM}"
+ BUILD_BYPRODUCTS ${GTEST_SHARED_LIB} ${GTEST_MAIN_SHARED_LIB}
+ ${GMOCK_SHARED_LIB}
+ CMAKE_ARGS ${GTEST_CMAKE_ARGS} ${EP_LOG_OPTIONS})
+ if(WIN32)
+ # Copy the built shared libraries to the same directory as our
+ # test programs because Windows doesn't provided rpath (run-time
+ # search path) feature. We need to put these shared libraries to
+ # the same directory as our test programs or add
+ # _GTEST_LIBRARY_DIR to PATH when we run our test programs. We
+ # choose the former because the latter may be forgotten.
+ set(_GTEST_RUNTIME_DIR "${GTEST_PREFIX}/bin")
+ set(_GTEST_RUNTIME_SUFFIX
+ "${CMAKE_GTEST_DEBUG_EXTENSION}${CMAKE_SHARED_LIBRARY_SUFFIX}")
+ set(_GTEST_RUNTIME_LIB
+ "${_GTEST_RUNTIME_DIR}/${CMAKE_SHARED_LIBRARY_PREFIX}gtest${_GTEST_RUNTIME_SUFFIX}"
+ )
+ set(_GMOCK_RUNTIME_LIB
+ "${_GTEST_RUNTIME_DIR}/${CMAKE_SHARED_LIBRARY_PREFIX}gmock${_GTEST_RUNTIME_SUFFIX}"
+ )
+ set(_GTEST_MAIN_RUNTIME_LIB
+ "${_GTEST_RUNTIME_DIR}/${CMAKE_SHARED_LIBRARY_PREFIX}gtest_main${_GTEST_RUNTIME_SUFFIX}"
+ )
+ if(CMAKE_VERSION VERSION_LESS 3.9)
+ message(FATAL_ERROR "Building GoogleTest from source on Windows requires at least CMake 3.9"
+ )
+ endif()
+ get_property(_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG)
+ if(_GENERATOR_IS_MULTI_CONFIG)
+ set(_GTEST_RUNTIME_OUTPUT_DIR "${BUILD_OUTPUT_ROOT_DIRECTORY}/${CMAKE_BUILD_TYPE}")
+ else()
+ set(_GTEST_RUNTIME_OUTPUT_DIR ${BUILD_OUTPUT_ROOT_DIRECTORY})
+ endif()
+ externalproject_add_step(googletest_ep copy
+ COMMAND ${CMAKE_COMMAND} -E make_directory
+ ${_GTEST_RUNTIME_OUTPUT_DIR}
+ COMMAND ${CMAKE_COMMAND} -E copy ${_GTEST_RUNTIME_LIB}
+ ${_GTEST_RUNTIME_OUTPUT_DIR}
+ COMMAND ${CMAKE_COMMAND} -E copy ${_GMOCK_RUNTIME_LIB}
+ ${_GTEST_RUNTIME_OUTPUT_DIR}
+ COMMAND ${CMAKE_COMMAND} -E copy ${_GTEST_MAIN_RUNTIME_LIB}
+ ${_GTEST_RUNTIME_OUTPUT_DIR}
+ DEPENDEES install)
+ endif()
+
+ # The include directory must exist before it is referenced by a target.
+ file(MAKE_DIRECTORY "${GTEST_INCLUDE_DIR}")
+
+ add_library(GTest::gtest SHARED IMPORTED)
+ set_target_properties(GTest::gtest
+ PROPERTIES ${_GTEST_IMPORTED_TYPE} "${GTEST_SHARED_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
+
+ add_library(GTest::gtest_main SHARED IMPORTED)
+ set_target_properties(GTest::gtest_main
+ PROPERTIES ${_GTEST_IMPORTED_TYPE} "${GTEST_MAIN_SHARED_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
+
+ add_library(GTest::gmock SHARED IMPORTED)
+ set_target_properties(GTest::gmock
+ PROPERTIES ${_GTEST_IMPORTED_TYPE} "${GMOCK_SHARED_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}")
+ add_dependencies(toolchain-tests googletest_ep)
+ add_dependencies(GTest::gtest googletest_ep)
+ add_dependencies(GTest::gtest_main googletest_ep)
+ add_dependencies(GTest::gmock googletest_ep)
+endmacro()
+
+if(ARROW_TESTING)
+ resolve_dependency(GTest
+ REQUIRED_VERSION
+ 1.10.0
+ USE_CONFIG
+ TRUE)
+
+ if(NOT GTEST_VENDORED)
+ # TODO(wesm): This logic does not work correctly with the MSVC static libraries
+ # built for the shared crt
+
+ # set(CMAKE_REQUIRED_LIBRARIES GTest::GTest GTest::Main GTest::GMock)
+ # CHECK_CXX_SOURCE_COMPILES("
+ # #include <gmock/gmock.h>
+ # #include <gtest/gtest.h>
+
+ # class A {
+ # public:
+ # int run() const { return 1; }
+ # };
+
+ # class B : public A {
+ # public:
+ # MOCK_CONST_METHOD0(run, int());
+ # };
+
+ # TEST(Base, Test) {
+ # B b;
+ # }" GTEST_COMPILES_WITHOUT_MACRO)
+ # if (NOT GTEST_COMPILES_WITHOUT_MACRO)
+ # message(STATUS "Setting GTEST_LINKED_AS_SHARED_LIBRARY=1 on GTest::GTest")
+ # add_compile_definitions("GTEST_LINKED_AS_SHARED_LIBRARY=1")
+ # endif()
+ # set(CMAKE_REQUIRED_LIBRARIES)
+ endif()
+
+ get_target_property(GTEST_INCLUDE_DIR GTest::gtest INTERFACE_INCLUDE_DIRECTORIES)
+ # TODO: Don't use global includes but rather target_include_directories
+ include_directories(SYSTEM ${GTEST_INCLUDE_DIR})
+endif()
+
+macro(build_benchmark)
+ message(STATUS "Building benchmark from source")
+ if(CMAKE_VERSION VERSION_LESS 3.6)
+ message(FATAL_ERROR "Building gbenchmark from source requires at least CMake 3.6")
+ endif()
+
+ if(NOT MSVC)
+ set(GBENCHMARK_CMAKE_CXX_FLAGS "${EP_CXX_FLAGS} -std=c++11")
+ endif()
+
+ if(APPLE AND (CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID
+ STREQUAL "Clang"))
+ set(GBENCHMARK_CMAKE_CXX_FLAGS "${GBENCHMARK_CMAKE_CXX_FLAGS} -stdlib=libc++")
+ endif()
+
+ set(GBENCHMARK_PREFIX
+ "${CMAKE_CURRENT_BINARY_DIR}/gbenchmark_ep/src/gbenchmark_ep-install")
+ set(GBENCHMARK_INCLUDE_DIR "${GBENCHMARK_PREFIX}/include")
+ set(GBENCHMARK_STATIC_LIB
+ "${GBENCHMARK_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}benchmark${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ set(GBENCHMARK_MAIN_STATIC_LIB
+ "${GBENCHMARK_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}benchmark_main${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ set(GBENCHMARK_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS}
+ "-DCMAKE_INSTALL_PREFIX=${GBENCHMARK_PREFIX}"
+ -DCMAKE_INSTALL_LIBDIR=lib
+ -DBENCHMARK_ENABLE_TESTING=OFF
+ -DCMAKE_CXX_FLAGS=${GBENCHMARK_CMAKE_CXX_FLAGS})
+ if(APPLE)
+ set(GBENCHMARK_CMAKE_ARGS ${GBENCHMARK_CMAKE_ARGS} "-DBENCHMARK_USE_LIBCXX=ON")
+ endif()
+
+ externalproject_add(gbenchmark_ep
+ URL ${GBENCHMARK_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_GBENCHMARK_BUILD_SHA256_CHECKSUM}"
+ BUILD_BYPRODUCTS "${GBENCHMARK_STATIC_LIB}"
+ "${GBENCHMARK_MAIN_STATIC_LIB}"
+ CMAKE_ARGS ${GBENCHMARK_CMAKE_ARGS} ${EP_LOG_OPTIONS})
+
+ # The include directory must exist before it is referenced by a target.
+ file(MAKE_DIRECTORY "${GBENCHMARK_INCLUDE_DIR}")
+
+ add_library(benchmark::benchmark STATIC IMPORTED)
+ set_target_properties(benchmark::benchmark
+ PROPERTIES IMPORTED_LOCATION "${GBENCHMARK_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${GBENCHMARK_INCLUDE_DIR}")
+
+ add_library(benchmark::benchmark_main STATIC IMPORTED)
+ set_target_properties(benchmark::benchmark_main
+ PROPERTIES IMPORTED_LOCATION "${GBENCHMARK_MAIN_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${GBENCHMARK_INCLUDE_DIR}")
+
+ add_dependencies(toolchain-benchmarks gbenchmark_ep)
+ add_dependencies(benchmark::benchmark gbenchmark_ep)
+ add_dependencies(benchmark::benchmark_main gbenchmark_ep)
+endmacro()
+
+if(ARROW_BUILD_BENCHMARKS)
+ # ArgsProduct() is available since 1.5.2
+ set(BENCHMARK_REQUIRED_VERSION 1.5.2)
+ resolve_dependency(benchmark
+ REQUIRED_VERSION
+ ${BENCHMARK_REQUIRED_VERSION}
+ IS_RUNTIME_DEPENDENCY
+ FALSE)
+ # TODO: Don't use global includes but rather target_include_directories
+ get_target_property(BENCHMARK_INCLUDE_DIR benchmark::benchmark
+ INTERFACE_INCLUDE_DIRECTORIES)
+ include_directories(SYSTEM ${BENCHMARK_INCLUDE_DIR})
+endif()
+
+macro(build_rapidjson)
+ message(STATUS "Building RapidJSON from source")
+ set(RAPIDJSON_PREFIX
+ "${CMAKE_CURRENT_BINARY_DIR}/rapidjson_ep/src/rapidjson_ep-install")
+ set(RAPIDJSON_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS}
+ -DRAPIDJSON_BUILD_DOC=OFF
+ -DRAPIDJSON_BUILD_EXAMPLES=OFF
+ -DRAPIDJSON_BUILD_TESTS=OFF
+ "-DCMAKE_INSTALL_PREFIX=${RAPIDJSON_PREFIX}")
+
+ externalproject_add(rapidjson_ep
+ ${EP_LOG_OPTIONS}
+ PREFIX "${CMAKE_BINARY_DIR}"
+ URL ${RAPIDJSON_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_RAPIDJSON_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${RAPIDJSON_CMAKE_ARGS})
+
+ set(RAPIDJSON_INCLUDE_DIR "${RAPIDJSON_PREFIX}/include")
+
+ add_dependencies(toolchain rapidjson_ep)
+ add_dependencies(toolchain-tests rapidjson_ep)
+ add_dependencies(rapidjson rapidjson_ep)
+
+ set(RAPIDJSON_VENDORED TRUE)
+endmacro()
+
+if(ARROW_WITH_RAPIDJSON)
+ set(ARROW_RAPIDJSON_REQUIRED_VERSION "1.1.0")
+ resolve_dependency(RapidJSON
+ HAVE_ALT
+ TRUE
+ REQUIRED_VERSION
+ ${ARROW_RAPIDJSON_REQUIRED_VERSION}
+ IS_RUNTIME_DEPENDENCY
+ FALSE)
+
+ if(RapidJSON_INCLUDE_DIR)
+ set(RAPIDJSON_INCLUDE_DIR "${RapidJSON_INCLUDE_DIR}")
+ endif()
+
+ # TODO: Don't use global includes but rather target_include_directories
+ include_directories(SYSTEM ${RAPIDJSON_INCLUDE_DIR})
+endif()
+
+macro(build_xsimd)
+ message(STATUS "Building xsimd from source")
+ set(XSIMD_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/xsimd_ep/src/xsimd_ep-install")
+ set(XSIMD_CMAKE_ARGS ${EP_COMMON_CMAKE_ARGS} "-DCMAKE_INSTALL_PREFIX=${XSIMD_PREFIX}")
+
+ externalproject_add(xsimd_ep
+ ${EP_LOG_OPTIONS}
+ PREFIX "${CMAKE_BINARY_DIR}"
+ URL ${XSIMD_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_XSIMD_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${XSIMD_CMAKE_ARGS})
+
+ set(XSIMD_INCLUDE_DIR "${XSIMD_PREFIX}/include")
+
+ add_dependencies(toolchain xsimd_ep)
+ add_dependencies(toolchain-tests xsimd_ep)
+
+ set(XSIMD_VENDORED TRUE)
+endmacro()
+
+if((NOT ARROW_SIMD_LEVEL STREQUAL "NONE") OR (NOT ARROW_RUNTIME_SIMD_LEVEL STREQUAL "NONE"
+ ))
+ set(xsimd_SOURCE "BUNDLED")
+ resolve_dependency(xsimd)
+ # TODO: Don't use global includes but rather target_include_directories
+ include_directories(SYSTEM ${XSIMD_INCLUDE_DIR})
+endif()
+
+macro(build_zlib)
+ message(STATUS "Building ZLIB from source")
+ set(ZLIB_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/zlib_ep/src/zlib_ep-install")
+ if(MSVC)
+ if(${UPPERCASE_BUILD_TYPE} STREQUAL "DEBUG")
+ set(ZLIB_STATIC_LIB_NAME zlibstaticd.lib)
+ else()
+ set(ZLIB_STATIC_LIB_NAME zlibstatic.lib)
+ endif()
+ else()
+ set(ZLIB_STATIC_LIB_NAME libz.a)
+ endif()
+ set(ZLIB_STATIC_LIB "${ZLIB_PREFIX}/lib/${ZLIB_STATIC_LIB_NAME}")
+ set(ZLIB_CMAKE_ARGS ${EP_COMMON_CMAKE_ARGS} "-DCMAKE_INSTALL_PREFIX=${ZLIB_PREFIX}"
+ -DBUILD_SHARED_LIBS=OFF)
+
+ externalproject_add(zlib_ep
+ URL ${ZLIB_SOURCE_URL} ${EP_LOG_OPTIONS}
+ URL_HASH "SHA256=${ARROW_ZLIB_BUILD_SHA256_CHECKSUM}"
+ BUILD_BYPRODUCTS "${ZLIB_STATIC_LIB}"
+ CMAKE_ARGS ${ZLIB_CMAKE_ARGS})
+
+ file(MAKE_DIRECTORY "${ZLIB_PREFIX}/include")
+
+ add_library(ZLIB::ZLIB STATIC IMPORTED)
+ set(ZLIB_LIBRARIES ${ZLIB_STATIC_LIB})
+ set(ZLIB_INCLUDE_DIRS "${ZLIB_PREFIX}/include")
+ set_target_properties(ZLIB::ZLIB
+ PROPERTIES IMPORTED_LOCATION ${ZLIB_LIBRARIES}
+ INTERFACE_INCLUDE_DIRECTORIES ${ZLIB_INCLUDE_DIRS})
+
+ add_dependencies(toolchain zlib_ep)
+ add_dependencies(ZLIB::ZLIB zlib_ep)
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS ZLIB::ZLIB)
+endmacro()
+
+if(ARROW_WITH_ZLIB)
+ resolve_dependency(ZLIB PC_PACKAGE_NAMES zlib)
+
+ # TODO: Don't use global includes but rather target_include_directories
+ get_target_property(ZLIB_INCLUDE_DIR ZLIB::ZLIB INTERFACE_INCLUDE_DIRECTORIES)
+ include_directories(SYSTEM ${ZLIB_INCLUDE_DIR})
+endif()
+
+macro(build_lz4)
+ message(STATUS "Building lz4 from source")
+ set(LZ4_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/lz4_ep-prefix/src/lz4_ep")
+ set(LZ4_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/lz4_ep-prefix")
+
+ if(MSVC)
+ if(ARROW_USE_STATIC_CRT)
+ if(${UPPERCASE_BUILD_TYPE} STREQUAL "DEBUG")
+ set(LZ4_RUNTIME_LIBRARY_LINKAGE "/p:RuntimeLibrary=MultiThreadedDebug")
+ else()
+ set(LZ4_RUNTIME_LIBRARY_LINKAGE "/p:RuntimeLibrary=MultiThreaded")
+ endif()
+ endif()
+ set(LZ4_STATIC_LIB
+ "${LZ4_BUILD_DIR}/build/VS2010/bin/x64_${CMAKE_BUILD_TYPE}/liblz4_static.lib")
+ set(LZ4_BUILD_COMMAND
+ BUILD_COMMAND msbuild.exe /m /p:Configuration=${CMAKE_BUILD_TYPE} /p:Platform=x64
+ /p:PlatformToolset=v140 ${LZ4_RUNTIME_LIBRARY_LINKAGE} /t:Build
+ ${LZ4_BUILD_DIR}/build/VS2010/lz4.sln)
+ else()
+ set(LZ4_STATIC_LIB "${LZ4_BUILD_DIR}/lib/liblz4.a")
+ # Must explicitly invoke sh on MinGW
+ set(LZ4_BUILD_COMMAND
+ BUILD_COMMAND sh "${CMAKE_SOURCE_DIR}/build-support/build-lz4-lib.sh"
+ "AR=${CMAKE_AR}" "OS=${CMAKE_SYSTEM_NAME}")
+ endif()
+
+ # We need to copy the header in lib to directory outside of the build
+ externalproject_add(lz4_ep
+ URL ${LZ4_SOURCE_URL} ${EP_LOG_OPTIONS}
+ URL_HASH "SHA256=${ARROW_LZ4_BUILD_SHA256_CHECKSUM}"
+ UPDATE_COMMAND ${CMAKE_COMMAND} -E copy_directory
+ "${LZ4_BUILD_DIR}/lib" "${LZ4_PREFIX}/include"
+ ${LZ4_PATCH_COMMAND}
+ CONFIGURE_COMMAND ""
+ INSTALL_COMMAND ""
+ BINARY_DIR ${LZ4_BUILD_DIR}
+ BUILD_BYPRODUCTS ${LZ4_STATIC_LIB} ${LZ4_BUILD_COMMAND})
+
+ file(MAKE_DIRECTORY "${LZ4_PREFIX}/include")
+ add_library(LZ4::lz4 STATIC IMPORTED)
+ set_target_properties(LZ4::lz4
+ PROPERTIES IMPORTED_LOCATION "${LZ4_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${LZ4_PREFIX}/include")
+ add_dependencies(toolchain lz4_ep)
+ add_dependencies(LZ4::lz4 lz4_ep)
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS LZ4::lz4)
+endmacro()
+
+if(ARROW_WITH_LZ4)
+ resolve_dependency(Lz4 PC_PACKAGE_NAMES liblz4)
+
+ # TODO: Don't use global includes but rather target_include_directories
+ get_target_property(LZ4_INCLUDE_DIR LZ4::lz4 INTERFACE_INCLUDE_DIRECTORIES)
+ include_directories(SYSTEM ${LZ4_INCLUDE_DIR})
+endif()
+
+macro(build_zstd)
+ message(STATUS "Building zstd from source")
+ set(ZSTD_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/zstd_ep-install")
+
+ set(ZSTD_CMAKE_ARGS
+ ${EP_COMMON_TOOLCHAIN}
+ "-DCMAKE_INSTALL_PREFIX=${ZSTD_PREFIX}"
+ -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
+ -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR}
+ -DZSTD_BUILD_PROGRAMS=off
+ -DZSTD_BUILD_SHARED=off
+ -DZSTD_BUILD_STATIC=on
+ -DZSTD_MULTITHREAD_SUPPORT=off)
+
+ if(MSVC)
+ set(ZSTD_STATIC_LIB "${ZSTD_PREFIX}/${CMAKE_INSTALL_LIBDIR}/zstd_static.lib")
+ if(ARROW_USE_STATIC_CRT)
+ set(ZSTD_CMAKE_ARGS ${ZSTD_CMAKE_ARGS} "-DZSTD_USE_STATIC_RUNTIME=on")
+ endif()
+ else()
+ set(ZSTD_STATIC_LIB "${ZSTD_PREFIX}/${CMAKE_INSTALL_LIBDIR}/libzstd.a")
+ # Only pass our C flags on Unix as on MSVC it leads to a
+ # "incompatible command-line options" error
+ set(ZSTD_CMAKE_ARGS
+ ${ZSTD_CMAKE_ARGS}
+ -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
+ -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
+ -DCMAKE_C_FLAGS=${EP_C_FLAGS}
+ -DCMAKE_CXX_FLAGS=${EP_CXX_FLAGS})
+ endif()
+
+ if(CMAKE_VERSION VERSION_LESS 3.7)
+ message(FATAL_ERROR "Building zstd using ExternalProject requires at least CMake 3.7")
+ endif()
+
+ externalproject_add(zstd_ep
+ ${EP_LOG_OPTIONS}
+ CMAKE_ARGS ${ZSTD_CMAKE_ARGS}
+ SOURCE_SUBDIR "build/cmake"
+ INSTALL_DIR ${ZSTD_PREFIX}
+ URL ${ZSTD_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_ZSTD_BUILD_SHA256_CHECKSUM}"
+ BUILD_BYPRODUCTS "${ZSTD_STATIC_LIB}")
+
+ file(MAKE_DIRECTORY "${ZSTD_PREFIX}/include")
+
+ add_library(zstd::libzstd STATIC IMPORTED)
+ set_target_properties(zstd::libzstd
+ PROPERTIES IMPORTED_LOCATION "${ZSTD_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${ZSTD_PREFIX}/include")
+
+ add_dependencies(toolchain zstd_ep)
+ add_dependencies(zstd::libzstd zstd_ep)
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS zstd::libzstd)
+endmacro()
+
+if(ARROW_WITH_ZSTD)
+ # ARROW-13384: ZSTD_minCLevel was added in v1.4.0, required by ARROW-13091
+ resolve_dependency(zstd
+ PC_PACKAGE_NAMES
+ libzstd
+ REQUIRED_VERSION
+ 1.4.0)
+
+ if(TARGET zstd::libzstd)
+ set(ARROW_ZSTD_LIBZSTD zstd::libzstd)
+ else()
+ # "SYSTEM" source will prioritize cmake config, which exports
+ # zstd::libzstd_{static,shared}
+ if(ARROW_ZSTD_USE_SHARED)
+ if(TARGET zstd::libzstd_shared)
+ set(ARROW_ZSTD_LIBZSTD zstd::libzstd_shared)
+ endif()
+ else()
+ if(TARGET zstd::libzstd_static)
+ set(ARROW_ZSTD_LIBZSTD zstd::libzstd_static)
+ endif()
+ endif()
+ endif()
+
+ # TODO: Don't use global includes but rather target_include_directories
+ get_target_property(ZSTD_INCLUDE_DIR ${ARROW_ZSTD_LIBZSTD}
+ INTERFACE_INCLUDE_DIRECTORIES)
+ include_directories(SYSTEM ${ZSTD_INCLUDE_DIR})
+endif()
+
+# ----------------------------------------------------------------------
+# RE2 (required for Gandiva)
+
+macro(build_re2)
+ message(STATUS "Building RE2 from source")
+ set(RE2_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/re2_ep-install")
+ set(RE2_STATIC_LIB
+ "${RE2_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}re2${CMAKE_STATIC_LIBRARY_SUFFIX}")
+
+ set(RE2_CMAKE_ARGS ${EP_COMMON_CMAKE_ARGS} "-DCMAKE_INSTALL_PREFIX=${RE2_PREFIX}"
+ -DCMAKE_INSTALL_LIBDIR=lib)
+
+ externalproject_add(re2_ep
+ ${EP_LOG_OPTIONS}
+ INSTALL_DIR ${RE2_PREFIX}
+ URL ${RE2_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_RE2_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${RE2_CMAKE_ARGS}
+ BUILD_BYPRODUCTS "${RE2_STATIC_LIB}")
+
+ file(MAKE_DIRECTORY "${RE2_PREFIX}/include")
+ add_library(re2::re2 STATIC IMPORTED)
+ set_target_properties(re2::re2
+ PROPERTIES IMPORTED_LOCATION "${RE2_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${RE2_PREFIX}/include")
+
+ add_dependencies(toolchain re2_ep)
+ add_dependencies(re2::re2 re2_ep)
+ set(RE2_VENDORED TRUE)
+ # Set values so that FindRE2 finds this too
+ set(RE2_LIB ${RE2_STATIC_LIB})
+ set(RE2_INCLUDE_DIR "${RE2_PREFIX}/include")
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS re2::re2)
+endmacro()
+
+if(ARROW_WITH_RE2)
+ # Don't specify "PC_PACKAGE_NAMES re2" here because re2.pc may
+ # include -std=c++11. It's not compatible with C source and C++
+ # source not uses C++ 11.
+ resolve_dependency(re2 HAVE_ALT TRUE)
+ if(${re2_SOURCE} STREQUAL "SYSTEM")
+ get_target_property(RE2_LIB re2::re2 IMPORTED_LOCATION)
+ string(APPEND ARROW_PC_LIBS_PRIVATE " ${RE2_LIB}")
+ endif()
+ add_definitions(-DARROW_WITH_RE2)
+
+ # TODO: Don't use global includes but rather target_include_directories
+ get_target_property(RE2_INCLUDE_DIR re2::re2 INTERFACE_INCLUDE_DIRECTORIES)
+ include_directories(SYSTEM ${RE2_INCLUDE_DIR})
+endif()
+
+macro(build_bzip2)
+ message(STATUS "Building BZip2 from source")
+ set(BZIP2_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/bzip2_ep-install")
+ set(BZIP2_STATIC_LIB
+ "${BZIP2_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}bz2${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+
+ set(BZIP2_EXTRA_ARGS "CC=${CMAKE_C_COMPILER}" "CFLAGS=${EP_C_FLAGS}")
+
+ if(CMAKE_OSX_SYSROOT)
+ list(APPEND BZIP2_EXTRA_ARGS "SDKROOT=${CMAKE_OSX_SYSROOT}")
+ endif()
+
+ if(CMAKE_AR)
+ list(APPEND BZIP2_EXTRA_ARGS AR=${CMAKE_AR})
+ endif()
+
+ if(CMAKE_RANLIB)
+ list(APPEND BZIP2_EXTRA_ARGS RANLIB=${CMAKE_RANLIB})
+ endif()
+
+ externalproject_add(bzip2_ep
+ ${EP_LOG_OPTIONS}
+ CONFIGURE_COMMAND ""
+ BUILD_IN_SOURCE 1
+ BUILD_COMMAND ${MAKE} libbz2.a ${MAKE_BUILD_ARGS}
+ ${BZIP2_EXTRA_ARGS}
+ INSTALL_COMMAND ${MAKE} install PREFIX=${BZIP2_PREFIX}
+ ${BZIP2_EXTRA_ARGS}
+ INSTALL_DIR ${BZIP2_PREFIX}
+ URL ${ARROW_BZIP2_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_BZIP2_BUILD_SHA256_CHECKSUM}"
+ BUILD_BYPRODUCTS "${BZIP2_STATIC_LIB}")
+
+ file(MAKE_DIRECTORY "${BZIP2_PREFIX}/include")
+ add_library(BZip2::BZip2 STATIC IMPORTED)
+ set_target_properties(BZip2::BZip2
+ PROPERTIES IMPORTED_LOCATION "${BZIP2_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${BZIP2_PREFIX}/include")
+ set(BZIP2_INCLUDE_DIR "${BZIP2_PREFIX}/include")
+
+ add_dependencies(toolchain bzip2_ep)
+ add_dependencies(BZip2::BZip2 bzip2_ep)
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS BZip2::BZip2)
+endmacro()
+
+if(ARROW_WITH_BZ2)
+ resolve_dependency(BZip2)
+ if(${BZip2_SOURCE} STREQUAL "SYSTEM")
+ string(APPEND ARROW_PC_LIBS_PRIVATE " ${BZIP2_LIBRARIES}")
+ endif()
+
+ if(NOT TARGET BZip2::BZip2)
+ add_library(BZip2::BZip2 UNKNOWN IMPORTED)
+ set_target_properties(BZip2::BZip2
+ PROPERTIES IMPORTED_LOCATION "${BZIP2_LIBRARIES}"
+ INTERFACE_INCLUDE_DIRECTORIES "${BZIP2_INCLUDE_DIR}")
+ endif()
+ include_directories(SYSTEM "${BZIP2_INCLUDE_DIR}")
+endif()
+
+macro(build_utf8proc)
+ message(STATUS "Building utf8proc from source")
+ set(UTF8PROC_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/utf8proc_ep-install")
+ if(MSVC)
+ set(UTF8PROC_STATIC_LIB "${UTF8PROC_PREFIX}/lib/utf8proc_static.lib")
+ else()
+ set(UTF8PROC_STATIC_LIB
+ "${UTF8PROC_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}utf8proc${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ endif()
+
+ set(UTF8PROC_CMAKE_ARGS
+ ${EP_COMMON_TOOLCHAIN}
+ "-DCMAKE_INSTALL_PREFIX=${UTF8PROC_PREFIX}"
+ -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
+ -DCMAKE_INSTALL_LIBDIR=lib
+ -DBUILD_SHARED_LIBS=OFF)
+
+ externalproject_add(utf8proc_ep
+ ${EP_LOG_OPTIONS}
+ CMAKE_ARGS ${UTF8PROC_CMAKE_ARGS}
+ INSTALL_DIR ${UTF8PROC_PREFIX}
+ URL ${ARROW_UTF8PROC_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_UTF8PROC_BUILD_SHA256_CHECKSUM}"
+ BUILD_BYPRODUCTS "${UTF8PROC_STATIC_LIB}")
+
+ file(MAKE_DIRECTORY "${UTF8PROC_PREFIX}/include")
+ add_library(utf8proc::utf8proc STATIC IMPORTED)
+ set_target_properties(utf8proc::utf8proc
+ PROPERTIES IMPORTED_LOCATION "${UTF8PROC_STATIC_LIB}"
+ INTERFACE_COMPILER_DEFINITIONS "UTF8PROC_STATIC"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${UTF8PROC_PREFIX}/include")
+
+ add_dependencies(toolchain utf8proc_ep)
+ add_dependencies(utf8proc::utf8proc utf8proc_ep)
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS utf8proc::utf8proc)
+endmacro()
+
+if(ARROW_WITH_UTF8PROC)
+ resolve_dependency(utf8proc
+ REQUIRED_VERSION
+ "2.2.0"
+ PC_PACKAGE_NAMES
+ libutf8proc)
+
+ add_definitions(-DARROW_WITH_UTF8PROC)
+
+ # TODO: Don't use global definitions but rather
+ # target_compile_definitions or target_link_libraries
+ get_target_property(UTF8PROC_COMPILER_DEFINITIONS utf8proc::utf8proc
+ INTERFACE_COMPILER_DEFINITIONS)
+ if(UTF8PROC_COMPILER_DEFINITIONS)
+ add_definitions(-D${UTF8PROC_COMPILER_DEFINITIONS})
+ endif()
+
+ # TODO: Don't use global includes but rather
+ # target_include_directories or target_link_libraries
+ get_target_property(UTF8PROC_INCLUDE_DIR utf8proc::utf8proc
+ INTERFACE_INCLUDE_DIRECTORIES)
+ include_directories(SYSTEM ${UTF8PROC_INCLUDE_DIR})
+endif()
+
+macro(build_cares)
+ message(STATUS "Building c-ares from source")
+ set(CARES_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/cares_ep-install")
+ set(CARES_INCLUDE_DIR "${CARES_PREFIX}/include")
+
+ # If you set -DCARES_SHARED=ON then the build system names the library
+ # libcares_static.a
+ set(CARES_STATIC_LIB
+ "${CARES_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}cares${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+
+ set(CARES_CMAKE_ARGS
+ "${EP_COMMON_CMAKE_ARGS}"
+ -DCARES_STATIC=ON
+ -DCARES_SHARED=OFF
+ -DCMAKE_INSTALL_LIBDIR=lib
+ "-DCMAKE_INSTALL_PREFIX=${CARES_PREFIX}")
+
+ externalproject_add(cares_ep
+ ${EP_LOG_OPTIONS}
+ URL ${CARES_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_CARES_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${CARES_CMAKE_ARGS}
+ BUILD_BYPRODUCTS "${CARES_STATIC_LIB}")
+
+ file(MAKE_DIRECTORY ${CARES_INCLUDE_DIR})
+
+ add_dependencies(toolchain cares_ep)
+ add_library(c-ares::cares STATIC IMPORTED)
+ set_target_properties(c-ares::cares
+ PROPERTIES IMPORTED_LOCATION "${CARES_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${CARES_INCLUDE_DIR}")
+ add_dependencies(c-ares::cares cares_ep)
+
+ if(APPLE)
+ # libresolv must be linked from c-ares version 1.16.1
+ find_library(LIBRESOLV_LIBRARY NAMES resolv libresolv REQUIRED)
+ set_target_properties(c-ares::cares PROPERTIES INTERFACE_LINK_LIBRARIES
+ "${LIBRESOLV_LIBRARY}")
+ endif()
+
+ set(CARES_VENDORED TRUE)
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS c-ares::cares)
+endmacro()
+
+# ----------------------------------------------------------------------
+# Dependencies for Arrow Flight RPC
+
+macro(build_absl_once)
+ if(NOT TARGET absl_ep)
+ message(STATUS "Building Abseil-cpp from source")
+ set(ABSL_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/absl_ep-install")
+ set(ABSL_INCLUDE_DIR "${ABSL_PREFIX}/include")
+ set(ABSL_CMAKE_ARGS
+ "${EP_COMMON_CMAKE_ARGS}" -DABSL_RUN_TESTS=OFF -DCMAKE_INSTALL_LIBDIR=lib
+ "-DCMAKE_INSTALL_PREFIX=${ABSL_PREFIX}")
+ set(ABSL_BUILD_BYPRODUCTS)
+ set(ABSL_LIBRARIES)
+
+ # Abseil produces the following libraries, each is fairly small, but there
+ # are (as you can see), many of them. We need to add the libraries first,
+ # and then describe how they depend on each other. The list can be
+ # refreshed using:
+ # ls -1 $PREFIX/lib/libabsl_*.a | sed -e 's/.*libabsl_//' -e 's/.a$//'
+ set(_ABSL_LIBS
+ bad_any_cast_impl
+ bad_optional_access
+ bad_variant_access
+ base
+ city
+ civil_time
+ cord
+ debugging_internal
+ demangle_internal
+ examine_stack
+ exponential_biased
+ failure_signal_handler
+ flags
+ flags_commandlineflag
+ flags_commandlineflag_internal
+ flags_config
+ flags_internal
+ flags_marshalling
+ flags_parse
+ flags_private_handle_accessor
+ flags_program_name
+ flags_reflection
+ flags_usage
+ flags_usage_internal
+ graphcycles_internal
+ hash
+ hashtablez_sampler
+ int128
+ leak_check
+ leak_check_disable
+ log_severity
+ malloc_internal
+ periodic_sampler
+ random_distributions
+ random_internal_distribution_test_util
+ random_internal_platform
+ random_internal_pool_urbg
+ random_internal_randen
+ random_internal_randen_hwaes
+ random_internal_randen_hwaes_impl
+ random_internal_randen_slow
+ random_internal_seed_material
+ random_seed_gen_exception
+ random_seed_sequences
+ raw_hash_set
+ raw_logging_internal
+ scoped_set_env
+ spinlock_wait
+ stacktrace
+ status
+ statusor
+ strerror
+ str_format_internal
+ strings
+ strings_internal
+ symbolize
+ synchronization
+ throw_delegate
+ time
+ time_zone
+ wyhash)
+ # Abseil creates a number of header-only targets, which are needed to resolve dependencies.
+ # The list can be refreshed using:
+ # comm -13 <(ls -l $PREFIX/lib/libabsl_*.a | sed -e 's/.*libabsl_//' -e 's/.a$//' | sort -u) \
+ # <(ls -1 $PREFIX/lib/pkgconfig/absl_*.pc | sed -e 's/.*absl_//' -e 's/.pc$//' | sort -u)
+ set(_ABSL_INTERFACE_LIBS
+ algorithm
+ algorithm_container
+ any
+ atomic_hook
+ bad_any_cast
+ base_internal
+ bind_front
+ bits
+ btree
+ cleanup
+ cleanup_internal
+ compare
+ compressed_tuple
+ config
+ container_common
+ container_memory
+ core_headers
+ counting_allocator
+ debugging
+ dynamic_annotations
+ endian
+ errno_saver
+ fast_type_id
+ fixed_array
+ flags_path_util
+ flat_hash_map
+ flat_hash_set
+ function_ref
+ hash_function_defaults
+ hash_policy_traits
+ hashtable_debug
+ hashtable_debug_hooks
+ have_sse
+ inlined_vector
+ inlined_vector_internal
+ kernel_timeout_internal
+ layout
+ memory
+ meta
+ node_hash_map
+ node_hash_policy
+ node_hash_set
+ numeric
+ numeric_representation
+ optional
+ pretty_function
+ random_bit_gen_ref
+ random_internal_distribution_caller
+ random_internal_fastmath
+ random_internal_fast_uniform_bits
+ random_internal_generate_real
+ random_internal_iostream_state_saver
+ random_internal_mock_helpers
+ random_internal_nonsecure_base
+ random_internal_pcg_engine
+ random_internal_randen_engine
+ random_internal_salted_seed_seq
+ random_internal_traits
+ random_internal_uniform_helper
+ random_internal_wide_multiply
+ random_random
+ raw_hash_map
+ span
+ str_format
+ type_traits
+ utility
+ variant)
+
+ foreach(_ABSL_LIB ${_ABSL_LIBS})
+ set(_ABSL_STATIC_LIBRARY
+ "${ABSL_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}absl_${_ABSL_LIB}${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ add_library(absl::${_ABSL_LIB} STATIC IMPORTED)
+ set_target_properties(absl::${_ABSL_LIB}
+ PROPERTIES IMPORTED_LOCATION ${_ABSL_STATIC_LIBRARY}
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${ABSL_INCLUDE_DIR}")
+ list(APPEND ABSL_BUILD_BYPRODUCTS ${_ABSL_STATIC_LIBRARY})
+ endforeach()
+ foreach(_ABSL_LIB ${_ABSL_INTERFACE_LIBS})
+ add_library(absl::${_ABSL_LIB} INTERFACE IMPORTED)
+ set_target_properties(absl::${_ABSL_LIB} PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
+ "${ABSL_INCLUDE_DIR}")
+ endforeach()
+
+ # Extracted the dependency information using the Abseil pkg-config files:
+ # grep Requires $PREFIX/pkgconfig/absl_*.pc | \
+ # sed -e 's;.*/absl_;set_property(TARGET absl::;' \
+ # -e 's/.pc:Requires:/ PROPERTY INTERFACE_LINK_LIBRARIES /' \
+ # -e 's/ = 20210324,//g' \
+ # -e 's/ = 20210324//g' \
+ # -e 's/absl_/absl::/g' \
+ # -e 's/$/)/' | \
+ # grep -v 'INTERFACE_LINK_LIBRARIES[ ]*)'
+ set_property(TARGET absl::algorithm_container
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::algorithm absl::core_headers
+ absl::meta)
+ set_property(TARGET absl::algorithm PROPERTY INTERFACE_LINK_LIBRARIES absl::config)
+ set_property(TARGET absl::any
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::bad_any_cast
+ absl::config
+ absl::core_headers
+ absl::fast_type_id
+ absl::type_traits
+ absl::utility)
+ set_property(TARGET absl::atomic_hook PROPERTY INTERFACE_LINK_LIBRARIES absl::config
+ absl::core_headers)
+ set_property(TARGET absl::bad_any_cast_impl
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::config
+ absl::raw_logging_internal)
+ set_property(TARGET absl::bad_any_cast PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::bad_any_cast_impl absl::config)
+ set_property(TARGET absl::bad_optional_access
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::config
+ absl::raw_logging_internal)
+ set_property(TARGET absl::bad_variant_access
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::config
+ absl::raw_logging_internal)
+ set_property(TARGET absl::base_internal PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config absl::type_traits)
+ set_property(TARGET absl::base
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::atomic_hook
+ absl::base_internal
+ absl::config
+ absl::core_headers
+ absl::dynamic_annotations
+ absl::log_severity
+ absl::raw_logging_internal
+ absl::spinlock_wait
+ absl::type_traits)
+ set_property(TARGET absl::bind_front
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::base_internal
+ absl::compressed_tuple)
+ set_property(TARGET absl::bits PROPERTY INTERFACE_LINK_LIBRARIES absl::core_headers)
+ set_property(TARGET absl::btree
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::container_common
+ absl::compare
+ absl::compressed_tuple
+ absl::container_memory
+ absl::cord
+ absl::core_headers
+ absl::layout
+ absl::memory
+ absl::strings
+ absl::throw_delegate
+ absl::type_traits
+ absl::utility)
+ set_property(TARGET absl::city PROPERTY INTERFACE_LINK_LIBRARIES absl::config
+ absl::core_headers absl::endian)
+ set_property(TARGET absl::cleanup_internal
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::base_internal absl::core_headers
+ absl::utility)
+ set_property(TARGET absl::cleanup
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::cleanup_internal absl::config
+ absl::core_headers)
+ set_property(TARGET absl::compare PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::core_headers absl::type_traits)
+ set_property(TARGET absl::compressed_tuple PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::utility)
+ set_property(TARGET absl::container_common PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::type_traits)
+ set_property(TARGET absl::container_memory
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::memory
+ absl::type_traits
+ absl::utility)
+ set_property(TARGET absl::cord
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::base
+ absl::base_internal
+ absl::compressed_tuple
+ absl::config
+ absl::core_headers
+ absl::endian
+ absl::fixed_array
+ absl::function_ref
+ absl::inlined_vector
+ absl::optional
+ absl::raw_logging_internal
+ absl::strings
+ absl::strings_internal
+ absl::throw_delegate
+ absl::type_traits)
+ set_property(TARGET absl::core_headers PROPERTY INTERFACE_LINK_LIBRARIES absl::config)
+ set_property(TARGET absl::counting_allocator PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config)
+ set_property(TARGET absl::debugging_internal
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::core_headers
+ absl::config
+ absl::dynamic_annotations
+ absl::errno_saver
+ absl::raw_logging_internal)
+ set_property(TARGET absl::debugging PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::stacktrace absl::leak_check)
+ set_property(TARGET absl::demangle_internal PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::base absl::core_headers)
+ set_property(TARGET absl::dynamic_annotations PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config)
+ set_property(TARGET absl::endian PROPERTY INTERFACE_LINK_LIBRARIES absl::base
+ absl::config absl::core_headers)
+ set_property(TARGET absl::errno_saver PROPERTY INTERFACE_LINK_LIBRARIES absl::config)
+ set_property(TARGET absl::examine_stack
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::stacktrace
+ absl::symbolize
+ absl::config
+ absl::core_headers
+ absl::raw_logging_internal)
+ set_property(TARGET absl::exponential_biased PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config absl::core_headers)
+ set_property(TARGET absl::failure_signal_handler
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::examine_stack
+ absl::stacktrace
+ absl::base
+ absl::config
+ absl::core_headers
+ absl::errno_saver
+ absl::raw_logging_internal)
+ set_property(TARGET absl::fast_type_id PROPERTY INTERFACE_LINK_LIBRARIES absl::config)
+ set_property(TARGET absl::fixed_array
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::compressed_tuple
+ absl::algorithm
+ absl::config
+ absl::core_headers
+ absl::dynamic_annotations
+ absl::throw_delegate
+ absl::memory)
+ set_property(TARGET absl::flags_commandlineflag_internal
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::config absl::fast_type_id)
+ set_property(TARGET absl::flags_commandlineflag
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::fast_type_id
+ absl::flags_commandlineflag_internal
+ absl::optional
+ absl::strings)
+ set_property(TARGET absl::flags_config
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::flags_path_util
+ absl::flags_program_name
+ absl::core_headers
+ absl::strings
+ absl::synchronization)
+ set_property(TARGET absl::flags_internal
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::base
+ absl::config
+ absl::flags_commandlineflag
+ absl::flags_commandlineflag_internal
+ absl::flags_config
+ absl::flags_marshalling
+ absl::synchronization
+ absl::meta
+ absl::utility)
+ set_property(TARGET absl::flags_marshalling
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::core_headers
+ absl::log_severity
+ absl::strings
+ absl::str_format)
+ set_property(TARGET absl::flags_parse
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::core_headers
+ absl::flags_config
+ absl::flags
+ absl::flags_commandlineflag
+ absl::flags_commandlineflag_internal
+ absl::flags_internal
+ absl::flags_private_handle_accessor
+ absl::flags_program_name
+ absl::flags_reflection
+ absl::flags_usage
+ absl::strings
+ absl::synchronization)
+ set_property(TARGET absl::flags_path_util PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config absl::strings)
+ set_property(TARGET absl::flags
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::flags_commandlineflag
+ absl::flags_config
+ absl::flags_internal
+ absl::flags_reflection
+ absl::base
+ absl::core_headers
+ absl::strings)
+ set_property(TARGET absl::flags_private_handle_accessor
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::flags_commandlineflag
+ absl::flags_commandlineflag_internal
+ absl::strings)
+ set_property(TARGET absl::flags_program_name
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::core_headers
+ absl::flags_path_util
+ absl::strings
+ absl::synchronization)
+ set_property(TARGET absl::flags_reflection
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::flags_commandlineflag
+ absl::flags_private_handle_accessor
+ absl::flags_config
+ absl::strings
+ absl::synchronization
+ absl::flat_hash_map)
+ set_property(TARGET absl::flags_usage_internal
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::flags_config
+ absl::flags
+ absl::flags_commandlineflag
+ absl::flags_internal
+ absl::flags_path_util
+ absl::flags_private_handle_accessor
+ absl::flags_program_name
+ absl::flags_reflection
+ absl::strings
+ absl::synchronization)
+ set_property(TARGET absl::flags_usage
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::core_headers
+ absl::flags_usage_internal
+ absl::strings
+ absl::synchronization)
+ set_property(TARGET absl::flat_hash_map
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::container_memory
+ absl::hash_function_defaults
+ absl::raw_hash_map
+ absl::algorithm_container
+ absl::memory)
+ set_property(TARGET absl::flat_hash_set
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::container_memory
+ absl::hash_function_defaults
+ absl::raw_hash_set
+ absl::algorithm_container
+ absl::core_headers
+ absl::memory)
+ set_property(TARGET absl::function_ref PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::base_internal absl::meta)
+ set_property(TARGET absl::graphcycles_internal
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::base
+ absl::base_internal
+ absl::config
+ absl::core_headers
+ absl::malloc_internal
+ absl::raw_logging_internal)
+ set_property(TARGET absl::hash_function_defaults
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::cord
+ absl::hash
+ absl::strings)
+ set_property(TARGET absl::hash
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::city
+ absl::config
+ absl::core_headers
+ absl::endian
+ absl::fixed_array
+ absl::meta
+ absl::int128
+ absl::strings
+ absl::optional
+ absl::variant
+ absl::utility
+ absl::wyhash)
+ set_property(TARGET absl::hash_policy_traits PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::meta)
+ set_property(TARGET absl::hashtable_debug_hooks PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config)
+ set_property(TARGET absl::hashtable_debug PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::hashtable_debug_hooks)
+ set_property(TARGET absl::hashtablez_sampler
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::base
+ absl::exponential_biased
+ absl::have_sse
+ absl::synchronization)
+ set_property(TARGET absl::inlined_vector_internal
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::compressed_tuple
+ absl::core_headers
+ absl::memory
+ absl::span
+ absl::type_traits)
+ set_property(TARGET absl::inlined_vector
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::algorithm
+ absl::core_headers
+ absl::inlined_vector_internal
+ absl::throw_delegate
+ absl::memory)
+ set_property(TARGET absl::int128 PROPERTY INTERFACE_LINK_LIBRARIES absl::config
+ absl::core_headers absl::bits)
+ set_property(TARGET absl::kernel_timeout_internal
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::core_headers
+ absl::raw_logging_internal absl::time)
+ set_property(TARGET absl::layout
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::core_headers
+ absl::meta
+ absl::strings
+ absl::span
+ absl::utility)
+ set_property(TARGET absl::leak_check PROPERTY INTERFACE_LINK_LIBRARIES absl::config
+ absl::core_headers)
+ set_property(TARGET absl::log_severity PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::core_headers)
+ set_property(TARGET absl::malloc_internal
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::base
+ absl::base_internal
+ absl::config
+ absl::core_headers
+ absl::dynamic_annotations
+ absl::raw_logging_internal)
+ set_property(TARGET absl::memory PROPERTY INTERFACE_LINK_LIBRARIES absl::core_headers
+ absl::meta)
+ set_property(TARGET absl::meta PROPERTY INTERFACE_LINK_LIBRARIES absl::type_traits)
+ set_property(TARGET absl::node_hash_map
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::container_memory
+ absl::hash_function_defaults
+ absl::node_hash_policy
+ absl::raw_hash_map
+ absl::algorithm_container
+ absl::memory)
+ set_property(TARGET absl::node_hash_policy PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config)
+ set_property(TARGET absl::node_hash_set
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::hash_function_defaults
+ absl::node_hash_policy
+ absl::raw_hash_set
+ absl::algorithm_container
+ absl::memory)
+ set_property(TARGET absl::numeric PROPERTY INTERFACE_LINK_LIBRARIES absl::int128)
+ set_property(TARGET absl::numeric_representation PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config)
+ set_property(TARGET absl::optional
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::bad_optional_access
+ absl::base_internal
+ absl::config
+ absl::core_headers
+ absl::memory
+ absl::type_traits
+ absl::utility)
+ set_property(TARGET absl::periodic_sampler
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::core_headers
+ absl::exponential_biased)
+ set_property(TARGET absl::random_bit_gen_ref
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::core_headers
+ absl::random_internal_distribution_caller
+ absl::random_internal_fast_uniform_bits
+ absl::type_traits)
+ set_property(TARGET absl::random_distributions
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::base_internal
+ absl::config
+ absl::core_headers
+ absl::random_internal_generate_real
+ absl::random_internal_distribution_caller
+ absl::random_internal_fast_uniform_bits
+ absl::random_internal_fastmath
+ absl::random_internal_iostream_state_saver
+ absl::random_internal_traits
+ absl::random_internal_uniform_helper
+ absl::random_internal_wide_multiply
+ absl::strings
+ absl::type_traits)
+ set_property(TARGET absl::random_internal_distribution_caller
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::config absl::utility
+ absl::fast_type_id)
+ set_property(TARGET absl::random_internal_distribution_test_util
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::core_headers
+ absl::raw_logging_internal
+ absl::strings
+ absl::str_format
+ absl::span)
+ set_property(TARGET absl::random_internal_fastmath PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::bits)
+ set_property(TARGET absl::random_internal_fast_uniform_bits
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::config)
+ set_property(TARGET absl::random_internal_generate_real
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::bits
+ absl::random_internal_fastmath
+ absl::random_internal_traits
+ absl::type_traits)
+ set_property(TARGET absl::random_internal_iostream_state_saver
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::int128 absl::type_traits)
+ set_property(TARGET absl::random_internal_mock_helpers
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::fast_type_id absl::optional)
+ set_property(TARGET absl::random_internal_nonsecure_base
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::core_headers
+ absl::optional
+ absl::random_internal_pool_urbg
+ absl::random_internal_salted_seed_seq
+ absl::random_internal_seed_material
+ absl::span
+ absl::type_traits)
+ set_property(TARGET absl::random_internal_pcg_engine
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::int128
+ absl::random_internal_fastmath
+ absl::random_internal_iostream_state_saver
+ absl::type_traits)
+ set_property(TARGET absl::random_internal_platform PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config)
+ set_property(TARGET absl::random_internal_pool_urbg
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::base
+ absl::config
+ absl::core_headers
+ absl::endian
+ absl::random_internal_randen
+ absl::random_internal_seed_material
+ absl::random_internal_traits
+ absl::random_seed_gen_exception
+ absl::raw_logging_internal
+ absl::span)
+ set_property(TARGET absl::random_internal_randen_engine
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::endian
+ absl::random_internal_iostream_state_saver
+ absl::random_internal_randen
+ absl::raw_logging_internal
+ absl::type_traits)
+ set_property(TARGET absl::random_internal_randen_hwaes_impl
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::random_internal_platform
+ absl::config)
+ set_property(TARGET absl::random_internal_randen_hwaes
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::random_internal_platform
+ absl::random_internal_randen_hwaes_impl absl::config)
+ set_property(TARGET absl::random_internal_randen
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::random_internal_platform
+ absl::random_internal_randen_hwaes
+ absl::random_internal_randen_slow)
+ set_property(TARGET absl::random_internal_randen_slow
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::random_internal_platform
+ absl::config)
+ set_property(TARGET absl::random_internal_salted_seed_seq
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::inlined_vector
+ absl::optional
+ absl::span
+ absl::random_internal_seed_material
+ absl::type_traits)
+ set_property(TARGET absl::random_internal_seed_material
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::core_headers
+ absl::optional
+ absl::random_internal_fast_uniform_bits
+ absl::raw_logging_internal
+ absl::span
+ absl::strings)
+ set_property(TARGET absl::random_internal_traits PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config)
+ set_property(TARGET absl::random_internal_uniform_helper
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::config
+ absl::random_internal_traits absl::type_traits)
+ set_property(TARGET absl::random_internal_wide_multiply
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::bits absl::config absl::int128)
+ set_property(TARGET absl::random_random
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::random_distributions
+ absl::random_internal_nonsecure_base
+ absl::random_internal_pcg_engine
+ absl::random_internal_pool_urbg
+ absl::random_internal_randen_engine
+ absl::random_seed_sequences)
+ set_property(TARGET absl::random_seed_gen_exception PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config)
+ set_property(TARGET absl::random_seed_sequences
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::inlined_vector
+ absl::random_internal_nonsecure_base
+ absl::random_internal_pool_urbg
+ absl::random_internal_salted_seed_seq
+ absl::random_internal_seed_material
+ absl::random_seed_gen_exception
+ absl::span)
+ set_property(TARGET absl::raw_hash_map
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::container_memory
+ absl::raw_hash_set absl::throw_delegate)
+ set_property(TARGET absl::raw_hash_set
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::bits
+ absl::compressed_tuple
+ absl::config
+ absl::container_common
+ absl::container_memory
+ absl::core_headers
+ absl::endian
+ absl::hash_policy_traits
+ absl::hashtable_debug_hooks
+ absl::have_sse
+ absl::layout
+ absl::memory
+ absl::meta
+ absl::optional
+ absl::utility
+ absl::hashtablez_sampler)
+ set_property(TARGET absl::raw_logging_internal
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::atomic_hook
+ absl::config
+ absl::core_headers
+ absl::log_severity)
+ set_property(TARGET absl::scoped_set_env
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::config
+ absl::raw_logging_internal)
+ set_property(TARGET absl::span
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::algorithm
+ absl::core_headers
+ absl::throw_delegate
+ absl::type_traits)
+ set_property(TARGET absl::spinlock_wait
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::base_internal absl::core_headers
+ absl::errno_saver)
+ set_property(TARGET absl::stacktrace
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::debugging_internal absl::config
+ absl::core_headers)
+ set_property(TARGET absl::statusor
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::status
+ absl::core_headers
+ absl::raw_logging_internal
+ absl::type_traits
+ absl::strings
+ absl::utility
+ absl::variant)
+ set_property(TARGET absl::status
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::atomic_hook
+ absl::config
+ absl::core_headers
+ absl::raw_logging_internal
+ absl::inlined_vector
+ absl::stacktrace
+ absl::symbolize
+ absl::strings
+ absl::cord
+ absl::str_format
+ absl::optional)
+ set_property(TARGET absl::strerror PROPERTY INTERFACE_LINK_LIBRARIES absl::config
+ absl::core_headers absl::errno_saver)
+ set_property(TARGET absl::str_format_internal
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::bits
+ absl::strings
+ absl::config
+ absl::core_headers
+ absl::numeric_representation
+ absl::type_traits
+ absl::int128
+ absl::span)
+ set_property(TARGET absl::str_format PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::str_format_internal)
+ set_property(TARGET absl::strings_internal
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::config
+ absl::core_headers
+ absl::endian
+ absl::raw_logging_internal
+ absl::type_traits)
+ set_property(TARGET absl::strings
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::strings_internal
+ absl::base
+ absl::bits
+ absl::config
+ absl::core_headers
+ absl::endian
+ absl::int128
+ absl::memory
+ absl::raw_logging_internal
+ absl::throw_delegate
+ absl::type_traits)
+ set_property(TARGET absl::symbolize
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::debugging_internal
+ absl::demangle_internal
+ absl::base
+ absl::config
+ absl::core_headers
+ absl::dynamic_annotations
+ absl::malloc_internal
+ absl::raw_logging_internal
+ absl::strings)
+ set_property(TARGET absl::synchronization
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::graphcycles_internal
+ absl::kernel_timeout_internal
+ absl::atomic_hook
+ absl::base
+ absl::base_internal
+ absl::config
+ absl::core_headers
+ absl::dynamic_annotations
+ absl::malloc_internal
+ absl::raw_logging_internal
+ absl::stacktrace
+ absl::symbolize
+ absl::time)
+ set_property(TARGET absl::throw_delegate
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::config
+ absl::raw_logging_internal)
+ set_property(TARGET absl::time
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::base
+ absl::civil_time
+ absl::core_headers
+ absl::int128
+ absl::raw_logging_internal
+ absl::strings
+ absl::time_zone)
+ set_property(TARGET absl::type_traits PROPERTY INTERFACE_LINK_LIBRARIES absl::config)
+ set_property(TARGET absl::utility
+ PROPERTY INTERFACE_LINK_LIBRARIES absl::base_internal absl::config
+ absl::type_traits)
+ set_property(TARGET absl::variant
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::bad_variant_access
+ absl::base_internal
+ absl::config
+ absl::core_headers
+ absl::type_traits
+ absl::utility)
+ set_property(TARGET absl::wyhash PROPERTY INTERFACE_LINK_LIBRARIES absl::config
+ absl::endian absl::int128)
+
+ externalproject_add(absl_ep
+ ${EP_LOG_OPTIONS}
+ URL ${ABSL_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_ABSL_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${ABSL_CMAKE_ARGS}
+ BUILD_BYPRODUCTS ${ABSL_BUILD_BYPRODUCTS})
+
+ # Work around https://gitlab.kitware.com/cmake/cmake/issues/15052
+ file(MAKE_DIRECTORY ${ABSL_INCLUDE_DIR})
+
+ endif()
+endmacro()
+
+macro(build_grpc)
+ resolve_dependency(c-ares
+ HAVE_ALT
+ TRUE
+ PC_PACKAGE_NAMES
+ libcares)
+ # TODO: Don't use global includes but rather target_include_directories
+ get_target_property(c-ares_INCLUDE_DIR c-ares::cares INTERFACE_INCLUDE_DIRECTORIES)
+ include_directories(SYSTEM ${c-ares_INCLUDE_DIR})
+
+ # First need to build Abseil
+ build_absl_once()
+
+ message(STATUS "Building gRPC from source")
+
+ set(GRPC_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/grpc_ep-prefix/src/grpc_ep-build")
+ set(GRPC_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/grpc_ep-install")
+ set(GRPC_HOME "${GRPC_PREFIX}")
+ set(GRPC_INCLUDE_DIR "${GRPC_PREFIX}/include")
+
+ set(GRPC_STATIC_LIBRARY_GPR
+ "${GRPC_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gpr${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ set(GRPC_STATIC_LIBRARY_GRPC
+ "${GRPC_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}grpc${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ set(GRPC_STATIC_LIBRARY_GRPCPP
+ "${GRPC_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}grpc++${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ set(GRPC_STATIC_LIBRARY_ADDRESS_SORTING
+ "${GRPC_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}address_sorting${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ set(GRPC_STATIC_LIBRARY_UPB
+ "${GRPC_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}upb${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ set(GRPC_CPP_PLUGIN "${GRPC_PREFIX}/bin/grpc_cpp_plugin${CMAKE_EXECUTABLE_SUFFIX}")
+
+ set(GRPC_CMAKE_PREFIX)
+
+ add_custom_target(grpc_dependencies)
+
+ add_dependencies(grpc_dependencies absl_ep)
+ if(CARES_VENDORED)
+ add_dependencies(grpc_dependencies cares_ep)
+ endif()
+
+ if(GFLAGS_VENDORED)
+ add_dependencies(grpc_dependencies gflags_ep)
+ endif()
+
+ if(RE2_VENDORED)
+ add_dependencies(grpc_dependencies re2_ep)
+ endif()
+
+ add_dependencies(grpc_dependencies ${ARROW_PROTOBUF_LIBPROTOBUF} c-ares::cares
+ ZLIB::ZLIB)
+
+ get_target_property(GRPC_PROTOBUF_INCLUDE_DIR ${ARROW_PROTOBUF_LIBPROTOBUF}
+ INTERFACE_INCLUDE_DIRECTORIES)
+ get_filename_component(GRPC_PB_ROOT "${GRPC_PROTOBUF_INCLUDE_DIR}" DIRECTORY)
+ get_target_property(GRPC_Protobuf_PROTOC_LIBRARY ${ARROW_PROTOBUF_LIBPROTOC}
+ IMPORTED_LOCATION)
+ get_target_property(GRPC_CARES_INCLUDE_DIR c-ares::cares INTERFACE_INCLUDE_DIRECTORIES)
+ get_filename_component(GRPC_CARES_ROOT "${GRPC_CARES_INCLUDE_DIR}" DIRECTORY)
+ get_target_property(GRPC_GFLAGS_INCLUDE_DIR ${GFLAGS_LIBRARIES}
+ INTERFACE_INCLUDE_DIRECTORIES)
+ get_filename_component(GRPC_GFLAGS_ROOT "${GRPC_GFLAGS_INCLUDE_DIR}" DIRECTORY)
+ get_target_property(GRPC_RE2_INCLUDE_DIR re2::re2 INTERFACE_INCLUDE_DIRECTORIES)
+ get_filename_component(GRPC_RE2_ROOT "${GRPC_RE2_INCLUDE_DIR}" DIRECTORY)
+
+ set(GRPC_CMAKE_PREFIX "${GRPC_CMAKE_PREFIX};${GRPC_PB_ROOT}")
+ set(GRPC_CMAKE_PREFIX "${GRPC_CMAKE_PREFIX};${GRPC_GFLAGS_ROOT}")
+ set(GRPC_CMAKE_PREFIX "${GRPC_CMAKE_PREFIX};${GRPC_CARES_ROOT}")
+ set(GRPC_CMAKE_PREFIX "${GRPC_CMAKE_PREFIX};${GRPC_RE2_ROOT}")
+
+ # ZLIB is never vendored
+ set(GRPC_CMAKE_PREFIX "${GRPC_CMAKE_PREFIX};${ZLIB_ROOT}")
+ set(GRPC_CMAKE_PREFIX "${GRPC_CMAKE_PREFIX};${ABSL_PREFIX}")
+
+ if(RAPIDJSON_VENDORED)
+ add_dependencies(grpc_dependencies rapidjson_ep)
+ endif()
+
+ # Yuck, see https://stackoverflow.com/a/45433229/776560
+ string(REPLACE ";" "|" GRPC_PREFIX_PATH_ALT_SEP "${GRPC_CMAKE_PREFIX}")
+
+ set(GRPC_CMAKE_ARGS
+ "${EP_COMMON_CMAKE_ARGS}"
+ -DCMAKE_PREFIX_PATH='${GRPC_PREFIX_PATH_ALT_SEP}'
+ -DgRPC_ABSL_PROVIDER=package
+ -DgRPC_BUILD_CSHARP_EXT=OFF
+ -DgRPC_BUILD_GRPC_CSHARP_PLUGIN=OFF
+ -DgRPC_BUILD_GRPC_NODE_PLUGIN=OFF
+ -DgRPC_BUILD_GRPC_OBJECTIVE_C_PLUGIN=OFF
+ -DgRPC_BUILD_GRPC_PHP_PLUGIN=OFF
+ -DgRPC_BUILD_GRPC_PYTHON_PLUGIN=OFF
+ -DgRPC_BUILD_GRPC_RUBY_PLUGIN=OFF
+ -DgRPC_BUILD_TESTS=OFF
+ -DgRPC_CARES_PROVIDER=package
+ -DgRPC_GFLAGS_PROVIDER=package
+ -DgRPC_MSVC_STATIC_RUNTIME=${ARROW_USE_STATIC_CRT}
+ -DgRPC_PROTOBUF_PROVIDER=package
+ -DgRPC_RE2_PROVIDER=package
+ -DgRPC_SSL_PROVIDER=package
+ -DgRPC_ZLIB_PROVIDER=package
+ -DCMAKE_INSTALL_PREFIX=${GRPC_PREFIX}
+ -DCMAKE_INSTALL_LIBDIR=lib
+ -DBUILD_SHARED_LIBS=OFF)
+ if(PROTOBUF_VENDORED)
+ list(APPEND GRPC_CMAKE_ARGS -DgRPC_PROTOBUF_PACKAGE_TYPE=CONFIG)
+ endif()
+ if(OPENSSL_ROOT_DIR)
+ list(APPEND GRPC_CMAKE_ARGS -DOPENSSL_ROOT_DIR=${OPENSSL_ROOT_DIR})
+ endif()
+
+ # XXX the gRPC git checkout is huge and takes a long time
+ # Ideally, we should be able to use the tarballs, but they don't contain
+ # vendored dependencies such as c-ares...
+ externalproject_add(grpc_ep
+ URL ${GRPC_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_GRPC_BUILD_SHA256_CHECKSUM}"
+ LIST_SEPARATOR |
+ BUILD_BYPRODUCTS ${GRPC_STATIC_LIBRARY_GPR}
+ ${GRPC_STATIC_LIBRARY_GRPC}
+ ${GRPC_STATIC_LIBRARY_GRPCPP}
+ ${GRPC_STATIC_LIBRARY_ADDRESS_SORTING}
+ ${GRPC_STATIC_LIBRARY_UPB}
+ ${GRPC_CPP_PLUGIN}
+ CMAKE_ARGS ${GRPC_CMAKE_ARGS} ${EP_LOG_OPTIONS}
+ DEPENDS ${grpc_dependencies})
+
+ # Work around https://gitlab.kitware.com/cmake/cmake/issues/15052
+ file(MAKE_DIRECTORY ${GRPC_INCLUDE_DIR})
+
+ add_library(gRPC::upb STATIC IMPORTED)
+ set_target_properties(gRPC::upb
+ PROPERTIES IMPORTED_LOCATION "${GRPC_STATIC_LIBRARY_UPB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${GRPC_INCLUDE_DIR}")
+
+ set(GRPC_GPR_ABSL_LIBRARIES
+ absl::base
+ absl::statusor
+ absl::status
+ absl::cord
+ absl::strings
+ absl::synchronization
+ absl::time)
+ add_library(gRPC::gpr STATIC IMPORTED)
+ set_target_properties(gRPC::gpr
+ PROPERTIES IMPORTED_LOCATION "${GRPC_STATIC_LIBRARY_GPR}"
+ INTERFACE_INCLUDE_DIRECTORIES "${GRPC_INCLUDE_DIR}"
+ INTERFACE_LINK_LIBRARIES "${GRPC_GPR_ABSL_LIBRARIES}")
+
+ add_library(gRPC::address_sorting STATIC IMPORTED)
+ set_target_properties(gRPC::address_sorting
+ PROPERTIES IMPORTED_LOCATION
+ "${GRPC_STATIC_LIBRARY_ADDRESS_SORTING}"
+ INTERFACE_INCLUDE_DIRECTORIES "${GRPC_INCLUDE_DIR}")
+
+ add_library(gRPC::grpc STATIC IMPORTED)
+ set(GRPC_LINK_LIBRARIES
+ gRPC::gpr
+ gRPC::upb
+ gRPC::address_sorting
+ re2::re2
+ c-ares::cares
+ ZLIB::ZLIB
+ OpenSSL::SSL
+ Threads::Threads)
+ set_target_properties(gRPC::grpc
+ PROPERTIES IMPORTED_LOCATION "${GRPC_STATIC_LIBRARY_GRPC}"
+ INTERFACE_INCLUDE_DIRECTORIES "${GRPC_INCLUDE_DIR}"
+ INTERFACE_LINK_LIBRARIES "${GRPC_LINK_LIBRARIES}")
+
+ add_library(gRPC::grpc++ STATIC IMPORTED)
+ set(GRPCPP_LINK_LIBRARIES gRPC::grpc ${ARROW_PROTOBUF_LIBPROTOBUF})
+ set_target_properties(gRPC::grpc++
+ PROPERTIES IMPORTED_LOCATION "${GRPC_STATIC_LIBRARY_GRPCPP}"
+ INTERFACE_INCLUDE_DIRECTORIES "${GRPC_INCLUDE_DIR}"
+ INTERFACE_LINK_LIBRARIES "${GRPCPP_LINK_LIBRARIES}")
+
+ add_executable(gRPC::grpc_cpp_plugin IMPORTED)
+ set_target_properties(gRPC::grpc_cpp_plugin PROPERTIES IMPORTED_LOCATION
+ ${GRPC_CPP_PLUGIN})
+
+ add_dependencies(grpc_ep grpc_dependencies)
+ add_dependencies(toolchain grpc_ep)
+ add_dependencies(gRPC::grpc++ grpc_ep)
+ add_dependencies(gRPC::grpc_cpp_plugin grpc_ep)
+ set(GRPC_VENDORED TRUE)
+
+ # ar -M rejects with the "libgrpc++.a" filename because "+" is a line
+ # continuation character in these scripts, so we have to create a copy of the
+ # static lib that we will bundle later
+
+ set(GRPC_STATIC_LIBRARY_GRPCPP_FOR_AR
+ "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}grpcpp${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ add_custom_command(OUTPUT ${GRPC_STATIC_LIBRARY_GRPCPP_FOR_AR}
+ COMMAND ${CMAKE_COMMAND} -E copy $<TARGET_FILE:gRPC::grpc++>
+ ${GRPC_STATIC_LIBRARY_GRPCPP_FOR_AR}
+ DEPENDS grpc_ep)
+ add_library(gRPC::grpcpp_for_bundling STATIC IMPORTED)
+ set_target_properties(gRPC::grpcpp_for_bundling
+ PROPERTIES IMPORTED_LOCATION
+ "${GRPC_STATIC_LIBRARY_GRPCPP_FOR_AR}")
+
+ set_source_files_properties("${GRPC_STATIC_LIBRARY_GRPCPP_FOR_AR}" PROPERTIES GENERATED
+ TRUE)
+ add_custom_target(grpc_copy_grpc++ ALL DEPENDS "${GRPC_STATIC_LIBRARY_GRPCPP_FOR_AR}")
+ add_dependencies(gRPC::grpcpp_for_bundling grpc_copy_grpc++)
+
+ list(APPEND
+ ARROW_BUNDLED_STATIC_LIBS
+ ${GRPC_GPR_ABSL_LIBRARIES}
+ gRPC::address_sorting
+ gRPC::gpr
+ gRPC::grpc
+ gRPC::grpcpp_for_bundling
+ gRPC::upb)
+endmacro()
+
+if(ARROW_WITH_GRPC)
+ set(ARROW_GRPC_REQUIRED_VERSION "1.17.0")
+ resolve_dependency(gRPC
+ HAVE_ALT
+ TRUE
+ REQUIRED_VERSION
+ ${ARROW_GRPC_REQUIRED_VERSION}
+ PC_PACKAGE_NAMES
+ grpc++)
+
+ # TODO: Don't use global includes but rather target_include_directories
+ get_target_property(GRPC_INCLUDE_DIR gRPC::grpc++ INTERFACE_INCLUDE_DIRECTORIES)
+ include_directories(SYSTEM ${GRPC_INCLUDE_DIR})
+
+ if(GRPC_VENDORED)
+ set(GRPCPP_PP_INCLUDE TRUE)
+ else()
+ # grpc++ headers may reside in ${GRPC_INCLUDE_DIR}/grpc++ or ${GRPC_INCLUDE_DIR}/grpcpp
+ # depending on the gRPC version.
+ if(EXISTS "${GRPC_INCLUDE_DIR}/grpcpp/impl/codegen/config_protobuf.h")
+ set(GRPCPP_PP_INCLUDE TRUE)
+ elseif(EXISTS "${GRPC_INCLUDE_DIR}/grpc++/impl/codegen/config_protobuf.h")
+ set(GRPCPP_PP_INCLUDE FALSE)
+ else()
+ message(FATAL_ERROR "Cannot find grpc++ headers in ${GRPC_INCLUDE_DIR}")
+ endif()
+ endif()
+endif()
+
+# ----------------------------------------------------------------------
+# GCS and dependencies
+
+macro(build_crc32c_once)
+ if(NOT TARGET crc32c_ep)
+ message(STATUS "Building crc32c from source")
+ # Build crc32c
+ set(CRC32C_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/crc32c_ep-install")
+ set(CRC32C_INCLUDE_DIR "${CRC32C_PREFIX}/include")
+ set(CRC32C_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS}
+ -DCMAKE_INSTALL_LIBDIR=lib
+ "-DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>"
+ -DCMAKE_CXX_STANDARD=11
+ -DCRC32C_BUILD_TESTS=OFF
+ -DCRC32C_BUILD_BENCHMARKS=OFF
+ -DCRC32C_USE_GLOG=OFF)
+
+ set(_CRC32C_STATIC_LIBRARY
+ "${CRC32C_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}crc32c${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ set(CRC32C_BUILD_BYPRODUCTS ${_CRC32C_STATIC_LIBRARY})
+ set(CRC32C_LIBRARIES crc32c)
+
+ externalproject_add(crc32c_ep
+ ${EP_LOG_OPTIONS}
+ INSTALL_DIR ${CRC32C_PREFIX}
+ URL ${CRC32C_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_CRC32C_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${CRC32C_CMAKE_ARGS}
+ BUILD_BYPRODUCTS ${CRC32C_BUILD_BYPRODUCTS})
+ # Work around https://gitlab.kitware.com/cmake/cmake/issues/15052
+ file(MAKE_DIRECTORY "${CRC32C_INCLUDE_DIR}")
+ add_library(Crc32c::crc32c STATIC IMPORTED)
+ set_target_properties(Crc32c::crc32c
+ PROPERTIES IMPORTED_LOCATION ${_CRC32C_STATIC_LIBRARY}
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${CRC32C_INCLUDE_DIR}")
+ add_dependencies(Crc32c::crc32c crc32c_ep)
+ endif()
+endmacro()
+
+macro(build_nlohmann_json_once)
+ if(NOT TARGET nlohmann_json_ep)
+ message(STATUS "Building nlohmann-json from source")
+ # "Build" nlohmann-json
+ set(NLOHMANN_JSON_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/nlohmann_json_ep-install")
+ set(NLOHMANN_JSON_INCLUDE_DIR "${NLOHMANN_JSON_PREFIX}/include")
+ set(NLOHMANN_JSON_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS} -DCMAKE_CXX_STANDARD=11
+ "-DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>" -DBUILD_TESTING=OFF)
+
+ set(NLOHMANN_JSON_BUILD_BYPRODUCTS ${NLOHMANN_JSON_PREFIX}/include/nlohmann/json.hpp)
+
+ externalproject_add(nlohmann_json_ep
+ ${EP_LOG_OPTIONS}
+ INSTALL_DIR ${NLOHMANN_JSON_PREFIX}
+ URL ${NLOHMANN_JSON_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_NLOHMANN_JSON_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${NLOHMANN_JSON_CMAKE_ARGS}
+ BUILD_BYPRODUCTS ${NLOHMANN_JSON_BUILD_BYPRODUCTS})
+
+ # Work around https://gitlab.kitware.com/cmake/cmake/issues/15052
+ file(MAKE_DIRECTORY ${NLOHMANN_JSON_INCLUDE_DIR})
+
+ add_library(nlohmann_json::nlohmann_json INTERFACE IMPORTED)
+ set_target_properties(nlohmann_json::nlohmann_json
+ PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
+ "${NLOHMANN_JSON_INCLUDE_DIR}")
+ add_dependencies(nlohmann_json::nlohmann_json nlohmann_json_ep)
+ endif()
+endmacro()
+
+macro(build_google_cloud_cpp_storage)
+ message(STATUS "Building google-cloud-cpp from source")
+ message(STATUS "Only building the google-cloud-cpp::storage component")
+
+ # List of dependencies taken from https://github.com/googleapis/google-cloud-cpp/blob/master/doc/packaging.md
+ build_absl_once()
+ build_crc32c_once()
+ build_nlohmann_json_once()
+
+ # Curl is required on all platforms, but building it internally might also trip over S3's copy.
+ # For now, force its inclusion from the underlying system or fail.
+ find_package(CURL 7.47.0 REQUIRED)
+ find_package(OpenSSL ${ARROW_OPENSSL_REQUIRED_VERSION} REQUIRED)
+
+ # Build google-cloud-cpp, with only storage_client
+
+ # Inject vendored packages via CMAKE_PREFIX_PATH
+ list(APPEND GOOGLE_CLOUD_CPP_PREFIX_PATH_LIST ${ABSL_PREFIX})
+ list(APPEND GOOGLE_CLOUD_CPP_PREFIX_PATH_LIST ${CRC32C_PREFIX})
+ list(APPEND GOOGLE_CLOUD_CPP_PREFIX_PATH_LIST ${NLOHMANN_JSON_PREFIX})
+
+ set(GOOGLE_CLOUD_CPP_PREFIX_PATH_LIST_SEP_CHAR "|")
+ list(JOIN GOOGLE_CLOUD_CPP_PREFIX_PATH_LIST
+ ${GOOGLE_CLOUD_CPP_PREFIX_PATH_LIST_SEP_CHAR} GOOGLE_CLOUD_CPP_PREFIX_PATH)
+
+ set(GOOGLE_CLOUD_CPP_INSTALL_PREFIX
+ "${CMAKE_CURRENT_BINARY_DIR}/google_cloud_cpp_ep-install")
+ set(GOOGLE_CLOUD_CPP_INCLUDE_DIR "${GOOGLE_CLOUD_CPP_INSTALL_PREFIX}/include")
+ set(GOOGLE_CLOUD_CPP_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS}
+ -DBUILD_TESTING=OFF
+ -DCMAKE_INSTALL_LIBDIR=lib
+ "-DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>"
+ -DCMAKE_INSTALL_RPATH=$ORIGIN
+ -DCMAKE_PREFIX_PATH=${GOOGLE_CLOUD_CPP_PREFIX_PATH}
+ # Compile only the storage library and its dependencies. To enable
+ # other services (Spanner, Bigtable, etc.) add them (as a list) to this
+ # parameter. Each has its own `google-cloud-cpp::*` library.
+ -DGOOGLE_CLOUD_CPP_ENABLE=storage)
+ if(OPENSSL_ROOT_DIR)
+ list(APPEND GOOGLE_CLOUD_CPP_CMAKE_ARGS -DOPENSSL_ROOT_DIR=${OPENSSL_ROOT_DIR})
+ endif()
+
+ add_custom_target(google_cloud_cpp_dependencies)
+
+ add_dependencies(google_cloud_cpp_dependencies absl_ep)
+ add_dependencies(google_cloud_cpp_dependencies crc32c_ep)
+ add_dependencies(google_cloud_cpp_dependencies nlohmann_json_ep)
+ # Typically the steps to build the AWKSSDK provide `CURL::libcurl`, but if that is
+ # disabled we need to provide our own.
+ if(NOT TARGET CURL::libcurl)
+ find_package(CURL REQUIRED)
+ if(NOT TARGET CURL::libcurl)
+ # For CMake 3.11 or older
+ add_library(CURL::libcurl UNKNOWN IMPORTED)
+ set_target_properties(CURL::libcurl
+ PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
+ "${CURL_INCLUDE_DIRS}" IMPORTED_LOCATION
+ "${CURL_LIBRARIES}")
+ endif()
+ endif()
+
+ set(GOOGLE_CLOUD_CPP_STATIC_LIBRARY_STORAGE
+ "${GOOGLE_CLOUD_CPP_INSTALL_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}google_cloud_cpp_storage${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+
+ set(GOOGLE_CLOUD_CPP_STATIC_LIBRARY_COMMON
+ "${GOOGLE_CLOUD_CPP_INSTALL_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}google_cloud_cpp_common${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+
+ externalproject_add(google_cloud_cpp_ep
+ ${EP_LOG_OPTIONS}
+ LIST_SEPARATOR ${GOOGLE_CLOUD_CPP_PREFIX_PATH_LIST_SEP_CHAR}
+ INSTALL_DIR ${GOOGLE_CLOUD_CPP_INSTALL_PREFIX}
+ URL ${google_cloud_cpp_storage_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_GOOGLE_CLOUD_CPP_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${GOOGLE_CLOUD_CPP_CMAKE_ARGS}
+ BUILD_BYPRODUCTS ${GOOGLE_CLOUD_CPP_STATIC_LIBRARY_STORAGE}
+ ${GOOGLE_CLOUD_CPP_STATIC_LIBRARY_COMMON}
+ DEPENDS google_cloud_cpp_dependencies)
+
+ # Work around https://gitlab.kitware.com/cmake/cmake/issues/15052
+ file(MAKE_DIRECTORY ${GOOGLE_CLOUD_CPP_INCLUDE_DIR})
+
+ add_dependencies(toolchain google_cloud_cpp_ep)
+
+ add_library(google-cloud-cpp::common STATIC IMPORTED)
+ set_target_properties(google-cloud-cpp::common
+ PROPERTIES IMPORTED_LOCATION
+ "${GOOGLE_CLOUD_CPP_STATIC_LIBRARY_COMMON}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${GOOGLE_CLOUD_CPP_INCLUDE_DIR}")
+ set_property(TARGET google-cloud-cpp::common
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ absl::any
+ absl::flat_hash_map
+ absl::memory
+ absl::optional
+ absl::time
+ Threads::Threads)
+
+ add_library(google-cloud-cpp::storage STATIC IMPORTED)
+ set_target_properties(google-cloud-cpp::storage
+ PROPERTIES IMPORTED_LOCATION
+ "${GOOGLE_CLOUD_CPP_STATIC_LIBRARY_STORAGE}"
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${GOOGLE_CLOUD_CPP_INCLUDE_DIR}")
+ set_property(TARGET google-cloud-cpp::storage
+ PROPERTY INTERFACE_LINK_LIBRARIES
+ google-cloud-cpp::common
+ absl::memory
+ absl::strings
+ absl::str_format
+ absl::time
+ absl::variant
+ nlohmann_json::nlohmann_json
+ Crc32c::crc32c
+ CURL::libcurl
+ Threads::Threads
+ OpenSSL::SSL
+ OpenSSL::Crypto)
+ add_dependencies(google-cloud-cpp::storage google_cloud_cpp_ep)
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS google-cloud-cpp::storage
+ google-cloud-cpp::common)
+endmacro()
+
+if(ARROW_WITH_GOOGLE_CLOUD_CPP)
+ resolve_dependency(google_cloud_cpp_storage)
+ get_target_property(google_cloud_cpp_storage_INCLUDE_DIR google-cloud-cpp::storage
+ INTERFACE_INCLUDE_DIRECTORIES)
+ include_directories(SYSTEM ${google_cloud_cpp_storage_INCLUDE_DIR})
+ get_target_property(absl_base_INCLUDE_DIR absl::base INTERFACE_INCLUDE_DIRECTORIES)
+ include_directories(SYSTEM ${absl_base_INCLUDE_DIR})
+ message(STATUS "Found google-cloud-cpp::storage headers: ${google_cloud_cpp_storage_INCLUDE_DIR}"
+ )
+endif()
+
+#
+# HDFS thirdparty setup
+
+if(DEFINED ENV{HADOOP_HOME})
+ set(HADOOP_HOME $ENV{HADOOP_HOME})
+ if(NOT EXISTS "${HADOOP_HOME}/include/hdfs.h")
+ message(STATUS "Did not find hdfs.h in expected location, using vendored one")
+ set(HADOOP_HOME "${THIRDPARTY_DIR}/hadoop")
+ endif()
+else()
+ set(HADOOP_HOME "${THIRDPARTY_DIR}/hadoop")
+endif()
+
+set(HDFS_H_PATH "${HADOOP_HOME}/include/hdfs.h")
+if(NOT EXISTS ${HDFS_H_PATH})
+ message(FATAL_ERROR "Did not find hdfs.h at ${HDFS_H_PATH}")
+endif()
+message(STATUS "Found hdfs.h at: " ${HDFS_H_PATH})
+
+include_directories(SYSTEM "${HADOOP_HOME}/include")
+
+# ----------------------------------------------------------------------
+# Apache ORC
+
+macro(build_orc)
+ message("Building Apache ORC from source")
+ set(ORC_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/orc_ep-install")
+ set(ORC_HOME "${ORC_PREFIX}")
+ set(ORC_INCLUDE_DIR "${ORC_PREFIX}/include")
+ set(ORC_STATIC_LIB
+ "${ORC_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}orc${CMAKE_STATIC_LIBRARY_SUFFIX}")
+
+ get_target_property(ORC_PROTOBUF_INCLUDE_DIR ${ARROW_PROTOBUF_LIBPROTOBUF}
+ INTERFACE_INCLUDE_DIRECTORIES)
+ get_filename_component(ORC_PB_ROOT "${ORC_PROTOBUF_INCLUDE_DIR}" DIRECTORY)
+ get_target_property(ORC_PROTOBUF_LIBRARY ${ARROW_PROTOBUF_LIBPROTOBUF}
+ IMPORTED_LOCATION)
+
+ get_target_property(ORC_SNAPPY_INCLUDE_DIR Snappy::snappy INTERFACE_INCLUDE_DIRECTORIES)
+ get_filename_component(ORC_SNAPPY_ROOT "${ORC_SNAPPY_INCLUDE_DIR}" DIRECTORY)
+
+ get_target_property(ORC_LZ4_ROOT LZ4::lz4 INTERFACE_INCLUDE_DIRECTORIES)
+ get_filename_component(ORC_LZ4_ROOT "${ORC_LZ4_ROOT}" DIRECTORY)
+
+ # Weirdly passing in PROTOBUF_LIBRARY for PROTOC_LIBRARY still results in ORC finding
+ # the protoc library.
+ set(ORC_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS}
+ "-DCMAKE_INSTALL_PREFIX=${ORC_PREFIX}"
+ -DCMAKE_CXX_FLAGS=${EP_CXX_FLAGS}
+ -DSTOP_BUILD_ON_WARNING=OFF
+ -DBUILD_LIBHDFSPP=OFF
+ -DBUILD_JAVA=OFF
+ -DBUILD_TOOLS=OFF
+ -DBUILD_CPP_TESTS=OFF
+ -DINSTALL_VENDORED_LIBS=OFF
+ "-DSNAPPY_HOME=${ORC_SNAPPY_ROOT}"
+ "-DSNAPPY_INCLUDE_DIR=${ORC_SNAPPY_INCLUDE_DIR}"
+ "-DPROTOBUF_HOME=${ORC_PB_ROOT}"
+ "-DPROTOBUF_INCLUDE_DIR=${ORC_PROTOBUF_INCLUDE_DIR}"
+ "-DPROTOBUF_LIBRARY=${ORC_PROTOBUF_LIBRARY}"
+ "-DPROTOC_LIBRARY=${ORC_PROTOBUF_LIBRARY}"
+ "-DLZ4_HOME=${LZ4_HOME}"
+ "-DZSTD_HOME=${ZSTD_HOME}")
+ if(ZLIB_ROOT)
+ set(ORC_CMAKE_ARGS ${ORC_CMAKE_ARGS} "-DZLIB_HOME=${ZLIB_ROOT}")
+ endif()
+
+ # Work around CMake bug
+ file(MAKE_DIRECTORY ${ORC_INCLUDE_DIR})
+
+ externalproject_add(orc_ep
+ URL ${ORC_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_ORC_BUILD_SHA256_CHECKSUM}"
+ BUILD_BYPRODUCTS ${ORC_STATIC_LIB}
+ CMAKE_ARGS ${ORC_CMAKE_ARGS} ${EP_LOG_OPTIONS})
+
+ add_dependencies(toolchain orc_ep)
+
+ set(ORC_VENDORED 1)
+ add_dependencies(orc_ep ZLIB::ZLIB)
+ add_dependencies(orc_ep LZ4::lz4)
+ add_dependencies(orc_ep Snappy::snappy)
+ add_dependencies(orc_ep ${ARROW_PROTOBUF_LIBPROTOBUF})
+
+ add_library(orc::liborc STATIC IMPORTED)
+ set_target_properties(orc::liborc
+ PROPERTIES IMPORTED_LOCATION "${ORC_STATIC_LIB}"
+ INTERFACE_INCLUDE_DIRECTORIES "${ORC_INCLUDE_DIR}")
+
+ add_dependencies(toolchain orc_ep)
+ add_dependencies(orc::liborc orc_ep)
+
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS orc::liborc)
+endmacro()
+
+if(ARROW_ORC)
+ resolve_dependency(ORC)
+ include_directories(SYSTEM ${ORC_INCLUDE_DIR})
+ message(STATUS "Found ORC static library: ${ORC_STATIC_LIB}")
+ message(STATUS "Found ORC headers: ${ORC_INCLUDE_DIR}")
+endif()
+
+# ----------------------------------------------------------------------
+# AWS SDK for C++
+
+macro(build_awssdk)
+ message("Building AWS C++ SDK from source")
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS
+ "4.9")
+ message(FATAL_ERROR "AWS C++ SDK requires gcc >= 4.9")
+ endif()
+ set(AWSSDK_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/awssdk_ep-install")
+ set(AWSSDK_INCLUDE_DIR "${AWSSDK_PREFIX}/include")
+ set(AWSSDK_LIB_DIR "lib")
+
+ if(WIN32)
+ # On Windows, need to match build types
+ set(AWSSDK_BUILD_TYPE ${CMAKE_BUILD_TYPE})
+ else()
+ # Otherwise, always build in release mode.
+ # Especially with gcc, debug builds can fail with "asm constraint" errors:
+ # https://github.com/TileDB-Inc/TileDB/issues/1351
+ set(AWSSDK_BUILD_TYPE release)
+ endif()
+
+ set(AWSSDK_COMMON_CMAKE_ARGS
+ ${EP_COMMON_CMAKE_ARGS}
+ -DBUILD_SHARED_LIBS=OFF
+ -DCMAKE_BUILD_TYPE=${AWSSDK_BUILD_TYPE}
+ -DCMAKE_INSTALL_LIBDIR=${AWSSDK_LIB_DIR}
+ -DENABLE_TESTING=OFF
+ -DENABLE_UNITY_BUILD=ON
+ "-DCMAKE_INSTALL_PREFIX=${AWSSDK_PREFIX}"
+ "-DCMAKE_PREFIX_PATH=${AWSSDK_PREFIX}")
+
+ set(AWSSDK_CMAKE_ARGS
+ ${AWSSDK_COMMON_CMAKE_ARGS}
+ -DBUILD_DEPS=OFF
+ -DBUILD_ONLY=config\\$<SEMICOLON>s3\\$<SEMICOLON>transfer\\$<SEMICOLON>identity-management\\$<SEMICOLON>sts
+ -DMINIMIZE_SIZE=ON)
+ if(UNIX AND TARGET zlib_ep)
+ list(APPEND AWSSDK_CMAKE_ARGS -DZLIB_INCLUDE_DIR=${ZLIB_INCLUDE_DIRS}
+ -DZLIB_LIBRARY=${ZLIB_LIBRARIES})
+ endif()
+
+ file(MAKE_DIRECTORY ${AWSSDK_INCLUDE_DIR})
+
+ # AWS C++ SDK related libraries to link statically
+ set(_AWSSDK_LIBS
+ aws-cpp-sdk-identity-management
+ aws-cpp-sdk-sts
+ aws-cpp-sdk-cognito-identity
+ aws-cpp-sdk-s3
+ aws-cpp-sdk-core
+ aws-c-event-stream
+ aws-checksums
+ aws-c-common)
+ set(AWSSDK_LIBRARIES)
+ foreach(_AWSSDK_LIB ${_AWSSDK_LIBS})
+ # aws-c-common -> AWS-C-COMMON
+ string(TOUPPER ${_AWSSDK_LIB} _AWSSDK_LIB_UPPER)
+ # AWS-C-COMMON -> AWS_C_COMMON
+ string(REPLACE "-" "_" _AWSSDK_LIB_NAME_PREFIX ${_AWSSDK_LIB_UPPER})
+ set(_AWSSDK_STATIC_LIBRARY
+ "${AWSSDK_PREFIX}/${AWSSDK_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}${_AWSSDK_LIB}${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+ if(${_AWSSDK_LIB} MATCHES "^aws-cpp-sdk-")
+ set(_AWSSDK_TARGET_NAME ${_AWSSDK_LIB})
+ else()
+ set(_AWSSDK_TARGET_NAME AWS::${_AWSSDK_LIB})
+ endif()
+ add_library(${_AWSSDK_TARGET_NAME} STATIC IMPORTED)
+ set_target_properties(${_AWSSDK_TARGET_NAME}
+ PROPERTIES IMPORTED_LOCATION ${_AWSSDK_STATIC_LIBRARY}
+ INTERFACE_INCLUDE_DIRECTORIES
+ "${AWSSDK_INCLUDE_DIR}")
+ set("${_AWSSDK_LIB_NAME_PREFIX}_STATIC_LIBRARY" ${_AWSSDK_STATIC_LIBRARY})
+ list(APPEND AWSSDK_LIBRARIES ${_AWSSDK_TARGET_NAME})
+ endforeach()
+
+ externalproject_add(aws_c_common_ep
+ ${EP_LOG_OPTIONS}
+ URL ${AWS_C_COMMON_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_AWS_C_COMMON_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${AWSSDK_COMMON_CMAKE_ARGS}
+ BUILD_BYPRODUCTS ${AWS_C_COMMON_STATIC_LIBRARY})
+ add_dependencies(AWS::aws-c-common aws_c_common_ep)
+
+ externalproject_add(aws_checksums_ep
+ ${EP_LOG_OPTIONS}
+ URL ${AWS_CHECKSUMS_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_AWS_CHECKSUMS_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${AWSSDK_COMMON_CMAKE_ARGS}
+ BUILD_BYPRODUCTS ${AWS_CHECKSUMS_STATIC_LIBRARY}
+ DEPENDS aws_c_common_ep)
+ add_dependencies(AWS::aws-checksums aws_checksums_ep)
+
+ externalproject_add(aws_c_event_stream_ep
+ ${EP_LOG_OPTIONS}
+ URL ${AWS_C_EVENT_STREAM_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_AWS_C_EVENT_STREAM_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${AWSSDK_COMMON_CMAKE_ARGS}
+ BUILD_BYPRODUCTS ${AWS_C_EVENT_STREAM_STATIC_LIBRARY}
+ DEPENDS aws_checksums_ep)
+ add_dependencies(AWS::aws-c-event-stream aws_c_event_stream_ep)
+
+ set(AWSSDK_PATCH_COMMAND)
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER
+ "10")
+ # Workaround for https://github.com/aws/aws-sdk-cpp/issues/1750
+ set(AWSSDK_PATCH_COMMAND "sed" "-i.bak" "-e" "s/\"-Werror\"//g"
+ "<SOURCE_DIR>/cmake/compiler_settings.cmake")
+ endif()
+ externalproject_add(awssdk_ep
+ ${EP_LOG_OPTIONS}
+ URL ${AWSSDK_SOURCE_URL}
+ URL_HASH "SHA256=${ARROW_AWSSDK_BUILD_SHA256_CHECKSUM}"
+ CMAKE_ARGS ${AWSSDK_CMAKE_ARGS}
+ PATCH_COMMAND ${AWSSDK_PATCH_COMMAND}
+ BUILD_BYPRODUCTS ${AWS_CPP_SDK_COGNITO_IDENTITY_STATIC_LIBRARY}
+ ${AWS_CPP_SDK_CORE_STATIC_LIBRARY}
+ ${AWS_CPP_SDK_IDENTITY_MANAGEMENT_STATIC_LIBRARY}
+ ${AWS_CPP_SDK_S3_STATIC_LIBRARY}
+ ${AWS_CPP_SDK_STS_STATIC_LIBRARY}
+ DEPENDS aws_c_event_stream_ep)
+ add_dependencies(toolchain awssdk_ep)
+ foreach(_AWSSDK_LIB ${_AWSSDK_LIBS})
+ if(${_AWSSDK_LIB} MATCHES "^aws-cpp-sdk-")
+ add_dependencies(${_AWSSDK_LIB} awssdk_ep)
+ endif()
+ endforeach()
+
+ set(AWSSDK_VENDORED TRUE)
+ list(APPEND ARROW_BUNDLED_STATIC_LIBS ${AWSSDK_LIBRARIES})
+ set(AWSSDK_LINK_LIBRARIES ${AWSSDK_LIBRARIES})
+ if(UNIX)
+ # on Linux and macOS curl seems to be required
+ find_package(CURL REQUIRED)
+ if(NOT TARGET CURL::libcurl)
+ # For CMake 3.11 or older
+ add_library(CURL::libcurl UNKNOWN IMPORTED)
+ set_target_properties(CURL::libcurl
+ PROPERTIES INTERFACE_INCLUDE_DIRECTORIES
+ "${CURL_INCLUDE_DIRS}" IMPORTED_LOCATION
+ "${CURL_LIBRARIES}")
+ endif()
+ set_property(TARGET aws-cpp-sdk-core
+ APPEND
+ PROPERTY INTERFACE_LINK_LIBRARIES CURL::libcurl)
+ set_property(TARGET CURL::libcurl
+ APPEND
+ PROPERTY INTERFACE_LINK_LIBRARIES OpenSSL::SSL)
+ if(TARGET zlib_ep)
+ set_property(TARGET aws-cpp-sdk-core
+ APPEND
+ PROPERTY INTERFACE_LINK_LIBRARIES ZLIB::ZLIB)
+ add_dependencies(awssdk_ep zlib_ep)
+ endif()
+ endif()
+
+ # AWSSDK is static-only build
+endmacro()
+
+if(ARROW_S3)
+ # See https://aws.amazon.com/blogs/developer/developer-experience-of-the-aws-sdk-for-c-now-simplified-by-cmake/
+
+ # Workaround to force AWS CMake configuration to look for shared libraries
+ if(DEFINED ENV{CONDA_PREFIX})
+ if(DEFINED BUILD_SHARED_LIBS)
+ set(BUILD_SHARED_LIBS_WAS_SET TRUE)
+ set(BUILD_SHARED_LIBS_VALUE ${BUILD_SHARED_LIBS})
+ else()
+ set(BUILD_SHARED_LIBS_WAS_SET FALSE)
+ endif()
+ set(BUILD_SHARED_LIBS "ON")
+ endif()
+
+ # Need to customize the find_package() call, so cannot call resolve_dependency()
+ if(AWSSDK_SOURCE STREQUAL "AUTO")
+ find_package(AWSSDK
+ COMPONENTS config
+ s3
+ transfer
+ identity-management
+ sts)
+ if(NOT AWSSDK_FOUND)
+ build_awssdk()
+ endif()
+ elseif(AWSSDK_SOURCE STREQUAL "BUNDLED")
+ build_awssdk()
+ elseif(AWSSDK_SOURCE STREQUAL "SYSTEM")
+ find_package(AWSSDK REQUIRED
+ COMPONENTS config
+ s3
+ transfer
+ identity-management
+ sts)
+ endif()
+
+ # Restore previous value of BUILD_SHARED_LIBS
+ if(DEFINED ENV{CONDA_PREFIX})
+ if(BUILD_SHARED_LIBS_WAS_SET)
+ set(BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS_VALUE})
+ else()
+ unset(BUILD_SHARED_LIBS)
+ endif()
+ endif()
+
+ include_directories(SYSTEM ${AWSSDK_INCLUDE_DIR})
+ message(STATUS "Found AWS SDK headers: ${AWSSDK_INCLUDE_DIR}")
+ message(STATUS "Found AWS SDK libraries: ${AWSSDK_LINK_LIBRARIES}")
+
+ if(APPLE)
+ # CoreFoundation's path is hardcoded in the CMake files provided by
+ # aws-sdk-cpp to use the MacOSX SDK provided by XCode which makes
+ # XCode a hard dependency. Command Line Tools is often used instead
+ # of the full XCode suite, so let the linker to find it.
+ set_target_properties(AWS::aws-c-common
+ PROPERTIES INTERFACE_LINK_LIBRARIES
+ "-pthread;pthread;-framework CoreFoundation")
+ endif()
+endif()
+
+message(STATUS "All bundled static libraries: ${ARROW_BUNDLED_STATIC_LIBS}")
+
+# Write out the package configurations.
+
+configure_file("src/arrow/util/config.h.cmake" "src/arrow/util/config.h" ESCAPE_QUOTES)
+install(FILES "${ARROW_BINARY_DIR}/src/arrow/util/config.h"
+ DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/arrow/util")
diff --git a/src/arrow/cpp/cmake_modules/UseCython.cmake b/src/arrow/cpp/cmake_modules/UseCython.cmake
new file mode 100644
index 000000000..f2025efb4
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/UseCython.cmake
@@ -0,0 +1,187 @@
+# Define a function to create Cython modules.
+#
+# For more information on the Cython project, see http://cython.org/.
+# "Cython is a language that makes writing C extensions for the Python language
+# as easy as Python itself."
+#
+# This file defines a CMake function to build a Cython Python module.
+# To use it, first include this file.
+#
+# include( UseCython )
+#
+# Then call cython_add_module to create a module.
+#
+# cython_add_module( <target_name> <pyx_target_name> <output_files> <src1> <src2> ... <srcN> )
+#
+# Where <module_name> is the desired name of the target for the resulting Python module,
+# <pyx_target_name> is the desired name of the target that runs the Cython compiler
+# to generate the needed C or C++ files, <output_files> is a variable to hold the
+# files generated by Cython, and <src1> <src2> ... are source files
+# to be compiled into the module, e.g. *.pyx, *.c, *.cxx, etc.
+# only one .pyx file may be present for each target
+# (this is an inherent limitation of Cython).
+#
+# The sample paths set with the CMake include_directories() command will be used
+# for include directories to search for *.pxd when running the Cython compiler.
+#
+# Cache variables that effect the behavior include:
+#
+# CYTHON_ANNOTATE
+# CYTHON_NO_DOCSTRINGS
+# CYTHON_FLAGS
+#
+# Source file properties that effect the build process are
+#
+# CYTHON_IS_CXX
+# CYTHON_IS_PUBLIC
+# CYTHON_IS_API
+#
+# If this is set of a *.pyx file with CMake set_source_files_properties()
+# command, the file will be compiled as a C++ file.
+#
+# See also FindCython.cmake
+
+#=============================================================================
+# Copyright 2011 Kitware, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#=============================================================================
+
+# Configuration options.
+set(CYTHON_ANNOTATE OFF CACHE BOOL "Create an annotated .html file when compiling *.pyx.")
+set(CYTHON_NO_DOCSTRINGS OFF CACHE BOOL "Strip docstrings from the compiled module.")
+set(CYTHON_FLAGS "" CACHE STRING "Extra flags to the cython compiler.")
+mark_as_advanced(CYTHON_ANNOTATE CYTHON_NO_DOCSTRINGS CYTHON_FLAGS)
+
+find_package(Python3Alt REQUIRED)
+
+# (using another C++ extension breaks coverage)
+set(CYTHON_CXX_EXTENSION "cpp")
+set(CYTHON_C_EXTENSION "c")
+
+# Create a *.c or *.cpp file from a *.pyx file.
+# Input the generated file basename. The generate files will put into the variable
+# placed in the "generated_files" argument. Finally all the *.py and *.pyx files.
+function(compile_pyx
+ _name
+ pyx_target_name
+ generated_files
+ pyx_file)
+ # Default to assuming all files are C.
+ set(cxx_arg "")
+ set(extension ${CYTHON_C_EXTENSION})
+ set(pyx_lang "C")
+ set(comment "Compiling Cython C source for ${_name}...")
+
+ get_filename_component(pyx_file_basename "${pyx_file}" NAME_WE)
+
+ # Determine if it is a C or C++ file.
+ get_source_file_property(property_is_cxx ${pyx_file} CYTHON_IS_CXX)
+ if(${property_is_cxx})
+ set(cxx_arg "--cplus")
+ set(extension ${CYTHON_CXX_EXTENSION})
+ set(pyx_lang "CXX")
+ set(comment "Compiling Cython CXX source for ${_name}...")
+ endif()
+ get_source_file_property(pyx_location ${pyx_file} LOCATION)
+
+ set(output_file "${_name}.${extension}")
+
+ # Set additional flags.
+ if(CYTHON_ANNOTATE)
+ set(annotate_arg "--annotate")
+ endif()
+
+ if(CYTHON_NO_DOCSTRINGS)
+ set(no_docstrings_arg "--no-docstrings")
+ endif()
+
+ if(NOT WIN32)
+ string( TOLOWER "${CMAKE_BUILD_TYPE}" build_type )
+ if("${build_type}" STREQUAL "debug"
+ OR "${build_type}" STREQUAL "relwithdebinfo")
+ set(cython_debug_arg "--gdb")
+ endif()
+ endif()
+
+ # Determining generated file names.
+ get_source_file_property(property_is_public ${pyx_file} CYTHON_PUBLIC)
+ get_source_file_property(property_is_api ${pyx_file} CYTHON_API)
+ if(${property_is_api})
+ set(_generated_files "${output_file}" "${_name}.h" "${name}_api.h")
+ elseif(${property_is_public})
+ set(_generated_files "${output_file}" "${_name}.h")
+ else()
+ set(_generated_files "${output_file}")
+ endif()
+ set_source_files_properties(${_generated_files} PROPERTIES GENERATED TRUE)
+
+ if(NOT WIN32)
+ # Cython creates a lot of compiler warning detritus on clang
+ set_source_files_properties(${_generated_files} PROPERTIES COMPILE_FLAGS
+ -Wno-unused-function)
+ endif()
+
+ set(${generated_files} ${_generated_files} PARENT_SCOPE)
+
+ # Add the command to run the compiler.
+ add_custom_target(
+ ${pyx_target_name}
+ COMMAND ${PYTHON_EXECUTABLE}
+ -m
+ cython
+ ${cxx_arg}
+ ${annotate_arg}
+ ${no_docstrings_arg}
+ ${cython_debug_arg}
+ ${CYTHON_FLAGS}
+ # Necessary for autodoc of function arguments
+ --directive embedsignature=True
+ # Necessary for Cython code coverage
+ --working
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ --output-file
+ "${CMAKE_CURRENT_BINARY_DIR}/${output_file}"
+ "${CMAKE_CURRENT_SOURCE_DIR}/${pyx_file}"
+ DEPENDS ${pyx_location}
+ # Do not specify byproducts for now since they don't work with the older
+ # version of cmake available in the apt repositories.
+ #BYPRODUCTS ${_generated_files}
+ COMMENT ${comment})
+
+ # Remove their visibility to the user.
+ set(corresponding_pxd_file "" CACHE INTERNAL "")
+ set(header_location "" CACHE INTERNAL "")
+ set(pxd_location "" CACHE INTERNAL "")
+endfunction()
+
+# cython_add_module( <name> src1 src2 ... srcN )
+# Build the Cython Python module.
+function(cython_add_module _name pyx_target_name generated_files)
+ set(pyx_module_source "")
+ set(other_module_sources "")
+ foreach(_file ${ARGN})
+ if(${_file} MATCHES ".*\\.py[x]?$")
+ list(APPEND pyx_module_source ${_file})
+ else()
+ list(APPEND other_module_sources ${_file})
+ endif()
+ endforeach()
+ compile_pyx(${_name} ${pyx_target_name} _generated_files ${pyx_module_source})
+ set(${generated_files} ${_generated_files} PARENT_SCOPE)
+ include_directories(${PYTHON_INCLUDE_DIRS})
+ python_add_module(${_name} ${_generated_files} ${other_module_sources})
+ add_dependencies(${_name} ${pyx_target_name})
+endfunction()
+
+include(CMakeParseArguments)
diff --git a/src/arrow/cpp/cmake_modules/Usevcpkg.cmake b/src/arrow/cpp/cmake_modules/Usevcpkg.cmake
new file mode 100644
index 000000000..06ac4dd07
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/Usevcpkg.cmake
@@ -0,0 +1,249 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+message(STATUS "Using vcpkg to find dependencies")
+
+# ----------------------------------------------------------------------
+# Define macros
+
+# macro to list subdirectirectories (non-recursive)
+macro(list_subdirs SUBDIRS DIR)
+ file(GLOB children_
+ RELATIVE ${DIR}
+ ${DIR}/*)
+ set(subdirs_ "")
+ foreach(child_ ${children_})
+ if(IS_DIRECTORY "${DIR}/${child_}")
+ list(APPEND subdirs_ ${child_})
+ endif()
+ endforeach()
+ set("${SUBDIRS}" ${subdirs_})
+ unset(children_)
+ unset(subdirs_)
+endmacro()
+
+# ----------------------------------------------------------------------
+# Get VCPKG_ROOT
+
+if(DEFINED CMAKE_TOOLCHAIN_FILE)
+ # Get it from the CMake variable CMAKE_TOOLCHAIN_FILE
+ get_filename_component(_VCPKG_DOT_CMAKE "${CMAKE_TOOLCHAIN_FILE}" NAME)
+ if(EXISTS "${CMAKE_TOOLCHAIN_FILE}" AND _VCPKG_DOT_CMAKE STREQUAL "vcpkg.cmake")
+ get_filename_component(_VCPKG_BUILDSYSTEMS_DIR "${CMAKE_TOOLCHAIN_FILE}" DIRECTORY)
+ get_filename_component(VCPKG_ROOT "${_VCPKG_BUILDSYSTEMS_DIR}/../.." ABSOLUTE)
+ else()
+ message(FATAL_ERROR "vcpkg toolchain file not found at path specified in -DCMAKE_TOOLCHAIN_FILE"
+ )
+ endif()
+else()
+ if(DEFINED VCPKG_ROOT)
+ # Get it from the CMake variable VCPKG_ROOT
+ find_program(_VCPKG_BIN vcpkg
+ PATHS "${VCPKG_ROOT}"
+ NO_DEFAULT_PATH)
+ if(NOT _VCPKG_BIN)
+ message(FATAL_ERROR "vcpkg not found in directory specified in -DVCPKG_ROOT")
+ endif()
+ elseif(DEFINED ENV{VCPKG_ROOT})
+ # Get it from the environment variable VCPKG_ROOT
+ set(VCPKG_ROOT $ENV{VCPKG_ROOT})
+ find_program(_VCPKG_BIN vcpkg
+ PATHS "${VCPKG_ROOT}"
+ NO_DEFAULT_PATH)
+ if(NOT _VCPKG_BIN)
+ message(FATAL_ERROR "vcpkg not found in directory in environment variable VCPKG_ROOT"
+ )
+ endif()
+ else()
+ # Get it from the file vcpkg.path.txt
+ find_program(_VCPKG_BIN vcpkg)
+ if(_VCPKG_BIN)
+ get_filename_component(_VCPKG_REAL_BIN "${_VCPKG_BIN}" REALPATH)
+ get_filename_component(VCPKG_ROOT "${_VCPKG_REAL_BIN}" DIRECTORY)
+ else()
+ if(CMAKE_HOST_WIN32)
+ set(_VCPKG_PATH_TXT "$ENV{LOCALAPPDATA}/vcpkg/vcpkg.path.txt")
+ else()
+ set(_VCPKG_PATH_TXT "$ENV{HOME}/.vcpkg/vcpkg.path.txt")
+ endif()
+ if(EXISTS "${_VCPKG_PATH_TXT}")
+ file(READ "${_VCPKG_PATH_TXT}" VCPKG_ROOT)
+ else()
+ message(FATAL_ERROR "vcpkg not found. Install vcpkg if not installed, "
+ "then run vcpkg integrate install or set environment variable VCPKG_ROOT."
+ )
+ endif()
+ find_program(_VCPKG_BIN vcpkg
+ PATHS "${VCPKG_ROOT}"
+ NO_DEFAULT_PATH)
+ if(NOT _VCPKG_BIN)
+ message(FATAL_ERROR "vcpkg not found. Re-run vcpkg integrate install "
+ "or set environment variable VCPKG_ROOT.")
+ endif()
+ endif()
+ endif()
+ set(CMAKE_TOOLCHAIN_FILE
+ "${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake"
+ CACHE FILEPATH "Path to vcpkg CMake toolchain file")
+endif()
+message(STATUS "Using CMAKE_TOOLCHAIN_FILE: ${CMAKE_TOOLCHAIN_FILE}")
+message(STATUS "Using VCPKG_ROOT: ${VCPKG_ROOT}")
+
+# ----------------------------------------------------------------------
+# Get VCPKG_TARGET_TRIPLET
+
+if(DEFINED ENV{VCPKG_DEFAULT_TRIPLET} AND NOT DEFINED VCPKG_TARGET_TRIPLET)
+ set(VCPKG_TARGET_TRIPLET "$ENV{VCPKG_DEFAULT_TRIPLET}")
+endif()
+# Explicitly set manifest mode on if it is not set and vcpkg.json exists
+if(NOT DEFINED VCPKG_MANIFEST_MODE AND EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/vcpkg.json")
+ set(VCPKG_MANIFEST_MODE
+ ON
+ CACHE BOOL "Use vcpkg.json manifest")
+ message(STATUS "vcpkg.json manifest found. Using VCPKG_MANIFEST_MODE: ON")
+endif()
+# vcpkg can install packages in three different places
+set(_INST_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/vcpkg_installed") # try here first
+set(_INST_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/vcpkg_installed") # try here second
+set(_INST_VCPKG_ROOT "${VCPKG_ROOT}/installed")
+# Iterate over the places
+foreach(_INST_DIR IN LISTS _INST_BUILD_DIR _INST_SOURCE_DIR _INST_VCPKG_ROOT "notfound")
+ if(_INST_DIR STREQUAL "notfound")
+ message(FATAL_ERROR "vcpkg installed libraries directory not found. "
+ "Install packages with vcpkg before executing cmake.")
+ elseif(NOT EXISTS "${_INST_DIR}")
+ continue()
+ elseif((_INST_DIR STREQUAL _INST_BUILD_DIR OR _INST_DIR STREQUAL _INST_SOURCE_DIR)
+ AND NOT VCPKG_MANIFEST_MODE)
+ # Do not look for packages in the build or source dirs if manifest mode is off
+ message(STATUS "Skipped looking for installed packages in ${_INST_DIR} "
+ "because -DVCPKG_MANIFEST_MODE=OFF")
+ continue()
+ else()
+ message(STATUS "Looking for installed packages in ${_INST_DIR}")
+ endif()
+ if(DEFINED VCPKG_TARGET_TRIPLET)
+ # Check if a subdirectory named VCPKG_TARGET_TRIPLET
+ # exists in the vcpkg installed directory
+ if(EXISTS "${_INST_DIR}/${VCPKG_TARGET_TRIPLET}")
+ set(_VCPKG_INSTALLED_DIR "${_INST_DIR}")
+ break()
+ endif()
+ else()
+ # Infer VCPKG_TARGET_TRIPLET from the name of the
+ # subdirectory in the vcpkg installed directory
+ list_subdirs(_VCPKG_TRIPLET_SUBDIRS "${_INST_DIR}")
+ list(REMOVE_ITEM _VCPKG_TRIPLET_SUBDIRS "vcpkg")
+ list(LENGTH _VCPKG_TRIPLET_SUBDIRS _NUM_VCPKG_TRIPLET_SUBDIRS)
+ if(_NUM_VCPKG_TRIPLET_SUBDIRS EQUAL 1)
+ list(GET _VCPKG_TRIPLET_SUBDIRS 0 VCPKG_TARGET_TRIPLET)
+ set(_VCPKG_INSTALLED_DIR "${_INST_DIR}")
+ break()
+ endif()
+ endif()
+endforeach()
+if(NOT DEFINED VCPKG_TARGET_TRIPLET)
+ message(FATAL_ERROR "Could not infer VCPKG_TARGET_TRIPLET. "
+ "Specify triplet with -DVCPKG_TARGET_TRIPLET.")
+elseif(NOT DEFINED _VCPKG_INSTALLED_DIR)
+ message(FATAL_ERROR "Could not find installed vcpkg packages for triplet ${VCPKG_TARGET_TRIPLET}. "
+ "Install packages with vcpkg before executing cmake.")
+endif()
+
+set(VCPKG_TARGET_TRIPLET
+ "${VCPKG_TARGET_TRIPLET}"
+ CACHE STRING "vcpkg triplet for the target environment")
+
+if(NOT DEFINED VCPKG_BUILD_TYPE)
+ set(VCPKG_BUILD_TYPE
+ "${LOWERCASE_BUILD_TYPE}"
+ CACHE STRING "vcpkg build type (release|debug)")
+endif()
+
+if(NOT DEFINED VCPKG_LIBRARY_LINKAGE)
+ if(ARROW_DEPENDENCY_USE_SHARED)
+ set(VCPKG_LIBRARY_LINKAGE "dynamic")
+ else()
+ set(VCPKG_LIBRARY_LINKAGE "static")
+ endif()
+ set(VCPKG_LIBRARY_LINKAGE
+ "${VCPKG_LIBRARY_LINKAGE}"
+ CACHE STRING "vcpkg preferred library linkage (static|dynamic)")
+endif()
+
+message(STATUS "Using vcpkg installed libraries directory: ${_VCPKG_INSTALLED_DIR}")
+message(STATUS "Using VCPKG_TARGET_TRIPLET: ${VCPKG_TARGET_TRIPLET}")
+message(STATUS "Using VCPKG_BUILD_TYPE: ${VCPKG_BUILD_TYPE}")
+message(STATUS "Using VCPKG_LIBRARY_LINKAGE: ${VCPKG_LIBRARY_LINKAGE}")
+
+set(ARROW_VCPKG_PREFIX
+ "${_VCPKG_INSTALLED_DIR}/${VCPKG_TARGET_TRIPLET}"
+ CACHE PATH "Path to target triplet subdirectory in vcpkg installed directory")
+
+set(ARROW_VCPKG
+ ON
+ CACHE BOOL "Use vcpkg for dependencies")
+
+set(ARROW_DEPENDENCY_SOURCE
+ "SYSTEM"
+ CACHE STRING "The specified value VCPKG is implemented internally as SYSTEM" FORCE)
+
+set(BOOST_ROOT
+ "${ARROW_VCPKG_PREFIX}"
+ CACHE STRING "")
+set(BOOST_INCLUDEDIR
+ "${ARROW_VCPKG_PREFIX}/include/boost"
+ CACHE STRING "")
+set(BOOST_LIBRARYDIR
+ "${ARROW_VCPKG_PREFIX}/lib"
+ CACHE STRING "")
+set(OPENSSL_INCLUDE_DIR
+ "${ARROW_VCPKG_PREFIX}/include"
+ CACHE STRING "")
+set(OPENSSL_LIBRARIES
+ "${ARROW_VCPKG_PREFIX}/lib"
+ CACHE STRING "")
+set(OPENSSL_ROOT_DIR
+ "${ARROW_VCPKG_PREFIX}"
+ CACHE STRING "")
+set(Thrift_ROOT
+ "${ARROW_VCPKG_PREFIX}/lib"
+ CACHE STRING "")
+set(ZSTD_INCLUDE_DIR
+ "${ARROW_VCPKG_PREFIX}/include"
+ CACHE STRING "")
+set(ZSTD_ROOT
+ "${ARROW_VCPKG_PREFIX}"
+ CACHE STRING "")
+set(BROTLI_ROOT
+ "${ARROW_VCPKG_PREFIX}"
+ CACHE STRING "")
+set(LZ4_ROOT
+ "${ARROW_VCPKG_PREFIX}"
+ CACHE STRING "")
+
+if(CMAKE_HOST_WIN32)
+ set(LZ4_MSVC_LIB_PREFIX
+ ""
+ CACHE STRING "")
+ set(LZ4_MSVC_STATIC_LIB_SUFFIX
+ ""
+ CACHE STRING "")
+ set(ZSTD_MSVC_LIB_PREFIX
+ ""
+ CACHE STRING "")
+endif()
diff --git a/src/arrow/cpp/cmake_modules/san-config.cmake b/src/arrow/cpp/cmake_modules/san-config.cmake
new file mode 100644
index 000000000..bde9af23e
--- /dev/null
+++ b/src/arrow/cpp/cmake_modules/san-config.cmake
@@ -0,0 +1,122 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License. See accompanying LICENSE file.
+
+# Clang does not support using ASAN and TSAN simultaneously.
+if("${ARROW_USE_ASAN}" AND "${ARROW_USE_TSAN}")
+ message(SEND_ERROR "Can only enable one of ASAN or TSAN at a time")
+endif()
+
+# Flag to enable clang address sanitizer
+# This will only build if clang or a recent enough gcc is the chosen compiler
+if(${ARROW_USE_ASAN})
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
+ OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
+ OR (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION
+ VERSION_GREATER "4.8"))
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -DADDRESS_SANITIZER")
+ else()
+ message(SEND_ERROR "Cannot use ASAN without clang or gcc >= 4.8")
+ endif()
+endif()
+
+# Flag to enable clang undefined behavior sanitizer
+# We explicitly don't enable all of the sanitizer flags:
+# - disable 'vptr' because of RTTI issues across shared libraries (?)
+# - disable 'alignment' because unaligned access is really OK on Nehalem and we do it
+# all over the place.
+# - disable 'function' because it appears to give a false positive
+# (https://github.com/google/sanitizers/issues/911)
+# - disable 'float-divide-by-zero' on clang, which considers it UB
+# (https://bugs.llvm.org/show_bug.cgi?id=17000#c1)
+# Note: GCC does not support the 'function' flag.
+if(${ARROW_USE_UBSAN})
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL
+ "Clang")
+ set(CMAKE_CXX_FLAGS
+ "${CMAKE_CXX_FLAGS} -fsanitize=undefined -fno-sanitize=alignment,vptr,function,float-divide-by-zero -fno-sanitize-recover=all"
+ )
+ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION
+ VERSION_GREATER_EQUAL "5.1")
+ set(CMAKE_CXX_FLAGS
+ "${CMAKE_CXX_FLAGS} -fsanitize=undefined -fno-sanitize=alignment,vptr -fno-sanitize-recover=all"
+ )
+ else()
+ message(SEND_ERROR "Cannot use UBSAN without clang or gcc >= 5.1")
+ endif()
+endif()
+
+# Flag to enable thread sanitizer (clang or gcc 4.8)
+if(${ARROW_USE_TSAN})
+ if(NOT
+ (CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
+ OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
+ OR (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION
+ VERSION_GREATER "4.8")))
+ message(SEND_ERROR "Cannot use TSAN without clang or gcc >= 4.8")
+ endif()
+
+ add_definitions("-fsanitize=thread")
+
+ # Enables dynamic_annotations.h to actually generate code
+ add_definitions("-DDYNAMIC_ANNOTATIONS_ENABLED")
+
+ # changes atomicops to use the tsan implementations
+ add_definitions("-DTHREAD_SANITIZER")
+
+ # Disables using the precompiled template specializations for std::string, shared_ptr, etc
+ # so that the annotations in the header actually take effect.
+ add_definitions("-D_GLIBCXX_EXTERN_TEMPLATE=0")
+
+ # Some of the above also need to be passed to the linker.
+ set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -pie -fsanitize=thread")
+
+ # Strictly speaking, TSAN doesn't require dynamic linking. But it does
+ # require all code to be position independent, and the easiest way to
+ # guarantee that is via dynamic linking (not all 3rd party archives are
+ # compiled with -fPIC e.g. boost).
+ if("${ARROW_LINK}" STREQUAL "a")
+ message("Using dynamic linking for TSAN")
+ set(ARROW_LINK "d")
+ elseif("${ARROW_LINK}" STREQUAL "s")
+ message(SEND_ERROR "Cannot use TSAN with static linking")
+ endif()
+endif()
+
+if(${ARROW_USE_COVERAGE})
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL
+ "Clang")
+ add_definitions("-fsanitize-coverage=pc-table,inline-8bit-counters,edge,no-prune,trace-cmp,trace-div,trace-gep"
+ )
+
+ set(CMAKE_CXX_FLAGS
+ "${CMAKE_CXX_FLAGS} -fsanitize-coverage=pc-table,inline-8bit-counters,edge,no-prune,trace-cmp,trace-div,trace-gep"
+ )
+ else()
+ message(SEND_ERROR "You can only enable coverage with clang")
+ endif()
+endif()
+
+if("${ARROW_USE_UBSAN}"
+ OR "${ARROW_USE_ASAN}"
+ OR "${ARROW_USE_TSAN}")
+ # GCC 4.8 and 4.9 (latest as of this writing) don't allow you to specify
+ # disallowed entries for the sanitizer.
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL
+ "Clang")
+ set(CMAKE_CXX_FLAGS
+ "${CMAKE_CXX_FLAGS} -fsanitize-blacklist=${BUILD_SUPPORT_DIR}/sanitizer-disallowed-entries.txt"
+ )
+ else()
+ message(WARNING "GCC does not support specifying a sanitizer disallowed entries list. Known sanitizer check failures will not be suppressed."
+ )
+ endif()
+endif()
diff --git a/src/arrow/cpp/examples/arrow/CMakeLists.txt b/src/arrow/cpp/examples/arrow/CMakeLists.txt
new file mode 100644
index 000000000..ac758b92d
--- /dev/null
+++ b/src/arrow/cpp/examples/arrow/CMakeLists.txt
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+ADD_ARROW_EXAMPLE(row_wise_conversion_example)
+
+if (ARROW_COMPUTE)
+ ADD_ARROW_EXAMPLE(compute_register_example)
+endif()
+
+if (ARROW_COMPUTE AND ARROW_CSV)
+ ADD_ARROW_EXAMPLE(compute_and_write_csv_example)
+endif()
+
+if (ARROW_PARQUET AND ARROW_DATASET)
+ if (ARROW_BUILD_SHARED)
+ set(DATASET_EXAMPLES_LINK_LIBS arrow_dataset_shared)
+ else()
+ set(DATASET_EXAMPLES_LINK_LIBS arrow_dataset_static)
+ endif()
+
+ ADD_ARROW_EXAMPLE(dataset_parquet_scan_example
+ EXTRA_LINK_LIBS
+ ${DATASET_EXAMPLES_LINK_LIBS})
+ add_dependencies(dataset_parquet_scan_example parquet)
+
+ ADD_ARROW_EXAMPLE(dataset_documentation_example
+ EXTRA_LINK_LIBS
+ ${DATASET_EXAMPLES_LINK_LIBS})
+ add_dependencies(dataset_documentation_example parquet)
+endif()
diff --git a/src/arrow/cpp/examples/arrow/compute_and_write_csv_example.cc b/src/arrow/cpp/examples/arrow/compute_and_write_csv_example.cc
new file mode 100644
index 000000000..db3478759
--- /dev/null
+++ b/src/arrow/cpp/examples/arrow/compute_and_write_csv_example.cc
@@ -0,0 +1,113 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arrow/api.h>
+#include <arrow/compute/api_aggregate.h>
+#include <arrow/csv/api.h>
+#include <arrow/csv/writer.h>
+#include <arrow/io/api.h>
+#include <arrow/result.h>
+#include <arrow/status.h>
+
+#include <iostream>
+#include <vector>
+
+// Many operations in Apache Arrow operate on
+// columns of data, and the columns of data are
+// assembled into a table. In this example, we
+// examine how to compare two arrays which are
+// combined to form a table that is then written
+// out to a CSV file.
+//
+// To run this example you can use
+// ./compute_and_write_csv_example
+//
+// the program will write the files into
+// compute_and_write_output.csv
+// in the current directory
+
+arrow::Status RunMain(int argc, char** argv) {
+ // Make Arrays
+ arrow::NumericBuilder<arrow::Int64Type> int64_builder;
+ arrow::BooleanBuilder boolean_builder;
+
+ // Make place for 8 values in total
+ ARROW_RETURN_NOT_OK(int64_builder.Resize(8));
+ ARROW_RETURN_NOT_OK(boolean_builder.Resize(8));
+
+ // Bulk append the given values
+ std::vector<int64_t> int64_values = {1, 2, 3, 4, 5, 6, 7, 8};
+ ARROW_RETURN_NOT_OK(int64_builder.AppendValues(int64_values));
+ std::shared_ptr<arrow::Array> array_a;
+ ARROW_RETURN_NOT_OK(int64_builder.Finish(&array_a));
+ int64_builder.Reset();
+ int64_values = {2, 5, 1, 3, 6, 2, 7, 4};
+ std::shared_ptr<arrow::Array> array_b;
+ ARROW_RETURN_NOT_OK(int64_builder.AppendValues(int64_values));
+ ARROW_RETURN_NOT_OK(int64_builder.Finish(&array_b));
+
+ // Cast the arrays to their actual types
+ auto int64_array_a = std::static_pointer_cast<arrow::Int64Array>(array_a);
+ auto int64_array_b = std::static_pointer_cast<arrow::Int64Array>(array_b);
+ // Explicit comparison of values using a loop
+ for (int64_t i = 0; i < 8; i++) {
+ if ((!int64_array_a->IsNull(i)) && (!int64_array_b->IsNull(i))) {
+ bool comparison_result = int64_array_a->Value(i) > int64_array_b->Value(i);
+ boolean_builder.UnsafeAppend(comparison_result);
+ } else {
+ boolean_builder.UnsafeAppendNull();
+ }
+ }
+ std::shared_ptr<arrow::Array> array_a_gt_b_self;
+ ARROW_RETURN_NOT_OK(boolean_builder.Finish(&array_a_gt_b_self));
+ std::cout << "Array explicitly compared" << std::endl;
+
+ // Explicit comparison of values using a compute function
+ ARROW_ASSIGN_OR_RAISE(arrow::Datum compared_datum,
+ arrow::compute::CallFunction("greater", {array_a, array_b}));
+ auto array_a_gt_b_compute = compared_datum.make_array();
+ std::cout << "Arrays compared using a compute function" << std::endl;
+
+ // Create a table for the output
+ auto schema =
+ arrow::schema({arrow::field("a", arrow::int64()), arrow::field("b", arrow::int64()),
+ arrow::field("a>b? (self written)", arrow::boolean()),
+ arrow::field("a>b? (arrow)", arrow::boolean())});
+ std::shared_ptr<arrow::Table> my_table = arrow::Table::Make(
+ schema, {array_a, array_b, array_a_gt_b_self, array_a_gt_b_compute});
+
+ std::cout << "Table created" << std::endl;
+
+ // Write table to CSV file
+ auto csv_filename = "compute_and_write_output.csv";
+ ARROW_ASSIGN_OR_RAISE(auto outstream, arrow::io::FileOutputStream::Open(csv_filename));
+
+ std::cout << "Writing CSV file" << std::endl;
+ ARROW_RETURN_NOT_OK(arrow::csv::WriteCSV(
+ *my_table, arrow::csv::WriteOptions::Defaults(), outstream.get()));
+
+ return arrow::Status::OK();
+}
+
+int main(int argc, char** argv) {
+ arrow::Status status = RunMain(argc, argv);
+ if (!status.ok()) {
+ std::cerr << status << std::endl;
+ return EXIT_FAILURE;
+ }
+ return EXIT_SUCCESS;
+}
diff --git a/src/arrow/cpp/examples/arrow/compute_register_example.cc b/src/arrow/cpp/examples/arrow/compute_register_example.cc
new file mode 100644
index 000000000..dd760bb60
--- /dev/null
+++ b/src/arrow/cpp/examples/arrow/compute_register_example.cc
@@ -0,0 +1,168 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arrow/api.h>
+#include <arrow/compute/api.h>
+#include <arrow/compute/exec/exec_plan.h>
+#include <arrow/compute/exec/expression.h>
+#include <arrow/compute/exec/options.h>
+#include <arrow/util/async_generator.h>
+#include <arrow/util/future.h>
+
+#include <cstdlib>
+#include <iostream>
+#include <memory>
+
+// Demonstrate registering an Arrow compute function outside of the Arrow source tree
+
+namespace cp = ::arrow::compute;
+
+#define ABORT_ON_FAILURE(expr) \
+ do { \
+ arrow::Status status_ = (expr); \
+ if (!status_.ok()) { \
+ std::cerr << status_.message() << std::endl; \
+ abort(); \
+ } \
+ } while (0);
+
+class ExampleFunctionOptionsType : public cp::FunctionOptionsType {
+ const char* type_name() const override { return "ExampleFunctionOptionsType"; }
+ std::string Stringify(const cp::FunctionOptions&) const override {
+ return "ExampleFunctionOptionsType";
+ }
+ bool Compare(const cp::FunctionOptions&, const cp::FunctionOptions&) const override {
+ return true;
+ }
+ std::unique_ptr<cp::FunctionOptions> Copy(const cp::FunctionOptions&) const override;
+ // optional: support for serialization
+ // Result<std::shared_ptr<Buffer>> Serialize(const FunctionOptions&) const override;
+ // Result<std::unique_ptr<FunctionOptions>> Deserialize(const Buffer&) const override;
+};
+
+cp::FunctionOptionsType* GetExampleFunctionOptionsType() {
+ static ExampleFunctionOptionsType options_type;
+ return &options_type;
+}
+
+class ExampleFunctionOptions : public cp::FunctionOptions {
+ public:
+ ExampleFunctionOptions() : cp::FunctionOptions(GetExampleFunctionOptionsType()) {}
+};
+
+std::unique_ptr<cp::FunctionOptions> ExampleFunctionOptionsType::Copy(
+ const cp::FunctionOptions&) const {
+ return std::unique_ptr<cp::FunctionOptions>(new ExampleFunctionOptions());
+}
+
+arrow::Status ExampleFunctionImpl(cp::KernelContext* ctx, const cp::ExecBatch& batch,
+ arrow::Datum* out) {
+ *out->mutable_array() = *batch[0].array();
+ return arrow::Status::OK();
+}
+
+class ExampleNodeOptions : public cp::ExecNodeOptions {};
+
+// a basic ExecNode which ignores all input batches
+class ExampleNode : public cp::ExecNode {
+ public:
+ ExampleNode(ExecNode* input, const ExampleNodeOptions&)
+ : ExecNode(/*plan=*/input->plan(), /*inputs=*/{input},
+ /*input_labels=*/{"ignored"},
+ /*output_schema=*/input->output_schema(), /*num_outputs=*/1) {}
+
+ const char* kind_name() const override { return "ExampleNode"; }
+
+ arrow::Status StartProducing() override {
+ outputs_[0]->InputFinished(this, 0);
+ return arrow::Status::OK();
+ }
+
+ void ResumeProducing(ExecNode* output) override {}
+ void PauseProducing(ExecNode* output) override {}
+
+ void StopProducing(ExecNode* output) override { inputs_[0]->StopProducing(this); }
+ void StopProducing() override { inputs_[0]->StopProducing(); }
+
+ void InputReceived(ExecNode* input, cp::ExecBatch batch) override {}
+ void ErrorReceived(ExecNode* input, arrow::Status error) override {}
+ void InputFinished(ExecNode* input, int total_batches) override {}
+
+ arrow::Future<> finished() override { return inputs_[0]->finished(); }
+};
+
+arrow::Result<cp::ExecNode*> ExampleExecNodeFactory(cp::ExecPlan* plan,
+ std::vector<cp::ExecNode*> inputs,
+ const cp::ExecNodeOptions& options) {
+ const auto& example_options =
+ arrow::internal::checked_cast<const ExampleNodeOptions&>(options);
+
+ return plan->EmplaceNode<ExampleNode>(inputs[0], example_options);
+}
+
+const cp::FunctionDoc func_doc{
+ "Example function to demonstrate registering an out-of-tree function",
+ "",
+ {"x"},
+ "ExampleFunctionOptions"};
+
+int main(int argc, char** argv) {
+ const std::string name = "compute_register_example";
+ auto func = std::make_shared<cp::ScalarFunction>(name, cp::Arity::Unary(), &func_doc);
+ ABORT_ON_FAILURE(func->AddKernel({cp::InputType::Array(arrow::int64())}, arrow::int64(),
+ ExampleFunctionImpl));
+
+ auto registry = cp::GetFunctionRegistry();
+ ABORT_ON_FAILURE(registry->AddFunction(std::move(func)));
+
+ arrow::Int64Builder builder(arrow::default_memory_pool());
+ std::shared_ptr<arrow::Array> arr;
+ ABORT_ON_FAILURE(builder.Append(42));
+ ABORT_ON_FAILURE(builder.Finish(&arr));
+ auto options = std::make_shared<ExampleFunctionOptions>();
+ auto maybe_result = cp::CallFunction(name, {arr}, options.get());
+ ABORT_ON_FAILURE(maybe_result.status());
+
+ std::cout << maybe_result->make_array()->ToString() << std::endl;
+
+ // Expression serialization will raise NotImplemented if an expression includes
+ // FunctionOptions for which serialization is not supported.
+ auto expr = cp::call(name, {}, options);
+ auto maybe_serialized = cp::Serialize(expr);
+ std::cerr << maybe_serialized.status().ToString() << std::endl;
+
+ auto exec_registry = cp::default_exec_factory_registry();
+ ABORT_ON_FAILURE(
+ exec_registry->AddFactory("compute_register_example", ExampleExecNodeFactory));
+
+ auto maybe_plan = cp::ExecPlan::Make();
+ ABORT_ON_FAILURE(maybe_plan.status());
+ auto plan = maybe_plan.ValueOrDie();
+
+ arrow::AsyncGenerator<arrow::util::optional<cp::ExecBatch>> source_gen, sink_gen;
+ ABORT_ON_FAILURE(
+ cp::Declaration::Sequence(
+ {
+ {"source", cp::SourceNodeOptions{arrow::schema({}), source_gen}},
+ {"compute_register_example", ExampleNodeOptions{}},
+ {"sink", cp::SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get())
+ .status());
+
+ return EXIT_SUCCESS;
+}
diff --git a/src/arrow/cpp/examples/arrow/dataset_documentation_example.cc b/src/arrow/cpp/examples/arrow/dataset_documentation_example.cc
new file mode 100644
index 000000000..1aac66d4a
--- /dev/null
+++ b/src/arrow/cpp/examples/arrow/dataset_documentation_example.cc
@@ -0,0 +1,374 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This example showcases various ways to work with Datasets. It's
+// intended to be paired with the documentation.
+
+#include <arrow/api.h>
+#include <arrow/compute/cast.h>
+#include <arrow/compute/exec/expression.h>
+#include <arrow/dataset/dataset.h>
+#include <arrow/dataset/discovery.h>
+#include <arrow/dataset/file_base.h>
+#include <arrow/dataset/file_ipc.h>
+#include <arrow/dataset/file_parquet.h>
+#include <arrow/dataset/scanner.h>
+#include <arrow/filesystem/filesystem.h>
+#include <arrow/ipc/writer.h>
+#include <arrow/util/iterator.h>
+#include <parquet/arrow/writer.h>
+
+#include <iostream>
+#include <vector>
+
+namespace ds = arrow::dataset;
+namespace fs = arrow::fs;
+namespace cp = arrow::compute;
+
+#define ABORT_ON_FAILURE(expr) \
+ do { \
+ arrow::Status status_ = (expr); \
+ if (!status_.ok()) { \
+ std::cerr << status_.message() << std::endl; \
+ abort(); \
+ } \
+ } while (0);
+
+// (Doc section: Reading Datasets)
+// Generate some data for the rest of this example.
+std::shared_ptr<arrow::Table> CreateTable() {
+ auto schema =
+ arrow::schema({arrow::field("a", arrow::int64()), arrow::field("b", arrow::int64()),
+ arrow::field("c", arrow::int64())});
+ std::shared_ptr<arrow::Array> array_a;
+ std::shared_ptr<arrow::Array> array_b;
+ std::shared_ptr<arrow::Array> array_c;
+ arrow::NumericBuilder<arrow::Int64Type> builder;
+ ABORT_ON_FAILURE(builder.AppendValues({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}));
+ ABORT_ON_FAILURE(builder.Finish(&array_a));
+ builder.Reset();
+ ABORT_ON_FAILURE(builder.AppendValues({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}));
+ ABORT_ON_FAILURE(builder.Finish(&array_b));
+ builder.Reset();
+ ABORT_ON_FAILURE(builder.AppendValues({1, 2, 1, 2, 1, 2, 1, 2, 1, 2}));
+ ABORT_ON_FAILURE(builder.Finish(&array_c));
+ return arrow::Table::Make(schema, {array_a, array_b, array_c});
+}
+
+// Set up a dataset by writing two Parquet files.
+std::string CreateExampleParquetDataset(const std::shared_ptr<fs::FileSystem>& filesystem,
+ const std::string& root_path) {
+ auto base_path = root_path + "/parquet_dataset";
+ ABORT_ON_FAILURE(filesystem->CreateDir(base_path));
+ // Create an Arrow Table
+ auto table = CreateTable();
+ // Write it into two Parquet files
+ auto output = filesystem->OpenOutputStream(base_path + "/data1.parquet").ValueOrDie();
+ ABORT_ON_FAILURE(parquet::arrow::WriteTable(
+ *table->Slice(0, 5), arrow::default_memory_pool(), output, /*chunk_size=*/2048));
+ output = filesystem->OpenOutputStream(base_path + "/data2.parquet").ValueOrDie();
+ ABORT_ON_FAILURE(parquet::arrow::WriteTable(
+ *table->Slice(5), arrow::default_memory_pool(), output, /*chunk_size=*/2048));
+ return base_path;
+}
+// (Doc section: Reading Datasets)
+
+// (Doc section: Reading different file formats)
+// Set up a dataset by writing two Feather files.
+std::string CreateExampleFeatherDataset(const std::shared_ptr<fs::FileSystem>& filesystem,
+ const std::string& root_path) {
+ auto base_path = root_path + "/feather_dataset";
+ ABORT_ON_FAILURE(filesystem->CreateDir(base_path));
+ // Create an Arrow Table
+ auto table = CreateTable();
+ // Write it into two Feather files
+ auto output = filesystem->OpenOutputStream(base_path + "/data1.feather").ValueOrDie();
+ auto writer = arrow::ipc::MakeFileWriter(output.get(), table->schema()).ValueOrDie();
+ ABORT_ON_FAILURE(writer->WriteTable(*table->Slice(0, 5)));
+ ABORT_ON_FAILURE(writer->Close());
+ output = filesystem->OpenOutputStream(base_path + "/data2.feather").ValueOrDie();
+ writer = arrow::ipc::MakeFileWriter(output.get(), table->schema()).ValueOrDie();
+ ABORT_ON_FAILURE(writer->WriteTable(*table->Slice(5)));
+ ABORT_ON_FAILURE(writer->Close());
+ return base_path;
+}
+// (Doc section: Reading different file formats)
+
+// (Doc section: Reading and writing partitioned data)
+// Set up a dataset by writing files with partitioning
+std::string CreateExampleParquetHivePartitionedDataset(
+ const std::shared_ptr<fs::FileSystem>& filesystem, const std::string& root_path) {
+ auto base_path = root_path + "/parquet_dataset";
+ ABORT_ON_FAILURE(filesystem->CreateDir(base_path));
+ // Create an Arrow Table
+ auto schema = arrow::schema(
+ {arrow::field("a", arrow::int64()), arrow::field("b", arrow::int64()),
+ arrow::field("c", arrow::int64()), arrow::field("part", arrow::utf8())});
+ std::vector<std::shared_ptr<arrow::Array>> arrays(4);
+ arrow::NumericBuilder<arrow::Int64Type> builder;
+ ABORT_ON_FAILURE(builder.AppendValues({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}));
+ ABORT_ON_FAILURE(builder.Finish(&arrays[0]));
+ builder.Reset();
+ ABORT_ON_FAILURE(builder.AppendValues({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}));
+ ABORT_ON_FAILURE(builder.Finish(&arrays[1]));
+ builder.Reset();
+ ABORT_ON_FAILURE(builder.AppendValues({1, 2, 1, 2, 1, 2, 1, 2, 1, 2}));
+ ABORT_ON_FAILURE(builder.Finish(&arrays[2]));
+ arrow::StringBuilder string_builder;
+ ABORT_ON_FAILURE(
+ string_builder.AppendValues({"a", "a", "a", "a", "a", "b", "b", "b", "b", "b"}));
+ ABORT_ON_FAILURE(string_builder.Finish(&arrays[3]));
+ auto table = arrow::Table::Make(schema, arrays);
+ // Write it using Datasets
+ auto dataset = std::make_shared<ds::InMemoryDataset>(table);
+ auto scanner_builder = dataset->NewScan().ValueOrDie();
+ auto scanner = scanner_builder->Finish().ValueOrDie();
+
+ // The partition schema determines which fields are part of the partitioning.
+ auto partition_schema = arrow::schema({arrow::field("part", arrow::utf8())});
+ // We'll use Hive-style partitioning, which creates directories with "key=value" pairs.
+ auto partitioning = std::make_shared<ds::HivePartitioning>(partition_schema);
+ // We'll write Parquet files.
+ auto format = std::make_shared<ds::ParquetFileFormat>();
+ ds::FileSystemDatasetWriteOptions write_options;
+ write_options.file_write_options = format->DefaultWriteOptions();
+ write_options.filesystem = filesystem;
+ write_options.base_dir = base_path;
+ write_options.partitioning = partitioning;
+ write_options.basename_template = "part{i}.parquet";
+ ABORT_ON_FAILURE(ds::FileSystemDataset::Write(write_options, scanner));
+ return base_path;
+}
+// (Doc section: Reading and writing partitioned data)
+
+// (Doc section: Dataset discovery)
+// Read the whole dataset with the given format, without partitioning.
+std::shared_ptr<arrow::Table> ScanWholeDataset(
+ const std::shared_ptr<fs::FileSystem>& filesystem,
+ const std::shared_ptr<ds::FileFormat>& format, const std::string& base_dir) {
+ // Create a dataset by scanning the filesystem for files
+ fs::FileSelector selector;
+ selector.base_dir = base_dir;
+ auto factory = ds::FileSystemDatasetFactory::Make(filesystem, selector, format,
+ ds::FileSystemFactoryOptions())
+ .ValueOrDie();
+ auto dataset = factory->Finish().ValueOrDie();
+ // Print out the fragments
+ for (const auto& fragment : dataset->GetFragments().ValueOrDie()) {
+ std::cout << "Found fragment: " << (*fragment)->ToString() << std::endl;
+ }
+ // Read the entire dataset as a Table
+ auto scan_builder = dataset->NewScan().ValueOrDie();
+ auto scanner = scan_builder->Finish().ValueOrDie();
+ return scanner->ToTable().ValueOrDie();
+}
+// (Doc section: Dataset discovery)
+
+// (Doc section: Filtering data)
+// Read a dataset, but select only column "b" and only rows where b < 4.
+//
+// This is useful when you only want a few columns from a dataset. Where possible,
+// Datasets will push down the column selection such that less work is done.
+std::shared_ptr<arrow::Table> FilterAndSelectDataset(
+ const std::shared_ptr<fs::FileSystem>& filesystem,
+ const std::shared_ptr<ds::FileFormat>& format, const std::string& base_dir) {
+ fs::FileSelector selector;
+ selector.base_dir = base_dir;
+ auto factory = ds::FileSystemDatasetFactory::Make(filesystem, selector, format,
+ ds::FileSystemFactoryOptions())
+ .ValueOrDie();
+ auto dataset = factory->Finish().ValueOrDie();
+ // Read specified columns with a row filter
+ auto scan_builder = dataset->NewScan().ValueOrDie();
+ ABORT_ON_FAILURE(scan_builder->Project({"b"}));
+ ABORT_ON_FAILURE(scan_builder->Filter(cp::less(cp::field_ref("b"), cp::literal(4))));
+ auto scanner = scan_builder->Finish().ValueOrDie();
+ return scanner->ToTable().ValueOrDie();
+}
+// (Doc section: Filtering data)
+
+// (Doc section: Projecting columns)
+// Read a dataset, but with column projection.
+//
+// This is useful to derive new columns from existing data. For example, here we
+// demonstrate casting a column to a different type, and turning a numeric column into a
+// boolean column based on a predicate. You could also rename columns or perform
+// computations involving multiple columns.
+std::shared_ptr<arrow::Table> ProjectDataset(
+ const std::shared_ptr<fs::FileSystem>& filesystem,
+ const std::shared_ptr<ds::FileFormat>& format, const std::string& base_dir) {
+ fs::FileSelector selector;
+ selector.base_dir = base_dir;
+ auto factory = ds::FileSystemDatasetFactory::Make(filesystem, selector, format,
+ ds::FileSystemFactoryOptions())
+ .ValueOrDie();
+ auto dataset = factory->Finish().ValueOrDie();
+ // Read specified columns with a row filter
+ auto scan_builder = dataset->NewScan().ValueOrDie();
+ ABORT_ON_FAILURE(scan_builder->Project(
+ {
+ // Leave column "a" as-is.
+ cp::field_ref("a"),
+ // Cast column "b" to float32.
+ cp::call("cast", {cp::field_ref("b")},
+ arrow::compute::CastOptions::Safe(arrow::float32())),
+ // Derive a boolean column from "c".
+ cp::equal(cp::field_ref("c"), cp::literal(1)),
+ },
+ {"a_renamed", "b_as_float32", "c_1"}));
+ auto scanner = scan_builder->Finish().ValueOrDie();
+ return scanner->ToTable().ValueOrDie();
+}
+// (Doc section: Projecting columns)
+
+// (Doc section: Projecting columns #2)
+// Read a dataset, but with column projection.
+//
+// This time, we read all original columns plus one derived column. This simply combines
+// the previous two examples: selecting a subset of columns by name, and deriving new
+// columns with an expression.
+std::shared_ptr<arrow::Table> SelectAndProjectDataset(
+ const std::shared_ptr<fs::FileSystem>& filesystem,
+ const std::shared_ptr<ds::FileFormat>& format, const std::string& base_dir) {
+ fs::FileSelector selector;
+ selector.base_dir = base_dir;
+ auto factory = ds::FileSystemDatasetFactory::Make(filesystem, selector, format,
+ ds::FileSystemFactoryOptions())
+ .ValueOrDie();
+ auto dataset = factory->Finish().ValueOrDie();
+ // Read specified columns with a row filter
+ auto scan_builder = dataset->NewScan().ValueOrDie();
+ std::vector<std::string> names;
+ std::vector<cp::Expression> exprs;
+ // Read all the original columns.
+ for (const auto& field : dataset->schema()->fields()) {
+ names.push_back(field->name());
+ exprs.push_back(cp::field_ref(field->name()));
+ }
+ // Also derive a new column.
+ names.emplace_back("b_large");
+ exprs.push_back(cp::greater(cp::field_ref("b"), cp::literal(1)));
+ ABORT_ON_FAILURE(scan_builder->Project(exprs, names));
+ auto scanner = scan_builder->Finish().ValueOrDie();
+ return scanner->ToTable().ValueOrDie();
+}
+// (Doc section: Projecting columns #2)
+
+// (Doc section: Reading and writing partitioned data #2)
+// Read an entire dataset, but with partitioning information.
+std::shared_ptr<arrow::Table> ScanPartitionedDataset(
+ const std::shared_ptr<fs::FileSystem>& filesystem,
+ const std::shared_ptr<ds::FileFormat>& format, const std::string& base_dir) {
+ fs::FileSelector selector;
+ selector.base_dir = base_dir;
+ selector.recursive = true; // Make sure to search subdirectories
+ ds::FileSystemFactoryOptions options;
+ // We'll use Hive-style partitioning. We'll let Arrow Datasets infer the partition
+ // schema.
+ options.partitioning = ds::HivePartitioning::MakeFactory();
+ auto factory = ds::FileSystemDatasetFactory::Make(filesystem, selector, format, options)
+ .ValueOrDie();
+ auto dataset = factory->Finish().ValueOrDie();
+ // Print out the fragments
+ for (const auto& fragment : dataset->GetFragments().ValueOrDie()) {
+ std::cout << "Found fragment: " << (*fragment)->ToString() << std::endl;
+ std::cout << "Partition expression: "
+ << (*fragment)->partition_expression().ToString() << std::endl;
+ }
+ auto scan_builder = dataset->NewScan().ValueOrDie();
+ auto scanner = scan_builder->Finish().ValueOrDie();
+ return scanner->ToTable().ValueOrDie();
+}
+// (Doc section: Reading and writing partitioned data #2)
+
+// (Doc section: Reading and writing partitioned data #3)
+// Read an entire dataset, but with partitioning information. Also, filter the dataset on
+// the partition values.
+std::shared_ptr<arrow::Table> FilterPartitionedDataset(
+ const std::shared_ptr<fs::FileSystem>& filesystem,
+ const std::shared_ptr<ds::FileFormat>& format, const std::string& base_dir) {
+ fs::FileSelector selector;
+ selector.base_dir = base_dir;
+ selector.recursive = true;
+ ds::FileSystemFactoryOptions options;
+ options.partitioning = ds::HivePartitioning::MakeFactory();
+ auto factory = ds::FileSystemDatasetFactory::Make(filesystem, selector, format, options)
+ .ValueOrDie();
+ auto dataset = factory->Finish().ValueOrDie();
+ auto scan_builder = dataset->NewScan().ValueOrDie();
+ // Filter based on the partition values. This will mean that we won't even read the
+ // files whose partition expressions don't match the filter.
+ ABORT_ON_FAILURE(
+ scan_builder->Filter(cp::equal(cp::field_ref("part"), cp::literal("b"))));
+ auto scanner = scan_builder->Finish().ValueOrDie();
+ return scanner->ToTable().ValueOrDie();
+}
+// (Doc section: Reading and writing partitioned data #3)
+
+int main(int argc, char** argv) {
+ if (argc < 3) {
+ // Fake success for CI purposes.
+ return EXIT_SUCCESS;
+ }
+
+ std::string uri = argv[1];
+ std::string format_name = argv[2];
+ std::string mode = argc > 3 ? argv[3] : "no_filter";
+ std::string root_path;
+ auto fs = fs::FileSystemFromUri(uri, &root_path).ValueOrDie();
+
+ std::string base_path;
+ std::shared_ptr<ds::FileFormat> format;
+ if (format_name == "feather") {
+ format = std::make_shared<ds::IpcFileFormat>();
+ base_path = CreateExampleFeatherDataset(fs, root_path);
+ } else if (format_name == "parquet") {
+ format = std::make_shared<ds::ParquetFileFormat>();
+ base_path = CreateExampleParquetDataset(fs, root_path);
+ } else if (format_name == "parquet_hive") {
+ format = std::make_shared<ds::ParquetFileFormat>();
+ base_path = CreateExampleParquetHivePartitionedDataset(fs, root_path);
+ } else {
+ std::cerr << "Unknown format: " << format_name << std::endl;
+ std::cerr << "Supported formats: feather, parquet, parquet_hive" << std::endl;
+ return EXIT_FAILURE;
+ }
+
+ std::shared_ptr<arrow::Table> table;
+ if (mode == "no_filter") {
+ table = ScanWholeDataset(fs, format, base_path);
+ } else if (mode == "filter") {
+ table = FilterAndSelectDataset(fs, format, base_path);
+ } else if (mode == "project") {
+ table = ProjectDataset(fs, format, base_path);
+ } else if (mode == "select_project") {
+ table = SelectAndProjectDataset(fs, format, base_path);
+ } else if (mode == "partitioned") {
+ table = ScanPartitionedDataset(fs, format, base_path);
+ } else if (mode == "filter_partitioned") {
+ table = FilterPartitionedDataset(fs, format, base_path);
+ } else {
+ std::cerr << "Unknown mode: " << mode << std::endl;
+ std::cerr
+ << "Supported modes: no_filter, filter, project, select_project, partitioned"
+ << std::endl;
+ return EXIT_FAILURE;
+ }
+ std::cout << "Read " << table->num_rows() << " rows" << std::endl;
+ std::cout << table->ToString() << std::endl;
+ return EXIT_SUCCESS;
+}
diff --git a/src/arrow/cpp/examples/arrow/dataset_parquet_scan_example.cc b/src/arrow/cpp/examples/arrow/dataset_parquet_scan_example.cc
new file mode 100644
index 000000000..cd9b89fe3
--- /dev/null
+++ b/src/arrow/cpp/examples/arrow/dataset_parquet_scan_example.cc
@@ -0,0 +1,190 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arrow/api.h>
+#include <arrow/compute/exec/expression.h>
+#include <arrow/dataset/dataset.h>
+#include <arrow/dataset/discovery.h>
+#include <arrow/dataset/file_base.h>
+#include <arrow/dataset/file_parquet.h>
+#include <arrow/dataset/scanner.h>
+#include <arrow/filesystem/filesystem.h>
+#include <arrow/filesystem/path_util.h>
+
+#include <cstdlib>
+#include <iostream>
+
+using arrow::field;
+using arrow::int16;
+using arrow::Schema;
+using arrow::Table;
+
+namespace fs = arrow::fs;
+
+namespace ds = arrow::dataset;
+
+namespace cp = arrow::compute;
+
+#define ABORT_ON_FAILURE(expr) \
+ do { \
+ arrow::Status status_ = (expr); \
+ if (!status_.ok()) { \
+ std::cerr << status_.message() << std::endl; \
+ abort(); \
+ } \
+ } while (0);
+
+struct Configuration {
+ // Increase the ds::DataSet by repeating `repeat` times the ds::Dataset.
+ size_t repeat = 1;
+
+ // Indicates if the Scanner::ToTable should consume in parallel.
+ bool use_threads = true;
+
+ // Indicates to the Scan operator which columns are requested. This
+ // optimization avoid deserializing unneeded columns.
+ std::vector<std::string> projected_columns = {"pickup_at", "dropoff_at",
+ "total_amount"};
+
+ // Indicates the filter by which rows will be filtered. This optimization can
+ // make use of partition information and/or file metadata if possible.
+ cp::Expression filter =
+ cp::greater(cp::field_ref("total_amount"), cp::literal(1000.0f));
+
+ ds::InspectOptions inspect_options{};
+ ds::FinishOptions finish_options{};
+} conf;
+
+std::shared_ptr<fs::FileSystem> GetFileSystemFromUri(const std::string& uri,
+ std::string* path) {
+ return fs::FileSystemFromUri(uri, path).ValueOrDie();
+}
+
+std::shared_ptr<ds::Dataset> GetDatasetFromDirectory(
+ std::shared_ptr<fs::FileSystem> fs, std::shared_ptr<ds::ParquetFileFormat> format,
+ std::string dir) {
+ // Find all files under `path`
+ fs::FileSelector s;
+ s.base_dir = dir;
+ s.recursive = true;
+
+ ds::FileSystemFactoryOptions options;
+ // The factory will try to build a child dataset.
+ auto factory = ds::FileSystemDatasetFactory::Make(fs, s, format, options).ValueOrDie();
+
+ // Try to infer a common schema for all files.
+ auto schema = factory->Inspect(conf.inspect_options).ValueOrDie();
+ // Caller can optionally decide another schema as long as it is compatible
+ // with the previous one, e.g. `factory->Finish(compatible_schema)`.
+ auto child = factory->Finish(conf.finish_options).ValueOrDie();
+
+ ds::DatasetVector children{conf.repeat, child};
+ auto dataset = ds::UnionDataset::Make(std::move(schema), std::move(children));
+
+ return dataset.ValueOrDie();
+}
+
+std::shared_ptr<ds::Dataset> GetParquetDatasetFromMetadata(
+ std::shared_ptr<fs::FileSystem> fs, std::shared_ptr<ds::ParquetFileFormat> format,
+ std::string metadata_path) {
+ ds::ParquetFactoryOptions options;
+ auto factory =
+ ds::ParquetDatasetFactory::Make(metadata_path, fs, format, options).ValueOrDie();
+ return factory->Finish().ValueOrDie();
+}
+
+std::shared_ptr<ds::Dataset> GetDatasetFromFile(
+ std::shared_ptr<fs::FileSystem> fs, std::shared_ptr<ds::ParquetFileFormat> format,
+ std::string file) {
+ ds::FileSystemFactoryOptions options;
+ // The factory will try to build a child dataset.
+ auto factory =
+ ds::FileSystemDatasetFactory::Make(fs, {file}, format, options).ValueOrDie();
+
+ // Try to infer a common schema for all files.
+ auto schema = factory->Inspect(conf.inspect_options).ValueOrDie();
+ // Caller can optionally decide another schema as long as it is compatible
+ // with the previous one, e.g. `factory->Finish(compatible_schema)`.
+ auto child = factory->Finish(conf.finish_options).ValueOrDie();
+
+ ds::DatasetVector children;
+ children.resize(conf.repeat, child);
+ auto dataset = ds::UnionDataset::Make(std::move(schema), std::move(children));
+
+ return dataset.ValueOrDie();
+}
+
+std::shared_ptr<ds::Dataset> GetDatasetFromPath(
+ std::shared_ptr<fs::FileSystem> fs, std::shared_ptr<ds::ParquetFileFormat> format,
+ std::string path) {
+ auto info = fs->GetFileInfo(path).ValueOrDie();
+ if (info.IsDirectory()) {
+ return GetDatasetFromDirectory(fs, format, path);
+ }
+
+ auto dirname_basename = arrow::fs::internal::GetAbstractPathParent(path);
+ auto basename = dirname_basename.second;
+
+ if (basename == "_metadata") {
+ return GetParquetDatasetFromMetadata(fs, format, path);
+ }
+
+ return GetDatasetFromFile(fs, format, path);
+}
+
+std::shared_ptr<ds::Scanner> GetScannerFromDataset(std::shared_ptr<ds::Dataset> dataset,
+ std::vector<std::string> columns,
+ cp::Expression filter,
+ bool use_threads) {
+ auto scanner_builder = dataset->NewScan().ValueOrDie();
+
+ if (!columns.empty()) {
+ ABORT_ON_FAILURE(scanner_builder->Project(columns));
+ }
+
+ ABORT_ON_FAILURE(scanner_builder->Filter(filter));
+
+ ABORT_ON_FAILURE(scanner_builder->UseThreads(use_threads));
+
+ return scanner_builder->Finish().ValueOrDie();
+}
+
+std::shared_ptr<Table> GetTableFromScanner(std::shared_ptr<ds::Scanner> scanner) {
+ return scanner->ToTable().ValueOrDie();
+}
+
+int main(int argc, char** argv) {
+ auto format = std::make_shared<ds::ParquetFileFormat>();
+
+ if (argc != 2) {
+ // Fake success for CI purposes.
+ return EXIT_SUCCESS;
+ }
+
+ std::string path;
+ auto fs = GetFileSystemFromUri(argv[1], &path);
+
+ auto dataset = GetDatasetFromPath(fs, format, path);
+
+ auto scanner = GetScannerFromDataset(dataset, conf.projected_columns, conf.filter,
+ conf.use_threads);
+
+ auto table = GetTableFromScanner(scanner);
+ std::cout << "Table size: " << table->num_rows() << "\n";
+
+ return EXIT_SUCCESS;
+}
diff --git a/src/arrow/cpp/examples/arrow/row_wise_conversion_example.cc b/src/arrow/cpp/examples/arrow/row_wise_conversion_example.cc
new file mode 100644
index 000000000..1af1c5547
--- /dev/null
+++ b/src/arrow/cpp/examples/arrow/row_wise_conversion_example.cc
@@ -0,0 +1,207 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arrow/api.h>
+#include <arrow/result.h>
+
+#include <cstdint>
+#include <iomanip>
+#include <iostream>
+#include <vector>
+
+using arrow::DoubleBuilder;
+using arrow::Int64Builder;
+using arrow::ListBuilder;
+
+// While we want to use columnar data structures to build efficient operations, we
+// often receive data in a row-wise fashion from other systems. In the following,
+// we want give a brief introduction into the classes provided by Apache Arrow by
+// showing how to transform row-wise data into a columnar table.
+//
+// The table contains an id for a product, the number of components in the product
+// and the cost of each component.
+//
+// The data in this example is stored in the following struct:
+struct data_row {
+ int64_t id;
+ int64_t components;
+ std::vector<double> component_cost;
+};
+
+// Transforming a vector of structs into a columnar Table.
+//
+// The final representation should be an `arrow::Table` which in turn
+// is made up of an `arrow::Schema` and a list of
+// `arrow::ChunkedArray` instances. As the first step, we will iterate
+// over the data and build up the arrays incrementally. For this
+// task, we provide `arrow::ArrayBuilder` classes that help in the
+// construction of the final `arrow::Array` instances.
+//
+// For each type, Arrow has a specially typed builder class. For the primitive
+// values `id` and `components` we can use the `arrow::Int64Builder`. For the
+// `component_cost` vector, we need to have two builders, a top-level
+// `arrow::ListBuilder` that builds the array of offsets and a nested
+// `arrow::DoubleBuilder` that constructs the underlying values array that
+// is referenced by the offsets in the former array.
+arrow::Result<std::shared_ptr<arrow::Table>> VectorToColumnarTable(
+ const std::vector<struct data_row>& rows) {
+ // The builders are more efficient using
+ // arrow::jemalloc::MemoryPool::default_pool() as this can increase the size of
+ // the underlying memory regions in-place. At the moment, arrow::jemalloc is only
+ // supported on Unix systems, not Windows.
+ arrow::MemoryPool* pool = arrow::default_memory_pool();
+
+ Int64Builder id_builder(pool);
+ Int64Builder components_builder(pool);
+ ListBuilder component_cost_builder(pool, std::make_shared<DoubleBuilder>(pool));
+ // The following builder is owned by component_cost_builder.
+ DoubleBuilder* component_item_cost_builder =
+ (static_cast<DoubleBuilder*>(component_cost_builder.value_builder()));
+
+ // Now we can loop over our existing data and insert it into the builders. The
+ // `Append` calls here may fail (e.g. we cannot allocate enough additional memory).
+ // Thus we need to check their return values. For more information on these values,
+ // check the documentation about `arrow::Status`.
+ for (const data_row& row : rows) {
+ ARROW_RETURN_NOT_OK(id_builder.Append(row.id));
+ ARROW_RETURN_NOT_OK(components_builder.Append(row.components));
+
+ // Indicate the start of a new list row. This will memorise the current
+ // offset in the values builder.
+ ARROW_RETURN_NOT_OK(component_cost_builder.Append());
+ // Store the actual values. The same memory layout is
+ // used for the component cost data, in this case a vector of
+ // type double, as for the memory that Arrow uses to hold this
+ // data and will be created.
+ ARROW_RETURN_NOT_OK(component_item_cost_builder->AppendValues(
+ row.component_cost.data(), row.component_cost.size()));
+ }
+
+ // At the end, we finalise the arrays, declare the (type) schema and combine them
+ // into a single `arrow::Table`:
+ std::shared_ptr<arrow::Array> id_array;
+ ARROW_RETURN_NOT_OK(id_builder.Finish(&id_array));
+ std::shared_ptr<arrow::Array> components_array;
+ ARROW_RETURN_NOT_OK(components_builder.Finish(&components_array));
+ // No need to invoke component_cost_builder.Finish because it is implied by
+ // the parent builder's Finish invocation.
+ std::shared_ptr<arrow::Array> component_cost_array;
+ ARROW_RETURN_NOT_OK(component_cost_builder.Finish(&component_cost_array));
+
+ std::vector<std::shared_ptr<arrow::Field>> schema_vector = {
+ arrow::field("id", arrow::int64()), arrow::field("components", arrow::int64()),
+ arrow::field("component_cost", arrow::list(arrow::float64()))};
+
+ auto schema = std::make_shared<arrow::Schema>(schema_vector);
+
+ // The final `table` variable is the one we can then pass on to other functions
+ // that can consume Apache Arrow memory structures. This object has ownership of
+ // all referenced data, thus we don't have to care about undefined references once
+ // we leave the scope of the function building the table and its underlying arrays.
+ std::shared_ptr<arrow::Table> table =
+ arrow::Table::Make(schema, {id_array, components_array, component_cost_array});
+
+ return table;
+}
+
+arrow::Result<std::vector<data_row>> ColumnarTableToVector(
+ const std::shared_ptr<arrow::Table>& table) {
+ // To convert an Arrow table back into the same row-wise representation as in the
+ // above section, we first will check that the table conforms to our expected
+ // schema and then will build up the vector of rows incrementally.
+ //
+ // For the check if the table is as expected, we can utilise solely its schema.
+ std::vector<std::shared_ptr<arrow::Field>> schema_vector = {
+ arrow::field("id", arrow::int64()), arrow::field("components", arrow::int64()),
+ arrow::field("component_cost", arrow::list(arrow::float64()))};
+ auto expected_schema = std::make_shared<arrow::Schema>(schema_vector);
+
+ if (!expected_schema->Equals(*table->schema())) {
+ // The table doesn't have the expected schema thus we cannot directly
+ // convert it to our target representation.
+ return arrow::Status::Invalid("Schemas are not matching!");
+ }
+
+ // As we have ensured that the table has the expected structure, we can unpack the
+ // underlying arrays. For the primitive columns `id` and `components` we can use the
+ // high level functions to get the values whereas for the nested column
+ // `component_costs` we need to access the C-pointer to the data to copy its
+ // contents into the resulting `std::vector<double>`. Here we need to be careful to
+ // also add the offset to the pointer. This offset is needed to enable zero-copy
+ // slicing operations. While this could be adjusted automatically for double
+ // arrays, this cannot be done for the accompanying bitmap as often the slicing
+ // border would be inside a byte.
+
+ auto ids = std::static_pointer_cast<arrow::Int64Array>(table->column(0)->chunk(0));
+ auto components =
+ std::static_pointer_cast<arrow::Int64Array>(table->column(1)->chunk(0));
+ auto component_cost =
+ std::static_pointer_cast<arrow::ListArray>(table->column(2)->chunk(0));
+ auto component_cost_values =
+ std::static_pointer_cast<arrow::DoubleArray>(component_cost->values());
+ // To enable zero-copy slices, the native values pointer might need to account
+ // for this slicing offset. This is not needed for the higher level functions
+ // like Value(…) that already account for this offset internally.
+ const double* ccv_ptr = component_cost_values->raw_values();
+ std::vector<data_row> rows;
+ for (int64_t i = 0; i < table->num_rows(); i++) {
+ // Another simplification in this example is that we assume that there are
+ // no null entries, e.g. each row is fill with valid values.
+ int64_t id = ids->Value(i);
+ int64_t component = components->Value(i);
+ const double* first = ccv_ptr + component_cost->value_offset(i);
+ const double* last = ccv_ptr + component_cost->value_offset(i + 1);
+ std::vector<double> components_vec(first, last);
+ rows.push_back({id, component, components_vec});
+ }
+
+ return rows;
+}
+
+int main(int argc, char** argv) {
+ std::vector<data_row> rows = {
+ {1, 1, {10.0}}, {2, 3, {11.0, 12.0, 13.0}}, {3, 2, {15.0, 25.0}}};
+ std::shared_ptr<arrow::Table> table;
+ std::vector<data_row> expected_rows;
+
+ arrow::Result<std::shared_ptr<arrow::Table>> table_result = VectorToColumnarTable(rows);
+ table = std::move(table_result).ValueOrDie();
+
+ arrow::Result<std::vector<data_row>> expected_rows_result =
+ ColumnarTableToVector(table);
+ expected_rows = std::move(expected_rows_result).ValueOrDie();
+
+ assert(rows.size() == expected_rows.size());
+
+ // Print out contents of table, should get
+ // ID Components Component prices
+ // 1 1 10
+ // 2 3 11 12 13
+ // 3 2 15 25
+ std::cout << std::left << std::setw(3) << "ID " << std::left << std::setw(11)
+ << "Components " << std::left << std::setw(15) << "Component prices "
+ << std::endl;
+ for (const auto& row : rows) {
+ std::cout << std::left << std::setw(3) << row.id << std::left << std::setw(11)
+ << row.components;
+ for (const auto& cost : row.component_cost) {
+ std::cout << std::left << std::setw(4) << cost;
+ }
+ std::cout << std::endl;
+ }
+ return EXIT_SUCCESS;
+}
diff --git a/src/arrow/cpp/examples/minimal_build/.gitignore b/src/arrow/cpp/examples/minimal_build/.gitignore
new file mode 100644
index 000000000..c94f3ec42
--- /dev/null
+++ b/src/arrow/cpp/examples/minimal_build/.gitignore
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+test.arrow
diff --git a/src/arrow/cpp/examples/minimal_build/CMakeLists.txt b/src/arrow/cpp/examples/minimal_build/CMakeLists.txt
new file mode 100644
index 000000000..9fc20c70f
--- /dev/null
+++ b/src/arrow/cpp/examples/minimal_build/CMakeLists.txt
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+cmake_minimum_required(VERSION 3.0)
+
+project(ArrowMinimalExample)
+
+option(ARROW_LINK_SHARED "Link to the Arrow shared library" ON)
+
+find_package(Arrow REQUIRED)
+
+set(CMAKE_CXX_STANDARD 11)
+set(CMAKE_BUILD_TYPE Release)
+
+message(STATUS "Arrow version: ${ARROW_VERSION}")
+message(STATUS "Arrow SO version: ${ARROW_FULL_SO_VERSION}")
+
+add_executable(arrow_example example.cc)
+
+if (ARROW_LINK_SHARED)
+ target_link_libraries(arrow_example PRIVATE arrow_shared)
+else()
+ set(THREADS_PREFER_PTHREAD_FLAG ON)
+ find_package(Threads REQUIRED)
+ target_link_libraries(arrow_example PRIVATE arrow_static Threads::Threads)
+endif()
diff --git a/src/arrow/cpp/examples/minimal_build/README.md b/src/arrow/cpp/examples/minimal_build/README.md
new file mode 100644
index 000000000..9f889f6ad
--- /dev/null
+++ b/src/arrow/cpp/examples/minimal_build/README.md
@@ -0,0 +1,88 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Minimal C++ build example
+
+This directory showcases a minimal build of Arrow C++ (in `build_arrow.sh`).
+This minimal build is then used by an example third-party C++ project
+using CMake logic to compile and link against the Arrow C++ library
+(in `build_example.sh` and `CMakeLists.txt`).
+
+When run, the example executable reads a file named `test.csv`,
+displays its parsed contents, and then saves them in Arrow IPC format in
+a file named `test.arrow`.
+
+## Running the example
+
+You can run this simple example using [Docker Compose][docker-compose]
+and the given `docker-compose.yml` and dockerfiles, which installs a
+minimal Ubuntu image with a basic C++ toolchain.
+
+Just open a terminal in this directory and run the following commands:
+
+```bash
+docker-compose run --rm minimal
+```
+
+Note that this example mounts two volumes inside the Docker image:
+* `/arrow` points to the Arrow source tree
+* `/io` points to this example directory
+
+## Statically-linked builds
+
+We've provided an example build configuration here with CMake to show how to
+create a statically-linked executable with bundled dependencies.
+
+To run it on Linux, you can use the above Docker image:
+
+```bash
+docker-compose run --rm static
+```
+
+On macOS, you can use the `run_static.sh` but you must set some environment
+variables to point the script to your Arrow checkout, for example:
+
+```bash
+export ARROW_DIR=path/to/arrow-clone
+export EXAMPLE_DIR=$ARROW_DIR/cpp/examples/minimal_build
+export ARROW_BUILD_DIR=$(pwd)/arrow-build
+export EXAMPLE_BUILD_DIR=$(pwd)/example
+
+./run_static.sh
+```
+
+On Windows, you can run `run_static.bat` from the command prompt with Visual
+Studio's command line tools enabled and CMake and ninja build in the path:
+
+```
+call run_static.bat
+```
+
+### Static linking against system libraries
+
+You can also use static libraries of Arrow's dependencies from the
+system. To run this configuration, set
+`ARROW_DEPENDENCY_SOURCE=SYSTEM` for `run_static.sh`. You can use
+`docker-compose` for this too:
+
+```bash
+docker-compose run --rm static-system-dependency
+```
+
+[docker-compose]: https://docs.docker.com/compose/
diff --git a/src/arrow/cpp/examples/minimal_build/build_arrow.sh b/src/arrow/cpp/examples/minimal_build/build_arrow.sh
new file mode 100755
index 000000000..402c312e4
--- /dev/null
+++ b/src/arrow/cpp/examples/minimal_build/build_arrow.sh
@@ -0,0 +1,35 @@
+#!/usr/bin/env bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -ex
+
+NPROC=$(nproc)
+
+mkdir -p $ARROW_BUILD_DIR
+pushd $ARROW_BUILD_DIR
+
+# Enable the CSV reader as it's used by the example third-party build
+cmake /arrow/cpp \
+ -DARROW_CSV=ON \
+ -DARROW_JEMALLOC=OFF \
+ $ARROW_CMAKE_OPTIONS
+
+make -j$NPROC
+make install
+
+popd
diff --git a/src/arrow/cpp/examples/minimal_build/build_example.sh b/src/arrow/cpp/examples/minimal_build/build_example.sh
new file mode 100755
index 000000000..a315755a5
--- /dev/null
+++ b/src/arrow/cpp/examples/minimal_build/build_example.sh
@@ -0,0 +1,27 @@
+#!/usr/bin/env bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -ex
+
+mkdir -p $EXAMPLE_BUILD_DIR
+pushd $EXAMPLE_BUILD_DIR
+
+cmake /io
+make
+
+popd
diff --git a/src/arrow/cpp/examples/minimal_build/docker-compose.yml b/src/arrow/cpp/examples/minimal_build/docker-compose.yml
new file mode 100644
index 000000000..6e2dcef81
--- /dev/null
+++ b/src/arrow/cpp/examples/minimal_build/docker-compose.yml
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+version: '3.5'
+
+services:
+ minimal:
+ build:
+ context: .
+ dockerfile: minimal.dockerfile
+ volumes:
+ - ../../../:/arrow:delegated
+ - .:/io:delegated
+ command:
+ - "/io/run.sh"
+
+ static:
+ build:
+ context: .
+ dockerfile: minimal.dockerfile
+ volumes:
+ - ../../../:/arrow:delegated
+ - .:/io:delegated
+ command:
+ - "/io/run_static.sh"
+
+ static-system-dependency:
+ build:
+ context: .
+ dockerfile: system_dependency.dockerfile
+ environment:
+ ARROW_DEPENDENCY_SOURCE: "SYSTEM"
+ volumes:
+ - ../../../:/arrow:delegated
+ - .:/io:delegated
+ command:
+ - "/io/run_static.sh"
diff --git a/src/arrow/cpp/examples/minimal_build/example.cc b/src/arrow/cpp/examples/minimal_build/example.cc
new file mode 100644
index 000000000..9bfb9953e
--- /dev/null
+++ b/src/arrow/cpp/examples/minimal_build/example.cc
@@ -0,0 +1,69 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arrow/csv/api.h>
+#include <arrow/io/api.h>
+#include <arrow/ipc/api.h>
+#include <arrow/pretty_print.h>
+#include <arrow/result.h>
+#include <arrow/status.h>
+#include <arrow/table.h>
+
+#include <iostream>
+
+using arrow::Status;
+
+namespace {
+
+Status RunMain(int argc, char** argv) {
+ const char* csv_filename = "test.csv";
+ const char* arrow_filename = "test.arrow";
+
+ std::cerr << "* Reading CSV file '" << csv_filename << "' into table" << std::endl;
+ ARROW_ASSIGN_OR_RAISE(auto input_file, arrow::io::ReadableFile::Open(csv_filename));
+ ARROW_ASSIGN_OR_RAISE(auto csv_reader, arrow::csv::TableReader::Make(
+ arrow::io::default_io_context(), input_file,
+ arrow::csv::ReadOptions::Defaults(),
+ arrow::csv::ParseOptions::Defaults(),
+ arrow::csv::ConvertOptions::Defaults()));
+ ARROW_ASSIGN_OR_RAISE(auto table, csv_reader->Read());
+
+ std::cerr << "* Read table:" << std::endl;
+ ARROW_RETURN_NOT_OK(arrow::PrettyPrint(*table, {}, &std::cerr));
+
+ std::cerr << "* Writing table into Arrow IPC file '" << arrow_filename << "'"
+ << std::endl;
+ ARROW_ASSIGN_OR_RAISE(auto output_file,
+ arrow::io::FileOutputStream::Open(arrow_filename));
+ ARROW_ASSIGN_OR_RAISE(auto batch_writer,
+ arrow::ipc::MakeFileWriter(output_file, table->schema()));
+ ARROW_RETURN_NOT_OK(batch_writer->WriteTable(*table));
+ ARROW_RETURN_NOT_OK(batch_writer->Close());
+
+ return Status::OK();
+}
+
+} // namespace
+
+int main(int argc, char** argv) {
+ Status st = RunMain(argc, argv);
+ if (!st.ok()) {
+ std::cerr << st << std::endl;
+ return 1;
+ }
+ return 0;
+}
diff --git a/src/arrow/cpp/examples/minimal_build/minimal.dockerfile b/src/arrow/cpp/examples/minimal_build/minimal.dockerfile
new file mode 100644
index 000000000..9361fc5e8
--- /dev/null
+++ b/src/arrow/cpp/examples/minimal_build/minimal.dockerfile
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+FROM ubuntu:focal
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get update -y -q && \
+ apt-get install -y -q --no-install-recommends \
+ build-essential \
+ cmake \
+ pkg-config && \
+ apt-get clean && rm -rf /var/lib/apt/lists*
diff --git a/src/arrow/cpp/examples/minimal_build/run.sh b/src/arrow/cpp/examples/minimal_build/run.sh
new file mode 100755
index 000000000..a76058b0b
--- /dev/null
+++ b/src/arrow/cpp/examples/minimal_build/run.sh
@@ -0,0 +1,48 @@
+#!/usr/bin/env bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+
+cd /io
+
+export ARROW_BUILD_DIR=/build/arrow
+export EXAMPLE_BUILD_DIR=/build/example
+
+echo
+echo "=="
+echo "== Building Arrow C++ library"
+echo "=="
+echo
+
+./build_arrow.sh
+
+echo
+echo "=="
+echo "== Building example project using Arrow C++ library"
+echo "=="
+echo
+
+./build_example.sh
+
+echo
+echo "=="
+echo "== Running example project"
+echo "=="
+echo
+
+${EXAMPLE_BUILD_DIR}/arrow_example
diff --git a/src/arrow/cpp/examples/minimal_build/run_static.bat b/src/arrow/cpp/examples/minimal_build/run_static.bat
new file mode 100644
index 000000000..bbc7ff8f7
--- /dev/null
+++ b/src/arrow/cpp/examples/minimal_build/run_static.bat
@@ -0,0 +1,88 @@
+@rem Licensed to the Apache Software Foundation (ASF) under one
+@rem or more contributor license agreements. See the NOTICE file
+@rem distributed with this work for additional information
+@rem regarding copyright ownership. The ASF licenses this file
+@rem to you under the Apache License, Version 2.0 (the
+@rem "License"); you may not use this file except in compliance
+@rem with the License. You may obtain a copy of the License at
+@rem
+@rem http://www.apache.org/licenses/LICENSE-2.0
+@rem
+@rem Unless required by applicable law or agreed to in writing,
+@rem software distributed under the License is distributed on an
+@rem "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+@rem KIND, either express or implied. See the License for the
+@rem specific language governing permissions and limitations
+@rem under the License.
+
+@echo on
+
+@rem clean up prior attempts
+if exist "arrow-build" rd arrow-build /s /q
+if exist "dist" rd dist /s /q
+if exist "example" rd example /s /q
+
+echo
+echo "=="
+echo "== Building Arrow C++ library"
+echo "=="
+echo
+
+set INSTALL_PREFIX=%cd%\dist
+
+mkdir arrow-build
+pushd arrow-build
+
+@rem bzip2_ep fails with this method
+
+cmake ..\..\.. ^
+ -GNinja ^
+ -DCMAKE_INSTALL_PREFIX=%INSTALL_PREFIX% ^
+ -DARROW_DEPENDENCY_SOURCE=BUNDLED ^
+ -DARROW_BUILD_SHARED=OFF ^
+ -DARROW_BUILD_STATIC=ON ^
+ -DARROW_COMPUTE=ON ^
+ -DARROW_CSV=ON ^
+ -DARROW_DATASET=ON ^
+ -DARROW_FILESYSTEM=ON ^
+ -DARROW_HDFS=ON ^
+ -DARROW_JSON=ON ^
+ -DARROW_MIMALLOC=ON ^
+ -DARROW_ORC=ON ^
+ -DARROW_PARQUET=ON ^
+ -DARROW_PLASMA=ON ^
+ -DARROW_WITH_BROTLI=ON ^
+ -DARROW_WITH_BZ2=OFF ^
+ -DARROW_WITH_LZ4=ON ^
+ -DARROW_WITH_SNAPPY=ON ^
+ -DARROW_WITH_ZLIB=ON ^
+ -DARROW_WITH_ZSTD=ON
+
+ninja install
+
+popd
+
+echo
+echo "=="
+echo "== Building example project using Arrow C++ library"
+echo "=="
+echo
+
+mkdir example
+pushd example
+
+cmake .. ^
+ -GNinja ^
+ -DCMAKE_PREFIX_PATH="%INSTALL_PREFIX%" ^
+ -DARROW_LINK_SHARED=OFF
+ninja
+
+popd
+
+echo
+echo "=="
+echo "== Running example project"
+echo "=="
+echo
+
+call example\arrow_example.exe
diff --git a/src/arrow/cpp/examples/minimal_build/run_static.sh b/src/arrow/cpp/examples/minimal_build/run_static.sh
new file mode 100755
index 000000000..ff3bb8945
--- /dev/null
+++ b/src/arrow/cpp/examples/minimal_build/run_static.sh
@@ -0,0 +1,121 @@
+#!/usr/bin/env bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+
+: ${ARROW_DIR:=/arrow}
+: ${EXAMPLE_DIR:=/io}
+: ${ARROW_BUILD_DIR:=/build/arrow}
+: ${EXAMPLE_BUILD_DIR:=/build/example}
+
+: ${ARROW_DEPENDENCY_SOURCE:=BUNDLED}
+
+echo
+echo "=="
+echo "== Building Arrow C++ library"
+echo "=="
+echo
+
+mkdir -p $ARROW_BUILD_DIR
+pushd $ARROW_BUILD_DIR
+
+NPROC=$(nproc)
+
+cmake $ARROW_DIR/cpp \
+ -DARROW_BUILD_SHARED=OFF \
+ -DARROW_BUILD_STATIC=ON \
+ -DARROW_COMPUTE=ON \
+ -DARROW_CSV=ON \
+ -DARROW_DATASET=ON \
+ -DARROW_DEPENDENCY_SOURCE=${ARROW_DEPENDENCY_SOURCE} \
+ -DARROW_DEPENDENCY_USE_SHARED=OFF \
+ -DARROW_FILESYSTEM=ON \
+ -DARROW_HDFS=ON \
+ -DARROW_JEMALLOC=ON \
+ -DARROW_JSON=ON \
+ -DARROW_ORC=ON \
+ -DARROW_PARQUET=ON \
+ -DARROW_PLASMA=ON \
+ -DARROW_WITH_BROTLI=ON \
+ -DARROW_WITH_BZ2=ON \
+ -DARROW_WITH_LZ4=ON \
+ -DARROW_WITH_SNAPPY=ON \
+ -DARROW_WITH_ZLIB=ON \
+ -DARROW_WITH_ZSTD=ON \
+ -DORC_SOURCE=BUNDLED \
+ $ARROW_CMAKE_OPTIONS
+
+make -j$NPROC
+make install
+
+popd
+
+echo
+echo "=="
+echo "== CMake:"
+echo "== Building example project using Arrow C++ library"
+echo "=="
+echo
+
+rm -rf $EXAMPLE_BUILD_DIR
+mkdir -p $EXAMPLE_BUILD_DIR
+pushd $EXAMPLE_BUILD_DIR
+
+cmake $EXAMPLE_DIR -DARROW_LINK_SHARED=OFF
+make
+
+popd
+
+echo
+echo "=="
+echo "== CMake:"
+echo "== Running example project"
+echo "=="
+echo
+
+pushd $EXAMPLE_DIR
+
+$EXAMPLE_BUILD_DIR/arrow_example
+
+echo
+echo "=="
+echo "== pkg-config"
+echo "== Building example project using Arrow C++ library"
+echo "=="
+echo
+
+rm -rf $EXAMPLE_BUILD_DIR
+mkdir -p $EXAMPLE_BUILD_DIR
+${CXX:-c++} \
+ -o $EXAMPLE_BUILD_DIR/arrow_example \
+ $EXAMPLE_DIR/example.cc \
+ $(PKG_CONFIG_PATH=$ARROW_BUILD_DIR/lib/pkgconfig \
+ pkg-config --cflags --libs --static arrow)
+
+popd
+
+echo
+echo "=="
+echo "== pkg-config:"
+echo "== Running example project"
+echo "=="
+echo
+
+pushd $EXAMPLE_DIR
+
+$EXAMPLE_BUILD_DIR/arrow_example
diff --git a/src/arrow/cpp/examples/minimal_build/system_dependency.dockerfile b/src/arrow/cpp/examples/minimal_build/system_dependency.dockerfile
new file mode 100644
index 000000000..926fcaf6f
--- /dev/null
+++ b/src/arrow/cpp/examples/minimal_build/system_dependency.dockerfile
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+FROM ubuntu:focal
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get update -y -q && \
+ apt-get install -y -q --no-install-recommends \
+ build-essential \
+ cmake \
+ libboost-filesystem-dev \
+ libboost-regex-dev \
+ libboost-system-dev \
+ libbrotli-dev \
+ libbz2-dev \
+ libgflags-dev \
+ liblz4-dev \
+ libprotobuf-dev \
+ libprotoc-dev \
+ libre2-dev \
+ libsnappy-dev \
+ libthrift-dev \
+ libutf8proc-dev \
+ libzstd-dev \
+ pkg-config \
+ protobuf-compiler \
+ rapidjson-dev \
+ zlib1g-dev && \
+ apt-get clean && rm -rf /var/lib/apt/lists*
diff --git a/src/arrow/cpp/examples/minimal_build/test.csv b/src/arrow/cpp/examples/minimal_build/test.csv
new file mode 100644
index 000000000..ca2440852
--- /dev/null
+++ b/src/arrow/cpp/examples/minimal_build/test.csv
@@ -0,0 +1,3 @@
+Integers,Strings,Timestamps
+1,Some,2018-11-13 17:11:10
+2,data,N/A
diff --git a/src/arrow/cpp/examples/parquet/CMakeLists.txt b/src/arrow/cpp/examples/parquet/CMakeLists.txt
new file mode 100644
index 000000000..2d16948ae
--- /dev/null
+++ b/src/arrow/cpp/examples/parquet/CMakeLists.txt
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_executable(parquet_low_level_example low_level_api/reader_writer.cc)
+add_executable(parquet_low_level_example2 low_level_api/reader_writer2.cc)
+add_executable(parquet_arrow_example parquet_arrow/reader_writer.cc)
+add_executable(parquet_stream_api_example parquet_stream_api/stream_reader_writer.cc)
+target_include_directories(parquet_low_level_example PRIVATE low_level_api/)
+target_include_directories(parquet_low_level_example2 PRIVATE low_level_api/)
+
+# The variables in these files are for illustration purposes
+set(PARQUET_EXAMPLES_WARNING_SUPPRESSIONS
+ low_level_api/reader_writer.cc
+ low_level_api/reader_writer2.cc)
+
+if (PARQUET_REQUIRE_ENCRYPTION)
+ add_executable(parquet_encryption_example low_level_api/encryption_reader_writer.cc)
+ add_executable(parquet_encryption_example_all_crypto_options low_level_api/encryption_reader_writer_all_crypto_options.cc)
+ target_include_directories(parquet_encryption_example PRIVATE low_level_api/)
+ target_include_directories(parquet_encryption_example_all_crypto_options PRIVATE low_level_api/)
+
+ set(PARQUET_EXAMPLES_WARNING_SUPPRESSIONS
+ ${PARQUET_EXAMPLES_WARNING_SUPPRESSIONS}
+ low_level_api/encryption_reader_writer.cc
+ low_level_api/encryption_reader_writer_all_crypto_options.cc)
+
+endif()
+
+if(UNIX)
+ foreach(FILE ${PARQUET_EXAMPLES_WARNING_SUPPRESSIONS})
+ set_property(SOURCE ${FILE}
+ APPEND_STRING
+ PROPERTY COMPILE_FLAGS "-Wno-unused-variable")
+ endforeach()
+endif()
+
+# Prefer shared linkage but use static if shared build is deactivated
+if (ARROW_BUILD_SHARED)
+ set(PARQUET_EXAMPLE_LINK_LIBS parquet_shared)
+else()
+ set(PARQUET_EXAMPLE_LINK_LIBS parquet_static)
+endif()
+
+target_link_libraries(parquet_arrow_example ${PARQUET_EXAMPLE_LINK_LIBS})
+target_link_libraries(parquet_low_level_example ${PARQUET_EXAMPLE_LINK_LIBS})
+target_link_libraries(parquet_low_level_example2 ${PARQUET_EXAMPLE_LINK_LIBS})
+target_link_libraries(parquet_stream_api_example ${PARQUET_EXAMPLE_LINK_LIBS})
+
+if(PARQUET_REQUIRE_ENCRYPTION)
+ target_link_libraries(parquet_encryption_example ${PARQUET_EXAMPLE_LINK_LIBS})
+ target_link_libraries(parquet_encryption_example_all_crypto_options ${PARQUET_EXAMPLE_LINK_LIBS})
+endif()
+
+add_dependencies(parquet
+ parquet_low_level_example
+ parquet_low_level_example2
+ parquet_arrow_example
+ parquet_stream_api_example)
+
+if (PARQUET_REQUIRE_ENCRYPTION)
+ add_dependencies(parquet
+ parquet_encryption_example
+ parquet_encryption_example_all_crypto_options)
+endif()
diff --git a/src/arrow/cpp/examples/parquet/low_level_api/encryption_reader_writer.cc b/src/arrow/cpp/examples/parquet/low_level_api/encryption_reader_writer.cc
new file mode 100644
index 000000000..75788b283
--- /dev/null
+++ b/src/arrow/cpp/examples/parquet/low_level_api/encryption_reader_writer.cc
@@ -0,0 +1,451 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <reader_writer.h>
+
+#include <cassert>
+#include <fstream>
+#include <iostream>
+#include <memory>
+
+/*
+ * This file contains sample for writing and reading encrypted Parquet file with
+ * basic encryption configuration.
+ *
+ * A detailed description of the Parquet Modular Encryption specification can be found
+ * here:
+ * https://github.com/apache/parquet-format/blob/encryption/Encryption.md
+ *
+ * The write sample creates a file with eight columns where two of the columns and the
+ * footer are encrypted.
+ *
+ * The read sample decrypts using key retriever that holds the keys of two encrypted
+ * columns and the footer key.
+ */
+
+constexpr int NUM_ROWS_PER_ROW_GROUP = 500;
+const char* PARQUET_FILENAME = "parquet_cpp_example.parquet.encrypted";
+const char* kFooterEncryptionKey = "0123456789012345"; // 128bit/16
+const char* kColumnEncryptionKey1 = "1234567890123450";
+const char* kColumnEncryptionKey2 = "1234567890123451";
+
+int main(int argc, char** argv) {
+ /**********************************************************************************
+ PARQUET ENCRYPTION WRITER EXAMPLE
+ **********************************************************************************/
+
+ try {
+ // Create a local file output stream instance.
+ using FileClass = ::arrow::io::FileOutputStream;
+ std::shared_ptr<FileClass> out_file;
+ PARQUET_ASSIGN_OR_THROW(out_file, FileClass::Open(PARQUET_FILENAME));
+
+ // Setup the parquet schema
+ std::shared_ptr<GroupNode> schema = SetupSchema();
+
+ // Add encryption properties
+ // Encryption configuration: Encrypt two columns and the footer.
+ std::map<std::string, std::shared_ptr<parquet::ColumnEncryptionProperties>>
+ encryption_cols;
+
+ parquet::SchemaDescriptor schema_desc;
+ schema_desc.Init(schema);
+ auto column_path1 = schema_desc.Column(5)->path()->ToDotString();
+ auto column_path2 = schema_desc.Column(4)->path()->ToDotString();
+
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder0(column_path1);
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder1(column_path2);
+ encryption_col_builder0.key(kColumnEncryptionKey1)->key_id("kc1");
+ encryption_col_builder1.key(kColumnEncryptionKey2)->key_id("kc2");
+
+ encryption_cols[column_path1] = encryption_col_builder0.build();
+ encryption_cols[column_path2] = encryption_col_builder1.build();
+
+ parquet::FileEncryptionProperties::Builder file_encryption_builder(
+ kFooterEncryptionKey);
+
+ parquet::WriterProperties::Builder builder;
+ // Add the current encryption configuration to WriterProperties.
+ builder.encryption(file_encryption_builder.footer_key_metadata("kf")
+ ->encrypted_columns(encryption_cols)
+ ->build());
+
+ // Add other writer properties
+ builder.compression(parquet::Compression::SNAPPY);
+
+ std::shared_ptr<parquet::WriterProperties> props = builder.build();
+
+ // Create a ParquetFileWriter instance
+ std::shared_ptr<parquet::ParquetFileWriter> file_writer =
+ parquet::ParquetFileWriter::Open(out_file, schema, props);
+
+ // Append a RowGroup with a specific number of rows.
+ parquet::RowGroupWriter* rg_writer = file_writer->AppendRowGroup();
+
+ // Write the Bool column
+ parquet::BoolWriter* bool_writer =
+ static_cast<parquet::BoolWriter*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ bool value = ((i % 2) == 0) ? true : false;
+ bool_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Write the Int32 column
+ parquet::Int32Writer* int32_writer =
+ static_cast<parquet::Int32Writer*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ int32_t value = i;
+ int32_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Write the Int64 column. Each row has repeats twice.
+ parquet::Int64Writer* int64_writer =
+ static_cast<parquet::Int64Writer*>(rg_writer->NextColumn());
+ for (int i = 0; i < 2 * NUM_ROWS_PER_ROW_GROUP; i++) {
+ int64_t value = i * 1000 * 1000;
+ value *= 1000 * 1000;
+ int16_t definition_level = 1;
+ int16_t repetition_level = 0;
+ if ((i % 2) == 0) {
+ repetition_level = 1; // start of a new record
+ }
+ int64_writer->WriteBatch(1, &definition_level, &repetition_level, &value);
+ }
+
+ // Write the INT96 column.
+ parquet::Int96Writer* int96_writer =
+ static_cast<parquet::Int96Writer*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ parquet::Int96 value;
+ value.value[0] = i;
+ value.value[1] = i + 1;
+ value.value[2] = i + 2;
+ int96_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Write the Float column
+ parquet::FloatWriter* float_writer =
+ static_cast<parquet::FloatWriter*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ float value = static_cast<float>(i) * 1.1f;
+ float_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Write the Double column
+ parquet::DoubleWriter* double_writer =
+ static_cast<parquet::DoubleWriter*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ double value = i * 1.1111111;
+ double_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Write the ByteArray column. Make every alternate values NULL
+ parquet::ByteArrayWriter* ba_writer =
+ static_cast<parquet::ByteArrayWriter*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ parquet::ByteArray value;
+ char hello[FIXED_LENGTH] = "parquet";
+ hello[7] = static_cast<char>(static_cast<int>('0') + i / 100);
+ hello[8] = static_cast<char>(static_cast<int>('0') + (i / 10) % 10);
+ hello[9] = static_cast<char>(static_cast<int>('0') + i % 10);
+ if (i % 2 == 0) {
+ int16_t definition_level = 1;
+ value.ptr = reinterpret_cast<const uint8_t*>(&hello[0]);
+ value.len = FIXED_LENGTH;
+ ba_writer->WriteBatch(1, &definition_level, nullptr, &value);
+ } else {
+ int16_t definition_level = 0;
+ ba_writer->WriteBatch(1, &definition_level, nullptr, nullptr);
+ }
+ }
+
+ // Write the FixedLengthByteArray column
+ parquet::FixedLenByteArrayWriter* flba_writer =
+ static_cast<parquet::FixedLenByteArrayWriter*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ parquet::FixedLenByteArray value;
+ char v = static_cast<char>(i);
+ char flba[FIXED_LENGTH] = {v, v, v, v, v, v, v, v, v, v};
+ value.ptr = reinterpret_cast<const uint8_t*>(&flba[0]);
+
+ flba_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Close the ParquetFileWriter
+ file_writer->Close();
+
+ // Write the bytes to file
+ DCHECK(out_file->Close().ok());
+ } catch (const std::exception& e) {
+ std::cerr << "Parquet write error: " << e.what() << std::endl;
+ return -1;
+ }
+
+ /**********************************************************************************
+ PARQUET ENCRYPTION READER EXAMPLE
+ **********************************************************************************/
+
+ // Decryption configuration: Decrypt using key retriever callback that holds the keys
+ // of two encrypted columns and the footer key.
+ std::shared_ptr<parquet::StringKeyIdRetriever> string_kr1 =
+ std::make_shared<parquet::StringKeyIdRetriever>();
+ string_kr1->PutKey("kf", kFooterEncryptionKey);
+ string_kr1->PutKey("kc1", kColumnEncryptionKey1);
+ string_kr1->PutKey("kc2", kColumnEncryptionKey2);
+ std::shared_ptr<parquet::DecryptionKeyRetriever> kr1 =
+ std::static_pointer_cast<parquet::StringKeyIdRetriever>(string_kr1);
+
+ parquet::FileDecryptionProperties::Builder file_decryption_builder;
+
+ try {
+ parquet::ReaderProperties reader_properties = parquet::default_reader_properties();
+
+ // Add the current decryption configuration to ReaderProperties.
+ reader_properties.file_decryption_properties(
+ file_decryption_builder.key_retriever(kr1)->build());
+
+ // Create a ParquetReader instance
+ std::unique_ptr<parquet::ParquetFileReader> parquet_reader =
+ parquet::ParquetFileReader::OpenFile(PARQUET_FILENAME, false, reader_properties);
+
+ // Get the File MetaData
+ std::shared_ptr<parquet::FileMetaData> file_metadata = parquet_reader->metadata();
+
+ // Get the number of RowGroups
+ int num_row_groups = file_metadata->num_row_groups();
+ assert(num_row_groups == 1);
+
+ // Get the number of Columns
+ int num_columns = file_metadata->num_columns();
+ assert(num_columns == 8);
+
+ // Iterate over all the RowGroups in the file
+ for (int r = 0; r < num_row_groups; ++r) {
+ // Get the RowGroup Reader
+ std::shared_ptr<parquet::RowGroupReader> row_group_reader =
+ parquet_reader->RowGroup(r);
+
+ int64_t values_read = 0;
+ int64_t rows_read = 0;
+ int16_t definition_level;
+ int16_t repetition_level;
+ int i;
+ std::shared_ptr<parquet::ColumnReader> column_reader;
+
+ // Get the Column Reader for the boolean column
+ column_reader = row_group_reader->Column(0);
+ parquet::BoolReader* bool_reader =
+ static_cast<parquet::BoolReader*>(column_reader.get());
+
+ // Read all the rows in the column
+ i = 0;
+ while (bool_reader->HasNext()) {
+ bool value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = bool_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ bool expected_value = ((i % 2) == 0) ? true : false;
+ assert(value == expected_value);
+ i++;
+ }
+
+ // Get the Column Reader for the Int32 column
+ column_reader = row_group_reader->Column(1);
+ parquet::Int32Reader* int32_reader =
+ static_cast<parquet::Int32Reader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (int32_reader->HasNext()) {
+ int32_t value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = int32_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ assert(value == i);
+ i++;
+ }
+
+ // Get the Column Reader for the Int64 column
+ column_reader = row_group_reader->Column(2);
+ parquet::Int64Reader* int64_reader =
+ static_cast<parquet::Int64Reader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (int64_reader->HasNext()) {
+ int64_t value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = int64_reader->ReadBatch(1, &definition_level, &repetition_level,
+ &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ int64_t expected_value = i * 1000 * 1000;
+ expected_value *= 1000 * 1000;
+ assert(value == expected_value);
+ if ((i % 2) == 0) {
+ assert(repetition_level == 1);
+ } else {
+ assert(repetition_level == 0);
+ }
+ i++;
+ }
+
+ // Get the Column Reader for the Int96 column
+ column_reader = row_group_reader->Column(3);
+ parquet::Int96Reader* int96_reader =
+ static_cast<parquet::Int96Reader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (int96_reader->HasNext()) {
+ parquet::Int96 value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = int96_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ parquet::Int96 expected_value;
+ expected_value.value[0] = i;
+ expected_value.value[1] = i + 1;
+ expected_value.value[2] = i + 2;
+ for (int j = 0; j < 3; j++) {
+ assert(value.value[j] == expected_value.value[j]);
+ }
+ ARROW_UNUSED(expected_value); // suppress compiler warning in release builds
+ i++;
+ }
+
+ // Get the Column Reader for the Float column
+ column_reader = row_group_reader->Column(4);
+ parquet::FloatReader* float_reader =
+ static_cast<parquet::FloatReader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (float_reader->HasNext()) {
+ float value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = float_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ float expected_value = static_cast<float>(i) * 1.1f;
+ assert(value == expected_value);
+ i++;
+ }
+
+ // Get the Column Reader for the Double column
+ column_reader = row_group_reader->Column(5);
+ parquet::DoubleReader* double_reader =
+ static_cast<parquet::DoubleReader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (double_reader->HasNext()) {
+ double value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = double_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ double expected_value = i * 1.1111111;
+ assert(value == expected_value);
+ i++;
+ }
+
+ // Get the Column Reader for the ByteArray column
+ column_reader = row_group_reader->Column(6);
+ parquet::ByteArrayReader* ba_reader =
+ static_cast<parquet::ByteArrayReader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (ba_reader->HasNext()) {
+ parquet::ByteArray value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read =
+ ba_reader->ReadBatch(1, &definition_level, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ ARROW_UNUSED(rows_read); // suppress compiler warning in release builds
+ // Verify the value written
+ char expected_value[FIXED_LENGTH] = "parquet";
+ expected_value[7] = static_cast<char>('0' + i / 100);
+ expected_value[8] = static_cast<char>('0' + (i / 10) % 10);
+ expected_value[9] = static_cast<char>('0' + i % 10);
+ if (i % 2 == 0) { // only alternate values exist
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ assert(value.len == FIXED_LENGTH);
+ assert(memcmp(value.ptr, &expected_value[0], FIXED_LENGTH) == 0);
+ assert(definition_level == 1);
+ } else {
+ // There are NULL values in the rows written
+ assert(values_read == 0);
+ assert(definition_level == 0);
+ }
+ ARROW_UNUSED(expected_value); // suppress compiler warning in release builds
+ i++;
+ }
+
+ // Get the Column Reader for the FixedLengthByteArray column
+ column_reader = row_group_reader->Column(7);
+ parquet::FixedLenByteArrayReader* flba_reader =
+ static_cast<parquet::FixedLenByteArrayReader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (flba_reader->HasNext()) {
+ parquet::FixedLenByteArray value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = flba_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ char v = static_cast<char>(i);
+ char expected_value[FIXED_LENGTH] = {v, v, v, v, v, v, v, v, v, v};
+ assert(memcmp(value.ptr, &expected_value[0], FIXED_LENGTH) == 0);
+ i++;
+ }
+ }
+ } catch (const std::exception& e) {
+ std::cerr << "Parquet read error: " << e.what() << std::endl;
+ }
+
+ std::cout << "Parquet Writing and Reading Complete" << std::endl;
+ return 0;
+}
diff --git a/src/arrow/cpp/examples/parquet/low_level_api/encryption_reader_writer_all_crypto_options.cc b/src/arrow/cpp/examples/parquet/low_level_api/encryption_reader_writer_all_crypto_options.cc
new file mode 100644
index 000000000..5b01e0284
--- /dev/null
+++ b/src/arrow/cpp/examples/parquet/low_level_api/encryption_reader_writer_all_crypto_options.cc
@@ -0,0 +1,656 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arrow/io/file.h>
+#include <arrow/util/logging.h>
+#include <dirent.h>
+#include <parquet/api/reader.h>
+#include <parquet/api/writer.h>
+
+#include <cassert>
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <regex>
+#include <sstream>
+
+/*
+ * This file contains samples for writing and reading encrypted Parquet files in different
+ * encryption and decryption configurations.
+ * Each sample section is dedicated to an independent configuration and shows its creation
+ * from beginning to end.
+ * The samples have the following goals:
+ * 1) Demonstrate usage of different options for data encryption and decryption.
+ * 2) Produce encrypted files for interoperability tests with other (eg parquet-mr)
+ * readers that support encryption.
+ * 3) Produce encrypted files with plaintext footer, for testing the ability of legacy
+ * readers to parse the footer and read unencrypted columns.
+ * 4) Perform interoperability tests with other (eg parquet-mr) writers, by reading
+ * encrypted files produced by these writers.
+ *
+ * Each write sample produces new independent parquet file, encrypted with a different
+ * encryption configuration as described below.
+ * The name of each file is in the form of:
+ * tester<encryption config number>.parquet.encrypted.
+ *
+ * The read sample creates a set of decryption configurations and then uses each of them
+ * to read all encrypted files in the input directory.
+ *
+ * The different encryption and decryption configurations are listed below.
+ *
+ * Usage: ./encryption-interop-tests <write/read> <path-to-directory-of-parquet-files>
+ *
+ * A detailed description of the Parquet Modular Encryption specification can be found
+ * here:
+ * https://github.com/apache/parquet-format/blob/encryption/Encryption.md
+ *
+ * The write sample creates files with four columns in the following
+ * encryption configurations:
+ *
+ * - Encryption configuration 1: Encrypt all columns and the footer with the same key.
+ * (uniform encryption)
+ * - Encryption configuration 2: Encrypt two columns and the footer, with different
+ * keys.
+ * - Encryption configuration 3: Encrypt two columns, with different keys.
+ * Don’t encrypt footer (to enable legacy readers)
+ * - plaintext footer mode.
+ * - Encryption configuration 4: Encrypt two columns and the footer, with different
+ * keys. Supply aad_prefix for file identity
+ * verification.
+ * - Encryption configuration 5: Encrypt two columns and the footer, with different
+ * keys. Supply aad_prefix, and call
+ * disable_aad_prefix_storage to prevent file
+ * identity storage in file metadata.
+ * - Encryption configuration 6: Encrypt two columns and the footer, with different
+ * keys. Use the alternative (AES_GCM_CTR_V1) algorithm.
+ *
+ * The read sample uses each of the following decryption configurations to read every
+ * encrypted files in the input directory:
+ *
+ * - Decryption configuration 1: Decrypt using key retriever that holds the keys of
+ * two encrypted columns and the footer key.
+ * - Decryption configuration 2: Decrypt using key retriever that holds the keys of
+ * two encrypted columns and the footer key. Supplies
+ * aad_prefix to verify file identity.
+ * - Decryption configuration 3: Decrypt using explicit column and footer keys
+ * (instead of key retrieval callback).
+ */
+
+constexpr int NUM_ROWS_PER_ROW_GROUP = 500;
+
+const char* kFooterEncryptionKey = "0123456789012345"; // 128bit/16
+const char* kColumnEncryptionKey1 = "1234567890123450";
+const char* kColumnEncryptionKey2 = "1234567890123451";
+const char* fileName = "tester";
+
+using FileClass = ::arrow::io::FileOutputStream;
+using parquet::ConvertedType;
+using parquet::Repetition;
+using parquet::Type;
+using parquet::schema::GroupNode;
+using parquet::schema::PrimitiveNode;
+
+void PrintDecryptionConfiguration(int configuration);
+// Check that the decryption result is as expected.
+void CheckResult(std::string file, int example_id, std::string exception_msg);
+// Returns true if FileName ends with suffix. Otherwise returns false.
+// Used to skip unencrypted parquet files.
+bool FileNameEndsWith(std::string file_name, std::string suffix);
+
+std::vector<std::string> GetDirectoryFiles(const std::string& path) {
+ std::vector<std::string> files;
+ struct dirent* entry;
+ DIR* dir = opendir(path.c_str());
+
+ if (dir == NULL) {
+ exit(-1);
+ }
+ while ((entry = readdir(dir)) != NULL) {
+ files.push_back(std::string(entry->d_name));
+ }
+ closedir(dir);
+ return files;
+}
+
+static std::shared_ptr<GroupNode> SetupSchema() {
+ parquet::schema::NodeVector fields;
+ // Create a primitive node named 'boolean_field' with type:BOOLEAN,
+ // repetition:REQUIRED
+ fields.push_back(PrimitiveNode::Make("boolean_field", Repetition::REQUIRED,
+ Type::BOOLEAN, ConvertedType::NONE));
+
+ // Create a primitive node named 'int32_field' with type:INT32, repetition:REQUIRED,
+ // logical type:TIME_MILLIS
+ fields.push_back(PrimitiveNode::Make("int32_field", Repetition::REQUIRED, Type::INT32,
+ ConvertedType::TIME_MILLIS));
+
+ fields.push_back(PrimitiveNode::Make("float_field", Repetition::REQUIRED, Type::FLOAT,
+ ConvertedType::NONE));
+
+ fields.push_back(PrimitiveNode::Make("double_field", Repetition::REQUIRED, Type::DOUBLE,
+ ConvertedType::NONE));
+
+ // Create a GroupNode named 'schema' using the primitive nodes defined above
+ // This GroupNode is the root node of the schema tree
+ return std::static_pointer_cast<GroupNode>(
+ GroupNode::Make("schema", Repetition::REQUIRED, fields));
+}
+
+void InteropTestWriteEncryptedParquetFiles(std::string root_path) {
+ /**********************************************************************************
+ Creating a number of Encryption configurations
+ **********************************************************************************/
+
+ // This vector will hold various encryption configurations.
+ std::vector<std::shared_ptr<parquet::FileEncryptionProperties>>
+ vector_of_encryption_configurations;
+
+ // Encryption configuration 1: Encrypt all columns and the footer with the same key.
+ // (uniform encryption)
+ parquet::FileEncryptionProperties::Builder file_encryption_builder_1(
+ kFooterEncryptionKey);
+ // Add to list of encryption configurations.
+ vector_of_encryption_configurations.push_back(
+ file_encryption_builder_1.footer_key_metadata("kf")->build());
+
+ // Encryption configuration 2: Encrypt two columns and the footer, with different keys.
+ std::map<std::string, std::shared_ptr<parquet::ColumnEncryptionProperties>>
+ encryption_cols2;
+ std::string path1 = "double_field";
+ std::string path2 = "float_field";
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_20(path1);
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_21(path2);
+ encryption_col_builder_20.key(kColumnEncryptionKey1)->key_id("kc1");
+ encryption_col_builder_21.key(kColumnEncryptionKey2)->key_id("kc2");
+
+ encryption_cols2[path1] = encryption_col_builder_20.build();
+ encryption_cols2[path2] = encryption_col_builder_21.build();
+
+ parquet::FileEncryptionProperties::Builder file_encryption_builder_2(
+ kFooterEncryptionKey);
+
+ vector_of_encryption_configurations.push_back(
+ file_encryption_builder_2.footer_key_metadata("kf")
+ ->encrypted_columns(encryption_cols2)
+ ->build());
+
+ // Encryption configuration 3: Encrypt two columns, with different keys.
+ // Don’t encrypt footer.
+ // (plaintext footer mode, readable by legacy readers)
+ std::map<std::string, std::shared_ptr<parquet::ColumnEncryptionProperties>>
+ encryption_cols3;
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_30(path1);
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_31(path2);
+ encryption_col_builder_30.key(kColumnEncryptionKey1)->key_id("kc1");
+ encryption_col_builder_31.key(kColumnEncryptionKey2)->key_id("kc2");
+
+ encryption_cols3[path1] = encryption_col_builder_30.build();
+ encryption_cols3[path2] = encryption_col_builder_31.build();
+ parquet::FileEncryptionProperties::Builder file_encryption_builder_3(
+ kFooterEncryptionKey);
+
+ vector_of_encryption_configurations.push_back(
+ file_encryption_builder_3.footer_key_metadata("kf")
+ ->encrypted_columns(encryption_cols3)
+ ->set_plaintext_footer()
+ ->build());
+
+ // Encryption configuration 4: Encrypt two columns and the footer, with different keys.
+ // Use aad_prefix.
+ std::map<std::string, std::shared_ptr<parquet::ColumnEncryptionProperties>>
+ encryption_cols4;
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_40(path1);
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_41(path2);
+ encryption_col_builder_40.key(kColumnEncryptionKey1)->key_id("kc1");
+ encryption_col_builder_41.key(kColumnEncryptionKey2)->key_id("kc2");
+
+ encryption_cols4[path1] = encryption_col_builder_40.build();
+ encryption_cols4[path2] = encryption_col_builder_41.build();
+ parquet::FileEncryptionProperties::Builder file_encryption_builder_4(
+ kFooterEncryptionKey);
+
+ vector_of_encryption_configurations.push_back(
+ file_encryption_builder_4.footer_key_metadata("kf")
+ ->encrypted_columns(encryption_cols4)
+ ->aad_prefix(fileName)
+ ->build());
+
+ // Encryption configuration 5: Encrypt two columns and the footer, with different keys.
+ // Use aad_prefix and disable_aad_prefix_storage.
+ std::map<std::string, std::shared_ptr<parquet::ColumnEncryptionProperties>>
+ encryption_cols5;
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_50(path1);
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_51(path2);
+ encryption_col_builder_50.key(kColumnEncryptionKey1)->key_id("kc1");
+ encryption_col_builder_51.key(kColumnEncryptionKey2)->key_id("kc2");
+
+ encryption_cols5[path1] = encryption_col_builder_50.build();
+ encryption_cols5[path2] = encryption_col_builder_51.build();
+ parquet::FileEncryptionProperties::Builder file_encryption_builder_5(
+ kFooterEncryptionKey);
+
+ vector_of_encryption_configurations.push_back(
+ file_encryption_builder_5.encrypted_columns(encryption_cols5)
+ ->footer_key_metadata("kf")
+ ->aad_prefix(fileName)
+ ->disable_aad_prefix_storage()
+ ->build());
+
+ // Encryption configuration 6: Encrypt two columns and the footer, with different keys.
+ // Use AES_GCM_CTR_V1 algorithm.
+ std::map<std::string, std::shared_ptr<parquet::ColumnEncryptionProperties>>
+ encryption_cols6;
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_60(path1);
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_61(path2);
+ encryption_col_builder_60.key(kColumnEncryptionKey1)->key_id("kc1");
+ encryption_col_builder_61.key(kColumnEncryptionKey2)->key_id("kc2");
+
+ encryption_cols6[path1] = encryption_col_builder_60.build();
+ encryption_cols6[path2] = encryption_col_builder_61.build();
+ parquet::FileEncryptionProperties::Builder file_encryption_builder_6(
+ kFooterEncryptionKey);
+
+ vector_of_encryption_configurations.push_back(
+ file_encryption_builder_6.footer_key_metadata("kf")
+ ->encrypted_columns(encryption_cols6)
+ ->algorithm(parquet::ParquetCipher::AES_GCM_CTR_V1)
+ ->build());
+
+ /**********************************************************************************
+ PARQUET WRITER EXAMPLE
+ **********************************************************************************/
+
+ // Iterate over the encryption configurations and for each one write a parquet file.
+ for (unsigned example_id = 0; example_id < vector_of_encryption_configurations.size();
+ ++example_id) {
+ std::stringstream ss;
+ ss << example_id + 1;
+ std::string test_number_string = ss.str();
+ try {
+ // Create a local file output stream instance.
+ std::shared_ptr<FileClass> out_file;
+ std::string file =
+ root_path + fileName + std::string(test_number_string) + ".parquet.encrypted";
+ std::cout << "Write " << file << std::endl;
+ PARQUET_ASSIGN_OR_THROW(out_file, FileClass::Open(file));
+
+ // Setup the parquet schema
+ std::shared_ptr<GroupNode> schema = SetupSchema();
+
+ // Add writer properties
+ parquet::WriterProperties::Builder builder;
+ builder.compression(parquet::Compression::SNAPPY);
+
+ // Add the current encryption configuration to WriterProperties.
+ builder.encryption(vector_of_encryption_configurations[example_id]);
+
+ std::shared_ptr<parquet::WriterProperties> props = builder.build();
+
+ // Create a ParquetFileWriter instance
+ std::shared_ptr<parquet::ParquetFileWriter> file_writer =
+ parquet::ParquetFileWriter::Open(out_file, schema, props);
+
+ // Append a RowGroup with a specific number of rows.
+ parquet::RowGroupWriter* rg_writer = file_writer->AppendRowGroup();
+
+ // Write the Bool column
+ parquet::BoolWriter* bool_writer =
+ static_cast<parquet::BoolWriter*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ bool value = ((i % 2) == 0) ? true : false;
+ bool_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Write the Int32 column
+ parquet::Int32Writer* int32_writer =
+ static_cast<parquet::Int32Writer*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ int32_t value = i;
+ int32_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Write the Float column
+ parquet::FloatWriter* float_writer =
+ static_cast<parquet::FloatWriter*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ float value = static_cast<float>(i) * 1.1f;
+ float_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Write the Double column
+ parquet::DoubleWriter* double_writer =
+ static_cast<parquet::DoubleWriter*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ double value = i * 1.1111111;
+ double_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+ // Close the ParquetFileWriter
+ file_writer->Close();
+
+ // Write the bytes to file
+ DCHECK(out_file->Close().ok());
+ } catch (const std::exception& e) {
+ std::cerr << "Parquet write error: " << e.what() << std::endl;
+ return;
+ }
+ }
+}
+
+void InteropTestReadEncryptedParquetFiles(std::string root_path) {
+ std::vector<std::string> files_in_directory = GetDirectoryFiles(root_path);
+
+ /**********************************************************************************
+ Creating a number of Decryption configurations
+ **********************************************************************************/
+
+ // This vector will hold various decryption configurations.
+ std::vector<std::shared_ptr<parquet::FileDecryptionProperties>>
+ vector_of_decryption_configurations;
+
+ // Decryption configuration 1: Decrypt using key retriever callback that holds the keys
+ // of two encrypted columns and the footer key.
+ std::shared_ptr<parquet::StringKeyIdRetriever> string_kr1 =
+ std::make_shared<parquet::StringKeyIdRetriever>();
+ string_kr1->PutKey("kf", kFooterEncryptionKey);
+ string_kr1->PutKey("kc1", kColumnEncryptionKey1);
+ string_kr1->PutKey("kc2", kColumnEncryptionKey2);
+ std::shared_ptr<parquet::DecryptionKeyRetriever> kr1 =
+ std::static_pointer_cast<parquet::StringKeyIdRetriever>(string_kr1);
+
+ parquet::FileDecryptionProperties::Builder file_decryption_builder_1;
+ vector_of_decryption_configurations.push_back(
+ file_decryption_builder_1.key_retriever(kr1)->build());
+
+ // Decryption configuration 2: Decrypt using key retriever callback that holds the keys
+ // of two encrypted columns and the footer key. Supply aad_prefix.
+ std::shared_ptr<parquet::StringKeyIdRetriever> string_kr2 =
+ std::make_shared<parquet::StringKeyIdRetriever>();
+ string_kr2->PutKey("kf", kFooterEncryptionKey);
+ string_kr2->PutKey("kc1", kColumnEncryptionKey1);
+ string_kr2->PutKey("kc2", kColumnEncryptionKey2);
+ std::shared_ptr<parquet::DecryptionKeyRetriever> kr2 =
+ std::static_pointer_cast<parquet::StringKeyIdRetriever>(string_kr2);
+
+ parquet::FileDecryptionProperties::Builder file_decryption_builder_2;
+ vector_of_decryption_configurations.push_back(
+ file_decryption_builder_2.key_retriever(kr2)->aad_prefix(fileName)->build());
+
+ // Decryption configuration 3: Decrypt using explicit column and footer keys.
+ std::string path_double = "double_field";
+ std::string path_float = "float_field";
+ std::map<std::string, std::shared_ptr<parquet::ColumnDecryptionProperties>>
+ decryption_cols;
+ parquet::ColumnDecryptionProperties::Builder decryption_col_builder31(path_double);
+ parquet::ColumnDecryptionProperties::Builder decryption_col_builder32(path_float);
+
+ decryption_cols[path_double] =
+ decryption_col_builder31.key(kColumnEncryptionKey1)->build();
+ decryption_cols[path_float] =
+ decryption_col_builder32.key(kColumnEncryptionKey2)->build();
+
+ parquet::FileDecryptionProperties::Builder file_decryption_builder_3;
+ vector_of_decryption_configurations.push_back(
+ file_decryption_builder_3.footer_key(kFooterEncryptionKey)
+ ->column_keys(decryption_cols)
+ ->build());
+
+ /**********************************************************************************
+ PARQUET READER EXAMPLE
+ **********************************************************************************/
+
+ // Iterate over the decryption configurations and use each one to read every files
+ // in the input directory.
+ for (unsigned example_id = 0; example_id < vector_of_decryption_configurations.size();
+ ++example_id) {
+ PrintDecryptionConfiguration(example_id + 1);
+ for (auto const& file : files_in_directory) {
+ std::string exception_msg = "";
+ if (!FileNameEndsWith(file, "parquet.encrypted")) // Skip non encrypted files
+ continue;
+ try {
+ std::cout << "--> Read file " << file << std::endl;
+
+ parquet::ReaderProperties reader_properties =
+ parquet::default_reader_properties();
+
+ // Add the current decryption configuration to ReaderProperties.
+ reader_properties.file_decryption_properties(
+ vector_of_decryption_configurations[example_id]->DeepClone());
+
+ // Create a ParquetReader instance
+ std::unique_ptr<parquet::ParquetFileReader> parquet_reader =
+ parquet::ParquetFileReader::OpenFile(root_path + file, false,
+ reader_properties);
+
+ // Get the File MetaData
+ std::shared_ptr<parquet::FileMetaData> file_metadata = parquet_reader->metadata();
+
+ // Get the number of RowGroups
+ int num_row_groups = file_metadata->num_row_groups();
+ assert(num_row_groups == 1);
+
+ // Get the number of Columns
+ int num_columns = file_metadata->num_columns();
+ assert(num_columns == 4);
+
+ // Iterate over all the RowGroups in the file
+ for (int r = 0; r < num_row_groups; ++r) {
+ // Get the RowGroup Reader
+ std::shared_ptr<parquet::RowGroupReader> row_group_reader =
+ parquet_reader->RowGroup(r);
+
+ int64_t values_read = 0;
+ int64_t rows_read = 0;
+ int i;
+ std::shared_ptr<parquet::ColumnReader> column_reader;
+
+ // Get the Column Reader for the boolean column
+ column_reader = row_group_reader->Column(0);
+ parquet::BoolReader* bool_reader =
+ static_cast<parquet::BoolReader*>(column_reader.get());
+
+ // Read all the rows in the column
+ i = 0;
+ while (bool_reader->HasNext()) {
+ bool value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = bool_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ bool expected_value = ((i % 2) == 0) ? true : false;
+ assert(value == expected_value);
+ i++;
+ }
+ ARROW_UNUSED(rows_read); // suppress compiler warning in release builds
+
+ // Get the Column Reader for the Int32 column
+ column_reader = row_group_reader->Column(1);
+ parquet::Int32Reader* int32_reader =
+ static_cast<parquet::Int32Reader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (int32_reader->HasNext()) {
+ int32_t value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read =
+ int32_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ assert(value == i);
+ i++;
+ }
+
+ // Get the Column Reader for the Float column
+ column_reader = row_group_reader->Column(2);
+ parquet::FloatReader* float_reader =
+ static_cast<parquet::FloatReader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (float_reader->HasNext()) {
+ float value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read =
+ float_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ float expected_value = static_cast<float>(i) * 1.1f;
+ assert(value == expected_value);
+ i++;
+ }
+
+ // Get the Column Reader for the Double column
+ column_reader = row_group_reader->Column(3);
+ parquet::DoubleReader* double_reader =
+ static_cast<parquet::DoubleReader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (double_reader->HasNext()) {
+ double value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read =
+ double_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ double expected_value = i * 1.1111111;
+ assert(value == expected_value);
+ i++;
+ }
+ }
+ } catch (const std::exception& e) {
+ exception_msg = e.what();
+ }
+ CheckResult(file, example_id, exception_msg);
+ std::cout << "file [" << file << "] Parquet Reading Complete" << std::endl;
+ }
+ }
+}
+
+void PrintDecryptionConfiguration(int configuration) {
+ std::cout << "\n\nDecryption configuration ";
+ if (configuration == 1) {
+ std::cout << "1: \n\nDecrypt using key retriever that holds"
+ " the keys of two encrypted columns and the footer key."
+ << std::endl;
+ } else if (configuration == 2) {
+ std::cout << "2: \n\nDecrypt using key retriever that holds"
+ " the keys of two encrypted columns and the footer key. Pass aad_prefix."
+ << std::endl;
+ } else if (configuration == 3) {
+ std::cout << "3: \n\nDecrypt using explicit column and footer keys." << std::endl;
+ } else {
+ std::cout << "Unknown configuration" << std::endl;
+ exit(-1);
+ }
+ std::cout << std::endl;
+}
+
+// Check that the decryption result is as expected.
+void CheckResult(std::string file, int example_id, std::string exception_msg) {
+ int encryption_configuration_number;
+ std::regex r("tester([0-9]+)\\.parquet.encrypted");
+ std::smatch m;
+ std::regex_search(file, m, r);
+ if (m.size() == 0) {
+ std::cerr
+ << "Error: Error parsing filename to extract encryption configuration number. "
+ << std::endl;
+ }
+ std::string encryption_configuration_number_str = m.str(1);
+ encryption_configuration_number = atoi(encryption_configuration_number_str.c_str());
+ if (encryption_configuration_number < 1 || encryption_configuration_number > 6) {
+ std::cerr << "Error: Unknown encryption configuration number. " << std::endl;
+ }
+
+ int decryption_configuration_number = example_id + 1;
+
+ // Encryption_configuration number five contains aad_prefix and
+ // disable_aad_prefix_storage.
+ // An exception is expected to be thrown if the file is not decrypted with aad_prefix.
+ if (encryption_configuration_number == 5) {
+ if (decryption_configuration_number == 1 || decryption_configuration_number == 3) {
+ std::size_t found = exception_msg.find("AAD");
+ if (found == std::string::npos)
+ std::cout << "Error: Expecting AAD related exception.";
+ return;
+ }
+ }
+ // Decryption configuration number two contains aad_prefix. An exception is expected to
+ // be thrown if the file was not encrypted with the same aad_prefix.
+ if (decryption_configuration_number == 2) {
+ if (encryption_configuration_number != 5 && encryption_configuration_number != 4) {
+ std::size_t found = exception_msg.find("AAD");
+ if (found == std::string::npos) {
+ std::cout << "Error: Expecting AAD related exception." << std::endl;
+ }
+ return;
+ }
+ }
+ if (!exception_msg.empty())
+ std::cout << "Error: Unexpected exception was thrown." << exception_msg;
+}
+
+bool FileNameEndsWith(std::string file_name, std::string suffix) {
+ std::string::size_type idx = file_name.find_first_of('.');
+
+ if (idx != std::string::npos) {
+ std::string extension = file_name.substr(idx + 1);
+ if (extension.compare(suffix) == 0) return true;
+ }
+ return false;
+}
+
+int main(int argc, char** argv) {
+ enum Operation { write, read };
+ std::string root_path;
+ Operation operation = write;
+ if (argc < 3) {
+ std::cout << "Usage: encryption-reader-writer-all-crypto-options <read/write> "
+ "<Path-to-parquet-files>"
+ << std::endl;
+ exit(1);
+ }
+ root_path = argv[1];
+ if (root_path.compare("read") == 0) {
+ operation = read;
+ }
+
+ root_path = argv[2];
+ std::cout << "Root path is: " << root_path << std::endl;
+
+ if (operation == write) {
+ InteropTestWriteEncryptedParquetFiles(root_path);
+ } else {
+ InteropTestReadEncryptedParquetFiles(root_path);
+ }
+ return 0;
+}
diff --git a/src/arrow/cpp/examples/parquet/low_level_api/reader_writer.cc b/src/arrow/cpp/examples/parquet/low_level_api/reader_writer.cc
new file mode 100644
index 000000000..09af32289
--- /dev/null
+++ b/src/arrow/cpp/examples/parquet/low_level_api/reader_writer.cc
@@ -0,0 +1,413 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <reader_writer.h>
+
+#include <cassert>
+#include <fstream>
+#include <iostream>
+#include <memory>
+
+/*
+ * This example describes writing and reading Parquet Files in C++ and serves as a
+ * reference to the API.
+ * The file contains all the physical data types supported by Parquet.
+ * This example uses the RowGroupWriter API that supports writing RowGroups optimized for
+ * memory consumption.
+ **/
+
+/* Parquet is a structured columnar file format
+ * Parquet File = "Parquet data" + "Parquet Metadata"
+ * "Parquet data" is simply a vector of RowGroups. Each RowGroup is a batch of rows in a
+ * columnar layout
+ * "Parquet Metadata" contains the "file schema" and attributes of the RowGroups and their
+ * Columns
+ * "file schema" is a tree where each node is either a primitive type (leaf nodes) or a
+ * complex (nested) type (internal nodes)
+ * For specific details, please refer the format here:
+ * https://github.com/apache/parquet-format/blob/master/LogicalTypes.md
+ **/
+
+constexpr int NUM_ROWS_PER_ROW_GROUP = 500;
+const char PARQUET_FILENAME[] = "parquet_cpp_example.parquet";
+
+int main(int argc, char** argv) {
+ /**********************************************************************************
+ PARQUET WRITER EXAMPLE
+ **********************************************************************************/
+ // parquet::REQUIRED fields do not need definition and repetition level values
+ // parquet::OPTIONAL fields require only definition level values
+ // parquet::REPEATED fields require both definition and repetition level values
+ try {
+ // Create a local file output stream instance.
+ using FileClass = ::arrow::io::FileOutputStream;
+ std::shared_ptr<FileClass> out_file;
+ PARQUET_ASSIGN_OR_THROW(out_file, FileClass::Open(PARQUET_FILENAME));
+
+ // Setup the parquet schema
+ std::shared_ptr<GroupNode> schema = SetupSchema();
+
+ // Add writer properties
+ parquet::WriterProperties::Builder builder;
+ builder.compression(parquet::Compression::SNAPPY);
+ std::shared_ptr<parquet::WriterProperties> props = builder.build();
+
+ // Create a ParquetFileWriter instance
+ std::shared_ptr<parquet::ParquetFileWriter> file_writer =
+ parquet::ParquetFileWriter::Open(out_file, schema, props);
+
+ // Append a RowGroup with a specific number of rows.
+ parquet::RowGroupWriter* rg_writer = file_writer->AppendRowGroup();
+
+ // Write the Bool column
+ parquet::BoolWriter* bool_writer =
+ static_cast<parquet::BoolWriter*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ bool value = ((i % 2) == 0) ? true : false;
+ bool_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Write the Int32 column
+ parquet::Int32Writer* int32_writer =
+ static_cast<parquet::Int32Writer*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ int32_t value = i;
+ int32_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Write the Int64 column. Each row has repeats twice.
+ parquet::Int64Writer* int64_writer =
+ static_cast<parquet::Int64Writer*>(rg_writer->NextColumn());
+ for (int i = 0; i < 2 * NUM_ROWS_PER_ROW_GROUP; i++) {
+ int64_t value = i * 1000 * 1000;
+ value *= 1000 * 1000;
+ int16_t definition_level = 1;
+ int16_t repetition_level = 0;
+ if ((i % 2) == 0) {
+ repetition_level = 1; // start of a new record
+ }
+ int64_writer->WriteBatch(1, &definition_level, &repetition_level, &value);
+ }
+
+ // Write the INT96 column.
+ parquet::Int96Writer* int96_writer =
+ static_cast<parquet::Int96Writer*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ parquet::Int96 value;
+ value.value[0] = i;
+ value.value[1] = i + 1;
+ value.value[2] = i + 2;
+ int96_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Write the Float column
+ parquet::FloatWriter* float_writer =
+ static_cast<parquet::FloatWriter*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ float value = static_cast<float>(i) * 1.1f;
+ float_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Write the Double column
+ parquet::DoubleWriter* double_writer =
+ static_cast<parquet::DoubleWriter*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ double value = i * 1.1111111;
+ double_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Write the ByteArray column. Make every alternate values NULL
+ parquet::ByteArrayWriter* ba_writer =
+ static_cast<parquet::ByteArrayWriter*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ parquet::ByteArray value;
+ char hello[FIXED_LENGTH] = "parquet";
+ hello[7] = static_cast<char>(static_cast<int>('0') + i / 100);
+ hello[8] = static_cast<char>(static_cast<int>('0') + (i / 10) % 10);
+ hello[9] = static_cast<char>(static_cast<int>('0') + i % 10);
+ if (i % 2 == 0) {
+ int16_t definition_level = 1;
+ value.ptr = reinterpret_cast<const uint8_t*>(&hello[0]);
+ value.len = FIXED_LENGTH;
+ ba_writer->WriteBatch(1, &definition_level, nullptr, &value);
+ } else {
+ int16_t definition_level = 0;
+ ba_writer->WriteBatch(1, &definition_level, nullptr, nullptr);
+ }
+ }
+
+ // Write the FixedLengthByteArray column
+ parquet::FixedLenByteArrayWriter* flba_writer =
+ static_cast<parquet::FixedLenByteArrayWriter*>(rg_writer->NextColumn());
+ for (int i = 0; i < NUM_ROWS_PER_ROW_GROUP; i++) {
+ parquet::FixedLenByteArray value;
+ char v = static_cast<char>(i);
+ char flba[FIXED_LENGTH] = {v, v, v, v, v, v, v, v, v, v};
+ value.ptr = reinterpret_cast<const uint8_t*>(&flba[0]);
+
+ flba_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Close the ParquetFileWriter
+ file_writer->Close();
+
+ // Write the bytes to file
+ DCHECK(out_file->Close().ok());
+ } catch (const std::exception& e) {
+ std::cerr << "Parquet write error: " << e.what() << std::endl;
+ return -1;
+ }
+
+ /**********************************************************************************
+ PARQUET READER EXAMPLE
+ **********************************************************************************/
+
+ try {
+ // Create a ParquetReader instance
+ std::unique_ptr<parquet::ParquetFileReader> parquet_reader =
+ parquet::ParquetFileReader::OpenFile(PARQUET_FILENAME, false);
+
+ // Get the File MetaData
+ std::shared_ptr<parquet::FileMetaData> file_metadata = parquet_reader->metadata();
+
+ // Get the number of RowGroups
+ int num_row_groups = file_metadata->num_row_groups();
+ assert(num_row_groups == 1);
+
+ // Get the number of Columns
+ int num_columns = file_metadata->num_columns();
+ assert(num_columns == 8);
+
+ // Iterate over all the RowGroups in the file
+ for (int r = 0; r < num_row_groups; ++r) {
+ // Get the RowGroup Reader
+ std::shared_ptr<parquet::RowGroupReader> row_group_reader =
+ parquet_reader->RowGroup(r);
+
+ int64_t values_read = 0;
+ int64_t rows_read = 0;
+ int16_t definition_level;
+ int16_t repetition_level;
+ int i;
+ std::shared_ptr<parquet::ColumnReader> column_reader;
+
+ ARROW_UNUSED(rows_read); // prevent warning in release build
+
+ // Get the Column Reader for the boolean column
+ column_reader = row_group_reader->Column(0);
+ parquet::BoolReader* bool_reader =
+ static_cast<parquet::BoolReader*>(column_reader.get());
+
+ // Read all the rows in the column
+ i = 0;
+ while (bool_reader->HasNext()) {
+ bool value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = bool_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ bool expected_value = ((i % 2) == 0) ? true : false;
+ assert(value == expected_value);
+ i++;
+ }
+
+ // Get the Column Reader for the Int32 column
+ column_reader = row_group_reader->Column(1);
+ parquet::Int32Reader* int32_reader =
+ static_cast<parquet::Int32Reader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (int32_reader->HasNext()) {
+ int32_t value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = int32_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ assert(value == i);
+ i++;
+ }
+
+ // Get the Column Reader for the Int64 column
+ column_reader = row_group_reader->Column(2);
+ parquet::Int64Reader* int64_reader =
+ static_cast<parquet::Int64Reader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (int64_reader->HasNext()) {
+ int64_t value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = int64_reader->ReadBatch(1, &definition_level, &repetition_level,
+ &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ int64_t expected_value = i * 1000 * 1000;
+ expected_value *= 1000 * 1000;
+ assert(value == expected_value);
+ if ((i % 2) == 0) {
+ assert(repetition_level == 1);
+ } else {
+ assert(repetition_level == 0);
+ }
+ i++;
+ }
+
+ // Get the Column Reader for the Int96 column
+ column_reader = row_group_reader->Column(3);
+ parquet::Int96Reader* int96_reader =
+ static_cast<parquet::Int96Reader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (int96_reader->HasNext()) {
+ parquet::Int96 value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = int96_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ parquet::Int96 expected_value;
+ ARROW_UNUSED(expected_value); // prevent warning in release build
+ expected_value.value[0] = i;
+ expected_value.value[1] = i + 1;
+ expected_value.value[2] = i + 2;
+ for (int j = 0; j < 3; j++) {
+ assert(value.value[j] == expected_value.value[j]);
+ }
+ i++;
+ }
+
+ // Get the Column Reader for the Float column
+ column_reader = row_group_reader->Column(4);
+ parquet::FloatReader* float_reader =
+ static_cast<parquet::FloatReader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (float_reader->HasNext()) {
+ float value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = float_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ float expected_value = static_cast<float>(i) * 1.1f;
+ assert(value == expected_value);
+ i++;
+ }
+
+ // Get the Column Reader for the Double column
+ column_reader = row_group_reader->Column(5);
+ parquet::DoubleReader* double_reader =
+ static_cast<parquet::DoubleReader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (double_reader->HasNext()) {
+ double value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = double_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ double expected_value = i * 1.1111111;
+ assert(value == expected_value);
+ i++;
+ }
+
+ // Get the Column Reader for the ByteArray column
+ column_reader = row_group_reader->Column(6);
+ parquet::ByteArrayReader* ba_reader =
+ static_cast<parquet::ByteArrayReader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (ba_reader->HasNext()) {
+ parquet::ByteArray value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read =
+ ba_reader->ReadBatch(1, &definition_level, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // Verify the value written
+ char expected_value[FIXED_LENGTH] = "parquet";
+ ARROW_UNUSED(expected_value); // prevent warning in release build
+ expected_value[7] = static_cast<char>('0' + i / 100);
+ expected_value[8] = static_cast<char>('0' + (i / 10) % 10);
+ expected_value[9] = static_cast<char>('0' + i % 10);
+ if (i % 2 == 0) { // only alternate values exist
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ assert(value.len == FIXED_LENGTH);
+ assert(memcmp(value.ptr, &expected_value[0], FIXED_LENGTH) == 0);
+ assert(definition_level == 1);
+ } else {
+ // There are NULL values in the rows written
+ assert(values_read == 0);
+ assert(definition_level == 0);
+ }
+ i++;
+ }
+
+ // Get the Column Reader for the FixedLengthByteArray column
+ column_reader = row_group_reader->Column(7);
+ parquet::FixedLenByteArrayReader* flba_reader =
+ static_cast<parquet::FixedLenByteArrayReader*>(column_reader.get());
+ // Read all the rows in the column
+ i = 0;
+ while (flba_reader->HasNext()) {
+ parquet::FixedLenByteArray value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = flba_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ char v = static_cast<char>(i);
+ char expected_value[FIXED_LENGTH] = {v, v, v, v, v, v, v, v, v, v};
+ assert(memcmp(value.ptr, &expected_value[0], FIXED_LENGTH) == 0);
+ i++;
+ }
+ }
+ } catch (const std::exception& e) {
+ std::cerr << "Parquet read error: " << e.what() << std::endl;
+ return -1;
+ }
+
+ std::cout << "Parquet Writing and Reading Complete" << std::endl;
+
+ return 0;
+}
diff --git a/src/arrow/cpp/examples/parquet/low_level_api/reader_writer.h b/src/arrow/cpp/examples/parquet/low_level_api/reader_writer.h
new file mode 100644
index 000000000..ed8e74653
--- /dev/null
+++ b/src/arrow/cpp/examples/parquet/low_level_api/reader_writer.h
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arrow/io/file.h>
+#include <arrow/util/logging.h>
+#include <parquet/api/reader.h>
+#include <parquet/api/writer.h>
+
+using parquet::ConvertedType;
+using parquet::Repetition;
+using parquet::Type;
+using parquet::schema::GroupNode;
+using parquet::schema::PrimitiveNode;
+
+constexpr int FIXED_LENGTH = 10;
+
+static std::shared_ptr<GroupNode> SetupSchema() {
+ parquet::schema::NodeVector fields;
+ // Create a primitive node named 'boolean_field' with type:BOOLEAN,
+ // repetition:REQUIRED
+ fields.push_back(PrimitiveNode::Make("boolean_field", Repetition::REQUIRED,
+ Type::BOOLEAN, ConvertedType::NONE));
+
+ // Create a primitive node named 'int32_field' with type:INT32, repetition:REQUIRED,
+ // logical type:TIME_MILLIS
+ fields.push_back(PrimitiveNode::Make("int32_field", Repetition::REQUIRED, Type::INT32,
+ ConvertedType::TIME_MILLIS));
+
+ // Create a primitive node named 'int64_field' with type:INT64, repetition:REPEATED
+ fields.push_back(PrimitiveNode::Make("int64_field", Repetition::REPEATED, Type::INT64,
+ ConvertedType::NONE));
+
+ fields.push_back(PrimitiveNode::Make("int96_field", Repetition::REQUIRED, Type::INT96,
+ ConvertedType::NONE));
+
+ fields.push_back(PrimitiveNode::Make("float_field", Repetition::REQUIRED, Type::FLOAT,
+ ConvertedType::NONE));
+
+ fields.push_back(PrimitiveNode::Make("double_field", Repetition::REQUIRED, Type::DOUBLE,
+ ConvertedType::NONE));
+
+ // Create a primitive node named 'ba_field' with type:BYTE_ARRAY, repetition:OPTIONAL
+ fields.push_back(PrimitiveNode::Make("ba_field", Repetition::OPTIONAL, Type::BYTE_ARRAY,
+ ConvertedType::NONE));
+
+ // Create a primitive node named 'flba_field' with type:FIXED_LEN_BYTE_ARRAY,
+ // repetition:REQUIRED, field_length = FIXED_LENGTH
+ fields.push_back(PrimitiveNode::Make("flba_field", Repetition::REQUIRED,
+ Type::FIXED_LEN_BYTE_ARRAY, ConvertedType::NONE,
+ FIXED_LENGTH));
+
+ // Create a GroupNode named 'schema' using the primitive nodes defined above
+ // This GroupNode is the root node of the schema tree
+ return std::static_pointer_cast<GroupNode>(
+ GroupNode::Make("schema", Repetition::REQUIRED, fields));
+}
diff --git a/src/arrow/cpp/examples/parquet/low_level_api/reader_writer2.cc b/src/arrow/cpp/examples/parquet/low_level_api/reader_writer2.cc
new file mode 100644
index 000000000..65dd5799e
--- /dev/null
+++ b/src/arrow/cpp/examples/parquet/low_level_api/reader_writer2.cc
@@ -0,0 +1,434 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <reader_writer.h>
+
+#include <cassert>
+#include <fstream>
+#include <iostream>
+#include <memory>
+
+/*
+ * This example describes writing and reading Parquet Files in C++ and serves as a
+ * reference to the API.
+ * The file contains all the physical data types supported by Parquet.
+ * This example uses the RowGroupWriter API that supports writing RowGroups based on a
+ * certain size.
+ **/
+
+/* Parquet is a structured columnar file format
+ * Parquet File = "Parquet data" + "Parquet Metadata"
+ * "Parquet data" is simply a vector of RowGroups. Each RowGroup is a batch of rows in a
+ * columnar layout
+ * "Parquet Metadata" contains the "file schema" and attributes of the RowGroups and their
+ * Columns
+ * "file schema" is a tree where each node is either a primitive type (leaf nodes) or a
+ * complex (nested) type (internal nodes)
+ * For specific details, please refer the format here:
+ * https://github.com/apache/parquet-format/blob/master/LogicalTypes.md
+ **/
+
+constexpr int NUM_ROWS = 2500000;
+constexpr int64_t ROW_GROUP_SIZE = 16 * 1024 * 1024; // 16 MB
+const char PARQUET_FILENAME[] = "parquet_cpp_example2.parquet";
+
+int main(int argc, char** argv) {
+ /**********************************************************************************
+ PARQUET WRITER EXAMPLE
+ **********************************************************************************/
+ // parquet::REQUIRED fields do not need definition and repetition level values
+ // parquet::OPTIONAL fields require only definition level values
+ // parquet::REPEATED fields require both definition and repetition level values
+ try {
+ // Create a local file output stream instance.
+ using FileClass = ::arrow::io::FileOutputStream;
+ std::shared_ptr<FileClass> out_file;
+ PARQUET_ASSIGN_OR_THROW(out_file, FileClass::Open(PARQUET_FILENAME));
+
+ // Setup the parquet schema
+ std::shared_ptr<GroupNode> schema = SetupSchema();
+
+ // Add writer properties
+ parquet::WriterProperties::Builder builder;
+ builder.compression(parquet::Compression::SNAPPY);
+ std::shared_ptr<parquet::WriterProperties> props = builder.build();
+
+ // Create a ParquetFileWriter instance
+ std::shared_ptr<parquet::ParquetFileWriter> file_writer =
+ parquet::ParquetFileWriter::Open(out_file, schema, props);
+
+ // Append a BufferedRowGroup to keep the RowGroup open until a certain size
+ parquet::RowGroupWriter* rg_writer = file_writer->AppendBufferedRowGroup();
+
+ int num_columns = file_writer->num_columns();
+ std::vector<int64_t> buffered_values_estimate(num_columns, 0);
+ for (int i = 0; i < NUM_ROWS; i++) {
+ int64_t estimated_bytes = 0;
+ // Get the estimated size of the values that are not written to a page yet
+ for (int n = 0; n < num_columns; n++) {
+ estimated_bytes += buffered_values_estimate[n];
+ }
+
+ // We need to consider the compressed pages
+ // as well as the values that are not compressed yet
+ if ((rg_writer->total_bytes_written() + rg_writer->total_compressed_bytes() +
+ estimated_bytes) > ROW_GROUP_SIZE) {
+ rg_writer->Close();
+ std::fill(buffered_values_estimate.begin(), buffered_values_estimate.end(), 0);
+ rg_writer = file_writer->AppendBufferedRowGroup();
+ }
+
+ int col_id = 0;
+ // Write the Bool column
+ parquet::BoolWriter* bool_writer =
+ static_cast<parquet::BoolWriter*>(rg_writer->column(col_id));
+ bool bool_value = ((i % 2) == 0) ? true : false;
+ bool_writer->WriteBatch(1, nullptr, nullptr, &bool_value);
+ buffered_values_estimate[col_id] = bool_writer->EstimatedBufferedValueBytes();
+
+ // Write the Int32 column
+ col_id++;
+ parquet::Int32Writer* int32_writer =
+ static_cast<parquet::Int32Writer*>(rg_writer->column(col_id));
+ int32_t int32_value = i;
+ int32_writer->WriteBatch(1, nullptr, nullptr, &int32_value);
+ buffered_values_estimate[col_id] = int32_writer->EstimatedBufferedValueBytes();
+
+ // Write the Int64 column. Each row has repeats twice.
+ col_id++;
+ parquet::Int64Writer* int64_writer =
+ static_cast<parquet::Int64Writer*>(rg_writer->column(col_id));
+ int64_t int64_value1 = 2 * i;
+ int16_t definition_level = 1;
+ int16_t repetition_level = 0;
+ int64_writer->WriteBatch(1, &definition_level, &repetition_level, &int64_value1);
+ int64_t int64_value2 = (2 * i + 1);
+ repetition_level = 1; // start of a new record
+ int64_writer->WriteBatch(1, &definition_level, &repetition_level, &int64_value2);
+ buffered_values_estimate[col_id] = int64_writer->EstimatedBufferedValueBytes();
+
+ // Write the INT96 column.
+ col_id++;
+ parquet::Int96Writer* int96_writer =
+ static_cast<parquet::Int96Writer*>(rg_writer->column(col_id));
+ parquet::Int96 int96_value;
+ int96_value.value[0] = i;
+ int96_value.value[1] = i + 1;
+ int96_value.value[2] = i + 2;
+ int96_writer->WriteBatch(1, nullptr, nullptr, &int96_value);
+ buffered_values_estimate[col_id] = int96_writer->EstimatedBufferedValueBytes();
+
+ // Write the Float column
+ col_id++;
+ parquet::FloatWriter* float_writer =
+ static_cast<parquet::FloatWriter*>(rg_writer->column(col_id));
+ float float_value = static_cast<float>(i) * 1.1f;
+ float_writer->WriteBatch(1, nullptr, nullptr, &float_value);
+ buffered_values_estimate[col_id] = float_writer->EstimatedBufferedValueBytes();
+
+ // Write the Double column
+ col_id++;
+ parquet::DoubleWriter* double_writer =
+ static_cast<parquet::DoubleWriter*>(rg_writer->column(col_id));
+ double double_value = i * 1.1111111;
+ double_writer->WriteBatch(1, nullptr, nullptr, &double_value);
+ buffered_values_estimate[col_id] = double_writer->EstimatedBufferedValueBytes();
+
+ // Write the ByteArray column. Make every alternate values NULL
+ col_id++;
+ parquet::ByteArrayWriter* ba_writer =
+ static_cast<parquet::ByteArrayWriter*>(rg_writer->column(col_id));
+ parquet::ByteArray ba_value;
+ char hello[FIXED_LENGTH] = "parquet";
+ hello[7] = static_cast<char>(static_cast<int>('0') + i / 100);
+ hello[8] = static_cast<char>(static_cast<int>('0') + (i / 10) % 10);
+ hello[9] = static_cast<char>(static_cast<int>('0') + i % 10);
+ if (i % 2 == 0) {
+ int16_t definition_level = 1;
+ ba_value.ptr = reinterpret_cast<const uint8_t*>(&hello[0]);
+ ba_value.len = FIXED_LENGTH;
+ ba_writer->WriteBatch(1, &definition_level, nullptr, &ba_value);
+ } else {
+ int16_t definition_level = 0;
+ ba_writer->WriteBatch(1, &definition_level, nullptr, nullptr);
+ }
+ buffered_values_estimate[col_id] = ba_writer->EstimatedBufferedValueBytes();
+
+ // Write the FixedLengthByteArray column
+ col_id++;
+ parquet::FixedLenByteArrayWriter* flba_writer =
+ static_cast<parquet::FixedLenByteArrayWriter*>(rg_writer->column(col_id));
+ parquet::FixedLenByteArray flba_value;
+ char v = static_cast<char>(i);
+ char flba[FIXED_LENGTH] = {v, v, v, v, v, v, v, v, v, v};
+ flba_value.ptr = reinterpret_cast<const uint8_t*>(&flba[0]);
+
+ flba_writer->WriteBatch(1, nullptr, nullptr, &flba_value);
+ buffered_values_estimate[col_id] = flba_writer->EstimatedBufferedValueBytes();
+ }
+
+ // Close the RowGroupWriter
+ rg_writer->Close();
+ // Close the ParquetFileWriter
+ file_writer->Close();
+
+ // Write the bytes to file
+ DCHECK(out_file->Close().ok());
+ } catch (const std::exception& e) {
+ std::cerr << "Parquet write error: " << e.what() << std::endl;
+ return -1;
+ }
+
+ /**********************************************************************************
+ PARQUET READER EXAMPLE
+ **********************************************************************************/
+
+ try {
+ // Create a ParquetReader instance
+ std::unique_ptr<parquet::ParquetFileReader> parquet_reader =
+ parquet::ParquetFileReader::OpenFile(PARQUET_FILENAME, false);
+
+ // Get the File MetaData
+ std::shared_ptr<parquet::FileMetaData> file_metadata = parquet_reader->metadata();
+
+ int num_row_groups = file_metadata->num_row_groups();
+
+ // Get the number of Columns
+ int num_columns = file_metadata->num_columns();
+ assert(num_columns == 8);
+
+ std::vector<int> col_row_counts(num_columns, 0);
+
+ // Iterate over all the RowGroups in the file
+ for (int r = 0; r < num_row_groups; ++r) {
+ // Get the RowGroup Reader
+ std::shared_ptr<parquet::RowGroupReader> row_group_reader =
+ parquet_reader->RowGroup(r);
+
+ assert(row_group_reader->metadata()->total_byte_size() < ROW_GROUP_SIZE);
+
+ int64_t values_read = 0;
+ int64_t rows_read = 0;
+ int16_t definition_level;
+ int16_t repetition_level;
+ std::shared_ptr<parquet::ColumnReader> column_reader;
+ int col_id = 0;
+
+ ARROW_UNUSED(rows_read); // prevent warning in release build
+
+ // Get the Column Reader for the boolean column
+ column_reader = row_group_reader->Column(col_id);
+ parquet::BoolReader* bool_reader =
+ static_cast<parquet::BoolReader*>(column_reader.get());
+
+ // Read all the rows in the column
+ while (bool_reader->HasNext()) {
+ bool value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = bool_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ bool expected_value = ((col_row_counts[col_id] % 2) == 0) ? true : false;
+ assert(value == expected_value);
+ col_row_counts[col_id]++;
+ }
+
+ // Get the Column Reader for the Int32 column
+ col_id++;
+ column_reader = row_group_reader->Column(col_id);
+ parquet::Int32Reader* int32_reader =
+ static_cast<parquet::Int32Reader*>(column_reader.get());
+ // Read all the rows in the column
+ while (int32_reader->HasNext()) {
+ int32_t value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = int32_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ assert(value == col_row_counts[col_id]);
+ col_row_counts[col_id]++;
+ }
+
+ // Get the Column Reader for the Int64 column
+ col_id++;
+ column_reader = row_group_reader->Column(col_id);
+ parquet::Int64Reader* int64_reader =
+ static_cast<parquet::Int64Reader*>(column_reader.get());
+ // Read all the rows in the column
+ while (int64_reader->HasNext()) {
+ int64_t value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = int64_reader->ReadBatch(1, &definition_level, &repetition_level,
+ &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ int64_t expected_value = col_row_counts[col_id];
+ assert(value == expected_value);
+ if ((col_row_counts[col_id] % 2) == 0) {
+ assert(repetition_level == 0);
+ } else {
+ assert(repetition_level == 1);
+ }
+ col_row_counts[col_id]++;
+ }
+
+ // Get the Column Reader for the Int96 column
+ col_id++;
+ column_reader = row_group_reader->Column(col_id);
+ parquet::Int96Reader* int96_reader =
+ static_cast<parquet::Int96Reader*>(column_reader.get());
+ // Read all the rows in the column
+ while (int96_reader->HasNext()) {
+ parquet::Int96 value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = int96_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ parquet::Int96 expected_value;
+ ARROW_UNUSED(expected_value); // prevent warning in release build
+ expected_value.value[0] = col_row_counts[col_id];
+ expected_value.value[1] = col_row_counts[col_id] + 1;
+ expected_value.value[2] = col_row_counts[col_id] + 2;
+ for (int j = 0; j < 3; j++) {
+ assert(value.value[j] == expected_value.value[j]);
+ }
+ col_row_counts[col_id]++;
+ }
+
+ // Get the Column Reader for the Float column
+ col_id++;
+ column_reader = row_group_reader->Column(col_id);
+ parquet::FloatReader* float_reader =
+ static_cast<parquet::FloatReader*>(column_reader.get());
+ // Read all the rows in the column
+ while (float_reader->HasNext()) {
+ float value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = float_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ float expected_value = static_cast<float>(col_row_counts[col_id]) * 1.1f;
+ assert(value == expected_value);
+ col_row_counts[col_id]++;
+ }
+
+ // Get the Column Reader for the Double column
+ col_id++;
+ column_reader = row_group_reader->Column(col_id);
+ parquet::DoubleReader* double_reader =
+ static_cast<parquet::DoubleReader*>(column_reader.get());
+ // Read all the rows in the column
+ while (double_reader->HasNext()) {
+ double value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = double_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ double expected_value = col_row_counts[col_id] * 1.1111111;
+ assert(value == expected_value);
+ col_row_counts[col_id]++;
+ }
+
+ // Get the Column Reader for the ByteArray column
+ col_id++;
+ column_reader = row_group_reader->Column(col_id);
+ parquet::ByteArrayReader* ba_reader =
+ static_cast<parquet::ByteArrayReader*>(column_reader.get());
+ // Read all the rows in the column
+ while (ba_reader->HasNext()) {
+ parquet::ByteArray value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read =
+ ba_reader->ReadBatch(1, &definition_level, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // Verify the value written
+ char expected_value[FIXED_LENGTH] = "parquet";
+ ARROW_UNUSED(expected_value); // prevent warning in release build
+ expected_value[7] = static_cast<char>('0' + col_row_counts[col_id] / 100);
+ expected_value[8] = static_cast<char>('0' + (col_row_counts[col_id] / 10) % 10);
+ expected_value[9] = static_cast<char>('0' + col_row_counts[col_id] % 10);
+ if (col_row_counts[col_id] % 2 == 0) { // only alternate values exist
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ assert(value.len == FIXED_LENGTH);
+ assert(memcmp(value.ptr, &expected_value[0], FIXED_LENGTH) == 0);
+ assert(definition_level == 1);
+ } else {
+ // There are NULL values in the rows written
+ assert(values_read == 0);
+ assert(definition_level == 0);
+ }
+ col_row_counts[col_id]++;
+ }
+
+ // Get the Column Reader for the FixedLengthByteArray column
+ col_id++;
+ column_reader = row_group_reader->Column(col_id);
+ parquet::FixedLenByteArrayReader* flba_reader =
+ static_cast<parquet::FixedLenByteArrayReader*>(column_reader.get());
+ // Read all the rows in the column
+ while (flba_reader->HasNext()) {
+ parquet::FixedLenByteArray value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = flba_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ assert(rows_read == 1);
+ // There are no NULL values in the rows written
+ assert(values_read == 1);
+ // Verify the value written
+ char v = static_cast<char>(col_row_counts[col_id]);
+ char expected_value[FIXED_LENGTH] = {v, v, v, v, v, v, v, v, v, v};
+ assert(memcmp(value.ptr, &expected_value[0], FIXED_LENGTH) == 0);
+ col_row_counts[col_id]++;
+ }
+ }
+ } catch (const std::exception& e) {
+ std::cerr << "Parquet read error: " << e.what() << std::endl;
+ return -1;
+ }
+
+ std::cout << "Parquet Writing and Reading Complete" << std::endl;
+
+ return 0;
+}
diff --git a/src/arrow/cpp/examples/parquet/parquet_arrow/CMakeLists.txt b/src/arrow/cpp/examples/parquet/parquet_arrow/CMakeLists.txt
new file mode 100644
index 000000000..43eb21957
--- /dev/null
+++ b/src/arrow/cpp/examples/parquet/parquet_arrow/CMakeLists.txt
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Require cmake that supports BYPRODUCTS in add_custom_command, ExternalProject_Add [1].
+cmake_minimum_required(VERSION 3.2.0)
+
+project(parquet_arrow_example)
+
+include(ExternalProject)
+include(FindPkgConfig)
+include(GNUInstallDirs)
+
+set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake_modules")
+
+# This ensures that things like gnu++11 get passed correctly
+set(CMAKE_CXX_STANDARD 11)
+
+# We require a C++11 compliant compiler
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+# Look for installed packages the system
+find_package(Arrow REQUIRED)
+find_package(Parquet REQUIRED)
+
+include_directories(SYSTEM ${ARROW_INCLUDE_DIR} ${PARQUET_INCLUDE_DIR})
+
+add_executable(parquet_arrow_example reader_writer.cc)
+target_link_libraries(parquet_arrow_example ${PARQUET_SHARED_LIB} ${ARROW_SHARED_LIB})
diff --git a/src/arrow/cpp/examples/parquet/parquet_arrow/README.md b/src/arrow/cpp/examples/parquet/parquet_arrow/README.md
new file mode 100644
index 000000000..e99819fd2
--- /dev/null
+++ b/src/arrow/cpp/examples/parquet/parquet_arrow/README.md
@@ -0,0 +1,20 @@
+<!---
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. See accompanying LICENSE file.
+-->
+
+Using parquet-cpp with the arrow interface
+==========================================
+
+This folder contains an example project that shows how to setup a CMake project
+that consumes `parquet-cpp` as a library as well as how you can use the
+`parquet/arrow` interface to reading and write Apache Parquet files.
diff --git a/src/arrow/cpp/examples/parquet/parquet_arrow/reader_writer.cc b/src/arrow/cpp/examples/parquet/parquet_arrow/reader_writer.cc
new file mode 100644
index 000000000..f5d96ec16
--- /dev/null
+++ b/src/arrow/cpp/examples/parquet/parquet_arrow/reader_writer.cc
@@ -0,0 +1,140 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arrow/api.h>
+#include <arrow/io/api.h>
+#include <parquet/arrow/reader.h>
+#include <parquet/arrow/writer.h>
+#include <parquet/exception.h>
+
+#include <iostream>
+
+// #0 Build dummy data to pass around
+// To have some input data, we first create an Arrow Table that holds
+// some data.
+std::shared_ptr<arrow::Table> generate_table() {
+ arrow::Int64Builder i64builder;
+ PARQUET_THROW_NOT_OK(i64builder.AppendValues({1, 2, 3, 4, 5}));
+ std::shared_ptr<arrow::Array> i64array;
+ PARQUET_THROW_NOT_OK(i64builder.Finish(&i64array));
+
+ arrow::StringBuilder strbuilder;
+ PARQUET_THROW_NOT_OK(strbuilder.Append("some"));
+ PARQUET_THROW_NOT_OK(strbuilder.Append("string"));
+ PARQUET_THROW_NOT_OK(strbuilder.Append("content"));
+ PARQUET_THROW_NOT_OK(strbuilder.Append("in"));
+ PARQUET_THROW_NOT_OK(strbuilder.Append("rows"));
+ std::shared_ptr<arrow::Array> strarray;
+ PARQUET_THROW_NOT_OK(strbuilder.Finish(&strarray));
+
+ std::shared_ptr<arrow::Schema> schema = arrow::schema(
+ {arrow::field("int", arrow::int64()), arrow::field("str", arrow::utf8())});
+
+ return arrow::Table::Make(schema, {i64array, strarray});
+}
+
+// #1 Write out the data as a Parquet file
+void write_parquet_file(const arrow::Table& table) {
+ std::shared_ptr<arrow::io::FileOutputStream> outfile;
+ PARQUET_ASSIGN_OR_THROW(
+ outfile, arrow::io::FileOutputStream::Open("parquet-arrow-example.parquet"));
+ // The last argument to the function call is the size of the RowGroup in
+ // the parquet file. Normally you would choose this to be rather large but
+ // for the example, we use a small value to have multiple RowGroups.
+ PARQUET_THROW_NOT_OK(
+ parquet::arrow::WriteTable(table, arrow::default_memory_pool(), outfile, 3));
+}
+
+// #2: Fully read in the file
+void read_whole_file() {
+ std::cout << "Reading parquet-arrow-example.parquet at once" << std::endl;
+ std::shared_ptr<arrow::io::ReadableFile> infile;
+ PARQUET_ASSIGN_OR_THROW(infile,
+ arrow::io::ReadableFile::Open("parquet-arrow-example.parquet",
+ arrow::default_memory_pool()));
+
+ std::unique_ptr<parquet::arrow::FileReader> reader;
+ PARQUET_THROW_NOT_OK(
+ parquet::arrow::OpenFile(infile, arrow::default_memory_pool(), &reader));
+ std::shared_ptr<arrow::Table> table;
+ PARQUET_THROW_NOT_OK(reader->ReadTable(&table));
+ std::cout << "Loaded " << table->num_rows() << " rows in " << table->num_columns()
+ << " columns." << std::endl;
+}
+
+// #3: Read only a single RowGroup of the parquet file
+void read_single_rowgroup() {
+ std::cout << "Reading first RowGroup of parquet-arrow-example.parquet" << std::endl;
+ std::shared_ptr<arrow::io::ReadableFile> infile;
+ PARQUET_ASSIGN_OR_THROW(infile,
+ arrow::io::ReadableFile::Open("parquet-arrow-example.parquet",
+ arrow::default_memory_pool()));
+
+ std::unique_ptr<parquet::arrow::FileReader> reader;
+ PARQUET_THROW_NOT_OK(
+ parquet::arrow::OpenFile(infile, arrow::default_memory_pool(), &reader));
+ std::shared_ptr<arrow::Table> table;
+ PARQUET_THROW_NOT_OK(reader->RowGroup(0)->ReadTable(&table));
+ std::cout << "Loaded " << table->num_rows() << " rows in " << table->num_columns()
+ << " columns." << std::endl;
+}
+
+// #4: Read only a single column of the whole parquet file
+void read_single_column() {
+ std::cout << "Reading first column of parquet-arrow-example.parquet" << std::endl;
+ std::shared_ptr<arrow::io::ReadableFile> infile;
+ PARQUET_ASSIGN_OR_THROW(infile,
+ arrow::io::ReadableFile::Open("parquet-arrow-example.parquet",
+ arrow::default_memory_pool()));
+
+ std::unique_ptr<parquet::arrow::FileReader> reader;
+ PARQUET_THROW_NOT_OK(
+ parquet::arrow::OpenFile(infile, arrow::default_memory_pool(), &reader));
+ std::shared_ptr<arrow::ChunkedArray> array;
+ PARQUET_THROW_NOT_OK(reader->ReadColumn(0, &array));
+ PARQUET_THROW_NOT_OK(arrow::PrettyPrint(*array, 4, &std::cout));
+ std::cout << std::endl;
+}
+
+// #5: Read only a single column of a RowGroup (this is known as ColumnChunk)
+// from the Parquet file.
+void read_single_column_chunk() {
+ std::cout << "Reading first ColumnChunk of the first RowGroup of "
+ "parquet-arrow-example.parquet"
+ << std::endl;
+ std::shared_ptr<arrow::io::ReadableFile> infile;
+ PARQUET_ASSIGN_OR_THROW(infile,
+ arrow::io::ReadableFile::Open("parquet-arrow-example.parquet",
+ arrow::default_memory_pool()));
+
+ std::unique_ptr<parquet::arrow::FileReader> reader;
+ PARQUET_THROW_NOT_OK(
+ parquet::arrow::OpenFile(infile, arrow::default_memory_pool(), &reader));
+ std::shared_ptr<arrow::ChunkedArray> array;
+ PARQUET_THROW_NOT_OK(reader->RowGroup(0)->Column(0)->Read(&array));
+ PARQUET_THROW_NOT_OK(arrow::PrettyPrint(*array, 4, &std::cout));
+ std::cout << std::endl;
+}
+
+int main(int argc, char** argv) {
+ std::shared_ptr<arrow::Table> table = generate_table();
+ write_parquet_file(*table);
+ read_whole_file();
+ read_single_rowgroup();
+ read_single_column();
+ read_single_column_chunk();
+}
diff --git a/src/arrow/cpp/examples/parquet/parquet_stream_api/stream_reader_writer.cc b/src/arrow/cpp/examples/parquet/parquet_stream_api/stream_reader_writer.cc
new file mode 100644
index 000000000..64ab7af49
--- /dev/null
+++ b/src/arrow/cpp/examples/parquet/parquet_stream_api/stream_reader_writer.cc
@@ -0,0 +1,324 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cassert>
+#include <chrono>
+#include <cstdint>
+#include <cstring>
+#include <ctime>
+#include <iomanip>
+#include <iostream>
+#include <utility>
+
+#include "arrow/io/file.h"
+#include "parquet/exception.h"
+#include "parquet/stream_reader.h"
+#include "parquet/stream_writer.h"
+
+// This file gives an example of how to use the parquet::StreamWriter
+// and parquet::StreamReader classes.
+// It shows writing/reading of the supported types as well as how a
+// user-defined type can be handled.
+
+template <typename T>
+using optional = parquet::StreamReader::optional<T>;
+
+// Example of a user-defined type to be written to/read from Parquet
+// using C++ input/output operators.
+class UserTimestamp {
+ public:
+ UserTimestamp() = default;
+
+ explicit UserTimestamp(const std::chrono::microseconds v) : ts_{v} {}
+
+ bool operator==(const UserTimestamp& x) const { return ts_ == x.ts_; }
+
+ void dump(std::ostream& os) const {
+ const auto t = static_cast<std::time_t>(
+ std::chrono::duration_cast<std::chrono::seconds>(ts_).count());
+ os << std::put_time(std::gmtime(&t), "%Y%m%d-%H%M%S");
+ }
+
+ void dump(parquet::StreamWriter& os) const { os << ts_; }
+
+ private:
+ std::chrono::microseconds ts_;
+};
+
+std::ostream& operator<<(std::ostream& os, const UserTimestamp& v) {
+ v.dump(os);
+ return os;
+}
+
+parquet::StreamWriter& operator<<(parquet::StreamWriter& os, const UserTimestamp& v) {
+ v.dump(os);
+ return os;
+}
+
+parquet::StreamReader& operator>>(parquet::StreamReader& os, UserTimestamp& v) {
+ std::chrono::microseconds ts;
+
+ os >> ts;
+ v = UserTimestamp{ts};
+
+ return os;
+}
+
+std::shared_ptr<parquet::schema::GroupNode> GetSchema() {
+ parquet::schema::NodeVector fields;
+
+ fields.push_back(parquet::schema::PrimitiveNode::Make(
+ "string_field", parquet::Repetition::OPTIONAL, parquet::Type::BYTE_ARRAY,
+ parquet::ConvertedType::UTF8));
+
+ fields.push_back(parquet::schema::PrimitiveNode::Make(
+ "char_field", parquet::Repetition::REQUIRED, parquet::Type::FIXED_LEN_BYTE_ARRAY,
+ parquet::ConvertedType::NONE, 1));
+
+ fields.push_back(parquet::schema::PrimitiveNode::Make(
+ "char[4]_field", parquet::Repetition::REQUIRED, parquet::Type::FIXED_LEN_BYTE_ARRAY,
+ parquet::ConvertedType::NONE, 4));
+
+ fields.push_back(parquet::schema::PrimitiveNode::Make(
+ "int8_field", parquet::Repetition::REQUIRED, parquet::Type::INT32,
+ parquet::ConvertedType::INT_8));
+
+ fields.push_back(parquet::schema::PrimitiveNode::Make(
+ "uint16_field", parquet::Repetition::REQUIRED, parquet::Type::INT32,
+ parquet::ConvertedType::UINT_16));
+
+ fields.push_back(parquet::schema::PrimitiveNode::Make(
+ "int32_field", parquet::Repetition::REQUIRED, parquet::Type::INT32,
+ parquet::ConvertedType::INT_32));
+
+ fields.push_back(parquet::schema::PrimitiveNode::Make(
+ "uint64_field", parquet::Repetition::OPTIONAL, parquet::Type::INT64,
+ parquet::ConvertedType::UINT_64));
+
+ fields.push_back(parquet::schema::PrimitiveNode::Make(
+ "double_field", parquet::Repetition::REQUIRED, parquet::Type::DOUBLE,
+ parquet::ConvertedType::NONE));
+
+ // User defined timestamp type.
+ fields.push_back(parquet::schema::PrimitiveNode::Make(
+ "timestamp_field", parquet::Repetition::REQUIRED, parquet::Type::INT64,
+ parquet::ConvertedType::TIMESTAMP_MICROS));
+
+ fields.push_back(parquet::schema::PrimitiveNode::Make(
+ "chrono_milliseconds_field", parquet::Repetition::REQUIRED, parquet::Type::INT64,
+ parquet::ConvertedType::TIMESTAMP_MILLIS));
+
+ return std::static_pointer_cast<parquet::schema::GroupNode>(
+ parquet::schema::GroupNode::Make("schema", parquet::Repetition::REQUIRED, fields));
+}
+
+struct TestData {
+ static const int num_rows = 2000;
+
+ static void init() { std::time(&ts_offset_); }
+
+ static optional<std::string> GetOptString(const int i) {
+ if (i % 2 == 0) return {};
+ return "Str #" + std::to_string(i);
+ }
+ static arrow::util::string_view GetStringView(const int i) {
+ static std::string string;
+ string = "StringView #" + std::to_string(i);
+ return arrow::util::string_view(string);
+ }
+ static const char* GetCharPtr(const int i) {
+ static std::string string;
+ string = "CharPtr #" + std::to_string(i);
+ return string.c_str();
+ }
+ static char GetChar(const int i) { return i & 1 ? 'M' : 'F'; }
+ static int8_t GetInt8(const int i) { return static_cast<int8_t>((i % 256) - 128); }
+ static uint16_t GetUInt16(const int i) { return static_cast<uint16_t>(i); }
+ static int32_t GetInt32(const int i) { return 3 * i - 17; }
+ static optional<uint64_t> GetOptUInt64(const int i) {
+ if (i % 11 == 0) return {};
+ return (1ull << 40) + i * i + 101;
+ }
+ static double GetDouble(const int i) { return 6.62607004e-34 * 3e8 * i; }
+ static UserTimestamp GetUserTimestamp(const int i) {
+ return UserTimestamp{std::chrono::microseconds{(ts_offset_ + 3 * i) * 1000000 + i}};
+ }
+ static std::chrono::milliseconds GetChronoMilliseconds(const int i) {
+ return std::chrono::milliseconds{(ts_offset_ + 3 * i) * 1000ull + i};
+ }
+
+ static char char4_array[4];
+
+ private:
+ static std::time_t ts_offset_;
+};
+
+char TestData::char4_array[] = "XYZ";
+std::time_t TestData::ts_offset_;
+
+void WriteParquetFile() {
+ std::shared_ptr<arrow::io::FileOutputStream> outfile;
+
+ PARQUET_ASSIGN_OR_THROW(
+ outfile, arrow::io::FileOutputStream::Open("parquet-stream-api-example.parquet"));
+
+ parquet::WriterProperties::Builder builder;
+
+#if defined ARROW_WITH_BROTLI
+ builder.compression(parquet::Compression::BROTLI);
+#elif defined ARROW_WITH_ZSTD
+ builder.compression(parquet::Compression::ZSTD);
+#endif
+
+ parquet::StreamWriter os{
+ parquet::ParquetFileWriter::Open(outfile, GetSchema(), builder.build())};
+
+ os.SetMaxRowGroupSize(1000);
+
+ for (auto i = 0; i < TestData::num_rows; ++i) {
+ // Output string using 3 different types: std::string, arrow::util::string_view and
+ // const char *.
+ switch (i % 3) {
+ case 0:
+ os << TestData::GetOptString(i);
+ break;
+ case 1:
+ os << TestData::GetStringView(i);
+ break;
+ case 2:
+ os << TestData::GetCharPtr(i);
+ break;
+ }
+ os << TestData::GetChar(i);
+ switch (i % 2) {
+ case 0:
+ os << TestData::char4_array;
+ break;
+ case 1:
+ os << parquet::StreamWriter::FixedStringView{TestData::GetCharPtr(i), 4};
+ break;
+ }
+ os << TestData::GetInt8(i);
+ os << TestData::GetUInt16(i);
+ os << TestData::GetInt32(i);
+ os << TestData::GetOptUInt64(i);
+ os << TestData::GetDouble(i);
+ os << TestData::GetUserTimestamp(i);
+ os << TestData::GetChronoMilliseconds(i);
+ os << parquet::EndRow;
+
+ if (i == TestData::num_rows / 2) {
+ os << parquet::EndRowGroup;
+ }
+ }
+ std::cout << "Parquet Stream Writing complete." << std::endl;
+}
+
+void ReadParquetFile() {
+ std::shared_ptr<arrow::io::ReadableFile> infile;
+
+ PARQUET_ASSIGN_OR_THROW(
+ infile, arrow::io::ReadableFile::Open("parquet-stream-api-example.parquet"));
+
+ parquet::StreamReader os{parquet::ParquetFileReader::Open(infile)};
+
+ optional<std::string> opt_string;
+ char ch;
+ char char_array[4];
+ int8_t int8;
+ uint16_t uint16;
+ int32_t int32;
+ optional<uint64_t> opt_uint64;
+ double d;
+ UserTimestamp ts_user;
+ std::chrono::milliseconds ts_ms;
+ int i;
+
+ for (i = 0; !os.eof(); ++i) {
+ os >> opt_string;
+ os >> ch;
+ os >> char_array;
+ os >> int8;
+ os >> uint16;
+ os >> int32;
+ os >> opt_uint64;
+ os >> d;
+ os >> ts_user;
+ os >> ts_ms;
+ os >> parquet::EndRow;
+
+ if (0) {
+ // For debugging.
+ std::cout << "Row #" << i << std::endl;
+
+ std::cout << "string[";
+ if (opt_string) {
+ std::cout << *opt_string;
+ } else {
+ std::cout << "N/A";
+ }
+ std::cout << "] char[" << ch << "] charArray[" << char_array << "] int8["
+ << int(int8) << "] uint16[" << uint16 << "] int32[" << int32;
+ std::cout << "] uint64[";
+ if (opt_uint64) {
+ std::cout << *opt_uint64;
+ } else {
+ std::cout << "N/A";
+ }
+ std::cout << "] double[" << d << "] tsUser[" << ts_user << "] tsMs["
+ << ts_ms.count() << "]" << std::endl;
+ }
+ // Check data.
+ switch (i % 3) {
+ case 0:
+ assert(opt_string == TestData::GetOptString(i));
+ break;
+ case 1:
+ assert(*opt_string == TestData::GetStringView(i));
+ break;
+ case 2:
+ assert(*opt_string == TestData::GetCharPtr(i));
+ break;
+ }
+ assert(ch == TestData::GetChar(i));
+ switch (i % 2) {
+ case 0:
+ assert(0 == std::memcmp(char_array, TestData::char4_array, sizeof(char_array)));
+ break;
+ case 1:
+ assert(0 == std::memcmp(char_array, TestData::GetCharPtr(i), sizeof(char_array)));
+ break;
+ }
+ assert(int8 == TestData::GetInt8(i));
+ assert(uint16 == TestData::GetUInt16(i));
+ assert(int32 == TestData::GetInt32(i));
+ assert(opt_uint64 == TestData::GetOptUInt64(i));
+ assert(std::abs(d - TestData::GetDouble(i)) < 1e-6);
+ assert(ts_user == TestData::GetUserTimestamp(i));
+ assert(ts_ms == TestData::GetChronoMilliseconds(i));
+ }
+ assert(TestData::num_rows == i);
+
+ std::cout << "Parquet Stream Reading complete." << std::endl;
+}
+
+int main() {
+ WriteParquetFile();
+ ReadParquetFile();
+
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/ArrowConfig.cmake.in b/src/arrow/cpp/src/arrow/ArrowConfig.cmake.in
new file mode 100644
index 000000000..6209baeec
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ArrowConfig.cmake.in
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This config sets the following variables in your project::
+#
+# ARROW_FULL_SO_VERSION - full shared library version of the found Arrow
+# ARROW_SO_VERSION - shared library version of the found Arrow
+# ARROW_VERSION - version of the found Arrow
+# ARROW_* - options used when the found Arrow is build such as ARROW_COMPUTE
+# Arrow_FOUND - true if Arrow found on the system
+#
+# This config sets the following targets in your project::
+#
+# arrow_shared - for linked as shared library if shared library is built
+# arrow_static - for linked as static library if static library is built
+
+@PACKAGE_INIT@
+
+set(ARROW_VERSION "@ARROW_VERSION@")
+set(ARROW_SO_VERSION "@ARROW_SO_VERSION@")
+set(ARROW_FULL_SO_VERSION "@ARROW_FULL_SO_VERSION@")
+
+set(ARROW_LIBRARY_PATH_SUFFIXES "@ARROW_LIBRARY_PATH_SUFFIXES@")
+set(ARROW_INCLUDE_PATH_SUFFIXES "@ARROW_INCLUDE_PATH_SUFFIXES@")
+set(ARROW_SYSTEM_DEPENDENCIES "@ARROW_SYSTEM_DEPENDENCIES@")
+set(ARROW_BUNDLED_STATIC_LIBS "@ARROW_BUNDLED_STATIC_LIBS@")
+
+include("${CMAKE_CURRENT_LIST_DIR}/ArrowOptions.cmake")
+
+include(CMakeFindDependencyMacro)
+
+# Load targets only once. If we load targets multiple times, CMake reports
+# already existent target error.
+if(NOT (TARGET arrow_shared OR TARGET arrow_static))
+ include("${CMAKE_CURRENT_LIST_DIR}/ArrowTargets.cmake")
+
+ if(TARGET arrow_static)
+ set(CMAKE_THREAD_PREFER_PTHREAD TRUE)
+ set(THREADS_PREFER_PTHREAD_FLAG TRUE)
+ find_dependency(Threads)
+
+ if(DEFINED CMAKE_MODULE_PATH)
+ set(_CMAKE_MODULE_PATH_OLD ${CMAKE_MODULE_PATH})
+ endif()
+ set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}")
+
+ foreach(_DEPENDENCY ${ARROW_SYSTEM_DEPENDENCIES})
+ find_dependency(${_DEPENDENCY})
+ endforeach()
+
+ if(DEFINED _CMAKE_MODULE_PATH_OLD)
+ set(CMAKE_MODULE_PATH ${_CMAKE_MODULE_PATH_OLD})
+ unset(_CMAKE_MODULE_PATH_OLD)
+ else()
+ unset(CMAKE_MODULE_PATH)
+ endif()
+
+ get_property(arrow_static_loc TARGET arrow_static PROPERTY LOCATION)
+ get_filename_component(arrow_lib_dir ${arrow_static_loc} DIRECTORY)
+
+ if(ARROW_BUNDLED_STATIC_LIBS)
+ add_library(arrow_bundled_dependencies STATIC IMPORTED)
+ set_target_properties(
+ arrow_bundled_dependencies
+ PROPERTIES
+ IMPORTED_LOCATION
+ "${arrow_lib_dir}/${CMAKE_STATIC_LIBRARY_PREFIX}arrow_bundled_dependencies${CMAKE_STATIC_LIBRARY_SUFFIX}"
+ )
+
+ get_property(arrow_static_interface_link_libraries
+ TARGET arrow_static
+ PROPERTY INTERFACE_LINK_LIBRARIES)
+ set_target_properties(
+ arrow_static PROPERTIES INTERFACE_LINK_LIBRARIES
+ "${arrow_static_interface_link_libraries};arrow_bundled_dependencies")
+ endif()
+ endif()
+endif()
diff --git a/src/arrow/cpp/src/arrow/ArrowTestingConfig.cmake.in b/src/arrow/cpp/src/arrow/ArrowTestingConfig.cmake.in
new file mode 100644
index 000000000..2b5548c8b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ArrowTestingConfig.cmake.in
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This config sets the following variables in your project::
+#
+# ArrowTesting_FOUND - true if Arrow testing library found on the system
+#
+# This config sets the following targets in your project::
+#
+# arrow_testing_shared - for linked as shared library if shared library is built
+# arrow_testing_static - for linked as static library if static library is built
+
+@PACKAGE_INIT@
+
+include(CMakeFindDependencyMacro)
+find_dependency(Arrow)
+
+# Load targets only once. If we load targets multiple times, CMake reports
+# already existent target error.
+if(NOT (TARGET arrow_testing_shared OR TARGET arrow_testing_static))
+ include("${CMAKE_CURRENT_LIST_DIR}/ArrowTestingTargets.cmake")
+endif()
diff --git a/src/arrow/cpp/src/arrow/CMakeLists.txt b/src/arrow/cpp/src/arrow/CMakeLists.txt
new file mode 100644
index 000000000..231000ac7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/CMakeLists.txt
@@ -0,0 +1,751 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_custom_target(arrow-all)
+add_custom_target(arrow)
+add_custom_target(arrow-benchmarks)
+add_custom_target(arrow-tests)
+add_custom_target(arrow-integration)
+add_dependencies(arrow-all
+ arrow
+ arrow-tests
+ arrow-benchmarks
+ arrow-integration)
+
+# Adding unit tests part of the "arrow" portion of the test suite
+function(ADD_ARROW_TEST REL_TEST_NAME)
+ set(options)
+ set(one_value_args PREFIX)
+ set(multi_value_args LABELS PRECOMPILED_HEADERS)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+
+ if(ARG_PREFIX)
+ set(PREFIX ${ARG_PREFIX})
+ else()
+ set(PREFIX "arrow")
+ endif()
+
+ if(ARG_LABELS)
+ set(LABELS ${ARG_LABELS})
+ else()
+ set(LABELS "arrow-tests")
+ endif()
+
+ # Because of https://gitlab.kitware.com/cmake/cmake/issues/20289,
+ # we must generate the precompiled header on an executable target.
+ # Do that on the first unit test target (here "arrow-array-test")
+ # and reuse the PCH for the other tests.
+ if(ARG_PRECOMPILED_HEADERS)
+ set(PCH_ARGS PRECOMPILED_HEADERS ${ARG_PRECOMPILED_HEADERS})
+ else()
+ set(PCH_ARGS PRECOMPILED_HEADER_LIB "arrow-array-test")
+ endif()
+
+ add_test_case(${REL_TEST_NAME}
+ PREFIX
+ ${PREFIX}
+ LABELS
+ ${LABELS}
+ ${PCH_ARGS}
+ ${ARG_UNPARSED_ARGUMENTS})
+endfunction()
+
+function(ADD_ARROW_FUZZ_TARGET REL_FUZZING_NAME)
+ set(options)
+ set(one_value_args PREFIX)
+ set(multi_value_args)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+
+ if(ARG_PREFIX)
+ set(PREFIX ${ARG_PREFIX})
+ else()
+ set(PREFIX "arrow")
+ endif()
+
+ if(ARROW_BUILD_STATIC)
+ set(LINK_LIBS arrow_static)
+ else()
+ set(LINK_LIBS arrow_shared)
+ endif()
+ add_fuzz_target(${REL_FUZZING_NAME}
+ PREFIX
+ ${PREFIX}
+ LINK_LIBS
+ ${LINK_LIBS}
+ ${ARG_UNPARSED_ARGUMENTS})
+endfunction()
+
+function(ADD_ARROW_BENCHMARK REL_TEST_NAME)
+ set(options)
+ set(one_value_args PREFIX)
+ set(multi_value_args)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+ if(ARG_PREFIX)
+ set(PREFIX ${ARG_PREFIX})
+ else()
+ set(PREFIX "arrow")
+ endif()
+ add_benchmark(${REL_TEST_NAME}
+ PREFIX
+ ${PREFIX}
+ LABELS
+ "arrow-benchmarks"
+ ${ARG_UNPARSED_ARGUMENTS})
+endfunction()
+
+macro(append_avx2_src SRC)
+ if(ARROW_HAVE_RUNTIME_AVX2)
+ list(APPEND ARROW_SRCS ${SRC})
+ set_source_files_properties(${SRC} PROPERTIES SKIP_PRECOMPILE_HEADERS ON)
+ set_source_files_properties(${SRC} PROPERTIES COMPILE_FLAGS ${ARROW_AVX2_FLAG})
+ endif()
+endmacro()
+
+macro(append_avx512_src SRC)
+ if(ARROW_HAVE_RUNTIME_AVX512)
+ list(APPEND ARROW_SRCS ${SRC})
+ set_source_files_properties(${SRC} PROPERTIES SKIP_PRECOMPILE_HEADERS ON)
+ set_source_files_properties(${SRC} PROPERTIES COMPILE_FLAGS ${ARROW_AVX512_FLAG})
+ endif()
+endmacro()
+
+set(ARROW_SRCS
+ array/array_base.cc
+ array/array_binary.cc
+ array/array_decimal.cc
+ array/array_dict.cc
+ array/array_nested.cc
+ array/array_primitive.cc
+ array/builder_adaptive.cc
+ array/builder_base.cc
+ array/builder_binary.cc
+ array/builder_decimal.cc
+ array/builder_dict.cc
+ array/builder_nested.cc
+ array/builder_primitive.cc
+ array/builder_union.cc
+ array/concatenate.cc
+ array/data.cc
+ array/diff.cc
+ array/util.cc
+ array/validate.cc
+ builder.cc
+ buffer.cc
+ chunked_array.cc
+ compare.cc
+ config.cc
+ datum.cc
+ device.cc
+ extension_type.cc
+ memory_pool.cc
+ pretty_print.cc
+ record_batch.cc
+ result.cc
+ scalar.cc
+ sparse_tensor.cc
+ status.cc
+ table.cc
+ table_builder.cc
+ tensor.cc
+ tensor/coo_converter.cc
+ tensor/csf_converter.cc
+ tensor/csx_converter.cc
+ type.cc
+ visitor.cc
+ c/bridge.cc
+ io/buffered.cc
+ io/caching.cc
+ io/compressed.cc
+ io/file.cc
+ io/hdfs.cc
+ io/hdfs_internal.cc
+ io/interfaces.cc
+ io/memory.cc
+ io/slow.cc
+ io/stdio.cc
+ io/transform.cc
+ util/async_util.cc
+ util/basic_decimal.cc
+ util/bit_block_counter.cc
+ util/bit_run_reader.cc
+ util/bit_util.cc
+ util/bitmap.cc
+ util/bitmap_builders.cc
+ util/bitmap_ops.cc
+ util/bpacking.cc
+ util/cancel.cc
+ util/compression.cc
+ util/counting_semaphore.cc
+ util/cpu_info.cc
+ util/decimal.cc
+ util/delimiting.cc
+ util/formatting.cc
+ util/future.cc
+ util/int_util.cc
+ util/io_util.cc
+ util/logging.cc
+ util/key_value_metadata.cc
+ util/memory.cc
+ util/mutex.cc
+ util/string.cc
+ util/string_builder.cc
+ util/task_group.cc
+ util/tdigest.cc
+ util/thread_pool.cc
+ util/time.cc
+ util/trie.cc
+ util/unreachable.cc
+ util/uri.cc
+ util/utf8.cc
+ util/value_parsing.cc
+ vendored/base64.cpp
+ vendored/datetime/tz.cpp
+ vendored/double-conversion/bignum.cc
+ vendored/double-conversion/double-conversion.cc
+ vendored/double-conversion/bignum-dtoa.cc
+ vendored/double-conversion/fast-dtoa.cc
+ vendored/double-conversion/cached-powers.cc
+ vendored/double-conversion/fixed-dtoa.cc
+ vendored/double-conversion/diy-fp.cc
+ vendored/double-conversion/strtod.cc)
+
+append_avx2_src(util/bpacking_avx2.cc)
+append_avx512_src(util/bpacking_avx512.cc)
+
+if(ARROW_HAVE_NEON)
+ list(APPEND ARROW_SRCS util/bpacking_neon.cc)
+endif()
+
+if(APPLE)
+ list(APPEND ARROW_SRCS vendored/datetime/ios.mm)
+endif()
+
+set(ARROW_C_SRCS
+ vendored/musl/strptime.c
+ vendored/uriparser/UriCommon.c
+ vendored/uriparser/UriCompare.c
+ vendored/uriparser/UriEscape.c
+ vendored/uriparser/UriFile.c
+ vendored/uriparser/UriIp4Base.c
+ vendored/uriparser/UriIp4.c
+ vendored/uriparser/UriMemory.c
+ vendored/uriparser/UriNormalizeBase.c
+ vendored/uriparser/UriNormalize.c
+ vendored/uriparser/UriParseBase.c
+ vendored/uriparser/UriParse.c
+ vendored/uriparser/UriQuery.c
+ vendored/uriparser/UriRecompose.c
+ vendored/uriparser/UriResolve.c
+ vendored/uriparser/UriShorten.c)
+
+set_source_files_properties(vendored/datetime/tz.cpp
+ PROPERTIES SKIP_PRECOMPILE_HEADERS ON
+ SKIP_UNITY_BUILD_INCLUSION ON)
+
+# Disable DLL exports in vendored uriparser library
+add_definitions(-DURI_STATIC_BUILD)
+
+if(ARROW_WITH_BROTLI)
+ add_definitions(-DARROW_WITH_BROTLI)
+ list(APPEND ARROW_SRCS util/compression_brotli.cc)
+endif()
+
+if(ARROW_WITH_BZ2)
+ add_definitions(-DARROW_WITH_BZ2)
+ list(APPEND ARROW_SRCS util/compression_bz2.cc)
+endif()
+
+if(ARROW_WITH_LZ4)
+ add_definitions(-DARROW_WITH_LZ4)
+ list(APPEND ARROW_SRCS util/compression_lz4.cc)
+endif()
+
+if(ARROW_WITH_SNAPPY)
+ add_definitions(-DARROW_WITH_SNAPPY)
+ list(APPEND ARROW_SRCS util/compression_snappy.cc)
+endif()
+
+if(ARROW_WITH_ZLIB)
+ add_definitions(-DARROW_WITH_ZLIB)
+ list(APPEND ARROW_SRCS util/compression_zlib.cc)
+endif()
+
+if(ARROW_WITH_ZSTD)
+ add_definitions(-DARROW_WITH_ZSTD)
+ list(APPEND ARROW_SRCS util/compression_zstd.cc)
+endif()
+
+set(ARROW_TESTING_SRCS
+ io/test_common.cc
+ ipc/test_common.cc
+ testing/json_integration.cc
+ testing/json_internal.cc
+ testing/gtest_util.cc
+ testing/random.cc
+ testing/generator.cc
+ testing/util.cc)
+
+# Add dependencies for third-party allocators.
+# If possible we only want memory_pool.cc to wait for allocators to finish building,
+# but that only works with Ninja
+# (see https://gitlab.kitware.com/cmake/cmake/issues/19677)
+
+set(_allocator_dependencies "") # Empty list
+if(ARROW_JEMALLOC)
+ list(APPEND _allocator_dependencies jemalloc_ep)
+endif()
+if(ARROW_MIMALLOC)
+ list(APPEND _allocator_dependencies mimalloc_ep)
+endif()
+
+if(_allocator_dependencies)
+ if("${CMAKE_GENERATOR}" STREQUAL "Ninja")
+ set_source_files_properties(memory_pool.cc PROPERTIES OBJECT_DEPENDS
+ "${_allocator_dependencies}")
+ else()
+ add_dependencies(arrow_dependencies ${_allocator_dependencies})
+ endif()
+ set_source_files_properties(memory_pool.cc PROPERTIES SKIP_PRECOMPILE_HEADERS ON
+ SKIP_UNITY_BUILD_INCLUSION ON)
+endif()
+
+unset(_allocator_dependencies)
+
+if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
+ set_property(SOURCE util/io_util.cc
+ APPEND_STRING
+ PROPERTY COMPILE_FLAGS " -Wno-unused-macros ")
+endif()
+
+#
+# Configure the base Arrow libraries
+#
+
+if(ARROW_CSV)
+ list(APPEND
+ ARROW_SRCS
+ csv/converter.cc
+ csv/chunker.cc
+ csv/column_builder.cc
+ csv/column_decoder.cc
+ csv/options.cc
+ csv/parser.cc
+ csv/reader.cc)
+ if(ARROW_COMPUTE)
+ list(APPEND ARROW_SRCS csv/writer.cc)
+ endif()
+
+ list(APPEND ARROW_TESTING_SRCS csv/test_common.cc)
+endif()
+
+if(ARROW_COMPUTE)
+ list(APPEND
+ ARROW_SRCS
+ compute/api_aggregate.cc
+ compute/api_scalar.cc
+ compute/api_vector.cc
+ compute/cast.cc
+ compute/exec.cc
+ compute/exec/aggregate_node.cc
+ compute/exec/exec_plan.cc
+ compute/exec/expression.cc
+ compute/exec/filter_node.cc
+ compute/exec/project_node.cc
+ compute/exec/source_node.cc
+ compute/exec/sink_node.cc
+ compute/exec/order_by_impl.cc
+ compute/function.cc
+ compute/function_internal.cc
+ compute/kernel.cc
+ compute/registry.cc
+ compute/kernels/aggregate_basic.cc
+ compute/kernels/aggregate_mode.cc
+ compute/kernels/aggregate_quantile.cc
+ compute/kernels/aggregate_tdigest.cc
+ compute/kernels/aggregate_var_std.cc
+ compute/kernels/codegen_internal.cc
+ compute/kernels/hash_aggregate.cc
+ compute/kernels/scalar_arithmetic.cc
+ compute/kernels/scalar_boolean.cc
+ compute/kernels/scalar_cast_boolean.cc
+ compute/kernels/scalar_cast_dictionary.cc
+ compute/kernels/scalar_cast_internal.cc
+ compute/kernels/scalar_cast_nested.cc
+ compute/kernels/scalar_cast_numeric.cc
+ compute/kernels/scalar_cast_string.cc
+ compute/kernels/scalar_cast_temporal.cc
+ compute/kernels/scalar_compare.cc
+ compute/kernels/scalar_nested.cc
+ compute/kernels/scalar_set_lookup.cc
+ compute/kernels/scalar_string.cc
+ compute/kernels/scalar_temporal_binary.cc
+ compute/kernels/scalar_temporal_unary.cc
+ compute/kernels/scalar_validity.cc
+ compute/kernels/scalar_if_else.cc
+ compute/kernels/util_internal.cc
+ compute/kernels/vector_array_sort.cc
+ compute/kernels/vector_hash.cc
+ compute/kernels/vector_nested.cc
+ compute/kernels/vector_replace.cc
+ compute/kernels/vector_selection.cc
+ compute/kernels/vector_sort.cc
+ compute/kernels/row_encoder.cc
+ compute/exec/union_node.cc
+ compute/exec/key_hash.cc
+ compute/exec/key_map.cc
+ compute/exec/key_compare.cc
+ compute/exec/key_encode.cc
+ compute/exec/util.cc
+ compute/exec/hash_join_dict.cc
+ compute/exec/hash_join.cc
+ compute/exec/hash_join_node.cc
+ compute/exec/task_util.cc)
+
+ append_avx2_src(compute/kernels/aggregate_basic_avx2.cc)
+ append_avx512_src(compute/kernels/aggregate_basic_avx512.cc)
+
+ append_avx2_src(compute/exec/key_hash_avx2.cc)
+ append_avx2_src(compute/exec/key_map_avx2.cc)
+ append_avx2_src(compute/exec/key_compare_avx2.cc)
+ append_avx2_src(compute/exec/key_encode_avx2.cc)
+ append_avx2_src(compute/exec/util_avx2.cc)
+
+ list(APPEND ARROW_TESTING_SRCS compute/exec/test_util.cc)
+endif()
+
+if(ARROW_FILESYSTEM)
+ if(ARROW_HDFS)
+ add_definitions(-DARROW_HDFS)
+ endif()
+
+ list(APPEND
+ ARROW_SRCS
+ filesystem/filesystem.cc
+ filesystem/localfs.cc
+ filesystem/mockfs.cc
+ filesystem/path_util.cc
+ filesystem/util_internal.cc)
+
+ if(ARROW_GCS)
+ list(APPEND ARROW_SRCS filesystem/gcsfs.cc filesystem/gcsfs_internal.cc)
+ set_source_files_properties(filesystem/gcsfs.cc filesystem/gcsfs_internal.cc
+ PROPERTIES SKIP_PRECOMPILE_HEADERS ON
+ SKIP_UNITY_BUILD_INCLUSION ON)
+ endif()
+ if(ARROW_HDFS)
+ list(APPEND ARROW_SRCS filesystem/hdfs.cc)
+ endif()
+ if(ARROW_S3)
+ list(APPEND ARROW_SRCS filesystem/s3fs.cc)
+ set_source_files_properties(filesystem/s3fs.cc
+ PROPERTIES SKIP_PRECOMPILE_HEADERS ON
+ SKIP_UNITY_BUILD_INCLUSION ON)
+ endif()
+
+ list(APPEND ARROW_TESTING_SRCS filesystem/test_util.cc)
+endif()
+
+if(ARROW_IPC)
+ list(APPEND
+ ARROW_SRCS
+ ipc/dictionary.cc
+ ipc/feather.cc
+ ipc/message.cc
+ ipc/metadata_internal.cc
+ ipc/options.cc
+ ipc/reader.cc
+ ipc/writer.cc)
+
+ if(ARROW_JSON)
+ list(APPEND ARROW_SRCS ipc/json_simple.cc)
+ endif()
+endif()
+
+if(ARROW_JSON)
+ list(APPEND
+ ARROW_SRCS
+ json/options.cc
+ json/chunked_builder.cc
+ json/chunker.cc
+ json/converter.cc
+ json/object_parser.cc
+ json/object_writer.cc
+ json/parser.cc
+ json/reader.cc)
+endif()
+
+if(ARROW_ORC)
+ list(APPEND ARROW_SRCS adapters/orc/adapter.cc adapters/orc/adapter_util.cc)
+endif()
+
+if(CXX_LINKER_SUPPORTS_VERSION_SCRIPT)
+ set(ARROW_VERSION_SCRIPT_FLAGS
+ "-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/symbols.map")
+ set(ARROW_SHARED_LINK_FLAGS ${ARROW_VERSION_SCRIPT_FLAGS})
+endif()
+
+set(ARROW_ALL_SRCS ${ARROW_SRCS} ${ARROW_C_SRCS})
+
+if(ARROW_BUILD_STATIC AND ARROW_BUNDLED_STATIC_LIBS)
+ set(ARROW_BUILD_BUNDLED_DEPENDENCIES TRUE)
+else()
+ set(ARROW_BUILD_BUNDLED_DEPENDENCIES FALSE)
+endif()
+
+if(ARROW_BUILD_BUNDLED_DEPENDENCIES)
+ string(APPEND ARROW_PC_LIBS_PRIVATE " -larrow_bundled_dependencies")
+endif()
+# Need -latomic on Raspbian.
+# See also: https://issues.apache.org/jira/browse/ARROW-12860
+if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux" AND ${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7")
+ string(APPEND ARROW_PC_LIBS_PRIVATE " -latomic")
+endif()
+
+add_arrow_lib(arrow
+ CMAKE_PACKAGE_NAME
+ Arrow
+ PKG_CONFIG_NAME
+ arrow
+ SOURCES
+ ${ARROW_ALL_SRCS}
+ OUTPUTS
+ ARROW_LIBRARIES
+ PRECOMPILED_HEADERS
+ "$<$<COMPILE_LANGUAGE:CXX>:arrow/pch.h>"
+ DEPENDENCIES
+ arrow_dependencies
+ SHARED_LINK_FLAGS
+ ${ARROW_SHARED_LINK_FLAGS}
+ SHARED_LINK_LIBS
+ ${ARROW_LINK_LIBS}
+ SHARED_PRIVATE_LINK_LIBS
+ ${ARROW_SHARED_PRIVATE_LINK_LIBS}
+ STATIC_LINK_LIBS
+ ${ARROW_STATIC_LINK_LIBS}
+ SHARED_INSTALL_INTERFACE_LIBS
+ ${ARROW_SHARED_INSTALL_INTERFACE_LIBS}
+ STATIC_INSTALL_INTERFACE_LIBS
+ ${ARROW_STATIC_INSTALL_INTERFACE_LIBS})
+
+add_dependencies(arrow ${ARROW_LIBRARIES})
+
+if(ARROW_BUILD_STATIC AND WIN32)
+ target_compile_definitions(arrow_static PUBLIC ARROW_STATIC)
+endif()
+
+foreach(LIB_TARGET ${ARROW_LIBRARIES})
+ target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_EXPORTING)
+endforeach()
+
+if(ARROW_WITH_BACKTRACE)
+ find_package(Backtrace)
+
+ foreach(LIB_TARGET ${ARROW_LIBRARIES})
+ if(Backtrace_FOUND AND ARROW_WITH_BACKTRACE)
+ target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_WITH_BACKTRACE)
+ endif()
+ endforeach()
+endif()
+
+if(ARROW_BUILD_BUNDLED_DEPENDENCIES)
+ arrow_car(_FIRST_LIB ${ARROW_BUNDLED_STATIC_LIBS})
+ arrow_cdr(_OTHER_LIBS ${ARROW_BUNDLED_STATIC_LIBS})
+ create_merged_static_lib(arrow_bundled_dependencies
+ NAME
+ arrow_bundled_dependencies
+ ROOT
+ ${_FIRST_LIB}
+ TO_MERGE
+ ${_OTHER_LIBS})
+endif()
+
+if(ARROW_TESTING)
+ # that depend on gtest
+ add_arrow_lib(arrow_testing
+ CMAKE_PACKAGE_NAME
+ ArrowTesting
+ PKG_CONFIG_NAME
+ arrow-testing
+ SOURCES
+ ${ARROW_TESTING_SRCS}
+ OUTPUTS
+ ARROW_TESTING_LIBRARIES
+ PRECOMPILED_HEADERS
+ "$<$<COMPILE_LANGUAGE:CXX>:arrow/pch.h>"
+ DEPENDENCIES
+ arrow_test_dependencies
+ SHARED_LINK_LIBS
+ arrow_shared
+ GTest::gtest
+ STATIC_LINK_LIBS
+ arrow_static)
+
+ add_custom_target(arrow_testing)
+ add_dependencies(arrow_testing ${ARROW_TESTING_LIBRARIES})
+
+ if(ARROW_BUILD_STATIC AND WIN32)
+ target_compile_definitions(arrow_testing_static PUBLIC ARROW_TESTING_STATIC)
+ endif()
+
+ foreach(LIB_TARGET ${ARROW_TESTING_LIBRARIES})
+ target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_TESTING_EXPORTING)
+ endforeach()
+endif()
+
+arrow_install_all_headers("arrow")
+
+config_summary_cmake_setters("${CMAKE_CURRENT_BINARY_DIR}/ArrowOptions.cmake")
+install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ArrowOptions.cmake
+ DESTINATION "${ARROW_CMAKE_INSTALL_DIR}")
+
+# For backward compatibility for find_package(arrow)
+install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/arrow-config.cmake
+ DESTINATION "${ARROW_CMAKE_INSTALL_DIR}")
+
+#
+# Unit tests
+#
+add_arrow_test(array_test
+ SOURCES
+ array/array_test.cc
+ array/array_binary_test.cc
+ array/array_dict_test.cc
+ array/array_list_test.cc
+ array/array_struct_test.cc
+ array/array_union_test.cc
+ array/array_view_test.cc
+ PRECOMPILED_HEADERS
+ "$<$<COMPILE_LANGUAGE:CXX>:arrow/testing/pch.h>")
+
+add_arrow_test(buffer_test)
+
+if(ARROW_IPC)
+ # The extension type unit tests require IPC / Flatbuffers support
+ add_arrow_test(extension_type_test)
+endif()
+
+add_arrow_test(misc_test
+ SOURCES
+ datum_test.cc
+ memory_pool_test.cc
+ result_test.cc
+ pretty_print_test.cc
+ status_test.cc)
+
+add_arrow_test(public_api_test)
+
+set_source_files_properties(public_api_test.cc PROPERTIES SKIP_PRECOMPILE_HEADERS ON
+ SKIP_UNITY_BUILD_INCLUSION ON)
+
+add_arrow_test(scalar_test)
+add_arrow_test(type_test)
+
+add_arrow_test(table_test
+ SOURCES
+ chunked_array_test.cc
+ record_batch_test.cc
+ table_test.cc
+ table_builder_test.cc)
+
+add_arrow_test(tensor_test)
+add_arrow_test(sparse_tensor_test)
+
+set(STL_TEST_SRCS stl_iterator_test.cc)
+if(ARROW_COMPUTE)
+ # This unit test uses compute code
+ list(APPEND STL_TEST_SRCS stl_test.cc)
+endif()
+add_arrow_test(stl_test SOURCES ${STL_TEST_SRCS})
+
+add_arrow_benchmark(builder_benchmark)
+add_arrow_benchmark(compare_benchmark)
+add_arrow_benchmark(memory_pool_benchmark)
+add_arrow_benchmark(type_benchmark)
+
+#
+# Recurse into sub-directories
+#
+
+# Unconditionally install testing headers that are also useful for Arrow consumers.
+add_subdirectory(testing)
+
+add_subdirectory(array)
+add_subdirectory(c)
+add_subdirectory(io)
+add_subdirectory(tensor)
+add_subdirectory(util)
+add_subdirectory(vendored)
+
+if(ARROW_CSV)
+ add_subdirectory(csv)
+endif()
+
+if(ARROW_COMPUTE)
+ add_subdirectory(compute)
+endif()
+
+if(ARROW_CUDA)
+ add_subdirectory(gpu)
+endif()
+
+if(ARROW_DATASET)
+ add_subdirectory(dataset)
+endif()
+
+if(ARROW_FILESYSTEM)
+ add_subdirectory(filesystem)
+endif()
+
+if(ARROW_FLIGHT)
+ add_subdirectory(flight)
+endif()
+
+if(ARROW_HIVESERVER2)
+ add_subdirectory(dbi/hiveserver2)
+endif()
+
+if(ARROW_IPC)
+ add_subdirectory(ipc)
+endif()
+
+if(ARROW_JSON)
+ add_subdirectory(json)
+endif()
+
+if(ARROW_ORC)
+ add_subdirectory(adapters/orc)
+endif()
+
+if(ARROW_PYTHON)
+ add_subdirectory(python)
+endif()
+
+if(ARROW_TENSORFLOW)
+ add_subdirectory(adapters/tensorflow)
+endif()
diff --git a/src/arrow/cpp/src/arrow/adapters/orc/CMakeLists.txt b/src/arrow/cpp/src/arrow/adapters/orc/CMakeLists.txt
new file mode 100644
index 000000000..ca901b07d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/adapters/orc/CMakeLists.txt
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#
+# arrow_orc
+#
+
+# Headers: top level
+install(FILES adapter.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/arrow/adapters/orc")
+
+# pkg-config support
+arrow_add_pkg_config("arrow-orc")
+
+set(ORC_MIN_TEST_LIBS
+ GTest::gtest_main
+ GTest::gtest
+ Snappy::snappy
+ LZ4::lz4
+ ZLIB::ZLIB)
+
+if(ARROW_BUILD_STATIC)
+ set(ARROW_LIBRARIES_FOR_STATIC_TESTS arrow_testing_static arrow_static)
+else()
+ set(ARROW_LIBRARIES_FOR_STATIC_TESTS arrow_testing_shared arrow_shared)
+endif()
+
+if(APPLE)
+ set(ORC_MIN_TEST_LIBS ${ORC_MIN_TEST_LIBS} ${CMAKE_DL_LIBS})
+elseif(NOT MSVC)
+ set(ORC_MIN_TEST_LIBS ${ORC_MIN_TEST_LIBS} pthread ${CMAKE_DL_LIBS})
+endif()
+
+set(ORC_STATIC_TEST_LINK_LIBS orc::liborc ${ARROW_LIBRARIES_FOR_STATIC_TESTS}
+ ${ORC_MIN_TEST_LIBS})
+
+add_arrow_test(adapter_test
+ PREFIX
+ "arrow-orc"
+ STATIC_LINK_LIBS
+ ${ORC_STATIC_TEST_LINK_LIBS})
+
+set_source_files_properties(adapter_test.cc PROPERTIES SKIP_PRECOMPILE_HEADERS ON
+ SKIP_UNITY_BUILD_INCLUSION ON)
diff --git a/src/arrow/cpp/src/arrow/adapters/orc/adapter.cc b/src/arrow/cpp/src/arrow/adapters/orc/adapter.cc
new file mode 100644
index 000000000..9bb4abfd0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/adapters/orc/adapter.cc
@@ -0,0 +1,699 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/adapters/orc/adapter.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <functional>
+#include <list>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/adapters/orc/adapter_util.h"
+#include "arrow/buffer.h"
+#include "arrow/builder.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/memory_pool.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/table_builder.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/range.h"
+#include "arrow/util/visibility.h"
+#include "orc/Exceptions.hh"
+
+// alias to not interfere with nested orc namespace
+namespace liborc = orc;
+
+#define ORC_THROW_NOT_OK(s) \
+ do { \
+ Status _s = (s); \
+ if (!_s.ok()) { \
+ std::stringstream ss; \
+ ss << "Arrow error: " << _s.ToString(); \
+ throw liborc::ParseError(ss.str()); \
+ } \
+ } while (0)
+
+#define ORC_ASSIGN_OR_THROW_IMPL(status_name, lhs, rexpr) \
+ auto status_name = (rexpr); \
+ ORC_THROW_NOT_OK(status_name.status()); \
+ lhs = std::move(status_name).ValueOrDie();
+
+#define ORC_ASSIGN_OR_THROW(lhs, rexpr) \
+ ORC_ASSIGN_OR_THROW_IMPL(ARROW_ASSIGN_OR_RAISE_NAME(_error_or_value, __COUNTER__), \
+ lhs, rexpr);
+
+#define ORC_BEGIN_CATCH_NOT_OK try {
+#define ORC_END_CATCH_NOT_OK \
+ } \
+ catch (const liborc::ParseError& e) { \
+ return Status::IOError(e.what()); \
+ } \
+ catch (const liborc::InvalidArgument& e) { \
+ return Status::Invalid(e.what()); \
+ } \
+ catch (const liborc::NotImplementedYet& e) { \
+ return Status::NotImplemented(e.what()); \
+ }
+
+#define ORC_CATCH_NOT_OK(_s) \
+ ORC_BEGIN_CATCH_NOT_OK(_s); \
+ ORC_END_CATCH_NOT_OK
+
+namespace arrow {
+namespace adapters {
+namespace orc {
+
+namespace {
+
+// The following are required by ORC to be uint64_t
+constexpr uint64_t kOrcWriterBatchSize = 128 * 1024;
+constexpr uint64_t kOrcNaturalWriteSize = 128 * 1024;
+
+using internal::checked_cast;
+
+class ArrowInputFile : public liborc::InputStream {
+ public:
+ explicit ArrowInputFile(const std::shared_ptr<io::RandomAccessFile>& file)
+ : file_(file) {}
+
+ uint64_t getLength() const override {
+ ORC_ASSIGN_OR_THROW(int64_t size, file_->GetSize());
+ return static_cast<uint64_t>(size);
+ }
+
+ uint64_t getNaturalReadSize() const override { return 128 * 1024; }
+
+ void read(void* buf, uint64_t length, uint64_t offset) override {
+ ORC_ASSIGN_OR_THROW(int64_t bytes_read, file_->ReadAt(offset, length, buf));
+
+ if (static_cast<uint64_t>(bytes_read) != length) {
+ throw liborc::ParseError("Short read from arrow input file");
+ }
+ }
+
+ const std::string& getName() const override {
+ static const std::string filename("ArrowInputFile");
+ return filename;
+ }
+
+ private:
+ std::shared_ptr<io::RandomAccessFile> file_;
+};
+
+struct StripeInformation {
+ uint64_t offset;
+ uint64_t length;
+ uint64_t num_rows;
+ uint64_t first_row_of_stripe;
+};
+
+// The number of rows to read in a ColumnVectorBatch
+constexpr int64_t kReadRowsBatch = 1000;
+
+class OrcStripeReader : public RecordBatchReader {
+ public:
+ OrcStripeReader(std::unique_ptr<liborc::RowReader> row_reader,
+ std::shared_ptr<Schema> schema, int64_t batch_size, MemoryPool* pool)
+ : row_reader_(std::move(row_reader)),
+ schema_(schema),
+ pool_(pool),
+ batch_size_{batch_size} {}
+
+ std::shared_ptr<Schema> schema() const override { return schema_; }
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* out) override {
+ std::unique_ptr<liborc::ColumnVectorBatch> batch;
+ ORC_CATCH_NOT_OK(batch = row_reader_->createRowBatch(batch_size_));
+
+ const liborc::Type& type = row_reader_->getSelectedType();
+ if (!row_reader_->next(*batch)) {
+ out->reset();
+ return Status::OK();
+ }
+
+ std::unique_ptr<RecordBatchBuilder> builder;
+ RETURN_NOT_OK(RecordBatchBuilder::Make(schema_, pool_, batch->numElements, &builder));
+
+ // The top-level type must be a struct to read into an arrow table
+ const auto& struct_batch = checked_cast<liborc::StructVectorBatch&>(*batch);
+
+ for (int i = 0; i < builder->num_fields(); i++) {
+ RETURN_NOT_OK(AppendBatch(type.getSubtype(i), struct_batch.fields[i], 0,
+ batch->numElements, builder->GetField(i)));
+ }
+
+ RETURN_NOT_OK(builder->Flush(out));
+ return Status::OK();
+ }
+
+ private:
+ std::unique_ptr<liborc::RowReader> row_reader_;
+ std::shared_ptr<Schema> schema_;
+ MemoryPool* pool_;
+ int64_t batch_size_;
+};
+
+} // namespace
+
+class ORCFileReader::Impl {
+ public:
+ Impl() {}
+ ~Impl() {}
+
+ Status Open(const std::shared_ptr<io::RandomAccessFile>& file, MemoryPool* pool) {
+ std::unique_ptr<ArrowInputFile> io_wrapper(new ArrowInputFile(file));
+ liborc::ReaderOptions options;
+ std::unique_ptr<liborc::Reader> liborc_reader;
+ ORC_CATCH_NOT_OK(liborc_reader = createReader(std::move(io_wrapper), options));
+ pool_ = pool;
+ reader_ = std::move(liborc_reader);
+ current_row_ = 0;
+
+ return Init();
+ }
+
+ Status Init() {
+ int64_t nstripes = reader_->getNumberOfStripes();
+ stripes_.resize(nstripes);
+ std::unique_ptr<liborc::StripeInformation> stripe;
+ uint64_t first_row_of_stripe = 0;
+ for (int i = 0; i < nstripes; ++i) {
+ stripe = reader_->getStripe(i);
+ stripes_[i] = StripeInformation({stripe->getOffset(), stripe->getLength(),
+ stripe->getNumberOfRows(), first_row_of_stripe});
+ first_row_of_stripe += stripe->getNumberOfRows();
+ }
+ return Status::OK();
+ }
+
+ int64_t NumberOfStripes() { return stripes_.size(); }
+
+ int64_t NumberOfRows() { return reader_->getNumberOfRows(); }
+
+ Status ReadSchema(std::shared_ptr<Schema>* out) {
+ const liborc::Type& type = reader_->getType();
+ return GetArrowSchema(type, out);
+ }
+
+ Status ReadSchema(const liborc::RowReaderOptions& opts, std::shared_ptr<Schema>* out) {
+ std::unique_ptr<liborc::RowReader> row_reader;
+ ORC_CATCH_NOT_OK(row_reader = reader_->createRowReader(opts));
+ const liborc::Type& type = row_reader->getSelectedType();
+ return GetArrowSchema(type, out);
+ }
+
+ Result<std::shared_ptr<const KeyValueMetadata>> ReadMetadata() {
+ const std::list<std::string> keys = reader_->getMetadataKeys();
+ auto metadata = std::make_shared<KeyValueMetadata>();
+ for (const auto& key : keys) {
+ metadata->Append(key, reader_->getMetadataValue(key));
+ }
+ return std::const_pointer_cast<const KeyValueMetadata>(metadata);
+ }
+
+ Status GetArrowSchema(const liborc::Type& type, std::shared_ptr<Schema>* out) {
+ if (type.getKind() != liborc::STRUCT) {
+ return Status::NotImplemented(
+ "Only ORC files with a top-level struct "
+ "can be handled");
+ }
+ int size = static_cast<int>(type.getSubtypeCount());
+ std::vector<std::shared_ptr<Field>> fields;
+ for (int child = 0; child < size; ++child) {
+ std::shared_ptr<DataType> elemtype;
+ RETURN_NOT_OK(GetArrowType(type.getSubtype(child), &elemtype));
+ std::string name = type.getFieldName(child);
+ fields.push_back(field(name, elemtype));
+ }
+ ARROW_ASSIGN_OR_RAISE(auto metadata, ReadMetadata());
+ *out = std::make_shared<Schema>(std::move(fields), std::move(metadata));
+ return Status::OK();
+ }
+
+ Status Read(std::shared_ptr<Table>* out) {
+ liborc::RowReaderOptions opts;
+ std::shared_ptr<Schema> schema;
+ RETURN_NOT_OK(ReadSchema(opts, &schema));
+ return ReadTable(opts, schema, out);
+ }
+
+ Status Read(const std::shared_ptr<Schema>& schema, std::shared_ptr<Table>* out) {
+ liborc::RowReaderOptions opts;
+ return ReadTable(opts, schema, out);
+ }
+
+ Status Read(const std::vector<int>& include_indices, std::shared_ptr<Table>* out) {
+ liborc::RowReaderOptions opts;
+ RETURN_NOT_OK(SelectIndices(&opts, include_indices));
+ std::shared_ptr<Schema> schema;
+ RETURN_NOT_OK(ReadSchema(opts, &schema));
+ return ReadTable(opts, schema, out);
+ }
+
+ Status Read(const std::vector<std::string>& include_names,
+ std::shared_ptr<Table>* out) {
+ liborc::RowReaderOptions opts;
+ RETURN_NOT_OK(SelectNames(&opts, include_names));
+ std::shared_ptr<Schema> schema;
+ RETURN_NOT_OK(ReadSchema(opts, &schema));
+ return ReadTable(opts, schema, out);
+ }
+
+ Status Read(const std::shared_ptr<Schema>& schema,
+ const std::vector<int>& include_indices, std::shared_ptr<Table>* out) {
+ liborc::RowReaderOptions opts;
+ RETURN_NOT_OK(SelectIndices(&opts, include_indices));
+ return ReadTable(opts, schema, out);
+ }
+
+ Status ReadStripe(int64_t stripe, std::shared_ptr<RecordBatch>* out) {
+ liborc::RowReaderOptions opts;
+ RETURN_NOT_OK(SelectStripe(&opts, stripe));
+ std::shared_ptr<Schema> schema;
+ RETURN_NOT_OK(ReadSchema(opts, &schema));
+ return ReadBatch(opts, schema, stripes_[stripe].num_rows, out);
+ }
+
+ Status ReadStripe(int64_t stripe, const std::vector<int>& include_indices,
+ std::shared_ptr<RecordBatch>* out) {
+ liborc::RowReaderOptions opts;
+ RETURN_NOT_OK(SelectIndices(&opts, include_indices));
+ RETURN_NOT_OK(SelectStripe(&opts, stripe));
+ std::shared_ptr<Schema> schema;
+ RETURN_NOT_OK(ReadSchema(opts, &schema));
+ return ReadBatch(opts, schema, stripes_[stripe].num_rows, out);
+ }
+
+ Status ReadStripe(int64_t stripe, const std::vector<std::string>& include_names,
+ std::shared_ptr<RecordBatch>* out) {
+ liborc::RowReaderOptions opts;
+ RETURN_NOT_OK(SelectNames(&opts, include_names));
+ RETURN_NOT_OK(SelectStripe(&opts, stripe));
+ std::shared_ptr<Schema> schema;
+ RETURN_NOT_OK(ReadSchema(opts, &schema));
+ return ReadBatch(opts, schema, stripes_[stripe].num_rows, out);
+ }
+
+ Status SelectStripe(liborc::RowReaderOptions* opts, int64_t stripe) {
+ ARROW_RETURN_IF(stripe < 0 || stripe >= NumberOfStripes(),
+ Status::Invalid("Out of bounds stripe: ", stripe));
+
+ opts->range(stripes_[stripe].offset, stripes_[stripe].length);
+ return Status::OK();
+ }
+
+ Status SelectStripeWithRowNumber(liborc::RowReaderOptions* opts, int64_t row_number,
+ StripeInformation* out) {
+ ARROW_RETURN_IF(row_number >= NumberOfRows(),
+ Status::Invalid("Out of bounds row number: ", row_number));
+
+ for (auto it = stripes_.begin(); it != stripes_.end(); it++) {
+ if (static_cast<uint64_t>(row_number) >= it->first_row_of_stripe &&
+ static_cast<uint64_t>(row_number) < it->first_row_of_stripe + it->num_rows) {
+ opts->range(it->offset, it->length);
+ *out = *it;
+ return Status::OK();
+ }
+ }
+
+ return Status::Invalid("Invalid row number", row_number);
+ }
+
+ Status SelectIndices(liborc::RowReaderOptions* opts,
+ const std::vector<int>& include_indices) {
+ std::list<uint64_t> include_indices_list;
+ for (auto it = include_indices.begin(); it != include_indices.end(); ++it) {
+ ARROW_RETURN_IF(*it < 0, Status::Invalid("Negative field index"));
+ include_indices_list.push_back(*it);
+ }
+ opts->includeTypes(include_indices_list);
+ return Status::OK();
+ }
+
+ Status SelectNames(liborc::RowReaderOptions* opts,
+ const std::vector<std::string>& include_names) {
+ std::list<std::string> include_names_list(include_names.begin(), include_names.end());
+ opts->include(include_names_list);
+ return Status::OK();
+ }
+
+ Status ReadTable(const liborc::RowReaderOptions& row_opts,
+ const std::shared_ptr<Schema>& schema, std::shared_ptr<Table>* out) {
+ liborc::RowReaderOptions opts(row_opts);
+ std::vector<std::shared_ptr<RecordBatch>> batches(stripes_.size());
+ for (size_t stripe = 0; stripe < stripes_.size(); stripe++) {
+ opts.range(stripes_[stripe].offset, stripes_[stripe].length);
+ RETURN_NOT_OK(ReadBatch(opts, schema, stripes_[stripe].num_rows, &batches[stripe]));
+ }
+ return Table::FromRecordBatches(schema, std::move(batches)).Value(out);
+ }
+
+ Status ReadBatch(const liborc::RowReaderOptions& opts,
+ const std::shared_ptr<Schema>& schema, int64_t nrows,
+ std::shared_ptr<RecordBatch>* out) {
+ std::unique_ptr<liborc::RowReader> row_reader;
+ std::unique_ptr<liborc::ColumnVectorBatch> batch;
+
+ ORC_BEGIN_CATCH_NOT_OK
+ row_reader = reader_->createRowReader(opts);
+ batch = row_reader->createRowBatch(std::min(nrows, kReadRowsBatch));
+ ORC_END_CATCH_NOT_OK
+
+ std::unique_ptr<RecordBatchBuilder> builder;
+ RETURN_NOT_OK(RecordBatchBuilder::Make(schema, pool_, nrows, &builder));
+
+ // The top-level type must be a struct to read into an arrow table
+ const auto& struct_batch = checked_cast<liborc::StructVectorBatch&>(*batch);
+
+ const liborc::Type& type = row_reader->getSelectedType();
+ while (row_reader->next(*batch)) {
+ for (int i = 0; i < builder->num_fields(); i++) {
+ RETURN_NOT_OK(AppendBatch(type.getSubtype(i), struct_batch.fields[i], 0,
+ batch->numElements, builder->GetField(i)));
+ }
+ }
+ RETURN_NOT_OK(builder->Flush(out));
+ return Status::OK();
+ }
+
+ Status Seek(int64_t row_number) {
+ ARROW_RETURN_IF(row_number >= NumberOfRows(),
+ Status::Invalid("Out of bounds row number: ", row_number));
+
+ current_row_ = row_number;
+ return Status::OK();
+ }
+
+ Status NextStripeReader(int64_t batch_size, const std::vector<int>& include_indices,
+ std::shared_ptr<RecordBatchReader>* out) {
+ if (current_row_ >= NumberOfRows()) {
+ out->reset();
+ return Status::OK();
+ }
+
+ liborc::RowReaderOptions opts;
+ if (!include_indices.empty()) {
+ RETURN_NOT_OK(SelectIndices(&opts, include_indices));
+ }
+ StripeInformation stripe_info({0, 0, 0, 0});
+ RETURN_NOT_OK(SelectStripeWithRowNumber(&opts, current_row_, &stripe_info));
+ std::shared_ptr<Schema> schema;
+ RETURN_NOT_OK(ReadSchema(opts, &schema));
+ std::unique_ptr<liborc::RowReader> row_reader;
+
+ ORC_BEGIN_CATCH_NOT_OK
+ row_reader = reader_->createRowReader(opts);
+ row_reader->seekToRow(current_row_);
+ current_row_ = stripe_info.first_row_of_stripe + stripe_info.num_rows;
+ ORC_END_CATCH_NOT_OK
+
+ *out = std::shared_ptr<RecordBatchReader>(
+ new OrcStripeReader(std::move(row_reader), schema, batch_size, pool_));
+ return Status::OK();
+ }
+
+ Status NextStripeReader(int64_t batch_size, std::shared_ptr<RecordBatchReader>* out) {
+ return NextStripeReader(batch_size, {}, out);
+ }
+
+ private:
+ MemoryPool* pool_;
+ std::unique_ptr<liborc::Reader> reader_;
+ std::vector<StripeInformation> stripes_;
+ int64_t current_row_;
+};
+
+ORCFileReader::ORCFileReader() { impl_.reset(new ORCFileReader::Impl()); }
+
+ORCFileReader::~ORCFileReader() {}
+
+Status ORCFileReader::Open(const std::shared_ptr<io::RandomAccessFile>& file,
+ MemoryPool* pool, std::unique_ptr<ORCFileReader>* reader) {
+ return Open(file, pool).Value(reader);
+}
+
+Result<std::unique_ptr<ORCFileReader>> ORCFileReader::Open(
+ const std::shared_ptr<io::RandomAccessFile>& file, MemoryPool* pool) {
+ auto result = std::unique_ptr<ORCFileReader>(new ORCFileReader());
+ RETURN_NOT_OK(result->impl_->Open(file, pool));
+ return std::move(result);
+}
+
+Result<std::shared_ptr<const KeyValueMetadata>> ORCFileReader::ReadMetadata() {
+ return impl_->ReadMetadata();
+}
+
+Status ORCFileReader::ReadSchema(std::shared_ptr<Schema>* out) {
+ return impl_->ReadSchema(out);
+}
+
+Result<std::shared_ptr<Schema>> ORCFileReader::ReadSchema() {
+ std::shared_ptr<Schema> schema;
+ RETURN_NOT_OK(impl_->ReadSchema(&schema));
+ return schema;
+}
+
+Status ORCFileReader::Read(std::shared_ptr<Table>* out) { return impl_->Read(out); }
+
+Result<std::shared_ptr<Table>> ORCFileReader::Read() {
+ std::shared_ptr<Table> table;
+ RETURN_NOT_OK(impl_->Read(&table));
+ return table;
+}
+
+Status ORCFileReader::Read(const std::shared_ptr<Schema>& schema,
+ std::shared_ptr<Table>* out) {
+ return impl_->Read(schema, out);
+}
+
+Result<std::shared_ptr<Table>> ORCFileReader::Read(
+ const std::shared_ptr<Schema>& schema) {
+ std::shared_ptr<Table> table;
+ RETURN_NOT_OK(impl_->Read(schema, &table));
+ return table;
+}
+
+Status ORCFileReader::Read(const std::vector<int>& include_indices,
+ std::shared_ptr<Table>* out) {
+ return impl_->Read(include_indices, out);
+}
+
+Result<std::shared_ptr<Table>> ORCFileReader::Read(
+ const std::vector<int>& include_indices) {
+ std::shared_ptr<Table> table;
+ RETURN_NOT_OK(impl_->Read(include_indices, &table));
+ return table;
+}
+
+Result<std::shared_ptr<Table>> ORCFileReader::Read(
+ const std::vector<std::string>& include_names) {
+ std::shared_ptr<Table> table;
+ RETURN_NOT_OK(impl_->Read(include_names, &table));
+ return table;
+}
+
+Status ORCFileReader::Read(const std::shared_ptr<Schema>& schema,
+ const std::vector<int>& include_indices,
+ std::shared_ptr<Table>* out) {
+ return impl_->Read(schema, include_indices, out);
+}
+
+Result<std::shared_ptr<Table>> ORCFileReader::Read(
+ const std::shared_ptr<Schema>& schema, const std::vector<int>& include_indices) {
+ std::shared_ptr<Table> table;
+ RETURN_NOT_OK(impl_->Read(schema, include_indices, &table));
+ return table;
+}
+
+Status ORCFileReader::ReadStripe(int64_t stripe, std::shared_ptr<RecordBatch>* out) {
+ return impl_->ReadStripe(stripe, out);
+}
+
+Result<std::shared_ptr<RecordBatch>> ORCFileReader::ReadStripe(int64_t stripe) {
+ std::shared_ptr<RecordBatch> recordBatch;
+ RETURN_NOT_OK(impl_->ReadStripe(stripe, &recordBatch));
+ return recordBatch;
+}
+
+Status ORCFileReader::ReadStripe(int64_t stripe, const std::vector<int>& include_indices,
+ std::shared_ptr<RecordBatch>* out) {
+ return impl_->ReadStripe(stripe, include_indices, out);
+}
+
+Result<std::shared_ptr<RecordBatch>> ORCFileReader::ReadStripe(
+ int64_t stripe, const std::vector<int>& include_indices) {
+ std::shared_ptr<RecordBatch> recordBatch;
+ RETURN_NOT_OK(impl_->ReadStripe(stripe, include_indices, &recordBatch));
+ return recordBatch;
+}
+
+Result<std::shared_ptr<RecordBatch>> ORCFileReader::ReadStripe(
+ int64_t stripe, const std::vector<std::string>& include_names) {
+ std::shared_ptr<RecordBatch> recordBatch;
+ RETURN_NOT_OK(impl_->ReadStripe(stripe, include_names, &recordBatch));
+ return recordBatch;
+}
+
+Status ORCFileReader::Seek(int64_t row_number) { return impl_->Seek(row_number); }
+
+Status ORCFileReader::NextStripeReader(int64_t batch_sizes,
+ std::shared_ptr<RecordBatchReader>* out) {
+ return impl_->NextStripeReader(batch_sizes, out);
+}
+
+Result<std::shared_ptr<RecordBatchReader>> ORCFileReader::NextStripeReader(
+ int64_t batch_size) {
+ std::shared_ptr<RecordBatchReader> reader;
+ RETURN_NOT_OK(impl_->NextStripeReader(batch_size, &reader));
+ return reader;
+}
+
+Status ORCFileReader::NextStripeReader(int64_t batch_size,
+ const std::vector<int>& include_indices,
+ std::shared_ptr<RecordBatchReader>* out) {
+ return impl_->NextStripeReader(batch_size, include_indices, out);
+}
+
+Result<std::shared_ptr<RecordBatchReader>> ORCFileReader::NextStripeReader(
+ int64_t batch_size, const std::vector<int>& include_indices) {
+ std::shared_ptr<RecordBatchReader> reader;
+ RETURN_NOT_OK(impl_->NextStripeReader(batch_size, include_indices, &reader));
+ return reader;
+}
+
+int64_t ORCFileReader::NumberOfStripes() { return impl_->NumberOfStripes(); }
+
+int64_t ORCFileReader::NumberOfRows() { return impl_->NumberOfRows(); }
+
+namespace {
+
+class ArrowOutputStream : public liborc::OutputStream {
+ public:
+ explicit ArrowOutputStream(arrow::io::OutputStream& output_stream)
+ : output_stream_(output_stream), length_(0) {}
+
+ uint64_t getLength() const override { return length_; }
+
+ uint64_t getNaturalWriteSize() const override { return kOrcNaturalWriteSize; }
+
+ void write(const void* buf, size_t length) override {
+ ORC_THROW_NOT_OK(output_stream_.Write(buf, static_cast<int64_t>(length)));
+ length_ += static_cast<int64_t>(length);
+ }
+
+ // Mandatory due to us implementing an ORC virtual class.
+ // Used by ORC for error messages, not used by Arrow
+ const std::string& getName() const override {
+ static const std::string filename("ArrowOutputFile");
+ return filename;
+ }
+
+ void close() override {
+ if (!output_stream_.closed()) {
+ ORC_THROW_NOT_OK(output_stream_.Close());
+ }
+ }
+
+ void set_length(int64_t length) { length_ = length; }
+
+ private:
+ arrow::io::OutputStream& output_stream_;
+ int64_t length_;
+};
+
+} // namespace
+
+class ORCFileWriter::Impl {
+ public:
+ Status Open(arrow::io::OutputStream* output_stream) {
+ out_stream_ = std::unique_ptr<liborc::OutputStream>(
+ checked_cast<liborc::OutputStream*>(new ArrowOutputStream(*output_stream)));
+ return Status::OK();
+ }
+
+ Status Write(const Table& table) {
+ std::unique_ptr<liborc::WriterOptions> orc_options =
+ std::unique_ptr<liborc::WriterOptions>(new liborc::WriterOptions());
+ ARROW_ASSIGN_OR_RAISE(auto orc_schema, GetOrcType(*(table.schema())));
+ ORC_CATCH_NOT_OK(
+ writer_ = liborc::createWriter(*orc_schema, out_stream_.get(), *orc_options))
+
+ int64_t num_rows = table.num_rows();
+ const int num_cols_ = table.num_columns();
+ std::vector<int64_t> arrow_index_offset(num_cols_, 0);
+ std::vector<int> arrow_chunk_offset(num_cols_, 0);
+ std::unique_ptr<liborc::ColumnVectorBatch> batch =
+ writer_->createRowBatch(kOrcWriterBatchSize);
+ liborc::StructVectorBatch* root =
+ internal::checked_cast<liborc::StructVectorBatch*>(batch.get());
+ while (num_rows > 0) {
+ for (int i = 0; i < num_cols_; i++) {
+ RETURN_NOT_OK(adapters::orc::WriteBatch(
+ *(table.column(i)), kOrcWriterBatchSize, &(arrow_chunk_offset[i]),
+ &(arrow_index_offset[i]), (root->fields)[i]));
+ }
+ root->numElements = (root->fields)[0]->numElements;
+ writer_->add(*batch);
+ batch->clear();
+ num_rows -= kOrcWriterBatchSize;
+ }
+ return Status::OK();
+ }
+
+ Status Close() {
+ writer_->close();
+ return Status::OK();
+ }
+
+ private:
+ std::unique_ptr<liborc::Writer> writer_;
+ std::unique_ptr<liborc::OutputStream> out_stream_;
+};
+
+ORCFileWriter::~ORCFileWriter() {}
+
+ORCFileWriter::ORCFileWriter() { impl_.reset(new ORCFileWriter::Impl()); }
+
+Result<std::unique_ptr<ORCFileWriter>> ORCFileWriter::Open(
+ io::OutputStream* output_stream) {
+ std::unique_ptr<ORCFileWriter> result =
+ std::unique_ptr<ORCFileWriter>(new ORCFileWriter());
+ Status status = result->impl_->Open(output_stream);
+ RETURN_NOT_OK(status);
+ return std::move(result);
+}
+
+Status ORCFileWriter::Write(const Table& table) { return impl_->Write(table); }
+
+Status ORCFileWriter::Close() { return impl_->Close(); }
+
+} // namespace orc
+} // namespace adapters
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/adapters/orc/adapter.h b/src/arrow/cpp/src/arrow/adapters/orc/adapter.h
new file mode 100644
index 000000000..e053eab43
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/adapters/orc/adapter.h
@@ -0,0 +1,291 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "arrow/io/interfaces.h"
+#include "arrow/memory_pool.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace adapters {
+namespace orc {
+
+/// \class ORCFileReader
+/// \brief Read an Arrow Table or RecordBatch from an ORC file.
+class ARROW_EXPORT ORCFileReader {
+ public:
+ ~ORCFileReader();
+
+ /// \brief Creates a new ORC reader.
+ ///
+ /// \param[in] file the data source
+ /// \param[in] pool a MemoryPool to use for buffer allocations
+ /// \param[out] reader the returned reader object
+ /// \return Status
+ ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.")
+ static Status Open(const std::shared_ptr<io::RandomAccessFile>& file, MemoryPool* pool,
+ std::unique_ptr<ORCFileReader>* reader);
+
+ /// \brief Creates a new ORC reader
+ ///
+ /// \param[in] file the data source
+ /// \param[in] pool a MemoryPool to use for buffer allocations
+ /// \return the returned reader object
+ static Result<std::unique_ptr<ORCFileReader>> Open(
+ const std::shared_ptr<io::RandomAccessFile>& file, MemoryPool* pool);
+
+ /// \brief Return the metadata read from the ORC file
+ ///
+ /// \return A KeyValueMetadata object containing the ORC metadata
+ Result<std::shared_ptr<const KeyValueMetadata>> ReadMetadata();
+
+ /// \brief Return the schema read from the ORC file
+ ///
+ /// \param[out] out the returned Schema object
+ ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.")
+ Status ReadSchema(std::shared_ptr<Schema>* out);
+
+ /// \brief Return the schema read from the ORC file
+ ///
+ /// \return the returned Schema object
+ Result<std::shared_ptr<Schema>> ReadSchema();
+
+ /// \brief Read the file as a Table
+ ///
+ /// The table will be composed of one record batch per stripe.
+ ///
+ /// \param[out] out the returned Table
+ ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.")
+ Status Read(std::shared_ptr<Table>* out);
+
+ /// \brief Read the file as a Table
+ ///
+ /// The table will be composed of one record batch per stripe.
+ ///
+ /// \return the returned Table
+ Result<std::shared_ptr<Table>> Read();
+
+ /// \brief Read the file as a Table
+ ///
+ /// The table will be composed of one record batch per stripe.
+ ///
+ /// \param[in] schema the Table schema
+ /// \param[out] out the returned Table
+ ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.")
+ Status Read(const std::shared_ptr<Schema>& schema, std::shared_ptr<Table>* out);
+
+ /// \brief Read the file as a Table
+ ///
+ /// The table will be composed of one record batch per stripe.
+ ///
+ /// \param[in] schema the Table schema
+ /// \return the returned Table
+ Result<std::shared_ptr<Table>> Read(const std::shared_ptr<Schema>& schema);
+
+ /// \brief Read the file as a Table
+ ///
+ /// The table will be composed of one record batch per stripe.
+ ///
+ /// \param[in] include_indices the selected field indices to read
+ /// \param[out] out the returned Table
+ ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.")
+ Status Read(const std::vector<int>& include_indices, std::shared_ptr<Table>* out);
+
+ /// \brief Read the file as a Table
+ ///
+ /// The table will be composed of one record batch per stripe.
+ ///
+ /// \param[in] include_indices the selected field indices to read
+ /// \return the returned Table
+ Result<std::shared_ptr<Table>> Read(const std::vector<int>& include_indices);
+
+ /// \brief Read the file as a Table
+ ///
+ /// The table will be composed of one record batch per stripe.
+ ///
+ /// \param[in] include_names the selected field names to read
+ /// \return the returned Table
+ Result<std::shared_ptr<Table>> Read(const std::vector<std::string>& include_names);
+
+ /// \brief Read the file as a Table
+ ///
+ /// The table will be composed of one record batch per stripe.
+ ///
+ /// \param[in] schema the Table schema
+ /// \param[in] include_indices the selected field indices to read
+ /// \param[out] out the returned Table
+ ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.")
+ Status Read(const std::shared_ptr<Schema>& schema,
+ const std::vector<int>& include_indices, std::shared_ptr<Table>* out);
+
+ /// \brief Read the file as a Table
+ ///
+ /// The table will be composed of one record batch per stripe.
+ ///
+ /// \param[in] schema the Table schema
+ /// \param[in] include_indices the selected field indices to read
+ /// \return the returned Table
+ Result<std::shared_ptr<Table>> Read(const std::shared_ptr<Schema>& schema,
+ const std::vector<int>& include_indices);
+
+ /// \brief Read a single stripe as a RecordBatch
+ ///
+ /// \param[in] stripe the stripe index
+ /// \param[out] out the returned RecordBatch
+ ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.")
+ Status ReadStripe(int64_t stripe, std::shared_ptr<RecordBatch>* out);
+
+ /// \brief Read a single stripe as a RecordBatch
+ ///
+ /// \param[in] stripe the stripe index
+ /// \return the returned RecordBatch
+ Result<std::shared_ptr<RecordBatch>> ReadStripe(int64_t stripe);
+
+ /// \brief Read a single stripe as a RecordBatch
+ ///
+ /// \param[in] stripe the stripe index
+ /// \param[in] include_indices the selected field indices to read
+ /// \param[out] out the returned RecordBatch
+ ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.")
+ Status ReadStripe(int64_t stripe, const std::vector<int>& include_indices,
+ std::shared_ptr<RecordBatch>* out);
+
+ /// \brief Read a single stripe as a RecordBatch
+ ///
+ /// \param[in] stripe the stripe index
+ /// \param[in] include_indices the selected field indices to read
+ /// \return the returned RecordBatch
+ Result<std::shared_ptr<RecordBatch>> ReadStripe(
+ int64_t stripe, const std::vector<int>& include_indices);
+
+ /// \brief Read a single stripe as a RecordBatch
+ ///
+ /// \param[in] stripe the stripe index
+ /// \param[in] include_names the selected field names to read
+ /// \return the returned RecordBatch
+ Result<std::shared_ptr<RecordBatch>> ReadStripe(
+ int64_t stripe, const std::vector<std::string>& include_names);
+
+ /// \brief Seek to designated row. Invoke NextStripeReader() after seek
+ /// will return stripe reader starting from designated row.
+ ///
+ /// \param[in] row_number the rows number to seek
+ Status Seek(int64_t row_number);
+
+ /// \brief Get a stripe level record batch iterator with specified row count
+ /// in each record batch. NextStripeReader serves as a fine grain
+ /// alternative to ReadStripe which may cause OOM issue by loading
+ /// the whole stripes into memory.
+ ///
+ /// \param[in] batch_size the number of rows each record batch contains in
+ /// record batch iteration.
+ /// \param[out] out the returned stripe reader
+ ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.")
+ Status NextStripeReader(int64_t batch_size, std::shared_ptr<RecordBatchReader>* out);
+
+ /// \brief Get a stripe level record batch iterator with specified row count
+ /// in each record batch. NextStripeReader serves as a fine grain
+ /// alternative to ReadStripe which may cause OOM issue by loading
+ /// the whole stripes into memory.
+ ///
+ /// \param[in] batch_size the number of rows each record batch contains in
+ /// record batch iteration.
+ /// \return the returned stripe reader
+ Result<std::shared_ptr<RecordBatchReader>> NextStripeReader(int64_t batch_size);
+
+ /// \brief Get a stripe level record batch iterator with specified row count
+ /// in each record batch. NextStripeReader serves as a fine grain
+ /// alternative to ReadStripe which may cause OOM issue by loading
+ /// the whole stripes into memory.
+ ///
+ /// \param[in] batch_size Get a stripe level record batch iterator with specified row
+ /// count in each record batch.
+ ///
+ /// \param[in] include_indices the selected field indices to read
+ /// \param[out] out the returned stripe reader
+ ARROW_DEPRECATED("Deprecated in 6.0.0. Use Result-returning overload instead.")
+ Status NextStripeReader(int64_t batch_size, const std::vector<int>& include_indices,
+ std::shared_ptr<RecordBatchReader>* out);
+
+ /// \brief Get a stripe level record batch iterator with specified row count
+ /// in each record batch. NextStripeReader serves as a fine grain
+ /// alternative to ReadStripe which may cause OOM issue by loading
+ /// the whole stripes into memory.
+ ///
+ /// \param[in] batch_size Get a stripe level record batch iterator with specified row
+ /// count in each record batch.
+ ///
+ /// \param[in] include_indices the selected field indices to read
+ /// \return the returned stripe reader
+ Result<std::shared_ptr<RecordBatchReader>> NextStripeReader(
+ int64_t batch_size, const std::vector<int>& include_indices);
+
+ /// \brief The number of stripes in the file
+ int64_t NumberOfStripes();
+
+ /// \brief The number of rows in the file
+ int64_t NumberOfRows();
+
+ private:
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+ ORCFileReader();
+};
+
+/// \class ORCFileWriter
+/// \brief Write an Arrow Table or RecordBatch to an ORC file.
+class ARROW_EXPORT ORCFileWriter {
+ public:
+ ~ORCFileWriter();
+ /// \brief Creates a new ORC writer.
+ ///
+ /// \param[in] output_stream a pointer to the io::OutputStream to write into
+ /// \return the returned writer object
+ static Result<std::unique_ptr<ORCFileWriter>> Open(io::OutputStream* output_stream);
+
+ /// \brief Write a table
+ ///
+ /// \param[in] table the Arrow table from which data is extracted
+ /// \return Status
+ Status Write(const Table& table);
+
+ /// \brief Close an ORC writer (orc::Writer)
+ ///
+ /// \return Status
+ Status Close();
+
+ private:
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+
+ private:
+ ORCFileWriter();
+};
+
+} // namespace orc
+} // namespace adapters
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/adapters/orc/adapter_test.cc b/src/arrow/cpp/src/arrow/adapters/orc/adapter_test.cc
new file mode 100644
index 000000000..39c66b90f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/adapters/orc/adapter_test.cc
@@ -0,0 +1,686 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/adapters/orc/adapter.h"
+
+#include <gtest/gtest.h>
+
+#include <orc/OrcFile.hh>
+#include <string>
+
+#include "arrow/adapters/orc/adapter_util.h"
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/chunked_array.h"
+#include "arrow/compute/cast.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/memory.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace liborc = orc;
+
+namespace arrow {
+
+using internal::checked_pointer_cast;
+
+constexpr int kDefaultSmallMemStreamSize = 16384 * 5; // 80KB
+constexpr int kDefaultMemStreamSize = 10 * 1024 * 1024;
+constexpr int64_t kNanoMax = std::numeric_limits<int64_t>::max();
+constexpr int64_t kNanoMin = std::numeric_limits<int64_t>::lowest();
+const int64_t kMicroMax = std::floor(kNanoMax / 1000);
+const int64_t kMicroMin = std::ceil(kNanoMin / 1000);
+const int64_t kMilliMax = std::floor(kMicroMax / 1000);
+const int64_t kMilliMin = std::ceil(kMicroMin / 1000);
+const int64_t kSecondMax = std::floor(kMilliMax / 1000);
+const int64_t kSecondMin = std::ceil(kMilliMin / 1000);
+
+static constexpr random::SeedType kRandomSeed = 0x0ff1ce;
+
+class MemoryOutputStream : public liborc::OutputStream {
+ public:
+ explicit MemoryOutputStream(ssize_t capacity)
+ : data_(capacity), name_("MemoryOutputStream"), length_(0) {}
+
+ uint64_t getLength() const override { return length_; }
+
+ uint64_t getNaturalWriteSize() const override { return natural_write_size_; }
+
+ void write(const void* buf, size_t size) override {
+ memcpy(data_.data() + length_, buf, size);
+ length_ += size;
+ }
+
+ const std::string& getName() const override { return name_; }
+
+ const char* getData() const { return data_.data(); }
+
+ void close() override {}
+
+ void reset() { length_ = 0; }
+
+ private:
+ std::vector<char> data_;
+ std::string name_;
+ uint64_t length_, natural_write_size_;
+};
+
+std::shared_ptr<Buffer> GenerateFixedDifferenceBuffer(int32_t fixed_length,
+ int64_t length) {
+ BufferBuilder builder;
+ int32_t offsets[length];
+ ARROW_EXPECT_OK(builder.Resize(4 * length));
+ for (int32_t i = 0; i < length; i++) {
+ offsets[i] = fixed_length * i;
+ }
+ ARROW_EXPECT_OK(builder.Append(offsets, 4 * length));
+ std::shared_ptr<Buffer> buffer;
+ ARROW_EXPECT_OK(builder.Finish(&buffer));
+ return buffer;
+}
+
+std::shared_ptr<Array> CastFixedSizeBinaryArrayToBinaryArray(
+ std::shared_ptr<Array> array) {
+ auto fixed_size_binary_array = checked_pointer_cast<FixedSizeBinaryArray>(array);
+ std::shared_ptr<Buffer> value_offsets = GenerateFixedDifferenceBuffer(
+ fixed_size_binary_array->byte_width(), array->length() + 1);
+ return std::make_shared<BinaryArray>(array->length(), value_offsets,
+ array->data()->buffers[1],
+ array->data()->buffers[0]);
+}
+
+template <typename TargetArrayType>
+std::shared_ptr<Array> CastInt64ArrayToTemporalArray(
+ const std::shared_ptr<DataType>& type, std::shared_ptr<Array> array) {
+ std::shared_ptr<ArrayData> new_array_data =
+ ArrayData::Make(type, array->length(), array->data()->buffers);
+ return std::make_shared<TargetArrayType>(new_array_data);
+}
+
+Result<std::shared_ptr<Array>> GenerateRandomDate64Array(int64_t size,
+ double null_probability) {
+ arrow::random::RandomArrayGenerator rand(kRandomSeed);
+ return CastInt64ArrayToTemporalArray<Date64Array>(
+ date64(), rand.Int64(size, kMilliMin, kMilliMax, null_probability));
+}
+
+Result<std::shared_ptr<Array>> GenerateRandomTimestampArray(int64_t size,
+ arrow::TimeUnit::type type,
+ double null_probability) {
+ arrow::random::RandomArrayGenerator rand(kRandomSeed);
+ switch (type) {
+ case arrow::TimeUnit::type::SECOND: {
+ return CastInt64ArrayToTemporalArray<TimestampArray>(
+ timestamp(TimeUnit::SECOND),
+ rand.Int64(size, kSecondMin, kSecondMax, null_probability));
+ }
+ case arrow::TimeUnit::type::MILLI: {
+ return CastInt64ArrayToTemporalArray<TimestampArray>(
+ timestamp(TimeUnit::MILLI),
+ rand.Int64(size, kMilliMin, kMilliMax, null_probability));
+ }
+ case arrow::TimeUnit::type::MICRO: {
+ return CastInt64ArrayToTemporalArray<TimestampArray>(
+ timestamp(TimeUnit::MICRO),
+ rand.Int64(size, kMicroMin, kMicroMax, null_probability));
+ }
+ case arrow::TimeUnit::type::NANO: {
+ return CastInt64ArrayToTemporalArray<TimestampArray>(
+ timestamp(TimeUnit::NANO),
+ rand.Int64(size, kNanoMin, kNanoMax, null_probability));
+ }
+ default: {
+ return arrow::Status::TypeError("Unknown or unsupported Arrow TimeUnit: ", type);
+ }
+ }
+}
+
+/// \brief Construct a random weak composition of a nonnegative integer
+/// i.e. a way of writing it as the sum of a sequence of n non-negative
+/// integers.
+///
+/// \param[in] n the number of integers in the weak composition
+/// \param[in] sum the integer of which a random weak composition is generated
+/// \param[out] out The generated weak composition
+template <typename T, typename U>
+void RandWeakComposition(int64_t n, T sum, std::vector<U>* out) {
+ const int random_seed = 0;
+ std::default_random_engine gen(random_seed);
+ out->resize(n, static_cast<T>(0));
+ T remaining_sum = sum;
+ std::generate(out->begin(), out->end() - 1, [&gen, &remaining_sum] {
+ std::uniform_int_distribution<T> d(static_cast<T>(0), remaining_sum);
+ auto res = d(gen);
+ remaining_sum -= res;
+ return static_cast<U>(res);
+ });
+ (*out)[n - 1] += remaining_sum;
+ std::random_shuffle(out->begin(), out->end());
+}
+
+std::shared_ptr<ChunkedArray> GenerateRandomChunkedArray(
+ const std::shared_ptr<DataType>& data_type, int64_t size, int64_t min_num_chunks,
+ int64_t max_num_chunks, double null_probability) {
+ arrow::random::RandomArrayGenerator rand(kRandomSeed);
+ std::vector<int64_t> num_chunks(1, 0);
+ std::vector<int64_t> current_size_chunks;
+ arrow::randint<int64_t, int64_t>(1, min_num_chunks, max_num_chunks, &num_chunks);
+ int64_t current_num_chunks = num_chunks[0];
+ ArrayVector arrays(current_num_chunks, nullptr);
+ arrow::RandWeakComposition(current_num_chunks, size, &current_size_chunks);
+ for (int j = 0; j < current_num_chunks; j++) {
+ switch (data_type->id()) {
+ case arrow::Type::type::DATE64: {
+ EXPECT_OK_AND_ASSIGN(arrays[j], GenerateRandomDate64Array(current_size_chunks[j],
+ null_probability));
+ break;
+ }
+ case arrow::Type::type::TIMESTAMP: {
+ EXPECT_OK_AND_ASSIGN(
+ arrays[j],
+ GenerateRandomTimestampArray(
+ current_size_chunks[j],
+ arrow::internal::checked_pointer_cast<arrow::TimestampType>(data_type)
+ ->unit(),
+ null_probability));
+ break;
+ }
+ default:
+ arrays[j] = rand.ArrayOf(data_type, current_size_chunks[j], null_probability);
+ }
+ }
+ return std::make_shared<ChunkedArray>(arrays);
+}
+
+std::shared_ptr<Table> GenerateRandomTable(const std::shared_ptr<Schema>& schema,
+ int64_t size, int64_t min_num_chunks,
+ int64_t max_num_chunks,
+ double null_probability) {
+ int num_cols = schema->num_fields();
+ ChunkedArrayVector cv;
+ for (int col = 0; col < num_cols; col++) {
+ cv.push_back(GenerateRandomChunkedArray(schema->field(col)->type(), size,
+ min_num_chunks, max_num_chunks,
+ null_probability));
+ }
+ return Table::Make(schema, cv);
+}
+
+void AssertTableWriteReadEqual(const std::shared_ptr<Table>& input_table,
+ const std::shared_ptr<Table>& expected_output_table,
+ const int64_t max_size = kDefaultSmallMemStreamSize) {
+ EXPECT_OK_AND_ASSIGN(auto buffer_output_stream,
+ io::BufferOutputStream::Create(max_size));
+ EXPECT_OK_AND_ASSIGN(auto writer,
+ adapters::orc::ORCFileWriter::Open(buffer_output_stream.get()));
+ ARROW_EXPECT_OK(writer->Write(*input_table));
+ ARROW_EXPECT_OK(writer->Close());
+ EXPECT_OK_AND_ASSIGN(auto buffer, buffer_output_stream->Finish());
+ std::shared_ptr<io::RandomAccessFile> in_stream(new io::BufferReader(buffer));
+ EXPECT_OK_AND_ASSIGN(
+ auto reader, adapters::orc::ORCFileReader::Open(in_stream, default_memory_pool()));
+ EXPECT_OK_AND_ASSIGN(auto actual_output_table, reader->Read());
+ AssertTablesEqual(*expected_output_table, *actual_output_table, false, false);
+}
+
+void AssertArrayWriteReadEqual(const std::shared_ptr<Array>& input_array,
+ const std::shared_ptr<Array>& expected_output_array,
+ const int64_t max_size = kDefaultSmallMemStreamSize) {
+ std::shared_ptr<Schema> input_schema = schema({field("col0", input_array->type())}),
+ output_schema =
+ schema({field("col0", expected_output_array->type())});
+ auto input_chunked_array = std::make_shared<ChunkedArray>(input_array),
+ expected_output_chunked_array =
+ std::make_shared<ChunkedArray>(expected_output_array);
+ std::shared_ptr<Table> input_table = Table::Make(input_schema, {input_chunked_array}),
+ expected_output_table =
+ Table::Make(output_schema, {expected_output_chunked_array});
+ AssertTableWriteReadEqual(input_table, expected_output_table, max_size);
+}
+
+void SchemaORCWriteReadTest(const std::shared_ptr<Schema>& schema, int64_t size,
+ int64_t min_num_chunks, int64_t max_num_chunks,
+ double null_probability,
+ int64_t max_size = kDefaultSmallMemStreamSize) {
+ const std::shared_ptr<Table> table =
+ GenerateRandomTable(schema, size, min_num_chunks, max_num_chunks, null_probability);
+ AssertTableWriteReadEqual(table, table, max_size);
+}
+
+std::unique_ptr<liborc::Writer> CreateWriter(uint64_t stripe_size,
+ const liborc::Type& type,
+ liborc::OutputStream* stream) {
+ liborc::WriterOptions options;
+ options.setStripeSize(stripe_size);
+ options.setCompressionBlockSize(1024);
+ options.setMemoryPool(liborc::getDefaultPool());
+ options.setRowIndexStride(0);
+ return liborc::createWriter(type, stream, options);
+}
+
+TEST(TestAdapterRead, ReadIntAndStringFileMultipleStripes) {
+ MemoryOutputStream mem_stream(kDefaultMemStreamSize);
+ ORC_UNIQUE_PTR<liborc::Type> type(
+ liborc::Type::buildTypeFromString("struct<col1:int,col2:string>"));
+
+ constexpr uint64_t stripe_size = 1024; // 1K
+ constexpr uint64_t stripe_count = 10;
+ constexpr uint64_t stripe_row_count = 16384;
+ constexpr uint64_t reader_batch_size = 1024;
+
+ auto writer = CreateWriter(stripe_size, *type, &mem_stream);
+ auto batch = writer->createRowBatch(stripe_row_count);
+ auto struct_batch = internal::checked_cast<liborc::StructVectorBatch*>(batch.get());
+ auto long_batch =
+ internal::checked_cast<liborc::LongVectorBatch*>(struct_batch->fields[0]);
+ auto str_batch =
+ internal::checked_cast<liborc::StringVectorBatch*>(struct_batch->fields[1]);
+ int64_t accumulated = 0;
+
+ for (uint64_t j = 0; j < stripe_count; ++j) {
+ std::string data_buffer(stripe_row_count * 5, '\0');
+ uint64_t offset = 0;
+ for (uint64_t i = 0; i < stripe_row_count; ++i) {
+ std::string str_data = std::to_string(accumulated % stripe_row_count);
+ long_batch->data[i] = static_cast<int64_t>(accumulated % stripe_row_count);
+ str_batch->data[i] = &data_buffer[offset];
+ str_batch->length[i] = static_cast<int64_t>(str_data.size());
+ memcpy(&data_buffer[offset], str_data.c_str(), str_data.size());
+ accumulated++;
+ offset += str_data.size();
+ }
+ struct_batch->numElements = stripe_row_count;
+ long_batch->numElements = stripe_row_count;
+ str_batch->numElements = stripe_row_count;
+
+ writer->add(*batch);
+ }
+
+ writer->close();
+
+ std::shared_ptr<io::RandomAccessFile> in_stream(new io::BufferReader(
+ std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>(mem_stream.getData()),
+ static_cast<int64_t>(mem_stream.getLength()))));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto reader, adapters::orc::ORCFileReader::Open(in_stream, default_memory_pool()));
+
+ EXPECT_OK_AND_ASSIGN(auto metadata, reader->ReadMetadata());
+ auto expected_metadata = std::const_pointer_cast<const KeyValueMetadata>(
+ key_value_metadata(std::vector<std::string>(), std::vector<std::string>()));
+ ASSERT_TRUE(metadata->Equals(*expected_metadata));
+ ASSERT_EQ(stripe_row_count * stripe_count, reader->NumberOfRows());
+ ASSERT_EQ(stripe_count, reader->NumberOfStripes());
+ accumulated = 0;
+ EXPECT_OK_AND_ASSIGN(auto stripe_reader, reader->NextStripeReader(reader_batch_size));
+ while (stripe_reader) {
+ std::shared_ptr<RecordBatch> record_batch;
+ EXPECT_TRUE(stripe_reader->ReadNext(&record_batch).ok());
+ while (record_batch) {
+ auto int32_array = checked_pointer_cast<Int32Array>(record_batch->column(0));
+ auto str_array = checked_pointer_cast<StringArray>(record_batch->column(1));
+ for (int j = 0; j < record_batch->num_rows(); ++j) {
+ EXPECT_EQ(accumulated % stripe_row_count, int32_array->Value(j));
+ EXPECT_EQ(std::to_string(accumulated % stripe_row_count),
+ str_array->GetString(j));
+ accumulated++;
+ }
+ EXPECT_TRUE(stripe_reader->ReadNext(&record_batch).ok());
+ }
+ EXPECT_OK_AND_ASSIGN(stripe_reader, reader->NextStripeReader(reader_batch_size));
+ }
+
+ // test seek operation
+ int64_t start_offset = 830;
+ EXPECT_TRUE(reader->Seek(stripe_row_count + start_offset).ok());
+
+ EXPECT_OK_AND_ASSIGN(stripe_reader, reader->NextStripeReader(reader_batch_size));
+ std::shared_ptr<RecordBatch> record_batch;
+ EXPECT_TRUE(stripe_reader->ReadNext(&record_batch).ok());
+ while (record_batch) {
+ auto int32_array = std::dynamic_pointer_cast<Int32Array>(record_batch->column(0));
+ auto str_array = std::dynamic_pointer_cast<StringArray>(record_batch->column(1));
+ for (int j = 0; j < record_batch->num_rows(); ++j) {
+ std::ostringstream os;
+ os << start_offset % stripe_row_count;
+ EXPECT_EQ(start_offset % stripe_row_count, int32_array->Value(j));
+ EXPECT_EQ(os.str(), str_array->GetString(j));
+ start_offset++;
+ }
+ EXPECT_TRUE(stripe_reader->ReadNext(&record_batch).ok());
+ }
+}
+
+// WriteORC tests
+// Trivial
+
+class TestORCWriterTrivialNoConversion : public ::testing::Test {
+ public:
+ TestORCWriterTrivialNoConversion() {
+ table_schema = schema(
+ {field("bool", boolean()), field("int8", int8()), field("int16", int16()),
+ field("int32", int32()), field("int64", int64()), field("float", float32()),
+ field("double", float64()), field("decimal128nz", decimal128(25, 6)),
+ field("decimal128z", decimal128(32, 0)), field("date32", date32()),
+ field("ts3", timestamp(TimeUnit::NANO)), field("string", utf8()),
+ field("binary", binary()),
+ field("struct", struct_({field("a", utf8()), field("b", int64())})),
+ field("list", list(int32())),
+ field("lsl", list(struct_({field("lsl0", list(int32()))}))),
+ field("map", map(utf8(), utf8()))});
+ }
+
+ protected:
+ std::shared_ptr<Schema> table_schema;
+};
+TEST_F(TestORCWriterTrivialNoConversion, writeTrivialChunk) {
+ std::shared_ptr<Table> table = TableFromJSON(table_schema, {R"([])"});
+ AssertTableWriteReadEqual(table, table, kDefaultSmallMemStreamSize / 16);
+}
+TEST_F(TestORCWriterTrivialNoConversion, writeChunkless) {
+ std::shared_ptr<Table> table = TableFromJSON(table_schema, {});
+ AssertTableWriteReadEqual(table, table, kDefaultSmallMemStreamSize / 16);
+}
+class TestORCWriterTrivialWithConversion : public ::testing::Test {
+ public:
+ TestORCWriterTrivialWithConversion() {
+ input_schema = schema(
+ {field("date64", date64()), field("ts0", timestamp(TimeUnit::SECOND)),
+ field("ts1", timestamp(TimeUnit::MILLI)),
+ field("ts2", timestamp(TimeUnit::MICRO)), field("large_string", large_utf8()),
+ field("large_binary", large_binary()),
+ field("fixed_size_binary0", fixed_size_binary(0)),
+ field("fixed_size_binary", fixed_size_binary(5)),
+ field("large_list", large_list(int32())),
+ field("fixed_size_list", fixed_size_list(int32(), 3))}),
+ output_schema = schema(
+ {field("date64", timestamp(TimeUnit::NANO)),
+ field("ts0", timestamp(TimeUnit::NANO)), field("ts1", timestamp(TimeUnit::NANO)),
+ field("ts2", timestamp(TimeUnit::NANO)), field("large_string", utf8()),
+ field("large_binary", binary()), field("fixed_size_binary0", binary()),
+ field("fixed_size_binary", binary()), field("large_list", list(int32())),
+ field("fixed_size_list", list(int32()))});
+ }
+
+ protected:
+ std::shared_ptr<Schema> input_schema, output_schema;
+};
+TEST_F(TestORCWriterTrivialWithConversion, writeTrivialChunk) {
+ std::shared_ptr<Table> input_table = TableFromJSON(input_schema, {R"([])"}),
+ expected_output_table = TableFromJSON(output_schema, {R"([])"});
+ AssertTableWriteReadEqual(input_table, expected_output_table,
+ kDefaultSmallMemStreamSize / 16);
+}
+TEST_F(TestORCWriterTrivialWithConversion, writeChunkless) {
+ std::shared_ptr<Table> input_table = TableFromJSON(input_schema, {}),
+ expected_output_table = TableFromJSON(output_schema, {});
+ AssertTableWriteReadEqual(input_table, expected_output_table,
+ kDefaultSmallMemStreamSize / 16);
+}
+
+// General
+
+class TestORCWriterNoConversion : public ::testing::Test {
+ public:
+ TestORCWriterNoConversion() {
+ table_schema = schema(
+ {field("bool", boolean()), field("int8", int8()), field("int16", int16()),
+ field("int32", int32()), field("int64", int64()), field("float", float32()),
+ field("double", float64()), field("date32", date32()),
+ field("decimal64", decimal128(18, 4)), field("decimal64z", decimal128(18, 0)),
+ field("ts3", timestamp(TimeUnit::NANO)), field("string", utf8()),
+ field("binary", binary())});
+ }
+
+ protected:
+ std::shared_ptr<Schema> table_schema;
+};
+TEST_F(TestORCWriterNoConversion, writeNoNulls) {
+ SchemaORCWriteReadTest(table_schema, 11203, 5, 10, 0, kDefaultSmallMemStreamSize * 5);
+}
+TEST_F(TestORCWriterNoConversion, writeMixed) {
+ SchemaORCWriteReadTest(table_schema, 9405, 1, 20, 0.6, kDefaultSmallMemStreamSize * 5);
+}
+TEST_F(TestORCWriterNoConversion, writeAllNulls) {
+ SchemaORCWriteReadTest(table_schema, 4006, 1, 5, 1);
+}
+
+// Converts
+// Since Arrow has way more types than ORC type conversions are unavoidable
+class TestORCWriterWithConversion : public ::testing::Test {
+ public:
+ TestORCWriterWithConversion() {
+ input_schema = schema(
+ {field("date64", date64()), field("ts0", timestamp(TimeUnit::SECOND)),
+ field("ts1", timestamp(TimeUnit::MILLI)),
+ field("ts2", timestamp(TimeUnit::MICRO)), field("large_string", large_utf8()),
+ field("large_binary", large_binary()),
+ field("fixed_size_binary0", fixed_size_binary(0)),
+ field("fixed_size_binary", fixed_size_binary(5))}),
+ output_schema = schema(
+ {field("date64", timestamp(TimeUnit::NANO)),
+ field("ts0", timestamp(TimeUnit::NANO)), field("ts1", timestamp(TimeUnit::NANO)),
+ field("ts2", timestamp(TimeUnit::NANO)), field("large_string", utf8()),
+ field("large_binary", binary()), field("fixed_size_binary0", binary()),
+ field("fixed_size_binary", binary())});
+ }
+ void RunTest(int64_t num_rows, double null_possibility,
+ int64_t max_size = kDefaultSmallMemStreamSize) {
+ int64_t num_cols = (input_schema->fields()).size();
+ std::shared_ptr<Table> input_table =
+ GenerateRandomTable(input_schema, num_rows, 1, 1, null_possibility);
+ ArrayVector av(num_cols);
+ for (int i = 0; i < num_cols - 2; i++) {
+ EXPECT_OK_AND_ASSIGN(av[i],
+ arrow::compute::Cast(*(input_table->column(i)->chunk(0)),
+ output_schema->field(i)->type()));
+ }
+ for (int i = num_cols - 2; i < num_cols; i++) {
+ av[i] = CastFixedSizeBinaryArrayToBinaryArray(input_table->column(i)->chunk(0));
+ }
+ std::shared_ptr<Table> expected_output_table = Table::Make(output_schema, av);
+ AssertTableWriteReadEqual(input_table, expected_output_table, max_size);
+ }
+
+ protected:
+ std::shared_ptr<Schema> input_schema, output_schema;
+};
+TEST_F(TestORCWriterWithConversion, writeAllNulls) { RunTest(12000, 1); }
+TEST_F(TestORCWriterWithConversion, writeNoNulls) { RunTest(10009, 0); }
+TEST_F(TestORCWriterWithConversion, writeMixed) { RunTest(8021, 0.5); }
+
+class TestORCWriterSingleArray : public ::testing::Test {
+ public:
+ TestORCWriterSingleArray() : rand(kRandomSeed) {}
+
+ protected:
+ arrow::random::RandomArrayGenerator rand;
+};
+
+// Nested types
+TEST_F(TestORCWriterSingleArray, WriteStruct) {
+ std::vector<std::shared_ptr<Field>> subfields{field("int32", boolean())};
+ const int64_t num_rows = 1234;
+ int num_subcols = subfields.size();
+ ArrayVector av0(num_subcols);
+ for (int i = 0; i < num_subcols; i++) {
+ av0[i] = rand.ArrayOf(subfields[i]->type(), num_rows, 0.4);
+ }
+ std::shared_ptr<Buffer> bitmap = rand.NullBitmap(num_rows, 0.5);
+ std::shared_ptr<Array> array =
+ std::make_shared<StructArray>(struct_(subfields), num_rows, av0, bitmap);
+ AssertArrayWriteReadEqual(array, array, kDefaultSmallMemStreamSize * 10);
+}
+
+TEST_F(TestORCWriterSingleArray, WriteStructOfStruct) {
+ std::vector<std::shared_ptr<Field>> subsubfields{
+ field("bool", boolean()),
+ field("int8", int8()),
+ field("int16", int16()),
+ field("int32", int32()),
+ field("int64", int64()),
+ field("date32", date32()),
+ field("ts3", timestamp(TimeUnit::NANO)),
+ field("string", utf8()),
+ field("binary", binary())};
+ const int64_t num_rows = 1234;
+ int num_subsubcols = subsubfields.size();
+ ArrayVector av00(num_subsubcols), av0(1);
+ for (int i = 0; i < num_subsubcols; i++) {
+ av00[i] = rand.ArrayOf(subsubfields[i]->type(), num_rows, 0);
+ }
+ std::shared_ptr<Buffer> bitmap0 = rand.NullBitmap(num_rows, 0);
+ av0[0] = std::make_shared<StructArray>(struct_(subsubfields), num_rows, av00, bitmap0);
+ std::shared_ptr<Buffer> bitmap = rand.NullBitmap(num_rows, 0.2);
+ std::shared_ptr<Array> array = std::make_shared<StructArray>(
+ struct_({field("struct2", struct_(subsubfields))}), num_rows, av0, bitmap);
+ AssertArrayWriteReadEqual(array, array, kDefaultSmallMemStreamSize * 10);
+}
+
+TEST_F(TestORCWriterSingleArray, WriteList) {
+ const int64_t num_rows = 1234;
+ auto value_array = rand.ArrayOf(int32(), 125 * num_rows, 0);
+ std::shared_ptr<Array> array = rand.List(*value_array, num_rows, 1);
+ AssertArrayWriteReadEqual(array, array, kDefaultSmallMemStreamSize * 100);
+}
+
+TEST_F(TestORCWriterSingleArray, WriteLargeList) {
+ const int64_t num_rows = 1234;
+ auto value_array = rand.ArrayOf(int32(), 5 * num_rows, 0.5);
+ auto output_offsets = rand.Offsets(num_rows + 1, 0, 5 * num_rows, 0.6, false);
+ EXPECT_OK_AND_ASSIGN(auto input_offsets,
+ arrow::compute::Cast(*output_offsets, int64()));
+ EXPECT_OK_AND_ASSIGN(auto input_array,
+ arrow::LargeListArray::FromArrays(*input_offsets, *value_array));
+ EXPECT_OK_AND_ASSIGN(auto output_array,
+ arrow::ListArray::FromArrays(*output_offsets, *value_array));
+ AssertArrayWriteReadEqual(input_array, output_array, kDefaultSmallMemStreamSize * 10);
+}
+
+TEST_F(TestORCWriterSingleArray, WriteFixedSizeList) {
+ const int64_t num_rows = 1234;
+ std::shared_ptr<Array> value_array = rand.ArrayOf(int32(), 3 * num_rows, 0.8);
+ std::shared_ptr<Buffer> bitmap = rand.NullBitmap(num_rows, 1);
+ std::shared_ptr<Buffer> buffer = GenerateFixedDifferenceBuffer(3, num_rows + 1);
+ std::shared_ptr<Array> input_array = std::make_shared<FixedSizeListArray>(
+ fixed_size_list(int32(), 3), num_rows, value_array, bitmap),
+ output_array = std::make_shared<ListArray>(
+ list(int32()), num_rows, buffer, value_array, bitmap);
+ AssertArrayWriteReadEqual(input_array, output_array, kDefaultSmallMemStreamSize * 10);
+}
+
+TEST_F(TestORCWriterSingleArray, WriteListOfList) {
+ const int64_t num_rows = 1234;
+ auto value_value_array = rand.ArrayOf(utf8(), 4 * num_rows, 0.5);
+ std::shared_ptr<Array> value_array = rand.List(*value_value_array, 2 * num_rows, 0.7);
+ std::shared_ptr<Array> array = rand.List(*value_array, num_rows, 0.4);
+ AssertArrayWriteReadEqual(array, array, kDefaultSmallMemStreamSize * 10);
+}
+
+TEST_F(TestORCWriterSingleArray, WriteListOfListOfList) {
+ const int64_t num_rows = 1234;
+ auto value3_array = rand.ArrayOf(int64(), 12 * num_rows, 0.1);
+ std::shared_ptr<Array> value2_array = rand.List(*value3_array, 5 * num_rows, 0);
+ std::shared_ptr<Array> value_array = rand.List(*value2_array, 2 * num_rows, 0.1);
+ std::shared_ptr<Array> array = rand.List(*value_array, num_rows, 0.1);
+ AssertArrayWriteReadEqual(array, array, kDefaultSmallMemStreamSize * 35);
+}
+
+TEST_F(TestORCWriterSingleArray, WriteListOfStruct) {
+ const int64_t num_rows = 1234, num_values = 3 * num_rows;
+ ArrayVector av00(1);
+ av00[0] = rand.ArrayOf(int32(), num_values, 0);
+ std::shared_ptr<Buffer> bitmap = rand.NullBitmap(num_values, 0.2);
+ std::shared_ptr<Array> value_array = std::make_shared<StructArray>(
+ struct_({field("a", int32())}), num_values, av00, bitmap);
+ std::shared_ptr<Array> array = rand.List(*value_array, num_rows, 0);
+ AssertArrayWriteReadEqual(array, array, kDefaultSmallMemStreamSize * 30);
+}
+
+TEST_F(TestORCWriterSingleArray, WriteStructOfList) {
+ const int64_t num_rows = 1234;
+ ArrayVector av0(1);
+ auto value_array = rand.ArrayOf(int32(), 5 * num_rows, 0.2);
+ av0[0] = rand.List(*value_array, num_rows, 0);
+ std::shared_ptr<Buffer> bitmap = rand.NullBitmap(num_rows, 0.2);
+ std::shared_ptr<Array> array = std::make_shared<StructArray>(
+ struct_({field("a", list(int32()))}), num_rows, av0, bitmap);
+ AssertArrayWriteReadEqual(array, array, kDefaultSmallMemStreamSize * 20);
+}
+
+TEST_F(TestORCWriterSingleArray, WriteMap) {
+ const int64_t num_rows = 1234;
+ auto key_array = rand.ArrayOf(int32(), 20 * num_rows, 0);
+ auto item_array = rand.ArrayOf(int32(), 20 * num_rows, 1);
+ std::shared_ptr<Array> array = rand.Map(key_array, item_array, num_rows, 0.1);
+ AssertArrayWriteReadEqual(array, array, kDefaultSmallMemStreamSize * 50);
+}
+
+TEST_F(TestORCWriterSingleArray, WriteStructOfMap) {
+ const int64_t num_rows = 1234, num_values = 5 * num_rows;
+ ArrayVector av0(1);
+ auto key_array = rand.ArrayOf(binary(), num_values, 0);
+ auto item_array = rand.ArrayOf(int32(), num_values, 0.5);
+ av0[0] = rand.Map(key_array, item_array, num_rows, 0.2);
+ std::shared_ptr<Array> array = std::make_shared<StructArray>(
+ struct_({field("a", map(binary(), int32()))}), num_rows, av0);
+ AssertArrayWriteReadEqual(array, array, kDefaultSmallMemStreamSize * 20);
+}
+
+TEST_F(TestORCWriterSingleArray, WriteMapOfStruct) {
+ const int64_t num_rows = 1234, num_values = 10 * num_rows;
+ std::shared_ptr<Array> key_array = rand.ArrayOf(utf8(), num_values, 0);
+ ArrayVector av00(1);
+ av00[0] = rand.ArrayOf(int32(), num_values, 0.1);
+ std::shared_ptr<Buffer> bitmap = rand.NullBitmap(num_values, 0.2);
+ std::shared_ptr<Array> item_array = std::make_shared<StructArray>(
+ struct_({field("a", int32())}), num_values, av00, bitmap);
+ std::shared_ptr<Array> array = rand.Map(key_array, item_array, num_rows, 0.1);
+ AssertArrayWriteReadEqual(array, array, kDefaultSmallMemStreamSize * 10);
+}
+
+TEST_F(TestORCWriterSingleArray, WriteMapOfMap) {
+ const int64_t num_rows = 1234;
+ auto key_key_array = rand.ArrayOf(utf8(), 4 * num_rows, 0);
+ auto key_item_array = rand.ArrayOf(int32(), 4 * num_rows, 0.5);
+ std::shared_ptr<Array> key_array =
+ rand.Map(key_key_array, key_item_array, 2 * num_rows, 0);
+ auto item_key_array = rand.ArrayOf(utf8(), 4 * num_rows, 0);
+ auto item_item_array = rand.ArrayOf(int32(), 4 * num_rows, 0.2);
+ std::shared_ptr<Array> item_array =
+ rand.Map(item_key_array, item_item_array, 2 * num_rows, 0.3);
+ std::shared_ptr<Array> array = rand.Map(key_array, item_array, num_rows, 0.4);
+ AssertArrayWriteReadEqual(array, array, kDefaultSmallMemStreamSize * 10);
+}
+
+TEST_F(TestORCWriterSingleArray, WriteListOfMap) {
+ const int64_t num_rows = 1234;
+ auto value_key_array = rand.ArrayOf(utf8(), 4 * num_rows, 0);
+ auto value_item_array = rand.ArrayOf(int32(), 4 * num_rows, 0.5);
+ std::shared_ptr<Array> value_array =
+ rand.Map(value_key_array, value_item_array, 2 * num_rows, 0.2);
+ std::shared_ptr<Array> array = rand.List(*value_array, num_rows, 0.4);
+ AssertArrayWriteReadEqual(array, array, kDefaultSmallMemStreamSize * 10);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/adapters/orc/adapter_util.cc b/src/arrow/cpp/src/arrow/adapters/orc/adapter_util.cc
new file mode 100644
index 000000000..f956a6f62
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/adapters/orc/adapter_util.cc
@@ -0,0 +1,1069 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/adapters/orc/adapter_util.h"
+
+#include <cmath>
+#include <string>
+#include <vector>
+
+#include "arrow/array/builder_base.h"
+#include "arrow/builder.h"
+#include "arrow/chunked_array.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/range.h"
+#include "arrow/util/string_view.h"
+#include "arrow/visitor_inline.h"
+#include "orc/Exceptions.hh"
+#include "orc/MemoryPool.hh"
+#include "orc/OrcFile.hh"
+
+// alias to not interfere with nested orc namespace
+namespace liborc = orc;
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace adapters {
+namespace orc {
+
+namespace {
+
+// The number of milliseconds, microseconds and nanoseconds in a second
+constexpr int64_t kOneSecondMillis = 1000LL;
+constexpr int64_t kOneMicroNanos = 1000LL;
+constexpr int64_t kOneSecondMicros = 1000000LL;
+constexpr int64_t kOneMilliNanos = 1000000LL;
+constexpr int64_t kOneSecondNanos = 1000000000LL;
+
+Status AppendStructBatch(const liborc::Type* type,
+ liborc::ColumnVectorBatch* column_vector_batch, int64_t offset,
+ int64_t length, ArrayBuilder* abuilder) {
+ auto builder = checked_cast<StructBuilder*>(abuilder);
+ auto batch = checked_cast<liborc::StructVectorBatch*>(column_vector_batch);
+
+ const uint8_t* valid_bytes = nullptr;
+ if (batch->hasNulls) {
+ valid_bytes = reinterpret_cast<const uint8_t*>(batch->notNull.data()) + offset;
+ }
+ RETURN_NOT_OK(builder->AppendValues(length, valid_bytes));
+
+ for (int i = 0; i < builder->num_fields(); i++) {
+ RETURN_NOT_OK(AppendBatch(type->getSubtype(i), batch->fields[i], offset, length,
+ builder->field_builder(i)));
+ }
+ return Status::OK();
+}
+
+Status AppendListBatch(const liborc::Type* type,
+ liborc::ColumnVectorBatch* column_vector_batch, int64_t offset,
+ int64_t length, ArrayBuilder* abuilder) {
+ auto builder = checked_cast<ListBuilder*>(abuilder);
+ auto batch = checked_cast<liborc::ListVectorBatch*>(column_vector_batch);
+ liborc::ColumnVectorBatch* elements = batch->elements.get();
+ const liborc::Type* elemtype = type->getSubtype(0);
+
+ const bool has_nulls = batch->hasNulls;
+ for (int64_t i = offset; i < length + offset; i++) {
+ if (!has_nulls || batch->notNull[i]) {
+ int64_t start = batch->offsets[i];
+ int64_t end = batch->offsets[i + 1];
+ RETURN_NOT_OK(builder->Append());
+ RETURN_NOT_OK(
+ AppendBatch(elemtype, elements, start, end - start, builder->value_builder()));
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ }
+ return Status::OK();
+}
+
+Status AppendMapBatch(const liborc::Type* type,
+ liborc::ColumnVectorBatch* column_vector_batch, int64_t offset,
+ int64_t length, ArrayBuilder* abuilder) {
+ auto builder = checked_cast<MapBuilder*>(abuilder);
+ auto batch = checked_cast<liborc::MapVectorBatch*>(column_vector_batch);
+ liborc::ColumnVectorBatch* keys = batch->keys.get();
+ liborc::ColumnVectorBatch* items = batch->elements.get();
+ const liborc::Type* key_type = type->getSubtype(0);
+ const liborc::Type* item_type = type->getSubtype(1);
+
+ const bool has_nulls = batch->hasNulls;
+ for (int64_t i = offset; i < length + offset; i++) {
+ if (!has_nulls || batch->notNull[i]) {
+ int64_t start = batch->offsets[i];
+ int64_t end = batch->offsets[i + 1];
+ RETURN_NOT_OK(builder->Append());
+ RETURN_NOT_OK(
+ AppendBatch(key_type, keys, start, end - start, builder->key_builder()));
+ RETURN_NOT_OK(
+ AppendBatch(item_type, items, start, end - start, builder->item_builder()));
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ }
+ return Status::OK();
+}
+
+template <class BuilderType, class BatchType, class ElemType>
+Status AppendNumericBatch(liborc::ColumnVectorBatch* column_vector_batch, int64_t offset,
+ int64_t length, ArrayBuilder* abuilder) {
+ auto builder = checked_cast<BuilderType*>(abuilder);
+ auto batch = checked_cast<BatchType*>(column_vector_batch);
+
+ if (length == 0) {
+ return Status::OK();
+ }
+ const uint8_t* valid_bytes = nullptr;
+ if (batch->hasNulls) {
+ valid_bytes = reinterpret_cast<const uint8_t*>(batch->notNull.data()) + offset;
+ }
+ const ElemType* source = batch->data.data() + offset;
+ RETURN_NOT_OK(builder->AppendValues(source, length, valid_bytes));
+ return Status::OK();
+}
+
+template <class BuilderType, class TargetType, class BatchType, class SourceType>
+Status AppendNumericBatchCast(liborc::ColumnVectorBatch* column_vector_batch,
+ int64_t offset, int64_t length, ArrayBuilder* abuilder) {
+ auto builder = checked_cast<BuilderType*>(abuilder);
+ auto batch = checked_cast<BatchType*>(column_vector_batch);
+
+ if (length == 0) {
+ return Status::OK();
+ }
+
+ const uint8_t* valid_bytes = nullptr;
+ if (batch->hasNulls) {
+ valid_bytes = reinterpret_cast<const uint8_t*>(batch->notNull.data()) + offset;
+ }
+ const SourceType* source = batch->data.data() + offset;
+ auto cast_iter = internal::MakeLazyRange(
+ [&source](int64_t index) { return static_cast<TargetType>(source[index]); },
+ length);
+
+ RETURN_NOT_OK(builder->AppendValues(cast_iter.begin(), cast_iter.end(), valid_bytes));
+
+ return Status::OK();
+}
+
+Status AppendBoolBatch(liborc::ColumnVectorBatch* column_vector_batch, int64_t offset,
+ int64_t length, ArrayBuilder* abuilder) {
+ auto builder = checked_cast<BooleanBuilder*>(abuilder);
+ auto batch = checked_cast<liborc::LongVectorBatch*>(column_vector_batch);
+
+ if (length == 0) {
+ return Status::OK();
+ }
+
+ const uint8_t* valid_bytes = nullptr;
+ if (batch->hasNulls) {
+ valid_bytes = reinterpret_cast<const uint8_t*>(batch->notNull.data()) + offset;
+ }
+ const int64_t* source = batch->data.data() + offset;
+
+ auto cast_iter = internal::MakeLazyRange(
+ [&source](int64_t index) { return static_cast<bool>(source[index]); }, length);
+
+ RETURN_NOT_OK(builder->AppendValues(cast_iter.begin(), cast_iter.end(), valid_bytes));
+
+ return Status::OK();
+}
+
+Status AppendTimestampBatch(liborc::ColumnVectorBatch* column_vector_batch,
+ int64_t offset, int64_t length, ArrayBuilder* abuilder) {
+ auto builder = checked_cast<TimestampBuilder*>(abuilder);
+ auto batch = checked_cast<liborc::TimestampVectorBatch*>(column_vector_batch);
+
+ if (length == 0) {
+ return Status::OK();
+ }
+
+ const uint8_t* valid_bytes = nullptr;
+ if (batch->hasNulls) {
+ valid_bytes = reinterpret_cast<const uint8_t*>(batch->notNull.data()) + offset;
+ }
+
+ const int64_t* seconds = batch->data.data() + offset;
+ const int64_t* nanos = batch->nanoseconds.data() + offset;
+
+ auto transform_timestamp = [seconds, nanos](int64_t index) {
+ return seconds[index] * kOneSecondNanos + nanos[index];
+ };
+
+ auto transform_range = internal::MakeLazyRange(transform_timestamp, length);
+
+ RETURN_NOT_OK(
+ builder->AppendValues(transform_range.begin(), transform_range.end(), valid_bytes));
+ return Status::OK();
+}
+
+template <class BuilderType>
+Status AppendBinaryBatch(liborc::ColumnVectorBatch* column_vector_batch, int64_t offset,
+ int64_t length, ArrayBuilder* abuilder) {
+ auto builder = checked_cast<BuilderType*>(abuilder);
+ auto batch = checked_cast<liborc::StringVectorBatch*>(column_vector_batch);
+
+ const bool has_nulls = batch->hasNulls;
+ for (int64_t i = offset; i < length + offset; i++) {
+ if (!has_nulls || batch->notNull[i]) {
+ RETURN_NOT_OK(
+ builder->Append(batch->data[i], static_cast<int32_t>(batch->length[i])));
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ }
+ return Status::OK();
+}
+
+Status AppendFixedBinaryBatch(liborc::ColumnVectorBatch* column_vector_batch,
+ int64_t offset, int64_t length, ArrayBuilder* abuilder) {
+ auto builder = checked_cast<FixedSizeBinaryBuilder*>(abuilder);
+ auto batch = checked_cast<liborc::StringVectorBatch*>(column_vector_batch);
+
+ const bool has_nulls = batch->hasNulls;
+ for (int64_t i = offset; i < length + offset; i++) {
+ if (!has_nulls || batch->notNull[i]) {
+ RETURN_NOT_OK(builder->Append(batch->data[i]));
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ }
+ return Status::OK();
+}
+
+Status AppendDecimalBatch(const liborc::Type* type,
+ liborc::ColumnVectorBatch* column_vector_batch, int64_t offset,
+ int64_t length, ArrayBuilder* abuilder) {
+ auto builder = checked_cast<Decimal128Builder*>(abuilder);
+
+ const bool has_nulls = column_vector_batch->hasNulls;
+ if (type->getPrecision() == 0 || type->getPrecision() > 18) {
+ auto batch = checked_cast<liborc::Decimal128VectorBatch*>(column_vector_batch);
+ for (int64_t i = offset; i < length + offset; i++) {
+ if (!has_nulls || batch->notNull[i]) {
+ RETURN_NOT_OK(builder->Append(
+ Decimal128(batch->values[i].getHighBits(), batch->values[i].getLowBits())));
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ }
+ } else {
+ auto batch = checked_cast<liborc::Decimal64VectorBatch*>(column_vector_batch);
+ for (int64_t i = offset; i < length + offset; i++) {
+ if (!has_nulls || batch->notNull[i]) {
+ RETURN_NOT_OK(builder->Append(Decimal128(batch->values[i])));
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status AppendBatch(const liborc::Type* type, liborc::ColumnVectorBatch* batch,
+ int64_t offset, int64_t length, ArrayBuilder* builder) {
+ if (type == nullptr) {
+ return Status::OK();
+ }
+ liborc::TypeKind kind = type->getKind();
+ switch (kind) {
+ case liborc::STRUCT:
+ return AppendStructBatch(type, batch, offset, length, builder);
+ case liborc::LIST:
+ return AppendListBatch(type, batch, offset, length, builder);
+ case liborc::MAP:
+ return AppendMapBatch(type, batch, offset, length, builder);
+ case liborc::LONG:
+ return AppendNumericBatch<Int64Builder, liborc::LongVectorBatch, int64_t>(
+ batch, offset, length, builder);
+ case liborc::INT:
+ return AppendNumericBatchCast<Int32Builder, int32_t, liborc::LongVectorBatch,
+ int64_t>(batch, offset, length, builder);
+ case liborc::SHORT:
+ return AppendNumericBatchCast<Int16Builder, int16_t, liborc::LongVectorBatch,
+ int64_t>(batch, offset, length, builder);
+ case liborc::BYTE:
+ return AppendNumericBatchCast<Int8Builder, int8_t, liborc::LongVectorBatch,
+ int64_t>(batch, offset, length, builder);
+ case liborc::DOUBLE:
+ return AppendNumericBatch<DoubleBuilder, liborc::DoubleVectorBatch, double>(
+ batch, offset, length, builder);
+ case liborc::FLOAT:
+ return AppendNumericBatchCast<FloatBuilder, float, liborc::DoubleVectorBatch,
+ double>(batch, offset, length, builder);
+ case liborc::BOOLEAN:
+ return AppendBoolBatch(batch, offset, length, builder);
+ case liborc::VARCHAR:
+ case liborc::STRING:
+ return AppendBinaryBatch<StringBuilder>(batch, offset, length, builder);
+ case liborc::BINARY:
+ return AppendBinaryBatch<BinaryBuilder>(batch, offset, length, builder);
+ case liborc::CHAR:
+ return AppendFixedBinaryBatch(batch, offset, length, builder);
+ case liborc::DATE:
+ return AppendNumericBatchCast<Date32Builder, int32_t, liborc::LongVectorBatch,
+ int64_t>(batch, offset, length, builder);
+ case liborc::TIMESTAMP:
+ return AppendTimestampBatch(batch, offset, length, builder);
+ case liborc::DECIMAL:
+ return AppendDecimalBatch(type, batch, offset, length, builder);
+ default:
+ return Status::NotImplemented("Not implemented type kind: ", kind);
+ }
+}
+
+namespace {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+Status WriteBatch(const Array& parray, int64_t orc_offset,
+ liborc::ColumnVectorBatch* column_vector_batch);
+
+// Make sure children of StructArray have appropriate null.
+Result<std::shared_ptr<Array>> NormalizeArray(const std::shared_ptr<Array>& array) {
+ Type::type kind = array->type_id();
+ switch (kind) {
+ case Type::type::STRUCT: {
+ if (array->null_count() == 0) {
+ return array;
+ } else {
+ auto struct_array = checked_pointer_cast<StructArray>(array);
+ const std::shared_ptr<Buffer> bitmap = struct_array->null_bitmap();
+ std::shared_ptr<DataType> struct_type = struct_array->type();
+ std::size_t size = struct_type->fields().size();
+ std::vector<std::shared_ptr<Array>> new_children(size, nullptr);
+ for (std::size_t i = 0; i < size; i++) {
+ std::shared_ptr<Array> child = struct_array->field(i);
+ const std::shared_ptr<Buffer> child_bitmap = child->null_bitmap();
+ std::shared_ptr<Buffer> final_child_bitmap;
+ if (child_bitmap == nullptr) {
+ final_child_bitmap = bitmap;
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ final_child_bitmap,
+ internal::BitmapAnd(default_memory_pool(), bitmap->data(), 0,
+ child_bitmap->data(), 0, struct_array->length(), 0));
+ }
+ std::shared_ptr<ArrayData> child_array_data = child->data();
+ std::vector<std::shared_ptr<Buffer>> child_buffers = child_array_data->buffers;
+ child_buffers[0] = final_child_bitmap;
+ std::shared_ptr<ArrayData> new_child_array_data =
+ ArrayData::Make(child->type(), child->length(), child_buffers,
+ child_array_data->child_data, child_array_data->dictionary);
+ ARROW_ASSIGN_OR_RAISE(new_children[i],
+ NormalizeArray(MakeArray(new_child_array_data)));
+ }
+ return std::make_shared<StructArray>(struct_type, struct_array->length(),
+ new_children, bitmap);
+ }
+ }
+ case Type::type::LIST: {
+ auto list_array = checked_pointer_cast<ListArray>(array);
+ ARROW_ASSIGN_OR_RAISE(auto value_array, NormalizeArray(list_array->values()));
+ return std::make_shared<ListArray>(list_array->type(), list_array->length(),
+ list_array->value_offsets(), value_array,
+ list_array->null_bitmap());
+ }
+ case Type::type::LARGE_LIST: {
+ auto list_array = checked_pointer_cast<LargeListArray>(array);
+ ARROW_ASSIGN_OR_RAISE(auto value_array, NormalizeArray(list_array->values()));
+ return std::make_shared<LargeListArray>(list_array->type(), list_array->length(),
+ list_array->value_offsets(), value_array,
+ list_array->null_bitmap());
+ }
+ case Type::type::FIXED_SIZE_LIST: {
+ auto list_array = checked_pointer_cast<FixedSizeListArray>(array);
+ ARROW_ASSIGN_OR_RAISE(auto value_array, NormalizeArray(list_array->values()));
+ return std::make_shared<FixedSizeListArray>(list_array->type(),
+ list_array->length(), value_array,
+ list_array->null_bitmap());
+ }
+ case Type::type::MAP: {
+ auto map_array = checked_pointer_cast<MapArray>(array);
+ ARROW_ASSIGN_OR_RAISE(auto key_array, NormalizeArray(map_array->keys()));
+ ARROW_ASSIGN_OR_RAISE(auto item_array, NormalizeArray(map_array->items()));
+ return std::make_shared<MapArray>(map_array->type(), map_array->length(),
+ map_array->value_offsets(), key_array, item_array,
+ map_array->null_bitmap());
+ }
+ default: {
+ return array;
+ }
+ }
+}
+
+template <class DataType, class BatchType, typename Enable = void>
+struct Appender {};
+
+// Types for long/double-like Appender, that is, numeric, boolean or date32
+template <typename T>
+using is_generic_type =
+ std::integral_constant<bool, is_number_type<T>::value ||
+ std::is_same<Date32Type, T>::value ||
+ is_boolean_type<T>::value>;
+template <typename T, typename R = void>
+using enable_if_generic = enable_if_t<is_generic_type<T>::value, R>;
+
+// Number-like
+template <class DataType, class BatchType>
+struct Appender<DataType, BatchType, enable_if_generic<DataType>> {
+ using ArrayType = typename TypeTraits<DataType>::ArrayType;
+ using ValueType = typename TypeTraits<DataType>::CType;
+ Status VisitNull() {
+ batch->notNull[running_orc_offset] = false;
+ running_orc_offset++;
+ running_arrow_offset++;
+ return Status::OK();
+ }
+ Status VisitValue(ValueType v) {
+ batch->data[running_orc_offset] = array.Value(running_arrow_offset);
+ batch->notNull[running_orc_offset] = true;
+ running_orc_offset++;
+ running_arrow_offset++;
+ return Status::OK();
+ }
+ const ArrayType& array;
+ BatchType* batch;
+ int64_t running_orc_offset, running_arrow_offset;
+};
+
+// Binary
+template <class DataType>
+struct Appender<DataType, liborc::StringVectorBatch> {
+ using ArrayType = typename TypeTraits<DataType>::ArrayType;
+ using COffsetType = typename TypeTraits<DataType>::OffsetType::c_type;
+ Status VisitNull() {
+ batch->notNull[running_orc_offset] = false;
+ running_orc_offset++;
+ running_arrow_offset++;
+ return Status::OK();
+ }
+ Status VisitValue(util::string_view v) {
+ batch->notNull[running_orc_offset] = true;
+ COffsetType data_length = 0;
+ batch->data[running_orc_offset] = reinterpret_cast<char*>(
+ const_cast<uint8_t*>(array.GetValue(running_arrow_offset, &data_length)));
+ batch->length[running_orc_offset] = data_length;
+ running_orc_offset++;
+ running_arrow_offset++;
+ return Status::OK();
+ }
+ const ArrayType& array;
+ liborc::StringVectorBatch* batch;
+ int64_t running_orc_offset, running_arrow_offset;
+};
+
+// Decimal
+template <>
+struct Appender<Decimal128Type, liborc::Decimal64VectorBatch> {
+ Status VisitNull() {
+ batch->notNull[running_orc_offset] = false;
+ running_orc_offset++;
+ running_arrow_offset++;
+ return Status::OK();
+ }
+ Status VisitValue(util::string_view v) {
+ batch->notNull[running_orc_offset] = true;
+ const Decimal128 dec_value(array.GetValue(running_arrow_offset));
+ batch->values[running_orc_offset] = static_cast<int64_t>(dec_value.low_bits());
+ running_orc_offset++;
+ running_arrow_offset++;
+ return Status::OK();
+ }
+ const Decimal128Array& array;
+ liborc::Decimal64VectorBatch* batch;
+ int64_t running_orc_offset, running_arrow_offset;
+};
+
+template <>
+struct Appender<Decimal128Type, liborc::Decimal128VectorBatch> {
+ Status VisitNull() {
+ batch->notNull[running_orc_offset] = false;
+ running_orc_offset++;
+ running_arrow_offset++;
+ return Status::OK();
+ }
+ Status VisitValue(util::string_view v) {
+ batch->notNull[running_orc_offset] = true;
+ const Decimal128 dec_value(array.GetValue(running_arrow_offset));
+ batch->values[running_orc_offset] =
+ liborc::Int128(dec_value.high_bits(), dec_value.low_bits());
+ running_orc_offset++;
+ running_arrow_offset++;
+ return Status::OK();
+ }
+ const Decimal128Array& array;
+ liborc::Decimal128VectorBatch* batch;
+ int64_t running_orc_offset, running_arrow_offset;
+};
+
+// Date64 and Timestamp
+template <class DataType>
+struct TimestampAppender {
+ using ArrayType = typename TypeTraits<DataType>::ArrayType;
+ Status VisitNull() {
+ batch->notNull[running_orc_offset] = false;
+ running_orc_offset++;
+ running_arrow_offset++;
+ return Status::OK();
+ }
+ Status VisitValue(int64_t v) {
+ int64_t data = array.Value(running_arrow_offset);
+ batch->notNull[running_orc_offset] = true;
+ batch->data[running_orc_offset] =
+ static_cast<int64_t>(std::floor(data / conversion_factor_from_second));
+ batch->nanoseconds[running_orc_offset] =
+ (data - conversion_factor_from_second * batch->data[running_orc_offset]) *
+ conversion_factor_to_nano;
+ running_orc_offset++;
+ running_arrow_offset++;
+ return Status::OK();
+ }
+ const ArrayType& array;
+ liborc::TimestampVectorBatch* batch;
+ int64_t running_orc_offset, running_arrow_offset;
+ int64_t conversion_factor_from_second, conversion_factor_to_nano;
+};
+
+// FSB
+struct FixedSizeBinaryAppender {
+ Status VisitNull() {
+ batch->notNull[running_orc_offset] = false;
+ running_orc_offset++;
+ running_arrow_offset++;
+ return Status::OK();
+ }
+ Status VisitValue(util::string_view v) {
+ batch->notNull[running_orc_offset] = true;
+ batch->data[running_orc_offset] = reinterpret_cast<char*>(
+ const_cast<uint8_t*>(array.GetValue(running_arrow_offset)));
+ batch->length[running_orc_offset] = data_length;
+ running_orc_offset++;
+ running_arrow_offset++;
+ return Status::OK();
+ }
+ const FixedSizeBinaryArray& array;
+ liborc::StringVectorBatch* batch;
+ int64_t running_orc_offset, running_arrow_offset;
+ const int32_t data_length;
+};
+
+// static_cast from int64_t or double to itself shouldn't introduce overhead
+// Pleae see
+// https://stackoverflow.com/questions/19106826/
+// can-static-cast-to-same-type-introduce-runtime-overhead
+template <class DataType, class BatchType>
+Status WriteGenericBatch(const Array& array, int64_t orc_offset,
+ liborc::ColumnVectorBatch* column_vector_batch) {
+ using ArrayType = typename TypeTraits<DataType>::ArrayType;
+ const ArrayType& array_(checked_cast<const ArrayType&>(array));
+ auto batch = checked_cast<BatchType*>(column_vector_batch);
+ if (array.null_count()) {
+ batch->hasNulls = true;
+ }
+ Appender<DataType, BatchType> appender{array_, batch, orc_offset, 0};
+ ArrayDataVisitor<DataType> visitor;
+ RETURN_NOT_OK(visitor.Visit(*(array_.data()), &appender));
+ return Status::OK();
+}
+
+template <class DataType>
+Status WriteTimestampBatch(const Array& array, int64_t orc_offset,
+ liborc::ColumnVectorBatch* column_vector_batch,
+ const int64_t& conversion_factor_from_second,
+ const int64_t& conversion_factor_to_nano) {
+ using ArrayType = typename TypeTraits<DataType>::ArrayType;
+ const ArrayType& array_(checked_cast<const ArrayType&>(array));
+ auto batch = checked_cast<liborc::TimestampVectorBatch*>(column_vector_batch);
+ if (array.null_count()) {
+ batch->hasNulls = true;
+ }
+ TimestampAppender<DataType> appender{array_,
+ batch,
+ orc_offset,
+ 0,
+ conversion_factor_from_second,
+ conversion_factor_to_nano};
+ ArrayDataVisitor<DataType> visitor;
+ RETURN_NOT_OK(visitor.Visit(*(array_.data()), &appender));
+ return Status::OK();
+}
+
+Status WriteFixedSizeBinaryBatch(const Array& array, int64_t orc_offset,
+ liborc::ColumnVectorBatch* column_vector_batch) {
+ const FixedSizeBinaryArray& array_(checked_cast<const FixedSizeBinaryArray&>(array));
+ auto batch = checked_cast<liborc::StringVectorBatch*>(column_vector_batch);
+ if (array.null_count()) {
+ batch->hasNulls = true;
+ }
+ FixedSizeBinaryAppender appender{array_, batch, orc_offset, 0, array_.byte_width()};
+ ArrayDataVisitor<FixedSizeBinaryType> visitor;
+ RETURN_NOT_OK(visitor.Visit(*(array_.data()), &appender));
+ return Status::OK();
+}
+
+Status WriteStructBatch(const Array& array, int64_t orc_offset,
+ liborc::ColumnVectorBatch* column_vector_batch) {
+ std::shared_ptr<Array> array_ = MakeArray(array.data());
+ std::shared_ptr<StructArray> struct_array(checked_pointer_cast<StructArray>(array_));
+ auto batch = checked_cast<liborc::StructVectorBatch*>(column_vector_batch);
+ std::size_t size = array.type()->fields().size();
+ int64_t arrow_length = array.length();
+ int64_t running_arrow_offset = 0, running_orc_offset = orc_offset;
+ // First fill fields of ColumnVectorBatch
+ if (array.null_count()) {
+ batch->hasNulls = true;
+ }
+ for (; running_arrow_offset < arrow_length;
+ running_orc_offset++, running_arrow_offset++) {
+ if (array.IsNull(running_arrow_offset)) {
+ batch->notNull[running_orc_offset] = false;
+ } else {
+ batch->notNull[running_orc_offset] = true;
+ }
+ }
+ // Fill the fields
+ for (std::size_t i = 0; i < size; i++) {
+ batch->fields[i]->resize(orc_offset + arrow_length);
+ RETURN_NOT_OK(WriteBatch(*(struct_array->field(i)), orc_offset, batch->fields[i]));
+ }
+ return Status::OK();
+}
+
+template <class ArrayType>
+Status WriteListBatch(const Array& array, int64_t orc_offset,
+ liborc::ColumnVectorBatch* column_vector_batch) {
+ const ArrayType& list_array(checked_cast<const ArrayType&>(array));
+ auto batch = checked_cast<liborc::ListVectorBatch*>(column_vector_batch);
+ liborc::ColumnVectorBatch* element_batch = (batch->elements).get();
+ int64_t arrow_length = array.length();
+ int64_t running_arrow_offset = 0, running_orc_offset = orc_offset;
+ if (orc_offset == 0) {
+ batch->offsets[0] = 0;
+ }
+ if (array.null_count()) {
+ batch->hasNulls = true;
+ }
+ for (; running_arrow_offset < arrow_length;
+ running_orc_offset++, running_arrow_offset++) {
+ if (array.IsNull(running_arrow_offset)) {
+ batch->notNull[running_orc_offset] = false;
+ batch->offsets[running_orc_offset + 1] = batch->offsets[running_orc_offset];
+ } else {
+ batch->notNull[running_orc_offset] = true;
+ batch->offsets[running_orc_offset + 1] =
+ batch->offsets[running_orc_offset] +
+ list_array.value_offset(running_arrow_offset + 1) -
+ list_array.value_offset(running_arrow_offset);
+ element_batch->resize(batch->offsets[running_orc_offset + 1]);
+ int64_t subarray_arrow_offset = list_array.value_offset(running_arrow_offset),
+ subarray_orc_offset = batch->offsets[running_orc_offset],
+ subarray_orc_length =
+ batch->offsets[running_orc_offset + 1] - subarray_orc_offset;
+ RETURN_NOT_OK(WriteBatch(
+ *(list_array.values()->Slice(subarray_arrow_offset, subarray_orc_length)),
+ subarray_orc_offset, element_batch));
+ }
+ }
+ return Status::OK();
+}
+
+Status WriteMapBatch(const Array& array, int64_t orc_offset,
+ liborc::ColumnVectorBatch* column_vector_batch) {
+ const MapArray& map_array(checked_cast<const MapArray&>(array));
+ auto batch = checked_cast<liborc::MapVectorBatch*>(column_vector_batch);
+ liborc::ColumnVectorBatch* key_batch = (batch->keys).get();
+ liborc::ColumnVectorBatch* element_batch = (batch->elements).get();
+ std::shared_ptr<Array> key_array = map_array.keys();
+ std::shared_ptr<Array> element_array = map_array.items();
+ int64_t arrow_length = array.length();
+ int64_t running_arrow_offset = 0, running_orc_offset = orc_offset;
+ if (orc_offset == 0) {
+ batch->offsets[0] = 0;
+ }
+ if (array.null_count()) {
+ batch->hasNulls = true;
+ }
+ for (; running_arrow_offset < arrow_length;
+ running_orc_offset++, running_arrow_offset++) {
+ if (array.IsNull(running_arrow_offset)) {
+ batch->notNull[running_orc_offset] = false;
+ batch->offsets[running_orc_offset + 1] = batch->offsets[running_orc_offset];
+ } else {
+ batch->notNull[running_orc_offset] = true;
+ batch->offsets[running_orc_offset + 1] =
+ batch->offsets[running_orc_offset] +
+ map_array.value_offset(running_arrow_offset + 1) -
+ map_array.value_offset(running_arrow_offset);
+ int64_t subarray_arrow_offset = map_array.value_offset(running_arrow_offset),
+ subarray_orc_offset = batch->offsets[running_orc_offset],
+ new_subarray_orc_offset = batch->offsets[running_orc_offset + 1],
+ subarray_orc_length = new_subarray_orc_offset - subarray_orc_offset;
+ key_batch->resize(new_subarray_orc_offset);
+ element_batch->resize(new_subarray_orc_offset);
+ RETURN_NOT_OK(
+ WriteBatch(*(key_array->Slice(subarray_arrow_offset, subarray_orc_length)),
+ subarray_orc_offset, key_batch));
+ RETURN_NOT_OK(
+ WriteBatch(*(element_array->Slice(subarray_arrow_offset, subarray_orc_length)),
+ subarray_orc_offset, element_batch));
+ }
+ }
+ return Status::OK();
+}
+
+Status WriteBatch(const Array& array, int64_t orc_offset,
+ liborc::ColumnVectorBatch* column_vector_batch) {
+ Type::type kind = array.type_id();
+ column_vector_batch->numElements = orc_offset;
+ switch (kind) {
+ case Type::type::BOOL:
+ return WriteGenericBatch<BooleanType, liborc::LongVectorBatch>(array, orc_offset,
+ column_vector_batch);
+ case Type::type::INT8:
+ return WriteGenericBatch<Int8Type, liborc::LongVectorBatch>(array, orc_offset,
+ column_vector_batch);
+ case Type::type::INT16:
+ return WriteGenericBatch<Int16Type, liborc::LongVectorBatch>(array, orc_offset,
+ column_vector_batch);
+ case Type::type::INT32:
+ return WriteGenericBatch<Int32Type, liborc::LongVectorBatch>(array, orc_offset,
+ column_vector_batch);
+ case Type::type::INT64:
+ return WriteGenericBatch<Int64Type, liborc::LongVectorBatch>(array, orc_offset,
+ column_vector_batch);
+ case Type::type::FLOAT:
+ return WriteGenericBatch<FloatType, liborc::DoubleVectorBatch>(array, orc_offset,
+ column_vector_batch);
+ case Type::type::DOUBLE:
+ return WriteGenericBatch<DoubleType, liborc::DoubleVectorBatch>(
+ array, orc_offset, column_vector_batch);
+ case Type::type::BINARY:
+ return WriteGenericBatch<BinaryType, liborc::StringVectorBatch>(
+ array, orc_offset, column_vector_batch);
+ case Type::type::LARGE_BINARY:
+ return WriteGenericBatch<LargeBinaryType, liborc::StringVectorBatch>(
+ array, orc_offset, column_vector_batch);
+ case Type::type::STRING:
+ return WriteGenericBatch<StringType, liborc::StringVectorBatch>(
+ array, orc_offset, column_vector_batch);
+ case Type::type::LARGE_STRING:
+ return WriteGenericBatch<LargeStringType, liborc::StringVectorBatch>(
+ array, orc_offset, column_vector_batch);
+ case Type::type::FIXED_SIZE_BINARY:
+ return WriteFixedSizeBinaryBatch(array, orc_offset, column_vector_batch);
+ case Type::type::DATE32:
+ return WriteGenericBatch<Date32Type, liborc::LongVectorBatch>(array, orc_offset,
+ column_vector_batch);
+ case Type::type::DATE64:
+ return WriteTimestampBatch<Date64Type>(array, orc_offset, column_vector_batch,
+ kOneSecondMillis, kOneMilliNanos);
+ case Type::type::TIMESTAMP: {
+ switch (internal::checked_pointer_cast<TimestampType>(array.type())->unit()) {
+ case TimeUnit::type::SECOND:
+ return WriteTimestampBatch<TimestampType>(
+ array, orc_offset, column_vector_batch, 1, kOneSecondNanos);
+ case TimeUnit::type::MILLI:
+ return WriteTimestampBatch<TimestampType>(
+ array, orc_offset, column_vector_batch, kOneSecondMillis, kOneMilliNanos);
+ case TimeUnit::type::MICRO:
+ return WriteTimestampBatch<TimestampType>(
+ array, orc_offset, column_vector_batch, kOneSecondMicros, kOneMicroNanos);
+ case TimeUnit::type::NANO:
+ return WriteTimestampBatch<TimestampType>(
+ array, orc_offset, column_vector_batch, kOneSecondNanos, 1);
+ default:
+ return Status::TypeError("Unknown or unsupported Arrow type: ",
+ array.type()->ToString());
+ }
+ }
+ case Type::type::DECIMAL128: {
+ int32_t precision = checked_pointer_cast<Decimal128Type>(array.type())->precision();
+ if (precision > 18) {
+ return WriteGenericBatch<Decimal128Type, liborc::Decimal128VectorBatch>(
+ array, orc_offset, column_vector_batch);
+ } else {
+ return WriteGenericBatch<Decimal128Type, liborc::Decimal64VectorBatch>(
+ array, orc_offset, column_vector_batch);
+ }
+ }
+ case Type::type::STRUCT:
+ return WriteStructBatch(array, orc_offset, column_vector_batch);
+ case Type::type::LIST:
+ return WriteListBatch<ListArray>(array, orc_offset, column_vector_batch);
+ case Type::type::LARGE_LIST:
+ return WriteListBatch<LargeListArray>(array, orc_offset, column_vector_batch);
+ case Type::type::FIXED_SIZE_LIST:
+ return WriteListBatch<FixedSizeListArray>(array, orc_offset, column_vector_batch);
+ case Type::type::MAP:
+ return WriteMapBatch(array, orc_offset, column_vector_batch);
+ default: {
+ return Status::NotImplemented("Unknown or unsupported Arrow type: ",
+ array.type()->ToString());
+ }
+ }
+ return Status::OK();
+}
+
+Result<ORC_UNIQUE_PTR<liborc::Type>> GetOrcType(const DataType& type) {
+ Type::type kind = type.id();
+ switch (kind) {
+ case Type::type::BOOL:
+ return liborc::createPrimitiveType(liborc::TypeKind::BOOLEAN);
+ case Type::type::INT8:
+ return liborc::createPrimitiveType(liborc::TypeKind::BYTE);
+ case Type::type::INT16:
+ return liborc::createPrimitiveType(liborc::TypeKind::SHORT);
+ case Type::type::INT32:
+ return liborc::createPrimitiveType(liborc::TypeKind::INT);
+ case Type::type::INT64:
+ return liborc::createPrimitiveType(liborc::TypeKind::LONG);
+ case Type::type::FLOAT:
+ return liborc::createPrimitiveType(liborc::TypeKind::FLOAT);
+ case Type::type::DOUBLE:
+ return liborc::createPrimitiveType(liborc::TypeKind::DOUBLE);
+ // Use STRING instead of VARCHAR for now, both use UTF-8
+ case Type::type::STRING:
+ case Type::type::LARGE_STRING:
+ return liborc::createPrimitiveType(liborc::TypeKind::STRING);
+ case Type::type::BINARY:
+ case Type::type::LARGE_BINARY:
+ case Type::type::FIXED_SIZE_BINARY:
+ return liborc::createPrimitiveType(liborc::TypeKind::BINARY);
+ case Type::type::DATE32:
+ return liborc::createPrimitiveType(liborc::TypeKind::DATE);
+ case Type::type::DATE64:
+ case Type::type::TIMESTAMP:
+ return liborc::createPrimitiveType(liborc::TypeKind::TIMESTAMP);
+ case Type::type::DECIMAL128: {
+ const uint64_t precision =
+ static_cast<uint64_t>(checked_cast<const Decimal128Type&>(type).precision());
+ const uint64_t scale =
+ static_cast<uint64_t>(checked_cast<const Decimal128Type&>(type).scale());
+ return liborc::createDecimalType(precision, scale);
+ }
+ case Type::type::LIST:
+ case Type::type::FIXED_SIZE_LIST:
+ case Type::type::LARGE_LIST: {
+ std::shared_ptr<DataType> arrow_child_type =
+ checked_cast<const BaseListType&>(type).value_type();
+ ARROW_ASSIGN_OR_RAISE(auto orc_subtype, GetOrcType(*arrow_child_type));
+ return liborc::createListType(std::move(orc_subtype));
+ }
+ case Type::type::STRUCT: {
+ ORC_UNIQUE_PTR<liborc::Type> out_type = liborc::createStructType();
+ std::vector<std::shared_ptr<Field>> arrow_fields =
+ checked_cast<const StructType&>(type).fields();
+ for (std::vector<std::shared_ptr<Field>>::iterator it = arrow_fields.begin();
+ it != arrow_fields.end(); ++it) {
+ std::string field_name = (*it)->name();
+ std::shared_ptr<DataType> arrow_child_type = (*it)->type();
+ ARROW_ASSIGN_OR_RAISE(auto orc_subtype, GetOrcType(*arrow_child_type));
+ out_type->addStructField(field_name, std::move(orc_subtype));
+ }
+ return std::move(out_type);
+ }
+ case Type::type::MAP: {
+ std::shared_ptr<DataType> key_arrow_type =
+ checked_cast<const MapType&>(type).key_type();
+ std::shared_ptr<DataType> item_arrow_type =
+ checked_cast<const MapType&>(type).item_type();
+ ARROW_ASSIGN_OR_RAISE(auto key_orc_type, GetOrcType(*key_arrow_type));
+ ARROW_ASSIGN_OR_RAISE(auto item_orc_type, GetOrcType(*item_arrow_type));
+ return liborc::createMapType(std::move(key_orc_type), std::move(item_orc_type));
+ }
+ case Type::type::DENSE_UNION:
+ case Type::type::SPARSE_UNION: {
+ ORC_UNIQUE_PTR<liborc::Type> out_type = liborc::createUnionType();
+ std::vector<std::shared_ptr<Field>> arrow_fields =
+ checked_cast<const UnionType&>(type).fields();
+ for (std::vector<std::shared_ptr<Field>>::iterator it = arrow_fields.begin();
+ it != arrow_fields.end(); ++it) {
+ std::string field_name = (*it)->name();
+ std::shared_ptr<DataType> arrow_child_type = (*it)->type();
+ ARROW_ASSIGN_OR_RAISE(auto orc_subtype, GetOrcType(*arrow_child_type));
+ out_type->addUnionChild(std::move(orc_subtype));
+ }
+ return std::move(out_type);
+ }
+ default: {
+ return Status::NotImplemented("Unknown or unsupported Arrow type: ",
+ type.ToString());
+ }
+ }
+}
+
+} // namespace
+
+Status WriteBatch(const ChunkedArray& chunked_array, int64_t length,
+ int* arrow_chunk_offset, int64_t* arrow_index_offset,
+ liborc::ColumnVectorBatch* column_vector_batch) {
+ int num_batch = chunked_array.num_chunks();
+ int64_t orc_offset = 0;
+ while (*arrow_chunk_offset < num_batch && orc_offset < length) {
+ ARROW_ASSIGN_OR_RAISE(auto array,
+ NormalizeArray(chunked_array.chunk(*arrow_chunk_offset)));
+ int64_t num_written_elements =
+ std::min(length - orc_offset, array->length() - *arrow_index_offset);
+ if (num_written_elements > 0) {
+ RETURN_NOT_OK(WriteBatch(*(array->Slice(*arrow_index_offset, num_written_elements)),
+ orc_offset, column_vector_batch));
+ orc_offset += num_written_elements;
+ *arrow_index_offset += num_written_elements;
+ }
+ if (orc_offset < length) { // Another Arrow Array done
+ *arrow_index_offset = 0;
+ (*arrow_chunk_offset)++;
+ }
+ }
+ column_vector_batch->numElements = orc_offset;
+ return Status::OK();
+}
+
+Status GetArrowType(const liborc::Type* type, std::shared_ptr<DataType>* out) {
+ // When subselecting fields on read, liborc will set some nodes to nullptr,
+ // so we need to check for nullptr before progressing
+ if (type == nullptr) {
+ *out = null();
+ return Status::OK();
+ }
+ liborc::TypeKind kind = type->getKind();
+ const int subtype_count = static_cast<int>(type->getSubtypeCount());
+
+ switch (kind) {
+ case liborc::BOOLEAN:
+ *out = boolean();
+ break;
+ case liborc::BYTE:
+ *out = int8();
+ break;
+ case liborc::SHORT:
+ *out = int16();
+ break;
+ case liborc::INT:
+ *out = int32();
+ break;
+ case liborc::LONG:
+ *out = int64();
+ break;
+ case liborc::FLOAT:
+ *out = float32();
+ break;
+ case liborc::DOUBLE:
+ *out = float64();
+ break;
+ case liborc::VARCHAR:
+ case liborc::STRING:
+ *out = utf8();
+ break;
+ case liborc::BINARY:
+ *out = binary();
+ break;
+ case liborc::CHAR:
+ *out = fixed_size_binary(static_cast<int>(type->getMaximumLength()));
+ break;
+ case liborc::TIMESTAMP:
+ *out = timestamp(TimeUnit::NANO);
+ break;
+ case liborc::DATE:
+ *out = date32();
+ break;
+ case liborc::DECIMAL: {
+ const int precision = static_cast<int>(type->getPrecision());
+ const int scale = static_cast<int>(type->getScale());
+ if (precision == 0) {
+ // In HIVE 0.11/0.12 precision is set as 0, but means max precision
+ *out = decimal128(38, 6);
+ } else {
+ *out = decimal128(precision, scale);
+ }
+ break;
+ }
+ case liborc::LIST: {
+ if (subtype_count != 1) {
+ return Status::TypeError("Invalid Orc List type");
+ }
+ std::shared_ptr<DataType> elemtype;
+ RETURN_NOT_OK(GetArrowType(type->getSubtype(0), &elemtype));
+ *out = list(elemtype);
+ break;
+ }
+ case liborc::MAP: {
+ if (subtype_count != 2) {
+ return Status::TypeError("Invalid Orc Map type");
+ }
+ std::shared_ptr<DataType> key_type, item_type;
+ RETURN_NOT_OK(GetArrowType(type->getSubtype(0), &key_type));
+ RETURN_NOT_OK(GetArrowType(type->getSubtype(1), &item_type));
+ *out = map(key_type, item_type);
+ break;
+ }
+ case liborc::STRUCT: {
+ std::vector<std::shared_ptr<Field>> fields;
+ for (int child = 0; child < subtype_count; ++child) {
+ std::shared_ptr<DataType> elem_type;
+ RETURN_NOT_OK(GetArrowType(type->getSubtype(child), &elem_type));
+ std::string name = type->getFieldName(child);
+ fields.push_back(field(name, elem_type));
+ }
+ *out = struct_(fields);
+ break;
+ }
+ case liborc::UNION: {
+ std::vector<std::shared_ptr<Field>> fields;
+ std::vector<int8_t> type_codes;
+ for (int child = 0; child < subtype_count; ++child) {
+ std::shared_ptr<DataType> elem_type;
+ RETURN_NOT_OK(GetArrowType(type->getSubtype(child), &elem_type));
+ fields.push_back(field("_union_" + std::to_string(child), elem_type));
+ type_codes.push_back(static_cast<int8_t>(child));
+ }
+ *out = sparse_union(fields, type_codes);
+ break;
+ }
+ default: {
+ return Status::TypeError("Unknown Orc type kind: ", type->toString());
+ }
+ }
+ return Status::OK();
+}
+
+Result<ORC_UNIQUE_PTR<liborc::Type>> GetOrcType(const Schema& schema) {
+ int numFields = schema.num_fields();
+ ORC_UNIQUE_PTR<liborc::Type> out_type = liborc::createStructType();
+ for (int i = 0; i < numFields; i++) {
+ std::shared_ptr<Field> field = schema.field(i);
+ std::string field_name = field->name();
+ std::shared_ptr<DataType> arrow_child_type = field->type();
+ ARROW_ASSIGN_OR_RAISE(auto orc_subtype, GetOrcType(*arrow_child_type));
+ out_type->addStructField(field_name, std::move(orc_subtype));
+ }
+ return std::move(out_type);
+}
+
+} // namespace orc
+} // namespace adapters
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/adapters/orc/adapter_util.h b/src/arrow/cpp/src/arrow/adapters/orc/adapter_util.h
new file mode 100644
index 000000000..3e6d0fcc6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/adapters/orc/adapter_util.h
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/array/builder_base.h"
+#include "arrow/status.h"
+#include "orc/OrcFile.hh"
+
+namespace liborc = orc;
+
+namespace arrow {
+
+namespace adapters {
+
+namespace orc {
+
+Status GetArrowType(const liborc::Type* type, std::shared_ptr<DataType>* out);
+
+Result<ORC_UNIQUE_PTR<liborc::Type>> GetOrcType(const Schema& schema);
+
+Status AppendBatch(const liborc::Type* type, liborc::ColumnVectorBatch* batch,
+ int64_t offset, int64_t length, arrow::ArrayBuilder* builder);
+
+/// \brief Write a chunked array to an orc::ColumnVectorBatch
+///
+/// \param[in] chunked_array the chunked array
+/// \param[in] length the orc::ColumnVectorBatch size limit
+/// \param[in,out] arrow_chunk_offset The current chunk being processed
+/// \param[in,out] arrow_index_offset The index of the arrow_chunk_offset array
+/// before or after a process
+/// \param[in,out] column_vector_batch the orc::ColumnVectorBatch to be filled
+/// \return Status
+Status WriteBatch(const ChunkedArray& chunked_array, int64_t length,
+ int* arrow_chunk_offset, int64_t* arrow_index_offset,
+ liborc::ColumnVectorBatch* column_vector_batch);
+
+} // namespace orc
+} // namespace adapters
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/adapters/orc/arrow-orc.pc.in b/src/arrow/cpp/src/arrow/adapters/orc/arrow-orc.pc.in
new file mode 100644
index 000000000..eec59ccc5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/adapters/orc/arrow-orc.pc.in
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+
+Name: Apache Arrow ORC
+Description: ORC modules for Apache Arrow
+Version: @ARROW_VERSION@
+Requires: arrow
diff --git a/src/arrow/cpp/src/arrow/adapters/tensorflow/CMakeLists.txt b/src/arrow/cpp/src/arrow/adapters/tensorflow/CMakeLists.txt
new file mode 100644
index 000000000..a627db417
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/adapters/tensorflow/CMakeLists.txt
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+arrow_install_all_headers("arrow/adapters/tensorflow")
+
+# pkg-config support
+arrow_add_pkg_config("arrow-tensorflow")
diff --git a/src/arrow/cpp/src/arrow/adapters/tensorflow/arrow-tensorflow.pc.in b/src/arrow/cpp/src/arrow/adapters/tensorflow/arrow-tensorflow.pc.in
new file mode 100644
index 000000000..a2b38a0a0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/adapters/tensorflow/arrow-tensorflow.pc.in
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+
+Name: Apache Arrow Tensorflow
+Description: TensorFlow modules for Apache Arrow
+Version: @ARROW_VERSION@
+Requires: arrow
diff --git a/src/arrow/cpp/src/arrow/adapters/tensorflow/convert.h b/src/arrow/cpp/src/arrow/adapters/tensorflow/convert.h
new file mode 100644
index 000000000..9d093eddf
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/adapters/tensorflow/convert.h
@@ -0,0 +1,128 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "tensorflow/core/framework/op.h"
+
+#include "arrow/type.h"
+
+// These utilities are supposed to be included in TensorFlow operators
+// that need to be compiled separately from Arrow because of ABI issues.
+// They therefore need to be header-only.
+
+namespace arrow {
+
+namespace adapters {
+
+namespace tensorflow {
+
+Status GetArrowType(::tensorflow::DataType dtype, std::shared_ptr<DataType>* out) {
+ switch (dtype) {
+ case ::tensorflow::DT_BOOL:
+ *out = arrow::boolean();
+ break;
+ case ::tensorflow::DT_FLOAT:
+ *out = arrow::float32();
+ break;
+ case ::tensorflow::DT_DOUBLE:
+ *out = arrow::float64();
+ break;
+ case ::tensorflow::DT_HALF:
+ *out = arrow::float16();
+ break;
+ case ::tensorflow::DT_INT8:
+ *out = arrow::int8();
+ break;
+ case ::tensorflow::DT_INT16:
+ *out = arrow::int16();
+ break;
+ case ::tensorflow::DT_INT32:
+ *out = arrow::int32();
+ break;
+ case ::tensorflow::DT_INT64:
+ *out = arrow::int64();
+ break;
+ case ::tensorflow::DT_UINT8:
+ *out = arrow::uint8();
+ break;
+ case ::tensorflow::DT_UINT16:
+ *out = arrow::uint16();
+ break;
+ case ::tensorflow::DT_UINT32:
+ *out = arrow::uint32();
+ break;
+ case ::tensorflow::DT_UINT64:
+ *out = arrow::uint64();
+ break;
+ default:
+ return Status::TypeError("TensorFlow data type is not supported");
+ }
+ return Status::OK();
+}
+
+Status GetTensorFlowType(std::shared_ptr<DataType> dtype, ::tensorflow::DataType* out) {
+ switch (dtype->id()) {
+ case Type::BOOL:
+ *out = ::tensorflow::DT_BOOL;
+ break;
+ case Type::UINT8:
+ *out = ::tensorflow::DT_UINT8;
+ break;
+ case Type::INT8:
+ *out = ::tensorflow::DT_INT8;
+ break;
+ case Type::UINT16:
+ *out = ::tensorflow::DT_UINT16;
+ break;
+ case Type::INT16:
+ *out = ::tensorflow::DT_INT16;
+ break;
+ case Type::UINT32:
+ *out = ::tensorflow::DT_UINT32;
+ break;
+ case Type::INT32:
+ *out = ::tensorflow::DT_INT32;
+ break;
+ case Type::UINT64:
+ *out = ::tensorflow::DT_UINT64;
+ break;
+ case Type::INT64:
+ *out = ::tensorflow::DT_INT64;
+ break;
+ case Type::HALF_FLOAT:
+ *out = ::tensorflow::DT_HALF;
+ break;
+ case Type::FLOAT:
+ *out = ::tensorflow::DT_FLOAT;
+ break;
+ case Type::DOUBLE:
+ *out = ::tensorflow::DT_DOUBLE;
+ break;
+ default:
+ return Status::TypeError("Arrow data type is not supported");
+ }
+ return arrow::Status::OK();
+}
+
+} // namespace tensorflow
+
+} // namespace adapters
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/api.h b/src/arrow/cpp/src/arrow/api.h
new file mode 100644
index 000000000..8958eaf1c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/api.h
@@ -0,0 +1,44 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Coarse public API while the library is in development
+
+#pragma once
+
+#include "arrow/array.h" // IYWU pragma: export
+#include "arrow/array/concatenate.h" // IYWU pragma: export
+#include "arrow/buffer.h" // IYWU pragma: export
+#include "arrow/builder.h" // IYWU pragma: export
+#include "arrow/chunked_array.h" // IYWU pragma: export
+#include "arrow/compare.h" // IYWU pragma: export
+#include "arrow/config.h" // IYWU pragma: export
+#include "arrow/datum.h" // IYWU pragma: export
+#include "arrow/extension_type.h" // IYWU pragma: export
+#include "arrow/memory_pool.h" // IYWU pragma: export
+#include "arrow/pretty_print.h" // IYWU pragma: export
+#include "arrow/record_batch.h" // IYWU pragma: export
+#include "arrow/result.h" // IYWU pragma: export
+#include "arrow/status.h" // IYWU pragma: export
+#include "arrow/table.h" // IYWU pragma: export
+#include "arrow/table_builder.h" // IYWU pragma: export
+#include "arrow/tensor.h" // IYWU pragma: export
+#include "arrow/type.h" // IYWU pragma: export
+#include "arrow/util/key_value_metadata.h" // IWYU pragma: export
+#include "arrow/visitor.h" // IYWU pragma: export
+
+/// \brief Top-level namespace for Apache Arrow C++ API
+namespace arrow {}
diff --git a/src/arrow/cpp/src/arrow/array.h b/src/arrow/cpp/src/arrow/array.h
new file mode 100644
index 000000000..918c76174
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array.h
@@ -0,0 +1,44 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Kitchen-sink public API for arrow::Array data structures. C++ library code
+// (especially header files) in Apache Arrow should use more specific headers
+// unless it's a file that uses most or all Array types in which case using
+// arrow/array.h is fine.
+
+#pragma once
+
+/// \defgroup numeric-arrays Concrete classes for numeric arrays
+/// @{
+/// @}
+
+/// \defgroup binary-arrays Concrete classes for binary/string arrays
+/// @{
+/// @}
+
+/// \defgroup nested-arrays Concrete classes for nested arrays
+/// @{
+/// @}
+
+#include "arrow/array/array_base.h" // IWYU pragma: keep
+#include "arrow/array/array_binary.h" // IWYU pragma: keep
+#include "arrow/array/array_decimal.h" // IWYU pragma: keep
+#include "arrow/array/array_dict.h" // IWYU pragma: keep
+#include "arrow/array/array_nested.h" // IWYU pragma: keep
+#include "arrow/array/array_primitive.h" // IWYU pragma: keep
+#include "arrow/array/data.h" // IWYU pragma: keep
+#include "arrow/array/util.h" // IWYU pragma: keep
diff --git a/src/arrow/cpp/src/arrow/array/CMakeLists.txt b/src/arrow/cpp/src/arrow/array/CMakeLists.txt
new file mode 100644
index 000000000..c0fc17687
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_arrow_test(concatenate_test)
+
+if(ARROW_COMPUTE)
+ # This unit test uses compute code
+ add_arrow_test(diff_test)
+endif()
+
+# Headers: top level
+arrow_install_all_headers("arrow/array")
diff --git a/src/arrow/cpp/src/arrow/array/README.md b/src/arrow/cpp/src/arrow/array/README.md
new file mode 100644
index 000000000..01ffa104e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/README.md
@@ -0,0 +1,20 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+## Implementation details related to columnar (array) data structures
diff --git a/src/arrow/cpp/src/arrow/array/array_base.cc b/src/arrow/cpp/src/arrow/array/array_base.cc
new file mode 100644
index 000000000..dd3cec1d7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_base.cc
@@ -0,0 +1,313 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/array_base.h"
+
+#include <cstdint>
+#include <memory>
+#include <sstream> // IWYU pragma: keep
+#include <string>
+#include <type_traits>
+#include <utility>
+
+#include "arrow/array/array_binary.h"
+#include "arrow/array/array_dict.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/array/array_primitive.h"
+#include "arrow/array/util.h"
+#include "arrow/array/validate.h"
+#include "arrow/buffer.h"
+#include "arrow/compare.h"
+#include "arrow/pretty_print.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/logging.h"
+#include "arrow/visitor.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+class ExtensionArray;
+
+// ----------------------------------------------------------------------
+// Base array class
+
+int64_t Array::null_count() const { return data_->GetNullCount(); }
+
+namespace internal {
+
+struct ScalarFromArraySlotImpl {
+ template <typename T>
+ using ScalarType = typename TypeTraits<T>::ScalarType;
+
+ Status Visit(const NullArray& a) {
+ out_ = std::make_shared<NullScalar>();
+ return Status::OK();
+ }
+
+ Status Visit(const BooleanArray& a) { return Finish(a.Value(index_)); }
+
+ template <typename T>
+ Status Visit(const NumericArray<T>& a) {
+ return Finish(a.Value(index_));
+ }
+
+ Status Visit(const Decimal128Array& a) {
+ return Finish(Decimal128(a.GetValue(index_)));
+ }
+
+ Status Visit(const Decimal256Array& a) {
+ return Finish(Decimal256(a.GetValue(index_)));
+ }
+
+ template <typename T>
+ Status Visit(const BaseBinaryArray<T>& a) {
+ return Finish(a.GetString(index_));
+ }
+
+ Status Visit(const FixedSizeBinaryArray& a) { return Finish(a.GetString(index_)); }
+
+ Status Visit(const DayTimeIntervalArray& a) { return Finish(a.Value(index_)); }
+ Status Visit(const MonthDayNanoIntervalArray& a) { return Finish(a.Value(index_)); }
+
+ template <typename T>
+ Status Visit(const BaseListArray<T>& a) {
+ return Finish(a.value_slice(index_));
+ }
+
+ Status Visit(const FixedSizeListArray& a) { return Finish(a.value_slice(index_)); }
+
+ Status Visit(const StructArray& a) {
+ ScalarVector children;
+ for (const auto& child : a.fields()) {
+ children.emplace_back();
+ ARROW_ASSIGN_OR_RAISE(children.back(), child->GetScalar(index_));
+ }
+ return Finish(std::move(children));
+ }
+
+ Status Visit(const SparseUnionArray& a) {
+ const auto type_code = a.type_code(index_);
+ // child array which stores the actual value
+ const auto arr = a.field(a.child_id(index_));
+ // no need to adjust the index
+ ARROW_ASSIGN_OR_RAISE(auto value, arr->GetScalar(index_));
+ if (value->is_valid) {
+ out_ = std::shared_ptr<Scalar>(new SparseUnionScalar(value, type_code, a.type()));
+ } else {
+ out_ = std::shared_ptr<Scalar>(new SparseUnionScalar(type_code, a.type()));
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DenseUnionArray& a) {
+ const auto type_code = a.type_code(index_);
+ // child array which stores the actual value
+ auto arr = a.field(a.child_id(index_));
+ // need to look up the value based on offsets
+ auto offset = a.value_offset(index_);
+ ARROW_ASSIGN_OR_RAISE(auto value, arr->GetScalar(offset));
+ if (value->is_valid) {
+ out_ = std::shared_ptr<Scalar>(new DenseUnionScalar(value, type_code, a.type()));
+ } else {
+ out_ = std::shared_ptr<Scalar>(new DenseUnionScalar(type_code, a.type()));
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryArray& a) {
+ auto ty = a.type();
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto index, MakeScalar(checked_cast<const DictionaryType&>(*ty).index_type(),
+ a.GetValueIndex(index_)));
+
+ auto scalar = DictionaryScalar(ty);
+ scalar.is_valid = a.IsValid(index_);
+ scalar.value.index = index;
+ scalar.value.dictionary = a.dictionary();
+
+ out_ = std::make_shared<DictionaryScalar>(std::move(scalar));
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionArray& a) {
+ ARROW_ASSIGN_OR_RAISE(auto storage, a.storage()->GetScalar(index_));
+ out_ = std::make_shared<ExtensionScalar>(std::move(storage), a.type());
+ return Status::OK();
+ }
+
+ template <typename Arg>
+ Status Finish(Arg&& arg) {
+ return MakeScalar(array_.type(), std::forward<Arg>(arg)).Value(&out_);
+ }
+
+ Status Finish(std::string arg) {
+ return MakeScalar(array_.type(), Buffer::FromString(std::move(arg))).Value(&out_);
+ }
+
+ Result<std::shared_ptr<Scalar>> Finish() && {
+ if (index_ >= array_.length()) {
+ return Status::IndexError("tried to refer to element ", index_,
+ " but array is only ", array_.length(), " long");
+ }
+
+ if (array_.IsNull(index_)) {
+ auto null = MakeNullScalar(array_.type());
+ if (is_dictionary(array_.type()->id())) {
+ auto& dict_null = checked_cast<DictionaryScalar&>(*null);
+ const auto& dict_array = checked_cast<const DictionaryArray&>(array_);
+ dict_null.value.dictionary = dict_array.dictionary();
+ }
+ return null;
+ }
+
+ RETURN_NOT_OK(VisitArrayInline(array_, this));
+ return std::move(out_);
+ }
+
+ ScalarFromArraySlotImpl(const Array& array, int64_t index)
+ : array_(array), index_(index) {}
+
+ const Array& array_;
+ int64_t index_;
+ std::shared_ptr<Scalar> out_;
+};
+
+} // namespace internal
+
+Result<std::shared_ptr<Scalar>> Array::GetScalar(int64_t i) const {
+ return internal::ScalarFromArraySlotImpl{*this, i}.Finish();
+}
+
+std::string Array::Diff(const Array& other) const {
+ std::stringstream diff;
+ ARROW_IGNORE_EXPR(Equals(other, EqualOptions().diff_sink(&diff)));
+ return diff.str();
+}
+
+bool Array::Equals(const Array& arr, const EqualOptions& opts) const {
+ return ArrayEquals(*this, arr, opts);
+}
+
+bool Array::Equals(const std::shared_ptr<Array>& arr, const EqualOptions& opts) const {
+ if (!arr) {
+ return false;
+ }
+ return Equals(*arr, opts);
+}
+
+bool Array::ApproxEquals(const Array& arr, const EqualOptions& opts) const {
+ return ArrayApproxEquals(*this, arr, opts);
+}
+
+bool Array::ApproxEquals(const std::shared_ptr<Array>& arr,
+ const EqualOptions& opts) const {
+ if (!arr) {
+ return false;
+ }
+ return ApproxEquals(*arr, opts);
+}
+
+bool Array::RangeEquals(const Array& other, int64_t start_idx, int64_t end_idx,
+ int64_t other_start_idx, const EqualOptions& opts) const {
+ return ArrayRangeEquals(*this, other, start_idx, end_idx, other_start_idx, opts);
+}
+
+bool Array::RangeEquals(const std::shared_ptr<Array>& other, int64_t start_idx,
+ int64_t end_idx, int64_t other_start_idx,
+ const EqualOptions& opts) const {
+ if (!other) {
+ return false;
+ }
+ return ArrayRangeEquals(*this, *other, start_idx, end_idx, other_start_idx, opts);
+}
+
+bool Array::RangeEquals(int64_t start_idx, int64_t end_idx, int64_t other_start_idx,
+ const Array& other, const EqualOptions& opts) const {
+ return ArrayRangeEquals(*this, other, start_idx, end_idx, other_start_idx, opts);
+}
+
+bool Array::RangeEquals(int64_t start_idx, int64_t end_idx, int64_t other_start_idx,
+ const std::shared_ptr<Array>& other,
+ const EqualOptions& opts) const {
+ if (!other) {
+ return false;
+ }
+ return ArrayRangeEquals(*this, *other, start_idx, end_idx, other_start_idx, opts);
+}
+
+std::shared_ptr<Array> Array::Slice(int64_t offset, int64_t length) const {
+ return MakeArray(data_->Slice(offset, length));
+}
+
+std::shared_ptr<Array> Array::Slice(int64_t offset) const {
+ int64_t slice_length = data_->length - offset;
+ return Slice(offset, slice_length);
+}
+
+Result<std::shared_ptr<Array>> Array::SliceSafe(int64_t offset, int64_t length) const {
+ ARROW_ASSIGN_OR_RAISE(auto sliced_data, data_->SliceSafe(offset, length));
+ return MakeArray(std::move(sliced_data));
+}
+
+Result<std::shared_ptr<Array>> Array::SliceSafe(int64_t offset) const {
+ if (offset < 0) {
+ // Avoid UBSAN in subtraction below
+ return Status::Invalid("Negative buffer slice offset");
+ }
+ return SliceSafe(offset, data_->length - offset);
+}
+
+std::string Array::ToString() const {
+ std::stringstream ss;
+ ARROW_CHECK_OK(PrettyPrint(*this, 0, &ss));
+ return ss.str();
+}
+
+Result<std::shared_ptr<Array>> Array::View(
+ const std::shared_ptr<DataType>& out_type) const {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ArrayData> result,
+ internal::GetArrayView(data_, out_type));
+ return MakeArray(result);
+}
+
+// ----------------------------------------------------------------------
+// NullArray
+
+NullArray::NullArray(int64_t length) {
+ SetData(ArrayData::Make(null(), length, {nullptr}, length));
+}
+
+// ----------------------------------------------------------------------
+// Implement Array::Accept as inline visitor
+
+Status Array::Accept(ArrayVisitor* visitor) const {
+ return VisitArrayInline(*this, visitor);
+}
+
+Status Array::Validate() const { return internal::ValidateArray(*this); }
+
+Status Array::ValidateFull() const {
+ RETURN_NOT_OK(internal::ValidateArray(*this));
+ return internal::ValidateArrayFull(*this);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_base.h b/src/arrow/cpp/src/arrow/array/array_base.h
new file mode 100644
index 000000000..2add572e7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_base.h
@@ -0,0 +1,260 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <iosfwd>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/data.h"
+#include "arrow/buffer.h"
+#include "arrow/compare.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+#include "arrow/visitor.h"
+
+namespace arrow {
+
+// ----------------------------------------------------------------------
+// User array accessor types
+
+/// \brief Array base type
+/// Immutable data array with some logical type and some length.
+///
+/// Any memory is owned by the respective Buffer instance (or its parents).
+///
+/// The base class is only required to have a null bitmap buffer if the null
+/// count is greater than 0
+///
+/// If known, the null count can be provided in the base Array constructor. If
+/// the null count is not known, pass -1 to indicate that the null count is to
+/// be computed on the first call to null_count()
+class ARROW_EXPORT Array {
+ public:
+ virtual ~Array() = default;
+
+ /// \brief Return true if value at index is null. Does not boundscheck
+ bool IsNull(int64_t i) const {
+ return null_bitmap_data_ != NULLPTR
+ ? !BitUtil::GetBit(null_bitmap_data_, i + data_->offset)
+ : data_->null_count == data_->length;
+ }
+
+ /// \brief Return true if value at index is valid (not null). Does not
+ /// boundscheck
+ bool IsValid(int64_t i) const {
+ return null_bitmap_data_ != NULLPTR
+ ? BitUtil::GetBit(null_bitmap_data_, i + data_->offset)
+ : data_->null_count != data_->length;
+ }
+
+ /// \brief Return a Scalar containing the value of this array at i
+ Result<std::shared_ptr<Scalar>> GetScalar(int64_t i) const;
+
+ /// Size in the number of elements this array contains.
+ int64_t length() const { return data_->length; }
+
+ /// A relative position into another array's data, to enable zero-copy
+ /// slicing. This value defaults to zero
+ int64_t offset() const { return data_->offset; }
+
+ /// The number of null entries in the array. If the null count was not known
+ /// at time of construction (and set to a negative value), then the null
+ /// count will be computed and cached on the first invocation of this
+ /// function
+ int64_t null_count() const;
+
+ std::shared_ptr<DataType> type() const { return data_->type; }
+ Type::type type_id() const { return data_->type->id(); }
+
+ /// Buffer for the validity (null) bitmap, if any. Note that Union types
+ /// never have a null bitmap.
+ ///
+ /// Note that for `null_count == 0` or for null type, this will be null.
+ /// This buffer does not account for any slice offset
+ const std::shared_ptr<Buffer>& null_bitmap() const { return data_->buffers[0]; }
+
+ /// Raw pointer to the null bitmap.
+ ///
+ /// Note that for `null_count == 0` or for null type, this will be null.
+ /// This buffer does not account for any slice offset
+ const uint8_t* null_bitmap_data() const { return null_bitmap_data_; }
+
+ /// Equality comparison with another array
+ bool Equals(const Array& arr, const EqualOptions& = EqualOptions::Defaults()) const;
+ bool Equals(const std::shared_ptr<Array>& arr,
+ const EqualOptions& = EqualOptions::Defaults()) const;
+
+ /// \brief Return the formatted unified diff of arrow::Diff between this
+ /// Array and another Array
+ std::string Diff(const Array& other) const;
+
+ /// Approximate equality comparison with another array
+ ///
+ /// epsilon is only used if this is FloatArray or DoubleArray
+ bool ApproxEquals(const std::shared_ptr<Array>& arr,
+ const EqualOptions& = EqualOptions::Defaults()) const;
+ bool ApproxEquals(const Array& arr,
+ const EqualOptions& = EqualOptions::Defaults()) const;
+
+ /// Compare if the range of slots specified are equal for the given array and
+ /// this array. end_idx exclusive. This methods does not bounds check.
+ bool RangeEquals(int64_t start_idx, int64_t end_idx, int64_t other_start_idx,
+ const Array& other,
+ const EqualOptions& = EqualOptions::Defaults()) const;
+ bool RangeEquals(int64_t start_idx, int64_t end_idx, int64_t other_start_idx,
+ const std::shared_ptr<Array>& other,
+ const EqualOptions& = EqualOptions::Defaults()) const;
+ bool RangeEquals(const Array& other, int64_t start_idx, int64_t end_idx,
+ int64_t other_start_idx,
+ const EqualOptions& = EqualOptions::Defaults()) const;
+ bool RangeEquals(const std::shared_ptr<Array>& other, int64_t start_idx,
+ int64_t end_idx, int64_t other_start_idx,
+ const EqualOptions& = EqualOptions::Defaults()) const;
+
+ Status Accept(ArrayVisitor* visitor) const;
+
+ /// Construct a zero-copy view of this array with the given type.
+ ///
+ /// This method checks if the types are layout-compatible.
+ /// Nested types are traversed in depth-first order. Data buffers must have
+ /// the same item sizes, even though the logical types may be different.
+ /// An error is returned if the types are not layout-compatible.
+ Result<std::shared_ptr<Array>> View(const std::shared_ptr<DataType>& type) const;
+
+ /// Construct a zero-copy slice of the array with the indicated offset and
+ /// length
+ ///
+ /// \param[in] offset the position of the first element in the constructed
+ /// slice
+ /// \param[in] length the length of the slice. If there are not enough
+ /// elements in the array, the length will be adjusted accordingly
+ ///
+ /// \return a new object wrapped in std::shared_ptr<Array>
+ std::shared_ptr<Array> Slice(int64_t offset, int64_t length) const;
+
+ /// Slice from offset until end of the array
+ std::shared_ptr<Array> Slice(int64_t offset) const;
+
+ /// Input-checking variant of Array::Slice
+ Result<std::shared_ptr<Array>> SliceSafe(int64_t offset, int64_t length) const;
+ /// Input-checking variant of Array::Slice
+ Result<std::shared_ptr<Array>> SliceSafe(int64_t offset) const;
+
+ const std::shared_ptr<ArrayData>& data() const { return data_; }
+
+ int num_fields() const { return static_cast<int>(data_->child_data.size()); }
+
+ /// \return PrettyPrint representation of array suitable for debugging
+ std::string ToString() const;
+
+ /// \brief Perform cheap validation checks to determine obvious inconsistencies
+ /// within the array's internal data.
+ ///
+ /// This is O(k) where k is the number of descendents.
+ ///
+ /// \return Status
+ Status Validate() const;
+
+ /// \brief Perform extensive validation checks to determine inconsistencies
+ /// within the array's internal data.
+ ///
+ /// This is potentially O(k*n) where k is the number of descendents and n
+ /// is the array length.
+ ///
+ /// \return Status
+ Status ValidateFull() const;
+
+ protected:
+ Array() : null_bitmap_data_(NULLPTR) {}
+
+ std::shared_ptr<ArrayData> data_;
+ const uint8_t* null_bitmap_data_;
+
+ /// Protected method for constructors
+ void SetData(const std::shared_ptr<ArrayData>& data) {
+ if (data->buffers.size() > 0) {
+ null_bitmap_data_ = data->GetValuesSafe<uint8_t>(0, /*offset=*/0);
+ } else {
+ null_bitmap_data_ = NULLPTR;
+ }
+ data_ = data;
+ }
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Array);
+};
+
+static inline std::ostream& operator<<(std::ostream& os, const Array& x) {
+ os << x.ToString();
+ return os;
+}
+
+/// Base class for non-nested arrays
+class ARROW_EXPORT FlatArray : public Array {
+ protected:
+ using Array::Array;
+};
+
+/// Base class for arrays of fixed-size logical types
+class ARROW_EXPORT PrimitiveArray : public FlatArray {
+ public:
+ PrimitiveArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ /// Does not account for any slice offset
+ std::shared_ptr<Buffer> values() const { return data_->buffers[1]; }
+
+ protected:
+ PrimitiveArray() : raw_values_(NULLPTR) {}
+
+ void SetData(const std::shared_ptr<ArrayData>& data) {
+ this->Array::SetData(data);
+ raw_values_ = data->GetValuesSafe<uint8_t>(1, /*offset=*/0);
+ }
+
+ explicit PrimitiveArray(const std::shared_ptr<ArrayData>& data) { SetData(data); }
+
+ const uint8_t* raw_values_;
+};
+
+/// Degenerate null type Array
+class ARROW_EXPORT NullArray : public FlatArray {
+ public:
+ using TypeClass = NullType;
+
+ explicit NullArray(const std::shared_ptr<ArrayData>& data) { SetData(data); }
+ explicit NullArray(int64_t length);
+
+ private:
+ void SetData(const std::shared_ptr<ArrayData>& data) {
+ null_bitmap_data_ = NULLPTR;
+ data->null_count = data->length;
+ data_ = data;
+ }
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_binary.cc b/src/arrow/cpp/src/arrow/array/array_binary.cc
new file mode 100644
index 000000000..9466b5a48
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_binary.cc
@@ -0,0 +1,108 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/array_binary.h"
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/validate.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+BinaryArray::BinaryArray(const std::shared_ptr<ArrayData>& data) {
+ ARROW_CHECK(is_binary_like(data->type->id()));
+ SetData(data);
+}
+
+BinaryArray::BinaryArray(int64_t length, const std::shared_ptr<Buffer>& value_offsets,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap, int64_t null_count,
+ int64_t offset) {
+ SetData(ArrayData::Make(binary(), length, {null_bitmap, value_offsets, data},
+ null_count, offset));
+}
+
+LargeBinaryArray::LargeBinaryArray(const std::shared_ptr<ArrayData>& data) {
+ ARROW_CHECK(is_large_binary_like(data->type->id()));
+ SetData(data);
+}
+
+LargeBinaryArray::LargeBinaryArray(int64_t length,
+ const std::shared_ptr<Buffer>& value_offsets,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap,
+ int64_t null_count, int64_t offset) {
+ SetData(ArrayData::Make(large_binary(), length, {null_bitmap, value_offsets, data},
+ null_count, offset));
+}
+
+StringArray::StringArray(const std::shared_ptr<ArrayData>& data) {
+ ARROW_CHECK_EQ(data->type->id(), Type::STRING);
+ SetData(data);
+}
+
+StringArray::StringArray(int64_t length, const std::shared_ptr<Buffer>& value_offsets,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap, int64_t null_count,
+ int64_t offset) {
+ SetData(ArrayData::Make(utf8(), length, {null_bitmap, value_offsets, data}, null_count,
+ offset));
+}
+
+Status StringArray::ValidateUTF8() const { return internal::ValidateUTF8(*data_); }
+
+LargeStringArray::LargeStringArray(const std::shared_ptr<ArrayData>& data) {
+ ARROW_CHECK_EQ(data->type->id(), Type::LARGE_STRING);
+ SetData(data);
+}
+
+LargeStringArray::LargeStringArray(int64_t length,
+ const std::shared_ptr<Buffer>& value_offsets,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap,
+ int64_t null_count, int64_t offset) {
+ SetData(ArrayData::Make(large_utf8(), length, {null_bitmap, value_offsets, data},
+ null_count, offset));
+}
+
+Status LargeStringArray::ValidateUTF8() const { return internal::ValidateUTF8(*data_); }
+
+FixedSizeBinaryArray::FixedSizeBinaryArray(const std::shared_ptr<ArrayData>& data) {
+ SetData(data);
+}
+
+FixedSizeBinaryArray::FixedSizeBinaryArray(const std::shared_ptr<DataType>& type,
+ int64_t length,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap,
+ int64_t null_count, int64_t offset)
+ : PrimitiveArray(type, length, data, null_bitmap, null_count, offset),
+ byte_width_(checked_cast<const FixedSizeBinaryType&>(*type).byte_width()) {}
+
+const uint8_t* FixedSizeBinaryArray::GetValue(int64_t i) const {
+ return raw_values_ + (i + data_->offset) * byte_width_;
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_binary.h b/src/arrow/cpp/src/arrow/array/array_binary.h
new file mode 100644
index 000000000..19d3d65a4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_binary.h
@@ -0,0 +1,261 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Array accessor classes for Binary, LargeBinart, String, LargeString,
+// FixedSizeBinary
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/data.h"
+#include "arrow/buffer.h"
+#include "arrow/stl_iterator.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_view.h" // IWYU pragma: export
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+/// \addtogroup binary-arrays
+///
+/// @{
+
+// ----------------------------------------------------------------------
+// Binary and String
+
+/// Base class for variable-sized binary arrays, regardless of offset size
+/// and logical interpretation.
+template <typename TYPE>
+class BaseBinaryArray : public FlatArray {
+ public:
+ using TypeClass = TYPE;
+ using offset_type = typename TypeClass::offset_type;
+ using IteratorType = stl::ArrayIterator<BaseBinaryArray<TYPE>>;
+
+ /// Return the pointer to the given elements bytes
+ // XXX should GetValue(int64_t i) return a string_view?
+ const uint8_t* GetValue(int64_t i, offset_type* out_length) const {
+ // Account for base offset
+ i += data_->offset;
+ const offset_type pos = raw_value_offsets_[i];
+ *out_length = raw_value_offsets_[i + 1] - pos;
+ return raw_data_ + pos;
+ }
+
+ /// \brief Get binary value as a string_view
+ ///
+ /// \param i the value index
+ /// \return the view over the selected value
+ util::string_view GetView(int64_t i) const {
+ // Account for base offset
+ i += data_->offset;
+ const offset_type pos = raw_value_offsets_[i];
+ return util::string_view(reinterpret_cast<const char*>(raw_data_ + pos),
+ raw_value_offsets_[i + 1] - pos);
+ }
+
+ /// \brief Get binary value as a string_view
+ /// Provided for consistency with other arrays.
+ ///
+ /// \param i the value index
+ /// \return the view over the selected value
+ util::string_view Value(int64_t i) const { return GetView(i); }
+
+ /// \brief Get binary value as a std::string
+ ///
+ /// \param i the value index
+ /// \return the value copied into a std::string
+ std::string GetString(int64_t i) const { return std::string(GetView(i)); }
+
+ /// Note that this buffer does not account for any slice offset
+ std::shared_ptr<Buffer> value_offsets() const { return data_->buffers[1]; }
+
+ /// Note that this buffer does not account for any slice offset
+ std::shared_ptr<Buffer> value_data() const { return data_->buffers[2]; }
+
+ const offset_type* raw_value_offsets() const {
+ return raw_value_offsets_ + data_->offset;
+ }
+
+ const uint8_t* raw_data() const { return raw_data_; }
+
+ /// \brief Return the data buffer absolute offset of the data for the value
+ /// at the passed index.
+ ///
+ /// Does not perform boundschecking
+ offset_type value_offset(int64_t i) const {
+ return raw_value_offsets_[i + data_->offset];
+ }
+
+ /// \brief Return the length of the data for the value at the passed index.
+ ///
+ /// Does not perform boundschecking
+ offset_type value_length(int64_t i) const {
+ i += data_->offset;
+ return raw_value_offsets_[i + 1] - raw_value_offsets_[i];
+ }
+
+ /// \brief Return the total length of the memory in the data buffer
+ /// referenced by this array. If the array has been sliced then this may be
+ /// less than the size of the data buffer (data_->buffers[2]).
+ offset_type total_values_length() const {
+ if (data_->length > 0) {
+ return raw_value_offsets_[data_->length + data_->offset] -
+ raw_value_offsets_[data_->offset];
+ } else {
+ return 0;
+ }
+ }
+
+ IteratorType begin() const { return IteratorType(*this); }
+
+ IteratorType end() const { return IteratorType(*this, length()); }
+
+ protected:
+ // For subclasses
+ BaseBinaryArray() = default;
+
+ // Protected method for constructors
+ void SetData(const std::shared_ptr<ArrayData>& data) {
+ this->Array::SetData(data);
+ raw_value_offsets_ = data->GetValuesSafe<offset_type>(1, /*offset=*/0);
+ raw_data_ = data->GetValuesSafe<uint8_t>(2, /*offset=*/0);
+ }
+
+ const offset_type* raw_value_offsets_ = NULLPTR;
+ const uint8_t* raw_data_ = NULLPTR;
+};
+
+/// Concrete Array class for variable-size binary data
+class ARROW_EXPORT BinaryArray : public BaseBinaryArray<BinaryType> {
+ public:
+ explicit BinaryArray(const std::shared_ptr<ArrayData>& data);
+
+ BinaryArray(int64_t length, const std::shared_ptr<Buffer>& value_offsets,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ protected:
+ // For subclasses such as StringArray
+ BinaryArray() : BaseBinaryArray() {}
+};
+
+/// Concrete Array class for variable-size string (utf-8) data
+class ARROW_EXPORT StringArray : public BinaryArray {
+ public:
+ using TypeClass = StringType;
+
+ explicit StringArray(const std::shared_ptr<ArrayData>& data);
+
+ StringArray(int64_t length, const std::shared_ptr<Buffer>& value_offsets,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ /// \brief Validate that this array contains only valid UTF8 entries
+ ///
+ /// This check is also implied by ValidateFull()
+ Status ValidateUTF8() const;
+};
+
+/// Concrete Array class for large variable-size binary data
+class ARROW_EXPORT LargeBinaryArray : public BaseBinaryArray<LargeBinaryType> {
+ public:
+ explicit LargeBinaryArray(const std::shared_ptr<ArrayData>& data);
+
+ LargeBinaryArray(int64_t length, const std::shared_ptr<Buffer>& value_offsets,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ protected:
+ // For subclasses such as LargeStringArray
+ LargeBinaryArray() : BaseBinaryArray() {}
+};
+
+/// Concrete Array class for large variable-size string (utf-8) data
+class ARROW_EXPORT LargeStringArray : public LargeBinaryArray {
+ public:
+ using TypeClass = LargeStringType;
+
+ explicit LargeStringArray(const std::shared_ptr<ArrayData>& data);
+
+ LargeStringArray(int64_t length, const std::shared_ptr<Buffer>& value_offsets,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ /// \brief Validate that this array contains only valid UTF8 entries
+ ///
+ /// This check is also implied by ValidateFull()
+ Status ValidateUTF8() const;
+};
+
+// ----------------------------------------------------------------------
+// Fixed width binary
+
+/// Concrete Array class for fixed-size binary data
+class ARROW_EXPORT FixedSizeBinaryArray : public PrimitiveArray {
+ public:
+ using TypeClass = FixedSizeBinaryType;
+ using IteratorType = stl::ArrayIterator<FixedSizeBinaryArray>;
+
+ explicit FixedSizeBinaryArray(const std::shared_ptr<ArrayData>& data);
+
+ FixedSizeBinaryArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ const uint8_t* GetValue(int64_t i) const;
+ const uint8_t* Value(int64_t i) const { return GetValue(i); }
+
+ util::string_view GetView(int64_t i) const {
+ return util::string_view(reinterpret_cast<const char*>(GetValue(i)), byte_width());
+ }
+
+ std::string GetString(int64_t i) const { return std::string(GetView(i)); }
+
+ int32_t byte_width() const { return byte_width_; }
+
+ const uint8_t* raw_values() const { return raw_values_ + data_->offset * byte_width_; }
+
+ IteratorType begin() const { return IteratorType(*this); }
+
+ IteratorType end() const { return IteratorType(*this, length()); }
+
+ protected:
+ void SetData(const std::shared_ptr<ArrayData>& data) {
+ this->PrimitiveArray::SetData(data);
+ byte_width_ =
+ internal::checked_cast<const FixedSizeBinaryType&>(*type()).byte_width();
+ }
+
+ int32_t byte_width_;
+};
+
+/// @}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_binary_test.cc b/src/arrow/cpp/src/arrow/array/array_binary_test.cc
new file mode 100644
index 000000000..6892e5f0a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_binary_test.cc
@@ -0,0 +1,900 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/buffer.h"
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_builders.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/string_view.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+// ----------------------------------------------------------------------
+// String / Binary tests
+
+template <typename ArrayType>
+void CheckStringArray(const ArrayType& array, const std::vector<std::string>& strings,
+ const std::vector<uint8_t>& is_valid, int repeats = 1) {
+ int64_t length = array.length();
+ int64_t base_length = static_cast<int64_t>(strings.size());
+ ASSERT_EQ(base_length, static_cast<int64_t>(is_valid.size()));
+ ASSERT_EQ(base_length * repeats, length);
+
+ int32_t value_pos = 0;
+ for (int i = 0; i < length; ++i) {
+ auto j = i % base_length;
+ if (is_valid[j]) {
+ ASSERT_FALSE(array.IsNull(i));
+ auto view = array.GetView(i);
+ ASSERT_EQ(value_pos, array.value_offset(i));
+ ASSERT_EQ(strings[j].size(), view.size());
+ ASSERT_EQ(util::string_view(strings[j]), view);
+ value_pos += static_cast<int32_t>(view.size());
+ } else {
+ ASSERT_TRUE(array.IsNull(i));
+ }
+ }
+}
+
+template <typename T>
+class TestStringArray : public ::testing::Test {
+ public:
+ using TypeClass = T;
+ using offset_type = typename TypeClass::offset_type;
+ using ArrayType = typename TypeTraits<TypeClass>::ArrayType;
+ using BuilderType = typename TypeTraits<TypeClass>::BuilderType;
+
+ void SetUp() {
+ chars_ = {'a', 'b', 'b', 'c', 'c', 'c'};
+ offsets_ = {0, 1, 1, 1, 3, 6};
+ valid_bytes_ = {1, 1, 0, 1, 1};
+ expected_ = {"a", "", "", "bb", "ccc"};
+
+ MakeArray();
+ }
+
+ void MakeArray() {
+ length_ = static_cast<int64_t>(offsets_.size()) - 1;
+ value_buf_ = Buffer::Wrap(chars_);
+ offsets_buf_ = Buffer::Wrap(offsets_);
+ ASSERT_OK_AND_ASSIGN(null_bitmap_, internal::BytesToBits(valid_bytes_));
+ null_count_ = CountNulls(valid_bytes_);
+
+ strings_ = std::make_shared<ArrayType>(length_, offsets_buf_, value_buf_,
+ null_bitmap_, null_count_);
+ }
+
+ void TestArrayBasics() {
+ ASSERT_EQ(length_, strings_->length());
+ ASSERT_EQ(1, strings_->null_count());
+ ASSERT_OK(strings_->ValidateFull());
+ TestInitialized(*strings_);
+ AssertZeroPadded(*strings_);
+ }
+
+ void TestArrayCtors() {
+ // ARROW-8863: ArrayData::null_count set to 0 when no validity bitmap
+ // provided
+ ArrayType arr(length_, offsets_buf_, value_buf_);
+ ASSERT_EQ(arr.data()->null_count, 0);
+ }
+
+ void TestTotalValuesLength() {
+ auto ty = TypeTraits<T>::type_singleton();
+ auto arr = ArrayFromJSON(ty, R"(["a", null, "bbb", "cccc", "ddddd"])");
+
+ offset_type values_length = arr.total_values_length();
+ ASSERT_EQ(values_length, static_cast<offset_type>(13));
+
+ offset_type sliced_values_length =
+ checked_cast<const ArrayType&>(*arr.Slice(3)).total_values_length();
+ ASSERT_EQ(sliced_values_length, static_cast<offset_type>(9));
+
+ // Zero-length array is a special case
+ offset_type zero_size_length =
+ checked_cast<const ArrayType&>(*arr.Slice(0, 0)).total_values_length();
+ ASSERT_EQ(zero_size_length, static_cast<offset_type>(0));
+ }
+
+ void TestType() {
+ std::shared_ptr<DataType> type = this->strings_->type();
+
+ if (std::is_same<TypeClass, StringType>::value) {
+ ASSERT_EQ(Type::STRING, type->id());
+ ASSERT_EQ(Type::STRING, this->strings_->type_id());
+ } else if (std::is_same<TypeClass, LargeStringType>::value) {
+ ASSERT_EQ(Type::LARGE_STRING, type->id());
+ ASSERT_EQ(Type::LARGE_STRING, this->strings_->type_id());
+ } else if (std::is_same<TypeClass, BinaryType>::value) {
+ ASSERT_EQ(Type::BINARY, type->id());
+ ASSERT_EQ(Type::BINARY, this->strings_->type_id());
+ } else if (std::is_same<TypeClass, LargeBinaryType>::value) {
+ ASSERT_EQ(Type::LARGE_BINARY, type->id());
+ ASSERT_EQ(Type::LARGE_BINARY, this->strings_->type_id());
+ } else {
+ FAIL();
+ }
+ }
+
+ void TestListFunctions() {
+ int64_t pos = 0;
+ for (size_t i = 0; i < expected_.size(); ++i) {
+ ASSERT_EQ(pos, strings_->value_offset(i));
+ ASSERT_EQ(expected_[i].size(), strings_->value_length(i));
+ pos += expected_[i].size();
+ }
+ }
+
+ void TestDestructor() {
+ auto arr = std::make_shared<ArrayType>(length_, offsets_buf_, value_buf_,
+ null_bitmap_, null_count_);
+ }
+
+ void TestGetString() {
+ for (size_t i = 0; i < expected_.size(); ++i) {
+ if (valid_bytes_[i] == 0) {
+ ASSERT_TRUE(strings_->IsNull(i));
+ } else {
+ ASSERT_FALSE(strings_->IsNull(i));
+ ASSERT_EQ(expected_[i], strings_->GetString(i));
+ }
+ }
+ }
+
+ void TestEmptyStringComparison() {
+ offsets_ = {0, 0, 0, 0, 0, 0};
+ offsets_buf_ = Buffer::Wrap(offsets_);
+ length_ = static_cast<int64_t>(offsets_.size() - 1);
+
+ auto strings_a = std::make_shared<ArrayType>(length_, offsets_buf_, nullptr,
+ null_bitmap_, null_count_);
+ auto strings_b = std::make_shared<ArrayType>(length_, offsets_buf_, nullptr,
+ null_bitmap_, null_count_);
+ ASSERT_TRUE(strings_a->Equals(strings_b));
+ }
+
+ void TestCompareNullByteSlots() {
+ BuilderType builder;
+ BuilderType builder2;
+ BuilderType builder3;
+
+ ASSERT_OK(builder.Append("foo"));
+ ASSERT_OK(builder2.Append("foo"));
+ ASSERT_OK(builder3.Append("foo"));
+
+ ASSERT_OK(builder.Append("bar"));
+ ASSERT_OK(builder2.AppendNull());
+
+ // same length, but different
+ ASSERT_OK(builder3.Append("xyz"));
+
+ ASSERT_OK(builder.Append("baz"));
+ ASSERT_OK(builder2.Append("baz"));
+ ASSERT_OK(builder3.Append("baz"));
+
+ std::shared_ptr<Array> array, array2, array3;
+ FinishAndCheckPadding(&builder, &array);
+ ASSERT_OK(builder2.Finish(&array2));
+ ASSERT_OK(builder3.Finish(&array3));
+
+ const auto& a1 = checked_cast<const ArrayType&>(*array);
+ const auto& a2 = checked_cast<const ArrayType&>(*array2);
+ const auto& a3 = checked_cast<const ArrayType&>(*array3);
+
+ // The validity bitmaps are the same, the data is different, but the unequal
+ // portion is masked out
+ ArrayType equal_array(3, a1.value_offsets(), a1.value_data(), a2.null_bitmap(), 1);
+ ArrayType equal_array2(3, a3.value_offsets(), a3.value_data(), a2.null_bitmap(), 1);
+
+ ASSERT_TRUE(equal_array.Equals(equal_array2));
+ ASSERT_TRUE(a2.RangeEquals(equal_array2, 0, 3, 0));
+
+ ASSERT_TRUE(equal_array.Array::Slice(1)->Equals(equal_array2.Array::Slice(1)));
+ ASSERT_TRUE(
+ equal_array.Array::Slice(1)->RangeEquals(0, 2, 0, equal_array2.Array::Slice(1)));
+ }
+
+ void TestSliceGetString() {
+ BuilderType builder;
+
+ ASSERT_OK(builder.Append("a"));
+ ASSERT_OK(builder.Append("b"));
+ ASSERT_OK(builder.Append("c"));
+
+ std::shared_ptr<Array> array;
+ ASSERT_OK(builder.Finish(&array));
+ auto s = array->Slice(1, 10);
+ auto arr = std::dynamic_pointer_cast<ArrayType>(s);
+ ASSERT_EQ(arr->GetString(0), "b");
+ }
+
+ Status ValidateFull(int64_t length, std::vector<offset_type> offsets,
+ util::string_view data, int64_t offset = 0) {
+ ArrayType arr(length, Buffer::Wrap(offsets), std::make_shared<Buffer>(data),
+ /*null_bitmap=*/nullptr, /*null_count=*/0, offset);
+ return arr.ValidateFull();
+ }
+
+ Status ValidateFull(const std::string& json) {
+ auto ty = TypeTraits<T>::type_singleton();
+ auto arr = ArrayFromJSON(ty, json);
+ return arr->ValidateFull();
+ }
+
+ void TestValidateOffsets() {
+ ASSERT_OK(ValidateFull(0, {0}, ""));
+ ASSERT_OK(ValidateFull(1, {0, 4}, "data"));
+ ASSERT_OK(ValidateFull(2, {0, 4, 4}, "data"));
+ ASSERT_OK(ValidateFull(2, {0, 5, 9}, "some data"));
+
+ // Non-zero array offset
+ ASSERT_OK(ValidateFull(0, {0, 4}, "data", 1));
+ ASSERT_OK(ValidateFull(1, {0, 5, 9}, "some data", 1));
+ ASSERT_OK(ValidateFull(0, {0, 5, 9}, "some data", 2));
+
+ // Not enough offsets
+ ASSERT_RAISES(Invalid, ValidateFull(1, {}, ""));
+ ASSERT_RAISES(Invalid, ValidateFull(1, {0}, ""));
+ ASSERT_RAISES(Invalid, ValidateFull(2, {0, 4}, "data"));
+ ASSERT_RAISES(Invalid, ValidateFull(1, {0, 4}, "data", 1));
+
+ // Offset out of bounds
+ ASSERT_RAISES(Invalid, ValidateFull(1, {0, 5}, "data"));
+ // Negative offset
+ ASSERT_RAISES(Invalid, ValidateFull(1, {-1, 0}, "data"));
+ ASSERT_RAISES(Invalid, ValidateFull(1, {0, -1}, "data"));
+ ASSERT_RAISES(Invalid, ValidateFull(1, {0, -1, -1}, "data", 1));
+ // Offsets non-monotonic
+ ASSERT_RAISES(Invalid, ValidateFull(2, {0, 5, 4}, "some data"));
+ }
+
+ void TestValidateData() {
+ // Valid UTF8
+ ASSERT_OK(ValidateFull(R"(["Voix", "ambiguë", "d’un", "cœur"])"));
+ ASSERT_OK(ValidateFull(R"(["いろはにほへと", "ちりぬるを", "わかよたれそ"])"));
+ ASSERT_OK(ValidateFull(R"(["😀", "😄"])"));
+ ASSERT_OK(ValidateFull(1, {0, 4}, "\xf4\x8f\xbf\xbf")); // \U0010ffff
+
+ // Invalid UTF8
+ auto ty = TypeTraits<T>::type_singleton();
+ auto st1 = ValidateFull(3, {0, 4, 6, 9}, "abc \xff def");
+ // Hypothetical \U00110000
+ auto st2 = ValidateFull(1, {0, 4}, "\xf4\x90\x80\x80");
+ // Single UTF8 character straddles two entries
+ auto st3 = ValidateFull(2, {0, 1, 2}, "\xc3\xa9");
+ if (T::is_utf8) {
+ ASSERT_RAISES(Invalid, st1);
+ ASSERT_RAISES(Invalid, st2);
+ ASSERT_RAISES(Invalid, st3);
+ } else {
+ ASSERT_OK(st1);
+ ASSERT_OK(st2);
+ ASSERT_OK(st3);
+ }
+ }
+
+ protected:
+ std::vector<offset_type> offsets_;
+ std::vector<char> chars_;
+ std::vector<uint8_t> valid_bytes_;
+
+ std::vector<std::string> expected_;
+
+ std::shared_ptr<Buffer> value_buf_;
+ std::shared_ptr<Buffer> offsets_buf_;
+ std::shared_ptr<Buffer> null_bitmap_;
+
+ int64_t null_count_;
+ int64_t length_;
+
+ std::shared_ptr<ArrayType> strings_;
+};
+
+TYPED_TEST_SUITE(TestStringArray, BinaryArrowTypes);
+
+TYPED_TEST(TestStringArray, TestArrayBasics) { this->TestArrayBasics(); }
+
+TYPED_TEST(TestStringArray, TestArrayCtors) { this->TestArrayCtors(); }
+
+TYPED_TEST(TestStringArray, TestType) { this->TestType(); }
+
+TYPED_TEST(TestStringArray, TestListFunctions) { this->TestListFunctions(); }
+
+TYPED_TEST(TestStringArray, TestDestructor) { this->TestDestructor(); }
+
+TYPED_TEST(TestStringArray, TestGetString) { this->TestGetString(); }
+
+TYPED_TEST(TestStringArray, TestEmptyStringComparison) {
+ this->TestEmptyStringComparison();
+}
+
+TYPED_TEST(TestStringArray, CompareNullByteSlots) { this->TestCompareNullByteSlots(); }
+
+TYPED_TEST(TestStringArray, TestSliceGetString) { this->TestSliceGetString(); }
+
+TYPED_TEST(TestStringArray, TestValidateOffsets) { this->TestValidateOffsets(); }
+
+TYPED_TEST(TestStringArray, TestValidateData) { this->TestValidateData(); }
+
+template <typename T>
+class TestUTF8Array : public ::testing::Test {
+ public:
+ using TypeClass = T;
+ using offset_type = typename TypeClass::offset_type;
+ using ArrayType = typename TypeTraits<TypeClass>::ArrayType;
+
+ Status ValidateUTF8(int64_t length, std::vector<offset_type> offsets,
+ util::string_view data, int64_t offset = 0) {
+ ArrayType arr(length, Buffer::Wrap(offsets), std::make_shared<Buffer>(data),
+ /*null_bitmap=*/nullptr, /*null_count=*/0, offset);
+ return arr.ValidateUTF8();
+ }
+
+ Status ValidateUTF8(const std::string& json) {
+ auto ty = TypeTraits<T>::type_singleton();
+ auto arr = ArrayFromJSON(ty, json);
+ return checked_cast<const ArrayType&>(*arr).ValidateUTF8();
+ }
+
+ void TestValidateUTF8() {
+ ASSERT_OK(ValidateUTF8(R"(["Voix", "ambiguë", "d’un", "cœur"])"));
+ ASSERT_OK(ValidateUTF8(1, {0, 4}, "\xf4\x8f\xbf\xbf")); // \U0010ffff
+
+ ASSERT_RAISES(Invalid, ValidateUTF8(1, {0, 1}, "\xf4"));
+
+ // More tests in TestValidateData() above
+ // (ValidateFull() calls ValidateUTF8() internally)
+ }
+};
+
+TYPED_TEST_SUITE(TestUTF8Array, StringArrowTypes);
+
+TYPED_TEST(TestUTF8Array, TestValidateUTF8) { this->TestValidateUTF8(); }
+
+// ----------------------------------------------------------------------
+// String builder tests
+
+template <typename T>
+class TestStringBuilder : public TestBuilder {
+ public:
+ using TypeClass = T;
+ using offset_type = typename TypeClass::offset_type;
+ using ArrayType = typename TypeTraits<TypeClass>::ArrayType;
+ using BuilderType = typename TypeTraits<TypeClass>::BuilderType;
+
+ void SetUp() {
+ TestBuilder::SetUp();
+ builder_.reset(new BuilderType(pool_));
+ }
+
+ void Done() {
+ std::shared_ptr<Array> out;
+ FinishAndCheckPadding(builder_.get(), &out);
+
+ result_ = std::dynamic_pointer_cast<ArrayType>(out);
+ ASSERT_OK(result_->ValidateFull());
+ }
+
+ void TestScalarAppend() {
+ std::vector<std::string> strings = {"", "bb", "a", "", "ccc"};
+ std::vector<uint8_t> is_valid = {1, 1, 1, 0, 1};
+
+ int N = static_cast<int>(strings.size());
+ int reps = 10;
+
+ for (int j = 0; j < reps; ++j) {
+ for (int i = 0; i < N; ++i) {
+ if (!is_valid[i]) {
+ ASSERT_OK(builder_->AppendNull());
+ } else {
+ ASSERT_OK(builder_->Append(strings[i]));
+ }
+ }
+ }
+ Done();
+
+ ASSERT_EQ(reps * N, result_->length());
+ ASSERT_EQ(reps, result_->null_count());
+ ASSERT_EQ(reps * 6, result_->value_data()->size());
+
+ CheckStringArray(*result_, strings, is_valid, reps);
+ }
+
+ void TestScalarAppendUnsafe() {
+ std::vector<std::string> strings = {"", "bb", "a", "", "ccc"};
+ std::vector<uint8_t> is_valid = {1, 1, 1, 0, 1};
+
+ int N = static_cast<int>(strings.size());
+ int reps = 13;
+ int64_t total_length = 0;
+ for (const auto& s : strings) {
+ total_length += static_cast<int64_t>(s.size());
+ }
+
+ ASSERT_OK(builder_->Reserve(N * reps));
+ ASSERT_OK(builder_->ReserveData(total_length * reps));
+
+ for (int j = 0; j < reps; ++j) {
+ for (int i = 0; i < N; ++i) {
+ if (!is_valid[i]) {
+ builder_->UnsafeAppendNull();
+ } else {
+ builder_->UnsafeAppend(strings[i]);
+ }
+ }
+ }
+ ASSERT_EQ(builder_->value_data_length(), total_length * reps);
+ Done();
+
+ ASSERT_OK(result_->ValidateFull());
+ ASSERT_EQ(reps * N, result_->length());
+ ASSERT_EQ(reps, result_->null_count());
+ ASSERT_EQ(reps * total_length, result_->value_data()->size());
+
+ CheckStringArray(*result_, strings, is_valid, reps);
+ }
+
+ void TestExtendCurrent() {
+ std::vector<std::string> strings = {"", "bbbb", "aaaaa", "", "ccc"};
+ std::vector<uint8_t> is_valid = {1, 1, 1, 0, 1};
+
+ int N = static_cast<int>(strings.size());
+ int reps = 10;
+
+ for (int j = 0; j < reps; ++j) {
+ for (int i = 0; i < N; ++i) {
+ if (!is_valid[i]) {
+ ASSERT_OK(builder_->AppendNull());
+ } else if (strings[i].length() > 3) {
+ ASSERT_OK(builder_->Append(strings[i].substr(0, 3)));
+ ASSERT_OK(builder_->ExtendCurrent(strings[i].substr(3)));
+ } else {
+ ASSERT_OK(builder_->Append(strings[i]));
+ }
+ }
+ }
+ Done();
+
+ ASSERT_EQ(reps * N, result_->length());
+ ASSERT_EQ(reps, result_->null_count());
+ ASSERT_EQ(reps * 12, result_->value_data()->size());
+
+ CheckStringArray(*result_, strings, is_valid, reps);
+ }
+
+ void TestExtendCurrentUnsafe() {
+ std::vector<std::string> strings = {"", "bbbb", "aaaaa", "", "ccc"};
+ std::vector<uint8_t> is_valid = {1, 1, 1, 0, 1};
+
+ int N = static_cast<int>(strings.size());
+ int reps = 13;
+ int64_t total_length = 0;
+ for (const auto& s : strings) {
+ total_length += static_cast<int64_t>(s.size());
+ }
+
+ ASSERT_OK(builder_->Reserve(N * reps));
+ ASSERT_OK(builder_->ReserveData(total_length * reps));
+
+ for (int j = 0; j < reps; ++j) {
+ for (int i = 0; i < N; ++i) {
+ if (!is_valid[i]) {
+ builder_->UnsafeAppendNull();
+ } else if (strings[i].length() > 3) {
+ builder_->UnsafeAppend(strings[i].substr(0, 3));
+ builder_->UnsafeExtendCurrent(strings[i].substr(3));
+ } else {
+ builder_->UnsafeAppend(strings[i]);
+ }
+ }
+ }
+ ASSERT_EQ(builder_->value_data_length(), total_length * reps);
+ Done();
+
+ ASSERT_EQ(reps * N, result_->length());
+ ASSERT_EQ(reps, result_->null_count());
+ ASSERT_EQ(reps * 12, result_->value_data()->size());
+
+ CheckStringArray(*result_, strings, is_valid, reps);
+ }
+
+ void TestVectorAppend() {
+ std::vector<std::string> strings = {"", "bb", "a", "", "ccc"};
+ std::vector<uint8_t> valid_bytes = {1, 1, 1, 0, 1};
+
+ int N = static_cast<int>(strings.size());
+ int reps = 1000;
+
+ for (int j = 0; j < reps; ++j) {
+ ASSERT_OK(builder_->AppendValues(strings, valid_bytes.data()));
+ }
+ Done();
+
+ ASSERT_EQ(reps * N, result_->length());
+ ASSERT_EQ(reps, result_->null_count());
+ ASSERT_EQ(reps * 6, result_->value_data()->size());
+
+ CheckStringArray(*result_, strings, valid_bytes, reps);
+ }
+
+ void TestAppendCStringsWithValidBytes() {
+ const char* strings[] = {nullptr, "aaa", nullptr, "ignored", ""};
+ std::vector<uint8_t> valid_bytes = {1, 1, 1, 0, 1};
+
+ int N = static_cast<int>(sizeof(strings) / sizeof(strings[0]));
+ int reps = 1000;
+
+ for (int j = 0; j < reps; ++j) {
+ ASSERT_OK(builder_->AppendValues(strings, N, valid_bytes.data()));
+ }
+ Done();
+
+ ASSERT_EQ(reps * N, result_->length());
+ ASSERT_EQ(reps * 3, result_->null_count());
+ ASSERT_EQ(reps * 3, result_->value_data()->size());
+
+ CheckStringArray(*result_, {"", "aaa", "", "", ""}, {0, 1, 0, 0, 1}, reps);
+ }
+
+ void TestAppendCStringsWithoutValidBytes() {
+ const char* strings[] = {"", "bb", "a", nullptr, "ccc"};
+
+ int N = static_cast<int>(sizeof(strings) / sizeof(strings[0]));
+ int reps = 1000;
+
+ for (int j = 0; j < reps; ++j) {
+ ASSERT_OK(builder_->AppendValues(strings, N));
+ }
+ Done();
+
+ ASSERT_EQ(reps * N, result_->length());
+ ASSERT_EQ(reps, result_->null_count());
+ ASSERT_EQ(reps * 6, result_->value_data()->size());
+
+ CheckStringArray(*result_, {"", "bb", "a", "", "ccc"}, {1, 1, 1, 0, 1}, reps);
+ }
+
+ void TestCapacityReserve() {
+ std::vector<std::string> strings = {"aaaaa", "bbbbbbbbbb", "ccccccccccccccc",
+ "dddddddddd"};
+ int N = static_cast<int>(strings.size());
+ int reps = 15;
+ int64_t length = 0;
+ int64_t capacity = 1000;
+ int64_t expected_capacity = BitUtil::RoundUpToMultipleOf64(capacity);
+
+ ASSERT_OK(builder_->ReserveData(capacity));
+
+ ASSERT_EQ(length, builder_->value_data_length());
+ ASSERT_EQ(expected_capacity, builder_->value_data_capacity());
+
+ for (int j = 0; j < reps; ++j) {
+ for (int i = 0; i < N; ++i) {
+ ASSERT_OK(builder_->Append(strings[i]));
+ length += static_cast<int64_t>(strings[i].size());
+
+ ASSERT_EQ(length, builder_->value_data_length());
+ ASSERT_EQ(expected_capacity, builder_->value_data_capacity());
+ }
+ }
+
+ int extra_capacity = 500;
+ expected_capacity = BitUtil::RoundUpToMultipleOf64(length + extra_capacity);
+
+ ASSERT_OK(builder_->ReserveData(extra_capacity));
+
+ ASSERT_EQ(length, builder_->value_data_length());
+ int64_t actual_capacity = builder_->value_data_capacity();
+ ASSERT_GE(actual_capacity, expected_capacity);
+ ASSERT_EQ(actual_capacity & 63, 0);
+
+ Done();
+
+ ASSERT_EQ(reps * N, result_->length());
+ ASSERT_EQ(0, result_->null_count());
+ ASSERT_EQ(reps * 40, result_->value_data()->size());
+ }
+
+ void TestOverflowCheck() {
+ auto max_size = builder_->memory_limit();
+
+ ASSERT_OK(builder_->ValidateOverflow(1));
+ ASSERT_OK(builder_->ValidateOverflow(max_size));
+ ASSERT_RAISES(CapacityError, builder_->ValidateOverflow(max_size + 1));
+
+ ASSERT_OK(builder_->Append("bb"));
+ ASSERT_OK(builder_->ValidateOverflow(max_size - 2));
+ ASSERT_RAISES(CapacityError, builder_->ValidateOverflow(max_size - 1));
+
+ ASSERT_OK(builder_->AppendNull());
+ ASSERT_OK(builder_->ValidateOverflow(max_size - 2));
+ ASSERT_RAISES(CapacityError, builder_->ValidateOverflow(max_size - 1));
+
+ ASSERT_OK(builder_->Append("ccc"));
+ ASSERT_OK(builder_->ValidateOverflow(max_size - 5));
+ ASSERT_RAISES(CapacityError, builder_->ValidateOverflow(max_size - 4));
+ }
+
+ void TestZeroLength() {
+ // All buffers are null
+ Done();
+ ASSERT_EQ(result_->length(), 0);
+ ASSERT_EQ(result_->null_count(), 0);
+ }
+
+ protected:
+ std::unique_ptr<BuilderType> builder_;
+ std::shared_ptr<ArrayType> result_;
+};
+
+TYPED_TEST_SUITE(TestStringBuilder, BinaryArrowTypes);
+
+TYPED_TEST(TestStringBuilder, TestScalarAppend) { this->TestScalarAppend(); }
+
+TYPED_TEST(TestStringBuilder, TestScalarAppendUnsafe) { this->TestScalarAppendUnsafe(); }
+
+TYPED_TEST(TestStringBuilder, TestExtendCurrent) { this->TestExtendCurrent(); }
+
+TYPED_TEST(TestStringBuilder, TestExtendCurrentUnsafe) {
+ this->TestExtendCurrentUnsafe();
+}
+
+TYPED_TEST(TestStringBuilder, TestVectorAppend) { this->TestVectorAppend(); }
+
+TYPED_TEST(TestStringBuilder, TestAppendCStringsWithValidBytes) {
+ this->TestAppendCStringsWithValidBytes();
+}
+
+TYPED_TEST(TestStringBuilder, TestAppendCStringsWithoutValidBytes) {
+ this->TestAppendCStringsWithoutValidBytes();
+}
+
+TYPED_TEST(TestStringBuilder, TestCapacityReserve) { this->TestCapacityReserve(); }
+
+TYPED_TEST(TestStringBuilder, TestZeroLength) { this->TestZeroLength(); }
+
+TYPED_TEST(TestStringBuilder, TestOverflowCheck) { this->TestOverflowCheck(); }
+
+// ----------------------------------------------------------------------
+// ChunkedBinaryBuilder tests
+
+class TestChunkedBinaryBuilder : public ::testing::Test {
+ public:
+ void SetUp() {}
+
+ void Init(int32_t chunksize) {
+ builder_.reset(new internal::ChunkedBinaryBuilder(chunksize));
+ }
+
+ void Init(int32_t chunksize, int32_t chunklength) {
+ builder_.reset(new internal::ChunkedBinaryBuilder(chunksize, chunklength));
+ }
+
+ protected:
+ std::unique_ptr<internal::ChunkedBinaryBuilder> builder_;
+};
+
+TEST_F(TestChunkedBinaryBuilder, BasicOperation) {
+ const int32_t chunksize = 1000;
+ Init(chunksize);
+
+ const int elem_size = 10;
+ uint8_t buf[elem_size];
+
+ BinaryBuilder unchunked_builder;
+
+ const int iterations = 1000;
+ for (int i = 0; i < iterations; ++i) {
+ random_bytes(elem_size, i, buf);
+
+ ASSERT_OK(unchunked_builder.Append(buf, elem_size));
+ ASSERT_OK(builder_->Append(buf, elem_size));
+ }
+
+ std::shared_ptr<Array> unchunked;
+ ASSERT_OK(unchunked_builder.Finish(&unchunked));
+
+ ArrayVector chunks;
+ ASSERT_OK(builder_->Finish(&chunks));
+
+ // This assumes that everything is evenly divisible
+ ArrayVector expected_chunks;
+ const int elems_per_chunk = chunksize / elem_size;
+ for (int i = 0; i < iterations / elems_per_chunk; ++i) {
+ expected_chunks.emplace_back(unchunked->Slice(i * elems_per_chunk, elems_per_chunk));
+ }
+
+ ASSERT_EQ(expected_chunks.size(), chunks.size());
+ for (size_t i = 0; i < chunks.size(); ++i) {
+ AssertArraysEqual(*expected_chunks[i], *chunks[i]);
+ }
+}
+
+TEST_F(TestChunkedBinaryBuilder, Reserve) {
+ // ARROW-6060
+ const int32_t chunksize = 1000;
+ Init(chunksize);
+ ASSERT_OK(builder_->Reserve(chunksize / 2));
+ auto bytes_after_first_reserve = default_memory_pool()->bytes_allocated();
+ for (int i = 0; i < 8; ++i) {
+ ASSERT_OK(builder_->Reserve(chunksize / 2));
+ }
+ // no new memory will be allocated since capacity was sufficient for the loop's
+ // Reserve() calls
+ ASSERT_EQ(default_memory_pool()->bytes_allocated(), bytes_after_first_reserve);
+}
+
+TEST_F(TestChunkedBinaryBuilder, NoData) {
+ Init(1000);
+
+ ArrayVector chunks;
+ ASSERT_OK(builder_->Finish(&chunks));
+
+ ASSERT_EQ(1, chunks.size());
+ ASSERT_EQ(0, chunks[0]->length());
+}
+
+TEST_F(TestChunkedBinaryBuilder, LargeElements) {
+ Init(100);
+
+ const int bufsize = 101;
+ uint8_t buf[bufsize];
+
+ const int iterations = 100;
+ for (int i = 0; i < iterations; ++i) {
+ random_bytes(bufsize, i, buf);
+ ASSERT_OK(builder_->Append(buf, bufsize));
+ }
+
+ ArrayVector chunks;
+ ASSERT_OK(builder_->Finish(&chunks));
+ ASSERT_EQ(iterations, static_cast<int>(chunks.size()));
+
+ int64_t total_data_size = 0;
+ for (auto chunk : chunks) {
+ ASSERT_EQ(1, chunk->length());
+ total_data_size +=
+ static_cast<int64_t>(static_cast<const BinaryArray&>(*chunk).GetView(0).size());
+ }
+ ASSERT_EQ(iterations * bufsize, total_data_size);
+}
+
+TEST_F(TestChunkedBinaryBuilder, LargeElementCount) {
+ int32_t max_chunk_length = 100;
+ Init(100, max_chunk_length);
+
+ auto length = max_chunk_length + 1;
+
+ // ChunkedBinaryBuilder can reserve memory for more than its configured maximum
+ // (per chunk) element count
+ ASSERT_OK(builder_->Reserve(length));
+
+ for (int64_t i = 0; i < 2 * length; ++i) {
+ // Appending more elements than have been reserved memory simply overflows to the next
+ // chunk
+ ASSERT_OK(builder_->Append(""));
+ }
+
+ ArrayVector chunks;
+ ASSERT_OK(builder_->Finish(&chunks));
+
+ // should have two chunks full of empty strings and another with two more empty strings
+ ASSERT_EQ(chunks.size(), 3);
+ ASSERT_EQ(chunks[0]->length(), max_chunk_length);
+ ASSERT_EQ(chunks[1]->length(), max_chunk_length);
+ ASSERT_EQ(chunks[2]->length(), 2);
+ for (auto&& boxed_chunk : chunks) {
+ const auto& chunk = checked_cast<const BinaryArray&>(*boxed_chunk);
+ ASSERT_EQ(chunk.value_offset(0), chunk.value_offset(chunk.length()));
+ }
+}
+
+TEST(TestChunkedStringBuilder, BasicOperation) {
+ const int chunksize = 100;
+ internal::ChunkedStringBuilder builder(chunksize);
+
+ std::string value = "0123456789";
+
+ const int iterations = 100;
+ for (int i = 0; i < iterations; ++i) {
+ ASSERT_OK(builder.Append(value));
+ }
+
+ ArrayVector chunks;
+ ASSERT_OK(builder.Finish(&chunks));
+
+ ASSERT_EQ(10, chunks.size());
+
+ // Type is correct
+ for (auto chunk : chunks) {
+ ASSERT_TRUE(chunk->type()->Equals(utf8()));
+ }
+}
+
+// ----------------------------------------------------------------------
+// ArrayDataVisitor<binary-like> tests
+
+struct BinaryAppender {
+ Status VisitNull() {
+ data.emplace_back("(null)");
+ return Status::OK();
+ }
+
+ Status VisitValue(util::string_view v) {
+ data.push_back(v);
+ return Status::OK();
+ }
+
+ std::vector<util::string_view> data;
+};
+
+template <typename T>
+class TestBinaryDataVisitor : public ::testing::Test {
+ public:
+ using TypeClass = T;
+
+ void SetUp() override { type_ = TypeTraits<TypeClass>::type_singleton(); }
+
+ void TestBasics() {
+ auto array = ArrayFromJSON(type_, R"(["foo", null, "bar"])");
+ BinaryAppender appender;
+ ArrayDataVisitor<TypeClass> visitor;
+ ASSERT_OK(visitor.Visit(*array->data(), &appender));
+ ASSERT_THAT(appender.data, ::testing::ElementsAreArray({"foo", "(null)", "bar"}));
+ ARROW_UNUSED(visitor); // Workaround weird MSVC warning
+ }
+
+ void TestSliced() {
+ auto array = ArrayFromJSON(type_, R"(["ab", null, "cd", "ef"])")->Slice(1, 2);
+ BinaryAppender appender;
+ ArrayDataVisitor<TypeClass> visitor;
+ ASSERT_OK(visitor.Visit(*array->data(), &appender));
+ ASSERT_THAT(appender.data, ::testing::ElementsAreArray({"(null)", "cd"}));
+ ARROW_UNUSED(visitor); // Workaround weird MSVC warning
+ }
+
+ protected:
+ std::shared_ptr<DataType> type_;
+};
+
+TYPED_TEST_SUITE(TestBinaryDataVisitor, BinaryArrowTypes);
+
+TYPED_TEST(TestBinaryDataVisitor, Basics) { this->TestBasics(); }
+
+TYPED_TEST(TestBinaryDataVisitor, Sliced) { this->TestSliced(); }
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_decimal.cc b/src/arrow/cpp/src/arrow/array/array_decimal.cc
new file mode 100644
index 000000000..d65f6ee53
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_decimal.cc
@@ -0,0 +1,63 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/array_decimal.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+
+#include "arrow/array/array_binary.h"
+#include "arrow/array/data.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+// ----------------------------------------------------------------------
+// Decimal128
+
+Decimal128Array::Decimal128Array(const std::shared_ptr<ArrayData>& data)
+ : FixedSizeBinaryArray(data) {
+ ARROW_CHECK_EQ(data->type->id(), Type::DECIMAL128);
+}
+
+std::string Decimal128Array::FormatValue(int64_t i) const {
+ const auto& type_ = checked_cast<const Decimal128Type&>(*type());
+ const Decimal128 value(GetValue(i));
+ return value.ToString(type_.scale());
+}
+
+// ----------------------------------------------------------------------
+// Decimal256
+
+Decimal256Array::Decimal256Array(const std::shared_ptr<ArrayData>& data)
+ : FixedSizeBinaryArray(data) {
+ ARROW_CHECK_EQ(data->type->id(), Type::DECIMAL256);
+}
+
+std::string Decimal256Array::FormatValue(int64_t i) const {
+ const auto& type_ = checked_cast<const Decimal256Type&>(*type());
+ const Decimal256 value(GetValue(i));
+ return value.ToString(type_.scale());
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_decimal.h b/src/arrow/cpp/src/arrow/array/array_decimal.h
new file mode 100644
index 000000000..f14812549
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_decimal.h
@@ -0,0 +1,72 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+
+#include "arrow/array/array_binary.h"
+#include "arrow/array/data.h"
+#include "arrow/type.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+/// \addtogroup numeric-arrays
+///
+/// @{
+
+// ----------------------------------------------------------------------
+// Decimal128Array
+
+/// Concrete Array class for 128-bit decimal data
+class ARROW_EXPORT Decimal128Array : public FixedSizeBinaryArray {
+ public:
+ using TypeClass = Decimal128Type;
+
+ using FixedSizeBinaryArray::FixedSizeBinaryArray;
+
+ /// \brief Construct Decimal128Array from ArrayData instance
+ explicit Decimal128Array(const std::shared_ptr<ArrayData>& data);
+
+ std::string FormatValue(int64_t i) const;
+};
+
+// Backward compatibility
+using DecimalArray = Decimal128Array;
+
+// ----------------------------------------------------------------------
+// Decimal256Array
+
+/// Concrete Array class for 256-bit decimal data
+class ARROW_EXPORT Decimal256Array : public FixedSizeBinaryArray {
+ public:
+ using TypeClass = Decimal256Type;
+
+ using FixedSizeBinaryArray::FixedSizeBinaryArray;
+
+ /// \brief Construct Decimal256Array from ArrayData instance
+ explicit Decimal256Array(const std::shared_ptr<ArrayData>& data);
+
+ std::string FormatValue(int64_t i) const;
+};
+
+/// @}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_dict.cc b/src/arrow/cpp/src/arrow/array/array_dict.cc
new file mode 100644
index 000000000..2fa95e9a1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_dict.cc
@@ -0,0 +1,442 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/array_dict.h"
+
+#include <algorithm>
+#include <climits>
+#include <cstdint>
+#include <limits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/array_primitive.h"
+#include "arrow/array/data.h"
+#include "arrow/array/dict_internal.h"
+#include "arrow/array/util.h"
+#include "arrow/buffer.h"
+#include "arrow/chunked_array.h"
+#include "arrow/datum.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/int_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::CopyBitmap;
+
+// ----------------------------------------------------------------------
+// DictionaryArray
+
+std::shared_ptr<Array> DictionaryArray::indices() const { return indices_; }
+
+int64_t DictionaryArray::GetValueIndex(int64_t i) const {
+ const uint8_t* indices_data = data_->buffers[1]->data();
+ // If the value is non-negative then we can use the unsigned path
+ switch (indices_->type_id()) {
+ case Type::UINT8:
+ case Type::INT8:
+ return static_cast<int64_t>(indices_data[data_->offset + i]);
+ case Type::UINT16:
+ case Type::INT16:
+ return static_cast<int64_t>(
+ reinterpret_cast<const uint16_t*>(indices_data)[data_->offset + i]);
+ case Type::UINT32:
+ case Type::INT32:
+ return static_cast<int64_t>(
+ reinterpret_cast<const uint32_t*>(indices_data)[data_->offset + i]);
+ case Type::UINT64:
+ case Type::INT64:
+ return static_cast<int64_t>(
+ reinterpret_cast<const uint64_t*>(indices_data)[data_->offset + i]);
+ default:
+ ARROW_CHECK(false) << "unreachable";
+ return -1;
+ }
+}
+
+DictionaryArray::DictionaryArray(const std::shared_ptr<ArrayData>& data)
+ : dict_type_(checked_cast<const DictionaryType*>(data->type.get())) {
+ ARROW_CHECK_EQ(data->type->id(), Type::DICTIONARY);
+ ARROW_CHECK_NE(data->dictionary, nullptr);
+ SetData(data);
+}
+
+void DictionaryArray::SetData(const std::shared_ptr<ArrayData>& data) {
+ this->Array::SetData(data);
+ auto indices_data = data_->Copy();
+ indices_data->type = dict_type_->index_type();
+ indices_data->dictionary = nullptr;
+ indices_ = MakeArray(indices_data);
+}
+
+DictionaryArray::DictionaryArray(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& dictionary)
+ : dict_type_(checked_cast<const DictionaryType*>(type.get())) {
+ ARROW_CHECK_EQ(type->id(), Type::DICTIONARY);
+ ARROW_CHECK_EQ(indices->type_id(), dict_type_->index_type()->id());
+ ARROW_CHECK_EQ(dict_type_->value_type()->id(), dictionary->type()->id());
+ DCHECK(dict_type_->value_type()->Equals(*dictionary->type()));
+ auto data = indices->data()->Copy();
+ data->type = type;
+ data->dictionary = dictionary->data();
+ SetData(data);
+}
+
+std::shared_ptr<Array> DictionaryArray::dictionary() const {
+ if (!dictionary_) {
+ dictionary_ = MakeArray(data_->dictionary);
+ }
+ return dictionary_;
+}
+
+Result<std::shared_ptr<Array>> DictionaryArray::FromArrays(
+ const std::shared_ptr<DataType>& type, const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& dictionary) {
+ if (type->id() != Type::DICTIONARY) {
+ return Status::TypeError("Expected a dictionary type");
+ }
+ const auto& dict = checked_cast<const DictionaryType&>(*type);
+ if (indices->type_id() != dict.index_type()->id()) {
+ return Status::TypeError(
+ "Dictionary type's index type does not match "
+ "indices array's type");
+ }
+ RETURN_NOT_OK(internal::CheckIndexBounds(*indices->data(),
+ static_cast<uint64_t>(dictionary->length())));
+ return std::make_shared<DictionaryArray>(type, indices, dictionary);
+}
+
+bool DictionaryArray::CanCompareIndices(const DictionaryArray& other) const {
+ DCHECK(dictionary()->type()->Equals(other.dictionary()->type()))
+ << "dictionaries have differing type " << *dictionary()->type() << " vs "
+ << *other.dictionary()->type();
+
+ if (!indices()->type()->Equals(other.indices()->type())) {
+ return false;
+ }
+
+ auto min_length = std::min(dictionary()->length(), other.dictionary()->length());
+ return dictionary()->RangeEquals(other.dictionary(), 0, min_length, 0);
+}
+
+// ----------------------------------------------------------------------
+// Dictionary transposition
+
+namespace {
+
+inline bool IsTrivialTransposition(const int32_t* transpose_map,
+ int64_t input_dict_size) {
+ for (int64_t i = 0; i < input_dict_size; ++i) {
+ if (transpose_map[i] != i) {
+ return false;
+ }
+ }
+ return true;
+}
+
+Result<std::shared_ptr<ArrayData>> TransposeDictIndices(
+ const std::shared_ptr<ArrayData>& data, const std::shared_ptr<DataType>& in_type,
+ const std::shared_ptr<DataType>& out_type,
+ const std::shared_ptr<ArrayData>& dictionary, const int32_t* transpose_map,
+ MemoryPool* pool) {
+ // Note that in_type may be different from data->type if data is of type ExtensionType
+ if (in_type->id() != Type::DICTIONARY || out_type->id() != Type::DICTIONARY) {
+ return Status::TypeError("Expected dictionary type");
+ }
+ const int64_t in_dict_len = data->dictionary->length;
+ const auto& in_dict_type = checked_cast<const DictionaryType&>(*in_type);
+ const auto& out_dict_type = checked_cast<const DictionaryType&>(*out_type);
+
+ const auto& in_index_type = *in_dict_type.index_type();
+ const auto& out_index_type =
+ checked_cast<const FixedWidthType&>(*out_dict_type.index_type());
+
+ if (in_index_type.id() == out_index_type.id() &&
+ IsTrivialTransposition(transpose_map, in_dict_len)) {
+ // Index type and values will be identical => we can simply reuse
+ // the existing buffers.
+ auto out_data =
+ ArrayData::Make(out_type, data->length, {data->buffers[0], data->buffers[1]},
+ data->null_count, data->offset);
+ out_data->dictionary = dictionary;
+ return out_data;
+ }
+
+ // Default path: compute a buffer of transposed indices.
+ ARROW_ASSIGN_OR_RAISE(
+ auto out_buffer,
+ AllocateBuffer(data->length * (out_index_type.bit_width() / CHAR_BIT), pool));
+
+ // Shift null buffer if the original offset is non-zero
+ std::shared_ptr<Buffer> null_bitmap;
+ if (data->offset != 0 && data->null_count != 0) {
+ ARROW_ASSIGN_OR_RAISE(null_bitmap, CopyBitmap(pool, data->buffers[0]->data(),
+ data->offset, data->length));
+ } else {
+ null_bitmap = data->buffers[0];
+ }
+
+ auto out_data = ArrayData::Make(out_type, data->length,
+ {null_bitmap, std::move(out_buffer)}, data->null_count);
+ out_data->dictionary = dictionary;
+ RETURN_NOT_OK(internal::TransposeInts(
+ in_index_type, out_index_type, data->GetValues<uint8_t>(1, 0),
+ out_data->GetMutableValues<uint8_t>(1, 0), data->offset, out_data->offset,
+ data->length, transpose_map));
+ return out_data;
+}
+
+} // namespace
+
+Result<std::shared_ptr<Array>> DictionaryArray::Transpose(
+ const std::shared_ptr<DataType>& type, const std::shared_ptr<Array>& dictionary,
+ const int32_t* transpose_map, MemoryPool* pool) const {
+ ARROW_ASSIGN_OR_RAISE(auto transposed,
+ TransposeDictIndices(data_, data_->type, type, dictionary->data(),
+ transpose_map, pool));
+ return MakeArray(std::move(transposed));
+}
+
+// ----------------------------------------------------------------------
+// Dictionary unification
+
+namespace {
+
+template <typename T>
+class DictionaryUnifierImpl : public DictionaryUnifier {
+ public:
+ using ArrayType = typename TypeTraits<T>::ArrayType;
+ using DictTraits = typename internal::DictionaryTraits<T>;
+ using MemoTableType = typename DictTraits::MemoTableType;
+
+ DictionaryUnifierImpl(MemoryPool* pool, std::shared_ptr<DataType> value_type)
+ : pool_(pool), value_type_(value_type), memo_table_(pool) {}
+
+ Status Unify(const Array& dictionary, std::shared_ptr<Buffer>* out) override {
+ if (dictionary.null_count() > 0) {
+ return Status::Invalid("Cannot yet unify dictionaries with nulls");
+ }
+ if (!dictionary.type()->Equals(*value_type_)) {
+ return Status::Invalid("Dictionary type different from unifier: ",
+ dictionary.type()->ToString());
+ }
+ const ArrayType& values = checked_cast<const ArrayType&>(dictionary);
+ if (out != nullptr) {
+ ARROW_ASSIGN_OR_RAISE(auto result,
+ AllocateBuffer(dictionary.length() * sizeof(int32_t), pool_));
+ auto result_raw = reinterpret_cast<int32_t*>(result->mutable_data());
+ for (int64_t i = 0; i < values.length(); ++i) {
+ RETURN_NOT_OK(memo_table_.GetOrInsert(values.GetView(i), &result_raw[i]));
+ }
+ *out = std::move(result);
+ } else {
+ for (int64_t i = 0; i < values.length(); ++i) {
+ int32_t unused_memo_index;
+ RETURN_NOT_OK(memo_table_.GetOrInsert(values.GetView(i), &unused_memo_index));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Unify(const Array& dictionary) override { return Unify(dictionary, nullptr); }
+
+ Status GetResult(std::shared_ptr<DataType>* out_type,
+ std::shared_ptr<Array>* out_dict) override {
+ int64_t dict_length = memo_table_.size();
+ std::shared_ptr<DataType> index_type;
+ if (dict_length <= std::numeric_limits<int8_t>::max()) {
+ index_type = int8();
+ } else if (dict_length <= std::numeric_limits<int16_t>::max()) {
+ index_type = int16();
+ } else if (dict_length <= std::numeric_limits<int32_t>::max()) {
+ index_type = int32();
+ } else {
+ index_type = int64();
+ }
+ // Build unified dictionary type with the right index type
+ *out_type = arrow::dictionary(index_type, value_type_);
+
+ // Build unified dictionary array
+ std::shared_ptr<ArrayData> data;
+ RETURN_NOT_OK(DictTraits::GetDictionaryArrayData(pool_, value_type_, memo_table_,
+ 0 /* start_offset */, &data));
+ *out_dict = MakeArray(data);
+ return Status::OK();
+ }
+
+ Status GetResultWithIndexType(const std::shared_ptr<DataType>& index_type,
+ std::shared_ptr<Array>* out_dict) override {
+ int64_t dict_length = memo_table_.size();
+ if (!internal::IntegersCanFit(Datum(dict_length), *index_type).ok()) {
+ return Status::Invalid(
+ "These dictionaries cannot be combined. The unified dictionary requires a "
+ "larger index type.");
+ }
+
+ // Build unified dictionary array
+ std::shared_ptr<ArrayData> data;
+ RETURN_NOT_OK(DictTraits::GetDictionaryArrayData(pool_, value_type_, memo_table_,
+ 0 /* start_offset */, &data));
+ *out_dict = MakeArray(data);
+ return Status::OK();
+ }
+
+ private:
+ MemoryPool* pool_;
+ std::shared_ptr<DataType> value_type_;
+ MemoTableType memo_table_;
+};
+
+struct MakeUnifier {
+ MemoryPool* pool;
+ std::shared_ptr<DataType> value_type;
+ std::unique_ptr<DictionaryUnifier> result;
+
+ MakeUnifier(MemoryPool* pool, std::shared_ptr<DataType> value_type)
+ : pool(pool), value_type(value_type) {}
+
+ template <typename T>
+ enable_if_no_memoize<T, Status> Visit(const T&) {
+ // Default implementation for non-dictionary-supported datatypes
+ return Status::NotImplemented("Unification of ", *value_type,
+ " dictionaries is not implemented");
+ }
+
+ template <typename T>
+ enable_if_memoize<T, Status> Visit(const T&) {
+ result.reset(new DictionaryUnifierImpl<T>(pool, value_type));
+ return Status::OK();
+ }
+};
+
+struct RecursiveUnifier {
+ MemoryPool* pool;
+
+ // Return true if any of the arrays was changed (including descendents)
+ Result<bool> Unify(std::shared_ptr<DataType> type, ArrayDataVector* chunks) {
+ DCHECK(!chunks->empty());
+ bool changed = false;
+ std::shared_ptr<DataType> ext_type = nullptr;
+
+ if (type->id() == Type::EXTENSION) {
+ ext_type = std::move(type);
+ type = checked_cast<const ExtensionType&>(*ext_type).storage_type();
+ }
+
+ // Unify all child dictionaries (if any)
+ if (type->num_fields() > 0) {
+ ArrayDataVector children(chunks->size());
+ for (int i = 0; i < type->num_fields(); ++i) {
+ std::transform(chunks->begin(), chunks->end(), children.begin(),
+ [i](const std::shared_ptr<ArrayData>& array) {
+ return array->child_data[i];
+ });
+ ARROW_ASSIGN_OR_RAISE(bool child_changed,
+ Unify(type->field(i)->type(), &children));
+ if (child_changed) {
+ // Only do this when unification actually occurred
+ for (size_t j = 0; j < chunks->size(); ++j) {
+ (*chunks)[j]->child_data[i] = std::move(children[j]);
+ }
+ changed = true;
+ }
+ }
+ }
+
+ // Unify this dictionary
+ if (type->id() == Type::DICTIONARY) {
+ const auto& dict_type = checked_cast<const DictionaryType&>(*type);
+ // XXX Ideally, we should unify dictionaries nested in value_type first,
+ // but DictionaryUnifier doesn't supported nested dictionaries anyway,
+ // so this will fail.
+ ARROW_ASSIGN_OR_RAISE(auto unifier,
+ DictionaryUnifier::Make(dict_type.value_type(), this->pool));
+ // Unify all dictionary array chunks
+ BufferVector transpose_maps(chunks->size());
+ for (size_t j = 0; j < chunks->size(); ++j) {
+ DCHECK_NE((*chunks)[j]->dictionary, nullptr);
+ RETURN_NOT_OK(
+ unifier->Unify(*MakeArray((*chunks)[j]->dictionary), &transpose_maps[j]));
+ }
+ std::shared_ptr<Array> dictionary;
+ RETURN_NOT_OK(unifier->GetResultWithIndexType(dict_type.index_type(), &dictionary));
+ for (size_t j = 0; j < chunks->size(); ++j) {
+ ARROW_ASSIGN_OR_RAISE(
+ (*chunks)[j],
+ TransposeDictIndices(
+ (*chunks)[j], type, type, dictionary->data(),
+ reinterpret_cast<const int32_t*>(transpose_maps[j]->data()), this->pool));
+ if (ext_type) {
+ (*chunks)[j]->type = ext_type;
+ }
+ }
+ changed = true;
+ }
+
+ return changed;
+ }
+};
+
+} // namespace
+
+Result<std::unique_ptr<DictionaryUnifier>> DictionaryUnifier::Make(
+ std::shared_ptr<DataType> value_type, MemoryPool* pool) {
+ MakeUnifier maker(pool, value_type);
+ RETURN_NOT_OK(VisitTypeInline(*value_type, &maker));
+ return std::move(maker.result);
+}
+
+Result<std::shared_ptr<ChunkedArray>> DictionaryUnifier::UnifyChunkedArray(
+ const std::shared_ptr<ChunkedArray>& array, MemoryPool* pool) {
+ if (array->num_chunks() <= 1) {
+ return array;
+ }
+
+ ArrayDataVector data_chunks(array->num_chunks());
+ std::transform(array->chunks().begin(), array->chunks().end(), data_chunks.begin(),
+ [](const std::shared_ptr<Array>& array) { return array->data(); });
+ ARROW_ASSIGN_OR_RAISE(bool changed,
+ RecursiveUnifier{pool}.Unify(array->type(), &data_chunks));
+ if (!changed) {
+ return array;
+ }
+ ArrayVector chunks(array->num_chunks());
+ std::transform(data_chunks.begin(), data_chunks.end(), chunks.begin(),
+ [](const std::shared_ptr<ArrayData>& data) { return MakeArray(data); });
+ return std::make_shared<ChunkedArray>(std::move(chunks), array->type());
+}
+
+Result<std::shared_ptr<Table>> DictionaryUnifier::UnifyTable(const Table& table,
+ MemoryPool* pool) {
+ ChunkedArrayVector columns = table.columns();
+ for (auto& col : columns) {
+ ARROW_ASSIGN_OR_RAISE(col, DictionaryUnifier::UnifyChunkedArray(col, pool));
+ }
+ return Table::Make(table.schema(), std::move(columns), table.num_rows());
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_dict.h b/src/arrow/cpp/src/arrow/array/array_dict.h
new file mode 100644
index 000000000..8791eaa07
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_dict.h
@@ -0,0 +1,180 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/data.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+// ----------------------------------------------------------------------
+// DictionaryArray
+
+/// \brief Array type for dictionary-encoded data with a
+/// data-dependent dictionary
+///
+/// A dictionary array contains an array of non-negative integers (the
+/// "dictionary indices") along with a data type containing a "dictionary"
+/// corresponding to the distinct values represented in the data.
+///
+/// For example, the array
+///
+/// ["foo", "bar", "foo", "bar", "foo", "bar"]
+///
+/// with dictionary ["bar", "foo"], would have dictionary array representation
+///
+/// indices: [1, 0, 1, 0, 1, 0]
+/// dictionary: ["bar", "foo"]
+///
+/// The indices in principle may be any integer type.
+class ARROW_EXPORT DictionaryArray : public Array {
+ public:
+ using TypeClass = DictionaryType;
+
+ explicit DictionaryArray(const std::shared_ptr<ArrayData>& data);
+
+ DictionaryArray(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& dictionary);
+
+ /// \brief Construct DictionaryArray from dictionary and indices
+ /// array and validate
+ ///
+ /// This function does the validation of the indices and input type. It checks if
+ /// all indices are non-negative and smaller than the size of the dictionary.
+ ///
+ /// \param[in] type a dictionary type
+ /// \param[in] dictionary the dictionary with same value type as the
+ /// type object
+ /// \param[in] indices an array of non-negative integers smaller than the
+ /// size of the dictionary
+ static Result<std::shared_ptr<Array>> FromArrays(
+ const std::shared_ptr<DataType>& type, const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& dictionary);
+
+ static Result<std::shared_ptr<Array>> FromArrays(
+ const std::shared_ptr<Array>& indices, const std::shared_ptr<Array>& dictionary) {
+ return FromArrays(::arrow::dictionary(indices->type(), dictionary->type()), indices,
+ dictionary);
+ }
+
+ /// \brief Transpose this DictionaryArray
+ ///
+ /// This method constructs a new dictionary array with the given dictionary
+ /// type, transposing indices using the transpose map. The type and the
+ /// transpose map are typically computed using DictionaryUnifier.
+ ///
+ /// \param[in] type the new type object
+ /// \param[in] dictionary the new dictionary
+ /// \param[in] transpose_map transposition array of this array's indices
+ /// into the target array's indices
+ /// \param[in] pool a pool to allocate the array data from
+ Result<std::shared_ptr<Array>> Transpose(
+ const std::shared_ptr<DataType>& type, const std::shared_ptr<Array>& dictionary,
+ const int32_t* transpose_map, MemoryPool* pool = default_memory_pool()) const;
+
+ /// \brief Determine whether dictionary arrays may be compared without unification
+ bool CanCompareIndices(const DictionaryArray& other) const;
+
+ /// \brief Return the dictionary for this array, which is stored as
+ /// a member of the ArrayData internal structure
+ std::shared_ptr<Array> dictionary() const;
+ std::shared_ptr<Array> indices() const;
+
+ /// \brief Return the ith value of indices, cast to int64_t. Not recommended
+ /// for use in performance-sensitive code. Does not validate whether the
+ /// value is null or out-of-bounds.
+ int64_t GetValueIndex(int64_t i) const;
+
+ const DictionaryType* dict_type() const { return dict_type_; }
+
+ private:
+ void SetData(const std::shared_ptr<ArrayData>& data);
+ const DictionaryType* dict_type_;
+ std::shared_ptr<Array> indices_;
+
+ // Lazily initialized when invoking dictionary()
+ mutable std::shared_ptr<Array> dictionary_;
+};
+
+/// \brief Helper class for incremental dictionary unification
+class ARROW_EXPORT DictionaryUnifier {
+ public:
+ virtual ~DictionaryUnifier() = default;
+
+ /// \brief Construct a DictionaryUnifier
+ /// \param[in] value_type the data type of the dictionaries
+ /// \param[in] pool MemoryPool to use for memory allocations
+ static Result<std::unique_ptr<DictionaryUnifier>> Make(
+ std::shared_ptr<DataType> value_type, MemoryPool* pool = default_memory_pool());
+
+ /// \brief Unify dictionaries accross array chunks
+ ///
+ /// The dictionaries in the array chunks will be unified, their indices
+ /// accordingly transposed.
+ ///
+ /// Only dictionaries with a primitive value type are currently supported.
+ /// However, dictionaries nested inside a more complex type are correctly unified.
+ static Result<std::shared_ptr<ChunkedArray>> UnifyChunkedArray(
+ const std::shared_ptr<ChunkedArray>& array,
+ MemoryPool* pool = default_memory_pool());
+
+ /// \brief Unify dictionaries accross the chunks of each table column
+ ///
+ /// The dictionaries in each table column will be unified, their indices
+ /// accordingly transposed.
+ ///
+ /// Only dictionaries with a primitive value type are currently supported.
+ /// However, dictionaries nested inside a more complex type are correctly unified.
+ static Result<std::shared_ptr<Table>> UnifyTable(
+ const Table& table, MemoryPool* pool = default_memory_pool());
+
+ /// \brief Append dictionary to the internal memo
+ virtual Status Unify(const Array& dictionary) = 0;
+
+ /// \brief Append dictionary and compute transpose indices
+ /// \param[in] dictionary the dictionary values to unify
+ /// \param[out] out_transpose a Buffer containing computed transpose indices
+ /// as int32_t values equal in length to the passed dictionary. The value in
+ /// each slot corresponds to the new index value for each original index
+ /// for a DictionaryArray with the old dictionary
+ virtual Status Unify(const Array& dictionary,
+ std::shared_ptr<Buffer>* out_transpose) = 0;
+
+ /// \brief Return a result DictionaryType with the smallest possible index
+ /// type to accommodate the unified dictionary. The unifier cannot be used
+ /// after this is called
+ virtual Status GetResult(std::shared_ptr<DataType>* out_type,
+ std::shared_ptr<Array>* out_dict) = 0;
+
+ /// \brief Return a unified dictionary with the given index type. If
+ /// the index type is not large enough then an invalid status will be returned.
+ /// The unifier cannot be used after this is called
+ virtual Status GetResultWithIndexType(const std::shared_ptr<DataType>& index_type,
+ std::shared_ptr<Array>* out_dict) = 0;
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_dict_test.cc b/src/arrow/cpp/src/arrow/array/array_dict_test.cc
new file mode 100644
index 000000000..d6f7f3c86
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_dict_test.cc
@@ -0,0 +1,1678 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <array>
+#include <cstdint>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/array/builder_dict.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/chunked_array.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/testing/extension_type.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+void CheckTransposeMap(const Buffer& map, std::vector<int32_t> expected) {
+ AssertBufferEqual(map, *Buffer::Wrap(expected));
+}
+
+void CheckDictionaryArray(const std::shared_ptr<Array>& array,
+ const std::shared_ptr<Array>& expected_values,
+ const std::shared_ptr<Array>& expected_indices) {
+ const auto& dict_array = checked_cast<const DictionaryArray&>(*array);
+ AssertArraysEqual(*expected_values, *dict_array.dictionary(), /*verbose=*/true);
+ AssertArraysEqual(*expected_indices, *dict_array.indices(), /*verbose=*/true);
+}
+
+std::shared_ptr<Array> DictExtensionFromJSON(const std::shared_ptr<DataType>& type,
+ const std::string& json) {
+ auto ext_type = checked_pointer_cast<ExtensionType>(type);
+ auto storage = ArrayFromJSON(ext_type->storage_type(), json);
+ auto ext_data = storage->data()->Copy();
+ ext_data->type = ext_type;
+ return MakeArray(ext_data);
+}
+
+// ----------------------------------------------------------------------
+// Dictionary tests
+
+template <typename Type>
+class TestDictionaryBuilder : public TestBuilder {};
+
+typedef ::testing::Types<Int8Type, UInt8Type, Int16Type, UInt16Type, Int32Type,
+ UInt32Type, Int64Type, UInt64Type, FloatType, DoubleType>
+ PrimitiveDictionaries;
+
+TYPED_TEST_SUITE(TestDictionaryBuilder, PrimitiveDictionaries);
+
+TYPED_TEST(TestDictionaryBuilder, Basic) {
+ using c_type = typename TypeParam::c_type;
+
+ DictionaryBuilder<TypeParam> builder;
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(2)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.AppendNull());
+
+ ASSERT_EQ(builder.length(), 4);
+ ASSERT_EQ(builder.null_count(), 1);
+
+ // Build expected data
+ auto value_type = std::make_shared<TypeParam>();
+ auto dict_type = dictionary(int8(), value_type);
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+
+ DictionaryArray expected(dict_type, ArrayFromJSON(int8(), "[0, 1, 0, null]"),
+ ArrayFromJSON(value_type, "[1, 2]"));
+ ASSERT_TRUE(expected.Equals(result));
+}
+
+TYPED_TEST(TestDictionaryBuilder, ArrayInit) {
+ using c_type = typename TypeParam::c_type;
+
+ auto value_type = std::make_shared<TypeParam>();
+ auto dict_array = ArrayFromJSON(value_type, "[1, 2]");
+ auto dict_type = dictionary(int8(), value_type);
+
+ DictionaryBuilder<TypeParam> builder(dict_array);
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(2)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.AppendNull());
+
+ ASSERT_EQ(builder.length(), 4);
+ ASSERT_EQ(builder.null_count(), 1);
+
+ // Build expected data
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+
+ auto indices = ArrayFromJSON(int8(), "[0, 1, 0, null]");
+ DictionaryArray expected(dict_type, indices, dict_array);
+
+ AssertArraysEqual(expected, *result);
+}
+
+TYPED_TEST(TestDictionaryBuilder, MakeBuilder) {
+ using c_type = typename TypeParam::c_type;
+
+ auto value_type = std::make_shared<TypeParam>();
+ auto dict_array = ArrayFromJSON(value_type, "[1, 2]");
+ auto dict_type = dictionary(int8(), value_type);
+ std::unique_ptr<ArrayBuilder> boxed_builder;
+ ASSERT_OK(MakeBuilder(default_memory_pool(), dict_type, &boxed_builder));
+ auto& builder = checked_cast<DictionaryBuilder<TypeParam>&>(*boxed_builder);
+
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(2)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.AppendNull());
+
+ ASSERT_EQ(builder.length(), 4);
+ ASSERT_EQ(builder.null_count(), 1);
+
+ // Build expected data
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+
+ auto int_array = ArrayFromJSON(int8(), "[0, 1, 0, null]");
+ DictionaryArray expected(dict_type, int_array, dict_array);
+
+ AssertArraysEqual(expected, *result);
+}
+
+TYPED_TEST(TestDictionaryBuilder, ArrayConversion) {
+ auto type = std::make_shared<TypeParam>();
+
+ auto intermediate_result = ArrayFromJSON(type, "[1, 2, 1]");
+ DictionaryBuilder<TypeParam> dictionary_builder;
+ ASSERT_OK(dictionary_builder.AppendArray(*intermediate_result));
+ std::shared_ptr<Array> result;
+ ASSERT_OK(dictionary_builder.Finish(&result));
+
+ // Build expected data
+ auto dict_array = ArrayFromJSON(type, "[1, 2]");
+ auto dict_type = dictionary(int8(), type);
+
+ auto int_array = ArrayFromJSON(int8(), "[0, 1, 0]");
+ DictionaryArray expected(dict_type, int_array, dict_array);
+
+ ASSERT_TRUE(expected.Equals(result));
+}
+
+TYPED_TEST(TestDictionaryBuilder, DoubleTableSize) {
+ using Scalar = typename TypeParam::c_type;
+ // Skip this test for (u)int8
+ if (sizeof(Scalar) > 1) {
+ // Build the dictionary Array
+ DictionaryBuilder<TypeParam> builder;
+ // Build expected data
+ NumericBuilder<TypeParam> dict_builder;
+ Int16Builder int_builder;
+
+ // Fill with 1024 different values
+ for (int64_t i = 0; i < 1024; i++) {
+ ASSERT_OK(builder.Append(static_cast<Scalar>(i)));
+ ASSERT_OK(dict_builder.Append(static_cast<Scalar>(i)));
+ ASSERT_OK(int_builder.Append(static_cast<uint16_t>(i)));
+ }
+ // Fill with an already existing value
+ for (int64_t i = 0; i < 1024; i++) {
+ ASSERT_OK(builder.Append(static_cast<Scalar>(1)));
+ ASSERT_OK(int_builder.Append(1));
+ }
+
+ // Finalize result
+ std::shared_ptr<Array> result;
+ FinishAndCheckPadding(&builder, &result);
+
+ // Finalize expected data
+ std::shared_ptr<Array> dict_array;
+ ASSERT_OK(dict_builder.Finish(&dict_array));
+
+ auto dtype = dictionary(int16(), dict_array->type());
+ std::shared_ptr<Array> int_array;
+ ASSERT_OK(int_builder.Finish(&int_array));
+
+ DictionaryArray expected(dtype, int_array, dict_array);
+ AssertArraysEqual(expected, *result);
+ }
+}
+
+TYPED_TEST(TestDictionaryBuilder, DeltaDictionary) {
+ using c_type = typename TypeParam::c_type;
+ auto type = std::make_shared<TypeParam>();
+
+ DictionaryBuilder<TypeParam> builder;
+
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(2)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(2)));
+ std::shared_ptr<Array> result;
+ FinishAndCheckPadding(&builder, &result);
+
+ // Build expected data for the initial dictionary
+ auto ex_dict = ArrayFromJSON(type, "[1, 2]");
+ auto dict_type1 = dictionary(int8(), type);
+ DictionaryArray expected(dict_type1, ArrayFromJSON(int8(), "[0, 1, 0, 1]"), ex_dict);
+
+ ASSERT_TRUE(expected.Equals(result));
+
+ // extend the dictionary builder with new data
+ ASSERT_OK(builder.Append(static_cast<c_type>(2)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(3)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(3)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(3)));
+
+ std::shared_ptr<Array> result_indices, result_delta;
+ ASSERT_OK(builder.FinishDelta(&result_indices, &result_delta));
+ AssertArraysEqual(*ArrayFromJSON(int8(), "[1, 2, 2, 0, 2]"), *result_indices);
+ AssertArraysEqual(*ArrayFromJSON(type, "[3]"), *result_delta);
+}
+
+TYPED_TEST(TestDictionaryBuilder, DoubleDeltaDictionary) {
+ using c_type = typename TypeParam::c_type;
+ auto type = std::make_shared<TypeParam>();
+ auto dict_type = dictionary(int8(), type);
+
+ DictionaryBuilder<TypeParam> builder;
+
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(2)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(2)));
+ std::shared_ptr<Array> result;
+ FinishAndCheckPadding(&builder, &result);
+
+ // Build expected data for the initial dictionary
+ auto ex_dict1 = ArrayFromJSON(type, "[1, 2]");
+ DictionaryArray expected(dict_type, ArrayFromJSON(int8(), "[0, 1, 0, 1]"), ex_dict1);
+
+ ASSERT_TRUE(expected.Equals(result));
+
+ // extend the dictionary builder with new data
+ ASSERT_OK(builder.Append(static_cast<c_type>(2)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(3)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(3)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(3)));
+
+ std::shared_ptr<Array> result_indices1, result_delta1;
+ ASSERT_OK(builder.FinishDelta(&result_indices1, &result_delta1));
+ AssertArraysEqual(*ArrayFromJSON(int8(), "[1, 2, 2, 0, 2]"), *result_indices1);
+ AssertArraysEqual(*ArrayFromJSON(type, "[3]"), *result_delta1);
+
+ // extend the dictionary builder with new data again
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(2)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(3)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(4)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(5)));
+
+ std::shared_ptr<Array> result_indices2, result_delta2;
+ ASSERT_OK(builder.FinishDelta(&result_indices2, &result_delta2));
+ AssertArraysEqual(*ArrayFromJSON(int8(), "[0, 1, 2, 3, 4]"), *result_indices2);
+ AssertArraysEqual(*ArrayFromJSON(type, "[4, 5]"), *result_delta2);
+}
+
+TYPED_TEST(TestDictionaryBuilder, Dictionary32_BasicPrimitive) {
+ using c_type = typename TypeParam::c_type;
+ auto type = std::make_shared<TypeParam>();
+ auto dict_type = dictionary(int32(), type);
+
+ Dictionary32Builder<TypeParam> builder;
+
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(2)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(2)));
+ std::shared_ptr<Array> result;
+ FinishAndCheckPadding(&builder, &result);
+
+ // Build expected data for the initial dictionary
+ auto ex_dict1 = ArrayFromJSON(type, "[1, 2]");
+ DictionaryArray expected(dict_type, ArrayFromJSON(int32(), "[0, 1, 0, 1]"), ex_dict1);
+ ASSERT_TRUE(expected.Equals(result));
+}
+
+TYPED_TEST(TestDictionaryBuilder, FinishResetBehavior) {
+ // ARROW-6861
+ using c_type = typename TypeParam::c_type;
+ auto type = std::make_shared<TypeParam>();
+
+ Dictionary32Builder<TypeParam> builder;
+
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(2)));
+
+ // Properties from indices_builder propagated
+ ASSERT_LT(0, builder.capacity());
+ ASSERT_LT(0, builder.null_count());
+ ASSERT_EQ(4, builder.length());
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+
+ // Everything reset
+ ASSERT_EQ(0, builder.capacity());
+ ASSERT_EQ(0, builder.length());
+ ASSERT_EQ(0, builder.null_count());
+
+ // Use the builder again
+ ASSERT_OK(builder.Append(static_cast<c_type>(3)));
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_OK(builder.Append(static_cast<c_type>(4)));
+
+ ASSERT_OK(builder.Finish(&result));
+
+ // Dictionary has 4 elements because the dictionary memo was not reset
+ ASSERT_EQ(4, static_cast<const DictionaryArray&>(*result).dictionary()->length());
+}
+
+TYPED_TEST(TestDictionaryBuilder, ResetFull) {
+ using c_type = typename TypeParam::c_type;
+ auto type = std::make_shared<TypeParam>();
+
+ Dictionary32Builder<TypeParam> builder;
+
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_OK(builder.Append(static_cast<c_type>(1)));
+ ASSERT_OK(builder.Append(static_cast<c_type>(2)));
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+
+ ASSERT_OK(builder.Append(static_cast<c_type>(3)));
+ ASSERT_OK(builder.Finish(&result));
+
+ // Dictionary expanded
+ const auto& dict_result = static_cast<const DictionaryArray&>(*result);
+ AssertArraysEqual(*ArrayFromJSON(int32(), "[2]"), *dict_result.indices());
+ AssertArraysEqual(*ArrayFromJSON(type, "[1, 2, 3]"),
+ *static_cast<const DictionaryArray&>(*result).dictionary());
+
+ builder.ResetFull();
+ ASSERT_OK(builder.Append(static_cast<c_type>(4)));
+ ASSERT_OK(builder.Finish(&result));
+ const auto& dict_result2 = static_cast<const DictionaryArray&>(*result);
+ AssertArraysEqual(*ArrayFromJSON(int32(), "[0]"), *dict_result2.indices());
+ AssertArraysEqual(*ArrayFromJSON(type, "[4]"), *dict_result2.dictionary());
+}
+
+TEST(TestDictionaryBuilderAdHoc, AppendIndicesUpdateCapacity) {
+ DictionaryBuilder<Int32Type> builder;
+ Dictionary32Builder<Int32Type> builder32;
+
+ std::vector<int32_t> indices_i32 = {0, 1, 2};
+ std::vector<int64_t> indices_i64 = {0, 1, 2};
+
+ ASSERT_OK(builder.AppendIndices(indices_i64.data(), 3));
+ ASSERT_OK(builder32.AppendIndices(indices_i32.data(), 3));
+
+ ASSERT_LT(0, builder.capacity());
+ ASSERT_LT(0, builder32.capacity());
+}
+
+TEST(TestStringDictionaryBuilder, Basic) {
+ // Build the dictionary Array
+ StringDictionaryBuilder builder;
+ ASSERT_OK(builder.Append("test"));
+ ASSERT_OK(builder.Append("test2"));
+ ASSERT_OK(builder.Append("test", 4));
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+
+ // Build expected data
+ auto ex_dict = ArrayFromJSON(utf8(), "[\"test\", \"test2\"]");
+ auto dtype = dictionary(int8(), utf8());
+ auto int_array = ArrayFromJSON(int8(), "[0, 1, 0]");
+ DictionaryArray expected(dtype, int_array, ex_dict);
+
+ ASSERT_TRUE(expected.Equals(result));
+}
+
+template <typename BuilderType, typename IndexType, typename AppendCType>
+void TestStringDictionaryAppendIndices() {
+ auto index_type = TypeTraits<IndexType>::type_singleton();
+
+ auto ex_dict = ArrayFromJSON(utf8(), R"(["c", "a", "b", "d"])");
+ auto invalid_dict = ArrayFromJSON(binary(), R"(["e", "f"])");
+
+ BuilderType builder;
+ ASSERT_OK(builder.InsertMemoValues(*ex_dict));
+
+ // Inserting again should have no effect
+ ASSERT_OK(builder.InsertMemoValues(*ex_dict));
+
+ // Type mismatch
+ ASSERT_RAISES(Invalid, builder.InsertMemoValues(*invalid_dict));
+
+ std::vector<AppendCType> raw_indices = {0, 1, 2, -1, 3};
+ std::vector<uint8_t> is_valid = {1, 1, 1, 0, 1};
+ for (int i = 0; i < 2; ++i) {
+ ASSERT_OK(builder.AppendIndices(
+ raw_indices.data(), static_cast<int64_t>(raw_indices.size()), is_valid.data()));
+ }
+
+ ASSERT_EQ(10, builder.length());
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+
+ auto ex_indices = ArrayFromJSON(index_type, R"([0, 1, 2, null, 3, 0, 1, 2, null, 3])");
+ auto dtype = dictionary(index_type, utf8());
+ DictionaryArray expected(dtype, ex_indices, ex_dict);
+ ASSERT_TRUE(expected.Equals(result));
+}
+
+TEST(TestStringDictionaryBuilder, AppendIndices) {
+ // Currently AdaptiveIntBuilder only accepts int64_t in bulk appends
+ TestStringDictionaryAppendIndices<StringDictionaryBuilder, Int8Type, int64_t>();
+
+ TestStringDictionaryAppendIndices<StringDictionary32Builder, Int32Type, int32_t>();
+}
+
+TEST(TestStringDictionaryBuilder, ArrayInit) {
+ auto dict_array = ArrayFromJSON(utf8(), R"(["test", "test2"])");
+ auto int_array = ArrayFromJSON(int8(), "[0, 1, 0]");
+
+ // Build the dictionary Array
+ StringDictionaryBuilder builder(dict_array);
+ ASSERT_OK(builder.Append("test"));
+ ASSERT_OK(builder.Append("test2"));
+ ASSERT_OK(builder.Append("test"));
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+
+ // Build expected data
+ DictionaryArray expected(dictionary(int8(), utf8()), int_array, dict_array);
+
+ AssertArraysEqual(expected, *result);
+}
+
+template <typename BuilderType>
+void TestStringDictionaryMakeBuilder(const std::shared_ptr<DataType>& value_type) {
+ auto dict_array = ArrayFromJSON(value_type, R"(["test", "test2"])");
+ auto dict_type = dictionary(int8(), value_type);
+ auto int_array = ArrayFromJSON(int8(), "[0, 1, 0]");
+ std::unique_ptr<ArrayBuilder> boxed_builder;
+ ASSERT_OK(MakeBuilder(default_memory_pool(), dict_type, &boxed_builder));
+ auto& builder = checked_cast<BuilderType&>(*boxed_builder);
+
+ // Build the dictionary Array
+ ASSERT_OK(builder.Append("test"));
+ ASSERT_OK(builder.Append("test2"));
+ ASSERT_OK(builder.Append("test"));
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+
+ // Build expected data
+ DictionaryArray expected(dict_type, int_array, dict_array);
+
+ AssertArraysEqual(expected, *result);
+}
+
+TEST(TestStringDictionaryBuilder, MakeBuilder) {
+ TestStringDictionaryMakeBuilder<DictionaryBuilder<StringType>>(utf8());
+}
+
+TEST(TestLargeStringDictionaryBuilder, MakeBuilder) {
+ TestStringDictionaryMakeBuilder<DictionaryBuilder<LargeStringType>>(large_utf8());
+}
+
+// ARROW-4367
+TEST(TestStringDictionaryBuilder, OnlyNull) {
+ // Build the dictionary Array
+ StringDictionaryBuilder builder;
+ ASSERT_OK(builder.AppendNull());
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+
+ // Build expected data
+ auto dict = ArrayFromJSON(utf8(), "[]");
+ auto dtype = dictionary(int8(), utf8());
+ auto int_array = ArrayFromJSON(int8(), "[null]");
+ DictionaryArray expected(dtype, int_array, dict);
+
+ ASSERT_TRUE(expected.Equals(result));
+}
+
+TEST(TestStringDictionaryBuilder, DoubleTableSize) {
+ // Build the dictionary Array
+ StringDictionaryBuilder builder;
+ // Build expected data
+ StringBuilder str_builder;
+ Int16Builder int_builder;
+
+ // Fill with 1024 different values
+ for (int64_t i = 0; i < 1024; i++) {
+ std::stringstream ss;
+ ss << "test" << i;
+ ASSERT_OK(builder.Append(ss.str()));
+ ASSERT_OK(str_builder.Append(ss.str()));
+ ASSERT_OK(int_builder.Append(static_cast<uint16_t>(i)));
+ }
+ // Fill with an already existing value
+ for (int64_t i = 0; i < 1024; i++) {
+ ASSERT_OK(builder.Append("test1"));
+ ASSERT_OK(int_builder.Append(1));
+ }
+
+ // Finalize result
+ std::shared_ptr<Array> result;
+ FinishAndCheckPadding(&builder, &result);
+
+ // Finalize expected data
+ std::shared_ptr<Array> str_array;
+ ASSERT_OK(str_builder.Finish(&str_array));
+ auto dtype = dictionary(int16(), utf8());
+ std::shared_ptr<Array> int_array;
+ ASSERT_OK(int_builder.Finish(&int_array));
+
+ DictionaryArray expected(dtype, int_array, str_array);
+ ASSERT_TRUE(expected.Equals(result));
+}
+
+TEST(TestStringDictionaryBuilder, DeltaDictionary) {
+ // Build the dictionary Array
+ StringDictionaryBuilder builder;
+ ASSERT_OK(builder.Append("test"));
+ ASSERT_OK(builder.Append("test2"));
+ ASSERT_OK(builder.Append("test"));
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+
+ // Build expected data
+ auto dict = ArrayFromJSON(utf8(), "[\"test\", \"test2\"]");
+ auto dtype = dictionary(int8(), utf8());
+ auto int_array = ArrayFromJSON(int8(), "[0, 1, 0]");
+ DictionaryArray expected(dtype, int_array, dict);
+
+ ASSERT_TRUE(expected.Equals(result));
+
+ // build a delta dictionary
+ ASSERT_OK(builder.Append("test2"));
+ ASSERT_OK(builder.Append("test3"));
+ ASSERT_OK(builder.Append("test2"));
+
+ std::shared_ptr<Array> result_indices, result_delta;
+ ASSERT_OK(builder.FinishDelta(&result_indices, &result_delta));
+
+ // Build expected data
+ AssertArraysEqual(*ArrayFromJSON(int8(), "[1, 2, 1]"), *result_indices);
+ AssertArraysEqual(*ArrayFromJSON(utf8(), "[\"test3\"]"), *result_delta);
+}
+
+TEST(TestStringDictionaryBuilder, BigDeltaDictionary) {
+ constexpr int16_t kTestLength = 2048;
+ // Build the dictionary Array
+ StringDictionaryBuilder builder;
+
+ StringBuilder str_builder1;
+ Int16Builder int_builder1;
+
+ for (int16_t idx = 0; idx < kTestLength; ++idx) {
+ std::stringstream sstream;
+ sstream << "test" << idx;
+ ASSERT_OK(builder.Append(sstream.str()));
+ ASSERT_OK(str_builder1.Append(sstream.str()));
+ ASSERT_OK(int_builder1.Append(idx));
+ }
+
+ std::shared_ptr<Array> result;
+ FinishAndCheckPadding(&builder, &result);
+
+ std::shared_ptr<Array> str_array1;
+ ASSERT_OK(str_builder1.Finish(&str_array1));
+
+ auto dtype1 = dictionary(int16(), utf8());
+
+ std::shared_ptr<Array> int_array1;
+ ASSERT_OK(int_builder1.Finish(&int_array1));
+
+ DictionaryArray expected(dtype1, int_array1, str_array1);
+ ASSERT_TRUE(expected.Equals(result));
+
+ // build delta 1
+ StringBuilder str_builder2;
+ Int16Builder int_builder2;
+
+ for (int16_t idx = 0; idx < kTestLength; ++idx) {
+ ASSERT_OK(builder.Append("test1"));
+ ASSERT_OK(int_builder2.Append(1));
+ }
+
+ for (int16_t idx = 0; idx < kTestLength; ++idx) {
+ ASSERT_OK(builder.Append("test_new_value1"));
+ ASSERT_OK(int_builder2.Append(kTestLength));
+ }
+ ASSERT_OK(str_builder2.Append("test_new_value1"));
+
+ std::shared_ptr<Array> indices2, delta2;
+ ASSERT_OK(builder.FinishDelta(&indices2, &delta2));
+
+ std::shared_ptr<Array> str_array2;
+ ASSERT_OK(str_builder2.Finish(&str_array2));
+
+ std::shared_ptr<Array> int_array2;
+ ASSERT_OK(int_builder2.Finish(&int_array2));
+
+ AssertArraysEqual(*int_array2, *indices2);
+ AssertArraysEqual(*str_array2, *delta2);
+
+ // build delta 2
+ StringBuilder str_builder3;
+ Int16Builder int_builder3;
+
+ for (int16_t idx = 0; idx < kTestLength; ++idx) {
+ ASSERT_OK(builder.Append("test2"));
+ ASSERT_OK(int_builder3.Append(2));
+ }
+
+ for (int16_t idx = 0; idx < kTestLength; ++idx) {
+ ASSERT_OK(builder.Append("test_new_value2"));
+ ASSERT_OK(int_builder3.Append(kTestLength + 1));
+ }
+ ASSERT_OK(str_builder3.Append("test_new_value2"));
+
+ std::shared_ptr<Array> indices3, delta3;
+ ASSERT_OK(builder.FinishDelta(&indices3, &delta3));
+
+ std::shared_ptr<Array> str_array3;
+ ASSERT_OK(str_builder3.Finish(&str_array3));
+
+ std::shared_ptr<Array> int_array3;
+ ASSERT_OK(int_builder3.Finish(&int_array3));
+
+ AssertArraysEqual(*int_array3, *indices3);
+ AssertArraysEqual(*str_array3, *delta3);
+}
+
+TEST(TestFixedSizeBinaryDictionaryBuilder, Basic) {
+ // Build the dictionary Array
+ DictionaryBuilder<FixedSizeBinaryType> builder(arrow::fixed_size_binary(4));
+ std::vector<uint8_t> test{12, 12, 11, 12};
+ std::vector<uint8_t> test2{12, 12, 11, 11};
+ ASSERT_OK(builder.Append(test.data()));
+ ASSERT_OK(builder.Append(test2.data()));
+ ASSERT_OK(builder.Append(test.data()));
+
+ std::shared_ptr<Array> result;
+ FinishAndCheckPadding(&builder, &result);
+
+ // Build expected data
+ auto value_type = arrow::fixed_size_binary(4);
+ FixedSizeBinaryBuilder fsb_builder(value_type);
+ ASSERT_OK(fsb_builder.Append(test.data()));
+ ASSERT_OK(fsb_builder.Append(test2.data()));
+ std::shared_ptr<Array> fsb_array;
+ ASSERT_OK(fsb_builder.Finish(&fsb_array));
+
+ auto dtype = dictionary(int8(), value_type);
+
+ Int8Builder int_builder;
+ ASSERT_OK(int_builder.Append(0));
+ ASSERT_OK(int_builder.Append(1));
+ ASSERT_OK(int_builder.Append(0));
+ std::shared_ptr<Array> int_array;
+ ASSERT_OK(int_builder.Finish(&int_array));
+
+ DictionaryArray expected(dtype, int_array, fsb_array);
+ ASSERT_TRUE(expected.Equals(result));
+}
+
+TEST(TestFixedSizeBinaryDictionaryBuilder, ArrayInit) {
+ // Build the dictionary Array
+ auto value_type = fixed_size_binary(4);
+ auto dict_array = ArrayFromJSON(value_type, R"(["abcd", "wxyz"])");
+ util::string_view test = "abcd", test2 = "wxyz";
+ DictionaryBuilder<FixedSizeBinaryType> builder(dict_array);
+ ASSERT_OK(builder.Append(test));
+ ASSERT_OK(builder.Append(test2));
+ ASSERT_OK(builder.Append(test));
+
+ std::shared_ptr<Array> result;
+ FinishAndCheckPadding(&builder, &result);
+
+ // Build expected data
+ auto indices = ArrayFromJSON(int8(), "[0, 1, 0]");
+ DictionaryArray expected(dictionary(int8(), value_type), indices, dict_array);
+ AssertArraysEqual(expected, *result);
+}
+
+TEST(TestFixedSizeBinaryDictionaryBuilder, MakeBuilder) {
+ // Build the dictionary Array
+ auto value_type = fixed_size_binary(4);
+ auto dict_array = ArrayFromJSON(value_type, R"(["abcd", "wxyz"])");
+ auto dict_type = dictionary(int8(), value_type);
+
+ std::unique_ptr<ArrayBuilder> boxed_builder;
+ ASSERT_OK(MakeBuilder(default_memory_pool(), dict_type, &boxed_builder));
+ auto& builder = checked_cast<DictionaryBuilder<FixedSizeBinaryType>&>(*boxed_builder);
+ util::string_view test = "abcd", test2 = "wxyz";
+ ASSERT_OK(builder.Append(test));
+ ASSERT_OK(builder.Append(test2));
+ ASSERT_OK(builder.Append(test));
+
+ std::shared_ptr<Array> result;
+ FinishAndCheckPadding(&builder, &result);
+
+ // Build expected data
+ auto indices = ArrayFromJSON(int8(), "[0, 1, 0]");
+ DictionaryArray expected(dict_type, indices, dict_array);
+ AssertArraysEqual(expected, *result);
+}
+
+TEST(TestFixedSizeBinaryDictionaryBuilder, DeltaDictionary) {
+ // Build the dictionary Array
+ auto value_type = arrow::fixed_size_binary(4);
+ auto dict_type = dictionary(int8(), value_type);
+
+ DictionaryBuilder<FixedSizeBinaryType> builder(value_type);
+ std::vector<uint8_t> test{12, 12, 11, 12};
+ std::vector<uint8_t> test2{12, 12, 11, 11};
+ std::vector<uint8_t> test3{12, 12, 11, 10};
+
+ ASSERT_OK(builder.Append(test.data()));
+ ASSERT_OK(builder.Append(test2.data()));
+ ASSERT_OK(builder.Append(test.data()));
+
+ std::shared_ptr<Array> result1;
+ FinishAndCheckPadding(&builder, &result1);
+
+ // Build expected data
+ FixedSizeBinaryBuilder fsb_builder1(value_type);
+ ASSERT_OK(fsb_builder1.Append(test.data()));
+ ASSERT_OK(fsb_builder1.Append(test2.data()));
+ std::shared_ptr<Array> fsb_array1;
+ ASSERT_OK(fsb_builder1.Finish(&fsb_array1));
+
+ Int8Builder int_builder1;
+ ASSERT_OK(int_builder1.Append(0));
+ ASSERT_OK(int_builder1.Append(1));
+ ASSERT_OK(int_builder1.Append(0));
+ std::shared_ptr<Array> int_array1;
+ ASSERT_OK(int_builder1.Finish(&int_array1));
+
+ DictionaryArray expected1(dict_type, int_array1, fsb_array1);
+ ASSERT_TRUE(expected1.Equals(result1));
+
+ // build delta dictionary
+ ASSERT_OK(builder.Append(test.data()));
+ ASSERT_OK(builder.Append(test2.data()));
+ ASSERT_OK(builder.Append(test3.data()));
+
+ std::shared_ptr<Array> indices2, delta2;
+ ASSERT_OK(builder.FinishDelta(&indices2, &delta2));
+
+ // Build expected data
+ FixedSizeBinaryBuilder fsb_builder2(value_type);
+ ASSERT_OK(fsb_builder2.Append(test3.data()));
+ std::shared_ptr<Array> fsb_array2;
+ ASSERT_OK(fsb_builder2.Finish(&fsb_array2));
+
+ Int8Builder int_builder2;
+ ASSERT_OK(int_builder2.Append(0));
+ ASSERT_OK(int_builder2.Append(1));
+ ASSERT_OK(int_builder2.Append(2));
+
+ std::shared_ptr<Array> int_array2;
+ ASSERT_OK(int_builder2.Finish(&int_array2));
+
+ AssertArraysEqual(*int_array2, *indices2);
+ AssertArraysEqual(*fsb_array2, *delta2);
+}
+
+TEST(TestFixedSizeBinaryDictionaryBuilder, DoubleTableSize) {
+ // Build the dictionary Array
+ auto value_type = arrow::fixed_size_binary(4);
+ auto dict_type = dictionary(int16(), value_type);
+
+ DictionaryBuilder<FixedSizeBinaryType> builder(value_type);
+ // Build expected data
+ FixedSizeBinaryBuilder fsb_builder(value_type);
+ Int16Builder int_builder;
+
+ // Fill with 1024 different values
+ for (int64_t i = 0; i < 1024; i++) {
+ std::vector<uint8_t> value{12, 12, static_cast<uint8_t>(i / 128),
+ static_cast<uint8_t>(i % 128)};
+ ASSERT_OK(builder.Append(value.data()));
+ ASSERT_OK(fsb_builder.Append(value.data()));
+ ASSERT_OK(int_builder.Append(static_cast<uint16_t>(i)));
+ }
+ // Fill with an already existing value
+ std::vector<uint8_t> known_value{12, 12, 0, 1};
+ for (int64_t i = 0; i < 1024; i++) {
+ ASSERT_OK(builder.Append(known_value.data()));
+ ASSERT_OK(int_builder.Append(1));
+ }
+
+ // Finalize result
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+
+ // Finalize expected data
+ std::shared_ptr<Array> fsb_array;
+ ASSERT_OK(fsb_builder.Finish(&fsb_array));
+ std::shared_ptr<Array> int_array;
+ ASSERT_OK(int_builder.Finish(&int_array));
+
+ DictionaryArray expected(dict_type, int_array, fsb_array);
+ ASSERT_TRUE(expected.Equals(result));
+}
+
+#ifndef NDEBUG
+TEST(TestFixedSizeBinaryDictionaryBuilder, AppendArrayInvalidType) {
+ // Build the dictionary Array
+ auto value_type = fixed_size_binary(4);
+ DictionaryBuilder<FixedSizeBinaryType> builder(value_type);
+ // Build an array with different byte width
+ auto fsb_array = ArrayFromJSON(fixed_size_binary(3), R"(["foo", "bar"])");
+
+ ASSERT_RAISES(TypeError, builder.AppendArray(*fsb_array));
+}
+#endif
+
+template <typename DecimalValue>
+void TestDecimalDictionaryBuilderBasic(std::shared_ptr<DataType> decimal_type) {
+ // Build the dictionary Array
+ DictionaryBuilder<FixedSizeBinaryType> builder(decimal_type);
+
+ // Test data
+ std::vector<DecimalValue> test{12, 12, 11, 12};
+ for (const auto& value : test) {
+ ASSERT_OK(builder.Append(value.ToBytes().data()));
+ }
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+
+ // Build expected data
+ DictionaryArray expected(dictionary(int8(), decimal_type),
+ ArrayFromJSON(int8(), "[0, 0, 1, 0]"),
+ ArrayFromJSON(decimal_type, "[\"12\", \"11\"]"));
+
+ ASSERT_TRUE(expected.Equals(result));
+}
+
+TEST(TestDecimal128DictionaryBuilder, Basic) {
+ TestDecimalDictionaryBuilderBasic<Decimal128>(arrow::decimal128(2, 0));
+}
+
+TEST(TestDecimal256DictionaryBuilder, Basic) {
+ TestDecimalDictionaryBuilderBasic<Decimal256>(arrow::decimal256(76, 0));
+}
+
+void TestDecimalDictionaryBuilderDoubleTableSize(
+ std::shared_ptr<DataType> decimal_type, FixedSizeBinaryBuilder& decimal_builder) {
+ // Build the dictionary Array
+ DictionaryBuilder<FixedSizeBinaryType> dict_builder(decimal_type);
+
+ // Build expected data
+ Int16Builder int_builder;
+
+ // Fill with 1024 different values
+ for (int64_t i = 0; i < 1024; i++) {
+ // Decimal256Builder takes 32 bytes, while Decimal128Builder takes only the first 16
+ // bytes.
+ const uint8_t bytes[32] = {0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 12,
+ 12,
+ static_cast<uint8_t>(i / 128),
+ static_cast<uint8_t>(i % 128)};
+ ASSERT_OK(dict_builder.Append(bytes));
+ ASSERT_OK(decimal_builder.Append(bytes));
+ ASSERT_OK(int_builder.Append(static_cast<uint16_t>(i)));
+ }
+ // Fill with an already existing value
+ const uint8_t known_value[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 12, 0, 1};
+ for (int64_t i = 0; i < 1024; i++) {
+ ASSERT_OK(dict_builder.Append(known_value));
+ ASSERT_OK(int_builder.Append(1));
+ }
+
+ // Finalize result
+ std::shared_ptr<Array> result;
+ ASSERT_OK(dict_builder.Finish(&result));
+
+ // Finalize expected data
+ std::shared_ptr<Array> decimal_array;
+ ASSERT_OK(decimal_builder.Finish(&decimal_array));
+
+ std::shared_ptr<Array> int_array;
+ ASSERT_OK(int_builder.Finish(&int_array));
+
+ DictionaryArray expected(dictionary(int16(), decimal_type), int_array, decimal_array);
+ ASSERT_TRUE(expected.Equals(result));
+}
+
+TEST(TestDecimal128DictionaryBuilder, DoubleTableSize) {
+ const auto& decimal_type = arrow::decimal128(21, 0);
+ Decimal128Builder decimal_builder(decimal_type);
+ TestDecimalDictionaryBuilderDoubleTableSize(decimal_type, decimal_builder);
+}
+
+TEST(TestDecimal256DictionaryBuilder, DoubleTableSize) {
+ const auto& decimal_type = arrow::decimal256(21, 0);
+ Decimal256Builder decimal_builder(decimal_type);
+ TestDecimalDictionaryBuilderDoubleTableSize(decimal_type, decimal_builder);
+}
+
+TEST(TestNullDictionaryBuilder, Basic) {
+ // MakeBuilder
+ auto dict_type = dictionary(int8(), null());
+ std::unique_ptr<ArrayBuilder> boxed_builder;
+ ASSERT_OK(MakeBuilder(default_memory_pool(), dict_type, &boxed_builder));
+ auto& builder = checked_cast<DictionaryBuilder<NullType>&>(*boxed_builder);
+
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_EQ(3, builder.length());
+ ASSERT_EQ(3, builder.null_count());
+
+ ASSERT_OK(builder.AppendNulls(4));
+ ASSERT_EQ(7, builder.length());
+ ASSERT_EQ(7, builder.null_count());
+
+ auto null_array = ArrayFromJSON(null(), "[null, null, null, null]");
+ ASSERT_OK(builder.AppendArray(*null_array));
+ ASSERT_EQ(11, builder.length());
+ ASSERT_EQ(11, builder.null_count());
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+ AssertTypeEqual(*dict_type, *result->type());
+ ASSERT_EQ(11, result->length());
+ ASSERT_EQ(11, result->null_count());
+}
+
+#ifndef NDEBUG
+TEST(TestNullDictionaryBuilder, AppendArrayInvalidType) {
+ // MakeBuilder
+ auto dict_type = dictionary(int8(), null());
+ std::unique_ptr<ArrayBuilder> boxed_builder;
+ ASSERT_OK(MakeBuilder(default_memory_pool(), dict_type, &boxed_builder));
+ auto& builder = checked_cast<DictionaryBuilder<NullType>&>(*boxed_builder);
+
+ auto int8_array = ArrayFromJSON(int8(), "[0, 1, 0, null]");
+ ASSERT_RAISES(TypeError, builder.AppendArray(*int8_array));
+}
+#endif
+
+// ----------------------------------------------------------------------
+// Index byte width tests
+
+template <typename IndexType, typename ValueType>
+void AssertIndexByteWidth(const std::shared_ptr<DataType>& value_type =
+ TypeTraits<ValueType>::type_singleton()) {
+ auto index_type = TypeTraits<IndexType>::type_singleton();
+ auto dict_type =
+ checked_pointer_cast<DictionaryType>(dictionary(index_type, value_type));
+ std::unique_ptr<ArrayBuilder> builder;
+ ASSERT_OK(MakeBuilder(default_memory_pool(), dict_type, &builder));
+ auto builder_dict_type = checked_pointer_cast<DictionaryType>(builder->type());
+ AssertTypeEqual(dict_type->index_type(), builder_dict_type->index_type());
+}
+
+typedef ::testing::Types<Int8Type, Int16Type, Int32Type, Int64Type> IndexTypes;
+
+template <typename Type>
+class TestDictionaryBuilderIndexByteWidth : public TestBuilder {};
+
+TYPED_TEST_SUITE(TestDictionaryBuilderIndexByteWidth, IndexTypes);
+
+TYPED_TEST(TestDictionaryBuilderIndexByteWidth, MakeBuilder) {
+ AssertIndexByteWidth<TypeParam, FloatType>();
+ AssertIndexByteWidth<TypeParam, BinaryType>();
+ AssertIndexByteWidth<TypeParam, StringType>();
+ AssertIndexByteWidth<TypeParam, FixedSizeBinaryType>(fixed_size_binary(4));
+ AssertIndexByteWidth<TypeParam, NullType>();
+}
+
+// ----------------------------------------------------------------------
+// DictionaryArray tests
+
+TEST(TestDictionary, Equals) {
+ std::vector<bool> is_valid = {true, true, false, true, true, true};
+ std::shared_ptr<Array> dict, dict2, indices, indices2, indices3;
+
+ dict = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]");
+ std::shared_ptr<DataType> dict_type = dictionary(int16(), utf8());
+
+ dict2 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\", \"qux\"]");
+ std::shared_ptr<DataType> dict2_type = dictionary(int16(), utf8());
+
+ std::vector<int16_t> indices_values = {1, 2, -1, 0, 2, 0};
+ ArrayFromVector<Int16Type, int16_t>(is_valid, indices_values, &indices);
+
+ std::vector<int16_t> indices2_values = {1, 2, 0, 0, 2, 0};
+ ArrayFromVector<Int16Type, int16_t>(is_valid, indices2_values, &indices2);
+
+ std::vector<int16_t> indices3_values = {1, 1, 0, 0, 2, 0};
+ ArrayFromVector<Int16Type, int16_t>(is_valid, indices3_values, &indices3);
+
+ auto array = std::make_shared<DictionaryArray>(dict_type, indices, dict);
+ auto array2 = std::make_shared<DictionaryArray>(dict_type, indices2, dict);
+ auto array3 = std::make_shared<DictionaryArray>(dict2_type, indices, dict2);
+ auto array4 = std::make_shared<DictionaryArray>(dict_type, indices3, dict);
+
+ ASSERT_TRUE(array->Equals(array));
+
+ // Equal, because the unequal index is masked by null
+ ASSERT_TRUE(array->Equals(array2));
+
+ // Unequal dictionaries
+ ASSERT_FALSE(array->Equals(array3));
+
+ // Unequal indices
+ ASSERT_FALSE(array->Equals(array4));
+
+ // RangeEquals
+ ASSERT_TRUE(array->RangeEquals(3, 6, 3, array4));
+ ASSERT_FALSE(array->RangeEquals(1, 3, 1, array4));
+
+ // ARROW-33 Test slices
+ const int64_t size = array->length();
+
+ std::shared_ptr<Array> slice, slice2;
+ slice = array->Array::Slice(2);
+ slice2 = array->Array::Slice(2);
+ ASSERT_EQ(size - 2, slice->length());
+
+ ASSERT_TRUE(slice->Equals(slice2));
+ ASSERT_TRUE(array->RangeEquals(2, array->length(), 0, slice));
+
+ // Chained slices
+ slice2 = array->Array::Slice(1)->Array::Slice(1);
+ ASSERT_TRUE(slice->Equals(slice2));
+
+ slice = array->Slice(1, 3);
+ slice2 = array->Slice(1, 3);
+ ASSERT_EQ(3, slice->length());
+
+ ASSERT_TRUE(slice->Equals(slice2));
+ ASSERT_TRUE(array->RangeEquals(1, 4, 0, slice));
+}
+
+TEST(TestDictionary, Validate) {
+ auto dict = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]");
+ auto dict_type = dictionary(int16(), utf8());
+
+ auto indices = ArrayFromJSON(int16(), "[1, 2, null, 0, 2, 0]");
+ std::shared_ptr<Array> arr =
+ std::make_shared<DictionaryArray>(dict_type, indices, dict);
+
+ // Only checking index type for now
+ ASSERT_OK(arr->ValidateFull());
+
+ // ARROW-7008: Invalid dict was not being validated
+ std::vector<std::shared_ptr<Buffer>> buffers = {nullptr, nullptr, nullptr};
+ auto invalid_data = std::make_shared<ArrayData>(utf8(), 0, buffers);
+
+ indices = ArrayFromJSON(int16(), "[]");
+ arr = std::make_shared<DictionaryArray>(dict_type, indices, MakeArray(invalid_data));
+ ASSERT_RAISES(Invalid, arr->ValidateFull());
+
+ // Make the data buffer non-null
+ ASSERT_OK_AND_ASSIGN(buffers[2], AllocateBuffer(0));
+ arr = std::make_shared<DictionaryArray>(dict_type, indices, MakeArray(invalid_data));
+ ASSERT_RAISES(Invalid, arr->ValidateFull());
+
+ ASSERT_DEATH(
+ {
+ std::shared_ptr<Array> null_dict_arr =
+ std::make_shared<DictionaryArray>(dict_type, indices, nullptr);
+ },
+ "");
+}
+
+TEST(TestDictionary, FromArrays) {
+ auto dict = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]");
+ for (auto index_ty : all_dictionary_index_types()) {
+ auto dict_type = dictionary(index_ty, utf8());
+
+ auto indices1 = ArrayFromJSON(index_ty, "[1, 2, 0, 0, 2, 0]");
+ // Index out of bounds
+ auto indices2 = ArrayFromJSON(index_ty, "[1, 2, 0, 3, 2, 0]");
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ DictionaryArray::FromArrays(dict_type, indices1, dict));
+ ASSERT_RAISES(IndexError, DictionaryArray::FromArrays(dict_type, indices2, dict));
+
+ if (checked_cast<const IntegerType&>(*index_ty).is_signed()) {
+ // Invalid index is masked by null, so it's OK
+ auto indices3 = ArrayFromJSON(index_ty, "[1, 2, -1, null, 2, 0]");
+ BitUtil::ClearBit(indices3->data()->buffers[0]->mutable_data(), 2);
+ ASSERT_OK_AND_ASSIGN(auto arr3,
+ DictionaryArray::FromArrays(dict_type, indices3, dict));
+ }
+
+ auto indices4 = ArrayFromJSON(index_ty, "[1, 2, null, 3, 2, 0]");
+ ASSERT_RAISES(IndexError, DictionaryArray::FromArrays(dict_type, indices4, dict));
+
+ // Probe other validation checks
+ ASSERT_RAISES(TypeError, DictionaryArray::FromArrays(index_ty, indices4, dict));
+
+ auto different_index_ty =
+ dictionary(index_ty->id() == Type::INT8 ? uint8() : int8(), utf8());
+ ASSERT_RAISES(TypeError,
+ DictionaryArray::FromArrays(different_index_ty, indices4, dict));
+ }
+}
+
+static void CheckTranspose(const std::shared_ptr<Array>& input,
+ const int32_t* transpose_map,
+ const std::shared_ptr<DataType>& out_dict_type,
+ const std::shared_ptr<Array>& out_dict,
+ const std::shared_ptr<Array>& expected_indices) {
+ ASSERT_OK_AND_ASSIGN(auto transposed,
+ internal::checked_cast<const DictionaryArray&>(*input).Transpose(
+ out_dict_type, out_dict, transpose_map));
+ ASSERT_OK(transposed->ValidateFull());
+
+ ASSERT_OK_AND_ASSIGN(auto expected, DictionaryArray::FromArrays(
+ out_dict_type, expected_indices, out_dict));
+ AssertArraysEqual(*transposed, *expected);
+}
+
+TEST(TestDictionary, TransposeBasic) {
+ auto dict = ArrayFromJSON(utf8(), "[\"A\", \"B\", \"C\"]");
+
+ auto CheckIndexType = [&](const std::shared_ptr<DataType>& index_ty) {
+ auto dict_type = dictionary(index_ty, utf8());
+ auto indices = ArrayFromJSON(index_ty, "[1, 2, 0, 0]");
+ // ["B", "C", "A", "A"]
+ ASSERT_OK_AND_ASSIGN(auto arr, DictionaryArray::FromArrays(dict_type, indices, dict));
+ // ["C", "A"]
+ auto sliced = arr->Slice(1, 2);
+
+ // Transpose to same index type
+ {
+ auto out_dict_type = dict_type;
+ auto out_dict = ArrayFromJSON(utf8(), "[\"Z\", \"A\", \"C\", \"B\"]");
+ auto expected_indices = ArrayFromJSON(index_ty, "[3, 2, 1, 1]");
+ std::vector<int32_t> transpose_map = {1, 3, 2};
+ CheckTranspose(arr, transpose_map.data(), out_dict_type, out_dict,
+ expected_indices);
+
+ // Sliced
+ expected_indices = ArrayFromJSON(index_ty, "[2, 1]");
+ CheckTranspose(sliced, transpose_map.data(), out_dict_type, out_dict,
+ expected_indices);
+ }
+
+ // Transpose to other index type
+ auto out_dict = ArrayFromJSON(utf8(), "[\"Z\", \"A\", \"C\", \"B\"]");
+ std::vector<int32_t> transpose_map = {1, 3, 2};
+ for (auto other_ty : all_dictionary_index_types()) {
+ auto out_dict_type = dictionary(other_ty, utf8());
+ auto expected_indices = ArrayFromJSON(other_ty, "[3, 2, 1, 1]");
+ CheckTranspose(arr, transpose_map.data(), out_dict_type, out_dict,
+ expected_indices);
+
+ // Sliced
+ expected_indices = ArrayFromJSON(other_ty, "[2, 1]");
+ CheckTranspose(sliced, transpose_map.data(), out_dict_type, out_dict,
+ expected_indices);
+ }
+ };
+
+ for (auto ty : all_dictionary_index_types()) {
+ CheckIndexType(ty);
+ }
+}
+
+TEST(TestDictionary, TransposeTrivial) {
+ // Test a trivial transposition, possibly optimized away
+
+ auto dict = ArrayFromJSON(utf8(), "[\"A\", \"B\", \"C\"]");
+ auto dict_type = dictionary(int16(), utf8());
+ auto indices = ArrayFromJSON(int16(), "[1, 2, 0, 0]");
+ // ["B", "C", "A", "A"]
+ ASSERT_OK_AND_ASSIGN(auto arr, DictionaryArray::FromArrays(dict_type, indices, dict));
+ // ["C", "A"]
+ auto sliced = arr->Slice(1, 2);
+
+ std::vector<int32_t> transpose_map = {0, 1, 2};
+
+ // Transpose to same index type
+ {
+ auto out_dict_type = dict_type;
+ auto out_dict = ArrayFromJSON(utf8(), "[\"A\", \"B\", \"C\", \"D\"]");
+ auto expected_indices = ArrayFromJSON(int16(), "[1, 2, 0, 0]");
+ CheckTranspose(arr, transpose_map.data(), out_dict_type, out_dict, expected_indices);
+
+ // Sliced
+ expected_indices = ArrayFromJSON(int16(), "[2, 0]");
+ CheckTranspose(sliced, transpose_map.data(), out_dict_type, out_dict,
+ expected_indices);
+ }
+
+ // Transpose to other index type
+ {
+ auto out_dict_type = dictionary(int8(), utf8());
+ auto out_dict = ArrayFromJSON(utf8(), "[\"A\", \"B\", \"C\", \"D\"]");
+ auto expected_indices = ArrayFromJSON(int8(), "[1, 2, 0, 0]");
+ CheckTranspose(arr, transpose_map.data(), out_dict_type, out_dict, expected_indices);
+
+ // Sliced
+ expected_indices = ArrayFromJSON(int8(), "[2, 0]");
+ CheckTranspose(sliced, transpose_map.data(), out_dict_type, out_dict,
+ expected_indices);
+ }
+}
+
+TEST(TestDictionary, GetValueIndex) {
+ const char* indices_json = "[5, 0, 1, 3, 2, 4]";
+ auto indices_int64 = ArrayFromJSON(int64(), indices_json);
+ auto dict = ArrayFromJSON(int32(), "[10, 20, 30, 40, 50, 60]");
+
+ const auto& typed_indices_int64 = checked_cast<const Int64Array&>(*indices_int64);
+ for (auto index_ty : all_dictionary_index_types()) {
+ auto indices = ArrayFromJSON(index_ty, indices_json);
+ auto dict_ty = dictionary(index_ty, int32());
+
+ DictionaryArray dict_arr(dict_ty, indices, dict);
+
+ int64_t offset = 1;
+ auto sliced_dict_arr = dict_arr.Slice(offset);
+
+ for (int64_t i = 0; i < indices->length(); ++i) {
+ ASSERT_EQ(dict_arr.GetValueIndex(i), typed_indices_int64.Value(i));
+ if (i < sliced_dict_arr->length()) {
+ ASSERT_EQ(checked_cast<const DictionaryArray&>(*sliced_dict_arr).GetValueIndex(i),
+ typed_indices_int64.Value(i + offset));
+ }
+ }
+ }
+}
+
+TEST(TestDictionary, TransposeNulls) {
+ auto dict = ArrayFromJSON(utf8(), "[\"A\", \"B\", \"C\"]");
+ auto dict_type = dictionary(int16(), utf8());
+ auto indices = ArrayFromJSON(int16(), "[1, 2, null, 0]");
+ // ["B", "C", null, "A"]
+ ASSERT_OK_AND_ASSIGN(auto arr, DictionaryArray::FromArrays(dict_type, indices, dict));
+ // ["C", null]
+ auto sliced = arr->Slice(1, 2);
+
+ auto out_dict = ArrayFromJSON(utf8(), "[\"Z\", \"A\", \"C\", \"B\"]");
+ auto out_dict_type = dictionary(int16(), utf8());
+ auto expected_indices = ArrayFromJSON(int16(), "[3, 2, null, 1]");
+
+ std::vector<int32_t> transpose_map = {1, 3, 2};
+ CheckTranspose(arr, transpose_map.data(), out_dict_type, out_dict, expected_indices);
+
+ // Sliced
+ expected_indices = ArrayFromJSON(int16(), "[2, null]");
+ CheckTranspose(sliced, transpose_map.data(), out_dict_type, out_dict, expected_indices);
+}
+
+TEST(TestDictionary, ListOfDictionary) {
+ std::unique_ptr<ArrayBuilder> root_builder;
+ ASSERT_OK(MakeBuilder(default_memory_pool(), list(dictionary(int8(), utf8())),
+ &root_builder));
+ auto list_builder = checked_cast<ListBuilder*>(root_builder.get());
+ auto dict_builder =
+ checked_cast<DictionaryBuilder<StringType>*>(list_builder->value_builder());
+
+ ASSERT_OK(list_builder->Append());
+ std::vector<std::string> expected;
+ for (char a : util::string_view("abc")) {
+ for (char d : util::string_view("def")) {
+ for (char g : util::string_view("ghi")) {
+ for (char j : util::string_view("jkl")) {
+ for (char m : util::string_view("mno")) {
+ for (char p : util::string_view("pqr")) {
+ if ((static_cast<int>(a) + d + g + j + m + p) % 16 == 0) {
+ ASSERT_OK(list_builder->Append());
+ }
+ // 3**6 distinct strings; too large for int8
+ char str[] = {a, d, g, j, m, p, '\0'};
+ ASSERT_OK(dict_builder->Append(str));
+ expected.push_back(str);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ ASSERT_TRUE(list_builder->type()->Equals(list(dictionary(int16(), utf8()))));
+
+ std::shared_ptr<Array> expected_dict;
+ ArrayFromVector<StringType, std::string>(expected, &expected_dict);
+
+ std::shared_ptr<Array> array;
+ ASSERT_OK(root_builder->Finish(&array));
+ ASSERT_OK(array->ValidateFull());
+
+ auto expected_type = list(dictionary(int16(), utf8()));
+ ASSERT_EQ(array->type()->ToString(), expected_type->ToString());
+
+ auto list_array = checked_cast<const ListArray*>(array.get());
+ auto actual_dict =
+ checked_cast<const DictionaryArray&>(*list_array->values()).dictionary();
+ ASSERT_ARRAYS_EQUAL(*expected_dict, *actual_dict);
+}
+
+TEST(TestDictionary, CanCompareIndices) {
+ auto make_dict = [](std::shared_ptr<DataType> index_type,
+ std::shared_ptr<DataType> value_type, std::string dictionary_json) {
+ std::shared_ptr<Array> out;
+ ARROW_EXPECT_OK(
+ DictionaryArray::FromArrays(dictionary(index_type, value_type),
+ ArrayFromJSON(index_type, "[]"),
+ ArrayFromJSON(value_type, dictionary_json))
+ .Value(&out));
+ return checked_pointer_cast<DictionaryArray>(out);
+ };
+
+ auto compare_and_swap = [](const DictionaryArray& l, const DictionaryArray& r,
+ bool expected) {
+ ASSERT_EQ(l.CanCompareIndices(r), expected)
+ << "left: " << l.ToString() << "\nright: " << r.ToString();
+ ASSERT_EQ(r.CanCompareIndices(l), expected)
+ << "left: " << r.ToString() << "\nright: " << l.ToString();
+ };
+
+ {
+ auto array = make_dict(int16(), utf8(), R"(["foo", "bar"])");
+ auto same = make_dict(int16(), utf8(), R"(["foo", "bar"])");
+ compare_and_swap(*array, *same, true);
+ }
+
+ {
+ auto array = make_dict(int16(), utf8(), R"(["foo", "bar", "quux"])");
+ auto prefix_dict = make_dict(int16(), utf8(), R"(["foo", "bar"])");
+ compare_and_swap(*array, *prefix_dict, true);
+ }
+
+ {
+ auto array = make_dict(int16(), utf8(), R"(["foo", "bar"])");
+ auto indices_need_casting = make_dict(int8(), utf8(), R"(["foo", "bar"])");
+ compare_and_swap(*array, *indices_need_casting, false);
+ }
+
+ {
+ auto array = make_dict(int16(), utf8(), R"(["foo", "bar", "quux"])");
+ auto non_prefix_dict = make_dict(int16(), utf8(), R"(["foo", "blink"])");
+ compare_and_swap(*array, *non_prefix_dict, false);
+ }
+}
+
+TEST(TestDictionary, IndicesArray) {
+ auto dict = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]");
+ auto dict_type = dictionary(int16(), utf8());
+ auto indices = ArrayFromJSON(int16(), "[1, 2, null, 0, 2, 0]");
+ auto arr = std::make_shared<DictionaryArray>(dict_type, indices, dict);
+
+ // The indices array should not have dictionary data
+ ASSERT_EQ(arr->indices()->data()->dictionary, nullptr);
+
+ // Validate the indices array
+ ASSERT_OK(arr->indices()->ValidateFull());
+}
+
+TEST(TestDictionaryUnifier, Numeric) {
+ auto dict_ty = int64();
+
+ auto d1 = ArrayFromJSON(dict_ty, "[3, 4, 7]");
+ auto d2 = ArrayFromJSON(dict_ty, "[1, 7, 4, 8]");
+ auto d3 = ArrayFromJSON(dict_ty, "[1, -200]");
+
+ auto expected = dictionary(int8(), dict_ty);
+ auto expected_dict = ArrayFromJSON(dict_ty, "[3, 4, 7, 1, 8, -200]");
+
+ ASSERT_OK_AND_ASSIGN(auto unifier, DictionaryUnifier::Make(dict_ty));
+
+ std::shared_ptr<DataType> out_type;
+ std::shared_ptr<Array> out_dict;
+
+ ASSERT_OK(unifier->Unify(*d1));
+ ASSERT_OK(unifier->Unify(*d2));
+ ASSERT_OK(unifier->Unify(*d3));
+
+ ASSERT_RAISES(Invalid, unifier->Unify(*ArrayFromJSON(int32(), "[1, -200]")));
+
+ ASSERT_OK(unifier->GetResult(&out_type, &out_dict));
+ ASSERT_TRUE(out_type->Equals(*expected));
+ ASSERT_TRUE(out_dict->Equals(*expected_dict));
+
+ std::shared_ptr<Buffer> b1, b2, b3;
+
+ ASSERT_OK(unifier->Unify(*d1, &b1));
+ ASSERT_OK(unifier->Unify(*d2, &b2));
+ ASSERT_OK(unifier->Unify(*d3, &b3));
+ ASSERT_OK(unifier->GetResult(&out_type, &out_dict));
+ ASSERT_TRUE(out_type->Equals(*expected));
+ ASSERT_TRUE(out_dict->Equals(*expected_dict));
+
+ CheckTransposeMap(*b1, {0, 1, 2});
+ CheckTransposeMap(*b2, {3, 2, 1, 4});
+ CheckTransposeMap(*b3, {3, 5});
+}
+
+TEST(TestDictionaryUnifier, String) {
+ auto dict_ty = utf8();
+
+ auto t1 = dictionary(int16(), dict_ty);
+ auto d1 = ArrayFromJSON(dict_ty, "[\"foo\", \"bar\"]");
+
+ auto t2 = dictionary(int32(), dict_ty);
+ auto d2 = ArrayFromJSON(dict_ty, "[\"quux\", \"foo\"]");
+
+ auto expected = dictionary(int8(), dict_ty);
+ auto expected_dict = ArrayFromJSON(dict_ty, "[\"foo\", \"bar\", \"quux\"]");
+
+ ASSERT_OK_AND_ASSIGN(auto unifier, DictionaryUnifier::Make(dict_ty));
+
+ std::shared_ptr<DataType> out_type;
+ std::shared_ptr<Array> out_dict;
+ ASSERT_OK(unifier->Unify(*d1));
+ ASSERT_OK(unifier->Unify(*d2));
+ ASSERT_OK(unifier->GetResult(&out_type, &out_dict));
+ ASSERT_TRUE(out_type->Equals(*expected));
+ ASSERT_TRUE(out_dict->Equals(*expected_dict));
+
+ std::shared_ptr<Buffer> b1, b2;
+
+ ASSERT_OK(unifier->Unify(*d1, &b1));
+ ASSERT_OK(unifier->Unify(*d2, &b2));
+ ASSERT_OK(unifier->GetResult(&out_type, &out_dict));
+ ASSERT_TRUE(out_type->Equals(*expected));
+ ASSERT_TRUE(out_dict->Equals(*expected_dict));
+
+ CheckTransposeMap(*b1, {0, 1});
+ CheckTransposeMap(*b2, {2, 0});
+}
+
+TEST(TestDictionaryUnifier, FixedSizeBinary) {
+ auto type = fixed_size_binary(3);
+
+ std::string data = "foobarbazqux";
+ auto buf = std::make_shared<Buffer>(data);
+ // ["foo", "bar"]
+ auto dict1 = std::make_shared<FixedSizeBinaryArray>(type, 2, SliceBuffer(buf, 0, 6));
+ auto t1 = dictionary(int16(), type);
+ // ["bar", "baz", "qux"]
+ auto dict2 = std::make_shared<FixedSizeBinaryArray>(type, 3, SliceBuffer(buf, 3, 9));
+ auto t2 = dictionary(int16(), type);
+
+ // ["foo", "bar", "baz", "qux"]
+ auto expected_dict = std::make_shared<FixedSizeBinaryArray>(type, 4, buf);
+ auto expected = dictionary(int8(), type);
+
+ ASSERT_OK_AND_ASSIGN(auto unifier, DictionaryUnifier::Make(type));
+ std::shared_ptr<DataType> out_type;
+ std::shared_ptr<Array> out_dict;
+ ASSERT_OK(unifier->Unify(*dict1));
+ ASSERT_OK(unifier->Unify(*dict2));
+ ASSERT_OK(unifier->GetResult(&out_type, &out_dict));
+ ASSERT_TRUE(out_type->Equals(*expected));
+ ASSERT_TRUE(out_dict->Equals(*expected_dict));
+
+ std::shared_ptr<Buffer> b1, b2;
+ ASSERT_OK(unifier->Unify(*dict1, &b1));
+ ASSERT_OK(unifier->Unify(*dict2, &b2));
+ ASSERT_OK(unifier->GetResult(&out_type, &out_dict));
+ ASSERT_TRUE(out_type->Equals(*expected));
+ ASSERT_TRUE(out_dict->Equals(*expected_dict));
+
+ CheckTransposeMap(*b1, {0, 1});
+ CheckTransposeMap(*b2, {1, 2, 3});
+}
+
+TEST(TestDictionaryUnifier, Large) {
+ // Unifying "large" dictionary types should choose the right index type
+ std::shared_ptr<Array> dict1, dict2, expected_dict;
+
+ Int32Builder builder;
+ ASSERT_OK(builder.Reserve(120));
+ for (int32_t i = 0; i < 120; ++i) {
+ builder.UnsafeAppend(i);
+ }
+ ASSERT_OK(builder.Finish(&dict1));
+ ASSERT_EQ(dict1->length(), 120);
+ auto t1 = dictionary(int8(), int32());
+
+ ASSERT_OK(builder.Reserve(30));
+ for (int32_t i = 110; i < 140; ++i) {
+ builder.UnsafeAppend(i);
+ }
+ ASSERT_OK(builder.Finish(&dict2));
+ ASSERT_EQ(dict2->length(), 30);
+ auto t2 = dictionary(int8(), int32());
+
+ ASSERT_OK(builder.Reserve(140));
+ for (int32_t i = 0; i < 140; ++i) {
+ builder.UnsafeAppend(i);
+ }
+ ASSERT_OK(builder.Finish(&expected_dict));
+ ASSERT_EQ(expected_dict->length(), 140);
+
+ // int8 would be too narrow to hold all possible index values
+ auto expected = dictionary(int16(), int32());
+
+ ASSERT_OK_AND_ASSIGN(auto unifier, DictionaryUnifier::Make(int32()));
+ std::shared_ptr<DataType> out_type;
+ std::shared_ptr<Array> out_dict;
+ ASSERT_OK(unifier->Unify(*dict1));
+ ASSERT_OK(unifier->Unify(*dict2));
+ ASSERT_OK(unifier->GetResult(&out_type, &out_dict));
+ ASSERT_TRUE(out_type->Equals(*expected));
+ ASSERT_TRUE(out_dict->Equals(*expected_dict));
+}
+
+TEST(TestDictionaryUnifier, ChunkedArraySimple) {
+ auto type = dictionary(int8(), utf8());
+ auto chunk1 = ArrayFromJSON(type, R"(["ab", "cd", null, "cd"])");
+ auto chunk2 = ArrayFromJSON(type, R"(["ef", "cd", "ef"])");
+ auto chunk3 = ArrayFromJSON(type, R"(["ef", "ab", null, "ab"])");
+ auto chunk4 = ArrayFromJSON(type, "[]");
+ ASSERT_OK_AND_ASSIGN(auto chunked,
+ ChunkedArray::Make({chunk1, chunk2, chunk3, chunk4}));
+
+ ASSERT_OK_AND_ASSIGN(auto unified, DictionaryUnifier::UnifyChunkedArray(chunked));
+ ASSERT_EQ(unified->num_chunks(), 4);
+ auto expected_dict = ArrayFromJSON(utf8(), R"(["ab", "cd", "ef"])");
+ CheckDictionaryArray(unified->chunk(0), expected_dict,
+ ArrayFromJSON(int8(), "[0, 1, null, 1]"));
+ CheckDictionaryArray(unified->chunk(1), expected_dict,
+ ArrayFromJSON(int8(), "[2, 1, 2]"));
+ CheckDictionaryArray(unified->chunk(2), expected_dict,
+ ArrayFromJSON(int8(), "[2, 0, null, 0]"));
+ CheckDictionaryArray(unified->chunk(3), expected_dict, ArrayFromJSON(int8(), "[]"));
+}
+
+TEST(TestDictionaryUnifier, ChunkedArrayZeroChunk) {
+ auto type = dictionary(int8(), utf8());
+ ASSERT_OK_AND_ASSIGN(auto chunked, ChunkedArray::Make(ArrayVector{}, type));
+ ASSERT_OK_AND_ASSIGN(auto unified, DictionaryUnifier::UnifyChunkedArray(chunked));
+ AssertChunkedEqual(*chunked, *unified);
+}
+
+TEST(TestDictionaryUnifier, ChunkedArrayOneChunk) {
+ auto type = dictionary(int8(), utf8());
+ auto chunk1 = ArrayFromJSON(type, R"(["ab", "cd", null, "cd"])");
+ ASSERT_OK_AND_ASSIGN(auto chunked, ChunkedArray::Make({chunk1}));
+ ASSERT_OK_AND_ASSIGN(auto unified, DictionaryUnifier::UnifyChunkedArray(chunked));
+ AssertChunkedEqual(*chunked, *unified);
+}
+
+TEST(TestDictionaryUnifier, ChunkedArrayNoDict) {
+ auto type = int8();
+ auto chunk1 = ArrayFromJSON(type, "[1, 1, 2, 3]");
+ auto chunk2 = ArrayFromJSON(type, "[5, 8, 13]");
+ ASSERT_OK_AND_ASSIGN(auto chunked, ChunkedArray::Make({chunk1, chunk2}));
+ ASSERT_OK_AND_ASSIGN(auto unified, DictionaryUnifier::UnifyChunkedArray(chunked));
+ AssertChunkedEqual(*chunked, *unified);
+}
+
+TEST(TestDictionaryUnifier, ChunkedArrayNested) {
+ // Dict in a nested type: ok
+ auto type = list(dictionary(int16(), utf8()));
+ auto chunk1 = ArrayFromJSON(type, R"([["ab", "cd"], ["cd"]])");
+ auto chunk2 = ArrayFromJSON(type, R"([[], ["ef", "cd", "ef"]])");
+ ASSERT_OK_AND_ASSIGN(auto chunked, ChunkedArray::Make({chunk1, chunk2}));
+
+ ASSERT_OK_AND_ASSIGN(auto unified, DictionaryUnifier::UnifyChunkedArray(chunked));
+ ASSERT_EQ(unified->num_chunks(), 2);
+ auto expected_dict = ArrayFromJSON(utf8(), R"(["ab", "cd", "ef"])");
+ auto unified1 = checked_pointer_cast<ListArray>(unified->chunk(0));
+ AssertArraysEqual(*unified1->offsets(), *ArrayFromJSON(int32(), "[0, 2, 3]"));
+ CheckDictionaryArray(unified1->values(), expected_dict,
+ ArrayFromJSON(int16(), "[0, 1, 1]"));
+ auto unified2 = checked_pointer_cast<ListArray>(unified->chunk(1));
+ AssertArraysEqual(*unified2->offsets(), *ArrayFromJSON(int32(), "[0, 0, 3]"));
+ CheckDictionaryArray(unified2->values(), expected_dict,
+ ArrayFromJSON(int16(), "[2, 1, 2]"));
+}
+
+TEST(TestDictionaryUnifier, ChunkedArrayExtension) {
+ // Dict in an extension type: ok
+ auto type = dict_extension_type();
+ auto chunk1 = DictExtensionFromJSON(type, R"(["ab", null, "cd", "ab"])");
+ auto chunk2 = DictExtensionFromJSON(type, R"(["ef", "ab", "ab"])");
+ ASSERT_OK_AND_ASSIGN(auto chunked, ChunkedArray::Make({chunk1, chunk2}));
+
+ ASSERT_OK_AND_ASSIGN(auto unified, DictionaryUnifier::UnifyChunkedArray(chunked));
+ ASSERT_EQ(unified->num_chunks(), 2);
+
+ auto expected_dict = ArrayFromJSON(utf8(), R"(["ab", "cd", "ef"])");
+ auto unified1 = checked_pointer_cast<ExtensionArray>(unified->chunk(0));
+ AssertTypeEqual(*type, *unified1->type());
+ CheckDictionaryArray(unified1->storage(), expected_dict,
+ ArrayFromJSON(int8(), "[0, null, 1, 0]"));
+ auto unified2 = checked_pointer_cast<ExtensionArray>(unified->chunk(1));
+ AssertTypeEqual(*type, *unified2->type());
+ CheckDictionaryArray(unified2->storage(), expected_dict,
+ ArrayFromJSON(int8(), "[2, 0, 0]"));
+}
+
+TEST(TestDictionaryUnifier, ChunkedArrayNestedDict) {
+ // Dict in a dict type: unsupported
+ auto inner_type = list(dictionary(uint32(), utf8()));
+ auto inner_dict1 = ArrayFromJSON(inner_type, R"([["ab", "cd"], [], ["cd", null]])");
+ ASSERT_OK_AND_ASSIGN(
+ auto chunk1, DictionaryArray::FromArrays(ArrayFromJSON(int32(), "[2, 1, 0, 1, 2]"),
+ inner_dict1));
+ auto inner_dict2 = ArrayFromJSON(inner_type, R"([["cd", "ef"], ["cd", null], []])");
+ ASSERT_OK_AND_ASSIGN(
+ auto chunk2,
+ DictionaryArray::FromArrays(ArrayFromJSON(int32(), "[1, 2, 2, 0]"), inner_dict2));
+ ASSERT_OK_AND_ASSIGN(auto chunked, ChunkedArray::Make({chunk1, chunk2}));
+
+ ASSERT_RAISES(NotImplemented, DictionaryUnifier::UnifyChunkedArray(chunked));
+}
+
+TEST(TestDictionaryUnifier, TableZeroColumns) {
+ auto schema = ::arrow::schema(FieldVector{});
+ auto table = Table::Make(schema, ArrayVector{}, /*num_rows=*/42);
+
+ ASSERT_OK_AND_ASSIGN(auto unified, DictionaryUnifier::UnifyTable(*table));
+ AssertSchemaEqual(*schema, *unified->schema());
+ ASSERT_EQ(unified->num_rows(), 42);
+ AssertTablesEqual(*table, *unified);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_list_test.cc b/src/arrow/cpp/src/arrow/array/array_list_test.cc
new file mode 100644
index 000000000..a503cbd51
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_list_test.cc
@@ -0,0 +1,1182 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/buffer.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_builders.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+using ListTypes = ::testing::Types<ListType, LargeListType>;
+
+// ----------------------------------------------------------------------
+// List tests
+
+template <typename T>
+class TestListArray : public TestBuilder {
+ public:
+ using TypeClass = T;
+ using offset_type = typename TypeClass::offset_type;
+ using ArrayType = typename TypeTraits<TypeClass>::ArrayType;
+ using BuilderType = typename TypeTraits<TypeClass>::BuilderType;
+ using OffsetType = typename TypeTraits<TypeClass>::OffsetType;
+ using OffsetArrayType = typename TypeTraits<TypeClass>::OffsetArrayType;
+ using OffsetBuilderType = typename TypeTraits<TypeClass>::OffsetBuilderType;
+
+ void SetUp() {
+ TestBuilder::SetUp();
+
+ value_type_ = int16();
+ type_ = std::make_shared<T>(value_type_);
+
+ std::unique_ptr<ArrayBuilder> tmp;
+ ASSERT_OK(MakeBuilder(pool_, type_, &tmp));
+ builder_.reset(checked_cast<BuilderType*>(tmp.release()));
+ }
+
+ void Done() {
+ std::shared_ptr<Array> out;
+ FinishAndCheckPadding(builder_.get(), &out);
+ result_ = std::dynamic_pointer_cast<ArrayType>(out);
+ }
+
+ void ValidateBasicListArray(const ArrayType* result, const std::vector<int16_t>& values,
+ const std::vector<uint8_t>& is_valid) {
+ ASSERT_OK(result->ValidateFull());
+ ASSERT_EQ(1, result->null_count());
+ ASSERT_EQ(0, result->values()->null_count());
+
+ ASSERT_EQ(3, result->length());
+ std::vector<offset_type> ex_offsets = {0, 3, 3, 7};
+ for (size_t i = 0; i < ex_offsets.size(); ++i) {
+ ASSERT_EQ(ex_offsets[i], result->value_offset(i));
+ }
+
+ for (int i = 0; i < result->length(); ++i) {
+ ASSERT_EQ(is_valid[i] == 0, result->IsNull(i));
+ }
+
+ ASSERT_EQ(7, result->values()->length());
+ auto varr = std::dynamic_pointer_cast<Int16Array>(result->values());
+
+ for (size_t i = 0; i < values.size(); ++i) {
+ ASSERT_EQ(values[i], varr->Value(i));
+ }
+
+ auto offsets = std::dynamic_pointer_cast<OffsetArrayType>(result->offsets());
+ ASSERT_EQ(offsets->length(), result->length() + 1);
+ ASSERT_EQ(offsets->null_count(), 0);
+ AssertTypeEqual(*offsets->type(), OffsetType());
+
+ for (int64_t i = 0; i < result->length(); ++i) {
+ ASSERT_EQ(offsets->Value(i), result_->raw_value_offsets()[i]);
+ }
+ // last offset
+ ASSERT_EQ(offsets->Value(result->length()),
+ result_->raw_value_offsets()[result->length()]);
+ }
+
+ void TestBasics() {
+ std::vector<int16_t> values = {0, 1, 2, 3, 4, 5, 6};
+ std::vector<int> lengths = {3, 0, 4};
+ std::vector<uint8_t> is_valid = {1, 0, 1};
+
+ Int16Builder* vb = checked_cast<Int16Builder*>(builder_->value_builder());
+
+ ASSERT_OK(builder_->Reserve(lengths.size()));
+ ASSERT_OK(vb->Reserve(values.size()));
+
+ int pos = 0;
+ for (size_t i = 0; i < lengths.size(); ++i) {
+ ASSERT_OK(builder_->Append(is_valid[i] > 0));
+ for (int j = 0; j < lengths[i]; ++j) {
+ ASSERT_OK(vb->Append(values[pos++]));
+ }
+ }
+
+ Done();
+ ValidateBasicListArray(result_.get(), values, is_valid);
+ }
+
+ void TestEquality() {
+ auto vb = checked_cast<Int16Builder*>(builder_->value_builder());
+
+ std::shared_ptr<Array> array, equal_array, unequal_array;
+ std::vector<offset_type> equal_offsets = {0, 1, 2, 5, 6, 7, 8, 10};
+ std::vector<int16_t> equal_values = {1, 2, 3, 4, 5, 2, 2, 2, 5, 6};
+ std::vector<offset_type> unequal_offsets = {0, 1, 4, 7};
+ std::vector<int16_t> unequal_values = {1, 2, 2, 2, 3, 4, 5};
+
+ // setup two equal arrays
+ ASSERT_OK(builder_->AppendValues(equal_offsets.data(), equal_offsets.size()));
+ ASSERT_OK(vb->AppendValues(equal_values.data(), equal_values.size()));
+
+ ASSERT_OK(builder_->Finish(&array));
+ ASSERT_OK(builder_->AppendValues(equal_offsets.data(), equal_offsets.size()));
+ ASSERT_OK(vb->AppendValues(equal_values.data(), equal_values.size()));
+
+ ASSERT_OK(builder_->Finish(&equal_array));
+ // now an unequal one
+ ASSERT_OK(builder_->AppendValues(unequal_offsets.data(), unequal_offsets.size()));
+ ASSERT_OK(vb->AppendValues(unequal_values.data(), unequal_values.size()));
+
+ ASSERT_OK(builder_->Finish(&unequal_array));
+
+ // Test array equality
+ EXPECT_TRUE(array->Equals(array));
+ EXPECT_TRUE(array->Equals(equal_array));
+ EXPECT_TRUE(equal_array->Equals(array));
+ EXPECT_FALSE(equal_array->Equals(unequal_array));
+ EXPECT_FALSE(unequal_array->Equals(equal_array));
+
+ // Test range equality
+ EXPECT_TRUE(array->RangeEquals(0, 1, 0, unequal_array));
+ EXPECT_FALSE(array->RangeEquals(0, 2, 0, unequal_array));
+ EXPECT_FALSE(array->RangeEquals(1, 2, 1, unequal_array));
+ EXPECT_TRUE(array->RangeEquals(2, 3, 2, unequal_array));
+
+ // Check with slices, ARROW-33
+ std::shared_ptr<Array> slice, slice2;
+
+ slice = array->Slice(2);
+ slice2 = array->Slice(2);
+ ASSERT_EQ(array->length() - 2, slice->length());
+
+ ASSERT_TRUE(slice->Equals(slice2));
+ ASSERT_TRUE(array->RangeEquals(2, slice->length(), 0, slice));
+
+ // Chained slices
+ slice2 = array->Slice(1)->Slice(1);
+ ASSERT_TRUE(slice->Equals(slice2));
+
+ slice = array->Slice(1, 4);
+ slice2 = array->Slice(1, 4);
+ ASSERT_EQ(4, slice->length());
+
+ ASSERT_TRUE(slice->Equals(slice2));
+ ASSERT_TRUE(array->RangeEquals(1, 5, 0, slice));
+ }
+
+ void TestValuesEquality() {
+ auto type = std::make_shared<T>(int32());
+ auto left = ArrayFromJSON(type, "[[1, 2], [3], [0]]");
+ auto right = ArrayFromJSON(type, "[[1, 2], [3], [100000]]");
+ auto offset = 2;
+ EXPECT_FALSE(left->Slice(offset)->Equals(right->Slice(offset)));
+ }
+
+ void TestFromArrays() {
+ std::shared_ptr<Array> offsets1, offsets2, offsets3, offsets4, offsets5, values;
+
+ std::vector<bool> offsets_is_valid3 = {true, false, true, true};
+ std::vector<bool> offsets_is_valid4 = {true, true, false, true};
+ std::vector<bool> offsets_is_valid5 = {true, true, false, false};
+
+ std::vector<bool> values_is_valid = {true, false, true, true, true, true};
+
+ std::vector<offset_type> offset1_values = {0, 2, 2, 6};
+ std::vector<offset_type> offset2_values = {0, 2, 6, 6};
+
+ std::vector<int8_t> values_values = {0, 1, 2, 3, 4, 5};
+ const int length = 3;
+
+ ArrayFromVector<OffsetType, offset_type>(offset1_values, &offsets1);
+ ArrayFromVector<OffsetType, offset_type>(offset2_values, &offsets2);
+
+ ArrayFromVector<OffsetType, offset_type>(offsets_is_valid3, offset1_values,
+ &offsets3);
+ ArrayFromVector<OffsetType, offset_type>(offsets_is_valid4, offset2_values,
+ &offsets4);
+ ArrayFromVector<OffsetType, offset_type>(offsets_is_valid5, offset2_values,
+ &offsets5);
+
+ ArrayFromVector<Int8Type, int8_t>(values_is_valid, values_values, &values);
+
+ auto list_type = std::make_shared<T>(int8());
+
+ ASSERT_OK_AND_ASSIGN(auto list1, ArrayType::FromArrays(*offsets1, *values, pool_));
+ ASSERT_OK_AND_ASSIGN(auto list3, ArrayType::FromArrays(*offsets3, *values, pool_));
+ ASSERT_OK_AND_ASSIGN(auto list4, ArrayType::FromArrays(*offsets4, *values, pool_));
+ ASSERT_OK(list1->ValidateFull());
+ ASSERT_OK(list3->ValidateFull());
+ ASSERT_OK(list4->ValidateFull());
+
+ ArrayType expected1(list_type, length, offsets1->data()->buffers[1], values,
+ offsets1->data()->buffers[0], 0);
+ AssertArraysEqual(expected1, *list1);
+
+ // Use null bitmap from offsets3, but clean offsets from non-null version
+ ArrayType expected3(list_type, length, offsets1->data()->buffers[1], values,
+ offsets3->data()->buffers[0], 1);
+ AssertArraysEqual(expected3, *list3);
+
+ // Check that the last offset bit is zero
+ ASSERT_FALSE(BitUtil::GetBit(list3->null_bitmap()->data(), length + 1));
+
+ ArrayType expected4(list_type, length, offsets2->data()->buffers[1], values,
+ offsets4->data()->buffers[0], 1);
+ AssertArraysEqual(expected4, *list4);
+
+ // Test failure modes
+
+ std::shared_ptr<Array> tmp;
+
+ // Zero-length offsets
+ ASSERT_RAISES(Invalid, ArrayType::FromArrays(*offsets1->Slice(0, 0), *values, pool_));
+
+ // Offsets not the right type
+ ASSERT_RAISES(TypeError, ArrayType::FromArrays(*values, *offsets1, pool_));
+
+ // Null final offset
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Last list offset should be non-null"),
+ ArrayType::FromArrays(*offsets5, *values, pool_));
+
+ // ARROW-12077: check for off-by-one in construction (need mimalloc/ASan/Valgrind)
+ {
+ std::shared_ptr<Array> offsets, values;
+ // Length multiple of 8 - we'll allocate a validity buffer with exactly enough bits
+ // (Need a large enough buffer or else ASan doesn't catch it)
+ std::vector<bool> offsets_is_valid(4096);
+ std::vector<offset_type> offset_values(4096);
+ std::vector<int8_t> values_values(4096);
+ std::fill(offsets_is_valid.begin(), offsets_is_valid.end(), true);
+ offsets_is_valid[1] = false;
+ std::fill(offset_values.begin(), offset_values.end(), 0);
+ std::fill(values_values.begin(), values_values.end(), 0);
+ ArrayFromVector<OffsetType, offset_type>(offsets_is_valid, offset_values, &offsets);
+ ArrayFromVector<Int8Type, int8_t>(values_values, &values);
+ ASSERT_OK_AND_ASSIGN(auto list, ArrayType::FromArrays(*offsets, *values, pool_));
+ }
+ }
+
+ void TestAppendNull() {
+ ASSERT_OK(builder_->AppendNull());
+ ASSERT_OK(builder_->AppendNull());
+
+ Done();
+
+ ASSERT_OK(result_->ValidateFull());
+ ASSERT_TRUE(result_->IsNull(0));
+ ASSERT_TRUE(result_->IsNull(1));
+
+ ASSERT_EQ(0, result_->raw_value_offsets()[0]);
+ ASSERT_EQ(0, result_->value_offset(1));
+ ASSERT_EQ(0, result_->value_offset(2));
+
+ auto values = result_->values();
+ ASSERT_EQ(0, values->length());
+ // Values buffer should be non-null
+ ASSERT_NE(nullptr, values->data()->buffers[1]);
+ }
+
+ void TestAppendNulls() {
+ ASSERT_OK(builder_->AppendNulls(3));
+
+ Done();
+
+ ASSERT_OK(result_->ValidateFull());
+ ASSERT_EQ(result_->length(), 3);
+ ASSERT_EQ(result_->null_count(), 3);
+ ASSERT_TRUE(result_->IsNull(0));
+ ASSERT_TRUE(result_->IsNull(1));
+ ASSERT_TRUE(result_->IsNull(2));
+
+ ASSERT_EQ(0, result_->raw_value_offsets()[0]);
+ ASSERT_EQ(0, result_->value_offset(1));
+ ASSERT_EQ(0, result_->value_offset(2));
+ ASSERT_EQ(0, result_->value_offset(3));
+
+ auto values = result_->values();
+ ASSERT_EQ(0, values->length());
+ // Values buffer should be non-null
+ ASSERT_NE(nullptr, values->data()->buffers[1]);
+ }
+
+ void TestBulkAppend() {
+ std::vector<int16_t> values = {0, 1, 2, 3, 4, 5, 6};
+ std::vector<uint8_t> is_valid = {1, 0, 1};
+ std::vector<offset_type> offsets = {0, 3, 3};
+
+ Int16Builder* vb = checked_cast<Int16Builder*>(builder_->value_builder());
+ ASSERT_OK(vb->Reserve(values.size()));
+
+ ASSERT_OK(builder_->AppendValues(offsets.data(), offsets.size(), is_valid.data()));
+ for (int16_t value : values) {
+ ASSERT_OK(vb->Append(value));
+ }
+ Done();
+ ValidateBasicListArray(result_.get(), values, is_valid);
+ }
+
+ void TestBulkAppendInvalid() {
+ std::vector<int16_t> values = {0, 1, 2, 3, 4, 5, 6};
+ std::vector<int> lengths = {3, 0, 4};
+ std::vector<uint8_t> is_valid = {1, 0, 1};
+ // Should be {0, 3, 3} given the is_valid array
+ std::vector<offset_type> offsets = {0, 2, 4};
+
+ Int16Builder* vb = checked_cast<Int16Builder*>(builder_->value_builder());
+ ASSERT_OK(vb->Reserve(values.size()));
+
+ ASSERT_OK(builder_->AppendValues(offsets.data(), offsets.size(), is_valid.data()));
+ ASSERT_OK(builder_->AppendValues(offsets.data(), offsets.size(), is_valid.data()));
+ for (int16_t value : values) {
+ ASSERT_OK(vb->Append(value));
+ }
+
+ Done();
+ ASSERT_RAISES(Invalid, result_->ValidateFull());
+ }
+
+ void TestZeroLength() {
+ // All buffers are null
+ Done();
+ ASSERT_OK(result_->ValidateFull());
+ }
+
+ void TestBuilderPreserveFieldName() {
+ auto list_type_with_name = std::make_shared<T>(field("counts", int16()));
+
+ std::unique_ptr<ArrayBuilder> tmp;
+ ASSERT_OK(MakeBuilder(pool_, list_type_with_name, &tmp));
+ builder_.reset(checked_cast<BuilderType*>(tmp.release()));
+
+ std::vector<offset_type> offsets = {1, 2, 4, 8};
+ ASSERT_OK(builder_->AppendValues(offsets.data(), offsets.size()));
+
+ std::shared_ptr<Array> list_array;
+ ASSERT_OK(builder_->Finish(&list_array));
+
+ const auto& type = checked_cast<T&>(*list_array->type());
+ ASSERT_EQ("counts", type.value_field()->name());
+ }
+
+ void TestFlattenZeroLength() {
+ Done();
+ ASSERT_OK_AND_ASSIGN(auto flattened, result_->Flatten());
+ ASSERT_OK(flattened->ValidateFull());
+ ASSERT_EQ(0, flattened->length());
+ }
+
+ void TestFlattenSimple() {
+ auto type = std::make_shared<T>(int32());
+ auto list_array = std::dynamic_pointer_cast<ArrayType>(
+ ArrayFromJSON(type, "[[1, 2], [3], [4], null, [5], [], [6]]"));
+ ASSERT_OK_AND_ASSIGN(auto flattened, list_array->Flatten());
+ ASSERT_OK(flattened->ValidateFull());
+ EXPECT_TRUE(flattened->Equals(ArrayFromJSON(int32(), "[1, 2, 3, 4, 5, 6]")));
+ }
+
+ void TestFlattenSliced() {
+ auto type = std::make_shared<T>(int32());
+ auto list_array = std::dynamic_pointer_cast<ArrayType>(
+ ArrayFromJSON(type, "[[1, 2], [3], [4], null, [5], [], [6]]"));
+ auto sliced_list_array =
+ std::dynamic_pointer_cast<ArrayType>(list_array->Slice(3, 4));
+ ASSERT_OK_AND_ASSIGN(auto flattened, list_array->Flatten());
+ ASSERT_OK(flattened->ValidateFull());
+ // Note the difference between values() and Flatten().
+ EXPECT_TRUE(flattened->Equals(ArrayFromJSON(int32(), "[5, 6]")));
+ EXPECT_TRUE(sliced_list_array->values()->Equals(
+ ArrayFromJSON(int32(), "[1, 2, 3, 4, 5, 6]")));
+ }
+
+ void TestFlattenNonEmptyBackingNulls() {
+ auto type = std::make_shared<T>(int32());
+ auto array_data =
+ std::dynamic_pointer_cast<ArrayType>(
+ ArrayFromJSON(type, "[[1, 2], [3], null, [5, 6], [7, 8], [], [9]]"))
+ ->data();
+ ASSERT_EQ(2, array_data->buffers.size());
+ auto null_bitmap_buffer = array_data->buffers[0];
+ ASSERT_NE(nullptr, null_bitmap_buffer);
+ BitUtil::ClearBit(null_bitmap_buffer->mutable_data(), 1);
+ BitUtil::ClearBit(null_bitmap_buffer->mutable_data(), 3);
+ BitUtil::ClearBit(null_bitmap_buffer->mutable_data(), 4);
+ array_data->null_count += 3;
+ auto list_array = std::dynamic_pointer_cast<ArrayType>(MakeArray(array_data));
+ ASSERT_OK(list_array->ValidateFull());
+ ASSERT_OK_AND_ASSIGN(auto flattened, list_array->Flatten());
+ EXPECT_TRUE(flattened->Equals(ArrayFromJSON(int32(), "[1, 2, 9]")))
+ << flattened->ToString();
+ }
+
+ Status ValidateOffsets(int64_t length, std::vector<offset_type> offsets,
+ const std::shared_ptr<Array>& values, int64_t offset = 0) {
+ auto type = std::make_shared<TypeClass>(values->type());
+ ArrayType arr(type, length, Buffer::Wrap(offsets), values,
+ /*null_bitmap=*/nullptr, /*null_count=*/0, offset);
+ return arr.ValidateFull();
+ }
+
+ void TestValidateOffsets() {
+ auto empty_values = ArrayFromJSON(int16(), "[]");
+ auto values = ArrayFromJSON(int16(), "[1, 2, 3, 4, 5, 6, 7]");
+
+ // An empty list array can have omitted or 0-length offsets
+ ASSERT_OK(ValidateOffsets(0, {}, empty_values));
+
+ ASSERT_OK(ValidateOffsets(0, {0}, empty_values));
+ ASSERT_OK(ValidateOffsets(1, {0, 7}, values));
+ ASSERT_OK(ValidateOffsets(2, {0, 4, 7}, values));
+ ASSERT_OK(ValidateOffsets(3, {0, 4, 7, 7}, values));
+
+ // Non-zero array offset
+ ASSERT_OK(ValidateOffsets(1, {0, 4, 7}, values, 1));
+ ASSERT_OK(ValidateOffsets(0, {0, 4, 7}, values, 2));
+
+ // Not enough offsets
+ ASSERT_RAISES(Invalid, ValidateOffsets(1, {0}, values));
+ ASSERT_RAISES(Invalid, ValidateOffsets(2, {0, 0}, values, 1));
+
+ // Offset out of bounds
+ ASSERT_RAISES(Invalid, ValidateOffsets(1, {0, 8}, values));
+ ASSERT_RAISES(Invalid, ValidateOffsets(1, {0, 8, 8}, values, 1));
+ // Negative offset
+ ASSERT_RAISES(Invalid, ValidateOffsets(1, {-1, 0}, values));
+ ASSERT_RAISES(Invalid, ValidateOffsets(1, {0, -1}, values));
+ ASSERT_RAISES(Invalid, ValidateOffsets(2, {0, -1, -1}, values, 1));
+ // Offsets non-monotonic
+ ASSERT_RAISES(Invalid, ValidateOffsets(2, {0, 7, 4}, values));
+ }
+
+ void TestCornerCases() {
+ // ARROW-7985
+ ASSERT_OK(builder_->AppendNull());
+ Done();
+ auto expected = ArrayFromJSON(type_, "[null]");
+ AssertArraysEqual(*result_, *expected);
+
+ SetUp();
+ ASSERT_OK(builder_->Append());
+ Done();
+ expected = ArrayFromJSON(type_, "[[]]");
+ AssertArraysEqual(*result_, *expected);
+
+ SetUp();
+ ASSERT_OK(builder_->AppendNull());
+ ASSERT_OK(builder_->value_builder()->Reserve(100));
+ Done();
+ expected = ArrayFromJSON(type_, "[null]");
+ AssertArraysEqual(*result_, *expected);
+ }
+
+ void TestOverflowCheck() {
+ Int16Builder* vb = checked_cast<Int16Builder*>(builder_->value_builder());
+ auto max_elements = builder_->maximum_elements();
+
+ ASSERT_OK(builder_->ValidateOverflow(1));
+ ASSERT_OK(builder_->ValidateOverflow(max_elements));
+ ASSERT_RAISES(CapacityError, builder_->ValidateOverflow(max_elements + 1));
+
+ ASSERT_OK(builder_->Append());
+ ASSERT_OK(vb->Append(1));
+ ASSERT_OK(vb->Append(2));
+ ASSERT_OK(builder_->ValidateOverflow(max_elements - 2));
+ ASSERT_RAISES(CapacityError, builder_->ValidateOverflow(max_elements - 1));
+
+ ASSERT_OK(builder_->AppendNull());
+ ASSERT_OK(builder_->ValidateOverflow(max_elements - 2));
+ ASSERT_RAISES(CapacityError, builder_->ValidateOverflow(max_elements - 1));
+
+ ASSERT_OK(builder_->Append());
+ ASSERT_OK(vb->Append(1));
+ ASSERT_OK(vb->Append(2));
+ ASSERT_OK(vb->Append(3));
+ ASSERT_OK(builder_->ValidateOverflow(max_elements - 5));
+ ASSERT_RAISES(CapacityError, builder_->ValidateOverflow(max_elements - 4));
+ }
+
+ protected:
+ std::shared_ptr<DataType> value_type_;
+
+ std::shared_ptr<BuilderType> builder_;
+ std::shared_ptr<ArrayType> result_;
+};
+
+TYPED_TEST_SUITE(TestListArray, ListTypes);
+
+TYPED_TEST(TestListArray, Basics) { this->TestBasics(); }
+
+TYPED_TEST(TestListArray, Equality) { this->TestEquality(); }
+
+TYPED_TEST(TestListArray, ValuesEquality) { this->TestValuesEquality(); }
+
+TYPED_TEST(TestListArray, FromArrays) { this->TestFromArrays(); }
+
+TYPED_TEST(TestListArray, AppendNull) { this->TestAppendNull(); }
+
+TYPED_TEST(TestListArray, AppendNulls) { this->TestAppendNulls(); }
+
+TYPED_TEST(TestListArray, BulkAppend) { this->TestBulkAppend(); }
+
+TYPED_TEST(TestListArray, BulkAppendInvalid) { this->TestBulkAppendInvalid(); }
+
+TYPED_TEST(TestListArray, ZeroLength) { this->TestZeroLength(); }
+
+TYPED_TEST(TestListArray, BuilderPreserveFieldName) {
+ this->TestBuilderPreserveFieldName();
+}
+
+TYPED_TEST(TestListArray, FlattenSimple) { this->TestFlattenSimple(); }
+TYPED_TEST(TestListArray, FlattenZeroLength) { this->TestFlattenZeroLength(); }
+TYPED_TEST(TestListArray, TestFlattenNonEmptyBackingNulls) {
+ this->TestFlattenNonEmptyBackingNulls();
+}
+
+TYPED_TEST(TestListArray, ValidateOffsets) { this->TestValidateOffsets(); }
+
+TYPED_TEST(TestListArray, CornerCases) { this->TestCornerCases(); }
+
+#ifndef ARROW_LARGE_MEMORY_TESTS
+TYPED_TEST(TestListArray, DISABLED_TestOverflowCheck) { this->TestOverflowCheck(); }
+#else
+TYPED_TEST(TestListArray, TestOverflowCheck) { this->TestOverflowCheck(); }
+#endif
+
+// ----------------------------------------------------------------------
+// Map tests
+
+class TestMapArray : public TestBuilder {
+ public:
+ using offset_type = typename MapType::offset_type;
+ using OffsetType = typename TypeTraits<MapType>::OffsetType;
+
+ void SetUp() {
+ TestBuilder::SetUp();
+
+ key_type_ = utf8();
+ value_type_ = int32();
+ type_ = map(key_type_, value_type_);
+
+ std::unique_ptr<ArrayBuilder> tmp;
+ ASSERT_OK(MakeBuilder(pool_, type_, &tmp));
+ builder_ = checked_pointer_cast<MapBuilder>(std::move(tmp));
+ }
+
+ void Done() {
+ std::shared_ptr<Array> out;
+ FinishAndCheckPadding(builder_.get(), &out);
+ result_ = std::dynamic_pointer_cast<MapArray>(out);
+ }
+
+ protected:
+ std::shared_ptr<DataType> value_type_, key_type_;
+
+ std::shared_ptr<MapBuilder> builder_;
+ std::shared_ptr<MapArray> result_;
+};
+
+TEST_F(TestMapArray, Equality) {
+ auto& kb = checked_cast<StringBuilder&>(*builder_->key_builder());
+ auto& ib = checked_cast<Int32Builder&>(*builder_->item_builder());
+
+ std::shared_ptr<Array> array, equal_array, unequal_array;
+ std::vector<int32_t> equal_offsets = {0, 1, 2, 5, 6, 7, 8, 10};
+ std::vector<util::string_view> equal_keys = {"a", "a", "a", "b", "c",
+ "a", "a", "a", "a", "b"};
+ std::vector<int32_t> equal_values = {1, 2, 3, 4, 5, 2, 2, 2, 5, 6};
+ std::vector<int32_t> unequal_offsets = {0, 1, 4, 7};
+ std::vector<util::string_view> unequal_keys = {"a", "a", "b", "c", "a", "b", "c"};
+ std::vector<int32_t> unequal_values = {1, 2, 2, 2, 3, 4, 5};
+
+ // setup two equal arrays
+ for (auto out : {&array, &equal_array}) {
+ ASSERT_OK(builder_->AppendValues(equal_offsets.data(), equal_offsets.size()));
+ for (auto&& key : equal_keys) {
+ ASSERT_OK(kb.Append(key));
+ }
+ ASSERT_OK(ib.AppendValues(equal_values.data(), equal_values.size()));
+ ASSERT_OK(builder_->Finish(out));
+ }
+
+ // now an unequal one
+ ASSERT_OK(builder_->AppendValues(unequal_offsets.data(), unequal_offsets.size()));
+ for (auto&& key : unequal_keys) {
+ ASSERT_OK(kb.Append(key));
+ }
+ ASSERT_OK(ib.AppendValues(unequal_values.data(), unequal_values.size()));
+ ASSERT_OK(builder_->Finish(&unequal_array));
+
+ // Test array equality
+ EXPECT_TRUE(array->Equals(array));
+ EXPECT_TRUE(array->Equals(equal_array));
+ EXPECT_TRUE(equal_array->Equals(array));
+ EXPECT_FALSE(equal_array->Equals(unequal_array));
+ EXPECT_FALSE(unequal_array->Equals(equal_array));
+
+ // Test range equality
+ EXPECT_TRUE(array->RangeEquals(0, 1, 0, unequal_array));
+ EXPECT_FALSE(array->RangeEquals(0, 2, 0, unequal_array));
+ EXPECT_FALSE(array->RangeEquals(1, 2, 1, unequal_array));
+ EXPECT_TRUE(array->RangeEquals(2, 3, 2, unequal_array));
+}
+
+TEST_F(TestMapArray, BuildingIntToInt) {
+ auto type = map(int16(), int16());
+
+ auto expected_keys = ArrayFromJSON(int16(), R"([
+ 0, 1, 2, 3, 4, 5,
+ 0, 1, 2, 3, 4, 5
+ ])");
+ auto expected_items = ArrayFromJSON(int16(), R"([
+ 1, 1, 2, 3, 5, 8,
+ null, null, 0, 1, null, 2
+ ])");
+ auto expected_offsets = ArrayFromJSON(int32(), "[0, 6, 6, 12, 12]")->data()->buffers[1];
+ auto expected_null_bitmap =
+ ArrayFromJSON(boolean(), "[1, 0, 1, 1]")->data()->buffers[1];
+
+ MapArray expected(type, 4, expected_offsets, expected_keys, expected_items,
+ expected_null_bitmap, 1, 0);
+
+ auto key_builder = std::make_shared<Int16Builder>();
+ auto item_builder = std::make_shared<Int16Builder>();
+ MapBuilder map_builder(default_memory_pool(), key_builder, item_builder);
+
+ std::shared_ptr<Array> actual;
+ ASSERT_OK(map_builder.Append());
+ ASSERT_OK(key_builder->AppendValues({0, 1, 2, 3, 4, 5}));
+ ASSERT_OK(item_builder->AppendValues({1, 1, 2, 3, 5, 8}));
+ ASSERT_OK(map_builder.AppendNull());
+ ASSERT_OK(map_builder.Append());
+ ASSERT_OK(key_builder->AppendValues({0, 1, 2, 3, 4, 5}));
+ ASSERT_OK(item_builder->AppendValues({-1, -1, 0, 1, -1, 2}, {0, 0, 1, 1, 0, 1}));
+ ASSERT_OK(map_builder.Append());
+ ASSERT_OK(map_builder.Finish(&actual));
+ ASSERT_OK(actual->ValidateFull());
+
+ ASSERT_ARRAYS_EQUAL(*actual, expected);
+}
+
+TEST_F(TestMapArray, BuildingStringToInt) {
+ auto type = map(utf8(), int32());
+
+ std::vector<int32_t> offsets = {0, 2, 2, 3, 3};
+ auto expected_keys = ArrayFromJSON(utf8(), R"(["joe", "mark", "cap"])");
+ auto expected_values = ArrayFromJSON(int32(), "[0, null, 8]");
+ ASSERT_OK_AND_ASSIGN(auto expected_null_bitmap, internal::BytesToBits({1, 0, 1, 1}));
+ MapArray expected(type, 4, Buffer::Wrap(offsets), expected_keys, expected_values,
+ expected_null_bitmap, 1);
+
+ auto key_builder = std::make_shared<StringBuilder>();
+ auto item_builder = std::make_shared<Int32Builder>();
+ MapBuilder map_builder(default_memory_pool(), key_builder, item_builder);
+
+ std::shared_ptr<Array> actual;
+ ASSERT_OK(map_builder.Append());
+ ASSERT_OK(key_builder->Append("joe"));
+ ASSERT_OK(item_builder->Append(0));
+ ASSERT_OK(key_builder->Append("mark"));
+ ASSERT_OK(item_builder->AppendNull());
+ ASSERT_OK(map_builder.AppendNull());
+ ASSERT_OK(map_builder.Append());
+ ASSERT_OK(key_builder->Append("cap"));
+ ASSERT_OK(item_builder->Append(8));
+ ASSERT_OK(map_builder.Append());
+ ASSERT_OK(map_builder.Finish(&actual));
+ ASSERT_OK(actual->ValidateFull());
+
+ ASSERT_ARRAYS_EQUAL(*actual, expected);
+}
+
+TEST_F(TestMapArray, FromArrays) {
+ std::shared_ptr<Array> offsets1, offsets2, offsets3, offsets4, keys, items;
+
+ std::vector<bool> offsets_is_valid3 = {true, false, true, true};
+ std::vector<bool> offsets_is_valid4 = {true, true, false, true};
+
+ std::vector<bool> items_is_valid = {true, false, true, true, true, true};
+
+ std::vector<MapType::offset_type> offset1_values = {0, 2, 2, 6};
+ std::vector<MapType::offset_type> offset2_values = {0, 2, 6, 6};
+
+ std::vector<int8_t> key_values = {0, 1, 2, 3, 4, 5};
+ std::vector<int16_t> item_values = {10, 9, 8, 7, 6, 5};
+ const int length = 3;
+
+ ArrayFromVector<OffsetType, offset_type>(offset1_values, &offsets1);
+ ArrayFromVector<OffsetType, offset_type>(offset2_values, &offsets2);
+
+ ArrayFromVector<OffsetType, offset_type>(offsets_is_valid3, offset1_values, &offsets3);
+ ArrayFromVector<OffsetType, offset_type>(offsets_is_valid4, offset2_values, &offsets4);
+
+ ArrayFromVector<Int8Type, int8_t>(key_values, &keys);
+ ArrayFromVector<Int16Type, int16_t>(items_is_valid, item_values, &items);
+
+ auto map_type = map(int8(), int16());
+
+ ASSERT_OK_AND_ASSIGN(auto map1, MapArray::FromArrays(offsets1, keys, items, pool_));
+ ASSERT_OK_AND_ASSIGN(auto map3, MapArray::FromArrays(offsets3, keys, items, pool_));
+ ASSERT_OK_AND_ASSIGN(auto map4, MapArray::FromArrays(offsets4, keys, items, pool_));
+ ASSERT_OK(map1->Validate());
+ ASSERT_OK(map3->Validate());
+ ASSERT_OK(map4->Validate());
+
+ MapArray expected1(map_type, length, offsets1->data()->buffers[1], keys, items,
+ offsets1->data()->buffers[0], 0);
+ AssertArraysEqual(expected1, *map1);
+
+ // Use null bitmap from offsets3, but clean offsets from non-null version
+ MapArray expected3(map_type, length, offsets1->data()->buffers[1], keys, items,
+ offsets3->data()->buffers[0], 1);
+ AssertArraysEqual(expected3, *map3);
+
+ // Check that the last offset bit is zero
+ ASSERT_FALSE(BitUtil::GetBit(map3->null_bitmap()->data(), length + 1));
+
+ MapArray expected4(map_type, length, offsets2->data()->buffers[1], keys, items,
+ offsets4->data()->buffers[0], 1);
+ AssertArraysEqual(expected4, *map4);
+
+ // Test failure modes
+
+ std::shared_ptr<Array> tmp;
+
+ // Zero-length offsets
+ ASSERT_RAISES(Invalid, MapArray::FromArrays(offsets1->Slice(0, 0), keys, items, pool_));
+
+ // Offsets not the right type
+ ASSERT_RAISES(TypeError, MapArray::FromArrays(keys, offsets1, items, pool_));
+
+ // Keys and Items different lengths
+ ASSERT_RAISES(Invalid, MapArray::FromArrays(offsets1, keys->Slice(0, 1), items, pool_));
+
+ // Keys contains null values
+ std::shared_ptr<Array> keys_with_null = offsets3;
+ std::shared_ptr<Array> tmp_items = items->Slice(0, offsets3->length());
+ ASSERT_EQ(keys_with_null->length(), tmp_items->length());
+ ASSERT_RAISES(Invalid,
+ MapArray::FromArrays(offsets1, keys_with_null, tmp_items, pool_));
+}
+
+TEST_F(TestMapArray, FromArraysEquality) {
+ // More equality tests using MapArray::FromArrays
+ auto keys1 = ArrayFromJSON(utf8(), R"(["ab", "cd", "ef", "gh", "ij", "kl"])");
+ auto keys2 = ArrayFromJSON(utf8(), R"(["ab", "cd", "ef", "gh", "ij", "kl"])");
+ auto keys3 = ArrayFromJSON(utf8(), R"(["ab", "cd", "ef", "gh", "zz", "kl"])");
+ auto items1 = ArrayFromJSON(int16(), "[1, 2, 3, 4, 5, 6]");
+ auto items2 = ArrayFromJSON(int16(), "[1, 2, 3, 4, 5, 6]");
+ auto items3 = ArrayFromJSON(int16(), "[1, 2, 3, null, 5, 6]");
+ auto offsets1 = ArrayFromJSON(int32(), "[0, 1, 3, null, 5, 6]");
+ auto offsets2 = ArrayFromJSON(int32(), "[0, 1, 3, null, 5, 6]");
+ auto offsets3 = ArrayFromJSON(int32(), "[0, 1, 3, 3, 5, 6]");
+
+ ASSERT_OK_AND_ASSIGN(auto array1, MapArray::FromArrays(offsets1, keys1, items1));
+ ASSERT_OK_AND_ASSIGN(auto array2, MapArray::FromArrays(offsets2, keys2, items2));
+ ASSERT_OK_AND_ASSIGN(auto array3, MapArray::FromArrays(offsets3, keys2, items2));
+ ASSERT_OK_AND_ASSIGN(auto array4, MapArray::FromArrays(offsets2, keys3, items2));
+ ASSERT_OK_AND_ASSIGN(auto array5, MapArray::FromArrays(offsets2, keys2, items3));
+ ASSERT_OK_AND_ASSIGN(auto array6, MapArray::FromArrays(offsets3, keys3, items3));
+
+ ASSERT_TRUE(array1->Equals(array2));
+ ASSERT_TRUE(array1->RangeEquals(array2, 0, 5, 0));
+
+ ASSERT_FALSE(array1->Equals(array3)); // different offsets
+ ASSERT_FALSE(array1->RangeEquals(array3, 0, 5, 0));
+ ASSERT_TRUE(array1->RangeEquals(array3, 0, 2, 0));
+ ASSERT_FALSE(array1->RangeEquals(array3, 2, 5, 2));
+
+ ASSERT_FALSE(array1->Equals(array4)); // different keys
+ ASSERT_FALSE(array1->RangeEquals(array4, 0, 5, 0));
+ ASSERT_TRUE(array1->RangeEquals(array4, 0, 2, 0));
+ ASSERT_FALSE(array1->RangeEquals(array4, 2, 5, 2));
+
+ ASSERT_FALSE(array1->Equals(array5)); // different items
+ ASSERT_FALSE(array1->RangeEquals(array5, 0, 5, 0));
+ ASSERT_TRUE(array1->RangeEquals(array5, 0, 2, 0));
+ ASSERT_FALSE(array1->RangeEquals(array5, 2, 5, 2));
+
+ ASSERT_FALSE(array1->Equals(array6)); // different everything
+ ASSERT_FALSE(array1->RangeEquals(array6, 0, 5, 0));
+ ASSERT_TRUE(array1->RangeEquals(array6, 0, 2, 0));
+ ASSERT_FALSE(array1->RangeEquals(array6, 2, 5, 2));
+
+ // Map array equality should be indifferent to field names
+ ASSERT_OK_AND_ASSIGN(auto other_map_type,
+ MapType::Make(field("some_entries",
+ struct_({field("some_key", utf8(), false),
+ field("some_value", int16())}),
+ false)));
+ ASSERT_OK_AND_ASSIGN(auto array7,
+ MapArray::FromArrays(other_map_type, offsets2, keys2, items2));
+ ASSERT_TRUE(array1->Equals(array7));
+ ASSERT_TRUE(array1->RangeEquals(array7, 0, 5, 0));
+}
+
+namespace {
+
+template <typename TYPE>
+Status BuildListOfStructPairs(TYPE& builder, std::shared_ptr<Array>* out) {
+ auto struct_builder = internal::checked_cast<StructBuilder*>(builder.value_builder());
+ auto field0_builder =
+ internal::checked_cast<Int16Builder*>(struct_builder->field_builder(0));
+ auto field1_builder =
+ internal::checked_cast<Int16Builder*>(struct_builder->field_builder(1));
+
+ RETURN_NOT_OK(builder.Append());
+ RETURN_NOT_OK(field0_builder->AppendValues({0, 1}));
+ RETURN_NOT_OK(field1_builder->AppendValues({1, -1}, {1, 0}));
+ RETURN_NOT_OK(struct_builder->AppendValues(2, NULLPTR));
+ RETURN_NOT_OK(builder.AppendNull());
+ RETURN_NOT_OK(builder.Append());
+ RETURN_NOT_OK(field0_builder->Append(2));
+ RETURN_NOT_OK(field1_builder->Append(3));
+ RETURN_NOT_OK(struct_builder->Append());
+ RETURN_NOT_OK(builder.Append());
+ RETURN_NOT_OK(builder.Append());
+ RETURN_NOT_OK(field0_builder->AppendValues({3, 4}));
+ RETURN_NOT_OK(field1_builder->AppendValues({4, 5}));
+ RETURN_NOT_OK(struct_builder->AppendValues(2, NULLPTR));
+ RETURN_NOT_OK(builder.Finish(out));
+ RETURN_NOT_OK((*out)->Validate());
+
+ return Status::OK();
+}
+
+} // namespace
+
+TEST_F(TestMapArray, ValueBuilder) {
+ auto key_builder = std::make_shared<Int16Builder>();
+ auto item_builder = std::make_shared<Int16Builder>();
+ MapBuilder map_builder(default_memory_pool(), key_builder, item_builder);
+
+ // Build Map array using key/item builder
+ std::shared_ptr<Array> expected;
+ ASSERT_OK(map_builder.Append());
+ ASSERT_OK(key_builder->AppendValues({0, 1}));
+ ASSERT_OK(item_builder->AppendValues({1, -1}, {1, 0}));
+ ASSERT_OK(map_builder.AppendNull());
+ ASSERT_OK(map_builder.Append());
+ ASSERT_OK(key_builder->Append(2));
+ ASSERT_OK(item_builder->Append(3));
+ ASSERT_OK(map_builder.Append());
+ ASSERT_OK(map_builder.Append());
+ ASSERT_OK(key_builder->AppendValues({3, 4}));
+ ASSERT_OK(item_builder->AppendValues({4, 5}));
+ ASSERT_OK(map_builder.Finish(&expected));
+ ASSERT_OK(expected->Validate());
+
+ map_builder.Reset();
+
+ // Build Map array like an Array of Structs using value builder
+ std::shared_ptr<Array> actual_map;
+ ASSERT_OK(BuildListOfStructPairs(map_builder, &actual_map));
+ ASSERT_ARRAYS_EQUAL(*actual_map, *expected);
+
+ map_builder.Reset();
+
+ // Build a ListArray of Structs, and compare MapArray to the List
+ auto map_type = internal::checked_pointer_cast<MapType>(map_builder.type());
+ auto struct_type = map_type->value_type();
+ std::vector<std::shared_ptr<ArrayBuilder>> child_builders{key_builder, item_builder};
+ auto struct_builder =
+ std::make_shared<StructBuilder>(struct_type, default_memory_pool(), child_builders);
+ ListBuilder list_builder(default_memory_pool(), struct_builder, map_type);
+
+ std::shared_ptr<Array> actual_list;
+ ASSERT_OK(BuildListOfStructPairs(list_builder, &actual_list));
+
+ MapArray* map_ptr = internal::checked_cast<MapArray*>(actual_map.get());
+ auto list_type = std::make_shared<ListType>(map_type->field(0));
+ ListArray map_as_list(list_type, map_ptr->length(), map_ptr->data()->buffers[1],
+ map_ptr->values(), actual_map->data()->buffers[0],
+ map_ptr->null_count());
+
+ ASSERT_ARRAYS_EQUAL(*actual_list, map_as_list);
+}
+
+// ----------------------------------------------------------------------
+// FixedSizeList tests
+
+class TestFixedSizeListArray : public TestBuilder {
+ public:
+ void SetUp() {
+ TestBuilder::SetUp();
+
+ value_type_ = int32();
+ type_ = fixed_size_list(value_type_, list_size());
+
+ std::unique_ptr<ArrayBuilder> tmp;
+ ASSERT_OK(MakeBuilder(pool_, type_, &tmp));
+ builder_.reset(checked_cast<FixedSizeListBuilder*>(tmp.release()));
+ }
+
+ void Done() {
+ std::shared_ptr<Array> out;
+ FinishAndCheckPadding(builder_.get(), &out);
+ result_ = std::dynamic_pointer_cast<FixedSizeListArray>(out);
+ }
+
+ protected:
+ static constexpr int32_t list_size() { return 2; }
+ std::shared_ptr<DataType> value_type_;
+
+ std::shared_ptr<FixedSizeListBuilder> builder_;
+ std::shared_ptr<FixedSizeListArray> result_;
+};
+
+TEST_F(TestFixedSizeListArray, Equality) {
+ Int32Builder* vb = checked_cast<Int32Builder*>(builder_->value_builder());
+
+ std::shared_ptr<Array> array, equal_array, unequal_array;
+ std::vector<int32_t> equal_values = {1, 2, 3, 4, 5, 2, 2, 2, 5, 6};
+ std::vector<int32_t> unequal_values = {1, 2, 2, 2, 3, 4, 5, 2};
+
+ // setup two equal arrays
+ ASSERT_OK(builder_->AppendValues(equal_values.size() / list_size()));
+ ASSERT_OK(vb->AppendValues(equal_values.data(), equal_values.size()));
+ ASSERT_OK(builder_->Finish(&array));
+
+ ASSERT_OK(builder_->AppendValues(equal_values.size() / list_size()));
+ ASSERT_OK(vb->AppendValues(equal_values.data(), equal_values.size()));
+
+ ASSERT_OK(builder_->Finish(&equal_array));
+
+ // now an unequal one
+ ASSERT_OK(builder_->AppendValues(unequal_values.size() / list_size()));
+ ASSERT_OK(vb->AppendValues(unequal_values.data(), unequal_values.size()));
+ ASSERT_OK(builder_->Finish(&unequal_array));
+
+ // Test array equality
+ AssertArraysEqual(*array, *array);
+ AssertArraysEqual(*array, *equal_array);
+ EXPECT_FALSE(equal_array->Equals(unequal_array));
+ EXPECT_FALSE(unequal_array->Equals(equal_array));
+
+ // Test range equality
+ EXPECT_TRUE(array->RangeEquals(0, 1, 0, unequal_array));
+ EXPECT_FALSE(array->RangeEquals(0, 2, 0, unequal_array));
+ EXPECT_FALSE(array->RangeEquals(1, 2, 1, unequal_array));
+ EXPECT_TRUE(array->RangeEquals(1, 3, 2, unequal_array));
+}
+
+TEST_F(TestFixedSizeListArray, TestAppendNull) {
+ ASSERT_OK(builder_->AppendNull());
+ ASSERT_OK(builder_->AppendNull());
+
+ Done();
+
+ ASSERT_OK(result_->ValidateFull());
+ ASSERT_TRUE(result_->IsNull(0));
+ ASSERT_TRUE(result_->IsNull(1));
+
+ ASSERT_EQ(0, result_->value_offset(0));
+ ASSERT_EQ(list_size(), result_->value_offset(1));
+
+ auto values = result_->values();
+ ASSERT_EQ(list_size() * 2, values->length());
+}
+
+TEST_F(TestFixedSizeListArray, TestAppendNulls) {
+ ASSERT_OK(builder_->AppendNulls(3));
+
+ Done();
+
+ ASSERT_OK(result_->ValidateFull());
+ ASSERT_EQ(result_->length(), 3);
+ ASSERT_EQ(result_->null_count(), 3);
+ ASSERT_TRUE(result_->IsNull(0));
+ ASSERT_TRUE(result_->IsNull(1));
+ ASSERT_TRUE(result_->IsNull(2));
+
+ ASSERT_EQ(0, result_->value_offset(0));
+ ASSERT_EQ(list_size(), result_->value_offset(1));
+ ASSERT_EQ(list_size() * 2, result_->value_offset(2));
+
+ auto values = result_->values();
+ ASSERT_EQ(list_size() * 3, values->length());
+}
+
+void ValidateBasicFixedSizeListArray(const FixedSizeListArray* result,
+ const std::vector<int32_t>& values,
+ const std::vector<uint8_t>& is_valid) {
+ ASSERT_OK(result->ValidateFull());
+ ASSERT_EQ(1, result->null_count());
+ ASSERT_LE(result->values()->null_count(), 2);
+
+ ASSERT_EQ(3, result->length());
+ for (int32_t i = 0; i < 3; ++i) {
+ ASSERT_EQ(i * result->value_length(), result->value_offset(i));
+ }
+
+ for (int i = 0; i < result->length(); ++i) {
+ ASSERT_EQ(is_valid[i] == 0, result->IsNull(i));
+ }
+
+ ASSERT_LE(result->length() * result->value_length(), result->values()->length());
+ auto varr = std::dynamic_pointer_cast<Int32Array>(result->values());
+
+ for (size_t i = 0; i < values.size(); ++i) {
+ if (is_valid[i / result->value_length()] == 0) {
+ continue;
+ }
+ ASSERT_EQ(values[i], varr->Value(i));
+ }
+}
+
+TEST_F(TestFixedSizeListArray, TestBasics) {
+ std::vector<int32_t> values = {0, 1, 2, 3, 4, 5};
+ std::vector<uint8_t> is_valid = {1, 0, 1};
+
+ Int32Builder* vb = checked_cast<Int32Builder*>(builder_->value_builder());
+
+ int pos = 0;
+ for (size_t i = 0; i < values.size() / list_size(); ++i) {
+ if (is_valid[i] == 0) {
+ ASSERT_OK(builder_->AppendNull());
+ pos += list_size();
+ continue;
+ }
+ ASSERT_OK(builder_->Append());
+ for (int j = 0; j < list_size(); ++j) {
+ ASSERT_OK(vb->Append(values[pos++]));
+ }
+ }
+
+ Done();
+ ValidateBasicFixedSizeListArray(result_.get(), values, is_valid);
+}
+
+TEST_F(TestFixedSizeListArray, BulkAppend) {
+ std::vector<int32_t> values = {0, 1, 2, 3, 4, 5};
+ std::vector<uint8_t> is_valid = {1, 0, 1};
+
+ Int32Builder* vb = checked_cast<Int32Builder*>(builder_->value_builder());
+
+ ASSERT_OK(builder_->AppendValues(values.size() / list_size(), is_valid.data()));
+ for (int32_t value : values) {
+ ASSERT_OK(vb->Append(value));
+ }
+ Done();
+ ValidateBasicFixedSizeListArray(result_.get(), values, is_valid);
+}
+
+TEST_F(TestFixedSizeListArray, BulkAppendExcess) {
+ std::vector<int32_t> values = {0, 1, 2, 3, 4, 5};
+ std::vector<uint8_t> is_valid = {1, 0, 1};
+
+ Int32Builder* vb = checked_cast<Int32Builder*>(builder_->value_builder());
+
+ ASSERT_OK(builder_->AppendValues(values.size() / list_size(), is_valid.data()));
+ for (int32_t value : values) {
+ ASSERT_OK(vb->Append(value));
+ }
+ for (int32_t value : values) {
+ ASSERT_OK(vb->Append(value));
+ }
+
+ Done();
+ // We appended too many values to the child array, but that's OK
+ ValidateBasicFixedSizeListArray(result_.get(), values, is_valid);
+}
+
+TEST_F(TestFixedSizeListArray, TestZeroLength) {
+ // All buffers are null
+ Done();
+ ASSERT_OK(result_->ValidateFull());
+}
+
+TEST_F(TestFixedSizeListArray, TestBuilderPreserveFieldName) {
+ auto list_type_with_name = fixed_size_list(field("counts", int32()), list_size());
+
+ std::unique_ptr<ArrayBuilder> tmp;
+ ASSERT_OK(MakeBuilder(pool_, list_type_with_name, &tmp));
+ builder_.reset(checked_cast<FixedSizeListBuilder*>(tmp.release()));
+
+ ASSERT_OK(builder_->AppendValues(4));
+
+ std::shared_ptr<Array> list_array;
+ ASSERT_OK(builder_->Finish(&list_array));
+
+ const auto& type = checked_cast<FixedSizeListType&>(*list_array->type());
+ ASSERT_EQ("counts", type.value_field()->name());
+}
+
+TEST_F(TestFixedSizeListArray, NegativeLength) {
+ type_ = fixed_size_list(value_type_, -42);
+ auto values = ArrayFromJSON(value_type_, "[]");
+ result_ = std::make_shared<FixedSizeListArray>(type_, 0, values);
+ ASSERT_RAISES(Invalid, result_->ValidateFull());
+}
+
+TEST_F(TestFixedSizeListArray, NotEnoughValues) {
+ type_ = fixed_size_list(value_type_, 2);
+ auto values = ArrayFromJSON(value_type_, "[]");
+ result_ = std::make_shared<FixedSizeListArray>(type_, 1, values);
+ ASSERT_RAISES(Invalid, result_->ValidateFull());
+
+ // ARROW-13437: too many values is OK though
+ values = ArrayFromJSON(value_type_, "[1, 2, 3, 4]");
+ result_ = std::make_shared<FixedSizeListArray>(type_, 1, values);
+ ASSERT_OK(result_->ValidateFull());
+}
+
+TEST_F(TestFixedSizeListArray, FlattenZeroLength) {
+ Done();
+ ASSERT_OK_AND_ASSIGN(auto flattened, result_->Flatten());
+ ASSERT_OK(flattened->ValidateFull());
+ ASSERT_EQ(0, flattened->length());
+ AssertTypeEqual(*flattened->type(), *value_type_);
+}
+
+TEST_F(TestFixedSizeListArray, Flatten) {
+ std::vector<int32_t> values = {0, 1, 2, 3, 4, 5, 6, 7};
+ std::vector<uint8_t> is_valid = {1, 0, 1, 1};
+ ASSERT_OK(builder_->AppendValues(4, is_valid.data()));
+ auto* vb = checked_cast<Int32Builder*>(builder_->value_builder());
+ ASSERT_OK(vb->AppendValues(values.data(), static_cast<int64_t>(values.size())));
+ Done();
+
+ {
+ ASSERT_OK_AND_ASSIGN(auto flattened, result_->Flatten());
+ ASSERT_OK(flattened->ValidateFull());
+ ASSERT_EQ(6, flattened->length());
+ AssertArraysEqual(*flattened, *ArrayFromJSON(value_type_, "[0, 1, 4, 5, 6, 7]"),
+ /*verbose=*/true);
+ }
+
+ {
+ auto sliced = std::dynamic_pointer_cast<FixedSizeListArray>(result_->Slice(1, 2));
+ ASSERT_OK_AND_ASSIGN(auto flattened, sliced->Flatten());
+ ASSERT_OK(flattened->ValidateFull());
+ ASSERT_EQ(2, flattened->length());
+ AssertArraysEqual(*flattened, *ArrayFromJSON(value_type_, "[4, 5]"),
+ /*verbose=*/true);
+ }
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_nested.cc b/src/arrow/cpp/src/arrow/array/array_nested.cc
new file mode 100644
index 000000000..22ad728a4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_nested.cc
@@ -0,0 +1,763 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/array_nested.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_primitive.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/array/util.h"
+#include "arrow/buffer.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/atomic_shared_ptr.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::BitmapAnd;
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+using internal::CopyBitmap;
+
+// ----------------------------------------------------------------------
+// ListArray / LargeListArray
+
+namespace {
+
+template <typename TYPE>
+Status CleanListOffsets(const Array& offsets, MemoryPool* pool,
+ std::shared_ptr<Buffer>* offset_buf_out,
+ std::shared_ptr<Buffer>* validity_buf_out) {
+ using offset_type = typename TYPE::offset_type;
+ using OffsetArrowType = typename CTypeTraits<offset_type>::ArrowType;
+ using OffsetArrayType = typename TypeTraits<OffsetArrowType>::ArrayType;
+
+ const auto& typed_offsets = checked_cast<const OffsetArrayType&>(offsets);
+ const int64_t num_offsets = offsets.length();
+
+ if (offsets.null_count() > 0) {
+ if (!offsets.IsValid(num_offsets - 1)) {
+ return Status::Invalid("Last list offset should be non-null");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto clean_offsets,
+ AllocateBuffer(num_offsets * sizeof(offset_type), pool));
+
+ // Copy valid bits, ignoring the final offset (since for a length N list array,
+ // we have N + 1 offsets)
+ ARROW_ASSIGN_OR_RAISE(
+ auto clean_valid_bits,
+ offsets.null_bitmap()->CopySlice(0, BitUtil::BytesForBits(num_offsets - 1)));
+ *validity_buf_out = clean_valid_bits;
+
+ const offset_type* raw_offsets = typed_offsets.raw_values();
+ auto clean_raw_offsets =
+ reinterpret_cast<offset_type*>(clean_offsets->mutable_data());
+
+ // Must work backwards so we can tell how many values were in the last non-null value
+ offset_type current_offset = raw_offsets[num_offsets - 1];
+ for (int64_t i = num_offsets - 1; i >= 0; --i) {
+ if (offsets.IsValid(i)) {
+ current_offset = raw_offsets[i];
+ }
+ clean_raw_offsets[i] = current_offset;
+ }
+
+ *offset_buf_out = std::move(clean_offsets);
+ } else {
+ *validity_buf_out = offsets.null_bitmap();
+ *offset_buf_out = typed_offsets.values();
+ }
+
+ return Status::OK();
+}
+
+template <typename TYPE>
+Result<std::shared_ptr<typename TypeTraits<TYPE>::ArrayType>> ListArrayFromArrays(
+ const Array& offsets, const Array& values, MemoryPool* pool) {
+ using offset_type = typename TYPE::offset_type;
+ using ArrayType = typename TypeTraits<TYPE>::ArrayType;
+ using OffsetArrowType = typename CTypeTraits<offset_type>::ArrowType;
+
+ if (offsets.length() == 0) {
+ return Status::Invalid("List offsets must have non-zero length");
+ }
+
+ if (offsets.type_id() != OffsetArrowType::type_id) {
+ return Status::TypeError("List offsets must be ", OffsetArrowType::type_name());
+ }
+
+ std::shared_ptr<Buffer> offset_buf, validity_buf;
+ RETURN_NOT_OK(CleanListOffsets<TYPE>(offsets, pool, &offset_buf, &validity_buf));
+ BufferVector buffers = {validity_buf, offset_buf};
+
+ auto list_type = std::make_shared<TYPE>(values.type());
+ auto internal_data =
+ ArrayData::Make(list_type, offsets.length() - 1, std::move(buffers),
+ offsets.null_count(), offsets.offset());
+ internal_data->child_data.push_back(values.data());
+
+ return std::make_shared<ArrayType>(internal_data);
+}
+
+static std::shared_ptr<Array> SliceArrayWithOffsets(const Array& array, int64_t begin,
+ int64_t end) {
+ return array.Slice(begin, end - begin);
+}
+
+template <typename ListArrayT>
+Result<std::shared_ptr<Array>> FlattenListArray(const ListArrayT& list_array,
+ MemoryPool* memory_pool) {
+ const int64_t list_array_length = list_array.length();
+ std::shared_ptr<arrow::Array> value_array = list_array.values();
+
+ // Shortcut: if a ListArray does not contain nulls, then simply slice its
+ // value array with the first and the last offsets.
+ if (list_array.null_count() == 0) {
+ return SliceArrayWithOffsets(*value_array, list_array.value_offset(0),
+ list_array.value_offset(list_array_length));
+ }
+
+ // The ListArray contains nulls: there may be a non-empty sub-list behind
+ // a null and it must not be contained in the result.
+ std::vector<std::shared_ptr<Array>> non_null_fragments;
+ int64_t valid_begin = 0;
+ while (valid_begin < list_array_length) {
+ int64_t valid_end = valid_begin;
+ while (valid_end < list_array_length &&
+ (list_array.IsValid(valid_end) || list_array.value_length(valid_end) == 0)) {
+ ++valid_end;
+ }
+ if (valid_begin < valid_end) {
+ non_null_fragments.push_back(
+ SliceArrayWithOffsets(*value_array, list_array.value_offset(valid_begin),
+ list_array.value_offset(valid_end)));
+ }
+ valid_begin = valid_end + 1; // skip null entry
+ }
+
+ // Final attempt to avoid invoking Concatenate().
+ if (non_null_fragments.size() == 1) {
+ return non_null_fragments[0];
+ }
+
+ return Concatenate(non_null_fragments, memory_pool);
+}
+
+} // namespace
+
+namespace internal {
+
+template <typename TYPE>
+inline void SetListData(BaseListArray<TYPE>* self, const std::shared_ptr<ArrayData>& data,
+ Type::type expected_type_id) {
+ ARROW_CHECK_EQ(data->buffers.size(), 2);
+ ARROW_CHECK_EQ(data->type->id(), expected_type_id);
+ ARROW_CHECK_EQ(data->child_data.size(), 1);
+
+ self->Array::SetData(data);
+
+ self->list_type_ = checked_cast<const TYPE*>(data->type.get());
+ self->raw_value_offsets_ =
+ data->GetValuesSafe<typename TYPE::offset_type>(1, /*offset=*/0);
+
+ ARROW_CHECK_EQ(self->list_type_->value_type()->id(), data->child_data[0]->type->id());
+ DCHECK(self->list_type_->value_type()->Equals(data->child_data[0]->type));
+ self->values_ = MakeArray(self->data_->child_data[0]);
+}
+
+} // namespace internal
+
+ListArray::ListArray(std::shared_ptr<ArrayData> data) { SetData(std::move(data)); }
+
+LargeListArray::LargeListArray(const std::shared_ptr<ArrayData>& data) { SetData(data); }
+
+ListArray::ListArray(std::shared_ptr<DataType> type, int64_t length,
+ std::shared_ptr<Buffer> value_offsets, std::shared_ptr<Array> values,
+ std::shared_ptr<Buffer> null_bitmap, int64_t null_count,
+ int64_t offset) {
+ ARROW_CHECK_EQ(type->id(), Type::LIST);
+ auto internal_data = ArrayData::Make(
+ std::move(type), length,
+ BufferVector{std::move(null_bitmap), std::move(value_offsets)}, null_count, offset);
+ internal_data->child_data.emplace_back(values->data());
+ SetData(std::move(internal_data));
+}
+
+void ListArray::SetData(const std::shared_ptr<ArrayData>& data) {
+ internal::SetListData(this, data);
+}
+
+LargeListArray::LargeListArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Buffer>& value_offsets,
+ const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Buffer>& null_bitmap,
+ int64_t null_count, int64_t offset) {
+ ARROW_CHECK_EQ(type->id(), Type::LARGE_LIST);
+ auto internal_data =
+ ArrayData::Make(type, length, {null_bitmap, value_offsets}, null_count, offset);
+ internal_data->child_data.emplace_back(values->data());
+ SetData(internal_data);
+}
+
+void LargeListArray::SetData(const std::shared_ptr<ArrayData>& data) {
+ internal::SetListData(this, data);
+}
+
+Result<std::shared_ptr<ListArray>> ListArray::FromArrays(const Array& offsets,
+ const Array& values,
+ MemoryPool* pool) {
+ return ListArrayFromArrays<ListType>(offsets, values, pool);
+}
+
+Result<std::shared_ptr<LargeListArray>> LargeListArray::FromArrays(const Array& offsets,
+ const Array& values,
+ MemoryPool* pool) {
+ return ListArrayFromArrays<LargeListType>(offsets, values, pool);
+}
+
+Result<std::shared_ptr<Array>> ListArray::Flatten(MemoryPool* memory_pool) const {
+ return FlattenListArray(*this, memory_pool);
+}
+
+Result<std::shared_ptr<Array>> LargeListArray::Flatten(MemoryPool* memory_pool) const {
+ return FlattenListArray(*this, memory_pool);
+}
+
+static std::shared_ptr<Array> BoxOffsets(const std::shared_ptr<DataType>& boxed_type,
+ const ArrayData& data) {
+ std::vector<std::shared_ptr<Buffer>> buffers = {nullptr, data.buffers[1]};
+ auto offsets_data =
+ std::make_shared<ArrayData>(boxed_type, data.length + 1, std::move(buffers),
+ /*null_count=*/0, data.offset);
+ return MakeArray(offsets_data);
+}
+
+std::shared_ptr<Array> ListArray::offsets() const { return BoxOffsets(int32(), *data_); }
+
+std::shared_ptr<Array> LargeListArray::offsets() const {
+ return BoxOffsets(int64(), *data_);
+}
+
+// ----------------------------------------------------------------------
+// MapArray
+
+MapArray::MapArray(const std::shared_ptr<ArrayData>& data) { SetData(data); }
+
+MapArray::MapArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Buffer>& offsets,
+ const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Buffer>& null_bitmap, int64_t null_count,
+ int64_t offset) {
+ SetData(ArrayData::Make(type, length, {null_bitmap, offsets}, {values->data()},
+ null_count, offset));
+}
+
+MapArray::MapArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Buffer>& offsets,
+ const std::shared_ptr<Array>& keys,
+ const std::shared_ptr<Array>& items,
+ const std::shared_ptr<Buffer>& null_bitmap, int64_t null_count,
+ int64_t offset) {
+ auto pair_data = ArrayData::Make(type->fields()[0]->type(), keys->data()->length,
+ {nullptr}, {keys->data(), items->data()}, 0, offset);
+ auto map_data = ArrayData::Make(type, length, {null_bitmap, offsets}, {pair_data},
+ null_count, offset);
+ SetData(map_data);
+}
+
+Result<std::shared_ptr<Array>> MapArray::FromArraysInternal(
+ std::shared_ptr<DataType> type, const std::shared_ptr<Array>& offsets,
+ const std::shared_ptr<Array>& keys, const std::shared_ptr<Array>& items,
+ MemoryPool* pool) {
+ using offset_type = typename MapType::offset_type;
+ using OffsetArrowType = typename CTypeTraits<offset_type>::ArrowType;
+
+ if (offsets->length() == 0) {
+ return Status::Invalid("Map offsets must have non-zero length");
+ }
+
+ if (offsets->type_id() != OffsetArrowType::type_id) {
+ return Status::TypeError("Map offsets must be ", OffsetArrowType::type_name());
+ }
+
+ if (keys->null_count() != 0) {
+ return Status::Invalid("Map can not contain NULL valued keys");
+ }
+
+ if (keys->length() != items->length()) {
+ return Status::Invalid("Map key and item arrays must be equal length");
+ }
+
+ std::shared_ptr<Buffer> offset_buf, validity_buf;
+ RETURN_NOT_OK(CleanListOffsets<MapType>(*offsets, pool, &offset_buf, &validity_buf));
+
+ return std::make_shared<MapArray>(type, offsets->length() - 1, offset_buf, keys, items,
+ validity_buf, offsets->null_count(),
+ offsets->offset());
+}
+
+Result<std::shared_ptr<Array>> MapArray::FromArrays(const std::shared_ptr<Array>& offsets,
+ const std::shared_ptr<Array>& keys,
+ const std::shared_ptr<Array>& items,
+ MemoryPool* pool) {
+ return FromArraysInternal(std::make_shared<MapType>(keys->type(), items->type()),
+ offsets, keys, items, pool);
+}
+
+Result<std::shared_ptr<Array>> MapArray::FromArrays(std::shared_ptr<DataType> type,
+ const std::shared_ptr<Array>& offsets,
+ const std::shared_ptr<Array>& keys,
+ const std::shared_ptr<Array>& items,
+ MemoryPool* pool) {
+ if (type->id() != Type::MAP) {
+ return Status::TypeError("Expected map type, got ", type->ToString());
+ }
+ const auto& map_type = checked_cast<const MapType&>(*type);
+ if (!map_type.key_type()->Equals(keys->type())) {
+ return Status::TypeError("Mismatching map keys type");
+ }
+ if (!map_type.item_type()->Equals(items->type())) {
+ return Status::TypeError("Mismatching map items type");
+ }
+ return FromArraysInternal(std::move(type), offsets, keys, items, pool);
+}
+
+Status MapArray::ValidateChildData(
+ const std::vector<std::shared_ptr<ArrayData>>& child_data) {
+ if (child_data.size() != 1) {
+ return Status::Invalid("Expected one child array for map array");
+ }
+ const auto& pair_data = child_data[0];
+ if (pair_data->type->id() != Type::STRUCT) {
+ return Status::Invalid("Map array child array should have struct type");
+ }
+ if (pair_data->null_count != 0) {
+ return Status::Invalid("Map array child array should have no nulls");
+ }
+ if (pair_data->child_data.size() != 2) {
+ return Status::Invalid("Map array child array should have two fields");
+ }
+ if (pair_data->child_data[0]->null_count != 0) {
+ return Status::Invalid("Map array keys array should have no nulls");
+ }
+ return Status::OK();
+}
+
+void MapArray::SetData(const std::shared_ptr<ArrayData>& data) {
+ ARROW_CHECK_OK(ValidateChildData(data->child_data));
+
+ internal::SetListData(this, data, Type::MAP);
+ map_type_ = checked_cast<const MapType*>(data->type.get());
+ const auto& pair_data = data->child_data[0];
+ keys_ = MakeArray(pair_data->child_data[0]);
+ items_ = MakeArray(pair_data->child_data[1]);
+}
+
+// ----------------------------------------------------------------------
+// FixedSizeListArray
+
+FixedSizeListArray::FixedSizeListArray(const std::shared_ptr<ArrayData>& data) {
+ SetData(data);
+}
+
+FixedSizeListArray::FixedSizeListArray(const std::shared_ptr<DataType>& type,
+ int64_t length,
+ const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Buffer>& null_bitmap,
+ int64_t null_count, int64_t offset) {
+ auto internal_data = ArrayData::Make(type, length, {null_bitmap}, null_count, offset);
+ internal_data->child_data.emplace_back(values->data());
+ SetData(internal_data);
+}
+
+void FixedSizeListArray::SetData(const std::shared_ptr<ArrayData>& data) {
+ ARROW_CHECK_EQ(data->type->id(), Type::FIXED_SIZE_LIST);
+ this->Array::SetData(data);
+
+ ARROW_CHECK_EQ(list_type()->value_type()->id(), data->child_data[0]->type->id());
+ DCHECK(list_type()->value_type()->Equals(data->child_data[0]->type));
+ list_size_ = list_type()->list_size();
+
+ ARROW_CHECK_EQ(data_->child_data.size(), 1);
+ values_ = MakeArray(data_->child_data[0]);
+}
+
+const FixedSizeListType* FixedSizeListArray::list_type() const {
+ return checked_cast<const FixedSizeListType*>(data_->type.get());
+}
+
+std::shared_ptr<DataType> FixedSizeListArray::value_type() const {
+ return list_type()->value_type();
+}
+
+std::shared_ptr<Array> FixedSizeListArray::values() const { return values_; }
+
+Result<std::shared_ptr<Array>> FixedSizeListArray::FromArrays(
+ const std::shared_ptr<Array>& values, int32_t list_size) {
+ if (list_size <= 0) {
+ return Status::Invalid("list_size needs to be a strict positive integer");
+ }
+
+ if ((values->length() % list_size) != 0) {
+ return Status::Invalid(
+ "The length of the values Array needs to be a multiple of the list_size");
+ }
+ int64_t length = values->length() / list_size;
+ auto list_type = std::make_shared<FixedSizeListType>(values->type(), list_size);
+ std::shared_ptr<Buffer> validity_buf;
+
+ return std::make_shared<FixedSizeListArray>(list_type, length, values, validity_buf,
+ /*null_count=*/0, /*offset=*/0);
+}
+
+Result<std::shared_ptr<Array>> FixedSizeListArray::Flatten(
+ MemoryPool* memory_pool) const {
+ return FlattenListArray(*this, memory_pool);
+}
+
+// ----------------------------------------------------------------------
+// Struct
+
+StructArray::StructArray(const std::shared_ptr<ArrayData>& data) {
+ ARROW_CHECK_EQ(data->type->id(), Type::STRUCT);
+ SetData(data);
+ boxed_fields_.resize(data->child_data.size());
+}
+
+StructArray::StructArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::vector<std::shared_ptr<Array>>& children,
+ std::shared_ptr<Buffer> null_bitmap, int64_t null_count,
+ int64_t offset) {
+ ARROW_CHECK_EQ(type->id(), Type::STRUCT);
+ SetData(ArrayData::Make(type, length, {null_bitmap}, null_count, offset));
+ for (const auto& child : children) {
+ data_->child_data.push_back(child->data());
+ }
+ boxed_fields_.resize(children.size());
+}
+
+Result<std::shared_ptr<StructArray>> StructArray::Make(
+ const std::vector<std::shared_ptr<Array>>& children,
+ const std::vector<std::shared_ptr<Field>>& fields,
+ std::shared_ptr<Buffer> null_bitmap, int64_t null_count, int64_t offset) {
+ if (children.size() != fields.size()) {
+ return Status::Invalid("Mismatching number of fields and child arrays");
+ }
+ int64_t length = 0;
+ if (children.size() == 0) {
+ return Status::Invalid("Can't infer struct array length with 0 child arrays");
+ }
+ length = children.front()->length();
+ for (const auto& child : children) {
+ if (length != child->length()) {
+ return Status::Invalid("Mismatching child array lengths");
+ }
+ }
+ if (offset > length) {
+ return Status::IndexError("Offset greater than length of child arrays");
+ }
+ if (null_bitmap == nullptr) {
+ if (null_count > 0) {
+ return Status::Invalid("null_count = ", null_count, " but no null bitmap given");
+ }
+ null_count = 0;
+ }
+ return std::make_shared<StructArray>(struct_(fields), length - offset, children,
+ null_bitmap, null_count, offset);
+}
+
+Result<std::shared_ptr<StructArray>> StructArray::Make(
+ const std::vector<std::shared_ptr<Array>>& children,
+ const std::vector<std::string>& field_names, std::shared_ptr<Buffer> null_bitmap,
+ int64_t null_count, int64_t offset) {
+ if (children.size() != field_names.size()) {
+ return Status::Invalid("Mismatching number of field names and child arrays");
+ }
+ std::vector<std::shared_ptr<Field>> fields(children.size());
+ for (size_t i = 0; i < children.size(); ++i) {
+ fields[i] = ::arrow::field(field_names[i], children[i]->type());
+ }
+ return Make(children, fields, std::move(null_bitmap), null_count, offset);
+}
+
+const StructType* StructArray::struct_type() const {
+ return checked_cast<const StructType*>(data_->type.get());
+}
+
+const ArrayVector& StructArray::fields() const {
+ for (int i = 0; i < num_fields(); ++i) {
+ (void)field(i);
+ }
+ return boxed_fields_;
+}
+
+std::shared_ptr<Array> StructArray::field(int i) const {
+ std::shared_ptr<Array> result = internal::atomic_load(&boxed_fields_[i]);
+ if (!result) {
+ std::shared_ptr<ArrayData> field_data;
+ if (data_->offset != 0 || data_->child_data[i]->length != data_->length) {
+ field_data = data_->child_data[i]->Slice(data_->offset, data_->length);
+ } else {
+ field_data = data_->child_data[i];
+ }
+ result = MakeArray(field_data);
+ internal::atomic_store(&boxed_fields_[i], result);
+ }
+ return result;
+}
+
+std::shared_ptr<Array> StructArray::GetFieldByName(const std::string& name) const {
+ int i = struct_type()->GetFieldIndex(name);
+ return i == -1 ? nullptr : field(i);
+}
+
+Result<ArrayVector> StructArray::Flatten(MemoryPool* pool) const {
+ ArrayVector flattened;
+ flattened.reserve(data_->child_data.size());
+ std::shared_ptr<Buffer> null_bitmap = data_->buffers[0];
+
+ for (const auto& child_data_ptr : data_->child_data) {
+ auto child_data = child_data_ptr->Copy();
+
+ std::shared_ptr<Buffer> flattened_null_bitmap;
+ int64_t flattened_null_count = kUnknownNullCount;
+
+ // Need to adjust for parent offset
+ if (data_->offset != 0 || data_->length != child_data->length) {
+ child_data = child_data->Slice(data_->offset, data_->length);
+ }
+ std::shared_ptr<Buffer> child_null_bitmap = child_data->buffers[0];
+ const int64_t child_offset = child_data->offset;
+
+ // The validity of a flattened datum is the logical AND of the struct
+ // element's validity and the individual field element's validity.
+ if (null_bitmap && child_null_bitmap) {
+ ARROW_ASSIGN_OR_RAISE(
+ flattened_null_bitmap,
+ BitmapAnd(pool, child_null_bitmap->data(), child_offset, null_bitmap_data_,
+ data_->offset, data_->length, child_offset));
+ } else if (child_null_bitmap) {
+ flattened_null_bitmap = child_null_bitmap;
+ flattened_null_count = child_data->null_count;
+ } else if (null_bitmap) {
+ if (child_offset == data_->offset) {
+ flattened_null_bitmap = null_bitmap;
+ } else {
+ // If the child has an offset, need to synthesize a validity
+ // buffer with an offset too
+ ARROW_ASSIGN_OR_RAISE(flattened_null_bitmap,
+ AllocateEmptyBitmap(child_offset + data_->length, pool));
+ CopyBitmap(null_bitmap_data_, data_->offset, data_->length,
+ flattened_null_bitmap->mutable_data(), child_offset);
+ }
+ flattened_null_count = data_->null_count;
+ } else {
+ flattened_null_count = 0;
+ }
+
+ auto flattened_data = child_data->Copy();
+ flattened_data->buffers[0] = flattened_null_bitmap;
+ flattened_data->null_count = flattened_null_count;
+
+ flattened.push_back(MakeArray(flattened_data));
+ }
+
+ return flattened;
+}
+
+// ----------------------------------------------------------------------
+// UnionArray
+
+void UnionArray::SetData(std::shared_ptr<ArrayData> data) {
+ this->Array::SetData(std::move(data));
+
+ union_type_ = checked_cast<const UnionType*>(data_->type.get());
+
+ ARROW_CHECK_GE(data_->buffers.size(), 2);
+ raw_type_codes_ = data->GetValuesSafe<int8_t>(1, /*offset=*/0);
+ boxed_fields_.resize(data_->child_data.size());
+}
+
+void SparseUnionArray::SetData(std::shared_ptr<ArrayData> data) {
+ this->UnionArray::SetData(std::move(data));
+ ARROW_CHECK_EQ(data_->type->id(), Type::SPARSE_UNION);
+ ARROW_CHECK_EQ(data_->buffers.size(), 2);
+
+ // No validity bitmap
+ ARROW_CHECK_EQ(data_->buffers[0], nullptr);
+}
+
+void DenseUnionArray::SetData(const std::shared_ptr<ArrayData>& data) {
+ this->UnionArray::SetData(std::move(data));
+
+ ARROW_CHECK_EQ(data_->type->id(), Type::DENSE_UNION);
+ ARROW_CHECK_EQ(data_->buffers.size(), 3);
+
+ // No validity bitmap
+ ARROW_CHECK_EQ(data_->buffers[0], nullptr);
+
+ raw_value_offsets_ = data->GetValuesSafe<int32_t>(2, /*offset=*/0);
+}
+
+SparseUnionArray::SparseUnionArray(std::shared_ptr<ArrayData> data) {
+ SetData(std::move(data));
+}
+
+SparseUnionArray::SparseUnionArray(std::shared_ptr<DataType> type, int64_t length,
+ ArrayVector children,
+ std::shared_ptr<Buffer> type_codes, int64_t offset) {
+ auto internal_data = ArrayData::Make(std::move(type), length,
+ BufferVector{nullptr, std::move(type_codes)},
+ /*null_count=*/0, offset);
+ for (const auto& child : children) {
+ internal_data->child_data.push_back(child->data());
+ }
+ SetData(std::move(internal_data));
+}
+
+DenseUnionArray::DenseUnionArray(const std::shared_ptr<ArrayData>& data) {
+ SetData(data);
+}
+
+DenseUnionArray::DenseUnionArray(std::shared_ptr<DataType> type, int64_t length,
+ ArrayVector children, std::shared_ptr<Buffer> type_ids,
+ std::shared_ptr<Buffer> value_offsets, int64_t offset) {
+ auto internal_data = ArrayData::Make(
+ std::move(type), length,
+ BufferVector{nullptr, std::move(type_ids), std::move(value_offsets)},
+ /*null_count=*/0, offset);
+ for (const auto& child : children) {
+ internal_data->child_data.push_back(child->data());
+ }
+ SetData(internal_data);
+}
+
+Result<std::shared_ptr<Array>> DenseUnionArray::Make(
+ const Array& type_ids, const Array& value_offsets, ArrayVector children,
+ std::vector<std::string> field_names, std::vector<type_code_t> type_codes) {
+ if (value_offsets.length() == 0) {
+ return Status::Invalid("UnionArray offsets must have non-zero length");
+ }
+
+ if (value_offsets.type_id() != Type::INT32) {
+ return Status::TypeError("UnionArray offsets must be signed int32");
+ }
+
+ if (type_ids.type_id() != Type::INT8) {
+ return Status::TypeError("UnionArray type_ids must be signed int8");
+ }
+
+ if (type_ids.null_count() != 0) {
+ return Status::Invalid("Union type ids may not have nulls");
+ }
+
+ if (value_offsets.null_count() != 0) {
+ return Status::Invalid("Make does not allow nulls in value_offsets");
+ }
+
+ if (field_names.size() > 0 && field_names.size() != children.size()) {
+ return Status::Invalid("field_names must have the same length as children");
+ }
+
+ if (type_codes.size() > 0 && type_codes.size() != children.size()) {
+ return Status::Invalid("type_codes must have the same length as children");
+ }
+
+ BufferVector buffers = {nullptr, checked_cast<const Int8Array&>(type_ids).values(),
+ checked_cast<const Int32Array&>(value_offsets).values()};
+
+ auto union_type = dense_union(children, std::move(field_names), std::move(type_codes));
+ auto internal_data =
+ ArrayData::Make(std::move(union_type), type_ids.length(), std::move(buffers),
+ /*null_count=*/0, type_ids.offset());
+ for (const auto& child : children) {
+ internal_data->child_data.push_back(child->data());
+ }
+ return std::make_shared<DenseUnionArray>(std::move(internal_data));
+}
+
+Result<std::shared_ptr<Array>> SparseUnionArray::Make(
+ const Array& type_ids, ArrayVector children, std::vector<std::string> field_names,
+ std::vector<int8_t> type_codes) {
+ if (type_ids.type_id() != Type::INT8) {
+ return Status::TypeError("UnionArray type_ids must be signed int8");
+ }
+
+ if (type_ids.null_count() != 0) {
+ return Status::Invalid("Union type ids may not have nulls");
+ }
+
+ if (field_names.size() > 0 && field_names.size() != children.size()) {
+ return Status::Invalid("field_names must have the same length as children");
+ }
+
+ if (type_codes.size() > 0 && type_codes.size() != children.size()) {
+ return Status::Invalid("type_codes must have the same length as children");
+ }
+
+ BufferVector buffers = {nullptr, checked_cast<const Int8Array&>(type_ids).values()};
+ auto union_type = sparse_union(children, std::move(field_names), std::move(type_codes));
+ auto internal_data =
+ ArrayData::Make(std::move(union_type), type_ids.length(), std::move(buffers),
+ /*null_count=*/0, type_ids.offset());
+ for (const auto& child : children) {
+ internal_data->child_data.push_back(child->data());
+ if (child->length() != type_ids.length()) {
+ return Status::Invalid(
+ "Sparse UnionArray must have len(child) == len(type_ids) for all children");
+ }
+ }
+ return std::make_shared<SparseUnionArray>(std::move(internal_data));
+}
+
+std::shared_ptr<Array> UnionArray::field(int i) const {
+ if (i < 0 ||
+ static_cast<decltype(boxed_fields_)::size_type>(i) >= boxed_fields_.size()) {
+ return nullptr;
+ }
+ std::shared_ptr<Array> result = internal::atomic_load(&boxed_fields_[i]);
+ if (!result) {
+ std::shared_ptr<ArrayData> child_data = data_->child_data[i]->Copy();
+ if (mode() == UnionMode::SPARSE) {
+ // Sparse union: need to adjust child if union is sliced
+ // (for dense unions, the need to lookup through the offsets
+ // makes this unnecessary)
+ if (data_->offset != 0 || child_data->length > data_->length) {
+ child_data = child_data->Slice(data_->offset, data_->length);
+ }
+ }
+ result = MakeArray(child_data);
+ internal::atomic_store(&boxed_fields_[i], result);
+ }
+ return result;
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_nested.h b/src/arrow/cpp/src/arrow/array/array_nested.h
new file mode 100644
index 000000000..762ba24f2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_nested.h
@@ -0,0 +1,533 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Array accessor classes for List, LargeList, FixedSizeList, Map, Struct, and
+// Union
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/data.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+/// \addtogroup nested-arrays
+///
+/// @{
+
+// ----------------------------------------------------------------------
+// ListArray
+
+template <typename TYPE>
+class BaseListArray;
+
+namespace internal {
+
+// Private helper for ListArray::SetData.
+// Unfortunately, trying to define BaseListArray::SetData outside of this header
+// doesn't play well with MSVC.
+template <typename TYPE>
+void SetListData(BaseListArray<TYPE>* self, const std::shared_ptr<ArrayData>& data,
+ Type::type expected_type_id = TYPE::type_id);
+
+} // namespace internal
+
+/// Base class for variable-sized list arrays, regardless of offset size.
+template <typename TYPE>
+class BaseListArray : public Array {
+ public:
+ using TypeClass = TYPE;
+ using offset_type = typename TypeClass::offset_type;
+
+ const TypeClass* list_type() const { return list_type_; }
+
+ /// \brief Return array object containing the list's values
+ std::shared_ptr<Array> values() const { return values_; }
+
+ /// Note that this buffer does not account for any slice offset
+ std::shared_ptr<Buffer> value_offsets() const { return data_->buffers[1]; }
+
+ std::shared_ptr<DataType> value_type() const { return list_type_->value_type(); }
+
+ /// Return pointer to raw value offsets accounting for any slice offset
+ const offset_type* raw_value_offsets() const {
+ return raw_value_offsets_ + data_->offset;
+ }
+
+ // The following functions will not perform boundschecking
+ offset_type value_offset(int64_t i) const {
+ return raw_value_offsets_[i + data_->offset];
+ }
+ offset_type value_length(int64_t i) const {
+ i += data_->offset;
+ return raw_value_offsets_[i + 1] - raw_value_offsets_[i];
+ }
+ std::shared_ptr<Array> value_slice(int64_t i) const {
+ return values_->Slice(value_offset(i), value_length(i));
+ }
+
+ protected:
+ friend void internal::SetListData<TYPE>(BaseListArray<TYPE>* self,
+ const std::shared_ptr<ArrayData>& data,
+ Type::type expected_type_id);
+
+ const TypeClass* list_type_ = NULLPTR;
+ std::shared_ptr<Array> values_;
+ const offset_type* raw_value_offsets_ = NULLPTR;
+};
+
+/// Concrete Array class for list data
+class ARROW_EXPORT ListArray : public BaseListArray<ListType> {
+ public:
+ explicit ListArray(std::shared_ptr<ArrayData> data);
+
+ ListArray(std::shared_ptr<DataType> type, int64_t length,
+ std::shared_ptr<Buffer> value_offsets, std::shared_ptr<Array> values,
+ std::shared_ptr<Buffer> null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ /// \brief Construct ListArray from array of offsets and child value array
+ ///
+ /// This function does the bare minimum of validation of the offsets and
+ /// input types, and will allocate a new offsets array if necessary (i.e. if
+ /// the offsets contain any nulls). If the offsets do not have nulls, they
+ /// are assumed to be well-formed
+ ///
+ /// \param[in] offsets Array containing n + 1 offsets encoding length and
+ /// size. Must be of int32 type
+ /// \param[in] values Array containing list values
+ /// \param[in] pool MemoryPool in case new offsets array needs to be
+ /// allocated because of null values
+ static Result<std::shared_ptr<ListArray>> FromArrays(
+ const Array& offsets, const Array& values,
+ MemoryPool* pool = default_memory_pool());
+
+ /// \brief Return an Array that is a concatenation of the lists in this array.
+ ///
+ /// Note that it's different from `values()` in that it takes into
+ /// consideration of this array's offsets as well as null elements backed
+ /// by non-empty lists (they are skipped, thus copying may be needed).
+ Result<std::shared_ptr<Array>> Flatten(
+ MemoryPool* memory_pool = default_memory_pool()) const;
+
+ /// \brief Return list offsets as an Int32Array
+ std::shared_ptr<Array> offsets() const;
+
+ protected:
+ // This constructor defers SetData to a derived array class
+ ListArray() = default;
+
+ void SetData(const std::shared_ptr<ArrayData>& data);
+};
+
+/// Concrete Array class for large list data (with 64-bit offsets)
+class ARROW_EXPORT LargeListArray : public BaseListArray<LargeListType> {
+ public:
+ explicit LargeListArray(const std::shared_ptr<ArrayData>& data);
+
+ LargeListArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Buffer>& value_offsets,
+ const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ /// \brief Construct LargeListArray from array of offsets and child value array
+ ///
+ /// This function does the bare minimum of validation of the offsets and
+ /// input types, and will allocate a new offsets array if necessary (i.e. if
+ /// the offsets contain any nulls). If the offsets do not have nulls, they
+ /// are assumed to be well-formed
+ ///
+ /// \param[in] offsets Array containing n + 1 offsets encoding length and
+ /// size. Must be of int64 type
+ /// \param[in] values Array containing list values
+ /// \param[in] pool MemoryPool in case new offsets array needs to be
+ /// allocated because of null values
+ static Result<std::shared_ptr<LargeListArray>> FromArrays(
+ const Array& offsets, const Array& values,
+ MemoryPool* pool = default_memory_pool());
+
+ /// \brief Return an Array that is a concatenation of the lists in this array.
+ ///
+ /// Note that it's different from `values()` in that it takes into
+ /// consideration of this array's offsets as well as null elements backed
+ /// by non-empty lists (they are skipped, thus copying may be needed).
+ Result<std::shared_ptr<Array>> Flatten(
+ MemoryPool* memory_pool = default_memory_pool()) const;
+
+ /// \brief Return list offsets as an Int64Array
+ std::shared_ptr<Array> offsets() const;
+
+ protected:
+ void SetData(const std::shared_ptr<ArrayData>& data);
+};
+
+// ----------------------------------------------------------------------
+// MapArray
+
+/// Concrete Array class for map data
+///
+/// NB: "value" in this context refers to a pair of a key and the corresponding item
+class ARROW_EXPORT MapArray : public ListArray {
+ public:
+ using TypeClass = MapType;
+
+ explicit MapArray(const std::shared_ptr<ArrayData>& data);
+
+ MapArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Buffer>& value_offsets,
+ const std::shared_ptr<Array>& keys, const std::shared_ptr<Array>& items,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ MapArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Buffer>& value_offsets,
+ const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ /// \brief Construct MapArray from array of offsets and child key, item arrays
+ ///
+ /// This function does the bare minimum of validation of the offsets and
+ /// input types, and will allocate a new offsets array if necessary (i.e. if
+ /// the offsets contain any nulls). If the offsets do not have nulls, they
+ /// are assumed to be well-formed
+ ///
+ /// \param[in] offsets Array containing n + 1 offsets encoding length and
+ /// size. Must be of int32 type
+ /// \param[in] keys Array containing key values
+ /// \param[in] items Array containing item values
+ /// \param[in] pool MemoryPool in case new offsets array needs to be
+ /// allocated because of null values
+ static Result<std::shared_ptr<Array>> FromArrays(
+ const std::shared_ptr<Array>& offsets, const std::shared_ptr<Array>& keys,
+ const std::shared_ptr<Array>& items, MemoryPool* pool = default_memory_pool());
+
+ static Result<std::shared_ptr<Array>> FromArrays(
+ std::shared_ptr<DataType> type, const std::shared_ptr<Array>& offsets,
+ const std::shared_ptr<Array>& keys, const std::shared_ptr<Array>& items,
+ MemoryPool* pool = default_memory_pool());
+
+ const MapType* map_type() const { return map_type_; }
+
+ /// \brief Return array object containing all map keys
+ std::shared_ptr<Array> keys() const { return keys_; }
+
+ /// \brief Return array object containing all mapped items
+ std::shared_ptr<Array> items() const { return items_; }
+
+ /// Validate child data before constructing the actual MapArray.
+ static Status ValidateChildData(
+ const std::vector<std::shared_ptr<ArrayData>>& child_data);
+
+ protected:
+ void SetData(const std::shared_ptr<ArrayData>& data);
+
+ static Result<std::shared_ptr<Array>> FromArraysInternal(
+ std::shared_ptr<DataType> type, const std::shared_ptr<Array>& offsets,
+ const std::shared_ptr<Array>& keys, const std::shared_ptr<Array>& items,
+ MemoryPool* pool);
+
+ private:
+ const MapType* map_type_;
+ std::shared_ptr<Array> keys_, items_;
+};
+
+// ----------------------------------------------------------------------
+// FixedSizeListArray
+
+/// Concrete Array class for fixed size list data
+class ARROW_EXPORT FixedSizeListArray : public Array {
+ public:
+ using TypeClass = FixedSizeListType;
+ using offset_type = TypeClass::offset_type;
+
+ explicit FixedSizeListArray(const std::shared_ptr<ArrayData>& data);
+
+ FixedSizeListArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ const FixedSizeListType* list_type() const;
+
+ /// \brief Return array object containing the list's values
+ std::shared_ptr<Array> values() const;
+
+ std::shared_ptr<DataType> value_type() const;
+
+ // The following functions will not perform boundschecking
+ int32_t value_offset(int64_t i) const {
+ i += data_->offset;
+ return static_cast<int32_t>(list_size_ * i);
+ }
+ int32_t value_length(int64_t i = 0) const {
+ ARROW_UNUSED(i);
+ return list_size_;
+ }
+ std::shared_ptr<Array> value_slice(int64_t i) const {
+ return values_->Slice(value_offset(i), value_length(i));
+ }
+
+ /// \brief Return an Array that is a concatenation of the lists in this array.
+ ///
+ /// Note that it's different from `values()` in that it takes into
+ /// consideration null elements (they are skipped, thus copying may be needed).
+ Result<std::shared_ptr<Array>> Flatten(
+ MemoryPool* memory_pool = default_memory_pool()) const;
+
+ /// \brief Construct FixedSizeListArray from child value array and value_length
+ ///
+ /// \param[in] values Array containing list values
+ /// \param[in] list_size The fixed length of each list
+ /// \return Will have length equal to values.length() / list_size
+ static Result<std::shared_ptr<Array>> FromArrays(const std::shared_ptr<Array>& values,
+ int32_t list_size);
+
+ protected:
+ void SetData(const std::shared_ptr<ArrayData>& data);
+ int32_t list_size_;
+
+ private:
+ std::shared_ptr<Array> values_;
+};
+
+// ----------------------------------------------------------------------
+// Struct
+
+/// Concrete Array class for struct data
+class ARROW_EXPORT StructArray : public Array {
+ public:
+ using TypeClass = StructType;
+
+ explicit StructArray(const std::shared_ptr<ArrayData>& data);
+
+ StructArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::vector<std::shared_ptr<Array>>& children,
+ std::shared_ptr<Buffer> null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ /// \brief Return a StructArray from child arrays and field names.
+ ///
+ /// The length and data type are automatically inferred from the arguments.
+ /// There should be at least one child array.
+ static Result<std::shared_ptr<StructArray>> Make(
+ const ArrayVector& children, const std::vector<std::string>& field_names,
+ std::shared_ptr<Buffer> null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ /// \brief Return a StructArray from child arrays and fields.
+ ///
+ /// The length is automatically inferred from the arguments.
+ /// There should be at least one child array. This method does not
+ /// check that field types and child array types are consistent.
+ static Result<std::shared_ptr<StructArray>> Make(
+ const ArrayVector& children, const FieldVector& fields,
+ std::shared_ptr<Buffer> null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ const StructType* struct_type() const;
+
+ // Return a shared pointer in case the requestor desires to share ownership
+ // with this array. The returned array has its offset, length and null
+ // count adjusted.
+ std::shared_ptr<Array> field(int pos) const;
+
+ const ArrayVector& fields() const;
+
+ /// Returns null if name not found
+ std::shared_ptr<Array> GetFieldByName(const std::string& name) const;
+
+ /// \brief Flatten this array as a vector of arrays, one for each field
+ ///
+ /// \param[in] pool The pool to allocate null bitmaps from, if necessary
+ Result<ArrayVector> Flatten(MemoryPool* pool = default_memory_pool()) const;
+
+ private:
+ // For caching boxed child data
+ // XXX This is not handled in a thread-safe manner.
+ mutable ArrayVector boxed_fields_;
+};
+
+// ----------------------------------------------------------------------
+// Union
+
+/// Base class for SparseUnionArray and DenseUnionArray
+class ARROW_EXPORT UnionArray : public Array {
+ public:
+ using type_code_t = int8_t;
+
+ /// Note that this buffer does not account for any slice offset
+ std::shared_ptr<Buffer> type_codes() const { return data_->buffers[1]; }
+
+ const type_code_t* raw_type_codes() const { return raw_type_codes_ + data_->offset; }
+
+ /// The logical type code of the value at index.
+ type_code_t type_code(int64_t i) const { return raw_type_codes_[i + data_->offset]; }
+
+ /// The physical child id containing value at index.
+ int child_id(int64_t i) const {
+ return union_type_->child_ids()[raw_type_codes_[i + data_->offset]];
+ }
+
+ const UnionType* union_type() const { return union_type_; }
+
+ UnionMode::type mode() const { return union_type_->mode(); }
+
+ /// \brief Return the given field as an individual array.
+ ///
+ /// For sparse unions, the returned array has its offset, length and null
+ /// count adjusted.
+ std::shared_ptr<Array> field(int pos) const;
+
+ protected:
+ void SetData(std::shared_ptr<ArrayData> data);
+
+ const type_code_t* raw_type_codes_;
+ const UnionType* union_type_;
+
+ // For caching boxed child data
+ mutable std::vector<std::shared_ptr<Array>> boxed_fields_;
+};
+
+/// Concrete Array class for sparse union data
+class ARROW_EXPORT SparseUnionArray : public UnionArray {
+ public:
+ using TypeClass = SparseUnionType;
+
+ explicit SparseUnionArray(std::shared_ptr<ArrayData> data);
+
+ SparseUnionArray(std::shared_ptr<DataType> type, int64_t length, ArrayVector children,
+ std::shared_ptr<Buffer> type_ids, int64_t offset = 0);
+
+ /// \brief Construct SparseUnionArray from type_ids and children
+ ///
+ /// This function does the bare minimum of validation of the input types.
+ ///
+ /// \param[in] type_ids An array of logical type ids for the union type
+ /// \param[in] children Vector of children Arrays containing the data for each type.
+ /// \param[in] type_codes Vector of type codes.
+ static Result<std::shared_ptr<Array>> Make(const Array& type_ids, ArrayVector children,
+ std::vector<type_code_t> type_codes) {
+ return Make(std::move(type_ids), std::move(children), std::vector<std::string>{},
+ std::move(type_codes));
+ }
+
+ /// \brief Construct SparseUnionArray with custom field names from type_ids and children
+ ///
+ /// This function does the bare minimum of validation of the input types.
+ ///
+ /// \param[in] type_ids An array of logical type ids for the union type
+ /// \param[in] children Vector of children Arrays containing the data for each type.
+ /// \param[in] field_names Vector of strings containing the name of each field.
+ /// \param[in] type_codes Vector of type codes.
+ static Result<std::shared_ptr<Array>> Make(const Array& type_ids, ArrayVector children,
+ std::vector<std::string> field_names = {},
+ std::vector<type_code_t> type_codes = {});
+
+ const SparseUnionType* union_type() const {
+ return internal::checked_cast<const SparseUnionType*>(union_type_);
+ }
+
+ protected:
+ void SetData(std::shared_ptr<ArrayData> data);
+};
+
+/// \brief Concrete Array class for dense union data
+///
+/// Note that union types do not have a validity bitmap
+class ARROW_EXPORT DenseUnionArray : public UnionArray {
+ public:
+ using TypeClass = DenseUnionType;
+
+ explicit DenseUnionArray(const std::shared_ptr<ArrayData>& data);
+
+ DenseUnionArray(std::shared_ptr<DataType> type, int64_t length, ArrayVector children,
+ std::shared_ptr<Buffer> type_ids,
+ std::shared_ptr<Buffer> value_offsets = NULLPTR, int64_t offset = 0);
+
+ /// \brief Construct DenseUnionArray from type_ids, value_offsets, and children
+ ///
+ /// This function does the bare minimum of validation of the offsets and
+ /// input types.
+ ///
+ /// \param[in] type_ids An array of logical type ids for the union type
+ /// \param[in] value_offsets An array of signed int32 values indicating the
+ /// relative offset into the respective child array for the type in a given slot.
+ /// The respective offsets for each child value array must be in order / increasing.
+ /// \param[in] children Vector of children Arrays containing the data for each type.
+ /// \param[in] type_codes Vector of type codes.
+ static Result<std::shared_ptr<Array>> Make(const Array& type_ids,
+ const Array& value_offsets,
+ ArrayVector children,
+ std::vector<type_code_t> type_codes) {
+ return Make(type_ids, value_offsets, std::move(children), std::vector<std::string>{},
+ std::move(type_codes));
+ }
+
+ /// \brief Construct DenseUnionArray with custom field names from type_ids,
+ /// value_offsets, and children
+ ///
+ /// This function does the bare minimum of validation of the offsets and
+ /// input types.
+ ///
+ /// \param[in] type_ids An array of logical type ids for the union type
+ /// \param[in] value_offsets An array of signed int32 values indicating the
+ /// relative offset into the respective child array for the type in a given slot.
+ /// The respective offsets for each child value array must be in order / increasing.
+ /// \param[in] children Vector of children Arrays containing the data for each type.
+ /// \param[in] field_names Vector of strings containing the name of each field.
+ /// \param[in] type_codes Vector of type codes.
+ static Result<std::shared_ptr<Array>> Make(const Array& type_ids,
+ const Array& value_offsets,
+ ArrayVector children,
+ std::vector<std::string> field_names = {},
+ std::vector<type_code_t> type_codes = {});
+
+ const DenseUnionType* union_type() const {
+ return internal::checked_cast<const DenseUnionType*>(union_type_);
+ }
+
+ /// Note that this buffer does not account for any slice offset
+ std::shared_ptr<Buffer> value_offsets() const { return data_->buffers[2]; }
+
+ int32_t value_offset(int64_t i) const { return raw_value_offsets_[i + data_->offset]; }
+
+ const int32_t* raw_value_offsets() const { return raw_value_offsets_ + data_->offset; }
+
+ protected:
+ const int32_t* raw_value_offsets_;
+
+ void SetData(const std::shared_ptr<ArrayData>& data);
+};
+
+/// @}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_primitive.cc b/src/arrow/cpp/src/arrow/array/array_primitive.cc
new file mode 100644
index 000000000..5312c3ece
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_primitive.cc
@@ -0,0 +1,133 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/array_primitive.h"
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/array/array_base.h"
+#include "arrow/type.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+// ----------------------------------------------------------------------
+// Primitive array base
+
+PrimitiveArray::PrimitiveArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap,
+ int64_t null_count, int64_t offset) {
+ SetData(ArrayData::Make(type, length, {null_bitmap, data}, null_count, offset));
+}
+
+// ----------------------------------------------------------------------
+// BooleanArray
+
+BooleanArray::BooleanArray(const std::shared_ptr<ArrayData>& data)
+ : PrimitiveArray(data) {
+ ARROW_CHECK_EQ(data->type->id(), Type::BOOL);
+}
+
+BooleanArray::BooleanArray(int64_t length, const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap, int64_t null_count,
+ int64_t offset)
+ : PrimitiveArray(boolean(), length, data, null_bitmap, null_count, offset) {}
+
+int64_t BooleanArray::false_count() const {
+ return this->length() - this->null_count() - this->true_count();
+}
+
+int64_t BooleanArray::true_count() const {
+ if (data_->null_count.load() != 0) {
+ DCHECK(data_->buffers[0]);
+ internal::BinaryBitBlockCounter bit_counter(data_->buffers[0]->data(), data_->offset,
+ data_->buffers[1]->data(), data_->offset,
+ data_->length);
+ int64_t count = 0;
+ while (true) {
+ internal::BitBlockCount block = bit_counter.NextAndWord();
+ if (block.length == 0) {
+ break;
+ }
+ count += block.popcount;
+ }
+ return count;
+ } else {
+ return internal::CountSetBits(data_->buffers[1]->data(), data_->offset,
+ data_->length);
+ }
+}
+
+// ----------------------------------------------------------------------
+// Day time interval
+
+DayTimeIntervalArray::DayTimeIntervalArray(const std::shared_ptr<ArrayData>& data) {
+ SetData(data);
+}
+
+DayTimeIntervalArray::DayTimeIntervalArray(const std::shared_ptr<DataType>& type,
+ int64_t length,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap,
+ int64_t null_count, int64_t offset)
+ : PrimitiveArray(type, length, data, null_bitmap, null_count, offset) {}
+
+DayTimeIntervalArray::DayTimeIntervalArray(int64_t length,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap,
+ int64_t null_count, int64_t offset)
+ : PrimitiveArray(day_time_interval(), length, data, null_bitmap, null_count, offset) {
+}
+
+DayTimeIntervalType::DayMilliseconds DayTimeIntervalArray::GetValue(int64_t i) const {
+ DCHECK(i < length());
+ return *reinterpret_cast<const DayTimeIntervalType::DayMilliseconds*>(
+ raw_values_ + (i + data_->offset) * byte_width());
+}
+
+// ----------------------------------------------------------------------
+// Month, day and Nanos interval
+
+MonthDayNanoIntervalArray::MonthDayNanoIntervalArray(
+ const std::shared_ptr<ArrayData>& data) {
+ SetData(data);
+}
+
+MonthDayNanoIntervalArray::MonthDayNanoIntervalArray(
+ const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Buffer>& data, const std::shared_ptr<Buffer>& null_bitmap,
+ int64_t null_count, int64_t offset)
+ : PrimitiveArray(type, length, data, null_bitmap, null_count, offset) {}
+
+MonthDayNanoIntervalArray::MonthDayNanoIntervalArray(
+ int64_t length, const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap, int64_t null_count, int64_t offset)
+ : PrimitiveArray(month_day_nano_interval(), length, data, null_bitmap, null_count,
+ offset) {}
+
+MonthDayNanoIntervalType::MonthDayNanos MonthDayNanoIntervalArray::GetValue(
+ int64_t i) const {
+ DCHECK(i < length());
+ return *reinterpret_cast<const MonthDayNanoIntervalType::MonthDayNanos*>(
+ raw_values_ + (i + data_->offset) * byte_width());
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_primitive.h b/src/arrow/cpp/src/arrow/array/array_primitive.h
new file mode 100644
index 000000000..b5385f965
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_primitive.h
@@ -0,0 +1,178 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Array accessor types for primitive/C-type-based arrays, such as numbers,
+// boolean, and temporal types.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/data.h"
+#include "arrow/stl_iterator.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h" // IWYU pragma: export
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+/// Concrete Array class for boolean data
+class ARROW_EXPORT BooleanArray : public PrimitiveArray {
+ public:
+ using TypeClass = BooleanType;
+ using IteratorType = stl::ArrayIterator<BooleanArray>;
+
+ explicit BooleanArray(const std::shared_ptr<ArrayData>& data);
+
+ BooleanArray(int64_t length, const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ bool Value(int64_t i) const {
+ return BitUtil::GetBit(reinterpret_cast<const uint8_t*>(raw_values_),
+ i + data_->offset);
+ }
+
+ bool GetView(int64_t i) const { return Value(i); }
+
+ /// \brief Return the number of false (0) values among the valid
+ /// values. Result is not cached.
+ int64_t false_count() const;
+
+ /// \brief Return the number of true (1) values among the valid
+ /// values. Result is not cached.
+ int64_t true_count() const;
+
+ IteratorType begin() const { return IteratorType(*this); }
+
+ IteratorType end() const { return IteratorType(*this, length()); }
+
+ protected:
+ using PrimitiveArray::PrimitiveArray;
+};
+
+/// \addtogroup numeric-arrays
+///
+/// @{
+
+/// \brief Concrete Array class for numeric data with a corresponding C type
+///
+/// This class is templated on the corresponding DataType subclass for the
+/// given data, for example NumericArray<Int8Type> or NumericArray<Date32Type>.
+///
+/// Note that convenience aliases are available for all accepted types
+/// (for example Int8Array for NumericArray<Int8Type>).
+template <typename TYPE>
+class NumericArray : public PrimitiveArray {
+ public:
+ using TypeClass = TYPE;
+ using value_type = typename TypeClass::c_type;
+ using IteratorType = stl::ArrayIterator<NumericArray<TYPE>>;
+
+ explicit NumericArray(const std::shared_ptr<ArrayData>& data) : PrimitiveArray(data) {}
+
+ // Only enable this constructor without a type argument for types without additional
+ // metadata
+ template <typename T1 = TYPE>
+ NumericArray(enable_if_parameter_free<T1, int64_t> length,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0)
+ : PrimitiveArray(TypeTraits<T1>::type_singleton(), length, data, null_bitmap,
+ null_count, offset) {}
+
+ const value_type* raw_values() const {
+ return reinterpret_cast<const value_type*>(raw_values_) + data_->offset;
+ }
+
+ value_type Value(int64_t i) const { return raw_values()[i]; }
+
+ // For API compatibility with BinaryArray etc.
+ value_type GetView(int64_t i) const { return Value(i); }
+
+ IteratorType begin() const { return IteratorType(*this); }
+
+ IteratorType end() const { return IteratorType(*this, length()); }
+
+ protected:
+ using PrimitiveArray::PrimitiveArray;
+};
+
+/// DayTimeArray
+/// ---------------------
+/// \brief Array of Day and Millisecond values.
+class ARROW_EXPORT DayTimeIntervalArray : public PrimitiveArray {
+ public:
+ using TypeClass = DayTimeIntervalType;
+
+ explicit DayTimeIntervalArray(const std::shared_ptr<ArrayData>& data);
+
+ DayTimeIntervalArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ DayTimeIntervalArray(int64_t length, const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ TypeClass::DayMilliseconds GetValue(int64_t i) const;
+ TypeClass::DayMilliseconds Value(int64_t i) const { return GetValue(i); }
+
+ // For compatibility with Take kernel.
+ TypeClass::DayMilliseconds GetView(int64_t i) const { return GetValue(i); }
+
+ int32_t byte_width() const { return sizeof(TypeClass::DayMilliseconds); }
+
+ const uint8_t* raw_values() const { return raw_values_ + data_->offset * byte_width(); }
+};
+
+/// \brief Array of Month, Day and nanosecond values.
+class ARROW_EXPORT MonthDayNanoIntervalArray : public PrimitiveArray {
+ public:
+ using TypeClass = MonthDayNanoIntervalType;
+
+ explicit MonthDayNanoIntervalArray(const std::shared_ptr<ArrayData>& data);
+
+ MonthDayNanoIntervalArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ MonthDayNanoIntervalArray(int64_t length, const std::shared_ptr<Buffer>& data,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ TypeClass::MonthDayNanos GetValue(int64_t i) const;
+ TypeClass::MonthDayNanos Value(int64_t i) const { return GetValue(i); }
+
+ // For compatibility with Take kernel.
+ TypeClass::MonthDayNanos GetView(int64_t i) const { return GetValue(i); }
+
+ int32_t byte_width() const { return sizeof(TypeClass::MonthDayNanos); }
+
+ const uint8_t* raw_values() const { return raw_values_ + data_->offset * byte_width(); }
+};
+
+/// @}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_struct_test.cc b/src/arrow/cpp/src/arrow/array/array_struct_test.cc
new file mode 100644
index 000000000..49573af89
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_struct_test.cc
@@ -0,0 +1,699 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/chunked_array.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+// ----------------------------------------------------------------------
+// Struct tests
+
+void ValidateBasicStructArray(const StructArray* result,
+ const std::vector<uint8_t>& struct_is_valid,
+ const std::vector<char>& list_values,
+ const std::vector<uint8_t>& list_is_valid,
+ const std::vector<int>& list_lengths,
+ const std::vector<int>& list_offsets,
+ const std::vector<int32_t>& int_values) {
+ ASSERT_EQ(4, result->length());
+ ASSERT_OK(result->ValidateFull());
+
+ auto list_char_arr = std::dynamic_pointer_cast<ListArray>(result->field(0));
+ auto char_arr = std::dynamic_pointer_cast<Int8Array>(list_char_arr->values());
+ auto int32_arr = std::dynamic_pointer_cast<Int32Array>(result->field(1));
+
+ ASSERT_EQ(nullptr, result->GetFieldByName("nonexistent"));
+ ASSERT_TRUE(list_char_arr->Equals(result->GetFieldByName("list")));
+ ASSERT_TRUE(int32_arr->Equals(result->GetFieldByName("int")));
+
+ ASSERT_EQ(0, result->null_count());
+ ASSERT_EQ(1, list_char_arr->null_count());
+ ASSERT_EQ(0, int32_arr->null_count());
+
+ // List<char>
+ ASSERT_EQ(4, list_char_arr->length());
+ ASSERT_EQ(10, list_char_arr->values()->length());
+ for (size_t i = 0; i < list_offsets.size(); ++i) {
+ ASSERT_EQ(list_offsets[i], list_char_arr->raw_value_offsets()[i]);
+ }
+ for (size_t i = 0; i < list_values.size(); ++i) {
+ ASSERT_EQ(list_values[i], char_arr->Value(i));
+ }
+
+ // Int32
+ ASSERT_EQ(4, int32_arr->length());
+ for (size_t i = 0; i < int_values.size(); ++i) {
+ ASSERT_EQ(int_values[i], int32_arr->Value(i));
+ }
+}
+
+TEST(StructArray, FromFieldNames) {
+ std::shared_ptr<Array> a, b, c, array, expected;
+ a = ArrayFromJSON(int32(), "[4, null]");
+ b = ArrayFromJSON(utf8(), R"([null, "foo"])");
+ std::vector<std::string> field_names = {"a", "b"};
+
+ auto res = StructArray::Make({a, b}, field_names);
+ ASSERT_OK(res);
+ array = *res;
+ expected = ArrayFromJSON(struct_({field("a", int32()), field("b", utf8())}),
+ R"([{"a": 4, "b": null}, {"a": null, "b": "foo"}])");
+ AssertArraysEqual(*array, *expected);
+
+ // With non-zero offsets
+ res =
+ StructArray::Make({a, b}, field_names, /*null_bitmap =*/nullptr, /*null_count =*/0,
+ /*offset =*/1);
+ ASSERT_OK(res);
+ array = *res;
+ expected = ArrayFromJSON(struct_({field("a", int32()), field("b", utf8())}),
+ R"([{"a": null, "b": "foo"}])");
+ AssertArraysEqual(*array, *expected);
+
+ res =
+ StructArray::Make({a, b}, field_names, /*null_bitmap =*/nullptr, /*null_count =*/0,
+ /*offset =*/2);
+ ASSERT_OK(res);
+ array = *res;
+ expected = ArrayFromJSON(struct_({field("a", int32()), field("b", utf8())}), R"([])");
+ AssertArraysEqual(*array, *expected);
+
+ // Offset greater than length
+ res =
+ StructArray::Make({a, b}, field_names, /*null_bitmap =*/nullptr, /*null_count =*/0,
+ /*offset =*/3);
+ ASSERT_RAISES(IndexError, res);
+
+ // With null bitmap
+ std::shared_ptr<Buffer> null_bitmap;
+ BitmapFromVector<bool>({false, true}, &null_bitmap);
+ res = StructArray::Make({a, b}, field_names, null_bitmap);
+ ASSERT_OK(res);
+ array = *res;
+ expected = ArrayFromJSON(struct_({field("a", int32()), field("b", utf8())}),
+ R"([null, {"a": null, "b": "foo"}])");
+ AssertArraysEqual(*array, *expected);
+
+ // Mismatching array lengths
+ field_names = {"a", "c"};
+ c = ArrayFromJSON(int64(), "[1, 2, 3]");
+ res = StructArray::Make({a, c}, field_names);
+ ASSERT_RAISES(Invalid, res);
+
+ // Mismatching number of fields
+ field_names = {"a", "b", "c"};
+ res = StructArray::Make({a, b}, field_names);
+ ASSERT_RAISES(Invalid, res);
+
+ // Fail on 0 children (cannot infer array length)
+ field_names = {};
+ res = StructArray::Make({}, field_names);
+ ASSERT_RAISES(Invalid, res);
+}
+
+TEST(StructArray, FromFields) {
+ std::shared_ptr<Array> a, b, c, array, expected;
+ std::shared_ptr<Field> fa, fb, fc;
+ a = ArrayFromJSON(int32(), "[4, 5]");
+ b = ArrayFromJSON(utf8(), R"([null, "foo"])");
+ fa = field("a", int32(), /*nullable =*/false);
+ fb = field("b", utf8(), /*nullable =*/true);
+ fc = field("b", int64(), /*nullable =*/true);
+
+ auto res = StructArray::Make({a, b}, {fa, fb});
+ ASSERT_OK(res);
+ array = *res;
+ expected =
+ ArrayFromJSON(struct_({fa, fb}), R"([{"a": 4, "b": null}, {"a": 5, "b": "foo"}])");
+ AssertArraysEqual(*array, *expected);
+
+ // Mismatching array lengths
+ c = ArrayFromJSON(int64(), "[1, 2, 3]");
+ res = StructArray::Make({a, c}, {fa, fc});
+ ASSERT_RAISES(Invalid, res);
+
+ // Mismatching number of fields
+ res = StructArray::Make({a, b}, {fa, fb, fc});
+ ASSERT_RAISES(Invalid, res);
+
+ // Fail on 0 children (cannot infer array length)
+ res = StructArray::Make({}, std::vector<std::shared_ptr<Field>>{});
+ ASSERT_RAISES(Invalid, res);
+}
+
+TEST(StructArray, Validate) {
+ auto a = ArrayFromJSON(int32(), "[4, 5]");
+ auto type = struct_({field("a", int32())});
+ auto children = std::vector<std::shared_ptr<Array>>{a};
+
+ auto arr = std::make_shared<StructArray>(type, 2, children);
+ ASSERT_OK(arr->ValidateFull());
+ arr = std::make_shared<StructArray>(type, 1, children, nullptr, 0, /*offset=*/1);
+ ASSERT_OK(arr->ValidateFull());
+ arr = std::make_shared<StructArray>(type, 0, children, nullptr, 0, /*offset=*/2);
+ ASSERT_OK(arr->ValidateFull());
+
+ // Length + offset < child length, but it's ok
+ arr = std::make_shared<StructArray>(type, 1, children, nullptr, 0, /*offset=*/0);
+ ASSERT_OK(arr->ValidateFull());
+
+ // Length + offset > child length
+ arr = std::make_shared<StructArray>(type, 1, children, nullptr, 0, /*offset=*/2);
+ ASSERT_RAISES(Invalid, arr->ValidateFull());
+
+ // Offset > child length
+ arr = std::make_shared<StructArray>(type, 0, children, nullptr, 0, /*offset=*/3);
+ ASSERT_RAISES(Invalid, arr->ValidateFull());
+}
+
+TEST(StructArray, Flatten) {
+ auto type =
+ struct_({field("a", int32()), field("b", utf8()), field("c", list(boolean()))});
+ {
+ auto struct_arr = std::static_pointer_cast<StructArray>(ArrayFromJSON(
+ type, R"([[1, "a", [null, false]], [null, "bc", []], [2, null, null]])"));
+ ASSERT_OK_AND_ASSIGN(auto flattened, struct_arr->Flatten(default_memory_pool()));
+ AssertArraysEqual(*ArrayFromJSON(int32(), "[1, null, 2]"), *flattened[0],
+ /*verbose=*/true);
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"(["a", "bc", null])"), *flattened[1],
+ /*verbose=*/true);
+ AssertArraysEqual(*ArrayFromJSON(list(boolean()), "[[null, false], [], null]"),
+ *flattened[2], /*verbose=*/true);
+ }
+ {
+ ArrayVector children = {
+ ArrayFromJSON(int32(), "[1, 2, 3, 4]")->Slice(1, 3),
+ ArrayFromJSON(utf8(), R"([null, "ab", "cde", null])")->Slice(1, 3),
+ ArrayFromJSON(list(boolean()), "[[true], [], [true, false, null], [false]]")
+ ->Slice(1, 3),
+ };
+
+ // Without slice or top-level nulls
+ auto struct_arr = std::make_shared<StructArray>(type, 3, children);
+ ASSERT_OK_AND_ASSIGN(auto flattened, struct_arr->Flatten(default_memory_pool()));
+ AssertArraysEqual(*ArrayFromJSON(int32(), "[2, 3, 4]"), *flattened[0],
+ /*verbose=*/true);
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"(["ab", "cde", null])"), *flattened[1],
+ /*verbose=*/true);
+ AssertArraysEqual(
+ *ArrayFromJSON(list(boolean()), "[[], [true, false, null], [false]]"),
+ *flattened[2], /*verbose=*/true);
+
+ // With slice
+ struct_arr = std::make_shared<StructArray>(type, 2, children, /*null_bitmap=*/nullptr,
+ /*null_count=*/0, /*offset=*/1);
+ ASSERT_OK_AND_ASSIGN(flattened, struct_arr->Flatten(default_memory_pool()));
+ AssertArraysEqual(*ArrayFromJSON(int32(), "[3, 4]"), *flattened[0],
+ /*verbose=*/true);
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"(["cde", null])"), *flattened[1],
+ /*verbose=*/true);
+ AssertArraysEqual(*ArrayFromJSON(list(boolean()), "[[true, false, null], [false]]"),
+ *flattened[2], /*verbose=*/true);
+
+ struct_arr = std::make_shared<StructArray>(type, 1, children, /*null_bitmap=*/nullptr,
+ /*null_count=*/0, /*offset=*/2);
+ ASSERT_OK_AND_ASSIGN(flattened, struct_arr->Flatten(default_memory_pool()));
+ AssertArraysEqual(*ArrayFromJSON(int32(), "[4]"), *flattened[0],
+ /*verbose=*/true);
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null])"), *flattened[1],
+ /*verbose=*/true);
+ AssertArraysEqual(*ArrayFromJSON(list(boolean()), "[[false]]"), *flattened[2],
+ /*verbose=*/true);
+
+ // With top-level nulls
+ std::shared_ptr<Buffer> null_bitmap;
+ BitmapFromVector<bool>({true, false, true}, &null_bitmap);
+ struct_arr = std::make_shared<StructArray>(type, 3, children, null_bitmap);
+ ASSERT_OK_AND_ASSIGN(flattened, struct_arr->Flatten(default_memory_pool()));
+ AssertArraysEqual(*ArrayFromJSON(int32(), "[2, null, 4]"), *flattened[0],
+ /*verbose=*/true);
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"(["ab", null, null])"), *flattened[1],
+ /*verbose=*/true);
+ AssertArraysEqual(*ArrayFromJSON(list(boolean()), "[[], null, [false]]"),
+ *flattened[2], /*verbose=*/true);
+
+ // With slice and top-level nulls
+ struct_arr = std::make_shared<StructArray>(type, 2, children, null_bitmap,
+ /*null_count=*/1, /*offset=*/1);
+ ASSERT_OK_AND_ASSIGN(flattened, struct_arr->Flatten(default_memory_pool()));
+ AssertArraysEqual(*ArrayFromJSON(int32(), "[null, 4]"), *flattened[0],
+ /*verbose=*/true);
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null, null])"), *flattened[1],
+ /*verbose=*/true);
+ AssertArraysEqual(*ArrayFromJSON(list(boolean()), "[null, [false]]"), *flattened[2],
+ /*verbose=*/true);
+
+ struct_arr = std::make_shared<StructArray>(type, 1, children, null_bitmap,
+ /*null_count=*/0, /*offset=*/2);
+ ASSERT_OK_AND_ASSIGN(flattened, struct_arr->Flatten(default_memory_pool()));
+ AssertArraysEqual(*ArrayFromJSON(int32(), "[4]"), *flattened[0],
+ /*verbose=*/true);
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null])"), *flattened[1],
+ /*verbose=*/true);
+ AssertArraysEqual(*ArrayFromJSON(list(boolean()), "[[false]]"), *flattened[2],
+ /*verbose=*/true);
+ }
+}
+
+/// ARROW-7740: Flattening a slice shouldn't affect the parent array.
+TEST(StructArray, FlattenOfSlice) {
+ auto a = ArrayFromJSON(int32(), "[4, 5]");
+ auto type = struct_({field("a", int32())});
+ auto children = std::vector<std::shared_ptr<Array>>{a};
+
+ auto arr = std::make_shared<StructArray>(type, 2, children);
+ ASSERT_OK(arr->ValidateFull());
+
+ auto slice = internal::checked_pointer_cast<StructArray>(arr->Slice(0, 1));
+ ASSERT_OK(slice->ValidateFull());
+
+ ASSERT_OK_AND_ASSIGN(auto flattened, slice->Flatten(default_memory_pool()));
+
+ ASSERT_OK(slice->ValidateFull());
+ ASSERT_OK(arr->ValidateFull());
+}
+
+// ----------------------------------------------------------------------------------
+// Struct test
+class TestStructBuilder : public TestBuilder {
+ public:
+ void SetUp() {
+ TestBuilder::SetUp();
+
+ auto int32_type = int32();
+ auto char_type = int8();
+ auto list_type = list(char_type);
+
+ std::vector<std::shared_ptr<DataType>> types = {list_type, int32_type};
+ std::vector<std::shared_ptr<Field>> fields;
+ fields.push_back(field("list", list_type));
+ fields.push_back(field("int", int32_type));
+
+ type_ = struct_(fields);
+ value_fields_ = fields;
+
+ std::unique_ptr<ArrayBuilder> tmp;
+ ASSERT_OK(MakeBuilder(pool_, type_, &tmp));
+ builder_.reset(checked_cast<StructBuilder*>(tmp.release()));
+ ASSERT_EQ(2, static_cast<int>(builder_->num_fields()));
+ }
+
+ void Done() {
+ std::shared_ptr<Array> out;
+ FinishAndCheckPadding(builder_.get(), &out);
+ result_ = std::dynamic_pointer_cast<StructArray>(out);
+ }
+
+ protected:
+ std::vector<std::shared_ptr<Field>> value_fields_;
+
+ std::shared_ptr<StructBuilder> builder_;
+ std::shared_ptr<StructArray> result_;
+};
+
+TEST_F(TestStructBuilder, TestAppendNull) {
+ ASSERT_OK(builder_->AppendNull());
+ ASSERT_OK(builder_->AppendNull());
+ ASSERT_EQ(2, static_cast<int>(builder_->num_fields()));
+
+ Done();
+
+ ASSERT_OK(result_->ValidateFull());
+
+ ASSERT_EQ(2, static_cast<int>(result_->num_fields()));
+ ASSERT_EQ(2, result_->length());
+ ASSERT_EQ(2, result_->field(0)->length());
+ ASSERT_EQ(2, result_->field(1)->length());
+ ASSERT_TRUE(result_->IsNull(0));
+ ASSERT_TRUE(result_->IsNull(1));
+ ASSERT_EQ(0, result_->field(0)->null_count());
+ ASSERT_EQ(0, result_->field(1)->null_count());
+
+ ASSERT_EQ(Type::LIST, result_->field(0)->type_id());
+ ASSERT_EQ(Type::INT32, result_->field(1)->type_id());
+}
+
+TEST_F(TestStructBuilder, TestBasics) {
+ std::vector<int32_t> int_values = {1, 2, 3, 4};
+ std::vector<char> list_values = {'j', 'o', 'e', 'b', 'o', 'b', 'm', 'a', 'r', 'k'};
+ std::vector<int> list_lengths = {3, 0, 3, 4};
+ std::vector<int> list_offsets = {0, 3, 3, 6, 10};
+ std::vector<uint8_t> list_is_valid = {1, 0, 1, 1};
+ std::vector<uint8_t> struct_is_valid = {1, 1, 1, 1};
+
+ ListBuilder* list_vb = checked_cast<ListBuilder*>(builder_->field_builder(0));
+ Int8Builder* char_vb = checked_cast<Int8Builder*>(list_vb->value_builder());
+ Int32Builder* int_vb = checked_cast<Int32Builder*>(builder_->field_builder(1));
+ ASSERT_EQ(2, static_cast<int>(builder_->num_fields()));
+
+ ARROW_EXPECT_OK(builder_->Resize(list_lengths.size()));
+ ARROW_EXPECT_OK(char_vb->Resize(list_values.size()));
+ ARROW_EXPECT_OK(int_vb->Resize(int_values.size()));
+
+ int pos = 0;
+ for (size_t i = 0; i < list_lengths.size(); ++i) {
+ ASSERT_OK(list_vb->Append(list_is_valid[i] > 0));
+ int_vb->UnsafeAppend(int_values[i]);
+ for (int j = 0; j < list_lengths[i]; ++j) {
+ char_vb->UnsafeAppend(list_values[pos++]);
+ }
+ }
+
+ for (size_t i = 0; i < struct_is_valid.size(); ++i) {
+ ASSERT_OK(builder_->Append(struct_is_valid[i] > 0));
+ }
+
+ Done();
+
+ ValidateBasicStructArray(result_.get(), struct_is_valid, list_values, list_is_valid,
+ list_lengths, list_offsets, int_values);
+}
+
+TEST_F(TestStructBuilder, BulkAppend) {
+ std::vector<int32_t> int_values = {1, 2, 3, 4};
+ std::vector<char> list_values = {'j', 'o', 'e', 'b', 'o', 'b', 'm', 'a', 'r', 'k'};
+ std::vector<int> list_lengths = {3, 0, 3, 4};
+ std::vector<int> list_offsets = {0, 3, 3, 6};
+ std::vector<uint8_t> list_is_valid = {1, 0, 1, 1};
+ std::vector<uint8_t> struct_is_valid = {1, 1, 1, 1};
+
+ ListBuilder* list_vb = checked_cast<ListBuilder*>(builder_->field_builder(0));
+ Int8Builder* char_vb = checked_cast<Int8Builder*>(list_vb->value_builder());
+ Int32Builder* int_vb = checked_cast<Int32Builder*>(builder_->field_builder(1));
+
+ ASSERT_OK(builder_->Resize(list_lengths.size()));
+ ASSERT_OK(char_vb->Resize(list_values.size()));
+ ASSERT_OK(int_vb->Resize(int_values.size()));
+
+ ASSERT_OK(builder_->AppendValues(struct_is_valid.size(), struct_is_valid.data()));
+
+ ASSERT_OK(list_vb->AppendValues(list_offsets.data(), list_offsets.size(),
+ list_is_valid.data()));
+ for (int8_t value : list_values) {
+ char_vb->UnsafeAppend(value);
+ }
+ for (int32_t value : int_values) {
+ int_vb->UnsafeAppend(value);
+ }
+
+ Done();
+ ValidateBasicStructArray(result_.get(), struct_is_valid, list_values, list_is_valid,
+ list_lengths, list_offsets, int_values);
+}
+
+TEST_F(TestStructBuilder, BulkAppendInvalid) {
+ std::vector<int32_t> int_values = {1, 2, 3, 4};
+ std::vector<char> list_values = {'j', 'o', 'e', 'b', 'o', 'b', 'm', 'a', 'r', 'k'};
+ std::vector<int> list_lengths = {3, 0, 3, 4};
+ std::vector<int> list_offsets = {0, 3, 3, 6};
+ std::vector<uint8_t> list_is_valid = {1, 0, 1, 1};
+ std::vector<uint8_t> struct_is_valid = {1, 0, 1, 1}; // should be 1, 1, 1, 1
+
+ ListBuilder* list_vb = checked_cast<ListBuilder*>(builder_->field_builder(0));
+ Int8Builder* char_vb = checked_cast<Int8Builder*>(list_vb->value_builder());
+ Int32Builder* int_vb = checked_cast<Int32Builder*>(builder_->field_builder(1));
+
+ ASSERT_OK(builder_->Reserve(list_lengths.size()));
+ ASSERT_OK(char_vb->Reserve(list_values.size()));
+ ASSERT_OK(int_vb->Reserve(int_values.size()));
+
+ ASSERT_OK(builder_->AppendValues(struct_is_valid.size(), struct_is_valid.data()));
+
+ ASSERT_OK(list_vb->AppendValues(list_offsets.data(), list_offsets.size(),
+ list_is_valid.data()));
+ for (int8_t value : list_values) {
+ char_vb->UnsafeAppend(value);
+ }
+ for (int32_t value : int_values) {
+ int_vb->UnsafeAppend(value);
+ }
+
+ Done();
+ // Even null bitmap of the parent Struct is not valid, validate will ignore it.
+ ASSERT_OK(result_->ValidateFull());
+}
+
+TEST_F(TestStructBuilder, TestEquality) {
+ std::shared_ptr<Array> array, equal_array;
+ std::shared_ptr<Array> unequal_bitmap_array, unequal_offsets_array,
+ unequal_values_array;
+
+ std::vector<int32_t> int_values = {101, 102, 103, 104};
+ std::vector<char> list_values = {'j', 'o', 'e', 'b', 'o', 'b', 'm', 'a', 'r', 'k'};
+ std::vector<int> list_lengths = {3, 0, 3, 4};
+ std::vector<int> list_offsets = {0, 3, 3, 6};
+ std::vector<uint8_t> list_is_valid = {1, 0, 1, 1};
+ std::vector<uint8_t> struct_is_valid = {1, 1, 1, 1};
+
+ std::vector<int32_t> unequal_int_values = {104, 102, 103, 101};
+ std::vector<char> unequal_list_values = {'j', 'o', 'e', 'b', 'o',
+ 'b', 'l', 'u', 'c', 'y'};
+ std::vector<int> unequal_list_offsets = {0, 3, 4, 6};
+ std::vector<uint8_t> unequal_list_is_valid = {1, 1, 1, 1};
+ std::vector<uint8_t> unequal_struct_is_valid = {1, 0, 0, 1};
+
+ ListBuilder* list_vb = checked_cast<ListBuilder*>(builder_->field_builder(0));
+ Int8Builder* char_vb = checked_cast<Int8Builder*>(list_vb->value_builder());
+ Int32Builder* int_vb = checked_cast<Int32Builder*>(builder_->field_builder(1));
+ ASSERT_OK(builder_->Reserve(list_lengths.size()));
+ ASSERT_OK(char_vb->Reserve(list_values.size()));
+ ASSERT_OK(int_vb->Reserve(int_values.size()));
+
+ // setup two equal arrays, one of which takes an unequal bitmap
+ ASSERT_OK(builder_->AppendValues(struct_is_valid.size(), struct_is_valid.data()));
+ ASSERT_OK(list_vb->AppendValues(list_offsets.data(), list_offsets.size(),
+ list_is_valid.data()));
+ for (int8_t value : list_values) {
+ char_vb->UnsafeAppend(value);
+ }
+ for (int32_t value : int_values) {
+ int_vb->UnsafeAppend(value);
+ }
+
+ FinishAndCheckPadding(builder_.get(), &array);
+
+ ASSERT_OK(builder_->Resize(list_lengths.size()));
+ ASSERT_OK(char_vb->Resize(list_values.size()));
+ ASSERT_OK(int_vb->Resize(int_values.size()));
+
+ ASSERT_OK(builder_->AppendValues(struct_is_valid.size(), struct_is_valid.data()));
+ ASSERT_OK(list_vb->AppendValues(list_offsets.data(), list_offsets.size(),
+ list_is_valid.data()));
+ for (int8_t value : list_values) {
+ char_vb->UnsafeAppend(value);
+ }
+ for (int32_t value : int_values) {
+ int_vb->UnsafeAppend(value);
+ }
+
+ ASSERT_OK(builder_->Finish(&equal_array));
+
+ ASSERT_OK(builder_->Resize(list_lengths.size()));
+ ASSERT_OK(char_vb->Resize(list_values.size()));
+ ASSERT_OK(int_vb->Resize(int_values.size()));
+
+ // setup an unequal one with the unequal bitmap
+ ASSERT_OK(builder_->AppendValues(unequal_struct_is_valid.size(),
+ unequal_struct_is_valid.data()));
+ ASSERT_OK(list_vb->AppendValues(list_offsets.data(), list_offsets.size(),
+ list_is_valid.data()));
+ for (int8_t value : list_values) {
+ char_vb->UnsafeAppend(value);
+ }
+ for (int32_t value : int_values) {
+ int_vb->UnsafeAppend(value);
+ }
+
+ ASSERT_OK(builder_->Finish(&unequal_bitmap_array));
+
+ ASSERT_OK(builder_->Resize(list_lengths.size()));
+ ASSERT_OK(char_vb->Resize(list_values.size()));
+ ASSERT_OK(int_vb->Resize(int_values.size()));
+
+ // setup an unequal one with unequal offsets
+ ASSERT_OK(builder_->AppendValues(struct_is_valid.size(), struct_is_valid.data()));
+ ASSERT_OK(list_vb->AppendValues(unequal_list_offsets.data(),
+ unequal_list_offsets.size(),
+ unequal_list_is_valid.data()));
+ for (int8_t value : list_values) {
+ char_vb->UnsafeAppend(value);
+ }
+ for (int32_t value : int_values) {
+ int_vb->UnsafeAppend(value);
+ }
+
+ ASSERT_OK(builder_->Finish(&unequal_offsets_array));
+
+ ASSERT_OK(builder_->Resize(list_lengths.size()));
+ ASSERT_OK(char_vb->Resize(list_values.size()));
+ ASSERT_OK(int_vb->Resize(int_values.size()));
+
+ // setup anunequal one with unequal values
+ ASSERT_OK(builder_->AppendValues(struct_is_valid.size(), struct_is_valid.data()));
+ ASSERT_OK(list_vb->AppendValues(list_offsets.data(), list_offsets.size(),
+ list_is_valid.data()));
+ for (int8_t value : unequal_list_values) {
+ char_vb->UnsafeAppend(value);
+ }
+ for (int32_t value : unequal_int_values) {
+ int_vb->UnsafeAppend(value);
+ }
+
+ ASSERT_OK(builder_->Finish(&unequal_values_array));
+
+ // Test array equality
+ EXPECT_TRUE(array->Equals(array));
+ EXPECT_TRUE(array->Equals(equal_array));
+ EXPECT_TRUE(equal_array->Equals(array));
+ EXPECT_FALSE(equal_array->Equals(unequal_bitmap_array));
+ EXPECT_FALSE(unequal_bitmap_array->Equals(equal_array));
+ EXPECT_FALSE(unequal_bitmap_array->Equals(unequal_values_array));
+ EXPECT_FALSE(unequal_values_array->Equals(unequal_bitmap_array));
+ EXPECT_FALSE(unequal_bitmap_array->Equals(unequal_offsets_array));
+ EXPECT_FALSE(unequal_offsets_array->Equals(unequal_bitmap_array));
+
+ // Test range equality
+ EXPECT_TRUE(array->RangeEquals(0, 4, 0, equal_array));
+ EXPECT_TRUE(array->RangeEquals(3, 4, 3, unequal_bitmap_array));
+ EXPECT_TRUE(array->RangeEquals(0, 1, 0, unequal_offsets_array));
+ EXPECT_FALSE(array->RangeEquals(0, 2, 0, unequal_offsets_array));
+ EXPECT_FALSE(array->RangeEquals(1, 2, 1, unequal_offsets_array));
+ EXPECT_FALSE(array->RangeEquals(0, 1, 0, unequal_values_array));
+ EXPECT_TRUE(array->RangeEquals(1, 3, 1, unequal_values_array));
+ EXPECT_FALSE(array->RangeEquals(3, 4, 3, unequal_values_array));
+}
+
+TEST_F(TestStructBuilder, TestZeroLength) {
+ // All buffers are null
+ Done();
+ ASSERT_OK(result_->ValidateFull());
+}
+
+TEST_F(TestStructBuilder, TestSlice) {
+ std::shared_ptr<Array> array, equal_array;
+ std::shared_ptr<Array> unequal_bitmap_array, unequal_offsets_array,
+ unequal_values_array;
+
+ std::vector<int32_t> int_values = {101, 102, 103, 104};
+ std::vector<char> list_values = {'j', 'o', 'e', 'b', 'o', 'b', 'm', 'a', 'r', 'k'};
+ std::vector<int> list_lengths = {3, 0, 3, 4};
+ std::vector<int> list_offsets = {0, 3, 3, 6};
+ std::vector<uint8_t> list_is_valid = {1, 0, 1, 1};
+ std::vector<uint8_t> struct_is_valid = {1, 1, 1, 1};
+
+ ListBuilder* list_vb = checked_cast<ListBuilder*>(builder_->field_builder(0));
+ Int8Builder* char_vb = checked_cast<Int8Builder*>(list_vb->value_builder());
+ Int32Builder* int_vb = checked_cast<Int32Builder*>(builder_->field_builder(1));
+ ASSERT_OK(builder_->Reserve(list_lengths.size()));
+ ASSERT_OK(char_vb->Reserve(list_values.size()));
+ ASSERT_OK(int_vb->Reserve(int_values.size()));
+
+ ASSERT_OK(builder_->AppendValues(struct_is_valid.size(), struct_is_valid.data()));
+ ASSERT_OK(list_vb->AppendValues(list_offsets.data(), list_offsets.size(),
+ list_is_valid.data()));
+ for (int8_t value : list_values) {
+ char_vb->UnsafeAppend(value);
+ }
+ for (int32_t value : int_values) {
+ int_vb->UnsafeAppend(value);
+ }
+ FinishAndCheckPadding(builder_.get(), &array);
+
+ std::shared_ptr<StructArray> slice, slice2;
+ std::shared_ptr<Int32Array> int_field;
+ std::shared_ptr<ListArray> list_field;
+
+ slice = std::dynamic_pointer_cast<StructArray>(array->Slice(2));
+ slice2 = std::dynamic_pointer_cast<StructArray>(array->Slice(2));
+ ASSERT_EQ(array->length() - 2, slice->length());
+
+ ASSERT_TRUE(slice->Equals(slice2));
+ ASSERT_TRUE(array->RangeEquals(2, slice->length(), 0, slice));
+
+ int_field = std::dynamic_pointer_cast<Int32Array>(slice->field(1));
+ ASSERT_EQ(int_field->length(), slice->length());
+ ASSERT_EQ(int_field->Value(0), 103);
+ ASSERT_EQ(int_field->Value(1), 104);
+ ASSERT_EQ(int_field->null_count(), 0);
+ list_field = std::dynamic_pointer_cast<ListArray>(slice->field(0));
+ ASSERT_FALSE(list_field->IsNull(0));
+ ASSERT_FALSE(list_field->IsNull(1));
+ ASSERT_EQ(list_field->value_length(0), 3);
+ ASSERT_EQ(list_field->value_length(1), 4);
+ ASSERT_EQ(list_field->null_count(), 0);
+
+ slice = std::dynamic_pointer_cast<StructArray>(array->Slice(1, 2));
+ slice2 = std::dynamic_pointer_cast<StructArray>(array->Slice(1, 2));
+ ASSERT_EQ(2, slice->length());
+
+ ASSERT_TRUE(slice->Equals(slice2));
+ ASSERT_TRUE(array->RangeEquals(1, 3, 0, slice));
+
+ int_field = std::dynamic_pointer_cast<Int32Array>(slice->field(1));
+ ASSERT_EQ(int_field->length(), slice->length());
+ ASSERT_EQ(int_field->Value(0), 102);
+ ASSERT_EQ(int_field->Value(1), 103);
+ ASSERT_EQ(int_field->null_count(), 0);
+ list_field = std::dynamic_pointer_cast<ListArray>(slice->field(0));
+ ASSERT_TRUE(list_field->IsNull(0));
+ ASSERT_FALSE(list_field->IsNull(1));
+ ASSERT_EQ(list_field->value_length(0), 0);
+ ASSERT_EQ(list_field->value_length(1), 3);
+ ASSERT_EQ(list_field->null_count(), 1);
+}
+
+TEST(TestFieldRef, GetChildren) {
+ auto struct_array = ArrayFromJSON(struct_({field("a", float64())}), R"([
+ {"a": 6.125},
+ {"a": 0.0},
+ {"a": -1}
+ ])");
+
+ ASSERT_OK_AND_ASSIGN(auto a, FieldRef("a").GetOne(*struct_array));
+ auto expected_a = ArrayFromJSON(float64(), "[6.125, 0.0, -1]");
+ AssertArraysEqual(*a, *expected_a);
+
+ // more nested:
+ struct_array =
+ ArrayFromJSON(struct_({field("a", struct_({field("a", float64())}))}), R"([
+ {"a": {"a": 6.125}},
+ {"a": {"a": 0.0}},
+ {"a": {"a": -1}}
+ ])");
+
+ ASSERT_OK_AND_ASSIGN(a, FieldRef("a", "a").GetOne(*struct_array));
+ expected_a = ArrayFromJSON(float64(), "[6.125, 0.0, -1]");
+ AssertArraysEqual(*a, *expected_a);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_test.cc b/src/arrow/cpp/src/arrow/array/array_test.cc
new file mode 100644
index 000000000..62ee032db
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_test.cc
@@ -0,0 +1,3291 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <numeric>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_binary.h"
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/array_primitive.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/array/builder_dict.h"
+#include "arrow/array/builder_time.h"
+#include "arrow/array/data.h"
+#include "arrow/array/util.h"
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/compare.h"
+#include "arrow/result.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/testing/extension_type.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_compat.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_builders.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/range.h"
+#include "arrow/visitor_inline.h"
+
+// This file is compiled together with array-*-test.cc into a single
+// executable array-test.
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+class TestArray : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = default_memory_pool(); }
+
+ protected:
+ MemoryPool* pool_;
+};
+
+TEST_F(TestArray, TestNullCount) {
+ // These are placeholders
+ auto data = std::make_shared<Buffer>(nullptr, 0);
+ auto null_bitmap = std::make_shared<Buffer>(nullptr, 0);
+
+ std::unique_ptr<Int32Array> arr(new Int32Array(100, data, null_bitmap, 10));
+ ASSERT_EQ(10, arr->null_count());
+
+ std::unique_ptr<Int32Array> arr_no_nulls(new Int32Array(100, data));
+ ASSERT_EQ(0, arr_no_nulls->null_count());
+
+ std::unique_ptr<Int32Array> arr_default_null_count(
+ new Int32Array(100, data, null_bitmap));
+ ASSERT_EQ(kUnknownNullCount, arr_default_null_count->data()->null_count);
+}
+
+TEST_F(TestArray, TestSlicePreservesAllNullCount) {
+ // These are placeholders
+ auto data = std::make_shared<Buffer>(nullptr, 0);
+ auto null_bitmap = std::make_shared<Buffer>(nullptr, 0);
+
+ Int32Array arr(/*length=*/100, data, null_bitmap,
+ /*null_count*/ 100);
+ EXPECT_EQ(arr.Slice(1, 99)->data()->null_count, arr.Slice(1, 99)->length());
+}
+
+TEST_F(TestArray, TestLength) {
+ // Placeholder buffer
+ auto data = std::make_shared<Buffer>(nullptr, 0);
+
+ std::unique_ptr<Int32Array> arr(new Int32Array(100, data));
+ ASSERT_EQ(arr->length(), 100);
+}
+
+TEST_F(TestArray, TestNullToString) {
+ // Invalid NULL buffer
+ auto data = std::make_shared<Buffer>(nullptr, 400);
+
+ std::unique_ptr<Int32Array> arr(new Int32Array(100, data));
+ ASSERT_EQ(arr->ToString(), "<Invalid array: Missing values buffer in non-empty array>");
+}
+
+TEST_F(TestArray, TestSliceSafe) {
+ std::vector<int32_t> original_data{1, 2, 3, 4, 5, 6, 7};
+ auto arr = std::make_shared<Int32Array>(7, Buffer::Wrap(original_data));
+
+ auto check_data = [](const Array& arr, const std::vector<int32_t>& expected) {
+ ASSERT_EQ(arr.length(), static_cast<int64_t>(expected.size()));
+ const int32_t* data = arr.data()->GetValues<int32_t>(1);
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ ASSERT_EQ(data[i], expected[i]);
+ }
+ };
+
+ check_data(*arr, {1, 2, 3, 4, 5, 6, 7});
+
+ ASSERT_OK_AND_ASSIGN(auto sliced, arr->SliceSafe(0, 0));
+ check_data(*sliced, {});
+
+ ASSERT_OK_AND_ASSIGN(sliced, arr->SliceSafe(0, 7));
+ check_data(*sliced, original_data);
+
+ ASSERT_OK_AND_ASSIGN(sliced, arr->SliceSafe(3, 4));
+ check_data(*sliced, {4, 5, 6, 7});
+
+ ASSERT_OK_AND_ASSIGN(sliced, arr->SliceSafe(0, 7));
+ check_data(*sliced, {1, 2, 3, 4, 5, 6, 7});
+
+ ASSERT_OK_AND_ASSIGN(sliced, arr->SliceSafe(7, 0));
+ check_data(*sliced, {});
+
+ ASSERT_RAISES(Invalid, arr->SliceSafe(8, 0));
+ ASSERT_RAISES(Invalid, arr->SliceSafe(0, 8));
+ ASSERT_RAISES(Invalid, arr->SliceSafe(-1, 0));
+ ASSERT_RAISES(Invalid, arr->SliceSafe(0, -1));
+ ASSERT_RAISES(Invalid, arr->SliceSafe(6, 2));
+ ASSERT_RAISES(Invalid, arr->SliceSafe(6, std::numeric_limits<int64_t>::max() - 5));
+
+ ASSERT_OK_AND_ASSIGN(sliced, arr->SliceSafe(0));
+ check_data(*sliced, original_data);
+
+ ASSERT_OK_AND_ASSIGN(sliced, arr->SliceSafe(3));
+ check_data(*sliced, {4, 5, 6, 7});
+
+ ASSERT_OK_AND_ASSIGN(sliced, arr->SliceSafe(7));
+ check_data(*sliced, {});
+
+ ASSERT_RAISES(Invalid, arr->SliceSafe(8));
+ ASSERT_RAISES(Invalid, arr->SliceSafe(-1));
+}
+
+Status MakeArrayFromValidBytes(const std::vector<uint8_t>& v, MemoryPool* pool,
+ std::shared_ptr<Array>* out) {
+ int64_t null_count = v.size() - std::accumulate(v.begin(), v.end(), 0);
+
+ ARROW_ASSIGN_OR_RAISE(auto null_buf, internal::BytesToBits(v));
+
+ TypedBufferBuilder<int32_t> value_builder(pool);
+ for (size_t i = 0; i < v.size(); ++i) {
+ RETURN_NOT_OK(value_builder.Append(0));
+ }
+
+ std::shared_ptr<Buffer> values;
+ RETURN_NOT_OK(value_builder.Finish(&values));
+ *out = std::make_shared<Int32Array>(v.size(), values, null_buf, null_count);
+ return Status::OK();
+}
+
+TEST_F(TestArray, TestEquality) {
+ std::shared_ptr<Array> array, equal_array, unequal_array;
+
+ ASSERT_OK(MakeArrayFromValidBytes({1, 0, 1, 1, 0, 1, 0, 0}, pool_, &array));
+ ASSERT_OK(MakeArrayFromValidBytes({1, 0, 1, 1, 0, 1, 0, 0}, pool_, &equal_array));
+ ASSERT_OK(MakeArrayFromValidBytes({1, 1, 1, 1, 0, 1, 0, 0}, pool_, &unequal_array));
+
+ EXPECT_TRUE(array->Equals(array));
+ EXPECT_TRUE(array->Equals(equal_array));
+ EXPECT_TRUE(equal_array->Equals(array));
+ EXPECT_FALSE(equal_array->Equals(unequal_array));
+ EXPECT_FALSE(unequal_array->Equals(equal_array));
+ EXPECT_TRUE(array->RangeEquals(4, 8, 4, unequal_array));
+ EXPECT_FALSE(array->RangeEquals(0, 4, 0, unequal_array));
+ EXPECT_FALSE(array->RangeEquals(0, 8, 0, unequal_array));
+ EXPECT_FALSE(array->RangeEquals(1, 2, 1, unequal_array));
+
+ auto timestamp_ns_array = std::make_shared<NumericArray<TimestampType>>(
+ timestamp(TimeUnit::NANO), array->length(), array->data()->buffers[1],
+ array->data()->buffers[0], array->null_count());
+ auto timestamp_us_array = std::make_shared<NumericArray<TimestampType>>(
+ timestamp(TimeUnit::MICRO), array->length(), array->data()->buffers[1],
+ array->data()->buffers[0], array->null_count());
+ ASSERT_FALSE(array->Equals(timestamp_ns_array));
+ // ARROW-2567: Ensure that not only the type id but also the type equality
+ // itself is checked.
+ ASSERT_FALSE(timestamp_us_array->Equals(timestamp_ns_array));
+ ASSERT_TRUE(timestamp_us_array->RangeEquals(0, 1, 0, timestamp_us_array));
+ ASSERT_FALSE(timestamp_us_array->RangeEquals(0, 1, 0, timestamp_ns_array));
+}
+
+TEST_F(TestArray, TestNullArrayEquality) {
+ auto array_1 = std::make_shared<NullArray>(10);
+ auto array_2 = std::make_shared<NullArray>(10);
+ auto array_3 = std::make_shared<NullArray>(20);
+
+ EXPECT_TRUE(array_1->Equals(array_1));
+ EXPECT_TRUE(array_1->Equals(array_2));
+ EXPECT_FALSE(array_1->Equals(array_3));
+}
+
+TEST_F(TestArray, SliceRecomputeNullCount) {
+ std::vector<uint8_t> valid_bytes = {1, 0, 1, 1, 0, 1, 0, 0, 0};
+
+ std::shared_ptr<Array> array;
+ ASSERT_OK(MakeArrayFromValidBytes(valid_bytes, pool_, &array));
+
+ ASSERT_EQ(5, array->null_count());
+
+ auto slice = array->Slice(1, 4);
+ ASSERT_EQ(2, slice->null_count());
+
+ slice = array->Slice(4);
+ ASSERT_EQ(4, slice->null_count());
+
+ auto slice2 = slice->Slice(0);
+ ASSERT_EQ(4, slice2->null_count());
+
+ slice = array->Slice(0);
+ ASSERT_EQ(5, slice->null_count());
+
+ // No bitmap, compute 0
+ const int kBufferSize = 64;
+ ASSERT_OK_AND_ASSIGN(auto data, AllocateBuffer(kBufferSize, pool_));
+ memset(data->mutable_data(), 0, kBufferSize);
+
+ auto arr = std::make_shared<Int32Array>(16, std::move(data), nullptr, -1);
+ ASSERT_EQ(0, arr->null_count());
+}
+
+TEST_F(TestArray, NullArraySliceNullCount) {
+ auto null_arr = std::make_shared<NullArray>(10);
+ auto null_arr_sliced = null_arr->Slice(3, 6);
+
+ // The internal null count is 6, does not require recomputation
+ ASSERT_EQ(6, null_arr_sliced->data()->null_count);
+
+ ASSERT_EQ(6, null_arr_sliced->null_count());
+}
+
+TEST_F(TestArray, TestIsNullIsValid) {
+ // clang-format off
+ std::vector<uint8_t> null_bitmap = {1, 0, 1, 1, 0, 1, 0, 0,
+ 1, 0, 1, 1, 0, 1, 0, 0,
+ 1, 0, 1, 1, 0, 1, 0, 0,
+ 1, 0, 1, 1, 0, 1, 0, 0,
+ 1, 0, 0, 1};
+ // clang-format on
+ int64_t null_count = 0;
+ for (uint8_t x : null_bitmap) {
+ if (x == 0) {
+ ++null_count;
+ }
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto null_buf, internal::BytesToBits(null_bitmap));
+
+ std::unique_ptr<Array> arr;
+ arr.reset(new Int32Array(null_bitmap.size(), nullptr, null_buf, null_count));
+
+ ASSERT_EQ(null_count, arr->null_count());
+ ASSERT_EQ(5, null_buf->size());
+
+ ASSERT_TRUE(arr->null_bitmap()->Equals(*null_buf.get()));
+
+ for (size_t i = 0; i < null_bitmap.size(); ++i) {
+ EXPECT_EQ(null_bitmap[i] != 0, !arr->IsNull(i)) << i;
+ EXPECT_EQ(null_bitmap[i] != 0, arr->IsValid(i)) << i;
+ }
+}
+
+TEST_F(TestArray, TestIsNullIsValidNoNulls) {
+ const int64_t size = 10;
+
+ std::unique_ptr<Array> arr;
+ arr.reset(new Int32Array(size, nullptr, nullptr, 0));
+
+ for (size_t i = 0; i < size; ++i) {
+ EXPECT_TRUE(arr->IsValid(i));
+ EXPECT_FALSE(arr->IsNull(i));
+ }
+}
+
+TEST_F(TestArray, BuildLargeInMemoryArray) {
+#ifdef NDEBUG
+ const int64_t length = static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
+#elif !defined(ARROW_VALGRIND)
+ // use a smaller size since the insert function isn't optimized properly on debug and
+ // the test takes a long time to complete
+ const int64_t length = 2 << 24;
+#else
+ // use an even smaller size with valgrind
+ const int64_t length = 2 << 20;
+#endif
+
+ BooleanBuilder builder;
+ std::vector<bool> zeros(length);
+ ASSERT_OK(builder.AppendValues(zeros));
+
+ std::shared_ptr<Array> result;
+ FinishAndCheckPadding(&builder, &result);
+
+ ASSERT_EQ(length, result->length());
+}
+
+TEST_F(TestArray, TestMakeArrayOfNull) {
+ std::shared_ptr<DataType> types[] = {
+ // clang-format off
+ null(),
+ boolean(),
+ int8(),
+ uint16(),
+ int32(),
+ uint64(),
+ float64(),
+ binary(),
+ large_binary(),
+ fixed_size_binary(3),
+ decimal(16, 4),
+ utf8(),
+ large_utf8(),
+ list(utf8()),
+ list(int64()), // ARROW-9071
+ large_list(large_utf8()),
+ fixed_size_list(utf8(), 3),
+ fixed_size_list(int64(), 4),
+ dictionary(int32(), utf8()),
+ struct_({field("a", utf8()), field("b", int32())}),
+ smallint(), // extension type
+ // clang-format on
+ };
+
+ for (int64_t length : {0, 1, 16, 133}) {
+ for (auto type : types) {
+ ASSERT_OK_AND_ASSIGN(auto array, MakeArrayOfNull(type, length));
+ ASSERT_OK(array->ValidateFull());
+ ASSERT_EQ(array->length(), length);
+ ASSERT_EQ(array->null_count(), length);
+ for (int64_t i = 0; i < length; ++i) {
+ ASSERT_TRUE(array->IsNull(i));
+ ASSERT_FALSE(array->IsValid(i));
+ }
+ }
+ }
+}
+
+TEST_F(TestArray, TestMakeArrayOfNullUnion) {
+ // Unions need special checking -- the top level null count is 0 (per
+ // ARROW-9222) so we check the first child to make sure is contains all nulls
+ // and check that the type_ids all point to the first child
+ const int64_t union_length = 10;
+ auto s_union_ty = sparse_union({field("a", utf8()), field("b", int32())}, {0, 1});
+ ASSERT_OK_AND_ASSIGN(auto s_union_nulls, MakeArrayOfNull(s_union_ty, union_length));
+ ASSERT_OK(s_union_nulls->ValidateFull());
+ ASSERT_EQ(s_union_nulls->null_count(), 0);
+ {
+ const auto& typed_union = checked_cast<const SparseUnionArray&>(*s_union_nulls);
+ ASSERT_EQ(typed_union.field(0)->null_count(), union_length);
+
+ // Check type codes are all 0
+ for (int i = 0; i < union_length; ++i) {
+ ASSERT_EQ(typed_union.raw_type_codes()[i], 0);
+ }
+ }
+
+ s_union_ty = sparse_union({field("a", utf8()), field("b", int32())}, {2, 7});
+ ASSERT_OK_AND_ASSIGN(s_union_nulls, MakeArrayOfNull(s_union_ty, union_length));
+ ASSERT_OK(s_union_nulls->ValidateFull());
+ ASSERT_EQ(s_union_nulls->null_count(), 0);
+ {
+ const auto& typed_union = checked_cast<const SparseUnionArray&>(*s_union_nulls);
+ ASSERT_EQ(typed_union.field(0)->null_count(), union_length);
+
+ // Check type codes are all 2
+ for (int i = 0; i < union_length; ++i) {
+ ASSERT_EQ(typed_union.raw_type_codes()[i], 2);
+ }
+ }
+
+ auto d_union_ty = dense_union({field("a", utf8()), field("b", int32())}, {0, 1});
+ ASSERT_OK_AND_ASSIGN(auto d_union_nulls, MakeArrayOfNull(d_union_ty, union_length));
+ ASSERT_OK(d_union_nulls->ValidateFull());
+ ASSERT_EQ(d_union_nulls->null_count(), 0);
+ {
+ const auto& typed_union = checked_cast<const DenseUnionArray&>(*d_union_nulls);
+
+ // Child field has length 1 which is a null element
+ ASSERT_EQ(typed_union.field(0)->length(), 1);
+ ASSERT_EQ(typed_union.field(0)->null_count(), 1);
+
+ // Check type codes are all 0 and the offsets point to the first element of
+ // the first child
+ for (int i = 0; i < union_length; ++i) {
+ ASSERT_EQ(typed_union.raw_type_codes()[i], 0);
+ ASSERT_EQ(typed_union.raw_value_offsets()[i], 0);
+ }
+ }
+}
+
+TEST_F(TestArray, TestValidateNullCount) {
+ Int32Builder builder(pool_);
+ ASSERT_OK(builder.Append(5));
+ ASSERT_OK(builder.Append(42));
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_OK_AND_ASSIGN(auto array, builder.Finish());
+
+ ArrayData* data = array->data().get();
+ data->null_count = kUnknownNullCount;
+ ASSERT_OK(array->ValidateFull());
+ data->null_count = 1;
+ ASSERT_OK(array->ValidateFull());
+
+ // null_count out of bounds
+ data->null_count = -2;
+ ASSERT_RAISES(Invalid, array->Validate());
+ ASSERT_RAISES(Invalid, array->ValidateFull());
+ data->null_count = 4;
+ ASSERT_RAISES(Invalid, array->Validate());
+ ASSERT_RAISES(Invalid, array->ValidateFull());
+
+ // null_count inconsistent with data
+ for (const int64_t null_count : {0, 2, 3}) {
+ data->null_count = null_count;
+ ASSERT_OK(array->Validate());
+ ASSERT_RAISES(Invalid, array->ValidateFull());
+ }
+}
+
+void AssertAppendScalar(MemoryPool* pool, const std::shared_ptr<Scalar>& scalar) {
+ std::unique_ptr<arrow::ArrayBuilder> builder;
+ auto null_scalar = MakeNullScalar(scalar->type);
+ ASSERT_OK(MakeBuilderExactIndex(pool, scalar->type, &builder));
+ ASSERT_OK(builder->AppendScalar(*scalar));
+ ASSERT_OK(builder->AppendScalar(*scalar));
+ ASSERT_OK(builder->AppendScalar(*null_scalar));
+ ASSERT_OK(builder->AppendScalars({scalar, null_scalar}));
+ ASSERT_OK(builder->AppendScalar(*scalar, /*n_repeats=*/2));
+ ASSERT_OK(builder->AppendScalar(*null_scalar, /*n_repeats=*/2));
+
+ std::shared_ptr<Array> out;
+ FinishAndCheckPadding(builder.get(), &out);
+ ASSERT_OK(out->ValidateFull());
+ AssertTypeEqual(scalar->type, out->type());
+ ASSERT_EQ(out->length(), 9);
+
+ const bool can_check_nulls = internal::HasValidityBitmap(out->type()->id());
+ // For a dictionary builder, the output dictionary won't necessarily be the same
+ const bool can_check_values = !is_dictionary(out->type()->id());
+
+ if (can_check_nulls) {
+ ASSERT_EQ(out->null_count(), 4);
+ }
+
+ for (const auto index : {0, 1, 3, 5, 6}) {
+ ASSERT_FALSE(out->IsNull(index));
+ ASSERT_OK_AND_ASSIGN(auto scalar_i, out->GetScalar(index));
+ ASSERT_OK(scalar_i->ValidateFull());
+ if (can_check_values) AssertScalarsEqual(*scalar, *scalar_i, /*verbose=*/true);
+ }
+ for (const auto index : {2, 4, 7, 8}) {
+ ASSERT_EQ(out->IsNull(index), can_check_nulls);
+ ASSERT_OK_AND_ASSIGN(auto scalar_i, out->GetScalar(index));
+ ASSERT_OK(scalar_i->ValidateFull());
+ AssertScalarsEqual(*null_scalar, *scalar_i, /*verbose=*/true);
+ }
+}
+
+static ScalarVector GetScalars() {
+ auto hello = Buffer::FromString("hello");
+ DayTimeIntervalType::DayMilliseconds daytime{1, 100};
+ MonthDayNanoIntervalType::MonthDayNanos month_day_nanos{5, 4, 100};
+
+ FieldVector union_fields{field("string", utf8()), field("number", int32()),
+ field("other_number", int32())};
+ std::vector<int8_t> union_type_codes{5, 6, 42};
+
+ const auto sparse_union_ty = ::arrow::sparse_union(union_fields, union_type_codes);
+ const auto dense_union_ty = ::arrow::dense_union(union_fields, union_type_codes);
+
+ return {
+ std::make_shared<BooleanScalar>(false),
+ std::make_shared<Int8Scalar>(3),
+ std::make_shared<UInt16Scalar>(3),
+ std::make_shared<Int32Scalar>(3),
+ std::make_shared<UInt64Scalar>(3),
+ std::make_shared<DoubleScalar>(3.0),
+ std::make_shared<Date32Scalar>(10),
+ std::make_shared<Date64Scalar>(11),
+ std::make_shared<Time32Scalar>(1000, time32(TimeUnit::SECOND)),
+ std::make_shared<Time64Scalar>(1111, time64(TimeUnit::MICRO)),
+ std::make_shared<TimestampScalar>(1111, timestamp(TimeUnit::MILLI)),
+ std::make_shared<MonthIntervalScalar>(1),
+ std::make_shared<DayTimeIntervalScalar>(daytime),
+ std::make_shared<MonthDayNanoIntervalScalar>(month_day_nanos),
+ std::make_shared<DurationScalar>(60, duration(TimeUnit::SECOND)),
+ std::make_shared<BinaryScalar>(hello),
+ std::make_shared<LargeBinaryScalar>(hello),
+ std::make_shared<FixedSizeBinaryScalar>(
+ hello, fixed_size_binary(static_cast<int32_t>(hello->size()))),
+ std::make_shared<Decimal128Scalar>(Decimal128(10), decimal(16, 4)),
+ std::make_shared<Decimal256Scalar>(Decimal256(10), decimal(76, 38)),
+ std::make_shared<StringScalar>(hello),
+ std::make_shared<LargeStringScalar>(hello),
+ std::make_shared<ListScalar>(ArrayFromJSON(int8(), "[1, 2, 3]")),
+ ScalarFromJSON(map(int8(), utf8()), R"([[1, "foo"], [2, "bar"]])"),
+ std::make_shared<LargeListScalar>(ArrayFromJSON(int8(), "[1, 1, 2, 2, 3, 3]")),
+ std::make_shared<FixedSizeListScalar>(ArrayFromJSON(int8(), "[1, 2, 3, 4]")),
+ std::make_shared<StructScalar>(
+ ScalarVector{
+ std::make_shared<Int32Scalar>(2),
+ std::make_shared<Int32Scalar>(6),
+ },
+ struct_({field("min", int32()), field("max", int32())})),
+ // Same values, different union type codes
+ std::make_shared<SparseUnionScalar>(std::make_shared<Int32Scalar>(100), 6,
+ sparse_union_ty),
+ std::make_shared<SparseUnionScalar>(std::make_shared<Int32Scalar>(100), 42,
+ sparse_union_ty),
+ std::make_shared<SparseUnionScalar>(42, sparse_union_ty),
+ std::make_shared<DenseUnionScalar>(std::make_shared<Int32Scalar>(101), 6,
+ dense_union_ty),
+ std::make_shared<DenseUnionScalar>(std::make_shared<Int32Scalar>(101), 42,
+ dense_union_ty),
+ std::make_shared<DenseUnionScalar>(42, dense_union_ty),
+ DictionaryScalar::Make(ScalarFromJSON(int8(), "1"),
+ ArrayFromJSON(utf8(), R"(["foo", "bar"])")),
+ DictionaryScalar::Make(ScalarFromJSON(uint8(), "1"),
+ ArrayFromJSON(utf8(), R"(["foo", "bar"])")),
+ };
+}
+
+TEST_F(TestArray, TestMakeArrayFromScalar) {
+ ASSERT_OK_AND_ASSIGN(auto null_array, MakeArrayFromScalar(NullScalar(), 5));
+ ASSERT_OK(null_array->ValidateFull());
+ ASSERT_EQ(null_array->length(), 5);
+ ASSERT_EQ(null_array->null_count(), 5);
+
+ auto scalars = GetScalars();
+
+ for (int64_t length : {16}) {
+ for (auto scalar : scalars) {
+ ASSERT_OK_AND_ASSIGN(auto array, MakeArrayFromScalar(*scalar, length));
+ ASSERT_OK(array->ValidateFull());
+ ASSERT_EQ(array->length(), length);
+ ASSERT_EQ(array->null_count(), 0);
+
+ // test case for ARROW-13321
+ for (int64_t i : std::vector<int64_t>{0, length / 2, length - 1}) {
+ ASSERT_OK_AND_ASSIGN(auto s, array->GetScalar(i));
+ AssertScalarsEqual(*s, *scalar, /*verbose=*/true);
+ }
+ }
+ }
+
+ for (auto scalar : scalars) {
+ AssertAppendScalar(pool_, scalar);
+ }
+}
+
+TEST_F(TestArray, TestMakeArrayFromScalarSliced) {
+ // Regression test for ARROW-13437
+ auto scalars = GetScalars();
+
+ for (auto scalar : scalars) {
+ SCOPED_TRACE(scalar->type->ToString());
+ ASSERT_OK_AND_ASSIGN(auto array, MakeArrayFromScalar(*scalar, 32));
+ auto sliced = array->Slice(1, 4);
+ ASSERT_EQ(sliced->length(), 4);
+ ASSERT_EQ(sliced->null_count(), 0);
+ ARROW_EXPECT_OK(sliced->ValidateFull());
+ }
+}
+
+TEST_F(TestArray, TestMakeArrayFromDictionaryScalar) {
+ auto dictionary = ArrayFromJSON(utf8(), R"(["foo", "bar", "baz"])");
+ auto type = std::make_shared<DictionaryType>(int8(), utf8());
+ ASSERT_OK_AND_ASSIGN(auto value, MakeScalar(int8(), 1));
+ auto scalar = DictionaryScalar({value, dictionary}, type);
+
+ ASSERT_OK_AND_ASSIGN(auto array, MakeArrayFromScalar(scalar, 4));
+ ASSERT_OK(array->ValidateFull());
+ ASSERT_EQ(array->length(), 4);
+ ASSERT_EQ(array->null_count(), 0);
+
+ for (int i = 0; i < 4; i++) {
+ ASSERT_OK_AND_ASSIGN(auto item, array->GetScalar(i));
+ ASSERT_TRUE(item->Equals(scalar));
+ }
+}
+
+TEST_F(TestArray, TestMakeArrayFromMapScalar) {
+ auto value =
+ ArrayFromJSON(struct_({field("key", utf8(), false), field("value", int8())}),
+ R"([{"key": "a", "value": 1}, {"key": "b", "value": 2}])");
+ auto scalar = MapScalar(value);
+
+ ASSERT_OK_AND_ASSIGN(auto array, MakeArrayFromScalar(scalar, 11));
+ ASSERT_OK(array->ValidateFull());
+ ASSERT_EQ(array->length(), 11);
+ ASSERT_EQ(array->null_count(), 0);
+
+ for (int i = 0; i < 11; i++) {
+ ASSERT_OK_AND_ASSIGN(auto item, array->GetScalar(i));
+ ASSERT_TRUE(item->Equals(scalar));
+ }
+
+ AssertAppendScalar(pool_, std::make_shared<MapScalar>(scalar));
+}
+
+TEST_F(TestArray, TestAppendArraySlice) {
+ auto scalars = GetScalars();
+ for (const auto& scalar : scalars) {
+ ARROW_SCOPED_TRACE(*scalar->type);
+ ASSERT_OK_AND_ASSIGN(auto array, MakeArrayFromScalar(*scalar, 16));
+ ASSERT_OK_AND_ASSIGN(auto nulls, MakeArrayOfNull(scalar->type, 16));
+
+ std::unique_ptr<arrow::ArrayBuilder> builder;
+ ASSERT_OK(MakeBuilder(pool_, scalar->type, &builder));
+
+ ASSERT_OK(builder->AppendArraySlice(*array->data(), 0, 4));
+ ASSERT_EQ(4, builder->length());
+ ASSERT_OK(builder->AppendArraySlice(*array->data(), 0, 0));
+ ASSERT_EQ(4, builder->length());
+ ASSERT_OK(builder->AppendArraySlice(*array->data(), 1, 0));
+ ASSERT_EQ(4, builder->length());
+ ASSERT_OK(builder->AppendArraySlice(*array->data(), 1, 4));
+ ASSERT_EQ(8, builder->length());
+
+ ASSERT_OK(builder->AppendArraySlice(*nulls->data(), 0, 4));
+ ASSERT_EQ(12, builder->length());
+ if (!is_union(scalar->type->id())) {
+ ASSERT_EQ(4, builder->null_count());
+ }
+ ASSERT_OK(builder->AppendArraySlice(*nulls->data(), 0, 0));
+ ASSERT_EQ(12, builder->length());
+ if (!is_union(scalar->type->id())) {
+ ASSERT_EQ(4, builder->null_count());
+ }
+ ASSERT_OK(builder->AppendArraySlice(*nulls->data(), 1, 0));
+ ASSERT_EQ(12, builder->length());
+ if (!is_union(scalar->type->id())) {
+ ASSERT_EQ(4, builder->null_count());
+ }
+ ASSERT_OK(builder->AppendArraySlice(*nulls->data(), 1, 4));
+ ASSERT_EQ(16, builder->length());
+ if (!is_union(scalar->type->id())) {
+ ASSERT_EQ(8, builder->null_count());
+ }
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder->Finish(&result));
+ ASSERT_OK(result->ValidateFull());
+ ASSERT_EQ(16, result->length());
+ if (!is_union(scalar->type->id())) {
+ ASSERT_EQ(8, result->null_count());
+ }
+ }
+
+ {
+ ASSERT_OK_AND_ASSIGN(auto array, MakeArrayOfNull(null(), 16));
+ NullBuilder builder(pool_);
+ ASSERT_OK(builder.AppendArraySlice(*array->data(), 0, 4));
+ ASSERT_EQ(4, builder.length());
+ ASSERT_OK(builder.AppendArraySlice(*array->data(), 0, 0));
+ ASSERT_EQ(4, builder.length());
+ ASSERT_OK(builder.AppendArraySlice(*array->data(), 1, 0));
+ ASSERT_EQ(4, builder.length());
+ ASSERT_OK(builder.AppendArraySlice(*array->data(), 1, 4));
+ ASSERT_EQ(8, builder.length());
+ std::shared_ptr<Array> result;
+ ASSERT_OK(builder.Finish(&result));
+ ASSERT_OK(result->ValidateFull());
+ ASSERT_EQ(8, result->length());
+ ASSERT_EQ(8, result->null_count());
+ }
+}
+
+TEST_F(TestArray, ValidateBuffersPrimitive) {
+ auto empty_buffer = std::make_shared<Buffer>("");
+ auto null_buffer = Buffer::FromString("\xff");
+ auto data_buffer = Buffer::FromString("123456789abcdef0");
+
+ auto data = ArrayData::Make(int64(), 2, {null_buffer, data_buffer});
+ auto array = MakeArray(data);
+ ASSERT_OK(array->ValidateFull());
+ data = ArrayData::Make(boolean(), 8, {null_buffer, data_buffer});
+ array = MakeArray(data);
+ ASSERT_OK(array->ValidateFull());
+
+ // Null buffer too small
+ data = ArrayData::Make(int64(), 2, {empty_buffer, data_buffer});
+ array = MakeArray(data);
+ ASSERT_RAISES(Invalid, array->Validate());
+ data = ArrayData::Make(boolean(), 9, {null_buffer, data_buffer});
+ array = MakeArray(data);
+ ASSERT_RAISES(Invalid, array->Validate());
+
+ // Data buffer too small
+ data = ArrayData::Make(int64(), 3, {null_buffer, data_buffer});
+ array = MakeArray(data);
+ ASSERT_RAISES(Invalid, array->Validate());
+
+ // Null buffer absent but null_count > 0.
+ data = ArrayData::Make(int64(), 2, {nullptr, data_buffer}, 1);
+ array = MakeArray(data);
+ ASSERT_RAISES(Invalid, array->Validate());
+
+ //
+ // With offset > 0
+ //
+ data = ArrayData::Make(int64(), 1, {null_buffer, data_buffer}, kUnknownNullCount, 1);
+ array = MakeArray(data);
+ ASSERT_OK(array->ValidateFull());
+ data = ArrayData::Make(boolean(), 6, {null_buffer, data_buffer}, kUnknownNullCount, 2);
+ array = MakeArray(data);
+ ASSERT_OK(array->ValidateFull());
+
+ // Null buffer too small
+ data = ArrayData::Make(boolean(), 7, {null_buffer, data_buffer}, kUnknownNullCount, 2);
+ array = MakeArray(data);
+ ASSERT_RAISES(Invalid, array->Validate());
+
+ // Data buffer too small
+ data = ArrayData::Make(int64(), 2, {null_buffer, data_buffer}, kUnknownNullCount, 1);
+ array = MakeArray(data);
+ ASSERT_RAISES(Invalid, array->Validate());
+}
+
+// ----------------------------------------------------------------------
+// Null type tests
+
+TEST(TestNullBuilder, Basics) {
+ NullBuilder builder;
+ std::shared_ptr<Array> array;
+
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_OK(builder.Append(nullptr));
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_OK(builder.AppendNulls(2));
+ ASSERT_EQ(5, builder.null_count());
+ ASSERT_OK(builder.Finish(&array));
+
+ const auto& null_array = checked_cast<NullArray&>(*array);
+ ASSERT_EQ(null_array.length(), 5);
+ ASSERT_EQ(null_array.null_count(), 5);
+}
+
+// ----------------------------------------------------------------------
+// Primitive type tests
+
+TEST(TestPrimitiveArray, CtorNoValidityBitmap) {
+ // ARROW-8863
+ std::shared_ptr<Buffer> data = *AllocateBuffer(40);
+ Int32Array arr(10, data);
+ ASSERT_EQ(arr.data()->null_count, 0);
+}
+
+TEST_F(TestBuilder, TestReserve) {
+ UInt8Builder builder(pool_);
+
+ ASSERT_OK(builder.Resize(1000));
+ ASSERT_EQ(1000, builder.capacity());
+
+ // Reserve overallocates for small upsizes.
+ ASSERT_OK(builder.Reserve(1030));
+ ASSERT_GE(builder.capacity(), 2000);
+}
+
+TEST_F(TestBuilder, TestResizeDownsize) {
+ UInt8Builder builder(pool_);
+
+ ASSERT_OK(builder.Resize(1000));
+ ASSERT_EQ(1000, builder.capacity());
+ ASSERT_EQ(0, builder.length());
+ ASSERT_OK(builder.AppendNulls(500));
+ ASSERT_EQ(1000, builder.capacity());
+ ASSERT_EQ(500, builder.length());
+
+ // Can downsize below current capacity
+ ASSERT_OK(builder.Resize(500));
+ // ... but not below current populated length
+ ASSERT_RAISES(Invalid, builder.Resize(499));
+ ASSERT_GE(500, builder.capacity());
+ ASSERT_EQ(500, builder.length());
+}
+
+template <typename Attrs>
+class TestPrimitiveBuilder : public TestBuilder {
+ public:
+ typedef Attrs TestAttrs;
+ typedef typename Attrs::ArrayType ArrayType;
+ typedef typename Attrs::BuilderType BuilderType;
+ typedef typename Attrs::T CType;
+ typedef typename Attrs::Type Type;
+
+ virtual void SetUp() {
+ TestBuilder::SetUp();
+
+ type_ = Attrs::type();
+
+ std::unique_ptr<ArrayBuilder> tmp;
+ ASSERT_OK(MakeBuilder(pool_, type_, &tmp));
+ builder_.reset(checked_cast<BuilderType*>(tmp.release()));
+
+ ASSERT_OK(MakeBuilder(pool_, type_, &tmp));
+ builder_nn_.reset(checked_cast<BuilderType*>(tmp.release()));
+ }
+
+ void RandomData(int64_t N, double pct_null = 0.1) {
+ Attrs::draw(N, &draws_);
+
+ valid_bytes_.resize(static_cast<size_t>(N));
+ random_null_bytes(N, pct_null, valid_bytes_.data());
+ }
+
+ void Check(const std::unique_ptr<BuilderType>& builder, bool nullable) {
+ int64_t size = builder->length();
+ auto ex_data = Buffer::Wrap(draws_.data(), size);
+
+ std::shared_ptr<Buffer> ex_null_bitmap;
+ int64_t ex_null_count = 0;
+
+ if (nullable) {
+ ASSERT_OK_AND_ASSIGN(ex_null_bitmap, internal::BytesToBits(valid_bytes_));
+ ex_null_count = CountNulls(valid_bytes_);
+ } else {
+ ex_null_bitmap = nullptr;
+ }
+
+ auto expected =
+ std::make_shared<ArrayType>(size, ex_data, ex_null_bitmap, ex_null_count);
+
+ std::shared_ptr<Array> out;
+ FinishAndCheckPadding(builder.get(), &out);
+
+ std::shared_ptr<ArrayType> result = checked_pointer_cast<ArrayType>(out);
+
+ // Builder is now reset
+ ASSERT_EQ(0, builder->length());
+ ASSERT_EQ(0, builder->capacity());
+ ASSERT_EQ(0, builder->null_count());
+
+ ASSERT_EQ(ex_null_count, result->null_count());
+ ASSERT_TRUE(result->Equals(*expected));
+ }
+
+ void FlipValue(CType* ptr) {
+ auto byteptr = reinterpret_cast<uint8_t*>(ptr);
+ *byteptr = static_cast<uint8_t>(~*byteptr);
+ }
+
+ protected:
+ std::unique_ptr<BuilderType> builder_;
+ std::unique_ptr<BuilderType> builder_nn_;
+
+ std::vector<CType> draws_;
+ std::vector<uint8_t> valid_bytes_;
+};
+
+/// \brief uint8_t isn't a valid template parameter to uniform_int_distribution, so
+/// we use SampleType to determine which kind of integer to use to sample.
+template <typename T, typename = enable_if_t<std::is_integral<T>::value, T>>
+struct UniformIntSampleType {
+ using type = T;
+};
+
+template <>
+struct UniformIntSampleType<uint8_t> {
+ using type = uint16_t;
+};
+
+template <>
+struct UniformIntSampleType<int8_t> {
+ using type = int16_t;
+};
+
+#define PTYPE_DECL(CapType, c_type) \
+ typedef CapType##Array ArrayType; \
+ typedef CapType##Builder BuilderType; \
+ typedef CapType##Type Type; \
+ typedef c_type T; \
+ \
+ static std::shared_ptr<DataType> type() { return std::make_shared<Type>(); }
+
+#define PINT_DECL(CapType, c_type) \
+ struct P##CapType { \
+ PTYPE_DECL(CapType, c_type) \
+ static void draw(int64_t N, std::vector<T>* draws) { \
+ using sample_type = typename UniformIntSampleType<c_type>::type; \
+ const T lower = std::numeric_limits<T>::min(); \
+ const T upper = std::numeric_limits<T>::max(); \
+ randint(N, static_cast<sample_type>(lower), static_cast<sample_type>(upper), \
+ draws); \
+ } \
+ static T Modify(T inp) { return inp / 2; } \
+ typedef \
+ typename std::conditional<std::is_unsigned<T>::value, uint64_t, int64_t>::type \
+ ConversionType; \
+ }
+
+#define PFLOAT_DECL(CapType, c_type, LOWER, UPPER) \
+ struct P##CapType { \
+ PTYPE_DECL(CapType, c_type) \
+ static void draw(int64_t N, std::vector<T>* draws) { \
+ random_real(N, 0, LOWER, UPPER, draws); \
+ } \
+ static T Modify(T inp) { return inp / 2; } \
+ typedef double ConversionType; \
+ }
+
+PINT_DECL(UInt8, uint8_t);
+PINT_DECL(UInt16, uint16_t);
+PINT_DECL(UInt32, uint32_t);
+PINT_DECL(UInt64, uint64_t);
+
+PINT_DECL(Int8, int8_t);
+PINT_DECL(Int16, int16_t);
+PINT_DECL(Int32, int32_t);
+PINT_DECL(Int64, int64_t);
+
+PFLOAT_DECL(Float, float, -1000.0f, 1000.0f);
+PFLOAT_DECL(Double, double, -1000.0, 1000.0);
+
+struct PBoolean {
+ PTYPE_DECL(Boolean, uint8_t)
+ static T Modify(T inp) { return !inp; }
+ typedef int64_t ConversionType;
+};
+
+struct PDayTimeInterval {
+ using DayMilliseconds = DayTimeIntervalType::DayMilliseconds;
+ PTYPE_DECL(DayTimeInterval, DayMilliseconds);
+ static void draw(int64_t N, std::vector<T>* draws) { return rand_day_millis(N, draws); }
+
+ static DayMilliseconds Modify(DayMilliseconds inp) {
+ inp.days /= 2;
+ return inp;
+ }
+ typedef DayMilliseconds ConversionType;
+};
+
+struct PMonthDayNanoInterval {
+ using MonthDayNanos = MonthDayNanoIntervalType::MonthDayNanos;
+ PTYPE_DECL(MonthDayNanoInterval, MonthDayNanos);
+ static void draw(int64_t N, std::vector<T>* draws) {
+ return rand_month_day_nanos(N, draws);
+ }
+ static MonthDayNanos Modify(MonthDayNanos inp) {
+ inp.days /= 2;
+ return inp;
+ }
+ typedef MonthDayNanos ConversionType;
+};
+
+template <>
+void TestPrimitiveBuilder<PBoolean>::RandomData(int64_t N, double pct_null) {
+ draws_.resize(static_cast<size_t>(N));
+ valid_bytes_.resize(static_cast<size_t>(N));
+
+ random_null_bytes(N, 0.5, draws_.data());
+ random_null_bytes(N, pct_null, valid_bytes_.data());
+}
+
+template <>
+void TestPrimitiveBuilder<PBoolean>::FlipValue(CType* ptr) {
+ *ptr = !*ptr;
+}
+
+template <>
+void TestPrimitiveBuilder<PBoolean>::Check(const std::unique_ptr<BooleanBuilder>& builder,
+ bool nullable) {
+ const int64_t size = builder->length();
+
+ // Build expected result array
+ std::shared_ptr<Buffer> ex_data;
+ std::shared_ptr<Buffer> ex_null_bitmap;
+ int64_t ex_null_count = 0;
+
+ ASSERT_OK_AND_ASSIGN(ex_data, internal::BytesToBits(draws_));
+ if (nullable) {
+ ASSERT_OK_AND_ASSIGN(ex_null_bitmap, internal::BytesToBits(valid_bytes_));
+ ex_null_count = CountNulls(valid_bytes_);
+ } else {
+ ex_null_bitmap = nullptr;
+ }
+ auto expected =
+ std::make_shared<BooleanArray>(size, ex_data, ex_null_bitmap, ex_null_count);
+ ASSERT_EQ(size, expected->length());
+
+ // Finish builder and check result array
+ std::shared_ptr<Array> out;
+ FinishAndCheckPadding(builder.get(), &out);
+
+ std::shared_ptr<BooleanArray> result = checked_pointer_cast<BooleanArray>(out);
+
+ ASSERT_EQ(ex_null_count, result->null_count());
+ ASSERT_EQ(size, result->length());
+
+ for (int64_t i = 0; i < size; ++i) {
+ if (nullable) {
+ ASSERT_EQ(valid_bytes_[i] == 0, result->IsNull(i)) << i;
+ } else {
+ ASSERT_FALSE(result->IsNull(i));
+ }
+ if (!result->IsNull(i)) {
+ bool actual = BitUtil::GetBit(result->values()->data(), i);
+ ASSERT_EQ(draws_[i] != 0, actual) << i;
+ }
+ }
+ AssertArraysEqual(*result, *expected);
+
+ // buffers are correctly sized
+ if (result->data()->buffers[0]) {
+ ASSERT_EQ(result->data()->buffers[0]->size(), BitUtil::BytesForBits(size));
+ } else {
+ ASSERT_EQ(result->data()->null_count, 0);
+ }
+ ASSERT_EQ(result->data()->buffers[1]->size(), BitUtil::BytesForBits(size));
+
+ // Builder is now reset
+ ASSERT_EQ(0, builder->length());
+ ASSERT_EQ(0, builder->capacity());
+ ASSERT_EQ(0, builder->null_count());
+}
+
+TEST(TestBooleanArray, TrueCountFalseCount) {
+ random::RandomArrayGenerator rng(/*seed=*/0);
+
+ const int64_t length = 10000;
+ auto arr = rng.Boolean(length, /*true_probability=*/0.5, /*null_probability=*/0.1);
+
+ auto CheckArray = [&](const BooleanArray& values) {
+ int64_t expected_false = 0;
+ int64_t expected_true = 0;
+ for (int64_t i = 0; i < values.length(); ++i) {
+ if (values.IsValid(i)) {
+ if (values.Value(i)) {
+ ++expected_true;
+ } else {
+ ++expected_false;
+ }
+ }
+ }
+ ASSERT_EQ(values.true_count(), expected_true);
+ ASSERT_EQ(values.false_count(), expected_false);
+ };
+
+ CheckArray(checked_cast<const BooleanArray&>(*arr));
+ CheckArray(checked_cast<const BooleanArray&>(*arr->Slice(5)));
+ CheckArray(checked_cast<const BooleanArray&>(*arr->Slice(0, 0)));
+}
+
+TEST(TestPrimitiveAdHoc, TestType) {
+ Int8Builder i8(default_memory_pool());
+ ASSERT_TRUE(i8.type()->Equals(int8()));
+
+ DictionaryBuilder<Int8Type> d_i8(utf8());
+ ASSERT_TRUE(d_i8.type()->Equals(dictionary(int8(), utf8())));
+
+ Dictionary32Builder<Int8Type> d32_i8(utf8());
+ ASSERT_TRUE(d32_i8.type()->Equals(dictionary(int32(), utf8())));
+}
+
+TEST(NumericBuilderAccessors, TestSettersGetters) {
+ int64_t datum = 42;
+ int64_t new_datum = 43;
+ NumericBuilder<Int64Type> builder(int64(), default_memory_pool());
+
+ builder.Reset();
+ ASSERT_OK(builder.Append(datum));
+ ASSERT_EQ(builder.GetValue(0), datum);
+
+ // Now update the value.
+ builder[0] = new_datum;
+
+ ASSERT_EQ(builder.GetValue(0), new_datum);
+ ASSERT_EQ(((const NumericBuilder<Int64Type>&)builder)[0], new_datum);
+}
+
+typedef ::testing::Types<PBoolean, PUInt8, PUInt16, PUInt32, PUInt64, PInt8, PInt16,
+ PInt32, PInt64, PFloat, PDouble, PDayTimeInterval,
+ PMonthDayNanoInterval>
+ Primitives;
+
+TYPED_TEST_SUITE(TestPrimitiveBuilder, Primitives);
+
+TYPED_TEST(TestPrimitiveBuilder, TestInit) {
+ ASSERT_OK(this->builder_->Reserve(1000));
+ ASSERT_EQ(1000, this->builder_->capacity());
+
+ // Small upsize => should overallocate
+ ASSERT_OK(this->builder_->Reserve(1200));
+ ASSERT_GE(2000, this->builder_->capacity());
+
+ // Large upsize => should allocate exactly
+ ASSERT_OK(this->builder_->Reserve(32768));
+ ASSERT_EQ(32768, this->builder_->capacity());
+
+ // unsure if this should go in all builder classes
+ ASSERT_EQ(0, this->builder_->num_children());
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestAppendNull) {
+ int64_t size = 1000;
+ for (int64_t i = 0; i < size; ++i) {
+ ASSERT_OK(this->builder_->AppendNull());
+ ASSERT_EQ(i + 1, this->builder_->null_count());
+ }
+
+ std::shared_ptr<Array> out;
+ FinishAndCheckPadding(this->builder_.get(), &out);
+ auto result = checked_pointer_cast<typename TypeParam::ArrayType>(out);
+
+ for (int64_t i = 0; i < size; ++i) {
+ ASSERT_TRUE(result->IsNull(i)) << i;
+ }
+
+ for (auto buffer : result->data()->buffers) {
+ for (int64_t i = 0; i < buffer->capacity(); i++) {
+ // Validates current implementation, algorithms shouldn't rely on this
+ ASSERT_EQ(0, *(buffer->data() + i)) << i;
+ }
+ }
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestAppendNulls) {
+ const int64_t size = 10;
+ ASSERT_OK(this->builder_->AppendNulls(size));
+ ASSERT_EQ(size, this->builder_->null_count());
+
+ std::shared_ptr<Array> result;
+ FinishAndCheckPadding(this->builder_.get(), &result);
+
+ for (int64_t i = 0; i < size; ++i) {
+ ASSERT_FALSE(result->IsValid(i));
+ }
+
+ for (auto buffer : result->data()->buffers) {
+ for (int64_t i = 0; i < buffer->capacity(); i++) {
+ // Validates current implementation, algorithms shouldn't rely on this
+ ASSERT_EQ(0, *(buffer->data() + i)) << i;
+ }
+ }
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestAppendEmptyValue) {
+ ASSERT_OK(this->builder_->AppendNull());
+ ASSERT_OK(this->builder_->AppendEmptyValue());
+ ASSERT_OK(this->builder_->AppendNulls(2));
+ ASSERT_OK(this->builder_->AppendEmptyValues(2));
+
+ std::shared_ptr<Array> out;
+ FinishAndCheckPadding(this->builder_.get(), &out);
+ ASSERT_OK(out->ValidateFull());
+
+ auto result = checked_pointer_cast<typename TypeParam::ArrayType>(out);
+ ASSERT_EQ(result->length(), 6);
+ ASSERT_EQ(result->null_count(), 3);
+
+ ASSERT_TRUE(result->IsNull(0));
+ ASSERT_FALSE(result->IsNull(1));
+ ASSERT_TRUE(result->IsNull(2));
+ ASSERT_TRUE(result->IsNull(3));
+ ASSERT_FALSE(result->IsNull(4));
+ ASSERT_FALSE(result->IsNull(5));
+
+ // implementation detail: the value slots are 0-initialized
+ for (int64_t i = 0; i < result->length(); ++i) {
+ typename TestFixture::CType t{};
+ ASSERT_EQ(result->Value(i), t);
+ }
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestArrayDtorDealloc) {
+ typedef typename TestFixture::CType T;
+
+ int64_t size = 1000;
+
+ std::vector<T>& draws = this->draws_;
+ std::vector<uint8_t>& valid_bytes = this->valid_bytes_;
+
+ int64_t memory_before = this->pool_->bytes_allocated();
+
+ this->RandomData(size);
+ ASSERT_OK(this->builder_->Reserve(size));
+
+ int64_t i;
+ for (i = 0; i < size; ++i) {
+ if (valid_bytes[i] > 0) {
+ ASSERT_OK(this->builder_->Append(draws[i]));
+ } else {
+ ASSERT_OK(this->builder_->AppendNull());
+ }
+ }
+
+ do {
+ std::shared_ptr<Array> result;
+ FinishAndCheckPadding(this->builder_.get(), &result);
+ } while (false);
+
+ ASSERT_EQ(memory_before, this->pool_->bytes_allocated());
+}
+
+TYPED_TEST(TestPrimitiveBuilder, Equality) {
+ typedef typename TestFixture::CType T;
+
+ const int64_t size = 1000;
+ this->RandomData(size);
+ std::vector<T>& draws = this->draws_;
+ std::vector<uint8_t>& valid_bytes = this->valid_bytes_;
+ std::shared_ptr<Array> array, equal_array, unequal_array;
+ auto builder = this->builder_.get();
+ ASSERT_OK(MakeArray(valid_bytes, draws, size, builder, &array));
+ ASSERT_OK(MakeArray(valid_bytes, draws, size, builder, &equal_array));
+
+ // Make the not equal array by negating the first valid element with itself.
+ const auto first_valid = std::find_if(valid_bytes.begin(), valid_bytes.end(),
+ [](uint8_t valid) { return valid > 0; });
+ const int64_t first_valid_idx = std::distance(valid_bytes.begin(), first_valid);
+ // This should be true with a very high probability, but might introduce flakiness
+ ASSERT_LT(first_valid_idx, size - 1);
+ this->FlipValue(&draws[first_valid_idx]);
+ ASSERT_OK(MakeArray(valid_bytes, draws, size, builder, &unequal_array));
+
+ // test normal equality
+ EXPECT_TRUE(array->Equals(array));
+ EXPECT_TRUE(array->Equals(equal_array));
+ EXPECT_TRUE(equal_array->Equals(array));
+ EXPECT_FALSE(equal_array->Equals(unequal_array));
+ EXPECT_FALSE(unequal_array->Equals(equal_array));
+
+ // Test range equality
+ EXPECT_FALSE(array->RangeEquals(0, first_valid_idx + 1, 0, unequal_array));
+ EXPECT_FALSE(array->RangeEquals(first_valid_idx, size, first_valid_idx, unequal_array));
+ EXPECT_TRUE(array->RangeEquals(0, first_valid_idx, 0, unequal_array));
+ EXPECT_TRUE(
+ array->RangeEquals(first_valid_idx + 1, size, first_valid_idx + 1, unequal_array));
+}
+
+TYPED_TEST(TestPrimitiveBuilder, SliceEquality) {
+ typedef typename TestFixture::CType T;
+
+ const int64_t size = 1000;
+ this->RandomData(size);
+ std::vector<T>& draws = this->draws_;
+ std::vector<uint8_t>& valid_bytes = this->valid_bytes_;
+ auto builder = this->builder_.get();
+
+ std::shared_ptr<Array> array;
+ ASSERT_OK(MakeArray(valid_bytes, draws, size, builder, &array));
+
+ std::shared_ptr<Array> slice, slice2;
+
+ slice = array->Slice(5);
+ slice2 = array->Slice(5);
+ ASSERT_EQ(size - 5, slice->length());
+
+ ASSERT_TRUE(slice->Equals(slice2));
+ ASSERT_TRUE(array->RangeEquals(5, array->length(), 0, slice));
+
+ // Chained slices
+ slice2 = array->Slice(2)->Slice(3);
+ ASSERT_TRUE(slice->Equals(slice2));
+
+ slice = array->Slice(5, 10);
+ slice2 = array->Slice(5, 10);
+ ASSERT_EQ(10, slice->length());
+
+ ASSERT_TRUE(slice->Equals(slice2));
+ ASSERT_TRUE(array->RangeEquals(5, 15, 0, slice));
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestAppendScalar) {
+ typedef typename TestFixture::CType T;
+
+ const int64_t size = 10000;
+
+ std::vector<T>& draws = this->draws_;
+ std::vector<uint8_t>& valid_bytes = this->valid_bytes_;
+
+ this->RandomData(size);
+
+ ASSERT_OK(this->builder_->Reserve(1000));
+ ASSERT_OK(this->builder_nn_->Reserve(1000));
+
+ int64_t null_count = 0;
+ // Append the first 1000
+ for (size_t i = 0; i < 1000; ++i) {
+ if (valid_bytes[i] > 0) {
+ ASSERT_OK(this->builder_->Append(draws[i]));
+ } else {
+ ASSERT_OK(this->builder_->AppendNull());
+ ++null_count;
+ }
+ ASSERT_OK(this->builder_nn_->Append(draws[i]));
+ }
+
+ ASSERT_EQ(null_count, this->builder_->null_count());
+
+ ASSERT_EQ(1000, this->builder_->length());
+ ASSERT_EQ(1000, this->builder_->capacity());
+
+ ASSERT_EQ(1000, this->builder_nn_->length());
+ ASSERT_EQ(1000, this->builder_nn_->capacity());
+
+ ASSERT_OK(this->builder_->Reserve(size - 1000));
+ ASSERT_OK(this->builder_nn_->Reserve(size - 1000));
+
+ // Append the next 9000
+ for (size_t i = 1000; i < size; ++i) {
+ if (valid_bytes[i] > 0) {
+ ASSERT_OK(this->builder_->Append(draws[i]));
+ } else {
+ ASSERT_OK(this->builder_->AppendNull());
+ }
+ ASSERT_OK(this->builder_nn_->Append(draws[i]));
+ }
+
+ ASSERT_EQ(size, this->builder_->length());
+ ASSERT_GE(size, this->builder_->capacity());
+
+ ASSERT_EQ(size, this->builder_nn_->length());
+ ASSERT_GE(size, this->builder_nn_->capacity());
+
+ this->Check(this->builder_, true);
+ this->Check(this->builder_nn_, false);
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestAppendValues) {
+ typedef typename TestFixture::CType T;
+
+ int64_t size = 10000;
+ this->RandomData(size);
+
+ std::vector<T>& draws = this->draws_;
+ std::vector<uint8_t>& valid_bytes = this->valid_bytes_;
+
+ // first slug
+ int64_t K = 1000;
+
+ ASSERT_OK(this->builder_->AppendValues(draws.data(), K, valid_bytes.data()));
+ ASSERT_OK(this->builder_nn_->AppendValues(draws.data(), K));
+
+ ASSERT_EQ(1000, this->builder_->length());
+ ASSERT_EQ(1000, this->builder_->capacity());
+
+ ASSERT_EQ(1000, this->builder_nn_->length());
+ ASSERT_EQ(1000, this->builder_nn_->capacity());
+
+ // Append the next 9000
+ ASSERT_OK(
+ this->builder_->AppendValues(draws.data() + K, size - K, valid_bytes.data() + K));
+ ASSERT_OK(this->builder_nn_->AppendValues(draws.data() + K, size - K));
+
+ ASSERT_EQ(size, this->builder_->length());
+ ASSERT_GE(size, this->builder_->capacity());
+
+ ASSERT_EQ(size, this->builder_nn_->length());
+ ASSERT_GE(size, this->builder_nn_->capacity());
+
+ this->Check(this->builder_, true);
+ this->Check(this->builder_nn_, false);
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestTypedFinish) {
+ typedef typename TestFixture::CType T;
+
+ int64_t size = 1000;
+ this->RandomData(size);
+
+ std::vector<T>& draws = this->draws_;
+ std::vector<uint8_t>& valid_bytes = this->valid_bytes_;
+
+ ASSERT_OK(this->builder_->AppendValues(draws.data(), size, valid_bytes.data()));
+ std::shared_ptr<Array> result_untyped;
+ ASSERT_OK(this->builder_->Finish(&result_untyped));
+
+ ASSERT_OK(this->builder_->AppendValues(draws.data(), size, valid_bytes.data()));
+ std::shared_ptr<typename TestFixture::ArrayType> result;
+ ASSERT_OK(this->builder_->Finish(&result));
+
+ AssertArraysEqual(*result_untyped, *result);
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestAppendValuesIter) {
+ int64_t size = 10000;
+ this->RandomData(size);
+
+ ASSERT_OK(this->builder_->AppendValues(this->draws_.begin(), this->draws_.end(),
+ this->valid_bytes_.begin()));
+ ASSERT_OK(this->builder_nn_->AppendValues(this->draws_.begin(), this->draws_.end()));
+
+ ASSERT_EQ(size, this->builder_->length());
+ ASSERT_GE(size, this->builder_->capacity());
+
+ this->Check(this->builder_, true);
+ this->Check(this->builder_nn_, false);
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestAppendValuesIterNullValid) {
+ int64_t size = 10000;
+ this->RandomData(size);
+
+ ASSERT_OK(this->builder_nn_->AppendValues(this->draws_.begin(),
+ this->draws_.begin() + size / 2,
+ static_cast<uint8_t*>(nullptr)));
+
+ ASSERT_GE(size / 2, this->builder_nn_->capacity());
+
+ ASSERT_OK(this->builder_nn_->AppendValues(this->draws_.begin() + size / 2,
+ this->draws_.end(),
+ static_cast<uint64_t*>(nullptr)));
+
+ this->Check(this->builder_nn_, false);
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestAppendValuesLazyIter) {
+ typedef typename TestFixture::CType T;
+
+ int64_t size = 10000;
+ this->RandomData(size);
+
+ auto& draws = this->draws_;
+ auto& valid_bytes = this->valid_bytes_;
+
+ auto halve = [&draws](int64_t index) {
+ return TestFixture::TestAttrs::Modify(draws[index]);
+ };
+ auto lazy_iter = internal::MakeLazyRange(halve, size);
+
+ ASSERT_OK(this->builder_->AppendValues(lazy_iter.begin(), lazy_iter.end(),
+ valid_bytes.begin()));
+
+ std::vector<T> halved;
+ transform(draws.begin(), draws.end(), back_inserter(halved),
+ [](T in) { return TestFixture::TestAttrs::Modify(in); });
+
+ std::shared_ptr<Array> result;
+ FinishAndCheckPadding(this->builder_.get(), &result);
+
+ std::shared_ptr<Array> expected;
+ ASSERT_OK(
+ this->builder_->AppendValues(halved.data(), halved.size(), valid_bytes.data()));
+ FinishAndCheckPadding(this->builder_.get(), &expected);
+
+ ASSERT_TRUE(expected->Equals(result));
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestAppendValuesIterConverted) {
+ typedef typename TestFixture::CType T;
+ // find type we can safely convert the tested values to and from
+ using conversion_type = typename TestFixture::TestAttrs::ConversionType;
+
+ int64_t size = 10000;
+ this->RandomData(size);
+
+ // append convertible values
+ std::vector<conversion_type> draws_converted(this->draws_.begin(), this->draws_.end());
+ std::vector<int32_t> valid_bytes_converted(this->valid_bytes_.begin(),
+ this->valid_bytes_.end());
+
+ auto cast_values = internal::MakeLazyRange(
+ [&draws_converted](int64_t index) {
+ return static_cast<T>(draws_converted[index]);
+ },
+ size);
+ auto cast_valid = internal::MakeLazyRange(
+ [&valid_bytes_converted](int64_t index) {
+ return static_cast<bool>(valid_bytes_converted[index]);
+ },
+ size);
+
+ ASSERT_OK(this->builder_->AppendValues(cast_values.begin(), cast_values.end(),
+ cast_valid.begin()));
+ ASSERT_OK(this->builder_nn_->AppendValues(cast_values.begin(), cast_values.end()));
+
+ ASSERT_EQ(size, this->builder_->length());
+ ASSERT_GE(size, this->builder_->capacity());
+
+ ASSERT_EQ(size, this->builder_->length());
+ ASSERT_GE(size, this->builder_->capacity());
+
+ this->Check(this->builder_, true);
+ this->Check(this->builder_nn_, false);
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestZeroPadded) {
+ typedef typename TestFixture::CType T;
+
+ int64_t size = 10000;
+ this->RandomData(size);
+
+ std::vector<T>& draws = this->draws_;
+ std::vector<uint8_t>& valid_bytes = this->valid_bytes_;
+
+ // first slug
+ int64_t K = 1000;
+
+ ASSERT_OK(this->builder_->AppendValues(draws.data(), K, valid_bytes.data()));
+
+ std::shared_ptr<Array> out;
+ FinishAndCheckPadding(this->builder_.get(), &out);
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestAppendValuesStdBool) {
+ // ARROW-1383
+ typedef typename TestFixture::CType T;
+
+ int64_t size = 10000;
+ this->RandomData(size);
+
+ std::vector<T>& draws = this->draws_;
+
+ std::vector<bool> is_valid;
+
+ // first slug
+ int64_t K = 1000;
+
+ for (int64_t i = 0; i < K; ++i) {
+ is_valid.push_back(this->valid_bytes_[i] != 0);
+ }
+ ASSERT_OK(this->builder_->AppendValues(draws.data(), K, is_valid));
+ ASSERT_OK(this->builder_nn_->AppendValues(draws.data(), K));
+
+ ASSERT_EQ(1000, this->builder_->length());
+ ASSERT_EQ(1000, this->builder_->capacity());
+ ASSERT_EQ(1000, this->builder_nn_->length());
+ ASSERT_EQ(1000, this->builder_nn_->capacity());
+
+ // Append the next 9000
+ is_valid.clear();
+ std::vector<T> partial_draws;
+ for (int64_t i = K; i < size; ++i) {
+ partial_draws.push_back(draws[i]);
+ is_valid.push_back(this->valid_bytes_[i] != 0);
+ }
+
+ ASSERT_OK(this->builder_->AppendValues(partial_draws, is_valid));
+ ASSERT_OK(this->builder_nn_->AppendValues(partial_draws));
+
+ ASSERT_EQ(size, this->builder_->length());
+ ASSERT_GE(size, this->builder_->capacity());
+
+ ASSERT_EQ(size, this->builder_nn_->length());
+ ASSERT_GE(size, this->builder_->capacity());
+
+ this->Check(this->builder_, true);
+ this->Check(this->builder_nn_, false);
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestAdvance) {
+ ARROW_SUPPRESS_DEPRECATION_WARNING
+ int64_t n = 1000;
+ ASSERT_OK(this->builder_->Reserve(n));
+
+ ASSERT_OK(this->builder_->Advance(100));
+ ASSERT_EQ(100, this->builder_->length());
+
+ ASSERT_OK(this->builder_->Advance(900));
+
+ int64_t too_many = this->builder_->capacity() - 1000 + 1;
+ ASSERT_RAISES(Invalid, this->builder_->Advance(too_many));
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestResize) {
+ int64_t cap = kMinBuilderCapacity * 2;
+
+ ASSERT_OK(this->builder_->Reserve(cap));
+ ASSERT_EQ(cap, this->builder_->capacity());
+}
+
+TYPED_TEST(TestPrimitiveBuilder, TestReserve) {
+ ASSERT_OK(this->builder_->Reserve(10));
+ ASSERT_EQ(0, this->builder_->length());
+ ASSERT_EQ(kMinBuilderCapacity, this->builder_->capacity());
+
+ ASSERT_OK(this->builder_->Reserve(100));
+ ASSERT_EQ(0, this->builder_->length());
+ ASSERT_GE(100, this->builder_->capacity());
+ ARROW_SUPPRESS_DEPRECATION_WARNING
+ ASSERT_OK(this->builder_->Advance(100));
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
+ ASSERT_EQ(100, this->builder_->length());
+ ASSERT_GE(100, this->builder_->capacity());
+
+ ASSERT_RAISES(Invalid, this->builder_->Resize(1));
+}
+
+TEST(TestBooleanBuilder, AppendNullsAdvanceBuilder) {
+ BooleanBuilder builder;
+
+ std::vector<uint8_t> values = {1, 0, 0, 1};
+ std::vector<uint8_t> is_valid = {1, 1, 0, 1};
+
+ std::shared_ptr<Array> arr;
+ ASSERT_OK(builder.AppendValues(values.data(), 2));
+ ASSERT_OK(builder.AppendNulls(1));
+ ASSERT_OK(builder.AppendValues(values.data() + 3, 1));
+ ASSERT_OK(builder.Finish(&arr));
+
+ ASSERT_EQ(1, arr->null_count());
+
+ const auto& barr = static_cast<const BooleanArray&>(*arr);
+ ASSERT_TRUE(barr.Value(0));
+ ASSERT_FALSE(barr.Value(1));
+ ASSERT_TRUE(barr.IsNull(2));
+ ASSERT_TRUE(barr.Value(3));
+}
+
+TEST(TestBooleanBuilder, TestStdBoolVectorAppend) {
+ BooleanBuilder builder;
+ BooleanBuilder builder_nn;
+
+ std::vector<bool> values, is_valid;
+
+ const int length = 10000;
+ random_is_valid(length, 0.5, &values);
+ random_is_valid(length, 0.1, &is_valid);
+
+ const int chunksize = 1000;
+ for (int chunk = 0; chunk < length / chunksize; ++chunk) {
+ std::vector<bool> chunk_values, chunk_is_valid;
+ for (int i = chunk * chunksize; i < (chunk + 1) * chunksize; ++i) {
+ chunk_values.push_back(values[i]);
+ chunk_is_valid.push_back(is_valid[i]);
+ }
+ ASSERT_OK(builder.AppendValues(chunk_values, chunk_is_valid));
+ ASSERT_OK(builder_nn.AppendValues(chunk_values));
+ }
+
+ std::shared_ptr<Array> result, result_nn;
+ ASSERT_OK(builder.Finish(&result));
+ ASSERT_OK(builder_nn.Finish(&result_nn));
+
+ const auto& arr = checked_cast<const BooleanArray&>(*result);
+ const auto& arr_nn = checked_cast<const BooleanArray&>(*result_nn);
+ for (int i = 0; i < length; ++i) {
+ if (is_valid[i]) {
+ ASSERT_FALSE(arr.IsNull(i));
+ ASSERT_EQ(values[i], arr.Value(i));
+ } else {
+ ASSERT_TRUE(arr.IsNull(i));
+ }
+ ASSERT_EQ(values[i], arr_nn.Value(i));
+ }
+}
+
+template <typename TYPE>
+void CheckApproxEquals() {
+ std::shared_ptr<Array> a, b;
+ std::shared_ptr<DataType> type = TypeTraits<TYPE>::type_singleton();
+
+ ArrayFromVector<TYPE>(type, {true, false}, {0.5, 1.0}, &a);
+ ArrayFromVector<TYPE>(type, {true, false}, {0.5000001f, 1.0000001f}, &b);
+ ASSERT_TRUE(a->ApproxEquals(b));
+ ASSERT_TRUE(b->ApproxEquals(a));
+ ASSERT_TRUE(a->ApproxEquals(b, EqualOptions().nans_equal(true)));
+ ASSERT_TRUE(b->ApproxEquals(a, EqualOptions().nans_equal(true)));
+
+ ArrayFromVector<TYPE>(type, {true, false}, {0.5001f, 1.000001f}, &b);
+ ASSERT_FALSE(a->ApproxEquals(b));
+ ASSERT_FALSE(b->ApproxEquals(a));
+ ASSERT_FALSE(a->ApproxEquals(b, EqualOptions().nans_equal(true)));
+ ASSERT_FALSE(b->ApproxEquals(a, EqualOptions().nans_equal(true)));
+ ASSERT_TRUE(a->ApproxEquals(b, EqualOptions().atol(1e-3)));
+ ASSERT_TRUE(b->ApproxEquals(a, EqualOptions().atol(1e-3)));
+ ASSERT_TRUE(a->ApproxEquals(b, EqualOptions().atol(1e-3).nans_equal(true)));
+ ASSERT_TRUE(b->ApproxEquals(a, EqualOptions().atol(1e-3).nans_equal(true)));
+
+ ArrayFromVector<TYPE>(type, {true, false}, {0.5, 1.25}, &b);
+ ASSERT_TRUE(a->ApproxEquals(b));
+ ASSERT_TRUE(b->ApproxEquals(a));
+ ASSERT_TRUE(a->ApproxEquals(b, EqualOptions().nans_equal(true)));
+ ASSERT_TRUE(b->ApproxEquals(a, EqualOptions().nans_equal(true)));
+
+ // Mismatching validity
+ ArrayFromVector<TYPE>(type, {true, false}, {0.5, 1.0}, &a);
+ ArrayFromVector<TYPE>(type, {false, false}, {0.5, 1.0}, &b);
+ ASSERT_FALSE(a->ApproxEquals(b));
+ ASSERT_FALSE(b->ApproxEquals(a));
+ ASSERT_FALSE(a->ApproxEquals(b, EqualOptions().nans_equal(true)));
+ ASSERT_FALSE(b->ApproxEquals(a, EqualOptions().nans_equal(true)));
+
+ ArrayFromVector<TYPE>(type, {false, true}, {0.5, 1.0}, &b);
+ ASSERT_FALSE(a->ApproxEquals(b));
+ ASSERT_FALSE(b->ApproxEquals(a));
+ ASSERT_FALSE(a->ApproxEquals(b, EqualOptions().nans_equal(true)));
+ ASSERT_FALSE(b->ApproxEquals(a, EqualOptions().nans_equal(true)));
+}
+
+template <typename TYPE>
+void CheckSliceApproxEquals() {
+ using T = typename TYPE::c_type;
+
+ const int64_t kSize = 50;
+ std::vector<T> draws1;
+ std::vector<T> draws2;
+
+ const uint32_t kSeed = 0;
+ random_real(kSize, kSeed, 0.0, 100.0, &draws1);
+ random_real(kSize, kSeed + 1, 0.0, 100.0, &draws2);
+
+ // Make the draws equal in the sliced segment, but unequal elsewhere (to
+ // catch not using the slice offset)
+ for (int64_t i = 10; i < 30; ++i) {
+ draws2[i] = draws1[i];
+ }
+
+ std::vector<bool> is_valid;
+ random_is_valid(kSize, 0.1, &is_valid);
+
+ std::shared_ptr<Array> array1, array2;
+ ArrayFromVector<TYPE, T>(is_valid, draws1, &array1);
+ ArrayFromVector<TYPE, T>(is_valid, draws2, &array2);
+
+ std::shared_ptr<Array> slice1 = array1->Slice(10, 20);
+ std::shared_ptr<Array> slice2 = array2->Slice(10, 20);
+
+ ASSERT_TRUE(slice1->ApproxEquals(slice2));
+}
+
+template <typename TYPE>
+void CheckFloatingNanEquality() {
+ std::shared_ptr<Array> a, b;
+ std::shared_ptr<DataType> type = TypeTraits<TYPE>::type_singleton();
+
+ const auto nan_value = static_cast<typename TYPE::c_type>(NAN);
+
+ // NaN in a null entry
+ ArrayFromVector<TYPE>(type, {true, false}, {0.5, nan_value}, &a);
+ ArrayFromVector<TYPE>(type, {true, false}, {0.5, nan_value}, &b);
+ ASSERT_TRUE(a->Equals(b));
+ ASSERT_TRUE(b->Equals(a));
+ ASSERT_TRUE(a->ApproxEquals(b));
+ ASSERT_TRUE(b->ApproxEquals(a));
+ ASSERT_TRUE(a->RangeEquals(b, 0, 2, 0));
+ ASSERT_TRUE(b->RangeEquals(a, 0, 2, 0));
+ ASSERT_TRUE(a->RangeEquals(b, 1, 2, 1));
+ ASSERT_TRUE(b->RangeEquals(a, 1, 2, 1));
+
+ // NaN in a valid entry
+ ArrayFromVector<TYPE>(type, {false, true}, {0.5, nan_value}, &a);
+ ArrayFromVector<TYPE>(type, {false, true}, {0.5, nan_value}, &b);
+ ASSERT_FALSE(a->Equals(b));
+ ASSERT_FALSE(b->Equals(a));
+ ASSERT_TRUE(a->Equals(b, EqualOptions().nans_equal(true)));
+ ASSERT_TRUE(b->Equals(a, EqualOptions().nans_equal(true)));
+ ASSERT_FALSE(a->ApproxEquals(b));
+ ASSERT_FALSE(b->ApproxEquals(a));
+ ASSERT_TRUE(a->ApproxEquals(b, EqualOptions().atol(1e-5).nans_equal(true)));
+ ASSERT_TRUE(b->ApproxEquals(a, EqualOptions().atol(1e-5).nans_equal(true)));
+ // NaN in tested range
+ ASSERT_FALSE(a->RangeEquals(b, 0, 2, 0));
+ ASSERT_FALSE(b->RangeEquals(a, 0, 2, 0));
+ ASSERT_FALSE(a->RangeEquals(b, 1, 2, 1));
+ ASSERT_FALSE(b->RangeEquals(a, 1, 2, 1));
+ // NaN not in tested range
+ ASSERT_TRUE(a->RangeEquals(b, 0, 1, 0));
+ ASSERT_TRUE(b->RangeEquals(a, 0, 1, 0));
+
+ // NaN != non-NaN
+ ArrayFromVector<TYPE>(type, {false, true}, {0.5, nan_value}, &a);
+ ArrayFromVector<TYPE>(type, {false, true}, {0.5, 0.0}, &b);
+ ASSERT_FALSE(a->Equals(b));
+ ASSERT_FALSE(b->Equals(a));
+ ASSERT_FALSE(a->Equals(b, EqualOptions().nans_equal(true)));
+ ASSERT_FALSE(b->Equals(a, EqualOptions().nans_equal(true)));
+ ASSERT_FALSE(a->ApproxEquals(b));
+ ASSERT_FALSE(b->ApproxEquals(a));
+ ASSERT_FALSE(a->ApproxEquals(b, EqualOptions().atol(1e-5).nans_equal(true)));
+ ASSERT_FALSE(b->ApproxEquals(a, EqualOptions().atol(1e-5).nans_equal(true)));
+ // NaN in tested range
+ ASSERT_FALSE(a->RangeEquals(b, 0, 2, 0));
+ ASSERT_FALSE(b->RangeEquals(a, 0, 2, 0));
+ ASSERT_FALSE(a->RangeEquals(b, 1, 2, 1));
+ ASSERT_FALSE(b->RangeEquals(a, 1, 2, 1));
+ // NaN not in tested range
+ ASSERT_TRUE(a->RangeEquals(b, 0, 1, 0));
+ ASSERT_TRUE(b->RangeEquals(a, 0, 1, 0));
+}
+
+template <typename TYPE>
+void CheckFloatingInfinityEquality() {
+ std::shared_ptr<Array> a, b;
+ std::shared_ptr<DataType> type = TypeTraits<TYPE>::type_singleton();
+
+ const auto infinity = std::numeric_limits<typename TYPE::c_type>::infinity();
+
+ for (auto nans_equal : {false, true}) {
+ // Infinity in a null entry
+ ArrayFromVector<TYPE>(type, {true, false}, {0.5, infinity}, &a);
+ ArrayFromVector<TYPE>(type, {true, false}, {0.5, -infinity}, &b);
+ ASSERT_TRUE(a->Equals(b));
+ ASSERT_TRUE(b->Equals(a));
+ ASSERT_TRUE(a->ApproxEquals(b, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
+ ASSERT_TRUE(b->ApproxEquals(a, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
+ ASSERT_TRUE(a->RangeEquals(b, 0, 2, 0));
+ ASSERT_TRUE(b->RangeEquals(a, 0, 2, 0));
+ ASSERT_TRUE(a->RangeEquals(b, 1, 2, 1));
+ ASSERT_TRUE(b->RangeEquals(a, 1, 2, 1));
+
+ // Infinity in a valid entry
+ ArrayFromVector<TYPE>(type, {false, true}, {0.5, infinity}, &a);
+ ArrayFromVector<TYPE>(type, {false, true}, {0.5, infinity}, &b);
+ ASSERT_TRUE(a->Equals(b));
+ ASSERT_TRUE(b->Equals(a));
+ ASSERT_TRUE(a->ApproxEquals(b, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
+ ASSERT_TRUE(b->ApproxEquals(a, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
+ ASSERT_TRUE(a->ApproxEquals(b, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
+ ASSERT_TRUE(b->ApproxEquals(a, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
+ // Infinity in tested range
+ ASSERT_TRUE(a->RangeEquals(b, 0, 2, 0));
+ ASSERT_TRUE(b->RangeEquals(a, 0, 2, 0));
+ ASSERT_TRUE(a->RangeEquals(b, 1, 2, 1));
+ ASSERT_TRUE(b->RangeEquals(a, 1, 2, 1));
+ // Infinity not in tested range
+ ASSERT_TRUE(a->RangeEquals(b, 0, 1, 0));
+ ASSERT_TRUE(b->RangeEquals(a, 0, 1, 0));
+
+ // Infinity != non-infinity
+ ArrayFromVector<TYPE>(type, {false, true}, {0.5, -infinity}, &a);
+ ArrayFromVector<TYPE>(type, {false, true}, {0.5, 0.0}, &b);
+ ASSERT_FALSE(a->Equals(b));
+ ASSERT_FALSE(b->Equals(a));
+ ASSERT_FALSE(a->ApproxEquals(b, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
+ ASSERT_FALSE(b->ApproxEquals(a));
+ ASSERT_FALSE(a->ApproxEquals(b, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
+ ASSERT_FALSE(b->ApproxEquals(a, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
+ // Infinity != Negative infinity
+ ArrayFromVector<TYPE>(type, {true, true}, {0.5, -infinity}, &a);
+ ArrayFromVector<TYPE>(type, {true, true}, {0.5, infinity}, &b);
+ ASSERT_FALSE(a->Equals(b));
+ ASSERT_FALSE(b->Equals(a));
+ ASSERT_FALSE(a->ApproxEquals(b));
+ ASSERT_FALSE(b->ApproxEquals(a));
+ ASSERT_FALSE(a->ApproxEquals(b, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
+ ASSERT_FALSE(b->ApproxEquals(a, EqualOptions().atol(1e-5).nans_equal(nans_equal)));
+ // Infinity in tested range
+ ASSERT_FALSE(a->RangeEquals(b, 0, 2, 0));
+ ASSERT_FALSE(b->RangeEquals(a, 0, 2, 0));
+ ASSERT_FALSE(a->RangeEquals(b, 1, 2, 1));
+ ASSERT_FALSE(b->RangeEquals(a, 1, 2, 1));
+ // Infinity not in tested range
+ ASSERT_TRUE(a->RangeEquals(b, 0, 1, 0));
+ ASSERT_TRUE(b->RangeEquals(a, 0, 1, 0));
+ }
+}
+
+TEST(TestPrimitiveAdHoc, FloatingApproxEquals) {
+ CheckApproxEquals<FloatType>();
+ CheckApproxEquals<DoubleType>();
+}
+
+TEST(TestPrimitiveAdHoc, FloatingSliceApproxEquals) {
+ CheckSliceApproxEquals<FloatType>();
+ CheckSliceApproxEquals<DoubleType>();
+}
+
+TEST(TestPrimitiveAdHoc, FloatingNanEquality) {
+ CheckFloatingNanEquality<FloatType>();
+ CheckFloatingNanEquality<DoubleType>();
+}
+
+TEST(TestPrimitiveAdHoc, FloatingInfinityEquality) {
+ CheckFloatingInfinityEquality<FloatType>();
+ CheckFloatingInfinityEquality<DoubleType>();
+}
+
+// ----------------------------------------------------------------------
+// FixedSizeBinary tests
+
+class TestFWBinaryArray : public ::testing::Test {
+ public:
+ void SetUp() {}
+
+ void InitBuilder(int byte_width) {
+ auto type = fixed_size_binary(byte_width);
+ builder_.reset(new FixedSizeBinaryBuilder(type, default_memory_pool()));
+ }
+
+ protected:
+ std::unique_ptr<FixedSizeBinaryBuilder> builder_;
+};
+
+TEST_F(TestFWBinaryArray, Builder) {
+ int32_t byte_width = 10;
+ int64_t length = 4096;
+
+ int64_t nbytes = length * byte_width;
+
+ std::vector<uint8_t> data(nbytes);
+ random_bytes(nbytes, 0, data.data());
+
+ std::vector<uint8_t> is_valid(length);
+ random_null_bytes(length, 0.1, is_valid.data());
+
+ const uint8_t* raw_data = data.data();
+
+ std::shared_ptr<Array> result;
+
+ auto CheckResult = [&length, &is_valid, &raw_data, &byte_width](const Array& result) {
+ // Verify output
+ const auto& fw_result = checked_cast<const FixedSizeBinaryArray&>(result);
+
+ ASSERT_EQ(length, result.length());
+
+ for (int64_t i = 0; i < result.length(); ++i) {
+ if (is_valid[i]) {
+ ASSERT_EQ(0,
+ memcmp(raw_data + byte_width * i, fw_result.GetValue(i), byte_width));
+ } else {
+ ASSERT_TRUE(fw_result.IsNull(i));
+ }
+ }
+ };
+
+ // Build using iterative API
+ InitBuilder(byte_width);
+ for (int64_t i = 0; i < length; ++i) {
+ if (is_valid[i]) {
+ ASSERT_OK(builder_->Append(raw_data + byte_width * i));
+ } else {
+ ASSERT_OK(builder_->AppendNull());
+ }
+ }
+
+ FinishAndCheckPadding(builder_.get(), &result);
+ CheckResult(*result);
+
+ // Build using batch API
+ InitBuilder(byte_width);
+
+ const uint8_t* raw_is_valid = is_valid.data();
+
+ ASSERT_OK(builder_->AppendValues(raw_data, 50, raw_is_valid));
+ ASSERT_OK(
+ builder_->AppendValues(raw_data + 50 * byte_width, length - 50, raw_is_valid + 50));
+ FinishAndCheckPadding(builder_.get(), &result);
+
+ CheckResult(*result);
+
+ // Build from std::string
+ InitBuilder(byte_width);
+ for (int64_t i = 0; i < length; ++i) {
+ if (is_valid[i]) {
+ ASSERT_OK(builder_->Append(std::string(
+ reinterpret_cast<const char*>(raw_data + byte_width * i), byte_width)));
+ } else {
+ ASSERT_OK(builder_->AppendNull());
+ }
+ }
+
+ ASSERT_OK(builder_->Finish(&result));
+ CheckResult(*result);
+}
+
+TEST_F(TestFWBinaryArray, EqualsRangeEquals) {
+ // Check that we don't compare data in null slots
+
+ auto type = fixed_size_binary(4);
+ FixedSizeBinaryBuilder builder1(type);
+ FixedSizeBinaryBuilder builder2(type);
+
+ ASSERT_OK(builder1.Append("foo1"));
+ ASSERT_OK(builder1.AppendNull());
+
+ ASSERT_OK(builder2.Append("foo1"));
+ ASSERT_OK(builder2.Append("foo2"));
+
+ std::shared_ptr<Array> array1, array2;
+ ASSERT_OK(builder1.Finish(&array1));
+ ASSERT_OK(builder2.Finish(&array2));
+
+ const auto& a1 = checked_cast<const FixedSizeBinaryArray&>(*array1);
+ const auto& a2 = checked_cast<const FixedSizeBinaryArray&>(*array2);
+
+ FixedSizeBinaryArray equal1(type, 2, a1.values(), a1.null_bitmap(), 1);
+ FixedSizeBinaryArray equal2(type, 2, a2.values(), a1.null_bitmap(), 1);
+
+ ASSERT_TRUE(equal1.Equals(equal2));
+ ASSERT_TRUE(equal1.RangeEquals(equal2, 0, 2, 0));
+}
+
+TEST_F(TestFWBinaryArray, ZeroSize) {
+ auto type = fixed_size_binary(0);
+ FixedSizeBinaryBuilder builder(type);
+
+ ASSERT_OK(builder.Append(""));
+ ASSERT_OK(builder.Append(std::string()));
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_OK(builder.AppendNull());
+
+ std::shared_ptr<Array> array;
+ ASSERT_OK(builder.Finish(&array));
+
+ const auto& fw_array = checked_cast<const FixedSizeBinaryArray&>(*array);
+
+ // data is never allocated
+ ASSERT_EQ(fw_array.values()->size(), 0);
+ ASSERT_EQ(0, fw_array.byte_width());
+
+ ASSERT_EQ(5, array->length());
+ ASSERT_EQ(3, array->null_count());
+}
+
+TEST_F(TestFWBinaryArray, ZeroPadding) {
+ auto type = fixed_size_binary(4);
+ FixedSizeBinaryBuilder builder(type);
+
+ ASSERT_OK(builder.Append("foo1"));
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_OK(builder.Append("foo2"));
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_OK(builder.Append("foo3"));
+
+ std::shared_ptr<Array> array;
+ FinishAndCheckPadding(&builder, &array);
+}
+
+TEST_F(TestFWBinaryArray, Slice) {
+ auto type = fixed_size_binary(4);
+ FixedSizeBinaryBuilder builder(type);
+
+ std::vector<std::string> strings = {"foo1", "foo2", "foo3", "foo4", "foo5"};
+ std::vector<uint8_t> is_null = {0, 1, 0, 0, 0};
+
+ for (int i = 0; i < 5; ++i) {
+ if (is_null[i]) {
+ ASSERT_OK(builder.AppendNull());
+ } else {
+ ASSERT_OK(builder.Append(strings[i]));
+ }
+ }
+
+ std::shared_ptr<Array> array;
+ ASSERT_OK(builder.Finish(&array));
+
+ std::shared_ptr<Array> slice, slice2;
+
+ slice = array->Slice(1);
+ slice2 = array->Slice(1);
+ ASSERT_EQ(4, slice->length());
+
+ ASSERT_TRUE(slice->Equals(slice2));
+ ASSERT_TRUE(array->RangeEquals(1, slice->length(), 0, slice));
+
+ // Chained slices
+ slice = array->Slice(2);
+ slice2 = array->Slice(1)->Slice(1);
+ ASSERT_TRUE(slice->Equals(slice2));
+
+ slice = array->Slice(1, 3);
+ ASSERT_EQ(3, slice->length());
+
+ slice2 = array->Slice(1, 3);
+ ASSERT_TRUE(slice->Equals(slice2));
+ ASSERT_TRUE(array->RangeEquals(1, 3, 0, slice));
+}
+
+TEST_F(TestFWBinaryArray, BuilderNulls) {
+ auto type = fixed_size_binary(4);
+ FixedSizeBinaryBuilder builder(type);
+
+ for (int x = 0; x < 100; x++) {
+ ASSERT_OK(builder.AppendNull());
+ }
+ ASSERT_OK(builder.AppendNulls(500));
+
+ std::shared_ptr<Array> array;
+ ASSERT_OK(builder.Finish(&array));
+
+ for (auto buffer : array->data()->buffers) {
+ for (int64_t i = 0; i < buffer->capacity(); i++) {
+ // Validates current implementation, algorithms shouldn't rely on this
+ ASSERT_EQ(0, *(buffer->data() + i)) << i;
+ }
+ }
+}
+
+struct FWBinaryAppender {
+ Status VisitNull() {
+ data.emplace_back("(null)");
+ return Status::OK();
+ }
+
+ Status VisitValue(util::string_view v) {
+ data.push_back(v);
+ return Status::OK();
+ }
+
+ std::vector<util::string_view> data;
+};
+
+TEST_F(TestFWBinaryArray, ArrayDataVisitor) {
+ auto type = fixed_size_binary(3);
+
+ auto array = ArrayFromJSON(type, R"(["abc", null, "def"])");
+ FWBinaryAppender appender;
+ ArrayDataVisitor<FixedSizeBinaryType> visitor;
+ ASSERT_OK(visitor.Visit(*array->data(), &appender));
+ ASSERT_THAT(appender.data, ::testing::ElementsAreArray({"abc", "(null)", "def"}));
+ ARROW_UNUSED(visitor); // Workaround weird MSVC warning
+}
+
+TEST_F(TestFWBinaryArray, ArrayDataVisitorSliced) {
+ auto type = fixed_size_binary(3);
+
+ auto array = ArrayFromJSON(type, R"(["abc", null, "def", "ghi"])")->Slice(1, 2);
+ FWBinaryAppender appender;
+ ArrayDataVisitor<FixedSizeBinaryType> visitor;
+ ASSERT_OK(visitor.Visit(*array->data(), &appender));
+ ASSERT_THAT(appender.data, ::testing::ElementsAreArray({"(null)", "def"}));
+ ARROW_UNUSED(visitor); // Workaround weird MSVC warning
+}
+
+// ----------------------------------------------------------------------
+// AdaptiveInt tests
+
+class TestAdaptiveIntBuilder : public TestBuilder {
+ public:
+ void SetUp() {
+ TestBuilder::SetUp();
+ builder_ = std::make_shared<AdaptiveIntBuilder>(pool_);
+ }
+
+ void Done() { FinishAndCheckPadding(builder_.get(), &result_); }
+
+ template <typename ExpectedType>
+ void TestAppendValues() {
+ using CType = typename TypeTraits<ExpectedType>::CType;
+ auto type = TypeTraits<ExpectedType>::type_singleton();
+
+ std::vector<int64_t> values(
+ {0, std::numeric_limits<CType>::min(), std::numeric_limits<CType>::max()});
+ std::vector<CType> expected_values(
+ {0, std::numeric_limits<CType>::min(), std::numeric_limits<CType>::max()});
+ ArrayFromVector<ExpectedType, CType>(expected_values, &expected_);
+
+ SetUp();
+ ASSERT_OK(builder_->AppendValues(values.data(), values.size()));
+ AssertTypeEqual(*builder_->type(), *type);
+ ASSERT_EQ(builder_->length(), static_cast<int64_t>(values.size()));
+ Done();
+ ASSERT_EQ(builder_->length(), 0);
+ AssertArraysEqual(*expected_, *result_);
+
+ // Reuse builder
+ builder_->Reset();
+ AssertTypeEqual(*builder_->type(), *int8());
+ ASSERT_OK(builder_->AppendValues(values.data(), values.size()));
+ AssertTypeEqual(*builder_->type(), *type);
+ ASSERT_EQ(builder_->length(), static_cast<int64_t>(values.size()));
+ Done();
+ ASSERT_EQ(builder_->length(), 0);
+ AssertArraysEqual(*expected_, *result_);
+ }
+
+ protected:
+ std::shared_ptr<AdaptiveIntBuilder> builder_;
+
+ std::shared_ptr<Array> expected_;
+ std::shared_ptr<Array> result_;
+};
+
+TEST_F(TestAdaptiveIntBuilder, TestInt8) {
+ ASSERT_EQ(builder_->type()->id(), Type::INT8);
+ ASSERT_EQ(builder_->length(), 0);
+ ASSERT_OK(builder_->Append(0));
+ ASSERT_EQ(builder_->length(), 1);
+ ASSERT_OK(builder_->Append(127));
+ ASSERT_EQ(builder_->length(), 2);
+ ASSERT_OK(builder_->Append(-128));
+ ASSERT_EQ(builder_->length(), 3);
+
+ Done();
+
+ std::vector<int8_t> expected_values({0, 127, -128});
+ ArrayFromVector<Int8Type, int8_t>(expected_values, &expected_);
+ AssertArraysEqual(*expected_, *result_);
+}
+
+TEST_F(TestAdaptiveIntBuilder, TestInt16) {
+ ASSERT_OK(builder_->Append(0));
+ ASSERT_EQ(builder_->type()->id(), Type::INT8);
+ ASSERT_OK(builder_->Append(128));
+ ASSERT_EQ(builder_->type()->id(), Type::INT16);
+ Done();
+
+ std::vector<int16_t> expected_values({0, 128});
+ ArrayFromVector<Int16Type, int16_t>(expected_values, &expected_);
+ AssertArraysEqual(*expected_, *result_);
+
+ SetUp();
+ ASSERT_OK(builder_->Append(-129));
+ expected_values = {-129};
+ Done();
+
+ ArrayFromVector<Int16Type, int16_t>(expected_values, &expected_);
+ AssertArraysEqual(*expected_, *result_);
+
+ SetUp();
+ ASSERT_OK(builder_->Append(std::numeric_limits<int16_t>::max()));
+ ASSERT_OK(builder_->Append(std::numeric_limits<int16_t>::min()));
+ expected_values = {std::numeric_limits<int16_t>::max(),
+ std::numeric_limits<int16_t>::min()};
+ Done();
+
+ ArrayFromVector<Int16Type, int16_t>(expected_values, &expected_);
+ AssertArraysEqual(*expected_, *result_);
+}
+
+TEST_F(TestAdaptiveIntBuilder, TestInt16Nulls) {
+ ASSERT_OK(builder_->Append(0));
+ ASSERT_EQ(builder_->type()->id(), Type::INT8);
+ ASSERT_OK(builder_->Append(128));
+ ASSERT_EQ(builder_->type()->id(), Type::INT16);
+ ASSERT_OK(builder_->AppendNull());
+ ASSERT_EQ(builder_->type()->id(), Type::INT16);
+ ASSERT_EQ(1, builder_->null_count());
+ Done();
+
+ std::vector<int16_t> expected_values({0, 128, 0});
+ ArrayFromVector<Int16Type, int16_t>({1, 1, 0}, expected_values, &expected_);
+ ASSERT_ARRAYS_EQUAL(*expected_, *result_);
+}
+
+TEST_F(TestAdaptiveIntBuilder, TestInt32) {
+ ASSERT_OK(builder_->Append(0));
+ ASSERT_EQ(builder_->type()->id(), Type::INT8);
+ ASSERT_OK(
+ builder_->Append(static_cast<int64_t>(std::numeric_limits<int16_t>::max()) + 1));
+ ASSERT_EQ(builder_->type()->id(), Type::INT32);
+ Done();
+
+ std::vector<int32_t> expected_values(
+ {0, static_cast<int32_t>(std::numeric_limits<int16_t>::max()) + 1});
+ ArrayFromVector<Int32Type, int32_t>(expected_values, &expected_);
+ AssertArraysEqual(*expected_, *result_);
+
+ SetUp();
+ ASSERT_OK(
+ builder_->Append(static_cast<int64_t>(std::numeric_limits<int16_t>::min()) - 1));
+ expected_values = {static_cast<int32_t>(std::numeric_limits<int16_t>::min()) - 1};
+ Done();
+
+ ArrayFromVector<Int32Type, int32_t>(expected_values, &expected_);
+ AssertArraysEqual(*expected_, *result_);
+
+ SetUp();
+ ASSERT_OK(builder_->Append(std::numeric_limits<int32_t>::max()));
+ ASSERT_OK(builder_->Append(std::numeric_limits<int32_t>::min()));
+ expected_values = {std::numeric_limits<int32_t>::max(),
+ std::numeric_limits<int32_t>::min()};
+ Done();
+
+ ArrayFromVector<Int32Type, int32_t>(expected_values, &expected_);
+ AssertArraysEqual(*expected_, *result_);
+}
+
+TEST_F(TestAdaptiveIntBuilder, TestInt64) {
+ ASSERT_OK(builder_->Append(0));
+ ASSERT_EQ(builder_->type()->id(), Type::INT8);
+ ASSERT_OK(
+ builder_->Append(static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1));
+ ASSERT_EQ(builder_->type()->id(), Type::INT64);
+ Done();
+
+ std::vector<int64_t> expected_values(
+ {0, static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1});
+ ArrayFromVector<Int64Type, int64_t>(expected_values, &expected_);
+ AssertArraysEqual(*expected_, *result_);
+
+ SetUp();
+ ASSERT_OK(
+ builder_->Append(static_cast<int64_t>(std::numeric_limits<int32_t>::min()) - 1));
+ expected_values = {static_cast<int64_t>(std::numeric_limits<int32_t>::min()) - 1};
+ Done();
+
+ ArrayFromVector<Int64Type, int64_t>(expected_values, &expected_);
+ AssertArraysEqual(*expected_, *result_);
+
+ SetUp();
+ ASSERT_OK(builder_->Append(std::numeric_limits<int64_t>::max()));
+ ASSERT_OK(builder_->Append(std::numeric_limits<int64_t>::min()));
+ expected_values = {std::numeric_limits<int64_t>::max(),
+ std::numeric_limits<int64_t>::min()};
+ Done();
+
+ ArrayFromVector<Int64Type, int64_t>(expected_values, &expected_);
+ AssertArraysEqual(*expected_, *result_);
+}
+
+TEST_F(TestAdaptiveIntBuilder, TestManyAppends) {
+ // More than the builder's internal scratchpad size
+ const int32_t n_values = 99999;
+ std::vector<int32_t> expected_values(n_values);
+
+ for (int32_t i = 0; i < n_values; ++i) {
+ int32_t val = (i & 1) ? i : -i;
+ expected_values[i] = val;
+ ASSERT_OK(builder_->Append(val));
+ ASSERT_EQ(builder_->length(), i + 1);
+ }
+ ASSERT_EQ(builder_->type()->id(), Type::INT32);
+ Done();
+
+ ArrayFromVector<Int32Type, int32_t>(expected_values, &expected_);
+ AssertArraysEqual(*expected_, *result_);
+}
+
+TEST_F(TestAdaptiveIntBuilder, TestAppendValues) {
+ this->template TestAppendValues<Int64Type>();
+ this->template TestAppendValues<Int32Type>();
+ this->template TestAppendValues<Int16Type>();
+ this->template TestAppendValues<Int8Type>();
+}
+
+TEST_F(TestAdaptiveIntBuilder, TestAssertZeroPadded) {
+ std::vector<int64_t> values(
+ {0, static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1});
+ ASSERT_OK(builder_->AppendValues(values.data(), values.size()));
+ Done();
+}
+
+TEST_F(TestAdaptiveIntBuilder, TestAppendNull) {
+ int64_t size = 1000;
+ ASSERT_OK(builder_->Append(127));
+ ASSERT_EQ(0, builder_->null_count());
+ for (unsigned index = 1; index < size - 1; ++index) {
+ ASSERT_OK(builder_->AppendNull());
+ ASSERT_EQ(index, builder_->null_count());
+ }
+ ASSERT_OK(builder_->Append(-128));
+ ASSERT_EQ(size - 2, builder_->null_count());
+
+ Done();
+
+ std::vector<bool> expected_valid(size, false);
+ expected_valid[0] = true;
+ expected_valid[size - 1] = true;
+ std::vector<int8_t> expected_values(size);
+ expected_values[0] = 127;
+ expected_values[size - 1] = -128;
+ std::shared_ptr<Array> expected;
+ ArrayFromVector<Int8Type, int8_t>(expected_valid, expected_values, &expected_);
+ AssertArraysEqual(*expected_, *result_);
+}
+
+TEST_F(TestAdaptiveIntBuilder, TestAppendNulls) {
+ constexpr int64_t size = 10;
+ ASSERT_EQ(0, builder_->null_count());
+ ASSERT_OK(builder_->AppendNulls(size));
+ ASSERT_EQ(size, builder_->null_count());
+
+ Done();
+
+ for (unsigned index = 0; index < size; ++index) {
+ ASSERT_FALSE(result_->IsValid(index));
+ }
+}
+
+TEST_F(TestAdaptiveIntBuilder, TestAppendEmptyValue) {
+ ASSERT_OK(builder_->AppendNulls(2));
+ ASSERT_OK(builder_->AppendEmptyValue());
+ ASSERT_OK(builder_->Append(42));
+ ASSERT_OK(builder_->AppendEmptyValues(2));
+ Done();
+
+ ASSERT_OK(result_->ValidateFull());
+ // NOTE: The fact that we get 0 is really an implementation detail
+ AssertArraysEqual(*result_, *ArrayFromJSON(int8(), "[null, null, 0, 42, 0, 0]"));
+}
+
+TEST(TestAdaptiveIntBuilderWithStartIntSize, TestReset) {
+ auto builder = std::make_shared<AdaptiveIntBuilder>(
+ static_cast<uint8_t>(sizeof(int16_t)), default_memory_pool());
+ AssertTypeEqual(*int16(), *builder->type());
+
+ ASSERT_OK(
+ builder->Append(static_cast<int64_t>(std::numeric_limits<int16_t>::max()) + 1));
+ AssertTypeEqual(*int32(), *builder->type());
+
+ builder->Reset();
+ AssertTypeEqual(*int16(), *builder->type());
+}
+
+class TestAdaptiveUIntBuilder : public TestBuilder {
+ public:
+ void SetUp() {
+ TestBuilder::SetUp();
+ builder_ = std::make_shared<AdaptiveUIntBuilder>(pool_);
+ }
+
+ void Done() { FinishAndCheckPadding(builder_.get(), &result_); }
+
+ template <typename ExpectedType>
+ void TestAppendValues() {
+ using CType = typename TypeTraits<ExpectedType>::CType;
+ auto type = TypeTraits<ExpectedType>::type_singleton();
+
+ std::vector<uint64_t> values(
+ {0, std::numeric_limits<CType>::min(), std::numeric_limits<CType>::max()});
+ std::vector<CType> expected_values(
+ {0, std::numeric_limits<CType>::min(), std::numeric_limits<CType>::max()});
+ ArrayFromVector<ExpectedType, CType>(expected_values, &expected_);
+
+ SetUp();
+ ASSERT_OK(builder_->AppendValues(values.data(), values.size()));
+ AssertTypeEqual(*builder_->type(), *type);
+ ASSERT_EQ(builder_->length(), static_cast<int64_t>(values.size()));
+ Done();
+ ASSERT_EQ(builder_->length(), 0);
+ AssertArraysEqual(*expected_, *result_);
+
+ // Reuse builder
+ builder_->Reset();
+ AssertTypeEqual(*builder_->type(), *uint8());
+ ASSERT_OK(builder_->AppendValues(values.data(), values.size()));
+ AssertTypeEqual(*builder_->type(), *type);
+ ASSERT_EQ(builder_->length(), static_cast<int64_t>(values.size()));
+ Done();
+ ASSERT_EQ(builder_->length(), 0);
+ AssertArraysEqual(*expected_, *result_);
+ }
+
+ protected:
+ std::shared_ptr<AdaptiveUIntBuilder> builder_;
+
+ std::shared_ptr<Array> expected_;
+ std::shared_ptr<Array> result_;
+};
+
+TEST_F(TestAdaptiveUIntBuilder, TestUInt8) {
+ ASSERT_EQ(builder_->length(), 0);
+ ASSERT_OK(builder_->Append(0));
+ ASSERT_EQ(builder_->length(), 1);
+ ASSERT_OK(builder_->Append(255));
+ ASSERT_EQ(builder_->length(), 2);
+
+ Done();
+
+ std::vector<uint8_t> expected_values({0, 255});
+ ArrayFromVector<UInt8Type, uint8_t>(expected_values, &expected_);
+ ASSERT_TRUE(expected_->Equals(result_));
+}
+
+TEST_F(TestAdaptiveUIntBuilder, TestUInt16) {
+ ASSERT_OK(builder_->Append(0));
+ ASSERT_EQ(builder_->type()->id(), Type::UINT8);
+ ASSERT_OK(builder_->Append(256));
+ ASSERT_EQ(builder_->type()->id(), Type::UINT16);
+ Done();
+
+ std::vector<uint16_t> expected_values({0, 256});
+ ArrayFromVector<UInt16Type, uint16_t>(expected_values, &expected_);
+ ASSERT_ARRAYS_EQUAL(*expected_, *result_);
+
+ SetUp();
+ ASSERT_OK(builder_->Append(std::numeric_limits<uint16_t>::max()));
+ expected_values = {std::numeric_limits<uint16_t>::max()};
+ Done();
+
+ ArrayFromVector<UInt16Type, uint16_t>(expected_values, &expected_);
+ ASSERT_TRUE(expected_->Equals(result_));
+}
+
+TEST_F(TestAdaptiveUIntBuilder, TestUInt16Nulls) {
+ ASSERT_OK(builder_->Append(0));
+ ASSERT_EQ(builder_->type()->id(), Type::UINT8);
+ ASSERT_OK(builder_->Append(256));
+ ASSERT_EQ(builder_->type()->id(), Type::UINT16);
+ ASSERT_OK(builder_->AppendNull());
+ ASSERT_EQ(builder_->type()->id(), Type::UINT16);
+ ASSERT_EQ(1, builder_->null_count());
+ Done();
+
+ std::vector<uint16_t> expected_values({0, 256, 0});
+ ArrayFromVector<UInt16Type, uint16_t>({1, 1, 0}, expected_values, &expected_);
+ ASSERT_ARRAYS_EQUAL(*expected_, *result_);
+}
+
+TEST_F(TestAdaptiveUIntBuilder, TestUInt32) {
+ ASSERT_OK(builder_->Append(0));
+ ASSERT_EQ(builder_->type()->id(), Type::UINT8);
+ ASSERT_OK(
+ builder_->Append(static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 1));
+ ASSERT_EQ(builder_->type()->id(), Type::UINT32);
+ Done();
+
+ std::vector<uint32_t> expected_values(
+ {0, static_cast<uint32_t>(std::numeric_limits<uint16_t>::max()) + 1});
+ ArrayFromVector<UInt32Type, uint32_t>(expected_values, &expected_);
+ ASSERT_TRUE(expected_->Equals(result_));
+
+ SetUp();
+ ASSERT_OK(builder_->Append(std::numeric_limits<uint32_t>::max()));
+ expected_values = {std::numeric_limits<uint32_t>::max()};
+ Done();
+
+ ArrayFromVector<UInt32Type, uint32_t>(expected_values, &expected_);
+ ASSERT_TRUE(expected_->Equals(result_));
+}
+
+TEST_F(TestAdaptiveUIntBuilder, TestUInt64) {
+ ASSERT_OK(builder_->Append(0));
+ ASSERT_EQ(builder_->type()->id(), Type::UINT8);
+ ASSERT_OK(
+ builder_->Append(static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1));
+ ASSERT_EQ(builder_->type()->id(), Type::UINT64);
+ Done();
+
+ std::vector<uint64_t> expected_values(
+ {0, static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1});
+ ArrayFromVector<UInt64Type, uint64_t>(expected_values, &expected_);
+ ASSERT_TRUE(expected_->Equals(result_));
+
+ SetUp();
+ ASSERT_OK(builder_->Append(std::numeric_limits<uint64_t>::max()));
+ expected_values = {std::numeric_limits<uint64_t>::max()};
+ Done();
+
+ ArrayFromVector<UInt64Type, uint64_t>(expected_values, &expected_);
+ ASSERT_TRUE(expected_->Equals(result_));
+}
+
+TEST_F(TestAdaptiveUIntBuilder, TestManyAppends) {
+ // More than the builder's internal scratchpad size
+ const int32_t n_values = 99999;
+ std::vector<uint32_t> expected_values(n_values);
+
+ for (int32_t i = 0; i < n_values; ++i) {
+ auto val = static_cast<uint32_t>(i);
+ expected_values[i] = val;
+ ASSERT_OK(builder_->Append(val));
+ ASSERT_EQ(builder_->length(), i + 1);
+ }
+ ASSERT_EQ(builder_->type()->id(), Type::UINT32);
+ Done();
+
+ ArrayFromVector<UInt32Type, uint32_t>(expected_values, &expected_);
+ AssertArraysEqual(*expected_, *result_);
+}
+
+TEST_F(TestAdaptiveUIntBuilder, TestAppendValues) {
+ this->template TestAppendValues<UInt64Type>();
+ this->template TestAppendValues<UInt32Type>();
+ this->template TestAppendValues<UInt16Type>();
+ this->template TestAppendValues<UInt8Type>();
+}
+
+TEST_F(TestAdaptiveUIntBuilder, TestAssertZeroPadded) {
+ std::vector<uint64_t> values(
+ {0, static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1});
+ ASSERT_OK(builder_->AppendValues(values.data(), values.size()));
+ Done();
+}
+
+TEST_F(TestAdaptiveUIntBuilder, TestAppendNull) {
+ int64_t size = 1000;
+ ASSERT_OK(builder_->Append(254));
+ for (unsigned index = 1; index < size - 1; ++index) {
+ ASSERT_OK(builder_->AppendNull());
+ ASSERT_EQ(index, builder_->null_count());
+ }
+ ASSERT_OK(builder_->Append(255));
+
+ Done();
+
+ std::vector<bool> expected_valid(size, false);
+ expected_valid[0] = true;
+ expected_valid[size - 1] = true;
+ std::vector<uint8_t> expected_values(size);
+ expected_values[0] = 254;
+ expected_values[size - 1] = 255;
+ std::shared_ptr<Array> expected;
+ ArrayFromVector<UInt8Type, uint8_t>(expected_valid, expected_values, &expected_);
+ AssertArraysEqual(*expected_, *result_);
+}
+
+TEST_F(TestAdaptiveUIntBuilder, TestAppendNulls) {
+ constexpr int64_t size = 10;
+ ASSERT_OK(builder_->AppendNulls(size));
+ ASSERT_EQ(size, builder_->null_count());
+
+ Done();
+
+ for (unsigned index = 0; index < size; ++index) {
+ ASSERT_FALSE(result_->IsValid(index));
+ }
+}
+
+TEST_F(TestAdaptiveUIntBuilder, TestAppendEmptyValue) {
+ ASSERT_OK(builder_->AppendNulls(2));
+ ASSERT_OK(builder_->AppendEmptyValue());
+ ASSERT_OK(builder_->Append(42));
+ ASSERT_OK(builder_->AppendEmptyValues(2));
+ Done();
+
+ ASSERT_OK(result_->ValidateFull());
+ // NOTE: The fact that we get 0 is really an implementation detail
+ AssertArraysEqual(*result_, *ArrayFromJSON(uint8(), "[null, null, 0, 42, 0, 0]"));
+}
+
+TEST(TestAdaptiveUIntBuilderWithStartIntSize, TestReset) {
+ auto builder = std::make_shared<AdaptiveUIntBuilder>(
+ static_cast<uint8_t>(sizeof(uint16_t)), default_memory_pool());
+ AssertTypeEqual(uint16(), builder->type());
+
+ ASSERT_OK(
+ builder->Append(static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 1));
+ AssertTypeEqual(uint32(), builder->type());
+
+ builder->Reset();
+ AssertTypeEqual(uint16(), builder->type());
+}
+
+// ----------------------------------------------------------------------
+// Test Decimal arrays
+
+template <typename TYPE>
+class DecimalTest : public ::testing::TestWithParam<int> {
+ public:
+ using DecimalBuilder = typename TypeTraits<TYPE>::BuilderType;
+ using DecimalValue = typename TypeTraits<TYPE>::ScalarType::ValueType;
+ using DecimalArray = typename TypeTraits<TYPE>::ArrayType;
+ using DecimalVector = std::vector<DecimalValue>;
+
+ DecimalTest() {}
+
+ template <size_t BYTE_WIDTH = 16>
+ void MakeData(const DecimalVector& input, std::vector<uint8_t>* out) const {
+ out->reserve(input.size() * BYTE_WIDTH);
+
+ for (const auto& value : input) {
+ auto bytes = value.ToBytes();
+ out->insert(out->end(), bytes.cbegin(), bytes.cend());
+ }
+ }
+
+ template <size_t BYTE_WIDTH = 16>
+ void TestCreate(int32_t precision, const DecimalVector& draw,
+ const std::vector<uint8_t>& valid_bytes, int64_t offset) const {
+ auto type = std::make_shared<TYPE>(precision, 4);
+ auto builder = std::make_shared<DecimalBuilder>(type);
+
+ size_t null_count = 0;
+
+ const size_t size = draw.size();
+
+ ASSERT_OK(builder->Reserve(size));
+
+ for (size_t i = 0; i < size; ++i) {
+ if (valid_bytes[i]) {
+ ASSERT_OK(builder->Append(draw[i]));
+ } else {
+ ASSERT_OK(builder->AppendNull());
+ ++null_count;
+ }
+ }
+
+ std::shared_ptr<Array> out;
+ FinishAndCheckPadding(builder.get(), &out);
+ ASSERT_EQ(builder->length(), 0);
+
+ std::vector<uint8_t> raw_bytes;
+
+ raw_bytes.reserve(size * BYTE_WIDTH);
+ MakeData<BYTE_WIDTH>(draw, &raw_bytes);
+
+ auto expected_data = std::make_shared<Buffer>(raw_bytes.data(), BYTE_WIDTH);
+ std::shared_ptr<Buffer> expected_null_bitmap;
+ ASSERT_OK_AND_ASSIGN(expected_null_bitmap, internal::BytesToBits(valid_bytes));
+
+ int64_t expected_null_count = CountNulls(valid_bytes);
+ auto expected = std::make_shared<DecimalArray>(
+ type, size, expected_data, expected_null_bitmap, expected_null_count);
+
+ std::shared_ptr<Array> lhs = out->Slice(offset);
+ std::shared_ptr<Array> rhs = expected->Slice(offset);
+ ASSERT_ARRAYS_EQUAL(*rhs, *lhs);
+ }
+};
+
+using Decimal128Test = DecimalTest<Decimal128Type>;
+
+TEST_P(Decimal128Test, NoNulls) {
+ int32_t precision = GetParam();
+ std::vector<Decimal128> draw = {Decimal128(1), Decimal128(-2), Decimal128(2389),
+ Decimal128(4), Decimal128(-12348)};
+ std::vector<uint8_t> valid_bytes = {true, true, true, true, true};
+ this->TestCreate(precision, draw, valid_bytes, 0);
+ this->TestCreate(precision, draw, valid_bytes, 2);
+}
+
+TEST_P(Decimal128Test, WithNulls) {
+ int32_t precision = GetParam();
+ std::vector<Decimal128> draw = {Decimal128(1), Decimal128(2), Decimal128(-1),
+ Decimal128(4), Decimal128(-1), Decimal128(1),
+ Decimal128(2)};
+ Decimal128 big;
+ ASSERT_OK_AND_ASSIGN(big, Decimal128::FromString("230342903942.234234"));
+ draw.push_back(big);
+
+ Decimal128 big_negative;
+ ASSERT_OK_AND_ASSIGN(big_negative, Decimal128::FromString("-23049302932.235234"));
+ draw.push_back(big_negative);
+
+ std::vector<uint8_t> valid_bytes = {true, true, false, true, false,
+ true, true, true, true};
+ this->TestCreate(precision, draw, valid_bytes, 0);
+ this->TestCreate(precision, draw, valid_bytes, 2);
+}
+
+INSTANTIATE_TEST_SUITE_P(Decimal128Test, Decimal128Test, ::testing::Range(1, 38));
+
+using Decimal256Test = DecimalTest<Decimal256Type>;
+
+TEST_P(Decimal256Test, NoNulls) {
+ int32_t precision = GetParam();
+ std::vector<Decimal256> draw = {Decimal256(1), Decimal256(-2), Decimal256(2389),
+ Decimal256(4), Decimal256(-12348)};
+ std::vector<uint8_t> valid_bytes = {true, true, true, true, true};
+ this->TestCreate(precision, draw, valid_bytes, 0);
+ this->TestCreate(precision, draw, valid_bytes, 2);
+}
+
+TEST_P(Decimal256Test, WithNulls) {
+ int32_t precision = GetParam();
+ std::vector<Decimal256> draw = {Decimal256(1), Decimal256(2), Decimal256(-1),
+ Decimal256(4), Decimal256(-1), Decimal256(1),
+ Decimal256(2)};
+ Decimal256 big; // (pow(2, 255) - 1) / pow(10, 38)
+ ASSERT_OK_AND_ASSIGN(big,
+ Decimal256::FromString("578960446186580977117854925043439539266."
+ "34992332820282019728792003956564819967"));
+ draw.push_back(big);
+
+ Decimal256 big_negative; // -pow(2, 255) / pow(10, 38)
+ ASSERT_OK_AND_ASSIGN(big_negative,
+ Decimal256::FromString("-578960446186580977117854925043439539266."
+ "34992332820282019728792003956564819968"));
+ draw.push_back(big_negative);
+
+ std::vector<uint8_t> valid_bytes = {true, true, false, true, false,
+ true, true, true, true};
+ this->TestCreate(precision, draw, valid_bytes, 0);
+ this->TestCreate(precision, draw, valid_bytes, 2);
+}
+
+INSTANTIATE_TEST_SUITE_P(Decimal256Test, Decimal256Test,
+ ::testing::Values(1, 2, 5, 10, 38, 39, 40, 75, 76));
+
+// ----------------------------------------------------------------------
+// Test rechunking
+
+TEST(TestRechunkArraysConsistently, Trivial) {
+ std::vector<ArrayVector> groups, rechunked;
+ rechunked = internal::RechunkArraysConsistently(groups);
+ ASSERT_EQ(rechunked.size(), 0);
+
+ std::shared_ptr<Array> a1, a2, b1;
+ ArrayFromVector<Int16Type, int16_t>({}, &a1);
+ ArrayFromVector<Int16Type, int16_t>({}, &a2);
+ ArrayFromVector<Int32Type, int32_t>({}, &b1);
+
+ groups = {{a1, a2}, {}, {b1}};
+ rechunked = internal::RechunkArraysConsistently(groups);
+ ASSERT_EQ(rechunked.size(), 3);
+
+ for (auto& arrvec : rechunked) {
+ for (auto& arr : arrvec) {
+ AssertZeroPadded(*arr);
+ TestInitialized(*arr);
+ }
+ }
+}
+
+TEST(TestRechunkArraysConsistently, Plain) {
+ std::shared_ptr<Array> expected;
+ std::shared_ptr<Array> a1, a2, a3, b1, b2, b3, b4;
+ ArrayFromVector<Int16Type, int16_t>({1, 2, 3}, &a1);
+ ArrayFromVector<Int16Type, int16_t>({4, 5}, &a2);
+ ArrayFromVector<Int16Type, int16_t>({6, 7, 8, 9}, &a3);
+
+ ArrayFromVector<Int32Type, int32_t>({41, 42}, &b1);
+ ArrayFromVector<Int32Type, int32_t>({43, 44, 45}, &b2);
+ ArrayFromVector<Int32Type, int32_t>({46, 47}, &b3);
+ ArrayFromVector<Int32Type, int32_t>({48, 49}, &b4);
+
+ ArrayVector a{a1, a2, a3};
+ ArrayVector b{b1, b2, b3, b4};
+
+ std::vector<ArrayVector> groups{a, b}, rechunked;
+ rechunked = internal::RechunkArraysConsistently(groups);
+ ASSERT_EQ(rechunked.size(), 2);
+ auto ra = rechunked[0];
+ auto rb = rechunked[1];
+
+ ASSERT_EQ(ra.size(), 5);
+ ArrayFromVector<Int16Type, int16_t>({1, 2}, &expected);
+ ASSERT_ARRAYS_EQUAL(*ra[0], *expected);
+ ArrayFromVector<Int16Type, int16_t>({3}, &expected);
+ ASSERT_ARRAYS_EQUAL(*ra[1], *expected);
+ ArrayFromVector<Int16Type, int16_t>({4, 5}, &expected);
+ ASSERT_ARRAYS_EQUAL(*ra[2], *expected);
+ ArrayFromVector<Int16Type, int16_t>({6, 7}, &expected);
+ ASSERT_ARRAYS_EQUAL(*ra[3], *expected);
+ ArrayFromVector<Int16Type, int16_t>({8, 9}, &expected);
+ ASSERT_ARRAYS_EQUAL(*ra[4], *expected);
+
+ ASSERT_EQ(rb.size(), 5);
+ ArrayFromVector<Int32Type, int32_t>({41, 42}, &expected);
+ ASSERT_ARRAYS_EQUAL(*rb[0], *expected);
+ ArrayFromVector<Int32Type, int32_t>({43}, &expected);
+ ASSERT_ARRAYS_EQUAL(*rb[1], *expected);
+ ArrayFromVector<Int32Type, int32_t>({44, 45}, &expected);
+ ASSERT_ARRAYS_EQUAL(*rb[2], *expected);
+ ArrayFromVector<Int32Type, int32_t>({46, 47}, &expected);
+ ASSERT_ARRAYS_EQUAL(*rb[3], *expected);
+ ArrayFromVector<Int32Type, int32_t>({48, 49}, &expected);
+ ASSERT_ARRAYS_EQUAL(*rb[4], *expected);
+
+ for (auto& arrvec : rechunked) {
+ for (auto& arr : arrvec) {
+ AssertZeroPadded(*arr);
+ TestInitialized(*arr);
+ }
+ }
+}
+
+// ----------------------------------------------------------------------
+// Test SwapEndianArrayData
+
+/// \brief Indicate if fields are equals.
+///
+/// \param[in] target ArrayData to be converted and tested
+/// \param[in] expected result ArrayData
+void AssertArrayDataEqualsWithSwapEndian(const std::shared_ptr<ArrayData>& target,
+ const std::shared_ptr<ArrayData>& expected) {
+ auto swap_array = MakeArray(*::arrow::internal::SwapEndianArrayData(target));
+ auto expected_array = MakeArray(expected);
+ ASSERT_ARRAYS_EQUAL(*swap_array, *expected_array);
+ ASSERT_OK(swap_array->ValidateFull());
+}
+
+TEST(TestSwapEndianArrayData, PrimitiveType) {
+ auto null_buffer = Buffer::FromString("\xff");
+ auto data_int_buffer = Buffer::FromString("01234567");
+
+ auto data = ArrayData::Make(null(), 0, {nullptr}, 0);
+ auto expected_data = data;
+ AssertArrayDataEqualsWithSwapEndian(data, expected_data);
+
+ data = ArrayData::Make(boolean(), 8, {null_buffer, data_int_buffer}, 0);
+ expected_data = data;
+ AssertArrayDataEqualsWithSwapEndian(data, expected_data);
+
+ data = ArrayData::Make(int8(), 8, {null_buffer, data_int_buffer}, 0);
+ expected_data = data;
+ AssertArrayDataEqualsWithSwapEndian(data, expected_data);
+
+ data = ArrayData::Make(uint16(), 4, {null_buffer, data_int_buffer}, 0);
+ auto data_int16_buffer = Buffer::FromString("10325476");
+ expected_data = ArrayData::Make(uint16(), 4, {null_buffer, data_int16_buffer}, 0);
+ AssertArrayDataEqualsWithSwapEndian(data, expected_data);
+
+ data = ArrayData::Make(int32(), 2, {null_buffer, data_int_buffer}, 0);
+ auto data_int32_buffer = Buffer::FromString("32107654");
+ expected_data = ArrayData::Make(int32(), 2, {null_buffer, data_int32_buffer}, 0);
+ AssertArrayDataEqualsWithSwapEndian(data, expected_data);
+
+ data = ArrayData::Make(uint64(), 1, {null_buffer, data_int_buffer}, 0);
+ auto data_int64_buffer = Buffer::FromString("76543210");
+ expected_data = ArrayData::Make(uint64(), 1, {null_buffer, data_int64_buffer}, 0);
+ AssertArrayDataEqualsWithSwapEndian(data, expected_data);
+
+ auto data_16byte_buffer = Buffer::FromString("0123456789abcdef");
+ data = ArrayData::Make(decimal128(38, 10), 1, {null_buffer, data_16byte_buffer});
+ auto data_decimal128_buffer = Buffer::FromString("fedcba9876543210");
+ expected_data =
+ ArrayData::Make(decimal128(38, 10), 1, {null_buffer, data_decimal128_buffer}, 0);
+ AssertArrayDataEqualsWithSwapEndian(data, expected_data);
+
+ auto data_32byte_buffer = Buffer::FromString("0123456789abcdef123456789ABCDEF0");
+ data = ArrayData::Make(decimal256(76, 20), 1, {null_buffer, data_32byte_buffer});
+ auto data_decimal256_buffer = Buffer::FromString("0FEDCBA987654321fedcba9876543210");
+ expected_data =
+ ArrayData::Make(decimal256(76, 20), 1, {null_buffer, data_decimal256_buffer}, 0);
+ AssertArrayDataEqualsWithSwapEndian(data, expected_data);
+
+ auto data_float_buffer = Buffer::FromString("01200560");
+ data = ArrayData::Make(float32(), 2, {null_buffer, data_float_buffer}, 0);
+ auto data_float32_buffer = Buffer::FromString("02100650");
+ expected_data = ArrayData::Make(float32(), 2, {null_buffer, data_float32_buffer}, 0);
+ AssertArrayDataEqualsWithSwapEndian(data, expected_data);
+
+ data = ArrayData::Make(float64(), 1, {null_buffer, data_float_buffer});
+ auto data_float64_buffer = Buffer::FromString("06500210");
+ expected_data = ArrayData::Make(float64(), 1, {null_buffer, data_float64_buffer}, 0);
+ AssertArrayDataEqualsWithSwapEndian(data, expected_data);
+
+ // With offset > 0
+ data =
+ ArrayData::Make(int64(), 1, {null_buffer, data_int_buffer}, kUnknownNullCount, 1);
+ ASSERT_RAISES(Invalid, ::arrow::internal::SwapEndianArrayData(data));
+}
+
+std::shared_ptr<ArrayData> ReplaceBuffers(const std::shared_ptr<ArrayData>& data,
+ const int32_t buffer_index,
+ const std::vector<uint8_t>& buffer_data) {
+ const auto test_data = data->Copy();
+ test_data->buffers[buffer_index] =
+ std::make_shared<Buffer>(buffer_data.data(), buffer_data.size());
+ return test_data;
+}
+
+std::shared_ptr<ArrayData> ReplaceBuffersInChild(const std::shared_ptr<ArrayData>& data,
+ const int32_t child_index,
+ const std::vector<uint8_t>& child_data) {
+ const auto test_data = data->Copy();
+ // assume updating only buffer[1] in child_data
+ auto child_array_data = test_data->child_data[child_index]->Copy();
+ child_array_data->buffers[1] =
+ std::make_shared<Buffer>(child_data.data(), child_data.size());
+ test_data->child_data[child_index] = child_array_data;
+ return test_data;
+}
+
+std::shared_ptr<ArrayData> ReplaceBuffersInDictionary(
+ const std::shared_ptr<ArrayData>& data, const int32_t buffer_index,
+ const std::vector<uint8_t>& buffer_data) {
+ const auto test_data = data->Copy();
+ auto dict_array_data = test_data->dictionary->Copy();
+ dict_array_data->buffers[buffer_index] =
+ std::make_shared<Buffer>(buffer_data.data(), buffer_data.size());
+ test_data->dictionary = dict_array_data;
+ return test_data;
+}
+
+TEST(TestSwapEndianArrayData, BinaryType) {
+ auto array = ArrayFromJSON(binary(), R"(["0123", null, "45"])");
+ const std::vector<uint8_t> offset1 =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 6};
+#else
+ {0, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0};
+#endif
+ auto expected_data = array->data();
+ auto test_data = ReplaceBuffers(expected_data, 1, offset1);
+ AssertArrayDataEqualsWithSwapEndian(test_data, expected_data);
+
+ array = ArrayFromJSON(large_binary(), R"(["01234", null, "567"])");
+ const std::vector<uint8_t> offset2 =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5,
+ 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 8};
+#else
+ {0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0,
+ 5, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0};
+#endif
+ expected_data = array->data();
+ test_data = ReplaceBuffers(expected_data, 1, offset2);
+ AssertArrayDataEqualsWithSwapEndian(test_data, expected_data);
+
+ array = ArrayFromJSON(fixed_size_binary(3), R"(["012", null, "345"])");
+ expected_data = array->data();
+ AssertArrayDataEqualsWithSwapEndian(expected_data, expected_data);
+}
+
+TEST(TestSwapEndianArrayData, StringType) {
+ auto array = ArrayFromJSON(utf8(), R"(["ABCD", null, "EF"])");
+ const std::vector<uint8_t> offset1 =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 6};
+#else
+ {0, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0};
+#endif
+ auto expected_data = array->data();
+ auto test_data = ReplaceBuffers(expected_data, 1, offset1);
+ AssertArrayDataEqualsWithSwapEndian(test_data, expected_data);
+
+ array = ArrayFromJSON(large_utf8(), R"(["ABCDE", null, "FGH"])");
+ const std::vector<uint8_t> offset2 =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5,
+ 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 8};
+#else
+ {0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0,
+ 5, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0};
+#endif
+ expected_data = array->data();
+ test_data = ReplaceBuffers(expected_data, 1, offset2);
+ AssertArrayDataEqualsWithSwapEndian(test_data, expected_data);
+}
+
+TEST(TestSwapEndianArrayData, ListType) {
+ auto type1 = std::make_shared<ListType>(int32());
+ auto array = ArrayFromJSON(type1, "[[0, 1, 2, 3], null, [4, 5]]");
+ const std::vector<uint8_t> offset1 =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 6};
+#else
+ {0, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0};
+#endif
+ const std::vector<uint8_t> data1 =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5};
+#else
+ {0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0};
+#endif
+ auto expected_data = array->data();
+ auto test_data = ReplaceBuffers(expected_data, 1, offset1);
+ test_data = ReplaceBuffersInChild(test_data, 0, data1);
+ AssertArrayDataEqualsWithSwapEndian(test_data, expected_data);
+
+ auto type2 = std::make_shared<LargeListType>(int64());
+ array = ArrayFromJSON(type2, "[[0, 1, 2], null, [3]]");
+ const std::vector<uint8_t> offset2 =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3,
+ 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4};
+#else
+ {0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0,
+ 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0};
+#endif
+ const std::vector<uint8_t> data2 =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3};
+#else
+ {0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
+ 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0};
+#endif
+ expected_data = array->data();
+ test_data = ReplaceBuffers(expected_data, 1, offset2);
+ test_data = ReplaceBuffersInChild(test_data, 0, data2);
+ AssertArrayDataEqualsWithSwapEndian(test_data, expected_data);
+
+ auto type3 = std::make_shared<FixedSizeListType>(int32(), 2);
+ array = ArrayFromJSON(type3, "[[0, 1], null, [2, 3]]");
+ expected_data = array->data();
+ const std::vector<uint8_t> data3 =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3};
+#else
+ {0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0};
+#endif
+ test_data = ReplaceBuffersInChild(expected_data, 0, data3);
+ AssertArrayDataEqualsWithSwapEndian(test_data, expected_data);
+}
+
+TEST(TestSwapEndianArrayData, DictionaryType) {
+ auto type = dictionary(int32(), int16());
+ auto dict = ArrayFromJSON(int16(), "[4, 5, 6, 7]");
+ DictionaryArray array(type, ArrayFromJSON(int32(), "[0, 2, 3]"), dict);
+ auto expected_data = array.data();
+ const std::vector<uint8_t> data1 =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3};
+#else
+ {0, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0};
+#endif
+ const std::vector<uint8_t> data2 =
+#if ARROW_LITTLE_ENDIAN
+ {0, 4, 0, 5, 0, 6, 0, 7};
+#else
+ {4, 0, 5, 0, 6, 0, 7, 0};
+#endif
+ auto test_data = ReplaceBuffers(expected_data, 1, data1);
+ test_data = ReplaceBuffersInDictionary(test_data, 1, data2);
+ // dictionary must be explicitly swapped
+ test_data->dictionary = *::arrow::internal::SwapEndianArrayData(test_data->dictionary);
+ AssertArrayDataEqualsWithSwapEndian(test_data, expected_data);
+}
+
+TEST(TestSwapEndianArrayData, StructType) {
+ auto array = ArrayFromJSON(struct_({field("a", int32()), field("b", utf8())}),
+ R"([{"a": 4, "b": null}, {"a": null, "b": "foo"}])");
+ auto expected_data = array->data();
+ const std::vector<uint8_t> data1 =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 4, 0, 0, 0, 0};
+#else
+ {4, 0, 0, 0, 0, 0, 0, 0};
+#endif
+ const std::vector<uint8_t> data2 =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3};
+#else
+ {0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0};
+#endif
+ auto test_data = ReplaceBuffersInChild(expected_data, 0, data1);
+ test_data = ReplaceBuffersInChild(test_data, 1, data2);
+ AssertArrayDataEqualsWithSwapEndian(test_data, expected_data);
+}
+
+TEST(TestSwapEndianArrayData, UnionType) {
+ auto expected_i8 = ArrayFromJSON(int8(), "[127, null, null, null, null]");
+ auto expected_str = ArrayFromJSON(utf8(), R"([null, "abcd", null, null, ""])");
+ auto expected_i32 = ArrayFromJSON(int32(), "[null, null, 1, 2, null]");
+ std::vector<uint8_t> expected_types_vector;
+ expected_types_vector.push_back(Type::INT8);
+ expected_types_vector.insert(expected_types_vector.end(), 2, Type::STRING);
+ expected_types_vector.insert(expected_types_vector.end(), 2, Type::INT32);
+ std::shared_ptr<Array> expected_types;
+ ArrayFromVector<Int8Type, uint8_t>(expected_types_vector, &expected_types);
+ auto arr1 = SparseUnionArray::Make(
+ *expected_types, {expected_i8, expected_str, expected_i32}, {"i8", "str", "i32"},
+ {Type::INT8, Type::STRING, Type::INT32});
+ auto expected_data = (*arr1)->data();
+ const std::vector<uint8_t> data1a =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4};
+#else
+ {0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0};
+#endif
+ const std::vector<uint8_t> data1b =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0};
+#else
+ {0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0};
+#endif
+ auto test_data = ReplaceBuffersInChild(expected_data, 1, data1a);
+ test_data = ReplaceBuffersInChild(test_data, 2, data1b);
+ AssertArrayDataEqualsWithSwapEndian(test_data, expected_data);
+
+ expected_i8 = ArrayFromJSON(int8(), "[33, 10, -10]");
+ expected_str = ArrayFromJSON(utf8(), R"(["abc", "", "def"])");
+ expected_i32 = ArrayFromJSON(int32(), "[1, -259, 2]");
+ auto expected_offsets = ArrayFromJSON(int32(), "[0, 0, 0, 1, 1, 1, 2, 2, 2]");
+ auto arr2 = DenseUnionArray::Make(
+ *expected_types, *expected_offsets, {expected_i8, expected_str, expected_i32},
+ {"i8", "str", "i32"}, {Type::INT8, Type::STRING, Type::INT32});
+ expected_data = (*arr2)->data();
+ const std::vector<uint8_t> data2a =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
+ 0, 1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2};
+#else
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,
+ 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0};
+#endif
+ const std::vector<uint8_t> data2b =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 6};
+#else
+ {0, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0};
+#endif
+ const std::vector<uint8_t> data2c =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 1, 255, 255, 254, 253, 0, 0, 0, 2};
+#else
+ {1, 0, 0, 0, 253, 254, 255, 255, 2, 0, 0, 0};
+#endif
+ test_data = ReplaceBuffers(expected_data, 2, data2a);
+ test_data = ReplaceBuffersInChild(test_data, 1, data2b);
+ test_data = ReplaceBuffersInChild(test_data, 2, data2c);
+ AssertArrayDataEqualsWithSwapEndian(test_data, expected_data);
+}
+
+TEST(TestSwapEndianArrayData, ExtensionType) {
+ auto array_int16 = ArrayFromJSON(int16(), "[0, 1, 2, 3]");
+ auto ext_data = array_int16->data()->Copy();
+ ext_data->type = std::make_shared<SmallintType>();
+ auto array = MakeArray(ext_data);
+ auto expected_data = array->data();
+ const std::vector<uint8_t> data =
+#if ARROW_LITTLE_ENDIAN
+ {0, 0, 0, 1, 0, 2, 0, 3};
+#else
+ {0, 0, 1, 0, 2, 0, 3, 0};
+#endif
+ auto test_data = ReplaceBuffers(expected_data, 1, data);
+ AssertArrayDataEqualsWithSwapEndian(test_data, expected_data);
+}
+
+TEST(TestSwapEndianArrayData, MonthDayNanoInterval) {
+ auto array = ArrayFromJSON(month_day_nano_interval(), R"([[0, 1, 2],
+ [5000, 200, 3000000000]])");
+ auto expected_array =
+ ArrayFromJSON(month_day_nano_interval(), R"([[0, 16777216, 144115188075855872],
+ [-2012020736, -939524096, 26688110733557760]])");
+
+ auto swap_array = MakeArray(*::arrow::internal::SwapEndianArrayData(array->data()));
+ EXPECT_TRUE(!swap_array->Equals(array));
+ ASSERT_ARRAYS_EQUAL(*swap_array, *expected_array);
+ ASSERT_ARRAYS_EQUAL(
+ *MakeArray(*::arrow::internal::SwapEndianArrayData(swap_array->data())), *array);
+ ASSERT_OK(swap_array->ValidateFull());
+}
+
+DataTypeVector SwappableTypes() {
+ return DataTypeVector{int8(),
+ int16(),
+ int32(),
+ int64(),
+ uint8(),
+ uint16(),
+ uint32(),
+ uint64(),
+ decimal128(19, 4),
+ decimal256(37, 8),
+ timestamp(TimeUnit::MICRO, ""),
+ time32(TimeUnit::SECOND),
+ time64(TimeUnit::NANO),
+ date32(),
+ date64(),
+ day_time_interval(),
+ month_interval(),
+ month_day_nano_interval(),
+ binary(),
+ utf8(),
+ large_binary(),
+ large_utf8(),
+ list(int16()),
+ large_list(int16()),
+ dictionary(int16(), utf8())};
+}
+
+TEST(TestSwapEndianArrayData, RandomData) {
+ random::RandomArrayGenerator rng(42);
+
+ for (const auto& type : SwappableTypes()) {
+ ARROW_SCOPED_TRACE("type = ", type->ToString());
+ auto arr = rng.ArrayOf(*field("", type), /*size=*/31);
+ ASSERT_OK_AND_ASSIGN(auto swapped_data,
+ ::arrow::internal::SwapEndianArrayData(arr->data()));
+ auto swapped = MakeArray(swapped_data);
+ ASSERT_OK_AND_ASSIGN(auto roundtripped_data,
+ ::arrow::internal::SwapEndianArrayData(swapped_data));
+ auto roundtripped = MakeArray(roundtripped_data);
+ ASSERT_OK(roundtripped->ValidateFull());
+
+ AssertArraysEqual(*arr, *roundtripped, /*verbose=*/true);
+ if (type->id() == Type::INT8 || type->id() == Type::UINT8) {
+ AssertArraysEqual(*arr, *swapped, /*verbose=*/true);
+ } else {
+ // Random generated data is unlikely to be made of byte-palindromes
+ ASSERT_FALSE(arr->Equals(*swapped));
+ }
+ }
+}
+
+TEST(TestSwapEndianArrayData, InvalidLength) {
+ // IPC-incoming data may be invalid, SwapEndianArrayData shouldn't crash
+ // by accessing memory out of bounds.
+ random::RandomArrayGenerator rng(42);
+
+ for (const auto& type : SwappableTypes()) {
+ ARROW_SCOPED_TRACE("type = ", type->ToString());
+ ASSERT_OK_AND_ASSIGN(auto arr, MakeArrayOfNull(type, 0));
+ auto data = arr->data();
+ // Fake length
+ data->length = 123456789;
+ ASSERT_OK_AND_ASSIGN(auto swapped_data, ::arrow::internal::SwapEndianArrayData(data));
+ auto swapped = MakeArray(swapped_data);
+ ASSERT_RAISES(Invalid, swapped->Validate());
+ }
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_union_test.cc b/src/arrow/cpp/src/arrow/array/array_union_test.cc
new file mode 100644
index 000000000..d3afe40df
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_union_test.cc
@@ -0,0 +1,582 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <string>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/array/builder_union.h"
+// TODO ipc shouldn't be included here
+#include "arrow/ipc/test_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+TEST(TestUnionArray, TestSliceEquals) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(ipc::test::MakeUnion(&batch));
+
+ auto CheckUnion = [](std::shared_ptr<Array> array) {
+ const int64_t size = array->length();
+ std::shared_ptr<Array> slice, slice2;
+ slice = array->Slice(2);
+ ASSERT_EQ(size - 2, slice->length());
+
+ slice2 = array->Slice(2);
+ ASSERT_EQ(size - 2, slice->length());
+
+ ASSERT_TRUE(slice->Equals(slice2));
+ ASSERT_TRUE(array->RangeEquals(2, array->length(), 0, slice));
+
+ // Chained slices
+ slice2 = array->Slice(1)->Slice(1);
+ ASSERT_TRUE(slice->Equals(slice2));
+
+ slice = array->Slice(1, 5);
+ slice2 = array->Slice(1, 5);
+ ASSERT_EQ(5, slice->length());
+
+ ASSERT_TRUE(slice->Equals(slice2));
+ ASSERT_TRUE(array->RangeEquals(1, 6, 0, slice));
+
+ AssertZeroPadded(*array);
+ TestInitialized(*array);
+ };
+
+ CheckUnion(batch->column(0));
+ CheckUnion(batch->column(1));
+}
+
+TEST(TestSparseUnionArray, Validate) {
+ auto a = ArrayFromJSON(int32(), "[4, 5]");
+ auto type = sparse_union({field("a", int32())});
+ auto children = std::vector<std::shared_ptr<Array>>{a};
+ auto type_ids_array = ArrayFromJSON(int8(), "[0, 0, 0]");
+ auto type_ids = type_ids_array->data()->buffers[1];
+
+ auto arr = std::make_shared<SparseUnionArray>(type, 2, children, type_ids);
+ ASSERT_OK(arr->ValidateFull());
+ arr = std::make_shared<SparseUnionArray>(type, 1, children, type_ids,
+ /*offset=*/1);
+ ASSERT_OK(arr->ValidateFull());
+ arr = std::make_shared<SparseUnionArray>(type, 0, children, type_ids,
+ /*offset=*/2);
+ ASSERT_OK(arr->ValidateFull());
+
+ // Length + offset < child length, but it's ok
+ arr = std::make_shared<SparseUnionArray>(type, 1, children, type_ids,
+ /*offset=*/0);
+ ASSERT_OK(arr->ValidateFull());
+
+ // Length + offset > child length
+ arr = std::make_shared<SparseUnionArray>(type, 1, children, type_ids,
+ /*offset=*/2);
+ ASSERT_RAISES(Invalid, arr->ValidateFull());
+
+ // Offset > child length
+ arr = std::make_shared<SparseUnionArray>(type, 0, children, type_ids,
+ /*offset=*/3);
+ ASSERT_RAISES(Invalid, arr->ValidateFull());
+}
+
+// -------------------------------------------------------------------------
+// Tests for MakeDense and MakeSparse
+
+class TestUnionArrayFactories : public ::testing::Test {
+ public:
+ void SetUp() {
+ pool_ = default_memory_pool();
+ type_codes_ = {1, 2, 4, 127};
+ ArrayFromVector<Int8Type>({0, 1, 2, 0, 1, 3, 2, 0, 2, 1}, &type_ids_);
+ ArrayFromVector<Int8Type>({1, 2, 4, 1, 2, 127, 4, 1, 4, 2}, &logical_type_ids_);
+ ArrayFromVector<Int8Type>({1, 2, 4, 1, -2, 127, 4, 1, 4, 2}, &invalid_type_ids1_);
+ ArrayFromVector<Int8Type>({1, 2, 4, 1, 3, 127, 4, 1, 4, 2}, &invalid_type_ids2_);
+ }
+
+ void CheckUnionArray(const UnionArray& array, UnionMode::type mode,
+ const std::vector<std::string>& field_names,
+ const std::vector<int8_t>& type_codes) {
+ ASSERT_EQ(mode, array.mode());
+ CheckFieldNames(array, field_names);
+ CheckTypeCodes(array, type_codes);
+ const auto& type_ids = checked_cast<const Int8Array&>(*type_ids_);
+ for (int64_t i = 0; i < type_ids.length(); ++i) {
+ ASSERT_EQ(array.child_id(i), type_ids.Value(i));
+ }
+ ASSERT_EQ(nullptr, array.field(-1));
+ ASSERT_EQ(nullptr, array.field(static_cast<int>(type_ids.length())));
+ }
+
+ void CheckFieldNames(const UnionArray& array, const std::vector<std::string>& names) {
+ const auto& type = checked_cast<const UnionType&>(*array.type());
+ ASSERT_EQ(type.num_fields(), names.size());
+ for (int i = 0; i < type.num_fields(); ++i) {
+ ASSERT_EQ(type.field(i)->name(), names[i]);
+ }
+ }
+
+ void CheckTypeCodes(const UnionArray& array, const std::vector<int8_t>& codes) {
+ const auto& type = checked_cast<const UnionType&>(*array.type());
+ ASSERT_EQ(codes, type.type_codes());
+ }
+
+ protected:
+ MemoryPool* pool_;
+ std::vector<int8_t> type_codes_;
+ std::shared_ptr<Array> type_ids_;
+ std::shared_ptr<Array> logical_type_ids_;
+ std::shared_ptr<Array> invalid_type_ids1_;
+ std::shared_ptr<Array> invalid_type_ids2_;
+};
+
+TEST_F(TestUnionArrayFactories, TestMakeDense) {
+ std::shared_ptr<Array> value_offsets;
+ // type_ids_: {0, 1, 2, 0, 1, 3, 2, 0, 2, 1}
+ ArrayFromVector<Int32Type, int32_t>({0, 0, 0, 1, 1, 0, 1, 2, 1, 2}, &value_offsets);
+
+ auto children = std::vector<std::shared_ptr<Array>>(4);
+ ArrayFromVector<StringType, std::string>({"abc", "def", "xyz"}, &children[0]);
+ ArrayFromVector<UInt8Type>({10, 20, 30}, &children[1]);
+ ArrayFromVector<DoubleType>({1.618, 2.718, 3.142}, &children[2]);
+ ArrayFromVector<Int8Type>({-12}, &children[3]);
+
+ std::vector<std::string> field_names = {"str", "int1", "real", "int2"};
+
+ std::shared_ptr<Array> result;
+ const UnionArray* union_array;
+
+ // without field names and type codes
+ ASSERT_OK_AND_ASSIGN(result,
+ DenseUnionArray::Make(*type_ids_, *value_offsets, children));
+ ASSERT_OK(result->ValidateFull());
+ union_array = checked_cast<const UnionArray*>(result.get());
+ CheckUnionArray(*union_array, UnionMode::DENSE, {"0", "1", "2", "3"}, {0, 1, 2, 3});
+
+ // with field name
+ ASSERT_RAISES(Invalid,
+ DenseUnionArray::Make(*type_ids_, *value_offsets, children, {"one"}));
+ ASSERT_OK_AND_ASSIGN(
+ result, DenseUnionArray::Make(*type_ids_, *value_offsets, children, field_names));
+ ASSERT_OK(result->ValidateFull());
+ union_array = checked_cast<const UnionArray*>(result.get());
+ CheckUnionArray(*union_array, UnionMode::DENSE, field_names, {0, 1, 2, 3});
+
+ // with type codes
+ ASSERT_RAISES(Invalid, DenseUnionArray::Make(*logical_type_ids_, *value_offsets,
+ children, std::vector<int8_t>{0}));
+ ASSERT_OK_AND_ASSIGN(result, DenseUnionArray::Make(*logical_type_ids_, *value_offsets,
+ children, type_codes_));
+ ASSERT_OK(result->ValidateFull());
+ union_array = checked_cast<const UnionArray*>(result.get());
+ CheckUnionArray(*union_array, UnionMode::DENSE, {"0", "1", "2", "3"}, type_codes_);
+
+ // with field names and type codes
+ ASSERT_RAISES(Invalid, DenseUnionArray::Make(*logical_type_ids_, *value_offsets,
+ children, {"one"}, type_codes_));
+ ASSERT_OK_AND_ASSIGN(result, DenseUnionArray::Make(*logical_type_ids_, *value_offsets,
+ children, field_names, type_codes_));
+ ASSERT_OK(result->ValidateFull());
+ union_array = checked_cast<const UnionArray*>(result.get());
+ CheckUnionArray(*union_array, UnionMode::DENSE, field_names, type_codes_);
+
+ // Invalid type codes
+ ASSERT_OK_AND_ASSIGN(result, DenseUnionArray::Make(*invalid_type_ids1_, *value_offsets,
+ children, type_codes_));
+ ASSERT_RAISES(Invalid, result->ValidateFull());
+ ASSERT_OK_AND_ASSIGN(result, DenseUnionArray::Make(*invalid_type_ids2_, *value_offsets,
+ children, type_codes_));
+ ASSERT_RAISES(Invalid, result->ValidateFull());
+
+ // Invalid offsets
+ // - offset out of bounds at index 5
+ std::shared_ptr<Array> invalid_offsets;
+ ArrayFromVector<Int32Type, int32_t>({0, 0, 0, 1, 1, 1, 1, 2, 1, 2}, &invalid_offsets);
+ ASSERT_OK_AND_ASSIGN(result,
+ DenseUnionArray::Make(*type_ids_, *invalid_offsets, children));
+ ASSERT_RAISES(Invalid, result->ValidateFull());
+ // - negative offset at index 5
+ ArrayFromVector<Int32Type, int32_t>({0, 0, 0, 1, 1, -1, 1, 2, 1, 2}, &invalid_offsets);
+ ASSERT_OK_AND_ASSIGN(result,
+ DenseUnionArray::Make(*type_ids_, *invalid_offsets, children));
+ ASSERT_RAISES(Invalid, result->ValidateFull());
+ // - non-monotonic offset at index 3
+ ArrayFromVector<Int32Type, int32_t>({1, 0, 0, 0, 1, 0, 1, 2, 1, 2}, &invalid_offsets);
+ ASSERT_OK_AND_ASSIGN(result,
+ DenseUnionArray::Make(*type_ids_, *invalid_offsets, children));
+ ASSERT_RAISES(Invalid, result->ValidateFull());
+}
+
+TEST_F(TestUnionArrayFactories, TestMakeSparse) {
+ auto children = std::vector<std::shared_ptr<Array>>(4);
+ ArrayFromVector<StringType, std::string>(
+ {"abc", "", "", "def", "", "", "", "xyz", "", ""}, &children[0]);
+ ArrayFromVector<UInt8Type>({0, 10, 0, 0, 20, 0, 0, 0, 0, 30}, &children[1]);
+ ArrayFromVector<DoubleType>({0.0, 0.0, 1.618, 0.0, 0.0, 0.0, 2.718, 0.0, 3.142, 0.0},
+ &children[2]);
+ ArrayFromVector<Int8Type>({0, 0, 0, 0, 0, -12, 0, 0, 0, 0}, &children[3]);
+
+ std::vector<std::string> field_names = {"str", "int1", "real", "int2"};
+
+ std::shared_ptr<Array> result;
+
+ // without field names and type codes
+ ASSERT_OK_AND_ASSIGN(result, SparseUnionArray::Make(*type_ids_, children));
+ ASSERT_OK(result->ValidateFull());
+ CheckUnionArray(checked_cast<UnionArray&>(*result), UnionMode::SPARSE,
+ {"0", "1", "2", "3"}, {0, 1, 2, 3});
+
+ // with field names
+ ASSERT_RAISES(Invalid, SparseUnionArray::Make(*type_ids_, children, {"one"}));
+ ASSERT_OK_AND_ASSIGN(result, SparseUnionArray::Make(*type_ids_, children, field_names));
+ ASSERT_OK(result->ValidateFull());
+ CheckUnionArray(checked_cast<UnionArray&>(*result), UnionMode::SPARSE, field_names,
+ {0, 1, 2, 3});
+
+ // with type codes
+ ASSERT_RAISES(Invalid, SparseUnionArray::Make(*logical_type_ids_, children,
+ std::vector<int8_t>{0}));
+ ASSERT_OK_AND_ASSIGN(result,
+ SparseUnionArray::Make(*logical_type_ids_, children, type_codes_));
+ ASSERT_OK(result->ValidateFull());
+ CheckUnionArray(checked_cast<UnionArray&>(*result), UnionMode::SPARSE,
+ {"0", "1", "2", "3"}, type_codes_);
+
+ // with field names and type codes
+ ASSERT_RAISES(Invalid, SparseUnionArray::Make(*logical_type_ids_, children, {"one"},
+ type_codes_));
+ ASSERT_OK_AND_ASSIGN(result, SparseUnionArray::Make(*logical_type_ids_, children,
+ field_names, type_codes_));
+ ASSERT_OK(result->ValidateFull());
+ CheckUnionArray(checked_cast<UnionArray&>(*result), UnionMode::SPARSE, field_names,
+ type_codes_);
+
+ // Invalid type codes
+ ASSERT_OK_AND_ASSIGN(
+ result, SparseUnionArray::Make(*invalid_type_ids1_, children, type_codes_));
+ ASSERT_RAISES(Invalid, result->ValidateFull());
+ ASSERT_OK_AND_ASSIGN(
+ result, SparseUnionArray::Make(*invalid_type_ids2_, children, type_codes_));
+ ASSERT_RAISES(Invalid, result->ValidateFull());
+
+ // Invalid child length
+ ArrayFromVector<Int8Type>({0, 0, 0, 0, 0, -12, 0, 0, 0}, &children[3]);
+ ASSERT_RAISES(Invalid, SparseUnionArray::Make(*type_ids_, children));
+}
+
+template <typename B>
+class UnionBuilderTest : public ::testing::Test {
+ public:
+ int8_t I8 = 8, STR = 13, DBL = 7;
+
+ virtual void AppendInt(int8_t i) {
+ expected_types_vector.push_back(I8);
+ ASSERT_OK(union_builder->Append(I8));
+ ASSERT_OK(i8_builder->Append(i));
+ }
+
+ virtual void AppendString(const std::string& str) {
+ expected_types_vector.push_back(STR);
+ ASSERT_OK(union_builder->Append(STR));
+ ASSERT_OK(str_builder->Append(str));
+ }
+
+ virtual void AppendDouble(double dbl) {
+ expected_types_vector.push_back(DBL);
+ ASSERT_OK(union_builder->Append(DBL));
+ ASSERT_OK(dbl_builder->Append(dbl));
+ }
+
+ void AppendBasics() {
+ AppendInt(33);
+ AppendString("abc");
+ AppendDouble(1.0);
+ AppendDouble(-1.0);
+ AppendString("");
+ AppendInt(10);
+ AppendString("def");
+ AppendInt(-10);
+ AppendDouble(0.5);
+
+ ASSERT_OK(union_builder->Finish(&actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int8Type, uint8_t>(expected_types_vector, &expected_types);
+ }
+
+ void AppendNullsAndEmptyValues() {
+ AppendString("abc");
+ ASSERT_OK(union_builder->AppendNull());
+ ASSERT_OK(union_builder->AppendEmptyValue());
+ expected_types_vector.insert(expected_types_vector.end(), 3, I8);
+ AppendInt(42);
+ ASSERT_OK(union_builder->AppendNulls(2));
+ ASSERT_OK(union_builder->AppendEmptyValues(2));
+ expected_types_vector.insert(expected_types_vector.end(), 3, I8);
+
+ ASSERT_OK(union_builder->Finish(&actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int8Type, uint8_t>(expected_types_vector, &expected_types);
+ }
+
+ void AppendInferred() {
+ I8 = union_builder->AppendChild(i8_builder, "i8");
+ ASSERT_EQ(I8, 0);
+ AppendInt(33);
+ AppendInt(10);
+
+ STR = union_builder->AppendChild(str_builder, "str");
+ ASSERT_EQ(STR, 1);
+ AppendString("abc");
+ AppendString("");
+ AppendString("def");
+ AppendInt(-10);
+
+ DBL = union_builder->AppendChild(dbl_builder, "dbl");
+ ASSERT_EQ(DBL, 2);
+ AppendDouble(1.0);
+ AppendDouble(-1.0);
+ AppendDouble(0.5);
+
+ ASSERT_OK(union_builder->Finish(&actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int8Type, uint8_t>(expected_types_vector, &expected_types);
+
+ ASSERT_EQ(I8, 0);
+ ASSERT_EQ(STR, 1);
+ ASSERT_EQ(DBL, 2);
+ }
+
+ void AppendListOfInferred(std::shared_ptr<ListArray>* actual) {
+ ListBuilder list_builder(default_memory_pool(), union_builder);
+
+ ASSERT_OK(list_builder.Append());
+ I8 = union_builder->AppendChild(i8_builder, "i8");
+ ASSERT_EQ(I8, 0);
+ AppendInt(10);
+
+ ASSERT_OK(list_builder.Append());
+ STR = union_builder->AppendChild(str_builder, "str");
+ ASSERT_EQ(STR, 1);
+ AppendString("abc");
+ AppendInt(-10);
+
+ ASSERT_OK(list_builder.Append());
+ DBL = union_builder->AppendChild(dbl_builder, "dbl");
+ ASSERT_EQ(DBL, 2);
+ AppendDouble(0.5);
+
+ ASSERT_OK(list_builder.Finish(actual));
+ ASSERT_OK((*actual)->ValidateFull());
+ ArrayFromVector<Int8Type, uint8_t>(expected_types_vector, &expected_types);
+ }
+
+ std::vector<uint8_t> expected_types_vector;
+ std::shared_ptr<Array> expected_types;
+ std::shared_ptr<Int8Builder> i8_builder = std::make_shared<Int8Builder>();
+ std::shared_ptr<StringBuilder> str_builder = std::make_shared<StringBuilder>();
+ std::shared_ptr<DoubleBuilder> dbl_builder = std::make_shared<DoubleBuilder>();
+ std::shared_ptr<B> union_builder = std::make_shared<B>(default_memory_pool());
+ std::shared_ptr<UnionArray> actual;
+};
+
+class DenseUnionBuilderTest : public UnionBuilderTest<DenseUnionBuilder> {};
+class SparseUnionBuilderTest : public UnionBuilderTest<SparseUnionBuilder> {
+ public:
+ using Base = UnionBuilderTest<SparseUnionBuilder>;
+
+ void AppendInt(int8_t i) override {
+ Base::AppendInt(i);
+ ASSERT_OK(str_builder->AppendEmptyValue());
+ ASSERT_OK(dbl_builder->AppendEmptyValue());
+ }
+
+ void AppendString(const std::string& str) override {
+ Base::AppendString(str);
+ ASSERT_OK(i8_builder->AppendEmptyValue());
+ ASSERT_OK(dbl_builder->AppendEmptyValue());
+ }
+
+ void AppendDouble(double dbl) override {
+ Base::AppendDouble(dbl);
+ ASSERT_OK(i8_builder->AppendEmptyValue());
+ ASSERT_OK(str_builder->AppendEmptyValue());
+ }
+};
+
+TEST_F(DenseUnionBuilderTest, Basics) {
+ union_builder.reset(new DenseUnionBuilder(
+ default_memory_pool(), {i8_builder, str_builder, dbl_builder},
+ dense_union({field("i8", int8()), field("str", utf8()), field("dbl", float64())},
+ {I8, STR, DBL})));
+ AppendBasics();
+
+ auto expected_i8 = ArrayFromJSON(int8(), "[33, 10, -10]");
+ auto expected_str = ArrayFromJSON(utf8(), R"(["abc", "", "def"])");
+ auto expected_dbl = ArrayFromJSON(float64(), "[1.0, -1.0, 0.5]");
+
+ auto expected_offsets = ArrayFromJSON(int32(), "[0, 0, 0, 1, 1, 1, 2, 2, 2]");
+
+ ASSERT_OK_AND_ASSIGN(auto expected,
+ DenseUnionArray::Make(*expected_types, *expected_offsets,
+ {expected_i8, expected_str, expected_dbl},
+ {"i8", "str", "dbl"}, {I8, STR, DBL}));
+
+ ASSERT_EQ(expected->type()->ToString(), actual->type()->ToString());
+ ASSERT_ARRAYS_EQUAL(*expected, *actual);
+}
+
+TEST_F(DenseUnionBuilderTest, NullsAndEmptyValues) {
+ union_builder.reset(new DenseUnionBuilder(
+ default_memory_pool(), {i8_builder, str_builder, dbl_builder},
+ dense_union({field("i8", int8()), field("str", utf8()), field("dbl", float64())},
+ {I8, STR, DBL})));
+ AppendNullsAndEmptyValues();
+
+ // Four null / empty values (the latter implementation-defined) were appended to I8
+ auto expected_i8 = ArrayFromJSON(int8(), "[null, 0, 42, null, 0]");
+ auto expected_str = ArrayFromJSON(utf8(), R"(["abc"])");
+ auto expected_dbl = ArrayFromJSON(float64(), "[]");
+
+ // "abc", null, 0, 42, null, null, 0, 0
+ auto expected_offsets = ArrayFromJSON(int32(), "[0, 0, 1, 2, 3, 3, 4, 4]");
+
+ ASSERT_OK_AND_ASSIGN(auto expected,
+ DenseUnionArray::Make(*expected_types, *expected_offsets,
+ {expected_i8, expected_str, expected_dbl},
+ {"i8", "str", "dbl"}, {I8, STR, DBL}));
+
+ ASSERT_EQ(expected->type()->ToString(), actual->type()->ToString());
+ ASSERT_ARRAYS_EQUAL(*expected, *actual);
+ // Physical arrays must be as expected
+ ASSERT_ARRAYS_EQUAL(*expected_i8, *actual->field(0));
+ ASSERT_ARRAYS_EQUAL(*expected_str, *actual->field(1));
+ ASSERT_ARRAYS_EQUAL(*expected_dbl, *actual->field(2));
+}
+
+TEST_F(DenseUnionBuilderTest, InferredType) {
+ AppendInferred();
+
+ auto expected_i8 = ArrayFromJSON(int8(), "[33, 10, -10]");
+ auto expected_str = ArrayFromJSON(utf8(), R"(["abc", "", "def"])");
+ auto expected_dbl = ArrayFromJSON(float64(), "[1.0, -1.0, 0.5]");
+
+ auto expected_offsets = ArrayFromJSON(int32(), "[0, 1, 0, 1, 2, 2, 0, 1, 2]");
+
+ ASSERT_OK_AND_ASSIGN(auto expected,
+ DenseUnionArray::Make(*expected_types, *expected_offsets,
+ {expected_i8, expected_str, expected_dbl},
+ {"i8", "str", "dbl"}, {I8, STR, DBL}));
+
+ ASSERT_EQ(expected->type()->ToString(), actual->type()->ToString());
+ ASSERT_ARRAYS_EQUAL(*expected, *actual);
+}
+
+TEST_F(DenseUnionBuilderTest, ListOfInferredType) {
+ std::shared_ptr<ListArray> actual;
+ AppendListOfInferred(&actual);
+
+ auto expected_type = list(
+ dense_union({field("i8", int8()), field("str", utf8()), field("dbl", float64())},
+ {I8, STR, DBL}));
+ ASSERT_EQ(expected_type->ToString(), actual->type()->ToString());
+}
+
+TEST_F(SparseUnionBuilderTest, Basics) {
+ union_builder.reset(new SparseUnionBuilder(
+ default_memory_pool(), {i8_builder, str_builder, dbl_builder},
+ sparse_union({field("i8", int8()), field("str", utf8()), field("dbl", float64())},
+ {I8, STR, DBL})));
+
+ AppendBasics();
+
+ auto expected_i8 =
+ ArrayFromJSON(int8(), "[33, null, null, null, null, 10, null, -10, null]");
+ auto expected_str =
+ ArrayFromJSON(utf8(), R"([null, "abc", null, null, "", null, "def", null, null])");
+ auto expected_dbl =
+ ArrayFromJSON(float64(), "[null, null, 1.0, -1.0, null, null, null, null, 0.5]");
+
+ ASSERT_OK_AND_ASSIGN(
+ auto expected,
+ SparseUnionArray::Make(*expected_types, {expected_i8, expected_str, expected_dbl},
+ {"i8", "str", "dbl"}, {I8, STR, DBL}));
+
+ ASSERT_EQ(expected->type()->ToString(), actual->type()->ToString());
+ ASSERT_ARRAYS_EQUAL(*expected, *actual);
+}
+
+TEST_F(SparseUnionBuilderTest, NullsAndEmptyValues) {
+ union_builder.reset(new SparseUnionBuilder(
+ default_memory_pool(), {i8_builder, str_builder, dbl_builder},
+ sparse_union({field("i8", int8()), field("str", utf8()), field("dbl", float64())},
+ {I8, STR, DBL})));
+ AppendNullsAndEmptyValues();
+
+ // "abc", null, 0, 42, null, null, 0, 0
+ // (note that getting 0 for empty values is implementation-defined)
+ auto expected_i8 = ArrayFromJSON(int8(), "[0, null, 0, 42, null, null, 0, 0]");
+ auto expected_str = ArrayFromJSON(utf8(), R"(["abc", "", "", "", "", "", "", ""])");
+ auto expected_dbl = ArrayFromJSON(float64(), "[0, 0, 0, 0, 0, 0, 0, 0]");
+
+ ASSERT_OK_AND_ASSIGN(
+ auto expected,
+ SparseUnionArray::Make(*expected_types, {expected_i8, expected_str, expected_dbl},
+ {"i8", "str", "dbl"}, {I8, STR, DBL}));
+
+ ASSERT_EQ(expected->type()->ToString(), actual->type()->ToString());
+ ASSERT_ARRAYS_EQUAL(*expected, *actual);
+ // Physical arrays must be as expected
+ ASSERT_ARRAYS_EQUAL(*expected_i8, *actual->field(0));
+ ASSERT_ARRAYS_EQUAL(*expected_str, *actual->field(1));
+ ASSERT_ARRAYS_EQUAL(*expected_dbl, *actual->field(2));
+}
+
+TEST_F(SparseUnionBuilderTest, InferredType) {
+ AppendInferred();
+
+ auto expected_i8 =
+ ArrayFromJSON(int8(), "[33, 10, null, null, null, -10, null, null, null]");
+ auto expected_str =
+ ArrayFromJSON(utf8(), R"([null, null, "abc", "", "def", null, null, null, null])");
+ auto expected_dbl =
+ ArrayFromJSON(float64(), "[null, null, null, null, null, null, 1.0, -1.0, 0.5]");
+
+ ASSERT_OK_AND_ASSIGN(
+ auto expected,
+ SparseUnionArray::Make(*expected_types, {expected_i8, expected_str, expected_dbl},
+ {"i8", "str", "dbl"}, {I8, STR, DBL}));
+
+ ASSERT_EQ(expected->type()->ToString(), actual->type()->ToString());
+ ASSERT_ARRAYS_EQUAL(*expected, *actual);
+}
+
+TEST_F(SparseUnionBuilderTest, StructWithUnion) {
+ auto union_builder = std::make_shared<SparseUnionBuilder>(default_memory_pool());
+ StructBuilder builder(struct_({field("u", union_builder->type())}),
+ default_memory_pool(), {union_builder});
+ ASSERT_EQ(union_builder->AppendChild(std::make_shared<Int32Builder>(), "i"), 0);
+ ASSERT_TRUE(builder.type()->Equals(
+ struct_({field("u", sparse_union({field("i", int32())}, {0}))})));
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/array_view_test.cc b/src/arrow/cpp/src/arrow/array/array_view_test.cc
new file mode 100644
index 000000000..07dc3014e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/array_view_test.cc
@@ -0,0 +1,441 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <string>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_dict.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/array/data.h"
+#include "arrow/extension_type.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+void CheckView(const std::shared_ptr<Array>& input,
+ const std::shared_ptr<DataType>& view_type,
+ const std::shared_ptr<Array>& expected) {
+ ASSERT_OK_AND_ASSIGN(auto result, input->View(view_type));
+ ASSERT_OK(result->ValidateFull());
+ AssertArraysEqual(*expected, *result);
+}
+
+void CheckView(const std::shared_ptr<Array>& input,
+ const std::shared_ptr<Array>& expected_view) {
+ CheckView(input, expected_view->type(), expected_view);
+}
+
+void CheckViewFails(const std::shared_ptr<Array>& input,
+ const std::shared_ptr<DataType>& view_type) {
+ ASSERT_RAISES(Invalid, input->View(view_type));
+}
+
+class IPv4Type : public ExtensionType {
+ public:
+ IPv4Type() : ExtensionType(fixed_size_binary(4)) {}
+
+ std::string extension_name() const override { return "ipv4"; }
+
+ bool ExtensionEquals(const ExtensionType& other) const override {
+ return other.extension_name() == this->extension_name();
+ }
+
+ std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
+ DCHECK_EQ(data->type->id(), Type::EXTENSION);
+ DCHECK_EQ("ipv4", static_cast<const ExtensionType&>(*data->type).extension_name());
+ return std::make_shared<ExtensionArray>(data);
+ }
+
+ Result<std::shared_ptr<DataType>> Deserialize(
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized) const override {
+ return Status::NotImplemented("IPv4Type::Deserialize");
+ }
+
+ std::string Serialize() const override { return ""; }
+};
+
+TEST(TestArrayView, IdentityPrimitive) {
+ auto arr = ArrayFromJSON(int16(), "[0, -1, 42]");
+ CheckView(arr, arr->type(), arr);
+ arr = ArrayFromJSON(int16(), "[0, -1, 42, null]");
+ CheckView(arr, arr->type(), arr);
+ arr = ArrayFromJSON(boolean(), "[true, false, null]");
+ CheckView(arr, arr->type(), arr);
+}
+
+TEST(TestArrayView, IdentityNullType) {
+ auto arr = ArrayFromJSON(null(), "[null, null, null]");
+ CheckView(arr, arr->type(), arr);
+}
+
+TEST(TestArrayView, PrimitiveAsPrimitive) {
+ auto arr = ArrayFromJSON(int16(), "[0, -1, 42]");
+ auto expected = ArrayFromJSON(uint16(), "[0, 65535, 42]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ arr = ArrayFromJSON(int32(), "[0, 1069547520, -1071644672, null]");
+ expected = ArrayFromJSON(float32(), "[0.0, 1.5, -2.5, null]");
+ CheckView(arr, expected);
+
+ arr = ArrayFromJSON(timestamp(TimeUnit::SECOND),
+ R"(["1970-01-01","2000-02-29","3989-07-14","1900-02-28"])");
+ expected = ArrayFromJSON(int64(), "[0, 951782400, 63730281600, -2203977600]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+}
+
+TEST(TestArrayView, PrimitiveAsFixedSizeBinary) {
+#if ARROW_LITTLE_ENDIAN
+ auto arr = ArrayFromJSON(int32(), "[2020568934, 2054316386, null]");
+#else
+ auto arr = ArrayFromJSON(int32(), "[1718579064, 1650553466, null]");
+#endif
+ auto expected = ArrayFromJSON(fixed_size_binary(4), R"(["foox", "barz", null])");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+}
+
+TEST(TestArrayView, StringAsBinary) {
+ auto arr = ArrayFromJSON(utf8(), R"(["foox", "barz", null])");
+ auto expected = ArrayFromJSON(binary(), R"(["foox", "barz", null])");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+}
+
+TEST(TestArrayView, PrimitiveWrongSize) {
+ auto arr = ArrayFromJSON(int16(), "[0, -1, 42]");
+ CheckViewFails(arr, int8());
+ CheckViewFails(arr, fixed_size_binary(3));
+ CheckViewFails(arr, null());
+}
+
+TEST(TestArrayView, StructAsStructSimple) {
+ auto ty1 = struct_({field("a", int8()), field("b", int32())});
+ auto ty2 = struct_({field("c", uint8()), field("d", float32())});
+
+ auto arr = ArrayFromJSON(ty1, "[[0, 0], [1, 1069547520], [-1, -1071644672]]");
+ auto expected = ArrayFromJSON(ty2, "[[0, 0], [1, 1.5], [255, -2.5]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // With nulls
+ arr = ArrayFromJSON(ty1, "[[0, 0], null, [-1, -1071644672]]");
+ expected = ArrayFromJSON(ty2, "[[0, 0], null, [255, -2.5]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // With nested nulls
+ arr = ArrayFromJSON(ty1, "[[0, null], null, [-1, -1071644672]]");
+ expected = ArrayFromJSON(ty2, "[[0, null], null, [255, -2.5]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ ty2 = struct_({field("c", uint8()), field("d", fixed_size_binary(4))});
+#if ARROW_LITTLE_ENDIAN
+ arr = ArrayFromJSON(ty1, "[[0, null], null, [-1, 2020568934]]");
+#else
+ arr = ArrayFromJSON(ty1, "[[0, null], null, [-1, 1718579064]]");
+#endif
+ expected = ArrayFromJSON(ty2, R"([[0, null], null, [255, "foox"]])");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+}
+
+TEST(TestArrayView, StructAsStructNonNullable) {
+ auto ty1 = struct_({field("a", int8()), field("b", int32())});
+ auto ty2 = struct_({field("c", uint8(), /*nullable=*/false), field("d", float32())});
+
+ auto arr = ArrayFromJSON(ty1, "[[0, 0], [1, 1069547520], [-1, -1071644672]]");
+ auto expected = ArrayFromJSON(ty2, "[[0, 0], [1, 1.5], [255, -2.5]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // With nested nulls
+ arr = ArrayFromJSON(ty1, "[[0, null], [-1, -1071644672]]");
+ expected = ArrayFromJSON(ty2, "[[0, null], [255, -2.5]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // Nested null cannot be viewed as non-null field
+ arr = ArrayFromJSON(ty1, "[[0, null], [null, -1071644672]]");
+ CheckViewFails(arr, ty2);
+}
+
+TEST(TestArrayView, StructAsStructWrongLayout) {
+ auto ty1 = struct_({field("a", int8()), field("b", int32())});
+ auto arr = ArrayFromJSON(ty1, "[[0, 0], [1, 1069547520], [-1, -1071644672]]");
+
+ auto ty2 = struct_({field("c", int16()), field("d", int32())});
+ CheckViewFails(arr, ty2);
+ ty2 = struct_({field("c", int32()), field("d", int8())});
+ CheckViewFails(arr, ty2);
+ ty2 = struct_({field("c", int8())});
+ CheckViewFails(arr, ty2);
+ ty2 = struct_({field("c", fixed_size_binary(5))});
+ CheckViewFails(arr, ty2);
+}
+
+TEST(TestArrayView, StructAsStructWithNullType) {
+ auto ty1 = struct_({field("a", int8()), field("b", null())});
+ auto ty2 = struct_({field("c", uint8()), field("d", null())});
+
+ auto arr = ArrayFromJSON(ty1, "[[0, null], [1, null], [-1, null]]");
+ auto expected = ArrayFromJSON(ty2, "[[0, null], [1, null], [255, null]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // With nulls and nested nulls
+ arr = ArrayFromJSON(ty1, "[null, [null, null], [-1, null]]");
+ expected = ArrayFromJSON(ty2, "[null, [null, null], [255, null]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // Moving the null types around
+ ty2 = struct_({field("c", null()), field("d", uint8())});
+ expected = ArrayFromJSON(ty2, "[null, [null, null], [null, 255]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // Removing the null type
+ ty2 = struct_({field("c", uint8())});
+ expected = ArrayFromJSON(ty2, "[null, [null], [255]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+}
+
+TEST(TestArrayView, StructAsFlat) {
+ auto ty1 = struct_({field("a", int16())});
+ auto arr = ArrayFromJSON(ty1, "[[0], [1], [-1]]");
+ auto expected = ArrayFromJSON(uint16(), "[0, 1, 65535]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // With nulls
+ arr = ArrayFromJSON(ty1, "[[0], null, [-1]]");
+ expected = ArrayFromJSON(uint16(), "[0, null, 65535]");
+ // CheckView(arr, expected); // XXX currently fails
+ CheckView(expected, arr);
+
+ // With nested nulls => fails
+ arr = ArrayFromJSON(ty1, "[[0], [null], [-1]]");
+ CheckViewFails(arr, uint16());
+}
+
+TEST(TestArrayView, StructAsFlatWithNullType) {
+ auto ty1 = struct_({field("a", null()), field("b", int16()), field("c", null())});
+ auto arr = ArrayFromJSON(ty1, "[[null, 0, null], [null, -1, null]]");
+ auto expected = ArrayFromJSON(uint16(), "[0, 65535]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // With nulls
+ arr = ArrayFromJSON(ty1, "[[null, 0, null], null, [null, -1, null]]");
+ expected = ArrayFromJSON(uint16(), "[0, null, 65535]");
+ // CheckView(arr, expected); // XXX currently fails
+ CheckView(expected, arr);
+
+ // With nested nulls => fails
+ arr = ArrayFromJSON(ty1, "[[null, null, null]]");
+ CheckViewFails(arr, uint16());
+}
+
+TEST(TestArrayView, StructAsStructNested) {
+ // Nesting tree shape need not be identical
+ auto ty1 = struct_({field("a", struct_({field("b", int8())})), field("d", int32())});
+ auto ty2 = struct_({field("a", uint8()), field("b", struct_({field("b", float32())}))});
+ auto arr = ArrayFromJSON(ty1, "[[[0], 1069547520], [[-1], -1071644672]]");
+ auto expected = ArrayFromJSON(ty2, "[[0, [1.5]], [255, [-2.5]]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // With null types
+ ty1 = struct_({field("a", struct_({field("xx", null()), field("b", int8())})),
+ field("d", int32())});
+ ty2 = struct_({field("a", uint8()),
+ field("b", struct_({field("b", float32()), field("xx", null())}))});
+ arr = ArrayFromJSON(ty1, "[[[null, 0], 1069547520], [[null, -1], -1071644672]]");
+ expected = ArrayFromJSON(ty2, "[[0, [1.5, null]], [255, [-2.5, null]]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // XXX With nulls (currently fails)
+}
+
+TEST(TestArrayView, ListAsListSimple) {
+ auto arr = ArrayFromJSON(list(int16()), "[[0, -1], [], [42]]");
+ auto expected = ArrayFromJSON(list(uint16()), "[[0, 65535], [], [42]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // With nulls
+ arr = ArrayFromJSON(list(int16()), "[[0, -1], null, [42]]");
+ expected = ArrayFromJSON(list(uint16()), "[[0, 65535], null, [42]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // With nested nulls
+ arr = ArrayFromJSON(list(int16()), "[[0, -1], null, [null, 42]]");
+ expected = ArrayFromJSON(list(uint16()), "[[0, 65535], null, [null, 42]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+}
+
+TEST(TestArrayView, FixedSizeListAsFixedSizeList) {
+ auto ty1 = fixed_size_list(int16(), 3);
+ auto ty2 = fixed_size_list(uint16(), 3);
+ auto arr = ArrayFromJSON(ty1, "[[0, -1, 42], [5, 6, -16384]]");
+ auto expected = ArrayFromJSON(ty2, "[[0, 65535, 42], [5, 6, 49152]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // With nested nulls
+ arr = ArrayFromJSON(ty1, "[[0, -1, null], null, [5, 6, -16384]]");
+ expected = ArrayFromJSON(ty2, "[[0, 65535, null], null, [5, 6, 49152]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+}
+
+TEST(TestArrayView, FixedSizeListAsFlat) {
+ auto ty1 = fixed_size_list(int16(), 3);
+ auto arr = ArrayFromJSON(ty1, "[[0, -1, 42], [5, 6, -16384]]");
+ auto expected = ArrayFromJSON(uint16(), "[0, 65535, 42, 5, 6, 49152]");
+ CheckView(arr, expected);
+ // CheckView(expected, arr); // XXX currently fails
+
+ // XXX With nulls (currently fails)
+}
+
+TEST(TestArrayView, FixedSizeListAsFixedSizeBinary) {
+ auto ty1 = fixed_size_list(int32(), 1);
+#if ARROW_LITTLE_ENDIAN
+ auto arr = ArrayFromJSON(ty1, "[[2020568934], [2054316386]]");
+#else
+ auto arr = ArrayFromJSON(ty1, "[[1718579064], [1650553466]]");
+#endif
+ auto expected = ArrayFromJSON(fixed_size_binary(4), R"(["foox", "barz"])");
+ CheckView(arr, expected);
+}
+
+TEST(TestArrayView, SparseUnionAsStruct) {
+ auto child1 = ArrayFromJSON(int16(), "[0, -1, 42]");
+ auto child2 = ArrayFromJSON(int32(), "[0, 1069547520, -1071644672]");
+ auto indices = ArrayFromJSON(int8(), "[0, 0, 1]");
+ ASSERT_OK_AND_ASSIGN(auto arr, SparseUnionArray::Make(*indices, {child1, child2}));
+ ASSERT_OK(arr->ValidateFull());
+
+ auto ty1 = struct_({field("a", int8()), field("b", uint16()), field("c", float32())});
+ auto expected = ArrayFromJSON(ty1, "[[0, 0, 0], [0, 65535, 1.5], [1, 42, -2.5]]");
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+}
+
+TEST(TestArrayView, DecimalRoundTrip) {
+ auto ty1 = decimal(10, 4);
+ auto arr = ArrayFromJSON(ty1, R"(["123.4567", "-78.9000", null])");
+
+ auto ty2 = fixed_size_binary(16);
+ ASSERT_OK_AND_ASSIGN(auto v, arr->View(ty2));
+ ASSERT_OK(v->ValidateFull());
+ ASSERT_OK_AND_ASSIGN(auto w, v->View(ty1));
+ ASSERT_OK(w->ValidateFull());
+ AssertArraysEqual(*arr, *w);
+}
+
+TEST(TestArrayView, Dictionaries) {
+ // ARROW-6049
+ auto ty1 = dictionary(int8(), float32());
+ auto ty2 = dictionary(int8(), int32());
+
+ auto indices = ArrayFromJSON(int8(), "[0, 2, null, 1]");
+ auto values = ArrayFromJSON(float32(), "[0.0, 1.5, -2.5]");
+
+ ASSERT_OK_AND_ASSIGN(auto expected_dict, values->View(int32()));
+ ASSERT_OK_AND_ASSIGN(auto arr, DictionaryArray::FromArrays(ty1, indices, values));
+ ASSERT_OK_AND_ASSIGN(auto expected,
+ DictionaryArray::FromArrays(ty2, indices, expected_dict));
+
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+
+ // Incompatible index type
+ auto ty3 = dictionary(int16(), int32());
+ CheckViewFails(arr, ty3);
+
+ // Incompatible dictionary type
+ auto ty4 = dictionary(int16(), float64());
+ CheckViewFails(arr, ty4);
+
+ // Check dictionary-encoded child
+ auto offsets = ArrayFromJSON(int32(), "[0, 2, 2, 4]");
+ ASSERT_OK_AND_ASSIGN(auto list_arr, ListArray::FromArrays(*offsets, *arr));
+ ASSERT_OK_AND_ASSIGN(auto expected_list_arr,
+ ListArray::FromArrays(*offsets, *expected));
+ CheckView(list_arr, expected_list_arr);
+ CheckView(expected_list_arr, list_arr);
+}
+
+TEST(TestArrayView, ExtensionType) {
+ auto ty1 = std::make_shared<IPv4Type>();
+ auto data = ArrayFromJSON(ty1->storage_type(), R"(["ABCD", null])")->data();
+ data->type = ty1;
+ auto arr = ty1->MakeArray(data);
+#if ARROW_LITTLE_ENDIAN
+ auto expected = ArrayFromJSON(uint32(), "[1145258561, null]");
+#else
+ auto expected = ArrayFromJSON(uint32(), "[1094861636, null]");
+#endif
+ CheckView(arr, expected);
+ CheckView(expected, arr);
+}
+
+TEST(TestArrayView, NonZeroOffset) {
+ auto arr = ArrayFromJSON(int16(), "[10, 11, 12, 13]");
+
+ ASSERT_OK_AND_ASSIGN(auto expected, arr->View(fixed_size_binary(2)));
+ CheckView(arr->Slice(1), expected->Slice(1));
+}
+
+TEST(TestArrayView, NonZeroNestedOffset) {
+ auto list_values = ArrayFromJSON(int16(), "[10, 11, 12, 13, 14]");
+ auto view_values = ArrayFromJSON(uint16(), "[10, 11, 12, 13, 14]");
+
+ auto list_offsets = ArrayFromJSON(int32(), "[0, 2, 3]");
+
+ ASSERT_OK_AND_ASSIGN(auto arr,
+ ListArray::FromArrays(*list_offsets, *list_values->Slice(2)));
+ ASSERT_OK_AND_ASSIGN(auto expected,
+ ListArray::FromArrays(*list_offsets, *view_values->Slice(2)));
+ ASSERT_OK(arr->ValidateFull());
+ CheckView(arr->Slice(1), expected->Slice(1));
+
+ // Be extra paranoid about checking offsets
+ ASSERT_OK_AND_ASSIGN(auto result, arr->Slice(1)->View(expected->type()));
+ ASSERT_EQ(1, result->offset());
+ ASSERT_EQ(2, static_cast<const ListArray&>(*result).values()->offset());
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_adaptive.cc b/src/arrow/cpp/src/arrow/array/builder_adaptive.cc
new file mode 100644
index 000000000..36e5546a7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_adaptive.cc
@@ -0,0 +1,380 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/builder_adaptive.h"
+
+#include <algorithm>
+#include <cstdint>
+
+#include "arrow/array/data.h"
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/int_util.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::AdaptiveIntBuilderBase;
+
+AdaptiveIntBuilderBase::AdaptiveIntBuilderBase(uint8_t start_int_size, MemoryPool* pool)
+ : ArrayBuilder(pool), start_int_size_(start_int_size), int_size_(start_int_size) {}
+
+void AdaptiveIntBuilderBase::Reset() {
+ ArrayBuilder::Reset();
+ data_.reset();
+ raw_data_ = nullptr;
+ pending_pos_ = 0;
+ pending_has_nulls_ = false;
+ int_size_ = start_int_size_;
+}
+
+Status AdaptiveIntBuilderBase::Resize(int64_t capacity) {
+ RETURN_NOT_OK(CheckCapacity(capacity));
+ capacity = std::max(capacity, kMinBuilderCapacity);
+
+ int64_t nbytes = capacity * int_size_;
+ if (capacity_ == 0) {
+ ARROW_ASSIGN_OR_RAISE(data_, AllocateResizableBuffer(nbytes, pool_));
+ } else {
+ RETURN_NOT_OK(data_->Resize(nbytes));
+ }
+ raw_data_ = reinterpret_cast<uint8_t*>(data_->mutable_data());
+
+ return ArrayBuilder::Resize(capacity);
+}
+
+template <typename new_type, typename old_type>
+typename std::enable_if<sizeof(old_type) >= sizeof(new_type), Status>::type
+AdaptiveIntBuilderBase::ExpandIntSizeInternal() {
+ return Status::OK();
+}
+
+template <typename new_type, typename old_type>
+typename std::enable_if<(sizeof(old_type) < sizeof(new_type)), Status>::type
+AdaptiveIntBuilderBase::ExpandIntSizeInternal() {
+ int_size_ = sizeof(new_type);
+ RETURN_NOT_OK(Resize(data_->size() / sizeof(old_type)));
+
+ const old_type* src = reinterpret_cast<old_type*>(raw_data_);
+ new_type* dst = reinterpret_cast<new_type*>(raw_data_);
+ // By doing the backward copy, we ensure that no element is overridden during
+ // the copy process while the copy stays in-place.
+ std::copy_backward(src, src + length_, dst + length_);
+
+ return Status::OK();
+}
+
+std::shared_ptr<DataType> AdaptiveUIntBuilder::type() const {
+ auto int_size = int_size_;
+ if (pending_pos_ != 0) {
+ const uint8_t* valid_bytes = pending_has_nulls_ ? pending_valid_ : nullptr;
+ int_size =
+ internal::DetectUIntWidth(pending_data_, valid_bytes, pending_pos_, int_size_);
+ }
+ switch (int_size) {
+ case 1:
+ return uint8();
+ case 2:
+ return uint16();
+ case 4:
+ return uint32();
+ case 8:
+ return uint64();
+ default:
+ DCHECK(false);
+ }
+ return nullptr;
+}
+
+std::shared_ptr<DataType> AdaptiveIntBuilder::type() const {
+ auto int_size = int_size_;
+ if (pending_pos_ != 0) {
+ const uint8_t* valid_bytes = pending_has_nulls_ ? pending_valid_ : nullptr;
+ int_size = internal::DetectIntWidth(reinterpret_cast<const int64_t*>(pending_data_),
+ valid_bytes, pending_pos_, int_size_);
+ }
+ switch (int_size) {
+ case 1:
+ return int8();
+ case 2:
+ return int16();
+ case 4:
+ return int32();
+ case 8:
+ return int64();
+ default:
+ DCHECK(false);
+ }
+ return nullptr;
+}
+
+AdaptiveIntBuilder::AdaptiveIntBuilder(uint8_t start_int_size, MemoryPool* pool)
+ : AdaptiveIntBuilderBase(start_int_size, pool) {}
+
+Status AdaptiveIntBuilder::FinishInternal(std::shared_ptr<ArrayData>* out) {
+ RETURN_NOT_OK(CommitPendingData());
+
+ std::shared_ptr<Buffer> null_bitmap;
+ RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
+ RETURN_NOT_OK(TrimBuffer(length_ * int_size_, data_.get()));
+
+ *out = ArrayData::Make(type(), length_, {null_bitmap, data_}, null_count_);
+
+ data_ = nullptr;
+ capacity_ = length_ = null_count_ = 0;
+ return Status::OK();
+}
+
+Status AdaptiveIntBuilder::CommitPendingData() {
+ if (pending_pos_ == 0) {
+ return Status::OK();
+ }
+ RETURN_NOT_OK(Reserve(pending_pos_));
+ const uint8_t* valid_bytes = pending_has_nulls_ ? pending_valid_ : nullptr;
+ RETURN_NOT_OK(AppendValuesInternal(reinterpret_cast<const int64_t*>(pending_data_),
+ pending_pos_, valid_bytes));
+ pending_has_nulls_ = false;
+ pending_pos_ = 0;
+ return Status::OK();
+}
+
+static constexpr int64_t kAdaptiveIntChunkSize = 8192;
+
+Status AdaptiveIntBuilder::AppendValuesInternal(const int64_t* values, int64_t length,
+ const uint8_t* valid_bytes) {
+ if (pending_pos_ > 0) {
+ // UnsafeAppendToBitmap expects length_ to be the pre-update value, satisfy it
+ DCHECK_EQ(length, pending_pos_) << "AppendValuesInternal called while data pending";
+ length_ -= pending_pos_;
+ }
+
+ while (length > 0) {
+ // In case `length` is very large, we don't want to trash the cache by
+ // scanning it twice (first to detect int width, second to copy the data).
+ // Instead, process data in L2-cacheable chunks.
+ const int64_t chunk_size = std::min(length, kAdaptiveIntChunkSize);
+
+ uint8_t new_int_size;
+ new_int_size = internal::DetectIntWidth(values, valid_bytes, chunk_size, int_size_);
+
+ DCHECK_GE(new_int_size, int_size_);
+ if (new_int_size > int_size_) {
+ // This updates int_size_
+ RETURN_NOT_OK(ExpandIntSize(new_int_size));
+ }
+
+ switch (int_size_) {
+ case 1:
+ internal::DowncastInts(values, reinterpret_cast<int8_t*>(raw_data_) + length_,
+ chunk_size);
+ break;
+ case 2:
+ internal::DowncastInts(values, reinterpret_cast<int16_t*>(raw_data_) + length_,
+ chunk_size);
+ break;
+ case 4:
+ internal::DowncastInts(values, reinterpret_cast<int32_t*>(raw_data_) + length_,
+ chunk_size);
+ break;
+ case 8:
+ internal::DowncastInts(values, reinterpret_cast<int64_t*>(raw_data_) + length_,
+ chunk_size);
+ break;
+ default:
+ DCHECK(false);
+ }
+
+ // UnsafeAppendToBitmap increments length_ by chunk_size
+ ArrayBuilder::UnsafeAppendToBitmap(valid_bytes, chunk_size);
+ values += chunk_size;
+ if (valid_bytes != nullptr) {
+ valid_bytes += chunk_size;
+ }
+ length -= chunk_size;
+ }
+
+ return Status::OK();
+}
+
+Status AdaptiveUIntBuilder::CommitPendingData() {
+ if (pending_pos_ == 0) {
+ return Status::OK();
+ }
+ RETURN_NOT_OK(Reserve(pending_pos_));
+ const uint8_t* valid_bytes = pending_has_nulls_ ? pending_valid_ : nullptr;
+ RETURN_NOT_OK(AppendValuesInternal(pending_data_, pending_pos_, valid_bytes));
+ pending_has_nulls_ = false;
+ pending_pos_ = 0;
+ return Status::OK();
+}
+
+Status AdaptiveIntBuilder::AppendValues(const int64_t* values, int64_t length,
+ const uint8_t* valid_bytes) {
+ RETURN_NOT_OK(CommitPendingData());
+ RETURN_NOT_OK(Reserve(length));
+
+ return AppendValuesInternal(values, length, valid_bytes);
+}
+
+template <typename new_type>
+Status AdaptiveIntBuilder::ExpandIntSizeN() {
+ switch (int_size_) {
+ case 1:
+ return ExpandIntSizeInternal<new_type, int8_t>();
+ case 2:
+ return ExpandIntSizeInternal<new_type, int16_t>();
+ case 4:
+ return ExpandIntSizeInternal<new_type, int32_t>();
+ case 8:
+ return ExpandIntSizeInternal<new_type, int64_t>();
+ default:
+ DCHECK(false);
+ }
+ return Status::OK();
+}
+
+Status AdaptiveIntBuilder::ExpandIntSize(uint8_t new_int_size) {
+ switch (new_int_size) {
+ case 1:
+ return ExpandIntSizeN<int8_t>();
+ case 2:
+ return ExpandIntSizeN<int16_t>();
+ case 4:
+ return ExpandIntSizeN<int32_t>();
+ case 8:
+ return ExpandIntSizeN<int64_t>();
+ default:
+ DCHECK(false);
+ }
+ return Status::OK();
+}
+
+AdaptiveUIntBuilder::AdaptiveUIntBuilder(uint8_t start_int_size, MemoryPool* pool)
+ : AdaptiveIntBuilderBase(start_int_size, pool) {}
+
+Status AdaptiveUIntBuilder::FinishInternal(std::shared_ptr<ArrayData>* out) {
+ RETURN_NOT_OK(CommitPendingData());
+
+ std::shared_ptr<Buffer> null_bitmap;
+ RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
+ RETURN_NOT_OK(TrimBuffer(length_ * int_size_, data_.get()));
+
+ *out = ArrayData::Make(type(), length_, {null_bitmap, data_}, null_count_);
+
+ data_ = nullptr;
+ capacity_ = length_ = null_count_ = 0;
+ return Status::OK();
+}
+
+Status AdaptiveUIntBuilder::AppendValuesInternal(const uint64_t* values, int64_t length,
+ const uint8_t* valid_bytes) {
+ if (pending_pos_ > 0) {
+ // UnsafeAppendToBitmap expects length_ to be the pre-update value, satisfy it
+ DCHECK_EQ(length, pending_pos_) << "AppendValuesInternal called while data pending";
+ length_ -= pending_pos_;
+ }
+
+ while (length > 0) {
+ // See AdaptiveIntBuilder::AppendValuesInternal
+ const int64_t chunk_size = std::min(length, kAdaptiveIntChunkSize);
+
+ uint8_t new_int_size;
+ new_int_size = internal::DetectUIntWidth(values, valid_bytes, chunk_size, int_size_);
+
+ DCHECK_GE(new_int_size, int_size_);
+ if (new_int_size > int_size_) {
+ // This updates int_size_
+ RETURN_NOT_OK(ExpandIntSize(new_int_size));
+ }
+
+ switch (int_size_) {
+ case 1:
+ internal::DowncastUInts(values, reinterpret_cast<uint8_t*>(raw_data_) + length_,
+ chunk_size);
+ break;
+ case 2:
+ internal::DowncastUInts(values, reinterpret_cast<uint16_t*>(raw_data_) + length_,
+ chunk_size);
+ break;
+ case 4:
+ internal::DowncastUInts(values, reinterpret_cast<uint32_t*>(raw_data_) + length_,
+ chunk_size);
+ break;
+ case 8:
+ internal::DowncastUInts(values, reinterpret_cast<uint64_t*>(raw_data_) + length_,
+ chunk_size);
+ break;
+ default:
+ DCHECK(false);
+ }
+
+ // UnsafeAppendToBitmap increments length_ by chunk_size
+ ArrayBuilder::UnsafeAppendToBitmap(valid_bytes, chunk_size);
+ values += chunk_size;
+ if (valid_bytes != nullptr) {
+ valid_bytes += chunk_size;
+ }
+ length -= chunk_size;
+ }
+
+ return Status::OK();
+}
+
+Status AdaptiveUIntBuilder::AppendValues(const uint64_t* values, int64_t length,
+ const uint8_t* valid_bytes) {
+ RETURN_NOT_OK(Reserve(length));
+
+ return AppendValuesInternal(values, length, valid_bytes);
+}
+
+template <typename new_type>
+Status AdaptiveUIntBuilder::ExpandIntSizeN() {
+ switch (int_size_) {
+ case 1:
+ return ExpandIntSizeInternal<new_type, uint8_t>();
+ case 2:
+ return ExpandIntSizeInternal<new_type, uint16_t>();
+ case 4:
+ return ExpandIntSizeInternal<new_type, uint32_t>();
+ case 8:
+ return ExpandIntSizeInternal<new_type, uint64_t>();
+ default:
+ DCHECK(false);
+ }
+ return Status::OK();
+}
+
+Status AdaptiveUIntBuilder::ExpandIntSize(uint8_t new_int_size) {
+ switch (new_int_size) {
+ case 1:
+ return ExpandIntSizeN<uint8_t>();
+ case 2:
+ return ExpandIntSizeN<uint16_t>();
+ case 4:
+ return ExpandIntSizeN<uint32_t>();
+ case 8:
+ return ExpandIntSizeN<uint64_t>();
+ default:
+ DCHECK(false);
+ }
+ return Status::OK();
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_adaptive.h b/src/arrow/cpp/src/arrow/array/builder_adaptive.h
new file mode 100644
index 000000000..c0df79725
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_adaptive.h
@@ -0,0 +1,203 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <type_traits>
+
+#include "arrow/array/builder_base.h"
+#include "arrow/buffer.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+namespace internal {
+
+class ARROW_EXPORT AdaptiveIntBuilderBase : public ArrayBuilder {
+ public:
+ AdaptiveIntBuilderBase(uint8_t start_int_size, MemoryPool* pool);
+
+ explicit AdaptiveIntBuilderBase(MemoryPool* pool)
+ : AdaptiveIntBuilderBase(sizeof(uint8_t), pool) {}
+
+ /// \brief Append multiple nulls
+ /// \param[in] length the number of nulls to append
+ Status AppendNulls(int64_t length) final {
+ ARROW_RETURN_NOT_OK(CommitPendingData());
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ memset(data_->mutable_data() + length_ * int_size_, 0, int_size_ * length);
+ UnsafeSetNull(length);
+ return Status::OK();
+ }
+
+ Status AppendNull() final {
+ pending_data_[pending_pos_] = 0;
+ pending_valid_[pending_pos_] = 0;
+ pending_has_nulls_ = true;
+ ++pending_pos_;
+ ++length_;
+ ++null_count_;
+
+ if (ARROW_PREDICT_FALSE(pending_pos_ >= pending_size_)) {
+ return CommitPendingData();
+ }
+ return Status::OK();
+ }
+
+ Status AppendEmptyValues(int64_t length) final {
+ ARROW_RETURN_NOT_OK(CommitPendingData());
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ memset(data_->mutable_data() + length_ * int_size_, 0, int_size_ * length);
+ UnsafeSetNotNull(length);
+ return Status::OK();
+ }
+
+ Status AppendEmptyValue() final {
+ pending_data_[pending_pos_] = 0;
+ pending_valid_[pending_pos_] = 1;
+ ++pending_pos_;
+ ++length_;
+
+ if (ARROW_PREDICT_FALSE(pending_pos_ >= pending_size_)) {
+ return CommitPendingData();
+ }
+ return Status::OK();
+ }
+
+ void Reset() override;
+ Status Resize(int64_t capacity) override;
+
+ protected:
+ Status AppendInternal(const uint64_t val) {
+ pending_data_[pending_pos_] = val;
+ pending_valid_[pending_pos_] = 1;
+ ++pending_pos_;
+ ++length_;
+
+ if (ARROW_PREDICT_FALSE(pending_pos_ >= pending_size_)) {
+ return CommitPendingData();
+ }
+ return Status::OK();
+ }
+
+ virtual Status CommitPendingData() = 0;
+
+ template <typename new_type, typename old_type>
+ typename std::enable_if<sizeof(old_type) >= sizeof(new_type), Status>::type
+ ExpandIntSizeInternal();
+ template <typename new_type, typename old_type>
+ typename std::enable_if<(sizeof(old_type) < sizeof(new_type)), Status>::type
+ ExpandIntSizeInternal();
+
+ std::shared_ptr<ResizableBuffer> data_;
+ uint8_t* raw_data_ = NULLPTR;
+
+ const uint8_t start_int_size_;
+ uint8_t int_size_;
+
+ static constexpr int32_t pending_size_ = 1024;
+ uint8_t pending_valid_[pending_size_];
+ uint64_t pending_data_[pending_size_];
+ int32_t pending_pos_ = 0;
+ bool pending_has_nulls_ = false;
+};
+
+} // namespace internal
+
+class ARROW_EXPORT AdaptiveUIntBuilder : public internal::AdaptiveIntBuilderBase {
+ public:
+ explicit AdaptiveUIntBuilder(uint8_t start_int_size,
+ MemoryPool* pool = default_memory_pool());
+
+ explicit AdaptiveUIntBuilder(MemoryPool* pool = default_memory_pool())
+ : AdaptiveUIntBuilder(sizeof(uint8_t), pool) {}
+
+ using ArrayBuilder::Advance;
+ using internal::AdaptiveIntBuilderBase::Reset;
+
+ /// Scalar append
+ Status Append(const uint64_t val) { return AppendInternal(val); }
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values a contiguous C array of values
+ /// \param[in] length the number of values to append
+ /// \param[in] valid_bytes an optional sequence of bytes where non-zero
+ /// indicates a valid (non-null) value
+ /// \return Status
+ Status AppendValues(const uint64_t* values, int64_t length,
+ const uint8_t* valid_bytes = NULLPTR);
+
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
+
+ std::shared_ptr<DataType> type() const override;
+
+ protected:
+ Status CommitPendingData() override;
+ Status ExpandIntSize(uint8_t new_int_size);
+
+ Status AppendValuesInternal(const uint64_t* values, int64_t length,
+ const uint8_t* valid_bytes);
+
+ template <typename new_type>
+ Status ExpandIntSizeN();
+};
+
+class ARROW_EXPORT AdaptiveIntBuilder : public internal::AdaptiveIntBuilderBase {
+ public:
+ explicit AdaptiveIntBuilder(uint8_t start_int_size,
+ MemoryPool* pool = default_memory_pool());
+
+ explicit AdaptiveIntBuilder(MemoryPool* pool = default_memory_pool())
+ : AdaptiveIntBuilder(sizeof(uint8_t), pool) {}
+
+ using ArrayBuilder::Advance;
+ using internal::AdaptiveIntBuilderBase::Reset;
+
+ /// Scalar append
+ Status Append(const int64_t val) { return AppendInternal(static_cast<uint64_t>(val)); }
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values a contiguous C array of values
+ /// \param[in] length the number of values to append
+ /// \param[in] valid_bytes an optional sequence of bytes where non-zero
+ /// indicates a valid (non-null) value
+ /// \return Status
+ Status AppendValues(const int64_t* values, int64_t length,
+ const uint8_t* valid_bytes = NULLPTR);
+
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
+
+ std::shared_ptr<DataType> type() const override;
+
+ protected:
+ Status CommitPendingData() override;
+ Status ExpandIntSize(uint8_t new_int_size);
+
+ Status AppendValuesInternal(const int64_t* values, int64_t length,
+ const uint8_t* valid_bytes);
+
+ template <typename new_type>
+ Status ExpandIntSizeN();
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_base.cc b/src/arrow/cpp/src/arrow/array/builder_base.cc
new file mode 100644
index 000000000..117b9d376
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_base.cc
@@ -0,0 +1,336 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/builder_base.h"
+
+#include <cstdint>
+#include <type_traits>
+#include <vector>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/builder_dict.h"
+#include "arrow/array/data.h"
+#include "arrow/array/util.h"
+#include "arrow/buffer.h"
+#include "arrow/builder.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+Status ArrayBuilder::CheckArrayType(const std::shared_ptr<DataType>& expected_type,
+ const Array& array, const char* message) {
+ if (!expected_type->Equals(*array.type())) {
+ return Status::TypeError(message);
+ }
+ return Status::OK();
+}
+
+Status ArrayBuilder::CheckArrayType(Type::type expected_type, const Array& array,
+ const char* message) {
+ if (array.type_id() != expected_type) {
+ return Status::TypeError(message);
+ }
+ return Status::OK();
+}
+
+Status ArrayBuilder::TrimBuffer(const int64_t bytes_filled, ResizableBuffer* buffer) {
+ if (buffer) {
+ if (bytes_filled < buffer->size()) {
+ // Trim buffer
+ RETURN_NOT_OK(buffer->Resize(bytes_filled));
+ }
+ // zero the padding
+ buffer->ZeroPadding();
+ } else {
+ // Null buffers are allowed in place of 0-byte buffers
+ DCHECK_EQ(bytes_filled, 0);
+ }
+ return Status::OK();
+}
+
+Status ArrayBuilder::AppendToBitmap(bool is_valid) {
+ RETURN_NOT_OK(Reserve(1));
+ UnsafeAppendToBitmap(is_valid);
+ return Status::OK();
+}
+
+Status ArrayBuilder::AppendToBitmap(const uint8_t* valid_bytes, int64_t length) {
+ RETURN_NOT_OK(Reserve(length));
+ UnsafeAppendToBitmap(valid_bytes, length);
+ return Status::OK();
+}
+
+Status ArrayBuilder::AppendToBitmap(int64_t num_bits, bool value) {
+ RETURN_NOT_OK(Reserve(num_bits));
+ UnsafeAppendToBitmap(num_bits, value);
+ return Status::OK();
+}
+
+Status ArrayBuilder::Resize(int64_t capacity) {
+ RETURN_NOT_OK(CheckCapacity(capacity));
+ capacity_ = capacity;
+ return null_bitmap_builder_.Resize(capacity);
+}
+
+Status ArrayBuilder::Advance(int64_t elements) {
+ if (length_ + elements > capacity_) {
+ return Status::Invalid("Builder must be expanded");
+ }
+ length_ += elements;
+ return null_bitmap_builder_.Advance(elements);
+}
+
+namespace {
+
+struct AppendScalarImpl {
+ template <typename T>
+ enable_if_t<has_c_type<T>::value || is_decimal_type<T>::value ||
+ is_fixed_size_binary_type<T>::value,
+ Status>
+ Visit(const T&) {
+ auto builder = internal::checked_cast<typename TypeTraits<T>::BuilderType*>(builder_);
+ RETURN_NOT_OK(builder->Reserve(n_repeats_ * (scalars_end_ - scalars_begin_)));
+
+ for (int64_t i = 0; i < n_repeats_; i++) {
+ for (const std::shared_ptr<Scalar>* raw = scalars_begin_; raw != scalars_end_;
+ raw++) {
+ auto scalar =
+ internal::checked_cast<const typename TypeTraits<T>::ScalarType*>(raw->get());
+ if (scalar->is_valid) {
+ builder->UnsafeAppend(scalar->value);
+ } else {
+ builder->UnsafeAppendNull();
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_base_binary<T, Status> Visit(const T&) {
+ int64_t data_size = 0;
+ for (const std::shared_ptr<Scalar>* raw = scalars_begin_; raw != scalars_end_;
+ raw++) {
+ auto scalar =
+ internal::checked_cast<const typename TypeTraits<T>::ScalarType*>(raw->get());
+ if (scalar->is_valid) {
+ data_size += scalar->value->size();
+ }
+ }
+
+ auto builder = internal::checked_cast<typename TypeTraits<T>::BuilderType*>(builder_);
+ RETURN_NOT_OK(builder->Reserve(n_repeats_ * (scalars_end_ - scalars_begin_)));
+ RETURN_NOT_OK(builder->ReserveData(n_repeats_ * data_size));
+
+ for (int64_t i = 0; i < n_repeats_; i++) {
+ for (const std::shared_ptr<Scalar>* raw = scalars_begin_; raw != scalars_end_;
+ raw++) {
+ auto scalar =
+ internal::checked_cast<const typename TypeTraits<T>::ScalarType*>(raw->get());
+ if (scalar->is_valid) {
+ builder->UnsafeAppend(util::string_view{*scalar->value});
+ } else {
+ builder->UnsafeAppendNull();
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_list_like<T, Status> Visit(const T&) {
+ auto builder = internal::checked_cast<typename TypeTraits<T>::BuilderType*>(builder_);
+ int64_t num_children = 0;
+ for (const std::shared_ptr<Scalar>* scalar = scalars_begin_; scalar != scalars_end_;
+ scalar++) {
+ if (!(*scalar)->is_valid) continue;
+ num_children +=
+ internal::checked_cast<const BaseListScalar&>(**scalar).value->length();
+ }
+ RETURN_NOT_OK(builder->value_builder()->Reserve(num_children * n_repeats_));
+
+ for (int64_t i = 0; i < n_repeats_; i++) {
+ for (const std::shared_ptr<Scalar>* scalar = scalars_begin_; scalar != scalars_end_;
+ scalar++) {
+ if ((*scalar)->is_valid) {
+ RETURN_NOT_OK(builder->Append());
+ const Array& list =
+ *internal::checked_cast<const BaseListScalar&>(**scalar).value;
+ for (int64_t i = 0; i < list.length(); i++) {
+ ARROW_ASSIGN_OR_RAISE(auto scalar, list.GetScalar(i));
+ RETURN_NOT_OK(builder->value_builder()->AppendScalar(*scalar));
+ }
+ } else {
+ RETURN_NOT_OK(builder_->AppendNull());
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const StructType& type) {
+ auto* builder = internal::checked_cast<StructBuilder*>(builder_);
+ auto count = n_repeats_ * (scalars_end_ - scalars_begin_);
+ RETURN_NOT_OK(builder->Reserve(count));
+ for (int field_index = 0; field_index < type.num_fields(); ++field_index) {
+ RETURN_NOT_OK(builder->field_builder(field_index)->Reserve(count));
+ }
+ for (int64_t i = 0; i < n_repeats_; i++) {
+ for (const std::shared_ptr<Scalar>* s = scalars_begin_; s != scalars_end_; s++) {
+ const auto& scalar = internal::checked_cast<const StructScalar&>(**s);
+ for (int field_index = 0; field_index < type.num_fields(); ++field_index) {
+ if (!scalar.is_valid || !scalar.value[field_index]) {
+ RETURN_NOT_OK(builder->field_builder(field_index)->AppendNull());
+ } else {
+ RETURN_NOT_OK(builder->field_builder(field_index)
+ ->AppendScalar(*scalar.value[field_index]));
+ }
+ }
+ RETURN_NOT_OK(builder->Append(scalar.is_valid));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const SparseUnionType& type) { return MakeUnionArray(type); }
+
+ Status Visit(const DenseUnionType& type) { return MakeUnionArray(type); }
+
+ template <typename T>
+ Status MakeUnionArray(const T& type) {
+ using BuilderType = typename TypeTraits<T>::BuilderType;
+ constexpr bool is_dense = std::is_same<T, DenseUnionType>::value;
+
+ auto* builder = internal::checked_cast<BuilderType*>(builder_);
+ const auto count = n_repeats_ * (scalars_end_ - scalars_begin_);
+
+ RETURN_NOT_OK(builder->Reserve(count));
+
+ DCHECK_EQ(type.num_fields(), builder->num_children());
+ for (int field_index = 0; field_index < type.num_fields(); ++field_index) {
+ RETURN_NOT_OK(builder->child_builder(field_index)->Reserve(count));
+ }
+
+ for (int64_t i = 0; i < n_repeats_; i++) {
+ for (const std::shared_ptr<Scalar>* s = scalars_begin_; s != scalars_end_; s++) {
+ // For each scalar,
+ // 1. append the type code,
+ // 2. append the value to the corresponding child,
+ // 3. if the union is sparse, append null to the other children.
+ const auto& scalar = internal::checked_cast<const UnionScalar&>(**s);
+ const auto scalar_field_index = type.child_ids()[scalar.type_code];
+ RETURN_NOT_OK(builder->Append(scalar.type_code));
+
+ for (int field_index = 0; field_index < type.num_fields(); ++field_index) {
+ auto* child_builder = builder->child_builder(field_index).get();
+ if (field_index == scalar_field_index) {
+ if (scalar.is_valid) {
+ RETURN_NOT_OK(child_builder->AppendScalar(*scalar.value));
+ } else {
+ RETURN_NOT_OK(child_builder->AppendNull());
+ }
+ } else if (!is_dense) {
+ RETURN_NOT_OK(child_builder->AppendNull());
+ }
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DataType& type) {
+ return Status::NotImplemented("AppendScalar for type ", type);
+ }
+
+ Status Convert() { return VisitTypeInline(*(*scalars_begin_)->type, this); }
+
+ const std::shared_ptr<Scalar>* scalars_begin_;
+ const std::shared_ptr<Scalar>* scalars_end_;
+ int64_t n_repeats_;
+ ArrayBuilder* builder_;
+};
+
+} // namespace
+
+Status ArrayBuilder::AppendScalar(const Scalar& scalar, int64_t n_repeats) {
+ if (!scalar.type->Equals(type())) {
+ return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(),
+ " to builder for type ", type()->ToString());
+ }
+ std::shared_ptr<Scalar> shared{const_cast<Scalar*>(&scalar), [](Scalar*) {}};
+ return AppendScalarImpl{&shared, &shared + 1, n_repeats, this}.Convert();
+}
+
+Status ArrayBuilder::AppendScalars(const ScalarVector& scalars) {
+ if (scalars.empty()) return Status::OK();
+ const auto ty = type();
+ for (const auto& scalar : scalars) {
+ if (!scalar->type->Equals(ty)) {
+ return Status::Invalid("Cannot append scalar of type ", scalar->type->ToString(),
+ " to builder for type ", type()->ToString());
+ }
+ }
+ return AppendScalarImpl{scalars.data(), scalars.data() + scalars.size(),
+ /*n_repeats=*/1, this}
+ .Convert();
+}
+
+Status ArrayBuilder::Finish(std::shared_ptr<Array>* out) {
+ std::shared_ptr<ArrayData> internal_data;
+ RETURN_NOT_OK(FinishInternal(&internal_data));
+ *out = MakeArray(internal_data);
+ return Status::OK();
+}
+
+Result<std::shared_ptr<Array>> ArrayBuilder::Finish() {
+ std::shared_ptr<Array> out;
+ RETURN_NOT_OK(Finish(&out));
+ return out;
+}
+
+void ArrayBuilder::Reset() {
+ capacity_ = length_ = null_count_ = 0;
+ null_bitmap_builder_.Reset();
+}
+
+Status ArrayBuilder::SetNotNull(int64_t length) {
+ RETURN_NOT_OK(Reserve(length));
+ UnsafeSetNotNull(length);
+ return Status::OK();
+}
+
+void ArrayBuilder::UnsafeAppendToBitmap(const std::vector<bool>& is_valid) {
+ for (bool element_valid : is_valid) {
+ UnsafeAppendToBitmap(element_valid);
+ }
+}
+
+void ArrayBuilder::UnsafeSetNotNull(int64_t length) {
+ length_ += length;
+ null_bitmap_builder_.UnsafeAppend(length, true);
+}
+
+void ArrayBuilder::UnsafeSetNull(int64_t length) {
+ length_ += length;
+ null_count_ += length;
+ null_bitmap_builder_.UnsafeAppend(length, false);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_base.h b/src/arrow/cpp/src/arrow/array/builder_base.h
new file mode 100644
index 000000000..a513bf0f4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_base.h
@@ -0,0 +1,307 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm> // IWYU pragma: keep
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_primitive.h"
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+constexpr int64_t kMinBuilderCapacity = 1 << 5;
+constexpr int64_t kListMaximumElements = std::numeric_limits<int32_t>::max() - 1;
+
+/// Base class for all data array builders.
+///
+/// This class provides a facilities for incrementally building the null bitmap
+/// (see Append methods) and as a side effect the current number of slots and
+/// the null count.
+///
+/// \note Users are expected to use builders as one of the concrete types below.
+/// For example, ArrayBuilder* pointing to BinaryBuilder should be downcast before use.
+class ARROW_EXPORT ArrayBuilder {
+ public:
+ explicit ArrayBuilder(MemoryPool* pool) : pool_(pool), null_bitmap_builder_(pool) {}
+
+ ARROW_DEFAULT_MOVE_AND_ASSIGN(ArrayBuilder);
+
+ virtual ~ArrayBuilder() = default;
+
+ /// For nested types. Since the objects are owned by this class instance, we
+ /// skip shared pointers and just return a raw pointer
+ ArrayBuilder* child(int i) { return children_[i].get(); }
+
+ const std::shared_ptr<ArrayBuilder>& child_builder(int i) const { return children_[i]; }
+
+ int num_children() const { return static_cast<int>(children_.size()); }
+
+ virtual int64_t length() const { return length_; }
+ int64_t null_count() const { return null_count_; }
+ int64_t capacity() const { return capacity_; }
+
+ /// \brief Ensure that enough memory has been allocated to fit the indicated
+ /// number of total elements in the builder, including any that have already
+ /// been appended. Does not account for reallocations that may be due to
+ /// variable size data, like binary values. To make space for incremental
+ /// appends, use Reserve instead.
+ ///
+ /// \param[in] capacity the minimum number of total array values to
+ /// accommodate. Must be greater than the current capacity.
+ /// \return Status
+ virtual Status Resize(int64_t capacity);
+
+ /// \brief Ensure that there is enough space allocated to append the indicated
+ /// number of elements without any further reallocation. Overallocation is
+ /// used in order to minimize the impact of incremental Reserve() calls.
+ /// Note that additional_capacity is relative to the current number of elements
+ /// rather than to the current capacity, so calls to Reserve() which are not
+ /// interspersed with addition of new elements may not increase the capacity.
+ ///
+ /// \param[in] additional_capacity the number of additional array values
+ /// \return Status
+ Status Reserve(int64_t additional_capacity) {
+ auto current_capacity = capacity();
+ auto min_capacity = length() + additional_capacity;
+ if (min_capacity <= current_capacity) return Status::OK();
+
+ // leave growth factor up to BufferBuilder
+ auto new_capacity = BufferBuilder::GrowByFactor(current_capacity, min_capacity);
+ return Resize(new_capacity);
+ }
+
+ /// Reset the builder.
+ virtual void Reset();
+
+ /// \brief Append a null value to builder
+ virtual Status AppendNull() = 0;
+ /// \brief Append a number of null values to builder
+ virtual Status AppendNulls(int64_t length) = 0;
+
+ /// \brief Append a non-null value to builder
+ ///
+ /// The appended value is an implementation detail, but the corresponding
+ /// memory slot is guaranteed to be initialized.
+ /// This method is useful when appending a null value to a parent nested type.
+ virtual Status AppendEmptyValue() = 0;
+
+ /// \brief Append a number of non-null values to builder
+ ///
+ /// The appended values are an implementation detail, but the corresponding
+ /// memory slot is guaranteed to be initialized.
+ /// This method is useful when appending null values to a parent nested type.
+ virtual Status AppendEmptyValues(int64_t length) = 0;
+
+ /// \brief Append a value from a scalar
+ Status AppendScalar(const Scalar& scalar) { return AppendScalar(scalar, 1); }
+ virtual Status AppendScalar(const Scalar& scalar, int64_t n_repeats);
+ virtual Status AppendScalars(const ScalarVector& scalars);
+
+ /// \brief Append a range of values from an array.
+ ///
+ /// The given array must be the same type as the builder.
+ virtual Status AppendArraySlice(const ArrayData& array, int64_t offset,
+ int64_t length) {
+ return Status::NotImplemented("AppendArraySlice for builder for ", *type());
+ }
+
+ /// For cases where raw data was memcpy'd into the internal buffers, allows us
+ /// to advance the length of the builder. It is your responsibility to use
+ /// this function responsibly.
+ ARROW_DEPRECATED(
+ "Deprecated in 6.0.0. ArrayBuilder::Advance is poorly supported and mostly "
+ "untested.\nFor low-level control over buffer construction, use BufferBuilder "
+ "or TypedBufferBuilder directly.")
+ Status Advance(int64_t elements);
+
+ /// \brief Return result of builder as an internal generic ArrayData
+ /// object. Resets builder except for dictionary builder
+ ///
+ /// \param[out] out the finalized ArrayData object
+ /// \return Status
+ virtual Status FinishInternal(std::shared_ptr<ArrayData>* out) = 0;
+
+ /// \brief Return result of builder as an Array object.
+ ///
+ /// The builder is reset except for DictionaryBuilder.
+ ///
+ /// \param[out] out the finalized Array object
+ /// \return Status
+ Status Finish(std::shared_ptr<Array>* out);
+
+ /// \brief Return result of builder as an Array object.
+ ///
+ /// The builder is reset except for DictionaryBuilder.
+ ///
+ /// \return The finalized Array object
+ Result<std::shared_ptr<Array>> Finish();
+
+ /// \brief Return the type of the built Array
+ virtual std::shared_ptr<DataType> type() const = 0;
+
+ protected:
+ /// Append to null bitmap
+ Status AppendToBitmap(bool is_valid);
+
+ /// Vector append. Treat each zero byte as a null. If valid_bytes is null
+ /// assume all of length bits are valid.
+ Status AppendToBitmap(const uint8_t* valid_bytes, int64_t length);
+
+ /// Uniform append. Append N times the same validity bit.
+ Status AppendToBitmap(int64_t num_bits, bool value);
+
+ /// Set the next length bits to not null (i.e. valid).
+ Status SetNotNull(int64_t length);
+
+ // Unsafe operations (don't check capacity/don't resize)
+
+ void UnsafeAppendNull() { UnsafeAppendToBitmap(false); }
+
+ // Append to null bitmap, update the length
+ void UnsafeAppendToBitmap(bool is_valid) {
+ null_bitmap_builder_.UnsafeAppend(is_valid);
+ ++length_;
+ if (!is_valid) ++null_count_;
+ }
+
+ // Vector append. Treat each zero byte as a nullzero. If valid_bytes is null
+ // assume all of length bits are valid.
+ void UnsafeAppendToBitmap(const uint8_t* valid_bytes, int64_t length) {
+ if (valid_bytes == NULLPTR) {
+ return UnsafeSetNotNull(length);
+ }
+ null_bitmap_builder_.UnsafeAppend(valid_bytes, length);
+ length_ += length;
+ null_count_ = null_bitmap_builder_.false_count();
+ }
+
+ // Vector append. Copy from a given bitmap. If bitmap is null assume
+ // all of length bits are valid.
+ void UnsafeAppendToBitmap(const uint8_t* bitmap, int64_t offset, int64_t length) {
+ if (bitmap == NULLPTR) {
+ return UnsafeSetNotNull(length);
+ }
+ null_bitmap_builder_.UnsafeAppend(bitmap, offset, length);
+ length_ += length;
+ null_count_ = null_bitmap_builder_.false_count();
+ }
+
+ // Append the same validity value a given number of times.
+ void UnsafeAppendToBitmap(const int64_t num_bits, bool value) {
+ if (value) {
+ UnsafeSetNotNull(num_bits);
+ } else {
+ UnsafeSetNull(num_bits);
+ }
+ }
+
+ void UnsafeAppendToBitmap(const std::vector<bool>& is_valid);
+
+ // Set the next validity bits to not null (i.e. valid).
+ void UnsafeSetNotNull(int64_t length);
+
+ // Set the next validity bits to null (i.e. invalid).
+ void UnsafeSetNull(int64_t length);
+
+ static Status TrimBuffer(const int64_t bytes_filled, ResizableBuffer* buffer);
+
+ /// \brief Finish to an array of the specified ArrayType
+ template <typename ArrayType>
+ Status FinishTyped(std::shared_ptr<ArrayType>* out) {
+ std::shared_ptr<Array> out_untyped;
+ ARROW_RETURN_NOT_OK(Finish(&out_untyped));
+ *out = std::static_pointer_cast<ArrayType>(std::move(out_untyped));
+ return Status::OK();
+ }
+
+ // Check the requested capacity for validity
+ Status CheckCapacity(int64_t new_capacity) {
+ if (ARROW_PREDICT_FALSE(new_capacity < 0)) {
+ return Status::Invalid(
+ "Resize capacity must be positive (requested: ", new_capacity, ")");
+ }
+
+ if (ARROW_PREDICT_FALSE(new_capacity < length_)) {
+ return Status::Invalid("Resize cannot downsize (requested: ", new_capacity,
+ ", current length: ", length_, ")");
+ }
+
+ return Status::OK();
+ }
+
+ // Check for array type
+ Status CheckArrayType(const std::shared_ptr<DataType>& expected_type,
+ const Array& array, const char* message);
+ Status CheckArrayType(Type::type expected_type, const Array& array,
+ const char* message);
+
+ MemoryPool* pool_;
+
+ TypedBufferBuilder<bool> null_bitmap_builder_;
+ int64_t null_count_ = 0;
+
+ // Array length, so far. Also, the index of the next element to be added
+ int64_t length_ = 0;
+ int64_t capacity_ = 0;
+
+ // Child value array builders. These are owned by this class
+ std::vector<std::shared_ptr<ArrayBuilder>> children_;
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ArrayBuilder);
+};
+
+/// \brief Construct an empty ArrayBuilder corresponding to the data
+/// type
+/// \param[in] pool the MemoryPool to use for allocations
+/// \param[in] type the data type to create the builder for
+/// \param[out] out the created ArrayBuilder
+ARROW_EXPORT
+Status MakeBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
+ std::unique_ptr<ArrayBuilder>* out);
+
+/// \brief Construct an empty ArrayBuilder corresponding to the data
+/// type, where any top-level or nested dictionary builders return the
+/// exact index type specified by the type.
+ARROW_EXPORT
+Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr<DataType>& type,
+ std::unique_ptr<ArrayBuilder>* out);
+
+/// \brief Construct an empty DictionaryBuilder initialized optionally
+/// with a pre-existing dictionary
+/// \param[in] pool the MemoryPool to use for allocations
+/// \param[in] type the dictionary type to create the builder for
+/// \param[in] dictionary the initial dictionary, if any. May be nullptr
+/// \param[out] out the created ArrayBuilder
+ARROW_EXPORT
+Status MakeDictionaryBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Array>& dictionary,
+ std::unique_ptr<ArrayBuilder>* out);
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_binary.cc b/src/arrow/cpp/src/arrow/array/builder_binary.cc
new file mode 100644
index 000000000..fd1be1798
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_binary.cc
@@ -0,0 +1,207 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/builder_binary.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+#include <numeric>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+// ----------------------------------------------------------------------
+// Fixed width binary
+
+FixedSizeBinaryBuilder::FixedSizeBinaryBuilder(const std::shared_ptr<DataType>& type,
+ MemoryPool* pool)
+ : ArrayBuilder(pool),
+ byte_width_(checked_cast<const FixedSizeBinaryType&>(*type).byte_width()),
+ byte_builder_(pool) {}
+
+void FixedSizeBinaryBuilder::CheckValueSize(int64_t size) {
+ DCHECK_EQ(size, byte_width_) << "Appending wrong size to FixedSizeBinaryBuilder";
+}
+
+Status FixedSizeBinaryBuilder::AppendValues(const uint8_t* data, int64_t length,
+ const uint8_t* valid_bytes) {
+ RETURN_NOT_OK(Reserve(length));
+ UnsafeAppendToBitmap(valid_bytes, length);
+ return byte_builder_.Append(data, length * byte_width_);
+}
+
+Status FixedSizeBinaryBuilder::AppendValues(const uint8_t* data, int64_t length,
+ const uint8_t* validity,
+ int64_t bitmap_offset) {
+ RETURN_NOT_OK(Reserve(length));
+ UnsafeAppendToBitmap(validity, bitmap_offset, length);
+ return byte_builder_.Append(data, length * byte_width_);
+}
+
+Status FixedSizeBinaryBuilder::AppendNull() {
+ RETURN_NOT_OK(Reserve(1));
+ UnsafeAppendNull();
+ return Status::OK();
+}
+
+Status FixedSizeBinaryBuilder::AppendNulls(int64_t length) {
+ RETURN_NOT_OK(Reserve(length));
+ UnsafeAppendToBitmap(length, false);
+ byte_builder_.UnsafeAppend(/*num_copies=*/length * byte_width_, 0);
+ return Status::OK();
+}
+
+Status FixedSizeBinaryBuilder::AppendEmptyValue() {
+ RETURN_NOT_OK(Reserve(1));
+ UnsafeAppendToBitmap(true);
+ byte_builder_.UnsafeAppend(/*num_copies=*/byte_width_, 0);
+ return Status::OK();
+}
+
+Status FixedSizeBinaryBuilder::AppendEmptyValues(int64_t length) {
+ RETURN_NOT_OK(Reserve(length));
+ UnsafeAppendToBitmap(length, true);
+ byte_builder_.UnsafeAppend(/*num_copies=*/length * byte_width_, 0);
+ return Status::OK();
+}
+
+void FixedSizeBinaryBuilder::Reset() {
+ ArrayBuilder::Reset();
+ byte_builder_.Reset();
+}
+
+Status FixedSizeBinaryBuilder::Resize(int64_t capacity) {
+ RETURN_NOT_OK(CheckCapacity(capacity));
+ RETURN_NOT_OK(byte_builder_.Resize(capacity * byte_width_));
+ return ArrayBuilder::Resize(capacity);
+}
+
+Status FixedSizeBinaryBuilder::FinishInternal(std::shared_ptr<ArrayData>* out) {
+ std::shared_ptr<Buffer> data;
+ RETURN_NOT_OK(byte_builder_.Finish(&data));
+
+ std::shared_ptr<Buffer> null_bitmap;
+ RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
+ *out = ArrayData::Make(type(), length_, {null_bitmap, data}, null_count_);
+
+ capacity_ = length_ = null_count_ = 0;
+ return Status::OK();
+}
+
+const uint8_t* FixedSizeBinaryBuilder::GetValue(int64_t i) const {
+ const uint8_t* data_ptr = byte_builder_.data();
+ return data_ptr + i * byte_width_;
+}
+
+util::string_view FixedSizeBinaryBuilder::GetView(int64_t i) const {
+ const uint8_t* data_ptr = byte_builder_.data();
+ return util::string_view(reinterpret_cast<const char*>(data_ptr + i * byte_width_),
+ byte_width_);
+}
+
+// ----------------------------------------------------------------------
+// ChunkedArray builders
+
+namespace internal {
+
+ChunkedBinaryBuilder::ChunkedBinaryBuilder(int32_t max_chunk_value_length,
+ MemoryPool* pool)
+ : max_chunk_value_length_(max_chunk_value_length), builder_(new BinaryBuilder(pool)) {
+ DCHECK_LE(max_chunk_value_length, kBinaryMemoryLimit);
+}
+
+ChunkedBinaryBuilder::ChunkedBinaryBuilder(int32_t max_chunk_value_length,
+ int32_t max_chunk_length, MemoryPool* pool)
+ : ChunkedBinaryBuilder(max_chunk_value_length, pool) {
+ max_chunk_length_ = max_chunk_length;
+}
+
+Status ChunkedBinaryBuilder::Finish(ArrayVector* out) {
+ if (builder_->length() > 0 || chunks_.size() == 0) {
+ std::shared_ptr<Array> chunk;
+ RETURN_NOT_OK(builder_->Finish(&chunk));
+ chunks_.emplace_back(std::move(chunk));
+ }
+ *out = std::move(chunks_);
+ return Status::OK();
+}
+
+Status ChunkedBinaryBuilder::NextChunk() {
+ std::shared_ptr<Array> chunk;
+ RETURN_NOT_OK(builder_->Finish(&chunk));
+ chunks_.emplace_back(std::move(chunk));
+
+ if (auto capacity = extra_capacity_) {
+ extra_capacity_ = 0;
+ return Reserve(capacity);
+ }
+
+ return Status::OK();
+}
+
+Status ChunkedStringBuilder::Finish(ArrayVector* out) {
+ RETURN_NOT_OK(ChunkedBinaryBuilder::Finish(out));
+
+ // Change data type to string/utf8
+ for (size_t i = 0; i < out->size(); ++i) {
+ std::shared_ptr<ArrayData> data = (*out)[i]->data();
+ data->type = ::arrow::utf8();
+ (*out)[i] = std::make_shared<StringArray>(data);
+ }
+ return Status::OK();
+}
+
+Status ChunkedBinaryBuilder::Reserve(int64_t values) {
+ if (ARROW_PREDICT_FALSE(extra_capacity_ != 0)) {
+ extra_capacity_ += values;
+ return Status::OK();
+ }
+
+ auto current_capacity = builder_->capacity();
+ auto min_capacity = builder_->length() + values;
+ if (current_capacity >= min_capacity) {
+ return Status::OK();
+ }
+
+ auto new_capacity = BufferBuilder::GrowByFactor(current_capacity, min_capacity);
+ if (ARROW_PREDICT_TRUE(new_capacity <= max_chunk_length_)) {
+ return builder_->Resize(new_capacity);
+ }
+
+ extra_capacity_ = new_capacity - max_chunk_length_;
+ return builder_->Resize(max_chunk_length_);
+}
+
+} // namespace internal
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_binary.h b/src/arrow/cpp/src/arrow/array/builder_binary.h
new file mode 100644
index 000000000..6ca65113f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_binary.h
@@ -0,0 +1,697 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <array>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <numeric>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_binary.h"
+#include "arrow/array/builder_base.h"
+#include "arrow/array/data.h"
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_view.h" // IWYU pragma: export
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+// ----------------------------------------------------------------------
+// Binary and String
+
+template <typename TYPE>
+class BaseBinaryBuilder : public ArrayBuilder {
+ public:
+ using TypeClass = TYPE;
+ using offset_type = typename TypeClass::offset_type;
+
+ explicit BaseBinaryBuilder(MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool), offsets_builder_(pool), value_data_builder_(pool) {}
+
+ BaseBinaryBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool)
+ : BaseBinaryBuilder(pool) {}
+
+ Status Append(const uint8_t* value, offset_type length) {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ ARROW_RETURN_NOT_OK(AppendNextOffset());
+ // Safety check for UBSAN.
+ if (ARROW_PREDICT_TRUE(length > 0)) {
+ ARROW_RETURN_NOT_OK(ValidateOverflow(length));
+ ARROW_RETURN_NOT_OK(value_data_builder_.Append(value, length));
+ }
+
+ UnsafeAppendToBitmap(true);
+ return Status::OK();
+ }
+
+ Status Append(const char* value, offset_type length) {
+ return Append(reinterpret_cast<const uint8_t*>(value), length);
+ }
+
+ Status Append(util::string_view value) {
+ return Append(value.data(), static_cast<offset_type>(value.size()));
+ }
+
+ /// Extend the last appended value by appending more data at the end
+ ///
+ /// Unlike Append, this does not create a new offset.
+ Status ExtendCurrent(const uint8_t* value, offset_type length) {
+ // Safety check for UBSAN.
+ if (ARROW_PREDICT_TRUE(length > 0)) {
+ ARROW_RETURN_NOT_OK(ValidateOverflow(length));
+ ARROW_RETURN_NOT_OK(value_data_builder_.Append(value, length));
+ }
+ return Status::OK();
+ }
+
+ Status ExtendCurrent(util::string_view value) {
+ return ExtendCurrent(reinterpret_cast<const uint8_t*>(value.data()),
+ static_cast<offset_type>(value.size()));
+ }
+
+ Status AppendNulls(int64_t length) final {
+ const int64_t num_bytes = value_data_builder_.length();
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ for (int64_t i = 0; i < length; ++i) {
+ offsets_builder_.UnsafeAppend(static_cast<offset_type>(num_bytes));
+ }
+ UnsafeAppendToBitmap(length, false);
+ return Status::OK();
+ }
+
+ Status AppendNull() final {
+ ARROW_RETURN_NOT_OK(AppendNextOffset());
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ UnsafeAppendToBitmap(false);
+ return Status::OK();
+ }
+
+ Status AppendEmptyValue() final {
+ ARROW_RETURN_NOT_OK(AppendNextOffset());
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ UnsafeAppendToBitmap(true);
+ return Status::OK();
+ }
+
+ Status AppendEmptyValues(int64_t length) final {
+ const int64_t num_bytes = value_data_builder_.length();
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ for (int64_t i = 0; i < length; ++i) {
+ offsets_builder_.UnsafeAppend(static_cast<offset_type>(num_bytes));
+ }
+ UnsafeAppendToBitmap(length, true);
+ return Status::OK();
+ }
+
+ /// \brief Append without checking capacity
+ ///
+ /// Offsets and data should have been presized using Reserve() and
+ /// ReserveData(), respectively.
+ void UnsafeAppend(const uint8_t* value, offset_type length) {
+ UnsafeAppendNextOffset();
+ value_data_builder_.UnsafeAppend(value, length);
+ UnsafeAppendToBitmap(true);
+ }
+
+ void UnsafeAppend(const char* value, offset_type length) {
+ UnsafeAppend(reinterpret_cast<const uint8_t*>(value), length);
+ }
+
+ void UnsafeAppend(const std::string& value) {
+ UnsafeAppend(value.c_str(), static_cast<offset_type>(value.size()));
+ }
+
+ void UnsafeAppend(util::string_view value) {
+ UnsafeAppend(value.data(), static_cast<offset_type>(value.size()));
+ }
+
+ /// Like ExtendCurrent, but do not check capacity
+ void UnsafeExtendCurrent(const uint8_t* value, offset_type length) {
+ value_data_builder_.UnsafeAppend(value, length);
+ }
+
+ void UnsafeExtendCurrent(util::string_view value) {
+ UnsafeExtendCurrent(reinterpret_cast<const uint8_t*>(value.data()),
+ static_cast<offset_type>(value.size()));
+ }
+
+ void UnsafeAppendNull() {
+ const int64_t num_bytes = value_data_builder_.length();
+ offsets_builder_.UnsafeAppend(static_cast<offset_type>(num_bytes));
+ UnsafeAppendToBitmap(false);
+ }
+
+ void UnsafeAppendEmptyValue() {
+ const int64_t num_bytes = value_data_builder_.length();
+ offsets_builder_.UnsafeAppend(static_cast<offset_type>(num_bytes));
+ UnsafeAppendToBitmap(true);
+ }
+
+ /// \brief Append a sequence of strings in one shot.
+ ///
+ /// \param[in] values a vector of strings
+ /// \param[in] valid_bytes an optional sequence of bytes where non-zero
+ /// indicates a valid (non-null) value
+ /// \return Status
+ Status AppendValues(const std::vector<std::string>& values,
+ const uint8_t* valid_bytes = NULLPTR) {
+ std::size_t total_length = std::accumulate(
+ values.begin(), values.end(), 0ULL,
+ [](uint64_t sum, const std::string& str) { return sum + str.size(); });
+ ARROW_RETURN_NOT_OK(Reserve(values.size()));
+ ARROW_RETURN_NOT_OK(value_data_builder_.Reserve(total_length));
+ ARROW_RETURN_NOT_OK(offsets_builder_.Reserve(values.size()));
+
+ if (valid_bytes != NULLPTR) {
+ for (std::size_t i = 0; i < values.size(); ++i) {
+ UnsafeAppendNextOffset();
+ if (valid_bytes[i]) {
+ value_data_builder_.UnsafeAppend(
+ reinterpret_cast<const uint8_t*>(values[i].data()), values[i].size());
+ }
+ }
+ } else {
+ for (std::size_t i = 0; i < values.size(); ++i) {
+ UnsafeAppendNextOffset();
+ value_data_builder_.UnsafeAppend(
+ reinterpret_cast<const uint8_t*>(values[i].data()), values[i].size());
+ }
+ }
+
+ UnsafeAppendToBitmap(valid_bytes, values.size());
+ return Status::OK();
+ }
+
+ /// \brief Append a sequence of nul-terminated strings in one shot.
+ /// If one of the values is NULL, it is processed as a null
+ /// value even if the corresponding valid_bytes entry is 1.
+ ///
+ /// \param[in] values a contiguous C array of nul-terminated char *
+ /// \param[in] length the number of values to append
+ /// \param[in] valid_bytes an optional sequence of bytes where non-zero
+ /// indicates a valid (non-null) value
+ /// \return Status
+ Status AppendValues(const char** values, int64_t length,
+ const uint8_t* valid_bytes = NULLPTR) {
+ std::size_t total_length = 0;
+ std::vector<std::size_t> value_lengths(length);
+ bool have_null_value = false;
+ for (int64_t i = 0; i < length; ++i) {
+ if (values[i] != NULLPTR) {
+ auto value_length = strlen(values[i]);
+ value_lengths[i] = value_length;
+ total_length += value_length;
+ } else {
+ have_null_value = true;
+ }
+ }
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ ARROW_RETURN_NOT_OK(ReserveData(total_length));
+
+ if (valid_bytes) {
+ int64_t valid_bytes_offset = 0;
+ for (int64_t i = 0; i < length; ++i) {
+ UnsafeAppendNextOffset();
+ if (valid_bytes[i]) {
+ if (values[i]) {
+ value_data_builder_.UnsafeAppend(reinterpret_cast<const uint8_t*>(values[i]),
+ value_lengths[i]);
+ } else {
+ UnsafeAppendToBitmap(valid_bytes + valid_bytes_offset,
+ i - valid_bytes_offset);
+ UnsafeAppendToBitmap(false);
+ valid_bytes_offset = i + 1;
+ }
+ }
+ }
+ UnsafeAppendToBitmap(valid_bytes + valid_bytes_offset, length - valid_bytes_offset);
+ } else {
+ if (have_null_value) {
+ std::vector<uint8_t> valid_vector(length, 0);
+ for (int64_t i = 0; i < length; ++i) {
+ UnsafeAppendNextOffset();
+ if (values[i]) {
+ value_data_builder_.UnsafeAppend(reinterpret_cast<const uint8_t*>(values[i]),
+ value_lengths[i]);
+ valid_vector[i] = 1;
+ }
+ }
+ UnsafeAppendToBitmap(valid_vector.data(), length);
+ } else {
+ for (int64_t i = 0; i < length; ++i) {
+ UnsafeAppendNextOffset();
+ value_data_builder_.UnsafeAppend(reinterpret_cast<const uint8_t*>(values[i]),
+ value_lengths[i]);
+ }
+ UnsafeAppendToBitmap(NULLPTR, length);
+ }
+ }
+ return Status::OK();
+ }
+
+ Status AppendArraySlice(const ArrayData& array, int64_t offset,
+ int64_t length) override {
+ auto bitmap = array.GetValues<uint8_t>(0, 0);
+ auto offsets = array.GetValues<offset_type>(1);
+ auto data = array.GetValues<uint8_t>(2, 0);
+ for (int64_t i = 0; i < length; i++) {
+ if (!bitmap || BitUtil::GetBit(bitmap, array.offset + offset + i)) {
+ const offset_type start = offsets[offset + i];
+ const offset_type end = offsets[offset + i + 1];
+ ARROW_RETURN_NOT_OK(Append(data + start, end - start));
+ } else {
+ ARROW_RETURN_NOT_OK(AppendNull());
+ }
+ }
+ return Status::OK();
+ }
+
+ void Reset() override {
+ ArrayBuilder::Reset();
+ offsets_builder_.Reset();
+ value_data_builder_.Reset();
+ }
+
+ Status ValidateOverflow(int64_t new_bytes) {
+ auto new_size = value_data_builder_.length() + new_bytes;
+ if (ARROW_PREDICT_FALSE(new_size > memory_limit())) {
+ return Status::CapacityError("array cannot contain more than ", memory_limit(),
+ " bytes, have ", new_size);
+ } else {
+ return Status::OK();
+ }
+ }
+
+ Status Resize(int64_t capacity) override {
+ ARROW_RETURN_NOT_OK(CheckCapacity(capacity));
+ // One more than requested for offsets
+ ARROW_RETURN_NOT_OK(offsets_builder_.Resize(capacity + 1));
+ return ArrayBuilder::Resize(capacity);
+ }
+
+ /// \brief Ensures there is enough allocated capacity to append the indicated
+ /// number of bytes to the value data buffer without additional allocations
+ Status ReserveData(int64_t elements) {
+ ARROW_RETURN_NOT_OK(ValidateOverflow(elements));
+ return value_data_builder_.Reserve(elements);
+ }
+
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
+ // Write final offset (values length)
+ ARROW_RETURN_NOT_OK(AppendNextOffset());
+
+ // These buffers' padding zeroed by BufferBuilder
+ std::shared_ptr<Buffer> offsets, value_data, null_bitmap;
+ ARROW_RETURN_NOT_OK(offsets_builder_.Finish(&offsets));
+ ARROW_RETURN_NOT_OK(value_data_builder_.Finish(&value_data));
+ ARROW_RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
+
+ *out = ArrayData::Make(type(), length_, {null_bitmap, offsets, value_data},
+ null_count_, 0);
+ Reset();
+ return Status::OK();
+ }
+
+ /// \return data pointer of the value date builder
+ const uint8_t* value_data() const { return value_data_builder_.data(); }
+ /// \return size of values buffer so far
+ int64_t value_data_length() const { return value_data_builder_.length(); }
+ /// \return capacity of values buffer
+ int64_t value_data_capacity() const { return value_data_builder_.capacity(); }
+
+ /// \return data pointer of the value date builder
+ const offset_type* offsets_data() const { return offsets_builder_.data(); }
+
+ /// Temporary access to a value.
+ ///
+ /// This pointer becomes invalid on the next modifying operation.
+ const uint8_t* GetValue(int64_t i, offset_type* out_length) const {
+ const offset_type* offsets = offsets_builder_.data();
+ const auto offset = offsets[i];
+ if (i == (length_ - 1)) {
+ *out_length = static_cast<offset_type>(value_data_builder_.length()) - offset;
+ } else {
+ *out_length = offsets[i + 1] - offset;
+ }
+ return value_data_builder_.data() + offset;
+ }
+
+ offset_type offset(int64_t i) const { return offsets_data()[i]; }
+
+ /// Temporary access to a value.
+ ///
+ /// This view becomes invalid on the next modifying operation.
+ util::string_view GetView(int64_t i) const {
+ offset_type value_length;
+ const uint8_t* value_data = GetValue(i, &value_length);
+ return util::string_view(reinterpret_cast<const char*>(value_data), value_length);
+ }
+
+ // Cannot make this a static attribute because of linking issues
+ static constexpr int64_t memory_limit() {
+ return std::numeric_limits<offset_type>::max() - 1;
+ }
+
+ protected:
+ TypedBufferBuilder<offset_type> offsets_builder_;
+ TypedBufferBuilder<uint8_t> value_data_builder_;
+
+ Status AppendNextOffset() {
+ const int64_t num_bytes = value_data_builder_.length();
+ return offsets_builder_.Append(static_cast<offset_type>(num_bytes));
+ }
+
+ void UnsafeAppendNextOffset() {
+ const int64_t num_bytes = value_data_builder_.length();
+ offsets_builder_.UnsafeAppend(static_cast<offset_type>(num_bytes));
+ }
+};
+
+/// \class BinaryBuilder
+/// \brief Builder class for variable-length binary data
+class ARROW_EXPORT BinaryBuilder : public BaseBinaryBuilder<BinaryType> {
+ public:
+ using BaseBinaryBuilder::BaseBinaryBuilder;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<BinaryArray>* out) { return FinishTyped(out); }
+
+ std::shared_ptr<DataType> type() const override { return binary(); }
+};
+
+/// \class StringBuilder
+/// \brief Builder class for UTF8 strings
+class ARROW_EXPORT StringBuilder : public BinaryBuilder {
+ public:
+ using BinaryBuilder::BinaryBuilder;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<StringArray>* out) { return FinishTyped(out); }
+
+ std::shared_ptr<DataType> type() const override { return utf8(); }
+};
+
+/// \class LargeBinaryBuilder
+/// \brief Builder class for large variable-length binary data
+class ARROW_EXPORT LargeBinaryBuilder : public BaseBinaryBuilder<LargeBinaryType> {
+ public:
+ using BaseBinaryBuilder::BaseBinaryBuilder;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<LargeBinaryArray>* out) { return FinishTyped(out); }
+
+ std::shared_ptr<DataType> type() const override { return large_binary(); }
+};
+
+/// \class LargeStringBuilder
+/// \brief Builder class for large UTF8 strings
+class ARROW_EXPORT LargeStringBuilder : public LargeBinaryBuilder {
+ public:
+ using LargeBinaryBuilder::LargeBinaryBuilder;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<LargeStringArray>* out) { return FinishTyped(out); }
+
+ std::shared_ptr<DataType> type() const override { return large_utf8(); }
+};
+
+// ----------------------------------------------------------------------
+// FixedSizeBinaryBuilder
+
+class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder {
+ public:
+ using TypeClass = FixedSizeBinaryType;
+
+ explicit FixedSizeBinaryBuilder(const std::shared_ptr<DataType>& type,
+ MemoryPool* pool = default_memory_pool());
+
+ Status Append(const uint8_t* value) {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ UnsafeAppend(value);
+ return Status::OK();
+ }
+
+ Status Append(const char* value) {
+ return Append(reinterpret_cast<const uint8_t*>(value));
+ }
+
+ Status Append(const util::string_view& view) {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ UnsafeAppend(view);
+ return Status::OK();
+ }
+
+ Status Append(const std::string& s) {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ UnsafeAppend(s);
+ return Status::OK();
+ }
+
+ Status Append(const Buffer& s) {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ UnsafeAppend(util::string_view(s));
+ return Status::OK();
+ }
+
+ Status Append(const std::shared_ptr<Buffer>& s) { return Append(*s); }
+
+ template <size_t NBYTES>
+ Status Append(const std::array<uint8_t, NBYTES>& value) {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ UnsafeAppend(
+ util::string_view(reinterpret_cast<const char*>(value.data()), value.size()));
+ return Status::OK();
+ }
+
+ Status AppendValues(const uint8_t* data, int64_t length,
+ const uint8_t* valid_bytes = NULLPTR);
+
+ Status AppendValues(const uint8_t* data, int64_t length, const uint8_t* validity,
+ int64_t bitmap_offset);
+
+ Status AppendNull() final;
+ Status AppendNulls(int64_t length) final;
+
+ Status AppendEmptyValue() final;
+ Status AppendEmptyValues(int64_t length) final;
+
+ Status AppendArraySlice(const ArrayData& array, int64_t offset,
+ int64_t length) override {
+ return AppendValues(
+ array.GetValues<uint8_t>(1, 0) + ((array.offset + offset) * byte_width_), length,
+ array.GetValues<uint8_t>(0, 0), array.offset + offset);
+ }
+
+ void UnsafeAppend(const uint8_t* value) {
+ UnsafeAppendToBitmap(true);
+ if (ARROW_PREDICT_TRUE(byte_width_ > 0)) {
+ byte_builder_.UnsafeAppend(value, byte_width_);
+ }
+ }
+
+ void UnsafeAppend(const char* value) {
+ UnsafeAppend(reinterpret_cast<const uint8_t*>(value));
+ }
+
+ void UnsafeAppend(util::string_view value) {
+#ifndef NDEBUG
+ CheckValueSize(static_cast<size_t>(value.size()));
+#endif
+ UnsafeAppend(reinterpret_cast<const uint8_t*>(value.data()));
+ }
+
+ void UnsafeAppend(const Buffer& s) { UnsafeAppend(util::string_view(s)); }
+
+ void UnsafeAppend(const std::shared_ptr<Buffer>& s) { UnsafeAppend(*s); }
+
+ void UnsafeAppendNull() {
+ UnsafeAppendToBitmap(false);
+ byte_builder_.UnsafeAppend(/*num_copies=*/byte_width_, 0);
+ }
+
+ Status ValidateOverflow(int64_t new_bytes) const {
+ auto new_size = byte_builder_.length() + new_bytes;
+ if (ARROW_PREDICT_FALSE(new_size > memory_limit())) {
+ return Status::CapacityError("array cannot contain more than ", memory_limit(),
+ " bytes, have ", new_size);
+ } else {
+ return Status::OK();
+ }
+ }
+
+ /// \brief Ensures there is enough allocated capacity to append the indicated
+ /// number of bytes to the value data buffer without additional allocations
+ Status ReserveData(int64_t elements) {
+ ARROW_RETURN_NOT_OK(ValidateOverflow(elements));
+ return byte_builder_.Reserve(elements);
+ }
+
+ void Reset() override;
+ Status Resize(int64_t capacity) override;
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<FixedSizeBinaryArray>* out) { return FinishTyped(out); }
+
+ /// \return size of values buffer so far
+ int64_t value_data_length() const { return byte_builder_.length(); }
+
+ int32_t byte_width() const { return byte_width_; }
+
+ /// Temporary access to a value.
+ ///
+ /// This pointer becomes invalid on the next modifying operation.
+ const uint8_t* GetValue(int64_t i) const;
+
+ /// Temporary access to a value.
+ ///
+ /// This view becomes invalid on the next modifying operation.
+ util::string_view GetView(int64_t i) const;
+
+ static constexpr int64_t memory_limit() {
+ return std::numeric_limits<int64_t>::max() - 1;
+ }
+
+ std::shared_ptr<DataType> type() const override {
+ return fixed_size_binary(byte_width_);
+ }
+
+ protected:
+ int32_t byte_width_;
+ BufferBuilder byte_builder_;
+
+ /// Temporary access to a value.
+ ///
+ /// This pointer becomes invalid on the next modifying operation.
+ uint8_t* GetMutableValue(int64_t i) {
+ uint8_t* data_ptr = byte_builder_.mutable_data();
+ return data_ptr + i * byte_width_;
+ }
+
+ void CheckValueSize(int64_t size);
+};
+
+// ----------------------------------------------------------------------
+// Chunked builders: build a sequence of BinaryArray or StringArray that are
+// limited to a particular size (to the upper limit of 2GB)
+
+namespace internal {
+
+class ARROW_EXPORT ChunkedBinaryBuilder {
+ public:
+ explicit ChunkedBinaryBuilder(int32_t max_chunk_value_length,
+ MemoryPool* pool = default_memory_pool());
+
+ ChunkedBinaryBuilder(int32_t max_chunk_value_length, int32_t max_chunk_length,
+ MemoryPool* pool = default_memory_pool());
+
+ virtual ~ChunkedBinaryBuilder() = default;
+
+ Status Append(const uint8_t* value, int32_t length) {
+ if (ARROW_PREDICT_FALSE(length + builder_->value_data_length() >
+ max_chunk_value_length_)) {
+ if (builder_->value_data_length() == 0) {
+ // The current item is larger than max_chunk_size_;
+ // this chunk will be oversize and hold *only* this item
+ ARROW_RETURN_NOT_OK(builder_->Append(value, length));
+ return NextChunk();
+ }
+ // The current item would cause builder_->value_data_length() to exceed
+ // max_chunk_size_, so finish this chunk and append the current item to the next
+ // chunk
+ ARROW_RETURN_NOT_OK(NextChunk());
+ return Append(value, length);
+ }
+
+ if (ARROW_PREDICT_FALSE(builder_->length() == max_chunk_length_)) {
+ // The current item would cause builder_->length() to exceed max_chunk_length_, so
+ // finish this chunk and append the current item to the next chunk
+ ARROW_RETURN_NOT_OK(NextChunk());
+ }
+
+ return builder_->Append(value, length);
+ }
+
+ Status Append(const util::string_view& value) {
+ return Append(reinterpret_cast<const uint8_t*>(value.data()),
+ static_cast<int32_t>(value.size()));
+ }
+
+ Status AppendNull() {
+ if (ARROW_PREDICT_FALSE(builder_->length() == max_chunk_length_)) {
+ ARROW_RETURN_NOT_OK(NextChunk());
+ }
+ return builder_->AppendNull();
+ }
+
+ Status Reserve(int64_t values);
+
+ virtual Status Finish(ArrayVector* out);
+
+ protected:
+ Status NextChunk();
+
+ // maximum total character data size per chunk
+ int64_t max_chunk_value_length_;
+
+ // maximum elements allowed per chunk
+ int64_t max_chunk_length_ = kListMaximumElements;
+
+ // when Reserve() would cause builder_ to exceed its max_chunk_length_,
+ // add to extra_capacity_ instead and wait to reserve until the next chunk
+ int64_t extra_capacity_ = 0;
+
+ std::unique_ptr<BinaryBuilder> builder_;
+ std::vector<std::shared_ptr<Array>> chunks_;
+};
+
+class ARROW_EXPORT ChunkedStringBuilder : public ChunkedBinaryBuilder {
+ public:
+ using ChunkedBinaryBuilder::ChunkedBinaryBuilder;
+
+ Status Finish(ArrayVector* out) override;
+};
+
+} // namespace internal
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_decimal.cc b/src/arrow/cpp/src/arrow/array/builder_decimal.cc
new file mode 100644
index 000000000..bd7615a73
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_decimal.cc
@@ -0,0 +1,105 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/builder_decimal.h"
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/array/data.h"
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/status.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+
+namespace arrow {
+
+class Buffer;
+class MemoryPool;
+
+// ----------------------------------------------------------------------
+// Decimal128Builder
+
+Decimal128Builder::Decimal128Builder(const std::shared_ptr<DataType>& type,
+ MemoryPool* pool)
+ : FixedSizeBinaryBuilder(type, pool),
+ decimal_type_(internal::checked_pointer_cast<Decimal128Type>(type)) {}
+
+Status Decimal128Builder::Append(Decimal128 value) {
+ RETURN_NOT_OK(FixedSizeBinaryBuilder::Reserve(1));
+ UnsafeAppend(value);
+ return Status::OK();
+}
+
+void Decimal128Builder::UnsafeAppend(Decimal128 value) {
+ value.ToBytes(GetMutableValue(length()));
+ byte_builder_.UnsafeAdvance(16);
+ UnsafeAppendToBitmap(true);
+}
+
+void Decimal128Builder::UnsafeAppend(util::string_view value) {
+ FixedSizeBinaryBuilder::UnsafeAppend(value);
+}
+
+Status Decimal128Builder::FinishInternal(std::shared_ptr<ArrayData>* out) {
+ std::shared_ptr<Buffer> data;
+ RETURN_NOT_OK(byte_builder_.Finish(&data));
+ std::shared_ptr<Buffer> null_bitmap;
+ RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
+
+ *out = ArrayData::Make(type(), length_, {null_bitmap, data}, null_count_);
+ capacity_ = length_ = null_count_ = 0;
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Decimal256Builder
+
+Decimal256Builder::Decimal256Builder(const std::shared_ptr<DataType>& type,
+ MemoryPool* pool)
+ : FixedSizeBinaryBuilder(type, pool),
+ decimal_type_(internal::checked_pointer_cast<Decimal256Type>(type)) {}
+
+Status Decimal256Builder::Append(const Decimal256& value) {
+ RETURN_NOT_OK(FixedSizeBinaryBuilder::Reserve(1));
+ UnsafeAppend(value);
+ return Status::OK();
+}
+
+void Decimal256Builder::UnsafeAppend(const Decimal256& value) {
+ value.ToBytes(GetMutableValue(length()));
+ byte_builder_.UnsafeAdvance(32);
+ UnsafeAppendToBitmap(true);
+}
+
+void Decimal256Builder::UnsafeAppend(util::string_view value) {
+ FixedSizeBinaryBuilder::UnsafeAppend(value);
+}
+
+Status Decimal256Builder::FinishInternal(std::shared_ptr<ArrayData>* out) {
+ std::shared_ptr<Buffer> data;
+ RETURN_NOT_OK(byte_builder_.Finish(&data));
+ std::shared_ptr<Buffer> null_bitmap;
+ RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
+
+ *out = ArrayData::Make(type(), length_, {null_bitmap, data}, null_count_);
+ capacity_ = length_ = null_count_ = 0;
+ return Status::OK();
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_decimal.h b/src/arrow/cpp/src/arrow/array/builder_decimal.h
new file mode 100644
index 000000000..f48392ed0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_decimal.h
@@ -0,0 +1,94 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/builder_base.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/data.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class ARROW_EXPORT Decimal128Builder : public FixedSizeBinaryBuilder {
+ public:
+ using TypeClass = Decimal128Type;
+ using ValueType = Decimal128;
+
+ explicit Decimal128Builder(const std::shared_ptr<DataType>& type,
+ MemoryPool* pool = default_memory_pool());
+
+ using FixedSizeBinaryBuilder::Append;
+ using FixedSizeBinaryBuilder::AppendValues;
+ using FixedSizeBinaryBuilder::Reset;
+
+ Status Append(Decimal128 val);
+ void UnsafeAppend(Decimal128 val);
+ void UnsafeAppend(util::string_view val);
+
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<Decimal128Array>* out) { return FinishTyped(out); }
+
+ std::shared_ptr<DataType> type() const override { return decimal_type_; }
+
+ protected:
+ std::shared_ptr<Decimal128Type> decimal_type_;
+};
+
+class ARROW_EXPORT Decimal256Builder : public FixedSizeBinaryBuilder {
+ public:
+ using TypeClass = Decimal256Type;
+ using ValueType = Decimal256;
+
+ explicit Decimal256Builder(const std::shared_ptr<DataType>& type,
+ MemoryPool* pool = default_memory_pool());
+
+ using FixedSizeBinaryBuilder::Append;
+ using FixedSizeBinaryBuilder::AppendValues;
+ using FixedSizeBinaryBuilder::Reset;
+
+ Status Append(const Decimal256& val);
+ void UnsafeAppend(const Decimal256& val);
+ void UnsafeAppend(util::string_view val);
+
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<Decimal256Array>* out) { return FinishTyped(out); }
+
+ std::shared_ptr<DataType> type() const override { return decimal_type_; }
+
+ protected:
+ std::shared_ptr<Decimal256Type> decimal_type_;
+};
+
+using DecimalBuilder = Decimal128Builder;
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_dict.cc b/src/arrow/cpp/src/arrow/array/builder_dict.cc
new file mode 100644
index 000000000..d24731699
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_dict.cc
@@ -0,0 +1,213 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/builder_dict.h"
+
+#include <cstdint>
+#include <utility>
+
+#include "arrow/array/dict_internal.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/hashing.h"
+#include "arrow/util/logging.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+// ----------------------------------------------------------------------
+// DictionaryBuilder
+
+namespace internal {
+
+class DictionaryMemoTable::DictionaryMemoTableImpl {
+ // Type-dependent visitor for memo table initialization
+ struct MemoTableInitializer {
+ std::shared_ptr<DataType> value_type_;
+ MemoryPool* pool_;
+ std::unique_ptr<MemoTable>* memo_table_;
+
+ template <typename T>
+ enable_if_no_memoize<T, Status> Visit(const T&) {
+ return Status::NotImplemented("Initialization of ", value_type_->ToString(),
+ " memo table is not implemented");
+ }
+
+ template <typename T>
+ enable_if_memoize<T, Status> Visit(const T&) {
+ using MemoTable = typename DictionaryTraits<T>::MemoTableType;
+ memo_table_->reset(new MemoTable(pool_, 0));
+ return Status::OK();
+ }
+ };
+
+ // Type-dependent visitor for memo table insertion
+ struct ArrayValuesInserter {
+ DictionaryMemoTableImpl* impl_;
+ const Array& values_;
+
+ template <typename T>
+ Status Visit(const T& type) {
+ using ArrayType = typename TypeTraits<T>::ArrayType;
+ return InsertValues(type, checked_cast<const ArrayType&>(values_));
+ }
+
+ private:
+ template <typename T, typename ArrayType>
+ enable_if_no_memoize<T, Status> InsertValues(const T& type, const ArrayType&) {
+ return Status::NotImplemented("Inserting array values of ", type,
+ " is not implemented");
+ }
+
+ template <typename T, typename ArrayType>
+ enable_if_memoize<T, Status> InsertValues(const T&, const ArrayType& array) {
+ if (array.null_count() > 0) {
+ return Status::Invalid("Cannot insert dictionary values containing nulls");
+ }
+ for (int64_t i = 0; i < array.length(); ++i) {
+ int32_t unused_memo_index;
+ RETURN_NOT_OK(impl_->GetOrInsert<T>(array.GetView(i), &unused_memo_index));
+ }
+ return Status::OK();
+ }
+ };
+
+ // Type-dependent visitor for building ArrayData from memo table
+ struct ArrayDataGetter {
+ std::shared_ptr<DataType> value_type_;
+ MemoTable* memo_table_;
+ MemoryPool* pool_;
+ int64_t start_offset_;
+ std::shared_ptr<ArrayData>* out_;
+
+ template <typename T>
+ enable_if_no_memoize<T, Status> Visit(const T&) {
+ return Status::NotImplemented("Getting array data of ", value_type_,
+ " is not implemented");
+ }
+
+ template <typename T>
+ enable_if_memoize<T, Status> Visit(const T&) {
+ using ConcreteMemoTable = typename DictionaryTraits<T>::MemoTableType;
+ auto memo_table = checked_cast<ConcreteMemoTable*>(memo_table_);
+ return DictionaryTraits<T>::GetDictionaryArrayData(pool_, value_type_, *memo_table,
+ start_offset_, out_);
+ }
+ };
+
+ public:
+ DictionaryMemoTableImpl(MemoryPool* pool, std::shared_ptr<DataType> type)
+ : pool_(pool), type_(std::move(type)), memo_table_(nullptr) {
+ MemoTableInitializer visitor{type_, pool_, &memo_table_};
+ ARROW_CHECK_OK(VisitTypeInline(*type_, &visitor));
+ }
+
+ Status InsertValues(const Array& array) {
+ if (!array.type()->Equals(*type_)) {
+ return Status::Invalid("Array value type does not match memo type: ",
+ array.type()->ToString());
+ }
+ ArrayValuesInserter visitor{this, array};
+ return VisitTypeInline(*array.type(), &visitor);
+ }
+
+ template <typename PhysicalType,
+ typename CType = typename DictionaryValue<PhysicalType>::type>
+ Status GetOrInsert(CType value, int32_t* out) {
+ using ConcreteMemoTable = typename DictionaryTraits<PhysicalType>::MemoTableType;
+ return checked_cast<ConcreteMemoTable*>(memo_table_.get())->GetOrInsert(value, out);
+ }
+
+ Status GetArrayData(int64_t start_offset, std::shared_ptr<ArrayData>* out) {
+ ArrayDataGetter visitor{type_, memo_table_.get(), pool_, start_offset, out};
+ return VisitTypeInline(*type_, &visitor);
+ }
+
+ int32_t size() const { return memo_table_->size(); }
+
+ private:
+ MemoryPool* pool_;
+ std::shared_ptr<DataType> type_;
+ std::unique_ptr<MemoTable> memo_table_;
+};
+
+DictionaryMemoTable::DictionaryMemoTable(MemoryPool* pool,
+ const std::shared_ptr<DataType>& type)
+ : impl_(new DictionaryMemoTableImpl(pool, type)) {}
+
+DictionaryMemoTable::DictionaryMemoTable(MemoryPool* pool,
+ const std::shared_ptr<Array>& dictionary)
+ : impl_(new DictionaryMemoTableImpl(pool, dictionary->type())) {
+ ARROW_CHECK_OK(impl_->InsertValues(*dictionary));
+}
+
+DictionaryMemoTable::~DictionaryMemoTable() = default;
+
+#define GET_OR_INSERT(ARROW_TYPE) \
+ Status DictionaryMemoTable::GetOrInsert( \
+ const ARROW_TYPE*, typename ARROW_TYPE::c_type value, int32_t* out) { \
+ return impl_->GetOrInsert<ARROW_TYPE>(value, out); \
+ }
+
+GET_OR_INSERT(BooleanType)
+GET_OR_INSERT(Int8Type)
+GET_OR_INSERT(Int16Type)
+GET_OR_INSERT(Int32Type)
+GET_OR_INSERT(Int64Type)
+GET_OR_INSERT(UInt8Type)
+GET_OR_INSERT(UInt16Type)
+GET_OR_INSERT(UInt32Type)
+GET_OR_INSERT(UInt64Type)
+GET_OR_INSERT(FloatType)
+GET_OR_INSERT(DoubleType)
+GET_OR_INSERT(DurationType);
+GET_OR_INSERT(TimestampType);
+GET_OR_INSERT(Date32Type);
+GET_OR_INSERT(Date64Type);
+GET_OR_INSERT(Time32Type);
+GET_OR_INSERT(Time64Type);
+GET_OR_INSERT(MonthDayNanoIntervalType);
+GET_OR_INSERT(DayTimeIntervalType);
+GET_OR_INSERT(MonthIntervalType);
+
+#undef GET_OR_INSERT
+
+Status DictionaryMemoTable::GetOrInsert(const BinaryType*, util::string_view value,
+ int32_t* out) {
+ return impl_->GetOrInsert<BinaryType>(value, out);
+}
+
+Status DictionaryMemoTable::GetOrInsert(const LargeBinaryType*, util::string_view value,
+ int32_t* out) {
+ return impl_->GetOrInsert<LargeBinaryType>(value, out);
+}
+
+Status DictionaryMemoTable::GetArrayData(int64_t start_offset,
+ std::shared_ptr<ArrayData>* out) {
+ return impl_->GetArrayData(start_offset, out);
+}
+
+Status DictionaryMemoTable::InsertValues(const Array& array) {
+ return impl_->InsertValues(array);
+}
+
+int32_t DictionaryMemoTable::size() const { return impl_->size(); }
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_dict.h b/src/arrow/cpp/src/arrow/array/builder_dict.h
new file mode 100644
index 000000000..76199deac
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_dict.h
@@ -0,0 +1,712 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <type_traits>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_binary.h"
+#include "arrow/array/builder_adaptive.h" // IWYU pragma: export
+#include "arrow/array/builder_base.h" // IWYU pragma: export
+#include "arrow/array/builder_primitive.h" // IWYU pragma: export
+#include "arrow/array/data.h"
+#include "arrow/array/util.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+// ----------------------------------------------------------------------
+// Dictionary builder
+
+namespace internal {
+
+template <typename T, typename Enable = void>
+struct DictionaryValue {
+ using type = typename T::c_type;
+ using PhysicalType = T;
+};
+
+template <typename T>
+struct DictionaryValue<T, enable_if_base_binary<T>> {
+ using type = util::string_view;
+ using PhysicalType =
+ typename std::conditional<std::is_same<typename T::offset_type, int32_t>::value,
+ BinaryType, LargeBinaryType>::type;
+};
+
+template <typename T>
+struct DictionaryValue<T, enable_if_fixed_size_binary<T>> {
+ using type = util::string_view;
+ using PhysicalType = BinaryType;
+};
+
+class ARROW_EXPORT DictionaryMemoTable {
+ public:
+ DictionaryMemoTable(MemoryPool* pool, const std::shared_ptr<DataType>& type);
+ DictionaryMemoTable(MemoryPool* pool, const std::shared_ptr<Array>& dictionary);
+ ~DictionaryMemoTable();
+
+ Status GetArrayData(int64_t start_offset, std::shared_ptr<ArrayData>* out);
+
+ /// \brief Insert new memo values
+ Status InsertValues(const Array& values);
+
+ int32_t size() const;
+
+ template <typename T>
+ Status GetOrInsert(typename DictionaryValue<T>::type value, int32_t* out) {
+ // We want to keep the DictionaryMemoTable implementation private, also we can't
+ // use extern template classes because of compiler issues (MinGW?). Instead,
+ // we expose explicit function overrides for each supported physical type.
+ const typename DictionaryValue<T>::PhysicalType* physical_type = NULLPTR;
+ return GetOrInsert(physical_type, value, out);
+ }
+
+ private:
+ Status GetOrInsert(const BooleanType*, bool value, int32_t* out);
+ Status GetOrInsert(const Int8Type*, int8_t value, int32_t* out);
+ Status GetOrInsert(const Int16Type*, int16_t value, int32_t* out);
+ Status GetOrInsert(const Int32Type*, int32_t value, int32_t* out);
+ Status GetOrInsert(const Int64Type*, int64_t value, int32_t* out);
+ Status GetOrInsert(const UInt8Type*, uint8_t value, int32_t* out);
+ Status GetOrInsert(const UInt16Type*, uint16_t value, int32_t* out);
+ Status GetOrInsert(const UInt32Type*, uint32_t value, int32_t* out);
+ Status GetOrInsert(const UInt64Type*, uint64_t value, int32_t* out);
+ Status GetOrInsert(const DurationType*, int64_t value, int32_t* out);
+ Status GetOrInsert(const TimestampType*, int64_t value, int32_t* out);
+ Status GetOrInsert(const Date32Type*, int32_t value, int32_t* out);
+ Status GetOrInsert(const Date64Type*, int64_t value, int32_t* out);
+ Status GetOrInsert(const Time32Type*, int32_t value, int32_t* out);
+ Status GetOrInsert(const Time64Type*, int64_t value, int32_t* out);
+ Status GetOrInsert(const MonthDayNanoIntervalType*,
+ MonthDayNanoIntervalType::MonthDayNanos value, int32_t* out);
+ Status GetOrInsert(const DayTimeIntervalType*,
+ DayTimeIntervalType::DayMilliseconds value, int32_t* out);
+ Status GetOrInsert(const MonthIntervalType*, int32_t value, int32_t* out);
+ Status GetOrInsert(const FloatType*, float value, int32_t* out);
+ Status GetOrInsert(const DoubleType*, double value, int32_t* out);
+
+ Status GetOrInsert(const BinaryType*, util::string_view value, int32_t* out);
+ Status GetOrInsert(const LargeBinaryType*, util::string_view value, int32_t* out);
+
+ class DictionaryMemoTableImpl;
+ std::unique_ptr<DictionaryMemoTableImpl> impl_;
+};
+
+/// \brief Array builder for created encoded DictionaryArray from
+/// dense array
+///
+/// Unlike other builders, dictionary builder does not completely
+/// reset the state on Finish calls.
+template <typename BuilderType, typename T>
+class DictionaryBuilderBase : public ArrayBuilder {
+ public:
+ using TypeClass = DictionaryType;
+ using Value = typename DictionaryValue<T>::type;
+
+ // WARNING: the type given below is the value type, not the DictionaryType.
+ // The DictionaryType is instantiated on the Finish() call.
+ template <typename B = BuilderType, typename T1 = T>
+ DictionaryBuilderBase(uint8_t start_int_size,
+ enable_if_t<std::is_base_of<AdaptiveIntBuilderBase, B>::value &&
+ !is_fixed_size_binary_type<T1>::value,
+ const std::shared_ptr<DataType>&>
+ value_type,
+ MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool),
+ memo_table_(new internal::DictionaryMemoTable(pool, value_type)),
+ delta_offset_(0),
+ byte_width_(-1),
+ indices_builder_(start_int_size, pool),
+ value_type_(value_type) {}
+
+ template <typename T1 = T>
+ explicit DictionaryBuilderBase(
+ enable_if_t<!is_fixed_size_binary_type<T1>::value, const std::shared_ptr<DataType>&>
+ value_type,
+ MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool),
+ memo_table_(new internal::DictionaryMemoTable(pool, value_type)),
+ delta_offset_(0),
+ byte_width_(-1),
+ indices_builder_(pool),
+ value_type_(value_type) {}
+
+ template <typename T1 = T>
+ explicit DictionaryBuilderBase(
+ const std::shared_ptr<DataType>& index_type,
+ enable_if_t<!is_fixed_size_binary_type<T1>::value, const std::shared_ptr<DataType>&>
+ value_type,
+ MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool),
+ memo_table_(new internal::DictionaryMemoTable(pool, value_type)),
+ delta_offset_(0),
+ byte_width_(-1),
+ indices_builder_(index_type, pool),
+ value_type_(value_type) {}
+
+ template <typename B = BuilderType, typename T1 = T>
+ DictionaryBuilderBase(uint8_t start_int_size,
+ enable_if_t<std::is_base_of<AdaptiveIntBuilderBase, B>::value &&
+ is_fixed_size_binary_type<T1>::value,
+ const std::shared_ptr<DataType>&>
+ value_type,
+ MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool),
+ memo_table_(new internal::DictionaryMemoTable(pool, value_type)),
+ delta_offset_(0),
+ byte_width_(static_cast<const T1&>(*value_type).byte_width()),
+ indices_builder_(start_int_size, pool),
+ value_type_(value_type) {}
+
+ template <typename T1 = T>
+ explicit DictionaryBuilderBase(
+ enable_if_fixed_size_binary<T1, const std::shared_ptr<DataType>&> value_type,
+ MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool),
+ memo_table_(new internal::DictionaryMemoTable(pool, value_type)),
+ delta_offset_(0),
+ byte_width_(static_cast<const T1&>(*value_type).byte_width()),
+ indices_builder_(pool),
+ value_type_(value_type) {}
+
+ template <typename T1 = T>
+ explicit DictionaryBuilderBase(
+ const std::shared_ptr<DataType>& index_type,
+ enable_if_fixed_size_binary<T1, const std::shared_ptr<DataType>&> value_type,
+ MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool),
+ memo_table_(new internal::DictionaryMemoTable(pool, value_type)),
+ delta_offset_(0),
+ byte_width_(static_cast<const T1&>(*value_type).byte_width()),
+ indices_builder_(index_type, pool),
+ value_type_(value_type) {}
+
+ template <typename T1 = T>
+ explicit DictionaryBuilderBase(
+ enable_if_parameter_free<T1, MemoryPool*> pool = default_memory_pool())
+ : DictionaryBuilderBase<BuilderType, T1>(TypeTraits<T1>::type_singleton(), pool) {}
+
+ // This constructor doesn't check for errors. Use InsertMemoValues instead.
+ explicit DictionaryBuilderBase(const std::shared_ptr<Array>& dictionary,
+ MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool),
+ memo_table_(new internal::DictionaryMemoTable(pool, dictionary)),
+ delta_offset_(0),
+ byte_width_(-1),
+ indices_builder_(pool),
+ value_type_(dictionary->type()) {}
+
+ ~DictionaryBuilderBase() override = default;
+
+ /// \brief The current number of entries in the dictionary
+ int64_t dictionary_length() const { return memo_table_->size(); }
+
+ /// \brief The value byte width (for FixedSizeBinaryType)
+ template <typename T1 = T>
+ enable_if_fixed_size_binary<T1, int32_t> byte_width() const {
+ return byte_width_;
+ }
+
+ /// \brief Append a scalar value
+ Status Append(Value value) {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+
+ int32_t memo_index;
+ ARROW_RETURN_NOT_OK(memo_table_->GetOrInsert<T>(value, &memo_index));
+ ARROW_RETURN_NOT_OK(indices_builder_.Append(memo_index));
+ length_ += 1;
+
+ return Status::OK();
+ }
+
+ /// \brief Append a fixed-width string (only for FixedSizeBinaryType)
+ template <typename T1 = T>
+ enable_if_fixed_size_binary<T1, Status> Append(const uint8_t* value) {
+ return Append(util::string_view(reinterpret_cast<const char*>(value), byte_width_));
+ }
+
+ /// \brief Append a fixed-width string (only for FixedSizeBinaryType)
+ template <typename T1 = T>
+ enable_if_fixed_size_binary<T1, Status> Append(const char* value) {
+ return Append(util::string_view(value, byte_width_));
+ }
+
+ /// \brief Append a string (only for binary types)
+ template <typename T1 = T>
+ enable_if_binary_like<T1, Status> Append(const uint8_t* value, int32_t length) {
+ return Append(reinterpret_cast<const char*>(value), length);
+ }
+
+ /// \brief Append a string (only for binary types)
+ template <typename T1 = T>
+ enable_if_binary_like<T1, Status> Append(const char* value, int32_t length) {
+ return Append(util::string_view(value, length));
+ }
+
+ /// \brief Append a string (only for string types)
+ template <typename T1 = T>
+ enable_if_string_like<T1, Status> Append(const char* value, int32_t length) {
+ return Append(util::string_view(value, length));
+ }
+
+ /// \brief Append a decimal (only for Decimal128Type)
+ template <typename T1 = T>
+ enable_if_decimal128<T1, Status> Append(const Decimal128& value) {
+ uint8_t data[16];
+ value.ToBytes(data);
+ return Append(data, 16);
+ }
+
+ /// \brief Append a decimal (only for Decimal128Type)
+ template <typename T1 = T>
+ enable_if_decimal256<T1, Status> Append(const Decimal256& value) {
+ uint8_t data[32];
+ value.ToBytes(data);
+ return Append(data, 32);
+ }
+
+ /// \brief Append a scalar null value
+ Status AppendNull() final {
+ length_ += 1;
+ null_count_ += 1;
+
+ return indices_builder_.AppendNull();
+ }
+
+ Status AppendNulls(int64_t length) final {
+ length_ += length;
+ null_count_ += length;
+
+ return indices_builder_.AppendNulls(length);
+ }
+
+ Status AppendEmptyValue() final {
+ length_ += 1;
+
+ return indices_builder_.AppendEmptyValue();
+ }
+
+ Status AppendEmptyValues(int64_t length) final {
+ length_ += length;
+
+ return indices_builder_.AppendEmptyValues(length);
+ }
+
+ Status AppendScalar(const Scalar& scalar, int64_t n_repeats) override {
+ if (!scalar.is_valid) return AppendNulls(n_repeats);
+
+ const auto& dict_ty = internal::checked_cast<const DictionaryType&>(*scalar.type);
+ const DictionaryScalar& dict_scalar =
+ internal::checked_cast<const DictionaryScalar&>(scalar);
+ const auto& dict = internal::checked_cast<const typename TypeTraits<T>::ArrayType&>(
+ *dict_scalar.value.dictionary);
+ ARROW_RETURN_NOT_OK(Reserve(n_repeats));
+ switch (dict_ty.index_type()->id()) {
+ case Type::UINT8:
+ return AppendScalarImpl<UInt8Type>(dict, *dict_scalar.value.index, n_repeats);
+ case Type::INT8:
+ return AppendScalarImpl<Int8Type>(dict, *dict_scalar.value.index, n_repeats);
+ case Type::UINT16:
+ return AppendScalarImpl<UInt16Type>(dict, *dict_scalar.value.index, n_repeats);
+ case Type::INT16:
+ return AppendScalarImpl<Int16Type>(dict, *dict_scalar.value.index, n_repeats);
+ case Type::UINT32:
+ return AppendScalarImpl<UInt32Type>(dict, *dict_scalar.value.index, n_repeats);
+ case Type::INT32:
+ return AppendScalarImpl<Int32Type>(dict, *dict_scalar.value.index, n_repeats);
+ case Type::UINT64:
+ return AppendScalarImpl<UInt64Type>(dict, *dict_scalar.value.index, n_repeats);
+ case Type::INT64:
+ return AppendScalarImpl<Int64Type>(dict, *dict_scalar.value.index, n_repeats);
+ default:
+ return Status::TypeError("Invalid index type: ", dict_ty);
+ }
+ return Status::OK();
+ }
+
+ Status AppendScalars(const ScalarVector& scalars) override {
+ for (const auto& scalar : scalars) {
+ ARROW_RETURN_NOT_OK(AppendScalar(*scalar, /*n_repeats=*/1));
+ }
+ return Status::OK();
+ }
+
+ Status AppendArraySlice(const ArrayData& array, int64_t offset, int64_t length) final {
+ // Visit the indices and insert the unpacked values.
+ const auto& dict_ty = internal::checked_cast<const DictionaryType&>(*array.type);
+ const typename TypeTraits<T>::ArrayType dict(array.dictionary);
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ switch (dict_ty.index_type()->id()) {
+ case Type::UINT8:
+ return AppendArraySliceImpl<uint8_t>(dict, array, offset, length);
+ case Type::INT8:
+ return AppendArraySliceImpl<int8_t>(dict, array, offset, length);
+ case Type::UINT16:
+ return AppendArraySliceImpl<uint16_t>(dict, array, offset, length);
+ case Type::INT16:
+ return AppendArraySliceImpl<int16_t>(dict, array, offset, length);
+ case Type::UINT32:
+ return AppendArraySliceImpl<uint32_t>(dict, array, offset, length);
+ case Type::INT32:
+ return AppendArraySliceImpl<int32_t>(dict, array, offset, length);
+ case Type::UINT64:
+ return AppendArraySliceImpl<uint64_t>(dict, array, offset, length);
+ case Type::INT64:
+ return AppendArraySliceImpl<int64_t>(dict, array, offset, length);
+ default:
+ return Status::TypeError("Invalid index type: ", dict_ty);
+ }
+ return Status::OK();
+ }
+
+ /// \brief Insert values into the dictionary's memo, but do not append any
+ /// indices. Can be used to initialize a new builder with known dictionary
+ /// values
+ /// \param[in] values dictionary values to add to memo. Type must match
+ /// builder type
+ Status InsertMemoValues(const Array& values) {
+ return memo_table_->InsertValues(values);
+ }
+
+ /// \brief Append a whole dense array to the builder
+ template <typename T1 = T>
+ enable_if_t<!is_fixed_size_binary_type<T1>::value, Status> AppendArray(
+ const Array& array) {
+ using ArrayType = typename TypeTraits<T>::ArrayType;
+
+#ifndef NDEBUG
+ ARROW_RETURN_NOT_OK(ArrayBuilder::CheckArrayType(
+ value_type_, array, "Wrong value type of array to be appended"));
+#endif
+
+ const auto& concrete_array = static_cast<const ArrayType&>(array);
+ for (int64_t i = 0; i < array.length(); i++) {
+ if (array.IsNull(i)) {
+ ARROW_RETURN_NOT_OK(AppendNull());
+ } else {
+ ARROW_RETURN_NOT_OK(Append(concrete_array.GetView(i)));
+ }
+ }
+ return Status::OK();
+ }
+
+ template <typename T1 = T>
+ enable_if_fixed_size_binary<T1, Status> AppendArray(const Array& array) {
+#ifndef NDEBUG
+ ARROW_RETURN_NOT_OK(ArrayBuilder::CheckArrayType(
+ value_type_, array, "Wrong value type of array to be appended"));
+#endif
+
+ const auto& concrete_array = static_cast<const FixedSizeBinaryArray&>(array);
+ for (int64_t i = 0; i < array.length(); i++) {
+ if (array.IsNull(i)) {
+ ARROW_RETURN_NOT_OK(AppendNull());
+ } else {
+ ARROW_RETURN_NOT_OK(Append(concrete_array.GetValue(i)));
+ }
+ }
+ return Status::OK();
+ }
+
+ void Reset() override {
+ // Perform a partial reset. Call ResetFull to also reset the accumulated
+ // dictionary values
+ ArrayBuilder::Reset();
+ indices_builder_.Reset();
+ }
+
+ /// \brief Reset and also clear accumulated dictionary values in memo table
+ void ResetFull() {
+ Reset();
+ memo_table_.reset(new internal::DictionaryMemoTable(pool_, value_type_));
+ }
+
+ Status Resize(int64_t capacity) override {
+ ARROW_RETURN_NOT_OK(CheckCapacity(capacity));
+ capacity = std::max(capacity, kMinBuilderCapacity);
+ ARROW_RETURN_NOT_OK(indices_builder_.Resize(capacity));
+ capacity_ = indices_builder_.capacity();
+ return Status::OK();
+ }
+
+ /// \brief Return dictionary indices and a delta dictionary since the last
+ /// time that Finish or FinishDelta were called, and reset state of builder
+ /// (except the memo table)
+ Status FinishDelta(std::shared_ptr<Array>* out_indices,
+ std::shared_ptr<Array>* out_delta) {
+ std::shared_ptr<ArrayData> indices_data;
+ std::shared_ptr<ArrayData> delta_data;
+ ARROW_RETURN_NOT_OK(FinishWithDictOffset(delta_offset_, &indices_data, &delta_data));
+ *out_indices = MakeArray(indices_data);
+ *out_delta = MakeArray(delta_data);
+ return Status::OK();
+ }
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<DictionaryArray>* out) { return FinishTyped(out); }
+
+ std::shared_ptr<DataType> type() const override {
+ return ::arrow::dictionary(indices_builder_.type(), value_type_);
+ }
+
+ protected:
+ template <typename c_type>
+ Status AppendArraySliceImpl(const typename TypeTraits<T>::ArrayType& dict,
+ const ArrayData& array, int64_t offset, int64_t length) {
+ const c_type* values = array.GetValues<c_type>(1) + offset;
+ return VisitBitBlocks(
+ array.buffers[0], array.offset + offset, length,
+ [&](const int64_t position) {
+ const int64_t index = static_cast<int64_t>(values[position]);
+ if (dict.IsValid(index)) {
+ return Append(dict.GetView(index));
+ }
+ return AppendNull();
+ },
+ [&]() { return AppendNull(); });
+ }
+
+ template <typename IndexType>
+ Status AppendScalarImpl(const typename TypeTraits<T>::ArrayType& dict,
+ const Scalar& index_scalar, int64_t n_repeats) {
+ using ScalarType = typename TypeTraits<IndexType>::ScalarType;
+ const auto index = internal::checked_cast<const ScalarType&>(index_scalar).value;
+ if (index_scalar.is_valid && dict.IsValid(index)) {
+ const auto& value = dict.GetView(index);
+ for (int64_t i = 0; i < n_repeats; i++) {
+ ARROW_RETURN_NOT_OK(Append(value));
+ }
+ return Status::OK();
+ }
+ return AppendNulls(n_repeats);
+ }
+
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
+ std::shared_ptr<ArrayData> dictionary;
+ ARROW_RETURN_NOT_OK(FinishWithDictOffset(/*offset=*/0, out, &dictionary));
+
+ // Set type of array data to the right dictionary type
+ (*out)->type = type();
+ (*out)->dictionary = dictionary;
+ return Status::OK();
+ }
+
+ Status FinishWithDictOffset(int64_t dict_offset,
+ std::shared_ptr<ArrayData>* out_indices,
+ std::shared_ptr<ArrayData>* out_dictionary) {
+ // Finalize indices array
+ ARROW_RETURN_NOT_OK(indices_builder_.FinishInternal(out_indices));
+
+ // Generate dictionary array from hash table contents
+ ARROW_RETURN_NOT_OK(memo_table_->GetArrayData(dict_offset, out_dictionary));
+ delta_offset_ = memo_table_->size();
+
+ // Update internals for further uses of this DictionaryBuilder
+ ArrayBuilder::Reset();
+ return Status::OK();
+ }
+
+ std::unique_ptr<DictionaryMemoTable> memo_table_;
+
+ // The size of the dictionary memo at last invocation of Finish, to use in
+ // FinishDelta for computing dictionary deltas
+ int32_t delta_offset_;
+
+ // Only used for FixedSizeBinaryType
+ int32_t byte_width_;
+
+ BuilderType indices_builder_;
+ std::shared_ptr<DataType> value_type_;
+};
+
+template <typename BuilderType>
+class DictionaryBuilderBase<BuilderType, NullType> : public ArrayBuilder {
+ public:
+ template <typename B = BuilderType>
+ DictionaryBuilderBase(
+ enable_if_t<std::is_base_of<AdaptiveIntBuilderBase, B>::value, uint8_t>
+ start_int_size,
+ const std::shared_ptr<DataType>& value_type,
+ MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool), indices_builder_(start_int_size, pool) {}
+
+ explicit DictionaryBuilderBase(const std::shared_ptr<DataType>& value_type,
+ MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool), indices_builder_(pool) {}
+
+ explicit DictionaryBuilderBase(const std::shared_ptr<DataType>& index_type,
+ const std::shared_ptr<DataType>& value_type,
+ MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool), indices_builder_(index_type, pool) {}
+
+ template <typename B = BuilderType>
+ explicit DictionaryBuilderBase(
+ enable_if_t<std::is_base_of<AdaptiveIntBuilderBase, B>::value, uint8_t>
+ start_int_size,
+ MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool), indices_builder_(start_int_size, pool) {}
+
+ explicit DictionaryBuilderBase(MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool), indices_builder_(pool) {}
+
+ explicit DictionaryBuilderBase(const std::shared_ptr<Array>& dictionary,
+ MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool), indices_builder_(pool) {}
+
+ /// \brief Append a scalar null value
+ Status AppendNull() final {
+ length_ += 1;
+ null_count_ += 1;
+
+ return indices_builder_.AppendNull();
+ }
+
+ Status AppendNulls(int64_t length) final {
+ length_ += length;
+ null_count_ += length;
+
+ return indices_builder_.AppendNulls(length);
+ }
+
+ Status AppendEmptyValue() final {
+ length_ += 1;
+
+ return indices_builder_.AppendEmptyValue();
+ }
+
+ Status AppendEmptyValues(int64_t length) final {
+ length_ += length;
+
+ return indices_builder_.AppendEmptyValues(length);
+ }
+
+ /// \brief Append a whole dense array to the builder
+ Status AppendArray(const Array& array) {
+#ifndef NDEBUG
+ ARROW_RETURN_NOT_OK(ArrayBuilder::CheckArrayType(
+ Type::NA, array, "Wrong value type of array to be appended"));
+#endif
+ for (int64_t i = 0; i < array.length(); i++) {
+ ARROW_RETURN_NOT_OK(AppendNull());
+ }
+ return Status::OK();
+ }
+
+ Status Resize(int64_t capacity) override {
+ ARROW_RETURN_NOT_OK(CheckCapacity(capacity));
+ capacity = std::max(capacity, kMinBuilderCapacity);
+
+ ARROW_RETURN_NOT_OK(indices_builder_.Resize(capacity));
+ capacity_ = indices_builder_.capacity();
+ return Status::OK();
+ }
+
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
+ ARROW_RETURN_NOT_OK(indices_builder_.FinishInternal(out));
+ (*out)->type = dictionary((*out)->type, null());
+ (*out)->dictionary = NullArray(0).data();
+ return Status::OK();
+ }
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<DictionaryArray>* out) { return FinishTyped(out); }
+
+ std::shared_ptr<DataType> type() const override {
+ return ::arrow::dictionary(indices_builder_.type(), null());
+ }
+
+ protected:
+ BuilderType indices_builder_;
+};
+
+} // namespace internal
+
+/// \brief A DictionaryArray builder that uses AdaptiveIntBuilder to return the
+/// smallest index size that can accommodate the dictionary indices
+template <typename T>
+class DictionaryBuilder : public internal::DictionaryBuilderBase<AdaptiveIntBuilder, T> {
+ public:
+ using BASE = internal::DictionaryBuilderBase<AdaptiveIntBuilder, T>;
+ using BASE::BASE;
+
+ /// \brief Append dictionary indices directly without modifying memo
+ ///
+ /// NOTE: Experimental API
+ Status AppendIndices(const int64_t* values, int64_t length,
+ const uint8_t* valid_bytes = NULLPTR) {
+ int64_t null_count_before = this->indices_builder_.null_count();
+ ARROW_RETURN_NOT_OK(this->indices_builder_.AppendValues(values, length, valid_bytes));
+ this->capacity_ = this->indices_builder_.capacity();
+ this->length_ += length;
+ this->null_count_ += this->indices_builder_.null_count() - null_count_before;
+ return Status::OK();
+ }
+};
+
+/// \brief A DictionaryArray builder that always returns int32 dictionary
+/// indices so that data cast to dictionary form will have a consistent index
+/// type, e.g. for creating a ChunkedArray
+template <typename T>
+class Dictionary32Builder : public internal::DictionaryBuilderBase<Int32Builder, T> {
+ public:
+ using BASE = internal::DictionaryBuilderBase<Int32Builder, T>;
+ using BASE::BASE;
+
+ /// \brief Append dictionary indices directly without modifying memo
+ ///
+ /// NOTE: Experimental API
+ Status AppendIndices(const int32_t* values, int64_t length,
+ const uint8_t* valid_bytes = NULLPTR) {
+ int64_t null_count_before = this->indices_builder_.null_count();
+ ARROW_RETURN_NOT_OK(this->indices_builder_.AppendValues(values, length, valid_bytes));
+ this->capacity_ = this->indices_builder_.capacity();
+ this->length_ += length;
+ this->null_count_ += this->indices_builder_.null_count() - null_count_before;
+ return Status::OK();
+ }
+};
+
+// ----------------------------------------------------------------------
+// Binary / Unicode builders
+// (compatibility aliases; those used to be derived classes with additional
+// Append() overloads, but they have been folded into DictionaryBuilderBase)
+
+using BinaryDictionaryBuilder = DictionaryBuilder<BinaryType>;
+using StringDictionaryBuilder = DictionaryBuilder<StringType>;
+using BinaryDictionary32Builder = Dictionary32Builder<BinaryType>;
+using StringDictionary32Builder = Dictionary32Builder<StringType>;
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_nested.cc b/src/arrow/cpp/src/arrow/array/builder_nested.cc
new file mode 100644
index 000000000..a3bcde038
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_nested.cc
@@ -0,0 +1,294 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/builder_nested.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <utility>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+// ----------------------------------------------------------------------
+// MapBuilder
+
+MapBuilder::MapBuilder(MemoryPool* pool, const std::shared_ptr<ArrayBuilder>& key_builder,
+ std::shared_ptr<ArrayBuilder> const& item_builder,
+ const std::shared_ptr<DataType>& type)
+ : ArrayBuilder(pool), key_builder_(key_builder), item_builder_(item_builder) {
+ auto map_type = internal::checked_cast<const MapType*>(type.get());
+ keys_sorted_ = map_type->keys_sorted();
+
+ std::vector<std::shared_ptr<ArrayBuilder>> child_builders{key_builder, item_builder};
+ auto struct_builder =
+ std::make_shared<StructBuilder>(map_type->value_type(), pool, child_builders);
+
+ list_builder_ =
+ std::make_shared<ListBuilder>(pool, struct_builder, struct_builder->type());
+}
+
+MapBuilder::MapBuilder(MemoryPool* pool, const std::shared_ptr<ArrayBuilder>& key_builder,
+ const std::shared_ptr<ArrayBuilder>& item_builder,
+ bool keys_sorted)
+ : MapBuilder(pool, key_builder, item_builder,
+ map(key_builder->type(), item_builder->type(), keys_sorted)) {}
+
+MapBuilder::MapBuilder(MemoryPool* pool,
+ const std::shared_ptr<ArrayBuilder>& struct_builder,
+ const std::shared_ptr<DataType>& type)
+ : ArrayBuilder(pool) {
+ auto map_type = internal::checked_cast<const MapType*>(type.get());
+ keys_sorted_ = map_type->keys_sorted();
+ key_builder_ = struct_builder->child_builder(0);
+ item_builder_ = struct_builder->child_builder(1);
+ list_builder_ =
+ std::make_shared<ListBuilder>(pool, struct_builder, struct_builder->type());
+}
+
+Status MapBuilder::Resize(int64_t capacity) {
+ RETURN_NOT_OK(list_builder_->Resize(capacity));
+ capacity_ = list_builder_->capacity();
+ return Status::OK();
+}
+
+void MapBuilder::Reset() {
+ list_builder_->Reset();
+ ArrayBuilder::Reset();
+}
+
+Status MapBuilder::FinishInternal(std::shared_ptr<ArrayData>* out) {
+ ARROW_CHECK_EQ(item_builder_->length(), key_builder_->length())
+ << "keys and items builders don't have the same size in MapBuilder";
+ RETURN_NOT_OK(AdjustStructBuilderLength());
+ RETURN_NOT_OK(list_builder_->FinishInternal(out));
+ (*out)->type = type();
+ ArrayBuilder::Reset();
+ return Status::OK();
+}
+
+Status MapBuilder::AppendValues(const int32_t* offsets, int64_t length,
+ const uint8_t* valid_bytes) {
+ DCHECK_EQ(item_builder_->length(), key_builder_->length());
+ RETURN_NOT_OK(AdjustStructBuilderLength());
+ RETURN_NOT_OK(list_builder_->AppendValues(offsets, length, valid_bytes));
+ length_ = list_builder_->length();
+ null_count_ = list_builder_->null_count();
+ return Status::OK();
+}
+
+Status MapBuilder::Append() {
+ DCHECK_EQ(item_builder_->length(), key_builder_->length());
+ RETURN_NOT_OK(AdjustStructBuilderLength());
+ RETURN_NOT_OK(list_builder_->Append());
+ length_ = list_builder_->length();
+ return Status::OK();
+}
+
+Status MapBuilder::AppendNull() {
+ DCHECK_EQ(item_builder_->length(), key_builder_->length());
+ RETURN_NOT_OK(AdjustStructBuilderLength());
+ RETURN_NOT_OK(list_builder_->AppendNull());
+ length_ = list_builder_->length();
+ null_count_ = list_builder_->null_count();
+ return Status::OK();
+}
+
+Status MapBuilder::AppendNulls(int64_t length) {
+ DCHECK_EQ(item_builder_->length(), key_builder_->length());
+ RETURN_NOT_OK(AdjustStructBuilderLength());
+ RETURN_NOT_OK(list_builder_->AppendNulls(length));
+ length_ = list_builder_->length();
+ null_count_ = list_builder_->null_count();
+ return Status::OK();
+}
+
+Status MapBuilder::AppendEmptyValue() {
+ DCHECK_EQ(item_builder_->length(), key_builder_->length());
+ RETURN_NOT_OK(AdjustStructBuilderLength());
+ RETURN_NOT_OK(list_builder_->AppendEmptyValue());
+ length_ = list_builder_->length();
+ null_count_ = list_builder_->null_count();
+ return Status::OK();
+}
+
+Status MapBuilder::AppendEmptyValues(int64_t length) {
+ DCHECK_EQ(item_builder_->length(), key_builder_->length());
+ RETURN_NOT_OK(AdjustStructBuilderLength());
+ RETURN_NOT_OK(list_builder_->AppendEmptyValues(length));
+ length_ = list_builder_->length();
+ null_count_ = list_builder_->null_count();
+ return Status::OK();
+}
+
+Status MapBuilder::AdjustStructBuilderLength() {
+ // If key/item builders have been appended, adjust struct builder length
+ // to match. Struct and key are non-nullable, append all valid values.
+ auto struct_builder =
+ internal::checked_cast<StructBuilder*>(list_builder_->value_builder());
+ if (struct_builder->length() < key_builder_->length()) {
+ int64_t length_diff = key_builder_->length() - struct_builder->length();
+ RETURN_NOT_OK(struct_builder->AppendValues(length_diff, NULLPTR));
+ }
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// FixedSizeListBuilder
+
+FixedSizeListBuilder::FixedSizeListBuilder(
+ MemoryPool* pool, const std::shared_ptr<ArrayBuilder>& value_builder,
+ const std::shared_ptr<DataType>& type)
+ : ArrayBuilder(pool),
+ value_field_(type->field(0)),
+ list_size_(
+ internal::checked_cast<const FixedSizeListType*>(type.get())->list_size()),
+ value_builder_(value_builder) {}
+
+FixedSizeListBuilder::FixedSizeListBuilder(
+ MemoryPool* pool, const std::shared_ptr<ArrayBuilder>& value_builder,
+ int32_t list_size)
+ : FixedSizeListBuilder(pool, value_builder,
+ fixed_size_list(value_builder->type(), list_size)) {}
+
+void FixedSizeListBuilder::Reset() {
+ ArrayBuilder::Reset();
+ value_builder_->Reset();
+}
+
+Status FixedSizeListBuilder::Append() {
+ RETURN_NOT_OK(Reserve(1));
+ UnsafeAppendToBitmap(true);
+ return Status::OK();
+}
+
+Status FixedSizeListBuilder::AppendValues(int64_t length, const uint8_t* valid_bytes) {
+ RETURN_NOT_OK(Reserve(length));
+ UnsafeAppendToBitmap(valid_bytes, length);
+ return Status::OK();
+}
+
+Status FixedSizeListBuilder::AppendNull() {
+ RETURN_NOT_OK(Reserve(1));
+ UnsafeAppendToBitmap(false);
+ return value_builder_->AppendNulls(list_size_);
+}
+
+Status FixedSizeListBuilder::AppendNulls(int64_t length) {
+ RETURN_NOT_OK(Reserve(length));
+ UnsafeAppendToBitmap(length, false);
+ return value_builder_->AppendNulls(list_size_ * length);
+}
+
+Status FixedSizeListBuilder::ValidateOverflow(int64_t new_elements) {
+ auto new_length = value_builder_->length() + new_elements;
+ if (new_elements != list_size_) {
+ return Status::Invalid("Length of item not correct: expected ", list_size_,
+ " but got array of size ", new_elements);
+ }
+ if (new_length > maximum_elements()) {
+ return Status::CapacityError("array cannot contain more than ", maximum_elements(),
+ " elements, have ", new_elements);
+ }
+ return Status::OK();
+}
+
+Status FixedSizeListBuilder::AppendEmptyValue() {
+ RETURN_NOT_OK(Reserve(1));
+ UnsafeAppendToBitmap(true);
+ return value_builder_->AppendEmptyValues(list_size_);
+}
+
+Status FixedSizeListBuilder::AppendEmptyValues(int64_t length) {
+ RETURN_NOT_OK(Reserve(length));
+ UnsafeAppendToBitmap(length, true);
+ return value_builder_->AppendEmptyValues(list_size_ * length);
+}
+
+Status FixedSizeListBuilder::Resize(int64_t capacity) {
+ RETURN_NOT_OK(CheckCapacity(capacity));
+ return ArrayBuilder::Resize(capacity);
+}
+
+Status FixedSizeListBuilder::FinishInternal(std::shared_ptr<ArrayData>* out) {
+ std::shared_ptr<ArrayData> items;
+
+ if (value_builder_->length() == 0) {
+ // Try to make sure we get a non-null values buffer (ARROW-2744)
+ RETURN_NOT_OK(value_builder_->Resize(0));
+ }
+ RETURN_NOT_OK(value_builder_->FinishInternal(&items));
+
+ std::shared_ptr<Buffer> null_bitmap;
+ RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
+ *out = ArrayData::Make(type(), length_, {null_bitmap}, {std::move(items)}, null_count_);
+ Reset();
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Struct
+
+StructBuilder::StructBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool,
+ std::vector<std::shared_ptr<ArrayBuilder>> field_builders)
+ : ArrayBuilder(pool), type_(type) {
+ children_ = std::move(field_builders);
+}
+
+void StructBuilder::Reset() {
+ ArrayBuilder::Reset();
+ for (const auto& field_builder : children_) {
+ field_builder->Reset();
+ }
+}
+
+Status StructBuilder::FinishInternal(std::shared_ptr<ArrayData>* out) {
+ std::shared_ptr<Buffer> null_bitmap;
+ RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
+
+ std::vector<std::shared_ptr<ArrayData>> child_data(children_.size());
+ for (size_t i = 0; i < children_.size(); ++i) {
+ if (length_ == 0) {
+ // Try to make sure the child buffers are initialized
+ RETURN_NOT_OK(children_[i]->Resize(0));
+ }
+ RETURN_NOT_OK(children_[i]->FinishInternal(&child_data[i]));
+ }
+
+ *out = ArrayData::Make(type(), length_, {null_bitmap}, null_count_);
+ (*out)->child_data = std::move(child_data);
+
+ capacity_ = length_ = null_count_ = 0;
+ return Status::OK();
+}
+
+std::shared_ptr<DataType> StructBuilder::type() const {
+ DCHECK_EQ(type_->fields().size(), children_.size());
+ std::vector<std::shared_ptr<Field>> fields(children_.size());
+ for (int i = 0; i < static_cast<int>(fields.size()); ++i) {
+ fields[i] = type_->field(i)->WithType(children_[i]->type());
+ }
+ return struct_(std::move(fields));
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_nested.h b/src/arrow/cpp/src/arrow/array/builder_nested.h
new file mode 100644
index 000000000..e53b758ef
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_nested.h
@@ -0,0 +1,544 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/array_nested.h"
+#include "arrow/array/builder_base.h"
+#include "arrow/array/data.h"
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+// ----------------------------------------------------------------------
+// List builder
+
+template <typename TYPE>
+class BaseListBuilder : public ArrayBuilder {
+ public:
+ using TypeClass = TYPE;
+ using offset_type = typename TypeClass::offset_type;
+
+ /// Use this constructor to incrementally build the value array along with offsets and
+ /// null bitmap.
+ BaseListBuilder(MemoryPool* pool, std::shared_ptr<ArrayBuilder> const& value_builder,
+ const std::shared_ptr<DataType>& type)
+ : ArrayBuilder(pool),
+ offsets_builder_(pool),
+ value_builder_(value_builder),
+ value_field_(type->field(0)->WithType(NULLPTR)) {}
+
+ BaseListBuilder(MemoryPool* pool, std::shared_ptr<ArrayBuilder> const& value_builder)
+ : BaseListBuilder(pool, value_builder, list(value_builder->type())) {}
+
+ Status Resize(int64_t capacity) override {
+ if (capacity > maximum_elements()) {
+ return Status::CapacityError("List array cannot reserve space for more than ",
+ maximum_elements(), " got ", capacity);
+ }
+ ARROW_RETURN_NOT_OK(CheckCapacity(capacity));
+
+ // One more than requested for offsets
+ ARROW_RETURN_NOT_OK(offsets_builder_.Resize(capacity + 1));
+ return ArrayBuilder::Resize(capacity);
+ }
+
+ void Reset() override {
+ ArrayBuilder::Reset();
+ offsets_builder_.Reset();
+ value_builder_->Reset();
+ }
+
+ /// \brief Vector append
+ ///
+ /// If passed, valid_bytes is of equal length to values, and any zero byte
+ /// will be considered as a null for that slot
+ Status AppendValues(const offset_type* offsets, int64_t length,
+ const uint8_t* valid_bytes = NULLPTR) {
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ UnsafeAppendToBitmap(valid_bytes, length);
+ offsets_builder_.UnsafeAppend(offsets, length);
+ return Status::OK();
+ }
+
+ /// \brief Start a new variable-length list slot
+ ///
+ /// This function should be called before beginning to append elements to the
+ /// value builder
+ Status Append(bool is_valid = true) {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ UnsafeAppendToBitmap(is_valid);
+ return AppendNextOffset();
+ }
+
+ Status AppendNull() final { return Append(false); }
+
+ Status AppendNulls(int64_t length) final {
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ ARROW_RETURN_NOT_OK(ValidateOverflow(0));
+ UnsafeAppendToBitmap(length, false);
+ const int64_t num_values = value_builder_->length();
+ for (int64_t i = 0; i < length; ++i) {
+ offsets_builder_.UnsafeAppend(static_cast<offset_type>(num_values));
+ }
+ return Status::OK();
+ }
+
+ Status AppendEmptyValue() final { return Append(true); }
+
+ Status AppendEmptyValues(int64_t length) final {
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ ARROW_RETURN_NOT_OK(ValidateOverflow(0));
+ UnsafeAppendToBitmap(length, true);
+ const int64_t num_values = value_builder_->length();
+ for (int64_t i = 0; i < length; ++i) {
+ offsets_builder_.UnsafeAppend(static_cast<offset_type>(num_values));
+ }
+ return Status::OK();
+ }
+
+ Status AppendArraySlice(const ArrayData& array, int64_t offset,
+ int64_t length) override {
+ const offset_type* offsets = array.GetValues<offset_type>(1);
+ const uint8_t* validity = array.MayHaveNulls() ? array.buffers[0]->data() : NULLPTR;
+ for (int64_t row = offset; row < offset + length; row++) {
+ if (!validity || BitUtil::GetBit(validity, array.offset + row)) {
+ ARROW_RETURN_NOT_OK(Append());
+ int64_t slot_length = offsets[row + 1] - offsets[row];
+ ARROW_RETURN_NOT_OK(value_builder_->AppendArraySlice(*array.child_data[0],
+ offsets[row], slot_length));
+ } else {
+ ARROW_RETURN_NOT_OK(AppendNull());
+ }
+ }
+ return Status::OK();
+ }
+
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
+ ARROW_RETURN_NOT_OK(AppendNextOffset());
+
+ // Offset padding zeroed by BufferBuilder
+ std::shared_ptr<Buffer> offsets, null_bitmap;
+ ARROW_RETURN_NOT_OK(offsets_builder_.Finish(&offsets));
+ ARROW_RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
+
+ if (value_builder_->length() == 0) {
+ // Try to make sure we get a non-null values buffer (ARROW-2744)
+ ARROW_RETURN_NOT_OK(value_builder_->Resize(0));
+ }
+
+ std::shared_ptr<ArrayData> items;
+ ARROW_RETURN_NOT_OK(value_builder_->FinishInternal(&items));
+
+ *out = ArrayData::Make(type(), length_, {null_bitmap, offsets}, {std::move(items)},
+ null_count_);
+ Reset();
+ return Status::OK();
+ }
+
+ Status ValidateOverflow(int64_t new_elements) const {
+ auto new_length = value_builder_->length() + new_elements;
+ if (ARROW_PREDICT_FALSE(new_length > maximum_elements())) {
+ return Status::CapacityError("List array cannot contain more than ",
+ maximum_elements(), " elements, have ", new_elements);
+ } else {
+ return Status::OK();
+ }
+ }
+
+ ArrayBuilder* value_builder() const { return value_builder_.get(); }
+
+ // Cannot make this a static attribute because of linking issues
+ static constexpr int64_t maximum_elements() {
+ return std::numeric_limits<offset_type>::max() - 1;
+ }
+
+ std::shared_ptr<DataType> type() const override {
+ return std::make_shared<TYPE>(value_field_->WithType(value_builder_->type()));
+ }
+
+ protected:
+ TypedBufferBuilder<offset_type> offsets_builder_;
+ std::shared_ptr<ArrayBuilder> value_builder_;
+ std::shared_ptr<Field> value_field_;
+
+ Status AppendNextOffset() {
+ ARROW_RETURN_NOT_OK(ValidateOverflow(0));
+ const int64_t num_values = value_builder_->length();
+ return offsets_builder_.Append(static_cast<offset_type>(num_values));
+ }
+};
+
+/// \class ListBuilder
+/// \brief Builder class for variable-length list array value types
+///
+/// To use this class, you must append values to the child array builder and use
+/// the Append function to delimit each distinct list value (once the values
+/// have been appended to the child array) or use the bulk API to append
+/// a sequence of offsets and null values.
+///
+/// A note on types. Per arrow/type.h all types in the c++ implementation are
+/// logical so even though this class always builds list array, this can
+/// represent multiple different logical types. If no logical type is provided
+/// at construction time, the class defaults to List<T> where t is taken from the
+/// value_builder/values that the object is constructed with.
+class ARROW_EXPORT ListBuilder : public BaseListBuilder<ListType> {
+ public:
+ using BaseListBuilder::BaseListBuilder;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<ListArray>* out) { return FinishTyped(out); }
+};
+
+/// \class LargeListBuilder
+/// \brief Builder class for large variable-length list array value types
+///
+/// Like ListBuilder, but to create large list arrays (with 64-bit offsets).
+class ARROW_EXPORT LargeListBuilder : public BaseListBuilder<LargeListType> {
+ public:
+ using BaseListBuilder::BaseListBuilder;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<LargeListArray>* out) { return FinishTyped(out); }
+};
+
+// ----------------------------------------------------------------------
+// Map builder
+
+/// \class MapBuilder
+/// \brief Builder class for arrays of variable-size maps
+///
+/// To use this class, you must append values to the key and item array builders
+/// and use the Append function to delimit each distinct map (once the keys and items
+/// have been appended) or use the bulk API to append a sequence of offsets and null
+/// maps.
+///
+/// Key uniqueness and ordering are not validated.
+class ARROW_EXPORT MapBuilder : public ArrayBuilder {
+ public:
+ /// Use this constructor to define the built array's type explicitly. If key_builder
+ /// or item_builder has indeterminate type, this builder will also.
+ MapBuilder(MemoryPool* pool, const std::shared_ptr<ArrayBuilder>& key_builder,
+ const std::shared_ptr<ArrayBuilder>& item_builder,
+ const std::shared_ptr<DataType>& type);
+
+ /// Use this constructor to infer the built array's type. If key_builder or
+ /// item_builder has indeterminate type, this builder will also.
+ MapBuilder(MemoryPool* pool, const std::shared_ptr<ArrayBuilder>& key_builder,
+ const std::shared_ptr<ArrayBuilder>& item_builder, bool keys_sorted = false);
+
+ MapBuilder(MemoryPool* pool, const std::shared_ptr<ArrayBuilder>& item_builder,
+ const std::shared_ptr<DataType>& type);
+
+ Status Resize(int64_t capacity) override;
+ void Reset() override;
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<MapArray>* out) { return FinishTyped(out); }
+
+ /// \brief Vector append
+ ///
+ /// If passed, valid_bytes is of equal length to values, and any zero byte
+ /// will be considered as a null for that slot
+ Status AppendValues(const int32_t* offsets, int64_t length,
+ const uint8_t* valid_bytes = NULLPTR);
+
+ /// \brief Start a new variable-length map slot
+ ///
+ /// This function should be called before beginning to append elements to the
+ /// key and item builders
+ Status Append();
+
+ Status AppendNull() final;
+
+ Status AppendNulls(int64_t length) final;
+
+ Status AppendEmptyValue() final;
+
+ Status AppendEmptyValues(int64_t length) final;
+
+ Status AppendArraySlice(const ArrayData& array, int64_t offset,
+ int64_t length) override {
+ const int32_t* offsets = array.GetValues<int32_t>(1);
+ const uint8_t* validity = array.MayHaveNulls() ? array.buffers[0]->data() : NULLPTR;
+ for (int64_t row = offset; row < offset + length; row++) {
+ if (!validity || BitUtil::GetBit(validity, array.offset + row)) {
+ ARROW_RETURN_NOT_OK(Append());
+ const int64_t slot_length = offsets[row + 1] - offsets[row];
+ ARROW_RETURN_NOT_OK(key_builder_->AppendArraySlice(
+ *array.child_data[0]->child_data[0], offsets[row], slot_length));
+ ARROW_RETURN_NOT_OK(item_builder_->AppendArraySlice(
+ *array.child_data[0]->child_data[1], offsets[row], slot_length));
+ } else {
+ ARROW_RETURN_NOT_OK(AppendNull());
+ }
+ }
+ return Status::OK();
+ }
+
+ /// \brief Get builder to append keys.
+ ///
+ /// Append a key with this builder should be followed by appending
+ /// an item or null value with item_builder().
+ ArrayBuilder* key_builder() const { return key_builder_.get(); }
+
+ /// \brief Get builder to append items
+ ///
+ /// Appending an item with this builder should have been preceded
+ /// by appending a key with key_builder().
+ ArrayBuilder* item_builder() const { return item_builder_.get(); }
+
+ /// \brief Get builder to add Map entries as struct values.
+ ///
+ /// This is used instead of key_builder()/item_builder() and allows
+ /// the Map to be built as a list of struct values.
+ ArrayBuilder* value_builder() const { return list_builder_->value_builder(); }
+
+ std::shared_ptr<DataType> type() const override {
+ return map(key_builder_->type(), item_builder_->type(), keys_sorted_);
+ }
+
+ Status ValidateOverflow(int64_t new_elements) {
+ return list_builder_->ValidateOverflow(new_elements);
+ }
+
+ protected:
+ inline Status AdjustStructBuilderLength();
+
+ protected:
+ bool keys_sorted_ = false;
+ std::shared_ptr<ListBuilder> list_builder_;
+ std::shared_ptr<ArrayBuilder> key_builder_;
+ std::shared_ptr<ArrayBuilder> item_builder_;
+};
+
+// ----------------------------------------------------------------------
+// FixedSizeList builder
+
+/// \class FixedSizeListBuilder
+/// \brief Builder class for fixed-length list array value types
+class ARROW_EXPORT FixedSizeListBuilder : public ArrayBuilder {
+ public:
+ /// Use this constructor to define the built array's type explicitly. If value_builder
+ /// has indeterminate type, this builder will also.
+ FixedSizeListBuilder(MemoryPool* pool,
+ std::shared_ptr<ArrayBuilder> const& value_builder,
+ int32_t list_size);
+
+ /// Use this constructor to infer the built array's type. If value_builder has
+ /// indeterminate type, this builder will also.
+ FixedSizeListBuilder(MemoryPool* pool,
+ std::shared_ptr<ArrayBuilder> const& value_builder,
+ const std::shared_ptr<DataType>& type);
+
+ Status Resize(int64_t capacity) override;
+ void Reset() override;
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<FixedSizeListArray>* out) { return FinishTyped(out); }
+
+ /// \brief Append a valid fixed length list.
+ ///
+ /// This function affects only the validity bitmap; the child values must be appended
+ /// using the child array builder.
+ Status Append();
+
+ /// \brief Vector append
+ ///
+ /// If passed, valid_bytes wil be read and any zero byte
+ /// will cause the corresponding slot to be null
+ ///
+ /// This function affects only the validity bitmap; the child values must be appended
+ /// using the child array builder. This includes appending nulls for null lists.
+ /// XXX this restriction is confusing, should this method be omitted?
+ Status AppendValues(int64_t length, const uint8_t* valid_bytes = NULLPTR);
+
+ /// \brief Append a null fixed length list.
+ ///
+ /// The child array builder will have the appropriate number of nulls appended
+ /// automatically.
+ Status AppendNull() final;
+
+ /// \brief Append length null fixed length lists.
+ ///
+ /// The child array builder will have the appropriate number of nulls appended
+ /// automatically.
+ Status AppendNulls(int64_t length) final;
+
+ Status ValidateOverflow(int64_t new_elements);
+
+ Status AppendEmptyValue() final;
+
+ Status AppendEmptyValues(int64_t length) final;
+
+ Status AppendArraySlice(const ArrayData& array, int64_t offset, int64_t length) final {
+ const uint8_t* validity = array.MayHaveNulls() ? array.buffers[0]->data() : NULLPTR;
+ for (int64_t row = offset; row < offset + length; row++) {
+ if (!validity || BitUtil::GetBit(validity, array.offset + row)) {
+ ARROW_RETURN_NOT_OK(value_builder_->AppendArraySlice(
+ *array.child_data[0], list_size_ * (array.offset + row), list_size_));
+ ARROW_RETURN_NOT_OK(Append());
+ } else {
+ ARROW_RETURN_NOT_OK(AppendNull());
+ }
+ }
+ return Status::OK();
+ }
+
+ ArrayBuilder* value_builder() const { return value_builder_.get(); }
+
+ std::shared_ptr<DataType> type() const override {
+ return fixed_size_list(value_field_->WithType(value_builder_->type()), list_size_);
+ }
+
+ // Cannot make this a static attribute because of linking issues
+ static constexpr int64_t maximum_elements() {
+ return std::numeric_limits<FixedSizeListType::offset_type>::max() - 1;
+ }
+
+ protected:
+ std::shared_ptr<Field> value_field_;
+ const int32_t list_size_;
+ std::shared_ptr<ArrayBuilder> value_builder_;
+};
+
+// ----------------------------------------------------------------------
+// Struct
+
+// ---------------------------------------------------------------------------------
+// StructArray builder
+/// Append, Resize and Reserve methods are acting on StructBuilder.
+/// Please make sure all these methods of all child-builders' are consistently
+/// called to maintain data-structure consistency.
+class ARROW_EXPORT StructBuilder : public ArrayBuilder {
+ public:
+ /// If any of field_builders has indeterminate type, this builder will also
+ StructBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool,
+ std::vector<std::shared_ptr<ArrayBuilder>> field_builders);
+
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<StructArray>* out) { return FinishTyped(out); }
+
+ /// Null bitmap is of equal length to every child field, and any zero byte
+ /// will be considered as a null for that field, but users must using app-
+ /// end methods or advance methods of the child builders' independently to
+ /// insert data.
+ Status AppendValues(int64_t length, const uint8_t* valid_bytes) {
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ UnsafeAppendToBitmap(valid_bytes, length);
+ return Status::OK();
+ }
+
+ /// Append an element to the Struct. All child-builders' Append method must
+ /// be called independently to maintain data-structure consistency.
+ Status Append(bool is_valid = true) {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ UnsafeAppendToBitmap(is_valid);
+ return Status::OK();
+ }
+
+ /// \brief Append a null value. Automatically appends an empty value to each child
+ /// builder.
+ Status AppendNull() final {
+ for (const auto& field : children_) {
+ ARROW_RETURN_NOT_OK(field->AppendEmptyValue());
+ }
+ return Append(false);
+ }
+
+ /// \brief Append multiple null values. Automatically appends empty values to each
+ /// child builder.
+ Status AppendNulls(int64_t length) final {
+ for (const auto& field : children_) {
+ ARROW_RETURN_NOT_OK(field->AppendEmptyValues(length));
+ }
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ UnsafeAppendToBitmap(length, false);
+ return Status::OK();
+ }
+
+ Status AppendEmptyValue() final {
+ for (const auto& field : children_) {
+ ARROW_RETURN_NOT_OK(field->AppendEmptyValue());
+ }
+ return Append(true);
+ }
+
+ Status AppendEmptyValues(int64_t length) final {
+ for (const auto& field : children_) {
+ ARROW_RETURN_NOT_OK(field->AppendEmptyValues(length));
+ }
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ UnsafeAppendToBitmap(length, true);
+ return Status::OK();
+ }
+
+ Status AppendArraySlice(const ArrayData& array, int64_t offset,
+ int64_t length) override {
+ for (int i = 0; static_cast<size_t>(i) < children_.size(); i++) {
+ ARROW_RETURN_NOT_OK(children_[i]->AppendArraySlice(*array.child_data[i],
+ array.offset + offset, length));
+ }
+ const uint8_t* validity = array.MayHaveNulls() ? array.buffers[0]->data() : NULLPTR;
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ UnsafeAppendToBitmap(validity, array.offset + offset, length);
+ return Status::OK();
+ }
+
+ void Reset() override;
+
+ ArrayBuilder* field_builder(int i) const { return children_[i].get(); }
+
+ int num_fields() const { return static_cast<int>(children_.size()); }
+
+ std::shared_ptr<DataType> type() const override;
+
+ private:
+ std::shared_ptr<DataType> type_;
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_primitive.cc b/src/arrow/cpp/src/arrow/array/builder_primitive.cc
new file mode 100644
index 000000000..769c2f7d0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_primitive.cc
@@ -0,0 +1,145 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/builder_primitive.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/int_util.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+// ----------------------------------------------------------------------
+// Null builder
+
+Status NullBuilder::FinishInternal(std::shared_ptr<ArrayData>* out) {
+ *out = ArrayData::Make(null(), length_, {nullptr}, length_);
+ length_ = null_count_ = 0;
+ return Status::OK();
+}
+
+BooleanBuilder::BooleanBuilder(MemoryPool* pool)
+ : ArrayBuilder(pool), data_builder_(pool) {}
+
+BooleanBuilder::BooleanBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool)
+ : BooleanBuilder(pool) {
+ ARROW_CHECK_EQ(Type::BOOL, type->id());
+}
+
+void BooleanBuilder::Reset() {
+ ArrayBuilder::Reset();
+ data_builder_.Reset();
+}
+
+Status BooleanBuilder::Resize(int64_t capacity) {
+ RETURN_NOT_OK(CheckCapacity(capacity));
+ capacity = std::max(capacity, kMinBuilderCapacity);
+ RETURN_NOT_OK(data_builder_.Resize(capacity));
+ return ArrayBuilder::Resize(capacity);
+}
+
+Status BooleanBuilder::FinishInternal(std::shared_ptr<ArrayData>* out) {
+ ARROW_ASSIGN_OR_RAISE(auto null_bitmap, null_bitmap_builder_.FinishWithLength(length_));
+ ARROW_ASSIGN_OR_RAISE(auto data, data_builder_.FinishWithLength(length_));
+
+ *out = ArrayData::Make(boolean(), length_, {null_bitmap, data}, null_count_);
+
+ capacity_ = length_ = null_count_ = 0;
+ return Status::OK();
+}
+
+Status BooleanBuilder::AppendValues(const uint8_t* values, int64_t length,
+ const uint8_t* valid_bytes) {
+ RETURN_NOT_OK(Reserve(length));
+
+ int64_t i = 0;
+ data_builder_.UnsafeAppend<false>(length,
+ [values, &i]() -> bool { return values[i++] != 0; });
+ ArrayBuilder::UnsafeAppendToBitmap(valid_bytes, length);
+ return Status::OK();
+}
+
+Status BooleanBuilder::AppendValues(const uint8_t* values, int64_t length,
+ const uint8_t* validity, int64_t offset) {
+ RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend(values, offset, length);
+ ArrayBuilder::UnsafeAppendToBitmap(validity, offset, length);
+ return Status::OK();
+}
+
+Status BooleanBuilder::AppendValues(const uint8_t* values, int64_t length,
+ const std::vector<bool>& is_valid) {
+ RETURN_NOT_OK(Reserve(length));
+ DCHECK_EQ(length, static_cast<int64_t>(is_valid.size()));
+ int64_t i = 0;
+ data_builder_.UnsafeAppend<false>(length,
+ [values, &i]() -> bool { return values[i++]; });
+ ArrayBuilder::UnsafeAppendToBitmap(is_valid);
+ return Status::OK();
+}
+
+Status BooleanBuilder::AppendValues(const std::vector<uint8_t>& values,
+ const std::vector<bool>& is_valid) {
+ return AppendValues(values.data(), static_cast<int64_t>(values.size()), is_valid);
+}
+
+Status BooleanBuilder::AppendValues(const std::vector<uint8_t>& values) {
+ return AppendValues(values.data(), static_cast<int64_t>(values.size()));
+}
+
+Status BooleanBuilder::AppendValues(const std::vector<bool>& values,
+ const std::vector<bool>& is_valid) {
+ const int64_t length = static_cast<int64_t>(values.size());
+ RETURN_NOT_OK(Reserve(length));
+ DCHECK_EQ(length, static_cast<int64_t>(is_valid.size()));
+ int64_t i = 0;
+ data_builder_.UnsafeAppend<false>(length,
+ [&values, &i]() -> bool { return values[i++]; });
+ ArrayBuilder::UnsafeAppendToBitmap(is_valid);
+ return Status::OK();
+}
+
+Status BooleanBuilder::AppendValues(const std::vector<bool>& values) {
+ const int64_t length = static_cast<int64_t>(values.size());
+ RETURN_NOT_OK(Reserve(length));
+ int64_t i = 0;
+ data_builder_.UnsafeAppend<false>(length,
+ [&values, &i]() -> bool { return values[i++]; });
+ ArrayBuilder::UnsafeSetNotNull(length);
+ return Status::OK();
+}
+
+Status BooleanBuilder::AppendValues(int64_t length, bool value) {
+ RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend(length, value);
+ ArrayBuilder::UnsafeSetNotNull(length);
+ return Status::OK();
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_primitive.h b/src/arrow/cpp/src/arrow/array/builder_primitive.h
new file mode 100644
index 000000000..67d58fc9d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_primitive.h
@@ -0,0 +1,519 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <memory>
+#include <vector>
+
+#include "arrow/array/builder_base.h"
+#include "arrow/array/data.h"
+#include "arrow/result.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+class ARROW_EXPORT NullBuilder : public ArrayBuilder {
+ public:
+ explicit NullBuilder(MemoryPool* pool = default_memory_pool()) : ArrayBuilder(pool) {}
+ explicit NullBuilder(const std::shared_ptr<DataType>& type,
+ MemoryPool* pool = default_memory_pool())
+ : NullBuilder(pool) {}
+
+ /// \brief Append the specified number of null elements
+ Status AppendNulls(int64_t length) final {
+ if (length < 0) return Status::Invalid("length must be positive");
+ null_count_ += length;
+ length_ += length;
+ return Status::OK();
+ }
+
+ /// \brief Append a single null element
+ Status AppendNull() final { return AppendNulls(1); }
+
+ Status AppendEmptyValues(int64_t length) final { return AppendNulls(length); }
+
+ Status AppendEmptyValue() final { return AppendEmptyValues(1); }
+
+ Status Append(std::nullptr_t) { return AppendNull(); }
+
+ Status AppendArraySlice(const ArrayData&, int64_t, int64_t length) override {
+ return AppendNulls(length);
+ }
+
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ std::shared_ptr<DataType> type() const override { return null(); }
+
+ Status Finish(std::shared_ptr<NullArray>* out) { return FinishTyped(out); }
+};
+
+/// Base class for all Builders that emit an Array of a scalar numerical type.
+template <typename T>
+class NumericBuilder : public ArrayBuilder {
+ public:
+ using TypeClass = T;
+ using value_type = typename T::c_type;
+ using ArrayType = typename TypeTraits<T>::ArrayType;
+
+ template <typename T1 = T>
+ explicit NumericBuilder(
+ enable_if_parameter_free<T1, MemoryPool*> pool = default_memory_pool())
+ : ArrayBuilder(pool), type_(TypeTraits<T>::type_singleton()), data_builder_(pool) {}
+
+ NumericBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool)
+ : ArrayBuilder(pool), type_(type), data_builder_(pool) {}
+
+ /// Append a single scalar and increase the size if necessary.
+ Status Append(const value_type val) {
+ ARROW_RETURN_NOT_OK(ArrayBuilder::Reserve(1));
+ UnsafeAppend(val);
+ return Status::OK();
+ }
+
+ /// Write nulls as uint8_t* (0 value indicates null) into pre-allocated memory
+ /// The memory at the corresponding data slot is set to 0 to prevent
+ /// uninitialized memory access
+ Status AppendNulls(int64_t length) final {
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend(length, value_type{}); // zero
+ UnsafeSetNull(length);
+ return Status::OK();
+ }
+
+ /// \brief Append a single null element
+ Status AppendNull() final {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ data_builder_.UnsafeAppend(value_type{}); // zero
+ UnsafeAppendToBitmap(false);
+ return Status::OK();
+ }
+
+ /// \brief Append a empty element
+ Status AppendEmptyValue() final {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ data_builder_.UnsafeAppend(value_type{}); // zero
+ UnsafeAppendToBitmap(true);
+ return Status::OK();
+ }
+
+ /// \brief Append several empty elements
+ Status AppendEmptyValues(int64_t length) final {
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend(length, value_type{}); // zero
+ UnsafeSetNotNull(length);
+ return Status::OK();
+ }
+
+ value_type GetValue(int64_t index) const { return data_builder_.data()[index]; }
+
+ void Reset() override { data_builder_.Reset(); }
+
+ Status Resize(int64_t capacity) override {
+ ARROW_RETURN_NOT_OK(CheckCapacity(capacity));
+ capacity = std::max(capacity, kMinBuilderCapacity);
+ ARROW_RETURN_NOT_OK(data_builder_.Resize(capacity));
+ return ArrayBuilder::Resize(capacity);
+ }
+
+ value_type operator[](int64_t index) const { return GetValue(index); }
+
+ value_type& operator[](int64_t index) {
+ return reinterpret_cast<value_type*>(data_builder_.mutable_data())[index];
+ }
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values a contiguous C array of values
+ /// \param[in] length the number of values to append
+ /// \param[in] valid_bytes an optional sequence of bytes where non-zero
+ /// indicates a valid (non-null) value
+ /// \return Status
+ Status AppendValues(const value_type* values, int64_t length,
+ const uint8_t* valid_bytes = NULLPTR) {
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend(values, length);
+ // length_ is update by these
+ ArrayBuilder::UnsafeAppendToBitmap(valid_bytes, length);
+ return Status::OK();
+ }
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values a contiguous C array of values
+ /// \param[in] length the number of values to append
+ /// \param[in] bitmap a validity bitmap to copy (may be null)
+ /// \param[in] bitmap_offset an offset into the validity bitmap
+ /// \return Status
+ Status AppendValues(const value_type* values, int64_t length, const uint8_t* bitmap,
+ int64_t bitmap_offset) {
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend(values, length);
+ // length_ is update by these
+ ArrayBuilder::UnsafeAppendToBitmap(bitmap, bitmap_offset, length);
+ return Status::OK();
+ }
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values a contiguous C array of values
+ /// \param[in] length the number of values to append
+ /// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
+ /// (0). Equal in length to values
+ /// \return Status
+ Status AppendValues(const value_type* values, int64_t length,
+ const std::vector<bool>& is_valid) {
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend(values, length);
+ // length_ is update by these
+ ArrayBuilder::UnsafeAppendToBitmap(is_valid);
+ return Status::OK();
+ }
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values a std::vector of values
+ /// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
+ /// (0). Equal in length to values
+ /// \return Status
+ Status AppendValues(const std::vector<value_type>& values,
+ const std::vector<bool>& is_valid) {
+ return AppendValues(values.data(), static_cast<int64_t>(values.size()), is_valid);
+ }
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values a std::vector of values
+ /// \return Status
+ Status AppendValues(const std::vector<value_type>& values) {
+ return AppendValues(values.data(), static_cast<int64_t>(values.size()));
+ }
+
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
+ ARROW_ASSIGN_OR_RAISE(auto null_bitmap,
+ null_bitmap_builder_.FinishWithLength(length_));
+ ARROW_ASSIGN_OR_RAISE(auto data, data_builder_.FinishWithLength(length_));
+ *out = ArrayData::Make(type(), length_, {null_bitmap, data}, null_count_);
+ capacity_ = length_ = null_count_ = 0;
+ return Status::OK();
+ }
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<ArrayType>* out) { return FinishTyped(out); }
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values_begin InputIterator to the beginning of the values
+ /// \param[in] values_end InputIterator pointing to the end of the values
+ /// \return Status
+ template <typename ValuesIter>
+ Status AppendValues(ValuesIter values_begin, ValuesIter values_end) {
+ int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend(values_begin, values_end);
+ // this updates the length_
+ UnsafeSetNotNull(length);
+ return Status::OK();
+ }
+
+ /// \brief Append a sequence of elements in one shot, with a specified nullmap
+ /// \param[in] values_begin InputIterator to the beginning of the values
+ /// \param[in] values_end InputIterator pointing to the end of the values
+ /// \param[in] valid_begin InputIterator with elements indication valid(1)
+ /// or null(0) values.
+ /// \return Status
+ template <typename ValuesIter, typename ValidIter>
+ enable_if_t<!std::is_pointer<ValidIter>::value, Status> AppendValues(
+ ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) {
+ static_assert(!internal::is_null_pointer<ValidIter>::value,
+ "Don't pass a NULLPTR directly as valid_begin, use the 2-argument "
+ "version instead");
+ int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend(values_begin, values_end);
+ null_bitmap_builder_.UnsafeAppend<true>(
+ length, [&valid_begin]() -> bool { return *valid_begin++; });
+ length_ = null_bitmap_builder_.length();
+ null_count_ = null_bitmap_builder_.false_count();
+ return Status::OK();
+ }
+
+ // Same as above, with a pointer type ValidIter
+ template <typename ValuesIter, typename ValidIter>
+ enable_if_t<std::is_pointer<ValidIter>::value, Status> AppendValues(
+ ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) {
+ int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend(values_begin, values_end);
+ // this updates the length_
+ if (valid_begin == NULLPTR) {
+ UnsafeSetNotNull(length);
+ } else {
+ null_bitmap_builder_.UnsafeAppend<true>(
+ length, [&valid_begin]() -> bool { return *valid_begin++; });
+ length_ = null_bitmap_builder_.length();
+ null_count_ = null_bitmap_builder_.false_count();
+ }
+
+ return Status::OK();
+ }
+
+ Status AppendArraySlice(const ArrayData& array, int64_t offset,
+ int64_t length) override {
+ return AppendValues(array.GetValues<value_type>(1) + offset, length,
+ array.GetValues<uint8_t>(0, 0), array.offset + offset);
+ }
+
+ /// Append a single scalar under the assumption that the underlying Buffer is
+ /// large enough.
+ ///
+ /// This method does not capacity-check; make sure to call Reserve
+ /// beforehand.
+ void UnsafeAppend(const value_type val) {
+ ArrayBuilder::UnsafeAppendToBitmap(true);
+ data_builder_.UnsafeAppend(val);
+ }
+
+ void UnsafeAppendNull() {
+ ArrayBuilder::UnsafeAppendToBitmap(false);
+ data_builder_.UnsafeAppend(value_type{}); // zero
+ }
+
+ std::shared_ptr<DataType> type() const override { return type_; }
+
+ protected:
+ std::shared_ptr<DataType> type_;
+ TypedBufferBuilder<value_type> data_builder_;
+};
+
+// Builders
+
+using UInt8Builder = NumericBuilder<UInt8Type>;
+using UInt16Builder = NumericBuilder<UInt16Type>;
+using UInt32Builder = NumericBuilder<UInt32Type>;
+using UInt64Builder = NumericBuilder<UInt64Type>;
+
+using Int8Builder = NumericBuilder<Int8Type>;
+using Int16Builder = NumericBuilder<Int16Type>;
+using Int32Builder = NumericBuilder<Int32Type>;
+using Int64Builder = NumericBuilder<Int64Type>;
+
+using HalfFloatBuilder = NumericBuilder<HalfFloatType>;
+using FloatBuilder = NumericBuilder<FloatType>;
+using DoubleBuilder = NumericBuilder<DoubleType>;
+
+class ARROW_EXPORT BooleanBuilder : public ArrayBuilder {
+ public:
+ using TypeClass = BooleanType;
+ using value_type = bool;
+
+ explicit BooleanBuilder(MemoryPool* pool = default_memory_pool());
+
+ BooleanBuilder(const std::shared_ptr<DataType>& type,
+ MemoryPool* pool = default_memory_pool());
+
+ /// Write nulls as uint8_t* (0 value indicates null) into pre-allocated memory
+ Status AppendNulls(int64_t length) final {
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend(length, false);
+ UnsafeSetNull(length);
+ return Status::OK();
+ }
+
+ Status AppendNull() final {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ UnsafeAppendNull();
+ return Status::OK();
+ }
+
+ Status AppendEmptyValue() final {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ data_builder_.UnsafeAppend(false);
+ UnsafeSetNotNull(1);
+ return Status::OK();
+ }
+
+ Status AppendEmptyValues(int64_t length) final {
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend(length, false);
+ UnsafeSetNotNull(length);
+ return Status::OK();
+ }
+
+ /// Scalar append
+ Status Append(const bool val) {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ UnsafeAppend(val);
+ return Status::OK();
+ }
+
+ Status Append(const uint8_t val) { return Append(val != 0); }
+
+ /// Scalar append, without checking for capacity
+ void UnsafeAppend(const bool val) {
+ data_builder_.UnsafeAppend(val);
+ UnsafeAppendToBitmap(true);
+ }
+
+ void UnsafeAppendNull() {
+ data_builder_.UnsafeAppend(false);
+ UnsafeAppendToBitmap(false);
+ }
+
+ void UnsafeAppend(const uint8_t val) { UnsafeAppend(val != 0); }
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values a contiguous array of bytes (non-zero is 1)
+ /// \param[in] length the number of values to append
+ /// \param[in] valid_bytes an optional sequence of bytes where non-zero
+ /// indicates a valid (non-null) value
+ /// \return Status
+ Status AppendValues(const uint8_t* values, int64_t length,
+ const uint8_t* valid_bytes = NULLPTR);
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values a bitmap of values
+ /// \param[in] length the number of values to append
+ /// \param[in] validity a validity bitmap to copy (may be null)
+ /// \param[in] offset an offset into the values and validity bitmaps
+ /// \return Status
+ Status AppendValues(const uint8_t* values, int64_t length, const uint8_t* validity,
+ int64_t offset);
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values a contiguous C array of values
+ /// \param[in] length the number of values to append
+ /// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
+ /// (0). Equal in length to values
+ /// \return Status
+ Status AppendValues(const uint8_t* values, int64_t length,
+ const std::vector<bool>& is_valid);
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values a std::vector of bytes
+ /// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
+ /// (0). Equal in length to values
+ /// \return Status
+ Status AppendValues(const std::vector<uint8_t>& values,
+ const std::vector<bool>& is_valid);
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values a std::vector of bytes
+ /// \return Status
+ Status AppendValues(const std::vector<uint8_t>& values);
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values an std::vector<bool> indicating true (1) or false
+ /// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
+ /// (0). Equal in length to values
+ /// \return Status
+ Status AppendValues(const std::vector<bool>& values, const std::vector<bool>& is_valid);
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values an std::vector<bool> indicating true (1) or false
+ /// \return Status
+ Status AppendValues(const std::vector<bool>& values);
+
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values_begin InputIterator to the beginning of the values
+ /// \param[in] values_end InputIterator pointing to the end of the values
+ /// or null(0) values
+ /// \return Status
+ template <typename ValuesIter>
+ Status AppendValues(ValuesIter values_begin, ValuesIter values_end) {
+ int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend<false>(
+ length, [&values_begin]() -> bool { return *values_begin++; });
+ // this updates length_
+ UnsafeSetNotNull(length);
+ return Status::OK();
+ }
+
+ /// \brief Append a sequence of elements in one shot, with a specified nullmap
+ /// \param[in] values_begin InputIterator to the beginning of the values
+ /// \param[in] values_end InputIterator pointing to the end of the values
+ /// \param[in] valid_begin InputIterator with elements indication valid(1)
+ /// or null(0) values
+ /// \return Status
+ template <typename ValuesIter, typename ValidIter>
+ enable_if_t<!std::is_pointer<ValidIter>::value, Status> AppendValues(
+ ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) {
+ static_assert(!internal::is_null_pointer<ValidIter>::value,
+ "Don't pass a NULLPTR directly as valid_begin, use the 2-argument "
+ "version instead");
+ int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
+ ARROW_RETURN_NOT_OK(Reserve(length));
+
+ data_builder_.UnsafeAppend<false>(
+ length, [&values_begin]() -> bool { return *values_begin++; });
+ null_bitmap_builder_.UnsafeAppend<true>(
+ length, [&valid_begin]() -> bool { return *valid_begin++; });
+ length_ = null_bitmap_builder_.length();
+ null_count_ = null_bitmap_builder_.false_count();
+ return Status::OK();
+ }
+
+ // Same as above, for a pointer type ValidIter
+ template <typename ValuesIter, typename ValidIter>
+ enable_if_t<std::is_pointer<ValidIter>::value, Status> AppendValues(
+ ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) {
+ int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend<false>(
+ length, [&values_begin]() -> bool { return *values_begin++; });
+
+ if (valid_begin == NULLPTR) {
+ UnsafeSetNotNull(length);
+ } else {
+ null_bitmap_builder_.UnsafeAppend<true>(
+ length, [&valid_begin]() -> bool { return *valid_begin++; });
+ }
+ length_ = null_bitmap_builder_.length();
+ null_count_ = null_bitmap_builder_.false_count();
+ return Status::OK();
+ }
+
+ Status AppendValues(int64_t length, bool value);
+
+ Status AppendArraySlice(const ArrayData& array, int64_t offset,
+ int64_t length) override {
+ return AppendValues(array.GetValues<uint8_t>(1, 0), length,
+ array.GetValues<uint8_t>(0, 0), array.offset + offset);
+ }
+
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<BooleanArray>* out) { return FinishTyped(out); }
+
+ void Reset() override;
+ Status Resize(int64_t capacity) override;
+
+ std::shared_ptr<DataType> type() const override { return boolean(); }
+
+ protected:
+ TypedBufferBuilder<bool> data_builder_;
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_time.h b/src/arrow/cpp/src/arrow/array/builder_time.h
new file mode 100644
index 000000000..55a7beaaa
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_time.h
@@ -0,0 +1,56 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Contains declarations of time related Arrow builder types.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/array/builder_base.h"
+#include "arrow/array/builder_primitive.h"
+
+namespace arrow {
+
+// TODO(ARROW-7938): this class is untested
+
+class ARROW_EXPORT DayTimeIntervalBuilder : public NumericBuilder<DayTimeIntervalType> {
+ public:
+ using DayMilliseconds = DayTimeIntervalType::DayMilliseconds;
+
+ explicit DayTimeIntervalBuilder(MemoryPool* pool = default_memory_pool())
+ : DayTimeIntervalBuilder(day_time_interval(), pool) {}
+
+ explicit DayTimeIntervalBuilder(std::shared_ptr<DataType> type,
+ MemoryPool* pool = default_memory_pool())
+ : NumericBuilder<DayTimeIntervalType>(type, pool) {}
+};
+
+class ARROW_EXPORT MonthDayNanoIntervalBuilder
+ : public NumericBuilder<MonthDayNanoIntervalType> {
+ public:
+ using MonthDayNanos = MonthDayNanoIntervalType::MonthDayNanos;
+
+ explicit MonthDayNanoIntervalBuilder(MemoryPool* pool = default_memory_pool())
+ : MonthDayNanoIntervalBuilder(month_day_nano_interval(), pool) {}
+
+ explicit MonthDayNanoIntervalBuilder(std::shared_ptr<DataType> type,
+ MemoryPool* pool = default_memory_pool())
+ : NumericBuilder<MonthDayNanoIntervalType>(type, pool) {}
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_union.cc b/src/arrow/cpp/src/arrow/array/builder_union.cc
new file mode 100644
index 000000000..6096b76ff
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_union.cc
@@ -0,0 +1,151 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/builder_union.h"
+
+#include <cstddef>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+Status BasicUnionBuilder::FinishInternal(std::shared_ptr<ArrayData>* out) {
+ int64_t length = types_builder_.length();
+
+ std::shared_ptr<Buffer> types;
+ RETURN_NOT_OK(types_builder_.Finish(&types));
+
+ std::vector<std::shared_ptr<ArrayData>> child_data(children_.size());
+ for (size_t i = 0; i < children_.size(); ++i) {
+ RETURN_NOT_OK(children_[i]->FinishInternal(&child_data[i]));
+ }
+
+ *out = ArrayData::Make(type(), length, {nullptr, types}, /*null_count=*/0);
+ (*out)->child_data = std::move(child_data);
+ return Status::OK();
+}
+
+Status DenseUnionBuilder::AppendArraySlice(const ArrayData& array, const int64_t offset,
+ const int64_t length) {
+ const int8_t* type_codes = array.GetValues<int8_t>(1);
+ const int32_t* offsets = array.GetValues<int32_t>(2);
+ for (int64_t row = offset; row < offset + length; row++) {
+ const int8_t type_code = type_codes[row];
+ const int child_id = type_id_to_child_id_[type_code];
+ const int32_t union_offset = offsets[row];
+ RETURN_NOT_OK(Append(type_code));
+ RETURN_NOT_OK(type_id_to_children_[type_code]->AppendArraySlice(
+ *array.child_data[child_id], union_offset, /*length=*/1));
+ }
+ return Status::OK();
+}
+
+Status DenseUnionBuilder::FinishInternal(std::shared_ptr<ArrayData>* out) {
+ ARROW_RETURN_NOT_OK(BasicUnionBuilder::FinishInternal(out));
+ (*out)->buffers.resize(3);
+ ARROW_RETURN_NOT_OK(offsets_builder_.Finish(&(*out)->buffers[2]));
+ return Status::OK();
+}
+
+BasicUnionBuilder::BasicUnionBuilder(
+ MemoryPool* pool, const std::vector<std::shared_ptr<ArrayBuilder>>& children,
+ const std::shared_ptr<DataType>& type)
+ : ArrayBuilder(pool), child_fields_(children.size()), types_builder_(pool) {
+ const auto& union_type = checked_cast<const UnionType&>(*type);
+ mode_ = union_type.mode();
+
+ DCHECK_EQ(children.size(), union_type.type_codes().size());
+
+ type_codes_ = union_type.type_codes();
+ children_ = children;
+
+ type_id_to_child_id_.resize(union_type.max_type_code() + 1, -1);
+ type_id_to_children_.resize(union_type.max_type_code() + 1, nullptr);
+ DCHECK_LE(
+ type_id_to_children_.size() - 1,
+ static_cast<decltype(type_id_to_children_)::size_type>(UnionType::kMaxTypeCode));
+
+ for (size_t i = 0; i < children.size(); ++i) {
+ child_fields_[i] = union_type.field(static_cast<int>(i));
+
+ auto type_id = union_type.type_codes()[i];
+ type_id_to_child_id_[type_id] = static_cast<int>(i);
+ type_id_to_children_[type_id] = children[i].get();
+ }
+}
+
+int8_t BasicUnionBuilder::AppendChild(const std::shared_ptr<ArrayBuilder>& new_child,
+ const std::string& field_name) {
+ children_.push_back(new_child);
+ auto new_type_id = NextTypeId();
+
+ type_id_to_child_id_[new_type_id] = static_cast<int>(children_.size() - 1);
+ type_id_to_children_[new_type_id] = new_child.get();
+ child_fields_.push_back(field(field_name, nullptr));
+ type_codes_.push_back(static_cast<int8_t>(new_type_id));
+
+ return new_type_id;
+}
+
+std::shared_ptr<DataType> BasicUnionBuilder::type() const {
+ std::vector<std::shared_ptr<Field>> child_fields(child_fields_.size());
+ for (size_t i = 0; i < child_fields.size(); ++i) {
+ child_fields[i] = child_fields_[i]->WithType(children_[i]->type());
+ }
+ return mode_ == UnionMode::SPARSE ? sparse_union(std::move(child_fields), type_codes_)
+ : dense_union(std::move(child_fields), type_codes_);
+}
+
+int8_t BasicUnionBuilder::NextTypeId() {
+ // Find type_id such that type_id_to_children_[type_id] == nullptr
+ // and use that for the new child. Start searching at dense_type_id_
+ // since type_id_to_children_ is densely packed up at least up to dense_type_id_
+ for (; static_cast<size_t>(dense_type_id_) < type_id_to_children_.size();
+ ++dense_type_id_) {
+ if (type_id_to_children_[dense_type_id_] == nullptr) {
+ return dense_type_id_++;
+ }
+ }
+
+ DCHECK_LT(
+ type_id_to_children_.size(),
+ static_cast<decltype(type_id_to_children_)::size_type>(UnionType::kMaxTypeCode));
+
+ // type_id_to_children_ is already densely packed, so just append the new child
+ type_id_to_child_id_.resize(type_id_to_child_id_.size() + 1);
+ type_id_to_children_.resize(type_id_to_children_.size() + 1);
+ return dense_type_id_++;
+}
+
+Status SparseUnionBuilder::AppendArraySlice(const ArrayData& array, const int64_t offset,
+ const int64_t length) {
+ for (size_t i = 0; i < type_codes_.size(); i++) {
+ RETURN_NOT_OK(type_id_to_children_[type_codes_[i]]->AppendArraySlice(
+ *array.child_data[i], array.offset + offset, length));
+ }
+ const int8_t* type_codes = array.GetValues<int8_t>(1);
+ RETURN_NOT_OK(types_builder_.Append(type_codes + offset, length));
+ return Status::OK();
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/builder_union.h b/src/arrow/cpp/src/arrow/array/builder_union.h
new file mode 100644
index 000000000..c1a799e56
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/builder_union.h
@@ -0,0 +1,242 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_nested.h"
+#include "arrow/array/builder_base.h"
+#include "arrow/array/data.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+/// \brief Base class for union array builds.
+///
+/// Note that while we subclass ArrayBuilder, as union types do not have a
+/// validity bitmap, the bitmap builder member of ArrayBuilder is not used.
+class ARROW_EXPORT BasicUnionBuilder : public ArrayBuilder {
+ public:
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
+
+ /// \cond FALSE
+ using ArrayBuilder::Finish;
+ /// \endcond
+
+ Status Finish(std::shared_ptr<UnionArray>* out) { return FinishTyped(out); }
+
+ /// \brief Make a new child builder available to the UnionArray
+ ///
+ /// \param[in] new_child the child builder
+ /// \param[in] field_name the name of the field in the union array type
+ /// if type inference is used
+ /// \return child index, which is the "type" argument that needs
+ /// to be passed to the "Append" method to add a new element to
+ /// the union array.
+ int8_t AppendChild(const std::shared_ptr<ArrayBuilder>& new_child,
+ const std::string& field_name = "");
+
+ std::shared_ptr<DataType> type() const override;
+
+ int64_t length() const override { return types_builder_.length(); }
+
+ protected:
+ BasicUnionBuilder(MemoryPool* pool,
+ const std::vector<std::shared_ptr<ArrayBuilder>>& children,
+ const std::shared_ptr<DataType>& type);
+
+ int8_t NextTypeId();
+
+ std::vector<std::shared_ptr<Field>> child_fields_;
+ std::vector<int8_t> type_codes_;
+ UnionMode::type mode_;
+
+ std::vector<ArrayBuilder*> type_id_to_children_;
+ std::vector<int> type_id_to_child_id_;
+ // for all type_id < dense_type_id_, type_id_to_children_[type_id] != nullptr
+ int8_t dense_type_id_ = 0;
+ TypedBufferBuilder<int8_t> types_builder_;
+};
+
+/// \class DenseUnionBuilder
+///
+/// This API is EXPERIMENTAL.
+class ARROW_EXPORT DenseUnionBuilder : public BasicUnionBuilder {
+ public:
+ /// Use this constructor to initialize the UnionBuilder with no child builders,
+ /// allowing type to be inferred. You will need to call AppendChild for each of the
+ /// children builders you want to use.
+ explicit DenseUnionBuilder(MemoryPool* pool)
+ : BasicUnionBuilder(pool, {}, dense_union(FieldVector{})), offsets_builder_(pool) {}
+
+ /// Use this constructor to specify the type explicitly.
+ /// You can still add child builders to the union after using this constructor
+ DenseUnionBuilder(MemoryPool* pool,
+ const std::vector<std::shared_ptr<ArrayBuilder>>& children,
+ const std::shared_ptr<DataType>& type)
+ : BasicUnionBuilder(pool, children, type), offsets_builder_(pool) {}
+
+ Status AppendNull() final {
+ const int8_t first_child_code = type_codes_[0];
+ ArrayBuilder* child_builder = type_id_to_children_[first_child_code];
+ ARROW_RETURN_NOT_OK(types_builder_.Append(first_child_code));
+ ARROW_RETURN_NOT_OK(
+ offsets_builder_.Append(static_cast<int32_t>(child_builder->length())));
+ // Append a null arbitrarily to the first child
+ return child_builder->AppendNull();
+ }
+
+ Status AppendNulls(int64_t length) final {
+ const int8_t first_child_code = type_codes_[0];
+ ArrayBuilder* child_builder = type_id_to_children_[first_child_code];
+ ARROW_RETURN_NOT_OK(types_builder_.Append(length, first_child_code));
+ ARROW_RETURN_NOT_OK(
+ offsets_builder_.Append(length, static_cast<int32_t>(child_builder->length())));
+ // Append just a single null to the first child
+ return child_builder->AppendNull();
+ }
+
+ Status AppendEmptyValue() final {
+ const int8_t first_child_code = type_codes_[0];
+ ArrayBuilder* child_builder = type_id_to_children_[first_child_code];
+ ARROW_RETURN_NOT_OK(types_builder_.Append(first_child_code));
+ ARROW_RETURN_NOT_OK(
+ offsets_builder_.Append(static_cast<int32_t>(child_builder->length())));
+ // Append an empty value arbitrarily to the first child
+ return child_builder->AppendEmptyValue();
+ }
+
+ Status AppendEmptyValues(int64_t length) final {
+ const int8_t first_child_code = type_codes_[0];
+ ArrayBuilder* child_builder = type_id_to_children_[first_child_code];
+ ARROW_RETURN_NOT_OK(types_builder_.Append(length, first_child_code));
+ ARROW_RETURN_NOT_OK(
+ offsets_builder_.Append(length, static_cast<int32_t>(child_builder->length())));
+ // Append just a single empty value to the first child
+ return child_builder->AppendEmptyValue();
+ }
+
+ /// \brief Append an element to the UnionArray. This must be followed
+ /// by an append to the appropriate child builder.
+ ///
+ /// \param[in] next_type type_id of the child to which the next value will be appended.
+ ///
+ /// The corresponding child builder must be appended to independently after this method
+ /// is called.
+ Status Append(int8_t next_type) {
+ ARROW_RETURN_NOT_OK(types_builder_.Append(next_type));
+ if (type_id_to_children_[next_type]->length() == kListMaximumElements) {
+ return Status::CapacityError(
+ "a dense UnionArray cannot contain more than 2^31 - 1 elements from a single "
+ "child");
+ }
+ auto offset = static_cast<int32_t>(type_id_to_children_[next_type]->length());
+ return offsets_builder_.Append(offset);
+ }
+
+ Status AppendArraySlice(const ArrayData& array, int64_t offset,
+ int64_t length) override;
+
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
+
+ private:
+ TypedBufferBuilder<int32_t> offsets_builder_;
+};
+
+/// \class SparseUnionBuilder
+///
+/// This API is EXPERIMENTAL.
+class ARROW_EXPORT SparseUnionBuilder : public BasicUnionBuilder {
+ public:
+ /// Use this constructor to initialize the UnionBuilder with no child builders,
+ /// allowing type to be inferred. You will need to call AppendChild for each of the
+ /// children builders you want to use.
+ explicit SparseUnionBuilder(MemoryPool* pool)
+ : BasicUnionBuilder(pool, {}, sparse_union(FieldVector{})) {}
+
+ /// Use this constructor to specify the type explicitly.
+ /// You can still add child builders to the union after using this constructor
+ SparseUnionBuilder(MemoryPool* pool,
+ const std::vector<std::shared_ptr<ArrayBuilder>>& children,
+ const std::shared_ptr<DataType>& type)
+ : BasicUnionBuilder(pool, children, type) {}
+
+ /// \brief Append a null value.
+ ///
+ /// A null is appended to the first child, empty values to the other children.
+ Status AppendNull() final {
+ const auto first_child_code = type_codes_[0];
+ ARROW_RETURN_NOT_OK(types_builder_.Append(first_child_code));
+ ARROW_RETURN_NOT_OK(type_id_to_children_[first_child_code]->AppendNull());
+ for (int i = 1; i < static_cast<int>(type_codes_.size()); ++i) {
+ ARROW_RETURN_NOT_OK(type_id_to_children_[type_codes_[i]]->AppendEmptyValue());
+ }
+ return Status::OK();
+ }
+
+ /// \brief Append multiple null values.
+ ///
+ /// Nulls are appended to the first child, empty values to the other children.
+ Status AppendNulls(int64_t length) final {
+ const auto first_child_code = type_codes_[0];
+ ARROW_RETURN_NOT_OK(types_builder_.Append(length, first_child_code));
+ ARROW_RETURN_NOT_OK(type_id_to_children_[first_child_code]->AppendNulls(length));
+ for (int i = 1; i < static_cast<int>(type_codes_.size()); ++i) {
+ ARROW_RETURN_NOT_OK(
+ type_id_to_children_[type_codes_[i]]->AppendEmptyValues(length));
+ }
+ return Status::OK();
+ }
+
+ Status AppendEmptyValue() final {
+ ARROW_RETURN_NOT_OK(types_builder_.Append(type_codes_[0]));
+ for (int8_t code : type_codes_) {
+ ARROW_RETURN_NOT_OK(type_id_to_children_[code]->AppendEmptyValue());
+ }
+ return Status::OK();
+ }
+
+ Status AppendEmptyValues(int64_t length) final {
+ ARROW_RETURN_NOT_OK(types_builder_.Append(length, type_codes_[0]));
+ for (int8_t code : type_codes_) {
+ ARROW_RETURN_NOT_OK(type_id_to_children_[code]->AppendEmptyValues(length));
+ }
+ return Status::OK();
+ }
+
+ /// \brief Append an element to the UnionArray. This must be followed
+ /// by an append to the appropriate child builder.
+ ///
+ /// \param[in] next_type type_id of the child to which the next value will be appended.
+ ///
+ /// The corresponding child builder must be appended to independently after this method
+ /// is called, and all other child builders must have null or empty value appended.
+ Status Append(int8_t next_type) { return types_builder_.Append(next_type); }
+
+ Status AppendArraySlice(const ArrayData& array, int64_t offset,
+ int64_t length) override;
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/concatenate.cc b/src/arrow/cpp/src/arrow/array/concatenate.cc
new file mode 100644
index 000000000..54a75f06c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/concatenate.cc
@@ -0,0 +1,510 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/concatenate.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/data.h"
+#include "arrow/array/util.h"
+#include "arrow/buffer.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/int_util.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/logging.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::SafeSignedAdd;
+
+namespace {
+/// offset, length pair for representing a Range of a buffer or array
+struct Range {
+ int64_t offset = -1, length = 0;
+
+ Range() = default;
+ Range(int64_t o, int64_t l) : offset(o), length(l) {}
+};
+
+/// non-owning view into a range of bits
+struct Bitmap {
+ Bitmap() = default;
+ Bitmap(const uint8_t* d, Range r) : data(d), range(r) {}
+ explicit Bitmap(const std::shared_ptr<Buffer>& buffer, Range r)
+ : Bitmap(buffer ? buffer->data() : nullptr, r) {}
+
+ const uint8_t* data = nullptr;
+ Range range;
+
+ bool AllSet() const { return data == nullptr; }
+};
+
+// Allocate a buffer and concatenate bitmaps into it.
+Status ConcatenateBitmaps(const std::vector<Bitmap>& bitmaps, MemoryPool* pool,
+ std::shared_ptr<Buffer>* out) {
+ int64_t out_length = 0;
+ for (const auto& bitmap : bitmaps) {
+ if (internal::AddWithOverflow(out_length, bitmap.range.length, &out_length)) {
+ return Status::Invalid("Length overflow when concatenating arrays");
+ }
+ }
+ ARROW_ASSIGN_OR_RAISE(*out, AllocateBitmap(out_length, pool));
+ uint8_t* dst = (*out)->mutable_data();
+
+ int64_t bitmap_offset = 0;
+ for (auto bitmap : bitmaps) {
+ if (bitmap.AllSet()) {
+ BitUtil::SetBitsTo(dst, bitmap_offset, bitmap.range.length, true);
+ } else {
+ internal::CopyBitmap(bitmap.data, bitmap.range.offset, bitmap.range.length, dst,
+ bitmap_offset);
+ }
+ bitmap_offset += bitmap.range.length;
+ }
+
+ return Status::OK();
+}
+
+// Write offsets in src into dst, adjusting them such that first_offset
+// will be the first offset written.
+template <typename Offset>
+Status PutOffsets(const std::shared_ptr<Buffer>& src, Offset first_offset, Offset* dst,
+ Range* values_range);
+
+// Concatenate buffers holding offsets into a single buffer of offsets,
+// also computing the ranges of values spanned by each buffer of offsets.
+template <typename Offset>
+Status ConcatenateOffsets(const BufferVector& buffers, MemoryPool* pool,
+ std::shared_ptr<Buffer>* out,
+ std::vector<Range>* values_ranges) {
+ values_ranges->resize(buffers.size());
+
+ // allocate output buffer
+ int64_t out_length = 0;
+ for (const auto& buffer : buffers) {
+ out_length += buffer->size() / sizeof(Offset);
+ }
+ ARROW_ASSIGN_OR_RAISE(*out, AllocateBuffer((out_length + 1) * sizeof(Offset), pool));
+ auto dst = reinterpret_cast<Offset*>((*out)->mutable_data());
+
+ int64_t elements_length = 0;
+ Offset values_length = 0;
+ for (size_t i = 0; i < buffers.size(); ++i) {
+ // the first offset from buffers[i] will be adjusted to values_length
+ // (the cumulative length of values spanned by offsets in previous buffers)
+ RETURN_NOT_OK(PutOffsets<Offset>(buffers[i], values_length, &dst[elements_length],
+ &values_ranges->at(i)));
+ elements_length += buffers[i]->size() / sizeof(Offset);
+ values_length += static_cast<Offset>(values_ranges->at(i).length);
+ }
+
+ // the final element in dst is the length of all values spanned by the offsets
+ dst[out_length] = values_length;
+ return Status::OK();
+}
+
+template <typename Offset>
+Status PutOffsets(const std::shared_ptr<Buffer>& src, Offset first_offset, Offset* dst,
+ Range* values_range) {
+ if (src->size() == 0) {
+ // It's allowed to have an empty offsets buffer for a 0-length array
+ // (see Array::Validate)
+ values_range->offset = 0;
+ values_range->length = 0;
+ return Status::OK();
+ }
+
+ // Get the range of offsets to transfer from src
+ auto src_begin = reinterpret_cast<const Offset*>(src->data());
+ auto src_end = reinterpret_cast<const Offset*>(src->data() + src->size());
+
+ // Compute the range of values which is spanned by this range of offsets
+ values_range->offset = src_begin[0];
+ values_range->length = *src_end - values_range->offset;
+ if (first_offset > std::numeric_limits<Offset>::max() - values_range->length) {
+ return Status::Invalid("offset overflow while concatenating arrays");
+ }
+
+ // Write offsets into dst, ensuring that the first offset written is
+ // first_offset
+ auto adjustment = first_offset - src_begin[0];
+ // NOTE: Concatenate can be called during IPC reads to append delta dictionaries.
+ // Avoid UB on non-validated input by doing the addition in the unsigned domain.
+ // (the result can later be validated using Array::ValidateFull)
+ std::transform(src_begin, src_end, dst, [adjustment](Offset offset) {
+ return SafeSignedAdd(offset, adjustment);
+ });
+ return Status::OK();
+}
+
+class ConcatenateImpl {
+ public:
+ ConcatenateImpl(const ArrayDataVector& in, MemoryPool* pool)
+ : in_(std::move(in)), pool_(pool), out_(std::make_shared<ArrayData>()) {
+ out_->type = in[0]->type;
+ for (size_t i = 0; i < in_.size(); ++i) {
+ out_->length = SafeSignedAdd(out_->length, in[i]->length);
+ if (out_->null_count == kUnknownNullCount ||
+ in[i]->null_count == kUnknownNullCount) {
+ out_->null_count = kUnknownNullCount;
+ continue;
+ }
+ out_->null_count = SafeSignedAdd(out_->null_count.load(), in[i]->null_count.load());
+ }
+ out_->buffers.resize(in[0]->buffers.size());
+ out_->child_data.resize(in[0]->child_data.size());
+ for (auto& data : out_->child_data) {
+ data = std::make_shared<ArrayData>();
+ }
+ }
+
+ Status Concatenate(std::shared_ptr<ArrayData>* out) && {
+ if (out_->null_count != 0 && internal::HasValidityBitmap(out_->type->id())) {
+ RETURN_NOT_OK(ConcatenateBitmaps(Bitmaps(0), pool_, &out_->buffers[0]));
+ }
+ RETURN_NOT_OK(VisitTypeInline(*out_->type, this));
+ *out = std::move(out_);
+ return Status::OK();
+ }
+
+ Status Visit(const NullType&) { return Status::OK(); }
+
+ Status Visit(const BooleanType&) {
+ return ConcatenateBitmaps(Bitmaps(1), pool_, &out_->buffers[1]);
+ }
+
+ Status Visit(const FixedWidthType& fixed) {
+ // Handles numbers, decimal128, decimal256, fixed_size_binary
+ ARROW_ASSIGN_OR_RAISE(auto buffers, Buffers(1, fixed));
+ return ConcatenateBuffers(buffers, pool_).Value(&out_->buffers[1]);
+ }
+
+ Status Visit(const BinaryType&) {
+ std::vector<Range> value_ranges;
+ ARROW_ASSIGN_OR_RAISE(auto index_buffers, Buffers(1, sizeof(int32_t)));
+ RETURN_NOT_OK(ConcatenateOffsets<int32_t>(index_buffers, pool_, &out_->buffers[1],
+ &value_ranges));
+ ARROW_ASSIGN_OR_RAISE(auto value_buffers, Buffers(2, value_ranges));
+ return ConcatenateBuffers(value_buffers, pool_).Value(&out_->buffers[2]);
+ }
+
+ Status Visit(const LargeBinaryType&) {
+ std::vector<Range> value_ranges;
+ ARROW_ASSIGN_OR_RAISE(auto index_buffers, Buffers(1, sizeof(int64_t)));
+ RETURN_NOT_OK(ConcatenateOffsets<int64_t>(index_buffers, pool_, &out_->buffers[1],
+ &value_ranges));
+ ARROW_ASSIGN_OR_RAISE(auto value_buffers, Buffers(2, value_ranges));
+ return ConcatenateBuffers(value_buffers, pool_).Value(&out_->buffers[2]);
+ }
+
+ Status Visit(const ListType&) {
+ std::vector<Range> value_ranges;
+ ARROW_ASSIGN_OR_RAISE(auto index_buffers, Buffers(1, sizeof(int32_t)));
+ RETURN_NOT_OK(ConcatenateOffsets<int32_t>(index_buffers, pool_, &out_->buffers[1],
+ &value_ranges));
+ ARROW_ASSIGN_OR_RAISE(auto child_data, ChildData(0, value_ranges));
+ return ConcatenateImpl(child_data, pool_).Concatenate(&out_->child_data[0]);
+ }
+
+ Status Visit(const LargeListType&) {
+ std::vector<Range> value_ranges;
+ ARROW_ASSIGN_OR_RAISE(auto index_buffers, Buffers(1, sizeof(int64_t)));
+ RETURN_NOT_OK(ConcatenateOffsets<int64_t>(index_buffers, pool_, &out_->buffers[1],
+ &value_ranges));
+ ARROW_ASSIGN_OR_RAISE(auto child_data, ChildData(0, value_ranges));
+ return ConcatenateImpl(child_data, pool_).Concatenate(&out_->child_data[0]);
+ }
+
+ Status Visit(const FixedSizeListType& fixed_size_list) {
+ ARROW_ASSIGN_OR_RAISE(auto child_data, ChildData(0, fixed_size_list.list_size()));
+ return ConcatenateImpl(child_data, pool_).Concatenate(&out_->child_data[0]);
+ }
+
+ Status Visit(const StructType& s) {
+ for (int i = 0; i < s.num_fields(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto child_data, ChildData(i));
+ RETURN_NOT_OK(ConcatenateImpl(child_data, pool_).Concatenate(&out_->child_data[i]));
+ }
+ return Status::OK();
+ }
+
+ Result<BufferVector> UnifyDictionaries(const DictionaryType& d) {
+ BufferVector new_index_lookup;
+ ARROW_ASSIGN_OR_RAISE(auto unifier, DictionaryUnifier::Make(d.value_type()));
+ new_index_lookup.resize(in_.size());
+ for (size_t i = 0; i < in_.size(); i++) {
+ auto item = in_[i];
+ auto dictionary_array = MakeArray(item->dictionary);
+ RETURN_NOT_OK(unifier->Unify(*dictionary_array, &new_index_lookup[i]));
+ }
+ std::shared_ptr<Array> out_dictionary;
+ RETURN_NOT_OK(unifier->GetResultWithIndexType(d.index_type(), &out_dictionary));
+ out_->dictionary = out_dictionary->data();
+ return new_index_lookup;
+ }
+
+ // Transpose and concatenate dictionary indices
+ Result<std::shared_ptr<Buffer>> ConcatenateDictionaryIndices(
+ const DataType& index_type, const BufferVector& index_transpositions) {
+ const auto index_width =
+ internal::checked_cast<const FixedWidthType&>(index_type).bit_width() / 8;
+ int64_t out_length = 0;
+ for (const auto& data : in_) {
+ out_length += data->length;
+ }
+ ARROW_ASSIGN_OR_RAISE(auto out, AllocateBuffer(out_length * index_width, pool_));
+ uint8_t* out_data = out->mutable_data();
+ for (size_t i = 0; i < in_.size(); i++) {
+ const auto& data = in_[i];
+ auto transpose_map =
+ reinterpret_cast<const int32_t*>(index_transpositions[i]->data());
+ const uint8_t* src = data->GetValues<uint8_t>(1, 0);
+ if (!data->buffers[0]) {
+ RETURN_NOT_OK(internal::TransposeInts(index_type, index_type,
+ /*src=*/data->GetValues<uint8_t>(1, 0),
+ /*dest=*/out_data,
+ /*src_offset=*/data->offset,
+ /*dest_offset=*/0, /*length=*/data->length,
+ transpose_map));
+ } else {
+ internal::BitRunReader reader(data->buffers[0]->data(), data->offset,
+ data->length);
+ int64_t position = 0;
+ while (true) {
+ internal::BitRun run = reader.NextRun();
+ if (run.length == 0) break;
+
+ if (run.set) {
+ RETURN_NOT_OK(internal::TransposeInts(index_type, index_type, src,
+ /*dest=*/out_data,
+ /*src_offset=*/data->offset + position,
+ /*dest_offset=*/position, run.length,
+ transpose_map));
+ } else {
+ std::fill(out_data + position,
+ out_data + position + (run.length * index_width), 0x00);
+ }
+
+ position += run.length;
+ }
+ }
+ out_data += data->length * index_width;
+ }
+ return std::move(out);
+ }
+
+ Status Visit(const DictionaryType& d) {
+ auto fixed = internal::checked_cast<const FixedWidthType*>(d.index_type().get());
+
+ // Two cases: all the dictionaries are the same, or unification is
+ // required
+ bool dictionaries_same = true;
+ std::shared_ptr<Array> dictionary0 = MakeArray(in_[0]->dictionary);
+ for (size_t i = 1; i < in_.size(); ++i) {
+ if (!MakeArray(in_[i]->dictionary)->Equals(dictionary0)) {
+ dictionaries_same = false;
+ break;
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto index_buffers, Buffers(1, *fixed));
+ if (dictionaries_same) {
+ out_->dictionary = in_[0]->dictionary;
+ return ConcatenateBuffers(index_buffers, pool_).Value(&out_->buffers[1]);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto index_lookup, UnifyDictionaries(d));
+ ARROW_ASSIGN_OR_RAISE(out_->buffers[1],
+ ConcatenateDictionaryIndices(*fixed, index_lookup));
+ return Status::OK();
+ }
+ }
+
+ Status Visit(const UnionType& u) {
+ return Status::NotImplemented("concatenation of ", u);
+ }
+
+ Status Visit(const ExtensionType& e) {
+ // XXX can we just concatenate their storage?
+ return Status::NotImplemented("concatenation of ", e);
+ }
+
+ private:
+ // NOTE: Concatenate() can be called during IPC reads to append delta dictionaries
+ // on non-validated input. Therefore, the input-checking SliceBufferSafe and
+ // ArrayData::SliceSafe are used below.
+
+ // Gather the index-th buffer of each input into a vector.
+ // Bytes are sliced with that input's offset and length.
+ // Note that BufferVector will not contain the buffer of in_[i] if it's
+ // nullptr.
+ Result<BufferVector> Buffers(size_t index) {
+ BufferVector buffers;
+ buffers.reserve(in_.size());
+ for (const auto& array_data : in_) {
+ const auto& buffer = array_data->buffers[index];
+ if (buffer != nullptr) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto sliced_buffer,
+ SliceBufferSafe(buffer, array_data->offset, array_data->length));
+ buffers.push_back(std::move(sliced_buffer));
+ }
+ }
+ return buffers;
+ }
+
+ // Gather the index-th buffer of each input into a vector.
+ // Bytes are sliced with the explicitly passed ranges.
+ // Note that BufferVector will not contain the buffer of in_[i] if it's
+ // nullptr.
+ Result<BufferVector> Buffers(size_t index, const std::vector<Range>& ranges) {
+ DCHECK_EQ(in_.size(), ranges.size());
+ BufferVector buffers;
+ buffers.reserve(in_.size());
+ for (size_t i = 0; i < in_.size(); ++i) {
+ const auto& buffer = in_[i]->buffers[index];
+ if (buffer != nullptr) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto sliced_buffer,
+ SliceBufferSafe(buffer, ranges[i].offset, ranges[i].length));
+ buffers.push_back(std::move(sliced_buffer));
+ } else {
+ DCHECK_EQ(ranges[i].length, 0);
+ }
+ }
+ return buffers;
+ }
+
+ // Gather the index-th buffer of each input into a vector.
+ // Buffers are assumed to contain elements of the given byte_width,
+ // those elements are sliced with that input's offset and length.
+ // Note that BufferVector will not contain the buffer of in_[i] if it's
+ // nullptr.
+ Result<BufferVector> Buffers(size_t index, int byte_width) {
+ BufferVector buffers;
+ buffers.reserve(in_.size());
+ for (const auto& array_data : in_) {
+ const auto& buffer = array_data->buffers[index];
+ if (buffer != nullptr) {
+ ARROW_ASSIGN_OR_RAISE(auto sliced_buffer,
+ SliceBufferSafe(buffer, array_data->offset * byte_width,
+ array_data->length * byte_width));
+ buffers.push_back(std::move(sliced_buffer));
+ }
+ }
+ return buffers;
+ }
+
+ // Gather the index-th buffer of each input into a vector.
+ // Buffers are assumed to contain elements of fixed.bit_width(),
+ // those elements are sliced with that input's offset and length.
+ // Note that BufferVector will not contain the buffer of in_[i] if it's
+ // nullptr.
+ Result<BufferVector> Buffers(size_t index, const FixedWidthType& fixed) {
+ DCHECK_EQ(fixed.bit_width() % 8, 0);
+ return Buffers(index, fixed.bit_width() / 8);
+ }
+
+ // Gather the index-th buffer of each input as a Bitmap
+ // into a vector of Bitmaps.
+ std::vector<Bitmap> Bitmaps(size_t index) {
+ std::vector<Bitmap> bitmaps(in_.size());
+ for (size_t i = 0; i < in_.size(); ++i) {
+ Range range(in_[i]->offset, in_[i]->length);
+ bitmaps[i] = Bitmap(in_[i]->buffers[index], range);
+ }
+ return bitmaps;
+ }
+
+ // Gather the index-th child_data of each input into a vector.
+ // Elements are sliced with that input's offset and length.
+ Result<ArrayDataVector> ChildData(size_t index) {
+ ArrayDataVector child_data(in_.size());
+ for (size_t i = 0; i < in_.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(child_data[i], in_[i]->child_data[index]->SliceSafe(
+ in_[i]->offset, in_[i]->length));
+ }
+ return child_data;
+ }
+
+ // Gather the index-th child_data of each input into a vector.
+ // Elements are sliced with that input's offset and length multiplied by multiplier.
+ Result<ArrayDataVector> ChildData(size_t index, size_t multiplier) {
+ ArrayDataVector child_data(in_.size());
+ for (size_t i = 0; i < in_.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(
+ child_data[i], in_[i]->child_data[index]->SliceSafe(
+ in_[i]->offset * multiplier, in_[i]->length * multiplier));
+ }
+ return child_data;
+ }
+
+ // Gather the index-th child_data of each input into a vector.
+ // Elements are sliced with the explicitly passed ranges.
+ Result<ArrayDataVector> ChildData(size_t index, const std::vector<Range>& ranges) {
+ DCHECK_EQ(in_.size(), ranges.size());
+ ArrayDataVector child_data(in_.size());
+ for (size_t i = 0; i < in_.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(child_data[i], in_[i]->child_data[index]->SliceSafe(
+ ranges[i].offset, ranges[i].length));
+ }
+ return child_data;
+ }
+
+ const ArrayDataVector& in_;
+ MemoryPool* pool_;
+ std::shared_ptr<ArrayData> out_;
+};
+
+} // namespace
+
+Result<std::shared_ptr<Array>> Concatenate(const ArrayVector& arrays, MemoryPool* pool) {
+ if (arrays.size() == 0) {
+ return Status::Invalid("Must pass at least one array");
+ }
+
+ // gather ArrayData of input arrays
+ ArrayDataVector data(arrays.size());
+ for (size_t i = 0; i < arrays.size(); ++i) {
+ if (!arrays[i]->type()->Equals(*arrays[0]->type())) {
+ return Status::Invalid("arrays to be concatenated must be identically typed, but ",
+ *arrays[0]->type(), " and ", *arrays[i]->type(),
+ " were encountered.");
+ }
+ data[i] = arrays[i]->data();
+ }
+
+ std::shared_ptr<ArrayData> out_data;
+ RETURN_NOT_OK(ConcatenateImpl(data, pool).Concatenate(&out_data));
+ return MakeArray(std::move(out_data));
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/concatenate.h b/src/arrow/cpp/src/arrow/array/concatenate.h
new file mode 100644
index 000000000..e7597aad8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/concatenate.h
@@ -0,0 +1,37 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+/// \brief Concatenate arrays
+///
+/// \param[in] arrays a vector of arrays to be concatenated
+/// \param[in] pool memory to store the result will be allocated from this memory pool
+/// \return the concatenated array
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> Concatenate(const ArrayVector& arrays,
+ MemoryPool* pool = default_memory_pool());
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/concatenate_test.cc b/src/arrow/cpp/src/arrow/array/concatenate_test.cc
new file mode 100644
index 000000000..305910c24
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/concatenate_test.cc
@@ -0,0 +1,398 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <array>
+#include <cstdint>
+#include <cstring>
+#include <iterator>
+#include <limits>
+#include <memory>
+#include <numeric>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/buffer.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+
+namespace arrow {
+
+class ConcatenateTest : public ::testing::Test {
+ protected:
+ ConcatenateTest()
+ : rng_(seed_),
+ sizes_({0, 1, 2, 4, 16, 31, 1234}),
+ null_probabilities_({0.0, 0.1, 0.5, 0.9, 1.0}) {}
+
+ template <typename OffsetType>
+ std::vector<OffsetType> Offsets(int32_t length, int32_t slice_count) {
+ std::vector<OffsetType> offsets(static_cast<std::size_t>(slice_count + 1));
+ std::default_random_engine gen(seed_);
+ std::uniform_int_distribution<OffsetType> dist(0, length);
+ std::generate(offsets.begin(), offsets.end(), [&] { return dist(gen); });
+ std::sort(offsets.begin(), offsets.end());
+ return offsets;
+ }
+
+ ArrayVector Slices(const std::shared_ptr<Array>& array,
+ const std::vector<int32_t>& offsets) {
+ ArrayVector slices(offsets.size() - 1);
+ for (size_t i = 0; i != slices.size(); ++i) {
+ slices[i] = array->Slice(offsets[i], offsets[i + 1] - offsets[i]);
+ }
+ return slices;
+ }
+
+ template <typename PrimitiveType>
+ std::shared_ptr<Array> GeneratePrimitive(int64_t size, double null_probability) {
+ if (std::is_same<PrimitiveType, BooleanType>::value) {
+ return rng_.Boolean(size, 0.5, null_probability);
+ }
+ return rng_.Numeric<PrimitiveType, uint8_t>(size, 0, 127, null_probability);
+ }
+
+ void CheckTrailingBitsAreZeroed(const std::shared_ptr<Buffer>& bitmap, int64_t length) {
+ if (auto preceding_bits = BitUtil::kPrecedingBitmask[length % 8]) {
+ auto last_byte = bitmap->data()[length / 8];
+ ASSERT_EQ(static_cast<uint8_t>(last_byte & preceding_bits), last_byte)
+ << length << " " << int(preceding_bits);
+ }
+ }
+
+ template <typename ArrayFactory>
+ void Check(ArrayFactory&& factory) {
+ for (auto size : this->sizes_) {
+ auto offsets = this->Offsets<int32_t>(size, 3);
+ for (auto null_probability : this->null_probabilities_) {
+ std::shared_ptr<Array> array;
+ factory(size, null_probability, &array);
+ auto expected = array->Slice(offsets.front(), offsets.back() - offsets.front());
+ auto slices = this->Slices(array, offsets);
+ ASSERT_OK_AND_ASSIGN(auto actual, Concatenate(slices));
+ AssertArraysEqual(*expected, *actual);
+ if (actual->data()->buffers[0]) {
+ CheckTrailingBitsAreZeroed(actual->data()->buffers[0], actual->length());
+ }
+ if (actual->type_id() == Type::BOOL) {
+ CheckTrailingBitsAreZeroed(actual->data()->buffers[1], actual->length());
+ }
+ }
+ }
+ }
+
+ random::SeedType seed_ = 0xdeadbeef;
+ random::RandomArrayGenerator rng_;
+ std::vector<int32_t> sizes_;
+ std::vector<double> null_probabilities_;
+};
+
+TEST(ConcatenateEmptyArraysTest, TestValueBuffersNullPtr) {
+ ArrayVector inputs;
+
+ std::shared_ptr<Array> binary_array;
+ BinaryBuilder builder;
+ ASSERT_OK(builder.Finish(&binary_array));
+ inputs.push_back(std::move(binary_array));
+
+ builder.Reset();
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_OK(builder.Finish(&binary_array));
+ inputs.push_back(std::move(binary_array));
+
+ ASSERT_OK_AND_ASSIGN(auto actual, Concatenate(inputs));
+ AssertArraysEqual(*actual, *inputs[1]);
+}
+
+template <typename PrimitiveType>
+class PrimitiveConcatenateTest : public ConcatenateTest {
+ public:
+};
+
+TYPED_TEST_SUITE(PrimitiveConcatenateTest, PrimitiveArrowTypes);
+
+TYPED_TEST(PrimitiveConcatenateTest, Primitives) {
+ this->Check([this](int64_t size, double null_probability, std::shared_ptr<Array>* out) {
+ *out = this->template GeneratePrimitive<TypeParam>(size, null_probability);
+ });
+}
+
+TEST_F(ConcatenateTest, NullType) {
+ Check([](int32_t size, double null_probability, std::shared_ptr<Array>* out) {
+ *out = std::make_shared<NullArray>(size);
+ });
+}
+
+TEST_F(ConcatenateTest, StringType) {
+ Check([this](int32_t size, double null_probability, std::shared_ptr<Array>* out) {
+ *out = rng_.String(size, /*min_length =*/0, /*max_length =*/15, null_probability);
+ ASSERT_OK((**out).ValidateFull());
+ });
+}
+
+TEST_F(ConcatenateTest, LargeStringType) {
+ Check([this](int32_t size, double null_probability, std::shared_ptr<Array>* out) {
+ *out =
+ rng_.LargeString(size, /*min_length =*/0, /*max_length =*/15, null_probability);
+ ASSERT_OK((**out).ValidateFull());
+ });
+}
+
+TEST_F(ConcatenateTest, FixedSizeListType) {
+ Check([this](int32_t size, double null_probability, std::shared_ptr<Array>* out) {
+ auto list_size = 3;
+ auto values_size = size * list_size;
+ auto values = this->GeneratePrimitive<Int8Type>(values_size, null_probability);
+ ASSERT_OK_AND_ASSIGN(*out, FixedSizeListArray::FromArrays(values, list_size));
+ ASSERT_OK((**out).ValidateFull());
+ });
+}
+
+TEST_F(ConcatenateTest, ListType) {
+ Check([this](int32_t size, double null_probability, std::shared_ptr<Array>* out) {
+ auto values_size = size * 4;
+ auto values = this->GeneratePrimitive<Int8Type>(values_size, null_probability);
+ auto offsets_vector = this->Offsets<int32_t>(values_size, size);
+ // Ensure first and last offsets encompass the whole values array
+ offsets_vector.front() = 0;
+ offsets_vector.back() = static_cast<int32_t>(values_size);
+ std::shared_ptr<Array> offsets;
+ ArrayFromVector<Int32Type>(offsets_vector, &offsets);
+ ASSERT_OK_AND_ASSIGN(*out, ListArray::FromArrays(*offsets, *values));
+ ASSERT_OK((**out).ValidateFull());
+ });
+}
+
+TEST_F(ConcatenateTest, LargeListType) {
+ Check([this](int32_t size, double null_probability, std::shared_ptr<Array>* out) {
+ auto values_size = size * 4;
+ auto values = this->GeneratePrimitive<Int8Type>(values_size, null_probability);
+ auto offsets_vector = this->Offsets<int64_t>(values_size, size);
+ // Ensure first and last offsets encompass the whole values array
+ offsets_vector.front() = 0;
+ offsets_vector.back() = static_cast<int64_t>(values_size);
+ std::shared_ptr<Array> offsets;
+ ArrayFromVector<Int64Type>(offsets_vector, &offsets);
+ ASSERT_OK_AND_ASSIGN(*out, LargeListArray::FromArrays(*offsets, *values));
+ ASSERT_OK((**out).ValidateFull());
+ });
+}
+
+TEST_F(ConcatenateTest, StructType) {
+ Check([this](int32_t size, double null_probability, std::shared_ptr<Array>* out) {
+ auto foo = this->GeneratePrimitive<Int8Type>(size, null_probability);
+ auto bar = this->GeneratePrimitive<DoubleType>(size, null_probability);
+ auto baz = this->GeneratePrimitive<BooleanType>(size, null_probability);
+ *out = std::make_shared<StructArray>(
+ struct_({field("foo", int8()), field("bar", float64()), field("baz", boolean())}),
+ size, ArrayVector{foo, bar, baz});
+ });
+}
+
+TEST_F(ConcatenateTest, DictionaryType) {
+ Check([this](int32_t size, double null_probability, std::shared_ptr<Array>* out) {
+ auto indices = this->GeneratePrimitive<Int32Type>(size, null_probability);
+ auto dict = this->GeneratePrimitive<DoubleType>(128, 0);
+ auto type = dictionary(int32(), dict->type());
+ *out = std::make_shared<DictionaryArray>(type, indices, dict);
+ });
+}
+
+TEST_F(ConcatenateTest, DictionaryTypeDifferentDictionaries) {
+ {
+ auto dict_type = dictionary(uint8(), utf8());
+ auto dict_one = DictArrayFromJSON(dict_type, "[1, 2, null, 3, 0]",
+ "[\"A0\", \"A1\", \"A2\", \"A3\"]");
+ auto dict_two = DictArrayFromJSON(dict_type, "[null, 4, 2, 1]",
+ "[\"B0\", \"B1\", \"B2\", \"B3\", \"B4\"]");
+ auto concat_expected = DictArrayFromJSON(
+ dict_type, "[1, 2, null, 3, 0, null, 8, 6, 5]",
+ "[\"A0\", \"A1\", \"A2\", \"A3\", \"B0\", \"B1\", \"B2\", \"B3\", \"B4\"]");
+ ASSERT_OK_AND_ASSIGN(auto concat_actual, Concatenate({dict_one, dict_two}));
+ AssertArraysEqual(*concat_expected, *concat_actual);
+ }
+ {
+ const int SIZE = 500;
+ auto dict_type = dictionary(uint16(), utf8());
+
+ UInt16Builder index_builder;
+ UInt16Builder expected_index_builder;
+ ASSERT_OK(index_builder.Reserve(SIZE));
+ ASSERT_OK(expected_index_builder.Reserve(SIZE * 2));
+ for (auto i = 0; i < SIZE; i++) {
+ index_builder.UnsafeAppend(i);
+ expected_index_builder.UnsafeAppend(i);
+ }
+ for (auto i = SIZE; i < 2 * SIZE; i++) {
+ expected_index_builder.UnsafeAppend(i);
+ }
+ ASSERT_OK_AND_ASSIGN(auto indices, index_builder.Finish());
+ ASSERT_OK_AND_ASSIGN(auto expected_indices, expected_index_builder.Finish());
+
+ // Creates three dictionaries. The first maps i->"{i}" the second maps i->"{500+i}",
+ // each for 500 values and the third maps i->"{i}" but for 1000 values.
+ // The first and second concatenated should end up equaling the third. All strings
+ // are padded to length 8 so we can know the size ahead of time.
+ StringBuilder values_one_builder;
+ StringBuilder values_two_builder;
+ ASSERT_OK(values_one_builder.Resize(SIZE));
+ ASSERT_OK(values_two_builder.Resize(SIZE));
+ ASSERT_OK(values_one_builder.ReserveData(8 * SIZE));
+ ASSERT_OK(values_two_builder.ReserveData(8 * SIZE));
+ for (auto i = 0; i < SIZE; i++) {
+ auto i_str = std::to_string(i);
+ auto padded = i_str.insert(0, 8 - i_str.length(), '0');
+ values_one_builder.UnsafeAppend(padded);
+ auto upper_i_str = std::to_string(i + SIZE);
+ auto upper_padded = upper_i_str.insert(0, 8 - i_str.length(), '0');
+ values_two_builder.UnsafeAppend(upper_padded);
+ }
+ ASSERT_OK_AND_ASSIGN(auto dictionary_one, values_one_builder.Finish());
+ ASSERT_OK_AND_ASSIGN(auto dictionary_two, values_two_builder.Finish());
+ ASSERT_OK_AND_ASSIGN(auto expected_dictionary,
+ Concatenate({dictionary_one, dictionary_two}))
+
+ auto one = std::make_shared<DictionaryArray>(dict_type, indices, dictionary_one);
+ auto two = std::make_shared<DictionaryArray>(dict_type, indices, dictionary_two);
+ auto expected = std::make_shared<DictionaryArray>(dict_type, expected_indices,
+ expected_dictionary);
+ ASSERT_OK_AND_ASSIGN(auto combined, Concatenate({one, two}));
+ AssertArraysEqual(*combined, *expected);
+ }
+}
+
+TEST_F(ConcatenateTest, DictionaryTypePartialOverlapDictionaries) {
+ auto dict_type = dictionary(uint8(), utf8());
+ auto dict_one = DictArrayFromJSON(dict_type, "[1, 2, null, 3, 0]",
+ "[\"A0\", \"A1\", \"C2\", \"C3\"]");
+ auto dict_two = DictArrayFromJSON(dict_type, "[null, 4, 2, 1]",
+ "[\"B0\", \"B1\", \"C2\", \"C3\", \"B4\"]");
+ auto concat_expected =
+ DictArrayFromJSON(dict_type, "[1, 2, null, 3, 0, null, 6, 2, 5]",
+ "[\"A0\", \"A1\", \"C2\", \"C3\", \"B0\", \"B1\", \"B4\"]");
+ ASSERT_OK_AND_ASSIGN(auto concat_actual, Concatenate({dict_one, dict_two}));
+ AssertArraysEqual(*concat_expected, *concat_actual);
+}
+
+TEST_F(ConcatenateTest, DictionaryTypeDifferentSizeIndex) {
+ auto dict_type = dictionary(uint8(), utf8());
+ auto bigger_dict_type = dictionary(uint16(), utf8());
+ auto dict_one = DictArrayFromJSON(dict_type, "[0]", "[\"A0\"]");
+ auto dict_two = DictArrayFromJSON(bigger_dict_type, "[0]", "[\"B0\"]");
+ ASSERT_RAISES(Invalid, Concatenate({dict_one, dict_two}).status());
+}
+
+TEST_F(ConcatenateTest, DictionaryTypeCantUnifyNullInDictionary) {
+ auto dict_type = dictionary(uint8(), utf8());
+ auto dict_one = DictArrayFromJSON(dict_type, "[0, 1]", "[null, \"A\"]");
+ auto dict_two = DictArrayFromJSON(dict_type, "[0, 1]", "[null, \"B\"]");
+ ASSERT_RAISES(Invalid, Concatenate({dict_one, dict_two}).status());
+}
+
+TEST_F(ConcatenateTest, DictionaryTypeEnlargedIndices) {
+ auto size = std::numeric_limits<uint8_t>::max() + 1;
+ auto dict_type = dictionary(uint8(), uint16());
+
+ UInt8Builder index_builder;
+ ASSERT_OK(index_builder.Reserve(size));
+ for (auto i = 0; i < size; i++) {
+ index_builder.UnsafeAppend(i);
+ }
+ ASSERT_OK_AND_ASSIGN(auto indices, index_builder.Finish());
+
+ UInt16Builder values_builder;
+ ASSERT_OK(values_builder.Reserve(size));
+ UInt16Builder values_builder_two;
+ ASSERT_OK(values_builder_two.Reserve(size));
+ for (auto i = 0; i < size; i++) {
+ values_builder.UnsafeAppend(i);
+ values_builder_two.UnsafeAppend(i + size);
+ }
+ ASSERT_OK_AND_ASSIGN(auto dictionary_one, values_builder.Finish());
+ ASSERT_OK_AND_ASSIGN(auto dictionary_two, values_builder_two.Finish());
+
+ auto dict_one = std::make_shared<DictionaryArray>(dict_type, indices, dictionary_one);
+ auto dict_two = std::make_shared<DictionaryArray>(dict_type, indices, dictionary_two);
+ ASSERT_RAISES(Invalid, Concatenate({dict_one, dict_two}).status());
+
+ auto bigger_dict_type = dictionary(uint16(), uint16());
+
+ auto bigger_one =
+ std::make_shared<DictionaryArray>(bigger_dict_type, dictionary_one, dictionary_one);
+ auto bigger_two =
+ std::make_shared<DictionaryArray>(bigger_dict_type, dictionary_one, dictionary_two);
+ ASSERT_OK_AND_ASSIGN(auto combined, Concatenate({bigger_one, bigger_two}));
+ ASSERT_EQ(size * 2, combined->length());
+}
+
+TEST_F(ConcatenateTest, DictionaryTypeNullSlots) {
+ // Regression test for ARROW-13639
+ auto dict_type = dictionary(uint32(), utf8());
+ auto dict_one = DictArrayFromJSON(dict_type, "[null, null, null, null]", "[]");
+ auto dict_two =
+ DictArrayFromJSON(dict_type, "[null, null, null, null, 0, 1]", R"(["a", "b"])");
+ auto expected = DictArrayFromJSON(
+ dict_type, "[null, null, null, null, null, null, null, null, 0, 1]",
+ R"(["a", "b"])");
+ ASSERT_OK_AND_ASSIGN(auto concat_actual, Concatenate({dict_one, dict_two}));
+ ASSERT_OK(concat_actual->ValidateFull());
+ TestInitialized(*concat_actual);
+ AssertArraysEqual(*expected, *concat_actual);
+}
+
+TEST_F(ConcatenateTest, DISABLED_UnionType) {
+ // sparse mode
+ Check([this](int32_t size, double null_probability, std::shared_ptr<Array>* out) {
+ auto foo = this->GeneratePrimitive<Int8Type>(size, null_probability);
+ auto bar = this->GeneratePrimitive<DoubleType>(size, null_probability);
+ auto baz = this->GeneratePrimitive<BooleanType>(size, null_probability);
+ auto type_ids = rng_.Numeric<Int8Type>(size, 0, 2, null_probability);
+ ASSERT_OK_AND_ASSIGN(*out, SparseUnionArray::Make(*type_ids, {foo, bar, baz}));
+ });
+ // dense mode
+ Check([this](int32_t size, double null_probability, std::shared_ptr<Array>* out) {
+ auto foo = this->GeneratePrimitive<Int8Type>(size, null_probability);
+ auto bar = this->GeneratePrimitive<DoubleType>(size, null_probability);
+ auto baz = this->GeneratePrimitive<BooleanType>(size, null_probability);
+ auto type_ids = rng_.Numeric<Int8Type>(size, 0, 2, null_probability);
+ auto value_offsets = rng_.Numeric<Int32Type>(size, 0, size, 0);
+ ASSERT_OK_AND_ASSIGN(
+ *out, DenseUnionArray::Make(*type_ids, *value_offsets, {foo, bar, baz}));
+ });
+}
+
+TEST_F(ConcatenateTest, OffsetOverflow) {
+ auto fake_long = ArrayFromJSON(utf8(), "[\"\"]");
+ fake_long->data()->GetMutableValues<int32_t>(1)[1] =
+ std::numeric_limits<int32_t>::max();
+ std::shared_ptr<Array> concatenated;
+ // XX since the data fake_long claims to own isn't there, this will segfault if
+ // Concatenate doesn't detect overflow and raise an error.
+ ASSERT_RAISES(Invalid, Concatenate({fake_long, fake_long}).status());
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/data.cc b/src/arrow/cpp/src/arrow/array/data.cc
new file mode 100644
index 000000000..5a2144739
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/data.cc
@@ -0,0 +1,331 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/data.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+
+using internal::CountSetBits;
+
+static inline void AdjustNonNullable(Type::type type_id, int64_t length,
+ std::vector<std::shared_ptr<Buffer>>* buffers,
+ int64_t* null_count) {
+ if (type_id == Type::NA) {
+ *null_count = length;
+ (*buffers)[0] = nullptr;
+ } else if (internal::HasValidityBitmap(type_id)) {
+ if (*null_count == 0) {
+ // In case there are no nulls, don't keep an allocated null bitmap around
+ (*buffers)[0] = nullptr;
+ } else if (*null_count == kUnknownNullCount && buffers->at(0) == nullptr) {
+ // Conversely, if no null bitmap is provided, set the null count to 0
+ *null_count = 0;
+ }
+ } else {
+ *null_count = 0;
+ }
+}
+
+std::shared_ptr<ArrayData> ArrayData::Make(std::shared_ptr<DataType> type, int64_t length,
+ std::vector<std::shared_ptr<Buffer>> buffers,
+ int64_t null_count, int64_t offset) {
+ AdjustNonNullable(type->id(), length, &buffers, &null_count);
+ return std::make_shared<ArrayData>(std::move(type), length, std::move(buffers),
+ null_count, offset);
+}
+
+std::shared_ptr<ArrayData> ArrayData::Make(
+ std::shared_ptr<DataType> type, int64_t length,
+ std::vector<std::shared_ptr<Buffer>> buffers,
+ std::vector<std::shared_ptr<ArrayData>> child_data, int64_t null_count,
+ int64_t offset) {
+ AdjustNonNullable(type->id(), length, &buffers, &null_count);
+ return std::make_shared<ArrayData>(std::move(type), length, std::move(buffers),
+ std::move(child_data), null_count, offset);
+}
+
+std::shared_ptr<ArrayData> ArrayData::Make(
+ std::shared_ptr<DataType> type, int64_t length,
+ std::vector<std::shared_ptr<Buffer>> buffers,
+ std::vector<std::shared_ptr<ArrayData>> child_data,
+ std::shared_ptr<ArrayData> dictionary, int64_t null_count, int64_t offset) {
+ AdjustNonNullable(type->id(), length, &buffers, &null_count);
+ auto data = std::make_shared<ArrayData>(std::move(type), length, std::move(buffers),
+ std::move(child_data), null_count, offset);
+ data->dictionary = std::move(dictionary);
+ return data;
+}
+
+std::shared_ptr<ArrayData> ArrayData::Make(std::shared_ptr<DataType> type, int64_t length,
+ int64_t null_count, int64_t offset) {
+ return std::make_shared<ArrayData>(std::move(type), length, null_count, offset);
+}
+
+std::shared_ptr<ArrayData> ArrayData::Slice(int64_t off, int64_t len) const {
+ ARROW_CHECK_LE(off, length) << "Slice offset greater than array length";
+ len = std::min(length - off, len);
+ off += offset;
+
+ auto copy = this->Copy();
+ copy->length = len;
+ copy->offset = off;
+ if (null_count == length) {
+ copy->null_count = len;
+ } else if (off == offset && len == length) { // A copy of current.
+ copy->null_count = null_count.load();
+ } else {
+ copy->null_count = null_count != 0 ? kUnknownNullCount : 0;
+ }
+ return copy;
+}
+
+Result<std::shared_ptr<ArrayData>> ArrayData::SliceSafe(int64_t off, int64_t len) const {
+ RETURN_NOT_OK(internal::CheckSliceParams(length, off, len, "array"));
+ return Slice(off, len);
+}
+
+int64_t ArrayData::GetNullCount() const {
+ int64_t precomputed = this->null_count.load();
+ if (ARROW_PREDICT_FALSE(precomputed == kUnknownNullCount)) {
+ if (this->buffers[0]) {
+ precomputed = this->length -
+ CountSetBits(this->buffers[0]->data(), this->offset, this->length);
+ } else {
+ precomputed = 0;
+ }
+ this->null_count.store(precomputed);
+ }
+ return precomputed;
+}
+
+// ----------------------------------------------------------------------
+// Implement ArrayData::View
+
+namespace {
+
+void AccumulateLayouts(const std::shared_ptr<DataType>& type,
+ std::vector<DataTypeLayout>* layouts) {
+ layouts->push_back(type->layout());
+ for (const auto& child : type->fields()) {
+ AccumulateLayouts(child->type(), layouts);
+ }
+}
+
+void AccumulateArrayData(const std::shared_ptr<ArrayData>& data,
+ std::vector<std::shared_ptr<ArrayData>>* out) {
+ out->push_back(data);
+ for (const auto& child : data->child_data) {
+ AccumulateArrayData(child, out);
+ }
+}
+
+struct ViewDataImpl {
+ std::shared_ptr<DataType> root_in_type;
+ std::shared_ptr<DataType> root_out_type;
+ std::vector<DataTypeLayout> in_layouts;
+ std::vector<std::shared_ptr<ArrayData>> in_data;
+ int64_t in_data_length;
+ size_t in_layout_idx = 0;
+ size_t in_buffer_idx = 0;
+ bool input_exhausted = false;
+
+ Status InvalidView(const std::string& msg) {
+ return Status::Invalid("Can't view array of type ", root_in_type->ToString(), " as ",
+ root_out_type->ToString(), ": ", msg);
+ }
+
+ void AdjustInputPointer() {
+ if (input_exhausted) {
+ return;
+ }
+ while (true) {
+ // Skip exhausted layout (might be empty layout)
+ while (in_buffer_idx >= in_layouts[in_layout_idx].buffers.size()) {
+ in_buffer_idx = 0;
+ ++in_layout_idx;
+ if (in_layout_idx >= in_layouts.size()) {
+ input_exhausted = true;
+ return;
+ }
+ }
+ const auto& in_spec = in_layouts[in_layout_idx].buffers[in_buffer_idx];
+ if (in_spec.kind != DataTypeLayout::ALWAYS_NULL) {
+ return;
+ }
+ // Skip always-null input buffers
+ // (e.g. buffer 0 of a null type or buffer 2 of a sparse union)
+ ++in_buffer_idx;
+ }
+ }
+
+ Status CheckInputAvailable() {
+ if (input_exhausted) {
+ return InvalidView("not enough buffers for view type");
+ }
+ return Status::OK();
+ }
+
+ Status CheckInputExhausted() {
+ if (!input_exhausted) {
+ return InvalidView("too many buffers for view type");
+ }
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<ArrayData>> GetDictionaryView(const DataType& out_type) {
+ if (in_data[in_layout_idx]->type->id() != Type::DICTIONARY) {
+ return InvalidView("Cannot get view as dictionary type");
+ }
+ const auto& dict_out_type = static_cast<const DictionaryType&>(out_type);
+ return internal::GetArrayView(in_data[in_layout_idx]->dictionary,
+ dict_out_type.value_type());
+ }
+
+ Status MakeDataView(const std::shared_ptr<Field>& out_field,
+ std::shared_ptr<ArrayData>* out) {
+ const auto& out_type = out_field->type();
+ const auto out_layout = out_type->layout();
+
+ AdjustInputPointer();
+ int64_t out_length = in_data_length;
+ int64_t out_offset = 0;
+ int64_t out_null_count;
+
+ std::shared_ptr<ArrayData> dictionary;
+ if (out_type->id() == Type::DICTIONARY) {
+ ARROW_ASSIGN_OR_RAISE(dictionary, GetDictionaryView(*out_type));
+ }
+
+ // No type has a purely empty layout
+ DCHECK_GT(out_layout.buffers.size(), 0);
+
+ std::vector<std::shared_ptr<Buffer>> out_buffers;
+
+ // Process null bitmap
+ if (in_buffer_idx == 0 && out_layout.buffers[0].kind == DataTypeLayout::BITMAP) {
+ // Copy input null bitmap
+ RETURN_NOT_OK(CheckInputAvailable());
+ const auto& in_data_item = in_data[in_layout_idx];
+ if (!out_field->nullable() && in_data_item->GetNullCount() != 0) {
+ return InvalidView("nulls in input cannot be viewed as non-nullable");
+ }
+ DCHECK_GT(in_data_item->buffers.size(), in_buffer_idx);
+ out_buffers.push_back(in_data_item->buffers[in_buffer_idx]);
+ out_length = in_data_item->length;
+ out_offset = in_data_item->offset;
+ out_null_count = in_data_item->null_count;
+ ++in_buffer_idx;
+ AdjustInputPointer();
+ } else {
+ // No null bitmap in input, append no-nulls bitmap
+ out_buffers.push_back(nullptr);
+ if (out_type->id() == Type::NA) {
+ out_null_count = out_length;
+ } else {
+ out_null_count = 0;
+ }
+ }
+
+ // Process other buffers in output layout
+ for (size_t out_buffer_idx = 1; out_buffer_idx < out_layout.buffers.size();
+ ++out_buffer_idx) {
+ const auto& out_spec = out_layout.buffers[out_buffer_idx];
+ // If always-null buffer is expected, just construct it
+ if (out_spec.kind == DataTypeLayout::ALWAYS_NULL) {
+ out_buffers.push_back(nullptr);
+ continue;
+ }
+
+ // If input buffer is null bitmap, try to ignore it
+ while (in_buffer_idx == 0) {
+ RETURN_NOT_OK(CheckInputAvailable());
+ if (in_data[in_layout_idx]->GetNullCount() != 0) {
+ return InvalidView("cannot represent nested nulls");
+ }
+ ++in_buffer_idx;
+ AdjustInputPointer();
+ }
+
+ RETURN_NOT_OK(CheckInputAvailable());
+ const auto& in_spec = in_layouts[in_layout_idx].buffers[in_buffer_idx];
+ if (out_spec != in_spec) {
+ return InvalidView("incompatible layouts");
+ }
+ // Copy input buffer
+ const auto& in_data_item = in_data[in_layout_idx];
+ out_length = in_data_item->length;
+ out_offset = in_data_item->offset;
+ DCHECK_GT(in_data_item->buffers.size(), in_buffer_idx);
+ out_buffers.push_back(in_data_item->buffers[in_buffer_idx]);
+ ++in_buffer_idx;
+ AdjustInputPointer();
+ }
+
+ std::shared_ptr<ArrayData> out_data = ArrayData::Make(
+ out_type, out_length, std::move(out_buffers), out_null_count, out_offset);
+ out_data->dictionary = dictionary;
+
+ // Process children recursively, depth-first
+ for (const auto& child_field : out_type->fields()) {
+ std::shared_ptr<ArrayData> child_data;
+ RETURN_NOT_OK(MakeDataView(child_field, &child_data));
+ out_data->child_data.push_back(std::move(child_data));
+ }
+ *out = std::move(out_data);
+ return Status::OK();
+ }
+};
+
+} // namespace
+
+namespace internal {
+
+Result<std::shared_ptr<ArrayData>> GetArrayView(
+ const std::shared_ptr<ArrayData>& data, const std::shared_ptr<DataType>& out_type) {
+ ViewDataImpl impl;
+ impl.root_in_type = data->type;
+ impl.root_out_type = out_type;
+ AccumulateLayouts(impl.root_in_type, &impl.in_layouts);
+ AccumulateArrayData(data, &impl.in_data);
+ impl.in_data_length = data->length;
+
+ std::shared_ptr<ArrayData> out_data;
+ // Dummy field for output type
+ auto out_field = field("", out_type);
+ RETURN_NOT_OK(impl.MakeDataView(out_field, &out_data));
+ RETURN_NOT_OK(impl.CheckInputExhausted());
+ return out_data;
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/data.h b/src/arrow/cpp/src/arrow/array/data.h
new file mode 100644
index 000000000..418d09def
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/data.h
@@ -0,0 +1,258 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <atomic> // IWYU pragma: export
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/result.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+// When slicing, we do not know the null count of the sliced range without
+// doing some computation. To avoid doing this eagerly, we set the null count
+// to -1 (any negative number will do). When Array::null_count is called the
+// first time, the null count will be computed. See ARROW-33
+constexpr int64_t kUnknownNullCount = -1;
+
+// ----------------------------------------------------------------------
+// Generic array data container
+
+/// \class ArrayData
+/// \brief Mutable container for generic Arrow array data
+///
+/// This data structure is a self-contained representation of the memory and
+/// metadata inside an Arrow array data structure (called vectors in Java). The
+/// classes arrow::Array and its subclasses provide strongly-typed accessors
+/// with support for the visitor pattern and other affordances.
+///
+/// This class is designed for easy internal data manipulation, analytical data
+/// processing, and data transport to and from IPC messages. For example, we
+/// could cast from int64 to float64 like so:
+///
+/// Int64Array arr = GetMyData();
+/// auto new_data = arr.data()->Copy();
+/// new_data->type = arrow::float64();
+/// DoubleArray double_arr(new_data);
+///
+/// This object is also useful in an analytics setting where memory may be
+/// reused. For example, if we had a group of operations all returning doubles,
+/// say:
+///
+/// Log(Sqrt(Expr(arr)))
+///
+/// Then the low-level implementations of each of these functions could have
+/// the signatures
+///
+/// void Log(const ArrayData& values, ArrayData* out);
+///
+/// As another example a function may consume one or more memory buffers in an
+/// input array and replace them with newly-allocated data, changing the output
+/// data type as well.
+struct ARROW_EXPORT ArrayData {
+ ArrayData() = default;
+
+ ArrayData(std::shared_ptr<DataType> type, int64_t length,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0)
+ : type(std::move(type)), length(length), null_count(null_count), offset(offset) {}
+
+ ArrayData(std::shared_ptr<DataType> type, int64_t length,
+ std::vector<std::shared_ptr<Buffer>> buffers,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0)
+ : ArrayData(std::move(type), length, null_count, offset) {
+ this->buffers = std::move(buffers);
+ }
+
+ ArrayData(std::shared_ptr<DataType> type, int64_t length,
+ std::vector<std::shared_ptr<Buffer>> buffers,
+ std::vector<std::shared_ptr<ArrayData>> child_data,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0)
+ : ArrayData(std::move(type), length, null_count, offset) {
+ this->buffers = std::move(buffers);
+ this->child_data = std::move(child_data);
+ }
+
+ static std::shared_ptr<ArrayData> Make(std::shared_ptr<DataType> type, int64_t length,
+ std::vector<std::shared_ptr<Buffer>> buffers,
+ int64_t null_count = kUnknownNullCount,
+ int64_t offset = 0);
+
+ static std::shared_ptr<ArrayData> Make(
+ std::shared_ptr<DataType> type, int64_t length,
+ std::vector<std::shared_ptr<Buffer>> buffers,
+ std::vector<std::shared_ptr<ArrayData>> child_data,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ static std::shared_ptr<ArrayData> Make(
+ std::shared_ptr<DataType> type, int64_t length,
+ std::vector<std::shared_ptr<Buffer>> buffers,
+ std::vector<std::shared_ptr<ArrayData>> child_data,
+ std::shared_ptr<ArrayData> dictionary, int64_t null_count = kUnknownNullCount,
+ int64_t offset = 0);
+
+ static std::shared_ptr<ArrayData> Make(std::shared_ptr<DataType> type, int64_t length,
+ int64_t null_count = kUnknownNullCount,
+ int64_t offset = 0);
+
+ // Move constructor
+ ArrayData(ArrayData&& other) noexcept
+ : type(std::move(other.type)),
+ length(other.length),
+ offset(other.offset),
+ buffers(std::move(other.buffers)),
+ child_data(std::move(other.child_data)),
+ dictionary(std::move(other.dictionary)) {
+ SetNullCount(other.null_count);
+ }
+
+ // Copy constructor
+ ArrayData(const ArrayData& other) noexcept
+ : type(other.type),
+ length(other.length),
+ offset(other.offset),
+ buffers(other.buffers),
+ child_data(other.child_data),
+ dictionary(other.dictionary) {
+ SetNullCount(other.null_count);
+ }
+
+ // Move assignment
+ ArrayData& operator=(ArrayData&& other) {
+ type = std::move(other.type);
+ length = other.length;
+ SetNullCount(other.null_count);
+ offset = other.offset;
+ buffers = std::move(other.buffers);
+ child_data = std::move(other.child_data);
+ dictionary = std::move(other.dictionary);
+ return *this;
+ }
+
+ // Copy assignment
+ ArrayData& operator=(const ArrayData& other) {
+ type = other.type;
+ length = other.length;
+ SetNullCount(other.null_count);
+ offset = other.offset;
+ buffers = other.buffers;
+ child_data = other.child_data;
+ dictionary = other.dictionary;
+ return *this;
+ }
+
+ std::shared_ptr<ArrayData> Copy() const { return std::make_shared<ArrayData>(*this); }
+
+ // Access a buffer's data as a typed C pointer
+ template <typename T>
+ inline const T* GetValues(int i, int64_t absolute_offset) const {
+ if (buffers[i]) {
+ return reinterpret_cast<const T*>(buffers[i]->data()) + absolute_offset;
+ } else {
+ return NULLPTR;
+ }
+ }
+
+ template <typename T>
+ inline const T* GetValues(int i) const {
+ return GetValues<T>(i, offset);
+ }
+
+ // Like GetValues, but returns NULLPTR instead of aborting if the underlying
+ // buffer is not a CPU buffer.
+ template <typename T>
+ inline const T* GetValuesSafe(int i, int64_t absolute_offset) const {
+ if (buffers[i] && buffers[i]->is_cpu()) {
+ return reinterpret_cast<const T*>(buffers[i]->data()) + absolute_offset;
+ } else {
+ return NULLPTR;
+ }
+ }
+
+ template <typename T>
+ inline const T* GetValuesSafe(int i) const {
+ return GetValuesSafe<T>(i, offset);
+ }
+
+ // Access a buffer's data as a typed C pointer
+ template <typename T>
+ inline T* GetMutableValues(int i, int64_t absolute_offset) {
+ if (buffers[i]) {
+ return reinterpret_cast<T*>(buffers[i]->mutable_data()) + absolute_offset;
+ } else {
+ return NULLPTR;
+ }
+ }
+
+ template <typename T>
+ inline T* GetMutableValues(int i) {
+ return GetMutableValues<T>(i, offset);
+ }
+
+ /// \brief Construct a zero-copy slice of the data with the given offset and length
+ std::shared_ptr<ArrayData> Slice(int64_t offset, int64_t length) const;
+
+ /// \brief Input-checking variant of Slice
+ ///
+ /// An Invalid Status is returned if the requested slice falls out of bounds.
+ /// Note that unlike Slice, `length` isn't clamped to the available buffer size.
+ Result<std::shared_ptr<ArrayData>> SliceSafe(int64_t offset, int64_t length) const;
+
+ void SetNullCount(int64_t v) { null_count.store(v); }
+
+ /// \brief Return null count, or compute and set it if it's not known
+ int64_t GetNullCount() const;
+
+ bool MayHaveNulls() const {
+ // If an ArrayData is slightly malformed it may have kUnknownNullCount set
+ // but no buffer
+ return null_count.load() != 0 && buffers[0] != NULLPTR;
+ }
+
+ std::shared_ptr<DataType> type;
+ int64_t length = 0;
+ mutable std::atomic<int64_t> null_count{0};
+ // The logical start point into the physical buffers (in values, not bytes).
+ // Note that, for child data, this must be *added* to the child data's own offset.
+ int64_t offset = 0;
+ std::vector<std::shared_ptr<Buffer>> buffers;
+ std::vector<std::shared_ptr<ArrayData>> child_data;
+
+ // The dictionary for this Array, if any. Only used for dictionary type
+ std::shared_ptr<ArrayData> dictionary;
+};
+
+namespace internal {
+
+/// Construct a zero-copy view of this ArrayData with the given type.
+///
+/// This method checks if the types are layout-compatible.
+/// Nested types are traversed in depth-first order. Data buffers must have
+/// the same item sizes, even though the logical types may be different.
+/// An error is returned if the types are not layout-compatible.
+ARROW_EXPORT
+Result<std::shared_ptr<ArrayData>> GetArrayView(const std::shared_ptr<ArrayData>& data,
+ const std::shared_ptr<DataType>& type);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/dict_internal.h b/src/arrow/cpp/src/arrow/array/dict_internal.h
new file mode 100644
index 000000000..aa027ac22
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/dict_internal.h
@@ -0,0 +1,193 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/array/builder_dict.h"
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/hashing.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string_view.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+namespace internal {
+
+template <typename T, typename Enable = void>
+struct DictionaryTraits {
+ using MemoTableType = void;
+};
+
+} // namespace internal
+
+template <typename T, typename Out = void>
+using enable_if_memoize = enable_if_t<
+ !std::is_same<typename internal::DictionaryTraits<T>::MemoTableType, void>::value,
+ Out>;
+
+template <typename T, typename Out = void>
+using enable_if_no_memoize = enable_if_t<
+ std::is_same<typename internal::DictionaryTraits<T>::MemoTableType, void>::value,
+ Out>;
+
+namespace internal {
+
+template <>
+struct DictionaryTraits<BooleanType> {
+ using T = BooleanType;
+ using MemoTableType = typename HashTraits<T>::MemoTableType;
+
+ static Status GetDictionaryArrayData(MemoryPool* pool,
+ const std::shared_ptr<DataType>& type,
+ const MemoTableType& memo_table,
+ int64_t start_offset,
+ std::shared_ptr<ArrayData>* out) {
+ if (start_offset < 0) {
+ return Status::Invalid("invalid start_offset ", start_offset);
+ }
+
+ BooleanBuilder builder(pool);
+ const auto& bool_values = memo_table.values();
+ const auto null_index = memo_table.GetNull();
+
+ // Will iterate up to 3 times.
+ for (int64_t i = start_offset; i < memo_table.size(); i++) {
+ RETURN_NOT_OK(i == null_index ? builder.AppendNull()
+ : builder.Append(bool_values[i]));
+ }
+
+ return builder.FinishInternal(out);
+ }
+}; // namespace internal
+
+template <typename T>
+struct DictionaryTraits<T, enable_if_has_c_type<T>> {
+ using c_type = typename T::c_type;
+ using MemoTableType = typename HashTraits<T>::MemoTableType;
+
+ static Status GetDictionaryArrayData(MemoryPool* pool,
+ const std::shared_ptr<DataType>& type,
+ const MemoTableType& memo_table,
+ int64_t start_offset,
+ std::shared_ptr<ArrayData>* out) {
+ auto dict_length = static_cast<int64_t>(memo_table.size()) - start_offset;
+ // This makes a copy, but we assume a dictionary array is usually small
+ // compared to the size of the dictionary-using array.
+ // (also, copying the dictionary values is cheap compared to the cost
+ // of building the memo table)
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<Buffer> dict_buffer,
+ AllocateBuffer(TypeTraits<T>::bytes_required(dict_length), pool));
+ memo_table.CopyValues(static_cast<int32_t>(start_offset),
+ reinterpret_cast<c_type*>(dict_buffer->mutable_data()));
+
+ int64_t null_count = 0;
+ std::shared_ptr<Buffer> null_bitmap = nullptr;
+ RETURN_NOT_OK(
+ ComputeNullBitmap(pool, memo_table, start_offset, &null_count, &null_bitmap));
+
+ *out = ArrayData::Make(type, dict_length, {null_bitmap, dict_buffer}, null_count);
+ return Status::OK();
+ }
+};
+
+template <typename T>
+struct DictionaryTraits<T, enable_if_base_binary<T>> {
+ using MemoTableType = typename HashTraits<T>::MemoTableType;
+
+ static Status GetDictionaryArrayData(MemoryPool* pool,
+ const std::shared_ptr<DataType>& type,
+ const MemoTableType& memo_table,
+ int64_t start_offset,
+ std::shared_ptr<ArrayData>* out) {
+ using offset_type = typename T::offset_type;
+
+ // Create the offsets buffer
+ auto dict_length = static_cast<int64_t>(memo_table.size() - start_offset);
+ ARROW_ASSIGN_OR_RAISE(auto dict_offsets,
+ AllocateBuffer(sizeof(offset_type) * (dict_length + 1), pool));
+ auto raw_offsets = reinterpret_cast<offset_type*>(dict_offsets->mutable_data());
+ memo_table.CopyOffsets(static_cast<int32_t>(start_offset), raw_offsets);
+
+ // Create the data buffer
+ auto values_size = memo_table.values_size();
+ ARROW_ASSIGN_OR_RAISE(auto dict_data, AllocateBuffer(values_size, pool));
+ if (values_size > 0) {
+ memo_table.CopyValues(static_cast<int32_t>(start_offset), dict_data->size(),
+ dict_data->mutable_data());
+ }
+
+ int64_t null_count = 0;
+ std::shared_ptr<Buffer> null_bitmap = nullptr;
+ RETURN_NOT_OK(
+ ComputeNullBitmap(pool, memo_table, start_offset, &null_count, &null_bitmap));
+
+ *out = ArrayData::Make(type, dict_length,
+ {null_bitmap, std::move(dict_offsets), std::move(dict_data)},
+ null_count);
+
+ return Status::OK();
+ }
+};
+
+template <typename T>
+struct DictionaryTraits<T, enable_if_fixed_size_binary<T>> {
+ using MemoTableType = typename HashTraits<T>::MemoTableType;
+
+ static Status GetDictionaryArrayData(MemoryPool* pool,
+ const std::shared_ptr<DataType>& type,
+ const MemoTableType& memo_table,
+ int64_t start_offset,
+ std::shared_ptr<ArrayData>* out) {
+ const T& concrete_type = internal::checked_cast<const T&>(*type);
+
+ // Create the data buffer
+ auto dict_length = static_cast<int64_t>(memo_table.size() - start_offset);
+ auto width_length = concrete_type.byte_width();
+ auto data_length = dict_length * width_length;
+ ARROW_ASSIGN_OR_RAISE(auto dict_data, AllocateBuffer(data_length, pool));
+ auto data = dict_data->mutable_data();
+
+ memo_table.CopyFixedWidthValues(static_cast<int32_t>(start_offset), width_length,
+ data_length, data);
+
+ int64_t null_count = 0;
+ std::shared_ptr<Buffer> null_bitmap = nullptr;
+ RETURN_NOT_OK(
+ ComputeNullBitmap(pool, memo_table, start_offset, &null_count, &null_bitmap));
+
+ *out = ArrayData::Make(type, dict_length, {null_bitmap, std::move(dict_data)},
+ null_count);
+ return Status::OK();
+ }
+};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/diff.cc b/src/arrow/cpp/src/arrow/array/diff.cc
new file mode 100644
index 000000000..0a50de0f1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/diff.cc
@@ -0,0 +1,794 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/diff.h"
+
+#include <algorithm>
+#include <chrono>
+#include <functional>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/array/array_primitive.h"
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/extension_type.h"
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/range.h"
+#include "arrow/util/string.h"
+#include "arrow/util/string_view.h"
+#include "arrow/vendored/datetime.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+using internal::MakeLazyRange;
+
+template <typename ArrayType>
+auto GetView(const ArrayType& array, int64_t index) -> decltype(array.GetView(index)) {
+ return array.GetView(index);
+}
+
+struct Slice {
+ const Array* array_;
+ int64_t offset_, length_;
+
+ bool operator==(const Slice& other) const {
+ return length_ == other.length_ &&
+ array_->RangeEquals(offset_, offset_ + length_, other.offset_, *other.array_);
+ }
+ bool operator!=(const Slice& other) const { return !(*this == other); }
+};
+
+template <typename ArrayType, typename T = typename ArrayType::TypeClass,
+ typename = enable_if_list_like<T>>
+static Slice GetView(const ArrayType& array, int64_t index) {
+ return Slice{array.values().get(), array.value_offset(index),
+ array.value_length(index)};
+}
+
+struct UnitSlice {
+ const Array* array_;
+ int64_t offset_;
+
+ bool operator==(const UnitSlice& other) const {
+ return array_->RangeEquals(offset_, offset_ + 1, other.offset_, *other.array_);
+ }
+ bool operator!=(const UnitSlice& other) const { return !(*this == other); }
+};
+
+// FIXME(bkietz) this is inefficient;
+// StructArray's fields can be diffed independently then merged
+static UnitSlice GetView(const StructArray& array, int64_t index) {
+ return UnitSlice{&array, index};
+}
+
+static UnitSlice GetView(const UnionArray& array, int64_t index) {
+ return UnitSlice{&array, index};
+}
+
+using ValueComparator = std::function<bool(const Array&, int64_t, const Array&, int64_t)>;
+
+struct ValueComparatorVisitor {
+ template <typename T>
+ Status Visit(const T&) {
+ using ArrayType = typename TypeTraits<T>::ArrayType;
+ out = [](const Array& base, int64_t base_index, const Array& target,
+ int64_t target_index) {
+ return (GetView(checked_cast<const ArrayType&>(base), base_index) ==
+ GetView(checked_cast<const ArrayType&>(target), target_index));
+ };
+ return Status::OK();
+ }
+
+ Status Visit(const NullType&) { return Status::NotImplemented("null type"); }
+
+ Status Visit(const ExtensionType&) { return Status::NotImplemented("extension type"); }
+
+ Status Visit(const DictionaryType&) {
+ return Status::NotImplemented("dictionary type");
+ }
+
+ ValueComparator Create(const DataType& type) {
+ DCHECK_OK(VisitTypeInline(type, this));
+ return out;
+ }
+
+ ValueComparator out;
+};
+
+ValueComparator GetValueComparator(const DataType& type) {
+ ValueComparatorVisitor type_visitor;
+ return type_visitor.Create(type);
+}
+
+// represents an intermediate state in the comparison of two arrays
+struct EditPoint {
+ int64_t base, target;
+ bool operator==(EditPoint other) const {
+ return base == other.base && target == other.target;
+ }
+};
+
+/// A generic sequence difference algorithm, based on
+///
+/// E. W. Myers, "An O(ND) difference algorithm and its variations,"
+/// Algorithmica, vol. 1, no. 1-4, pp. 251–266, 1986.
+///
+/// To summarize, an edit script is computed by maintaining the furthest set of EditPoints
+/// which are reachable in a given number of edits D. This is used to compute the furthest
+/// set reachable with D+1 edits, and the process continues inductively until a complete
+/// edit script is discovered.
+///
+/// From each edit point a single deletion and insertion is made then as many shared
+/// elements as possible are skipped, recording only the endpoint of the run. This
+/// representation is minimal in the common case where the sequences differ only slightly,
+/// since most of the elements are shared between base and target and are represented
+/// implicitly.
+class QuadraticSpaceMyersDiff {
+ public:
+ QuadraticSpaceMyersDiff(const Array& base, const Array& target, MemoryPool* pool)
+ : base_(base),
+ target_(target),
+ pool_(pool),
+ value_comparator_(GetValueComparator(*base.type())),
+ base_begin_(0),
+ base_end_(base.length()),
+ target_begin_(0),
+ target_end_(target.length()),
+ endpoint_base_({ExtendFrom({base_begin_, target_begin_}).base}),
+ insert_({true}) {
+ if ((base_end_ - base_begin_ == target_end_ - target_begin_) &&
+ endpoint_base_[0] == base_end_) {
+ // trivial case: base == target
+ finish_index_ = 0;
+ }
+ }
+
+ bool ValuesEqual(int64_t base_index, int64_t target_index) const {
+ bool base_null = base_.IsNull(base_index);
+ bool target_null = target_.IsNull(target_index);
+ if (base_null || target_null) {
+ // If only one is null, then this is false, otherwise true
+ return base_null && target_null;
+ }
+ return value_comparator_(base_, base_index, target_, target_index);
+ }
+
+ // increment the position within base (the element pointed to was deleted)
+ // then extend maximally
+ EditPoint DeleteOne(EditPoint p) const {
+ if (p.base != base_end_) {
+ ++p.base;
+ }
+ return ExtendFrom(p);
+ }
+
+ // increment the position within target (the element pointed to was inserted)
+ // then extend maximally
+ EditPoint InsertOne(EditPoint p) const {
+ if (p.target != target_end_) {
+ ++p.target;
+ }
+ return ExtendFrom(p);
+ }
+
+ // increment the position within base and target (the elements skipped in this way were
+ // present in both sequences)
+ EditPoint ExtendFrom(EditPoint p) const {
+ for (; p.base != base_end_ && p.target != target_end_; ++p.base, ++p.target) {
+ if (!ValuesEqual(p.base, p.target)) {
+ break;
+ }
+ }
+ return p;
+ }
+
+ // beginning of a range for storing per-edit state in endpoint_base_ and insert_
+ int64_t StorageOffset(int64_t edit_count) const {
+ return edit_count * (edit_count + 1) / 2;
+ }
+
+ // given edit_count and index, augment endpoint_base_[index] with the corresponding
+ // position in target (which is only implicitly represented in edit_count, index)
+ EditPoint GetEditPoint(int64_t edit_count, int64_t index) const {
+ DCHECK_GE(index, StorageOffset(edit_count));
+ DCHECK_LT(index, StorageOffset(edit_count + 1));
+ auto insertions_minus_deletions =
+ 2 * (index - StorageOffset(edit_count)) - edit_count;
+ auto maximal_base = endpoint_base_[index];
+ auto maximal_target = std::min(
+ target_begin_ + ((maximal_base - base_begin_) + insertions_minus_deletions),
+ target_end_);
+ return {maximal_base, maximal_target};
+ }
+
+ void Next() {
+ ++edit_count_;
+ // base_begin_ is used as a dummy value here since Iterator may not be default
+ // constructible. The newly allocated range is completely overwritten below.
+ endpoint_base_.resize(StorageOffset(edit_count_ + 1), base_begin_);
+ insert_.resize(StorageOffset(edit_count_ + 1), false);
+
+ auto previous_offset = StorageOffset(edit_count_ - 1);
+ auto current_offset = StorageOffset(edit_count_);
+
+ // try deleting from base first
+ for (int64_t i = 0, i_out = 0; i < edit_count_; ++i, ++i_out) {
+ auto previous_endpoint = GetEditPoint(edit_count_ - 1, i + previous_offset);
+ endpoint_base_[i_out + current_offset] = DeleteOne(previous_endpoint).base;
+ }
+
+ // check if inserting from target could do better
+ for (int64_t i = 0, i_out = 1; i < edit_count_; ++i, ++i_out) {
+ // retrieve the previously computed best endpoint for (edit_count_, i_out)
+ // for comparison with the best endpoint achievable with an insertion
+ auto endpoint_after_deletion = GetEditPoint(edit_count_, i_out + current_offset);
+
+ auto previous_endpoint = GetEditPoint(edit_count_ - 1, i + previous_offset);
+ auto endpoint_after_insertion = InsertOne(previous_endpoint);
+
+ if (endpoint_after_insertion.base - endpoint_after_deletion.base >= 0) {
+ // insertion was more efficient; keep it and mark the insertion in insert_
+ insert_[i_out + current_offset] = true;
+ endpoint_base_[i_out + current_offset] = endpoint_after_insertion.base;
+ }
+ }
+
+ // check for completion
+ EditPoint finish = {base_end_, target_end_};
+ for (int64_t i_out = 0; i_out < edit_count_ + 1; ++i_out) {
+ if (GetEditPoint(edit_count_, i_out + current_offset) == finish) {
+ finish_index_ = i_out + current_offset;
+ return;
+ }
+ }
+ }
+
+ bool Done() { return finish_index_ != -1; }
+
+ Result<std::shared_ptr<StructArray>> GetEdits(MemoryPool* pool) {
+ DCHECK(Done());
+
+ int64_t length = edit_count_ + 1;
+ ARROW_ASSIGN_OR_RAISE(auto insert_buf, AllocateEmptyBitmap(length, pool));
+ ARROW_ASSIGN_OR_RAISE(auto run_length_buf,
+ AllocateBuffer(length * sizeof(int64_t), pool));
+ auto run_length = reinterpret_cast<int64_t*>(run_length_buf->mutable_data());
+
+ auto index = finish_index_;
+ auto endpoint = GetEditPoint(edit_count_, finish_index_);
+
+ for (int64_t i = edit_count_; i > 0; --i) {
+ bool insert = insert_[index];
+ BitUtil::SetBitTo(insert_buf->mutable_data(), i, insert);
+
+ auto insertions_minus_deletions =
+ (endpoint.base - base_begin_) - (endpoint.target - target_begin_);
+ if (insert) {
+ ++insertions_minus_deletions;
+ } else {
+ --insertions_minus_deletions;
+ }
+ index = (i - 1 - insertions_minus_deletions) / 2 + StorageOffset(i - 1);
+
+ // endpoint of previous edit
+ auto previous = GetEditPoint(i - 1, index);
+ run_length[i] = endpoint.base - previous.base - !insert;
+ DCHECK_GE(run_length[i], 0);
+
+ endpoint = previous;
+ }
+ BitUtil::SetBitTo(insert_buf->mutable_data(), 0, false);
+ run_length[0] = endpoint.base - base_begin_;
+
+ return StructArray::Make(
+ {std::make_shared<BooleanArray>(length, std::move(insert_buf)),
+ std::make_shared<Int64Array>(length, std::move(run_length_buf))},
+ {field("insert", boolean()), field("run_length", int64())});
+ }
+
+ Result<std::shared_ptr<StructArray>> Diff() {
+ while (!Done()) {
+ Next();
+ }
+ return GetEdits(pool_);
+ }
+
+ private:
+ const Array& base_;
+ const Array& target_;
+ MemoryPool* pool_;
+ ValueComparator value_comparator_;
+ int64_t finish_index_ = -1;
+ int64_t edit_count_ = 0;
+ int64_t base_begin_, base_end_;
+ int64_t target_begin_, target_end_;
+ // each element of endpoint_base_ is the furthest position in base reachable given an
+ // edit_count and (# insertions) - (# deletions). Each bit of insert_ records whether
+ // the corresponding furthest position was reached via an insertion or a deletion
+ // (followed by a run of shared elements). See StorageOffset for the
+ // layout of these vectors
+ std::vector<int64_t> endpoint_base_;
+ std::vector<bool> insert_;
+};
+
+Result<std::shared_ptr<StructArray>> NullDiff(const Array& base, const Array& target,
+ MemoryPool* pool) {
+ bool insert = base.length() < target.length();
+ auto run_length = std::min(base.length(), target.length());
+ auto edit_count = std::max(base.length(), target.length()) - run_length;
+
+ TypedBufferBuilder<bool> insert_builder(pool);
+ RETURN_NOT_OK(insert_builder.Resize(edit_count + 1));
+ insert_builder.UnsafeAppend(false);
+ TypedBufferBuilder<int64_t> run_length_builder(pool);
+ RETURN_NOT_OK(run_length_builder.Resize(edit_count + 1));
+ run_length_builder.UnsafeAppend(run_length);
+ if (edit_count > 0) {
+ insert_builder.UnsafeAppend(edit_count, insert);
+ run_length_builder.UnsafeAppend(edit_count, 0);
+ }
+
+ std::shared_ptr<Buffer> insert_buf, run_length_buf;
+ RETURN_NOT_OK(insert_builder.Finish(&insert_buf));
+ RETURN_NOT_OK(run_length_builder.Finish(&run_length_buf));
+
+ return StructArray::Make({std::make_shared<BooleanArray>(edit_count + 1, insert_buf),
+ std::make_shared<Int64Array>(edit_count + 1, run_length_buf)},
+ {field("insert", boolean()), field("run_length", int64())});
+}
+
+Result<std::shared_ptr<StructArray>> Diff(const Array& base, const Array& target,
+ MemoryPool* pool) {
+ if (!base.type()->Equals(target.type())) {
+ return Status::TypeError("only taking the diff of like-typed arrays is supported.");
+ }
+
+ if (base.type()->id() == Type::NA) {
+ return NullDiff(base, target, pool);
+ } else if (base.type()->id() == Type::EXTENSION) {
+ auto base_storage = checked_cast<const ExtensionArray&>(base).storage();
+ auto target_storage = checked_cast<const ExtensionArray&>(target).storage();
+ return Diff(*base_storage, *target_storage, pool);
+ } else if (base.type()->id() == Type::DICTIONARY) {
+ return Status::NotImplemented("diffing arrays of type ", *base.type());
+ } else {
+ return QuadraticSpaceMyersDiff(base, target, pool).Diff();
+ }
+}
+
+using Formatter = std::function<void(const Array&, int64_t index, std::ostream*)>;
+
+static Result<Formatter> MakeFormatter(const DataType& type);
+
+class MakeFormatterImpl {
+ public:
+ Result<Formatter> Make(const DataType& type) && {
+ RETURN_NOT_OK(VisitTypeInline(type, this));
+ return std::move(impl_);
+ }
+
+ private:
+ template <typename VISITOR>
+ friend Status VisitTypeInline(const DataType&, VISITOR*);
+
+ // factory implementation
+ Status Visit(const BooleanType&) {
+ impl_ = [](const Array& array, int64_t index, std::ostream* os) {
+ *os << (checked_cast<const BooleanArray&>(array).Value(index) ? "true" : "false");
+ };
+ return Status::OK();
+ }
+
+ // format Numerics with std::ostream defaults
+ template <typename T>
+ enable_if_number<T, Status> Visit(const T&) {
+ impl_ = [](const Array& array, int64_t index, std::ostream* os) {
+ const auto& numeric = checked_cast<const NumericArray<T>&>(array);
+ if (sizeof(decltype(numeric.Value(index))) == sizeof(char)) {
+ // override std::ostream defaults for /(u|)int8_t/ since they are
+ // formatted as potentially unprintable/tty borking characters
+ *os << static_cast<int16_t>(numeric.Value(index));
+ } else {
+ *os << numeric.Value(index);
+ }
+ };
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_date<T, Status> Visit(const T&) {
+ using unit = typename std::conditional<std::is_same<T, Date32Type>::value,
+ arrow_vendored::date::days,
+ std::chrono::milliseconds>::type;
+
+ static arrow_vendored::date::sys_days epoch{arrow_vendored::date::jan / 1 / 1970};
+
+ impl_ = [](const Array& array, int64_t index, std::ostream* os) {
+ unit value(checked_cast<const NumericArray<T>&>(array).Value(index));
+ *os << arrow_vendored::date::format("%F", value + epoch);
+ };
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_time<T, Status> Visit(const T&) {
+ impl_ = MakeTimeFormatter<T, false>("%T");
+ return Status::OK();
+ }
+
+ Status Visit(const TimestampType&) {
+ impl_ = MakeTimeFormatter<TimestampType, true>("%F %T");
+ return Status::OK();
+ }
+
+ Status Visit(const DayTimeIntervalType&) {
+ impl_ = [](const Array& array, int64_t index, std::ostream* os) {
+ auto day_millis = checked_cast<const DayTimeIntervalArray&>(array).Value(index);
+ *os << day_millis.days << "d" << day_millis.milliseconds << "ms";
+ };
+ return Status::OK();
+ }
+
+ Status Visit(const MonthDayNanoIntervalType&) {
+ impl_ = [](const Array& array, int64_t index, std::ostream* os) {
+ auto month_day_nanos =
+ checked_cast<const MonthDayNanoIntervalArray&>(array).Value(index);
+ *os << month_day_nanos.months << "M" << month_day_nanos.days << "d"
+ << month_day_nanos.nanoseconds << "ns";
+ };
+ return Status::OK();
+ }
+
+ // format Binary, LargeBinary and FixedSizeBinary in hexadecimal
+ template <typename T>
+ enable_if_binary_like<T, Status> Visit(const T&) {
+ using ArrayType = typename TypeTraits<T>::ArrayType;
+ impl_ = [](const Array& array, int64_t index, std::ostream* os) {
+ *os << HexEncode(checked_cast<const ArrayType&>(array).GetView(index));
+ };
+ return Status::OK();
+ }
+
+ // format Strings with \"\n\r\t\\ escaped
+ template <typename T>
+ enable_if_string_like<T, Status> Visit(const T&) {
+ using ArrayType = typename TypeTraits<T>::ArrayType;
+ impl_ = [](const Array& array, int64_t index, std::ostream* os) {
+ *os << "\"" << Escape(checked_cast<const ArrayType&>(array).GetView(index)) << "\"";
+ };
+ return Status::OK();
+ }
+
+ // format Decimals with Decimal128Array::FormatValue
+ Status Visit(const Decimal128Type&) {
+ impl_ = [](const Array& array, int64_t index, std::ostream* os) {
+ *os << checked_cast<const Decimal128Array&>(array).FormatValue(index);
+ };
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_list_like<T, Status> Visit(const T& t) {
+ struct ListImpl {
+ explicit ListImpl(Formatter f) : values_formatter_(std::move(f)) {}
+
+ void operator()(const Array& array, int64_t index, std::ostream* os) {
+ const auto& list_array =
+ checked_cast<const typename TypeTraits<T>::ArrayType&>(array);
+ *os << "[";
+ for (int32_t i = 0; i < list_array.value_length(index); ++i) {
+ if (i != 0) {
+ *os << ", ";
+ }
+ values_formatter_(*list_array.values(), i + list_array.value_offset(index), os);
+ }
+ *os << "]";
+ }
+
+ Formatter values_formatter_;
+ };
+
+ ARROW_ASSIGN_OR_RAISE(auto values_formatter, MakeFormatter(*t.value_type()));
+ impl_ = ListImpl(std::move(values_formatter));
+ return Status::OK();
+ }
+
+ // TODO(bkietz) format maps better
+
+ Status Visit(const StructType& t) {
+ struct StructImpl {
+ explicit StructImpl(std::vector<Formatter> f) : field_formatters_(std::move(f)) {}
+
+ void operator()(const Array& array, int64_t index, std::ostream* os) {
+ const auto& struct_array = checked_cast<const StructArray&>(array);
+ *os << "{";
+ for (int i = 0, printed = 0; i < struct_array.num_fields(); ++i) {
+ if (printed != 0) {
+ *os << ", ";
+ }
+ if (struct_array.field(i)->IsNull(index)) {
+ continue;
+ }
+ ++printed;
+ *os << struct_array.struct_type()->field(i)->name() << ": ";
+ field_formatters_[i](*struct_array.field(i), index, os);
+ }
+ *os << "}";
+ }
+
+ std::vector<Formatter> field_formatters_;
+ };
+
+ std::vector<Formatter> field_formatters(t.num_fields());
+ for (int i = 0; i < t.num_fields(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(field_formatters[i], MakeFormatter(*t.field(i)->type()));
+ }
+
+ impl_ = StructImpl(std::move(field_formatters));
+ return Status::OK();
+ }
+
+ Status Visit(const UnionType& t) {
+ struct UnionImpl {
+ explicit UnionImpl(std::vector<Formatter> f) : field_formatters_(std::move(f)) {}
+
+ void DoFormat(const UnionArray& array, int64_t index, int64_t child_index,
+ std::ostream* os) {
+ auto type_code = array.raw_type_codes()[index];
+ auto child = array.field(array.child_id(index));
+
+ *os << "{" << static_cast<int16_t>(type_code) << ": ";
+ if (child->IsNull(child_index)) {
+ *os << "null";
+ } else {
+ field_formatters_[type_code](*child, child_index, os);
+ }
+ *os << "}";
+ }
+
+ std::vector<Formatter> field_formatters_;
+ };
+
+ struct SparseImpl : UnionImpl {
+ using UnionImpl::UnionImpl;
+
+ void operator()(const Array& array, int64_t index, std::ostream* os) {
+ const auto& union_array = checked_cast<const SparseUnionArray&>(array);
+ DoFormat(union_array, index, index, os);
+ }
+ };
+
+ struct DenseImpl : UnionImpl {
+ using UnionImpl::UnionImpl;
+
+ void operator()(const Array& array, int64_t index, std::ostream* os) {
+ const auto& union_array = checked_cast<const DenseUnionArray&>(array);
+ DoFormat(union_array, index, union_array.raw_value_offsets()[index], os);
+ }
+ };
+
+ std::vector<Formatter> field_formatters(t.max_type_code() + 1);
+ for (int i = 0; i < t.num_fields(); ++i) {
+ auto type_id = t.type_codes()[i];
+ ARROW_ASSIGN_OR_RAISE(field_formatters[type_id],
+ MakeFormatter(*t.field(i)->type()));
+ }
+
+ if (t.mode() == UnionMode::SPARSE) {
+ impl_ = SparseImpl(std::move(field_formatters));
+ } else {
+ impl_ = DenseImpl(std::move(field_formatters));
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const NullType& t) {
+ return Status::NotImplemented("formatting diffs between arrays of type ", t);
+ }
+
+ Status Visit(const DictionaryType& t) {
+ return Status::NotImplemented("formatting diffs between arrays of type ", t);
+ }
+
+ Status Visit(const ExtensionType& t) {
+ return Status::NotImplemented("formatting diffs between arrays of type ", t);
+ }
+
+ Status Visit(const DurationType& t) {
+ return Status::NotImplemented("formatting diffs between arrays of type ", t);
+ }
+
+ Status Visit(const MonthIntervalType& t) {
+ return Status::NotImplemented("formatting diffs between arrays of type ", t);
+ }
+
+ template <typename T, bool AddEpoch>
+ Formatter MakeTimeFormatter(const std::string& fmt_str) {
+ return [fmt_str](const Array& array, int64_t index, std::ostream* os) {
+ auto fmt = fmt_str.c_str();
+ auto unit = checked_cast<const T&>(*array.type()).unit();
+ auto value = checked_cast<const NumericArray<T>&>(array).Value(index);
+ using arrow_vendored::date::format;
+ using std::chrono::nanoseconds;
+ using std::chrono::microseconds;
+ using std::chrono::milliseconds;
+ using std::chrono::seconds;
+ if (AddEpoch) {
+ static arrow_vendored::date::sys_days epoch{arrow_vendored::date::jan / 1 / 1970};
+
+ switch (unit) {
+ case TimeUnit::NANO:
+ *os << format(fmt, static_cast<nanoseconds>(value) + epoch);
+ break;
+ case TimeUnit::MICRO:
+ *os << format(fmt, static_cast<microseconds>(value) + epoch);
+ break;
+ case TimeUnit::MILLI:
+ *os << format(fmt, static_cast<milliseconds>(value) + epoch);
+ break;
+ case TimeUnit::SECOND:
+ *os << format(fmt, static_cast<seconds>(value) + epoch);
+ break;
+ }
+ return;
+ }
+ switch (unit) {
+ case TimeUnit::NANO:
+ *os << format(fmt, static_cast<nanoseconds>(value));
+ break;
+ case TimeUnit::MICRO:
+ *os << format(fmt, static_cast<microseconds>(value));
+ break;
+ case TimeUnit::MILLI:
+ *os << format(fmt, static_cast<milliseconds>(value));
+ break;
+ case TimeUnit::SECOND:
+ *os << format(fmt, static_cast<seconds>(value));
+ break;
+ }
+ };
+ }
+
+ Formatter impl_;
+};
+
+static Result<Formatter> MakeFormatter(const DataType& type) {
+ return MakeFormatterImpl{}.Make(type);
+}
+
+Status VisitEditScript(
+ const Array& edits,
+ const std::function<Status(int64_t delete_begin, int64_t delete_end,
+ int64_t insert_begin, int64_t insert_end)>& visitor) {
+ static const auto edits_type =
+ struct_({field("insert", boolean()), field("run_length", int64())});
+ DCHECK(edits.type()->Equals(*edits_type));
+ DCHECK_GE(edits.length(), 1);
+
+ auto insert = checked_pointer_cast<BooleanArray>(
+ checked_cast<const StructArray&>(edits).field(0));
+ auto run_lengths =
+ checked_pointer_cast<Int64Array>(checked_cast<const StructArray&>(edits).field(1));
+
+ DCHECK(!insert->Value(0));
+
+ auto length = run_lengths->Value(0);
+ int64_t base_begin, base_end, target_begin, target_end;
+ base_begin = base_end = target_begin = target_end = length;
+ for (int64_t i = 1; i < edits.length(); ++i) {
+ if (insert->Value(i)) {
+ ++target_end;
+ } else {
+ ++base_end;
+ }
+ length = run_lengths->Value(i);
+ if (length != 0) {
+ RETURN_NOT_OK(visitor(base_begin, base_end, target_begin, target_end));
+ base_begin = base_end = base_end + length;
+ target_begin = target_end = target_end + length;
+ }
+ }
+ if (length == 0) {
+ return visitor(base_begin, base_end, target_begin, target_end);
+ }
+ return Status::OK();
+}
+
+class UnifiedDiffFormatter {
+ public:
+ UnifiedDiffFormatter(std::ostream* os, Formatter formatter)
+ : os_(os), formatter_(std::move(formatter)) {}
+
+ Status operator()(int64_t delete_begin, int64_t delete_end, int64_t insert_begin,
+ int64_t insert_end) {
+ *os_ << "@@ -" << delete_begin << ", +" << insert_begin << " @@" << std::endl;
+
+ for (int64_t i = delete_begin; i < delete_end; ++i) {
+ *os_ << "-";
+ if (base_->IsValid(i)) {
+ formatter_(*base_, i, &*os_);
+ } else {
+ *os_ << "null";
+ }
+ *os_ << std::endl;
+ }
+
+ for (int64_t i = insert_begin; i < insert_end; ++i) {
+ *os_ << "+";
+ if (target_->IsValid(i)) {
+ formatter_(*target_, i, &*os_);
+ } else {
+ *os_ << "null";
+ }
+ *os_ << std::endl;
+ }
+
+ return Status::OK();
+ }
+
+ Status operator()(const Array& edits, const Array& base, const Array& target) {
+ if (edits.length() == 1) {
+ return Status::OK();
+ }
+ base_ = &base;
+ target_ = &target;
+ *os_ << std::endl;
+ return VisitEditScript(edits, *this);
+ }
+
+ private:
+ std::ostream* os_ = nullptr;
+ const Array* base_ = nullptr;
+ const Array* target_ = nullptr;
+ Formatter formatter_;
+};
+
+Result<std::function<Status(const Array& edits, const Array& base, const Array& target)>>
+MakeUnifiedDiffFormatter(const DataType& type, std::ostream* os) {
+ if (type.id() == Type::NA) {
+ return [os](const Array& edits, const Array& base, const Array& target) {
+ if (base.length() != target.length()) {
+ *os << "# Null arrays differed" << std::endl
+ << "-" << base.length() << " nulls" << std::endl
+ << "+" << target.length() << " nulls" << std::endl;
+ }
+ return Status::OK();
+ };
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto formatter, MakeFormatter(type));
+ return UnifiedDiffFormatter(os, std::move(formatter));
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/diff.h b/src/arrow/cpp/src/arrow/array/diff.h
new file mode 100644
index 000000000..a405164b3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/diff.h
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <functional>
+#include <iosfwd>
+#include <memory>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+/// \brief Compare two arrays, returning an edit script which expresses the difference
+/// between them
+///
+/// An edit script is an array of struct(insert: bool, run_length: int64_t).
+/// Each element of "insert" determines whether an element was inserted into (true)
+/// or deleted from (false) base. Each insertion or deletion is followed by a run of
+/// elements which are unchanged from base to target; the length of this run is stored
+/// in "run_length". (Note that the edit script begins and ends with a run of shared
+/// elements but both fields of the struct must have the same length. To accommodate this
+/// the first element of "insert" should be ignored.)
+///
+/// For example for base "hlloo" and target "hello", the edit script would be
+/// [
+/// {"insert": false, "run_length": 1}, // leading run of length 1 ("h")
+/// {"insert": true, "run_length": 3}, // insert("e") then a run of length 3 ("llo")
+/// {"insert": false, "run_length": 0} // delete("o") then an empty run
+/// ]
+///
+/// Diffing arrays containing nulls is not currently supported.
+///
+/// \param[in] base baseline for comparison
+/// \param[in] target an array of identical type to base whose elements differ from base's
+/// \param[in] pool memory to store the result will be allocated from this memory pool
+/// \return an edit script array which can be applied to base to produce target
+ARROW_EXPORT
+Result<std::shared_ptr<StructArray>> Diff(const Array& base, const Array& target,
+ MemoryPool* pool = default_memory_pool());
+
+/// \brief visitor interface for easy traversal of an edit script
+///
+/// visitor will be called for each hunk of insertions and deletions.
+ARROW_EXPORT Status VisitEditScript(
+ const Array& edits,
+ const std::function<Status(int64_t delete_begin, int64_t delete_end,
+ int64_t insert_begin, int64_t insert_end)>& visitor);
+
+/// \brief return a function which will format an edit script in unified
+/// diff format to os, given base and target arrays of type
+ARROW_EXPORT Result<
+ std::function<Status(const Array& edits, const Array& base, const Array& target)>>
+MakeUnifiedDiffFormatter(const DataType& type, std::ostream* os);
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/diff_test.cc b/src/arrow/cpp/src/arrow/array/diff_test.cc
new file mode 100644
index 000000000..d802a52cd
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/diff_test.cc
@@ -0,0 +1,696 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <numeric>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/array/diff.h"
+#include "arrow/compute/api.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+constexpr random::SeedType kSeed = 0xdeadbeef;
+static const auto edits_type =
+ struct_({field("insert", boolean()), field("run_length", int64())});
+
+Status ValidateEditScript(const Array& edits, const Array& base, const Array& target) {
+ // beginning (in base) of the run before the current hunk
+ int64_t base_run_begin = 0;
+ return VisitEditScript(edits, [&](int64_t delete_begin, int64_t delete_end,
+ int64_t insert_begin, int64_t insert_end) {
+ auto target_run_begin = insert_begin - (delete_begin - base_run_begin);
+ if (!base.RangeEquals(base_run_begin, delete_begin, target_run_begin, target)) {
+ return Status::Invalid("base and target were unequal in a run");
+ }
+
+ base_run_begin = delete_end;
+ for (int64_t i = insert_begin; i < insert_end; ++i) {
+ for (int64_t d = delete_begin; d < delete_end; ++d) {
+ if (target.RangeEquals(i, i + 1, d, base)) {
+ return Status::Invalid("a deleted element was simultaneously inserted");
+ }
+ }
+ }
+
+ return Status::OK();
+ });
+}
+
+class DiffTest : public ::testing::Test {
+ protected:
+ DiffTest() : rng_(kSeed) {}
+
+ void DoDiff() {
+ auto edits = Diff(*base_, *target_, default_memory_pool());
+ ASSERT_OK(edits.status());
+ edits_ = edits.ValueOrDie();
+ ASSERT_OK(edits_->ValidateFull());
+ ASSERT_TRUE(edits_->type()->Equals(edits_type));
+ insert_ = checked_pointer_cast<BooleanArray>(edits_->field(0));
+ run_lengths_ = checked_pointer_cast<Int64Array>(edits_->field(1));
+ }
+
+ void DoDiffAndFormat(std::stringstream* out) {
+ DoDiff();
+ auto formatter = MakeUnifiedDiffFormatter(*base_->type(), out);
+ ASSERT_OK(formatter.status());
+ ASSERT_OK(formatter.ValueOrDie()(*edits_, *base_, *target_));
+ }
+
+ // validate diff and assert that it formats as expected, both directly
+ // and through Array::Equals
+ void AssertDiffAndFormat(const std::string& formatted_expected) {
+ std::stringstream formatted;
+
+ DoDiffAndFormat(&formatted);
+ ASSERT_EQ(formatted.str(), formatted_expected) << "formatted diff incorrectly";
+ formatted.str("");
+
+ ASSERT_EQ(edits_->length() == 1,
+ base_->Equals(*target_, EqualOptions().diff_sink(&formatted)));
+ ASSERT_EQ(formatted.str(), formatted_expected)
+ << "Array::Equals formatted diff incorrectly";
+ }
+
+ void AssertInsertIs(const std::string& insert_json) {
+ AssertArraysEqual(*ArrayFromJSON(boolean(), insert_json), *insert_, /*verbose=*/true);
+ }
+
+ void AssertRunLengthIs(const std::string& run_lengths_json) {
+ AssertArraysEqual(*ArrayFromJSON(int64(), run_lengths_json), *run_lengths_,
+ /*verbose=*/true);
+ }
+
+ void BaseAndTargetFromRandomFilter(std::shared_ptr<Array> values,
+ double filter_probability) {
+ std::shared_ptr<Array> base_filter, target_filter;
+ do {
+ base_filter = this->rng_.Boolean(values->length(), filter_probability, 0.0);
+ target_filter = this->rng_.Boolean(values->length(), filter_probability, 0.0);
+ } while (base_filter->Equals(target_filter));
+
+ ASSERT_OK_AND_ASSIGN(Datum out_datum, compute::Filter(values, base_filter));
+ base_ = out_datum.make_array();
+
+ ASSERT_OK_AND_ASSIGN(out_datum, compute::Filter(values, target_filter));
+ target_ = out_datum.make_array();
+ }
+
+ void TestBasicsWithUnions(UnionMode::type mode) {
+ ASSERT_OK_AND_ASSIGN(
+ auto type,
+ UnionType::Make({field("foo", utf8()), field("bar", int32())}, {2, 5}, mode));
+
+ // insert one
+ base_ = ArrayFromJSON(type, R"([[2, "!"], [5, 3], [5, 13]])");
+ target_ = ArrayFromJSON(type, R"([[2, "!"], [2, "?"], [5, 3], [5, 13]])");
+ DoDiff();
+ AssertInsertIs("[false, true]");
+ AssertRunLengthIs("[1, 2]");
+
+ // delete one
+ base_ = ArrayFromJSON(type, R"([[2, "!"], [2, "?"], [5, 3], [5, 13]])");
+ target_ = ArrayFromJSON(type, R"([[2, "!"], [5, 3], [5, 13]])");
+ DoDiff();
+ AssertInsertIs("[false, false]");
+ AssertRunLengthIs("[1, 2]");
+
+ // change one
+ base_ = ArrayFromJSON(type, R"([[5, 3], [2, "!"], [5, 13]])");
+ target_ = ArrayFromJSON(type, R"([[2, "3"], [2, "!"], [5, 13]])");
+ DoDiff();
+ AssertInsertIs("[false, false, true]");
+ AssertRunLengthIs("[0, 0, 2]");
+
+ // null out one
+ base_ = ArrayFromJSON(type, R"([[2, "!"], [5, 3], [5, 13]])");
+ target_ = ArrayFromJSON(type, R"([[2, "!"], [5, 3], null])");
+ DoDiff();
+ AssertInsertIs("[false, false, true]");
+ AssertRunLengthIs("[2, 0, 0]");
+ }
+
+ random::RandomArrayGenerator rng_;
+ std::shared_ptr<StructArray> edits_;
+ std::shared_ptr<Array> base_, target_;
+ std::shared_ptr<BooleanArray> insert_;
+ std::shared_ptr<Int64Array> run_lengths_;
+};
+
+TEST_F(DiffTest, Trivial) {
+ base_ = ArrayFromJSON(int32(), "[]");
+ target_ = ArrayFromJSON(int32(), "[]");
+ DoDiff();
+ AssertInsertIs("[false]");
+ AssertRunLengthIs("[0]");
+
+ base_ = ArrayFromJSON(null(), "[null, null]");
+ target_ = ArrayFromJSON(null(), "[null, null, null, null]");
+ DoDiff();
+ AssertInsertIs("[false, true, true]");
+ AssertRunLengthIs("[2, 0, 0]");
+
+ base_ = ArrayFromJSON(int32(), "[1, 2, 3]");
+ target_ = ArrayFromJSON(int32(), "[1, 2, 3]");
+ DoDiff();
+ AssertInsertIs("[false]");
+ AssertRunLengthIs("[3]");
+}
+
+TEST_F(DiffTest, Errors) {
+ std::stringstream formatted;
+
+ base_ = ArrayFromJSON(int32(), "[]");
+ target_ = ArrayFromJSON(utf8(), "[]");
+ ASSERT_RAISES(TypeError, Diff(*base_, *target_, default_memory_pool()));
+
+ ASSERT_FALSE(base_->Equals(*target_, EqualOptions().diff_sink(&formatted)));
+ ASSERT_EQ(formatted.str(), "# Array types differed: int32 vs string\n");
+}
+
+template <typename ArrowType>
+class DiffTestWithNumeric : public DiffTest {
+ protected:
+ std::shared_ptr<DataType> type_singleton() {
+ return TypeTraits<ArrowType>::type_singleton();
+ }
+};
+
+TYPED_TEST_SUITE(DiffTestWithNumeric, NumericArrowTypes);
+
+TYPED_TEST(DiffTestWithNumeric, Basics) {
+ // insert one
+ this->base_ = ArrayFromJSON(this->type_singleton(), "[1, 2, null, 5]");
+ this->target_ = ArrayFromJSON(this->type_singleton(), "[1, 2, 3, null, 5]");
+ this->DoDiff();
+ this->AssertInsertIs("[false, true]");
+ this->AssertRunLengthIs("[2, 2]");
+
+ // delete one
+ this->base_ = ArrayFromJSON(this->type_singleton(), "[1, 2, 3, null, 5]");
+ this->target_ = ArrayFromJSON(this->type_singleton(), "[1, 2, null, 5]");
+ this->DoDiff();
+ this->AssertInsertIs("[false, false]");
+ this->AssertRunLengthIs("[2, 2]");
+
+ // change one
+ this->base_ = ArrayFromJSON(this->type_singleton(), "[1, 2, 3, null, 5]");
+ this->target_ = ArrayFromJSON(this->type_singleton(), "[1, 2, 23, null, 5]");
+ this->DoDiff();
+ this->AssertInsertIs("[false, false, true]");
+ this->AssertRunLengthIs("[2, 0, 2]");
+
+ // null out one
+ this->base_ = ArrayFromJSON(this->type_singleton(), "[1, 2, 3, null, 5]");
+ this->target_ = ArrayFromJSON(this->type_singleton(), "[1, 2, null, null, 5]");
+ this->DoDiff();
+ this->AssertInsertIs("[false, false, true]");
+ this->AssertRunLengthIs("[2, 1, 1]");
+
+ // append some
+ this->base_ = ArrayFromJSON(this->type_singleton(), "[1, 2, 3, null, 5]");
+ this->target_ = ArrayFromJSON(this->type_singleton(), "[1, 2, 3, null, 5, 6, 7, 8, 9]");
+ this->DoDiff();
+ this->AssertInsertIs("[false, true, true, true, true]");
+ this->AssertRunLengthIs("[5, 0, 0, 0, 0]");
+
+ // prepend some
+ this->base_ = ArrayFromJSON(this->type_singleton(), "[1, 2, 3, null, 5]");
+ this->target_ = ArrayFromJSON(this->type_singleton(), "[6, 4, 2, 0, 1, 2, 3, null, 5]");
+ this->DoDiff();
+ this->AssertInsertIs("[false, true, true, true, true]");
+ this->AssertRunLengthIs("[0, 0, 0, 0, 5]");
+}
+
+TEST_F(DiffTest, CompareRandomInt64) {
+ for (auto null_probability : {0.0, 0.25}) {
+ auto values = this->rng_.Int64(1 << 10, 0, 127, null_probability);
+ for (const double filter_probability : {0.99, 0.75, 0.5}) {
+ this->BaseAndTargetFromRandomFilter(values, filter_probability);
+
+ std::stringstream formatted;
+ this->DoDiffAndFormat(&formatted);
+ auto st = ValidateEditScript(*this->edits_, *this->base_, *this->target_);
+ if (!st.ok()) {
+ ASSERT_OK(Status(st.code(), st.message() + "\n" + formatted.str()));
+ }
+ }
+ }
+}
+
+TEST_F(DiffTest, CompareRandomStrings) {
+ for (auto null_probability : {0.0, 0.25}) {
+ auto values = this->rng_.StringWithRepeats(1 << 10, 1 << 8, 0, 32, null_probability);
+ for (const double filter_probability : {0.99, 0.75, 0.5}) {
+ this->BaseAndTargetFromRandomFilter(values, filter_probability);
+
+ std::stringstream formatted;
+ this->DoDiffAndFormat(&formatted);
+ auto st = ValidateEditScript(*this->edits_, *this->base_, *this->target_);
+ if (!st.ok()) {
+ ASSERT_OK(Status(st.code(), st.message() + "\n" + formatted.str()));
+ }
+ }
+ }
+}
+
+TEST_F(DiffTest, BasicsWithBooleans) {
+ // insert one
+ base_ = ArrayFromJSON(boolean(), R"([true, true, true])");
+ target_ = ArrayFromJSON(boolean(), R"([true, false, true, true])");
+ DoDiff();
+ AssertInsertIs("[false, true]");
+ AssertRunLengthIs("[1, 2]");
+
+ // delete one
+ base_ = ArrayFromJSON(boolean(), R"([true, false, true, true])");
+ target_ = ArrayFromJSON(boolean(), R"([true, true, true])");
+ DoDiff();
+ AssertInsertIs("[false, false]");
+ AssertRunLengthIs("[1, 2]");
+
+ // change one
+ base_ = ArrayFromJSON(boolean(), R"([false, false, true])");
+ target_ = ArrayFromJSON(boolean(), R"([true, false, true])");
+ DoDiff();
+ AssertInsertIs("[false, false, true]");
+ AssertRunLengthIs("[0, 0, 2]");
+
+ // null out one
+ base_ = ArrayFromJSON(boolean(), R"([true, false, true])");
+ target_ = ArrayFromJSON(boolean(), R"([true, false, null])");
+ DoDiff();
+ AssertInsertIs("[false, false, true]");
+ AssertRunLengthIs("[2, 0, 0]");
+}
+
+TEST_F(DiffTest, BasicsWithStrings) {
+ // insert one
+ base_ = ArrayFromJSON(utf8(), R"(["give", "a", "break"])");
+ target_ = ArrayFromJSON(utf8(), R"(["give", "me", "a", "break"])");
+ DoDiff();
+ AssertInsertIs("[false, true]");
+ AssertRunLengthIs("[1, 2]");
+
+ // delete one
+ base_ = ArrayFromJSON(utf8(), R"(["give", "me", "a", "break"])");
+ target_ = ArrayFromJSON(utf8(), R"(["give", "a", "break"])");
+ DoDiff();
+ AssertInsertIs("[false, false]");
+ AssertRunLengthIs("[1, 2]");
+
+ // change one
+ base_ = ArrayFromJSON(utf8(), R"(["give", "a", "break"])");
+ target_ = ArrayFromJSON(utf8(), R"(["gimme", "a", "break"])");
+ DoDiff();
+ AssertInsertIs("[false, false, true]");
+ AssertRunLengthIs("[0, 0, 2]");
+
+ // null out one
+ base_ = ArrayFromJSON(utf8(), R"(["give", "a", "break"])");
+ target_ = ArrayFromJSON(utf8(), R"(["give", "a", null])");
+ DoDiff();
+ AssertInsertIs("[false, false, true]");
+ AssertRunLengthIs("[2, 0, 0]");
+}
+
+TEST_F(DiffTest, BasicsWithLists) {
+ // insert one
+ base_ = ArrayFromJSON(list(int32()), R"([[2, 3, 1], [], [13]])");
+ target_ = ArrayFromJSON(list(int32()), R"([[2, 3, 1], [5, 9], [], [13]])");
+ DoDiff();
+ AssertInsertIs("[false, true]");
+ AssertRunLengthIs("[1, 2]");
+
+ // delete one
+ base_ = ArrayFromJSON(list(int32()), R"([[2, 3, 1], [5, 9], [], [13]])");
+ target_ = ArrayFromJSON(list(int32()), R"([[2, 3, 1], [], [13]])");
+ DoDiff();
+ AssertInsertIs("[false, false]");
+ AssertRunLengthIs("[1, 2]");
+
+ // change one
+ base_ = ArrayFromJSON(list(int32()), R"([[2, 3, 1], [], [13]])");
+ target_ = ArrayFromJSON(list(int32()), R"([[3, 3, 3], [], [13]])");
+ DoDiff();
+ AssertInsertIs("[false, false, true]");
+ AssertRunLengthIs("[0, 0, 2]");
+
+ // null out one
+ base_ = ArrayFromJSON(list(int32()), R"([[2, 3, 1], [], [13]])");
+ target_ = ArrayFromJSON(list(int32()), R"([[2, 3, 1], [], null])");
+ DoDiff();
+ AssertInsertIs("[false, false, true]");
+ AssertRunLengthIs("[2, 0, 0]");
+}
+
+TEST_F(DiffTest, BasicsWithStructs) {
+ auto type = struct_({field("foo", utf8()), field("bar", int32())});
+
+ // insert one
+ base_ = ArrayFromJSON(type, R"([{"foo": "!", "bar": 3}, {}, {"bar": 13}])");
+ target_ =
+ ArrayFromJSON(type, R"([{"foo": "!", "bar": 3}, {"foo": "?"}, {}, {"bar": 13}])");
+ DoDiff();
+ AssertInsertIs("[false, true]");
+ AssertRunLengthIs("[1, 2]");
+
+ // delete one
+ base_ =
+ ArrayFromJSON(type, R"([{"foo": "!", "bar": 3}, {"foo": "?"}, {}, {"bar": 13}])");
+ target_ = ArrayFromJSON(type, R"([{"foo": "!", "bar": 3}, {}, {"bar": 13}])");
+ DoDiff();
+ AssertInsertIs("[false, false]");
+ AssertRunLengthIs("[1, 2]");
+
+ // change one
+ base_ = ArrayFromJSON(type, R"([{"foo": "!", "bar": 3}, {}, {"bar": 13}])");
+ target_ = ArrayFromJSON(type, R"([{"foo": "!", "bar": 2}, {}, {"bar": 13}])");
+ DoDiff();
+ AssertInsertIs("[false, false, true]");
+ AssertRunLengthIs("[0, 0, 2]");
+
+ // null out one
+ base_ = ArrayFromJSON(type, R"([{"foo": "!", "bar": 3}, {}, {"bar": 13}])");
+ target_ = ArrayFromJSON(type, R"([{"foo": "!", "bar": 3}, {}, null])");
+ DoDiff();
+ AssertInsertIs("[false, false, true]");
+ AssertRunLengthIs("[2, 0, 0]");
+}
+
+TEST_F(DiffTest, BasicsWithSparseUnions) { TestBasicsWithUnions(UnionMode::SPARSE); }
+
+TEST_F(DiffTest, BasicsWithDenseUnions) { TestBasicsWithUnions(UnionMode::DENSE); }
+
+TEST_F(DiffTest, UnifiedDiffFormatter) {
+ // no changes
+ base_ = ArrayFromJSON(utf8(), R"(["give", "me", "a", "break"])");
+ target_ = ArrayFromJSON(utf8(), R"(["give", "me", "a", "break"])");
+ AssertDiffAndFormat(R"()");
+
+ // insert one
+ base_ = ArrayFromJSON(utf8(), R"(["give", "a", "break"])");
+ target_ = ArrayFromJSON(utf8(), R"(["give", "me", "a", "break"])");
+ AssertDiffAndFormat(R"(
+@@ -1, +1 @@
++"me"
+)");
+
+ // delete one
+ base_ = ArrayFromJSON(utf8(), R"(["give", "me", "a", "break"])");
+ target_ = ArrayFromJSON(utf8(), R"(["give", "a", "break"])");
+ AssertDiffAndFormat(R"(
+@@ -1, +1 @@
+-"me"
+)");
+
+ // change one
+ base_ = ArrayFromJSON(utf8(), R"(["give", "a", "break"])");
+ target_ = ArrayFromJSON(utf8(), R"(["gimme", "a", "break"])");
+ AssertDiffAndFormat(R"(
+@@ -0, +0 @@
+-"give"
++"gimme"
+)");
+
+ // null out one
+ base_ = ArrayFromJSON(utf8(), R"(["give", "a", "break"])");
+ target_ = ArrayFromJSON(utf8(), R"(["give", "a", null])");
+ AssertDiffAndFormat(R"(
+@@ -2, +2 @@
+-"break"
++null
+)");
+
+ // strings with escaped chars
+ base_ = ArrayFromJSON(utf8(), R"(["newline:\n", "quote:'", "backslash:\\"])");
+ target_ =
+ ArrayFromJSON(utf8(), R"(["newline:\n", "tab:\t", "quote:\"", "backslash:\\"])");
+ AssertDiffAndFormat(R"(
+@@ -1, +1 @@
+-"quote:'"
++"tab:\t"
++"quote:\""
+)");
+
+ // date32
+ base_ = ArrayFromJSON(date32(), R"([0, 1, 2, 31, 4])");
+ target_ = ArrayFromJSON(date32(), R"([0, 1, 31, 2, 4])");
+ AssertDiffAndFormat(R"(
+@@ -2, +2 @@
+-1970-01-03
+@@ -4, +3 @@
++1970-01-03
+)");
+
+ // date64
+ constexpr int64_t ms_per_day = 24 * 60 * 60 * 1000;
+ ArrayFromVector<Date64Type>(
+ {0 * ms_per_day, 1 * ms_per_day, 2 * ms_per_day, 31 * ms_per_day, 4 * ms_per_day},
+ &base_);
+ ArrayFromVector<Date64Type>(
+ {0 * ms_per_day, 1 * ms_per_day, 31 * ms_per_day, 2 * ms_per_day, 4 * ms_per_day},
+ &target_);
+ AssertDiffAndFormat(R"(
+@@ -2, +2 @@
+-1970-01-03
+@@ -4, +3 @@
++1970-01-03
+)");
+
+ // timestamp
+ auto x = 678 + 1000000 * (5 + 60 * (4 + 60 * (3 + 24 * int64_t(1))));
+ ArrayFromVector<TimestampType>(timestamp(TimeUnit::MICRO), {0, 1, x, 2, 4}, &base_);
+ ArrayFromVector<TimestampType>(timestamp(TimeUnit::MICRO), {0, 1, 2, x, 4}, &target_);
+ AssertDiffAndFormat(R"(
+@@ -2, +2 @@
+-1970-01-02 03:04:05.000678
+@@ -4, +3 @@
++1970-01-02 03:04:05.000678
+)");
+
+ // Month, Day, Nano Intervals
+ base_ = ArrayFromJSON(month_day_nano_interval(), R"([[2, 3, 1]])");
+ target_ = ArrayFromJSON(month_day_nano_interval(), R"([])");
+ AssertDiffAndFormat(R"(
+@@ -0, +0 @@
+-2M3d1ns
+)");
+
+ // lists
+ base_ = ArrayFromJSON(list(int32()), R"([[2, 3, 1], [], [13], []])");
+ target_ = ArrayFromJSON(list(int32()), R"([[2, 3, 1], [5, 9], [], [13]])");
+ AssertDiffAndFormat(R"(
+@@ -1, +1 @@
++[5, 9]
+@@ -3, +4 @@
+-[]
+)");
+
+ // maps
+ base_ = ArrayFromJSON(map(utf8(), int32()), R"([
+ [["foo", 2], ["bar", 3], ["baz", 1]],
+ [],
+ [["quux", 13]],
+ []
+ ])");
+ target_ = ArrayFromJSON(map(utf8(), int32()), R"([
+ [["foo", 2], ["bar", 3], ["baz", 1]],
+ [["ytho", 11]],
+ [],
+ [["quux", 13]]
+ ])");
+ AssertDiffAndFormat(R"(
+@@ -1, +1 @@
++[{key: "ytho", value: 11}]
+@@ -3, +4 @@
+-[]
+)");
+
+ // structs
+ auto type = struct_({field("foo", utf8()), field("bar", int32())});
+ base_ = ArrayFromJSON(type, R"([{"foo": "!", "bar": 3}, {}, {"bar": 13}])");
+ target_ = ArrayFromJSON(type, R"([{"foo": null, "bar": 2}, {}, {"bar": 13}])");
+ AssertDiffAndFormat(R"(
+@@ -0, +0 @@
+-{foo: "!", bar: 3}
++{bar: 2}
+)");
+
+ // unions
+ for (auto union_ : UnionTypeFactories()) {
+ type = union_({field("foo", utf8()), field("bar", int32())}, {2, 5});
+ base_ = ArrayFromJSON(type, R"([[2, "!"], [5, 3], [5, 13]])");
+ target_ = ArrayFromJSON(type, R"([[2, "!"], [2, "3"], [5, 13]])");
+ AssertDiffAndFormat(R"(
+@@ -1, +1 @@
+-{5: 3}
++{2: "3"}
+)");
+ }
+
+ for (auto type : {int8(), uint8(), // verify that these are printed as numbers rather
+ // than their ascii characters
+ int16(), uint16()}) {
+ // small difference
+ base_ = ArrayFromJSON(type, "[0, 1, 2, 3, 5, 8, 11, 13, 17]");
+ target_ = ArrayFromJSON(type, "[2, 3, 5, 7, 11, 13, 17, 19]");
+ AssertDiffAndFormat(R"(
+@@ -0, +0 @@
+-0
+-1
+@@ -5, +3 @@
+-8
++7
+@@ -9, +7 @@
++19
+)");
+
+ // large difference
+ base_ = ArrayFromJSON(type, "[57, 10, 22, 126, 42]");
+ target_ = ArrayFromJSON(type, "[58, 57, 75, 93, 53, 8, 22, 42, 79, 11]");
+ AssertDiffAndFormat(R"(
+@@ -0, +0 @@
++58
+@@ -1, +2 @@
+-10
++75
++93
++53
++8
+@@ -3, +7 @@
+-126
+@@ -5, +8 @@
++79
++11
+)");
+ }
+}
+
+TEST_F(DiffTest, DictionaryDiffFormatter) {
+ std::stringstream formatted;
+
+ // differing indices
+ auto base_dict = ArrayFromJSON(utf8(), R"(["a", "b", "c"])");
+ auto base_indices = ArrayFromJSON(int8(), "[0, 1, 2, 2, 0, 1]");
+ ASSERT_OK_AND_ASSIGN(base_, DictionaryArray::FromArrays(
+ dictionary(base_indices->type(), base_dict->type()),
+ base_indices, base_dict));
+
+ auto target_dict = base_dict;
+ auto target_indices = ArrayFromJSON(int8(), "[0, 1, 2, 2, 1, 1]");
+ ASSERT_OK_AND_ASSIGN(
+ target_,
+ DictionaryArray::FromArrays(dictionary(target_indices->type(), target_dict->type()),
+ target_indices, target_dict));
+
+ base_->Equals(*target_, EqualOptions().diff_sink(&formatted));
+ auto formatted_expected_indices = R"(# Dictionary arrays differed
+## dictionary diff
+## indices diff
+@@ -4, +4 @@
+-0
+@@ -6, +5 @@
++1
+)";
+ ASSERT_EQ(formatted.str(), formatted_expected_indices);
+
+ // Note: Diff doesn't work at the moment with dictionary arrays
+ ASSERT_RAISES(NotImplemented, Diff(*base_, *target_));
+
+ // differing dictionaries
+ target_dict = ArrayFromJSON(utf8(), R"(["b", "c", "a"])");
+ target_indices = base_indices;
+ ASSERT_OK_AND_ASSIGN(
+ target_,
+ DictionaryArray::FromArrays(dictionary(target_indices->type(), target_dict->type()),
+ target_indices, target_dict));
+
+ formatted.str("");
+ base_->Equals(*target_, EqualOptions().diff_sink(&formatted));
+ auto formatted_expected_values = R"(# Dictionary arrays differed
+## dictionary diff
+@@ -0, +0 @@
+-"a"
+@@ -3, +2 @@
++"a"
+## indices diff
+)";
+ ASSERT_EQ(formatted.str(), formatted_expected_values);
+}
+
+void MakeSameLength(std::shared_ptr<Array>* a, std::shared_ptr<Array>* b) {
+ auto length = std::min((*a)->length(), (*b)->length());
+ *a = (*a)->Slice(0, length);
+ *b = (*b)->Slice(0, length);
+}
+
+TEST_F(DiffTest, CompareRandomStruct) {
+ for (auto null_probability : {0.0, 0.25}) {
+ constexpr auto length = 1 << 10;
+ auto int32_values = this->rng_.Int32(length, 0, 127, null_probability);
+ auto utf8_values = this->rng_.String(length, 0, 16, null_probability);
+ for (const double filter_probability : {0.9999, 0.75}) {
+ this->BaseAndTargetFromRandomFilter(int32_values, filter_probability);
+ auto int32_base = this->base_;
+ auto int32_target = this->base_;
+
+ this->BaseAndTargetFromRandomFilter(utf8_values, filter_probability);
+ auto utf8_base = this->base_;
+ auto utf8_target = this->base_;
+
+ MakeSameLength(&int32_base, &utf8_base);
+ MakeSameLength(&int32_target, &utf8_target);
+
+ auto type = struct_({field("i", int32()), field("s", utf8())});
+ auto base_res = StructArray::Make({int32_base, utf8_base}, type->fields());
+ ASSERT_OK(base_res.status());
+ base_ = base_res.ValueOrDie();
+ auto target_res = StructArray::Make({int32_target, utf8_target}, type->fields());
+ ASSERT_OK(target_res.status());
+ target_ = target_res.ValueOrDie();
+
+ std::stringstream formatted;
+ this->DoDiffAndFormat(&formatted);
+ auto st = ValidateEditScript(*this->edits_, *this->base_, *this->target_);
+ if (!st.ok()) {
+ ASSERT_OK(Status(st.code(), st.message() + "\n" + formatted.str()));
+ }
+ }
+ }
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/util.cc b/src/arrow/cpp/src/arrow/array/util.cc
new file mode 100644
index 000000000..d639830f4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/util.cc
@@ -0,0 +1,860 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/util.h"
+
+#include <algorithm>
+#include <array>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_dict.h"
+#include "arrow/array/array_primitive.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/extension_type.h"
+#include "arrow/result.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/logging.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+// ----------------------------------------------------------------------
+// Loading from ArrayData
+
+namespace {
+
+class ArrayDataWrapper {
+ public:
+ ArrayDataWrapper(const std::shared_ptr<ArrayData>& data, std::shared_ptr<Array>* out)
+ : data_(data), out_(out) {}
+
+ template <typename T>
+ Status Visit(const T&) {
+ using ArrayType = typename TypeTraits<T>::ArrayType;
+ *out_ = std::make_shared<ArrayType>(data_);
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionType& type) {
+ *out_ = type.MakeArray(data_);
+ return Status::OK();
+ }
+
+ const std::shared_ptr<ArrayData>& data_;
+ std::shared_ptr<Array>* out_;
+};
+
+class ArrayDataEndianSwapper {
+ public:
+ explicit ArrayDataEndianSwapper(const std::shared_ptr<ArrayData>& data) : data_(data) {
+ out_ = data->Copy();
+ }
+
+ // WARNING: this facility can be called on invalid Array data by the IPC reader.
+ // Do not rely on the advertised ArrayData length, instead use the physical
+ // buffer sizes to avoid accessing memory out of bounds.
+ //
+ // (If this guarantee turns out to be difficult to maintain, we should call
+ // Validate() instead)
+ Status SwapType(const DataType& type) {
+ RETURN_NOT_OK(VisitTypeInline(type, this));
+ RETURN_NOT_OK(SwapChildren(type.fields()));
+ if (internal::HasValidityBitmap(type.id())) {
+ // Copy null bitmap
+ out_->buffers[0] = data_->buffers[0];
+ }
+ return Status::OK();
+ }
+
+ Status SwapChildren(const FieldVector& child_fields) {
+ for (size_t i = 0; i < child_fields.size(); i++) {
+ ARROW_ASSIGN_OR_RAISE(out_->child_data[i],
+ internal::SwapEndianArrayData(data_->child_data[i]));
+ }
+ return Status::OK();
+ }
+
+ template <typename T>
+ Result<std::shared_ptr<Buffer>> ByteSwapBuffer(
+ const std::shared_ptr<Buffer>& in_buffer) {
+ if (sizeof(T) == 1) {
+ // if data size is 1, element is not swapped. We can use the original buffer
+ return in_buffer;
+ }
+ auto in_data = reinterpret_cast<const T*>(in_buffer->data());
+ ARROW_ASSIGN_OR_RAISE(auto out_buffer, AllocateBuffer(in_buffer->size()));
+ auto out_data = reinterpret_cast<T*>(out_buffer->mutable_data());
+ // NOTE: data_->length not trusted (see warning above)
+ int64_t length = in_buffer->size() / sizeof(T);
+ for (int64_t i = 0; i < length; i++) {
+ out_data[i] = BitUtil::ByteSwap(in_data[i]);
+ }
+ return std::move(out_buffer);
+ }
+
+ template <typename VALUE_TYPE>
+ Status SwapOffsets(int index) {
+ if (data_->buffers[index] == nullptr || data_->buffers[index]->size() == 0) {
+ out_->buffers[index] = data_->buffers[index];
+ return Status::OK();
+ }
+ // Except union, offset has one more element rather than data->length
+ ARROW_ASSIGN_OR_RAISE(out_->buffers[index],
+ ByteSwapBuffer<VALUE_TYPE>(data_->buffers[index]));
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_t<std::is_base_of<FixedWidthType, T>::value &&
+ !std::is_base_of<FixedSizeBinaryType, T>::value &&
+ !std::is_base_of<DictionaryType, T>::value,
+ Status>
+ Visit(const T& type) {
+ using value_type = typename T::c_type;
+ ARROW_ASSIGN_OR_RAISE(out_->buffers[1],
+ ByteSwapBuffer<value_type>(data_->buffers[1]));
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal128Type& type) {
+ auto data = reinterpret_cast<const uint64_t*>(data_->buffers[1]->data());
+ ARROW_ASSIGN_OR_RAISE(auto new_buffer, AllocateBuffer(data_->buffers[1]->size()));
+ auto new_data = reinterpret_cast<uint64_t*>(new_buffer->mutable_data());
+ // NOTE: data_->length not trusted (see warning above)
+ const int64_t length = data_->buffers[1]->size() / Decimal128Type::kByteWidth;
+ for (int64_t i = 0; i < length; i++) {
+ uint64_t tmp;
+ auto idx = i * 2;
+#if ARROW_LITTLE_ENDIAN
+ tmp = BitUtil::FromBigEndian(data[idx]);
+ new_data[idx] = BitUtil::FromBigEndian(data[idx + 1]);
+ new_data[idx + 1] = tmp;
+#else
+ tmp = BitUtil::FromLittleEndian(data[idx]);
+ new_data[idx] = BitUtil::FromLittleEndian(data[idx + 1]);
+ new_data[idx + 1] = tmp;
+#endif
+ }
+ out_->buffers[1] = std::move(new_buffer);
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal256Type& type) {
+ auto data = reinterpret_cast<const uint64_t*>(data_->buffers[1]->data());
+ ARROW_ASSIGN_OR_RAISE(auto new_buffer, AllocateBuffer(data_->buffers[1]->size()));
+ auto new_data = reinterpret_cast<uint64_t*>(new_buffer->mutable_data());
+ // NOTE: data_->length not trusted (see warning above)
+ const int64_t length = data_->buffers[1]->size() / Decimal256Type::kByteWidth;
+ for (int64_t i = 0; i < length; i++) {
+ uint64_t tmp0, tmp1, tmp2;
+ auto idx = i * 4;
+#if ARROW_LITTLE_ENDIAN
+ tmp0 = BitUtil::FromBigEndian(data[idx]);
+ tmp1 = BitUtil::FromBigEndian(data[idx + 1]);
+ tmp2 = BitUtil::FromBigEndian(data[idx + 2]);
+ new_data[idx] = BitUtil::FromBigEndian(data[idx + 3]);
+ new_data[idx + 1] = tmp2;
+ new_data[idx + 2] = tmp1;
+ new_data[idx + 3] = tmp0;
+#else
+ tmp0 = BitUtil::FromLittleEndian(data[idx]);
+ tmp1 = BitUtil::FromLittleEndian(data[idx + 1]);
+ tmp2 = BitUtil::FromLittleEndian(data[idx + 2]);
+ new_data[idx] = BitUtil::FromLittleEndian(data[idx + 3]);
+ new_data[idx + 1] = tmp2;
+ new_data[idx + 2] = tmp1;
+ new_data[idx + 3] = tmp0;
+#endif
+ }
+ out_->buffers[1] = std::move(new_buffer);
+ return Status::OK();
+ }
+
+ Status Visit(const DayTimeIntervalType& type) {
+ ARROW_ASSIGN_OR_RAISE(out_->buffers[1], ByteSwapBuffer<uint32_t>(data_->buffers[1]));
+ return Status::OK();
+ }
+
+ Status Visit(const MonthDayNanoIntervalType& type) {
+ using MonthDayNanos = MonthDayNanoIntervalType::MonthDayNanos;
+ auto data = reinterpret_cast<const MonthDayNanos*>(data_->buffers[1]->data());
+ ARROW_ASSIGN_OR_RAISE(auto new_buffer, AllocateBuffer(data_->buffers[1]->size()));
+ auto new_data = reinterpret_cast<MonthDayNanos*>(new_buffer->mutable_data());
+ // NOTE: data_->length not trusted (see warning above)
+ const int64_t length = data_->buffers[1]->size() / sizeof(MonthDayNanos);
+ for (int64_t i = 0; i < length; i++) {
+ MonthDayNanos tmp = data[i];
+#if ARROW_LITTLE_ENDIAN
+ tmp.months = BitUtil::FromBigEndian(tmp.months);
+ tmp.days = BitUtil::FromBigEndian(tmp.days);
+ tmp.nanoseconds = BitUtil::FromBigEndian(tmp.nanoseconds);
+#else
+ tmp.months = BitUtil::FromLittleEndian(tmp.months);
+ tmp.days = BitUtil::FromLittleEndian(tmp.days);
+ tmp.nanoseconds = BitUtil::FromLittleEndian(tmp.nanoseconds);
+#endif
+ new_data[i] = tmp;
+ }
+ out_->buffers[1] = std::move(new_buffer);
+ return Status::OK();
+ }
+
+ Status Visit(const NullType& type) { return Status::OK(); }
+ Status Visit(const BooleanType& type) { return Status::OK(); }
+ Status Visit(const Int8Type& type) { return Status::OK(); }
+ Status Visit(const UInt8Type& type) { return Status::OK(); }
+ Status Visit(const FixedSizeBinaryType& type) { return Status::OK(); }
+ Status Visit(const FixedSizeListType& type) { return Status::OK(); }
+ Status Visit(const StructType& type) { return Status::OK(); }
+ Status Visit(const UnionType& type) {
+ out_->buffers[1] = data_->buffers[1];
+ if (type.mode() == UnionMode::DENSE) {
+ RETURN_NOT_OK(SwapOffsets<int32_t>(2));
+ }
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_t<std::is_same<BinaryType, T>::value || std::is_same<StringType, T>::value,
+ Status>
+ Visit(const T& type) {
+ RETURN_NOT_OK(SwapOffsets<int32_t>(1));
+ out_->buffers[2] = data_->buffers[2];
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_t<std::is_same<LargeBinaryType, T>::value ||
+ std::is_same<LargeStringType, T>::value,
+ Status>
+ Visit(const T& type) {
+ RETURN_NOT_OK(SwapOffsets<int64_t>(1));
+ out_->buffers[2] = data_->buffers[2];
+ return Status::OK();
+ }
+
+ Status Visit(const ListType& type) {
+ RETURN_NOT_OK(SwapOffsets<int32_t>(1));
+ return Status::OK();
+ }
+ Status Visit(const LargeListType& type) {
+ RETURN_NOT_OK(SwapOffsets<int64_t>(1));
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryType& type) {
+ // dictionary was already swapped in ReadDictionary() in ipc/reader.cc
+ RETURN_NOT_OK(SwapType(*type.index_type()));
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionType& type) {
+ RETURN_NOT_OK(SwapType(*type.storage_type()));
+ return Status::OK();
+ }
+
+ const std::shared_ptr<ArrayData>& data_;
+ std::shared_ptr<ArrayData> out_;
+};
+
+} // namespace
+
+namespace internal {
+
+Result<std::shared_ptr<ArrayData>> SwapEndianArrayData(
+ const std::shared_ptr<ArrayData>& data) {
+ if (data->offset != 0) {
+ return Status::Invalid("Unsupported data format: data.offset != 0");
+ }
+ ArrayDataEndianSwapper swapper(data);
+ RETURN_NOT_OK(swapper.SwapType(*data->type));
+ return std::move(swapper.out_);
+}
+
+} // namespace internal
+
+std::shared_ptr<Array> MakeArray(const std::shared_ptr<ArrayData>& data) {
+ std::shared_ptr<Array> out;
+ ArrayDataWrapper wrapper_visitor(data, &out);
+ DCHECK_OK(VisitTypeInline(*data->type, &wrapper_visitor));
+ DCHECK(out);
+ return out;
+}
+
+// ----------------------------------------------------------------------
+// Misc APIs
+
+namespace {
+
+// get the maximum buffer length required, then allocate a single zeroed buffer
+// to use anywhere a buffer is required
+class NullArrayFactory {
+ public:
+ struct GetBufferLength {
+ GetBufferLength(const std::shared_ptr<DataType>& type, int64_t length)
+ : type_(*type), length_(length), buffer_length_(BitUtil::BytesForBits(length)) {}
+
+ Result<int64_t> Finish() && {
+ RETURN_NOT_OK(VisitTypeInline(type_, this));
+ return buffer_length_;
+ }
+
+ template <typename T, typename = decltype(TypeTraits<T>::bytes_required(0))>
+ Status Visit(const T&) {
+ return MaxOf(TypeTraits<T>::bytes_required(length_));
+ }
+
+ template <typename T>
+ enable_if_var_size_list<T, Status> Visit(const T&) {
+ // values array may be empty, but there must be at least one offset of 0
+ return MaxOf(sizeof(typename T::offset_type) * (length_ + 1));
+ }
+
+ template <typename T>
+ enable_if_base_binary<T, Status> Visit(const T&) {
+ // values buffer may be empty, but there must be at least one offset of 0
+ return MaxOf(sizeof(typename T::offset_type) * (length_ + 1));
+ }
+
+ Status Visit(const FixedSizeListType& type) {
+ return MaxOf(GetBufferLength(type.value_type(), type.list_size() * length_));
+ }
+
+ Status Visit(const FixedSizeBinaryType& type) {
+ return MaxOf(type.byte_width() * length_);
+ }
+
+ Status Visit(const StructType& type) {
+ for (const auto& child : type.fields()) {
+ RETURN_NOT_OK(MaxOf(GetBufferLength(child->type(), length_)));
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const UnionType& type) {
+ // type codes
+ RETURN_NOT_OK(MaxOf(length_));
+ if (type.mode() == UnionMode::DENSE) {
+ // offsets
+ RETURN_NOT_OK(MaxOf(sizeof(int32_t) * length_));
+ }
+ for (const auto& child : type.fields()) {
+ RETURN_NOT_OK(MaxOf(GetBufferLength(child->type(), length_)));
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryType& type) {
+ RETURN_NOT_OK(MaxOf(GetBufferLength(type.value_type(), length_)));
+ return MaxOf(GetBufferLength(type.index_type(), length_));
+ }
+
+ Status Visit(const ExtensionType& type) {
+ // XXX is an extension array's length always == storage length
+ return MaxOf(GetBufferLength(type.storage_type(), length_));
+ }
+
+ Status Visit(const DataType& type) {
+ return Status::NotImplemented("construction of all-null ", type);
+ }
+
+ private:
+ Status MaxOf(GetBufferLength&& other) {
+ ARROW_ASSIGN_OR_RAISE(int64_t buffer_length, std::move(other).Finish());
+ return MaxOf(buffer_length);
+ }
+
+ Status MaxOf(int64_t buffer_length) {
+ if (buffer_length > buffer_length_) {
+ buffer_length_ = buffer_length;
+ }
+ return Status::OK();
+ }
+
+ const DataType& type_;
+ int64_t length_, buffer_length_;
+ };
+
+ NullArrayFactory(MemoryPool* pool, const std::shared_ptr<DataType>& type,
+ int64_t length)
+ : pool_(pool), type_(type), length_(length) {}
+
+ Status CreateBuffer() {
+ ARROW_ASSIGN_OR_RAISE(int64_t buffer_length,
+ GetBufferLength(type_, length_).Finish());
+ ARROW_ASSIGN_OR_RAISE(buffer_, AllocateBuffer(buffer_length, pool_));
+ std::memset(buffer_->mutable_data(), 0, buffer_->size());
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<ArrayData>> Create() {
+ if (buffer_ == nullptr) {
+ RETURN_NOT_OK(CreateBuffer());
+ }
+ std::vector<std::shared_ptr<ArrayData>> child_data(type_->num_fields());
+ out_ = ArrayData::Make(type_, length_, {buffer_}, child_data, length_, 0);
+ RETURN_NOT_OK(VisitTypeInline(*type_, this));
+ return out_;
+ }
+
+ Status Visit(const NullType&) {
+ out_->buffers.resize(1, nullptr);
+ return Status::OK();
+ }
+
+ Status Visit(const FixedWidthType&) {
+ out_->buffers.resize(2, buffer_);
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_base_binary<T, Status> Visit(const T&) {
+ out_->buffers.resize(3, buffer_);
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_var_size_list<T, Status> Visit(const T& type) {
+ out_->buffers.resize(2, buffer_);
+ ARROW_ASSIGN_OR_RAISE(out_->child_data[0], CreateChild(0, /*length=*/0));
+ return Status::OK();
+ }
+
+ Status Visit(const FixedSizeListType& type) {
+ ARROW_ASSIGN_OR_RAISE(out_->child_data[0],
+ CreateChild(0, length_ * type.list_size()));
+ return Status::OK();
+ }
+
+ Status Visit(const StructType& type) {
+ for (int i = 0; i < type_->num_fields(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(i, length_));
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const UnionType& type) {
+ out_->buffers.resize(2);
+
+ // First buffer is always null
+ out_->buffers[0] = nullptr;
+
+ out_->buffers[1] = buffer_;
+ // buffer_ is zeroed, but 0 may not be a valid type code
+ if (type.type_codes()[0] != 0) {
+ ARROW_ASSIGN_OR_RAISE(out_->buffers[1], AllocateBuffer(length_, pool_));
+ std::memset(out_->buffers[1]->mutable_data(), type.type_codes()[0], length_);
+ }
+
+ // For sparse unions, we now create children with the same length as the
+ // parent
+ int64_t child_length = length_;
+ if (type.mode() == UnionMode::DENSE) {
+ // For dense unions, we set the offsets to all zero and create children
+ // with length 1
+ out_->buffers.resize(3);
+ out_->buffers[2] = buffer_;
+
+ child_length = 1;
+ }
+ for (int i = 0; i < type_->num_fields(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(i, child_length));
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryType& type) {
+ out_->buffers.resize(2, buffer_);
+ ARROW_ASSIGN_OR_RAISE(auto typed_null_dict, MakeArrayOfNull(type.value_type(), 0));
+ out_->dictionary = typed_null_dict->data();
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionType& type) {
+ RETURN_NOT_OK(VisitTypeInline(*type.storage_type(), this));
+ return Status::OK();
+ }
+
+ Status Visit(const DataType& type) {
+ return Status::NotImplemented("construction of all-null ", type);
+ }
+
+ Result<std::shared_ptr<ArrayData>> CreateChild(int i, int64_t length) {
+ NullArrayFactory child_factory(pool_, type_->field(i)->type(), length);
+ child_factory.buffer_ = buffer_;
+ return child_factory.Create();
+ }
+
+ MemoryPool* pool_;
+ std::shared_ptr<DataType> type_;
+ int64_t length_;
+ std::shared_ptr<ArrayData> out_;
+ std::shared_ptr<Buffer> buffer_;
+};
+
+class RepeatedArrayFactory {
+ public:
+ RepeatedArrayFactory(MemoryPool* pool, const Scalar& scalar, int64_t length)
+ : pool_(pool), scalar_(scalar), length_(length) {}
+
+ Result<std::shared_ptr<Array>> Create() {
+ RETURN_NOT_OK(VisitTypeInline(*scalar_.type, this));
+ return out_;
+ }
+
+ Status Visit(const NullType& type) {
+ DCHECK(false); // already forwarded to MakeArrayOfNull
+ return Status::OK();
+ }
+
+ Status Visit(const BooleanType&) {
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBitmap(length_, pool_));
+ BitUtil::SetBitsTo(buffer->mutable_data(), 0, length_,
+ checked_cast<const BooleanScalar&>(scalar_).value);
+ out_ = std::make_shared<BooleanArray>(length_, buffer);
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_t<is_number_type<T>::value || is_temporal_type<T>::value, Status> Visit(
+ const T&) {
+ auto value = checked_cast<const typename TypeTraits<T>::ScalarType&>(scalar_).value;
+ return FinishFixedWidth(&value, sizeof(value));
+ }
+
+ Status Visit(const FixedSizeBinaryType& type) {
+ auto value = checked_cast<const FixedSizeBinaryScalar&>(scalar_).value;
+ return FinishFixedWidth(value->data(), type.byte_width());
+ }
+
+ template <typename T>
+ enable_if_decimal<T, Status> Visit(const T&) {
+ using ScalarType = typename TypeTraits<T>::ScalarType;
+ auto value = checked_cast<const ScalarType&>(scalar_).value.ToBytes();
+ return FinishFixedWidth(value.data(), value.size());
+ }
+
+ Status Visit(const Decimal256Type&) {
+ auto value = checked_cast<const Decimal256Scalar&>(scalar_).value.ToBytes();
+ return FinishFixedWidth(value.data(), value.size());
+ }
+
+ template <typename T>
+ enable_if_base_binary<T, Status> Visit(const T&) {
+ std::shared_ptr<Buffer> value =
+ checked_cast<const typename TypeTraits<T>::ScalarType&>(scalar_).value;
+ std::shared_ptr<Buffer> values_buffer, offsets_buffer;
+ RETURN_NOT_OK(CreateBufferOf(value->data(), value->size(), &values_buffer));
+ auto size = static_cast<typename T::offset_type>(value->size());
+ RETURN_NOT_OK(CreateOffsetsBuffer(size, &offsets_buffer));
+ out_ = std::make_shared<typename TypeTraits<T>::ArrayType>(length_, offsets_buffer,
+ values_buffer);
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_var_size_list<T, Status> Visit(const T& type) {
+ using ScalarType = typename TypeTraits<T>::ScalarType;
+ using ArrayType = typename TypeTraits<T>::ArrayType;
+
+ auto value = checked_cast<const ScalarType&>(scalar_).value;
+
+ ArrayVector values(length_, value);
+ ARROW_ASSIGN_OR_RAISE(auto value_array, Concatenate(values, pool_));
+
+ std::shared_ptr<Buffer> offsets_buffer;
+ auto size = static_cast<typename T::offset_type>(value->length());
+ RETURN_NOT_OK(CreateOffsetsBuffer(size, &offsets_buffer));
+
+ out_ =
+ std::make_shared<ArrayType>(scalar_.type, length_, offsets_buffer, value_array);
+ return Status::OK();
+ }
+
+ Status Visit(const FixedSizeListType& type) {
+ auto value = checked_cast<const FixedSizeListScalar&>(scalar_).value;
+
+ ArrayVector values(length_, value);
+ ARROW_ASSIGN_OR_RAISE(auto value_array, Concatenate(values, pool_));
+
+ out_ = std::make_shared<FixedSizeListArray>(scalar_.type, length_, value_array);
+ return Status::OK();
+ }
+
+ Status Visit(const MapType& type) {
+ auto map_scalar = checked_cast<const MapScalar&>(scalar_);
+ auto struct_array = checked_cast<const StructArray*>(map_scalar.value.get());
+
+ ArrayVector keys(length_, struct_array->field(0));
+ ArrayVector values(length_, struct_array->field(1));
+
+ ARROW_ASSIGN_OR_RAISE(auto key_array, Concatenate(keys, pool_));
+ ARROW_ASSIGN_OR_RAISE(auto value_array, Concatenate(values, pool_));
+
+ std::shared_ptr<Buffer> offsets_buffer;
+ auto size = static_cast<typename MapType::offset_type>(struct_array->length());
+ RETURN_NOT_OK(CreateOffsetsBuffer(size, &offsets_buffer));
+
+ out_ = std::make_shared<MapArray>(scalar_.type, length_, std::move(offsets_buffer),
+ std::move(key_array), std::move(value_array));
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryType& type) {
+ const auto& value = checked_cast<const DictionaryScalar&>(scalar_).value;
+ ARROW_ASSIGN_OR_RAISE(auto indices,
+ MakeArrayFromScalar(*value.index, length_, pool_));
+ out_ = std::make_shared<DictionaryArray>(scalar_.type, std::move(indices),
+ value.dictionary);
+ return Status::OK();
+ }
+
+ Status Visit(const StructType& type) {
+ ArrayVector fields;
+ for (const auto& value : checked_cast<const StructScalar&>(scalar_).value) {
+ fields.emplace_back();
+ ARROW_ASSIGN_OR_RAISE(fields.back(), MakeArrayFromScalar(*value, length_, pool_));
+ }
+ out_ = std::make_shared<StructArray>(scalar_.type, length_, std::move(fields));
+ return Status::OK();
+ }
+
+ Status Visit(const SparseUnionType& type) {
+ const auto& union_scalar = checked_cast<const UnionScalar&>(scalar_);
+ const auto& union_type = checked_cast<const UnionType&>(*scalar_.type);
+ const auto scalar_type_code = union_scalar.type_code;
+ const auto scalar_child_id = union_type.child_ids()[scalar_type_code];
+
+ // Create child arrays: most of them are all-null, except for the child array
+ // for the given type code (if the scalar is valid).
+ ArrayVector fields;
+ for (int i = 0; i < type.num_fields(); ++i) {
+ fields.emplace_back();
+ if (i == scalar_child_id && scalar_.is_valid) {
+ ARROW_ASSIGN_OR_RAISE(fields.back(),
+ MakeArrayFromScalar(*union_scalar.value, length_, pool_));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ fields.back(), MakeArrayOfNull(union_type.field(i)->type(), length_, pool_));
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto type_codes_buffer, CreateUnionTypeCodes(scalar_type_code));
+
+ out_ = std::make_shared<SparseUnionArray>(scalar_.type, length_, std::move(fields),
+ std::move(type_codes_buffer));
+ return Status::OK();
+ }
+
+ Status Visit(const DenseUnionType& type) {
+ const auto& union_scalar = checked_cast<const UnionScalar&>(scalar_);
+ const auto& union_type = checked_cast<const UnionType&>(*scalar_.type);
+ const auto scalar_type_code = union_scalar.type_code;
+ const auto scalar_child_id = union_type.child_ids()[scalar_type_code];
+
+ // Create child arrays: all of them are empty, except for the child array
+ // for the given type code (if length > 0).
+ ArrayVector fields;
+ for (int i = 0; i < type.num_fields(); ++i) {
+ fields.emplace_back();
+ if (i == scalar_child_id && length_ > 0) {
+ if (scalar_.is_valid) {
+ // One valid element (will be referenced by multiple offsets)
+ ARROW_ASSIGN_OR_RAISE(fields.back(),
+ MakeArrayFromScalar(*union_scalar.value, 1, pool_));
+ } else {
+ // One null element (will be referenced by multiple offsets)
+ ARROW_ASSIGN_OR_RAISE(fields.back(),
+ MakeArrayOfNull(union_type.field(i)->type(), 1, pool_));
+ }
+ } else {
+ // Zero element (will not be referenced by any offset)
+ ARROW_ASSIGN_OR_RAISE(fields.back(),
+ MakeArrayOfNull(union_type.field(i)->type(), 0, pool_));
+ }
+ }
+
+ // Create an offsets buffer with all offsets equal to 0
+ ARROW_ASSIGN_OR_RAISE(auto offsets_buffer,
+ AllocateBuffer(length_ * sizeof(int32_t), pool_));
+ memset(offsets_buffer->mutable_data(), 0, offsets_buffer->size());
+
+ ARROW_ASSIGN_OR_RAISE(auto type_codes_buffer, CreateUnionTypeCodes(scalar_type_code));
+
+ out_ = std::make_shared<DenseUnionArray>(scalar_.type, length_, std::move(fields),
+ std::move(type_codes_buffer),
+ std::move(offsets_buffer));
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionType& type) {
+ return Status::NotImplemented("construction from scalar of type ", *scalar_.type);
+ }
+
+ Result<std::shared_ptr<Buffer>> CreateUnionTypeCodes(int8_t type_code) {
+ TypedBufferBuilder<int8_t> builder(pool_);
+ RETURN_NOT_OK(builder.Resize(length_));
+ builder.UnsafeAppend(length_, type_code);
+ return builder.Finish();
+ }
+
+ template <typename OffsetType>
+ Status CreateOffsetsBuffer(OffsetType value_length, std::shared_ptr<Buffer>* out) {
+ TypedBufferBuilder<OffsetType> builder(pool_);
+ RETURN_NOT_OK(builder.Resize(length_ + 1));
+ OffsetType offset = 0;
+ for (int64_t i = 0; i < length_ + 1; ++i, offset += value_length) {
+ builder.UnsafeAppend(offset);
+ }
+ return builder.Finish(out);
+ }
+
+ Status CreateBufferOf(const void* data, size_t data_length,
+ std::shared_ptr<Buffer>* out) {
+ BufferBuilder builder(pool_);
+ RETURN_NOT_OK(builder.Resize(length_ * data_length));
+ for (int64_t i = 0; i < length_; ++i) {
+ builder.UnsafeAppend(data, data_length);
+ }
+ return builder.Finish(out);
+ }
+
+ Status FinishFixedWidth(const void* data, size_t data_length) {
+ std::shared_ptr<Buffer> buffer;
+ RETURN_NOT_OK(CreateBufferOf(data, data_length, &buffer));
+ out_ = MakeArray(
+ ArrayData::Make(scalar_.type, length_, {nullptr, std::move(buffer)}, 0));
+ return Status::OK();
+ }
+
+ MemoryPool* pool_;
+ const Scalar& scalar_;
+ int64_t length_;
+ std::shared_ptr<Array> out_;
+};
+
+} // namespace
+
+Result<std::shared_ptr<Array>> MakeArrayOfNull(const std::shared_ptr<DataType>& type,
+ int64_t length, MemoryPool* pool) {
+ ARROW_ASSIGN_OR_RAISE(auto data, NullArrayFactory(pool, type, length).Create());
+ return MakeArray(data);
+}
+
+Result<std::shared_ptr<Array>> MakeArrayFromScalar(const Scalar& scalar, int64_t length,
+ MemoryPool* pool) {
+ // Null union scalars still have a type code associated
+ if (!scalar.is_valid && !is_union(scalar.type->id())) {
+ return MakeArrayOfNull(scalar.type, length, pool);
+ }
+ return RepeatedArrayFactory(pool, scalar, length).Create();
+}
+
+namespace internal {
+
+std::vector<ArrayVector> RechunkArraysConsistently(
+ const std::vector<ArrayVector>& groups) {
+ if (groups.size() <= 1) {
+ return groups;
+ }
+ int64_t total_length = 0;
+ for (const auto& array : groups.front()) {
+ total_length += array->length();
+ }
+#ifndef NDEBUG
+ for (const auto& group : groups) {
+ int64_t group_length = 0;
+ for (const auto& array : group) {
+ group_length += array->length();
+ }
+ DCHECK_EQ(group_length, total_length)
+ << "Array groups should have the same total number of elements";
+ }
+#endif
+ if (total_length == 0) {
+ return groups;
+ }
+
+ // Set up result vectors
+ std::vector<ArrayVector> rechunked_groups(groups.size());
+
+ // Set up progress counters
+ std::vector<ArrayVector::const_iterator> current_arrays;
+ std::vector<int64_t> array_offsets;
+ for (const auto& group : groups) {
+ current_arrays.emplace_back(group.cbegin());
+ array_offsets.emplace_back(0);
+ }
+
+ // Scan all array vectors at once, rechunking along the way
+ int64_t start = 0;
+ while (start < total_length) {
+ // First compute max possible length for next chunk
+ int64_t chunk_length = std::numeric_limits<int64_t>::max();
+ for (size_t i = 0; i < groups.size(); i++) {
+ auto& arr_it = current_arrays[i];
+ auto& offset = array_offsets[i];
+ // Skip any done arrays (including 0-length arrays)
+ while (offset == (*arr_it)->length()) {
+ ++arr_it;
+ offset = 0;
+ }
+ const auto& array = *arr_it;
+ DCHECK_GT(array->length(), offset);
+ chunk_length = std::min(chunk_length, array->length() - offset);
+ }
+ DCHECK_GT(chunk_length, 0);
+
+ // Then slice all arrays along this chunk size
+ for (size_t i = 0; i < groups.size(); i++) {
+ const auto& array = *current_arrays[i];
+ auto& offset = array_offsets[i];
+ if (offset == 0 && array->length() == chunk_length) {
+ // Slice spans entire array
+ rechunked_groups[i].emplace_back(array);
+ } else {
+ DCHECK_LT(chunk_length - offset, array->length());
+ rechunked_groups[i].emplace_back(array->Slice(offset, chunk_length));
+ }
+ offset += chunk_length;
+ }
+ start += chunk_length;
+ }
+
+ return rechunked_groups;
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/util.h b/src/arrow/cpp/src/arrow/array/util.h
new file mode 100644
index 000000000..3ef4e0882
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/util.h
@@ -0,0 +1,78 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "arrow/array/data.h"
+#include "arrow/compare.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+/// \brief Create a strongly-typed Array instance from generic ArrayData
+/// \param[in] data the array contents
+/// \return the resulting Array instance
+ARROW_EXPORT
+std::shared_ptr<Array> MakeArray(const std::shared_ptr<ArrayData>& data);
+
+/// \brief Create a strongly-typed Array instance with all elements null
+/// \param[in] type the array type
+/// \param[in] length the array length
+/// \param[in] pool the memory pool to allocate memory from
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> MakeArrayOfNull(const std::shared_ptr<DataType>& type,
+ int64_t length,
+ MemoryPool* pool = default_memory_pool());
+
+/// \brief Create an Array instance whose slots are the given scalar
+/// \param[in] scalar the value with which to fill the array
+/// \param[in] length the array length
+/// \param[in] pool the memory pool to allocate memory from
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> MakeArrayFromScalar(
+ const Scalar& scalar, int64_t length, MemoryPool* pool = default_memory_pool());
+
+namespace internal {
+
+/// \brief Swap endian of each element in a generic ArrayData
+///
+/// As dictionaries are often shared between different arrays, dictionaries
+/// are not swapped by this function and should be handled separately.
+///
+/// \param[in] data the array contents
+/// \return the resulting ArrayData whose elements were swapped
+ARROW_EXPORT
+Result<std::shared_ptr<ArrayData>> SwapEndianArrayData(
+ const std::shared_ptr<ArrayData>& data);
+
+/// Given a number of ArrayVectors, treat each ArrayVector as the
+/// chunks of a chunked array. Then rechunk each ArrayVector such that
+/// all ArrayVectors are chunked identically. It is mandatory that
+/// all ArrayVectors contain the same total number of elements.
+ARROW_EXPORT
+std::vector<ArrayVector> RechunkArraysConsistently(const std::vector<ArrayVector>&);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/validate.cc b/src/arrow/cpp/src/arrow/array/validate.cc
new file mode 100644
index 000000000..c66c4f53b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/validate.cc
@@ -0,0 +1,679 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/validate.h"
+
+#include <vector>
+
+#include "arrow/array.h" // IWYU pragma: keep
+#include "arrow/extension_type.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/utf8.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+namespace internal {
+
+///////////////////////////////////////////////////////////////////////////
+// ValidateArray: cheap validation checks
+
+namespace {
+
+struct ValidateArrayImpl {
+ const ArrayData& data;
+
+ Status Validate() { return ValidateWithType(*data.type); }
+
+ Status ValidateWithType(const DataType& type) { return VisitTypeInline(type, this); }
+
+ Status Visit(const NullType&) {
+ if (data.null_count != data.length) {
+ return Status::Invalid("Null array null_count unequal to its length");
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const FixedWidthType&) {
+ if (data.length > 0) {
+ if (!IsBufferValid(1)) {
+ return Status::Invalid("Missing values buffer in non-empty array");
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const StringType& type) { return ValidateBinaryLike(type); }
+
+ Status Visit(const BinaryType& type) { return ValidateBinaryLike(type); }
+
+ Status Visit(const LargeStringType& type) { return ValidateBinaryLike(type); }
+
+ Status Visit(const LargeBinaryType& type) { return ValidateBinaryLike(type); }
+
+ Status Visit(const ListType& type) { return ValidateListLike(type); }
+
+ Status Visit(const LargeListType& type) { return ValidateListLike(type); }
+
+ Status Visit(const MapType& type) { return ValidateListLike(type); }
+
+ Status Visit(const FixedSizeListType& type) {
+ const ArrayData& values = *data.child_data[0];
+ const int64_t list_size = type.list_size();
+ if (list_size < 0) {
+ return Status::Invalid("Fixed size list has negative list size");
+ }
+
+ int64_t expected_values_length = -1;
+ if (MultiplyWithOverflow(data.length, list_size, &expected_values_length) ||
+ values.length < expected_values_length) {
+ return Status::Invalid("Values length (", values.length,
+ ") is less than the length (", data.length,
+ ") multiplied by the value size (", list_size, ")");
+ }
+
+ const Status child_valid = ValidateArray(values);
+ if (!child_valid.ok()) {
+ return Status::Invalid("Fixed size list child array invalid: ",
+ child_valid.ToString());
+ }
+
+ return Status::OK();
+ }
+
+ Status Visit(const StructType& type) {
+ for (int i = 0; i < type.num_fields(); ++i) {
+ const auto& field_data = *data.child_data[i];
+
+ // Validate child first, to catch nonsensical length / offset etc.
+ const Status field_valid = ValidateArray(field_data);
+ if (!field_valid.ok()) {
+ return Status::Invalid("Struct child array #", i,
+ " invalid: ", field_valid.ToString());
+ }
+
+ if (field_data.length < data.length + data.offset) {
+ return Status::Invalid("Struct child array #", i,
+ " has length smaller than expected for struct array (",
+ field_data.length, " < ", data.length + data.offset, ")");
+ }
+
+ const auto& field_type = type.field(i)->type();
+ if (!field_data.type->Equals(*field_type)) {
+ return Status::Invalid("Struct child array #", i, " does not match type field: ",
+ field_data.type->ToString(), " vs ",
+ field_type->ToString());
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const UnionType& type) {
+ for (int i = 0; i < type.num_fields(); ++i) {
+ const auto& field_data = *data.child_data[i];
+
+ // Validate child first, to catch nonsensical length / offset etc.
+ const Status field_valid = ValidateArray(field_data);
+ if (!field_valid.ok()) {
+ return Status::Invalid("Union child array #", i,
+ " invalid: ", field_valid.ToString());
+ }
+
+ if (type.mode() == UnionMode::SPARSE &&
+ field_data.length < data.length + data.offset) {
+ return Status::Invalid("Sparse union child array #", i,
+ " has length smaller than expected for union array (",
+ field_data.length, " < ", data.length + data.offset, ")");
+ }
+
+ const auto& field_type = type.field(i)->type();
+ if (!field_data.type->Equals(*field_type)) {
+ return Status::Invalid("Union child array #", i, " does not match type field: ",
+ field_data.type->ToString(), " vs ",
+ field_type->ToString());
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryType& type) {
+ Type::type index_type_id = type.index_type()->id();
+ if (!is_integer(index_type_id)) {
+ return Status::Invalid("Dictionary indices must be integer type");
+ }
+ if (!data.dictionary) {
+ return Status::Invalid("Dictionary values must be non-null");
+ }
+ const Status dict_valid = ValidateArray(*data.dictionary);
+ if (!dict_valid.ok()) {
+ return Status::Invalid("Dictionary array invalid: ", dict_valid.ToString());
+ }
+ // Visit indices
+ return ValidateWithType(*type.index_type());
+ }
+
+ Status Visit(const ExtensionType& type) {
+ // Visit storage
+ return ValidateWithType(*type.storage_type());
+ }
+
+ private:
+ bool IsBufferValid(int index) { return IsBufferValid(data, index); }
+
+ static bool IsBufferValid(const ArrayData& data, int index) {
+ return data.buffers[index] != nullptr && data.buffers[index]->address() != 0;
+ }
+
+ template <typename BinaryType>
+ Status ValidateBinaryLike(const BinaryType& type) {
+ if (!IsBufferValid(2)) {
+ return Status::Invalid("Value data buffer is null");
+ }
+ // First validate offsets, to make sure the accesses below are valid
+ RETURN_NOT_OK(ValidateOffsets(type));
+
+ if (data.length > 0 && data.buffers[1]->is_cpu()) {
+ using offset_type = typename BinaryType::offset_type;
+
+ const auto offsets = data.GetValues<offset_type>(1);
+ const Buffer& values = *data.buffers[2];
+
+ const auto first_offset = offsets[0];
+ const auto last_offset = offsets[data.length];
+ // This early test avoids undefined behaviour when computing `data_extent`
+ if (first_offset < 0 || last_offset < 0) {
+ return Status::Invalid("Negative offsets in binary array");
+ }
+ const auto data_extent = last_offset - first_offset;
+ const auto values_length = values.size();
+ if (values_length < data_extent) {
+ return Status::Invalid("Length spanned by binary offsets (", data_extent,
+ ") larger than values array (size ", values_length, ")");
+ }
+ // These tests ensure that array concatenation is safe if Validate() succeeds
+ // (for delta dictionaries)
+ if (first_offset > values_length || last_offset > values_length) {
+ return Status::Invalid("First or last binary offset out of bounds");
+ }
+ if (first_offset > last_offset) {
+ return Status::Invalid("First offset larger than last offset in binary array");
+ }
+ }
+ return Status::OK();
+ }
+
+ template <typename ListType>
+ Status ValidateListLike(const ListType& type) {
+ // First validate offsets, to make sure the accesses below are valid
+ RETURN_NOT_OK(ValidateOffsets(type));
+
+ const ArrayData& values = *data.child_data[0];
+
+ // An empty list array can have 0 offsets
+ if (data.length > 0 && data.buffers[1]->is_cpu()) {
+ using offset_type = typename ListType::offset_type;
+
+ const auto offsets = data.GetValues<offset_type>(1);
+
+ const auto first_offset = offsets[0];
+ const auto last_offset = offsets[data.length];
+ // This early test avoids undefined behaviour when computing `data_extent`
+ if (first_offset < 0 || last_offset < 0) {
+ return Status::Invalid("Negative offsets in list array");
+ }
+ const auto data_extent = last_offset - first_offset;
+ const auto values_length = values.length;
+ if (values_length < data_extent) {
+ return Status::Invalid("Length spanned by list offsets (", data_extent,
+ ") larger than values array (length ", values_length, ")");
+ }
+ // These tests ensure that array concatenation is safe if Validate() succeeds
+ // (for delta dictionaries)
+ if (first_offset > values_length || last_offset > values_length) {
+ return Status::Invalid("First or last list offset out of bounds");
+ }
+ if (first_offset > last_offset) {
+ return Status::Invalid("First offset larger than last offset in list array");
+ }
+ }
+
+ const Status child_valid = ValidateArray(values);
+ if (!child_valid.ok()) {
+ return Status::Invalid("List child array invalid: ", child_valid.ToString());
+ }
+ return Status::OK();
+ }
+
+ template <typename TypeClass>
+ Status ValidateOffsets(const TypeClass& type) {
+ using offset_type = typename TypeClass::offset_type;
+
+ const Buffer* offsets = data.buffers[1].get();
+ if (offsets == nullptr) {
+ // For length 0, an empty offsets buffer seems accepted as a special case
+ // (ARROW-544)
+ if (data.length > 0) {
+ return Status::Invalid("Non-empty array but offsets are null");
+ }
+ return Status::OK();
+ }
+
+ // An empty list array can have 0 offsets
+ auto required_offsets = (data.length > 0) ? data.length + data.offset + 1 : 0;
+ if (offsets->size() / static_cast<int32_t>(sizeof(offset_type)) < required_offsets) {
+ return Status::Invalid("Offsets buffer size (bytes): ", offsets->size(),
+ " isn't large enough for length: ", data.length);
+ }
+
+ return Status::OK();
+ }
+};
+
+} // namespace
+
+ARROW_EXPORT
+Status ValidateArray(const ArrayData& data) {
+ if (data.type == nullptr) {
+ return Status::Invalid("Array type is absent");
+ }
+
+ // First check the data layout conforms to the spec
+ const DataType& type = *data.type;
+ const auto layout = type.layout();
+
+ if (data.length < 0) {
+ return Status::Invalid("Array length is negative");
+ }
+
+ if (data.buffers.size() != layout.buffers.size()) {
+ return Status::Invalid("Expected ", layout.buffers.size(),
+ " buffers in array "
+ "of type ",
+ type.ToString(), ", got ", data.buffers.size());
+ }
+
+ // This check is required to avoid addition overflow below
+ int64_t length_plus_offset = -1;
+ if (AddWithOverflow(data.length, data.offset, &length_plus_offset)) {
+ return Status::Invalid("Array of type ", type.ToString(),
+ " has impossibly large length and offset");
+ }
+
+ for (int i = 0; i < static_cast<int>(data.buffers.size()); ++i) {
+ const auto& buffer = data.buffers[i];
+ const auto& spec = layout.buffers[i];
+
+ if (buffer == nullptr) {
+ continue;
+ }
+ int64_t min_buffer_size = -1;
+ switch (spec.kind) {
+ case DataTypeLayout::BITMAP:
+ min_buffer_size = BitUtil::BytesForBits(length_plus_offset);
+ break;
+ case DataTypeLayout::FIXED_WIDTH:
+ if (MultiplyWithOverflow(length_plus_offset, spec.byte_width, &min_buffer_size)) {
+ return Status::Invalid("Array of type ", type.ToString(),
+ " has impossibly large length and offset");
+ }
+ break;
+ case DataTypeLayout::ALWAYS_NULL:
+ // XXX Should we raise on non-null buffer?
+ continue;
+ default:
+ continue;
+ }
+ if (buffer->size() < min_buffer_size) {
+ return Status::Invalid("Buffer #", i, " too small in array of type ",
+ type.ToString(), " and length ", data.length,
+ ": expected at least ", min_buffer_size, " byte(s), got ",
+ buffer->size());
+ }
+ }
+ if (type.id() != Type::NA && data.null_count > 0 && data.buffers[0] == nullptr) {
+ return Status::Invalid("Array of type ", type.ToString(), " has ", data.null_count,
+ " nulls but no null bitmap");
+ }
+
+ // Check null_count() *after* validating the buffer sizes, to avoid
+ // reading out of bounds.
+ if (data.null_count > data.length) {
+ return Status::Invalid("Null count exceeds array length");
+ }
+ if (data.null_count < 0 && data.null_count != kUnknownNullCount) {
+ return Status::Invalid("Negative null count");
+ }
+
+ if (type.id() != Type::EXTENSION) {
+ if (data.child_data.size() != static_cast<size_t>(type.num_fields())) {
+ return Status::Invalid("Expected ", type.num_fields(),
+ " child arrays in array "
+ "of type ",
+ type.ToString(), ", got ", data.child_data.size());
+ }
+ }
+ if (layout.has_dictionary && !data.dictionary) {
+ return Status::Invalid("Array of type ", type.ToString(),
+ " must have dictionary values");
+ }
+ if (!layout.has_dictionary && data.dictionary) {
+ return Status::Invalid("Unexpected dictionary values in array of type ",
+ type.ToString());
+ }
+
+ ValidateArrayImpl validator{data};
+ return validator.Validate();
+}
+
+ARROW_EXPORT
+Status ValidateArray(const Array& array) { return ValidateArray(*array.data()); }
+
+///////////////////////////////////////////////////////////////////////////
+// ValidateArrayFull: expensive validation checks
+
+namespace {
+
+struct UTF8DataValidator {
+ const ArrayData& data;
+
+ Status Visit(const DataType&) {
+ // Default, should be unreachable
+ return Status::NotImplemented("");
+ }
+
+ template <typename StringType>
+ enable_if_string<StringType, Status> Visit(const StringType&) {
+ util::InitializeUTF8();
+
+ int64_t i = 0;
+ return VisitArrayDataInline<StringType>(
+ data,
+ [&](util::string_view v) {
+ if (ARROW_PREDICT_FALSE(!util::ValidateUTF8(v))) {
+ return Status::Invalid("Invalid UTF8 sequence at string index ", i);
+ }
+ ++i;
+ return Status::OK();
+ },
+ [&]() {
+ ++i;
+ return Status::OK();
+ });
+ }
+};
+
+struct BoundsChecker {
+ const ArrayData& data;
+ int64_t min_value;
+ int64_t max_value;
+
+ Status Visit(const DataType&) {
+ // Default, should be unreachable
+ return Status::NotImplemented("");
+ }
+
+ template <typename IntegerType>
+ enable_if_integer<IntegerType, Status> Visit(const IntegerType&) {
+ using c_type = typename IntegerType::c_type;
+
+ int64_t i = 0;
+ return VisitArrayDataInline<IntegerType>(
+ data,
+ [&](c_type value) {
+ const auto v = static_cast<int64_t>(value);
+ if (ARROW_PREDICT_FALSE(v < min_value || v > max_value)) {
+ return Status::Invalid("Value at position ", i, " out of bounds: ", v,
+ " (should be in [", min_value, ", ", max_value, "])");
+ }
+ ++i;
+ return Status::OK();
+ },
+ [&]() {
+ ++i;
+ return Status::OK();
+ });
+ }
+};
+
+struct ValidateArrayFullImpl {
+ const ArrayData& data;
+
+ Status Validate() { return ValidateWithType(*data.type); }
+
+ Status ValidateWithType(const DataType& type) { return VisitTypeInline(type, this); }
+
+ Status Visit(const NullType& type) { return Status::OK(); }
+
+ Status Visit(const FixedWidthType& type) { return Status::OK(); }
+
+ Status Visit(const StringType& type) {
+ RETURN_NOT_OK(ValidateBinaryLike(type));
+ return ValidateUTF8(data);
+ }
+
+ Status Visit(const LargeStringType& type) {
+ RETURN_NOT_OK(ValidateBinaryLike(type));
+ return ValidateUTF8(data);
+ }
+
+ Status Visit(const BinaryType& type) { return ValidateBinaryLike(type); }
+
+ Status Visit(const LargeBinaryType& type) { return ValidateBinaryLike(type); }
+
+ Status Visit(const ListType& type) { return ValidateListLike(type); }
+
+ Status Visit(const LargeListType& type) { return ValidateListLike(type); }
+
+ Status Visit(const MapType& type) { return ValidateListLike(type); }
+
+ Status Visit(const FixedSizeListType& type) {
+ const ArrayData& child = *data.child_data[0];
+ const Status child_valid = ValidateArrayFull(child);
+ if (!child_valid.ok()) {
+ return Status::Invalid("Fixed size list child array invalid: ",
+ child_valid.ToString());
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const StructType& type) {
+ // Validate children
+ for (int64_t i = 0; i < type.num_fields(); ++i) {
+ const ArrayData& field = *data.child_data[i];
+ const Status field_valid = ValidateArrayFull(field);
+ if (!field_valid.ok()) {
+ return Status::Invalid("Struct child array #", i,
+ " invalid: ", field_valid.ToString());
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const UnionType& type) {
+ const auto& child_ids = type.child_ids();
+ const auto& type_codes_map = type.type_codes();
+
+ const int8_t* type_codes = data.GetValues<int8_t>(1);
+
+ for (int64_t i = 0; i < data.length; ++i) {
+ // Note that union arrays never have top-level nulls
+ const int32_t code = type_codes[i];
+ if (code < 0 || child_ids[code] == UnionType::kInvalidChildId) {
+ return Status::Invalid("Union value at position ", i, " has invalid type id ",
+ code);
+ }
+ }
+
+ if (type.mode() == UnionMode::DENSE) {
+ // Map logical type id to child length
+ std::vector<int64_t> child_lengths(256);
+ for (int child_id = 0; child_id < type.num_fields(); ++child_id) {
+ child_lengths[type_codes_map[child_id]] = data.child_data[child_id]->length;
+ }
+
+ // Check offsets are in bounds
+ std::vector<int64_t> last_child_offsets(256, 0);
+ const int32_t* offsets = data.GetValues<int32_t>(2);
+ for (int64_t i = 0; i < data.length; ++i) {
+ const int32_t code = type_codes[i];
+ const int32_t offset = offsets[i];
+ if (offset < 0) {
+ return Status::Invalid("Union value at position ", i, " has negative offset ",
+ offset);
+ }
+ if (offset >= child_lengths[code]) {
+ return Status::Invalid("Union value at position ", i,
+ " has offset larger "
+ "than child length (",
+ offset, " >= ", child_lengths[code], ")");
+ }
+ if (offset < last_child_offsets[code]) {
+ return Status::Invalid("Union value at position ", i,
+ " has non-monotonic offset ", offset);
+ }
+ last_child_offsets[code] = offset;
+ }
+ }
+
+ // Validate children
+ for (int64_t i = 0; i < type.num_fields(); ++i) {
+ const ArrayData& field = *data.child_data[i];
+ const Status field_valid = ValidateArrayFull(field);
+ if (!field_valid.ok()) {
+ return Status::Invalid("Union child array #", i,
+ " invalid: ", field_valid.ToString());
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryType& type) {
+ const Status indices_status =
+ CheckBounds(*type.index_type(), 0, data.dictionary->length - 1);
+ if (!indices_status.ok()) {
+ return Status::Invalid("Dictionary indices invalid: ", indices_status.ToString());
+ }
+ return ValidateArrayFull(*data.dictionary);
+ }
+
+ Status Visit(const ExtensionType& type) {
+ return ValidateWithType(*type.storage_type());
+ }
+
+ protected:
+ template <typename BinaryType>
+ Status ValidateBinaryLike(const BinaryType& type) {
+ const auto& data_buffer = data.buffers[2];
+ if (data_buffer == nullptr) {
+ return Status::Invalid("Binary data buffer is null");
+ }
+ return ValidateOffsets(type, data_buffer->size());
+ }
+
+ template <typename ListType>
+ Status ValidateListLike(const ListType& type) {
+ const ArrayData& child = *data.child_data[0];
+ const Status child_valid = ValidateArrayFull(child);
+ if (!child_valid.ok()) {
+ return Status::Invalid("List child array invalid: ", child_valid.ToString());
+ }
+ return ValidateOffsets(type, child.offset + child.length);
+ }
+
+ template <typename TypeClass>
+ Status ValidateOffsets(const TypeClass& type, int64_t offset_limit) {
+ using offset_type = typename TypeClass::offset_type;
+ if (data.length == 0) {
+ return Status::OK();
+ }
+
+ const offset_type* offsets = data.GetValues<offset_type>(1);
+ if (offsets == nullptr) {
+ return Status::Invalid("Non-empty array but offsets are null");
+ }
+
+ auto prev_offset = offsets[0];
+ if (prev_offset < 0) {
+ return Status::Invalid("Offset invariant failure: array starts at negative offset ",
+ prev_offset);
+ }
+ for (int64_t i = 1; i <= data.length; ++i) {
+ const auto current_offset = offsets[i];
+ if (current_offset < prev_offset) {
+ return Status::Invalid("Offset invariant failure: non-monotonic offset at slot ",
+ i, ": ", current_offset, " < ", prev_offset);
+ }
+ if (current_offset > offset_limit) {
+ return Status::Invalid("Offset invariant failure: offset for slot ", i,
+ " out of bounds: ", current_offset, " > ", offset_limit);
+ }
+ prev_offset = current_offset;
+ }
+ return Status::OK();
+ }
+
+ Status CheckBounds(const DataType& type, int64_t min_value, int64_t max_value) {
+ BoundsChecker checker{data, min_value, max_value};
+ return VisitTypeInline(type, &checker);
+ }
+};
+
+} // namespace
+
+ARROW_EXPORT
+Status ValidateArrayFull(const ArrayData& data) {
+ if (data.null_count != kUnknownNullCount) {
+ int64_t actual_null_count;
+ if (HasValidityBitmap(data.type->id()) && data.buffers[0]) {
+ // Do not call GetNullCount() as it would also set the `null_count` member
+ actual_null_count =
+ data.length - CountSetBits(data.buffers[0]->data(), data.offset, data.length);
+ } else if (data.type->id() == Type::NA) {
+ actual_null_count = data.length;
+ } else {
+ actual_null_count = 0;
+ }
+ if (actual_null_count != data.null_count) {
+ return Status::Invalid("null_count value (", data.null_count,
+ ") doesn't match actual number of nulls in array (",
+ actual_null_count, ")");
+ }
+ }
+ return ValidateArrayFullImpl{data}.Validate();
+}
+
+ARROW_EXPORT
+Status ValidateArrayFull(const Array& array) { return ValidateArrayFull(*array.data()); }
+
+ARROW_EXPORT
+Status ValidateUTF8(const ArrayData& data) {
+ DCHECK(data.type->id() == Type::STRING || data.type->id() == Type::LARGE_STRING);
+ UTF8DataValidator validator{data};
+ return VisitTypeInline(*data.type, &validator);
+}
+
+ARROW_EXPORT
+Status ValidateUTF8(const Array& array) { return ValidateUTF8(*array.data()); }
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/array/validate.h b/src/arrow/cpp/src/arrow/array/validate.h
new file mode 100644
index 000000000..cae3e16b3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/array/validate.h
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace internal {
+
+// Internal functions implementing Array::Validate() and friends.
+
+// O(1) array metadata validation
+
+ARROW_EXPORT
+Status ValidateArray(const Array& array);
+
+ARROW_EXPORT
+Status ValidateArray(const ArrayData& data);
+
+// O(N) array data validation.
+// Note the "full" routines don't validate metadata. It should be done
+// beforehand using ValidateArray(), otherwise invalid memory accesses
+// may occur.
+
+ARROW_EXPORT
+Status ValidateArrayFull(const Array& array);
+
+ARROW_EXPORT
+Status ValidateArrayFull(const ArrayData& data);
+
+ARROW_EXPORT
+Status ValidateUTF8(const Array& array);
+
+ARROW_EXPORT
+Status ValidateUTF8(const ArrayData& data);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/arrow-config.cmake b/src/arrow/cpp/src/arrow/arrow-config.cmake
new file mode 100644
index 000000000..8c9173c17
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/arrow-config.cmake
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+message(WARNING "find_package(arrow) is deprecated. Use find_package(Arrow) instead.")
+find_package(Arrow CONFIG)
+
+include(FindPackageHandleStandardArgs)
+find_package_handle_standard_args(arrow
+ REQUIRED_VARS
+ ARROW_INCLUDE_DIR
+ VERSION_VAR
+ ARROW_VERSION)
diff --git a/src/arrow/cpp/src/arrow/arrow-testing.pc.in b/src/arrow/cpp/src/arrow/arrow-testing.pc.in
new file mode 100644
index 000000000..39c08fcf0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/arrow-testing.pc.in
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+gtest_includedir=@GTEST_INCLUDE_DIR@
+
+Name: Apache Arrow testing
+Description: Library for testing Apache Arrow related programs.
+Version: @ARROW_VERSION@
+Requires: arrow
+Libs: -L${libdir} -larrow_testing
+Cflags: -I${gtest_includedir}
diff --git a/src/arrow/cpp/src/arrow/arrow.pc.in b/src/arrow/cpp/src/arrow/arrow.pc.in
new file mode 100644
index 000000000..ef995fdc3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/arrow.pc.in
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+
+so_version=@ARROW_SO_VERSION@
+abi_version=@ARROW_SO_VERSION@
+full_so_version=@ARROW_FULL_SO_VERSION@
+
+Name: Apache Arrow
+Description: Arrow is a set of technologies that enable big-data systems to process and move data fast.
+Version: @ARROW_VERSION@
+Requires.private:@ARROW_PC_REQUIRES_PRIVATE@
+Libs: -L${libdir} -larrow
+Libs.private:@ARROW_PC_LIBS_PRIVATE@
+Cflags: -I${includedir}
diff --git a/src/arrow/cpp/src/arrow/buffer.cc b/src/arrow/cpp/src/arrow/buffer.cc
new file mode 100644
index 000000000..b1b2945d0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/buffer.cc
@@ -0,0 +1,207 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/buffer.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <utility>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string.h"
+
+namespace arrow {
+
+Result<std::shared_ptr<Buffer>> Buffer::CopySlice(const int64_t start,
+ const int64_t nbytes,
+ MemoryPool* pool) const {
+ // Sanity checks
+ ARROW_CHECK_LE(start, size_);
+ ARROW_CHECK_LE(nbytes, size_ - start);
+ DCHECK_GE(nbytes, 0);
+
+ ARROW_ASSIGN_OR_RAISE(auto new_buffer, AllocateResizableBuffer(nbytes, pool));
+ std::memcpy(new_buffer->mutable_data(), data() + start, static_cast<size_t>(nbytes));
+ return std::move(new_buffer);
+}
+
+namespace {
+
+Status CheckBufferSlice(const Buffer& buffer, int64_t offset, int64_t length) {
+ return internal::CheckSliceParams(buffer.size(), offset, length, "buffer");
+}
+
+Status CheckBufferSlice(const Buffer& buffer, int64_t offset) {
+ if (ARROW_PREDICT_FALSE(offset < 0)) {
+ // Avoid UBSAN in subtraction below
+ return Status::Invalid("Negative buffer slice offset");
+ }
+ return CheckBufferSlice(buffer, offset, buffer.size() - offset);
+}
+
+} // namespace
+
+Result<std::shared_ptr<Buffer>> SliceBufferSafe(const std::shared_ptr<Buffer>& buffer,
+ int64_t offset) {
+ RETURN_NOT_OK(CheckBufferSlice(*buffer, offset));
+ return SliceBuffer(buffer, offset);
+}
+
+Result<std::shared_ptr<Buffer>> SliceBufferSafe(const std::shared_ptr<Buffer>& buffer,
+ int64_t offset, int64_t length) {
+ RETURN_NOT_OK(CheckBufferSlice(*buffer, offset, length));
+ return SliceBuffer(buffer, offset, length);
+}
+
+Result<std::shared_ptr<Buffer>> SliceMutableBufferSafe(
+ const std::shared_ptr<Buffer>& buffer, int64_t offset) {
+ RETURN_NOT_OK(CheckBufferSlice(*buffer, offset));
+ return SliceMutableBuffer(buffer, offset);
+}
+
+Result<std::shared_ptr<Buffer>> SliceMutableBufferSafe(
+ const std::shared_ptr<Buffer>& buffer, int64_t offset, int64_t length) {
+ RETURN_NOT_OK(CheckBufferSlice(*buffer, offset, length));
+ return SliceMutableBuffer(buffer, offset, length);
+}
+
+std::string Buffer::ToHexString() {
+ return HexEncode(data(), static_cast<size_t>(size()));
+}
+
+bool Buffer::Equals(const Buffer& other, const int64_t nbytes) const {
+ return this == &other || (size_ >= nbytes && other.size_ >= nbytes &&
+ (data_ == other.data_ ||
+ !memcmp(data_, other.data_, static_cast<size_t>(nbytes))));
+}
+
+bool Buffer::Equals(const Buffer& other) const {
+ return this == &other || (size_ == other.size_ &&
+ (data_ == other.data_ ||
+ !memcmp(data_, other.data_, static_cast<size_t>(size_))));
+}
+
+std::string Buffer::ToString() const {
+ return std::string(reinterpret_cast<const char*>(data_), static_cast<size_t>(size_));
+}
+
+void Buffer::CheckMutable() const { DCHECK(is_mutable()) << "buffer not mutable"; }
+
+void Buffer::CheckCPU() const {
+ DCHECK(is_cpu()) << "not a CPU buffer (device: " << device()->ToString() << ")";
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> Buffer::GetReader(
+ std::shared_ptr<Buffer> buf) {
+ return buf->memory_manager_->GetBufferReader(buf);
+}
+
+Result<std::shared_ptr<io::OutputStream>> Buffer::GetWriter(std::shared_ptr<Buffer> buf) {
+ if (!buf->is_mutable()) {
+ return Status::Invalid("Expected mutable buffer");
+ }
+ return buf->memory_manager_->GetBufferWriter(buf);
+}
+
+Result<std::shared_ptr<Buffer>> Buffer::Copy(std::shared_ptr<Buffer> source,
+ const std::shared_ptr<MemoryManager>& to) {
+ return MemoryManager::CopyBuffer(source, to);
+}
+
+Result<std::shared_ptr<Buffer>> Buffer::View(std::shared_ptr<Buffer> source,
+ const std::shared_ptr<MemoryManager>& to) {
+ return MemoryManager::ViewBuffer(source, to);
+}
+
+Result<std::shared_ptr<Buffer>> Buffer::ViewOrCopy(
+ std::shared_ptr<Buffer> source, const std::shared_ptr<MemoryManager>& to) {
+ auto maybe_buffer = MemoryManager::ViewBuffer(source, to);
+ if (maybe_buffer.ok()) {
+ return maybe_buffer;
+ }
+ return MemoryManager::CopyBuffer(source, to);
+}
+
+class StlStringBuffer : public Buffer {
+ public:
+ explicit StlStringBuffer(std::string data)
+ : Buffer(nullptr, 0), input_(std::move(data)) {
+ data_ = reinterpret_cast<const uint8_t*>(input_.c_str());
+ size_ = static_cast<int64_t>(input_.size());
+ capacity_ = size_;
+ }
+
+ private:
+ std::string input_;
+};
+
+std::shared_ptr<Buffer> Buffer::FromString(std::string data) {
+ return std::make_shared<StlStringBuffer>(std::move(data));
+}
+
+std::shared_ptr<Buffer> SliceMutableBuffer(const std::shared_ptr<Buffer>& buffer,
+ const int64_t offset, const int64_t length) {
+ return std::make_shared<MutableBuffer>(buffer, offset, length);
+}
+
+MutableBuffer::MutableBuffer(const std::shared_ptr<Buffer>& parent, const int64_t offset,
+ const int64_t size)
+ : MutableBuffer(reinterpret_cast<uint8_t*>(parent->mutable_address()) + offset,
+ size) {
+ DCHECK(parent->is_mutable()) << "Must pass mutable buffer";
+ parent_ = parent;
+}
+
+Result<std::shared_ptr<Buffer>> AllocateBitmap(int64_t length, MemoryPool* pool) {
+ ARROW_ASSIGN_OR_RAISE(auto buf, AllocateBuffer(BitUtil::BytesForBits(length), pool));
+ // Zero out any trailing bits
+ if (buf->size() > 0) {
+ buf->mutable_data()[buf->size() - 1] = 0;
+ }
+ return std::move(buf);
+}
+
+Result<std::shared_ptr<Buffer>> AllocateEmptyBitmap(int64_t length, MemoryPool* pool) {
+ ARROW_ASSIGN_OR_RAISE(auto buf, AllocateBuffer(BitUtil::BytesForBits(length), pool));
+ memset(buf->mutable_data(), 0, static_cast<size_t>(buf->size()));
+ return std::move(buf);
+}
+
+Status AllocateEmptyBitmap(int64_t length, std::shared_ptr<Buffer>* out) {
+ return AllocateEmptyBitmap(length).Value(out);
+}
+
+Result<std::shared_ptr<Buffer>> ConcatenateBuffers(
+ const std::vector<std::shared_ptr<Buffer>>& buffers, MemoryPool* pool) {
+ int64_t out_length = 0;
+ for (const auto& buffer : buffers) {
+ out_length += buffer->size();
+ }
+ ARROW_ASSIGN_OR_RAISE(auto out, AllocateBuffer(out_length, pool));
+ auto out_data = out->mutable_data();
+ for (const auto& buffer : buffers) {
+ std::memcpy(out_data, buffer->data(), buffer->size());
+ out_data += buffer->size();
+ }
+ return std::move(out);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/buffer.h b/src/arrow/cpp/src/arrow/buffer.h
new file mode 100644
index 000000000..cfd525ab2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/buffer.h
@@ -0,0 +1,499 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/device.h"
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+// ----------------------------------------------------------------------
+// Buffer classes
+
+/// \class Buffer
+/// \brief Object containing a pointer to a piece of contiguous memory with a
+/// particular size.
+///
+/// Buffers have two related notions of length: size and capacity. Size is
+/// the number of bytes that might have valid data. Capacity is the number
+/// of bytes that were allocated for the buffer in total.
+///
+/// The Buffer base class does not own its memory, but subclasses often do.
+///
+/// The following invariant is always true: Size <= Capacity
+class ARROW_EXPORT Buffer {
+ public:
+ /// \brief Construct from buffer and size without copying memory
+ ///
+ /// \param[in] data a memory buffer
+ /// \param[in] size buffer size
+ ///
+ /// \note The passed memory must be kept alive through some other means
+ Buffer(const uint8_t* data, int64_t size)
+ : is_mutable_(false), is_cpu_(true), data_(data), size_(size), capacity_(size) {
+ SetMemoryManager(default_cpu_memory_manager());
+ }
+
+ Buffer(const uint8_t* data, int64_t size, std::shared_ptr<MemoryManager> mm,
+ std::shared_ptr<Buffer> parent = NULLPTR)
+ : is_mutable_(false), data_(data), size_(size), capacity_(size), parent_(parent) {
+ SetMemoryManager(std::move(mm));
+ }
+
+ Buffer(uintptr_t address, int64_t size, std::shared_ptr<MemoryManager> mm,
+ std::shared_ptr<Buffer> parent = NULLPTR)
+ : Buffer(reinterpret_cast<const uint8_t*>(address), size, std::move(mm),
+ std::move(parent)) {}
+
+ /// \brief Construct from string_view without copying memory
+ ///
+ /// \param[in] data a string_view object
+ ///
+ /// \note The memory viewed by data must not be deallocated in the lifetime of the
+ /// Buffer; temporary rvalue strings must be stored in an lvalue somewhere
+ explicit Buffer(util::string_view data)
+ : Buffer(reinterpret_cast<const uint8_t*>(data.data()),
+ static_cast<int64_t>(data.size())) {}
+
+ virtual ~Buffer() = default;
+
+ /// An offset into data that is owned by another buffer, but we want to be
+ /// able to retain a valid pointer to it even after other shared_ptr's to the
+ /// parent buffer have been destroyed
+ ///
+ /// This method makes no assertions about alignment or padding of the buffer but
+ /// in general we expected buffers to be aligned and padded to 64 bytes. In the future
+ /// we might add utility methods to help determine if a buffer satisfies this contract.
+ Buffer(const std::shared_ptr<Buffer>& parent, const int64_t offset, const int64_t size)
+ : Buffer(parent->data_ + offset, size) {
+ parent_ = parent;
+ SetMemoryManager(parent->memory_manager_);
+ }
+
+ uint8_t operator[](std::size_t i) const { return data_[i]; }
+
+ /// \brief Construct a new std::string with a hexadecimal representation of the buffer.
+ /// \return std::string
+ std::string ToHexString();
+
+ /// Return true if both buffers are the same size and contain the same bytes
+ /// up to the number of compared bytes
+ bool Equals(const Buffer& other, int64_t nbytes) const;
+
+ /// Return true if both buffers are the same size and contain the same bytes
+ bool Equals(const Buffer& other) const;
+
+ /// Copy a section of the buffer into a new Buffer.
+ Result<std::shared_ptr<Buffer>> CopySlice(
+ const int64_t start, const int64_t nbytes,
+ MemoryPool* pool = default_memory_pool()) const;
+
+ /// Zero bytes in padding, i.e. bytes between size_ and capacity_.
+ void ZeroPadding() {
+#ifndef NDEBUG
+ CheckMutable();
+#endif
+ // A zero-capacity buffer can have a null data pointer
+ if (capacity_ != 0) {
+ memset(mutable_data() + size_, 0, static_cast<size_t>(capacity_ - size_));
+ }
+ }
+
+ /// \brief Construct an immutable buffer that takes ownership of the contents
+ /// of an std::string (without copying it).
+ ///
+ /// \param[in] data a string to own
+ /// \return a new Buffer instance
+ static std::shared_ptr<Buffer> FromString(std::string data);
+
+ /// \brief Create buffer referencing typed memory with some length without
+ /// copying
+ /// \param[in] data the typed memory as C array
+ /// \param[in] length the number of values in the array
+ /// \return a new shared_ptr<Buffer>
+ template <typename T, typename SizeType = int64_t>
+ static std::shared_ptr<Buffer> Wrap(const T* data, SizeType length) {
+ return std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>(data),
+ static_cast<int64_t>(sizeof(T) * length));
+ }
+
+ /// \brief Create buffer referencing std::vector with some length without
+ /// copying
+ /// \param[in] data the vector to be referenced. If this vector is changed,
+ /// the buffer may become invalid
+ /// \return a new shared_ptr<Buffer>
+ template <typename T>
+ static std::shared_ptr<Buffer> Wrap(const std::vector<T>& data) {
+ return std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>(data.data()),
+ static_cast<int64_t>(sizeof(T) * data.size()));
+ }
+
+ /// \brief Copy buffer contents into a new std::string
+ /// \return std::string
+ /// \note Can throw std::bad_alloc if buffer is large
+ std::string ToString() const;
+
+ /// \brief View buffer contents as a util::string_view
+ /// \return util::string_view
+ explicit operator util::string_view() const {
+ return util::string_view(reinterpret_cast<const char*>(data_), size_);
+ }
+
+ /// \brief View buffer contents as a util::bytes_view
+ /// \return util::bytes_view
+ explicit operator util::bytes_view() const { return util::bytes_view(data_, size_); }
+
+ /// \brief Return a pointer to the buffer's data
+ ///
+ /// The buffer has to be a CPU buffer (`is_cpu()` is true).
+ /// Otherwise, an assertion may be thrown or a null pointer may be returned.
+ ///
+ /// To get the buffer's data address regardless of its device, call `address()`.
+ const uint8_t* data() const {
+#ifndef NDEBUG
+ CheckCPU();
+#endif
+ return ARROW_PREDICT_TRUE(is_cpu_) ? data_ : NULLPTR;
+ }
+
+ /// \brief Return a writable pointer to the buffer's data
+ ///
+ /// The buffer has to be a mutable CPU buffer (`is_cpu()` and `is_mutable()`
+ /// are true). Otherwise, an assertion may be thrown or a null pointer may
+ /// be returned.
+ ///
+ /// To get the buffer's mutable data address regardless of its device, call
+ /// `mutable_address()`.
+ uint8_t* mutable_data() {
+#ifndef NDEBUG
+ CheckCPU();
+ CheckMutable();
+#endif
+ return ARROW_PREDICT_TRUE(is_cpu_ && is_mutable_) ? const_cast<uint8_t*>(data_)
+ : NULLPTR;
+ }
+
+ /// \brief Return the device address of the buffer's data
+ uintptr_t address() const { return reinterpret_cast<uintptr_t>(data_); }
+
+ /// \brief Return a writable device address to the buffer's data
+ ///
+ /// The buffer has to be a mutable buffer (`is_mutable()` is true).
+ /// Otherwise, an assertion may be thrown or 0 may be returned.
+ uintptr_t mutable_address() const {
+#ifndef NDEBUG
+ CheckMutable();
+#endif
+ return ARROW_PREDICT_TRUE(is_mutable_) ? reinterpret_cast<uintptr_t>(data_) : 0;
+ }
+
+ /// \brief Return the buffer's size in bytes
+ int64_t size() const { return size_; }
+
+ /// \brief Return the buffer's capacity (number of allocated bytes)
+ int64_t capacity() const { return capacity_; }
+
+ /// \brief Whether the buffer is directly CPU-accessible
+ ///
+ /// If this function returns true, you can read directly from the buffer's
+ /// `data()` pointer. Otherwise, you'll have to `View()` or `Copy()` it.
+ bool is_cpu() const { return is_cpu_; }
+
+ /// \brief Whether the buffer is mutable
+ ///
+ /// If this function returns true, you are allowed to modify buffer contents
+ /// using the pointer returned by `mutable_data()` or `mutable_address()`.
+ bool is_mutable() const { return is_mutable_; }
+
+ const std::shared_ptr<Device>& device() const { return memory_manager_->device(); }
+
+ const std::shared_ptr<MemoryManager>& memory_manager() const { return memory_manager_; }
+
+ std::shared_ptr<Buffer> parent() const { return parent_; }
+
+ /// \brief Get a RandomAccessFile for reading a buffer
+ ///
+ /// The returned file object reads from this buffer's underlying memory.
+ static Result<std::shared_ptr<io::RandomAccessFile>> GetReader(std::shared_ptr<Buffer>);
+
+ /// \brief Get a OutputStream for writing to a buffer
+ ///
+ /// The buffer must be mutable. The returned stream object writes into the buffer's
+ /// underlying memory (but it won't resize it).
+ static Result<std::shared_ptr<io::OutputStream>> GetWriter(std::shared_ptr<Buffer>);
+
+ /// \brief Copy buffer
+ ///
+ /// The buffer contents will be copied into a new buffer allocated by the
+ /// given MemoryManager. This function supports cross-device copies.
+ static Result<std::shared_ptr<Buffer>> Copy(std::shared_ptr<Buffer> source,
+ const std::shared_ptr<MemoryManager>& to);
+
+ /// \brief View buffer
+ ///
+ /// Return a Buffer that reflects this buffer, seen potentially from another
+ /// device, without making an explicit copy of the contents. The underlying
+ /// mechanism is typically implemented by the kernel or device driver, and may
+ /// involve lazy caching of parts of the buffer contents on the destination
+ /// device's memory.
+ ///
+ /// If a non-copy view is unsupported for the buffer on the given device,
+ /// nullptr is returned. An error can be returned if some low-level
+ /// operation fails (such as an out-of-memory condition).
+ static Result<std::shared_ptr<Buffer>> View(std::shared_ptr<Buffer> source,
+ const std::shared_ptr<MemoryManager>& to);
+
+ /// \brief View or copy buffer
+ ///
+ /// Try to view buffer contents on the given MemoryManager's device, but
+ /// fall back to copying if a no-copy view isn't supported.
+ static Result<std::shared_ptr<Buffer>> ViewOrCopy(
+ std::shared_ptr<Buffer> source, const std::shared_ptr<MemoryManager>& to);
+
+ protected:
+ bool is_mutable_;
+ bool is_cpu_;
+ const uint8_t* data_;
+ int64_t size_;
+ int64_t capacity_;
+
+ // null by default, but may be set
+ std::shared_ptr<Buffer> parent_;
+
+ private:
+ // private so that subclasses are forced to call SetMemoryManager()
+ std::shared_ptr<MemoryManager> memory_manager_;
+
+ protected:
+ void CheckMutable() const;
+ void CheckCPU() const;
+
+ void SetMemoryManager(std::shared_ptr<MemoryManager> mm) {
+ memory_manager_ = std::move(mm);
+ is_cpu_ = memory_manager_->is_cpu();
+ }
+
+ private:
+ Buffer() = delete;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Buffer);
+};
+
+/// \defgroup buffer-slicing-functions Functions for slicing buffers
+///
+/// @{
+
+/// \brief Construct a view on a buffer at the given offset and length.
+///
+/// This function cannot fail and does not check for errors (except in debug builds)
+static inline std::shared_ptr<Buffer> SliceBuffer(const std::shared_ptr<Buffer>& buffer,
+ const int64_t offset,
+ const int64_t length) {
+ return std::make_shared<Buffer>(buffer, offset, length);
+}
+
+/// \brief Construct a view on a buffer at the given offset, up to the buffer's end.
+///
+/// This function cannot fail and does not check for errors (except in debug builds)
+static inline std::shared_ptr<Buffer> SliceBuffer(const std::shared_ptr<Buffer>& buffer,
+ const int64_t offset) {
+ int64_t length = buffer->size() - offset;
+ return SliceBuffer(buffer, offset, length);
+}
+
+/// \brief Input-checking version of SliceBuffer
+///
+/// An Invalid Status is returned if the requested slice falls out of bounds.
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> SliceBufferSafe(const std::shared_ptr<Buffer>& buffer,
+ int64_t offset);
+/// \brief Input-checking version of SliceBuffer
+///
+/// An Invalid Status is returned if the requested slice falls out of bounds.
+/// Note that unlike SliceBuffer, `length` isn't clamped to the available buffer size.
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> SliceBufferSafe(const std::shared_ptr<Buffer>& buffer,
+ int64_t offset, int64_t length);
+
+/// \brief Like SliceBuffer, but construct a mutable buffer slice.
+///
+/// If the parent buffer is not mutable, behavior is undefined (it may abort
+/// in debug builds).
+ARROW_EXPORT
+std::shared_ptr<Buffer> SliceMutableBuffer(const std::shared_ptr<Buffer>& buffer,
+ const int64_t offset, const int64_t length);
+
+/// \brief Like SliceBuffer, but construct a mutable buffer slice.
+///
+/// If the parent buffer is not mutable, behavior is undefined (it may abort
+/// in debug builds).
+static inline std::shared_ptr<Buffer> SliceMutableBuffer(
+ const std::shared_ptr<Buffer>& buffer, const int64_t offset) {
+ int64_t length = buffer->size() - offset;
+ return SliceMutableBuffer(buffer, offset, length);
+}
+
+/// \brief Input-checking version of SliceMutableBuffer
+///
+/// An Invalid Status is returned if the requested slice falls out of bounds.
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> SliceMutableBufferSafe(
+ const std::shared_ptr<Buffer>& buffer, int64_t offset);
+/// \brief Input-checking version of SliceMutableBuffer
+///
+/// An Invalid Status is returned if the requested slice falls out of bounds.
+/// Note that unlike SliceBuffer, `length` isn't clamped to the available buffer size.
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> SliceMutableBufferSafe(
+ const std::shared_ptr<Buffer>& buffer, int64_t offset, int64_t length);
+
+/// @}
+
+/// \class MutableBuffer
+/// \brief A Buffer whose contents can be mutated. May or may not own its data.
+class ARROW_EXPORT MutableBuffer : public Buffer {
+ public:
+ MutableBuffer(uint8_t* data, const int64_t size) : Buffer(data, size) {
+ is_mutable_ = true;
+ }
+
+ MutableBuffer(uint8_t* data, const int64_t size, std::shared_ptr<MemoryManager> mm)
+ : Buffer(data, size, std::move(mm)) {
+ is_mutable_ = true;
+ }
+
+ MutableBuffer(const std::shared_ptr<Buffer>& parent, const int64_t offset,
+ const int64_t size);
+
+ /// \brief Create buffer referencing typed memory with some length
+ /// \param[in] data the typed memory as C array
+ /// \param[in] length the number of values in the array
+ /// \return a new shared_ptr<Buffer>
+ template <typename T, typename SizeType = int64_t>
+ static std::shared_ptr<Buffer> Wrap(T* data, SizeType length) {
+ return std::make_shared<MutableBuffer>(reinterpret_cast<uint8_t*>(data),
+ static_cast<int64_t>(sizeof(T) * length));
+ }
+
+ protected:
+ MutableBuffer() : Buffer(NULLPTR, 0) {}
+};
+
+/// \class ResizableBuffer
+/// \brief A mutable buffer that can be resized
+class ARROW_EXPORT ResizableBuffer : public MutableBuffer {
+ public:
+ /// Change buffer reported size to indicated size, allocating memory if
+ /// necessary. This will ensure that the capacity of the buffer is a multiple
+ /// of 64 bytes as defined in Layout.md.
+ /// Consider using ZeroPadding afterwards, to conform to the Arrow layout
+ /// specification.
+ ///
+ /// @param new_size The new size for the buffer.
+ /// @param shrink_to_fit Whether to shrink the capacity if new size < current size
+ virtual Status Resize(const int64_t new_size, bool shrink_to_fit) = 0;
+ Status Resize(const int64_t new_size) {
+ return Resize(new_size, /*shrink_to_fit=*/true);
+ }
+
+ /// Ensure that buffer has enough memory allocated to fit the indicated
+ /// capacity (and meets the 64 byte padding requirement in Layout.md).
+ /// It does not change buffer's reported size and doesn't zero the padding.
+ virtual Status Reserve(const int64_t new_capacity) = 0;
+
+ template <class T>
+ Status TypedResize(const int64_t new_nb_elements, bool shrink_to_fit = true) {
+ return Resize(sizeof(T) * new_nb_elements, shrink_to_fit);
+ }
+
+ template <class T>
+ Status TypedReserve(const int64_t new_nb_elements) {
+ return Reserve(sizeof(T) * new_nb_elements);
+ }
+
+ protected:
+ ResizableBuffer(uint8_t* data, int64_t size) : MutableBuffer(data, size) {}
+ ResizableBuffer(uint8_t* data, int64_t size, std::shared_ptr<MemoryManager> mm)
+ : MutableBuffer(data, size, std::move(mm)) {}
+};
+
+/// \defgroup buffer-allocation-functions Functions for allocating buffers
+///
+/// @{
+
+/// \brief Allocate a fixed size mutable buffer from a memory pool, zero its padding.
+///
+/// \param[in] size size of buffer to allocate
+/// \param[in] pool a memory pool
+ARROW_EXPORT
+Result<std::unique_ptr<Buffer>> AllocateBuffer(const int64_t size,
+ MemoryPool* pool = NULLPTR);
+
+/// \brief Allocate a resizeable buffer from a memory pool, zero its padding.
+///
+/// \param[in] size size of buffer to allocate
+/// \param[in] pool a memory pool
+ARROW_EXPORT
+Result<std::unique_ptr<ResizableBuffer>> AllocateResizableBuffer(
+ const int64_t size, MemoryPool* pool = NULLPTR);
+
+/// \brief Allocate a bitmap buffer from a memory pool
+/// no guarantee on values is provided.
+///
+/// \param[in] length size in bits of bitmap to allocate
+/// \param[in] pool memory pool to allocate memory from
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> AllocateBitmap(int64_t length,
+ MemoryPool* pool = NULLPTR);
+
+ARROW_EXPORT
+Status AllocateBitmap(MemoryPool* pool, int64_t length, std::shared_ptr<Buffer>* out);
+
+/// \brief Allocate a zero-initialized bitmap buffer from a memory pool
+///
+/// \param[in] length size in bits of bitmap to allocate
+/// \param[in] pool memory pool to allocate memory from
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> AllocateEmptyBitmap(int64_t length,
+ MemoryPool* pool = NULLPTR);
+
+/// \brief Concatenate multiple buffers into a single buffer
+///
+/// \param[in] buffers to be concatenated
+/// \param[in] pool memory pool to allocate the new buffer from
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> ConcatenateBuffers(const BufferVector& buffers,
+ MemoryPool* pool = NULLPTR);
+
+ARROW_EXPORT
+Status ConcatenateBuffers(const BufferVector& buffers, MemoryPool* pool,
+ std::shared_ptr<Buffer>* out);
+
+/// @}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/buffer_builder.h b/src/arrow/cpp/src/arrow/buffer_builder.h
new file mode 100644
index 000000000..7b02ad09a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/buffer_builder.h
@@ -0,0 +1,459 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/status.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_generate.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/ubsan.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+// ----------------------------------------------------------------------
+// Buffer builder classes
+
+/// \class BufferBuilder
+/// \brief A class for incrementally building a contiguous chunk of in-memory
+/// data
+class ARROW_EXPORT BufferBuilder {
+ public:
+ explicit BufferBuilder(MemoryPool* pool = default_memory_pool())
+ : pool_(pool),
+ data_(/*ensure never null to make ubsan happy and avoid check penalties below*/
+ util::MakeNonNull<uint8_t>()),
+ capacity_(0),
+ size_(0) {}
+
+ /// \brief Constructs new Builder that will start using
+ /// the provided buffer until Finish/Reset are called.
+ /// The buffer is not resized.
+ explicit BufferBuilder(std::shared_ptr<ResizableBuffer> buffer,
+ MemoryPool* pool = default_memory_pool())
+ : buffer_(std::move(buffer)),
+ pool_(pool),
+ data_(buffer_->mutable_data()),
+ capacity_(buffer_->capacity()),
+ size_(buffer_->size()) {}
+
+ /// \brief Resize the buffer to the nearest multiple of 64 bytes
+ ///
+ /// \param new_capacity the new capacity of the of the builder. Will be
+ /// rounded up to a multiple of 64 bytes for padding
+ /// \param shrink_to_fit if new capacity is smaller than the existing,
+ /// reallocate internal buffer. Set to false to avoid reallocations when
+ /// shrinking the builder.
+ /// \return Status
+ Status Resize(const int64_t new_capacity, bool shrink_to_fit = true) {
+ if (buffer_ == NULLPTR) {
+ ARROW_ASSIGN_OR_RAISE(buffer_, AllocateResizableBuffer(new_capacity, pool_));
+ } else {
+ ARROW_RETURN_NOT_OK(buffer_->Resize(new_capacity, shrink_to_fit));
+ }
+ capacity_ = buffer_->capacity();
+ data_ = buffer_->mutable_data();
+ return Status::OK();
+ }
+
+ /// \brief Ensure that builder can accommodate the additional number of bytes
+ /// without the need to perform allocations
+ ///
+ /// \param[in] additional_bytes number of additional bytes to make space for
+ /// \return Status
+ Status Reserve(const int64_t additional_bytes) {
+ auto min_capacity = size_ + additional_bytes;
+ if (min_capacity <= capacity_) {
+ return Status::OK();
+ }
+ return Resize(GrowByFactor(capacity_, min_capacity), false);
+ }
+
+ /// \brief Return a capacity expanded by the desired growth factor
+ static int64_t GrowByFactor(int64_t current_capacity, int64_t new_capacity) {
+ // Doubling capacity except for large Reserve requests. 2x growth strategy
+ // (versus 1.5x) seems to have slightly better performance when using
+ // jemalloc, but significantly better performance when using the system
+ // allocator. See ARROW-6450 for further discussion
+ return std::max(new_capacity, current_capacity * 2);
+ }
+
+ /// \brief Append the given data to the buffer
+ ///
+ /// The buffer is automatically expanded if necessary.
+ Status Append(const void* data, const int64_t length) {
+ if (ARROW_PREDICT_FALSE(size_ + length > capacity_)) {
+ ARROW_RETURN_NOT_OK(Resize(GrowByFactor(capacity_, size_ + length), false));
+ }
+ UnsafeAppend(data, length);
+ return Status::OK();
+ }
+
+ /// \brief Append copies of a value to the buffer
+ ///
+ /// The buffer is automatically expanded if necessary.
+ Status Append(const int64_t num_copies, uint8_t value) {
+ ARROW_RETURN_NOT_OK(Reserve(num_copies));
+ UnsafeAppend(num_copies, value);
+ return Status::OK();
+ }
+
+ // Advance pointer and zero out memory
+ Status Advance(const int64_t length) { return Append(length, 0); }
+
+ // Advance pointer, but don't allocate or zero memory
+ void UnsafeAdvance(const int64_t length) { size_ += length; }
+
+ // Unsafe methods don't check existing size
+ void UnsafeAppend(const void* data, const int64_t length) {
+ memcpy(data_ + size_, data, static_cast<size_t>(length));
+ size_ += length;
+ }
+
+ void UnsafeAppend(const int64_t num_copies, uint8_t value) {
+ memset(data_ + size_, value, static_cast<size_t>(num_copies));
+ size_ += num_copies;
+ }
+
+ /// \brief Return result of builder as a Buffer object.
+ ///
+ /// The builder is reset and can be reused afterwards.
+ ///
+ /// \param[out] out the finalized Buffer object
+ /// \param shrink_to_fit if the buffer size is smaller than its capacity,
+ /// reallocate to fit more tightly in memory. Set to false to avoid
+ /// a reallocation, at the expense of potentially more memory consumption.
+ /// \return Status
+ Status Finish(std::shared_ptr<Buffer>* out, bool shrink_to_fit = true) {
+ ARROW_RETURN_NOT_OK(Resize(size_, shrink_to_fit));
+ if (size_ != 0) buffer_->ZeroPadding();
+ *out = buffer_;
+ if (*out == NULLPTR) {
+ ARROW_ASSIGN_OR_RAISE(*out, AllocateBuffer(0, pool_));
+ }
+ Reset();
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<Buffer>> Finish(bool shrink_to_fit = true) {
+ std::shared_ptr<Buffer> out;
+ ARROW_RETURN_NOT_OK(Finish(&out, shrink_to_fit));
+ return out;
+ }
+
+ /// \brief Like Finish, but override the final buffer size
+ ///
+ /// This is useful after writing data directly into the builder memory
+ /// without calling the Append methods (basically, when using BufferBuilder
+ /// mostly for memory allocation).
+ Result<std::shared_ptr<Buffer>> FinishWithLength(int64_t final_length,
+ bool shrink_to_fit = true) {
+ size_ = final_length;
+ return Finish(shrink_to_fit);
+ }
+
+ void Reset() {
+ buffer_ = NULLPTR;
+ capacity_ = size_ = 0;
+ }
+
+ /// \brief Set size to a smaller value without modifying builder
+ /// contents. For reusable BufferBuilder classes
+ /// \param[in] position must be non-negative and less than or equal
+ /// to the current length()
+ void Rewind(int64_t position) { size_ = position; }
+
+ int64_t capacity() const { return capacity_; }
+ int64_t length() const { return size_; }
+ const uint8_t* data() const { return data_; }
+ uint8_t* mutable_data() { return data_; }
+
+ private:
+ std::shared_ptr<ResizableBuffer> buffer_;
+ MemoryPool* pool_;
+ uint8_t* data_;
+ int64_t capacity_;
+ int64_t size_;
+};
+
+template <typename T, typename Enable = void>
+class TypedBufferBuilder;
+
+/// \brief A BufferBuilder for building a buffer of arithmetic elements
+template <typename T>
+class TypedBufferBuilder<
+ T, typename std::enable_if<std::is_arithmetic<T>::value ||
+ std::is_standard_layout<T>::value>::type> {
+ public:
+ explicit TypedBufferBuilder(MemoryPool* pool = default_memory_pool())
+ : bytes_builder_(pool) {}
+
+ explicit TypedBufferBuilder(std::shared_ptr<ResizableBuffer> buffer,
+ MemoryPool* pool = default_memory_pool())
+ : bytes_builder_(std::move(buffer), pool) {}
+
+ explicit TypedBufferBuilder(BufferBuilder builder)
+ : bytes_builder_(std::move(builder)) {}
+
+ BufferBuilder* bytes_builder() { return &bytes_builder_; }
+
+ Status Append(T value) {
+ return bytes_builder_.Append(reinterpret_cast<uint8_t*>(&value), sizeof(T));
+ }
+
+ Status Append(const T* values, int64_t num_elements) {
+ return bytes_builder_.Append(reinterpret_cast<const uint8_t*>(values),
+ num_elements * sizeof(T));
+ }
+
+ Status Append(const int64_t num_copies, T value) {
+ ARROW_RETURN_NOT_OK(Reserve(num_copies + length()));
+ UnsafeAppend(num_copies, value);
+ return Status::OK();
+ }
+
+ void UnsafeAppend(T value) {
+ bytes_builder_.UnsafeAppend(reinterpret_cast<uint8_t*>(&value), sizeof(T));
+ }
+
+ void UnsafeAppend(const T* values, int64_t num_elements) {
+ bytes_builder_.UnsafeAppend(reinterpret_cast<const uint8_t*>(values),
+ num_elements * sizeof(T));
+ }
+
+ template <typename Iter>
+ void UnsafeAppend(Iter values_begin, Iter values_end) {
+ int64_t num_elements = static_cast<int64_t>(std::distance(values_begin, values_end));
+ auto data = mutable_data() + length();
+ bytes_builder_.UnsafeAdvance(num_elements * sizeof(T));
+ std::copy(values_begin, values_end, data);
+ }
+
+ void UnsafeAppend(const int64_t num_copies, T value) {
+ auto data = mutable_data() + length();
+ bytes_builder_.UnsafeAdvance(num_copies * sizeof(T));
+ std::fill(data, data + num_copies, value);
+ }
+
+ Status Resize(const int64_t new_capacity, bool shrink_to_fit = true) {
+ return bytes_builder_.Resize(new_capacity * sizeof(T), shrink_to_fit);
+ }
+
+ Status Reserve(const int64_t additional_elements) {
+ return bytes_builder_.Reserve(additional_elements * sizeof(T));
+ }
+
+ Status Advance(const int64_t length) {
+ return bytes_builder_.Advance(length * sizeof(T));
+ }
+
+ Status Finish(std::shared_ptr<Buffer>* out, bool shrink_to_fit = true) {
+ return bytes_builder_.Finish(out, shrink_to_fit);
+ }
+
+ Result<std::shared_ptr<Buffer>> Finish(bool shrink_to_fit = true) {
+ std::shared_ptr<Buffer> out;
+ ARROW_RETURN_NOT_OK(Finish(&out, shrink_to_fit));
+ return out;
+ }
+
+ /// \brief Like Finish, but override the final buffer size
+ ///
+ /// This is useful after writing data directly into the builder memory
+ /// without calling the Append methods (basically, when using TypedBufferBuilder
+ /// only for memory allocation).
+ Result<std::shared_ptr<Buffer>> FinishWithLength(int64_t final_length,
+ bool shrink_to_fit = true) {
+ return bytes_builder_.FinishWithLength(final_length * sizeof(T), shrink_to_fit);
+ }
+
+ void Reset() { bytes_builder_.Reset(); }
+
+ int64_t length() const { return bytes_builder_.length() / sizeof(T); }
+ int64_t capacity() const { return bytes_builder_.capacity() / sizeof(T); }
+ const T* data() const { return reinterpret_cast<const T*>(bytes_builder_.data()); }
+ T* mutable_data() { return reinterpret_cast<T*>(bytes_builder_.mutable_data()); }
+
+ private:
+ BufferBuilder bytes_builder_;
+};
+
+/// \brief A BufferBuilder for building a buffer containing a bitmap
+template <>
+class TypedBufferBuilder<bool> {
+ public:
+ explicit TypedBufferBuilder(MemoryPool* pool = default_memory_pool())
+ : bytes_builder_(pool) {}
+
+ explicit TypedBufferBuilder(BufferBuilder builder)
+ : bytes_builder_(std::move(builder)) {}
+
+ BufferBuilder* bytes_builder() { return &bytes_builder_; }
+
+ Status Append(bool value) {
+ ARROW_RETURN_NOT_OK(Reserve(1));
+ UnsafeAppend(value);
+ return Status::OK();
+ }
+
+ Status Append(const uint8_t* valid_bytes, int64_t num_elements) {
+ ARROW_RETURN_NOT_OK(Reserve(num_elements));
+ UnsafeAppend(valid_bytes, num_elements);
+ return Status::OK();
+ }
+
+ Status Append(const int64_t num_copies, bool value) {
+ ARROW_RETURN_NOT_OK(Reserve(num_copies));
+ UnsafeAppend(num_copies, value);
+ return Status::OK();
+ }
+
+ void UnsafeAppend(bool value) {
+ BitUtil::SetBitTo(mutable_data(), bit_length_, value);
+ if (!value) {
+ ++false_count_;
+ }
+ ++bit_length_;
+ }
+
+ /// \brief Append bits from an array of bytes (one value per byte)
+ void UnsafeAppend(const uint8_t* bytes, int64_t num_elements) {
+ if (num_elements == 0) return;
+ int64_t i = 0;
+ internal::GenerateBitsUnrolled(mutable_data(), bit_length_, num_elements, [&] {
+ bool value = bytes[i++];
+ false_count_ += !value;
+ return value;
+ });
+ bit_length_ += num_elements;
+ }
+
+ /// \brief Append bits from a packed bitmap
+ void UnsafeAppend(const uint8_t* bitmap, int64_t offset, int64_t num_elements) {
+ if (num_elements == 0) return;
+ internal::CopyBitmap(bitmap, offset, num_elements, mutable_data(), bit_length_);
+ false_count_ += num_elements - internal::CountSetBits(bitmap, offset, num_elements);
+ bit_length_ += num_elements;
+ }
+
+ void UnsafeAppend(const int64_t num_copies, bool value) {
+ BitUtil::SetBitsTo(mutable_data(), bit_length_, num_copies, value);
+ false_count_ += num_copies * !value;
+ bit_length_ += num_copies;
+ }
+
+ template <bool count_falses, typename Generator>
+ void UnsafeAppend(const int64_t num_elements, Generator&& gen) {
+ if (num_elements == 0) return;
+
+ if (count_falses) {
+ internal::GenerateBitsUnrolled(mutable_data(), bit_length_, num_elements, [&] {
+ bool value = gen();
+ false_count_ += !value;
+ return value;
+ });
+ } else {
+ internal::GenerateBitsUnrolled(mutable_data(), bit_length_, num_elements,
+ std::forward<Generator>(gen));
+ }
+ bit_length_ += num_elements;
+ }
+
+ Status Resize(const int64_t new_capacity, bool shrink_to_fit = true) {
+ const int64_t old_byte_capacity = bytes_builder_.capacity();
+ ARROW_RETURN_NOT_OK(
+ bytes_builder_.Resize(BitUtil::BytesForBits(new_capacity), shrink_to_fit));
+ // Resize() may have chosen a larger capacity (e.g. for padding),
+ // so ask it again before calling memset().
+ const int64_t new_byte_capacity = bytes_builder_.capacity();
+ if (new_byte_capacity > old_byte_capacity) {
+ // The additional buffer space is 0-initialized for convenience,
+ // so that other methods can simply bump the length.
+ memset(mutable_data() + old_byte_capacity, 0,
+ static_cast<size_t>(new_byte_capacity - old_byte_capacity));
+ }
+ return Status::OK();
+ }
+
+ Status Reserve(const int64_t additional_elements) {
+ return Resize(
+ BufferBuilder::GrowByFactor(bit_length_, bit_length_ + additional_elements),
+ false);
+ }
+
+ Status Advance(const int64_t length) {
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ bit_length_ += length;
+ false_count_ += length;
+ return Status::OK();
+ }
+
+ Status Finish(std::shared_ptr<Buffer>* out, bool shrink_to_fit = true) {
+ // set bytes_builder_.size_ == byte size of data
+ bytes_builder_.UnsafeAdvance(BitUtil::BytesForBits(bit_length_) -
+ bytes_builder_.length());
+ bit_length_ = false_count_ = 0;
+ return bytes_builder_.Finish(out, shrink_to_fit);
+ }
+
+ Result<std::shared_ptr<Buffer>> Finish(bool shrink_to_fit = true) {
+ std::shared_ptr<Buffer> out;
+ ARROW_RETURN_NOT_OK(Finish(&out, shrink_to_fit));
+ return out;
+ }
+
+ /// \brief Like Finish, but override the final buffer size
+ ///
+ /// This is useful after writing data directly into the builder memory
+ /// without calling the Append methods (basically, when using TypedBufferBuilder
+ /// only for memory allocation).
+ Result<std::shared_ptr<Buffer>> FinishWithLength(int64_t final_length,
+ bool shrink_to_fit = true) {
+ const auto final_byte_length = BitUtil::BytesForBits(final_length);
+ bytes_builder_.UnsafeAdvance(final_byte_length - bytes_builder_.length());
+ bit_length_ = false_count_ = 0;
+ return bytes_builder_.FinishWithLength(final_byte_length, shrink_to_fit);
+ }
+
+ void Reset() {
+ bytes_builder_.Reset();
+ bit_length_ = false_count_ = 0;
+ }
+
+ int64_t length() const { return bit_length_; }
+ int64_t capacity() const { return bytes_builder_.capacity() * 8; }
+ const uint8_t* data() const { return bytes_builder_.data(); }
+ uint8_t* mutable_data() { return bytes_builder_.mutable_data(); }
+ int64_t false_count() const { return false_count_; }
+
+ private:
+ BufferBuilder bytes_builder_;
+ int64_t bit_length_ = 0;
+ int64_t false_count_ = 0;
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/buffer_test.cc b/src/arrow/cpp/src/arrow/buffer_test.cc
new file mode 100644
index 000000000..4295d4ca6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/buffer_test.cc
@@ -0,0 +1,926 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/device.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+static const char kMyDeviceTypeName[] = "arrowtest::MyDevice";
+
+static const int kMyDeviceAllowCopy = 1;
+static const int kMyDeviceAllowView = 2;
+static const int kMyDeviceDisallowCopyView = 3;
+
+class MyDevice : public Device {
+ public:
+ explicit MyDevice(int value) : Device(), value_(value) {}
+
+ const char* type_name() const override { return kMyDeviceTypeName; }
+
+ std::string ToString() const override {
+ switch (value_) {
+ case kMyDeviceAllowCopy:
+ return "MyDevice[noview]";
+ case kMyDeviceAllowView:
+ return "MyDevice[nocopy]";
+ default:
+ return "MyDevice[nocopy][noview]";
+ }
+ }
+
+ bool Equals(const Device& other) const override {
+ if (other.type_name() != kMyDeviceTypeName) {
+ return false;
+ }
+ return checked_cast<const MyDevice&>(other).value_ == value_;
+ }
+
+ std::shared_ptr<MemoryManager> default_memory_manager() override;
+
+ int value() const { return value_; }
+
+ bool allow_copy() const { return value_ == kMyDeviceAllowCopy; }
+
+ bool allow_view() const { return value_ == kMyDeviceAllowView; }
+
+ protected:
+ int value_;
+};
+
+class MyMemoryManager : public MemoryManager {
+ public:
+ explicit MyMemoryManager(std::shared_ptr<Device> device) : MemoryManager(device) {}
+
+ bool allow_copy() const {
+ return checked_cast<const MyDevice&>(*device()).allow_copy();
+ }
+
+ bool allow_view() const {
+ return checked_cast<const MyDevice&>(*device()).allow_view();
+ }
+
+ Result<std::shared_ptr<io::RandomAccessFile>> GetBufferReader(
+ std::shared_ptr<Buffer> buf) override {
+ return Status::NotImplemented("");
+ }
+
+ Result<std::shared_ptr<Buffer>> AllocateBuffer(int64_t size) override {
+ return Status::NotImplemented("");
+ }
+
+ Result<std::shared_ptr<io::OutputStream>> GetBufferWriter(
+ std::shared_ptr<Buffer> buf) override {
+ return Status::NotImplemented("");
+ }
+
+ protected:
+ Result<std::shared_ptr<Buffer>> CopyBufferFrom(
+ const std::shared_ptr<Buffer>& buf,
+ const std::shared_ptr<MemoryManager>& from) override;
+ Result<std::shared_ptr<Buffer>> CopyBufferTo(
+ const std::shared_ptr<Buffer>& buf,
+ const std::shared_ptr<MemoryManager>& to) override;
+ Result<std::shared_ptr<Buffer>> ViewBufferFrom(
+ const std::shared_ptr<Buffer>& buf,
+ const std::shared_ptr<MemoryManager>& from) override;
+ Result<std::shared_ptr<Buffer>> ViewBufferTo(
+ const std::shared_ptr<Buffer>& buf,
+ const std::shared_ptr<MemoryManager>& to) override;
+};
+
+class MyBuffer : public Buffer {
+ public:
+ MyBuffer(std::shared_ptr<MemoryManager> mm, const std::shared_ptr<Buffer>& parent)
+ : Buffer(parent->data(), parent->size()) {
+ parent_ = parent;
+ SetMemoryManager(mm);
+ }
+};
+
+std::shared_ptr<MemoryManager> MyDevice::default_memory_manager() {
+ return std::make_shared<MyMemoryManager>(shared_from_this());
+}
+
+Result<std::shared_ptr<Buffer>> MyMemoryManager::CopyBufferFrom(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& from) {
+ if (!allow_copy()) {
+ return nullptr;
+ }
+ if (from->is_cpu()) {
+ // CPU to MyDevice:
+ // 1. CPU to CPU
+ ARROW_ASSIGN_OR_RAISE(auto dest,
+ MemoryManager::CopyBuffer(buf, default_cpu_memory_manager()));
+ // 2. Wrap CPU buffer result
+ return std::make_shared<MyBuffer>(shared_from_this(), dest);
+ }
+ return nullptr;
+}
+
+Result<std::shared_ptr<Buffer>> MyMemoryManager::CopyBufferTo(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& to) {
+ if (!allow_copy()) {
+ return nullptr;
+ }
+ if (to->is_cpu() && buf->parent()) {
+ // MyDevice to CPU
+ return MemoryManager::CopyBuffer(buf->parent(), to);
+ }
+ return nullptr;
+}
+
+Result<std::shared_ptr<Buffer>> MyMemoryManager::ViewBufferFrom(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& from) {
+ if (!allow_view()) {
+ return nullptr;
+ }
+ if (from->is_cpu()) {
+ // CPU on MyDevice: wrap CPU buffer
+ return std::make_shared<MyBuffer>(shared_from_this(), buf);
+ }
+ return nullptr;
+}
+
+Result<std::shared_ptr<Buffer>> MyMemoryManager::ViewBufferTo(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& to) {
+ if (!allow_view()) {
+ return nullptr;
+ }
+ if (to->is_cpu() && buf->parent()) {
+ // MyDevice on CPU: unwrap buffer
+ return buf->parent();
+ }
+ return nullptr;
+}
+
+// Like AssertBufferEqual, but doesn't call Buffer::data()
+void AssertMyBufferEqual(const Buffer& buffer, util::string_view expected) {
+ ASSERT_EQ(util::string_view(buffer), expected);
+}
+
+void AssertIsCPUBuffer(const Buffer& buf) {
+ ASSERT_TRUE(buf.is_cpu());
+ ASSERT_EQ(*buf.device(), *CPUDevice::Instance());
+}
+
+class TestDevice : public ::testing::Test {
+ public:
+ void SetUp() {
+ cpu_device_ = CPUDevice::Instance();
+ my_copy_device_ = std::make_shared<MyDevice>(kMyDeviceAllowCopy);
+ my_view_device_ = std::make_shared<MyDevice>(kMyDeviceAllowView);
+ my_other_device_ = std::make_shared<MyDevice>(kMyDeviceDisallowCopyView);
+
+ cpu_mm_ = cpu_device_->default_memory_manager();
+ my_copy_mm_ = my_copy_device_->default_memory_manager();
+ my_view_mm_ = my_view_device_->default_memory_manager();
+ my_other_mm_ = my_other_device_->default_memory_manager();
+
+ cpu_src_ = Buffer::FromString("some data");
+ my_copy_src_ = std::make_shared<MyBuffer>(my_copy_mm_, cpu_src_);
+ my_view_src_ = std::make_shared<MyBuffer>(my_view_mm_, cpu_src_);
+ my_other_src_ = std::make_shared<MyBuffer>(my_other_mm_, cpu_src_);
+ }
+
+ protected:
+ std::shared_ptr<Device> cpu_device_, my_copy_device_, my_view_device_, my_other_device_;
+ std::shared_ptr<MemoryManager> cpu_mm_, my_copy_mm_, my_view_mm_, my_other_mm_;
+ std::shared_ptr<Buffer> cpu_src_, my_copy_src_, my_view_src_, my_other_src_;
+};
+
+TEST_F(TestDevice, Basics) {
+ ASSERT_TRUE(cpu_device_->is_cpu());
+
+ ASSERT_EQ(*cpu_device_, *cpu_device_);
+ ASSERT_EQ(*my_copy_device_, *my_copy_device_);
+ ASSERT_NE(*cpu_device_, *my_copy_device_);
+ ASSERT_NE(*my_copy_device_, *cpu_device_);
+
+ ASSERT_TRUE(cpu_mm_->is_cpu());
+ ASSERT_FALSE(my_copy_mm_->is_cpu());
+ ASSERT_FALSE(my_other_mm_->is_cpu());
+}
+
+TEST_F(TestDevice, Copy) {
+ // CPU-to-CPU
+ ASSERT_OK_AND_ASSIGN(auto buffer, MemoryManager::CopyBuffer(cpu_src_, cpu_mm_));
+ ASSERT_EQ(buffer->device(), cpu_device_);
+ ASSERT_TRUE(buffer->is_cpu());
+ ASSERT_NE(buffer->address(), cpu_src_->address());
+ ASSERT_NE(buffer->data(), nullptr);
+ AssertBufferEqual(*buffer, "some data");
+
+ // CPU-to-device
+ ASSERT_OK_AND_ASSIGN(buffer, MemoryManager::CopyBuffer(cpu_src_, my_copy_mm_));
+ ASSERT_EQ(buffer->device(), my_copy_device_);
+ ASSERT_FALSE(buffer->is_cpu());
+ ASSERT_NE(buffer->address(), cpu_src_->address());
+#ifdef NDEBUG
+ ASSERT_EQ(buffer->data(), nullptr);
+#endif
+ AssertMyBufferEqual(*buffer, "some data");
+
+ // Device-to-CPU
+ ASSERT_OK_AND_ASSIGN(buffer, MemoryManager::CopyBuffer(my_copy_src_, cpu_mm_));
+ ASSERT_EQ(buffer->device(), cpu_device_);
+ ASSERT_TRUE(buffer->is_cpu());
+ ASSERT_NE(buffer->address(), my_copy_src_->address());
+ ASSERT_NE(buffer->data(), nullptr);
+ AssertBufferEqual(*buffer, "some data");
+
+ // Device-to-device with an intermediate CPU copy
+ ASSERT_OK_AND_ASSIGN(buffer, MemoryManager::CopyBuffer(my_copy_src_, my_copy_mm_));
+ ASSERT_EQ(buffer->device(), my_copy_device_);
+ ASSERT_FALSE(buffer->is_cpu());
+ ASSERT_NE(buffer->address(), my_copy_src_->address());
+#ifdef NDEBUG
+ ASSERT_EQ(buffer->data(), nullptr);
+#endif
+ AssertMyBufferEqual(*buffer, "some data");
+
+ // Device-to-device with an intermediate view on CPU, then a copy from CPU to device
+ ASSERT_OK_AND_ASSIGN(buffer, MemoryManager::CopyBuffer(my_view_src_, my_copy_mm_));
+ ASSERT_EQ(buffer->device(), my_copy_device_);
+ ASSERT_FALSE(buffer->is_cpu());
+ ASSERT_NE(buffer->address(), my_copy_src_->address());
+#ifdef NDEBUG
+ ASSERT_EQ(buffer->data(), nullptr);
+#endif
+ AssertMyBufferEqual(*buffer, "some data");
+
+ ASSERT_RAISES(NotImplemented, MemoryManager::CopyBuffer(cpu_src_, my_other_mm_));
+ ASSERT_RAISES(NotImplemented, MemoryManager::CopyBuffer(my_other_src_, cpu_mm_));
+}
+
+TEST_F(TestDevice, View) {
+ // CPU-on-CPU
+ ASSERT_OK_AND_ASSIGN(auto buffer, MemoryManager::ViewBuffer(cpu_src_, cpu_mm_));
+ ASSERT_EQ(buffer->device(), cpu_device_);
+ ASSERT_TRUE(buffer->is_cpu());
+ ASSERT_EQ(buffer->address(), cpu_src_->address());
+ ASSERT_NE(buffer->data(), nullptr);
+ AssertBufferEqual(*buffer, "some data");
+
+ // CPU-on-device
+ ASSERT_OK_AND_ASSIGN(buffer, MemoryManager::ViewBuffer(cpu_src_, my_view_mm_));
+ ASSERT_EQ(buffer->device(), my_view_device_);
+ ASSERT_FALSE(buffer->is_cpu());
+ ASSERT_EQ(buffer->address(), cpu_src_->address());
+#ifdef NDEBUG
+ ASSERT_EQ(buffer->data(), nullptr);
+#endif
+ AssertMyBufferEqual(*buffer, "some data");
+
+ // Device-on-CPU
+ ASSERT_OK_AND_ASSIGN(buffer, MemoryManager::ViewBuffer(my_view_src_, cpu_mm_));
+ ASSERT_EQ(buffer->device(), cpu_device_);
+ ASSERT_TRUE(buffer->is_cpu());
+ ASSERT_EQ(buffer->address(), my_copy_src_->address());
+ ASSERT_NE(buffer->data(), nullptr);
+ AssertBufferEqual(*buffer, "some data");
+
+ ASSERT_RAISES(NotImplemented, MemoryManager::CopyBuffer(cpu_src_, my_other_mm_));
+ ASSERT_RAISES(NotImplemented, MemoryManager::CopyBuffer(my_other_src_, cpu_mm_));
+}
+
+TEST(TestAllocate, Basics) {
+ ASSERT_OK_AND_ASSIGN(auto new_buffer, AllocateBuffer(1024));
+ auto mm = new_buffer->memory_manager();
+ ASSERT_TRUE(mm->is_cpu());
+ ASSERT_EQ(mm.get(), default_cpu_memory_manager().get());
+ auto cpu_mm = checked_pointer_cast<CPUMemoryManager>(mm);
+ ASSERT_EQ(cpu_mm->pool(), default_memory_pool());
+
+ auto pool = std::make_shared<ProxyMemoryPool>(default_memory_pool());
+ ASSERT_OK_AND_ASSIGN(new_buffer, AllocateBuffer(1024, pool.get()));
+ mm = new_buffer->memory_manager();
+ ASSERT_TRUE(mm->is_cpu());
+ cpu_mm = checked_pointer_cast<CPUMemoryManager>(mm);
+ ASSERT_EQ(cpu_mm->pool(), pool.get());
+ new_buffer.reset(); // Destroy before pool
+}
+
+TEST(TestAllocate, Bitmap) {
+ ASSERT_OK_AND_ASSIGN(auto new_buffer, AllocateBitmap(100));
+ AssertIsCPUBuffer(*new_buffer);
+ EXPECT_GE(new_buffer->size(), 13);
+ EXPECT_EQ(new_buffer->capacity() % 8, 0);
+}
+
+TEST(TestAllocate, EmptyBitmap) {
+ ASSERT_OK_AND_ASSIGN(auto new_buffer, AllocateEmptyBitmap(100));
+ AssertIsCPUBuffer(*new_buffer);
+ EXPECT_EQ(new_buffer->size(), 13);
+ EXPECT_EQ(new_buffer->capacity() % 8, 0);
+ EXPECT_TRUE(std::all_of(new_buffer->data(), new_buffer->data() + new_buffer->capacity(),
+ [](int8_t byte) { return byte == 0; }));
+}
+
+TEST(TestBuffer, FromStdString) {
+ std::string val = "hello, world";
+
+ Buffer buf(val);
+ AssertIsCPUBuffer(buf);
+ ASSERT_EQ(0, memcmp(buf.data(), val.c_str(), val.size()));
+ ASSERT_EQ(static_cast<int64_t>(val.size()), buf.size());
+}
+
+TEST(TestBuffer, FromStdStringWithMemory) {
+ std::string expected = "hello, world";
+ std::shared_ptr<Buffer> buf;
+
+ {
+ std::string temp = "hello, world";
+ buf = Buffer::FromString(temp);
+ AssertIsCPUBuffer(*buf);
+ ASSERT_EQ(0, memcmp(buf->data(), temp.c_str(), temp.size()));
+ ASSERT_EQ(static_cast<int64_t>(temp.size()), buf->size());
+ }
+
+ // Now temp goes out of scope and we check if created buffer
+ // is still valid to make sure it actually owns its space
+ ASSERT_EQ(0, memcmp(buf->data(), expected.c_str(), expected.size()));
+ ASSERT_EQ(static_cast<int64_t>(expected.size()), buf->size());
+}
+
+TEST(TestBuffer, EqualsWithSameContent) {
+ MemoryPool* pool = default_memory_pool();
+ const int32_t bufferSize = 128 * 1024;
+ uint8_t* rawBuffer1;
+ ASSERT_OK(pool->Allocate(bufferSize, &rawBuffer1));
+ memset(rawBuffer1, 12, bufferSize);
+ uint8_t* rawBuffer2;
+ ASSERT_OK(pool->Allocate(bufferSize, &rawBuffer2));
+ memset(rawBuffer2, 12, bufferSize);
+ uint8_t* rawBuffer3;
+ ASSERT_OK(pool->Allocate(bufferSize, &rawBuffer3));
+ memset(rawBuffer3, 3, bufferSize);
+
+ Buffer buffer1(rawBuffer1, bufferSize);
+ Buffer buffer2(rawBuffer2, bufferSize);
+ Buffer buffer3(rawBuffer3, bufferSize);
+ ASSERT_TRUE(buffer1.Equals(buffer2));
+ ASSERT_FALSE(buffer1.Equals(buffer3));
+
+ pool->Free(rawBuffer1, bufferSize);
+ pool->Free(rawBuffer2, bufferSize);
+ pool->Free(rawBuffer3, bufferSize);
+}
+
+TEST(TestBuffer, EqualsWithSameBuffer) {
+ MemoryPool* pool = default_memory_pool();
+ const int32_t bufferSize = 128 * 1024;
+ uint8_t* rawBuffer;
+ ASSERT_OK(pool->Allocate(bufferSize, &rawBuffer));
+ memset(rawBuffer, 111, bufferSize);
+
+ Buffer buffer1(rawBuffer, bufferSize);
+ Buffer buffer2(rawBuffer, bufferSize);
+ ASSERT_TRUE(buffer1.Equals(buffer2));
+
+ const int64_t nbytes = bufferSize / 2;
+ Buffer buffer3(rawBuffer, nbytes);
+ ASSERT_TRUE(buffer1.Equals(buffer3, nbytes));
+ ASSERT_FALSE(buffer1.Equals(buffer3, nbytes + 1));
+
+ pool->Free(rawBuffer, bufferSize);
+}
+
+TEST(TestBuffer, CopySlice) {
+ std::string data_str = "some data to copy";
+
+ auto data = reinterpret_cast<const uint8_t*>(data_str.c_str());
+
+ Buffer buf(data, data_str.size());
+
+ ASSERT_OK_AND_ASSIGN(auto out, buf.CopySlice(5, 4));
+ AssertIsCPUBuffer(*out);
+
+ Buffer expected(data + 5, 4);
+ ASSERT_TRUE(out->Equals(expected));
+ // assert the padding is zeroed
+ std::vector<uint8_t> zeros(out->capacity() - out->size());
+ ASSERT_EQ(0, memcmp(out->data() + out->size(), zeros.data(), zeros.size()));
+}
+
+TEST(TestBuffer, CopySliceEmpty) {
+ auto buf = std::make_shared<Buffer>("");
+ ASSERT_OK_AND_ASSIGN(auto out, buf->CopySlice(0, 0));
+ AssertBufferEqual(*out, "");
+
+ buf = std::make_shared<Buffer>("1234");
+ ASSERT_OK_AND_ASSIGN(out, buf->CopySlice(0, 0));
+ AssertBufferEqual(*out, "");
+ ASSERT_OK_AND_ASSIGN(out, buf->CopySlice(4, 0));
+ AssertBufferEqual(*out, "");
+}
+
+TEST(TestBuffer, ToHexString) {
+ const uint8_t data_array[] = "\a0hex string\xa9";
+ std::basic_string<uint8_t> data_str = data_array;
+
+ auto data = reinterpret_cast<const uint8_t*>(data_str.c_str());
+
+ Buffer buf(data, data_str.size());
+
+ ASSERT_EQ(buf.ToHexString(), std::string("073068657820737472696E67A9"));
+}
+
+TEST(TestBuffer, SliceBuffer) {
+ std::string data_str = "some data to slice";
+ auto data = reinterpret_cast<const uint8_t*>(data_str.c_str());
+
+ auto buf = std::make_shared<Buffer>(data, data_str.size());
+
+ std::shared_ptr<Buffer> out = SliceBuffer(buf, 5, 4);
+ AssertIsCPUBuffer(*out);
+ Buffer expected(data + 5, 4);
+ ASSERT_TRUE(out->Equals(expected));
+
+ ASSERT_EQ(2, buf.use_count());
+}
+
+TEST(TestBuffer, SliceBufferSafe) {
+ std::string data_str = "some data to slice";
+ auto data = reinterpret_cast<const uint8_t*>(data_str.c_str());
+
+ auto buf = std::make_shared<Buffer>(data, data_str.size());
+
+ ASSERT_OK_AND_ASSIGN(auto sliced, SliceBufferSafe(buf, 5, 4));
+ AssertBufferEqual(*sliced, "data");
+ ASSERT_OK_AND_ASSIGN(sliced, SliceBufferSafe(buf, 0, 4));
+ AssertBufferEqual(*sliced, "some");
+ ASSERT_OK_AND_ASSIGN(sliced, SliceBufferSafe(buf, 0, 0));
+ AssertBufferEqual(*sliced, "");
+ ASSERT_OK_AND_ASSIGN(sliced, SliceBufferSafe(buf, 4, 0));
+ AssertBufferEqual(*sliced, "");
+ ASSERT_OK_AND_ASSIGN(sliced, SliceBufferSafe(buf, buf->size(), 0));
+ AssertBufferEqual(*sliced, "");
+
+ ASSERT_RAISES(Invalid, SliceBufferSafe(buf, -1, 0));
+ ASSERT_RAISES(Invalid, SliceBufferSafe(buf, 0, -1));
+ ASSERT_RAISES(Invalid, SliceBufferSafe(buf, 0, buf->size() + 1));
+ ASSERT_RAISES(Invalid, SliceBufferSafe(buf, 2, buf->size() - 1));
+ ASSERT_RAISES(Invalid, SliceBufferSafe(buf, buf->size() + 1, 0));
+ ASSERT_RAISES(Invalid,
+ SliceBufferSafe(buf, 3, std::numeric_limits<int64_t>::max() - 2));
+
+ ASSERT_OK_AND_ASSIGN(sliced, SliceBufferSafe(buf, 0));
+ AssertBufferEqual(*sliced, "some data to slice");
+ ASSERT_OK_AND_ASSIGN(sliced, SliceBufferSafe(buf, 5));
+ AssertBufferEqual(*sliced, "data to slice");
+ ASSERT_OK_AND_ASSIGN(sliced, SliceBufferSafe(buf, buf->size()));
+ AssertBufferEqual(*sliced, "");
+
+ ASSERT_RAISES(Invalid, SliceBufferSafe(buf, -1));
+ ASSERT_RAISES(Invalid, SliceBufferSafe(buf, buf->size() + 1));
+}
+
+TEST(TestMutableBuffer, Wrap) {
+ std::vector<int32_t> values = {1, 2, 3};
+
+ auto buf = MutableBuffer::Wrap(values.data(), values.size());
+ AssertIsCPUBuffer(*buf);
+ reinterpret_cast<int32_t*>(buf->mutable_data())[1] = 4;
+
+ ASSERT_EQ(4, values[1]);
+}
+
+TEST(TestBuffer, FromStringRvalue) {
+ std::string expected = "input data";
+
+ std::shared_ptr<Buffer> buffer;
+ {
+ std::string data_str = "input data";
+ buffer = Buffer::FromString(std::move(data_str));
+ AssertIsCPUBuffer(*buffer);
+ }
+
+ ASSERT_FALSE(buffer->is_mutable());
+
+ ASSERT_EQ(0, memcmp(buffer->data(), expected.c_str(), expected.size()));
+ ASSERT_EQ(static_cast<int64_t>(expected.size()), buffer->size());
+}
+
+TEST(TestBuffer, SliceMutableBuffer) {
+ std::string data_str = "some data to slice";
+ auto data = reinterpret_cast<const uint8_t*>(data_str.c_str());
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> buffer, AllocateBuffer(50));
+
+ memcpy(buffer->mutable_data(), data, data_str.size());
+
+ std::shared_ptr<Buffer> slice = SliceMutableBuffer(buffer, 5, 10);
+ AssertIsCPUBuffer(*slice);
+ ASSERT_TRUE(slice->is_mutable());
+ ASSERT_EQ(10, slice->size());
+
+ Buffer expected(data + 5, 10);
+ ASSERT_TRUE(slice->Equals(expected));
+}
+
+TEST(TestBuffer, GetReader) {
+ const std::string data_str = "some data to read";
+ auto data = reinterpret_cast<const uint8_t*>(data_str.c_str());
+
+ auto buf = std::make_shared<Buffer>(data, data_str.size());
+ ASSERT_OK_AND_ASSIGN(auto reader, Buffer::GetReader(buf));
+ ASSERT_OK_AND_EQ(static_cast<int64_t>(data_str.size()), reader->GetSize());
+ ASSERT_OK_AND_ASSIGN(auto read_buf, reader->ReadAt(5, 4));
+ AssertBufferEqual(*read_buf, "data");
+}
+
+TEST(TestBuffer, GetWriter) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> buf, AllocateBuffer(9));
+ ASSERT_OK_AND_ASSIGN(auto writer, Buffer::GetWriter(buf));
+ ASSERT_OK(writer->Write(reinterpret_cast<const uint8_t*>("some data"), 9));
+ AssertBufferEqual(*buf, "some data");
+
+ // Non-mutable buffer
+ buf = std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>("xxx"), 3);
+ ASSERT_RAISES(Invalid, Buffer::GetWriter(buf));
+}
+
+template <typename AllocateFunction>
+void TestZeroSizeAllocateBuffer(MemoryPool* pool, AllocateFunction&& allocate_func) {
+ auto allocated_bytes = pool->bytes_allocated();
+ {
+ std::shared_ptr<Buffer> buffer, buffer2;
+
+ ASSERT_OK(allocate_func(pool, 0, &buffer));
+ AssertIsCPUBuffer(*buffer);
+ ASSERT_EQ(buffer->size(), 0);
+ // Even 0-sized buffers should not have a null data pointer
+ auto data = buffer->data();
+ ASSERT_NE(data, nullptr);
+ ASSERT_EQ(buffer->mutable_data(), data);
+
+ // As an optimization, another 0-size buffer should share the same memory "area"
+ ASSERT_OK(allocate_func(pool, 0, &buffer2));
+ AssertIsCPUBuffer(*buffer2);
+ ASSERT_EQ(buffer2->size(), 0);
+ ASSERT_EQ(buffer2->data(), data);
+
+ ASSERT_GE(pool->bytes_allocated(), allocated_bytes);
+ }
+ ASSERT_EQ(pool->bytes_allocated(), allocated_bytes);
+}
+
+TEST(TestAllocateBuffer, ZeroSize) {
+ MemoryPool* pool = default_memory_pool();
+ auto allocate_func = [](MemoryPool* pool, int64_t size, std::shared_ptr<Buffer>* out) {
+ return AllocateBuffer(size, pool).Value(out);
+ };
+ TestZeroSizeAllocateBuffer(pool, allocate_func);
+}
+
+TEST(TestAllocateResizableBuffer, ZeroSize) {
+ MemoryPool* pool = default_memory_pool();
+ auto allocate_func = [](MemoryPool* pool, int64_t size, std::shared_ptr<Buffer>* out) {
+ ARROW_ASSIGN_OR_RAISE(auto resizable, AllocateResizableBuffer(size, pool));
+ *out = std::move(resizable);
+ return Status::OK();
+ };
+ TestZeroSizeAllocateBuffer(pool, allocate_func);
+}
+
+TEST(TestAllocateResizableBuffer, ZeroResize) {
+ MemoryPool* pool = default_memory_pool();
+ auto allocated_bytes = pool->bytes_allocated();
+ {
+ std::shared_ptr<ResizableBuffer> buffer;
+
+ ASSERT_OK_AND_ASSIGN(buffer, AllocateResizableBuffer(1000, pool));
+ ASSERT_EQ(buffer->size(), 1000);
+ ASSERT_NE(buffer->data(), nullptr);
+ ASSERT_EQ(buffer->mutable_data(), buffer->data());
+
+ ASSERT_GE(pool->bytes_allocated(), allocated_bytes + 1000);
+
+ ASSERT_OK(buffer->Resize(0));
+ ASSERT_NE(buffer->data(), nullptr);
+ ASSERT_EQ(buffer->mutable_data(), buffer->data());
+
+ ASSERT_GE(pool->bytes_allocated(), allocated_bytes);
+ ASSERT_LT(pool->bytes_allocated(), allocated_bytes + 1000);
+ }
+ ASSERT_EQ(pool->bytes_allocated(), allocated_bytes);
+}
+
+TEST(TestBufferBuilder, ResizeReserve) {
+ const std::string data = "some data";
+ auto data_ptr = data.c_str();
+
+ BufferBuilder builder;
+
+ ASSERT_OK(builder.Append(data_ptr, 9));
+ ASSERT_EQ(9, builder.length());
+
+ ASSERT_OK(builder.Resize(128));
+ ASSERT_EQ(128, builder.capacity());
+ ASSERT_EQ(9, builder.length());
+
+ // Do not shrink to fit
+ ASSERT_OK(builder.Resize(64, false));
+ ASSERT_EQ(128, builder.capacity());
+ ASSERT_EQ(9, builder.length());
+
+ // Shrink to fit
+ ASSERT_OK(builder.Resize(64));
+ ASSERT_EQ(64, builder.capacity());
+ ASSERT_EQ(9, builder.length());
+
+ // Reserve elements
+ ASSERT_OK(builder.Reserve(60));
+ ASSERT_EQ(128, builder.capacity());
+ ASSERT_EQ(9, builder.length());
+}
+
+TEST(TestBufferBuilder, Finish) {
+ const std::string data = "some data";
+ auto data_ptr = data.c_str();
+
+ for (const bool shrink_to_fit : {true, false}) {
+ ARROW_SCOPED_TRACE("shrink_to_fit = ", shrink_to_fit);
+ BufferBuilder builder;
+ ASSERT_OK(builder.Append(data_ptr, 9));
+ ASSERT_OK(builder.Append(data_ptr, 9));
+ ASSERT_EQ(18, builder.length());
+ ASSERT_EQ(64, builder.capacity());
+
+ ASSERT_OK_AND_ASSIGN(auto buf, builder.Finish(shrink_to_fit));
+ ASSERT_EQ(buf->size(), 18);
+ ASSERT_EQ(buf->capacity(), 64);
+ }
+ for (const bool shrink_to_fit : {true, false}) {
+ ARROW_SCOPED_TRACE("shrink_to_fit = ", shrink_to_fit);
+ BufferBuilder builder;
+ ASSERT_OK(builder.Reserve(1024));
+ builder.UnsafeAppend(data_ptr, 9);
+ builder.UnsafeAppend(data_ptr, 9);
+ ASSERT_EQ(18, builder.length());
+ ASSERT_EQ(builder.capacity(), 1024);
+
+ ASSERT_OK_AND_ASSIGN(auto buf, builder.Finish(shrink_to_fit));
+ ASSERT_EQ(buf->size(), 18);
+ ASSERT_EQ(buf->capacity(), shrink_to_fit ? 64 : 1024);
+ }
+}
+
+TEST(TestBufferBuilder, FinishEmpty) {
+ for (const bool shrink_to_fit : {true, false}) {
+ ARROW_SCOPED_TRACE("shrink_to_fit = ", shrink_to_fit);
+ BufferBuilder builder;
+ ASSERT_EQ(0, builder.length());
+ ASSERT_EQ(0, builder.capacity());
+
+ ASSERT_OK_AND_ASSIGN(auto buf, builder.Finish(shrink_to_fit));
+ ASSERT_EQ(buf->size(), 0);
+ ASSERT_EQ(buf->capacity(), 0);
+ }
+ for (const bool shrink_to_fit : {true, false}) {
+ ARROW_SCOPED_TRACE("shrink_to_fit = ", shrink_to_fit);
+ BufferBuilder builder;
+ ASSERT_OK(builder.Reserve(1024));
+ ASSERT_EQ(0, builder.length());
+ ASSERT_EQ(1024, builder.capacity());
+
+ ASSERT_OK_AND_ASSIGN(auto buf, builder.Finish(shrink_to_fit));
+ ASSERT_EQ(buf->size(), 0);
+ ASSERT_EQ(buf->capacity(), shrink_to_fit ? 0 : 1024);
+ }
+}
+
+template <typename T>
+class TypedTestBufferBuilder : public ::testing::Test {};
+
+using BufferBuilderElements = ::testing::Types<int16_t, uint32_t, double>;
+
+TYPED_TEST_SUITE(TypedTestBufferBuilder, BufferBuilderElements);
+
+TYPED_TEST(TypedTestBufferBuilder, BasicTypedBufferBuilderUsage) {
+ TypedBufferBuilder<TypeParam> builder;
+
+ ASSERT_OK(builder.Append(static_cast<TypeParam>(0)));
+ ASSERT_EQ(builder.length(), 1);
+ ASSERT_EQ(builder.capacity(), 64 / sizeof(TypeParam));
+
+ constexpr int nvalues = 4;
+ TypeParam values[nvalues];
+ for (int i = 0; i != nvalues; ++i) {
+ values[i] = static_cast<TypeParam>(i);
+ }
+ ASSERT_OK(builder.Append(values, nvalues));
+ ASSERT_EQ(builder.length(), nvalues + 1);
+
+ std::shared_ptr<Buffer> built;
+ ASSERT_OK(builder.Finish(&built));
+ AssertIsCPUBuffer(*built);
+
+ auto data = reinterpret_cast<const TypeParam*>(built->data());
+ ASSERT_EQ(data[0], static_cast<TypeParam>(0));
+ for (auto value : values) {
+ ++data;
+ ASSERT_EQ(*data, value);
+ }
+}
+
+TYPED_TEST(TypedTestBufferBuilder, AppendCopies) {
+ TypedBufferBuilder<TypeParam> builder;
+
+ ASSERT_OK(builder.Append(13, static_cast<TypeParam>(1)));
+ ASSERT_OK(builder.Append(17, static_cast<TypeParam>(0)));
+ ASSERT_EQ(builder.length(), 13 + 17);
+
+ std::shared_ptr<Buffer> built;
+ ASSERT_OK(builder.Finish(&built));
+
+ auto data = reinterpret_cast<const TypeParam*>(built->data());
+ for (int i = 0; i != 13 + 17; ++i, ++data) {
+ ASSERT_EQ(*data, static_cast<TypeParam>(i < 13)) << "index = " << i;
+ }
+}
+
+TEST(TestBoolBufferBuilder, Basics) {
+ TypedBufferBuilder<bool> builder;
+
+ ASSERT_OK(builder.Append(false));
+ ASSERT_EQ(builder.length(), 1);
+ ASSERT_EQ(builder.capacity(), 64 * 8);
+
+ constexpr int nvalues = 4;
+ uint8_t values[nvalues];
+ for (int i = 0; i != nvalues; ++i) {
+ values[i] = static_cast<uint8_t>(i);
+ }
+ ASSERT_OK(builder.Append(values, nvalues));
+ ASSERT_EQ(builder.length(), nvalues + 1);
+
+ ASSERT_EQ(builder.false_count(), 2);
+
+ std::shared_ptr<Buffer> built;
+ ASSERT_OK(builder.Finish(&built));
+ AssertIsCPUBuffer(*built);
+
+ ASSERT_EQ(BitUtil::GetBit(built->data(), 0), false);
+ for (int i = 0; i != nvalues; ++i) {
+ ASSERT_EQ(BitUtil::GetBit(built->data(), i + 1), static_cast<bool>(values[i]));
+ }
+
+ ASSERT_EQ(built->size(), BitUtil::BytesForBits(nvalues + 1));
+}
+
+TEST(TestBoolBufferBuilder, AppendCopies) {
+ TypedBufferBuilder<bool> builder;
+
+ ASSERT_OK(builder.Append(13, true));
+ ASSERT_OK(builder.Append(17, false));
+ ASSERT_EQ(builder.length(), 13 + 17);
+ ASSERT_EQ(builder.capacity(), 64 * 8);
+ ASSERT_EQ(builder.false_count(), 17);
+
+ std::shared_ptr<Buffer> built;
+ ASSERT_OK(builder.Finish(&built));
+ AssertIsCPUBuffer(*built);
+
+ for (int i = 0; i != 13 + 17; ++i) {
+ EXPECT_EQ(BitUtil::GetBit(built->data(), i), i < 13) << "index = " << i;
+ }
+
+ ASSERT_EQ(built->size(), BitUtil::BytesForBits(13 + 17));
+}
+
+TEST(TestBoolBufferBuilder, Reserve) {
+ TypedBufferBuilder<bool> builder;
+
+ ASSERT_OK(builder.Reserve(13 + 17));
+ builder.UnsafeAppend(13, true);
+ builder.UnsafeAppend(17, false);
+ ASSERT_EQ(builder.length(), 13 + 17);
+ ASSERT_EQ(builder.capacity(), 64 * 8);
+ ASSERT_EQ(builder.false_count(), 17);
+
+ ASSERT_OK_AND_ASSIGN(auto built, builder.Finish());
+ AssertIsCPUBuffer(*built);
+ ASSERT_EQ(built->size(), BitUtil::BytesForBits(13 + 17));
+}
+
+template <typename T>
+class TypedTestBuffer : public ::testing::Test {};
+
+using BufferPtrs =
+ ::testing::Types<std::shared_ptr<ResizableBuffer>, std::unique_ptr<ResizableBuffer>>;
+
+TYPED_TEST_SUITE(TypedTestBuffer, BufferPtrs);
+
+TYPED_TEST(TypedTestBuffer, IsMutableFlag) {
+ Buffer buf(nullptr, 0);
+
+ ASSERT_FALSE(buf.is_mutable());
+
+ MutableBuffer mbuf(nullptr, 0);
+ ASSERT_TRUE(mbuf.is_mutable());
+ AssertIsCPUBuffer(mbuf);
+
+ TypeParam pool_buf;
+ ASSERT_OK_AND_ASSIGN(pool_buf, AllocateResizableBuffer(0));
+ ASSERT_TRUE(pool_buf->is_mutable());
+ AssertIsCPUBuffer(*pool_buf);
+}
+
+TYPED_TEST(TypedTestBuffer, Resize) {
+ TypeParam buf;
+ ASSERT_OK_AND_ASSIGN(buf, AllocateResizableBuffer(0));
+ AssertIsCPUBuffer(*buf);
+
+ ASSERT_EQ(0, buf->size());
+ ASSERT_OK(buf->Resize(100));
+ ASSERT_EQ(100, buf->size());
+ ASSERT_OK(buf->Resize(200));
+ ASSERT_EQ(200, buf->size());
+
+ // Make it smaller, too
+ ASSERT_OK(buf->Resize(50, true));
+ ASSERT_EQ(50, buf->size());
+ // We have actually shrunken in size
+ // The spec requires that capacity is a multiple of 64
+ ASSERT_EQ(64, buf->capacity());
+
+ // Resize to a larger capacity again to test shrink_to_fit = false
+ ASSERT_OK(buf->Resize(100));
+ ASSERT_EQ(128, buf->capacity());
+ ASSERT_OK(buf->Resize(50, false));
+ ASSERT_EQ(128, buf->capacity());
+}
+
+TYPED_TEST(TypedTestBuffer, TypedResize) {
+ TypeParam buf;
+ ASSERT_OK_AND_ASSIGN(buf, AllocateResizableBuffer(0));
+
+ ASSERT_EQ(0, buf->size());
+ ASSERT_OK(buf->template TypedResize<double>(100));
+ ASSERT_EQ(800, buf->size());
+ ASSERT_OK(buf->template TypedResize<double>(200));
+ ASSERT_EQ(1600, buf->size());
+
+ ASSERT_OK(buf->template TypedResize<double>(50, true));
+ ASSERT_EQ(400, buf->size());
+ ASSERT_EQ(448, buf->capacity());
+
+ ASSERT_OK(buf->template TypedResize<double>(100));
+ ASSERT_EQ(832, buf->capacity());
+ ASSERT_OK(buf->template TypedResize<double>(50, false));
+ ASSERT_EQ(832, buf->capacity());
+}
+
+TYPED_TEST(TypedTestBuffer, ResizeOOM) {
+// This test doesn't play nice with AddressSanitizer
+#ifndef ADDRESS_SANITIZER
+ // realloc fails, even though there may be no explicit limit
+ TypeParam buf;
+ ASSERT_OK_AND_ASSIGN(buf, AllocateResizableBuffer(0));
+ ASSERT_OK(buf->Resize(100));
+ int64_t to_alloc = std::min<uint64_t>(std::numeric_limits<int64_t>::max(),
+ std::numeric_limits<size_t>::max());
+ // subtract 63 to prevent overflow after the size is aligned
+ to_alloc -= 63;
+ ASSERT_RAISES(OutOfMemory, buf->Resize(to_alloc));
+#endif
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/builder.cc b/src/arrow/cpp/src/arrow/builder.cc
new file mode 100644
index 000000000..7b7ec1706
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/builder.cc
@@ -0,0 +1,312 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/builder.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/hashing.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+class MemoryPool;
+
+// ----------------------------------------------------------------------
+// Helper functions
+
+using arrow::internal::checked_cast;
+
+// Generic int builder that delegates to the builder for a specific
+// type. Used to reduce the number of template instantiations in the
+// exact_index_type case below, to reduce build time and memory usage.
+class ARROW_EXPORT TypeErasedIntBuilder : public ArrayBuilder {
+ public:
+ explicit TypeErasedIntBuilder(MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool) {
+ // Not intended to be used, but adding this is easier than adding a bunch of enable_if
+ // magic to builder_dict.h
+ DCHECK(false);
+ }
+ explicit TypeErasedIntBuilder(const std::shared_ptr<DataType>& type,
+ MemoryPool* pool = default_memory_pool())
+ : ArrayBuilder(pool), type_id_(type->id()) {
+ DCHECK(is_integer(type_id_));
+ switch (type_id_) {
+ case Type::UINT8:
+ builder_ = internal::make_unique<UInt8Builder>(pool);
+ break;
+ case Type::INT8:
+ builder_ = internal::make_unique<Int8Builder>(pool);
+ break;
+ case Type::UINT16:
+ builder_ = internal::make_unique<UInt16Builder>(pool);
+ break;
+ case Type::INT16:
+ builder_ = internal::make_unique<Int16Builder>(pool);
+ break;
+ case Type::UINT32:
+ builder_ = internal::make_unique<UInt32Builder>(pool);
+ break;
+ case Type::INT32:
+ builder_ = internal::make_unique<Int32Builder>(pool);
+ break;
+ case Type::UINT64:
+ builder_ = internal::make_unique<UInt64Builder>(pool);
+ break;
+ case Type::INT64:
+ builder_ = internal::make_unique<Int64Builder>(pool);
+ break;
+ default:
+ DCHECK(false);
+ }
+ }
+
+ void Reset() override { return builder_->Reset(); }
+ Status Append(int32_t value) {
+ switch (type_id_) {
+ case Type::UINT8:
+ return checked_cast<UInt8Builder*>(builder_.get())->Append(value);
+ case Type::INT8:
+ return checked_cast<Int8Builder*>(builder_.get())->Append(value);
+ case Type::UINT16:
+ return checked_cast<UInt16Builder*>(builder_.get())->Append(value);
+ case Type::INT16:
+ return checked_cast<Int16Builder*>(builder_.get())->Append(value);
+ case Type::UINT32:
+ return checked_cast<UInt32Builder*>(builder_.get())->Append(value);
+ case Type::INT32:
+ return checked_cast<Int32Builder*>(builder_.get())->Append(value);
+ case Type::UINT64:
+ return checked_cast<UInt64Builder*>(builder_.get())->Append(value);
+ case Type::INT64:
+ return checked_cast<Int64Builder*>(builder_.get())->Append(value);
+ default:
+ DCHECK(false);
+ }
+ return Status::NotImplemented("Internal implementation error");
+ }
+ Status AppendNull() override { return builder_->AppendNull(); }
+ Status AppendNulls(int64_t length) override { return builder_->AppendNulls(length); }
+ Status AppendEmptyValue() override { return builder_->AppendEmptyValue(); }
+ Status AppendEmptyValues(int64_t length) override {
+ return builder_->AppendEmptyValues(length);
+ }
+
+ Status AppendScalar(const Scalar& scalar, int64_t n_repeats) override {
+ return builder_->AppendScalar(scalar, n_repeats);
+ }
+ Status AppendScalars(const ScalarVector& scalars) override {
+ return builder_->AppendScalars(scalars);
+ }
+ Status AppendArraySlice(const ArrayData& array, int64_t offset,
+ int64_t length) override {
+ return builder_->AppendArraySlice(array, offset, length);
+ }
+
+ Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
+ return builder_->FinishInternal(out);
+ }
+
+ std::shared_ptr<DataType> type() const override { return builder_->type(); }
+
+ private:
+ std::unique_ptr<ArrayBuilder> builder_;
+ Type::type type_id_;
+};
+
+struct DictionaryBuilderCase {
+ template <typename ValueType, typename Enable = typename ValueType::c_type>
+ Status Visit(const ValueType&) {
+ return CreateFor<ValueType>();
+ }
+
+ Status Visit(const NullType&) { return CreateFor<NullType>(); }
+ Status Visit(const BinaryType&) { return CreateFor<BinaryType>(); }
+ Status Visit(const StringType&) { return CreateFor<StringType>(); }
+ Status Visit(const LargeBinaryType&) { return CreateFor<LargeBinaryType>(); }
+ Status Visit(const LargeStringType&) { return CreateFor<LargeStringType>(); }
+ Status Visit(const FixedSizeBinaryType&) { return CreateFor<FixedSizeBinaryType>(); }
+ Status Visit(const Decimal128Type&) { return CreateFor<Decimal128Type>(); }
+ Status Visit(const Decimal256Type&) { return CreateFor<Decimal256Type>(); }
+
+ Status Visit(const DataType& value_type) { return NotImplemented(value_type); }
+ Status Visit(const HalfFloatType& value_type) { return NotImplemented(value_type); }
+ Status NotImplemented(const DataType& value_type) {
+ return Status::NotImplemented(
+ "MakeBuilder: cannot construct builder for dictionaries with value type ",
+ value_type);
+ }
+
+ template <typename ValueType>
+ Status CreateFor() {
+ using AdaptiveBuilderType = DictionaryBuilder<ValueType>;
+ if (dictionary != nullptr) {
+ out->reset(new AdaptiveBuilderType(dictionary, pool));
+ } else if (exact_index_type) {
+ if (!is_integer(index_type->id())) {
+ return Status::TypeError("MakeBuilder: invalid index type ", *index_type);
+ }
+ out->reset(new internal::DictionaryBuilderBase<TypeErasedIntBuilder, ValueType>(
+ index_type, value_type, pool));
+ } else {
+ auto start_int_size = internal::GetByteWidth(*index_type);
+ out->reset(new AdaptiveBuilderType(start_int_size, value_type, pool));
+ }
+ return Status::OK();
+ }
+
+ Status Make() { return VisitTypeInline(*value_type, this); }
+
+ MemoryPool* pool;
+ const std::shared_ptr<DataType>& index_type;
+ const std::shared_ptr<DataType>& value_type;
+ const std::shared_ptr<Array>& dictionary;
+ bool exact_index_type;
+ std::unique_ptr<ArrayBuilder>* out;
+};
+
+struct MakeBuilderImpl {
+ template <typename T>
+ enable_if_not_nested<T, Status> Visit(const T&) {
+ out.reset(new typename TypeTraits<T>::BuilderType(type, pool));
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryType& dict_type) {
+ DictionaryBuilderCase visitor = {pool,
+ dict_type.index_type(),
+ dict_type.value_type(),
+ /*dictionary=*/nullptr,
+ exact_index_type,
+ &out};
+ return visitor.Make();
+ }
+
+ Status Visit(const ListType& list_type) {
+ std::shared_ptr<DataType> value_type = list_type.value_type();
+ ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type));
+ out.reset(new ListBuilder(pool, std::move(value_builder), type));
+ return Status::OK();
+ }
+
+ Status Visit(const LargeListType& list_type) {
+ std::shared_ptr<DataType> value_type = list_type.value_type();
+ ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type));
+ out.reset(new LargeListBuilder(pool, std::move(value_builder), type));
+ return Status::OK();
+ }
+
+ Status Visit(const MapType& map_type) {
+ ARROW_ASSIGN_OR_RAISE(auto key_builder, ChildBuilder(map_type.key_type()));
+ ARROW_ASSIGN_OR_RAISE(auto item_builder, ChildBuilder(map_type.item_type()));
+ out.reset(
+ new MapBuilder(pool, std::move(key_builder), std::move(item_builder), type));
+ return Status::OK();
+ }
+
+ Status Visit(const FixedSizeListType& list_type) {
+ auto value_type = list_type.value_type();
+ ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type));
+ out.reset(new FixedSizeListBuilder(pool, std::move(value_builder), type));
+ return Status::OK();
+ }
+
+ Status Visit(const StructType& struct_type) {
+ ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
+ out.reset(new StructBuilder(type, pool, std::move(field_builders)));
+ return Status::OK();
+ }
+
+ Status Visit(const SparseUnionType&) {
+ ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
+ out.reset(new SparseUnionBuilder(pool, std::move(field_builders), type));
+ return Status::OK();
+ }
+
+ Status Visit(const DenseUnionType&) {
+ ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
+ out.reset(new DenseUnionBuilder(pool, std::move(field_builders), type));
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionType&) { return NotImplemented(); }
+ Status Visit(const DataType&) { return NotImplemented(); }
+
+ Status NotImplemented() {
+ return Status::NotImplemented("MakeBuilder: cannot construct builder for type ",
+ type->ToString());
+ }
+
+ Result<std::unique_ptr<ArrayBuilder>> ChildBuilder(
+ const std::shared_ptr<DataType>& type) {
+ MakeBuilderImpl impl{pool, type, exact_index_type, /*out=*/nullptr};
+ RETURN_NOT_OK(VisitTypeInline(*type, &impl));
+ return std::move(impl.out);
+ }
+
+ Result<std::vector<std::shared_ptr<ArrayBuilder>>> FieldBuilders(const DataType& type,
+ MemoryPool* pool) {
+ std::vector<std::shared_ptr<ArrayBuilder>> field_builders;
+ for (const auto& field : type.fields()) {
+ std::unique_ptr<ArrayBuilder> builder;
+ MakeBuilderImpl impl{pool, field->type(), exact_index_type, /*out=*/nullptr};
+ RETURN_NOT_OK(VisitTypeInline(*field->type(), &impl));
+ field_builders.emplace_back(std::move(impl.out));
+ }
+ return field_builders;
+ }
+
+ MemoryPool* pool;
+ const std::shared_ptr<DataType>& type;
+ bool exact_index_type;
+ std::unique_ptr<ArrayBuilder> out;
+};
+
+Status MakeBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
+ std::unique_ptr<ArrayBuilder>* out) {
+ MakeBuilderImpl impl{pool, type, /*exact_index_type=*/false, /*out=*/nullptr};
+ RETURN_NOT_OK(VisitTypeInline(*type, &impl));
+ *out = std::move(impl.out);
+ return Status::OK();
+}
+
+Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr<DataType>& type,
+ std::unique_ptr<ArrayBuilder>* out) {
+ MakeBuilderImpl impl{pool, type, /*exact_index_type=*/true, /*out=*/nullptr};
+ RETURN_NOT_OK(VisitTypeInline(*type, &impl));
+ *out = std::move(impl.out);
+ return Status::OK();
+}
+
+Status MakeDictionaryBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Array>& dictionary,
+ std::unique_ptr<ArrayBuilder>* out) {
+ const auto& dict_type = static_cast<const DictionaryType&>(*type);
+ DictionaryBuilderCase visitor = {
+ pool, dict_type.index_type(), dict_type.value_type(),
+ dictionary, /*exact_index_type=*/false, out};
+ return visitor.Make();
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/builder.h b/src/arrow/cpp/src/arrow/builder.h
new file mode 100644
index 000000000..4b80e5580
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/builder.h
@@ -0,0 +1,32 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/array/builder_adaptive.h" // IWYU pragma: keep
+#include "arrow/array/builder_base.h" // IWYU pragma: keep
+#include "arrow/array/builder_binary.h" // IWYU pragma: keep
+#include "arrow/array/builder_decimal.h" // IWYU pragma: keep
+#include "arrow/array/builder_dict.h" // IWYU pragma: keep
+#include "arrow/array/builder_nested.h" // IWYU pragma: keep
+#include "arrow/array/builder_primitive.h" // IWYU pragma: keep
+#include "arrow/array/builder_time.h" // IWYU pragma: keep
+#include "arrow/array/builder_union.h" // IWYU pragma: keep
+#include "arrow/status.h"
+#include "arrow/util/visibility.h"
diff --git a/src/arrow/cpp/src/arrow/builder_benchmark.cc b/src/arrow/cpp/src/arrow/builder_benchmark.cc
new file mode 100644
index 000000000..d0edb4b2d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/builder_benchmark.cc
@@ -0,0 +1,453 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <numeric>
+#include <random>
+#include <string>
+#include <vector>
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/builder.h"
+#include "arrow/memory_pool.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+
+using ValueType = int64_t;
+using VectorType = std::vector<ValueType>;
+constexpr int64_t kNumberOfElements = 256 * 512;
+
+static VectorType AlmostU8CompressibleVector() {
+ VectorType data(kNumberOfElements, 64);
+
+ // Insert an element late in the game that does not fit in the 8bit
+ // representation. This forces AdaptiveIntBuilder's to resize.
+ data[kNumberOfElements - 2] = 1L << 13;
+
+ return data;
+}
+
+constexpr int64_t kRounds = 256;
+static VectorType kData = AlmostU8CompressibleVector();
+constexpr int64_t kBytesProcessPerRound = kNumberOfElements * sizeof(ValueType);
+constexpr int64_t kBytesProcessed = kRounds * kBytesProcessPerRound;
+
+static const char* kBinaryString = "12345678";
+static arrow::util::string_view kBinaryView(kBinaryString);
+
+static void BuildIntArrayNoNulls(benchmark::State& state) { // NOLINT non-const reference
+ for (auto _ : state) {
+ Int64Builder builder;
+
+ for (int i = 0; i < kRounds; i++) {
+ ABORT_NOT_OK(builder.AppendValues(kData.data(), kData.size(), nullptr));
+ }
+
+ std::shared_ptr<Array> out;
+ ABORT_NOT_OK(builder.Finish(&out));
+ }
+
+ state.SetBytesProcessed(state.iterations() * kBytesProcessed);
+}
+
+static void BuildAdaptiveIntNoNulls(
+ benchmark::State& state) { // NOLINT non-const reference
+ for (auto _ : state) {
+ AdaptiveIntBuilder builder;
+
+ for (int i = 0; i < kRounds; i++) {
+ ABORT_NOT_OK(builder.AppendValues(kData.data(), kData.size(), nullptr));
+ }
+
+ std::shared_ptr<Array> out;
+ ABORT_NOT_OK(builder.Finish(&out));
+ }
+
+ state.SetBytesProcessed(state.iterations() * kBytesProcessed);
+}
+
+static void BuildAdaptiveIntNoNullsScalarAppend(
+ benchmark::State& state) { // NOLINT non-const reference
+ for (auto _ : state) {
+ AdaptiveIntBuilder builder;
+
+ for (int i = 0; i < kRounds; i++) {
+ for (size_t j = 0; j < kData.size(); j++) {
+ ABORT_NOT_OK(builder.Append(kData[i]))
+ }
+ }
+
+ std::shared_ptr<Array> out;
+ ABORT_NOT_OK(builder.Finish(&out));
+ }
+
+ state.SetBytesProcessed(state.iterations() * kBytesProcessed);
+}
+
+static void BuildBooleanArrayNoNulls(
+ benchmark::State& state) { // NOLINT non-const reference
+
+ size_t n_bytes = kBytesProcessPerRound;
+ const uint8_t* data = reinterpret_cast<const uint8_t*>(kData.data());
+
+ for (auto _ : state) {
+ BooleanBuilder builder;
+
+ for (int i = 0; i < kRounds; i++) {
+ ABORT_NOT_OK(builder.AppendValues(data, n_bytes));
+ }
+
+ std::shared_ptr<Array> out;
+ ABORT_NOT_OK(builder.Finish(&out));
+ }
+
+ state.SetBytesProcessed(state.iterations() * kBytesProcessed);
+}
+
+static void BuildBinaryArray(benchmark::State& state) { // NOLINT non-const reference
+ for (auto _ : state) {
+ BinaryBuilder builder;
+
+ for (int64_t i = 0; i < kRounds * kNumberOfElements; i++) {
+ ABORT_NOT_OK(builder.Append(kBinaryView));
+ }
+
+ std::shared_ptr<Array> out;
+ ABORT_NOT_OK(builder.Finish(&out));
+ }
+
+ state.SetBytesProcessed(state.iterations() * kBytesProcessed);
+}
+
+static void BuildChunkedBinaryArray(
+ benchmark::State& state) { // NOLINT non-const reference
+ // 1MB chunks
+ const int32_t kChunkSize = 1 << 20;
+
+ for (auto _ : state) {
+ internal::ChunkedBinaryBuilder builder(kChunkSize);
+
+ for (int64_t i = 0; i < kRounds * kNumberOfElements; i++) {
+ ABORT_NOT_OK(builder.Append(kBinaryView));
+ }
+
+ ArrayVector out;
+ ABORT_NOT_OK(builder.Finish(&out));
+ }
+
+ state.SetBytesProcessed(state.iterations() * kBytesProcessed);
+}
+
+static void BuildFixedSizeBinaryArray(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto type = fixed_size_binary(static_cast<int32_t>(kBinaryView.size()));
+
+ for (auto _ : state) {
+ FixedSizeBinaryBuilder builder(type);
+
+ for (int64_t i = 0; i < kRounds * kNumberOfElements; i++) {
+ ABORT_NOT_OK(builder.Append(kBinaryView));
+ }
+
+ std::shared_ptr<Array> out;
+ ABORT_NOT_OK(builder.Finish(&out));
+ }
+
+ state.SetBytesProcessed(state.iterations() * kBytesProcessed);
+}
+
+static void BuildDecimalArray(benchmark::State& state) { // NOLINT non-const reference
+ auto type = decimal(10, 5);
+ Decimal128 value;
+ int32_t precision = 0;
+ int32_t scale = 0;
+ ABORT_NOT_OK(Decimal128::FromString("1234.1234", &value, &precision, &scale));
+ for (auto _ : state) {
+ Decimal128Builder builder(type);
+
+ for (int64_t i = 0; i < kRounds * kNumberOfElements; i++) {
+ ABORT_NOT_OK(builder.Append(value));
+ }
+
+ std::shared_ptr<Array> out;
+ ABORT_NOT_OK(builder.Finish(&out));
+ }
+
+ state.SetBytesProcessed(state.iterations() * kRounds * kNumberOfElements * 16);
+}
+
+// ----------------------------------------------------------------------
+// DictionaryBuilder benchmarks
+
+size_t kDistinctElements = kNumberOfElements / 100;
+
+// Testing with different distributions of integer values helps stress
+// the hash table's robustness.
+
+// Make a vector out of `n_distinct` sequential int values
+template <class Integer = ValueType>
+static std::vector<Integer> MakeSequentialIntDictFodder() {
+ std::default_random_engine gen(42);
+ std::vector<Integer> values(kNumberOfElements);
+ {
+ std::uniform_int_distribution<Integer> values_dist(0, kDistinctElements - 1);
+ std::generate(values.begin(), values.end(), [&]() { return values_dist(gen); });
+ }
+ return values;
+}
+
+// Make a vector out of `n_distinct` int values with potentially colliding hash
+// entries as only their highest bits differ.
+template <class Integer = ValueType>
+static std::vector<Integer> MakeSimilarIntDictFodder() {
+ std::default_random_engine gen(42);
+ std::vector<Integer> values(kNumberOfElements);
+ {
+ std::uniform_int_distribution<Integer> values_dist(0, kDistinctElements - 1);
+ auto max_int = std::numeric_limits<Integer>::max();
+ auto multiplier =
+ static_cast<Integer>(BitUtil::NextPower2(max_int / kDistinctElements / 2));
+ std::generate(values.begin(), values.end(),
+ [&]() { return multiplier * values_dist(gen); });
+ }
+ return values;
+}
+
+// Make a vector out of `n_distinct` random int values
+template <class Integer = ValueType>
+static std::vector<Integer> MakeRandomIntDictFodder() {
+ std::default_random_engine gen(42);
+ std::vector<Integer> values_dict(kDistinctElements);
+ std::vector<Integer> values(kNumberOfElements);
+
+ {
+ std::uniform_int_distribution<Integer> values_dist(
+ 0, std::numeric_limits<Integer>::max());
+ std::generate(values_dict.begin(), values_dict.end(),
+ [&]() { return static_cast<Integer>(values_dist(gen)); });
+ }
+ {
+ std::uniform_int_distribution<int32_t> indices_dist(
+ 0, static_cast<int32_t>(kDistinctElements - 1));
+ std::generate(values.begin(), values.end(),
+ [&]() { return values_dict[indices_dist(gen)]; });
+ }
+ return values;
+}
+
+// Make a vector out of `kDistinctElements` string values
+static std::vector<std::string> MakeStringDictFodder() {
+ std::default_random_engine gen(42);
+ std::vector<std::string> values_dict(kDistinctElements);
+ std::vector<std::string> values(kNumberOfElements);
+
+ {
+ auto it = values_dict.begin();
+ // Add empty string
+ *it++ = "";
+ // Add a few similar strings
+ *it++ = "abc";
+ *it++ = "abcdef";
+ *it++ = "abcfgh";
+ // Add random strings
+ std::uniform_int_distribution<int32_t> length_dist(2, 20);
+ std::independent_bits_engine<std::default_random_engine, 8, uint16_t> bytes_gen(42);
+
+ std::generate(it, values_dict.end(), [&] {
+ auto length = length_dist(gen);
+ std::string s(length, 'X');
+ for (int32_t i = 0; i < length; ++i) {
+ s[i] = static_cast<char>(bytes_gen());
+ }
+ return s;
+ });
+ }
+ {
+ std::uniform_int_distribution<int32_t> indices_dist(
+ 0, static_cast<int32_t>(kDistinctElements - 1));
+ std::generate(values.begin(), values.end(),
+ [&] { return values_dict[indices_dist(gen)]; });
+ }
+ return values;
+}
+
+template <class DictionaryBuilderType, class Scalar>
+static void BenchmarkDictionaryArray(
+ benchmark::State& state, // NOLINT non-const reference
+ const std::vector<Scalar>& fodder, size_t fodder_nbytes = 0) {
+ for (auto _ : state) {
+ DictionaryBuilderType builder(default_memory_pool());
+
+ for (int64_t i = 0; i < kRounds; i++) {
+ for (const auto& value : fodder) {
+ ABORT_NOT_OK(builder.Append(value));
+ }
+ }
+
+ std::shared_ptr<Array> out;
+ ABORT_NOT_OK(builder.Finish(&out));
+ }
+
+ if (fodder_nbytes == 0) {
+ fodder_nbytes = fodder.size() * sizeof(Scalar);
+ }
+ state.SetBytesProcessed(state.iterations() * fodder_nbytes * kRounds);
+}
+
+static void BuildInt64DictionaryArrayRandom(
+ benchmark::State& state) { // NOLINT non-const reference
+ const auto fodder = MakeRandomIntDictFodder();
+ BenchmarkDictionaryArray<DictionaryBuilder<Int64Type>>(state, fodder);
+}
+
+static void BuildInt64DictionaryArraySequential(
+ benchmark::State& state) { // NOLINT non-const reference
+ const auto fodder = MakeSequentialIntDictFodder();
+ BenchmarkDictionaryArray<DictionaryBuilder<Int64Type>>(state, fodder);
+}
+
+static void BuildInt64DictionaryArraySimilar(
+ benchmark::State& state) { // NOLINT non-const reference
+ const auto fodder = MakeSimilarIntDictFodder();
+ BenchmarkDictionaryArray<DictionaryBuilder<Int64Type>>(state, fodder);
+}
+
+static void BuildStringDictionaryArray(
+ benchmark::State& state) { // NOLINT non-const reference
+ const auto fodder = MakeStringDictFodder();
+ auto fodder_nbytes =
+ std::accumulate(fodder.begin(), fodder.end(), 0ULL,
+ [&](size_t acc, const std::string& s) { return acc + s.size(); });
+ BenchmarkDictionaryArray<BinaryDictionaryBuilder>(state, fodder, fodder_nbytes);
+}
+
+static void ArrayDataConstructDestruct(
+ benchmark::State& state) { // NOLINT non-const reference
+ std::vector<std::shared_ptr<ArrayData>> arrays;
+
+ const int kNumArrays = 1000;
+ auto InitArrays = [&]() {
+ for (int i = 0; i < kNumArrays; ++i) {
+ arrays.emplace_back(new ArrayData);
+ }
+ };
+
+ for (auto _ : state) {
+ InitArrays();
+ arrays.clear();
+ }
+}
+
+// ----------------------------------------------------------------------
+// BufferBuilder benchmarks
+
+static void BenchmarkBufferBuilder(
+ const std::string& datum,
+ benchmark::State& state) { // NOLINT non-const reference
+ const void* raw_data = datum.data();
+ int64_t raw_nbytes = static_cast<int64_t>(datum.size());
+ // Write approx. 256 MB to BufferBuilder
+ int64_t num_raw_values = (1 << 28) / raw_nbytes;
+ for (auto _ : state) {
+ BufferBuilder builder;
+ std::shared_ptr<Buffer> buf;
+ for (int64_t i = 0; i < num_raw_values; ++i) {
+ ABORT_NOT_OK(builder.Append(raw_data, raw_nbytes));
+ }
+ ABORT_NOT_OK(builder.Finish(&buf));
+ }
+ state.SetBytesProcessed(int64_t(state.iterations()) * num_raw_values * raw_nbytes);
+}
+
+static void BufferBuilderTinyWrites(
+ benchmark::State& state) { // NOLINT non-const reference
+ // A 8-byte datum
+ return BenchmarkBufferBuilder("abdefghi", state);
+}
+
+static void BufferBuilderSmallWrites(
+ benchmark::State& state) { // NOLINT non-const reference
+ // A 700-byte datum
+ std::string datum;
+ for (int i = 0; i < 100; ++i) {
+ datum += "abcdefg";
+ }
+ return BenchmarkBufferBuilder(datum, state);
+}
+
+static void BufferBuilderLargeWrites(
+ benchmark::State& state) { // NOLINT non-const reference
+ // A 1.5MB datum
+ std::string datum(1500000, 'x');
+ return BenchmarkBufferBuilder(datum, state);
+}
+
+BENCHMARK(BufferBuilderTinyWrites)->UseRealTime();
+BENCHMARK(BufferBuilderSmallWrites)->UseRealTime();
+BENCHMARK(BufferBuilderLargeWrites)->UseRealTime();
+
+// ----------------------------------------------------------------------
+// Benchmark declarations
+//
+
+#ifdef ARROW_WITH_BENCHMARKS_REFERENCE
+
+// This benchmarks acts as a reference to the native std::vector
+// implementation. It appends kRounds chunks into a vector.
+static void ReferenceBuildVectorNoNulls(
+ benchmark::State& state) { // NOLINT non-const reference
+ for (auto _ : state) {
+ std::vector<int64_t> builder;
+
+ for (int i = 0; i < kRounds; i++) {
+ builder.insert(builder.end(), kData.cbegin(), kData.cend());
+ }
+ }
+
+ state.SetBytesProcessed(state.iterations() * kBytesProcessed);
+}
+
+BENCHMARK(ReferenceBuildVectorNoNulls);
+
+#endif
+
+BENCHMARK(BuildBooleanArrayNoNulls);
+
+BENCHMARK(BuildIntArrayNoNulls);
+BENCHMARK(BuildAdaptiveIntNoNulls);
+BENCHMARK(BuildAdaptiveIntNoNullsScalarAppend);
+
+BENCHMARK(BuildBinaryArray);
+BENCHMARK(BuildChunkedBinaryArray);
+BENCHMARK(BuildFixedSizeBinaryArray);
+BENCHMARK(BuildDecimalArray);
+
+BENCHMARK(BuildInt64DictionaryArrayRandom);
+BENCHMARK(BuildInt64DictionaryArraySequential);
+BENCHMARK(BuildInt64DictionaryArraySimilar);
+BENCHMARK(BuildStringDictionaryArray);
+
+BENCHMARK(ArrayDataConstructDestruct);
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/c/CMakeLists.txt b/src/arrow/cpp/src/arrow/c/CMakeLists.txt
new file mode 100644
index 000000000..3765477ba
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/c/CMakeLists.txt
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_arrow_test(bridge_test PREFIX "arrow-c")
+
+add_arrow_benchmark(bridge_benchmark)
+
+arrow_install_all_headers("arrow/c")
diff --git a/src/arrow/cpp/src/arrow/c/abi.h b/src/arrow/cpp/src/arrow/c/abi.h
new file mode 100644
index 000000000..a78170dbd
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/c/abi.h
@@ -0,0 +1,103 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define ARROW_FLAG_DICTIONARY_ORDERED 1
+#define ARROW_FLAG_NULLABLE 2
+#define ARROW_FLAG_MAP_KEYS_SORTED 4
+
+struct ArrowSchema {
+ // Array type description
+ const char* format;
+ const char* name;
+ const char* metadata;
+ int64_t flags;
+ int64_t n_children;
+ struct ArrowSchema** children;
+ struct ArrowSchema* dictionary;
+
+ // Release callback
+ void (*release)(struct ArrowSchema*);
+ // Opaque producer-specific data
+ void* private_data;
+};
+
+struct ArrowArray {
+ // Array data description
+ int64_t length;
+ int64_t null_count;
+ int64_t offset;
+ int64_t n_buffers;
+ int64_t n_children;
+ const void** buffers;
+ struct ArrowArray** children;
+ struct ArrowArray* dictionary;
+
+ // Release callback
+ void (*release)(struct ArrowArray*);
+ // Opaque producer-specific data
+ void* private_data;
+};
+
+// EXPERIMENTAL: C stream interface
+
+struct ArrowArrayStream {
+ // Callback to get the stream type
+ // (will be the same for all arrays in the stream).
+ //
+ // Return value: 0 if successful, an `errno`-compatible error code otherwise.
+ //
+ // If successful, the ArrowSchema must be released independently from the stream.
+ int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out);
+
+ // Callback to get the next array
+ // (if no error and the array is released, the stream has ended)
+ //
+ // Return value: 0 if successful, an `errno`-compatible error code otherwise.
+ //
+ // If successful, the ArrowArray must be released independently from the stream.
+ int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out);
+
+ // Callback to get optional detailed error information.
+ // This must only be called if the last stream operation failed
+ // with a non-0 return code.
+ //
+ // Return value: pointer to a null-terminated character array describing
+ // the last error, or NULL if no description is available.
+ //
+ // The returned pointer is only valid until the next operation on this stream
+ // (including release).
+ const char* (*get_last_error)(struct ArrowArrayStream*);
+
+ // Release callback: release the stream's own resources.
+ // Note that arrays returned by `get_next` must be individually released.
+ void (*release)(struct ArrowArrayStream*);
+
+ // Opaque producer-specific data
+ void* private_data;
+};
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/src/arrow/cpp/src/arrow/c/bridge.cc b/src/arrow/cpp/src/arrow/c/bridge.cc
new file mode 100644
index 000000000..e5bfad810
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/c/bridge.cc
@@ -0,0 +1,1818 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/c/bridge.h"
+
+#include <algorithm>
+#include <cerrno>
+#include <cstring>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/c/helpers.h"
+#include "arrow/c/util_internal.h"
+#include "arrow/extension_type.h"
+#include "arrow/memory_pool.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/stl_allocator.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/small_vector.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/value_parsing.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+using internal::SmallVector;
+using internal::StaticVector;
+
+using internal::ArrayExportGuard;
+using internal::ArrayExportTraits;
+using internal::SchemaExportGuard;
+using internal::SchemaExportTraits;
+
+namespace {
+
+Status ExportingNotImplemented(const DataType& type) {
+ return Status::NotImplemented("Exporting ", type.ToString(), " array not supported");
+}
+
+// Allocate exported private data using MemoryPool,
+// to allow accounting memory and checking for memory leaks.
+
+// XXX use Gandiva's SimpleArena?
+
+template <typename Derived>
+struct PoolAllocationMixin {
+ static void* operator new(size_t size) {
+ DCHECK_EQ(size, sizeof(Derived));
+ uint8_t* data;
+ ARROW_CHECK_OK(default_memory_pool()->Allocate(static_cast<int64_t>(size), &data));
+ return data;
+ }
+
+ static void operator delete(void* ptr) {
+ default_memory_pool()->Free(reinterpret_cast<uint8_t*>(ptr), sizeof(Derived));
+ }
+};
+
+//////////////////////////////////////////////////////////////////////////
+// C schema export
+
+struct ExportedSchemaPrivateData : PoolAllocationMixin<ExportedSchemaPrivateData> {
+ std::string format_;
+ std::string name_;
+ std::string metadata_;
+ struct ArrowSchema dictionary_;
+ SmallVector<struct ArrowSchema, 1> children_;
+ SmallVector<struct ArrowSchema*, 4> child_pointers_;
+
+ ExportedSchemaPrivateData() = default;
+ ARROW_DEFAULT_MOVE_AND_ASSIGN(ExportedSchemaPrivateData);
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ExportedSchemaPrivateData);
+};
+
+void ReleaseExportedSchema(struct ArrowSchema* schema) {
+ if (ArrowSchemaIsReleased(schema)) {
+ return;
+ }
+ for (int64_t i = 0; i < schema->n_children; ++i) {
+ struct ArrowSchema* child = schema->children[i];
+ ArrowSchemaRelease(child);
+ DCHECK(ArrowSchemaIsReleased(child))
+ << "Child release callback should have marked it released";
+ }
+ struct ArrowSchema* dict = schema->dictionary;
+ if (dict != nullptr) {
+ ArrowSchemaRelease(dict);
+ DCHECK(ArrowSchemaIsReleased(dict))
+ << "Dictionary release callback should have marked it released";
+ }
+ DCHECK_NE(schema->private_data, nullptr);
+ delete reinterpret_cast<ExportedSchemaPrivateData*>(schema->private_data);
+
+ ArrowSchemaMarkReleased(schema);
+}
+
+template <typename SizeType>
+Result<int32_t> DowncastMetadataSize(SizeType size) {
+ auto res = static_cast<int32_t>(size);
+ if (res < 0 || static_cast<SizeType>(res) != size) {
+ return Status::Invalid("Metadata too large (more than 2**31 items or bytes)");
+ }
+ return res;
+}
+
+Result<std::string> EncodeMetadata(const KeyValueMetadata& metadata) {
+ ARROW_ASSIGN_OR_RAISE(auto npairs, DowncastMetadataSize(metadata.size()));
+ std::string exported;
+
+ // Pre-compute total string size
+ size_t total_size = 4;
+ for (int32_t i = 0; i < npairs; ++i) {
+ total_size += 8 + metadata.key(i).length() + metadata.value(i).length();
+ }
+ exported.resize(total_size);
+
+ char* data_start = &exported[0];
+ char* data = data_start;
+ auto write_int32 = [&](int32_t v) -> void {
+ memcpy(data, &v, 4);
+ data += 4;
+ };
+ auto write_string = [&](const std::string& s) -> Status {
+ ARROW_ASSIGN_OR_RAISE(auto len, DowncastMetadataSize(s.length()));
+ write_int32(len);
+ if (len > 0) {
+ memcpy(data, s.data(), len);
+ data += len;
+ }
+ return Status::OK();
+ };
+
+ write_int32(npairs);
+ for (int32_t i = 0; i < npairs; ++i) {
+ RETURN_NOT_OK(write_string(metadata.key(i)));
+ RETURN_NOT_OK(write_string(metadata.value(i)));
+ }
+ DCHECK_EQ(static_cast<size_t>(data - data_start), total_size);
+ return exported;
+}
+
+struct SchemaExporter {
+ Status ExportField(const Field& field) {
+ export_.name_ = field.name();
+ flags_ = field.nullable() ? ARROW_FLAG_NULLABLE : 0;
+
+ const DataType* type = UnwrapExtension(field.type().get());
+ RETURN_NOT_OK(ExportFormat(*type));
+ RETURN_NOT_OK(ExportChildren(type->fields()));
+ RETURN_NOT_OK(ExportMetadata(field.metadata().get()));
+ return Status::OK();
+ }
+
+ Status ExportType(const DataType& orig_type) {
+ flags_ = ARROW_FLAG_NULLABLE;
+
+ const DataType* type = UnwrapExtension(&orig_type);
+ RETURN_NOT_OK(ExportFormat(*type));
+ RETURN_NOT_OK(ExportChildren(type->fields()));
+ // There may be additional metadata to export
+ RETURN_NOT_OK(ExportMetadata(nullptr));
+ return Status::OK();
+ }
+
+ Status ExportSchema(const Schema& schema) {
+ static const StructType dummy_struct_type({});
+ flags_ = 0;
+
+ RETURN_NOT_OK(ExportFormat(dummy_struct_type));
+ RETURN_NOT_OK(ExportChildren(schema.fields()));
+ RETURN_NOT_OK(ExportMetadata(schema.metadata().get()));
+ return Status::OK();
+ }
+
+ // Finalize exporting by setting C struct fields and allocating
+ // autonomous private data for each schema node.
+ //
+ // This function can't fail, as properly reclaiming memory in case of error
+ // would be too fragile. After this function returns, memory is reclaimed
+ // by calling the release() pointer in the top level ArrowSchema struct.
+ void Finish(struct ArrowSchema* c_struct) {
+ // First, create permanent ExportedSchemaPrivateData
+ auto pdata = new ExportedSchemaPrivateData(std::move(export_));
+
+ // Second, finish dictionary and children.
+ if (dict_exporter_) {
+ dict_exporter_->Finish(&pdata->dictionary_);
+ }
+ pdata->child_pointers_.resize(child_exporters_.size(), nullptr);
+ for (size_t i = 0; i < child_exporters_.size(); ++i) {
+ auto ptr = pdata->child_pointers_[i] = &pdata->children_[i];
+ child_exporters_[i].Finish(ptr);
+ }
+
+ // Third, fill C struct.
+ DCHECK_NE(c_struct, nullptr);
+ memset(c_struct, 0, sizeof(*c_struct));
+
+ c_struct->format = pdata->format_.c_str();
+ c_struct->name = pdata->name_.c_str();
+ c_struct->metadata = pdata->metadata_.empty() ? nullptr : pdata->metadata_.c_str();
+ c_struct->flags = flags_;
+
+ c_struct->n_children = static_cast<int64_t>(child_exporters_.size());
+ c_struct->children = c_struct->n_children ? pdata->child_pointers_.data() : nullptr;
+ c_struct->dictionary = dict_exporter_ ? &pdata->dictionary_ : nullptr;
+ c_struct->private_data = pdata;
+ c_struct->release = ReleaseExportedSchema;
+ }
+
+ const DataType* UnwrapExtension(const DataType* type) {
+ if (type->id() == Type::EXTENSION) {
+ const auto& ext_type = checked_cast<const ExtensionType&>(*type);
+ additional_metadata_.reserve(2);
+ additional_metadata_.emplace_back(kExtensionTypeKeyName, ext_type.extension_name());
+ additional_metadata_.emplace_back(kExtensionMetadataKeyName, ext_type.Serialize());
+ return ext_type.storage_type().get();
+ }
+ return type;
+ }
+
+ Status ExportFormat(const DataType& type) {
+ if (type.id() == Type::DICTIONARY) {
+ const auto& dict_type = checked_cast<const DictionaryType&>(type);
+ if (dict_type.ordered()) {
+ flags_ |= ARROW_FLAG_DICTIONARY_ORDERED;
+ }
+ // Dictionary type: parent struct describes index type,
+ // child dictionary struct describes value type.
+ RETURN_NOT_OK(VisitTypeInline(*dict_type.index_type(), this));
+ dict_exporter_.reset(new SchemaExporter());
+ RETURN_NOT_OK(dict_exporter_->ExportType(*dict_type.value_type()));
+ } else {
+ RETURN_NOT_OK(VisitTypeInline(type, this));
+ }
+ DCHECK(!export_.format_.empty());
+ return Status::OK();
+ }
+
+ Status ExportChildren(const std::vector<std::shared_ptr<Field>>& fields) {
+ export_.children_.resize(fields.size());
+ child_exporters_.resize(fields.size());
+ for (size_t i = 0; i < fields.size(); ++i) {
+ RETURN_NOT_OK(child_exporters_[i].ExportField(*fields[i]));
+ }
+ return Status::OK();
+ }
+
+ Status ExportMetadata(const KeyValueMetadata* orig_metadata) {
+ static const KeyValueMetadata empty_metadata;
+
+ if (orig_metadata == nullptr) {
+ orig_metadata = &empty_metadata;
+ }
+ if (additional_metadata_.empty()) {
+ if (orig_metadata->size() > 0) {
+ ARROW_ASSIGN_OR_RAISE(export_.metadata_, EncodeMetadata(*orig_metadata));
+ }
+ return Status::OK();
+ }
+ // Additional metadata needs to be appended to the existing
+ // (for extension types)
+ KeyValueMetadata metadata(orig_metadata->keys(), orig_metadata->values());
+ for (const auto& kv : additional_metadata_) {
+ // The metadata may already be there => ignore
+ if (metadata.Contains(kv.first)) {
+ continue;
+ }
+ metadata.Append(kv.first, kv.second);
+ }
+ ARROW_ASSIGN_OR_RAISE(export_.metadata_, EncodeMetadata(metadata));
+ return Status::OK();
+ }
+
+ Status SetFormat(std::string s) {
+ export_.format_ = std::move(s);
+ return Status::OK();
+ }
+
+ // Type-specific visitors
+
+ Status Visit(const DataType& type) { return ExportingNotImplemented(type); }
+
+ Status Visit(const NullType& type) { return SetFormat("n"); }
+
+ Status Visit(const BooleanType& type) { return SetFormat("b"); }
+
+ Status Visit(const Int8Type& type) { return SetFormat("c"); }
+
+ Status Visit(const UInt8Type& type) { return SetFormat("C"); }
+
+ Status Visit(const Int16Type& type) { return SetFormat("s"); }
+
+ Status Visit(const UInt16Type& type) { return SetFormat("S"); }
+
+ Status Visit(const Int32Type& type) { return SetFormat("i"); }
+
+ Status Visit(const UInt32Type& type) { return SetFormat("I"); }
+
+ Status Visit(const Int64Type& type) { return SetFormat("l"); }
+
+ Status Visit(const UInt64Type& type) { return SetFormat("L"); }
+
+ Status Visit(const HalfFloatType& type) { return SetFormat("e"); }
+
+ Status Visit(const FloatType& type) { return SetFormat("f"); }
+
+ Status Visit(const DoubleType& type) { return SetFormat("g"); }
+
+ Status Visit(const FixedSizeBinaryType& type) {
+ return SetFormat("w:" + std::to_string(type.byte_width()));
+ }
+
+ Status Visit(const DecimalType& type) {
+ if (type.bit_width() == 128) {
+ // 128 is the default bit-width
+ return SetFormat("d:" + std::to_string(type.precision()) + "," +
+ std::to_string(type.scale()));
+ } else {
+ return SetFormat("d:" + std::to_string(type.precision()) + "," +
+ std::to_string(type.scale()) + "," +
+ std::to_string(type.bit_width()));
+ }
+ }
+
+ Status Visit(const BinaryType& type) { return SetFormat("z"); }
+
+ Status Visit(const LargeBinaryType& type) { return SetFormat("Z"); }
+
+ Status Visit(const StringType& type) { return SetFormat("u"); }
+
+ Status Visit(const LargeStringType& type) { return SetFormat("U"); }
+
+ Status Visit(const Date32Type& type) { return SetFormat("tdD"); }
+
+ Status Visit(const Date64Type& type) { return SetFormat("tdm"); }
+
+ Status Visit(const Time32Type& type) {
+ switch (type.unit()) {
+ case TimeUnit::SECOND:
+ export_.format_ = "tts";
+ break;
+ case TimeUnit::MILLI:
+ export_.format_ = "ttm";
+ break;
+ default:
+ return Status::Invalid("Invalid time unit for Time32: ", type.unit());
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const Time64Type& type) {
+ switch (type.unit()) {
+ case TimeUnit::MICRO:
+ export_.format_ = "ttu";
+ break;
+ case TimeUnit::NANO:
+ export_.format_ = "ttn";
+ break;
+ default:
+ return Status::Invalid("Invalid time unit for Time64: ", type.unit());
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const TimestampType& type) {
+ switch (type.unit()) {
+ case TimeUnit::SECOND:
+ export_.format_ = "tss:";
+ break;
+ case TimeUnit::MILLI:
+ export_.format_ = "tsm:";
+ break;
+ case TimeUnit::MICRO:
+ export_.format_ = "tsu:";
+ break;
+ case TimeUnit::NANO:
+ export_.format_ = "tsn:";
+ break;
+ default:
+ return Status::Invalid("Invalid time unit for Timestamp: ", type.unit());
+ }
+ export_.format_ += type.timezone();
+ return Status::OK();
+ }
+
+ Status Visit(const DurationType& type) {
+ switch (type.unit()) {
+ case TimeUnit::SECOND:
+ export_.format_ = "tDs";
+ break;
+ case TimeUnit::MILLI:
+ export_.format_ = "tDm";
+ break;
+ case TimeUnit::MICRO:
+ export_.format_ = "tDu";
+ break;
+ case TimeUnit::NANO:
+ export_.format_ = "tDn";
+ break;
+ default:
+ return Status::Invalid("Invalid time unit for Duration: ", type.unit());
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const MonthIntervalType& type) { return SetFormat("tiM"); }
+
+ Status Visit(const DayTimeIntervalType& type) { return SetFormat("tiD"); }
+
+ Status Visit(const MonthDayNanoIntervalType& type) { return SetFormat("tin"); }
+
+ Status Visit(const ListType& type) { return SetFormat("+l"); }
+
+ Status Visit(const LargeListType& type) { return SetFormat("+L"); }
+
+ Status Visit(const FixedSizeListType& type) {
+ return SetFormat("+w:" + std::to_string(type.list_size()));
+ }
+
+ Status Visit(const StructType& type) { return SetFormat("+s"); }
+
+ Status Visit(const MapType& type) {
+ export_.format_ = "+m";
+ if (type.keys_sorted()) {
+ flags_ |= ARROW_FLAG_MAP_KEYS_SORTED;
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const UnionType& type) {
+ std::string& s = export_.format_;
+ s = "+u";
+ if (type.mode() == UnionMode::DENSE) {
+ s += "d:";
+ } else {
+ DCHECK_EQ(type.mode(), UnionMode::SPARSE);
+ s += "s:";
+ }
+ bool first = true;
+ for (const auto code : type.type_codes()) {
+ if (!first) {
+ s += ",";
+ }
+ s += std::to_string(code);
+ first = false;
+ }
+ return Status::OK();
+ }
+
+ ExportedSchemaPrivateData export_;
+ int64_t flags_ = 0;
+ std::vector<std::pair<std::string, std::string>> additional_metadata_;
+ std::unique_ptr<SchemaExporter> dict_exporter_;
+ std::vector<SchemaExporter> child_exporters_;
+};
+
+} // namespace
+
+Status ExportType(const DataType& type, struct ArrowSchema* out) {
+ SchemaExporter exporter;
+ RETURN_NOT_OK(exporter.ExportType(type));
+ exporter.Finish(out);
+ return Status::OK();
+}
+
+Status ExportField(const Field& field, struct ArrowSchema* out) {
+ SchemaExporter exporter;
+ RETURN_NOT_OK(exporter.ExportField(field));
+ exporter.Finish(out);
+ return Status::OK();
+}
+
+Status ExportSchema(const Schema& schema, struct ArrowSchema* out) {
+ SchemaExporter exporter;
+ RETURN_NOT_OK(exporter.ExportSchema(schema));
+ exporter.Finish(out);
+ return Status::OK();
+}
+
+//////////////////////////////////////////////////////////////////////////
+// C data export
+
+namespace {
+
+struct ExportedArrayPrivateData : PoolAllocationMixin<ExportedArrayPrivateData> {
+ // The buffers are owned by the ArrayData member
+ StaticVector<const void*, 3> buffers_;
+ struct ArrowArray dictionary_;
+ SmallVector<struct ArrowArray, 1> children_;
+ SmallVector<struct ArrowArray*, 4> child_pointers_;
+
+ std::shared_ptr<ArrayData> data_;
+
+ ExportedArrayPrivateData() = default;
+ ARROW_DEFAULT_MOVE_AND_ASSIGN(ExportedArrayPrivateData);
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ExportedArrayPrivateData);
+};
+
+void ReleaseExportedArray(struct ArrowArray* array) {
+ if (ArrowArrayIsReleased(array)) {
+ return;
+ }
+ for (int64_t i = 0; i < array->n_children; ++i) {
+ struct ArrowArray* child = array->children[i];
+ ArrowArrayRelease(child);
+ DCHECK(ArrowArrayIsReleased(child))
+ << "Child release callback should have marked it released";
+ }
+ struct ArrowArray* dict = array->dictionary;
+ if (dict != nullptr) {
+ ArrowArrayRelease(dict);
+ DCHECK(ArrowArrayIsReleased(dict))
+ << "Dictionary release callback should have marked it released";
+ }
+ DCHECK_NE(array->private_data, nullptr);
+ delete reinterpret_cast<ExportedArrayPrivateData*>(array->private_data);
+
+ ArrowArrayMarkReleased(array);
+}
+
+struct ArrayExporter {
+ Status Export(const std::shared_ptr<ArrayData>& data) {
+ // Force computing null count.
+ // This is because ARROW-9037 is in version 0.17 and 0.17.1, and they are
+ // not able to import arrays without a null bitmap and null_count == -1.
+ data->GetNullCount();
+ // Store buffer pointers
+ size_t n_buffers = data->buffers.size();
+ auto buffers_begin = data->buffers.begin();
+ if (n_buffers > 0 && !internal::HasValidityBitmap(data->type->id())) {
+ --n_buffers;
+ ++buffers_begin;
+ }
+ export_.buffers_.resize(n_buffers);
+ std::transform(buffers_begin, data->buffers.end(), export_.buffers_.begin(),
+ [](const std::shared_ptr<Buffer>& buffer) -> const void* {
+ return buffer ? buffer->data() : nullptr;
+ });
+
+ // Export dictionary
+ if (data->dictionary != nullptr) {
+ dict_exporter_.reset(new ArrayExporter());
+ RETURN_NOT_OK(dict_exporter_->Export(data->dictionary));
+ }
+
+ // Export children
+ export_.children_.resize(data->child_data.size());
+ child_exporters_.resize(data->child_data.size());
+ for (size_t i = 0; i < data->child_data.size(); ++i) {
+ RETURN_NOT_OK(child_exporters_[i].Export(data->child_data[i]));
+ }
+
+ // Store owning pointer to ArrayData
+ export_.data_ = data;
+
+ return Status::OK();
+ }
+
+ // Finalize exporting by setting C struct fields and allocating
+ // autonomous private data for each array node.
+ //
+ // This function can't fail, as properly reclaiming memory in case of error
+ // would be too fragile. After this function returns, memory is reclaimed
+ // by calling the release() pointer in the top level ArrowArray struct.
+ void Finish(struct ArrowArray* c_struct_) {
+ // First, create permanent ExportedArrayPrivateData, to make sure that
+ // child ArrayData pointers don't get invalidated.
+ auto pdata = new ExportedArrayPrivateData(std::move(export_));
+ const ArrayData& data = *pdata->data_;
+
+ // Second, finish dictionary and children.
+ if (dict_exporter_) {
+ dict_exporter_->Finish(&pdata->dictionary_);
+ }
+ pdata->child_pointers_.resize(data.child_data.size(), nullptr);
+ for (size_t i = 0; i < data.child_data.size(); ++i) {
+ auto ptr = &pdata->children_[i];
+ pdata->child_pointers_[i] = ptr;
+ child_exporters_[i].Finish(ptr);
+ }
+
+ // Third, fill C struct.
+ DCHECK_NE(c_struct_, nullptr);
+ memset(c_struct_, 0, sizeof(*c_struct_));
+
+ c_struct_->length = data.length;
+ c_struct_->null_count = data.null_count;
+ c_struct_->offset = data.offset;
+ c_struct_->n_buffers = static_cast<int64_t>(pdata->buffers_.size());
+ c_struct_->n_children = static_cast<int64_t>(pdata->child_pointers_.size());
+ c_struct_->buffers = pdata->buffers_.data();
+ c_struct_->children = c_struct_->n_children ? pdata->child_pointers_.data() : nullptr;
+ c_struct_->dictionary = dict_exporter_ ? &pdata->dictionary_ : nullptr;
+ c_struct_->private_data = pdata;
+ c_struct_->release = ReleaseExportedArray;
+ }
+
+ ExportedArrayPrivateData export_;
+ std::unique_ptr<ArrayExporter> dict_exporter_;
+ std::vector<ArrayExporter> child_exporters_;
+};
+
+} // namespace
+
+Status ExportArray(const Array& array, struct ArrowArray* out,
+ struct ArrowSchema* out_schema) {
+ SchemaExportGuard guard(out_schema);
+ if (out_schema != nullptr) {
+ RETURN_NOT_OK(ExportType(*array.type(), out_schema));
+ }
+ ArrayExporter exporter;
+ RETURN_NOT_OK(exporter.Export(array.data()));
+ exporter.Finish(out);
+ guard.Detach();
+ return Status::OK();
+}
+
+Status ExportRecordBatch(const RecordBatch& batch, struct ArrowArray* out,
+ struct ArrowSchema* out_schema) {
+ // XXX perhaps bypass ToStructArray() for speed?
+ ARROW_ASSIGN_OR_RAISE(auto array, batch.ToStructArray());
+
+ SchemaExportGuard guard(out_schema);
+ if (out_schema != nullptr) {
+ // Export the schema, not the struct type, so as not to lose top-level metadata
+ RETURN_NOT_OK(ExportSchema(*batch.schema(), out_schema));
+ }
+ ArrayExporter exporter;
+ RETURN_NOT_OK(exporter.Export(array->data()));
+ exporter.Finish(out);
+ guard.Detach();
+ return Status::OK();
+}
+
+//////////////////////////////////////////////////////////////////////////
+// C schema import
+
+namespace {
+
+static constexpr int64_t kMaxImportRecursionLevel = 64;
+
+Status InvalidFormatString(util::string_view v) {
+ return Status::Invalid("Invalid or unsupported format string: '", v, "'");
+}
+
+class FormatStringParser {
+ public:
+ FormatStringParser() {}
+
+ explicit FormatStringParser(util::string_view v) : view_(v), index_(0) {}
+
+ bool AtEnd() const { return index_ >= view_.length(); }
+
+ char Next() { return view_[index_++]; }
+
+ util::string_view Rest() { return view_.substr(index_); }
+
+ Status CheckNext(char c) {
+ if (AtEnd() || Next() != c) {
+ return Invalid();
+ }
+ return Status::OK();
+ }
+
+ Status CheckHasNext() {
+ if (AtEnd()) {
+ return Invalid();
+ }
+ return Status::OK();
+ }
+
+ Status CheckAtEnd() {
+ if (!AtEnd()) {
+ return Invalid();
+ }
+ return Status::OK();
+ }
+
+ template <typename IntType = int32_t>
+ Result<IntType> ParseInt(util::string_view v) {
+ using ArrowIntType = typename CTypeTraits<IntType>::ArrowType;
+ IntType value;
+ if (!internal::ParseValue<ArrowIntType>(v.data(), v.size(), &value)) {
+ return Invalid();
+ }
+ return value;
+ }
+
+ Result<TimeUnit::type> ParseTimeUnit() {
+ RETURN_NOT_OK(CheckHasNext());
+ switch (Next()) {
+ case 's':
+ return TimeUnit::SECOND;
+ case 'm':
+ return TimeUnit::MILLI;
+ case 'u':
+ return TimeUnit::MICRO;
+ case 'n':
+ return TimeUnit::NANO;
+ default:
+ return Invalid();
+ }
+ }
+
+ SmallVector<util::string_view, 2> Split(util::string_view v, char delim = ',') {
+ SmallVector<util::string_view, 2> parts;
+ size_t start = 0, end;
+ while (true) {
+ end = v.find_first_of(delim, start);
+ parts.push_back(v.substr(start, end - start));
+ if (end == util::string_view::npos) {
+ break;
+ }
+ start = end + 1;
+ }
+ return parts;
+ }
+
+ template <typename IntType = int32_t>
+ Result<std::vector<IntType>> ParseInts(util::string_view v) {
+ auto parts = Split(v);
+ std::vector<IntType> result;
+ result.reserve(parts.size());
+ for (const auto& p : parts) {
+ ARROW_ASSIGN_OR_RAISE(auto i, ParseInt<IntType>(p));
+ result.push_back(i);
+ }
+ return result;
+ }
+
+ Status Invalid() { return InvalidFormatString(view_); }
+
+ protected:
+ util::string_view view_;
+ size_t index_;
+};
+
+struct DecodedMetadata {
+ std::shared_ptr<KeyValueMetadata> metadata;
+ std::string extension_name;
+ std::string extension_serialized;
+};
+
+Result<DecodedMetadata> DecodeMetadata(const char* metadata) {
+ auto read_int32 = [&](int32_t* out) -> Status {
+ int32_t v;
+ memcpy(&v, metadata, 4);
+ metadata += 4;
+ *out = v;
+ if (*out < 0) {
+ return Status::Invalid("Invalid encoded metadata string");
+ }
+ return Status::OK();
+ };
+
+ auto read_string = [&](std::string* out) -> Status {
+ int32_t len;
+ RETURN_NOT_OK(read_int32(&len));
+ out->resize(len);
+ if (len > 0) {
+ memcpy(&(*out)[0], metadata, len);
+ metadata += len;
+ }
+ return Status::OK();
+ };
+
+ DecodedMetadata decoded;
+
+ if (metadata == nullptr) {
+ return decoded;
+ }
+ int32_t npairs;
+ RETURN_NOT_OK(read_int32(&npairs));
+ if (npairs == 0) {
+ return decoded;
+ }
+ std::vector<std::string> keys(npairs);
+ std::vector<std::string> values(npairs);
+ for (int32_t i = 0; i < npairs; ++i) {
+ RETURN_NOT_OK(read_string(&keys[i]));
+ RETURN_NOT_OK(read_string(&values[i]));
+ if (keys[i] == kExtensionTypeKeyName) {
+ decoded.extension_name = values[i];
+ } else if (keys[i] == kExtensionMetadataKeyName) {
+ decoded.extension_serialized = values[i];
+ }
+ }
+ decoded.metadata = key_value_metadata(std::move(keys), std::move(values));
+ return decoded;
+}
+
+struct SchemaImporter {
+ SchemaImporter() : c_struct_(nullptr), guard_(nullptr) {}
+
+ Status Import(struct ArrowSchema* src) {
+ if (ArrowSchemaIsReleased(src)) {
+ return Status::Invalid("Cannot import released ArrowSchema");
+ }
+ guard_.Reset(src);
+ recursion_level_ = 0;
+ c_struct_ = src;
+ return DoImport();
+ }
+
+ Result<std::shared_ptr<Field>> MakeField() const {
+ const char* name = c_struct_->name ? c_struct_->name : "";
+ bool nullable = (c_struct_->flags & ARROW_FLAG_NULLABLE) != 0;
+ return field(name, type_, nullable, std::move(metadata_.metadata));
+ }
+
+ Result<std::shared_ptr<Schema>> MakeSchema() const {
+ if (type_->id() != Type::STRUCT) {
+ return Status::Invalid(
+ "Cannot import schema: ArrowSchema describes non-struct type ",
+ type_->ToString());
+ }
+ return schema(type_->fields(), std::move(metadata_.metadata));
+ }
+
+ Result<std::shared_ptr<DataType>> MakeType() const { return type_; }
+
+ protected:
+ Status ImportChild(const SchemaImporter* parent, struct ArrowSchema* src) {
+ if (ArrowSchemaIsReleased(src)) {
+ return Status::Invalid("Cannot import released ArrowSchema");
+ }
+ recursion_level_ = parent->recursion_level_ + 1;
+ if (recursion_level_ >= kMaxImportRecursionLevel) {
+ return Status::Invalid("Recursion level in ArrowSchema struct exceeded");
+ }
+ // The ArrowSchema is owned by its parent, so don't release it ourselves
+ c_struct_ = src;
+ return DoImport();
+ }
+
+ Status ImportDict(const SchemaImporter* parent, struct ArrowSchema* src) {
+ return ImportChild(parent, src);
+ }
+
+ Status DoImport() {
+ // First import children (required for reconstituting parent type)
+ child_importers_.resize(c_struct_->n_children);
+ for (int64_t i = 0; i < c_struct_->n_children; ++i) {
+ DCHECK_NE(c_struct_->children[i], nullptr);
+ RETURN_NOT_OK(child_importers_[i].ImportChild(this, c_struct_->children[i]));
+ }
+
+ // Import main type
+ RETURN_NOT_OK(ProcessFormat());
+ DCHECK_NE(type_, nullptr);
+
+ // Import dictionary type
+ if (c_struct_->dictionary != nullptr) {
+ // Check this index type
+ if (!is_integer(type_->id())) {
+ return Status::Invalid(
+ "ArrowSchema struct has a dictionary but is not an integer type: ",
+ type_->ToString());
+ }
+ SchemaImporter dict_importer;
+ RETURN_NOT_OK(dict_importer.ImportDict(this, c_struct_->dictionary));
+ bool ordered = (c_struct_->flags & ARROW_FLAG_DICTIONARY_ORDERED) != 0;
+ type_ = dictionary(type_, dict_importer.type_, ordered);
+ }
+
+ // Import metadata
+ ARROW_ASSIGN_OR_RAISE(metadata_, DecodeMetadata(c_struct_->metadata));
+
+ // Detect extension type
+ if (!metadata_.extension_name.empty()) {
+ const auto registered_ext_type = GetExtensionType(metadata_.extension_name);
+ if (registered_ext_type) {
+ ARROW_ASSIGN_OR_RAISE(
+ type_, registered_ext_type->Deserialize(std::move(type_),
+ metadata_.extension_serialized));
+ }
+ }
+
+ return Status::OK();
+ }
+
+ Status ProcessFormat() {
+ f_parser_ = FormatStringParser(c_struct_->format);
+ RETURN_NOT_OK(f_parser_.CheckHasNext());
+ switch (f_parser_.Next()) {
+ case 'n':
+ return ProcessPrimitive(null());
+ case 'b':
+ return ProcessPrimitive(boolean());
+ case 'c':
+ return ProcessPrimitive(int8());
+ case 'C':
+ return ProcessPrimitive(uint8());
+ case 's':
+ return ProcessPrimitive(int16());
+ case 'S':
+ return ProcessPrimitive(uint16());
+ case 'i':
+ return ProcessPrimitive(int32());
+ case 'I':
+ return ProcessPrimitive(uint32());
+ case 'l':
+ return ProcessPrimitive(int64());
+ case 'L':
+ return ProcessPrimitive(uint64());
+ case 'e':
+ return ProcessPrimitive(float16());
+ case 'f':
+ return ProcessPrimitive(float32());
+ case 'g':
+ return ProcessPrimitive(float64());
+ case 'u':
+ return ProcessPrimitive(utf8());
+ case 'U':
+ return ProcessPrimitive(large_utf8());
+ case 'z':
+ return ProcessPrimitive(binary());
+ case 'Z':
+ return ProcessPrimitive(large_binary());
+ case 'w':
+ return ProcessFixedSizeBinary();
+ case 'd':
+ return ProcessDecimal();
+ case 't':
+ return ProcessTemporal();
+ case '+':
+ return ProcessNested();
+ }
+ return f_parser_.Invalid();
+ }
+
+ Status ProcessTemporal() {
+ RETURN_NOT_OK(f_parser_.CheckHasNext());
+ switch (f_parser_.Next()) {
+ case 'd':
+ return ProcessDate();
+ case 't':
+ return ProcessTime();
+ case 'D':
+ return ProcessDuration();
+ case 'i':
+ return ProcessInterval();
+ case 's':
+ return ProcessTimestamp();
+ }
+ return f_parser_.Invalid();
+ }
+
+ Status ProcessNested() {
+ RETURN_NOT_OK(f_parser_.CheckHasNext());
+ switch (f_parser_.Next()) {
+ case 'l':
+ return ProcessListLike<ListType>();
+ case 'L':
+ return ProcessListLike<LargeListType>();
+ case 'w':
+ return ProcessFixedSizeList();
+ case 's':
+ return ProcessStruct();
+ case 'm':
+ return ProcessMap();
+ case 'u':
+ return ProcessUnion();
+ }
+ return f_parser_.Invalid();
+ }
+
+ Status ProcessDate() {
+ RETURN_NOT_OK(f_parser_.CheckHasNext());
+ switch (f_parser_.Next()) {
+ case 'D':
+ return ProcessPrimitive(date32());
+ case 'm':
+ return ProcessPrimitive(date64());
+ }
+ return f_parser_.Invalid();
+ }
+
+ Status ProcessInterval() {
+ RETURN_NOT_OK(f_parser_.CheckHasNext());
+ switch (f_parser_.Next()) {
+ case 'D':
+ return ProcessPrimitive(day_time_interval());
+ case 'M':
+ return ProcessPrimitive(month_interval());
+ case 'n':
+ return ProcessPrimitive(month_day_nano_interval());
+ }
+ return f_parser_.Invalid();
+ }
+
+ Status ProcessTime() {
+ ARROW_ASSIGN_OR_RAISE(auto unit, f_parser_.ParseTimeUnit());
+ if (unit == TimeUnit::SECOND || unit == TimeUnit::MILLI) {
+ return ProcessPrimitive(time32(unit));
+ } else {
+ return ProcessPrimitive(time64(unit));
+ }
+ }
+
+ Status ProcessDuration() {
+ ARROW_ASSIGN_OR_RAISE(auto unit, f_parser_.ParseTimeUnit());
+ return ProcessPrimitive(duration(unit));
+ }
+
+ Status ProcessTimestamp() {
+ ARROW_ASSIGN_OR_RAISE(auto unit, f_parser_.ParseTimeUnit());
+ RETURN_NOT_OK(f_parser_.CheckNext(':'));
+ type_ = timestamp(unit, std::string(f_parser_.Rest()));
+ return Status::OK();
+ }
+
+ Status ProcessFixedSizeBinary() {
+ RETURN_NOT_OK(f_parser_.CheckNext(':'));
+ ARROW_ASSIGN_OR_RAISE(auto byte_width, f_parser_.ParseInt(f_parser_.Rest()));
+ if (byte_width < 0) {
+ return f_parser_.Invalid();
+ }
+ type_ = fixed_size_binary(byte_width);
+ return Status::OK();
+ }
+
+ Status ProcessDecimal() {
+ RETURN_NOT_OK(f_parser_.CheckNext(':'));
+ ARROW_ASSIGN_OR_RAISE(auto prec_scale, f_parser_.ParseInts(f_parser_.Rest()));
+ // 3 elements indicates bit width was communicated as well.
+ if (prec_scale.size() != 2 && prec_scale.size() != 3) {
+ return f_parser_.Invalid();
+ }
+ if (prec_scale[0] <= 0) {
+ return f_parser_.Invalid();
+ }
+ if (prec_scale.size() == 2 || prec_scale[2] == 128) {
+ type_ = decimal128(prec_scale[0], prec_scale[1]);
+ } else if (prec_scale[2] == 256) {
+ type_ = decimal256(prec_scale[0], prec_scale[1]);
+ } else {
+ return f_parser_.Invalid();
+ }
+ return Status::OK();
+ }
+
+ Status ProcessPrimitive(const std::shared_ptr<DataType>& type) {
+ RETURN_NOT_OK(f_parser_.CheckAtEnd());
+ type_ = type;
+ return CheckNoChildren(type);
+ }
+
+ template <typename ListType>
+ Status ProcessListLike() {
+ RETURN_NOT_OK(f_parser_.CheckAtEnd());
+ RETURN_NOT_OK(CheckNumChildren(1));
+ ARROW_ASSIGN_OR_RAISE(auto field, MakeChildField(0));
+ type_ = std::make_shared<ListType>(field);
+ return Status::OK();
+ }
+
+ Status ProcessMap() {
+ RETURN_NOT_OK(f_parser_.CheckAtEnd());
+ RETURN_NOT_OK(CheckNumChildren(1));
+ ARROW_ASSIGN_OR_RAISE(auto field, MakeChildField(0));
+ const auto& value_type = field->type();
+ if (value_type->id() != Type::STRUCT) {
+ return Status::Invalid("Imported map array has unexpected child field type: ",
+ field->ToString());
+ }
+ if (value_type->num_fields() != 2) {
+ return Status::Invalid("Imported map array has unexpected child field type: ",
+ field->ToString());
+ }
+
+ bool keys_sorted = (c_struct_->flags & ARROW_FLAG_MAP_KEYS_SORTED);
+ type_ = map(value_type->field(0)->type(), value_type->field(1)->type(), keys_sorted);
+ return Status::OK();
+ }
+
+ Status ProcessFixedSizeList() {
+ RETURN_NOT_OK(f_parser_.CheckNext(':'));
+ ARROW_ASSIGN_OR_RAISE(auto list_size, f_parser_.ParseInt(f_parser_.Rest()));
+ if (list_size < 0) {
+ return f_parser_.Invalid();
+ }
+ RETURN_NOT_OK(CheckNumChildren(1));
+ ARROW_ASSIGN_OR_RAISE(auto field, MakeChildField(0));
+ type_ = fixed_size_list(field, list_size);
+ return Status::OK();
+ }
+
+ Status ProcessStruct() {
+ RETURN_NOT_OK(f_parser_.CheckAtEnd());
+ ARROW_ASSIGN_OR_RAISE(auto fields, MakeChildFields());
+ type_ = struct_(std::move(fields));
+ return Status::OK();
+ }
+
+ Status ProcessUnion() {
+ RETURN_NOT_OK(f_parser_.CheckHasNext());
+ UnionMode::type mode;
+ switch (f_parser_.Next()) {
+ case 'd':
+ mode = UnionMode::DENSE;
+ break;
+ case 's':
+ mode = UnionMode::SPARSE;
+ break;
+ default:
+ return f_parser_.Invalid();
+ }
+ RETURN_NOT_OK(f_parser_.CheckNext(':'));
+ ARROW_ASSIGN_OR_RAISE(auto type_codes, f_parser_.ParseInts<int8_t>(f_parser_.Rest()));
+ ARROW_ASSIGN_OR_RAISE(auto fields, MakeChildFields());
+ if (fields.size() != type_codes.size()) {
+ return Status::Invalid(
+ "ArrowArray struct number of children incompatible with format string "
+ "(mismatching number of union type codes) ",
+ "'", c_struct_->format, "'");
+ }
+ for (const auto code : type_codes) {
+ if (code < 0) {
+ return Status::Invalid("Negative type code in union: format string '",
+ c_struct_->format, "'");
+ }
+ }
+ if (mode == UnionMode::SPARSE) {
+ type_ = sparse_union(std::move(fields), std::move(type_codes));
+ } else {
+ type_ = dense_union(std::move(fields), std::move(type_codes));
+ }
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<Field>> MakeChildField(int64_t child_id) {
+ const auto& child = child_importers_[child_id];
+ if (child.c_struct_->name == nullptr) {
+ return Status::Invalid("Expected non-null name in imported array child");
+ }
+ return child.MakeField();
+ }
+
+ Result<std::vector<std::shared_ptr<Field>>> MakeChildFields() {
+ std::vector<std::shared_ptr<Field>> fields(child_importers_.size());
+ for (int64_t i = 0; i < static_cast<int64_t>(child_importers_.size()); ++i) {
+ ARROW_ASSIGN_OR_RAISE(fields[i], MakeChildField(i));
+ }
+ return fields;
+ }
+
+ Status CheckNoChildren(const std::shared_ptr<DataType>& type) {
+ return CheckNumChildren(type, 0);
+ }
+
+ Status CheckNumChildren(const std::shared_ptr<DataType>& type, int64_t n_children) {
+ if (c_struct_->n_children != n_children) {
+ return Status::Invalid("Expected ", n_children, " children for imported type ",
+ *type, ", ArrowArray struct has ", c_struct_->n_children);
+ }
+ return Status::OK();
+ }
+
+ Status CheckNumChildren(int64_t n_children) {
+ if (c_struct_->n_children != n_children) {
+ return Status::Invalid("Expected ", n_children, " children for imported format '",
+ c_struct_->format, "', ArrowArray struct has ",
+ c_struct_->n_children);
+ }
+ return Status::OK();
+ }
+
+ struct ArrowSchema* c_struct_;
+ SchemaExportGuard guard_;
+ FormatStringParser f_parser_;
+ int64_t recursion_level_;
+ std::vector<SchemaImporter> child_importers_;
+ std::shared_ptr<DataType> type_;
+ DecodedMetadata metadata_;
+};
+
+} // namespace
+
+Result<std::shared_ptr<DataType>> ImportType(struct ArrowSchema* schema) {
+ SchemaImporter importer;
+ RETURN_NOT_OK(importer.Import(schema));
+ return importer.MakeType();
+}
+
+Result<std::shared_ptr<Field>> ImportField(struct ArrowSchema* schema) {
+ SchemaImporter importer;
+ RETURN_NOT_OK(importer.Import(schema));
+ return importer.MakeField();
+}
+
+Result<std::shared_ptr<Schema>> ImportSchema(struct ArrowSchema* schema) {
+ SchemaImporter importer;
+ RETURN_NOT_OK(importer.Import(schema));
+ return importer.MakeSchema();
+}
+
+//////////////////////////////////////////////////////////////////////////
+// C data import
+
+namespace {
+
+// A wrapper struct for an imported C ArrowArray.
+// The ArrowArray is released on destruction.
+struct ImportedArrayData {
+ struct ArrowArray array_;
+
+ ImportedArrayData() {
+ ArrowArrayMarkReleased(&array_); // Initially released
+ }
+
+ void Release() {
+ if (!ArrowArrayIsReleased(&array_)) {
+ ArrowArrayRelease(&array_);
+ DCHECK(ArrowArrayIsReleased(&array_));
+ }
+ }
+
+ ~ImportedArrayData() { Release(); }
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ImportedArrayData);
+};
+
+// A buffer wrapping an imported piece of data.
+class ImportedBuffer : public Buffer {
+ public:
+ ImportedBuffer(const uint8_t* data, int64_t size,
+ std::shared_ptr<ImportedArrayData> import)
+ : Buffer(data, size), import_(std::move(import)) {}
+
+ ~ImportedBuffer() override {}
+
+ protected:
+ std::shared_ptr<ImportedArrayData> import_;
+};
+
+struct ArrayImporter {
+ explicit ArrayImporter(const std::shared_ptr<DataType>& type) : type_(type) {}
+
+ Status Import(struct ArrowArray* src) {
+ if (ArrowArrayIsReleased(src)) {
+ return Status::Invalid("Cannot import released ArrowArray");
+ }
+ recursion_level_ = 0;
+ import_ = std::make_shared<ImportedArrayData>();
+ c_struct_ = &import_->array_;
+ ArrowArrayMove(src, c_struct_);
+ return DoImport();
+ }
+
+ Result<std::shared_ptr<Array>> MakeArray() {
+ DCHECK_NE(data_, nullptr);
+ return ::arrow::MakeArray(data_);
+ }
+
+ std::shared_ptr<ArrayData> GetArrayData() {
+ DCHECK_NE(data_, nullptr);
+ return data_;
+ }
+
+ Result<std::shared_ptr<RecordBatch>> MakeRecordBatch(std::shared_ptr<Schema> schema) {
+ DCHECK_NE(data_, nullptr);
+ if (data_->GetNullCount() != 0) {
+ return Status::Invalid(
+ "ArrowArray struct has non-zero null count, "
+ "cannot be imported as RecordBatch");
+ }
+ if (data_->offset != 0) {
+ return Status::Invalid(
+ "ArrowArray struct has non-zero offset, "
+ "cannot be imported as RecordBatch");
+ }
+ return RecordBatch::Make(std::move(schema), data_->length,
+ std::move(data_->child_data));
+ }
+
+ Status ImportChild(const ArrayImporter* parent, struct ArrowArray* src) {
+ if (ArrowArrayIsReleased(src)) {
+ return Status::Invalid("Cannot import released ArrowArray");
+ }
+ recursion_level_ = parent->recursion_level_ + 1;
+ if (recursion_level_ >= kMaxImportRecursionLevel) {
+ return Status::Invalid("Recursion level in ArrowArray struct exceeded");
+ }
+ // Child buffers will keep the entire parent import alive.
+ // Perhaps we can move the child structs to an owned area
+ // when the parent ImportedArrayData::Release() gets called,
+ // but that is another level of complication.
+ import_ = parent->import_;
+ // The ArrowArray shouldn't be moved, it's owned by its parent
+ c_struct_ = src;
+ return DoImport();
+ }
+
+ Status ImportDict(const ArrayImporter* parent, struct ArrowArray* src) {
+ return ImportChild(parent, src);
+ }
+
+ Status DoImport() {
+ // Unwrap extension type
+ const DataType* storage_type = type_.get();
+ if (storage_type->id() == Type::EXTENSION) {
+ storage_type =
+ checked_cast<const ExtensionType&>(*storage_type).storage_type().get();
+ }
+
+ // First import children (required for reconstituting parent array data)
+ const auto& fields = storage_type->fields();
+ if (c_struct_->n_children != static_cast<int64_t>(fields.size())) {
+ return Status::Invalid("ArrowArray struct has ", c_struct_->n_children,
+ " children, expected ", fields.size(), " for type ",
+ type_->ToString());
+ }
+ child_importers_.reserve(fields.size());
+ for (int64_t i = 0; i < c_struct_->n_children; ++i) {
+ DCHECK_NE(c_struct_->children[i], nullptr);
+ child_importers_.emplace_back(fields[i]->type());
+ RETURN_NOT_OK(child_importers_.back().ImportChild(this, c_struct_->children[i]));
+ }
+
+ // Import main data
+ RETURN_NOT_OK(VisitTypeInline(*storage_type, this));
+
+ bool is_dict_type = (storage_type->id() == Type::DICTIONARY);
+ if (c_struct_->dictionary != nullptr) {
+ if (!is_dict_type) {
+ return Status::Invalid("Import type is ", type_->ToString(),
+ " but dictionary field in ArrowArray struct is not null");
+ }
+ const auto& dict_type = checked_cast<const DictionaryType&>(*storage_type);
+ // Import dictionary values
+ ArrayImporter dict_importer(dict_type.value_type());
+ RETURN_NOT_OK(dict_importer.ImportDict(this, c_struct_->dictionary));
+ data_->dictionary = dict_importer.GetArrayData();
+ } else {
+ if (is_dict_type) {
+ return Status::Invalid("Import type is ", type_->ToString(),
+ " but dictionary field in ArrowArray struct is null");
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DataType& type) {
+ return Status::NotImplemented("Cannot import array of type ", type_->ToString());
+ }
+
+ Status Visit(const FixedWidthType& type) { return ImportFixedSizePrimitive(type); }
+
+ Status Visit(const NullType& type) {
+ RETURN_NOT_OK(CheckNoChildren());
+ if (c_struct_->n_buffers == 1) {
+ // Legacy format exported by older Arrow C++ versions
+ RETURN_NOT_OK(AllocateArrayData());
+ } else {
+ RETURN_NOT_OK(CheckNumBuffers(0));
+ RETURN_NOT_OK(AllocateArrayData());
+ data_->buffers.insert(data_->buffers.begin(), nullptr);
+ }
+ data_->null_count = data_->length;
+ return Status::OK();
+ }
+
+ Status Visit(const StringType& type) { return ImportStringLike(type); }
+
+ Status Visit(const BinaryType& type) { return ImportStringLike(type); }
+
+ Status Visit(const LargeStringType& type) { return ImportStringLike(type); }
+
+ Status Visit(const LargeBinaryType& type) { return ImportStringLike(type); }
+
+ Status Visit(const ListType& type) { return ImportListLike(type); }
+
+ Status Visit(const LargeListType& type) { return ImportListLike(type); }
+
+ Status Visit(const FixedSizeListType& type) {
+ RETURN_NOT_OK(CheckNumChildren(1));
+ RETURN_NOT_OK(CheckNumBuffers(1));
+ RETURN_NOT_OK(AllocateArrayData());
+ RETURN_NOT_OK(ImportNullBitmap());
+ return Status::OK();
+ }
+
+ Status Visit(const StructType& type) {
+ RETURN_NOT_OK(CheckNumBuffers(1));
+ RETURN_NOT_OK(AllocateArrayData());
+ RETURN_NOT_OK(ImportNullBitmap());
+ return Status::OK();
+ }
+
+ Status Visit(const SparseUnionType& type) {
+ RETURN_NOT_OK(CheckNoNulls());
+ if (c_struct_->n_buffers == 2) {
+ // ARROW-14179: legacy format exported by older Arrow C++ versions
+ RETURN_NOT_OK(AllocateArrayData());
+ RETURN_NOT_OK(ImportFixedSizeBuffer(1, sizeof(int8_t)));
+ } else {
+ RETURN_NOT_OK(CheckNumBuffers(1));
+ RETURN_NOT_OK(AllocateArrayData());
+ RETURN_NOT_OK(ImportFixedSizeBuffer(0, sizeof(int8_t)));
+ // Prepend a null bitmap buffer, as expected by SparseUnionArray
+ data_->buffers.insert(data_->buffers.begin(), nullptr);
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DenseUnionType& type) {
+ RETURN_NOT_OK(CheckNoNulls());
+ if (c_struct_->n_buffers == 3) {
+ // ARROW-14179: legacy format exported by older Arrow C++ versions
+ RETURN_NOT_OK(AllocateArrayData());
+ RETURN_NOT_OK(ImportFixedSizeBuffer(1, sizeof(int8_t)));
+ RETURN_NOT_OK(ImportFixedSizeBuffer(2, sizeof(int32_t)));
+ } else {
+ RETURN_NOT_OK(CheckNumBuffers(2));
+ RETURN_NOT_OK(AllocateArrayData());
+ RETURN_NOT_OK(ImportFixedSizeBuffer(0, sizeof(int8_t)));
+ RETURN_NOT_OK(ImportFixedSizeBuffer(1, sizeof(int32_t)));
+ // Prepend a null bitmap pointer, as expected by DenseUnionArray
+ data_->buffers.insert(data_->buffers.begin(), nullptr);
+ }
+ return Status::OK();
+ }
+
+ Status ImportFixedSizePrimitive(const FixedWidthType& type) {
+ RETURN_NOT_OK(CheckNoChildren());
+ RETURN_NOT_OK(CheckNumBuffers(2));
+ RETURN_NOT_OK(AllocateArrayData());
+ RETURN_NOT_OK(ImportNullBitmap());
+ if (BitUtil::IsMultipleOf8(type.bit_width())) {
+ RETURN_NOT_OK(ImportFixedSizeBuffer(1, type.bit_width() / 8));
+ } else {
+ DCHECK_EQ(type.bit_width(), 1);
+ RETURN_NOT_OK(ImportBitsBuffer(1));
+ }
+ return Status::OK();
+ }
+
+ template <typename StringType>
+ Status ImportStringLike(const StringType& type) {
+ RETURN_NOT_OK(CheckNoChildren());
+ RETURN_NOT_OK(CheckNumBuffers(3));
+ RETURN_NOT_OK(AllocateArrayData());
+ RETURN_NOT_OK(ImportNullBitmap());
+ RETURN_NOT_OK(ImportOffsetsBuffer<typename StringType::offset_type>(1));
+ RETURN_NOT_OK(ImportStringValuesBuffer<typename StringType::offset_type>(1, 2));
+ return Status::OK();
+ }
+
+ template <typename ListType>
+ Status ImportListLike(const ListType& type) {
+ RETURN_NOT_OK(CheckNumChildren(1));
+ RETURN_NOT_OK(CheckNumBuffers(2));
+ RETURN_NOT_OK(AllocateArrayData());
+ RETURN_NOT_OK(ImportNullBitmap());
+ RETURN_NOT_OK(ImportOffsetsBuffer<typename ListType::offset_type>(1));
+ return Status::OK();
+ }
+
+ Status CheckNoChildren() { return CheckNumChildren(0); }
+
+ Status CheckNumChildren(int64_t n_children) {
+ if (c_struct_->n_children != n_children) {
+ return Status::Invalid("Expected ", n_children, " children for imported type ",
+ type_->ToString(), ", ArrowArray struct has ",
+ c_struct_->n_children);
+ }
+ return Status::OK();
+ }
+
+ Status CheckNumBuffers(int64_t n_buffers) {
+ if (n_buffers != c_struct_->n_buffers) {
+ return Status::Invalid("Expected ", n_buffers, " buffers for imported type ",
+ type_->ToString(), ", ArrowArray struct has ",
+ c_struct_->n_buffers);
+ }
+ return Status::OK();
+ }
+
+ Status CheckNoNulls() {
+ if (c_struct_->null_count != 0) {
+ return Status::Invalid("Unexpected non-zero null count for imported type ",
+ type_->ToString());
+ }
+ return Status::OK();
+ }
+
+ Status AllocateArrayData() {
+ DCHECK_EQ(data_, nullptr);
+ data_ = std::make_shared<ArrayData>(type_, c_struct_->length, c_struct_->null_count,
+ c_struct_->offset);
+ data_->buffers.resize(static_cast<size_t>(c_struct_->n_buffers));
+ data_->child_data.resize(static_cast<size_t>(c_struct_->n_children));
+ DCHECK_EQ(child_importers_.size(), data_->child_data.size());
+ std::transform(child_importers_.begin(), child_importers_.end(),
+ data_->child_data.begin(),
+ [](const ArrayImporter& child) { return child.data_; });
+ return Status::OK();
+ }
+
+ Status ImportNullBitmap(int32_t buffer_id = 0) {
+ RETURN_NOT_OK(ImportBitsBuffer(buffer_id));
+ if (data_->null_count > 0 && data_->buffers[buffer_id] == nullptr) {
+ return Status::Invalid(
+ "ArrowArray struct has null bitmap buffer but non-zero null_count ",
+ data_->null_count);
+ }
+ return Status::OK();
+ }
+
+ Status ImportBitsBuffer(int32_t buffer_id) {
+ // Compute visible size of buffer
+ int64_t buffer_size = BitUtil::BytesForBits(c_struct_->length + c_struct_->offset);
+ return ImportBuffer(buffer_id, buffer_size);
+ }
+
+ Status ImportFixedSizeBuffer(int32_t buffer_id, int64_t byte_width) {
+ // Compute visible size of buffer
+ int64_t buffer_size = byte_width * (c_struct_->length + c_struct_->offset);
+ return ImportBuffer(buffer_id, buffer_size);
+ }
+
+ template <typename OffsetType>
+ Status ImportOffsetsBuffer(int32_t buffer_id) {
+ // Compute visible size of buffer
+ int64_t buffer_size =
+ sizeof(OffsetType) * (c_struct_->length + c_struct_->offset + 1);
+ return ImportBuffer(buffer_id, buffer_size);
+ }
+
+ template <typename OffsetType>
+ Status ImportStringValuesBuffer(int32_t offsets_buffer_id, int32_t buffer_id,
+ int64_t byte_width = 1) {
+ auto offsets = data_->GetValues<OffsetType>(offsets_buffer_id);
+ // Compute visible size of buffer
+ int64_t buffer_size = byte_width * offsets[c_struct_->length];
+ return ImportBuffer(buffer_id, buffer_size);
+ }
+
+ Status ImportBuffer(int32_t buffer_id, int64_t buffer_size) {
+ std::shared_ptr<Buffer>* out = &data_->buffers[buffer_id];
+ auto data = reinterpret_cast<const uint8_t*>(c_struct_->buffers[buffer_id]);
+ if (data != nullptr) {
+ *out = std::make_shared<ImportedBuffer>(data, buffer_size, import_);
+ } else {
+ out->reset();
+ }
+ return Status::OK();
+ }
+
+ struct ArrowArray* c_struct_;
+ int64_t recursion_level_;
+ const std::shared_ptr<DataType>& type_;
+
+ std::shared_ptr<ImportedArrayData> import_;
+ std::shared_ptr<ArrayData> data_;
+ std::vector<ArrayImporter> child_importers_;
+};
+
+} // namespace
+
+Result<std::shared_ptr<Array>> ImportArray(struct ArrowArray* array,
+ std::shared_ptr<DataType> type) {
+ ArrayImporter importer(type);
+ RETURN_NOT_OK(importer.Import(array));
+ return importer.MakeArray();
+}
+
+Result<std::shared_ptr<Array>> ImportArray(struct ArrowArray* array,
+ struct ArrowSchema* type) {
+ auto maybe_type = ImportType(type);
+ if (!maybe_type.ok()) {
+ ArrowArrayRelease(array);
+ return maybe_type.status();
+ }
+ return ImportArray(array, *maybe_type);
+}
+
+Result<std::shared_ptr<RecordBatch>> ImportRecordBatch(struct ArrowArray* array,
+ std::shared_ptr<Schema> schema) {
+ auto type = struct_(schema->fields());
+ ArrayImporter importer(type);
+ RETURN_NOT_OK(importer.Import(array));
+ return importer.MakeRecordBatch(std::move(schema));
+}
+
+Result<std::shared_ptr<RecordBatch>> ImportRecordBatch(struct ArrowArray* array,
+ struct ArrowSchema* schema) {
+ auto maybe_schema = ImportSchema(schema);
+ if (!maybe_schema.ok()) {
+ ArrowArrayRelease(array);
+ return maybe_schema.status();
+ }
+ return ImportRecordBatch(array, *maybe_schema);
+}
+
+//////////////////////////////////////////////////////////////////////////
+// C stream export
+
+namespace {
+
+class ExportedArrayStream {
+ public:
+ struct PrivateData {
+ explicit PrivateData(std::shared_ptr<RecordBatchReader> reader)
+ : reader_(std::move(reader)) {}
+
+ std::shared_ptr<RecordBatchReader> reader_;
+ std::string last_error_;
+
+ PrivateData() = default;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData);
+ };
+
+ explicit ExportedArrayStream(struct ArrowArrayStream* stream) : stream_(stream) {}
+
+ Status GetSchema(struct ArrowSchema* out_schema) {
+ return ExportSchema(*reader()->schema(), out_schema);
+ }
+
+ Status GetNext(struct ArrowArray* out_array) {
+ std::shared_ptr<RecordBatch> batch;
+ RETURN_NOT_OK(reader()->ReadNext(&batch));
+ if (batch == nullptr) {
+ // End of stream
+ ArrowArrayMarkReleased(out_array);
+ return Status::OK();
+ } else {
+ return ExportRecordBatch(*batch, out_array);
+ }
+ }
+
+ const char* GetLastError() {
+ const auto& last_error = private_data()->last_error_;
+ return last_error.empty() ? nullptr : last_error.c_str();
+ }
+
+ void Release() {
+ if (ArrowArrayStreamIsReleased(stream_)) {
+ return;
+ }
+ DCHECK_NE(private_data(), nullptr);
+ delete private_data();
+
+ ArrowArrayStreamMarkReleased(stream_);
+ }
+
+ // C-compatible callbacks
+
+ static int StaticGetSchema(struct ArrowArrayStream* stream,
+ struct ArrowSchema* out_schema) {
+ ExportedArrayStream self{stream};
+ return self.ToCError(self.GetSchema(out_schema));
+ }
+
+ static int StaticGetNext(struct ArrowArrayStream* stream,
+ struct ArrowArray* out_array) {
+ ExportedArrayStream self{stream};
+ return self.ToCError(self.GetNext(out_array));
+ }
+
+ static void StaticRelease(struct ArrowArrayStream* stream) {
+ ExportedArrayStream{stream}.Release();
+ }
+
+ static const char* StaticGetLastError(struct ArrowArrayStream* stream) {
+ return ExportedArrayStream{stream}.GetLastError();
+ }
+
+ private:
+ int ToCError(const Status& status) {
+ if (ARROW_PREDICT_TRUE(status.ok())) {
+ private_data()->last_error_.clear();
+ return 0;
+ }
+ private_data()->last_error_ = status.ToString();
+ switch (status.code()) {
+ case StatusCode::IOError:
+ return EIO;
+ case StatusCode::NotImplemented:
+ return ENOSYS;
+ case StatusCode::OutOfMemory:
+ return ENOMEM;
+ default:
+ return EINVAL; // Fallback for Invalid, TypeError, etc.
+ }
+ }
+
+ PrivateData* private_data() {
+ return reinterpret_cast<PrivateData*>(stream_->private_data);
+ }
+
+ const std::shared_ptr<RecordBatchReader>& reader() { return private_data()->reader_; }
+
+ struct ArrowArrayStream* stream_;
+};
+
+} // namespace
+
+Status ExportRecordBatchReader(std::shared_ptr<RecordBatchReader> reader,
+ struct ArrowArrayStream* out) {
+ out->get_schema = ExportedArrayStream::StaticGetSchema;
+ out->get_next = ExportedArrayStream::StaticGetNext;
+ out->get_last_error = ExportedArrayStream::StaticGetLastError;
+ out->release = ExportedArrayStream::StaticRelease;
+ out->private_data = new ExportedArrayStream::PrivateData{std::move(reader)};
+ return Status::OK();
+}
+
+//////////////////////////////////////////////////////////////////////////
+// C stream import
+
+namespace {
+
+class ArrayStreamBatchReader : public RecordBatchReader {
+ public:
+ explicit ArrayStreamBatchReader(struct ArrowArrayStream* stream) {
+ ArrowArrayStreamMove(stream, &stream_);
+ DCHECK(!ArrowArrayStreamIsReleased(&stream_));
+ }
+
+ ~ArrayStreamBatchReader() {
+ ArrowArrayStreamRelease(&stream_);
+ DCHECK(ArrowArrayStreamIsReleased(&stream_));
+ }
+
+ std::shared_ptr<Schema> schema() const override { return CacheSchema(); }
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
+ struct ArrowArray c_array;
+ RETURN_NOT_OK(StatusFromCError(stream_.get_next(&stream_, &c_array)));
+ if (ArrowArrayIsReleased(&c_array)) {
+ // End of stream
+ batch->reset();
+ return Status::OK();
+ } else {
+ return ImportRecordBatch(&c_array, CacheSchema()).Value(batch);
+ }
+ }
+
+ private:
+ std::shared_ptr<Schema> CacheSchema() const {
+ if (!schema_) {
+ struct ArrowSchema c_schema;
+ ARROW_CHECK_OK(StatusFromCError(stream_.get_schema(&stream_, &c_schema)));
+ schema_ = ImportSchema(&c_schema).ValueOrDie();
+ }
+ return schema_;
+ }
+
+ Status StatusFromCError(int errno_like) const {
+ if (ARROW_PREDICT_TRUE(errno_like == 0)) {
+ return Status::OK();
+ }
+ StatusCode code;
+ switch (errno_like) {
+ case EDOM:
+ case EINVAL:
+ case ERANGE:
+ code = StatusCode::Invalid;
+ break;
+ case ENOMEM:
+ code = StatusCode::OutOfMemory;
+ break;
+ case ENOSYS:
+ code = StatusCode::NotImplemented;
+ default:
+ code = StatusCode::IOError;
+ break;
+ }
+ const char* last_error = stream_.get_last_error(&stream_);
+ return Status(code, last_error ? std::string(last_error) : "");
+ }
+
+ mutable struct ArrowArrayStream stream_;
+ mutable std::shared_ptr<Schema> schema_;
+};
+
+} // namespace
+
+Result<std::shared_ptr<RecordBatchReader>> ImportRecordBatchReader(
+ struct ArrowArrayStream* stream) {
+ if (ArrowArrayStreamIsReleased(stream)) {
+ return Status::Invalid("Cannot import released ArrowArrayStream");
+ }
+ // XXX should we call get_schema() here to avoid crashing on error?
+ return std::make_shared<ArrayStreamBatchReader>(stream);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/c/bridge.h b/src/arrow/cpp/src/arrow/c/bridge.h
new file mode 100644
index 000000000..294f53e49
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/c/bridge.h
@@ -0,0 +1,197 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/c/abi.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+/// \defgroup c-data-interface Functions for working with the C data interface.
+///
+/// @{
+
+/// \brief Export C++ DataType using the C data interface format.
+///
+/// The root type is considered to have empty name and metadata.
+/// If you want the root type to have a name and/or metadata, pass
+/// a Field instead.
+///
+/// \param[in] type DataType object to export
+/// \param[out] out C struct where to export the datatype
+ARROW_EXPORT
+Status ExportType(const DataType& type, struct ArrowSchema* out);
+
+/// \brief Export C++ Field using the C data interface format.
+///
+/// \param[in] field Field object to export
+/// \param[out] out C struct where to export the field
+ARROW_EXPORT
+Status ExportField(const Field& field, struct ArrowSchema* out);
+
+/// \brief Export C++ Schema using the C data interface format.
+///
+/// \param[in] schema Schema object to export
+/// \param[out] out C struct where to export the field
+ARROW_EXPORT
+Status ExportSchema(const Schema& schema, struct ArrowSchema* out);
+
+/// \brief Export C++ Array using the C data interface format.
+///
+/// The resulting ArrowArray struct keeps the array data and buffers alive
+/// until its release callback is called by the consumer.
+///
+/// \param[in] array Array object to export
+/// \param[out] out C struct where to export the array
+/// \param[out] out_schema optional C struct where to export the array type
+ARROW_EXPORT
+Status ExportArray(const Array& array, struct ArrowArray* out,
+ struct ArrowSchema* out_schema = NULLPTR);
+
+/// \brief Export C++ RecordBatch using the C data interface format.
+///
+/// The record batch is exported as if it were a struct array.
+/// The resulting ArrowArray struct keeps the record batch data and buffers alive
+/// until its release callback is called by the consumer.
+///
+/// \param[in] batch Record batch to export
+/// \param[out] out C struct where to export the record batch
+/// \param[out] out_schema optional C struct where to export the record batch schema
+ARROW_EXPORT
+Status ExportRecordBatch(const RecordBatch& batch, struct ArrowArray* out,
+ struct ArrowSchema* out_schema = NULLPTR);
+
+/// \brief Import C++ DataType from the C data interface.
+///
+/// The given ArrowSchema struct is released (as per the C data interface
+/// specification), even if this function fails.
+///
+/// \param[in,out] schema C data interface struct representing the data type
+/// \return Imported type object
+ARROW_EXPORT
+Result<std::shared_ptr<DataType>> ImportType(struct ArrowSchema* schema);
+
+/// \brief Import C++ Field from the C data interface.
+///
+/// The given ArrowSchema struct is released (as per the C data interface
+/// specification), even if this function fails.
+///
+/// \param[in,out] schema C data interface struct representing the field
+/// \return Imported field object
+ARROW_EXPORT
+Result<std::shared_ptr<Field>> ImportField(struct ArrowSchema* schema);
+
+/// \brief Import C++ Schema from the C data interface.
+///
+/// The given ArrowSchema struct is released (as per the C data interface
+/// specification), even if this function fails.
+///
+/// \param[in,out] schema C data interface struct representing the field
+/// \return Imported field object
+ARROW_EXPORT
+Result<std::shared_ptr<Schema>> ImportSchema(struct ArrowSchema* schema);
+
+/// \brief Import C++ array from the C data interface.
+///
+/// The ArrowArray struct has its contents moved (as per the C data interface
+/// specification) to a private object held alive by the resulting array.
+///
+/// \param[in,out] array C data interface struct holding the array data
+/// \param[in] type type of the imported array
+/// \return Imported array object
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> ImportArray(struct ArrowArray* array,
+ std::shared_ptr<DataType> type);
+
+/// \brief Import C++ array and its type from the C data interface.
+///
+/// The ArrowArray struct has its contents moved (as per the C data interface
+/// specification) to a private object held alive by the resulting array.
+/// The ArrowSchema struct is released, even if this function fails.
+///
+/// \param[in,out] array C data interface struct holding the array data
+/// \param[in,out] type C data interface struct holding the array type
+/// \return Imported array object
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> ImportArray(struct ArrowArray* array,
+ struct ArrowSchema* type);
+
+/// \brief Import C++ record batch from the C data interface.
+///
+/// The ArrowArray struct has its contents moved (as per the C data interface
+/// specification) to a private object held alive by the resulting record batch.
+///
+/// \param[in,out] array C data interface struct holding the record batch data
+/// \param[in] schema schema of the imported record batch
+/// \return Imported record batch object
+ARROW_EXPORT
+Result<std::shared_ptr<RecordBatch>> ImportRecordBatch(struct ArrowArray* array,
+ std::shared_ptr<Schema> schema);
+
+/// \brief Import C++ record batch and its schema from the C data interface.
+///
+/// The type represented by the ArrowSchema struct must be a struct type array.
+/// The ArrowArray struct has its contents moved (as per the C data interface
+/// specification) to a private object held alive by the resulting record batch.
+/// The ArrowSchema struct is released, even if this function fails.
+///
+/// \param[in,out] array C data interface struct holding the record batch data
+/// \param[in,out] schema C data interface struct holding the record batch schema
+/// \return Imported record batch object
+ARROW_EXPORT
+Result<std::shared_ptr<RecordBatch>> ImportRecordBatch(struct ArrowArray* array,
+ struct ArrowSchema* schema);
+
+/// @}
+
+/// \defgroup c-stream-interface Functions for working with the C data interface.
+///
+/// @{
+
+/// \brief EXPERIMENTAL: Export C++ RecordBatchReader using the C stream interface.
+///
+/// The resulting ArrowArrayStream struct keeps the record batch reader alive
+/// until its release callback is called by the consumer.
+///
+/// \param[in] reader RecordBatchReader object to export
+/// \param[out] out C struct where to export the stream
+ARROW_EXPORT
+Status ExportRecordBatchReader(std::shared_ptr<RecordBatchReader> reader,
+ struct ArrowArrayStream* out);
+
+/// \brief EXPERIMENTAL: Import C++ RecordBatchReader from the C stream interface.
+///
+/// The ArrowArrayStream struct has its contents moved to a private object
+/// held alive by the resulting record batch reader.
+///
+/// \param[in,out] stream C stream interface struct
+/// \return Imported RecordBatchReader object
+ARROW_EXPORT
+Result<std::shared_ptr<RecordBatchReader>> ImportRecordBatchReader(
+ struct ArrowArrayStream* stream);
+
+/// @}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/c/bridge_benchmark.cc b/src/arrow/cpp/src/arrow/c/bridge_benchmark.cc
new file mode 100644
index 000000000..1ae4657fc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/c/bridge_benchmark.cc
@@ -0,0 +1,159 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/array.h"
+#include "arrow/c/bridge.h"
+#include "arrow/c/helpers.h"
+#include "arrow/ipc/json_simple.h"
+#include "arrow/record_batch.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+
+std::shared_ptr<Schema> ExampleSchema() {
+ auto f0 = field("f0", utf8());
+ auto f1 = field("f1", timestamp(TimeUnit::MICRO, "UTC"));
+ auto f2 = field("f2", int64());
+ auto f3 = field("f3", int16());
+ auto f4 = field("f4", int16());
+ auto f5 = field("f5", float32());
+ auto f6 = field("f6", float32());
+ auto f7 = field("f7", float32());
+ auto f8 = field("f8", decimal(19, 10));
+ return schema({f0, f1, f2, f3, f4, f5, f6, f7, f8});
+}
+
+std::shared_ptr<RecordBatch> ExampleRecordBatch() {
+ // We don't care about the actual data, since it's exported as raw buffer pointers
+ auto schema = ExampleSchema();
+ int64_t length = 1000;
+ std::vector<std::shared_ptr<Array>> columns;
+ for (const auto& field : schema->fields()) {
+ auto array = *MakeArrayOfNull(field->type(), length);
+ columns.push_back(array);
+ }
+ return RecordBatch::Make(schema, length, columns);
+}
+
+static void ExportType(benchmark::State& state) { // NOLINT non-const reference
+ struct ArrowSchema c_export;
+ auto type = utf8();
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(ExportType(*type, &c_export));
+ ArrowSchemaRelease(&c_export);
+ }
+ state.SetItemsProcessed(state.iterations());
+}
+
+static void ExportSchema(benchmark::State& state) { // NOLINT non-const reference
+ struct ArrowSchema c_export;
+ auto schema = ExampleSchema();
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(ExportSchema(*schema, &c_export));
+ ArrowSchemaRelease(&c_export);
+ }
+ state.SetItemsProcessed(state.iterations());
+}
+
+static void ExportArray(benchmark::State& state) { // NOLINT non-const reference
+ struct ArrowArray c_export;
+ auto array = ArrayFromJSON(utf8(), R"(["foo", "bar", null])");
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(ExportArray(*array, &c_export));
+ ArrowArrayRelease(&c_export);
+ }
+ state.SetItemsProcessed(state.iterations());
+}
+
+static void ExportRecordBatch(benchmark::State& state) { // NOLINT non-const reference
+ struct ArrowArray c_export;
+ auto batch = ExampleRecordBatch();
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(ExportRecordBatch(*batch, &c_export));
+ ArrowArrayRelease(&c_export);
+ }
+ state.SetItemsProcessed(state.iterations());
+}
+
+static void ExportImportType(benchmark::State& state) { // NOLINT non-const reference
+ struct ArrowSchema c_export;
+ auto type = utf8();
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(ExportType(*type, &c_export));
+ ImportType(&c_export).ValueOrDie();
+ }
+ state.SetItemsProcessed(state.iterations());
+}
+
+static void ExportImportSchema(benchmark::State& state) { // NOLINT non-const reference
+ struct ArrowSchema c_export;
+ auto schema = ExampleSchema();
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(ExportSchema(*schema, &c_export));
+ ImportSchema(&c_export).ValueOrDie();
+ }
+ state.SetItemsProcessed(state.iterations());
+}
+
+static void ExportImportArray(benchmark::State& state) { // NOLINT non-const reference
+ struct ArrowArray c_export;
+ auto array = ArrayFromJSON(utf8(), R"(["foo", "bar", null])");
+ auto type = array->type();
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(ExportArray(*array, &c_export));
+ ImportArray(&c_export, type).ValueOrDie();
+ }
+ state.SetItemsProcessed(state.iterations());
+}
+
+static void ExportImportRecordBatch(
+ benchmark::State& state) { // NOLINT non-const reference
+ struct ArrowArray c_export;
+ auto batch = ExampleRecordBatch();
+ auto schema = batch->schema();
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(ExportRecordBatch(*batch, &c_export));
+ ImportRecordBatch(&c_export, schema).ValueOrDie();
+ }
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(ExportType);
+BENCHMARK(ExportSchema);
+BENCHMARK(ExportArray);
+BENCHMARK(ExportRecordBatch);
+
+BENCHMARK(ExportImportType);
+BENCHMARK(ExportImportSchema);
+BENCHMARK(ExportImportArray);
+BENCHMARK(ExportImportRecordBatch);
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/c/bridge_test.cc b/src/arrow/cpp/src/arrow/c/bridge_test.cc
new file mode 100644
index 000000000..fd2beca82
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/c/bridge_test.cc
@@ -0,0 +1,3226 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cerrno>
+#include <deque>
+#include <functional>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
+
+#include "arrow/c/bridge.h"
+#include "arrow/c/helpers.h"
+#include "arrow/c/util_internal.h"
+#include "arrow/ipc/json_simple.h"
+#include "arrow/memory_pool.h"
+#include "arrow/testing/extension_type.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+
+using internal::ArrayExportGuard;
+using internal::ArrayExportTraits;
+using internal::ArrayStreamExportGuard;
+using internal::ArrayStreamExportTraits;
+using internal::checked_cast;
+using internal::SchemaExportGuard;
+using internal::SchemaExportTraits;
+
+template <typename T>
+struct ExportTraits {};
+
+template <typename T>
+using Exporter = std::function<Status(const T&, struct ArrowSchema*)>;
+
+template <>
+struct ExportTraits<DataType> {
+ static Exporter<DataType> ExportFunc;
+};
+
+template <>
+struct ExportTraits<Field> {
+ static Exporter<Field> ExportFunc;
+};
+
+template <>
+struct ExportTraits<Schema> {
+ static Exporter<Schema> ExportFunc;
+};
+
+Exporter<DataType> ExportTraits<DataType>::ExportFunc = ExportType;
+Exporter<Field> ExportTraits<Field>::ExportFunc = ExportField;
+Exporter<Schema> ExportTraits<Schema>::ExportFunc = ExportSchema;
+
+// An interceptor that checks whether a release callback was called.
+// (for import tests)
+template <typename Traits>
+class ReleaseCallback {
+ public:
+ using CType = typename Traits::CType;
+
+ explicit ReleaseCallback(CType* c_struct) : called_(false) {
+ orig_release_ = c_struct->release;
+ orig_private_data_ = c_struct->private_data;
+ c_struct->release = StaticRelease;
+ c_struct->private_data = this;
+ }
+
+ static void StaticRelease(CType* c_struct) {
+ reinterpret_cast<ReleaseCallback*>(c_struct->private_data)->Release(c_struct);
+ }
+
+ void Release(CType* c_struct) {
+ ASSERT_FALSE(called_) << "ReleaseCallback called twice";
+ called_ = true;
+ ASSERT_FALSE(Traits::IsReleasedFunc(c_struct))
+ << "ReleaseCallback called with released Arrow"
+ << (std::is_same<CType, ArrowSchema>::value ? "Schema" : "Array");
+ // Call original release callback
+ c_struct->release = orig_release_;
+ c_struct->private_data = orig_private_data_;
+ Traits::ReleaseFunc(c_struct);
+ ASSERT_TRUE(Traits::IsReleasedFunc(c_struct))
+ << "ReleaseCallback did not release ArrowSchema";
+ }
+
+ void AssertCalled() { ASSERT_TRUE(called_) << "ReleaseCallback was not called"; }
+
+ void AssertNotCalled() { ASSERT_FALSE(called_) << "ReleaseCallback was called"; }
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ReleaseCallback);
+
+ bool called_;
+ void (*orig_release_)(CType*);
+ void* orig_private_data_;
+};
+
+using SchemaReleaseCallback = ReleaseCallback<SchemaExportTraits>;
+using ArrayReleaseCallback = ReleaseCallback<ArrayExportTraits>;
+
+static const std::vector<std::string> kMetadataKeys1{"key1", "key2"};
+static const std::vector<std::string> kMetadataValues1{"", "bar"};
+
+static const std::vector<std::string> kMetadataKeys2{"key"};
+static const std::vector<std::string> kMetadataValues2{"abcde"};
+
+// clang-format off
+static const std::string kEncodedMetadata1{ // NOLINT: runtime/string
+#if ARROW_LITTLE_ENDIAN
+ 2, 0, 0, 0,
+ 4, 0, 0, 0, 'k', 'e', 'y', '1', 0, 0, 0, 0,
+ 4, 0, 0, 0, 'k', 'e', 'y', '2', 3, 0, 0, 0, 'b', 'a', 'r'};
+#else
+ 0, 0, 0, 2,
+ 0, 0, 0, 4, 'k', 'e', 'y', '1', 0, 0, 0, 0,
+ 0, 0, 0, 4, 'k', 'e', 'y', '2', 0, 0, 0, 3, 'b', 'a', 'r'};
+#endif
+
+static const std::string kEncodedMetadata2{ // NOLINT: runtime/string
+#if ARROW_LITTLE_ENDIAN
+ 1, 0, 0, 0,
+ 3, 0, 0, 0, 'k', 'e', 'y', 5, 0, 0, 0, 'a', 'b', 'c', 'd', 'e'};
+#else
+ 0, 0, 0, 1,
+ 0, 0, 0, 3, 'k', 'e', 'y', 0, 0, 0, 5, 'a', 'b', 'c', 'd', 'e'};
+#endif
+
+static const std::string kEncodedUuidMetadata = // NOLINT: runtime/string
+#if ARROW_LITTLE_ENDIAN
+ std::string {2, 0, 0, 0} +
+ std::string {20, 0, 0, 0} + kExtensionTypeKeyName +
+ std::string {4, 0, 0, 0} + "uuid" +
+ std::string {24, 0, 0, 0} + kExtensionMetadataKeyName +
+ std::string {15, 0, 0, 0} + "uuid-serialized";
+#else
+ std::string {0, 0, 0, 2} +
+ std::string {0, 0, 0, 20} + kExtensionTypeKeyName +
+ std::string {0, 0, 0, 4} + "uuid" +
+ std::string {0, 0, 0, 24} + kExtensionMetadataKeyName +
+ std::string {0, 0, 0, 15} + "uuid-serialized";
+#endif
+
+static const std::string kEncodedDictExtensionMetadata = // NOLINT: runtime/string
+#if ARROW_LITTLE_ENDIAN
+ std::string {2, 0, 0, 0} +
+ std::string {20, 0, 0, 0} + kExtensionTypeKeyName +
+ std::string {14, 0, 0, 0} + "dict-extension" +
+ std::string {24, 0, 0, 0} + kExtensionMetadataKeyName +
+ std::string {25, 0, 0, 0} + "dict-extension-serialized";
+#else
+ std::string {0, 0, 0, 2} +
+ std::string {0, 0, 0, 20} + kExtensionTypeKeyName +
+ std::string {0, 0, 0, 14} + "dict-extension" +
+ std::string {0, 0, 0, 24} + kExtensionMetadataKeyName +
+ std::string {0, 0, 0, 25} + "dict-extension-serialized";
+#endif
+
+static const std::string kEncodedComplex128Metadata = // NOLINT: runtime/string
+#if ARROW_LITTLE_ENDIAN
+ std::string {2, 0, 0, 0} +
+ std::string {20, 0, 0, 0} + kExtensionTypeKeyName +
+ std::string {10, 0, 0, 0} + "complex128" +
+ std::string {24, 0, 0, 0} + kExtensionMetadataKeyName +
+ std::string {21, 0, 0, 0} + "complex128-serialized";
+#else
+ std::string {0, 0, 0, 2} +
+ std::string {0, 0, 0, 20} + kExtensionTypeKeyName +
+ std::string {0, 0, 0, 10} + "complex128" +
+ std::string {0, 0, 0, 24} + kExtensionMetadataKeyName +
+ std::string {0, 0, 0, 21} + "complex128-serialized";
+#endif
+// clang-format on
+
+static constexpr int64_t kDefaultFlags = ARROW_FLAG_NULLABLE;
+
+////////////////////////////////////////////////////////////////////////////
+// Schema export tests
+
+struct SchemaExportChecker {
+ SchemaExportChecker(std::vector<std::string> flattened_formats,
+ std::vector<std::string> flattened_names,
+ std::vector<int64_t> flattened_flags = {},
+ std::vector<std::string> flattened_metadata = {})
+ : flattened_formats_(std::move(flattened_formats)),
+ flattened_names_(std::move(flattened_names)),
+ flattened_flags_(
+ flattened_flags.empty()
+ ? std::vector<int64_t>(flattened_formats_.size(), kDefaultFlags)
+ : std::move(flattened_flags)),
+ flattened_metadata_(std::move(flattened_metadata)),
+ flattened_index_(0) {}
+
+ void operator()(struct ArrowSchema* c_export, bool inner = false) {
+ ASSERT_LT(flattened_index_, flattened_formats_.size());
+ ASSERT_LT(flattened_index_, flattened_names_.size());
+ ASSERT_LT(flattened_index_, flattened_flags_.size());
+ ASSERT_EQ(std::string(c_export->format), flattened_formats_[flattened_index_]);
+ ASSERT_EQ(std::string(c_export->name), flattened_names_[flattened_index_]);
+ std::string expected_md;
+ if (!flattened_metadata_.empty()) {
+ expected_md = flattened_metadata_[flattened_index_];
+ }
+ if (!expected_md.empty()) {
+ ASSERT_NE(c_export->metadata, nullptr);
+ ASSERT_EQ(std::string(c_export->metadata, expected_md.size()), expected_md);
+ } else {
+ ASSERT_EQ(c_export->metadata, nullptr);
+ }
+ ASSERT_EQ(c_export->flags, flattened_flags_[flattened_index_]);
+ ++flattened_index_;
+
+ if (c_export->dictionary != nullptr) {
+ // Recurse into dictionary
+ operator()(c_export->dictionary, true);
+ }
+
+ if (c_export->n_children > 0) {
+ ASSERT_NE(c_export->children, nullptr);
+ // Recurse into children
+ for (int64_t i = 0; i < c_export->n_children; ++i) {
+ ASSERT_NE(c_export->children[i], nullptr);
+ operator()(c_export->children[i], true);
+ }
+ } else {
+ ASSERT_EQ(c_export->children, nullptr);
+ }
+
+ if (!inner) {
+ // Caller gave the right number of names and format strings
+ ASSERT_EQ(flattened_index_, flattened_formats_.size());
+ ASSERT_EQ(flattened_index_, flattened_names_.size());
+ ASSERT_EQ(flattened_index_, flattened_flags_.size());
+ }
+ }
+
+ const std::vector<std::string> flattened_formats_;
+ const std::vector<std::string> flattened_names_;
+ std::vector<int64_t> flattened_flags_;
+ const std::vector<std::string> flattened_metadata_;
+ size_t flattened_index_;
+};
+
+class TestSchemaExport : public ::testing::Test {
+ public:
+ void SetUp() override { pool_ = default_memory_pool(); }
+
+ template <typename T>
+ void TestNested(const std::shared_ptr<T>& schema_like,
+ std::vector<std::string> flattened_formats,
+ std::vector<std::string> flattened_names,
+ std::vector<int64_t> flattened_flags = {},
+ std::vector<std::string> flattened_metadata = {}) {
+ SchemaExportChecker checker(std::move(flattened_formats), std::move(flattened_names),
+ std::move(flattened_flags),
+ std::move(flattened_metadata));
+
+ auto orig_bytes = pool_->bytes_allocated();
+
+ struct ArrowSchema c_export;
+ ASSERT_OK(ExportTraits<T>::ExportFunc(*schema_like, &c_export));
+
+ SchemaExportGuard guard(&c_export);
+ auto new_bytes = pool_->bytes_allocated();
+ ASSERT_GT(new_bytes, orig_bytes);
+
+ checker(&c_export);
+
+ // Release the ArrowSchema, underlying data should be destroyed
+ guard.Release();
+ ASSERT_EQ(pool_->bytes_allocated(), orig_bytes);
+ }
+
+ template <typename T>
+ void TestPrimitive(const std::shared_ptr<T>& schema_like, const char* format,
+ const std::string& name = "", int64_t flags = kDefaultFlags,
+ const std::string& metadata = "") {
+ TestNested(schema_like, {format}, {name}, {flags}, {metadata});
+ }
+
+ protected:
+ MemoryPool* pool_;
+};
+
+TEST_F(TestSchemaExport, Primitive) {
+ TestPrimitive(int8(), "c");
+ TestPrimitive(int16(), "s");
+ TestPrimitive(int32(), "i");
+ TestPrimitive(int64(), "l");
+ TestPrimitive(uint8(), "C");
+ TestPrimitive(uint16(), "S");
+ TestPrimitive(uint32(), "I");
+ TestPrimitive(uint64(), "L");
+
+ TestPrimitive(boolean(), "b");
+ TestPrimitive(null(), "n");
+
+ TestPrimitive(float16(), "e");
+ TestPrimitive(float32(), "f");
+ TestPrimitive(float64(), "g");
+
+ TestPrimitive(fixed_size_binary(3), "w:3");
+ TestPrimitive(binary(), "z");
+ TestPrimitive(large_binary(), "Z");
+ TestPrimitive(utf8(), "u");
+ TestPrimitive(large_utf8(), "U");
+
+ TestPrimitive(decimal(16, 4), "d:16,4");
+ TestPrimitive(decimal256(16, 4), "d:16,4,256");
+
+ TestPrimitive(decimal(15, 0), "d:15,0");
+ TestPrimitive(decimal256(15, 0), "d:15,0,256");
+
+ TestPrimitive(decimal(15, -4), "d:15,-4");
+ TestPrimitive(decimal256(15, -4), "d:15,-4,256");
+}
+
+TEST_F(TestSchemaExport, Temporal) {
+ TestPrimitive(date32(), "tdD");
+ TestPrimitive(date64(), "tdm");
+ TestPrimitive(time32(TimeUnit::SECOND), "tts");
+ TestPrimitive(time32(TimeUnit::MILLI), "ttm");
+ TestPrimitive(time64(TimeUnit::MICRO), "ttu");
+ TestPrimitive(time64(TimeUnit::NANO), "ttn");
+ TestPrimitive(duration(TimeUnit::SECOND), "tDs");
+ TestPrimitive(duration(TimeUnit::MILLI), "tDm");
+ TestPrimitive(duration(TimeUnit::MICRO), "tDu");
+ TestPrimitive(duration(TimeUnit::NANO), "tDn");
+ TestPrimitive(month_interval(), "tiM");
+ TestPrimitive(month_day_nano_interval(), "tin");
+ TestPrimitive(day_time_interval(), "tiD");
+
+ TestPrimitive(timestamp(TimeUnit::SECOND), "tss:");
+ TestPrimitive(timestamp(TimeUnit::SECOND, "Europe/Paris"), "tss:Europe/Paris");
+ TestPrimitive(timestamp(TimeUnit::MILLI), "tsm:");
+ TestPrimitive(timestamp(TimeUnit::MILLI, "Europe/Paris"), "tsm:Europe/Paris");
+ TestPrimitive(timestamp(TimeUnit::MICRO), "tsu:");
+ TestPrimitive(timestamp(TimeUnit::MICRO, "Europe/Paris"), "tsu:Europe/Paris");
+ TestPrimitive(timestamp(TimeUnit::NANO), "tsn:");
+ TestPrimitive(timestamp(TimeUnit::NANO, "Europe/Paris"), "tsn:Europe/Paris");
+}
+
+TEST_F(TestSchemaExport, List) {
+ TestNested(list(int8()), {"+l", "c"}, {"", "item"});
+ TestNested(large_list(uint16()), {"+L", "S"}, {"", "item"});
+ TestNested(fixed_size_list(int64(), 2), {"+w:2", "l"}, {"", "item"});
+
+ TestNested(list(large_list(int32())), {"+l", "+L", "i"}, {"", "item", "item"});
+}
+
+TEST_F(TestSchemaExport, Struct) {
+ auto type = struct_({field("a", int8()), field("b", utf8())});
+ TestNested(type, {"+s", "c", "u"}, {"", "a", "b"},
+ {ARROW_FLAG_NULLABLE, ARROW_FLAG_NULLABLE, ARROW_FLAG_NULLABLE});
+
+ // With nullable = false
+ type = struct_({field("a", int8(), /*nullable=*/false), field("b", utf8())});
+ TestNested(type, {"+s", "c", "u"}, {"", "a", "b"},
+ {ARROW_FLAG_NULLABLE, 0, ARROW_FLAG_NULLABLE});
+
+ // With metadata
+ auto f0 = type->field(0);
+ auto f1 =
+ type->field(1)->WithMetadata(key_value_metadata(kMetadataKeys1, kMetadataValues1));
+ type = struct_({f0, f1});
+ TestNested(type, {"+s", "c", "u"}, {"", "a", "b"},
+ {ARROW_FLAG_NULLABLE, 0, ARROW_FLAG_NULLABLE}, {"", "", kEncodedMetadata1});
+}
+
+TEST_F(TestSchemaExport, Map) {
+ TestNested(map(int8(), utf8()), {"+m", "+s", "c", "u"}, {"", "entries", "key", "value"},
+ {ARROW_FLAG_NULLABLE, 0, 0, ARROW_FLAG_NULLABLE});
+ TestNested(
+ map(int8(), utf8(), /*keys_sorted=*/true), {"+m", "+s", "c", "u"},
+ {"", "entries", "key", "value"},
+ {ARROW_FLAG_NULLABLE | ARROW_FLAG_MAP_KEYS_SORTED, 0, 0, ARROW_FLAG_NULLABLE});
+}
+
+TEST_F(TestSchemaExport, Union) {
+ // Dense
+ auto field_a = field("a", int8());
+ auto field_b = field("b", boolean(), /*nullable=*/false);
+ auto type = dense_union({field_a, field_b}, {42, 43});
+ TestNested(type, {"+ud:42,43", "c", "b"}, {"", "a", "b"},
+ {ARROW_FLAG_NULLABLE, ARROW_FLAG_NULLABLE, 0});
+ // Sparse
+ field_a = field("a", int8(), /*nullable=*/false);
+ field_b = field("b", boolean());
+ type = sparse_union({field_a, field_b}, {42, 43});
+ TestNested(type, {"+us:42,43", "c", "b"}, {"", "a", "b"},
+ {ARROW_FLAG_NULLABLE, 0, ARROW_FLAG_NULLABLE});
+}
+
+std::string GetIndexFormat(Type::type type_id) {
+ switch (type_id) {
+ case Type::UINT8:
+ return "C";
+ case Type::INT8:
+ return "c";
+ case Type::UINT16:
+ return "S";
+ case Type::INT16:
+ return "s";
+ case Type::UINT32:
+ return "I";
+ case Type::INT32:
+ return "i";
+ case Type::UINT64:
+ return "L";
+ case Type::INT64:
+ return "l";
+ default:
+ DCHECK(false);
+ return "";
+ }
+}
+
+TEST_F(TestSchemaExport, Dictionary) {
+ for (auto index_ty : all_dictionary_index_types()) {
+ std::string index_fmt = GetIndexFormat(index_ty->id());
+ TestNested(dictionary(index_ty, utf8()), {index_fmt, "u"}, {"", ""});
+ TestNested(dictionary(index_ty, list(utf8()), /*ordered=*/true),
+ {index_fmt, "+l", "u"}, {"", "", "item"},
+ {ARROW_FLAG_NULLABLE | ARROW_FLAG_DICTIONARY_ORDERED, ARROW_FLAG_NULLABLE,
+ ARROW_FLAG_NULLABLE});
+ TestNested(large_list(dictionary(index_ty, list(utf8()))),
+ {"+L", index_fmt, "+l", "u"}, {"", "item", "", "item"});
+ }
+}
+
+TEST_F(TestSchemaExport, Extension) {
+ TestPrimitive(uuid(), "w:16", "", kDefaultFlags, kEncodedUuidMetadata);
+
+ TestNested(dict_extension_type(), {"c", "u"}, {"", ""}, {kDefaultFlags, kDefaultFlags},
+ {kEncodedDictExtensionMetadata, ""});
+
+ TestNested(complex128(), {"+s", "g", "g"}, {"", "real", "imag"},
+ {ARROW_FLAG_NULLABLE, 0, 0}, {kEncodedComplex128Metadata, "", ""});
+}
+
+TEST_F(TestSchemaExport, ExportField) {
+ TestPrimitive(field("thing", null()), "n", "thing", ARROW_FLAG_NULLABLE);
+ // With nullable = false
+ TestPrimitive(field("thing", null(), /*nullable=*/false), "n", "thing", 0);
+ // With metadata
+ auto f = field("thing", null(), /*nullable=*/false);
+ f = f->WithMetadata(key_value_metadata(kMetadataKeys1, kMetadataValues1));
+ TestPrimitive(f, "n", "thing", 0, kEncodedMetadata1);
+}
+
+TEST_F(TestSchemaExport, ExportSchema) {
+ // A schema is exported as an equivalent struct type (+ top-level metadata)
+ auto f1 = field("nulls", null(), /*nullable=*/false);
+ auto f2 = field("lists", list(int64()));
+ auto schema = ::arrow::schema({f1, f2});
+ TestNested(schema, {"+s", "n", "+l", "l"}, {"", "nulls", "lists", "item"},
+ {0, 0, ARROW_FLAG_NULLABLE, ARROW_FLAG_NULLABLE});
+
+ // With field metadata
+ f2 = f2->WithMetadata(key_value_metadata(kMetadataKeys1, kMetadataValues1));
+ schema = ::arrow::schema({f1, f2});
+ TestNested(schema, {"+s", "n", "+l", "l"}, {"", "nulls", "lists", "item"},
+ {0, 0, ARROW_FLAG_NULLABLE, ARROW_FLAG_NULLABLE},
+ {"", "", kEncodedMetadata1, ""});
+
+ // With field metadata and schema metadata
+ schema = schema->WithMetadata(key_value_metadata(kMetadataKeys2, kMetadataValues2));
+ TestNested(schema, {"+s", "n", "+l", "l"}, {"", "nulls", "lists", "item"},
+ {0, 0, ARROW_FLAG_NULLABLE, ARROW_FLAG_NULLABLE},
+ {kEncodedMetadata2, "", kEncodedMetadata1, ""});
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Array export tests
+
+struct ArrayExportChecker {
+ void operator()(struct ArrowArray* c_export, const ArrayData& expected_data) {
+ ASSERT_EQ(c_export->length, expected_data.length);
+ ASSERT_EQ(c_export->null_count, expected_data.null_count);
+ ASSERT_EQ(c_export->offset, expected_data.offset);
+
+ auto expected_n_buffers = static_cast<int64_t>(expected_data.buffers.size());
+ auto expected_buffers = expected_data.buffers.data();
+ if (!internal::HasValidityBitmap(expected_data.type->id())) {
+ --expected_n_buffers;
+ ++expected_buffers;
+ }
+ ASSERT_EQ(c_export->n_buffers, expected_n_buffers);
+ ASSERT_NE(c_export->buffers, nullptr);
+ for (int64_t i = 0; i < c_export->n_buffers; ++i) {
+ auto expected_ptr = expected_buffers[i] ? expected_buffers[i]->data() : nullptr;
+ ASSERT_EQ(c_export->buffers[i], expected_ptr);
+ }
+
+ if (expected_data.dictionary != nullptr) {
+ // Recurse into dictionary
+ ASSERT_NE(c_export->dictionary, nullptr);
+ operator()(c_export->dictionary, *expected_data.dictionary);
+ } else {
+ ASSERT_EQ(c_export->dictionary, nullptr);
+ }
+
+ ASSERT_EQ(c_export->n_children,
+ static_cast<int64_t>(expected_data.child_data.size()));
+ if (c_export->n_children > 0) {
+ ASSERT_NE(c_export->children, nullptr);
+ // Recurse into children
+ for (int64_t i = 0; i < c_export->n_children; ++i) {
+ ASSERT_NE(c_export->children[i], nullptr);
+ operator()(c_export->children[i], *expected_data.child_data[i]);
+ }
+ } else {
+ ASSERT_EQ(c_export->children, nullptr);
+ }
+ }
+};
+
+struct RecordBatchExportChecker {
+ void operator()(struct ArrowArray* c_export, const RecordBatch& expected_batch) {
+ ASSERT_EQ(c_export->length, expected_batch.num_rows());
+ ASSERT_EQ(c_export->null_count, 0);
+ ASSERT_EQ(c_export->offset, 0);
+
+ ASSERT_EQ(c_export->n_buffers, 1); // Like a struct array
+ ASSERT_NE(c_export->buffers, nullptr);
+ ASSERT_EQ(c_export->buffers[0], nullptr); // No null bitmap
+ ASSERT_EQ(c_export->dictionary, nullptr);
+
+ ASSERT_EQ(c_export->n_children, expected_batch.num_columns());
+ if (c_export->n_children > 0) {
+ ArrayExportChecker array_checker{};
+
+ ASSERT_NE(c_export->children, nullptr);
+ // Recurse into children
+ for (int i = 0; i < expected_batch.num_columns(); ++i) {
+ ASSERT_NE(c_export->children[i], nullptr);
+ array_checker(c_export->children[i], *expected_batch.column(i)->data());
+ }
+ } else {
+ ASSERT_EQ(c_export->children, nullptr);
+ }
+ }
+};
+
+class TestArrayExport : public ::testing::Test {
+ public:
+ void SetUp() override { pool_ = default_memory_pool(); }
+
+ static std::function<Result<std::shared_ptr<Array>>()> JSONArrayFactory(
+ std::shared_ptr<DataType> type, const char* json) {
+ return [=]() { return ArrayFromJSON(type, json); };
+ }
+
+ template <typename ArrayFactory, typename ExportCheckFunc>
+ void TestWithArrayFactory(ArrayFactory&& factory, ExportCheckFunc&& check_func) {
+ auto orig_bytes = pool_->bytes_allocated();
+
+ std::shared_ptr<Array> arr;
+ ASSERT_OK_AND_ASSIGN(arr, ToResult(factory()));
+ ARROW_SCOPED_TRACE("type = ", arr->type()->ToString(),
+ ", array data = ", arr->ToString());
+ const ArrayData& data = *arr->data(); // non-owning reference
+ struct ArrowArray c_export;
+ ASSERT_OK(ExportArray(*arr, &c_export));
+
+ ArrayExportGuard guard(&c_export);
+ auto new_bytes = pool_->bytes_allocated();
+ ASSERT_GT(new_bytes, orig_bytes);
+
+ // Release the shared_ptr<Array>, underlying data should be held alive
+ arr.reset();
+ ASSERT_EQ(pool_->bytes_allocated(), new_bytes);
+ check_func(&c_export, data);
+
+ // Release the ArrowArray, underlying data should be destroyed
+ guard.Release();
+ ASSERT_EQ(pool_->bytes_allocated(), orig_bytes);
+ }
+
+ template <typename ArrayFactory>
+ void TestNested(ArrayFactory&& factory) {
+ ArrayExportChecker checker;
+ TestWithArrayFactory(std::forward<ArrayFactory>(factory), checker);
+ }
+
+ void TestNested(const std::shared_ptr<DataType>& type, const char* json) {
+ TestNested(JSONArrayFactory(type, json));
+ }
+
+ template <typename ArrayFactory>
+ void TestPrimitive(ArrayFactory&& factory) {
+ TestNested(std::forward<ArrayFactory>(factory));
+ }
+
+ void TestPrimitive(const std::shared_ptr<DataType>& type, const char* json) {
+ TestNested(type, json);
+ }
+
+ template <typename ArrayFactory, typename ExportCheckFunc>
+ void TestMoveWithArrayFactory(ArrayFactory&& factory, ExportCheckFunc&& check_func) {
+ auto orig_bytes = pool_->bytes_allocated();
+
+ std::shared_ptr<Array> arr;
+ ASSERT_OK_AND_ASSIGN(arr, ToResult(factory()));
+ const ArrayData& data = *arr->data(); // non-owning reference
+ struct ArrowArray c_export_temp, c_export_final;
+ ASSERT_OK(ExportArray(*arr, &c_export_temp));
+
+ // Move the ArrowArray to its final location
+ ArrowArrayMove(&c_export_temp, &c_export_final);
+ ASSERT_TRUE(ArrowArrayIsReleased(&c_export_temp));
+
+ ArrayExportGuard guard(&c_export_final);
+ auto new_bytes = pool_->bytes_allocated();
+ ASSERT_GT(new_bytes, orig_bytes);
+ check_func(&c_export_final, data);
+
+ // Release the shared_ptr<Array>, underlying data should be held alive
+ arr.reset();
+ ASSERT_EQ(pool_->bytes_allocated(), new_bytes);
+ check_func(&c_export_final, data);
+
+ // Release the ArrowArray, underlying data should be destroyed
+ guard.Release();
+ ASSERT_EQ(pool_->bytes_allocated(), orig_bytes);
+ }
+
+ template <typename ArrayFactory>
+ void TestMoveNested(ArrayFactory&& factory) {
+ ArrayExportChecker checker;
+
+ TestMoveWithArrayFactory(std::forward<ArrayFactory>(factory), checker);
+ }
+
+ void TestMoveNested(const std::shared_ptr<DataType>& type, const char* json) {
+ TestMoveNested(JSONArrayFactory(type, json));
+ }
+
+ void TestMovePrimitive(const std::shared_ptr<DataType>& type, const char* json) {
+ TestMoveNested(type, json);
+ }
+
+ template <typename ArrayFactory, typename ExportCheckFunc>
+ void TestMoveChildWithArrayFactory(ArrayFactory&& factory, int64_t child_id,
+ ExportCheckFunc&& check_func) {
+ auto orig_bytes = pool_->bytes_allocated();
+
+ std::shared_ptr<Array> arr;
+ ASSERT_OK_AND_ASSIGN(arr, ToResult(factory()));
+ struct ArrowArray c_export_parent, c_export_child;
+ ASSERT_OK(ExportArray(*arr, &c_export_parent));
+
+ auto bytes_with_parent = pool_->bytes_allocated();
+ ASSERT_GT(bytes_with_parent, orig_bytes);
+
+ // Move the child ArrowArray to its final location
+ {
+ ArrayExportGuard parent_guard(&c_export_parent);
+ ASSERT_LT(child_id, c_export_parent.n_children);
+ ArrowArrayMove(c_export_parent.children[child_id], &c_export_child);
+ }
+ ArrayExportGuard child_guard(&c_export_child);
+
+ // Now parent is released
+ ASSERT_TRUE(ArrowArrayIsReleased(&c_export_parent));
+ auto bytes_with_child = pool_->bytes_allocated();
+ ASSERT_LT(bytes_with_child, bytes_with_parent);
+ ASSERT_GT(bytes_with_child, orig_bytes);
+
+ const ArrayData& data = *arr->data()->child_data[child_id]; // non-owning reference
+ check_func(&c_export_child, data);
+
+ // Release the shared_ptr<Array>, some underlying data should be held alive
+ arr.reset();
+ ASSERT_LT(pool_->bytes_allocated(), bytes_with_child);
+ ASSERT_GT(pool_->bytes_allocated(), orig_bytes);
+ check_func(&c_export_child, data);
+
+ // Release the ArrowArray, underlying data should be destroyed
+ child_guard.Release();
+ ASSERT_EQ(pool_->bytes_allocated(), orig_bytes);
+ }
+
+ template <typename ArrayFactory>
+ void TestMoveChild(ArrayFactory&& factory, int64_t child_id) {
+ ArrayExportChecker checker;
+
+ TestMoveChildWithArrayFactory(std::forward<ArrayFactory>(factory), child_id, checker);
+ }
+
+ void TestMoveChild(const std::shared_ptr<DataType>& type, const char* json,
+ int64_t child_id) {
+ TestMoveChild(JSONArrayFactory(type, json), child_id);
+ }
+
+ template <typename ArrayFactory, typename ExportCheckFunc>
+ void TestMoveChildrenWithArrayFactory(ArrayFactory&& factory,
+ const std::vector<int64_t> children_ids,
+ ExportCheckFunc&& check_func) {
+ auto orig_bytes = pool_->bytes_allocated();
+
+ std::shared_ptr<Array> arr;
+ ASSERT_OK_AND_ASSIGN(arr, ToResult(factory()));
+ struct ArrowArray c_export_parent;
+ ASSERT_OK(ExportArray(*arr, &c_export_parent));
+
+ auto bytes_with_parent = pool_->bytes_allocated();
+ ASSERT_GT(bytes_with_parent, orig_bytes);
+
+ // Move the children ArrowArrays to their final locations
+ std::vector<struct ArrowArray> c_export_children(children_ids.size());
+ std::vector<ArrayExportGuard> child_guards;
+ std::vector<const ArrayData*> child_data;
+ {
+ ArrayExportGuard parent_guard(&c_export_parent);
+ for (size_t i = 0; i < children_ids.size(); ++i) {
+ const auto child_id = children_ids[i];
+ ASSERT_LT(child_id, c_export_parent.n_children);
+ ArrowArrayMove(c_export_parent.children[child_id], &c_export_children[i]);
+ child_guards.emplace_back(&c_export_children[i]);
+ // Keep non-owning pointer to the child ArrayData
+ child_data.push_back(arr->data()->child_data[child_id].get());
+ }
+ }
+
+ // Now parent is released
+ ASSERT_TRUE(ArrowArrayIsReleased(&c_export_parent));
+ auto bytes_with_child = pool_->bytes_allocated();
+ ASSERT_LT(bytes_with_child, bytes_with_parent);
+ ASSERT_GT(bytes_with_child, orig_bytes);
+ for (size_t i = 0; i < children_ids.size(); ++i) {
+ check_func(&c_export_children[i], *child_data[i]);
+ }
+
+ // Release the shared_ptr<Array>, the children data should be held alive
+ arr.reset();
+ ASSERT_LT(pool_->bytes_allocated(), bytes_with_child);
+ ASSERT_GT(pool_->bytes_allocated(), orig_bytes);
+ for (size_t i = 0; i < children_ids.size(); ++i) {
+ check_func(&c_export_children[i], *child_data[i]);
+ }
+
+ // Release the ArrowArrays, underlying data should be destroyed
+ for (auto& child_guard : child_guards) {
+ child_guard.Release();
+ }
+ ASSERT_EQ(pool_->bytes_allocated(), orig_bytes);
+ }
+
+ template <typename ArrayFactory>
+ void TestMoveChildren(ArrayFactory&& factory, const std::vector<int64_t> children_ids) {
+ ArrayExportChecker checker;
+
+ TestMoveChildrenWithArrayFactory(std::forward<ArrayFactory>(factory), children_ids,
+ checker);
+ }
+
+ void TestMoveChildren(const std::shared_ptr<DataType>& type, const char* json,
+ const std::vector<int64_t> children_ids) {
+ TestMoveChildren(JSONArrayFactory(type, json), children_ids);
+ }
+
+ protected:
+ MemoryPool* pool_;
+};
+
+TEST_F(TestArrayExport, Primitive) {
+ TestPrimitive(int8(), "[1, 2, null, -3]");
+ TestPrimitive(int16(), "[1, 2, -3]");
+ TestPrimitive(int32(), "[1, 2, null, -3]");
+ TestPrimitive(int64(), "[1, 2, -3]");
+ TestPrimitive(uint8(), "[1, 2, 3]");
+ TestPrimitive(uint16(), "[1, 2, null, 3]");
+ TestPrimitive(uint32(), "[1, 2, 3]");
+ TestPrimitive(uint64(), "[1, 2, null, 3]");
+
+ TestPrimitive(boolean(), "[true, false, null]");
+ TestPrimitive(null(), "[null, null]");
+
+ TestPrimitive(float32(), "[1.5, null]");
+ TestPrimitive(float64(), "[1.5, null]");
+
+ TestPrimitive(fixed_size_binary(3), R"(["foo", "bar", null])");
+ TestPrimitive(binary(), R"(["foo", "bar", null])");
+ TestPrimitive(large_binary(), R"(["foo", "bar", null])");
+ TestPrimitive(utf8(), R"(["foo", "bar", null])");
+ TestPrimitive(large_utf8(), R"(["foo", "bar", null])");
+
+ TestPrimitive(decimal(16, 4), R"(["1234.5670", null])");
+ TestPrimitive(decimal256(16, 4), R"(["1234.5670", null])");
+
+ TestPrimitive(month_day_nano_interval(), R"([[-1, 5, 20], null])");
+}
+
+TEST_F(TestArrayExport, PrimitiveSliced) {
+ auto factory = []() { return ArrayFromJSON(int16(), "[1, 2, null, -3]")->Slice(1, 2); };
+
+ TestPrimitive(factory);
+}
+
+TEST_F(TestArrayExport, Null) {
+ TestPrimitive(null(), "[null, null, null]");
+ TestPrimitive(null(), "[]");
+}
+
+TEST_F(TestArrayExport, Temporal) {
+ const char* json = "[1, 2, null, 42]";
+ TestPrimitive(date32(), json);
+ TestPrimitive(date64(), json);
+ TestPrimitive(time32(TimeUnit::SECOND), json);
+ TestPrimitive(time32(TimeUnit::MILLI), json);
+ TestPrimitive(time64(TimeUnit::MICRO), json);
+ TestPrimitive(time64(TimeUnit::NANO), json);
+ TestPrimitive(duration(TimeUnit::SECOND), json);
+ TestPrimitive(duration(TimeUnit::MILLI), json);
+ TestPrimitive(duration(TimeUnit::MICRO), json);
+ TestPrimitive(duration(TimeUnit::NANO), json);
+ TestPrimitive(month_interval(), json);
+
+ TestPrimitive(day_time_interval(), "[[7, 600], null]");
+
+ json = R"(["1970-01-01","2000-02-29","1900-02-28"])";
+ TestPrimitive(timestamp(TimeUnit::SECOND), json);
+ TestPrimitive(timestamp(TimeUnit::SECOND, "Europe/Paris"), json);
+ TestPrimitive(timestamp(TimeUnit::MILLI), json);
+ TestPrimitive(timestamp(TimeUnit::MILLI, "Europe/Paris"), json);
+ TestPrimitive(timestamp(TimeUnit::MICRO), json);
+ TestPrimitive(timestamp(TimeUnit::MICRO, "Europe/Paris"), json);
+ TestPrimitive(timestamp(TimeUnit::NANO), json);
+ TestPrimitive(timestamp(TimeUnit::NANO, "Europe/Paris"), json);
+}
+
+TEST_F(TestArrayExport, List) {
+ TestNested(list(int8()), "[[1, 2], [3, null], null]");
+ TestNested(large_list(uint16()), "[[1, 2], [3, null], null]");
+ TestNested(fixed_size_list(int64(), 2), "[[1, 2], [3, null], null]");
+
+ TestNested(list(large_list(int32())), "[[[1, 2], [3], null], null]");
+}
+
+TEST_F(TestArrayExport, ListSliced) {
+ {
+ auto factory = []() {
+ return ArrayFromJSON(list(int8()), "[[1, 2], [3, null], [4, 5, 6], null]")
+ ->Slice(1, 2);
+ };
+ TestNested(factory);
+ }
+ {
+ auto factory = []() {
+ auto values = ArrayFromJSON(int16(), "[1, 2, 3, 4, null, 5, 6, 7, 8]")->Slice(1, 6);
+ auto offsets = ArrayFromJSON(int32(), "[0, 2, 3, 5, 6]")->Slice(2, 4);
+ return ListArray::FromArrays(*offsets, *values);
+ };
+ TestNested(factory);
+ }
+}
+
+TEST_F(TestArrayExport, Struct) {
+ const char* data = R"([[1, "foo"], [2, null]])";
+ auto type = struct_({field("a", int8()), field("b", utf8())});
+ TestNested(type, data);
+}
+
+TEST_F(TestArrayExport, Map) {
+ const char* json = R"([[[1, "foo"], [2, null]], [[3, "bar"]]])";
+ TestNested(map(int8(), utf8()), json);
+ TestNested(map(int8(), utf8(), /*keys_sorted=*/true), json);
+}
+
+TEST_F(TestArrayExport, Union) {
+ const char* data = "[null, [42, 1], [43, true], [42, null], [42, 2]]";
+ // Dense
+ auto field_a = field("a", int8());
+ auto field_b = field("b", boolean(), /*nullable=*/false);
+ auto type = dense_union({field_a, field_b}, {42, 43});
+ TestNested(type, data);
+ // Sparse
+ field_a = field("a", int8(), /*nullable=*/false);
+ field_b = field("b", boolean());
+ type = sparse_union({field_a, field_b}, {42, 43});
+ TestNested(type, data);
+}
+
+TEST_F(TestArrayExport, Dictionary) {
+ {
+ auto factory = []() {
+ auto values = ArrayFromJSON(utf8(), R"(["foo", "bar", "quux"])");
+ auto indices = ArrayFromJSON(uint16(), "[0, 2, 1, null, 1]");
+ return DictionaryArray::FromArrays(dictionary(indices->type(), values->type()),
+ indices, values);
+ };
+ TestNested(factory);
+ }
+ {
+ auto factory = []() {
+ auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])");
+ auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]");
+ return DictionaryArray::FromArrays(
+ dictionary(indices->type(), values->type(), /*ordered=*/true), indices, values);
+ };
+ TestNested(factory);
+ }
+ {
+ auto factory = []() -> Result<std::shared_ptr<Array>> {
+ auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])");
+ auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]");
+ ARROW_ASSIGN_OR_RAISE(
+ auto dict_array,
+ DictionaryArray::FromArrays(dictionary(indices->type(), values->type()),
+ indices, values));
+ auto offsets = ArrayFromJSON(int64(), "[0, 2, 5]");
+ ARROW_ASSIGN_OR_RAISE(auto arr, LargeListArray::FromArrays(*offsets, *dict_array));
+ RETURN_NOT_OK(arr->ValidateFull());
+ return arr;
+ };
+ TestNested(factory);
+ }
+}
+
+TEST_F(TestArrayExport, Extension) {
+ TestPrimitive(ExampleUuid);
+ TestPrimitive(ExampleSmallint);
+ TestPrimitive(ExampleComplex128);
+}
+
+TEST_F(TestArrayExport, MovePrimitive) {
+ TestMovePrimitive(int8(), "[1, 2, null, -3]");
+ TestMovePrimitive(fixed_size_binary(3), R"(["foo", "bar", null])");
+ TestMovePrimitive(binary(), R"(["foo", "bar", null])");
+}
+
+TEST_F(TestArrayExport, MoveNested) {
+ TestMoveNested(list(int8()), "[[1, 2], [3, null], null]");
+ TestMoveNested(list(large_list(int32())), "[[[1, 2], [3], null], null]");
+ TestMoveNested(struct_({field("a", int8()), field("b", utf8())}),
+ R"([[1, "foo"], [2, null]])");
+}
+
+TEST_F(TestArrayExport, MoveDictionary) {
+ {
+ auto factory = []() {
+ auto values = ArrayFromJSON(utf8(), R"(["foo", "bar", "quux"])");
+ auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]");
+ return DictionaryArray::FromArrays(dictionary(indices->type(), values->type()),
+ indices, values);
+ };
+ TestMoveNested(factory);
+ }
+ {
+ auto factory = []() -> Result<std::shared_ptr<Array>> {
+ auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])");
+ auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]");
+ ARROW_ASSIGN_OR_RAISE(
+ auto dict_array,
+ DictionaryArray::FromArrays(dictionary(indices->type(), values->type()),
+ indices, values));
+ auto offsets = ArrayFromJSON(int64(), "[0, 2, 5]");
+ ARROW_ASSIGN_OR_RAISE(auto arr, LargeListArray::FromArrays(*offsets, *dict_array));
+ RETURN_NOT_OK(arr->ValidateFull());
+ return arr;
+ };
+ TestMoveNested(factory);
+ }
+}
+
+TEST_F(TestArrayExport, MoveChild) {
+ TestMoveChild(list(int8()), "[[1, 2], [3, null], null]", /*child_id=*/0);
+ TestMoveChild(list(large_list(int32())), "[[[1, 2], [3], null], null]",
+ /*child_id=*/0);
+ TestMoveChild(struct_({field("ints", int8()), field("strs", utf8())}),
+ R"([[1, "foo"], [2, null]])",
+ /*child_id=*/0);
+ TestMoveChild(struct_({field("ints", int8()), field("strs", utf8())}),
+ R"([[1, "foo"], [2, null]])",
+ /*child_id=*/1);
+ {
+ auto factory = []() -> Result<std::shared_ptr<Array>> {
+ auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])");
+ auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]");
+ ARROW_ASSIGN_OR_RAISE(
+ auto dict_array,
+ DictionaryArray::FromArrays(dictionary(indices->type(), values->type()),
+ indices, values));
+ auto offsets = ArrayFromJSON(int64(), "[0, 2, 5]");
+ ARROW_ASSIGN_OR_RAISE(auto arr, LargeListArray::FromArrays(*offsets, *dict_array));
+ RETURN_NOT_OK(arr->ValidateFull());
+ return arr;
+ };
+ TestMoveChild(factory, /*child_id=*/0);
+ }
+}
+
+TEST_F(TestArrayExport, MoveSeveralChildren) {
+ TestMoveChildren(
+ struct_({field("ints", int8()), field("floats", float64()), field("strs", utf8())}),
+ R"([[1, 1.5, "foo"], [2, 0.0, null]])", /*children_ids=*/{0, 2});
+}
+
+TEST_F(TestArrayExport, ExportArrayAndType) {
+ struct ArrowSchema c_schema {};
+ struct ArrowArray c_array {};
+ SchemaExportGuard schema_guard(&c_schema);
+ ArrayExportGuard array_guard(&c_array);
+
+ auto array = ArrayFromJSON(int8(), "[1, 2, 3]");
+ ASSERT_OK(ExportArray(*array, &c_array, &c_schema));
+ const ArrayData& data = *array->data();
+ array.reset();
+ ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
+ ASSERT_FALSE(ArrowArrayIsReleased(&c_array));
+ ASSERT_EQ(c_schema.format, std::string("c"));
+ ASSERT_EQ(c_schema.n_children, 0);
+ ArrayExportChecker checker{};
+ checker(&c_array, data);
+}
+
+TEST_F(TestArrayExport, ExportRecordBatch) {
+ struct ArrowSchema c_schema {};
+ struct ArrowArray c_array {};
+
+ auto schema = ::arrow::schema(
+ {field("ints", int16()), field("bools", boolean(), /*nullable=*/false)});
+ schema = schema->WithMetadata(key_value_metadata(kMetadataKeys2, kMetadataValues2));
+ auto arr0 = ArrayFromJSON(int16(), "[1, 2, null]");
+ auto arr1 = ArrayFromJSON(boolean(), "[false, true, false]");
+
+ auto batch_factory = [&]() { return RecordBatch::Make(schema, 3, {arr0, arr1}); };
+
+ {
+ auto batch = batch_factory();
+
+ ASSERT_OK(ExportRecordBatch(*batch, &c_array));
+ ArrayExportGuard array_guard(&c_array);
+ RecordBatchExportChecker checker{};
+ checker(&c_array, *batch);
+
+ // Create batch anew, with the same buffer pointers
+ batch = batch_factory();
+ checker(&c_array, *batch);
+ }
+ {
+ // Check one can export both schema and record batch at once
+ auto batch = batch_factory();
+
+ ASSERT_OK(ExportRecordBatch(*batch, &c_array, &c_schema));
+ SchemaExportGuard schema_guard(&c_schema);
+ ArrayExportGuard array_guard(&c_array);
+ ASSERT_EQ(c_schema.format, std::string("+s"));
+ ASSERT_EQ(c_schema.n_children, 2);
+ ASSERT_NE(c_schema.metadata, nullptr);
+ ASSERT_EQ(kEncodedMetadata2,
+ std::string(c_schema.metadata, kEncodedMetadata2.size()));
+ RecordBatchExportChecker checker{};
+ checker(&c_array, *batch);
+
+ // Create batch anew, with the same buffer pointers
+ batch = batch_factory();
+ checker(&c_array, *batch);
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Schema import tests
+
+void NoOpSchemaRelease(struct ArrowSchema* schema) { ArrowSchemaMarkReleased(schema); }
+
+class SchemaStructBuilder {
+ public:
+ SchemaStructBuilder() { Reset(); }
+
+ void Reset() {
+ memset(&c_struct_, 0, sizeof(c_struct_));
+ c_struct_.release = NoOpSchemaRelease;
+ nested_structs_.clear();
+ children_arrays_.clear();
+ }
+
+ // Create a new ArrowSchema struct with a stable C pointer
+ struct ArrowSchema* AddChild() {
+ nested_structs_.emplace_back();
+ struct ArrowSchema* result = &nested_structs_.back();
+ memset(result, 0, sizeof(*result));
+ result->release = NoOpSchemaRelease;
+ return result;
+ }
+
+ // Create a stable C pointer to the N last structs in nested_structs_
+ struct ArrowSchema** NLastChildren(int64_t n_children, struct ArrowSchema* parent) {
+ children_arrays_.emplace_back(n_children);
+ struct ArrowSchema** children = children_arrays_.back().data();
+ int64_t nested_offset;
+ // If parent is itself at the end of nested_structs_, skip it
+ if (parent != nullptr && &nested_structs_.back() == parent) {
+ nested_offset = static_cast<int64_t>(nested_structs_.size()) - n_children - 1;
+ } else {
+ nested_offset = static_cast<int64_t>(nested_structs_.size()) - n_children;
+ }
+ for (int64_t i = 0; i < n_children; ++i) {
+ children[i] = &nested_structs_[nested_offset + i];
+ }
+ return children;
+ }
+
+ struct ArrowSchema* LastChild(struct ArrowSchema* parent = nullptr) {
+ return *NLastChildren(1, parent);
+ }
+
+ void FillPrimitive(struct ArrowSchema* c, const char* format,
+ const char* name = nullptr, int64_t flags = kDefaultFlags) {
+ c->flags = flags;
+ c->format = format;
+ c->name = name;
+ }
+
+ void FillDictionary(struct ArrowSchema* c) { c->dictionary = LastChild(c); }
+
+ void FillListLike(struct ArrowSchema* c, const char* format, const char* name = nullptr,
+ int64_t flags = kDefaultFlags) {
+ c->flags = flags;
+ c->format = format;
+ c->name = name;
+ c->n_children = 1;
+ c->children = NLastChildren(1, c);
+ c->children[0]->name = "item";
+ }
+
+ void FillStructLike(struct ArrowSchema* c, const char* format, int64_t n_children,
+ const char* name = nullptr, int64_t flags = kDefaultFlags) {
+ c->flags = flags;
+ c->format = format;
+ c->name = name;
+ c->n_children = n_children;
+ c->children = NLastChildren(c->n_children, c);
+ }
+
+ void FillPrimitive(const char* format, const char* name = nullptr,
+ int64_t flags = kDefaultFlags) {
+ FillPrimitive(&c_struct_, format, name, flags);
+ }
+
+ void FillDictionary() { FillDictionary(&c_struct_); }
+
+ void FillListLike(const char* format, const char* name = nullptr,
+ int64_t flags = kDefaultFlags) {
+ FillListLike(&c_struct_, format, name, flags);
+ }
+
+ void FillStructLike(const char* format, int64_t n_children, const char* name = nullptr,
+ int64_t flags = kDefaultFlags) {
+ FillStructLike(&c_struct_, format, n_children, name, flags);
+ }
+
+ struct ArrowSchema c_struct_;
+ // Deque elements don't move when the deque is appended to, which allows taking
+ // stable C pointers to them.
+ std::deque<struct ArrowSchema> nested_structs_;
+ std::deque<std::vector<struct ArrowSchema*>> children_arrays_;
+};
+
+class TestSchemaImport : public ::testing::Test, public SchemaStructBuilder {
+ public:
+ void SetUp() override { Reset(); }
+
+ void CheckImport(const std::shared_ptr<DataType>& expected) {
+ SchemaReleaseCallback cb(&c_struct_);
+
+ ASSERT_OK_AND_ASSIGN(auto type, ImportType(&c_struct_));
+ ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
+ Reset(); // for further tests
+ cb.AssertCalled(); // was released
+ AssertTypeEqual(*expected, *type);
+ }
+
+ void CheckImport(const std::shared_ptr<Field>& expected) {
+ SchemaReleaseCallback cb(&c_struct_);
+
+ ASSERT_OK_AND_ASSIGN(auto field, ImportField(&c_struct_));
+ ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
+ Reset(); // for further tests
+ cb.AssertCalled(); // was released
+ AssertFieldEqual(*expected, *field);
+ }
+
+ void CheckImport(const std::shared_ptr<Schema>& expected) {
+ SchemaReleaseCallback cb(&c_struct_);
+
+ ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_struct_));
+ ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
+ Reset(); // for further tests
+ cb.AssertCalled(); // was released
+ AssertSchemaEqual(*expected, *schema);
+ }
+
+ void CheckImportError() {
+ SchemaReleaseCallback cb(&c_struct_);
+
+ ASSERT_RAISES(Invalid, ImportField(&c_struct_));
+ ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
+ cb.AssertCalled(); // was released
+ }
+
+ void CheckSchemaImportError() {
+ SchemaReleaseCallback cb(&c_struct_);
+
+ ASSERT_RAISES(Invalid, ImportSchema(&c_struct_));
+ ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
+ cb.AssertCalled(); // was released
+ }
+};
+
+TEST_F(TestSchemaImport, Primitive) {
+ FillPrimitive("c");
+ CheckImport(int8());
+ FillPrimitive("c");
+ CheckImport(field("", int8()));
+ FillPrimitive("C");
+ CheckImport(field("", uint8()));
+ FillPrimitive("s");
+ CheckImport(field("", int16()));
+ FillPrimitive("S");
+ CheckImport(field("", uint16()));
+ FillPrimitive("i");
+ CheckImport(field("", int32()));
+ FillPrimitive("I");
+ CheckImport(field("", uint32()));
+ FillPrimitive("l");
+ CheckImport(field("", int64()));
+ FillPrimitive("L");
+ CheckImport(field("", uint64()));
+
+ FillPrimitive("b");
+ CheckImport(field("", boolean()));
+ FillPrimitive("e");
+ CheckImport(field("", float16()));
+ FillPrimitive("f");
+ CheckImport(field("", float32()));
+ FillPrimitive("g");
+ CheckImport(field("", float64()));
+
+ FillPrimitive("d:16,4");
+ CheckImport(field("", decimal128(16, 4)));
+ FillPrimitive("d:16,4,128");
+ CheckImport(field("", decimal128(16, 4)));
+ FillPrimitive("d:16,4,256");
+ CheckImport(field("", decimal256(16, 4)));
+
+ FillPrimitive("d:16,0");
+ CheckImport(field("", decimal128(16, 0)));
+ FillPrimitive("d:16,0,128");
+ CheckImport(field("", decimal128(16, 0)));
+ FillPrimitive("d:16,0,256");
+ CheckImport(field("", decimal256(16, 0)));
+
+ FillPrimitive("d:16,-4");
+ CheckImport(field("", decimal128(16, -4)));
+ FillPrimitive("d:16,-4,128");
+ CheckImport(field("", decimal128(16, -4)));
+ FillPrimitive("d:16,-4,256");
+ CheckImport(field("", decimal256(16, -4)));
+}
+
+TEST_F(TestSchemaImport, Temporal) {
+ FillPrimitive("tdD");
+ CheckImport(date32());
+ FillPrimitive("tdm");
+ CheckImport(date64());
+
+ FillPrimitive("tts");
+ CheckImport(time32(TimeUnit::SECOND));
+ FillPrimitive("ttm");
+ CheckImport(time32(TimeUnit::MILLI));
+ FillPrimitive("ttu");
+ CheckImport(time64(TimeUnit::MICRO));
+ FillPrimitive("ttn");
+ CheckImport(time64(TimeUnit::NANO));
+
+ FillPrimitive("tDs");
+ CheckImport(duration(TimeUnit::SECOND));
+ FillPrimitive("tDm");
+ CheckImport(duration(TimeUnit::MILLI));
+ FillPrimitive("tDu");
+ CheckImport(duration(TimeUnit::MICRO));
+ FillPrimitive("tDn");
+ CheckImport(duration(TimeUnit::NANO));
+
+ FillPrimitive("tiM");
+ CheckImport(month_interval());
+ FillPrimitive("tiD");
+ CheckImport(day_time_interval());
+ FillPrimitive("tin");
+ CheckImport(month_day_nano_interval());
+
+ FillPrimitive("tss:");
+ CheckImport(timestamp(TimeUnit::SECOND));
+ FillPrimitive("tsm:");
+ CheckImport(timestamp(TimeUnit::MILLI));
+ FillPrimitive("tsu:");
+ CheckImport(timestamp(TimeUnit::MICRO));
+ FillPrimitive("tsn:");
+ CheckImport(timestamp(TimeUnit::NANO));
+
+ FillPrimitive("tss:Europe/Paris");
+ CheckImport(timestamp(TimeUnit::SECOND, "Europe/Paris"));
+ FillPrimitive("tsm:Europe/Paris");
+ CheckImport(timestamp(TimeUnit::MILLI, "Europe/Paris"));
+ FillPrimitive("tsu:Europe/Paris");
+ CheckImport(timestamp(TimeUnit::MICRO, "Europe/Paris"));
+ FillPrimitive("tsn:Europe/Paris");
+ CheckImport(timestamp(TimeUnit::NANO, "Europe/Paris"));
+}
+
+TEST_F(TestSchemaImport, String) {
+ FillPrimitive("u");
+ CheckImport(utf8());
+ FillPrimitive("z");
+ CheckImport(binary());
+ FillPrimitive("U");
+ CheckImport(large_utf8());
+ FillPrimitive("Z");
+ CheckImport(large_binary());
+
+ FillPrimitive("w:3");
+ CheckImport(fixed_size_binary(3));
+ FillPrimitive("d:15,4");
+ CheckImport(decimal(15, 4));
+}
+
+TEST_F(TestSchemaImport, List) {
+ FillPrimitive(AddChild(), "c");
+ FillListLike("+l");
+ CheckImport(list(int8()));
+
+ FillPrimitive(AddChild(), "s", "item", 0);
+ FillListLike("+l");
+ CheckImport(list(field("item", int16(), /*nullable=*/false)));
+
+ // Large list
+ FillPrimitive(AddChild(), "s");
+ FillListLike("+L");
+ CheckImport(large_list(int16()));
+
+ // Fixed-size list
+ FillPrimitive(AddChild(), "c");
+ FillListLike("+w:3");
+ CheckImport(fixed_size_list(int8(), 3));
+}
+
+TEST_F(TestSchemaImport, NestedList) {
+ FillPrimitive(AddChild(), "c");
+ FillListLike(AddChild(), "+l");
+ FillListLike("+L");
+ CheckImport(large_list(list(int8())));
+
+ FillPrimitive(AddChild(), "c");
+ FillListLike(AddChild(), "+w:3");
+ FillListLike("+l");
+ CheckImport(list(fixed_size_list(int8(), 3)));
+}
+
+TEST_F(TestSchemaImport, Struct) {
+ FillPrimitive(AddChild(), "u", "strs");
+ FillPrimitive(AddChild(), "S", "ints");
+ FillStructLike("+s", 2);
+ auto expected = struct_({field("strs", utf8()), field("ints", uint16())});
+ CheckImport(expected);
+
+ FillPrimitive(AddChild(), "u", "strs", 0);
+ FillPrimitive(AddChild(), "S", "ints", kDefaultFlags);
+ FillStructLike("+s", 2);
+ expected =
+ struct_({field("strs", utf8(), /*nullable=*/false), field("ints", uint16())});
+ CheckImport(expected);
+
+ // With metadata
+ auto c = AddChild();
+ FillPrimitive(c, "u", "strs", 0);
+ c->metadata = kEncodedMetadata2.c_str();
+ FillPrimitive(AddChild(), "S", "ints", kDefaultFlags);
+ FillStructLike("+s", 2);
+ expected = struct_({field("strs", utf8(), /*nullable=*/false,
+ key_value_metadata(kMetadataKeys2, kMetadataValues2)),
+ field("ints", uint16())});
+ CheckImport(expected);
+}
+
+TEST_F(TestSchemaImport, Union) {
+ // Sparse
+ FillPrimitive(AddChild(), "u", "strs");
+ FillPrimitive(AddChild(), "c", "ints");
+ FillStructLike("+us:43,42", 2);
+ auto expected = sparse_union({field("strs", utf8()), field("ints", int8())}, {43, 42});
+ CheckImport(expected);
+
+ // Dense
+ FillPrimitive(AddChild(), "u", "strs");
+ FillPrimitive(AddChild(), "c", "ints");
+ FillStructLike("+ud:43,42", 2);
+ expected = dense_union({field("strs", utf8()), field("ints", int8())}, {43, 42});
+ CheckImport(expected);
+}
+
+TEST_F(TestSchemaImport, Map) {
+ FillPrimitive(AddChild(), "u", "key");
+ FillPrimitive(AddChild(), "i", "value");
+ FillStructLike(AddChild(), "+s", 2, "entries");
+ FillListLike("+m");
+ auto expected = map(utf8(), int32());
+ CheckImport(expected);
+
+ FillPrimitive(AddChild(), "u", "key");
+ FillPrimitive(AddChild(), "i", "value");
+ FillStructLike(AddChild(), "+s", 2, "entries");
+ FillListLike("+m", "", ARROW_FLAG_MAP_KEYS_SORTED);
+ expected = map(utf8(), int32(), /*keys_sorted=*/true);
+ CheckImport(expected);
+}
+
+TEST_F(TestSchemaImport, Dictionary) {
+ FillPrimitive(AddChild(), "u");
+ FillPrimitive("c");
+ FillDictionary();
+ auto expected = dictionary(int8(), utf8());
+ CheckImport(expected);
+
+ FillPrimitive(AddChild(), "u");
+ FillPrimitive("c", "", ARROW_FLAG_NULLABLE | ARROW_FLAG_DICTIONARY_ORDERED);
+ FillDictionary();
+ expected = dictionary(int8(), utf8(), /*ordered=*/true);
+ CheckImport(expected);
+
+ FillPrimitive(AddChild(), "u");
+ FillListLike(AddChild(), "+L");
+ FillPrimitive("c");
+ FillDictionary();
+ expected = dictionary(int8(), large_list(utf8()));
+ CheckImport(expected);
+
+ FillPrimitive(AddChild(), "u");
+ FillPrimitive(AddChild(), "c");
+ FillDictionary(LastChild());
+ FillListLike("+l");
+ expected = list(dictionary(int8(), utf8()));
+ CheckImport(expected);
+}
+
+TEST_F(TestSchemaImport, UnregisteredExtension) {
+ FillPrimitive("w:16");
+ c_struct_.metadata = kEncodedUuidMetadata.c_str();
+ auto expected = fixed_size_binary(16);
+ CheckImport(expected);
+}
+
+TEST_F(TestSchemaImport, RegisteredExtension) {
+ {
+ ExtensionTypeGuard guard(uuid());
+ FillPrimitive("w:16");
+ c_struct_.metadata = kEncodedUuidMetadata.c_str();
+ auto expected = uuid();
+ CheckImport(expected);
+ }
+ {
+ ExtensionTypeGuard guard(dict_extension_type());
+ FillPrimitive(AddChild(), "u");
+ FillPrimitive("c");
+ FillDictionary();
+ c_struct_.metadata = kEncodedDictExtensionMetadata.c_str();
+ auto expected = dict_extension_type();
+ CheckImport(expected);
+ }
+}
+
+TEST_F(TestSchemaImport, FormatStringError) {
+ FillPrimitive("");
+ CheckImportError();
+ FillPrimitive("cc");
+ CheckImportError();
+ FillPrimitive("w3");
+ CheckImportError();
+ FillPrimitive("w:three");
+ CheckImportError();
+ FillPrimitive("w:3,5");
+ CheckImportError();
+ FillPrimitive("d:15");
+ CheckImportError();
+ FillPrimitive("d:15.4");
+ CheckImportError();
+ FillPrimitive("d:15,z");
+ CheckImportError();
+ FillPrimitive("t");
+ CheckImportError();
+ FillPrimitive("td");
+ CheckImportError();
+ FillPrimitive("tz");
+ CheckImportError();
+ FillPrimitive("tdd");
+ CheckImportError();
+ FillPrimitive("tdDd");
+ CheckImportError();
+ FillPrimitive("tss");
+ CheckImportError();
+ FillPrimitive("tss;UTC");
+ CheckImportError();
+ FillPrimitive("+");
+ CheckImportError();
+ FillPrimitive("+mm");
+ CheckImportError();
+ FillPrimitive("+u");
+ CheckImportError();
+}
+
+TEST_F(TestSchemaImport, UnionError) {
+ FillPrimitive(AddChild(), "u", "strs");
+ FillStructLike("+uz", 1);
+ CheckImportError();
+
+ FillPrimitive(AddChild(), "u", "strs");
+ FillStructLike("+uz:", 1);
+ CheckImportError();
+
+ FillPrimitive(AddChild(), "u", "strs");
+ FillStructLike("+uz:1", 1);
+ CheckImportError();
+
+ FillPrimitive(AddChild(), "u", "strs");
+ FillStructLike("+us:1.2", 1);
+ CheckImportError();
+
+ FillPrimitive(AddChild(), "u", "strs");
+ FillStructLike("+ud:-1", 1);
+ CheckImportError();
+
+ FillPrimitive(AddChild(), "u", "strs");
+ FillStructLike("+ud:1,2", 1);
+ CheckImportError();
+}
+
+TEST_F(TestSchemaImport, DictionaryError) {
+ // Bad index type
+ FillPrimitive(AddChild(), "c");
+ FillPrimitive("u");
+ FillDictionary();
+ CheckImportError();
+
+ // Nested dictionary
+ FillPrimitive(AddChild(), "c");
+ FillPrimitive(AddChild(), "u");
+ FillDictionary(LastChild());
+ FillPrimitive("u");
+ FillDictionary();
+ CheckImportError();
+}
+
+TEST_F(TestSchemaImport, ExtensionError) {
+ ExtensionTypeGuard guard(uuid());
+
+ // Storage type doesn't match
+ FillPrimitive("w:15");
+ c_struct_.metadata = kEncodedUuidMetadata.c_str();
+ CheckImportError();
+
+ // Invalid serialization
+ std::string bogus_metadata = kEncodedUuidMetadata;
+ bogus_metadata[bogus_metadata.size() - 5] += 1;
+ FillPrimitive("w:16");
+ c_struct_.metadata = bogus_metadata.c_str();
+ CheckImportError();
+}
+
+TEST_F(TestSchemaImport, RecursionError) {
+ FillPrimitive(AddChild(), "c", "unused");
+ auto c = AddChild();
+ FillStructLike(c, "+s", 1, "child");
+ FillStructLike("+s", 1, "parent");
+ c->children[0] = &c_struct_;
+ CheckImportError();
+}
+
+TEST_F(TestSchemaImport, ImportField) {
+ FillPrimitive("c", "thing", kDefaultFlags);
+ CheckImport(field("thing", int8()));
+ FillPrimitive("c", "thing", 0);
+ CheckImport(field("thing", int8(), /*nullable=*/false));
+ // With metadata
+ FillPrimitive("c", "thing", kDefaultFlags);
+ c_struct_.metadata = kEncodedMetadata1.c_str();
+ CheckImport(field("thing", int8(), /*nullable=*/true,
+ key_value_metadata(kMetadataKeys1, kMetadataValues1)));
+}
+
+TEST_F(TestSchemaImport, ImportSchema) {
+ FillPrimitive(AddChild(), "l");
+ FillListLike(AddChild(), "+l", "int_lists");
+ FillPrimitive(AddChild(), "u", "strs");
+ FillStructLike("+s", 2);
+ auto f1 = field("int_lists", list(int64()));
+ auto f2 = field("strs", utf8());
+ auto expected = schema({f1, f2});
+ CheckImport(expected);
+
+ // With metadata
+ FillPrimitive(AddChild(), "l");
+ FillListLike(AddChild(), "+l", "int_lists");
+ LastChild()->metadata = kEncodedMetadata2.c_str();
+ FillPrimitive(AddChild(), "u", "strs");
+ FillStructLike("+s", 2);
+ c_struct_.metadata = kEncodedMetadata1.c_str();
+ f1 = f1->WithMetadata(key_value_metadata(kMetadataKeys2, kMetadataValues2));
+ expected = schema({f1, f2}, key_value_metadata(kMetadataKeys1, kMetadataValues1));
+ CheckImport(expected);
+}
+
+TEST_F(TestSchemaImport, ImportSchemaError) {
+ // Not a struct type
+ FillPrimitive("n");
+ CheckSchemaImportError();
+
+ FillPrimitive(AddChild(), "l", "ints");
+ FillPrimitive(AddChild(), "u", "strs");
+ FillStructLike("+us:43,42", 2);
+ CheckSchemaImportError();
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Data import tests
+
+// [true, false, true, true, false, true, true, true] * 2
+static const uint8_t bits_buffer1[] = {0xed, 0xed};
+
+static const void* buffers_no_nulls_no_data[1] = {nullptr};
+static const void* buffers_nulls_no_data1[1] = {bits_buffer1};
+
+static const uint8_t data_buffer1[] = {1, 2, 3, 4, 5, 6, 7, 8,
+ 9, 10, 11, 12, 13, 14, 15, 16};
+static const uint8_t data_buffer2[] = "abcdefghijklmnopqrstuvwxyz";
+#if ARROW_LITTLE_ENDIAN
+static const uint64_t data_buffer3[] = {123456789, 0, 987654321, 0};
+#else
+static const uint64_t data_buffer3[] = {0, 123456789, 0, 987654321};
+#endif
+static const uint8_t data_buffer4[] = {1, 2, 0, 1, 3, 0};
+static const float data_buffer5[] = {0.0f, 1.5f, -2.0f, 3.0f, 4.0f, 5.0f};
+static const double data_buffer6[] = {0.0, 1.5, -2.0, 3.0, 4.0, 5.0};
+static const int32_t data_buffer7[] = {1234, 5678, 9012, 3456};
+static const int64_t data_buffer8[] = {123456789, 987654321, -123456789, -987654321};
+#if ARROW_LITTLE_ENDIAN
+static const void* primitive_buffers_no_nulls1_8[2] = {nullptr, data_buffer1};
+static const void* primitive_buffers_no_nulls1_16[2] = {nullptr, data_buffer1};
+static const void* primitive_buffers_no_nulls1_32[2] = {nullptr, data_buffer1};
+static const void* primitive_buffers_no_nulls1_64[2] = {nullptr, data_buffer1};
+static const void* primitive_buffers_nulls1_8[2] = {bits_buffer1, data_buffer1};
+static const void* primitive_buffers_nulls1_16[2] = {bits_buffer1, data_buffer1};
+#else
+static const uint8_t data_buffer1_16[] = {2, 1, 4, 3, 6, 5, 8, 7,
+ 10, 9, 12, 11, 14, 13, 16, 15};
+static const uint8_t data_buffer1_32[] = {4, 3, 2, 1, 8, 7, 6, 5,
+ 12, 11, 10, 9, 16, 15, 14, 13};
+static const uint8_t data_buffer1_64[] = {8, 7, 6, 5, 4, 3, 2, 1,
+ 16, 15, 14, 13, 12, 11, 10, 9};
+static const void* primitive_buffers_no_nulls1_8[2] = {nullptr, data_buffer1};
+static const void* primitive_buffers_no_nulls1_16[2] = {nullptr, data_buffer1_16};
+static const void* primitive_buffers_no_nulls1_32[2] = {nullptr, data_buffer1_32};
+static const void* primitive_buffers_no_nulls1_64[2] = {nullptr, data_buffer1_64};
+static const void* primitive_buffers_nulls1_8[2] = {bits_buffer1, data_buffer1};
+static const void* primitive_buffers_nulls1_16[2] = {bits_buffer1, data_buffer1_16};
+#endif
+static const void* primitive_buffers_no_nulls2[2] = {nullptr, data_buffer2};
+static const void* primitive_buffers_no_nulls3[2] = {nullptr, data_buffer3};
+static const void* primitive_buffers_no_nulls4[2] = {nullptr, data_buffer4};
+static const void* primitive_buffers_no_nulls5[2] = {nullptr, data_buffer5};
+static const void* primitive_buffers_no_nulls6[2] = {nullptr, data_buffer6};
+static const void* primitive_buffers_no_nulls7[2] = {nullptr, data_buffer7};
+static const void* primitive_buffers_nulls7[2] = {bits_buffer1, data_buffer7};
+static const void* primitive_buffers_no_nulls8[2] = {nullptr, data_buffer8};
+static const void* primitive_buffers_nulls8[2] = {bits_buffer1, data_buffer8};
+
+static const int64_t timestamp_data_buffer1[] = {0, 951782400, -2203977600LL};
+static const int64_t timestamp_data_buffer2[] = {0, 951782400000LL, -2203977600000LL};
+static const int64_t timestamp_data_buffer3[] = {0, 951782400000000LL,
+ -2203977600000000LL};
+static const int64_t timestamp_data_buffer4[] = {0, 951782400000000000LL,
+ -2203977600000000000LL};
+static const void* timestamp_buffers_no_nulls1[2] = {nullptr, timestamp_data_buffer1};
+static const void* timestamp_buffers_nulls1[2] = {bits_buffer1, timestamp_data_buffer1};
+static const void* timestamp_buffers_no_nulls2[2] = {nullptr, timestamp_data_buffer2};
+static const void* timestamp_buffers_no_nulls3[2] = {nullptr, timestamp_data_buffer3};
+static const void* timestamp_buffers_no_nulls4[2] = {nullptr, timestamp_data_buffer4};
+
+static const uint8_t string_data_buffer1[] = "foobarquuxxyzzy";
+
+static const int32_t string_offsets_buffer1[] = {0, 3, 3, 6, 10, 15};
+static const void* string_buffers_no_nulls1[3] = {nullptr, string_offsets_buffer1,
+ string_data_buffer1};
+
+static const int64_t large_string_offsets_buffer1[] = {0, 3, 3, 6, 10};
+static const void* large_string_buffers_no_nulls1[3] = {
+ nullptr, large_string_offsets_buffer1, string_data_buffer1};
+
+static const int32_t list_offsets_buffer1[] = {0, 2, 2, 5, 6, 8};
+static const void* list_buffers_no_nulls1[2] = {nullptr, list_offsets_buffer1};
+static const void* list_buffers_nulls1[2] = {bits_buffer1, list_offsets_buffer1};
+
+static const int64_t large_list_offsets_buffer1[] = {0, 2, 2, 5, 6, 8};
+static const void* large_list_buffers_no_nulls1[2] = {nullptr,
+ large_list_offsets_buffer1};
+
+static const int8_t type_codes_buffer1[] = {42, 42, 43, 43, 42};
+static const int32_t union_offsets_buffer1[] = {0, 1, 0, 1, 2};
+static const void* sparse_union_buffers1_legacy[2] = {nullptr, type_codes_buffer1};
+static const void* dense_union_buffers1_legacy[3] = {nullptr, type_codes_buffer1,
+ union_offsets_buffer1};
+static const void* sparse_union_buffers1[1] = {type_codes_buffer1};
+static const void* dense_union_buffers1[2] = {type_codes_buffer1, union_offsets_buffer1};
+
+void NoOpArrayRelease(struct ArrowArray* schema) { ArrowArrayMarkReleased(schema); }
+
+class TestArrayImport : public ::testing::Test {
+ public:
+ void SetUp() override { Reset(); }
+
+ void Reset() {
+ memset(&c_struct_, 0, sizeof(c_struct_));
+ c_struct_.release = NoOpArrayRelease;
+ nested_structs_.clear();
+ children_arrays_.clear();
+ }
+
+ // Create a new ArrowArray struct with a stable C pointer
+ struct ArrowArray* AddChild() {
+ nested_structs_.emplace_back();
+ struct ArrowArray* result = &nested_structs_.back();
+ memset(result, 0, sizeof(*result));
+ result->release = NoOpArrayRelease;
+ return result;
+ }
+
+ // Create a stable C pointer to the N last structs in nested_structs_
+ struct ArrowArray** NLastChildren(int64_t n_children, struct ArrowArray* parent) {
+ children_arrays_.emplace_back(n_children);
+ struct ArrowArray** children = children_arrays_.back().data();
+ int64_t nested_offset;
+ // If parent is itself at the end of nested_structs_, skip it
+ if (parent != nullptr && &nested_structs_.back() == parent) {
+ nested_offset = static_cast<int64_t>(nested_structs_.size()) - n_children - 1;
+ } else {
+ nested_offset = static_cast<int64_t>(nested_structs_.size()) - n_children;
+ }
+ for (int64_t i = 0; i < n_children; ++i) {
+ children[i] = &nested_structs_[nested_offset + i];
+ }
+ return children;
+ }
+
+ struct ArrowArray* LastChild(struct ArrowArray* parent = nullptr) {
+ return *NLastChildren(1, parent);
+ }
+
+ void FillPrimitive(struct ArrowArray* c, int64_t length, int64_t null_count,
+ int64_t offset, const void** buffers) {
+ c->length = length;
+ c->null_count = null_count;
+ c->offset = offset;
+ c->n_buffers = 2;
+ c->buffers = buffers;
+ }
+
+ void FillDictionary(struct ArrowArray* c) { c->dictionary = LastChild(c); }
+
+ void FillStringLike(struct ArrowArray* c, int64_t length, int64_t null_count,
+ int64_t offset, const void** buffers) {
+ c->length = length;
+ c->null_count = null_count;
+ c->offset = offset;
+ c->n_buffers = 3;
+ c->buffers = buffers;
+ }
+
+ void FillListLike(struct ArrowArray* c, int64_t length, int64_t null_count,
+ int64_t offset, const void** buffers) {
+ c->length = length;
+ c->null_count = null_count;
+ c->offset = offset;
+ c->n_buffers = 2;
+ c->buffers = buffers;
+ c->n_children = 1;
+ c->children = NLastChildren(1, c);
+ }
+
+ void FillFixedSizeListLike(struct ArrowArray* c, int64_t length, int64_t null_count,
+ int64_t offset, const void** buffers) {
+ c->length = length;
+ c->null_count = null_count;
+ c->offset = offset;
+ c->n_buffers = 1;
+ c->buffers = buffers;
+ c->n_children = 1;
+ c->children = NLastChildren(1, c);
+ }
+
+ void FillStructLike(struct ArrowArray* c, int64_t length, int64_t null_count,
+ int64_t offset, int64_t n_children, const void** buffers) {
+ c->length = length;
+ c->null_count = null_count;
+ c->offset = offset;
+ c->n_buffers = 1;
+ c->buffers = buffers;
+ c->n_children = n_children;
+ c->children = NLastChildren(c->n_children, c);
+ }
+
+ // `legacy` selects pre-ARROW-14179 behaviour
+ void FillUnionLike(struct ArrowArray* c, UnionMode::type mode, int64_t length,
+ int64_t null_count, int64_t offset, int64_t n_children,
+ const void** buffers, bool legacy) {
+ c->length = length;
+ c->null_count = null_count;
+ c->offset = offset;
+ if (mode == UnionMode::SPARSE) {
+ c->n_buffers = legacy ? 2 : 1;
+ } else {
+ c->n_buffers = legacy ? 3 : 2;
+ }
+ c->buffers = buffers;
+ c->n_children = n_children;
+ c->children = NLastChildren(c->n_children, c);
+ }
+
+ void FillPrimitive(int64_t length, int64_t null_count, int64_t offset,
+ const void** buffers) {
+ FillPrimitive(&c_struct_, length, null_count, offset, buffers);
+ }
+
+ void FillDictionary() { FillDictionary(&c_struct_); }
+
+ void FillStringLike(int64_t length, int64_t null_count, int64_t offset,
+ const void** buffers) {
+ FillStringLike(&c_struct_, length, null_count, offset, buffers);
+ }
+
+ void FillListLike(int64_t length, int64_t null_count, int64_t offset,
+ const void** buffers) {
+ FillListLike(&c_struct_, length, null_count, offset, buffers);
+ }
+
+ void FillFixedSizeListLike(int64_t length, int64_t null_count, int64_t offset,
+ const void** buffers) {
+ FillFixedSizeListLike(&c_struct_, length, null_count, offset, buffers);
+ }
+
+ void FillStructLike(int64_t length, int64_t null_count, int64_t offset,
+ int64_t n_children, const void** buffers) {
+ FillStructLike(&c_struct_, length, null_count, offset, n_children, buffers);
+ }
+
+ void FillUnionLike(UnionMode::type mode, int64_t length, int64_t null_count,
+ int64_t offset, int64_t n_children, const void** buffers,
+ bool legacy) {
+ FillUnionLike(&c_struct_, mode, length, null_count, offset, n_children, buffers,
+ legacy);
+ }
+
+ void CheckImport(const std::shared_ptr<Array>& expected) {
+ ArrayReleaseCallback cb(&c_struct_);
+
+ auto type = expected->type();
+ ASSERT_OK_AND_ASSIGN(auto array, ImportArray(&c_struct_, type));
+ ASSERT_TRUE(ArrowArrayIsReleased(&c_struct_)); // was moved
+ Reset(); // for further tests
+
+ ASSERT_OK(array->ValidateFull());
+ // Special case: Null array doesn't have any data, so it needn't
+ // keep the ArrowArray struct alive.
+ if (type->id() != Type::NA) {
+ cb.AssertNotCalled();
+ }
+ AssertArraysEqual(*expected, *array, true);
+ array.reset();
+ cb.AssertCalled();
+ }
+
+ void CheckImport(const std::shared_ptr<RecordBatch>& expected) {
+ ArrayReleaseCallback cb(&c_struct_);
+
+ auto schema = expected->schema();
+ ASSERT_OK_AND_ASSIGN(auto batch, ImportRecordBatch(&c_struct_, schema));
+ ASSERT_TRUE(ArrowArrayIsReleased(&c_struct_)); // was moved
+ Reset(); // for further tests
+
+ ASSERT_OK(batch->ValidateFull());
+ AssertBatchesEqual(*expected, *batch);
+ cb.AssertNotCalled();
+ batch.reset();
+ cb.AssertCalled();
+ }
+
+ void CheckImportError(const std::shared_ptr<DataType>& type) {
+ ArrayReleaseCallback cb(&c_struct_);
+
+ ASSERT_RAISES(Invalid, ImportArray(&c_struct_, type));
+ ASSERT_TRUE(ArrowArrayIsReleased(&c_struct_));
+ Reset(); // for further tests
+ cb.AssertCalled(); // was released
+ }
+
+ void CheckImportError(const std::shared_ptr<Schema>& schema) {
+ ArrayReleaseCallback cb(&c_struct_);
+
+ ASSERT_RAISES(Invalid, ImportRecordBatch(&c_struct_, schema));
+ ASSERT_TRUE(ArrowArrayIsReleased(&c_struct_));
+ Reset(); // for further tests
+ cb.AssertCalled(); // was released
+ }
+
+ protected:
+ struct ArrowArray c_struct_;
+ // Deque elements don't move when the deque is appended to, which allows taking
+ // stable C pointers to them.
+ std::deque<struct ArrowArray> nested_structs_;
+ std::deque<std::vector<struct ArrowArray*>> children_arrays_;
+};
+
+TEST_F(TestArrayImport, Primitive) {
+ FillPrimitive(3, 0, 0, primitive_buffers_no_nulls1_8);
+ CheckImport(ArrayFromJSON(int8(), "[1, 2, 3]"));
+ FillPrimitive(5, 0, 0, primitive_buffers_no_nulls1_8);
+ CheckImport(ArrayFromJSON(uint8(), "[1, 2, 3, 4, 5]"));
+ FillPrimitive(3, 0, 0, primitive_buffers_no_nulls1_16);
+ CheckImport(ArrayFromJSON(int16(), "[513, 1027, 1541]"));
+ FillPrimitive(3, 0, 0, primitive_buffers_no_nulls1_16);
+ CheckImport(ArrayFromJSON(uint16(), "[513, 1027, 1541]"));
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls1_32);
+ CheckImport(ArrayFromJSON(int32(), "[67305985, 134678021]"));
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls1_32);
+ CheckImport(ArrayFromJSON(uint32(), "[67305985, 134678021]"));
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls1_64);
+ CheckImport(ArrayFromJSON(int64(), "[578437695752307201, 1157159078456920585]"));
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls1_64);
+ CheckImport(ArrayFromJSON(uint64(), "[578437695752307201, 1157159078456920585]"));
+
+ FillPrimitive(3, 0, 0, primitive_buffers_no_nulls1_8);
+ CheckImport(ArrayFromJSON(boolean(), "[true, false, false]"));
+ FillPrimitive(6, 0, 0, primitive_buffers_no_nulls5);
+ CheckImport(ArrayFromJSON(float32(), "[0.0, 1.5, -2.0, 3.0, 4.0, 5.0]"));
+ FillPrimitive(6, 0, 0, primitive_buffers_no_nulls6);
+ CheckImport(ArrayFromJSON(float64(), "[0.0, 1.5, -2.0, 3.0, 4.0, 5.0]"));
+
+ // With nulls
+ FillPrimitive(9, -1, 0, primitive_buffers_nulls1_8);
+ CheckImport(ArrayFromJSON(int8(), "[1, null, 3, 4, null, 6, 7, 8, 9]"));
+ FillPrimitive(9, 2, 0, primitive_buffers_nulls1_8);
+ CheckImport(ArrayFromJSON(int8(), "[1, null, 3, 4, null, 6, 7, 8, 9]"));
+ FillPrimitive(3, -1, 0, primitive_buffers_nulls1_16);
+ CheckImport(ArrayFromJSON(int16(), "[513, null, 1541]"));
+ FillPrimitive(3, 1, 0, primitive_buffers_nulls1_16);
+ CheckImport(ArrayFromJSON(int16(), "[513, null, 1541]"));
+ FillPrimitive(3, -1, 0, primitive_buffers_nulls1_8);
+ CheckImport(ArrayFromJSON(boolean(), "[true, null, false]"));
+ FillPrimitive(3, 1, 0, primitive_buffers_nulls1_8);
+ CheckImport(ArrayFromJSON(boolean(), "[true, null, false]"));
+}
+
+TEST_F(TestArrayImport, Temporal) {
+ FillPrimitive(3, 0, 0, primitive_buffers_no_nulls7);
+ CheckImport(ArrayFromJSON(date32(), "[1234, 5678, 9012]"));
+ FillPrimitive(3, 0, 0, primitive_buffers_no_nulls8);
+ CheckImport(ArrayFromJSON(date64(), "[123456789, 987654321, -123456789]"));
+
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls7);
+ CheckImport(ArrayFromJSON(time32(TimeUnit::SECOND), "[1234, 5678]"));
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls7);
+ CheckImport(ArrayFromJSON(time32(TimeUnit::MILLI), "[1234, 5678]"));
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls8);
+ CheckImport(ArrayFromJSON(time64(TimeUnit::MICRO), "[123456789, 987654321]"));
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls8);
+ CheckImport(ArrayFromJSON(time64(TimeUnit::NANO), "[123456789, 987654321]"));
+
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls8);
+ CheckImport(ArrayFromJSON(duration(TimeUnit::SECOND), "[123456789, 987654321]"));
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls8);
+ CheckImport(ArrayFromJSON(duration(TimeUnit::MILLI), "[123456789, 987654321]"));
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls8);
+ CheckImport(ArrayFromJSON(duration(TimeUnit::MICRO), "[123456789, 987654321]"));
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls8);
+ CheckImport(ArrayFromJSON(duration(TimeUnit::NANO), "[123456789, 987654321]"));
+
+ FillPrimitive(3, 0, 0, primitive_buffers_no_nulls7);
+ CheckImport(ArrayFromJSON(month_interval(), "[1234, 5678, 9012]"));
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls7);
+ CheckImport(ArrayFromJSON(day_time_interval(), "[[1234, 5678], [9012, 3456]]"));
+
+ const char* json = R"(["1970-01-01","2000-02-29","1900-02-28"])";
+ FillPrimitive(3, 0, 0, timestamp_buffers_no_nulls1);
+ CheckImport(ArrayFromJSON(timestamp(TimeUnit::SECOND), json));
+ FillPrimitive(3, 0, 0, timestamp_buffers_no_nulls2);
+ CheckImport(ArrayFromJSON(timestamp(TimeUnit::MILLI), json));
+ FillPrimitive(3, 0, 0, timestamp_buffers_no_nulls3);
+ CheckImport(ArrayFromJSON(timestamp(TimeUnit::MICRO), json));
+ FillPrimitive(3, 0, 0, timestamp_buffers_no_nulls4);
+ CheckImport(ArrayFromJSON(timestamp(TimeUnit::NANO), json));
+
+ // With nulls
+ FillPrimitive(3, -1, 0, primitive_buffers_nulls7);
+ CheckImport(ArrayFromJSON(date32(), "[1234, null, 9012]"));
+ FillPrimitive(3, -1, 0, primitive_buffers_nulls8);
+ CheckImport(ArrayFromJSON(date64(), "[123456789, null, -123456789]"));
+ FillPrimitive(2, -1, 0, primitive_buffers_nulls8);
+ CheckImport(ArrayFromJSON(time64(TimeUnit::NANO), "[123456789, null]"));
+ FillPrimitive(2, -1, 0, primitive_buffers_nulls8);
+ CheckImport(ArrayFromJSON(duration(TimeUnit::NANO), "[123456789, null]"));
+ FillPrimitive(3, -1, 0, primitive_buffers_nulls7);
+ CheckImport(ArrayFromJSON(month_interval(), "[1234, null, 9012]"));
+ FillPrimitive(2, -1, 0, primitive_buffers_nulls7);
+ CheckImport(ArrayFromJSON(day_time_interval(), "[[1234, 5678], null]"));
+ FillPrimitive(3, -1, 0, timestamp_buffers_nulls1);
+ CheckImport(ArrayFromJSON(timestamp(TimeUnit::SECOND, "UTC+2"),
+ R"(["1970-01-01",null,"1900-02-28"])"));
+}
+
+TEST_F(TestArrayImport, Null) {
+ // Arrow C++ used to export null arrays with a null bitmap buffer
+ for (const int64_t n_buffers : {0, 1}) {
+ const void* buffers[] = {nullptr};
+ c_struct_.length = 3;
+ c_struct_.null_count = 3;
+ c_struct_.offset = 0;
+ c_struct_.buffers = buffers;
+ c_struct_.n_buffers = n_buffers;
+ CheckImport(ArrayFromJSON(null(), "[null, null, null]"));
+ }
+}
+
+TEST_F(TestArrayImport, PrimitiveWithOffset) {
+ FillPrimitive(3, 0, 2, primitive_buffers_no_nulls1_8);
+ CheckImport(ArrayFromJSON(int8(), "[3, 4, 5]"));
+ FillPrimitive(3, 0, 1, primitive_buffers_no_nulls1_16);
+ CheckImport(ArrayFromJSON(uint16(), "[1027, 1541, 2055]"));
+
+ FillPrimitive(4, 0, 7, primitive_buffers_no_nulls1_8);
+ CheckImport(ArrayFromJSON(boolean(), "[false, false, true, false]"));
+}
+
+TEST_F(TestArrayImport, NullWithOffset) {
+ const void* buffers[] = {nullptr};
+ c_struct_.length = 3;
+ c_struct_.null_count = 3;
+ c_struct_.offset = 5;
+ c_struct_.n_buffers = 1;
+ c_struct_.buffers = buffers;
+ CheckImport(ArrayFromJSON(null(), "[null, null, null]"));
+}
+
+TEST_F(TestArrayImport, String) {
+ FillStringLike(4, 0, 0, string_buffers_no_nulls1);
+ CheckImport(ArrayFromJSON(utf8(), R"(["foo", "", "bar", "quux"])"));
+ FillStringLike(4, 0, 0, string_buffers_no_nulls1);
+ CheckImport(ArrayFromJSON(binary(), R"(["foo", "", "bar", "quux"])"));
+ FillStringLike(4, 0, 0, large_string_buffers_no_nulls1);
+ CheckImport(ArrayFromJSON(large_utf8(), R"(["foo", "", "bar", "quux"])"));
+ FillStringLike(4, 0, 0, large_string_buffers_no_nulls1);
+ CheckImport(ArrayFromJSON(large_binary(), R"(["foo", "", "bar", "quux"])"));
+
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls2);
+ CheckImport(ArrayFromJSON(fixed_size_binary(3), R"(["abc", "def"])"));
+ FillPrimitive(2, 0, 0, primitive_buffers_no_nulls3);
+ CheckImport(ArrayFromJSON(decimal(15, 4), R"(["12345.6789", "98765.4321"])"));
+}
+
+TEST_F(TestArrayImport, List) {
+ FillPrimitive(AddChild(), 8, 0, 0, primitive_buffers_no_nulls1_8);
+ FillListLike(5, 0, 0, list_buffers_no_nulls1);
+ CheckImport(ArrayFromJSON(list(int8()), "[[1, 2], [], [3, 4, 5], [6], [7, 8]]"));
+ FillPrimitive(AddChild(), 5, 0, 0, primitive_buffers_no_nulls1_16);
+ FillListLike(3, 1, 0, list_buffers_nulls1);
+ CheckImport(ArrayFromJSON(list(int16()), "[[513, 1027], null, [1541, 2055, 2569]]"));
+
+ // Large list
+ FillPrimitive(AddChild(), 5, 0, 0, primitive_buffers_no_nulls1_16);
+ FillListLike(3, 0, 0, large_list_buffers_no_nulls1);
+ CheckImport(
+ ArrayFromJSON(large_list(int16()), "[[513, 1027], [], [1541, 2055, 2569]]"));
+
+ // Fixed-size list
+ FillPrimitive(AddChild(), 9, 0, 0, primitive_buffers_no_nulls1_8);
+ FillFixedSizeListLike(3, 0, 0, buffers_no_nulls_no_data);
+ CheckImport(
+ ArrayFromJSON(fixed_size_list(int8(), 3), "[[1, 2, 3], [4, 5, 6], [7, 8, 9]]"));
+}
+
+TEST_F(TestArrayImport, NestedList) {
+ FillPrimitive(AddChild(), 8, 0, 0, primitive_buffers_no_nulls1_8);
+ FillListLike(AddChild(), 5, 0, 0, list_buffers_no_nulls1);
+ FillListLike(3, 0, 0, large_list_buffers_no_nulls1);
+ CheckImport(ArrayFromJSON(large_list(list(int8())),
+ "[[[1, 2], []], [], [[3, 4, 5], [6], [7, 8]]]"));
+
+ FillPrimitive(AddChild(), 6, 0, 0, primitive_buffers_no_nulls1_8);
+ FillFixedSizeListLike(AddChild(), 2, 0, 0, buffers_no_nulls_no_data);
+ FillListLike(2, 0, 0, list_buffers_no_nulls1);
+ CheckImport(
+ ArrayFromJSON(list(fixed_size_list(int8(), 3)), "[[[1, 2, 3], [4, 5, 6]], []]"));
+}
+
+TEST_F(TestArrayImport, ListWithOffset) {
+ // Offset in child
+ FillPrimitive(AddChild(), 8, 0, 1, primitive_buffers_no_nulls1_8);
+ FillListLike(5, 0, 0, list_buffers_no_nulls1);
+ CheckImport(ArrayFromJSON(list(int8()), "[[2, 3], [], [4, 5, 6], [7], [8, 9]]"));
+
+ FillPrimitive(AddChild(), 9, 0, 1, primitive_buffers_no_nulls1_8);
+ FillFixedSizeListLike(3, 0, 0, buffers_no_nulls_no_data);
+ CheckImport(
+ ArrayFromJSON(fixed_size_list(int8(), 3), "[[2, 3, 4], [5, 6, 7], [8, 9, 10]]"));
+
+ // Offset in parent
+ FillPrimitive(AddChild(), 8, 0, 0, primitive_buffers_no_nulls1_8);
+ FillListLike(4, 0, 1, list_buffers_no_nulls1);
+ CheckImport(ArrayFromJSON(list(int8()), "[[], [3, 4, 5], [6], [7, 8]]"));
+
+ FillPrimitive(AddChild(), 9, 0, 0, primitive_buffers_no_nulls1_8);
+ FillFixedSizeListLike(3, 0, 1, buffers_no_nulls_no_data);
+ CheckImport(
+ ArrayFromJSON(fixed_size_list(int8(), 3), "[[4, 5, 6], [7, 8, 9], [10, 11, 12]]"));
+
+ // Both
+ FillPrimitive(AddChild(), 8, 0, 2, primitive_buffers_no_nulls1_8);
+ FillListLike(4, 0, 1, list_buffers_no_nulls1);
+ CheckImport(ArrayFromJSON(list(int8()), "[[], [5, 6, 7], [8], [9, 10]]"));
+
+ FillPrimitive(AddChild(), 9, 0, 2, primitive_buffers_no_nulls1_8);
+ FillFixedSizeListLike(3, 0, 1, buffers_no_nulls_no_data);
+ CheckImport(ArrayFromJSON(fixed_size_list(int8(), 3),
+ "[[6, 7, 8], [9, 10, 11], [12, 13, 14]]"));
+}
+
+TEST_F(TestArrayImport, Struct) {
+ FillStringLike(AddChild(), 3, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 3, -1, 0, primitive_buffers_nulls1_16);
+ FillStructLike(3, 0, 0, 2, buffers_no_nulls_no_data);
+ auto expected = ArrayFromJSON(struct_({field("strs", utf8()), field("ints", uint16())}),
+ R"([["foo", 513], ["", null], ["bar", 1541]])");
+ CheckImport(expected);
+
+ FillStringLike(AddChild(), 3, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 3, 0, 0, primitive_buffers_no_nulls1_16);
+ FillStructLike(3, -1, 0, 2, buffers_nulls_no_data1);
+ expected = ArrayFromJSON(struct_({field("strs", utf8()), field("ints", uint16())}),
+ R"([["foo", 513], null, ["bar", 1541]])");
+ CheckImport(expected);
+
+ FillStringLike(AddChild(), 3, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 3, 0, 0, primitive_buffers_no_nulls1_16);
+ FillStructLike(3, -1, 0, 2, buffers_nulls_no_data1);
+ expected = ArrayFromJSON(
+ struct_({field("strs", utf8(), /*nullable=*/false), field("ints", uint16())}),
+ R"([["foo", 513], null, ["bar", 1541]])");
+ CheckImport(expected);
+}
+
+TEST_F(TestArrayImport, SparseUnion) {
+ auto type = sparse_union({field("strs", utf8()), field("ints", int8())}, {43, 42});
+ auto expected =
+ ArrayFromJSON(type, R"([[42, 1], [42, null], [43, "bar"], [43, "quux"]])");
+
+ FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 4, -1, 0, primitive_buffers_nulls1_8);
+ FillUnionLike(UnionMode::SPARSE, 4, 0, 0, 2, sparse_union_buffers1, /*legacy=*/false);
+ CheckImport(expected);
+
+ // Legacy format with null bitmap (ARROW-14179)
+ FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 4, -1, 0, primitive_buffers_nulls1_8);
+ FillUnionLike(UnionMode::SPARSE, 4, 0, 0, 2, sparse_union_buffers1_legacy,
+ /*legacy=*/true);
+ CheckImport(expected);
+}
+
+TEST_F(TestArrayImport, DenseUnion) {
+ auto type = dense_union({field("strs", utf8()), field("ints", int8())}, {43, 42});
+ auto expected =
+ ArrayFromJSON(type, R"([[42, 1], [42, null], [43, "foo"], [43, ""], [42, 3]])");
+
+ FillStringLike(AddChild(), 2, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 3, -1, 0, primitive_buffers_nulls1_8);
+ FillUnionLike(UnionMode::DENSE, 5, 0, 0, 2, dense_union_buffers1, /*legacy=*/false);
+ CheckImport(expected);
+
+ // Legacy format with null bitmap (ARROW-14179)
+ FillStringLike(AddChild(), 2, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 3, -1, 0, primitive_buffers_nulls1_8);
+ FillUnionLike(UnionMode::DENSE, 5, 0, 0, 2, dense_union_buffers1_legacy,
+ /*legacy=*/true);
+ CheckImport(expected);
+}
+
+TEST_F(TestArrayImport, StructWithOffset) {
+ // Child
+ FillStringLike(AddChild(), 3, 0, 1, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 3, 0, 2, primitive_buffers_no_nulls1_8);
+ FillStructLike(3, 0, 0, 2, buffers_no_nulls_no_data);
+ auto expected = ArrayFromJSON(struct_({field("strs", utf8()), field("ints", int8())}),
+ R"([["", 3], ["bar", 4], ["quux", 5]])");
+ CheckImport(expected);
+
+ // Parent and child
+ FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 4, 0, 2, primitive_buffers_no_nulls1_8);
+ FillStructLike(3, 0, 1, 2, buffers_no_nulls_no_data);
+ expected = ArrayFromJSON(struct_({field("strs", utf8()), field("ints", int8())}),
+ R"([["", 4], ["bar", 5], ["quux", 6]])");
+ CheckImport(expected);
+}
+
+TEST_F(TestArrayImport, Map) {
+ FillStringLike(AddChild(), 5, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 5, 0, 0, primitive_buffers_no_nulls1_8);
+ FillStructLike(AddChild(), 5, 0, 0, 2, buffers_no_nulls_no_data);
+ FillListLike(3, 1, 0, list_buffers_nulls1);
+ auto expected = ArrayFromJSON(
+ map(utf8(), uint8()),
+ R"([[["foo", 1], ["", 2]], null, [["bar", 3], ["quux", 4], ["xyzzy", 5]]])");
+ CheckImport(expected);
+}
+
+TEST_F(TestArrayImport, Dictionary) {
+ FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(6, 0, 0, primitive_buffers_no_nulls4);
+ FillDictionary();
+
+ auto dict_values = ArrayFromJSON(utf8(), R"(["foo", "", "bar", "quux"])");
+ auto indices = ArrayFromJSON(int8(), "[1, 2, 0, 1, 3, 0]");
+ ASSERT_OK_AND_ASSIGN(
+ auto expected,
+ DictionaryArray::FromArrays(dictionary(int8(), utf8()), indices, dict_values));
+ CheckImport(expected);
+
+ FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(6, 0, 0, primitive_buffers_no_nulls4);
+ FillDictionary();
+
+ ASSERT_OK_AND_ASSIGN(
+ expected, DictionaryArray::FromArrays(dictionary(int8(), utf8(), /*ordered=*/true),
+ indices, dict_values));
+ CheckImport(expected);
+}
+
+TEST_F(TestArrayImport, NestedDictionary) {
+ FillPrimitive(AddChild(), 6, 0, 0, primitive_buffers_no_nulls1_8);
+ FillListLike(AddChild(), 4, 0, 0, list_buffers_no_nulls1);
+ FillPrimitive(6, 0, 0, primitive_buffers_no_nulls4);
+ FillDictionary();
+
+ auto dict_values = ArrayFromJSON(list(int8()), "[[1, 2], [], [3, 4, 5], [6]]");
+ auto indices = ArrayFromJSON(int8(), "[1, 2, 0, 1, 3, 0]");
+ ASSERT_OK_AND_ASSIGN(auto expected,
+ DictionaryArray::FromArrays(dictionary(int8(), list(int8())),
+ indices, dict_values));
+ CheckImport(expected);
+
+ FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 6, 0, 0, primitive_buffers_no_nulls4);
+ FillDictionary(LastChild());
+ FillListLike(3, 0, 0, list_buffers_no_nulls1);
+
+ dict_values = ArrayFromJSON(utf8(), R"(["foo", "", "bar", "quux"])");
+ indices = ArrayFromJSON(int8(), "[1, 2, 0, 1, 3, 0]");
+ ASSERT_OK_AND_ASSIGN(
+ auto dict_array,
+ DictionaryArray::FromArrays(dictionary(int8(), utf8()), indices, dict_values));
+ auto offsets = ArrayFromJSON(int32(), "[0, 2, 2, 5]");
+ ASSERT_OK_AND_ASSIGN(expected, ListArray::FromArrays(*offsets, *dict_array));
+ CheckImport(expected);
+}
+
+TEST_F(TestArrayImport, DictionaryWithOffset) {
+ FillStringLike(AddChild(), 3, 0, 1, string_buffers_no_nulls1);
+ FillPrimitive(3, 0, 0, primitive_buffers_no_nulls4);
+ FillDictionary();
+
+ auto expected = DictArrayFromJSON(dictionary(int8(), utf8()), "[1, 2, 0]",
+ R"(["", "bar", "quux"])");
+ CheckImport(expected);
+
+ FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(4, 0, 2, primitive_buffers_no_nulls4);
+ FillDictionary();
+
+ expected = DictArrayFromJSON(dictionary(int8(), utf8()), "[0, 1, 3, 0]",
+ R"(["foo", "", "bar", "quux"])");
+ CheckImport(expected);
+}
+
+TEST_F(TestArrayImport, RegisteredExtension) {
+ ExtensionTypeGuard guard({smallint(), dict_extension_type(), complex128()});
+
+ // smallint
+ FillPrimitive(3, 0, 0, primitive_buffers_no_nulls1_16);
+ auto expected =
+ ExtensionType::WrapArray(smallint(), ArrayFromJSON(int16(), "[513, 1027, 1541]"));
+ CheckImport(expected);
+
+ // dict_extension_type
+ FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(6, 0, 0, primitive_buffers_no_nulls4);
+ FillDictionary();
+
+ auto storage = DictArrayFromJSON(dictionary(int8(), utf8()), "[1, 2, 0, 1, 3, 0]",
+ R"(["foo", "", "bar", "quux"])");
+ expected = ExtensionType::WrapArray(dict_extension_type(), storage);
+ CheckImport(expected);
+
+ // complex128
+ FillPrimitive(AddChild(), 3, 0, /*offset=*/0, primitive_buffers_no_nulls6);
+ FillPrimitive(AddChild(), 3, 0, /*offset=*/3, primitive_buffers_no_nulls6);
+ FillStructLike(3, 0, 0, 2, buffers_no_nulls_no_data);
+ expected = MakeComplex128(ArrayFromJSON(float64(), "[0.0, 1.5, -2.0]"),
+ ArrayFromJSON(float64(), "[3.0, 4.0, 5.0]"));
+ CheckImport(expected);
+}
+
+TEST_F(TestArrayImport, PrimitiveError) {
+ // Bad number of buffers
+ FillPrimitive(3, 0, 0, primitive_buffers_no_nulls1_8);
+ c_struct_.n_buffers = 1;
+ CheckImportError(int8());
+
+ // Zero null bitmap but non-zero null_count
+ FillPrimitive(3, 1, 0, primitive_buffers_no_nulls1_8);
+ CheckImportError(int8());
+}
+
+TEST_F(TestArrayImport, StructError) {
+ // Bad number of children
+ FillStringLike(AddChild(), 3, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 3, -1, 0, primitive_buffers_nulls1_8);
+ FillStructLike(3, 0, 0, 2, buffers_no_nulls_no_data);
+ CheckImportError(struct_({field("strs", utf8())}));
+}
+
+TEST_F(TestArrayImport, MapError) {
+ // Bad number of (struct) children in map child
+ FillStringLike(AddChild(), 5, 0, 0, string_buffers_no_nulls1);
+ FillStructLike(AddChild(), 5, 0, 0, 1, buffers_no_nulls_no_data);
+ FillListLike(3, 1, 0, list_buffers_nulls1);
+ CheckImportError(map(utf8(), uint8()));
+}
+
+TEST_F(TestArrayImport, UnionError) {
+ // Non-zero null count
+ auto type = sparse_union({field("strs", utf8()), field("ints", int8())}, {43, 42});
+ FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 4, -1, 0, primitive_buffers_nulls1_8);
+ FillUnionLike(UnionMode::SPARSE, 4, -1, 0, 2, sparse_union_buffers1, /*legacy=*/false);
+ CheckImportError(type);
+
+ type = dense_union({field("strs", utf8()), field("ints", int8())}, {43, 42});
+ FillStringLike(AddChild(), 2, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 3, -1, 0, primitive_buffers_nulls1_8);
+ FillUnionLike(UnionMode::DENSE, 5, -1, 0, 2, dense_union_buffers1, /*legacy=*/false);
+ CheckImportError(type);
+}
+
+TEST_F(TestArrayImport, DictionaryError) {
+ // Missing dictionary field
+ FillPrimitive(3, 0, 0, primitive_buffers_no_nulls4);
+ CheckImportError(dictionary(int8(), utf8()));
+
+ // Unexpected dictionary field
+ FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(6, 0, 0, primitive_buffers_no_nulls4);
+ FillDictionary();
+ CheckImportError(int8());
+}
+
+TEST_F(TestArrayImport, RecursionError) {
+ // Infinite loop through children
+ FillStringLike(AddChild(), 3, 0, 0, string_buffers_no_nulls1);
+ FillStructLike(AddChild(), 3, 0, 0, 1, buffers_no_nulls_no_data);
+ FillStructLike(3, 0, 0, 1, buffers_no_nulls_no_data);
+ c_struct_.children[0] = &c_struct_;
+ CheckImportError(struct_({field("ints", struct_({field("ints", int8())}))}));
+}
+
+TEST_F(TestArrayImport, ImportRecordBatch) {
+ auto schema = ::arrow::schema(
+ {field("strs", utf8(), /*nullable=*/false), field("ints", uint16())});
+ auto expected_strs = ArrayFromJSON(utf8(), R"(["", "bar", "quux"])");
+ auto expected_ints = ArrayFromJSON(uint16(), "[513, null, 1541]");
+
+ FillStringLike(AddChild(), 3, 0, 1, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 3, -1, 0, primitive_buffers_nulls1_16);
+ FillStructLike(3, 0, 0, 2, buffers_no_nulls_no_data);
+
+ auto expected = RecordBatch::Make(schema, 3, {expected_strs, expected_ints});
+ CheckImport(expected);
+}
+
+TEST_F(TestArrayImport, ImportRecordBatchError) {
+ // Struct with non-zero parent offset
+ FillStringLike(AddChild(), 4, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 4, 0, 0, primitive_buffers_no_nulls1_16);
+ FillStructLike(3, 0, 1, 2, buffers_no_nulls_no_data);
+ auto schema = ::arrow::schema({field("strs", utf8()), field("ints", uint16())});
+ CheckImportError(schema);
+
+ // Struct with nulls in parent
+ FillStringLike(AddChild(), 3, 0, 0, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 3, 0, 0, primitive_buffers_no_nulls1_8);
+ FillStructLike(3, 1, 0, 2, buffers_nulls_no_data1);
+ CheckImportError(schema);
+}
+
+TEST_F(TestArrayImport, ImportArrayAndType) {
+ // Test importing both array and its type at the same time
+ SchemaStructBuilder schema_builder;
+ schema_builder.FillPrimitive("c");
+ SchemaReleaseCallback schema_cb(&schema_builder.c_struct_);
+
+ FillPrimitive(3, 0, 0, primitive_buffers_no_nulls1_8);
+ ArrayReleaseCallback array_cb(&c_struct_);
+
+ ASSERT_OK_AND_ASSIGN(auto array, ImportArray(&c_struct_, &schema_builder.c_struct_));
+ AssertArraysEqual(*array, *ArrayFromJSON(int8(), "[1, 2, 3]"));
+ schema_cb.AssertCalled(); // was released
+ array_cb.AssertNotCalled();
+ ASSERT_TRUE(ArrowArrayIsReleased(&c_struct_)); // was moved
+ array.reset();
+ array_cb.AssertCalled();
+}
+
+TEST_F(TestArrayImport, ImportArrayAndTypeError) {
+ // On error, both structs are released
+ SchemaStructBuilder schema_builder;
+ schema_builder.FillPrimitive("cc");
+ SchemaReleaseCallback schema_cb(&schema_builder.c_struct_);
+
+ FillPrimitive(3, 0, 0, primitive_buffers_no_nulls1_8);
+ ArrayReleaseCallback array_cb(&c_struct_);
+
+ ASSERT_RAISES(Invalid, ImportArray(&c_struct_, &schema_builder.c_struct_));
+ schema_cb.AssertCalled();
+ array_cb.AssertCalled();
+}
+
+TEST_F(TestArrayImport, ImportRecordBatchAndSchema) {
+ // Test importing both record batch and its schema at the same time
+ auto schema = ::arrow::schema({field("strs", utf8()), field("ints", uint16())});
+ auto expected_strs = ArrayFromJSON(utf8(), R"(["", "bar", "quux"])");
+ auto expected_ints = ArrayFromJSON(uint16(), "[513, null, 1541]");
+
+ SchemaStructBuilder schema_builder;
+ schema_builder.FillPrimitive(schema_builder.AddChild(), "u", "strs");
+ schema_builder.FillPrimitive(schema_builder.AddChild(), "S", "ints");
+ schema_builder.FillStructLike("+s", 2);
+ SchemaReleaseCallback schema_cb(&schema_builder.c_struct_);
+
+ FillStringLike(AddChild(), 3, 0, 1, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 3, -1, 0, primitive_buffers_nulls1_16);
+ FillStructLike(3, 0, 0, 2, buffers_no_nulls_no_data);
+ ArrayReleaseCallback array_cb(&c_struct_);
+
+ ASSERT_OK_AND_ASSIGN(auto batch,
+ ImportRecordBatch(&c_struct_, &schema_builder.c_struct_));
+ auto expected = RecordBatch::Make(schema, 3, {expected_strs, expected_ints});
+ AssertBatchesEqual(*batch, *expected);
+ schema_cb.AssertCalled(); // was released
+ array_cb.AssertNotCalled();
+ ASSERT_TRUE(ArrowArrayIsReleased(&c_struct_)); // was moved
+ batch.reset();
+ array_cb.AssertCalled();
+}
+
+TEST_F(TestArrayImport, ImportRecordBatchAndSchemaError) {
+ // On error, both structs are released
+ SchemaStructBuilder schema_builder;
+ schema_builder.FillPrimitive("cc");
+ SchemaReleaseCallback schema_cb(&schema_builder.c_struct_);
+
+ FillStringLike(AddChild(), 3, 0, 1, string_buffers_no_nulls1);
+ FillPrimitive(AddChild(), 3, -1, 0, primitive_buffers_nulls1_8);
+ FillStructLike(3, 0, 0, 2, buffers_no_nulls_no_data);
+ ArrayReleaseCallback array_cb(&c_struct_);
+
+ ASSERT_RAISES(Invalid, ImportRecordBatch(&c_struct_, &schema_builder.c_struct_));
+ schema_cb.AssertCalled();
+ array_cb.AssertCalled();
+}
+
+////////////////////////////////////////////////////////////////////////////
+// C++ -> C -> C++ schema roundtripping tests
+
+class TestSchemaRoundtrip : public ::testing::Test {
+ public:
+ void SetUp() override { pool_ = default_memory_pool(); }
+
+ template <typename TypeFactory, typename ExpectedTypeFactory>
+ void TestWithTypeFactory(TypeFactory&& factory,
+ ExpectedTypeFactory&& factory_expected) {
+ std::shared_ptr<DataType> type, actual;
+ struct ArrowSchema c_schema {}; // zeroed
+ SchemaExportGuard schema_guard(&c_schema);
+
+ auto orig_bytes = pool_->bytes_allocated();
+
+ type = factory();
+ auto type_use_count = type.use_count();
+ ASSERT_OK(ExportType(*type, &c_schema));
+ ASSERT_GT(pool_->bytes_allocated(), orig_bytes);
+ // Export stores no reference to the type
+ ASSERT_EQ(type_use_count, type.use_count());
+ type.reset();
+
+ // Recreate the type
+ ASSERT_OK_AND_ASSIGN(actual, ImportType(&c_schema));
+ type = factory_expected();
+ AssertTypeEqual(*type, *actual);
+ type.reset();
+ actual.reset();
+
+ ASSERT_EQ(pool_->bytes_allocated(), orig_bytes);
+ }
+
+ template <typename TypeFactory>
+ void TestWithTypeFactory(TypeFactory&& factory) {
+ TestWithTypeFactory(factory, factory);
+ }
+
+ template <typename SchemaFactory>
+ void TestWithSchemaFactory(SchemaFactory&& factory) {
+ std::shared_ptr<Schema> schema, actual;
+ struct ArrowSchema c_schema {}; // zeroed
+ SchemaExportGuard schema_guard(&c_schema);
+
+ auto orig_bytes = pool_->bytes_allocated();
+
+ schema = factory();
+ auto schema_use_count = schema.use_count();
+ ASSERT_OK(ExportSchema(*schema, &c_schema));
+ ASSERT_GT(pool_->bytes_allocated(), orig_bytes);
+ // Export stores no reference to the schema
+ ASSERT_EQ(schema_use_count, schema.use_count());
+ schema.reset();
+
+ // Recreate the schema
+ ASSERT_OK_AND_ASSIGN(actual, ImportSchema(&c_schema));
+ schema = factory();
+ AssertSchemaEqual(*schema, *actual);
+ schema.reset();
+ actual.reset();
+
+ ASSERT_EQ(pool_->bytes_allocated(), orig_bytes);
+ }
+
+ protected:
+ MemoryPool* pool_;
+};
+
+TEST_F(TestSchemaRoundtrip, Null) { TestWithTypeFactory(null); }
+
+TEST_F(TestSchemaRoundtrip, Primitive) {
+ TestWithTypeFactory(int32);
+ TestWithTypeFactory(boolean);
+ TestWithTypeFactory(float16);
+
+ TestWithTypeFactory(std::bind(decimal128, 19, 4));
+ TestWithTypeFactory(std::bind(decimal256, 19, 4));
+ TestWithTypeFactory(std::bind(decimal128, 19, 0));
+ TestWithTypeFactory(std::bind(decimal256, 19, 0));
+ TestWithTypeFactory(std::bind(decimal128, 19, -5));
+ TestWithTypeFactory(std::bind(decimal256, 19, -5));
+ TestWithTypeFactory(std::bind(fixed_size_binary, 3));
+ TestWithTypeFactory(binary);
+ TestWithTypeFactory(large_utf8);
+}
+
+TEST_F(TestSchemaRoundtrip, Temporal) {
+ TestWithTypeFactory(date32);
+ TestWithTypeFactory(day_time_interval);
+ TestWithTypeFactory(month_interval);
+ TestWithTypeFactory(month_day_nano_interval);
+ TestWithTypeFactory(std::bind(time64, TimeUnit::NANO));
+ TestWithTypeFactory(std::bind(duration, TimeUnit::MICRO));
+ TestWithTypeFactory([]() { return arrow::timestamp(TimeUnit::MICRO, "Europe/Paris"); });
+}
+
+TEST_F(TestSchemaRoundtrip, List) {
+ TestWithTypeFactory([]() { return list(utf8()); });
+ TestWithTypeFactory([]() { return large_list(list(utf8())); });
+ TestWithTypeFactory([]() { return fixed_size_list(utf8(), 5); });
+ TestWithTypeFactory([]() { return list(fixed_size_list(utf8(), 5)); });
+}
+
+TEST_F(TestSchemaRoundtrip, Struct) {
+ auto f1 = field("f1", utf8(), /*nullable=*/false);
+ auto f2 = field("f2", list(decimal(19, 4)));
+
+ TestWithTypeFactory([&]() { return struct_({f1, f2}); });
+ f2 = f2->WithMetadata(key_value_metadata(kMetadataKeys2, kMetadataValues2));
+ TestWithTypeFactory([&]() { return struct_({f1, f2}); });
+}
+
+TEST_F(TestSchemaRoundtrip, Union) {
+ auto f1 = field("f1", utf8(), /*nullable=*/false);
+ auto f2 = field("f2", list(decimal(19, 4)));
+ auto type_codes = std::vector<int8_t>{42, 43};
+
+ TestWithTypeFactory([&]() { return sparse_union({f1, f2}, type_codes); });
+ f2 = f2->WithMetadata(key_value_metadata(kMetadataKeys2, kMetadataValues2));
+ TestWithTypeFactory([&]() { return dense_union({f1, f2}, type_codes); });
+}
+
+TEST_F(TestSchemaRoundtrip, Dictionary) {
+ for (auto index_ty : all_dictionary_index_types()) {
+ TestWithTypeFactory([&]() { return dictionary(index_ty, utf8()); });
+ TestWithTypeFactory([&]() { return dictionary(index_ty, utf8(), /*ordered=*/true); });
+ TestWithTypeFactory([&]() { return dictionary(index_ty, list(utf8())); });
+ TestWithTypeFactory([&]() { return list(dictionary(index_ty, list(utf8()))); });
+ }
+}
+
+TEST_F(TestSchemaRoundtrip, UnregisteredExtension) {
+ TestWithTypeFactory(uuid, []() { return fixed_size_binary(16); });
+ TestWithTypeFactory(dict_extension_type, []() { return dictionary(int8(), utf8()); });
+
+ // Inside nested type
+ TestWithTypeFactory([]() { return list(dict_extension_type()); },
+ []() { return list(dictionary(int8(), utf8())); });
+}
+
+TEST_F(TestSchemaRoundtrip, RegisteredExtension) {
+ ExtensionTypeGuard guard({uuid(), dict_extension_type(), complex128()});
+ TestWithTypeFactory(uuid);
+ TestWithTypeFactory(dict_extension_type);
+ TestWithTypeFactory(complex128);
+
+ // Inside nested type
+ TestWithTypeFactory([]() { return list(uuid()); });
+ TestWithTypeFactory([]() { return list(dict_extension_type()); });
+ TestWithTypeFactory([]() { return list(complex128()); });
+}
+
+TEST_F(TestSchemaRoundtrip, Map) {
+ TestWithTypeFactory([&]() { return map(utf8(), int32()); });
+ TestWithTypeFactory([&]() { return map(list(utf8()), int32()); });
+ TestWithTypeFactory([&]() { return list(map(list(utf8()), int32())); });
+}
+
+TEST_F(TestSchemaRoundtrip, Schema) {
+ auto f1 = field("f1", utf8(), /*nullable=*/false);
+ auto f2 = field("f2", list(decimal256(19, 4)));
+ auto md1 = key_value_metadata(kMetadataKeys1, kMetadataValues1);
+ auto md2 = key_value_metadata(kMetadataKeys2, kMetadataValues2);
+
+ TestWithSchemaFactory([&]() { return schema({f1, f2}); });
+ f2 = f2->WithMetadata(md2);
+ TestWithSchemaFactory([&]() { return schema({f1, f2}); });
+ TestWithSchemaFactory([&]() { return schema({f1, f2}, md1); });
+}
+
+////////////////////////////////////////////////////////////////////////////
+// C++ -> C -> C++ data roundtripping tests
+
+class TestArrayRoundtrip : public ::testing::Test {
+ public:
+ using ArrayFactory = std::function<Result<std::shared_ptr<Array>>()>;
+
+ void SetUp() override { pool_ = default_memory_pool(); }
+
+ static ArrayFactory JSONArrayFactory(std::shared_ptr<DataType> type, const char* json) {
+ return [=]() { return ArrayFromJSON(type, json); };
+ }
+
+ static ArrayFactory SlicedArrayFactory(ArrayFactory factory) {
+ return [=]() -> Result<std::shared_ptr<Array>> {
+ ARROW_ASSIGN_OR_RAISE(auto arr, factory());
+ DCHECK_GE(arr->length(), 2);
+ return arr->Slice(1, arr->length() - 2);
+ };
+ }
+
+ template <typename ArrayFactory>
+ void TestWithArrayFactory(ArrayFactory&& factory) {
+ TestWithArrayFactory(factory, factory);
+ }
+
+ template <typename ArrayFactory, typename ExpectedArrayFactory>
+ void TestWithArrayFactory(ArrayFactory&& factory,
+ ExpectedArrayFactory&& factory_expected) {
+ std::shared_ptr<Array> array;
+ struct ArrowArray c_array {};
+ struct ArrowSchema c_schema {};
+ ArrayExportGuard array_guard(&c_array);
+ SchemaExportGuard schema_guard(&c_schema);
+
+ auto orig_bytes = pool_->bytes_allocated();
+
+ ASSERT_OK_AND_ASSIGN(array, ToResult(factory()));
+ ASSERT_OK(ExportType(*array->type(), &c_schema));
+ ASSERT_OK(ExportArray(*array, &c_array));
+
+ auto new_bytes = pool_->bytes_allocated();
+ if (array->type_id() != Type::NA) {
+ ASSERT_GT(new_bytes, orig_bytes);
+ }
+
+ array.reset();
+ ASSERT_EQ(pool_->bytes_allocated(), new_bytes);
+ ASSERT_OK_AND_ASSIGN(array, ImportArray(&c_array, &c_schema));
+ ASSERT_OK(array->ValidateFull());
+ ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
+ ASSERT_TRUE(ArrowArrayIsReleased(&c_array));
+
+ // Re-export and re-import, now both at once
+ ASSERT_OK(ExportArray(*array, &c_array, &c_schema));
+ array.reset();
+ ASSERT_OK_AND_ASSIGN(array, ImportArray(&c_array, &c_schema));
+ ASSERT_OK(array->ValidateFull());
+ ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
+ ASSERT_TRUE(ArrowArrayIsReleased(&c_array));
+
+ // Check value of imported array
+ {
+ std::shared_ptr<Array> expected;
+ ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected()));
+ AssertTypeEqual(*expected->type(), *array->type());
+ AssertArraysEqual(*expected, *array, true);
+ }
+ array.reset();
+ ASSERT_EQ(pool_->bytes_allocated(), orig_bytes);
+ }
+
+ template <typename BatchFactory>
+ void TestWithBatchFactory(BatchFactory&& factory) {
+ std::shared_ptr<RecordBatch> batch;
+ struct ArrowArray c_array {};
+ struct ArrowSchema c_schema {};
+ ArrayExportGuard array_guard(&c_array);
+ SchemaExportGuard schema_guard(&c_schema);
+
+ auto orig_bytes = pool_->bytes_allocated();
+ ASSERT_OK_AND_ASSIGN(batch, ToResult(factory()));
+ ASSERT_OK(ExportSchema(*batch->schema(), &c_schema));
+ ASSERT_OK(ExportRecordBatch(*batch, &c_array));
+
+ auto new_bytes = pool_->bytes_allocated();
+ batch.reset();
+ ASSERT_EQ(pool_->bytes_allocated(), new_bytes);
+ ASSERT_OK_AND_ASSIGN(batch, ImportRecordBatch(&c_array, &c_schema));
+ ASSERT_OK(batch->ValidateFull());
+ ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
+ ASSERT_TRUE(ArrowArrayIsReleased(&c_array));
+
+ // Re-export and re-import, now both at once
+ ASSERT_OK(ExportRecordBatch(*batch, &c_array, &c_schema));
+ batch.reset();
+ ASSERT_OK_AND_ASSIGN(batch, ImportRecordBatch(&c_array, &c_schema));
+ ASSERT_OK(batch->ValidateFull());
+ ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
+ ASSERT_TRUE(ArrowArrayIsReleased(&c_array));
+
+ // Check value of imported record batch
+ {
+ std::shared_ptr<RecordBatch> expected;
+ ASSERT_OK_AND_ASSIGN(expected, ToResult(factory()));
+ AssertSchemaEqual(*expected->schema(), *batch->schema());
+ AssertBatchesEqual(*expected, *batch);
+ }
+ batch.reset();
+ ASSERT_EQ(pool_->bytes_allocated(), orig_bytes);
+ }
+
+ void TestWithJSON(std::shared_ptr<DataType> type, const char* json) {
+ TestWithArrayFactory(JSONArrayFactory(type, json));
+ }
+
+ void TestWithJSONSliced(std::shared_ptr<DataType> type, const char* json) {
+ TestWithArrayFactory(SlicedArrayFactory(JSONArrayFactory(type, json)));
+ }
+
+ protected:
+ MemoryPool* pool_;
+};
+
+TEST_F(TestArrayRoundtrip, Null) {
+ TestWithJSON(null(), "[]");
+ TestWithJSON(null(), "[null, null]");
+
+ TestWithJSONSliced(null(), "[null, null]");
+ TestWithJSONSliced(null(), "[null, null, null]");
+}
+
+TEST_F(TestArrayRoundtrip, Primitive) {
+ TestWithJSON(int32(), "[]");
+ TestWithJSON(int32(), "[4, 5, null]");
+
+ TestWithJSON(decimal128(16, 4), R"(["0.4759", "1234.5670", null])");
+ TestWithJSON(decimal256(16, 4), R"(["0.4759", "1234.5670", null])");
+
+ TestWithJSON(month_day_nano_interval(), R"([[1, -600, 5000], null])");
+
+ TestWithJSONSliced(int32(), "[4, 5]");
+ TestWithJSONSliced(int32(), "[4, 5, 6, null]");
+ TestWithJSONSliced(decimal128(16, 4), R"(["0.4759", "1234.5670", null])");
+ TestWithJSONSliced(decimal256(16, 4), R"(["0.4759", "1234.5670", null])");
+ TestWithJSONSliced(month_day_nano_interval(),
+ R"([[4, 5, 6], [1, -600, 5000], null, null])");
+}
+
+TEST_F(TestArrayRoundtrip, UnknownNullCount) {
+ TestWithArrayFactory([]() -> Result<std::shared_ptr<Array>> {
+ auto arr = ArrayFromJSON(int32(), "[0, 1, 2]");
+ if (arr->null_bitmap()) {
+ return Status::Invalid(
+ "Failed precondition: "
+ "the array shouldn't have a null bitmap.");
+ }
+ arr->data()->SetNullCount(kUnknownNullCount);
+ return arr;
+ });
+}
+
+TEST_F(TestArrayRoundtrip, List) {
+ TestWithJSON(list(int32()), "[]");
+ TestWithJSON(list(int32()), "[[4, 5], [6, null], null]");
+
+ TestWithJSONSliced(list(int32()), "[[4, 5], [6, null], null]");
+}
+
+TEST_F(TestArrayRoundtrip, Struct) {
+ auto type = struct_({field("ints", int16()), field("bools", boolean())});
+ TestWithJSON(type, "[]");
+ TestWithJSON(type, "[[4, true], [5, false]]");
+ TestWithJSON(type, "[[4, null], null, [5, false]]");
+
+ TestWithJSONSliced(type, "[[4, null], null, [5, false]]");
+
+ // With nullable = false and metadata
+ auto f0 = field("ints", int16(), /*nullable=*/false);
+ auto f1 = field("bools", boolean(), /*nullable=*/true,
+ key_value_metadata(kMetadataKeys1, kMetadataValues1));
+ type = struct_({f0, f1});
+ TestWithJSON(type, "[]");
+ TestWithJSON(type, "[[4, true], [5, null]]");
+
+ TestWithJSONSliced(type, "[[4, true], [5, null], [6, false]]");
+}
+
+TEST_F(TestArrayRoundtrip, Map) {
+ // Map type
+ auto type = map(utf8(), int32());
+ const char* json = R"([[["foo", 123], ["bar", -456]], null,
+ [["foo", null]], []])";
+ TestWithJSON(type, json);
+ TestWithJSONSliced(type, json);
+
+ type = map(utf8(), int32(), /*keys_sorted=*/true);
+ TestWithJSON(type, json);
+ TestWithJSONSliced(type, json);
+}
+
+TEST_F(TestArrayRoundtrip, Union) {
+ FieldVector fields = {field("strs", utf8()), field("ints", int8())};
+ std::vector<int8_t> type_codes = {43, 42};
+ DataTypeVector union_types = {sparse_union(fields, type_codes),
+ dense_union(fields, type_codes)};
+ const char* json = R"([[42, 1], [42, null], [43, "foo"], [43, ""], [42, 3]])";
+
+ for (const auto& type : union_types) {
+ TestWithJSON(type, "[]");
+ TestWithJSON(type, json);
+ TestWithJSONSliced(type, json);
+ }
+}
+
+TEST_F(TestArrayRoundtrip, Dictionary) {
+ {
+ auto factory = []() {
+ auto values = ArrayFromJSON(utf8(), R"(["foo", "bar", "quux"])");
+ auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]");
+ return DictionaryArray::FromArrays(dictionary(indices->type(), values->type()),
+ indices, values);
+ };
+ TestWithArrayFactory(factory);
+ TestWithArrayFactory(SlicedArrayFactory(factory));
+ }
+ {
+ auto factory = []() {
+ auto values = ArrayFromJSON(list(utf8()), R"([["abc", "def"], ["efg"], []])");
+ auto indices = ArrayFromJSON(int32(), "[0, 2, 1, null, 1]");
+ return DictionaryArray::FromArrays(
+ dictionary(indices->type(), values->type(), /*ordered=*/true), indices, values);
+ };
+ TestWithArrayFactory(factory);
+ TestWithArrayFactory(SlicedArrayFactory(factory));
+ }
+}
+
+TEST_F(TestArrayRoundtrip, RegisteredExtension) {
+ ExtensionTypeGuard guard({smallint(), complex128(), dict_extension_type(), uuid()});
+
+ TestWithArrayFactory(ExampleSmallint);
+ TestWithArrayFactory(ExampleUuid);
+ TestWithArrayFactory(ExampleComplex128);
+ TestWithArrayFactory(ExampleDictExtension);
+
+ // Nested inside outer array
+ auto NestedFactory = [](ArrayFactory factory) {
+ return [factory]() -> Result<std::shared_ptr<Array>> {
+ ARROW_ASSIGN_OR_RAISE(auto arr, ToResult(factory()));
+ return FixedSizeListArray::FromArrays(arr, /*list_size=*/1);
+ };
+ };
+ TestWithArrayFactory(NestedFactory(ExampleSmallint));
+ TestWithArrayFactory(NestedFactory(ExampleUuid));
+ TestWithArrayFactory(NestedFactory(ExampleComplex128));
+ TestWithArrayFactory(NestedFactory(ExampleDictExtension));
+}
+
+TEST_F(TestArrayRoundtrip, UnregisteredExtension) {
+ auto StorageExtractor = [](ArrayFactory factory) {
+ return [factory]() -> Result<std::shared_ptr<Array>> {
+ ARROW_ASSIGN_OR_RAISE(auto arr, ToResult(factory()));
+ return checked_cast<const ExtensionArray&>(*arr).storage();
+ };
+ };
+
+ TestWithArrayFactory(ExampleSmallint, StorageExtractor(ExampleSmallint));
+ TestWithArrayFactory(ExampleUuid, StorageExtractor(ExampleUuid));
+ TestWithArrayFactory(ExampleComplex128, StorageExtractor(ExampleComplex128));
+ TestWithArrayFactory(ExampleDictExtension, StorageExtractor(ExampleDictExtension));
+}
+
+TEST_F(TestArrayRoundtrip, RecordBatch) {
+ auto schema = ::arrow::schema(
+ {field("ints", int16()), field("bools", boolean(), /*nullable=*/false)});
+ auto arr0 = ArrayFromJSON(int16(), "[1, 2, null]");
+ auto arr1 = ArrayFromJSON(boolean(), "[false, true, false]");
+
+ {
+ auto factory = [&]() { return RecordBatch::Make(schema, 3, {arr0, arr1}); };
+ TestWithBatchFactory(factory);
+ }
+ {
+ // With schema and field metadata
+ auto factory = [&]() {
+ auto f0 = schema->field(0);
+ auto f1 = schema->field(1);
+ f1 = f1->WithMetadata(key_value_metadata(kMetadataKeys1, kMetadataValues1));
+ auto schema_with_md =
+ ::arrow::schema({f0, f1}, key_value_metadata(kMetadataKeys2, kMetadataValues2));
+ return RecordBatch::Make(schema_with_md, 3, {arr0, arr1});
+ };
+ TestWithBatchFactory(factory);
+ }
+}
+
+// TODO C -> C++ -> C roundtripping tests?
+
+////////////////////////////////////////////////////////////////////////////
+// Array stream export tests
+
+class FailingRecordBatchReader : public RecordBatchReader {
+ public:
+ explicit FailingRecordBatchReader(Status error) : error_(std::move(error)) {}
+
+ static std::shared_ptr<Schema> expected_schema() { return arrow::schema({}); }
+
+ std::shared_ptr<Schema> schema() const override { return expected_schema(); }
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* batch) override { return error_; }
+
+ protected:
+ Status error_;
+};
+
+class BaseArrayStreamTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ pool_ = default_memory_pool();
+ orig_allocated_ = pool_->bytes_allocated();
+ }
+
+ void TearDown() override { ASSERT_EQ(pool_->bytes_allocated(), orig_allocated_); }
+
+ RecordBatchVector MakeBatches(std::shared_ptr<Schema> schema, ArrayVector arrays) {
+ DCHECK_EQ(schema->num_fields(), 1);
+ RecordBatchVector batches;
+ for (const auto& array : arrays) {
+ batches.push_back(RecordBatch::Make(schema, array->length(), {array}));
+ }
+ return batches;
+ }
+
+ protected:
+ MemoryPool* pool_;
+ int64_t orig_allocated_;
+};
+
+class TestArrayStreamExport : public BaseArrayStreamTest {
+ public:
+ void AssertStreamSchema(struct ArrowArrayStream* c_stream, const Schema& expected) {
+ struct ArrowSchema c_schema;
+ ASSERT_EQ(0, c_stream->get_schema(c_stream, &c_schema));
+
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
+ ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
+ AssertSchemaEqual(expected, *schema);
+ }
+
+ void AssertStreamEnd(struct ArrowArrayStream* c_stream) {
+ struct ArrowArray c_array;
+ ASSERT_EQ(0, c_stream->get_next(c_stream, &c_array));
+
+ ArrayExportGuard guard(&c_array);
+ ASSERT_TRUE(ArrowArrayIsReleased(&c_array));
+ }
+
+ void AssertStreamNext(struct ArrowArrayStream* c_stream, const RecordBatch& expected) {
+ struct ArrowArray c_array;
+ ASSERT_EQ(0, c_stream->get_next(c_stream, &c_array));
+
+ ArrayExportGuard guard(&c_array);
+ ASSERT_FALSE(ArrowArrayIsReleased(&c_array));
+
+ ASSERT_OK_AND_ASSIGN(auto batch, ImportRecordBatch(&c_array, expected.schema()));
+ AssertBatchesEqual(expected, *batch);
+ }
+};
+
+TEST_F(TestArrayStreamExport, Empty) {
+ auto schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(schema, {});
+ ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make(batches, schema));
+
+ struct ArrowArrayStream c_stream;
+
+ ASSERT_OK(ExportRecordBatchReader(reader, &c_stream));
+ ArrayStreamExportGuard guard(&c_stream);
+
+ ASSERT_FALSE(ArrowArrayStreamIsReleased(&c_stream));
+ AssertStreamSchema(&c_stream, *schema);
+ AssertStreamEnd(&c_stream);
+ AssertStreamEnd(&c_stream);
+}
+
+TEST_F(TestArrayStreamExport, Simple) {
+ auto schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(
+ schema, {ArrayFromJSON(int32(), "[1, 2]"), ArrayFromJSON(int32(), "[4, 5, null]")});
+ ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make(batches, schema));
+
+ struct ArrowArrayStream c_stream;
+
+ ASSERT_OK(ExportRecordBatchReader(reader, &c_stream));
+ ArrayStreamExportGuard guard(&c_stream);
+
+ ASSERT_FALSE(ArrowArrayStreamIsReleased(&c_stream));
+ AssertStreamSchema(&c_stream, *schema);
+ AssertStreamNext(&c_stream, *batches[0]);
+ AssertStreamNext(&c_stream, *batches[1]);
+ AssertStreamEnd(&c_stream);
+ AssertStreamEnd(&c_stream);
+}
+
+TEST_F(TestArrayStreamExport, ArrayLifetime) {
+ auto schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(
+ schema, {ArrayFromJSON(int32(), "[1, 2]"), ArrayFromJSON(int32(), "[4, 5, null]")});
+ ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make(batches, schema));
+
+ struct ArrowArrayStream c_stream;
+ struct ArrowSchema c_schema;
+ struct ArrowArray c_array0, c_array1;
+
+ ASSERT_OK(ExportRecordBatchReader(reader, &c_stream));
+ {
+ ArrayStreamExportGuard guard(&c_stream);
+ ASSERT_FALSE(ArrowArrayStreamIsReleased(&c_stream));
+
+ ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema));
+ ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array0));
+ ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array1));
+ AssertStreamEnd(&c_stream);
+ }
+
+ ArrayExportGuard guard0(&c_array0), guard1(&c_array1);
+
+ {
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_OK_AND_ASSIGN(auto got_schema, ImportSchema(&c_schema));
+ AssertSchemaEqual(*schema, *got_schema);
+ }
+
+ ASSERT_GT(pool_->bytes_allocated(), orig_allocated_);
+ ASSERT_OK_AND_ASSIGN(auto batch, ImportRecordBatch(&c_array1, schema));
+ AssertBatchesEqual(*batches[1], *batch);
+ ASSERT_OK_AND_ASSIGN(batch, ImportRecordBatch(&c_array0, schema));
+ AssertBatchesEqual(*batches[0], *batch);
+}
+
+TEST_F(TestArrayStreamExport, Errors) {
+ auto reader =
+ std::make_shared<FailingRecordBatchReader>(Status::Invalid("some example error"));
+
+ struct ArrowArrayStream c_stream;
+
+ ASSERT_OK(ExportRecordBatchReader(reader, &c_stream));
+ ArrayStreamExportGuard guard(&c_stream);
+
+ struct ArrowSchema c_schema;
+ ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema));
+ ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
+ {
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
+ AssertSchemaEqual(schema, arrow::schema({}));
+ }
+
+ struct ArrowArray c_array;
+ ASSERT_EQ(EINVAL, c_stream.get_next(&c_stream, &c_array));
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Array stream roundtrip tests
+
+class TestArrayStreamRoundtrip : public BaseArrayStreamTest {
+ public:
+ void Roundtrip(std::shared_ptr<RecordBatchReader>* reader,
+ struct ArrowArrayStream* c_stream) {
+ ASSERT_OK(ExportRecordBatchReader(*reader, c_stream));
+ ASSERT_FALSE(ArrowArrayStreamIsReleased(c_stream));
+
+ ASSERT_OK_AND_ASSIGN(auto got_reader, ImportRecordBatchReader(c_stream));
+ *reader = std::move(got_reader);
+ }
+
+ void Roundtrip(
+ std::shared_ptr<RecordBatchReader> reader,
+ std::function<void(const std::shared_ptr<RecordBatchReader>&)> check_func) {
+ ArrowArrayStream c_stream;
+
+ // NOTE: ReleaseCallback<> is not immediately usable with ArrowArrayStream,
+ // because get_next and get_schema need the original private_data.
+ std::weak_ptr<RecordBatchReader> weak_reader(reader);
+ ASSERT_EQ(weak_reader.use_count(), 1); // Expiration check will fail otherwise
+
+ ASSERT_OK(ExportRecordBatchReader(std::move(reader), &c_stream));
+ ASSERT_FALSE(ArrowArrayStreamIsReleased(&c_stream));
+
+ {
+ ASSERT_OK_AND_ASSIGN(auto new_reader, ImportRecordBatchReader(&c_stream));
+ // Stream was moved
+ ASSERT_TRUE(ArrowArrayStreamIsReleased(&c_stream));
+ ASSERT_FALSE(weak_reader.expired());
+
+ check_func(new_reader);
+ }
+ // Stream was released when `new_reader` was destroyed
+ ASSERT_TRUE(weak_reader.expired());
+ }
+
+ void AssertReaderNext(const std::shared_ptr<RecordBatchReader>& reader,
+ const RecordBatch& expected) {
+ ASSERT_OK_AND_ASSIGN(auto batch, reader->Next());
+ ASSERT_NE(batch, nullptr);
+ AssertBatchesEqual(expected, *batch);
+ }
+
+ void AssertReaderEnd(const std::shared_ptr<RecordBatchReader>& reader) {
+ ASSERT_OK_AND_ASSIGN(auto batch, reader->Next());
+ ASSERT_EQ(batch, nullptr);
+ }
+};
+
+TEST_F(TestArrayStreamRoundtrip, Simple) {
+ auto orig_schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(orig_schema, {ArrayFromJSON(int32(), "[1, 2]"),
+ ArrayFromJSON(int32(), "[4, 5, null]")});
+
+ ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make(batches, orig_schema));
+
+ Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>& reader) {
+ AssertSchemaEqual(*orig_schema, *reader->schema());
+ AssertReaderNext(reader, *batches[0]);
+ AssertReaderNext(reader, *batches[1]);
+ AssertReaderEnd(reader);
+ AssertReaderEnd(reader);
+ });
+}
+
+TEST_F(TestArrayStreamRoundtrip, Errors) {
+ auto reader = std::make_shared<FailingRecordBatchReader>(
+ Status::Invalid("roundtrip error example"));
+
+ Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>& reader) {
+ auto status = reader->Next().status();
+ ASSERT_RAISES(Invalid, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("roundtrip error example"));
+ });
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/c/helpers.h b/src/arrow/cpp/src/arrow/c/helpers.h
new file mode 100644
index 000000000..a5c1f6fe4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/c/helpers.h
@@ -0,0 +1,117 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <assert.h>
+#include <string.h>
+
+#include "arrow/c/abi.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/// Query whether the C schema is released
+inline int ArrowSchemaIsReleased(const struct ArrowSchema* schema) {
+ return schema->release == NULL;
+}
+
+/// Mark the C schema released (for use in release callbacks)
+inline void ArrowSchemaMarkReleased(struct ArrowSchema* schema) {
+ schema->release = NULL;
+}
+
+/// Move the C schema from `src` to `dest`
+///
+/// Note `dest` must *not* point to a valid schema already, otherwise there
+/// will be a memory leak.
+inline void ArrowSchemaMove(struct ArrowSchema* src, struct ArrowSchema* dest) {
+ assert(dest != src);
+ assert(!ArrowSchemaIsReleased(src));
+ memcpy(dest, src, sizeof(struct ArrowSchema));
+ ArrowSchemaMarkReleased(src);
+}
+
+/// Release the C schema, if necessary, by calling its release callback
+inline void ArrowSchemaRelease(struct ArrowSchema* schema) {
+ if (!ArrowSchemaIsReleased(schema)) {
+ schema->release(schema);
+ assert(ArrowSchemaIsReleased(schema));
+ }
+}
+
+/// Query whether the C array is released
+inline int ArrowArrayIsReleased(const struct ArrowArray* array) {
+ return array->release == NULL;
+}
+
+/// Mark the C array released (for use in release callbacks)
+inline void ArrowArrayMarkReleased(struct ArrowArray* array) { array->release = NULL; }
+
+/// Move the C array from `src` to `dest`
+///
+/// Note `dest` must *not* point to a valid array already, otherwise there
+/// will be a memory leak.
+inline void ArrowArrayMove(struct ArrowArray* src, struct ArrowArray* dest) {
+ assert(dest != src);
+ assert(!ArrowArrayIsReleased(src));
+ memcpy(dest, src, sizeof(struct ArrowArray));
+ ArrowArrayMarkReleased(src);
+}
+
+/// Release the C array, if necessary, by calling its release callback
+inline void ArrowArrayRelease(struct ArrowArray* array) {
+ if (!ArrowArrayIsReleased(array)) {
+ array->release(array);
+ assert(ArrowArrayIsReleased(array));
+ }
+}
+
+/// Query whether the C array stream is released
+inline int ArrowArrayStreamIsReleased(const struct ArrowArrayStream* stream) {
+ return stream->release == NULL;
+}
+
+/// Mark the C array stream released (for use in release callbacks)
+inline void ArrowArrayStreamMarkReleased(struct ArrowArrayStream* stream) {
+ stream->release = NULL;
+}
+
+/// Move the C array stream from `src` to `dest`
+///
+/// Note `dest` must *not* point to a valid stream already, otherwise there
+/// will be a memory leak.
+inline void ArrowArrayStreamMove(struct ArrowArrayStream* src,
+ struct ArrowArrayStream* dest) {
+ assert(dest != src);
+ assert(!ArrowArrayStreamIsReleased(src));
+ memcpy(dest, src, sizeof(struct ArrowArrayStream));
+ ArrowArrayStreamMarkReleased(src);
+}
+
+/// Release the C array stream, if necessary, by calling its release callback
+inline void ArrowArrayStreamRelease(struct ArrowArrayStream* stream) {
+ if (!ArrowArrayStreamIsReleased(stream)) {
+ stream->release(stream);
+ assert(ArrowArrayStreamIsReleased(stream));
+ }
+}
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/src/arrow/cpp/src/arrow/c/util_internal.h b/src/arrow/cpp/src/arrow/c/util_internal.h
new file mode 100644
index 000000000..6a33be9b0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/c/util_internal.h
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/c/helpers.h"
+
+namespace arrow {
+namespace internal {
+
+struct SchemaExportTraits {
+ typedef struct ArrowSchema CType;
+ static constexpr auto IsReleasedFunc = &ArrowSchemaIsReleased;
+ static constexpr auto ReleaseFunc = &ArrowSchemaRelease;
+};
+
+struct ArrayExportTraits {
+ typedef struct ArrowArray CType;
+ static constexpr auto IsReleasedFunc = &ArrowArrayIsReleased;
+ static constexpr auto ReleaseFunc = &ArrowArrayRelease;
+};
+
+struct ArrayStreamExportTraits {
+ typedef struct ArrowArrayStream CType;
+ static constexpr auto IsReleasedFunc = &ArrowArrayStreamIsReleased;
+ static constexpr auto ReleaseFunc = &ArrowArrayStreamRelease;
+};
+
+// A RAII-style object to release a C Array / Schema struct at block scope exit.
+template <typename Traits>
+class ExportGuard {
+ public:
+ using CType = typename Traits::CType;
+
+ explicit ExportGuard(CType* c_export) : c_export_(c_export) {}
+
+ ExportGuard(ExportGuard&& other) : c_export_(other.c_export_) {
+ other.c_export_ = nullptr;
+ }
+
+ ExportGuard& operator=(ExportGuard&& other) {
+ Release();
+ c_export_ = other.c_export_;
+ other.c_export_ = nullptr;
+ }
+
+ ~ExportGuard() { Release(); }
+
+ void Detach() { c_export_ = nullptr; }
+
+ void Reset(CType* c_export) { c_export_ = c_export; }
+
+ void Release() {
+ if (c_export_) {
+ Traits::ReleaseFunc(c_export_);
+ c_export_ = nullptr;
+ }
+ }
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ExportGuard);
+
+ CType* c_export_;
+};
+
+using SchemaExportGuard = ExportGuard<SchemaExportTraits>;
+using ArrayExportGuard = ExportGuard<ArrayExportTraits>;
+using ArrayStreamExportGuard = ExportGuard<ArrayStreamExportTraits>;
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/chunked_array.cc b/src/arrow/cpp/src/arrow/chunked_array.cc
new file mode 100644
index 000000000..0c954e72e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/chunked_array.cc
@@ -0,0 +1,304 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/chunked_array.h"
+
+#include <algorithm>
+#include <cstdlib>
+#include <memory>
+#include <sstream>
+#include <utility>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/array/validate.h"
+#include "arrow/pretty_print.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+class MemoryPool;
+
+// ----------------------------------------------------------------------
+// ChunkedArray methods
+
+ChunkedArray::ChunkedArray(ArrayVector chunks) : chunks_(std::move(chunks)) {
+ length_ = 0;
+ null_count_ = 0;
+
+ ARROW_CHECK_GT(chunks_.size(), 0)
+ << "cannot construct ChunkedArray from empty vector and omitted type";
+ type_ = chunks_[0]->type();
+ for (const std::shared_ptr<Array>& chunk : chunks_) {
+ length_ += chunk->length();
+ null_count_ += chunk->null_count();
+ }
+}
+
+ChunkedArray::ChunkedArray(ArrayVector chunks, std::shared_ptr<DataType> type)
+ : chunks_(std::move(chunks)), type_(std::move(type)) {
+ length_ = 0;
+ null_count_ = 0;
+ for (const std::shared_ptr<Array>& chunk : chunks_) {
+ length_ += chunk->length();
+ null_count_ += chunk->null_count();
+ }
+}
+
+Result<std::shared_ptr<ChunkedArray>> ChunkedArray::Make(ArrayVector chunks,
+ std::shared_ptr<DataType> type) {
+ if (type == nullptr) {
+ if (chunks.size() == 0) {
+ return Status::Invalid(
+ "cannot construct ChunkedArray from empty vector "
+ "and omitted type");
+ }
+ type = chunks[0]->type();
+ }
+ for (size_t i = 0; i < chunks.size(); ++i) {
+ if (!chunks[i]->type()->Equals(*type)) {
+ return Status::Invalid("Array chunks must all be same type");
+ }
+ }
+ return std::make_shared<ChunkedArray>(std::move(chunks), std::move(type));
+}
+
+bool ChunkedArray::Equals(const ChunkedArray& other) const {
+ if (length_ != other.length()) {
+ return false;
+ }
+ if (null_count_ != other.null_count()) {
+ return false;
+ }
+ // We cannot toggle check_metadata here yet, so we don't check it
+ if (!type_->Equals(*other.type_, /*check_metadata=*/false)) {
+ return false;
+ }
+
+ // Check contents of the underlying arrays. This checks for equality of
+ // the underlying data independently of the chunk size.
+ return internal::ApplyBinaryChunked(
+ *this, other,
+ [](const Array& left_piece, const Array& right_piece,
+ int64_t ARROW_ARG_UNUSED(position)) {
+ if (!left_piece.Equals(right_piece)) {
+ return Status::Invalid("Unequal piece");
+ }
+ return Status::OK();
+ })
+ .ok();
+}
+
+bool ChunkedArray::Equals(const std::shared_ptr<ChunkedArray>& other) const {
+ if (this == other.get()) {
+ return true;
+ }
+ if (!other) {
+ return false;
+ }
+ return Equals(*other.get());
+}
+
+bool ChunkedArray::ApproxEquals(const ChunkedArray& other,
+ const EqualOptions& equal_options) const {
+ if (length_ != other.length()) {
+ return false;
+ }
+ if (null_count_ != other.null_count()) {
+ return false;
+ }
+ // We cannot toggle check_metadata here yet, so we don't check it
+ if (!type_->Equals(*other.type_, /*check_metadata=*/false)) {
+ return false;
+ }
+
+ // Check contents of the underlying arrays. This checks for equality of
+ // the underlying data independently of the chunk size.
+ return internal::ApplyBinaryChunked(
+ *this, other,
+ [&](const Array& left_piece, const Array& right_piece,
+ int64_t ARROW_ARG_UNUSED(position)) {
+ if (!left_piece.ApproxEquals(right_piece, equal_options)) {
+ return Status::Invalid("Unequal piece");
+ }
+ return Status::OK();
+ })
+ .ok();
+}
+
+Result<std::shared_ptr<Scalar>> ChunkedArray::GetScalar(int64_t index) const {
+ for (const auto& chunk : chunks_) {
+ if (index < chunk->length()) {
+ return chunk->GetScalar(index);
+ }
+ index -= chunk->length();
+ }
+ return Status::Invalid("index out of bounds");
+}
+
+std::shared_ptr<ChunkedArray> ChunkedArray::Slice(int64_t offset, int64_t length) const {
+ ARROW_CHECK_LE(offset, length_) << "Slice offset greater than array length";
+ bool offset_equals_length = offset == length_;
+ int curr_chunk = 0;
+ while (curr_chunk < num_chunks() && offset >= chunk(curr_chunk)->length()) {
+ offset -= chunk(curr_chunk)->length();
+ curr_chunk++;
+ }
+
+ ArrayVector new_chunks;
+ if (num_chunks() > 0 && (offset_equals_length || length == 0)) {
+ // Special case the zero-length slice to make sure there is at least 1 Array
+ // in the result. When there are zero chunks we return zero chunks
+ new_chunks.push_back(chunk(std::min(curr_chunk, num_chunks() - 1))->Slice(0, 0));
+ } else {
+ while (curr_chunk < num_chunks() && length > 0) {
+ new_chunks.push_back(chunk(curr_chunk)->Slice(offset, length));
+ length -= chunk(curr_chunk)->length() - offset;
+ offset = 0;
+ curr_chunk++;
+ }
+ }
+
+ return std::make_shared<ChunkedArray>(new_chunks, type_);
+}
+
+std::shared_ptr<ChunkedArray> ChunkedArray::Slice(int64_t offset) const {
+ return Slice(offset, length_);
+}
+
+Result<std::vector<std::shared_ptr<ChunkedArray>>> ChunkedArray::Flatten(
+ MemoryPool* pool) const {
+ if (type()->id() != Type::STRUCT) {
+ // Emulate nonexistent copy constructor
+ return std::vector<std::shared_ptr<ChunkedArray>>{
+ std::make_shared<ChunkedArray>(chunks_, type_)};
+ }
+
+ std::vector<ArrayVector> flattened_chunks(type()->num_fields());
+ for (const auto& chunk : chunks_) {
+ ARROW_ASSIGN_OR_RAISE(auto arrays,
+ checked_cast<const StructArray&>(*chunk).Flatten(pool));
+ DCHECK_EQ(arrays.size(), flattened_chunks.size());
+ for (size_t i = 0; i < arrays.size(); ++i) {
+ flattened_chunks[i].push_back(arrays[i]);
+ }
+ }
+
+ std::vector<std::shared_ptr<ChunkedArray>> flattened(type()->num_fields());
+ for (size_t i = 0; i < flattened.size(); ++i) {
+ auto child_type = type()->field(static_cast<int>(i))->type();
+ flattened[i] =
+ std::make_shared<ChunkedArray>(std::move(flattened_chunks[i]), child_type);
+ }
+ return flattened;
+}
+
+Result<std::shared_ptr<ChunkedArray>> ChunkedArray::View(
+ const std::shared_ptr<DataType>& type) const {
+ ArrayVector out_chunks(this->num_chunks());
+ for (int i = 0; i < this->num_chunks(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(out_chunks[i], chunks_[i]->View(type));
+ }
+ return std::make_shared<ChunkedArray>(out_chunks, type);
+}
+
+std::string ChunkedArray::ToString() const {
+ std::stringstream ss;
+ ARROW_CHECK_OK(PrettyPrint(*this, 0, &ss));
+ return ss.str();
+}
+
+Status ChunkedArray::Validate() const {
+ if (chunks_.size() == 0) {
+ return Status::OK();
+ }
+
+ const auto& type = *chunks_[0]->type();
+ // Make sure chunks all have the same type
+ for (size_t i = 1; i < chunks_.size(); ++i) {
+ const Array& chunk = *chunks_[i];
+ if (!chunk.type()->Equals(type)) {
+ return Status::Invalid("In chunk ", i, " expected type ", type.ToString(),
+ " but saw ", chunk.type()->ToString());
+ }
+ }
+ // Validate the chunks themselves
+ for (size_t i = 0; i < chunks_.size(); ++i) {
+ const Array& chunk = *chunks_[i];
+ const Status st = internal::ValidateArray(chunk);
+ if (!st.ok()) {
+ return Status::Invalid("In chunk ", i, ": ", st.ToString());
+ }
+ }
+ return Status::OK();
+}
+
+Status ChunkedArray::ValidateFull() const {
+ RETURN_NOT_OK(Validate());
+ for (size_t i = 0; i < chunks_.size(); ++i) {
+ const Array& chunk = *chunks_[i];
+ const Status st = internal::ValidateArrayFull(chunk);
+ if (!st.ok()) {
+ return Status::Invalid("In chunk ", i, ": ", st.ToString());
+ }
+ }
+ return Status::OK();
+}
+
+namespace internal {
+
+bool MultipleChunkIterator::Next(std::shared_ptr<Array>* next_left,
+ std::shared_ptr<Array>* next_right) {
+ if (pos_ == length_) return false;
+
+ // Find non-empty chunk
+ std::shared_ptr<Array> chunk_left, chunk_right;
+ while (true) {
+ chunk_left = left_.chunk(chunk_idx_left_);
+ chunk_right = right_.chunk(chunk_idx_right_);
+ if (chunk_pos_left_ == chunk_left->length()) {
+ chunk_pos_left_ = 0;
+ ++chunk_idx_left_;
+ continue;
+ }
+ if (chunk_pos_right_ == chunk_right->length()) {
+ chunk_pos_right_ = 0;
+ ++chunk_idx_right_;
+ continue;
+ }
+ break;
+ }
+ // Determine how big of a section to return
+ int64_t iteration_size = std::min(chunk_left->length() - chunk_pos_left_,
+ chunk_right->length() - chunk_pos_right_);
+
+ *next_left = chunk_left->Slice(chunk_pos_left_, iteration_size);
+ *next_right = chunk_right->Slice(chunk_pos_right_, iteration_size);
+
+ pos_ += iteration_size;
+ chunk_pos_left_ += iteration_size;
+ chunk_pos_right_ += iteration_size;
+ return true;
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/chunked_array.h b/src/arrow/cpp/src/arrow/chunked_array.h
new file mode 100644
index 000000000..0bf0c66c1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/chunked_array.h
@@ -0,0 +1,255 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/compare.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Array;
+class DataType;
+class MemoryPool;
+
+/// \class ChunkedArray
+/// \brief A data structure managing a list of primitive Arrow arrays logically
+/// as one large array
+///
+/// Data chunking is treated throughout this project largely as an
+/// implementation detail for performance and memory use optimization.
+/// ChunkedArray allows Array objects to be collected and interpreted
+/// as a single logical array without requiring an expensive concatenation
+/// step.
+///
+/// In some cases, data produced by a function may exceed the capacity of an
+/// Array (like BinaryArray or StringArray) and so returning multiple Arrays is
+/// the only possibility. In these cases, we recommend returning a ChunkedArray
+/// instead of vector of Arrays or some alternative.
+///
+/// When data is processed in parallel, it may not be practical or possible to
+/// create large contiguous memory allocations and write output into them. With
+/// some data types, like binary and string types, it is not possible at all to
+/// produce non-chunked array outputs without requiring a concatenation step at
+/// the end of processing.
+///
+/// Application developers may tune chunk sizes based on analysis of
+/// performance profiles but many developer-users will not need to be
+/// especially concerned with the chunking details.
+///
+/// Preserving the chunk layout/sizes in processing steps is generally not
+/// considered to be a contract in APIs. A function may decide to alter the
+/// chunking of its result. Similarly, APIs accepting multiple ChunkedArray
+/// inputs should not expect the chunk layout to be the same in each input.
+class ARROW_EXPORT ChunkedArray {
+ public:
+ /// \brief Construct a chunked array from a vector of arrays
+ ///
+ /// The vector must be non-empty and all its elements must have the same
+ /// data type.
+ explicit ChunkedArray(ArrayVector chunks);
+
+ ChunkedArray(ChunkedArray&&) = default;
+ ChunkedArray& operator=(ChunkedArray&&) = default;
+
+ /// \brief Construct a chunked array from a single Array
+ explicit ChunkedArray(std::shared_ptr<Array> chunk)
+ : ChunkedArray(ArrayVector{std::move(chunk)}) {}
+
+ /// \brief Construct a chunked array from a vector of arrays and a data type
+ ///
+ /// As the data type is passed explicitly, the vector may be empty.
+ ChunkedArray(ArrayVector chunks, std::shared_ptr<DataType> type);
+
+ // \brief Constructor with basic input validation.
+ static Result<std::shared_ptr<ChunkedArray>> Make(
+ ArrayVector chunks, std::shared_ptr<DataType> type = NULLPTR);
+
+ /// \return the total length of the chunked array; computed on construction
+ int64_t length() const { return length_; }
+
+ /// \return the total number of nulls among all chunks
+ int64_t null_count() const { return null_count_; }
+
+ int num_chunks() const { return static_cast<int>(chunks_.size()); }
+
+ /// \return chunk a particular chunk from the chunked array
+ std::shared_ptr<Array> chunk(int i) const { return chunks_[i]; }
+
+ const ArrayVector& chunks() const { return chunks_; }
+
+ /// \brief Construct a zero-copy slice of the chunked array with the
+ /// indicated offset and length
+ ///
+ /// \param[in] offset the position of the first element in the constructed
+ /// slice
+ /// \param[in] length the length of the slice. If there are not enough
+ /// elements in the chunked array, the length will be adjusted accordingly
+ ///
+ /// \return a new object wrapped in std::shared_ptr<ChunkedArray>
+ std::shared_ptr<ChunkedArray> Slice(int64_t offset, int64_t length) const;
+
+ /// \brief Slice from offset until end of the chunked array
+ std::shared_ptr<ChunkedArray> Slice(int64_t offset) const;
+
+ /// \brief Flatten this chunked array as a vector of chunked arrays, one
+ /// for each struct field
+ ///
+ /// \param[in] pool The pool for buffer allocations, if any
+ Result<std::vector<std::shared_ptr<ChunkedArray>>> Flatten(
+ MemoryPool* pool = default_memory_pool()) const;
+
+ /// Construct a zero-copy view of this chunked array with the given
+ /// type. Calls Array::View on each constituent chunk. Always succeeds if
+ /// there are zero chunks
+ Result<std::shared_ptr<ChunkedArray>> View(const std::shared_ptr<DataType>& type) const;
+
+ const std::shared_ptr<DataType>& type() const { return type_; }
+
+ /// \brief Return a Scalar containing the value of this array at index
+ Result<std::shared_ptr<Scalar>> GetScalar(int64_t index) const;
+
+ /// \brief Determine if two chunked arrays are equal.
+ ///
+ /// Two chunked arrays can be equal only if they have equal datatypes.
+ /// However, they may be equal even if they have different chunkings.
+ bool Equals(const ChunkedArray& other) const;
+ /// \brief Determine if two chunked arrays are equal.
+ bool Equals(const std::shared_ptr<ChunkedArray>& other) const;
+ /// \brief Determine if two chunked arrays approximately equal
+ bool ApproxEquals(const ChunkedArray& other,
+ const EqualOptions& = EqualOptions::Defaults()) const;
+
+ /// \return PrettyPrint representation suitable for debugging
+ std::string ToString() const;
+
+ /// \brief Perform cheap validation checks to determine obvious inconsistencies
+ /// within the chunk array's internal data.
+ ///
+ /// This is O(k*m) where k is the number of array descendents,
+ /// and m is the number of chunks.
+ ///
+ /// \return Status
+ Status Validate() const;
+
+ /// \brief Perform extensive validation checks to determine inconsistencies
+ /// within the chunk array's internal data.
+ ///
+ /// This is O(k*n) where k is the number of array descendents,
+ /// and n is the length in elements.
+ ///
+ /// \return Status
+ Status ValidateFull() const;
+
+ protected:
+ ArrayVector chunks_;
+ int64_t length_;
+ int64_t null_count_;
+ std::shared_ptr<DataType> type_;
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ChunkedArray);
+};
+
+namespace internal {
+
+/// \brief EXPERIMENTAL: Utility for incremental iteration over contiguous
+/// pieces of potentially differently-chunked ChunkedArray objects
+class ARROW_EXPORT MultipleChunkIterator {
+ public:
+ MultipleChunkIterator(const ChunkedArray& left, const ChunkedArray& right)
+ : left_(left),
+ right_(right),
+ pos_(0),
+ length_(left.length()),
+ chunk_idx_left_(0),
+ chunk_idx_right_(0),
+ chunk_pos_left_(0),
+ chunk_pos_right_(0) {}
+
+ bool Next(std::shared_ptr<Array>* next_left, std::shared_ptr<Array>* next_right);
+
+ int64_t position() const { return pos_; }
+
+ private:
+ const ChunkedArray& left_;
+ const ChunkedArray& right_;
+
+ // The amount of the entire ChunkedArray consumed
+ int64_t pos_;
+
+ // Length of the chunked array(s)
+ int64_t length_;
+
+ // Current left chunk
+ int chunk_idx_left_;
+
+ // Current right chunk
+ int chunk_idx_right_;
+
+ // Offset into the current left chunk
+ int64_t chunk_pos_left_;
+
+ // Offset into the current right chunk
+ int64_t chunk_pos_right_;
+};
+
+/// \brief Evaluate binary function on two ChunkedArray objects having possibly
+/// different chunk layouts. The passed binary function / functor should have
+/// the following signature.
+///
+/// Status(const Array&, const Array&, int64_t)
+///
+/// The third argument is the absolute position relative to the start of each
+/// ChunkedArray. The function is executed against each contiguous pair of
+/// array segments, slicing if necessary.
+///
+/// For example, if two arrays have chunk sizes
+///
+/// left: [10, 10, 20]
+/// right: [15, 10, 15]
+///
+/// Then the following invocations take place (pseudocode)
+///
+/// func(left.chunk[0][0:10], right.chunk[0][0:10], 0)
+/// func(left.chunk[1][0:5], right.chunk[0][10:15], 10)
+/// func(left.chunk[1][5:10], right.chunk[1][0:5], 15)
+/// func(left.chunk[2][0:5], right.chunk[1][5:10], 20)
+/// func(left.chunk[2][5:20], right.chunk[2][:], 25)
+template <typename Action>
+Status ApplyBinaryChunked(const ChunkedArray& left, const ChunkedArray& right,
+ Action&& action) {
+ MultipleChunkIterator iterator(left, right);
+ std::shared_ptr<Array> left_piece, right_piece;
+ while (iterator.Next(&left_piece, &right_piece)) {
+ ARROW_RETURN_NOT_OK(action(*left_piece, *right_piece, iterator.position()));
+ }
+ return Status::OK();
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/chunked_array_test.cc b/src/arrow/cpp/src/arrow/chunked_array_test.cc
new file mode 100644
index 000000000..c41a4c2bd
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/chunked_array_test.cc
@@ -0,0 +1,266 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/chunked_array.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+
+class TestChunkedArray : public TestBase {
+ protected:
+ virtual void Construct() {
+ one_ = std::make_shared<ChunkedArray>(arrays_one_);
+ if (!arrays_another_.empty()) {
+ another_ = std::make_shared<ChunkedArray>(arrays_another_);
+ }
+ }
+
+ ArrayVector arrays_one_;
+ ArrayVector arrays_another_;
+
+ std::shared_ptr<ChunkedArray> one_;
+ std::shared_ptr<ChunkedArray> another_;
+};
+
+TEST_F(TestChunkedArray, Make) {
+ ASSERT_RAISES(Invalid, ChunkedArray::Make({}));
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<ChunkedArray> result,
+ ChunkedArray::Make({}, int64()));
+ AssertTypeEqual(*int64(), *result->type());
+ ASSERT_EQ(result->num_chunks(), 0);
+
+ auto chunk0 = ArrayFromJSON(int8(), "[0, 1, 2]");
+ auto chunk1 = ArrayFromJSON(int16(), "[3, 4, 5]");
+
+ ASSERT_OK_AND_ASSIGN(result, ChunkedArray::Make({chunk0, chunk0}));
+ ASSERT_OK_AND_ASSIGN(auto result2, ChunkedArray::Make({chunk0, chunk0}, int8()));
+ AssertChunkedEqual(*result, *result2);
+
+ ASSERT_RAISES(Invalid, ChunkedArray::Make({chunk0, chunk1}));
+ ASSERT_RAISES(Invalid, ChunkedArray::Make({chunk0}, int16()));
+}
+
+TEST_F(TestChunkedArray, BasicEquals) {
+ std::vector<bool> null_bitmap(100, true);
+ std::vector<int32_t> data(100, 1);
+ std::shared_ptr<Array> array;
+ ArrayFromVector<Int32Type, int32_t>(null_bitmap, data, &array);
+ arrays_one_.push_back(array);
+ arrays_another_.push_back(array);
+
+ Construct();
+ ASSERT_TRUE(one_->Equals(one_));
+ ASSERT_FALSE(one_->Equals(nullptr));
+ ASSERT_TRUE(one_->Equals(another_));
+ ASSERT_TRUE(one_->Equals(*another_.get()));
+}
+
+TEST_F(TestChunkedArray, EqualsDifferingTypes) {
+ std::vector<bool> null_bitmap(100, true);
+ std::vector<int32_t> data32(100, 1);
+ std::vector<int64_t> data64(100, 1);
+ std::shared_ptr<Array> array;
+ ArrayFromVector<Int32Type, int32_t>(null_bitmap, data32, &array);
+ arrays_one_.push_back(array);
+ ArrayFromVector<Int64Type, int64_t>(null_bitmap, data64, &array);
+ arrays_another_.push_back(array);
+
+ Construct();
+ ASSERT_FALSE(one_->Equals(another_));
+ ASSERT_FALSE(one_->Equals(*another_.get()));
+}
+
+TEST_F(TestChunkedArray, EqualsDifferingLengths) {
+ std::vector<bool> null_bitmap100(100, true);
+ std::vector<bool> null_bitmap101(101, true);
+ std::vector<int32_t> data100(100, 1);
+ std::vector<int32_t> data101(101, 1);
+ std::shared_ptr<Array> array;
+ ArrayFromVector<Int32Type, int32_t>(null_bitmap100, data100, &array);
+ arrays_one_.push_back(array);
+ ArrayFromVector<Int32Type, int32_t>(null_bitmap101, data101, &array);
+ arrays_another_.push_back(array);
+
+ Construct();
+ ASSERT_FALSE(one_->Equals(another_));
+ ASSERT_FALSE(one_->Equals(*another_.get()));
+
+ std::vector<bool> null_bitmap1(1, true);
+ std::vector<int32_t> data1(1, 1);
+ ArrayFromVector<Int32Type, int32_t>(null_bitmap1, data1, &array);
+ arrays_one_.push_back(array);
+
+ Construct();
+ ASSERT_TRUE(one_->Equals(another_));
+ ASSERT_TRUE(one_->Equals(*another_.get()));
+}
+
+TEST_F(TestChunkedArray, EqualsDifferingMetadata) {
+ auto left_ty = list(field("item", int32()));
+
+ auto metadata = key_value_metadata({"foo"}, {"bar"});
+ auto right_ty = list(field("item", int32(), true, metadata));
+
+ std::vector<std::shared_ptr<Array>> left_chunks = {ArrayFromJSON(left_ty, "[[]]")};
+ std::vector<std::shared_ptr<Array>> right_chunks = {ArrayFromJSON(right_ty, "[[]]")};
+
+ ChunkedArray left(left_chunks);
+ ChunkedArray right(right_chunks);
+ ASSERT_TRUE(left.Equals(right));
+}
+
+TEST_F(TestChunkedArray, SliceEquals) {
+ arrays_one_.push_back(MakeRandomArray<Int32Array>(100));
+ arrays_one_.push_back(MakeRandomArray<Int32Array>(50));
+ arrays_one_.push_back(MakeRandomArray<Int32Array>(50));
+ Construct();
+
+ std::shared_ptr<ChunkedArray> slice = one_->Slice(125, 50);
+ ASSERT_EQ(slice->length(), 50);
+ AssertChunkedEqual(*one_->Slice(125, 50), *slice);
+
+ std::shared_ptr<ChunkedArray> slice2 = one_->Slice(75)->Slice(25)->Slice(25, 50);
+ ASSERT_EQ(slice2->length(), 50);
+ AssertChunkedEqual(*slice, *slice2);
+
+ // Making empty slices of a ChunkedArray
+ std::shared_ptr<ChunkedArray> slice3 = one_->Slice(one_->length(), 99);
+ ASSERT_EQ(slice3->length(), 0);
+ ASSERT_EQ(slice3->num_chunks(), 1);
+ ASSERT_TRUE(slice3->type()->Equals(one_->type()));
+
+ std::shared_ptr<ChunkedArray> slice4 = one_->Slice(10, 0);
+ ASSERT_EQ(slice4->length(), 0);
+ ASSERT_EQ(slice4->num_chunks(), 1);
+ ASSERT_TRUE(slice4->type()->Equals(one_->type()));
+
+ // Slicing an empty ChunkedArray
+ std::shared_ptr<ChunkedArray> slice5 = slice4->Slice(0, 10);
+ ASSERT_EQ(slice5->length(), 0);
+ ASSERT_EQ(slice5->num_chunks(), 1);
+ ASSERT_TRUE(slice5->type()->Equals(one_->type()));
+}
+
+TEST_F(TestChunkedArray, ZeroChunksIssues) {
+ ArrayVector empty = {};
+ auto no_chunks = std::make_shared<ChunkedArray>(empty, int8());
+
+ // ARROW-8911, assert that slicing is a no-op when there are zero-chunks
+ auto sliced = no_chunks->Slice(0, 0);
+ auto sliced2 = no_chunks->Slice(0, 5);
+ AssertChunkedEqual(*no_chunks, *sliced);
+ AssertChunkedEqual(*no_chunks, *sliced2);
+}
+
+TEST_F(TestChunkedArray, Validate) {
+ // Valid if empty
+ ArrayVector empty = {};
+ auto no_chunks = std::make_shared<ChunkedArray>(empty, utf8());
+ ASSERT_OK(no_chunks->ValidateFull());
+
+ random::RandomArrayGenerator gen(0);
+ arrays_one_.push_back(gen.Int32(50, 0, 100, 0.1));
+ Construct();
+ ASSERT_OK(one_->ValidateFull());
+
+ arrays_one_.push_back(gen.Int32(50, 0, 100, 0.1));
+ Construct();
+ ASSERT_OK(one_->ValidateFull());
+
+ arrays_one_.push_back(gen.String(50, 0, 10, 0.1));
+ Construct();
+ ASSERT_RAISES(Invalid, one_->ValidateFull());
+}
+
+TEST_F(TestChunkedArray, PrintDiff) {
+ random::RandomArrayGenerator gen(0);
+ arrays_one_.push_back(gen.Int32(50, 0, 100, 0.1));
+ Construct();
+
+ auto other = one_->Slice(25);
+ ASSERT_OK_AND_ASSIGN(auto diff, PrintArrayDiff(*one_, *other));
+ ASSERT_EQ(*diff, "Expected length 50 but was actually 25");
+
+ ASSERT_OK_AND_ASSIGN(diff, PrintArrayDiff(*other, *one_));
+ ASSERT_EQ(*diff, "Expected length 25 but was actually 50");
+}
+
+TEST_F(TestChunkedArray, View) {
+ auto in_ty = int32();
+ auto out_ty = fixed_size_binary(4);
+#if ARROW_LITTLE_ENDIAN
+ auto arr = ArrayFromJSON(in_ty, "[2020568934, 2054316386, null]");
+ auto arr2 = ArrayFromJSON(in_ty, "[2020568934, 2054316386]");
+#else
+ auto arr = ArrayFromJSON(in_ty, "[1718579064, 1650553466, null]");
+ auto arr2 = ArrayFromJSON(in_ty, "[1718579064, 1650553466]");
+#endif
+ auto ex = ArrayFromJSON(out_ty, R"(["foox", "barz", null])");
+ auto ex2 = ArrayFromJSON(out_ty, R"(["foox", "barz"])");
+
+ ArrayVector chunks = {arr, arr2};
+ ArrayVector ex_chunks = {ex, ex2};
+ auto carr = std::make_shared<ChunkedArray>(chunks);
+ auto expected = std::make_shared<ChunkedArray>(ex_chunks);
+
+ ASSERT_OK_AND_ASSIGN(auto result, carr->View(out_ty));
+ AssertChunkedEqual(*expected, *result);
+
+ // Zero length
+ ArrayVector empty = {};
+ carr = std::make_shared<ChunkedArray>(empty, in_ty);
+ expected = std::make_shared<ChunkedArray>(empty, out_ty);
+ ASSERT_OK_AND_ASSIGN(result, carr->View(out_ty));
+ AssertChunkedEqual(*expected, *result);
+}
+
+TEST_F(TestChunkedArray, GetScalar) {
+ auto ty = int32();
+ ArrayVector chunks{ArrayFromJSON(ty, "[6, 7, null]"), ArrayFromJSON(ty, "[]"),
+ ArrayFromJSON(ty, "[null]"), ArrayFromJSON(ty, "[3, 4, 5]")};
+ ChunkedArray carr(chunks);
+
+ auto check_scalar = [](const ChunkedArray& array, int64_t index,
+ const Scalar& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, array.GetScalar(index));
+ AssertScalarsEqual(expected, *actual, /*verbose=*/true);
+ };
+
+ check_scalar(carr, 0, **MakeScalar(ty, 6));
+ check_scalar(carr, 2, *MakeNullScalar(ty));
+ check_scalar(carr, 3, *MakeNullScalar(ty));
+ check_scalar(carr, 4, **MakeScalar(ty, 3));
+ check_scalar(carr, 6, **MakeScalar(ty, 5));
+
+ ASSERT_RAISES(Invalid, carr.GetScalar(7));
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compare.cc b/src/arrow/cpp/src/arrow/compare.cc
new file mode 100644
index 000000000..4ecb00a3f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compare.cc
@@ -0,0 +1,1300 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Functions for comparing Arrow data structures
+
+#include "arrow/compare.h"
+
+#include <climits>
+#include <cmath>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/diff.h"
+#include "arrow/buffer.h"
+#include "arrow/scalar.h"
+#include "arrow/sparse_tensor.h"
+#include "arrow/status.h"
+#include "arrow/tensor.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/memory.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::BitmapEquals;
+using internal::BitmapReader;
+using internal::BitmapUInt64Reader;
+using internal::checked_cast;
+using internal::OptionalBitmapEquals;
+
+// ----------------------------------------------------------------------
+// Public method implementations
+
+namespace {
+
+// TODO also handle HALF_FLOAT NaNs
+
+enum FloatingEqualityFlags : int8_t { Approximate = 1, NansEqual = 2 };
+
+template <typename T, int8_t Flags>
+struct FloatingEquality {
+ bool operator()(T x, T y) { return x == y; }
+};
+
+template <typename T>
+struct FloatingEquality<T, NansEqual> {
+ bool operator()(T x, T y) { return (x == y) || (std::isnan(x) && std::isnan(y)); }
+};
+
+template <typename T>
+struct FloatingEquality<T, Approximate> {
+ explicit FloatingEquality(const EqualOptions& options)
+ : epsilon(static_cast<T>(options.atol())) {}
+
+ bool operator()(T x, T y) { return (fabs(x - y) <= epsilon) || (x == y); }
+
+ const T epsilon;
+};
+
+template <typename T>
+struct FloatingEquality<T, Approximate | NansEqual> {
+ explicit FloatingEquality(const EqualOptions& options)
+ : epsilon(static_cast<T>(options.atol())) {}
+
+ bool operator()(T x, T y) {
+ return (fabs(x - y) <= epsilon) || (x == y) || (std::isnan(x) && std::isnan(y));
+ }
+
+ const T epsilon;
+};
+
+template <typename T, typename Visitor>
+void VisitFloatingEquality(const EqualOptions& options, bool floating_approximate,
+ Visitor&& visit) {
+ if (options.nans_equal()) {
+ if (floating_approximate) {
+ visit(FloatingEquality<T, NansEqual | Approximate>{options});
+ } else {
+ visit(FloatingEquality<T, NansEqual>{});
+ }
+ } else {
+ if (floating_approximate) {
+ visit(FloatingEquality<T, Approximate>{options});
+ } else {
+ visit(FloatingEquality<T, 0>{});
+ }
+ }
+}
+
+inline bool IdentityImpliesEqualityNansNotEqual(const DataType& type) {
+ if (type.id() == Type::FLOAT || type.id() == Type::DOUBLE) {
+ return false;
+ }
+ for (const auto& child : type.fields()) {
+ if (!IdentityImpliesEqualityNansNotEqual(*child->type())) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline bool IdentityImpliesEquality(const DataType& type, const EqualOptions& options) {
+ if (options.nans_equal()) {
+ return true;
+ }
+ return IdentityImpliesEqualityNansNotEqual(type);
+}
+
+bool CompareArrayRanges(const ArrayData& left, const ArrayData& right,
+ int64_t left_start_idx, int64_t left_end_idx,
+ int64_t right_start_idx, const EqualOptions& options,
+ bool floating_approximate);
+
+class RangeDataEqualsImpl {
+ public:
+ // PRE-CONDITIONS:
+ // - the types are equal
+ // - the ranges are in bounds
+ RangeDataEqualsImpl(const EqualOptions& options, bool floating_approximate,
+ const ArrayData& left, const ArrayData& right,
+ int64_t left_start_idx, int64_t right_start_idx,
+ int64_t range_length)
+ : options_(options),
+ floating_approximate_(floating_approximate),
+ left_(left),
+ right_(right),
+ left_start_idx_(left_start_idx),
+ right_start_idx_(right_start_idx),
+ range_length_(range_length),
+ result_(false) {}
+
+ bool Compare() {
+ // Compare null bitmaps
+ if (left_start_idx_ == 0 && right_start_idx_ == 0 && range_length_ == left_.length &&
+ range_length_ == right_.length) {
+ // If we're comparing entire arrays, we can first compare the cached null counts
+ if (left_.GetNullCount() != right_.GetNullCount()) {
+ return false;
+ }
+ }
+ if (!OptionalBitmapEquals(left_.buffers[0], left_.offset + left_start_idx_,
+ right_.buffers[0], right_.offset + right_start_idx_,
+ range_length_)) {
+ return false;
+ }
+ // Compare values
+ return CompareWithType(*left_.type);
+ }
+
+ bool CompareWithType(const DataType& type) {
+ result_ = true;
+ if (range_length_ != 0) {
+ ARROW_CHECK_OK(VisitTypeInline(type, this));
+ }
+ return result_;
+ }
+
+ Status Visit(const NullType&) { return Status::OK(); }
+
+ template <typename TypeClass>
+ enable_if_primitive_ctype<TypeClass, Status> Visit(const TypeClass& type) {
+ return ComparePrimitive(type);
+ }
+
+ template <typename TypeClass>
+ enable_if_t<is_temporal_type<TypeClass>::value, Status> Visit(const TypeClass& type) {
+ return ComparePrimitive(type);
+ }
+
+ Status Visit(const BooleanType&) {
+ const uint8_t* left_bits = left_.GetValues<uint8_t>(1, 0);
+ const uint8_t* right_bits = right_.GetValues<uint8_t>(1, 0);
+ auto compare_runs = [&](int64_t i, int64_t length) -> bool {
+ if (length <= 8) {
+ // Avoid the BitmapUInt64Reader overhead for very small runs
+ for (int64_t j = i; j < i + length; ++j) {
+ if (BitUtil::GetBit(left_bits, left_start_idx_ + left_.offset + j) !=
+ BitUtil::GetBit(right_bits, right_start_idx_ + right_.offset + j)) {
+ return false;
+ }
+ }
+ return true;
+ } else if (length <= 1024) {
+ BitmapUInt64Reader left_reader(left_bits, left_start_idx_ + left_.offset + i,
+ length);
+ BitmapUInt64Reader right_reader(right_bits, right_start_idx_ + right_.offset + i,
+ length);
+ while (left_reader.position() < length) {
+ if (left_reader.NextWord() != right_reader.NextWord()) {
+ return false;
+ }
+ }
+ DCHECK_EQ(right_reader.position(), length);
+ } else {
+ // BitmapEquals is the fastest method on large runs
+ return BitmapEquals(left_bits, left_start_idx_ + left_.offset + i, right_bits,
+ right_start_idx_ + right_.offset + i, length);
+ }
+ return true;
+ };
+ VisitValidRuns(compare_runs);
+ return Status::OK();
+ }
+
+ Status Visit(const FloatType& type) { return CompareFloating(type); }
+
+ Status Visit(const DoubleType& type) { return CompareFloating(type); }
+
+ // Also matches StringType
+ Status Visit(const BinaryType& type) { return CompareBinary(type); }
+
+ // Also matches LargeStringType
+ Status Visit(const LargeBinaryType& type) { return CompareBinary(type); }
+
+ Status Visit(const FixedSizeBinaryType& type) {
+ const auto byte_width = type.byte_width();
+ const uint8_t* left_data = left_.GetValues<uint8_t>(1, 0);
+ const uint8_t* right_data = right_.GetValues<uint8_t>(1, 0);
+
+ if (left_data != nullptr && right_data != nullptr) {
+ auto compare_runs = [&](int64_t i, int64_t length) -> bool {
+ return memcmp(left_data + (left_start_idx_ + left_.offset + i) * byte_width,
+ right_data + (right_start_idx_ + right_.offset + i) * byte_width,
+ length * byte_width) == 0;
+ };
+ VisitValidRuns(compare_runs);
+ } else {
+ auto compare_runs = [&](int64_t i, int64_t length) -> bool { return true; };
+ VisitValidRuns(compare_runs);
+ }
+ return Status::OK();
+ }
+
+ // Also matches MapType
+ Status Visit(const ListType& type) { return CompareList(type); }
+
+ Status Visit(const LargeListType& type) { return CompareList(type); }
+
+ Status Visit(const FixedSizeListType& type) {
+ const auto list_size = type.list_size();
+ const ArrayData& left_data = *left_.child_data[0];
+ const ArrayData& right_data = *right_.child_data[0];
+
+ auto compare_runs = [&](int64_t i, int64_t length) -> bool {
+ RangeDataEqualsImpl impl(options_, floating_approximate_, left_data, right_data,
+ (left_start_idx_ + left_.offset + i) * list_size,
+ (right_start_idx_ + right_.offset + i) * list_size,
+ length * list_size);
+ return impl.Compare();
+ };
+ VisitValidRuns(compare_runs);
+ return Status::OK();
+ }
+
+ Status Visit(const StructType& type) {
+ const int32_t num_fields = type.num_fields();
+
+ auto compare_runs = [&](int64_t i, int64_t length) -> bool {
+ for (int32_t f = 0; f < num_fields; ++f) {
+ RangeDataEqualsImpl impl(options_, floating_approximate_, *left_.child_data[f],
+ *right_.child_data[f],
+ left_start_idx_ + left_.offset + i,
+ right_start_idx_ + right_.offset + i, length);
+ if (!impl.Compare()) {
+ return false;
+ }
+ }
+ return true;
+ };
+ VisitValidRuns(compare_runs);
+ return Status::OK();
+ }
+
+ Status Visit(const SparseUnionType& type) {
+ const auto& child_ids = type.child_ids();
+ const int8_t* left_codes = left_.GetValues<int8_t>(1);
+ const int8_t* right_codes = right_.GetValues<int8_t>(1);
+
+ // Unions don't have a null bitmap
+ for (int64_t i = 0; i < range_length_; ++i) {
+ const auto type_id = left_codes[left_start_idx_ + i];
+ if (type_id != right_codes[right_start_idx_ + i]) {
+ result_ = false;
+ break;
+ }
+ const auto child_num = child_ids[type_id];
+ // XXX can we instead detect runs of same-child union values?
+ RangeDataEqualsImpl impl(
+ options_, floating_approximate_, *left_.child_data[child_num],
+ *right_.child_data[child_num], left_start_idx_ + left_.offset + i,
+ right_start_idx_ + right_.offset + i, 1);
+ if (!impl.Compare()) {
+ result_ = false;
+ break;
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DenseUnionType& type) {
+ const auto& child_ids = type.child_ids();
+ const int8_t* left_codes = left_.GetValues<int8_t>(1);
+ const int8_t* right_codes = right_.GetValues<int8_t>(1);
+ const int32_t* left_offsets = left_.GetValues<int32_t>(2);
+ const int32_t* right_offsets = right_.GetValues<int32_t>(2);
+
+ for (int64_t i = 0; i < range_length_; ++i) {
+ const auto type_id = left_codes[left_start_idx_ + i];
+ if (type_id != right_codes[right_start_idx_ + i]) {
+ result_ = false;
+ break;
+ }
+ const auto child_num = child_ids[type_id];
+ RangeDataEqualsImpl impl(
+ options_, floating_approximate_, *left_.child_data[child_num],
+ *right_.child_data[child_num], left_offsets[left_start_idx_ + i],
+ right_offsets[right_start_idx_ + i], 1);
+ if (!impl.Compare()) {
+ result_ = false;
+ break;
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryType& type) {
+ // Compare dictionaries
+ result_ &= CompareArrayRanges(
+ *left_.dictionary, *right_.dictionary,
+ /*left_start_idx=*/0,
+ /*left_end_idx=*/std::max(left_.dictionary->length, right_.dictionary->length),
+ /*right_start_idx=*/0, options_, floating_approximate_);
+ if (result_) {
+ // Compare indices
+ result_ &= CompareWithType(*type.index_type());
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionType& type) {
+ // Compare storages
+ result_ &= CompareWithType(*type.storage_type());
+ return Status::OK();
+ }
+
+ protected:
+ // For CompareFloating (templated local classes or lambdas not supported in C++11)
+ template <typename CType>
+ struct ComparatorVisitor {
+ RangeDataEqualsImpl* impl;
+ const CType* left_values;
+ const CType* right_values;
+
+ template <typename CompareFunction>
+ void operator()(CompareFunction&& compare) {
+ impl->VisitValues([&](int64_t i) {
+ const CType x = left_values[i + impl->left_start_idx_];
+ const CType y = right_values[i + impl->right_start_idx_];
+ return compare(x, y);
+ });
+ }
+ };
+
+ template <typename CType>
+ friend struct ComparatorVisitor;
+
+ template <typename TypeClass, typename CType = typename TypeClass::c_type>
+ Status ComparePrimitive(const TypeClass&) {
+ const CType* left_values = left_.GetValues<CType>(1);
+ const CType* right_values = right_.GetValues<CType>(1);
+ VisitValidRuns([&](int64_t i, int64_t length) {
+ return memcmp(left_values + left_start_idx_ + i,
+ right_values + right_start_idx_ + i, length * sizeof(CType)) == 0;
+ });
+ return Status::OK();
+ }
+
+ template <typename TypeClass>
+ Status CompareFloating(const TypeClass&) {
+ using CType = typename TypeClass::c_type;
+ const CType* left_values = left_.GetValues<CType>(1);
+ const CType* right_values = right_.GetValues<CType>(1);
+
+ ComparatorVisitor<CType> visitor{this, left_values, right_values};
+ VisitFloatingEquality<CType>(options_, floating_approximate_, visitor);
+ return Status::OK();
+ }
+
+ template <typename TypeClass>
+ Status CompareBinary(const TypeClass&) {
+ const uint8_t* left_data = left_.GetValues<uint8_t>(2, 0);
+ const uint8_t* right_data = right_.GetValues<uint8_t>(2, 0);
+
+ if (left_data != nullptr && right_data != nullptr) {
+ const auto compare_ranges = [&](int64_t left_offset, int64_t right_offset,
+ int64_t length) -> bool {
+ return memcmp(left_data + left_offset, right_data + right_offset, length) == 0;
+ };
+ CompareWithOffsets<typename TypeClass::offset_type>(1, compare_ranges);
+ } else {
+ // One of the arrays is an array of empty strings and nulls.
+ // We just need to compare the offsets.
+ // (note we must not call memcmp() with null data pointers)
+ CompareWithOffsets<typename TypeClass::offset_type>(1, [](...) { return true; });
+ }
+ return Status::OK();
+ }
+
+ template <typename TypeClass>
+ Status CompareList(const TypeClass&) {
+ const ArrayData& left_data = *left_.child_data[0];
+ const ArrayData& right_data = *right_.child_data[0];
+
+ const auto compare_ranges = [&](int64_t left_offset, int64_t right_offset,
+ int64_t length) -> bool {
+ RangeDataEqualsImpl impl(options_, floating_approximate_, left_data, right_data,
+ left_offset, right_offset, length);
+ return impl.Compare();
+ };
+
+ CompareWithOffsets<typename TypeClass::offset_type>(1, compare_ranges);
+ return Status::OK();
+ }
+
+ template <typename offset_type, typename CompareRanges>
+ void CompareWithOffsets(int offsets_buffer_index, CompareRanges&& compare_ranges) {
+ const offset_type* left_offsets =
+ left_.GetValues<offset_type>(offsets_buffer_index) + left_start_idx_;
+ const offset_type* right_offsets =
+ right_.GetValues<offset_type>(offsets_buffer_index) + right_start_idx_;
+
+ const auto compare_runs = [&](int64_t i, int64_t length) {
+ for (int64_t j = i; j < i + length; ++j) {
+ if (left_offsets[j + 1] - left_offsets[j] !=
+ right_offsets[j + 1] - right_offsets[j]) {
+ return false;
+ }
+ }
+ if (!compare_ranges(left_offsets[i], right_offsets[i],
+ left_offsets[i + length] - left_offsets[i])) {
+ return false;
+ }
+ return true;
+ };
+
+ VisitValidRuns(compare_runs);
+ }
+
+ template <typename CompareValues>
+ void VisitValues(CompareValues&& compare_values) {
+ internal::VisitSetBitRunsVoid(left_.buffers[0], left_.offset + left_start_idx_,
+ range_length_, [&](int64_t position, int64_t length) {
+ for (int64_t i = 0; i < length; ++i) {
+ result_ &= compare_values(position + i);
+ }
+ });
+ }
+
+ // Visit and compare runs of non-null values
+ template <typename CompareRuns>
+ void VisitValidRuns(CompareRuns&& compare_runs) {
+ const uint8_t* left_null_bitmap = left_.GetValues<uint8_t>(0, 0);
+ if (left_null_bitmap == nullptr) {
+ result_ = compare_runs(0, range_length_);
+ return;
+ }
+ internal::SetBitRunReader reader(left_null_bitmap, left_.offset + left_start_idx_,
+ range_length_);
+ while (true) {
+ const auto run = reader.NextRun();
+ if (run.length == 0) {
+ return;
+ }
+ if (!compare_runs(run.position, run.length)) {
+ result_ = false;
+ return;
+ }
+ }
+ }
+
+ const EqualOptions& options_;
+ const bool floating_approximate_;
+ const ArrayData& left_;
+ const ArrayData& right_;
+ const int64_t left_start_idx_;
+ const int64_t right_start_idx_;
+ const int64_t range_length_;
+
+ bool result_;
+};
+
+bool CompareArrayRanges(const ArrayData& left, const ArrayData& right,
+ int64_t left_start_idx, int64_t left_end_idx,
+ int64_t right_start_idx, const EqualOptions& options,
+ bool floating_approximate) {
+ if (left.type->id() != right.type->id() ||
+ !TypeEquals(*left.type, *right.type, false /* check_metadata */)) {
+ return false;
+ }
+
+ const int64_t range_length = left_end_idx - left_start_idx;
+ DCHECK_GE(range_length, 0);
+ if (left_start_idx + range_length > left.length) {
+ // Left range too small
+ return false;
+ }
+ if (right_start_idx + range_length > right.length) {
+ // Right range too small
+ return false;
+ }
+ if (&left == &right && left_start_idx == right_start_idx &&
+ IdentityImpliesEquality(*left.type, options)) {
+ return true;
+ }
+ // Compare values
+ RangeDataEqualsImpl impl(options, floating_approximate, left, right, left_start_idx,
+ right_start_idx, range_length);
+ return impl.Compare();
+}
+
+class TypeEqualsVisitor {
+ public:
+ explicit TypeEqualsVisitor(const DataType& right, bool check_metadata)
+ : right_(right), check_metadata_(check_metadata), result_(false) {}
+
+ Status VisitChildren(const DataType& left) {
+ if (left.num_fields() != right_.num_fields()) {
+ result_ = false;
+ return Status::OK();
+ }
+
+ for (int i = 0; i < left.num_fields(); ++i) {
+ if (!left.field(i)->Equals(right_.field(i), check_metadata_)) {
+ result_ = false;
+ return Status::OK();
+ }
+ }
+ result_ = true;
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_t<is_null_type<T>::value || is_primitive_ctype<T>::value ||
+ is_base_binary_type<T>::value,
+ Status>
+ Visit(const T&) {
+ result_ = true;
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_interval<T, Status> Visit(const T& left) {
+ const auto& right = checked_cast<const IntervalType&>(right_);
+ result_ = right.interval_type() == left.interval_type();
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_t<is_time_type<T>::value || is_date_type<T>::value ||
+ is_duration_type<T>::value,
+ Status>
+ Visit(const T& left) {
+ const auto& right = checked_cast<const T&>(right_);
+ result_ = left.unit() == right.unit();
+ return Status::OK();
+ }
+
+ Status Visit(const TimestampType& left) {
+ const auto& right = checked_cast<const TimestampType&>(right_);
+ result_ = left.unit() == right.unit() && left.timezone() == right.timezone();
+ return Status::OK();
+ }
+
+ Status Visit(const FixedSizeBinaryType& left) {
+ const auto& right = checked_cast<const FixedSizeBinaryType&>(right_);
+ result_ = left.byte_width() == right.byte_width();
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal128Type& left) {
+ const auto& right = checked_cast<const Decimal128Type&>(right_);
+ result_ = left.precision() == right.precision() && left.scale() == right.scale();
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal256Type& left) {
+ const auto& right = checked_cast<const Decimal256Type&>(right_);
+ result_ = left.precision() == right.precision() && left.scale() == right.scale();
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_t<is_list_like_type<T>::value || is_struct_type<T>::value, Status> Visit(
+ const T& left) {
+ return VisitChildren(left);
+ }
+
+ Status Visit(const MapType& left) {
+ const auto& right = checked_cast<const MapType&>(right_);
+ if (left.keys_sorted() != right.keys_sorted()) {
+ result_ = false;
+ return Status::OK();
+ }
+ result_ = left.key_type()->Equals(*right.key_type(), check_metadata_) &&
+ left.item_type()->Equals(*right.item_type(), check_metadata_);
+ return Status::OK();
+ }
+
+ Status Visit(const UnionType& left) {
+ const auto& right = checked_cast<const UnionType&>(right_);
+
+ if (left.mode() != right.mode() || left.type_codes() != right.type_codes()) {
+ result_ = false;
+ return Status::OK();
+ }
+
+ result_ = std::equal(
+ left.fields().begin(), left.fields().end(), right.fields().begin(),
+ [this](const std::shared_ptr<Field>& l, const std::shared_ptr<Field>& r) {
+ return l->Equals(r, check_metadata_);
+ });
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryType& left) {
+ const auto& right = checked_cast<const DictionaryType&>(right_);
+ result_ = left.index_type()->Equals(right.index_type()) &&
+ left.value_type()->Equals(right.value_type()) &&
+ (left.ordered() == right.ordered());
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionType& left) {
+ result_ = left.ExtensionEquals(static_cast<const ExtensionType&>(right_));
+ return Status::OK();
+ }
+
+ bool result() const { return result_; }
+
+ protected:
+ const DataType& right_;
+ bool check_metadata_;
+ bool result_;
+};
+
+bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts,
+ bool floating_approximate);
+bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options,
+ bool floating_approximate);
+
+class ScalarEqualsVisitor {
+ public:
+ // PRE-CONDITIONS:
+ // - the types are equal
+ // - the scalars are non-null
+ explicit ScalarEqualsVisitor(const Scalar& right, const EqualOptions& opts,
+ bool floating_approximate)
+ : right_(right),
+ options_(opts),
+ floating_approximate_(floating_approximate),
+ result_(false) {}
+
+ Status Visit(const NullScalar& left) {
+ result_ = true;
+ return Status::OK();
+ }
+
+ Status Visit(const BooleanScalar& left) {
+ const auto& right = checked_cast<const BooleanScalar&>(right_);
+ result_ = left.value == right.value;
+ return Status::OK();
+ }
+
+ template <typename T>
+ typename std::enable_if<(is_primitive_ctype<typename T::TypeClass>::value ||
+ is_temporal_type<typename T::TypeClass>::value),
+ Status>::type
+ Visit(const T& left_) {
+ const auto& right = checked_cast<const T&>(right_);
+ result_ = right.value == left_.value;
+ return Status::OK();
+ }
+
+ Status Visit(const FloatScalar& left) { return CompareFloating(left); }
+
+ Status Visit(const DoubleScalar& left) { return CompareFloating(left); }
+
+ template <typename T>
+ typename std::enable_if<std::is_base_of<BaseBinaryScalar, T>::value, Status>::type
+ Visit(const T& left) {
+ const auto& right = checked_cast<const BaseBinaryScalar&>(right_);
+ result_ = internal::SharedPtrEquals(left.value, right.value);
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal128Scalar& left) {
+ const auto& right = checked_cast<const Decimal128Scalar&>(right_);
+ result_ = left.value == right.value;
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal256Scalar& left) {
+ const auto& right = checked_cast<const Decimal256Scalar&>(right_);
+ result_ = left.value == right.value;
+ return Status::OK();
+ }
+
+ Status Visit(const ListScalar& left) {
+ const auto& right = checked_cast<const ListScalar&>(right_);
+ result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_);
+ return Status::OK();
+ }
+
+ Status Visit(const LargeListScalar& left) {
+ const auto& right = checked_cast<const LargeListScalar&>(right_);
+ result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_);
+ return Status::OK();
+ }
+
+ Status Visit(const MapScalar& left) {
+ const auto& right = checked_cast<const MapScalar&>(right_);
+ result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_);
+ return Status::OK();
+ }
+
+ Status Visit(const FixedSizeListScalar& left) {
+ const auto& right = checked_cast<const FixedSizeListScalar&>(right_);
+ result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_);
+ return Status::OK();
+ }
+
+ Status Visit(const StructScalar& left) {
+ const auto& right = checked_cast<const StructScalar&>(right_);
+
+ if (right.value.size() != left.value.size()) {
+ result_ = false;
+ } else {
+ bool all_equals = true;
+ for (size_t i = 0; i < left.value.size() && all_equals; i++) {
+ all_equals &= ScalarEquals(*left.value[i], *right.value[i], options_,
+ floating_approximate_);
+ }
+ result_ = all_equals;
+ }
+
+ return Status::OK();
+ }
+
+ Status Visit(const UnionScalar& left) {
+ const auto& right = checked_cast<const UnionScalar&>(right_);
+ result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_);
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryScalar& left) {
+ const auto& right = checked_cast<const DictionaryScalar&>(right_);
+ result_ = ScalarEquals(*left.value.index, *right.value.index, options_,
+ floating_approximate_) &&
+ ArrayEquals(*left.value.dictionary, *right.value.dictionary, options_,
+ floating_approximate_);
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionScalar& left) {
+ const auto& right = checked_cast<const ExtensionScalar&>(right_);
+ result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_);
+ return Status::OK();
+ }
+
+ bool result() const { return result_; }
+
+ protected:
+ // For CompareFloating (templated local classes or lambdas not supported in C++11)
+ template <typename ScalarType>
+ struct ComparatorVisitor {
+ const ScalarType& left;
+ const ScalarType& right;
+ bool* result;
+
+ template <typename CompareFunction>
+ void operator()(CompareFunction&& compare) {
+ *result = compare(left.value, right.value);
+ }
+ };
+
+ template <typename ScalarType>
+ Status CompareFloating(const ScalarType& left) {
+ using CType = decltype(left.value);
+
+ ComparatorVisitor<ScalarType> visitor{left, checked_cast<const ScalarType&>(right_),
+ &result_};
+ VisitFloatingEquality<CType>(options_, floating_approximate_, visitor);
+ return Status::OK();
+ }
+
+ const Scalar& right_;
+ const EqualOptions options_;
+ const bool floating_approximate_;
+ bool result_;
+};
+
+Status PrintDiff(const Array& left, const Array& right, std::ostream* os);
+
+Status PrintDiff(const Array& left, const Array& right, int64_t left_offset,
+ int64_t left_length, int64_t right_offset, int64_t right_length,
+ std::ostream* os) {
+ if (os == nullptr) {
+ return Status::OK();
+ }
+
+ if (!left.type()->Equals(right.type())) {
+ *os << "# Array types differed: " << *left.type() << " vs " << *right.type()
+ << std::endl;
+ return Status::OK();
+ }
+
+ if (left.type()->id() == Type::DICTIONARY) {
+ *os << "# Dictionary arrays differed" << std::endl;
+
+ const auto& left_dict = checked_cast<const DictionaryArray&>(left);
+ const auto& right_dict = checked_cast<const DictionaryArray&>(right);
+
+ *os << "## dictionary diff";
+ auto pos = os->tellp();
+ RETURN_NOT_OK(PrintDiff(*left_dict.dictionary(), *right_dict.dictionary(), os));
+ if (os->tellp() == pos) {
+ *os << std::endl;
+ }
+
+ *os << "## indices diff";
+ pos = os->tellp();
+ RETURN_NOT_OK(PrintDiff(*left_dict.indices(), *right_dict.indices(), os));
+ if (os->tellp() == pos) {
+ *os << std::endl;
+ }
+ return Status::OK();
+ }
+
+ const auto left_slice = left.Slice(left_offset, left_length);
+ const auto right_slice = right.Slice(right_offset, right_length);
+ ARROW_ASSIGN_OR_RAISE(auto edits,
+ Diff(*left_slice, *right_slice, default_memory_pool()));
+ ARROW_ASSIGN_OR_RAISE(auto formatter, MakeUnifiedDiffFormatter(*left.type(), os));
+ return formatter(*edits, *left_slice, *right_slice);
+}
+
+Status PrintDiff(const Array& left, const Array& right, std::ostream* os) {
+ return PrintDiff(left, right, 0, left.length(), 0, right.length(), os);
+}
+
+bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx,
+ int64_t left_end_idx, int64_t right_start_idx,
+ const EqualOptions& options, bool floating_approximate) {
+ bool are_equal =
+ CompareArrayRanges(*left.data(), *right.data(), left_start_idx, left_end_idx,
+ right_start_idx, options, floating_approximate);
+ if (!are_equal) {
+ ARROW_IGNORE_EXPR(PrintDiff(
+ left, right, left_start_idx, left_end_idx, right_start_idx,
+ right_start_idx + (left_end_idx - left_start_idx), options.diff_sink()));
+ }
+ return are_equal;
+}
+
+bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts,
+ bool floating_approximate) {
+ if (left.length() != right.length()) {
+ ARROW_IGNORE_EXPR(PrintDiff(left, right, opts.diff_sink()));
+ return false;
+ }
+ return ArrayRangeEquals(left, right, 0, left.length(), 0, opts, floating_approximate);
+}
+
+bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options,
+ bool floating_approximate) {
+ if (&left == &right && IdentityImpliesEquality(*left.type, options)) {
+ return true;
+ }
+ if (!left.type->Equals(right.type)) {
+ return false;
+ }
+ if (left.is_valid != right.is_valid) {
+ return false;
+ }
+ if (!left.is_valid) {
+ return true;
+ }
+ ScalarEqualsVisitor visitor(right, options, floating_approximate);
+ auto error = VisitScalarInline(left, &visitor);
+ DCHECK_OK(error);
+ return visitor.result();
+}
+
+} // namespace
+
+bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx,
+ int64_t left_end_idx, int64_t right_start_idx,
+ const EqualOptions& options) {
+ const bool floating_approximate = false;
+ return ArrayRangeEquals(left, right, left_start_idx, left_end_idx, right_start_idx,
+ options, floating_approximate);
+}
+
+bool ArrayRangeApproxEquals(const Array& left, const Array& right, int64_t left_start_idx,
+ int64_t left_end_idx, int64_t right_start_idx,
+ const EqualOptions& options) {
+ const bool floating_approximate = true;
+ return ArrayRangeEquals(left, right, left_start_idx, left_end_idx, right_start_idx,
+ options, floating_approximate);
+}
+
+bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts) {
+ const bool floating_approximate = false;
+ return ArrayEquals(left, right, opts, floating_approximate);
+}
+
+bool ArrayApproxEquals(const Array& left, const Array& right, const EqualOptions& opts) {
+ const bool floating_approximate = true;
+ return ArrayEquals(left, right, opts, floating_approximate);
+}
+
+bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options) {
+ const bool floating_approximate = false;
+ return ScalarEquals(left, right, options, floating_approximate);
+}
+
+bool ScalarApproxEquals(const Scalar& left, const Scalar& right,
+ const EqualOptions& options) {
+ const bool floating_approximate = true;
+ return ScalarEquals(left, right, options, floating_approximate);
+}
+
+namespace {
+
+bool StridedIntegerTensorContentEquals(const int dim_index, int64_t left_offset,
+ int64_t right_offset, int elem_size,
+ const Tensor& left, const Tensor& right) {
+ const auto n = left.shape()[dim_index];
+ const auto left_stride = left.strides()[dim_index];
+ const auto right_stride = right.strides()[dim_index];
+ if (dim_index == left.ndim() - 1) {
+ for (int64_t i = 0; i < n; ++i) {
+ if (memcmp(left.raw_data() + left_offset + i * left_stride,
+ right.raw_data() + right_offset + i * right_stride, elem_size) != 0) {
+ return false;
+ }
+ }
+ return true;
+ }
+ for (int64_t i = 0; i < n; ++i) {
+ if (!StridedIntegerTensorContentEquals(dim_index + 1, left_offset, right_offset,
+ elem_size, left, right)) {
+ return false;
+ }
+ left_offset += left_stride;
+ right_offset += right_stride;
+ }
+ return true;
+}
+
+bool IntegerTensorEquals(const Tensor& left, const Tensor& right) {
+ bool are_equal;
+ // The arrays are the same object
+ if (&left == &right) {
+ are_equal = true;
+ } else {
+ const bool left_row_major_p = left.is_row_major();
+ const bool left_column_major_p = left.is_column_major();
+ const bool right_row_major_p = right.is_row_major();
+ const bool right_column_major_p = right.is_column_major();
+
+ if (!(left_row_major_p && right_row_major_p) &&
+ !(left_column_major_p && right_column_major_p)) {
+ const auto& type = checked_cast<const FixedWidthType&>(*left.type());
+ are_equal = StridedIntegerTensorContentEquals(0, 0, 0, internal::GetByteWidth(type),
+ left, right);
+ } else {
+ const int byte_width = internal::GetByteWidth(*left.type());
+ DCHECK_GT(byte_width, 0);
+
+ const uint8_t* left_data = left.data()->data();
+ const uint8_t* right_data = right.data()->data();
+
+ are_equal = memcmp(left_data, right_data,
+ static_cast<size_t>(byte_width * left.size())) == 0;
+ }
+ }
+ return are_equal;
+}
+
+template <typename DataType>
+bool StridedFloatTensorContentEquals(const int dim_index, int64_t left_offset,
+ int64_t right_offset, const Tensor& left,
+ const Tensor& right, const EqualOptions& opts) {
+ using c_type = typename DataType::c_type;
+ static_assert(std::is_floating_point<c_type>::value,
+ "DataType must be a floating point type");
+
+ const auto n = left.shape()[dim_index];
+ const auto left_stride = left.strides()[dim_index];
+ const auto right_stride = right.strides()[dim_index];
+ if (dim_index == left.ndim() - 1) {
+ auto left_data = left.raw_data();
+ auto right_data = right.raw_data();
+ if (opts.nans_equal()) {
+ for (int64_t i = 0; i < n; ++i) {
+ c_type left_value =
+ *reinterpret_cast<const c_type*>(left_data + left_offset + i * left_stride);
+ c_type right_value = *reinterpret_cast<const c_type*>(right_data + right_offset +
+ i * right_stride);
+ if (left_value != right_value &&
+ !(std::isnan(left_value) && std::isnan(right_value))) {
+ return false;
+ }
+ }
+ } else {
+ for (int64_t i = 0; i < n; ++i) {
+ c_type left_value =
+ *reinterpret_cast<const c_type*>(left_data + left_offset + i * left_stride);
+ c_type right_value = *reinterpret_cast<const c_type*>(right_data + right_offset +
+ i * right_stride);
+ if (left_value != right_value) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+ for (int64_t i = 0; i < n; ++i) {
+ if (!StridedFloatTensorContentEquals<DataType>(dim_index + 1, left_offset,
+ right_offset, left, right, opts)) {
+ return false;
+ }
+ left_offset += left_stride;
+ right_offset += right_stride;
+ }
+ return true;
+}
+
+template <typename DataType>
+bool FloatTensorEquals(const Tensor& left, const Tensor& right,
+ const EqualOptions& opts) {
+ return StridedFloatTensorContentEquals<DataType>(0, 0, 0, left, right, opts);
+}
+
+} // namespace
+
+bool TensorEquals(const Tensor& left, const Tensor& right, const EqualOptions& opts) {
+ if (left.type_id() != right.type_id()) {
+ return false;
+ } else if (left.size() == 0 && right.size() == 0) {
+ return true;
+ } else if (left.shape() != right.shape()) {
+ return false;
+ }
+
+ switch (left.type_id()) {
+ // TODO: Support half-float tensors
+ // case Type::HALF_FLOAT:
+ case Type::FLOAT:
+ return FloatTensorEquals<FloatType>(left, right, opts);
+
+ case Type::DOUBLE:
+ return FloatTensorEquals<DoubleType>(left, right, opts);
+
+ default:
+ return IntegerTensorEquals(left, right);
+ }
+}
+
+namespace {
+
+template <typename LeftSparseIndexType, typename RightSparseIndexType>
+struct SparseTensorEqualsImpl {
+ static bool Compare(const SparseTensorImpl<LeftSparseIndexType>& left,
+ const SparseTensorImpl<RightSparseIndexType>& right,
+ const EqualOptions&) {
+ // TODO(mrkn): should we support the equality among different formats?
+ return false;
+ }
+};
+
+bool IntegerSparseTensorDataEquals(const uint8_t* left_data, const uint8_t* right_data,
+ const int byte_width, const int64_t length) {
+ if (left_data == right_data) {
+ return true;
+ }
+ return memcmp(left_data, right_data, static_cast<size_t>(byte_width * length)) == 0;
+}
+
+template <typename DataType>
+bool FloatSparseTensorDataEquals(const typename DataType::c_type* left_data,
+ const typename DataType::c_type* right_data,
+ const int64_t length, const EqualOptions& opts) {
+ using c_type = typename DataType::c_type;
+ static_assert(std::is_floating_point<c_type>::value,
+ "DataType must be a floating point type");
+ if (opts.nans_equal()) {
+ if (left_data == right_data) {
+ return true;
+ }
+
+ for (int64_t i = 0; i < length; ++i) {
+ const auto left = left_data[i];
+ const auto right = right_data[i];
+ if (left != right && !(std::isnan(left) && std::isnan(right))) {
+ return false;
+ }
+ }
+ } else {
+ for (int64_t i = 0; i < length; ++i) {
+ if (left_data[i] != right_data[i]) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+template <typename SparseIndexType>
+struct SparseTensorEqualsImpl<SparseIndexType, SparseIndexType> {
+ static bool Compare(const SparseTensorImpl<SparseIndexType>& left,
+ const SparseTensorImpl<SparseIndexType>& right,
+ const EqualOptions& opts) {
+ DCHECK(left.type()->id() == right.type()->id());
+ DCHECK(left.shape() == right.shape());
+
+ const auto length = left.non_zero_length();
+ DCHECK(length == right.non_zero_length());
+
+ const auto& left_index = checked_cast<const SparseIndexType&>(*left.sparse_index());
+ const auto& right_index = checked_cast<const SparseIndexType&>(*right.sparse_index());
+
+ if (!left_index.Equals(right_index)) {
+ return false;
+ }
+
+ const int byte_width = internal::GetByteWidth(*left.type());
+ DCHECK_GT(byte_width, 0);
+
+ const uint8_t* left_data = left.data()->data();
+ const uint8_t* right_data = right.data()->data();
+ switch (left.type()->id()) {
+ // TODO: Support half-float tensors
+ // case Type::HALF_FLOAT:
+ case Type::FLOAT:
+ return FloatSparseTensorDataEquals<FloatType>(
+ reinterpret_cast<const float*>(left_data),
+ reinterpret_cast<const float*>(right_data), length, opts);
+
+ case Type::DOUBLE:
+ return FloatSparseTensorDataEquals<DoubleType>(
+ reinterpret_cast<const double*>(left_data),
+ reinterpret_cast<const double*>(right_data), length, opts);
+
+ default: // Integer cases
+ return IntegerSparseTensorDataEquals(left_data, right_data, byte_width, length);
+ }
+ }
+};
+
+template <typename SparseIndexType>
+inline bool SparseTensorEqualsImplDispatch(const SparseTensorImpl<SparseIndexType>& left,
+ const SparseTensor& right,
+ const EqualOptions& opts) {
+ switch (right.format_id()) {
+ case SparseTensorFormat::COO: {
+ const auto& right_coo =
+ checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(right);
+ return SparseTensorEqualsImpl<SparseIndexType, SparseCOOIndex>::Compare(
+ left, right_coo, opts);
+ }
+
+ case SparseTensorFormat::CSR: {
+ const auto& right_csr =
+ checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(right);
+ return SparseTensorEqualsImpl<SparseIndexType, SparseCSRIndex>::Compare(
+ left, right_csr, opts);
+ }
+
+ case SparseTensorFormat::CSC: {
+ const auto& right_csc =
+ checked_cast<const SparseTensorImpl<SparseCSCIndex>&>(right);
+ return SparseTensorEqualsImpl<SparseIndexType, SparseCSCIndex>::Compare(
+ left, right_csc, opts);
+ }
+
+ case SparseTensorFormat::CSF: {
+ const auto& right_csf =
+ checked_cast<const SparseTensorImpl<SparseCSFIndex>&>(right);
+ return SparseTensorEqualsImpl<SparseIndexType, SparseCSFIndex>::Compare(
+ left, right_csf, opts);
+ }
+
+ default:
+ return false;
+ }
+}
+
+} // namespace
+
+bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right,
+ const EqualOptions& opts) {
+ if (left.type()->id() != right.type()->id()) {
+ return false;
+ } else if (left.size() == 0 && right.size() == 0) {
+ return true;
+ } else if (left.shape() != right.shape()) {
+ return false;
+ } else if (left.non_zero_length() != right.non_zero_length()) {
+ return false;
+ }
+
+ switch (left.format_id()) {
+ case SparseTensorFormat::COO: {
+ const auto& left_coo = checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(left);
+ return SparseTensorEqualsImplDispatch(left_coo, right, opts);
+ }
+
+ case SparseTensorFormat::CSR: {
+ const auto& left_csr = checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(left);
+ return SparseTensorEqualsImplDispatch(left_csr, right, opts);
+ }
+
+ case SparseTensorFormat::CSC: {
+ const auto& left_csc = checked_cast<const SparseTensorImpl<SparseCSCIndex>&>(left);
+ return SparseTensorEqualsImplDispatch(left_csc, right, opts);
+ }
+
+ case SparseTensorFormat::CSF: {
+ const auto& left_csf = checked_cast<const SparseTensorImpl<SparseCSFIndex>&>(left);
+ return SparseTensorEqualsImplDispatch(left_csf, right, opts);
+ }
+
+ default:
+ return false;
+ }
+}
+
+bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata) {
+ // The arrays are the same object
+ if (&left == &right) {
+ return true;
+ } else if (left.id() != right.id()) {
+ return false;
+ } else {
+ // First try to compute fingerprints
+ if (check_metadata) {
+ const auto& left_metadata_fp = left.metadata_fingerprint();
+ const auto& right_metadata_fp = right.metadata_fingerprint();
+ if (left_metadata_fp != right_metadata_fp) {
+ return false;
+ }
+ }
+
+ const auto& left_fp = left.fingerprint();
+ const auto& right_fp = right.fingerprint();
+ if (!left_fp.empty() && !right_fp.empty()) {
+ return left_fp == right_fp;
+ }
+
+ // TODO remove check_metadata here?
+ TypeEqualsVisitor visitor(right, check_metadata);
+ auto error = VisitTypeInline(left, &visitor);
+ if (!error.ok()) {
+ DCHECK(false) << "Types are not comparable: " << error.ToString();
+ }
+ return visitor.result();
+ }
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compare.h b/src/arrow/cpp/src/arrow/compare.h
new file mode 100644
index 000000000..6769b2386
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compare.h
@@ -0,0 +1,133 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Functions for comparing Arrow data structures
+
+#pragma once
+
+#include <cstdint>
+#include <iosfwd>
+
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Array;
+class DataType;
+class Tensor;
+class SparseTensor;
+struct Scalar;
+
+static constexpr double kDefaultAbsoluteTolerance = 1E-5;
+
+/// A container of options for equality comparisons
+class EqualOptions {
+ public:
+ /// Whether or not NaNs are considered equal.
+ bool nans_equal() const { return nans_equal_; }
+
+ /// Return a new EqualOptions object with the "nans_equal" property changed.
+ EqualOptions nans_equal(bool v) const {
+ auto res = EqualOptions(*this);
+ res.nans_equal_ = v;
+ return res;
+ }
+
+ /// The absolute tolerance for approximate comparisons of floating-point values.
+ double atol() const { return atol_; }
+
+ /// Return a new EqualOptions object with the "atol" property changed.
+ EqualOptions atol(double v) const {
+ auto res = EqualOptions(*this);
+ res.atol_ = v;
+ return res;
+ }
+
+ /// The ostream to which a diff will be formatted if arrays disagree.
+ /// If this is null (the default) no diff will be formatted.
+ std::ostream* diff_sink() const { return diff_sink_; }
+
+ /// Return a new EqualOptions object with the "diff_sink" property changed.
+ /// This option will be ignored if diff formatting of the types of compared arrays is
+ /// not supported.
+ EqualOptions diff_sink(std::ostream* diff_sink) const {
+ auto res = EqualOptions(*this);
+ res.diff_sink_ = diff_sink;
+ return res;
+ }
+
+ static EqualOptions Defaults() { return {}; }
+
+ protected:
+ double atol_ = kDefaultAbsoluteTolerance;
+ bool nans_equal_ = false;
+ std::ostream* diff_sink_ = NULLPTR;
+};
+
+/// Returns true if the arrays are exactly equal
+bool ARROW_EXPORT ArrayEquals(const Array& left, const Array& right,
+ const EqualOptions& = EqualOptions::Defaults());
+
+/// Returns true if the arrays are approximately equal. For non-floating point
+/// types, this is equivalent to ArrayEquals(left, right)
+bool ARROW_EXPORT ArrayApproxEquals(const Array& left, const Array& right,
+ const EqualOptions& = EqualOptions::Defaults());
+
+/// Returns true if indicated equal-length segment of arrays are exactly equal
+bool ARROW_EXPORT ArrayRangeEquals(const Array& left, const Array& right,
+ int64_t start_idx, int64_t end_idx,
+ int64_t other_start_idx,
+ const EqualOptions& = EqualOptions::Defaults());
+
+/// Returns true if indicated equal-length segment of arrays are approximately equal
+bool ARROW_EXPORT ArrayRangeApproxEquals(const Array& left, const Array& right,
+ int64_t start_idx, int64_t end_idx,
+ int64_t other_start_idx,
+ const EqualOptions& = EqualOptions::Defaults());
+
+bool ARROW_EXPORT TensorEquals(const Tensor& left, const Tensor& right,
+ const EqualOptions& = EqualOptions::Defaults());
+
+/// EXPERIMENTAL: Returns true if the given sparse tensors are exactly equal
+bool ARROW_EXPORT SparseTensorEquals(const SparseTensor& left, const SparseTensor& right,
+ const EqualOptions& = EqualOptions::Defaults());
+
+/// Returns true if the type metadata are exactly equal
+/// \param[in] left a DataType
+/// \param[in] right a DataType
+/// \param[in] check_metadata whether to compare KeyValueMetadata for child
+/// fields
+bool ARROW_EXPORT TypeEquals(const DataType& left, const DataType& right,
+ bool check_metadata = true);
+
+/// Returns true if scalars are equal
+/// \param[in] left a Scalar
+/// \param[in] right a Scalar
+/// \param[in] options comparison options
+bool ARROW_EXPORT ScalarEquals(const Scalar& left, const Scalar& right,
+ const EqualOptions& options = EqualOptions::Defaults());
+
+/// Returns true if scalars are approximately equal
+/// \param[in] left a Scalar
+/// \param[in] right a Scalar
+/// \param[in] options comparison options
+bool ARROW_EXPORT
+ScalarApproxEquals(const Scalar& left, const Scalar& right,
+ const EqualOptions& options = EqualOptions::Defaults());
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compare_benchmark.cc b/src/arrow/cpp/src/arrow/compare_benchmark.cc
new file mode 100644
index 000000000..2699f90f6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compare_benchmark.cc
@@ -0,0 +1,164 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/array.h"
+#include "arrow/compare.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+
+constexpr auto kSeed = 0x94378165;
+
+static void BenchmarkArrayRangeEquals(const std::shared_ptr<Array>& array,
+ benchmark::State& state) {
+ const auto left_array = array;
+ // Make sure pointer equality can't be used as a shortcut
+ // (should we would deep-copy child_data and buffers?)
+ const auto right_array =
+ MakeArray(array->data()->Copy())->Slice(1, array->length() - 1);
+
+ for (auto _ : state) {
+ const bool are_ok = ArrayRangeEquals(*left_array, *right_array,
+ /*left_start_idx=*/1,
+ /*left_end_idx=*/array->length() - 2,
+ /*right_start_idx=*/0);
+ if (ARROW_PREDICT_FALSE(!are_ok)) {
+ ARROW_LOG(FATAL) << "Arrays should have compared equal";
+ }
+ }
+}
+
+static void ArrayRangeEqualsInt32(benchmark::State& state) {
+ RegressionArgs args(state, /*size_is_bytes=*/false);
+
+ auto rng = random::RandomArrayGenerator(kSeed);
+ auto array = rng.Int32(args.size, 0, 100, args.null_proportion);
+
+ BenchmarkArrayRangeEquals(array, state);
+}
+
+static void ArrayRangeEqualsFloat32(benchmark::State& state) {
+ RegressionArgs args(state, /*size_is_bytes=*/false);
+
+ auto rng = random::RandomArrayGenerator(kSeed);
+ auto array = rng.Float32(args.size, 0, 100, args.null_proportion);
+
+ BenchmarkArrayRangeEquals(array, state);
+}
+
+static void ArrayRangeEqualsBoolean(benchmark::State& state) {
+ RegressionArgs args(state, /*size_is_bytes=*/false);
+
+ auto rng = random::RandomArrayGenerator(kSeed);
+ auto array = rng.Boolean(args.size, /*true_probability=*/0.3, args.null_proportion);
+
+ BenchmarkArrayRangeEquals(array, state);
+}
+
+static void ArrayRangeEqualsString(benchmark::State& state) {
+ RegressionArgs args(state, /*size_is_bytes=*/false);
+
+ auto rng = random::RandomArrayGenerator(kSeed);
+ auto array =
+ rng.String(args.size, /*min_length=*/0, /*max_length=*/15, args.null_proportion);
+
+ BenchmarkArrayRangeEquals(array, state);
+}
+
+static void ArrayRangeEqualsFixedSizeBinary(benchmark::State& state) {
+ RegressionArgs args(state, /*size_is_bytes=*/false);
+
+ auto rng = random::RandomArrayGenerator(kSeed);
+ auto array = rng.FixedSizeBinary(args.size, /*byte_width=*/8, args.null_proportion);
+
+ BenchmarkArrayRangeEquals(array, state);
+}
+
+static void ArrayRangeEqualsListOfInt32(benchmark::State& state) {
+ RegressionArgs args(state, /*size_is_bytes=*/false);
+
+ auto rng = random::RandomArrayGenerator(kSeed);
+ auto values = rng.Int32(args.size * 10, 0, 100, args.null_proportion);
+ // Force empty list null entries, since it is overwhelmingly the common case.
+ auto array = rng.List(*values, /*size=*/args.size, args.null_proportion,
+ /*force_empty_nulls=*/true);
+
+ BenchmarkArrayRangeEquals(array, state);
+}
+
+static void ArrayRangeEqualsStruct(benchmark::State& state) {
+ // struct<int32, utf8>
+ RegressionArgs args(state, /*size_is_bytes=*/false);
+
+ auto rng = random::RandomArrayGenerator(kSeed);
+ auto values1 = rng.Int32(args.size, 0, 100, args.null_proportion);
+ auto values2 =
+ rng.String(args.size, /*min_length=*/0, /*max_length=*/15, args.null_proportion);
+ auto null_bitmap = rng.NullBitmap(args.size, args.null_proportion);
+ auto array = *StructArray::Make({values1, values2},
+ std::vector<std::string>{"ints", "strs"}, null_bitmap);
+
+ BenchmarkArrayRangeEquals(array, state);
+}
+
+static void ArrayRangeEqualsSparseUnion(benchmark::State& state) {
+ // sparse_union<int32, utf8>
+ RegressionArgs args(state, /*size_is_bytes=*/false);
+
+ auto rng = random::RandomArrayGenerator(kSeed);
+ auto values1 = rng.Int32(args.size, 0, 100, args.null_proportion);
+ auto values2 =
+ rng.String(args.size, /*min_length=*/0, /*max_length=*/15, args.null_proportion);
+ auto array = rng.SparseUnion({values1, values2}, args.size);
+
+ BenchmarkArrayRangeEquals(array, state);
+}
+
+static void ArrayRangeEqualsDenseUnion(benchmark::State& state) {
+ // dense_union<int32, utf8>
+ RegressionArgs args(state, /*size_is_bytes=*/false);
+
+ auto rng = random::RandomArrayGenerator(kSeed);
+ auto values1 = rng.Int32(args.size, 0, 100, args.null_proportion);
+ auto values2 =
+ rng.String(args.size, /*min_length=*/0, /*max_length=*/15, args.null_proportion);
+ auto array = rng.DenseUnion({values1, values2}, args.size);
+
+ BenchmarkArrayRangeEquals(array, state);
+}
+
+BENCHMARK(ArrayRangeEqualsInt32)->Apply(RegressionSetArgs);
+BENCHMARK(ArrayRangeEqualsFloat32)->Apply(RegressionSetArgs);
+BENCHMARK(ArrayRangeEqualsBoolean)->Apply(RegressionSetArgs);
+BENCHMARK(ArrayRangeEqualsString)->Apply(RegressionSetArgs);
+BENCHMARK(ArrayRangeEqualsFixedSizeBinary)->Apply(RegressionSetArgs);
+BENCHMARK(ArrayRangeEqualsListOfInt32)->Apply(RegressionSetArgs);
+BENCHMARK(ArrayRangeEqualsStruct)->Apply(RegressionSetArgs);
+BENCHMARK(ArrayRangeEqualsSparseUnion)->Apply(RegressionSetArgs);
+BENCHMARK(ArrayRangeEqualsDenseUnion)->Apply(RegressionSetArgs);
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/CMakeLists.txt b/src/arrow/cpp/src/arrow/compute/CMakeLists.txt
new file mode 100644
index 000000000..897dc32f3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/CMakeLists.txt
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_custom_target(arrow_compute)
+
+arrow_install_all_headers("arrow/compute")
+
+# pkg-config support
+arrow_add_pkg_config("arrow-compute")
+
+#
+# Unit tests
+#
+
+function(ADD_ARROW_COMPUTE_TEST REL_TEST_NAME)
+ set(options)
+ set(one_value_args PREFIX)
+ set(multi_value_args LABELS)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+
+ if(ARG_PREFIX)
+ set(PREFIX ${ARG_PREFIX})
+ else()
+ set(PREFIX "arrow-compute")
+ endif()
+
+ if(ARG_LABELS)
+ set(LABELS ${ARG_LABELS})
+ else()
+ set(LABELS "arrow_compute")
+ endif()
+
+ add_arrow_test(${REL_TEST_NAME}
+ EXTRA_LINK_LIBS
+ ${ARROW_DATASET_TEST_LINK_LIBS}
+ PREFIX
+ ${PREFIX}
+ LABELS
+ ${LABELS}
+ ${ARG_UNPARSED_ARGUMENTS})
+endfunction()
+
+add_arrow_compute_test(internals_test
+ SOURCES
+ function_test.cc
+ exec_test.cc
+ kernel_test.cc
+ registry_test.cc)
+
+add_arrow_benchmark(function_benchmark PREFIX "arrow-compute")
+
+add_subdirectory(kernels)
+
+add_subdirectory(exec)
diff --git a/src/arrow/cpp/src/arrow/compute/README.md b/src/arrow/cpp/src/arrow/compute/README.md
new file mode 100644
index 000000000..80d8918e3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/README.md
@@ -0,0 +1,58 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+## Apache Arrow C++ Compute Functions
+
+This submodule contains analytical functions that process primarily Arrow
+columnar data; some functions can process scalar or Arrow-based array
+inputs. These are intended for use inside query engines, data frame libraries,
+etc.
+
+Many functions have SQL-like semantics in that they perform elementwise or
+scalar operations on whole arrays at a time. Other functions are not SQL-like
+and compute results that may be a different length or whose results depend on
+the order of the values.
+
+Some basic terminology:
+
+* We use the term "function" to refer to particular general operation that may
+ have many different implementations corresponding to different combinations
+ of types or function behavior options.
+* We call a specific implementation of a function a "kernel". When executing a
+ function on inputs, we must first select a suitable kernel (kernel selection
+ is called "dispatching") corresponding to the value types of the inputs
+* Functions along with their kernel implementations are collected in a
+ "function registry". Given a function name and argument types, we can look up
+ that function and dispatch to a compatible kernel.
+
+Types of functions
+
+* Scalar functions: elementwise functions that perform scalar operations in a
+ vectorized manner. These functions are generally valid for SQL-like
+ context. These are called "scalar" in that the functions executed consider
+ each value in an array independently, and the output array or arrays have the
+ same length as the input arrays. The result for each array cell is generally
+ independent of its position in the array.
+* Vector functions, which produce a result whose output is generally dependent
+ on the entire contents of the input arrays. These functions **are generally
+ not valid** for SQL-like processing because the output size may be different
+ than the input size, and the result may change based on the order of the
+ values in the array. This includes things like array subselection, sorting,
+ hashing, and more.
+* Scalar aggregate functions of which can be used in a SQL-like context \ No newline at end of file
diff --git a/src/arrow/cpp/src/arrow/compute/api.h b/src/arrow/cpp/src/arrow/compute/api.h
new file mode 100644
index 000000000..a890cd362
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/api.h
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// NOTE: API is EXPERIMENTAL and will change without going through a
+// deprecation cycle
+
+#pragma once
+
+/// \defgroup compute-concrete-options Concrete option classes for compute functions
+/// @{
+/// @}
+
+#include "arrow/compute/api_aggregate.h" // IWYU pragma: export
+#include "arrow/compute/api_scalar.h" // IWYU pragma: export
+#include "arrow/compute/api_vector.h" // IWYU pragma: export
+#include "arrow/compute/cast.h" // IWYU pragma: export
+#include "arrow/compute/exec.h" // IWYU pragma: export
+#include "arrow/compute/function.h" // IWYU pragma: export
+#include "arrow/compute/kernel.h" // IWYU pragma: export
+#include "arrow/compute/registry.h" // IWYU pragma: export
+#include "arrow/datum.h" // IWYU pragma: export
diff --git a/src/arrow/cpp/src/arrow/compute/api_aggregate.cc b/src/arrow/cpp/src/arrow/compute/api_aggregate.cc
new file mode 100644
index 000000000..8cd3a8d2a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/api_aggregate.cc
@@ -0,0 +1,250 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/api_aggregate.h"
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/function_internal.h"
+#include "arrow/compute/registry.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+namespace internal {
+template <>
+struct EnumTraits<compute::CountOptions::CountMode>
+ : BasicEnumTraits<compute::CountOptions::CountMode, compute::CountOptions::ONLY_VALID,
+ compute::CountOptions::ONLY_NULL, compute::CountOptions::ALL> {
+ static std::string name() { return "CountOptions::CountMode"; }
+ static std::string value_name(compute::CountOptions::CountMode value) {
+ switch (value) {
+ case compute::CountOptions::ONLY_VALID:
+ return "NON_NULL";
+ case compute::CountOptions::ONLY_NULL:
+ return "NULLS";
+ case compute::CountOptions::ALL:
+ return "ALL";
+ }
+ return "<INVALID>";
+ }
+};
+
+template <>
+struct EnumTraits<compute::QuantileOptions::Interpolation>
+ : BasicEnumTraits<compute::QuantileOptions::Interpolation,
+ compute::QuantileOptions::LINEAR, compute::QuantileOptions::LOWER,
+ compute::QuantileOptions::HIGHER, compute::QuantileOptions::NEAREST,
+ compute::QuantileOptions::MIDPOINT> {
+ static std::string name() { return "QuantileOptions::Interpolation"; }
+ static std::string value_name(compute::QuantileOptions::Interpolation value) {
+ switch (value) {
+ case compute::QuantileOptions::LINEAR:
+ return "LINEAR";
+ case compute::QuantileOptions::LOWER:
+ return "LOWER";
+ case compute::QuantileOptions::HIGHER:
+ return "HIGHER";
+ case compute::QuantileOptions::NEAREST:
+ return "NEAREST";
+ case compute::QuantileOptions::MIDPOINT:
+ return "MIDPOINT";
+ }
+ return "<INVALID>";
+ }
+};
+} // namespace internal
+
+namespace compute {
+
+// ----------------------------------------------------------------------
+// Function options
+
+using ::arrow::internal::checked_cast;
+
+namespace internal {
+namespace {
+using ::arrow::internal::DataMember;
+static auto kScalarAggregateOptionsType = GetFunctionOptionsType<ScalarAggregateOptions>(
+ DataMember("skip_nulls", &ScalarAggregateOptions::skip_nulls),
+ DataMember("min_count", &ScalarAggregateOptions::min_count));
+static auto kCountOptionsType =
+ GetFunctionOptionsType<CountOptions>(DataMember("mode", &CountOptions::mode));
+static auto kModeOptionsType = GetFunctionOptionsType<ModeOptions>(
+ DataMember("n", &ModeOptions::n), DataMember("skip_nulls", &ModeOptions::skip_nulls),
+ DataMember("min_count", &ModeOptions::min_count));
+static auto kVarianceOptionsType = GetFunctionOptionsType<VarianceOptions>(
+ DataMember("ddof", &VarianceOptions::ddof),
+ DataMember("skip_nulls", &VarianceOptions::skip_nulls),
+ DataMember("min_count", &VarianceOptions::min_count));
+static auto kQuantileOptionsType = GetFunctionOptionsType<QuantileOptions>(
+ DataMember("q", &QuantileOptions::q),
+ DataMember("interpolation", &QuantileOptions::interpolation),
+ DataMember("skip_nulls", &QuantileOptions::skip_nulls),
+ DataMember("min_count", &QuantileOptions::min_count));
+static auto kTDigestOptionsType = GetFunctionOptionsType<TDigestOptions>(
+ DataMember("q", &TDigestOptions::q), DataMember("delta", &TDigestOptions::delta),
+ DataMember("buffer_size", &TDigestOptions::buffer_size),
+ DataMember("skip_nulls", &TDigestOptions::skip_nulls),
+ DataMember("min_count", &TDigestOptions::min_count));
+static auto kIndexOptionsType =
+ GetFunctionOptionsType<IndexOptions>(DataMember("value", &IndexOptions::value));
+} // namespace
+} // namespace internal
+
+ScalarAggregateOptions::ScalarAggregateOptions(bool skip_nulls, uint32_t min_count)
+ : FunctionOptions(internal::kScalarAggregateOptionsType),
+ skip_nulls(skip_nulls),
+ min_count(min_count) {}
+constexpr char ScalarAggregateOptions::kTypeName[];
+
+CountOptions::CountOptions(CountMode mode)
+ : FunctionOptions(internal::kCountOptionsType), mode(mode) {}
+constexpr char CountOptions::kTypeName[];
+
+ModeOptions::ModeOptions(int64_t n, bool skip_nulls, uint32_t min_count)
+ : FunctionOptions(internal::kModeOptionsType),
+ n{n},
+ skip_nulls{skip_nulls},
+ min_count{min_count} {}
+constexpr char ModeOptions::kTypeName[];
+
+VarianceOptions::VarianceOptions(int ddof, bool skip_nulls, uint32_t min_count)
+ : FunctionOptions(internal::kVarianceOptionsType),
+ ddof(ddof),
+ skip_nulls(skip_nulls),
+ min_count(min_count) {}
+constexpr char VarianceOptions::kTypeName[];
+
+QuantileOptions::QuantileOptions(double q, enum Interpolation interpolation,
+ bool skip_nulls, uint32_t min_count)
+ : FunctionOptions(internal::kQuantileOptionsType),
+ q{q},
+ interpolation{interpolation},
+ skip_nulls{skip_nulls},
+ min_count{min_count} {}
+QuantileOptions::QuantileOptions(std::vector<double> q, enum Interpolation interpolation,
+ bool skip_nulls, uint32_t min_count)
+ : FunctionOptions(internal::kQuantileOptionsType),
+ q{std::move(q)},
+ interpolation{interpolation},
+ skip_nulls{skip_nulls},
+ min_count{min_count} {}
+constexpr char QuantileOptions::kTypeName[];
+
+TDigestOptions::TDigestOptions(double q, uint32_t delta, uint32_t buffer_size,
+ bool skip_nulls, uint32_t min_count)
+ : FunctionOptions(internal::kTDigestOptionsType),
+ q{q},
+ delta{delta},
+ buffer_size{buffer_size},
+ skip_nulls{skip_nulls},
+ min_count{min_count} {}
+TDigestOptions::TDigestOptions(std::vector<double> q, uint32_t delta,
+ uint32_t buffer_size, bool skip_nulls, uint32_t min_count)
+ : FunctionOptions(internal::kTDigestOptionsType),
+ q{std::move(q)},
+ delta{delta},
+ buffer_size{buffer_size},
+ skip_nulls{skip_nulls},
+ min_count{min_count} {}
+constexpr char TDigestOptions::kTypeName[];
+
+IndexOptions::IndexOptions(std::shared_ptr<Scalar> value)
+ : FunctionOptions(internal::kIndexOptionsType), value{std::move(value)} {}
+IndexOptions::IndexOptions() : IndexOptions(std::make_shared<NullScalar>()) {}
+constexpr char IndexOptions::kTypeName[];
+
+namespace internal {
+void RegisterAggregateOptions(FunctionRegistry* registry) {
+ DCHECK_OK(registry->AddFunctionOptionsType(kScalarAggregateOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kCountOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kModeOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kVarianceOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kQuantileOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kTDigestOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kIndexOptionsType));
+}
+} // namespace internal
+
+// ----------------------------------------------------------------------
+// Scalar aggregates
+
+Result<Datum> Count(const Datum& value, const CountOptions& options, ExecContext* ctx) {
+ return CallFunction("count", {value}, &options, ctx);
+}
+
+Result<Datum> Mean(const Datum& value, const ScalarAggregateOptions& options,
+ ExecContext* ctx) {
+ return CallFunction("mean", {value}, &options, ctx);
+}
+
+Result<Datum> Product(const Datum& value, const ScalarAggregateOptions& options,
+ ExecContext* ctx) {
+ return CallFunction("product", {value}, &options, ctx);
+}
+
+Result<Datum> Sum(const Datum& value, const ScalarAggregateOptions& options,
+ ExecContext* ctx) {
+ return CallFunction("sum", {value}, &options, ctx);
+}
+
+Result<Datum> MinMax(const Datum& value, const ScalarAggregateOptions& options,
+ ExecContext* ctx) {
+ return CallFunction("min_max", {value}, &options, ctx);
+}
+
+Result<Datum> Any(const Datum& value, const ScalarAggregateOptions& options,
+ ExecContext* ctx) {
+ return CallFunction("any", {value}, &options, ctx);
+}
+
+Result<Datum> All(const Datum& value, const ScalarAggregateOptions& options,
+ ExecContext* ctx) {
+ return CallFunction("all", {value}, &options, ctx);
+}
+
+Result<Datum> Mode(const Datum& value, const ModeOptions& options, ExecContext* ctx) {
+ return CallFunction("mode", {value}, &options, ctx);
+}
+
+Result<Datum> Stddev(const Datum& value, const VarianceOptions& options,
+ ExecContext* ctx) {
+ return CallFunction("stddev", {value}, &options, ctx);
+}
+
+Result<Datum> Variance(const Datum& value, const VarianceOptions& options,
+ ExecContext* ctx) {
+ return CallFunction("variance", {value}, &options, ctx);
+}
+
+Result<Datum> Quantile(const Datum& value, const QuantileOptions& options,
+ ExecContext* ctx) {
+ return CallFunction("quantile", {value}, &options, ctx);
+}
+
+Result<Datum> TDigest(const Datum& value, const TDigestOptions& options,
+ ExecContext* ctx) {
+ return CallFunction("tdigest", {value}, &options, ctx);
+}
+
+Result<Datum> Index(const Datum& value, const IndexOptions& options, ExecContext* ctx) {
+ return CallFunction("index", {value}, &options, ctx);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/api_aggregate.h b/src/arrow/cpp/src/arrow/compute/api_aggregate.h
new file mode 100644
index 000000000..c8df81773
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/api_aggregate.h
@@ -0,0 +1,494 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Eager evaluation convenience APIs for invoking common functions, including
+// necessary memory allocations
+
+#pragma once
+
+#include "arrow/compute/function.h"
+#include "arrow/datum.h"
+#include "arrow/result.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Array;
+
+namespace compute {
+
+class ExecContext;
+
+// ----------------------------------------------------------------------
+// Aggregate functions
+
+/// \addtogroup compute-concrete-options
+/// @{
+
+/// \brief Control general scalar aggregate kernel behavior
+///
+/// By default, null values are ignored (skip_nulls = true).
+class ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions {
+ public:
+ explicit ScalarAggregateOptions(bool skip_nulls = true, uint32_t min_count = 1);
+ constexpr static char const kTypeName[] = "ScalarAggregateOptions";
+ static ScalarAggregateOptions Defaults() { return ScalarAggregateOptions{}; }
+
+ /// If true (the default), null values are ignored. Otherwise, if any value is null,
+ /// emit null.
+ bool skip_nulls;
+ /// If less than this many non-null values are observed, emit null.
+ uint32_t min_count;
+};
+
+/// \brief Control count aggregate kernel behavior.
+///
+/// By default, only non-null values are counted.
+class ARROW_EXPORT CountOptions : public FunctionOptions {
+ public:
+ enum CountMode {
+ /// Count only non-null values.
+ ONLY_VALID = 0,
+ /// Count only null values.
+ ONLY_NULL,
+ /// Count both non-null and null values.
+ ALL,
+ };
+ explicit CountOptions(CountMode mode = CountMode::ONLY_VALID);
+ constexpr static char const kTypeName[] = "CountOptions";
+ static CountOptions Defaults() { return CountOptions{}; }
+
+ CountMode mode;
+};
+
+/// \brief Control Mode kernel behavior
+///
+/// Returns top-n common values and counts.
+/// By default, returns the most common value and count.
+class ARROW_EXPORT ModeOptions : public FunctionOptions {
+ public:
+ explicit ModeOptions(int64_t n = 1, bool skip_nulls = true, uint32_t min_count = 0);
+ constexpr static char const kTypeName[] = "ModeOptions";
+ static ModeOptions Defaults() { return ModeOptions{}; }
+
+ int64_t n = 1;
+ /// If true (the default), null values are ignored. Otherwise, if any value is null,
+ /// emit null.
+ bool skip_nulls;
+ /// If less than this many non-null values are observed, emit null.
+ uint32_t min_count;
+};
+
+/// \brief Control Delta Degrees of Freedom (ddof) of Variance and Stddev kernel
+///
+/// The divisor used in calculations is N - ddof, where N is the number of elements.
+/// By default, ddof is zero, and population variance or stddev is returned.
+class ARROW_EXPORT VarianceOptions : public FunctionOptions {
+ public:
+ explicit VarianceOptions(int ddof = 0, bool skip_nulls = true, uint32_t min_count = 0);
+ constexpr static char const kTypeName[] = "VarianceOptions";
+ static VarianceOptions Defaults() { return VarianceOptions{}; }
+
+ int ddof = 0;
+ /// If true (the default), null values are ignored. Otherwise, if any value is null,
+ /// emit null.
+ bool skip_nulls;
+ /// If less than this many non-null values are observed, emit null.
+ uint32_t min_count;
+};
+
+/// \brief Control Quantile kernel behavior
+///
+/// By default, returns the median value.
+class ARROW_EXPORT QuantileOptions : public FunctionOptions {
+ public:
+ /// Interpolation method to use when quantile lies between two data points
+ enum Interpolation {
+ LINEAR = 0,
+ LOWER,
+ HIGHER,
+ NEAREST,
+ MIDPOINT,
+ };
+
+ explicit QuantileOptions(double q = 0.5, enum Interpolation interpolation = LINEAR,
+ bool skip_nulls = true, uint32_t min_count = 0);
+
+ explicit QuantileOptions(std::vector<double> q,
+ enum Interpolation interpolation = LINEAR,
+ bool skip_nulls = true, uint32_t min_count = 0);
+
+ constexpr static char const kTypeName[] = "QuantileOptions";
+ static QuantileOptions Defaults() { return QuantileOptions{}; }
+
+ /// quantile must be between 0 and 1 inclusive
+ std::vector<double> q;
+ enum Interpolation interpolation;
+ /// If true (the default), null values are ignored. Otherwise, if any value is null,
+ /// emit null.
+ bool skip_nulls;
+ /// If less than this many non-null values are observed, emit null.
+ uint32_t min_count;
+};
+
+/// \brief Control TDigest approximate quantile kernel behavior
+///
+/// By default, returns the median value.
+class ARROW_EXPORT TDigestOptions : public FunctionOptions {
+ public:
+ explicit TDigestOptions(double q = 0.5, uint32_t delta = 100,
+ uint32_t buffer_size = 500, bool skip_nulls = true,
+ uint32_t min_count = 0);
+ explicit TDigestOptions(std::vector<double> q, uint32_t delta = 100,
+ uint32_t buffer_size = 500, bool skip_nulls = true,
+ uint32_t min_count = 0);
+ constexpr static char const kTypeName[] = "TDigestOptions";
+ static TDigestOptions Defaults() { return TDigestOptions{}; }
+
+ /// quantile must be between 0 and 1 inclusive
+ std::vector<double> q;
+ /// compression parameter, default 100
+ uint32_t delta;
+ /// input buffer size, default 500
+ uint32_t buffer_size;
+ /// If true (the default), null values are ignored. Otherwise, if any value is null,
+ /// emit null.
+ bool skip_nulls;
+ /// If less than this many non-null values are observed, emit null.
+ uint32_t min_count;
+};
+
+/// \brief Control Index kernel behavior
+class ARROW_EXPORT IndexOptions : public FunctionOptions {
+ public:
+ explicit IndexOptions(std::shared_ptr<Scalar> value);
+ // Default constructor for serialization
+ IndexOptions();
+ constexpr static char const kTypeName[] = "IndexOptions";
+
+ std::shared_ptr<Scalar> value;
+};
+
+/// @}
+
+/// \brief Count values in an array.
+///
+/// \param[in] options counting options, see CountOptions for more information
+/// \param[in] datum to count
+/// \param[in] ctx the function execution context, optional
+/// \return out resulting datum
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Count(const Datum& datum,
+ const CountOptions& options = CountOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the mean of a numeric array.
+///
+/// \param[in] value datum to compute the mean, expecting Array
+/// \param[in] options see ScalarAggregateOptions for more information
+/// \param[in] ctx the function execution context, optional
+/// \return datum of the computed mean as a DoubleScalar
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Mean(
+ const Datum& value,
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the product of values of a numeric array.
+///
+/// \param[in] value datum to compute product of, expecting Array or ChunkedArray
+/// \param[in] options see ScalarAggregateOptions for more information
+/// \param[in] ctx the function execution context, optional
+/// \return datum of the computed sum as a Scalar
+///
+/// \since 6.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Product(
+ const Datum& value,
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Sum values of a numeric array.
+///
+/// \param[in] value datum to sum, expecting Array or ChunkedArray
+/// \param[in] options see ScalarAggregateOptions for more information
+/// \param[in] ctx the function execution context, optional
+/// \return datum of the computed sum as a Scalar
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Sum(
+ const Datum& value,
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Calculate the min / max of a numeric array
+///
+/// This function returns both the min and max as a struct scalar, with type
+/// struct<min: T, max: T>, where T is the input type
+///
+/// \param[in] value input datum, expecting Array or ChunkedArray
+/// \param[in] options see ScalarAggregateOptions for more information
+/// \param[in] ctx the function execution context, optional
+/// \return resulting datum as a struct<min: T, max: T> scalar
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> MinMax(
+ const Datum& value,
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Test whether any element in a boolean array evaluates to true.
+///
+/// This function returns true if any of the elements in the array evaluates
+/// to true and false otherwise. Null values are ignored by default.
+/// If null values are taken into account by setting ScalarAggregateOptions
+/// parameter skip_nulls = false then Kleene logic is used.
+/// See KleeneOr for more details on Kleene logic.
+///
+/// \param[in] value input datum, expecting a boolean array
+/// \param[in] options see ScalarAggregateOptions for more information
+/// \param[in] ctx the function execution context, optional
+/// \return resulting datum as a BooleanScalar
+///
+/// \since 3.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Any(
+ const Datum& value,
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Test whether all elements in a boolean array evaluate to true.
+///
+/// This function returns true if all of the elements in the array evaluate
+/// to true and false otherwise. Null values are ignored by default.
+/// If null values are taken into account by setting ScalarAggregateOptions
+/// parameter skip_nulls = false then Kleene logic is used.
+/// See KleeneAnd for more details on Kleene logic.
+///
+/// \param[in] value input datum, expecting a boolean array
+/// \param[in] options see ScalarAggregateOptions for more information
+/// \param[in] ctx the function execution context, optional
+/// \return resulting datum as a BooleanScalar
+
+/// \since 3.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> All(
+ const Datum& value,
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Calculate the modal (most common) value of a numeric array
+///
+/// This function returns top-n most common values and number of times they occur as
+/// an array of `struct<mode: T, count: int64>`, where T is the input type.
+/// Values with larger counts are returned before smaller ones.
+/// If there are more than one values with same count, smaller value is returned first.
+///
+/// \param[in] value input datum, expecting Array or ChunkedArray
+/// \param[in] options see ModeOptions for more information
+/// \param[in] ctx the function execution context, optional
+/// \return resulting datum as an array of struct<mode: T, count: int64>
+///
+/// \since 2.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Mode(const Datum& value,
+ const ModeOptions& options = ModeOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Calculate the standard deviation of a numeric array
+///
+/// \param[in] value input datum, expecting Array or ChunkedArray
+/// \param[in] options see VarianceOptions for more information
+/// \param[in] ctx the function execution context, optional
+/// \return datum of the computed standard deviation as a DoubleScalar
+///
+/// \since 2.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Stddev(const Datum& value,
+ const VarianceOptions& options = VarianceOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Calculate the variance of a numeric array
+///
+/// \param[in] value input datum, expecting Array or ChunkedArray
+/// \param[in] options see VarianceOptions for more information
+/// \param[in] ctx the function execution context, optional
+/// \return datum of the computed variance as a DoubleScalar
+///
+/// \since 2.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Variance(const Datum& value,
+ const VarianceOptions& options = VarianceOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Calculate the quantiles of a numeric array
+///
+/// \param[in] value input datum, expecting Array or ChunkedArray
+/// \param[in] options see QuantileOptions for more information
+/// \param[in] ctx the function execution context, optional
+/// \return resulting datum as an array
+///
+/// \since 4.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Quantile(const Datum& value,
+ const QuantileOptions& options = QuantileOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Calculate the approximate quantiles of a numeric array with T-Digest algorithm
+///
+/// \param[in] value input datum, expecting Array or ChunkedArray
+/// \param[in] options see TDigestOptions for more information
+/// \param[in] ctx the function execution context, optional
+/// \return resulting datum as an array
+///
+/// \since 4.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> TDigest(const Datum& value,
+ const TDigestOptions& options = TDigestOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Find the first index of a value in an array.
+///
+/// \param[in] value The array to search.
+/// \param[in] options The array to search for. See IndexOoptions.
+/// \param[in] ctx the function execution context, optional
+/// \return out a Scalar containing the index (or -1 if not found).
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Index(const Datum& value, const IndexOptions& options,
+ ExecContext* ctx = NULLPTR);
+
+namespace internal {
+
+/// Internal use only: streaming group identifier.
+/// Consumes batches of keys and yields batches of the group ids.
+class ARROW_EXPORT Grouper {
+ public:
+ virtual ~Grouper() = default;
+
+ /// Construct a Grouper which receives the specified key types
+ static Result<std::unique_ptr<Grouper>> Make(const std::vector<ValueDescr>& descrs,
+ ExecContext* ctx = default_exec_context());
+
+ /// Consume a batch of keys, producing the corresponding group ids as an integer array.
+ /// Currently only uint32 indices will be produced, eventually the bit width will only
+ /// be as wide as necessary.
+ virtual Result<Datum> Consume(const ExecBatch& batch) = 0;
+
+ /// Get current unique keys. May be called multiple times.
+ virtual Result<ExecBatch> GetUniques() = 0;
+
+ /// Get the current number of groups.
+ virtual uint32_t num_groups() const = 0;
+
+ /// \brief Assemble lists of indices of identical elements.
+ ///
+ /// \param[in] ids An unsigned, all-valid integral array which will be
+ /// used as grouping criteria.
+ /// \param[in] num_groups An upper bound for the elements of ids
+ /// \return A num_groups-long ListArray where the slot at i contains a
+ /// list of indices where i appears in ids.
+ ///
+ /// MakeGroupings([
+ /// 2,
+ /// 2,
+ /// 5,
+ /// 5,
+ /// 2,
+ /// 3
+ /// ], 8) == [
+ /// [],
+ /// [],
+ /// [0, 1, 4],
+ /// [5],
+ /// [],
+ /// [2, 3],
+ /// [],
+ /// []
+ /// ]
+ static Result<std::shared_ptr<ListArray>> MakeGroupings(
+ const UInt32Array& ids, uint32_t num_groups,
+ ExecContext* ctx = default_exec_context());
+
+ /// \brief Produce a ListArray whose slots are selections of `array` which correspond to
+ /// the provided groupings.
+ ///
+ /// For example,
+ /// ApplyGroupings([
+ /// [],
+ /// [],
+ /// [0, 1, 4],
+ /// [5],
+ /// [],
+ /// [2, 3],
+ /// [],
+ /// []
+ /// ], [2, 2, 5, 5, 2, 3]) == [
+ /// [],
+ /// [],
+ /// [2, 2, 2],
+ /// [3],
+ /// [],
+ /// [5, 5],
+ /// [],
+ /// []
+ /// ]
+ static Result<std::shared_ptr<ListArray>> ApplyGroupings(
+ const ListArray& groupings, const Array& array,
+ ExecContext* ctx = default_exec_context());
+};
+
+/// \brief Configure a grouped aggregation
+struct ARROW_EXPORT Aggregate {
+ /// the name of the aggregation function
+ std::string function;
+
+ /// options for the aggregation function
+ const FunctionOptions* options;
+};
+
+/// Internal use only: helper function for testing HashAggregateKernels.
+/// This will be replaced by streaming execution operators.
+ARROW_EXPORT
+Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Datum>& keys,
+ const std::vector<Aggregate>& aggregates, bool use_threads = false,
+ ExecContext* ctx = default_exec_context());
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/api_scalar.cc b/src/arrow/cpp/src/arrow/compute/api_scalar.cc
new file mode 100644
index 000000000..e3fe1bdf7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/api_scalar.cc
@@ -0,0 +1,676 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/api_scalar.h"
+
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include "arrow/array/array_base.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/function_internal.h"
+#include "arrow/compute/registry.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+namespace internal {
+template <>
+struct EnumTraits<compute::JoinOptions::NullHandlingBehavior>
+ : BasicEnumTraits<compute::JoinOptions::NullHandlingBehavior,
+ compute::JoinOptions::NullHandlingBehavior::EMIT_NULL,
+ compute::JoinOptions::NullHandlingBehavior::SKIP,
+ compute::JoinOptions::NullHandlingBehavior::REPLACE> {
+ static std::string name() { return "JoinOptions::NullHandlingBehavior"; }
+ static std::string value_name(compute::JoinOptions::NullHandlingBehavior value) {
+ switch (value) {
+ case compute::JoinOptions::NullHandlingBehavior::EMIT_NULL:
+ return "EMIT_NULL";
+ case compute::JoinOptions::NullHandlingBehavior::SKIP:
+ return "SKIP";
+ case compute::JoinOptions::NullHandlingBehavior::REPLACE:
+ return "REPLACE";
+ }
+ return "<INVALID>";
+ }
+};
+
+template <>
+struct EnumTraits<TimeUnit::type>
+ : BasicEnumTraits<TimeUnit::type, TimeUnit::type::SECOND, TimeUnit::type::MILLI,
+ TimeUnit::type::MICRO, TimeUnit::type::NANO> {
+ static std::string name() { return "TimeUnit::type"; }
+ static std::string value_name(TimeUnit::type value) {
+ switch (value) {
+ case TimeUnit::type::SECOND:
+ return "SECOND";
+ case TimeUnit::type::MILLI:
+ return "MILLI";
+ case TimeUnit::type::MICRO:
+ return "MICRO";
+ case TimeUnit::type::NANO:
+ return "NANO";
+ }
+ return "<INVALID>";
+ }
+};
+
+template <>
+struct EnumTraits<compute::CompareOperator>
+ : BasicEnumTraits<
+ compute::CompareOperator, compute::CompareOperator::EQUAL,
+ compute::CompareOperator::NOT_EQUAL, compute::CompareOperator::GREATER,
+ compute::CompareOperator::GREATER_EQUAL, compute::CompareOperator::LESS,
+ compute::CompareOperator::LESS_EQUAL> {
+ static std::string name() { return "compute::CompareOperator"; }
+ static std::string value_name(compute::CompareOperator value) {
+ switch (value) {
+ case compute::CompareOperator::EQUAL:
+ return "EQUAL";
+ case compute::CompareOperator::NOT_EQUAL:
+ return "NOT_EQUAL";
+ case compute::CompareOperator::GREATER:
+ return "GREATER";
+ case compute::CompareOperator::GREATER_EQUAL:
+ return "GREATER_EQUAL";
+ case compute::CompareOperator::LESS:
+ return "LESS";
+ case compute::CompareOperator::LESS_EQUAL:
+ return "LESS_EQUAL";
+ }
+ return "<INVALID>";
+ }
+};
+template <>
+struct EnumTraits<compute::AssumeTimezoneOptions::Ambiguous>
+ : BasicEnumTraits<compute::AssumeTimezoneOptions::Ambiguous,
+ compute::AssumeTimezoneOptions::Ambiguous::AMBIGUOUS_RAISE,
+ compute::AssumeTimezoneOptions::Ambiguous::AMBIGUOUS_EARLIEST,
+ compute::AssumeTimezoneOptions::Ambiguous::AMBIGUOUS_LATEST> {
+ static std::string name() { return "AssumeTimezoneOptions::Ambiguous"; }
+ static std::string value_name(compute::AssumeTimezoneOptions::Ambiguous value) {
+ switch (value) {
+ case compute::AssumeTimezoneOptions::Ambiguous::AMBIGUOUS_RAISE:
+ return "AMBIGUOUS_RAISE";
+ case compute::AssumeTimezoneOptions::Ambiguous::AMBIGUOUS_EARLIEST:
+ return "AMBIGUOUS_EARLIEST";
+ case compute::AssumeTimezoneOptions::Ambiguous::AMBIGUOUS_LATEST:
+ return "AMBIGUOUS_LATEST";
+ }
+ return "<INVALID>";
+ }
+};
+template <>
+struct EnumTraits<compute::AssumeTimezoneOptions::Nonexistent>
+ : BasicEnumTraits<compute::AssumeTimezoneOptions::Nonexistent,
+ compute::AssumeTimezoneOptions::Nonexistent::NONEXISTENT_RAISE,
+ compute::AssumeTimezoneOptions::Nonexistent::NONEXISTENT_EARLIEST,
+ compute::AssumeTimezoneOptions::Nonexistent::NONEXISTENT_LATEST> {
+ static std::string name() { return "AssumeTimezoneOptions::Nonexistent"; }
+ static std::string value_name(compute::AssumeTimezoneOptions::Nonexistent value) {
+ switch (value) {
+ case compute::AssumeTimezoneOptions::Nonexistent::NONEXISTENT_RAISE:
+ return "NONEXISTENT_RAISE";
+ case compute::AssumeTimezoneOptions::Nonexistent::NONEXISTENT_EARLIEST:
+ return "NONEXISTENT_EARLIEST";
+ case compute::AssumeTimezoneOptions::Nonexistent::NONEXISTENT_LATEST:
+ return "NONEXISTENT_LATEST";
+ }
+ return "<INVALID>";
+ }
+};
+
+template <>
+struct EnumTraits<compute::RoundMode>
+ : BasicEnumTraits<compute::RoundMode, compute::RoundMode::DOWN,
+ compute::RoundMode::UP, compute::RoundMode::TOWARDS_ZERO,
+ compute::RoundMode::TOWARDS_INFINITY, compute::RoundMode::HALF_DOWN,
+ compute::RoundMode::HALF_UP, compute::RoundMode::HALF_TOWARDS_ZERO,
+ compute::RoundMode::HALF_TOWARDS_INFINITY,
+ compute::RoundMode::HALF_TO_EVEN, compute::RoundMode::HALF_TO_ODD> {
+ static std::string name() { return "compute::RoundMode"; }
+ static std::string value_name(compute::RoundMode value) {
+ switch (value) {
+ case compute::RoundMode::DOWN:
+ return "DOWN";
+ case compute::RoundMode::UP:
+ return "UP";
+ case compute::RoundMode::TOWARDS_ZERO:
+ return "TOWARDS_ZERO";
+ case compute::RoundMode::TOWARDS_INFINITY:
+ return "TOWARDS_INFINITY";
+ case compute::RoundMode::HALF_DOWN:
+ return "HALF_DOWN";
+ case compute::RoundMode::HALF_UP:
+ return "HALF_UP";
+ case compute::RoundMode::HALF_TOWARDS_ZERO:
+ return "HALF_TOWARDS_ZERO";
+ case compute::RoundMode::HALF_TOWARDS_INFINITY:
+ return "HALF_TOWARDS_INFINITY";
+ case compute::RoundMode::HALF_TO_EVEN:
+ return "HALF_TO_EVEN";
+ case compute::RoundMode::HALF_TO_ODD:
+ return "HALF_TO_ODD";
+ }
+ return "<INVALID>";
+ }
+};
+} // namespace internal
+
+namespace compute {
+
+// ----------------------------------------------------------------------
+// Function options
+
+using ::arrow::internal::checked_cast;
+
+namespace internal {
+namespace {
+using ::arrow::internal::DataMember;
+static auto kArithmeticOptionsType = GetFunctionOptionsType<ArithmeticOptions>(
+ DataMember("check_overflow", &ArithmeticOptions::check_overflow));
+static auto kElementWiseAggregateOptionsType =
+ GetFunctionOptionsType<ElementWiseAggregateOptions>(
+ DataMember("skip_nulls", &ElementWiseAggregateOptions::skip_nulls));
+static auto kRoundOptionsType = GetFunctionOptionsType<RoundOptions>(
+ DataMember("ndigits", &RoundOptions::ndigits),
+ DataMember("round_mode", &RoundOptions::round_mode));
+static auto kRoundToMultipleOptionsType = GetFunctionOptionsType<RoundToMultipleOptions>(
+ DataMember("multiple", &RoundToMultipleOptions::multiple),
+ DataMember("round_mode", &RoundToMultipleOptions::round_mode));
+static auto kJoinOptionsType = GetFunctionOptionsType<JoinOptions>(
+ DataMember("null_handling", &JoinOptions::null_handling),
+ DataMember("null_replacement", &JoinOptions::null_replacement));
+static auto kMatchSubstringOptionsType = GetFunctionOptionsType<MatchSubstringOptions>(
+ DataMember("pattern", &MatchSubstringOptions::pattern),
+ DataMember("ignore_case", &MatchSubstringOptions::ignore_case));
+static auto kSplitOptionsType = GetFunctionOptionsType<SplitOptions>(
+ DataMember("max_splits", &SplitOptions::max_splits),
+ DataMember("reverse", &SplitOptions::reverse));
+static auto kSplitPatternOptionsType = GetFunctionOptionsType<SplitPatternOptions>(
+ DataMember("pattern", &SplitPatternOptions::pattern),
+ DataMember("max_splits", &SplitPatternOptions::max_splits),
+ DataMember("reverse", &SplitPatternOptions::reverse));
+static auto kReplaceSliceOptionsType = GetFunctionOptionsType<ReplaceSliceOptions>(
+ DataMember("start", &ReplaceSliceOptions::start),
+ DataMember("stop", &ReplaceSliceOptions::stop),
+ DataMember("replacement", &ReplaceSliceOptions::replacement));
+static auto kReplaceSubstringOptionsType =
+ GetFunctionOptionsType<ReplaceSubstringOptions>(
+ DataMember("pattern", &ReplaceSubstringOptions::pattern),
+ DataMember("replacement", &ReplaceSubstringOptions::replacement),
+ DataMember("max_replacements", &ReplaceSubstringOptions::max_replacements));
+static auto kExtractRegexOptionsType = GetFunctionOptionsType<ExtractRegexOptions>(
+ DataMember("pattern", &ExtractRegexOptions::pattern));
+static auto kSetLookupOptionsType = GetFunctionOptionsType<SetLookupOptions>(
+ DataMember("value_set", &SetLookupOptions::value_set),
+ DataMember("skip_nulls", &SetLookupOptions::skip_nulls));
+static auto kStrptimeOptionsType = GetFunctionOptionsType<StrptimeOptions>(
+ DataMember("format", &StrptimeOptions::format),
+ DataMember("unit", &StrptimeOptions::unit));
+static auto kStrftimeOptionsType = GetFunctionOptionsType<StrftimeOptions>(
+ DataMember("format", &StrftimeOptions::format));
+static auto kAssumeTimezoneOptionsType = GetFunctionOptionsType<AssumeTimezoneOptions>(
+ DataMember("timezone", &AssumeTimezoneOptions::timezone),
+ DataMember("ambiguous", &AssumeTimezoneOptions::ambiguous),
+ DataMember("nonexistent", &AssumeTimezoneOptions::nonexistent));
+static auto kPadOptionsType = GetFunctionOptionsType<PadOptions>(
+ DataMember("width", &PadOptions::width), DataMember("padding", &PadOptions::padding));
+static auto kTrimOptionsType = GetFunctionOptionsType<TrimOptions>(
+ DataMember("characters", &TrimOptions::characters));
+static auto kSliceOptionsType = GetFunctionOptionsType<SliceOptions>(
+ DataMember("start", &SliceOptions::start), DataMember("stop", &SliceOptions::stop),
+ DataMember("step", &SliceOptions::step));
+static auto kMakeStructOptionsType = GetFunctionOptionsType<MakeStructOptions>(
+ DataMember("field_names", &MakeStructOptions::field_names),
+ DataMember("field_nullability", &MakeStructOptions::field_nullability),
+ DataMember("field_metadata", &MakeStructOptions::field_metadata));
+static auto kDayOfWeekOptionsType = GetFunctionOptionsType<DayOfWeekOptions>(
+ DataMember("count_from_zero", &DayOfWeekOptions::count_from_zero),
+ DataMember("week_start", &DayOfWeekOptions::week_start));
+static auto kWeekOptionsType = GetFunctionOptionsType<WeekOptions>(
+ DataMember("week_starts_monday", &WeekOptions::week_starts_monday),
+ DataMember("count_from_zero", &WeekOptions::count_from_zero),
+ DataMember("first_week_is_fully_in_year", &WeekOptions::first_week_is_fully_in_year));
+static auto kNullOptionsType = GetFunctionOptionsType<NullOptions>(
+ DataMember("nan_is_null", &NullOptions::nan_is_null));
+} // namespace
+} // namespace internal
+
+ArithmeticOptions::ArithmeticOptions(bool check_overflow)
+ : FunctionOptions(internal::kArithmeticOptionsType), check_overflow(check_overflow) {}
+constexpr char ArithmeticOptions::kTypeName[];
+
+ElementWiseAggregateOptions::ElementWiseAggregateOptions(bool skip_nulls)
+ : FunctionOptions(internal::kElementWiseAggregateOptionsType),
+ skip_nulls(skip_nulls) {}
+constexpr char ElementWiseAggregateOptions::kTypeName[];
+
+RoundOptions::RoundOptions(int64_t ndigits, RoundMode round_mode)
+ : FunctionOptions(internal::kRoundOptionsType),
+ ndigits(ndigits),
+ round_mode(round_mode) {
+ static_assert(RoundMode::HALF_DOWN > RoundMode::DOWN &&
+ RoundMode::HALF_DOWN > RoundMode::UP &&
+ RoundMode::HALF_DOWN > RoundMode::TOWARDS_ZERO &&
+ RoundMode::HALF_DOWN > RoundMode::TOWARDS_INFINITY &&
+ RoundMode::HALF_DOWN < RoundMode::HALF_UP &&
+ RoundMode::HALF_DOWN < RoundMode::HALF_TOWARDS_ZERO &&
+ RoundMode::HALF_DOWN < RoundMode::HALF_TOWARDS_INFINITY &&
+ RoundMode::HALF_DOWN < RoundMode::HALF_TO_EVEN &&
+ RoundMode::HALF_DOWN < RoundMode::HALF_TO_ODD,
+ "Invalid order of round modes. Modes prefixed with HALF need to be "
+ "enumerated last with HALF_DOWN being the first among them.");
+}
+constexpr char RoundOptions::kTypeName[];
+
+RoundToMultipleOptions::RoundToMultipleOptions(double multiple, RoundMode round_mode)
+ : RoundToMultipleOptions(std::make_shared<DoubleScalar>(multiple), round_mode) {}
+RoundToMultipleOptions::RoundToMultipleOptions(std::shared_ptr<Scalar> multiple,
+ RoundMode round_mode)
+ : FunctionOptions(internal::kRoundToMultipleOptionsType),
+ multiple(std::move(multiple)),
+ round_mode(round_mode) {}
+constexpr char RoundToMultipleOptions::kTypeName[];
+
+JoinOptions::JoinOptions(NullHandlingBehavior null_handling, std::string null_replacement)
+ : FunctionOptions(internal::kJoinOptionsType),
+ null_handling(null_handling),
+ null_replacement(std::move(null_replacement)) {}
+constexpr char JoinOptions::kTypeName[];
+
+MatchSubstringOptions::MatchSubstringOptions(std::string pattern, bool ignore_case)
+ : FunctionOptions(internal::kMatchSubstringOptionsType),
+ pattern(std::move(pattern)),
+ ignore_case(ignore_case) {}
+MatchSubstringOptions::MatchSubstringOptions() : MatchSubstringOptions("", false) {}
+constexpr char MatchSubstringOptions::kTypeName[];
+
+SplitOptions::SplitOptions(int64_t max_splits, bool reverse)
+ : FunctionOptions(internal::kSplitOptionsType),
+ max_splits(max_splits),
+ reverse(reverse) {}
+constexpr char SplitOptions::kTypeName[];
+
+SplitPatternOptions::SplitPatternOptions(std::string pattern, int64_t max_splits,
+ bool reverse)
+ : FunctionOptions(internal::kSplitPatternOptionsType),
+ pattern(std::move(pattern)),
+ max_splits(max_splits),
+ reverse(reverse) {}
+SplitPatternOptions::SplitPatternOptions() : SplitPatternOptions("", -1, false) {}
+constexpr char SplitPatternOptions::kTypeName[];
+
+ReplaceSliceOptions::ReplaceSliceOptions(int64_t start, int64_t stop,
+ std::string replacement)
+ : FunctionOptions(internal::kReplaceSliceOptionsType),
+ start(start),
+ stop(stop),
+ replacement(std::move(replacement)) {}
+ReplaceSliceOptions::ReplaceSliceOptions() : ReplaceSliceOptions(0, 0, "") {}
+constexpr char ReplaceSliceOptions::kTypeName[];
+
+ReplaceSubstringOptions::ReplaceSubstringOptions(std::string pattern,
+ std::string replacement,
+ int64_t max_replacements)
+ : FunctionOptions(internal::kReplaceSubstringOptionsType),
+ pattern(std::move(pattern)),
+ replacement(std::move(replacement)),
+ max_replacements(max_replacements) {}
+ReplaceSubstringOptions::ReplaceSubstringOptions()
+ : ReplaceSubstringOptions("", "", -1) {}
+constexpr char ReplaceSubstringOptions::kTypeName[];
+
+ExtractRegexOptions::ExtractRegexOptions(std::string pattern)
+ : FunctionOptions(internal::kExtractRegexOptionsType), pattern(std::move(pattern)) {}
+ExtractRegexOptions::ExtractRegexOptions() : ExtractRegexOptions("") {}
+constexpr char ExtractRegexOptions::kTypeName[];
+
+SetLookupOptions::SetLookupOptions(Datum value_set, bool skip_nulls)
+ : FunctionOptions(internal::kSetLookupOptionsType),
+ value_set(std::move(value_set)),
+ skip_nulls(skip_nulls) {}
+SetLookupOptions::SetLookupOptions() : SetLookupOptions({}, false) {}
+constexpr char SetLookupOptions::kTypeName[];
+
+StrptimeOptions::StrptimeOptions(std::string format, TimeUnit::type unit)
+ : FunctionOptions(internal::kStrptimeOptionsType),
+ format(std::move(format)),
+ unit(unit) {}
+StrptimeOptions::StrptimeOptions() : StrptimeOptions("", TimeUnit::SECOND) {}
+constexpr char StrptimeOptions::kTypeName[];
+
+StrftimeOptions::StrftimeOptions(std::string format, std::string locale)
+ : FunctionOptions(internal::kStrftimeOptionsType),
+ format(std::move(format)),
+ locale(std::move(locale)) {}
+StrftimeOptions::StrftimeOptions() : StrftimeOptions(kDefaultFormat) {}
+constexpr char StrftimeOptions::kTypeName[];
+constexpr const char* StrftimeOptions::kDefaultFormat;
+
+AssumeTimezoneOptions::AssumeTimezoneOptions(std::string timezone, Ambiguous ambiguous,
+ Nonexistent nonexistent)
+ : FunctionOptions(internal::kAssumeTimezoneOptionsType),
+ timezone(std::move(timezone)),
+ ambiguous(ambiguous),
+ nonexistent(nonexistent) {}
+AssumeTimezoneOptions::AssumeTimezoneOptions() : AssumeTimezoneOptions("UTC") {}
+constexpr char AssumeTimezoneOptions::kTypeName[];
+
+PadOptions::PadOptions(int64_t width, std::string padding)
+ : FunctionOptions(internal::kPadOptionsType),
+ width(width),
+ padding(std::move(padding)) {}
+PadOptions::PadOptions() : PadOptions(0, " ") {}
+constexpr char PadOptions::kTypeName[];
+
+TrimOptions::TrimOptions(std::string characters)
+ : FunctionOptions(internal::kTrimOptionsType), characters(std::move(characters)) {}
+TrimOptions::TrimOptions() : TrimOptions("") {}
+constexpr char TrimOptions::kTypeName[];
+
+SliceOptions::SliceOptions(int64_t start, int64_t stop, int64_t step)
+ : FunctionOptions(internal::kSliceOptionsType),
+ start(start),
+ stop(stop),
+ step(step) {}
+SliceOptions::SliceOptions() : SliceOptions(0, 0, 1) {}
+constexpr char SliceOptions::kTypeName[];
+
+MakeStructOptions::MakeStructOptions(
+ std::vector<std::string> n, std::vector<bool> r,
+ std::vector<std::shared_ptr<const KeyValueMetadata>> m)
+ : FunctionOptions(internal::kMakeStructOptionsType),
+ field_names(std::move(n)),
+ field_nullability(std::move(r)),
+ field_metadata(std::move(m)) {}
+
+MakeStructOptions::MakeStructOptions(std::vector<std::string> n)
+ : FunctionOptions(internal::kMakeStructOptionsType),
+ field_names(std::move(n)),
+ field_nullability(field_names.size(), true),
+ field_metadata(field_names.size(), NULLPTR) {}
+
+MakeStructOptions::MakeStructOptions() : MakeStructOptions(std::vector<std::string>()) {}
+constexpr char MakeStructOptions::kTypeName[];
+
+DayOfWeekOptions::DayOfWeekOptions(bool count_from_zero, uint32_t week_start)
+ : FunctionOptions(internal::kDayOfWeekOptionsType),
+ count_from_zero(count_from_zero),
+ week_start(week_start) {}
+constexpr char DayOfWeekOptions::kTypeName[];
+
+WeekOptions::WeekOptions(bool week_starts_monday, bool count_from_zero,
+ bool first_week_is_fully_in_year)
+ : FunctionOptions(internal::kWeekOptionsType),
+ week_starts_monday(week_starts_monday),
+ count_from_zero(count_from_zero),
+ first_week_is_fully_in_year(first_week_is_fully_in_year) {}
+constexpr char WeekOptions::kTypeName[];
+
+NullOptions::NullOptions(bool nan_is_null)
+ : FunctionOptions(internal::kNullOptionsType), nan_is_null(nan_is_null) {}
+constexpr char NullOptions::kTypeName[];
+
+namespace internal {
+void RegisterScalarOptions(FunctionRegistry* registry) {
+ DCHECK_OK(registry->AddFunctionOptionsType(kArithmeticOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kElementWiseAggregateOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kRoundOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kRoundToMultipleOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kJoinOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kMatchSubstringOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kSplitOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kSplitPatternOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kReplaceSliceOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kReplaceSubstringOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kSetLookupOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kStrptimeOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kStrftimeOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kAssumeTimezoneOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kPadOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kTrimOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kSliceOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kMakeStructOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kDayOfWeekOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kWeekOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kNullOptionsType));
+}
+} // namespace internal
+
+#define SCALAR_EAGER_UNARY(NAME, REGISTRY_NAME) \
+ Result<Datum> NAME(const Datum& value, ExecContext* ctx) { \
+ return CallFunction(REGISTRY_NAME, {value}, ctx); \
+ }
+
+#define SCALAR_EAGER_BINARY(NAME, REGISTRY_NAME) \
+ Result<Datum> NAME(const Datum& left, const Datum& right, ExecContext* ctx) { \
+ return CallFunction(REGISTRY_NAME, {left, right}, ctx); \
+ }
+
+// ----------------------------------------------------------------------
+// Arithmetic
+
+#define SCALAR_ARITHMETIC_UNARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \
+ Result<Datum> NAME(const Datum& arg, ArithmeticOptions options, ExecContext* ctx) { \
+ auto func_name = (options.check_overflow) ? REGISTRY_CHECKED_NAME : REGISTRY_NAME; \
+ return CallFunction(func_name, {arg}, ctx); \
+ }
+
+SCALAR_ARITHMETIC_UNARY(AbsoluteValue, "abs", "abs_checked")
+SCALAR_ARITHMETIC_UNARY(Negate, "negate", "negate_checked")
+SCALAR_EAGER_UNARY(Sign, "sign")
+SCALAR_ARITHMETIC_UNARY(Sin, "sin", "sin_checked")
+SCALAR_ARITHMETIC_UNARY(Cos, "cos", "cos_checked")
+SCALAR_ARITHMETIC_UNARY(Asin, "asin", "asin_checked")
+SCALAR_ARITHMETIC_UNARY(Acos, "acos", "acos_checked")
+SCALAR_ARITHMETIC_UNARY(Tan, "tan", "tan_checked")
+SCALAR_EAGER_UNARY(Atan, "atan")
+SCALAR_ARITHMETIC_UNARY(Ln, "ln", "ln_checked")
+SCALAR_ARITHMETIC_UNARY(Log10, "log10", "log10_checked")
+SCALAR_ARITHMETIC_UNARY(Log2, "log2", "log2_checked")
+SCALAR_ARITHMETIC_UNARY(Log1p, "log1p", "log1p_checked")
+
+Result<Datum> Round(const Datum& arg, RoundOptions options, ExecContext* ctx) {
+ return CallFunction("round", {arg}, &options, ctx);
+}
+
+Result<Datum> RoundToMultiple(const Datum& arg, RoundToMultipleOptions options,
+ ExecContext* ctx) {
+ return CallFunction("round_to_multiple", {arg}, &options, ctx);
+}
+
+#define SCALAR_ARITHMETIC_BINARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \
+ Result<Datum> NAME(const Datum& left, const Datum& right, ArithmeticOptions options, \
+ ExecContext* ctx) { \
+ auto func_name = (options.check_overflow) ? REGISTRY_CHECKED_NAME : REGISTRY_NAME; \
+ return CallFunction(func_name, {left, right}, ctx); \
+ }
+
+SCALAR_ARITHMETIC_BINARY(Add, "add", "add_checked")
+SCALAR_ARITHMETIC_BINARY(Subtract, "subtract", "subtract_checked")
+SCALAR_ARITHMETIC_BINARY(Multiply, "multiply", "multiply_checked")
+SCALAR_ARITHMETIC_BINARY(Divide, "divide", "divide_checked")
+SCALAR_ARITHMETIC_BINARY(Power, "power", "power_checked")
+SCALAR_ARITHMETIC_BINARY(ShiftLeft, "shift_left", "shift_left_checked")
+SCALAR_ARITHMETIC_BINARY(ShiftRight, "shift_right", "shift_right_checked")
+SCALAR_ARITHMETIC_BINARY(Logb, "logb", "logb_checked")
+SCALAR_EAGER_BINARY(Atan2, "atan2")
+SCALAR_EAGER_UNARY(Floor, "floor")
+SCALAR_EAGER_UNARY(Ceil, "ceil")
+SCALAR_EAGER_UNARY(Trunc, "trunc")
+
+Result<Datum> MaxElementWise(const std::vector<Datum>& args,
+ ElementWiseAggregateOptions options, ExecContext* ctx) {
+ return CallFunction("max_element_wise", args, &options, ctx);
+}
+
+Result<Datum> MinElementWise(const std::vector<Datum>& args,
+ ElementWiseAggregateOptions options, ExecContext* ctx) {
+ return CallFunction("min_element_wise", args, &options, ctx);
+}
+
+// ----------------------------------------------------------------------
+// Set-related operations
+
+static Result<Datum> ExecSetLookup(const std::string& func_name, const Datum& data,
+ const SetLookupOptions& options, ExecContext* ctx) {
+ if (!options.value_set.is_arraylike()) {
+ return Status::Invalid("Set lookup value set must be Array or ChunkedArray");
+ }
+ std::shared_ptr<DataType> data_type;
+ if (data.type()->id() == Type::DICTIONARY) {
+ data_type =
+ arrow::internal::checked_pointer_cast<DictionaryType>(data.type())->value_type();
+ } else {
+ data_type = data.type();
+ }
+
+ if (options.value_set.length() > 0 && !data_type->Equals(options.value_set.type())) {
+ std::stringstream ss;
+ ss << "Array type didn't match type of values set: " << data_type->ToString()
+ << " vs " << options.value_set.type()->ToString();
+ return Status::Invalid(ss.str());
+ }
+ return CallFunction(func_name, {data}, &options, ctx);
+}
+
+Result<Datum> IsIn(const Datum& values, const SetLookupOptions& options,
+ ExecContext* ctx) {
+ return ExecSetLookup("is_in", values, options, ctx);
+}
+
+Result<Datum> IsIn(const Datum& values, const Datum& value_set, ExecContext* ctx) {
+ return ExecSetLookup("is_in", values, SetLookupOptions{value_set}, ctx);
+}
+
+Result<Datum> IndexIn(const Datum& values, const SetLookupOptions& options,
+ ExecContext* ctx) {
+ return ExecSetLookup("index_in", values, options, ctx);
+}
+
+Result<Datum> IndexIn(const Datum& values, const Datum& value_set, ExecContext* ctx) {
+ return ExecSetLookup("index_in", values, SetLookupOptions{value_set}, ctx);
+}
+
+// ----------------------------------------------------------------------
+// Boolean functions
+
+SCALAR_EAGER_UNARY(Invert, "invert")
+SCALAR_EAGER_BINARY(And, "and")
+SCALAR_EAGER_BINARY(KleeneAnd, "and_kleene")
+SCALAR_EAGER_BINARY(Or, "or")
+SCALAR_EAGER_BINARY(KleeneOr, "or_kleene")
+SCALAR_EAGER_BINARY(Xor, "xor")
+SCALAR_EAGER_BINARY(AndNot, "and_not")
+SCALAR_EAGER_BINARY(KleeneAndNot, "and_not_kleene")
+
+// ----------------------------------------------------------------------
+
+Result<Datum> Compare(const Datum& left, const Datum& right, CompareOptions options,
+ ExecContext* ctx) {
+ std::string func_name;
+ switch (options.op) {
+ case CompareOperator::EQUAL:
+ func_name = "equal";
+ break;
+ case CompareOperator::NOT_EQUAL:
+ func_name = "not_equal";
+ break;
+ case CompareOperator::GREATER:
+ func_name = "greater";
+ break;
+ case CompareOperator::GREATER_EQUAL:
+ func_name = "greater_equal";
+ break;
+ case CompareOperator::LESS:
+ func_name = "less";
+ break;
+ case CompareOperator::LESS_EQUAL:
+ func_name = "less_equal";
+ break;
+ }
+ return CallFunction(func_name, {left, right}, nullptr, ctx);
+}
+
+// ----------------------------------------------------------------------
+// Validity functions
+
+SCALAR_EAGER_UNARY(IsValid, "is_valid")
+SCALAR_EAGER_UNARY(IsNan, "is_nan")
+
+Result<Datum> IfElse(const Datum& cond, const Datum& if_true, const Datum& if_false,
+ ExecContext* ctx) {
+ return CallFunction("if_else", {cond, if_true, if_false}, ctx);
+}
+
+Result<Datum> CaseWhen(const Datum& cond, const std::vector<Datum>& cases,
+ ExecContext* ctx) {
+ std::vector<Datum> args = {cond};
+ args.reserve(cases.size() + 1);
+ args.insert(args.end(), cases.begin(), cases.end());
+ return CallFunction("case_when", args, ctx);
+}
+
+Result<Datum> IsNull(const Datum& arg, NullOptions options, ExecContext* ctx) {
+ return CallFunction("is_null", {arg}, &options, ctx);
+}
+
+// ----------------------------------------------------------------------
+// Temporal functions
+
+SCALAR_EAGER_UNARY(Year, "year")
+SCALAR_EAGER_UNARY(Month, "month")
+SCALAR_EAGER_UNARY(Day, "day")
+SCALAR_EAGER_UNARY(DayOfYear, "day_of_year")
+SCALAR_EAGER_UNARY(ISOYear, "iso_year")
+SCALAR_EAGER_UNARY(ISOWeek, "iso_week")
+SCALAR_EAGER_UNARY(USWeek, "us_week")
+SCALAR_EAGER_UNARY(ISOCalendar, "iso_calendar")
+SCALAR_EAGER_UNARY(Quarter, "quarter")
+SCALAR_EAGER_UNARY(Hour, "hour")
+SCALAR_EAGER_UNARY(Minute, "minute")
+SCALAR_EAGER_UNARY(Second, "second")
+SCALAR_EAGER_UNARY(Millisecond, "millisecond")
+SCALAR_EAGER_UNARY(Microsecond, "microsecond")
+SCALAR_EAGER_UNARY(Nanosecond, "nanosecond")
+SCALAR_EAGER_UNARY(Subsecond, "subsecond")
+
+Result<Datum> DayOfWeek(const Datum& arg, DayOfWeekOptions options, ExecContext* ctx) {
+ return CallFunction("day_of_week", {arg}, &options, ctx);
+}
+
+Result<Datum> AssumeTimezone(const Datum& arg, AssumeTimezoneOptions options,
+ ExecContext* ctx) {
+ return CallFunction("assume_timezone", {arg}, &options, ctx);
+}
+
+Result<Datum> Week(const Datum& arg, WeekOptions options, ExecContext* ctx) {
+ return CallFunction("week", {arg}, &options, ctx);
+}
+
+Result<Datum> Strftime(const Datum& arg, StrftimeOptions options, ExecContext* ctx) {
+ return CallFunction("strftime", {arg}, &options, ctx);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/api_scalar.h b/src/arrow/cpp/src/arrow/compute/api_scalar.h
new file mode 100644
index 000000000..4bb18b375
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/api_scalar.h
@@ -0,0 +1,1219 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Eager evaluation convenience APIs for invoking common functions, including
+// necessary memory allocations
+
+#pragma once
+
+#include <string>
+#include <utility>
+
+#include "arrow/compute/function.h"
+#include "arrow/compute/type_fwd.h"
+#include "arrow/datum.h"
+#include "arrow/result.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace compute {
+
+/// \addtogroup compute-concrete-options
+///
+/// @{
+
+class ARROW_EXPORT ArithmeticOptions : public FunctionOptions {
+ public:
+ explicit ArithmeticOptions(bool check_overflow = false);
+ constexpr static char const kTypeName[] = "ArithmeticOptions";
+ bool check_overflow;
+};
+
+class ARROW_EXPORT ElementWiseAggregateOptions : public FunctionOptions {
+ public:
+ explicit ElementWiseAggregateOptions(bool skip_nulls = true);
+ constexpr static char const kTypeName[] = "ElementWiseAggregateOptions";
+ static ElementWiseAggregateOptions Defaults() { return ElementWiseAggregateOptions{}; }
+ bool skip_nulls;
+};
+
+/// Rounding and tie-breaking modes for round compute functions.
+/// Additional details and examples are provided in compute.rst.
+enum class RoundMode : int8_t {
+ /// Round to nearest integer less than or equal in magnitude (aka "floor")
+ DOWN,
+ /// Round to nearest integer greater than or equal in magnitude (aka "ceil")
+ UP,
+ /// Get the integral part without fractional digits (aka "trunc")
+ TOWARDS_ZERO,
+ /// Round negative values with DOWN rule and positive values with UP rule
+ TOWARDS_INFINITY,
+ /// Round ties with DOWN rule
+ HALF_DOWN,
+ /// Round ties with UP rule
+ HALF_UP,
+ /// Round ties with TOWARDS_ZERO rule
+ HALF_TOWARDS_ZERO,
+ /// Round ties with TOWARDS_INFINITY rule
+ HALF_TOWARDS_INFINITY,
+ /// Round ties to nearest even integer
+ HALF_TO_EVEN,
+ /// Round ties to nearest odd integer
+ HALF_TO_ODD,
+};
+
+class ARROW_EXPORT RoundOptions : public FunctionOptions {
+ public:
+ explicit RoundOptions(int64_t ndigits = 0,
+ RoundMode round_mode = RoundMode::HALF_TO_EVEN);
+ constexpr static char const kTypeName[] = "RoundOptions";
+ static RoundOptions Defaults() { return RoundOptions(); }
+ /// Rounding precision (number of digits to round to)
+ int64_t ndigits;
+ /// Rounding and tie-breaking mode
+ RoundMode round_mode;
+};
+
+class ARROW_EXPORT RoundToMultipleOptions : public FunctionOptions {
+ public:
+ explicit RoundToMultipleOptions(double multiple = 1.0,
+ RoundMode round_mode = RoundMode::HALF_TO_EVEN);
+ explicit RoundToMultipleOptions(std::shared_ptr<Scalar> multiple,
+ RoundMode round_mode = RoundMode::HALF_TO_EVEN);
+ constexpr static char const kTypeName[] = "RoundToMultipleOptions";
+ static RoundToMultipleOptions Defaults() { return RoundToMultipleOptions(); }
+ /// Rounding scale (multiple to round to).
+ ///
+ /// Should be a scalar of a type compatible with the argument to be rounded.
+ /// For example, rounding a decimal value means a decimal multiple is
+ /// required. Rounding a floating point or integer value means a floating
+ /// point scalar is required.
+ std::shared_ptr<Scalar> multiple;
+ /// Rounding and tie-breaking mode
+ RoundMode round_mode;
+};
+
+/// Options for var_args_join.
+class ARROW_EXPORT JoinOptions : public FunctionOptions {
+ public:
+ /// How to handle null values. (A null separator always results in a null output.)
+ enum NullHandlingBehavior {
+ /// A null in any input results in a null in the output.
+ EMIT_NULL,
+ /// Nulls in inputs are skipped.
+ SKIP,
+ /// Nulls in inputs are replaced with the replacement string.
+ REPLACE,
+ };
+ explicit JoinOptions(NullHandlingBehavior null_handling = EMIT_NULL,
+ std::string null_replacement = "");
+ constexpr static char const kTypeName[] = "JoinOptions";
+ static JoinOptions Defaults() { return JoinOptions(); }
+ NullHandlingBehavior null_handling;
+ std::string null_replacement;
+};
+
+class ARROW_EXPORT MatchSubstringOptions : public FunctionOptions {
+ public:
+ explicit MatchSubstringOptions(std::string pattern, bool ignore_case = false);
+ MatchSubstringOptions();
+ constexpr static char const kTypeName[] = "MatchSubstringOptions";
+
+ /// The exact substring (or regex, depending on kernel) to look for inside input values.
+ std::string pattern;
+ /// Whether to perform a case-insensitive match.
+ bool ignore_case;
+};
+
+class ARROW_EXPORT SplitOptions : public FunctionOptions {
+ public:
+ explicit SplitOptions(int64_t max_splits = -1, bool reverse = false);
+ constexpr static char const kTypeName[] = "SplitOptions";
+
+ /// Maximum number of splits allowed, or unlimited when -1
+ int64_t max_splits;
+ /// Start splitting from the end of the string (only relevant when max_splits != -1)
+ bool reverse;
+};
+
+class ARROW_EXPORT SplitPatternOptions : public FunctionOptions {
+ public:
+ explicit SplitPatternOptions(std::string pattern, int64_t max_splits = -1,
+ bool reverse = false);
+ SplitPatternOptions();
+ constexpr static char const kTypeName[] = "SplitPatternOptions";
+
+ /// The exact substring to split on.
+ std::string pattern;
+ /// Maximum number of splits allowed, or unlimited when -1
+ int64_t max_splits;
+ /// Start splitting from the end of the string (only relevant when max_splits != -1)
+ bool reverse;
+};
+
+class ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions {
+ public:
+ explicit ReplaceSliceOptions(int64_t start, int64_t stop, std::string replacement);
+ ReplaceSliceOptions();
+ constexpr static char const kTypeName[] = "ReplaceSliceOptions";
+
+ /// Index to start slicing at
+ int64_t start;
+ /// Index to stop slicing at
+ int64_t stop;
+ /// String to replace the slice with
+ std::string replacement;
+};
+
+class ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions {
+ public:
+ explicit ReplaceSubstringOptions(std::string pattern, std::string replacement,
+ int64_t max_replacements = -1);
+ ReplaceSubstringOptions();
+ constexpr static char const kTypeName[] = "ReplaceSubstringOptions";
+
+ /// Pattern to match, literal, or regular expression depending on which kernel is used
+ std::string pattern;
+ /// String to replace the pattern with
+ std::string replacement;
+ /// Max number of substrings to replace (-1 means unbounded)
+ int64_t max_replacements;
+};
+
+class ARROW_EXPORT ExtractRegexOptions : public FunctionOptions {
+ public:
+ explicit ExtractRegexOptions(std::string pattern);
+ ExtractRegexOptions();
+ constexpr static char const kTypeName[] = "ExtractRegexOptions";
+
+ /// Regular expression with named capture fields
+ std::string pattern;
+};
+
+/// Options for IsIn and IndexIn functions
+class ARROW_EXPORT SetLookupOptions : public FunctionOptions {
+ public:
+ explicit SetLookupOptions(Datum value_set, bool skip_nulls = false);
+ SetLookupOptions();
+ constexpr static char const kTypeName[] = "SetLookupOptions";
+
+ /// The set of values to look up input values into.
+ Datum value_set;
+ /// Whether nulls in `value_set` count for lookup.
+ ///
+ /// If true, any null in `value_set` is ignored and nulls in the input
+ /// produce null (IndexIn) or false (IsIn) values in the output.
+ /// If false, any null in `value_set` is successfully matched in
+ /// the input.
+ bool skip_nulls;
+};
+
+class ARROW_EXPORT StrptimeOptions : public FunctionOptions {
+ public:
+ explicit StrptimeOptions(std::string format, TimeUnit::type unit);
+ StrptimeOptions();
+ constexpr static char const kTypeName[] = "StrptimeOptions";
+
+ std::string format;
+ TimeUnit::type unit;
+};
+
+class ARROW_EXPORT StrftimeOptions : public FunctionOptions {
+ public:
+ explicit StrftimeOptions(std::string format, std::string locale = "C");
+ StrftimeOptions();
+
+ constexpr static char const kTypeName[] = "StrftimeOptions";
+
+ constexpr static const char* kDefaultFormat = "%Y-%m-%dT%H:%M:%S";
+
+ /// The desired format string.
+ std::string format;
+ /// The desired output locale string.
+ std::string locale;
+};
+
+class ARROW_EXPORT PadOptions : public FunctionOptions {
+ public:
+ explicit PadOptions(int64_t width, std::string padding = " ");
+ PadOptions();
+ constexpr static char const kTypeName[] = "PadOptions";
+
+ /// The desired string length.
+ int64_t width;
+ /// What to pad the string with. Should be one codepoint (Unicode)/byte (ASCII).
+ std::string padding;
+};
+
+class ARROW_EXPORT TrimOptions : public FunctionOptions {
+ public:
+ explicit TrimOptions(std::string characters);
+ TrimOptions();
+ constexpr static char const kTypeName[] = "TrimOptions";
+
+ /// The individual characters that can be trimmed from the string.
+ std::string characters;
+};
+
+class ARROW_EXPORT SliceOptions : public FunctionOptions {
+ public:
+ explicit SliceOptions(int64_t start, int64_t stop = std::numeric_limits<int64_t>::max(),
+ int64_t step = 1);
+ SliceOptions();
+ constexpr static char const kTypeName[] = "SliceOptions";
+ int64_t start, stop, step;
+};
+
+class ARROW_EXPORT NullOptions : public FunctionOptions {
+ public:
+ explicit NullOptions(bool nan_is_null = false);
+ constexpr static char const kTypeName[] = "NullOptions";
+ static NullOptions Defaults() { return NullOptions{}; }
+
+ bool nan_is_null;
+};
+
+enum CompareOperator : int8_t {
+ EQUAL,
+ NOT_EQUAL,
+ GREATER,
+ GREATER_EQUAL,
+ LESS,
+ LESS_EQUAL,
+};
+
+struct ARROW_EXPORT CompareOptions {
+ explicit CompareOptions(CompareOperator op) : op(op) {}
+ CompareOptions() : CompareOptions(CompareOperator::EQUAL) {}
+ enum CompareOperator op;
+};
+
+class ARROW_EXPORT MakeStructOptions : public FunctionOptions {
+ public:
+ MakeStructOptions(std::vector<std::string> n, std::vector<bool> r,
+ std::vector<std::shared_ptr<const KeyValueMetadata>> m);
+ explicit MakeStructOptions(std::vector<std::string> n);
+ MakeStructOptions();
+ constexpr static char const kTypeName[] = "MakeStructOptions";
+
+ /// Names for wrapped columns
+ std::vector<std::string> field_names;
+
+ /// Nullability bits for wrapped columns
+ std::vector<bool> field_nullability;
+
+ /// Metadata attached to wrapped columns
+ std::vector<std::shared_ptr<const KeyValueMetadata>> field_metadata;
+};
+
+struct ARROW_EXPORT DayOfWeekOptions : public FunctionOptions {
+ public:
+ explicit DayOfWeekOptions(bool count_from_zero = true, uint32_t week_start = 1);
+ constexpr static char const kTypeName[] = "DayOfWeekOptions";
+ static DayOfWeekOptions Defaults() { return DayOfWeekOptions(); }
+
+ /// Number days from 0 if true and from 1 if false
+ bool count_from_zero;
+ /// What day does the week start with (Monday=1, Sunday=7).
+ /// The numbering is unaffected by the count_from_zero parameter.
+ uint32_t week_start;
+};
+
+/// Used to control timestamp timezone conversion and handling ambiguous/nonexistent
+/// times.
+struct ARROW_EXPORT AssumeTimezoneOptions : public FunctionOptions {
+ public:
+ /// \brief How to interpret ambiguous local times that can be interpreted as
+ /// multiple instants (normally two) due to DST shifts.
+ ///
+ /// AMBIGUOUS_EARLIEST emits the earliest instant amongst possible interpretations.
+ /// AMBIGUOUS_LATEST emits the latest instant amongst possible interpretations.
+ enum Ambiguous { AMBIGUOUS_RAISE, AMBIGUOUS_EARLIEST, AMBIGUOUS_LATEST };
+
+ /// \brief How to handle local times that do not exist due to DST shifts.
+ ///
+ /// NONEXISTENT_EARLIEST emits the instant "just before" the DST shift instant
+ /// in the given timestamp precision (for example, for a nanoseconds precision
+ /// timestamp, this is one nanosecond before the DST shift instant).
+ /// NONEXISTENT_LATEST emits the DST shift instant.
+ enum Nonexistent { NONEXISTENT_RAISE, NONEXISTENT_EARLIEST, NONEXISTENT_LATEST };
+
+ explicit AssumeTimezoneOptions(std::string timezone,
+ Ambiguous ambiguous = AMBIGUOUS_RAISE,
+ Nonexistent nonexistent = NONEXISTENT_RAISE);
+ AssumeTimezoneOptions();
+ constexpr static char const kTypeName[] = "AssumeTimezoneOptions";
+
+ /// Timezone to convert timestamps from
+ std::string timezone;
+
+ /// How to interpret ambiguous local times (due to DST shifts)
+ Ambiguous ambiguous;
+ /// How to interpret non-existent local times (due to DST shifts)
+ Nonexistent nonexistent;
+};
+
+struct ARROW_EXPORT WeekOptions : public FunctionOptions {
+ public:
+ explicit WeekOptions(bool week_starts_monday = true, bool count_from_zero = false,
+ bool first_week_is_fully_in_year = false);
+ constexpr static char const kTypeName[] = "WeekOptions";
+ static WeekOptions Defaults() { return WeekOptions{}; }
+ static WeekOptions ISODefaults() {
+ return WeekOptions{/*week_starts_monday*/ true,
+ /*count_from_zero=*/false,
+ /*first_week_is_fully_in_year=*/false};
+ }
+ static WeekOptions USDefaults() {
+ return WeekOptions{/*week_starts_monday*/ false,
+ /*count_from_zero=*/false,
+ /*first_week_is_fully_in_year=*/false};
+ }
+
+ /// What day does the week start with (Monday=true, Sunday=false)
+ bool week_starts_monday;
+ /// Dates from current year that fall into last ISO week of the previous year return
+ /// 0 if true and 52 or 53 if false.
+ bool count_from_zero;
+ /// Must the first week be fully in January (true), or is a week that begins on
+ /// December 29, 30, or 31 considered to be the first week of the new year (false)?
+ bool first_week_is_fully_in_year;
+};
+
+/// @}
+
+/// \brief Get the absolute value of a value.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg the value transformed
+/// \param[in] options arithmetic options (overflow handling), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise absolute value
+ARROW_EXPORT
+Result<Datum> AbsoluteValue(const Datum& arg,
+ ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Add two values together. Array values must be the same length. If
+/// either addend is null the result will be null.
+///
+/// \param[in] left the first addend
+/// \param[in] right the second addend
+/// \param[in] options arithmetic options (overflow handling), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise sum
+ARROW_EXPORT
+Result<Datum> Add(const Datum& left, const Datum& right,
+ ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Subtract two values. Array values must be the same length. If the
+/// minuend or subtrahend is null the result will be null.
+///
+/// \param[in] left the value subtracted from (minuend)
+/// \param[in] right the value by which the minuend is reduced (subtrahend)
+/// \param[in] options arithmetic options (overflow handling), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise difference
+ARROW_EXPORT
+Result<Datum> Subtract(const Datum& left, const Datum& right,
+ ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Multiply two values. Array values must be the same length. If either
+/// factor is null the result will be null.
+///
+/// \param[in] left the first factor
+/// \param[in] right the second factor
+/// \param[in] options arithmetic options (overflow handling), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise product
+ARROW_EXPORT
+Result<Datum> Multiply(const Datum& left, const Datum& right,
+ ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Divide two values. Array values must be the same length. If either
+/// argument is null the result will be null. For integer types, if there is
+/// a zero divisor, an error will be raised.
+///
+/// \param[in] left the dividend
+/// \param[in] right the divisor
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise quotient
+ARROW_EXPORT
+Result<Datum> Divide(const Datum& left, const Datum& right,
+ ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Negate values.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg the value negated
+/// \param[in] options arithmetic options (overflow handling), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise negation
+ARROW_EXPORT
+Result<Datum> Negate(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Raise the values of base array to the power of the exponent array values.
+/// Array values must be the same length. If either base or exponent is null the result
+/// will be null.
+///
+/// \param[in] left the base
+/// \param[in] right the exponent
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise base value raised to the power of exponent
+ARROW_EXPORT
+Result<Datum> Power(const Datum& left, const Datum& right,
+ ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Left shift the left array by the right array. Array values must be the
+/// same length. If either operand is null, the result will be null.
+///
+/// \param[in] left the value to shift
+/// \param[in] right the value to shift by
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise left value shifted left by the right value
+ARROW_EXPORT
+Result<Datum> ShiftLeft(const Datum& left, const Datum& right,
+ ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Right shift the left array by the right array. Array values must be the
+/// same length. If either operand is null, the result will be null. Performs a
+/// logical shift for unsigned values, and an arithmetic shift for signed values.
+///
+/// \param[in] left the value to shift
+/// \param[in] right the value to shift by
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise left value shifted right by the right value
+ARROW_EXPORT
+Result<Datum> ShiftRight(const Datum& left, const Datum& right,
+ ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the sine of the array values.
+/// \param[in] arg The values to compute the sine for.
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise sine of the values
+ARROW_EXPORT
+Result<Datum> Sin(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the cosine of the array values.
+/// \param[in] arg The values to compute the cosine for.
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise cosine of the values
+ARROW_EXPORT
+Result<Datum> Cos(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the inverse sine (arcsine) of the array values.
+/// \param[in] arg The values to compute the inverse sine for.
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise inverse sine of the values
+ARROW_EXPORT
+Result<Datum> Asin(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the inverse cosine (arccosine) of the array values.
+/// \param[in] arg The values to compute the inverse cosine for.
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise inverse cosine of the values
+ARROW_EXPORT
+Result<Datum> Acos(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the tangent of the array values.
+/// \param[in] arg The values to compute the tangent for.
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise tangent of the values
+ARROW_EXPORT
+Result<Datum> Tan(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the inverse tangent (arctangent) of the array values.
+/// \param[in] arg The values to compute the inverse tangent for.
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise inverse tangent of the values
+ARROW_EXPORT
+Result<Datum> Atan(const Datum& arg, ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the inverse tangent (arctangent) of y/x, using the
+/// argument signs to determine the correct quadrant.
+/// \param[in] y The y-values to compute the inverse tangent for.
+/// \param[in] x The x-values to compute the inverse tangent for.
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise inverse tangent of the values
+ARROW_EXPORT
+Result<Datum> Atan2(const Datum& y, const Datum& x, ExecContext* ctx = NULLPTR);
+
+/// \brief Get the natural log of a value.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg The values to compute the logarithm for.
+/// \param[in] options arithmetic options (overflow handling), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise natural log
+ARROW_EXPORT
+Result<Datum> Ln(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Get the log base 10 of a value.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg The values to compute the logarithm for.
+/// \param[in] options arithmetic options (overflow handling), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise log base 10
+ARROW_EXPORT
+Result<Datum> Log10(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Get the log base 2 of a value.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg The values to compute the logarithm for.
+/// \param[in] options arithmetic options (overflow handling), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise log base 2
+ARROW_EXPORT
+Result<Datum> Log2(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Get the natural log of (1 + value).
+///
+/// If argument is null the result will be null.
+/// This function may be more accurate than Log(1 + value) for values close to zero.
+///
+/// \param[in] arg The values to compute the logarithm for.
+/// \param[in] options arithmetic options (overflow handling), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise natural log
+ARROW_EXPORT
+Result<Datum> Log1p(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Get the log of a value to the given base.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg The values to compute the logarithm for.
+/// \param[in] base The given base.
+/// \param[in] options arithmetic options (overflow handling), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise log to the given base
+ARROW_EXPORT
+Result<Datum> Logb(const Datum& arg, const Datum& base,
+ ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Round to the nearest integer less than or equal in magnitude to the
+/// argument.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg the value to round
+/// \param[in] ctx the function execution context, optional
+/// \return the rounded value
+ARROW_EXPORT
+Result<Datum> Floor(const Datum& arg, ExecContext* ctx = NULLPTR);
+
+/// \brief Round to the nearest integer greater than or equal in magnitude to the
+/// argument.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg the value to round
+/// \param[in] ctx the function execution context, optional
+/// \return the rounded value
+ARROW_EXPORT
+Result<Datum> Ceil(const Datum& arg, ExecContext* ctx = NULLPTR);
+
+/// \brief Get the integral part without fractional digits.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg the value to truncate
+/// \param[in] ctx the function execution context, optional
+/// \return the truncated value
+ARROW_EXPORT
+Result<Datum> Trunc(const Datum& arg, ExecContext* ctx = NULLPTR);
+
+/// \brief Find the element-wise maximum of any number of arrays or scalars.
+/// Array values must be the same length.
+///
+/// \param[in] args arrays or scalars to operate on.
+/// \param[in] options options for handling nulls, optional
+/// \param[in] ctx the function execution context, optional
+/// \return the element-wise maximum
+ARROW_EXPORT
+Result<Datum> MaxElementWise(
+ const std::vector<Datum>& args,
+ ElementWiseAggregateOptions options = ElementWiseAggregateOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Find the element-wise minimum of any number of arrays or scalars.
+/// Array values must be the same length.
+///
+/// \param[in] args arrays or scalars to operate on.
+/// \param[in] options options for handling nulls, optional
+/// \param[in] ctx the function execution context, optional
+/// \return the element-wise minimum
+ARROW_EXPORT
+Result<Datum> MinElementWise(
+ const std::vector<Datum>& args,
+ ElementWiseAggregateOptions options = ElementWiseAggregateOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Get the sign of a value. Array values can be of arbitrary length. If argument
+/// is null the result will be null.
+///
+/// \param[in] arg the value to extract sign from
+/// \param[in] ctx the function execution context, optional
+/// \return the element-wise sign function
+ARROW_EXPORT
+Result<Datum> Sign(const Datum& arg, ExecContext* ctx = NULLPTR);
+
+/// \brief Round a value to a given precision.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg the value rounded
+/// \param[in] options rounding options (rounding mode and number of digits), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the element-wise rounded value
+ARROW_EXPORT
+Result<Datum> Round(const Datum& arg, RoundOptions options = RoundOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Round a value to a given multiple.
+///
+/// If argument is null the result will be null.
+///
+/// \param[in] arg the value to round
+/// \param[in] options rounding options (rounding mode and multiple), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the element-wise rounded value
+ARROW_EXPORT
+Result<Datum> RoundToMultiple(
+ const Datum& arg, RoundToMultipleOptions options = RoundToMultipleOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compare a numeric array with a scalar.
+///
+/// \param[in] left datum to compare, must be an Array
+/// \param[in] right datum to compare, must be a Scalar of the same type than
+/// left Datum.
+/// \param[in] options compare options
+/// \param[in] ctx the function execution context, optional
+/// \return resulting datum
+///
+/// Note on floating point arrays, this uses ieee-754 compare semantics.
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_DEPRECATED("Deprecated in 5.0.0. Use each compare function directly")
+ARROW_EXPORT
+Result<Datum> Compare(const Datum& left, const Datum& right, CompareOptions options,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Invert the values of a boolean datum
+/// \param[in] value datum to invert
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Invert(const Datum& value, ExecContext* ctx = NULLPTR);
+
+/// \brief Element-wise AND of two boolean datums which always propagates nulls
+/// (null and false is null).
+///
+/// \param[in] left left operand
+/// \param[in] right right operand
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> And(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR);
+
+/// \brief Element-wise AND of two boolean datums with a Kleene truth table
+/// (null and false is false).
+///
+/// \param[in] left left operand
+/// \param[in] right right operand
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> KleeneAnd(const Datum& left, const Datum& right,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Element-wise OR of two boolean datums which always propagates nulls
+/// (null and true is null).
+///
+/// \param[in] left left operand
+/// \param[in] right right operand
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Or(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR);
+
+/// \brief Element-wise OR of two boolean datums with a Kleene truth table
+/// (null or true is true).
+///
+/// \param[in] left left operand
+/// \param[in] right right operand
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> KleeneOr(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR);
+
+/// \brief Element-wise XOR of two boolean datums
+/// \param[in] left left operand
+/// \param[in] right right operand
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Xor(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR);
+
+/// \brief Element-wise AND NOT of two boolean datums which always propagates nulls
+/// (null and not true is null).
+///
+/// \param[in] left left operand
+/// \param[in] right right operand
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 3.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> AndNot(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR);
+
+/// \brief Element-wise AND NOT of two boolean datums with a Kleene truth table
+/// (false and not null is false, null and not true is false).
+///
+/// \param[in] left left operand
+/// \param[in] right right operand
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 3.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> KleeneAndNot(const Datum& left, const Datum& right,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief IsIn returns true for each element of `values` that is contained in
+/// `value_set`
+///
+/// Behaviour of nulls is governed by SetLookupOptions::skip_nulls.
+///
+/// \param[in] values array-like input to look up in value_set
+/// \param[in] options SetLookupOptions
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> IsIn(const Datum& values, const SetLookupOptions& options,
+ ExecContext* ctx = NULLPTR);
+ARROW_EXPORT
+Result<Datum> IsIn(const Datum& values, const Datum& value_set,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief IndexIn examines each slot in the values against a value_set array.
+/// If the value is not found in value_set, null will be output.
+/// If found, the index of occurrence within value_set (ignoring duplicates)
+/// will be output.
+///
+/// For example given values = [99, 42, 3, null] and
+/// value_set = [3, 3, 99], the output will be = [2, null, 0, null]
+///
+/// Behaviour of nulls is governed by SetLookupOptions::skip_nulls.
+///
+/// \param[in] values array-like input
+/// \param[in] options SetLookupOptions
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> IndexIn(const Datum& values, const SetLookupOptions& options,
+ ExecContext* ctx = NULLPTR);
+ARROW_EXPORT
+Result<Datum> IndexIn(const Datum& values, const Datum& value_set,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief IsValid returns true for each element of `values` that is not null,
+/// false otherwise
+///
+/// \param[in] values input to examine for validity
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> IsValid(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief IsNull returns true for each element of `values` that is null,
+/// false otherwise
+///
+/// \param[in] values input to examine for nullity
+/// \param[in] options NullOptions
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> IsNull(const Datum& values, NullOptions options = NullOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief IsNan returns true for each element of `values` that is NaN,
+/// false otherwise
+///
+/// \param[in] values input to look for NaN
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 3.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> IsNan(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief IfElse returns elements chosen from `left` or `right`
+/// depending on `cond`. `null` values in `cond` will be promoted to the result
+///
+/// \param[in] cond `Boolean` condition Scalar/ Array
+/// \param[in] left Scalar/ Array
+/// \param[in] right Scalar/ Array
+/// \param[in] ctx the function execution context, optional
+///
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> IfElse(const Datum& cond, const Datum& left, const Datum& right,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief CaseWhen behaves like a switch/case or if-else if-else statement: for
+/// each row, select the first value for which the corresponding condition is
+/// true, or (if given) select the 'else' value, else emit null. Note that a
+/// null condition is the same as false.
+///
+/// \param[in] cond Conditions (Boolean)
+/// \param[in] cases Values (any type), along with an optional 'else' value.
+/// \param[in] ctx the function execution context, optional
+///
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> CaseWhen(const Datum& cond, const std::vector<Datum>& cases,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Year returns year for each element of `values`
+///
+/// \param[in] values input to extract year from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Year(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief Month returns month for each element of `values`.
+/// Month is encoded as January=1, December=12
+///
+/// \param[in] values input to extract month from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Month(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief Day returns day number for each element of `values`
+///
+/// \param[in] values input to extract day from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Day(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief DayOfWeek returns number of the day of the week value for each element of
+/// `values`.
+///
+/// By default week starts on Monday denoted by 0 and ends on Sunday denoted
+/// by 6. Start day of the week (Monday=1, Sunday=7) and numbering base (0 or 1) can be
+/// set using DayOfWeekOptions
+///
+/// \param[in] values input to extract number of the day of the week from
+/// \param[in] options for setting start of the week and day numbering
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT Result<Datum> DayOfWeek(const Datum& values,
+ DayOfWeekOptions options = DayOfWeekOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief DayOfYear returns number of day of the year for each element of `values`.
+/// January 1st maps to day number 1, February 1st to 32, etc.
+///
+/// \param[in] values input to extract number of day of the year from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT Result<Datum> DayOfYear(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief ISOYear returns ISO year number for each element of `values`.
+/// First week of an ISO year has the majority (4 or more) of its days in January.
+///
+/// \param[in] values input to extract ISO year from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> ISOYear(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief ISOWeek returns ISO week of year number for each element of `values`.
+/// First ISO week has the majority (4 or more) of its days in January.
+/// ISO week starts on Monday. Year can have 52 or 53 weeks.
+/// Week numbering can start with 1.
+///
+/// \param[in] values input to extract ISO week of year from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT Result<Datum> ISOWeek(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief USWeek returns US week of year number for each element of `values`.
+/// First US week has the majority (4 or more) of its days in January.
+/// US week starts on Sunday. Year can have 52 or 53 weeks.
+/// Week numbering starts with 1.
+///
+/// \param[in] values input to extract US week of year from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 6.0.0
+/// \note API not yet finalized
+ARROW_EXPORT Result<Datum> USWeek(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief Week returns week of year number for each element of `values`.
+/// First ISO week has the majority (4 or more) of its days in January.
+/// Year can have 52 or 53 weeks. Week numbering can start with 0 or 1
+/// depending on DayOfWeekOptions.count_from_zero.
+///
+/// \param[in] values input to extract week of year from
+/// \param[in] options for setting numbering start
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 6.0.0
+/// \note API not yet finalized
+ARROW_EXPORT Result<Datum> Week(const Datum& values, WeekOptions options = WeekOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief ISOCalendar returns a (ISO year, ISO week, ISO day of week) struct for
+/// each element of `values`.
+/// ISO week starts on Monday denoted by 1 and ends on Sunday denoted by 7.
+///
+/// \param[in] values input to ISO calendar struct from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT Result<Datum> ISOCalendar(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief Quarter returns the quarter of year number for each element of `values`
+/// First quarter maps to 1 and fourth quarter maps to 4.
+///
+/// \param[in] values input to extract quarter of year from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT Result<Datum> Quarter(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief Hour returns hour value for each element of `values`
+///
+/// \param[in] values input to extract hour from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Hour(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief Minute returns minutes value for each element of `values`
+///
+/// \param[in] values input to extract minutes from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Minute(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief Second returns seconds value for each element of `values`
+///
+/// \param[in] values input to extract seconds from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Second(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief Millisecond returns number of milliseconds since the last full second
+/// for each element of `values`
+///
+/// \param[in] values input to extract milliseconds from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Millisecond(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief Microsecond returns number of microseconds since the last full millisecond
+/// for each element of `values`
+///
+/// \param[in] values input to extract microseconds from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Microsecond(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief Nanosecond returns number of nanoseconds since the last full millisecond
+/// for each element of `values`
+///
+/// \param[in] values input to extract nanoseconds from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Nanosecond(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief Subsecond returns the fraction of second elapsed since last full second
+/// as a float for each element of `values`
+///
+/// \param[in] values input to extract subsecond from
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT Result<Datum> Subsecond(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief Format timestamps according to a format string
+///
+/// Return formatted time strings according to the format string
+/// `StrftimeOptions::format` and to the locale specifier `Strftime::locale`.
+///
+/// \param[in] values input timestamps
+/// \param[in] options for setting format string and locale
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 6.0.0
+/// \note API not yet finalized
+ARROW_EXPORT Result<Datum> Strftime(const Datum& values, StrftimeOptions options,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Converts timestamps from local timestamp without a timezone to a timestamp with
+/// timezone, interpreting the local timestamp as being in the specified timezone for each
+/// element of `values`
+///
+/// \param[in] values input to convert
+/// \param[in] options for setting source timezone, exception and ambiguous timestamp
+/// handling.
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 6.0.0
+/// \note API not yet finalized
+ARROW_EXPORT Result<Datum> AssumeTimezone(const Datum& values,
+ AssumeTimezoneOptions options,
+ ExecContext* ctx = NULLPTR);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/api_vector.cc b/src/arrow/cpp/src/arrow/compute/api_vector.cc
new file mode 100644
index 000000000..1fc6b7874
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/api_vector.cc
@@ -0,0 +1,328 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/api_vector.h"
+
+#include <memory>
+#include <sstream>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/array_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/function_internal.h"
+#include "arrow/compute/registry.h"
+#include "arrow/datum.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace internal {
+
+using compute::DictionaryEncodeOptions;
+using compute::FilterOptions;
+using compute::NullPlacement;
+
+template <>
+struct EnumTraits<FilterOptions::NullSelectionBehavior>
+ : BasicEnumTraits<FilterOptions::NullSelectionBehavior, FilterOptions::DROP,
+ FilterOptions::EMIT_NULL> {
+ static std::string name() { return "FilterOptions::NullSelectionBehavior"; }
+ static std::string value_name(FilterOptions::NullSelectionBehavior value) {
+ switch (value) {
+ case FilterOptions::DROP:
+ return "DROP";
+ case FilterOptions::EMIT_NULL:
+ return "EMIT_NULL";
+ }
+ return "<INVALID>";
+ }
+};
+template <>
+struct EnumTraits<DictionaryEncodeOptions::NullEncodingBehavior>
+ : BasicEnumTraits<DictionaryEncodeOptions::NullEncodingBehavior,
+ DictionaryEncodeOptions::ENCODE, DictionaryEncodeOptions::MASK> {
+ static std::string name() { return "DictionaryEncodeOptions::NullEncodingBehavior"; }
+ static std::string value_name(DictionaryEncodeOptions::NullEncodingBehavior value) {
+ switch (value) {
+ case DictionaryEncodeOptions::ENCODE:
+ return "ENCODE";
+ case DictionaryEncodeOptions::MASK:
+ return "MASK";
+ }
+ return "<INVALID>";
+ }
+};
+template <>
+struct EnumTraits<NullPlacement>
+ : BasicEnumTraits<NullPlacement, NullPlacement::AtStart, NullPlacement::AtEnd> {
+ static std::string name() { return "NullPlacement"; }
+ static std::string value_name(NullPlacement value) {
+ switch (value) {
+ case NullPlacement::AtStart:
+ return "AtStart";
+ case NullPlacement::AtEnd:
+ return "AtEnd";
+ }
+ return "<INVALID>";
+ }
+};
+
+} // namespace internal
+
+namespace compute {
+
+// ----------------------------------------------------------------------
+// Function options
+
+bool SortKey::Equals(const SortKey& other) const {
+ return name == other.name && order == other.order;
+}
+std::string SortKey::ToString() const {
+ std::stringstream ss;
+ ss << name << ' ';
+ switch (order) {
+ case SortOrder::Ascending:
+ ss << "ASC";
+ break;
+ case SortOrder::Descending:
+ ss << "DESC";
+ break;
+ }
+ return ss.str();
+}
+
+namespace internal {
+namespace {
+using ::arrow::internal::DataMember;
+static auto kFilterOptionsType = GetFunctionOptionsType<FilterOptions>(
+ DataMember("null_selection_behavior", &FilterOptions::null_selection_behavior));
+static auto kTakeOptionsType = GetFunctionOptionsType<TakeOptions>(
+ DataMember("boundscheck", &TakeOptions::boundscheck));
+static auto kDictionaryEncodeOptionsType =
+ GetFunctionOptionsType<DictionaryEncodeOptions>(DataMember(
+ "null_encoding_behavior", &DictionaryEncodeOptions::null_encoding_behavior));
+static auto kArraySortOptionsType = GetFunctionOptionsType<ArraySortOptions>(
+ DataMember("order", &ArraySortOptions::order),
+ DataMember("null_placement", &ArraySortOptions::null_placement));
+static auto kSortOptionsType = GetFunctionOptionsType<SortOptions>(
+ DataMember("sort_keys", &SortOptions::sort_keys),
+ DataMember("null_placement", &SortOptions::null_placement));
+static auto kPartitionNthOptionsType = GetFunctionOptionsType<PartitionNthOptions>(
+ DataMember("pivot", &PartitionNthOptions::pivot),
+ DataMember("null_placement", &PartitionNthOptions::null_placement));
+static auto kSelectKOptionsType = GetFunctionOptionsType<SelectKOptions>(
+ DataMember("k", &SelectKOptions::k),
+ DataMember("sort_keys", &SelectKOptions::sort_keys));
+} // namespace
+} // namespace internal
+
+FilterOptions::FilterOptions(NullSelectionBehavior null_selection)
+ : FunctionOptions(internal::kFilterOptionsType),
+ null_selection_behavior(null_selection) {}
+constexpr char FilterOptions::kTypeName[];
+
+TakeOptions::TakeOptions(bool boundscheck)
+ : FunctionOptions(internal::kTakeOptionsType), boundscheck(boundscheck) {}
+constexpr char TakeOptions::kTypeName[];
+
+DictionaryEncodeOptions::DictionaryEncodeOptions(NullEncodingBehavior null_encoding)
+ : FunctionOptions(internal::kDictionaryEncodeOptionsType),
+ null_encoding_behavior(null_encoding) {}
+constexpr char DictionaryEncodeOptions::kTypeName[];
+
+ArraySortOptions::ArraySortOptions(SortOrder order, NullPlacement null_placement)
+ : FunctionOptions(internal::kArraySortOptionsType),
+ order(order),
+ null_placement(null_placement) {}
+constexpr char ArraySortOptions::kTypeName[];
+
+SortOptions::SortOptions(std::vector<SortKey> sort_keys, NullPlacement null_placement)
+ : FunctionOptions(internal::kSortOptionsType),
+ sort_keys(std::move(sort_keys)),
+ null_placement(null_placement) {}
+constexpr char SortOptions::kTypeName[];
+
+PartitionNthOptions::PartitionNthOptions(int64_t pivot, NullPlacement null_placement)
+ : FunctionOptions(internal::kPartitionNthOptionsType),
+ pivot(pivot),
+ null_placement(null_placement) {}
+constexpr char PartitionNthOptions::kTypeName[];
+
+SelectKOptions::SelectKOptions(int64_t k, std::vector<SortKey> sort_keys)
+ : FunctionOptions(internal::kSelectKOptionsType),
+ k(k),
+ sort_keys(std::move(sort_keys)) {}
+constexpr char SelectKOptions::kTypeName[];
+
+namespace internal {
+void RegisterVectorOptions(FunctionRegistry* registry) {
+ DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kTakeOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kDictionaryEncodeOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kArraySortOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kSortOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kPartitionNthOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType));
+}
+} // namespace internal
+
+// ----------------------------------------------------------------------
+// Direct exec interface to kernels
+
+Result<std::shared_ptr<Array>> NthToIndices(const Array& values,
+ const PartitionNthOptions& options,
+ ExecContext* ctx) {
+ ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("partition_nth_indices",
+ {Datum(values)}, &options, ctx));
+ return result.make_array();
+}
+
+Result<std::shared_ptr<Array>> NthToIndices(const Array& values, int64_t n,
+ ExecContext* ctx) {
+ PartitionNthOptions options(/*pivot=*/n);
+ ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("partition_nth_indices",
+ {Datum(values)}, &options, ctx));
+ return result.make_array();
+}
+
+Result<std::shared_ptr<Array>> SelectKUnstable(const Datum& datum,
+ const SelectKOptions& options,
+ ExecContext* ctx) {
+ ARROW_ASSIGN_OR_RAISE(Datum result,
+ CallFunction("select_k_unstable", {datum}, &options, ctx));
+ return result.make_array();
+}
+
+Result<Datum> ReplaceWithMask(const Datum& values, const Datum& mask,
+ const Datum& replacements, ExecContext* ctx) {
+ return CallFunction("replace_with_mask", {values, mask, replacements}, ctx);
+}
+
+Result<std::shared_ptr<Array>> SortIndices(const Array& values,
+ const ArraySortOptions& options,
+ ExecContext* ctx) {
+ ARROW_ASSIGN_OR_RAISE(
+ Datum result, CallFunction("array_sort_indices", {Datum(values)}, &options, ctx));
+ return result.make_array();
+}
+
+Result<std::shared_ptr<Array>> SortIndices(const Array& values, SortOrder order,
+ ExecContext* ctx) {
+ ArraySortOptions options(order);
+ ARROW_ASSIGN_OR_RAISE(
+ Datum result, CallFunction("array_sort_indices", {Datum(values)}, &options, ctx));
+ return result.make_array();
+}
+
+Result<std::shared_ptr<Array>> SortIndices(const ChunkedArray& chunked_array,
+ const ArraySortOptions& array_options,
+ ExecContext* ctx) {
+ SortOptions options({SortKey("", array_options.order)}, array_options.null_placement);
+ ARROW_ASSIGN_OR_RAISE(
+ Datum result, CallFunction("sort_indices", {Datum(chunked_array)}, &options, ctx));
+ return result.make_array();
+}
+
+Result<std::shared_ptr<Array>> SortIndices(const ChunkedArray& chunked_array,
+ SortOrder order, ExecContext* ctx) {
+ SortOptions options({SortKey("not-used", order)});
+ ARROW_ASSIGN_OR_RAISE(
+ Datum result, CallFunction("sort_indices", {Datum(chunked_array)}, &options, ctx));
+ return result.make_array();
+}
+
+Result<std::shared_ptr<Array>> SortIndices(const Datum& datum, const SortOptions& options,
+ ExecContext* ctx) {
+ ARROW_ASSIGN_OR_RAISE(Datum result,
+ CallFunction("sort_indices", {datum}, &options, ctx));
+ return result.make_array();
+}
+
+Result<std::shared_ptr<Array>> Unique(const Datum& value, ExecContext* ctx) {
+ ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("unique", {value}, ctx));
+ return result.make_array();
+}
+
+Result<Datum> DictionaryEncode(const Datum& value, const DictionaryEncodeOptions& options,
+ ExecContext* ctx) {
+ return CallFunction("dictionary_encode", {value}, &options, ctx);
+}
+
+const char kValuesFieldName[] = "values";
+const char kCountsFieldName[] = "counts";
+const int32_t kValuesFieldIndex = 0;
+const int32_t kCountsFieldIndex = 1;
+
+Result<std::shared_ptr<StructArray>> ValueCounts(const Datum& value, ExecContext* ctx) {
+ ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("value_counts", {value}, ctx));
+ return checked_pointer_cast<StructArray>(result.make_array());
+}
+
+// ----------------------------------------------------------------------
+// Filter- and take-related selection functions
+
+Result<Datum> Filter(const Datum& values, const Datum& filter,
+ const FilterOptions& options, ExecContext* ctx) {
+ // Invoke metafunction which deals with Datum kinds other than just Array,
+ // ChunkedArray.
+ return CallFunction("filter", {values, filter}, &options, ctx);
+}
+
+Result<Datum> Take(const Datum& values, const Datum& filter, const TakeOptions& options,
+ ExecContext* ctx) {
+ // Invoke metafunction which deals with Datum kinds other than just Array,
+ // ChunkedArray.
+ return CallFunction("take", {values, filter}, &options, ctx);
+}
+
+Result<std::shared_ptr<Array>> Take(const Array& values, const Array& indices,
+ const TakeOptions& options, ExecContext* ctx) {
+ ARROW_ASSIGN_OR_RAISE(Datum out, Take(Datum(values), Datum(indices), options, ctx));
+ return out.make_array();
+}
+
+// ----------------------------------------------------------------------
+// Dropnull functions
+
+Result<Datum> DropNull(const Datum& values, ExecContext* ctx) {
+ // Invoke metafunction which deals with Datum kinds other than just Array,
+ // ChunkedArray.
+ return CallFunction("drop_null", {values}, ctx);
+}
+
+Result<std::shared_ptr<Array>> DropNull(const Array& values, ExecContext* ctx) {
+ ARROW_ASSIGN_OR_RAISE(Datum out, DropNull(Datum(values), ctx));
+ return out.make_array();
+}
+
+// ----------------------------------------------------------------------
+// Deprecated functions
+
+Result<std::shared_ptr<Array>> SortToIndices(const Array& values, ExecContext* ctx) {
+ return SortIndices(values, SortOrder::Ascending, ctx);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/api_vector.h b/src/arrow/cpp/src/arrow/compute/api_vector.h
new file mode 100644
index 000000000..a91cf91df
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/api_vector.h
@@ -0,0 +1,506 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <utility>
+
+#include "arrow/compute/function.h"
+#include "arrow/datum.h"
+#include "arrow/result.h"
+#include "arrow/type_fwd.h"
+
+namespace arrow {
+namespace compute {
+
+class ExecContext;
+
+/// \addtogroup compute-concrete-options
+/// @{
+
+class ARROW_EXPORT FilterOptions : public FunctionOptions {
+ public:
+ /// Configure the action taken when a slot of the selection mask is null
+ enum NullSelectionBehavior {
+ /// The corresponding filtered value will be removed in the output.
+ DROP,
+ /// The corresponding filtered value will be null in the output.
+ EMIT_NULL,
+ };
+
+ explicit FilterOptions(NullSelectionBehavior null_selection = DROP);
+ constexpr static char const kTypeName[] = "FilterOptions";
+ static FilterOptions Defaults() { return FilterOptions(); }
+
+ NullSelectionBehavior null_selection_behavior = DROP;
+};
+
+class ARROW_EXPORT TakeOptions : public FunctionOptions {
+ public:
+ explicit TakeOptions(bool boundscheck = true);
+ constexpr static char const kTypeName[] = "TakeOptions";
+ static TakeOptions BoundsCheck() { return TakeOptions(true); }
+ static TakeOptions NoBoundsCheck() { return TakeOptions(false); }
+ static TakeOptions Defaults() { return BoundsCheck(); }
+
+ bool boundscheck = true;
+};
+
+/// \brief Options for the dictionary encode function
+class ARROW_EXPORT DictionaryEncodeOptions : public FunctionOptions {
+ public:
+ /// Configure how null values will be encoded
+ enum NullEncodingBehavior {
+ /// The null value will be added to the dictionary with a proper index.
+ ENCODE,
+ /// The null value will be masked in the indices array.
+ MASK
+ };
+
+ explicit DictionaryEncodeOptions(NullEncodingBehavior null_encoding = MASK);
+ constexpr static char const kTypeName[] = "DictionaryEncodeOptions";
+ static DictionaryEncodeOptions Defaults() { return DictionaryEncodeOptions(); }
+
+ NullEncodingBehavior null_encoding_behavior = MASK;
+};
+
+enum class SortOrder {
+ /// Arrange values in increasing order
+ Ascending,
+ /// Arrange values in decreasing order
+ Descending,
+};
+
+enum class NullPlacement {
+ /// Place nulls and NaNs before any non-null values.
+ /// NaNs will come after nulls.
+ AtStart,
+ /// Place nulls and NaNs after any non-null values.
+ /// NaNs will come before nulls.
+ AtEnd,
+};
+
+/// \brief One sort key for PartitionNthIndices (TODO) and SortIndices
+class ARROW_EXPORT SortKey : public util::EqualityComparable<SortKey> {
+ public:
+ explicit SortKey(std::string name, SortOrder order = SortOrder::Ascending)
+ : name(std::move(name)), order(order) {}
+
+ using util::EqualityComparable<SortKey>::Equals;
+ using util::EqualityComparable<SortKey>::operator==;
+ using util::EqualityComparable<SortKey>::operator!=;
+ bool Equals(const SortKey& other) const;
+ std::string ToString() const;
+
+ /// The name of the sort column.
+ std::string name;
+ /// How to order by this sort key.
+ SortOrder order;
+};
+
+class ARROW_EXPORT ArraySortOptions : public FunctionOptions {
+ public:
+ explicit ArraySortOptions(SortOrder order = SortOrder::Ascending,
+ NullPlacement null_placement = NullPlacement::AtEnd);
+ constexpr static char const kTypeName[] = "ArraySortOptions";
+ static ArraySortOptions Defaults() { return ArraySortOptions(); }
+
+ /// Sorting order
+ SortOrder order;
+ /// Whether nulls and NaNs are placed at the start or at the end
+ NullPlacement null_placement;
+};
+
+class ARROW_EXPORT SortOptions : public FunctionOptions {
+ public:
+ explicit SortOptions(std::vector<SortKey> sort_keys = {},
+ NullPlacement null_placement = NullPlacement::AtEnd);
+ constexpr static char const kTypeName[] = "SortOptions";
+ static SortOptions Defaults() { return SortOptions(); }
+
+ /// Column key(s) to order by and how to order by these sort keys.
+ std::vector<SortKey> sort_keys;
+ /// Whether nulls and NaNs are placed at the start or at the end
+ NullPlacement null_placement;
+};
+
+/// \brief SelectK options
+class ARROW_EXPORT SelectKOptions : public FunctionOptions {
+ public:
+ explicit SelectKOptions(int64_t k = -1, std::vector<SortKey> sort_keys = {});
+ constexpr static char const kTypeName[] = "SelectKOptions";
+ static SelectKOptions Defaults() { return SelectKOptions(); }
+
+ static SelectKOptions TopKDefault(int64_t k, std::vector<std::string> key_names = {}) {
+ std::vector<SortKey> keys;
+ for (const auto& name : key_names) {
+ keys.emplace_back(SortKey(name, SortOrder::Descending));
+ }
+ if (key_names.empty()) {
+ keys.emplace_back(SortKey("not-used", SortOrder::Descending));
+ }
+ return SelectKOptions{k, keys};
+ }
+ static SelectKOptions BottomKDefault(int64_t k,
+ std::vector<std::string> key_names = {}) {
+ std::vector<SortKey> keys;
+ for (const auto& name : key_names) {
+ keys.emplace_back(SortKey(name, SortOrder::Ascending));
+ }
+ if (key_names.empty()) {
+ keys.emplace_back(SortKey("not-used", SortOrder::Ascending));
+ }
+ return SelectKOptions{k, keys};
+ }
+
+ /// The number of `k` elements to keep.
+ int64_t k;
+ /// Column key(s) to order by and how to order by these sort keys.
+ std::vector<SortKey> sort_keys;
+};
+
+/// \brief Partitioning options for NthToIndices
+class ARROW_EXPORT PartitionNthOptions : public FunctionOptions {
+ public:
+ explicit PartitionNthOptions(int64_t pivot,
+ NullPlacement null_placement = NullPlacement::AtEnd);
+ PartitionNthOptions() : PartitionNthOptions(0) {}
+ constexpr static char const kTypeName[] = "PartitionNthOptions";
+
+ /// The index into the equivalent sorted array of the partition pivot element.
+ int64_t pivot;
+ /// Whether nulls and NaNs are partitioned at the start or at the end
+ NullPlacement null_placement;
+};
+
+/// @}
+
+/// \brief Filter with a boolean selection filter
+///
+/// The output will be populated with values from the input at positions
+/// where the selection filter is not 0. Nulls in the filter will be handled
+/// based on options.null_selection_behavior.
+///
+/// For example given values = ["a", "b", "c", null, "e", "f"] and
+/// filter = [0, 1, 1, 0, null, 1], the output will be
+/// (null_selection_behavior == DROP) = ["b", "c", "f"]
+/// (null_selection_behavior == EMIT_NULL) = ["b", "c", null, "f"]
+///
+/// \param[in] values array to filter
+/// \param[in] filter indicates which values should be filtered out
+/// \param[in] options configures null_selection_behavior
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+ARROW_EXPORT
+Result<Datum> Filter(const Datum& values, const Datum& filter,
+ const FilterOptions& options = FilterOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+namespace internal {
+
+// These internal functions are implemented in kernels/vector_selection.cc
+
+/// \brief Return the number of selected indices in the boolean filter
+ARROW_EXPORT
+int64_t GetFilterOutputSize(const ArrayData& filter,
+ FilterOptions::NullSelectionBehavior null_selection);
+
+/// \brief Compute uint64 selection indices for use with Take given a boolean
+/// filter
+ARROW_EXPORT
+Result<std::shared_ptr<ArrayData>> GetTakeIndices(
+ const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection,
+ MemoryPool* memory_pool = default_memory_pool());
+
+} // namespace internal
+
+/// \brief ReplaceWithMask replaces each value in the array corresponding
+/// to a true value in the mask with the next element from `replacements`.
+///
+/// \param[in] values Array input to replace
+/// \param[in] mask Array or Scalar of Boolean mask values
+/// \param[in] replacements The replacement values to draw from. There must
+/// be as many replacement values as true values in the mask.
+/// \param[in] ctx the function execution context, optional
+///
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> ReplaceWithMask(const Datum& values, const Datum& mask,
+ const Datum& replacements, ExecContext* ctx = NULLPTR);
+
+/// \brief Take from an array of values at indices in another array
+///
+/// The output array will be of the same type as the input values
+/// array, with elements taken from the values array at the given
+/// indices. If an index is null then the taken element will be null.
+///
+/// For example given values = ["a", "b", "c", null, "e", "f"] and
+/// indices = [2, 1, null, 3], the output will be
+/// = [values[2], values[1], null, values[3]]
+/// = ["c", "b", null, null]
+///
+/// \param[in] values datum from which to take
+/// \param[in] indices which values to take
+/// \param[in] options options
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+ARROW_EXPORT
+Result<Datum> Take(const Datum& values, const Datum& indices,
+ const TakeOptions& options = TakeOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Take with Array inputs and output
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> Take(const Array& values, const Array& indices,
+ const TakeOptions& options = TakeOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Drop Null from an array of values
+///
+/// The output array will be of the same type as the input values
+/// array, with elements taken from the values array without nulls.
+///
+/// For example given values = ["a", "b", "c", null, "e", "f"],
+/// the output will be = ["a", "b", "c", "e", "f"]
+///
+/// \param[in] values datum from which to take
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+ARROW_EXPORT
+Result<Datum> DropNull(const Datum& values, ExecContext* ctx = NULLPTR);
+
+/// \brief DropNull with Array inputs and output
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> DropNull(const Array& values, ExecContext* ctx = NULLPTR);
+
+/// \brief Return indices that partition an array around n-th sorted element.
+///
+/// Find index of n-th(0 based) smallest value and perform indirect
+/// partition of an array around that element. Output indices[0 ~ n-1]
+/// holds values no greater than n-th element, and indices[n+1 ~ end]
+/// holds values no less than n-th element. Elements in each partition
+/// is not sorted. Nulls will be partitioned to the end of the output.
+/// Output is not guaranteed to be stable.
+///
+/// \param[in] values array to be partitioned
+/// \param[in] n pivot array around sorted n-th element
+/// \param[in] ctx the function execution context, optional
+/// \return offsets indices that would partition an array
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> NthToIndices(const Array& values, int64_t n,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Return indices that partition an array around n-th sorted element.
+///
+/// This overload takes a PartitionNthOptions specifiying the pivot index
+/// and the null handling.
+///
+/// \param[in] values array to be partitioned
+/// \param[in] options options including pivot index and null handling
+/// \param[in] ctx the function execution context, optional
+/// \return offsets indices that would partition an array
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> NthToIndices(const Array& values,
+ const PartitionNthOptions& options,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Return indices that would select the first `k` elements.
+///
+/// Perform an indirect sort of the datum, keeping only the first `k` elements. The output
+/// array will contain indices such that the item indicated by the k-th index will be in
+/// the position it would be if the datum were sorted by `options.sort_keys`. However,
+/// indices of null values will not be part of the output. The sort is not guaranteed to
+/// be stable.
+///
+/// \param[in] datum datum to be partitioned
+/// \param[in] options options
+/// \param[in] ctx the function execution context, optional
+/// \return a datum with the same schema as the input
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> SelectKUnstable(const Datum& datum,
+ const SelectKOptions& options,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Return the indices that would sort an array.
+///
+/// Perform an indirect sort of array. The output array will contain
+/// indices that would sort an array, which would be the same length
+/// as input. Nulls will be stably partitioned to the end of the output
+/// regardless of order.
+///
+/// For example given array = [null, 1, 3.3, null, 2, 5.3] and order
+/// = SortOrder::DESCENDING, the output will be [5, 2, 4, 1, 0,
+/// 3].
+///
+/// \param[in] array array to sort
+/// \param[in] order ascending or descending
+/// \param[in] ctx the function execution context, optional
+/// \return offsets indices that would sort an array
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> SortIndices(const Array& array,
+ SortOrder order = SortOrder::Ascending,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Return the indices that would sort an array.
+///
+/// This overload takes a ArraySortOptions specifiying the sort order
+/// and the null handling.
+///
+/// \param[in] array array to sort
+/// \param[in] options options including sort order and null handling
+/// \param[in] ctx the function execution context, optional
+/// \return offsets indices that would sort an array
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> SortIndices(const Array& array,
+ const ArraySortOptions& options,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Return the indices that would sort a chunked array.
+///
+/// Perform an indirect sort of chunked array. The output array will
+/// contain indices that would sort a chunked array, which would be
+/// the same length as input. Nulls will be stably partitioned to the
+/// end of the output regardless of order.
+///
+/// For example given chunked_array = [[null, 1], [3.3], [null, 2,
+/// 5.3]] and order = SortOrder::DESCENDING, the output will be [5, 2,
+/// 4, 1, 0, 3].
+///
+/// \param[in] chunked_array chunked array to sort
+/// \param[in] order ascending or descending
+/// \param[in] ctx the function execution context, optional
+/// \return offsets indices that would sort an array
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> SortIndices(const ChunkedArray& chunked_array,
+ SortOrder order = SortOrder::Ascending,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Return the indices that would sort a chunked array.
+///
+/// This overload takes a ArraySortOptions specifiying the sort order
+/// and the null handling.
+///
+/// \param[in] chunked_array chunked array to sort
+/// \param[in] options options including sort order and null handling
+/// \param[in] ctx the function execution context, optional
+/// \return offsets indices that would sort an array
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> SortIndices(const ChunkedArray& chunked_array,
+ const ArraySortOptions& options,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Return the indices that would sort an input in the
+/// specified order. Input is one of array, chunked array record batch
+/// or table.
+///
+/// Perform an indirect sort of input. The output array will contain
+/// indices that would sort an input, which would be the same length
+/// as input. Nulls will be stably partitioned to the start or to the end
+/// of the output depending on SortOrder::null_placement.
+///
+/// For example given input (table) = {
+/// "column1": [[null, 1], [ 3, null, 2, 1]],
+/// "column2": [[ 5], [3, null, null, 5, 5]],
+/// } and options = {
+/// {"column1", SortOrder::Ascending},
+/// {"column2", SortOrder::Descending},
+/// }, the output will be [5, 1, 4, 2, 0, 3].
+///
+/// \param[in] datum array, chunked array, record batch or table to sort
+/// \param[in] options options
+/// \param[in] ctx the function execution context, optional
+/// \return offsets indices that would sort a table
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> SortIndices(const Datum& datum, const SortOptions& options,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compute unique elements from an array-like object
+///
+/// Note if a null occurs in the input it will NOT be included in the output.
+///
+/// \param[in] datum array-like input
+/// \param[in] ctx the function execution context, optional
+/// \return result as Array
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> Unique(const Datum& datum, ExecContext* ctx = NULLPTR);
+
+// Constants for accessing the output of ValueCounts
+ARROW_EXPORT extern const char kValuesFieldName[];
+ARROW_EXPORT extern const char kCountsFieldName[];
+ARROW_EXPORT extern const int32_t kValuesFieldIndex;
+ARROW_EXPORT extern const int32_t kCountsFieldIndex;
+
+/// \brief Return counts of unique elements from an array-like object.
+///
+/// Note that the counts do not include counts for nulls in the array. These can be
+/// obtained separately from metadata.
+///
+/// For floating point arrays there is no attempt to normalize -0.0, 0.0 and NaN values
+/// which can lead to unexpected results if the input Array has these values.
+///
+/// \param[in] value array-like input
+/// \param[in] ctx the function execution context, optional
+/// \return counts An array of <input type "Values", int64_t "Counts"> structs.
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<std::shared_ptr<StructArray>> ValueCounts(const Datum& value,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Dictionary-encode values in an array-like object
+///
+/// Any nulls encountered in the dictionary will be handled according to the
+/// specified null encoding behavior.
+///
+/// For example, given values ["a", "b", null, "a", null] the output will be
+/// (null_encoding == ENCODE) Indices: [0, 1, 2, 0, 2] / Dict: ["a", "b", null]
+/// (null_encoding == MASK) Indices: [0, 1, null, 0, null] / Dict: ["a", "b"]
+///
+/// If the input is already dictionary encoded this function is a no-op unless
+/// it needs to modify the null_encoding (TODO)
+///
+/// \param[in] data array-like input
+/// \param[in] ctx the function execution context, optional
+/// \param[in] options configures null encoding behavior
+/// \return result with same shape and type as input
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> DictionaryEncode(
+ const Datum& data,
+ const DictionaryEncodeOptions& options = DictionaryEncodeOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
+
+// ----------------------------------------------------------------------
+// Deprecated functions
+
+ARROW_DEPRECATED("Deprecated in 3.0.0. Use SortIndices()")
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> SortToIndices(const Array& values,
+ ExecContext* ctx = NULLPTR);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/arrow-compute.pc.in b/src/arrow/cpp/src/arrow/compute/arrow-compute.pc.in
new file mode 100644
index 000000000..bbdb12c47
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/arrow-compute.pc.in
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+Name: Apache Arrow Compute
+Description: Compute modules for Apache Arrow
+Version: @ARROW_VERSION@
+Requires: arrow
diff --git a/src/arrow/cpp/src/arrow/compute/cast.cc b/src/arrow/cpp/src/arrow/compute/cast.cc
new file mode 100644
index 000000000..4de68ba8d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/cast.cc
@@ -0,0 +1,273 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/cast.h"
+
+#include <mutex>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "arrow/compute/cast_internal.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/function_internal.h"
+#include "arrow/compute/kernel.h"
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/compute/registry.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/reflection_internal.h"
+
+namespace arrow {
+
+using internal::ToTypeName;
+
+namespace compute {
+namespace internal {
+
+// ----------------------------------------------------------------------
+// Function options
+
+namespace {
+
+std::unordered_map<int, std::shared_ptr<CastFunction>> g_cast_table;
+std::once_flag cast_table_initialized;
+
+void AddCastFunctions(const std::vector<std::shared_ptr<CastFunction>>& funcs) {
+ for (const auto& func : funcs) {
+ g_cast_table[static_cast<int>(func->out_type_id())] = func;
+ }
+}
+
+void InitCastTable() {
+ AddCastFunctions(GetBooleanCasts());
+ AddCastFunctions(GetBinaryLikeCasts());
+ AddCastFunctions(GetNestedCasts());
+ AddCastFunctions(GetNumericCasts());
+ AddCastFunctions(GetTemporalCasts());
+ AddCastFunctions(GetDictionaryCasts());
+}
+
+void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); }
+
+// Private version of GetCastFunction with better error reporting
+// if the input type is known.
+Result<std::shared_ptr<CastFunction>> GetCastFunctionInternal(
+ const std::shared_ptr<DataType>& to_type, const DataType* from_type = nullptr) {
+ internal::EnsureInitCastTable();
+ auto it = internal::g_cast_table.find(static_cast<int>(to_type->id()));
+ if (it == internal::g_cast_table.end()) {
+ if (from_type != nullptr) {
+ return Status::NotImplemented("Unsupported cast from ", *from_type, " to ",
+ *to_type,
+ " (no available cast function for target type)");
+ } else {
+ return Status::NotImplemented("Unsupported cast to ", *to_type,
+ " (no available cast function for target type)");
+ }
+ }
+ return it->second;
+}
+
+const FunctionDoc cast_doc{"Cast values to another data type",
+ ("Behavior when values wouldn't fit in the target type\n"
+ "can be controlled through CastOptions."),
+ {"input"},
+ "CastOptions"};
+
+// Metafunction for dispatching to appropriate CastFunction. This corresponds
+// to the standard SQL CAST(expr AS target_type)
+class CastMetaFunction : public MetaFunction {
+ public:
+ CastMetaFunction() : MetaFunction("cast", Arity::Unary(), &cast_doc) {}
+
+ Result<const CastOptions*> ValidateOptions(const FunctionOptions* options) const {
+ auto cast_options = static_cast<const CastOptions*>(options);
+
+ if (cast_options == nullptr || cast_options->to_type == nullptr) {
+ return Status::Invalid(
+ "Cast requires that options be passed with "
+ "the to_type populated");
+ }
+
+ return cast_options;
+ }
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ ARROW_ASSIGN_OR_RAISE(auto cast_options, ValidateOptions(options));
+ if (args[0].type()->Equals(*cast_options->to_type)) {
+ return args[0];
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<CastFunction> cast_func,
+ GetCastFunctionInternal(cast_options->to_type, args[0].type().get()));
+ return cast_func->Execute(args, options, ctx);
+ }
+};
+
+static auto kCastOptionsType = GetFunctionOptionsType<CastOptions>(
+ arrow::internal::DataMember("to_type", &CastOptions::to_type),
+ arrow::internal::DataMember("allow_int_overflow", &CastOptions::allow_int_overflow),
+ arrow::internal::DataMember("allow_time_truncate", &CastOptions::allow_time_truncate),
+ arrow::internal::DataMember("allow_time_overflow", &CastOptions::allow_time_overflow),
+ arrow::internal::DataMember("allow_decimal_truncate",
+ &CastOptions::allow_decimal_truncate),
+ arrow::internal::DataMember("allow_float_truncate",
+ &CastOptions::allow_float_truncate),
+ arrow::internal::DataMember("allow_invalid_utf8", &CastOptions::allow_invalid_utf8));
+} // namespace
+
+void RegisterScalarCast(FunctionRegistry* registry) {
+ DCHECK_OK(registry->AddFunction(std::make_shared<CastMetaFunction>()));
+ DCHECK_OK(registry->AddFunctionOptionsType(kCastOptionsType));
+}
+} // namespace internal
+
+CastOptions::CastOptions(bool safe)
+ : FunctionOptions(internal::kCastOptionsType),
+ allow_int_overflow(!safe),
+ allow_time_truncate(!safe),
+ allow_time_overflow(!safe),
+ allow_decimal_truncate(!safe),
+ allow_float_truncate(!safe),
+ allow_invalid_utf8(!safe) {}
+
+constexpr char CastOptions::kTypeName[];
+
+CastFunction::CastFunction(std::string name, Type::type out_type_id)
+ : ScalarFunction(std::move(name), Arity::Unary(), /*doc=*/nullptr),
+ out_type_id_(out_type_id) {}
+
+Status CastFunction::AddKernel(Type::type in_type_id, ScalarKernel kernel) {
+ // We use the same KernelInit for every cast
+ kernel.init = internal::CastState::Init;
+ RETURN_NOT_OK(ScalarFunction::AddKernel(kernel));
+ in_type_ids_.push_back(in_type_id);
+ return Status::OK();
+}
+
+Status CastFunction::AddKernel(Type::type in_type_id, std::vector<InputType> in_types,
+ OutputType out_type, ArrayKernelExec exec,
+ NullHandling::type null_handling,
+ MemAllocation::type mem_allocation) {
+ ScalarKernel kernel;
+ kernel.signature = KernelSignature::Make(std::move(in_types), std::move(out_type));
+ kernel.exec = exec;
+ kernel.null_handling = null_handling;
+ kernel.mem_allocation = mem_allocation;
+ return AddKernel(in_type_id, std::move(kernel));
+}
+
+Result<const Kernel*> CastFunction::DispatchExact(
+ const std::vector<ValueDescr>& values) const {
+ RETURN_NOT_OK(CheckArity(values));
+
+ std::vector<const ScalarKernel*> candidate_kernels;
+ for (const auto& kernel : kernels_) {
+ if (kernel.signature->MatchesInputs(values)) {
+ candidate_kernels.push_back(&kernel);
+ }
+ }
+
+ if (candidate_kernels.size() == 0) {
+ return Status::NotImplemented("Unsupported cast from ", values[0].type->ToString(),
+ " to ", ToTypeName(out_type_id_), " using function ",
+ this->name());
+ }
+
+ if (candidate_kernels.size() == 1) {
+ // One match, return it
+ return candidate_kernels[0];
+ }
+
+ // Now we are in a casting scenario where we may have both a EXACT_TYPE and
+ // a SAME_TYPE_ID. So we will see if there is an exact match among the
+ // candidate kernels and if not we will just return the first one
+ for (auto kernel : candidate_kernels) {
+ const InputType& arg0 = kernel->signature->in_types()[0];
+ if (arg0.kind() == InputType::EXACT_TYPE) {
+ // Bingo. Return it
+ return kernel;
+ }
+ }
+
+ // We didn't find an exact match. So just return some kernel that matches
+ return candidate_kernels[0];
+}
+
+Result<Datum> Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) {
+ return CallFunction("cast", {value}, &options, ctx);
+}
+
+Result<Datum> Cast(const Datum& value, std::shared_ptr<DataType> to_type,
+ const CastOptions& options, ExecContext* ctx) {
+ CastOptions options_with_to_type = options;
+ options_with_to_type.to_type = to_type;
+ return Cast(value, options_with_to_type, ctx);
+}
+
+Result<std::shared_ptr<Array>> Cast(const Array& value, std::shared_ptr<DataType> to_type,
+ const CastOptions& options, ExecContext* ctx) {
+ ARROW_ASSIGN_OR_RAISE(Datum result, Cast(Datum(value), to_type, options, ctx));
+ return result.make_array();
+}
+
+Result<std::shared_ptr<CastFunction>> GetCastFunction(
+ const std::shared_ptr<DataType>& to_type) {
+ return internal::GetCastFunctionInternal(to_type);
+}
+
+bool CanCast(const DataType& from_type, const DataType& to_type) {
+ internal::EnsureInitCastTable();
+ auto it = internal::g_cast_table.find(static_cast<int>(to_type.id()));
+ if (it == internal::g_cast_table.end()) {
+ return false;
+ }
+
+ const CastFunction* function = it->second.get();
+ DCHECK_EQ(function->out_type_id(), to_type.id());
+
+ for (auto from_id : function->in_type_ids()) {
+ // XXX should probably check the output type as well
+ if (from_type.id() == from_id) return true;
+ }
+
+ return false;
+}
+
+Result<std::vector<Datum>> Cast(std::vector<Datum> datums, std::vector<ValueDescr> descrs,
+ ExecContext* ctx) {
+ for (size_t i = 0; i != datums.size(); ++i) {
+ if (descrs[i] != datums[i].descr()) {
+ if (descrs[i].shape != datums[i].shape()) {
+ return Status::NotImplemented("casting between Datum shapes");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(datums[i],
+ Cast(datums[i], CastOptions::Safe(descrs[i].type), ctx));
+ }
+ }
+
+ return datums;
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/cast.h b/src/arrow/cpp/src/arrow/compute/cast.h
new file mode 100644
index 000000000..131f57f89
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/cast.h
@@ -0,0 +1,167 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/compute/function.h"
+#include "arrow/compute/kernel.h"
+#include "arrow/datum.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Array;
+
+namespace compute {
+
+class ExecContext;
+
+/// \addtogroup compute-concrete-options
+/// @{
+
+class ARROW_EXPORT CastOptions : public FunctionOptions {
+ public:
+ explicit CastOptions(bool safe = true);
+
+ constexpr static char const kTypeName[] = "CastOptions";
+ static CastOptions Safe(std::shared_ptr<DataType> to_type = NULLPTR) {
+ CastOptions safe(true);
+ safe.to_type = std::move(to_type);
+ return safe;
+ }
+
+ static CastOptions Unsafe(std::shared_ptr<DataType> to_type = NULLPTR) {
+ CastOptions unsafe(false);
+ unsafe.to_type = std::move(to_type);
+ return unsafe;
+ }
+
+ // Type being casted to. May be passed separate to eager function
+ // compute::Cast
+ std::shared_ptr<DataType> to_type;
+
+ bool allow_int_overflow;
+ bool allow_time_truncate;
+ bool allow_time_overflow;
+ bool allow_decimal_truncate;
+ bool allow_float_truncate;
+ // Indicate if conversions from Binary/FixedSizeBinary to string must
+ // validate the utf8 payload.
+ bool allow_invalid_utf8;
+};
+
+/// @}
+
+// Cast functions are _not_ registered in the FunctionRegistry, though they use
+// the same execution machinery
+class CastFunction : public ScalarFunction {
+ public:
+ CastFunction(std::string name, Type::type out_type_id);
+
+ Type::type out_type_id() const { return out_type_id_; }
+ const std::vector<Type::type>& in_type_ids() const { return in_type_ids_; }
+
+ Status AddKernel(Type::type in_type_id, std::vector<InputType> in_types,
+ OutputType out_type, ArrayKernelExec exec,
+ NullHandling::type = NullHandling::INTERSECTION,
+ MemAllocation::type = MemAllocation::PREALLOCATE);
+
+ // Note, this function toggles off memory allocation and sets the init
+ // function to CastInit
+ Status AddKernel(Type::type in_type_id, ScalarKernel kernel);
+
+ Result<const Kernel*> DispatchExact(
+ const std::vector<ValueDescr>& values) const override;
+
+ private:
+ std::vector<Type::type> in_type_ids_;
+ const Type::type out_type_id_;
+};
+
+ARROW_EXPORT
+Result<std::shared_ptr<CastFunction>> GetCastFunction(
+ const std::shared_ptr<DataType>& to_type);
+
+/// \brief Return true if a cast function is defined
+ARROW_EXPORT
+bool CanCast(const DataType& from_type, const DataType& to_type);
+
+// ----------------------------------------------------------------------
+// Convenience invocation APIs for a number of kernels
+
+/// \brief Cast from one array type to another
+/// \param[in] value array to cast
+/// \param[in] to_type type to cast to
+/// \param[in] options casting options
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting array
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> Cast(const Array& value, std::shared_ptr<DataType> to_type,
+ const CastOptions& options = CastOptions::Safe(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Cast from one array type to another
+/// \param[in] value array to cast
+/// \param[in] options casting options. The "to_type" field must be populated
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting array
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Cast(const Datum& value, const CastOptions& options,
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Cast from one value to another
+/// \param[in] value datum to cast
+/// \param[in] to_type type to cast to
+/// \param[in] options casting options
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datum
+///
+/// \since 1.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<Datum> Cast(const Datum& value, std::shared_ptr<DataType> to_type,
+ const CastOptions& options = CastOptions::Safe(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Cast several values simultaneously. Safe cast options are used.
+/// \param[in] values datums to cast
+/// \param[in] descrs ValueDescrs to cast to
+/// \param[in] ctx the function execution context, optional
+/// \return the resulting datums
+///
+/// \since 4.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result<std::vector<Datum>> Cast(std::vector<Datum> values, std::vector<ValueDescr> descrs,
+ ExecContext* ctx = NULLPTR);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/cast_internal.h b/src/arrow/cpp/src/arrow/compute/cast_internal.h
new file mode 100644
index 000000000..0105d08a5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/cast_internal.h
@@ -0,0 +1,43 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <vector>
+
+#include "arrow/compute/cast.h" // IWYU pragma: keep
+#include "arrow/compute/kernel.h" // IWYU pragma: keep
+#include "arrow/compute/kernels/codegen_internal.h" // IWYU pragma: keep
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+using CastState = OptionsWrapper<CastOptions>;
+
+// See kernels/scalar_cast_*.cc for these
+std::vector<std::shared_ptr<CastFunction>> GetBooleanCasts();
+std::vector<std::shared_ptr<CastFunction>> GetNumericCasts();
+std::vector<std::shared_ptr<CastFunction>> GetTemporalCasts();
+std::vector<std::shared_ptr<CastFunction>> GetBinaryLikeCasts();
+std::vector<std::shared_ptr<CastFunction>> GetNestedCasts();
+std::vector<std::shared_ptr<CastFunction>> GetDictionaryCasts();
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec.cc b/src/arrow/cpp/src/arrow/compute/exec.cc
new file mode 100644
index 000000000..50f1ad4fd
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec.cc
@@ -0,0 +1,1061 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <sstream>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_primitive.h"
+#include "arrow/array/data.h"
+#include "arrow/array/util.h"
+#include "arrow/buffer.h"
+#include "arrow/chunked_array.h"
+#include "arrow/compute/exec_internal.h"
+#include "arrow/compute/function.h"
+#include "arrow/compute/kernel.h"
+#include "arrow/compute/registry.h"
+#include "arrow/datum.h"
+#include "arrow/pretty_print.h"
+#include "arrow/record_batch.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/cpu_info.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/vector.h"
+
+namespace arrow {
+
+using internal::BitmapAnd;
+using internal::checked_cast;
+using internal::CopyBitmap;
+using internal::CpuInfo;
+
+namespace compute {
+
+ExecContext* default_exec_context() {
+ static ExecContext default_ctx;
+ return &default_ctx;
+}
+
+ExecBatch::ExecBatch(const RecordBatch& batch)
+ : values(batch.num_columns()), length(batch.num_rows()) {
+ auto columns = batch.column_data();
+ std::move(columns.begin(), columns.end(), values.begin());
+}
+
+bool ExecBatch::Equals(const ExecBatch& other) const {
+ return guarantee == other.guarantee && values == other.values;
+}
+
+void PrintTo(const ExecBatch& batch, std::ostream* os) {
+ *os << "ExecBatch\n";
+
+ static const std::string indent = " ";
+
+ *os << indent << "# Rows: " << batch.length << "\n";
+ if (batch.guarantee != literal(true)) {
+ *os << indent << "Guarantee: " << batch.guarantee.ToString() << "\n";
+ }
+
+ int i = 0;
+ for (const Datum& value : batch.values) {
+ *os << indent << "" << i++ << ": ";
+
+ if (value.is_scalar()) {
+ *os << "Scalar[" << value.scalar()->ToString() << "]\n";
+ continue;
+ }
+
+ auto array = value.make_array();
+ PrettyPrintOptions options;
+ options.skip_new_lines = true;
+ *os << "Array";
+ ARROW_CHECK_OK(PrettyPrint(*array, options, os));
+ *os << "\n";
+ }
+}
+
+std::string ExecBatch::ToString() const {
+ std::stringstream ss;
+ PrintTo(*this, &ss);
+ return ss.str();
+}
+
+ExecBatch ExecBatch::Slice(int64_t offset, int64_t length) const {
+ ExecBatch out = *this;
+ for (auto& value : out.values) {
+ if (value.is_scalar()) continue;
+ value = value.array()->Slice(offset, length);
+ }
+ out.length = std::min(length, this->length - offset);
+ return out;
+}
+
+Result<ExecBatch> ExecBatch::Make(std::vector<Datum> values) {
+ if (values.empty()) {
+ return Status::Invalid("Cannot infer ExecBatch length without at least one value");
+ }
+
+ int64_t length = -1;
+ for (const auto& value : values) {
+ if (value.is_scalar()) {
+ continue;
+ }
+
+ if (length == -1) {
+ length = value.length();
+ continue;
+ }
+
+ if (length != value.length()) {
+ return Status::Invalid(
+ "Arrays used to construct an ExecBatch must have equal length");
+ }
+ }
+
+ if (length == -1) {
+ length = 1;
+ }
+
+ return ExecBatch(std::move(values), length);
+}
+
+Result<std::shared_ptr<RecordBatch>> ExecBatch::ToRecordBatch(
+ std::shared_ptr<Schema> schema, MemoryPool* pool) const {
+ ArrayVector columns(schema->num_fields());
+
+ for (size_t i = 0; i < columns.size(); ++i) {
+ const Datum& value = values[i];
+ if (value.is_array()) {
+ columns[i] = value.make_array();
+ continue;
+ }
+ ARROW_ASSIGN_OR_RAISE(columns[i], MakeArrayFromScalar(*value.scalar(), length, pool));
+ }
+
+ return RecordBatch::Make(std::move(schema), length, std::move(columns));
+}
+
+namespace {
+
+Result<std::shared_ptr<Buffer>> AllocateDataBuffer(KernelContext* ctx, int64_t length,
+ int bit_width) {
+ if (bit_width == 1) {
+ return ctx->AllocateBitmap(length);
+ } else {
+ int64_t buffer_size = BitUtil::BytesForBits(length * bit_width);
+ return ctx->Allocate(buffer_size);
+ }
+}
+
+struct BufferPreallocation {
+ explicit BufferPreallocation(int bit_width = -1, int added_length = 0)
+ : bit_width(bit_width), added_length(added_length) {}
+
+ int bit_width;
+ int added_length;
+};
+
+void ComputeDataPreallocate(const DataType& type,
+ std::vector<BufferPreallocation>* widths) {
+ if (is_fixed_width(type.id()) && type.id() != Type::NA) {
+ widths->emplace_back(checked_cast<const FixedWidthType&>(type).bit_width());
+ return;
+ }
+ // Preallocate binary and list offsets
+ switch (type.id()) {
+ case Type::BINARY:
+ case Type::STRING:
+ case Type::LIST:
+ case Type::MAP:
+ widths->emplace_back(32, /*added_length=*/1);
+ return;
+ case Type::LARGE_BINARY:
+ case Type::LARGE_STRING:
+ case Type::LARGE_LIST:
+ widths->emplace_back(64, /*added_length=*/1);
+ return;
+ default:
+ break;
+ }
+}
+
+} // namespace
+
+namespace detail {
+
+Status CheckAllValues(const std::vector<Datum>& values) {
+ for (const auto& value : values) {
+ if (!value.is_value()) {
+ return Status::Invalid("Tried executing function with non-value type: ",
+ value.ToString());
+ }
+ }
+ return Status::OK();
+}
+
+ExecBatchIterator::ExecBatchIterator(std::vector<Datum> args, int64_t length,
+ int64_t max_chunksize)
+ : args_(std::move(args)),
+ position_(0),
+ length_(length),
+ max_chunksize_(max_chunksize) {
+ chunk_indexes_.resize(args_.size(), 0);
+ chunk_positions_.resize(args_.size(), 0);
+}
+
+Result<std::unique_ptr<ExecBatchIterator>> ExecBatchIterator::Make(
+ std::vector<Datum> args, int64_t max_chunksize) {
+ for (const auto& arg : args) {
+ if (!(arg.is_arraylike() || arg.is_scalar())) {
+ return Status::Invalid(
+ "ExecBatchIterator only works with Scalar, Array, and "
+ "ChunkedArray arguments");
+ }
+ }
+
+ // If the arguments are all scalars, then the length is 1
+ int64_t length = 1;
+
+ bool length_set = false;
+ for (auto& arg : args) {
+ if (arg.is_scalar()) {
+ continue;
+ }
+ if (!length_set) {
+ length = arg.length();
+ length_set = true;
+ } else {
+ if (arg.length() != length) {
+ return Status::Invalid("Array arguments must all be the same length");
+ }
+ }
+ }
+
+ max_chunksize = std::min(length, max_chunksize);
+
+ return std::unique_ptr<ExecBatchIterator>(
+ new ExecBatchIterator(std::move(args), length, max_chunksize));
+}
+
+bool ExecBatchIterator::Next(ExecBatch* batch) {
+ if (position_ == length_) {
+ return false;
+ }
+
+ // Determine how large the common contiguous "slice" of all the arguments is
+ int64_t iteration_size = std::min(length_ - position_, max_chunksize_);
+
+ // If length_ is 0, then this loop will never execute
+ for (size_t i = 0; i < args_.size() && iteration_size > 0; ++i) {
+ // If the argument is not a chunked array, it's either a Scalar or Array,
+ // in which case it doesn't influence the size of this batch. Note that if
+ // the args are all scalars the batch length is 1
+ if (args_[i].kind() != Datum::CHUNKED_ARRAY) {
+ continue;
+ }
+ const ChunkedArray& arg = *args_[i].chunked_array();
+ std::shared_ptr<Array> current_chunk;
+ while (true) {
+ current_chunk = arg.chunk(chunk_indexes_[i]);
+ if (chunk_positions_[i] == current_chunk->length()) {
+ // Chunk is zero-length, or was exhausted in the previous iteration
+ chunk_positions_[i] = 0;
+ ++chunk_indexes_[i];
+ continue;
+ }
+ break;
+ }
+ iteration_size =
+ std::min(current_chunk->length() - chunk_positions_[i], iteration_size);
+ }
+
+ // Now, fill the batch
+ batch->values.resize(args_.size());
+ batch->length = iteration_size;
+ for (size_t i = 0; i < args_.size(); ++i) {
+ if (args_[i].is_scalar()) {
+ batch->values[i] = args_[i].scalar();
+ } else if (args_[i].is_array()) {
+ batch->values[i] = args_[i].array()->Slice(position_, iteration_size);
+ } else {
+ const ChunkedArray& carr = *args_[i].chunked_array();
+ const auto& chunk = carr.chunk(chunk_indexes_[i]);
+ batch->values[i] = chunk->data()->Slice(chunk_positions_[i], iteration_size);
+ chunk_positions_[i] += iteration_size;
+ }
+ }
+ position_ += iteration_size;
+ DCHECK_LE(position_, length_);
+ return true;
+}
+
+namespace {
+
+struct NullGeneralization {
+ enum type { PERHAPS_NULL, ALL_VALID, ALL_NULL };
+
+ static type Get(const Datum& datum) {
+ if (datum.type()->id() == Type::NA) {
+ return ALL_NULL;
+ }
+
+ if (datum.is_scalar()) {
+ return datum.scalar()->is_valid ? ALL_VALID : ALL_NULL;
+ }
+
+ const auto& arr = *datum.array();
+
+ // Do not count the bits if they haven't been counted already
+ const int64_t known_null_count = arr.null_count.load();
+ if ((known_null_count == 0) || (arr.buffers[0] == NULLPTR)) {
+ return ALL_VALID;
+ }
+
+ if (known_null_count == arr.length) {
+ return ALL_NULL;
+ }
+
+ return PERHAPS_NULL;
+ }
+};
+
+// Null propagation implementation that deals both with preallocated bitmaps
+// and maybe-to-be allocated bitmaps
+//
+// If the bitmap is preallocated, it MUST be populated (since it might be a
+// view of a much larger bitmap). If it isn't preallocated, then we have
+// more flexibility.
+//
+// * If the batch has no nulls, then we do nothing
+// * If only a single array has nulls, and its offset is a multiple of 8,
+// then we can zero-copy the bitmap into the output
+// * Otherwise, we allocate the bitmap and populate it
+class NullPropagator {
+ public:
+ NullPropagator(KernelContext* ctx, const ExecBatch& batch, ArrayData* output)
+ : ctx_(ctx), batch_(batch), output_(output) {
+ for (const Datum& datum : batch_.values) {
+ auto null_generalization = NullGeneralization::Get(datum);
+
+ if (null_generalization == NullGeneralization::ALL_NULL) {
+ is_all_null_ = true;
+ }
+
+ if (null_generalization != NullGeneralization::ALL_VALID &&
+ datum.kind() == Datum::ARRAY) {
+ arrays_with_nulls_.push_back(datum.array().get());
+ }
+ }
+
+ if (output->buffers[0] != nullptr) {
+ bitmap_preallocated_ = true;
+ SetBitmap(output_->buffers[0].get());
+ }
+ }
+
+ void SetBitmap(Buffer* bitmap) { bitmap_ = bitmap->mutable_data(); }
+
+ Status EnsureAllocated() {
+ if (bitmap_preallocated_) {
+ return Status::OK();
+ }
+ ARROW_ASSIGN_OR_RAISE(output_->buffers[0], ctx_->AllocateBitmap(output_->length));
+ SetBitmap(output_->buffers[0].get());
+ return Status::OK();
+ }
+
+ Status AllNullShortCircuit() {
+ // OK, the output should be all null
+ output_->null_count = output_->length;
+
+ if (bitmap_preallocated_) {
+ BitUtil::SetBitsTo(bitmap_, output_->offset, output_->length, false);
+ return Status::OK();
+ }
+
+ // Walk all the values with nulls instead of breaking on the first in case
+ // we find a bitmap that can be reused in the non-preallocated case
+ for (const ArrayData* arr : arrays_with_nulls_) {
+ if (arr->null_count.load() == arr->length && arr->buffers[0] != nullptr) {
+ // Reuse this all null bitmap
+ output_->buffers[0] = arr->buffers[0];
+ return Status::OK();
+ }
+ }
+
+ RETURN_NOT_OK(EnsureAllocated());
+ BitUtil::SetBitsTo(bitmap_, output_->offset, output_->length, false);
+ return Status::OK();
+ }
+
+ Status PropagateSingle() {
+ // One array
+ const ArrayData& arr = *arrays_with_nulls_[0];
+ const std::shared_ptr<Buffer>& arr_bitmap = arr.buffers[0];
+
+ // Reuse the null count if it's known
+ output_->null_count = arr.null_count.load();
+
+ if (bitmap_preallocated_) {
+ CopyBitmap(arr_bitmap->data(), arr.offset, arr.length, bitmap_, output_->offset);
+ return Status::OK();
+ }
+
+ // Two cases when memory was not pre-allocated:
+ //
+ // * Offset is zero: we reuse the bitmap as is
+ // * Offset is nonzero but a multiple of 8: we can slice the bitmap
+ // * Offset is not a multiple of 8: we must allocate and use CopyBitmap
+ //
+ // Keep in mind that output_->offset is not permitted to be nonzero when
+ // the bitmap is not preallocated, and that precondition is asserted
+ // higher in the call stack.
+ if (arr.offset == 0) {
+ output_->buffers[0] = arr_bitmap;
+ } else if (arr.offset % 8 == 0) {
+ output_->buffers[0] =
+ SliceBuffer(arr_bitmap, arr.offset / 8, BitUtil::BytesForBits(arr.length));
+ } else {
+ RETURN_NOT_OK(EnsureAllocated());
+ CopyBitmap(arr_bitmap->data(), arr.offset, arr.length, bitmap_,
+ /*dst_offset=*/0);
+ }
+ return Status::OK();
+ }
+
+ Status PropagateMultiple() {
+ // More than one array. We use BitmapAnd to intersect their bitmaps
+
+ // Do not compute the intersection null count until it's needed
+ RETURN_NOT_OK(EnsureAllocated());
+
+ auto Accumulate = [&](const ArrayData& left, const ArrayData& right) {
+ DCHECK(left.buffers[0]);
+ DCHECK(right.buffers[0]);
+ BitmapAnd(left.buffers[0]->data(), left.offset, right.buffers[0]->data(),
+ right.offset, output_->length, output_->offset,
+ output_->buffers[0]->mutable_data());
+ };
+
+ DCHECK_GT(arrays_with_nulls_.size(), 1);
+
+ // Seed the output bitmap with the & of the first two bitmaps
+ Accumulate(*arrays_with_nulls_[0], *arrays_with_nulls_[1]);
+
+ // Accumulate the rest
+ for (size_t i = 2; i < arrays_with_nulls_.size(); ++i) {
+ Accumulate(*output_, *arrays_with_nulls_[i]);
+ }
+ return Status::OK();
+ }
+
+ Status Execute() {
+ if (is_all_null_) {
+ // An all-null value (scalar null or all-null array) gives us a short
+ // circuit opportunity
+ return AllNullShortCircuit();
+ }
+
+ // At this point, by construction we know that all of the values in
+ // arrays_with_nulls_ are arrays that are not all null. So there are a
+ // few cases:
+ //
+ // * No arrays. This is a no-op w/o preallocation but when the bitmap is
+ // pre-allocated we have to fill it with 1's
+ // * One array, whose bitmap can be zero-copied (w/o preallocation, and
+ // when no byte is split) or copied (split byte or w/ preallocation)
+ // * More than one array, we must compute the intersection of all the
+ // bitmaps
+ //
+ // BUT, if the output offset is nonzero for some reason, we copy into the
+ // output unconditionally
+
+ output_->null_count = kUnknownNullCount;
+
+ if (arrays_with_nulls_.empty()) {
+ // No arrays with nulls case
+ output_->null_count = 0;
+ if (bitmap_preallocated_) {
+ BitUtil::SetBitsTo(bitmap_, output_->offset, output_->length, true);
+ }
+ return Status::OK();
+ }
+
+ if (arrays_with_nulls_.size() == 1) {
+ return PropagateSingle();
+ }
+
+ return PropagateMultiple();
+ }
+
+ private:
+ KernelContext* ctx_;
+ const ExecBatch& batch_;
+ std::vector<const ArrayData*> arrays_with_nulls_;
+ bool is_all_null_ = false;
+ ArrayData* output_;
+ uint8_t* bitmap_;
+ bool bitmap_preallocated_ = false;
+};
+
+std::shared_ptr<ChunkedArray> ToChunkedArray(const std::vector<Datum>& values,
+ const std::shared_ptr<DataType>& type) {
+ std::vector<std::shared_ptr<Array>> arrays;
+ arrays.reserve(values.size());
+ for (const Datum& val : values) {
+ if (val.length() == 0) {
+ // Skip empty chunks
+ continue;
+ }
+ arrays.emplace_back(val.make_array());
+ }
+ return std::make_shared<ChunkedArray>(std::move(arrays), type);
+}
+
+bool HaveChunkedArray(const std::vector<Datum>& values) {
+ for (const auto& value : values) {
+ if (value.kind() == Datum::CHUNKED_ARRAY) {
+ return true;
+ }
+ }
+ return false;
+}
+
+template <typename KernelType>
+class KernelExecutorImpl : public KernelExecutor {
+ public:
+ Status Init(KernelContext* kernel_ctx, KernelInitArgs args) override {
+ kernel_ctx_ = kernel_ctx;
+ kernel_ = static_cast<const KernelType*>(args.kernel);
+
+ // Resolve the output descriptor for this kernel
+ ARROW_ASSIGN_OR_RAISE(
+ output_descr_, kernel_->signature->out_type().Resolve(kernel_ctx_, args.inputs));
+
+ return Status::OK();
+ }
+
+ protected:
+ // This is overridden by the VectorExecutor
+ virtual Status SetupArgIteration(const std::vector<Datum>& args) {
+ ARROW_ASSIGN_OR_RAISE(
+ batch_iterator_, ExecBatchIterator::Make(args, exec_context()->exec_chunksize()));
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<ArrayData>> PrepareOutput(int64_t length) {
+ auto out = std::make_shared<ArrayData>(output_descr_.type, length);
+ out->buffers.resize(output_num_buffers_);
+
+ if (validity_preallocated_) {
+ ARROW_ASSIGN_OR_RAISE(out->buffers[0], kernel_ctx_->AllocateBitmap(length));
+ }
+ if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) {
+ out->null_count = 0;
+ }
+ for (size_t i = 0; i < data_preallocated_.size(); ++i) {
+ const auto& prealloc = data_preallocated_[i];
+ if (prealloc.bit_width >= 0) {
+ ARROW_ASSIGN_OR_RAISE(
+ out->buffers[i + 1],
+ AllocateDataBuffer(kernel_ctx_, length + prealloc.added_length,
+ prealloc.bit_width));
+ }
+ }
+ return out;
+ }
+
+ Status CheckResultType(const Datum& out, const char* function_name) override {
+ const auto& type = out.type();
+ if (type != nullptr && !type->Equals(output_descr_.type)) {
+ return Status::TypeError(
+ "kernel type result mismatch for function '", function_name, "': declared as ",
+ output_descr_.type->ToString(), ", actual is ", type->ToString());
+ }
+ return Status::OK();
+ }
+
+ ExecContext* exec_context() { return kernel_ctx_->exec_context(); }
+ KernelState* state() { return kernel_ctx_->state(); }
+
+ // Not all of these members are used for every executor type
+
+ KernelContext* kernel_ctx_;
+ const KernelType* kernel_;
+ std::unique_ptr<ExecBatchIterator> batch_iterator_;
+ ValueDescr output_descr_;
+
+ int output_num_buffers_;
+
+ // If true, then memory is preallocated for the validity bitmap with the same
+ // strategy as the data buffer(s).
+ bool validity_preallocated_ = false;
+
+ // The kernel writes into data buffers preallocated for these bit widths
+ // (0 indicates no preallocation);
+ std::vector<BufferPreallocation> data_preallocated_;
+};
+
+class ScalarExecutor : public KernelExecutorImpl<ScalarKernel> {
+ public:
+ Status Execute(const std::vector<Datum>& args, ExecListener* listener) override {
+ RETURN_NOT_OK(PrepareExecute(args));
+ ExecBatch batch;
+ while (batch_iterator_->Next(&batch)) {
+ RETURN_NOT_OK(ExecuteBatch(batch, listener));
+ }
+ if (preallocate_contiguous_) {
+ // If we preallocated one big chunk, since the kernel execution is
+ // completed, we can now emit it
+ RETURN_NOT_OK(listener->OnResult(std::move(preallocated_)));
+ }
+ return Status::OK();
+ }
+
+ Datum WrapResults(const std::vector<Datum>& inputs,
+ const std::vector<Datum>& outputs) override {
+ if (output_descr_.shape == ValueDescr::SCALAR) {
+ DCHECK_GT(outputs.size(), 0);
+ if (outputs.size() == 1) {
+ // Return as SCALAR
+ return outputs[0];
+ } else {
+ // Return as COLLECTION
+ return outputs;
+ }
+ } else {
+ // If execution yielded multiple chunks (because large arrays were split
+ // based on the ExecContext parameters, then the result is a ChunkedArray
+ if (HaveChunkedArray(inputs) || outputs.size() > 1) {
+ return ToChunkedArray(outputs, output_descr_.type);
+ } else if (outputs.size() == 1) {
+ // Outputs have just one element
+ return outputs[0];
+ } else {
+ // XXX: In the case where no outputs are omitted, is returning a 0-length
+ // array always the correct move?
+ return MakeArrayOfNull(output_descr_.type, /*length=*/0,
+ exec_context()->memory_pool())
+ .ValueOrDie();
+ }
+ }
+ }
+
+ protected:
+ Status ExecuteBatch(const ExecBatch& batch, ExecListener* listener) {
+ Datum out;
+ RETURN_NOT_OK(PrepareNextOutput(batch, &out));
+
+ if (output_descr_.shape == ValueDescr::ARRAY) {
+ ArrayData* out_arr = out.mutable_array();
+ if (kernel_->null_handling == NullHandling::INTERSECTION) {
+ RETURN_NOT_OK(PropagateNulls(kernel_ctx_, batch, out_arr));
+ } else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) {
+ out_arr->null_count = 0;
+ }
+ } else {
+ if (kernel_->null_handling == NullHandling::INTERSECTION) {
+ // set scalar validity
+ out.scalar()->is_valid =
+ std::all_of(batch.values.begin(), batch.values.end(),
+ [](const Datum& input) { return input.scalar()->is_valid; });
+ } else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) {
+ out.scalar()->is_valid = true;
+ }
+ }
+
+ RETURN_NOT_OK(kernel_->exec(kernel_ctx_, batch, &out));
+ if (!preallocate_contiguous_) {
+ // If we are producing chunked output rather than one big array, then
+ // emit each chunk as soon as it's available
+ RETURN_NOT_OK(listener->OnResult(std::move(out)));
+ }
+ return Status::OK();
+ }
+
+ Status PrepareExecute(const std::vector<Datum>& args) {
+ RETURN_NOT_OK(this->SetupArgIteration(args));
+
+ if (output_descr_.shape == ValueDescr::ARRAY) {
+ // If the executor is configured to produce a single large Array output for
+ // kernels supporting preallocation, then we do so up front and then
+ // iterate over slices of that large array. Otherwise, we preallocate prior
+ // to processing each batch emitted from the ExecBatchIterator
+ RETURN_NOT_OK(SetupPreallocation(batch_iterator_->length()));
+ }
+ return Status::OK();
+ }
+
+ // We must accommodate two different modes of execution for preallocated
+ // execution
+ //
+ // * A single large ("contiguous") allocation that we populate with results
+ // on a chunkwise basis according to the ExecBatchIterator. This permits
+ // parallelization even if the objective is to obtain a single Array or
+ // ChunkedArray at the end
+ // * A standalone buffer preallocation for each chunk emitted from the
+ // ExecBatchIterator
+ //
+ // When data buffer preallocation is not possible (e.g. with BINARY / STRING
+ // outputs), then contiguous results are only possible if the input is
+ // contiguous.
+
+ Status PrepareNextOutput(const ExecBatch& batch, Datum* out) {
+ if (output_descr_.shape == ValueDescr::ARRAY) {
+ if (preallocate_contiguous_) {
+ // The output is already fully preallocated
+ const int64_t batch_start_position = batch_iterator_->position() - batch.length;
+
+ if (batch.length < batch_iterator_->length()) {
+ // If this is a partial execution, then we write into a slice of
+ // preallocated_
+ out->value = preallocated_->Slice(batch_start_position, batch.length);
+ } else {
+ // Otherwise write directly into preallocated_. The main difference
+ // computationally (versus the Slice approach) is that the null_count
+ // may not need to be recomputed in the result
+ out->value = preallocated_;
+ }
+ } else {
+ // We preallocate (maybe) only for the output of processing the current
+ // batch
+ ARROW_ASSIGN_OR_RAISE(out->value, PrepareOutput(batch.length));
+ }
+ } else {
+ // For scalar outputs, we set a null scalar of the correct type to
+ // communicate the output type to the kernel if needed
+ //
+ // XXX: Is there some way to avoid this step?
+ out->value = MakeNullScalar(output_descr_.type);
+ }
+ return Status::OK();
+ }
+
+ Status SetupPreallocation(int64_t total_length) {
+ output_num_buffers_ = static_cast<int>(output_descr_.type->layout().buffers.size());
+
+ // Decide if we need to preallocate memory for this kernel
+ validity_preallocated_ =
+ (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE &&
+ kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL &&
+ output_descr_.type->id() != Type::NA);
+ if (kernel_->mem_allocation == MemAllocation::PREALLOCATE) {
+ ComputeDataPreallocate(*output_descr_.type, &data_preallocated_);
+ }
+
+ // Contiguous preallocation only possible on non-nested types if all
+ // buffers are preallocated. Otherwise, we must go chunk-by-chunk.
+ //
+ // Some kernels are also unable to write into sliced outputs, so we respect the
+ // kernel's attributes.
+ preallocate_contiguous_ =
+ (exec_context()->preallocate_contiguous() && kernel_->can_write_into_slices &&
+ validity_preallocated_ && !is_nested(output_descr_.type->id()) &&
+ !is_dictionary(output_descr_.type->id()) &&
+ data_preallocated_.size() == static_cast<size_t>(output_num_buffers_ - 1) &&
+ std::all_of(data_preallocated_.begin(), data_preallocated_.end(),
+ [](const BufferPreallocation& prealloc) {
+ return prealloc.bit_width >= 0;
+ }));
+ if (preallocate_contiguous_) {
+ ARROW_ASSIGN_OR_RAISE(preallocated_, PrepareOutput(total_length));
+ }
+ return Status::OK();
+ }
+
+ // If true, and the kernel and output type supports preallocation (for both
+ // the validity and data buffers), then we allocate one big array and then
+ // iterate through it while executing the kernel in chunks
+ bool preallocate_contiguous_ = false;
+
+ // For storing a contiguous preallocation per above. Unused otherwise
+ std::shared_ptr<ArrayData> preallocated_;
+};
+
+Status PackBatchNoChunks(const std::vector<Datum>& args, ExecBatch* out) {
+ int64_t length = 0;
+ for (const auto& arg : args) {
+ switch (arg.kind()) {
+ case Datum::SCALAR:
+ case Datum::ARRAY:
+ case Datum::CHUNKED_ARRAY:
+ length = std::max(arg.length(), length);
+ break;
+ default:
+ DCHECK(false);
+ break;
+ }
+ }
+ out->length = length;
+ out->values = args;
+ return Status::OK();
+}
+
+class VectorExecutor : public KernelExecutorImpl<VectorKernel> {
+ public:
+ Status Execute(const std::vector<Datum>& args, ExecListener* listener) override {
+ RETURN_NOT_OK(PrepareExecute(args));
+ ExecBatch batch;
+ if (kernel_->can_execute_chunkwise) {
+ while (batch_iterator_->Next(&batch)) {
+ RETURN_NOT_OK(ExecuteBatch(batch, listener));
+ }
+ } else {
+ RETURN_NOT_OK(PackBatchNoChunks(args, &batch));
+ RETURN_NOT_OK(ExecuteBatch(batch, listener));
+ }
+ return Finalize(listener);
+ }
+
+ Datum WrapResults(const std::vector<Datum>& inputs,
+ const std::vector<Datum>& outputs) override {
+ // If execution yielded multiple chunks (because large arrays were split
+ // based on the ExecContext parameters, then the result is a ChunkedArray
+ if (kernel_->output_chunked && (HaveChunkedArray(inputs) || outputs.size() > 1)) {
+ return ToChunkedArray(outputs, output_descr_.type);
+ } else if (outputs.size() == 1) {
+ // Outputs have just one element
+ return outputs[0];
+ } else {
+ // XXX: In the case where no outputs are omitted, is returning a 0-length
+ // array always the correct move?
+ return MakeArrayOfNull(output_descr_.type, /*length=*/0).ValueOrDie();
+ }
+ }
+
+ protected:
+ Status ExecuteBatch(const ExecBatch& batch, ExecListener* listener) {
+ Datum out;
+ if (output_descr_.shape == ValueDescr::ARRAY) {
+ // We preallocate (maybe) only for the output of processing the current
+ // batch
+ ARROW_ASSIGN_OR_RAISE(out.value, PrepareOutput(batch.length));
+ }
+
+ if (kernel_->null_handling == NullHandling::INTERSECTION &&
+ output_descr_.shape == ValueDescr::ARRAY) {
+ RETURN_NOT_OK(PropagateNulls(kernel_ctx_, batch, out.mutable_array()));
+ }
+ RETURN_NOT_OK(kernel_->exec(kernel_ctx_, batch, &out));
+ if (!kernel_->finalize) {
+ // If there is no result finalizer (e.g. for hash-based functions, we can
+ // emit the processed batch right away rather than waiting
+ RETURN_NOT_OK(listener->OnResult(std::move(out)));
+ } else {
+ results_.emplace_back(std::move(out));
+ }
+ return Status::OK();
+ }
+
+ Status Finalize(ExecListener* listener) {
+ if (kernel_->finalize) {
+ // Intermediate results require post-processing after the execution is
+ // completed (possibly involving some accumulated state)
+ RETURN_NOT_OK(kernel_->finalize(kernel_ctx_, &results_));
+ for (const auto& result : results_) {
+ RETURN_NOT_OK(listener->OnResult(result));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status SetupArgIteration(const std::vector<Datum>& args) override {
+ if (kernel_->can_execute_chunkwise) {
+ ARROW_ASSIGN_OR_RAISE(batch_iterator_, ExecBatchIterator::Make(
+ args, exec_context()->exec_chunksize()));
+ }
+ return Status::OK();
+ }
+
+ Status PrepareExecute(const std::vector<Datum>& args) {
+ RETURN_NOT_OK(this->SetupArgIteration(args));
+ output_num_buffers_ = static_cast<int>(output_descr_.type->layout().buffers.size());
+
+ // Decide if we need to preallocate memory for this kernel
+ validity_preallocated_ =
+ (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE &&
+ kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL);
+ if (kernel_->mem_allocation == MemAllocation::PREALLOCATE) {
+ ComputeDataPreallocate(*output_descr_.type, &data_preallocated_);
+ }
+ return Status::OK();
+ }
+
+ std::vector<Datum> results_;
+};
+
+class ScalarAggExecutor : public KernelExecutorImpl<ScalarAggregateKernel> {
+ public:
+ Status Init(KernelContext* ctx, KernelInitArgs args) override {
+ input_descrs_ = &args.inputs;
+ options_ = args.options;
+ return KernelExecutorImpl<ScalarAggregateKernel>::Init(ctx, args);
+ }
+
+ Status Execute(const std::vector<Datum>& args, ExecListener* listener) override {
+ RETURN_NOT_OK(this->SetupArgIteration(args));
+
+ ExecBatch batch;
+ while (batch_iterator_->Next(&batch)) {
+ // TODO: implement parallelism
+ if (batch.length > 0) {
+ RETURN_NOT_OK(Consume(batch));
+ }
+ }
+
+ Datum out;
+ RETURN_NOT_OK(kernel_->finalize(kernel_ctx_, &out));
+ RETURN_NOT_OK(listener->OnResult(std::move(out)));
+ return Status::OK();
+ }
+
+ Datum WrapResults(const std::vector<Datum>&,
+ const std::vector<Datum>& outputs) override {
+ DCHECK_EQ(1, outputs.size());
+ return outputs[0];
+ }
+
+ private:
+ Status Consume(const ExecBatch& batch) {
+ // FIXME(ARROW-11840) don't merge *any* aggegates for every batch
+ ARROW_ASSIGN_OR_RAISE(
+ auto batch_state,
+ kernel_->init(kernel_ctx_, {kernel_, *input_descrs_, options_}));
+
+ if (batch_state == nullptr) {
+ return Status::Invalid("ScalarAggregation requires non-null kernel state");
+ }
+
+ KernelContext batch_ctx(exec_context());
+ batch_ctx.SetState(batch_state.get());
+
+ RETURN_NOT_OK(kernel_->consume(&batch_ctx, batch));
+ RETURN_NOT_OK(kernel_->merge(kernel_ctx_, std::move(*batch_state), state()));
+ return Status::OK();
+ }
+
+ const std::vector<ValueDescr>* input_descrs_;
+ const FunctionOptions* options_;
+};
+
+template <typename ExecutorType,
+ typename FunctionType = typename ExecutorType::FunctionType>
+Result<std::unique_ptr<KernelExecutor>> MakeExecutor(ExecContext* ctx,
+ const Function* func,
+ const FunctionOptions* options) {
+ DCHECK_EQ(ExecutorType::function_kind, func->kind());
+ auto typed_func = checked_cast<const FunctionType*>(func);
+ return std::unique_ptr<KernelExecutor>(new ExecutorType(ctx, typed_func, options));
+}
+
+} // namespace
+
+Status PropagateNulls(KernelContext* ctx, const ExecBatch& batch, ArrayData* output) {
+ DCHECK_NE(nullptr, output);
+ DCHECK_GT(output->buffers.size(), 0);
+
+ if (output->type->id() == Type::NA) {
+ // Null output type is a no-op (rare when this would happen but we at least
+ // will test for it)
+ return Status::OK();
+ }
+
+ // This function is ONLY able to write into output with non-zero offset
+ // when the bitmap is preallocated. This could be a DCHECK but returning
+ // error Status for now for emphasis
+ if (output->offset != 0 && output->buffers[0] == nullptr) {
+ return Status::Invalid(
+ "Can only propagate nulls into pre-allocated memory "
+ "when the output offset is non-zero");
+ }
+ NullPropagator propagator(ctx, batch, output);
+ return propagator.Execute();
+}
+
+std::unique_ptr<KernelExecutor> KernelExecutor::MakeScalar() {
+ return ::arrow::internal::make_unique<detail::ScalarExecutor>();
+}
+
+std::unique_ptr<KernelExecutor> KernelExecutor::MakeVector() {
+ return ::arrow::internal::make_unique<detail::VectorExecutor>();
+}
+
+std::unique_ptr<KernelExecutor> KernelExecutor::MakeScalarAggregate() {
+ return ::arrow::internal::make_unique<detail::ScalarAggExecutor>();
+}
+
+} // namespace detail
+
+ExecContext::ExecContext(MemoryPool* pool, ::arrow::internal::Executor* executor,
+ FunctionRegistry* func_registry)
+ : pool_(pool), executor_(executor) {
+ this->func_registry_ = func_registry == nullptr ? GetFunctionRegistry() : func_registry;
+}
+
+CpuInfo* ExecContext::cpu_info() const { return CpuInfo::GetInstance(); }
+
+// ----------------------------------------------------------------------
+// SelectionVector
+
+SelectionVector::SelectionVector(std::shared_ptr<ArrayData> data)
+ : data_(std::move(data)) {
+ DCHECK_EQ(Type::INT32, data_->type->id());
+ DCHECK_EQ(0, data_->GetNullCount());
+ indices_ = data_->GetValues<int32_t>(1);
+}
+
+SelectionVector::SelectionVector(const Array& arr) : SelectionVector(arr.data()) {}
+
+int32_t SelectionVector::length() const { return static_cast<int32_t>(data_->length); }
+
+Result<std::shared_ptr<SelectionVector>> SelectionVector::FromMask(
+ const BooleanArray& arr) {
+ return Status::NotImplemented("FromMask");
+}
+
+Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args,
+ const FunctionOptions* options, ExecContext* ctx) {
+ if (ctx == nullptr) {
+ ExecContext default_ctx;
+ return CallFunction(func_name, args, options, &default_ctx);
+ }
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<const Function> func,
+ ctx->func_registry()->GetFunction(func_name));
+ return func->Execute(args, options, ctx);
+}
+
+Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args,
+ ExecContext* ctx) {
+ return CallFunction(func_name, args, /*options=*/nullptr, ctx);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec.h b/src/arrow/cpp/src/arrow/compute/exec.h
new file mode 100644
index 000000000..7707622bc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec.h
@@ -0,0 +1,268 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// NOTE: API is EXPERIMENTAL and will change without going through a
+// deprecation cycle
+
+#pragma once
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/data.h"
+#include "arrow/compute/exec/expression.h"
+#include "arrow/datum.h"
+#include "arrow/memory_pool.h"
+#include "arrow/result.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace internal {
+
+class CpuInfo;
+
+} // namespace internal
+
+namespace compute {
+
+class FunctionOptions;
+class FunctionRegistry;
+
+// It seems like 64K might be a good default chunksize to use for execution
+// based on the experience of other query processing systems. The current
+// default is not to chunk contiguous arrays, though, but this may change in
+// the future once parallel execution is implemented
+static constexpr int64_t kDefaultExecChunksize = UINT16_MAX;
+
+/// \brief Context for expression-global variables and options used by
+/// function evaluation
+class ARROW_EXPORT ExecContext {
+ public:
+ // If no function registry passed, the default is used.
+ explicit ExecContext(MemoryPool* pool = default_memory_pool(),
+ ::arrow::internal::Executor* executor = NULLPTR,
+ FunctionRegistry* func_registry = NULLPTR);
+
+ /// \brief The MemoryPool used for allocations, default is
+ /// default_memory_pool().
+ MemoryPool* memory_pool() const { return pool_; }
+
+ ::arrow::internal::CpuInfo* cpu_info() const;
+
+ /// \brief An Executor which may be used to parallelize execution.
+ ::arrow::internal::Executor* executor() const { return executor_; }
+
+ /// \brief The FunctionRegistry for looking up functions by name and
+ /// selecting kernels for execution. Defaults to the library-global function
+ /// registry provided by GetFunctionRegistry.
+ FunctionRegistry* func_registry() const { return func_registry_; }
+
+ // \brief Set maximum length unit of work for kernel execution. Larger
+ // contiguous array inputs will be split into smaller chunks, and, if
+ // possible and enabled, processed in parallel. The default chunksize is
+ // INT64_MAX, so contiguous arrays are not split.
+ void set_exec_chunksize(int64_t chunksize) { exec_chunksize_ = chunksize; }
+
+ // \brief Maximum length for ExecBatch data chunks processed by
+ // kernels. Contiguous array inputs with longer length will be split into
+ // smaller chunks.
+ int64_t exec_chunksize() const { return exec_chunksize_; }
+
+ /// \brief Set whether to use multiple threads for function execution. This
+ /// is not yet used.
+ void set_use_threads(bool use_threads = true) { use_threads_ = use_threads; }
+
+ /// \brief If true, then utilize multiple threads where relevant for function
+ /// execution. This is not yet used.
+ bool use_threads() const { return use_threads_; }
+
+ // Set the preallocation strategy for kernel execution as it relates to
+ // chunked execution. For chunked execution, whether via ChunkedArray inputs
+ // or splitting larger Array arguments into smaller pieces, contiguous
+ // allocation (if permitted by the kernel) will allocate one large array to
+ // write output into yielding it to the caller at the end. If this option is
+ // set to off, then preallocations will be performed independently for each
+ // chunk of execution
+ //
+ // TODO: At some point we might want the limit the size of contiguous
+ // preallocations. For example, even if the exec_chunksize is 64K or less, we
+ // might limit contiguous allocations to 1M records, say.
+ void set_preallocate_contiguous(bool preallocate) {
+ preallocate_contiguous_ = preallocate;
+ }
+
+ /// \brief If contiguous preallocations should be used when doing chunked
+ /// execution as specified by exec_chunksize(). See
+ /// set_preallocate_contiguous() for more information.
+ bool preallocate_contiguous() const { return preallocate_contiguous_; }
+
+ private:
+ MemoryPool* pool_;
+ ::arrow::internal::Executor* executor_;
+ FunctionRegistry* func_registry_;
+ int64_t exec_chunksize_ = std::numeric_limits<int64_t>::max();
+ bool preallocate_contiguous_ = true;
+ bool use_threads_ = true;
+};
+
+ARROW_EXPORT ExecContext* default_exec_context();
+
+// TODO: Consider standardizing on uint16 selection vectors and only use them
+// when we can ensure that each value is 64K length or smaller
+
+/// \brief Container for an array of value selection indices that were
+/// materialized from a filter.
+///
+/// Columnar query engines (see e.g. [1]) have found that rather than
+/// materializing filtered data, the filter can instead be converted to an
+/// array of the "on" indices and then "fusing" these indices in operator
+/// implementations. This is especially relevant for aggregations but also
+/// applies to scalar operations.
+///
+/// We are not yet using this so this is mostly a placeholder for now.
+///
+/// [1]: http://cidrdb.org/cidr2005/papers/P19.pdf
+class ARROW_EXPORT SelectionVector {
+ public:
+ explicit SelectionVector(std::shared_ptr<ArrayData> data);
+
+ explicit SelectionVector(const Array& arr);
+
+ /// \brief Create SelectionVector from boolean mask
+ static Result<std::shared_ptr<SelectionVector>> FromMask(const BooleanArray& arr);
+
+ const int32_t* indices() const { return indices_; }
+ int32_t length() const;
+
+ private:
+ std::shared_ptr<ArrayData> data_;
+ const int32_t* indices_;
+};
+
+/// \brief A unit of work for kernel execution. It contains a collection of
+/// Array and Scalar values and an optional SelectionVector indicating that
+/// there is an unmaterialized filter that either must be materialized, or (if
+/// the kernel supports it) pushed down into the kernel implementation.
+///
+/// ExecBatch is semantically similar to RecordBatch in that in a SQL context
+/// it represents a collection of records, but constant "columns" are
+/// represented by Scalar values rather than having to be converted into arrays
+/// with repeated values.
+///
+/// TODO: Datum uses arrow/util/variant.h which may be a bit heavier-weight
+/// than is desirable for this class. Microbenchmarks would help determine for
+/// sure. See ARROW-8928.
+struct ARROW_EXPORT ExecBatch {
+ ExecBatch() = default;
+ ExecBatch(std::vector<Datum> values, int64_t length)
+ : values(std::move(values)), length(length) {}
+
+ explicit ExecBatch(const RecordBatch& batch);
+
+ static Result<ExecBatch> Make(std::vector<Datum> values);
+
+ Result<std::shared_ptr<RecordBatch>> ToRecordBatch(
+ std::shared_ptr<Schema> schema, MemoryPool* pool = default_memory_pool()) const;
+
+ /// The values representing positional arguments to be passed to a kernel's
+ /// exec function for processing.
+ std::vector<Datum> values;
+
+ /// A deferred filter represented as an array of indices into the values.
+ ///
+ /// For example, the filter [true, true, false, true] would be represented as
+ /// the selection vector [0, 1, 3]. When the selection vector is set,
+ /// ExecBatch::length is equal to the length of this array.
+ std::shared_ptr<SelectionVector> selection_vector;
+
+ /// A predicate Expression guaranteed to evaluate to true for all rows in this batch.
+ Expression guarantee = literal(true);
+
+ /// The semantic length of the ExecBatch. When the values are all scalars,
+ /// the length should be set to 1 for non-aggregate kernels, otherwise the
+ /// length is taken from the array values, except when there is a selection
+ /// vector. When there is a selection vector set, the length of the batch is
+ /// the length of the selection. Aggregate kernels can have an ExecBatch
+ /// formed by projecting just the partition columns from a batch in which
+ /// case, it would have scalar rows with length greater than 1.
+ ///
+ /// If the array values are of length 0 then the length is 0 regardless of
+ /// whether any values are Scalar. In general ExecBatch objects are produced
+ /// by ExecBatchIterator which by design does not yield length-0 batches.
+ int64_t length;
+
+ /// \brief Return the value at the i-th index
+ template <typename index_type>
+ inline const Datum& operator[](index_type i) const {
+ return values[i];
+ }
+
+ bool Equals(const ExecBatch& other) const;
+
+ /// \brief A convenience for the number of values / arguments.
+ int num_values() const { return static_cast<int>(values.size()); }
+
+ ExecBatch Slice(int64_t offset, int64_t length) const;
+
+ /// \brief A convenience for returning the ValueDescr objects (types and
+ /// shapes) from the batch.
+ std::vector<ValueDescr> GetDescriptors() const {
+ std::vector<ValueDescr> result;
+ for (const auto& value : this->values) {
+ result.emplace_back(value.descr());
+ }
+ return result;
+ }
+
+ std::string ToString() const;
+
+ ARROW_EXPORT friend void PrintTo(const ExecBatch&, std::ostream*);
+};
+
+inline bool operator==(const ExecBatch& l, const ExecBatch& r) { return l.Equals(r); }
+inline bool operator!=(const ExecBatch& l, const ExecBatch& r) { return !l.Equals(r); }
+
+/// \defgroup compute-call-function One-shot calls to compute functions
+///
+/// @{
+
+/// \brief One-shot invoker for all types of functions.
+///
+/// Does kernel dispatch, argument checking, iteration of ChunkedArray inputs,
+/// and wrapping of outputs.
+ARROW_EXPORT
+Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args,
+ const FunctionOptions* options, ExecContext* ctx = NULLPTR);
+
+/// \brief Variant of CallFunction which uses a function's default options.
+///
+/// NB: Some functions require FunctionOptions be provided.
+ARROW_EXPORT
+Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args,
+ ExecContext* ctx = NULLPTR);
+
+/// @}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/CMakeLists.txt b/src/arrow/cpp/src/arrow/compute/exec/CMakeLists.txt
new file mode 100644
index 000000000..ccc36c093
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/CMakeLists.txt
@@ -0,0 +1,33 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+arrow_install_all_headers("arrow/compute/exec")
+
+add_arrow_compute_test(expression_test
+ PREFIX
+ "arrow-compute"
+ SOURCES
+ expression_test.cc
+ subtree_test.cc)
+
+add_arrow_compute_test(plan_test PREFIX "arrow-compute")
+add_arrow_compute_test(hash_join_node_test PREFIX "arrow-compute")
+add_arrow_compute_test(union_node_test PREFIX "arrow-compute")
+
+add_arrow_compute_test(util_test PREFIX "arrow-compute")
+
+add_arrow_benchmark(expression_benchmark PREFIX "arrow-compute")
diff --git a/src/arrow/cpp/src/arrow/compute/exec/aggregate_node.cc b/src/arrow/cpp/src/arrow/compute/exec/aggregate_node.cc
new file mode 100644
index 000000000..ddf6f7934
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/aggregate_node.cc
@@ -0,0 +1,644 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/exec_plan.h"
+
+#include <mutex>
+#include <sstream>
+#include <thread>
+#include <unordered_map>
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/compute/exec_internal.h"
+#include "arrow/compute/registry.h"
+#include "arrow/datum.h"
+#include "arrow/result.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+namespace internal {
+
+Result<std::vector<const HashAggregateKernel*>> GetKernels(
+ ExecContext* ctx, const std::vector<internal::Aggregate>& aggregates,
+ const std::vector<ValueDescr>& in_descrs);
+
+Result<std::vector<std::unique_ptr<KernelState>>> InitKernels(
+ const std::vector<const HashAggregateKernel*>& kernels, ExecContext* ctx,
+ const std::vector<internal::Aggregate>& aggregates,
+ const std::vector<ValueDescr>& in_descrs);
+
+Result<FieldVector> ResolveKernels(
+ const std::vector<internal::Aggregate>& aggregates,
+ const std::vector<const HashAggregateKernel*>& kernels,
+ const std::vector<std::unique_ptr<KernelState>>& states, ExecContext* ctx,
+ const std::vector<ValueDescr>& descrs);
+
+} // namespace internal
+
+namespace {
+
+void AggregatesToString(
+ std::stringstream* ss, const Schema& input_schema,
+ const std::vector<internal::Aggregate>& aggs,
+ const std::vector<int>& target_field_ids,
+ const std::vector<std::unique_ptr<FunctionOptions>>& owned_options) {
+ *ss << "aggregates=[" << std::endl;
+ for (size_t i = 0; i < aggs.size(); i++) {
+ *ss << '\t' << aggs[i].function << '('
+ << input_schema.field(target_field_ids[i])->name();
+ if (owned_options[i]) {
+ *ss << ", " << owned_options[i]->ToString();
+ }
+ *ss << ")," << std::endl;
+ }
+ *ss << ']';
+}
+
+class ScalarAggregateNode : public ExecNode {
+ public:
+ ScalarAggregateNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ std::shared_ptr<Schema> output_schema,
+ std::vector<int> target_field_ids,
+ std::vector<internal::Aggregate> aggs,
+ std::vector<const ScalarAggregateKernel*> kernels,
+ std::vector<std::vector<std::unique_ptr<KernelState>>> states,
+ std::vector<std::unique_ptr<FunctionOptions>> owned_options)
+ : ExecNode(plan, std::move(inputs), {"target"},
+ /*output_schema=*/std::move(output_schema),
+ /*num_outputs=*/1),
+ target_field_ids_(std::move(target_field_ids)),
+ aggs_(std::move(aggs)),
+ kernels_(std::move(kernels)),
+ states_(std::move(states)),
+ owned_options_(std::move(owned_options)) {}
+
+ static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "ScalarAggregateNode"));
+
+ const auto& aggregate_options = checked_cast<const AggregateNodeOptions&>(options);
+ auto aggregates = aggregate_options.aggregates;
+
+ const auto& input_schema = *inputs[0]->output_schema();
+ auto exec_ctx = plan->exec_context();
+
+ std::vector<const ScalarAggregateKernel*> kernels(aggregates.size());
+ std::vector<std::vector<std::unique_ptr<KernelState>>> states(kernels.size());
+ FieldVector fields(kernels.size());
+ const auto& field_names = aggregate_options.names;
+ std::vector<int> target_field_ids(kernels.size());
+ std::vector<std::unique_ptr<FunctionOptions>> owned_options(aggregates.size());
+
+ for (size_t i = 0; i < kernels.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto match,
+ FieldRef(aggregate_options.targets[i]).FindOne(input_schema));
+ target_field_ids[i] = match[0];
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto function, exec_ctx->func_registry()->GetFunction(aggregates[i].function));
+
+ if (function->kind() != Function::SCALAR_AGGREGATE) {
+ return Status::Invalid("Provided non ScalarAggregateFunction ",
+ aggregates[i].function);
+ }
+
+ auto in_type = ValueDescr::Array(input_schema.field(target_field_ids[i])->type());
+
+ ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, function->DispatchExact({in_type}));
+ kernels[i] = static_cast<const ScalarAggregateKernel*>(kernel);
+
+ if (aggregates[i].options == nullptr) {
+ aggregates[i].options = function->default_options();
+ }
+ if (aggregates[i].options) {
+ owned_options[i] = aggregates[i].options->Copy();
+ aggregates[i].options = owned_options[i].get();
+ }
+
+ KernelContext kernel_ctx{exec_ctx};
+ states[i].resize(ThreadIndexer::Capacity());
+ RETURN_NOT_OK(Kernel::InitAll(&kernel_ctx,
+ KernelInitArgs{kernels[i],
+ {
+ in_type,
+ },
+ aggregates[i].options},
+ &states[i]));
+
+ // pick one to resolve the kernel signature
+ kernel_ctx.SetState(states[i][0].get());
+ ARROW_ASSIGN_OR_RAISE(
+ auto descr, kernels[i]->signature->out_type().Resolve(&kernel_ctx, {in_type}));
+
+ fields[i] = field(field_names[i], std::move(descr.type));
+ }
+
+ return plan->EmplaceNode<ScalarAggregateNode>(
+ plan, std::move(inputs), schema(std::move(fields)), std::move(target_field_ids),
+ std::move(aggregates), std::move(kernels), std::move(states),
+ std::move(owned_options));
+ }
+
+ const char* kind_name() const override { return "ScalarAggregateNode"; }
+
+ Status DoConsume(const ExecBatch& batch, size_t thread_index) {
+ for (size_t i = 0; i < kernels_.size(); ++i) {
+ KernelContext batch_ctx{plan()->exec_context()};
+ batch_ctx.SetState(states_[i][thread_index].get());
+
+ ExecBatch single_column_batch{{batch.values[target_field_ids_[i]]}, batch.length};
+ RETURN_NOT_OK(kernels_[i]->consume(&batch_ctx, single_column_batch));
+ }
+ return Status::OK();
+ }
+
+ void InputReceived(ExecNode* input, ExecBatch batch) override {
+ DCHECK_EQ(input, inputs_[0]);
+
+ auto thread_index = get_thread_index_();
+
+ if (ErrorIfNotOk(DoConsume(std::move(batch), thread_index))) return;
+
+ if (input_counter_.Increment()) {
+ ErrorIfNotOk(Finish());
+ }
+ }
+
+ void ErrorReceived(ExecNode* input, Status error) override {
+ DCHECK_EQ(input, inputs_[0]);
+ outputs_[0]->ErrorReceived(this, std::move(error));
+ }
+
+ void InputFinished(ExecNode* input, int total_batches) override {
+ DCHECK_EQ(input, inputs_[0]);
+
+ if (input_counter_.SetTotal(total_batches)) {
+ ErrorIfNotOk(Finish());
+ }
+ }
+
+ Status StartProducing() override {
+ finished_ = Future<>::Make();
+ // Scalar aggregates will only output a single batch
+ outputs_[0]->InputFinished(this, 1);
+ return Status::OK();
+ }
+
+ void PauseProducing(ExecNode* output) override {}
+
+ void ResumeProducing(ExecNode* output) override {}
+
+ void StopProducing(ExecNode* output) override {
+ DCHECK_EQ(output, outputs_[0]);
+ StopProducing();
+ }
+
+ void StopProducing() override {
+ if (input_counter_.Cancel()) {
+ finished_.MarkFinished();
+ }
+ inputs_[0]->StopProducing(this);
+ }
+
+ Future<> finished() override { return finished_; }
+
+ protected:
+ std::string ToStringExtra() const override {
+ std::stringstream ss;
+ const auto input_schema = inputs_[0]->output_schema();
+ AggregatesToString(&ss, *input_schema, aggs_, target_field_ids_, owned_options_);
+ return ss.str();
+ }
+
+ private:
+ Status Finish() {
+ ExecBatch batch{{}, 1};
+ batch.values.resize(kernels_.size());
+
+ for (size_t i = 0; i < kernels_.size(); ++i) {
+ KernelContext ctx{plan()->exec_context()};
+ ARROW_ASSIGN_OR_RAISE(auto merged, ScalarAggregateKernel::MergeAll(
+ kernels_[i], &ctx, std::move(states_[i])));
+ RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i]));
+ }
+
+ outputs_[0]->InputReceived(this, std::move(batch));
+ finished_.MarkFinished();
+ return Status::OK();
+ }
+
+ Future<> finished_ = Future<>::MakeFinished();
+ const std::vector<int> target_field_ids_;
+ const std::vector<internal::Aggregate> aggs_;
+ const std::vector<const ScalarAggregateKernel*> kernels_;
+
+ std::vector<std::vector<std::unique_ptr<KernelState>>> states_;
+ const std::vector<std::unique_ptr<FunctionOptions>> owned_options_;
+
+ ThreadIndexer get_thread_index_;
+ AtomicCounter input_counter_;
+};
+
+class GroupByNode : public ExecNode {
+ public:
+ GroupByNode(ExecNode* input, std::shared_ptr<Schema> output_schema, ExecContext* ctx,
+ std::vector<int> key_field_ids, std::vector<int> agg_src_field_ids,
+ std::vector<internal::Aggregate> aggs,
+ std::vector<const HashAggregateKernel*> agg_kernels,
+ std::vector<std::unique_ptr<FunctionOptions>> owned_options)
+ : ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema),
+ /*num_outputs=*/1),
+ ctx_(ctx),
+ key_field_ids_(std::move(key_field_ids)),
+ agg_src_field_ids_(std::move(agg_src_field_ids)),
+ aggs_(std::move(aggs)),
+ agg_kernels_(std::move(agg_kernels)),
+ owned_options_(std::move(owned_options)) {}
+
+ static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "GroupByNode"));
+
+ auto input = inputs[0];
+ const auto& aggregate_options = checked_cast<const AggregateNodeOptions&>(options);
+ const auto& keys = aggregate_options.keys;
+ // Copy (need to modify options pointer below)
+ auto aggs = aggregate_options.aggregates;
+ const auto& field_names = aggregate_options.names;
+
+ // Get input schema
+ auto input_schema = input->output_schema();
+
+ // Find input field indices for key fields
+ std::vector<int> key_field_ids(keys.size());
+ for (size_t i = 0; i < keys.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(*input_schema));
+ key_field_ids[i] = match[0];
+ }
+
+ // Find input field indices for aggregates
+ std::vector<int> agg_src_field_ids(aggs.size());
+ for (size_t i = 0; i < aggs.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto match,
+ aggregate_options.targets[i].FindOne(*input_schema));
+ agg_src_field_ids[i] = match[0];
+ }
+
+ // Build vector of aggregate source field data types
+ DCHECK_EQ(aggregate_options.targets.size(), aggs.size());
+ std::vector<ValueDescr> agg_src_descrs(aggs.size());
+ for (size_t i = 0; i < aggs.size(); ++i) {
+ auto agg_src_field_id = agg_src_field_ids[i];
+ agg_src_descrs[i] =
+ ValueDescr(input_schema->field(agg_src_field_id)->type(), ValueDescr::ARRAY);
+ }
+
+ auto ctx = input->plan()->exec_context();
+
+ // Construct aggregates
+ ARROW_ASSIGN_OR_RAISE(auto agg_kernels,
+ internal::GetKernels(ctx, aggs, agg_src_descrs));
+
+ ARROW_ASSIGN_OR_RAISE(auto agg_states,
+ internal::InitKernels(agg_kernels, ctx, aggs, agg_src_descrs));
+
+ ARROW_ASSIGN_OR_RAISE(
+ FieldVector agg_result_fields,
+ internal::ResolveKernels(aggs, agg_kernels, agg_states, ctx, agg_src_descrs));
+
+ // Build field vector for output schema
+ FieldVector output_fields{keys.size() + aggs.size()};
+
+ // Aggregate fields come before key fields to match the behavior of GroupBy function
+ for (size_t i = 0; i < aggs.size(); ++i) {
+ output_fields[i] = agg_result_fields[i]->WithName(field_names[i]);
+ }
+ size_t base = aggs.size();
+ for (size_t i = 0; i < keys.size(); ++i) {
+ int key_field_id = key_field_ids[i];
+ output_fields[base + i] = input_schema->field(key_field_id);
+ }
+
+ std::vector<std::unique_ptr<FunctionOptions>> owned_options;
+ owned_options.reserve(aggs.size());
+ for (auto& agg : aggs) {
+ owned_options.push_back(agg.options ? agg.options->Copy() : nullptr);
+ agg.options = owned_options.back().get();
+ }
+
+ return input->plan()->EmplaceNode<GroupByNode>(
+ input, schema(std::move(output_fields)), ctx, std::move(key_field_ids),
+ std::move(agg_src_field_ids), std::move(aggs), std::move(agg_kernels),
+ std::move(owned_options));
+ }
+
+ const char* kind_name() const override { return "GroupByNode"; }
+
+ Status Consume(ExecBatch batch) {
+ size_t thread_index = get_thread_index_();
+ if (thread_index >= local_states_.size()) {
+ return Status::IndexError("thread index ", thread_index, " is out of range [0, ",
+ local_states_.size(), ")");
+ }
+
+ auto state = &local_states_[thread_index];
+ RETURN_NOT_OK(InitLocalStateIfNeeded(state));
+
+ // Create a batch with key columns
+ std::vector<Datum> keys(key_field_ids_.size());
+ for (size_t i = 0; i < key_field_ids_.size(); ++i) {
+ keys[i] = batch.values[key_field_ids_[i]];
+ }
+ ExecBatch key_batch(std::move(keys), batch.length);
+
+ // Create a batch with group ids
+ ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch));
+
+ // Execute aggregate kernels
+ for (size_t i = 0; i < agg_kernels_.size(); ++i) {
+ KernelContext kernel_ctx{ctx_};
+ kernel_ctx.SetState(state->agg_states[i].get());
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto agg_batch,
+ ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch}));
+
+ RETURN_NOT_OK(agg_kernels_[i]->resize(&kernel_ctx, state->grouper->num_groups()));
+ RETURN_NOT_OK(agg_kernels_[i]->consume(&kernel_ctx, agg_batch));
+ }
+
+ return Status::OK();
+ }
+
+ Status Merge() {
+ ThreadLocalState* state0 = &local_states_[0];
+ for (size_t i = 1; i < local_states_.size(); ++i) {
+ ThreadLocalState* state = &local_states_[i];
+ if (!state->grouper) {
+ continue;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques());
+ ARROW_ASSIGN_OR_RAISE(Datum transposition, state0->grouper->Consume(other_keys));
+ state->grouper.reset();
+
+ for (size_t i = 0; i < agg_kernels_.size(); ++i) {
+ KernelContext batch_ctx{ctx_};
+ DCHECK(state0->agg_states[i]);
+ batch_ctx.SetState(state0->agg_states[i].get());
+
+ RETURN_NOT_OK(agg_kernels_[i]->resize(&batch_ctx, state0->grouper->num_groups()));
+ RETURN_NOT_OK(agg_kernels_[i]->merge(&batch_ctx, std::move(*state->agg_states[i]),
+ *transposition.array()));
+ state->agg_states[i].reset();
+ }
+ }
+ return Status::OK();
+ }
+
+ Result<ExecBatch> Finalize() {
+ ThreadLocalState* state = &local_states_[0];
+ // If we never got any batches, then state won't have been initialized
+ RETURN_NOT_OK(InitLocalStateIfNeeded(state));
+
+ ExecBatch out_data{{}, state->grouper->num_groups()};
+ out_data.values.resize(agg_kernels_.size() + key_field_ids_.size());
+
+ // Aggregate fields come before key fields to match the behavior of GroupBy function
+ for (size_t i = 0; i < agg_kernels_.size(); ++i) {
+ KernelContext batch_ctx{ctx_};
+ batch_ctx.SetState(state->agg_states[i].get());
+ RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out_data.values[i]));
+ state->agg_states[i].reset();
+ }
+
+ ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, state->grouper->GetUniques());
+ std::move(out_keys.values.begin(), out_keys.values.end(),
+ out_data.values.begin() + agg_kernels_.size());
+ state->grouper.reset();
+
+ if (output_counter_.SetTotal(
+ static_cast<int>(BitUtil::CeilDiv(out_data.length, output_batch_size())))) {
+ // this will be hit if out_data.length == 0
+ finished_.MarkFinished();
+ }
+ return out_data;
+ }
+
+ void OutputNthBatch(int n) {
+ // bail if StopProducing was called
+ if (finished_.is_finished()) return;
+
+ int64_t batch_size = output_batch_size();
+ outputs_[0]->InputReceived(this, out_data_.Slice(batch_size * n, batch_size));
+
+ if (output_counter_.Increment()) {
+ finished_.MarkFinished();
+ }
+ }
+
+ Status OutputResult() {
+ RETURN_NOT_OK(Merge());
+ ARROW_ASSIGN_OR_RAISE(out_data_, Finalize());
+
+ int num_output_batches = *output_counter_.total();
+ outputs_[0]->InputFinished(this, num_output_batches);
+
+ auto executor = ctx_->executor();
+ for (int i = 0; i < num_output_batches; ++i) {
+ if (executor) {
+ // bail if StopProducing was called
+ if (finished_.is_finished()) break;
+
+ auto plan = this->plan()->shared_from_this();
+ RETURN_NOT_OK(executor->Spawn([plan, this, i] { OutputNthBatch(i); }));
+ } else {
+ OutputNthBatch(i);
+ }
+ }
+
+ return Status::OK();
+ }
+
+ void InputReceived(ExecNode* input, ExecBatch batch) override {
+ // bail if StopProducing was called
+ if (finished_.is_finished()) return;
+
+ DCHECK_EQ(input, inputs_[0]);
+
+ if (ErrorIfNotOk(Consume(std::move(batch)))) return;
+
+ if (input_counter_.Increment()) {
+ ErrorIfNotOk(OutputResult());
+ }
+ }
+
+ void ErrorReceived(ExecNode* input, Status error) override {
+ DCHECK_EQ(input, inputs_[0]);
+
+ outputs_[0]->ErrorReceived(this, std::move(error));
+ }
+
+ void InputFinished(ExecNode* input, int total_batches) override {
+ // bail if StopProducing was called
+ if (finished_.is_finished()) return;
+
+ DCHECK_EQ(input, inputs_[0]);
+
+ if (input_counter_.SetTotal(total_batches)) {
+ ErrorIfNotOk(OutputResult());
+ }
+ }
+
+ Status StartProducing() override {
+ finished_ = Future<>::Make();
+
+ local_states_.resize(ThreadIndexer::Capacity());
+ return Status::OK();
+ }
+
+ void PauseProducing(ExecNode* output) override {}
+
+ void ResumeProducing(ExecNode* output) override {}
+
+ void StopProducing(ExecNode* output) override {
+ DCHECK_EQ(output, outputs_[0]);
+
+ ARROW_UNUSED(input_counter_.Cancel());
+ if (output_counter_.Cancel()) {
+ finished_.MarkFinished();
+ }
+ inputs_[0]->StopProducing(this);
+ }
+
+ void StopProducing() override { StopProducing(outputs_[0]); }
+
+ Future<> finished() override { return finished_; }
+
+ protected:
+ std::string ToStringExtra() const override {
+ std::stringstream ss;
+ const auto input_schema = inputs_[0]->output_schema();
+ ss << "keys=[";
+ for (size_t i = 0; i < key_field_ids_.size(); i++) {
+ if (i > 0) ss << ", ";
+ ss << '"' << input_schema->field(key_field_ids_[i])->name() << '"';
+ }
+ ss << "], ";
+ AggregatesToString(&ss, *input_schema, aggs_, agg_src_field_ids_, owned_options_);
+ return ss.str();
+ }
+
+ private:
+ struct ThreadLocalState {
+ std::unique_ptr<internal::Grouper> grouper;
+ std::vector<std::unique_ptr<KernelState>> agg_states;
+ };
+
+ ThreadLocalState* GetLocalState() {
+ size_t thread_index = get_thread_index_();
+ return &local_states_[thread_index];
+ }
+
+ Status InitLocalStateIfNeeded(ThreadLocalState* state) {
+ // Get input schema
+ auto input_schema = inputs_[0]->output_schema();
+
+ if (state->grouper != nullptr) return Status::OK();
+
+ // Build vector of key field data types
+ std::vector<ValueDescr> key_descrs(key_field_ids_.size());
+ for (size_t i = 0; i < key_field_ids_.size(); ++i) {
+ auto key_field_id = key_field_ids_[i];
+ key_descrs[i] = ValueDescr(input_schema->field(key_field_id)->type());
+ }
+
+ // Construct grouper
+ ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_));
+
+ // Build vector of aggregate source field data types
+ std::vector<ValueDescr> agg_src_descrs(agg_kernels_.size());
+ for (size_t i = 0; i < agg_kernels_.size(); ++i) {
+ auto agg_src_field_id = agg_src_field_ids_[i];
+ agg_src_descrs[i] =
+ ValueDescr(input_schema->field(agg_src_field_id)->type(), ValueDescr::ARRAY);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ state->agg_states,
+ internal::InitKernels(agg_kernels_, ctx_, aggs_, agg_src_descrs));
+
+ return Status::OK();
+ }
+
+ int output_batch_size() const {
+ int result = static_cast<int>(ctx_->exec_chunksize());
+ if (result < 0) {
+ result = 32 * 1024;
+ }
+ return result;
+ }
+
+ ExecContext* ctx_;
+ Future<> finished_ = Future<>::MakeFinished();
+
+ const std::vector<int> key_field_ids_;
+ const std::vector<int> agg_src_field_ids_;
+ const std::vector<internal::Aggregate> aggs_;
+ const std::vector<const HashAggregateKernel*> agg_kernels_;
+ // ARROW-13638: must hold owned copy of function options
+ const std::vector<std::unique_ptr<FunctionOptions>> owned_options_;
+
+ ThreadIndexer get_thread_index_;
+ AtomicCounter input_counter_, output_counter_;
+
+ std::vector<ThreadLocalState> local_states_;
+ ExecBatch out_data_;
+};
+
+} // namespace
+
+namespace internal {
+
+void RegisterAggregateNode(ExecFactoryRegistry* registry) {
+ DCHECK_OK(registry->AddFactory(
+ "aggregate",
+ [](ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) -> Result<ExecNode*> {
+ const auto& aggregate_options =
+ checked_cast<const AggregateNodeOptions&>(options);
+
+ if (aggregate_options.keys.empty()) {
+ // construct scalar agg node
+ return ScalarAggregateNode::Make(plan, std::move(inputs), options);
+ }
+ return GroupByNode::Make(plan, std::move(inputs), options);
+ }));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_1.jpg b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_1.jpg
new file mode 100644
index 000000000..814ad8a69
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_1.jpg
Binary files differ
diff --git a/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_10.jpg b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_10.jpg
new file mode 100644
index 000000000..7a75c96df
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_10.jpg
Binary files differ
diff --git a/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_11.jpg b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_11.jpg
new file mode 100644
index 000000000..59bcc167e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_11.jpg
Binary files differ
diff --git a/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_2.jpg b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_2.jpg
new file mode 100644
index 000000000..4484c57a8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_2.jpg
Binary files differ
diff --git a/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_3.jpg b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_3.jpg
new file mode 100644
index 000000000..afd33aba2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_3.jpg
Binary files differ
diff --git a/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_4.jpg b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_4.jpg
new file mode 100644
index 000000000..f026aebe9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_4.jpg
Binary files differ
diff --git a/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_5.jpg b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_5.jpg
new file mode 100644
index 000000000..8e1981b65
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_5.jpg
Binary files differ
diff --git a/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_6.jpg b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_6.jpg
new file mode 100644
index 000000000..e976a4614
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_6.jpg
Binary files differ
diff --git a/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_7.jpg b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_7.jpg
new file mode 100644
index 000000000..7552d5af6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_7.jpg
Binary files differ
diff --git a/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_8.jpg b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_8.jpg
new file mode 100644
index 000000000..242f13053
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_8.jpg
Binary files differ
diff --git a/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_9.jpg b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_9.jpg
new file mode 100644
index 000000000..4c064595c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/doc/img/key_map_9.jpg
Binary files differ
diff --git a/src/arrow/cpp/src/arrow/compute/exec/doc/key_map.md b/src/arrow/cpp/src/arrow/compute/exec/doc/key_map.md
new file mode 100644
index 000000000..fdedc88c4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/doc/key_map.md
@@ -0,0 +1,223 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Swiss Table
+
+A specialized hash table implementation used to dynamically map combinations of key field values to a dense set of integer ids. Ids can later be used in place of keys to identify groups of rows with equal keys.
+
+## Introduction
+
+Hash group-by in Arrow uses a variant of a hash table based on a data structure called Swiss table. Swiss table uses linear probing. There is an array of slots and the information related to inserted keys is stored in these slots. A hash function determines the slot where the search for a matching key will start during hash table lookup. Then the slots are visited sequentially, wrapping around the end of an array, until either a match or an empty slot is found, the latter case meaning that there is no match. Swiss table organizes the slots in blocks of 8 and has a design that enables data level parallelism at the block level. More precisely, it allows for visiting all slots within a block at once during lookups, by simply using 64-bit arithmetic. SIMD instructions can further enhance this data level parallelism allowing to process multiple blocks related to multiple input keys together using SIMD vectors of 64-bit elements. Occupied slots within a block are always clustered together. The name Swiss table comes from likening resulting sequences of empty slots to holes in a one dimensional cheese.
+
+## Interface
+
+Hash table used in query processing for implementing join and group-by operators does not need to provide all of the operations that a general purpose hash table would. Simplified requirements can help achieve a simpler and more efficient design. For instance we do not need to be able to remove previously inserted keys. It’s an append-only data structure: new keys can be added but old keys are never erased. Also, only a single copy of each key can be inserted - it is like `std::map` in that sense and not `std::multimap`.
+
+Our Swiss table is fully vectorized. That means that all methods work on vectors of input keys processing them in batches. Specialized SIMD implementations of processing functions are almost always provided for performance critical operations. All callback interfaces used from the core hash table code are also designed to work on batches of inputs instead of individual keys. The batch size can be almost arbitrary and is selected by the client of the hash table. Batch size should be the smallest number of input items, big enough so that the benefits of vectorization and SIMD can be fully experienced. Keeping it small means less memory used for temporary arrays storing intermediate results of computation (vector equivalent of some temporary variables kept on the stack). That in turn means smaller space in CPU caches, which also means less impact on other memory access intensive operations. We pick 1024 as the default size of the batch. We will call it a **mini-batch** to distinguish it from potentially other forms of batches used at higher levels in the code, e.g. when scheduling work for worker threads or relational operators inside an analytic query.
+
+The main functionality provided by Swiss table is mapping of arbitrarily complex keys to unique integer ids. Let us call it **lookup-or-insert**. Given a sequence of key values, return a corresponding sequence of integer ids, such that all keys that are equal receive the same id and for K distinct keys the integer ids will be assigned from the set of numbers 0 to (K-1). If we find a matching key in a hash table for a given input, we return the **key id** assigned when the key was first inserted into a hash table. If we fail to find an already inserted match, we assign the first unused integer as a key id and add a new entry to a hash table. Due to vectorized processing, which may result in out-of-order processing of individual inputs, it is not guaranteed that if there are two new key values in the same input batch and one of them appears earlier in the input sequence, then it will receive a smaller key id. Additional mapping functionality can be built on top of basic mapping to integer key id, for instance if we want to assign and perhaps keep updating some values to all unique keys, we can keep these values in a resizable vector indexed by obtained key id.
+
+The implementation of Swiss table does not need to have any information related to the domain of the keys. It does not use their logical data type or information about their physical representation and does not even use pointers to keys. All access to keys is delegated to a separate class or classes that provide callback functions for three operations:
+- computing hashes of keys;
+- checking equality for given pairs of keys;
+- appending a given sequence of keys to a stack maintained outside of Swiss table object, so that they can be referenced later on by key ids (key ids will be equal to their positions in the stack).
+
+
+When passing arguments to callback functions the keys are referenced using integer ids. For the left side - that is the keys present in the input mini-batch - ordinal positions within that mini-batch are used. For the right side - that is the keys inserted into the hash table - these are identified by key ids assigned to them and stored inside Swiss table when they were first encountered and processed.
+
+Diagram with logical view of information passing in callbacks:
+
+![alt text](img/key_map_1.jpg)
+
+Hash table values for inserted keys are also stored inside Swiss table. Because of that, hash table logic does not need to ever re-evaluate the hash, and there is actually no need for a hash function callback. It is enough that the caller provides hash values for all entries in the batch when calling lookup-or-insert.
+
+## Basic architecture and organization of data
+The hash table is an array of **slots**. Slots are grouped in groups of 8 called **blocks**. The number of blocks is a power of 2. The empty hash table starts with a single block, with all slots empty. Then, as the keys are getting inserted and the amount of empty slots is shrinking, at some point resizing of the hash table is triggered. The data stored in slots is moved to a new hash table that has the double of the number of blocks.
+
+The diagram below shows the basic organization of data in our implementation of Swiss table:
+
+![alt text](img/key_map_2.jpg)
+
+N is the log of the number of blocks, 2<sup>n+3</sup> is the number of slots and also the maximum number of inserted keys and hence (N + 3) is the number of bits required to store a key id. We will refer to N as the **size of the hash table**.
+
+Index of a block within an array will be called **block id**, and similarly index of a slot will be **slot id**. Sometimes we will focus on a single block and refer to slots that belong to it by using a **local slot id**, which is an index from 0 to 7.
+
+Every slot can either be **empty** or store data related to a single inserted key. There are three pieces of information stored inside a slot:
+- status byte,
+- key id,
+- key hash.
+
+Status byte, as the name suggests, stores 8 bits. The highest bit indicates if the slot is empty (the highest bit is set) or corresponds to one of inserted keys (the highest bit is zero). The remaining 7 bits contain 7 bits of key hash that we call a **stamp**. The stamp is used to eliminate some false positives when searching for a matching key for a given input. Slot also stores **key id**, which is a non-negative integer smaller than the number of inserted keys, that is used as a reference to the actual inserted key. The last piece of information related to an inserted key is its **hash** value. We store hashes for all keys, so that they never need to be re-computed. That greatly simplifies some operations, like resizing of a hash table, that may not even need to look at the keys at all. For an empty slot, the status byte is 0x80, key id is zero and the hash is not used and can be set to any number.
+
+A single block contains 8 slots and can be viewed as a micro-stack of up to 8 inserted keys. When the first key is inserted into an empty block, it will occupy a slot with local id 0. The second inserted key will go into slot number 1 and so on. We use N highest bits of hash to get an index of a **start block**, when searching for a match or an empty slot to insert a previously not seen key when that is the case. If the start block contains any empty slots, then the search for either a match or place to insert a key will end at that block. We will call such a block an **open block**. A block that is not open is a full block. In the case of full block, the input key related search may continue in the next block module the number of blocks. If the key is not inserted into its start block, we will refer to it as an **overflow** entry, other entries being **non-overflow**. Overflow entries are slower to process, since they require visiting more than one block, so we want to keep their percentage low. This is done by choosing the right **load factor** (percentage of occupied slots in the hash table) at which the hash table gets resized and the number of blocks gets doubled. By tuning this value we can control the probability of encountering an overflow entry.
+
+The most interesting part of each block is the set of status bytes of its slots, which is simply a single 64-bit word. The implementation of efficient searches across these bytes during lookups require using either leading zero count or trailing zero count intrinsic. Since there are cases when only the first one is available, in order to take advantage of it, we order the bytes in the 64-bit status word so that the first slot within a block uses the highest byte and the last one uses the lowest byte (slots are in reversed bytes order). The diagram below shows how the information about slots is stored within a 64-bit status word:
+
+![alt text](img/key_map_3.jpg)
+
+Each status byte has a 7-bit fragment of hash value - a **stamp** - and an empty slot bit. Empty slots have status byte equal to 0x80 - the highest bit is set to 1 to indicate an empty slot and the lowest bits, which are used by a stamp, are set to zero.
+
+The diagram below shows which bits of hash value are used by hash table:
+
+![alt text](img/key_map_4.jpg)
+
+If a hash table has 2<sup>N</sup> blocks, then we use N highest bits of a hash to select a start block when searching for a match. The next 7 bits are used as a stamp. Using the highest bits to pick a start block means that a range of hash values can be easily mapped to a range of block ids of start blocks for hashes in that range. This is useful when resizing a hash table or merging two hash tables together.
+
+### Interleaving status bytes and key ids
+
+Status bytes and key ids for all slots are stored in a single array of bytes. They are first grouped by 8 into blocks, then each block of status bytes is interleaved with a corresponding block of key ids. Finally key ids are represented using the smallest possible number of bits and bit-packed (bits representing each next key id start right after the last bit of the previous key id). Note that regardless of the chosen number of bits, a block of bit-packed key ids (that is 8 of them) will start and end on the byte boundary.
+
+The diagram below shows the organization of bytes and bits of a single block in interleaved array:
+![alt text](img/key_map_5.jpg)
+
+From the size of the hash table we can derive the number K of bits needed in the worst case to encode any key id. K is equal to the number of bits needed to represent slot id (number of keys is not greater than the number of slots and any key id is strictly less than the number of keys), which for a hash table of size N (N blocks) equals (N+3). To simplify bit packing and unpacking and avoid handling of special cases, we will round up K to full bytes for K > 24 bits.
+
+Status bytes are stored in a single 64-bit word in reverse byte order (the last byte corresponds to the slot with local id 0). On the other hand key ids are stored in the normal order (the order of slot ids).
+
+Since both status byte and key id for a given slot are stored in the same array close to each other, we can expect that most of the lookups will read only one CPU cache-line from memory inside Swiss table code (then at least another one outside Swiss table to access the bytes of the key for the purpose of comparison). Even if we hit an overflow entry, it is still likely to reside on the same cache-line as the start block data. Hash values, which are stored separately from status byte and key id, are only used when resizing and do not impact the lookups outside these events.
+
+> Improvement to consider:
+> In addition to the Swiss table data, we need to store an array of inserted keys, one for each key id. If keys are of fixed length, then the address of the bytes of the key can be calculated by multiplying key id by the common length of the key. If keys are of varying length, then there will be an additional array with an offset of each key within the array of concatenated bytes of keys. That means that any key comparison during lookup will involve 3 arrays: one to get key id, one to get key offset and final one with bytes of the key. This could be reduced to 2 array lookups if we stored key offset instead of key id interleaved with slot status bytes. Offset indexed by key id and stored in its own array becomes offset indexed by slot id and stored interleaved with slot status bytes. At the same time key id indexed by slot id and interleaved with slot status bytes before becomes key id referenced using offset and stored with key bytes. There may be a slight increase in the total size of memory needed by the hash table, equal to the difference in the number of bits used to store offset and those used to store key id, multiplied by the number of slots, but that should be a small fraction of the total size.
+
+### 32-bit hash vs 64-bit hash
+
+Currently we use 32-bit hash values in Swiss table code and 32-bit integers as key ids. For the robust implementation, sooner or later we will need to support 64-bit hash and 64-bit key ids. When we use 32-bit hash, it means that we run out of hash bits when hash table size N is greater than 25 (25 bits of hash needed to select a block and 7 bits needed to generate a stamp byte reach 32 total bits). When the number of inserted keys exceeds the maximal number of keys stored in a hash table of size 25 (which is at least 2<sup>24</sup>), the chance of false positives during lookups will start quickly growing. 32-bit hash should not be used with more than about 16 million inserted keys.
+
+### Low memory footprint and low chance of hash collisions
+
+Swiss table is a good choice of a hash table for modern hardware, because it combines lookups that can take advantage of special CPU instructions with space efficiency and low chance of hash collisions.
+
+Space efficiency is important for performance, because the cost of random array accesses, often dominating the lookup cost for larger hash tables, increases with the size of the arrays. This happens due to limited space of CPU caches. Let us look at what is the amortized additional storage cost for a key in a hash table apart from the essential cost of storing data of all those keys. Furthermore, we can skip the storage of hash values, since these are only used during infrequent hash table resize operations (should not have a big impact on CPU cache usage in normal cases).
+
+Half full hash table of size N will use 2 status bytes per inserted key (because for every filled slot there is one empty slot) and 2\*(N+3) bits for key id (again, one for the occupied slot and one for the empty). For N = 16 for instance this is slightly under 7 bytes per inserted key.
+
+Swiss table also has a low probability of false positives leading to wasted key comparisons. Here is some rationale behind why this should be the case. Hash table of size N can contain up to 2<sup>N+3</sup> keys. Search for a match involves (N + 7) hash bits: N to select a start block and 7 to use as a stamp. There are always at least 16 times more combinations of used hash bits than there are keys in the hash table (32 times more if the hash table is half full). These numbers mean that the probability of false positives resulting from a search for a matching slot should be low. That corresponds to an expected number of comparisons per lookup being close to 1 for keys already present and 0 for new keys.
+
+## Lookup
+
+Lookup-or-insert operation, given a hash of a key, finds a list of candidate slots with corresponding keys that are likely to be equal to the input key. The list may be empty, which means that the key does not exist yet in the hash table. If it is not empty, then the callback function for key comparison is called for each next candidate to verify that there is indeed a match. False positives get rejected and we end up either finding an actual match or an empty slot, which means that the key is new to the hash table. New keys get assigned next available integers as key ids, and are appended to the set of keys stored in the hash table. As a result of inserting new keys to the hash table, the density of occupied slots may reach an upper limit, at which point the hash table will be resized and will afterwards have twice as many slots. That is in summary lookup-or-insert functionality, but the actual implementation is a bit more involved, because of vectorization of the processing and various optimizations for common cases.
+
+### Search within a single block
+
+There are three possible cases that can occur when searching for a match for a given key (that is, for a given stamp of a key) within a single block, illustrated below.
+
+ 1. There is a matching stamp in the block of status bytes:
+
+![alt text](img/key_map_6.jpg)
+
+ 2. There is no matching stamp in the block, but there is an empty slot in the block:
+
+![alt text](img/key_map_7.jpg)
+
+ 3. There is no matching stamp in the block and the block is full (there are no empty slots left):
+
+![alt text](img/key_map_8.jpg)
+
+64-bit arithmetic can be used to search for a matching slot within the entire single block at once, without iterating over all slots in it. Following is an example of a sequence of steps to find the first status byte for a given stamp, returning the first empty slot on miss if the block is not full or 8 (one past maximum local slot id) otherwise.
+
+Following is a sketch of the possible steps to execute when searching for the matching stamp in a single block.
+
+*Example will use input stamp 0x5E and a 64-bit status bytes word with one empty slot:
+0x 4B17 5E3A 5E2B 1180*.
+
+1. [1 instruction] Replicate stamp to all bytes by multiplying it by 0x 0101 0101 0101 0101.
+
+ *We obtain: 0x 5E5E 5E5E 5E5E 5E5E.*
+
+2. [1 instruction] XOR replicated stamp with status bytes word. Bytes corresponding to a matching stamp will be 0, bytes corresponding to empty slots will have a value between 128 and 255, bytes corresponding to non-matching non-empty slots will have a value between 1 and 127.
+
+ *We obtain: 0x 1549 0064 0075 4FDE.*
+
+3. [2 instructions] In the next step we want to have information about a match in the highest bit of each byte. We can ignore here empty slot bytes, because they will be taken care of at a later step. Set the highest bit in each byte (OR with 0x 8080 8080 8080 8080) and then subtract 1 from each byte (subtract 0x 0101 0101 0101 0101 from 64-bit word). Now if a byte corresponds to a non-empty slot then the highest bit 0 indicates a match and 1 indicates a miss.
+
+ *We obtain: 0x 95C9 80E4 80F5 CFDE,
+ then 0x 94C8 7FE3 7FF4 CEDD.*
+
+4. [3 instructions] In the next step we want to obtain in each byte one of two values: 0x80 if it is either an empty slot or a match, 0x00 otherwise. We do it in three steps: NOT the result of the previous step to change the meaning of the highest bit; OR with the original status word to set highest bit in a byte to 1 for empty slots; mask out everything other than the highest bits in all bytes (AND with 0x 8080 8080 8080 8080).
+
+ *We obtain: 6B37 801C 800B 3122,
+ then 6B37 DE3E DE2B 31A2,
+ finally 0x0000 8000 8000 0080.*
+
+5. [2 instructions] Finally, use leading zero bits count and divide it by 8 to find an index of the last byte that corresponds either to a match or an empty slot. If the leading zero count intrinsic returns 64 for a 64-bit input zero, then after dividing by 8 we will also get the desired answer in case of a full block without any matches.
+
+ *We obtain: 16,
+ then 2 (index of the first slot within the block that matches the stamp).*
+
+If SIMD instructions with 64-bit lanes are available, multiple single block searches for different keys can be executed together. For instance AVX2 instruction set allows to process quadruplets of 64-bit values in a single instruction, four searches at once.
+
+### Complete search potentially across multiple blocks
+
+Full implementation of a search for a matching key may involve visiting multiple blocks beginning with the start block selected based on the hash of the key. We move to the next block modulo the number of blocks, whenever we do not find a match in the current block and the current block is full. The search may also involve visiting one or more slots in each block. Visiting in this case means calling a comparison callback to verify the match whenever a slot with a matching stamp is encountered. Eventually the search stops when either:
+- the matching key is found in one of the slots matching the stamp, or
+
+- an empty slot is reached. This is illustrated in the diagram below:
+![alt text](img/key_map_9.jpg)
+
+
+### Optimistic processing with two passes
+
+Hash table lookups may have high cost in the pessimistic case, when we encounter cases of hash collisions and full blocks that lead to visiting further blocks. In the majority of cases we can expect an optimistic situation - the start block is not full, so we will only visit this one block, and all stamps in the block are different, so we will need at most one comparison to find a match. We can expect about 90% of the key lookups for an existing key to go through the optimistic path of processing. For that reason it pays off to optimize especially for this 90% of inputs.
+
+Lookups in Swiss table are split into two passes over an input batch of keys. The **first pass: fast-path lookup** , is a highly optimized, vectorized, SIMD-friendly, branch-free code that fully handles optimistic cases. The **second pass: slow-path lookup** , is normally executed only for the selection of inputs that have not been finished in the first pass, although it can also be called directly on all of the inputs, skipping fast-path lookup. It handles all special cases and inserts but in order to be robust it is not as efficient as fast-path. Slow-path lookup does not need to repeat the work done in fast-path lookup - it can use the state reached at the end of fast-path lookup as a starting point.
+
+Fast-path lookup implements search only for the first stamp match and only within the start block. It only makes sense when we already have at least one key inserted into the hash table, since it does not handle inserts. It takes a vector of key hashes as an input and based on it outputs three pieces of information for each key:
+
+- Key id corresponding to the slot in which a matching stamp was found. Any valid key id if a matching stamp was not found.
+- A flag indicating if a match was found or not.
+- Slot id of a slot from which slow-path should pick up the search if the first match was either not found or it turns out to be false positive after evaluating key comparison.
+
+> Improvement to consider:
+> precomputing 1st pass lookup results.
+>
+> If the hash table is small, the number of inserted keys is small, we could further simplify and speed-up the first pass by storing in a lookup table pre-computed results for all combinations of hash bits. Let us consider the case of Swiss table of size 5 that has 256 slots and up to 128 inserted keys. Only 12 bits of hash are used by lookup in that case: 5 to select a block, 7 to create a stamp. For all 2<sup>12</sup> combinations of those bits we could keep the result of first pass lookup in an array. Key id and a match indicating flag can use one byte: 7 bits for key id and 1 bit for the flag. Note that slot id is only needed if we go into 2nd pass lookup, so it can be stored separately and likely only accessed by a small subset of keys. Fast-path lookup becomes almost a single fetch of result from a 4KB array. Lookup arrays used to implement this need to be kept in sync with the main copy of data about slots, which requires extra care during inserts. Since the number of entries in lookup arrays is much higher than the number of slots, this technique only makes sense for small hash tables.
+
+### Dense comparisons
+
+If there is at least one key inserted into a hash table, then every slot contains a key id value that corresponds to some actual key that can be used in comparison. That is because empty slots are initialized with 0 as their key id. After the fast-path lookup we get a match-found flag for each input. If it is set, then we need to run a comparison of the input key with the key in the hash table identified by key id returned by fast-path code. The comparison will verify that there is a true match between the keys. We only need to do this for a subset of inputs that have a match candidate, but since we have key id values corresponding to some real key for all inputs, we may as well execute comparisons on all inputs unconditionally. If the majority (e.g. more than 80%) of the keys have a match candidate, the cost of evaluating comparison for the remaining fraction of keys but without filtering may actually be cheaper than the cost of running evaluation only for required keys while referencing filter information. This can be seen as a variant of general preconditioning techniques used to avoid diverging conditional branches in the code. It may be used, based on some heuristic, to verify matches reported by fast-path lookups and is referred to as **dense comparisons**.
+
+## Resizing
+
+New hash table is initialized as empty and has only a single block with a space for only a few key entries. Doubling of the hash table size becomes necessary as more keys get inserted. It is invoked during the 2nd pass of the lookups, which also handles inserts. It happens immediately after the number of inserted keys reaches a specific upper limit decided based on a current size of the hash table. There may still be unprocessed entries from the input mini-batch after resizing, so the 2nd pass of the lookup is restarted right after, with the bigger hash table and the remaining subset of unprocessed entries.
+
+Current policy, that should work reasonably well, is to resize a small hash table (up to 8KB) when it is 50% full. Larger hash tables are resized when 75% full. We want to keep size in memory as small as possible, while maintaining a low probability of blocks becoming full.
+
+When discussing resizing we will be talking about **resize source** and **resize target** tables. The diagram below shows how the same hash bits are interpreted differently by the source and the target.
+
+![alt text](img/key_map_10.jpg)
+
+For a given hash, if a start block id was L in the source table, it will be either (2\*L+0) or (2\*L+1) in the target table. Based on that we can expect data access locality when migrating the data between the tables.
+
+Resizing is cheap also thanks to the fact that hash values for keys in the hash table are kept together with other slot data and do not need to be recomputed. That means that resizing procedure does not ever need to access the actual bytes of the key.
+
+### 1st pass
+
+Based on the hash value for a given slot we can tell whether this slot contains an overflow or non-overflow entry. In the first pass we go over all source slots in sequence, filter out overflow entries and move to the target table all other entries. Non-overflow entries from a block L will be distributed between blocks (2\*L+0) and (2\*L+1) of the target table. None of these target blocks can overflow, since they will be accommodating at most 8 input entries during this pass.
+
+For every non-overflow entry, the highest bit of a stamp in the source slot decides whether it will go to the left or to the right target block. It is further possible to avoid any conditional branches in this partitioning code, so that the result is friendly to the CPU execution pipeline.
+
+![alt text](img/key_map_11.jpg)
+
+
+### 2nd pass
+
+In the second pass of resizing, we scan all source slots again, this time focusing only on the overflow entries that were all skipped in the 1st pass. We simply reinsert them in the target table using generic insertion code with one exception. Since we know that all the source keys are different, there is no need to search for a matching stamp or run key comparisons (or look at the key values). We just need to find the first open block beginning with the start block in the target table and use its first empty slot as the insert destination.
+
+We expect overflow entries to be rare and therefore the relative cost of that pass should stay low.
+
diff --git a/src/arrow/cpp/src/arrow/compute/exec/exec_plan.cc b/src/arrow/cpp/src/arrow/compute/exec/exec_plan.cc
new file mode 100644
index 000000000..7e7824d85
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/exec_plan.cc
@@ -0,0 +1,523 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/exec_plan.h"
+
+#include <sstream>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/expression.h"
+#include "arrow/compute/exec_internal.h"
+#include "arrow/compute/registry.h"
+#include "arrow/datum.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/optional.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+namespace {
+
+struct ExecPlanImpl : public ExecPlan {
+ explicit ExecPlanImpl(ExecContext* exec_context) : ExecPlan(exec_context) {}
+
+ ~ExecPlanImpl() override {
+ if (started_ && !finished_.is_finished()) {
+ ARROW_LOG(WARNING) << "Plan was destroyed before finishing";
+ StopProducing();
+ finished().Wait();
+ }
+ }
+
+ ExecNode* AddNode(std::unique_ptr<ExecNode> node) {
+ if (node->label().empty()) {
+ node->SetLabel(std::to_string(auto_label_counter_++));
+ }
+ if (node->num_inputs() == 0) {
+ sources_.push_back(node.get());
+ }
+ if (node->num_outputs() == 0) {
+ sinks_.push_back(node.get());
+ }
+ nodes_.push_back(std::move(node));
+ return nodes_.back().get();
+ }
+
+ Status Validate() const {
+ if (nodes_.empty()) {
+ return Status::Invalid("ExecPlan has no node");
+ }
+ for (const auto& node : nodes_) {
+ RETURN_NOT_OK(node->Validate());
+ }
+ return Status::OK();
+ }
+
+ Status StartProducing() {
+ if (started_) {
+ return Status::Invalid("restarted ExecPlan");
+ }
+ started_ = true;
+
+ // producers precede consumers
+ sorted_nodes_ = TopoSort();
+
+ std::vector<Future<>> futures;
+
+ Status st = Status::OK();
+
+ using rev_it = std::reverse_iterator<NodeVector::iterator>;
+ for (rev_it it(sorted_nodes_.end()), end(sorted_nodes_.begin()); it != end; ++it) {
+ auto node = *it;
+
+ st = node->StartProducing();
+ if (!st.ok()) {
+ // Stop nodes that successfully started, in reverse order
+ stopped_ = true;
+ StopProducingImpl(it.base(), sorted_nodes_.end());
+ break;
+ }
+
+ futures.push_back(node->finished());
+ }
+
+ finished_ = AllFinished(futures);
+ return st;
+ }
+
+ void StopProducing() {
+ DCHECK(started_) << "stopped an ExecPlan which never started";
+ stopped_ = true;
+
+ StopProducingImpl(sorted_nodes_.begin(), sorted_nodes_.end());
+ }
+
+ template <typename It>
+ void StopProducingImpl(It begin, It end) {
+ for (auto it = begin; it != end; ++it) {
+ auto node = *it;
+ node->StopProducing();
+ }
+ }
+
+ NodeVector TopoSort() const {
+ struct Impl {
+ const std::vector<std::unique_ptr<ExecNode>>& nodes;
+ std::unordered_set<ExecNode*> visited;
+ NodeVector sorted;
+
+ explicit Impl(const std::vector<std::unique_ptr<ExecNode>>& nodes) : nodes(nodes) {
+ visited.reserve(nodes.size());
+ sorted.resize(nodes.size());
+
+ for (const auto& node : nodes) {
+ Visit(node.get());
+ }
+
+ DCHECK_EQ(visited.size(), nodes.size());
+ }
+
+ void Visit(ExecNode* node) {
+ if (visited.count(node) != 0) return;
+
+ for (auto input : node->inputs()) {
+ // Ensure that producers are inserted before this consumer
+ Visit(input);
+ }
+
+ sorted[visited.size()] = node;
+ visited.insert(node);
+ }
+ };
+
+ return std::move(Impl{nodes_}.sorted);
+ }
+
+ std::string ToString() const {
+ std::stringstream ss;
+ ss << "ExecPlan with " << nodes_.size() << " nodes:" << std::endl;
+ for (const auto& node : TopoSort()) {
+ ss << node->ToString() << std::endl;
+ }
+ return ss.str();
+ }
+
+ Future<> finished_ = Future<>::MakeFinished();
+ bool started_ = false, stopped_ = false;
+ std::vector<std::unique_ptr<ExecNode>> nodes_;
+ NodeVector sources_, sinks_;
+ NodeVector sorted_nodes_;
+ uint32_t auto_label_counter_ = 0;
+};
+
+ExecPlanImpl* ToDerived(ExecPlan* ptr) { return checked_cast<ExecPlanImpl*>(ptr); }
+
+const ExecPlanImpl* ToDerived(const ExecPlan* ptr) {
+ return checked_cast<const ExecPlanImpl*>(ptr);
+}
+
+util::optional<int> GetNodeIndex(const std::vector<ExecNode*>& nodes,
+ const ExecNode* node) {
+ for (int i = 0; i < static_cast<int>(nodes.size()); ++i) {
+ if (nodes[i] == node) return i;
+ }
+ return util::nullopt;
+}
+
+} // namespace
+
+Result<std::shared_ptr<ExecPlan>> ExecPlan::Make(ExecContext* ctx) {
+ return std::shared_ptr<ExecPlan>(new ExecPlanImpl{ctx});
+}
+
+ExecNode* ExecPlan::AddNode(std::unique_ptr<ExecNode> node) {
+ return ToDerived(this)->AddNode(std::move(node));
+}
+
+const ExecPlan::NodeVector& ExecPlan::sources() const {
+ return ToDerived(this)->sources_;
+}
+
+const ExecPlan::NodeVector& ExecPlan::sinks() const { return ToDerived(this)->sinks_; }
+
+Status ExecPlan::Validate() { return ToDerived(this)->Validate(); }
+
+Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); }
+
+void ExecPlan::StopProducing() { ToDerived(this)->StopProducing(); }
+
+Future<> ExecPlan::finished() { return ToDerived(this)->finished_; }
+
+std::string ExecPlan::ToString() const { return ToDerived(this)->ToString(); }
+
+ExecNode::ExecNode(ExecPlan* plan, NodeVector inputs,
+ std::vector<std::string> input_labels,
+ std::shared_ptr<Schema> output_schema, int num_outputs)
+ : plan_(plan),
+ inputs_(std::move(inputs)),
+ input_labels_(std::move(input_labels)),
+ output_schema_(std::move(output_schema)),
+ num_outputs_(num_outputs) {
+ for (auto input : inputs_) {
+ input->outputs_.push_back(this);
+ }
+}
+
+Status ExecNode::Validate() const {
+ if (inputs_.size() != input_labels_.size()) {
+ return Status::Invalid("Invalid number of inputs for '", label(), "' (expected ",
+ num_inputs(), ", actual ", input_labels_.size(), ")");
+ }
+
+ if (static_cast<int>(outputs_.size()) != num_outputs_) {
+ return Status::Invalid("Invalid number of outputs for '", label(), "' (expected ",
+ num_outputs(), ", actual ", outputs_.size(), ")");
+ }
+
+ for (auto out : outputs_) {
+ auto input_index = GetNodeIndex(out->inputs(), this);
+ if (!input_index) {
+ return Status::Invalid("Node '", label(), "' outputs to node '", out->label(),
+ "' but is not listed as an input.");
+ }
+ }
+
+ return Status::OK();
+}
+
+std::string ExecNode::ToString() const {
+ std::stringstream ss;
+ ss << kind_name() << "{\"" << label_ << '"';
+ if (!inputs_.empty()) {
+ ss << ", inputs=[";
+ for (size_t i = 0; i < inputs_.size(); i++) {
+ if (i > 0) ss << ", ";
+ ss << input_labels_[i] << ": \"" << inputs_[i]->label() << '"';
+ }
+ ss << ']';
+ }
+
+ if (!outputs_.empty()) {
+ ss << ", outputs=[";
+ for (size_t i = 0; i < outputs_.size(); i++) {
+ if (i > 0) ss << ", ";
+ ss << "\"" << outputs_[i]->label() << "\"";
+ }
+ ss << ']';
+ }
+
+ const std::string extra = ToStringExtra();
+ if (!extra.empty()) ss << ", " << extra;
+
+ ss << '}';
+ return ss.str();
+}
+
+std::string ExecNode::ToStringExtra() const { return ""; }
+
+bool ExecNode::ErrorIfNotOk(Status status) {
+ if (status.ok()) return false;
+
+ for (auto out : outputs_) {
+ out->ErrorReceived(this, out == outputs_.back() ? std::move(status) : status);
+ }
+ return true;
+}
+
+MapNode::MapNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ std::shared_ptr<Schema> output_schema, bool async_mode)
+ : ExecNode(plan, std::move(inputs), /*input_labels=*/{"target"},
+ std::move(output_schema),
+ /*num_outputs=*/1) {
+ if (async_mode) {
+ executor_ = plan_->exec_context()->executor();
+ } else {
+ executor_ = nullptr;
+ }
+}
+
+void MapNode::ErrorReceived(ExecNode* input, Status error) {
+ DCHECK_EQ(input, inputs_[0]);
+ outputs_[0]->ErrorReceived(this, std::move(error));
+}
+
+void MapNode::InputFinished(ExecNode* input, int total_batches) {
+ DCHECK_EQ(input, inputs_[0]);
+ outputs_[0]->InputFinished(this, total_batches);
+ if (input_counter_.SetTotal(total_batches)) {
+ this->Finish();
+ }
+}
+
+Status MapNode::StartProducing() { return Status::OK(); }
+
+void MapNode::PauseProducing(ExecNode* output) {}
+
+void MapNode::ResumeProducing(ExecNode* output) {}
+
+void MapNode::StopProducing(ExecNode* output) {
+ DCHECK_EQ(output, outputs_[0]);
+ StopProducing();
+}
+
+void MapNode::StopProducing() {
+ if (executor_) {
+ this->stop_source_.RequestStop();
+ }
+ if (input_counter_.Cancel()) {
+ this->Finish();
+ }
+ inputs_[0]->StopProducing(this);
+}
+
+Future<> MapNode::finished() { return finished_; }
+
+void MapNode::SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn,
+ ExecBatch batch) {
+ Status status;
+ if (finished_.is_finished()) {
+ return;
+ }
+ auto task = [this, map_fn, batch]() {
+ auto guarantee = batch.guarantee;
+ auto output_batch = map_fn(std::move(batch));
+ if (ErrorIfNotOk(output_batch.status())) {
+ return output_batch.status();
+ }
+ output_batch->guarantee = guarantee;
+ outputs_[0]->InputReceived(this, output_batch.MoveValueUnsafe());
+ return Status::OK();
+ };
+
+ if (executor_) {
+ status = task_group_.AddTask([this, task]() -> Result<Future<>> {
+ return this->executor_->Submit(this->stop_source_.token(), [this, task]() {
+ auto status = task();
+ if (this->input_counter_.Increment()) {
+ this->Finish(status);
+ }
+ return status;
+ });
+ });
+ } else {
+ status = task();
+ if (input_counter_.Increment()) {
+ this->Finish(status);
+ }
+ }
+ if (!status.ok()) {
+ if (input_counter_.Cancel()) {
+ this->Finish(status);
+ }
+ inputs_[0]->StopProducing(this);
+ return;
+ }
+}
+
+void MapNode::Finish(Status finish_st /*= Status::OK()*/) {
+ if (executor_) {
+ task_group_.End().AddCallback([this, finish_st](const Status& st) {
+ Status final_status = finish_st & st;
+ this->finished_.MarkFinished(final_status);
+ });
+ } else {
+ this->finished_.MarkFinished(finish_st);
+ }
+}
+
+std::shared_ptr<RecordBatchReader> MakeGeneratorReader(
+ std::shared_ptr<Schema> schema,
+ std::function<Future<util::optional<ExecBatch>>()> gen, MemoryPool* pool) {
+ struct Impl : RecordBatchReader {
+ std::shared_ptr<Schema> schema() const override { return schema_; }
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* record_batch) override {
+ ARROW_ASSIGN_OR_RAISE(auto batch, iterator_.Next());
+ if (batch) {
+ ARROW_ASSIGN_OR_RAISE(*record_batch, batch->ToRecordBatch(schema_, pool_));
+ } else {
+ *record_batch = IterationEnd<std::shared_ptr<RecordBatch>>();
+ }
+ return Status::OK();
+ }
+
+ MemoryPool* pool_;
+ std::shared_ptr<Schema> schema_;
+ Iterator<util::optional<ExecBatch>> iterator_;
+ };
+
+ auto out = std::make_shared<Impl>();
+ out->pool_ = pool;
+ out->schema_ = std::move(schema);
+ out->iterator_ = MakeGeneratorIterator(std::move(gen));
+ return out;
+}
+
+Result<ExecNode*> Declaration::AddToPlan(ExecPlan* plan,
+ ExecFactoryRegistry* registry) const {
+ std::vector<ExecNode*> inputs(this->inputs.size());
+
+ size_t i = 0;
+ for (const Input& input : this->inputs) {
+ if (auto node = util::get_if<ExecNode*>(&input)) {
+ inputs[i++] = *node;
+ continue;
+ }
+ ARROW_ASSIGN_OR_RAISE(inputs[i++],
+ util::get<Declaration>(input).AddToPlan(plan, registry));
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto node, MakeExecNode(this->factory_name, plan, std::move(inputs), *this->options,
+ registry));
+ node->SetLabel(this->label);
+ return node;
+}
+
+Declaration Declaration::Sequence(std::vector<Declaration> decls) {
+ DCHECK(!decls.empty());
+
+ Declaration out = std::move(decls.back());
+ decls.pop_back();
+ auto receiver = &out;
+ while (!decls.empty()) {
+ Declaration input = std::move(decls.back());
+ decls.pop_back();
+
+ receiver->inputs.emplace_back(std::move(input));
+ receiver = &util::get<Declaration>(receiver->inputs.front());
+ }
+ return out;
+}
+
+namespace internal {
+
+void RegisterSourceNode(ExecFactoryRegistry*);
+void RegisterFilterNode(ExecFactoryRegistry*);
+void RegisterProjectNode(ExecFactoryRegistry*);
+void RegisterUnionNode(ExecFactoryRegistry*);
+void RegisterAggregateNode(ExecFactoryRegistry*);
+void RegisterSinkNode(ExecFactoryRegistry*);
+void RegisterHashJoinNode(ExecFactoryRegistry*);
+
+} // namespace internal
+
+ExecFactoryRegistry* default_exec_factory_registry() {
+ class DefaultRegistry : public ExecFactoryRegistry {
+ public:
+ DefaultRegistry() {
+ internal::RegisterSourceNode(this);
+ internal::RegisterFilterNode(this);
+ internal::RegisterProjectNode(this);
+ internal::RegisterUnionNode(this);
+ internal::RegisterAggregateNode(this);
+ internal::RegisterSinkNode(this);
+ internal::RegisterHashJoinNode(this);
+ }
+
+ Result<Factory> GetFactory(const std::string& factory_name) override {
+ auto it = factories_.find(factory_name);
+ if (it == factories_.end()) {
+ return Status::KeyError("ExecNode factory named ", factory_name,
+ " not present in registry.");
+ }
+ return it->second;
+ }
+
+ Status AddFactory(std::string factory_name, Factory factory) override {
+ auto it_success = factories_.emplace(std::move(factory_name), std::move(factory));
+
+ if (!it_success.second) {
+ const auto& factory_name = it_success.first->first;
+ return Status::KeyError("ExecNode factory named ", factory_name,
+ " already registered.");
+ }
+
+ return Status::OK();
+ }
+
+ private:
+ std::unordered_map<std::string, Factory> factories_;
+ };
+
+ static DefaultRegistry instance;
+ return &instance;
+}
+
+Result<std::function<Future<util::optional<ExecBatch>>()>> MakeReaderGenerator(
+ std::shared_ptr<RecordBatchReader> reader, ::arrow::internal::Executor* io_executor,
+ int max_q, int q_restart) {
+ auto batch_it = MakeMapIterator(
+ [](std::shared_ptr<RecordBatch> batch) {
+ return util::make_optional(ExecBatch(*batch));
+ },
+ MakeIteratorFromReader(reader));
+
+ return MakeBackgroundGenerator(std::move(batch_it), io_executor, max_q, q_restart);
+}
+} // namespace compute
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/exec_plan.h b/src/arrow/cpp/src/arrow/compute/exec/exec_plan.h
new file mode 100644
index 000000000..b5e59fe8d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/exec_plan.h
@@ -0,0 +1,422 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/compute/type_fwd.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/async_util.h"
+#include "arrow/util/cancel.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+namespace compute {
+
+class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this<ExecPlan> {
+ public:
+ using NodeVector = std::vector<ExecNode*>;
+
+ virtual ~ExecPlan() = default;
+
+ ExecContext* exec_context() const { return exec_context_; }
+
+ /// Make an empty exec plan
+ static Result<std::shared_ptr<ExecPlan>> Make(ExecContext* = default_exec_context());
+
+ ExecNode* AddNode(std::unique_ptr<ExecNode> node);
+
+ template <typename Node, typename... Args>
+ Node* EmplaceNode(Args&&... args) {
+ std::unique_ptr<Node> node{new Node{std::forward<Args>(args)...}};
+ auto out = node.get();
+ AddNode(std::move(node));
+ return out;
+ }
+
+ /// The initial inputs
+ const NodeVector& sources() const;
+
+ /// The final outputs
+ const NodeVector& sinks() const;
+
+ Status Validate();
+
+ /// \brief Start producing on all nodes
+ ///
+ /// Nodes are started in reverse topological order, such that any node
+ /// is started before all of its inputs.
+ Status StartProducing();
+
+ /// \brief Stop producing on all nodes
+ ///
+ /// Nodes are stopped in topological order, such that any node
+ /// is stopped before all of its outputs.
+ void StopProducing();
+
+ /// \brief A future which will be marked finished when all nodes have stopped producing.
+ Future<> finished();
+
+ std::string ToString() const;
+
+ protected:
+ ExecContext* exec_context_;
+ explicit ExecPlan(ExecContext* exec_context) : exec_context_(exec_context) {}
+};
+
+class ARROW_EXPORT ExecNode {
+ public:
+ using NodeVector = std::vector<ExecNode*>;
+
+ virtual ~ExecNode() = default;
+
+ virtual const char* kind_name() const = 0;
+
+ // The number of inputs/outputs expected by this node
+ int num_inputs() const { return static_cast<int>(inputs_.size()); }
+ int num_outputs() const { return num_outputs_; }
+
+ /// This node's predecessors in the exec plan
+ const NodeVector& inputs() const { return inputs_; }
+
+ /// \brief Labels identifying the function of each input.
+ const std::vector<std::string>& input_labels() const { return input_labels_; }
+
+ /// This node's successors in the exec plan
+ const NodeVector& outputs() const { return outputs_; }
+
+ /// The datatypes for batches produced by this node
+ const std::shared_ptr<Schema>& output_schema() const { return output_schema_; }
+
+ /// This node's exec plan
+ ExecPlan* plan() { return plan_; }
+
+ /// \brief An optional label, for display and debugging
+ ///
+ /// There is no guarantee that this value is non-empty or unique.
+ const std::string& label() const { return label_; }
+ void SetLabel(std::string label) { label_ = std::move(label); }
+
+ Status Validate() const;
+
+ /// Upstream API:
+ /// These functions are called by input nodes that want to inform this node
+ /// about an updated condition (a new input batch, an error, an impeding
+ /// end of stream).
+ ///
+ /// Implementation rules:
+ /// - these may be called anytime after StartProducing() has succeeded
+ /// (and even during or after StopProducing())
+ /// - these may be called concurrently
+ /// - these are allowed to call back into PauseProducing(), ResumeProducing()
+ /// and StopProducing()
+
+ /// Transfer input batch to ExecNode
+ virtual void InputReceived(ExecNode* input, ExecBatch batch) = 0;
+
+ /// Signal error to ExecNode
+ virtual void ErrorReceived(ExecNode* input, Status error) = 0;
+
+ /// Mark the inputs finished after the given number of batches.
+ ///
+ /// This may be called before all inputs are received. This simply fixes
+ /// the total number of incoming batches for an input, so that the ExecNode
+ /// knows when it has received all input, regardless of order.
+ virtual void InputFinished(ExecNode* input, int total_batches) = 0;
+
+ /// Lifecycle API:
+ /// - start / stop to initiate and terminate production
+ /// - pause / resume to apply backpressure
+ ///
+ /// Implementation rules:
+ /// - StartProducing() should not recurse into the inputs, as it is
+ /// handled by ExecPlan::StartProducing()
+ /// - PauseProducing(), ResumeProducing(), StopProducing() may be called
+ /// concurrently (but only after StartProducing() has returned successfully)
+ /// - PauseProducing(), ResumeProducing(), StopProducing() may be called
+ /// by the downstream nodes' InputReceived(), ErrorReceived(), InputFinished()
+ /// methods
+ /// - StopProducing() should recurse into the inputs
+ /// - StopProducing() must be idempotent
+
+ // XXX What happens if StartProducing() calls an output's InputReceived()
+ // synchronously, and InputReceived() decides to call back into StopProducing()
+ // (or PauseProducing()) because it received enough data?
+ //
+ // Right now, since synchronous calls happen in both directions (input to
+ // output and then output to input), a node must be careful to be reentrant
+ // against synchronous calls from its output, *and* also concurrent calls from
+ // other threads. The most reliable solution is to update the internal state
+ // first, and notify outputs only at the end.
+ //
+ // Alternate rules:
+ // - StartProducing(), ResumeProducing() can call synchronously into
+ // its ouputs' consuming methods (InputReceived() etc.)
+ // - InputReceived(), ErrorReceived(), InputFinished() can call asynchronously
+ // into its inputs' PauseProducing(), StopProducing()
+ //
+ // Alternate API:
+ // - InputReceived(), ErrorReceived(), InputFinished() return a ProductionHint
+ // enum: either None (default), PauseProducing, ResumeProducing, StopProducing
+ // - A method allows passing a ProductionHint asynchronously from an output node
+ // (replacing PauseProducing(), ResumeProducing(), StopProducing())
+
+ /// \brief Start producing
+ ///
+ /// This must only be called once. If this fails, then other lifecycle
+ /// methods must not be called.
+ ///
+ /// This is typically called automatically by ExecPlan::StartProducing().
+ virtual Status StartProducing() = 0;
+
+ /// \brief Pause producing temporarily
+ ///
+ /// This call is a hint that an output node is currently not willing
+ /// to receive data.
+ ///
+ /// This may be called any number of times after StartProducing() succeeds.
+ /// However, the node is still free to produce data (which may be difficult
+ /// to prevent anyway if data is produced using multiple threads).
+ virtual void PauseProducing(ExecNode* output) = 0;
+
+ /// \brief Resume producing after a temporary pause
+ ///
+ /// This call is a hint that an output node is willing to receive data again.
+ ///
+ /// This may be called any number of times after StartProducing() succeeds.
+ /// This may also be called concurrently with PauseProducing(), which suggests
+ /// the implementation may use an atomic counter.
+ virtual void ResumeProducing(ExecNode* output) = 0;
+
+ /// \brief Stop producing definitively to a single output
+ ///
+ /// This call is a hint that an output node has completed and is not willing
+ /// to receive any further data.
+ virtual void StopProducing(ExecNode* output) = 0;
+
+ /// \brief Stop producing definitively to all outputs
+ virtual void StopProducing() = 0;
+
+ /// \brief A future which will be marked finished when this node has stopped producing.
+ virtual Future<> finished() = 0;
+
+ std::string ToString() const;
+
+ protected:
+ ExecNode(ExecPlan* plan, NodeVector inputs, std::vector<std::string> input_labels,
+ std::shared_ptr<Schema> output_schema, int num_outputs);
+
+ // A helper method to send an error status to all outputs.
+ // Returns true if the status was an error.
+ bool ErrorIfNotOk(Status status);
+
+ /// Provide extra info to include in the string representation.
+ virtual std::string ToStringExtra() const;
+
+ ExecPlan* plan_;
+ std::string label_;
+
+ NodeVector inputs_;
+ std::vector<std::string> input_labels_;
+
+ std::shared_ptr<Schema> output_schema_;
+ int num_outputs_;
+ NodeVector outputs_;
+};
+
+/// \brief MapNode is an ExecNode type class which process a task like filter/project
+/// (See SubmitTask method) to each given ExecBatch object, which have one input, one
+/// output, and are pure functions on the input
+///
+/// A simple parallel runner is created with a "map_fn" which is just a function that
+/// takes a batch in and returns a batch. This simple parallel runner also needs an
+/// executor (use simple synchronous runner if there is no executor)
+
+class MapNode : public ExecNode {
+ public:
+ MapNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ std::shared_ptr<Schema> output_schema, bool async_mode);
+
+ void ErrorReceived(ExecNode* input, Status error) override;
+
+ void InputFinished(ExecNode* input, int total_batches) override;
+
+ Status StartProducing() override;
+
+ void PauseProducing(ExecNode* output) override;
+
+ void ResumeProducing(ExecNode* output) override;
+
+ void StopProducing(ExecNode* output) override;
+
+ void StopProducing() override;
+
+ Future<> finished() override;
+
+ protected:
+ void SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn, ExecBatch batch);
+
+ void Finish(Status finish_st = Status::OK());
+
+ protected:
+ // Counter for the number of batches received
+ AtomicCounter input_counter_;
+
+ // Future to sync finished
+ Future<> finished_ = Future<>::Make();
+
+ // The task group for the corresponding batches
+ util::AsyncTaskGroup task_group_;
+
+ ::arrow::internal::Executor* executor_;
+
+ // Variable used to cancel remaining tasks in the executor
+ StopSource stop_source_;
+};
+
+/// \brief An extensible registry for factories of ExecNodes
+class ARROW_EXPORT ExecFactoryRegistry {
+ public:
+ using Factory = std::function<Result<ExecNode*>(ExecPlan*, std::vector<ExecNode*>,
+ const ExecNodeOptions&)>;
+
+ virtual ~ExecFactoryRegistry() = default;
+
+ /// \brief Get the named factory from this registry
+ ///
+ /// will raise if factory_name is not found
+ virtual Result<Factory> GetFactory(const std::string& factory_name) = 0;
+
+ /// \brief Add a factory to this registry with the provided name
+ ///
+ /// will raise if factory_name is already in the registry
+ virtual Status AddFactory(std::string factory_name, Factory factory) = 0;
+};
+
+/// The default registry, which includes built-in factories.
+ARROW_EXPORT
+ExecFactoryRegistry* default_exec_factory_registry();
+
+/// \brief Construct an ExecNode using the named factory
+inline Result<ExecNode*> MakeExecNode(
+ const std::string& factory_name, ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options,
+ ExecFactoryRegistry* registry = default_exec_factory_registry()) {
+ ARROW_ASSIGN_OR_RAISE(auto factory, registry->GetFactory(factory_name));
+ return factory(plan, std::move(inputs), options);
+}
+
+/// \brief Helper class for declaring sets of ExecNodes efficiently
+///
+/// A Declaration represents an unconstructed ExecNode (and potentially more since its
+/// inputs may also be Declarations). The node can be constructed and added to a plan
+/// with Declaration::AddToPlan, which will recursively construct any inputs as necessary.
+struct ARROW_EXPORT Declaration {
+ using Input = util::Variant<ExecNode*, Declaration>;
+
+ Declaration(std::string factory_name, std::vector<Input> inputs,
+ std::shared_ptr<ExecNodeOptions> options, std::string label)
+ : factory_name{std::move(factory_name)},
+ inputs{std::move(inputs)},
+ options{std::move(options)},
+ label{std::move(label)} {}
+
+ template <typename Options>
+ Declaration(std::string factory_name, std::vector<Input> inputs, Options options)
+ : factory_name{std::move(factory_name)},
+ inputs{std::move(inputs)},
+ options{std::make_shared<Options>(std::move(options))},
+ label{this->factory_name} {}
+
+ template <typename Options>
+ Declaration(std::string factory_name, Options options)
+ : factory_name{std::move(factory_name)},
+ inputs{},
+ options{std::make_shared<Options>(std::move(options))},
+ label{this->factory_name} {}
+
+ /// \brief Convenience factory for the common case of a simple sequence of nodes.
+ ///
+ /// Each of decls will be appended to the inputs of the subsequent declaration,
+ /// and the final modified declaration will be returned.
+ ///
+ /// Without this convenience factory, constructing a sequence would require explicit,
+ /// difficult-to-read nesting:
+ ///
+ /// Declaration{"n3",
+ /// {
+ /// Declaration{"n2",
+ /// {
+ /// Declaration{"n1",
+ /// {
+ /// Declaration{"n0", N0Opts{}},
+ /// },
+ /// N1Opts{}},
+ /// },
+ /// N2Opts{}},
+ /// },
+ /// N3Opts{}};
+ ///
+ /// An equivalent Declaration can be constructed more tersely using Sequence:
+ ///
+ /// Declaration::Sequence({
+ /// {"n0", N0Opts{}},
+ /// {"n1", N1Opts{}},
+ /// {"n2", N2Opts{}},
+ /// {"n3", N3Opts{}},
+ /// });
+ static Declaration Sequence(std::vector<Declaration> decls);
+
+ Result<ExecNode*> AddToPlan(ExecPlan* plan, ExecFactoryRegistry* registry =
+ default_exec_factory_registry()) const;
+
+ std::string factory_name;
+ std::vector<Input> inputs;
+ std::shared_ptr<ExecNodeOptions> options;
+ std::string label;
+};
+
+/// \brief Wrap an ExecBatch generator in a RecordBatchReader.
+///
+/// The RecordBatchReader does not impose any ordering on emitted batches.
+ARROW_EXPORT
+std::shared_ptr<RecordBatchReader> MakeGeneratorReader(
+ std::shared_ptr<Schema>, std::function<Future<util::optional<ExecBatch>>()>,
+ MemoryPool*);
+
+constexpr int kDefaultBackgroundMaxQ = 32;
+constexpr int kDefaultBackgroundQRestart = 16;
+
+/// \brief Make a generator of RecordBatchReaders
+///
+/// Useful as a source node for an Exec plan
+ARROW_EXPORT
+Result<std::function<Future<util::optional<ExecBatch>>()>> MakeReaderGenerator(
+ std::shared_ptr<RecordBatchReader> reader, arrow::internal::Executor* io_executor,
+ int max_q = kDefaultBackgroundMaxQ, int q_restart = kDefaultBackgroundQRestart);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/expression.cc b/src/arrow/cpp/src/arrow/compute/exec/expression.cc
new file mode 100644
index 000000000..64e330582
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/expression.cc
@@ -0,0 +1,1192 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/expression.h"
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "arrow/chunked_array.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/exec/expression_internal.h"
+#include "arrow/compute/exec_internal.h"
+#include "arrow/compute/function_internal.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/util/hash_util.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/string.h"
+#include "arrow/util/value_parsing.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+void Expression::Call::ComputeHash() {
+ hash = std::hash<std::string>{}(function_name);
+ for (const auto& arg : arguments) {
+ arrow::internal::hash_combine(hash, arg.hash());
+ }
+}
+
+Expression::Expression(Call call) {
+ call.ComputeHash();
+ impl_ = std::make_shared<Impl>(std::move(call));
+}
+
+Expression::Expression(Datum literal)
+ : impl_(std::make_shared<Impl>(std::move(literal))) {}
+
+Expression::Expression(Parameter parameter)
+ : impl_(std::make_shared<Impl>(std::move(parameter))) {}
+
+Expression literal(Datum lit) { return Expression(std::move(lit)); }
+
+Expression field_ref(FieldRef ref) {
+ return Expression(Expression::Parameter{std::move(ref), ValueDescr{}, -1});
+}
+
+Expression call(std::string function, std::vector<Expression> arguments,
+ std::shared_ptr<compute::FunctionOptions> options) {
+ Expression::Call call;
+ call.function_name = std::move(function);
+ call.arguments = std::move(arguments);
+ call.options = std::move(options);
+ return Expression(std::move(call));
+}
+
+const Datum* Expression::literal() const { return util::get_if<Datum>(impl_.get()); }
+
+const Expression::Parameter* Expression::parameter() const {
+ return util::get_if<Parameter>(impl_.get());
+}
+
+const FieldRef* Expression::field_ref() const {
+ if (auto parameter = this->parameter()) {
+ return &parameter->ref;
+ }
+ return nullptr;
+}
+
+const Expression::Call* Expression::call() const {
+ return util::get_if<Call>(impl_.get());
+}
+
+ValueDescr Expression::descr() const {
+ if (impl_ == nullptr) return {};
+
+ if (auto lit = literal()) {
+ return lit->descr();
+ }
+
+ if (auto parameter = this->parameter()) {
+ return parameter->descr;
+ }
+
+ return CallNotNull(*this)->descr;
+}
+
+namespace {
+
+std::string PrintDatum(const Datum& datum) {
+ if (datum.is_scalar()) {
+ if (!datum.scalar()->is_valid) return "null";
+
+ switch (datum.type()->id()) {
+ case Type::STRING:
+ case Type::LARGE_STRING:
+ return '"' +
+ Escape(util::string_view(*datum.scalar_as<BaseBinaryScalar>().value)) +
+ '"';
+
+ case Type::BINARY:
+ case Type::FIXED_SIZE_BINARY:
+ case Type::LARGE_BINARY:
+ return '"' + datum.scalar_as<BaseBinaryScalar>().value->ToHexString() + '"';
+
+ default:
+ break;
+ }
+
+ return datum.scalar()->ToString();
+ }
+ return datum.ToString();
+}
+
+} // namespace
+
+std::string Expression::ToString() const {
+ if (auto lit = literal()) {
+ return PrintDatum(*lit);
+ }
+
+ if (auto ref = field_ref()) {
+ if (auto name = ref->name()) {
+ return *name;
+ }
+ if (auto path = ref->field_path()) {
+ return path->ToString();
+ }
+ return ref->ToString();
+ }
+
+ auto call = CallNotNull(*this);
+ auto binary = [&](std::string op) {
+ return "(" + call->arguments[0].ToString() + " " + op + " " +
+ call->arguments[1].ToString() + ")";
+ };
+
+ if (auto cmp = Comparison::Get(call->function_name)) {
+ return binary(Comparison::GetOp(*cmp));
+ }
+
+ constexpr util::string_view kleene = "_kleene";
+ if (util::string_view{call->function_name}.ends_with(kleene)) {
+ auto op = call->function_name.substr(0, call->function_name.size() - kleene.size());
+ return binary(std::move(op));
+ }
+
+ if (auto options = GetMakeStructOptions(*call)) {
+ std::string out = "{";
+ auto argument = call->arguments.begin();
+ for (const auto& field_name : options->field_names) {
+ out += field_name + "=" + argument++->ToString() + ", ";
+ }
+ out.resize(out.size() - 1);
+ out.back() = '}';
+ return out;
+ }
+
+ std::string out = call->function_name + "(";
+ for (const auto& arg : call->arguments) {
+ out += arg.ToString() + ", ";
+ }
+
+ if (call->options) {
+ out += call->options->ToString();
+ out.resize(out.size() + 1);
+ } else {
+ out.resize(out.size() - 1);
+ }
+ out.back() = ')';
+ return out;
+}
+
+void PrintTo(const Expression& expr, std::ostream* os) {
+ *os << expr.ToString();
+ if (expr.IsBound()) {
+ *os << "[bound]";
+ }
+}
+
+bool Expression::Equals(const Expression& other) const {
+ if (Identical(*this, other)) return true;
+
+ if (impl_->index() != other.impl_->index()) {
+ return false;
+ }
+
+ if (auto lit = literal()) {
+ return lit->Equals(*other.literal());
+ }
+
+ if (auto ref = field_ref()) {
+ return ref->Equals(*other.field_ref());
+ }
+
+ auto call = CallNotNull(*this);
+ auto other_call = CallNotNull(other);
+
+ if (call->function_name != other_call->function_name ||
+ call->kernel != other_call->kernel) {
+ return false;
+ }
+
+ for (size_t i = 0; i < call->arguments.size(); ++i) {
+ if (!call->arguments[i].Equals(other_call->arguments[i])) {
+ return false;
+ }
+ }
+
+ if (call->options == other_call->options) return true;
+ if (call->options && other_call->options) {
+ return call->options->Equals(other_call->options);
+ }
+ return false;
+}
+
+bool Identical(const Expression& l, const Expression& r) { return l.impl_ == r.impl_; }
+
+size_t Expression::hash() const {
+ if (auto lit = literal()) {
+ if (lit->is_scalar()) {
+ return lit->scalar()->hash();
+ }
+ return 0;
+ }
+
+ if (auto ref = field_ref()) {
+ return ref->hash();
+ }
+
+ return CallNotNull(*this)->hash;
+}
+
+bool Expression::IsBound() const {
+ if (type() == nullptr) return false;
+
+ if (auto call = this->call()) {
+ if (call->kernel == nullptr) return false;
+
+ for (const Expression& arg : call->arguments) {
+ if (!arg.IsBound()) return false;
+ }
+ }
+
+ return true;
+}
+
+bool Expression::IsScalarExpression() const {
+ if (auto lit = literal()) {
+ return lit->is_scalar();
+ }
+
+ if (field_ref()) return true;
+
+ auto call = CallNotNull(*this);
+
+ for (const Expression& arg : call->arguments) {
+ if (!arg.IsScalarExpression()) return false;
+ }
+
+ if (call->function) {
+ return call->function->kind() == compute::Function::SCALAR;
+ }
+
+ // this expression is not bound; make a best guess based on
+ // the default function registry
+ if (auto function = compute::GetFunctionRegistry()
+ ->GetFunction(call->function_name)
+ .ValueOr(nullptr)) {
+ return function->kind() == compute::Function::SCALAR;
+ }
+
+ // unknown function or other error; conservatively return false
+ return false;
+}
+
+bool Expression::IsNullLiteral() const {
+ if (auto lit = literal()) {
+ if (lit->null_count() == lit->length()) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+bool Expression::IsSatisfiable() const {
+ if (type() && type()->id() == Type::NA) {
+ return false;
+ }
+
+ if (auto lit = literal()) {
+ if (lit->null_count() == lit->length()) {
+ return false;
+ }
+
+ if (lit->is_scalar() && lit->type()->id() == Type::BOOL) {
+ return lit->scalar_as<BooleanScalar>().value;
+ }
+ }
+
+ return true;
+}
+
+namespace {
+
+// Produce a bound Expression from unbound Call and bound arguments.
+Result<Expression> BindNonRecursive(Expression::Call call, bool insert_implicit_casts,
+ compute::ExecContext* exec_context) {
+ DCHECK(std::all_of(call.arguments.begin(), call.arguments.end(),
+ [](const Expression& argument) { return argument.IsBound(); }));
+
+ auto descrs = GetDescriptors(call.arguments);
+ ARROW_ASSIGN_OR_RAISE(call.function, GetFunction(call, exec_context));
+
+ if (!insert_implicit_casts) {
+ ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchExact(descrs));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&descrs));
+
+ for (size_t i = 0; i < descrs.size(); ++i) {
+ if (descrs[i] == call.arguments[i].descr()) continue;
+
+ if (descrs[i].shape != call.arguments[i].descr().shape) {
+ return Status::NotImplemented(
+ "Automatic broadcasting of scalars arguments to arrays in ",
+ Expression(std::move(call)).ToString());
+ }
+
+ if (auto lit = call.arguments[i].literal()) {
+ ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, descrs[i].type));
+ call.arguments[i] = literal(std::move(new_lit));
+ continue;
+ }
+
+ // construct an implicit cast Expression with which to replace this argument
+ Expression::Call implicit_cast;
+ implicit_cast.function_name = "cast";
+ implicit_cast.arguments = {std::move(call.arguments[i])};
+ implicit_cast.options = std::make_shared<compute::CastOptions>(
+ compute::CastOptions::Safe(descrs[i].type));
+
+ ARROW_ASSIGN_OR_RAISE(
+ call.arguments[i],
+ BindNonRecursive(std::move(implicit_cast),
+ /*insert_implicit_casts=*/false, exec_context));
+ }
+ }
+
+ compute::KernelContext kernel_context(exec_context);
+ if (call.kernel->init) {
+ ARROW_ASSIGN_OR_RAISE(
+ call.kernel_state,
+ call.kernel->init(&kernel_context, {call.kernel, descrs, call.options.get()}));
+
+ kernel_context.SetState(call.kernel_state.get());
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ call.descr, call.kernel->signature->out_type().Resolve(&kernel_context, descrs));
+
+ return Expression(std::move(call));
+}
+
+template <typename TypeOrSchema>
+Result<Expression> BindImpl(Expression expr, const TypeOrSchema& in,
+ ValueDescr::Shape shape, compute::ExecContext* exec_context) {
+ if (exec_context == nullptr) {
+ compute::ExecContext exec_context;
+ return BindImpl(std::move(expr), in, shape, &exec_context);
+ }
+
+ if (expr.literal()) return expr;
+
+ if (auto ref = expr.field_ref()) {
+ if (ref->IsNested()) {
+ return Status::NotImplemented("nested field references");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto path, ref->FindOne(in));
+
+ auto bound = *expr.parameter();
+ bound.index = path[0];
+ ARROW_ASSIGN_OR_RAISE(auto field, path.Get(in));
+ bound.descr.type = field->type();
+ bound.descr.shape = shape;
+ return Expression{std::move(bound)};
+ }
+
+ auto call = *CallNotNull(expr);
+ for (auto& argument : call.arguments) {
+ ARROW_ASSIGN_OR_RAISE(argument,
+ BindImpl(std::move(argument), in, shape, exec_context));
+ }
+ return BindNonRecursive(std::move(call),
+ /*insert_implicit_casts=*/true, exec_context);
+}
+
+} // namespace
+
+Result<Expression> Expression::Bind(const ValueDescr& in,
+ compute::ExecContext* exec_context) const {
+ return BindImpl(*this, *in.type, in.shape, exec_context);
+}
+
+Result<Expression> Expression::Bind(const Schema& in_schema,
+ compute::ExecContext* exec_context) const {
+ return BindImpl(*this, in_schema, ValueDescr::ARRAY, exec_context);
+}
+
+Result<ExecBatch> MakeExecBatch(const Schema& full_schema, const Datum& partial) {
+ ExecBatch out;
+
+ if (partial.kind() == Datum::RECORD_BATCH) {
+ const auto& partial_batch = *partial.record_batch();
+ out.length = partial_batch.num_rows();
+
+ for (const auto& field : full_schema.fields()) {
+ ARROW_ASSIGN_OR_RAISE(auto column,
+ FieldRef(field->name()).GetOneOrNone(partial_batch));
+
+ if (column) {
+ if (!column->type()->Equals(field->type())) {
+ // Referenced field was present but didn't have the expected type.
+ // This *should* be handled by readers, and will just be an error in the future.
+ ARROW_ASSIGN_OR_RAISE(
+ auto converted,
+ compute::Cast(column, field->type(), compute::CastOptions::Safe()));
+ column = converted.make_array();
+ }
+ out.values.emplace_back(std::move(column));
+ } else {
+ out.values.emplace_back(MakeNullScalar(field->type()));
+ }
+ }
+ return out;
+ }
+
+ // wasteful but useful for testing:
+ if (partial.type()->id() == Type::STRUCT) {
+ if (partial.is_array()) {
+ ARROW_ASSIGN_OR_RAISE(auto partial_batch,
+ RecordBatch::FromStructArray(partial.make_array()));
+
+ return MakeExecBatch(full_schema, partial_batch);
+ }
+
+ if (partial.is_scalar()) {
+ ARROW_ASSIGN_OR_RAISE(auto partial_array,
+ MakeArrayFromScalar(*partial.scalar(), 1));
+ ARROW_ASSIGN_OR_RAISE(auto out, MakeExecBatch(full_schema, partial_array));
+
+ for (Datum& value : out.values) {
+ if (value.is_scalar()) continue;
+ ARROW_ASSIGN_OR_RAISE(value, value.make_array()->GetScalar(0));
+ }
+ return out;
+ }
+ }
+
+ return Status::NotImplemented("MakeExecBatch from ", PrintDatum(partial));
+}
+
+Result<Datum> ExecuteScalarExpression(const Expression& expr, const Schema& full_schema,
+ const Datum& partial_input,
+ compute::ExecContext* exec_context) {
+ ARROW_ASSIGN_OR_RAISE(auto input, MakeExecBatch(full_schema, partial_input));
+ return ExecuteScalarExpression(expr, input, exec_context);
+}
+
+Result<Datum> ExecuteScalarExpression(const Expression& expr, const ExecBatch& input,
+ compute::ExecContext* exec_context) {
+ if (exec_context == nullptr) {
+ compute::ExecContext exec_context;
+ return ExecuteScalarExpression(expr, input, &exec_context);
+ }
+
+ if (!expr.IsBound()) {
+ return Status::Invalid("Cannot Execute unbound expression.");
+ }
+
+ if (!expr.IsScalarExpression()) {
+ return Status::Invalid(
+ "ExecuteScalarExpression cannot Execute non-scalar expression ", expr.ToString());
+ }
+
+ if (auto lit = expr.literal()) return *lit;
+
+ if (auto param = expr.parameter()) {
+ if (param->descr.type->id() == Type::NA) {
+ return MakeNullScalar(null());
+ }
+
+ const Datum& field = input[param->index];
+ if (!field.type()->Equals(param->descr.type)) {
+ return Status::Invalid("Referenced field ", expr.ToString(), " was ",
+ field.type()->ToString(), " but should have been ",
+ param->descr.type->ToString());
+ }
+
+ return field;
+ }
+
+ auto call = CallNotNull(expr);
+
+ std::vector<Datum> arguments(call->arguments.size());
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(
+ arguments[i], ExecuteScalarExpression(call->arguments[i], input, exec_context));
+ }
+
+ auto executor = compute::detail::KernelExecutor::MakeScalar();
+
+ compute::KernelContext kernel_context(exec_context);
+ kernel_context.SetState(call->kernel_state.get());
+
+ auto kernel = call->kernel;
+ auto descrs = GetDescriptors(arguments);
+ auto options = call->options.get();
+ RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, descrs, options}));
+
+ compute::detail::DatumAccumulator listener;
+ RETURN_NOT_OK(executor->Execute(arguments, &listener));
+ const auto out = executor->WrapResults(arguments, listener.values());
+#ifndef NDEBUG
+ DCHECK_OK(executor->CheckResultType(out, call->function_name.c_str()));
+#endif
+ return out;
+}
+
+namespace {
+
+std::array<std::pair<const Expression&, const Expression&>, 2>
+ArgumentsAndFlippedArguments(const Expression::Call& call) {
+ DCHECK_EQ(call.arguments.size(), 2);
+ return {std::pair<const Expression&, const Expression&>{call.arguments[0],
+ call.arguments[1]},
+ std::pair<const Expression&, const Expression&>{call.arguments[1],
+ call.arguments[0]}};
+}
+
+template <typename BinOp, typename It,
+ typename Out = typename std::iterator_traits<It>::value_type>
+util::optional<Out> FoldLeft(It begin, It end, const BinOp& bin_op) {
+ if (begin == end) return util::nullopt;
+
+ Out folded = std::move(*begin++);
+ while (begin != end) {
+ folded = bin_op(std::move(folded), std::move(*begin++));
+ }
+ return folded;
+}
+
+util::optional<compute::NullHandling::type> GetNullHandling(
+ const Expression::Call& call) {
+ if (call.function && call.function->kind() == compute::Function::SCALAR) {
+ return static_cast<const compute::ScalarKernel*>(call.kernel)->null_handling;
+ }
+ return util::nullopt;
+}
+
+} // namespace
+
+std::vector<FieldRef> FieldsInExpression(const Expression& expr) {
+ if (expr.literal()) return {};
+
+ if (auto ref = expr.field_ref()) {
+ return {*ref};
+ }
+
+ std::vector<FieldRef> fields;
+ for (const Expression& arg : CallNotNull(expr)->arguments) {
+ auto argument_fields = FieldsInExpression(arg);
+ std::move(argument_fields.begin(), argument_fields.end(), std::back_inserter(fields));
+ }
+ return fields;
+}
+
+bool ExpressionHasFieldRefs(const Expression& expr) {
+ if (expr.literal()) return false;
+
+ if (expr.field_ref()) return true;
+
+ for (const Expression& arg : CallNotNull(expr)->arguments) {
+ if (ExpressionHasFieldRefs(arg)) return true;
+ }
+ return false;
+}
+
+Result<Expression> FoldConstants(Expression expr) {
+ return Modify(
+ std::move(expr), [](Expression expr) { return expr; },
+ [](Expression expr, ...) -> Result<Expression> {
+ auto call = CallNotNull(expr);
+ if (std::all_of(call->arguments.begin(), call->arguments.end(),
+ [](const Expression& argument) { return argument.literal(); })) {
+ // all arguments are literal; we can evaluate this subexpression *now*
+ static const ExecBatch ignored_input = ExecBatch{};
+ ARROW_ASSIGN_OR_RAISE(Datum constant,
+ ExecuteScalarExpression(expr, ignored_input));
+
+ return literal(std::move(constant));
+ }
+
+ // XXX the following should probably be in a registry of passes instead
+ // of inline
+
+ if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) {
+ // kernels which always produce intersected validity can be resolved
+ // to null *now* if any of their inputs is a null literal
+ for (const auto& argument : call->arguments) {
+ if (argument.IsNullLiteral()) {
+ return argument;
+ }
+ }
+ }
+
+ if (call->function_name == "and_kleene") {
+ for (auto args : ArgumentsAndFlippedArguments(*call)) {
+ // true and x == x
+ if (args.first == literal(true)) return args.second;
+
+ // false and x == false
+ if (args.first == literal(false)) return args.first;
+
+ // x and x == x
+ if (args.first == args.second) return args.first;
+ }
+ return expr;
+ }
+
+ if (call->function_name == "or_kleene") {
+ for (auto args : ArgumentsAndFlippedArguments(*call)) {
+ // false or x == x
+ if (args.first == literal(false)) return args.second;
+
+ // true or x == true
+ if (args.first == literal(true)) return args.first;
+
+ // x or x == x
+ if (args.first == args.second) return args.first;
+ }
+ return expr;
+ }
+
+ return expr;
+ });
+}
+
+namespace {
+
+std::vector<Expression> GuaranteeConjunctionMembers(
+ const Expression& guaranteed_true_predicate) {
+ auto guarantee = guaranteed_true_predicate.call();
+ if (!guarantee || guarantee->function_name != "and_kleene") {
+ return {guaranteed_true_predicate};
+ }
+ return FlattenedAssociativeChain(guaranteed_true_predicate).fringe;
+}
+
+// Conjunction members which are represented in known_values are erased from
+// conjunction_members
+Status ExtractKnownFieldValuesImpl(
+ std::vector<Expression>* conjunction_members,
+ std::unordered_map<FieldRef, Datum, FieldRef::Hash>* known_values) {
+ auto unconsumed_end =
+ std::partition(conjunction_members->begin(), conjunction_members->end(),
+ [](const Expression& expr) {
+ // search for an equality conditions between a field and a literal
+ auto call = expr.call();
+ if (!call) return true;
+
+ if (call->function_name == "equal") {
+ auto ref = call->arguments[0].field_ref();
+ auto lit = call->arguments[1].literal();
+ return !(ref && lit);
+ }
+
+ if (call->function_name == "is_null") {
+ auto ref = call->arguments[0].field_ref();
+ return !ref;
+ }
+
+ return true;
+ });
+
+ for (auto it = unconsumed_end; it != conjunction_members->end(); ++it) {
+ auto call = CallNotNull(*it);
+
+ if (call->function_name == "equal") {
+ auto ref = call->arguments[0].field_ref();
+ auto lit = call->arguments[1].literal();
+ known_values->emplace(*ref, *lit);
+ } else if (call->function_name == "is_null") {
+ auto ref = call->arguments[0].field_ref();
+ known_values->emplace(*ref, Datum(std::make_shared<NullScalar>()));
+ }
+ }
+
+ conjunction_members->erase(unconsumed_end, conjunction_members->end());
+
+ return Status::OK();
+}
+
+} // namespace
+
+Result<KnownFieldValues> ExtractKnownFieldValues(
+ const Expression& guaranteed_true_predicate) {
+ auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate);
+ KnownFieldValues known_values;
+ RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map));
+ return known_values;
+}
+
+Result<Expression> ReplaceFieldsWithKnownValues(const KnownFieldValues& known_values,
+ Expression expr) {
+ if (!expr.IsBound()) {
+ return Status::Invalid(
+ "ReplaceFieldsWithKnownValues called on an unbound Expression");
+ }
+
+ return Modify(
+ std::move(expr),
+ [&known_values](Expression expr) -> Result<Expression> {
+ if (auto ref = expr.field_ref()) {
+ auto it = known_values.map.find(*ref);
+ if (it != known_values.map.end()) {
+ Datum lit = it->second;
+ if (lit.descr() == expr.descr()) return literal(std::move(lit));
+ // type mismatch, try casting the known value to the correct type
+
+ if (expr.type()->id() == Type::DICTIONARY &&
+ lit.type()->id() != Type::DICTIONARY) {
+ // the known value must be dictionary encoded
+
+ const auto& dict_type = checked_cast<const DictionaryType&>(*expr.type());
+ if (!lit.type()->Equals(dict_type.value_type())) {
+ ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, dict_type.value_type()));
+ }
+
+ if (lit.is_scalar()) {
+ ARROW_ASSIGN_OR_RAISE(auto dictionary,
+ MakeArrayFromScalar(*lit.scalar(), 1));
+
+ lit = Datum{DictionaryScalar::Make(MakeScalar<int32_t>(0),
+ std::move(dictionary))};
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, expr.type()));
+ return literal(std::move(lit));
+ }
+ }
+ return expr;
+ },
+ [](Expression expr, ...) { return expr; });
+}
+
+namespace {
+
+bool IsBinaryAssociativeCommutative(const Expression::Call& call) {
+ static std::unordered_set<std::string> binary_associative_commutative{
+ "and", "or", "and_kleene", "or_kleene", "xor",
+ "multiply", "add", "multiply_checked", "add_checked"};
+
+ auto it = binary_associative_commutative.find(call.function_name);
+ return it != binary_associative_commutative.end();
+}
+
+} // namespace
+
+Result<Expression> Canonicalize(Expression expr, compute::ExecContext* exec_context) {
+ if (exec_context == nullptr) {
+ compute::ExecContext exec_context;
+ return Canonicalize(std::move(expr), &exec_context);
+ }
+
+ // If potentially reconstructing more deeply than a call's immediate arguments
+ // (for example, when reorganizing an associative chain), add expressions to this set to
+ // avoid unnecessary work
+ struct {
+ std::unordered_set<Expression, Expression::Hash> set_;
+
+ bool operator()(const Expression& expr) const {
+ return set_.find(expr) != set_.end();
+ }
+
+ void Add(std::vector<Expression> exprs) {
+ std::move(exprs.begin(), exprs.end(), std::inserter(set_, set_.end()));
+ }
+ } AlreadyCanonicalized;
+
+ return Modify(
+ std::move(expr),
+ [&AlreadyCanonicalized, exec_context](Expression expr) -> Result<Expression> {
+ auto call = expr.call();
+ if (!call) return expr;
+
+ if (AlreadyCanonicalized(expr)) return expr;
+
+ if (IsBinaryAssociativeCommutative(*call)) {
+ struct {
+ int Priority(const Expression& operand) const {
+ // order literals first, starting with nulls
+ if (operand.IsNullLiteral()) return 0;
+ if (operand.literal()) return 1;
+ return 2;
+ }
+ bool operator()(const Expression& l, const Expression& r) const {
+ return Priority(l) < Priority(r);
+ }
+ } CanonicalOrdering;
+
+ FlattenedAssociativeChain chain(expr);
+ if (chain.was_left_folded &&
+ std::is_sorted(chain.fringe.begin(), chain.fringe.end(),
+ CanonicalOrdering)) {
+ AlreadyCanonicalized.Add(std::move(chain.exprs));
+ return expr;
+ }
+
+ std::stable_sort(chain.fringe.begin(), chain.fringe.end(), CanonicalOrdering);
+
+ // fold the chain back up
+ auto folded =
+ FoldLeft(chain.fringe.begin(), chain.fringe.end(),
+ [call, &AlreadyCanonicalized](Expression l, Expression r) {
+ auto canonicalized_call = *call;
+ canonicalized_call.arguments = {std::move(l), std::move(r)};
+ Expression expr(std::move(canonicalized_call));
+ AlreadyCanonicalized.Add({expr});
+ return expr;
+ });
+ return std::move(*folded);
+ }
+
+ if (auto cmp = Comparison::Get(call->function_name)) {
+ if (call->arguments[0].literal() && !call->arguments[1].literal()) {
+ // ensure that literals are on comparisons' RHS
+ auto flipped_call = *call;
+
+ std::swap(flipped_call.arguments[0], flipped_call.arguments[1]);
+ flipped_call.function_name =
+ Comparison::GetName(Comparison::GetFlipped(*cmp));
+
+ return BindNonRecursive(flipped_call,
+ /*insert_implicit_casts=*/false, exec_context);
+ }
+ }
+
+ return expr;
+ },
+ [](Expression expr, ...) { return expr; });
+}
+
+namespace {
+
+Result<Expression> DirectComparisonSimplification(Expression expr,
+ const Expression::Call& guarantee) {
+ return Modify(
+ std::move(expr), [](Expression expr) { return expr; },
+ [&guarantee](Expression expr, ...) -> Result<Expression> {
+ auto call = expr.call();
+ if (!call) return expr;
+
+ // Ensure both calls are comparisons with equal LHS and scalar RHS
+ auto cmp = Comparison::Get(expr);
+ auto cmp_guarantee = Comparison::Get(guarantee.function_name);
+
+ if (!cmp) return expr;
+ if (!cmp_guarantee) return expr;
+
+ const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]);
+ const auto& guarantee_lhs = guarantee.arguments[0];
+ if (lhs != guarantee_lhs) return expr;
+
+ auto rhs = call->arguments[1].literal();
+ auto guarantee_rhs = guarantee.arguments[1].literal();
+
+ if (!rhs) return expr;
+ if (!rhs->is_scalar()) return expr;
+
+ if (!guarantee_rhs) return expr;
+ if (!guarantee_rhs->is_scalar()) return expr;
+
+ ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs,
+ Comparison::Execute(*rhs, *guarantee_rhs));
+ DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA);
+
+ if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) {
+ // RHS of filter is equal to RHS of guarantee
+
+ if ((*cmp & *cmp_guarantee) == *cmp_guarantee) {
+ // guarantee is a subset of filter, so all data will be included
+ // x > 1, x >= 1, x != 1 guaranteed by x > 1
+ return literal(true);
+ }
+
+ if ((*cmp & *cmp_guarantee) == 0) {
+ // guarantee disjoint with filter, so all data will be excluded
+ // x > 1, x >= 1, x != 1 unsatisfiable if x == 1
+ return literal(false);
+ }
+
+ return expr;
+ }
+
+ if (*cmp_guarantee & cmp_rhs_guarantee_rhs) {
+ // x > 1, x >= 1, x != 1 cannot use guarantee x >= 3
+ return expr;
+ }
+
+ if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) {
+ // x > 1, x >= 1, x != 1 guaranteed by x >= 3
+ return literal(true);
+ } else {
+ // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3
+ return literal(false);
+ }
+ });
+}
+
+} // namespace
+
+Result<Expression> SimplifyWithGuarantee(Expression expr,
+ const Expression& guaranteed_true_predicate) {
+ auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate);
+
+ KnownFieldValues known_values;
+ RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map));
+
+ ARROW_ASSIGN_OR_RAISE(expr,
+ ReplaceFieldsWithKnownValues(known_values, std::move(expr)));
+
+ auto CanonicalizeAndFoldConstants = [&expr] {
+ ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(expr)));
+ ARROW_ASSIGN_OR_RAISE(expr, FoldConstants(std::move(expr)));
+ return Status::OK();
+ };
+ RETURN_NOT_OK(CanonicalizeAndFoldConstants());
+
+ for (const auto& guarantee : conjunction_members) {
+ if (Comparison::Get(guarantee) && guarantee.call()->arguments[1].literal()) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto simplified, DirectComparisonSimplification(expr, *CallNotNull(guarantee)));
+
+ if (Identical(simplified, expr)) continue;
+
+ expr = std::move(simplified);
+ RETURN_NOT_OK(CanonicalizeAndFoldConstants());
+ }
+ }
+
+ return expr;
+}
+
+// Serialization is accomplished by converting expressions to KeyValueMetadata and storing
+// this in the schema of a RecordBatch. Embedded arrays and scalars are stored in its
+// columns. Finally, the RecordBatch is written to an IPC file.
+Result<std::shared_ptr<Buffer>> Serialize(const Expression& expr) {
+ struct {
+ std::shared_ptr<KeyValueMetadata> metadata_ = std::make_shared<KeyValueMetadata>();
+ ArrayVector columns_;
+
+ Result<std::string> AddScalar(const Scalar& scalar) {
+ auto ret = columns_.size();
+ ARROW_ASSIGN_OR_RAISE(auto array, MakeArrayFromScalar(scalar, 1));
+ columns_.push_back(std::move(array));
+ return std::to_string(ret);
+ }
+
+ Status Visit(const Expression& expr) {
+ if (auto lit = expr.literal()) {
+ if (!lit->is_scalar()) {
+ return Status::NotImplemented("Serialization of non-scalar literals");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto value, AddScalar(*lit->scalar()));
+ metadata_->Append("literal", std::move(value));
+ return Status::OK();
+ }
+
+ if (auto ref = expr.field_ref()) {
+ if (!ref->name()) {
+ return Status::NotImplemented("Serialization of non-name field_refs");
+ }
+ metadata_->Append("field_ref", *ref->name());
+ return Status::OK();
+ }
+
+ auto call = CallNotNull(expr);
+ metadata_->Append("call", call->function_name);
+
+ for (const auto& argument : call->arguments) {
+ RETURN_NOT_OK(Visit(argument));
+ }
+
+ if (call->options) {
+ ARROW_ASSIGN_OR_RAISE(auto options_scalar,
+ internal::FunctionOptionsToStructScalar(*call->options));
+ ARROW_ASSIGN_OR_RAISE(auto value, AddScalar(*options_scalar));
+ metadata_->Append("options", std::move(value));
+ }
+
+ metadata_->Append("end", call->function_name);
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<RecordBatch>> operator()(const Expression& expr) {
+ RETURN_NOT_OK(Visit(expr));
+ FieldVector fields(columns_.size());
+ for (size_t i = 0; i < fields.size(); ++i) {
+ fields[i] = field("", columns_[i]->type());
+ }
+ return RecordBatch::Make(schema(std::move(fields), std::move(metadata_)), 1,
+ std::move(columns_));
+ }
+ } ToRecordBatch;
+
+ ARROW_ASSIGN_OR_RAISE(auto batch, ToRecordBatch(expr));
+ ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create());
+ ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema()));
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+ RETURN_NOT_OK(writer->Close());
+ return stream->Finish();
+}
+
+Result<Expression> Deserialize(std::shared_ptr<Buffer> buffer) {
+ io::BufferReader stream(std::move(buffer));
+ ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream));
+ ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0));
+ if (batch->schema()->metadata() == nullptr) {
+ return Status::Invalid("serialized Expression's batch repr had null metadata");
+ }
+ if (batch->num_rows() != 1) {
+ return Status::Invalid(
+ "serialized Expression's batch repr was not a single row - had ",
+ batch->num_rows());
+ }
+
+ struct FromRecordBatch {
+ const RecordBatch& batch_;
+ int index_;
+
+ const KeyValueMetadata& metadata() { return *batch_.schema()->metadata(); }
+
+ Result<std::shared_ptr<Scalar>> GetScalar(const std::string& i) {
+ int32_t column_index;
+ if (!::arrow::internal::ParseValue<Int32Type>(i.data(), i.length(),
+ &column_index)) {
+ return Status::Invalid("Couldn't parse column_index");
+ }
+ if (column_index >= batch_.num_columns()) {
+ return Status::Invalid("column_index out of bounds");
+ }
+ return batch_.column(column_index)->GetScalar(0);
+ }
+
+ Result<Expression> GetOne() {
+ if (index_ >= metadata().size()) {
+ return Status::Invalid("unterminated serialized Expression");
+ }
+
+ const std::string& key = metadata().key(index_);
+ const std::string& value = metadata().value(index_);
+ ++index_;
+
+ if (key == "literal") {
+ ARROW_ASSIGN_OR_RAISE(auto scalar, GetScalar(value));
+ return literal(std::move(scalar));
+ }
+
+ if (key == "field_ref") {
+ return field_ref(value);
+ }
+
+ if (key != "call") {
+ return Status::Invalid("Unrecognized serialized Expression key ", key);
+ }
+
+ std::vector<Expression> arguments;
+ while (metadata().key(index_) != "end") {
+ if (metadata().key(index_) == "options") {
+ ARROW_ASSIGN_OR_RAISE(auto options_scalar, GetScalar(metadata().value(index_)));
+ std::shared_ptr<compute::FunctionOptions> options;
+ if (options_scalar) {
+ ARROW_ASSIGN_OR_RAISE(
+ options, internal::FunctionOptionsFromStructScalar(
+ checked_cast<const StructScalar&>(*options_scalar)));
+ }
+ auto expr = call(value, std::move(arguments), std::move(options));
+ index_ += 2;
+ return expr;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto argument, GetOne());
+ arguments.push_back(std::move(argument));
+ }
+
+ ++index_;
+ return call(value, std::move(arguments));
+ }
+ };
+
+ return FromRecordBatch{*batch, 0}.GetOne();
+}
+
+Expression project(std::vector<Expression> values, std::vector<std::string> names) {
+ return call("make_struct", std::move(values),
+ compute::MakeStructOptions{std::move(names)});
+}
+
+Expression equal(Expression lhs, Expression rhs) {
+ return call("equal", {std::move(lhs), std::move(rhs)});
+}
+
+Expression not_equal(Expression lhs, Expression rhs) {
+ return call("not_equal", {std::move(lhs), std::move(rhs)});
+}
+
+Expression less(Expression lhs, Expression rhs) {
+ return call("less", {std::move(lhs), std::move(rhs)});
+}
+
+Expression less_equal(Expression lhs, Expression rhs) {
+ return call("less_equal", {std::move(lhs), std::move(rhs)});
+}
+
+Expression greater(Expression lhs, Expression rhs) {
+ return call("greater", {std::move(lhs), std::move(rhs)});
+}
+
+Expression greater_equal(Expression lhs, Expression rhs) {
+ return call("greater_equal", {std::move(lhs), std::move(rhs)});
+}
+
+Expression is_null(Expression lhs, bool nan_is_null) {
+ return call("is_null", {std::move(lhs)}, compute::NullOptions(std::move(nan_is_null)));
+}
+
+Expression is_valid(Expression lhs) { return call("is_valid", {std::move(lhs)}); }
+
+Expression and_(Expression lhs, Expression rhs) {
+ return call("and_kleene", {std::move(lhs), std::move(rhs)});
+}
+
+Expression and_(const std::vector<Expression>& operands) {
+ auto folded = FoldLeft<Expression(Expression, Expression)>(operands.begin(),
+ operands.end(), and_);
+ if (folded) {
+ return std::move(*folded);
+ }
+ return literal(true);
+}
+
+Expression or_(Expression lhs, Expression rhs) {
+ return call("or_kleene", {std::move(lhs), std::move(rhs)});
+}
+
+Expression or_(const std::vector<Expression>& operands) {
+ auto folded =
+ FoldLeft<Expression(Expression, Expression)>(operands.begin(), operands.end(), or_);
+ if (folded) {
+ return std::move(*folded);
+ }
+ return literal(false);
+}
+
+Expression not_(Expression operand) { return call("invert", {std::move(operand)}); }
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/expression.h b/src/arrow/cpp/src/arrow/compute/exec/expression.h
new file mode 100644
index 000000000..dac5728ab
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/expression.h
@@ -0,0 +1,269 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/compute/type_fwd.h"
+#include "arrow/datum.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/variant.h"
+
+namespace arrow {
+namespace compute {
+
+/// An unbound expression which maps a single Datum to another Datum.
+/// An expression is one of
+/// - A literal Datum.
+/// - A reference to a single (potentially nested) field of the input Datum.
+/// - A call to a compute function, with arguments specified by other Expressions.
+class ARROW_EXPORT Expression {
+ public:
+ struct Call {
+ std::string function_name;
+ std::vector<Expression> arguments;
+ std::shared_ptr<FunctionOptions> options;
+ // Cached hash value
+ size_t hash;
+
+ // post-Bind properties:
+ std::shared_ptr<Function> function;
+ const Kernel* kernel = NULLPTR;
+ std::shared_ptr<KernelState> kernel_state;
+ ValueDescr descr;
+
+ void ComputeHash();
+ };
+
+ std::string ToString() const;
+ bool Equals(const Expression& other) const;
+ size_t hash() const;
+ struct Hash {
+ size_t operator()(const Expression& expr) const { return expr.hash(); }
+ };
+
+ /// Bind this expression to the given input type, looking up Kernels and field types.
+ /// Some expression simplification may be performed and implicit casts will be inserted.
+ /// Any state necessary for execution will be initialized and returned.
+ Result<Expression> Bind(const ValueDescr& in, ExecContext* = NULLPTR) const;
+ Result<Expression> Bind(const Schema& in_schema, ExecContext* = NULLPTR) const;
+
+ // XXX someday
+ // Clone all KernelState in this bound expression. If any function referenced by this
+ // expression has mutable KernelState, it is not safe to execute or apply simplification
+ // passes to it (or copies of it!) from multiple threads. Cloning state produces new
+ // KernelStates where necessary to ensure that Expressions may be manipulated safely
+ // on multiple threads.
+ // Result<ExpressionState> CloneState() const;
+ // Status SetState(ExpressionState);
+
+ /// Return true if all an expression's field references have explicit ValueDescr and all
+ /// of its functions' kernels are looked up.
+ bool IsBound() const;
+
+ /// Return true if this expression is composed only of Scalar literals, field
+ /// references, and calls to ScalarFunctions.
+ bool IsScalarExpression() const;
+
+ /// Return true if this expression is literal and entirely null.
+ bool IsNullLiteral() const;
+
+ /// Return true if this expression could evaluate to true.
+ bool IsSatisfiable() const;
+
+ // XXX someday
+ // Result<PipelineGraph> GetPipelines();
+
+ /// Access a Call or return nullptr if this expression is not a call
+ const Call* call() const;
+ /// Access a Datum or return nullptr if this expression is not a literal
+ const Datum* literal() const;
+ /// Access a FieldRef or return nullptr if this expression is not a field_ref
+ const FieldRef* field_ref() const;
+
+ /// The type and shape to which this expression will evaluate
+ ValueDescr descr() const;
+ std::shared_ptr<DataType> type() const { return descr().type; }
+ // XXX someday
+ // NullGeneralization::type nullable() const;
+
+ struct Parameter {
+ FieldRef ref;
+
+ // post-bind properties
+ ValueDescr descr;
+ int index;
+ };
+ const Parameter* parameter() const;
+
+ Expression() = default;
+ explicit Expression(Call call);
+ explicit Expression(Datum literal);
+ explicit Expression(Parameter parameter);
+
+ private:
+ using Impl = util::Variant<Datum, Parameter, Call>;
+ std::shared_ptr<Impl> impl_;
+
+ ARROW_EXPORT friend bool Identical(const Expression& l, const Expression& r);
+
+ ARROW_EXPORT friend void PrintTo(const Expression&, std::ostream*);
+};
+
+inline bool operator==(const Expression& l, const Expression& r) { return l.Equals(r); }
+inline bool operator!=(const Expression& l, const Expression& r) { return !l.Equals(r); }
+
+// Factories
+
+ARROW_EXPORT
+Expression literal(Datum lit);
+
+template <typename Arg>
+Expression literal(Arg&& arg) {
+ return literal(Datum(std::forward<Arg>(arg)));
+}
+
+ARROW_EXPORT
+Expression field_ref(FieldRef ref);
+
+ARROW_EXPORT
+Expression call(std::string function, std::vector<Expression> arguments,
+ std::shared_ptr<FunctionOptions> options = NULLPTR);
+
+template <typename Options, typename = typename std::enable_if<
+ std::is_base_of<FunctionOptions, Options>::value>::type>
+Expression call(std::string function, std::vector<Expression> arguments,
+ Options options) {
+ return call(std::move(function), std::move(arguments),
+ std::make_shared<Options>(std::move(options)));
+}
+
+/// Assemble a list of all fields referenced by an Expression at any depth.
+ARROW_EXPORT
+std::vector<FieldRef> FieldsInExpression(const Expression&);
+
+/// Check if the expression references any fields.
+ARROW_EXPORT
+bool ExpressionHasFieldRefs(const Expression&);
+
+/// Assemble a mapping from field references to known values.
+struct ARROW_EXPORT KnownFieldValues;
+ARROW_EXPORT
+Result<KnownFieldValues> ExtractKnownFieldValues(
+ const Expression& guaranteed_true_predicate);
+
+/// \defgroup expression-passes Functions for modification of Expressions
+///
+/// @{
+///
+/// These transform bound expressions. Some transforms utilize a guarantee, which is
+/// provided as an Expression which is guaranteed to evaluate to true. The
+/// guaranteed_true_predicate need not be bound, but canonicalization is currently
+/// deferred to producers of guarantees. For example in order to be recognized as a
+/// guarantee on a field value, an Expression must be a call to "equal" with field_ref LHS
+/// and literal RHS. Flipping the arguments, "is_in" with a one-long value_set, ... or
+/// other semantically identical Expressions will not be recognized.
+
+/// Weak canonicalization which establishes guarantees for subsequent passes. Even
+/// equivalent Expressions may result in different canonicalized expressions.
+/// TODO this could be a strong canonicalization
+ARROW_EXPORT
+Result<Expression> Canonicalize(Expression, ExecContext* = NULLPTR);
+
+/// Simplify Expressions based on literal arguments (for example, add(null, x) will always
+/// be null so replace the call with a null literal). Includes early evaluation of all
+/// calls whose arguments are entirely literal.
+ARROW_EXPORT
+Result<Expression> FoldConstants(Expression);
+
+/// Simplify Expressions by replacing with known values of the fields which it references.
+ARROW_EXPORT
+Result<Expression> ReplaceFieldsWithKnownValues(const KnownFieldValues& known_values,
+ Expression);
+
+/// Simplify an expression by replacing subexpressions based on a guarantee:
+/// a boolean expression which is guaranteed to evaluate to `true`. For example, this is
+/// used to remove redundant function calls from a filter expression or to replace a
+/// reference to a constant-value field with a literal.
+ARROW_EXPORT
+Result<Expression> SimplifyWithGuarantee(Expression,
+ const Expression& guaranteed_true_predicate);
+
+/// @}
+
+// Execution
+
+/// Create an ExecBatch suitable for passing to ExecuteScalarExpression() from a
+/// RecordBatch which may have missing or incorrectly ordered columns.
+/// Missing fields will be replaced with null scalars.
+ARROW_EXPORT Result<ExecBatch> MakeExecBatch(const Schema& full_schema,
+ const Datum& partial);
+
+/// Execute a scalar expression against the provided state and input ExecBatch. This
+/// expression must be bound.
+ARROW_EXPORT
+Result<Datum> ExecuteScalarExpression(const Expression&, const ExecBatch& input,
+ ExecContext* = NULLPTR);
+
+/// Convenience function for invoking against a RecordBatch
+ARROW_EXPORT
+Result<Datum> ExecuteScalarExpression(const Expression&, const Schema& full_schema,
+ const Datum& partial_input, ExecContext* = NULLPTR);
+
+// Serialization
+
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> Serialize(const Expression&);
+
+ARROW_EXPORT
+Result<Expression> Deserialize(std::shared_ptr<Buffer>);
+
+// Convenience aliases for factories
+
+ARROW_EXPORT Expression project(std::vector<Expression> values,
+ std::vector<std::string> names);
+
+ARROW_EXPORT Expression equal(Expression lhs, Expression rhs);
+
+ARROW_EXPORT Expression not_equal(Expression lhs, Expression rhs);
+
+ARROW_EXPORT Expression less(Expression lhs, Expression rhs);
+
+ARROW_EXPORT Expression less_equal(Expression lhs, Expression rhs);
+
+ARROW_EXPORT Expression greater(Expression lhs, Expression rhs);
+
+ARROW_EXPORT Expression greater_equal(Expression lhs, Expression rhs);
+
+ARROW_EXPORT Expression is_null(Expression lhs, bool nan_is_null = false);
+
+ARROW_EXPORT Expression is_valid(Expression lhs);
+
+ARROW_EXPORT Expression and_(Expression lhs, Expression rhs);
+ARROW_EXPORT Expression and_(const std::vector<Expression>&);
+ARROW_EXPORT Expression or_(Expression lhs, Expression rhs);
+ARROW_EXPORT Expression or_(const std::vector<Expression>&);
+ARROW_EXPORT Expression not_(Expression operand);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/expression_benchmark.cc b/src/arrow/cpp/src/arrow/compute/exec/expression_benchmark.cc
new file mode 100644
index 000000000..1899b7caa
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/expression_benchmark.cc
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/compute/cast.h"
+#include "arrow/compute/exec/expression.h"
+#include "arrow/dataset/partition.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+
+namespace arrow {
+namespace compute {
+
+std::shared_ptr<Scalar> ninety_nine_dict =
+ DictionaryScalar::Make(MakeScalar(0), ArrayFromJSON(int64(), "[99]"));
+
+// A benchmark of SimplifyWithGuarantee using expressions arising from partitioning.
+static void SimplifyFilterWithGuarantee(benchmark::State& state, Expression filter,
+ Expression guarantee) {
+ auto dataset_schema = schema({field("a", int64()), field("b", int64())});
+ ASSIGN_OR_ABORT(filter, filter.Bind(*dataset_schema));
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(SimplifyWithGuarantee(filter, guarantee));
+ }
+}
+
+auto to_int64 = compute::CastOptions::Safe(int64());
+// A fully simplified filter.
+auto filter_simple_negative = and_(equal(field_ref("a"), literal(int64_t(99))),
+ equal(field_ref("b"), literal(int64_t(98))));
+auto filter_simple_positive = and_(equal(field_ref("a"), literal(int64_t(99))),
+ equal(field_ref("b"), literal(int64_t(99))));
+// A filter with casts inserted due to converting between the
+// assumed-by-default type and the inferred partition schema.
+auto filter_cast_negative =
+ and_(equal(call("cast", {field_ref("a")}, to_int64), literal(99)),
+ equal(call("cast", {field_ref("b")}, to_int64), literal(98)));
+auto filter_cast_positive =
+ and_(equal(call("cast", {field_ref("a")}, to_int64), literal(99)),
+ equal(call("cast", {field_ref("b")}, to_int64), literal(99)));
+
+// An unencoded partition expression for "a=99/b=99".
+auto guarantee = and_(equal(field_ref("a"), literal(int64_t(99))),
+ equal(field_ref("b"), literal(int64_t(99))));
+
+// A partition expression for "a=99/b=99" that uses dictionaries (inferred by default).
+auto guarantee_dictionary = and_(equal(field_ref("a"), literal(ninety_nine_dict)),
+ equal(field_ref("b"), literal(ninety_nine_dict)));
+
+// Negative queries (partition expressions that fail the filter)
+BENCHMARK_CAPTURE(SimplifyFilterWithGuarantee, negative_filter_simple_guarantee_simple,
+ filter_simple_negative, guarantee);
+BENCHMARK_CAPTURE(SimplifyFilterWithGuarantee, negative_filter_cast_guarantee_simple,
+ filter_cast_negative, guarantee);
+BENCHMARK_CAPTURE(SimplifyFilterWithGuarantee,
+ negative_filter_simple_guarantee_dictionary, filter_simple_negative,
+ guarantee_dictionary);
+BENCHMARK_CAPTURE(SimplifyFilterWithGuarantee, negative_filter_cast_guarantee_dictionary,
+ filter_cast_negative, guarantee_dictionary);
+// Positive queries (partition expressions that pass the filter)
+BENCHMARK_CAPTURE(SimplifyFilterWithGuarantee, positive_filter_simple_guarantee_simple,
+ filter_simple_positive, guarantee);
+BENCHMARK_CAPTURE(SimplifyFilterWithGuarantee, positive_filter_cast_guarantee_simple,
+ filter_cast_positive, guarantee);
+BENCHMARK_CAPTURE(SimplifyFilterWithGuarantee,
+ positive_filter_simple_guarantee_dictionary, filter_simple_positive,
+ guarantee_dictionary);
+BENCHMARK_CAPTURE(SimplifyFilterWithGuarantee, positive_filter_cast_guarantee_dictionary,
+ filter_cast_positive, guarantee_dictionary);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/expression_internal.h b/src/arrow/cpp/src/arrow/compute/exec/expression_internal.h
new file mode 100644
index 000000000..dc38924d9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/expression_internal.h
@@ -0,0 +1,336 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/expression.h"
+
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/registry.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+struct KnownFieldValues {
+ std::unordered_map<FieldRef, Datum, FieldRef::Hash> map;
+};
+
+inline const Expression::Call* CallNotNull(const Expression& expr) {
+ auto call = expr.call();
+ DCHECK_NE(call, nullptr);
+ return call;
+}
+
+inline std::vector<ValueDescr> GetDescriptors(const std::vector<Expression>& exprs) {
+ std::vector<ValueDescr> descrs(exprs.size());
+ for (size_t i = 0; i < exprs.size(); ++i) {
+ DCHECK(exprs[i].IsBound());
+ descrs[i] = exprs[i].descr();
+ }
+ return descrs;
+}
+
+inline std::vector<ValueDescr> GetDescriptors(const std::vector<Datum>& values) {
+ std::vector<ValueDescr> descrs(values.size());
+ for (size_t i = 0; i < values.size(); ++i) {
+ descrs[i] = values[i].descr();
+ }
+ return descrs;
+}
+
+struct Comparison {
+ enum type {
+ NA = 0,
+ EQUAL = 1,
+ LESS = 2,
+ GREATER = 4,
+ NOT_EQUAL = LESS | GREATER,
+ LESS_EQUAL = LESS | EQUAL,
+ GREATER_EQUAL = GREATER | EQUAL,
+ };
+
+ static const type* Get(const std::string& function) {
+ static std::unordered_map<std::string, type> map{
+ {"equal", EQUAL}, {"not_equal", NOT_EQUAL},
+ {"less", LESS}, {"less_equal", LESS_EQUAL},
+ {"greater", GREATER}, {"greater_equal", GREATER_EQUAL},
+ };
+
+ auto it = map.find(function);
+ return it != map.end() ? &it->second : nullptr;
+ }
+
+ static const type* Get(const Expression& expr) {
+ if (auto call = expr.call()) {
+ return Comparison::Get(call->function_name);
+ }
+ return nullptr;
+ }
+
+ // Execute a simple Comparison between scalars
+ static Result<type> Execute(Datum l, Datum r) {
+ if (!l.is_scalar() || !r.is_scalar()) {
+ return Status::Invalid("Cannot Execute Comparison on non-scalars");
+ }
+
+ std::vector<Datum> arguments{std::move(l), std::move(r)};
+
+ ARROW_ASSIGN_OR_RAISE(auto equal, compute::CallFunction("equal", arguments));
+
+ if (!equal.scalar()->is_valid) return NA;
+ if (equal.scalar_as<BooleanScalar>().value) return EQUAL;
+
+ ARROW_ASSIGN_OR_RAISE(auto less, compute::CallFunction("less", arguments));
+
+ if (!less.scalar()->is_valid) return NA;
+ return less.scalar_as<BooleanScalar>().value ? LESS : GREATER;
+ }
+
+ // Given an Expression wrapped in casts which preserve ordering
+ // (for example, cast(field_ref("i16"), to_type=int32())), unwrap the inner Expression.
+ // This is used to destructure implicitly cast field_refs during Expression
+ // simplification.
+ static const Expression& StripOrderPreservingCasts(const Expression& expr) {
+ auto call = expr.call();
+ if (!call) return expr;
+ if (call->function_name != "cast") return expr;
+
+ const Expression& from = call->arguments[0];
+
+ auto from_id = from.type()->id();
+ auto to_id = expr.type()->id();
+
+ if (is_floating(to_id)) {
+ if (is_integer(from_id) || is_floating(from_id)) {
+ return StripOrderPreservingCasts(from);
+ }
+ return expr;
+ }
+
+ if (is_unsigned_integer(to_id)) {
+ if (is_unsigned_integer(from_id) && bit_width(to_id) >= bit_width(from_id)) {
+ return StripOrderPreservingCasts(from);
+ }
+ return expr;
+ }
+
+ if (is_signed_integer(to_id)) {
+ if (is_integer(from_id) && bit_width(to_id) >= bit_width(from_id)) {
+ return StripOrderPreservingCasts(from);
+ }
+ return expr;
+ }
+
+ return expr;
+ }
+
+ static type GetFlipped(type op) {
+ switch (op) {
+ case NA:
+ return NA;
+ case EQUAL:
+ return EQUAL;
+ case LESS:
+ return GREATER;
+ case GREATER:
+ return LESS;
+ case NOT_EQUAL:
+ return NOT_EQUAL;
+ case LESS_EQUAL:
+ return GREATER_EQUAL;
+ case GREATER_EQUAL:
+ return LESS_EQUAL;
+ }
+ DCHECK(false);
+ return NA;
+ }
+
+ static std::string GetName(type op) {
+ switch (op) {
+ case NA:
+ break;
+ case EQUAL:
+ return "equal";
+ case LESS:
+ return "less";
+ case GREATER:
+ return "greater";
+ case NOT_EQUAL:
+ return "not_equal";
+ case LESS_EQUAL:
+ return "less_equal";
+ case GREATER_EQUAL:
+ return "greater_equal";
+ }
+ return "na";
+ }
+
+ static std::string GetOp(type op) {
+ switch (op) {
+ case NA:
+ DCHECK(false) << "unreachable";
+ break;
+ case EQUAL:
+ return "==";
+ case LESS:
+ return "<";
+ case GREATER:
+ return ">";
+ case NOT_EQUAL:
+ return "!=";
+ case LESS_EQUAL:
+ return "<=";
+ case GREATER_EQUAL:
+ return ">=";
+ }
+ DCHECK(false);
+ return "";
+ }
+};
+
+inline const compute::CastOptions* GetCastOptions(const Expression::Call& call) {
+ if (call.function_name != "cast") return nullptr;
+ return checked_cast<const compute::CastOptions*>(call.options.get());
+}
+
+inline bool IsSetLookup(const std::string& function) {
+ return function == "is_in" || function == "index_in";
+}
+
+inline const compute::MakeStructOptions* GetMakeStructOptions(
+ const Expression::Call& call) {
+ if (call.function_name != "make_struct") return nullptr;
+ return checked_cast<const compute::MakeStructOptions*>(call.options.get());
+}
+
+/// A helper for unboxing an Expression composed of associative function calls.
+/// Such expressions can frequently be rearranged to a semantically equivalent
+/// expression for more optimal execution or more straightforward manipulation.
+/// For example, (a + ((b + 3) + 4)) is equivalent to (((4 + 3) + a) + b) and the latter
+/// can be trivially constant-folded to ((7 + a) + b).
+struct FlattenedAssociativeChain {
+ /// True if a chain was already a left fold.
+ bool was_left_folded = true;
+
+ /// All "branch" expressions in a flattened chain. For example given (a + ((b + 3) + 4))
+ /// exprs would be [(a + ((b + 3) + 4)), ((b + 3) + 4), (b + 3)]
+ std::vector<Expression> exprs;
+
+ /// All "leaf" expressions in a flattened chain. For example given (a + ((b + 3) + 4))
+ /// the fringe would be [a, b, 3, 4]
+ std::vector<Expression> fringe;
+
+ explicit FlattenedAssociativeChain(Expression expr) : exprs{std::move(expr)} {
+ auto call = CallNotNull(exprs.back());
+ fringe = call->arguments;
+
+ auto it = fringe.begin();
+
+ while (it != fringe.end()) {
+ auto sub_call = it->call();
+ if (!sub_call || sub_call->function_name != call->function_name) {
+ ++it;
+ continue;
+ }
+
+ if (it != fringe.begin()) {
+ was_left_folded = false;
+ }
+
+ exprs.push_back(std::move(*it));
+ it = fringe.erase(it);
+
+ auto index = it - fringe.begin();
+ fringe.insert(it, sub_call->arguments.begin(), sub_call->arguments.end());
+ it = fringe.begin() + index;
+ // NB: no increment so we hit sub_call's first argument next iteration
+ }
+
+ DCHECK(std::all_of(exprs.begin(), exprs.end(), [](const Expression& expr) {
+ return CallNotNull(expr)->options == nullptr;
+ }));
+ }
+};
+
+inline Result<std::shared_ptr<compute::Function>> GetFunction(
+ const Expression::Call& call, compute::ExecContext* exec_context) {
+ if (call.function_name != "cast") {
+ return exec_context->func_registry()->GetFunction(call.function_name);
+ }
+ // XXX this special case is strange; why not make "cast" a ScalarFunction?
+ const auto& to_type = checked_cast<const compute::CastOptions&>(*call.options).to_type;
+ return compute::GetCastFunction(to_type);
+}
+
+/// Modify an Expression with pre-order and post-order visitation.
+/// `pre` will be invoked on each Expression. `pre` will visit Calls before their
+/// arguments, `post_call` will visit Calls (and no other Expressions) after their
+/// arguments. Visitors should return the Identical expression to indicate no change; this
+/// will prevent unnecessary construction in the common case where a modification is not
+/// possible/necessary/...
+///
+/// If an argument was modified, `post_call` visits a reconstructed Call with the modified
+/// arguments but also receives a pointer to the unmodified Expression as a second
+/// argument. If no arguments were modified the unmodified Expression* will be nullptr.
+template <typename PreVisit, typename PostVisitCall>
+Result<Expression> Modify(Expression expr, const PreVisit& pre,
+ const PostVisitCall& post_call) {
+ ARROW_ASSIGN_OR_RAISE(expr, Result<Expression>(pre(std::move(expr))));
+
+ auto call = expr.call();
+ if (!call) return expr;
+
+ bool at_least_one_modified = false;
+ std::vector<Expression> modified_arguments;
+
+ for (size_t i = 0; i < call->arguments.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto modified_argument,
+ Modify(call->arguments[i], pre, post_call));
+
+ if (Identical(modified_argument, call->arguments[i])) {
+ continue;
+ }
+
+ if (!at_least_one_modified) {
+ modified_arguments = call->arguments;
+ at_least_one_modified = true;
+ }
+
+ modified_arguments[i] = std::move(modified_argument);
+ }
+
+ if (at_least_one_modified) {
+ // reconstruct the call expression with the modified arguments
+ auto modified_call = *call;
+ modified_call.arguments = std::move(modified_arguments);
+ return post_call(Expression(std::move(modified_call)), &expr);
+ }
+
+ return post_call(std::move(expr), nullptr);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/expression_test.cc b/src/arrow/cpp/src/arrow/compute/exec/expression_test.cc
new file mode 100644
index 000000000..88b94e804
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/expression_test.cc
@@ -0,0 +1,1414 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/expression.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/compute/exec/expression_internal.h"
+#include "arrow/compute/function_internal.h"
+#include "arrow/compute/registry.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/make_unique.h"
+
+using testing::HasSubstr;
+using testing::UnorderedElementsAreArray;
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+const std::shared_ptr<Schema> kBoringSchema = schema({
+ field("bool", boolean()),
+ field("i8", int8()),
+ field("i32", int32()),
+ field("i32_req", int32(), /*nullable=*/false),
+ field("u32", uint32()),
+ field("i64", int64()),
+ field("f32", float32()),
+ field("f32_req", float32(), /*nullable=*/false),
+ field("f64", float64()),
+ field("date64", date64()),
+ field("str", utf8()),
+ field("dict_str", dictionary(int32(), utf8())),
+ field("dict_i32", dictionary(int32(), int32())),
+ field("ts_ns", timestamp(TimeUnit::NANO)),
+});
+
+#define EXPECT_OK ARROW_EXPECT_OK
+
+Expression cast(Expression argument, std::shared_ptr<DataType> to_type) {
+ return call("cast", {std::move(argument)},
+ compute::CastOptions::Safe(std::move(to_type)));
+}
+
+template <typename Actual, typename Expected>
+void ExpectResultsEqual(Actual&& actual, Expected&& expected) {
+ using MaybeActual = typename EnsureResult<typename std::decay<Actual>::type>::type;
+ using MaybeExpected = typename EnsureResult<typename std::decay<Expected>::type>::type;
+
+ MaybeActual maybe_actual(std::forward<Actual>(actual));
+ MaybeExpected maybe_expected(std::forward<Expected>(expected));
+
+ if (maybe_expected.ok()) {
+ EXPECT_EQ(maybe_actual, maybe_expected);
+ } else {
+ EXPECT_RAISES_WITH_CODE_AND_MESSAGE_THAT(
+ expected.status().code(), HasSubstr(expected.status().message()), maybe_actual);
+ }
+}
+
+const auto no_change = util::nullopt;
+
+TEST(ExpressionUtils, Comparison) {
+ auto Expect = [](Result<std::string> expected, Datum l, Datum r) {
+ ExpectResultsEqual(Comparison::Execute(l, r).Map(Comparison::GetName), expected);
+ };
+
+ Datum zero(0), one(1), two(2), null(std::make_shared<Int32Scalar>());
+ Datum str("hello"), bin(std::make_shared<BinaryScalar>(Buffer::FromString("hello")));
+ Datum dict_str(DictionaryScalar::Make(std::make_shared<Int32Scalar>(0),
+ ArrayFromJSON(utf8(), R"(["a", "b", "c"])")));
+
+ Status not_impl = Status::NotImplemented("no kernel matching input types");
+
+ Expect("equal", one, one);
+ Expect("less", one, two);
+ Expect("greater", one, zero);
+
+ Expect("na", one, null);
+ Expect("na", null, one);
+
+ // strings and ints are not comparable without explicit casts
+ Expect(not_impl, str, one);
+ Expect(not_impl, one, str);
+ Expect(not_impl, str, null); // not even null ints
+
+ // string -> binary implicit cast allowed
+ Expect("equal", str, bin);
+ Expect("equal", bin, str);
+
+ // dict_str -> string, implicit casts allowed
+ Expect("less", dict_str, str);
+ Expect("less", dict_str, bin);
+}
+
+TEST(ExpressionUtils, StripOrderPreservingCasts) {
+ auto Expect = [](Expression expr, util::optional<Expression> expected_stripped) {
+ ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema));
+ if (!expected_stripped) {
+ expected_stripped = expr;
+ } else {
+ ASSERT_OK_AND_ASSIGN(expected_stripped, expected_stripped->Bind(*kBoringSchema));
+ }
+ EXPECT_EQ(Comparison::StripOrderPreservingCasts(expr), *expected_stripped);
+ };
+
+ // Casting int to float preserves ordering.
+ // For example, let
+ // a = 3, b = 2, assert(a > b)
+ // After injecting a cast to float, this ordering still holds
+ // float(a) == 3.0, float(b) == 2.0, assert(float(a) > float(b))
+ Expect(cast(field_ref("i32"), float32()), field_ref("i32"));
+
+ // Casting an integral type to a wider integral type preserves ordering.
+ Expect(cast(field_ref("i32"), int64()), field_ref("i32"));
+ Expect(cast(field_ref("i32"), int32()), field_ref("i32"));
+ Expect(cast(field_ref("i32"), int16()), no_change);
+ Expect(cast(field_ref("i32"), int8()), no_change);
+
+ Expect(cast(field_ref("u32"), uint64()), field_ref("u32"));
+ Expect(cast(field_ref("u32"), uint32()), field_ref("u32"));
+ Expect(cast(field_ref("u32"), uint16()), no_change);
+ Expect(cast(field_ref("u32"), uint8()), no_change);
+
+ Expect(cast(field_ref("u32"), int64()), field_ref("u32"));
+ Expect(cast(field_ref("u32"), int32()), field_ref("u32"));
+ Expect(cast(field_ref("u32"), int16()), no_change);
+ Expect(cast(field_ref("u32"), int8()), no_change);
+
+ // Casting float to int can affect ordering.
+ // For example, let
+ // a = 3.5, b = 3.0, assert(a > b)
+ // After injecting a cast to integer, this ordering may no longer hold
+ // int(a) == 3, int(b) == 3, assert(!(int(a) > int(b)))
+ Expect(cast(field_ref("f32"), int32()), no_change);
+
+ // casting any float type to another preserves ordering
+ Expect(cast(field_ref("f32"), float64()), field_ref("f32"));
+ Expect(cast(field_ref("f64"), float32()), field_ref("f64"));
+
+ // casting signed integer to unsigned can alter ordering
+ Expect(cast(field_ref("i32"), uint32()), no_change);
+ Expect(cast(field_ref("i32"), uint64()), no_change);
+}
+
+TEST(ExpressionUtils, MakeExecBatch) {
+ auto Expect = [](std::shared_ptr<RecordBatch> partial_batch) {
+ SCOPED_TRACE(partial_batch->ToString());
+ ASSERT_OK_AND_ASSIGN(auto batch, MakeExecBatch(*kBoringSchema, partial_batch));
+
+ ASSERT_EQ(batch.num_values(), kBoringSchema->num_fields());
+ for (int i = 0; i < kBoringSchema->num_fields(); ++i) {
+ const auto& field = *kBoringSchema->field(i);
+
+ SCOPED_TRACE("Field#" + std::to_string(i) + " " + field.ToString());
+
+ EXPECT_TRUE(batch[i].type()->Equals(field.type()))
+ << "Incorrect type " << batch[i].type()->ToString();
+
+ ASSERT_OK_AND_ASSIGN(auto col, FieldRef(field.name()).GetOneOrNone(*partial_batch));
+
+ if (batch[i].is_scalar()) {
+ EXPECT_FALSE(batch[i].scalar()->is_valid)
+ << "Non-null placeholder scalar was injected";
+
+ EXPECT_EQ(col, nullptr)
+ << "Placeholder scalar overwrote column " << col->ToString();
+ } else {
+ AssertDatumsEqual(col, batch[i]);
+ }
+ }
+ };
+
+ auto GetField = [](std::string name) { return kBoringSchema->GetFieldByName(name); };
+
+ constexpr int64_t kNumRows = 3;
+ auto i32 = ArrayFromJSON(int32(), "[1, 2, 3]");
+ auto f32 = ArrayFromJSON(float32(), "[1.5, 2.25, 3.125]");
+
+ // empty
+ Expect(RecordBatchFromJSON(kBoringSchema, "[]"));
+
+ // subset
+ Expect(RecordBatch::Make(schema({GetField("i32"), GetField("f32")}), kNumRows,
+ {i32, f32}));
+
+ // flipped subset
+ Expect(RecordBatch::Make(schema({GetField("f32"), GetField("i32")}), kNumRows,
+ {f32, i32}));
+
+ auto duplicated_names =
+ RecordBatch::Make(schema({GetField("i32"), GetField("i32")}), kNumRows, {i32, i32});
+ ASSERT_RAISES(Invalid, MakeExecBatch(*kBoringSchema, duplicated_names));
+}
+
+class WidgetifyOptions : public compute::FunctionOptions {
+ public:
+ explicit WidgetifyOptions(bool really = true);
+ bool really;
+};
+class WidgetifyOptionsType : public FunctionOptionsType {
+ public:
+ static const FunctionOptionsType* GetInstance() {
+ static std::unique_ptr<FunctionOptionsType> instance(new WidgetifyOptionsType());
+ return instance.get();
+ }
+ const char* type_name() const override { return "widgetify"; }
+ std::string Stringify(const FunctionOptions& options) const override {
+ return type_name();
+ }
+ bool Compare(const FunctionOptions& options,
+ const FunctionOptions& other) const override {
+ return true;
+ }
+ std::unique_ptr<FunctionOptions> Copy(const FunctionOptions& options) const override {
+ const auto& opts = static_cast<const WidgetifyOptions&>(options);
+ return arrow::internal::make_unique<WidgetifyOptions>(opts.really);
+ }
+};
+WidgetifyOptions::WidgetifyOptions(bool really)
+ : FunctionOptions(WidgetifyOptionsType::GetInstance()), really(really) {}
+
+TEST(Expression, ToString) {
+ EXPECT_EQ(field_ref("alpha").ToString(), "alpha");
+
+ EXPECT_EQ(literal(3).ToString(), "3");
+ EXPECT_EQ(literal("a").ToString(), "\"a\"");
+ EXPECT_EQ(literal("a\nb").ToString(), "\"a\\nb\"");
+ EXPECT_EQ(literal(std::make_shared<BooleanScalar>()).ToString(), "null");
+ EXPECT_EQ(literal(std::make_shared<Int64Scalar>()).ToString(), "null");
+ EXPECT_EQ(literal(std::make_shared<BinaryScalar>(Buffer::FromString("az"))).ToString(),
+ "\"617A\"");
+
+ auto ts = *MakeScalar("1990-10-23 10:23:33")->CastTo(timestamp(TimeUnit::NANO));
+ EXPECT_EQ(literal(ts).ToString(), "1990-10-23 10:23:33.000000000");
+
+ EXPECT_EQ(call("add", {literal(3), field_ref("beta")}).ToString(), "add(3, beta)");
+
+ auto in_12 = call("index_in", {field_ref("beta")},
+ compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2]")});
+
+ EXPECT_EQ(in_12.ToString(),
+ "index_in(beta, {value_set=int32:[\n 1,\n 2\n], skip_nulls=false})");
+
+ EXPECT_EQ(and_(field_ref("a"), field_ref("b")).ToString(), "(a and b)");
+ EXPECT_EQ(or_(field_ref("a"), field_ref("b")).ToString(), "(a or b)");
+ EXPECT_EQ(not_(field_ref("a")).ToString(), "invert(a)");
+
+ EXPECT_EQ(
+ cast(field_ref("a"), int32()).ToString(),
+ "cast(a, {to_type=int32, allow_int_overflow=false, allow_time_truncate=false, "
+ "allow_time_overflow=false, allow_decimal_truncate=false, "
+ "allow_float_truncate=false, allow_invalid_utf8=false})");
+ EXPECT_EQ(
+ cast(field_ref("a"), nullptr).ToString(),
+ "cast(a, {to_type=<NULLPTR>, allow_int_overflow=false, allow_time_truncate=false, "
+ "allow_time_overflow=false, allow_decimal_truncate=false, "
+ "allow_float_truncate=false, allow_invalid_utf8=false})");
+
+ // NB: corrupted for nullary functions but we don't have any of those
+ EXPECT_EQ(call("widgetify", {}).ToString(), "widgetif)");
+ EXPECT_EQ(
+ call("widgetify", {literal(1)}, std::make_shared<WidgetifyOptions>()).ToString(),
+ "widgetify(1, widgetify)");
+
+ EXPECT_EQ(equal(field_ref("a"), literal(1)).ToString(), "(a == 1)");
+ EXPECT_EQ(less(field_ref("a"), literal(2)).ToString(), "(a < 2)");
+ EXPECT_EQ(greater(field_ref("a"), literal(3)).ToString(), "(a > 3)");
+ EXPECT_EQ(not_equal(field_ref("a"), literal("a")).ToString(), "(a != \"a\")");
+ EXPECT_EQ(less_equal(field_ref("a"), literal("b")).ToString(), "(a <= \"b\")");
+ EXPECT_EQ(greater_equal(field_ref("a"), literal("c")).ToString(), "(a >= \"c\")");
+
+ EXPECT_EQ(project(
+ {
+ field_ref("a"),
+ field_ref("a"),
+ literal(3),
+ in_12,
+ },
+ {
+ "a",
+ "renamed_a",
+ "three",
+ "b",
+ })
+ .ToString(),
+ "{a=a, renamed_a=a, three=3, b=" + in_12.ToString() + "}");
+}
+
+TEST(Expression, Equality) {
+ EXPECT_EQ(literal(1), literal(1));
+ EXPECT_NE(literal(1), literal(2));
+
+ EXPECT_EQ(field_ref("a"), field_ref("a"));
+ EXPECT_NE(field_ref("a"), field_ref("b"));
+ EXPECT_NE(field_ref("a"), literal(2));
+
+ EXPECT_EQ(call("add", {literal(3), field_ref("a")}),
+ call("add", {literal(3), field_ref("a")}));
+ EXPECT_NE(call("add", {literal(3), field_ref("a")}),
+ call("add", {literal(2), field_ref("a")}));
+ EXPECT_NE(call("add", {field_ref("a"), literal(3)}),
+ call("add", {literal(3), field_ref("a")}));
+
+ auto in_123 = compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]")};
+ EXPECT_EQ(call("add", {literal(3), call("index_in", {field_ref("beta")}, in_123)}),
+ call("add", {literal(3), call("index_in", {field_ref("beta")}, in_123)}));
+
+ auto in_12 = compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2]")};
+ EXPECT_NE(call("add", {literal(3), call("index_in", {field_ref("beta")}, in_12)}),
+ call("add", {literal(3), call("index_in", {field_ref("beta")}, in_123)}));
+
+ EXPECT_EQ(cast(field_ref("a"), int32()), cast(field_ref("a"), int32()));
+ EXPECT_NE(cast(field_ref("a"), int32()), cast(field_ref("a"), int64()));
+ EXPECT_NE(cast(field_ref("a"), int32()),
+ call("cast", {field_ref("a")}, compute::CastOptions::Unsafe(int32())));
+}
+
+Expression null_literal(const std::shared_ptr<DataType>& type) {
+ return Expression(MakeNullScalar(type));
+}
+
+TEST(Expression, Hash) {
+ std::unordered_set<Expression, Expression::Hash> set;
+
+ EXPECT_TRUE(set.emplace(field_ref("alpha")).second);
+ EXPECT_TRUE(set.emplace(field_ref("beta")).second);
+ EXPECT_FALSE(set.emplace(field_ref("beta")).second) << "already inserted";
+ EXPECT_TRUE(set.emplace(literal(1)).second);
+ EXPECT_FALSE(set.emplace(literal(1)).second) << "already inserted";
+ EXPECT_TRUE(set.emplace(literal(3)).second);
+
+ EXPECT_TRUE(set.emplace(null_literal(int32())).second);
+ EXPECT_FALSE(set.emplace(null_literal(int32())).second) << "already inserted";
+ EXPECT_TRUE(set.emplace(null_literal(float32())).second);
+ // NB: no validation on construction; we couldn't execute
+ // add with zero arguments
+ EXPECT_TRUE(set.emplace(call("add", {})).second);
+ EXPECT_FALSE(set.emplace(call("add", {})).second) << "already inserted";
+
+ // NB: unbound expressions don't check for availability in any registry
+ EXPECT_TRUE(set.emplace(call("widgetify", {})).second);
+
+ EXPECT_EQ(set.size(), 8);
+}
+
+TEST(Expression, IsScalarExpression) {
+ EXPECT_TRUE(literal(true).IsScalarExpression());
+
+ auto arr = ArrayFromJSON(int8(), "[]");
+ EXPECT_FALSE(literal(arr).IsScalarExpression());
+
+ EXPECT_TRUE(field_ref("a").IsScalarExpression());
+
+ EXPECT_TRUE(equal(field_ref("a"), literal(1)).IsScalarExpression());
+
+ EXPECT_FALSE(equal(field_ref("a"), literal(arr)).IsScalarExpression());
+
+ EXPECT_TRUE(call("is_in", {field_ref("a")}, compute::SetLookupOptions{arr, true})
+ .IsScalarExpression());
+
+ // non scalar function
+ EXPECT_FALSE(call("take", {field_ref("a"), literal(arr)}).IsScalarExpression());
+}
+
+TEST(Expression, IsSatisfiable) {
+ EXPECT_TRUE(literal(true).IsSatisfiable());
+ EXPECT_FALSE(literal(false).IsSatisfiable());
+
+ auto null = std::make_shared<BooleanScalar>();
+ EXPECT_FALSE(literal(null).IsSatisfiable());
+
+ EXPECT_TRUE(field_ref("a").IsSatisfiable());
+
+ EXPECT_TRUE(equal(field_ref("a"), literal(1)).IsSatisfiable());
+
+ // NB: no constant folding here
+ EXPECT_TRUE(equal(literal(0), literal(1)).IsSatisfiable());
+
+ // When a top level conjunction contains an Expression which is certain to evaluate to
+ // null, it can only evaluate to null or false.
+ auto never_true = and_(literal(null), field_ref("a"));
+ // This may appear in satisfiable filters if coalesced (for example, wrapped in fill_na)
+ EXPECT_TRUE(call("is_null", {never_true}).IsSatisfiable());
+ // ... but at the top level it is not satisfiable.
+ // This special case arises when (for example) an absent column has made
+ // one member of the conjunction always-null. This is fairly common and
+ // would be a worthwhile optimization to support.
+ // EXPECT_FALSE(null_or_false).IsSatisfiable());
+}
+
+TEST(Expression, FieldsInExpression) {
+ auto ExpectFieldsAre = [](Expression expr, std::vector<FieldRef> expected) {
+ EXPECT_THAT(FieldsInExpression(expr), testing::ContainerEq(expected));
+ };
+
+ ExpectFieldsAre(literal(true), {});
+
+ ExpectFieldsAre(field_ref("a"), {"a"});
+
+ ExpectFieldsAre(equal(field_ref("a"), literal(1)), {"a"});
+
+ ExpectFieldsAre(equal(field_ref("a"), field_ref("b")), {"a", "b"});
+
+ ExpectFieldsAre(
+ or_(equal(field_ref("a"), literal(1)), equal(field_ref("a"), literal(2))),
+ {"a", "a"});
+
+ ExpectFieldsAre(
+ or_(equal(field_ref("a"), literal(1)), equal(field_ref("b"), literal(2))),
+ {"a", "b"});
+
+ ExpectFieldsAre(or_(and_(not_(equal(field_ref("a"), literal(1))),
+ equal(field_ref("b"), literal(2))),
+ not_(less(field_ref("c"), literal(3)))),
+ {"a", "b", "c"});
+}
+
+TEST(Expression, ExpressionHasFieldRefs) {
+ EXPECT_FALSE(ExpressionHasFieldRefs(literal(true)));
+
+ EXPECT_FALSE(ExpressionHasFieldRefs(call("add", {literal(1), literal(3)})));
+
+ EXPECT_TRUE(ExpressionHasFieldRefs(field_ref("a")));
+
+ EXPECT_TRUE(ExpressionHasFieldRefs(equal(field_ref("a"), literal(1))));
+
+ EXPECT_TRUE(ExpressionHasFieldRefs(equal(field_ref("a"), field_ref("b"))));
+
+ EXPECT_TRUE(ExpressionHasFieldRefs(
+ or_(equal(field_ref("a"), literal(1)), equal(field_ref("a"), literal(2)))));
+
+ EXPECT_TRUE(ExpressionHasFieldRefs(
+ or_(equal(field_ref("a"), literal(1)), equal(field_ref("b"), literal(2)))));
+
+ EXPECT_TRUE(ExpressionHasFieldRefs(or_(
+ and_(not_(equal(field_ref("a"), literal(1))), equal(field_ref("b"), literal(2))),
+ not_(less(field_ref("c"), literal(3))))));
+}
+
+TEST(Expression, BindLiteral) {
+ for (Datum dat : {
+ Datum(3),
+ Datum(3.5),
+ Datum(ArrayFromJSON(int32(), "[1,2,3]")),
+ }) {
+ // literals are always considered bound
+ auto expr = literal(dat);
+ EXPECT_EQ(expr.descr(), dat.descr());
+ EXPECT_TRUE(expr.IsBound());
+ }
+}
+
+void ExpectBindsTo(Expression expr, util::optional<Expression> expected,
+ Expression* bound_out = nullptr) {
+ if (!expected) {
+ expected = expr;
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema));
+ EXPECT_TRUE(bound.IsBound());
+
+ ASSERT_OK_AND_ASSIGN(expected, expected->Bind(*kBoringSchema));
+ EXPECT_EQ(bound, *expected) << " unbound: " << expr.ToString();
+
+ if (bound_out) {
+ *bound_out = bound;
+ }
+}
+
+TEST(Expression, BindFieldRef) {
+ // an unbound field_ref does not have the output ValueDescr set
+ auto expr = field_ref("alpha");
+ EXPECT_EQ(expr.descr(), ValueDescr{});
+ EXPECT_FALSE(expr.IsBound());
+
+ ExpectBindsTo(field_ref("i32"), no_change, &expr);
+ EXPECT_EQ(expr.descr(), ValueDescr::Array(int32()));
+
+ // if the field is not found, an error will be raised
+ ASSERT_RAISES(Invalid, field_ref("no such field").Bind(*kBoringSchema));
+
+ // referencing a field by name is not supported if that name is not unique
+ // in the input schema
+ ASSERT_RAISES(Invalid, field_ref("alpha").Bind(Schema(
+ {field("alpha", int32()), field("alpha", float32())})));
+
+ // referencing nested fields is not supported
+ ASSERT_RAISES(NotImplemented,
+ field_ref(FieldRef("a", "b"))
+ .Bind(Schema({field("a", struct_({field("b", int32())}))})));
+}
+
+TEST(Expression, BindCall) {
+ auto expr = call("add", {field_ref("i32"), field_ref("i32_req")});
+ EXPECT_FALSE(expr.IsBound());
+
+ ExpectBindsTo(expr, no_change, &expr);
+ EXPECT_EQ(expr.descr(), ValueDescr::Array(int32()));
+
+ ExpectBindsTo(call("add", {field_ref("f32"), literal(3)}),
+ call("add", {field_ref("f32"), literal(3.0F)}));
+
+ ExpectBindsTo(call("add", {field_ref("i32"), literal(3.5F)}),
+ call("add", {cast(field_ref("i32"), float32()), literal(3.5F)}));
+}
+
+TEST(Expression, BindWithImplicitCasts) {
+ for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) {
+ // cast arguments to common numeric type
+ ExpectBindsTo(cmp(field_ref("i64"), field_ref("i32")),
+ cmp(field_ref("i64"), cast(field_ref("i32"), int64())));
+
+ ExpectBindsTo(cmp(field_ref("i64"), field_ref("f32")),
+ cmp(cast(field_ref("i64"), float32()), field_ref("f32")));
+
+ ExpectBindsTo(cmp(field_ref("i32"), field_ref("i64")),
+ cmp(cast(field_ref("i32"), int64()), field_ref("i64")));
+
+ ExpectBindsTo(cmp(field_ref("i8"), field_ref("u32")),
+ cmp(cast(field_ref("i8"), int64()), cast(field_ref("u32"), int64())));
+
+ // cast dictionary to value type
+ ExpectBindsTo(cmp(field_ref("dict_str"), field_ref("str")),
+ cmp(cast(field_ref("dict_str"), utf8()), field_ref("str")));
+
+ ExpectBindsTo(cmp(field_ref("dict_i32"), literal(int64_t(4))),
+ cmp(cast(field_ref("dict_i32"), int64()), literal(int64_t(4))));
+ }
+
+ compute::SetLookupOptions in_a{ArrayFromJSON(utf8(), R"(["a"])")};
+
+ // cast dictionary to value type
+ ExpectBindsTo(call("is_in", {field_ref("dict_str")}, in_a),
+ call("is_in", {cast(field_ref("dict_str"), utf8())}, in_a));
+}
+
+TEST(Expression, BindNestedCall) {
+ auto expr =
+ call("add", {field_ref("a"),
+ call("subtract", {call("multiply", {field_ref("b"), field_ref("c")}),
+ field_ref("d")})});
+ EXPECT_FALSE(expr.IsBound());
+
+ ASSERT_OK_AND_ASSIGN(expr,
+ expr.Bind(Schema({field("a", int32()), field("b", int32()),
+ field("c", int32()), field("d", int32())})));
+ EXPECT_EQ(expr.descr(), ValueDescr::Array(int32()));
+ EXPECT_TRUE(expr.IsBound());
+}
+
+TEST(Expression, ExecuteFieldRef) {
+ auto ExpectRefIs = [](FieldRef ref, Datum in, Datum expected) {
+ auto expr = field_ref(ref);
+
+ ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr()));
+ ASSERT_OK_AND_ASSIGN(Datum actual,
+ ExecuteScalarExpression(expr, Schema(in.type()->fields()), in));
+
+ AssertDatumsEqual(actual, expected, /*verbose=*/true);
+ };
+
+ ExpectRefIs("a", ArrayFromJSON(struct_({field("a", float64())}), R"([
+ {"a": 6.125},
+ {"a": 0.0},
+ {"a": -1}
+ ])"),
+ ArrayFromJSON(float64(), R"([6.125, 0.0, -1])"));
+
+ ExpectRefIs("a",
+ ArrayFromJSON(struct_({
+ field("a", float64()),
+ field("b", float64()),
+ }),
+ R"([
+ {"a": 6.125, "b": 7.5},
+ {"a": 0.0, "b": 2.125},
+ {"a": -1, "b": 4.0}
+ ])"),
+ ArrayFromJSON(float64(), R"([6.125, 0.0, -1])"));
+
+ ExpectRefIs("b",
+ ArrayFromJSON(struct_({
+ field("a", float64()),
+ field("b", float64()),
+ }),
+ R"([
+ {"a": 6.125, "b": 7.5},
+ {"a": 0.0, "b": 2.125},
+ {"a": -1, "b": 4.0}
+ ])"),
+ ArrayFromJSON(float64(), R"([7.5, 2.125, 4.0])"));
+}
+
+Result<Datum> NaiveExecuteScalarExpression(const Expression& expr, const Datum& input) {
+ if (auto lit = expr.literal()) {
+ return *lit;
+ }
+
+ if (auto ref = expr.field_ref()) {
+ if (input.type()) {
+ return ref->GetOneOrNone(*input.make_array());
+ }
+ return ref->GetOneOrNone(*input.record_batch());
+ }
+
+ auto call = CallNotNull(expr);
+
+ std::vector<Datum> arguments(call->arguments.size());
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(arguments[i],
+ NaiveExecuteScalarExpression(call->arguments[i], input));
+ }
+
+ compute::ExecContext exec_context;
+ ARROW_ASSIGN_OR_RAISE(auto function, GetFunction(*call, &exec_context));
+
+ auto descrs = GetDescriptors(call->arguments);
+ ARROW_ASSIGN_OR_RAISE(auto expected_kernel, function->DispatchExact(descrs));
+
+ EXPECT_EQ(call->kernel, expected_kernel);
+ return function->Execute(arguments, call->options.get(), &exec_context);
+}
+
+void ExpectExecute(Expression expr, Datum in, Datum* actual_out = NULLPTR) {
+ std::shared_ptr<Schema> schm;
+ if (in.is_value()) {
+ ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr()));
+ schm = schema(in.type()->fields());
+ } else {
+ ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*in.schema()));
+ schm = in.schema();
+ }
+
+ ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, *schm, in));
+
+ ASSERT_OK_AND_ASSIGN(Datum expected, NaiveExecuteScalarExpression(expr, in));
+
+ AssertDatumsEqual(actual, expected, /*verbose=*/true);
+
+ if (actual_out) {
+ *actual_out = actual;
+ }
+}
+
+TEST(Expression, ExecuteCall) {
+ ExpectExecute(call("add", {field_ref("a"), literal(3.5)}),
+ ArrayFromJSON(struct_({field("a", float64())}), R"([
+ {"a": 6.125},
+ {"a": 0.0},
+ {"a": -1}
+ ])"));
+
+ ExpectExecute(
+ call("add", {field_ref("a"), call("subtract", {literal(3.5), field_ref("b")})}),
+ ArrayFromJSON(struct_({field("a", float64()), field("b", float64())}), R"([
+ {"a": 6.125, "b": 3.375},
+ {"a": 0.0, "b": 1},
+ {"a": -1, "b": 4.75}
+ ])"));
+
+ ExpectExecute(call("strptime", {field_ref("a")},
+ compute::StrptimeOptions("%m/%d/%Y", TimeUnit::MICRO)),
+ ArrayFromJSON(struct_({field("a", utf8())}), R"([
+ {"a": "5/1/2020"},
+ {"a": null},
+ {"a": "12/11/1900"}
+ ])"));
+
+ ExpectExecute(project({call("add", {field_ref("a"), literal(3.5)})}, {"a + 3.5"}),
+ ArrayFromJSON(struct_({field("a", float64())}), R"([
+ {"a": 6.125},
+ {"a": 0.0},
+ {"a": -1}
+ ])"));
+}
+
+TEST(Expression, ExecuteDictionaryTransparent) {
+ ExpectExecute(
+ equal(field_ref("a"), field_ref("b")),
+ ArrayFromJSON(
+ struct_({field("a", dictionary(int32(), utf8())), field("b", utf8())}), R"([
+ {"a": "hi", "b": "hi"},
+ {"a": "", "b": ""},
+ {"a": "hi", "b": "hello"}
+ ])"));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto expr, project({field_ref("i32"), field_ref("dict_str")}, {"i32", "dict_str"})
+ .Bind(*kBoringSchema));
+
+ ASSERT_OK_AND_ASSIGN(
+ expr, SimplifyWithGuarantee(expr, equal(field_ref("dict_str"), literal("eh"))));
+
+ ASSERT_OK_AND_ASSIGN(auto res, ExecuteScalarExpression(
+ expr, *kBoringSchema,
+ ArrayFromJSON(struct_({field("i32", int32())}), R"([
+ {"i32": 0},
+ {"i32": 1},
+ {"i32": 2}
+ ])")));
+
+ AssertDatumsEqual(
+ res, ArrayFromJSON(struct_({field("i32", int32()),
+ field("dict_str", dictionary(int32(), utf8()))}),
+ R"([
+ {"i32": 0, "dict_str": "eh"},
+ {"i32": 1, "dict_str": "eh"},
+ {"i32": 2, "dict_str": "eh"}
+ ])"));
+}
+
+void ExpectIdenticalIfUnchanged(Expression modified, Expression original) {
+ if (modified == original) {
+ // no change -> must be identical
+ EXPECT_TRUE(Identical(modified, original)) << " " << original.ToString();
+ }
+}
+
+struct {
+ void operator()(Expression expr, Expression expected) {
+ ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema));
+ ASSERT_OK_AND_ASSIGN(expected, expected.Bind(*kBoringSchema));
+
+ ASSERT_OK_AND_ASSIGN(auto folded, FoldConstants(expr));
+
+ EXPECT_EQ(folded, expected);
+ ExpectIdenticalIfUnchanged(folded, expr);
+ }
+} ExpectFoldsTo;
+
+TEST(Expression, FoldConstants) {
+ // literals are unchanged
+ ExpectFoldsTo(literal(3), literal(3));
+
+ // field_refs are unchanged
+ ExpectFoldsTo(field_ref("i32"), field_ref("i32"));
+
+ // call against literals (3 + 2 == 5)
+ ExpectFoldsTo(call("add", {literal(3), literal(2)}), literal(5));
+
+ ExpectFoldsTo(call("equal", {literal(3), literal(3)}), literal(true));
+
+ // call against literal and field_ref
+ ExpectFoldsTo(call("add", {literal(3), field_ref("i32")}),
+ call("add", {literal(3), field_ref("i32")}));
+
+ // nested call against literals ((8 - (2 * 3)) + 2 == 4)
+ ExpectFoldsTo(call("add",
+ {
+ call("subtract",
+ {
+ literal(8),
+ call("multiply", {literal(2), literal(3)}),
+ }),
+ literal(2),
+ }),
+ literal(4));
+
+ // nested call against literals with one field_ref
+ // (i32 - (2 * 3)) + 2 == (i32 - 6) + 2
+ // NB this could be improved further by using associativity of addition; another pass
+ ExpectFoldsTo(call("add",
+ {
+ call("subtract",
+ {
+ field_ref("i32"),
+ call("multiply", {literal(2), literal(3)}),
+ }),
+ literal(2),
+ }),
+ call("add", {
+ call("subtract",
+ {
+ field_ref("i32"),
+ literal(6),
+ }),
+ literal(2),
+ }));
+
+ compute::SetLookupOptions in_123(ArrayFromJSON(int32(), "[1,2,3]"));
+
+ ExpectFoldsTo(call("is_in", {literal(2)}, in_123), literal(true));
+
+ ExpectFoldsTo(
+ call("is_in",
+ {call("add", {field_ref("i32"), call("multiply", {literal(2), literal(3)})})},
+ in_123),
+ call("is_in", {call("add", {field_ref("i32"), literal(6)})}, in_123));
+}
+
+TEST(Expression, FoldConstantsBoolean) {
+ // test and_kleene/or_kleene-specific optimizations
+ auto one = literal(1);
+ auto two = literal(2);
+ auto whatever = equal(call("add", {one, field_ref("i32")}), two);
+
+ auto true_ = literal(true);
+ auto false_ = literal(false);
+
+ ExpectFoldsTo(and_(false_, whatever), false_);
+ ExpectFoldsTo(and_(true_, whatever), whatever);
+ ExpectFoldsTo(and_(whatever, whatever), whatever);
+
+ ExpectFoldsTo(or_(true_, whatever), true_);
+ ExpectFoldsTo(or_(false_, whatever), whatever);
+ ExpectFoldsTo(or_(whatever, whatever), whatever);
+}
+
+TEST(Expression, ExtractKnownFieldValues) {
+ struct {
+ void operator()(Expression guarantee,
+ std::unordered_map<FieldRef, Datum, FieldRef::Hash> expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee));
+ EXPECT_THAT(actual.map, UnorderedElementsAreArray(expected))
+ << " guarantee: " << guarantee.ToString();
+ }
+ } ExpectKnown;
+
+ ExpectKnown(equal(field_ref("i32"), literal(3)), {{"i32", Datum(3)}});
+
+ ExpectKnown(greater(field_ref("i32"), literal(3)), {});
+
+ // FIXME known null should be expressed with is_null rather than equality
+ auto null_int32 = std::make_shared<Int32Scalar>();
+ ExpectKnown(equal(field_ref("i32"), literal(null_int32)), {{"i32", Datum(null_int32)}});
+
+ ExpectKnown(
+ and_({equal(field_ref("i32"), literal(3)), equal(field_ref("f32"), literal(1.5F))}),
+ {{"i32", Datum(3)}, {"f32", Datum(1.5F)}});
+
+ // NB: guarantees are *not* automatically canonicalized
+ ExpectKnown(
+ and_({equal(field_ref("i32"), literal(3)), equal(literal(1.5F), field_ref("f32"))}),
+ {{"i32", Datum(3)}});
+
+ // NB: guarantees are *not* automatically simplified
+ // (the below could be constant folded to a usable guarantee)
+ ExpectKnown(or_({equal(field_ref("i32"), literal(3)), literal(false)}), {});
+
+ // NB: guarantees are unbound; applying them may require casts
+ ExpectKnown(equal(field_ref("i32"), literal("1234324")), {{"i32", Datum("1234324")}});
+
+ ExpectKnown(
+ and_({equal(field_ref("i32"), literal(3)), equal(field_ref("f32"), literal(2.F)),
+ equal(field_ref("i32_req"), literal(1))}),
+ {{"i32", Datum(3)}, {"f32", Datum(2.F)}, {"i32_req", Datum(1)}});
+
+ ExpectKnown(
+ and_(or_(equal(field_ref("i32"), literal(3)), equal(field_ref("i32"), literal(4))),
+ equal(field_ref("f32"), literal(2.F))),
+ {{"f32", Datum(2.F)}});
+
+ ExpectKnown(and_({equal(field_ref("i32"), literal(3)),
+ equal(field_ref("f32"), field_ref("f32_req")),
+ equal(field_ref("i32_req"), literal(1))}),
+ {{"i32", Datum(3)}, {"i32_req", Datum(1)}});
+}
+
+TEST(Expression, ReplaceFieldsWithKnownValues) {
+ auto ExpectReplacesTo =
+ [](Expression expr,
+ const std::unordered_map<FieldRef, Datum, FieldRef::Hash>& known_values,
+ Expression unbound_expected) {
+ ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema));
+ ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema));
+ ASSERT_OK_AND_ASSIGN(auto replaced, ReplaceFieldsWithKnownValues(
+ KnownFieldValues{known_values}, expr));
+
+ EXPECT_EQ(replaced, expected);
+ ExpectIdenticalIfUnchanged(replaced, expr);
+ };
+
+ std::unordered_map<FieldRef, Datum, FieldRef::Hash> i32_is_3{{"i32", Datum(3)}};
+
+ ExpectReplacesTo(literal(1), i32_is_3, literal(1));
+
+ ExpectReplacesTo(field_ref("i32"), i32_is_3, literal(3));
+
+ // NB: known_values will be cast
+ ExpectReplacesTo(field_ref("i32"), {{"i32", Datum("3")}}, literal(3));
+
+ ExpectReplacesTo(field_ref("f32"), i32_is_3, field_ref("f32"));
+
+ ExpectReplacesTo(equal(field_ref("i32"), literal(1)), i32_is_3,
+ equal(literal(3), literal(1)));
+
+ Datum dict_str{
+ DictionaryScalar::Make(MakeScalar(0), ArrayFromJSON(utf8(), R"(["3"])"))};
+ ExpectReplacesTo(field_ref("dict_str"), {{"dict_str", dict_str}}, literal(dict_str));
+
+ ExpectReplacesTo(call("add",
+ {
+ call("subtract",
+ {
+ field_ref("i32"),
+ call("multiply", {literal(2), literal(3)}),
+ }),
+ literal(2),
+ }),
+ i32_is_3,
+ call("add", {
+ call("subtract",
+ {
+ literal(3),
+ call("multiply", {literal(2), literal(3)}),
+ }),
+ literal(2),
+ }));
+
+ std::unordered_map<FieldRef, Datum, FieldRef::Hash> i32_valid_str_null{
+ {"i32", Datum(3)}, {"str", MakeNullScalar(utf8())}};
+
+ ExpectReplacesTo(is_null(field_ref("i32")), i32_valid_str_null, is_null(literal(3)));
+
+ ExpectReplacesTo(is_valid(field_ref("i32")), i32_valid_str_null, is_valid(literal(3)));
+
+ ExpectReplacesTo(is_null(field_ref("str")), i32_valid_str_null,
+ is_null(null_literal(utf8())));
+
+ ExpectReplacesTo(is_valid(field_ref("str")), i32_valid_str_null,
+ is_valid(null_literal(utf8())));
+
+ Datum dict_i32{
+ DictionaryScalar::Make(MakeScalar<int32_t>(0), ArrayFromJSON(int32(), R"([3])"))};
+ // cast dictionary(int32(), int32()) -> dictionary(int32(), utf8())
+ ExpectReplacesTo(field_ref("dict_str"), {{"dict_str", dict_i32}}, literal(dict_str));
+
+ // cast dictionary(int8(), utf8()) -> dictionary(int32(), utf8())
+ auto dict_int8_str = Datum{
+ DictionaryScalar::Make(MakeScalar<int8_t>(0), ArrayFromJSON(utf8(), R"(["3"])"))};
+ ExpectReplacesTo(field_ref("dict_str"), {{"dict_str", dict_int8_str}},
+ literal(dict_str));
+}
+
+struct {
+ void operator()(Expression expr, Expression unbound_expected) const {
+ ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema));
+ ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema));
+ ASSERT_OK_AND_ASSIGN(auto actual, Canonicalize(bound));
+
+ EXPECT_EQ(actual, expected);
+ ExpectIdenticalIfUnchanged(actual, bound);
+ }
+} ExpectCanonicalizesTo;
+
+TEST(Expression, CanonicalizeTrivial) {
+ ExpectCanonicalizesTo(literal(1), literal(1));
+
+ ExpectCanonicalizesTo(field_ref("i32"), field_ref("i32"));
+
+ ExpectCanonicalizesTo(equal(field_ref("i32"), field_ref("i32_req")),
+ equal(field_ref("i32"), field_ref("i32_req")));
+}
+
+TEST(Expression, CanonicalizeAnd) {
+ // some aliases for brevity:
+ auto true_ = literal(true);
+ auto null_ = literal(std::make_shared<BooleanScalar>());
+
+ auto b = field_ref("bool");
+ auto c = equal(literal(1), literal(2));
+
+ // no change possible:
+ ExpectCanonicalizesTo(and_(b, c), and_(b, c));
+
+ // literals are placed innermost
+ ExpectCanonicalizesTo(and_(b, true_), and_(true_, b));
+ ExpectCanonicalizesTo(and_(true_, b), and_(true_, b));
+
+ ExpectCanonicalizesTo(and_(b, and_(true_, c)), and_(and_(true_, b), c));
+ ExpectCanonicalizesTo(and_(b, and_(and_(true_, true_), c)),
+ and_(and_(and_(true_, true_), b), c));
+ ExpectCanonicalizesTo(and_(b, and_(and_(true_, null_), c)),
+ and_(and_(and_(null_, true_), b), c));
+ ExpectCanonicalizesTo(and_(b, and_(and_(true_, null_), and_(c, null_))),
+ and_(and_(and_(and_(null_, null_), true_), b), c));
+
+ // catches and_kleene even when it's a subexpression
+ ExpectCanonicalizesTo(call("is_valid", {and_(b, true_)}),
+ call("is_valid", {and_(true_, b)}));
+}
+
+TEST(Expression, CanonicalizeComparison) {
+ ExpectCanonicalizesTo(equal(literal(1), field_ref("i32")),
+ equal(field_ref("i32"), literal(1)));
+
+ ExpectCanonicalizesTo(equal(field_ref("i32"), literal(1)),
+ equal(field_ref("i32"), literal(1)));
+
+ ExpectCanonicalizesTo(less(literal(1), field_ref("i32")),
+ greater(field_ref("i32"), literal(1)));
+
+ ExpectCanonicalizesTo(less(field_ref("i32"), literal(1)),
+ less(field_ref("i32"), literal(1)));
+}
+
+struct Simplify {
+ Expression expr;
+
+ struct Expectable {
+ Expression expr, guarantee;
+
+ void Expect(Expression unbound_expected) {
+ ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema));
+
+ ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, guarantee));
+
+ ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema));
+ EXPECT_EQ(simplified, expected) << " original: " << expr.ToString() << "\n"
+ << " guarantee: " << guarantee.ToString() << "\n"
+ << (simplified == bound ? " (no change)\n" : "");
+
+ ExpectIdenticalIfUnchanged(simplified, bound);
+ }
+ void ExpectUnchanged() { Expect(expr); }
+ void Expect(bool constant) { Expect(literal(constant)); }
+ };
+
+ Expectable WithGuarantee(Expression guarantee) { return {expr, guarantee}; }
+};
+
+TEST(Expression, SingleComparisonGuarantees) {
+ auto i32 = field_ref("i32");
+
+ // i32 is guaranteed equal to 3, so the projection can just materialize that constant
+ // and need not incur IO
+ Simplify{project({call("add", {i32, literal(1)})}, {"i32 + 1"})}
+ .WithGuarantee(equal(i32, literal(3)))
+ .Expect(literal(
+ std::make_shared<StructScalar>(ScalarVector{std::make_shared<Int32Scalar>(4)},
+ struct_({field("i32 + 1", int32())}))));
+
+ // i32 is guaranteed equal to 5 everywhere, so filtering i32==5 is redundant and the
+ // filter can be simplified to true (== select everything)
+ Simplify{
+ equal(i32, literal(5)),
+ }
+ .WithGuarantee(equal(i32, literal(5)))
+ .Expect(true);
+
+ Simplify{
+ equal(i32, literal(5)),
+ }
+ .WithGuarantee(equal(i32, literal(5)))
+ .Expect(true);
+
+ Simplify{
+ less_equal(i32, literal(5)),
+ }
+ .WithGuarantee(equal(i32, literal(5)))
+ .Expect(true);
+
+ Simplify{
+ less(i32, literal(5)),
+ }
+ .WithGuarantee(equal(i32, literal(3)))
+ .Expect(true);
+
+ Simplify{
+ greater_equal(i32, literal(5)),
+ }
+ .WithGuarantee(greater(i32, literal(5)))
+ .Expect(true);
+
+ // i32 is guaranteed less than 3 everywhere, so filtering i32==5 is redundant and the
+ // filter can be simplified to false (== select nothing)
+ Simplify{
+ equal(i32, literal(5)),
+ }
+ .WithGuarantee(less(i32, literal(3)))
+ .Expect(false);
+
+ Simplify{
+ less(i32, literal(5)),
+ }
+ .WithGuarantee(equal(i32, literal(5)))
+ .Expect(false);
+
+ Simplify{
+ less_equal(i32, literal(3)),
+ }
+ .WithGuarantee(equal(i32, literal(5)))
+ .Expect(false);
+
+ Simplify{
+ equal(i32, literal(0.5)),
+ }
+ .WithGuarantee(greater_equal(i32, literal(1)))
+ .Expect(false);
+
+ // no simplification possible:
+ Simplify{
+ not_equal(i32, literal(3)),
+ }
+ .WithGuarantee(less(i32, literal(5)))
+ .ExpectUnchanged();
+
+ // exhaustive coverage of all single comparison simplifications
+ for (std::string filter_op :
+ {"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) {
+ for (auto filter_rhs : {literal(5), literal(3), literal(7)}) {
+ auto filter = call(filter_op, {i32, filter_rhs});
+ for (std::string guarantee_op :
+ {"equal", "less", "less_equal", "greater", "greater_equal"}) {
+ auto guarantee = call(guarantee_op, {i32, literal(5)});
+
+ // generate data which satisfies the guarantee
+ static std::unordered_map<std::string, std::string> satisfying_i32{
+ {"equal", "[5]"},
+ {"less", "[4, 3, 2, 1]"},
+ {"less_equal", "[5, 4, 3, 2, 1]"},
+ {"greater", "[6, 7, 8, 9]"},
+ {"greater_equal", "[5, 6, 7, 8, 9]"},
+ };
+
+ ASSERT_OK_AND_ASSIGN(
+ Datum input,
+ StructArray::Make({ArrayFromJSON(int32(), satisfying_i32[guarantee_op])},
+ {"i32"}));
+
+ ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema));
+ ASSERT_OK_AND_ASSIGN(Datum evaluated,
+ ExecuteScalarExpression(filter, *kBoringSchema, input));
+
+ // ensure that the simplified filter is as simplified as it could be
+ // (this is always possible for single comparisons)
+ bool all = true, none = true;
+ for (int64_t i = 0; i < input.length(); ++i) {
+ if (evaluated.array_as<BooleanArray>()->Value(i)) {
+ none = false;
+ } else {
+ all = false;
+ }
+ }
+ Simplify{filter}.WithGuarantee(guarantee).Expect(
+ all ? literal(true) : none ? literal(false) : filter);
+ }
+ }
+ }
+}
+
+TEST(Expression, SimplifyWithGuarantee) {
+ // drop both members of a conjunctive filter
+ Simplify{
+ and_(equal(field_ref("i32"), literal(2)), equal(field_ref("f32"), literal(3.5F)))}
+ .WithGuarantee(and_(greater_equal(field_ref("i32"), literal(0)),
+ less_equal(field_ref("i32"), literal(1))))
+ .Expect(false);
+
+ // drop one member of a conjunctive filter
+ Simplify{
+ and_(equal(field_ref("i32"), literal(0)), equal(field_ref("f32"), literal(3.5F)))}
+ .WithGuarantee(equal(field_ref("i32"), literal(0)))
+ .Expect(equal(field_ref("f32"), literal(3.5F)));
+
+ // drop both members of a disjunctive filter
+ Simplify{
+ or_(equal(field_ref("i32"), literal(0)), equal(field_ref("f32"), literal(3.5F)))}
+ .WithGuarantee(equal(field_ref("i32"), literal(0)))
+ .Expect(true);
+
+ // drop one member of a disjunctive filter
+ Simplify{or_(equal(field_ref("i32"), literal(0)), equal(field_ref("i32"), literal(3)))}
+ .WithGuarantee(and_(greater_equal(field_ref("i32"), literal(0)),
+ less_equal(field_ref("i32"), literal(1))))
+ .Expect(equal(field_ref("i32"), literal(0)));
+
+ Simplify{or_(equal(field_ref("f32"), literal(0)), equal(field_ref("i32"), literal(3)))}
+ .WithGuarantee(greater(field_ref("f32"), literal(0.0)))
+ .Expect(equal(field_ref("i32"), literal(3)));
+
+ // simplification can see through implicit casts
+ compute::SetLookupOptions in_123{ArrayFromJSON(int32(), "[1,2,3]"), true};
+ Simplify{or_({equal(field_ref("f32"), literal(0)),
+ call("is_in", {field_ref("i64")}, in_123)})}
+ .WithGuarantee(greater(field_ref("f32"), literal(0.F)))
+ .Expect(call("is_in", {field_ref("i64")}, in_123));
+
+ Simplify{greater(field_ref("dict_i32"), literal(int64_t(1)))}
+ .WithGuarantee(equal(field_ref("dict_i32"), literal(0)))
+ .Expect(false);
+
+ Simplify{equal(field_ref("i32"), literal(7))}
+ .WithGuarantee(equal(field_ref("i32"), literal(7)))
+ .Expect(literal(true));
+
+ Simplify{equal(field_ref("i32"), literal(7))}
+ .WithGuarantee(not_(equal(field_ref("i32"), literal(7))))
+ .Expect(equal(field_ref("i32"), literal(7)));
+
+ Simplify{is_null(field_ref("i32"))}
+ .WithGuarantee(is_null(field_ref("i32")))
+ .Expect(literal(true));
+
+ Simplify{is_valid(field_ref("i32"))}
+ .WithGuarantee(is_valid(field_ref("i32")))
+ .Expect(is_valid(field_ref("i32")));
+}
+
+TEST(Expression, SimplifyThenExecute) {
+ auto filter =
+ or_({equal(field_ref("f32"), literal(0)),
+ call("is_in", {field_ref("i64")},
+ compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true})});
+
+ ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema));
+ auto guarantee = greater(field_ref("f32"), literal(0.0));
+
+ ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(filter, guarantee));
+
+ auto input = RecordBatchFromJSON(kBoringSchema, R"([
+ {"i64": 0, "f32": 0.1},
+ {"i64": 0, "f32": 0.3},
+ {"i64": 1, "f32": 0.5},
+ {"i64": 2, "f32": 0.1},
+ {"i64": 0, "f32": 0.1},
+ {"i64": 0, "f32": 0.4},
+ {"i64": 0, "f32": 1.0}
+ ])");
+
+ Datum evaluated, simplified_evaluated;
+ ExpectExecute(filter, input, &evaluated);
+ ExpectExecute(simplified, input, &simplified_evaluated);
+ AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true);
+}
+
+TEST(Expression, Filter) {
+ auto ExpectFilter = [](Expression filter, std::string batch_json) {
+ ASSERT_OK_AND_ASSIGN(auto s, kBoringSchema->AddField(0, field("in", boolean())));
+ auto batch = RecordBatchFromJSON(s, batch_json);
+ auto expected_mask = batch->column(0);
+
+ ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema));
+ ASSERT_OK_AND_ASSIGN(Datum mask,
+ ExecuteScalarExpression(filter, *kBoringSchema, batch));
+
+ AssertDatumsEqual(expected_mask, mask);
+ };
+
+ ExpectFilter(equal(field_ref("i32"), literal(0)), R"([
+ {"i32": 0, "f32": -0.1, "in": 1},
+ {"i32": 0, "f32": 0.3, "in": 1},
+ {"i32": 1, "f32": 0.2, "in": 0},
+ {"i32": 2, "f32": -0.1, "in": 0},
+ {"i32": 0, "f32": 0.1, "in": 1},
+ {"i32": 0, "f32": null, "in": 1},
+ {"i32": 0, "f32": 1.0, "in": 1}
+ ])");
+
+ ExpectFilter(
+ greater(call("multiply", {field_ref("f32"), field_ref("f64")}), literal(0)), R"([
+ {"f64": 0.3, "f32": 0.1, "in": 1},
+ {"f64": -0.1, "f32": 0.3, "in": 0},
+ {"f64": 0.1, "f32": 0.2, "in": 1},
+ {"f64": 0.0, "f32": -0.1, "in": 0},
+ {"f64": 1.0, "f32": 0.1, "in": 1},
+ {"f64": -2.0, "f32": null, "in": null},
+ {"f64": 3.0, "f32": 1.0, "in": 1}
+ ])");
+}
+
+TEST(Expression, SerializationRoundTrips) {
+ auto ExpectRoundTrips = [](const Expression& expr) {
+ ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(expr));
+ ASSERT_OK_AND_ASSIGN(Expression roundtripped, Deserialize(serialized));
+ EXPECT_EQ(expr, roundtripped);
+ };
+
+ ExpectRoundTrips(literal(MakeNullScalar(null())));
+
+ ExpectRoundTrips(literal(MakeNullScalar(int32())));
+
+ ExpectRoundTrips(
+ literal(MakeNullScalar(struct_({field("i", int32()), field("s", utf8())}))));
+
+ ExpectRoundTrips(literal(true));
+
+ ExpectRoundTrips(literal(false));
+
+ ExpectRoundTrips(literal(1));
+
+ ExpectRoundTrips(literal(1.125));
+
+ ExpectRoundTrips(literal("stringy strings"));
+
+ ExpectRoundTrips(field_ref("field"));
+
+ ExpectRoundTrips(greater(field_ref("a"), literal(0.25)));
+
+ ExpectRoundTrips(
+ or_({equal(field_ref("a"), literal(1)), not_equal(field_ref("b"), literal("hello")),
+ equal(field_ref("b"), literal("foo bar"))}));
+
+ ExpectRoundTrips(not_(field_ref("alpha")));
+
+ ExpectRoundTrips(call("is_in", {literal(1)},
+ compute::SetLookupOptions{ArrayFromJSON(int32(), "[1, 2, 3]")}));
+
+ ExpectRoundTrips(
+ call("is_in",
+ {call("cast", {field_ref("version")}, compute::CastOptions::Safe(float64()))},
+ compute::SetLookupOptions{ArrayFromJSON(float64(), "[0.5, 1.0, 2.0]"), true}));
+
+ ExpectRoundTrips(call("is_valid", {field_ref("validity")}));
+
+ ExpectRoundTrips(and_({and_(greater_equal(field_ref("x"), literal(-1.5)),
+ less(field_ref("x"), literal(0.0))),
+ and_(greater_equal(field_ref("y"), literal(0.0)),
+ less(field_ref("y"), literal(1.5))),
+ and_(greater(field_ref("z"), literal(1.5)),
+ less_equal(field_ref("z"), literal(3.0)))}));
+
+ ExpectRoundTrips(and_({equal(field_ref("year"), literal(int16_t(1999))),
+ equal(field_ref("month"), literal(int8_t(12))),
+ equal(field_ref("day"), literal(int8_t(31))),
+ equal(field_ref("hour"), literal(int8_t(0))),
+ equal(field_ref("alpha"), literal(int32_t(0))),
+ equal(field_ref("beta"), literal(3.25f))}));
+}
+
+TEST(Projection, AugmentWithNull) {
+ // NB: input contains *no columns* except i32
+ auto input = ArrayFromJSON(struct_({kBoringSchema->GetFieldByName("i32")}),
+ R"([{"i32": 0}, {"i32": 1}, {"i32": 2}])");
+
+ auto ExpectProject = [&](Expression proj, Datum expected) {
+ ASSERT_OK_AND_ASSIGN(proj, proj.Bind(*kBoringSchema));
+ ASSERT_OK_AND_ASSIGN(auto actual,
+ ExecuteScalarExpression(proj, *kBoringSchema, input));
+ AssertDatumsEqual(Datum(expected), actual);
+ };
+
+ ExpectProject(project({field_ref("f64"), field_ref("i32")},
+ {"projected double", "projected int"}),
+ // "projected double" is materialized as a column of nulls
+ ArrayFromJSON(struct_({field("projected double", float64()),
+ field("projected int", int32())}),
+ R"([
+ [null, 0],
+ [null, 1],
+ [null, 2]
+ ])"));
+
+ ExpectProject(
+ project({field_ref("f64")}, {"projected double"}),
+ // NB: only a scalar was projected, this is *not* automatically broadcast
+ // to an array. "projected double" is materialized as a null scalar
+ Datum(*StructScalar::Make({MakeNullScalar(float64())}, {"projected double"})));
+}
+
+TEST(Projection, AugmentWithKnownValues) {
+ auto input = ArrayFromJSON(struct_({kBoringSchema->GetFieldByName("i32")}),
+ R"([{"i32": 0}, {"i32": 1}, {"i32": 2}])");
+
+ auto ExpectSimplifyAndProject = [&](Expression proj, Datum expected,
+ Expression guarantee) {
+ ASSERT_OK_AND_ASSIGN(proj, proj.Bind(*kBoringSchema));
+ ASSERT_OK_AND_ASSIGN(proj, SimplifyWithGuarantee(proj, guarantee));
+ ASSERT_OK_AND_ASSIGN(auto actual,
+ ExecuteScalarExpression(proj, *kBoringSchema, input));
+ AssertDatumsEqual(Datum(expected), actual);
+ };
+
+ ExpectSimplifyAndProject(
+ project({field_ref("str"), field_ref("f64"), field_ref("i64"), field_ref("i32")},
+ {"str", "f64", "i64", "i32"}),
+ ArrayFromJSON(struct_({
+ field("str", utf8()),
+ field("f64", float64()),
+ field("i64", int64()),
+ field("i32", int32()),
+ }),
+ // str is explicitly null
+ // f64 is explicitly 3.5
+ // i64 is not specified in the guarantee and implicitly null
+ // i32 is present in the input and passed through
+ R"([
+ {"str": null, "f64": 3.5, "i64": null, "i32": 0},
+ {"str": null, "f64": 3.5, "i64": null, "i32": 1},
+ {"str": null, "f64": 3.5, "i64": null, "i32": 2}
+ ])"),
+ and_({
+ equal(field_ref("f64"), literal(3.5)),
+ is_null(field_ref("str")),
+ }));
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/filter_node.cc b/src/arrow/cpp/src/arrow/compute/exec/filter_node.cc
new file mode 100644
index 000000000..2e6d974dc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/filter_node.cc
@@ -0,0 +1,116 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/exec_plan.h"
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/expression.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/datum.h"
+#include "arrow/result.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+namespace {
+
+class FilterNode : public MapNode {
+ public:
+ FilterNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ std::shared_ptr<Schema> output_schema, Expression filter, bool async_mode)
+ : MapNode(plan, std::move(inputs), std::move(output_schema), async_mode),
+ filter_(std::move(filter)) {}
+
+ static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "FilterNode"));
+ auto schema = inputs[0]->output_schema();
+
+ const auto& filter_options = checked_cast<const FilterNodeOptions&>(options);
+
+ auto filter_expression = filter_options.filter_expression;
+ if (!filter_expression.IsBound()) {
+ ARROW_ASSIGN_OR_RAISE(filter_expression, filter_expression.Bind(*schema));
+ }
+
+ if (filter_expression.type()->id() != Type::BOOL) {
+ return Status::TypeError("Filter expression must evaluate to bool, but ",
+ filter_expression.ToString(), " evaluates to ",
+ filter_expression.type()->ToString());
+ }
+ return plan->EmplaceNode<FilterNode>(plan, std::move(inputs), std::move(schema),
+ std::move(filter_expression),
+ filter_options.async_mode);
+ }
+
+ const char* kind_name() const override { return "FilterNode"; }
+
+ Result<ExecBatch> DoFilter(const ExecBatch& target) {
+ ARROW_ASSIGN_OR_RAISE(Expression simplified_filter,
+ SimplifyWithGuarantee(filter_, target.guarantee));
+
+ ARROW_ASSIGN_OR_RAISE(Datum mask, ExecuteScalarExpression(simplified_filter, target,
+ plan()->exec_context()));
+
+ if (mask.is_scalar()) {
+ const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
+ if (mask_scalar.is_valid && mask_scalar.value) {
+ return target;
+ }
+
+ return target.Slice(0, 0);
+ }
+
+ // if the values are all scalar then the mask must also be
+ DCHECK(!std::all_of(target.values.begin(), target.values.end(),
+ [](const Datum& value) { return value.is_scalar(); }));
+
+ auto values = target.values;
+ for (auto& value : values) {
+ if (value.is_scalar()) continue;
+ ARROW_ASSIGN_OR_RAISE(value, Filter(value, mask, FilterOptions::Defaults()));
+ }
+ return ExecBatch::Make(std::move(values));
+ }
+
+ void InputReceived(ExecNode* input, ExecBatch batch) override {
+ DCHECK_EQ(input, inputs_[0]);
+ auto func = [this](ExecBatch batch) { return DoFilter(std::move(batch)); };
+ this->SubmitTask(std::move(func), std::move(batch));
+ }
+
+ protected:
+ std::string ToStringExtra() const override { return "filter=" + filter_.ToString(); }
+
+ private:
+ Expression filter_;
+};
+} // namespace
+
+namespace internal {
+void RegisterFilterNode(ExecFactoryRegistry* registry) {
+ DCHECK_OK(registry->AddFactory("filter", FilterNode::Make));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/forest_internal.h b/src/arrow/cpp/src/arrow/compute/exec/forest_internal.h
new file mode 100644
index 000000000..7b55a0aab
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/forest_internal.h
@@ -0,0 +1,125 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+
+namespace arrow {
+namespace compute {
+
+/// A Forest is a view of a sorted range which carries an ancestry relation in addition
+/// to an ordering relation: each element's descendants appear directly after it.
+/// This can be used to efficiently skip subtrees when iterating through the range.
+class Forest {
+ public:
+ Forest() = default;
+
+ /// \brief Construct a Forest viewing the range [0, size).
+ Forest(int size, std::function<bool(int, int)> is_ancestor) : size_(size) {
+ std::vector<int> descendant_counts(size, 0);
+
+ std::vector<int> parent_stack;
+
+ for (int i = 0; i < size; ++i) {
+ while (parent_stack.size() != 0) {
+ if (is_ancestor(parent_stack.back(), i)) break;
+
+ // parent_stack.back() has no more descendants; finalize count and pop
+ descendant_counts[parent_stack.back()] = i - 1 - parent_stack.back();
+ parent_stack.pop_back();
+ }
+
+ parent_stack.push_back(i);
+ }
+
+ // finalize descendant_counts for anything left in the stack
+ while (parent_stack.size() != 0) {
+ descendant_counts[parent_stack.back()] = size - 1 - parent_stack.back();
+ parent_stack.pop_back();
+ }
+
+ descendant_counts_ = std::make_shared<std::vector<int>>(std::move(descendant_counts));
+ }
+
+ /// \brief Returns the number of nodes in this forest.
+ int size() const { return size_; }
+
+ bool Equals(const Forest& other) const {
+ auto it = descendant_counts_->begin();
+ return size_ == other.size_ &&
+ std::equal(it, it + size_, other.descendant_counts_->begin());
+ }
+
+ struct Ref {
+ int num_descendants() const { return forest->descendant_counts_->at(i); }
+
+ bool IsAncestorOf(const Ref& ref) const {
+ return i < ref.i && i + 1 + num_descendants() > ref.i;
+ }
+
+ explicit operator bool() const { return forest != NULLPTR; }
+
+ const Forest* forest;
+ int i;
+ };
+
+ /// \brief Visit with eager pruning. Visitors must return Result<bool>, using
+ /// true to indicate a subtree should be visited and false to indicate that the
+ /// subtree should be skipped.
+ template <typename PreVisitor, typename PostVisitor>
+ Status Visit(PreVisitor&& pre, PostVisitor&& post) const {
+ std::vector<Ref> parent_stack;
+
+ for (int i = 0; i < size_; ++i) {
+ Ref ref = {this, i};
+
+ while (parent_stack.size() > 0) {
+ if (parent_stack.back().IsAncestorOf(ref)) break;
+
+ post(parent_stack.back());
+ parent_stack.pop_back();
+ }
+
+ ARROW_ASSIGN_OR_RAISE(bool visit_subtree, pre(ref));
+
+ if (!visit_subtree) {
+ // skip descendants
+ i += ref.num_descendants();
+ continue;
+ }
+
+ parent_stack.push_back(ref);
+ }
+
+ return Status::OK();
+ }
+
+ Ref operator[](int i) const { return Ref{this, i}; }
+
+ private:
+ int size_ = 0;
+ std::shared_ptr<std::vector<int>> descendant_counts_;
+};
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/hash_join.cc b/src/arrow/cpp/src/arrow/compute/exec/hash_join.cc
new file mode 100644
index 000000000..a89e23796
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/hash_join.cc
@@ -0,0 +1,795 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/hash_join.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <mutex>
+#include <unordered_map>
+#include <vector>
+
+#include "arrow/compute/exec/hash_join_dict.h"
+#include "arrow/compute/exec/task_util.h"
+#include "arrow/compute/kernels/row_encoder.h"
+
+namespace arrow {
+namespace compute {
+
+using internal::RowEncoder;
+
+class HashJoinBasicImpl : public HashJoinImpl {
+ private:
+ struct ThreadLocalState;
+
+ public:
+ Status InputReceived(size_t thread_index, int side, ExecBatch batch) override {
+ if (cancelled_) {
+ return Status::Cancelled("Hash join cancelled");
+ }
+ if (QueueBatchIfNeeded(side, batch)) {
+ return Status::OK();
+ } else {
+ ARROW_DCHECK(side == 0);
+ return ProbeBatch(thread_index, batch);
+ }
+ }
+
+ Status InputFinished(size_t thread_index, int side) override {
+ if (cancelled_) {
+ return Status::Cancelled("Hash join cancelled");
+ }
+ if (side == 0) {
+ bool proceed;
+ {
+ std::lock_guard<std::mutex> lock(finished_mutex_);
+ proceed = !left_side_finished_ && left_queue_finished_;
+ left_side_finished_ = true;
+ }
+ if (proceed) {
+ RETURN_NOT_OK(OnLeftSideAndQueueFinished(thread_index));
+ }
+ } else {
+ bool proceed;
+ {
+ std::lock_guard<std::mutex> lock(finished_mutex_);
+ proceed = !right_side_finished_;
+ right_side_finished_ = true;
+ }
+ if (proceed) {
+ RETURN_NOT_OK(OnRightSideFinished(thread_index));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
+ size_t num_threads, HashJoinSchema* schema_mgr,
+ std::vector<JoinKeyCmp> key_cmp, OutputBatchCallback output_batch_callback,
+ FinishedCallback finished_callback,
+ TaskScheduler::ScheduleImpl schedule_task_callback) override {
+ num_threads = std::max(num_threads, static_cast<size_t>(1));
+
+ ctx_ = ctx;
+ join_type_ = join_type;
+ num_threads_ = num_threads;
+ schema_mgr_ = schema_mgr;
+ key_cmp_ = std::move(key_cmp);
+ output_batch_callback_ = std::move(output_batch_callback);
+ finished_callback_ = std::move(finished_callback);
+ local_states_.resize(num_threads);
+ for (size_t i = 0; i < local_states_.size(); ++i) {
+ local_states_[i].is_initialized = false;
+ local_states_[i].is_has_match_initialized = false;
+ }
+ dict_probe_.Init(num_threads);
+
+ has_hash_table_ = false;
+ num_batches_produced_.store(0);
+ cancelled_ = false;
+ right_side_finished_ = false;
+ left_side_finished_ = false;
+ left_queue_finished_ = false;
+
+ scheduler_ = TaskScheduler::Make();
+ RegisterBuildHashTable();
+ RegisterProbeQueuedBatches();
+ RegisterScanHashTable();
+ scheduler_->RegisterEnd();
+ RETURN_NOT_OK(scheduler_->StartScheduling(
+ 0 /*thread index*/, std::move(schedule_task_callback),
+ static_cast<int>(2 * num_threads) /*concurrent tasks*/, use_sync_execution));
+
+ return Status::OK();
+ }
+
+ void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) override {
+ cancelled_ = true;
+ scheduler_->Abort(std::move(pos_abort_callback));
+ }
+
+ private:
+ void InitEncoder(int side, HashJoinProjection projection_handle, RowEncoder* encoder) {
+ std::vector<ValueDescr> data_types;
+ int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
+ data_types.resize(num_cols);
+ for (int icol = 0; icol < num_cols; ++icol) {
+ data_types[icol] =
+ ValueDescr(schema_mgr_->proj_maps[side].data_type(projection_handle, icol),
+ ValueDescr::ARRAY);
+ }
+ encoder->Init(data_types, ctx_);
+ encoder->Clear();
+ }
+
+ void InitLocalStateIfNeeded(size_t thread_index) {
+ ThreadLocalState& local_state = local_states_[thread_index];
+ if (!local_state.is_initialized) {
+ InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys);
+ bool has_payload =
+ (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
+ if (has_payload) {
+ InitEncoder(0, HashJoinProjection::PAYLOAD, &local_state.exec_batch_payloads);
+ }
+
+ local_state.is_initialized = true;
+ }
+ }
+
+ Status EncodeBatch(int side, HashJoinProjection projection_handle, RowEncoder* encoder,
+ const ExecBatch& batch, ExecBatch* opt_projected_batch = nullptr) {
+ ExecBatch projected({}, batch.length);
+ int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
+ projected.values.resize(num_cols);
+
+ auto to_input =
+ schema_mgr_->proj_maps[side].map(projection_handle, HashJoinProjection::INPUT);
+ for (int icol = 0; icol < num_cols; ++icol) {
+ projected.values[icol] = batch.values[to_input.get(icol)];
+ }
+
+ if (opt_projected_batch) {
+ *opt_projected_batch = projected;
+ }
+
+ return encoder->EncodeAndAppend(projected);
+ }
+
+ void ProbeBatch_Lookup(ThreadLocalState* local_state, const RowEncoder& exec_batch_keys,
+ const std::vector<const uint8_t*>& non_null_bit_vectors,
+ const std::vector<int64_t>& non_null_bit_vector_offsets,
+ std::vector<int32_t>* output_match,
+ std::vector<int32_t>* output_no_match,
+ std::vector<int32_t>* output_match_left,
+ std::vector<int32_t>* output_match_right) {
+ InitHasMatchIfNeeded(local_state);
+
+ ARROW_DCHECK(has_hash_table_);
+
+ InitHasMatchIfNeeded(local_state);
+
+ int num_cols = static_cast<int>(non_null_bit_vectors.size());
+ for (int32_t irow = 0; irow < exec_batch_keys.num_rows(); ++irow) {
+ // Apply null key filtering
+ bool no_match = hash_table_empty_;
+ for (int icol = 0; icol < num_cols; ++icol) {
+ bool is_null = non_null_bit_vectors[icol] &&
+ !BitUtil::GetBit(non_null_bit_vectors[icol],
+ non_null_bit_vector_offsets[icol] + irow);
+ if (key_cmp_[icol] == JoinKeyCmp::EQ && is_null) {
+ no_match = true;
+ break;
+ }
+ }
+ if (no_match) {
+ output_no_match->push_back(irow);
+ continue;
+ }
+ // Get all matches from hash table
+ bool has_match = false;
+
+ auto range = hash_table_.equal_range(exec_batch_keys.encoded_row(irow));
+ for (auto it = range.first; it != range.second; ++it) {
+ output_match_left->push_back(irow);
+ output_match_right->push_back(it->second);
+ // Mark row in hash table as having a match
+ BitUtil::SetBit(local_state->has_match.data(), it->second);
+ has_match = true;
+ }
+ if (!has_match) {
+ output_no_match->push_back(irow);
+ } else {
+ output_match->push_back(irow);
+ }
+ }
+ }
+
+ void ProbeBatch_OutputOne(int64_t batch_size_next, ExecBatch* opt_left_key,
+ ExecBatch* opt_left_payload, ExecBatch* opt_right_key,
+ ExecBatch* opt_right_payload) {
+ ExecBatch result({}, batch_size_next);
+ int num_out_cols_left =
+ schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT);
+ int num_out_cols_right =
+ schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT);
+ ARROW_DCHECK((opt_left_payload == nullptr) ==
+ (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) == 0));
+ ARROW_DCHECK((opt_right_payload == nullptr) ==
+ (schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) == 0));
+ result.values.resize(num_out_cols_left + num_out_cols_right);
+ auto from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
+ HashJoinProjection::KEY);
+ auto from_payload = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
+ HashJoinProjection::PAYLOAD);
+ for (int icol = 0; icol < num_out_cols_left; ++icol) {
+ bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
+ bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
+ ARROW_DCHECK(is_from_key != is_from_payload);
+ ARROW_DCHECK(!is_from_key ||
+ (opt_left_key &&
+ from_key.get(icol) < static_cast<int>(opt_left_key->values.size()) &&
+ opt_left_key->length == batch_size_next));
+ ARROW_DCHECK(
+ !is_from_payload ||
+ (opt_left_payload &&
+ from_payload.get(icol) < static_cast<int>(opt_left_payload->values.size()) &&
+ opt_left_payload->length == batch_size_next));
+ result.values[icol] = is_from_key
+ ? opt_left_key->values[from_key.get(icol)]
+ : opt_left_payload->values[from_payload.get(icol)];
+ }
+ from_key = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
+ HashJoinProjection::KEY);
+ from_payload = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
+ HashJoinProjection::PAYLOAD);
+ for (int icol = 0; icol < num_out_cols_right; ++icol) {
+ bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
+ bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
+ ARROW_DCHECK(is_from_key != is_from_payload);
+ ARROW_DCHECK(!is_from_key ||
+ (opt_right_key &&
+ from_key.get(icol) < static_cast<int>(opt_right_key->values.size()) &&
+ opt_right_key->length == batch_size_next));
+ ARROW_DCHECK(
+ !is_from_payload ||
+ (opt_right_payload &&
+ from_payload.get(icol) < static_cast<int>(opt_right_payload->values.size()) &&
+ opt_right_payload->length == batch_size_next));
+ result.values[num_out_cols_left + icol] =
+ is_from_key ? opt_right_key->values[from_key.get(icol)]
+ : opt_right_payload->values[from_payload.get(icol)];
+ }
+
+ output_batch_callback_(std::move(result));
+
+ // Update the counter of produced batches
+ //
+ num_batches_produced_++;
+ }
+
+ Status ProbeBatch_OutputOne(size_t thread_index, int64_t batch_size_next,
+ const int32_t* opt_left_ids, const int32_t* opt_right_ids) {
+ if (batch_size_next == 0 || (!opt_left_ids && !opt_right_ids)) {
+ return Status::OK();
+ }
+
+ bool has_left =
+ (join_type_ != JoinType::RIGHT_SEMI && join_type_ != JoinType::RIGHT_ANTI &&
+ schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT) > 0);
+ bool has_right =
+ (join_type_ != JoinType::LEFT_SEMI && join_type_ != JoinType::LEFT_ANTI &&
+ schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT) > 0);
+ bool has_left_payload =
+ has_left && (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
+ bool has_right_payload =
+ has_right &&
+ (schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0);
+
+ ThreadLocalState& local_state = local_states_[thread_index];
+ InitLocalStateIfNeeded(thread_index);
+
+ ExecBatch left_key;
+ ExecBatch left_payload;
+ ExecBatch right_key;
+ ExecBatch right_payload;
+ if (has_left) {
+ ARROW_DCHECK(opt_left_ids);
+ ARROW_ASSIGN_OR_RAISE(
+ left_key, local_state.exec_batch_keys.Decode(batch_size_next, opt_left_ids));
+ }
+ if (has_left_payload) {
+ ARROW_ASSIGN_OR_RAISE(left_payload, local_state.exec_batch_payloads.Decode(
+ batch_size_next, opt_left_ids));
+ }
+ if (has_right) {
+ ARROW_DCHECK(opt_right_ids);
+ ARROW_ASSIGN_OR_RAISE(right_key,
+ hash_table_keys_.Decode(batch_size_next, opt_right_ids));
+ // Post process build side keys that use dictionary
+ RETURN_NOT_OK(dict_build_.PostDecode(schema_mgr_->proj_maps[1], &right_key, ctx_));
+ }
+ if (has_right_payload) {
+ ARROW_ASSIGN_OR_RAISE(right_payload,
+ hash_table_payloads_.Decode(batch_size_next, opt_right_ids));
+ }
+
+ ProbeBatch_OutputOne(batch_size_next, has_left ? &left_key : nullptr,
+ has_left_payload ? &left_payload : nullptr,
+ has_right ? &right_key : nullptr,
+ has_right_payload ? &right_payload : nullptr);
+
+ return Status::OK();
+ }
+
+ Status ProbeBatch_OutputAll(size_t thread_index, const RowEncoder& exec_batch_keys,
+ const RowEncoder& exec_batch_payloads,
+ const std::vector<int32_t>& match,
+ const std::vector<int32_t>& no_match,
+ std::vector<int32_t>& match_left,
+ std::vector<int32_t>& match_right) {
+ if (join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) {
+ // Nothing to output
+ return Status::OK();
+ }
+
+ if (join_type_ == JoinType::LEFT_ANTI || join_type_ == JoinType::LEFT_SEMI) {
+ const std::vector<int32_t>& out_ids =
+ (join_type_ == JoinType::LEFT_SEMI) ? match : no_match;
+
+ for (size_t start = 0; start < out_ids.size(); start += output_batch_size_) {
+ int64_t batch_size_next = std::min(static_cast<int64_t>(out_ids.size() - start),
+ static_cast<int64_t>(output_batch_size_));
+ RETURN_NOT_OK(ProbeBatch_OutputOne(thread_index, batch_size_next,
+ out_ids.data() + start, nullptr));
+ }
+ } else {
+ if (join_type_ == JoinType::LEFT_OUTER || join_type_ == JoinType::FULL_OUTER) {
+ for (size_t i = 0; i < no_match.size(); ++i) {
+ match_left.push_back(no_match[i]);
+ match_right.push_back(RowEncoder::kRowIdForNulls());
+ }
+ }
+
+ ARROW_DCHECK(match_left.size() == match_right.size());
+
+ for (size_t start = 0; start < match_left.size(); start += output_batch_size_) {
+ int64_t batch_size_next =
+ std::min(static_cast<int64_t>(match_left.size() - start),
+ static_cast<int64_t>(output_batch_size_));
+ RETURN_NOT_OK(ProbeBatch_OutputOne(thread_index, batch_size_next,
+ match_left.data() + start,
+ match_right.data() + start));
+ }
+ }
+ return Status::OK();
+ }
+
+ void NullInfoFromBatch(const ExecBatch& batch,
+ std::vector<const uint8_t*>* nn_bit_vectors,
+ std::vector<int64_t>* nn_offsets,
+ std::vector<uint8_t>* nn_bit_vector_all_nulls) {
+ int num_cols = static_cast<int>(batch.values.size());
+ nn_bit_vectors->resize(num_cols);
+ nn_offsets->resize(num_cols);
+ nn_bit_vector_all_nulls->clear();
+ for (int64_t i = 0; i < num_cols; ++i) {
+ const uint8_t* nn = nullptr;
+ int64_t offset = 0;
+ if (batch[i].is_array()) {
+ if (batch[i].array()->buffers[0] != NULLPTR) {
+ nn = batch[i].array()->buffers[0]->data();
+ offset = batch[i].array()->offset;
+ }
+ } else {
+ ARROW_DCHECK(batch[i].is_scalar());
+ if (!batch[i].scalar_as<arrow::internal::PrimitiveScalarBase>().is_valid) {
+ if (nn_bit_vector_all_nulls->empty()) {
+ nn_bit_vector_all_nulls->resize(BitUtil::BytesForBits(batch.length));
+ memset(nn_bit_vector_all_nulls->data(), 0,
+ BitUtil::BytesForBits(batch.length));
+ }
+ nn = nn_bit_vector_all_nulls->data();
+ }
+ }
+ (*nn_bit_vectors)[i] = nn;
+ (*nn_offsets)[i] = offset;
+ }
+ }
+
+ Status ProbeBatch(size_t thread_index, const ExecBatch& batch) {
+ ThreadLocalState& local_state = local_states_[thread_index];
+ InitLocalStateIfNeeded(thread_index);
+
+ local_state.exec_batch_keys.Clear();
+
+ ExecBatch batch_key_for_lookups;
+
+ RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::KEY, &local_state.exec_batch_keys,
+ batch, &batch_key_for_lookups));
+ bool has_left_payload =
+ (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
+ if (has_left_payload) {
+ local_state.exec_batch_payloads.Clear();
+ RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::PAYLOAD,
+ &local_state.exec_batch_payloads, batch));
+ }
+
+ local_state.match.clear();
+ local_state.no_match.clear();
+ local_state.match_left.clear();
+ local_state.match_right.clear();
+
+ bool use_key_batch_for_dicts = dict_probe_.BatchRemapNeeded(
+ thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], ctx_);
+ RowEncoder* row_encoder_for_lookups = &local_state.exec_batch_keys;
+ if (use_key_batch_for_dicts) {
+ RETURN_NOT_OK(dict_probe_.EncodeBatch(
+ thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], dict_build_,
+ batch, &row_encoder_for_lookups, &batch_key_for_lookups, ctx_));
+ }
+
+ // Collect information about all nulls in key columns.
+ //
+ std::vector<const uint8_t*> non_null_bit_vectors;
+ std::vector<int64_t> non_null_bit_vector_offsets;
+ std::vector<uint8_t> all_nulls;
+ NullInfoFromBatch(batch_key_for_lookups, &non_null_bit_vectors,
+ &non_null_bit_vector_offsets, &all_nulls);
+
+ ProbeBatch_Lookup(&local_state, *row_encoder_for_lookups, non_null_bit_vectors,
+ non_null_bit_vector_offsets, &local_state.match,
+ &local_state.no_match, &local_state.match_left,
+ &local_state.match_right);
+
+ RETURN_NOT_OK(ProbeBatch_OutputAll(thread_index, local_state.exec_batch_keys,
+ local_state.exec_batch_payloads, local_state.match,
+ local_state.no_match, local_state.match_left,
+ local_state.match_right));
+
+ return Status::OK();
+ }
+
+ int64_t BuildHashTable_num_tasks() { return 1; }
+
+ Status BuildHashTable_exec_task(size_t thread_index, int64_t /*task_id*/) {
+ const std::vector<ExecBatch>& batches = right_batches_;
+ if (batches.empty()) {
+ hash_table_empty_ = true;
+ } else {
+ dict_build_.InitEncoder(schema_mgr_->proj_maps[1], &hash_table_keys_, ctx_);
+ bool has_payload =
+ (schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0);
+ if (has_payload) {
+ InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_);
+ }
+ hash_table_empty_ = true;
+ for (size_t ibatch = 0; ibatch < batches.size(); ++ibatch) {
+ if (cancelled_) {
+ return Status::Cancelled("Hash join cancelled");
+ }
+ const ExecBatch& batch = batches[ibatch];
+ if (batch.length == 0) {
+ continue;
+ } else if (hash_table_empty_) {
+ hash_table_empty_ = false;
+
+ RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], &batch, ctx_));
+ }
+ int32_t num_rows_before = hash_table_keys_.num_rows();
+ RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, schema_mgr_->proj_maps[1],
+ batch, &hash_table_keys_, ctx_));
+ if (has_payload) {
+ RETURN_NOT_OK(
+ EncodeBatch(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_, batch));
+ }
+ int32_t num_rows_after = hash_table_keys_.num_rows();
+ for (int32_t irow = num_rows_before; irow < num_rows_after; ++irow) {
+ hash_table_.insert(std::make_pair(hash_table_keys_.encoded_row(irow), irow));
+ }
+ }
+ }
+
+ if (hash_table_empty_) {
+ RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], nullptr, ctx_));
+ }
+
+ return Status::OK();
+ }
+
+ Status BuildHashTable_on_finished(size_t thread_index) {
+ if (cancelled_) {
+ return Status::Cancelled("Hash join cancelled");
+ }
+
+ {
+ std::lock_guard<std::mutex> lock(left_batches_mutex_);
+ has_hash_table_ = true;
+ }
+
+ right_batches_.clear();
+
+ RETURN_NOT_OK(ProbeQueuedBatches(thread_index));
+
+ return Status::OK();
+ }
+
+ void RegisterBuildHashTable() {
+ task_group_build_ = scheduler_->RegisterTaskGroup(
+ [this](size_t thread_index, int64_t task_id) -> Status {
+ return BuildHashTable_exec_task(thread_index, task_id);
+ },
+ [this](size_t thread_index) -> Status {
+ return BuildHashTable_on_finished(thread_index);
+ });
+ }
+
+ Status BuildHashTable(size_t thread_index) {
+ return scheduler_->StartTaskGroup(thread_index, task_group_build_,
+ BuildHashTable_num_tasks());
+ }
+
+ int64_t ProbeQueuedBatches_num_tasks() {
+ return static_cast<int64_t>(left_batches_.size());
+ }
+
+ Status ProbeQueuedBatches_exec_task(size_t thread_index, int64_t task_id) {
+ if (cancelled_) {
+ return Status::Cancelled("Hash join cancelled");
+ }
+ return ProbeBatch(thread_index, std::move(left_batches_[task_id]));
+ }
+
+ Status ProbeQueuedBatches_on_finished(size_t thread_index) {
+ if (cancelled_) {
+ return Status::Cancelled("Hash join cancelled");
+ }
+
+ left_batches_.clear();
+
+ bool proceed;
+ {
+ std::lock_guard<std::mutex> lock(finished_mutex_);
+ proceed = left_side_finished_ && !left_queue_finished_;
+ left_queue_finished_ = true;
+ }
+ if (proceed) {
+ RETURN_NOT_OK(OnLeftSideAndQueueFinished(thread_index));
+ }
+
+ return Status::OK();
+ }
+
+ void RegisterProbeQueuedBatches() {
+ task_group_queued_ = scheduler_->RegisterTaskGroup(
+ [this](size_t thread_index, int64_t task_id) -> Status {
+ return ProbeQueuedBatches_exec_task(thread_index, task_id);
+ },
+ [this](size_t thread_index) -> Status {
+ return ProbeQueuedBatches_on_finished(thread_index);
+ });
+ }
+
+ Status ProbeQueuedBatches(size_t thread_index) {
+ return scheduler_->StartTaskGroup(thread_index, task_group_queued_,
+ ProbeQueuedBatches_num_tasks());
+ }
+
+ int64_t ScanHashTable_num_tasks() {
+ if (!has_hash_table_ || hash_table_empty_) {
+ return 0;
+ }
+ if (join_type_ != JoinType::RIGHT_SEMI && join_type_ != JoinType::RIGHT_ANTI &&
+ join_type_ != JoinType::RIGHT_OUTER && join_type_ != JoinType::FULL_OUTER) {
+ return 0;
+ }
+ return BitUtil::CeilDiv(hash_table_keys_.num_rows(), hash_table_scan_unit_);
+ }
+
+ Status ScanHashTable_exec_task(size_t thread_index, int64_t task_id) {
+ if (cancelled_) {
+ return Status::Cancelled("Hash join cancelled");
+ }
+
+ int32_t start_row_id = static_cast<int32_t>(hash_table_scan_unit_ * task_id);
+ int32_t end_row_id =
+ static_cast<int32_t>(std::min(static_cast<int64_t>(hash_table_keys_.num_rows()),
+ hash_table_scan_unit_ * (task_id + 1)));
+
+ ThreadLocalState& local_state = local_states_[thread_index];
+ InitLocalStateIfNeeded(thread_index);
+
+ std::vector<int32_t>& id_left = local_state.no_match;
+ std::vector<int32_t>& id_right = local_state.match;
+ id_left.clear();
+ id_right.clear();
+ bool use_left = false;
+
+ bool match_search_value = (join_type_ == JoinType::RIGHT_SEMI);
+ for (int32_t row_id = start_row_id; row_id < end_row_id; ++row_id) {
+ if (BitUtil::GetBit(has_match_.data(), row_id) == match_search_value) {
+ id_right.push_back(row_id);
+ }
+ }
+
+ if (id_right.empty()) {
+ return Status::OK();
+ }
+
+ if (join_type_ != JoinType::RIGHT_SEMI && join_type_ != JoinType::RIGHT_ANTI) {
+ use_left = true;
+ id_left.resize(id_right.size());
+ for (size_t i = 0; i < id_left.size(); ++i) {
+ id_left[i] = RowEncoder::kRowIdForNulls();
+ }
+ }
+
+ RETURN_NOT_OK(
+ ProbeBatch_OutputOne(thread_index, static_cast<int64_t>(id_right.size()),
+ use_left ? id_left.data() : nullptr, id_right.data()));
+ return Status::OK();
+ }
+
+ Status ScanHashTable_on_finished(size_t thread_index) {
+ if (cancelled_) {
+ return Status::Cancelled("Hash join cancelled");
+ }
+ finished_callback_(num_batches_produced_.load());
+ return Status::OK();
+ }
+
+ void RegisterScanHashTable() {
+ task_group_scan_ = scheduler_->RegisterTaskGroup(
+ [this](size_t thread_index, int64_t task_id) -> Status {
+ return ScanHashTable_exec_task(thread_index, task_id);
+ },
+ [this](size_t thread_index) -> Status {
+ return ScanHashTable_on_finished(thread_index);
+ });
+ }
+
+ Status ScanHashTable(size_t thread_index) {
+ MergeHasMatch();
+ return scheduler_->StartTaskGroup(thread_index, task_group_scan_,
+ ScanHashTable_num_tasks());
+ }
+
+ bool QueueBatchIfNeeded(int side, ExecBatch batch) {
+ if (side == 0) {
+ std::lock_guard<std::mutex> lock(left_batches_mutex_);
+ if (has_hash_table_) {
+ return false;
+ }
+ left_batches_.emplace_back(std::move(batch));
+ return true;
+ } else {
+ std::lock_guard<std::mutex> lock(right_batches_mutex_);
+ right_batches_.emplace_back(std::move(batch));
+ return true;
+ }
+ }
+
+ Status OnRightSideFinished(size_t thread_index) { return BuildHashTable(thread_index); }
+
+ Status OnLeftSideAndQueueFinished(size_t thread_index) {
+ return ScanHashTable(thread_index);
+ }
+
+ void InitHasMatchIfNeeded(ThreadLocalState* local_state) {
+ if (local_state->is_has_match_initialized) {
+ return;
+ }
+ if (!hash_table_empty_) {
+ int32_t num_rows = hash_table_keys_.num_rows();
+ local_state->has_match.resize(BitUtil::BytesForBits(num_rows));
+ memset(local_state->has_match.data(), 0, BitUtil::BytesForBits(num_rows));
+ }
+ local_state->is_has_match_initialized = true;
+ }
+
+ void MergeHasMatch() {
+ if (hash_table_empty_) {
+ return;
+ }
+
+ int32_t num_rows = hash_table_keys_.num_rows();
+ has_match_.resize(BitUtil::BytesForBits(num_rows));
+ memset(has_match_.data(), 0, BitUtil::BytesForBits(num_rows));
+
+ for (size_t tid = 0; tid < local_states_.size(); ++tid) {
+ if (!local_states_[tid].is_initialized) {
+ continue;
+ }
+ if (!local_states_[tid].is_has_match_initialized) {
+ continue;
+ }
+ arrow::internal::BitmapOr(has_match_.data(), 0, local_states_[tid].has_match.data(),
+ 0, num_rows, 0, has_match_.data());
+ }
+ }
+
+ static constexpr int64_t hash_table_scan_unit_ = 32 * 1024;
+ static constexpr int64_t output_batch_size_ = 32 * 1024;
+
+ // Metadata
+ //
+ ExecContext* ctx_;
+ JoinType join_type_;
+ size_t num_threads_;
+ HashJoinSchema* schema_mgr_;
+ std::vector<JoinKeyCmp> key_cmp_;
+ std::unique_ptr<TaskScheduler> scheduler_;
+ int task_group_build_;
+ int task_group_queued_;
+ int task_group_scan_;
+
+ // Callbacks
+ //
+ OutputBatchCallback output_batch_callback_;
+ FinishedCallback finished_callback_;
+
+ // Thread local runtime state
+ //
+ struct ThreadLocalState {
+ bool is_initialized;
+ RowEncoder exec_batch_keys;
+ RowEncoder exec_batch_payloads;
+ std::vector<int32_t> match;
+ std::vector<int32_t> no_match;
+ std::vector<int32_t> match_left;
+ std::vector<int32_t> match_right;
+ bool is_has_match_initialized;
+ std::vector<uint8_t> has_match;
+ };
+ std::vector<ThreadLocalState> local_states_;
+
+ // Shared runtime state
+ //
+ RowEncoder hash_table_keys_;
+ RowEncoder hash_table_payloads_;
+ std::unordered_multimap<std::string, int32_t> hash_table_;
+ std::vector<uint8_t> has_match_;
+ bool hash_table_empty_;
+
+ // Dictionary handling
+ //
+ HashJoinDictBuildMulti dict_build_;
+ HashJoinDictProbeMulti dict_probe_;
+
+ std::vector<ExecBatch> left_batches_;
+ bool has_hash_table_;
+ std::mutex left_batches_mutex_;
+
+ std::vector<ExecBatch> right_batches_;
+ std::mutex right_batches_mutex_;
+
+ std::atomic<int64_t> num_batches_produced_;
+ bool cancelled_;
+
+ bool right_side_finished_;
+ bool left_side_finished_;
+ bool left_queue_finished_;
+ std::mutex finished_mutex_;
+};
+
+Result<std::unique_ptr<HashJoinImpl>> HashJoinImpl::MakeBasic() {
+ std::unique_ptr<HashJoinImpl> impl{new HashJoinBasicImpl()};
+ return std::move(impl);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/hash_join.h b/src/arrow/cpp/src/arrow/compute/exec/hash_join.h
new file mode 100644
index 000000000..6520e4ae4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/hash_join.h
@@ -0,0 +1,95 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/schema_util.h"
+#include "arrow/compute/exec/task_util.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+
+namespace arrow {
+namespace compute {
+
+class ARROW_EXPORT HashJoinSchema {
+ public:
+ Status Init(JoinType join_type, const Schema& left_schema,
+ const std::vector<FieldRef>& left_keys, const Schema& right_schema,
+ const std::vector<FieldRef>& right_keys,
+ const std::string& left_field_name_prefix,
+ const std::string& right_field_name_prefix);
+
+ Status Init(JoinType join_type, const Schema& left_schema,
+ const std::vector<FieldRef>& left_keys,
+ const std::vector<FieldRef>& left_output, const Schema& right_schema,
+ const std::vector<FieldRef>& right_keys,
+ const std::vector<FieldRef>& right_output,
+ const std::string& left_field_name_prefix,
+ const std::string& right_field_name_prefix);
+
+ static Status ValidateSchemas(JoinType join_type, const Schema& left_schema,
+ const std::vector<FieldRef>& left_keys,
+ const std::vector<FieldRef>& left_output,
+ const Schema& right_schema,
+ const std::vector<FieldRef>& right_keys,
+ const std::vector<FieldRef>& right_output,
+ const std::string& left_field_name_prefix,
+ const std::string& right_field_name_prefix);
+
+ std::shared_ptr<Schema> MakeOutputSchema(const std::string& left_field_name_prefix,
+ const std::string& right_field_name_prefix);
+
+ static int kMissingField() {
+ return SchemaProjectionMaps<HashJoinProjection>::kMissingField;
+ }
+
+ SchemaProjectionMaps<HashJoinProjection> proj_maps[2];
+
+ private:
+ static bool IsTypeSupported(const DataType& type);
+ static Result<std::vector<FieldRef>> VectorDiff(const Schema& schema,
+ const std::vector<FieldRef>& a,
+ const std::vector<FieldRef>& b);
+};
+
+class HashJoinImpl {
+ public:
+ using OutputBatchCallback = std::function<void(ExecBatch)>;
+ using FinishedCallback = std::function<void(int64_t)>;
+
+ virtual ~HashJoinImpl() = default;
+ virtual Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
+ size_t num_threads, HashJoinSchema* schema_mgr,
+ std::vector<JoinKeyCmp> key_cmp,
+ OutputBatchCallback output_batch_callback,
+ FinishedCallback finished_callback,
+ TaskScheduler::ScheduleImpl schedule_task_callback) = 0;
+ virtual Status InputReceived(size_t thread_index, int side, ExecBatch batch) = 0;
+ virtual Status InputFinished(size_t thread_index, int side) = 0;
+ virtual void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) = 0;
+
+ static Result<std::unique_ptr<HashJoinImpl>> MakeBasic();
+};
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/hash_join_dict.cc b/src/arrow/cpp/src/arrow/compute/exec/hash_join_dict.cc
new file mode 100644
index 000000000..195331a59
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/hash_join_dict.cc
@@ -0,0 +1,665 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/hash_join_dict.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+namespace compute {
+
+bool HashJoinDictUtil::KeyDataTypesValid(
+ const std::shared_ptr<DataType>& probe_data_type,
+ const std::shared_ptr<DataType>& build_data_type) {
+ bool l_is_dict = (probe_data_type->id() == Type::DICTIONARY);
+ bool r_is_dict = (build_data_type->id() == Type::DICTIONARY);
+ DataType* l_type;
+ if (l_is_dict) {
+ const auto& dict_type = checked_cast<const DictionaryType&>(*probe_data_type);
+ l_type = dict_type.value_type().get();
+ } else {
+ l_type = probe_data_type.get();
+ }
+ DataType* r_type;
+ if (r_is_dict) {
+ const auto& dict_type = checked_cast<const DictionaryType&>(*build_data_type);
+ r_type = dict_type.value_type().get();
+ } else {
+ r_type = build_data_type.get();
+ }
+ return l_type->Equals(*r_type);
+}
+
+Result<std::shared_ptr<ArrayData>> HashJoinDictUtil::IndexRemapUsingLUT(
+ ExecContext* ctx, const Datum& indices, int64_t batch_length,
+ const std::shared_ptr<ArrayData>& map_array,
+ const std::shared_ptr<DataType>& data_type) {
+ ARROW_DCHECK(indices.is_array() || indices.is_scalar());
+
+ const uint8_t* map_non_nulls = map_array->buffers[0]->data();
+ const int32_t* map = reinterpret_cast<const int32_t*>(map_array->buffers[1]->data());
+
+ ARROW_DCHECK(data_type->id() == Type::DICTIONARY);
+ const auto& dict_type = checked_cast<const DictionaryType&>(*data_type);
+
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<ArrayData> result,
+ ConvertToInt32(dict_type.index_type(), indices, batch_length, ctx));
+
+ uint8_t* nns = result->buffers[0]->mutable_data();
+ int32_t* ids = reinterpret_cast<int32_t*>(result->buffers[1]->mutable_data());
+ for (int64_t i = 0; i < batch_length; ++i) {
+ bool is_null = !BitUtil::GetBit(nns, i);
+ if (is_null) {
+ ids[i] = kNullId;
+ } else {
+ ARROW_DCHECK(ids[i] >= 0 && ids[i] < map_array->length);
+ if (!BitUtil::GetBit(map_non_nulls, ids[i])) {
+ BitUtil::ClearBit(nns, i);
+ ids[i] = kNullId;
+ } else {
+ ids[i] = map[ids[i]];
+ }
+ }
+ }
+
+ return result;
+}
+
+namespace {
+template <typename FROM, typename TO>
+static Result<std::shared_ptr<ArrayData>> ConvertImp(
+ const std::shared_ptr<DataType>& to_type, const Datum& input, int64_t batch_length,
+ ExecContext* ctx) {
+ ARROW_DCHECK(input.is_array() || input.is_scalar());
+ bool is_scalar = input.is_scalar();
+
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> to_buf,
+ AllocateBuffer(batch_length * sizeof(TO), ctx->memory_pool()));
+ TO* to = reinterpret_cast<TO*>(to_buf->mutable_data());
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> to_nn_buf,
+ AllocateBitmap(batch_length, ctx->memory_pool()));
+ uint8_t* to_nn = to_nn_buf->mutable_data();
+ memset(to_nn, 0xff, BitUtil::BytesForBits(batch_length));
+
+ if (!is_scalar) {
+ const ArrayData& arr = *input.array();
+ const FROM* from = arr.GetValues<FROM>(1);
+ DCHECK_EQ(arr.length, batch_length);
+
+ for (int64_t i = 0; i < arr.length; ++i) {
+ to[i] = static_cast<TO>(from[i]);
+ // Make sure we did not lose information during cast
+ ARROW_DCHECK(static_cast<FROM>(to[i]) == from[i]);
+
+ bool is_null = (arr.buffers[0] != NULLPTR) &&
+ !BitUtil::GetBit(arr.buffers[0]->data(), arr.offset + i);
+ if (is_null) {
+ BitUtil::ClearBit(to_nn, i);
+ }
+ }
+
+ // Pass null buffer unchanged
+ return ArrayData::Make(to_type, arr.length,
+ {std::move(to_nn_buf), std::move(to_buf)});
+ } else {
+ const auto& scalar = input.scalar_as<arrow::internal::PrimitiveScalarBase>();
+ if (scalar.is_valid) {
+ const util::string_view data = scalar.view();
+ DCHECK_EQ(data.size(), sizeof(FROM));
+ const FROM from = *reinterpret_cast<const FROM*>(data.data());
+ const TO to_value = static_cast<TO>(from);
+ // Make sure we did not lose information during cast
+ ARROW_DCHECK(static_cast<FROM>(to_value) == from);
+
+ for (int64_t i = 0; i < batch_length; ++i) {
+ to[i] = to_value;
+ }
+
+ memset(to_nn, 0xff, BitUtil::BytesForBits(batch_length));
+ return ArrayData::Make(to_type, batch_length,
+ {std::move(to_nn_buf), std::move(to_buf)});
+ } else {
+ memset(to_nn, 0, BitUtil::BytesForBits(batch_length));
+ return ArrayData::Make(to_type, batch_length,
+ {std::move(to_nn_buf), std::move(to_buf)});
+ }
+ }
+}
+} // namespace
+
+Result<std::shared_ptr<ArrayData>> HashJoinDictUtil::ConvertToInt32(
+ const std::shared_ptr<DataType>& from_type, const Datum& input, int64_t batch_length,
+ ExecContext* ctx) {
+ switch (from_type->id()) {
+ case Type::UINT8:
+ return ConvertImp<uint8_t, int32_t>(int32(), input, batch_length, ctx);
+ case Type::INT8:
+ return ConvertImp<int8_t, int32_t>(int32(), input, batch_length, ctx);
+ case Type::UINT16:
+ return ConvertImp<uint16_t, int32_t>(int32(), input, batch_length, ctx);
+ case Type::INT16:
+ return ConvertImp<int16_t, int32_t>(int32(), input, batch_length, ctx);
+ case Type::UINT32:
+ return ConvertImp<uint32_t, int32_t>(int32(), input, batch_length, ctx);
+ case Type::INT32:
+ return ConvertImp<int32_t, int32_t>(int32(), input, batch_length, ctx);
+ case Type::UINT64:
+ return ConvertImp<uint64_t, int32_t>(int32(), input, batch_length, ctx);
+ case Type::INT64:
+ return ConvertImp<int64_t, int32_t>(int32(), input, batch_length, ctx);
+ default:
+ ARROW_DCHECK(false);
+ return nullptr;
+ }
+}
+
+Result<std::shared_ptr<ArrayData>> HashJoinDictUtil::ConvertFromInt32(
+ const std::shared_ptr<DataType>& to_type, const Datum& input, int64_t batch_length,
+ ExecContext* ctx) {
+ switch (to_type->id()) {
+ case Type::UINT8:
+ return ConvertImp<int32_t, uint8_t>(to_type, input, batch_length, ctx);
+ case Type::INT8:
+ return ConvertImp<int32_t, int8_t>(to_type, input, batch_length, ctx);
+ case Type::UINT16:
+ return ConvertImp<int32_t, uint16_t>(to_type, input, batch_length, ctx);
+ case Type::INT16:
+ return ConvertImp<int32_t, int16_t>(to_type, input, batch_length, ctx);
+ case Type::UINT32:
+ return ConvertImp<int32_t, uint32_t>(to_type, input, batch_length, ctx);
+ case Type::INT32:
+ return ConvertImp<int32_t, int32_t>(to_type, input, batch_length, ctx);
+ case Type::UINT64:
+ return ConvertImp<int32_t, uint64_t>(to_type, input, batch_length, ctx);
+ case Type::INT64:
+ return ConvertImp<int32_t, int64_t>(to_type, input, batch_length, ctx);
+ default:
+ ARROW_DCHECK(false);
+ return nullptr;
+ }
+}
+
+std::shared_ptr<Array> HashJoinDictUtil::ExtractDictionary(const Datum& data) {
+ return data.is_array() ? MakeArray(data.array()->dictionary)
+ : data.scalar_as<DictionaryScalar>().value.dictionary;
+}
+
+Status HashJoinDictBuild::Init(ExecContext* ctx, std::shared_ptr<Array> dictionary,
+ std::shared_ptr<DataType> index_type,
+ std::shared_ptr<DataType> value_type) {
+ index_type_ = std::move(index_type);
+ value_type_ = std::move(value_type);
+ hash_table_.clear();
+
+ if (!dictionary) {
+ ARROW_ASSIGN_OR_RAISE(auto dict, MakeArrayOfNull(value_type_, 0));
+ unified_dictionary_ = dict->data();
+ return Status::OK();
+ }
+
+ dictionary_ = dictionary;
+
+ // Initialize encoder
+ internal::RowEncoder encoder;
+ std::vector<ValueDescr> encoder_types;
+ encoder_types.emplace_back(value_type_, ValueDescr::ARRAY);
+ encoder.Init(encoder_types, ctx);
+
+ // Encode all dictionary values
+ int64_t length = dictionary->data()->length;
+ if (length >= std::numeric_limits<int32_t>::max()) {
+ return Status::Invalid(
+ "Dictionary length in hash join must fit into signed 32-bit integer.");
+ }
+ ExecBatch batch({dictionary->data()}, length);
+ RETURN_NOT_OK(encoder.EncodeAndAppend(batch));
+
+ std::vector<int32_t> entries_to_take;
+
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> non_nulls_buf,
+ AllocateBitmap(length, ctx->memory_pool()));
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> ids_buf,
+ AllocateBuffer(length * sizeof(int32_t), ctx->memory_pool()));
+ uint8_t* non_nulls = non_nulls_buf->mutable_data();
+ int32_t* ids = reinterpret_cast<int32_t*>(ids_buf->mutable_data());
+ memset(non_nulls, 0xff, BitUtil::BytesForBits(length));
+
+ int32_t num_entries = 0;
+ for (int64_t i = 0; i < length; ++i) {
+ std::string str = encoder.encoded_row(static_cast<int32_t>(i));
+
+ // Do not insert null values into resulting dictionary.
+ // Null values will always be represented as null not an id pointing to a
+ // dictionary entry for null.
+ //
+ if (internal::KeyEncoder::IsNull(reinterpret_cast<const uint8_t*>(str.data()))) {
+ ids[i] = HashJoinDictUtil::kNullId;
+ BitUtil::ClearBit(non_nulls, i);
+ continue;
+ }
+
+ auto iter = hash_table_.find(str);
+ if (iter == hash_table_.end()) {
+ hash_table_.insert(std::make_pair(str, num_entries));
+ ids[i] = num_entries;
+ entries_to_take.push_back(static_cast<int32_t>(i));
+ ++num_entries;
+ } else {
+ ids[i] = iter->second;
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto out, encoder.Decode(num_entries, entries_to_take.data()));
+
+ unified_dictionary_ = out[0].array();
+ remapped_ids_ = ArrayData::Make(DataTypeAfterRemapping(), length,
+ {std::move(non_nulls_buf), std::move(ids_buf)});
+
+ return Status::OK();
+}
+
+Result<std::shared_ptr<ArrayData>> HashJoinDictBuild::RemapInputValues(
+ ExecContext* ctx, const Datum& values, int64_t batch_length) const {
+ // Initialize encoder
+ //
+ internal::RowEncoder encoder;
+ std::vector<ValueDescr> encoder_types;
+ encoder_types.emplace_back(value_type_, ValueDescr::ARRAY);
+ encoder.Init(encoder_types, ctx);
+
+ // Encode all
+ //
+ ARROW_DCHECK(values.is_array() || values.is_scalar());
+ bool is_scalar = values.is_scalar();
+ int64_t encoded_length = is_scalar ? 1 : batch_length;
+ ExecBatch batch({values}, encoded_length);
+ RETURN_NOT_OK(encoder.EncodeAndAppend(batch));
+
+ // Allocate output buffers
+ //
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> non_nulls_buf,
+ AllocateBitmap(batch_length, ctx->memory_pool()));
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<Buffer> ids_buf,
+ AllocateBuffer(batch_length * sizeof(int32_t), ctx->memory_pool()));
+ uint8_t* non_nulls = non_nulls_buf->mutable_data();
+ int32_t* ids = reinterpret_cast<int32_t*>(ids_buf->mutable_data());
+ memset(non_nulls, 0xff, BitUtil::BytesForBits(batch_length));
+
+ // Populate output buffers (for scalar only the first entry is populated)
+ //
+ for (int64_t i = 0; i < encoded_length; ++i) {
+ std::string str = encoder.encoded_row(static_cast<int32_t>(i));
+ if (internal::KeyEncoder::IsNull(reinterpret_cast<const uint8_t*>(str.data()))) {
+ // Map nulls to nulls
+ BitUtil::ClearBit(non_nulls, i);
+ ids[i] = HashJoinDictUtil::kNullId;
+ } else {
+ auto iter = hash_table_.find(str);
+ if (iter == hash_table_.end()) {
+ ids[i] = HashJoinDictUtil::kMissingValueId;
+ } else {
+ ids[i] = iter->second;
+ }
+ }
+ }
+
+ // Generate array of repeated values for scalar input
+ //
+ if (is_scalar) {
+ if (!BitUtil::GetBit(non_nulls, 0)) {
+ memset(non_nulls, 0, BitUtil::BytesForBits(batch_length));
+ }
+ for (int64_t i = 1; i < batch_length; ++i) {
+ ids[i] = ids[0];
+ }
+ }
+
+ return ArrayData::Make(DataTypeAfterRemapping(), batch_length,
+ {std::move(non_nulls_buf), std::move(ids_buf)});
+}
+
+Result<std::shared_ptr<ArrayData>> HashJoinDictBuild::RemapInput(
+ ExecContext* ctx, const Datum& indices, int64_t batch_length,
+ const std::shared_ptr<DataType>& data_type) const {
+ auto dict = HashJoinDictUtil::ExtractDictionary(indices);
+
+ if (!dictionary_->Equals(dict)) {
+ return Status::NotImplemented("Unifying differing dictionaries");
+ }
+
+ return HashJoinDictUtil::IndexRemapUsingLUT(ctx, indices, batch_length, remapped_ids_,
+ data_type);
+}
+
+Result<std::shared_ptr<ArrayData>> HashJoinDictBuild::RemapOutput(
+ const ArrayData& indices32Bit, ExecContext* ctx) const {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ArrayData> indices,
+ HashJoinDictUtil::ConvertFromInt32(
+ index_type_, Datum(indices32Bit), indices32Bit.length, ctx));
+
+ auto type = std::make_shared<DictionaryType>(index_type_, value_type_);
+ return ArrayData::Make(type, indices->length, indices->buffers, {},
+ unified_dictionary_);
+}
+
+void HashJoinDictBuild::CleanUp() {
+ index_type_.reset();
+ value_type_.reset();
+ hash_table_.clear();
+ remapped_ids_.reset();
+ unified_dictionary_.reset();
+}
+
+bool HashJoinDictProbe::KeyNeedsProcessing(
+ const std::shared_ptr<DataType>& probe_data_type,
+ const std::shared_ptr<DataType>& build_data_type) {
+ bool l_is_dict = (probe_data_type->id() == Type::DICTIONARY);
+ bool r_is_dict = (build_data_type->id() == Type::DICTIONARY);
+ return l_is_dict || r_is_dict;
+}
+
+std::shared_ptr<DataType> HashJoinDictProbe::DataTypeAfterRemapping(
+ const std::shared_ptr<DataType>& build_data_type) {
+ bool r_is_dict = (build_data_type->id() == Type::DICTIONARY);
+ if (r_is_dict) {
+ return HashJoinDictBuild::DataTypeAfterRemapping();
+ } else {
+ return build_data_type;
+ }
+}
+
+Result<std::shared_ptr<ArrayData>> HashJoinDictProbe::RemapInput(
+ const HashJoinDictBuild* opt_build_side, const Datum& data, int64_t batch_length,
+ const std::shared_ptr<DataType>& probe_data_type,
+ const std::shared_ptr<DataType>& build_data_type, ExecContext* ctx) {
+ // Cases:
+ // 1. Dictionary(probe)-Dictionary(build)
+ // 2. Dictionary(probe)-Value(build)
+ // 3. Value(probe)-Dictionary(build)
+ //
+ bool l_is_dict = (probe_data_type->id() == Type::DICTIONARY);
+ bool r_is_dict = (build_data_type->id() == Type::DICTIONARY);
+ if (l_is_dict) {
+ auto dict = HashJoinDictUtil::ExtractDictionary(data);
+ const auto& dict_type = checked_cast<const DictionaryType&>(*probe_data_type);
+
+ // Verify that the dictionary is always the same.
+ if (dictionary_) {
+ if (!dictionary_->Equals(dict)) {
+ return Status::NotImplemented(
+ "Unifying differing dictionaries for probe key of hash join");
+ }
+ } else {
+ dictionary_ = dict;
+
+ // Precompute helper data for the given dictionary if this is the first call.
+ if (r_is_dict) {
+ ARROW_DCHECK(opt_build_side);
+ ARROW_ASSIGN_OR_RAISE(
+ remapped_ids_,
+ opt_build_side->RemapInputValues(ctx, Datum(dict->data()), dict->length()));
+ } else {
+ std::vector<ValueDescr> encoder_types;
+ encoder_types.emplace_back(dict_type.value_type(), ValueDescr::ARRAY);
+ encoder_.Init(encoder_types, ctx);
+ ExecBatch batch({dict->data()}, dict->length());
+ RETURN_NOT_OK(encoder_.EncodeAndAppend(batch));
+ }
+ }
+
+ if (r_is_dict) {
+ // CASE 1:
+ // Remap dictionary ids
+ return HashJoinDictUtil::IndexRemapUsingLUT(ctx, data, batch_length, remapped_ids_,
+ probe_data_type);
+ } else {
+ // CASE 2:
+ // Decode selected rows from encoder.
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ArrayData> row_ids_arr,
+ HashJoinDictUtil::ConvertToInt32(dict_type.index_type(), data,
+ batch_length, ctx));
+ // Change nulls to internal::RowEncoder::kRowIdForNulls() in index.
+ int32_t* row_ids =
+ reinterpret_cast<int32_t*>(row_ids_arr->buffers[1]->mutable_data());
+ const uint8_t* non_nulls = row_ids_arr->buffers[0]->data();
+ for (int64_t i = 0; i < batch_length; ++i) {
+ if (!BitUtil::GetBit(non_nulls, i)) {
+ row_ids[i] = internal::RowEncoder::kRowIdForNulls();
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(ExecBatch batch, encoder_.Decode(batch_length, row_ids));
+ return batch.values[0].array();
+ }
+ } else {
+ // CASE 3:
+ // Map values to dictionary ids from build side.
+ // Values missing in the dictionary will get assigned a special constant
+ // HashJoinDictUtil::kMissingValueId (different than any valid id).
+ //
+ ARROW_DCHECK(r_is_dict);
+ ARROW_DCHECK(opt_build_side);
+ return opt_build_side->RemapInputValues(ctx, data, batch_length);
+ }
+}
+
+void HashJoinDictProbe::CleanUp() {
+ dictionary_.reset();
+ remapped_ids_.reset();
+ encoder_.Clear();
+}
+
+Status HashJoinDictBuildMulti::Init(
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map,
+ const ExecBatch* opt_non_empty_batch, ExecContext* ctx) {
+ int num_keys = proj_map.num_cols(HashJoinProjection::KEY);
+ needs_remap_.resize(num_keys);
+ remap_imp_.resize(num_keys);
+ for (int i = 0; i < num_keys; ++i) {
+ needs_remap_[i] = HashJoinDictBuild::KeyNeedsProcessing(
+ proj_map.data_type(HashJoinProjection::KEY, i));
+ }
+
+ bool build_side_empty = (opt_non_empty_batch == nullptr);
+
+ if (!build_side_empty) {
+ auto key_to_input = proj_map.map(HashJoinProjection::KEY, HashJoinProjection::INPUT);
+ for (int i = 0; i < num_keys; ++i) {
+ const std::shared_ptr<DataType>& data_type =
+ proj_map.data_type(HashJoinProjection::KEY, i);
+ if (data_type->id() == Type::DICTIONARY) {
+ const auto& dict_type = checked_cast<const DictionaryType&>(*data_type);
+ const auto& dict = HashJoinDictUtil::ExtractDictionary(
+ opt_non_empty_batch->values[key_to_input.get(i)]);
+ RETURN_NOT_OK(remap_imp_[i].Init(ctx, dict, dict_type.index_type(),
+ dict_type.value_type()));
+ }
+ }
+ } else {
+ for (int i = 0; i < num_keys; ++i) {
+ const std::shared_ptr<DataType>& data_type =
+ proj_map.data_type(HashJoinProjection::KEY, i);
+ if (data_type->id() == Type::DICTIONARY) {
+ const auto& dict_type = checked_cast<const DictionaryType&>(*data_type);
+ RETURN_NOT_OK(remap_imp_[i].Init(ctx, nullptr, dict_type.index_type(),
+ dict_type.value_type()));
+ }
+ }
+ }
+ return Status::OK();
+}
+
+void HashJoinDictBuildMulti::InitEncoder(
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map, RowEncoder* encoder,
+ ExecContext* ctx) {
+ int num_cols = proj_map.num_cols(HashJoinProjection::KEY);
+ std::vector<ValueDescr> data_types(num_cols);
+ for (int icol = 0; icol < num_cols; ++icol) {
+ std::shared_ptr<DataType> data_type =
+ proj_map.data_type(HashJoinProjection::KEY, icol);
+ if (HashJoinDictBuild::KeyNeedsProcessing(data_type)) {
+ data_type = HashJoinDictBuild::DataTypeAfterRemapping();
+ }
+ data_types[icol] = ValueDescr(data_type, ValueDescr::ARRAY);
+ }
+ encoder->Init(data_types, ctx);
+}
+
+Status HashJoinDictBuildMulti::EncodeBatch(
+ size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map,
+ const ExecBatch& batch, RowEncoder* encoder, ExecContext* ctx) const {
+ ExecBatch projected({}, batch.length);
+ int num_cols = proj_map.num_cols(HashJoinProjection::KEY);
+ projected.values.resize(num_cols);
+
+ auto to_input = proj_map.map(HashJoinProjection::KEY, HashJoinProjection::INPUT);
+ for (int icol = 0; icol < num_cols; ++icol) {
+ projected.values[icol] = batch.values[to_input.get(icol)];
+
+ if (needs_remap_[icol]) {
+ ARROW_ASSIGN_OR_RAISE(
+ projected.values[icol],
+ remap_imp_[icol].RemapInput(ctx, projected.values[icol], batch.length,
+ proj_map.data_type(HashJoinProjection::KEY, icol)));
+ }
+ }
+ return encoder->EncodeAndAppend(projected);
+}
+
+Status HashJoinDictBuildMulti::PostDecode(
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map,
+ ExecBatch* decoded_key_batch, ExecContext* ctx) {
+ // Post process build side keys that use dictionary
+ int num_keys = proj_map.num_cols(HashJoinProjection::KEY);
+ for (int i = 0; i < num_keys; ++i) {
+ if (needs_remap_[i]) {
+ ARROW_ASSIGN_OR_RAISE(
+ decoded_key_batch->values[i],
+ remap_imp_[i].RemapOutput(*decoded_key_batch->values[i].array(), ctx));
+ }
+ }
+ return Status::OK();
+}
+
+void HashJoinDictProbeMulti::Init(size_t num_threads) {
+ local_states_.resize(num_threads);
+ for (size_t i = 0; i < local_states_.size(); ++i) {
+ local_states_[i].is_initialized = false;
+ }
+}
+
+bool HashJoinDictProbeMulti::BatchRemapNeeded(
+ size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx) {
+ InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx);
+ return local_states_[thread_index].any_needs_remap;
+}
+
+void HashJoinDictProbeMulti::InitLocalStateIfNeeded(
+ size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx) {
+ ThreadLocalState& local_state = local_states_[thread_index];
+
+ // Check if we need to remap any of the input keys because of dictionary encoding
+ // on either side of the join
+ //
+ int num_cols = proj_map_probe.num_cols(HashJoinProjection::KEY);
+ local_state.any_needs_remap = false;
+ local_state.needs_remap.resize(num_cols);
+ local_state.remap_imp.resize(num_cols);
+ for (int i = 0; i < num_cols; ++i) {
+ local_state.needs_remap[i] = HashJoinDictProbe::KeyNeedsProcessing(
+ proj_map_probe.data_type(HashJoinProjection::KEY, i),
+ proj_map_build.data_type(HashJoinProjection::KEY, i));
+ if (local_state.needs_remap[i]) {
+ local_state.any_needs_remap = true;
+ }
+ }
+
+ if (local_state.any_needs_remap) {
+ InitEncoder(proj_map_probe, proj_map_build, &local_state.post_remap_encoder, ctx);
+ }
+}
+
+void HashJoinDictProbeMulti::InitEncoder(
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, RowEncoder* encoder,
+ ExecContext* ctx) {
+ int num_cols = proj_map_probe.num_cols(HashJoinProjection::KEY);
+ std::vector<ValueDescr> data_types(num_cols);
+ for (int icol = 0; icol < num_cols; ++icol) {
+ std::shared_ptr<DataType> data_type =
+ proj_map_probe.data_type(HashJoinProjection::KEY, icol);
+ std::shared_ptr<DataType> build_data_type =
+ proj_map_build.data_type(HashJoinProjection::KEY, icol);
+ if (HashJoinDictProbe::KeyNeedsProcessing(data_type, build_data_type)) {
+ data_type = HashJoinDictProbe::DataTypeAfterRemapping(build_data_type);
+ }
+ data_types[icol] = ValueDescr(data_type, ValueDescr::ARRAY);
+ }
+ encoder->Init(data_types, ctx);
+}
+
+Status HashJoinDictProbeMulti::EncodeBatch(
+ size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map_build,
+ const HashJoinDictBuildMulti& dict_build, const ExecBatch& batch,
+ RowEncoder** out_encoder, ExecBatch* opt_out_key_batch, ExecContext* ctx) {
+ ThreadLocalState& local_state = local_states_[thread_index];
+ InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx);
+
+ ExecBatch projected({}, batch.length);
+ int num_cols = proj_map_probe.num_cols(HashJoinProjection::KEY);
+ projected.values.resize(num_cols);
+
+ auto to_input = proj_map_probe.map(HashJoinProjection::KEY, HashJoinProjection::INPUT);
+ for (int icol = 0; icol < num_cols; ++icol) {
+ projected.values[icol] = batch.values[to_input.get(icol)];
+
+ if (local_state.needs_remap[icol]) {
+ ARROW_ASSIGN_OR_RAISE(
+ projected.values[icol],
+ local_state.remap_imp[icol].RemapInput(
+ &(dict_build.get_dict_build(icol)), projected.values[icol], batch.length,
+ proj_map_probe.data_type(HashJoinProjection::KEY, icol),
+ proj_map_build.data_type(HashJoinProjection::KEY, icol), ctx));
+ }
+ }
+
+ if (opt_out_key_batch) {
+ *opt_out_key_batch = projected;
+ }
+
+ local_state.post_remap_encoder.Clear();
+ RETURN_NOT_OK(local_state.post_remap_encoder.EncodeAndAppend(projected));
+ *out_encoder = &local_state.post_remap_encoder;
+
+ return Status::OK();
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/hash_join_dict.h b/src/arrow/cpp/src/arrow/compute/exec/hash_join_dict.h
new file mode 100644
index 000000000..26605cc44
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/hash_join_dict.h
@@ -0,0 +1,315 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <unordered_map>
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/schema_util.h"
+#include "arrow/compute/kernels/row_encoder.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+
+// This file contains hash join logic related to handling of dictionary encoded key
+// columns.
+//
+// A key column from probe side of the join can be matched against a key column from build
+// side of the join, as long as the underlying value types are equal. That means that:
+// - both scalars and arrays can be used and even mixed in the same column
+// - dictionary column can be matched against non-dictionary column if underlying value
+// types are equal
+// - dictionary column can be matched against dictionary column with a different index
+// type, and potentially using a different dictionary, if underlying value types are equal
+//
+// We currently require in hash join that for all dictionary encoded columns, the same
+// dictionary is used in all input exec batches.
+//
+// In order to allow matching columns with different dictionaries, different dictionary
+// index types, and dictionary key against non-dictionary key, internally comparisons will
+// be evaluated after remapping values on both sides of the join to a common
+// representation (which will be called "unified representation"). This common
+// representation is a column of int32() type (not a dictionary column). It represents an
+// index in the unified dictionary computed for the (only) dictionary present on build
+// side (an empty dictionary is still created for an empty build side). Null value is
+// always represented in this common representation as null int32 value, unified
+// dictionary will never contain a null value (so there is no ambiguity of representing
+// nulls as either index to a null entry in the dictionary or null index).
+//
+// Unified dictionary represents values present on build side. There may be values on
+// probe side that are not present in it. All such values, that are not null, are mapped
+// in the common representation to a special constant kMissingValueId.
+//
+
+namespace arrow {
+namespace compute {
+
+using internal::RowEncoder;
+
+/// Helper class with operations that are stateless and common to processing of dictionary
+/// keys on both build and probe side.
+class HashJoinDictUtil {
+ public:
+ // Null values in unified representation are always represented as null that has
+ // corresponding integer set to this constant
+ static constexpr int32_t kNullId = 0;
+ // Constant representing a value, that is not null, missing on the build side, in
+ // unified representation.
+ static constexpr int32_t kMissingValueId = -1;
+
+ // Check if data types of corresponding pair of key column on build and probe side are
+ // compatible
+ static bool KeyDataTypesValid(const std::shared_ptr<DataType>& probe_data_type,
+ const std::shared_ptr<DataType>& build_data_type);
+
+ // Input must be dictionary array or dictionary scalar.
+ // A precomputed and provided here lookup table in the form of int32() array will be
+ // used to remap input indices to unified representation.
+ //
+ static Result<std::shared_ptr<ArrayData>> IndexRemapUsingLUT(
+ ExecContext* ctx, const Datum& indices, int64_t batch_length,
+ const std::shared_ptr<ArrayData>& map_array,
+ const std::shared_ptr<DataType>& data_type);
+
+ // Return int32() array that contains indices of input dictionary array or scalar after
+ // type casting.
+ static Result<std::shared_ptr<ArrayData>> ConvertToInt32(
+ const std::shared_ptr<DataType>& from_type, const Datum& input,
+ int64_t batch_length, ExecContext* ctx);
+
+ // Return an array that contains elements of input int32() array after casting to a
+ // given integer type. This is used for mapping unified representation stored in the
+ // hash table on build side back to original input data type of hash join, when
+ // outputting hash join results to parent exec node.
+ //
+ static Result<std::shared_ptr<ArrayData>> ConvertFromInt32(
+ const std::shared_ptr<DataType>& to_type, const Datum& input, int64_t batch_length,
+ ExecContext* ctx);
+
+ // Return dictionary referenced in either dictionary array or dictionary scalar
+ static std::shared_ptr<Array> ExtractDictionary(const Datum& data);
+};
+
+/// Implements processing of dictionary arrays/scalars in key columns on the build side of
+/// a hash join.
+/// Each instance of this class corresponds to a single column and stores and
+/// processes only the information related to that column.
+/// Const methods are thread-safe, non-const methods are not (the caller must make sure
+/// that only one thread at any time will access them).
+///
+class HashJoinDictBuild {
+ public:
+ // Returns true if the key column (described in input by its data type) requires any
+ // pre- or post-processing related to handling dictionaries.
+ //
+ static bool KeyNeedsProcessing(const std::shared_ptr<DataType>& build_data_type) {
+ return (build_data_type->id() == Type::DICTIONARY);
+ }
+
+ // Data type of unified representation
+ static std::shared_ptr<DataType> DataTypeAfterRemapping() { return int32(); }
+
+ // Should be called only once in hash join, before processing any build or probe
+ // batches.
+ //
+ // Takes a pointer to the dictionary for a corresponding key column on the build side as
+ // an input. If the build side is empty, it still needs to be called, but with
+ // dictionary pointer set to null.
+ //
+ // Currently it is required that all input batches on build side share the same
+ // dictionary. For each input batch during its pre-processing, dictionary will be
+ // checked and error will be returned if it is different then the one provided in the
+ // call to this method.
+ //
+ // Unifies the dictionary. The order of the values is still preserved.
+ // Null and duplicate entries are removed. If the dictionary is already unified, its
+ // copy will be produced and stored within this class.
+ //
+ // Prepares the mapping from ids within original dictionary to the ids in the resulting
+ // dictionary. This is used later on to pre-process (map to unified representation) key
+ // column on build side.
+ //
+ // Prepares the reverse mapping (in the form of hash table) from values to the ids in
+ // the resulting dictionary. This will be used later on to pre-process (map to unified
+ // representation) key column on probe side. Values on probe side that are not present
+ // in the original dictionary will be mapped to a special constant kMissingValueId. The
+ // exception is made for nulls, which get always mapped to nulls (both when null is
+ // represented as a dictionary id pointing to a null and a null dictionary id).
+ //
+ Status Init(ExecContext* ctx, std::shared_ptr<Array> dictionary,
+ std::shared_ptr<DataType> index_type, std::shared_ptr<DataType> value_type);
+
+ // Remap array or scalar values into unified representation (array of int32()).
+ // Outputs kMissingValueId if input value is not found in the unified dictionary.
+ // Outputs null for null input value (with corresponding data set to kNullId).
+ //
+ Result<std::shared_ptr<ArrayData>> RemapInputValues(ExecContext* ctx,
+ const Datum& values,
+ int64_t batch_length) const;
+
+ // Remap dictionary array or dictionary scalar on build side to unified representation.
+ // Dictionary referenced in the input must match the dictionary that was
+ // given during initialization.
+ // The output is a dictionary array that references unified dictionary.
+ //
+ Result<std::shared_ptr<ArrayData>> RemapInput(
+ ExecContext* ctx, const Datum& indices, int64_t batch_length,
+ const std::shared_ptr<DataType>& data_type) const;
+
+ // Outputs dictionary array referencing unified dictionary, given an array with 32-bit
+ // ids.
+ // Used to post-process values looked up in a hash table on build side of the hash join
+ // before outputting to the parent exec node.
+ //
+ Result<std::shared_ptr<ArrayData>> RemapOutput(const ArrayData& indices32Bit,
+ ExecContext* ctx) const;
+
+ // Release shared pointers and memory
+ void CleanUp();
+
+ private:
+ // Data type of dictionary ids for the input dictionary on build side
+ std::shared_ptr<DataType> index_type_;
+ // Data type of values for the input dictionary on build side
+ std::shared_ptr<DataType> value_type_;
+ // Mapping from (encoded as string) values to the ids in unified dictionary
+ std::unordered_map<std::string, int32_t> hash_table_;
+ // Mapping from input dictionary ids to unified dictionary ids
+ std::shared_ptr<ArrayData> remapped_ids_;
+ // Input dictionary
+ std::shared_ptr<Array> dictionary_;
+ // Unified dictionary
+ std::shared_ptr<ArrayData> unified_dictionary_;
+};
+
+/// Implements processing of dictionary arrays/scalars in key columns on the probe side of
+/// a hash join.
+/// Each instance of this class corresponds to a single column and stores and
+/// processes only the information related to that column.
+/// It is not thread-safe - every participating thread should use its own instance of
+/// this class.
+///
+class HashJoinDictProbe {
+ public:
+ static bool KeyNeedsProcessing(const std::shared_ptr<DataType>& probe_data_type,
+ const std::shared_ptr<DataType>& build_data_type);
+
+ // Data type of the result of remapping input key column.
+ //
+ // The result of remapping is what is used in hash join for matching keys on build and
+ // probe side. The exact data types may be different, as described below, and therefore
+ // a common representation is needed for simplifying comparisons of pairs of keys on
+ // both sides.
+ //
+ // We support matching key that is of non-dictionary type with key that is of dictionary
+ // type, as long as the underlying value types are equal. We support matching when both
+ // keys are of dictionary type, regardless whether underlying dictionary index types are
+ // the same or not.
+ //
+ static std::shared_ptr<DataType> DataTypeAfterRemapping(
+ const std::shared_ptr<DataType>& build_data_type);
+
+ // Should only be called if KeyNeedsProcessing method returns true for a pair of
+ // corresponding key columns from build and probe side.
+ // Converts values in order to match the common representation for
+ // both build and probe side used in hash table comparison.
+ // Supports arrays and scalars as input.
+ // Argument opt_build_side should be null if dictionary key on probe side is matched
+ // with non-dictionary key on build side.
+ //
+ Result<std::shared_ptr<ArrayData>> RemapInput(
+ const HashJoinDictBuild* opt_build_side, const Datum& data, int64_t batch_length,
+ const std::shared_ptr<DataType>& probe_data_type,
+ const std::shared_ptr<DataType>& build_data_type, ExecContext* ctx);
+
+ void CleanUp();
+
+ private:
+ // May be null if probe side key is non-dictionary. Otherwise it is used to verify that
+ // only a single dictionary is referenced in exec batch on probe side of hash join.
+ std::shared_ptr<Array> dictionary_;
+ // Mapping from dictionary on probe side of hash join (if it is used) to unified
+ // representation.
+ std::shared_ptr<ArrayData> remapped_ids_;
+ // Encoder of key columns that uses unified representation instead of original data type
+ // for key columns that need to use it (have dictionaries on either side of the join).
+ internal::RowEncoder encoder_;
+};
+
+// Encapsulates dictionary handling logic for build side of hash join.
+//
+class HashJoinDictBuildMulti {
+ public:
+ Status Init(const SchemaProjectionMaps<HashJoinProjection>& proj_map,
+ const ExecBatch* opt_non_empty_batch, ExecContext* ctx);
+ static void InitEncoder(const SchemaProjectionMaps<HashJoinProjection>& proj_map,
+ RowEncoder* encoder, ExecContext* ctx);
+ Status EncodeBatch(size_t thread_index,
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map,
+ const ExecBatch& batch, RowEncoder* encoder, ExecContext* ctx) const;
+ Status PostDecode(const SchemaProjectionMaps<HashJoinProjection>& proj_map,
+ ExecBatch* decoded_key_batch, ExecContext* ctx);
+ const HashJoinDictBuild& get_dict_build(int icol) const { return remap_imp_[icol]; }
+
+ private:
+ std::vector<bool> needs_remap_;
+ std::vector<HashJoinDictBuild> remap_imp_;
+};
+
+// Encapsulates dictionary handling logic for probe side of hash join
+//
+class HashJoinDictProbeMulti {
+ public:
+ void Init(size_t num_threads);
+ bool BatchRemapNeeded(size_t thread_index,
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map_build,
+ ExecContext* ctx);
+ Status EncodeBatch(size_t thread_index,
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map_build,
+ const HashJoinDictBuildMulti& dict_build, const ExecBatch& batch,
+ RowEncoder** out_encoder, ExecBatch* opt_out_key_batch,
+ ExecContext* ctx);
+
+ private:
+ void InitLocalStateIfNeeded(
+ size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx);
+ static void InitEncoder(const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
+ const SchemaProjectionMaps<HashJoinProjection>& proj_map_build,
+ RowEncoder* encoder, ExecContext* ctx);
+ struct ThreadLocalState {
+ bool is_initialized;
+ // Whether any key column needs remapping (because of dictionaries used) before doing
+ // join hash table lookups
+ bool any_needs_remap;
+ // Whether each key column needs remapping before doing join hash table lookups
+ std::vector<bool> needs_remap;
+ std::vector<HashJoinDictProbe> remap_imp;
+ // Encoder of key columns that uses unified representation instead of original data
+ // type for key columns that need to use it (have dictionaries on either side of the
+ // join).
+ RowEncoder post_remap_encoder;
+ };
+ std::vector<ThreadLocalState> local_states_;
+};
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/hash_join_node.cc b/src/arrow/cpp/src/arrow/compute/exec/hash_join_node.cc
new file mode 100644
index 000000000..4bccb7610
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/hash_join_node.cc
@@ -0,0 +1,469 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <unordered_set>
+
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/hash_join.h"
+#include "arrow/compute/exec/hash_join_dict.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/schema_util.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+// Check if a type is supported in a join (as either a key or non-key column)
+bool HashJoinSchema::IsTypeSupported(const DataType& type) {
+ const Type::type id = type.id();
+ if (id == Type::DICTIONARY) {
+ return IsTypeSupported(*checked_cast<const DictionaryType&>(type).value_type());
+ }
+ return is_fixed_width(id) || is_binary_like(id) || is_large_binary_like(id);
+}
+
+Result<std::vector<FieldRef>> HashJoinSchema::VectorDiff(const Schema& schema,
+ const std::vector<FieldRef>& a,
+ const std::vector<FieldRef>& b) {
+ std::unordered_set<int> b_paths;
+ for (size_t i = 0; i < b.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto match, b[i].FindOne(schema));
+ b_paths.insert(match[0]);
+ }
+
+ std::vector<FieldRef> result;
+
+ for (size_t i = 0; i < a.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto match, a[i].FindOne(schema));
+ bool is_found = (b_paths.find(match[0]) != b_paths.end());
+ if (!is_found) {
+ result.push_back(a[i]);
+ }
+ }
+
+ return result;
+}
+
+Status HashJoinSchema::Init(JoinType join_type, const Schema& left_schema,
+ const std::vector<FieldRef>& left_keys,
+ const Schema& right_schema,
+ const std::vector<FieldRef>& right_keys,
+ const std::string& left_field_name_prefix,
+ const std::string& right_field_name_prefix) {
+ std::vector<FieldRef> left_output;
+ if (join_type != JoinType::RIGHT_SEMI && join_type != JoinType::RIGHT_ANTI) {
+ const FieldVector& left_fields = left_schema.fields();
+ left_output.resize(left_fields.size());
+ for (size_t i = 0; i < left_fields.size(); ++i) {
+ left_output[i] = FieldRef(static_cast<int>(i));
+ }
+ }
+ // Repeat the same for the right side
+ std::vector<FieldRef> right_output;
+ if (join_type != JoinType::LEFT_SEMI && join_type != JoinType::LEFT_ANTI) {
+ const FieldVector& right_fields = right_schema.fields();
+ right_output.resize(right_fields.size());
+ for (size_t i = 0; i < right_fields.size(); ++i) {
+ right_output[i] = FieldRef(static_cast<int>(i));
+ }
+ }
+ return Init(join_type, left_schema, left_keys, left_output, right_schema, right_keys,
+ right_output, left_field_name_prefix, right_field_name_prefix);
+}
+
+Status HashJoinSchema::Init(JoinType join_type, const Schema& left_schema,
+ const std::vector<FieldRef>& left_keys,
+ const std::vector<FieldRef>& left_output,
+ const Schema& right_schema,
+ const std::vector<FieldRef>& right_keys,
+ const std::vector<FieldRef>& right_output,
+ const std::string& left_field_name_prefix,
+ const std::string& right_field_name_prefix) {
+ RETURN_NOT_OK(ValidateSchemas(join_type, left_schema, left_keys, left_output,
+ right_schema, right_keys, right_output,
+ left_field_name_prefix, right_field_name_prefix));
+
+ std::vector<HashJoinProjection> handles;
+ std::vector<const std::vector<FieldRef>*> field_refs;
+
+ handles.push_back(HashJoinProjection::KEY);
+ field_refs.push_back(&left_keys);
+ ARROW_ASSIGN_OR_RAISE(auto left_payload,
+ VectorDiff(left_schema, left_output, left_keys));
+ handles.push_back(HashJoinProjection::PAYLOAD);
+ field_refs.push_back(&left_payload);
+ handles.push_back(HashJoinProjection::OUTPUT);
+ field_refs.push_back(&left_output);
+
+ RETURN_NOT_OK(
+ proj_maps[0].Init(HashJoinProjection::INPUT, left_schema, handles, field_refs));
+
+ handles.clear();
+ field_refs.clear();
+
+ handles.push_back(HashJoinProjection::KEY);
+ field_refs.push_back(&right_keys);
+ ARROW_ASSIGN_OR_RAISE(auto right_payload,
+ VectorDiff(right_schema, right_output, right_keys));
+ handles.push_back(HashJoinProjection::PAYLOAD);
+ field_refs.push_back(&right_payload);
+ handles.push_back(HashJoinProjection::OUTPUT);
+ field_refs.push_back(&right_output);
+
+ RETURN_NOT_OK(
+ proj_maps[1].Init(HashJoinProjection::INPUT, right_schema, handles, field_refs));
+
+ return Status::OK();
+}
+
+Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_schema,
+ const std::vector<FieldRef>& left_keys,
+ const std::vector<FieldRef>& left_output,
+ const Schema& right_schema,
+ const std::vector<FieldRef>& right_keys,
+ const std::vector<FieldRef>& right_output,
+ const std::string& left_field_name_prefix,
+ const std::string& right_field_name_prefix) {
+ // Checks for key fields:
+ // 1. Key field refs must match exactly one input field
+ // 2. Same number of key fields on left and right
+ // 3. At least one key field
+ // 4. Equal data types for corresponding key fields
+ // 5. Some data types may not be allowed in a key field or non-key field
+ //
+ if (left_keys.size() != right_keys.size()) {
+ return Status::Invalid("Different number of key fields on left (", left_keys.size(),
+ ") and right (", right_keys.size(), ") side of the join");
+ }
+ if (left_keys.size() < 1) {
+ return Status::Invalid("Join key cannot be empty");
+ }
+ for (size_t i = 0; i < left_keys.size() + right_keys.size(); ++i) {
+ bool left_side = i < left_keys.size();
+ const FieldRef& field_ref =
+ left_side ? left_keys[i] : right_keys[i - left_keys.size()];
+ Result<FieldPath> result = field_ref.FindOne(left_side ? left_schema : right_schema);
+ if (!result.ok()) {
+ return Status::Invalid("No match or multiple matches for key field reference ",
+ field_ref.ToString(), left_side ? " on left " : " on right ",
+ "side of the join");
+ }
+ const FieldPath& match = result.ValueUnsafe();
+ const std::shared_ptr<DataType>& type =
+ (left_side ? left_schema.fields() : right_schema.fields())[match[0]]->type();
+ if (!IsTypeSupported(*type)) {
+ return Status::Invalid("Data type ", *type, " is not supported in join key field");
+ }
+ }
+ for (size_t i = 0; i < left_keys.size(); ++i) {
+ const FieldRef& left_ref = left_keys[i];
+ const FieldRef& right_ref = right_keys[i];
+ int left_id = left_ref.FindOne(left_schema).ValueUnsafe()[0];
+ int right_id = right_ref.FindOne(right_schema).ValueUnsafe()[0];
+ const std::shared_ptr<DataType>& left_type = left_schema.fields()[left_id]->type();
+ const std::shared_ptr<DataType>& right_type = right_schema.fields()[right_id]->type();
+ if (!HashJoinDictUtil::KeyDataTypesValid(left_type, right_type)) {
+ return Status::Invalid(
+ "Incompatible data types for corresponding join field keys: ",
+ left_ref.ToString(), " of type ", left_type->ToString(), " and ",
+ right_ref.ToString(), " of type ", right_type->ToString());
+ }
+ }
+ for (const auto& field : left_schema.fields()) {
+ const auto& type = *field->type();
+ if (!IsTypeSupported(type)) {
+ return Status::Invalid("Data type ", type,
+ " is not supported in join non-key field");
+ }
+ }
+ for (const auto& field : right_schema.fields()) {
+ const auto& type = *field->type();
+ if (!IsTypeSupported(type)) {
+ return Status::Invalid("Data type ", type,
+ " is not supported in join non-key field");
+ }
+ }
+
+ // Check for output fields:
+ // 1. Output field refs must match exactly one input field
+ // 2. At least one output field
+ // 3. Dictionary type is not supported in an output field
+ // 4. Left semi/anti join (right semi/anti join) must not output fields from right
+ // (left)
+ // 5. No name collisions in output fields after adding (potentially empty)
+ // prefixes to left and right output
+ //
+ if (left_output.empty() && right_output.empty()) {
+ return Status::Invalid("Join must output at least one field");
+ }
+ if (join_type == JoinType::LEFT_SEMI || join_type == JoinType::LEFT_ANTI) {
+ if (!right_output.empty()) {
+ return Status::Invalid(
+ join_type == JoinType::LEFT_SEMI ? "Left semi join " : "Left anti-semi join ",
+ "may not output fields from right side");
+ }
+ }
+ if (join_type == JoinType::RIGHT_SEMI || join_type == JoinType::RIGHT_ANTI) {
+ if (!left_output.empty()) {
+ return Status::Invalid(join_type == JoinType::RIGHT_SEMI ? "Right semi join "
+ : "Right anti-semi join ",
+ "may not output fields from left side");
+ }
+ }
+ for (size_t i = 0; i < left_output.size() + right_output.size(); ++i) {
+ bool left_side = i < left_output.size();
+ const FieldRef& field_ref =
+ left_side ? left_output[i] : right_output[i - left_output.size()];
+ Result<FieldPath> result = field_ref.FindOne(left_side ? left_schema : right_schema);
+ if (!result.ok()) {
+ return Status::Invalid("No match or multiple matches for output field reference ",
+ field_ref.ToString(), left_side ? " on left " : " on right ",
+ "side of the join");
+ }
+ }
+ return Status::OK();
+}
+
+std::shared_ptr<Schema> HashJoinSchema::MakeOutputSchema(
+ const std::string& left_field_name_prefix,
+ const std::string& right_field_name_prefix) {
+ std::vector<std::shared_ptr<Field>> fields;
+ int left_size = proj_maps[0].num_cols(HashJoinProjection::OUTPUT);
+ int right_size = proj_maps[1].num_cols(HashJoinProjection::OUTPUT);
+ fields.resize(left_size + right_size);
+
+ for (int i = 0; i < left_size + right_size; ++i) {
+ bool is_left = (i < left_size);
+ int side = (is_left ? 0 : 1);
+ int input_field_id = proj_maps[side]
+ .map(HashJoinProjection::OUTPUT, HashJoinProjection::INPUT)
+ .get(is_left ? i : i - left_size);
+ const std::string& input_field_name =
+ proj_maps[side].field_name(HashJoinProjection::INPUT, input_field_id);
+ const std::shared_ptr<DataType>& input_data_type =
+ proj_maps[side].data_type(HashJoinProjection::INPUT, input_field_id);
+
+ std::string output_field_name =
+ (is_left ? left_field_name_prefix : right_field_name_prefix) + input_field_name;
+
+ // All fields coming out of join are marked as nullable.
+ fields[i] =
+ std::make_shared<Field>(output_field_name, input_data_type, true /*nullable*/);
+ }
+ return std::make_shared<Schema>(std::move(fields));
+}
+
+class HashJoinNode : public ExecNode {
+ public:
+ HashJoinNode(ExecPlan* plan, NodeVector inputs, const HashJoinNodeOptions& join_options,
+ std::shared_ptr<Schema> output_schema,
+ std::unique_ptr<HashJoinSchema> schema_mgr,
+ std::unique_ptr<HashJoinImpl> impl)
+ : ExecNode(plan, inputs, {"left", "right"},
+ /*output_schema=*/std::move(output_schema),
+ /*num_outputs=*/1),
+ join_type_(join_options.join_type),
+ key_cmp_(join_options.key_cmp),
+ schema_mgr_(std::move(schema_mgr)),
+ impl_(std::move(impl)) {
+ complete_.store(false);
+ }
+
+ static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) {
+ // Number of input exec nodes must be 2
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 2, "HashJoinNode"));
+
+ std::unique_ptr<HashJoinSchema> schema_mgr =
+ ::arrow::internal::make_unique<HashJoinSchema>();
+
+ const auto& join_options = checked_cast<const HashJoinNodeOptions&>(options);
+
+ // This will also validate input schemas
+ if (join_options.output_all) {
+ RETURN_NOT_OK(schema_mgr->Init(
+ join_options.join_type, *(inputs[0]->output_schema()), join_options.left_keys,
+ *(inputs[1]->output_schema()), join_options.right_keys,
+ join_options.output_prefix_for_left, join_options.output_prefix_for_right));
+ } else {
+ RETURN_NOT_OK(schema_mgr->Init(
+ join_options.join_type, *(inputs[0]->output_schema()), join_options.left_keys,
+ join_options.left_output, *(inputs[1]->output_schema()),
+ join_options.right_keys, join_options.right_output,
+ join_options.output_prefix_for_left, join_options.output_prefix_for_right));
+ }
+
+ // Generate output schema
+ std::shared_ptr<Schema> output_schema = schema_mgr->MakeOutputSchema(
+ join_options.output_prefix_for_left, join_options.output_prefix_for_right);
+
+ // Create hash join implementation object
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<HashJoinImpl> impl, HashJoinImpl::MakeBasic());
+
+ return plan->EmplaceNode<HashJoinNode>(plan, inputs, join_options,
+ std::move(output_schema),
+ std::move(schema_mgr), std::move(impl));
+ }
+
+ const char* kind_name() const override { return "HashJoinNode"; }
+
+ void InputReceived(ExecNode* input, ExecBatch batch) override {
+ ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end());
+
+ if (complete_.load()) {
+ return;
+ }
+
+ size_t thread_index = thread_indexer_();
+ int side = (input == inputs_[0]) ? 0 : 1;
+ {
+ Status status = impl_->InputReceived(thread_index, side, std::move(batch));
+ if (!status.ok()) {
+ StopProducing();
+ ErrorIfNotOk(status);
+ return;
+ }
+ }
+ if (batch_count_[side].Increment()) {
+ Status status = impl_->InputFinished(thread_index, side);
+ if (!status.ok()) {
+ StopProducing();
+ ErrorIfNotOk(status);
+ return;
+ }
+ }
+ }
+
+ void ErrorReceived(ExecNode* input, Status error) override {
+ DCHECK_EQ(input, inputs_[0]);
+ StopProducing();
+ outputs_[0]->ErrorReceived(this, std::move(error));
+ }
+
+ void InputFinished(ExecNode* input, int total_batches) override {
+ ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end());
+
+ size_t thread_index = thread_indexer_();
+ int side = (input == inputs_[0]) ? 0 : 1;
+
+ if (batch_count_[side].SetTotal(total_batches)) {
+ Status status = impl_->InputFinished(thread_index, side);
+ if (!status.ok()) {
+ StopProducing();
+ ErrorIfNotOk(status);
+ return;
+ }
+ }
+ }
+
+ Status StartProducing() override {
+ finished_ = Future<>::Make();
+
+ bool use_sync_execution = !(plan_->exec_context()->executor());
+ size_t num_threads = use_sync_execution ? 1 : thread_indexer_.Capacity();
+
+ RETURN_NOT_OK(impl_->Init(
+ plan_->exec_context(), join_type_, use_sync_execution, num_threads,
+ schema_mgr_.get(), key_cmp_,
+ [this](ExecBatch batch) { this->OutputBatchCallback(batch); },
+ [this](int64_t total_num_batches) { this->FinishedCallback(total_num_batches); },
+ [this](std::function<Status(size_t)> func) -> Status {
+ return this->ScheduleTaskCallback(std::move(func));
+ }));
+ return Status::OK();
+ }
+
+ void PauseProducing(ExecNode* output) override {}
+
+ void ResumeProducing(ExecNode* output) override {}
+
+ void StopProducing(ExecNode* output) override {
+ DCHECK_EQ(output, outputs_[0]);
+ StopProducing();
+ }
+
+ void StopProducing() override {
+ bool expected = false;
+ if (complete_.compare_exchange_strong(expected, true)) {
+ for (auto&& input : inputs_) {
+ input->StopProducing(this);
+ }
+ impl_->Abort([this]() { finished_.MarkFinished(); });
+ }
+ }
+
+ Future<> finished() override { return finished_; }
+
+ private:
+ void OutputBatchCallback(ExecBatch batch) {
+ outputs_[0]->InputReceived(this, std::move(batch));
+ }
+
+ void FinishedCallback(int64_t total_num_batches) {
+ bool expected = false;
+ if (complete_.compare_exchange_strong(expected, true)) {
+ outputs_[0]->InputFinished(this, static_cast<int>(total_num_batches));
+ finished_.MarkFinished();
+ }
+ }
+
+ Status ScheduleTaskCallback(std::function<Status(size_t)> func) {
+ auto executor = plan_->exec_context()->executor();
+ if (executor) {
+ RETURN_NOT_OK(executor->Spawn([this, func] {
+ size_t thread_index = thread_indexer_();
+ Status status = func(thread_index);
+ if (!status.ok()) {
+ StopProducing();
+ ErrorIfNotOk(status);
+ return;
+ }
+ }));
+ } else {
+ // We should not get here in serial execution mode
+ ARROW_DCHECK(false);
+ }
+ return Status::OK();
+ }
+
+ private:
+ AtomicCounter batch_count_[2];
+ std::atomic<bool> complete_;
+ Future<> finished_ = Future<>::MakeFinished();
+ JoinType join_type_;
+ std::vector<JoinKeyCmp> key_cmp_;
+ ThreadIndexer thread_indexer_;
+ std::unique_ptr<HashJoinSchema> schema_mgr_;
+ std::unique_ptr<HashJoinImpl> impl_;
+};
+
+namespace internal {
+
+void RegisterHashJoinNode(ExecFactoryRegistry* registry) {
+ DCHECK_OK(registry->AddFactory("hashjoin", HashJoinNode::Make));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/src/arrow/cpp/src/arrow/compute/exec/hash_join_node_test.cc
new file mode 100644
index 000000000..9afddf3c5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/hash_join_node_test.cc
@@ -0,0 +1,1693 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gmock/gmock-matchers.h>
+
+#include <random>
+#include <unordered_set>
+
+#include "arrow/api.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/test_util.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/compute/kernels/row_encoder.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/pcg_random.h"
+#include "arrow/util/thread_pool.h"
+
+using testing::UnorderedElementsAreArray;
+
+namespace arrow {
+namespace compute {
+
+BatchesWithSchema GenerateBatchesFromString(
+ const std::shared_ptr<Schema>& schema,
+ const std::vector<util::string_view>& json_strings, int multiplicity = 1) {
+ BatchesWithSchema out_batches{{}, schema};
+
+ std::vector<ValueDescr> descrs;
+ for (auto&& field : schema->fields()) {
+ descrs.emplace_back(field->type());
+ }
+
+ for (auto&& s : json_strings) {
+ out_batches.batches.push_back(ExecBatchFromJSON(descrs, s));
+ }
+
+ size_t batch_count = out_batches.batches.size();
+ for (int repeat = 1; repeat < multiplicity; ++repeat) {
+ for (size_t i = 0; i < batch_count; ++i) {
+ out_batches.batches.push_back(out_batches.batches[i]);
+ }
+ }
+
+ return out_batches;
+}
+
+void CheckRunOutput(JoinType type, const BatchesWithSchema& l_batches,
+ const BatchesWithSchema& r_batches,
+ const std::vector<FieldRef>& left_keys,
+ const std::vector<FieldRef>& right_keys,
+ const BatchesWithSchema& exp_batches, bool parallel = false) {
+ auto exec_ctx = arrow::internal::make_unique<ExecContext>(
+ default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
+
+ HashJoinNodeOptions join_options{type, left_keys, right_keys};
+ Declaration join{"hashjoin", join_options};
+
+ // add left source
+ join.inputs.emplace_back(Declaration{
+ "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel,
+ /*slow=*/false)}});
+ // add right source
+ join.inputs.emplace_back(Declaration{
+ "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel,
+ /*slow=*/false)}});
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}})
+ .AddToPlan(plan.get()));
+
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen));
+
+ ASSERT_OK_AND_ASSIGN(auto exp_table,
+ TableFromExecBatches(exp_batches.schema, exp_batches.batches));
+
+ ASSERT_OK_AND_ASSIGN(auto out_table, TableFromExecBatches(exp_batches.schema, res));
+
+ if (exp_table->num_rows() == 0) {
+ ASSERT_EQ(exp_table->num_rows(), out_table->num_rows());
+ } else {
+ std::vector<SortKey> sort_keys;
+ for (auto&& f : exp_batches.schema->fields()) {
+ sort_keys.emplace_back(f->name());
+ }
+ ASSERT_OK_AND_ASSIGN(auto exp_table_sort_ids,
+ SortIndices(exp_table, SortOptions(sort_keys)));
+ ASSERT_OK_AND_ASSIGN(auto exp_table_sorted, Take(exp_table, exp_table_sort_ids));
+ ASSERT_OK_AND_ASSIGN(auto out_table_sort_ids,
+ SortIndices(out_table, SortOptions(sort_keys)));
+ ASSERT_OK_AND_ASSIGN(auto out_table_sorted, Take(out_table, out_table_sort_ids));
+
+ AssertTablesEqual(*exp_table_sorted.table(), *out_table_sorted.table(),
+ /*same_chunk_layout=*/false, /*flatten=*/true);
+ }
+}
+
+void RunNonEmptyTest(JoinType type, bool parallel) {
+ auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())});
+ auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())});
+ BatchesWithSchema l_batches, r_batches, exp_batches;
+
+ int multiplicity = parallel ? 100 : 1;
+
+ l_batches = GenerateBatchesFromString(
+ l_schema,
+ {R"([[0,"d"], [1,"b"]])", R"([[2,"d"], [3,"a"], [4,"a"]])",
+ R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"},
+ multiplicity);
+
+ r_batches = GenerateBatchesFromString(
+ r_schema,
+ {R"([["f", 0], ["b", 1], ["b", 2]])", R"([["c", 3], ["g", 4]])", R"([["e", 5]])"},
+ multiplicity);
+
+ switch (type) {
+ case JoinType::LEFT_SEMI:
+ exp_batches = GenerateBatchesFromString(
+ l_schema, {R"([[1,"b"]])", R"([])", R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"},
+ multiplicity);
+ break;
+ case JoinType::RIGHT_SEMI:
+ exp_batches = GenerateBatchesFromString(
+ r_schema, {R"([["b", 1], ["b", 2]])", R"([["c", 3]])", R"([["e", 5]])"},
+ multiplicity);
+ break;
+ case JoinType::LEFT_ANTI:
+ exp_batches = GenerateBatchesFromString(
+ l_schema, {R"([[0,"d"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", R"([])"},
+ multiplicity);
+ break;
+ case JoinType::RIGHT_ANTI:
+ exp_batches = GenerateBatchesFromString(
+ r_schema, {R"([["f", 0]])", R"([["g", 4]])", R"([])"}, multiplicity);
+ break;
+ case JoinType::INNER:
+ case JoinType::LEFT_OUTER:
+ case JoinType::RIGHT_OUTER:
+ case JoinType::FULL_OUTER:
+ default:
+ FAIL() << "join type not implemented!";
+ }
+
+ CheckRunOutput(type, l_batches, r_batches,
+ /*left_keys=*/{{"l_str"}}, /*right_keys=*/{{"r_str"}}, exp_batches,
+ parallel);
+}
+
+void RunEmptyTest(JoinType type, bool parallel) {
+ auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())});
+ auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())});
+
+ int multiplicity = parallel ? 100 : 1;
+
+ BatchesWithSchema l_empty, r_empty, l_n_empty, r_n_empty;
+
+ l_empty = GenerateBatchesFromString(l_schema, {R"([])"}, multiplicity);
+ r_empty = GenerateBatchesFromString(r_schema, {R"([])"}, multiplicity);
+
+ l_n_empty =
+ GenerateBatchesFromString(l_schema, {R"([[0,"d"], [1,"b"]])"}, multiplicity);
+ r_n_empty = GenerateBatchesFromString(r_schema, {R"([["f", 0], ["b", 1], ["b", 2]])"},
+ multiplicity);
+
+ std::vector<FieldRef> l_keys{{"l_str"}};
+ std::vector<FieldRef> r_keys{{"r_str"}};
+
+ switch (type) {
+ case JoinType::LEFT_SEMI:
+ // both empty
+ CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, l_empty, parallel);
+ // right empty
+ CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, l_empty, parallel);
+ // left empty
+ CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, l_empty, parallel);
+ break;
+ case JoinType::RIGHT_SEMI:
+ // both empty
+ CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, r_empty, parallel);
+ // right empty
+ CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, r_empty, parallel);
+ // left empty
+ CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, r_empty, parallel);
+ break;
+ case JoinType::LEFT_ANTI:
+ // both empty
+ CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, l_empty, parallel);
+ // right empty
+ CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, l_n_empty, parallel);
+ // left empty
+ CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, l_empty, parallel);
+ break;
+ case JoinType::RIGHT_ANTI:
+ // both empty
+ CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, r_empty, parallel);
+ // right empty
+ CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, r_empty, parallel);
+ // left empty
+ CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, r_n_empty, parallel);
+ break;
+ case JoinType::INNER:
+ case JoinType::LEFT_OUTER:
+ case JoinType::RIGHT_OUTER:
+ case JoinType::FULL_OUTER:
+ default:
+ FAIL() << "join type not implemented!";
+ }
+}
+
+class HashJoinTest : public testing::TestWithParam<std::tuple<JoinType, bool>> {};
+
+INSTANTIATE_TEST_SUITE_P(
+ HashJoinTest, HashJoinTest,
+ ::testing::Combine(::testing::Values(JoinType::LEFT_SEMI, JoinType::RIGHT_SEMI,
+ JoinType::LEFT_ANTI, JoinType::RIGHT_ANTI),
+ ::testing::Values(false, true)));
+
+TEST_P(HashJoinTest, TestSemiJoins) {
+ RunNonEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam()));
+}
+
+TEST_P(HashJoinTest, TestSemiJoinsEmpty) {
+ RunEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam()));
+}
+
+class Random64Bit {
+ public:
+ explicit Random64Bit(random::SeedType seed) : rng_(seed) {}
+ uint64_t next() { return dist_(rng_); }
+ template <typename T>
+ inline T from_range(const T& min_val, const T& max_val) {
+ return static_cast<T>(min_val + (next() % (max_val - min_val + 1)));
+ }
+
+ private:
+ random::pcg32_fast rng_;
+ std::uniform_int_distribution<uint64_t> dist_;
+};
+
+struct RandomDataTypeConstraints {
+ int64_t data_type_enabled_mask;
+ // Null related
+ double min_null_probability;
+ double max_null_probability;
+ // Binary related
+ int min_binary_length;
+ int max_binary_length;
+ // String related
+ int min_string_length;
+ int max_string_length;
+
+ void Default() {
+ data_type_enabled_mask = kInt1 | kInt2 | kInt4 | kInt8 | kBool | kBinary | kString;
+ min_null_probability = 0.0;
+ max_null_probability = 0.2;
+ min_binary_length = 1;
+ max_binary_length = 40;
+ min_string_length = 0;
+ max_string_length = 40;
+ }
+
+ void OnlyInt(int int_size, bool allow_nulls) {
+ Default();
+ data_type_enabled_mask =
+ int_size == 8 ? kInt8 : int_size == 4 ? kInt4 : int_size == 2 ? kInt2 : kInt1;
+ if (!allow_nulls) {
+ max_null_probability = 0.0;
+ }
+ }
+
+ void OnlyString(bool allow_nulls) {
+ Default();
+ data_type_enabled_mask = kString;
+ if (!allow_nulls) {
+ max_null_probability = 0.0;
+ }
+ }
+
+ // Data type mask constants
+ static constexpr int64_t kInt1 = 1;
+ static constexpr int64_t kInt2 = 2;
+ static constexpr int64_t kInt4 = 4;
+ static constexpr int64_t kInt8 = 8;
+ static constexpr int64_t kBool = 16;
+ static constexpr int64_t kBinary = 32;
+ static constexpr int64_t kString = 64;
+};
+
+struct RandomDataType {
+ double null_probability;
+ bool is_fixed_length;
+ int fixed_length;
+ int min_string_length;
+ int max_string_length;
+
+ static RandomDataType Random(Random64Bit& rng,
+ const RandomDataTypeConstraints& constraints) {
+ RandomDataType result;
+ if ((constraints.data_type_enabled_mask & constraints.kString) != 0) {
+ if (constraints.data_type_enabled_mask != constraints.kString) {
+ // Both string and fixed length types enabled
+ // 50% chance of string
+ result.is_fixed_length = ((rng.next() % 2) == 0);
+ } else {
+ result.is_fixed_length = false;
+ }
+ } else {
+ result.is_fixed_length = true;
+ }
+ if (constraints.max_null_probability > 0.0) {
+ // 25% chance of no nulls
+ // Uniform distribution of null probability from min to max
+ result.null_probability = ((rng.next() % 4) == 0)
+ ? 0.0
+ : static_cast<double>(rng.next() % 1025) / 1024.0 *
+ (constraints.max_null_probability -
+ constraints.min_null_probability) +
+ constraints.min_null_probability;
+ } else {
+ result.null_probability = 0.0;
+ }
+ // Pick data type for fixed length
+ if (result.is_fixed_length) {
+ int log_type;
+ for (;;) {
+ log_type = rng.next() % 6;
+ if (constraints.data_type_enabled_mask & (1ULL << log_type)) {
+ break;
+ }
+ }
+ if ((1ULL << log_type) == constraints.kBinary) {
+ for (;;) {
+ result.fixed_length = rng.from_range(constraints.min_binary_length,
+ constraints.max_binary_length);
+ if (result.fixed_length != 1 && result.fixed_length != 2 &&
+ result.fixed_length != 4 && result.fixed_length != 8) {
+ break;
+ }
+ }
+ } else {
+ result.fixed_length =
+ ((1ULL << log_type) == constraints.kBool) ? 0 : (1ULL << log_type);
+ }
+ } else {
+ // Pick parameters for string
+ result.min_string_length =
+ rng.from_range(constraints.min_string_length, constraints.max_string_length);
+ result.max_string_length =
+ rng.from_range(constraints.min_string_length, constraints.max_string_length);
+ if (result.min_string_length > result.max_string_length) {
+ std::swap(result.min_string_length, result.max_string_length);
+ }
+ }
+ return result;
+ }
+};
+
+struct RandomDataTypeVector {
+ std::vector<RandomDataType> data_types;
+
+ void AddRandom(Random64Bit& rng, const RandomDataTypeConstraints& constraints) {
+ data_types.push_back(RandomDataType::Random(rng, constraints));
+ }
+
+ void Print() {
+ for (size_t i = 0; i < data_types.size(); ++i) {
+ if (!data_types[i].is_fixed_length) {
+ std::cout << "str[" << data_types[i].min_string_length << ".."
+ << data_types[i].max_string_length << "]";
+ SCOPED_TRACE("str[" + std::to_string(data_types[i].min_string_length) + ".." +
+ std::to_string(data_types[i].max_string_length) + "]");
+ } else {
+ std::cout << "int[" << data_types[i].fixed_length << "]";
+ SCOPED_TRACE("int[" + std::to_string(data_types[i].fixed_length) + "]");
+ }
+ }
+ std::cout << std::endl;
+ }
+};
+
+std::vector<std::shared_ptr<Array>> GenRandomRecords(
+ Random64Bit& rng, const std::vector<RandomDataType>& data_types, int num_rows) {
+ std::vector<std::shared_ptr<Array>> result;
+ random::RandomArrayGenerator rag(static_cast<random::SeedType>(rng.next()));
+ for (size_t i = 0; i < data_types.size(); ++i) {
+ if (data_types[i].is_fixed_length) {
+ switch (data_types[i].fixed_length) {
+ case 0:
+ result.push_back(rag.Boolean(num_rows, 0.5, data_types[i].null_probability));
+ break;
+ case 1:
+ result.push_back(rag.UInt8(num_rows, std::numeric_limits<uint8_t>::min(),
+ std::numeric_limits<uint8_t>::max(),
+ data_types[i].null_probability));
+ break;
+ case 2:
+ result.push_back(rag.UInt16(num_rows, std::numeric_limits<uint16_t>::min(),
+ std::numeric_limits<uint16_t>::max(),
+ data_types[i].null_probability));
+ break;
+ case 4:
+ result.push_back(rag.UInt32(num_rows, std::numeric_limits<uint32_t>::min(),
+ std::numeric_limits<uint32_t>::max(),
+ data_types[i].null_probability));
+ break;
+ case 8:
+ result.push_back(rag.UInt64(num_rows, std::numeric_limits<uint64_t>::min(),
+ std::numeric_limits<uint64_t>::max(),
+ data_types[i].null_probability));
+ break;
+ default:
+ result.push_back(rag.FixedSizeBinary(num_rows, data_types[i].fixed_length,
+ data_types[i].null_probability));
+ break;
+ }
+ } else {
+ result.push_back(rag.String(num_rows, data_types[i].min_string_length,
+ data_types[i].max_string_length,
+ data_types[i].null_probability));
+ }
+ }
+ return result;
+}
+
+// Index < 0 means appending null values to all columns.
+//
+void TakeUsingVector(ExecContext* ctx, const std::vector<std::shared_ptr<Array>>& input,
+ const std::vector<int32_t> indices,
+ std::vector<std::shared_ptr<Array>>* result) {
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<Buffer> buf,
+ AllocateBuffer(indices.size() * sizeof(int32_t), ctx->memory_pool()));
+ int32_t* buf_indices = reinterpret_cast<int32_t*>(buf->mutable_data());
+ bool has_null_rows = false;
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < 0) {
+ buf_indices[i] = 0;
+ has_null_rows = true;
+ } else {
+ buf_indices[i] = indices[i];
+ }
+ }
+ std::shared_ptr<Array> indices_array = MakeArray(ArrayData::Make(
+ int32(), indices.size(), {nullptr, std::move(buf)}, /*null_count=*/0));
+
+ result->resize(input.size());
+ for (size_t i = 0; i < result->size(); ++i) {
+ ASSERT_OK_AND_ASSIGN(Datum new_array, Take(input[i], indices_array));
+ (*result)[i] = new_array.make_array();
+ }
+ if (has_null_rows) {
+ for (size_t i = 0; i < result->size(); ++i) {
+ if ((*result)[i]->data()->buffers[0] == NULLPTR) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> null_buf,
+ AllocateBitmap(indices.size(), ctx->memory_pool()));
+ uint8_t* non_nulls = null_buf->mutable_data();
+ memset(non_nulls, 0xFF, BitUtil::BytesForBits(indices.size()));
+ if ((*result)[i]->data()->buffers.size() == 2) {
+ (*result)[i] = MakeArray(
+ ArrayData::Make((*result)[i]->type(), indices.size(),
+ {std::move(null_buf), (*result)[i]->data()->buffers[1]}));
+ } else {
+ (*result)[i] = MakeArray(
+ ArrayData::Make((*result)[i]->type(), indices.size(),
+ {std::move(null_buf), (*result)[i]->data()->buffers[1],
+ (*result)[i]->data()->buffers[2]}));
+ }
+ }
+ (*result)[i]->data()->SetNullCount(kUnknownNullCount);
+ }
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < 0) {
+ for (size_t col = 0; col < result->size(); ++col) {
+ uint8_t* non_nulls = (*result)[col]->data()->buffers[0]->mutable_data();
+ BitUtil::ClearBit(non_nulls, i);
+ }
+ }
+ }
+ }
+}
+
+// Generate random arrays given list of data type descriptions and null probabilities.
+// Make sure that all generated records are unique.
+// The actual number of generated records may be lower than desired because duplicates
+// will be removed without replacement.
+//
+std::vector<std::shared_ptr<Array>> GenRandomUniqueRecords(
+ Random64Bit& rng, const RandomDataTypeVector& data_types, int num_desired,
+ int* num_actual) {
+ std::vector<std::shared_ptr<Array>> result =
+ GenRandomRecords(rng, data_types.data_types, num_desired);
+
+ ExecContext* ctx = default_exec_context();
+ std::vector<ValueDescr> val_descrs;
+ for (size_t i = 0; i < result.size(); ++i) {
+ val_descrs.push_back(ValueDescr(result[i]->type(), ValueDescr::ARRAY));
+ }
+ internal::RowEncoder encoder;
+ encoder.Init(val_descrs, ctx);
+ ExecBatch batch({}, num_desired);
+ batch.values.resize(result.size());
+ for (size_t i = 0; i < result.size(); ++i) {
+ batch.values[i] = result[i];
+ }
+ Status status = encoder.EncodeAndAppend(batch);
+ ARROW_DCHECK(status.ok());
+
+ std::unordered_map<std::string, int> uniques;
+ std::vector<int32_t> ids;
+ for (int i = 0; i < num_desired; ++i) {
+ if (uniques.find(encoder.encoded_row(i)) == uniques.end()) {
+ uniques.insert(std::make_pair(encoder.encoded_row(i), i));
+ ids.push_back(i);
+ }
+ }
+ *num_actual = static_cast<int>(uniques.size());
+
+ std::vector<std::shared_ptr<Array>> output;
+ TakeUsingVector(ctx, result, ids, &output);
+ return output;
+}
+
+std::vector<bool> NullInKey(const std::vector<JoinKeyCmp>& cmp,
+ const std::vector<std::shared_ptr<Array>>& key) {
+ ARROW_DCHECK(cmp.size() <= key.size());
+ ARROW_DCHECK(key.size() > 0);
+ std::vector<bool> result;
+ result.resize(key[0]->length());
+ for (size_t i = 0; i < result.size(); ++i) {
+ result[i] = false;
+ }
+ for (size_t i = 0; i < cmp.size(); ++i) {
+ if (cmp[i] != JoinKeyCmp::EQ) {
+ continue;
+ }
+ if (key[i]->data()->buffers[0] == NULLPTR) {
+ continue;
+ }
+ const uint8_t* nulls = key[i]->data()->buffers[0]->data();
+ if (!nulls) {
+ continue;
+ }
+ for (size_t j = 0; j < result.size(); ++j) {
+ if (!BitUtil::GetBit(nulls, j)) {
+ result[j] = true;
+ }
+ }
+ }
+ return result;
+}
+
+void GenRandomJoinTables(ExecContext* ctx, Random64Bit& rng, int num_rows_l,
+ int num_rows_r, int num_keys_common, int num_keys_left,
+ int num_keys_right, const RandomDataTypeVector& key_types,
+ const RandomDataTypeVector& payload_left_types,
+ const RandomDataTypeVector& payload_right_types,
+ std::vector<int32_t>* key_id_l, std::vector<int32_t>* key_id_r,
+ std::vector<std::shared_ptr<Array>>* left,
+ std::vector<std::shared_ptr<Array>>* right) {
+ // Generate random keys dictionary
+ //
+ int num_keys_desired = num_keys_left + num_keys_right - num_keys_common;
+ int num_keys_actual = 0;
+ std::vector<std::shared_ptr<Array>> keys =
+ GenRandomUniqueRecords(rng, key_types, num_keys_desired, &num_keys_actual);
+
+ // There will be three dictionary id ranges:
+ // - common keys [0..num_keys_common-1]
+ // - keys on right that are not on left [num_keys_common..num_keys_right-1]
+ // - keys on left that are not on right [num_keys_right..num_keys_actual-1]
+ //
+ num_keys_common = static_cast<int>(static_cast<int64_t>(num_keys_common) *
+ num_keys_actual / num_keys_desired);
+ num_keys_right = static_cast<int>(static_cast<int64_t>(num_keys_right) *
+ num_keys_actual / num_keys_desired);
+ ARROW_DCHECK(num_keys_right >= num_keys_common);
+ num_keys_left = num_keys_actual - num_keys_right + num_keys_common;
+ if (num_keys_left == 0) {
+ ARROW_DCHECK(num_keys_common == 0 && num_keys_right > 0);
+ ++num_keys_left;
+ ++num_keys_common;
+ }
+ if (num_keys_right == 0) {
+ ARROW_DCHECK(num_keys_common == 0 && num_keys_left > 0);
+ ++num_keys_right;
+ ++num_keys_common;
+ }
+ ARROW_DCHECK(num_keys_left >= num_keys_common);
+ ARROW_DCHECK(num_keys_left + num_keys_right - num_keys_common == num_keys_actual);
+
+ key_id_l->resize(num_rows_l);
+ for (int i = 0; i < num_rows_l; ++i) {
+ (*key_id_l)[i] = rng.from_range(0, num_keys_left - 1);
+ if ((*key_id_l)[i] >= num_keys_common) {
+ (*key_id_l)[i] += num_keys_right - num_keys_common;
+ }
+ }
+
+ key_id_r->resize(num_rows_r);
+ for (int i = 0; i < num_rows_r; ++i) {
+ (*key_id_r)[i] = rng.from_range(0, num_keys_right - 1);
+ }
+
+ std::vector<std::shared_ptr<Array>> key_l;
+ std::vector<std::shared_ptr<Array>> key_r;
+ TakeUsingVector(ctx, keys, *key_id_l, &key_l);
+ TakeUsingVector(ctx, keys, *key_id_r, &key_r);
+ std::vector<std::shared_ptr<Array>> payload_l =
+ GenRandomRecords(rng, payload_left_types.data_types, num_rows_l);
+ std::vector<std::shared_ptr<Array>> payload_r =
+ GenRandomRecords(rng, payload_right_types.data_types, num_rows_r);
+
+ left->resize(key_l.size() + payload_l.size());
+ for (size_t i = 0; i < key_l.size(); ++i) {
+ (*left)[i] = key_l[i];
+ }
+ for (size_t i = 0; i < payload_l.size(); ++i) {
+ (*left)[key_l.size() + i] = payload_l[i];
+ }
+ right->resize(key_r.size() + payload_r.size());
+ for (size_t i = 0; i < key_r.size(); ++i) {
+ (*right)[i] = key_r[i];
+ }
+ for (size_t i = 0; i < payload_r.size(); ++i) {
+ (*right)[key_r.size() + i] = payload_r[i];
+ }
+}
+
+std::vector<std::shared_ptr<Array>> ConstructJoinOutputFromRowIds(
+ ExecContext* ctx, const std::vector<int32_t>& row_ids_l,
+ const std::vector<int32_t>& row_ids_r, const std::vector<std::shared_ptr<Array>>& l,
+ const std::vector<std::shared_ptr<Array>>& r,
+ const std::vector<int>& shuffle_output_l, const std::vector<int>& shuffle_output_r) {
+ std::vector<std::shared_ptr<Array>> full_output_l;
+ std::vector<std::shared_ptr<Array>> full_output_r;
+ TakeUsingVector(ctx, l, row_ids_l, &full_output_l);
+ TakeUsingVector(ctx, r, row_ids_r, &full_output_r);
+ std::vector<std::shared_ptr<Array>> result;
+ result.resize(shuffle_output_l.size() + shuffle_output_r.size());
+ for (size_t i = 0; i < shuffle_output_l.size(); ++i) {
+ result[i] = full_output_l[shuffle_output_l[i]];
+ }
+ for (size_t i = 0; i < shuffle_output_r.size(); ++i) {
+ result[shuffle_output_l.size() + i] = full_output_r[shuffle_output_r[i]];
+ }
+ return result;
+}
+
+BatchesWithSchema TableToBatches(Random64Bit& rng, int num_batches,
+ const std::vector<std::shared_ptr<Array>>& table,
+ const std::string& column_name_prefix) {
+ BatchesWithSchema result;
+
+ std::vector<std::shared_ptr<Field>> fields;
+ fields.resize(table.size());
+ for (size_t i = 0; i < table.size(); ++i) {
+ fields[i] = std::make_shared<Field>(column_name_prefix + std::to_string(i),
+ table[i]->type(), true);
+ }
+ result.schema = std::make_shared<Schema>(std::move(fields));
+
+ int64_t length = table[0]->length();
+ num_batches = std::min(num_batches, static_cast<int>(length));
+
+ std::vector<int64_t> batch_offsets;
+ batch_offsets.push_back(0);
+ batch_offsets.push_back(length);
+ std::unordered_set<int64_t> batch_offset_set;
+ for (int i = 0; i < num_batches - 1; ++i) {
+ for (;;) {
+ int64_t offset = rng.from_range(static_cast<int64_t>(1), length - 1);
+ if (batch_offset_set.find(offset) == batch_offset_set.end()) {
+ batch_offset_set.insert(offset);
+ batch_offsets.push_back(offset);
+ break;
+ }
+ }
+ }
+ std::sort(batch_offsets.begin(), batch_offsets.end());
+
+ for (int i = 0; i < num_batches; ++i) {
+ int64_t batch_offset = batch_offsets[i];
+ int64_t batch_length = batch_offsets[i + 1] - batch_offsets[i];
+ ExecBatch batch({}, batch_length);
+ batch.values.resize(table.size());
+ for (size_t col = 0; col < table.size(); ++col) {
+ batch.values[col] = table[col]->data()->Slice(batch_offset, batch_length);
+ }
+ result.batches.push_back(batch);
+ }
+
+ return result;
+}
+
+// -1 in result means outputting all corresponding fields as nulls
+//
+void HashJoinSimpleInt(JoinType join_type, const std::vector<int32_t>& l,
+ const std::vector<bool>& null_in_key_l,
+ const std::vector<int32_t>& r,
+ const std::vector<bool>& null_in_key_r,
+ std::vector<int32_t>* result_l, std::vector<int32_t>* result_r,
+ int64_t output_length_limit, bool* length_limit_reached) {
+ *length_limit_reached = false;
+
+ bool switch_sides = false;
+ switch (join_type) {
+ case JoinType::RIGHT_SEMI:
+ join_type = JoinType::LEFT_SEMI;
+ switch_sides = true;
+ break;
+ case JoinType::RIGHT_ANTI:
+ join_type = JoinType::LEFT_ANTI;
+ switch_sides = true;
+ break;
+ case JoinType::RIGHT_OUTER:
+ join_type = JoinType::LEFT_OUTER;
+ switch_sides = true;
+ break;
+ default:
+ break;
+ }
+ const std::vector<int32_t>& build = switch_sides ? l : r;
+ const std::vector<int32_t>& probe = switch_sides ? r : l;
+ const std::vector<bool>& null_in_key_build =
+ switch_sides ? null_in_key_l : null_in_key_r;
+ const std::vector<bool>& null_in_key_probe =
+ switch_sides ? null_in_key_r : null_in_key_l;
+ std::vector<int32_t>* result_build = switch_sides ? result_l : result_r;
+ std::vector<int32_t>* result_probe = switch_sides ? result_r : result_l;
+
+ std::unordered_multimap<int64_t, int64_t> map_build;
+ for (size_t i = 0; i < build.size(); ++i) {
+ map_build.insert(std::make_pair(build[i], i));
+ }
+ std::vector<bool> match_build;
+ match_build.resize(build.size());
+ for (size_t i = 0; i < build.size(); ++i) {
+ match_build[i] = false;
+ }
+
+ for (int32_t i = 0; i < static_cast<int32_t>(probe.size()); ++i) {
+ std::vector<int32_t> match_probe;
+ if (!null_in_key_probe[i]) {
+ auto range = map_build.equal_range(probe[i]);
+ for (auto it = range.first; it != range.second; ++it) {
+ if (!null_in_key_build[it->second]) {
+ match_probe.push_back(static_cast<int32_t>(it->second));
+ match_build[it->second] = true;
+ }
+ }
+ }
+ switch (join_type) {
+ case JoinType::LEFT_SEMI:
+ if (!match_probe.empty()) {
+ result_probe->push_back(i);
+ result_build->push_back(-1);
+ }
+ break;
+ case JoinType::LEFT_ANTI:
+ if (match_probe.empty()) {
+ result_probe->push_back(i);
+ result_build->push_back(-1);
+ }
+ break;
+ case JoinType::INNER:
+ for (size_t j = 0; j < match_probe.size(); ++j) {
+ result_probe->push_back(i);
+ result_build->push_back(match_probe[j]);
+ }
+ break;
+ case JoinType::LEFT_OUTER:
+ case JoinType::FULL_OUTER:
+ if (match_probe.empty()) {
+ result_probe->push_back(i);
+ result_build->push_back(-1);
+ } else {
+ for (size_t j = 0; j < match_probe.size(); ++j) {
+ result_probe->push_back(i);
+ result_build->push_back(match_probe[j]);
+ }
+ }
+ break;
+ default:
+ ARROW_DCHECK(false);
+ break;
+ }
+
+ if (static_cast<int64_t>(result_probe->size()) >= output_length_limit) {
+ *length_limit_reached = true;
+ return;
+ }
+ }
+
+ if (join_type == JoinType::FULL_OUTER) {
+ for (int32_t i = 0; i < static_cast<int32_t>(build.size()); ++i) {
+ if (!match_build[i]) {
+ result_probe->push_back(-1);
+ result_build->push_back(i);
+ }
+ }
+ }
+}
+
+std::vector<int> GenShuffle(Random64Bit& rng, int length) {
+ std::vector<int> shuffle(length);
+ std::iota(shuffle.begin(), shuffle.end(), 0);
+ for (int i = 0; i < length * 2; ++i) {
+ int from = rng.from_range(0, length - 1);
+ int to = rng.from_range(0, length - 1);
+ if (from != to) {
+ std::swap(shuffle[from], shuffle[to]);
+ }
+ }
+ return shuffle;
+}
+
+void GenJoinFieldRefs(Random64Bit& rng, int num_key_fields, bool no_output,
+ const std::vector<std::shared_ptr<Array>>& original_input,
+ const std::string& field_name_prefix,
+ std::vector<std::shared_ptr<Array>>* new_input,
+ std::vector<FieldRef>* keys, std::vector<FieldRef>* output,
+ std::vector<int>* output_field_ids) {
+ // Permute input
+ std::vector<int> shuffle = GenShuffle(rng, static_cast<int>(original_input.size()));
+ new_input->resize(original_input.size());
+ for (size_t i = 0; i < original_input.size(); ++i) {
+ (*new_input)[i] = original_input[shuffle[i]];
+ }
+
+ // Compute key field refs
+ keys->resize(num_key_fields);
+ for (size_t i = 0; i < shuffle.size(); ++i) {
+ if (shuffle[i] < num_key_fields) {
+ bool use_by_name_ref = (rng.from_range(0, 1) == 0);
+ if (use_by_name_ref) {
+ (*keys)[shuffle[i]] = FieldRef(field_name_prefix + std::to_string(i));
+ } else {
+ (*keys)[shuffle[i]] = FieldRef(static_cast<int>(i));
+ }
+ }
+ }
+
+ // Compute output field refs
+ if (!no_output) {
+ int num_output = rng.from_range(1, static_cast<int>(original_input.size() + 1));
+ output_field_ids->resize(num_output);
+ output->resize(num_output);
+ for (int i = 0; i < num_output; ++i) {
+ int col_id = rng.from_range(0, static_cast<int>(original_input.size() - 1));
+ (*output_field_ids)[i] = col_id;
+ (*output)[i] = (rng.from_range(0, 1) == 0)
+ ? FieldRef(field_name_prefix + std::to_string(col_id))
+ : FieldRef(col_id);
+ }
+ }
+}
+
+std::shared_ptr<Table> HashJoinSimple(
+ ExecContext* ctx, JoinType join_type, const std::vector<JoinKeyCmp>& cmp,
+ int num_key_fields, const std::vector<int32_t>& key_id_l,
+ const std::vector<int32_t>& key_id_r,
+ const std::vector<std::shared_ptr<Array>>& original_l,
+ const std::vector<std::shared_ptr<Array>>& original_r,
+ const std::vector<std::shared_ptr<Array>>& l,
+ const std::vector<std::shared_ptr<Array>>& r, const std::vector<int>& output_ids_l,
+ const std::vector<int>& output_ids_r, int64_t output_length_limit,
+ bool* length_limit_reached) {
+ std::vector<std::shared_ptr<Array>> key_l(num_key_fields);
+ std::vector<std::shared_ptr<Array>> key_r(num_key_fields);
+ for (int i = 0; i < num_key_fields; ++i) {
+ key_l[i] = original_l[i];
+ key_r[i] = original_r[i];
+ }
+ std::vector<bool> null_key_l = NullInKey(cmp, key_l);
+ std::vector<bool> null_key_r = NullInKey(cmp, key_r);
+
+ std::vector<int32_t> row_ids_l;
+ std::vector<int32_t> row_ids_r;
+ HashJoinSimpleInt(join_type, key_id_l, null_key_l, key_id_r, null_key_r, &row_ids_l,
+ &row_ids_r, output_length_limit, length_limit_reached);
+
+ std::vector<std::shared_ptr<Array>> result = ConstructJoinOutputFromRowIds(
+ ctx, row_ids_l, row_ids_r, l, r, output_ids_l, output_ids_r);
+
+ std::vector<std::shared_ptr<Field>> fields(result.size());
+ for (size_t i = 0; i < result.size(); ++i) {
+ fields[i] = std::make_shared<Field>("a" + std::to_string(i), result[i]->type(), true);
+ }
+ std::shared_ptr<Schema> schema = std::make_shared<Schema>(std::move(fields));
+ return Table::Make(schema, result, result[0]->length());
+}
+
+void HashJoinWithExecPlan(Random64Bit& rng, bool parallel,
+ const HashJoinNodeOptions& join_options,
+ const std::shared_ptr<Schema>& output_schema,
+ const std::vector<std::shared_ptr<Array>>& l,
+ const std::vector<std::shared_ptr<Array>>& r, int num_batches_l,
+ int num_batches_r, std::shared_ptr<Table>* output) {
+ auto exec_ctx = arrow::internal::make_unique<ExecContext>(
+ default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
+
+ // add left source
+ BatchesWithSchema l_batches = TableToBatches(rng, num_batches_l, l, "l_");
+ ASSERT_OK_AND_ASSIGN(
+ ExecNode * l_source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{l_batches.schema, l_batches.gen(parallel,
+ /*slow=*/false)}));
+
+ // add right source
+ BatchesWithSchema r_batches = TableToBatches(rng, num_batches_r, r, "r_");
+ ASSERT_OK_AND_ASSIGN(
+ ExecNode * r_source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{r_batches.schema, r_batches.gen(parallel,
+ /*slow=*/false)}));
+
+ ASSERT_OK_AND_ASSIGN(ExecNode * join, MakeExecNode("hashjoin", plan.get(),
+ {l_source, r_source}, join_options));
+
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+ ASSERT_OK_AND_ASSIGN(
+ std::ignore, MakeExecNode("sink", plan.get(), {join}, SinkNodeOptions{&sink_gen}));
+
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen));
+
+ ASSERT_OK_AND_ASSIGN(*output, TableFromExecBatches(output_schema, res));
+}
+
+TEST(HashJoin, Random) {
+ Random64Bit rng(42);
+
+ int num_tests = 100;
+ for (int test_id = 0; test_id < num_tests; ++test_id) {
+ bool parallel = (rng.from_range(0, 1) == 1);
+ auto exec_ctx = arrow::internal::make_unique<ExecContext>(
+ default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
+
+ // Constraints
+ RandomDataTypeConstraints type_constraints;
+ type_constraints.Default();
+ // type_constraints.OnlyInt(1, true);
+ constexpr int max_num_key_fields = 3;
+ constexpr int max_num_payload_fields = 3;
+ const char* join_type_names[] = {"LEFT_SEMI", "RIGHT_SEMI", "LEFT_ANTI",
+ "RIGHT_ANTI", "INNER", "LEFT_OUTER",
+ "RIGHT_OUTER", "FULL_OUTER"};
+ std::vector<JoinType> join_type_options{JoinType::LEFT_SEMI, JoinType::RIGHT_SEMI,
+ JoinType::LEFT_ANTI, JoinType::RIGHT_ANTI,
+ JoinType::INNER, JoinType::LEFT_OUTER,
+ JoinType::RIGHT_OUTER, JoinType::FULL_OUTER};
+ constexpr int join_type_mask = 0xFF;
+ // for INNER join only:
+ // constexpr int join_type_mask = 0x10;
+ std::vector<JoinKeyCmp> key_cmp_options{JoinKeyCmp::EQ, JoinKeyCmp::IS};
+ constexpr int key_cmp_mask = 0x03;
+ // for EQ only:
+ // constexpr int key_cmp_mask = 0x01;
+ constexpr int min_num_rows = 1;
+ const int max_num_rows = parallel ? 20000 : 2000;
+ constexpr int min_batch_size = 10;
+ constexpr int max_batch_size = 100;
+
+ // Generate list of key field data types
+ int num_key_fields = rng.from_range(1, max_num_key_fields);
+ RandomDataTypeVector key_types;
+ for (int i = 0; i < num_key_fields; ++i) {
+ key_types.AddRandom(rng, type_constraints);
+ }
+
+ // Generate lists of payload data types
+ int num_payload_fields[2];
+ RandomDataTypeVector payload_types[2];
+ for (int i = 0; i < 2; ++i) {
+ num_payload_fields[i] = rng.from_range(0, max_num_payload_fields);
+ for (int j = 0; j < num_payload_fields[i]; ++j) {
+ payload_types[i].AddRandom(rng, type_constraints);
+ }
+ }
+
+ // Generate join type and comparison functions
+ std::vector<JoinKeyCmp> key_cmp(num_key_fields);
+ std::string key_cmp_str;
+ for (int i = 0; i < num_key_fields; ++i) {
+ for (;;) {
+ int pos = rng.from_range(0, 1);
+ if ((key_cmp_mask & (1 << pos)) > 0) {
+ key_cmp[i] = key_cmp_options[pos];
+ if (i > 0) {
+ key_cmp_str += "_";
+ }
+ key_cmp_str += key_cmp[i] == JoinKeyCmp::EQ ? "EQ" : "IS";
+ break;
+ }
+ }
+ }
+ JoinType join_type;
+ std::string join_type_name;
+ for (;;) {
+ int pos = rng.from_range(0, 7);
+ if ((join_type_mask & (1 << pos)) > 0) {
+ join_type = join_type_options[pos];
+ join_type_name = join_type_names[pos];
+ break;
+ }
+ }
+
+ // Generate input records
+ int num_rows_l = rng.from_range(min_num_rows, max_num_rows);
+ int num_rows_r = rng.from_range(min_num_rows, max_num_rows);
+ int num_rows = std::min(num_rows_l, num_rows_r);
+ int batch_size = rng.from_range(min_batch_size, max_batch_size);
+ int num_keys = rng.from_range(std::max(1, num_rows / 10), num_rows);
+ int num_keys_r = rng.from_range(std::max(1, num_keys / 2), num_keys);
+ int num_keys_common = rng.from_range(std::max(1, num_keys_r / 2), num_keys_r);
+ int num_keys_l = num_keys_common + (num_keys - num_keys_r);
+ std::vector<int> key_id_vectors[2];
+ std::vector<std::shared_ptr<Array>> input_arrays[2];
+ GenRandomJoinTables(exec_ctx.get(), rng, num_rows_l, num_rows_r, num_keys_common,
+ num_keys_l, num_keys_r, key_types, payload_types[0],
+ payload_types[1], &(key_id_vectors[0]), &(key_id_vectors[1]),
+ &(input_arrays[0]), &(input_arrays[1]));
+ std::vector<std::shared_ptr<Array>> shuffled_input_arrays[2];
+ std::vector<FieldRef> key_fields[2];
+ std::vector<FieldRef> output_fields[2];
+ std::vector<int> output_field_ids[2];
+ for (int i = 0; i < 2; ++i) {
+ bool no_output = false;
+ if (i == 0) {
+ no_output =
+ join_type == JoinType::RIGHT_SEMI || join_type == JoinType::RIGHT_ANTI;
+ } else {
+ no_output = join_type == JoinType::LEFT_SEMI || join_type == JoinType::LEFT_ANTI;
+ }
+ GenJoinFieldRefs(rng, num_key_fields, no_output, input_arrays[i],
+ std::string((i == 0) ? "l_" : "r_"), &(shuffled_input_arrays[i]),
+ &(key_fields[i]), &(output_fields[i]), &(output_field_ids[i]));
+ }
+
+ // Print test case parameters
+ // print num_rows, batch_size, join_type, join_cmp
+ std::cout << join_type_name << " " << key_cmp_str << " ";
+ key_types.Print();
+ std::cout << " payload_l: ";
+ payload_types[0].Print();
+ std::cout << " payload_r: ";
+ payload_types[1].Print();
+ std::cout << " num_rows_l = " << num_rows_l << " num_rows_r = " << num_rows_r
+ << " batch size = " << batch_size
+ << " parallel = " << (parallel ? "true" : "false");
+ std::cout << std::endl;
+
+ // Run reference join implementation
+ std::vector<bool> null_in_key_vectors[2];
+ for (int i = 0; i < 2; ++i) {
+ null_in_key_vectors[i] = NullInKey(key_cmp, input_arrays[i]);
+ }
+ int64_t output_length_limit = 100000;
+ bool length_limit_reached = false;
+ std::shared_ptr<Table> output_rows_ref = HashJoinSimple(
+ exec_ctx.get(), join_type, key_cmp, num_key_fields, key_id_vectors[0],
+ key_id_vectors[1], input_arrays[0], input_arrays[1], shuffled_input_arrays[0],
+ shuffled_input_arrays[1], output_field_ids[0], output_field_ids[1],
+ output_length_limit, &length_limit_reached);
+ if (length_limit_reached) {
+ continue;
+ }
+
+ // Run tested join implementation
+ HashJoinNodeOptions join_options{join_type, key_fields[0], key_fields[1],
+ output_fields[0], output_fields[1], key_cmp};
+ std::vector<std::shared_ptr<Field>> output_schema_fields;
+ for (int i = 0; i < 2; ++i) {
+ for (size_t col = 0; col < output_fields[i].size(); ++col) {
+ output_schema_fields.push_back(std::make_shared<Field>(
+ std::string((i == 0) ? "l_" : "r_") + std::to_string(col),
+ shuffled_input_arrays[i][output_field_ids[i][col]]->type(), true));
+ }
+ }
+ std::shared_ptr<Schema> output_schema =
+ std::make_shared<Schema>(std::move(output_schema_fields));
+ std::shared_ptr<Table> output_rows_test;
+ HashJoinWithExecPlan(rng, parallel, join_options, output_schema,
+ shuffled_input_arrays[0], shuffled_input_arrays[1],
+ static_cast<int>(BitUtil::CeilDiv(num_rows_l, batch_size)),
+ static_cast<int>(BitUtil::CeilDiv(num_rows_r, batch_size)),
+ &output_rows_test);
+
+ // Compare results
+ AssertTablesEqual(output_rows_ref, output_rows_test);
+ }
+}
+
+void DecodeScalarsAndDictionariesInBatch(ExecBatch* batch, MemoryPool* pool) {
+ for (size_t i = 0; i < batch->values.size(); ++i) {
+ if (batch->values[i].is_scalar()) {
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<Array> col,
+ MakeArrayFromScalar(*(batch->values[i].scalar()), batch->length, pool));
+ batch->values[i] = Datum(col);
+ }
+ if (batch->values[i].type()->id() == Type::DICTIONARY) {
+ const auto& dict_type =
+ checked_cast<const DictionaryType&>(*batch->values[i].type());
+ std::shared_ptr<ArrayData> indices =
+ ArrayData::Make(dict_type.index_type(), batch->values[i].array()->length,
+ batch->values[i].array()->buffers);
+ const std::shared_ptr<ArrayData>& dictionary = batch->values[i].array()->dictionary;
+ ASSERT_OK_AND_ASSIGN(Datum col, Take(*dictionary, *indices));
+ batch->values[i] = col;
+ }
+ }
+}
+
+std::shared_ptr<Schema> UpdateSchemaAfterDecodingDictionaries(
+ const std::shared_ptr<Schema>& schema) {
+ std::vector<std::shared_ptr<Field>> output_fields(schema->num_fields());
+ for (int i = 0; i < schema->num_fields(); ++i) {
+ const std::shared_ptr<Field>& field = schema->field(i);
+ if (field->type()->id() == Type::DICTIONARY) {
+ const auto& dict_type = checked_cast<const DictionaryType&>(*field->type());
+ output_fields[i] = std::make_shared<Field>(field->name(), dict_type.value_type(),
+ true /* nullable */);
+ } else {
+ output_fields[i] = field->Copy();
+ }
+ }
+ return std::make_shared<Schema>(std::move(output_fields));
+}
+
+void TestHashJoinDictionaryHelper(
+ JoinType join_type, JoinKeyCmp cmp,
+ // Whether to run parallel hash join.
+ // This requires generating multiple copies of each input batch on one side of the
+ // join. Expected results will be automatically adjusted to reflect the multiplication
+ // of input batches.
+ bool parallel, Datum l_key, Datum l_payload, Datum r_key, Datum r_payload,
+ Datum l_out_key, Datum l_out_payload, Datum r_out_key, Datum r_out_payload,
+ // Number of rows at the end of expected output that represent rows from the right
+ // side that do not have a match on the left side. This number is needed to
+ // automatically adjust expected result when multiplying input batches on the left
+ // side.
+ int expected_num_r_no_match,
+ // Whether to swap two inputs to the hash join
+ bool swap_sides) {
+ int64_t l_length = l_key.is_array()
+ ? l_key.array()->length
+ : l_payload.is_array() ? l_payload.array()->length : -1;
+ int64_t r_length = r_key.is_array()
+ ? r_key.array()->length
+ : r_payload.is_array() ? r_payload.array()->length : -1;
+ ARROW_DCHECK(l_length >= 0 && r_length >= 0);
+
+ constexpr int batch_multiplicity_for_parallel = 2;
+
+ // Split both sides into exactly two batches
+ int64_t l_first_length = l_length / 2;
+ int64_t r_first_length = r_length / 2;
+ BatchesWithSchema l_batches, r_batches;
+ l_batches.batches.resize(2);
+ r_batches.batches.resize(2);
+ ASSERT_OK_AND_ASSIGN(
+ l_batches.batches[0],
+ ExecBatch::Make({l_key.is_array() ? l_key.array()->Slice(0, l_first_length) : l_key,
+ l_payload.is_array() ? l_payload.array()->Slice(0, l_first_length)
+ : l_payload}));
+ ASSERT_OK_AND_ASSIGN(
+ l_batches.batches[1],
+ ExecBatch::Make(
+ {l_key.is_array()
+ ? l_key.array()->Slice(l_first_length, l_length - l_first_length)
+ : l_key,
+ l_payload.is_array()
+ ? l_payload.array()->Slice(l_first_length, l_length - l_first_length)
+ : l_payload}));
+ ASSERT_OK_AND_ASSIGN(
+ r_batches.batches[0],
+ ExecBatch::Make({r_key.is_array() ? r_key.array()->Slice(0, r_first_length) : r_key,
+ r_payload.is_array() ? r_payload.array()->Slice(0, r_first_length)
+ : r_payload}));
+ ASSERT_OK_AND_ASSIGN(
+ r_batches.batches[1],
+ ExecBatch::Make(
+ {r_key.is_array()
+ ? r_key.array()->Slice(r_first_length, r_length - r_first_length)
+ : r_key,
+ r_payload.is_array()
+ ? r_payload.array()->Slice(r_first_length, r_length - r_first_length)
+ : r_payload}));
+ l_batches.schema =
+ schema({field("l_key", l_key.type()), field("l_payload", l_payload.type())});
+ r_batches.schema =
+ schema({field("r_key", r_key.type()), field("r_payload", r_payload.type())});
+
+ // Add copies of input batches on originally left side of the hash join
+ if (parallel) {
+ for (int i = 0; i < batch_multiplicity_for_parallel - 1; ++i) {
+ l_batches.batches.push_back(l_batches.batches[0]);
+ l_batches.batches.push_back(l_batches.batches[1]);
+ }
+ }
+
+ auto exec_ctx = arrow::internal::make_unique<ExecContext>(
+ default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
+ ASSERT_OK_AND_ASSIGN(
+ ExecNode * l_source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{l_batches.schema, l_batches.gen(parallel,
+ /*slow=*/false)}));
+ ASSERT_OK_AND_ASSIGN(
+ ExecNode * r_source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{r_batches.schema, r_batches.gen(parallel,
+ /*slow=*/false)}));
+ HashJoinNodeOptions join_options{join_type,
+ {FieldRef(swap_sides ? "r_key" : "l_key")},
+ {FieldRef(swap_sides ? "l_key" : "r_key")},
+ {FieldRef(swap_sides ? "r_key" : "l_key"),
+ FieldRef(swap_sides ? "r_payload" : "l_payload")},
+ {FieldRef(swap_sides ? "l_key" : "r_key"),
+ FieldRef(swap_sides ? "l_payload" : "r_payload")},
+ {cmp}};
+ ASSERT_OK_AND_ASSIGN(ExecNode * join, MakeExecNode("hashjoin", plan.get(),
+ {(swap_sides ? r_source : l_source),
+ (swap_sides ? l_source : r_source)},
+ join_options));
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+ ASSERT_OK_AND_ASSIGN(
+ std::ignore, MakeExecNode("sink", plan.get(), {join}, SinkNodeOptions{&sink_gen}));
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen));
+
+ for (auto& batch : res) {
+ DecodeScalarsAndDictionariesInBatch(&batch, exec_ctx->memory_pool());
+ }
+ std::shared_ptr<Schema> output_schema =
+ UpdateSchemaAfterDecodingDictionaries(join->output_schema());
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Table> output,
+ TableFromExecBatches(output_schema, res));
+
+ ExecBatch expected_batch;
+ if (swap_sides) {
+ ASSERT_OK_AND_ASSIGN(expected_batch, ExecBatch::Make({r_out_key, r_out_payload,
+ l_out_key, l_out_payload}));
+ } else {
+ ASSERT_OK_AND_ASSIGN(expected_batch, ExecBatch::Make({l_out_key, l_out_payload,
+ r_out_key, r_out_payload}));
+ }
+
+ DecodeScalarsAndDictionariesInBatch(&expected_batch, exec_ctx->memory_pool());
+
+ // Slice expected batch into two to separate rows on right side with no matches from
+ // everything else.
+ //
+ std::vector<ExecBatch> expected_batches;
+ ASSERT_OK_AND_ASSIGN(
+ auto prefix_batch,
+ ExecBatch::Make({expected_batch.values[0].array()->Slice(
+ 0, expected_batch.length - expected_num_r_no_match),
+ expected_batch.values[1].array()->Slice(
+ 0, expected_batch.length - expected_num_r_no_match),
+ expected_batch.values[2].array()->Slice(
+ 0, expected_batch.length - expected_num_r_no_match),
+ expected_batch.values[3].array()->Slice(
+ 0, expected_batch.length - expected_num_r_no_match)}));
+ for (int i = 0; i < (parallel ? batch_multiplicity_for_parallel : 1); ++i) {
+ expected_batches.push_back(prefix_batch);
+ }
+ if (expected_num_r_no_match > 0) {
+ ASSERT_OK_AND_ASSIGN(
+ auto suffix_batch,
+ ExecBatch::Make({expected_batch.values[0].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match),
+ expected_batch.values[1].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match),
+ expected_batch.values[2].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match),
+ expected_batch.values[3].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match)}));
+ expected_batches.push_back(suffix_batch);
+ }
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Table> expected,
+ TableFromExecBatches(output_schema, expected_batches));
+
+ // Compare results
+ AssertTablesEqual(expected, output);
+}
+
+TEST(HashJoin, Dictionary) {
+ auto int8_utf8 = dictionary(int8(), utf8());
+ auto uint8_utf8 = arrow::dictionary(uint8(), utf8());
+ auto int16_utf8 = arrow::dictionary(int16(), utf8());
+ auto uint16_utf8 = arrow::dictionary(uint16(), utf8());
+ auto int32_utf8 = arrow::dictionary(int32(), utf8());
+ auto uint32_utf8 = arrow::dictionary(uint32(), utf8());
+ auto int64_utf8 = arrow::dictionary(int64(), utf8());
+ auto uint64_utf8 = arrow::dictionary(uint64(), utf8());
+ std::shared_ptr<DataType> dict_types[] = {int8_utf8, uint8_utf8, int16_utf8,
+ uint16_utf8, int32_utf8, uint32_utf8,
+ int64_utf8, uint64_utf8};
+
+ Random64Bit rng(43);
+
+ // Dictionaries in payload columns
+ for (auto parallel : {false, true}) {
+ for (auto swap_sides : {false, true}) {
+ TestHashJoinDictionaryHelper(
+ JoinType::FULL_OUTER, JoinKeyCmp::EQ, parallel,
+ // Input
+ ArrayFromJSON(utf8(), R"(["a", "c", "c", "d"])"),
+ DictArrayFromJSON(int8_utf8, R"([4, 2, 3, 0])",
+ R"(["p", "q", "r", null, "r"])"),
+ ArrayFromJSON(utf8(), R"(["a", "a", "b", "c"])"),
+ DictArrayFromJSON(int16_utf8, R"([0, 1, 0, 2])", R"(["r", null, "r", "q"])"),
+ // Expected output
+ ArrayFromJSON(utf8(), R"(["a", "a", "c", "c", "d", null])"),
+ DictArrayFromJSON(int8_utf8, R"([4, 4, 2, 3, 0, null])",
+ R"(["p", "q", "r", null, "r"])"),
+ ArrayFromJSON(utf8(), R"(["a", "a", "c", "c", null, "b"])"),
+ DictArrayFromJSON(int16_utf8, R"([0, 1, 2, 2, null, 0])",
+ R"(["r", null, "r", "q"])"),
+ 1, swap_sides);
+ }
+ }
+
+ // Dictionaries in key columns
+ for (auto parallel : {false, true}) {
+ for (auto swap_sides : {false, true}) {
+ for (auto l_key_dict : {true, false}) {
+ for (auto r_key_dict : {true, false}) {
+ auto l_key_dict_type = dict_types[rng.from_range(0, 7)];
+ auto r_key_dict_type = dict_types[rng.from_range(0, 7)];
+
+ auto l_key = l_key_dict ? DictArrayFromJSON(l_key_dict_type, R"([2, 2, 0, 1])",
+ R"(["b", null, "a"])")
+ : ArrayFromJSON(utf8(), R"(["a", "a", "b", null])");
+ auto l_payload = ArrayFromJSON(utf8(), R"(["x", "y", "z", "y"])");
+ auto r_key = r_key_dict
+ ? DictArrayFromJSON(int16_utf8, R"([1, 0, null, 1, 2])",
+ R"([null, "b", "c"])")
+ : ArrayFromJSON(utf8(), R"(["b", null, null, "b", "c"])");
+ auto r_payload = ArrayFromJSON(utf8(), R"(["p", "r", "p", "q", "s"])");
+
+ // IS comparison function (null is equal to null when matching keys)
+ TestHashJoinDictionaryHelper(
+ JoinType::FULL_OUTER, JoinKeyCmp::IS, parallel,
+ // Input
+ l_key, l_payload, r_key, r_payload,
+ // Expected
+ l_key_dict ? DictArrayFromJSON(l_key_dict_type, R"([2, 2, 0, 0, 1, 1,
+ null])",
+ R"(["b", null, "a"])")
+ : ArrayFromJSON(utf8(), R"(["a", "a", "b", "b", null, null,
+ null])"),
+ ArrayFromJSON(utf8(), R"(["x", "y", "z", "z", "y", "y", null])"),
+ r_key_dict
+ ? DictArrayFromJSON(r_key_dict_type, R"([null, null, 0, 0, null, null,
+ 1])",
+ R"(["b", "c"])")
+ : ArrayFromJSON(utf8(), R"([null, null, "b", "b", null, null, "c"])"),
+ ArrayFromJSON(utf8(), R"([null, null, "p", "q", "r", "p", "s"])"), 1,
+ swap_sides);
+
+ // EQ comparison function (null is not matching null)
+ TestHashJoinDictionaryHelper(
+ JoinType::FULL_OUTER, JoinKeyCmp::EQ, parallel,
+ // Input
+ l_key, l_payload, r_key, r_payload,
+ // Expected
+ l_key_dict ? DictArrayFromJSON(l_key_dict_type,
+ R"([2, 2, 0, 0, 1, null, null, null])",
+ R"(["b", null, "a"])")
+ : ArrayFromJSON(
+ utf8(), R"(["a", "a", "b", "b", null, null, null, null])"),
+ ArrayFromJSON(utf8(), R"(["x", "y", "z", "z", "y", null, null, null])"),
+ r_key_dict
+ ? DictArrayFromJSON(r_key_dict_type,
+ R"([null, null, 0, 0, null, null, null, 1])",
+ R"(["b", "c"])")
+ : ArrayFromJSON(utf8(),
+ R"([null, null, "b", "b", null, null, null, "c"])"),
+ ArrayFromJSON(utf8(), R"([null, null, "p", "q", null, "r", "p", "s"])"), 3,
+ swap_sides);
+ }
+ }
+ }
+ }
+
+ // Empty build side
+ {
+ auto l_key_dict_type = dict_types[rng.from_range(0, 7)];
+ auto l_payload_dict_type = dict_types[rng.from_range(0, 7)];
+ auto r_key_dict_type = dict_types[rng.from_range(0, 7)];
+ auto r_payload_dict_type = dict_types[rng.from_range(0, 7)];
+
+ for (auto parallel : {false, true}) {
+ for (auto swap_sides : {false, true}) {
+ for (auto cmp : {JoinKeyCmp::IS, JoinKeyCmp::EQ}) {
+ TestHashJoinDictionaryHelper(
+ JoinType::FULL_OUTER, cmp, parallel,
+ // Input
+ DictArrayFromJSON(l_key_dict_type, R"([2, 0, 1])", R"(["b", null, "a"])"),
+ DictArrayFromJSON(l_payload_dict_type, R"([2, 2, 0])",
+ R"(["x", "y", "z"])"),
+ DictArrayFromJSON(r_key_dict_type, R"([])", R"([null, "b", "c"])"),
+ DictArrayFromJSON(r_payload_dict_type, R"([])", R"(["p", "r", "s"])"),
+ // Expected
+ DictArrayFromJSON(l_key_dict_type, R"([2, 0, 1])", R"(["b", null, "a"])"),
+ DictArrayFromJSON(l_payload_dict_type, R"([2, 2, 0])",
+ R"(["x", "y", "z"])"),
+ DictArrayFromJSON(r_key_dict_type, R"([null, null, null])",
+ R"(["b", "c"])"),
+ DictArrayFromJSON(r_payload_dict_type, R"([null, null, null])",
+ R"(["p", "r", "s"])"),
+ 0, swap_sides);
+ }
+ }
+ }
+ }
+
+ // Empty probe side
+ {
+ auto l_key_dict_type = dict_types[rng.from_range(0, 7)];
+ auto l_payload_dict_type = dict_types[rng.from_range(0, 7)];
+ auto r_key_dict_type = dict_types[rng.from_range(0, 7)];
+ auto r_payload_dict_type = dict_types[rng.from_range(0, 7)];
+
+ for (auto parallel : {false, true}) {
+ for (auto swap_sides : {false, true}) {
+ for (auto cmp : {JoinKeyCmp::IS, JoinKeyCmp::EQ}) {
+ TestHashJoinDictionaryHelper(
+ JoinType::FULL_OUTER, cmp, parallel,
+ // Input
+ DictArrayFromJSON(l_key_dict_type, R"([])", R"(["b", null, "a"])"),
+ DictArrayFromJSON(l_payload_dict_type, R"([])", R"(["x", "y", "z"])"),
+ DictArrayFromJSON(r_key_dict_type, R"([2, 0, 1, null])",
+ R"([null, "b", "c"])"),
+ DictArrayFromJSON(r_payload_dict_type, R"([1, 1, null, 0])",
+ R"(["p", "r", "s"])"),
+ // Expected
+ DictArrayFromJSON(l_key_dict_type, R"([null, null, null, null])",
+ R"(["b", null, "a"])"),
+ DictArrayFromJSON(l_payload_dict_type, R"([null, null, null, null])",
+ R"(["x", "y", "z"])"),
+ DictArrayFromJSON(r_key_dict_type, R"([1, null, 0, null])",
+ R"(["b", "c"])"),
+ DictArrayFromJSON(r_payload_dict_type, R"([1, 1, null, 0])",
+ R"(["p", "r", "s"])"),
+ 4, swap_sides);
+ }
+ }
+ }
+ }
+}
+
+TEST(HashJoin, Scalars) {
+ auto int8_utf8 = std::make_shared<DictionaryType>(int8(), utf8());
+ auto int16_utf8 = std::make_shared<DictionaryType>(int16(), utf8());
+ auto int32_utf8 = std::make_shared<DictionaryType>(int32(), utf8());
+
+ // Scalars in payload columns
+ for (auto use_scalar_dict : {false, true}) {
+ TestHashJoinDictionaryHelper(
+ JoinType::FULL_OUTER, JoinKeyCmp::EQ, false /*parallel*/,
+ // Input
+ ArrayFromJSON(utf8(), R"(["a", "c", "c", "d"])"),
+ use_scalar_dict ? DictScalarFromJSON(int16_utf8, "1", R"(["z", "x", "y"])")
+ : ScalarFromJSON(utf8(), "\"x\""),
+ ArrayFromJSON(utf8(), R"(["a", "a", "b", "c"])"),
+ use_scalar_dict ? DictScalarFromJSON(int32_utf8, "0", R"(["z", "x", "y"])")
+ : ScalarFromJSON(utf8(), "\"z\""),
+ // Expected output
+ ArrayFromJSON(utf8(), R"(["a", "a", "c", "c", "d", null])"),
+ ArrayFromJSON(utf8(), R"(["x", "x", "x", "x", "x", null])"),
+ ArrayFromJSON(utf8(), R"(["a", "a", "c", "c", null, "b"])"),
+ ArrayFromJSON(utf8(), R"(["z", "z", "z", "z", null, "z"])"), 1,
+ false /*swap sides*/);
+ }
+
+ // Scalars in key columns
+ for (auto use_scalar_dict : {false, true}) {
+ for (auto swap_sides : {false, true}) {
+ TestHashJoinDictionaryHelper(
+ JoinType::FULL_OUTER, JoinKeyCmp::EQ, false /*parallel*/,
+ // Input
+ use_scalar_dict ? DictScalarFromJSON(int8_utf8, "1", R"(["b", "a", "c"])")
+ : ScalarFromJSON(utf8(), "\"a\""),
+ ArrayFromJSON(utf8(), R"(["x", "y"])"),
+ ArrayFromJSON(utf8(), R"(["a", null, "b"])"),
+ ArrayFromJSON(utf8(), R"(["p", "q", "r"])"),
+ // Expected output
+ ArrayFromJSON(utf8(), R"(["a", "a", null, null])"),
+ ArrayFromJSON(utf8(), R"(["x", "y", null, null])"),
+ ArrayFromJSON(utf8(), R"(["a", "a", null, "b"])"),
+ ArrayFromJSON(utf8(), R"(["p", "p", "q", "r"])"), 2, swap_sides);
+ }
+ }
+
+ // Null scalars in key columns
+ for (auto use_scalar_dict : {false, true}) {
+ for (auto swap_sides : {false, true}) {
+ TestHashJoinDictionaryHelper(
+ JoinType::FULL_OUTER, JoinKeyCmp::EQ, false /*parallel*/,
+ // Input
+ use_scalar_dict ? DictScalarFromJSON(int16_utf8, "2", R"(["a", "b", null])")
+ : ScalarFromJSON(utf8(), "null"),
+ ArrayFromJSON(utf8(), R"(["x", "y"])"),
+ ArrayFromJSON(utf8(), R"(["a", null, "b"])"),
+ ArrayFromJSON(utf8(), R"(["p", "q", "r"])"),
+ // Expected output
+ ArrayFromJSON(utf8(), R"([null, null, null, null, null])"),
+ ArrayFromJSON(utf8(), R"(["x", "y", null, null, null])"),
+ ArrayFromJSON(utf8(), R"([null, null, "a", null, "b"])"),
+ ArrayFromJSON(utf8(), R"([null, null, "p", "q", "r"])"), 3, swap_sides);
+ TestHashJoinDictionaryHelper(
+ JoinType::FULL_OUTER, JoinKeyCmp::IS, false /*parallel*/,
+ // Input
+ use_scalar_dict ? DictScalarFromJSON(int16_utf8, "null", R"(["a", "b", null])")
+ : ScalarFromJSON(utf8(), "null"),
+ ArrayFromJSON(utf8(), R"(["x", "y"])"),
+ ArrayFromJSON(utf8(), R"(["a", null, "b"])"),
+ ArrayFromJSON(utf8(), R"(["p", "q", "r"])"),
+ // Expected output
+ ArrayFromJSON(utf8(), R"([null, null, null, null])"),
+ ArrayFromJSON(utf8(), R"(["x", "y", null, null])"),
+ ArrayFromJSON(utf8(), R"([null, null, "a", "b"])"),
+ ArrayFromJSON(utf8(), R"(["q", "q", "p", "r"])"), 2, swap_sides);
+ }
+ }
+
+ // Scalars with the empty build/probe side
+ for (auto use_scalar_dict : {false, true}) {
+ for (auto swap_sides : {false, true}) {
+ TestHashJoinDictionaryHelper(
+ JoinType::FULL_OUTER, JoinKeyCmp::EQ, false /*parallel*/,
+ // Input
+ use_scalar_dict ? DictScalarFromJSON(int8_utf8, "1", R"(["b", "a", "c"])")
+ : ScalarFromJSON(utf8(), "\"a\""),
+ ArrayFromJSON(utf8(), R"(["x", "y"])"), ArrayFromJSON(utf8(), R"([])"),
+ ArrayFromJSON(utf8(), R"([])"),
+ // Expected output
+ ArrayFromJSON(utf8(), R"(["a", "a"])"), ArrayFromJSON(utf8(), R"(["x", "y"])"),
+ ArrayFromJSON(utf8(), R"([null, null])"),
+ ArrayFromJSON(utf8(), R"([null, null])"), 0, swap_sides);
+ }
+ }
+
+ // Scalars vs dictionaries in key columns
+ for (auto use_scalar_dict : {false, true}) {
+ for (auto swap_sides : {false, true}) {
+ TestHashJoinDictionaryHelper(
+ JoinType::FULL_OUTER, JoinKeyCmp::EQ, false /*parallel*/,
+ // Input
+ use_scalar_dict ? DictScalarFromJSON(int32_utf8, "1", R"(["b", "a", "c"])")
+ : ScalarFromJSON(utf8(), "\"a\""),
+ ArrayFromJSON(utf8(), R"(["x", "y"])"),
+ DictArrayFromJSON(int32_utf8, R"([2, 2, 1])", R"(["b", null, "a"])"),
+ ArrayFromJSON(utf8(), R"(["p", "q", "r"])"),
+ // Expected output
+ ArrayFromJSON(utf8(), R"(["a", "a", "a", "a", null])"),
+ ArrayFromJSON(utf8(), R"(["x", "x", "y", "y", null])"),
+ ArrayFromJSON(utf8(), R"(["a", "a", "a", "a", null])"),
+ ArrayFromJSON(utf8(), R"(["p", "q", "p", "q", "r"])"), 1, swap_sides);
+ }
+ }
+}
+
+TEST(HashJoin, DictNegative) {
+ // For dictionary keys, all batches must share a single dictionary.
+ // Eventually, differing dictionaries will be unified and indices transposed
+ // during encoding to relieve this restriction.
+ const auto dictA = ArrayFromJSON(utf8(), R"(["ex", "why", "zee", null])");
+ const auto dictB = ArrayFromJSON(utf8(), R"(["different", "dictionary"])");
+
+ Datum datumFirst = Datum(
+ *DictionaryArray::FromArrays(ArrayFromJSON(int32(), R"([0, 1, 2, 3])"), dictA));
+ Datum datumSecondA = Datum(
+ *DictionaryArray::FromArrays(ArrayFromJSON(int32(), R"([3, 2, 2, 3])"), dictA));
+ Datum datumSecondB = Datum(
+ *DictionaryArray::FromArrays(ArrayFromJSON(int32(), R"([0, 1, 1, 0])"), dictB));
+
+ for (int i = 0; i < 4; ++i) {
+ BatchesWithSchema l, r;
+ l.schema = schema({field("l_key", dictionary(int32(), utf8())),
+ field("l_payload", dictionary(int32(), utf8()))});
+ r.schema = schema({field("r_key", dictionary(int32(), utf8())),
+ field("r_payload", dictionary(int32(), utf8()))});
+ l.batches.resize(2);
+ r.batches.resize(2);
+ ASSERT_OK_AND_ASSIGN(l.batches[0], ExecBatch::Make({datumFirst, datumFirst}));
+ ASSERT_OK_AND_ASSIGN(r.batches[0], ExecBatch::Make({datumFirst, datumFirst}));
+ ASSERT_OK_AND_ASSIGN(l.batches[1],
+ ExecBatch::Make({i == 0 ? datumSecondB : datumSecondA,
+ i == 1 ? datumSecondB : datumSecondA}));
+ ASSERT_OK_AND_ASSIGN(r.batches[1],
+ ExecBatch::Make({i == 2 ? datumSecondB : datumSecondA,
+ i == 3 ? datumSecondB : datumSecondA}));
+
+ auto exec_ctx =
+ arrow::internal::make_unique<ExecContext>(default_memory_pool(), nullptr);
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
+ ASSERT_OK_AND_ASSIGN(
+ ExecNode * l_source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{l.schema, l.gen(/*parallel=*/false,
+ /*slow=*/false)}));
+ ASSERT_OK_AND_ASSIGN(
+ ExecNode * r_source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{r.schema, r.gen(/*parallel=*/false,
+ /*slow=*/false)}));
+ HashJoinNodeOptions join_options{JoinType::INNER,
+ {FieldRef("l_key")},
+ {FieldRef("r_key")},
+ {FieldRef("l_key"), FieldRef("l_payload")},
+ {FieldRef("r_key"), FieldRef("r_payload")},
+ {JoinKeyCmp::EQ}};
+ ASSERT_OK_AND_ASSIGN(
+ ExecNode * join,
+ MakeExecNode("hashjoin", plan.get(), {l_source, r_source}, join_options));
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+ ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), {join},
+ SinkNodeOptions{&sink_gen}));
+
+ EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(
+ NotImplemented, ::testing::HasSubstr("Unifying differing dictionaries"),
+ StartAndCollect(plan.get(), sink_gen));
+ }
+}
+
+TEST(HashJoin, UnsupportedTypes) {
+ // ARROW-14519
+ const bool parallel = false;
+ const bool slow = false;
+
+ auto l_schema = schema({field("l_i32", int32()), field("l_list", list(int32()))});
+ auto l_schema_nolist = schema({field("l_i32", int32())});
+ auto r_schema = schema({field("r_i32", int32()), field("r_list", list(int32()))});
+ auto r_schema_nolist = schema({field("r_i32", int32())});
+
+ std::vector<std::pair<std::shared_ptr<Schema>, std::shared_ptr<Schema>>> cases{
+ {l_schema, r_schema}, {l_schema_nolist, r_schema}, {l_schema, r_schema_nolist}};
+ std::vector<FieldRef> l_keys{{"l_i32"}};
+ std::vector<FieldRef> r_keys{{"r_i32"}};
+
+ for (const auto& schemas : cases) {
+ BatchesWithSchema l_batches = GenerateBatchesFromString(schemas.first, {R"([])"});
+ BatchesWithSchema r_batches = GenerateBatchesFromString(schemas.second, {R"([])"});
+
+ ExecContext exec_ctx;
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx));
+
+ HashJoinNodeOptions join_options{JoinType::LEFT_SEMI, l_keys, r_keys};
+ Declaration join{"hashjoin", join_options};
+ join.inputs.emplace_back(Declaration{
+ "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, slow)}});
+ join.inputs.emplace_back(Declaration{
+ "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, slow)}});
+
+ ASSERT_RAISES(Invalid, join.AddToPlan(plan.get()));
+ }
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/key_compare.cc b/src/arrow/cpp/src/arrow/compute/exec/key_compare.cc
new file mode 100644
index 000000000..55b0e5e99
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/key_compare.cc
@@ -0,0 +1,424 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/key_compare.h"
+
+#include <memory.h>
+
+#include <algorithm>
+#include <cstdint>
+
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/ubsan.h"
+
+namespace arrow {
+namespace compute {
+
+template <bool use_selection>
+void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null,
+ const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx,
+ const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows,
+ uint8_t* match_bytevector) {
+ if (!rows.has_any_nulls(ctx) && !col.data(0)) {
+ return;
+ }
+ uint32_t num_processed = 0;
+#if defined(ARROW_HAVE_AVX2)
+ if (ctx->has_avx2()) {
+ num_processed = NullUpdateColumnToRow_avx2(use_selection, id_col, num_rows_to_compare,
+ sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, match_bytevector);
+ }
+#endif
+
+ if (!col.data(0)) {
+ // Remove rows from the result for which the column value is a null
+ const uint8_t* null_masks = rows.null_masks();
+ uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
+ for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) {
+ uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+ uint32_t irow_right = left_to_right_map[irow_left];
+ int64_t bitid = irow_right * null_mask_num_bytes * 8 + id_col;
+ match_bytevector[i] &= (BitUtil::GetBit(null_masks, bitid) ? 0 : 0xff);
+ }
+ } else if (!rows.has_any_nulls(ctx)) {
+ // Remove rows from the result for which the column value on left side is null
+ const uint8_t* non_nulls = col.data(0);
+ ARROW_DCHECK(non_nulls);
+ for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) {
+ uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+ match_bytevector[i] &=
+ BitUtil::GetBit(non_nulls, irow_left + col.bit_offset(0)) ? 0xff : 0;
+ }
+ } else {
+ const uint8_t* null_masks = rows.null_masks();
+ uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
+ const uint8_t* non_nulls = col.data(0);
+ ARROW_DCHECK(non_nulls);
+ for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) {
+ uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+ uint32_t irow_right = left_to_right_map[irow_left];
+ int64_t bitid_right = irow_right * null_mask_num_bytes * 8 + id_col;
+ int right_null = BitUtil::GetBit(null_masks, bitid_right) ? 0xff : 0;
+ int left_null =
+ BitUtil::GetBit(non_nulls, irow_left + col.bit_offset(0)) ? 0 : 0xff;
+ match_bytevector[i] |= left_null & right_null;
+ match_bytevector[i] &= ~(left_null ^ right_null);
+ }
+ }
+}
+
+template <bool use_selection, class COMPARE_FN>
+void KeyCompare::CompareBinaryColumnToRowHelper(
+ uint32_t offset_within_row, uint32_t first_row_to_compare,
+ uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null,
+ const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx,
+ const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows,
+ uint8_t* match_bytevector, COMPARE_FN compare_fn) {
+ bool is_fixed_length = rows.metadata().is_fixed_length;
+ if (is_fixed_length) {
+ uint32_t fixed_length = rows.metadata().fixed_length;
+ const uint8_t* rows_left = col.data(1);
+ const uint8_t* rows_right = rows.data(1);
+ for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) {
+ uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+ uint32_t irow_right = left_to_right_map[irow_left];
+ uint32_t offset_right = irow_right * fixed_length + offset_within_row;
+ match_bytevector[i] = compare_fn(rows_left, rows_right, irow_left, offset_right);
+ }
+ } else {
+ const uint8_t* rows_left = col.data(1);
+ const uint32_t* offsets_right = rows.offsets();
+ const uint8_t* rows_right = rows.data(2);
+ for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) {
+ uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+ uint32_t irow_right = left_to_right_map[irow_left];
+ uint32_t offset_right = offsets_right[irow_right] + offset_within_row;
+ match_bytevector[i] = compare_fn(rows_left, rows_right, irow_left, offset_right);
+ }
+ }
+}
+
+template <bool use_selection>
+void KeyCompare::CompareBinaryColumnToRow(
+ uint32_t offset_within_row, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) {
+ uint32_t num_processed = 0;
+#if defined(ARROW_HAVE_AVX2)
+ if (ctx->has_avx2()) {
+ num_processed = CompareBinaryColumnToRow_avx2(
+ use_selection, offset_within_row, num_rows_to_compare, sel_left_maybe_null,
+ left_to_right_map, ctx, col, rows, match_bytevector);
+ }
+#endif
+
+ uint32_t col_width = col.metadata().fixed_length;
+ if (col_width == 0) {
+ int bit_offset = col.bit_offset(1);
+ CompareBinaryColumnToRowHelper<use_selection>(
+ offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null,
+ left_to_right_map, ctx, col, rows, match_bytevector,
+ [bit_offset](const uint8_t* left_base, const uint8_t* right_base,
+ uint32_t irow_left, uint32_t offset_right) {
+ uint8_t left = BitUtil::GetBit(left_base, irow_left + bit_offset) ? 0xff : 0x00;
+ uint8_t right = right_base[offset_right];
+ return left == right ? 0xff : 0;
+ });
+ } else if (col_width == 1) {
+ CompareBinaryColumnToRowHelper<use_selection>(
+ offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null,
+ left_to_right_map, ctx, col, rows, match_bytevector,
+ [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left,
+ uint32_t offset_right) {
+ uint8_t left = left_base[irow_left];
+ uint8_t right = right_base[offset_right];
+ return left == right ? 0xff : 0;
+ });
+ } else if (col_width == 2) {
+ CompareBinaryColumnToRowHelper<use_selection>(
+ offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null,
+ left_to_right_map, ctx, col, rows, match_bytevector,
+ [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left,
+ uint32_t offset_right) {
+ util::CheckAlignment<uint16_t>(left_base);
+ util::CheckAlignment<uint16_t>(right_base + offset_right);
+ uint16_t left = reinterpret_cast<const uint16_t*>(left_base)[irow_left];
+ uint16_t right = *reinterpret_cast<const uint16_t*>(right_base + offset_right);
+ return left == right ? 0xff : 0;
+ });
+ } else if (col_width == 4) {
+ CompareBinaryColumnToRowHelper<use_selection>(
+ offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null,
+ left_to_right_map, ctx, col, rows, match_bytevector,
+ [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left,
+ uint32_t offset_right) {
+ util::CheckAlignment<uint32_t>(left_base);
+ util::CheckAlignment<uint32_t>(right_base + offset_right);
+ uint32_t left = reinterpret_cast<const uint32_t*>(left_base)[irow_left];
+ uint32_t right = *reinterpret_cast<const uint32_t*>(right_base + offset_right);
+ return left == right ? 0xff : 0;
+ });
+ } else if (col_width == 8) {
+ CompareBinaryColumnToRowHelper<use_selection>(
+ offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null,
+ left_to_right_map, ctx, col, rows, match_bytevector,
+ [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left,
+ uint32_t offset_right) {
+ util::CheckAlignment<uint64_t>(left_base);
+ util::CheckAlignment<uint64_t>(right_base + offset_right);
+ uint64_t left = reinterpret_cast<const uint64_t*>(left_base)[irow_left];
+ uint64_t right = *reinterpret_cast<const uint64_t*>(right_base + offset_right);
+ return left == right ? 0xff : 0;
+ });
+ } else {
+ CompareBinaryColumnToRowHelper<use_selection>(
+ offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null,
+ left_to_right_map, ctx, col, rows, match_bytevector,
+ [&col](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left,
+ uint32_t offset_right) {
+ uint32_t length = col.metadata().fixed_length;
+
+ // Non-zero length guarantees no underflow
+ int32_t num_loops_less_one =
+ static_cast<int32_t>(BitUtil::CeilDiv(length, 8)) - 1;
+
+ uint64_t tail_mask = ~0ULL >> (64 - 8 * (length - num_loops_less_one * 8));
+
+ const uint64_t* key_left_ptr =
+ reinterpret_cast<const uint64_t*>(left_base + irow_left * length);
+ util::CheckAlignment<uint64_t>(right_base + offset_right);
+ const uint64_t* key_right_ptr =
+ reinterpret_cast<const uint64_t*>(right_base + offset_right);
+ uint64_t result_or = 0;
+ int32_t i;
+ // length cannot be zero
+ for (i = 0; i < num_loops_less_one; ++i) {
+ uint64_t key_left = util::SafeLoad(key_left_ptr + i);
+ uint64_t key_right = key_right_ptr[i];
+ result_or |= key_left ^ key_right;
+ }
+ uint64_t key_left = util::SafeLoad(key_left_ptr + i);
+ uint64_t key_right = key_right_ptr[i];
+ result_or |= tail_mask & (key_left ^ key_right);
+ return result_or == 0 ? 0xff : 0;
+ });
+ }
+}
+
+// Overwrites the match_bytevector instead of updating it
+template <bool use_selection, bool is_first_varbinary_col>
+void KeyCompare::CompareVarBinaryColumnToRow(
+ uint32_t id_varbinary_col, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) {
+#if defined(ARROW_HAVE_AVX2)
+ if (ctx->has_avx2()) {
+ CompareVarBinaryColumnToRow_avx2(
+ use_selection, is_first_varbinary_col, id_varbinary_col, num_rows_to_compare,
+ sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector);
+ return;
+ }
+#endif
+
+ const uint32_t* offsets_left = col.offsets();
+ const uint32_t* offsets_right = rows.offsets();
+ const uint8_t* rows_left = col.data(2);
+ const uint8_t* rows_right = rows.data(2);
+ for (uint32_t i = 0; i < num_rows_to_compare; ++i) {
+ uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+ uint32_t irow_right = left_to_right_map[irow_left];
+ uint32_t begin_left = offsets_left[irow_left];
+ uint32_t length_left = offsets_left[irow_left + 1] - begin_left;
+ uint32_t begin_right = offsets_right[irow_right];
+ uint32_t length_right;
+ uint32_t offset_within_row;
+ if (!is_first_varbinary_col) {
+ rows.metadata().nth_varbinary_offset_and_length(
+ rows_right + begin_right, id_varbinary_col, &offset_within_row, &length_right);
+ } else {
+ rows.metadata().first_varbinary_offset_and_length(
+ rows_right + begin_right, &offset_within_row, &length_right);
+ }
+ begin_right += offset_within_row;
+ uint32_t length = std::min(length_left, length_right);
+ const uint64_t* key_left_ptr =
+ reinterpret_cast<const uint64_t*>(rows_left + begin_left);
+ util::CheckAlignment<uint64_t>(rows_right + begin_right);
+ const uint64_t* key_right_ptr =
+ reinterpret_cast<const uint64_t*>(rows_right + begin_right);
+ uint64_t result_or = 0;
+ if (length > 0) {
+ int32_t j;
+ // length can be zero
+ for (j = 0; j < static_cast<int32_t>(BitUtil::CeilDiv(length, 8)) - 1; ++j) {
+ uint64_t key_left = util::SafeLoad(key_left_ptr + j);
+ uint64_t key_right = key_right_ptr[j];
+ result_or |= key_left ^ key_right;
+ }
+ uint64_t tail_mask = ~0ULL >> (64 - 8 * (length - j * 8));
+ uint64_t key_left = util::SafeLoad(key_left_ptr + j);
+ uint64_t key_right = key_right_ptr[j];
+ result_or |= tail_mask & (key_left ^ key_right);
+ }
+ int result = result_or == 0 ? 0xff : 0;
+ result *= (length_left == length_right ? 1 : 0);
+ match_bytevector[i] = result;
+ }
+}
+
+void KeyCompare::AndByteVectors(KeyEncoder::KeyEncoderContext* ctx, uint32_t num_elements,
+ uint8_t* bytevector_A, const uint8_t* bytevector_B) {
+ uint32_t num_processed = 0;
+#if defined(ARROW_HAVE_AVX2)
+ if (ctx->has_avx2()) {
+ num_processed = AndByteVectors_avx2(num_elements, bytevector_A, bytevector_B);
+ }
+#endif
+
+ for (uint32_t i = num_processed / 8; i < BitUtil::CeilDiv(num_elements, 8); ++i) {
+ uint64_t* a = reinterpret_cast<uint64_t*>(bytevector_A);
+ const uint64_t* b = reinterpret_cast<const uint64_t*>(bytevector_B);
+ a[i] &= b[i];
+ }
+}
+
+void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null,
+ const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx,
+ uint32_t* out_num_rows,
+ uint16_t* out_sel_left_maybe_same,
+ const std::vector<KeyEncoder::KeyColumnArray>& cols,
+ const KeyEncoder::KeyRowArray& rows) {
+ if (num_rows_to_compare == 0) {
+ *out_num_rows = 0;
+ return;
+ }
+
+ // Allocate temporary byte and bit vectors
+ auto bytevector_A_holder =
+ util::TempVectorHolder<uint8_t>(ctx->stack, num_rows_to_compare);
+ auto bytevector_B_holder =
+ util::TempVectorHolder<uint8_t>(ctx->stack, num_rows_to_compare);
+ auto bitvector_holder =
+ util::TempVectorHolder<uint8_t>(ctx->stack, num_rows_to_compare);
+
+ uint8_t* match_bytevector_A = bytevector_A_holder.mutable_data();
+ uint8_t* match_bytevector_B = bytevector_B_holder.mutable_data();
+ uint8_t* match_bitvector = bitvector_holder.mutable_data();
+
+ bool is_first_column = true;
+ for (size_t icol = 0; icol < cols.size(); ++icol) {
+ const KeyEncoder::KeyColumnArray& col = cols[icol];
+ uint32_t offset_within_row =
+ rows.metadata().encoded_field_offset(static_cast<uint32_t>(icol));
+ if (col.metadata().is_fixed_length) {
+ if (sel_left_maybe_null) {
+ CompareBinaryColumnToRow<true>(
+ offset_within_row, num_rows_to_compare, sel_left_maybe_null,
+ left_to_right_map, ctx, col, rows,
+ is_first_column ? match_bytevector_A : match_bytevector_B);
+ NullUpdateColumnToRow<true>(
+ static_cast<uint32_t>(icol), num_rows_to_compare, sel_left_maybe_null,
+ left_to_right_map, ctx, col, rows,
+ is_first_column ? match_bytevector_A : match_bytevector_B);
+ } else {
+ // Version without using selection vector
+ CompareBinaryColumnToRow<false>(
+ offset_within_row, num_rows_to_compare, sel_left_maybe_null,
+ left_to_right_map, ctx, col, rows,
+ is_first_column ? match_bytevector_A : match_bytevector_B);
+ NullUpdateColumnToRow<false>(
+ static_cast<uint32_t>(icol), num_rows_to_compare, sel_left_maybe_null,
+ left_to_right_map, ctx, col, rows,
+ is_first_column ? match_bytevector_A : match_bytevector_B);
+ }
+ if (!is_first_column) {
+ AndByteVectors(ctx, num_rows_to_compare, match_bytevector_A, match_bytevector_B);
+ }
+ is_first_column = false;
+ }
+ }
+
+ uint32_t ivarbinary = 0;
+ for (size_t icol = 0; icol < cols.size(); ++icol) {
+ const KeyEncoder::KeyColumnArray& col = cols[icol];
+ if (!col.metadata().is_fixed_length) {
+ // Process varbinary and nulls
+ if (sel_left_maybe_null) {
+ if (ivarbinary == 0) {
+ CompareVarBinaryColumnToRow<true, true>(
+ ivarbinary, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, is_first_column ? match_bytevector_A : match_bytevector_B);
+ } else {
+ CompareVarBinaryColumnToRow<true, false>(ivarbinary, num_rows_to_compare,
+ sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, match_bytevector_B);
+ }
+ NullUpdateColumnToRow<true>(
+ static_cast<uint32_t>(icol), num_rows_to_compare, sel_left_maybe_null,
+ left_to_right_map, ctx, col, rows,
+ is_first_column ? match_bytevector_A : match_bytevector_B);
+ } else {
+ if (ivarbinary == 0) {
+ CompareVarBinaryColumnToRow<false, true>(
+ ivarbinary, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, is_first_column ? match_bytevector_A : match_bytevector_B);
+ } else {
+ CompareVarBinaryColumnToRow<false, false>(
+ ivarbinary, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, match_bytevector_B);
+ }
+ NullUpdateColumnToRow<false>(
+ static_cast<uint32_t>(icol), num_rows_to_compare, sel_left_maybe_null,
+ left_to_right_map, ctx, col, rows,
+ is_first_column ? match_bytevector_A : match_bytevector_B);
+ }
+ if (!is_first_column) {
+ AndByteVectors(ctx, num_rows_to_compare, match_bytevector_A, match_bytevector_B);
+ }
+ is_first_column = false;
+ ++ivarbinary;
+ }
+ }
+
+ util::BitUtil::bytes_to_bits(ctx->hardware_flags, num_rows_to_compare,
+ match_bytevector_A, match_bitvector);
+ if (sel_left_maybe_null) {
+ int out_num_rows_int;
+ util::BitUtil::bits_filter_indexes(0, ctx->hardware_flags, num_rows_to_compare,
+ match_bitvector, sel_left_maybe_null,
+ &out_num_rows_int, out_sel_left_maybe_same);
+ *out_num_rows = out_num_rows_int;
+ } else {
+ int out_num_rows_int;
+ util::BitUtil::bits_to_indexes(0, ctx->hardware_flags, num_rows_to_compare,
+ match_bitvector, &out_num_rows_int,
+ out_sel_left_maybe_same);
+ *out_num_rows = out_num_rows_int;
+ }
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/key_compare.h b/src/arrow/cpp/src/arrow/compute/exec/key_compare.h
new file mode 100644
index 000000000..aeb5abbdd
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/key_compare.h
@@ -0,0 +1,137 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+
+#include "arrow/compute/exec/key_encode.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/memory_pool.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+
+namespace arrow {
+namespace compute {
+
+class KeyCompare {
+ public:
+ // Returns a single 16-bit selection vector of rows that failed comparison.
+ // If there is input selection on the left, the resulting selection is a filtered image
+ // of input selection.
+ static void CompareColumnsToRows(uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null,
+ const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx,
+ uint32_t* out_num_rows,
+ uint16_t* out_sel_left_maybe_same,
+ const std::vector<KeyEncoder::KeyColumnArray>& cols,
+ const KeyEncoder::KeyRowArray& rows);
+
+ private:
+ template <bool use_selection>
+ static void NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null,
+ const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx,
+ const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows,
+ uint8_t* match_bytevector);
+
+ template <bool use_selection, class COMPARE_FN>
+ static void CompareBinaryColumnToRowHelper(
+ uint32_t offset_within_row, uint32_t first_row_to_compare,
+ uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null,
+ const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx,
+ const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows,
+ uint8_t* match_bytevector, COMPARE_FN compare_fn);
+
+ template <bool use_selection>
+ static void CompareBinaryColumnToRow(
+ uint32_t offset_within_row, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector);
+
+ template <bool use_selection, bool is_first_varbinary_col>
+ static void CompareVarBinaryColumnToRow(
+ uint32_t id_varlen_col, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector);
+
+ static void AndByteVectors(KeyEncoder::KeyEncoderContext* ctx, uint32_t num_elements,
+ uint8_t* bytevector_A, const uint8_t* bytevector_B);
+
+#if defined(ARROW_HAVE_AVX2)
+
+ template <bool use_selection>
+ static uint32_t NullUpdateColumnToRowImp_avx2(
+ uint32_t id_col, uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null,
+ const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx,
+ const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows,
+ uint8_t* match_bytevector);
+
+ template <bool use_selection, class COMPARE8_FN>
+ static uint32_t CompareBinaryColumnToRowHelper_avx2(
+ uint32_t offset_within_row, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector,
+ COMPARE8_FN compare8_fn);
+
+ template <bool use_selection>
+ static uint32_t CompareBinaryColumnToRowImp_avx2(
+ uint32_t offset_within_row, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector);
+
+ template <bool use_selection, bool is_first_varbinary_col>
+ static void CompareVarBinaryColumnToRowImp_avx2(
+ uint32_t id_varlen_col, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector);
+
+ static uint32_t AndByteVectors_avx2(uint32_t num_elements, uint8_t* bytevector_A,
+ const uint8_t* bytevector_B);
+
+ static uint32_t NullUpdateColumnToRow_avx2(
+ bool use_selection, uint32_t id_col, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector);
+
+ static uint32_t CompareBinaryColumnToRow_avx2(
+ bool use_selection, uint32_t offset_within_row, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector);
+
+ static void CompareVarBinaryColumnToRow_avx2(
+ bool use_selection, bool is_first_varbinary_col, uint32_t id_varlen_col,
+ uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null,
+ const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx,
+ const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows,
+ uint8_t* match_bytevector);
+
+#endif
+};
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/key_compare_avx2.cc b/src/arrow/cpp/src/arrow/compute/exec/key_compare_avx2.cc
new file mode 100644
index 000000000..df13e8cae
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/key_compare_avx2.cc
@@ -0,0 +1,633 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <immintrin.h>
+
+#include "arrow/compute/exec/key_compare.h"
+#include "arrow/util/bit_util.h"
+
+namespace arrow {
+namespace compute {
+
+#if defined(ARROW_HAVE_AVX2)
+
+inline __m256i set_first_n_bytes_avx2(int n) {
+ constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL;
+ constexpr uint64_t kByteSequence8To15 = 0x0f0e0d0c0b0a0908ULL;
+ constexpr uint64_t kByteSequence16To23 = 0x1716151413121110ULL;
+ constexpr uint64_t kByteSequence24To31 = 0x1f1e1d1c1b1a1918ULL;
+
+ return _mm256_cmpgt_epi8(_mm256_set1_epi8(n),
+ _mm256_setr_epi64x(kByteSequence0To7, kByteSequence8To15,
+ kByteSequence16To23, kByteSequence24To31));
+}
+
+template <bool use_selection>
+uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2(
+ uint32_t id_col, uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null,
+ const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx,
+ const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows,
+ uint8_t* match_bytevector) {
+ if (!rows.has_any_nulls(ctx) && !col.data(0)) {
+ return num_rows_to_compare;
+ }
+ if (!col.data(0)) {
+ // Remove rows from the result for which the column value is a null
+ const uint8_t* null_masks = rows.null_masks();
+ uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
+
+ uint32_t num_processed = 0;
+ constexpr uint32_t unroll = 8;
+ for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) {
+ __m256i irow_right;
+ if (use_selection) {
+ __m256i irow_left = _mm256_cvtepu16_epi32(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i));
+ irow_right = _mm256_i32gather_epi32((const int*)left_to_right_map, irow_left, 4);
+ } else {
+ irow_right =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_to_right_map) + i);
+ }
+ __m256i bitid =
+ _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(null_mask_num_bytes * 8));
+ bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(id_col));
+ __m256i right =
+ _mm256_i32gather_epi32((const int*)null_masks, _mm256_srli_epi32(bitid, 3), 1);
+ right = _mm256_and_si256(
+ _mm256_set1_epi32(1),
+ _mm256_srlv_epi32(right, _mm256_and_si256(bitid, _mm256_set1_epi32(7))));
+ __m256i cmp = _mm256_cmpeq_epi32(right, _mm256_setzero_si256());
+ uint32_t result_lo =
+ _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp)));
+ uint32_t result_hi =
+ _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1)));
+ reinterpret_cast<uint64_t*>(match_bytevector)[i] &=
+ result_lo | (static_cast<uint64_t>(result_hi) << 32);
+ }
+ num_processed = num_rows_to_compare / unroll * unroll;
+ return num_processed;
+ } else if (!rows.has_any_nulls(ctx)) {
+ // Remove rows from the result for which the column value on left side is null
+ const uint8_t* non_nulls = col.data(0);
+ ARROW_DCHECK(non_nulls);
+ uint32_t num_processed = 0;
+ constexpr uint32_t unroll = 8;
+ for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) {
+ __m256i cmp;
+ if (use_selection) {
+ __m256i irow_left = _mm256_cvtepu16_epi32(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i));
+ irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(col.bit_offset(0)));
+ __m256i left = _mm256_i32gather_epi32((const int*)non_nulls,
+ _mm256_srli_epi32(irow_left, 3), 1);
+ left = _mm256_and_si256(
+ _mm256_set1_epi32(1),
+ _mm256_srlv_epi32(left, _mm256_and_si256(irow_left, _mm256_set1_epi32(7))));
+ cmp = _mm256_cmpeq_epi32(left, _mm256_set1_epi32(1));
+ } else {
+ __m256i left = _mm256_cvtepu8_epi32(_mm_set1_epi8(static_cast<uint8_t>(
+ reinterpret_cast<const uint16_t*>(non_nulls + i)[0] >> col.bit_offset(0))));
+ __m256i bits = _mm256_setr_epi32(1, 2, 4, 8, 16, 32, 64, 128);
+ cmp = _mm256_cmpeq_epi32(_mm256_and_si256(left, bits), bits);
+ }
+ uint32_t result_lo =
+ _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp)));
+ uint32_t result_hi =
+ _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1)));
+ reinterpret_cast<uint64_t*>(match_bytevector)[i] &=
+ result_lo | (static_cast<uint64_t>(result_hi) << 32);
+ num_processed = num_rows_to_compare / unroll * unroll;
+ }
+ return num_processed;
+ } else {
+ const uint8_t* null_masks = rows.null_masks();
+ uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
+ const uint8_t* non_nulls = col.data(0);
+ ARROW_DCHECK(non_nulls);
+
+ uint32_t num_processed = 0;
+ constexpr uint32_t unroll = 8;
+ for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) {
+ __m256i left_null;
+ __m256i irow_right;
+ if (use_selection) {
+ __m256i irow_left = _mm256_cvtepu16_epi32(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i));
+ irow_right = _mm256_i32gather_epi32((const int*)left_to_right_map, irow_left, 4);
+ irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(col.bit_offset(0)));
+ __m256i left = _mm256_i32gather_epi32((const int*)non_nulls,
+ _mm256_srli_epi32(irow_left, 3), 1);
+ left = _mm256_and_si256(
+ _mm256_set1_epi32(1),
+ _mm256_srlv_epi32(left, _mm256_and_si256(irow_left, _mm256_set1_epi32(7))));
+ left_null = _mm256_cmpeq_epi32(left, _mm256_setzero_si256());
+ } else {
+ irow_right =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_to_right_map) + i);
+ __m256i left = _mm256_cvtepu8_epi32(_mm_set1_epi8(static_cast<uint8_t>(
+ reinterpret_cast<const uint16_t*>(non_nulls + i)[0] >> col.bit_offset(0))));
+ __m256i bits = _mm256_setr_epi32(1, 2, 4, 8, 16, 32, 64, 128);
+ left_null =
+ _mm256_cmpeq_epi32(_mm256_and_si256(left, bits), _mm256_setzero_si256());
+ }
+ __m256i bitid =
+ _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(null_mask_num_bytes * 8));
+ bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(id_col));
+ __m256i right =
+ _mm256_i32gather_epi32((const int*)null_masks, _mm256_srli_epi32(bitid, 3), 1);
+ right = _mm256_and_si256(
+ _mm256_set1_epi32(1),
+ _mm256_srlv_epi32(right, _mm256_and_si256(bitid, _mm256_set1_epi32(7))));
+ __m256i right_null = _mm256_cmpeq_epi32(right, _mm256_set1_epi32(1));
+
+ uint64_t left_null_64 =
+ static_cast<uint32_t>(_mm256_movemask_epi8(
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(left_null)))) |
+ (static_cast<uint64_t>(static_cast<uint32_t>(_mm256_movemask_epi8(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(left_null, 1)))))
+ << 32);
+
+ uint64_t right_null_64 =
+ static_cast<uint32_t>(_mm256_movemask_epi8(
+ _mm256_cvtepi32_epi64(_mm256_castsi256_si128(right_null)))) |
+ (static_cast<uint64_t>(static_cast<uint32_t>(_mm256_movemask_epi8(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_null, 1)))))
+ << 32);
+
+ reinterpret_cast<uint64_t*>(match_bytevector)[i] |= left_null_64 & right_null_64;
+ reinterpret_cast<uint64_t*>(match_bytevector)[i] &= ~(left_null_64 ^ right_null_64);
+ }
+ num_processed = num_rows_to_compare / unroll * unroll;
+ return num_processed;
+ }
+}
+
+template <bool use_selection, class COMPARE8_FN>
+uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2(
+ uint32_t offset_within_row, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector,
+ COMPARE8_FN compare8_fn) {
+ bool is_fixed_length = rows.metadata().is_fixed_length;
+ if (is_fixed_length) {
+ uint32_t fixed_length = rows.metadata().fixed_length;
+ const uint8_t* rows_left = col.data(1);
+ const uint8_t* rows_right = rows.data(1);
+ constexpr uint32_t unroll = 8;
+ __m256i irow_left = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
+ for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) {
+ if (use_selection) {
+ irow_left = _mm256_cvtepu16_epi32(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i));
+ }
+ __m256i irow_right;
+ if (use_selection) {
+ irow_right = _mm256_i32gather_epi32((const int*)left_to_right_map, irow_left, 4);
+ } else {
+ irow_right =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_to_right_map) + i);
+ }
+
+ __m256i offset_right =
+ _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(fixed_length));
+ offset_right = _mm256_add_epi32(offset_right, _mm256_set1_epi32(offset_within_row));
+
+ reinterpret_cast<uint64_t*>(match_bytevector)[i] =
+ compare8_fn(rows_left, rows_right, i * unroll, irow_left, offset_right);
+
+ if (!use_selection) {
+ irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(8));
+ }
+ }
+ return num_rows_to_compare - (num_rows_to_compare % unroll);
+ } else {
+ const uint8_t* rows_left = col.data(1);
+ const uint32_t* offsets_right = rows.offsets();
+ const uint8_t* rows_right = rows.data(2);
+ constexpr uint32_t unroll = 8;
+ __m256i irow_left = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
+ for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) {
+ if (use_selection) {
+ irow_left = _mm256_cvtepu16_epi32(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(sel_left_maybe_null) + i));
+ }
+ __m256i irow_right;
+ if (use_selection) {
+ irow_right = _mm256_i32gather_epi32((const int*)left_to_right_map, irow_left, 4);
+ } else {
+ irow_right =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_to_right_map) + i);
+ }
+ __m256i offset_right =
+ _mm256_i32gather_epi32((const int*)offsets_right, irow_right, 4);
+ offset_right = _mm256_add_epi32(offset_right, _mm256_set1_epi32(offset_within_row));
+
+ reinterpret_cast<uint64_t*>(match_bytevector)[i] =
+ compare8_fn(rows_left, rows_right, i * unroll, irow_left, offset_right);
+
+ if (!use_selection) {
+ irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(8));
+ }
+ }
+ return num_rows_to_compare - (num_rows_to_compare % unroll);
+ }
+}
+
+template <int column_width>
+inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* right_base,
+ __m256i irow_left, __m256i offset_right,
+ int bit_offset = 0) {
+ __m256i left;
+ switch (column_width) {
+ case 0:
+ irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(bit_offset));
+ left = _mm256_i32gather_epi32((const int*)left_base,
+ _mm256_srli_epi32(irow_left, 3), 1);
+ left = _mm256_and_si256(
+ _mm256_set1_epi32(1),
+ _mm256_srlv_epi32(left, _mm256_and_si256(irow_left, _mm256_set1_epi32(7))));
+ left = _mm256_mullo_epi32(left, _mm256_set1_epi32(0xff));
+ break;
+ case 1:
+ left = _mm256_i32gather_epi32((const int*)left_base, irow_left, 1);
+ left = _mm256_and_si256(left, _mm256_set1_epi32(0xff));
+ break;
+ case 2:
+ left = _mm256_i32gather_epi32((const int*)left_base, irow_left, 2);
+ left = _mm256_and_si256(left, _mm256_set1_epi32(0xff));
+ break;
+ case 4:
+ left = _mm256_i32gather_epi32((const int*)left_base, irow_left, 4);
+ break;
+ default:
+ ARROW_DCHECK(false);
+ }
+
+ __m256i right = _mm256_i32gather_epi32((const int*)right_base, offset_right, 1);
+ if (column_width != sizeof(uint32_t)) {
+ constexpr uint32_t mask = column_width == 0 || column_width == 1 ? 0xff : 0xffff;
+ right = _mm256_and_si256(right, _mm256_set1_epi32(mask));
+ }
+
+ __m256i cmp = _mm256_cmpeq_epi32(left, right);
+
+ uint32_t result_lo =
+ _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp)));
+ uint32_t result_hi =
+ _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1)));
+
+ return result_lo | (static_cast<uint64_t>(result_hi) << 32);
+}
+
+template <int column_width>
+inline uint64_t Compare8_avx2(const uint8_t* left_base, const uint8_t* right_base,
+ uint32_t irow_left_first, __m256i offset_right,
+ int bit_offset = 0) {
+ __m256i left;
+ switch (column_width) {
+ case 0: {
+ __m256i bits = _mm256_setr_epi32(1, 2, 4, 8, 16, 32, 64, 128);
+ uint32_t start_bit_index = irow_left_first + bit_offset;
+ uint8_t left_bits_8 =
+ (reinterpret_cast<const uint16_t*>(left_base + start_bit_index / 8)[0] >>
+ (start_bit_index % 8)) &
+ 0xff;
+ left =
+ _mm256_cmpeq_epi32(_mm256_and_si256(bits, _mm256_set1_epi8(left_bits_8)), bits);
+ left = _mm256_and_si256(left, _mm256_set1_epi32(0xff));
+ } break;
+ case 1:
+ left = _mm256_cvtepu8_epi32(_mm_set1_epi64x(
+ reinterpret_cast<const uint64_t*>(left_base)[irow_left_first / 8]));
+ break;
+ case 2:
+ left = _mm256_cvtepu16_epi32(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(left_base) + irow_left_first / 8));
+ break;
+ case 4:
+ left = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_base) +
+ irow_left_first / 8);
+ break;
+ default:
+ ARROW_DCHECK(false);
+ }
+
+ __m256i right = _mm256_i32gather_epi32((const int*)right_base, offset_right, 1);
+ if (column_width != sizeof(uint32_t)) {
+ constexpr uint32_t mask = column_width == 0 || column_width == 1 ? 0xff : 0xffff;
+ right = _mm256_and_si256(right, _mm256_set1_epi32(mask));
+ }
+
+ __m256i cmp = _mm256_cmpeq_epi32(left, right);
+
+ uint32_t result_lo =
+ _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp)));
+ uint32_t result_hi =
+ _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1)));
+
+ return result_lo | (static_cast<uint64_t>(result_hi) << 32);
+}
+
+template <bool use_selection>
+inline uint64_t Compare8_64bit_avx2(const uint8_t* left_base, const uint8_t* right_base,
+ __m256i irow_left, uint32_t irow_left_first,
+ __m256i offset_right) {
+ auto left_base_i64 =
+ reinterpret_cast<const arrow::util::int64_for_gather_t*>(left_base);
+ __m256i left_lo =
+ _mm256_i32gather_epi64(left_base_i64, _mm256_castsi256_si128(irow_left), 8);
+ __m256i left_hi =
+ _mm256_i32gather_epi64(left_base_i64, _mm256_extracti128_si256(irow_left, 1), 8);
+ if (use_selection) {
+ left_lo = _mm256_i32gather_epi64(left_base_i64, _mm256_castsi256_si128(irow_left), 8);
+ left_hi =
+ _mm256_i32gather_epi64(left_base_i64, _mm256_extracti128_si256(irow_left, 1), 8);
+ } else {
+ left_lo = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_base) +
+ irow_left_first / 4);
+ left_hi = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(left_base) +
+ irow_left_first / 4 + 1);
+ }
+ auto right_base_i64 =
+ reinterpret_cast<const arrow::util::int64_for_gather_t*>(right_base);
+ __m256i right_lo =
+ _mm256_i32gather_epi64(right_base_i64, _mm256_castsi256_si128(offset_right), 1);
+ __m256i right_hi = _mm256_i32gather_epi64(right_base_i64,
+ _mm256_extracti128_si256(offset_right, 1), 1);
+ uint32_t result_lo = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_lo, right_lo));
+ uint32_t result_hi = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_hi, right_hi));
+ return result_lo | (static_cast<uint64_t>(result_hi) << 32);
+}
+
+template <bool use_selection>
+inline uint64_t Compare8_Binary_avx2(uint32_t length, const uint8_t* left_base,
+ const uint8_t* right_base, __m256i irow_left,
+ uint32_t irow_left_first, __m256i offset_right) {
+ uint32_t irow_left_array[8];
+ uint32_t offset_right_array[8];
+ if (use_selection) {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(irow_left_array), irow_left);
+ }
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(offset_right_array), offset_right);
+
+ // Non-zero length guarantees no underflow
+ int32_t num_loops_less_one = (static_cast<int32_t>(length) + 31) / 32 - 1;
+
+ __m256i tail_mask = set_first_n_bytes_avx2(length - num_loops_less_one * 32);
+
+ uint64_t result = 0;
+ for (uint32_t irow = 0; irow < 8; ++irow) {
+ const __m256i* key_left_ptr = reinterpret_cast<const __m256i*>(
+ left_base +
+ (use_selection ? irow_left_array[irow] : irow_left_first + irow) * length);
+ const __m256i* key_right_ptr =
+ reinterpret_cast<const __m256i*>(right_base + offset_right_array[irow]);
+ __m256i result_or = _mm256_setzero_si256();
+ int32_t i;
+ // length cannot be zero
+ for (i = 0; i < num_loops_less_one; ++i) {
+ __m256i key_left = _mm256_loadu_si256(key_left_ptr + i);
+ __m256i key_right = _mm256_loadu_si256(key_right_ptr + i);
+ result_or = _mm256_or_si256(result_or, _mm256_xor_si256(key_left, key_right));
+ }
+ __m256i key_left = _mm256_loadu_si256(key_left_ptr + i);
+ __m256i key_right = _mm256_loadu_si256(key_right_ptr + i);
+ result_or = _mm256_or_si256(
+ result_or, _mm256_and_si256(tail_mask, _mm256_xor_si256(key_left, key_right)));
+ uint64_t result_single = _mm256_testz_si256(result_or, result_or) * 0xff;
+ result |= result_single << (8 * irow);
+ }
+ return result;
+}
+
+template <bool use_selection>
+uint32_t KeyCompare::CompareBinaryColumnToRowImp_avx2(
+ uint32_t offset_within_row, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) {
+ uint32_t col_width = col.metadata().fixed_length;
+ if (col_width == 0) {
+ int bit_offset = col.bit_offset(1);
+ return CompareBinaryColumnToRowHelper_avx2<use_selection>(
+ offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, match_bytevector,
+ [bit_offset](const uint8_t* left_base, const uint8_t* right_base,
+ uint32_t irow_left_base, __m256i irow_left, __m256i offset_right) {
+ if (use_selection) {
+ return CompareSelected8_avx2<0>(left_base, right_base, irow_left,
+ offset_right, bit_offset);
+ } else {
+ return Compare8_avx2<0>(left_base, right_base, irow_left_base, offset_right,
+ bit_offset);
+ }
+ });
+ } else if (col_width == 1) {
+ return CompareBinaryColumnToRowHelper_avx2<use_selection>(
+ offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, match_bytevector,
+ [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base,
+ __m256i irow_left, __m256i offset_right) {
+ if (use_selection) {
+ return CompareSelected8_avx2<1>(left_base, right_base, irow_left,
+ offset_right);
+ } else {
+ return Compare8_avx2<1>(left_base, right_base, irow_left_base, offset_right);
+ }
+ });
+ } else if (col_width == 2) {
+ return CompareBinaryColumnToRowHelper_avx2<use_selection>(
+ offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, match_bytevector,
+ [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base,
+ __m256i irow_left, __m256i offset_right) {
+ if (use_selection) {
+ return CompareSelected8_avx2<2>(left_base, right_base, irow_left,
+ offset_right);
+ } else {
+ return Compare8_avx2<2>(left_base, right_base, irow_left_base, offset_right);
+ }
+ });
+ } else if (col_width == 4) {
+ return CompareBinaryColumnToRowHelper_avx2<use_selection>(
+ offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, match_bytevector,
+ [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base,
+ __m256i irow_left, __m256i offset_right) {
+ if (use_selection) {
+ return CompareSelected8_avx2<4>(left_base, right_base, irow_left,
+ offset_right);
+ } else {
+ return Compare8_avx2<4>(left_base, right_base, irow_left_base, offset_right);
+ }
+ });
+ } else if (col_width == 8) {
+ return CompareBinaryColumnToRowHelper_avx2<use_selection>(
+ offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, match_bytevector,
+ [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base,
+ __m256i irow_left, __m256i offset_right) {
+ return Compare8_64bit_avx2<use_selection>(left_base, right_base, irow_left,
+ irow_left_base, offset_right);
+ });
+ } else {
+ return CompareBinaryColumnToRowHelper_avx2<use_selection>(
+ offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, match_bytevector,
+ [&col](const uint8_t* left_base, const uint8_t* right_base,
+ uint32_t irow_left_base, __m256i irow_left, __m256i offset_right) {
+ uint32_t length = col.metadata().fixed_length;
+ return Compare8_Binary_avx2<use_selection>(
+ length, left_base, right_base, irow_left, irow_left_base, offset_right);
+ });
+ }
+}
+
+// Overwrites the match_bytevector instead of updating it
+template <bool use_selection, bool is_first_varbinary_col>
+void KeyCompare::CompareVarBinaryColumnToRowImp_avx2(
+ uint32_t id_varbinary_col, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) {
+ const uint32_t* offsets_left = col.offsets();
+ const uint32_t* offsets_right = rows.offsets();
+ const uint8_t* rows_left = col.data(2);
+ const uint8_t* rows_right = rows.data(2);
+ for (uint32_t i = 0; i < num_rows_to_compare; ++i) {
+ uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+ uint32_t irow_right = left_to_right_map[irow_left];
+ uint32_t begin_left = offsets_left[irow_left];
+ uint32_t length_left = offsets_left[irow_left + 1] - begin_left;
+ uint32_t begin_right = offsets_right[irow_right];
+ uint32_t length_right;
+ uint32_t offset_within_row;
+ if (!is_first_varbinary_col) {
+ rows.metadata().nth_varbinary_offset_and_length(
+ rows_right + begin_right, id_varbinary_col, &offset_within_row, &length_right);
+ } else {
+ rows.metadata().first_varbinary_offset_and_length(
+ rows_right + begin_right, &offset_within_row, &length_right);
+ }
+ begin_right += offset_within_row;
+
+ __m256i result_or = _mm256_setzero_si256();
+ uint32_t length = std::min(length_left, length_right);
+ if (length > 0) {
+ const __m256i* key_left_ptr =
+ reinterpret_cast<const __m256i*>(rows_left + begin_left);
+ const __m256i* key_right_ptr =
+ reinterpret_cast<const __m256i*>(rows_right + begin_right);
+ int32_t j;
+ // length can be zero
+ for (j = 0; j < (static_cast<int32_t>(length) + 31) / 32 - 1; ++j) {
+ __m256i key_left = _mm256_loadu_si256(key_left_ptr + j);
+ __m256i key_right = _mm256_loadu_si256(key_right_ptr + j);
+ result_or = _mm256_or_si256(result_or, _mm256_xor_si256(key_left, key_right));
+ }
+
+ __m256i tail_mask = set_first_n_bytes_avx2(length - j * 32);
+
+ __m256i key_left = _mm256_loadu_si256(key_left_ptr + j);
+ __m256i key_right = _mm256_loadu_si256(key_right_ptr + j);
+ result_or = _mm256_or_si256(
+ result_or, _mm256_and_si256(tail_mask, _mm256_xor_si256(key_left, key_right)));
+ }
+ int result = _mm256_testz_si256(result_or, result_or) * 0xff;
+ result *= (length_left == length_right ? 1 : 0);
+ match_bytevector[i] = result;
+ }
+}
+
+uint32_t KeyCompare::AndByteVectors_avx2(uint32_t num_elements, uint8_t* bytevector_A,
+ const uint8_t* bytevector_B) {
+ constexpr int unroll = 32;
+ for (uint32_t i = 0; i < num_elements / unroll; ++i) {
+ __m256i result = _mm256_and_si256(
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bytevector_A) + i),
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bytevector_B) + i));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(bytevector_A) + i, result);
+ }
+ return (num_elements - (num_elements % unroll));
+}
+
+uint32_t KeyCompare::NullUpdateColumnToRow_avx2(
+ bool use_selection, uint32_t id_col, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) {
+ if (use_selection) {
+ return NullUpdateColumnToRowImp_avx2<true>(id_col, num_rows_to_compare,
+ sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, match_bytevector);
+ } else {
+ return NullUpdateColumnToRowImp_avx2<false>(id_col, num_rows_to_compare,
+ sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, match_bytevector);
+ }
+}
+
+uint32_t KeyCompare::CompareBinaryColumnToRow_avx2(
+ bool use_selection, uint32_t offset_within_row, uint32_t num_rows_to_compare,
+ const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+ KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+ const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) {
+ if (use_selection) {
+ return CompareBinaryColumnToRowImp_avx2<true>(offset_within_row, num_rows_to_compare,
+ sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, match_bytevector);
+ } else {
+ return CompareBinaryColumnToRowImp_avx2<false>(offset_within_row, num_rows_to_compare,
+ sel_left_maybe_null, left_to_right_map,
+ ctx, col, rows, match_bytevector);
+ }
+}
+
+void KeyCompare::CompareVarBinaryColumnToRow_avx2(
+ bool use_selection, bool is_first_varbinary_col, uint32_t id_varlen_col,
+ uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null,
+ const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx,
+ const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows,
+ uint8_t* match_bytevector) {
+ if (use_selection) {
+ if (is_first_varbinary_col) {
+ CompareVarBinaryColumnToRowImp_avx2<true, true>(
+ id_varlen_col, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx,
+ col, rows, match_bytevector);
+ } else {
+ CompareVarBinaryColumnToRowImp_avx2<true, false>(
+ id_varlen_col, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx,
+ col, rows, match_bytevector);
+ }
+ } else {
+ if (is_first_varbinary_col) {
+ CompareVarBinaryColumnToRowImp_avx2<false, true>(
+ id_varlen_col, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx,
+ col, rows, match_bytevector);
+ } else {
+ CompareVarBinaryColumnToRowImp_avx2<false, false>(
+ id_varlen_col, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx,
+ col, rows, match_bytevector);
+ }
+ }
+}
+
+#endif
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/key_encode.cc b/src/arrow/cpp/src/arrow/compute/exec/key_encode.cc
new file mode 100644
index 000000000..8ab76cd27
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/key_encode.cc
@@ -0,0 +1,1341 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/key_encode.h"
+
+#include <memory.h>
+
+#include <algorithm>
+
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/ubsan.h"
+
+namespace arrow {
+namespace compute {
+
+KeyEncoder::KeyRowArray::KeyRowArray()
+ : pool_(nullptr), rows_capacity_(0), bytes_capacity_(0) {}
+
+Status KeyEncoder::KeyRowArray::Init(MemoryPool* pool, const KeyRowMetadata& metadata) {
+ pool_ = pool;
+ metadata_ = metadata;
+
+ DCHECK(!null_masks_ && !offsets_ && !rows_);
+
+ constexpr int64_t rows_capacity = 8;
+ constexpr int64_t bytes_capacity = 1024;
+
+ // Null masks
+ ARROW_ASSIGN_OR_RAISE(auto null_masks,
+ AllocateResizableBuffer(size_null_masks(rows_capacity), pool_));
+ null_masks_ = std::move(null_masks);
+ memset(null_masks_->mutable_data(), 0, size_null_masks(rows_capacity));
+
+ // Offsets and rows
+ if (!metadata.is_fixed_length) {
+ ARROW_ASSIGN_OR_RAISE(auto offsets,
+ AllocateResizableBuffer(size_offsets(rows_capacity), pool_));
+ offsets_ = std::move(offsets);
+ memset(offsets_->mutable_data(), 0, size_offsets(rows_capacity));
+ reinterpret_cast<uint32_t*>(offsets_->mutable_data())[0] = 0;
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto rows,
+ AllocateResizableBuffer(size_rows_varying_length(bytes_capacity), pool_));
+ rows_ = std::move(rows);
+ memset(rows_->mutable_data(), 0, size_rows_varying_length(bytes_capacity));
+ bytes_capacity_ = size_rows_varying_length(bytes_capacity) - padding_for_vectors;
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ auto rows, AllocateResizableBuffer(size_rows_fixed_length(rows_capacity), pool_));
+ rows_ = std::move(rows);
+ memset(rows_->mutable_data(), 0, size_rows_fixed_length(rows_capacity));
+ bytes_capacity_ = size_rows_fixed_length(rows_capacity) - padding_for_vectors;
+ }
+
+ update_buffer_pointers();
+
+ rows_capacity_ = rows_capacity;
+
+ num_rows_ = 0;
+ num_rows_for_has_any_nulls_ = 0;
+ has_any_nulls_ = false;
+
+ return Status::OK();
+}
+
+void KeyEncoder::KeyRowArray::Clean() {
+ num_rows_ = 0;
+ num_rows_for_has_any_nulls_ = 0;
+ has_any_nulls_ = false;
+
+ if (!metadata_.is_fixed_length) {
+ reinterpret_cast<uint32_t*>(offsets_->mutable_data())[0] = 0;
+ }
+}
+
+int64_t KeyEncoder::KeyRowArray::size_null_masks(int64_t num_rows) {
+ return num_rows * metadata_.null_masks_bytes_per_row + padding_for_vectors;
+}
+
+int64_t KeyEncoder::KeyRowArray::size_offsets(int64_t num_rows) {
+ return (num_rows + 1) * sizeof(uint32_t) + padding_for_vectors;
+}
+
+int64_t KeyEncoder::KeyRowArray::size_rows_fixed_length(int64_t num_rows) {
+ return num_rows * metadata_.fixed_length + padding_for_vectors;
+}
+
+int64_t KeyEncoder::KeyRowArray::size_rows_varying_length(int64_t num_bytes) {
+ return num_bytes + padding_for_vectors;
+}
+
+void KeyEncoder::KeyRowArray::update_buffer_pointers() {
+ buffers_[0] = mutable_buffers_[0] = null_masks_->mutable_data();
+ if (metadata_.is_fixed_length) {
+ buffers_[1] = mutable_buffers_[1] = rows_->mutable_data();
+ buffers_[2] = mutable_buffers_[2] = nullptr;
+ } else {
+ buffers_[1] = mutable_buffers_[1] = offsets_->mutable_data();
+ buffers_[2] = mutable_buffers_[2] = rows_->mutable_data();
+ }
+}
+
+Status KeyEncoder::KeyRowArray::ResizeFixedLengthBuffers(int64_t num_extra_rows) {
+ if (rows_capacity_ >= num_rows_ + num_extra_rows) {
+ return Status::OK();
+ }
+
+ int64_t rows_capacity_new = std::max(static_cast<int64_t>(1), 2 * rows_capacity_);
+ while (rows_capacity_new < num_rows_ + num_extra_rows) {
+ rows_capacity_new *= 2;
+ }
+
+ // Null masks
+ RETURN_NOT_OK(null_masks_->Resize(size_null_masks(rows_capacity_new), false));
+ memset(null_masks_->mutable_data() + size_null_masks(rows_capacity_), 0,
+ size_null_masks(rows_capacity_new) - size_null_masks(rows_capacity_));
+
+ // Either offsets or rows
+ if (!metadata_.is_fixed_length) {
+ RETURN_NOT_OK(offsets_->Resize(size_offsets(rows_capacity_new), false));
+ memset(offsets_->mutable_data() + size_offsets(rows_capacity_), 0,
+ size_offsets(rows_capacity_new) - size_offsets(rows_capacity_));
+ } else {
+ RETURN_NOT_OK(rows_->Resize(size_rows_fixed_length(rows_capacity_new), false));
+ memset(rows_->mutable_data() + size_rows_fixed_length(rows_capacity_), 0,
+ size_rows_fixed_length(rows_capacity_new) -
+ size_rows_fixed_length(rows_capacity_));
+ bytes_capacity_ = size_rows_fixed_length(rows_capacity_new) - padding_for_vectors;
+ }
+
+ update_buffer_pointers();
+
+ rows_capacity_ = rows_capacity_new;
+
+ return Status::OK();
+}
+
+Status KeyEncoder::KeyRowArray::ResizeOptionalVaryingLengthBuffer(
+ int64_t num_extra_bytes) {
+ int64_t num_bytes = offsets()[num_rows_];
+ if (bytes_capacity_ >= num_bytes + num_extra_bytes || metadata_.is_fixed_length) {
+ return Status::OK();
+ }
+
+ int64_t bytes_capacity_new = std::max(static_cast<int64_t>(1), 2 * bytes_capacity_);
+ while (bytes_capacity_new < num_bytes + num_extra_bytes) {
+ bytes_capacity_new *= 2;
+ }
+
+ RETURN_NOT_OK(rows_->Resize(size_rows_varying_length(bytes_capacity_new), false));
+ memset(rows_->mutable_data() + size_rows_varying_length(bytes_capacity_), 0,
+ size_rows_varying_length(bytes_capacity_new) -
+ size_rows_varying_length(bytes_capacity_));
+
+ update_buffer_pointers();
+
+ bytes_capacity_ = bytes_capacity_new;
+
+ return Status::OK();
+}
+
+Status KeyEncoder::KeyRowArray::AppendSelectionFrom(const KeyRowArray& from,
+ uint32_t num_rows_to_append,
+ const uint16_t* source_row_ids) {
+ DCHECK(metadata_.is_compatible(from.metadata()));
+
+ RETURN_NOT_OK(ResizeFixedLengthBuffers(num_rows_to_append));
+
+ if (!metadata_.is_fixed_length) {
+ // Varying-length rows
+ auto from_offsets = reinterpret_cast<const uint32_t*>(from.offsets_->data());
+ auto to_offsets = reinterpret_cast<uint32_t*>(offsets_->mutable_data());
+ uint32_t total_length = to_offsets[num_rows_];
+ uint32_t total_length_to_append = 0;
+ for (uint32_t i = 0; i < num_rows_to_append; ++i) {
+ uint16_t row_id = source_row_ids ? source_row_ids[i] : i;
+ uint32_t length = from_offsets[row_id + 1] - from_offsets[row_id];
+ total_length_to_append += length;
+ to_offsets[num_rows_ + i + 1] = total_length + total_length_to_append;
+ }
+
+ RETURN_NOT_OK(ResizeOptionalVaryingLengthBuffer(total_length_to_append));
+
+ const uint8_t* src = from.rows_->data();
+ uint8_t* dst = rows_->mutable_data() + total_length;
+ for (uint32_t i = 0; i < num_rows_to_append; ++i) {
+ uint16_t row_id = source_row_ids ? source_row_ids[i] : i;
+ uint32_t length = from_offsets[row_id + 1] - from_offsets[row_id];
+ auto src64 = reinterpret_cast<const uint64_t*>(src + from_offsets[row_id]);
+ auto dst64 = reinterpret_cast<uint64_t*>(dst);
+ for (uint32_t j = 0; j < BitUtil::CeilDiv(length, 8); ++j) {
+ dst64[j] = src64[j];
+ }
+ dst += length;
+ }
+ } else {
+ // Fixed-length rows
+ const uint8_t* src = from.rows_->data();
+ uint8_t* dst = rows_->mutable_data() + num_rows_ * metadata_.fixed_length;
+ for (uint32_t i = 0; i < num_rows_to_append; ++i) {
+ uint16_t row_id = source_row_ids ? source_row_ids[i] : i;
+ uint32_t length = metadata_.fixed_length;
+ auto src64 = reinterpret_cast<const uint64_t*>(src + length * row_id);
+ auto dst64 = reinterpret_cast<uint64_t*>(dst);
+ for (uint32_t j = 0; j < BitUtil::CeilDiv(length, 8); ++j) {
+ dst64[j] = src64[j];
+ }
+ dst += length;
+ }
+ }
+
+ // Null masks
+ uint32_t byte_length = metadata_.null_masks_bytes_per_row;
+ uint64_t dst_byte_offset = num_rows_ * byte_length;
+ const uint8_t* src_base = from.null_masks_->data();
+ uint8_t* dst_base = null_masks_->mutable_data();
+ for (uint32_t i = 0; i < num_rows_to_append; ++i) {
+ uint32_t row_id = source_row_ids ? source_row_ids[i] : i;
+ int64_t src_byte_offset = row_id * byte_length;
+ const uint8_t* src = src_base + src_byte_offset;
+ uint8_t* dst = dst_base + dst_byte_offset;
+ for (uint32_t ibyte = 0; ibyte < byte_length; ++ibyte) {
+ dst[ibyte] = src[ibyte];
+ }
+ dst_byte_offset += byte_length;
+ }
+
+ num_rows_ += num_rows_to_append;
+
+ return Status::OK();
+}
+
+Status KeyEncoder::KeyRowArray::AppendEmpty(uint32_t num_rows_to_append,
+ uint32_t num_extra_bytes_to_append) {
+ RETURN_NOT_OK(ResizeFixedLengthBuffers(num_rows_to_append));
+ RETURN_NOT_OK(ResizeOptionalVaryingLengthBuffer(num_extra_bytes_to_append));
+ num_rows_ += num_rows_to_append;
+ if (metadata_.row_alignment > 1 || metadata_.string_alignment > 1) {
+ memset(rows_->mutable_data(), 0, bytes_capacity_);
+ }
+ return Status::OK();
+}
+
+bool KeyEncoder::KeyRowArray::has_any_nulls(const KeyEncoderContext* ctx) const {
+ if (has_any_nulls_) {
+ return true;
+ }
+ if (num_rows_for_has_any_nulls_ < num_rows_) {
+ auto size_per_row = metadata().null_masks_bytes_per_row;
+ has_any_nulls_ = !util::BitUtil::are_all_bytes_zero(
+ ctx->hardware_flags, null_masks() + size_per_row * num_rows_for_has_any_nulls_,
+ static_cast<uint32_t>(size_per_row * (num_rows_ - num_rows_for_has_any_nulls_)));
+ num_rows_for_has_any_nulls_ = num_rows_;
+ }
+ return has_any_nulls_;
+}
+
+KeyEncoder::KeyColumnArray::KeyColumnArray(const KeyColumnMetadata& metadata,
+ const KeyColumnArray& left,
+ const KeyColumnArray& right,
+ int buffer_id_to_replace) {
+ metadata_ = metadata;
+ length_ = left.length();
+ for (int i = 0; i < max_buffers_; ++i) {
+ buffers_[i] = left.buffers_[i];
+ mutable_buffers_[i] = left.mutable_buffers_[i];
+ }
+ buffers_[buffer_id_to_replace] = right.buffers_[buffer_id_to_replace];
+ mutable_buffers_[buffer_id_to_replace] = right.mutable_buffers_[buffer_id_to_replace];
+ bit_offset_[0] = left.bit_offset_[0];
+ bit_offset_[1] = left.bit_offset_[1];
+ if (buffer_id_to_replace < max_buffers_ - 1) {
+ bit_offset_[buffer_id_to_replace] = right.bit_offset_[buffer_id_to_replace];
+ }
+}
+
+KeyEncoder::KeyColumnArray::KeyColumnArray(const KeyColumnMetadata& metadata,
+ int64_t length, const uint8_t* buffer0,
+ const uint8_t* buffer1, const uint8_t* buffer2,
+ int bit_offset0, int bit_offset1) {
+ metadata_ = metadata;
+ length_ = length;
+ buffers_[0] = buffer0;
+ buffers_[1] = buffer1;
+ buffers_[2] = buffer2;
+ mutable_buffers_[0] = mutable_buffers_[1] = mutable_buffers_[2] = nullptr;
+ bit_offset_[0] = bit_offset0;
+ bit_offset_[1] = bit_offset1;
+}
+
+KeyEncoder::KeyColumnArray::KeyColumnArray(const KeyColumnMetadata& metadata,
+ int64_t length, uint8_t* buffer0,
+ uint8_t* buffer1, uint8_t* buffer2,
+ int bit_offset0, int bit_offset1) {
+ metadata_ = metadata;
+ length_ = length;
+ buffers_[0] = mutable_buffers_[0] = buffer0;
+ buffers_[1] = mutable_buffers_[1] = buffer1;
+ buffers_[2] = mutable_buffers_[2] = buffer2;
+ bit_offset_[0] = bit_offset0;
+ bit_offset_[1] = bit_offset1;
+}
+
+KeyEncoder::KeyColumnArray::KeyColumnArray(const KeyColumnArray& from, int64_t start,
+ int64_t length) {
+ metadata_ = from.metadata_;
+ length_ = length;
+ uint32_t fixed_size =
+ !metadata_.is_fixed_length ? sizeof(uint32_t) : metadata_.fixed_length;
+
+ buffers_[0] =
+ from.buffers_[0] ? from.buffers_[0] + (from.bit_offset_[0] + start) / 8 : nullptr;
+ mutable_buffers_[0] = from.mutable_buffers_[0]
+ ? from.mutable_buffers_[0] + (from.bit_offset_[0] + start) / 8
+ : nullptr;
+ bit_offset_[0] = (from.bit_offset_[0] + start) % 8;
+
+ if (fixed_size == 0) {
+ buffers_[1] =
+ from.buffers_[1] ? from.buffers_[1] + (from.bit_offset_[1] + start) / 8 : nullptr;
+ mutable_buffers_[1] = from.mutable_buffers_[1] ? from.mutable_buffers_[1] +
+ (from.bit_offset_[1] + start) / 8
+ : nullptr;
+ bit_offset_[1] = (from.bit_offset_[1] + start) % 8;
+ } else {
+ buffers_[1] = from.buffers_[1] ? from.buffers_[1] + start * fixed_size : nullptr;
+ mutable_buffers_[1] = from.mutable_buffers_[1]
+ ? from.mutable_buffers_[1] + start * fixed_size
+ : nullptr;
+ bit_offset_[1] = 0;
+ }
+
+ buffers_[2] = from.buffers_[2];
+ mutable_buffers_[2] = from.mutable_buffers_[2];
+}
+
+KeyEncoder::KeyColumnArray KeyEncoder::TransformBoolean::ArrayReplace(
+ const KeyColumnArray& column, const KeyColumnArray& temp) {
+ // Make sure that the temp buffer is large enough
+ DCHECK(temp.length() >= column.length() && temp.metadata().is_fixed_length &&
+ temp.metadata().fixed_length >= sizeof(uint8_t));
+ KeyColumnMetadata metadata;
+ metadata.is_fixed_length = true;
+ metadata.fixed_length = sizeof(uint8_t);
+ constexpr int buffer_index = 1;
+ KeyColumnArray result = KeyColumnArray(metadata, column, temp, buffer_index);
+ return result;
+}
+
+void KeyEncoder::TransformBoolean::PostDecode(const KeyColumnArray& input,
+ KeyColumnArray* output,
+ KeyEncoderContext* ctx) {
+ // Make sure that metadata and lengths are compatible.
+ DCHECK(output->metadata().is_fixed_length == input.metadata().is_fixed_length);
+ DCHECK(output->metadata().fixed_length == 0 && input.metadata().fixed_length == 1);
+ DCHECK(output->length() == input.length());
+ constexpr int buffer_index = 1;
+ DCHECK(input.data(buffer_index) != nullptr);
+ DCHECK(output->mutable_data(buffer_index) != nullptr);
+
+ util::BitUtil::bytes_to_bits(
+ ctx->hardware_flags, static_cast<int>(input.length()), input.data(buffer_index),
+ output->mutable_data(buffer_index), output->bit_offset(buffer_index));
+}
+
+bool KeyEncoder::EncoderInteger::IsBoolean(const KeyColumnMetadata& metadata) {
+ return metadata.is_fixed_length && metadata.fixed_length == 0;
+}
+
+bool KeyEncoder::EncoderInteger::UsesTransform(const KeyColumnArray& column) {
+ return IsBoolean(column.metadata());
+}
+
+KeyEncoder::KeyColumnArray KeyEncoder::EncoderInteger::ArrayReplace(
+ const KeyColumnArray& column, const KeyColumnArray& temp) {
+ if (IsBoolean(column.metadata())) {
+ return TransformBoolean::ArrayReplace(column, temp);
+ }
+ return column;
+}
+
+void KeyEncoder::EncoderInteger::PostDecode(const KeyColumnArray& input,
+ KeyColumnArray* output,
+ KeyEncoderContext* ctx) {
+ if (IsBoolean(output->metadata())) {
+ TransformBoolean::PostDecode(input, output, ctx);
+ }
+}
+
+void KeyEncoder::EncoderInteger::Decode(uint32_t start_row, uint32_t num_rows,
+ uint32_t offset_within_row,
+ const KeyRowArray& rows, KeyColumnArray* col,
+ KeyEncoderContext* ctx, KeyColumnArray* temp) {
+ KeyColumnArray col_prep;
+ if (UsesTransform(*col)) {
+ col_prep = ArrayReplace(*col, *temp);
+ } else {
+ col_prep = *col;
+ }
+
+ // When we have a single fixed length column we can just do memcpy
+ if (rows.metadata().is_fixed_length &&
+ col_prep.metadata().fixed_length == rows.metadata().fixed_length) {
+ DCHECK_EQ(offset_within_row, 0);
+ uint32_t row_size = rows.metadata().fixed_length;
+ memcpy(col_prep.mutable_data(1), rows.data(1) + start_row * row_size,
+ num_rows * row_size);
+ } else if (rows.metadata().is_fixed_length) {
+ uint32_t row_size = rows.metadata().fixed_length;
+ const uint8_t* row_base = rows.data(1) + start_row * row_size;
+ row_base += offset_within_row;
+ uint8_t* col_base = col_prep.mutable_data(1);
+ switch (col_prep.metadata().fixed_length) {
+ case 1:
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ col_base[i] = row_base[i * row_size];
+ }
+ break;
+ case 2:
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ reinterpret_cast<uint16_t*>(col_base)[i] =
+ *reinterpret_cast<const uint16_t*>(row_base + i * row_size);
+ }
+ break;
+ case 4:
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ reinterpret_cast<uint32_t*>(col_base)[i] =
+ *reinterpret_cast<const uint32_t*>(row_base + i * row_size);
+ }
+ break;
+ case 8:
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ reinterpret_cast<uint64_t*>(col_base)[i] =
+ *reinterpret_cast<const uint64_t*>(row_base + i * row_size);
+ }
+ break;
+ default:
+ DCHECK(false);
+ }
+ } else {
+ const uint32_t* row_offsets = rows.offsets() + start_row;
+ const uint8_t* row_base = rows.data(2);
+ row_base += offset_within_row;
+ uint8_t* col_base = col_prep.mutable_data(1);
+ switch (col_prep.metadata().fixed_length) {
+ case 1:
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ col_base[i] = row_base[row_offsets[i]];
+ }
+ break;
+ case 2:
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ reinterpret_cast<uint16_t*>(col_base)[i] =
+ *reinterpret_cast<const uint16_t*>(row_base + row_offsets[i]);
+ }
+ break;
+ case 4:
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ reinterpret_cast<uint32_t*>(col_base)[i] =
+ *reinterpret_cast<const uint32_t*>(row_base + row_offsets[i]);
+ }
+ break;
+ case 8:
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ reinterpret_cast<uint64_t*>(col_base)[i] =
+ *reinterpret_cast<const uint64_t*>(row_base + row_offsets[i]);
+ }
+ break;
+ default:
+ DCHECK(false);
+ }
+ }
+
+ if (UsesTransform(*col)) {
+ PostDecode(col_prep, col, ctx);
+ }
+}
+
+bool KeyEncoder::EncoderBinary::IsInteger(const KeyColumnMetadata& metadata) {
+ bool is_fixed_length = metadata.is_fixed_length;
+ auto size = metadata.fixed_length;
+ return is_fixed_length &&
+ (size == 0 || size == 1 || size == 2 || size == 4 || size == 8);
+}
+
+void KeyEncoder::EncoderBinary::Decode(uint32_t start_row, uint32_t num_rows,
+ uint32_t offset_within_row,
+ const KeyRowArray& rows, KeyColumnArray* col,
+ KeyEncoderContext* ctx, KeyColumnArray* temp) {
+ if (IsInteger(col->metadata())) {
+ EncoderInteger::Decode(start_row, num_rows, offset_within_row, rows, col, ctx, temp);
+ } else {
+ KeyColumnArray col_prep;
+ if (EncoderInteger::UsesTransform(*col)) {
+ col_prep = EncoderInteger::ArrayReplace(*col, *temp);
+ } else {
+ col_prep = *col;
+ }
+
+ bool is_row_fixed_length = rows.metadata().is_fixed_length;
+
+#if defined(ARROW_HAVE_AVX2)
+ if (ctx->has_avx2()) {
+ DecodeHelper_avx2(is_row_fixed_length, start_row, num_rows, offset_within_row, rows,
+ col);
+ } else {
+#endif
+ if (is_row_fixed_length) {
+ DecodeImp<true>(start_row, num_rows, offset_within_row, rows, col);
+ } else {
+ DecodeImp<false>(start_row, num_rows, offset_within_row, rows, col);
+ }
+#if defined(ARROW_HAVE_AVX2)
+ }
+#endif
+
+ if (EncoderInteger::UsesTransform(*col)) {
+ EncoderInteger::PostDecode(col_prep, col, ctx);
+ }
+ }
+}
+
+template <bool is_row_fixed_length>
+void KeyEncoder::EncoderBinary::DecodeImp(uint32_t start_row, uint32_t num_rows,
+ uint32_t offset_within_row,
+ const KeyRowArray& rows, KeyColumnArray* col) {
+ DecodeHelper<is_row_fixed_length>(
+ start_row, num_rows, offset_within_row, &rows, nullptr, col, col,
+ [](uint8_t* dst, const uint8_t* src, int64_t length) {
+ for (uint32_t istripe = 0; istripe < BitUtil::CeilDiv(length, 8); ++istripe) {
+ auto dst64 = reinterpret_cast<uint64_t*>(dst);
+ auto src64 = reinterpret_cast<const uint64_t*>(src);
+ util::SafeStore(dst64 + istripe, src64[istripe]);
+ }
+ });
+}
+
+void KeyEncoder::EncoderBinaryPair::Decode(uint32_t start_row, uint32_t num_rows,
+ uint32_t offset_within_row,
+ const KeyRowArray& rows, KeyColumnArray* col1,
+ KeyColumnArray* col2, KeyEncoderContext* ctx,
+ KeyColumnArray* temp1, KeyColumnArray* temp2) {
+ DCHECK(CanProcessPair(col1->metadata(), col2->metadata()));
+
+ KeyColumnArray col_prep[2];
+ if (EncoderInteger::UsesTransform(*col1)) {
+ col_prep[0] = EncoderInteger::ArrayReplace(*col1, *temp1);
+ } else {
+ col_prep[0] = *col1;
+ }
+ if (EncoderInteger::UsesTransform(*col2)) {
+ col_prep[1] = EncoderInteger::ArrayReplace(*col2, *temp2);
+ } else {
+ col_prep[1] = *col2;
+ }
+
+ uint32_t col_width1 = col_prep[0].metadata().fixed_length;
+ uint32_t col_width2 = col_prep[1].metadata().fixed_length;
+ int log_col_width1 =
+ col_width1 == 8 ? 3 : col_width1 == 4 ? 2 : col_width1 == 2 ? 1 : 0;
+ int log_col_width2 =
+ col_width2 == 8 ? 3 : col_width2 == 4 ? 2 : col_width2 == 2 ? 1 : 0;
+
+ bool is_row_fixed_length = rows.metadata().is_fixed_length;
+
+ uint32_t num_processed = 0;
+#if defined(ARROW_HAVE_AVX2)
+ if (ctx->has_avx2() && col_width1 == col_width2) {
+ num_processed =
+ DecodeHelper_avx2(is_row_fixed_length, col_width1, start_row, num_rows,
+ offset_within_row, rows, &col_prep[0], &col_prep[1]);
+ }
+#endif
+ if (num_processed < num_rows) {
+ using DecodeImp_t = void (*)(uint32_t, uint32_t, uint32_t, uint32_t,
+ const KeyRowArray&, KeyColumnArray*, KeyColumnArray*);
+ static const DecodeImp_t DecodeImp_fn[] = {
+ DecodeImp<false, uint8_t, uint8_t>, DecodeImp<false, uint16_t, uint8_t>,
+ DecodeImp<false, uint32_t, uint8_t>, DecodeImp<false, uint64_t, uint8_t>,
+ DecodeImp<false, uint8_t, uint16_t>, DecodeImp<false, uint16_t, uint16_t>,
+ DecodeImp<false, uint32_t, uint16_t>, DecodeImp<false, uint64_t, uint16_t>,
+ DecodeImp<false, uint8_t, uint32_t>, DecodeImp<false, uint16_t, uint32_t>,
+ DecodeImp<false, uint32_t, uint32_t>, DecodeImp<false, uint64_t, uint32_t>,
+ DecodeImp<false, uint8_t, uint64_t>, DecodeImp<false, uint16_t, uint64_t>,
+ DecodeImp<false, uint32_t, uint64_t>, DecodeImp<false, uint64_t, uint64_t>,
+ DecodeImp<true, uint8_t, uint8_t>, DecodeImp<true, uint16_t, uint8_t>,
+ DecodeImp<true, uint32_t, uint8_t>, DecodeImp<true, uint64_t, uint8_t>,
+ DecodeImp<true, uint8_t, uint16_t>, DecodeImp<true, uint16_t, uint16_t>,
+ DecodeImp<true, uint32_t, uint16_t>, DecodeImp<true, uint64_t, uint16_t>,
+ DecodeImp<true, uint8_t, uint32_t>, DecodeImp<true, uint16_t, uint32_t>,
+ DecodeImp<true, uint32_t, uint32_t>, DecodeImp<true, uint64_t, uint32_t>,
+ DecodeImp<true, uint8_t, uint64_t>, DecodeImp<true, uint16_t, uint64_t>,
+ DecodeImp<true, uint32_t, uint64_t>, DecodeImp<true, uint64_t, uint64_t>};
+ int dispatch_const =
+ (log_col_width2 << 2) | log_col_width1 | (is_row_fixed_length ? 16 : 0);
+ DecodeImp_fn[dispatch_const](num_processed, start_row, num_rows, offset_within_row,
+ rows, &(col_prep[0]), &(col_prep[1]));
+ }
+
+ if (EncoderInteger::UsesTransform(*col1)) {
+ EncoderInteger::PostDecode(col_prep[0], col1, ctx);
+ }
+ if (EncoderInteger::UsesTransform(*col2)) {
+ EncoderInteger::PostDecode(col_prep[1], col2, ctx);
+ }
+}
+
+template <bool is_row_fixed_length, typename col1_type, typename col2_type>
+void KeyEncoder::EncoderBinaryPair::DecodeImp(uint32_t num_rows_to_skip,
+ uint32_t start_row, uint32_t num_rows,
+ uint32_t offset_within_row,
+ const KeyRowArray& rows,
+ KeyColumnArray* col1,
+ KeyColumnArray* col2) {
+ DCHECK(rows.length() >= start_row + num_rows);
+ DCHECK(col1->length() == num_rows && col2->length() == num_rows);
+
+ uint8_t* dst_A = col1->mutable_data(1);
+ uint8_t* dst_B = col2->mutable_data(1);
+
+ uint32_t fixed_length = rows.metadata().fixed_length;
+ const uint32_t* offsets;
+ const uint8_t* src_base;
+ if (is_row_fixed_length) {
+ src_base = rows.data(1) + fixed_length * start_row + offset_within_row;
+ offsets = nullptr;
+ } else {
+ src_base = rows.data(2) + offset_within_row;
+ offsets = rows.offsets() + start_row;
+ }
+
+ using col1_type_const = typename std::add_const<col1_type>::type;
+ using col2_type_const = typename std::add_const<col2_type>::type;
+
+ if (is_row_fixed_length) {
+ const uint8_t* src = src_base + num_rows_to_skip * fixed_length;
+ for (uint32_t i = num_rows_to_skip; i < num_rows; ++i) {
+ reinterpret_cast<col1_type*>(dst_A)[i] = *reinterpret_cast<col1_type_const*>(src);
+ reinterpret_cast<col2_type*>(dst_B)[i] =
+ *reinterpret_cast<col2_type_const*>(src + sizeof(col1_type));
+ src += fixed_length;
+ }
+ } else {
+ for (uint32_t i = num_rows_to_skip; i < num_rows; ++i) {
+ const uint8_t* src = src_base + offsets[i];
+ reinterpret_cast<col1_type*>(dst_A)[i] = *reinterpret_cast<col1_type_const*>(src);
+ reinterpret_cast<col2_type*>(dst_B)[i] =
+ *reinterpret_cast<col2_type_const*>(src + sizeof(col1_type));
+ }
+ }
+}
+
+void KeyEncoder::EncoderOffsets::Decode(
+ uint32_t start_row, uint32_t num_rows, const KeyRowArray& rows,
+ std::vector<KeyColumnArray>* varbinary_cols,
+ const std::vector<uint32_t>& varbinary_cols_base_offset, KeyEncoderContext* ctx) {
+ DCHECK(!varbinary_cols->empty());
+ DCHECK(varbinary_cols->size() == varbinary_cols_base_offset.size());
+
+ DCHECK(!rows.metadata().is_fixed_length);
+ DCHECK(rows.length() >= start_row + num_rows);
+ for (const auto& col : *varbinary_cols) {
+ // Rows and columns must all be varying-length
+ DCHECK(!col.metadata().is_fixed_length);
+ // The space in columns must be exactly equal to a subset of rows selected
+ DCHECK(col.length() == num_rows);
+ }
+
+ // Offsets of varbinary columns data within each encoded row are stored
+ // in the same encoded row as an array of 32-bit integers.
+ // This array follows immediately the data of fixed-length columns.
+ // There is one element for each varying-length column.
+ // The Nth element is the sum of all the lengths of varbinary columns data in
+ // that row, up to and including Nth varbinary column.
+
+ const uint32_t* row_offsets = rows.offsets() + start_row;
+
+ // Set the base offset for each column
+ for (size_t col = 0; col < varbinary_cols->size(); ++col) {
+ uint32_t* col_offsets = (*varbinary_cols)[col].mutable_offsets();
+ col_offsets[0] = varbinary_cols_base_offset[col];
+ }
+
+ int string_alignment = rows.metadata().string_alignment;
+
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ // Find the beginning of cumulative lengths array for next row
+ const uint8_t* row = rows.data(2) + row_offsets[i];
+ const uint32_t* varbinary_ends = rows.metadata().varbinary_end_array(row);
+
+ // Update the offset of each column
+ uint32_t offset_within_row = rows.metadata().fixed_length;
+ for (size_t col = 0; col < varbinary_cols->size(); ++col) {
+ offset_within_row +=
+ KeyRowMetadata::padding_for_alignment(offset_within_row, string_alignment);
+ uint32_t length = varbinary_ends[col] - offset_within_row;
+ offset_within_row = varbinary_ends[col];
+ uint32_t* col_offsets = (*varbinary_cols)[col].mutable_offsets();
+ col_offsets[i + 1] = col_offsets[i] + length;
+ }
+ }
+}
+
+void KeyEncoder::EncoderVarBinary::Decode(uint32_t start_row, uint32_t num_rows,
+ uint32_t varbinary_col_id,
+ const KeyRowArray& rows, KeyColumnArray* col,
+ KeyEncoderContext* ctx) {
+ // Output column varbinary buffer needs an extra 32B
+ // at the end in avx2 version and 8B otherwise.
+#if defined(ARROW_HAVE_AVX2)
+ if (ctx->has_avx2()) {
+ DecodeHelper_avx2(start_row, num_rows, varbinary_col_id, rows, col);
+ } else {
+#endif
+ if (varbinary_col_id == 0) {
+ DecodeImp<true>(start_row, num_rows, varbinary_col_id, rows, col);
+ } else {
+ DecodeImp<false>(start_row, num_rows, varbinary_col_id, rows, col);
+ }
+#if defined(ARROW_HAVE_AVX2)
+ }
+#endif
+}
+
+template <bool first_varbinary_col>
+void KeyEncoder::EncoderVarBinary::DecodeImp(uint32_t start_row, uint32_t num_rows,
+ uint32_t varbinary_col_id,
+ const KeyRowArray& rows,
+ KeyColumnArray* col) {
+ DecodeHelper<first_varbinary_col>(
+ start_row, num_rows, varbinary_col_id, &rows, nullptr, col, col,
+ [](uint8_t* dst, const uint8_t* src, int64_t length) {
+ for (uint32_t istripe = 0; istripe < BitUtil::CeilDiv(length, 8); ++istripe) {
+ auto dst64 = reinterpret_cast<uint64_t*>(dst);
+ auto src64 = reinterpret_cast<const uint64_t*>(src);
+ util::SafeStore(dst64 + istripe, src64[istripe]);
+ }
+ });
+}
+
+void KeyEncoder::EncoderNulls::Decode(uint32_t start_row, uint32_t num_rows,
+ const KeyRowArray& rows,
+ std::vector<KeyColumnArray>* cols) {
+ // Every output column needs to have a space for exactly the required number
+ // of rows. It also needs to have non-nulls bit-vector allocated and mutable.
+ DCHECK_GT(cols->size(), 0);
+ for (auto& col : *cols) {
+ DCHECK(col.length() == num_rows);
+ DCHECK(col.mutable_data(0));
+ }
+
+ const uint8_t* null_masks = rows.null_masks();
+ uint32_t null_masks_bytes_per_row = rows.metadata().null_masks_bytes_per_row;
+ for (size_t col = 0; col < cols->size(); ++col) {
+ uint8_t* non_nulls = (*cols)[col].mutable_data(0);
+ const int bit_offset = (*cols)[col].bit_offset(0);
+ DCHECK_LT(bit_offset, 8);
+ non_nulls[0] |= 0xff << (bit_offset);
+ if (bit_offset + num_rows > 8) {
+ int bits_in_first_byte = 8 - bit_offset;
+ memset(non_nulls + 1, 0xff, BitUtil::BytesForBits(num_rows - bits_in_first_byte));
+ }
+ for (uint32_t row = 0; row < num_rows; ++row) {
+ uint32_t null_masks_bit_id =
+ (start_row + row) * null_masks_bytes_per_row * 8 + static_cast<uint32_t>(col);
+ bool is_set = BitUtil::GetBit(null_masks, null_masks_bit_id);
+ if (is_set) {
+ BitUtil::ClearBit(non_nulls, bit_offset + row);
+ }
+ }
+ }
+}
+
+uint32_t KeyEncoder::KeyRowMetadata::num_varbinary_cols() const {
+ uint32_t result = 0;
+ for (auto column_metadata : column_metadatas) {
+ if (!column_metadata.is_fixed_length) {
+ ++result;
+ }
+ }
+ return result;
+}
+
+bool KeyEncoder::KeyRowMetadata::is_compatible(const KeyRowMetadata& other) const {
+ if (other.num_cols() != num_cols()) {
+ return false;
+ }
+ if (row_alignment != other.row_alignment ||
+ string_alignment != other.string_alignment) {
+ return false;
+ }
+ for (size_t i = 0; i < column_metadatas.size(); ++i) {
+ if (column_metadatas[i].is_fixed_length !=
+ other.column_metadatas[i].is_fixed_length) {
+ return false;
+ }
+ if (column_metadatas[i].fixed_length != other.column_metadatas[i].fixed_length) {
+ return false;
+ }
+ }
+ return true;
+}
+
+void KeyEncoder::KeyRowMetadata::FromColumnMetadataVector(
+ const std::vector<KeyColumnMetadata>& cols, int in_row_alignment,
+ int in_string_alignment) {
+ column_metadatas.resize(cols.size());
+ for (size_t i = 0; i < cols.size(); ++i) {
+ column_metadatas[i] = cols[i];
+ }
+
+ const auto num_cols = static_cast<uint32_t>(cols.size());
+
+ // Sort columns.
+ //
+ // Columns are sorted based on the size in bytes of their fixed-length part.
+ // For the varying-length column, the fixed-length part is the 32-bit field storing
+ // cumulative length of varying-length fields.
+ //
+ // The rules are:
+ //
+ // a) Boolean column, marked with fixed-length 0, is considered to have fixed-length
+ // part of 1 byte.
+ //
+ // b) Columns with fixed-length part being power of 2 or multiple of row
+ // alignment precede other columns. They are sorted in decreasing order of the size of
+ // their fixed-length part.
+ //
+ // c) Fixed-length columns precede varying-length columns when
+ // both have the same size fixed-length part.
+ //
+ column_order.resize(num_cols);
+ for (uint32_t i = 0; i < num_cols; ++i) {
+ column_order[i] = i;
+ }
+ std::sort(
+ column_order.begin(), column_order.end(), [&cols](uint32_t left, uint32_t right) {
+ bool is_left_pow2 =
+ !cols[left].is_fixed_length || ARROW_POPCOUNT64(cols[left].fixed_length) <= 1;
+ bool is_right_pow2 = !cols[right].is_fixed_length ||
+ ARROW_POPCOUNT64(cols[right].fixed_length) <= 1;
+ bool is_left_fixedlen = cols[left].is_fixed_length;
+ bool is_right_fixedlen = cols[right].is_fixed_length;
+ uint32_t width_left =
+ cols[left].is_fixed_length ? cols[left].fixed_length : sizeof(uint32_t);
+ uint32_t width_right =
+ cols[right].is_fixed_length ? cols[right].fixed_length : sizeof(uint32_t);
+ if (is_left_pow2 != is_right_pow2) {
+ return is_left_pow2;
+ }
+ if (!is_left_pow2) {
+ return left < right;
+ }
+ if (width_left != width_right) {
+ return width_left > width_right;
+ }
+ if (is_left_fixedlen != is_right_fixedlen) {
+ return is_left_fixedlen;
+ }
+ return left < right;
+ });
+
+ row_alignment = in_row_alignment;
+ string_alignment = in_string_alignment;
+ varbinary_end_array_offset = 0;
+
+ column_offsets.resize(num_cols);
+ uint32_t num_varbinary_cols = 0;
+ uint32_t offset_within_row = 0;
+ for (uint32_t i = 0; i < num_cols; ++i) {
+ const KeyColumnMetadata& col = cols[column_order[i]];
+ if (col.is_fixed_length && col.fixed_length != 0 &&
+ ARROW_POPCOUNT64(col.fixed_length) != 1) {
+ offset_within_row +=
+ KeyRowMetadata::padding_for_alignment(offset_within_row, string_alignment, col);
+ }
+ column_offsets[i] = offset_within_row;
+ if (!col.is_fixed_length) {
+ if (num_varbinary_cols == 0) {
+ varbinary_end_array_offset = offset_within_row;
+ }
+ DCHECK(column_offsets[i] - varbinary_end_array_offset ==
+ num_varbinary_cols * sizeof(uint32_t));
+ ++num_varbinary_cols;
+ offset_within_row += sizeof(uint32_t);
+ } else {
+ // Boolean column is a bit-vector, which is indicated by
+ // setting fixed length in column metadata to zero.
+ // It will be stored as a byte in output row.
+ if (col.fixed_length == 0) {
+ offset_within_row += 1;
+ } else {
+ offset_within_row += col.fixed_length;
+ }
+ }
+ }
+
+ is_fixed_length = (num_varbinary_cols == 0);
+ fixed_length =
+ offset_within_row +
+ KeyRowMetadata::padding_for_alignment(
+ offset_within_row, num_varbinary_cols == 0 ? row_alignment : string_alignment);
+
+ // We set the number of bytes per row storing null masks of individual key columns
+ // to be a power of two. This is not required. It could be also set to the minimal
+ // number of bytes required for a given number of bits (one bit per column).
+ null_masks_bytes_per_row = 1;
+ while (static_cast<uint32_t>(null_masks_bytes_per_row * 8) < num_cols) {
+ null_masks_bytes_per_row *= 2;
+ }
+}
+
+void KeyEncoder::Init(const std::vector<KeyColumnMetadata>& cols, KeyEncoderContext* ctx,
+ int row_alignment, int string_alignment) {
+ ctx_ = ctx;
+ row_metadata_.FromColumnMetadataVector(cols, row_alignment, string_alignment);
+ uint32_t num_cols = row_metadata_.num_cols();
+ uint32_t num_varbinary_cols = row_metadata_.num_varbinary_cols();
+ batch_all_cols_.resize(num_cols);
+ batch_varbinary_cols_.resize(num_varbinary_cols);
+ batch_varbinary_cols_base_offsets_.resize(num_varbinary_cols);
+}
+
+void KeyEncoder::PrepareKeyColumnArrays(int64_t start_row, int64_t num_rows,
+ const std::vector<KeyColumnArray>& cols_in) {
+ const auto num_cols = static_cast<uint32_t>(cols_in.size());
+ DCHECK(batch_all_cols_.size() == num_cols);
+
+ uint32_t num_varbinary_visited = 0;
+ for (uint32_t i = 0; i < num_cols; ++i) {
+ const KeyColumnArray& col = cols_in[row_metadata_.column_order[i]];
+ KeyColumnArray col_window(col, start_row, num_rows);
+ batch_all_cols_[i] = col_window;
+ if (!col.metadata().is_fixed_length) {
+ DCHECK(num_varbinary_visited < batch_varbinary_cols_.size());
+ // If start row is zero, then base offset of varbinary column is also zero.
+ if (start_row == 0) {
+ batch_varbinary_cols_base_offsets_[num_varbinary_visited] = 0;
+ } else {
+ batch_varbinary_cols_base_offsets_[num_varbinary_visited] =
+ col.offsets()[start_row];
+ }
+ batch_varbinary_cols_[num_varbinary_visited++] = col_window;
+ }
+ }
+}
+
+void KeyEncoder::DecodeFixedLengthBuffers(int64_t start_row_input,
+ int64_t start_row_output, int64_t num_rows,
+ const KeyRowArray& rows,
+ std::vector<KeyColumnArray>* cols) {
+ // Prepare column array vectors
+ PrepareKeyColumnArrays(start_row_output, num_rows, *cols);
+
+ // Create two temp vectors with 16-bit elements
+ auto temp_buffer_holder_A =
+ util::TempVectorHolder<uint16_t>(ctx_->stack, static_cast<uint32_t>(num_rows));
+ auto temp_buffer_A = KeyColumnArray(
+ KeyColumnMetadata(true, sizeof(uint16_t)), num_rows, nullptr,
+ reinterpret_cast<uint8_t*>(temp_buffer_holder_A.mutable_data()), nullptr);
+ auto temp_buffer_holder_B =
+ util::TempVectorHolder<uint16_t>(ctx_->stack, static_cast<uint32_t>(num_rows));
+ auto temp_buffer_B = KeyColumnArray(
+ KeyColumnMetadata(true, sizeof(uint16_t)), num_rows, nullptr,
+ reinterpret_cast<uint8_t*>(temp_buffer_holder_B.mutable_data()), nullptr);
+
+ bool is_row_fixed_length = row_metadata_.is_fixed_length;
+ if (!is_row_fixed_length) {
+ EncoderOffsets::Decode(static_cast<uint32_t>(start_row_input),
+ static_cast<uint32_t>(num_rows), rows, &batch_varbinary_cols_,
+ batch_varbinary_cols_base_offsets_, ctx_);
+ }
+
+ // Process fixed length columns
+ const auto num_cols = static_cast<uint32_t>(batch_all_cols_.size());
+ for (uint32_t i = 0; i < num_cols;) {
+ if (!batch_all_cols_[i].metadata().is_fixed_length) {
+ i += 1;
+ continue;
+ }
+ bool can_process_pair =
+ (i + 1 < num_cols) && batch_all_cols_[i + 1].metadata().is_fixed_length &&
+ EncoderBinaryPair::CanProcessPair(batch_all_cols_[i].metadata(),
+ batch_all_cols_[i + 1].metadata());
+ if (!can_process_pair) {
+ EncoderBinary::Decode(static_cast<uint32_t>(start_row_input),
+ static_cast<uint32_t>(num_rows),
+ row_metadata_.column_offsets[i], rows, &batch_all_cols_[i],
+ ctx_, &temp_buffer_A);
+ i += 1;
+ } else {
+ EncoderBinaryPair::Decode(
+ static_cast<uint32_t>(start_row_input), static_cast<uint32_t>(num_rows),
+ row_metadata_.column_offsets[i], rows, &batch_all_cols_[i],
+ &batch_all_cols_[i + 1], ctx_, &temp_buffer_A, &temp_buffer_B);
+ i += 2;
+ }
+ }
+
+ // Process nulls
+ EncoderNulls::Decode(static_cast<uint32_t>(start_row_input),
+ static_cast<uint32_t>(num_rows), rows, &batch_all_cols_);
+}
+
+void KeyEncoder::DecodeVaryingLengthBuffers(int64_t start_row_input,
+ int64_t start_row_output, int64_t num_rows,
+ const KeyRowArray& rows,
+ std::vector<KeyColumnArray>* cols) {
+ // Prepare column array vectors
+ PrepareKeyColumnArrays(start_row_output, num_rows, *cols);
+
+ bool is_row_fixed_length = row_metadata_.is_fixed_length;
+ if (!is_row_fixed_length) {
+ for (size_t i = 0; i < batch_varbinary_cols_.size(); ++i) {
+ // Memcpy varbinary fields into precomputed in the previous step
+ // positions in the output row buffer.
+ EncoderVarBinary::Decode(static_cast<uint32_t>(start_row_input),
+ static_cast<uint32_t>(num_rows), static_cast<uint32_t>(i),
+ rows, &batch_varbinary_cols_[i], ctx_);
+ }
+ }
+}
+
+template <class COPY_FN, class SET_NULL_FN>
+void KeyEncoder::EncoderBinary::EncodeSelectedImp(
+ uint32_t offset_within_row, KeyRowArray* rows, const KeyColumnArray& col,
+ uint32_t num_selected, const uint16_t* selection, COPY_FN copy_fn,
+ SET_NULL_FN set_null_fn) {
+ bool is_fixed_length = rows->metadata().is_fixed_length;
+ if (is_fixed_length) {
+ uint32_t row_width = rows->metadata().fixed_length;
+ const uint8_t* src_base = col.data(1);
+ uint8_t* dst = rows->mutable_data(1) + offset_within_row;
+ for (uint32_t i = 0; i < num_selected; ++i) {
+ copy_fn(dst, src_base, selection[i]);
+ dst += row_width;
+ }
+ if (col.data(0)) {
+ const uint8_t* non_null_bits = col.data(0);
+ uint8_t* dst = rows->mutable_data(1) + offset_within_row;
+ for (uint32_t i = 0; i < num_selected; ++i) {
+ bool is_null = !BitUtil::GetBit(non_null_bits, selection[i] + col.bit_offset(0));
+ if (is_null) {
+ set_null_fn(dst);
+ }
+ dst += row_width;
+ }
+ }
+ } else {
+ const uint8_t* src_base = col.data(1);
+ uint8_t* dst = rows->mutable_data(2) + offset_within_row;
+ const uint32_t* offsets = rows->offsets();
+ for (uint32_t i = 0; i < num_selected; ++i) {
+ copy_fn(dst + offsets[i], src_base, selection[i]);
+ }
+ if (col.data(0)) {
+ const uint8_t* non_null_bits = col.data(0);
+ uint8_t* dst = rows->mutable_data(2) + offset_within_row;
+ const uint32_t* offsets = rows->offsets();
+ for (uint32_t i = 0; i < num_selected; ++i) {
+ bool is_null = !BitUtil::GetBit(non_null_bits, selection[i] + col.bit_offset(0));
+ if (is_null) {
+ set_null_fn(dst + offsets[i]);
+ }
+ }
+ }
+ }
+}
+
+void KeyEncoder::EncoderBinary::EncodeSelected(uint32_t offset_within_row,
+ KeyRowArray* rows,
+ const KeyColumnArray& col,
+ uint32_t num_selected,
+ const uint16_t* selection) {
+ uint32_t col_width = col.metadata().fixed_length;
+ if (col_width == 0) {
+ int bit_offset = col.bit_offset(1);
+ EncodeSelectedImp(
+ offset_within_row, rows, col, num_selected, selection,
+ [bit_offset](uint8_t* dst, const uint8_t* src_base, uint16_t irow) {
+ *dst = BitUtil::GetBit(src_base, irow + bit_offset) ? 0xff : 0x00;
+ },
+ [](uint8_t* dst) { *dst = 0xae; });
+ } else if (col_width == 1) {
+ EncodeSelectedImp(
+ offset_within_row, rows, col, num_selected, selection,
+ [](uint8_t* dst, const uint8_t* src_base, uint16_t irow) {
+ *dst = src_base[irow];
+ },
+ [](uint8_t* dst) { *dst = 0xae; });
+ } else if (col_width == 2) {
+ EncodeSelectedImp(
+ offset_within_row, rows, col, num_selected, selection,
+ [](uint8_t* dst, const uint8_t* src_base, uint16_t irow) {
+ *reinterpret_cast<uint16_t*>(dst) =
+ reinterpret_cast<const uint16_t*>(src_base)[irow];
+ },
+ [](uint8_t* dst) { *reinterpret_cast<uint16_t*>(dst) = 0xaeae; });
+ } else if (col_width == 4) {
+ EncodeSelectedImp(
+ offset_within_row, rows, col, num_selected, selection,
+ [](uint8_t* dst, const uint8_t* src_base, uint16_t irow) {
+ *reinterpret_cast<uint32_t*>(dst) =
+ reinterpret_cast<const uint32_t*>(src_base)[irow];
+ },
+ [](uint8_t* dst) {
+ *reinterpret_cast<uint32_t*>(dst) = static_cast<uint32_t>(0xaeaeaeae);
+ });
+ } else if (col_width == 8) {
+ EncodeSelectedImp(
+ offset_within_row, rows, col, num_selected, selection,
+ [](uint8_t* dst, const uint8_t* src_base, uint16_t irow) {
+ *reinterpret_cast<uint64_t*>(dst) =
+ reinterpret_cast<const uint64_t*>(src_base)[irow];
+ },
+ [](uint8_t* dst) { *reinterpret_cast<uint64_t*>(dst) = 0xaeaeaeaeaeaeaeaeULL; });
+ } else {
+ EncodeSelectedImp(
+ offset_within_row, rows, col, num_selected, selection,
+ [col_width](uint8_t* dst, const uint8_t* src_base, uint16_t irow) {
+ memcpy(dst, src_base + col_width * irow, col_width);
+ },
+ [col_width](uint8_t* dst) { memset(dst, 0xae, col_width); });
+ }
+}
+
+void KeyEncoder::EncoderOffsets::GetRowOffsetsSelected(
+ KeyRowArray* rows, const std::vector<KeyColumnArray>& cols, uint32_t num_selected,
+ const uint16_t* selection) {
+ if (rows->metadata().is_fixed_length) {
+ return;
+ }
+
+ uint32_t* row_offsets = rows->mutable_offsets();
+ for (uint32_t i = 0; i < num_selected; ++i) {
+ row_offsets[i] = rows->metadata().fixed_length;
+ }
+
+ for (size_t icol = 0; icol < cols.size(); ++icol) {
+ bool is_fixed_length = (cols[icol].metadata().is_fixed_length);
+ if (!is_fixed_length) {
+ const uint32_t* col_offsets = cols[icol].offsets();
+ for (uint32_t i = 0; i < num_selected; ++i) {
+ uint32_t irow = selection[i];
+ uint32_t length = col_offsets[irow + 1] - col_offsets[irow];
+ row_offsets[i] += KeyRowMetadata::padding_for_alignment(
+ row_offsets[i], rows->metadata().string_alignment);
+ row_offsets[i] += length;
+ }
+ const uint8_t* non_null_bits = cols[icol].data(0);
+ if (non_null_bits) {
+ const uint32_t* col_offsets = cols[icol].offsets();
+ for (uint32_t i = 0; i < num_selected; ++i) {
+ uint32_t irow = selection[i];
+ bool is_null = !BitUtil::GetBit(non_null_bits, irow + cols[icol].bit_offset(0));
+ if (is_null) {
+ uint32_t length = col_offsets[irow + 1] - col_offsets[irow];
+ row_offsets[i] -= length;
+ }
+ }
+ }
+ }
+ }
+
+ uint32_t sum = 0;
+ int row_alignment = rows->metadata().row_alignment;
+ for (uint32_t i = 0; i < num_selected; ++i) {
+ uint32_t length = row_offsets[i];
+ length += KeyRowMetadata::padding_for_alignment(length, row_alignment);
+ row_offsets[i] = sum;
+ sum += length;
+ }
+ row_offsets[num_selected] = sum;
+}
+
+template <bool has_nulls, bool is_first_varbinary>
+void KeyEncoder::EncoderOffsets::EncodeSelectedImp(
+ uint32_t ivarbinary, KeyRowArray* rows, const std::vector<KeyColumnArray>& cols,
+ uint32_t num_selected, const uint16_t* selection) {
+ const uint32_t* row_offsets = rows->offsets();
+ uint8_t* row_base = rows->mutable_data(2) +
+ rows->metadata().varbinary_end_array_offset +
+ ivarbinary * sizeof(uint32_t);
+ const uint32_t* col_offsets = cols[ivarbinary].offsets();
+ const uint8_t* col_non_null_bits = cols[ivarbinary].data(0);
+
+ for (uint32_t i = 0; i < num_selected; ++i) {
+ uint32_t irow = selection[i];
+ uint32_t length = col_offsets[irow + 1] - col_offsets[irow];
+ if (has_nulls) {
+ uint32_t null_multiplier =
+ BitUtil::GetBit(col_non_null_bits, irow + cols[ivarbinary].bit_offset(0)) ? 1
+ : 0;
+ length *= null_multiplier;
+ }
+ uint32_t* row = reinterpret_cast<uint32_t*>(row_base + row_offsets[i]);
+ if (is_first_varbinary) {
+ row[0] = rows->metadata().fixed_length + length;
+ } else {
+ row[0] = row[-1] +
+ KeyRowMetadata::padding_for_alignment(row[-1],
+ rows->metadata().string_alignment) +
+ length;
+ }
+ }
+}
+
+void KeyEncoder::EncoderOffsets::EncodeSelected(KeyRowArray* rows,
+ const std::vector<KeyColumnArray>& cols,
+ uint32_t num_selected,
+ const uint16_t* selection) {
+ if (rows->metadata().is_fixed_length) {
+ return;
+ }
+ uint32_t ivarbinary = 0;
+ for (size_t icol = 0; icol < cols.size(); ++icol) {
+ if (!cols[icol].metadata().is_fixed_length) {
+ const uint8_t* non_null_bits = cols[icol].data(0);
+ if (non_null_bits && ivarbinary == 0) {
+ EncodeSelectedImp<true, true>(ivarbinary, rows, cols, num_selected, selection);
+ } else if (non_null_bits && ivarbinary > 0) {
+ EncodeSelectedImp<true, false>(ivarbinary, rows, cols, num_selected, selection);
+ } else if (!non_null_bits && ivarbinary == 0) {
+ EncodeSelectedImp<false, true>(ivarbinary, rows, cols, num_selected, selection);
+ } else {
+ EncodeSelectedImp<false, false>(ivarbinary, rows, cols, num_selected, selection);
+ }
+ ivarbinary++;
+ }
+ }
+}
+
+void KeyEncoder::EncoderVarBinary::EncodeSelected(uint32_t ivarbinary, KeyRowArray* rows,
+ const KeyColumnArray& cols,
+ uint32_t num_selected,
+ const uint16_t* selection) {
+ const uint32_t* row_offsets = rows->offsets();
+ uint8_t* row_base = rows->mutable_data(2);
+ const uint32_t* col_offsets = cols.offsets();
+ const uint8_t* col_base = cols.data(2);
+
+ if (ivarbinary == 0) {
+ for (uint32_t i = 0; i < num_selected; ++i) {
+ uint8_t* row = row_base + row_offsets[i];
+ uint32_t row_offset;
+ uint32_t length;
+ rows->metadata().first_varbinary_offset_and_length(row, &row_offset, &length);
+ uint32_t irow = selection[i];
+ memcpy(row + row_offset, col_base + col_offsets[irow], length);
+ }
+ } else {
+ for (uint32_t i = 0; i < num_selected; ++i) {
+ uint8_t* row = row_base + row_offsets[i];
+ uint32_t row_offset;
+ uint32_t length;
+ rows->metadata().nth_varbinary_offset_and_length(row, ivarbinary, &row_offset,
+ &length);
+ uint32_t irow = selection[i];
+ memcpy(row + row_offset, col_base + col_offsets[irow], length);
+ }
+ }
+}
+
+void KeyEncoder::EncoderNulls::EncodeSelected(KeyRowArray* rows,
+ const std::vector<KeyColumnArray>& cols,
+ uint32_t num_selected,
+ const uint16_t* selection) {
+ uint8_t* null_masks = rows->null_masks();
+ uint32_t null_mask_num_bytes = rows->metadata().null_masks_bytes_per_row;
+ memset(null_masks, 0, null_mask_num_bytes * num_selected);
+ for (size_t icol = 0; icol < cols.size(); ++icol) {
+ const uint8_t* non_null_bits = cols[icol].data(0);
+ if (non_null_bits) {
+ for (uint32_t i = 0; i < num_selected; ++i) {
+ uint32_t irow = selection[i];
+ bool is_null = !BitUtil::GetBit(non_null_bits, irow + cols[icol].bit_offset(0));
+ if (is_null) {
+ BitUtil::SetBit(null_masks, i * null_mask_num_bytes * 8 + icol);
+ }
+ }
+ }
+ }
+}
+
+void KeyEncoder::PrepareEncodeSelected(int64_t start_row, int64_t num_rows,
+ const std::vector<KeyColumnArray>& cols) {
+ // Prepare column array vectors
+ PrepareKeyColumnArrays(start_row, num_rows, cols);
+}
+
+Status KeyEncoder::EncodeSelected(KeyRowArray* rows, uint32_t num_selected,
+ const uint16_t* selection) {
+ rows->Clean();
+ RETURN_NOT_OK(
+ rows->AppendEmpty(static_cast<uint32_t>(num_selected), static_cast<uint32_t>(0)));
+
+ EncoderOffsets::GetRowOffsetsSelected(rows, batch_varbinary_cols_, num_selected,
+ selection);
+
+ RETURN_NOT_OK(rows->AppendEmpty(static_cast<uint32_t>(0),
+ static_cast<uint32_t>(rows->offsets()[num_selected])));
+
+ for (size_t icol = 0; icol < batch_all_cols_.size(); ++icol) {
+ if (batch_all_cols_[icol].metadata().is_fixed_length) {
+ uint32_t offset_within_row = rows->metadata().column_offsets[icol];
+ EncoderBinary::EncodeSelected(offset_within_row, rows, batch_all_cols_[icol],
+ num_selected, selection);
+ }
+ }
+
+ EncoderOffsets::EncodeSelected(rows, batch_varbinary_cols_, num_selected, selection);
+
+ for (size_t icol = 0; icol < batch_varbinary_cols_.size(); ++icol) {
+ EncoderVarBinary::EncodeSelected(static_cast<uint32_t>(icol), rows,
+ batch_varbinary_cols_[icol], num_selected,
+ selection);
+ }
+
+ EncoderNulls::EncodeSelected(rows, batch_all_cols_, num_selected, selection);
+
+ return Status::OK();
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/key_encode.h b/src/arrow/cpp/src/arrow/compute/exec/key_encode.h
new file mode 100644
index 000000000..69f4a1694
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/key_encode.h
@@ -0,0 +1,567 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "arrow/compute/exec/util.h"
+#include "arrow/memory_pool.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/bit_util.h"
+
+namespace arrow {
+namespace compute {
+
+class KeyColumnMetadata;
+
+/// Converts between key representation as a collection of arrays for
+/// individual columns and another representation as a single array of rows
+/// combining data from all columns into one value.
+/// This conversion is reversible.
+/// Row-oriented storage is beneficial when there is a need for random access
+/// of individual rows and at the same time all included columns are likely to
+/// be accessed together, as in the case of hash table key.
+class KeyEncoder {
+ public:
+ struct KeyEncoderContext {
+ bool has_avx2() const {
+ return (hardware_flags & arrow::internal::CpuInfo::AVX2) > 0;
+ }
+ int64_t hardware_flags;
+ util::TempVectorStack* stack;
+ };
+
+ /// Description of a storage format of a single key column as needed
+ /// for the purpose of row encoding.
+ struct KeyColumnMetadata {
+ KeyColumnMetadata() = default;
+ KeyColumnMetadata(bool is_fixed_length_in, uint32_t fixed_length_in)
+ : is_fixed_length(is_fixed_length_in), fixed_length(fixed_length_in) {}
+ /// Is column storing a varying-length binary, using offsets array
+ /// to find a beginning of a value, or is it a fixed-length binary.
+ bool is_fixed_length;
+ /// For a fixed-length binary column: number of bytes per value.
+ /// Zero has a special meaning, indicating a bit vector with one bit per value.
+ /// For a varying-length binary column: number of bytes per offset.
+ uint32_t fixed_length;
+ };
+
+ /// Description of a storage format for rows produced by encoder.
+ struct KeyRowMetadata {
+ /// Is row a varying-length binary, using offsets array to find a beginning of a row,
+ /// or is it a fixed-length binary.
+ bool is_fixed_length;
+
+ /// For a fixed-length binary row, common size of rows in bytes,
+ /// rounded up to the multiple of alignment.
+ ///
+ /// For a varying-length binary, size of all encoded fixed-length key columns,
+ /// including lengths of varying-length columns, rounded up to the multiple of string
+ /// alignment.
+ uint32_t fixed_length;
+
+ /// Offset within a row to the array of 32-bit offsets within a row of
+ /// ends of varbinary fields.
+ /// Used only when the row is not fixed-length, zero for fixed-length row.
+ /// There are N elements for N varbinary fields.
+ /// Each element is the offset within a row of the first byte after
+ /// the corresponding varbinary field bytes in that row.
+ /// If varbinary fields begin at aligned addresses, than the end of the previous
+ /// varbinary field needs to be rounded up according to the specified alignment
+ /// to obtain the beginning of the next varbinary field.
+ /// The first varbinary field starts at offset specified by fixed_length,
+ /// which should already be aligned.
+ uint32_t varbinary_end_array_offset;
+
+ /// Fixed number of bytes per row that are used to encode null masks.
+ /// Null masks indicate for a single row which of its key columns are null.
+ /// Nth bit in the sequence of bytes assigned to a row represents null
+ /// information for Nth field according to the order in which they are encoded.
+ int null_masks_bytes_per_row;
+
+ /// Power of 2. Every row will start at the offset aligned to that number of bytes.
+ int row_alignment;
+
+ /// Power of 2. Must be no greater than row alignment.
+ /// Every non-power-of-2 binary field and every varbinary field bytes
+ /// will start aligned to that number of bytes.
+ int string_alignment;
+
+ /// Metadata of encoded columns in their original order.
+ std::vector<KeyColumnMetadata> column_metadatas;
+
+ /// Order in which fields are encoded.
+ std::vector<uint32_t> column_order;
+
+ /// Offsets within a row to fields in their encoding order.
+ std::vector<uint32_t> column_offsets;
+
+ /// Rounding up offset to the nearest multiple of alignment value.
+ /// Alignment must be a power of 2.
+ static inline uint32_t padding_for_alignment(uint32_t offset,
+ int required_alignment) {
+ ARROW_DCHECK(ARROW_POPCOUNT64(required_alignment) == 1);
+ return static_cast<uint32_t>((-static_cast<int32_t>(offset)) &
+ (required_alignment - 1));
+ }
+
+ /// Rounding up offset to the beginning of next column,
+ /// chosing required alignment based on the data type of that column.
+ static inline uint32_t padding_for_alignment(uint32_t offset, int string_alignment,
+ const KeyColumnMetadata& col_metadata) {
+ if (!col_metadata.is_fixed_length ||
+ ARROW_POPCOUNT64(col_metadata.fixed_length) <= 1) {
+ return 0;
+ } else {
+ return padding_for_alignment(offset, string_alignment);
+ }
+ }
+
+ /// Returns an array of offsets within a row of ends of varbinary fields.
+ inline const uint32_t* varbinary_end_array(const uint8_t* row) const {
+ ARROW_DCHECK(!is_fixed_length);
+ return reinterpret_cast<const uint32_t*>(row + varbinary_end_array_offset);
+ }
+ inline uint32_t* varbinary_end_array(uint8_t* row) const {
+ ARROW_DCHECK(!is_fixed_length);
+ return reinterpret_cast<uint32_t*>(row + varbinary_end_array_offset);
+ }
+
+ /// Returns the offset within the row and length of the first varbinary field.
+ inline void first_varbinary_offset_and_length(const uint8_t* row, uint32_t* offset,
+ uint32_t* length) const {
+ ARROW_DCHECK(!is_fixed_length);
+ *offset = fixed_length;
+ *length = varbinary_end_array(row)[0] - fixed_length;
+ }
+
+ /// Returns the offset within the row and length of the second and further varbinary
+ /// fields.
+ inline void nth_varbinary_offset_and_length(const uint8_t* row, int varbinary_id,
+ uint32_t* out_offset,
+ uint32_t* out_length) const {
+ ARROW_DCHECK(!is_fixed_length);
+ ARROW_DCHECK(varbinary_id > 0);
+ const uint32_t* varbinary_end = varbinary_end_array(row);
+ uint32_t offset = varbinary_end[varbinary_id - 1];
+ offset += padding_for_alignment(offset, string_alignment);
+ *out_offset = offset;
+ *out_length = varbinary_end[varbinary_id] - offset;
+ }
+
+ uint32_t encoded_field_order(uint32_t icol) const { return column_order[icol]; }
+
+ uint32_t encoded_field_offset(uint32_t icol) const { return column_offsets[icol]; }
+
+ uint32_t num_cols() const { return static_cast<uint32_t>(column_metadatas.size()); }
+
+ uint32_t num_varbinary_cols() const;
+
+ void FromColumnMetadataVector(const std::vector<KeyColumnMetadata>& cols,
+ int in_row_alignment, int in_string_alignment);
+
+ bool is_compatible(const KeyRowMetadata& other) const;
+ };
+
+ class KeyRowArray {
+ public:
+ KeyRowArray();
+ Status Init(MemoryPool* pool, const KeyRowMetadata& metadata);
+ void Clean();
+ Status AppendEmpty(uint32_t num_rows_to_append, uint32_t num_extra_bytes_to_append);
+ Status AppendSelectionFrom(const KeyRowArray& from, uint32_t num_rows_to_append,
+ const uint16_t* source_row_ids);
+ const KeyRowMetadata& metadata() const { return metadata_; }
+ int64_t length() const { return num_rows_; }
+ const uint8_t* data(int i) const {
+ ARROW_DCHECK(i >= 0 && i <= max_buffers_);
+ return buffers_[i];
+ }
+ uint8_t* mutable_data(int i) {
+ ARROW_DCHECK(i >= 0 && i <= max_buffers_);
+ return mutable_buffers_[i];
+ }
+ const uint32_t* offsets() const { return reinterpret_cast<const uint32_t*>(data(1)); }
+ uint32_t* mutable_offsets() { return reinterpret_cast<uint32_t*>(mutable_data(1)); }
+ const uint8_t* null_masks() const { return null_masks_->data(); }
+ uint8_t* null_masks() { return null_masks_->mutable_data(); }
+
+ bool has_any_nulls(const KeyEncoderContext* ctx) const;
+
+ private:
+ Status ResizeFixedLengthBuffers(int64_t num_extra_rows);
+ Status ResizeOptionalVaryingLengthBuffer(int64_t num_extra_bytes);
+
+ int64_t size_null_masks(int64_t num_rows);
+ int64_t size_offsets(int64_t num_rows);
+ int64_t size_rows_fixed_length(int64_t num_rows);
+ int64_t size_rows_varying_length(int64_t num_bytes);
+ void update_buffer_pointers();
+
+ static constexpr int64_t padding_for_vectors = 64;
+ MemoryPool* pool_;
+ KeyRowMetadata metadata_;
+ /// Buffers can only expand during lifetime and never shrink.
+ std::unique_ptr<ResizableBuffer> null_masks_;
+ std::unique_ptr<ResizableBuffer> offsets_;
+ std::unique_ptr<ResizableBuffer> rows_;
+ static constexpr int max_buffers_ = 3;
+ const uint8_t* buffers_[max_buffers_];
+ uint8_t* mutable_buffers_[max_buffers_];
+ int64_t num_rows_;
+ int64_t rows_capacity_;
+ int64_t bytes_capacity_;
+
+ // Mutable to allow lazy evaluation
+ mutable int64_t num_rows_for_has_any_nulls_;
+ mutable bool has_any_nulls_;
+ };
+
+ /// A lightweight description of an array representing one of key columns.
+ class KeyColumnArray {
+ public:
+ KeyColumnArray() = default;
+ /// Create as a mix of buffers according to the mask from two descriptions
+ /// (Nth bit is set to 0 if Nth buffer from the first input
+ /// should be used and is set to 1 otherwise).
+ /// Metadata is inherited from the first input.
+ KeyColumnArray(const KeyColumnMetadata& metadata, const KeyColumnArray& left,
+ const KeyColumnArray& right, int buffer_id_to_replace);
+ /// Create for reading
+ KeyColumnArray(const KeyColumnMetadata& metadata, int64_t length,
+ const uint8_t* buffer0, const uint8_t* buffer1, const uint8_t* buffer2,
+ int bit_offset0 = 0, int bit_offset1 = 0);
+ /// Create for writing
+ KeyColumnArray(const KeyColumnMetadata& metadata, int64_t length, uint8_t* buffer0,
+ uint8_t* buffer1, uint8_t* buffer2, int bit_offset0 = 0,
+ int bit_offset1 = 0);
+ /// Create as a window view of original description that is offset
+ /// by a given number of rows.
+ /// The number of rows used in offset must be divisible by 8
+ /// in order to not split bit vectors within a single byte.
+ KeyColumnArray(const KeyColumnArray& from, int64_t start, int64_t length);
+ uint8_t* mutable_data(int i) {
+ ARROW_DCHECK(i >= 0 && i <= max_buffers_);
+ return mutable_buffers_[i];
+ }
+ const uint8_t* data(int i) const {
+ ARROW_DCHECK(i >= 0 && i <= max_buffers_);
+ return buffers_[i];
+ }
+ uint32_t* mutable_offsets() { return reinterpret_cast<uint32_t*>(mutable_data(1)); }
+ const uint32_t* offsets() const { return reinterpret_cast<const uint32_t*>(data(1)); }
+ const KeyColumnMetadata& metadata() const { return metadata_; }
+ int64_t length() const { return length_; }
+ int bit_offset(int i) const {
+ ARROW_DCHECK(i >= 0 && i < max_buffers_);
+ return bit_offset_[i];
+ }
+
+ private:
+ static constexpr int max_buffers_ = 3;
+ const uint8_t* buffers_[max_buffers_];
+ uint8_t* mutable_buffers_[max_buffers_];
+ KeyColumnMetadata metadata_;
+ int64_t length_;
+ // Starting bit offset within the first byte (between 0 and 7)
+ // to be used when accessing buffers that store bit vectors.
+ int bit_offset_[max_buffers_ - 1];
+ };
+
+ void Init(const std::vector<KeyColumnMetadata>& cols, KeyEncoderContext* ctx,
+ int row_alignment, int string_alignment);
+
+ const KeyRowMetadata& row_metadata() { return row_metadata_; }
+
+ void PrepareEncodeSelected(int64_t start_row, int64_t num_rows,
+ const std::vector<KeyColumnArray>& cols);
+ Status EncodeSelected(KeyRowArray* rows, uint32_t num_selected,
+ const uint16_t* selection);
+
+ /// Decode a window of row oriented data into a corresponding
+ /// window of column oriented storage.
+ /// The output buffers need to be correctly allocated and sized before
+ /// calling each method.
+ /// For that reason decoding is split into two functions.
+ /// The output of the first one, that processes everything except for
+ /// varying length buffers, can be used to find out required varying
+ /// length buffers sizes.
+ void DecodeFixedLengthBuffers(int64_t start_row_input, int64_t start_row_output,
+ int64_t num_rows, const KeyRowArray& rows,
+ std::vector<KeyColumnArray>* cols);
+
+ void DecodeVaryingLengthBuffers(int64_t start_row_input, int64_t start_row_output,
+ int64_t num_rows, const KeyRowArray& rows,
+ std::vector<KeyColumnArray>* cols);
+
+ const std::vector<KeyColumnArray>& GetBatchColumns() const { return batch_all_cols_; }
+
+ private:
+ /// Prepare column array vectors.
+ /// Output column arrays represent a range of input column arrays
+ /// specified by starting row and number of rows.
+ /// Three vectors are generated:
+ /// - all columns
+ /// - fixed-length columns only
+ /// - varying-length columns only
+ void PrepareKeyColumnArrays(int64_t start_row, int64_t num_rows,
+ const std::vector<KeyColumnArray>& cols_in);
+
+ class TransformBoolean {
+ public:
+ static KeyColumnArray ArrayReplace(const KeyColumnArray& column,
+ const KeyColumnArray& temp);
+ static void PostDecode(const KeyColumnArray& input, KeyColumnArray* output,
+ KeyEncoderContext* ctx);
+ };
+
+ class EncoderInteger {
+ public:
+ static void Decode(uint32_t start_row, uint32_t num_rows, uint32_t offset_within_row,
+ const KeyRowArray& rows, KeyColumnArray* col,
+ KeyEncoderContext* ctx, KeyColumnArray* temp);
+ static bool UsesTransform(const KeyColumnArray& column);
+ static KeyColumnArray ArrayReplace(const KeyColumnArray& column,
+ const KeyColumnArray& temp);
+ static void PostDecode(const KeyColumnArray& input, KeyColumnArray* output,
+ KeyEncoderContext* ctx);
+
+ private:
+ static bool IsBoolean(const KeyColumnMetadata& metadata);
+ };
+
+ class EncoderBinary {
+ public:
+ static void EncodeSelected(uint32_t offset_within_row, KeyRowArray* rows,
+ const KeyColumnArray& col, uint32_t num_selected,
+ const uint16_t* selection);
+ static void Decode(uint32_t start_row, uint32_t num_rows, uint32_t offset_within_row,
+ const KeyRowArray& rows, KeyColumnArray* col,
+ KeyEncoderContext* ctx, KeyColumnArray* temp);
+ static bool IsInteger(const KeyColumnMetadata& metadata);
+
+ private:
+ template <class COPY_FN, class SET_NULL_FN>
+ static void EncodeSelectedImp(uint32_t offset_within_row, KeyRowArray* rows,
+ const KeyColumnArray& col, uint32_t num_selected,
+ const uint16_t* selection, COPY_FN copy_fn,
+ SET_NULL_FN set_null_fn);
+
+ template <bool is_row_fixed_length, class COPY_FN>
+ static inline void DecodeHelper(uint32_t start_row, uint32_t num_rows,
+ uint32_t offset_within_row,
+ const KeyRowArray* rows_const,
+ KeyRowArray* rows_mutable_maybe_null,
+ const KeyColumnArray* col_const,
+ KeyColumnArray* col_mutable_maybe_null,
+ COPY_FN copy_fn);
+ template <bool is_row_fixed_length>
+ static void DecodeImp(uint32_t start_row, uint32_t num_rows,
+ uint32_t offset_within_row, const KeyRowArray& rows,
+ KeyColumnArray* col);
+#if defined(ARROW_HAVE_AVX2)
+ static void DecodeHelper_avx2(bool is_row_fixed_length, uint32_t start_row,
+ uint32_t num_rows, uint32_t offset_within_row,
+ const KeyRowArray& rows, KeyColumnArray* col);
+ template <bool is_row_fixed_length>
+ static void DecodeImp_avx2(uint32_t start_row, uint32_t num_rows,
+ uint32_t offset_within_row, const KeyRowArray& rows,
+ KeyColumnArray* col);
+#endif
+ };
+
+ class EncoderBinaryPair {
+ public:
+ static bool CanProcessPair(const KeyColumnMetadata& col1,
+ const KeyColumnMetadata& col2) {
+ return EncoderBinary::IsInteger(col1) && EncoderBinary::IsInteger(col2);
+ }
+ static void Decode(uint32_t start_row, uint32_t num_rows, uint32_t offset_within_row,
+ const KeyRowArray& rows, KeyColumnArray* col1,
+ KeyColumnArray* col2, KeyEncoderContext* ctx,
+ KeyColumnArray* temp1, KeyColumnArray* temp2);
+
+ private:
+ template <bool is_row_fixed_length, typename col1_type, typename col2_type>
+ static void DecodeImp(uint32_t num_rows_to_skip, uint32_t start_row,
+ uint32_t num_rows, uint32_t offset_within_row,
+ const KeyRowArray& rows, KeyColumnArray* col1,
+ KeyColumnArray* col2);
+#if defined(ARROW_HAVE_AVX2)
+ static uint32_t DecodeHelper_avx2(bool is_row_fixed_length, uint32_t col_width,
+ uint32_t start_row, uint32_t num_rows,
+ uint32_t offset_within_row, const KeyRowArray& rows,
+ KeyColumnArray* col1, KeyColumnArray* col2);
+ template <bool is_row_fixed_length, uint32_t col_width>
+ static uint32_t DecodeImp_avx2(uint32_t start_row, uint32_t num_rows,
+ uint32_t offset_within_row, const KeyRowArray& rows,
+ KeyColumnArray* col1, KeyColumnArray* col2);
+#endif
+ };
+
+ class EncoderOffsets {
+ public:
+ static void GetRowOffsetsSelected(KeyRowArray* rows,
+ const std::vector<KeyColumnArray>& cols,
+ uint32_t num_selected, const uint16_t* selection);
+ static void EncodeSelected(KeyRowArray* rows, const std::vector<KeyColumnArray>& cols,
+ uint32_t num_selected, const uint16_t* selection);
+
+ static void Decode(uint32_t start_row, uint32_t num_rows, const KeyRowArray& rows,
+ std::vector<KeyColumnArray>* varbinary_cols,
+ const std::vector<uint32_t>& varbinary_cols_base_offset,
+ KeyEncoderContext* ctx);
+
+ private:
+ template <bool has_nulls, bool is_first_varbinary>
+ static void EncodeSelectedImp(uint32_t ivarbinary, KeyRowArray* rows,
+ const std::vector<KeyColumnArray>& cols,
+ uint32_t num_selected, const uint16_t* selection);
+ };
+
+ class EncoderVarBinary {
+ public:
+ static void EncodeSelected(uint32_t ivarbinary, KeyRowArray* rows,
+ const KeyColumnArray& cols, uint32_t num_selected,
+ const uint16_t* selection);
+
+ static void Decode(uint32_t start_row, uint32_t num_rows, uint32_t varbinary_col_id,
+ const KeyRowArray& rows, KeyColumnArray* col,
+ KeyEncoderContext* ctx);
+
+ private:
+ template <bool first_varbinary_col, class COPY_FN>
+ static inline void DecodeHelper(uint32_t start_row, uint32_t num_rows,
+ uint32_t varbinary_col_id,
+ const KeyRowArray* rows_const,
+ KeyRowArray* rows_mutable_maybe_null,
+ const KeyColumnArray* col_const,
+ KeyColumnArray* col_mutable_maybe_null,
+ COPY_FN copy_fn);
+ template <bool first_varbinary_col>
+ static void DecodeImp(uint32_t start_row, uint32_t num_rows,
+ uint32_t varbinary_col_id, const KeyRowArray& rows,
+ KeyColumnArray* col);
+#if defined(ARROW_HAVE_AVX2)
+ static void DecodeHelper_avx2(uint32_t start_row, uint32_t num_rows,
+ uint32_t varbinary_col_id, const KeyRowArray& rows,
+ KeyColumnArray* col);
+ template <bool first_varbinary_col>
+ static void DecodeImp_avx2(uint32_t start_row, uint32_t num_rows,
+ uint32_t varbinary_col_id, const KeyRowArray& rows,
+ KeyColumnArray* col);
+#endif
+ };
+
+ class EncoderNulls {
+ public:
+ static void EncodeSelected(KeyRowArray* rows, const std::vector<KeyColumnArray>& cols,
+ uint32_t num_selected, const uint16_t* selection);
+
+ static void Decode(uint32_t start_row, uint32_t num_rows, const KeyRowArray& rows,
+ std::vector<KeyColumnArray>* cols);
+ };
+
+ KeyEncoderContext* ctx_;
+
+ // Data initialized once, based on data types of key columns
+ KeyRowMetadata row_metadata_;
+
+ // Data initialized for each input batch.
+ // All elements are ordered according to the order of encoded fields in a row.
+ std::vector<KeyColumnArray> batch_all_cols_;
+ std::vector<KeyColumnArray> batch_varbinary_cols_;
+ std::vector<uint32_t> batch_varbinary_cols_base_offsets_;
+};
+
+template <bool is_row_fixed_length, class COPY_FN>
+inline void KeyEncoder::EncoderBinary::DecodeHelper(
+ uint32_t start_row, uint32_t num_rows, uint32_t offset_within_row,
+ const KeyRowArray* rows_const, KeyRowArray* rows_mutable_maybe_null,
+ const KeyColumnArray* col_const, KeyColumnArray* col_mutable_maybe_null,
+ COPY_FN copy_fn) {
+ ARROW_DCHECK(col_const && col_const->metadata().is_fixed_length);
+ uint32_t col_width = col_const->metadata().fixed_length;
+
+ if (is_row_fixed_length) {
+ uint32_t row_width = rows_const->metadata().fixed_length;
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ const uint8_t* src;
+ uint8_t* dst;
+ src = rows_const->data(1) + row_width * (start_row + i) + offset_within_row;
+ dst = col_mutable_maybe_null->mutable_data(1) + col_width * i;
+ copy_fn(dst, src, col_width);
+ }
+ } else {
+ const uint32_t* row_offsets = rows_const->offsets();
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ const uint8_t* src;
+ uint8_t* dst;
+ src = rows_const->data(2) + row_offsets[start_row + i] + offset_within_row;
+ dst = col_mutable_maybe_null->mutable_data(1) + col_width * i;
+ copy_fn(dst, src, col_width);
+ }
+ }
+}
+
+template <bool first_varbinary_col, class COPY_FN>
+inline void KeyEncoder::EncoderVarBinary::DecodeHelper(
+ uint32_t start_row, uint32_t num_rows, uint32_t varbinary_col_id,
+ const KeyRowArray* rows_const, KeyRowArray* rows_mutable_maybe_null,
+ const KeyColumnArray* col_const, KeyColumnArray* col_mutable_maybe_null,
+ COPY_FN copy_fn) {
+ // Column and rows need to be varying length
+ ARROW_DCHECK(!rows_const->metadata().is_fixed_length &&
+ !col_const->metadata().is_fixed_length);
+
+ const uint32_t* row_offsets_for_batch = rows_const->offsets() + start_row;
+ const uint32_t* col_offsets = col_const->offsets();
+
+ uint32_t col_offset_next = col_offsets[0];
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ uint32_t col_offset = col_offset_next;
+ col_offset_next = col_offsets[i + 1];
+
+ uint32_t row_offset = row_offsets_for_batch[i];
+ const uint8_t* row = rows_const->data(2) + row_offset;
+
+ uint32_t offset_within_row;
+ uint32_t length;
+ if (first_varbinary_col) {
+ rows_const->metadata().first_varbinary_offset_and_length(row, &offset_within_row,
+ &length);
+ } else {
+ rows_const->metadata().nth_varbinary_offset_and_length(row, varbinary_col_id,
+ &offset_within_row, &length);
+ }
+
+ row_offset += offset_within_row;
+
+ const uint8_t* src;
+ uint8_t* dst;
+ src = rows_const->data(2) + row_offset;
+ dst = col_mutable_maybe_null->mutable_data(2) + col_offset;
+ copy_fn(dst, src, length);
+ }
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/key_encode_avx2.cc b/src/arrow/cpp/src/arrow/compute/exec/key_encode_avx2.cc
new file mode 100644
index 000000000..832bb0361
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/key_encode_avx2.cc
@@ -0,0 +1,241 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <immintrin.h>
+
+#include "arrow/compute/exec/key_encode.h"
+
+namespace arrow {
+namespace compute {
+
+#if defined(ARROW_HAVE_AVX2)
+
+void KeyEncoder::EncoderBinary::DecodeHelper_avx2(bool is_row_fixed_length,
+ uint32_t start_row, uint32_t num_rows,
+ uint32_t offset_within_row,
+ const KeyRowArray& rows,
+ KeyColumnArray* col) {
+ if (is_row_fixed_length) {
+ DecodeImp_avx2<true>(start_row, num_rows, offset_within_row, rows, col);
+ } else {
+ DecodeImp_avx2<false>(start_row, num_rows, offset_within_row, rows, col);
+ }
+}
+
+template <bool is_row_fixed_length>
+void KeyEncoder::EncoderBinary::DecodeImp_avx2(uint32_t start_row, uint32_t num_rows,
+ uint32_t offset_within_row,
+ const KeyRowArray& rows,
+ KeyColumnArray* col) {
+ DecodeHelper<is_row_fixed_length>(
+ start_row, num_rows, offset_within_row, &rows, nullptr, col, col,
+ [](uint8_t* dst, const uint8_t* src, int64_t length) {
+ for (uint32_t istripe = 0; istripe < (length + 31) / 32; ++istripe) {
+ __m256i* dst256 = reinterpret_cast<__m256i*>(dst);
+ const __m256i* src256 = reinterpret_cast<const __m256i*>(src);
+ _mm256_storeu_si256(dst256 + istripe, _mm256_loadu_si256(src256 + istripe));
+ }
+ });
+}
+
+uint32_t KeyEncoder::EncoderBinaryPair::DecodeHelper_avx2(
+ bool is_row_fixed_length, uint32_t col_width, uint32_t start_row, uint32_t num_rows,
+ uint32_t offset_within_row, const KeyRowArray& rows, KeyColumnArray* col1,
+ KeyColumnArray* col2) {
+ using DecodeImp_avx2_t =
+ uint32_t (*)(uint32_t start_row, uint32_t num_rows, uint32_t offset_within_row,
+ const KeyRowArray& rows, KeyColumnArray* col1, KeyColumnArray* col2);
+ static const DecodeImp_avx2_t DecodeImp_avx2_fn[] = {
+ DecodeImp_avx2<false, 1>, DecodeImp_avx2<false, 2>, DecodeImp_avx2<false, 4>,
+ DecodeImp_avx2<false, 8>, DecodeImp_avx2<true, 1>, DecodeImp_avx2<true, 2>,
+ DecodeImp_avx2<true, 4>, DecodeImp_avx2<true, 8>};
+ int log_col_width = col_width == 8 ? 3 : col_width == 4 ? 2 : col_width == 2 ? 1 : 0;
+ int dispatch_const = log_col_width | (is_row_fixed_length ? 4 : 0);
+ return DecodeImp_avx2_fn[dispatch_const](start_row, num_rows, offset_within_row, rows,
+ col1, col2);
+}
+
+template <bool is_row_fixed_length, uint32_t col_width>
+uint32_t KeyEncoder::EncoderBinaryPair::DecodeImp_avx2(
+ uint32_t start_row, uint32_t num_rows, uint32_t offset_within_row,
+ const KeyRowArray& rows, KeyColumnArray* col1, KeyColumnArray* col2) {
+ ARROW_DCHECK(col_width == 1 || col_width == 2 || col_width == 4 || col_width == 8);
+
+ uint8_t* col_vals_A = col1->mutable_data(1);
+ uint8_t* col_vals_B = col2->mutable_data(1);
+
+ uint32_t fixed_length = rows.metadata().fixed_length;
+ const uint32_t* offsets;
+ const uint8_t* src_base;
+ if (is_row_fixed_length) {
+ src_base = rows.data(1) + fixed_length * start_row + offset_within_row;
+ offsets = nullptr;
+ } else {
+ src_base = rows.data(2) + offset_within_row;
+ offsets = rows.offsets() + start_row;
+ }
+
+ constexpr int unroll = 32 / col_width;
+
+ uint32_t num_processed = num_rows / unroll * unroll;
+
+ if (col_width == 8) {
+ for (uint32_t i = 0; i < num_rows / unroll; ++i) {
+ const __m128i *src0, *src1, *src2, *src3;
+ if (is_row_fixed_length) {
+ const uint8_t* src = src_base + (i * unroll) * fixed_length;
+ src0 = reinterpret_cast<const __m128i*>(src);
+ src1 = reinterpret_cast<const __m128i*>(src + fixed_length);
+ src2 = reinterpret_cast<const __m128i*>(src + fixed_length * 2);
+ src3 = reinterpret_cast<const __m128i*>(src + fixed_length * 3);
+ } else {
+ const uint32_t* row_offsets = offsets + i * unroll;
+ const uint8_t* src = src_base;
+ src0 = reinterpret_cast<const __m128i*>(src + row_offsets[0]);
+ src1 = reinterpret_cast<const __m128i*>(src + row_offsets[1]);
+ src2 = reinterpret_cast<const __m128i*>(src + row_offsets[2]);
+ src3 = reinterpret_cast<const __m128i*>(src + row_offsets[3]);
+ }
+
+ __m256i r0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_si128(src0)),
+ _mm_loadu_si128(src1), 1);
+ __m256i r1 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_si128(src2)),
+ _mm_loadu_si128(src3), 1);
+
+ r0 = _mm256_permute4x64_epi64(r0, 0xd8); // 0b11011000
+ r1 = _mm256_permute4x64_epi64(r1, 0xd8);
+
+ // First 128-bit lanes from both inputs
+ __m256i c1 = _mm256_permute2x128_si256(r0, r1, 0x20);
+ // Second 128-bit lanes from both inputs
+ __m256i c2 = _mm256_permute2x128_si256(r0, r1, 0x31);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(col_vals_A) + i, c1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(col_vals_B) + i, c2);
+ }
+ } else {
+ uint8_t buffer[64];
+ for (uint32_t i = 0; i < num_rows / unroll; ++i) {
+ if (is_row_fixed_length) {
+ const uint8_t* src = src_base + (i * unroll) * fixed_length;
+ for (int j = 0; j < unroll; ++j) {
+ if (col_width == 1) {
+ reinterpret_cast<uint16_t*>(buffer)[j] =
+ *reinterpret_cast<const uint16_t*>(src + fixed_length * j);
+ } else if (col_width == 2) {
+ reinterpret_cast<uint32_t*>(buffer)[j] =
+ *reinterpret_cast<const uint32_t*>(src + fixed_length * j);
+ } else if (col_width == 4) {
+ reinterpret_cast<uint64_t*>(buffer)[j] =
+ *reinterpret_cast<const uint64_t*>(src + fixed_length * j);
+ }
+ }
+ } else {
+ const uint32_t* row_offsets = offsets + i * unroll;
+ const uint8_t* src = src_base;
+ for (int j = 0; j < unroll; ++j) {
+ if (col_width == 1) {
+ reinterpret_cast<uint16_t*>(buffer)[j] =
+ *reinterpret_cast<const uint16_t*>(src + row_offsets[j]);
+ } else if (col_width == 2) {
+ reinterpret_cast<uint32_t*>(buffer)[j] =
+ *reinterpret_cast<const uint32_t*>(src + row_offsets[j]);
+ } else if (col_width == 4) {
+ reinterpret_cast<uint64_t*>(buffer)[j] =
+ *reinterpret_cast<const uint64_t*>(src + row_offsets[j]);
+ }
+ }
+ }
+
+ __m256i r0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer));
+ __m256i r1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer) + 1);
+
+ constexpr uint64_t kByteSequence_0_2_4_6_8_10_12_14 = 0x0e0c0a0806040200ULL;
+ constexpr uint64_t kByteSequence_1_3_5_7_9_11_13_15 = 0x0f0d0b0907050301ULL;
+ constexpr uint64_t kByteSequence_0_1_4_5_8_9_12_13 = 0x0d0c090805040100ULL;
+ constexpr uint64_t kByteSequence_2_3_6_7_10_11_14_15 = 0x0f0e0b0a07060302ULL;
+
+ if (col_width == 1) {
+ // Collect every second byte next to each other
+ const __m256i shuffle_const = _mm256_setr_epi64x(
+ kByteSequence_0_2_4_6_8_10_12_14, kByteSequence_1_3_5_7_9_11_13_15,
+ kByteSequence_0_2_4_6_8_10_12_14, kByteSequence_1_3_5_7_9_11_13_15);
+ r0 = _mm256_shuffle_epi8(r0, shuffle_const);
+ r1 = _mm256_shuffle_epi8(r1, shuffle_const);
+ // 0b11011000 swapping second and third 64-bit lane
+ r0 = _mm256_permute4x64_epi64(r0, 0xd8);
+ r1 = _mm256_permute4x64_epi64(r1, 0xd8);
+ } else if (col_width == 2) {
+ // Collect every second 16-bit word next to each other
+ const __m256i shuffle_const = _mm256_setr_epi64x(
+ kByteSequence_0_1_4_5_8_9_12_13, kByteSequence_2_3_6_7_10_11_14_15,
+ kByteSequence_0_1_4_5_8_9_12_13, kByteSequence_2_3_6_7_10_11_14_15);
+ r0 = _mm256_shuffle_epi8(r0, shuffle_const);
+ r1 = _mm256_shuffle_epi8(r1, shuffle_const);
+ // 0b11011000 swapping second and third 64-bit lane
+ r0 = _mm256_permute4x64_epi64(r0, 0xd8);
+ r1 = _mm256_permute4x64_epi64(r1, 0xd8);
+ } else if (col_width == 4) {
+ // Collect every second 32-bit word next to each other
+ const __m256i permute_const = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
+ r0 = _mm256_permutevar8x32_epi32(r0, permute_const);
+ r1 = _mm256_permutevar8x32_epi32(r1, permute_const);
+ }
+
+ // First 128-bit lanes from both inputs
+ __m256i c1 = _mm256_permute2x128_si256(r0, r1, 0x20);
+ // Second 128-bit lanes from both inputs
+ __m256i c2 = _mm256_permute2x128_si256(r0, r1, 0x31);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(col_vals_A) + i, c1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(col_vals_B) + i, c2);
+ }
+ }
+
+ return num_processed;
+}
+
+void KeyEncoder::EncoderVarBinary::DecodeHelper_avx2(uint32_t start_row,
+ uint32_t num_rows,
+ uint32_t varbinary_col_id,
+ const KeyRowArray& rows,
+ KeyColumnArray* col) {
+ if (varbinary_col_id == 0) {
+ DecodeImp_avx2<true>(start_row, num_rows, varbinary_col_id, rows, col);
+ } else {
+ DecodeImp_avx2<false>(start_row, num_rows, varbinary_col_id, rows, col);
+ }
+}
+
+template <bool first_varbinary_col>
+void KeyEncoder::EncoderVarBinary::DecodeImp_avx2(uint32_t start_row, uint32_t num_rows,
+ uint32_t varbinary_col_id,
+ const KeyRowArray& rows,
+ KeyColumnArray* col) {
+ DecodeHelper<first_varbinary_col>(
+ start_row, num_rows, varbinary_col_id, &rows, nullptr, col, col,
+ [](uint8_t* dst, const uint8_t* src, int64_t length) {
+ for (uint32_t istripe = 0; istripe < (length + 31) / 32; ++istripe) {
+ __m256i* dst256 = reinterpret_cast<__m256i*>(dst);
+ const __m256i* src256 = reinterpret_cast<const __m256i*>(src);
+ _mm256_storeu_si256(dst256 + istripe, _mm256_loadu_si256(src256 + istripe));
+ }
+ });
+}
+
+#endif
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/key_hash.cc b/src/arrow/cpp/src/arrow/compute/exec/key_hash.cc
new file mode 100644
index 000000000..76c8ed1ef
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/key_hash.cc
@@ -0,0 +1,319 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/key_hash.h"
+
+#include <memory.h>
+
+#include <algorithm>
+#include <cstdint>
+
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/ubsan.h"
+
+namespace arrow {
+namespace compute {
+
+inline uint32_t Hashing::avalanche_helper(uint32_t acc) {
+ acc ^= (acc >> 15);
+ acc *= PRIME32_2;
+ acc ^= (acc >> 13);
+ acc *= PRIME32_3;
+ acc ^= (acc >> 16);
+ return acc;
+}
+
+void Hashing::avalanche(int64_t hardware_flags, uint32_t num_keys, uint32_t* hashes) {
+ uint32_t processed = 0;
+#if defined(ARROW_HAVE_AVX2)
+ if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
+ int tail = num_keys % 8;
+ avalanche_avx2(num_keys - tail, hashes);
+ processed = num_keys - tail;
+ }
+#endif
+ for (uint32_t i = processed; i < num_keys; ++i) {
+ hashes[i] = avalanche_helper(hashes[i]);
+ }
+}
+
+inline uint32_t Hashing::combine_accumulators(const uint32_t acc1, const uint32_t acc2,
+ const uint32_t acc3, const uint32_t acc4) {
+ return ROTL(acc1, 1) + ROTL(acc2, 7) + ROTL(acc3, 12) + ROTL(acc4, 18);
+}
+
+template <typename T>
+inline void Hashing::helper_8B(uint32_t key_length, uint32_t num_keys, const T* keys,
+ uint32_t* hashes) {
+ ARROW_DCHECK(key_length <= 8);
+ constexpr uint64_t multiplier = 14029467366897019727ULL;
+ for (uint32_t ikey = 0; ikey < num_keys; ++ikey) {
+ uint64_t x = static_cast<uint64_t>(keys[ikey]);
+ hashes[ikey] = static_cast<uint32_t>(BYTESWAP(x * multiplier));
+ }
+}
+
+inline void Hashing::helper_stripe(uint32_t offset, uint64_t mask_hi, const uint8_t* keys,
+ uint32_t& acc1, uint32_t& acc2, uint32_t& acc3,
+ uint32_t& acc4) {
+ uint64_t v1 = util::SafeLoadAs<const uint64_t>(keys + offset);
+ // We do not need to mask v1, because we will not process a stripe
+ // unless at least 9 bytes of it are part of the key.
+ uint64_t v2 = util::SafeLoadAs<const uint64_t>(keys + offset + 8);
+ v2 &= mask_hi;
+ uint32_t x1 = static_cast<uint32_t>(v1);
+ uint32_t x2 = static_cast<uint32_t>(v1 >> 32);
+ uint32_t x3 = static_cast<uint32_t>(v2);
+ uint32_t x4 = static_cast<uint32_t>(v2 >> 32);
+ acc1 += x1 * PRIME32_2;
+ acc1 = ROTL(acc1, 13) * PRIME32_1;
+ acc2 += x2 * PRIME32_2;
+ acc2 = ROTL(acc2, 13) * PRIME32_1;
+ acc3 += x3 * PRIME32_2;
+ acc3 = ROTL(acc3, 13) * PRIME32_1;
+ acc4 += x4 * PRIME32_2;
+ acc4 = ROTL(acc4, 13) * PRIME32_1;
+}
+
+void Hashing::helper_stripes(int64_t hardware_flags, uint32_t num_keys,
+ uint32_t key_length, const uint8_t* keys, uint32_t* hash) {
+ uint32_t processed = 0;
+#if defined(ARROW_HAVE_AVX2)
+ if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
+ int tail = num_keys % 2;
+ helper_stripes_avx2(num_keys - tail, key_length, keys, hash);
+ processed = num_keys - tail;
+ }
+#endif
+
+ // If length modulo stripe length is less than or equal 8, round down to the nearest 16B
+ // boundary (8B ending will be processed in a separate function), otherwise round up.
+ const uint32_t num_stripes = (key_length + 7) / 16;
+ uint64_t mask_hi =
+ ~0ULL >>
+ (8 * ((num_stripes * 16 > key_length) ? num_stripes * 16 - key_length : 0));
+
+ for (uint32_t i = processed; i < num_keys; ++i) {
+ uint32_t acc1, acc2, acc3, acc4;
+ acc1 = static_cast<uint32_t>(
+ (static_cast<uint64_t>(PRIME32_1) + static_cast<uint64_t>(PRIME32_2)) &
+ 0xffffffff);
+ acc2 = PRIME32_2;
+ acc3 = 0;
+ acc4 = static_cast<uint32_t>(-static_cast<int32_t>(PRIME32_1));
+ uint32_t offset = i * key_length;
+ for (uint32_t stripe = 0; stripe < num_stripes - 1; ++stripe) {
+ helper_stripe(offset, ~0ULL, keys, acc1, acc2, acc3, acc4);
+ offset += 16;
+ }
+ helper_stripe(offset, mask_hi, keys, acc1, acc2, acc3, acc4);
+ hash[i] = combine_accumulators(acc1, acc2, acc3, acc4);
+ }
+}
+
+inline uint32_t Hashing::helper_tail(uint32_t offset, uint64_t mask, const uint8_t* keys,
+ uint32_t acc) {
+ uint64_t v = util::SafeLoadAs<const uint64_t>(keys + offset);
+ v &= mask;
+ uint32_t x1 = static_cast<uint32_t>(v);
+ uint32_t x2 = static_cast<uint32_t>(v >> 32);
+ acc += x1 * PRIME32_3;
+ acc = ROTL(acc, 17) * PRIME32_4;
+ acc += x2 * PRIME32_3;
+ acc = ROTL(acc, 17) * PRIME32_4;
+ return acc;
+}
+
+void Hashing::helper_tails(int64_t hardware_flags, uint32_t num_keys, uint32_t key_length,
+ const uint8_t* keys, uint32_t* hash) {
+ uint32_t processed = 0;
+#if defined(ARROW_HAVE_AVX2)
+ if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
+ int tail = num_keys % 8;
+ helper_tails_avx2(num_keys - tail, key_length, keys, hash);
+ processed = num_keys - tail;
+ }
+#endif
+ uint64_t mask = ~0ULL >> (8 * (((key_length % 8) == 0) ? 0 : 8 - (key_length % 8)));
+ uint32_t offset = key_length / 16 * 16;
+ offset += processed * key_length;
+ for (uint32_t i = processed; i < num_keys; ++i) {
+ hash[i] = helper_tail(offset, mask, keys, hash[i]);
+ offset += key_length;
+ }
+}
+
+void Hashing::hash_fixed(int64_t hardware_flags, uint32_t num_keys, uint32_t length_key,
+ const uint8_t* keys, uint32_t* hashes) {
+ ARROW_DCHECK(length_key > 0);
+
+ if (length_key <= 8 && ARROW_POPCOUNT64(length_key) == 1) {
+ switch (length_key) {
+ case 1:
+ helper_8B(length_key, num_keys, keys, hashes);
+ break;
+ case 2:
+ helper_8B(length_key, num_keys, reinterpret_cast<const uint16_t*>(keys), hashes);
+ break;
+ case 4:
+ helper_8B(length_key, num_keys, reinterpret_cast<const uint32_t*>(keys), hashes);
+ break;
+ case 8:
+ helper_8B(length_key, num_keys, reinterpret_cast<const uint64_t*>(keys), hashes);
+ break;
+ default:
+ ARROW_DCHECK(false);
+ }
+ return;
+ }
+ helper_stripes(hardware_flags, num_keys, length_key, keys, hashes);
+ if ((length_key % 16) > 0 && (length_key % 16) <= 8) {
+ helper_tails(hardware_flags, num_keys, length_key, keys, hashes);
+ }
+ avalanche(hardware_flags, num_keys, hashes);
+}
+
+void Hashing::hash_varlen(int64_t hardware_flags, uint32_t num_rows,
+ const uint32_t* offsets, const uint8_t* concatenated_keys,
+ uint32_t* temp_buffer, // Needs to hold 4 x 32-bit per row
+ uint32_t* hashes) {
+#if defined(ARROW_HAVE_AVX2)
+ if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
+ hash_varlen_avx2(num_rows, offsets, concatenated_keys, temp_buffer, hashes);
+ return;
+ }
+#endif
+ static const uint64_t masks[9] = {0,
+ 0xffULL,
+ 0xffffULL,
+ 0xffffffULL,
+ 0xffffffffULL,
+ 0xffffffffffULL,
+ 0xffffffffffffULL,
+ 0xffffffffffffffULL,
+ ~0ULL};
+
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ uint32_t offset = offsets[i];
+ uint32_t key_length = offsets[i + 1] - offsets[i];
+ const uint32_t num_stripes = key_length / 16;
+
+ uint32_t acc1, acc2, acc3, acc4;
+ acc1 = static_cast<uint32_t>(
+ (static_cast<uint64_t>(PRIME32_1) + static_cast<uint64_t>(PRIME32_2)) &
+ 0xffffffff);
+ acc2 = PRIME32_2;
+ acc3 = 0;
+ acc4 = static_cast<uint32_t>(-static_cast<int32_t>(PRIME32_1));
+
+ for (uint32_t stripe = 0; stripe < num_stripes; ++stripe) {
+ helper_stripe(offset, ~0ULL, concatenated_keys, acc1, acc2, acc3, acc4);
+ offset += 16;
+ }
+ uint32_t key_length_remaining = key_length - num_stripes * 16;
+ if (key_length_remaining > 8) {
+ helper_stripe(offset, masks[key_length_remaining - 8], concatenated_keys, acc1,
+ acc2, acc3, acc4);
+ hashes[i] = combine_accumulators(acc1, acc2, acc3, acc4);
+ } else if (key_length > 0) {
+ uint32_t acc_combined = combine_accumulators(acc1, acc2, acc3, acc4);
+ hashes[i] = helper_tail(offset, masks[key_length_remaining], concatenated_keys,
+ acc_combined);
+ } else {
+ hashes[i] = combine_accumulators(acc1, acc2, acc3, acc4);
+ }
+ }
+ avalanche(hardware_flags, num_rows, hashes);
+}
+
+// From:
+// https://www.boost.org/doc/libs/1_37_0/doc/html/hash/reference.html#boost.hash_combine
+// template <class T>
+// inline void hash_combine(std::size_t& seed, const T& v)
+//{
+// std::hash<T> hasher;
+// seed ^= hasher(v) + 0x9e3779b9 + (seed<<6) + (seed>>2);
+//}
+void Hashing::HashCombine(KeyEncoder::KeyEncoderContext* ctx, uint32_t num_rows,
+ uint32_t* accumulated_hash, const uint32_t* next_column_hash) {
+ uint32_t num_processed = 0;
+#if defined(ARROW_HAVE_AVX2)
+ if (ctx->has_avx2()) {
+ num_processed = HashCombine_avx2(num_rows, accumulated_hash, next_column_hash);
+ }
+#endif
+ for (uint32_t i = num_processed; i < num_rows; ++i) {
+ uint32_t acc = accumulated_hash[i];
+ uint32_t next = next_column_hash[i];
+ next += 0x9e3779b9 + (acc << 6) + (acc >> 2);
+ acc ^= next;
+ accumulated_hash[i] = acc;
+ }
+}
+
+void Hashing::HashMultiColumn(const std::vector<KeyEncoder::KeyColumnArray>& cols,
+ KeyEncoder::KeyEncoderContext* ctx, uint32_t* out_hash) {
+ uint32_t num_rows = static_cast<uint32_t>(cols[0].length());
+
+ auto hash_temp_buf = util::TempVectorHolder<uint32_t>(ctx->stack, num_rows);
+ auto hash_null_index_buf = util::TempVectorHolder<uint16_t>(ctx->stack, num_rows);
+ auto byte_temp_buf = util::TempVectorHolder<uint8_t>(ctx->stack, num_rows);
+ auto varbin_temp_buf = util::TempVectorHolder<uint32_t>(ctx->stack, 4 * num_rows);
+
+ bool is_first = true;
+
+ for (size_t icol = 0; icol < cols.size(); ++icol) {
+ if (cols[icol].metadata().is_fixed_length) {
+ uint32_t col_width = cols[icol].metadata().fixed_length;
+ if (col_width == 0) {
+ util::BitUtil::bits_to_bytes(ctx->hardware_flags, num_rows, cols[icol].data(1),
+ byte_temp_buf.mutable_data(),
+ cols[icol].bit_offset(1));
+ }
+ Hashing::hash_fixed(
+ ctx->hardware_flags, num_rows, col_width == 0 ? 1 : col_width,
+ col_width == 0 ? byte_temp_buf.mutable_data() : cols[icol].data(1),
+ is_first ? out_hash : hash_temp_buf.mutable_data());
+ } else {
+ Hashing::hash_varlen(
+ ctx->hardware_flags, num_rows, cols[icol].offsets(), cols[icol].data(2),
+ varbin_temp_buf.mutable_data(), // Needs to hold 4 x 32-bit per row
+ is_first ? out_hash : hash_temp_buf.mutable_data());
+ }
+
+ // Zero hash for nulls
+ if (cols[icol].data(0)) {
+ uint32_t* dst_hash = is_first ? out_hash : hash_temp_buf.mutable_data();
+ int num_nulls;
+ util::BitUtil::bits_to_indexes(0, ctx->hardware_flags, num_rows, cols[icol].data(0),
+ &num_nulls, hash_null_index_buf.mutable_data(),
+ cols[icol].bit_offset(0));
+ for (int i = 0; i < num_nulls; ++i) {
+ uint16_t row_id = hash_null_index_buf.mutable_data()[i];
+ dst_hash[row_id] = 0;
+ }
+ }
+
+ if (!is_first) {
+ HashCombine(ctx, num_rows, out_hash, hash_temp_buf.mutable_data());
+ }
+ is_first = false;
+ }
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/key_hash.h b/src/arrow/cpp/src/arrow/compute/exec/key_hash.h
new file mode 100644
index 000000000..a0ed42cf8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/key_hash.h
@@ -0,0 +1,106 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#if defined(ARROW_HAVE_AVX2)
+#include <immintrin.h>
+#endif
+
+#include <cstdint>
+
+#include "arrow/compute/exec/key_encode.h"
+#include "arrow/compute/exec/util.h"
+
+namespace arrow {
+namespace compute {
+
+// Implementations are based on xxh3 32-bit algorithm description from:
+// https://github.com/Cyan4973/xxHash/blob/dev/doc/xxhash_spec.md
+//
+class Hashing {
+ public:
+ static void hash_fixed(int64_t hardware_flags, uint32_t num_keys, uint32_t length_key,
+ const uint8_t* keys, uint32_t* hashes);
+
+ static void hash_varlen(int64_t hardware_flags, uint32_t num_rows,
+ const uint32_t* offsets, const uint8_t* concatenated_keys,
+ uint32_t* temp_buffer, // Needs to hold 4 x 32-bit per row
+ uint32_t* hashes);
+
+ static void HashMultiColumn(const std::vector<KeyEncoder::KeyColumnArray>& cols,
+ KeyEncoder::KeyEncoderContext* ctx, uint32_t* out_hash);
+
+ private:
+ static const uint32_t PRIME32_1 = 0x9E3779B1; // 0b10011110001101110111100110110001
+ static const uint32_t PRIME32_2 = 0x85EBCA77; // 0b10000101111010111100101001110111
+ static const uint32_t PRIME32_3 = 0xC2B2AE3D; // 0b11000010101100101010111000111101
+ static const uint32_t PRIME32_4 = 0x27D4EB2F; // 0b00100111110101001110101100101111
+ static const uint32_t PRIME32_5 = 0x165667B1; // 0b00010110010101100110011110110001
+
+ static void HashCombine(KeyEncoder::KeyEncoderContext* ctx, uint32_t num_rows,
+ uint32_t* accumulated_hash, const uint32_t* next_column_hash);
+
+#if defined(ARROW_HAVE_AVX2)
+ static uint32_t HashCombine_avx2(uint32_t num_rows, uint32_t* accumulated_hash,
+ const uint32_t* next_column_hash);
+#endif
+
+ // Avalanche
+ static inline uint32_t avalanche_helper(uint32_t acc);
+#if defined(ARROW_HAVE_AVX2)
+ static void avalanche_avx2(uint32_t num_keys, uint32_t* hashes);
+#endif
+ static void avalanche(int64_t hardware_flags, uint32_t num_keys, uint32_t* hashes);
+
+ // Accumulator combine
+ static inline uint32_t combine_accumulators(const uint32_t acc1, const uint32_t acc2,
+ const uint32_t acc3, const uint32_t acc4);
+#if defined(ARROW_HAVE_AVX2)
+ static inline uint64_t combine_accumulators_avx2(__m256i acc);
+#endif
+
+ // Helpers
+ template <typename T>
+ static inline void helper_8B(uint32_t key_length, uint32_t num_keys, const T* keys,
+ uint32_t* hashes);
+ static inline void helper_stripe(uint32_t offset, uint64_t mask_hi, const uint8_t* keys,
+ uint32_t& acc1, uint32_t& acc2, uint32_t& acc3,
+ uint32_t& acc4);
+ static inline uint32_t helper_tail(uint32_t offset, uint64_t mask, const uint8_t* keys,
+ uint32_t acc);
+#if defined(ARROW_HAVE_AVX2)
+ static void helper_stripes_avx2(uint32_t num_keys, uint32_t key_length,
+ const uint8_t* keys, uint32_t* hash);
+ static void helper_tails_avx2(uint32_t num_keys, uint32_t key_length,
+ const uint8_t* keys, uint32_t* hash);
+#endif
+ static void helper_stripes(int64_t hardware_flags, uint32_t num_keys,
+ uint32_t key_length, const uint8_t* keys, uint32_t* hash);
+ static void helper_tails(int64_t hardware_flags, uint32_t num_keys, uint32_t key_length,
+ const uint8_t* keys, uint32_t* hash);
+
+#if defined(ARROW_HAVE_AVX2)
+ static void hash_varlen_avx2(uint32_t num_rows, const uint32_t* offsets,
+ const uint8_t* concatenated_keys,
+ uint32_t* temp_buffer, // Needs to hold 4 x 32-bit per row
+ uint32_t* hashes);
+#endif
+};
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/key_hash_avx2.cc b/src/arrow/cpp/src/arrow/compute/exec/key_hash_avx2.cc
new file mode 100644
index 000000000..3804afe10
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/key_hash_avx2.cc
@@ -0,0 +1,268 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <immintrin.h>
+
+#include "arrow/compute/exec/key_hash.h"
+
+namespace arrow {
+namespace compute {
+
+#if defined(ARROW_HAVE_AVX2)
+
+void Hashing::avalanche_avx2(uint32_t num_keys, uint32_t* hashes) {
+ constexpr int unroll = 8;
+ ARROW_DCHECK(num_keys % unroll == 0);
+ for (uint32_t i = 0; i < num_keys / unroll; ++i) {
+ __m256i hash = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(hashes) + i);
+ hash = _mm256_xor_si256(hash, _mm256_srli_epi32(hash, 15));
+ hash = _mm256_mullo_epi32(hash, _mm256_set1_epi32(PRIME32_2));
+ hash = _mm256_xor_si256(hash, _mm256_srli_epi32(hash, 13));
+ hash = _mm256_mullo_epi32(hash, _mm256_set1_epi32(PRIME32_3));
+ hash = _mm256_xor_si256(hash, _mm256_srli_epi32(hash, 16));
+ _mm256_storeu_si256((reinterpret_cast<__m256i*>(hashes)) + i, hash);
+ }
+}
+
+inline uint64_t Hashing::combine_accumulators_avx2(__m256i acc) {
+ acc = _mm256_or_si256(
+ _mm256_sllv_epi32(acc, _mm256_setr_epi32(1, 7, 12, 18, 1, 7, 12, 18)),
+ _mm256_srlv_epi32(acc, _mm256_setr_epi32(32 - 1, 32 - 7, 32 - 12, 32 - 18, 32 - 1,
+ 32 - 7, 32 - 12, 32 - 18)));
+ acc = _mm256_add_epi32(acc, _mm256_shuffle_epi32(acc, 0xee)); // 0b11101110
+ acc = _mm256_add_epi32(acc, _mm256_srli_epi64(acc, 32));
+ acc = _mm256_permutevar8x32_epi32(acc, _mm256_setr_epi32(0, 4, 0, 0, 0, 0, 0, 0));
+ uint64_t result = _mm256_extract_epi64(acc, 0);
+ return result;
+}
+
+void Hashing::helper_stripes_avx2(uint32_t num_keys, uint32_t key_length,
+ const uint8_t* keys, uint32_t* hash) {
+ constexpr int unroll = 2;
+ ARROW_DCHECK(num_keys % unroll == 0);
+
+ constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL;
+ constexpr uint64_t kByteSequence8To15 = 0x0f0e0d0c0b0a0908ULL;
+
+ const __m256i mask_last_stripe =
+ (key_length % 16) <= 8
+ ? _mm256_set1_epi8(static_cast<char>(0xffU))
+ : _mm256_cmpgt_epi8(_mm256_set1_epi8(key_length % 16),
+ _mm256_setr_epi64x(kByteSequence0To7, kByteSequence8To15,
+ kByteSequence0To7, kByteSequence8To15));
+
+ // If length modulo stripe length is less than or equal 8, round down to the nearest 16B
+ // boundary (8B ending will be processed in a separate function), otherwise round up.
+ const uint32_t num_stripes = (key_length + 7) / 16;
+ for (uint32_t i = 0; i < num_keys / unroll; ++i) {
+ __m256i acc = _mm256_setr_epi32(
+ static_cast<uint32_t>((static_cast<uint64_t>(PRIME32_1) + PRIME32_2) &
+ 0xffffffff),
+ PRIME32_2, 0, static_cast<uint32_t>(-static_cast<int32_t>(PRIME32_1)),
+ static_cast<uint32_t>((static_cast<uint64_t>(PRIME32_1) + PRIME32_2) &
+ 0xffffffff),
+ PRIME32_2, 0, static_cast<uint32_t>(-static_cast<int32_t>(PRIME32_1)));
+ auto key0 = reinterpret_cast<const __m128i*>(keys + key_length * 2 * i);
+ auto key1 = reinterpret_cast<const __m128i*>(keys + key_length * 2 * i + key_length);
+ for (uint32_t stripe = 0; stripe < num_stripes - 1; ++stripe) {
+ auto key_stripe =
+ _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_si128(key0 + stripe)),
+ _mm_loadu_si128(key1 + stripe), 1);
+ acc = _mm256_add_epi32(
+ acc, _mm256_mullo_epi32(key_stripe, _mm256_set1_epi32(PRIME32_2)));
+ acc = _mm256_or_si256(_mm256_slli_epi32(acc, 13), _mm256_srli_epi32(acc, 32 - 13));
+ acc = _mm256_mullo_epi32(acc, _mm256_set1_epi32(PRIME32_1));
+ }
+ auto key_stripe = _mm256_inserti128_si256(
+ _mm256_castsi128_si256(_mm_loadu_si128(key0 + num_stripes - 1)),
+ _mm_loadu_si128(key1 + num_stripes - 1), 1);
+ key_stripe = _mm256_and_si256(key_stripe, mask_last_stripe);
+ acc = _mm256_add_epi32(acc,
+ _mm256_mullo_epi32(key_stripe, _mm256_set1_epi32(PRIME32_2)));
+ acc = _mm256_or_si256(_mm256_slli_epi32(acc, 13), _mm256_srli_epi32(acc, 32 - 13));
+ acc = _mm256_mullo_epi32(acc, _mm256_set1_epi32(PRIME32_1));
+ uint64_t result = combine_accumulators_avx2(acc);
+ reinterpret_cast<uint64_t*>(hash)[i] = result;
+ }
+}
+
+void Hashing::helper_tails_avx2(uint32_t num_keys, uint32_t key_length,
+ const uint8_t* keys, uint32_t* hash) {
+ constexpr int unroll = 8;
+ ARROW_DCHECK(num_keys % unroll == 0);
+ auto keys_i64 = reinterpret_cast<arrow::util::int64_for_gather_t*>(keys);
+
+ // Process between 1 and 8 last bytes of each key, starting from 16B boundary.
+ // The caller needs to make sure that there are no more than 8 bytes to process after
+ // that 16B boundary.
+ uint32_t first_offset = key_length - (key_length % 16);
+ __m256i mask = _mm256_set1_epi64x((~0ULL) >> (8 * (8 - (key_length % 16))));
+ __m256i offset =
+ _mm256_setr_epi32(0, key_length, key_length * 2, key_length * 3, key_length * 4,
+ key_length * 5, key_length * 6, key_length * 7);
+ offset = _mm256_add_epi32(offset, _mm256_set1_epi32(first_offset));
+ __m256i offset_incr = _mm256_set1_epi32(key_length * 8);
+
+ for (uint32_t i = 0; i < num_keys / unroll; ++i) {
+ auto v1 = _mm256_i32gather_epi64(keys_i64, _mm256_castsi256_si128(offset), 1);
+ auto v2 = _mm256_i32gather_epi64(keys_i64, _mm256_extracti128_si256(offset, 1), 1);
+ v1 = _mm256_and_si256(v1, mask);
+ v2 = _mm256_and_si256(v2, mask);
+ v1 = _mm256_permutevar8x32_epi32(v1, _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7));
+ v2 = _mm256_permutevar8x32_epi32(v2, _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7));
+ auto x1 = _mm256_permute2x128_si256(v1, v2, 0x20);
+ auto x2 = _mm256_permute2x128_si256(v1, v2, 0x31);
+ __m256i acc = _mm256_loadu_si256((reinterpret_cast<const __m256i*>(hash)) + i);
+
+ acc = _mm256_add_epi32(acc, _mm256_mullo_epi32(x1, _mm256_set1_epi32(PRIME32_3)));
+ acc = _mm256_or_si256(_mm256_slli_epi32(acc, 17), _mm256_srli_epi32(acc, 32 - 17));
+ acc = _mm256_mullo_epi32(acc, _mm256_set1_epi32(PRIME32_4));
+
+ acc = _mm256_add_epi32(acc, _mm256_mullo_epi32(x2, _mm256_set1_epi32(PRIME32_3)));
+ acc = _mm256_or_si256(_mm256_slli_epi32(acc, 17), _mm256_srli_epi32(acc, 32 - 17));
+ acc = _mm256_mullo_epi32(acc, _mm256_set1_epi32(PRIME32_4));
+
+ _mm256_storeu_si256((reinterpret_cast<__m256i*>(hash)) + i, acc);
+
+ offset = _mm256_add_epi32(offset, offset_incr);
+ }
+}
+
+void Hashing::hash_varlen_avx2(uint32_t num_rows, const uint32_t* offsets,
+ const uint8_t* concatenated_keys,
+ uint32_t* temp_buffer, // Needs to hold 4 x 32-bit per row
+ uint32_t* hashes) {
+ constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL;
+ constexpr uint64_t kByteSequence8To15 = 0x0f0e0d0c0b0a0908ULL;
+
+ const __m128i sequence = _mm_set_epi64x(kByteSequence8To15, kByteSequence0To7);
+ const __m128i acc_init = _mm_setr_epi32(
+ static_cast<uint32_t>((static_cast<uint64_t>(PRIME32_1) + PRIME32_2) & 0xffffffff),
+ PRIME32_2, 0, static_cast<uint32_t>(-static_cast<int32_t>(PRIME32_1)));
+
+ // Variable length keys are always processed as a sequence of 16B stripes,
+ // with the last stripe, if extending past the end of the key, having extra bytes set to
+ // 0 on the fly.
+ for (uint32_t ikey = 0; ikey < num_rows; ++ikey) {
+ uint32_t begin = offsets[ikey];
+ uint32_t end = offsets[ikey + 1];
+ uint32_t length = end - begin;
+ const uint8_t* base = concatenated_keys + begin;
+
+ __m128i acc = acc_init;
+
+ if (length) {
+ uint32_t i;
+ for (i = 0; i < (length - 1) / 16; ++i) {
+ __m128i key_stripe = _mm_loadu_si128(reinterpret_cast<const __m128i*>(base) + i);
+ acc = _mm_add_epi32(acc, _mm_mullo_epi32(key_stripe, _mm_set1_epi32(PRIME32_2)));
+ acc = _mm_or_si128(_mm_slli_epi32(acc, 13), _mm_srli_epi32(acc, 32 - 13));
+ acc = _mm_mullo_epi32(acc, _mm_set1_epi32(PRIME32_1));
+ }
+ __m128i key_stripe = _mm_loadu_si128(reinterpret_cast<const __m128i*>(base) + i);
+ __m128i mask = _mm_cmpgt_epi8(_mm_set1_epi8(((length - 1) % 16) + 1), sequence);
+ key_stripe = _mm_and_si128(key_stripe, mask);
+ acc = _mm_add_epi32(acc, _mm_mullo_epi32(key_stripe, _mm_set1_epi32(PRIME32_2)));
+ acc = _mm_or_si128(_mm_slli_epi32(acc, 13), _mm_srli_epi32(acc, 32 - 13));
+ acc = _mm_mullo_epi32(acc, _mm_set1_epi32(PRIME32_1));
+ }
+
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(temp_buffer) + ikey, acc);
+ }
+
+ // Combine accumulators and perform avalanche
+ constexpr int unroll = 8;
+ for (uint32_t i = 0; i < num_rows / unroll; ++i) {
+ __m256i accA =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(temp_buffer) + 4 * i + 0);
+ __m256i accB =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(temp_buffer) + 4 * i + 1);
+ __m256i accC =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(temp_buffer) + 4 * i + 2);
+ __m256i accD =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(temp_buffer) + 4 * i + 3);
+ // Transpose 2x 4x4 32-bit matrices
+ __m256i r0 = _mm256_unpacklo_epi32(accA, accB);
+ __m256i r1 = _mm256_unpackhi_epi32(accA, accB);
+ __m256i r2 = _mm256_unpacklo_epi32(accC, accD);
+ __m256i r3 = _mm256_unpackhi_epi32(accC, accD);
+ accA = _mm256_unpacklo_epi64(r0, r2);
+ accB = _mm256_unpackhi_epi64(r0, r2);
+ accC = _mm256_unpacklo_epi64(r1, r3);
+ accD = _mm256_unpackhi_epi64(r1, r3);
+ // _rotl(accA, 1)
+ // _rotl(accB, 7)
+ // _rotl(accC, 12)
+ // _rotl(accD, 18)
+ accA = _mm256_or_si256(_mm256_slli_epi32(accA, 1), _mm256_srli_epi32(accA, 32 - 1));
+ accB = _mm256_or_si256(_mm256_slli_epi32(accB, 7), _mm256_srli_epi32(accB, 32 - 7));
+ accC = _mm256_or_si256(_mm256_slli_epi32(accC, 12), _mm256_srli_epi32(accC, 32 - 12));
+ accD = _mm256_or_si256(_mm256_slli_epi32(accD, 18), _mm256_srli_epi32(accD, 32 - 18));
+ accA = _mm256_add_epi32(_mm256_add_epi32(accA, accB), _mm256_add_epi32(accC, accD));
+ // avalanche
+ __m256i hash = accA;
+ hash = _mm256_xor_si256(hash, _mm256_srli_epi32(hash, 15));
+ hash = _mm256_mullo_epi32(hash, _mm256_set1_epi32(PRIME32_2));
+ hash = _mm256_xor_si256(hash, _mm256_srli_epi32(hash, 13));
+ hash = _mm256_mullo_epi32(hash, _mm256_set1_epi32(PRIME32_3));
+ hash = _mm256_xor_si256(hash, _mm256_srli_epi32(hash, 16));
+ // Store.
+ // At this point, because of way 2x 4x4 transposition was done, output hashes are in
+ // order: 0, 2, 4, 6, 1, 3, 5, 7. Bring back the original order.
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(hashes) + i,
+ _mm256_permutevar8x32_epi32(hash, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)));
+ }
+ // Process the tail of up to 7 hashes
+ for (uint32_t i = num_rows - num_rows % unroll; i < num_rows; ++i) {
+ uint32_t* temp_buffer_base = temp_buffer + i * 4;
+ uint32_t acc = ROTL(temp_buffer_base[0], 1) + ROTL(temp_buffer_base[1], 7) +
+ ROTL(temp_buffer_base[2], 12) + ROTL(temp_buffer_base[3], 18);
+
+ // avalanche
+ acc ^= (acc >> 15);
+ acc *= PRIME32_2;
+ acc ^= (acc >> 13);
+ acc *= PRIME32_3;
+ acc ^= (acc >> 16);
+
+ hashes[i] = acc;
+ }
+}
+
+uint32_t Hashing::HashCombine_avx2(uint32_t num_rows, uint32_t* accumulated_hash,
+ const uint32_t* next_column_hash) {
+ constexpr uint32_t unroll = 8;
+ for (uint32_t i = 0; i < num_rows / unroll; ++i) {
+ __m256i acc =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(accumulated_hash) + i);
+ __m256i next =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(next_column_hash) + i);
+ next = _mm256_add_epi32(next, _mm256_set1_epi32(0x9e3779b9));
+ next = _mm256_add_epi32(next, _mm256_slli_epi32(acc, 6));
+ next = _mm256_add_epi32(next, _mm256_srli_epi32(acc, 2));
+ acc = _mm256_xor_si256(acc, next);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(accumulated_hash) + i, acc);
+ }
+ uint32_t num_processed = num_rows / unroll * unroll;
+ return num_processed;
+}
+
+#endif
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/key_map.cc b/src/arrow/cpp/src/arrow/compute/exec/key_map.cc
new file mode 100644
index 000000000..bff352e01
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/key_map.cc
@@ -0,0 +1,862 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/key_map.h"
+
+#include <memory.h>
+
+#include <algorithm>
+#include <cstdint>
+
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/ubsan.h"
+
+namespace arrow {
+
+using BitUtil::CountLeadingZeros;
+
+namespace compute {
+
+constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL;
+
+// Scan bytes in block in reverse and stop as soon
+// as a position of interest is found.
+//
+// Positions of interest:
+// a) slot with a matching stamp is encountered,
+// b) first empty slot is encountered,
+// c) we reach the end of the block.
+//
+// Optionally an index of the first slot to start the search from can be specified.
+// In this case slots before it will be ignored.
+//
+template <bool use_start_slot>
+inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot,
+ int* out_slot, int* out_match_found) const {
+ // Filled slot bytes have the highest bit set to 0 and empty slots are equal to 0x80.
+ uint64_t block_high_bits = block & kHighBitOfEachByte;
+
+ // Replicate 7-bit stamp to all non-empty slots, leaving zeroes for empty slots.
+ uint64_t stamp_pattern = stamp * ((block_high_bits ^ kHighBitOfEachByte) >> 7);
+
+ // If we xor this pattern with block status bytes we get in individual bytes:
+ // a) 0x00, for filled slots matching the stamp,
+ // b) 0x00 < x < 0x80, for filled slots not matching the stamp,
+ // c) 0x80, for empty slots.
+ uint64_t block_xor_pattern = block ^ stamp_pattern;
+
+ // If we then add 0x7f to every byte, we get:
+ // a) 0x7F
+ // b) 0x80 <= x < 0xFF
+ // c) 0xFF
+ uint64_t match_base = block_xor_pattern + ~kHighBitOfEachByte;
+
+ // The highest bit now tells us if we have a match (0) or not (1).
+ // We will negate the bits so that match is represented by a set bit.
+ uint64_t matches = ~match_base;
+
+ // Clear 7 non-relevant bits in each byte.
+ // Also clear bytes that correspond to slots that we were supposed to
+ // skip due to provided start slot index.
+ // Note: the highest byte corresponds to the first slot.
+ if (use_start_slot) {
+ matches &= kHighBitOfEachByte >> (8 * start_slot);
+ } else {
+ matches &= kHighBitOfEachByte;
+ }
+
+ // In case when there are no matches in slots and the block is full (no empty slots),
+ // pretend that there is a match in the last slot.
+ //
+ matches |= (~block_high_bits & 0x80);
+
+ // We get 0 if there are no matches
+ *out_match_found = (matches == 0 ? 0 : 1);
+
+ // Now if we or with the highest bits of the block and scan zero bits in reverse,
+ // we get 8x slot index that we were looking for.
+ // This formula works in all three cases a), b) and c).
+ *out_slot = static_cast<int>(CountLeadingZeros(matches | block_high_bits) >> 3);
+}
+
+inline uint64_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot,
+ uint64_t group_id_mask) const {
+ // Group id values for all 8 slots in the block are bit-packed and follow the status
+ // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In
+ // that case we can extract group id using aligned 64-bit word access.
+ int num_group_id_bits = static_cast<int>(ARROW_POPCOUNT64(group_id_mask));
+ ARROW_DCHECK(num_group_id_bits == 8 || num_group_id_bits == 16 ||
+ num_group_id_bits == 32 || num_group_id_bits == 64);
+
+ int bit_offset = slot * num_group_id_bits;
+ const uint64_t* group_id_bytes =
+ reinterpret_cast<const uint64_t*>(block_ptr) + 1 + (bit_offset >> 6);
+ uint64_t group_id = (*group_id_bytes >> (bit_offset & 63)) & group_id_mask;
+
+ return group_id;
+}
+
+template <typename T, bool use_selection>
+void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* selection,
+ const uint32_t* hashes, const uint8_t* local_slots,
+ uint32_t* out_group_ids, int element_offset,
+ int element_multiplier) const {
+ const T* elements = reinterpret_cast<const T*>(blocks_) + element_offset;
+ if (log_blocks_ == 0) {
+ ARROW_DCHECK(sizeof(T) == sizeof(uint8_t));
+ for (int i = 0; i < num_keys; ++i) {
+ uint32_t id = use_selection ? selection[i] : i;
+ uint32_t group_id = blocks_[8 + local_slots[id]];
+ out_group_ids[id] = group_id;
+ }
+ } else {
+ for (int i = 0; i < num_keys; ++i) {
+ uint32_t id = use_selection ? selection[i] : i;
+ uint32_t hash = hashes[id];
+ int64_t pos =
+ (hash >> (bits_hash_ - log_blocks_)) * element_multiplier + local_slots[id];
+ uint32_t group_id = static_cast<uint32_t>(elements[pos]);
+ ARROW_DCHECK(group_id < num_inserted_ || num_inserted_ == 0);
+ out_group_ids[id] = group_id;
+ }
+ }
+}
+
+void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_selection,
+ const uint32_t* hashes, const uint8_t* local_slots,
+ uint32_t* out_group_ids) const {
+ // Group id values for all 8 slots in the block are bit-packed and follow the status
+ // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In
+ // that case we can extract group id using aligned 64-bit word access.
+ int num_group_id_bits = num_groupid_bits_from_log_blocks(log_blocks_);
+ ARROW_DCHECK(num_group_id_bits == 8 || num_group_id_bits == 16 ||
+ num_group_id_bits == 32);
+
+ // Optimistically use simplified lookup involving only a start block to find
+ // a single group id candidate for every input.
+#if defined(ARROW_HAVE_AVX2)
+ int num_group_id_bytes = num_group_id_bits / 8;
+ if ((hardware_flags_ & arrow::internal::CpuInfo::AVX2) && !optional_selection) {
+ extract_group_ids_avx2(num_keys, hashes, local_slots, out_group_ids, sizeof(uint64_t),
+ 8 + 8 * num_group_id_bytes, num_group_id_bytes);
+ return;
+ }
+#endif
+ switch (num_group_id_bits) {
+ case 8:
+ if (optional_selection) {
+ extract_group_ids_imp<uint8_t, true>(num_keys, optional_selection, hashes,
+ local_slots, out_group_ids, 8, 16);
+ } else {
+ extract_group_ids_imp<uint8_t, false>(num_keys, nullptr, hashes, local_slots,
+ out_group_ids, 8, 16);
+ }
+ break;
+ case 16:
+ if (optional_selection) {
+ extract_group_ids_imp<uint16_t, true>(num_keys, optional_selection, hashes,
+ local_slots, out_group_ids, 4, 12);
+ } else {
+ extract_group_ids_imp<uint16_t, false>(num_keys, nullptr, hashes, local_slots,
+ out_group_ids, 4, 12);
+ }
+ break;
+ case 32:
+ if (optional_selection) {
+ extract_group_ids_imp<uint32_t, true>(num_keys, optional_selection, hashes,
+ local_slots, out_group_ids, 2, 10);
+ } else {
+ extract_group_ids_imp<uint32_t, false>(num_keys, nullptr, hashes, local_slots,
+ out_group_ids, 2, 10);
+ }
+ break;
+ default:
+ ARROW_DCHECK(false);
+ }
+}
+
+void SwissTable::init_slot_ids(const int num_keys, const uint16_t* selection,
+ const uint32_t* hashes, const uint8_t* local_slots,
+ const uint8_t* match_bitvector,
+ uint32_t* out_slot_ids) const {
+ ARROW_DCHECK(selection);
+ if (log_blocks_ == 0) {
+ for (int i = 0; i < num_keys; ++i) {
+ uint16_t id = selection[i];
+ uint32_t match = ::arrow::BitUtil::GetBit(match_bitvector, id) ? 1 : 0;
+ uint32_t slot_id = local_slots[id] + match;
+ out_slot_ids[id] = slot_id;
+ }
+ } else {
+ for (int i = 0; i < num_keys; ++i) {
+ uint16_t id = selection[i];
+ uint32_t hash = hashes[id];
+ uint32_t iblock = (hash >> (bits_hash_ - log_blocks_));
+ uint32_t match = ::arrow::BitUtil::GetBit(match_bitvector, id) ? 1 : 0;
+ uint32_t slot_id = iblock * 8 + local_slots[id] + match;
+ out_slot_ids[id] = slot_id;
+ }
+ }
+}
+
+void SwissTable::init_slot_ids_for_new_keys(uint32_t num_ids, const uint16_t* ids,
+ const uint32_t* hashes,
+ uint32_t* slot_ids) const {
+ int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
+ uint32_t num_block_bytes = num_groupid_bits + 8;
+ if (log_blocks_ == 0) {
+ uint64_t block = *reinterpret_cast<const uint64_t*>(blocks_);
+ uint32_t empty_slot =
+ static_cast<uint32_t>(8 - ARROW_POPCOUNT64(block & kHighBitOfEachByte));
+ for (uint32_t i = 0; i < num_ids; ++i) {
+ int id = ids[i];
+ slot_ids[id] = empty_slot;
+ }
+ } else {
+ for (uint32_t i = 0; i < num_ids; ++i) {
+ int id = ids[i];
+ uint32_t hash = hashes[id];
+ uint32_t iblock = hash >> (bits_hash_ - log_blocks_);
+ uint64_t block;
+ for (;;) {
+ block = *reinterpret_cast<const uint64_t*>(blocks_ + num_block_bytes * iblock);
+ block &= kHighBitOfEachByte;
+ if (block) {
+ break;
+ }
+ iblock = (iblock + 1) & ((1 << log_blocks_) - 1);
+ }
+ uint32_t empty_slot = static_cast<int>(8 - ARROW_POPCOUNT64(block));
+ slot_ids[id] = iblock * 8 + empty_slot;
+ }
+ }
+}
+
+// Quickly filter out keys that have no matches based only on hash value and the
+// corresponding starting 64-bit block of slot status bytes. May return false positives.
+//
+void SwissTable::early_filter_imp(const int num_keys, const uint32_t* hashes,
+ uint8_t* out_match_bitvector,
+ uint8_t* out_local_slots) const {
+ // Clear the output bit vector
+ memset(out_match_bitvector, 0, (num_keys + 7) / 8);
+
+ // Based on the size of the table, prepare bit number constants.
+ uint32_t stamp_mask = (1 << bits_stamp_) - 1;
+ int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
+
+ for (int i = 0; i < num_keys; ++i) {
+ // Extract from hash: block index and stamp
+ //
+ uint32_t hash = hashes[i];
+ uint32_t iblock = hash >> (bits_hash_ - bits_stamp_ - log_blocks_);
+ uint32_t stamp = iblock & stamp_mask;
+ iblock >>= bits_stamp_;
+
+ uint32_t num_block_bytes = num_groupid_bits + 8;
+ const uint8_t* blockbase = reinterpret_cast<const uint8_t*>(blocks_) +
+ static_cast<uint64_t>(iblock) * num_block_bytes;
+ ARROW_DCHECK(num_block_bytes % sizeof(uint64_t) == 0);
+ uint64_t block = *reinterpret_cast<const uint64_t*>(blockbase);
+
+ // Call helper functions to obtain the output triplet:
+ // - match (of a stamp) found flag
+ // - number of slots to skip before resuming further search, in case of no match or
+ // false positive
+ int match_found;
+ int islot_in_block;
+ search_block<false>(block, stamp, 0, &islot_in_block, &match_found);
+
+ out_match_bitvector[i / 8] |= match_found << (i & 7);
+ out_local_slots[i] = static_cast<uint8_t>(islot_in_block);
+ }
+}
+
+// How many groups we can keep in the hash table without the need for resizing.
+// When we reach this limit, we need to break processing of any further rows and resize.
+//
+uint64_t SwissTable::num_groups_for_resize() const {
+ // Resize small hash tables when 50% full (up to 12KB).
+ // Resize large hash tables when 75% full.
+ constexpr int log_blocks_small_ = 9;
+ uint64_t num_slots = 1ULL << (log_blocks_ + 3);
+ if (log_blocks_ <= log_blocks_small_) {
+ return num_slots / 2;
+ } else {
+ return num_slots * 3 / 4;
+ }
+}
+
+uint64_t SwissTable::wrap_global_slot_id(uint64_t global_slot_id) const {
+ uint64_t global_slot_id_mask = (1 << (log_blocks_ + 3)) - 1;
+ return global_slot_id & global_slot_id_mask;
+}
+
+void SwissTable::early_filter(const int num_keys, const uint32_t* hashes,
+ uint8_t* out_match_bitvector,
+ uint8_t* out_local_slots) const {
+ // Optimistically use simplified lookup involving only a start block to find
+ // a single group id candidate for every input.
+#if defined(ARROW_HAVE_AVX2)
+ if (hardware_flags_ & arrow::internal::CpuInfo::AVX2) {
+ if (log_blocks_ <= 4) {
+ int tail = num_keys % 32;
+ int delta = num_keys - tail;
+ early_filter_imp_avx2_x32(num_keys - tail, hashes, out_match_bitvector,
+ out_local_slots);
+ early_filter_imp_avx2_x8(tail, hashes + delta, out_match_bitvector + delta / 8,
+ out_local_slots + delta);
+ } else {
+ early_filter_imp_avx2_x8(num_keys, hashes, out_match_bitvector, out_local_slots);
+ }
+ } else {
+#endif
+ early_filter_imp(num_keys, hashes, out_match_bitvector, out_local_slots);
+#if defined(ARROW_HAVE_AVX2)
+ }
+#endif
+}
+
+// Input selection may be:
+// - a range of all ids from 0 to num_keys - 1
+// - a selection vector with list of ids
+// - a bit-vector marking ids that are included
+// Either selection index vector or selection bit-vector must be provided
+// but both cannot be set at the same time (one must be null).
+//
+// Input and output selection index vectors are allowed to point to the same buffer
+// (in-place filtering of ids).
+//
+// Output selection vector needs to have enough space for num_keys entries.
+//
+void SwissTable::run_comparisons(const int num_keys,
+ const uint16_t* optional_selection_ids,
+ const uint8_t* optional_selection_bitvector,
+ const uint32_t* groupids, int* out_num_not_equal,
+ uint16_t* out_not_equal_selection) const {
+ ARROW_DCHECK(optional_selection_ids || optional_selection_bitvector);
+ ARROW_DCHECK(!optional_selection_ids || !optional_selection_bitvector);
+
+ if (!optional_selection_ids && optional_selection_bitvector) {
+ // Count rows with matches (based on stamp comparison)
+ // and decide based on their percentage whether to call dense or sparse comparison
+ // function. Dense comparison means evaluating it for all inputs, even if the
+ // matching stamp was not found. It may be cheaper to evaluate comparison for all
+ // inputs if the extra cost of filtering is higher than the wasted processing of
+ // rows with no match.
+ //
+ // Dense comparison can only be used if there is at least one inserted key,
+ // because otherwise there is no key to compare to.
+ //
+ int64_t num_matches = arrow::internal::CountSetBits(optional_selection_bitvector,
+ /*offset=*/0, num_keys);
+
+ if (num_inserted_ > 0 && num_matches > 0 && num_matches > 3 * num_keys / 4) {
+ uint32_t out_num;
+ equal_impl_(num_keys, nullptr, groupids, &out_num, out_not_equal_selection);
+ *out_num_not_equal = static_cast<int>(out_num);
+ } else {
+ util::BitUtil::bits_to_indexes(1, hardware_flags_, num_keys,
+ optional_selection_bitvector, out_num_not_equal,
+ out_not_equal_selection);
+ uint32_t out_num;
+ equal_impl_(*out_num_not_equal, out_not_equal_selection, groupids, &out_num,
+ out_not_equal_selection);
+ *out_num_not_equal = static_cast<int>(out_num);
+ }
+ } else {
+ uint32_t out_num;
+ equal_impl_(num_keys, optional_selection_ids, groupids, &out_num,
+ out_not_equal_selection);
+ *out_num_not_equal = static_cast<int>(out_num);
+ }
+}
+
+// Given starting slot index, search blocks for a matching stamp
+// until one is found or an empty slot is reached.
+// If the search stopped on a non-empty slot, output corresponding
+// group id from that slot.
+//
+// Return true if a match was found.
+//
+bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_slot_id,
+ uint32_t* out_slot_id,
+ uint32_t* out_group_id) const {
+ const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
+ constexpr uint64_t stamp_mask = 0x7f;
+ const int stamp =
+ static_cast<int>((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask);
+ uint64_t start_slot_id = wrap_global_slot_id(in_slot_id);
+ int match_found;
+ int local_slot;
+ uint8_t* blockbase;
+ for (;;) {
+ const uint64_t num_block_bytes = (8 + num_groupid_bits);
+ blockbase = blocks_ + num_block_bytes * (start_slot_id >> 3);
+ uint64_t block = *reinterpret_cast<uint64_t*>(blockbase);
+
+ search_block<true>(block, stamp, (start_slot_id & 7), &local_slot, &match_found);
+
+ start_slot_id =
+ wrap_global_slot_id((start_slot_id & ~7ULL) + local_slot + match_found);
+
+ // Match found can be 1 in two cases:
+ // - match was found
+ // - match was not found in a full block
+ // In the second case search needs to continue in the next block.
+ if (match_found == 0 || blockbase[7 - local_slot] == stamp) {
+ break;
+ }
+ }
+
+ const uint64_t groupid_mask = (1ULL << num_groupid_bits) - 1;
+ *out_group_id =
+ static_cast<uint32_t>(extract_group_id(blockbase, local_slot, groupid_mask));
+ *out_slot_id = static_cast<uint32_t>(start_slot_id);
+
+ return match_found;
+}
+
+void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash,
+ uint32_t group_id) {
+ const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
+
+ // We assume here that the number of bits is rounded up to 8, 16, 32 or 64.
+ // In that case we can insert group id value using aligned 64-bit word access.
+ ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 ||
+ num_groupid_bits == 32 || num_groupid_bits == 64);
+
+ const uint64_t num_block_bytes = (8 + num_groupid_bits);
+ constexpr uint64_t stamp_mask = 0x7f;
+
+ int start_slot = (slot_id & 7);
+ int stamp =
+ static_cast<int>((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask);
+ uint64_t block_id = slot_id >> 3;
+ uint8_t* blockbase = blocks_ + num_block_bytes * block_id;
+
+ blockbase[7 - start_slot] = static_cast<uint8_t>(stamp);
+ int groupid_bit_offset = static_cast<int>(start_slot * num_groupid_bits);
+
+ // Block status bytes should start at an address aligned to 8 bytes
+ ARROW_DCHECK((reinterpret_cast<uint64_t>(blockbase) & 7) == 0);
+ uint64_t* ptr = reinterpret_cast<uint64_t*>(blockbase) + 1 + (groupid_bit_offset >> 6);
+ *ptr |= (static_cast<uint64_t>(group_id) << (groupid_bit_offset & 63));
+
+ hashes_[slot_id] = hash;
+}
+
+// Find method is the continuation of processing from early_filter.
+// Its input consists of hash values and the output of early_filter.
+// It updates match bit-vector, clearing it from any false positives
+// that might have been left by early_filter.
+// It also outputs group ids, which are needed to be able to execute
+// key comparisons. The caller may discard group ids if only the
+// match flag is of interest.
+//
+void SwissTable::find(const int num_keys, const uint32_t* hashes,
+ uint8_t* inout_match_bitvector, const uint8_t* local_slots,
+ uint32_t* out_group_ids) const {
+ // Temporary selection vector.
+ // It will hold ids of keys for which we do not know yet
+ // if they have a match in hash table or not.
+ //
+ // Initially the set of these keys is represented by input
+ // match bit-vector. Eventually we switch from this bit-vector
+ // to array of ids.
+ //
+ ARROW_DCHECK(num_keys <= (1 << log_minibatch_));
+ auto ids_buf = util::TempVectorHolder<uint16_t>(temp_stack_, num_keys);
+ uint16_t* ids = ids_buf.mutable_data();
+ int num_ids;
+
+ int64_t num_matches =
+ arrow::internal::CountSetBits(inout_match_bitvector, /*offset=*/0, num_keys);
+
+ // If there is a high density of selected input rows
+ // (majority of them are present in the selection),
+ // we may run some computation on all of the input rows ignoring
+ // selection and then filter the output of this computation
+ // (pre-filtering vs post-filtering).
+ //
+ bool visit_all = num_matches > 0 && num_matches > 3 * num_keys / 4;
+ if (visit_all) {
+ extract_group_ids(num_keys, nullptr, hashes, local_slots, out_group_ids);
+ run_comparisons(num_keys, nullptr, inout_match_bitvector, out_group_ids, &num_ids,
+ ids);
+ } else {
+ util::BitUtil::bits_to_indexes(1, hardware_flags_, num_keys, inout_match_bitvector,
+ &num_ids, ids);
+ extract_group_ids(num_ids, ids, hashes, local_slots, out_group_ids);
+ run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids);
+ }
+
+ if (num_ids == 0) {
+ return;
+ }
+
+ auto slot_ids_buf = util::TempVectorHolder<uint32_t>(temp_stack_, num_keys);
+ uint32_t* slot_ids = slot_ids_buf.mutable_data();
+ init_slot_ids(num_ids, ids, hashes, local_slots, inout_match_bitvector, slot_ids);
+
+ while (num_ids > 0) {
+ int num_ids_last_iteration = num_ids;
+ num_ids = 0;
+ for (int i = 0; i < num_ids_last_iteration; ++i) {
+ int id = ids[i];
+ uint32_t next_slot_id;
+ bool match_found = find_next_stamp_match(hashes[id], slot_ids[id], &next_slot_id,
+ &(out_group_ids[id]));
+ slot_ids[id] = next_slot_id;
+ // If next match was not found then clear match bit in a bit vector
+ if (!match_found) {
+ ::arrow::BitUtil::ClearBit(inout_match_bitvector, id);
+ } else {
+ ids[num_ids++] = id;
+ }
+ }
+
+ run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids);
+ }
+} // namespace compute
+
+// Slow processing of input keys in the most generic case.
+// Handles inserting new keys.
+// Pre-existing keys will be handled correctly, although the intended use is for this
+// call to follow a call to find() method, which would only pass on new keys that were
+// not present in the hash table.
+//
+// Run a single round of slot search - comparison or insert - filter unprocessed.
+// Update selection vector to reflect which items have been processed.
+// Ids in selection vector do not have to be sorted.
+//
+Status SwissTable::map_new_keys_helper(const uint32_t* hashes,
+ uint32_t* inout_num_selected,
+ uint16_t* inout_selection, bool* out_need_resize,
+ uint32_t* out_group_ids,
+ uint32_t* inout_next_slot_ids) {
+ auto num_groups_limit = num_groups_for_resize();
+ ARROW_DCHECK(num_inserted_ < num_groups_limit);
+
+ // Temporary arrays are of limited size.
+ // The input needs to be split into smaller portions if it exceeds that limit.
+ //
+ ARROW_DCHECK(*inout_num_selected <= static_cast<uint32_t>(1 << log_minibatch_));
+
+ size_t num_bytes_for_bits = (*inout_num_selected + 7) / 8 + sizeof(uint64_t);
+ auto match_bitvector_buf = util::TempVectorHolder<uint8_t>(
+ temp_stack_, static_cast<uint32_t>(num_bytes_for_bits));
+ uint8_t* match_bitvector = match_bitvector_buf.mutable_data();
+ memset(match_bitvector, 0xff, num_bytes_for_bits);
+
+ // Check the alignment of the input selection vector
+ ARROW_DCHECK((reinterpret_cast<uint64_t>(inout_selection) & 1) == 0);
+
+ uint32_t num_inserted_new = 0;
+ uint32_t num_processed;
+ for (num_processed = 0; num_processed < *inout_num_selected; ++num_processed) {
+ // row id in original batch
+ int id = inout_selection[num_processed];
+ bool match_found =
+ find_next_stamp_match(hashes[id], inout_next_slot_ids[id],
+ &inout_next_slot_ids[id], &out_group_ids[id]);
+ if (!match_found) {
+ // If we reach the empty slot we insert key for new group
+ //
+ out_group_ids[id] = num_inserted_ + num_inserted_new;
+ insert_into_empty_slot(inout_next_slot_ids[id], hashes[id], out_group_ids[id]);
+ ::arrow::BitUtil::ClearBit(match_bitvector, num_processed);
+ ++num_inserted_new;
+
+ // We need to break processing and have the caller of this function
+ // resize hash table if we reach the limit of the number of groups present.
+ //
+ if (num_inserted_ + num_inserted_new == num_groups_limit) {
+ ++num_processed;
+ break;
+ }
+ }
+ }
+
+ auto temp_ids_buffer =
+ util::TempVectorHolder<uint16_t>(temp_stack_, *inout_num_selected);
+ uint16_t* temp_ids = temp_ids_buffer.mutable_data();
+ int num_temp_ids = 0;
+
+ // Copy keys for newly inserted rows using callback
+ //
+ util::BitUtil::bits_filter_indexes(0, hardware_flags_, num_processed, match_bitvector,
+ inout_selection, &num_temp_ids, temp_ids);
+ ARROW_DCHECK(static_cast<int>(num_inserted_new) == num_temp_ids);
+ RETURN_NOT_OK(append_impl_(num_inserted_new, temp_ids));
+ num_inserted_ += num_inserted_new;
+
+ // Evaluate comparisons and append ids of rows that failed it to the non-match set.
+ util::BitUtil::bits_filter_indexes(1, hardware_flags_, num_processed, match_bitvector,
+ inout_selection, &num_temp_ids, temp_ids);
+ run_comparisons(num_temp_ids, temp_ids, nullptr, out_group_ids, &num_temp_ids,
+ temp_ids);
+
+ memcpy(inout_selection, temp_ids, sizeof(uint16_t) * num_temp_ids);
+ // Append ids of any unprocessed entries if we aborted processing due to the need
+ // to resize.
+ if (num_processed < *inout_num_selected) {
+ memmove(inout_selection + num_temp_ids, inout_selection + num_processed,
+ sizeof(uint16_t) * (*inout_num_selected - num_processed));
+ }
+ *inout_num_selected = num_temp_ids + (*inout_num_selected - num_processed);
+
+ *out_need_resize = (num_inserted_ == num_groups_limit);
+ return Status::OK();
+}
+
+// Do inserts and find group ids for a set of new keys (with possible duplicates within
+// this set).
+//
+Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* hashes,
+ uint32_t* group_ids) {
+ if (num_ids == 0) {
+ return Status::OK();
+ }
+
+ uint16_t max_id = ids[0];
+ for (uint32_t i = 1; i < num_ids; ++i) {
+ max_id = std::max(max_id, ids[i]);
+ }
+
+ // Temporary buffers have limited size.
+ // Caller is responsible for splitting larger input arrays into smaller chunks.
+ ARROW_DCHECK(static_cast<int>(num_ids) <= (1 << log_minibatch_));
+ ARROW_DCHECK(static_cast<int>(max_id + 1) <= (1 << log_minibatch_));
+
+ // Allocate temporary buffers for slot ids and intialize them
+ auto slot_ids_buf = util::TempVectorHolder<uint32_t>(temp_stack_, max_id + 1);
+ uint32_t* slot_ids = slot_ids_buf.mutable_data();
+ init_slot_ids_for_new_keys(num_ids, ids, hashes, slot_ids);
+
+ do {
+ // A single round of slow-pass (robust) lookup or insert.
+ // A single round ends with either a single comparison verifying the match
+ // candidate or inserting a new key. A single round of slow-pass may return early
+ // if we reach the limit of the number of groups due to inserts of new keys. In
+ // that case we need to resize and recalculating starting global slot ids for new
+ // bigger hash table.
+ bool out_of_capacity;
+ RETURN_NOT_OK(map_new_keys_helper(hashes, &num_ids, ids, &out_of_capacity, group_ids,
+ slot_ids));
+ if (out_of_capacity) {
+ RETURN_NOT_OK(grow_double());
+ // Reset start slot ids for still unprocessed input keys.
+ //
+ for (uint32_t i = 0; i < num_ids; ++i) {
+ // First slot in the new starting block
+ const int16_t id = ids[i];
+ slot_ids[id] = (hashes[id] >> (bits_hash_ - log_blocks_)) * 8;
+ }
+ }
+ } while (num_ids > 0);
+
+ return Status::OK();
+}
+
+Status SwissTable::grow_double() {
+ // Before and after metadata
+ int num_group_id_bits_before = num_groupid_bits_from_log_blocks(log_blocks_);
+ int num_group_id_bits_after = num_groupid_bits_from_log_blocks(log_blocks_ + 1);
+ uint64_t group_id_mask_before = ~0ULL >> (64 - num_group_id_bits_before);
+ int log_blocks_before = log_blocks_;
+ int log_blocks_after = log_blocks_ + 1;
+ uint64_t block_size_before = (8 + num_group_id_bits_before);
+ uint64_t block_size_after = (8 + num_group_id_bits_after);
+ uint64_t block_size_total_before = (block_size_before << log_blocks_before) + padding_;
+ uint64_t block_size_total_after = (block_size_after << log_blocks_after) + padding_;
+ uint64_t hashes_size_total_before =
+ (bits_hash_ / 8 * (1 << (log_blocks_before + 3))) + padding_;
+ uint64_t hashes_size_total_after =
+ (bits_hash_ / 8 * (1 << (log_blocks_after + 3))) + padding_;
+ constexpr uint32_t stamp_mask = (1 << bits_stamp_) - 1;
+
+ // Allocate new buffers
+ uint8_t* blocks_new;
+ RETURN_NOT_OK(pool_->Allocate(block_size_total_after, &blocks_new));
+ memset(blocks_new, 0, block_size_total_after);
+ uint8_t* hashes_new_8B;
+ uint32_t* hashes_new;
+ RETURN_NOT_OK(pool_->Allocate(hashes_size_total_after, &hashes_new_8B));
+ hashes_new = reinterpret_cast<uint32_t*>(hashes_new_8B);
+
+ // First pass over all old blocks.
+ // Reinsert entries that were not in the overflow block
+ // (block other than selected by hash bits corresponding to the entry).
+ for (int i = 0; i < (1 << log_blocks_); ++i) {
+ // How many full slots in this block
+ uint8_t* block_base = blocks_ + i * block_size_before;
+ uint8_t* double_block_base_new = blocks_new + 2 * i * block_size_after;
+ uint64_t block = *reinterpret_cast<const uint64_t*>(block_base);
+
+ auto full_slots =
+ static_cast<int>(CountLeadingZeros(block & kHighBitOfEachByte) >> 3);
+ int full_slots_new[2];
+ full_slots_new[0] = full_slots_new[1] = 0;
+ util::SafeStore(double_block_base_new, kHighBitOfEachByte);
+ util::SafeStore(double_block_base_new + block_size_after, kHighBitOfEachByte);
+
+ for (int j = 0; j < full_slots; ++j) {
+ uint64_t slot_id = i * 8 + j;
+ uint32_t hash = hashes_[slot_id];
+ uint64_t block_id_new = hash >> (bits_hash_ - log_blocks_after);
+ bool is_overflow_entry = ((block_id_new >> 1) != static_cast<uint64_t>(i));
+ if (is_overflow_entry) {
+ continue;
+ }
+
+ int ihalf = block_id_new & 1;
+ uint8_t stamp_new =
+ hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask;
+ uint64_t group_id_bit_offs = j * num_group_id_bits_before;
+ uint64_t group_id =
+ (util::SafeLoadAs<uint64_t>(block_base + 8 + (group_id_bit_offs >> 3)) >>
+ (group_id_bit_offs & 7)) &
+ group_id_mask_before;
+
+ uint64_t slot_id_new = i * 16 + ihalf * 8 + full_slots_new[ihalf];
+ hashes_new[slot_id_new] = hash;
+ uint8_t* block_base_new = double_block_base_new + ihalf * block_size_after;
+ block_base_new[7 - full_slots_new[ihalf]] = stamp_new;
+ int group_id_bit_offs_new = full_slots_new[ihalf] * num_group_id_bits_after;
+ uint64_t* ptr =
+ reinterpret_cast<uint64_t*>(block_base_new + 8 + (group_id_bit_offs_new >> 3));
+ util::SafeStore(ptr,
+ util::SafeLoad(ptr) | (group_id << (group_id_bit_offs_new & 7)));
+ full_slots_new[ihalf]++;
+ }
+ }
+
+ // Second pass over all old blocks.
+ // Reinsert entries that were in an overflow block.
+ for (int i = 0; i < (1 << log_blocks_); ++i) {
+ // How many full slots in this block
+ uint8_t* block_base = blocks_ + i * block_size_before;
+ uint64_t block = util::SafeLoadAs<uint64_t>(block_base);
+ int full_slots = static_cast<int>(CountLeadingZeros(block & kHighBitOfEachByte) >> 3);
+
+ for (int j = 0; j < full_slots; ++j) {
+ uint64_t slot_id = i * 8 + j;
+ uint32_t hash = hashes_[slot_id];
+ uint64_t block_id_new = hash >> (bits_hash_ - log_blocks_after);
+ bool is_overflow_entry = ((block_id_new >> 1) != static_cast<uint64_t>(i));
+ if (!is_overflow_entry) {
+ continue;
+ }
+
+ uint64_t group_id_bit_offs = j * num_group_id_bits_before;
+ uint64_t group_id =
+ (util::SafeLoadAs<uint64_t>(block_base + 8 + (group_id_bit_offs >> 3)) >>
+ (group_id_bit_offs & 7)) &
+ group_id_mask_before;
+ uint8_t stamp_new =
+ hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask;
+
+ uint8_t* block_base_new = blocks_new + block_id_new * block_size_after;
+ uint64_t block_new = util::SafeLoadAs<uint64_t>(block_base_new);
+ int full_slots_new =
+ static_cast<int>(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3);
+ while (full_slots_new == 8) {
+ block_id_new = (block_id_new + 1) & ((1 << log_blocks_after) - 1);
+ block_base_new = blocks_new + block_id_new * block_size_after;
+ block_new = util::SafeLoadAs<uint64_t>(block_base_new);
+ full_slots_new =
+ static_cast<int>(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3);
+ }
+
+ hashes_new[block_id_new * 8 + full_slots_new] = hash;
+ block_base_new[7 - full_slots_new] = stamp_new;
+ int group_id_bit_offs_new = full_slots_new * num_group_id_bits_after;
+ uint64_t* ptr =
+ reinterpret_cast<uint64_t*>(block_base_new + 8 + (group_id_bit_offs_new >> 3));
+ util::SafeStore(ptr,
+ util::SafeLoad(ptr) | (group_id << (group_id_bit_offs_new & 7)));
+ }
+ }
+
+ pool_->Free(blocks_, block_size_total_before);
+ pool_->Free(reinterpret_cast<uint8_t*>(hashes_), hashes_size_total_before);
+ log_blocks_ = log_blocks_after;
+ blocks_ = blocks_new;
+ hashes_ = hashes_new;
+
+ return Status::OK();
+}
+
+Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool,
+ util::TempVectorStack* temp_stack, int log_minibatch,
+ EqualImpl equal_impl, AppendImpl append_impl) {
+ hardware_flags_ = hardware_flags;
+ pool_ = pool;
+ temp_stack_ = temp_stack;
+ log_minibatch_ = log_minibatch;
+ equal_impl_ = equal_impl;
+ append_impl_ = append_impl;
+
+ log_blocks_ = 0;
+ int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
+ num_inserted_ = 0;
+
+ const uint64_t block_bytes = 8 + num_groupid_bits;
+ const uint64_t slot_bytes = (block_bytes << log_blocks_) + padding_;
+ RETURN_NOT_OK(pool_->Allocate(slot_bytes, &blocks_));
+
+ // Make sure group ids are initially set to zero for all slots.
+ memset(blocks_, 0, slot_bytes);
+
+ // Initialize all status bytes to represent an empty slot.
+ for (uint64_t i = 0; i < (static_cast<uint64_t>(1) << log_blocks_); ++i) {
+ util::SafeStore(blocks_ + i * block_bytes, kHighBitOfEachByte);
+ }
+
+ uint64_t num_slots = 1ULL << (log_blocks_ + 3);
+ const uint64_t hash_size = sizeof(uint32_t);
+ const uint64_t hash_bytes = hash_size * num_slots + padding_;
+ uint8_t* hashes8;
+ RETURN_NOT_OK(pool_->Allocate(hash_bytes, &hashes8));
+ hashes_ = reinterpret_cast<uint32_t*>(hashes8);
+
+ return Status::OK();
+}
+
+void SwissTable::cleanup() {
+ if (blocks_) {
+ int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
+ const uint64_t block_bytes = 8 + num_groupid_bits;
+ const uint64_t slot_bytes = (block_bytes << log_blocks_) + padding_;
+ pool_->Free(blocks_, slot_bytes);
+ blocks_ = nullptr;
+ }
+ if (hashes_) {
+ uint64_t num_slots = 1ULL << (log_blocks_ + 3);
+ const uint64_t hash_size = sizeof(uint32_t);
+ const uint64_t hash_bytes = hash_size * num_slots + padding_;
+ pool_->Free(reinterpret_cast<uint8_t*>(hashes_), hash_bytes);
+ hashes_ = nullptr;
+ }
+ log_blocks_ = 0;
+ num_inserted_ = 0;
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/key_map.h b/src/arrow/cpp/src/arrow/compute/exec/key_map.h
new file mode 100644
index 000000000..cf539f4a9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/key_map.h
@@ -0,0 +1,206 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <functional>
+
+#include "arrow/compute/exec/util.h"
+#include "arrow/memory_pool.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+
+namespace arrow {
+namespace compute {
+
+class SwissTable {
+ public:
+ SwissTable() = default;
+ ~SwissTable() { cleanup(); }
+
+ using EqualImpl =
+ std::function<void(int num_keys, const uint16_t* selection /* may be null */,
+ const uint32_t* group_ids, uint32_t* out_num_keys_mismatch,
+ uint16_t* out_selection_mismatch)>;
+ using AppendImpl = std::function<Status(int num_keys, const uint16_t* selection)>;
+
+ Status init(int64_t hardware_flags, MemoryPool* pool, util::TempVectorStack* temp_stack,
+ int log_minibatch, EqualImpl equal_impl, AppendImpl append_impl);
+
+ void cleanup();
+
+ void early_filter(const int num_keys, const uint32_t* hashes,
+ uint8_t* out_match_bitvector, uint8_t* out_local_slots) const;
+
+ void find(const int num_keys, const uint32_t* hashes, uint8_t* inout_match_bitvector,
+ const uint8_t* local_slots, uint32_t* out_group_ids) const;
+
+ Status map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* hashes,
+ uint32_t* group_ids);
+
+ private:
+ // Lookup helpers
+
+ /// \brief Scan bytes in block in reverse and stop as soon
+ /// as a position of interest is found.
+ ///
+ /// Positions of interest:
+ /// a) slot with a matching stamp is encountered,
+ /// b) first empty slot is encountered,
+ /// c) we reach the end of the block.
+ ///
+ /// Optionally an index of the first slot to start the search from can be specified.
+ /// In this case slots before it will be ignored.
+ ///
+ /// \param[in] block 8 byte block of hash table
+ /// \param[in] stamp 7 bits of hash used as a stamp
+ /// \param[in] start_slot Index of the first slot in the block to start search from. We
+ /// assume that this index always points to a non-empty slot, equivalently
+ /// that it comes before any empty slots. (Used only by one template
+ /// variant.)
+ /// \param[out] out_slot index corresponding to the discovered position of interest (8
+ /// represents end of block).
+ /// \param[out] out_match_found an integer flag (0 or 1) indicating if we reached an
+ /// empty slot (0) or not (1). Therefore 1 can mean that either actual match was found
+ /// (case a) above) or we reached the end of full block (case b) above).
+ ///
+ template <bool use_start_slot>
+ inline void search_block(uint64_t block, int stamp, int start_slot, int* out_slot,
+ int* out_match_found) const;
+
+ /// \brief Extract group id for a given slot in a given block.
+ ///
+ inline uint64_t extract_group_id(const uint8_t* block_ptr, int slot,
+ uint64_t group_id_mask) const;
+ void extract_group_ids(const int num_keys, const uint16_t* optional_selection,
+ const uint32_t* hashes, const uint8_t* local_slots,
+ uint32_t* out_group_ids) const;
+
+ template <typename T, bool use_selection>
+ void extract_group_ids_imp(const int num_keys, const uint16_t* selection,
+ const uint32_t* hashes, const uint8_t* local_slots,
+ uint32_t* out_group_ids, int elements_offset,
+ int element_mutltiplier) const;
+
+ inline uint64_t next_slot_to_visit(uint64_t block_index, int slot,
+ int match_found) const;
+
+ inline uint64_t num_groups_for_resize() const;
+
+ inline uint64_t wrap_global_slot_id(uint64_t global_slot_id) const;
+
+ void init_slot_ids(const int num_keys, const uint16_t* selection,
+ const uint32_t* hashes, const uint8_t* local_slots,
+ const uint8_t* match_bitvector, uint32_t* out_slot_ids) const;
+
+ void init_slot_ids_for_new_keys(uint32_t num_ids, const uint16_t* ids,
+ const uint32_t* hashes, uint32_t* slot_ids) const;
+
+ // Quickly filter out keys that have no matches based only on hash value and the
+ // corresponding starting 64-bit block of slot status bytes. May return false positives.
+ //
+ void early_filter_imp(const int num_keys, const uint32_t* hashes,
+ uint8_t* out_match_bitvector, uint8_t* out_local_slots) const;
+#if defined(ARROW_HAVE_AVX2)
+ void early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes,
+ uint8_t* out_match_bitvector,
+ uint8_t* out_local_slots) const;
+ void early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes,
+ uint8_t* out_match_bitvector,
+ uint8_t* out_local_slots) const;
+ void extract_group_ids_avx2(const int num_keys, const uint32_t* hashes,
+ const uint8_t* local_slots, uint32_t* out_group_ids,
+ int byte_offset, int byte_multiplier, int byte_size) const;
+#endif
+
+ void run_comparisons(const int num_keys, const uint16_t* optional_selection_ids,
+ const uint8_t* optional_selection_bitvector,
+ const uint32_t* groupids, int* out_num_not_equal,
+ uint16_t* out_not_equal_selection) const;
+
+ inline bool find_next_stamp_match(const uint32_t hash, const uint32_t in_slot_id,
+ uint32_t* out_slot_id, uint32_t* out_group_id) const;
+
+ inline void insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id);
+
+ // Slow processing of input keys in the most generic case.
+ // Handles inserting new keys.
+ // Pre-existing keys will be handled correctly, although the intended use is for this
+ // call to follow a call to find() method, which would only pass on new keys that were
+ // not present in the hash table.
+ //
+ Status map_new_keys_helper(const uint32_t* hashes, uint32_t* inout_num_selected,
+ uint16_t* inout_selection, bool* out_need_resize,
+ uint32_t* out_group_ids, uint32_t* out_next_slot_ids);
+
+ // Resize small hash tables when 50% full (up to 8KB).
+ // Resize large hash tables when 75% full.
+ Status grow_double();
+
+ static int num_groupid_bits_from_log_blocks(int log_blocks) {
+ int required_bits = log_blocks + 3;
+ return required_bits <= 8 ? 8
+ : required_bits <= 16 ? 16 : required_bits <= 32 ? 32 : 64;
+ }
+
+ // Use 32-bit hash for now
+ static constexpr int bits_hash_ = 32;
+
+ // Number of hash bits stored in slots in a block.
+ // The highest bits of hash determine block id.
+ // The next set of highest bits is a "stamp" stored in a slot in a block.
+ static constexpr int bits_stamp_ = 7;
+
+ // Padding bytes added at the end of buffers for ease of SIMD access
+ static constexpr int padding_ = 64;
+
+ int log_minibatch_;
+ // Base 2 log of the number of blocks
+ int log_blocks_ = 0;
+ // Number of keys inserted into hash table
+ uint32_t num_inserted_ = 0;
+
+ // Data for blocks.
+ // Each block has 8 status bytes for 8 slots, followed by 8 bit packed group ids for
+ // these slots. In 8B status word, the order of bytes is reversed. Group ids are in
+ // normal order. There is 64B padding at the end.
+ //
+ // 0 byte - 7 bucket | 1. byte - 6 bucket | ...
+ // ---------------------------------------------------
+ // | Empty bit* | Empty bit |
+ // ---------------------------------------------------
+ // | 7-bit hash | 7-bit hash |
+ // ---------------------------------------------------
+ // * Empty bucket has value 0x80. Non-empty bucket has highest bit set to 0.
+ //
+ uint8_t* blocks_;
+
+ // Array of hashes of values inserted into slots.
+ // Undefined if the corresponding slot is empty.
+ // There is 64B padding at the end.
+ uint32_t* hashes_;
+
+ int64_t hardware_flags_;
+ MemoryPool* pool_;
+ util::TempVectorStack* temp_stack_;
+
+ EqualImpl equal_impl_;
+ AppendImpl append_impl_;
+};
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/key_map_avx2.cc b/src/arrow/cpp/src/arrow/compute/exec/key_map_avx2.cc
new file mode 100644
index 000000000..2fca6bf6c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/key_map_avx2.cc
@@ -0,0 +1,414 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <immintrin.h>
+
+#include "arrow/compute/exec/key_map.h"
+
+namespace arrow {
+namespace compute {
+
+#if defined(ARROW_HAVE_AVX2)
+
+// Why it is OK to round up number of rows internally:
+// All of the buffers: hashes, out_match_bitvector, out_group_ids, out_next_slot_ids
+// are temporary buffers of group id mapping.
+// Temporary buffers are buffers that live only within the boundaries of a single
+// minibatch. Temporary buffers add 64B at the end, so that SIMD code does not have to
+// worry about reading and writing outside of the end of the buffer up to 64B. If the
+// hashes array contains garbage after the last element, it cannot cause computation to
+// fail, since any random data is a valid hash for the purpose of lookup.
+//
+// This is more or less translation of equivalent scalar code, adjusted for a different
+// instruction set (e.g. missing leading zero count instruction).
+//
+void SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes,
+ uint8_t* out_match_bitvector,
+ uint8_t* out_local_slots) const {
+ // Number of inputs processed together in a loop
+ constexpr int unroll = 8;
+
+ const int num_group_id_bits = num_groupid_bits_from_log_blocks(log_blocks_);
+ const __m256i* vhash_ptr = reinterpret_cast<const __m256i*>(hashes);
+ const __m256i vstamp_mask = _mm256_set1_epi32((1 << bits_stamp_) - 1);
+
+ // TODO: explain why it is ok to process hashes outside of buffer boundaries
+ for (int i = 0; i < ((num_hashes + unroll - 1) / unroll); ++i) {
+ constexpr uint64_t kEachByteIs8 = 0x0808080808080808ULL;
+ constexpr uint64_t kByteSequenceOfPowersOf2 = 0x8040201008040201ULL;
+
+ // Calculate block index and hash stamp for a byte in a block
+ //
+ __m256i vhash = _mm256_loadu_si256(vhash_ptr + i);
+ __m256i vblock_id = _mm256_srlv_epi32(
+ vhash, _mm256_set1_epi32(bits_hash_ - bits_stamp_ - log_blocks_));
+ __m256i vstamp = _mm256_and_si256(vblock_id, vstamp_mask);
+ vblock_id = _mm256_srli_epi32(vblock_id, bits_stamp_);
+
+ // We now split inputs and process 4 at a time,
+ // in order to process 64-bit blocks
+ //
+ __m256i vblock_offset =
+ _mm256_mullo_epi32(vblock_id, _mm256_set1_epi32(num_group_id_bits + 8));
+ __m256i voffset_A = _mm256_and_si256(vblock_offset, _mm256_set1_epi64x(0xffffffff));
+ __m256i vstamp_A = _mm256_and_si256(vstamp, _mm256_set1_epi64x(0xffffffff));
+ __m256i voffset_B = _mm256_srli_epi64(vblock_offset, 32);
+ __m256i vstamp_B = _mm256_srli_epi64(vstamp, 32);
+
+ auto blocks_i64 = reinterpret_cast<arrow::util::int64_for_gather_t*>(blocks_);
+ auto vblock_A = _mm256_i64gather_epi64(blocks_i64, voffset_A, 1);
+ auto vblock_B = _mm256_i64gather_epi64(blocks_i64, voffset_B, 1);
+ __m256i vblock_highbits_A =
+ _mm256_cmpeq_epi8(vblock_A, _mm256_set1_epi8(static_cast<unsigned char>(0x80)));
+ __m256i vblock_highbits_B =
+ _mm256_cmpeq_epi8(vblock_B, _mm256_set1_epi8(static_cast<unsigned char>(0x80)));
+ __m256i vbyte_repeat_pattern =
+ _mm256_setr_epi64x(0ULL, kEachByteIs8, 0ULL, kEachByteIs8);
+ vstamp_A = _mm256_shuffle_epi8(
+ vstamp_A, _mm256_or_si256(vbyte_repeat_pattern, vblock_highbits_A));
+ vstamp_B = _mm256_shuffle_epi8(
+ vstamp_B, _mm256_or_si256(vbyte_repeat_pattern, vblock_highbits_B));
+ __m256i vmatches_A = _mm256_cmpeq_epi8(vblock_A, vstamp_A);
+ __m256i vmatches_B = _mm256_cmpeq_epi8(vblock_B, vstamp_B);
+
+ // In case when there are no matches in slots and the block is full (no empty slots),
+ // pretend that there is a match in the last slot.
+ //
+ vmatches_A = _mm256_or_si256(
+ vmatches_A, _mm256_andnot_si256(vblock_highbits_A, _mm256_set1_epi64x(0xff)));
+ vmatches_B = _mm256_or_si256(
+ vmatches_B, _mm256_andnot_si256(vblock_highbits_B, _mm256_set1_epi64x(0xff)));
+
+ __m256i vmatch_found = _mm256_andnot_si256(
+ _mm256_blend_epi32(_mm256_cmpeq_epi64(vmatches_A, _mm256_setzero_si256()),
+ _mm256_cmpeq_epi64(vmatches_B, _mm256_setzero_si256()),
+ 0xaa), // 0b10101010
+ _mm256_set1_epi8(static_cast<unsigned char>(0xff)));
+ vmatches_A =
+ _mm256_sad_epu8(_mm256_and_si256(_mm256_or_si256(vmatches_A, vblock_highbits_A),
+ _mm256_set1_epi64x(kByteSequenceOfPowersOf2)),
+ _mm256_setzero_si256());
+ vmatches_B =
+ _mm256_sad_epu8(_mm256_and_si256(_mm256_or_si256(vmatches_B, vblock_highbits_B),
+ _mm256_set1_epi64x(kByteSequenceOfPowersOf2)),
+ _mm256_setzero_si256());
+ __m256i vmatches = _mm256_or_si256(vmatches_A, _mm256_slli_epi64(vmatches_B, 32));
+
+ // We are now back to processing 8 at a time.
+ // Each lane contains 8-bit bit vector marking slots that are matches.
+ // We need to find leading zeroes count for all slots.
+ //
+ // Emulating lzcnt in lowest bytes of 32-bit elements
+ __m256i vgt = _mm256_cmpgt_epi32(_mm256_set1_epi32(16), vmatches);
+ __m256i vlocal_slot =
+ _mm256_blendv_epi8(_mm256_srli_epi32(vmatches, 4),
+ _mm256_and_si256(vmatches, _mm256_set1_epi32(0x0f)), vgt);
+ vlocal_slot = _mm256_shuffle_epi8(
+ _mm256_setr_epi8(4, 3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2, 2, 1, 1,
+ 1, 1, 0, 0, 0, 0, 0, 0, 0, 0),
+ vlocal_slot);
+ vlocal_slot = _mm256_add_epi32(_mm256_and_si256(vlocal_slot, _mm256_set1_epi32(0xff)),
+ _mm256_and_si256(vgt, _mm256_set1_epi32(4)));
+
+ // Convert slot id relative to the block to slot id relative to the beginnning of the
+ // table
+ //
+ uint64_t local_slot = _mm256_extract_epi64(
+ _mm256_permutevar8x32_epi32(
+ _mm256_shuffle_epi8(
+ vlocal_slot, _mm256_setr_epi32(0x0c080400, 0, 0, 0, 0x0c080400, 0, 0, 0)),
+ _mm256_setr_epi32(0, 4, 0, 0, 0, 0, 0, 0)),
+ 0);
+ (reinterpret_cast<uint64_t*>(out_local_slots))[i] = local_slot;
+
+ // Convert match found vector from 32-bit elements to bit vector
+ out_match_bitvector[i] = _pext_u32(_mm256_movemask_epi8(vmatch_found),
+ 0x11111111); // 0b00010001 repeated 4x
+ }
+}
+
+// Take a set of 16 64-bit elements,
+// Output one AVX2 register per byte (0 to 7), containing a sequence of 16 bytes,
+// one from each input 64-bit word, all from the same position in 64-bit word.
+// 16 bytes are replicated in lower and upper half of each output register.
+//
+inline void split_bytes_avx2(__m256i word0, __m256i word1, __m256i word2, __m256i word3,
+ __m256i& byte0, __m256i& byte1, __m256i& byte2,
+ __m256i& byte3, __m256i& byte4, __m256i& byte5,
+ __m256i& byte6, __m256i& byte7) {
+ __m256i word01lo = _mm256_unpacklo_epi8(
+ word0, word1); // {a0, e0, a1, e1, ... a7, e7, c0, g0, c1, g1, ... c7, g7}
+ __m256i word23lo = _mm256_unpacklo_epi8(
+ word2, word3); // {i0, m0, i1, m1, ... i7, m7, k0, o0, k1, o1, ... k7, o7}
+ __m256i word01hi = _mm256_unpackhi_epi8(
+ word0, word1); // {b0, f0, b1, f1, ... b7, f1, d0, h0, d1, h1, ... d7, h7}
+ __m256i word23hi = _mm256_unpackhi_epi8(
+ word2, word3); // {j0, n0, j1, n1, ... j7, n7, l0, p0, l1, p1, ... l7, p7}
+
+ __m256i a =
+ _mm256_unpacklo_epi16(word01lo, word01hi); // {a0, e0, b0, f0, ... a3, e3, b3, f3,
+ // c0, g0, d0, h0, ... c3, g3, d3, h3}
+ __m256i b =
+ _mm256_unpacklo_epi16(word23lo, word23hi); // {i0, m0, j0, n0, ... i3, m3, j3, n3,
+ // k0, o0, l0, p0, ... k3, o3, l3, p3}
+ __m256i c =
+ _mm256_unpackhi_epi16(word01lo, word01hi); // {a4, e4, b4, f4, ... a7, e7, b7, f7,
+ // c4, g4, d4, h4, ... c7, g7, d7, h7}
+ __m256i d =
+ _mm256_unpackhi_epi16(word23lo, word23hi); // {i4, m4, j4, n4, ... i7, m7, j7, n7,
+ // k4, o4, l4, p4, ... k7, o7, l7, p7}
+
+ __m256i byte01 = _mm256_unpacklo_epi32(
+ a, b); // {a0, e0, b0, f0, i0, m0, j0, n0, a1, e1, b1, f1, i1, m1, j1, n1, c0, g0,
+ // d0, h0, k0, o0, l0, p0, ...}
+ __m256i shuffle_const =
+ _mm256_setr_epi8(0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15, 0, 2, 8, 10,
+ 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15);
+ byte01 = _mm256_permute4x64_epi64(
+ byte01, 0xd8); // 11011000 b - swapping middle two 64-bit elements
+ byte01 = _mm256_shuffle_epi8(byte01, shuffle_const);
+ __m256i byte23 = _mm256_unpackhi_epi32(a, b);
+ byte23 = _mm256_permute4x64_epi64(byte23, 0xd8);
+ byte23 = _mm256_shuffle_epi8(byte23, shuffle_const);
+ __m256i byte45 = _mm256_unpacklo_epi32(c, d);
+ byte45 = _mm256_permute4x64_epi64(byte45, 0xd8);
+ byte45 = _mm256_shuffle_epi8(byte45, shuffle_const);
+ __m256i byte67 = _mm256_unpackhi_epi32(c, d);
+ byte67 = _mm256_permute4x64_epi64(byte67, 0xd8);
+ byte67 = _mm256_shuffle_epi8(byte67, shuffle_const);
+
+ byte0 = _mm256_permute4x64_epi64(byte01, 0x44); // 01000100 b
+ byte1 = _mm256_permute4x64_epi64(byte01, 0xee); // 11101110 b
+ byte2 = _mm256_permute4x64_epi64(byte23, 0x44); // 01000100 b
+ byte3 = _mm256_permute4x64_epi64(byte23, 0xee); // 11101110 b
+ byte4 = _mm256_permute4x64_epi64(byte45, 0x44); // 01000100 b
+ byte5 = _mm256_permute4x64_epi64(byte45, 0xee); // 11101110 b
+ byte6 = _mm256_permute4x64_epi64(byte67, 0x44); // 01000100 b
+ byte7 = _mm256_permute4x64_epi64(byte67, 0xee); // 11101110 b
+}
+
+// This one can only process a multiple of 32 values.
+// The caller needs to process the remaining tail, if the input is not divisible by 32,
+// using a different method.
+// TODO: Explain the idea behind storing arrays in SIMD registers.
+// Explain why it is faster with SIMD than using memory loads.
+void SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes,
+ uint8_t* out_match_bitvector,
+ uint8_t* out_local_slots) const {
+ constexpr int unroll = 32;
+
+ // There is a limit on the number of input blocks,
+ // because we want to store all their data in a set of AVX2 registers.
+ ARROW_DCHECK(log_blocks_ <= 4);
+
+ // Remember that block bytes and group id bytes are in opposite orders in memory of hash
+ // table. We put them in the same order.
+ __m256i vblock_byte0, vblock_byte1, vblock_byte2, vblock_byte3, vblock_byte4,
+ vblock_byte5, vblock_byte6, vblock_byte7;
+ // What we output if there is no match in the block
+ __m256i vslot_empty_or_end;
+
+ constexpr uint32_t k4ByteSequence_0_4_8_12 = 0x0c080400;
+ constexpr uint32_t k4ByteSequence_1_5_9_13 = 0x0d090501;
+ constexpr uint32_t k4ByteSequence_2_6_10_14 = 0x0e0a0602;
+ constexpr uint32_t k4ByteSequence_3_7_11_15 = 0x0f0b0703;
+ constexpr uint64_t kByteSequence7DownTo0 = 0x0001020304050607ULL;
+ constexpr uint64_t kByteSequence15DownTo8 = 0x08090A0B0C0D0E0FULL;
+
+ // Bit unpack group ids into 1B.
+ // Assemble the sequence of block bytes.
+ uint64_t block_bytes[16];
+ const int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
+ for (int i = 0; i < (1 << log_blocks_); ++i) {
+ uint64_t in_blockbytes =
+ *reinterpret_cast<const uint64_t*>(blocks_ + (8 + num_groupid_bits) * i);
+ block_bytes[i] = in_blockbytes;
+ }
+
+ // Split a sequence of 64-bit words into SIMD vectors holding individual bytes
+ __m256i vblock_words0 =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(block_bytes) + 0);
+ __m256i vblock_words1 =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(block_bytes) + 1);
+ __m256i vblock_words2 =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(block_bytes) + 2);
+ __m256i vblock_words3 =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(block_bytes) + 3);
+ // Reverse the bytes in blocks
+ __m256i vshuffle_const =
+ _mm256_setr_epi64x(kByteSequence7DownTo0, kByteSequence15DownTo8,
+ kByteSequence7DownTo0, kByteSequence15DownTo8);
+ vblock_words0 = _mm256_shuffle_epi8(vblock_words0, vshuffle_const);
+ vblock_words1 = _mm256_shuffle_epi8(vblock_words1, vshuffle_const);
+ vblock_words2 = _mm256_shuffle_epi8(vblock_words2, vshuffle_const);
+ vblock_words3 = _mm256_shuffle_epi8(vblock_words3, vshuffle_const);
+ split_bytes_avx2(vblock_words0, vblock_words1, vblock_words2, vblock_words3,
+ vblock_byte0, vblock_byte1, vblock_byte2, vblock_byte3, vblock_byte4,
+ vblock_byte5, vblock_byte6, vblock_byte7);
+
+ // Calculate the slot to output when there is no match in a block.
+ // It will be the index of the first empty slot or 7 (the number of slots in block)
+ // if there are no empty slots.
+ vslot_empty_or_end = _mm256_set1_epi8(7);
+ {
+ __m256i vis_empty;
+#define CMP(VBLOCKBYTE, BYTENUM) \
+ vis_empty = \
+ _mm256_cmpeq_epi8(VBLOCKBYTE, _mm256_set1_epi8(static_cast<unsigned char>(0x80))); \
+ vslot_empty_or_end = \
+ _mm256_blendv_epi8(vslot_empty_or_end, _mm256_set1_epi8(BYTENUM), vis_empty);
+ CMP(vblock_byte7, 7);
+ CMP(vblock_byte6, 6);
+ CMP(vblock_byte5, 5);
+ CMP(vblock_byte4, 4);
+ CMP(vblock_byte3, 3);
+ CMP(vblock_byte2, 2);
+ CMP(vblock_byte1, 1);
+ CMP(vblock_byte0, 0);
+#undef CMP
+ }
+ __m256i vblock_is_full = _mm256_andnot_si256(
+ _mm256_cmpeq_epi8(vblock_byte7, _mm256_set1_epi8(static_cast<unsigned char>(0x80))),
+ _mm256_set1_epi8(static_cast<unsigned char>(0xff)));
+
+ const int block_id_mask = (1 << log_blocks_) - 1;
+
+ for (int i = 0; i < num_hashes / unroll; ++i) {
+ __m256i vhash0 =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(hashes) + 4 * i + 0);
+ __m256i vhash1 =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(hashes) + 4 * i + 1);
+ __m256i vhash2 =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(hashes) + 4 * i + 2);
+ __m256i vhash3 =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(hashes) + 4 * i + 3);
+
+ // We will get input in byte lanes in the order: [0, 8, 16, 24, 1, 9, 17, 25, 2, 10,
+ // 18, 26, ...]
+ vhash0 = _mm256_or_si256(_mm256_srli_epi32(vhash0, 16),
+ _mm256_and_si256(vhash2, _mm256_set1_epi32(0xffff0000)));
+ vhash1 = _mm256_or_si256(_mm256_srli_epi32(vhash1, 16),
+ _mm256_and_si256(vhash3, _mm256_set1_epi32(0xffff0000)));
+ __m256i vstamp_A = _mm256_and_si256(
+ _mm256_srlv_epi32(vhash0, _mm256_set1_epi32(16 - log_blocks_ - 7)),
+ _mm256_set1_epi16(0x7f));
+ __m256i vstamp_B = _mm256_and_si256(
+ _mm256_srlv_epi32(vhash1, _mm256_set1_epi32(16 - log_blocks_ - 7)),
+ _mm256_set1_epi16(0x7f));
+ __m256i vstamp = _mm256_or_si256(vstamp_A, _mm256_slli_epi16(vstamp_B, 8));
+ __m256i vblock_id_A =
+ _mm256_and_si256(_mm256_srlv_epi32(vhash0, _mm256_set1_epi32(16 - log_blocks_)),
+ _mm256_set1_epi16(block_id_mask));
+ __m256i vblock_id_B =
+ _mm256_and_si256(_mm256_srlv_epi32(vhash1, _mm256_set1_epi32(16 - log_blocks_)),
+ _mm256_set1_epi16(block_id_mask));
+ __m256i vblock_id = _mm256_or_si256(vblock_id_A, _mm256_slli_epi16(vblock_id_B, 8));
+
+ // Visit all block bytes in reverse order (overwriting data on multiple matches)
+ //
+ // Always set match found to true for full blocks.
+ //
+ __m256i vmatch_found = _mm256_shuffle_epi8(vblock_is_full, vblock_id);
+ __m256i vslot_id = _mm256_shuffle_epi8(vslot_empty_or_end, vblock_id);
+#define CMP(VBLOCK_BYTE, BYTENUM) \
+ { \
+ __m256i vcmp = \
+ _mm256_cmpeq_epi8(_mm256_shuffle_epi8(VBLOCK_BYTE, vblock_id), vstamp); \
+ vmatch_found = _mm256_or_si256(vmatch_found, vcmp); \
+ vslot_id = _mm256_blendv_epi8(vslot_id, _mm256_set1_epi8(BYTENUM), vcmp); \
+ }
+ CMP(vblock_byte7, 7);
+ CMP(vblock_byte6, 6);
+ CMP(vblock_byte5, 5);
+ CMP(vblock_byte4, 4);
+ CMP(vblock_byte3, 3);
+ CMP(vblock_byte2, 2);
+ CMP(vblock_byte1, 1);
+ CMP(vblock_byte0, 0);
+#undef CMP
+
+ // So far the output is in the order: [0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, ...]
+ vmatch_found = _mm256_shuffle_epi8(
+ vmatch_found,
+ _mm256_setr_epi32(k4ByteSequence_0_4_8_12, k4ByteSequence_1_5_9_13,
+ k4ByteSequence_2_6_10_14, k4ByteSequence_3_7_11_15,
+ k4ByteSequence_0_4_8_12, k4ByteSequence_1_5_9_13,
+ k4ByteSequence_2_6_10_14, k4ByteSequence_3_7_11_15));
+ // Now it is: [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, | 4, 5, 6, 7,
+ // 12, 13, 14, 15, ...]
+ vmatch_found = _mm256_permutevar8x32_epi32(vmatch_found,
+ _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
+
+ // Repeat the same permutation for slot ids
+ vslot_id = _mm256_shuffle_epi8(
+ vslot_id, _mm256_setr_epi32(k4ByteSequence_0_4_8_12, k4ByteSequence_1_5_9_13,
+ k4ByteSequence_2_6_10_14, k4ByteSequence_3_7_11_15,
+ k4ByteSequence_0_4_8_12, k4ByteSequence_1_5_9_13,
+ k4ByteSequence_2_6_10_14, k4ByteSequence_3_7_11_15));
+ vslot_id =
+ _mm256_permutevar8x32_epi32(vslot_id, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_local_slots) + i, vslot_id);
+
+ reinterpret_cast<uint32_t*>(out_match_bitvector)[i] =
+ _mm256_movemask_epi8(vmatch_found);
+ }
+}
+
+void SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashes,
+ const uint8_t* local_slots,
+ uint32_t* out_group_ids, int byte_offset,
+ int byte_multiplier, int byte_size) const {
+ ARROW_DCHECK(byte_size == 1 || byte_size == 2 || byte_size == 4);
+ uint32_t mask = byte_size == 1 ? 0xFF : byte_size == 2 ? 0xFFFF : 0xFFFFFFFF;
+ auto elements = reinterpret_cast<const int*>(blocks_ + byte_offset);
+ constexpr int unroll = 8;
+ if (log_blocks_ == 0) {
+ ARROW_DCHECK(byte_size == 1 && byte_offset == 8 && byte_multiplier == 16);
+ __m256i block_group_ids =
+ _mm256_set1_epi64x(reinterpret_cast<const uint64_t*>(blocks_)[1]);
+ for (int i = 0; i < (num_keys + unroll - 1) / unroll; ++i) {
+ __m256i local_slot =
+ _mm256_set1_epi64x(reinterpret_cast<const uint64_t*>(local_slots)[i]);
+ __m256i group_id = _mm256_shuffle_epi8(block_group_ids, local_slot);
+ group_id = _mm256_shuffle_epi8(
+ group_id, _mm256_setr_epi32(0x80808000, 0x80808001, 0x80808002, 0x80808003,
+ 0x80808004, 0x80808005, 0x80808006, 0x80808007));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id);
+ }
+ } else {
+ for (int i = 0; i < (num_keys + unroll - 1) / unroll; ++i) {
+ __m256i hash = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(hashes) + i);
+ __m256i local_slot =
+ _mm256_set1_epi64x(reinterpret_cast<const uint64_t*>(local_slots)[i]);
+ local_slot = _mm256_shuffle_epi8(
+ local_slot, _mm256_setr_epi32(0x80808000, 0x80808001, 0x80808002, 0x80808003,
+ 0x80808004, 0x80808005, 0x80808006, 0x80808007));
+ local_slot = _mm256_mullo_epi32(local_slot, _mm256_set1_epi32(byte_size));
+ __m256i pos = _mm256_srlv_epi32(hash, _mm256_set1_epi32(bits_hash_ - log_blocks_));
+ pos = _mm256_mullo_epi32(pos, _mm256_set1_epi32(byte_multiplier));
+ pos = _mm256_add_epi32(pos, local_slot);
+ __m256i group_id = _mm256_i32gather_epi32(elements, pos, 1);
+ group_id = _mm256_and_si256(group_id, _mm256_set1_epi32(mask));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id);
+ }
+ }
+}
+
+#endif
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/options.h b/src/arrow/cpp/src/arrow/compute/exec/options.h
new file mode 100644
index 000000000..87349191e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/options.h
@@ -0,0 +1,265 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/expression.h"
+#include "arrow/util/async_util.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace compute {
+
+class ARROW_EXPORT ExecNodeOptions {
+ public:
+ virtual ~ExecNodeOptions() = default;
+};
+
+/// \brief Adapt an AsyncGenerator<ExecBatch> as a source node
+///
+/// plan->exec_context()->executor() will be used to parallelize pushing to
+/// outputs, if provided.
+class ARROW_EXPORT SourceNodeOptions : public ExecNodeOptions {
+ public:
+ SourceNodeOptions(std::shared_ptr<Schema> output_schema,
+ std::function<Future<util::optional<ExecBatch>>()> generator)
+ : output_schema(std::move(output_schema)), generator(std::move(generator)) {}
+
+ std::shared_ptr<Schema> output_schema;
+ std::function<Future<util::optional<ExecBatch>>()> generator;
+};
+
+/// \brief Make a node which excludes some rows from batches passed through it
+///
+/// filter_expression will be evaluated against each batch which is pushed to
+/// this node. Any rows for which filter_expression does not evaluate to `true` will be
+/// excluded in the batch emitted by this node.
+class ARROW_EXPORT FilterNodeOptions : public ExecNodeOptions {
+ public:
+ explicit FilterNodeOptions(Expression filter_expression, bool async_mode = true)
+ : filter_expression(std::move(filter_expression)), async_mode(async_mode) {}
+
+ Expression filter_expression;
+ bool async_mode;
+};
+
+/// \brief Make a node which executes expressions on input batches, producing new batches.
+///
+/// Each expression will be evaluated against each batch which is pushed to
+/// this node to produce a corresponding output column.
+///
+/// If names are not provided, the string representations of exprs will be used.
+class ARROW_EXPORT ProjectNodeOptions : public ExecNodeOptions {
+ public:
+ explicit ProjectNodeOptions(std::vector<Expression> expressions,
+ std::vector<std::string> names = {}, bool async_mode = true)
+ : expressions(std::move(expressions)),
+ names(std::move(names)),
+ async_mode(async_mode) {}
+
+ std::vector<Expression> expressions;
+ std::vector<std::string> names;
+ bool async_mode;
+};
+
+/// \brief Make a node which aggregates input batches, optionally grouped by keys.
+class ARROW_EXPORT AggregateNodeOptions : public ExecNodeOptions {
+ public:
+ AggregateNodeOptions(std::vector<internal::Aggregate> aggregates,
+ std::vector<FieldRef> targets, std::vector<std::string> names,
+ std::vector<FieldRef> keys = {})
+ : aggregates(std::move(aggregates)),
+ targets(std::move(targets)),
+ names(std::move(names)),
+ keys(std::move(keys)) {}
+
+ // aggregations which will be applied to the targetted fields
+ std::vector<internal::Aggregate> aggregates;
+ // fields to which aggregations will be applied
+ std::vector<FieldRef> targets;
+ // output field names for aggregations
+ std::vector<std::string> names;
+ // keys by which aggregations will be grouped
+ std::vector<FieldRef> keys;
+};
+
+/// \brief Add a sink node which forwards to an AsyncGenerator<ExecBatch>
+///
+/// Emitted batches will not be ordered.
+class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions {
+ public:
+ explicit SinkNodeOptions(std::function<Future<util::optional<ExecBatch>>()>* generator,
+ util::BackpressureOptions backpressure = {})
+ : generator(generator), backpressure(std::move(backpressure)) {}
+
+ std::function<Future<util::optional<ExecBatch>>()>* generator;
+ util::BackpressureOptions backpressure;
+};
+
+class ARROW_EXPORT SinkNodeConsumer {
+ public:
+ virtual ~SinkNodeConsumer() = default;
+ /// \brief Consume a batch of data
+ virtual Status Consume(ExecBatch batch) = 0;
+ /// \brief Signal to the consumer that the last batch has been delivered
+ ///
+ /// The returned future should only finish when all outstanding tasks have completed
+ virtual Future<> Finish() = 0;
+};
+
+/// \brief Add a sink node which consumes data within the exec plan run
+class ARROW_EXPORT ConsumingSinkNodeOptions : public ExecNodeOptions {
+ public:
+ explicit ConsumingSinkNodeOptions(std::shared_ptr<SinkNodeConsumer> consumer)
+ : consumer(std::move(consumer)) {}
+
+ std::shared_ptr<SinkNodeConsumer> consumer;
+};
+
+/// \brief Make a node which sorts rows passed through it
+///
+/// All batches pushed to this node will be accumulated, then sorted, by the given
+/// fields. Then sorted batches will be forwarded to the generator in sorted order.
+class ARROW_EXPORT OrderBySinkNodeOptions : public SinkNodeOptions {
+ public:
+ explicit OrderBySinkNodeOptions(
+ SortOptions sort_options,
+ std::function<Future<util::optional<ExecBatch>>()>* generator)
+ : SinkNodeOptions(generator), sort_options(std::move(sort_options)) {}
+
+ SortOptions sort_options;
+};
+
+enum class JoinType {
+ LEFT_SEMI,
+ RIGHT_SEMI,
+ LEFT_ANTI,
+ RIGHT_ANTI,
+ INNER,
+ LEFT_OUTER,
+ RIGHT_OUTER,
+ FULL_OUTER
+};
+
+enum class JoinKeyCmp { EQ, IS };
+
+/// \brief Make a node which implements join operation using hash join strategy.
+class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions {
+ public:
+ static constexpr const char* default_output_prefix_for_left = "";
+ static constexpr const char* default_output_prefix_for_right = "";
+ HashJoinNodeOptions(
+ JoinType in_join_type, std::vector<FieldRef> in_left_keys,
+ std::vector<FieldRef> in_right_keys,
+ std::string output_prefix_for_left = default_output_prefix_for_left,
+ std::string output_prefix_for_right = default_output_prefix_for_right)
+ : join_type(in_join_type),
+ left_keys(std::move(in_left_keys)),
+ right_keys(std::move(in_right_keys)),
+ output_all(true),
+ output_prefix_for_left(std::move(output_prefix_for_left)),
+ output_prefix_for_right(std::move(output_prefix_for_right)) {
+ this->key_cmp.resize(this->left_keys.size());
+ for (size_t i = 0; i < this->left_keys.size(); ++i) {
+ this->key_cmp[i] = JoinKeyCmp::EQ;
+ }
+ }
+ HashJoinNodeOptions(
+ JoinType join_type, std::vector<FieldRef> left_keys,
+ std::vector<FieldRef> right_keys, std::vector<FieldRef> left_output,
+ std::vector<FieldRef> right_output,
+ std::string output_prefix_for_left = default_output_prefix_for_left,
+ std::string output_prefix_for_right = default_output_prefix_for_right)
+ : join_type(join_type),
+ left_keys(std::move(left_keys)),
+ right_keys(std::move(right_keys)),
+ output_all(false),
+ left_output(std::move(left_output)),
+ right_output(std::move(right_output)),
+ output_prefix_for_left(std::move(output_prefix_for_left)),
+ output_prefix_for_right(std::move(output_prefix_for_right)) {
+ this->key_cmp.resize(this->left_keys.size());
+ for (size_t i = 0; i < this->left_keys.size(); ++i) {
+ this->key_cmp[i] = JoinKeyCmp::EQ;
+ }
+ }
+ HashJoinNodeOptions(
+ JoinType join_type, std::vector<FieldRef> left_keys,
+ std::vector<FieldRef> right_keys, std::vector<FieldRef> left_output,
+ std::vector<FieldRef> right_output, std::vector<JoinKeyCmp> key_cmp,
+ std::string output_prefix_for_left = default_output_prefix_for_left,
+ std::string output_prefix_for_right = default_output_prefix_for_right)
+ : join_type(join_type),
+ left_keys(std::move(left_keys)),
+ right_keys(std::move(right_keys)),
+ output_all(false),
+ left_output(std::move(left_output)),
+ right_output(std::move(right_output)),
+ key_cmp(std::move(key_cmp)),
+ output_prefix_for_left(std::move(output_prefix_for_left)),
+ output_prefix_for_right(std::move(output_prefix_for_right)) {}
+
+ // type of join (inner, left, semi...)
+ JoinType join_type;
+ // key fields from left input
+ std::vector<FieldRef> left_keys;
+ // key fields from right input
+ std::vector<FieldRef> right_keys;
+ // if set all valid fields from both left and right input will be output
+ // (and field ref vectors for output fields will be ignored)
+ bool output_all;
+ // output fields passed from left input
+ std::vector<FieldRef> left_output;
+ // output fields passed from right input
+ std::vector<FieldRef> right_output;
+ // key comparison function (determines whether a null key is equal another null key or
+ // not)
+ std::vector<JoinKeyCmp> key_cmp;
+ // prefix added to names of output fields coming from left input (used to distinguish,
+ // if necessary, between fields of the same name in left and right input and can be left
+ // empty if there are no name collisions)
+ std::string output_prefix_for_left;
+ // prefix added to names of output fields coming from right input
+ std::string output_prefix_for_right;
+};
+
+/// \brief Make a node which select top_k/bottom_k rows passed through it
+///
+/// All batches pushed to this node will be accumulated, then selected, by the given
+/// fields. Then sorted batches will be forwarded to the generator in sorted order.
+class ARROW_EXPORT SelectKSinkNodeOptions : public SinkNodeOptions {
+ public:
+ explicit SelectKSinkNodeOptions(
+ SelectKOptions select_k_options,
+ std::function<Future<util::optional<ExecBatch>>()>* generator)
+ : SinkNodeOptions(generator), select_k_options(std::move(select_k_options)) {}
+
+ /// SelectK options
+ SelectKOptions select_k_options;
+};
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/order_by_impl.cc b/src/arrow/cpp/src/arrow/compute/exec/order_by_impl.cc
new file mode 100644
index 000000000..4afcf884f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/order_by_impl.cc
@@ -0,0 +1,104 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/order_by_impl.h"
+
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <vector>
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+using internal::checked_cast;
+
+namespace compute {
+
+class SortBasicImpl : public OrderByImpl {
+ public:
+ SortBasicImpl(ExecContext* ctx, const std::shared_ptr<Schema>& output_schema,
+ const SortOptions& options = SortOptions{})
+ : ctx_(ctx), output_schema_(output_schema), options_(options) {}
+
+ void InputReceived(const std::shared_ptr<RecordBatch>& batch) override {
+ std::unique_lock<std::mutex> lock(mutex_);
+ batches_.push_back(batch);
+ }
+
+ Result<Datum> DoFinish() override {
+ std::unique_lock<std::mutex> lock(mutex_);
+ ARROW_ASSIGN_OR_RAISE(auto table,
+ Table::FromRecordBatches(output_schema_, std::move(batches_)));
+ ARROW_ASSIGN_OR_RAISE(auto indices, SortIndices(table, options_, ctx_));
+ return Take(table, indices, TakeOptions::NoBoundsCheck(), ctx_);
+ }
+
+ std::string ToString() const override { return options_.ToString(); }
+
+ protected:
+ ExecContext* ctx_;
+ std::shared_ptr<Schema> output_schema_;
+ std::mutex mutex_;
+ std::vector<std::shared_ptr<RecordBatch>> batches_;
+
+ private:
+ const SortOptions options_;
+}; // namespace compute
+
+class SelectKBasicImpl : public SortBasicImpl {
+ public:
+ SelectKBasicImpl(ExecContext* ctx, const std::shared_ptr<Schema>& output_schema,
+ const SelectKOptions& options)
+ : SortBasicImpl(ctx, output_schema), options_(options) {}
+
+ Result<Datum> DoFinish() override {
+ std::unique_lock<std::mutex> lock(mutex_);
+ ARROW_ASSIGN_OR_RAISE(auto table,
+ Table::FromRecordBatches(output_schema_, std::move(batches_)));
+ ARROW_ASSIGN_OR_RAISE(auto indices, SelectKUnstable(table, options_, ctx_));
+ return Take(table, indices, TakeOptions::NoBoundsCheck(), ctx_);
+ }
+
+ std::string ToString() const override { return options_.ToString(); }
+
+ private:
+ const SelectKOptions options_;
+};
+
+Result<std::unique_ptr<OrderByImpl>> OrderByImpl::MakeSort(
+ ExecContext* ctx, const std::shared_ptr<Schema>& output_schema,
+ const SortOptions& options) {
+ std::unique_ptr<OrderByImpl> impl{new SortBasicImpl(ctx, output_schema, options)};
+ return std::move(impl);
+}
+
+Result<std::unique_ptr<OrderByImpl>> OrderByImpl::MakeSelectK(
+ ExecContext* ctx, const std::shared_ptr<Schema>& output_schema,
+ const SelectKOptions& options) {
+ std::unique_ptr<OrderByImpl> impl{new SelectKBasicImpl(ctx, output_schema, options)};
+ return std::move(impl);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/order_by_impl.h b/src/arrow/cpp/src/arrow/compute/exec/order_by_impl.h
new file mode 100644
index 000000000..afc92aedd
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/order_by_impl.h
@@ -0,0 +1,53 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "arrow/compute/exec/options.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+
+namespace arrow {
+namespace compute {
+
+class OrderByImpl {
+ public:
+ virtual ~OrderByImpl() = default;
+
+ virtual void InputReceived(const std::shared_ptr<RecordBatch>& batch) = 0;
+
+ virtual Result<Datum> DoFinish() = 0;
+
+ virtual std::string ToString() const = 0;
+
+ static Result<std::unique_ptr<OrderByImpl>> MakeSort(
+ ExecContext* ctx, const std::shared_ptr<Schema>& output_schema,
+ const SortOptions& options);
+
+ static Result<std::unique_ptr<OrderByImpl>> MakeSelectK(
+ ExecContext* ctx, const std::shared_ptr<Schema>& output_schema,
+ const SelectKOptions& options);
+};
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/plan_test.cc b/src/arrow/cpp/src/arrow/compute/exec/plan_test.cc
new file mode 100644
index 000000000..54d807ef9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/plan_test.cc
@@ -0,0 +1,1226 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gmock/gmock-matchers.h>
+
+#include <functional>
+#include <memory>
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/expression.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/test_util.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/io/util_internal.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/thread_pool.h"
+#include "arrow/util/vector.h"
+
+using testing::ElementsAre;
+using testing::ElementsAreArray;
+using testing::HasSubstr;
+using testing::Optional;
+using testing::UnorderedElementsAreArray;
+
+namespace arrow {
+
+namespace compute {
+
+TEST(ExecPlanConstruction, Empty) {
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+
+ ASSERT_THAT(plan->Validate(), Raises(StatusCode::Invalid));
+}
+
+TEST(ExecPlanConstruction, SingleNode) {
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ auto node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}, /*num_outputs=*/0);
+ ASSERT_OK(plan->Validate());
+ ASSERT_THAT(plan->sources(), ElementsAre(node));
+ ASSERT_THAT(plan->sinks(), ElementsAre(node));
+
+ ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make());
+ node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}, /*num_outputs=*/1);
+ // Output not bound
+ ASSERT_THAT(plan->Validate(), Raises(StatusCode::Invalid));
+}
+
+TEST(ExecPlanConstruction, SourceSink) {
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ auto source = MakeDummyNode(plan.get(), "source", /*inputs=*/{}, /*num_outputs=*/1);
+ auto sink = MakeDummyNode(plan.get(), "sink", /*inputs=*/{source}, /*num_outputs=*/0);
+
+ ASSERT_OK(plan->Validate());
+ EXPECT_THAT(plan->sources(), ElementsAre(source));
+ EXPECT_THAT(plan->sinks(), ElementsAre(sink));
+}
+
+TEST(ExecPlanConstruction, MultipleNode) {
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+
+ auto source1 = MakeDummyNode(plan.get(), "source1", /*inputs=*/{}, /*num_outputs=*/2);
+
+ auto source2 = MakeDummyNode(plan.get(), "source2", /*inputs=*/{}, /*num_outputs=*/1);
+
+ auto process1 =
+ MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1}, /*num_outputs=*/2);
+
+ auto process2 = MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1, source2},
+ /*num_outputs=*/1);
+
+ auto process3 =
+ MakeDummyNode(plan.get(), "process3", /*inputs=*/{process1, process2, process1},
+ /*num_outputs=*/1);
+
+ auto sink = MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}, /*num_outputs=*/0);
+
+ ASSERT_OK(plan->Validate());
+ ASSERT_THAT(plan->sources(), ElementsAre(source1, source2));
+ ASSERT_THAT(plan->sinks(), ElementsAre(sink));
+}
+
+TEST(ExecPlanConstruction, AutoLabel) {
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ auto source1 = MakeDummyNode(plan.get(), "", /*inputs=*/{}, /*num_outputs=*/2);
+ auto source2 =
+ MakeDummyNode(plan.get(), "some_label", /*inputs=*/{}, /*num_outputs=*/1);
+ auto source3 = MakeDummyNode(plan.get(), "", /*inputs=*/{}, /*num_outputs=*/2);
+
+ ASSERT_EQ("0", source1->label());
+ ASSERT_EQ("some_label", source2->label());
+ ASSERT_EQ("2", source3->label());
+}
+
+struct StartStopTracker {
+ std::vector<std::string> started, stopped;
+
+ StartProducingFunc start_producing_func(Status st = Status::OK()) {
+ return [this, st](ExecNode* node) {
+ started.push_back(node->label());
+ return st;
+ };
+ }
+
+ StopProducingFunc stop_producing_func() {
+ return [this](ExecNode* node) { stopped.push_back(node->label()); };
+ }
+};
+
+TEST(ExecPlan, DummyStartProducing) {
+ StartStopTracker t;
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+
+ auto source1 = MakeDummyNode(plan.get(), "source1", /*inputs=*/{}, /*num_outputs=*/2,
+ t.start_producing_func(), t.stop_producing_func());
+
+ auto source2 = MakeDummyNode(plan.get(), "source2", /*inputs=*/{}, /*num_outputs=*/1,
+ t.start_producing_func(), t.stop_producing_func());
+
+ auto process1 =
+ MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1}, /*num_outputs=*/2,
+ t.start_producing_func(), t.stop_producing_func());
+
+ auto process2 =
+ MakeDummyNode(plan.get(), "process2", /*inputs=*/{process1, source2},
+ /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func());
+
+ auto process3 =
+ MakeDummyNode(plan.get(), "process3", /*inputs=*/{process1, source1, process2},
+ /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func());
+
+ MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}, /*num_outputs=*/0,
+ t.start_producing_func(), t.stop_producing_func());
+
+ ASSERT_OK(plan->Validate());
+ ASSERT_EQ(t.started.size(), 0);
+ ASSERT_EQ(t.stopped.size(), 0);
+
+ ASSERT_OK(plan->StartProducing());
+ // Note that any correct reverse topological order may do
+ ASSERT_THAT(t.started, ElementsAre("sink", "process3", "process2", "process1",
+ "source2", "source1"));
+
+ plan->StopProducing();
+ ASSERT_THAT(plan->finished(), Finishes(Ok()));
+ // Note that any correct topological order may do
+ ASSERT_THAT(t.stopped, ElementsAre("source1", "source2", "process1", "process2",
+ "process3", "sink"));
+
+ ASSERT_THAT(plan->StartProducing(),
+ Raises(StatusCode::Invalid, HasSubstr("restarted")));
+}
+
+TEST(ExecPlan, DummyStartProducingError) {
+ StartStopTracker t;
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ auto source1 = MakeDummyNode(
+ plan.get(), "source1", /*num_inputs=*/{}, /*num_outputs=*/2,
+ t.start_producing_func(Status::NotImplemented("zzz")), t.stop_producing_func());
+
+ auto source2 =
+ MakeDummyNode(plan.get(), "source2", /*num_inputs=*/{}, /*num_outputs=*/1,
+ t.start_producing_func(), t.stop_producing_func());
+
+ auto process1 = MakeDummyNode(
+ plan.get(), "process1", /*num_inputs=*/{source1}, /*num_outputs=*/2,
+ t.start_producing_func(Status::IOError("xxx")), t.stop_producing_func());
+
+ auto process2 =
+ MakeDummyNode(plan.get(), "process2", /*num_inputs=*/{process1, source2},
+ /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func());
+
+ auto process3 =
+ MakeDummyNode(plan.get(), "process3", /*num_inputs=*/{process1, source1, process2},
+ /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func());
+
+ MakeDummyNode(plan.get(), "sink", /*num_inputs=*/{process3}, /*num_outputs=*/0,
+ t.start_producing_func(), t.stop_producing_func());
+
+ ASSERT_OK(plan->Validate());
+ ASSERT_EQ(t.started.size(), 0);
+ ASSERT_EQ(t.stopped.size(), 0);
+
+ // `process1` raises IOError
+ ASSERT_THAT(plan->StartProducing(), Raises(StatusCode::IOError));
+ ASSERT_THAT(t.started, ElementsAre("sink", "process3", "process2", "process1"));
+ // Nodes that started successfully were stopped in reverse order
+ ASSERT_THAT(t.stopped, ElementsAre("process2", "process3", "sink"));
+}
+
+TEST(ExecPlanExecution, SourceSink) {
+ for (bool slow : {false, true}) {
+ SCOPED_TRACE(slow ? "slowed" : "unslowed");
+
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel" : "single threaded");
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ auto basic_data = MakeBasicBatches();
+
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{basic_data.schema,
+ basic_data.gen(parallel, slow)}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(UnorderedElementsAreArray(basic_data.batches))));
+ }
+ }
+}
+
+TEST(ExecPlanExecution, SinkNodeBackpressure) {
+ constexpr uint32_t kPauseIfAbove = 4;
+ constexpr uint32_t kResumeIfBelow = 2;
+ EXPECT_OK_AND_ASSIGN(std::shared_ptr<ExecPlan> plan, ExecPlan::Make());
+ PushGenerator<util::optional<ExecBatch>> batch_producer;
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+ util::BackpressureOptions backpressure_options =
+ util::BackpressureOptions::Make(kResumeIfBelow, kPauseIfAbove);
+ std::shared_ptr<Schema> schema_ = schema({field("data", uint32())});
+ ARROW_EXPECT_OK(compute::Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions(schema_, batch_producer)},
+ {"sink", SinkNodeOptions{&sink_gen, backpressure_options}},
+ })
+ .AddToPlan(plan.get()));
+ ARROW_EXPECT_OK(plan->StartProducing());
+
+ EXPECT_OK_AND_ASSIGN(util::optional<ExecBatch> batch, ExecBatch::Make({MakeScalar(0)}));
+ ASSERT_TRUE(backpressure_options.toggle->IsOpen());
+
+ // Should be able to push kPauseIfAbove batches without triggering back pressure
+ for (uint32_t i = 0; i < kPauseIfAbove; i++) {
+ batch_producer.producer().Push(batch);
+ }
+ SleepABit();
+ ASSERT_TRUE(backpressure_options.toggle->IsOpen());
+
+ // One more batch should trigger back pressure
+ batch_producer.producer().Push(batch);
+ BusyWait(10, [&] { return !backpressure_options.toggle->IsOpen(); });
+ ASSERT_FALSE(backpressure_options.toggle->IsOpen());
+
+ // Reading as much as we can while keeping it paused
+ for (uint32_t i = kPauseIfAbove; i >= kResumeIfBelow; i--) {
+ ASSERT_FINISHES_OK(sink_gen());
+ }
+ SleepABit();
+ ASSERT_FALSE(backpressure_options.toggle->IsOpen());
+
+ // Reading one more item should open up backpressure
+ ASSERT_FINISHES_OK(sink_gen());
+ BusyWait(10, [&] { return backpressure_options.toggle->IsOpen(); });
+ ASSERT_TRUE(backpressure_options.toggle->IsOpen());
+
+ // Cleanup
+ batch_producer.producer().Push(IterationEnd<util::optional<ExecBatch>>());
+ plan->StopProducing();
+ ASSERT_FINISHES_OK(plan->finished());
+}
+
+TEST(ExecPlan, ToString) {
+ auto basic_data = MakeBasicBatches();
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{basic_data.schema,
+ basic_data.gen(/*parallel=*/false,
+ /*slow=*/false)}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+ EXPECT_EQ(plan->sources()[0]->ToString(), R"(SourceNode{"source", outputs=["sink"]})");
+ EXPECT_EQ(plan->sinks()[0]->ToString(),
+ R"(SinkNode{"sink", inputs=[collected: "source"]})");
+ EXPECT_EQ(plan->ToString(), R"(ExecPlan with 2 nodes:
+SourceNode{"source", outputs=["sink"]}
+SinkNode{"sink", inputs=[collected: "source"]}
+)");
+
+ ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make());
+ CountOptions options(CountOptions::ONLY_VALID);
+ ASSERT_OK(
+ Declaration::Sequence(
+ {
+ {"source",
+ SourceNodeOptions{basic_data.schema,
+ basic_data.gen(/*parallel=*/false, /*slow=*/false)}},
+ {"filter", FilterNodeOptions{greater_equal(field_ref("i32"), literal(0))}},
+ {"project", ProjectNodeOptions{{
+ field_ref("bool"),
+ call("multiply", {field_ref("i32"), literal(2)}),
+ }}},
+ {"aggregate",
+ AggregateNodeOptions{
+ /*aggregates=*/{{"hash_sum", nullptr}, {"hash_count", &options}},
+ /*targets=*/{"multiply(i32, 2)", "multiply(i32, 2)"},
+ /*names=*/{"sum(multiply(i32, 2))", "count(multiply(i32, 2))"},
+ /*keys=*/{"bool"}}},
+ {"filter", FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"),
+ literal(10))}},
+ {"order_by_sink",
+ OrderBySinkNodeOptions{
+ SortOptions({SortKey{"sum(multiply(i32, 2))", SortOrder::Ascending}}),
+ &sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+ EXPECT_EQ(plan->ToString(), R"a(ExecPlan with 6 nodes:
+SourceNode{"source", outputs=["filter"]}
+FilterNode{"filter", inputs=[target: "source"], outputs=["project"], filter=(i32 >= 0)}
+ProjectNode{"project", inputs=[target: "filter"], outputs=["aggregate"], projection=[bool, multiply(i32, 2)]}
+GroupByNode{"aggregate", inputs=[groupby: "project"], outputs=["filter"], keys=["bool"], aggregates=[
+ hash_sum(multiply(i32, 2)),
+ hash_count(multiply(i32, 2), {mode=NON_NULL}),
+]}
+FilterNode{"filter", inputs=[target: "aggregate"], outputs=["order_by_sink"], filter=(sum(multiply(i32, 2)) > 10)}
+OrderBySinkNode{"order_by_sink", inputs=[collected: "filter"], by={sort_keys=[sum(multiply(i32, 2)) ASC], null_placement=AtEnd}}
+)a");
+
+ ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make());
+ Declaration union_node{"union", ExecNodeOptions{}};
+ Declaration lhs{"source",
+ SourceNodeOptions{basic_data.schema,
+ basic_data.gen(/*parallel=*/false, /*slow=*/false)}};
+ lhs.label = "lhs";
+ Declaration rhs{"source",
+ SourceNodeOptions{basic_data.schema,
+ basic_data.gen(/*parallel=*/false, /*slow=*/false)}};
+ rhs.label = "rhs";
+ union_node.inputs.emplace_back(lhs);
+ union_node.inputs.emplace_back(rhs);
+ ASSERT_OK(
+ Declaration::Sequence(
+ {
+ union_node,
+ {"aggregate", AggregateNodeOptions{/*aggregates=*/{{"count", &options}},
+ /*targets=*/{"i32"},
+ /*names=*/{"count(i32)"},
+ /*keys=*/{}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+ EXPECT_EQ(plan->ToString(), R"a(ExecPlan with 5 nodes:
+SourceNode{"lhs", outputs=["union"]}
+SourceNode{"rhs", outputs=["union"]}
+UnionNode{"union", inputs=[input_0_label: "lhs", input_1_label: "rhs"], outputs=["aggregate"]}
+ScalarAggregateNode{"aggregate", inputs=[target: "union"], outputs=["sink"], aggregates=[
+ count(i32, {mode=NON_NULL}),
+]}
+SinkNode{"sink", inputs=[collected: "aggregate"]}
+)a");
+}
+
+TEST(ExecPlanExecution, SourceOrderBy) {
+ std::vector<ExecBatch> expected = {
+ ExecBatchFromJSON({int32(), boolean()},
+ "[[4, false], [5, null], [6, false], [7, false], [null, true]]")};
+ for (bool slow : {false, true}) {
+ SCOPED_TRACE(slow ? "slowed" : "unslowed");
+
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel" : "single threaded");
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ auto basic_data = MakeBasicBatches();
+
+ SortOptions options({SortKey("i32", SortOrder::Ascending)});
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{basic_data.schema,
+ basic_data.gen(parallel, slow)}},
+ {"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(ElementsAreArray(expected))));
+ }
+ }
+}
+
+TEST(ExecPlanExecution, SourceSinkError) {
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ auto basic_data = MakeBasicBatches();
+ auto it = basic_data.batches.begin();
+ AsyncGenerator<util::optional<ExecBatch>> error_source_gen =
+ [&]() -> Result<util::optional<ExecBatch>> {
+ if (it == basic_data.batches.end()) {
+ return Status::Invalid("Artificial error");
+ }
+ return util::make_optional(*it++);
+ };
+
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{basic_data.schema, error_source_gen}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(Raises(StatusCode::Invalid, HasSubstr("Artificial"))));
+}
+
+TEST(ExecPlanExecution, SourceConsumingSink) {
+ for (bool slow : {false, true}) {
+ SCOPED_TRACE(slow ? "slowed" : "unslowed");
+
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel" : "single threaded");
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ std::atomic<uint32_t> batches_seen{0};
+ Future<> finish = Future<>::Make();
+ struct TestConsumer : public SinkNodeConsumer {
+ TestConsumer(std::atomic<uint32_t>* batches_seen, Future<> finish)
+ : batches_seen(batches_seen), finish(std::move(finish)) {}
+
+ Status Consume(ExecBatch batch) override {
+ (*batches_seen)++;
+ return Status::OK();
+ }
+
+ Future<> Finish() override { return finish; }
+
+ std::atomic<uint32_t>* batches_seen;
+ Future<> finish;
+ };
+ std::shared_ptr<TestConsumer> consumer =
+ std::make_shared<TestConsumer>(&batches_seen, finish);
+
+ auto basic_data = MakeBasicBatches();
+ ASSERT_OK_AND_ASSIGN(
+ auto source, MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions(basic_data.schema,
+ basic_data.gen(parallel, slow))));
+ ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source},
+ ConsumingSinkNodeOptions(consumer)));
+ ASSERT_OK(plan->StartProducing());
+ // Source should finish fairly quickly
+ ASSERT_FINISHES_OK(source->finished());
+ SleepABit();
+ ASSERT_EQ(2, batches_seen);
+ // Consumer isn't finished and so plan shouldn't have finished
+ AssertNotFinished(plan->finished());
+ // Mark consumption complete, plan should finish
+ finish.MarkFinished();
+ ASSERT_FINISHES_OK(plan->finished());
+ }
+ }
+}
+
+TEST(ExecPlanExecution, ConsumingSinkError) {
+ struct ConsumeErrorConsumer : public SinkNodeConsumer {
+ Status Consume(ExecBatch batch) override { return Status::Invalid("XYZ"); }
+ Future<> Finish() override { return Future<>::MakeFinished(); }
+ };
+ struct FinishErrorConsumer : public SinkNodeConsumer {
+ Status Consume(ExecBatch batch) override { return Status::OK(); }
+ Future<> Finish() override { return Future<>::MakeFinished(Status::Invalid("XYZ")); }
+ };
+ std::vector<std::shared_ptr<SinkNodeConsumer>> consumers{
+ std::make_shared<ConsumeErrorConsumer>(), std::make_shared<FinishErrorConsumer>()};
+
+ for (auto& consumer : consumers) {
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ auto basic_data = MakeBasicBatches();
+ ASSERT_OK(Declaration::Sequence(
+ {{"source",
+ SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))},
+ {"consuming_sink", ConsumingSinkNodeOptions(consumer)}})
+ .AddToPlan(plan.get()));
+ ASSERT_OK_AND_ASSIGN(
+ auto source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))));
+ ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source},
+ ConsumingSinkNodeOptions(consumer)));
+ ASSERT_OK(plan->StartProducing());
+ ASSERT_FINISHES_AND_RAISES(Invalid, plan->finished());
+ }
+}
+
+TEST(ExecPlanExecution, ConsumingSinkErrorFinish) {
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ struct FinishErrorConsumer : public SinkNodeConsumer {
+ Status Consume(ExecBatch batch) override { return Status::OK(); }
+ Future<> Finish() override { return Future<>::MakeFinished(Status::Invalid("XYZ")); }
+ };
+ std::shared_ptr<FinishErrorConsumer> consumer = std::make_shared<FinishErrorConsumer>();
+
+ auto basic_data = MakeBasicBatches();
+ ASSERT_OK(
+ Declaration::Sequence(
+ {{"source", SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))},
+ {"consuming_sink", ConsumingSinkNodeOptions(consumer)}})
+ .AddToPlan(plan.get()));
+ ASSERT_OK_AND_ASSIGN(
+ auto source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))));
+ ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source},
+ ConsumingSinkNodeOptions(consumer)));
+ ASSERT_OK(plan->StartProducing());
+ ASSERT_FINISHES_AND_RAISES(Invalid, plan->finished());
+}
+
+TEST(ExecPlanExecution, StressSourceSink) {
+ for (bool slow : {false, true}) {
+ SCOPED_TRACE(slow ? "slowed" : "unslowed");
+
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel" : "single threaded");
+
+ int num_batches = (slow && !parallel) ? 30 : 300;
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ auto random_data = MakeRandomBatches(
+ schema({field("a", int32()), field("b", boolean())}), num_batches);
+
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{random_data.schema,
+ random_data.gen(parallel, slow)}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(UnorderedElementsAreArray(random_data.batches))));
+ }
+ }
+}
+
+TEST(ExecPlanExecution, StressSourceOrderBy) {
+ auto input_schema = schema({field("a", int32()), field("b", boolean())});
+ for (bool slow : {false, true}) {
+ SCOPED_TRACE(slow ? "slowed" : "unslowed");
+
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel" : "single threaded");
+
+ int num_batches = (slow && !parallel) ? 30 : 300;
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ auto random_data = MakeRandomBatches(input_schema, num_batches);
+
+ SortOptions options({SortKey("a", SortOrder::Ascending)});
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{random_data.schema,
+ random_data.gen(parallel, slow)}},
+ {"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ // Check that data is sorted appropriately
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto exec_batches,
+ StartAndCollect(plan.get(), sink_gen));
+ ASSERT_OK_AND_ASSIGN(auto actual, TableFromExecBatches(input_schema, exec_batches));
+ ASSERT_OK_AND_ASSIGN(auto original,
+ TableFromExecBatches(input_schema, random_data.batches));
+ ASSERT_OK_AND_ASSIGN(auto sort_indices, SortIndices(original, options));
+ ASSERT_OK_AND_ASSIGN(auto expected, Take(original, sort_indices));
+ AssertTablesEqual(*actual, *expected.table());
+ }
+ }
+}
+
+TEST(ExecPlanExecution, StressSourceGroupedSumStop) {
+ auto input_schema = schema({field("a", int32()), field("b", boolean())});
+ for (bool slow : {false, true}) {
+ SCOPED_TRACE(slow ? "slowed" : "unslowed");
+
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel" : "single threaded");
+
+ int num_batches = (slow && !parallel) ? 30 : 300;
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ auto random_data = MakeRandomBatches(input_schema, num_batches);
+
+ SortOptions options({SortKey("a", SortOrder::Ascending)});
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{random_data.schema,
+ random_data.gen(parallel, slow)}},
+ {"aggregate",
+ AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
+ /*targets=*/{"a"}, /*names=*/{"sum(a)"},
+ /*keys=*/{"b"}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_OK(plan->Validate());
+ ASSERT_OK(plan->StartProducing());
+ plan->StopProducing();
+ ASSERT_FINISHES_OK(plan->finished());
+ }
+ }
+}
+
+TEST(ExecPlanExecution, StressSourceSinkStopped) {
+ for (bool slow : {false, true}) {
+ SCOPED_TRACE(slow ? "slowed" : "unslowed");
+
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel" : "single threaded");
+
+ int num_batches = (slow && !parallel) ? 30 : 300;
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ auto random_data = MakeRandomBatches(
+ schema({field("a", int32()), field("b", boolean())}), num_batches);
+
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{random_data.schema,
+ random_data.gen(parallel, slow)}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_OK(plan->Validate());
+ ASSERT_OK(plan->StartProducing());
+
+ EXPECT_THAT(sink_gen(), Finishes(ResultWith(Optional(random_data.batches[0]))));
+
+ plan->StopProducing();
+ ASSERT_THAT(plan->finished(), Finishes(Ok()));
+ }
+ }
+}
+
+TEST(ExecPlanExecution, SourceFilterSink) {
+ auto basic_data = MakeBasicBatches();
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{basic_data.schema,
+ basic_data.gen(/*parallel=*/false,
+ /*slow=*/false)}},
+ {"filter", FilterNodeOptions{equal(field_ref("i32"), literal(6))}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(UnorderedElementsAreArray(
+ {ExecBatchFromJSON({int32(), boolean()}, "[]"),
+ ExecBatchFromJSON({int32(), boolean()}, "[[6, false]]")}))));
+}
+
+TEST(ExecPlanExecution, SourceProjectSink) {
+ auto basic_data = MakeBasicBatches();
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{basic_data.schema,
+ basic_data.gen(/*parallel=*/false,
+ /*slow=*/false)}},
+ {"project",
+ ProjectNodeOptions{{
+ not_(field_ref("bool")),
+ call("add", {field_ref("i32"), literal(1)}),
+ },
+ {"!bool", "i32 + 1"}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(UnorderedElementsAreArray(
+ {ExecBatchFromJSON({boolean(), int32()}, "[[false, null], [true, 5]]"),
+ ExecBatchFromJSON({boolean(), int32()},
+ "[[null, 6], [true, 7], [true, 8]]")}))));
+}
+
+namespace {
+
+BatchesWithSchema MakeGroupableBatches(int multiplicity = 1) {
+ BatchesWithSchema out;
+
+ out.batches = {ExecBatchFromJSON({int32(), utf8()}, R"([
+ [12, "alfa"],
+ [7, "beta"],
+ [3, "alfa"]
+ ])"),
+ ExecBatchFromJSON({int32(), utf8()}, R"([
+ [-2, "alfa"],
+ [-1, "gama"],
+ [3, "alfa"]
+ ])"),
+ ExecBatchFromJSON({int32(), utf8()}, R"([
+ [5, "gama"],
+ [3, "beta"],
+ [-8, "alfa"]
+ ])")};
+
+ size_t batch_count = out.batches.size();
+ for (int repeat = 1; repeat < multiplicity; ++repeat) {
+ for (size_t i = 0; i < batch_count; ++i) {
+ out.batches.push_back(out.batches[i]);
+ }
+ }
+
+ out.schema = schema({field("i32", int32()), field("str", utf8())});
+
+ return out;
+}
+
+} // namespace
+
+TEST(ExecPlanExecution, SourceGroupedSum) {
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel/merged" : "serial");
+
+ auto input = MakeGroupableBatches(/*multiplicity=*/parallel ? 100 : 1);
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{input.schema,
+ input.gen(parallel, /*slow=*/false)}},
+ {"aggregate",
+ AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
+ /*targets=*/{"i32"}, /*names=*/{"sum(i32)"},
+ /*keys=*/{"str"}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(UnorderedElementsAreArray({ExecBatchFromJSON(
+ {int64(), utf8()},
+ parallel ? R"([[800, "alfa"], [1000, "beta"], [400, "gama"]])"
+ : R"([[8, "alfa"], [10, "beta"], [4, "gama"]])")}))));
+ }
+}
+
+TEST(ExecPlanExecution, SourceFilterProjectGroupedSumFilter) {
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel/merged" : "serial");
+
+ int batch_multiplicity = parallel ? 100 : 1;
+ auto input = MakeGroupableBatches(/*multiplicity=*/batch_multiplicity);
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ ASSERT_OK(
+ Declaration::Sequence(
+ {
+ {"source",
+ SourceNodeOptions{input.schema, input.gen(parallel, /*slow=*/false)}},
+ {"filter",
+ FilterNodeOptions{greater_equal(field_ref("i32"), literal(0))}},
+ {"project", ProjectNodeOptions{{
+ field_ref("str"),
+ call("multiply", {field_ref("i32"), literal(2)}),
+ }}},
+ {"aggregate", AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
+ /*targets=*/{"multiply(i32, 2)"},
+ /*names=*/{"sum(multiply(i32, 2))"},
+ /*keys=*/{"str"}}},
+ {"filter", FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"),
+ literal(10 * batch_multiplicity))}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(UnorderedElementsAreArray({ExecBatchFromJSON(
+ {int64(), utf8()}, parallel ? R"([[3600, "alfa"], [2000, "beta"]])"
+ : R"([[36, "alfa"], [20, "beta"]])")}))));
+ }
+}
+
+TEST(ExecPlanExecution, SourceFilterProjectGroupedSumOrderBy) {
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel/merged" : "serial");
+
+ int batch_multiplicity = parallel ? 100 : 1;
+ auto input = MakeGroupableBatches(/*multiplicity=*/batch_multiplicity);
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ SortOptions options({SortKey("str", SortOrder::Descending)});
+ ASSERT_OK(
+ Declaration::Sequence(
+ {
+ {"source",
+ SourceNodeOptions{input.schema, input.gen(parallel, /*slow=*/false)}},
+ {"filter",
+ FilterNodeOptions{greater_equal(field_ref("i32"), literal(0))}},
+ {"project", ProjectNodeOptions{{
+ field_ref("str"),
+ call("multiply", {field_ref("i32"), literal(2)}),
+ }}},
+ {"aggregate", AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
+ /*targets=*/{"multiply(i32, 2)"},
+ /*names=*/{"sum(multiply(i32, 2))"},
+ /*keys=*/{"str"}}},
+ {"filter", FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"),
+ literal(10 * batch_multiplicity))}},
+ {"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(ElementsAreArray({ExecBatchFromJSON(
+ {int64(), utf8()}, parallel ? R"([[2000, "beta"], [3600, "alfa"]])"
+ : R"([[20, "beta"], [36, "alfa"]])")}))));
+ }
+}
+
+TEST(ExecPlanExecution, SourceFilterProjectGroupedSumTopK) {
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel/merged" : "serial");
+
+ int batch_multiplicity = parallel ? 100 : 1;
+ auto input = MakeGroupableBatches(/*multiplicity=*/batch_multiplicity);
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ SelectKOptions options = SelectKOptions::TopKDefault(/*k=*/1, {"str"});
+ ASSERT_OK(
+ Declaration::Sequence(
+ {
+ {"source",
+ SourceNodeOptions{input.schema, input.gen(parallel, /*slow=*/false)}},
+ {"project", ProjectNodeOptions{{
+ field_ref("str"),
+ call("multiply", {field_ref("i32"), literal(2)}),
+ }}},
+ {"aggregate", AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
+ /*targets=*/{"multiply(i32, 2)"},
+ /*names=*/{"sum(multiply(i32, 2))"},
+ /*keys=*/{"str"}}},
+ {"select_k_sink", SelectKSinkNodeOptions{options, &sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(
+ StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(ElementsAreArray({ExecBatchFromJSON(
+ {int64(), utf8()}, parallel ? R"([[800, "gama"]])" : R"([[8, "gama"]])")}))));
+ }
+}
+
+TEST(ExecPlanExecution, SourceScalarAggSink) {
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ auto basic_data = MakeBasicBatches();
+
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{basic_data.schema,
+ basic_data.gen(/*parallel=*/false,
+ /*slow=*/false)}},
+ {"aggregate", AggregateNodeOptions{
+ /*aggregates=*/{{"sum", nullptr}, {"any", nullptr}},
+ /*targets=*/{"i32", "bool"},
+ /*names=*/{"sum(i32)", "any(bool)"}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(
+ StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(UnorderedElementsAreArray({
+ ExecBatchFromJSON({ValueDescr::Scalar(int64()), ValueDescr::Scalar(boolean())},
+ "[[22, true]]"),
+ }))));
+}
+
+TEST(ExecPlanExecution, AggregationPreservesOptions) {
+ // ARROW-13638: aggregation nodes initialize per-thread kernel state lazily
+ // and need to keep a copy/strong reference to function options
+ {
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ auto basic_data = MakeBasicBatches();
+
+ {
+ auto options = std::make_shared<TDigestOptions>(TDigestOptions::Defaults());
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{basic_data.schema,
+ basic_data.gen(/*parallel=*/false,
+ /*slow=*/false)}},
+ {"aggregate",
+ AggregateNodeOptions{/*aggregates=*/{{"tdigest", options.get()}},
+ /*targets=*/{"i32"},
+ /*names=*/{"tdigest(i32)"}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+ }
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(UnorderedElementsAreArray({
+ ExecBatchFromJSON({ValueDescr::Array(float64())}, "[[5.5]]"),
+ }))));
+ }
+ {
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ auto data = MakeGroupableBatches(/*multiplicity=*/100);
+
+ {
+ auto options = std::make_shared<CountOptions>(CountOptions::Defaults());
+ ASSERT_OK(
+ Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{data.schema, data.gen(/*parallel=*/false,
+ /*slow=*/false)}},
+ {"aggregate",
+ AggregateNodeOptions{/*aggregates=*/{{"hash_count", options.get()}},
+ /*targets=*/{"i32"},
+ /*names=*/{"count(i32)"},
+ /*keys=*/{"str"}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+ }
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(UnorderedElementsAreArray({
+ ExecBatchFromJSON({int64(), utf8()},
+ R"([[500, "alfa"], [200, "beta"], [200, "gama"]])"),
+ }))));
+ }
+}
+
+TEST(ExecPlanExecution, ScalarSourceScalarAggSink) {
+ // ARROW-9056: scalar aggregation can be done over scalars, taking
+ // into account batch.length > 1 (e.g. a partition column)
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ BatchesWithSchema scalar_data;
+ scalar_data.batches = {
+ ExecBatchFromJSON({ValueDescr::Scalar(int32()), ValueDescr::Scalar(boolean())},
+ "[[5, false], [5, false], [5, false]]"),
+ ExecBatchFromJSON({int32(), boolean()}, "[[5, true], [6, false], [7, true]]")};
+ scalar_data.schema = schema({field("a", int32()), field("b", boolean())});
+
+ // index can't be tested as it's order-dependent
+ // mode/quantile can't be tested as they're technically vector kernels
+ ASSERT_OK(
+ Declaration::Sequence(
+ {
+ {"source",
+ SourceNodeOptions{scalar_data.schema, scalar_data.gen(/*parallel=*/false,
+ /*slow=*/false)}},
+ {"aggregate", AggregateNodeOptions{
+ /*aggregates=*/{{"all", nullptr},
+ {"any", nullptr},
+ {"count", nullptr},
+ {"mean", nullptr},
+ {"product", nullptr},
+ {"stddev", nullptr},
+ {"sum", nullptr},
+ {"tdigest", nullptr},
+ {"variance", nullptr}},
+ /*targets=*/{"b", "b", "a", "a", "a", "a", "a", "a", "a"},
+ /*names=*/
+ {"all(b)", "any(b)", "count(a)", "mean(a)", "product(a)",
+ "stddev(a)", "sum(a)", "tdigest(a)", "variance(a)"}}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(
+ StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(UnorderedElementsAreArray({
+ ExecBatchFromJSON(
+ {ValueDescr::Scalar(boolean()), ValueDescr::Scalar(boolean()),
+ ValueDescr::Scalar(int64()), ValueDescr::Scalar(float64()),
+ ValueDescr::Scalar(int64()), ValueDescr::Scalar(float64()),
+ ValueDescr::Scalar(int64()), ValueDescr::Array(float64()),
+ ValueDescr::Scalar(float64())},
+ R"([[false, true, 6, 5.5, 26250, 0.7637626158259734, 33, 5.0, 0.5833333333333334]])"),
+ }))));
+}
+
+TEST(ExecPlanExecution, ScalarSourceGroupedSum) {
+ // ARROW-14630: ensure grouped aggregation with a scalar key/array input doesn't error
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ BatchesWithSchema scalar_data;
+ scalar_data.batches = {
+ ExecBatchFromJSON({int32(), ValueDescr::Scalar(boolean())},
+ "[[5, false], [6, false], [7, false]]"),
+ ExecBatchFromJSON({int32(), ValueDescr::Scalar(boolean())},
+ "[[1, true], [2, true], [3, true]]"),
+ };
+ scalar_data.schema = schema({field("a", int32()), field("b", boolean())});
+
+ SortOptions options({SortKey("b", SortOrder::Descending)});
+ ASSERT_OK(Declaration::Sequence(
+ {
+ {"source", SourceNodeOptions{scalar_data.schema,
+ scalar_data.gen(/*parallel=*/false,
+ /*slow=*/false)}},
+ {"aggregate",
+ AggregateNodeOptions{/*aggregates=*/{{"hash_sum", nullptr}},
+ /*targets=*/{"a"}, /*names=*/{"hash_sum(a)"},
+ /*keys=*/{"b"}}},
+ {"order_by_sink", OrderBySinkNodeOptions{options, &sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(UnorderedElementsAreArray({
+ ExecBatchFromJSON({int64(), boolean()}, R"([[6, true], [18, false]])"),
+ }))));
+}
+
+TEST(ExecPlanExecution, SelfInnerHashJoinSink) {
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel/merged" : "serial");
+
+ auto input = MakeGroupableBatches();
+
+ auto exec_ctx = arrow::internal::make_unique<ExecContext>(
+ default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ ExecNode* left_source;
+ ExecNode* right_source;
+ for (auto source : {&left_source, &right_source}) {
+ ASSERT_OK_AND_ASSIGN(
+ *source, MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{input.schema,
+ input.gen(parallel, /*slow=*/false)}));
+ }
+ ASSERT_OK_AND_ASSIGN(
+ auto left_filter,
+ MakeExecNode("filter", plan.get(), {left_source},
+ FilterNodeOptions{greater_equal(field_ref("i32"), literal(-1))}));
+ ASSERT_OK_AND_ASSIGN(
+ auto right_filter,
+ MakeExecNode("filter", plan.get(), {right_source},
+ FilterNodeOptions{less_equal(field_ref("i32"), literal(2))}));
+
+ // left side: [3, "alfa"], [3, "alfa"], [12, "alfa"], [3, "beta"], [7, "beta"],
+ // [-1, "gama"], [5, "gama"]
+ // right side: [-2, "alfa"], [-8, "alfa"], [-1, "gama"]
+
+ HashJoinNodeOptions join_opts{JoinType::INNER,
+ /*left_keys=*/{"str"},
+ /*right_keys=*/{"str"}, "l_", "r_"};
+
+ ASSERT_OK_AND_ASSIGN(
+ auto hashjoin,
+ MakeExecNode("hashjoin", plan.get(), {left_filter, right_filter}, join_opts));
+
+ ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), {hashjoin},
+ SinkNodeOptions{&sink_gen}));
+
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto result, StartAndCollect(plan.get(), sink_gen));
+
+ std::vector<ExecBatch> expected = {
+ ExecBatchFromJSON({int32(), utf8(), int32(), utf8()}, R"([
+ [3, "alfa", -2, "alfa"], [3, "alfa", -8, "alfa"],
+ [3, "alfa", -2, "alfa"], [3, "alfa", -8, "alfa"],
+ [12, "alfa", -2, "alfa"], [12, "alfa", -8, "alfa"],
+ [-1, "gama", -1, "gama"], [5, "gama", -1, "gama"]])")};
+
+ AssertExecBatchesEqual(hashjoin->output_schema(), result, expected);
+ }
+}
+
+TEST(ExecPlanExecution, SelfOuterHashJoinSink) {
+ for (bool parallel : {false, true}) {
+ SCOPED_TRACE(parallel ? "parallel/merged" : "serial");
+
+ auto input = MakeGroupableBatches();
+
+ auto exec_ctx = arrow::internal::make_unique<ExecContext>(
+ default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ ExecNode* left_source;
+ ExecNode* right_source;
+ for (auto source : {&left_source, &right_source}) {
+ ASSERT_OK_AND_ASSIGN(
+ *source, MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{input.schema,
+ input.gen(parallel, /*slow=*/false)}));
+ }
+ ASSERT_OK_AND_ASSIGN(
+ auto left_filter,
+ MakeExecNode("filter", plan.get(), {left_source},
+ FilterNodeOptions{greater_equal(field_ref("i32"), literal(-1))}));
+ ASSERT_OK_AND_ASSIGN(
+ auto right_filter,
+ MakeExecNode("filter", plan.get(), {right_source},
+ FilterNodeOptions{less_equal(field_ref("i32"), literal(2))}));
+
+ // left side: [3, "alfa"], [3, "alfa"], [12, "alfa"], [3, "beta"], [7, "beta"],
+ // [-1, "gama"], [5, "gama"]
+ // right side: [-2, "alfa"], [-8, "alfa"], [-1, "gama"]
+
+ HashJoinNodeOptions join_opts{JoinType::FULL_OUTER,
+ /*left_keys=*/{"str"},
+ /*right_keys=*/{"str"}, "l_", "r_"};
+
+ ASSERT_OK_AND_ASSIGN(
+ auto hashjoin,
+ MakeExecNode("hashjoin", plan.get(), {left_filter, right_filter}, join_opts));
+
+ ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), {hashjoin},
+ SinkNodeOptions{&sink_gen}));
+
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto result, StartAndCollect(plan.get(), sink_gen));
+
+ std::vector<ExecBatch> expected = {
+ ExecBatchFromJSON({int32(), utf8(), int32(), utf8()}, R"([
+ [3, "alfa", -2, "alfa"], [3, "alfa", -8, "alfa"],
+ [3, "alfa", -2, "alfa"], [3, "alfa", -8, "alfa"],
+ [12, "alfa", -2, "alfa"], [12, "alfa", -8, "alfa"],
+ [3, "beta", null, null], [7, "beta", null, null],
+ [-1, "gama", -1, "gama"], [5, "gama", -1, "gama"]])")};
+
+ AssertExecBatchesEqual(hashjoin->output_schema(), result, expected);
+ }
+}
+
+TEST(ExecPlan, RecordBatchReaderSourceSink) {
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ // set up a RecordBatchReader:
+ auto input = MakeBasicBatches();
+
+ RecordBatchVector batches;
+ for (const ExecBatch& exec_batch : input.batches) {
+ ASSERT_OK_AND_ASSIGN(auto batch, exec_batch.ToRecordBatch(input.schema));
+ batches.push_back(batch);
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches(batches));
+ std::shared_ptr<RecordBatchReader> reader = std::make_shared<TableBatchReader>(*table);
+
+ // Map the RecordBatchReader to a SourceNode
+ ASSERT_OK_AND_ASSIGN(
+ auto batch_gen,
+ MakeReaderGenerator(std::move(reader), arrow::io::internal::GetIOThreadPool()));
+
+ ASSERT_OK(
+ Declaration::Sequence({
+ {"source", SourceNodeOptions{table->schema(), batch_gen}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
+ Finishes(ResultWith(UnorderedElementsAreArray(input.batches))));
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/project_node.cc b/src/arrow/cpp/src/arrow/compute/exec/project_node.cc
new file mode 100644
index 000000000..c675acb3d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/project_node.cc
@@ -0,0 +1,127 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/exec_plan.h"
+
+#include <sstream>
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/expression.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/datum.h"
+#include "arrow/result.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+namespace {
+
+class ProjectNode : public MapNode {
+ public:
+ ProjectNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ std::shared_ptr<Schema> output_schema, std::vector<Expression> exprs,
+ bool async_mode)
+ : MapNode(plan, std::move(inputs), std::move(output_schema), async_mode),
+ exprs_(std::move(exprs)) {}
+
+ static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "ProjectNode"));
+
+ const auto& project_options = checked_cast<const ProjectNodeOptions&>(options);
+ auto exprs = project_options.expressions;
+ auto names = project_options.names;
+
+ if (names.size() == 0) {
+ names.resize(exprs.size());
+ for (size_t i = 0; i < exprs.size(); ++i) {
+ names[i] = exprs[i].ToString();
+ }
+ }
+
+ FieldVector fields(exprs.size());
+ int i = 0;
+ for (auto& expr : exprs) {
+ if (!expr.IsBound()) {
+ ARROW_ASSIGN_OR_RAISE(expr, expr.Bind(*inputs[0]->output_schema()));
+ }
+ fields[i] = field(std::move(names[i]), expr.type());
+ ++i;
+ }
+ return plan->EmplaceNode<ProjectNode>(plan, std::move(inputs),
+ schema(std::move(fields)), std::move(exprs),
+ project_options.async_mode);
+ }
+
+ const char* kind_name() const override { return "ProjectNode"; }
+
+ Result<ExecBatch> DoProject(const ExecBatch& target) {
+ std::vector<Datum> values{exprs_.size()};
+ for (size_t i = 0; i < exprs_.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(Expression simplified_expr,
+ SimplifyWithGuarantee(exprs_[i], target.guarantee));
+
+ ARROW_ASSIGN_OR_RAISE(values[i], ExecuteScalarExpression(simplified_expr, target,
+ plan()->exec_context()));
+ }
+ return ExecBatch{std::move(values), target.length};
+ }
+
+ void InputReceived(ExecNode* input, ExecBatch batch) override {
+ DCHECK_EQ(input, inputs_[0]);
+ auto func = [this](ExecBatch batch) { return DoProject(std::move(batch)); };
+ this->SubmitTask(std::move(func), std::move(batch));
+ }
+
+ protected:
+ std::string ToStringExtra() const override {
+ std::stringstream ss;
+ ss << "projection=[";
+ for (int i = 0; static_cast<size_t>(i) < exprs_.size(); i++) {
+ if (i > 0) ss << ", ";
+ auto repr = exprs_[i].ToString();
+ if (repr != output_schema_->field(i)->name()) {
+ ss << '"' << output_schema_->field(i)->name() << "\": ";
+ }
+ ss << repr;
+ }
+ ss << ']';
+ return ss.str();
+ }
+
+ private:
+ std::vector<Expression> exprs_;
+};
+
+} // namespace
+
+namespace internal {
+
+void RegisterProjectNode(ExecFactoryRegistry* registry) {
+ DCHECK_OK(registry->AddFactory("project", ProjectNode::Make));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/schema_util.h b/src/arrow/cpp/src/arrow/compute/exec/schema_util.h
new file mode 100644
index 000000000..279cbb806
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/schema_util.h
@@ -0,0 +1,209 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/compute/exec/key_encode.h" // for KeyColumnMetadata
+#include "arrow/type.h" // for DataType, FieldRef, Field and Schema
+#include "arrow/util/mutex.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+// Identifiers for all different row schemas that are used in a join
+//
+enum class HashJoinProjection : int { INPUT = 0, KEY = 1, PAYLOAD = 2, OUTPUT = 3 };
+
+struct SchemaProjectionMap {
+ static constexpr int kMissingField = -1;
+ int num_cols;
+ const int* source_to_base;
+ const int* base_to_target;
+ inline int get(int i) const {
+ ARROW_DCHECK(i >= 0 && i < num_cols);
+ ARROW_DCHECK(source_to_base[i] != kMissingField);
+ return base_to_target[source_to_base[i]];
+ }
+};
+
+/// Helper class for managing different projections of the same row schema.
+/// Used to efficiently map any field in one projection to a corresponding field in
+/// another projection.
+/// Materialized mappings are generated lazily at the time of the first access.
+/// Thread-safe apart from initialization.
+template <typename ProjectionIdEnum>
+class SchemaProjectionMaps {
+ public:
+ static constexpr int kMissingField = -1;
+
+ Status Init(ProjectionIdEnum full_schema_handle, const Schema& schema,
+ const std::vector<ProjectionIdEnum>& projection_handles,
+ const std::vector<const std::vector<FieldRef>*>& projections) {
+ ARROW_DCHECK(projection_handles.size() == projections.size());
+ ARROW_RETURN_NOT_OK(RegisterSchema(full_schema_handle, schema));
+ for (size_t i = 0; i < projections.size(); ++i) {
+ ARROW_RETURN_NOT_OK(
+ RegisterProjectedSchema(projection_handles[i], *(projections[i]), schema));
+ }
+ RegisterEnd();
+ return Status::OK();
+ }
+
+ int num_cols(ProjectionIdEnum schema_handle) const {
+ int id = schema_id(schema_handle);
+ return static_cast<int>(schemas_[id].second.size());
+ }
+
+ const std::string& field_name(ProjectionIdEnum schema_handle, int field_id) const {
+ return field(schema_handle, field_id).field_name;
+ }
+
+ const std::shared_ptr<DataType>& data_type(ProjectionIdEnum schema_handle,
+ int field_id) const {
+ return field(schema_handle, field_id).data_type;
+ }
+
+ SchemaProjectionMap map(ProjectionIdEnum from, ProjectionIdEnum to) const {
+ int id_from = schema_id(from);
+ int id_to = schema_id(to);
+ SchemaProjectionMap result;
+ result.num_cols = num_cols(from);
+ result.source_to_base = mappings_[id_from].data();
+ result.base_to_target = inverse_mappings_[id_to].data();
+ return result;
+ }
+
+ protected:
+ struct FieldInfo {
+ int field_path;
+ std::string field_name;
+ std::shared_ptr<DataType> data_type;
+ };
+
+ Status RegisterSchema(ProjectionIdEnum handle, const Schema& schema) {
+ std::vector<FieldInfo> out_fields;
+ const FieldVector& in_fields = schema.fields();
+ out_fields.resize(in_fields.size());
+ for (size_t i = 0; i < in_fields.size(); ++i) {
+ const std::string& name = in_fields[i]->name();
+ const std::shared_ptr<DataType>& type = in_fields[i]->type();
+ out_fields[i].field_path = static_cast<int>(i);
+ out_fields[i].field_name = name;
+ out_fields[i].data_type = type;
+ }
+ schemas_.push_back(std::make_pair(handle, out_fields));
+ return Status::OK();
+ }
+
+ Status RegisterProjectedSchema(ProjectionIdEnum handle,
+ const std::vector<FieldRef>& selected_fields,
+ const Schema& full_schema) {
+ std::vector<FieldInfo> out_fields;
+ const FieldVector& in_fields = full_schema.fields();
+ out_fields.resize(selected_fields.size());
+ for (size_t i = 0; i < selected_fields.size(); ++i) {
+ // All fields must be found in schema without ambiguity
+ ARROW_ASSIGN_OR_RAISE(auto match, selected_fields[i].FindOne(full_schema));
+ const std::string& name = in_fields[match[0]]->name();
+ const std::shared_ptr<DataType>& type = in_fields[match[0]]->type();
+ out_fields[i].field_path = match[0];
+ out_fields[i].field_name = name;
+ out_fields[i].data_type = type;
+ }
+ schemas_.push_back(std::make_pair(handle, out_fields));
+ return Status::OK();
+ }
+
+ void RegisterEnd() {
+ size_t size = schemas_.size();
+ mappings_.resize(size);
+ inverse_mappings_.resize(size);
+ int id_base = 0;
+ for (size_t i = 0; i < size; ++i) {
+ GenerateMapForProjection(static_cast<int>(i), id_base);
+ }
+ }
+
+ int schema_id(ProjectionIdEnum schema_handle) const {
+ for (size_t i = 0; i < schemas_.size(); ++i) {
+ if (schemas_[i].first == schema_handle) {
+ return static_cast<int>(i);
+ }
+ }
+ // We should never get here
+ ARROW_DCHECK(false);
+ return -1;
+ }
+
+ const FieldInfo& field(ProjectionIdEnum schema_handle, int field_id) const {
+ int id = schema_id(schema_handle);
+ const std::vector<FieldInfo>& field_infos = schemas_[id].second;
+ return field_infos[field_id];
+ }
+
+ void GenerateMapForProjection(int id_proj, int id_base) {
+ int num_cols_proj = static_cast<int>(schemas_[id_proj].second.size());
+ int num_cols_base = static_cast<int>(schemas_[id_base].second.size());
+
+ std::vector<int>& mapping = mappings_[id_proj];
+ std::vector<int>& inverse_mapping = inverse_mappings_[id_proj];
+ mapping.resize(num_cols_proj);
+ inverse_mapping.resize(num_cols_base);
+
+ if (id_proj == id_base) {
+ for (int i = 0; i < num_cols_base; ++i) {
+ mapping[i] = inverse_mapping[i] = i;
+ }
+ } else {
+ const std::vector<FieldInfo>& fields_proj = schemas_[id_proj].second;
+ const std::vector<FieldInfo>& fields_base = schemas_[id_base].second;
+ for (int i = 0; i < num_cols_base; ++i) {
+ inverse_mapping[i] = SchemaProjectionMap::kMissingField;
+ }
+ for (int i = 0; i < num_cols_proj; ++i) {
+ int field_id = SchemaProjectionMap::kMissingField;
+ for (int j = 0; j < num_cols_base; ++j) {
+ if (fields_proj[i].field_path == fields_base[j].field_path) {
+ field_id = j;
+ // If there are multiple matches for the same input field,
+ // it will be mapped to the first match.
+ break;
+ }
+ }
+ ARROW_DCHECK(field_id != SchemaProjectionMap::kMissingField);
+ mapping[i] = field_id;
+ inverse_mapping[field_id] = i;
+ }
+ }
+ }
+
+ // vector used as a mapping from ProjectionIdEnum to fields
+ std::vector<std::pair<ProjectionIdEnum, std::vector<FieldInfo>>> schemas_;
+ std::vector<std::vector<int>> mappings_;
+ std::vector<std::vector<int>> inverse_mappings_;
+};
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/sink_node.cc b/src/arrow/cpp/src/arrow/compute/exec/sink_node.cc
new file mode 100644
index 000000000..1bb268038
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/sink_node.cc
@@ -0,0 +1,341 @@
+
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/exec_plan.h"
+
+#include <mutex>
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/expression.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/order_by_impl.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/compute/exec_internal.h"
+#include "arrow/datum.h"
+#include "arrow/result.h"
+#include "arrow/table.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/async_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/thread_pool.h"
+#include "arrow/util/unreachable.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+namespace {
+
+class SinkNode : public ExecNode {
+ public:
+ SinkNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ AsyncGenerator<util::optional<ExecBatch>>* generator,
+ util::BackpressureOptions backpressure)
+ : ExecNode(plan, std::move(inputs), {"collected"}, {},
+ /*num_outputs=*/0),
+ producer_(MakeProducer(generator, std::move(backpressure))) {}
+
+ static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "SinkNode"));
+
+ const auto& sink_options = checked_cast<const SinkNodeOptions&>(options);
+ return plan->EmplaceNode<SinkNode>(plan, std::move(inputs), sink_options.generator,
+ sink_options.backpressure);
+ }
+
+ static PushGenerator<util::optional<ExecBatch>>::Producer MakeProducer(
+ AsyncGenerator<util::optional<ExecBatch>>* out_gen,
+ util::BackpressureOptions backpressure) {
+ PushGenerator<util::optional<ExecBatch>> push_gen(std::move(backpressure));
+ auto out = push_gen.producer();
+ *out_gen = std::move(push_gen);
+ return out;
+ }
+
+ const char* kind_name() const override { return "SinkNode"; }
+
+ Status StartProducing() override {
+ finished_ = Future<>::Make();
+ return Status::OK();
+ }
+
+ // sink nodes have no outputs from which to feel backpressure
+ [[noreturn]] static void NoOutputs() {
+ Unreachable("no outputs; this should never be called");
+ }
+ [[noreturn]] void ResumeProducing(ExecNode* output) override { NoOutputs(); }
+ [[noreturn]] void PauseProducing(ExecNode* output) override { NoOutputs(); }
+ [[noreturn]] void StopProducing(ExecNode* output) override { NoOutputs(); }
+
+ void StopProducing() override {
+ Finish();
+ inputs_[0]->StopProducing(this);
+ }
+
+ Future<> finished() override { return finished_; }
+
+ void InputReceived(ExecNode* input, ExecBatch batch) override {
+ DCHECK_EQ(input, inputs_[0]);
+
+ bool did_push = producer_.Push(std::move(batch));
+ if (!did_push) return; // producer_ was Closed already
+
+ if (input_counter_.Increment()) {
+ Finish();
+ }
+ }
+
+ void ErrorReceived(ExecNode* input, Status error) override {
+ DCHECK_EQ(input, inputs_[0]);
+
+ producer_.Push(std::move(error));
+
+ if (input_counter_.Cancel()) {
+ Finish();
+ }
+ inputs_[0]->StopProducing(this);
+ }
+
+ void InputFinished(ExecNode* input, int total_batches) override {
+ if (input_counter_.SetTotal(total_batches)) {
+ Finish();
+ }
+ }
+
+ protected:
+ virtual void Finish() {
+ if (producer_.Close()) {
+ finished_.MarkFinished();
+ }
+ }
+
+ AtomicCounter input_counter_;
+ Future<> finished_ = Future<>::MakeFinished();
+
+ PushGenerator<util::optional<ExecBatch>>::Producer producer_;
+};
+
+// A sink node that owns consuming the data and will not finish until the consumption
+// is finished. Use SinkNode if you are transferring the ownership of the data to another
+// system. Use ConsumingSinkNode if the data is being consumed within the exec plan (i.e.
+// the exec plan should not complete until the consumption has completed).
+class ConsumingSinkNode : public ExecNode {
+ public:
+ ConsumingSinkNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ std::shared_ptr<SinkNodeConsumer> consumer)
+ : ExecNode(plan, std::move(inputs), {"to_consume"}, {},
+ /*num_outputs=*/0),
+ consumer_(std::move(consumer)) {}
+
+ static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "SinkNode"));
+
+ const auto& sink_options = checked_cast<const ConsumingSinkNodeOptions&>(options);
+ return plan->EmplaceNode<ConsumingSinkNode>(plan, std::move(inputs),
+ std::move(sink_options.consumer));
+ }
+
+ const char* kind_name() const override { return "ConsumingSinkNode"; }
+
+ Status StartProducing() override {
+ finished_ = Future<>::Make();
+ return Status::OK();
+ }
+
+ // sink nodes have no outputs from which to feel backpressure
+ [[noreturn]] static void NoOutputs() {
+ Unreachable("no outputs; this should never be called");
+ }
+ [[noreturn]] void ResumeProducing(ExecNode* output) override { NoOutputs(); }
+ [[noreturn]] void PauseProducing(ExecNode* output) override { NoOutputs(); }
+ [[noreturn]] void StopProducing(ExecNode* output) override { NoOutputs(); }
+
+ void StopProducing() override {
+ Finish(Status::Invalid("ExecPlan was stopped early"));
+ inputs_[0]->StopProducing(this);
+ }
+
+ Future<> finished() override { return finished_; }
+
+ void InputReceived(ExecNode* input, ExecBatch batch) override {
+ DCHECK_EQ(input, inputs_[0]);
+
+ // This can happen if an error was received and the source hasn't yet stopped. Since
+ // we have already called consumer_->Finish we don't want to call consumer_->Consume
+ if (input_counter_.Completed()) {
+ return;
+ }
+
+ Status consumption_status = consumer_->Consume(std::move(batch));
+ if (!consumption_status.ok()) {
+ if (input_counter_.Cancel()) {
+ Finish(std::move(consumption_status));
+ }
+ inputs_[0]->StopProducing(this);
+ return;
+ }
+
+ if (input_counter_.Increment()) {
+ Finish(Status::OK());
+ }
+ }
+
+ void ErrorReceived(ExecNode* input, Status error) override {
+ DCHECK_EQ(input, inputs_[0]);
+
+ if (input_counter_.Cancel()) {
+ Finish(std::move(error));
+ }
+
+ inputs_[0]->StopProducing(this);
+ }
+
+ void InputFinished(ExecNode* input, int total_batches) override {
+ if (input_counter_.SetTotal(total_batches)) {
+ Finish(Status::OK());
+ }
+ }
+
+ protected:
+ virtual void Finish(const Status& finish_st) {
+ consumer_->Finish().AddCallback([this, finish_st](const Status& st) {
+ // Prefer the plan error over the consumer error
+ Status final_status = finish_st & st;
+ finished_.MarkFinished(std::move(final_status));
+ });
+ }
+
+ AtomicCounter input_counter_;
+
+ Future<> finished_ = Future<>::MakeFinished();
+ std::shared_ptr<SinkNodeConsumer> consumer_;
+};
+
+// A sink node that accumulates inputs, then sorts them before emitting them.
+struct OrderBySinkNode final : public SinkNode {
+ OrderBySinkNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ std::unique_ptr<OrderByImpl> impl,
+ AsyncGenerator<util::optional<ExecBatch>>* generator,
+ util::BackpressureOptions backpressure)
+ : SinkNode(plan, std::move(inputs), generator, std::move(backpressure)),
+ impl_{std::move(impl)} {}
+
+ const char* kind_name() const override { return "OrderBySinkNode"; }
+
+ // A sink node that accumulates inputs, then sorts them before emitting them.
+ static Result<ExecNode*> MakeSort(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "OrderBySinkNode"));
+
+ const auto& sink_options = checked_cast<const OrderBySinkNodeOptions&>(options);
+ ARROW_ASSIGN_OR_RAISE(
+ std::unique_ptr<OrderByImpl> impl,
+ OrderByImpl::MakeSort(plan->exec_context(), inputs[0]->output_schema(),
+ sink_options.sort_options));
+ return plan->EmplaceNode<OrderBySinkNode>(plan, std::move(inputs), std::move(impl),
+ sink_options.generator,
+ sink_options.backpressure);
+ }
+
+ // A sink node that receives inputs and then compute top_k/bottom_k.
+ static Result<ExecNode*> MakeSelectK(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "OrderBySinkNode"));
+
+ const auto& sink_options = checked_cast<const SelectKSinkNodeOptions&>(options);
+ ARROW_ASSIGN_OR_RAISE(
+ std::unique_ptr<OrderByImpl> impl,
+ OrderByImpl::MakeSelectK(plan->exec_context(), inputs[0]->output_schema(),
+ sink_options.select_k_options));
+ return plan->EmplaceNode<OrderBySinkNode>(plan, std::move(inputs), std::move(impl),
+ sink_options.generator,
+ sink_options.backpressure);
+ }
+
+ void InputReceived(ExecNode* input, ExecBatch batch) override {
+ DCHECK_EQ(input, inputs_[0]);
+
+ auto maybe_batch = batch.ToRecordBatch(inputs_[0]->output_schema(),
+ plan()->exec_context()->memory_pool());
+ if (ErrorIfNotOk(maybe_batch.status())) {
+ StopProducing();
+ if (input_counter_.Cancel()) {
+ finished_.MarkFinished(maybe_batch.status());
+ }
+ return;
+ }
+ auto record_batch = maybe_batch.MoveValueUnsafe();
+
+ impl_->InputReceived(std::move(record_batch));
+ if (input_counter_.Increment()) {
+ Finish();
+ }
+ }
+
+ protected:
+ Status DoFinish() {
+ ARROW_ASSIGN_OR_RAISE(Datum sorted, impl_->DoFinish());
+ TableBatchReader reader(*sorted.table());
+ while (true) {
+ std::shared_ptr<RecordBatch> batch;
+ RETURN_NOT_OK(reader.ReadNext(&batch));
+ if (!batch) break;
+ bool did_push = producer_.Push(ExecBatch(*batch));
+ if (!did_push) break; // producer_ was Closed already
+ }
+ return Status::OK();
+ }
+
+ void Finish() override {
+ Status st = DoFinish();
+ if (ErrorIfNotOk(st)) {
+ producer_.Push(std::move(st));
+ }
+ SinkNode::Finish();
+ }
+
+ protected:
+ std::string ToStringExtra() const override {
+ return std::string("by=") + impl_->ToString();
+ }
+
+ private:
+ std::unique_ptr<OrderByImpl> impl_;
+};
+
+} // namespace
+
+namespace internal {
+
+void RegisterSinkNode(ExecFactoryRegistry* registry) {
+ DCHECK_OK(registry->AddFactory("select_k_sink", OrderBySinkNode::MakeSelectK));
+ DCHECK_OK(registry->AddFactory("order_by_sink", OrderBySinkNode::MakeSort));
+ DCHECK_OK(registry->AddFactory("consuming_sink", ConsumingSinkNode::Make));
+ DCHECK_OK(registry->AddFactory("sink", SinkNode::Make));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/source_node.cc b/src/arrow/cpp/src/arrow/compute/exec/source_node.cc
new file mode 100644
index 000000000..46bba5609
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/source_node.cc
@@ -0,0 +1,182 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <mutex>
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/expression.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/compute/exec_internal.h"
+#include "arrow/datum.h"
+#include "arrow/result.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/async_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/thread_pool.h"
+#include "arrow/util/unreachable.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+namespace {
+
+struct SourceNode : ExecNode {
+ SourceNode(ExecPlan* plan, std::shared_ptr<Schema> output_schema,
+ AsyncGenerator<util::optional<ExecBatch>> generator)
+ : ExecNode(plan, {}, {}, std::move(output_schema),
+ /*num_outputs=*/1),
+ generator_(std::move(generator)) {}
+
+ static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 0, "SourceNode"));
+ const auto& source_options = checked_cast<const SourceNodeOptions&>(options);
+ return plan->EmplaceNode<SourceNode>(plan, source_options.output_schema,
+ source_options.generator);
+ }
+
+ const char* kind_name() const override { return "SourceNode"; }
+
+ [[noreturn]] static void NoInputs() {
+ Unreachable("no inputs; this should never be called");
+ }
+ [[noreturn]] void InputReceived(ExecNode*, ExecBatch) override { NoInputs(); }
+ [[noreturn]] void ErrorReceived(ExecNode*, Status) override { NoInputs(); }
+ [[noreturn]] void InputFinished(ExecNode*, int) override { NoInputs(); }
+
+ Status StartProducing() override {
+ {
+ // If another exec node encountered an error during its StartProducing call
+ // it might have already called StopProducing on all of its inputs (including this
+ // node).
+ //
+ std::unique_lock<std::mutex> lock(mutex_);
+ if (stop_requested_) {
+ return Status::OK();
+ }
+ }
+
+ CallbackOptions options;
+ auto executor = plan()->exec_context()->executor();
+ if (executor) {
+ // These options will transfer execution to the desired Executor if necessary.
+ // This can happen for in-memory scans where batches didn't require
+ // any CPU work to decode. Otherwise, parsing etc should have already
+ // been placed us on the desired Executor and no queues will be pushed to.
+ options.executor = executor;
+ options.should_schedule = ShouldSchedule::IfDifferentExecutor;
+ }
+ finished_ =
+ Loop([this, executor, options] {
+ std::unique_lock<std::mutex> lock(mutex_);
+ int total_batches = batch_count_++;
+ if (stop_requested_) {
+ return Future<ControlFlow<int>>::MakeFinished(Break(total_batches));
+ }
+ lock.unlock();
+
+ return generator_().Then(
+ [=](const util::optional<ExecBatch>& maybe_batch) -> ControlFlow<int> {
+ std::unique_lock<std::mutex> lock(mutex_);
+ if (IsIterationEnd(maybe_batch) || stop_requested_) {
+ stop_requested_ = true;
+ return Break(total_batches);
+ }
+ lock.unlock();
+ ExecBatch batch = std::move(*maybe_batch);
+
+ if (executor) {
+ auto status =
+ task_group_.AddTask([this, executor, batch]() -> Result<Future<>> {
+ return executor->Submit([=]() {
+ outputs_[0]->InputReceived(this, std::move(batch));
+ return Status::OK();
+ });
+ });
+ if (!status.ok()) {
+ outputs_[0]->ErrorReceived(this, std::move(status));
+ return Break(total_batches);
+ }
+ } else {
+ outputs_[0]->InputReceived(this, std::move(batch));
+ }
+ return Continue();
+ },
+ [=](const Status& error) -> ControlFlow<int> {
+ // NB: ErrorReceived is independent of InputFinished, but
+ // ErrorReceived will usually prompt StopProducing which will
+ // prompt InputFinished. ErrorReceived may still be called from a
+ // node which was requested to stop (indeed, the request to stop
+ // may prompt an error).
+ std::unique_lock<std::mutex> lock(mutex_);
+ stop_requested_ = true;
+ lock.unlock();
+ outputs_[0]->ErrorReceived(this, error);
+ return Break(total_batches);
+ },
+ options);
+ }).Then([&](int total_batches) {
+ outputs_[0]->InputFinished(this, total_batches);
+ return task_group_.End();
+ });
+
+ return Status::OK();
+ }
+
+ void PauseProducing(ExecNode* output) override {}
+
+ void ResumeProducing(ExecNode* output) override {}
+
+ void StopProducing(ExecNode* output) override {
+ DCHECK_EQ(output, outputs_[0]);
+ StopProducing();
+ }
+
+ void StopProducing() override {
+ std::unique_lock<std::mutex> lock(mutex_);
+ stop_requested_ = true;
+ }
+
+ Future<> finished() override { return finished_; }
+
+ private:
+ std::mutex mutex_;
+ bool stop_requested_{false};
+ int batch_count_{0};
+ Future<> finished_ = Future<>::MakeFinished();
+ util::AsyncTaskGroup task_group_;
+ AsyncGenerator<util::optional<ExecBatch>> generator_;
+};
+
+} // namespace
+
+namespace internal {
+
+void RegisterSourceNode(ExecFactoryRegistry* registry) {
+ DCHECK_OK(registry->AddFactory("source", SourceNode::Make));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/subtree_internal.h b/src/arrow/cpp/src/arrow/compute/exec/subtree_internal.h
new file mode 100644
index 000000000..72d419df2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/subtree_internal.h
@@ -0,0 +1,178 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <stdint.h>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "arrow/compute/exec/expression.h"
+#include "arrow/util/optional.h"
+
+namespace arrow {
+namespace compute {
+// Helper class for efficiently detecting subtrees given expressions.
+//
+// Using fragment partition expressions as an example:
+// Partition expressions are broken into conjunction members and each member dictionary
+// encoded to impose a sortable ordering. In addition, subtrees are generated which span
+// groups of fragments and nested subtrees. After encoding each fragment is guaranteed to
+// be a descendant of at least one subtree. For example, given fragments in a
+// HivePartitioning with paths:
+//
+// /num=0/al=eh/dat.par
+// /num=0/al=be/dat.par
+// /num=1/al=eh/dat.par
+// /num=1/al=be/dat.par
+//
+// The following subtrees will be introduced:
+//
+// /num=0/
+// /num=0/al=eh/
+// /num=0/al=eh/dat.par
+// /num=0/al=be/
+// /num=0/al=be/dat.par
+// /num=1/
+// /num=1/al=eh/
+// /num=1/al=eh/dat.par
+// /num=1/al=be/
+// /num=1/al=be/dat.par
+struct SubtreeImpl {
+ // Each unique conjunction member is mapped to an integer.
+ using expression_code = char32_t;
+ // Partition expressions are mapped to strings of codes; strings give us lexicographic
+ // ordering (and potentially useful optimizations).
+ using expression_codes = std::basic_string<expression_code>;
+ // An encoded guarantee (if index is set) or subtree.
+ struct Encoded {
+ // An external index identifying the corresponding object (e.g. a Fragment) of the
+ // guarantee.
+ util::optional<int> index;
+ // An encoded expression representing a guarantee.
+ expression_codes guarantee;
+ };
+
+ std::unordered_map<compute::Expression, expression_code, compute::Expression::Hash>
+ expr_to_code_;
+ std::vector<compute::Expression> code_to_expr_;
+ std::unordered_set<expression_codes> subtree_exprs_;
+
+ // Encode a subexpression (returning the existing code if possible).
+ expression_code GetOrInsert(const compute::Expression& expr) {
+ auto next_code = static_cast<int>(expr_to_code_.size());
+ auto it_success = expr_to_code_.emplace(expr, next_code);
+
+ if (it_success.second) {
+ code_to_expr_.push_back(expr);
+ }
+ return it_success.first->second;
+ }
+
+ // Encode an expression (recursively breaking up conjunction members if possible).
+ void EncodeConjunctionMembers(const compute::Expression& expr,
+ expression_codes* codes) {
+ if (auto call = expr.call()) {
+ if (call->function_name == "and_kleene") {
+ // expr is a conjunction, encode its arguments
+ EncodeConjunctionMembers(call->arguments[0], codes);
+ EncodeConjunctionMembers(call->arguments[1], codes);
+ return;
+ }
+ }
+ // expr is not a conjunction, encode it whole
+ codes->push_back(GetOrInsert(expr));
+ }
+
+ // Convert an encoded subtree or guarantee back into an expression.
+ compute::Expression GetSubtreeExpression(const Encoded& encoded_subtree) {
+ // Filters will already be simplified by all of a subtree's ancestors, so
+ // we only need to simplify the filter by the trailing conjunction member
+ // of each subtree.
+ return code_to_expr_[encoded_subtree.guarantee.back()];
+ }
+
+ // Insert subtrees for each component of an encoded partition expression.
+ void GenerateSubtrees(expression_codes guarantee, std::vector<Encoded>* encoded) {
+ while (!guarantee.empty()) {
+ if (subtree_exprs_.insert(guarantee).second) {
+ Encoded encoded_subtree{/*index=*/util::nullopt, guarantee};
+ encoded->push_back(std::move(encoded_subtree));
+ }
+ guarantee.resize(guarantee.size() - 1);
+ }
+ }
+
+ // Encode a guarantee, and generate subtrees for it as well.
+ void EncodeOneGuarantee(int index, const Expression& guarantee,
+ std::vector<Encoded>* encoded) {
+ Encoded encoded_guarantee{index, {}};
+ EncodeConjunctionMembers(guarantee, &encoded_guarantee.guarantee);
+ GenerateSubtrees(encoded_guarantee.guarantee, encoded);
+ encoded->push_back(std::move(encoded_guarantee));
+ }
+
+ template <typename GetGuarantee>
+ std::vector<Encoded> EncodeGuarantees(const GetGuarantee& get, int count) {
+ std::vector<Encoded> encoded;
+ for (int i = 0; i < count; ++i) {
+ EncodeOneGuarantee(i, get(i), &encoded);
+ }
+ return encoded;
+ }
+
+ // Comparator for sort
+ struct ByGuarantee {
+ bool operator()(const Encoded& l, const Encoded& r) {
+ const auto cmp = l.guarantee.compare(r.guarantee);
+ if (cmp != 0) {
+ return cmp < 0;
+ }
+ // Equal guarantees; sort encodings with indices after encodings without
+ return (l.index ? 1 : 0) < (r.index ? 1 : 0);
+ }
+ };
+
+ // Comparator for building a Forest
+ struct IsAncestor {
+ const std::vector<Encoded> encoded;
+
+ bool operator()(int l, int r) const {
+ if (encoded[l].index) {
+ // Leaf-level object (e.g. a Fragment): not an ancestor.
+ return false;
+ }
+
+ const auto& ancestor = encoded[l].guarantee;
+ const auto& descendant = encoded[r].guarantee;
+
+ if (descendant.size() >= ancestor.size()) {
+ return std::equal(ancestor.begin(), ancestor.end(), descendant.begin());
+ }
+ return false;
+ }
+ };
+};
+
+inline bool operator==(const SubtreeImpl::Encoded& l, const SubtreeImpl::Encoded& r) {
+ return l.index == r.index && l.guarantee == r.guarantee;
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/subtree_test.cc b/src/arrow/cpp/src/arrow/compute/exec/subtree_test.cc
new file mode 100644
index 000000000..972131044
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/subtree_test.cc
@@ -0,0 +1,377 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/compute/exec/forest_internal.h"
+#include "arrow/compute/exec/subtree_internal.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace compute {
+
+using testing::ContainerEq;
+
+// Tests of subtree pruning
+
+// Don't depend on FileSystem - port just enough to be useful here
+struct FileInfo {
+ bool is_dir;
+ std::string path;
+
+ bool operator==(const FileInfo& other) const {
+ return is_dir == other.is_dir && path == other.path;
+ }
+
+ static FileInfo Dir(std::string path) { return FileInfo{true, std::move(path)}; }
+
+ static FileInfo File(std::string path) { return FileInfo{false, std::move(path)}; }
+
+ static bool ByPath(const FileInfo& l, const FileInfo& r) { return l.path < r.path; }
+};
+
+struct TestPathTree {
+ FileInfo info;
+ std::vector<TestPathTree> subtrees;
+
+ explicit TestPathTree(std::string file_path)
+ : info(FileInfo::File(std::move(file_path))) {}
+
+ TestPathTree(std::string dir_path, std::vector<TestPathTree> subtrees)
+ : info(FileInfo::Dir(std::move(dir_path))), subtrees(std::move(subtrees)) {}
+
+ TestPathTree(Forest::Ref ref, const std::vector<FileInfo>& infos) : info(infos[ref.i]) {
+ const Forest& forest = *ref.forest;
+
+ int begin = ref.i + 1;
+ int end = begin + ref.num_descendants();
+
+ for (int i = begin; i < end; ++i) {
+ subtrees.emplace_back(forest[i], infos);
+ i += forest[i].num_descendants();
+ }
+ }
+
+ bool operator==(const TestPathTree& other) const {
+ return info == other.info && subtrees == other.subtrees;
+ }
+
+ std::string ToString() const {
+ auto out = "\n" + info.path;
+ if (info.is_dir) out += "/";
+
+ for (const auto& subtree : subtrees) {
+ out += subtree.ToString();
+ }
+ return out;
+ }
+
+ friend std::ostream& operator<<(std::ostream& os, const TestPathTree& tree) {
+ return os << tree.ToString();
+ }
+};
+
+using PT = TestPathTree;
+
+util::string_view RemoveTrailingSlash(util::string_view key) {
+ while (!key.empty() && key.back() == '/') {
+ key.remove_suffix(1);
+ }
+ return key;
+}
+bool IsAncestorOf(util::string_view ancestor, util::string_view descendant) {
+ // See filesystem/path_util.h
+ ancestor = RemoveTrailingSlash(ancestor);
+ if (ancestor == "") return true;
+ descendant = RemoveTrailingSlash(descendant);
+ if (!descendant.starts_with(ancestor)) return false;
+ descendant.remove_prefix(ancestor.size());
+ if (descendant.empty()) return true;
+ return descendant.front() == '/';
+}
+
+Forest MakeForest(std::vector<FileInfo>* infos) {
+ std::sort(infos->begin(), infos->end(), FileInfo::ByPath);
+
+ return Forest(static_cast<int>(infos->size()), [&](int i, int j) {
+ return IsAncestorOf(infos->at(i).path, infos->at(j).path);
+ });
+}
+
+void ExpectForestIs(std::vector<FileInfo> infos, std::vector<PT> expected_roots) {
+ auto forest = MakeForest(&infos);
+
+ std::vector<PT> actual_roots;
+ ASSERT_OK(forest.Visit(
+ [&](Forest::Ref ref) -> Result<bool> {
+ actual_roots.emplace_back(ref, infos);
+ return false; // only vist roots
+ },
+ [](Forest::Ref) {}));
+
+ // visit expected and assert equality
+ EXPECT_THAT(actual_roots, ContainerEq(expected_roots));
+}
+
+TEST(Forest, Basic) {
+ ExpectForestIs({}, {});
+
+ ExpectForestIs({FileInfo::File("aa")}, {PT("aa")});
+ ExpectForestIs({FileInfo::Dir("AA")}, {PT("AA", {})});
+ ExpectForestIs({FileInfo::Dir("AA"), FileInfo::File("AA/aa")},
+ {PT("AA", {PT("AA/aa")})});
+ ExpectForestIs({FileInfo::Dir("AA"), FileInfo::Dir("AA/BB"), FileInfo::File("AA/BB/0")},
+ {PT("AA", {PT("AA/BB", {PT("AA/BB/0")})})});
+
+ // Missing parent can still find ancestor.
+ ExpectForestIs({FileInfo::Dir("AA"), FileInfo::File("AA/BB/bb")},
+ {PT("AA", {PT("AA/BB/bb")})});
+
+ // Ancestors should link to parent regardless of ordering.
+ ExpectForestIs({FileInfo::File("AA/aa"), FileInfo::Dir("AA")},
+ {PT("AA", {PT("AA/aa")})});
+
+ // Multiple roots are supported.
+ ExpectForestIs({FileInfo::File("aa"), FileInfo::File("bb")}, {PT("aa"), PT("bb")});
+ ExpectForestIs({FileInfo::File("00"), FileInfo::Dir("AA"), FileInfo::File("AA/aa"),
+ FileInfo::File("BB/bb")},
+ {PT("00"), PT("AA", {PT("AA/aa")}), PT("BB/bb")});
+ ExpectForestIs({FileInfo::Dir("AA"), FileInfo::Dir("AA/BB"), FileInfo::File("AA/BB/0"),
+ FileInfo::Dir("CC"), FileInfo::Dir("CC/BB"), FileInfo::File("CC/BB/0")},
+ {PT("AA", {PT("AA/BB", {PT("AA/BB/0")})}),
+ PT("CC", {PT("CC/BB", {PT("CC/BB/0")})})});
+}
+
+TEST(Forest, HourlyETL) {
+ // This test mimics a scenario where an ETL dumps hourly files in a structure
+ // `$year/$month/$day/$hour/*.parquet`.
+ constexpr int64_t kYears = 3;
+ constexpr int64_t kMonthsPerYear = 12;
+ constexpr int64_t kDaysPerMonth = 31;
+ constexpr int64_t kHoursPerDay = 24;
+ constexpr int64_t kFilesPerHour = 2;
+
+ // Avoid constructing strings
+ std::vector<std::string> numbers{kDaysPerMonth + 1};
+ for (size_t i = 0; i < numbers.size(); i++) {
+ numbers[i] = std::to_string(i);
+ if (numbers[i].size() == 1) {
+ numbers[i] = "0" + numbers[i];
+ }
+ }
+
+ auto join = [](const std::vector<std::string>& path) {
+ if (path.empty()) return std::string("");
+ std::string result = path[0];
+ for (const auto& part : path) {
+ result += '/';
+ result += part;
+ }
+ return result;
+ };
+
+ std::vector<FileInfo> infos;
+
+ std::vector<PT> forest;
+ for (int64_t year = 0; year < kYears; year++) {
+ auto year_str = std::to_string(year + 2000);
+ auto year_dir = FileInfo::Dir(year_str);
+ infos.push_back(year_dir);
+
+ std::vector<PT> months;
+ for (int64_t month = 0; month < kMonthsPerYear; month++) {
+ auto month_str = join({year_str, numbers[month + 1]});
+ auto month_dir = FileInfo::Dir(month_str);
+ infos.push_back(month_dir);
+
+ std::vector<PT> days;
+ for (int64_t day = 0; day < kDaysPerMonth; day++) {
+ auto day_str = join({month_str, numbers[day + 1]});
+ auto day_dir = FileInfo::Dir(day_str);
+ infos.push_back(day_dir);
+
+ std::vector<PT> hours;
+ for (int64_t hour = 0; hour < kHoursPerDay; hour++) {
+ auto hour_str = join({day_str, numbers[hour]});
+ auto hour_dir = FileInfo::Dir(hour_str);
+ infos.push_back(hour_dir);
+
+ std::vector<PT> files;
+ for (int64_t file = 0; file < kFilesPerHour; file++) {
+ auto file_str = join({hour_str, numbers[file] + ".parquet"});
+ auto file_fd = FileInfo::File(file_str);
+ infos.push_back(file_fd);
+ files.emplace_back(file_str);
+ }
+
+ auto hour_pt = PT(hour_str, std::move(files));
+ hours.push_back(hour_pt);
+ }
+
+ auto day_pt = PT(day_str, std::move(hours));
+ days.push_back(day_pt);
+ }
+
+ auto month_pt = PT(month_str, std::move(days));
+ months.push_back(month_pt);
+ }
+
+ auto year_pt = PT(year_str, std::move(months));
+ forest.push_back(year_pt);
+ }
+
+ ExpectForestIs(infos, forest);
+}
+
+TEST(Forest, Visit) {
+ using Infos = std::vector<FileInfo>;
+
+ for (auto infos :
+ {Infos{}, Infos{FileInfo::Dir("A"), FileInfo::File("A/a")},
+ Infos{FileInfo::Dir("AA"), FileInfo::Dir("AA/BB"), FileInfo::File("AA/BB/0"),
+ FileInfo::Dir("CC"), FileInfo::Dir("CC/BB"), FileInfo::File("CC/BB/0")}}) {
+ ASSERT_TRUE(std::is_sorted(infos.begin(), infos.end(), FileInfo::ByPath));
+
+ auto forest = MakeForest(&infos);
+
+ auto ignore_post = [](Forest::Ref) {};
+
+ // noop is fine
+ ASSERT_OK(
+ forest.Visit([](Forest::Ref) -> Result<bool> { return false; }, ignore_post));
+
+ // Should propagate failure
+ if (forest.size() != 0) {
+ ASSERT_RAISES(
+ Invalid,
+ forest.Visit([](Forest::Ref) -> Result<bool> { return Status::Invalid(""); },
+ ignore_post));
+ }
+
+ // Ensure basic visit of all nodes
+ int i = 0;
+ ASSERT_OK(forest.Visit(
+ [&](Forest::Ref ref) -> Result<bool> {
+ EXPECT_EQ(ref.i, i);
+ ++i;
+ return true;
+ },
+ ignore_post));
+
+ // Visit only directories
+ Infos actual_dirs;
+ ASSERT_OK(forest.Visit(
+ [&](Forest::Ref ref) -> Result<bool> {
+ if (!infos[ref.i].is_dir) {
+ return false;
+ }
+ actual_dirs.push_back(infos[ref.i]);
+ return true;
+ },
+ ignore_post));
+
+ Infos expected_dirs;
+ for (const auto& info : infos) {
+ if (info.is_dir) {
+ expected_dirs.push_back(info);
+ }
+ }
+ EXPECT_THAT(actual_dirs, ContainerEq(expected_dirs));
+ }
+}
+
+TEST(Subtree, EncodeExpression) {
+ SubtreeImpl tree;
+ ASSERT_EQ(0, tree.GetOrInsert(equal(field_ref("a"), literal("1"))));
+ // Should be idempotent
+ ASSERT_EQ(0, tree.GetOrInsert(equal(field_ref("a"), literal("1"))));
+ ASSERT_EQ(equal(field_ref("a"), literal("1")), tree.code_to_expr_[0]);
+
+ SubtreeImpl::expression_codes codes;
+ auto conj =
+ and_(equal(field_ref("a"), literal("1")), equal(field_ref("b"), literal("2")));
+ tree.EncodeConjunctionMembers(conj, &codes);
+ ASSERT_EQ(SubtreeImpl::expression_codes({0, 1}), codes);
+
+ codes.clear();
+ conj = or_(equal(field_ref("a"), literal("1")), equal(field_ref("b"), literal("2")));
+ tree.EncodeConjunctionMembers(conj, &codes);
+ ASSERT_EQ(SubtreeImpl::expression_codes({2}), codes);
+}
+
+TEST(Subtree, GetSubtreeExpression) {
+ SubtreeImpl tree;
+ const auto expr_a = equal(field_ref("a"), literal("1"));
+ const auto expr_b = equal(field_ref("b"), literal("2"));
+ const auto code_a = tree.GetOrInsert(expr_a);
+ const auto code_b = tree.GetOrInsert(expr_b);
+ ASSERT_EQ(expr_a,
+ tree.GetSubtreeExpression(SubtreeImpl::Encoded{util::nullopt, {code_a}}));
+ ASSERT_EQ(expr_b, tree.GetSubtreeExpression(
+ SubtreeImpl::Encoded{util::nullopt, {code_a, code_b}}));
+}
+
+class FakeFragment {
+ public:
+ explicit FakeFragment(Expression partition_expression)
+ : partition_expression_(partition_expression) {}
+ const Expression& partition_expression() const { return partition_expression_; }
+
+ private:
+ Expression partition_expression_;
+};
+
+TEST(Subtree, EncodeFragments) {
+ const auto expr_a =
+ and_(equal(field_ref("a"), literal("1")), equal(field_ref("b"), literal("2")));
+ const auto expr_b =
+ and_(equal(field_ref("a"), literal("2")), equal(field_ref("b"), literal("3")));
+ std::vector<std::shared_ptr<FakeFragment>> fragments;
+ fragments.push_back(std::make_shared<FakeFragment>(expr_a));
+ fragments.push_back(std::make_shared<FakeFragment>(expr_b));
+
+ SubtreeImpl tree;
+ auto encoded = tree.EncodeGuarantees(
+ [&](int index) { return fragments[index]->partition_expression(); },
+ static_cast<int>(fragments.size()));
+ EXPECT_THAT(
+ tree.code_to_expr_,
+ ContainerEq(std::vector<compute::Expression>{
+ equal(field_ref("a"), literal("1")), equal(field_ref("b"), literal("2")),
+ equal(field_ref("a"), literal("2")), equal(field_ref("b"), literal("3"))}));
+ EXPECT_THAT(
+ encoded,
+ testing::UnorderedElementsAreArray({
+ SubtreeImpl::Encoded{util::make_optional<int>(0),
+ SubtreeImpl::expression_codes({0, 1})},
+ SubtreeImpl::Encoded{util::make_optional<int>(1),
+ SubtreeImpl::expression_codes({2, 3})},
+ SubtreeImpl::Encoded{util::nullopt, SubtreeImpl::expression_codes({0})},
+ SubtreeImpl::Encoded{util::nullopt, SubtreeImpl::expression_codes({2})},
+ SubtreeImpl::Encoded{util::nullopt, SubtreeImpl::expression_codes({0, 1})},
+ SubtreeImpl::Encoded{util::nullopt, SubtreeImpl::expression_codes({2, 3})},
+ }));
+}
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/task_util.cc b/src/arrow/cpp/src/arrow/compute/exec/task_util.cc
new file mode 100644
index 000000000..e5e714d34
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/task_util.cc
@@ -0,0 +1,409 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/task_util.h"
+
+#include <algorithm>
+#include <mutex>
+
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace compute {
+
+class TaskSchedulerImpl : public TaskScheduler {
+ public:
+ TaskSchedulerImpl();
+ int RegisterTaskGroup(TaskImpl task_impl, TaskGroupContinuationImpl cont_impl) override;
+ void RegisterEnd() override;
+ Status StartTaskGroup(size_t thread_id, int group_id, int64_t total_num_tasks) override;
+ Status ExecuteMore(size_t thread_id, int num_tasks_to_execute,
+ bool execute_all) override;
+ Status StartScheduling(size_t thread_id, ScheduleImpl schedule_impl,
+ int num_concurrent_tasks, bool use_sync_execution) override;
+ void Abort(AbortContinuationImpl impl) override;
+
+ private:
+ // Task group state transitions progress one way.
+ // Seeing an old version of the state by a thread is a valid situation.
+ //
+ enum class TaskGroupState : int {
+ NOT_READY,
+ READY,
+ ALL_TASKS_STARTED,
+ ALL_TASKS_FINISHED
+ };
+
+ struct TaskGroup {
+ TaskGroup(TaskImpl task_impl, TaskGroupContinuationImpl cont_impl)
+ : task_impl_(std::move(task_impl)),
+ cont_impl_(std::move(cont_impl)),
+ state_(TaskGroupState::NOT_READY),
+ num_tasks_present_(0) {
+ num_tasks_started_.value.store(0);
+ num_tasks_finished_.value.store(0);
+ }
+ TaskGroup(const TaskGroup& src)
+ : task_impl_(src.task_impl_),
+ cont_impl_(src.cont_impl_),
+ state_(TaskGroupState::NOT_READY),
+ num_tasks_present_(0) {
+ ARROW_DCHECK(src.state_ == TaskGroupState::NOT_READY);
+ num_tasks_started_.value.store(0);
+ num_tasks_finished_.value.store(0);
+ }
+ TaskImpl task_impl_;
+ TaskGroupContinuationImpl cont_impl_;
+
+ TaskGroupState state_;
+ int64_t num_tasks_present_;
+
+ AtomicWithPadding<int64_t> num_tasks_started_;
+ AtomicWithPadding<int64_t> num_tasks_finished_;
+ };
+
+ std::vector<std::pair<int, int64_t>> PickTasks(int num_tasks, int start_task_group = 0);
+ Status ExecuteTask(size_t thread_id, int group_id, int64_t task_id,
+ bool* task_group_finished);
+ bool PostExecuteTask(size_t thread_id, int group_id);
+ Status OnTaskGroupFinished(size_t thread_id, int group_id,
+ bool* all_task_groups_finished);
+ Status ScheduleMore(size_t thread_id, int num_tasks_finished = 0);
+
+ bool use_sync_execution_;
+ int num_concurrent_tasks_;
+ ScheduleImpl schedule_impl_;
+ AbortContinuationImpl abort_cont_impl_;
+
+ std::vector<TaskGroup> task_groups_;
+ bool aborted_;
+ bool register_finished_;
+ std::mutex mutex_; // Mutex protecting task_groups_ (state_ and num_tasks_present_
+ // fields), aborted_ flag and register_finished_ flag
+
+ AtomicWithPadding<int> num_tasks_to_schedule_;
+};
+
+TaskSchedulerImpl::TaskSchedulerImpl()
+ : use_sync_execution_(false),
+ num_concurrent_tasks_(0),
+ aborted_(false),
+ register_finished_(false) {
+ num_tasks_to_schedule_.value.store(0);
+}
+
+int TaskSchedulerImpl::RegisterTaskGroup(TaskImpl task_impl,
+ TaskGroupContinuationImpl cont_impl) {
+ int result = static_cast<int>(task_groups_.size());
+ task_groups_.emplace_back(std::move(task_impl), std::move(cont_impl));
+ return result;
+}
+
+void TaskSchedulerImpl::RegisterEnd() {
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ register_finished_ = true;
+}
+
+Status TaskSchedulerImpl::StartTaskGroup(size_t thread_id, int group_id,
+ int64_t total_num_tasks) {
+ ARROW_DCHECK(group_id >= 0 && group_id < static_cast<int>(task_groups_.size()));
+ TaskGroup& task_group = task_groups_[group_id];
+
+ bool aborted = false;
+ bool all_tasks_finished = false;
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ aborted = aborted_;
+
+ if (task_group.state_ == TaskGroupState::NOT_READY) {
+ task_group.num_tasks_present_ = total_num_tasks;
+ if (total_num_tasks == 0) {
+ task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
+ all_tasks_finished = true;
+ }
+ task_group.state_ = TaskGroupState::READY;
+ }
+ }
+
+ if (!aborted && all_tasks_finished) {
+ bool all_task_groups_finished = false;
+ RETURN_NOT_OK(OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished));
+ if (all_task_groups_finished) {
+ return Status::OK();
+ }
+ }
+
+ if (!aborted) {
+ return ScheduleMore(thread_id);
+ } else {
+ return Status::Cancelled("Scheduler cancelled");
+ }
+}
+
+std::vector<std::pair<int, int64_t>> TaskSchedulerImpl::PickTasks(int num_tasks,
+ int start_task_group) {
+ std::vector<std::pair<int, int64_t>> result;
+ for (size_t i = 0; i < task_groups_.size(); ++i) {
+ int task_group_id = static_cast<int>((start_task_group + i) % (task_groups_.size()));
+ TaskGroup& task_group = task_groups_[task_group_id];
+
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (task_group.state_ != TaskGroupState::READY) {
+ continue;
+ }
+ }
+
+ int num_tasks_remaining = num_tasks - static_cast<int>(result.size());
+ int64_t start_task =
+ task_group.num_tasks_started_.value.fetch_add(num_tasks_remaining);
+ if (start_task >= task_group.num_tasks_present_) {
+ continue;
+ }
+
+ int num_tasks_current_group = num_tasks_remaining;
+ if (start_task + num_tasks_current_group >= task_group.num_tasks_present_) {
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (task_group.state_ == TaskGroupState::READY) {
+ task_group.state_ = TaskGroupState::ALL_TASKS_STARTED;
+ }
+ }
+ num_tasks_current_group =
+ static_cast<int>(task_group.num_tasks_present_ - start_task);
+ }
+
+ for (int64_t task_id = start_task; task_id < start_task + num_tasks_current_group;
+ ++task_id) {
+ result.push_back(std::make_pair(task_group_id, task_id));
+ }
+
+ if (static_cast<int>(result.size()) == num_tasks) {
+ break;
+ }
+ }
+
+ return result;
+}
+
+Status TaskSchedulerImpl::ExecuteTask(size_t thread_id, int group_id, int64_t task_id,
+ bool* task_group_finished) {
+ if (!aborted_) {
+ RETURN_NOT_OK(task_groups_[group_id].task_impl_(thread_id, task_id));
+ }
+ *task_group_finished = PostExecuteTask(thread_id, group_id);
+ return Status::OK();
+}
+
+bool TaskSchedulerImpl::PostExecuteTask(size_t thread_id, int group_id) {
+ int64_t total = task_groups_[group_id].num_tasks_present_;
+ int64_t prev_finished = task_groups_[group_id].num_tasks_finished_.value.fetch_add(1);
+ bool all_tasks_finished = (prev_finished + 1 == total);
+ return all_tasks_finished;
+}
+
+Status TaskSchedulerImpl::OnTaskGroupFinished(size_t thread_id, int group_id,
+ bool* all_task_groups_finished) {
+ bool aborted = false;
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ aborted = aborted_;
+ TaskGroup& task_group = task_groups_[group_id];
+ task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
+ *all_task_groups_finished = true;
+ for (size_t i = 0; i < task_groups_.size(); ++i) {
+ if (task_groups_[i].state_ != TaskGroupState::ALL_TASKS_FINISHED) {
+ *all_task_groups_finished = false;
+ break;
+ }
+ }
+ }
+
+ if (aborted && *all_task_groups_finished) {
+ abort_cont_impl_();
+ return Status::Cancelled("Scheduler cancelled");
+ }
+ if (!aborted) {
+ RETURN_NOT_OK(task_groups_[group_id].cont_impl_(thread_id));
+ }
+ return Status::OK();
+}
+
+Status TaskSchedulerImpl::ExecuteMore(size_t thread_id, int num_tasks_to_execute,
+ bool execute_all) {
+ num_tasks_to_execute = std::max(1, num_tasks_to_execute);
+
+ int last_id = 0;
+ for (;;) {
+ if (aborted_) {
+ return Status::Cancelled("Scheduler cancelled");
+ }
+
+ // Pick next bundle of tasks
+ const auto& tasks = PickTasks(num_tasks_to_execute, last_id);
+ if (tasks.empty()) {
+ break;
+ }
+ last_id = tasks.back().first;
+
+ // Execute picked tasks immediately
+ for (size_t i = 0; i < tasks.size(); ++i) {
+ int group_id = tasks[i].first;
+ int64_t task_id = tasks[i].second;
+ bool task_group_finished = false;
+ Status status = ExecuteTask(thread_id, group_id, task_id, &task_group_finished);
+ if (!status.ok()) {
+ // Mark the remaining picked tasks as finished
+ for (size_t j = i + 1; j < tasks.size(); ++j) {
+ if (PostExecuteTask(thread_id, tasks[j].first)) {
+ bool all_task_groups_finished = false;
+ RETURN_NOT_OK(
+ OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished));
+ if (all_task_groups_finished) {
+ return Status::OK();
+ }
+ }
+ }
+ return status;
+ } else {
+ if (task_group_finished) {
+ bool all_task_groups_finished = false;
+ RETURN_NOT_OK(
+ OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished));
+ if (all_task_groups_finished) {
+ return Status::OK();
+ }
+ }
+ }
+ }
+
+ if (!execute_all) {
+ num_tasks_to_execute -= static_cast<int>(tasks.size());
+ if (num_tasks_to_execute == 0) {
+ break;
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+Status TaskSchedulerImpl::StartScheduling(size_t thread_id, ScheduleImpl schedule_impl,
+ int num_concurrent_tasks,
+ bool use_sync_execution) {
+ schedule_impl_ = std::move(schedule_impl);
+ use_sync_execution_ = use_sync_execution;
+ num_concurrent_tasks_ = num_concurrent_tasks;
+ num_tasks_to_schedule_.value += num_concurrent_tasks;
+ return ScheduleMore(thread_id);
+}
+
+Status TaskSchedulerImpl::ScheduleMore(size_t thread_id, int num_tasks_finished) {
+ if (aborted_) {
+ return Status::Cancelled("Scheduler cancelled");
+ }
+
+ ARROW_DCHECK(register_finished_);
+
+ if (use_sync_execution_) {
+ return ExecuteMore(thread_id, 1, true);
+ }
+
+ int num_new_tasks = num_tasks_finished;
+ for (;;) {
+ int expected = num_tasks_to_schedule_.value.load();
+ if (num_tasks_to_schedule_.value.compare_exchange_strong(expected, 0)) {
+ num_new_tasks += expected;
+ break;
+ }
+ }
+ if (num_new_tasks == 0) {
+ return Status::OK();
+ }
+
+ const auto& tasks = PickTasks(num_new_tasks);
+ if (static_cast<int>(tasks.size()) < num_new_tasks) {
+ num_tasks_to_schedule_.value += num_new_tasks - static_cast<int>(tasks.size());
+ }
+
+ for (size_t i = 0; i < tasks.size(); ++i) {
+ int group_id = tasks[i].first;
+ int64_t task_id = tasks[i].second;
+ RETURN_NOT_OK(schedule_impl_([this, group_id, task_id](size_t thread_id) -> Status {
+ RETURN_NOT_OK(ScheduleMore(thread_id, 1));
+
+ bool task_group_finished = false;
+ RETURN_NOT_OK(ExecuteTask(thread_id, group_id, task_id, &task_group_finished));
+
+ if (task_group_finished) {
+ bool all_task_groups_finished = false;
+ return OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished);
+ }
+
+ return Status::OK();
+ }));
+ }
+
+ return Status::OK();
+}
+
+void TaskSchedulerImpl::Abort(AbortContinuationImpl impl) {
+ bool all_finished = true;
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ aborted_ = true;
+ abort_cont_impl_ = std::move(impl);
+ if (register_finished_) {
+ for (size_t i = 0; i < task_groups_.size(); ++i) {
+ TaskGroup& task_group = task_groups_[i];
+ if (task_group.state_ == TaskGroupState::NOT_READY) {
+ task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
+ } else if (task_group.state_ == TaskGroupState::READY) {
+ int64_t expected = task_group.num_tasks_started_.value.load();
+ for (;;) {
+ if (task_group.num_tasks_started_.value.compare_exchange_strong(
+ expected, task_group.num_tasks_present_)) {
+ break;
+ }
+ }
+ int64_t before_add = task_group.num_tasks_finished_.value.fetch_add(
+ task_group.num_tasks_present_ - expected);
+ if (before_add >= expected) {
+ task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
+ } else {
+ all_finished = false;
+ task_group.state_ = TaskGroupState::ALL_TASKS_STARTED;
+ }
+ }
+ }
+ }
+ }
+ if (all_finished) {
+ abort_cont_impl_();
+ }
+}
+
+std::unique_ptr<TaskScheduler> TaskScheduler::Make() {
+ std::unique_ptr<TaskSchedulerImpl> impl{new TaskSchedulerImpl()};
+ return std::move(impl);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/task_util.h b/src/arrow/cpp/src/arrow/compute/exec/task_util.h
new file mode 100644
index 000000000..44540d255
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/task_util.h
@@ -0,0 +1,100 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <atomic>
+#include <cstdint>
+#include <functional>
+#include <vector>
+
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace compute {
+
+// Atomic value surrounded by padding bytes to avoid cache line invalidation
+// whenever it is modified by a concurrent thread on a different CPU core.
+//
+template <typename T>
+class AtomicWithPadding {
+ private:
+ static constexpr int kCacheLineSize = 64;
+ uint8_t padding_before[kCacheLineSize];
+
+ public:
+ std::atomic<T> value;
+
+ private:
+ uint8_t padding_after[kCacheLineSize];
+};
+
+// Used for asynchronous execution of operations that can be broken into
+// a fixed number of symmetric tasks that can be executed concurrently.
+//
+// Implements priorities between multiple such operations, called task groups.
+//
+// Allows to specify the maximum number of in-flight tasks at any moment.
+//
+// Also allows for executing next pending tasks immediately using a caller thread.
+//
+class TaskScheduler {
+ public:
+ using TaskImpl = std::function<Status(size_t, int64_t)>;
+ using TaskGroupContinuationImpl = std::function<Status(size_t)>;
+ using ScheduleImpl = std::function<Status(TaskGroupContinuationImpl)>;
+ using AbortContinuationImpl = std::function<void()>;
+
+ virtual ~TaskScheduler() = default;
+
+ // Order in which task groups are registered represents priorities of their tasks
+ // (the first group has the highest priority).
+ //
+ // Returns task group identifier that is used to request operations on the task group.
+ virtual int RegisterTaskGroup(TaskImpl task_impl,
+ TaskGroupContinuationImpl cont_impl) = 0;
+
+ virtual void RegisterEnd() = 0;
+
+ // total_num_tasks may be zero, in which case task group continuation will be executed
+ // immediately
+ virtual Status StartTaskGroup(size_t thread_id, int group_id,
+ int64_t total_num_tasks) = 0;
+
+ // Execute given number of tasks immediately using caller thread
+ virtual Status ExecuteMore(size_t thread_id, int num_tasks_to_execute,
+ bool execute_all) = 0;
+
+ // Begin scheduling tasks using provided callback and
+ // the limit on the number of in-flight tasks at any moment.
+ //
+ // Scheduling will continue as long as there are waiting tasks.
+ //
+ // It will automatically resume whenever new task group gets started.
+ virtual Status StartScheduling(size_t thread_id, ScheduleImpl schedule_impl,
+ int num_concurrent_tasks, bool use_sync_execution) = 0;
+
+ // Abort scheduling and execution.
+ // Used in case of being notified about unrecoverable error for the entire query.
+ virtual void Abort(AbortContinuationImpl impl) = 0;
+
+ static std::unique_ptr<TaskScheduler> Make();
+};
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/test_util.cc b/src/arrow/cpp/src/arrow/compute/exec/test_util.cc
new file mode 100644
index 000000000..964c09398
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/test_util.cc
@@ -0,0 +1,239 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/test_util.h"
+
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
+
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <memory>
+#include <mutex>
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/datum.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/vector.h"
+
+namespace arrow {
+
+using internal::Executor;
+
+namespace compute {
+namespace {
+
+struct DummyNode : ExecNode {
+ DummyNode(ExecPlan* plan, NodeVector inputs, int num_outputs,
+ StartProducingFunc start_producing, StopProducingFunc stop_producing)
+ : ExecNode(plan, std::move(inputs), {}, dummy_schema(), num_outputs),
+ start_producing_(std::move(start_producing)),
+ stop_producing_(std::move(stop_producing)) {
+ input_labels_.resize(inputs_.size());
+ for (size_t i = 0; i < input_labels_.size(); ++i) {
+ input_labels_[i] = std::to_string(i);
+ }
+ }
+
+ const char* kind_name() const override { return "Dummy"; }
+
+ void InputReceived(ExecNode* input, ExecBatch batch) override {}
+
+ void ErrorReceived(ExecNode* input, Status error) override {}
+
+ void InputFinished(ExecNode* input, int total_batches) override {}
+
+ Status StartProducing() override {
+ if (start_producing_) {
+ RETURN_NOT_OK(start_producing_(this));
+ }
+ started_ = true;
+ return Status::OK();
+ }
+
+ void PauseProducing(ExecNode* output) override {
+ ASSERT_GE(num_outputs(), 0) << "Sink nodes should not experience backpressure";
+ AssertIsOutput(output);
+ }
+
+ void ResumeProducing(ExecNode* output) override {
+ ASSERT_GE(num_outputs(), 0) << "Sink nodes should not experience backpressure";
+ AssertIsOutput(output);
+ }
+
+ void StopProducing(ExecNode* output) override {
+ EXPECT_GE(num_outputs(), 0) << "Sink nodes should not experience backpressure";
+ AssertIsOutput(output);
+ }
+
+ void StopProducing() override {
+ if (started_) {
+ for (const auto& input : inputs_) {
+ input->StopProducing(this);
+ }
+ if (stop_producing_) {
+ stop_producing_(this);
+ }
+ }
+ }
+
+ Future<> finished() override { return Future<>::MakeFinished(); }
+
+ private:
+ void AssertIsOutput(ExecNode* output) {
+ auto it = std::find(outputs_.begin(), outputs_.end(), output);
+ ASSERT_NE(it, outputs_.end());
+ }
+
+ std::shared_ptr<Schema> dummy_schema() const {
+ return schema({field("dummy", null())});
+ }
+
+ StartProducingFunc start_producing_;
+ StopProducingFunc stop_producing_;
+ std::unordered_set<ExecNode*> requested_stop_;
+ bool started_ = false;
+};
+
+} // namespace
+
+ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector<ExecNode*> inputs,
+ int num_outputs, StartProducingFunc start_producing,
+ StopProducingFunc stop_producing) {
+ auto node =
+ plan->EmplaceNode<DummyNode>(plan, std::move(inputs), num_outputs,
+ std::move(start_producing), std::move(stop_producing));
+ if (!label.empty()) {
+ node->SetLabel(std::move(label));
+ }
+ return node;
+}
+
+ExecBatch ExecBatchFromJSON(const std::vector<ValueDescr>& descrs,
+ util::string_view json) {
+ auto fields = ::arrow::internal::MapVector(
+ [](const ValueDescr& descr) { return field("", descr.type); }, descrs);
+
+ ExecBatch batch{*RecordBatchFromJSON(schema(std::move(fields)), json)};
+
+ auto value_it = batch.values.begin();
+ for (const auto& descr : descrs) {
+ if (descr.shape == ValueDescr::SCALAR) {
+ if (batch.length == 0) {
+ *value_it = MakeNullScalar(value_it->type());
+ } else {
+ *value_it = value_it->make_array()->GetScalar(0).ValueOrDie();
+ }
+ }
+ ++value_it;
+ }
+
+ return batch;
+}
+
+Future<std::vector<ExecBatch>> StartAndCollect(
+ ExecPlan* plan, AsyncGenerator<util::optional<ExecBatch>> gen) {
+ RETURN_NOT_OK(plan->Validate());
+ RETURN_NOT_OK(plan->StartProducing());
+
+ auto collected_fut = CollectAsyncGenerator(gen);
+
+ return AllComplete({plan->finished(), Future<>(collected_fut)})
+ .Then([collected_fut]() -> Result<std::vector<ExecBatch>> {
+ ARROW_ASSIGN_OR_RAISE(auto collected, collected_fut.result());
+ return ::arrow::internal::MapVector(
+ [](util::optional<ExecBatch> batch) { return std::move(*batch); },
+ std::move(collected));
+ });
+}
+
+BatchesWithSchema MakeBasicBatches() {
+ BatchesWithSchema out;
+ out.batches = {
+ ExecBatchFromJSON({int32(), boolean()}, "[[null, true], [4, false]]"),
+ ExecBatchFromJSON({int32(), boolean()}, "[[5, null], [6, false], [7, false]]")};
+ out.schema = schema({field("i32", int32()), field("bool", boolean())});
+ return out;
+}
+
+BatchesWithSchema MakeRandomBatches(const std::shared_ptr<Schema>& schema,
+ int num_batches, int batch_size) {
+ BatchesWithSchema out;
+
+ random::RandomArrayGenerator rng(42);
+ out.batches.resize(num_batches);
+
+ for (int i = 0; i < num_batches; ++i) {
+ out.batches[i] = ExecBatch(*rng.BatchOf(schema->fields(), batch_size));
+ // add a tag scalar to ensure the batches are unique
+ out.batches[i].values.emplace_back(i);
+ }
+
+ out.schema = schema;
+ return out;
+}
+
+Result<std::shared_ptr<Table>> SortTableOnAllFields(const std::shared_ptr<Table>& tab) {
+ std::vector<SortKey> sort_keys;
+ for (auto&& f : tab->schema()->fields()) {
+ sort_keys.emplace_back(f->name());
+ }
+ ARROW_ASSIGN_OR_RAISE(auto sort_ids, SortIndices(tab, SortOptions(sort_keys)));
+ ARROW_ASSIGN_OR_RAISE(auto tab_sorted, Take(tab, sort_ids));
+ return tab_sorted.table();
+}
+
+void AssertTablesEqual(const std::shared_ptr<Table>& exp,
+ const std::shared_ptr<Table>& act) {
+ ASSERT_EQ(exp->num_columns(), act->num_columns());
+ if (exp->num_rows() == 0) {
+ ASSERT_EQ(exp->num_rows(), act->num_rows());
+ } else {
+ ASSERT_OK_AND_ASSIGN(auto exp_sorted, SortTableOnAllFields(exp));
+ ASSERT_OK_AND_ASSIGN(auto act_sorted, SortTableOnAllFields(act));
+
+ AssertTablesEqual(*exp_sorted, *act_sorted,
+ /*same_chunk_layout=*/false, /*flatten=*/true);
+ }
+}
+
+void AssertExecBatchesEqual(const std::shared_ptr<Schema>& schema,
+ const std::vector<ExecBatch>& exp,
+ const std::vector<ExecBatch>& act) {
+ ASSERT_OK_AND_ASSIGN(auto exp_tab, TableFromExecBatches(schema, exp));
+ ASSERT_OK_AND_ASSIGN(auto act_tab, TableFromExecBatches(schema, act));
+ AssertTablesEqual(exp_tab, act_tab);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/test_util.h b/src/arrow/cpp/src/arrow/compute/exec/test_util.h
new file mode 100644
index 000000000..dad55fc36
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/test_util.h
@@ -0,0 +1,107 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <arrow/testing/gtest_util.h>
+#include <arrow/util/vector.h>
+
+#include <functional>
+#include <string>
+#include <vector>
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/testing/visibility.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace compute {
+
+using StartProducingFunc = std::function<Status(ExecNode*)>;
+using StopProducingFunc = std::function<void(ExecNode*)>;
+
+// Make a dummy node that has no execution behaviour
+ARROW_TESTING_EXPORT
+ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector<ExecNode*> inputs,
+ int num_outputs, StartProducingFunc = {}, StopProducingFunc = {});
+
+ARROW_TESTING_EXPORT
+ExecBatch ExecBatchFromJSON(const std::vector<ValueDescr>& descrs,
+ util::string_view json);
+
+struct BatchesWithSchema {
+ std::vector<ExecBatch> batches;
+ std::shared_ptr<Schema> schema;
+
+ AsyncGenerator<util::optional<ExecBatch>> gen(bool parallel, bool slow) const {
+ auto opt_batches = ::arrow::internal::MapVector(
+ [](ExecBatch batch) { return util::make_optional(std::move(batch)); }, batches);
+
+ AsyncGenerator<util::optional<ExecBatch>> gen;
+
+ if (parallel) {
+ // emulate batches completing initial decode-after-scan on a cpu thread
+ gen = MakeBackgroundGenerator(MakeVectorIterator(std::move(opt_batches)),
+ ::arrow::internal::GetCpuThreadPool())
+ .ValueOrDie();
+
+ // ensure that callbacks are not executed immediately on a background thread
+ gen =
+ MakeTransferredGenerator(std::move(gen), ::arrow::internal::GetCpuThreadPool());
+ } else {
+ gen = MakeVectorGenerator(std::move(opt_batches));
+ }
+
+ if (slow) {
+ gen =
+ MakeMappedGenerator(std::move(gen), [](const util::optional<ExecBatch>& batch) {
+ SleepABit();
+ return batch;
+ });
+ }
+
+ return gen;
+ }
+};
+
+ARROW_TESTING_EXPORT
+Future<std::vector<ExecBatch>> StartAndCollect(
+ ExecPlan* plan, AsyncGenerator<util::optional<ExecBatch>> gen);
+
+ARROW_TESTING_EXPORT
+BatchesWithSchema MakeBasicBatches();
+
+ARROW_TESTING_EXPORT
+BatchesWithSchema MakeRandomBatches(const std::shared_ptr<Schema>& schema,
+ int num_batches = 10, int batch_size = 4);
+
+ARROW_TESTING_EXPORT
+Result<std::shared_ptr<Table>> SortTableOnAllFields(const std::shared_ptr<Table>& tab);
+
+ARROW_TESTING_EXPORT
+void AssertTablesEqual(const std::shared_ptr<Table>& exp,
+ const std::shared_ptr<Table>& act);
+
+ARROW_TESTING_EXPORT
+void AssertExecBatchesEqual(const std::shared_ptr<Schema>& schema,
+ const std::vector<ExecBatch>& exp,
+ const std::vector<ExecBatch>& act);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/union_node.cc b/src/arrow/cpp/src/arrow/compute/exec/union_node.cc
new file mode 100644
index 000000000..fef2f4e18
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/union_node.cc
@@ -0,0 +1,154 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <mutex>
+
+#include "arrow/api.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+namespace {
+std::vector<std::string> GetInputLabels(const ExecNode::NodeVector& inputs) {
+ std::vector<std::string> labels(inputs.size());
+ for (size_t i = 0; i < inputs.size(); i++) {
+ labels[i] = "input_" + std::to_string(i) + "_label";
+ }
+ return labels;
+}
+} // namespace
+
+class UnionNode : public ExecNode {
+ public:
+ UnionNode(ExecPlan* plan, std::vector<ExecNode*> inputs)
+ : ExecNode(plan, inputs, GetInputLabels(inputs),
+ /*output_schema=*/inputs[0]->output_schema(),
+ /*num_outputs=*/1) {
+ bool counter_completed = input_count_.SetTotal(static_cast<int>(inputs.size()));
+ ARROW_DCHECK(counter_completed == false);
+ }
+
+ const char* kind_name() const override { return "UnionNode"; }
+
+ static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, static_cast<int>(inputs.size()),
+ "UnionNode"));
+ if (inputs.size() < 1) {
+ return Status::Invalid("Constructing a `UnionNode` with inputs size less than 1");
+ }
+ auto schema = inputs.at(0)->output_schema();
+ for (auto input : inputs) {
+ if (!input->output_schema()->Equals(schema)) {
+ return Status::Invalid(
+ "UnionNode input schemas must all match, first schema was: ",
+ schema->ToString(), " got schema: ", input->output_schema()->ToString());
+ }
+ }
+ return plan->EmplaceNode<UnionNode>(plan, std::move(inputs));
+ }
+
+ void InputReceived(ExecNode* input, ExecBatch batch) override {
+ ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end());
+
+ if (finished_.is_finished()) {
+ return;
+ }
+ outputs_[0]->InputReceived(this, std::move(batch));
+ if (batch_count_.Increment()) {
+ finished_.MarkFinished();
+ }
+ }
+
+ void ErrorReceived(ExecNode* input, Status error) override {
+ DCHECK_EQ(input, inputs_[0]);
+ outputs_[0]->ErrorReceived(this, std::move(error));
+
+ StopProducing();
+ }
+
+ void InputFinished(ExecNode* input, int total_batches) override {
+ ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end());
+
+ total_batches_.fetch_add(total_batches);
+
+ if (input_count_.Increment()) {
+ outputs_[0]->InputFinished(this, total_batches_.load());
+ if (batch_count_.SetTotal(total_batches_.load())) {
+ finished_.MarkFinished();
+ }
+ }
+ }
+
+ Status StartProducing() override {
+ finished_ = Future<>::Make();
+ return Status::OK();
+ }
+
+ void PauseProducing(ExecNode* output) override {}
+
+ void ResumeProducing(ExecNode* output) override {}
+
+ void StopProducing(ExecNode* output) override {
+ DCHECK_EQ(output, outputs_[0]);
+ if (batch_count_.Cancel()) {
+ finished_.MarkFinished();
+ }
+ for (auto&& input : inputs_) {
+ input->StopProducing(this);
+ }
+ }
+
+ void StopProducing() override {
+ if (batch_count_.Cancel()) {
+ finished_.MarkFinished();
+ }
+ for (auto&& input : inputs_) {
+ input->StopProducing(this);
+ }
+ }
+
+ Future<> finished() override { return finished_; }
+
+ private:
+ AtomicCounter batch_count_;
+ AtomicCounter input_count_;
+ std::atomic<int> total_batches_{0};
+ Future<> finished_ = Future<>::MakeFinished();
+};
+
+namespace internal {
+
+void RegisterUnionNode(ExecFactoryRegistry* registry) {
+ DCHECK_OK(registry->AddFactory("union", UnionNode::Make));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/union_node_test.cc b/src/arrow/cpp/src/arrow/compute/exec/union_node_test.cc
new file mode 100644
index 000000000..41aaac26d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/union_node_test.cc
@@ -0,0 +1,150 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gmock/gmock-matchers.h>
+#include <random>
+
+#include "arrow/api.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+#include "arrow/testing/random.h"
+
+using testing::UnorderedElementsAreArray;
+
+namespace arrow {
+namespace compute {
+
+struct TestUnionNode : public ::testing::Test {
+ static constexpr int kNumBatches = 10;
+ static constexpr int kBatchSize = 10;
+
+ TestUnionNode() : rng_(0) {}
+
+ std::shared_ptr<Schema> GenerateRandomSchema(size_t num_inputs) {
+ static std::vector<std::shared_ptr<DataType>> some_arrow_types = {
+ arrow::null(), arrow::boolean(), arrow::int8(), arrow::int16(),
+ arrow::int32(), arrow::int64(), arrow::float16(), arrow::float32(),
+ arrow::float64(), arrow::utf8(), arrow::binary(), arrow::date32()};
+
+ std::vector<std::shared_ptr<Field>> fields(num_inputs);
+ std::default_random_engine gen(42);
+ std::uniform_int_distribution<int> types_dist(
+ 0, static_cast<int>(some_arrow_types.size()) - 1);
+ for (size_t i = 0; i < num_inputs; i++) {
+ int random_index = types_dist(gen);
+ auto col_type = some_arrow_types.at(random_index);
+ fields[i] =
+ field("column_" + std::to_string(i) + "_" + col_type->ToString(), col_type);
+ }
+ return schema(fields);
+ }
+
+ void GenerateBatchesFromSchema(const std::shared_ptr<Schema>& schema,
+ size_t num_batches, BatchesWithSchema* out_batches,
+ int multiplicity = 1, int64_t batch_size = 4) {
+ if (num_batches == 0) {
+ auto empty_record_batch = ExecBatch(*rng_.BatchOf(schema->fields(), 0));
+ out_batches->batches.push_back(empty_record_batch);
+ } else {
+ for (size_t j = 0; j < num_batches; j++) {
+ out_batches->batches.push_back(
+ ExecBatch(*rng_.BatchOf(schema->fields(), batch_size)));
+ }
+ }
+
+ size_t batch_count = out_batches->batches.size();
+ for (int repeat = 1; repeat < multiplicity; ++repeat) {
+ for (size_t i = 0; i < batch_count; ++i) {
+ out_batches->batches.push_back(out_batches->batches[i]);
+ }
+ }
+ out_batches->schema = schema;
+ }
+
+ void CheckRunOutput(const std::vector<BatchesWithSchema>& batches,
+ const BatchesWithSchema& exp_batches, bool parallel = false) {
+ SCOPED_TRACE(parallel ? "parallel" : "single threaded");
+
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
+
+ Declaration union_decl{"union", ExecNodeOptions{}};
+
+ for (const auto& batch : batches) {
+ union_decl.inputs.emplace_back(Declaration{
+ "source", SourceNodeOptions{batch.schema, batch.gen(parallel,
+ /*slow=*/false)}});
+ }
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+
+ // Test UnionNode::Make with zero inputs
+ if (batches.size() == 0) {
+ ASSERT_RAISES(Invalid, Declaration::Sequence(
+ {union_decl, {"sink", SinkNodeOptions{&sink_gen}}})
+ .AddToPlan(plan.get()));
+ return;
+ } else {
+ ASSERT_OK(Declaration::Sequence({union_decl, {"sink", SinkNodeOptions{&sink_gen}}})
+ .AddToPlan(plan.get()));
+ }
+
+ Future<std::vector<ExecBatch>> actual = StartAndCollect(plan.get(), sink_gen);
+
+ auto expected_matcher =
+ Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)));
+ ASSERT_THAT(actual, expected_matcher);
+ }
+
+ void CheckUnionExecNode(size_t num_input_nodes, size_t num_batches, bool parallel) {
+ auto random_schema = GenerateRandomSchema(num_input_nodes);
+
+ int multiplicity = parallel ? 10 : 1;
+ std::vector<std::shared_ptr<RecordBatch>> all_record_batches;
+ std::vector<BatchesWithSchema> input_batches(num_input_nodes);
+ BatchesWithSchema exp_batches;
+ exp_batches.schema = random_schema;
+ for (size_t i = 0; i < num_input_nodes; i++) {
+ GenerateBatchesFromSchema(random_schema, num_batches, &input_batches[i],
+ multiplicity, kBatchSize);
+ for (const auto& batch : input_batches[i].batches) {
+ exp_batches.batches.push_back(batch);
+ }
+ }
+ CheckRunOutput(input_batches, exp_batches, parallel);
+ }
+
+ ::arrow::random::RandomArrayGenerator rng_;
+};
+
+TEST_F(TestUnionNode, TestNonEmpty) {
+ for (bool parallel : {false, true}) {
+ for (int64_t num_input_nodes : {1, 2, 4, 8}) {
+ this->CheckUnionExecNode(num_input_nodes, kNumBatches, parallel);
+ }
+ }
+}
+TEST_F(TestUnionNode, TestWithAnEmptyBatch) {
+ this->CheckUnionExecNode(/*num_input_nodes*/ 2, /*num_batches=*/0, /*parallel=*/false);
+}
+
+TEST_F(TestUnionNode, TestEmpty) {
+ this->CheckUnionExecNode(/*num_input_nodes*/ 0, /*num_batches=*/0, /*parallel=*/false);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/util.cc b/src/arrow/cpp/src/arrow/compute/exec/util.cc
new file mode 100644
index 000000000..64060d445
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/util.cc
@@ -0,0 +1,336 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/util.h"
+
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/table.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/ubsan.h"
+
+namespace arrow {
+
+using BitUtil::CountTrailingZeros;
+
+namespace util {
+
+inline void BitUtil::bits_to_indexes_helper(uint64_t word, uint16_t base_index,
+ int* num_indexes, uint16_t* indexes) {
+ int n = *num_indexes;
+ while (word) {
+ indexes[n++] = base_index + static_cast<uint16_t>(CountTrailingZeros(word));
+ word &= word - 1;
+ }
+ *num_indexes = n;
+}
+
+inline void BitUtil::bits_filter_indexes_helper(uint64_t word,
+ const uint16_t* input_indexes,
+ int* num_indexes, uint16_t* indexes) {
+ int n = *num_indexes;
+ while (word) {
+ indexes[n++] = input_indexes[CountTrailingZeros(word)];
+ word &= word - 1;
+ }
+ *num_indexes = n;
+}
+
+template <int bit_to_search, bool filter_input_indexes>
+void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bits,
+ const uint8_t* bits, const uint16_t* input_indexes,
+ int* num_indexes, uint16_t* indexes,
+ uint16_t base_index) {
+ // 64 bits at a time
+ constexpr int unroll = 64;
+ int tail = num_bits % unroll;
+#if defined(ARROW_HAVE_AVX2)
+ if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
+ if (filter_input_indexes) {
+ bits_filter_indexes_avx2(bit_to_search, num_bits - tail, bits, input_indexes,
+ num_indexes, indexes);
+ } else {
+ bits_to_indexes_avx2(bit_to_search, num_bits - tail, bits, num_indexes, indexes,
+ base_index);
+ }
+ } else {
+#endif
+ *num_indexes = 0;
+ for (int i = 0; i < num_bits / unroll; ++i) {
+ uint64_t word = util::SafeLoad(&reinterpret_cast<const uint64_t*>(bits)[i]);
+ if (bit_to_search == 0) {
+ word = ~word;
+ }
+ if (filter_input_indexes) {
+ bits_filter_indexes_helper(word, input_indexes + i * 64, num_indexes, indexes);
+ } else {
+ bits_to_indexes_helper(word, i * 64 + base_index, num_indexes, indexes);
+ }
+ }
+#if defined(ARROW_HAVE_AVX2)
+ }
+#endif
+ // Optionally process the last partial word with masking out bits outside range
+ if (tail) {
+ uint64_t word =
+ util::SafeLoad(&reinterpret_cast<const uint64_t*>(bits)[num_bits / unroll]);
+ if (bit_to_search == 0) {
+ word = ~word;
+ }
+ word &= ~0ULL >> (64 - tail);
+ if (filter_input_indexes) {
+ bits_filter_indexes_helper(word, input_indexes + num_bits - tail, num_indexes,
+ indexes);
+ } else {
+ bits_to_indexes_helper(word, num_bits - tail + base_index, num_indexes, indexes);
+ }
+ }
+}
+
+void BitUtil::bits_to_indexes(int bit_to_search, int64_t hardware_flags, int num_bits,
+ const uint8_t* bits, int* num_indexes, uint16_t* indexes,
+ int bit_offset) {
+ bits += bit_offset / 8;
+ bit_offset %= 8;
+ *num_indexes = 0;
+ uint16_t base_index = 0;
+ if (bit_offset != 0) {
+ uint64_t bits_head =
+ util::SafeLoad(reinterpret_cast<const uint64_t*>(bits)) >> bit_offset;
+ int bits_in_first_byte = std::min(num_bits, 8 - bit_offset);
+ bits_to_indexes(bit_to_search, hardware_flags, bits_in_first_byte,
+ reinterpret_cast<const uint8_t*>(&bits_head), num_indexes, indexes);
+ if (num_bits <= bits_in_first_byte) {
+ return;
+ }
+ num_bits -= bits_in_first_byte;
+ indexes += *num_indexes;
+ bits += 1;
+ base_index = bits_in_first_byte;
+ }
+
+ int num_indexes_new = 0;
+ if (bit_to_search == 0) {
+ bits_to_indexes_internal<0, false>(hardware_flags, num_bits, bits, nullptr,
+ &num_indexes_new, indexes, base_index);
+ } else {
+ ARROW_DCHECK(bit_to_search == 1);
+ bits_to_indexes_internal<1, false>(hardware_flags, num_bits, bits, nullptr,
+ &num_indexes_new, indexes, base_index);
+ }
+ *num_indexes += num_indexes_new;
+}
+
+void BitUtil::bits_filter_indexes(int bit_to_search, int64_t hardware_flags,
+ const int num_bits, const uint8_t* bits,
+ const uint16_t* input_indexes, int* num_indexes,
+ uint16_t* indexes, int bit_offset) {
+ bits += bit_offset / 8;
+ bit_offset %= 8;
+ if (bit_offset != 0) {
+ int num_indexes_head = 0;
+ uint64_t bits_head =
+ util::SafeLoad(reinterpret_cast<const uint64_t*>(bits)) >> bit_offset;
+ int bits_in_first_byte = std::min(num_bits, 8 - bit_offset);
+ bits_filter_indexes(bit_to_search, hardware_flags, bits_in_first_byte,
+ reinterpret_cast<const uint8_t*>(&bits_head), input_indexes,
+ &num_indexes_head, indexes);
+ int num_indexes_tail = 0;
+ if (num_bits > bits_in_first_byte) {
+ bits_filter_indexes(bit_to_search, hardware_flags, num_bits - bits_in_first_byte,
+ bits + 1, input_indexes + bits_in_first_byte, &num_indexes_tail,
+ indexes + num_indexes_head);
+ }
+ *num_indexes = num_indexes_head + num_indexes_tail;
+ return;
+ }
+
+ if (bit_to_search == 0) {
+ bits_to_indexes_internal<0, true>(hardware_flags, num_bits, bits, input_indexes,
+ num_indexes, indexes);
+ } else {
+ ARROW_DCHECK(bit_to_search == 1);
+ bits_to_indexes_internal<1, true>(hardware_flags, num_bits, bits, input_indexes,
+ num_indexes, indexes);
+ }
+}
+
+void BitUtil::bits_split_indexes(int64_t hardware_flags, const int num_bits,
+ const uint8_t* bits, int* num_indexes_bit0,
+ uint16_t* indexes_bit0, uint16_t* indexes_bit1,
+ int bit_offset) {
+ bits_to_indexes(0, hardware_flags, num_bits, bits, num_indexes_bit0, indexes_bit0,
+ bit_offset);
+ int num_indexes_bit1;
+ bits_to_indexes(1, hardware_flags, num_bits, bits, &num_indexes_bit1, indexes_bit1,
+ bit_offset);
+}
+
+void BitUtil::bits_to_bytes(int64_t hardware_flags, const int num_bits,
+ const uint8_t* bits, uint8_t* bytes, int bit_offset) {
+ bits += bit_offset / 8;
+ bit_offset %= 8;
+ if (bit_offset != 0) {
+ uint64_t bits_head =
+ util::SafeLoad(reinterpret_cast<const uint64_t*>(bits)) >> bit_offset;
+ int bits_in_first_byte = std::min(num_bits, 8 - bit_offset);
+ bits_to_bytes(hardware_flags, bits_in_first_byte,
+ reinterpret_cast<const uint8_t*>(&bits_head), bytes);
+ if (num_bits > bits_in_first_byte) {
+ bits_to_bytes(hardware_flags, num_bits - bits_in_first_byte, bits + 1,
+ bytes + bits_in_first_byte);
+ }
+ return;
+ }
+
+ int num_processed = 0;
+#if defined(ARROW_HAVE_AVX2)
+ if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
+ // The function call below processes whole 32 bit chunks together.
+ num_processed = num_bits - (num_bits % 32);
+ bits_to_bytes_avx2(num_processed, bits, bytes);
+ }
+#endif
+ // Processing 8 bits at a time
+ constexpr int unroll = 8;
+ for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) {
+ uint8_t bits_next = bits[i];
+ // Clear the lowest bit and then make 8 copies of remaining 7 bits, each 7 bits apart
+ // from the previous.
+ uint64_t unpacked = static_cast<uint64_t>(bits_next & 0xfe) *
+ ((1ULL << 7) | (1ULL << 14) | (1ULL << 21) | (1ULL << 28) |
+ (1ULL << 35) | (1ULL << 42) | (1ULL << 49));
+ unpacked |= (bits_next & 1);
+ unpacked &= 0x0101010101010101ULL;
+ unpacked *= 255;
+ util::SafeStore(&reinterpret_cast<uint64_t*>(bytes)[i], unpacked);
+ }
+}
+
+void BitUtil::bytes_to_bits(int64_t hardware_flags, const int num_bits,
+ const uint8_t* bytes, uint8_t* bits, int bit_offset) {
+ bits += bit_offset / 8;
+ bit_offset %= 8;
+ if (bit_offset != 0) {
+ uint64_t bits_head;
+ int bits_in_first_byte = std::min(num_bits, 8 - bit_offset);
+ bytes_to_bits(hardware_flags, bits_in_first_byte, bytes,
+ reinterpret_cast<uint8_t*>(&bits_head));
+ uint8_t mask = (1 << bit_offset) - 1;
+ *bits = static_cast<uint8_t>((*bits & mask) | (bits_head << bit_offset));
+
+ if (num_bits > bits_in_first_byte) {
+ bytes_to_bits(hardware_flags, num_bits - bits_in_first_byte,
+ bytes + bits_in_first_byte, bits + 1);
+ }
+ return;
+ }
+
+ int num_processed = 0;
+#if defined(ARROW_HAVE_AVX2)
+ if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
+ // The function call below processes whole 32 bit chunks together.
+ num_processed = num_bits - (num_bits % 32);
+ bytes_to_bits_avx2(num_processed, bytes, bits);
+ }
+#endif
+ // Process 8 bits at a time
+ constexpr int unroll = 8;
+ for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) {
+ uint64_t bytes_next = util::SafeLoad(&reinterpret_cast<const uint64_t*>(bytes)[i]);
+ bytes_next &= 0x0101010101010101ULL;
+ bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes
+ bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes
+ bytes_next |= (bytes_next >> 28); // All 8 output bits in the lowest byte
+ bits[i] = static_cast<uint8_t>(bytes_next & 0xff);
+ }
+}
+
+bool BitUtil::are_all_bytes_zero(int64_t hardware_flags, const uint8_t* bytes,
+ uint32_t num_bytes) {
+#if defined(ARROW_HAVE_AVX2)
+ if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
+ return are_all_bytes_zero_avx2(bytes, num_bytes);
+ }
+#endif
+ uint64_t result_or = 0;
+ uint32_t i;
+ for (i = 0; i < num_bytes / 8; ++i) {
+ uint64_t x = util::SafeLoad(&reinterpret_cast<const uint64_t*>(bytes)[i]);
+ result_or |= x;
+ }
+ if (num_bytes % 8 > 0) {
+ uint64_t tail = 0;
+ result_or |= memcmp(bytes + i * 8, &tail, num_bytes % 8);
+ }
+ return result_or == 0;
+}
+
+} // namespace util
+
+namespace compute {
+
+Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector<ExecNode*>& inputs,
+ int expected_num_inputs, const char* kind_name) {
+ if (static_cast<int>(inputs.size()) != expected_num_inputs) {
+ return Status::Invalid(kind_name, " node requires ", expected_num_inputs,
+ " inputs but got ", inputs.size());
+ }
+
+ for (auto input : inputs) {
+ if (input->plan() != plan) {
+ return Status::Invalid("Constructing a ", kind_name,
+ " node in a different plan from its input");
+ }
+ }
+
+ return Status::OK();
+}
+
+Result<std::shared_ptr<Table>> TableFromExecBatches(
+ const std::shared_ptr<Schema>& schema, const std::vector<ExecBatch>& exec_batches) {
+ RecordBatchVector batches;
+ for (const auto& batch : exec_batches) {
+ ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema));
+ batches.push_back(std::move(rb));
+ }
+ return Table::FromRecordBatches(schema, batches);
+}
+
+size_t ThreadIndexer::operator()() {
+ auto id = std::this_thread::get_id();
+
+ auto guard = mutex_.Lock(); // acquire the lock
+ const auto& id_index = *id_to_index_.emplace(id, id_to_index_.size()).first;
+
+ return Check(id_index.second);
+}
+
+size_t ThreadIndexer::Capacity() {
+ static size_t max_size = arrow::internal::ThreadPool::DefaultCapacity() + 1;
+ return max_size;
+}
+
+size_t ThreadIndexer::Check(size_t thread_index) {
+ DCHECK_LT(thread_index, Capacity())
+ << "thread index " << thread_index << " is out of range [0, " << Capacity() << ")";
+
+ return thread_index;
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/util.h b/src/arrow/cpp/src/arrow/compute/exec/util.h
new file mode 100644
index 000000000..800c6f0e9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/util.h
@@ -0,0 +1,277 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <atomic>
+#include <cstdint>
+#include <thread>
+#include <unordered_map>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/compute/type_fwd.h"
+#include "arrow/memory_pool.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/cpu_info.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/mutex.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/thread_pool.h"
+
+#if defined(__clang__) || defined(__GNUC__)
+#define BYTESWAP(x) __builtin_bswap64(x)
+#define ROTL(x, n) (((x) << (n)) | ((x) >> (32 - (n))))
+#elif defined(_MSC_VER)
+#include <intrin.h>
+#define BYTESWAP(x) _byteswap_uint64(x)
+#define ROTL(x, n) _rotl((x), (n))
+#endif
+
+namespace arrow {
+namespace util {
+
+template <typename T>
+inline void CheckAlignment(const void* ptr) {
+ ARROW_DCHECK(reinterpret_cast<uint64_t>(ptr) % sizeof(T) == 0);
+}
+
+// Some platforms typedef int64_t as long int instead of long long int,
+// which breaks the _mm256_i64gather_epi64 and _mm256_i32gather_epi64 intrinsics
+// which need long long.
+// We use the cast to the type below in these intrinsics to make the code
+// compile in all cases.
+//
+using int64_for_gather_t = const long long int; // NOLINT runtime-int
+
+/// Storage used to allocate temporary vectors of a batch size.
+/// Temporary vectors should resemble allocating temporary variables on the stack
+/// but in the context of vectorized processing where we need to store a vector of
+/// temporaries instead of a single value.
+class TempVectorStack {
+ template <typename>
+ friend class TempVectorHolder;
+
+ public:
+ Status Init(MemoryPool* pool, int64_t size) {
+ num_vectors_ = 0;
+ top_ = 0;
+ buffer_size_ = size;
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateResizableBuffer(size, pool));
+ // Ensure later operations don't accidentally read uninitialized memory.
+ std::memset(buffer->mutable_data(), 0xFF, size);
+ buffer_ = std::move(buffer);
+ return Status::OK();
+ }
+
+ private:
+ int64_t PaddedAllocationSize(int64_t num_bytes) {
+ // Round up allocation size to multiple of 8 bytes
+ // to avoid returning temp vectors with unaligned address.
+ //
+ // Also add padding at the end to facilitate loads and stores
+ // using SIMD when number of vector elements is not divisible
+ // by the number of SIMD lanes.
+ //
+ return ::arrow::BitUtil::RoundUp(num_bytes, sizeof(int64_t)) + kPadding;
+ }
+ void alloc(uint32_t num_bytes, uint8_t** data, int* id) {
+ int64_t old_top = top_;
+ top_ += PaddedAllocationSize(num_bytes) + 2 * sizeof(uint64_t);
+ // Stack overflow check
+ ARROW_DCHECK(top_ <= buffer_size_);
+ *data = buffer_->mutable_data() + old_top + sizeof(uint64_t);
+ // We set 8 bytes before the beginning of the allocated range and
+ // 8 bytes after the end to check for stack overflow (which would
+ // result in those known bytes being corrupted).
+ reinterpret_cast<uint64_t*>(buffer_->mutable_data() + old_top)[0] = kGuard1;
+ reinterpret_cast<uint64_t*>(buffer_->mutable_data() + top_)[-1] = kGuard2;
+ *id = num_vectors_++;
+ }
+ void release(int id, uint32_t num_bytes) {
+ ARROW_DCHECK(num_vectors_ == id + 1);
+ int64_t size = PaddedAllocationSize(num_bytes) + 2 * sizeof(uint64_t);
+ ARROW_DCHECK(reinterpret_cast<const uint64_t*>(buffer_->mutable_data() + top_)[-1] ==
+ kGuard2);
+ ARROW_DCHECK(top_ >= size);
+ top_ -= size;
+ ARROW_DCHECK(reinterpret_cast<const uint64_t*>(buffer_->mutable_data() + top_)[0] ==
+ kGuard1);
+ --num_vectors_;
+ }
+ static constexpr uint64_t kGuard1 = 0x3141592653589793ULL;
+ static constexpr uint64_t kGuard2 = 0x0577215664901532ULL;
+ static constexpr int64_t kPadding = 64;
+ int num_vectors_;
+ int64_t top_;
+ std::unique_ptr<Buffer> buffer_;
+ int64_t buffer_size_;
+};
+
+template <typename T>
+class TempVectorHolder {
+ friend class TempVectorStack;
+
+ public:
+ ~TempVectorHolder() { stack_->release(id_, num_elements_ * sizeof(T)); }
+ T* mutable_data() { return reinterpret_cast<T*>(data_); }
+ TempVectorHolder(TempVectorStack* stack, uint32_t num_elements) {
+ stack_ = stack;
+ num_elements_ = num_elements;
+ stack_->alloc(num_elements * sizeof(T), &data_, &id_);
+ }
+
+ private:
+ TempVectorStack* stack_;
+ uint8_t* data_;
+ int id_;
+ uint32_t num_elements_;
+};
+
+class BitUtil {
+ public:
+ static void bits_to_indexes(int bit_to_search, int64_t hardware_flags,
+ const int num_bits, const uint8_t* bits, int* num_indexes,
+ uint16_t* indexes, int bit_offset = 0);
+
+ static void bits_filter_indexes(int bit_to_search, int64_t hardware_flags,
+ const int num_bits, const uint8_t* bits,
+ const uint16_t* input_indexes, int* num_indexes,
+ uint16_t* indexes, int bit_offset = 0);
+
+ // Input and output indexes may be pointing to the same data (in-place filtering).
+ static void bits_split_indexes(int64_t hardware_flags, const int num_bits,
+ const uint8_t* bits, int* num_indexes_bit0,
+ uint16_t* indexes_bit0, uint16_t* indexes_bit1,
+ int bit_offset = 0);
+
+ // Bit 1 is replaced with byte 0xFF.
+ static void bits_to_bytes(int64_t hardware_flags, const int num_bits,
+ const uint8_t* bits, uint8_t* bytes, int bit_offset = 0);
+
+ // Return highest bit of each byte.
+ static void bytes_to_bits(int64_t hardware_flags, const int num_bits,
+ const uint8_t* bytes, uint8_t* bits, int bit_offset = 0);
+
+ static bool are_all_bytes_zero(int64_t hardware_flags, const uint8_t* bytes,
+ uint32_t num_bytes);
+
+ private:
+ inline static void bits_to_indexes_helper(uint64_t word, uint16_t base_index,
+ int* num_indexes, uint16_t* indexes);
+ inline static void bits_filter_indexes_helper(uint64_t word,
+ const uint16_t* input_indexes,
+ int* num_indexes, uint16_t* indexes);
+ template <int bit_to_search, bool filter_input_indexes>
+ static void bits_to_indexes_internal(int64_t hardware_flags, const int num_bits,
+ const uint8_t* bits, const uint16_t* input_indexes,
+ int* num_indexes, uint16_t* indexes,
+ uint16_t base_index = 0);
+
+#if defined(ARROW_HAVE_AVX2)
+ static void bits_to_indexes_avx2(int bit_to_search, const int num_bits,
+ const uint8_t* bits, int* num_indexes,
+ uint16_t* indexes, uint16_t base_index = 0);
+ static void bits_filter_indexes_avx2(int bit_to_search, const int num_bits,
+ const uint8_t* bits, const uint16_t* input_indexes,
+ int* num_indexes, uint16_t* indexes);
+ template <int bit_to_search>
+ static void bits_to_indexes_imp_avx2(const int num_bits, const uint8_t* bits,
+ int* num_indexes, uint16_t* indexes,
+ uint16_t base_index = 0);
+ template <int bit_to_search>
+ static void bits_filter_indexes_imp_avx2(const int num_bits, const uint8_t* bits,
+ const uint16_t* input_indexes,
+ int* num_indexes, uint16_t* indexes);
+ static void bits_to_bytes_avx2(const int num_bits, const uint8_t* bits, uint8_t* bytes);
+ static void bytes_to_bits_avx2(const int num_bits, const uint8_t* bytes, uint8_t* bits);
+ static bool are_all_bytes_zero_avx2(const uint8_t* bytes, uint32_t num_bytes);
+#endif
+};
+
+} // namespace util
+namespace compute {
+
+ARROW_EXPORT
+Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector<ExecNode*>& inputs,
+ int expected_num_inputs, const char* kind_name);
+
+ARROW_EXPORT
+Result<std::shared_ptr<Table>> TableFromExecBatches(
+ const std::shared_ptr<Schema>& schema, const std::vector<ExecBatch>& exec_batches);
+
+class AtomicCounter {
+ public:
+ AtomicCounter() = default;
+
+ int count() const { return count_.load(); }
+
+ util::optional<int> total() const {
+ int total = total_.load();
+ if (total == -1) return {};
+ return total;
+ }
+
+ // return true if the counter is complete
+ bool Increment() {
+ DCHECK_NE(count_.load(), total_.load());
+ int count = count_.fetch_add(1) + 1;
+ if (count != total_.load()) return false;
+ return DoneOnce();
+ }
+
+ // return true if the counter is complete
+ bool SetTotal(int total) {
+ total_.store(total);
+ if (count_.load() != total) return false;
+ return DoneOnce();
+ }
+
+ // return true if the counter has not already been completed
+ bool Cancel() { return DoneOnce(); }
+
+ // return true if the counter has finished or been cancelled
+ bool Completed() { return complete_.load(); }
+
+ private:
+ // ensure there is only one true return from Increment(), SetTotal(), or Cancel()
+ bool DoneOnce() {
+ bool expected = false;
+ return complete_.compare_exchange_strong(expected, true);
+ }
+
+ std::atomic<int> count_{0}, total_{-1};
+ std::atomic<bool> complete_{false};
+};
+
+class ThreadIndexer {
+ public:
+ size_t operator()();
+
+ static size_t Capacity();
+
+ private:
+ static size_t Check(size_t thread_index);
+
+ util::Mutex mutex_;
+ std::unordered_map<std::thread::id, size_t> id_to_index_;
+};
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/util_avx2.cc b/src/arrow/cpp/src/arrow/compute/exec/util_avx2.cc
new file mode 100644
index 000000000..bdc0e41f5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/util_avx2.cc
@@ -0,0 +1,221 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <immintrin.h>
+
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/bit_util.h"
+
+namespace arrow {
+namespace util {
+
+#if defined(ARROW_HAVE_AVX2)
+
+void BitUtil::bits_to_indexes_avx2(int bit_to_search, const int num_bits,
+ const uint8_t* bits, int* num_indexes,
+ uint16_t* indexes, uint16_t base_index) {
+ if (bit_to_search == 0) {
+ bits_to_indexes_imp_avx2<0>(num_bits, bits, num_indexes, indexes, base_index);
+ } else {
+ ARROW_DCHECK(bit_to_search == 1);
+ bits_to_indexes_imp_avx2<1>(num_bits, bits, num_indexes, indexes, base_index);
+ }
+}
+
+template <int bit_to_search>
+void BitUtil::bits_to_indexes_imp_avx2(const int num_bits, const uint8_t* bits,
+ int* num_indexes, uint16_t* indexes,
+ uint16_t base_index) {
+ // 64 bits at a time
+ constexpr int unroll = 64;
+
+ // The caller takes care of processing the remaining bits at the end outside of the
+ // multiples of 64
+ ARROW_DCHECK(num_bits % unroll == 0);
+
+ constexpr uint64_t kEachByteIs1 = 0X0101010101010101ULL;
+ constexpr uint64_t kEachByteIs8 = 0x0808080808080808ULL;
+ constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL;
+
+ uint8_t byte_indexes[64];
+ const uint64_t incr = kEachByteIs8;
+ const uint64_t mask = kByteSequence0To7;
+ *num_indexes = 0;
+ for (int i = 0; i < num_bits / unroll; ++i) {
+ uint64_t word = reinterpret_cast<const uint64_t*>(bits)[i];
+ if (bit_to_search == 0) {
+ word = ~word;
+ }
+ uint64_t base = 0;
+ int num_indexes_loop = 0;
+ while (word) {
+ uint64_t byte_indexes_next =
+ _pext_u64(mask, _pdep_u64(word, kEachByteIs1) * 0xff) + base;
+ *reinterpret_cast<uint64_t*>(byte_indexes + num_indexes_loop) = byte_indexes_next;
+ base += incr;
+ num_indexes_loop += static_cast<int>(arrow::BitUtil::PopCount(word & 0xff));
+ word >>= 8;
+ }
+ // Unpack indexes to 16-bits and either add the base of i * 64 or shuffle input
+ // indexes
+ for (int j = 0; j < (num_indexes_loop + 15) / 16; ++j) {
+ __m256i output = _mm256_cvtepi8_epi16(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(byte_indexes) + j));
+ output = _mm256_add_epi16(output, _mm256_set1_epi16(i * 64 + base_index));
+ _mm256_storeu_si256(((__m256i*)(indexes + *num_indexes)) + j, output);
+ }
+ *num_indexes += num_indexes_loop;
+ }
+}
+
+void BitUtil::bits_filter_indexes_avx2(int bit_to_search, const int num_bits,
+ const uint8_t* bits, const uint16_t* input_indexes,
+ int* num_indexes, uint16_t* indexes) {
+ if (bit_to_search == 0) {
+ bits_filter_indexes_imp_avx2<0>(num_bits, bits, input_indexes, num_indexes, indexes);
+ } else {
+ bits_filter_indexes_imp_avx2<1>(num_bits, bits, input_indexes, num_indexes, indexes);
+ }
+}
+
+template <int bit_to_search>
+void BitUtil::bits_filter_indexes_imp_avx2(const int num_bits, const uint8_t* bits,
+ const uint16_t* input_indexes,
+ int* out_num_indexes, uint16_t* indexes) {
+ // 64 bits at a time
+ constexpr int unroll = 64;
+
+ // The caller takes care of processing the remaining bits at the end outside of the
+ // multiples of 64
+ ARROW_DCHECK(num_bits % unroll == 0);
+
+ constexpr uint64_t kRepeatedBitPattern0001 = 0x1111111111111111ULL;
+ constexpr uint64_t k4BitSequence0To15 = 0xfedcba9876543210ULL;
+ constexpr uint64_t kByteSequence_0_0_1_1_2_2_3_3 = 0x0303020201010000ULL;
+ constexpr uint64_t kByteSequence_4_4_5_5_6_6_7_7 = 0x0707060605050404ULL;
+ constexpr uint64_t kByteSequence_0_2_4_6_8_10_12_14 = 0x0e0c0a0806040200ULL;
+ constexpr uint64_t kByteSequence_1_3_5_7_9_11_13_15 = 0x0f0d0b0907050301ULL;
+ constexpr uint64_t kByteSequence_0_8_1_9_2_10_3_11 = 0x0b030a0209010800ULL;
+ constexpr uint64_t kByteSequence_4_12_5_13_6_14_7_15 = 0x0f070e060d050c04ULL;
+
+ const uint64_t mask = k4BitSequence0To15;
+ int num_indexes = 0;
+ for (int i = 0; i < num_bits / unroll; ++i) {
+ uint64_t word = reinterpret_cast<const uint64_t*>(bits)[i];
+ if (bit_to_search == 0) {
+ word = ~word;
+ }
+
+ int loop_id = 0;
+ while (word) {
+ uint64_t indexes_4bit =
+ _pext_u64(mask, _pdep_u64(word, kRepeatedBitPattern0001) * 0xf);
+ // Unpack 4 bit indexes to 8 bits
+ __m256i indexes_8bit = _mm256_set1_epi64x(indexes_4bit);
+ indexes_8bit = _mm256_shuffle_epi8(
+ indexes_8bit,
+ _mm256_setr_epi64x(kByteSequence_0_0_1_1_2_2_3_3, kByteSequence_4_4_5_5_6_6_7_7,
+ kByteSequence_0_0_1_1_2_2_3_3,
+ kByteSequence_4_4_5_5_6_6_7_7));
+ indexes_8bit = _mm256_blendv_epi8(
+ _mm256_and_si256(indexes_8bit, _mm256_set1_epi8(0x0f)),
+ _mm256_and_si256(_mm256_srli_epi32(indexes_8bit, 4), _mm256_set1_epi8(0x0f)),
+ _mm256_set1_epi16(static_cast<uint16_t>(0xff00)));
+ __m256i input =
+ _mm256_loadu_si256(((const __m256i*)input_indexes) + 4 * i + loop_id);
+ // Shuffle bytes to get low bytes in the first 128-bit lane and high bytes in the
+ // second
+ input = _mm256_shuffle_epi8(
+ input, _mm256_setr_epi64x(
+ kByteSequence_0_2_4_6_8_10_12_14, kByteSequence_1_3_5_7_9_11_13_15,
+ kByteSequence_0_2_4_6_8_10_12_14, kByteSequence_1_3_5_7_9_11_13_15));
+ input = _mm256_permute4x64_epi64(input, 0xd8); // 0b11011000
+ // Apply permutation
+ __m256i output = _mm256_shuffle_epi8(input, indexes_8bit);
+ // Move low and high bytes across 128-bit lanes to assemble back 16-bit indexes.
+ // (This is the reverse of the byte permutation we did on the input)
+ output = _mm256_permute4x64_epi64(output,
+ 0xd8); // The reverse of swapping 2nd and 3rd
+ // 64-bit element is the same permutation
+ output = _mm256_shuffle_epi8(output,
+ _mm256_setr_epi64x(kByteSequence_0_8_1_9_2_10_3_11,
+ kByteSequence_4_12_5_13_6_14_7_15,
+ kByteSequence_0_8_1_9_2_10_3_11,
+ kByteSequence_4_12_5_13_6_14_7_15));
+ _mm256_storeu_si256((__m256i*)(indexes + num_indexes), output);
+ num_indexes += static_cast<int>(arrow::BitUtil::PopCount(word & 0xffff));
+ word >>= 16;
+ ++loop_id;
+ }
+ }
+
+ *out_num_indexes = num_indexes;
+}
+
+void BitUtil::bits_to_bytes_avx2(const int num_bits, const uint8_t* bits,
+ uint8_t* bytes) {
+ constexpr int unroll = 32;
+
+ constexpr uint64_t kEachByteIs1 = 0x0101010101010101ULL;
+ constexpr uint64_t kEachByteIs2 = 0x0202020202020202ULL;
+ constexpr uint64_t kEachByteIs3 = 0x0303030303030303ULL;
+ constexpr uint64_t kByteSequencePowersOf2 = 0x8040201008040201ULL;
+
+ // Processing 32 bits at a time
+ for (int i = 0; i < num_bits / unroll; ++i) {
+ __m256i unpacked = _mm256_set1_epi32(reinterpret_cast<const uint32_t*>(bits)[i]);
+ unpacked = _mm256_shuffle_epi8(
+ unpacked, _mm256_setr_epi64x(0ULL, kEachByteIs1, kEachByteIs2, kEachByteIs3));
+ __m256i bits_in_bytes = _mm256_set1_epi64x(kByteSequencePowersOf2);
+ unpacked =
+ _mm256_cmpeq_epi8(bits_in_bytes, _mm256_and_si256(unpacked, bits_in_bytes));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(bytes) + i, unpacked);
+ }
+}
+
+void BitUtil::bytes_to_bits_avx2(const int num_bits, const uint8_t* bytes,
+ uint8_t* bits) {
+ constexpr int unroll = 32;
+ // Processing 32 bits at a time
+ for (int i = 0; i < num_bits / unroll; ++i) {
+ reinterpret_cast<uint32_t*>(bits)[i] = _mm256_movemask_epi8(
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bytes) + i));
+ }
+}
+
+bool BitUtil::are_all_bytes_zero_avx2(const uint8_t* bytes, uint32_t num_bytes) {
+ __m256i result_or = _mm256_setzero_si256();
+ uint32_t i;
+ for (i = 0; i < num_bytes / 32; ++i) {
+ __m256i x = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bytes) + i);
+ result_or = _mm256_or_si256(result_or, x);
+ }
+ result_or = _mm256_cmpeq_epi8(result_or, _mm256_set1_epi8(0));
+ result_or =
+ _mm256_andnot_si256(result_or, _mm256_set1_epi8(static_cast<uint8_t>(0xff)));
+ uint32_t result_or32 = _mm256_movemask_epi8(result_or);
+ if (num_bytes % 32 > 0) {
+ uint64_t tail[4] = {0, 0, 0, 0};
+ result_or32 |= memcmp(bytes + i * 32, tail, num_bytes % 32);
+ }
+ return result_or32 == 0;
+}
+
+#endif // ARROW_HAVE_AVX2
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec/util_test.cc b/src/arrow/cpp/src/arrow/compute/exec/util_test.cc
new file mode 100644
index 000000000..7acf8228d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec/util_test.cc
@@ -0,0 +1,131 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/hash_join.h"
+#include "arrow/compute/exec/schema_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+
+using testing::Eq;
+
+namespace arrow {
+namespace compute {
+
+const char* kLeftPrefix = "left.";
+const char* kRightPrefix = "right.";
+
+TEST(FieldMap, Trivial) {
+ HashJoinSchema schema_mgr;
+
+ auto left = schema({field("i32", int32())});
+ auto right = schema({field("i32", int32())});
+
+ ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"}, kLeftPrefix,
+ kRightPrefix));
+
+ auto output = schema_mgr.MakeOutputSchema(kLeftPrefix, kRightPrefix);
+ EXPECT_THAT(*output, Eq(Schema({
+ field("left.i32", int32()),
+ field("right.i32", int32()),
+ })));
+
+ auto i =
+ schema_mgr.proj_maps[0].map(HashJoinProjection::INPUT, HashJoinProjection::OUTPUT);
+ EXPECT_EQ(i.get(0), 0);
+}
+
+TEST(FieldMap, TrivialDuplicates) {
+ HashJoinSchema schema_mgr;
+
+ auto left = schema({field("i32", int32())});
+ auto right = schema({field("i32", int32())});
+
+ ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"}, "", ""));
+
+ auto output = schema_mgr.MakeOutputSchema("", "");
+ EXPECT_THAT(*output, Eq(Schema({
+ field("i32", int32()),
+ field("i32", int32()),
+ })));
+
+ auto i =
+ schema_mgr.proj_maps[0].map(HashJoinProjection::INPUT, HashJoinProjection::OUTPUT);
+ EXPECT_EQ(i.get(0), 0);
+}
+
+TEST(FieldMap, SingleKeyField) {
+ HashJoinSchema schema_mgr;
+
+ auto left = schema({field("i32", int32()), field("str", utf8())});
+ auto right = schema({field("f32", float32()), field("i32", int32())});
+
+ ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"}, kLeftPrefix,
+ kRightPrefix));
+
+ EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::INPUT), 2);
+ EXPECT_EQ(schema_mgr.proj_maps[1].num_cols(HashJoinProjection::INPUT), 2);
+ EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::KEY), 1);
+ EXPECT_EQ(schema_mgr.proj_maps[1].num_cols(HashJoinProjection::KEY), 1);
+ EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::OUTPUT), 2);
+ EXPECT_EQ(schema_mgr.proj_maps[1].num_cols(HashJoinProjection::OUTPUT), 2);
+
+ auto output = schema_mgr.MakeOutputSchema(kLeftPrefix, kRightPrefix);
+ EXPECT_THAT(*output, Eq(Schema({
+ field("left.i32", int32()),
+ field("left.str", utf8()),
+ field("right.f32", float32()),
+ field("right.i32", int32()),
+ })));
+
+ auto i =
+ schema_mgr.proj_maps[0].map(HashJoinProjection::INPUT, HashJoinProjection::OUTPUT);
+ EXPECT_EQ(i.get(0), 0);
+}
+
+TEST(FieldMap, TwoKeyFields) {
+ HashJoinSchema schema_mgr;
+
+ auto left = schema({
+ field("i32", int32()),
+ field("str", utf8()),
+ field("bool", boolean()),
+ });
+ auto right = schema({
+ field("i32", int32()),
+ field("str", utf8()),
+ field("f32", float32()),
+ field("f64", float64()),
+ });
+
+ ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32", "str"}, *right,
+ {"i32", "str"}, kLeftPrefix, kRightPrefix));
+
+ auto output = schema_mgr.MakeOutputSchema(kLeftPrefix, kRightPrefix);
+ EXPECT_THAT(*output, Eq(Schema({
+ field("left.i32", int32()),
+ field("left.str", utf8()),
+ field("left.bool", boolean()),
+
+ field("right.i32", int32()),
+ field("right.str", utf8()),
+ field("right.f32", float32()),
+ field("right.f64", float64()),
+ })));
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec_internal.h b/src/arrow/cpp/src/arrow/compute/exec_internal.h
new file mode 100644
index 000000000..74124f022
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec_internal.h
@@ -0,0 +1,145 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/kernel.h"
+#include "arrow/status.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace compute {
+
+class Function;
+
+static constexpr int64_t kDefaultMaxChunksize = std::numeric_limits<int64_t>::max();
+
+namespace detail {
+
+/// \brief Break std::vector<Datum> into a sequence of ExecBatch for kernel
+/// execution
+class ARROW_EXPORT ExecBatchIterator {
+ public:
+ /// \brief Construct iterator and do basic argument validation
+ ///
+ /// \param[in] args the Datum argument, must be all array-like or scalar
+ /// \param[in] max_chunksize the maximum length of each ExecBatch. Depending
+ /// on the chunk layout of ChunkedArray.
+ static Result<std::unique_ptr<ExecBatchIterator>> Make(
+ std::vector<Datum> args, int64_t max_chunksize = kDefaultMaxChunksize);
+
+ /// \brief Compute the next batch. Always returns at least one batch. Return
+ /// false if the iterator is exhausted
+ bool Next(ExecBatch* batch);
+
+ int64_t length() const { return length_; }
+
+ int64_t position() const { return position_; }
+
+ int64_t max_chunksize() const { return max_chunksize_; }
+
+ private:
+ ExecBatchIterator(std::vector<Datum> args, int64_t length, int64_t max_chunksize);
+
+ std::vector<Datum> args_;
+ std::vector<int> chunk_indexes_;
+ std::vector<int64_t> chunk_positions_;
+ int64_t position_;
+ int64_t length_;
+ int64_t max_chunksize_;
+};
+
+// "Push" / listener API like IPC reader so that consumers can receive
+// processed chunks as soon as they're available.
+
+class ARROW_EXPORT ExecListener {
+ public:
+ virtual ~ExecListener() = default;
+
+ virtual Status OnResult(Datum) { return Status::NotImplemented("OnResult"); }
+};
+
+class DatumAccumulator : public ExecListener {
+ public:
+ DatumAccumulator() = default;
+
+ Status OnResult(Datum value) override {
+ values_.emplace_back(value);
+ return Status::OK();
+ }
+
+ std::vector<Datum> values() { return std::move(values_); }
+
+ private:
+ std::vector<Datum> values_;
+};
+
+/// \brief Check that each Datum is of a "value" type, which means either
+/// SCALAR, ARRAY, or CHUNKED_ARRAY. If there are chunked inputs, then these
+/// inputs will be split into non-chunked ExecBatch values for execution
+Status CheckAllValues(const std::vector<Datum>& values);
+
+class ARROW_EXPORT KernelExecutor {
+ public:
+ virtual ~KernelExecutor() = default;
+
+ /// The Kernel's `init` method must be called and any KernelState set in the
+ /// KernelContext *before* KernelExecutor::Init is called. This is to facilitate
+ /// the case where init may be expensive and does not need to be called again for
+ /// each execution of the kernel, for example the same lookup table can be re-used
+ /// for all scanned batches in a dataset filter.
+ virtual Status Init(KernelContext*, KernelInitArgs) = 0;
+
+ /// XXX: Better configurability for listener
+ /// Not thread-safe
+ virtual Status Execute(const std::vector<Datum>& args, ExecListener* listener) = 0;
+
+ virtual Datum WrapResults(const std::vector<Datum>& args,
+ const std::vector<Datum>& outputs) = 0;
+
+ /// \brief Check the actual result type against the resolved output type
+ virtual Status CheckResultType(const Datum& out, const char* function_name) = 0;
+
+ static std::unique_ptr<KernelExecutor> MakeScalar();
+ static std::unique_ptr<KernelExecutor> MakeVector();
+ static std::unique_ptr<KernelExecutor> MakeScalarAggregate();
+};
+
+/// \brief Populate validity bitmap with the intersection of the nullity of the
+/// arguments. If a preallocated bitmap is not provided, then one will be
+/// allocated if needed (in some cases a bitmap can be zero-copied from the
+/// arguments). If any Scalar value is null, then the entire validity bitmap
+/// will be set to null.
+///
+/// \param[in] ctx kernel execution context, for memory allocation etc.
+/// \param[in] batch the data batch
+/// \param[in] out the output ArrayData, must not be null
+ARROW_EXPORT
+Status PropagateNulls(KernelContext* ctx, const ExecBatch& batch, ArrayData* out);
+
+} // namespace detail
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/exec_test.cc b/src/arrow/cpp/src/arrow/compute/exec_test.cc
new file mode 100644
index 000000000..3769517a9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/exec_test.cc
@@ -0,0 +1,891 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstring>
+#include <memory>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/data.h"
+#include "arrow/buffer.h"
+#include "arrow/chunked_array.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec_internal.h"
+#include "arrow/compute/function.h"
+#include "arrow/compute/function_internal.h"
+#include "arrow/compute/kernel.h"
+#include "arrow/compute/registry.h"
+#include "arrow/memory_pool.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/cpu_info.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+namespace detail {
+
+using ::arrow::internal::BitmapEquals;
+using ::arrow::internal::CopyBitmap;
+using ::arrow::internal::CountSetBits;
+
+TEST(ExecContext, BasicWorkings) {
+ {
+ ExecContext ctx;
+ ASSERT_EQ(GetFunctionRegistry(), ctx.func_registry());
+ ASSERT_EQ(default_memory_pool(), ctx.memory_pool());
+ ASSERT_EQ(std::numeric_limits<int64_t>::max(), ctx.exec_chunksize());
+
+ ASSERT_TRUE(ctx.use_threads());
+ ASSERT_EQ(arrow::internal::CpuInfo::GetInstance(), ctx.cpu_info());
+ }
+
+ // Now, let's customize all the things
+ LoggingMemoryPool my_pool(default_memory_pool());
+ std::unique_ptr<FunctionRegistry> custom_reg = FunctionRegistry::Make();
+ ExecContext ctx(&my_pool, /*executor=*/nullptr, custom_reg.get());
+
+ ASSERT_EQ(custom_reg.get(), ctx.func_registry());
+ ASSERT_EQ(&my_pool, ctx.memory_pool());
+
+ ctx.set_exec_chunksize(1 << 20);
+ ASSERT_EQ(1 << 20, ctx.exec_chunksize());
+
+ ctx.set_use_threads(false);
+ ASSERT_FALSE(ctx.use_threads());
+}
+
+TEST(SelectionVector, Basics) {
+ auto indices = ArrayFromJSON(int32(), "[0, 3]");
+ auto sel_vector = std::make_shared<SelectionVector>(*indices);
+
+ ASSERT_EQ(indices->length(), sel_vector->length());
+ ASSERT_EQ(3, sel_vector->indices()[1]);
+}
+
+void AssertValidityZeroExtraBits(const ArrayData& arr) {
+ const Buffer& buf = *arr.buffers[0];
+
+ const int64_t bit_extent = ((arr.offset + arr.length + 7) / 8) * 8;
+ for (int64_t i = arr.offset + arr.length; i < bit_extent; ++i) {
+ EXPECT_FALSE(BitUtil::GetBit(buf.data(), i)) << i;
+ }
+}
+
+class TestComputeInternals : public ::testing::Test {
+ public:
+ void SetUp() {
+ rng_.reset(new random::RandomArrayGenerator(/*seed=*/0));
+ ResetContexts();
+ }
+
+ void ResetContexts() {
+ exec_ctx_.reset(new ExecContext(default_memory_pool()));
+ ctx_.reset(new KernelContext(exec_ctx_.get()));
+ }
+
+ std::shared_ptr<Array> GetUInt8Array(int64_t size, double null_probability = 0.1) {
+ return rng_->UInt8(size, /*min=*/0, /*max=*/100, null_probability);
+ }
+
+ std::shared_ptr<Array> GetInt32Array(int64_t size, double null_probability = 0.1) {
+ return rng_->Int32(size, /*min=*/0, /*max=*/1000, null_probability);
+ }
+
+ std::shared_ptr<Array> GetFloat64Array(int64_t size, double null_probability = 0.1) {
+ return rng_->Float64(size, /*min=*/0, /*max=*/1000, null_probability);
+ }
+
+ std::shared_ptr<ChunkedArray> GetInt32Chunked(const std::vector<int>& sizes) {
+ std::vector<std::shared_ptr<Array>> chunks;
+ for (auto size : sizes) {
+ chunks.push_back(GetInt32Array(size));
+ }
+ return std::make_shared<ChunkedArray>(std::move(chunks));
+ }
+
+ protected:
+ std::unique_ptr<ExecContext> exec_ctx_;
+ std::unique_ptr<KernelContext> ctx_;
+ std::unique_ptr<random::RandomArrayGenerator> rng_;
+};
+
+class TestPropagateNulls : public TestComputeInternals {};
+
+TEST_F(TestPropagateNulls, UnknownNullCountWithNullsZeroCopies) {
+ const int64_t length = 16;
+
+ constexpr uint8_t validity_bitmap[8] = {254, 0, 0, 0, 0, 0, 0, 0};
+ auto nulls = std::make_shared<Buffer>(validity_bitmap, 8);
+
+ ArrayData output(boolean(), length, {nullptr, nullptr});
+ ArrayData input(boolean(), length, {nulls, nullptr}, kUnknownNullCount);
+
+ ExecBatch batch({input}, length);
+ ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output));
+ ASSERT_EQ(nulls.get(), output.buffers[0].get());
+ ASSERT_EQ(kUnknownNullCount, output.null_count);
+ ASSERT_EQ(9, output.GetNullCount());
+}
+
+TEST_F(TestPropagateNulls, UnknownNullCountWithoutNulls) {
+ const int64_t length = 16;
+ constexpr uint8_t validity_bitmap[8] = {255, 255, 0, 0, 0, 0, 0, 0};
+ auto nulls = std::make_shared<Buffer>(validity_bitmap, 8);
+
+ ArrayData output(boolean(), length, {nullptr, nullptr});
+ ArrayData input(boolean(), length, {nulls, nullptr}, kUnknownNullCount);
+
+ ExecBatch batch({input}, length);
+ ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output));
+ EXPECT_EQ(-1, output.null_count);
+ EXPECT_EQ(nulls.get(), output.buffers[0].get());
+}
+
+TEST_F(TestPropagateNulls, SetAllNulls) {
+ const int64_t length = 16;
+
+ auto CheckSetAllNull = [&](std::vector<Datum> values, bool preallocate) {
+ // Make fresh bitmap with all 1's
+ uint8_t bitmap_data[2] = {255, 255};
+ auto preallocated_mem = std::make_shared<MutableBuffer>(bitmap_data, 2);
+
+ std::vector<std::shared_ptr<Buffer>> buffers(2);
+ if (preallocate) {
+ buffers[0] = preallocated_mem;
+ }
+
+ ArrayData output(boolean(), length, buffers);
+
+ ExecBatch batch(values, length);
+ ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output));
+
+ if (preallocate) {
+ // Ensure that buffer object the same when we pass in preallocated memory
+ ASSERT_EQ(preallocated_mem.get(), output.buffers[0].get());
+ }
+ ASSERT_NE(nullptr, output.buffers[0]);
+ uint8_t expected[2] = {0, 0};
+ const Buffer& out_buf = *output.buffers[0];
+ ASSERT_EQ(0, std::memcmp(out_buf.data(), expected, out_buf.size()));
+ };
+
+ // There is a null scalar
+ std::shared_ptr<Scalar> i32_val = std::make_shared<Int32Scalar>(3);
+ std::vector<Datum> vals = {i32_val, MakeNullScalar(boolean())};
+ CheckSetAllNull(vals, true);
+ CheckSetAllNull(vals, false);
+
+ const double true_prob = 0.5;
+
+ vals[0] = rng_->Boolean(length, true_prob);
+ CheckSetAllNull(vals, true);
+ CheckSetAllNull(vals, false);
+
+ auto arr_all_nulls = rng_->Boolean(length, true_prob, /*null_probability=*/1);
+
+ // One value is all null
+ vals = {rng_->Boolean(length, true_prob, /*null_probability=*/0.5), arr_all_nulls};
+ CheckSetAllNull(vals, true);
+ CheckSetAllNull(vals, false);
+
+ // A value is NullType
+ std::shared_ptr<Array> null_arr = std::make_shared<NullArray>(length);
+ vals = {rng_->Boolean(length, true_prob), null_arr};
+ CheckSetAllNull(vals, true);
+ CheckSetAllNull(vals, false);
+
+ // Other nitty-gritty scenarios
+ {
+ // An all-null bitmap is zero-copied over, even though there is a
+ // null-scalar earlier in the batch
+ ArrayData output(boolean(), length, {nullptr, nullptr});
+ ExecBatch batch({MakeNullScalar(boolean()), arr_all_nulls}, length);
+ ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output));
+ ASSERT_EQ(arr_all_nulls->data()->buffers[0].get(), output.buffers[0].get());
+ }
+}
+
+TEST_F(TestPropagateNulls, SingleValueWithNulls) {
+ // Input offset is non-zero (0 mod 8 and nonzero mod 8 cases)
+ const int64_t length = 100;
+ auto arr = rng_->Boolean(length, 0.5, /*null_probability=*/0.5);
+
+ auto CheckSliced = [&](int64_t offset, bool preallocate = false,
+ int64_t out_offset = 0) {
+ // Unaligned bitmap, zero copy not possible
+ auto sliced = arr->Slice(offset);
+ std::vector<Datum> vals = {sliced};
+
+ ArrayData output(boolean(), vals[0].length(), {nullptr, nullptr});
+ output.offset = out_offset;
+
+ ExecBatch batch(vals, vals[0].length());
+
+ std::shared_ptr<Buffer> preallocated_bitmap;
+ if (preallocate) {
+ ASSERT_OK_AND_ASSIGN(
+ preallocated_bitmap,
+ AllocateBuffer(BitUtil::BytesForBits(sliced->length() + out_offset)));
+ std::memset(preallocated_bitmap->mutable_data(), 0, preallocated_bitmap->size());
+ output.buffers[0] = preallocated_bitmap;
+ } else {
+ ASSERT_EQ(0, output.offset);
+ }
+
+ ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output));
+
+ if (!preallocate) {
+ const Buffer* parent_buf = arr->data()->buffers[0].get();
+ if (offset == 0) {
+ // Validity bitmap same, no slice
+ ASSERT_EQ(parent_buf, output.buffers[0].get());
+ } else if (offset % 8 == 0) {
+ // Validity bitmap sliced
+ ASSERT_NE(parent_buf, output.buffers[0].get());
+ ASSERT_EQ(parent_buf, output.buffers[0]->parent().get());
+ } else {
+ // New memory for offset not 0 mod 8
+ ASSERT_NE(parent_buf, output.buffers[0].get());
+ ASSERT_EQ(nullptr, output.buffers[0]->parent());
+ }
+ } else {
+ // preallocated, so check that the validity bitmap is unbothered
+ ASSERT_EQ(preallocated_bitmap.get(), output.buffers[0].get());
+ }
+
+ ASSERT_EQ(arr->Slice(offset)->null_count(), output.GetNullCount());
+
+ ASSERT_TRUE(BitmapEquals(output.buffers[0]->data(), output.offset,
+ sliced->null_bitmap_data(), sliced->offset(),
+ output.length));
+ AssertValidityZeroExtraBits(output);
+ };
+
+ CheckSliced(8);
+ CheckSliced(7);
+ CheckSliced(8, /*preallocated=*/true);
+ CheckSliced(7, true);
+ CheckSliced(8, true, /*offset=*/4);
+ CheckSliced(7, true, 4);
+}
+
+TEST_F(TestPropagateNulls, ZeroCopyWhenZeroNullsOnOneInput) {
+ const int64_t length = 16;
+
+ constexpr uint8_t validity_bitmap[8] = {254, 0, 0, 0, 0, 0, 0, 0};
+ auto nulls = std::make_shared<Buffer>(validity_bitmap, 8);
+
+ ArrayData some_nulls(boolean(), 16, {nulls, nullptr}, /*null_count=*/9);
+ ArrayData no_nulls(boolean(), length, {nullptr, nullptr}, /*null_count=*/0);
+
+ ArrayData output(boolean(), length, {nullptr, nullptr});
+ ExecBatch batch({some_nulls, no_nulls}, length);
+ ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output));
+ ASSERT_EQ(nulls.get(), output.buffers[0].get());
+ ASSERT_EQ(9, output.null_count);
+
+ // Flip order of args
+ output = ArrayData(boolean(), length, {nullptr, nullptr});
+ batch.values = {no_nulls, no_nulls, some_nulls};
+ ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output));
+ ASSERT_EQ(nulls.get(), output.buffers[0].get());
+ ASSERT_EQ(9, output.null_count);
+
+ // Check that preallocated memory is not clobbered
+ uint8_t bitmap_data[2] = {0, 0};
+ auto preallocated_mem = std::make_shared<MutableBuffer>(bitmap_data, 2);
+ output.null_count = kUnknownNullCount;
+ output.buffers[0] = preallocated_mem;
+ ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output));
+
+ ASSERT_EQ(preallocated_mem.get(), output.buffers[0].get());
+ ASSERT_EQ(9, output.null_count);
+ ASSERT_EQ(254, bitmap_data[0]);
+ ASSERT_EQ(0, bitmap_data[1]);
+}
+
+TEST_F(TestPropagateNulls, IntersectsNulls) {
+ const int64_t length = 16;
+
+ // 0b01111111 0b11001111
+ constexpr uint8_t bitmap1[8] = {127, 207, 0, 0, 0, 0, 0, 0};
+
+ // 0b11111110 0b01111111
+ constexpr uint8_t bitmap2[8] = {254, 127, 0, 0, 0, 0, 0, 0};
+
+ // 0b11101111 0b11111110
+ constexpr uint8_t bitmap3[8] = {239, 254, 0, 0, 0, 0, 0, 0};
+
+ ArrayData arr1(boolean(), length, {std::make_shared<Buffer>(bitmap1, 8), nullptr});
+ ArrayData arr2(boolean(), length, {std::make_shared<Buffer>(bitmap2, 8), nullptr});
+ ArrayData arr3(boolean(), length, {std::make_shared<Buffer>(bitmap3, 8), nullptr});
+
+ auto CheckCase = [&](std::vector<Datum> values, int64_t ex_null_count,
+ const uint8_t* ex_bitmap, bool preallocate = false,
+ int64_t output_offset = 0) {
+ ExecBatch batch(values, length);
+
+ std::shared_ptr<Buffer> nulls;
+ if (preallocate) {
+ // Make the buffer one byte bigger so we can have non-zero offsets
+ ASSERT_OK_AND_ASSIGN(nulls, AllocateBuffer(3));
+ std::memset(nulls->mutable_data(), 0, nulls->size());
+ } else {
+ // non-zero output offset not permitted unless the output memory is
+ // preallocated
+ ASSERT_EQ(0, output_offset);
+ }
+ ArrayData output(boolean(), length, {nulls, nullptr});
+ output.offset = output_offset;
+
+ ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output));
+
+ // Preallocated memory used
+ if (preallocate) {
+ ASSERT_EQ(nulls.get(), output.buffers[0].get());
+ }
+
+ EXPECT_EQ(kUnknownNullCount, output.null_count);
+ EXPECT_EQ(ex_null_count, output.GetNullCount());
+
+ const auto& out_buffer = *output.buffers[0];
+
+ ASSERT_TRUE(BitmapEquals(out_buffer.data(), output_offset, ex_bitmap,
+ /*ex_offset=*/0, length));
+
+ // Now check that the rest of the bits in out_buffer are still 0
+ AssertValidityZeroExtraBits(output);
+ };
+
+ // 0b01101110 0b01001110
+ uint8_t expected1[2] = {110, 78};
+ CheckCase({arr1, arr2, arr3}, 7, expected1);
+ CheckCase({arr1, arr2, arr3}, 7, expected1, /*preallocate=*/true);
+ CheckCase({arr1, arr2, arr3}, 7, expected1, /*preallocate=*/true,
+ /*output_offset=*/4);
+
+ // 0b01111110 0b01001111
+ uint8_t expected2[2] = {126, 79};
+ CheckCase({arr1, arr2}, 5, expected2);
+ CheckCase({arr1, arr2}, 5, expected2, /*preallocate=*/true,
+ /*output_offset=*/4);
+}
+
+TEST_F(TestPropagateNulls, NullOutputTypeNoop) {
+ // Ensure we leave the buffers alone when the output type is null()
+ const int64_t length = 100;
+ ExecBatch batch({rng_->Boolean(100, 0.5, 0.5)}, length);
+
+ ArrayData output(null(), length, {nullptr});
+ ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output));
+ ASSERT_EQ(nullptr, output.buffers[0]);
+}
+
+// ----------------------------------------------------------------------
+// ExecBatchIterator
+
+class TestExecBatchIterator : public TestComputeInternals {
+ public:
+ void SetupIterator(std::vector<Datum> args,
+ int64_t max_chunksize = kDefaultMaxChunksize) {
+ ASSERT_OK_AND_ASSIGN(iterator_,
+ ExecBatchIterator::Make(std::move(args), max_chunksize));
+ }
+ void CheckIteration(const std::vector<Datum>& args, int chunksize,
+ const std::vector<int>& ex_batch_sizes) {
+ SetupIterator(args, chunksize);
+ ExecBatch batch;
+ int64_t position = 0;
+ for (size_t i = 0; i < ex_batch_sizes.size(); ++i) {
+ ASSERT_EQ(position, iterator_->position());
+ ASSERT_TRUE(iterator_->Next(&batch));
+ ASSERT_EQ(ex_batch_sizes[i], batch.length);
+
+ for (size_t j = 0; j < args.size(); ++j) {
+ switch (args[j].kind()) {
+ case Datum::SCALAR:
+ ASSERT_TRUE(args[j].scalar()->Equals(batch[j].scalar()));
+ break;
+ case Datum::ARRAY:
+ AssertArraysEqual(*args[j].make_array()->Slice(position, batch.length),
+ *batch[j].make_array());
+ break;
+ case Datum::CHUNKED_ARRAY: {
+ const ChunkedArray& carr = *args[j].chunked_array();
+ if (batch.length == 0) {
+ ASSERT_EQ(0, carr.length());
+ } else {
+ auto arg_slice = carr.Slice(position, batch.length);
+ // The sliced ChunkedArrays should only ever be 1 chunk
+ ASSERT_EQ(1, arg_slice->num_chunks());
+ AssertArraysEqual(*arg_slice->chunk(0), *batch[j].make_array());
+ }
+ } break;
+ default:
+ break;
+ }
+ }
+ position += ex_batch_sizes[i];
+ }
+ // Ensure that the iterator is exhausted
+ ASSERT_FALSE(iterator_->Next(&batch));
+
+ ASSERT_EQ(iterator_->length(), iterator_->position());
+ }
+
+ protected:
+ std::unique_ptr<ExecBatchIterator> iterator_;
+};
+
+TEST_F(TestExecBatchIterator, Basics) {
+ const int64_t length = 100;
+
+ // Simple case with a single chunk
+ std::vector<Datum> args = {Datum(GetInt32Array(length)), Datum(GetFloat64Array(length)),
+ Datum(std::make_shared<Int32Scalar>(3))};
+ SetupIterator(args);
+
+ ExecBatch batch;
+ ASSERT_TRUE(iterator_->Next(&batch));
+ ASSERT_EQ(3, batch.values.size());
+ ASSERT_EQ(3, batch.num_values());
+ ASSERT_EQ(length, batch.length);
+
+ std::vector<ValueDescr> descrs = batch.GetDescriptors();
+ ASSERT_EQ(ValueDescr::Array(int32()), descrs[0]);
+ ASSERT_EQ(ValueDescr::Array(float64()), descrs[1]);
+ ASSERT_EQ(ValueDescr::Scalar(int32()), descrs[2]);
+
+ AssertArraysEqual(*args[0].make_array(), *batch[0].make_array());
+ AssertArraysEqual(*args[1].make_array(), *batch[1].make_array());
+ ASSERT_TRUE(args[2].scalar()->Equals(batch[2].scalar()));
+
+ ASSERT_EQ(length, iterator_->position());
+ ASSERT_FALSE(iterator_->Next(&batch));
+
+ // Split into chunks of size 16
+ CheckIteration(args, /*chunksize=*/16, {16, 16, 16, 16, 16, 16, 4});
+}
+
+TEST_F(TestExecBatchIterator, InputValidation) {
+ std::vector<Datum> args = {Datum(GetInt32Array(10)), Datum(GetInt32Array(9))};
+ ASSERT_RAISES(Invalid, ExecBatchIterator::Make(args));
+
+ args = {Datum(GetInt32Array(9)), Datum(GetInt32Array(10))};
+ ASSERT_RAISES(Invalid, ExecBatchIterator::Make(args));
+
+ args = {Datum(GetInt32Array(10))};
+ ASSERT_OK_AND_ASSIGN(auto iterator, ExecBatchIterator::Make(args));
+ ASSERT_EQ(10, iterator->max_chunksize());
+}
+
+TEST_F(TestExecBatchIterator, ChunkedArrays) {
+ std::vector<Datum> args = {Datum(GetInt32Chunked({0, 20, 10})),
+ Datum(GetInt32Chunked({15, 15})), Datum(GetInt32Array(30)),
+ Datum(std::make_shared<Int32Scalar>(5)),
+ Datum(MakeNullScalar(boolean()))};
+
+ CheckIteration(args, /*chunksize=*/10, {10, 5, 5, 10});
+ CheckIteration(args, /*chunksize=*/20, {15, 5, 10});
+ CheckIteration(args, /*chunksize=*/30, {15, 5, 10});
+}
+
+TEST_F(TestExecBatchIterator, ZeroLengthInputs) {
+ auto carr = std::shared_ptr<ChunkedArray>(new ChunkedArray({}, int32()));
+
+ auto CheckArgs = [&](const std::vector<Datum>& args) {
+ auto iterator = ExecBatchIterator::Make(args).ValueOrDie();
+ ExecBatch batch;
+ ASSERT_FALSE(iterator->Next(&batch));
+ };
+
+ // Zero-length ChunkedArray with zero chunks
+ std::vector<Datum> args = {Datum(carr)};
+ CheckArgs(args);
+
+ // Zero-length array
+ args = {Datum(GetInt32Array(0))};
+ CheckArgs(args);
+
+ // ChunkedArray with single empty chunk
+ args = {Datum(GetInt32Chunked({0}))};
+ CheckArgs(args);
+}
+
+// ----------------------------------------------------------------------
+// Scalar function execution
+
+Status ExecCopy(KernelContext*, const ExecBatch& batch, Datum* out) {
+ DCHECK_EQ(1, batch.num_values());
+ const auto& type = checked_cast<const FixedWidthType&>(*batch[0].type());
+ int value_size = type.bit_width() / 8;
+
+ const ArrayData& arg0 = *batch[0].array();
+ ArrayData* out_arr = out->mutable_array();
+ uint8_t* dst = out_arr->buffers[1]->mutable_data() + out_arr->offset * value_size;
+ const uint8_t* src = arg0.buffers[1]->data() + arg0.offset * value_size;
+ std::memcpy(dst, src, batch.length * value_size);
+ return Status::OK();
+}
+
+Status ExecComputedBitmap(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // Propagate nulls not used. Check that the out bitmap isn't the same already
+ // as the input bitmap
+ const ArrayData& arg0 = *batch[0].array();
+ ArrayData* out_arr = out->mutable_array();
+
+ if (CountSetBits(arg0.buffers[0]->data(), arg0.offset, batch.length) > 0) {
+ // Check that the bitmap has not been already copied over
+ DCHECK(!BitmapEquals(arg0.buffers[0]->data(), arg0.offset,
+ out_arr->buffers[0]->data(), out_arr->offset, batch.length));
+ }
+
+ CopyBitmap(arg0.buffers[0]->data(), arg0.offset, batch.length,
+ out_arr->buffers[0]->mutable_data(), out_arr->offset);
+ return ExecCopy(ctx, batch, out);
+}
+
+Status ExecNoPreallocatedData(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // Validity preallocated, but not the data
+ ArrayData* out_arr = out->mutable_array();
+ DCHECK_EQ(0, out_arr->offset);
+ const auto& type = checked_cast<const FixedWidthType&>(*batch[0].type());
+ int value_size = type.bit_width() / 8;
+ Status s = (ctx->Allocate(out_arr->length * value_size).Value(&out_arr->buffers[1]));
+ DCHECK_OK(s);
+ return ExecCopy(ctx, batch, out);
+}
+
+Status ExecNoPreallocatedAnything(KernelContext* ctx, const ExecBatch& batch,
+ Datum* out) {
+ // Neither validity nor data preallocated
+ ArrayData* out_arr = out->mutable_array();
+ DCHECK_EQ(0, out_arr->offset);
+ Status s = (ctx->AllocateBitmap(out_arr->length).Value(&out_arr->buffers[0]));
+ DCHECK_OK(s);
+ const ArrayData& arg0 = *batch[0].array();
+ CopyBitmap(arg0.buffers[0]->data(), arg0.offset, batch.length,
+ out_arr->buffers[0]->mutable_data(), /*offset=*/0);
+
+ // Reuse the kernel that allocates the data
+ return ExecNoPreallocatedData(ctx, batch, out);
+}
+
+class ExampleOptions : public FunctionOptions {
+ public:
+ explicit ExampleOptions(std::shared_ptr<Scalar> value);
+ std::shared_ptr<Scalar> value;
+};
+
+class ExampleOptionsType : public FunctionOptionsType {
+ public:
+ static const FunctionOptionsType* GetInstance() {
+ static std::unique_ptr<FunctionOptionsType> instance(new ExampleOptionsType());
+ return instance.get();
+ }
+ const char* type_name() const override { return "example"; }
+ std::string Stringify(const FunctionOptions& options) const override {
+ return type_name();
+ }
+ bool Compare(const FunctionOptions& options,
+ const FunctionOptions& other) const override {
+ return true;
+ }
+ std::unique_ptr<FunctionOptions> Copy(const FunctionOptions& options) const override {
+ const auto& opts = static_cast<const ExampleOptions&>(options);
+ return arrow::internal::make_unique<ExampleOptions>(opts.value);
+ }
+};
+ExampleOptions::ExampleOptions(std::shared_ptr<Scalar> value)
+ : FunctionOptions(ExampleOptionsType::GetInstance()), value(std::move(value)) {}
+
+struct ExampleState : public KernelState {
+ std::shared_ptr<Scalar> value;
+ explicit ExampleState(std::shared_ptr<Scalar> value) : value(std::move(value)) {}
+};
+
+Result<std::unique_ptr<KernelState>> InitStateful(KernelContext*,
+ const KernelInitArgs& args) {
+ auto func_options = static_cast<const ExampleOptions*>(args.options);
+ return std::unique_ptr<KernelState>(new ExampleState{func_options->value});
+}
+
+Status ExecStateful(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // We take the value from the state and multiply the data in batch[0] with it
+ ExampleState* state = static_cast<ExampleState*>(ctx->state());
+ int32_t multiplier = checked_cast<const Int32Scalar&>(*state->value).value;
+
+ const ArrayData& arg0 = *batch[0].array();
+ ArrayData* out_arr = out->mutable_array();
+ const int32_t* arg0_data = arg0.GetValues<int32_t>(1);
+ int32_t* dst = out_arr->GetMutableValues<int32_t>(1);
+ for (int64_t i = 0; i < arg0.length; ++i) {
+ dst[i] = arg0_data[i] * multiplier;
+ }
+ return Status::OK();
+}
+
+Status ExecAddInt32(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const Int32Scalar& arg0 = batch[0].scalar_as<Int32Scalar>();
+ const Int32Scalar& arg1 = batch[1].scalar_as<Int32Scalar>();
+ out->value = std::make_shared<Int32Scalar>(arg0.value + arg1.value);
+ return Status::OK();
+}
+
+class TestCallScalarFunction : public TestComputeInternals {
+ protected:
+ static bool initialized_;
+
+ void SetUp() {
+ TestComputeInternals::SetUp();
+
+ if (!initialized_) {
+ initialized_ = true;
+ AddCopyFunctions();
+ AddNoPreallocateFunctions();
+ AddStatefulFunction();
+ AddScalarFunction();
+ }
+ }
+
+ void AddCopyFunctions() {
+ auto registry = GetFunctionRegistry();
+
+ // This function simply copies memory from the input argument into the
+ // (preallocated) output
+ auto func =
+ std::make_shared<ScalarFunction>("test_copy", Arity::Unary(), /*doc=*/nullptr);
+
+ // Add a few kernels. Our implementation only accepts arrays
+ ASSERT_OK(func->AddKernel({InputType::Array(uint8())}, uint8(), ExecCopy));
+ ASSERT_OK(func->AddKernel({InputType::Array(int32())}, int32(), ExecCopy));
+ ASSERT_OK(func->AddKernel({InputType::Array(float64())}, float64(), ExecCopy));
+ ASSERT_OK(registry->AddFunction(func));
+
+ // A version which doesn't want the executor to call PropagateNulls
+ auto func2 = std::make_shared<ScalarFunction>("test_copy_computed_bitmap",
+ Arity::Unary(), /*doc=*/nullptr);
+ ScalarKernel kernel({InputType::Array(uint8())}, uint8(), ExecComputedBitmap);
+ kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
+ ASSERT_OK(func2->AddKernel(kernel));
+ ASSERT_OK(registry->AddFunction(func2));
+ }
+
+ void AddNoPreallocateFunctions() {
+ auto registry = GetFunctionRegistry();
+
+ // A function that allocates its own output memory. We have cases for both
+ // non-preallocated data and non-preallocated validity bitmap
+ auto f1 = std::make_shared<ScalarFunction>("test_nopre_data", Arity::Unary(),
+ /*doc=*/nullptr);
+ auto f2 = std::make_shared<ScalarFunction>("test_nopre_validity_or_data",
+ Arity::Unary(), /*doc=*/nullptr);
+
+ ScalarKernel kernel({InputType::Array(uint8())}, uint8(), ExecNoPreallocatedData);
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ ASSERT_OK(f1->AddKernel(kernel));
+
+ kernel.exec = ExecNoPreallocatedAnything;
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ ASSERT_OK(f2->AddKernel(kernel));
+
+ ASSERT_OK(registry->AddFunction(f1));
+ ASSERT_OK(registry->AddFunction(f2));
+ }
+
+ void AddStatefulFunction() {
+ auto registry = GetFunctionRegistry();
+
+ // This function's behavior depends on a static parameter that is made
+ // available to the kernel's execution function through its Options object
+ auto func = std::make_shared<ScalarFunction>("test_stateful", Arity::Unary(),
+ /*doc=*/nullptr);
+
+ ScalarKernel kernel({InputType::Array(int32())}, int32(), ExecStateful, InitStateful);
+ ASSERT_OK(func->AddKernel(kernel));
+ ASSERT_OK(registry->AddFunction(func));
+ }
+
+ void AddScalarFunction() {
+ auto registry = GetFunctionRegistry();
+
+ auto func = std::make_shared<ScalarFunction>("test_scalar_add_int32", Arity::Binary(),
+ /*doc=*/nullptr);
+ ASSERT_OK(func->AddKernel({InputType::Scalar(int32()), InputType::Scalar(int32())},
+ int32(), ExecAddInt32));
+ ASSERT_OK(registry->AddFunction(func));
+ }
+};
+
+bool TestCallScalarFunction::initialized_ = false;
+
+TEST_F(TestCallScalarFunction, ArgumentValidation) {
+ // Copy accepts only a single array argument
+ Datum d1(GetInt32Array(10));
+
+ // Too many args
+ std::vector<Datum> args = {d1, d1};
+ ASSERT_RAISES(Invalid, CallFunction("test_copy", args));
+
+ // Too few
+ args = {};
+ ASSERT_RAISES(Invalid, CallFunction("test_copy", args));
+
+ // Cannot do scalar
+ args = {Datum(std::make_shared<Int32Scalar>(5))};
+ ASSERT_RAISES(NotImplemented, CallFunction("test_copy", args));
+}
+
+TEST_F(TestCallScalarFunction, PreallocationCases) {
+ double null_prob = 0.2;
+
+ auto arr = GetUInt8Array(1000, null_prob);
+
+ auto CheckFunction = [&](std::string func_name) {
+ ResetContexts();
+
+ // The default should be a single array output
+ {
+ std::vector<Datum> args = {Datum(arr)};
+ ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func_name, args));
+ ASSERT_EQ(Datum::ARRAY, result.kind());
+ AssertArraysEqual(*arr, *result.make_array());
+ }
+
+ // Set the exec_chunksize to be smaller, so now we have several invocations
+ // of the kernel, but still the output is onee array
+ {
+ std::vector<Datum> args = {Datum(arr)};
+ exec_ctx_->set_exec_chunksize(80);
+ ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func_name, args, exec_ctx_.get()));
+ AssertArraysEqual(*arr, *result.make_array());
+ }
+
+ {
+ // Chunksize not multiple of 8
+ std::vector<Datum> args = {Datum(arr)};
+ exec_ctx_->set_exec_chunksize(111);
+ ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func_name, args, exec_ctx_.get()));
+ AssertArraysEqual(*arr, *result.make_array());
+ }
+
+ // Input is chunked, output has one big chunk
+ {
+ auto carr = std::shared_ptr<ChunkedArray>(
+ new ChunkedArray({arr->Slice(0, 100), arr->Slice(100)}));
+ std::vector<Datum> args = {Datum(carr)};
+ ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func_name, args, exec_ctx_.get()));
+ std::shared_ptr<ChunkedArray> actual = result.chunked_array();
+ ASSERT_EQ(1, actual->num_chunks());
+ AssertChunkedEquivalent(*carr, *actual);
+ }
+
+ // Preallocate independently for each batch
+ {
+ std::vector<Datum> args = {Datum(arr)};
+ exec_ctx_->set_preallocate_contiguous(false);
+ exec_ctx_->set_exec_chunksize(400);
+ ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func_name, args, exec_ctx_.get()));
+ ASSERT_EQ(Datum::CHUNKED_ARRAY, result.kind());
+ const ChunkedArray& carr = *result.chunked_array();
+ ASSERT_EQ(3, carr.num_chunks());
+ AssertArraysEqual(*arr->Slice(0, 400), *carr.chunk(0));
+ AssertArraysEqual(*arr->Slice(400, 400), *carr.chunk(1));
+ AssertArraysEqual(*arr->Slice(800), *carr.chunk(2));
+ }
+ };
+
+ CheckFunction("test_copy");
+ CheckFunction("test_copy_computed_bitmap");
+}
+
+TEST_F(TestCallScalarFunction, BasicNonStandardCases) {
+ // Test a handful of cases
+ //
+ // * Validity bitmap computed by kernel rather than using PropagateNulls
+ // * Data not pre-allocated
+ // * Validity bitmap not pre-allocated
+
+ double null_prob = 0.2;
+
+ auto arr = GetUInt8Array(1000, null_prob);
+ std::vector<Datum> args = {Datum(arr)};
+
+ auto CheckFunction = [&](std::string func_name) {
+ ResetContexts();
+
+ // The default should be a single array output
+ {
+ ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func_name, args));
+ AssertArraysEqual(*arr, *result.make_array(), true);
+ }
+
+ // Split execution into 3 chunks
+ {
+ exec_ctx_->set_exec_chunksize(400);
+ ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func_name, args, exec_ctx_.get()));
+ ASSERT_EQ(Datum::CHUNKED_ARRAY, result.kind());
+ const ChunkedArray& carr = *result.chunked_array();
+ ASSERT_EQ(3, carr.num_chunks());
+ AssertArraysEqual(*arr->Slice(0, 400), *carr.chunk(0));
+ AssertArraysEqual(*arr->Slice(400, 400), *carr.chunk(1));
+ AssertArraysEqual(*arr->Slice(800), *carr.chunk(2));
+ }
+ };
+
+ CheckFunction("test_nopre_data");
+ CheckFunction("test_nopre_validity_or_data");
+}
+
+TEST_F(TestCallScalarFunction, StatefulKernel) {
+ auto input = ArrayFromJSON(int32(), "[1, 2, 3, null, 5]");
+ auto multiplier = std::make_shared<Int32Scalar>(2);
+ auto expected = ArrayFromJSON(int32(), "[2, 4, 6, null, 10]");
+
+ ExampleOptions options(multiplier);
+ std::vector<Datum> args = {Datum(input)};
+ ASSERT_OK_AND_ASSIGN(Datum result, CallFunction("test_stateful", args, &options));
+ AssertArraysEqual(*expected, *result.make_array());
+}
+
+TEST_F(TestCallScalarFunction, ScalarFunction) {
+ std::vector<Datum> args = {Datum(std::make_shared<Int32Scalar>(5)),
+ Datum(std::make_shared<Int32Scalar>(7))};
+ ASSERT_OK_AND_ASSIGN(Datum result, CallFunction("test_scalar_add_int32", args));
+ ASSERT_EQ(Datum::SCALAR, result.kind());
+
+ auto expected = std::make_shared<Int32Scalar>(12);
+ ASSERT_TRUE(expected->Equals(*result.scalar()));
+}
+
+} // namespace detail
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/function.cc b/src/arrow/cpp/src/arrow/compute/function.cc
new file mode 100644
index 000000000..dda5788c5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/function.cc
@@ -0,0 +1,339 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/function.h"
+
+#include <cstddef>
+#include <memory>
+#include <sstream>
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec_internal.h"
+#include "arrow/compute/function_internal.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/registry.h"
+#include "arrow/datum.h"
+#include "arrow/util/cpu_info.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+Result<std::shared_ptr<Buffer>> FunctionOptionsType::Serialize(
+ const FunctionOptions&) const {
+ return Status::NotImplemented("Serialize for ", type_name());
+}
+
+Result<std::unique_ptr<FunctionOptions>> FunctionOptionsType::Deserialize(
+ const Buffer& buffer) const {
+ return Status::NotImplemented("Deserialize for ", type_name());
+}
+
+std::string FunctionOptions::ToString() const { return options_type()->Stringify(*this); }
+
+bool FunctionOptions::Equals(const FunctionOptions& other) const {
+ if (this == &other) return true;
+ if (options_type() != other.options_type()) return false;
+ return options_type()->Compare(*this, other);
+}
+
+std::unique_ptr<FunctionOptions> FunctionOptions::Copy() const {
+ return options_type()->Copy(*this);
+}
+
+Result<std::shared_ptr<Buffer>> FunctionOptions::Serialize() const {
+ return options_type()->Serialize(*this);
+}
+
+Result<std::unique_ptr<FunctionOptions>> FunctionOptions::Deserialize(
+ const std::string& type_name, const Buffer& buffer) {
+ ARROW_ASSIGN_OR_RAISE(auto options,
+ GetFunctionRegistry()->GetFunctionOptionsType(type_name));
+ return options->Deserialize(buffer);
+}
+
+void PrintTo(const FunctionOptions& options, std::ostream* os) {
+ *os << options.ToString();
+}
+
+static const FunctionDoc kEmptyFunctionDoc{};
+
+const FunctionDoc& FunctionDoc::Empty() { return kEmptyFunctionDoc; }
+
+static Status CheckArityImpl(const Function* function, int passed_num_args,
+ const char* passed_num_args_label) {
+ if (function->arity().is_varargs && passed_num_args < function->arity().num_args) {
+ return Status::Invalid("VarArgs function ", function->name(), " needs at least ",
+ function->arity().num_args, " arguments but ",
+ passed_num_args_label, " only ", passed_num_args);
+ }
+
+ if (!function->arity().is_varargs && passed_num_args != function->arity().num_args) {
+ return Status::Invalid("Function ", function->name(), " accepts ",
+ function->arity().num_args, " arguments but ",
+ passed_num_args_label, " ", passed_num_args);
+ }
+
+ return Status::OK();
+}
+
+Status Function::CheckArity(const std::vector<InputType>& in_types) const {
+ return CheckArityImpl(this, static_cast<int>(in_types.size()), "kernel accepts");
+}
+
+Status Function::CheckArity(const std::vector<ValueDescr>& descrs) const {
+ return CheckArityImpl(this, static_cast<int>(descrs.size()),
+ "attempted to look up kernel(s) with");
+}
+
+namespace detail {
+
+Status NoMatchingKernel(const Function* func, const std::vector<ValueDescr>& descrs) {
+ return Status::NotImplemented("Function ", func->name(),
+ " has no kernel matching input types ",
+ ValueDescr::ToString(descrs));
+}
+
+template <typename KernelType>
+const KernelType* DispatchExactImpl(const std::vector<KernelType*>& kernels,
+ const std::vector<ValueDescr>& values) {
+ const KernelType* kernel_matches[SimdLevel::MAX] = {nullptr};
+
+ // Validate arity
+ for (const auto& kernel : kernels) {
+ if (kernel->signature->MatchesInputs(values)) {
+ kernel_matches[kernel->simd_level] = kernel;
+ }
+ }
+
+ // Dispatch as the CPU feature
+#if defined(ARROW_HAVE_RUNTIME_AVX512) || defined(ARROW_HAVE_RUNTIME_AVX2)
+ auto cpu_info = arrow::internal::CpuInfo::GetInstance();
+#endif
+#if defined(ARROW_HAVE_RUNTIME_AVX512)
+ if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) {
+ if (kernel_matches[SimdLevel::AVX512]) {
+ return kernel_matches[SimdLevel::AVX512];
+ }
+ }
+#endif
+#if defined(ARROW_HAVE_RUNTIME_AVX2)
+ if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) {
+ if (kernel_matches[SimdLevel::AVX2]) {
+ return kernel_matches[SimdLevel::AVX2];
+ }
+ }
+#endif
+ if (kernel_matches[SimdLevel::NONE]) {
+ return kernel_matches[SimdLevel::NONE];
+ }
+
+ return nullptr;
+}
+
+const Kernel* DispatchExactImpl(const Function* func,
+ const std::vector<ValueDescr>& values) {
+ if (func->kind() == Function::SCALAR) {
+ return DispatchExactImpl(checked_cast<const ScalarFunction*>(func)->kernels(),
+ values);
+ }
+
+ if (func->kind() == Function::VECTOR) {
+ return DispatchExactImpl(checked_cast<const VectorFunction*>(func)->kernels(),
+ values);
+ }
+
+ if (func->kind() == Function::SCALAR_AGGREGATE) {
+ return DispatchExactImpl(
+ checked_cast<const ScalarAggregateFunction*>(func)->kernels(), values);
+ }
+
+ if (func->kind() == Function::HASH_AGGREGATE) {
+ return DispatchExactImpl(checked_cast<const HashAggregateFunction*>(func)->kernels(),
+ values);
+ }
+
+ return nullptr;
+}
+
+} // namespace detail
+
+Result<const Kernel*> Function::DispatchExact(
+ const std::vector<ValueDescr>& values) const {
+ if (kind_ == Function::META) {
+ return Status::NotImplemented("Dispatch for a MetaFunction's Kernels");
+ }
+ RETURN_NOT_OK(CheckArity(values));
+
+ if (auto kernel = detail::DispatchExactImpl(this, values)) {
+ return kernel;
+ }
+ return detail::NoMatchingKernel(this, values);
+}
+
+Result<const Kernel*> Function::DispatchBest(std::vector<ValueDescr>* values) const {
+ // TODO(ARROW-11508) permit generic conversions here
+ return DispatchExact(*values);
+}
+
+Result<Datum> Function::Execute(const std::vector<Datum>& args,
+ const FunctionOptions* options, ExecContext* ctx) const {
+ if (options == nullptr) {
+ options = default_options();
+ }
+ if (ctx == nullptr) {
+ ExecContext default_ctx;
+ return Execute(args, options, &default_ctx);
+ }
+
+ // type-check Datum arguments here. Really we'd like to avoid this as much as
+ // possible
+ RETURN_NOT_OK(detail::CheckAllValues(args));
+ std::vector<ValueDescr> inputs(args.size());
+ for (size_t i = 0; i != args.size(); ++i) {
+ inputs[i] = args[i].descr();
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto kernel, DispatchBest(&inputs));
+ ARROW_ASSIGN_OR_RAISE(auto implicitly_cast_args, Cast(args, inputs, ctx));
+
+ std::unique_ptr<KernelState> state;
+
+ KernelContext kernel_ctx{ctx};
+ if (kernel->init) {
+ ARROW_ASSIGN_OR_RAISE(state, kernel->init(&kernel_ctx, {kernel, inputs, options}));
+ kernel_ctx.SetState(state.get());
+ }
+
+ std::unique_ptr<detail::KernelExecutor> executor;
+ if (kind() == Function::SCALAR) {
+ executor = detail::KernelExecutor::MakeScalar();
+ } else if (kind() == Function::VECTOR) {
+ executor = detail::KernelExecutor::MakeVector();
+ } else if (kind() == Function::SCALAR_AGGREGATE) {
+ executor = detail::KernelExecutor::MakeScalarAggregate();
+ } else {
+ return Status::NotImplemented("Direct execution of HASH_AGGREGATE functions");
+ }
+ RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, inputs, options}));
+
+ detail::DatumAccumulator listener;
+ RETURN_NOT_OK(executor->Execute(implicitly_cast_args, &listener));
+ const auto out = executor->WrapResults(implicitly_cast_args, listener.values());
+#ifndef NDEBUG
+ DCHECK_OK(executor->CheckResultType(out, name_.c_str()));
+#endif
+ return out;
+}
+
+Status Function::Validate() const {
+ if (!doc_->summary.empty()) {
+ // Documentation given, check its contents
+ int arg_count = static_cast<int>(doc_->arg_names.size());
+ if (arg_count == arity_.num_args) {
+ return Status::OK();
+ }
+ if (arity_.is_varargs && arg_count == arity_.num_args + 1) {
+ return Status::OK();
+ }
+ return Status::Invalid(
+ "In function '", name_,
+ "': ", "number of argument names for function documentation != function arity");
+ }
+ return Status::OK();
+}
+
+Status ScalarFunction::AddKernel(std::vector<InputType> in_types, OutputType out_type,
+ ArrayKernelExec exec, KernelInit init) {
+ RETURN_NOT_OK(CheckArity(in_types));
+
+ if (arity_.is_varargs && in_types.size() != 1) {
+ return Status::Invalid("VarArgs signatures must have exactly one input type");
+ }
+ auto sig =
+ KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs);
+ kernels_.emplace_back(std::move(sig), exec, init);
+ return Status::OK();
+}
+
+Status ScalarFunction::AddKernel(ScalarKernel kernel) {
+ RETURN_NOT_OK(CheckArity(kernel.signature->in_types()));
+ if (arity_.is_varargs && !kernel.signature->is_varargs()) {
+ return Status::Invalid("Function accepts varargs but kernel signature does not");
+ }
+ kernels_.emplace_back(std::move(kernel));
+ return Status::OK();
+}
+
+Status VectorFunction::AddKernel(std::vector<InputType> in_types, OutputType out_type,
+ ArrayKernelExec exec, KernelInit init) {
+ RETURN_NOT_OK(CheckArity(in_types));
+
+ if (arity_.is_varargs && in_types.size() != 1) {
+ return Status::Invalid("VarArgs signatures must have exactly one input type");
+ }
+ auto sig =
+ KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs);
+ kernels_.emplace_back(std::move(sig), exec, init);
+ return Status::OK();
+}
+
+Status VectorFunction::AddKernel(VectorKernel kernel) {
+ RETURN_NOT_OK(CheckArity(kernel.signature->in_types()));
+ if (arity_.is_varargs && !kernel.signature->is_varargs()) {
+ return Status::Invalid("Function accepts varargs but kernel signature does not");
+ }
+ kernels_.emplace_back(std::move(kernel));
+ return Status::OK();
+}
+
+Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) {
+ RETURN_NOT_OK(CheckArity(kernel.signature->in_types()));
+ if (arity_.is_varargs && !kernel.signature->is_varargs()) {
+ return Status::Invalid("Function accepts varargs but kernel signature does not");
+ }
+ kernels_.emplace_back(std::move(kernel));
+ return Status::OK();
+}
+
+Status HashAggregateFunction::AddKernel(HashAggregateKernel kernel) {
+ RETURN_NOT_OK(CheckArity(kernel.signature->in_types()));
+ if (arity_.is_varargs && !kernel.signature->is_varargs()) {
+ return Status::Invalid("Function accepts varargs but kernel signature does not");
+ }
+ kernels_.emplace_back(std::move(kernel));
+ return Status::OK();
+}
+
+Result<Datum> MetaFunction::Execute(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const {
+ RETURN_NOT_OK(
+ CheckArityImpl(this, static_cast<int>(args.size()), "attempted to Execute with"));
+
+ if (options == nullptr) {
+ options = default_options();
+ }
+ return ExecuteImpl(args, options, ctx);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/function.h b/src/arrow/cpp/src/arrow/compute/function.h
new file mode 100644
index 000000000..f08b50699
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/function.h
@@ -0,0 +1,395 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// NOTE: API is EXPERIMENTAL and will change without going through a
+// deprecation cycle.
+
+#pragma once
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/compute/kernel.h"
+#include "arrow/compute/type_fwd.h"
+#include "arrow/datum.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/compare.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace compute {
+
+/// \defgroup compute-functions Abstract compute function API
+///
+/// @{
+
+/// \brief Extension point for defining options outside libarrow (but
+/// still within this project).
+class ARROW_EXPORT FunctionOptionsType {
+ public:
+ virtual ~FunctionOptionsType() = default;
+
+ virtual const char* type_name() const = 0;
+ virtual std::string Stringify(const FunctionOptions&) const = 0;
+ virtual bool Compare(const FunctionOptions&, const FunctionOptions&) const = 0;
+ virtual Result<std::shared_ptr<Buffer>> Serialize(const FunctionOptions&) const;
+ virtual Result<std::unique_ptr<FunctionOptions>> Deserialize(
+ const Buffer& buffer) const;
+ virtual std::unique_ptr<FunctionOptions> Copy(const FunctionOptions&) const = 0;
+};
+
+/// \brief Base class for specifying options configuring a function's behavior,
+/// such as error handling.
+class ARROW_EXPORT FunctionOptions : public util::EqualityComparable<FunctionOptions> {
+ public:
+ virtual ~FunctionOptions() = default;
+
+ const FunctionOptionsType* options_type() const { return options_type_; }
+ const char* type_name() const { return options_type()->type_name(); }
+
+ bool Equals(const FunctionOptions& other) const;
+ using util::EqualityComparable<FunctionOptions>::Equals;
+ using util::EqualityComparable<FunctionOptions>::operator==;
+ using util::EqualityComparable<FunctionOptions>::operator!=;
+ std::string ToString() const;
+ std::unique_ptr<FunctionOptions> Copy() const;
+ /// \brief Serialize an options struct to a buffer.
+ Result<std::shared_ptr<Buffer>> Serialize() const;
+ /// \brief Deserialize an options struct from a buffer.
+ /// Note: this will only look for `type_name` in the default FunctionRegistry;
+ /// to use a custom FunctionRegistry, look up the FunctionOptionsType, then
+ /// call FunctionOptionsType::Deserialize().
+ static Result<std::unique_ptr<FunctionOptions>> Deserialize(
+ const std::string& type_name, const Buffer& buffer);
+
+ protected:
+ explicit FunctionOptions(const FunctionOptionsType* type) : options_type_(type) {}
+ const FunctionOptionsType* options_type_;
+};
+
+ARROW_EXPORT void PrintTo(const FunctionOptions&, std::ostream*);
+
+/// \brief Contains the number of required arguments for the function.
+///
+/// Naming conventions taken from https://en.wikipedia.org/wiki/Arity.
+struct ARROW_EXPORT Arity {
+ /// \brief A function taking no arguments
+ static Arity Nullary() { return Arity(0, false); }
+
+ /// \brief A function taking 1 argument
+ static Arity Unary() { return Arity(1, false); }
+
+ /// \brief A function taking 2 arguments
+ static Arity Binary() { return Arity(2, false); }
+
+ /// \brief A function taking 3 arguments
+ static Arity Ternary() { return Arity(3, false); }
+
+ /// \brief A function taking a variable number of arguments
+ ///
+ /// \param[in] min_args the minimum number of arguments required when
+ /// invoking the function
+ static Arity VarArgs(int min_args = 0) { return Arity(min_args, true); }
+
+ // NOTE: the 0-argument form (default constructor) is required for Cython
+ explicit Arity(int num_args = 0, bool is_varargs = false)
+ : num_args(num_args), is_varargs(is_varargs) {}
+
+ /// The number of required arguments (or the minimum number for varargs
+ /// functions).
+ int num_args;
+
+ /// If true, then the num_args is the minimum number of required arguments.
+ bool is_varargs = false;
+};
+
+struct ARROW_EXPORT FunctionDoc {
+ /// \brief A one-line summary of the function, using a verb.
+ ///
+ /// For example, "Add two numeric arrays or scalars".
+ std::string summary;
+
+ /// \brief A detailed description of the function, meant to follow the summary.
+ std::string description;
+
+ /// \brief Symbolic names (identifiers) for the function arguments.
+ ///
+ /// Some bindings may use this to generate nicer function signatures.
+ std::vector<std::string> arg_names;
+
+ // TODO add argument descriptions?
+
+ /// \brief Name of the options class, if any.
+ std::string options_class;
+
+ FunctionDoc() = default;
+
+ FunctionDoc(std::string summary, std::string description,
+ std::vector<std::string> arg_names, std::string options_class = "")
+ : summary(std::move(summary)),
+ description(std::move(description)),
+ arg_names(std::move(arg_names)),
+ options_class(std::move(options_class)) {}
+
+ static const FunctionDoc& Empty();
+};
+
+/// \brief Base class for compute functions. Function implementations contain a
+/// collection of "kernels" which are implementations of the function for
+/// specific argument types. Selecting a viable kernel for executing a function
+/// is referred to as "dispatching".
+class ARROW_EXPORT Function {
+ public:
+ /// \brief The kind of function, which indicates in what contexts it is
+ /// valid for use.
+ enum Kind {
+ /// A function that performs scalar data operations on whole arrays of
+ /// data. Can generally process Array or Scalar values. The size of the
+ /// output will be the same as the size (or broadcasted size, in the case
+ /// of mixing Array and Scalar inputs) of the input.
+ SCALAR,
+
+ /// A function with array input and output whose behavior depends on the
+ /// values of the entire arrays passed, rather than the value of each scalar
+ /// value.
+ VECTOR,
+
+ /// A function that computes scalar summary statistics from array input.
+ SCALAR_AGGREGATE,
+
+ /// A function that computes grouped summary statistics from array input
+ /// and an array of group identifiers.
+ HASH_AGGREGATE,
+
+ /// A function that dispatches to other functions and does not contain its
+ /// own kernels.
+ META
+ };
+
+ virtual ~Function() = default;
+
+ /// \brief The name of the kernel. The registry enforces uniqueness of names.
+ const std::string& name() const { return name_; }
+
+ /// \brief The kind of kernel, which indicates in what contexts it is valid
+ /// for use.
+ Function::Kind kind() const { return kind_; }
+
+ /// \brief Contains the number of arguments the function requires, or if the
+ /// function accepts variable numbers of arguments.
+ const Arity& arity() const { return arity_; }
+
+ /// \brief Return the function documentation
+ const FunctionDoc& doc() const { return *doc_; }
+
+ /// \brief Returns the number of registered kernels for this function.
+ virtual int num_kernels() const = 0;
+
+ /// \brief Return a kernel that can execute the function given the exact
+ /// argument types (without implicit type casts or scalar->array promotions).
+ ///
+ /// NB: This function is overridden in CastFunction.
+ virtual Result<const Kernel*> DispatchExact(
+ const std::vector<ValueDescr>& values) const;
+
+ /// \brief Return a best-match kernel that can execute the function given the argument
+ /// types, after implicit casts are applied.
+ ///
+ /// \param[in,out] values Argument types. An element may be modified to indicate that
+ /// the returned kernel only approximately matches the input value descriptors; callers
+ /// are responsible for casting inputs to the type and shape required by the kernel.
+ virtual Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const;
+
+ /// \brief Execute the function eagerly with the passed input arguments with
+ /// kernel dispatch, batch iteration, and memory allocation details taken
+ /// care of.
+ ///
+ /// If the `options` pointer is null, then `default_options()` will be used.
+ ///
+ /// This function can be overridden in subclasses.
+ virtual Result<Datum> Execute(const std::vector<Datum>& args,
+ const FunctionOptions* options, ExecContext* ctx) const;
+
+ /// \brief Returns the default options for this function.
+ ///
+ /// Whatever option semantics a Function has, implementations must guarantee
+ /// that default_options() is valid to pass to Execute as options.
+ const FunctionOptions* default_options() const { return default_options_; }
+
+ virtual Status Validate() const;
+
+ protected:
+ Function(std::string name, Function::Kind kind, const Arity& arity,
+ const FunctionDoc* doc, const FunctionOptions* default_options)
+ : name_(std::move(name)),
+ kind_(kind),
+ arity_(arity),
+ doc_(doc ? doc : &FunctionDoc::Empty()),
+ default_options_(default_options) {}
+
+ Status CheckArity(const std::vector<InputType>&) const;
+ Status CheckArity(const std::vector<ValueDescr>&) const;
+
+ std::string name_;
+ Function::Kind kind_;
+ Arity arity_;
+ const FunctionDoc* doc_;
+ const FunctionOptions* default_options_ = NULLPTR;
+};
+
+namespace detail {
+
+template <typename KernelType>
+class FunctionImpl : public Function {
+ public:
+ /// \brief Return pointers to current-available kernels for inspection
+ std::vector<const KernelType*> kernels() const {
+ std::vector<const KernelType*> result;
+ for (const auto& kernel : kernels_) {
+ result.push_back(&kernel);
+ }
+ return result;
+ }
+
+ int num_kernels() const override { return static_cast<int>(kernels_.size()); }
+
+ protected:
+ FunctionImpl(std::string name, Function::Kind kind, const Arity& arity,
+ const FunctionDoc* doc, const FunctionOptions* default_options)
+ : Function(std::move(name), kind, arity, doc, default_options) {}
+
+ std::vector<KernelType> kernels_;
+};
+
+/// \brief Look up a kernel in a function. If no Kernel is found, nullptr is returned.
+ARROW_EXPORT
+const Kernel* DispatchExactImpl(const Function* func, const std::vector<ValueDescr>&);
+
+/// \brief Return an error message if no Kernel is found.
+ARROW_EXPORT
+Status NoMatchingKernel(const Function* func, const std::vector<ValueDescr>&);
+
+} // namespace detail
+
+/// \brief A function that executes elementwise operations on arrays or
+/// scalars, and therefore whose results generally do not depend on the order
+/// of the values in the arguments. Accepts and returns arrays that are all of
+/// the same size. These functions roughly correspond to the functions used in
+/// SQL expressions.
+class ARROW_EXPORT ScalarFunction : public detail::FunctionImpl<ScalarKernel> {
+ public:
+ using KernelType = ScalarKernel;
+
+ ScalarFunction(std::string name, const Arity& arity, const FunctionDoc* doc,
+ const FunctionOptions* default_options = NULLPTR)
+ : detail::FunctionImpl<ScalarKernel>(std::move(name), Function::SCALAR, arity, doc,
+ default_options) {}
+
+ /// \brief Add a kernel with given input/output types, no required state
+ /// initialization, preallocation for fixed-width types, and default null
+ /// handling (intersect validity bitmaps of inputs).
+ Status AddKernel(std::vector<InputType> in_types, OutputType out_type,
+ ArrayKernelExec exec, KernelInit init = NULLPTR);
+
+ /// \brief Add a kernel (function implementation). Returns error if the
+ /// kernel's signature does not match the function's arity.
+ Status AddKernel(ScalarKernel kernel);
+};
+
+/// \brief A function that executes general array operations that may yield
+/// outputs of different sizes or have results that depend on the whole array
+/// contents. These functions roughly correspond to the functions found in
+/// non-SQL array languages like APL and its derivatives.
+class ARROW_EXPORT VectorFunction : public detail::FunctionImpl<VectorKernel> {
+ public:
+ using KernelType = VectorKernel;
+
+ VectorFunction(std::string name, const Arity& arity, const FunctionDoc* doc,
+ const FunctionOptions* default_options = NULLPTR)
+ : detail::FunctionImpl<VectorKernel>(std::move(name), Function::VECTOR, arity, doc,
+ default_options) {}
+
+ /// \brief Add a simple kernel with given input/output types, no required
+ /// state initialization, no data preallocation, and no preallocation of the
+ /// validity bitmap.
+ Status AddKernel(std::vector<InputType> in_types, OutputType out_type,
+ ArrayKernelExec exec, KernelInit init = NULLPTR);
+
+ /// \brief Add a kernel (function implementation). Returns error if the
+ /// kernel's signature does not match the function's arity.
+ Status AddKernel(VectorKernel kernel);
+};
+
+class ARROW_EXPORT ScalarAggregateFunction
+ : public detail::FunctionImpl<ScalarAggregateKernel> {
+ public:
+ using KernelType = ScalarAggregateKernel;
+
+ ScalarAggregateFunction(std::string name, const Arity& arity, const FunctionDoc* doc,
+ const FunctionOptions* default_options = NULLPTR)
+ : detail::FunctionImpl<ScalarAggregateKernel>(
+ std::move(name), Function::SCALAR_AGGREGATE, arity, doc, default_options) {}
+
+ /// \brief Add a kernel (function implementation). Returns error if the
+ /// kernel's signature does not match the function's arity.
+ Status AddKernel(ScalarAggregateKernel kernel);
+};
+
+class ARROW_EXPORT HashAggregateFunction
+ : public detail::FunctionImpl<HashAggregateKernel> {
+ public:
+ using KernelType = HashAggregateKernel;
+
+ HashAggregateFunction(std::string name, const Arity& arity, const FunctionDoc* doc,
+ const FunctionOptions* default_options = NULLPTR)
+ : detail::FunctionImpl<HashAggregateKernel>(
+ std::move(name), Function::HASH_AGGREGATE, arity, doc, default_options) {}
+
+ /// \brief Add a kernel (function implementation). Returns error if the
+ /// kernel's signature does not match the function's arity.
+ Status AddKernel(HashAggregateKernel kernel);
+};
+
+/// \brief A function that dispatches to other functions. Must implement
+/// MetaFunction::ExecuteImpl.
+///
+/// For Array, ChunkedArray, and Scalar Datum kinds, may rely on the execution
+/// of concrete Function types, but must handle other Datum kinds on its own.
+class ARROW_EXPORT MetaFunction : public Function {
+ public:
+ int num_kernels() const override { return 0; }
+
+ Result<Datum> Execute(const std::vector<Datum>& args, const FunctionOptions* options,
+ ExecContext* ctx) const override;
+
+ protected:
+ virtual Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const = 0;
+
+ MetaFunction(std::string name, const Arity& arity, const FunctionDoc* doc,
+ const FunctionOptions* default_options = NULLPTR)
+ : Function(std::move(name), Function::META, arity, doc, default_options) {}
+};
+
+/// @}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/function_benchmark.cc b/src/arrow/cpp/src/arrow/compute/function_benchmark.cc
new file mode 100644
index 000000000..a29a766be
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/function_benchmark.cc
@@ -0,0 +1,218 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/array/array_base.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/exec_internal.h"
+#include "arrow/memory_pool.h"
+#include "arrow/scalar.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+constexpr int32_t kSeed = 0xfede4a7e;
+constexpr int64_t kScalarCount = 1 << 10;
+
+inline ScalarVector ToScalars(std::shared_ptr<Array> arr) {
+ ScalarVector scalars{static_cast<size_t>(arr->length())};
+ int64_t i = 0;
+ for (auto& scalar : scalars) {
+ scalar = arr->GetScalar(i++).ValueOrDie();
+ }
+ return scalars;
+}
+
+void BM_CastDispatch(benchmark::State& state) {
+ // Repeatedly invoke a trivial Cast: the main cost should be dispatch
+ random::RandomArrayGenerator rag(kSeed);
+
+ auto int_scalars = ToScalars(rag.Int64(kScalarCount, 0, 1 << 20));
+
+ auto double_type = float64();
+ for (auto _ : state) {
+ Datum timestamp_scalar;
+ for (Datum int_scalar : int_scalars) {
+ ASSERT_OK_AND_ASSIGN(timestamp_scalar, Cast(int_scalar, double_type));
+ }
+ benchmark::DoNotOptimize(timestamp_scalar);
+ }
+
+ state.SetItemsProcessed(state.iterations() * kScalarCount);
+}
+
+void BM_CastDispatchBaseline(benchmark::State& state) {
+ // Repeatedly invoke a trivial Cast with all dispatch outside the hot loop
+ random::RandomArrayGenerator rag(kSeed);
+
+ auto int_scalars = ToScalars(rag.Int64(kScalarCount, 0, 1 << 20));
+
+ auto double_type = float64();
+ CastOptions cast_options;
+ cast_options.to_type = double_type;
+ ASSERT_OK_AND_ASSIGN(auto cast_function, GetCastFunction(double_type));
+ ASSERT_OK_AND_ASSIGN(auto cast_kernel,
+ cast_function->DispatchExact({int_scalars[0]->type}));
+ const auto& exec = static_cast<const ScalarKernel*>(cast_kernel)->exec;
+
+ ExecContext exec_context;
+ KernelContext kernel_context(&exec_context);
+ auto cast_state = cast_kernel
+ ->init(&kernel_context,
+ KernelInitArgs{cast_kernel, {double_type}, &cast_options})
+ .ValueOrDie();
+ kernel_context.SetState(cast_state.get());
+
+ for (auto _ : state) {
+ Datum timestamp_scalar = MakeNullScalar(double_type);
+ for (Datum int_scalar : int_scalars) {
+ ABORT_NOT_OK(
+ exec(&kernel_context, {{std::move(int_scalar)}, 1}, &timestamp_scalar));
+ }
+ benchmark::DoNotOptimize(timestamp_scalar);
+ }
+
+ state.SetItemsProcessed(state.iterations() * kScalarCount);
+}
+
+void BM_AddDispatch(benchmark::State& state) {
+ ExecContext exec_context;
+ KernelContext kernel_context(&exec_context);
+
+ for (auto _ : state) {
+ ASSERT_OK_AND_ASSIGN(auto add_function, GetFunctionRegistry()->GetFunction("add"));
+ ASSERT_OK_AND_ASSIGN(auto add_kernel,
+ checked_cast<const ScalarFunction&>(*add_function)
+ .DispatchExact({int64(), int64()}));
+ benchmark::DoNotOptimize(add_kernel);
+ }
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+static ScalarVector MakeScalarsForIsValid(int64_t nitems) {
+ std::vector<std::shared_ptr<Scalar>> scalars;
+ scalars.reserve(nitems);
+ for (int64_t i = 0; i < nitems; ++i) {
+ if (i & 0x10) {
+ scalars.emplace_back(MakeNullScalar(int64()));
+ } else {
+ scalars.emplace_back(*MakeScalar(int64(), i));
+ }
+ }
+ return scalars;
+}
+
+void BM_ExecuteScalarFunctionOnScalar(benchmark::State& state) {
+ // Execute a trivial function, with argument dispatch in the hot path
+ const int64_t N = 10000;
+
+ auto function = checked_pointer_cast<ScalarFunction>(
+ *GetFunctionRegistry()->GetFunction("is_valid"));
+ const auto scalars = MakeScalarsForIsValid(N);
+
+ ExecContext exec_context;
+ KernelContext kernel_context(&exec_context);
+
+ for (auto _ : state) {
+ int64_t total = 0;
+ for (const auto& scalar : scalars) {
+ const Datum result =
+ *function->Execute({Datum(scalar)}, function->default_options(), &exec_context);
+ total += result.scalar()->is_valid;
+ }
+ benchmark::DoNotOptimize(total);
+ }
+
+ state.SetItemsProcessed(state.iterations() * N);
+}
+
+void BM_ExecuteScalarKernelOnScalar(benchmark::State& state) {
+ // Execute a trivial function, with argument dispatch outside the hot path
+ const int64_t N = 10000;
+
+ auto function = *GetFunctionRegistry()->GetFunction("is_valid");
+ auto kernel = *function->DispatchExact({ValueDescr::Scalar(int64())});
+ const auto& exec = static_cast<const ScalarKernel&>(*kernel).exec;
+
+ const auto scalars = MakeScalarsForIsValid(N);
+
+ ExecContext exec_context;
+ KernelContext kernel_context(&exec_context);
+
+ for (auto _ : state) {
+ int64_t total = 0;
+ for (const auto& scalar : scalars) {
+ Datum result{MakeNullScalar(int64())};
+ ABORT_NOT_OK(exec(&kernel_context, ExecBatch{{scalar}, /*length=*/1}, &result));
+ total += result.scalar()->is_valid;
+ }
+ benchmark::DoNotOptimize(total);
+ }
+
+ state.SetItemsProcessed(state.iterations() * N);
+}
+
+void BM_ExecBatchIterator(benchmark::State& state) {
+ // Measure overhead related to splitting ExecBatch into smaller ExecBatches
+ // for parallelism or more optimal CPU cache affinity
+ random::RandomArrayGenerator rag(kSeed);
+
+ const int64_t length = 1 << 20;
+ const int num_fields = 32;
+
+ std::vector<Datum> args(num_fields);
+ for (int i = 0; i < num_fields; ++i) {
+ args[i] = rag.Int64(length, 0, 100)->data();
+ }
+
+ const int64_t blocksize = state.range(0);
+ for (auto _ : state) {
+ std::unique_ptr<detail::ExecBatchIterator> it =
+ *detail::ExecBatchIterator::Make(args, blocksize);
+ ExecBatch batch;
+ while (it->Next(&batch)) {
+ for (int i = 0; i < num_fields; ++i) {
+ auto data = batch.values[i].array()->buffers[1]->data();
+ benchmark::DoNotOptimize(data);
+ }
+ }
+ benchmark::DoNotOptimize(batch);
+ }
+ // Provides comparability across blocksizes by looking at the iterations per
+ // second. So 1000 iterations/second means that input splitting associated
+ // with ExecBatchIterator takes up 1ms every time.
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(BM_CastDispatch);
+BENCHMARK(BM_CastDispatchBaseline);
+BENCHMARK(BM_AddDispatch);
+BENCHMARK(BM_ExecuteScalarFunctionOnScalar);
+BENCHMARK(BM_ExecuteScalarKernelOnScalar);
+BENCHMARK(BM_ExecBatchIterator)->RangeMultiplier(4)->Range(1024, 64 * 1024);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/function_internal.cc b/src/arrow/cpp/src/arrow/compute/function_internal.cc
new file mode 100644
index 000000000..0a926e0a3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/function_internal.cc
@@ -0,0 +1,113 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/function_internal.h"
+
+#include "arrow/array/util.h"
+#include "arrow/compute/function.h"
+#include "arrow/compute/registry.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/scalar.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+using ::arrow::internal::checked_cast;
+
+constexpr char kTypeNameField[] = "_type_name";
+
+Result<std::shared_ptr<StructScalar>> FunctionOptionsToStructScalar(
+ const FunctionOptions& options) {
+ std::vector<std::string> field_names;
+ std::vector<std::shared_ptr<Scalar>> values;
+ const auto* options_type =
+ dynamic_cast<const GenericOptionsType*>(options.options_type());
+ if (!options_type) {
+ return Status::NotImplemented("serializing ", options.type_name(),
+ " to StructScalar");
+ }
+ RETURN_NOT_OK(options_type->ToStructScalar(options, &field_names, &values));
+ field_names.push_back(kTypeNameField);
+ const char* options_name = options.type_name();
+ values.emplace_back(
+ new BinaryScalar(Buffer::Wrap(options_name, std::strlen(options_name))));
+ return StructScalar::Make(std::move(values), std::move(field_names));
+}
+
+Result<std::unique_ptr<FunctionOptions>> FunctionOptionsFromStructScalar(
+ const StructScalar& scalar) {
+ ARROW_ASSIGN_OR_RAISE(auto type_name_holder, scalar.field(kTypeNameField));
+ const std::string type_name =
+ checked_cast<const BinaryScalar&>(*type_name_holder).value->ToString();
+ ARROW_ASSIGN_OR_RAISE(auto raw_options_type,
+ GetFunctionRegistry()->GetFunctionOptionsType(type_name));
+ const auto* options_type = checked_cast<const GenericOptionsType*>(raw_options_type);
+ return options_type->FromStructScalar(scalar);
+}
+
+Result<std::shared_ptr<Buffer>> GenericOptionsType::Serialize(
+ const FunctionOptions& options) const {
+ ARROW_ASSIGN_OR_RAISE(auto scalar, FunctionOptionsToStructScalar(options));
+ ARROW_ASSIGN_OR_RAISE(auto array, MakeArrayFromScalar(*scalar, 1));
+ auto batch =
+ RecordBatch::Make(schema({field("", array->type())}), /*num_rows=*/1, {array});
+ ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create());
+ ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema()));
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+ RETURN_NOT_OK(writer->Close());
+ return stream->Finish();
+}
+
+Result<std::unique_ptr<FunctionOptions>> GenericOptionsType::Deserialize(
+ const Buffer& buffer) const {
+ return DeserializeFunctionOptions(buffer);
+}
+
+Result<std::unique_ptr<FunctionOptions>> DeserializeFunctionOptions(
+ const Buffer& buffer) {
+ io::BufferReader stream(buffer);
+ ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream));
+ ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0));
+ if (batch->num_rows() != 1) {
+ return Status::Invalid(
+ "serialized FunctionOptions's batch repr was not a single row - had ",
+ batch->num_rows());
+ }
+ if (batch->num_columns() != 1) {
+ return Status::Invalid(
+ "serialized FunctionOptions's batch repr was not a single column - had ",
+ batch->num_columns());
+ }
+ auto column = batch->column(0);
+ if (column->type()->id() != Type::STRUCT) {
+ return Status::Invalid(
+ "serialized FunctionOptions's batch repr was not a struct column - was ",
+ column->type()->ToString());
+ }
+ ARROW_ASSIGN_OR_RAISE(auto raw_scalar,
+ checked_cast<const StructArray&>(*column).GetScalar(0));
+ auto scalar = checked_cast<const StructScalar&>(*raw_scalar);
+ return FunctionOptionsFromStructScalar(scalar);
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/function_internal.h b/src/arrow/cpp/src/arrow/compute/function_internal.h
new file mode 100644
index 000000000..587b9396b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/function_internal.h
@@ -0,0 +1,648 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "arrow/array/builder_base.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/function.h"
+#include "arrow/compute/type_fwd.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/reflection_internal.h"
+#include "arrow/util/string.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+struct Scalar;
+struct StructScalar;
+using ::arrow::internal::checked_cast;
+
+namespace internal {
+template <>
+struct EnumTraits<compute::SortOrder>
+ : BasicEnumTraits<compute::SortOrder, compute::SortOrder::Ascending,
+ compute::SortOrder::Descending> {
+ static std::string name() { return "SortOrder"; }
+ static std::string value_name(compute::SortOrder value) {
+ switch (value) {
+ case compute::SortOrder::Ascending:
+ return "Ascending";
+ case compute::SortOrder::Descending:
+ return "Descending";
+ }
+ return "<INVALID>";
+ }
+};
+} // namespace internal
+
+namespace compute {
+namespace internal {
+
+using arrow::internal::EnumTraits;
+using arrow::internal::has_enum_traits;
+
+template <typename Enum, typename CType = typename std::underlying_type<Enum>::type>
+Result<Enum> ValidateEnumValue(CType raw) {
+ for (auto valid : EnumTraits<Enum>::values()) {
+ if (raw == static_cast<CType>(valid)) {
+ return static_cast<Enum>(raw);
+ }
+ }
+ return Status::Invalid("Invalid value for ", EnumTraits<Enum>::name(), ": ", raw);
+}
+
+class GenericOptionsType : public FunctionOptionsType {
+ public:
+ Result<std::shared_ptr<Buffer>> Serialize(const FunctionOptions&) const override;
+ Result<std::unique_ptr<FunctionOptions>> Deserialize(
+ const Buffer& buffer) const override;
+ virtual Status ToStructScalar(const FunctionOptions& options,
+ std::vector<std::string>* field_names,
+ std::vector<std::shared_ptr<Scalar>>* values) const = 0;
+ virtual Result<std::unique_ptr<FunctionOptions>> FromStructScalar(
+ const StructScalar& scalar) const = 0;
+};
+
+ARROW_EXPORT
+Result<std::shared_ptr<StructScalar>> FunctionOptionsToStructScalar(
+ const FunctionOptions&);
+ARROW_EXPORT
+Result<std::unique_ptr<FunctionOptions>> FunctionOptionsFromStructScalar(
+ const StructScalar&);
+ARROW_EXPORT
+Result<std::unique_ptr<FunctionOptions>> DeserializeFunctionOptions(const Buffer& buffer);
+
+template <typename T>
+static inline enable_if_t<!has_enum_traits<T>::value, std::string> GenericToString(
+ const T& value) {
+ std::stringstream ss;
+ ss << value;
+ return ss.str();
+}
+
+static inline std::string GenericToString(bool value) { return value ? "true" : "false"; }
+
+static inline std::string GenericToString(const std::string& value) {
+ std::stringstream ss;
+ ss << '"' << value << '"';
+ return ss.str();
+}
+
+template <typename T>
+static inline enable_if_t<has_enum_traits<T>::value, std::string> GenericToString(
+ const T value) {
+ return EnumTraits<T>::value_name(value);
+}
+
+template <typename T>
+static inline std::string GenericToString(const std::shared_ptr<T>& value) {
+ std::stringstream ss;
+ return value ? value->ToString() : "<NULLPTR>";
+}
+
+static inline std::string GenericToString(const std::shared_ptr<Scalar>& value) {
+ std::stringstream ss;
+ ss << value->type->ToString() << ":" << value->ToString();
+ return ss.str();
+}
+
+static inline std::string GenericToString(
+ const std::shared_ptr<const KeyValueMetadata>& value) {
+ std::stringstream ss;
+ ss << "KeyValueMetadata{";
+ if (value) {
+ bool first = true;
+ for (const auto& pair : value->sorted_pairs()) {
+ if (!first) ss << ", ";
+ first = false;
+ ss << pair.first << ':' << pair.second;
+ }
+ }
+ ss << '}';
+ return ss.str();
+}
+
+static inline std::string GenericToString(const Datum& value) {
+ switch (value.kind()) {
+ case Datum::NONE:
+ return "<NULL DATUM>";
+ case Datum::SCALAR:
+ return GenericToString(value.scalar());
+ case Datum::ARRAY: {
+ std::stringstream ss;
+ ss << value.type()->ToString() << ':' << value.make_array()->ToString();
+ return ss.str();
+ }
+ case Datum::CHUNKED_ARRAY:
+ case Datum::RECORD_BATCH:
+ case Datum::TABLE:
+ case Datum::COLLECTION:
+ return value.ToString();
+ }
+ return value.ToString();
+}
+
+template <typename T>
+static inline std::string GenericToString(const std::vector<T>& value) {
+ std::stringstream ss;
+ ss << "[";
+ bool first = true;
+ // Don't use range-for with auto& to avoid Clang -Wrange-loop-analysis
+ for (auto it = value.begin(); it != value.end(); it++) {
+ if (!first) ss << ", ";
+ first = false;
+ ss << GenericToString(*it);
+ }
+ ss << ']';
+ return ss.str();
+}
+
+static inline std::string GenericToString(SortOrder value) {
+ switch (value) {
+ case SortOrder::Ascending:
+ return "Ascending";
+ case SortOrder::Descending:
+ return "Descending";
+ }
+ return "<INVALID SORT ORDER>";
+}
+
+static inline std::string GenericToString(const std::vector<SortKey>& value) {
+ std::stringstream ss;
+ ss << '[';
+ bool first = true;
+ for (const auto& key : value) {
+ if (!first) {
+ ss << ", ";
+ }
+ first = false;
+ ss << key.ToString();
+ }
+ ss << ']';
+ return ss.str();
+}
+
+template <typename T>
+static inline bool GenericEquals(const T& left, const T& right) {
+ return left == right;
+}
+
+template <typename T>
+static inline bool GenericEquals(const std::shared_ptr<T>& left,
+ const std::shared_ptr<T>& right) {
+ if (left && right) {
+ return left->Equals(*right);
+ }
+ return left == right;
+}
+
+static inline bool IsEmpty(const std::shared_ptr<const KeyValueMetadata>& meta) {
+ return !meta || meta->size() == 0;
+}
+
+static inline bool GenericEquals(const std::shared_ptr<const KeyValueMetadata>& left,
+ const std::shared_ptr<const KeyValueMetadata>& right) {
+ // Special case since null metadata is considered equivalent to empty
+ if (IsEmpty(left) || IsEmpty(right)) {
+ return IsEmpty(left) && IsEmpty(right);
+ }
+ return left->Equals(*right);
+}
+
+template <typename T>
+static inline bool GenericEquals(const std::vector<T>& left,
+ const std::vector<T>& right) {
+ if (left.size() != right.size()) return false;
+ for (size_t i = 0; i < left.size(); i++) {
+ if (!GenericEquals(left[i], right[i])) return false;
+ }
+ return true;
+}
+
+template <typename T>
+static inline decltype(TypeTraits<typename CTypeTraits<T>::ArrowType>::type_singleton())
+GenericTypeSingleton() {
+ return TypeTraits<typename CTypeTraits<T>::ArrowType>::type_singleton();
+}
+
+template <typename T>
+static inline enable_if_same<T, std::shared_ptr<const KeyValueMetadata>,
+ std::shared_ptr<DataType>>
+GenericTypeSingleton() {
+ return map(binary(), binary());
+}
+
+template <typename T>
+static inline enable_if_t<has_enum_traits<T>::value, std::shared_ptr<DataType>>
+GenericTypeSingleton() {
+ return TypeTraits<typename EnumTraits<T>::Type>::type_singleton();
+}
+
+template <typename T>
+static inline enable_if_same<T, SortKey, std::shared_ptr<DataType>>
+GenericTypeSingleton() {
+ std::vector<std::shared_ptr<Field>> fields;
+ fields.emplace_back(new Field("name", GenericTypeSingleton<std::string>()));
+ fields.emplace_back(new Field("order", GenericTypeSingleton<SortOrder>()));
+ return std::make_shared<StructType>(std::move(fields));
+}
+
+// N.B. ordering of overloads is relatively fragile
+template <typename T>
+static inline Result<decltype(MakeScalar(std::declval<T>()))> GenericToScalar(
+ const T& value) {
+ return MakeScalar(value);
+}
+
+// For Clang/libc++: when iterating through vector<bool>, we can't
+// pass it by reference so the overload above doesn't apply
+static inline Result<std::shared_ptr<Scalar>> GenericToScalar(bool value) {
+ return MakeScalar(value);
+}
+
+template <typename T, typename Enable = enable_if_t<has_enum_traits<T>::value>>
+static inline Result<std::shared_ptr<Scalar>> GenericToScalar(const T value) {
+ using CType = typename EnumTraits<T>::CType;
+ return GenericToScalar(static_cast<CType>(value));
+}
+
+static inline Result<std::shared_ptr<Scalar>> GenericToScalar(const SortKey& value) {
+ ARROW_ASSIGN_OR_RAISE(auto name, GenericToScalar(value.name));
+ ARROW_ASSIGN_OR_RAISE(auto order, GenericToScalar(value.order));
+ return StructScalar::Make({name, order}, {"name", "order"});
+}
+
+static inline Result<std::shared_ptr<Scalar>> GenericToScalar(
+ const std::shared_ptr<const KeyValueMetadata>& value) {
+ auto ty = GenericTypeSingleton<std::shared_ptr<const KeyValueMetadata>>();
+ std::unique_ptr<ArrayBuilder> builder;
+ RETURN_NOT_OK(MakeBuilder(default_memory_pool(), ty, &builder));
+ auto* map_builder = checked_cast<MapBuilder*>(builder.get());
+ auto* key_builder = checked_cast<BinaryBuilder*>(map_builder->key_builder());
+ auto* item_builder = checked_cast<BinaryBuilder*>(map_builder->item_builder());
+ RETURN_NOT_OK(map_builder->Append());
+ if (value) {
+ RETURN_NOT_OK(key_builder->AppendValues(value->keys()));
+ RETURN_NOT_OK(item_builder->AppendValues(value->values()));
+ }
+ std::shared_ptr<Array> arr;
+ RETURN_NOT_OK(map_builder->Finish(&arr));
+ return arr->GetScalar(0);
+}
+
+template <typename T>
+static inline Result<std::shared_ptr<Scalar>> GenericToScalar(
+ const std::vector<T>& value) {
+ std::shared_ptr<DataType> type = GenericTypeSingleton<T>();
+ std::vector<std::shared_ptr<Scalar>> scalars;
+ scalars.reserve(value.size());
+ // Don't use range-for with auto& to avoid Clang -Wrange-loop-analysis
+ for (auto it = value.begin(); it != value.end(); it++) {
+ ARROW_ASSIGN_OR_RAISE(auto scalar, GenericToScalar(*it));
+ scalars.push_back(std::move(scalar));
+ }
+ std::unique_ptr<ArrayBuilder> builder;
+ RETURN_NOT_OK(
+ MakeBuilder(default_memory_pool(), type ? type : scalars[0]->type, &builder));
+ RETURN_NOT_OK(builder->AppendScalars(scalars));
+ std::shared_ptr<Array> out;
+ RETURN_NOT_OK(builder->Finish(&out));
+ return std::make_shared<ListScalar>(std::move(out));
+}
+
+static inline Result<std::shared_ptr<Scalar>> GenericToScalar(
+ const std::shared_ptr<DataType>& value) {
+ if (!value) {
+ return Status::Invalid("shared_ptr<DataType> is nullptr");
+ }
+ return MakeNullScalar(value);
+}
+
+static inline Result<std::shared_ptr<Scalar>> GenericToScalar(
+ const std::shared_ptr<Scalar>& value) {
+ return value;
+}
+
+static inline Result<std::shared_ptr<Scalar>> GenericToScalar(
+ const std::shared_ptr<Array>& value) {
+ return std::make_shared<ListScalar>(value);
+}
+
+static inline Result<std::shared_ptr<Scalar>> GenericToScalar(const Datum& value) {
+ // TODO(ARROW-9434): store in a union instead.
+ switch (value.kind()) {
+ case Datum::ARRAY:
+ return GenericToScalar(value.make_array());
+ break;
+ default:
+ return Status::NotImplemented("Cannot serialize Datum kind ", value.kind());
+ }
+}
+
+template <typename T>
+static inline enable_if_primitive_ctype<typename CTypeTraits<T>::ArrowType, Result<T>>
+GenericFromScalar(const std::shared_ptr<Scalar>& value) {
+ using ArrowType = typename CTypeTraits<T>::ArrowType;
+ using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+ if (value->type->id() != ArrowType::type_id) {
+ return Status::Invalid("Expected type ", ArrowType::type_id, " but got ",
+ value->type->ToString());
+ }
+ const auto& holder = checked_cast<const ScalarType&>(*value);
+ if (!holder.is_valid) return Status::Invalid("Got null scalar");
+ return holder.value;
+}
+
+template <typename T>
+static inline enable_if_primitive_ctype<typename EnumTraits<T>::Type, Result<T>>
+GenericFromScalar(const std::shared_ptr<Scalar>& value) {
+ ARROW_ASSIGN_OR_RAISE(auto raw_val,
+ GenericFromScalar<typename EnumTraits<T>::CType>(value));
+ return ValidateEnumValue<T>(raw_val);
+}
+
+template <typename T, typename U>
+using enable_if_same_result = enable_if_same<T, U, Result<T>>;
+
+template <typename T>
+static inline enable_if_same_result<T, std::string> GenericFromScalar(
+ const std::shared_ptr<Scalar>& value) {
+ if (!is_base_binary_like(value->type->id())) {
+ return Status::Invalid("Expected binary-like type but got ", value->type->ToString());
+ }
+ const auto& holder = checked_cast<const BaseBinaryScalar&>(*value);
+ if (!holder.is_valid) return Status::Invalid("Got null scalar");
+ return holder.value->ToString();
+}
+
+template <typename T>
+static inline enable_if_same_result<T, SortKey> GenericFromScalar(
+ const std::shared_ptr<Scalar>& value) {
+ if (value->type->id() != Type::STRUCT) {
+ return Status::Invalid("Expected type STRUCT but got ", value->type->id());
+ }
+ if (!value->is_valid) return Status::Invalid("Got null scalar");
+ const auto& holder = checked_cast<const StructScalar&>(*value);
+ ARROW_ASSIGN_OR_RAISE(auto name_holder, holder.field("name"));
+ ARROW_ASSIGN_OR_RAISE(auto order_holder, holder.field("order"));
+ ARROW_ASSIGN_OR_RAISE(auto name, GenericFromScalar<std::string>(name_holder));
+ ARROW_ASSIGN_OR_RAISE(auto order, GenericFromScalar<SortOrder>(order_holder));
+ return SortKey{std::move(name), order};
+}
+
+template <typename T>
+static inline enable_if_same_result<T, std::shared_ptr<DataType>> GenericFromScalar(
+ const std::shared_ptr<Scalar>& value) {
+ return value->type;
+}
+
+template <typename T>
+static inline enable_if_same_result<T, std::shared_ptr<Scalar>> GenericFromScalar(
+ const std::shared_ptr<Scalar>& value) {
+ return value;
+}
+
+template <typename T>
+static inline enable_if_same_result<T, std::shared_ptr<const KeyValueMetadata>>
+GenericFromScalar(const std::shared_ptr<Scalar>& value) {
+ auto ty = GenericTypeSingleton<std::shared_ptr<const KeyValueMetadata>>();
+ if (!value->type->Equals(ty)) {
+ return Status::Invalid("Expected ", ty->ToString(), " but got ",
+ value->type->ToString());
+ }
+ const auto& holder = checked_cast<const MapScalar&>(*value);
+ std::vector<std::string> keys;
+ std::vector<std::string> values;
+ const auto& list = checked_cast<const StructArray&>(*holder.value);
+ const auto& key_arr = checked_cast<const BinaryArray&>(*list.field(0));
+ const auto& value_arr = checked_cast<const BinaryArray&>(*list.field(1));
+ for (int64_t i = 0; i < list.length(); i++) {
+ keys.push_back(key_arr.GetString(i));
+ values.push_back(value_arr.GetString(i));
+ }
+ return key_value_metadata(std::move(keys), std::move(values));
+}
+
+template <typename T>
+static inline enable_if_same_result<T, Datum> GenericFromScalar(
+ const std::shared_ptr<Scalar>& value) {
+ if (value->type->id() == Type::LIST) {
+ const auto& holder = checked_cast<const BaseListScalar&>(*value);
+ return holder.value;
+ }
+ // TODO(ARROW-9434): handle other possible datum kinds by looking for a union
+ return Status::Invalid("Cannot deserialize Datum from ", value->ToString());
+}
+
+template <typename T>
+static enable_if_same<typename CTypeTraits<T>::ArrowType, ListType, Result<T>>
+GenericFromScalar(const std::shared_ptr<Scalar>& value) {
+ using ValueType = typename T::value_type;
+ if (value->type->id() != Type::LIST) {
+ return Status::Invalid("Expected type LIST but got ", value->type->ToString());
+ }
+ const auto& holder = checked_cast<const BaseListScalar&>(*value);
+ if (!holder.is_valid) return Status::Invalid("Got null scalar");
+ std::vector<ValueType> result;
+ for (int i = 0; i < holder.value->length(); i++) {
+ ARROW_ASSIGN_OR_RAISE(auto scalar, holder.value->GetScalar(i));
+ ARROW_ASSIGN_OR_RAISE(auto v, GenericFromScalar<ValueType>(scalar));
+ result.push_back(std::move(v));
+ }
+ return result;
+}
+
+template <typename Options>
+struct StringifyImpl {
+ template <typename Tuple>
+ StringifyImpl(const Options& obj, const Tuple& props)
+ : obj_(obj), members_(props.size()) {
+ props.ForEach(*this);
+ }
+
+ template <typename Property>
+ void operator()(const Property& prop, size_t i) {
+ std::stringstream ss;
+ ss << prop.name() << '=' << GenericToString(prop.get(obj_));
+ members_[i] = ss.str();
+ }
+
+ std::string Finish() {
+ return "{" + arrow::internal::JoinStrings(members_, ", ") + "}";
+ }
+
+ const Options& obj_;
+ std::vector<std::string> members_;
+};
+
+template <typename Options>
+struct CompareImpl {
+ template <typename Tuple>
+ CompareImpl(const Options& l, const Options& r, const Tuple& props)
+ : left_(l), right_(r) {
+ props.ForEach(*this);
+ }
+
+ template <typename Property>
+ void operator()(const Property& prop, size_t) {
+ equal_ &= GenericEquals(prop.get(left_), prop.get(right_));
+ }
+
+ const Options& left_;
+ const Options& right_;
+ bool equal_ = true;
+};
+
+template <typename Options>
+struct ToStructScalarImpl {
+ template <typename Tuple>
+ ToStructScalarImpl(const Options& obj, const Tuple& props,
+ std::vector<std::string>* field_names,
+ std::vector<std::shared_ptr<Scalar>>* values)
+ : obj_(obj), field_names_(field_names), values_(values) {
+ props.ForEach(*this);
+ }
+
+ template <typename Property>
+ void operator()(const Property& prop, size_t) {
+ if (!status_.ok()) return;
+ auto result = GenericToScalar(prop.get(obj_));
+ if (!result.ok()) {
+ status_ = result.status().WithMessage("Could not serialize field ", prop.name(),
+ " of options type ", Options::kTypeName, ": ",
+ result.status().message());
+ return;
+ }
+ field_names_->emplace_back(prop.name());
+ values_->push_back(result.MoveValueUnsafe());
+ }
+
+ const Options& obj_;
+ Status status_;
+ std::vector<std::string>* field_names_;
+ std::vector<std::shared_ptr<Scalar>>* values_;
+};
+
+template <typename Options>
+struct FromStructScalarImpl {
+ template <typename Tuple>
+ FromStructScalarImpl(Options* obj, const StructScalar& scalar, const Tuple& props)
+ : obj_(obj), scalar_(scalar) {
+ props.ForEach(*this);
+ }
+
+ template <typename Property>
+ void operator()(const Property& prop, size_t) {
+ if (!status_.ok()) return;
+ auto maybe_holder = scalar_.field(std::string(prop.name()));
+ if (!maybe_holder.ok()) {
+ status_ = maybe_holder.status().WithMessage(
+ "Cannot deserialize field ", prop.name(), " of options type ",
+ Options::kTypeName, ": ", maybe_holder.status().message());
+ return;
+ }
+ auto holder = maybe_holder.MoveValueUnsafe();
+ auto result = GenericFromScalar<typename Property::Type>(holder);
+ if (!result.ok()) {
+ status_ = result.status().WithMessage("Cannot deserialize field ", prop.name(),
+ " of options type ", Options::kTypeName, ": ",
+ result.status().message());
+ return;
+ }
+ prop.set(obj_, result.MoveValueUnsafe());
+ }
+
+ Options* obj_;
+ Status status_;
+ const StructScalar& scalar_;
+};
+
+template <typename Options>
+struct CopyImpl {
+ template <typename Tuple>
+ CopyImpl(Options* obj, const Options& options, const Tuple& props)
+ : obj_(obj), options_(options) {
+ props.ForEach(*this);
+ }
+
+ template <typename Property>
+ void operator()(const Property& prop, size_t) {
+ prop.set(obj_, prop.get(options_));
+ }
+
+ Options* obj_;
+ const Options& options_;
+};
+
+template <typename Options, typename... Properties>
+const FunctionOptionsType* GetFunctionOptionsType(const Properties&... properties) {
+ static const class OptionsType : public GenericOptionsType {
+ public:
+ explicit OptionsType(const arrow::internal::PropertyTuple<Properties...> properties)
+ : properties_(properties) {}
+
+ const char* type_name() const override { return Options::kTypeName; }
+
+ std::string Stringify(const FunctionOptions& options) const override {
+ const auto& self = checked_cast<const Options&>(options);
+ return StringifyImpl<Options>(self, properties_).Finish();
+ }
+ bool Compare(const FunctionOptions& options,
+ const FunctionOptions& other) const override {
+ const auto& lhs = checked_cast<const Options&>(options);
+ const auto& rhs = checked_cast<const Options&>(other);
+ return CompareImpl<Options>(lhs, rhs, properties_).equal_;
+ }
+ Status ToStructScalar(const FunctionOptions& options,
+ std::vector<std::string>* field_names,
+ std::vector<std::shared_ptr<Scalar>>* values) const override {
+ const auto& self = checked_cast<const Options&>(options);
+ RETURN_NOT_OK(
+ ToStructScalarImpl<Options>(self, properties_, field_names, values).status_);
+ return Status::OK();
+ }
+ Result<std::unique_ptr<FunctionOptions>> FromStructScalar(
+ const StructScalar& scalar) const override {
+ auto options = std::unique_ptr<Options>(new Options());
+ RETURN_NOT_OK(
+ FromStructScalarImpl<Options>(options.get(), scalar, properties_).status_);
+ return std::move(options);
+ }
+ std::unique_ptr<FunctionOptions> Copy(const FunctionOptions& options) const override {
+ auto out = std::unique_ptr<Options>(new Options());
+ CopyImpl<Options>(out.get(), checked_cast<const Options&>(options), properties_);
+ return std::move(out);
+ }
+
+ private:
+ const arrow::internal::PropertyTuple<Properties...> properties_;
+ } instance(arrow::internal::MakeProperties(properties...));
+ return &instance;
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/function_test.cc b/src/arrow/cpp/src/arrow/compute/function_test.cc
new file mode 100644
index 000000000..626824d73
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/function_test.cc
@@ -0,0 +1,351 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/function.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/kernel.h"
+#include "arrow/datum.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+namespace compute {
+
+TEST(FunctionOptions, Equality) {
+ std::vector<std::shared_ptr<FunctionOptions>> options;
+ options.emplace_back(new ScalarAggregateOptions());
+ options.emplace_back(new ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1));
+ options.emplace_back(new CountOptions());
+ options.emplace_back(new CountOptions(CountOptions::ALL));
+ options.emplace_back(new ModeOptions());
+ options.emplace_back(new ModeOptions(/*n=*/2));
+ options.emplace_back(new VarianceOptions());
+ options.emplace_back(new VarianceOptions(/*ddof=*/2));
+ options.emplace_back(new QuantileOptions());
+ options.emplace_back(
+ new QuantileOptions(/*q=*/0.75, QuantileOptions::Interpolation::MIDPOINT));
+ options.emplace_back(new TDigestOptions());
+ options.emplace_back(
+ new TDigestOptions(/*q=*/0.75, /*delta=*/50, /*buffer_size=*/1024));
+ options.emplace_back(new IndexOptions(ScalarFromJSON(int64(), "16")));
+ options.emplace_back(new IndexOptions(ScalarFromJSON(boolean(), "true")));
+ options.emplace_back(new IndexOptions(ScalarFromJSON(boolean(), "null")));
+ options.emplace_back(new ArithmeticOptions());
+ options.emplace_back(new ArithmeticOptions(/*check_overflow=*/true));
+ options.emplace_back(new RoundOptions());
+ options.emplace_back(
+ new RoundOptions(/*ndigits=*/2, /*round_mode=*/RoundMode::TOWARDS_INFINITY));
+ options.emplace_back(new RoundToMultipleOptions());
+ options.emplace_back(new RoundToMultipleOptions(
+ /*multiple=*/100, /*round_mode=*/RoundMode::TOWARDS_INFINITY));
+ options.emplace_back(new ElementWiseAggregateOptions());
+ options.emplace_back(new ElementWiseAggregateOptions(/*skip_nulls=*/false));
+ options.emplace_back(new JoinOptions());
+ options.emplace_back(new JoinOptions(JoinOptions::REPLACE, "replacement"));
+ options.emplace_back(new MatchSubstringOptions("pattern"));
+ options.emplace_back(new MatchSubstringOptions("pattern", /*ignore_case=*/true));
+ options.emplace_back(new SplitOptions());
+ options.emplace_back(new SplitOptions(/*max_splits=*/2, /*reverse=*/true));
+ options.emplace_back(new SplitPatternOptions("pattern"));
+ options.emplace_back(
+ new SplitPatternOptions("pattern", /*max_splits=*/2, /*reverse=*/true));
+ options.emplace_back(new ReplaceSubstringOptions("pattern", "replacement"));
+ options.emplace_back(
+ new ReplaceSubstringOptions("pattern", "replacement", /*max_replacements=*/2));
+ options.emplace_back(new ReplaceSliceOptions(0, 1, "foo"));
+ options.emplace_back(new ReplaceSliceOptions(1, -1, "bar"));
+ options.emplace_back(new ExtractRegexOptions("pattern"));
+ options.emplace_back(new ExtractRegexOptions("pattern2"));
+ options.emplace_back(new SetLookupOptions(ArrayFromJSON(int64(), "[1, 2, 3, 4]")));
+ options.emplace_back(new SetLookupOptions(ArrayFromJSON(boolean(), "[true, false]")));
+ options.emplace_back(new StrptimeOptions("%Y", TimeUnit::type::MILLI));
+ options.emplace_back(new StrptimeOptions("%Y", TimeUnit::type::NANO));
+ options.emplace_back(new StrftimeOptions("%Y-%m-%dT%H:%M:%SZ", "C"));
+#ifndef _WIN32
+ options.emplace_back(new AssumeTimezoneOptions(
+ "Europe/Amsterdam", AssumeTimezoneOptions::Ambiguous::AMBIGUOUS_RAISE,
+ AssumeTimezoneOptions::Nonexistent::NONEXISTENT_RAISE));
+#endif
+ options.emplace_back(new PadOptions(5, " "));
+ options.emplace_back(new PadOptions(10, "A"));
+ options.emplace_back(new TrimOptions(" "));
+ options.emplace_back(new TrimOptions("abc"));
+ options.emplace_back(new SliceOptions(/*start=*/1));
+ options.emplace_back(new SliceOptions(/*start=*/1, /*stop=*/-5, /*step=*/-2));
+ // N.B. we never actually use field_nullability or field_metadata in Arrow
+ options.emplace_back(new MakeStructOptions({"col1"}, {true}, {}));
+ options.emplace_back(new MakeStructOptions({"col1"}, {false}, {}));
+ options.emplace_back(
+ new MakeStructOptions({"col1"}, {false}, {key_value_metadata({{"key", "val"}})}));
+ options.emplace_back(new DayOfWeekOptions(false, 1));
+ options.emplace_back(new WeekOptions(true, false, false));
+ options.emplace_back(new CastOptions(CastOptions::Safe(boolean())));
+ options.emplace_back(new CastOptions(CastOptions::Unsafe(int64())));
+ options.emplace_back(new FilterOptions());
+ options.emplace_back(
+ new FilterOptions(FilterOptions::NullSelectionBehavior::EMIT_NULL));
+ options.emplace_back(new TakeOptions());
+ options.emplace_back(new TakeOptions(/*boundscheck=*/false));
+ options.emplace_back(new DictionaryEncodeOptions());
+ options.emplace_back(
+ new DictionaryEncodeOptions(DictionaryEncodeOptions::NullEncodingBehavior::ENCODE));
+ options.emplace_back(new ArraySortOptions());
+ options.emplace_back(new ArraySortOptions(SortOrder::Descending));
+ options.emplace_back(new SortOptions());
+ options.emplace_back(new SortOptions({SortKey("key", SortOrder::Ascending)}));
+ options.emplace_back(new SortOptions(
+ {SortKey("key", SortOrder::Descending), SortKey("value", SortOrder::Descending)}));
+ options.emplace_back(new PartitionNthOptions(/*pivot=*/0));
+ options.emplace_back(new PartitionNthOptions(/*pivot=*/42));
+ options.emplace_back(new SelectKOptions(0, {}));
+ options.emplace_back(new SelectKOptions(5, {{SortKey("key", SortOrder::Ascending)}}));
+
+ for (size_t i = 0; i < options.size(); i++) {
+ const size_t prev_i = i == 0 ? options.size() - 1 : i - 1;
+ const FunctionOptions& cur = *options[i];
+ const FunctionOptions& prev = *options[prev_i];
+ SCOPED_TRACE(cur.type_name());
+ SCOPED_TRACE(cur.ToString());
+ ASSERT_EQ(cur, cur);
+ ASSERT_NE(cur, prev);
+ ASSERT_NE(prev, cur);
+ ASSERT_NE("", cur.ToString());
+
+ ASSERT_OK_AND_ASSIGN(auto serialized, cur.Serialize());
+ const auto* type_name = cur.type_name();
+ ASSERT_OK_AND_ASSIGN(
+ auto deserialized,
+ FunctionOptions::Deserialize(std::string(type_name, std::strlen(type_name)),
+ *serialized));
+ ASSERT_TRUE(cur.Equals(*deserialized));
+ }
+}
+
+struct ExecBatch;
+
+TEST(Arity, Basics) {
+ auto nullary = Arity::Nullary();
+ ASSERT_EQ(0, nullary.num_args);
+ ASSERT_FALSE(nullary.is_varargs);
+
+ auto unary = Arity::Unary();
+ ASSERT_EQ(1, unary.num_args);
+
+ auto binary = Arity::Binary();
+ ASSERT_EQ(2, binary.num_args);
+
+ auto ternary = Arity::Ternary();
+ ASSERT_EQ(3, ternary.num_args);
+
+ auto varargs = Arity::VarArgs();
+ ASSERT_EQ(0, varargs.num_args);
+ ASSERT_TRUE(varargs.is_varargs);
+
+ auto varargs2 = Arity::VarArgs(2);
+ ASSERT_EQ(2, varargs2.num_args);
+ ASSERT_TRUE(varargs2.is_varargs);
+}
+
+TEST(ScalarFunction, Basics) {
+ ScalarFunction func("scalar_test", Arity::Binary(), /*doc=*/nullptr);
+ ScalarFunction varargs_func("varargs_test", Arity::VarArgs(1), /*doc=*/nullptr);
+
+ ASSERT_EQ("scalar_test", func.name());
+ ASSERT_EQ(2, func.arity().num_args);
+ ASSERT_FALSE(func.arity().is_varargs);
+ ASSERT_EQ(Function::SCALAR, func.kind());
+
+ ASSERT_EQ("varargs_test", varargs_func.name());
+ ASSERT_EQ(1, varargs_func.arity().num_args);
+ ASSERT_TRUE(varargs_func.arity().is_varargs);
+ ASSERT_EQ(Function::SCALAR, varargs_func.kind());
+}
+
+TEST(VectorFunction, Basics) {
+ VectorFunction func("vector_test", Arity::Binary(), /*doc=*/nullptr);
+ VectorFunction varargs_func("varargs_test", Arity::VarArgs(1), /*doc=*/nullptr);
+
+ ASSERT_EQ("vector_test", func.name());
+ ASSERT_EQ(2, func.arity().num_args);
+ ASSERT_FALSE(func.arity().is_varargs);
+ ASSERT_EQ(Function::VECTOR, func.kind());
+
+ ASSERT_EQ("varargs_test", varargs_func.name());
+ ASSERT_EQ(1, varargs_func.arity().num_args);
+ ASSERT_TRUE(varargs_func.arity().is_varargs);
+ ASSERT_EQ(Function::VECTOR, varargs_func.kind());
+}
+
+auto ExecNYI = [](KernelContext* ctx, const ExecBatch& args, Datum* out) {
+ return Status::NotImplemented("NYI");
+};
+
+template <typename FunctionType>
+void CheckAddDispatch(FunctionType* func) {
+ using KernelType = typename FunctionType::KernelType;
+
+ ASSERT_EQ(0, func->num_kernels());
+ ASSERT_EQ(0, func->kernels().size());
+
+ std::vector<InputType> in_types1 = {int32(), int32()};
+ OutputType out_type1 = int32();
+
+ ASSERT_OK(func->AddKernel(in_types1, out_type1, ExecNYI));
+ ASSERT_OK(func->AddKernel({int32(), int8()}, int32(), ExecNYI));
+
+ // Duplicate sig is okay
+ ASSERT_OK(func->AddKernel(in_types1, out_type1, ExecNYI));
+
+ // Add given a descr
+ KernelType descr({float64(), float64()}, float64(), ExecNYI);
+ ASSERT_OK(func->AddKernel(descr));
+
+ ASSERT_EQ(4, func->num_kernels());
+ ASSERT_EQ(4, func->kernels().size());
+
+ // Try adding some invalid kernels
+ ASSERT_RAISES(Invalid, func->AddKernel({}, int32(), ExecNYI));
+ ASSERT_RAISES(Invalid, func->AddKernel({int32()}, int32(), ExecNYI));
+ ASSERT_RAISES(Invalid, func->AddKernel({int8(), int8(), int8()}, int32(), ExecNYI));
+
+ // Add valid and invalid kernel using kernel struct directly
+ KernelType valid_kernel({boolean(), boolean()}, boolean(), ExecNYI);
+ ASSERT_OK(func->AddKernel(valid_kernel));
+
+ KernelType invalid_kernel({boolean()}, boolean(), ExecNYI);
+ ASSERT_RAISES(Invalid, func->AddKernel(invalid_kernel));
+
+ ASSERT_OK_AND_ASSIGN(const Kernel* kernel, func->DispatchExact({int32(), int32()}));
+ KernelSignature expected_sig(in_types1, out_type1);
+ ASSERT_TRUE(kernel->signature->Equals(expected_sig));
+
+ // No kernel available
+ ASSERT_RAISES(NotImplemented, func->DispatchExact({utf8(), utf8()}));
+
+ // Wrong arity
+ ASSERT_RAISES(Invalid, func->DispatchExact({}));
+ ASSERT_RAISES(Invalid, func->DispatchExact({int32(), int32(), int32()}));
+}
+
+TEST(ScalarVectorFunction, DispatchExact) {
+ ScalarFunction func1("scalar_test", Arity::Binary(), /*doc=*/nullptr);
+ VectorFunction func2("vector_test", Arity::Binary(), /*doc=*/nullptr);
+
+ CheckAddDispatch(&func1);
+ CheckAddDispatch(&func2);
+}
+
+TEST(ArrayFunction, VarArgs) {
+ ScalarFunction va_func("va_test", Arity::VarArgs(1), /*doc=*/nullptr);
+
+ std::vector<InputType> va_args = {int8()};
+
+ ASSERT_OK(va_func.AddKernel(va_args, int8(), ExecNYI));
+
+ // No input type passed
+ ASSERT_RAISES(Invalid, va_func.AddKernel({}, int8(), ExecNYI));
+
+ // VarArgs function expect a single input type
+ ASSERT_RAISES(Invalid, va_func.AddKernel({int8(), int8()}, int8(), ExecNYI));
+
+ // Invalid sig
+ ScalarKernel non_va_kernel(std::make_shared<KernelSignature>(va_args, int8()), ExecNYI);
+ ASSERT_RAISES(Invalid, va_func.AddKernel(non_va_kernel));
+
+ std::vector<ValueDescr> args = {ValueDescr::Scalar(int8()), int8(), int8()};
+ ASSERT_OK_AND_ASSIGN(const Kernel* kernel, va_func.DispatchExact(args));
+ ASSERT_TRUE(kernel->signature->MatchesInputs(args));
+
+ // No dispatch possible because args incompatible
+ args[2] = int32();
+ ASSERT_RAISES(NotImplemented, va_func.DispatchExact(args));
+}
+
+TEST(ScalarAggregateFunction, Basics) {
+ ScalarAggregateFunction func("agg_test", Arity::Unary(), /*doc=*/nullptr);
+
+ ASSERT_EQ("agg_test", func.name());
+ ASSERT_EQ(1, func.arity().num_args);
+ ASSERT_FALSE(func.arity().is_varargs);
+ ASSERT_EQ(Function::SCALAR_AGGREGATE, func.kind());
+}
+
+Result<std::unique_ptr<KernelState>> NoopInit(KernelContext*, const KernelInitArgs&) {
+ return nullptr;
+}
+
+Status NoopConsume(KernelContext*, const ExecBatch&) { return Status::OK(); }
+Status NoopMerge(KernelContext*, const KernelState&, KernelState*) {
+ return Status::OK();
+}
+Status NoopFinalize(KernelContext*, Datum*) { return Status::OK(); }
+
+TEST(ScalarAggregateFunction, DispatchExact) {
+ ScalarAggregateFunction func("agg_test", Arity::Unary(), /*doc=*/nullptr);
+
+ std::vector<InputType> in_args = {ValueDescr::Array(int8())};
+ ScalarAggregateKernel kernel(std::move(in_args), int64(), NoopInit, NoopConsume,
+ NoopMerge, NoopFinalize);
+ ASSERT_OK(func.AddKernel(kernel));
+
+ in_args = {float64()};
+ kernel.signature = std::make_shared<KernelSignature>(in_args, float64());
+ ASSERT_OK(func.AddKernel(kernel));
+
+ ASSERT_EQ(2, func.num_kernels());
+ ASSERT_EQ(2, func.kernels().size());
+ ASSERT_TRUE(func.kernels()[1]->signature->Equals(*kernel.signature));
+
+ // Invalid arity
+ in_args = {};
+ kernel.signature = std::make_shared<KernelSignature>(in_args, float64());
+ ASSERT_RAISES(Invalid, func.AddKernel(kernel));
+
+ in_args = {float32(), float64()};
+ kernel.signature = std::make_shared<KernelSignature>(in_args, float64());
+ ASSERT_RAISES(Invalid, func.AddKernel(kernel));
+
+ std::vector<ValueDescr> dispatch_args = {ValueDescr::Array(int8())};
+ ASSERT_OK_AND_ASSIGN(const Kernel* selected_kernel, func.DispatchExact(dispatch_args));
+ ASSERT_EQ(func.kernels()[0], selected_kernel);
+ ASSERT_TRUE(selected_kernel->signature->MatchesInputs(dispatch_args));
+
+ // We declared that only arrays are accepted
+ dispatch_args[0] = {ValueDescr::Scalar(int8())};
+ ASSERT_RAISES(NotImplemented, func.DispatchExact(dispatch_args));
+
+ // Didn't qualify the float64() kernel so this actually dispatches (even
+ // though that may not be what you want)
+ dispatch_args[0] = {ValueDescr::Scalar(float64())};
+ ASSERT_OK_AND_ASSIGN(selected_kernel, func.DispatchExact(dispatch_args));
+ ASSERT_TRUE(selected_kernel->signature->MatchesInputs(dispatch_args));
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernel.cc b/src/arrow/cpp/src/arrow/compute/kernel.cc
new file mode 100644
index 000000000..666b73e41
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernel.cc
@@ -0,0 +1,507 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/kernel.h"
+
+#include <cstddef>
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include "arrow/buffer.h"
+#include "arrow/compute/exec.h"
+#include "arrow/result.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/hash_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::hash_combine;
+
+static constexpr size_t kHashSeed = 0;
+
+namespace compute {
+
+// ----------------------------------------------------------------------
+// KernelContext
+
+Result<std::shared_ptr<ResizableBuffer>> KernelContext::Allocate(int64_t nbytes) {
+ return AllocateResizableBuffer(nbytes, exec_ctx_->memory_pool());
+}
+
+Result<std::shared_ptr<ResizableBuffer>> KernelContext::AllocateBitmap(int64_t num_bits) {
+ const int64_t nbytes = BitUtil::BytesForBits(num_bits);
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ResizableBuffer> result,
+ AllocateResizableBuffer(nbytes, exec_ctx_->memory_pool()));
+ // Since bitmaps are typically written bit by bit, we could leak uninitialized bits.
+ // Make sure all memory is initialized (this also appeases Valgrind).
+ std::memset(result->mutable_data(), 0, result->size());
+ return result;
+}
+
+Status Kernel::InitAll(KernelContext* ctx, const KernelInitArgs& args,
+ std::vector<std::unique_ptr<KernelState>>* states) {
+ for (auto& state : *states) {
+ ARROW_ASSIGN_OR_RAISE(state, args.kernel->init(ctx, args));
+ }
+ return Status::OK();
+}
+
+Result<std::unique_ptr<KernelState>> ScalarAggregateKernel::MergeAll(
+ const ScalarAggregateKernel* kernel, KernelContext* ctx,
+ std::vector<std::unique_ptr<KernelState>> states) {
+ auto out = std::move(states.back());
+ states.pop_back();
+ ctx->SetState(out.get());
+ for (auto& state : states) {
+ RETURN_NOT_OK(kernel->merge(ctx, std::move(*state), out.get()));
+ }
+ return std::move(out);
+}
+
+// ----------------------------------------------------------------------
+// Some basic TypeMatcher implementations
+
+namespace match {
+
+class SameTypeIdMatcher : public TypeMatcher {
+ public:
+ explicit SameTypeIdMatcher(Type::type accepted_id) : accepted_id_(accepted_id) {}
+
+ bool Matches(const DataType& type) const override { return type.id() == accepted_id_; }
+
+ std::string ToString() const override {
+ std::stringstream ss;
+ ss << "Type::" << ::arrow::internal::ToString(accepted_id_);
+ return ss.str();
+ }
+
+ bool Equals(const TypeMatcher& other) const override {
+ if (this == &other) {
+ return true;
+ }
+ auto casted = dynamic_cast<const SameTypeIdMatcher*>(&other);
+ if (casted == nullptr) {
+ return false;
+ }
+ return this->accepted_id_ == casted->accepted_id_;
+ }
+
+ private:
+ Type::type accepted_id_;
+};
+
+std::shared_ptr<TypeMatcher> SameTypeId(Type::type type_id) {
+ return std::make_shared<SameTypeIdMatcher>(type_id);
+}
+
+template <typename ArrowType>
+class TimeUnitMatcher : public TypeMatcher {
+ using ThisType = TimeUnitMatcher<ArrowType>;
+
+ public:
+ explicit TimeUnitMatcher(TimeUnit::type accepted_unit)
+ : accepted_unit_(accepted_unit) {}
+
+ bool Matches(const DataType& type) const override {
+ if (type.id() != ArrowType::type_id) {
+ return false;
+ }
+ const auto& time_type = checked_cast<const ArrowType&>(type);
+ return time_type.unit() == accepted_unit_;
+ }
+
+ bool Equals(const TypeMatcher& other) const override {
+ if (this == &other) {
+ return true;
+ }
+ auto casted = dynamic_cast<const ThisType*>(&other);
+ if (casted == nullptr) {
+ return false;
+ }
+ return this->accepted_unit_ == casted->accepted_unit_;
+ }
+
+ std::string ToString() const override {
+ std::stringstream ss;
+ ss << ArrowType::type_name() << "(" << ::arrow::internal::ToString(accepted_unit_)
+ << ")";
+ return ss.str();
+ }
+
+ private:
+ TimeUnit::type accepted_unit_;
+};
+
+using DurationTypeUnitMatcher = TimeUnitMatcher<DurationType>;
+using Time32TypeUnitMatcher = TimeUnitMatcher<Time32Type>;
+using Time64TypeUnitMatcher = TimeUnitMatcher<Time64Type>;
+using TimestampTypeUnitMatcher = TimeUnitMatcher<TimestampType>;
+
+std::shared_ptr<TypeMatcher> TimestampTypeUnit(TimeUnit::type unit) {
+ return std::make_shared<TimestampTypeUnitMatcher>(unit);
+}
+
+std::shared_ptr<TypeMatcher> Time32TypeUnit(TimeUnit::type unit) {
+ return std::make_shared<Time32TypeUnitMatcher>(unit);
+}
+
+std::shared_ptr<TypeMatcher> Time64TypeUnit(TimeUnit::type unit) {
+ return std::make_shared<Time64TypeUnitMatcher>(unit);
+}
+
+std::shared_ptr<TypeMatcher> DurationTypeUnit(TimeUnit::type unit) {
+ return std::make_shared<DurationTypeUnitMatcher>(unit);
+}
+
+class IntegerMatcher : public TypeMatcher {
+ public:
+ IntegerMatcher() {}
+
+ bool Matches(const DataType& type) const override { return is_integer(type.id()); }
+
+ bool Equals(const TypeMatcher& other) const override {
+ if (this == &other) {
+ return true;
+ }
+ auto casted = dynamic_cast<const IntegerMatcher*>(&other);
+ return casted != nullptr;
+ }
+
+ std::string ToString() const override { return "integer"; }
+};
+
+std::shared_ptr<TypeMatcher> Integer() { return std::make_shared<IntegerMatcher>(); }
+
+class PrimitiveMatcher : public TypeMatcher {
+ public:
+ PrimitiveMatcher() {}
+
+ bool Matches(const DataType& type) const override { return is_primitive(type.id()); }
+
+ bool Equals(const TypeMatcher& other) const override {
+ if (this == &other) {
+ return true;
+ }
+ auto casted = dynamic_cast<const PrimitiveMatcher*>(&other);
+ return casted != nullptr;
+ }
+
+ std::string ToString() const override { return "primitive"; }
+};
+
+std::shared_ptr<TypeMatcher> Primitive() { return std::make_shared<PrimitiveMatcher>(); }
+
+class BinaryLikeMatcher : public TypeMatcher {
+ public:
+ BinaryLikeMatcher() {}
+
+ bool Matches(const DataType& type) const override { return is_binary_like(type.id()); }
+
+ bool Equals(const TypeMatcher& other) const override {
+ if (this == &other) {
+ return true;
+ }
+ auto casted = dynamic_cast<const BinaryLikeMatcher*>(&other);
+ return casted != nullptr;
+ }
+ std::string ToString() const override { return "binary-like"; }
+};
+
+std::shared_ptr<TypeMatcher> BinaryLike() {
+ return std::make_shared<BinaryLikeMatcher>();
+}
+
+class LargeBinaryLikeMatcher : public TypeMatcher {
+ public:
+ LargeBinaryLikeMatcher() {}
+
+ bool Matches(const DataType& type) const override {
+ return is_large_binary_like(type.id());
+ }
+
+ bool Equals(const TypeMatcher& other) const override {
+ if (this == &other) {
+ return true;
+ }
+ auto casted = dynamic_cast<const LargeBinaryLikeMatcher*>(&other);
+ return casted != nullptr;
+ }
+ std::string ToString() const override { return "large-binary-like"; }
+};
+
+class FixedSizeBinaryLikeMatcher : public TypeMatcher {
+ public:
+ FixedSizeBinaryLikeMatcher() {}
+
+ bool Matches(const DataType& type) const override {
+ return is_fixed_size_binary(type.id());
+ }
+
+ bool Equals(const TypeMatcher& other) const override {
+ if (this == &other) {
+ return true;
+ }
+ auto casted = dynamic_cast<const FixedSizeBinaryLikeMatcher*>(&other);
+ return casted != nullptr;
+ }
+ std::string ToString() const override { return "fixed-size-binary-like"; }
+};
+
+std::shared_ptr<TypeMatcher> LargeBinaryLike() {
+ return std::make_shared<LargeBinaryLikeMatcher>();
+}
+
+std::shared_ptr<TypeMatcher> FixedSizeBinaryLike() {
+ return std::make_shared<FixedSizeBinaryLikeMatcher>();
+}
+
+} // namespace match
+
+// ----------------------------------------------------------------------
+// InputType
+
+size_t InputType::Hash() const {
+ size_t result = kHashSeed;
+ hash_combine(result, static_cast<int>(shape_));
+ hash_combine(result, static_cast<int>(kind_));
+ switch (kind_) {
+ case InputType::EXACT_TYPE:
+ hash_combine(result, type_->Hash());
+ break;
+ default:
+ break;
+ }
+ return result;
+}
+
+std::string InputType::ToString() const {
+ std::stringstream ss;
+ switch (shape_) {
+ case ValueDescr::ANY:
+ ss << "any";
+ break;
+ case ValueDescr::ARRAY:
+ ss << "array";
+ break;
+ case ValueDescr::SCALAR:
+ ss << "scalar";
+ break;
+ default:
+ DCHECK(false);
+ break;
+ }
+ ss << "[";
+ switch (kind_) {
+ case InputType::ANY_TYPE:
+ ss << "any";
+ break;
+ case InputType::EXACT_TYPE:
+ ss << type_->ToString();
+ break;
+ case InputType::USE_TYPE_MATCHER: {
+ ss << type_matcher_->ToString();
+ } break;
+ default:
+ DCHECK(false);
+ break;
+ }
+ ss << "]";
+ return ss.str();
+}
+
+bool InputType::Equals(const InputType& other) const {
+ if (this == &other) {
+ return true;
+ }
+ if (kind_ != other.kind_ || shape_ != other.shape_) {
+ return false;
+ }
+ switch (kind_) {
+ case InputType::ANY_TYPE:
+ return true;
+ case InputType::EXACT_TYPE:
+ return type_->Equals(*other.type_);
+ case InputType::USE_TYPE_MATCHER:
+ return type_matcher_->Equals(*other.type_matcher_);
+ default:
+ return false;
+ }
+}
+
+bool InputType::Matches(const ValueDescr& descr) const {
+ if (shape_ != ValueDescr::ANY && descr.shape != shape_) {
+ return false;
+ }
+ switch (kind_) {
+ case InputType::EXACT_TYPE:
+ return type_->Equals(*descr.type);
+ case InputType::USE_TYPE_MATCHER:
+ return type_matcher_->Matches(*descr.type);
+ default:
+ // ANY_TYPE
+ return true;
+ }
+}
+
+bool InputType::Matches(const Datum& value) const { return Matches(value.descr()); }
+
+const std::shared_ptr<DataType>& InputType::type() const {
+ DCHECK_EQ(InputType::EXACT_TYPE, kind_);
+ return type_;
+}
+
+const TypeMatcher& InputType::type_matcher() const {
+ DCHECK_EQ(InputType::USE_TYPE_MATCHER, kind_);
+ return *type_matcher_;
+}
+
+// ----------------------------------------------------------------------
+// OutputType
+
+OutputType::OutputType(ValueDescr descr) : OutputType(descr.type) {
+ shape_ = descr.shape;
+}
+
+Result<ValueDescr> OutputType::Resolve(KernelContext* ctx,
+ const std::vector<ValueDescr>& args) const {
+ ValueDescr::Shape broadcasted_shape = GetBroadcastShape(args);
+ if (kind_ == OutputType::FIXED) {
+ return ValueDescr(type_, shape_ == ValueDescr::ANY ? broadcasted_shape : shape_);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(ValueDescr resolved_descr, resolver_(ctx, args));
+ if (resolved_descr.shape == ValueDescr::ANY) {
+ resolved_descr.shape = broadcasted_shape;
+ }
+ return resolved_descr;
+ }
+}
+
+const std::shared_ptr<DataType>& OutputType::type() const {
+ DCHECK_EQ(FIXED, kind_);
+ return type_;
+}
+
+const OutputType::Resolver& OutputType::resolver() const {
+ DCHECK_EQ(COMPUTED, kind_);
+ return resolver_;
+}
+
+std::string OutputType::ToString() const {
+ if (kind_ == OutputType::FIXED) {
+ return type_->ToString();
+ } else {
+ return "computed";
+ }
+}
+
+// ----------------------------------------------------------------------
+// KernelSignature
+
+KernelSignature::KernelSignature(std::vector<InputType> in_types, OutputType out_type,
+ bool is_varargs)
+ : in_types_(std::move(in_types)),
+ out_type_(std::move(out_type)),
+ is_varargs_(is_varargs),
+ hash_code_(0) {
+ DCHECK(!is_varargs || (is_varargs && (in_types_.size() >= 1)));
+}
+
+std::shared_ptr<KernelSignature> KernelSignature::Make(std::vector<InputType> in_types,
+ OutputType out_type,
+ bool is_varargs) {
+ return std::make_shared<KernelSignature>(std::move(in_types), std::move(out_type),
+ is_varargs);
+}
+
+bool KernelSignature::Equals(const KernelSignature& other) const {
+ if (is_varargs_ != other.is_varargs_) {
+ return false;
+ }
+ if (in_types_.size() != other.in_types_.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < in_types_.size(); ++i) {
+ if (!in_types_[i].Equals(other.in_types_[i])) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool KernelSignature::MatchesInputs(const std::vector<ValueDescr>& args) const {
+ if (is_varargs_) {
+ for (size_t i = 0; i < args.size(); ++i) {
+ if (!in_types_[std::min(i, in_types_.size() - 1)].Matches(args[i])) {
+ return false;
+ }
+ }
+ } else {
+ if (args.size() != in_types_.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < in_types_.size(); ++i) {
+ if (!in_types_[i].Matches(args[i])) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+size_t KernelSignature::Hash() const {
+ if (hash_code_ != 0) {
+ return hash_code_;
+ }
+ size_t result = kHashSeed;
+ for (const auto& in_type : in_types_) {
+ hash_combine(result, in_type.Hash());
+ }
+ hash_code_ = result;
+ return result;
+}
+
+std::string KernelSignature::ToString() const {
+ std::stringstream ss;
+
+ if (is_varargs_) {
+ ss << "varargs[";
+ } else {
+ ss << "(";
+ }
+ for (size_t i = 0; i < in_types_.size(); ++i) {
+ if (i > 0) {
+ ss << ", ";
+ }
+ ss << in_types_[i].ToString();
+ }
+ if (is_varargs_) {
+ ss << "]";
+ } else {
+ ss << ")";
+ }
+ ss << " -> " << out_type_.ToString();
+ return ss.str();
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernel.h b/src/arrow/cpp/src/arrow/compute/kernel.h
new file mode 100644
index 000000000..27fb83163
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernel.h
@@ -0,0 +1,752 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// NOTE: API is EXPERIMENTAL and will change without going through a
+// deprecation cycle
+
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/compute/exec.h"
+#include "arrow/datum.h"
+#include "arrow/memory_pool.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace compute {
+
+class FunctionOptions;
+
+/// \brief Base class for opaque kernel-specific state. For example, if there
+/// is some kind of initialization required.
+struct ARROW_EXPORT KernelState {
+ virtual ~KernelState() = default;
+};
+
+/// \brief Context/state for the execution of a particular kernel.
+class ARROW_EXPORT KernelContext {
+ public:
+ explicit KernelContext(ExecContext* exec_ctx) : exec_ctx_(exec_ctx) {}
+
+ /// \brief Allocate buffer from the context's memory pool. The contents are
+ /// not initialized.
+ Result<std::shared_ptr<ResizableBuffer>> Allocate(int64_t nbytes);
+
+ /// \brief Allocate buffer for bitmap from the context's memory pool. Like
+ /// Allocate, the contents of the buffer are not initialized but the last
+ /// byte is preemptively zeroed to help avoid ASAN or valgrind issues.
+ Result<std::shared_ptr<ResizableBuffer>> AllocateBitmap(int64_t num_bits);
+
+ /// \brief Assign the active KernelState to be utilized for each stage of
+ /// kernel execution. Ownership and memory lifetime of the KernelState must
+ /// be minded separately.
+ void SetState(KernelState* state) { state_ = state; }
+
+ KernelState* state() { return state_; }
+
+ /// \brief Configuration related to function execution that is to be shared
+ /// across multiple kernels.
+ ExecContext* exec_context() { return exec_ctx_; }
+
+ /// \brief The memory pool to use for allocations. For now, it uses the
+ /// MemoryPool contained in the ExecContext used to create the KernelContext.
+ MemoryPool* memory_pool() { return exec_ctx_->memory_pool(); }
+
+ private:
+ ExecContext* exec_ctx_;
+ KernelState* state_ = NULLPTR;
+};
+
+/// \brief The standard kernel execution API that must be implemented for
+/// SCALAR and VECTOR kernel types. This includes both stateless and stateful
+/// kernels. Kernels depending on some execution state access that state via
+/// subclasses of KernelState set on the KernelContext object. May be used for
+/// SCALAR and VECTOR kernel kinds. Implementations should endeavor to write
+/// into pre-allocated memory if they are able, though for some kernels
+/// (e.g. in cases when a builder like StringBuilder) must be employed this may
+/// not be possible.
+using ArrayKernelExec = std::function<Status(KernelContext*, const ExecBatch&, Datum*)>;
+
+/// \brief An type-checking interface to permit customizable validation rules
+/// for use with InputType and KernelSignature. This is for scenarios where the
+/// acceptance is not an exact type instance, such as a TIMESTAMP type for a
+/// specific TimeUnit, but permitting any time zone.
+struct ARROW_EXPORT TypeMatcher {
+ virtual ~TypeMatcher() = default;
+
+ /// \brief Return true if this matcher accepts the data type.
+ virtual bool Matches(const DataType& type) const = 0;
+
+ /// \brief A human-interpretable string representation of what the type
+ /// matcher checks for, usable when printing KernelSignature or formatting
+ /// error messages.
+ virtual std::string ToString() const = 0;
+
+ /// \brief Return true if this TypeMatcher contains the same matching rule as
+ /// the other. Currently depends on RTTI.
+ virtual bool Equals(const TypeMatcher& other) const = 0;
+};
+
+namespace match {
+
+/// \brief Match any DataType instance having the same DataType::id.
+ARROW_EXPORT std::shared_ptr<TypeMatcher> SameTypeId(Type::type type_id);
+
+/// \brief Match any TimestampType instance having the same unit, but the time
+/// zones can be different.
+ARROW_EXPORT std::shared_ptr<TypeMatcher> TimestampTypeUnit(TimeUnit::type unit);
+ARROW_EXPORT std::shared_ptr<TypeMatcher> Time32TypeUnit(TimeUnit::type unit);
+ARROW_EXPORT std::shared_ptr<TypeMatcher> Time64TypeUnit(TimeUnit::type unit);
+ARROW_EXPORT std::shared_ptr<TypeMatcher> DurationTypeUnit(TimeUnit::type unit);
+
+// \brief Match any integer type
+ARROW_EXPORT std::shared_ptr<TypeMatcher> Integer();
+
+// Match types using 32-bit varbinary representation
+ARROW_EXPORT std::shared_ptr<TypeMatcher> BinaryLike();
+
+// Match types using 64-bit varbinary representation
+ARROW_EXPORT std::shared_ptr<TypeMatcher> LargeBinaryLike();
+
+// Match any fixed binary type
+ARROW_EXPORT std::shared_ptr<TypeMatcher> FixedSizeBinaryLike();
+
+// \brief Match any primitive type (boolean or any type representable as a C
+// Type)
+ARROW_EXPORT std::shared_ptr<TypeMatcher> Primitive();
+
+} // namespace match
+
+/// \brief An object used for type- and shape-checking arguments to be passed
+/// to a kernel and stored in a KernelSignature. Distinguishes between ARRAY
+/// and SCALAR arguments using ValueDescr::Shape. The type-checking rule can be
+/// supplied either with an exact DataType instance or a custom TypeMatcher.
+class ARROW_EXPORT InputType {
+ public:
+ /// \brief The kind of type-checking rule that the InputType contains.
+ enum Kind {
+ /// \brief Accept any value type.
+ ANY_TYPE,
+
+ /// \brief A fixed arrow::DataType and will only exact match having this
+ /// exact type (e.g. same TimestampType unit, same decimal scale and
+ /// precision, or same nested child types).
+ EXACT_TYPE,
+
+ /// \brief Uses a TypeMatcher implementation to check the type.
+ USE_TYPE_MATCHER
+ };
+
+ /// \brief Accept any value type but with a specific shape (e.g. any Array or
+ /// any Scalar).
+ InputType(ValueDescr::Shape shape = ValueDescr::ANY) // NOLINT implicit construction
+ : kind_(ANY_TYPE), shape_(shape) {}
+
+ /// \brief Accept an exact value type.
+ InputType(std::shared_ptr<DataType> type, // NOLINT implicit construction
+ ValueDescr::Shape shape = ValueDescr::ANY)
+ : kind_(EXACT_TYPE), shape_(shape), type_(std::move(type)) {}
+
+ /// \brief Accept an exact value type and shape provided by a ValueDescr.
+ InputType(const ValueDescr& descr) // NOLINT implicit construction
+ : InputType(descr.type, descr.shape) {}
+
+ /// \brief Use the passed TypeMatcher to type check.
+ InputType(std::shared_ptr<TypeMatcher> type_matcher, // NOLINT implicit construction
+ ValueDescr::Shape shape = ValueDescr::ANY)
+ : kind_(USE_TYPE_MATCHER), shape_(shape), type_matcher_(std::move(type_matcher)) {}
+
+ /// \brief Match any type with the given Type::type. Uses a TypeMatcher for
+ /// its implementation.
+ explicit InputType(Type::type type_id, ValueDescr::Shape shape = ValueDescr::ANY)
+ : InputType(match::SameTypeId(type_id), shape) {}
+
+ InputType(const InputType& other) { CopyInto(other); }
+
+ void operator=(const InputType& other) { CopyInto(other); }
+
+ InputType(InputType&& other) { MoveInto(std::forward<InputType>(other)); }
+
+ void operator=(InputType&& other) { MoveInto(std::forward<InputType>(other)); }
+
+ // \brief Match an array with the given exact type. Convenience constructor.
+ static InputType Array(std::shared_ptr<DataType> type) {
+ return InputType(std::move(type), ValueDescr::ARRAY);
+ }
+
+ // \brief Match a scalar with the given exact type. Convenience constructor.
+ static InputType Scalar(std::shared_ptr<DataType> type) {
+ return InputType(std::move(type), ValueDescr::SCALAR);
+ }
+
+ // \brief Match an array with the given Type::type id. Convenience
+ // constructor.
+ static InputType Array(Type::type id) { return InputType(id, ValueDescr::ARRAY); }
+
+ // \brief Match a scalar with the given Type::type id. Convenience
+ // constructor.
+ static InputType Scalar(Type::type id) { return InputType(id, ValueDescr::SCALAR); }
+
+ /// \brief Return true if this input type matches the same type cases as the
+ /// other.
+ bool Equals(const InputType& other) const;
+
+ bool operator==(const InputType& other) const { return this->Equals(other); }
+
+ bool operator!=(const InputType& other) const { return !(*this == other); }
+
+ /// \brief Return hash code.
+ size_t Hash() const;
+
+ /// \brief Render a human-readable string representation.
+ std::string ToString() const;
+
+ /// \brief Return true if the value matches this argument kind in type
+ /// and shape.
+ bool Matches(const Datum& value) const;
+
+ /// \brief Return true if the value descriptor matches this argument kind in
+ /// type and shape.
+ bool Matches(const ValueDescr& value) const;
+
+ /// \brief The type matching rule that this InputType uses.
+ Kind kind() const { return kind_; }
+
+ /// \brief Indicates whether this InputType matches Array (ValueDescr::ARRAY),
+ /// Scalar (ValueDescr::SCALAR) values, or both (ValueDescr::ANY).
+ ValueDescr::Shape shape() const { return shape_; }
+
+ /// \brief For InputType::EXACT_TYPE kind, the exact type that this InputType
+ /// must match. Otherwise this function should not be used and will assert in
+ /// debug builds.
+ const std::shared_ptr<DataType>& type() const;
+
+ /// \brief For InputType::USE_TYPE_MATCHER, the TypeMatcher to be used for
+ /// checking the type of a value. Otherwise this function should not be used
+ /// and will assert in debug builds.
+ const TypeMatcher& type_matcher() const;
+
+ private:
+ void CopyInto(const InputType& other) {
+ this->kind_ = other.kind_;
+ this->shape_ = other.shape_;
+ this->type_ = other.type_;
+ this->type_matcher_ = other.type_matcher_;
+ }
+
+ void MoveInto(InputType&& other) {
+ this->kind_ = other.kind_;
+ this->shape_ = other.shape_;
+ this->type_ = std::move(other.type_);
+ this->type_matcher_ = std::move(other.type_matcher_);
+ }
+
+ Kind kind_;
+
+ ValueDescr::Shape shape_ = ValueDescr::ANY;
+
+ // For EXACT_TYPE Kind
+ std::shared_ptr<DataType> type_;
+
+ // For USE_TYPE_MATCHER Kind
+ std::shared_ptr<TypeMatcher> type_matcher_;
+};
+
+/// \brief Container to capture both exact and input-dependent output types.
+///
+/// The value shape returned by Resolve will be determined by broadcasting the
+/// shapes of the input arguments, otherwise this is handled by the
+/// user-defined resolver function:
+///
+/// * Any ARRAY shape -> output shape is ARRAY
+/// * All SCALAR shapes -> output shape is SCALAR
+class ARROW_EXPORT OutputType {
+ public:
+ /// \brief An enum indicating whether the value type is an invariant fixed
+ /// value or one that's computed by a kernel-defined resolver function.
+ enum ResolveKind { FIXED, COMPUTED };
+
+ /// Type resolution function. Given input types and shapes, return output
+ /// type and shape. This function MAY may use the kernel state to decide
+ /// the output type based on the functionoptions.
+ ///
+ /// This function SHOULD _not_ be used to check for arity, that is to be
+ /// performed one or more layers above.
+ using Resolver =
+ std::function<Result<ValueDescr>(KernelContext*, const std::vector<ValueDescr>&)>;
+
+ /// \brief Output an exact type, but with shape determined by promoting the
+ /// shapes of the inputs (any ARRAY argument yields ARRAY).
+ OutputType(std::shared_ptr<DataType> type) // NOLINT implicit construction
+ : kind_(FIXED), type_(std::move(type)) {}
+
+ /// \brief Output the exact type and shape provided by a ValueDescr
+ OutputType(ValueDescr descr); // NOLINT implicit construction
+
+ /// \brief Output a computed type depending on actual input types
+ OutputType(Resolver resolver) // NOLINT implicit construction
+ : kind_(COMPUTED), resolver_(std::move(resolver)) {}
+
+ OutputType(const OutputType& other) {
+ this->kind_ = other.kind_;
+ this->shape_ = other.shape_;
+ this->type_ = other.type_;
+ this->resolver_ = other.resolver_;
+ }
+
+ OutputType(OutputType&& other) {
+ this->kind_ = other.kind_;
+ this->type_ = std::move(other.type_);
+ this->shape_ = other.shape_;
+ this->resolver_ = other.resolver_;
+ }
+
+ OutputType& operator=(const OutputType&) = default;
+ OutputType& operator=(OutputType&&) = default;
+
+ /// \brief Return the shape and type of the expected output value of the
+ /// kernel given the value descriptors (shapes and types) of the input
+ /// arguments. The resolver may make use of state information kept in the
+ /// KernelContext.
+ Result<ValueDescr> Resolve(KernelContext* ctx,
+ const std::vector<ValueDescr>& args) const;
+
+ /// \brief The exact output value type for the FIXED kind.
+ const std::shared_ptr<DataType>& type() const;
+
+ /// \brief For use with COMPUTED resolution strategy. It may be more
+ /// convenient to invoke this with OutputType::Resolve returned from this
+ /// method.
+ const Resolver& resolver() const;
+
+ /// \brief Render a human-readable string representation.
+ std::string ToString() const;
+
+ /// \brief Return the kind of type resolution of this output type, whether
+ /// fixed/invariant or computed by a resolver.
+ ResolveKind kind() const { return kind_; }
+
+ /// \brief If the shape is ANY, then Resolve will compute the shape based on
+ /// the input arguments.
+ ValueDescr::Shape shape() const { return shape_; }
+
+ private:
+ ResolveKind kind_;
+
+ // For FIXED resolution
+ std::shared_ptr<DataType> type_;
+
+ /// \brief The shape of the output type to return when using Resolve. If ANY
+ /// will promote the input shapes.
+ ValueDescr::Shape shape_ = ValueDescr::ANY;
+
+ // For COMPUTED resolution
+ Resolver resolver_;
+};
+
+/// \brief Holds the input types and output type of the kernel.
+///
+/// VarArgs functions with minimum N arguments should pass up to N input types to be
+/// used to validate the input types of a function invocation. The first N-1 types
+/// will be matched against the first N-1 arguments, and the last type will be
+/// matched against the remaining arguments.
+class ARROW_EXPORT KernelSignature {
+ public:
+ KernelSignature(std::vector<InputType> in_types, OutputType out_type,
+ bool is_varargs = false);
+
+ /// \brief Convenience ctor since make_shared can be awkward
+ static std::shared_ptr<KernelSignature> Make(std::vector<InputType> in_types,
+ OutputType out_type,
+ bool is_varargs = false);
+
+ /// \brief Return true if the signature if compatible with the list of input
+ /// value descriptors.
+ bool MatchesInputs(const std::vector<ValueDescr>& descriptors) const;
+
+ /// \brief Returns true if the input types of each signature are
+ /// equal. Well-formed functions should have a deterministic output type
+ /// given input types, but currently it is the responsibility of the
+ /// developer to ensure this.
+ bool Equals(const KernelSignature& other) const;
+
+ bool operator==(const KernelSignature& other) const { return this->Equals(other); }
+
+ bool operator!=(const KernelSignature& other) const { return !(*this == other); }
+
+ /// \brief Compute a hash code for the signature
+ size_t Hash() const;
+
+ /// \brief The input types for the kernel. For VarArgs functions, this should
+ /// generally contain a single validator to use for validating all of the
+ /// function arguments.
+ const std::vector<InputType>& in_types() const { return in_types_; }
+
+ /// \brief The output type for the kernel. Use Resolve to return the exact
+ /// output given input argument ValueDescrs, since many kernels' output types
+ /// depend on their input types (or their type metadata).
+ const OutputType& out_type() const { return out_type_; }
+
+ /// \brief Render a human-readable string representation
+ std::string ToString() const;
+
+ bool is_varargs() const { return is_varargs_; }
+
+ private:
+ std::vector<InputType> in_types_;
+ OutputType out_type_;
+ bool is_varargs_;
+
+ // For caching the hash code after it's computed the first time
+ mutable uint64_t hash_code_;
+};
+
+/// \brief A function may contain multiple variants of a kernel for a given
+/// type combination for different SIMD levels. Based on the active system's
+/// CPU info or the user's preferences, we can elect to use one over the other.
+struct SimdLevel {
+ enum type { NONE = 0, SSE4_2, AVX, AVX2, AVX512, NEON, MAX };
+};
+
+/// \brief The strategy to use for propagating or otherwise populating the
+/// validity bitmap of a kernel output.
+struct NullHandling {
+ enum type {
+ /// Compute the output validity bitmap by intersecting the validity bitmaps
+ /// of the arguments using bitwise-and operations. This means that values
+ /// in the output are valid/non-null only if the corresponding values in
+ /// all input arguments were valid/non-null. Kernel generally need not
+ /// touch the bitmap thereafter, but a kernel's exec function is permitted
+ /// to alter the bitmap after the null intersection is computed if it needs
+ /// to.
+ INTERSECTION,
+
+ /// Kernel expects a pre-allocated buffer to write the result bitmap
+ /// into. The preallocated memory is not zeroed (except for the last byte),
+ /// so the kernel should ensure to completely populate the bitmap.
+ COMPUTED_PREALLOCATE,
+
+ /// Kernel allocates and sets the validity bitmap of the output.
+ COMPUTED_NO_PREALLOCATE,
+
+ /// Kernel output is never null and a validity bitmap does not need to be
+ /// allocated.
+ OUTPUT_NOT_NULL
+ };
+};
+
+/// \brief The preference for memory preallocation of fixed-width type outputs
+/// in kernel execution.
+struct MemAllocation {
+ enum type {
+ // For data types that support pre-allocation (i.e. fixed-width), the
+ // kernel expects to be provided a pre-allocated data buffer to write
+ // into. Non-fixed-width types must always allocate their own data
+ // buffers. The allocation made for the same length as the execution batch,
+ // so vector kernels yielding differently sized output should not use this.
+ //
+ // It is valid for the data to not be preallocated but the validity bitmap
+ // is (or is computed using the intersection/bitwise-and method).
+ //
+ // For variable-size output types like BinaryType or StringType, or for
+ // nested types, this option has no effect.
+ PREALLOCATE,
+
+ // The kernel is responsible for allocating its own data buffer for
+ // fixed-width type outputs.
+ NO_PREALLOCATE
+ };
+};
+
+struct Kernel;
+
+/// \brief Arguments to pass to a KernelInit function. A struct is used to help
+/// avoid API breakage should the arguments passed need to be expanded.
+struct KernelInitArgs {
+ /// \brief A pointer to the kernel being initialized. The init function may
+ /// depend on the kernel's KernelSignature or other data contained there.
+ const Kernel* kernel;
+
+ /// \brief The types and shapes of the input arguments that the kernel is
+ /// about to be executed against.
+ ///
+ /// TODO: should this be const std::vector<ValueDescr>*? const-ref is being
+ /// used to avoid the cost of copying the struct into the args struct.
+ const std::vector<ValueDescr>& inputs;
+
+ /// \brief Opaque options specific to this kernel. May be nullptr for functions
+ /// that do not require options.
+ const FunctionOptions* options;
+};
+
+/// \brief Common initializer function for all kernel types.
+using KernelInit = std::function<Result<std::unique_ptr<KernelState>>(
+ KernelContext*, const KernelInitArgs&)>;
+
+/// \brief Base type for kernels. Contains the function signature and
+/// optionally the state initialization function, along with some common
+/// attributes
+struct Kernel {
+ Kernel() = default;
+
+ Kernel(std::shared_ptr<KernelSignature> sig, KernelInit init)
+ : signature(std::move(sig)), init(std::move(init)) {}
+
+ Kernel(std::vector<InputType> in_types, OutputType out_type, KernelInit init)
+ : Kernel(KernelSignature::Make(std::move(in_types), std::move(out_type)),
+ std::move(init)) {}
+
+ /// \brief The "signature" of the kernel containing the InputType input
+ /// argument validators and OutputType output type and shape resolver.
+ std::shared_ptr<KernelSignature> signature;
+
+ /// \brief Create a new KernelState for invocations of this kernel, e.g. to
+ /// set up any options or state relevant for execution.
+ KernelInit init;
+
+ /// \brief Create a vector of new KernelState for invocations of this kernel.
+ static Status InitAll(KernelContext*, const KernelInitArgs&,
+ std::vector<std::unique_ptr<KernelState>>*);
+
+ /// \brief Indicates whether execution can benefit from parallelization
+ /// (splitting large chunks into smaller chunks and using multiple
+ /// threads). Some kernels may not support parallel execution at
+ /// all. Synchronization and concurrency-related issues are currently the
+ /// responsibility of the Kernel's implementation.
+ bool parallelizable = true;
+
+ /// \brief Indicates the level of SIMD instruction support in the host CPU is
+ /// required to use the function. The intention is for functions to be able to
+ /// contain multiple kernels with the same signature but different levels of SIMD,
+ /// so that the most optimized kernel supported on a host's processor can be chosen.
+ SimdLevel::type simd_level = SimdLevel::NONE;
+};
+
+/// \brief Common kernel base data structure for ScalarKernel and
+/// VectorKernel. It is called "ArrayKernel" in that the functions generally
+/// output array values (as opposed to scalar values in the case of aggregate
+/// functions).
+struct ArrayKernel : public Kernel {
+ ArrayKernel() = default;
+
+ ArrayKernel(std::shared_ptr<KernelSignature> sig, ArrayKernelExec exec,
+ KernelInit init = NULLPTR)
+ : Kernel(std::move(sig), init), exec(std::move(exec)) {}
+
+ ArrayKernel(std::vector<InputType> in_types, OutputType out_type, ArrayKernelExec exec,
+ KernelInit init = NULLPTR)
+ : Kernel(std::move(in_types), std::move(out_type), std::move(init)),
+ exec(std::move(exec)) {}
+
+ /// \brief Perform a single invocation of this kernel. Depending on the
+ /// implementation, it may only write into preallocated memory, while in some
+ /// cases it will allocate its own memory. Any required state is managed
+ /// through the KernelContext.
+ ArrayKernelExec exec;
+
+ /// \brief Writing execution results into larger contiguous allocations
+ /// requires that the kernel be able to write into sliced output ArrayData*,
+ /// including sliced output validity bitmaps. Some kernel implementations may
+ /// not be able to do this, so setting this to false disables this
+ /// functionality.
+ bool can_write_into_slices = true;
+};
+
+/// \brief Kernel data structure for implementations of ScalarFunction. In
+/// addition to the members found in ArrayKernel, contains the null handling
+/// and memory pre-allocation preferences.
+struct ScalarKernel : public ArrayKernel {
+ using ArrayKernel::ArrayKernel;
+
+ // For scalar functions preallocated data and intersecting arg validity
+ // bitmaps is a reasonable default
+ NullHandling::type null_handling = NullHandling::INTERSECTION;
+ MemAllocation::type mem_allocation = MemAllocation::PREALLOCATE;
+};
+
+// ----------------------------------------------------------------------
+// VectorKernel (for VectorFunction)
+
+/// \brief See VectorKernel::finalize member for usage
+using VectorFinalize = std::function<Status(KernelContext*, std::vector<Datum>*)>;
+
+/// \brief Kernel data structure for implementations of VectorFunction. In
+/// addition to the members found in ArrayKernel, contains an optional
+/// finalizer function, the null handling and memory pre-allocation preferences
+/// (which have different defaults from ScalarKernel), and some other
+/// execution-related options.
+struct VectorKernel : public ArrayKernel {
+ VectorKernel() = default;
+
+ VectorKernel(std::shared_ptr<KernelSignature> sig, ArrayKernelExec exec)
+ : ArrayKernel(std::move(sig), std::move(exec)) {}
+
+ VectorKernel(std::vector<InputType> in_types, OutputType out_type, ArrayKernelExec exec,
+ KernelInit init = NULLPTR, VectorFinalize finalize = NULLPTR)
+ : ArrayKernel(std::move(in_types), std::move(out_type), std::move(exec),
+ std::move(init)),
+ finalize(std::move(finalize)) {}
+
+ VectorKernel(std::shared_ptr<KernelSignature> sig, ArrayKernelExec exec,
+ KernelInit init = NULLPTR, VectorFinalize finalize = NULLPTR)
+ : ArrayKernel(std::move(sig), std::move(exec), std::move(init)),
+ finalize(std::move(finalize)) {}
+
+ /// \brief For VectorKernel, convert intermediate results into finalized
+ /// results. Mutates input argument. Some kernels may accumulate state
+ /// (example: hashing-related functions) through processing chunked inputs, and
+ /// then need to attach some accumulated state to each of the outputs of
+ /// processing each chunk of data.
+ VectorFinalize finalize;
+
+ /// Since vector kernels generally are implemented rather differently from
+ /// scalar/elementwise kernels (and they may not even yield arrays of the same
+ /// size), so we make the developer opt-in to any memory preallocation rather
+ /// than having to turn it off.
+ NullHandling::type null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ MemAllocation::type mem_allocation = MemAllocation::NO_PREALLOCATE;
+
+ /// Some vector kernels can do chunkwise execution using ExecBatchIterator,
+ /// in some cases accumulating some state. Other kernels (like Take) need to
+ /// be passed whole arrays and don't work on ChunkedArray inputs
+ bool can_execute_chunkwise = true;
+
+ /// Some kernels (like unique and value_counts) yield non-chunked output from
+ /// chunked-array inputs. This option controls how the results are boxed when
+ /// returned from ExecVectorFunction
+ ///
+ /// true -> ChunkedArray
+ /// false -> Array
+ bool output_chunked = true;
+};
+
+// ----------------------------------------------------------------------
+// ScalarAggregateKernel (for ScalarAggregateFunction)
+
+using ScalarAggregateConsume = std::function<Status(KernelContext*, const ExecBatch&)>;
+
+using ScalarAggregateMerge =
+ std::function<Status(KernelContext*, KernelState&&, KernelState*)>;
+
+// Finalize returns Datum to permit multiple return values
+using ScalarAggregateFinalize = std::function<Status(KernelContext*, Datum*)>;
+
+/// \brief Kernel data structure for implementations of
+/// ScalarAggregateFunction. The four necessary components of an aggregation
+/// kernel are the init, consume, merge, and finalize functions.
+///
+/// * init: creates a new KernelState for a kernel.
+/// * consume: processes an ExecBatch and updates the KernelState found in the
+/// KernelContext.
+/// * merge: combines one KernelState with another.
+/// * finalize: produces the end result of the aggregation using the
+/// KernelState in the KernelContext.
+struct ScalarAggregateKernel : public Kernel {
+ ScalarAggregateKernel() = default;
+
+ ScalarAggregateKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
+ ScalarAggregateConsume consume, ScalarAggregateMerge merge,
+ ScalarAggregateFinalize finalize)
+ : Kernel(std::move(sig), std::move(init)),
+ consume(std::move(consume)),
+ merge(std::move(merge)),
+ finalize(std::move(finalize)) {}
+
+ ScalarAggregateKernel(std::vector<InputType> in_types, OutputType out_type,
+ KernelInit init, ScalarAggregateConsume consume,
+ ScalarAggregateMerge merge, ScalarAggregateFinalize finalize)
+ : ScalarAggregateKernel(
+ KernelSignature::Make(std::move(in_types), std::move(out_type)),
+ std::move(init), std::move(consume), std::move(merge), std::move(finalize)) {}
+
+ /// \brief Merge a vector of KernelStates into a single KernelState.
+ /// The merged state will be returned and will be set on the KernelContext.
+ static Result<std::unique_ptr<KernelState>> MergeAll(
+ const ScalarAggregateKernel* kernel, KernelContext* ctx,
+ std::vector<std::unique_ptr<KernelState>> states);
+
+ ScalarAggregateConsume consume;
+ ScalarAggregateMerge merge;
+ ScalarAggregateFinalize finalize;
+};
+
+// ----------------------------------------------------------------------
+// HashAggregateKernel (for HashAggregateFunction)
+
+using HashAggregateResize = std::function<Status(KernelContext*, int64_t)>;
+
+using HashAggregateConsume = std::function<Status(KernelContext*, const ExecBatch&)>;
+
+using HashAggregateMerge =
+ std::function<Status(KernelContext*, KernelState&&, const ArrayData&)>;
+
+// Finalize returns Datum to permit multiple return values
+using HashAggregateFinalize = std::function<Status(KernelContext*, Datum*)>;
+
+/// \brief Kernel data structure for implementations of
+/// HashAggregateFunction. The four necessary components of an aggregation
+/// kernel are the init, consume, merge, and finalize functions.
+///
+/// * init: creates a new KernelState for a kernel.
+/// * resize: ensure that the KernelState can accommodate the specified number of groups.
+/// * consume: processes an ExecBatch (which includes the argument as well
+/// as an array of group identifiers) and updates the KernelState found in the
+/// KernelContext.
+/// * merge: combines one KernelState with another.
+/// * finalize: produces the end result of the aggregation using the
+/// KernelState in the KernelContext.
+struct HashAggregateKernel : public Kernel {
+ HashAggregateKernel() = default;
+
+ HashAggregateKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
+ HashAggregateResize resize, HashAggregateConsume consume,
+ HashAggregateMerge merge, HashAggregateFinalize finalize)
+ : Kernel(std::move(sig), std::move(init)),
+ resize(std::move(resize)),
+ consume(std::move(consume)),
+ merge(std::move(merge)),
+ finalize(std::move(finalize)) {}
+
+ HashAggregateKernel(std::vector<InputType> in_types, OutputType out_type,
+ KernelInit init, HashAggregateConsume consume,
+ HashAggregateResize resize, HashAggregateMerge merge,
+ HashAggregateFinalize finalize)
+ : HashAggregateKernel(
+ KernelSignature::Make(std::move(in_types), std::move(out_type)),
+ std::move(init), std::move(resize), std::move(consume), std::move(merge),
+ std::move(finalize)) {}
+
+ HashAggregateResize resize;
+ HashAggregateConsume consume;
+ HashAggregateMerge merge;
+ HashAggregateFinalize finalize;
+};
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernel_test.cc b/src/arrow/cpp/src/arrow/compute/kernel_test.cc
new file mode 100644
index 000000000..a63c42d4f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernel_test.cc
@@ -0,0 +1,516 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/compute/kernel.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+namespace compute {
+
+// ----------------------------------------------------------------------
+// TypeMatcher
+
+TEST(TypeMatcher, SameTypeId) {
+ std::shared_ptr<TypeMatcher> matcher = match::SameTypeId(Type::DECIMAL);
+ ASSERT_TRUE(matcher->Matches(*decimal(12, 2)));
+ ASSERT_FALSE(matcher->Matches(*int8()));
+
+ ASSERT_EQ("Type::DECIMAL128", matcher->ToString());
+
+ ASSERT_TRUE(matcher->Equals(*matcher));
+ ASSERT_TRUE(matcher->Equals(*match::SameTypeId(Type::DECIMAL)));
+ ASSERT_FALSE(matcher->Equals(*match::SameTypeId(Type::TIMESTAMP)));
+}
+
+TEST(TypeMatcher, TimestampTypeUnit) {
+ auto matcher = match::TimestampTypeUnit(TimeUnit::MILLI);
+ auto matcher2 = match::Time32TypeUnit(TimeUnit::MILLI);
+
+ ASSERT_TRUE(matcher->Matches(*timestamp(TimeUnit::MILLI)));
+ ASSERT_TRUE(matcher->Matches(*timestamp(TimeUnit::MILLI, "utc")));
+ ASSERT_FALSE(matcher->Matches(*timestamp(TimeUnit::SECOND)));
+ ASSERT_FALSE(matcher->Matches(*time32(TimeUnit::MILLI)));
+ ASSERT_TRUE(matcher2->Matches(*time32(TimeUnit::MILLI)));
+
+ // Check ToString representation
+ ASSERT_EQ("timestamp(s)", match::TimestampTypeUnit(TimeUnit::SECOND)->ToString());
+ ASSERT_EQ("timestamp(ms)", match::TimestampTypeUnit(TimeUnit::MILLI)->ToString());
+ ASSERT_EQ("timestamp(us)", match::TimestampTypeUnit(TimeUnit::MICRO)->ToString());
+ ASSERT_EQ("timestamp(ns)", match::TimestampTypeUnit(TimeUnit::NANO)->ToString());
+
+ // Equals implementation
+ ASSERT_TRUE(matcher->Equals(*matcher));
+ ASSERT_TRUE(matcher->Equals(*match::TimestampTypeUnit(TimeUnit::MILLI)));
+ ASSERT_FALSE(matcher->Equals(*match::TimestampTypeUnit(TimeUnit::MICRO)));
+ ASSERT_FALSE(matcher->Equals(*match::Time32TypeUnit(TimeUnit::MILLI)));
+}
+
+// ----------------------------------------------------------------------
+// InputType
+
+TEST(InputType, AnyTypeConstructor) {
+ // Check the ANY_TYPE ctors
+ InputType ty;
+ ASSERT_EQ(InputType::ANY_TYPE, ty.kind());
+ ASSERT_EQ(ValueDescr::ANY, ty.shape());
+
+ ty = InputType(ValueDescr::SCALAR);
+ ASSERT_EQ(ValueDescr::SCALAR, ty.shape());
+
+ ty = InputType(ValueDescr::ARRAY);
+ ASSERT_EQ(ValueDescr::ARRAY, ty.shape());
+}
+
+TEST(InputType, Constructors) {
+ // Exact type constructor
+ InputType ty1(int8());
+ ASSERT_EQ(InputType::EXACT_TYPE, ty1.kind());
+ ASSERT_EQ(ValueDescr::ANY, ty1.shape());
+ AssertTypeEqual(*int8(), *ty1.type());
+
+ InputType ty1_implicit = int8();
+ ASSERT_TRUE(ty1.Equals(ty1_implicit));
+
+ InputType ty1_array(int8(), ValueDescr::ARRAY);
+ ASSERT_EQ(ValueDescr::ARRAY, ty1_array.shape());
+
+ InputType ty1_scalar(int8(), ValueDescr::SCALAR);
+ ASSERT_EQ(ValueDescr::SCALAR, ty1_scalar.shape());
+
+ // Same type id constructor
+ InputType ty2(Type::DECIMAL);
+ ASSERT_EQ(InputType::USE_TYPE_MATCHER, ty2.kind());
+ ASSERT_EQ("any[Type::DECIMAL128]", ty2.ToString());
+ ASSERT_TRUE(ty2.type_matcher().Matches(*decimal(12, 2)));
+ ASSERT_FALSE(ty2.type_matcher().Matches(*int16()));
+
+ InputType ty2_array(Type::DECIMAL, ValueDescr::ARRAY);
+ ASSERT_EQ(ValueDescr::ARRAY, ty2_array.shape());
+
+ InputType ty2_scalar(Type::DECIMAL, ValueDescr::SCALAR);
+ ASSERT_EQ(ValueDescr::SCALAR, ty2_scalar.shape());
+
+ // Implicit construction in a vector
+ std::vector<InputType> types = {int8(), InputType(Type::DECIMAL)};
+ ASSERT_TRUE(types[0].Equals(ty1));
+ ASSERT_TRUE(types[1].Equals(ty2));
+
+ // Copy constructor
+ InputType ty3 = ty1;
+ InputType ty4 = ty2;
+ ASSERT_TRUE(ty3.Equals(ty1));
+ ASSERT_TRUE(ty4.Equals(ty2));
+
+ // Move constructor
+ InputType ty5 = std::move(ty3);
+ InputType ty6 = std::move(ty4);
+ ASSERT_TRUE(ty5.Equals(ty1));
+ ASSERT_TRUE(ty6.Equals(ty2));
+
+ // ToString
+ ASSERT_EQ("any[int8]", ty1.ToString());
+ ASSERT_EQ("array[int8]", ty1_array.ToString());
+ ASSERT_EQ("scalar[int8]", ty1_scalar.ToString());
+
+ ASSERT_EQ("any[Type::DECIMAL128]", ty2.ToString());
+ ASSERT_EQ("array[Type::DECIMAL128]", ty2_array.ToString());
+ ASSERT_EQ("scalar[Type::DECIMAL128]", ty2_scalar.ToString());
+
+ InputType ty7(match::TimestampTypeUnit(TimeUnit::MICRO));
+ ASSERT_EQ("any[timestamp(us)]", ty7.ToString());
+
+ InputType ty8;
+ InputType ty9(ValueDescr::ANY);
+ InputType ty10(ValueDescr::ARRAY);
+ InputType ty11(ValueDescr::SCALAR);
+ ASSERT_EQ("any[any]", ty8.ToString());
+ ASSERT_EQ("any[any]", ty9.ToString());
+ ASSERT_EQ("array[any]", ty10.ToString());
+ ASSERT_EQ("scalar[any]", ty11.ToString());
+}
+
+TEST(InputType, Equals) {
+ InputType t1 = int8();
+ InputType t2 = int8();
+ InputType t3(int8(), ValueDescr::ARRAY);
+ InputType t3_i32(int32(), ValueDescr::ARRAY);
+ InputType t3_scalar(int8(), ValueDescr::SCALAR);
+ InputType t4(int8(), ValueDescr::ARRAY);
+ InputType t4_i32(int32(), ValueDescr::ARRAY);
+
+ InputType t5(Type::DECIMAL);
+ InputType t6(Type::DECIMAL);
+ InputType t7(Type::DECIMAL, ValueDescr::SCALAR);
+ InputType t7_i32(Type::INT32, ValueDescr::SCALAR);
+ InputType t8(Type::DECIMAL, ValueDescr::SCALAR);
+ InputType t8_i32(Type::INT32, ValueDescr::SCALAR);
+
+ ASSERT_TRUE(t1.Equals(t2));
+ ASSERT_EQ(t1, t2);
+
+ // ANY vs SCALAR
+ ASSERT_NE(t1, t3);
+
+ ASSERT_EQ(t3, t4);
+
+ // both ARRAY, but different type
+ ASSERT_NE(t3, t3_i32);
+
+ // ARRAY vs SCALAR
+ ASSERT_NE(t3, t3_scalar);
+
+ ASSERT_EQ(t3_i32, t4_i32);
+
+ ASSERT_FALSE(t1.Equals(t5));
+ ASSERT_NE(t1, t5);
+
+ ASSERT_EQ(t5, t5);
+ ASSERT_EQ(t5, t6);
+ ASSERT_NE(t5, t7);
+ ASSERT_EQ(t7, t8);
+ ASSERT_EQ(t7, t8);
+ ASSERT_NE(t7, t7_i32);
+ ASSERT_EQ(t7_i32, t8_i32);
+
+ // NOTE: For the time being, we treat int32() and Type::INT32 as being
+ // different. This could obviously be fixed later to make these equivalent
+ ASSERT_NE(InputType(int8()), InputType(Type::INT32));
+
+ // Check that field metadata excluded from equality checks
+ InputType t9 = list(
+ field("item", utf8(), /*nullable=*/true, key_value_metadata({"foo"}, {"bar"})));
+ InputType t10 = list(field("item", utf8()));
+ ASSERT_TRUE(t9.Equals(t10));
+}
+
+TEST(InputType, Hash) {
+ InputType t0;
+ InputType t0_scalar(ValueDescr::SCALAR);
+ InputType t0_array(ValueDescr::ARRAY);
+
+ InputType t1 = int8();
+ InputType t2(Type::DECIMAL);
+
+ // These checks try to determine first of all whether Hash always returns the
+ // same value, and whether the elements of the type are all incorporated into
+ // the Hash
+ ASSERT_EQ(t0.Hash(), t0.Hash());
+ ASSERT_NE(t0.Hash(), t0_scalar.Hash());
+ ASSERT_NE(t0.Hash(), t0_array.Hash());
+ ASSERT_NE(t0_scalar.Hash(), t0_array.Hash());
+
+ ASSERT_EQ(t1.Hash(), t1.Hash());
+ ASSERT_EQ(t2.Hash(), t2.Hash());
+
+ ASSERT_NE(t0.Hash(), t1.Hash());
+ ASSERT_NE(t0.Hash(), t2.Hash());
+ ASSERT_NE(t1.Hash(), t2.Hash());
+}
+
+TEST(InputType, Matches) {
+ InputType ty1 = int8();
+
+ ASSERT_TRUE(ty1.Matches(ValueDescr::Scalar(int8())));
+ ASSERT_TRUE(ty1.Matches(ValueDescr::Array(int8())));
+ ASSERT_TRUE(ty1.Matches(ValueDescr::Any(int8())));
+ ASSERT_FALSE(ty1.Matches(ValueDescr::Any(int16())));
+
+ InputType ty2(Type::DECIMAL);
+ ASSERT_TRUE(ty2.Matches(ValueDescr::Scalar(decimal(12, 2))));
+ ASSERT_TRUE(ty2.Matches(ValueDescr::Array(decimal(12, 2))));
+ ASSERT_FALSE(ty2.Matches(ValueDescr::Any(float64())));
+
+ InputType ty3(int64(), ValueDescr::SCALAR);
+ ASSERT_FALSE(ty3.Matches(ValueDescr::Array(int64())));
+ ASSERT_TRUE(ty3.Matches(ValueDescr::Scalar(int64())));
+ ASSERT_FALSE(ty3.Matches(ValueDescr::Scalar(int32())));
+ ASSERT_FALSE(ty3.Matches(ValueDescr::Any(int64())));
+}
+
+// ----------------------------------------------------------------------
+// OutputType
+
+TEST(OutputType, Constructors) {
+ OutputType ty1 = int8();
+ ASSERT_EQ(OutputType::FIXED, ty1.kind());
+ AssertTypeEqual(*int8(), *ty1.type());
+
+ auto DummyResolver = [](KernelContext*,
+ const std::vector<ValueDescr>& args) -> Result<ValueDescr> {
+ return ValueDescr(int32(), GetBroadcastShape(args));
+ };
+ OutputType ty2(DummyResolver);
+ ASSERT_EQ(OutputType::COMPUTED, ty2.kind());
+
+ ASSERT_OK_AND_ASSIGN(ValueDescr out_descr2, ty2.Resolve(nullptr, {}));
+ ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr2);
+
+ // Copy constructor
+ OutputType ty3 = ty1;
+ ASSERT_EQ(OutputType::FIXED, ty3.kind());
+ AssertTypeEqual(*ty1.type(), *ty3.type());
+
+ OutputType ty4 = ty2;
+ ASSERT_EQ(OutputType::COMPUTED, ty4.kind());
+ ASSERT_OK_AND_ASSIGN(ValueDescr out_descr4, ty4.Resolve(nullptr, {}));
+ ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr4);
+
+ // Move constructor
+ OutputType ty5 = std::move(ty1);
+ ASSERT_EQ(OutputType::FIXED, ty5.kind());
+ AssertTypeEqual(*int8(), *ty5.type());
+
+ OutputType ty6 = std::move(ty4);
+ ASSERT_EQ(OutputType::COMPUTED, ty6.kind());
+ ASSERT_OK_AND_ASSIGN(ValueDescr out_descr6, ty6.Resolve(nullptr, {}));
+ ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr6);
+
+ // ToString
+
+ // ty1 was copied to ty3
+ ASSERT_EQ("int8", ty3.ToString());
+ ASSERT_EQ("computed", ty2.ToString());
+}
+
+TEST(OutputType, Resolve) {
+ // Check shape promotion rules for FIXED kind
+ OutputType ty1(int32());
+
+ ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty1.Resolve(nullptr, {}));
+ ASSERT_EQ(ValueDescr::Scalar(int32()), descr);
+
+ ASSERT_OK_AND_ASSIGN(descr,
+ ty1.Resolve(nullptr, {ValueDescr(int8(), ValueDescr::SCALAR)}));
+ ASSERT_EQ(ValueDescr::Scalar(int32()), descr);
+
+ ASSERT_OK_AND_ASSIGN(descr,
+ ty1.Resolve(nullptr, {ValueDescr(int8(), ValueDescr::SCALAR),
+ ValueDescr(int8(), ValueDescr::ARRAY)}));
+ ASSERT_EQ(ValueDescr::Array(int32()), descr);
+
+ OutputType ty2([](KernelContext*, const std::vector<ValueDescr>& args) {
+ return ValueDescr(args[0].type, GetBroadcastShape(args));
+ });
+
+ ASSERT_OK_AND_ASSIGN(descr, ty2.Resolve(nullptr, {ValueDescr::Array(utf8())}));
+ ASSERT_EQ(ValueDescr::Array(utf8()), descr);
+
+ // Type resolver that returns an error
+ OutputType ty3(
+ [](KernelContext* ctx, const std::vector<ValueDescr>& args) -> Result<ValueDescr> {
+ // NB: checking the value types versus the function arity should be
+ // validated elsewhere, so this is just for illustration purposes
+ if (args.size() == 0) {
+ return Status::Invalid("Need at least one argument");
+ }
+ return ValueDescr(args[0]);
+ });
+ ASSERT_RAISES(Invalid, ty3.Resolve(nullptr, {}));
+
+ // Type resolver that returns ValueDescr::ANY and needs type promotion
+ OutputType ty4(
+ [](KernelContext* ctx, const std::vector<ValueDescr>& args) -> Result<ValueDescr> {
+ return int32();
+ });
+
+ ASSERT_OK_AND_ASSIGN(descr, ty4.Resolve(nullptr, {ValueDescr::Array(int8())}));
+ ASSERT_EQ(ValueDescr::Array(int32()), descr);
+ ASSERT_OK_AND_ASSIGN(descr, ty4.Resolve(nullptr, {ValueDescr::Scalar(int8())}));
+ ASSERT_EQ(ValueDescr::Scalar(int32()), descr);
+}
+
+TEST(OutputType, ResolveDescr) {
+ ValueDescr d1 = ValueDescr::Scalar(int32());
+ ValueDescr d2 = ValueDescr::Array(int32());
+
+ OutputType ty1(d1);
+ OutputType ty2(d2);
+
+ ASSERT_EQ(ValueDescr::SCALAR, ty1.shape());
+ ASSERT_EQ(ValueDescr::ARRAY, ty2.shape());
+
+ {
+ ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty1.Resolve(nullptr, {}));
+ ASSERT_EQ(d1, descr);
+ }
+
+ {
+ ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty2.Resolve(nullptr, {}));
+ ASSERT_EQ(d2, descr);
+ }
+}
+
+// ----------------------------------------------------------------------
+// KernelSignature
+
+TEST(KernelSignature, Basics) {
+ // (any[int8], scalar[decimal]) -> utf8
+ std::vector<InputType> in_types({int8(), InputType(Type::DECIMAL, ValueDescr::SCALAR)});
+ OutputType out_type(utf8());
+
+ KernelSignature sig(in_types, out_type);
+ ASSERT_EQ(2, sig.in_types().size());
+ ASSERT_TRUE(sig.in_types()[0].type()->Equals(*int8()));
+ ASSERT_TRUE(sig.in_types()[0].Matches(ValueDescr::Scalar(int8())));
+ ASSERT_TRUE(sig.in_types()[0].Matches(ValueDescr::Array(int8())));
+
+ ASSERT_TRUE(sig.in_types()[1].Matches(ValueDescr::Scalar(decimal(12, 2))));
+ ASSERT_FALSE(sig.in_types()[1].Matches(ValueDescr::Array(decimal(12, 2))));
+}
+
+TEST(KernelSignature, Equals) {
+ KernelSignature sig1({}, utf8());
+ KernelSignature sig1_copy({}, utf8());
+ KernelSignature sig2({int8()}, utf8());
+
+ // Output type doesn't matter (for now)
+ KernelSignature sig3({int8()}, int32());
+
+ KernelSignature sig4({int8(), int16()}, utf8());
+ KernelSignature sig4_copy({int8(), int16()}, utf8());
+ KernelSignature sig5({int8(), int16(), int32()}, utf8());
+
+ // Differ in shape
+ KernelSignature sig6({ValueDescr::Scalar(int8())}, utf8());
+ KernelSignature sig7({ValueDescr::Array(int8())}, utf8());
+
+ ASSERT_EQ(sig1, sig1);
+
+ ASSERT_EQ(sig2, sig3);
+ ASSERT_NE(sig3, sig4);
+
+ // Different sig objects, but same sig
+ ASSERT_EQ(sig1, sig1_copy);
+ ASSERT_EQ(sig4, sig4_copy);
+
+ // Match first 2 args, but not third
+ ASSERT_NE(sig4, sig5);
+
+ ASSERT_NE(sig6, sig7);
+}
+
+TEST(KernelSignature, VarArgsEquals) {
+ KernelSignature sig1({int8()}, utf8(), /*is_varargs=*/true);
+ KernelSignature sig2({int8()}, utf8(), /*is_varargs=*/true);
+ KernelSignature sig3({int8()}, utf8());
+
+ ASSERT_EQ(sig1, sig2);
+ ASSERT_NE(sig2, sig3);
+}
+
+TEST(KernelSignature, Hash) {
+ // Some basic tests to ensure that the hashes are deterministic and that all
+ // input arguments are incorporated
+ KernelSignature sig1({}, utf8());
+ KernelSignature sig2({int8()}, utf8());
+ KernelSignature sig3({int8(), int32()}, utf8());
+
+ ASSERT_EQ(sig1.Hash(), sig1.Hash());
+ ASSERT_EQ(sig2.Hash(), sig2.Hash());
+ ASSERT_NE(sig1.Hash(), sig2.Hash());
+ ASSERT_NE(sig2.Hash(), sig3.Hash());
+}
+
+TEST(KernelSignature, MatchesInputs) {
+ // () -> boolean
+ KernelSignature sig1({}, boolean());
+
+ ASSERT_TRUE(sig1.MatchesInputs({}));
+ ASSERT_FALSE(sig1.MatchesInputs({int8()}));
+
+ // (any[int8], any[decimal]) -> boolean
+ KernelSignature sig2({int8(), InputType(Type::DECIMAL)}, boolean());
+
+ ASSERT_FALSE(sig2.MatchesInputs({}));
+ ASSERT_FALSE(sig2.MatchesInputs({int8()}));
+ ASSERT_TRUE(sig2.MatchesInputs({int8(), decimal(12, 2)}));
+ ASSERT_TRUE(sig2.MatchesInputs(
+ {ValueDescr::Scalar(int8()), ValueDescr::Scalar(decimal(12, 2))}));
+ ASSERT_TRUE(
+ sig2.MatchesInputs({ValueDescr::Array(int8()), ValueDescr::Array(decimal(12, 2))}));
+
+ // (scalar[int8], array[int32]) -> boolean
+ KernelSignature sig3({ValueDescr::Scalar(int8()), ValueDescr::Array(int32())},
+ boolean());
+
+ ASSERT_FALSE(sig3.MatchesInputs({}));
+
+ // Unqualified, these are ANY type and do not match because the kernel
+ // requires a scalar and an array
+ ASSERT_FALSE(sig3.MatchesInputs({int8(), int32()}));
+ ASSERT_TRUE(
+ sig3.MatchesInputs({ValueDescr::Scalar(int8()), ValueDescr::Array(int32())}));
+ ASSERT_FALSE(
+ sig3.MatchesInputs({ValueDescr::Array(int8()), ValueDescr::Array(int32())}));
+}
+
+TEST(KernelSignature, VarArgsMatchesInputs) {
+ {
+ KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true);
+
+ std::vector<ValueDescr> args = {int8()};
+ ASSERT_TRUE(sig.MatchesInputs(args));
+ args.push_back(ValueDescr::Scalar(int8()));
+ args.push_back(ValueDescr::Array(int8()));
+ ASSERT_TRUE(sig.MatchesInputs(args));
+ args.push_back(int32());
+ ASSERT_FALSE(sig.MatchesInputs(args));
+ }
+ {
+ KernelSignature sig({int8(), utf8()}, utf8(), /*is_varargs=*/true);
+
+ std::vector<ValueDescr> args = {int8()};
+ ASSERT_TRUE(sig.MatchesInputs(args));
+ args.push_back(ValueDescr::Scalar(utf8()));
+ args.push_back(ValueDescr::Array(utf8()));
+ ASSERT_TRUE(sig.MatchesInputs(args));
+ args.push_back(int32());
+ ASSERT_FALSE(sig.MatchesInputs(args));
+ }
+}
+
+TEST(KernelSignature, ToString) {
+ std::vector<InputType> in_types = {InputType(int8(), ValueDescr::SCALAR),
+ InputType(Type::DECIMAL, ValueDescr::ARRAY),
+ InputType(utf8())};
+ KernelSignature sig(in_types, utf8());
+ ASSERT_EQ("(scalar[int8], array[Type::DECIMAL128], any[string]) -> string",
+ sig.ToString());
+
+ OutputType out_type([](KernelContext*, const std::vector<ValueDescr>& args) {
+ return Status::Invalid("NYI");
+ });
+ KernelSignature sig2({int8(), InputType(Type::DECIMAL)}, out_type);
+ ASSERT_EQ("(any[int8], any[Type::DECIMAL128]) -> computed", sig2.ToString());
+}
+
+TEST(KernelSignature, VarArgsToString) {
+ KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true);
+ ASSERT_EQ("varargs[any[int8]] -> string", sig.ToString());
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/CMakeLists.txt b/src/arrow/cpp/src/arrow/compute/kernels/CMakeLists.txt
new file mode 100644
index 000000000..28686a9ca
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/CMakeLists.txt
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# ----------------------------------------------------------------------
+# Scalar kernels
+
+add_arrow_compute_test(scalar_test
+ SOURCES
+ scalar_arithmetic_test.cc
+ scalar_boolean_test.cc
+ scalar_cast_test.cc
+ scalar_compare_test.cc
+ scalar_nested_test.cc
+ scalar_set_lookup_test.cc
+ scalar_string_test.cc
+ scalar_temporal_test.cc
+ scalar_validity_test.cc
+ scalar_if_else_test.cc
+ test_util.cc)
+
+add_arrow_benchmark(scalar_arithmetic_benchmark PREFIX "arrow-compute")
+add_arrow_benchmark(scalar_boolean_benchmark PREFIX "arrow-compute")
+add_arrow_benchmark(scalar_cast_benchmark PREFIX "arrow-compute")
+add_arrow_benchmark(scalar_compare_benchmark PREFIX "arrow-compute")
+add_arrow_benchmark(scalar_if_else_benchmark PREFIX "arrow-compute")
+add_arrow_benchmark(scalar_set_lookup_benchmark PREFIX "arrow-compute")
+add_arrow_benchmark(scalar_string_benchmark PREFIX "arrow-compute")
+
+# ----------------------------------------------------------------------
+# Vector kernels
+
+add_arrow_compute_test(vector_test
+ SOURCES
+ vector_hash_test.cc
+ vector_nested_test.cc
+ vector_replace_test.cc
+ vector_selection_test.cc
+ vector_sort_test.cc
+ select_k_test.cc
+ test_util.cc)
+
+add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute")
+add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute")
+add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute")
+add_arrow_benchmark(vector_topk_benchmark PREFIX "arrow-compute")
+add_arrow_benchmark(vector_replace_benchmark PREFIX "arrow-compute")
+add_arrow_benchmark(vector_selection_benchmark PREFIX "arrow-compute")
+
+# ----------------------------------------------------------------------
+# Aggregate kernels
+
+# Aggregates
+
+add_arrow_compute_test(aggregate_test
+ SOURCES
+ aggregate_test.cc
+ hash_aggregate_test.cc
+ test_util.cc)
+add_arrow_benchmark(aggregate_benchmark PREFIX "arrow-compute")
+
+# ----------------------------------------------------------------------
+# Utilities
+
+add_arrow_compute_test(kernel_utility_test SOURCES codegen_internal_test.cc)
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic.cc
new file mode 100644
index 000000000..25697f7d3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic.cc
@@ -0,0 +1,1011 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/kernels/aggregate_basic_internal.h"
+#include "arrow/compute/kernels/aggregate_internal.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/util/cpu_info.h"
+#include "arrow/util/hashing.h"
+#include "arrow/util/make_unique.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+namespace {
+
+Status AggregateConsume(KernelContext* ctx, const ExecBatch& batch) {
+ return checked_cast<ScalarAggregator*>(ctx->state())->Consume(ctx, batch);
+}
+
+Status AggregateMerge(KernelContext* ctx, KernelState&& src, KernelState* dst) {
+ return checked_cast<ScalarAggregator*>(dst)->MergeFrom(ctx, std::move(src));
+}
+
+Status AggregateFinalize(KernelContext* ctx, Datum* out) {
+ return checked_cast<ScalarAggregator*>(ctx->state())->Finalize(ctx, out);
+}
+
+} // namespace
+
+void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
+ ScalarAggregateFunction* func, SimdLevel::type simd_level) {
+ ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateConsume,
+ AggregateMerge, AggregateFinalize);
+ // Set the simd level
+ kernel.simd_level = simd_level;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+}
+
+void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
+ ScalarAggregateFinalize finalize, ScalarAggregateFunction* func,
+ SimdLevel::type simd_level) {
+ ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateConsume,
+ AggregateMerge, std::move(finalize));
+ // Set the simd level
+ kernel.simd_level = simd_level;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+}
+
+namespace {
+
+// ----------------------------------------------------------------------
+// Count implementation
+
+struct CountImpl : public ScalarAggregator {
+ explicit CountImpl(CountOptions options) : options(std::move(options)) {}
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ if (options.mode == CountOptions::ALL) {
+ this->non_nulls += batch.length;
+ } else if (batch[0].is_array()) {
+ const ArrayData& input = *batch[0].array();
+ const int64_t nulls = input.GetNullCount();
+ this->nulls += nulls;
+ this->non_nulls += input.length - nulls;
+ } else {
+ const Scalar& input = *batch[0].scalar();
+ this->nulls += !input.is_valid * batch.length;
+ this->non_nulls += input.is_valid * batch.length;
+ }
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override {
+ const auto& other_state = checked_cast<const CountImpl&>(src);
+ this->non_nulls += other_state.non_nulls;
+ this->nulls += other_state.nulls;
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext* ctx, Datum* out) override {
+ const auto& state = checked_cast<const CountImpl&>(*ctx->state());
+ switch (state.options.mode) {
+ case CountOptions::ONLY_VALID:
+ case CountOptions::ALL:
+ // ALL is equivalent since we don't count the null/non-null
+ // separately to avoid potentially computing null count
+ *out = Datum(state.non_nulls);
+ break;
+ case CountOptions::ONLY_NULL:
+ *out = Datum(state.nulls);
+ break;
+ default:
+ DCHECK(false) << "unreachable";
+ }
+ return Status::OK();
+ }
+
+ CountOptions options;
+ int64_t non_nulls = 0;
+ int64_t nulls = 0;
+};
+
+Result<std::unique_ptr<KernelState>> CountInit(KernelContext*,
+ const KernelInitArgs& args) {
+ return ::arrow::internal::make_unique<CountImpl>(
+ static_cast<const CountOptions&>(*args.options));
+}
+
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type, typename VisitorArgType>
+struct CountDistinctImpl : public ScalarAggregator {
+ using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+ explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options)
+ : options(std::move(options)), memo_table_(new MemoTable(memory_pool, 0)) {}
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ if (batch[0].is_array()) {
+ const ArrayData& arr = *batch[0].array();
+ auto visit_null = []() { return Status::OK(); };
+ auto visit_value = [&](VisitorArgType arg) {
+ int y;
+ return memo_table_->GetOrInsert(arg, &y);
+ };
+ RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null));
+ this->non_nulls += memo_table_->size();
+ this->has_nulls = arr.GetNullCount() > 0;
+ } else {
+ const Scalar& input = *batch[0].scalar();
+ this->has_nulls = !input.is_valid;
+ if (input.is_valid) {
+ this->non_nulls += batch.length;
+ }
+ }
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override {
+ const auto& other_state = checked_cast<const CountDistinctImpl&>(src);
+ this->non_nulls += other_state.non_nulls;
+ this->has_nulls = this->has_nulls || other_state.has_nulls;
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext* ctx, Datum* out) override {
+ const auto& state = checked_cast<const CountDistinctImpl&>(*ctx->state());
+ const int64_t nulls = state.has_nulls ? 1 : 0;
+ switch (state.options.mode) {
+ case CountOptions::ONLY_VALID:
+ *out = Datum(state.non_nulls);
+ break;
+ case CountOptions::ALL:
+ *out = Datum(state.non_nulls + nulls);
+ break;
+ case CountOptions::ONLY_NULL:
+ *out = Datum(nulls);
+ break;
+ default:
+ DCHECK(false) << "unreachable";
+ }
+ return Status::OK();
+ }
+
+ const CountOptions options;
+ int64_t non_nulls = 0;
+ bool has_nulls = false;
+ std::unique_ptr<MemoTable> memo_table_;
+};
+
+template <typename Type, typename VisitorArgType>
+Result<std::unique_ptr<KernelState>> CountDistinctInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ return ::arrow::internal::make_unique<CountDistinctImpl<Type, VisitorArgType>>(
+ ctx->memory_pool(), static_cast<const CountOptions&>(*args.options));
+}
+
+template <typename Type, typename VisitorArgType = typename Type::c_type>
+void AddCountDistinctKernel(InputType type, ScalarAggregateFunction* func) {
+ AddAggKernel(KernelSignature::Make({type}, ValueDescr::Scalar(int64())),
+ CountDistinctInit<Type, VisitorArgType>, func);
+}
+
+void AddCountDistinctKernels(ScalarAggregateFunction* func) {
+ // Boolean
+ AddCountDistinctKernel<BooleanType>(boolean(), func);
+ // Number
+ AddCountDistinctKernel<Int8Type>(int8(), func);
+ AddCountDistinctKernel<Int16Type>(int16(), func);
+ AddCountDistinctKernel<Int32Type>(int32(), func);
+ AddCountDistinctKernel<Int64Type>(int64(), func);
+ AddCountDistinctKernel<UInt8Type>(uint8(), func);
+ AddCountDistinctKernel<UInt16Type>(uint16(), func);
+ AddCountDistinctKernel<UInt32Type>(uint32(), func);
+ AddCountDistinctKernel<UInt64Type>(uint64(), func);
+ AddCountDistinctKernel<HalfFloatType>(float16(), func);
+ AddCountDistinctKernel<FloatType>(float32(), func);
+ AddCountDistinctKernel<DoubleType>(float64(), func);
+ // Date
+ AddCountDistinctKernel<Date32Type>(date32(), func);
+ AddCountDistinctKernel<Date64Type>(date64(), func);
+ // Time
+ AddCountDistinctKernel<Time32Type>(match::SameTypeId(Type::TIME32), func);
+ AddCountDistinctKernel<Time64Type>(match::SameTypeId(Type::TIME64), func);
+ // Timestamp & Duration
+ AddCountDistinctKernel<TimestampType>(match::SameTypeId(Type::TIMESTAMP), func);
+ AddCountDistinctKernel<DurationType>(match::SameTypeId(Type::DURATION), func);
+ // Interval
+ AddCountDistinctKernel<MonthIntervalType>(month_interval(), func);
+ AddCountDistinctKernel<DayTimeIntervalType>(day_time_interval(), func);
+ AddCountDistinctKernel<MonthDayNanoIntervalType>(month_day_nano_interval(), func);
+ // Binary & String
+ AddCountDistinctKernel<BinaryType, util::string_view>(match::BinaryLike(), func);
+ AddCountDistinctKernel<LargeBinaryType, util::string_view>(match::LargeBinaryLike(),
+ func);
+ // Fixed binary & Decimal
+ AddCountDistinctKernel<FixedSizeBinaryType, util::string_view>(
+ match::FixedSizeBinaryLike(), func);
+}
+
+// ----------------------------------------------------------------------
+// Sum implementation
+
+template <typename ArrowType>
+struct SumImplDefault : public SumImpl<ArrowType, SimdLevel::NONE> {
+ using SumImpl<ArrowType, SimdLevel::NONE>::SumImpl;
+};
+
+template <typename ArrowType>
+struct MeanImplDefault : public MeanImpl<ArrowType, SimdLevel::NONE> {
+ using MeanImpl<ArrowType, SimdLevel::NONE>::MeanImpl;
+};
+
+Result<std::unique_ptr<KernelState>> SumInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ SumLikeInit<SumImplDefault> visitor(
+ ctx, args.inputs[0].type,
+ static_cast<const ScalarAggregateOptions&>(*args.options));
+ return visitor.Create();
+}
+
+Result<std::unique_ptr<KernelState>> MeanInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ SumLikeInit<MeanImplDefault> visitor(
+ ctx, args.inputs[0].type,
+ static_cast<const ScalarAggregateOptions&>(*args.options));
+ return visitor.Create();
+}
+
+// ----------------------------------------------------------------------
+// Product implementation
+
+using arrow::compute::internal::to_unsigned;
+
+template <typename ArrowType>
+struct ProductImpl : public ScalarAggregator {
+ using ThisType = ProductImpl<ArrowType>;
+ using AccType = typename FindAccumulatorType<ArrowType>::Type;
+ using ProductType = typename TypeTraits<AccType>::CType;
+ using OutputType = typename TypeTraits<AccType>::ScalarType;
+
+ explicit ProductImpl(const std::shared_ptr<DataType>& out_type,
+ const ScalarAggregateOptions& options)
+ : out_type(out_type),
+ options(options),
+ count(0),
+ product(MultiplyTraits<AccType>::one(*out_type)),
+ nulls_observed(false) {}
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ if (batch[0].is_array()) {
+ const auto& data = batch[0].array();
+ this->count += data->length - data->GetNullCount();
+ this->nulls_observed = this->nulls_observed || data->GetNullCount();
+
+ if (!options.skip_nulls && this->nulls_observed) {
+ // Short-circuit
+ return Status::OK();
+ }
+
+ internal::VisitArrayValuesInline<ArrowType>(
+ *data,
+ [&](typename TypeTraits<ArrowType>::CType value) {
+ this->product =
+ MultiplyTraits<AccType>::Multiply(*out_type, this->product, value);
+ },
+ [] {});
+ } else {
+ const auto& data = *batch[0].scalar();
+ this->count += data.is_valid * batch.length;
+ this->nulls_observed = this->nulls_observed || !data.is_valid;
+ if (data.is_valid) {
+ for (int64_t i = 0; i < batch.length; i++) {
+ auto value = internal::UnboxScalar<ArrowType>::Unbox(data);
+ this->product =
+ MultiplyTraits<AccType>::Multiply(*out_type, this->product, value);
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override {
+ const auto& other = checked_cast<const ThisType&>(src);
+ this->count += other.count;
+ this->product =
+ MultiplyTraits<AccType>::Multiply(*out_type, this->product, other.product);
+ this->nulls_observed = this->nulls_observed || other.nulls_observed;
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext*, Datum* out) override {
+ if ((!options.skip_nulls && this->nulls_observed) ||
+ (this->count < options.min_count)) {
+ out->value = std::make_shared<OutputType>(out_type);
+ } else {
+ out->value = std::make_shared<OutputType>(this->product, out_type);
+ }
+ return Status::OK();
+ }
+
+ std::shared_ptr<DataType> out_type;
+ ScalarAggregateOptions options;
+ size_t count;
+ ProductType product;
+ bool nulls_observed;
+};
+
+struct ProductInit {
+ std::unique_ptr<KernelState> state;
+ KernelContext* ctx;
+ const std::shared_ptr<DataType>& type;
+ const ScalarAggregateOptions& options;
+
+ ProductInit(KernelContext* ctx, const std::shared_ptr<DataType>& type,
+ const ScalarAggregateOptions& options)
+ : ctx(ctx), type(type), options(options) {}
+
+ Status Visit(const DataType&) {
+ return Status::NotImplemented("No product implemented");
+ }
+
+ Status Visit(const HalfFloatType&) {
+ return Status::NotImplemented("No product implemented");
+ }
+
+ Status Visit(const BooleanType&) {
+ auto ty = TypeTraits<typename ProductImpl<BooleanType>::AccType>::type_singleton();
+ state.reset(new ProductImpl<BooleanType>(ty, options));
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_number<Type, Status> Visit(const Type&) {
+ auto ty = TypeTraits<typename ProductImpl<Type>::AccType>::type_singleton();
+ state.reset(new ProductImpl<Type>(ty, options));
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_decimal<Type, Status> Visit(const Type&) {
+ state.reset(new ProductImpl<Type>(type, options));
+ return Status::OK();
+ }
+
+ Result<std::unique_ptr<KernelState>> Create() {
+ RETURN_NOT_OK(VisitTypeInline(*type, this));
+ return std::move(state);
+ }
+
+ static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ ProductInit visitor(ctx, args.inputs[0].type,
+ static_cast<const ScalarAggregateOptions&>(*args.options));
+ return visitor.Create();
+ }
+};
+
+// ----------------------------------------------------------------------
+// MinMax implementation
+
+Result<std::unique_ptr<KernelState>> MinMaxInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ ARROW_ASSIGN_OR_RAISE(auto out_type,
+ args.kernel->signature->out_type().Resolve(ctx, args.inputs));
+ MinMaxInitState<SimdLevel::NONE> visitor(
+ ctx, *args.inputs[0].type, std::move(out_type.type),
+ static_cast<const ScalarAggregateOptions&>(*args.options));
+ return visitor.Create();
+}
+
+// For "min" and "max" functions: override finalize and return the actual value
+template <MinOrMax min_or_max>
+void AddMinOrMaxAggKernel(ScalarAggregateFunction* func,
+ ScalarAggregateFunction* min_max_func) {
+ auto sig = KernelSignature::Make(
+ {InputType(ValueDescr::ANY)},
+ OutputType([](KernelContext*,
+ const std::vector<ValueDescr>& descrs) -> Result<ValueDescr> {
+ // any[T] -> scalar[T]
+ return ValueDescr::Scalar(descrs.front().type);
+ }));
+
+ auto init = [min_max_func](
+ KernelContext* ctx,
+ const KernelInitArgs& args) -> Result<std::unique_ptr<KernelState>> {
+ std::vector<ValueDescr> inputs = args.inputs;
+ ARROW_ASSIGN_OR_RAISE(auto kernel, min_max_func->DispatchBest(&inputs));
+ KernelInitArgs new_args{kernel, inputs, args.options};
+ return kernel->init(ctx, new_args);
+ };
+
+ auto finalize = [](KernelContext* ctx, Datum* out) -> Status {
+ Datum temp;
+ RETURN_NOT_OK(checked_cast<ScalarAggregator*>(ctx->state())->Finalize(ctx, &temp));
+ const auto& result = temp.scalar_as<StructScalar>();
+ DCHECK(result.is_valid);
+ *out = result.value[static_cast<uint8_t>(min_or_max)];
+ return Status::OK();
+ };
+
+ // Note SIMD level is always NONE, but the convenience kernel will
+ // dispatch to an appropriate implementation
+ AddAggKernel(std::move(sig), std::move(init), std::move(finalize), func);
+}
+
+// ----------------------------------------------------------------------
+// Any implementation
+
+struct BooleanAnyImpl : public ScalarAggregator {
+ explicit BooleanAnyImpl(ScalarAggregateOptions options) : options(std::move(options)) {}
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ // short-circuit if seen a True already
+ if (this->any == true && this->count >= options.min_count) {
+ return Status::OK();
+ }
+ if (batch[0].is_scalar()) {
+ const auto& scalar = *batch[0].scalar();
+ this->has_nulls = !scalar.is_valid;
+ this->any = scalar.is_valid && checked_cast<const BooleanScalar&>(scalar).value;
+ this->count += scalar.is_valid;
+ return Status::OK();
+ }
+ const auto& data = *batch[0].array();
+ this->has_nulls = data.GetNullCount() > 0;
+ this->count += data.length - data.GetNullCount();
+ arrow::internal::OptionalBinaryBitBlockCounter counter(
+ data.buffers[0], data.offset, data.buffers[1], data.offset, data.length);
+ int64_t position = 0;
+ while (position < data.length) {
+ const auto block = counter.NextAndBlock();
+ if (block.popcount > 0) {
+ this->any = true;
+ break;
+ }
+ position += block.length;
+ }
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override {
+ const auto& other = checked_cast<const BooleanAnyImpl&>(src);
+ this->any |= other.any;
+ this->has_nulls |= other.has_nulls;
+ this->count += other.count;
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext* ctx, Datum* out) override {
+ if ((!options.skip_nulls && !this->any && this->has_nulls) ||
+ this->count < options.min_count) {
+ out->value = std::make_shared<BooleanScalar>();
+ } else {
+ out->value = std::make_shared<BooleanScalar>(this->any);
+ }
+ return Status::OK();
+ }
+
+ bool any = false;
+ bool has_nulls = false;
+ int64_t count = 0;
+ ScalarAggregateOptions options;
+};
+
+Result<std::unique_ptr<KernelState>> AnyInit(KernelContext*, const KernelInitArgs& args) {
+ const ScalarAggregateOptions options =
+ static_cast<const ScalarAggregateOptions&>(*args.options);
+ return ::arrow::internal::make_unique<BooleanAnyImpl>(
+ static_cast<const ScalarAggregateOptions&>(*args.options));
+}
+
+// ----------------------------------------------------------------------
+// All implementation
+
+struct BooleanAllImpl : public ScalarAggregator {
+ explicit BooleanAllImpl(ScalarAggregateOptions options) : options(std::move(options)) {}
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ // short-circuit if seen a false already
+ if (this->all == false && this->count >= options.min_count) {
+ return Status::OK();
+ }
+ // short-circuit if seen a null already
+ if (!options.skip_nulls && this->has_nulls) {
+ return Status::OK();
+ }
+ if (batch[0].is_scalar()) {
+ const auto& scalar = *batch[0].scalar();
+ this->has_nulls = !scalar.is_valid;
+ this->count += scalar.is_valid;
+ this->all = !scalar.is_valid || checked_cast<const BooleanScalar&>(scalar).value;
+ return Status::OK();
+ }
+ const auto& data = *batch[0].array();
+ this->has_nulls = data.GetNullCount() > 0;
+ this->count += data.length - data.GetNullCount();
+ arrow::internal::OptionalBinaryBitBlockCounter counter(
+ data.buffers[1], data.offset, data.buffers[0], data.offset, data.length);
+ int64_t position = 0;
+ while (position < data.length) {
+ const auto block = counter.NextOrNotBlock();
+ if (!block.AllSet()) {
+ this->all = false;
+ break;
+ }
+ position += block.length;
+ }
+
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override {
+ const auto& other = checked_cast<const BooleanAllImpl&>(src);
+ this->all &= other.all;
+ this->has_nulls |= other.has_nulls;
+ this->count += other.count;
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext*, Datum* out) override {
+ if ((!options.skip_nulls && this->all && this->has_nulls) ||
+ this->count < options.min_count) {
+ out->value = std::make_shared<BooleanScalar>();
+ } else {
+ out->value = std::make_shared<BooleanScalar>(this->all);
+ }
+ return Status::OK();
+ }
+
+ bool all = true;
+ bool has_nulls = false;
+ int64_t count = 0;
+ ScalarAggregateOptions options;
+};
+
+Result<std::unique_ptr<KernelState>> AllInit(KernelContext*, const KernelInitArgs& args) {
+ return ::arrow::internal::make_unique<BooleanAllImpl>(
+ static_cast<const ScalarAggregateOptions&>(*args.options));
+}
+
+// ----------------------------------------------------------------------
+// Index implementation
+
+template <typename ArgType>
+struct IndexImpl : public ScalarAggregator {
+ using ArgValue = typename internal::GetViewType<ArgType>::T;
+
+ explicit IndexImpl(IndexOptions options, KernelState* raw_state)
+ : options(std::move(options)), seen(0), index(-1) {
+ if (auto state = static_cast<IndexImpl<ArgType>*>(raw_state)) {
+ seen = state->seen;
+ index = state->index;
+ }
+ }
+
+ Status Consume(KernelContext* ctx, const ExecBatch& batch) override {
+ // short-circuit
+ if (index >= 0 || !options.value->is_valid) {
+ return Status::OK();
+ }
+
+ const ArgValue desired = internal::UnboxScalar<ArgType>::Unbox(*options.value);
+
+ if (batch[0].is_scalar()) {
+ seen = batch.length;
+ if (batch[0].scalar()->is_valid) {
+ const ArgValue v = internal::UnboxScalar<ArgType>::Unbox(*batch[0].scalar());
+ if (v == desired) {
+ index = 0;
+ return Status::Cancelled("Found");
+ }
+ }
+ return Status::OK();
+ }
+
+ auto input = batch[0].array();
+ seen = input->length;
+ int64_t i = 0;
+
+ ARROW_UNUSED(internal::VisitArrayValuesInline<ArgType>(
+ *input,
+ [&](ArgValue v) -> Status {
+ if (v == desired) {
+ index = i;
+ return Status::Cancelled("Found");
+ } else {
+ ++i;
+ return Status::OK();
+ }
+ },
+ [&]() -> Status {
+ ++i;
+ return Status::OK();
+ }));
+
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override {
+ const auto& other = checked_cast<const IndexImpl&>(src);
+ if (index < 0 && other.index >= 0) {
+ index = seen + other.index;
+ }
+ seen += other.seen;
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext*, Datum* out) override {
+ out->value = std::make_shared<Int64Scalar>(index >= 0 ? index : -1);
+ return Status::OK();
+ }
+
+ const IndexOptions options;
+ int64_t seen = 0;
+ int64_t index = -1;
+};
+
+struct IndexInit {
+ std::unique_ptr<KernelState> state;
+ KernelContext* ctx;
+ const IndexOptions& options;
+ const DataType& type;
+
+ IndexInit(KernelContext* ctx, const IndexOptions& options, const DataType& type)
+ : ctx(ctx), options(options), type(type) {}
+
+ Status Visit(const DataType& type) {
+ return Status::NotImplemented("Index kernel not implemented for ", type.ToString());
+ }
+
+ Status Visit(const BooleanType&) {
+ state.reset(new IndexImpl<BooleanType>(options, ctx->state()));
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_number<Type, Status> Visit(const Type&) {
+ state.reset(new IndexImpl<Type>(options, ctx->state()));
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_base_binary<Type, Status> Visit(const Type&) {
+ state.reset(new IndexImpl<Type>(options, ctx->state()));
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_date<Type, Status> Visit(const Type&) {
+ state.reset(new IndexImpl<Type>(options, ctx->state()));
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_time<Type, Status> Visit(const Type&) {
+ state.reset(new IndexImpl<Type>(options, ctx->state()));
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_timestamp<Type, Status> Visit(const Type&) {
+ state.reset(new IndexImpl<Type>(options, ctx->state()));
+ return Status::OK();
+ }
+
+ Result<std::unique_ptr<KernelState>> Create() {
+ RETURN_NOT_OK(VisitTypeInline(type, this));
+ return std::move(state);
+ }
+
+ static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ if (!args.options) {
+ return Status::Invalid("Must provide IndexOptions for index kernel");
+ }
+ IndexInit visitor(ctx, static_cast<const IndexOptions&>(*args.options),
+ *args.inputs[0].type);
+ return visitor.Create();
+ }
+};
+
+} // namespace
+
+void AddBasicAggKernels(KernelInit init,
+ const std::vector<std::shared_ptr<DataType>>& types,
+ std::shared_ptr<DataType> out_ty, ScalarAggregateFunction* func,
+ SimdLevel::type simd_level) {
+ for (const auto& ty : types) {
+ // array[InT] -> scalar[OutT]
+ auto sig =
+ KernelSignature::Make({InputType::Array(ty->id())}, ValueDescr::Scalar(out_ty));
+ AddAggKernel(std::move(sig), init, func, simd_level);
+ }
+}
+
+void AddScalarAggKernels(KernelInit init,
+ const std::vector<std::shared_ptr<DataType>>& types,
+ std::shared_ptr<DataType> out_ty,
+ ScalarAggregateFunction* func) {
+ for (const auto& ty : types) {
+ // scalar[InT] -> scalar[OutT]
+ auto sig =
+ KernelSignature::Make({InputType::Scalar(ty->id())}, ValueDescr::Scalar(out_ty));
+ AddAggKernel(std::move(sig), init, func, SimdLevel::NONE);
+ }
+}
+
+void AddArrayScalarAggKernels(KernelInit init,
+ const std::vector<std::shared_ptr<DataType>>& types,
+ std::shared_ptr<DataType> out_ty,
+ ScalarAggregateFunction* func,
+ SimdLevel::type simd_level = SimdLevel::NONE) {
+ AddBasicAggKernels(init, types, out_ty, func, simd_level);
+ AddScalarAggKernels(init, types, out_ty, func);
+}
+
+namespace {
+
+Result<ValueDescr> MinMaxType(KernelContext*, const std::vector<ValueDescr>& descrs) {
+ // any[T] -> scalar[struct<min: T, max: T>]
+ auto ty = descrs.front().type;
+ return ValueDescr::Scalar(struct_({field("min", ty), field("max", ty)}));
+}
+
+} // namespace
+
+void AddMinMaxKernel(KernelInit init, internal::detail::GetTypeId get_id,
+ ScalarAggregateFunction* func, SimdLevel::type simd_level) {
+ auto sig = KernelSignature::Make({InputType(get_id.id)}, OutputType(MinMaxType));
+ AddAggKernel(std::move(sig), init, func, simd_level);
+}
+
+void AddMinMaxKernels(KernelInit init,
+ const std::vector<std::shared_ptr<DataType>>& types,
+ ScalarAggregateFunction* func, SimdLevel::type simd_level) {
+ for (const auto& ty : types) {
+ AddMinMaxKernel(init, ty, func, simd_level);
+ }
+}
+
+namespace {
+
+Result<ValueDescr> ScalarFirstType(KernelContext*,
+ const std::vector<ValueDescr>& descrs) {
+ ValueDescr result = descrs.front();
+ result.shape = ValueDescr::SCALAR;
+ return result;
+}
+
+const FunctionDoc count_doc{"Count the number of null / non-null values",
+ ("By default, only non-null values are counted.\n"
+ "This can be changed through CountOptions."),
+ {"array"},
+ "CountOptions"};
+
+const FunctionDoc count_distinct_doc{"Count the number of unique values",
+ ("By default, only non-null values are counted.\n"
+ "This can be changed through CountOptions."),
+ {"array"},
+ "CountOptions"};
+
+const FunctionDoc sum_doc{
+ "Compute the sum of a numeric array",
+ ("Null values are ignored by default. Minimum count of non-null\n"
+ "values can be set and null is returned if too few are present.\n"
+ "This can be changed through ScalarAggregateOptions."),
+ {"array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc product_doc{
+ "Compute the product of values in a numeric array",
+ ("Null values are ignored by default. Minimum count of non-null\n"
+ "values can be set and null is returned if too few are present.\n"
+ "This can be changed through ScalarAggregateOptions."),
+ {"array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc mean_doc{
+ "Compute the mean of a numeric array",
+ ("Null values are ignored by default. Minimum count of non-null\n"
+ "values can be set and null is returned if too few are present.\n"
+ "This can be changed through ScalarAggregateOptions.\n"
+ "The result is a double for integer and floating point arguments,\n"
+ "and a decimal with the same bit-width/precision/scale for decimal arguments.\n"
+ "For integers and floats, NaN is returned if min_count = 0 and\n"
+ "there are no values. For decimals, null is returned instead."),
+ {"array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc min_max_doc{"Compute the minimum and maximum values of a numeric array",
+ ("Null values are ignored by default.\n"
+ "This can be changed through ScalarAggregateOptions."),
+ {"array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc min_or_max_doc{
+ "Compute the minimum or maximum values of a numeric array",
+ ("Null values are ignored by default.\n"
+ "This can be changed through ScalarAggregateOptions."),
+ {"array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc any_doc{"Test whether any element in a boolean array evaluates to true",
+ ("Null values are ignored by default.\n"
+ "If null values are taken into account by setting "
+ "ScalarAggregateOptions parameter skip_nulls = false then "
+ "Kleene logic is used.\n"
+ "See KleeneOr for more details on Kleene logic."),
+ {"array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc all_doc{"Test whether all elements in a boolean array evaluate to true",
+ ("Null values are ignored by default.\n"
+ "If null values are taken into account by setting "
+ "ScalarAggregateOptions parameter skip_nulls = false then "
+ "Kleene logic is used.\n"
+ "See KleeneAnd for more details on Kleene logic."),
+ {"array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc index_doc{"Find the index of the first occurrence of a given value",
+ ("The result is always computed as an int64_t, regardless\n"
+ "of the offset type of the input array."),
+ {"array"},
+ "IndexOptions"};
+
+} // namespace
+
+void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
+ static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults();
+ static auto default_count_options = CountOptions::Defaults();
+
+ auto func = std::make_shared<ScalarAggregateFunction>(
+ "count", Arity::Unary(), &count_doc, &default_count_options);
+
+ // Takes any input, outputs int64 scalar
+ InputType any_input;
+ AddAggKernel(KernelSignature::Make({any_input}, ValueDescr::Scalar(int64())), CountInit,
+ func.get());
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ func = std::make_shared<ScalarAggregateFunction>(
+ "count_distinct", Arity::Unary(), &count_distinct_doc, &default_count_options);
+ // Takes any input, outputs int64 scalar
+ AddCountDistinctKernels(func.get());
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ func = std::make_shared<ScalarAggregateFunction>("sum", Arity::Unary(), &sum_doc,
+ &default_scalar_aggregate_options);
+ AddArrayScalarAggKernels(SumInit, {boolean()}, uint64(), func.get());
+ AddAggKernel(
+ KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)),
+ SumInit, func.get(), SimdLevel::NONE);
+ AddAggKernel(
+ KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)),
+ SumInit, func.get(), SimdLevel::NONE);
+ AddArrayScalarAggKernels(SumInit, SignedIntTypes(), int64(), func.get());
+ AddArrayScalarAggKernels(SumInit, UnsignedIntTypes(), uint64(), func.get());
+ AddArrayScalarAggKernels(SumInit, FloatingPointTypes(), float64(), func.get());
+ // Add the SIMD variants for sum
+#if defined(ARROW_HAVE_RUNTIME_AVX2) || defined(ARROW_HAVE_RUNTIME_AVX512)
+ auto cpu_info = arrow::internal::CpuInfo::GetInstance();
+#endif
+#if defined(ARROW_HAVE_RUNTIME_AVX2)
+ if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) {
+ AddSumAvx2AggKernels(func.get());
+ }
+#endif
+#if defined(ARROW_HAVE_RUNTIME_AVX512)
+ if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) {
+ AddSumAvx512AggKernels(func.get());
+ }
+#endif
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ func = std::make_shared<ScalarAggregateFunction>("mean", Arity::Unary(), &mean_doc,
+ &default_scalar_aggregate_options);
+ AddArrayScalarAggKernels(MeanInit, {boolean()}, float64(), func.get());
+ AddArrayScalarAggKernels(MeanInit, NumericTypes(), float64(), func.get());
+ AddAggKernel(
+ KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)),
+ MeanInit, func.get(), SimdLevel::NONE);
+ AddAggKernel(
+ KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)),
+ MeanInit, func.get(), SimdLevel::NONE);
+ // Add the SIMD variants for mean
+#if defined(ARROW_HAVE_RUNTIME_AVX2)
+ if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) {
+ AddMeanAvx2AggKernels(func.get());
+ }
+#endif
+#if defined(ARROW_HAVE_RUNTIME_AVX512)
+ if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) {
+ AddMeanAvx512AggKernels(func.get());
+ }
+#endif
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ func = std::make_shared<ScalarAggregateFunction>(
+ "min_max", Arity::Unary(), &min_max_doc, &default_scalar_aggregate_options);
+ AddMinMaxKernels(MinMaxInit, {null(), boolean()}, func.get());
+ AddMinMaxKernels(MinMaxInit, NumericTypes(), func.get());
+ AddMinMaxKernels(MinMaxInit, TemporalTypes(), func.get());
+ AddMinMaxKernels(MinMaxInit, BaseBinaryTypes(), func.get());
+ AddMinMaxKernel(MinMaxInit, Type::FIXED_SIZE_BINARY, func.get());
+ AddMinMaxKernel(MinMaxInit, Type::INTERVAL_MONTHS, func.get());
+ AddMinMaxKernel(MinMaxInit, Type::DECIMAL128, func.get());
+ AddMinMaxKernel(MinMaxInit, Type::DECIMAL256, func.get());
+ // Add the SIMD variants for min max
+#if defined(ARROW_HAVE_RUNTIME_AVX2)
+ if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) {
+ AddMinMaxAvx2AggKernels(func.get());
+ }
+#endif
+#if defined(ARROW_HAVE_RUNTIME_AVX512)
+ if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) {
+ AddMinMaxAvx512AggKernels(func.get());
+ }
+#endif
+
+ auto min_max_func = func.get();
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ // Add min/max as convenience functions
+ func = std::make_shared<ScalarAggregateFunction>("min", Arity::Unary(), &min_or_max_doc,
+ &default_scalar_aggregate_options);
+ AddMinOrMaxAggKernel<MinOrMax::Min>(func.get(), min_max_func);
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ func = std::make_shared<ScalarAggregateFunction>("max", Arity::Unary(), &min_or_max_doc,
+ &default_scalar_aggregate_options);
+ AddMinOrMaxAggKernel<MinOrMax::Max>(func.get(), min_max_func);
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ func = std::make_shared<ScalarAggregateFunction>(
+ "product", Arity::Unary(), &product_doc, &default_scalar_aggregate_options);
+ AddArrayScalarAggKernels(ProductInit::Init, {boolean()}, uint64(), func.get());
+ AddArrayScalarAggKernels(ProductInit::Init, SignedIntTypes(), int64(), func.get());
+ AddArrayScalarAggKernels(ProductInit::Init, UnsignedIntTypes(), uint64(), func.get());
+ AddArrayScalarAggKernels(ProductInit::Init, FloatingPointTypes(), float64(),
+ func.get());
+ AddAggKernel(
+ KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)),
+ ProductInit::Init, func.get(), SimdLevel::NONE);
+ AddAggKernel(
+ KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)),
+ ProductInit::Init, func.get(), SimdLevel::NONE);
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ // any
+ func = std::make_shared<ScalarAggregateFunction>("any", Arity::Unary(), &any_doc,
+ &default_scalar_aggregate_options);
+ AddArrayScalarAggKernels(AnyInit, {boolean()}, boolean(), func.get());
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ // all
+ func = std::make_shared<ScalarAggregateFunction>("all", Arity::Unary(), &all_doc,
+ &default_scalar_aggregate_options);
+ AddArrayScalarAggKernels(AllInit, {boolean()}, boolean(), func.get());
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ // index
+ func = std::make_shared<ScalarAggregateFunction>("index", Arity::Unary(), &index_doc);
+ AddBasicAggKernels(IndexInit::Init, BaseBinaryTypes(), int64(), func.get());
+ AddBasicAggKernels(IndexInit::Init, PrimitiveTypes(), int64(), func.get());
+ AddBasicAggKernels(IndexInit::Init, TemporalTypes(), int64(), func.get());
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc
new file mode 100644
index 000000000..00e3e2e5f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/kernels/aggregate_basic_internal.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+// ----------------------------------------------------------------------
+// Sum implementation
+
+template <typename ArrowType>
+struct SumImplAvx2 : public SumImpl<ArrowType, SimdLevel::AVX2> {
+ using SumImpl<ArrowType, SimdLevel::AVX2>::SumImpl;
+};
+
+template <typename ArrowType>
+struct MeanImplAvx2 : public MeanImpl<ArrowType, SimdLevel::AVX2> {
+ using MeanImpl<ArrowType, SimdLevel::AVX2>::MeanImpl;
+};
+
+Result<std::unique_ptr<KernelState>> SumInitAvx2(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ SumLikeInit<SumImplAvx2> visitor(
+ ctx, args.inputs[0].type,
+ static_cast<const ScalarAggregateOptions&>(*args.options));
+ return visitor.Create();
+}
+
+Result<std::unique_ptr<KernelState>> MeanInitAvx2(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ SumLikeInit<MeanImplAvx2> visitor(
+ ctx, args.inputs[0].type,
+ static_cast<const ScalarAggregateOptions&>(*args.options));
+ return visitor.Create();
+}
+
+// ----------------------------------------------------------------------
+// MinMax implementation
+
+Result<std::unique_ptr<KernelState>> MinMaxInitAvx2(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ ARROW_ASSIGN_OR_RAISE(auto out_type,
+ args.kernel->signature->out_type().Resolve(ctx, args.inputs));
+ MinMaxInitState<SimdLevel::AVX2> visitor(
+ ctx, *args.inputs[0].type, std::move(out_type.type),
+ static_cast<const ScalarAggregateOptions&>(*args.options));
+ return visitor.Create();
+}
+
+void AddSumAvx2AggKernels(ScalarAggregateFunction* func) {
+ AddBasicAggKernels(SumInitAvx2, SignedIntTypes(), int64(), func, SimdLevel::AVX2);
+ AddBasicAggKernels(SumInitAvx2, UnsignedIntTypes(), uint64(), func, SimdLevel::AVX2);
+ AddBasicAggKernels(SumInitAvx2, FloatingPointTypes(), float64(), func, SimdLevel::AVX2);
+}
+
+void AddMeanAvx2AggKernels(ScalarAggregateFunction* func) {
+ AddBasicAggKernels(MeanInitAvx2, NumericTypes(), float64(), func, SimdLevel::AVX2);
+}
+
+void AddMinMaxAvx2AggKernels(ScalarAggregateFunction* func) {
+ // Enable int types for AVX2 variants.
+ // No auto vectorize for float/double as it use fmin/fmax which has NaN handling.
+ AddMinMaxKernels(MinMaxInitAvx2, IntTypes(), func, SimdLevel::AVX2);
+ AddMinMaxKernels(MinMaxInitAvx2, TemporalTypes(), func, SimdLevel::AVX2);
+ AddMinMaxKernels(MinMaxInitAvx2, BaseBinaryTypes(), func, SimdLevel::AVX2);
+ AddMinMaxKernel(MinMaxInitAvx2, Type::FIXED_SIZE_BINARY, func, SimdLevel::AVX2);
+ AddMinMaxKernel(MinMaxInitAvx2, Type::INTERVAL_MONTHS, func, SimdLevel::AVX2);
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc
new file mode 100644
index 000000000..8c10eb19b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc
@@ -0,0 +1,90 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/kernels/aggregate_basic_internal.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+// ----------------------------------------------------------------------
+// Sum implementation
+
+template <typename ArrowType>
+struct SumImplAvx512 : public SumImpl<ArrowType, SimdLevel::AVX512> {
+ using SumImpl<ArrowType, SimdLevel::AVX512>::SumImpl;
+};
+
+template <typename ArrowType>
+struct MeanImplAvx512 : public MeanImpl<ArrowType, SimdLevel::AVX512> {
+ using MeanImpl<ArrowType, SimdLevel::AVX512>::MeanImpl;
+};
+
+Result<std::unique_ptr<KernelState>> SumInitAvx512(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ SumLikeInit<SumImplAvx512> visitor(
+ ctx, args.inputs[0].type,
+ static_cast<const ScalarAggregateOptions&>(*args.options));
+ return visitor.Create();
+}
+
+Result<std::unique_ptr<KernelState>> MeanInitAvx512(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ SumLikeInit<MeanImplAvx512> visitor(
+ ctx, args.inputs[0].type,
+ static_cast<const ScalarAggregateOptions&>(*args.options));
+ return visitor.Create();
+}
+
+// ----------------------------------------------------------------------
+// MinMax implementation
+
+Result<std::unique_ptr<KernelState>> MinMaxInitAvx512(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ ARROW_ASSIGN_OR_RAISE(auto out_type,
+ args.kernel->signature->out_type().Resolve(ctx, args.inputs));
+ MinMaxInitState<SimdLevel::AVX512> visitor(
+ ctx, *args.inputs[0].type, std::move(out_type.type),
+ static_cast<const ScalarAggregateOptions&>(*args.options));
+ return visitor.Create();
+}
+
+void AddSumAvx512AggKernels(ScalarAggregateFunction* func) {
+ AddBasicAggKernels(SumInitAvx512, SignedIntTypes(), int64(), func, SimdLevel::AVX512);
+ AddBasicAggKernels(SumInitAvx512, UnsignedIntTypes(), uint64(), func,
+ SimdLevel::AVX512);
+ AddBasicAggKernels(SumInitAvx512, FloatingPointTypes(), float64(), func,
+ SimdLevel::AVX512);
+}
+
+void AddMeanAvx512AggKernels(ScalarAggregateFunction* func) {
+ AddBasicAggKernels(MeanInitAvx512, NumericTypes(), float64(), func, SimdLevel::AVX512);
+}
+
+void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func) {
+ // Enable 32/64 int types for avx512 variants, no advantage on 8/16 int.
+ AddMinMaxKernels(MinMaxInitAvx512, {int32(), uint32(), int64(), uint64()}, func,
+ SimdLevel::AVX512);
+ AddMinMaxKernels(MinMaxInitAvx512, TemporalTypes(), func, SimdLevel::AVX512);
+ AddMinMaxKernels(MinMaxInitAvx512, BaseBinaryTypes(), func, SimdLevel::AVX2);
+ AddMinMaxKernel(MinMaxInitAvx512, Type::FIXED_SIZE_BINARY, func, SimdLevel::AVX2);
+ AddMinMaxKernel(MinMaxInitAvx512, Type::INTERVAL_MONTHS, func, SimdLevel::AVX512);
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
new file mode 100644
index 000000000..156e908ea
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
@@ -0,0 +1,626 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cmath>
+#include <utility>
+
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/kernels/aggregate_internal.h"
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/util/align_util.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/decimal.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+void AddBasicAggKernels(KernelInit init,
+ const std::vector<std::shared_ptr<DataType>>& types,
+ std::shared_ptr<DataType> out_ty, ScalarAggregateFunction* func,
+ SimdLevel::type simd_level = SimdLevel::NONE);
+
+void AddMinMaxKernels(KernelInit init,
+ const std::vector<std::shared_ptr<DataType>>& types,
+ ScalarAggregateFunction* func,
+ SimdLevel::type simd_level = SimdLevel::NONE);
+void AddMinMaxKernel(KernelInit init, internal::detail::GetTypeId get_id,
+ ScalarAggregateFunction* func,
+ SimdLevel::type simd_level = SimdLevel::NONE);
+
+// SIMD variants for kernels
+void AddSumAvx2AggKernels(ScalarAggregateFunction* func);
+void AddMeanAvx2AggKernels(ScalarAggregateFunction* func);
+void AddMinMaxAvx2AggKernels(ScalarAggregateFunction* func);
+
+void AddSumAvx512AggKernels(ScalarAggregateFunction* func);
+void AddMeanAvx512AggKernels(ScalarAggregateFunction* func);
+void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func);
+
+// ----------------------------------------------------------------------
+// Sum implementation
+
+template <typename ArrowType, SimdLevel::type SimdLevel>
+struct SumImpl : public ScalarAggregator {
+ using ThisType = SumImpl<ArrowType, SimdLevel>;
+ using CType = typename TypeTraits<ArrowType>::CType;
+ using SumType = typename FindAccumulatorType<ArrowType>::Type;
+ using SumCType = typename TypeTraits<SumType>::CType;
+ using OutputType = typename TypeTraits<SumType>::ScalarType;
+
+ SumImpl(const std::shared_ptr<DataType>& out_type,
+ const ScalarAggregateOptions& options_)
+ : out_type(out_type), options(options_) {}
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ if (batch[0].is_array()) {
+ const auto& data = batch[0].array();
+ this->count += data->length - data->GetNullCount();
+ this->nulls_observed = this->nulls_observed || data->GetNullCount();
+
+ if (!options.skip_nulls && this->nulls_observed) {
+ // Short-circuit
+ return Status::OK();
+ }
+
+ if (is_boolean_type<ArrowType>::value) {
+ this->sum += static_cast<SumCType>(BooleanArray(data).true_count());
+ } else {
+ this->sum += SumArray<CType, SumCType, SimdLevel>(*data);
+ }
+ } else {
+ const auto& data = *batch[0].scalar();
+ this->count += data.is_valid * batch.length;
+ this->nulls_observed = this->nulls_observed || !data.is_valid;
+ if (data.is_valid) {
+ this->sum += internal::UnboxScalar<ArrowType>::Unbox(data) * batch.length;
+ }
+ }
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override {
+ const auto& other = checked_cast<const ThisType&>(src);
+ this->count += other.count;
+ this->sum += other.sum;
+ this->nulls_observed = this->nulls_observed || other.nulls_observed;
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext*, Datum* out) override {
+ if ((!options.skip_nulls && this->nulls_observed) ||
+ (this->count < options.min_count)) {
+ out->value = std::make_shared<OutputType>(out_type);
+ } else {
+ out->value = std::make_shared<OutputType>(this->sum, out_type);
+ }
+ return Status::OK();
+ }
+
+ size_t count = 0;
+ bool nulls_observed = false;
+ SumCType sum = 0;
+ std::shared_ptr<DataType> out_type;
+ ScalarAggregateOptions options;
+};
+
+template <typename ArrowType, SimdLevel::type SimdLevel>
+struct MeanImpl : public SumImpl<ArrowType, SimdLevel> {
+ using SumImpl<ArrowType, SimdLevel>::SumImpl;
+
+ template <typename T = ArrowType>
+ enable_if_decimal<T, Status> FinalizeImpl(Datum* out) {
+ using SumCType = typename SumImpl<ArrowType, SimdLevel>::SumCType;
+ using OutputType = typename SumImpl<ArrowType, SimdLevel>::OutputType;
+ if ((!options.skip_nulls && this->nulls_observed) ||
+ (this->count < options.min_count) || (this->count == 0)) {
+ out->value = std::make_shared<OutputType>(this->out_type);
+ } else {
+ const SumCType mean = this->sum / this->count;
+ out->value = std::make_shared<OutputType>(mean, this->out_type);
+ }
+ return Status::OK();
+ }
+ template <typename T = ArrowType>
+ enable_if_t<!is_decimal_type<T>::value, Status> FinalizeImpl(Datum* out) {
+ if ((!options.skip_nulls && this->nulls_observed) ||
+ (this->count < options.min_count)) {
+ out->value = std::make_shared<DoubleScalar>();
+ } else {
+ const double mean = static_cast<double>(this->sum) / this->count;
+ out->value = std::make_shared<DoubleScalar>(mean);
+ }
+ return Status::OK();
+ }
+ Status Finalize(KernelContext*, Datum* out) override { return FinalizeImpl(out); }
+
+ using SumImpl<ArrowType, SimdLevel>::options;
+};
+
+template <template <typename> class KernelClass>
+struct SumLikeInit {
+ std::unique_ptr<KernelState> state;
+ KernelContext* ctx;
+ const std::shared_ptr<DataType> type;
+ const ScalarAggregateOptions& options;
+
+ SumLikeInit(KernelContext* ctx, const std::shared_ptr<DataType>& type,
+ const ScalarAggregateOptions& options)
+ : ctx(ctx), type(type), options(options) {}
+
+ Status Visit(const DataType&) { return Status::NotImplemented("No sum implemented"); }
+
+ Status Visit(const HalfFloatType&) {
+ return Status::NotImplemented("No sum implemented");
+ }
+
+ Status Visit(const BooleanType&) {
+ auto ty = TypeTraits<typename KernelClass<BooleanType>::SumType>::type_singleton();
+ state.reset(new KernelClass<BooleanType>(ty, options));
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_number<Type, Status> Visit(const Type&) {
+ auto ty = TypeTraits<typename KernelClass<Type>::SumType>::type_singleton();
+ state.reset(new KernelClass<Type>(ty, options));
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_decimal<Type, Status> Visit(const Type&) {
+ state.reset(new KernelClass<Type>(type, options));
+ return Status::OK();
+ }
+
+ Result<std::unique_ptr<KernelState>> Create() {
+ RETURN_NOT_OK(VisitTypeInline(*type, this));
+ return std::move(state);
+ }
+};
+
+// ----------------------------------------------------------------------
+// MinMax implementation
+
+template <typename ArrowType, SimdLevel::type SimdLevel, typename Enable = void>
+struct MinMaxState {};
+
+template <typename ArrowType, SimdLevel::type SimdLevel>
+struct MinMaxState<ArrowType, SimdLevel, enable_if_boolean<ArrowType>> {
+ using ThisType = MinMaxState<ArrowType, SimdLevel>;
+ using T = typename ArrowType::c_type;
+
+ ThisType& operator+=(const ThisType& rhs) {
+ this->has_nulls |= rhs.has_nulls;
+ this->min = this->min && rhs.min;
+ this->max = this->max || rhs.max;
+ return *this;
+ }
+
+ void MergeOne(T value) {
+ this->min = this->min && value;
+ this->max = this->max || value;
+ }
+
+ T min = true;
+ T max = false;
+ bool has_nulls = false;
+};
+
+template <typename ArrowType, SimdLevel::type SimdLevel>
+struct MinMaxState<ArrowType, SimdLevel, enable_if_integer<ArrowType>> {
+ using ThisType = MinMaxState<ArrowType, SimdLevel>;
+ using T = typename ArrowType::c_type;
+ using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+
+ ThisType& operator+=(const ThisType& rhs) {
+ this->has_nulls |= rhs.has_nulls;
+ this->min = std::min(this->min, rhs.min);
+ this->max = std::max(this->max, rhs.max);
+ return *this;
+ }
+
+ void MergeOne(T value) {
+ this->min = std::min(this->min, value);
+ this->max = std::max(this->max, value);
+ }
+
+ T min = std::numeric_limits<T>::max();
+ T max = std::numeric_limits<T>::min();
+ bool has_nulls = false;
+};
+
+template <typename ArrowType, SimdLevel::type SimdLevel>
+struct MinMaxState<ArrowType, SimdLevel, enable_if_floating_point<ArrowType>> {
+ using ThisType = MinMaxState<ArrowType, SimdLevel>;
+ using T = typename ArrowType::c_type;
+ using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+
+ ThisType& operator+=(const ThisType& rhs) {
+ this->has_nulls |= rhs.has_nulls;
+ this->min = std::fmin(this->min, rhs.min);
+ this->max = std::fmax(this->max, rhs.max);
+ return *this;
+ }
+
+ void MergeOne(T value) {
+ this->min = std::fmin(this->min, value);
+ this->max = std::fmax(this->max, value);
+ }
+
+ T min = std::numeric_limits<T>::infinity();
+ T max = -std::numeric_limits<T>::infinity();
+ bool has_nulls = false;
+};
+
+template <typename ArrowType, SimdLevel::type SimdLevel>
+struct MinMaxState<ArrowType, SimdLevel, enable_if_decimal<ArrowType>> {
+ using ThisType = MinMaxState<ArrowType, SimdLevel>;
+ using T = typename TypeTraits<ArrowType>::CType;
+ using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+
+ MinMaxState() : min(T::GetMaxSentinel()), max(T::GetMinSentinel()) {}
+
+ ThisType& operator+=(const ThisType& rhs) {
+ this->has_nulls |= rhs.has_nulls;
+ this->min = std::min(this->min, rhs.min);
+ this->max = std::max(this->max, rhs.max);
+ return *this;
+ }
+
+ void MergeOne(util::string_view value) {
+ MergeOne(T(reinterpret_cast<const uint8_t*>(value.data())));
+ }
+
+ void MergeOne(const T value) {
+ this->min = std::min(this->min, value);
+ this->max = std::max(this->max, value);
+ }
+
+ T min;
+ T max;
+ bool has_nulls = false;
+};
+
+template <typename ArrowType, SimdLevel::type SimdLevel>
+struct MinMaxState<ArrowType, SimdLevel,
+ enable_if_t<is_base_binary_type<ArrowType>::value ||
+ std::is_same<ArrowType, FixedSizeBinaryType>::value>> {
+ using ThisType = MinMaxState<ArrowType, SimdLevel>;
+ using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+
+ ThisType& operator+=(const ThisType& rhs) {
+ if (!this->seen && rhs.seen) {
+ this->min = rhs.min;
+ this->max = rhs.max;
+ } else if (this->seen && rhs.seen) {
+ if (this->min > rhs.min) {
+ this->min = rhs.min;
+ }
+ if (this->max < rhs.max) {
+ this->max = rhs.max;
+ }
+ }
+ this->has_nulls |= rhs.has_nulls;
+ this->seen |= rhs.seen;
+ return *this;
+ }
+
+ void MergeOne(util::string_view value) {
+ if (!seen) {
+ this->min = std::string(value);
+ this->max = std::string(value);
+ } else {
+ if (value < util::string_view(this->min)) {
+ this->min = std::string(value);
+ } else if (value > util::string_view(this->max)) {
+ this->max = std::string(value);
+ }
+ }
+ this->seen = true;
+ }
+
+ std::string min;
+ std::string max;
+ bool has_nulls = false;
+ bool seen = false;
+};
+
+template <typename ArrowType, SimdLevel::type SimdLevel>
+struct MinMaxImpl : public ScalarAggregator {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ using ThisType = MinMaxImpl<ArrowType, SimdLevel>;
+ using StateType = MinMaxState<ArrowType, SimdLevel>;
+
+ MinMaxImpl(std::shared_ptr<DataType> out_type, ScalarAggregateOptions options)
+ : out_type(std::move(out_type)), options(std::move(options)), count(0) {
+ this->options.min_count = std::max<uint32_t>(1, this->options.min_count);
+ }
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ if (batch[0].is_array()) {
+ return ConsumeArray(ArrayType(batch[0].array()));
+ }
+ return ConsumeScalar(*batch[0].scalar());
+ }
+
+ Status ConsumeScalar(const Scalar& scalar) {
+ StateType local;
+ local.has_nulls = !scalar.is_valid;
+ this->count += scalar.is_valid;
+
+ if (local.has_nulls && !options.skip_nulls) {
+ this->state = local;
+ return Status::OK();
+ }
+
+ local.MergeOne(internal::UnboxScalar<ArrowType>::Unbox(scalar));
+ this->state = local;
+ return Status::OK();
+ }
+
+ Status ConsumeArray(const ArrayType& arr) {
+ StateType local;
+
+ const auto null_count = arr.null_count();
+ local.has_nulls = null_count > 0;
+ this->count += arr.length() - null_count;
+
+ if (local.has_nulls && !options.skip_nulls) {
+ this->state = local;
+ return Status::OK();
+ }
+
+ if (local.has_nulls) {
+ local += ConsumeWithNulls(arr);
+ } else { // All true values
+ for (int64_t i = 0; i < arr.length(); i++) {
+ local.MergeOne(arr.GetView(i));
+ }
+ }
+ this->state = local;
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override {
+ const auto& other = checked_cast<const ThisType&>(src);
+ this->state += other.state;
+ this->count += other.count;
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext*, Datum* out) override {
+ const auto& struct_type = checked_cast<const StructType&>(*out_type);
+ const auto& child_type = struct_type.field(0)->type();
+
+ std::vector<std::shared_ptr<Scalar>> values;
+ // Physical type != result type
+ if ((state.has_nulls && !options.skip_nulls) || (this->count < options.min_count)) {
+ // (null, null)
+ auto null_scalar = MakeNullScalar(child_type);
+ values = {null_scalar, null_scalar};
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto min_scalar,
+ MakeScalar(child_type, std::move(state.min)));
+ ARROW_ASSIGN_OR_RAISE(auto max_scalar,
+ MakeScalar(child_type, std::move(state.max)));
+ values = {std::move(min_scalar), std::move(max_scalar)};
+ }
+ out->value = std::make_shared<StructScalar>(std::move(values), this->out_type);
+ return Status::OK();
+ }
+
+ std::shared_ptr<DataType> out_type;
+ ScalarAggregateOptions options;
+ int64_t count;
+ MinMaxState<ArrowType, SimdLevel> state;
+
+ private:
+ StateType ConsumeWithNulls(const ArrayType& arr) const {
+ StateType local;
+ const int64_t length = arr.length();
+ int64_t offset = arr.offset();
+ const uint8_t* bitmap = arr.null_bitmap_data();
+ int64_t idx = 0;
+
+ const auto p = arrow::internal::BitmapWordAlign<1>(bitmap, offset, length);
+ // First handle the leading bits
+ const int64_t leading_bits = p.leading_bits;
+ while (idx < leading_bits) {
+ if (BitUtil::GetBit(bitmap, offset)) {
+ local.MergeOne(arr.GetView(idx));
+ }
+ idx++;
+ offset++;
+ }
+
+ // The aligned parts scanned with BitBlockCounter
+ arrow::internal::BitBlockCounter data_counter(bitmap, offset, length - leading_bits);
+ auto current_block = data_counter.NextWord();
+ while (idx < length) {
+ if (current_block.AllSet()) { // All true values
+ int run_length = 0;
+ // Scan forward until a block that has some false values (or the end)
+ while (current_block.length > 0 && current_block.AllSet()) {
+ run_length += current_block.length;
+ current_block = data_counter.NextWord();
+ }
+ for (int64_t i = 0; i < run_length; i++) {
+ local.MergeOne(arr.GetView(idx + i));
+ }
+ idx += run_length;
+ offset += run_length;
+ // The current_block already computed, advance to next loop
+ continue;
+ } else if (!current_block.NoneSet()) { // Some values are null
+ BitmapReader reader(arr.null_bitmap_data(), offset, current_block.length);
+ for (int64_t i = 0; i < current_block.length; i++) {
+ if (reader.IsSet()) {
+ local.MergeOne(arr.GetView(idx + i));
+ }
+ reader.Next();
+ }
+
+ idx += current_block.length;
+ offset += current_block.length;
+ } else { // All null values
+ idx += current_block.length;
+ offset += current_block.length;
+ }
+ current_block = data_counter.NextWord();
+ }
+
+ return local;
+ }
+};
+
+template <SimdLevel::type SimdLevel>
+struct BooleanMinMaxImpl : public MinMaxImpl<BooleanType, SimdLevel> {
+ using StateType = MinMaxState<BooleanType, SimdLevel>;
+ using ArrayType = typename TypeTraits<BooleanType>::ArrayType;
+ using MinMaxImpl<BooleanType, SimdLevel>::MinMaxImpl;
+ using MinMaxImpl<BooleanType, SimdLevel>::options;
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ if (ARROW_PREDICT_FALSE(batch[0].is_scalar())) {
+ return ConsumeScalar(checked_cast<const BooleanScalar&>(*batch[0].scalar()));
+ }
+ StateType local;
+ ArrayType arr(batch[0].array());
+
+ const auto arr_length = arr.length();
+ const auto null_count = arr.null_count();
+ const auto valid_count = arr_length - null_count;
+
+ local.has_nulls = null_count > 0;
+ this->count += valid_count;
+ if (local.has_nulls && !options.skip_nulls) {
+ this->state = local;
+ return Status::OK();
+ }
+
+ const auto true_count = arr.true_count();
+ const auto false_count = valid_count - true_count;
+ local.max = true_count > 0;
+ local.min = false_count == 0;
+
+ this->state = local;
+ return Status::OK();
+ }
+
+ Status ConsumeScalar(const BooleanScalar& scalar) {
+ StateType local;
+
+ local.has_nulls = !scalar.is_valid;
+ this->count += scalar.is_valid;
+ if (local.has_nulls && !options.skip_nulls) {
+ this->state = local;
+ return Status::OK();
+ }
+
+ const int true_count = scalar.is_valid && scalar.value;
+ const int false_count = scalar.is_valid && !scalar.value;
+ local.max = true_count > 0;
+ local.min = false_count == 0;
+
+ this->state = local;
+ return Status::OK();
+ }
+};
+
+struct NullMinMaxImpl : public ScalarAggregator {
+ Status Consume(KernelContext*, const ExecBatch& batch) override { return Status::OK(); }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override { return Status::OK(); }
+
+ Status Finalize(KernelContext*, Datum* out) override {
+ std::vector<std::shared_ptr<Scalar>> values{std::make_shared<NullScalar>(),
+ std::make_shared<NullScalar>()};
+ out->value = std::make_shared<StructScalar>(
+ std::move(values), struct_({field("min", null()), field("max", null())}));
+ return Status::OK();
+ }
+};
+
+template <SimdLevel::type SimdLevel>
+struct MinMaxInitState {
+ std::unique_ptr<KernelState> state;
+ KernelContext* ctx;
+ const DataType& in_type;
+ const std::shared_ptr<DataType>& out_type;
+ const ScalarAggregateOptions& options;
+
+ MinMaxInitState(KernelContext* ctx, const DataType& in_type,
+ const std::shared_ptr<DataType>& out_type,
+ const ScalarAggregateOptions& options)
+ : ctx(ctx), in_type(in_type), out_type(out_type), options(options) {}
+
+ Status Visit(const DataType& ty) {
+ return Status::NotImplemented("No min/max implemented for ", ty);
+ }
+
+ Status Visit(const HalfFloatType& ty) {
+ return Status::NotImplemented("No min/max implemented for ", ty);
+ }
+
+ Status Visit(const NullType&) {
+ state.reset(new NullMinMaxImpl());
+ return Status::OK();
+ }
+
+ Status Visit(const BooleanType&) {
+ state.reset(new BooleanMinMaxImpl<SimdLevel>(out_type, options));
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_physical_integer<Type, Status> Visit(const Type&) {
+ using PhysicalType = typename Type::PhysicalType;
+ state.reset(new MinMaxImpl<PhysicalType, SimdLevel>(out_type, options));
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_floating_point<Type, Status> Visit(const Type&) {
+ state.reset(new MinMaxImpl<Type, SimdLevel>(out_type, options));
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_base_binary<Type, Status> Visit(const Type&) {
+ state.reset(new MinMaxImpl<Type, SimdLevel>(out_type, options));
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_fixed_size_binary<Type, Status> Visit(const Type&) {
+ state.reset(new MinMaxImpl<Type, SimdLevel>(out_type, options));
+ return Status::OK();
+ }
+
+ Result<std::unique_ptr<KernelState>> Create() {
+ RETURN_NOT_OK(VisitTypeInline(in_type, this));
+ return std::move(state);
+ }
+};
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc
new file mode 100644
index 000000000..39cfeb039
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc
@@ -0,0 +1,752 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <vector>
+
+#include "arrow/compute/api.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_reader.h"
+
+namespace arrow {
+namespace compute {
+
+#include <cassert>
+#include <cmath>
+#include <iostream>
+#include <random>
+
+#ifdef ARROW_WITH_BENCHMARKS_REFERENCE
+
+namespace BitUtil = arrow::BitUtil;
+using arrow::internal::BitmapReader;
+
+template <typename T>
+struct SumState {
+ using ValueType = T;
+
+ SumState() : total(0), valid_count(0) {}
+
+ T total = 0;
+ int64_t valid_count = 0;
+};
+
+template <typename T>
+struct Traits {};
+
+template <>
+struct Traits<int64_t> {
+ using ArrayType = typename CTypeTraits<int64_t>::ArrayType;
+ static constexpr int64_t null_sentinel = std::numeric_limits<int64_t>::lowest();
+
+ static void FixSentinel(std::shared_ptr<ArrayType>& array) {
+ auto data = array->data();
+ for (int64_t i = 0; i < array->length(); i++)
+ if (array->IsNull(i)) {
+ int64_t* val_ptr = data->GetMutableValues<int64_t>(1, i);
+ *val_ptr = null_sentinel;
+ }
+ }
+
+ static inline bool IsNull(int64_t val) { return val == null_sentinel; }
+
+ static inline bool NotNull(int64_t val) { return val != null_sentinel; }
+};
+
+template <typename T>
+struct Summer {
+ public:
+ using ValueType = T;
+ using ArrowType = typename CTypeTraits<T>::ArrowType;
+};
+
+template <typename T>
+struct SumNoNulls : public Summer<T> {
+ using ArrayType = typename CTypeTraits<T>::ArrayType;
+
+ static void Sum(const ArrayType& array, SumState<T>* state) {
+ SumState<T> local;
+
+ const auto values = array.raw_values();
+ for (int64_t i = 0; i < array.length(); ++i) {
+ local.total += values[i];
+ }
+
+ local.valid_count = array.length();
+ *state = local;
+ }
+};
+
+template <typename T>
+struct SumNoNullsUnrolled : public Summer<T> {
+ using ArrayType = typename CTypeTraits<T>::ArrayType;
+
+ static void Sum(const ArrayType& array, SumState<T>* state) {
+ SumState<T> local;
+
+ const auto values = array.raw_values();
+ const auto length = array.length();
+ const int64_t length_rounded = BitUtil::RoundDown(length, 8);
+ for (int64_t i = 0; i < length_rounded; i += 8) {
+ local.total += values[i + 0] + values[i + 1] + values[i + 2] + values[i + 3] +
+ values[i + 4] + values[i + 5] + values[i + 6] + values[i + 7];
+ }
+
+ for (int64_t i = length_rounded; i < length; ++i) {
+ local.total += values[i];
+ }
+
+ local.valid_count = length;
+
+ *state = local;
+ }
+};
+
+template <typename T>
+struct SumSentinel : public Summer<T> {
+ using ArrayType = typename CTypeTraits<T>::ArrayType;
+
+ static void Sum(const ArrayType& array, SumState<T>* state) {
+ SumState<T> local;
+
+ const auto values = array.raw_values();
+ const auto length = array.length();
+ for (int64_t i = 0; i < length; i++) {
+ // NaN is not equal to itself
+ local.total += values[i] * Traits<T>::NotNull(values[i]);
+ local.valid_count++;
+ }
+
+ *state = local;
+ }
+};
+
+template <typename T>
+struct SumSentinelUnrolled : public Summer<T> {
+ using ArrayType = typename CTypeTraits<T>::ArrayType;
+
+ static void Sum(const ArrayType& array, SumState<T>* state) {
+ SumState<T> local;
+
+#define SUM_NOT_NULL(ITEM) \
+ do { \
+ local.total += values[i + ITEM] * Traits<T>::NotNull(values[i + ITEM]); \
+ local.valid_count++; \
+ } while (0)
+
+ const auto values = array.raw_values();
+ const auto length = array.length();
+ const int64_t length_rounded = BitUtil::RoundDown(length, 8);
+ for (int64_t i = 0; i < length_rounded; i += 8) {
+ SUM_NOT_NULL(0);
+ SUM_NOT_NULL(1);
+ SUM_NOT_NULL(2);
+ SUM_NOT_NULL(3);
+ SUM_NOT_NULL(4);
+ SUM_NOT_NULL(5);
+ SUM_NOT_NULL(6);
+ SUM_NOT_NULL(7);
+ }
+
+#undef SUM_NOT_NULL
+
+ for (int64_t i = length_rounded * 8; i < length; ++i) {
+ local.total += values[i] * Traits<T>::NotNull(values[i]);
+ ++local.valid_count;
+ }
+
+ *state = local;
+ }
+};
+
+template <typename T>
+struct SumBitmapNaive : public Summer<T> {
+ using ArrayType = typename CTypeTraits<T>::ArrayType;
+
+ static void Sum(const ArrayType& array, SumState<T>* state) {
+ SumState<T> local;
+
+ const auto values = array.raw_values();
+ const auto bitmap = array.null_bitmap_data();
+ const auto length = array.length();
+
+ for (int64_t i = 0; i < length; ++i) {
+ if (BitUtil::GetBit(bitmap, i)) {
+ local.total += values[i];
+ ++local.valid_count;
+ }
+ }
+
+ *state = local;
+ }
+};
+
+template <typename T>
+struct SumBitmapReader : public Summer<T> {
+ using ArrayType = typename CTypeTraits<T>::ArrayType;
+
+ static void Sum(const ArrayType& array, SumState<T>* state) {
+ SumState<T> local;
+
+ const auto values = array.raw_values();
+ const auto bitmap = array.null_bitmap_data();
+ const auto length = array.length();
+ BitmapReader bit_reader(bitmap, 0, length);
+ for (int64_t i = 0; i < length; ++i) {
+ if (bit_reader.IsSet()) {
+ local.total += values[i];
+ ++local.valid_count;
+ }
+
+ bit_reader.Next();
+ }
+
+ *state = local;
+ }
+};
+
+template <typename T>
+struct SumBitmapVectorizeUnroll : public Summer<T> {
+ using ArrayType = typename CTypeTraits<T>::ArrayType;
+
+ static void Sum(const ArrayType& array, SumState<T>* state) {
+ SumState<T> local;
+
+ const auto values = array.raw_values();
+ const auto bitmap = array.null_bitmap_data();
+ const auto length = array.length();
+ const int64_t length_rounded = BitUtil::RoundDown(length, 8);
+ for (int64_t i = 0; i < length_rounded; i += 8) {
+ const uint8_t valid_byte = bitmap[i / 8];
+
+#define SUM_SHIFT(ITEM) (values[i + ITEM] * ((valid_byte >> ITEM) & 1))
+
+ if (valid_byte < 0xFF) {
+ // Some nulls
+ local.total += SUM_SHIFT(0);
+ local.total += SUM_SHIFT(1);
+ local.total += SUM_SHIFT(2);
+ local.total += SUM_SHIFT(3);
+ local.total += SUM_SHIFT(4);
+ local.total += SUM_SHIFT(5);
+ local.total += SUM_SHIFT(6);
+ local.total += SUM_SHIFT(7);
+ local.valid_count += BitUtil::kBytePopcount[valid_byte];
+ } else {
+ // No nulls
+ local.total += values[i + 0] + values[i + 1] + values[i + 2] + values[i + 3] +
+ values[i + 4] + values[i + 5] + values[i + 6] + values[i + 7];
+ local.valid_count += 8;
+ }
+ }
+
+#undef SUM_SHIFT
+
+ for (int64_t i = length_rounded; i < length; ++i) {
+ if (BitUtil::GetBit(bitmap, i)) {
+ local.total = values[i];
+ ++local.valid_count;
+ }
+ }
+
+ *state = local;
+ }
+};
+
+template <typename Functor>
+void ReferenceSum(benchmark::State& state) {
+ using T = typename Functor::ValueType;
+
+ RegressionArgs args(state);
+ const int64_t array_size = args.size / sizeof(int64_t);
+ auto rand = random::RandomArrayGenerator(1923);
+ auto array = std::static_pointer_cast<NumericArray<Int64Type>>(
+ rand.Int64(array_size, -100, 100, args.null_proportion));
+
+ Traits<T>::FixSentinel(array);
+
+ for (auto _ : state) {
+ SumState<T> sum_state;
+ Functor::Sum(*array, &sum_state);
+ benchmark::DoNotOptimize(sum_state);
+ }
+}
+
+BENCHMARK_TEMPLATE(ReferenceSum, SumNoNulls<int64_t>)->Apply(BenchmarkSetArgs);
+BENCHMARK_TEMPLATE(ReferenceSum, SumNoNullsUnrolled<int64_t>)->Apply(BenchmarkSetArgs);
+BENCHMARK_TEMPLATE(ReferenceSum, SumSentinel<int64_t>)->Apply(BenchmarkSetArgs);
+BENCHMARK_TEMPLATE(ReferenceSum, SumSentinelUnrolled<int64_t>)->Apply(BenchmarkSetArgs);
+BENCHMARK_TEMPLATE(ReferenceSum, SumBitmapNaive<int64_t>)->Apply(BenchmarkSetArgs);
+BENCHMARK_TEMPLATE(ReferenceSum, SumBitmapReader<int64_t>)->Apply(BenchmarkSetArgs);
+BENCHMARK_TEMPLATE(ReferenceSum, SumBitmapVectorizeUnroll<int64_t>)
+ ->Apply(BenchmarkSetArgs);
+#endif // ARROW_WITH_BENCHMARKS_REFERENCE
+
+//
+// GroupBy
+//
+
+static void BenchmarkGroupBy(benchmark::State& state,
+ std::vector<internal::Aggregate> aggregates,
+ std::vector<Datum> arguments, std::vector<Datum> keys) {
+ for (auto _ : state) {
+ ABORT_NOT_OK(GroupBy(arguments, keys, aggregates).status());
+ }
+}
+
+#define GROUP_BY_BENCHMARK(Name, Impl) \
+ static void Name(benchmark::State& state) { \
+ RegressionArgs args(state, false); \
+ auto rng = random::RandomArrayGenerator(1923); \
+ (Impl)(); \
+ } \
+ BENCHMARK(Name)->Apply([](benchmark::internal::Benchmark* bench) { \
+ BenchmarkSetArgsWithSizes(bench, {1 * 1024 * 1024}); \
+ })
+
+GROUP_BY_BENCHMARK(SumDoublesGroupedByTinyStringSet, [&] {
+ auto summand = rng.Float64(args.size,
+ /*min=*/0.0,
+ /*max=*/1.0e14,
+ /*null_probability=*/args.null_proportion,
+ /*nan_probability=*/args.null_proportion / 10);
+
+ auto key = rng.StringWithRepeats(args.size,
+ /*unique=*/16,
+ /*min_length=*/3,
+ /*max_length=*/32);
+
+ BenchmarkGroupBy(state, {{"hash_sum", NULLPTR}}, {summand}, {key});
+});
+
+GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallStringSet, [&] {
+ auto summand = rng.Float64(args.size,
+ /*min=*/0.0,
+ /*max=*/1.0e14,
+ /*null_probability=*/args.null_proportion,
+ /*nan_probability=*/args.null_proportion / 10);
+
+ auto key = rng.StringWithRepeats(args.size,
+ /*unique=*/256,
+ /*min_length=*/3,
+ /*max_length=*/32);
+
+ BenchmarkGroupBy(state, {{"hash_sum", NULLPTR}}, {summand}, {key});
+});
+
+GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumStringSet, [&] {
+ auto summand = rng.Float64(args.size,
+ /*min=*/0.0,
+ /*max=*/1.0e14,
+ /*null_probability=*/args.null_proportion,
+ /*nan_probability=*/args.null_proportion / 10);
+
+ auto key = rng.StringWithRepeats(args.size,
+ /*unique=*/4096,
+ /*min_length=*/3,
+ /*max_length=*/32);
+
+ BenchmarkGroupBy(state, {{"hash_sum", NULLPTR}}, {summand}, {key});
+});
+
+GROUP_BY_BENCHMARK(SumDoublesGroupedByTinyIntegerSet, [&] {
+ auto summand = rng.Float64(args.size,
+ /*min=*/0.0,
+ /*max=*/1.0e14,
+ /*null_probability=*/args.null_proportion,
+ /*nan_probability=*/args.null_proportion / 10);
+
+ auto key = rng.Int64(args.size,
+ /*min=*/0,
+ /*max=*/15);
+
+ BenchmarkGroupBy(state, {{"hash_sum", NULLPTR}}, {summand}, {key});
+});
+
+GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallIntegerSet, [&] {
+ auto summand = rng.Float64(args.size,
+ /*min=*/0.0,
+ /*max=*/1.0e14,
+ /*null_probability=*/args.null_proportion,
+ /*nan_probability=*/args.null_proportion / 10);
+
+ auto key = rng.Int64(args.size,
+ /*min=*/0,
+ /*max=*/255);
+
+ BenchmarkGroupBy(state, {{"hash_sum", NULLPTR}}, {summand}, {key});
+});
+
+GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumIntegerSet, [&] {
+ auto summand = rng.Float64(args.size,
+ /*min=*/0.0,
+ /*max=*/1.0e14,
+ /*null_probability=*/args.null_proportion,
+ /*nan_probability=*/args.null_proportion / 10);
+
+ auto key = rng.Int64(args.size,
+ /*min=*/0,
+ /*max=*/4095);
+
+ BenchmarkGroupBy(state, {{"hash_sum", NULLPTR}}, {summand}, {key});
+});
+
+GROUP_BY_BENCHMARK(SumDoublesGroupedByTinyIntStringPairSet, [&] {
+ auto summand = rng.Float64(args.size,
+ /*min=*/0.0,
+ /*max=*/1.0e14,
+ /*null_probability=*/args.null_proportion,
+ /*nan_probability=*/args.null_proportion / 10);
+
+ auto int_key = rng.Int64(args.size,
+ /*min=*/0,
+ /*max=*/4);
+ auto str_key = rng.StringWithRepeats(args.size,
+ /*unique=*/4,
+ /*min_length=*/3,
+ /*max_length=*/32);
+
+ BenchmarkGroupBy(state, {{"hash_sum", NULLPTR}}, {summand}, {int_key, str_key});
+});
+
+GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallIntStringPairSet, [&] {
+ auto summand = rng.Float64(args.size,
+ /*min=*/0.0,
+ /*max=*/1.0e14,
+ /*null_probability=*/args.null_proportion,
+ /*nan_probability=*/args.null_proportion / 10);
+
+ auto int_key = rng.Int64(args.size,
+ /*min=*/0,
+ /*max=*/15);
+ auto str_key = rng.StringWithRepeats(args.size,
+ /*unique=*/16,
+ /*min_length=*/3,
+ /*max_length=*/32);
+
+ BenchmarkGroupBy(state, {{"hash_sum", NULLPTR}}, {summand}, {int_key, str_key});
+});
+
+GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumIntStringPairSet, [&] {
+ auto summand = rng.Float64(args.size,
+ /*min=*/0.0,
+ /*max=*/1.0e14,
+ /*null_probability=*/args.null_proportion,
+ /*nan_probability=*/args.null_proportion / 10);
+
+ auto int_key = rng.Int64(args.size,
+ /*min=*/0,
+ /*max=*/63);
+ auto str_key = rng.StringWithRepeats(args.size,
+ /*unique=*/64,
+ /*min_length=*/3,
+ /*max_length=*/32);
+
+ BenchmarkGroupBy(state, {{"hash_sum", NULLPTR}}, {summand}, {int_key, str_key});
+});
+
+//
+// Sum
+//
+
+template <typename ArrowType>
+static void SumKernel(benchmark::State& state) {
+ using CType = typename TypeTraits<ArrowType>::CType;
+
+ RegressionArgs args(state);
+ const int64_t array_size = args.size / sizeof(CType);
+ auto rand = random::RandomArrayGenerator(1923);
+ auto array = rand.Numeric<ArrowType>(array_size, -100, 100, args.null_proportion);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(Sum(array).status());
+ }
+}
+
+static void SumKernelArgs(benchmark::internal::Benchmark* bench) {
+ BenchmarkSetArgsWithSizes(bench, {1 * 1024 * 1024}); // 1M
+}
+
+#define SUM_KERNEL_BENCHMARK(FuncName, Type) \
+ static void FuncName(benchmark::State& state) { SumKernel<Type>(state); } \
+ BENCHMARK(FuncName)->Apply(SumKernelArgs)
+
+SUM_KERNEL_BENCHMARK(SumKernelFloat, FloatType);
+SUM_KERNEL_BENCHMARK(SumKernelDouble, DoubleType);
+SUM_KERNEL_BENCHMARK(SumKernelInt8, Int8Type);
+SUM_KERNEL_BENCHMARK(SumKernelInt16, Int16Type);
+SUM_KERNEL_BENCHMARK(SumKernelInt32, Int32Type);
+SUM_KERNEL_BENCHMARK(SumKernelInt64, Int64Type);
+
+//
+// Mode
+//
+
+template <typename ArrowType>
+void ModeKernel(benchmark::State& state, int min, int max) {
+ using CType = typename TypeTraits<ArrowType>::CType;
+
+ RegressionArgs args(state);
+ const int64_t array_size = args.size / sizeof(CType);
+ auto rand = random::RandomArrayGenerator(1924);
+ auto array = rand.Numeric<ArrowType>(array_size, min, max, args.null_proportion);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(Mode(array).status());
+ }
+}
+
+template <typename ArrowType>
+void ModeKernelNarrow(benchmark::State& state) {
+ ModeKernel<ArrowType>(state, -5000, 8000); // max - min < 16384
+}
+
+template <>
+void ModeKernelNarrow<Int8Type>(benchmark::State& state) {
+ ModeKernel<Int8Type>(state, -128, 127);
+}
+
+template <>
+void ModeKernelNarrow<BooleanType>(benchmark::State& state) {
+ RegressionArgs args(state);
+ auto rand = random::RandomArrayGenerator(1924);
+ auto array = rand.Boolean(args.size * 8, 0.5, args.null_proportion);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(Mode(array).status());
+ }
+}
+
+template <typename ArrowType>
+void ModeKernelWide(benchmark::State& state) {
+ ModeKernel<ArrowType>(state, -1234567, 7654321);
+}
+
+static void ModeKernelArgs(benchmark::internal::Benchmark* bench) {
+ BenchmarkSetArgsWithSizes(bench, {1 * 1024 * 1024}); // 1M
+}
+
+BENCHMARK_TEMPLATE(ModeKernelNarrow, BooleanType)->Apply(ModeKernelArgs);
+BENCHMARK_TEMPLATE(ModeKernelNarrow, Int8Type)->Apply(ModeKernelArgs);
+BENCHMARK_TEMPLATE(ModeKernelNarrow, Int32Type)->Apply(ModeKernelArgs);
+BENCHMARK_TEMPLATE(ModeKernelNarrow, Int64Type)->Apply(ModeKernelArgs);
+BENCHMARK_TEMPLATE(ModeKernelWide, Int32Type)->Apply(ModeKernelArgs);
+BENCHMARK_TEMPLATE(ModeKernelWide, Int64Type)->Apply(ModeKernelArgs);
+BENCHMARK_TEMPLATE(ModeKernelWide, FloatType)->Apply(ModeKernelArgs);
+BENCHMARK_TEMPLATE(ModeKernelWide, DoubleType)->Apply(ModeKernelArgs);
+
+//
+// MinMax
+//
+
+template <typename ArrowType>
+static void MinMaxKernelBench(benchmark::State& state) {
+ using CType = typename TypeTraits<ArrowType>::CType;
+
+ RegressionArgs args(state);
+ const int64_t array_size = args.size / sizeof(CType);
+ auto rand = random::RandomArrayGenerator(1923);
+ auto array = rand.Numeric<ArrowType>(array_size, -100, 100, args.null_proportion);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(MinMax(array).status());
+ }
+}
+
+static void MinMaxKernelBenchArgs(benchmark::internal::Benchmark* bench) {
+ BenchmarkSetArgsWithSizes(bench, {1 * 1024 * 1024}); // 1M
+}
+
+#define MINMAX_KERNEL_BENCHMARK(FuncName, Type) \
+ static void FuncName(benchmark::State& state) { MinMaxKernelBench<Type>(state); } \
+ BENCHMARK(FuncName)->Apply(MinMaxKernelBenchArgs)
+
+MINMAX_KERNEL_BENCHMARK(MinMaxKernelFloat, FloatType);
+MINMAX_KERNEL_BENCHMARK(MinMaxKernelDouble, DoubleType);
+MINMAX_KERNEL_BENCHMARK(MinMaxKernelInt8, Int8Type);
+MINMAX_KERNEL_BENCHMARK(MinMaxKernelInt16, Int16Type);
+MINMAX_KERNEL_BENCHMARK(MinMaxKernelInt32, Int32Type);
+MINMAX_KERNEL_BENCHMARK(MinMaxKernelInt64, Int64Type);
+
+//
+// Count
+//
+
+static void CountKernelBenchInt64(benchmark::State& state) {
+ RegressionArgs args(state);
+ const int64_t array_size = args.size / sizeof(int64_t);
+ auto rand = random::RandomArrayGenerator(1923);
+ auto array = rand.Numeric<Int64Type>(array_size, -100, 100, args.null_proportion);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(Count(array->Slice(1, array_size)).status());
+ }
+}
+BENCHMARK(CountKernelBenchInt64)->Args({1 * 1024 * 1024, 2}); // 1M with 50% null.
+
+//
+// Variance
+//
+
+template <typename ArrowType>
+void VarianceKernelBench(benchmark::State& state) {
+ using CType = typename TypeTraits<ArrowType>::CType;
+
+ VarianceOptions options;
+ RegressionArgs args(state);
+ const int64_t array_size = args.size / sizeof(CType);
+ auto rand = random::RandomArrayGenerator(1925);
+ auto array = rand.Numeric<ArrowType>(array_size, -100000, 100000, args.null_proportion);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(Variance(array, options).status());
+ }
+}
+
+static void VarianceKernelBenchArgs(benchmark::internal::Benchmark* bench) {
+ BenchmarkSetArgsWithSizes(bench, {1 * 1024 * 1024});
+}
+
+#define VARIANCE_KERNEL_BENCHMARK(FuncName, Type) \
+ static void FuncName(benchmark::State& state) { VarianceKernelBench<Type>(state); } \
+ BENCHMARK(FuncName)->Apply(VarianceKernelBenchArgs)
+
+VARIANCE_KERNEL_BENCHMARK(VarianceKernelInt32, Int32Type);
+VARIANCE_KERNEL_BENCHMARK(VarianceKernelInt64, Int64Type);
+VARIANCE_KERNEL_BENCHMARK(VarianceKernelFloat, FloatType);
+VARIANCE_KERNEL_BENCHMARK(VarianceKernelDouble, DoubleType);
+
+//
+// Quantile
+//
+
+static std::vector<double> deciles() {
+ return {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0};
+}
+
+static std::vector<double> centiles() {
+ std::vector<double> q(101);
+ for (int i = 0; i <= 100; ++i) {
+ q[i] = i / 100.0;
+ }
+ return q;
+}
+
+template <typename ArrowType>
+void QuantileKernel(benchmark::State& state, int min, int max, std::vector<double> q) {
+ using CType = typename TypeTraits<ArrowType>::CType;
+
+ QuantileOptions options(std::move(q));
+ RegressionArgs args(state);
+ const int64_t array_size = args.size / sizeof(CType);
+ auto rand = random::RandomArrayGenerator(1926);
+ auto array = rand.Numeric<ArrowType>(array_size, min, max, args.null_proportion);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(Quantile(array, options).status());
+ }
+ state.SetItemsProcessed(state.iterations() * array_size);
+}
+
+template <typename ArrowType>
+void QuantileKernelMedian(benchmark::State& state, int min, int max) {
+ QuantileKernel<ArrowType>(state, min, max, {0.5});
+}
+
+template <typename ArrowType>
+void QuantileKernelMedianWide(benchmark::State& state) {
+ QuantileKernel<ArrowType>(state, 0, 1 << 24, {0.5});
+}
+
+template <typename ArrowType>
+void QuantileKernelMedianNarrow(benchmark::State& state) {
+ QuantileKernel<ArrowType>(state, -30000, 30000, {0.5});
+}
+
+template <typename ArrowType>
+void QuantileKernelDecilesWide(benchmark::State& state) {
+ QuantileKernel<ArrowType>(state, 0, 1 << 24, deciles());
+}
+
+template <typename ArrowType>
+void QuantileKernelDecilesNarrow(benchmark::State& state) {
+ QuantileKernel<ArrowType>(state, -30000, 30000, deciles());
+}
+
+template <typename ArrowType>
+void QuantileKernelCentilesWide(benchmark::State& state) {
+ QuantileKernel<ArrowType>(state, 0, 1 << 24, centiles());
+}
+
+template <typename ArrowType>
+void QuantileKernelCentilesNarrow(benchmark::State& state) {
+ QuantileKernel<ArrowType>(state, -30000, 30000, centiles());
+}
+
+static void QuantileKernelArgs(benchmark::internal::Benchmark* bench) {
+ BenchmarkSetArgsWithSizes(bench, {1 * 1024 * 1024});
+}
+
+BENCHMARK_TEMPLATE(QuantileKernelMedianNarrow, Int32Type)->Apply(QuantileKernelArgs);
+BENCHMARK_TEMPLATE(QuantileKernelMedianWide, Int32Type)->Apply(QuantileKernelArgs);
+BENCHMARK_TEMPLATE(QuantileKernelMedianNarrow, Int64Type)->Apply(QuantileKernelArgs);
+BENCHMARK_TEMPLATE(QuantileKernelMedianWide, Int64Type)->Apply(QuantileKernelArgs);
+BENCHMARK_TEMPLATE(QuantileKernelMedianWide, DoubleType)->Apply(QuantileKernelArgs);
+
+BENCHMARK_TEMPLATE(QuantileKernelDecilesNarrow, Int32Type)->Apply(QuantileKernelArgs);
+BENCHMARK_TEMPLATE(QuantileKernelDecilesWide, Int32Type)->Apply(QuantileKernelArgs);
+BENCHMARK_TEMPLATE(QuantileKernelDecilesWide, DoubleType)->Apply(QuantileKernelArgs);
+
+BENCHMARK_TEMPLATE(QuantileKernelCentilesNarrow, Int32Type)->Apply(QuantileKernelArgs);
+BENCHMARK_TEMPLATE(QuantileKernelCentilesWide, Int32Type)->Apply(QuantileKernelArgs);
+BENCHMARK_TEMPLATE(QuantileKernelCentilesWide, DoubleType)->Apply(QuantileKernelArgs);
+
+static void TDigestKernelDouble(benchmark::State& state, std::vector<double> q) {
+ TDigestOptions options{std::move(q)};
+ RegressionArgs args(state);
+ const int64_t array_size = args.size / sizeof(double);
+ auto rand = random::RandomArrayGenerator(1926);
+ auto array = rand.Numeric<DoubleType>(array_size, 0, 1 << 24, args.null_proportion);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(TDigest(array, options).status());
+ }
+ state.SetItemsProcessed(state.iterations() * array_size);
+}
+
+static void TDigestKernelDoubleMedian(benchmark::State& state) {
+ TDigestKernelDouble(state, {0.5});
+}
+
+static void TDigestKernelDoubleDeciles(benchmark::State& state) {
+ TDigestKernelDouble(state, deciles());
+}
+
+static void TDigestKernelDoubleCentiles(benchmark::State& state) {
+ TDigestKernelDouble(state, centiles());
+}
+
+BENCHMARK(TDigestKernelDoubleMedian)->Apply(QuantileKernelArgs);
+BENCHMARK(TDigestKernelDoubleDeciles)->Apply(QuantileKernelArgs);
+BENCHMARK(TDigestKernelDoubleCentiles)->Apply(QuantileKernelArgs);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/aggregate_internal.h b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_internal.h
new file mode 100644
index 000000000..22a54558f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_internal.h
@@ -0,0 +1,223 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+// Find the largest compatible primitive type for a primitive type.
+template <typename I, typename Enable = void>
+struct FindAccumulatorType {};
+
+template <typename I>
+struct FindAccumulatorType<I, enable_if_boolean<I>> {
+ using Type = UInt64Type;
+};
+
+template <typename I>
+struct FindAccumulatorType<I, enable_if_signed_integer<I>> {
+ using Type = Int64Type;
+};
+
+template <typename I>
+struct FindAccumulatorType<I, enable_if_unsigned_integer<I>> {
+ using Type = UInt64Type;
+};
+
+template <typename I>
+struct FindAccumulatorType<I, enable_if_floating_point<I>> {
+ using Type = DoubleType;
+};
+
+template <typename I>
+struct FindAccumulatorType<I, enable_if_decimal128<I>> {
+ using Type = Decimal128Type;
+};
+
+template <typename I>
+struct FindAccumulatorType<I, enable_if_decimal256<I>> {
+ using Type = Decimal256Type;
+};
+
+// Helpers for implementing aggregations on decimals
+
+template <typename Type, typename Enable = void>
+struct MultiplyTraits {
+ using CType = typename TypeTraits<Type>::CType;
+
+ constexpr static CType one(const DataType&) { return static_cast<CType>(1); }
+
+ constexpr static CType Multiply(const DataType&, CType lhs, CType rhs) {
+ return static_cast<CType>(internal::to_unsigned(lhs) * internal::to_unsigned(rhs));
+ }
+};
+
+template <typename Type>
+struct MultiplyTraits<Type, enable_if_decimal<Type>> {
+ using CType = typename TypeTraits<Type>::CType;
+
+ constexpr static CType one(const DataType& ty) {
+ // Return 1 scaled to output type scale
+ return CType(1).IncreaseScaleBy(static_cast<const Type&>(ty).scale());
+ }
+
+ constexpr static CType Multiply(const DataType& ty, CType lhs, CType rhs) {
+ // Multiply then rescale down to output scale
+ return (lhs * rhs).ReduceScaleBy(static_cast<const Type&>(ty).scale());
+ }
+};
+
+struct ScalarAggregator : public KernelState {
+ virtual Status Consume(KernelContext* ctx, const ExecBatch& batch) = 0;
+ virtual Status MergeFrom(KernelContext* ctx, KernelState&& src) = 0;
+ virtual Status Finalize(KernelContext* ctx, Datum* out) = 0;
+};
+
+// Helper to differentiate between var/std calculation so we can fold
+// kernel implementations together
+enum class VarOrStd : bool { Var, Std };
+
+// Helper to differentiate between min/max calculation so we can fold
+// kernel implementations together
+enum class MinOrMax : uint8_t { Min = 0, Max };
+
+void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
+ ScalarAggregateFunction* func,
+ SimdLevel::type simd_level = SimdLevel::NONE);
+
+void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
+ ScalarAggregateFinalize finalize, ScalarAggregateFunction* func,
+ SimdLevel::type simd_level = SimdLevel::NONE);
+
+// SumArray must be parameterized with the SIMD level since it's called both from
+// translation units with and without vectorization. Normally it gets inlined but
+// if not, without the parameter, we'll have multiple definitions of the same
+// symbol and we'll get unexpected results.
+
+// non-recursive pairwise summation for floating points
+// https://en.wikipedia.org/wiki/Pairwise_summation
+template <typename ValueType, typename SumType, SimdLevel::type SimdLevel,
+ typename ValueFunc>
+enable_if_t<std::is_floating_point<SumType>::value, SumType> SumArray(
+ const ArrayData& data, ValueFunc&& func) {
+ using arrow::internal::VisitSetBitRunsVoid;
+
+ const int64_t data_size = data.length - data.GetNullCount();
+ if (data_size == 0) {
+ return 0;
+ }
+
+ // number of inputs to accumulate before merging with another block
+ constexpr int kBlockSize = 16; // same as numpy
+ // levels (tree depth) = ceil(log2(len)) + 1, a bit larger than necessary
+ const int levels = BitUtil::Log2(static_cast<uint64_t>(data_size)) + 1;
+ // temporary summation per level
+ std::vector<SumType> sum(levels);
+ // whether two summations are ready and should be reduced to upper level
+ // one bit for each level, bit0 -> level0, ...
+ uint64_t mask = 0;
+ // level of root node holding the final summation
+ int root_level = 0;
+
+ // reduce summation of one block (may be smaller than kBlockSize) from leaf node
+ // continue reducing to upper level if two summations are ready for non-leaf node
+ auto reduce = [&](SumType block_sum) {
+ int cur_level = 0;
+ uint64_t cur_level_mask = 1ULL;
+ sum[cur_level] += block_sum;
+ mask ^= cur_level_mask;
+ while ((mask & cur_level_mask) == 0) {
+ block_sum = sum[cur_level];
+ sum[cur_level] = 0;
+ ++cur_level;
+ DCHECK_LT(cur_level, levels);
+ cur_level_mask <<= 1;
+ sum[cur_level] += block_sum;
+ mask ^= cur_level_mask;
+ }
+ root_level = std::max(root_level, cur_level);
+ };
+
+ const ValueType* values = data.GetValues<ValueType>(1);
+ VisitSetBitRunsVoid(data.buffers[0], data.offset, data.length,
+ [&](int64_t pos, int64_t len) {
+ const ValueType* v = &values[pos];
+ // unsigned division by constant is cheaper than signed one
+ const uint64_t blocks = static_cast<uint64_t>(len) / kBlockSize;
+ const uint64_t remains = static_cast<uint64_t>(len) % kBlockSize;
+
+ for (uint64_t i = 0; i < blocks; ++i) {
+ SumType block_sum = 0;
+ for (int j = 0; j < kBlockSize; ++j) {
+ block_sum += func(v[j]);
+ }
+ reduce(block_sum);
+ v += kBlockSize;
+ }
+
+ if (remains > 0) {
+ SumType block_sum = 0;
+ for (uint64_t i = 0; i < remains; ++i) {
+ block_sum += func(v[i]);
+ }
+ reduce(block_sum);
+ }
+ });
+
+ // reduce intermediate summations from all non-leaf nodes
+ for (int i = 1; i <= root_level; ++i) {
+ sum[i] += sum[i - 1];
+ }
+
+ return sum[root_level];
+}
+
+// naive summation for integers and decimals
+template <typename ValueType, typename SumType, SimdLevel::type SimdLevel,
+ typename ValueFunc>
+enable_if_t<!std::is_floating_point<SumType>::value, SumType> SumArray(
+ const ArrayData& data, ValueFunc&& func) {
+ using arrow::internal::VisitSetBitRunsVoid;
+
+ SumType sum = 0;
+ const ValueType* values = data.GetValues<ValueType>(1);
+ VisitSetBitRunsVoid(data.buffers[0], data.offset, data.length,
+ [&](int64_t pos, int64_t len) {
+ for (int64_t i = 0; i < len; ++i) {
+ sum += func(values[pos + i]);
+ }
+ });
+ return sum;
+}
+
+template <typename ValueType, typename SumType, SimdLevel::type SimdLevel>
+SumType SumArray(const ArrayData& data) {
+ return SumArray<ValueType, SumType, SimdLevel>(
+ data, [](ValueType v) { return static_cast<SumType>(v); });
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/aggregate_mode.cc b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_mode.cc
new file mode 100644
index 000000000..f225f6bf5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_mode.cc
@@ -0,0 +1,419 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cmath>
+#include <queue>
+#include <utility>
+
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/kernels/aggregate_internal.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/result.h"
+#include "arrow/stl_allocator.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+namespace {
+
+using ModeState = OptionsWrapper<ModeOptions>;
+
+constexpr char kModeFieldName[] = "mode";
+constexpr char kCountFieldName[] = "count";
+
+constexpr uint64_t kCountEOF = ~0ULL;
+
+template <typename InType, typename CType = typename InType::c_type>
+Result<std::pair<CType*, int64_t*>> PrepareOutput(int64_t n, KernelContext* ctx,
+ Datum* out) {
+ const auto& mode_type = TypeTraits<InType>::type_singleton();
+ const auto& count_type = int64();
+
+ auto mode_data = ArrayData::Make(mode_type, /*length=*/n, /*null_count=*/0);
+ mode_data->buffers.resize(2, nullptr);
+ auto count_data = ArrayData::Make(count_type, n, 0);
+ count_data->buffers.resize(2, nullptr);
+
+ CType* mode_buffer = nullptr;
+ int64_t* count_buffer = nullptr;
+
+ if (n > 0) {
+ ARROW_ASSIGN_OR_RAISE(mode_data->buffers[1], ctx->Allocate(n * sizeof(CType)));
+ ARROW_ASSIGN_OR_RAISE(count_data->buffers[1], ctx->Allocate(n * sizeof(int64_t)));
+ mode_buffer = mode_data->template GetMutableValues<CType>(1);
+ count_buffer = count_data->template GetMutableValues<int64_t>(1);
+ }
+
+ const auto& out_type =
+ struct_({field(kModeFieldName, mode_type), field(kCountFieldName, count_type)});
+ *out = Datum(ArrayData::Make(out_type, n, {nullptr}, {mode_data, count_data}, 0));
+
+ return std::make_pair(mode_buffer, count_buffer);
+}
+
+// find top-n value:count pairs with minimal heap
+// suboptimal for tiny or large n, possibly okay as we're not in hot path
+template <typename InType, typename Generator>
+Status Finalize(KernelContext* ctx, Datum* out, Generator&& gen) {
+ using CType = typename InType::c_type;
+
+ using ValueCountPair = std::pair<CType, uint64_t>;
+ auto gt = [](const ValueCountPair& lhs, const ValueCountPair& rhs) {
+ const bool rhs_is_nan = rhs.first != rhs.first; // nan as largest value
+ return lhs.second > rhs.second ||
+ (lhs.second == rhs.second && (lhs.first < rhs.first || rhs_is_nan));
+ };
+
+ std::priority_queue<ValueCountPair, std::vector<ValueCountPair>, decltype(gt)> min_heap(
+ std::move(gt));
+
+ const ModeOptions& options = ModeState::Get(ctx);
+ while (true) {
+ const ValueCountPair& value_count = gen();
+ DCHECK_NE(value_count.second, 0);
+ if (value_count.second == kCountEOF) break;
+ if (static_cast<int64_t>(min_heap.size()) < options.n) {
+ min_heap.push(value_count);
+ } else if (gt(value_count, min_heap.top())) {
+ min_heap.pop();
+ min_heap.push(value_count);
+ }
+ }
+ const int64_t n = min_heap.size();
+
+ CType* mode_buffer;
+ int64_t* count_buffer;
+ ARROW_ASSIGN_OR_RAISE(std::tie(mode_buffer, count_buffer),
+ PrepareOutput<InType>(n, ctx, out));
+
+ for (int64_t i = n - 1; i >= 0; --i) {
+ std::tie(mode_buffer[i], count_buffer[i]) = min_heap.top();
+ min_heap.pop();
+ }
+
+ return Status::OK();
+}
+
+// count value occurances for integers with narrow value range
+// O(1) space, O(n) time
+template <typename T>
+struct CountModer {
+ using CType = typename T::c_type;
+
+ CType min;
+ std::vector<uint64_t> counts;
+
+ CountModer(CType min, CType max) {
+ uint32_t value_range = static_cast<uint32_t>(max - min) + 1;
+ DCHECK_LT(value_range, 1 << 20);
+ this->min = min;
+ this->counts.resize(value_range, 0);
+ }
+
+ Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // count values in all chunks, ignore nulls
+ const Datum& datum = batch[0];
+
+ const ModeOptions& options = ModeState::Get(ctx);
+ if ((!options.skip_nulls && datum.null_count() > 0) ||
+ (datum.length() - datum.null_count() < options.min_count)) {
+ return PrepareOutput<T>(/*n=*/0, ctx, out).status();
+ }
+
+ CountValues<CType>(this->counts.data(), datum, this->min);
+
+ // generator to emit next value:count pair
+ int index = 0;
+ auto gen = [&]() {
+ for (; index < static_cast<int>(counts.size()); ++index) {
+ if (counts[index] != 0) {
+ auto value_count =
+ std::make_pair(static_cast<CType>(index + this->min), counts[index]);
+ ++index;
+ return value_count;
+ }
+ }
+ return std::pair<CType, uint64_t>(0, kCountEOF);
+ };
+
+ return Finalize<T>(ctx, out, std::move(gen));
+ }
+};
+
+// booleans can be handled more straightforward
+template <>
+struct CountModer<BooleanType> {
+ Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const Datum& datum = batch[0];
+
+ const ModeOptions& options = ModeState::Get(ctx);
+ if ((!options.skip_nulls && datum.null_count() > 0) ||
+ (datum.length() - datum.null_count() < options.min_count)) {
+ return PrepareOutput<BooleanType>(/*n=*/0, ctx, out).status();
+ }
+
+ int64_t counts[2]{};
+
+ for (const auto& array : datum.chunks()) {
+ if (array->length() > array->null_count()) {
+ const int64_t true_count =
+ arrow::internal::checked_pointer_cast<BooleanArray>(array)->true_count();
+ const int64_t false_count = array->length() - array->null_count() - true_count;
+ counts[true] += true_count;
+ counts[false] += false_count;
+ }
+ }
+
+ const int64_t distinct_values = (counts[0] != 0) + (counts[1] != 0);
+ const int64_t n = std::min(options.n, distinct_values);
+
+ bool* mode_buffer;
+ int64_t* count_buffer;
+ ARROW_ASSIGN_OR_RAISE(std::tie(mode_buffer, count_buffer),
+ PrepareOutput<BooleanType>(n, ctx, out));
+
+ if (n >= 1) {
+ const bool index = counts[1] > counts[0];
+ mode_buffer[0] = index;
+ count_buffer[0] = counts[index];
+ if (n == 2) {
+ mode_buffer[1] = !index;
+ count_buffer[1] = counts[!index];
+ }
+ }
+
+ return Status::OK();
+ }
+};
+
+// copy and sort approach for floating points or integers with wide value range
+// O(n) space, O(nlogn) time
+template <typename T>
+struct SortModer {
+ using CType = typename T::c_type;
+ using Allocator = arrow::stl::allocator<CType>;
+
+ Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const Datum& datum = batch[0];
+ const int64_t in_length = datum.length() - datum.null_count();
+
+ const ModeOptions& options = ModeState::Get(ctx);
+ if ((!options.skip_nulls && datum.null_count() > 0) ||
+ (in_length < options.min_count)) {
+ return PrepareOutput<T>(/*n=*/0, ctx, out).status();
+ }
+
+ // copy all chunks to a buffer, ignore nulls and nans
+ std::vector<CType, Allocator> in_buffer(Allocator(ctx->memory_pool()));
+
+ uint64_t nan_count = 0;
+ if (in_length > 0) {
+ in_buffer.resize(in_length);
+ CopyNonNullValues(datum, in_buffer.data());
+
+ // drop nan
+ if (is_floating_type<T>::value) {
+ const auto& it = std::remove_if(in_buffer.begin(), in_buffer.end(),
+ [](CType v) { return v != v; });
+ nan_count = in_buffer.end() - it;
+ in_buffer.resize(it - in_buffer.begin());
+ }
+ }
+
+ // sort the input data to count same values
+ std::sort(in_buffer.begin(), in_buffer.end());
+
+ // generator to emit next value:count pair
+ auto it = in_buffer.cbegin();
+ auto gen = [&]() {
+ if (ARROW_PREDICT_FALSE(it == in_buffer.cend())) {
+ // handle NAN at last
+ if (nan_count > 0) {
+ auto value_count = std::make_pair(static_cast<CType>(NAN), nan_count);
+ nan_count = 0;
+ return value_count;
+ }
+ return std::pair<CType, uint64_t>(static_cast<CType>(0), kCountEOF);
+ }
+ // count same values
+ const CType value = *it;
+ uint64_t count = 0;
+ do {
+ ++it;
+ ++count;
+ } while (it != in_buffer.cend() && *it == value);
+ return std::make_pair(value, count);
+ };
+
+ return Finalize<T>(ctx, out, std::move(gen));
+ }
+};
+
+// pick counting or sorting approach per integers value range
+template <typename T>
+struct CountOrSortModer {
+ using CType = typename T::c_type;
+
+ Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // cross point to benefit from counting approach
+ // about 2x improvement for int32/64 from micro-benchmarking
+ static constexpr int kMinArraySize = 8192;
+ static constexpr int kMaxValueRange = 32768;
+
+ const Datum& datum = batch[0];
+ if (datum.length() - datum.null_count() >= kMinArraySize) {
+ CType min, max;
+ std::tie(min, max) = GetMinMax<CType>(datum);
+
+ if (static_cast<uint64_t>(max) - static_cast<uint64_t>(min) <= kMaxValueRange) {
+ return CountModer<T>(min, max).Exec(ctx, batch, out);
+ }
+ }
+
+ return SortModer<T>().Exec(ctx, batch, out);
+ }
+};
+
+template <typename InType, typename Enable = void>
+struct Moder;
+
+template <>
+struct Moder<Int8Type> {
+ CountModer<Int8Type> impl;
+ Moder() : impl(-128, 127) {}
+};
+
+template <>
+struct Moder<UInt8Type> {
+ CountModer<UInt8Type> impl;
+ Moder() : impl(0, 255) {}
+};
+
+template <>
+struct Moder<BooleanType> {
+ CountModer<BooleanType> impl;
+};
+
+template <typename InType>
+struct Moder<InType, enable_if_t<(is_integer_type<InType>::value &&
+ (sizeof(typename InType::c_type) > 1))>> {
+ CountOrSortModer<InType> impl;
+};
+
+template <typename InType>
+struct Moder<InType, enable_if_t<is_floating_type<InType>::value>> {
+ SortModer<InType> impl;
+};
+
+template <typename T>
+Status ScalarMode(KernelContext* ctx, const Scalar& scalar, Datum* out) {
+ using CType = typename T::c_type;
+
+ const ModeOptions& options = ModeState::Get(ctx);
+ if ((!options.skip_nulls && !scalar.is_valid) ||
+ (static_cast<uint32_t>(scalar.is_valid) < options.min_count)) {
+ return PrepareOutput<T>(/*n=*/0, ctx, out).status();
+ }
+
+ if (scalar.is_valid) {
+ bool called = false;
+ return Finalize<T>(ctx, out, [&]() {
+ if (!called) {
+ called = true;
+ return std::pair<CType, uint64_t>(UnboxScalar<T>::Unbox(scalar), 1);
+ }
+ return std::pair<CType, uint64_t>(static_cast<CType>(0), kCountEOF);
+ });
+ }
+ return Finalize<T>(ctx, out, []() {
+ return std::pair<CType, uint64_t>(static_cast<CType>(0), kCountEOF);
+ });
+}
+
+template <typename _, typename InType>
+struct ModeExecutor {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (ctx->state() == nullptr) {
+ return Status::Invalid("Mode requires ModeOptions");
+ }
+ const ModeOptions& options = ModeState::Get(ctx);
+ if (options.n <= 0) {
+ return Status::Invalid("ModeOption::n must be strictly positive");
+ }
+
+ if (batch[0].is_scalar()) {
+ return ScalarMode<InType>(ctx, *batch[0].scalar(), out);
+ }
+
+ return Moder<InType>().impl.Exec(ctx, batch, out);
+ }
+};
+
+VectorKernel NewModeKernel(const std::shared_ptr<DataType>& in_type) {
+ VectorKernel kernel;
+ kernel.init = ModeState::Init;
+ kernel.can_execute_chunkwise = false;
+ kernel.output_chunked = false;
+ auto out_type =
+ struct_({field(kModeFieldName, in_type), field(kCountFieldName, int64())});
+ kernel.signature =
+ KernelSignature::Make({InputType(in_type)}, ValueDescr::Array(out_type));
+ return kernel;
+}
+
+void AddBooleanModeKernel(VectorFunction* func) {
+ VectorKernel kernel = NewModeKernel(boolean());
+ kernel.exec = ModeExecutor<StructType, BooleanType>::Exec;
+ DCHECK_OK(func->AddKernel(kernel));
+}
+
+void AddNumericModeKernels(VectorFunction* func) {
+ for (const auto& type : NumericTypes()) {
+ VectorKernel kernel = NewModeKernel(type);
+ kernel.exec = GenerateNumeric<ModeExecutor, StructType>(*type);
+ DCHECK_OK(func->AddKernel(kernel));
+ }
+}
+
+const FunctionDoc mode_doc{
+ "Calculate the modal (most common) values of a numeric array",
+ ("Returns top-n most common values and number of times they occur in an array.\n"
+ "Result is an array of `struct<mode: T, count: int64>`, where T is the input type.\n"
+ "Values with larger counts are returned before smaller counts.\n"
+ "If there are more than one values with same count, smaller one is returned first.\n"
+ "Nulls are ignored. If there are no non-null values in the array,\n"
+ "empty array is returned."),
+ {"array"},
+ "ModeOptions"};
+
+} // namespace
+
+void RegisterScalarAggregateMode(FunctionRegistry* registry) {
+ static auto default_options = ModeOptions::Defaults();
+ auto func = std::make_shared<VectorFunction>("mode", Arity::Unary(), &mode_doc,
+ &default_options);
+ AddBooleanModeKernel(func.get());
+ AddNumericModeKernels(func.get());
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/aggregate_quantile.cc b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_quantile.cc
new file mode 100644
index 000000000..62e375e69
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_quantile.cc
@@ -0,0 +1,513 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cmath>
+#include <vector>
+
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/stl_allocator.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+namespace {
+
+using QuantileState = internal::OptionsWrapper<QuantileOptions>;
+
+// output is at some input data point, not interpolated
+bool IsDataPoint(const QuantileOptions& options) {
+ // some interpolation methods return exact data point
+ return options.interpolation == QuantileOptions::LOWER ||
+ options.interpolation == QuantileOptions::HIGHER ||
+ options.interpolation == QuantileOptions::NEAREST;
+}
+
+// quantile to exact datapoint index (IsDataPoint == true)
+uint64_t QuantileToDataPoint(size_t length, double q,
+ enum QuantileOptions::Interpolation interpolation) {
+ const double index = (length - 1) * q;
+ uint64_t datapoint_index = static_cast<uint64_t>(index);
+ const double fraction = index - datapoint_index;
+
+ if (interpolation == QuantileOptions::LINEAR ||
+ interpolation == QuantileOptions::MIDPOINT) {
+ DCHECK_EQ(fraction, 0);
+ }
+
+ // convert NEAREST interpolation method to LOWER or HIGHER
+ if (interpolation == QuantileOptions::NEAREST) {
+ if (fraction < 0.5) {
+ interpolation = QuantileOptions::LOWER;
+ } else if (fraction > 0.5) {
+ interpolation = QuantileOptions::HIGHER;
+ } else {
+ // round 0.5 to nearest even number, similar to numpy.around
+ interpolation =
+ (datapoint_index & 1) ? QuantileOptions::HIGHER : QuantileOptions::LOWER;
+ }
+ }
+
+ if (interpolation == QuantileOptions::HIGHER && fraction != 0) {
+ ++datapoint_index;
+ }
+
+ return datapoint_index;
+}
+
+// copy and nth_element approach, large memory footprint
+template <typename InType>
+struct SortQuantiler {
+ using CType = typename InType::c_type;
+ using Allocator = arrow::stl::allocator<CType>;
+
+ Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const QuantileOptions& options = QuantileState::Get(ctx);
+ const Datum& datum = batch[0];
+
+ // copy all chunks to a buffer, ignore nulls and nans
+ std::vector<CType, Allocator> in_buffer(Allocator(ctx->memory_pool()));
+ int64_t in_length = 0;
+ if ((!options.skip_nulls && datum.null_count() > 0) ||
+ (datum.length() - datum.null_count() < options.min_count)) {
+ in_length = 0;
+ } else {
+ in_length = datum.length() - datum.null_count();
+ }
+
+ if (in_length > 0) {
+ in_buffer.resize(in_length);
+ CopyNonNullValues(datum, in_buffer.data());
+
+ // drop nan
+ if (is_floating_type<InType>::value) {
+ const auto& it = std::remove_if(in_buffer.begin(), in_buffer.end(),
+ [](CType v) { return v != v; });
+ in_buffer.resize(it - in_buffer.begin());
+ }
+ }
+
+ // prepare out array
+ // out type depends on options
+ const bool is_datapoint = IsDataPoint(options);
+ const std::shared_ptr<DataType> out_type =
+ is_datapoint ? TypeTraits<InType>::type_singleton() : float64();
+ int64_t out_length = options.q.size();
+ if (in_buffer.empty()) {
+ return MakeArrayOfNull(out_type, out_length, ctx->memory_pool()).Value(out);
+ }
+ auto out_data = ArrayData::Make(out_type, out_length, 0);
+ out_data->buffers.resize(2, nullptr);
+
+ // calculate quantiles
+ if (out_length > 0) {
+ ARROW_ASSIGN_OR_RAISE(out_data->buffers[1],
+ ctx->Allocate(out_length * GetBitWidth(*out_type) / 8));
+
+ // find quantiles in descending order
+ std::vector<int64_t> q_indices(out_length);
+ std::iota(q_indices.begin(), q_indices.end(), 0);
+ std::sort(q_indices.begin(), q_indices.end(),
+ [&options](int64_t left_index, int64_t right_index) {
+ return options.q[right_index] < options.q[left_index];
+ });
+
+ // input array is partitioned around data point at `last_index` (pivot)
+ // for next quatile which is smaller, we only consider inputs left of the pivot
+ uint64_t last_index = in_buffer.size();
+ if (is_datapoint) {
+ CType* out_buffer = out_data->template GetMutableValues<CType>(1);
+ for (int64_t i = 0; i < out_length; ++i) {
+ const int64_t q_index = q_indices[i];
+ out_buffer[q_index] = GetQuantileAtDataPoint(
+ in_buffer, &last_index, options.q[q_index], options.interpolation);
+ }
+ } else {
+ double* out_buffer = out_data->template GetMutableValues<double>(1);
+ for (int64_t i = 0; i < out_length; ++i) {
+ const int64_t q_index = q_indices[i];
+ out_buffer[q_index] = GetQuantileByInterp(
+ in_buffer, &last_index, options.q[q_index], options.interpolation);
+ }
+ }
+ }
+
+ *out = Datum(std::move(out_data));
+ return Status::OK();
+ }
+
+ // return quantile located exactly at some input data point
+ CType GetQuantileAtDataPoint(std::vector<CType, Allocator>& in, uint64_t* last_index,
+ double q,
+ enum QuantileOptions::Interpolation interpolation) {
+ const uint64_t datapoint_index = QuantileToDataPoint(in.size(), q, interpolation);
+
+ if (datapoint_index != *last_index) {
+ DCHECK_LT(datapoint_index, *last_index);
+ std::nth_element(in.begin(), in.begin() + datapoint_index,
+ in.begin() + *last_index);
+ *last_index = datapoint_index;
+ }
+
+ return in[datapoint_index];
+ }
+
+ // return quantile interpolated from adjacent input data points
+ double GetQuantileByInterp(std::vector<CType, Allocator>& in, uint64_t* last_index,
+ double q,
+ enum QuantileOptions::Interpolation interpolation) {
+ const double index = (in.size() - 1) * q;
+ const uint64_t lower_index = static_cast<uint64_t>(index);
+ const double fraction = index - lower_index;
+
+ if (lower_index != *last_index) {
+ DCHECK_LT(lower_index, *last_index);
+ std::nth_element(in.begin(), in.begin() + lower_index, in.begin() + *last_index);
+ }
+
+ const double lower_value = static_cast<double>(in[lower_index]);
+ if (fraction == 0) {
+ *last_index = lower_index;
+ return lower_value;
+ }
+
+ const uint64_t higher_index = lower_index + 1;
+ DCHECK_LT(higher_index, in.size());
+ if (lower_index != *last_index && higher_index != *last_index) {
+ DCHECK_LT(higher_index, *last_index);
+ // higher value must be the minimal value after lower_index
+ auto min = std::min_element(in.begin() + higher_index, in.begin() + *last_index);
+ std::iter_swap(in.begin() + higher_index, min);
+ }
+ *last_index = lower_index;
+
+ const double higher_value = static_cast<double>(in[higher_index]);
+
+ if (interpolation == QuantileOptions::LINEAR) {
+ // more stable than naive linear interpolation
+ return fraction * higher_value + (1 - fraction) * lower_value;
+ } else if (interpolation == QuantileOptions::MIDPOINT) {
+ return lower_value / 2 + higher_value / 2;
+ } else {
+ DCHECK(false);
+ return NAN;
+ }
+ }
+};
+
+// histogram approach with constant memory, only for integers within limited value range
+template <typename InType>
+struct CountQuantiler {
+ using CType = typename InType::c_type;
+
+ CType min;
+ std::vector<uint64_t> counts; // counts[i]: # of values equals i + min
+
+ // indices to adjacent non-empty bins covering current quantile
+ struct AdjacentBins {
+ int left_index;
+ int right_index;
+ uint64_t total_count; // accumulated counts till left_index (inclusive)
+ };
+
+ CountQuantiler(CType min, CType max) {
+ uint32_t value_range = static_cast<uint32_t>(max - min) + 1;
+ DCHECK_LT(value_range, 1 << 30);
+ this->min = min;
+ this->counts.resize(value_range, 0);
+ }
+
+ Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const QuantileOptions& options = QuantileState::Get(ctx);
+
+ // count values in all chunks, ignore nulls
+ const Datum& datum = batch[0];
+ int64_t in_length = 0;
+ if ((options.skip_nulls || (!options.skip_nulls && datum.null_count() == 0)) &&
+ (datum.length() - datum.null_count() >= options.min_count)) {
+ in_length = CountValues<CType>(this->counts.data(), datum, this->min);
+ }
+
+ // prepare out array
+ // out type depends on options
+ const bool is_datapoint = IsDataPoint(options);
+ const std::shared_ptr<DataType> out_type =
+ is_datapoint ? TypeTraits<InType>::type_singleton() : float64();
+ int64_t out_length = options.q.size();
+ if (in_length == 0) {
+ return MakeArrayOfNull(out_type, out_length, ctx->memory_pool()).Value(out);
+ }
+ auto out_data = ArrayData::Make(out_type, out_length, 0);
+ out_data->buffers.resize(2, nullptr);
+
+ // calculate quantiles
+ if (out_length > 0) {
+ ARROW_ASSIGN_OR_RAISE(out_data->buffers[1],
+ ctx->Allocate(out_length * GetBitWidth(*out_type) / 8));
+
+ // find quantiles in ascending order
+ std::vector<int64_t> q_indices(out_length);
+ std::iota(q_indices.begin(), q_indices.end(), 0);
+ std::sort(q_indices.begin(), q_indices.end(),
+ [&options](int64_t left_index, int64_t right_index) {
+ return options.q[left_index] < options.q[right_index];
+ });
+
+ AdjacentBins bins{0, 0, this->counts[0]};
+ if (is_datapoint) {
+ CType* out_buffer = out_data->template GetMutableValues<CType>(1);
+ for (int64_t i = 0; i < out_length; ++i) {
+ const int64_t q_index = q_indices[i];
+ out_buffer[q_index] = GetQuantileAtDataPoint(
+ in_length, &bins, options.q[q_index], options.interpolation);
+ }
+ } else {
+ double* out_buffer = out_data->template GetMutableValues<double>(1);
+ for (int64_t i = 0; i < out_length; ++i) {
+ const int64_t q_index = q_indices[i];
+ out_buffer[q_index] = GetQuantileByInterp(in_length, &bins, options.q[q_index],
+ options.interpolation);
+ }
+ }
+ }
+
+ *out = Datum(std::move(out_data));
+ return Status::OK();
+ }
+
+ // return quantile located exactly at some input data point
+ CType GetQuantileAtDataPoint(int64_t in_length, AdjacentBins* bins, double q,
+ enum QuantileOptions::Interpolation interpolation) {
+ const uint64_t datapoint_index = QuantileToDataPoint(in_length, q, interpolation);
+ while (datapoint_index >= bins->total_count &&
+ static_cast<size_t>(bins->left_index) < this->counts.size() - 1) {
+ ++bins->left_index;
+ bins->total_count += this->counts[bins->left_index];
+ }
+ DCHECK_LT(datapoint_index, bins->total_count);
+ return static_cast<CType>(bins->left_index + this->min);
+ }
+
+ // return quantile interpolated from adjacent input data points
+ double GetQuantileByInterp(int64_t in_length, AdjacentBins* bins, double q,
+ enum QuantileOptions::Interpolation interpolation) {
+ const double index = (in_length - 1) * q;
+ const uint64_t index_floor = static_cast<uint64_t>(index);
+ const double fraction = index - index_floor;
+
+ while (index_floor >= bins->total_count &&
+ static_cast<size_t>(bins->left_index) < this->counts.size() - 1) {
+ ++bins->left_index;
+ bins->total_count += this->counts[bins->left_index];
+ }
+ DCHECK_LT(index_floor, bins->total_count);
+ const double lower_value = static_cast<double>(bins->left_index + this->min);
+
+ // quantile lies in this bin, no interpolation needed
+ if (index <= bins->total_count - 1) {
+ return lower_value;
+ }
+
+ // quantile lies across two bins, locate next bin if not already done
+ DCHECK_EQ(index_floor, bins->total_count - 1);
+ if (bins->right_index <= bins->left_index) {
+ bins->right_index = bins->left_index + 1;
+ while (static_cast<size_t>(bins->right_index) < this->counts.size() - 1 &&
+ this->counts[bins->right_index] == 0) {
+ ++bins->right_index;
+ }
+ }
+ DCHECK_LT(static_cast<size_t>(bins->right_index), this->counts.size());
+ DCHECK_GT(this->counts[bins->right_index], 0);
+ const double higher_value = static_cast<double>(bins->right_index + this->min);
+
+ if (interpolation == QuantileOptions::LINEAR) {
+ return fraction * higher_value + (1 - fraction) * lower_value;
+ } else if (interpolation == QuantileOptions::MIDPOINT) {
+ return lower_value / 2 + higher_value / 2;
+ } else {
+ DCHECK(false);
+ return NAN;
+ }
+ }
+};
+
+// histogram or 'copy & nth_element' approach per value range and size, only for integers
+template <typename InType>
+struct CountOrSortQuantiler {
+ using CType = typename InType::c_type;
+
+ Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // cross point to benefit from histogram approach
+ // parameters estimated from ad-hoc benchmarks manually
+ static constexpr int kMinArraySize = 65536;
+ static constexpr int kMaxValueRange = 65536;
+
+ const Datum& datum = batch[0];
+ if (datum.length() - datum.null_count() >= kMinArraySize) {
+ CType min, max;
+ std::tie(min, max) = GetMinMax<CType>(datum);
+
+ if (static_cast<uint64_t>(max) - static_cast<uint64_t>(min) <= kMaxValueRange) {
+ return CountQuantiler<InType>(min, max).Exec(ctx, batch, out);
+ }
+ }
+
+ return SortQuantiler<InType>().Exec(ctx, batch, out);
+ }
+};
+
+template <typename InType, typename Enable = void>
+struct ExactQuantiler;
+
+template <>
+struct ExactQuantiler<UInt8Type> {
+ CountQuantiler<UInt8Type> impl;
+ ExactQuantiler() : impl(0, 255) {}
+};
+
+template <>
+struct ExactQuantiler<Int8Type> {
+ CountQuantiler<Int8Type> impl;
+ ExactQuantiler() : impl(-128, 127) {}
+};
+
+template <typename InType>
+struct ExactQuantiler<InType, enable_if_t<(is_integer_type<InType>::value &&
+ (sizeof(typename InType::c_type) > 1))>> {
+ CountOrSortQuantiler<InType> impl;
+};
+
+template <typename InType>
+struct ExactQuantiler<InType, enable_if_t<is_floating_type<InType>::value>> {
+ SortQuantiler<InType> impl;
+};
+
+template <typename T>
+Status ScalarQuantile(KernelContext* ctx, const QuantileOptions& options,
+ const Scalar& scalar, Datum* out) {
+ using CType = typename T::c_type;
+ ArrayData* output = out->mutable_array();
+ output->length = options.q.size();
+ auto out_type = IsDataPoint(options) ? scalar.type : float64();
+ ARROW_ASSIGN_OR_RAISE(
+ output->buffers[1],
+ ctx->Allocate(output->length * BitUtil::BytesForBits(GetBitWidth(*out_type))));
+
+ if (!scalar.is_valid || options.min_count > 1) {
+ output->null_count = output->length;
+ ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(output->length));
+ BitUtil::SetBitsTo(output->buffers[0]->mutable_data(), /*offset=*/0, output->length,
+ false);
+ if (IsDataPoint(options)) {
+ CType* out_buffer = output->template GetMutableValues<CType>(1);
+ std::fill(out_buffer, out_buffer + output->length, CType(0));
+ } else {
+ double* out_buffer = output->template GetMutableValues<double>(1);
+ std::fill(out_buffer, out_buffer + output->length, 0.0);
+ }
+ return Status::OK();
+ }
+ output->null_count = 0;
+ if (IsDataPoint(options)) {
+ CType* out_buffer = output->template GetMutableValues<CType>(1);
+ for (int64_t i = 0; i < output->length; i++) {
+ out_buffer[i] = UnboxScalar<T>::Unbox(scalar);
+ }
+ } else {
+ double* out_buffer = output->template GetMutableValues<double>(1);
+ for (int64_t i = 0; i < output->length; i++) {
+ out_buffer[i] = static_cast<double>(UnboxScalar<T>::Unbox(scalar));
+ }
+ }
+ return Status::OK();
+}
+
+template <typename _, typename InType>
+struct QuantileExecutor {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (ctx->state() == nullptr) {
+ return Status::Invalid("Quantile requires QuantileOptions");
+ }
+
+ const QuantileOptions& options = QuantileState::Get(ctx);
+ if (options.q.empty()) {
+ return Status::Invalid("Requires quantile argument");
+ }
+ for (double q : options.q) {
+ if (q < 0 || q > 1) {
+ return Status::Invalid("Quantile must be between 0 and 1");
+ }
+ }
+
+ if (batch[0].is_scalar()) {
+ return ScalarQuantile<InType>(ctx, options, *batch[0].scalar(), out);
+ }
+
+ return ExactQuantiler<InType>().impl.Exec(ctx, batch, out);
+ }
+};
+
+Result<ValueDescr> ResolveOutput(KernelContext* ctx,
+ const std::vector<ValueDescr>& args) {
+ const QuantileOptions& options = QuantileState::Get(ctx);
+ if (IsDataPoint(options)) {
+ return ValueDescr::Array(args[0].type);
+ } else {
+ return ValueDescr::Array(float64());
+ }
+}
+
+void AddQuantileKernels(VectorFunction* func) {
+ VectorKernel base;
+ base.init = QuantileState::Init;
+ base.can_execute_chunkwise = false;
+ base.output_chunked = false;
+
+ for (const auto& ty : NumericTypes()) {
+ base.signature = KernelSignature::Make({InputType(ty)}, OutputType(ResolveOutput));
+ // output type is determined at runtime, set template argument to nulltype
+ base.exec = GenerateNumeric<QuantileExecutor, NullType>(*ty);
+ DCHECK_OK(func->AddKernel(base));
+ }
+}
+
+const FunctionDoc quantile_doc{
+ "Compute an array of quantiles of a numeric array or chunked array",
+ ("By default, 0.5 quantile (median) is returned.\n"
+ "If quantile lies between two data points, an interpolated value is\n"
+ "returned based on selected interpolation method.\n"
+ "Nulls and NaNs are ignored.\n"
+ "An array of nulls is returned if there is no valid data point."),
+ {"array"},
+ "QuantileOptions"};
+
+} // namespace
+
+void RegisterScalarAggregateQuantile(FunctionRegistry* registry) {
+ static QuantileOptions default_options;
+ auto func = std::make_shared<VectorFunction>("quantile", Arity::Unary(), &quantile_doc,
+ &default_options);
+ AddQuantileKernels(func.get());
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc
new file mode 100644
index 000000000..0fddf38f5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc
@@ -0,0 +1,235 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/kernels/aggregate_internal.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/tdigest.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+namespace {
+
+using arrow::internal::TDigest;
+using arrow::internal::VisitSetBitRunsVoid;
+
+template <typename ArrowType>
+struct TDigestImpl : public ScalarAggregator {
+ using ThisType = TDigestImpl<ArrowType>;
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ using CType = typename ArrowType::c_type;
+
+ explicit TDigestImpl(const TDigestOptions& options)
+ : options{options},
+ tdigest{options.delta, options.buffer_size},
+ count{0},
+ all_valid{true} {}
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ if (!this->all_valid) return Status::OK();
+ if (!options.skip_nulls && batch[0].null_count() > 0) {
+ this->all_valid = false;
+ return Status::OK();
+ }
+ if (batch[0].is_array()) {
+ const ArrayData& data = *batch[0].array();
+ const CType* values = data.GetValues<CType>(1);
+
+ if (data.length > data.GetNullCount()) {
+ this->count += data.length - data.GetNullCount();
+ VisitSetBitRunsVoid(data.buffers[0], data.offset, data.length,
+ [&](int64_t pos, int64_t len) {
+ for (int64_t i = 0; i < len; ++i) {
+ this->tdigest.NanAdd(values[pos + i]);
+ }
+ });
+ }
+ } else {
+ const CType value = UnboxScalar<ArrowType>::Unbox(*batch[0].scalar());
+ if (batch[0].scalar()->is_valid) {
+ this->count += 1;
+ for (int64_t i = 0; i < batch.length; i++) {
+ this->tdigest.NanAdd(value);
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override {
+ const auto& other = checked_cast<const ThisType&>(src);
+ if (!this->all_valid || !other.all_valid) {
+ this->all_valid = false;
+ return Status::OK();
+ }
+ this->tdigest.Merge(other.tdigest);
+ this->count += other.count;
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext* ctx, Datum* out) override {
+ const int64_t out_length = options.q.size();
+ auto out_data = ArrayData::Make(float64(), out_length, 0);
+ out_data->buffers.resize(2, nullptr);
+ ARROW_ASSIGN_OR_RAISE(out_data->buffers[1],
+ ctx->Allocate(out_length * sizeof(double)));
+ double* out_buffer = out_data->template GetMutableValues<double>(1);
+
+ if (this->tdigest.is_empty() || !this->all_valid || this->count < options.min_count) {
+ ARROW_ASSIGN_OR_RAISE(out_data->buffers[0], ctx->AllocateBitmap(out_length));
+ std::memset(out_data->buffers[0]->mutable_data(), 0x00,
+ out_data->buffers[0]->size());
+ std::fill(out_buffer, out_buffer + out_length, 0.0);
+ out_data->null_count = out_length;
+ } else {
+ for (int64_t i = 0; i < out_length; ++i) {
+ out_buffer[i] = this->tdigest.Quantile(this->options.q[i]);
+ }
+ }
+ *out = Datum(std::move(out_data));
+ return Status::OK();
+ }
+
+ const TDigestOptions options;
+ TDigest tdigest;
+ int64_t count;
+ bool all_valid;
+};
+
+struct TDigestInitState {
+ std::unique_ptr<KernelState> state;
+ KernelContext* ctx;
+ const DataType& in_type;
+ const TDigestOptions& options;
+
+ TDigestInitState(KernelContext* ctx, const DataType& in_type,
+ const TDigestOptions& options)
+ : ctx(ctx), in_type(in_type), options(options) {}
+
+ Status Visit(const DataType&) {
+ return Status::NotImplemented("No tdigest implemented");
+ }
+
+ Status Visit(const HalfFloatType&) {
+ return Status::NotImplemented("No tdigest implemented");
+ }
+
+ template <typename Type>
+ enable_if_t<is_number_type<Type>::value, Status> Visit(const Type&) {
+ state.reset(new TDigestImpl<Type>(options));
+ return Status::OK();
+ }
+
+ Result<std::unique_ptr<KernelState>> Create() {
+ RETURN_NOT_OK(VisitTypeInline(in_type, this));
+ return std::move(state);
+ }
+};
+
+Result<std::unique_ptr<KernelState>> TDigestInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ TDigestInitState visitor(ctx, *args.inputs[0].type,
+ static_cast<const TDigestOptions&>(*args.options));
+ return visitor.Create();
+}
+
+void AddTDigestKernels(KernelInit init,
+ const std::vector<std::shared_ptr<DataType>>& types,
+ ScalarAggregateFunction* func) {
+ for (const auto& ty : types) {
+ auto sig = KernelSignature::Make({InputType(ty)}, float64());
+ AddAggKernel(std::move(sig), init, func);
+ }
+}
+
+const FunctionDoc tdigest_doc{
+ "Approximate quantiles of a numeric array with T-Digest algorithm",
+ ("By default, 0.5 quantile (median) is returned.\n"
+ "Nulls and NaNs are ignored.\n"
+ "An array of nulls is returned if there is no valid data point."),
+ {"array"},
+ "TDigestOptions"};
+
+const FunctionDoc approximate_median_doc{
+ "Approximate median of a numeric array with T-Digest algorithm",
+ ("Nulls and NaNs are ignored.\n"
+ "A null scalar is returned if there is no valid data point."),
+ {"array"},
+ "ScalarAggregateOptions"};
+
+std::shared_ptr<ScalarAggregateFunction> AddTDigestAggKernels() {
+ static auto default_tdigest_options = TDigestOptions::Defaults();
+ auto func = std::make_shared<ScalarAggregateFunction>(
+ "tdigest", Arity::Unary(), &tdigest_doc, &default_tdigest_options);
+ AddTDigestKernels(TDigestInit, NumericTypes(), func.get());
+ return func;
+}
+
+std::shared_ptr<ScalarAggregateFunction> AddApproximateMedianAggKernels(
+ const ScalarAggregateFunction* tdigest_func) {
+ static ScalarAggregateOptions default_scalar_aggregate_options;
+
+ auto median = std::make_shared<ScalarAggregateFunction>(
+ "approximate_median", Arity::Unary(), &approximate_median_doc,
+ &default_scalar_aggregate_options);
+
+ auto sig =
+ KernelSignature::Make({InputType(ValueDescr::ANY)}, ValueDescr::Scalar(float64()));
+
+ auto init = [tdigest_func](
+ KernelContext* ctx,
+ const KernelInitArgs& args) -> Result<std::unique_ptr<KernelState>> {
+ std::vector<ValueDescr> inputs = args.inputs;
+ ARROW_ASSIGN_OR_RAISE(auto kernel, tdigest_func->DispatchBest(&inputs));
+ const auto& scalar_options =
+ checked_cast<const ScalarAggregateOptions&>(*args.options);
+ TDigestOptions options;
+ // Default q = 0.5
+ options.min_count = scalar_options.min_count;
+ options.skip_nulls = scalar_options.skip_nulls;
+ KernelInitArgs new_args{kernel, inputs, &options};
+ return kernel->init(ctx, new_args);
+ };
+
+ auto finalize = [](KernelContext* ctx, Datum* out) -> Status {
+ Datum temp;
+ RETURN_NOT_OK(checked_cast<ScalarAggregator*>(ctx->state())->Finalize(ctx, &temp));
+ const auto arr = temp.make_array();
+ DCHECK_EQ(arr->length(), 1);
+ return arr->GetScalar(0).Value(out);
+ };
+
+ AddAggKernel(std::move(sig), std::move(init), std::move(finalize), median.get());
+ return median;
+}
+
+} // namespace
+
+void RegisterScalarAggregateTDigest(FunctionRegistry* registry) {
+ auto tdigest = AddTDigestAggKernels();
+ DCHECK_OK(registry->AddFunction(tdigest));
+
+ auto approx_median = AddApproximateMedianAggKernels(tdigest.get());
+ DCHECK_OK(registry->AddFunction(approx_median));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/aggregate_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_test.cc
new file mode 100644
index 000000000..992f73698
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_test.cc
@@ -0,0 +1,3670 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <limits>
+#include <memory>
+#include <type_traits>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/chunked_array.h"
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/kernels/aggregate_internal.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/compute/registry.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/int_util_internal.h"
+
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::BitmapReader;
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+using internal::FindAccumulatorType;
+
+//
+// Sum
+//
+
+template <typename ArrowType>
+using SumResult =
+ std::pair<typename FindAccumulatorType<ArrowType>::Type::c_type, size_t>;
+
+template <typename ArrowType>
+static SumResult<ArrowType> NaiveSumPartial(const Array& array) {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ using ResultType = SumResult<ArrowType>;
+
+ ResultType result;
+
+ auto data = array.data();
+ const auto& array_numeric = reinterpret_cast<const ArrayType&>(array);
+ const auto values = array_numeric.raw_values();
+
+ if (array.null_count() != 0) {
+ BitmapReader reader(array.null_bitmap_data(), array.offset(), array.length());
+ for (int64_t i = 0; i < array.length(); i++) {
+ if (reader.IsSet()) {
+ result.first += values[i];
+ result.second++;
+ }
+
+ reader.Next();
+ }
+ } else {
+ for (int64_t i = 0; i < array.length(); i++) {
+ result.first += values[i];
+ result.second++;
+ }
+ }
+
+ return result;
+}
+
+template <typename ArrowType>
+static Datum NaiveSum(const Array& array) {
+ using SumType = typename FindAccumulatorType<ArrowType>::Type;
+ using SumScalarType = typename TypeTraits<SumType>::ScalarType;
+
+ auto result = NaiveSumPartial<ArrowType>(array);
+ bool is_valid = result.second > 0;
+
+ if (!is_valid) return Datum(std::make_shared<SumScalarType>());
+ return Datum(std::make_shared<SumScalarType>(result.first));
+}
+
+void ValidateSum(
+ const Datum input, Datum expected,
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) {
+ ASSERT_OK_AND_ASSIGN(Datum result, Sum(input, options));
+ AssertDatumsApproxEqual(expected, result, /*verbose=*/true);
+}
+
+template <typename ArrowType>
+void ValidateSum(
+ const char* json, Datum expected,
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) {
+ auto array = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
+ ValidateSum(*array, expected, options);
+}
+
+template <typename ArrowType>
+void ValidateSum(
+ const std::vector<std::string>& json, Datum expected,
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) {
+ auto array = ChunkedArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
+ ValidateSum(array, expected, options);
+}
+
+template <typename ArrowType>
+void ValidateSum(const Array& array, const ScalarAggregateOptions& options =
+ ScalarAggregateOptions::Defaults()) {
+ ValidateSum(array, NaiveSum<ArrowType>(array), options);
+}
+
+using UnaryOp = Result<Datum>(const Datum&, const ScalarAggregateOptions&, ExecContext*);
+
+template <UnaryOp& Op, typename ScalarAggregateOptions, typename ScalarType>
+void ValidateBooleanAgg(const std::string& json,
+ const std::shared_ptr<ScalarType>& expected,
+ const ScalarAggregateOptions& options) {
+ SCOPED_TRACE(json);
+ auto array = ArrayFromJSON(boolean(), json);
+ ASSERT_OK_AND_ASSIGN(Datum result, Op(array, options, nullptr));
+
+ auto equal_options = EqualOptions::Defaults().nans_equal(true);
+ AssertScalarsEqual(*expected, *result.scalar(), /*verbose=*/true, equal_options);
+}
+
+TEST(TestBooleanAggregation, Sum) {
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults();
+ ValidateBooleanAgg<Sum>("[]", std::make_shared<UInt64Scalar>(), options);
+ ValidateBooleanAgg<Sum>("[null]", std::make_shared<UInt64Scalar>(), options);
+ ValidateBooleanAgg<Sum>("[null, false]", std::make_shared<UInt64Scalar>(0), options);
+ ValidateBooleanAgg<Sum>("[true]", std::make_shared<UInt64Scalar>(1), options);
+ ValidateBooleanAgg<Sum>("[true, false, true]", std::make_shared<UInt64Scalar>(2),
+ options);
+ ValidateBooleanAgg<Sum>("[true, false, true, true, null]",
+ std::make_shared<UInt64Scalar>(3), options);
+
+ const ScalarAggregateOptions& options_min_count_zero =
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0);
+ ValidateBooleanAgg<Sum>("[]", std::make_shared<UInt64Scalar>(0),
+ options_min_count_zero);
+ ValidateBooleanAgg<Sum>("[null]", std::make_shared<UInt64Scalar>(0),
+ options_min_count_zero);
+
+ std::string json = "[true, null, false, null]";
+ ValidateBooleanAgg<Sum>(json, std::make_shared<UInt64Scalar>(1),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1));
+ ValidateBooleanAgg<Sum>(json, std::make_shared<UInt64Scalar>(1),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2));
+ ValidateBooleanAgg<Sum>(json, std::make_shared<UInt64Scalar>(),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3));
+ ValidateBooleanAgg<Sum>("[]", std::make_shared<UInt64Scalar>(0),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
+ ValidateBooleanAgg<Sum>(json, std::make_shared<UInt64Scalar>(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
+ ValidateBooleanAgg<Sum>("[]", std::make_shared<UInt64Scalar>(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
+ ValidateBooleanAgg<Sum>(json, std::make_shared<UInt64Scalar>(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
+
+ json = "[true, false]";
+ ValidateBooleanAgg<Sum>(json, std::make_shared<UInt64Scalar>(1),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1));
+ ValidateBooleanAgg<Sum>(json, std::make_shared<UInt64Scalar>(1),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/2));
+ ValidateBooleanAgg<Sum>(json, std::make_shared<UInt64Scalar>(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
+
+ EXPECT_THAT(Sum(MakeScalar(true)),
+ ResultWith(Datum(std::make_shared<UInt64Scalar>(1))));
+ EXPECT_THAT(Sum(MakeScalar(false)),
+ ResultWith(Datum(std::make_shared<UInt64Scalar>(0))));
+ EXPECT_THAT(Sum(MakeNullScalar(boolean())),
+ ResultWith(Datum(MakeNullScalar(uint64()))));
+ EXPECT_THAT(Sum(MakeNullScalar(boolean()),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ ResultWith(ScalarFromJSON(uint64(), "0")));
+ EXPECT_THAT(Sum(MakeNullScalar(boolean()),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ ResultWith(ScalarFromJSON(uint64(), "null")));
+}
+
+TEST(TestBooleanAggregation, Product) {
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults();
+ ValidateBooleanAgg<Product>("[]", std::make_shared<UInt64Scalar>(), options);
+ ValidateBooleanAgg<Product>("[null]", std::make_shared<UInt64Scalar>(), options);
+ ValidateBooleanAgg<Product>("[null, false]", std::make_shared<UInt64Scalar>(0),
+ options);
+ ValidateBooleanAgg<Product>("[true]", std::make_shared<UInt64Scalar>(1), options);
+ ValidateBooleanAgg<Product>("[true, false, true]", std::make_shared<UInt64Scalar>(0),
+ options);
+ ValidateBooleanAgg<Product>("[true, false, true, true, null]",
+ std::make_shared<UInt64Scalar>(0), options);
+
+ const ScalarAggregateOptions& options_min_count_zero =
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0);
+ ValidateBooleanAgg<Product>("[]", std::make_shared<UInt64Scalar>(1),
+ options_min_count_zero);
+ ValidateBooleanAgg<Product>("[null]", std::make_shared<UInt64Scalar>(1),
+ options_min_count_zero);
+
+ const char* json = "[true, null, true, null]";
+ ValidateBooleanAgg<Product>(
+ json, std::make_shared<UInt64Scalar>(1),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1));
+ ValidateBooleanAgg<Product>(
+ json, std::make_shared<UInt64Scalar>(1),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2));
+ ValidateBooleanAgg<Product>(
+ json, std::make_shared<UInt64Scalar>(),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3));
+ ValidateBooleanAgg<Product>(
+ "[]", std::make_shared<UInt64Scalar>(1),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
+ ValidateBooleanAgg<Product>(
+ json, std::make_shared<UInt64Scalar>(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
+ ValidateBooleanAgg<Product>(
+ "[]", std::make_shared<UInt64Scalar>(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
+ ValidateBooleanAgg<Product>(
+ json, std::make_shared<UInt64Scalar>(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
+
+ EXPECT_THAT(Product(MakeScalar(true)),
+ ResultWith(Datum(std::make_shared<UInt64Scalar>(1))));
+ EXPECT_THAT(Product(MakeScalar(false)),
+ ResultWith(Datum(std::make_shared<UInt64Scalar>(0))));
+ EXPECT_THAT(Product(MakeNullScalar(boolean())),
+ ResultWith(Datum(MakeNullScalar(uint64()))));
+ EXPECT_THAT(Product(MakeNullScalar(boolean()),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ ResultWith(ScalarFromJSON(uint64(), "1")));
+ EXPECT_THAT(Product(MakeNullScalar(boolean()),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ ResultWith(ScalarFromJSON(uint64(), "null")));
+}
+
+TEST(TestBooleanAggregation, Mean) {
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults();
+ ValidateBooleanAgg<Mean>("[]", std::make_shared<DoubleScalar>(), options);
+ ValidateBooleanAgg<Mean>("[null]", std::make_shared<DoubleScalar>(), options);
+ ValidateBooleanAgg<Mean>("[null, false]", std::make_shared<DoubleScalar>(0), options);
+ ValidateBooleanAgg<Mean>("[true]", std::make_shared<DoubleScalar>(1), options);
+ ValidateBooleanAgg<Mean>("[true, false, true, false]",
+ std::make_shared<DoubleScalar>(0.5), options);
+ ValidateBooleanAgg<Mean>("[true, null]", std::make_shared<DoubleScalar>(1), options);
+ ValidateBooleanAgg<Mean>("[true, null, false, true, true]",
+ std::make_shared<DoubleScalar>(0.75), options);
+ ValidateBooleanAgg<Mean>("[true, null, false, false, false]",
+ std::make_shared<DoubleScalar>(0.25), options);
+
+ const ScalarAggregateOptions& options_min_count_zero =
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0);
+ ValidateBooleanAgg<Mean>("[]", std::make_shared<DoubleScalar>(NAN),
+ options_min_count_zero);
+ ValidateBooleanAgg<Mean>("[null]", std::make_shared<DoubleScalar>(NAN),
+ options_min_count_zero);
+
+ const char* json = "[true, null, false, null]";
+ ValidateBooleanAgg<Mean>(json, std::make_shared<DoubleScalar>(0.5),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1));
+ ValidateBooleanAgg<Mean>(json, std::make_shared<DoubleScalar>(0.5),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2));
+ ValidateBooleanAgg<Mean>(json, std::make_shared<DoubleScalar>(),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3));
+ ValidateBooleanAgg<Mean>("[]", std::make_shared<DoubleScalar>(NAN),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
+ ValidateBooleanAgg<Mean>(json, std::make_shared<DoubleScalar>(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
+ ValidateBooleanAgg<Mean>("[]", std::make_shared<DoubleScalar>(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
+ ValidateBooleanAgg<Mean>(json, std::make_shared<DoubleScalar>(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
+
+ EXPECT_THAT(Mean(MakeScalar(true)), ResultWith(ScalarFromJSON(float64(), "1.0")));
+ EXPECT_THAT(Mean(MakeScalar(false)), ResultWith(ScalarFromJSON(float64(), "0.0")));
+ EXPECT_THAT(Mean(MakeNullScalar(boolean())),
+ ResultWith(ScalarFromJSON(float64(), "null")));
+ ASSERT_OK_AND_ASSIGN(
+ auto result, Mean(MakeNullScalar(boolean()),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)));
+ AssertDatumsApproxEqual(result, ScalarFromJSON(float64(), "NaN"), /*detailed=*/true,
+ EqualOptions::Defaults().nans_equal(true));
+ EXPECT_THAT(Mean(MakeNullScalar(boolean()),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ ResultWith(ScalarFromJSON(float64(), "null")));
+}
+
+template <typename ArrowType>
+class TestNumericSumKernel : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestNumericSumKernel, NumericArrowTypes);
+TYPED_TEST(TestNumericSumKernel, SimpleSum) {
+ using SumType = typename FindAccumulatorType<TypeParam>::Type;
+ using ScalarType = typename TypeTraits<SumType>::ScalarType;
+ using InputScalarType = typename TypeTraits<TypeParam>::ScalarType;
+ using T = typename TypeParam::c_type;
+
+ ValidateSum<TypeParam>("[]", Datum(std::make_shared<ScalarType>()));
+
+ ValidateSum<TypeParam>("[null]", Datum(std::make_shared<ScalarType>()));
+
+ ValidateSum<TypeParam>("[0, 1, 2, 3, 4, 5]",
+ Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
+
+ std::vector<std::string> chunks = {"[0, 1, 2, 3, 4, 5]"};
+ ValidateSum<TypeParam>(chunks,
+ Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
+
+ chunks = {"[0, 1, 2]", "[3, 4, 5]"};
+ ValidateSum<TypeParam>(chunks,
+ Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
+
+ chunks = {"[0, 1, 2]", "[]", "[3, 4, 5]"};
+ ValidateSum<TypeParam>(chunks,
+ Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
+
+ ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
+ ValidateSum<TypeParam>("[]", Datum(std::make_shared<ScalarType>(static_cast<T>(0))),
+ options);
+ ValidateSum<TypeParam>("[null]", Datum(std::make_shared<ScalarType>(static_cast<T>(0))),
+ options);
+ chunks = {};
+ ValidateSum<TypeParam>(chunks, Datum(std::make_shared<ScalarType>(static_cast<T>(0))),
+ options);
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0);
+ ValidateSum<TypeParam>("[]", Datum(std::make_shared<ScalarType>(static_cast<T>(0))),
+ options);
+ ValidateSum<TypeParam>("[null]", Datum(std::make_shared<ScalarType>()), options);
+ ValidateSum<TypeParam>("[1, null, 3, null, 3, null, 7]",
+ Datum(std::make_shared<ScalarType>()), options);
+ ValidateSum<TypeParam>("[1, null, 3, null, 3, null, 7]",
+ Datum(std::make_shared<ScalarType>(14)));
+
+ EXPECT_THAT(Sum(Datum(std::make_shared<InputScalarType>(static_cast<T>(5)))),
+ ResultWith(Datum(std::make_shared<ScalarType>(static_cast<T>(5)))));
+ EXPECT_THAT(Sum(MakeNullScalar(TypeTraits<TypeParam>::type_singleton())),
+ ResultWith(Datum(MakeNullScalar(TypeTraits<SumType>::type_singleton()))));
+}
+
+TYPED_TEST(TestNumericSumKernel, ScalarAggregateOptions) {
+ using SumType = typename FindAccumulatorType<TypeParam>::Type;
+ using ScalarType = typename TypeTraits<SumType>::ScalarType;
+ using InputScalarType = typename TypeTraits<TypeParam>::ScalarType;
+ using T = typename TypeParam::c_type;
+
+ const T expected_result = static_cast<T>(14);
+ auto null_result = Datum(std::make_shared<ScalarType>());
+ auto zero_result = Datum(std::make_shared<ScalarType>(static_cast<T>(0)));
+ auto result = Datum(std::make_shared<ScalarType>(expected_result));
+ const char* json = "[1, null, 3, null, 3, null, 7]";
+
+ ValidateSum<TypeParam>("[]", zero_result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0));
+ ValidateSum<TypeParam>("[null]", zero_result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0));
+ ValidateSum<TypeParam>(json, result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3));
+ ValidateSum<TypeParam>(json, result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/4));
+ ValidateSum<TypeParam>(json, null_result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/5));
+ ValidateSum<TypeParam>("[]", null_result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1));
+ ValidateSum<TypeParam>("[null]", null_result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1));
+ ValidateSum<TypeParam>("[]", zero_result,
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
+ ValidateSum<TypeParam>(json, null_result,
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
+
+ EXPECT_THAT(Sum(Datum(std::make_shared<InputScalarType>(static_cast<T>(5))),
+ ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(Datum(std::make_shared<ScalarType>(static_cast<T>(5)))));
+ EXPECT_THAT(Sum(Datum(std::make_shared<InputScalarType>(static_cast<T>(5))),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2)),
+ ResultWith(Datum(MakeNullScalar(TypeTraits<SumType>::type_singleton()))));
+ EXPECT_THAT(Sum(MakeNullScalar(TypeTraits<TypeParam>::type_singleton()),
+ ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(Datum(MakeNullScalar(TypeTraits<SumType>::type_singleton()))));
+}
+
+template <typename ArrowType>
+class TestRandomNumericSumKernel : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestRandomNumericSumKernel, NumericArrowTypes);
+TYPED_TEST(TestRandomNumericSumKernel, RandomArraySum) {
+ auto rand = random::RandomArrayGenerator(0x5487655);
+ const ScalarAggregateOptions& options =
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1);
+ // Test size up to 1<<13 (8192).
+ for (size_t i = 3; i < 14; i += 2) {
+ for (auto null_probability : {0.0, 0.001, 0.1, 0.5, 0.999, 1.0}) {
+ for (auto length_adjust : {-2, -1, 0, 1, 2}) {
+ int64_t length = (1UL << i) + length_adjust;
+ auto array = rand.Numeric<TypeParam>(length, 0, 100, null_probability);
+ ValidateSum<TypeParam>(*array, options);
+ }
+ }
+ }
+}
+
+TYPED_TEST_SUITE(TestRandomNumericSumKernel, NumericArrowTypes);
+TYPED_TEST(TestRandomNumericSumKernel, RandomArraySumOverflow) {
+ using CType = typename TypeParam::c_type;
+ using SumCType = typename FindAccumulatorType<TypeParam>::Type::c_type;
+ if (sizeof(CType) == sizeof(SumCType)) {
+ // Skip if accumulator type is same to original type
+ return;
+ }
+
+ CType max = std::numeric_limits<CType>::max();
+ CType min = std::numeric_limits<CType>::min();
+ int64_t length = 1024;
+
+ auto rand = random::RandomArrayGenerator(0x5487655);
+ const ScalarAggregateOptions& options =
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1);
+ for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
+ // Test overflow on the original type
+ auto array = rand.Numeric<TypeParam>(length, max - 200, max - 100, null_probability);
+ ValidateSum<TypeParam>(*array, options);
+ array = rand.Numeric<TypeParam>(length, min + 100, min + 200, null_probability);
+ ValidateSum<TypeParam>(*array, options);
+ }
+}
+
+TYPED_TEST(TestRandomNumericSumKernel, RandomSliceArraySum) {
+ auto arithmetic = ArrayFromJSON(TypeTraits<TypeParam>::type_singleton(),
+ "[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]");
+ ValidateSum<TypeParam>(*arithmetic);
+ for (size_t i = 1; i < 15; i++) {
+ auto slice = arithmetic->Slice(i, 16);
+ ValidateSum<TypeParam>(*slice);
+ }
+
+ // Trigger ConsumeSparse with different slice offsets.
+ auto rand = random::RandomArrayGenerator(0xfa432643);
+ const int64_t length = 1U << 5;
+ auto array = rand.Numeric<TypeParam>(length, 0, 10, 0.5);
+ for (size_t i = 1; i < 16; i++) {
+ for (size_t j = 1; j < 16; j++) {
+ auto slice = array->Slice(i, length - j);
+ ValidateSum<TypeParam>(*slice);
+ }
+ }
+}
+
+// Test round-off error
+class TestSumKernelRoundOff : public ::testing::Test {};
+
+TEST_F(TestSumKernelRoundOff, Basics) {
+ using ScalarType = typename TypeTraits<DoubleType>::ScalarType;
+
+ // array = np.arange(321000, dtype='float64')
+ // array -= np.mean(array)
+ // array *= arrray
+ double index = 0;
+ ASSERT_OK_AND_ASSIGN(
+ auto array, ArrayFromBuilderVisitor(
+ float64(), 321000, [&](NumericBuilder<DoubleType>* builder) {
+ builder->UnsafeAppend((index - 160499.5) * (index - 160499.5));
+ ++index;
+ }));
+
+ // reference value from numpy.sum()
+ ASSERT_OK_AND_ASSIGN(Datum result, Sum(array));
+ auto sum = checked_cast<const ScalarType*>(result.scalar().get());
+ ASSERT_EQ(sum->value, 2756346749973250.0);
+}
+
+TEST(TestDecimalSumKernel, SimpleSum) {
+ for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, R"([])")),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, R"([null])")),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+ EXPECT_THAT(
+ Sum(ArrayFromJSON(ty, R"(["0.00", "1.01", "2.02", "3.03", "4.04", "5.05"])")),
+ ResultWith(ScalarFromJSON(ty, R"("15.15")")));
+ Datum chunks =
+ ChunkedArrayFromJSON(ty, {R"(["0.00", "1.01", "2.02", "3.03", "4.04", "5.05"])"});
+ EXPECT_THAT(Sum(chunks), ResultWith(ScalarFromJSON(ty, R"("15.15")")));
+ chunks = ChunkedArrayFromJSON(
+ ty, {R"(["0.00", "1.01", "2.02"])", R"(["3.03", "4.04", "5.05"])"});
+ EXPECT_THAT(Sum(chunks), ResultWith(ScalarFromJSON(ty, R"("15.15")")));
+ chunks = ChunkedArrayFromJSON(
+ ty, {R"(["0.00", "1.01", "2.02"])", "[]", R"(["3.03", "4.04", "5.05"])"});
+ EXPECT_THAT(Sum(chunks), ResultWith(ScalarFromJSON(ty, R"("15.15")")));
+
+ ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, R"([])"), options),
+ ResultWith(ScalarFromJSON(ty, R"("0.00")")));
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, R"([null])"), options),
+ ResultWith(ScalarFromJSON(ty, R"("0.00")")));
+ chunks = ChunkedArrayFromJSON(ty, {});
+ EXPECT_THAT(Sum(chunks, options), ResultWith(ScalarFromJSON(ty, R"("0.00")")));
+
+ EXPECT_THAT(
+ Sum(ArrayFromJSON(ty, R"(["1.01", null, "3.03", null, "5.05", null, "7.07"])"),
+ options),
+ ResultWith(ScalarFromJSON(ty, R"("16.16")")));
+
+ EXPECT_THAT(Sum(ScalarFromJSON(ty, R"("5.05")")),
+ ResultWith(ScalarFromJSON(ty, R"("5.05")")));
+ EXPECT_THAT(Sum(ScalarFromJSON(ty, R"(null)")),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+ EXPECT_THAT(Sum(ScalarFromJSON(ty, R"(null)"), options),
+ ResultWith(ScalarFromJSON(ty, R"("0.00")")));
+ }
+}
+
+TEST(TestDecimalSumKernel, ScalarAggregateOptions) {
+ for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ Datum null = ScalarFromJSON(ty, R"(null)");
+ Datum zero = ScalarFromJSON(ty, R"("0.00")");
+ Datum result = ScalarFromJSON(ty, R"("14.14")");
+ Datum arr =
+ ArrayFromJSON(ty, R"(["1.01", null, "3.03", null, "3.03", null, "7.07"])");
+
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, "[]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ ResultWith(zero));
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, "[null]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ ResultWith(zero));
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3)),
+ ResultWith(result));
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/4)),
+ ResultWith(result));
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/5)),
+ ResultWith(null));
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, "[]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
+ ResultWith(null));
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, "[null]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
+ ResultWith(null));
+
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)),
+ ResultWith(null));
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)),
+ ResultWith(null));
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)),
+ ResultWith(null));
+
+ arr = ArrayFromJSON(ty, R"(["1.01", "3.03", "3.03", "7.07"])");
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)),
+ ResultWith(result));
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)),
+ ResultWith(result));
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)),
+ ResultWith(null));
+
+ EXPECT_THAT(Sum(ScalarFromJSON(ty, R"("5.05")"),
+ ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(ScalarFromJSON(ty, R"("5.05")")));
+ EXPECT_THAT(Sum(ScalarFromJSON(ty, R"("5.05")"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2)),
+ ResultWith(null));
+ EXPECT_THAT(Sum(null, ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(null));
+ }
+}
+
+//
+// Product
+//
+
+template <typename ArrowType>
+class TestNumericProductKernel : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestNumericProductKernel, NumericArrowTypes);
+TYPED_TEST(TestNumericProductKernel, SimpleProduct) {
+ using ProductType = typename FindAccumulatorType<TypeParam>::Type;
+ using T = typename TypeParam::c_type;
+ using ProductT = typename ProductType::c_type;
+
+ Datum null_result(std::make_shared<typename TypeTraits<ProductType>::ScalarType>());
+
+ auto ty = TypeTraits<TypeParam>::type_singleton();
+
+ EXPECT_THAT(Product(ArrayFromJSON(ty, "[]")), ResultWith(null_result));
+ EXPECT_THAT(Product(ArrayFromJSON(ty, "[null]")), ResultWith(null_result));
+ EXPECT_THAT(Product(ArrayFromJSON(ty, "[0, 1, 2, 3, 4, 5]")),
+ ResultWith(Datum(static_cast<ProductT>(0))));
+ Datum chunks = ChunkedArrayFromJSON(ty, {"[1, 2, 3, 4, 5]"});
+ EXPECT_THAT(Product(chunks), ResultWith(Datum(static_cast<ProductT>(120))));
+ chunks = ChunkedArrayFromJSON(ty, {"[1, 2]", "[3, 4, 5]"});
+ EXPECT_THAT(Product(chunks), ResultWith(Datum(static_cast<ProductT>(120))));
+ chunks = ChunkedArrayFromJSON(ty, {"[1, 2]", "[]", "[3, 4, 5]"});
+ EXPECT_THAT(Product(chunks), ResultWith(Datum(static_cast<ProductT>(120))));
+
+ ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
+ EXPECT_THAT(Product(ArrayFromJSON(ty, "[]"), options), Datum(static_cast<ProductT>(1)));
+ EXPECT_THAT(Product(ArrayFromJSON(ty, "[null]"), options),
+ Datum(static_cast<ProductT>(1)));
+ chunks = ChunkedArrayFromJSON(ty, {});
+ EXPECT_THAT(Product(chunks, options), Datum(static_cast<ProductT>(1)));
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0);
+ EXPECT_THAT(Product(ArrayFromJSON(ty, "[]"), options),
+ ResultWith(Datum(static_cast<ProductT>(1))));
+ EXPECT_THAT(Product(ArrayFromJSON(ty, "[null]"), options), ResultWith(null_result));
+ EXPECT_THAT(Product(ArrayFromJSON(ty, "[1, null, 3, null, 3, null, 7]"), options),
+ ResultWith(null_result));
+ EXPECT_THAT(Product(ArrayFromJSON(ty, "[1, null, 3, null, 3, null, 7]")),
+ Datum(static_cast<ProductT>(63)));
+
+ EXPECT_THAT(Product(Datum(static_cast<T>(5))),
+ ResultWith(Datum(static_cast<ProductT>(5))));
+ EXPECT_THAT(Product(MakeNullScalar(TypeTraits<TypeParam>::type_singleton())),
+ ResultWith(null_result));
+}
+
+TYPED_TEST(TestNumericProductKernel, ScalarAggregateOptions) {
+ using ProductType = typename FindAccumulatorType<TypeParam>::Type;
+ using T = typename TypeParam::c_type;
+ using ProductT = typename ProductType::c_type;
+
+ Datum null_result(std::make_shared<typename TypeTraits<ProductType>::ScalarType>());
+ Datum one_result(static_cast<ProductT>(1));
+ Datum result(static_cast<ProductT>(63));
+
+ auto ty = TypeTraits<TypeParam>::type_singleton();
+ Datum empty = ArrayFromJSON(ty, "[]");
+ Datum null = ArrayFromJSON(ty, "[null]");
+ Datum arr = ArrayFromJSON(ty, "[1, null, 3, null, 3, null, 7]");
+
+ EXPECT_THAT(
+ Product(empty, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ ResultWith(one_result));
+ EXPECT_THAT(Product(null, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ ResultWith(one_result));
+ EXPECT_THAT(Product(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3)),
+ ResultWith(result));
+ EXPECT_THAT(Product(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/4)),
+ ResultWith(result));
+ EXPECT_THAT(Product(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/5)),
+ ResultWith(null_result));
+ EXPECT_THAT(
+ Product(empty, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
+ ResultWith(null_result));
+ EXPECT_THAT(Product(null, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
+ ResultWith(null_result));
+ EXPECT_THAT(
+ Product(empty, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ ResultWith(one_result));
+ EXPECT_THAT(Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ ResultWith(null_result));
+
+ EXPECT_THAT(
+ Product(Datum(static_cast<T>(5)), ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(Datum(static_cast<ProductT>(5))));
+ EXPECT_THAT(Product(Datum(static_cast<T>(5)),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2)),
+ ResultWith(null_result));
+ EXPECT_THAT(Product(MakeNullScalar(TypeTraits<TypeParam>::type_singleton()),
+ ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(null_result));
+}
+
+TEST(TestDecimalProductKernel, SimpleProduct) {
+ for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ Datum null = ScalarFromJSON(ty, R"(null)");
+
+ EXPECT_THAT(Product(ArrayFromJSON(ty, R"([])")), ResultWith(null));
+ EXPECT_THAT(Product(ArrayFromJSON(ty, R"([null])")), ResultWith(null));
+ EXPECT_THAT(
+ Product(ArrayFromJSON(ty, R"(["0.00", "1.00", "2.00", "3.00", "4.00", "5.00"])")),
+ ResultWith(ScalarFromJSON(ty, R"("0.00")")));
+ Datum chunks =
+ ChunkedArrayFromJSON(ty, {R"(["1.00", "2.00", "3.00", "4.00", "5.00"])"});
+ EXPECT_THAT(Product(chunks), ResultWith(ScalarFromJSON(ty, R"("120.00")")));
+ chunks =
+ ChunkedArrayFromJSON(ty, {R"(["1.00", "2.00"])", R"(["-3.00", "4.00", "5.00"])"});
+ EXPECT_THAT(Product(chunks), ResultWith(ScalarFromJSON(ty, R"("-120.00")")));
+ chunks = ChunkedArrayFromJSON(
+ ty, {R"(["1.00", "2.00"])", R"([])", R"(["-3.00", "4.00", "-5.00"])"});
+ EXPECT_THAT(Product(chunks), ResultWith(ScalarFromJSON(ty, R"("120.00")")));
+
+ const ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
+
+ EXPECT_THAT(Product(ArrayFromJSON(ty, R"([])"), options),
+ ResultWith(ScalarFromJSON(ty, R"("1.00")")));
+ EXPECT_THAT(Product(ArrayFromJSON(ty, R"([null])"), options),
+ ResultWith(ScalarFromJSON(ty, R"("1.00")")));
+ chunks = ChunkedArrayFromJSON(ty, {});
+ EXPECT_THAT(Product(chunks, options), ResultWith(ScalarFromJSON(ty, R"("1.00")")));
+
+ EXPECT_THAT(Product(ArrayFromJSON(
+ ty, R"(["1.00", null, "-3.00", null, "3.00", null, "7.00"])"),
+ options),
+ ResultWith(ScalarFromJSON(ty, R"("-63.00")")));
+
+ EXPECT_THAT(Product(ScalarFromJSON(ty, R"("5.00")")),
+ ResultWith(ScalarFromJSON(ty, R"("5.00")")));
+ EXPECT_THAT(Product(null), ResultWith(null));
+ }
+}
+
+TEST(TestDecimalProductKernel, ScalarAggregateOptions) {
+ for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ Datum null = ScalarFromJSON(ty, R"(null)");
+ Datum one = ScalarFromJSON(ty, R"("1.00")");
+ Datum result = ScalarFromJSON(ty, R"("63.00")");
+
+ Datum empty = ArrayFromJSON(ty, R"([])");
+ Datum null_arr = ArrayFromJSON(ty, R"([null])");
+ Datum arr =
+ ArrayFromJSON(ty, R"(["1.00", null, "3.00", null, "3.00", null, "7.00"])");
+
+ EXPECT_THAT(
+ Product(empty, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ ResultWith(one));
+ EXPECT_THAT(
+ Product(null_arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ ResultWith(one));
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3)),
+ ResultWith(result));
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/4)),
+ ResultWith(result));
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/5)),
+ ResultWith(null));
+ EXPECT_THAT(
+ Product(empty, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
+ ResultWith(null));
+ EXPECT_THAT(
+ Product(null_arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
+ ResultWith(null));
+
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)),
+ ResultWith(null));
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)),
+ ResultWith(null));
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)),
+ ResultWith(null));
+
+ arr = ArrayFromJSON(ty, R"(["1.00", "3.00", "3.00", "7.00"])");
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)),
+ ResultWith(result));
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)),
+ ResultWith(result));
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)),
+ ResultWith(null));
+
+ EXPECT_THAT(Product(ScalarFromJSON(ty, R"("5.00")"),
+ ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(ScalarFromJSON(ty, R"("5.00")")));
+ EXPECT_THAT(Product(ScalarFromJSON(ty, R"("5.00")"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2)),
+ ResultWith(null));
+ EXPECT_THAT(Product(null, ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(null));
+ }
+}
+
+TEST(TestProductKernel, Overflow) {
+ EXPECT_THAT(Product(ArrayFromJSON(int64(), "[8589934592, 8589934593]")),
+ ResultWith(Datum(static_cast<int64_t>(8589934592))));
+}
+
+//
+// Count
+//
+
+using CountPair = std::pair<int64_t, int64_t>;
+
+static CountPair NaiveCount(const Array& array) {
+ CountPair count;
+
+ count.first = array.length() - array.null_count();
+ count.second = array.null_count();
+
+ return count;
+}
+
+void ValidateCount(const Array& input, CountPair expected) {
+ CountOptions non_null;
+ CountOptions nulls(CountOptions::ONLY_NULL);
+ CountOptions all(CountOptions::ALL);
+
+ ASSERT_OK_AND_ASSIGN(Datum result, Count(input, non_null));
+ AssertDatumsEqual(result, Datum(expected.first));
+
+ ASSERT_OK_AND_ASSIGN(result, Count(input, nulls));
+ AssertDatumsEqual(result, Datum(expected.second));
+
+ ASSERT_OK_AND_ASSIGN(result, Count(input, all));
+ AssertDatumsEqual(result, Datum(expected.first + expected.second));
+}
+
+template <typename ArrowType>
+void ValidateCount(const char* json, CountPair expected) {
+ auto array = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
+ ValidateCount(*array, expected);
+}
+
+void ValidateCount(const Array& input) { ValidateCount(input, NaiveCount(input)); }
+
+template <typename ArrowType>
+class TestCountKernel : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestCountKernel, NumericArrowTypes);
+TYPED_TEST(TestCountKernel, SimpleCount) {
+ ValidateCount<TypeParam>("[]", {0, 0});
+ ValidateCount<TypeParam>("[null]", {0, 1});
+ ValidateCount<TypeParam>("[1, null, 2]", {2, 1});
+ ValidateCount<TypeParam>("[null, null, null]", {0, 3});
+ ValidateCount<TypeParam>("[1, 2, 3, 4, 5, 6, 7, 8, 9]", {9, 0});
+
+ auto ty = TypeTraits<TypeParam>::type_singleton();
+ EXPECT_THAT(Count(MakeNullScalar(ty)), ResultWith(Datum(int64_t(0))));
+ EXPECT_THAT(Count(MakeNullScalar(ty), CountOptions(CountOptions::ONLY_NULL)),
+ ResultWith(Datum(int64_t(1))));
+ EXPECT_THAT(Count(*MakeScalar(ty, 1)), ResultWith(Datum(int64_t(1))));
+ EXPECT_THAT(Count(*MakeScalar(ty, 1), CountOptions(CountOptions::ONLY_NULL)),
+ ResultWith(Datum(int64_t(0))));
+
+ CountOptions all(CountOptions::ALL);
+ EXPECT_THAT(Count(MakeNullScalar(ty), all), ResultWith(Datum(int64_t(1))));
+ EXPECT_THAT(Count(*MakeScalar(ty, 1), all), ResultWith(Datum(int64_t(1))));
+}
+
+template <typename ArrowType>
+class TestRandomNumericCountKernel : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestRandomNumericCountKernel, NumericArrowTypes);
+TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount) {
+ auto rand = random::RandomArrayGenerator(0x1205643);
+ for (size_t i = 3; i < 10; i++) {
+ for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) {
+ for (auto length_adjust : {-2, -1, 0, 1, 2}) {
+ int64_t length = (1UL << i) + length_adjust;
+ auto array = rand.Numeric<TypeParam>(length, 0, 100, null_probability);
+ ValidateCount(*array);
+ }
+ }
+ }
+}
+
+//
+// Count Distinct
+//
+
+class TestCountDistinctKernel : public ::testing::Test {
+ protected:
+ Datum Expected(int64_t value) { return MakeScalar(static_cast<int64_t>(value)); }
+
+ void Check(Datum input, int64_t expected_all, bool has_nulls = true) {
+ int64_t expected_valid = has_nulls ? expected_all - 1 : expected_all;
+ int64_t expected_null = has_nulls ? 1 : 0;
+ CheckScalar("count_distinct", {input}, Expected(expected_valid), &only_valid);
+ CheckScalar("count_distinct", {input}, Expected(expected_null), &only_null);
+ CheckScalar("count_distinct", {input}, Expected(expected_all), &all);
+ }
+
+ void Check(const std::shared_ptr<DataType>& type, util::string_view json,
+ int64_t expected_all, bool has_nulls = true) {
+ Check(ArrayFromJSON(type, json), expected_all, has_nulls);
+ }
+
+ void Check(const std::shared_ptr<DataType>& type, util::string_view json) {
+ auto input = ScalarFromJSON(type, json);
+ auto zero = ResultWith(Expected(0));
+ auto one = ResultWith(Expected(1));
+ // non null scalar
+ EXPECT_THAT(CallFunction("count_distinct", {input}, &only_valid), one);
+ EXPECT_THAT(CallFunction("count_distinct", {input}, &only_null), zero);
+ EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one);
+ // null scalar
+ input = MakeNullScalar(input->type);
+ EXPECT_THAT(CallFunction("count_distinct", {input}, &only_valid), zero);
+ EXPECT_THAT(CallFunction("count_distinct", {input}, &only_null), one);
+ EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one);
+ }
+
+ CountOptions only_valid{CountOptions::ONLY_VALID};
+ CountOptions only_null{CountOptions::ONLY_NULL};
+ CountOptions all{CountOptions::ALL};
+};
+
+TEST_F(TestCountDistinctKernel, AllArrayTypesWithNulls) {
+ // Boolean
+ Check(boolean(), "[]", 0, /*has_nulls=*/false);
+ Check(boolean(), "[true, null, false, null, false, true]", 3);
+ // Number
+ for (auto ty : NumericTypes()) {
+ Check(ty, "[1, 1, null, 2, 5, 8, 9, 9, null, 10, 6, 6]", 8);
+ Check(ty, "[1, 1, 8, 2, 5, 8, 9, 9, 10, 10, 6, 6]", 7, /*has_nulls=*/false);
+ }
+ // Date
+ Check(date32(), "[0, 11016, 0, null, 14241, 14241, null]", 4);
+ Check(date64(), "[0, null, 0, null, 0, 0, 1262217600000]", 3);
+ // Time
+ Check(time32(TimeUnit::SECOND), "[0, 11, 0, null, 14, 14, null]", 4);
+ Check(time32(TimeUnit::MILLI), "[0, 11000, 0, null, 11000, 11000]", 3);
+ Check(time64(TimeUnit::MICRO), "[84203999999, 0, null, 84203999999, 0]", 3);
+ Check(time64(TimeUnit::NANO), "[11715003000000, 0, null, 0, 0]", 3);
+ // Timestamp & Duration
+ for (auto u : TimeUnit::values()) {
+ Check(duration(u), "[123456789, null, 987654321, 123456789, null]", 3);
+ Check(duration(u), "[123456789, 987654321, 123456789, 123456789]", 2,
+ /*has_nulls=*/false);
+ auto ts = R"(["2009-12-31T04:20:20", "2020-01-01", null, "2009-12-31T04:20:20"])";
+ Check(timestamp(u), ts, 3);
+ Check(timestamp(u, "Pacific/Marquesas"), ts, 3);
+ }
+ // Interval
+ Check(month_interval(), "[9012, 5678, null, 9012, 5678, null, 9012]", 3);
+ Check(day_time_interval(), "[[0, 1], [0, 1], null, [0, 1], [1234, 5678]]", 3);
+ Check(month_day_nano_interval(), "[[0, 1, 2], [0, 1, 2], null, [0, 1, 2]]", 2);
+ // Binary & String & Fixed binary
+ auto samples = R"([null, "abc", null, "abc", "abc", "cba", "bca", "cba", null])";
+ Check(binary(), samples, 4);
+ Check(large_binary(), samples, 4);
+ Check(utf8(), samples, 4);
+ Check(large_utf8(), samples, 4);
+ Check(fixed_size_binary(3), samples, 4);
+ // Decimal
+ samples = R"(["12345.679", "98765.421", null, "12345.679", "98765.421"])";
+ Check(decimal128(21, 3), samples, 3);
+ Check(decimal256(13, 3), samples, 3);
+}
+
+TEST_F(TestCountDistinctKernel, AllScalarTypesWithNulls) {
+ // Boolean
+ Check(boolean(), "true");
+ // Number
+ for (auto ty : NumericTypes()) {
+ Check(ty, "91");
+ }
+ // Date
+ Check(date32(), "11016");
+ Check(date64(), "1262217600000");
+ // Time
+ Check(time32(TimeUnit::SECOND), "14");
+ Check(time32(TimeUnit::MILLI), "11000");
+ Check(time64(TimeUnit::MICRO), "84203999999");
+ Check(time64(TimeUnit::NANO), "11715003000000");
+ // Timestamp & Duration
+ for (auto u : TimeUnit::values()) {
+ Check(duration(u), "987654321");
+ Check(duration(u), "123456789");
+ auto ts = R"("2009-12-31T04:20:20")";
+ Check(timestamp(u), ts);
+ Check(timestamp(u, "Pacific/Marquesas"), ts);
+ }
+ // Interval
+ Check(month_interval(), "5678");
+ Check(day_time_interval(), "[1234, 5678]");
+ Check(month_day_nano_interval(), "[0, 1, 2]");
+ // Binary & String & Fixed binary
+ auto sample = R"("cba")";
+ Check(binary(), sample);
+ Check(large_binary(), sample);
+ Check(utf8(), sample);
+ Check(large_utf8(), sample);
+ Check(fixed_size_binary(3), sample);
+ // Decimal
+ sample = R"("98765.421")";
+ Check(decimal128(21, 3), sample);
+ Check(decimal256(13, 3), sample);
+}
+
+TEST_F(TestCountDistinctKernel, Random) {
+ UInt32Builder builder;
+ std::unordered_set<uint32_t> memo;
+ auto visit_null = []() { return Status::OK(); };
+ auto visit_value = [&](uint32_t arg) {
+ const bool inserted = memo.insert(arg).second;
+ if (inserted) {
+ return builder.Append(arg);
+ }
+ return Status::OK();
+ };
+ auto rand = random::RandomArrayGenerator(0x1205643);
+ auto arr = rand.Numeric<UInt32Type>(1024, 0, 100, 0.0)->data();
+ auto r = VisitArrayDataInline<UInt32Type>(*arr, visit_value, visit_null);
+ auto input = builder.Finish().ValueOrDie();
+ Check(input, memo.size(), false);
+}
+
+//
+// Mean
+//
+
+template <typename ArrowType>
+static Datum NaiveMean(const Array& array) {
+ using MeanScalarType = typename TypeTraits<DoubleType>::ScalarType;
+
+ const auto result = NaiveSumPartial<ArrowType>(array);
+ const double mean = static_cast<double>(result.first) /
+ static_cast<double>(result.second ? result.second : 1UL);
+ const bool is_valid = result.second > 0;
+
+ if (!is_valid) return Datum(std::make_shared<MeanScalarType>());
+ return Datum(std::make_shared<MeanScalarType>(mean));
+}
+
+template <typename ArrowType>
+void ValidateMean(const Array& input, Datum expected,
+ const ScalarAggregateOptions& options) {
+ ASSERT_OK_AND_ASSIGN(Datum result, Mean(input, options, nullptr));
+ auto equal_options = EqualOptions::Defaults().nans_equal(true);
+ AssertDatumsApproxEqual(expected, result, /*verbose=*/true, equal_options);
+}
+
+template <typename ArrowType>
+void ValidateMean(
+ const std::string& json, Datum expected,
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) {
+ auto array = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
+ ValidateMean<ArrowType>(*array, expected, options);
+}
+
+template <typename ArrowType>
+void ValidateMean(const Array& array, const ScalarAggregateOptions& options =
+ ScalarAggregateOptions::Defaults()) {
+ ValidateMean<ArrowType>(array, NaiveMean<ArrowType>(array), options);
+}
+
+template <typename ArrowType>
+class TestNumericMeanKernel : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestNumericMeanKernel, NumericArrowTypes);
+TYPED_TEST(TestNumericMeanKernel, SimpleMean) {
+ using ScalarType = typename TypeTraits<DoubleType>::ScalarType;
+ using InputScalarType = typename TypeTraits<TypeParam>::ScalarType;
+ using T = typename TypeParam::c_type;
+
+ const ScalarAggregateOptions& options =
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0);
+
+ ValidateMean<TypeParam>("[]", Datum(std::make_shared<ScalarType>(NAN)), options);
+
+ ValidateMean<TypeParam>("[null]", Datum(std::make_shared<ScalarType>(NAN)), options);
+
+ ValidateMean<TypeParam>("[]", Datum(std::make_shared<ScalarType>()));
+
+ ValidateMean<TypeParam>("[null]", Datum(std::make_shared<ScalarType>()));
+
+ ValidateMean<TypeParam>("[1, null, 1]", Datum(std::make_shared<ScalarType>(1.0)));
+
+ ValidateMean<TypeParam>("[1, 2, 3, 4, 5, 6, 7, 8]",
+ Datum(std::make_shared<ScalarType>(4.5)));
+
+ ValidateMean<TypeParam>("[0, 0, 0, 0, 0, 0, 0, 0]",
+ Datum(std::make_shared<ScalarType>(0.0)));
+
+ ValidateMean<TypeParam>("[1, 1, 1, 1, 1, 1, 1, 1]",
+ Datum(std::make_shared<ScalarType>(1.0)));
+
+ EXPECT_THAT(Mean(Datum(std::make_shared<InputScalarType>(static_cast<T>(5)))),
+ ResultWith(Datum(std::make_shared<ScalarType>(5.0))));
+ EXPECT_THAT(Mean(MakeNullScalar(TypeTraits<TypeParam>::type_singleton())),
+ ResultWith(Datum(MakeNullScalar(float64()))));
+}
+
+TYPED_TEST(TestNumericMeanKernel, ScalarAggregateOptions) {
+ using ScalarType = typename TypeTraits<DoubleType>::ScalarType;
+ using InputScalarType = typename TypeTraits<TypeParam>::ScalarType;
+ using T = typename TypeParam::c_type;
+ auto expected_result = Datum(std::make_shared<ScalarType>(3));
+ auto null_result = Datum(std::make_shared<ScalarType>());
+ auto nan_result = Datum(std::make_shared<ScalarType>(NAN));
+ std::string json = "[1, null, 2, 2, null, 7]";
+
+ ValidateMean<TypeParam>("[]", nan_result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0));
+ ValidateMean<TypeParam>("[null]", nan_result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0));
+ ValidateMean<TypeParam>("[]", null_result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1));
+ ValidateMean<TypeParam>("[null]", null_result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1));
+ ValidateMean<TypeParam>(json, expected_result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0));
+ ValidateMean<TypeParam>(json, expected_result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3));
+ ValidateMean<TypeParam>(json, expected_result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/4));
+ ValidateMean<TypeParam>(json, null_result,
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/5));
+
+ ValidateMean<TypeParam>("[]", nan_result,
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
+ ValidateMean<TypeParam>("[null]", null_result,
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
+ ValidateMean<TypeParam>(json, null_result,
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
+
+ ValidateMean<TypeParam>("[]", null_result,
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1));
+ ValidateMean<TypeParam>("[null]", null_result,
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1));
+ ValidateMean<TypeParam>(json, null_result,
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1));
+
+ json = "[1, 2, 2, 7]";
+ ValidateMean<TypeParam>(json, expected_result,
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1));
+ ValidateMean<TypeParam>(json, expected_result,
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4));
+ ValidateMean<TypeParam>(json, null_result,
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5));
+
+ EXPECT_THAT(Mean(Datum(std::make_shared<InputScalarType>(static_cast<T>(5))),
+ ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(Datum(std::make_shared<ScalarType>(5.0))));
+ EXPECT_THAT(Mean(Datum(std::make_shared<InputScalarType>(static_cast<T>(5))),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2)),
+ ResultWith(Datum(MakeNullScalar(float64()))));
+ EXPECT_THAT(Mean(MakeNullScalar(TypeTraits<TypeParam>::type_singleton()),
+ ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(Datum(MakeNullScalar(float64()))));
+}
+
+template <typename ArrowType>
+class TestRandomNumericMeanKernel : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestRandomNumericMeanKernel, NumericArrowTypes);
+TYPED_TEST(TestRandomNumericMeanKernel, RandomArrayMean) {
+ auto rand = random::RandomArrayGenerator(0x8afc055);
+ const ScalarAggregateOptions& options =
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1);
+ // Test size up to 1<<13 (8192).
+ for (size_t i = 3; i < 14; i += 2) {
+ for (auto null_probability : {0.0, 0.001, 0.1, 0.5, 0.999, 1.0}) {
+ for (auto length_adjust : {-2, -1, 0, 1, 2}) {
+ int64_t length = (1UL << i) + length_adjust;
+ auto array = rand.Numeric<TypeParam>(length, 0, 100, null_probability);
+ ValidateMean<TypeParam>(*array, options);
+ }
+ }
+ }
+}
+
+TYPED_TEST_SUITE(TestRandomNumericMeanKernel, NumericArrowTypes);
+TYPED_TEST(TestRandomNumericMeanKernel, RandomArrayMeanOverflow) {
+ using CType = typename TypeParam::c_type;
+ using SumCType = typename FindAccumulatorType<TypeParam>::Type::c_type;
+ if (sizeof(CType) == sizeof(SumCType)) {
+ // Skip if accumulator type is same to original type
+ return;
+ }
+
+ CType max = std::numeric_limits<CType>::max();
+ CType min = std::numeric_limits<CType>::min();
+ int64_t length = 1024;
+
+ auto rand = random::RandomArrayGenerator(0x8afc055);
+ const ScalarAggregateOptions& options =
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1);
+ for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
+ // Test overflow on the original type
+ auto array = rand.Numeric<TypeParam>(length, max - 200, max - 100, null_probability);
+ ValidateMean<TypeParam>(*array, options);
+ array = rand.Numeric<TypeParam>(length, min + 100, min + 200, null_probability);
+ ValidateMean<TypeParam>(*array, options);
+ }
+}
+
+TEST(TestDecimalMeanKernel, SimpleMean) {
+ ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
+
+ for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ // Decimal doesn't have NaN
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, R"([])"), options),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, R"([null])"), options),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, R"([])")),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, R"([null])")),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, R"(["1.01", null, "1.01"])")),
+ ResultWith(ScalarFromJSON(ty, R"("1.01")")));
+ EXPECT_THAT(
+ Mean(ArrayFromJSON(
+ ty, R"(["1.01", "2.02", "3.03", "4.04", "5.05", "6.06", "7.07", "8.08"])")),
+ ResultWith(ScalarFromJSON(ty, R"("4.54")")));
+ EXPECT_THAT(
+ Mean(ArrayFromJSON(
+ ty, R"(["0.00", "0.00", "0.00", "0.00", "0.00", "0.00", "0.00", "0.00"])")),
+ ResultWith(ScalarFromJSON(ty, R"("0.00")")));
+ EXPECT_THAT(
+ Mean(ArrayFromJSON(
+ ty, R"(["1.01", "1.01", "1.01", "1.01", "1.01", "1.01", "1.01", "1.01"])")),
+ ResultWith(ScalarFromJSON(ty, R"("1.01")")));
+
+ EXPECT_THAT(Mean(ScalarFromJSON(ty, R"("5.05")")),
+ ResultWith(ScalarFromJSON(ty, R"("5.05")")));
+ EXPECT_THAT(Mean(ScalarFromJSON(ty, R"(null)")),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+ }
+}
+
+TEST(TestDecimalMeanKernel, ScalarAggregateOptions) {
+ for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ Datum result = ScalarFromJSON(ty, R"("3.03")");
+ Datum null = ScalarFromJSON(ty, R"(null)");
+ Datum arr = ArrayFromJSON(ty, R"(["1.01", null, "2.02", "2.02", null, "7.07"])");
+
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ null);
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[null]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ null);
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
+ null);
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[null]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
+ null);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ result);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3)),
+ result);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/4)),
+ result);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/5)),
+ null);
+
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[]"),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ null);
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[null]"),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ null);
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[]"),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1)),
+ null);
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[null]"),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1)),
+ null);
+
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ null);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)),
+ null);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)),
+ null);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)),
+ null);
+
+ arr = ArrayFromJSON(ty, R"(["1.01", "2.02", "2.02", "7.07"])");
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ result);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)),
+ ResultWith(result));
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)),
+ ResultWith(result));
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)),
+ ResultWith(null));
+
+ EXPECT_THAT(Mean(ScalarFromJSON(ty, R"("5.05")"),
+ ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(ScalarFromJSON(ty, R"("5.05")")));
+ EXPECT_THAT(Mean(ScalarFromJSON(ty, R"("5.05")"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2)),
+ ResultWith(null));
+ EXPECT_THAT(Mean(null, ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(null));
+ }
+}
+
+//
+// Min / Max
+//
+
+template <typename ArrowType>
+class TestPrimitiveMinMaxKernel : public ::testing::Test {
+ using Traits = TypeTraits<ArrowType>;
+ using ArrayType = typename Traits::ArrayType;
+ using c_type = typename ArrowType::c_type;
+ using ScalarType = typename Traits::ScalarType;
+
+ public:
+ void AssertMinMaxIs(const Datum& array, c_type expected_min, c_type expected_max,
+ const ScalarAggregateOptions& options) {
+ ASSERT_OK_AND_ASSIGN(Datum out, MinMax(array, options));
+ const StructScalar& value = out.scalar_as<StructScalar>();
+
+ {
+ const auto& out_min = checked_cast<const ScalarType&>(*value.value[0]);
+ ASSERT_EQ(expected_min, out_min.value);
+
+ const auto& out_max = checked_cast<const ScalarType&>(*value.value[1]);
+ ASSERT_EQ(expected_max, out_max.value);
+ }
+
+ {
+ ASSERT_OK_AND_ASSIGN(out, CallFunction("min", {array}, &options));
+ const auto& out_min = out.scalar_as<ScalarType>();
+ ASSERT_EQ(expected_min, out_min.value);
+
+ ASSERT_OK_AND_ASSIGN(out, CallFunction("max", {array}, &options));
+ const auto& out_max = out.scalar_as<ScalarType>();
+ ASSERT_EQ(expected_max, out_max.value);
+ }
+ }
+
+ void AssertMinMaxIs(const std::string& json, c_type expected_min, c_type expected_max,
+ const ScalarAggregateOptions& options) {
+ auto array = ArrayFromJSON(type_singleton(), json);
+ AssertMinMaxIs(array, expected_min, expected_max, options);
+ }
+
+ void AssertMinMaxIs(const std::vector<std::string>& json, c_type expected_min,
+ c_type expected_max, const ScalarAggregateOptions& options) {
+ auto array = ChunkedArrayFromJSON(type_singleton(), json);
+ AssertMinMaxIs(array, expected_min, expected_max, options);
+ }
+
+ void AssertMinMaxIsNull(const Datum& array, const ScalarAggregateOptions& options) {
+ ASSERT_OK_AND_ASSIGN(Datum out, MinMax(array, options));
+ for (const auto& val : out.scalar_as<StructScalar>().value) {
+ ASSERT_FALSE(val->is_valid);
+ }
+ }
+
+ void AssertMinMaxIsNull(const std::string& json,
+ const ScalarAggregateOptions& options) {
+ auto array = ArrayFromJSON(type_singleton(), json);
+ AssertMinMaxIsNull(array, options);
+ }
+
+ void AssertMinMaxIsNull(const std::vector<std::string>& json,
+ const ScalarAggregateOptions& options) {
+ auto array = ChunkedArrayFromJSON(type_singleton(), json);
+ AssertMinMaxIsNull(array, options);
+ }
+
+ std::shared_ptr<DataType> type_singleton() {
+ return default_type_instance<ArrowType>();
+ }
+};
+
+template <typename ArrowType>
+class TestIntegerMinMaxKernel : public TestPrimitiveMinMaxKernel<ArrowType> {};
+
+template <typename ArrowType>
+class TestFloatingMinMaxKernel : public TestPrimitiveMinMaxKernel<ArrowType> {};
+
+class TestBooleanMinMaxKernel : public TestPrimitiveMinMaxKernel<BooleanType> {};
+class TestDayTimeIntervalMinMaxKernel
+ : public TestPrimitiveMinMaxKernel<DayTimeIntervalType> {};
+
+TEST_F(TestBooleanMinMaxKernel, Basics) {
+ ScalarAggregateOptions options;
+ std::vector<std::string> chunked_input0 = {"[]", "[]"};
+ std::vector<std::string> chunked_input1 = {"[true, true, null]", "[true, null]"};
+ std::vector<std::string> chunked_input2 = {"[false, false, false]", "[false]"};
+ std::vector<std::string> chunked_input3 = {"[true, null]", "[null, false]"};
+ auto ty = struct_({field("min", boolean()), field("max", boolean())});
+
+ // SKIP nulls by default
+ options = ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1);
+ this->AssertMinMaxIsNull("[]", options);
+ this->AssertMinMaxIsNull("[null, null, null]", options);
+ this->AssertMinMaxIs("[false, false, false]", false, false, options);
+ this->AssertMinMaxIs("[false, false, false, null]", false, false, options);
+ this->AssertMinMaxIs("[true, null, true, true]", true, true, options);
+ this->AssertMinMaxIs("[true, null, true, true]", true, true, options);
+ this->AssertMinMaxIs("[true, null, false, true]", false, true, options);
+ this->AssertMinMaxIsNull(chunked_input0, options);
+ this->AssertMinMaxIs(chunked_input1, true, true, options);
+ this->AssertMinMaxIs(chunked_input2, false, false, options);
+ this->AssertMinMaxIs(chunked_input3, false, true, options);
+
+ Datum null_min_max = ScalarFromJSON(ty, "[null, null]");
+ Datum true_min_max = ScalarFromJSON(ty, "[true, true]");
+ Datum false_min_max = ScalarFromJSON(ty, "[false, false]");
+ EXPECT_THAT(MinMax(MakeNullScalar(boolean())), ResultWith(null_min_max));
+ EXPECT_THAT(MinMax(MakeScalar(true)), ResultWith(true_min_max));
+ EXPECT_THAT(MinMax(MakeScalar(false)), ResultWith(false_min_max));
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1);
+ this->AssertMinMaxIsNull("[]", options);
+ this->AssertMinMaxIsNull("[null, null, null]", options);
+ this->AssertMinMaxIsNull("[false, null, false]", options);
+ this->AssertMinMaxIsNull("[true, null]", options);
+ this->AssertMinMaxIs("[true, true, true]", true, true, options);
+ this->AssertMinMaxIs("[false, false]", false, false, options);
+ this->AssertMinMaxIs("[false, true]", false, true, options);
+ this->AssertMinMaxIsNull(chunked_input0, options);
+ this->AssertMinMaxIsNull(chunked_input1, options);
+ this->AssertMinMaxIs(chunked_input2, false, false, options);
+ this->AssertMinMaxIsNull(chunked_input3, options);
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/2);
+ EXPECT_THAT(MinMax(MakeNullScalar(boolean()), options), ResultWith(null_min_max));
+ EXPECT_THAT(MinMax(MakeScalar(true), options), ResultWith(null_min_max));
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0);
+ this->AssertMinMaxIsNull("[]", options);
+ this->AssertMinMaxIsNull("[null]", options);
+}
+
+TYPED_TEST_SUITE(TestIntegerMinMaxKernel, PhysicalIntegralArrowTypes);
+TYPED_TEST(TestIntegerMinMaxKernel, Basics) {
+ ScalarAggregateOptions options;
+ std::vector<std::string> chunked_input1 = {"[5, 1, 2, 3, 4]", "[9, 1, null, 3, 4]"};
+ std::vector<std::string> chunked_input2 = {"[5, null, 2, 3, 4]", "[9, 1, 2, 3, 4]"};
+ std::vector<std::string> chunked_input3 = {"[5, 1, 2, 3, null]", "[9, 1, null, 3, 4]"};
+ auto item_ty = default_type_instance<TypeParam>();
+ auto ty = struct_({field("min", item_ty), field("max", item_ty)});
+
+ // SKIP nulls by default
+ this->AssertMinMaxIsNull("[]", options);
+ this->AssertMinMaxIsNull("[null, null, null]", options);
+ this->AssertMinMaxIs("[5, 1, 2, 3, 4]", 1, 5, options);
+ this->AssertMinMaxIs("[5, null, 2, 3, 4]", 2, 5, options);
+ this->AssertMinMaxIs(chunked_input1, 1, 9, options);
+ this->AssertMinMaxIs(chunked_input2, 1, 9, options);
+ this->AssertMinMaxIs(chunked_input3, 1, 9, options);
+
+ Datum null_min_max(std::make_shared<StructScalar>(
+ ScalarVector{MakeNullScalar(item_ty), MakeNullScalar(item_ty)}, ty));
+ auto one_scalar = *MakeScalar(item_ty, static_cast<typename TypeParam::c_type>(1));
+ Datum one_min_max(
+ std::make_shared<StructScalar>(ScalarVector{one_scalar, one_scalar}, ty));
+ EXPECT_THAT(MinMax(MakeNullScalar(item_ty)), ResultWith(null_min_max));
+ EXPECT_THAT(MinMax(one_scalar), ResultWith(one_min_max));
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1);
+ this->AssertMinMaxIs("[5, 1, 2, 3, 4]", 1, 5, options);
+ this->AssertMinMaxIsNull("[5, null, 2, 3, 4]", options);
+ this->AssertMinMaxIsNull(chunked_input1, options);
+ this->AssertMinMaxIsNull(chunked_input2, options);
+ this->AssertMinMaxIsNull(chunked_input3, options);
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/5);
+ this->AssertMinMaxIs("[5, 1, 2, 3, 4]", 1, 5, options);
+ this->AssertMinMaxIsNull("[5, null, 2, 3, 4]", options);
+ this->AssertMinMaxIs(chunked_input1, 1, 9, options);
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5);
+ this->AssertMinMaxIs("[5, 1, 2, 3, 4]", 1, 5, options);
+ this->AssertMinMaxIsNull("[5, null, 2, 3, 4]", options);
+ this->AssertMinMaxIsNull(chunked_input1, options);
+}
+
+TYPED_TEST_SUITE(TestFloatingMinMaxKernel, RealArrowTypes);
+TYPED_TEST(TestFloatingMinMaxKernel, Floats) {
+ ScalarAggregateOptions options;
+ std::vector<std::string> chunked_input1 = {"[5, 1, 2, 3, 4]", "[9, 1, null, 3, 4]"};
+ std::vector<std::string> chunked_input2 = {"[5, null, 2, 3, 4]", "[9, 1, 2, 3, 4]"};
+ std::vector<std::string> chunked_input3 = {"[5, 1, 2, 3, null]", "[9, 1, null, 3, 4]"};
+ auto item_ty = TypeTraits<TypeParam>::type_singleton();
+ auto ty = struct_({field("min", item_ty), field("max", item_ty)});
+
+ this->AssertMinMaxIs("[5, 1, 2, 3, 4]", 1, 5, options);
+ this->AssertMinMaxIs("[5, null, 2, 3, 4]", 2, 5, options);
+ this->AssertMinMaxIs("[5, Inf, 2, 3, 4]", 2.0, INFINITY, options);
+ this->AssertMinMaxIs("[5, NaN, 2, 3, 4]", 2, 5, options);
+ this->AssertMinMaxIs("[5, -Inf, 2, 3, 4]", -INFINITY, 5, options);
+ this->AssertMinMaxIs(chunked_input1, 1, 9, options);
+ this->AssertMinMaxIs(chunked_input2, 1, 9, options);
+ this->AssertMinMaxIs(chunked_input3, 1, 9, options);
+
+ Datum null_min_max(std::make_shared<StructScalar>(
+ ScalarVector{MakeNullScalar(item_ty), MakeNullScalar(item_ty)}, ty));
+ auto one_scalar = *MakeScalar(item_ty, static_cast<typename TypeParam::c_type>(1));
+ Datum one_min_max(
+ std::make_shared<StructScalar>(ScalarVector{one_scalar, one_scalar}, ty));
+ EXPECT_THAT(MinMax(MakeNullScalar(item_ty)), ResultWith(null_min_max));
+ EXPECT_THAT(MinMax(one_scalar), ResultWith(one_min_max));
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/false);
+ this->AssertMinMaxIs("[5, 1, 2, 3, 4]", 1, 5, options);
+ this->AssertMinMaxIs("[5, -Inf, 2, 3, 4]", -INFINITY, 5, options);
+ this->AssertMinMaxIsNull("[5, null, 2, 3, 4]", options);
+ this->AssertMinMaxIsNull("[5, -Inf, null, 3, 4]", options);
+ this->AssertMinMaxIsNull(chunked_input1, options);
+ this->AssertMinMaxIsNull(chunked_input2, options);
+ this->AssertMinMaxIsNull(chunked_input3, options);
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0);
+ this->AssertMinMaxIsNull("[]", options);
+ this->AssertMinMaxIsNull("[null]", options);
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1);
+ this->AssertMinMaxIsNull("[]", options);
+ this->AssertMinMaxIsNull("[null]", options);
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/5);
+ this->AssertMinMaxIs("[5, 1, 2, 3, 4]", 1, 5, options);
+ this->AssertMinMaxIsNull("[5, null, 2, 3, 4]", options);
+ this->AssertMinMaxIs(chunked_input1, 1, 9, options);
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5);
+ this->AssertMinMaxIs("[5, 1, 2, 3, 4]", 1, 5, options);
+ this->AssertMinMaxIsNull("[5, null, 2, 3, 4]", options);
+ this->AssertMinMaxIsNull(chunked_input1, options);
+}
+
+TYPED_TEST(TestFloatingMinMaxKernel, DefaultOptions) {
+ auto values = ArrayFromJSON(this->type_singleton(), "[0, 1, 2, 3, 4]");
+
+ ASSERT_OK_AND_ASSIGN(auto no_options_provided, CallFunction("min_max", {values}));
+
+ auto default_options = ScalarAggregateOptions::Defaults();
+ ASSERT_OK_AND_ASSIGN(auto explicit_defaults,
+ CallFunction("min_max", {values}, &default_options));
+
+ AssertDatumsEqual(explicit_defaults, no_options_provided);
+}
+
+TEST(TestDecimalMinMaxKernel, Decimals) {
+ for (const auto& item_ty : {decimal128(5, 2), decimal256(5, 2)}) {
+ auto ty = struct_({field("min", item_ty), field("max", item_ty)});
+
+ Datum chunked_input1 =
+ ChunkedArrayFromJSON(item_ty, {R"(["5.10", "1.23", "2.00", "3.45", "4.56"])",
+ R"(["9.42", "1.01", null, "3.14", "4.00"])"});
+ Datum chunked_input2 =
+ ChunkedArrayFromJSON(item_ty, {R"(["5.10", null, "2.00", "3.45", "4.56"])",
+ R"(["9.42", "1.01", "2.52", "3.14", "4.00"])"});
+ Datum chunked_input3 =
+ ChunkedArrayFromJSON(item_ty, {R"(["5.10", "1.23", "2.00", "3.45", null])",
+ R"(["9.42", "1.01", null, "3.14", "4.00"])"});
+
+ ScalarAggregateOptions options;
+
+ EXPECT_THAT(
+ MinMax(ArrayFromJSON(item_ty, R"(["5.10", "-1.23", "2.00", "3.45", "4.56"])"),
+ options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": "-1.23", "max": "5.10"})")));
+ EXPECT_THAT(
+ MinMax(ArrayFromJSON(item_ty, R"(["-5.10", "-1.23", "-2.00", "-3.45", "-4.56"])"),
+ options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": "-5.10", "max": "-1.23"})")));
+ EXPECT_THAT(
+ MinMax(ArrayFromJSON(item_ty, R"(["5.10", null, "2.00", "3.45", "4.56"])"),
+ options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": "2.00", "max": "5.10"})")));
+
+ EXPECT_THAT(MinMax(chunked_input1, options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": "1.01", "max": "9.42"})")));
+ EXPECT_THAT(MinMax(chunked_input2, options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": "1.01", "max": "9.42"})")));
+ EXPECT_THAT(MinMax(chunked_input3, options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": "1.01", "max": "9.42"})")));
+
+ EXPECT_THAT(CallFunction("min", {chunked_input1}, &options),
+ ResultWith(ScalarFromJSON(item_ty, R"("1.01")")));
+ EXPECT_THAT(CallFunction("max", {chunked_input1}, &options),
+ ResultWith(ScalarFromJSON(item_ty, R"("9.42")")));
+
+ EXPECT_THAT(MinMax(ScalarFromJSON(item_ty, "null"), options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": null, "max": null})")));
+ EXPECT_THAT(MinMax(ScalarFromJSON(item_ty, R"("1.00")"), options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": "1.00", "max": "1.00"})")));
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/false);
+ EXPECT_THAT(
+ MinMax(ArrayFromJSON(item_ty, R"(["5.10", "-1.23", "2.00", "3.45", "4.56"])"),
+ options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": "-1.23", "max": "5.10"})")));
+ EXPECT_THAT(
+ MinMax(ArrayFromJSON(item_ty, R"(["5.10", null, "2.00", "3.45", "4.56"])"),
+ options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": null, "max": null})")));
+ EXPECT_THAT(MinMax(chunked_input1, options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": null, "max": null})")));
+ EXPECT_THAT(MinMax(chunked_input2, options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": null, "max": null})")));
+ EXPECT_THAT(MinMax(chunked_input3, options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": null, "max": null})")));
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0);
+ EXPECT_THAT(MinMax(ArrayFromJSON(item_ty, R"([])"), options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": null, "max": null})")));
+ EXPECT_THAT(MinMax(ArrayFromJSON(item_ty, R"([null])"), options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": null, "max": null})")));
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/5);
+ EXPECT_THAT(MinMax(ArrayFromJSON(
+ item_ty, R"(["5.10", "-1.23", "2.00", "3.45", "4.56", null])"),
+ options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": "-1.23", "max": "5.10"})")));
+ EXPECT_THAT(
+ MinMax(ArrayFromJSON(item_ty, R"(["5.10", "-1.23", "2.00", "3.45", null])"),
+ options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": null, "max": null})")));
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1);
+ EXPECT_THAT(MinMax(ArrayFromJSON(item_ty, R"([])"), options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": null, "max": null})")));
+ EXPECT_THAT(MinMax(ArrayFromJSON(item_ty, R"([null])"), options),
+ ResultWith(ScalarFromJSON(ty, R"({"min": null, "max": null})")));
+ }
+}
+
+TEST(TestNullMinMaxKernel, Basics) {
+ auto item_ty = null();
+ auto ty = struct_({field("min", item_ty), field("max", item_ty)});
+ Datum result = ScalarFromJSON(ty, "[null, null]");
+ EXPECT_THAT(MinMax(ScalarFromJSON(item_ty, "null")), ResultWith(result));
+ EXPECT_THAT(MinMax(ArrayFromJSON(item_ty, "[]")), ResultWith(result));
+ EXPECT_THAT(MinMax(ArrayFromJSON(item_ty, "[null]")), ResultWith(result));
+ EXPECT_THAT(MinMax(ChunkedArrayFromJSON(item_ty, {"[null]", "[]", "[null, null]"})),
+ ResultWith(result));
+}
+
+template <typename ArrowType>
+class TestBaseBinaryMinMaxKernel : public ::testing::Test {};
+TYPED_TEST_SUITE(TestBaseBinaryMinMaxKernel, BinaryArrowTypes);
+TYPED_TEST(TestBaseBinaryMinMaxKernel, Basics) {
+ std::vector<std::string> chunked_input1 = {R"(["cc", "", "aa", "b", "c"])",
+ R"(["d", "", null, "b", "c"])"};
+ std::vector<std::string> chunked_input2 = {R"(["cc", null, "aa", "b", "c"])",
+ R"(["d", "", "aa", "b", "c"])"};
+ std::vector<std::string> chunked_input3 = {R"(["cc", "", "aa", "b", null])",
+ R"(["d", "", null, "b", "c"])"};
+ auto ty = std::make_shared<TypeParam>();
+ auto res_ty = struct_({field("min", ty), field("max", ty)});
+ Datum null = ScalarFromJSON(res_ty, R"([null, null])");
+
+ // SKIP nulls by default
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, R"([])")), ResultWith(null));
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, R"([null, null, null])")), ResultWith(null));
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input1[0])),
+ ResultWith(ScalarFromJSON(res_ty, R"(["", "cc"])")));
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input2[0])),
+ ResultWith(ScalarFromJSON(res_ty, R"(["aa", "cc"])")));
+ EXPECT_THAT(MinMax(ChunkedArrayFromJSON(ty, chunked_input1)),
+ ResultWith(ScalarFromJSON(res_ty, R"(["", "d"])")));
+ EXPECT_THAT(MinMax(ChunkedArrayFromJSON(ty, chunked_input2)),
+ ResultWith(ScalarFromJSON(res_ty, R"(["", "d"])")));
+ EXPECT_THAT(MinMax(ChunkedArrayFromJSON(ty, chunked_input3)),
+ ResultWith(ScalarFromJSON(res_ty, R"(["", "d"])")));
+
+ EXPECT_THAT(MinMax(MakeNullScalar(ty)), ResultWith(null));
+ EXPECT_THAT(MinMax(ScalarFromJSON(ty, R"("one")")),
+ ResultWith(ScalarFromJSON(res_ty, R"(["one", "one"])")));
+
+ ScalarAggregateOptions options(/*skip_nulls=*/false);
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input1[0]), options),
+ ResultWith(ScalarFromJSON(res_ty, R"(["", "cc"])")));
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input2[0]), options), ResultWith(null));
+ EXPECT_THAT(MinMax(ChunkedArrayFromJSON(ty, chunked_input1), options),
+ ResultWith(null));
+ EXPECT_THAT(MinMax(MakeNullScalar(ty), options), ResultWith(null));
+ EXPECT_THAT(MinMax(ScalarFromJSON(ty, R"("one")"), options),
+ ResultWith(ScalarFromJSON(res_ty, R"(["one", "one"])")));
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/9);
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input1[0]), options), ResultWith(null));
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input2[0]), options), ResultWith(null));
+ EXPECT_THAT(MinMax(ChunkedArrayFromJSON(ty, chunked_input1), options),
+ ResultWith(ScalarFromJSON(res_ty, R"(["", "d"])")));
+ EXPECT_THAT(MinMax(MakeNullScalar(ty), options), ResultWith(null));
+ EXPECT_THAT(MinMax(ScalarFromJSON(ty, R"("one")"), options), ResultWith(null));
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4);
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input1[0]), options),
+ ResultWith(ScalarFromJSON(res_ty, R"(["", "cc"])")));
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input2[0]), options), ResultWith(null));
+ EXPECT_THAT(MinMax(ChunkedArrayFromJSON(ty, chunked_input1), options),
+ ResultWith(null));
+ EXPECT_THAT(MinMax(MakeNullScalar(ty), options), ResultWith(null));
+ EXPECT_THAT(MinMax(ScalarFromJSON(ty, R"("one")"), options), ResultWith(null));
+}
+
+TEST(TestFixedSizeBinaryMinMaxKernel, Basics) {
+ auto ty = fixed_size_binary(2);
+ std::vector<std::string> chunked_input1 = {R"(["cd", "aa", "ab", "bb", "cc"])",
+ R"(["da", "aa", null, "bb", "cc"])"};
+ std::vector<std::string> chunked_input2 = {R"(["cd", null, "ab", "bb", "cc"])",
+ R"(["da", "aa", "ab", "bb", "cc"])"};
+ std::vector<std::string> chunked_input3 = {R"(["cd", "aa", "ab", "bb", null])",
+ R"(["da", "aa", null, "bb", "cc"])"};
+ auto res_ty = struct_({field("min", ty), field("max", ty)});
+ Datum null = ScalarFromJSON(res_ty, R"([null, null])");
+
+ // SKIP nulls by default
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, R"([])")), ResultWith(null));
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, R"([null, null, null])")), ResultWith(null));
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input1[0])),
+ ResultWith(ScalarFromJSON(res_ty, R"(["aa", "cd"])")));
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input2[0])),
+ ResultWith(ScalarFromJSON(res_ty, R"(["ab", "cd"])")));
+ EXPECT_THAT(MinMax(ChunkedArrayFromJSON(ty, chunked_input1)),
+ ResultWith(ScalarFromJSON(res_ty, R"(["aa", "da"])")));
+ EXPECT_THAT(MinMax(ChunkedArrayFromJSON(ty, chunked_input2)),
+ ResultWith(ScalarFromJSON(res_ty, R"(["aa", "da"])")));
+ EXPECT_THAT(MinMax(ChunkedArrayFromJSON(ty, chunked_input3)),
+ ResultWith(ScalarFromJSON(res_ty, R"(["aa", "da"])")));
+
+ EXPECT_THAT(MinMax(MakeNullScalar(ty)), ResultWith(null));
+ EXPECT_THAT(MinMax(ScalarFromJSON(ty, R"("aa")")),
+ ResultWith(ScalarFromJSON(res_ty, R"(["aa", "aa"])")));
+
+ ScalarAggregateOptions options(/*skip_nulls=*/false);
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input1[0]), options),
+ ResultWith(ScalarFromJSON(res_ty, R"(["aa", "cd"])")));
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input2[0]), options), ResultWith(null));
+ EXPECT_THAT(MinMax(ChunkedArrayFromJSON(ty, chunked_input1), options),
+ ResultWith(null));
+ EXPECT_THAT(MinMax(MakeNullScalar(ty), options), ResultWith(null));
+ EXPECT_THAT(MinMax(ScalarFromJSON(ty, R"("aa")"), options),
+ ResultWith(ScalarFromJSON(res_ty, R"(["aa", "aa"])")));
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/9);
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input1[0]), options), ResultWith(null));
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input2[0]), options), ResultWith(null));
+ EXPECT_THAT(MinMax(ChunkedArrayFromJSON(ty, chunked_input1), options),
+ ResultWith(ScalarFromJSON(res_ty, R"(["aa", "da"])")));
+ EXPECT_THAT(MinMax(MakeNullScalar(ty), options), ResultWith(null));
+ EXPECT_THAT(MinMax(ScalarFromJSON(ty, R"("aa")"), options), ResultWith(null));
+
+ options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4);
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input1[0]), options),
+ ResultWith(ScalarFromJSON(res_ty, R"(["aa", "cd"])")));
+ EXPECT_THAT(MinMax(ArrayFromJSON(ty, chunked_input2[0]), options), ResultWith(null));
+ EXPECT_THAT(MinMax(ChunkedArrayFromJSON(ty, chunked_input1), options),
+ ResultWith(null));
+ EXPECT_THAT(MinMax(MakeNullScalar(ty), options), ResultWith(null));
+ EXPECT_THAT(MinMax(ScalarFromJSON(ty, R"("aa")"), options), ResultWith(null));
+}
+
+template <typename ArrowType>
+struct MinMaxResult {
+ using T = typename ArrowType::c_type;
+
+ T min = 0;
+ T max = 0;
+ bool is_valid = false;
+};
+
+template <typename ArrowType>
+static enable_if_integer<ArrowType, MinMaxResult<ArrowType>> NaiveMinMax(
+ const Array& array) {
+ using T = typename ArrowType::c_type;
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ MinMaxResult<ArrowType> result;
+
+ const auto& array_numeric = reinterpret_cast<const ArrayType&>(array);
+ const auto values = array_numeric.raw_values();
+
+ if (array.length() <= array.null_count()) { // All null values
+ return result;
+ }
+
+ T min = std::numeric_limits<T>::max();
+ T max = std::numeric_limits<T>::min();
+ if (array.null_count() != 0) { // Some values are null
+ BitmapReader reader(array.null_bitmap_data(), array.offset(), array.length());
+ for (int64_t i = 0; i < array.length(); i++) {
+ if (reader.IsSet()) {
+ min = std::min(min, values[i]);
+ max = std::max(max, values[i]);
+ }
+ reader.Next();
+ }
+ } else { // All true values
+ for (int64_t i = 0; i < array.length(); i++) {
+ min = std::min(min, values[i]);
+ max = std::max(max, values[i]);
+ }
+ }
+
+ result.min = min;
+ result.max = max;
+ result.is_valid = true;
+ return result;
+}
+
+template <typename ArrowType>
+static enable_if_floating_point<ArrowType, MinMaxResult<ArrowType>> NaiveMinMax(
+ const Array& array) {
+ using T = typename ArrowType::c_type;
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ MinMaxResult<ArrowType> result;
+
+ const auto& array_numeric = reinterpret_cast<const ArrayType&>(array);
+ const auto values = array_numeric.raw_values();
+
+ if (array.length() <= array.null_count()) { // All null values
+ return result;
+ }
+
+ T min = std::numeric_limits<T>::infinity();
+ T max = -std::numeric_limits<T>::infinity();
+ if (array.null_count() != 0) { // Some values are null
+ BitmapReader reader(array.null_bitmap_data(), array.offset(), array.length());
+ for (int64_t i = 0; i < array.length(); i++) {
+ if (reader.IsSet()) {
+ min = std::fmin(min, values[i]);
+ max = std::fmax(max, values[i]);
+ }
+ reader.Next();
+ }
+ } else { // All true values
+ for (int64_t i = 0; i < array.length(); i++) {
+ min = std::fmin(min, values[i]);
+ max = std::fmax(max, values[i]);
+ }
+ }
+
+ result.min = min;
+ result.max = max;
+ result.is_valid = true;
+ return result;
+}
+
+template <typename ArrowType>
+void ValidateMinMax(const Array& array, const ScalarAggregateOptions& options) {
+ using Traits = TypeTraits<ArrowType>;
+ using ScalarType = typename Traits::ScalarType;
+
+ ASSERT_OK_AND_ASSIGN(Datum out, MinMax(array, options));
+ const StructScalar& value = out.scalar_as<StructScalar>();
+
+ auto expected = NaiveMinMax<ArrowType>(array);
+ const auto& out_min = checked_cast<const ScalarType&>(*value.value[0]);
+ const auto& out_max = checked_cast<const ScalarType&>(*value.value[1]);
+
+ if (expected.is_valid) {
+ ASSERT_TRUE(out_min.is_valid);
+ ASSERT_TRUE(out_max.is_valid);
+ ASSERT_EQ(expected.min, out_min.value);
+ ASSERT_EQ(expected.max, out_max.value);
+ } else { // All null values
+ ASSERT_FALSE(out_min.is_valid);
+ ASSERT_FALSE(out_max.is_valid);
+ }
+}
+
+template <typename ArrowType>
+class TestRandomNumericMinMaxKernel : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestRandomNumericMinMaxKernel, NumericArrowTypes);
+TYPED_TEST(TestRandomNumericMinMaxKernel, RandomArrayMinMax) {
+ auto rand = random::RandomArrayGenerator(0x8afc055);
+ const ScalarAggregateOptions& options =
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1);
+ // Test size up to 1<<11 (2048).
+ for (size_t i = 3; i < 12; i += 2) {
+ for (auto null_probability : {0.0, 0.01, 0.1, 0.5, 0.99, 1.0}) {
+ int64_t base_length = (1UL << i) + 2;
+ auto array = rand.Numeric<TypeParam>(base_length, 0, 100, null_probability);
+ for (auto length_adjust : {-2, -1, 0, 1, 2}) {
+ int64_t length = (1UL << i) + length_adjust;
+ ValidateMinMax<TypeParam>(*array->Slice(0, length), options);
+ }
+ }
+ }
+}
+
+//
+// Any
+//
+
+class TestAnyKernel : public ::testing::Test {
+ public:
+ void AssertAnyIs(const Datum& array, const std::shared_ptr<BooleanScalar>& expected,
+ const ScalarAggregateOptions& options) {
+ SCOPED_TRACE(options.ToString());
+ ASSERT_OK_AND_ASSIGN(Datum out, Any(array, options, nullptr));
+ const BooleanScalar& out_any = out.scalar_as<BooleanScalar>();
+ AssertScalarsEqual(*expected, out_any, /*verbose=*/true);
+ }
+
+ void AssertAnyIs(
+ const std::string& json, const std::shared_ptr<BooleanScalar>& expected,
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) {
+ SCOPED_TRACE(json);
+ auto array = ArrayFromJSON(boolean(), json);
+ AssertAnyIs(array, expected, options);
+ }
+
+ void AssertAnyIs(
+ const std::vector<std::string>& json,
+ const std::shared_ptr<BooleanScalar>& expected,
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) {
+ auto array = ChunkedArrayFromJSON(boolean(), json);
+ AssertAnyIs(array, expected, options);
+ }
+};
+
+TEST_F(TestAnyKernel, Basics) {
+ auto true_value = std::make_shared<BooleanScalar>(true);
+ auto false_value = std::make_shared<BooleanScalar>(false);
+ auto null_value = std::make_shared<BooleanScalar>();
+ null_value->is_valid = false;
+
+ std::vector<std::string> chunked_input0 = {"[]", "[true]"};
+ std::vector<std::string> chunked_input1 = {"[true, true, null]", "[true, null]"};
+ std::vector<std::string> chunked_input2 = {"[false, false, false]", "[false]"};
+ std::vector<std::string> chunked_input3 = {"[false, null]", "[null, false]"};
+ std::vector<std::string> chunked_input4 = {"[true, null]", "[null, false]"};
+
+ const ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
+ this->AssertAnyIs("[]", false_value, options);
+ this->AssertAnyIs("[false]", false_value, options);
+ this->AssertAnyIs("[true, false]", true_value, options);
+ this->AssertAnyIs("[null, null, null]", false_value, options);
+ this->AssertAnyIs("[false, false, false]", false_value, options);
+ this->AssertAnyIs("[false, false, false, null]", false_value, options);
+ this->AssertAnyIs("[true, null, true, true]", true_value, options);
+ this->AssertAnyIs("[false, null, false, true]", true_value, options);
+ this->AssertAnyIs("[true, null, false, true]", true_value, options);
+ this->AssertAnyIs(chunked_input0, true_value, options);
+ this->AssertAnyIs(chunked_input1, true_value, options);
+ this->AssertAnyIs(chunked_input2, false_value, options);
+ this->AssertAnyIs(chunked_input3, false_value, options);
+ this->AssertAnyIs(chunked_input4, true_value, options);
+
+ EXPECT_THAT(Any(Datum(true), options), ResultWith(Datum(true)));
+ EXPECT_THAT(Any(Datum(false), options), ResultWith(Datum(false)));
+ EXPECT_THAT(Any(Datum(null_value), options), ResultWith(Datum(false)));
+
+ const ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false, /*min_count=*/0);
+ this->AssertAnyIs("[]", false_value, keep_nulls);
+ this->AssertAnyIs("[false]", false_value, keep_nulls);
+ this->AssertAnyIs("[true, false]", true_value, keep_nulls);
+ this->AssertAnyIs("[null, null, null]", null_value, keep_nulls);
+ this->AssertAnyIs("[false, false, false]", false_value, keep_nulls);
+ this->AssertAnyIs("[false, false, false, null]", null_value, keep_nulls);
+ this->AssertAnyIs("[true, null, true, true]", true_value, keep_nulls);
+ this->AssertAnyIs("[false, null, false, true]", true_value, keep_nulls);
+ this->AssertAnyIs("[true, null, false, true]", true_value, keep_nulls);
+ this->AssertAnyIs(chunked_input0, true_value, keep_nulls);
+ this->AssertAnyIs(chunked_input1, true_value, keep_nulls);
+ this->AssertAnyIs(chunked_input2, false_value, keep_nulls);
+ this->AssertAnyIs(chunked_input3, null_value, keep_nulls);
+ this->AssertAnyIs(chunked_input4, true_value, keep_nulls);
+
+ EXPECT_THAT(Any(Datum(true), keep_nulls), ResultWith(Datum(true)));
+ EXPECT_THAT(Any(Datum(false), keep_nulls), ResultWith(Datum(false)));
+ EXPECT_THAT(Any(Datum(null_value), keep_nulls), ResultWith(Datum(null_value)));
+
+ const ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/2);
+ this->AssertAnyIs("[]", null_value, min_count);
+ this->AssertAnyIs("[false]", null_value, min_count);
+ this->AssertAnyIs("[true, false]", true_value, min_count);
+ this->AssertAnyIs("[null, null, null]", null_value, min_count);
+ this->AssertAnyIs("[false, false, false]", false_value, min_count);
+ this->AssertAnyIs("[false, false, false, null]", false_value, min_count);
+ this->AssertAnyIs("[true, null, true, true]", true_value, min_count);
+ this->AssertAnyIs("[false, null, false, true]", true_value, min_count);
+ this->AssertAnyIs("[true, null, false, true]", true_value, min_count);
+ this->AssertAnyIs(chunked_input0, null_value, min_count);
+ this->AssertAnyIs(chunked_input1, true_value, min_count);
+ this->AssertAnyIs(chunked_input2, false_value, min_count);
+ this->AssertAnyIs(chunked_input3, false_value, min_count);
+ this->AssertAnyIs(chunked_input4, true_value, min_count);
+
+ EXPECT_THAT(Any(Datum(true), min_count), ResultWith(Datum(null_value)));
+ EXPECT_THAT(Any(Datum(false), min_count), ResultWith(Datum(null_value)));
+ EXPECT_THAT(Any(Datum(null_value), min_count), ResultWith(Datum(null_value)));
+}
+
+//
+// All
+//
+
+class TestAllKernel : public ::testing::Test {
+ public:
+ void AssertAllIs(const Datum& array, const std::shared_ptr<BooleanScalar>& expected,
+ const ScalarAggregateOptions& options) {
+ SCOPED_TRACE(options.ToString());
+ ASSERT_OK_AND_ASSIGN(Datum out, All(array, options, nullptr));
+ const BooleanScalar& out_all = out.scalar_as<BooleanScalar>();
+ AssertScalarsEqual(*expected, out_all, /*verbose=*/true);
+ }
+
+ void AssertAllIs(
+ const std::string& json, const std::shared_ptr<BooleanScalar>& expected,
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) {
+ SCOPED_TRACE(json);
+ auto array = ArrayFromJSON(boolean(), json);
+ AssertAllIs(array, expected, options);
+ }
+
+ void AssertAllIs(
+ const std::vector<std::string>& json,
+ const std::shared_ptr<BooleanScalar>& expected,
+ const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) {
+ auto array = ChunkedArrayFromJSON(boolean(), json);
+ AssertAllIs(array, expected, options);
+ }
+};
+
+TEST_F(TestAllKernel, Basics) {
+ auto true_value = std::make_shared<BooleanScalar>(true);
+ auto false_value = std::make_shared<BooleanScalar>(false);
+ auto null_value = std::make_shared<BooleanScalar>();
+ null_value->is_valid = false;
+
+ std::vector<std::string> chunked_input0 = {"[]", "[true]"};
+ std::vector<std::string> chunked_input1 = {"[true, true, null]", "[true, null]"};
+ std::vector<std::string> chunked_input2 = {"[false, false, false]", "[false]"};
+ std::vector<std::string> chunked_input3 = {"[false, null]", "[null, false]"};
+ std::vector<std::string> chunked_input4 = {"[true, null]", "[null, false]"};
+ std::vector<std::string> chunked_input5 = {"[false, null]", "[null, true]"};
+
+ const ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
+ this->AssertAllIs("[]", true_value, options);
+ this->AssertAllIs("[false]", false_value, options);
+ this->AssertAllIs("[true, false]", false_value, options);
+ this->AssertAllIs("[null, null, null]", true_value, options);
+ this->AssertAllIs("[false, false, false]", false_value, options);
+ this->AssertAllIs("[false, false, false, null]", false_value, options);
+ this->AssertAllIs("[true, null, true, true]", true_value, options);
+ this->AssertAllIs("[false, null, false, true]", false_value, options);
+ this->AssertAllIs("[true, null, false, true]", false_value, options);
+ this->AssertAllIs(chunked_input0, true_value, options);
+ this->AssertAllIs(chunked_input1, true_value, options);
+ this->AssertAllIs(chunked_input2, false_value, options);
+ this->AssertAllIs(chunked_input3, false_value, options);
+ this->AssertAllIs(chunked_input4, false_value, options);
+ this->AssertAllIs(chunked_input5, false_value, options);
+
+ EXPECT_THAT(All(Datum(true), options), ResultWith(Datum(true)));
+ EXPECT_THAT(All(Datum(false), options), ResultWith(Datum(false)));
+ EXPECT_THAT(All(Datum(null_value), options), ResultWith(Datum(true)));
+
+ const ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false, /*min_count=*/0);
+ this->AssertAllIs("[]", true_value, keep_nulls);
+ this->AssertAllIs("[false]", false_value, keep_nulls);
+ this->AssertAllIs("[true, false]", false_value, keep_nulls);
+ this->AssertAllIs("[null, null, null]", null_value, keep_nulls);
+ this->AssertAllIs("[false, false, false]", false_value, keep_nulls);
+ this->AssertAllIs("[false, false, false, null]", false_value, keep_nulls);
+ this->AssertAllIs("[true, null, true, true]", null_value, keep_nulls);
+ this->AssertAllIs("[false, null, false, true]", false_value, keep_nulls);
+ this->AssertAllIs("[true, null, false, true]", false_value, keep_nulls);
+ this->AssertAllIs(chunked_input0, true_value, keep_nulls);
+ this->AssertAllIs(chunked_input1, null_value, keep_nulls);
+ this->AssertAllIs(chunked_input2, false_value, keep_nulls);
+ this->AssertAllIs(chunked_input3, false_value, keep_nulls);
+ this->AssertAllIs(chunked_input4, false_value, keep_nulls);
+ this->AssertAllIs(chunked_input5, false_value, keep_nulls);
+
+ EXPECT_THAT(All(Datum(true), keep_nulls), ResultWith(Datum(true)));
+ EXPECT_THAT(All(Datum(false), keep_nulls), ResultWith(Datum(false)));
+ EXPECT_THAT(All(Datum(null_value), keep_nulls), ResultWith(Datum(null_value)));
+
+ const ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/2);
+ this->AssertAllIs("[]", null_value, min_count);
+ this->AssertAllIs("[false]", null_value, min_count);
+ this->AssertAllIs("[true, false]", false_value, min_count);
+ this->AssertAllIs("[null, null, null]", null_value, min_count);
+ this->AssertAllIs("[false, false, false]", false_value, min_count);
+ this->AssertAllIs("[false, false, false, null]", false_value, min_count);
+ this->AssertAllIs("[true, null, true, true]", true_value, min_count);
+ this->AssertAllIs("[false, null, false, true]", false_value, min_count);
+ this->AssertAllIs("[true, null, false, true]", false_value, min_count);
+ this->AssertAllIs(chunked_input0, null_value, min_count);
+ this->AssertAllIs(chunked_input1, true_value, min_count);
+ this->AssertAllIs(chunked_input2, false_value, min_count);
+ this->AssertAllIs(chunked_input3, false_value, min_count);
+ this->AssertAllIs(chunked_input4, false_value, min_count);
+ this->AssertAllIs(chunked_input5, false_value, min_count);
+
+ EXPECT_THAT(All(Datum(true), min_count), ResultWith(Datum(null_value)));
+ EXPECT_THAT(All(Datum(false), min_count), ResultWith(Datum(null_value)));
+ EXPECT_THAT(All(Datum(null_value), min_count), ResultWith(Datum(null_value)));
+}
+
+//
+// Index
+//
+
+template <typename ArrowType>
+class TestIndexKernel : public ::testing::Test {
+ public:
+ using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+ void AssertIndexIs(const Datum& array, const std::shared_ptr<ScalarType>& value,
+ int64_t expected) {
+ IndexOptions options(value);
+ ASSERT_OK_AND_ASSIGN(Datum out, Index(array, options));
+ const Int64Scalar& out_index = out.scalar_as<Int64Scalar>();
+ ASSERT_EQ(out_index.value, expected);
+ }
+
+ void AssertIndexIs(const std::string& json, const std::shared_ptr<ScalarType>& value,
+ int64_t expected) {
+ SCOPED_TRACE("Value: " + value->ToString());
+ SCOPED_TRACE("Input: " + json);
+ auto array = ArrayFromJSON(type_singleton(), json);
+ AssertIndexIs(array, value, expected);
+ }
+
+ void AssertIndexIs(const std::vector<std::string>& json,
+ const std::shared_ptr<ScalarType>& value, int64_t expected) {
+ SCOPED_TRACE("Value: " + value->ToString());
+ auto array = ChunkedArrayFromJSON(type_singleton(), json);
+ SCOPED_TRACE("Input: " + array->ToString());
+ AssertIndexIs(array, value, expected);
+ }
+
+ std::shared_ptr<DataType> type_singleton() { return std::make_shared<ArrowType>(); }
+};
+
+template <typename ArrowType>
+class TestNumericIndexKernel : public TestIndexKernel<ArrowType> {
+ public:
+ using CType = typename TypeTraits<ArrowType>::CType;
+};
+TYPED_TEST_SUITE(TestNumericIndexKernel, NumericArrowTypes);
+TYPED_TEST(TestNumericIndexKernel, Basics) {
+ std::vector<std::string> chunked_input0 = {"[]", "[0]"};
+ std::vector<std::string> chunked_input1 = {"[1, 0, null]", "[0, 0]"};
+ std::vector<std::string> chunked_input2 = {"[1, 1, 1]", "[1, 0]", "[0, 1]"};
+ std::vector<std::string> chunked_input3 = {"[1, 1, 1]", "[1, 1]"};
+ std::vector<std::string> chunked_input4 = {"[1, 1, 1]", "[1, 1]", "[0]"};
+
+ auto value = std::make_shared<typename TestFixture::ScalarType>(
+ static_cast<typename TestFixture::CType>(0));
+ auto null_value = std::make_shared<typename TestFixture::ScalarType>(
+ static_cast<typename TestFixture::CType>(0));
+ null_value->is_valid = false;
+
+ this->AssertIndexIs("[]", value, -1);
+ this->AssertIndexIs("[0]", value, 0);
+ this->AssertIndexIs("[1, 2, 3, 4]", value, -1);
+ this->AssertIndexIs("[1, 2, 3, 4, 0]", value, 4);
+ this->AssertIndexIs("[null, null, null]", value, -1);
+ this->AssertIndexIs("[null, null, null]", null_value, -1);
+ this->AssertIndexIs("[0, null, null]", null_value, -1);
+ this->AssertIndexIs(chunked_input0, value, 0);
+ this->AssertIndexIs(chunked_input1, value, 1);
+ this->AssertIndexIs(chunked_input2, value, 4);
+ this->AssertIndexIs(chunked_input3, value, -1);
+ this->AssertIndexIs(chunked_input4, value, 5);
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Must provide IndexOptions"),
+ CallFunction("index", {ArrayFromJSON(this->type_singleton(), "[0]")}));
+}
+TYPED_TEST(TestNumericIndexKernel, Random) {
+ constexpr auto kChunks = 4;
+ auto rand = random::RandomArrayGenerator(0x5487655);
+ auto value = std::make_shared<typename TestFixture::ScalarType>(
+ static_cast<typename TestFixture::CType>(0));
+
+ // Test chunked array sizes from 32 to 2048
+ for (size_t i = 3; i <= 9; i += 2) {
+ const int64_t chunk_length = static_cast<int64_t>(1) << i;
+ ArrayVector chunks;
+ for (int i = 0; i < kChunks; i++) {
+ chunks.push_back(
+ rand.ArrayOf(this->type_singleton(), chunk_length, /*null_probability=*/0.1));
+ }
+ ChunkedArray chunked_array(std::move(chunks));
+
+ int64_t expected = -1;
+ int64_t index = 0;
+ for (auto chunk : chunked_array.chunks()) {
+ auto typed_chunk = arrow::internal::checked_pointer_cast<
+ typename TypeTraits<TypeParam>::ArrayType>(chunk);
+ for (auto value : *typed_chunk) {
+ if (value.has_value() &&
+ value.value() == static_cast<typename TestFixture::CType>(0)) {
+ expected = index;
+ break;
+ }
+ index++;
+ }
+ if (expected >= 0) break;
+ }
+
+ this->AssertIndexIs(Datum(chunked_array), value, expected);
+ }
+}
+
+template <typename ArrowType>
+class TestDateTimeIndexKernel : public TestIndexKernel<ArrowType> {};
+TYPED_TEST_SUITE(TestDateTimeIndexKernel, TemporalArrowTypes);
+TYPED_TEST(TestDateTimeIndexKernel, Basics) {
+ auto type = this->type_singleton();
+ auto value = std::make_shared<typename TestFixture::ScalarType>(42, type);
+ auto null_value = std::make_shared<typename TestFixture::ScalarType>(42, type);
+ null_value->is_valid = false;
+
+ this->AssertIndexIs("[]", value, -1);
+ this->AssertIndexIs("[42]", value, 0);
+ this->AssertIndexIs("[84, 84, 84, 84]", value, -1);
+ this->AssertIndexIs("[84, 84, 84, 84, 42]", value, 4);
+ this->AssertIndexIs("[null, null, null]", value, -1);
+ this->AssertIndexIs("[null, null, null]", null_value, -1);
+ this->AssertIndexIs("[42, null, null]", null_value, -1);
+}
+
+template <typename ArrowType>
+class TestBooleanIndexKernel : public TestIndexKernel<ArrowType> {};
+TYPED_TEST_SUITE(TestBooleanIndexKernel, ::testing::Types<BooleanType>);
+TYPED_TEST(TestBooleanIndexKernel, Basics) {
+ auto value = std::make_shared<typename TestFixture::ScalarType>(true);
+ auto null_value = std::make_shared<typename TestFixture::ScalarType>(true);
+ null_value->is_valid = false;
+
+ this->AssertIndexIs("[]", value, -1);
+ this->AssertIndexIs("[true]", value, 0);
+ this->AssertIndexIs("[false, false, false, false]", value, -1);
+ this->AssertIndexIs("[false, false, false, false, true]", value, 4);
+ this->AssertIndexIs("[null, null, null]", value, -1);
+ this->AssertIndexIs("[null, null, null]", null_value, -1);
+ this->AssertIndexIs("[true, null, null]", null_value, -1);
+}
+
+template <typename ArrowType>
+class TestStringIndexKernel : public TestIndexKernel<ArrowType> {};
+TYPED_TEST_SUITE(TestStringIndexKernel, BinaryArrowTypes);
+TYPED_TEST(TestStringIndexKernel, Basics) {
+ auto buffer = Buffer::FromString("foo");
+ auto value = std::make_shared<typename TestFixture::ScalarType>(buffer);
+ auto null_value = std::make_shared<typename TestFixture::ScalarType>(buffer);
+ null_value->is_valid = false;
+
+ this->AssertIndexIs(R"([])", value, -1);
+ this->AssertIndexIs(R"(["foo"])", value, 0);
+ this->AssertIndexIs(R"(["bar", "bar", "bar", "bar"])", value, -1);
+ this->AssertIndexIs(R"(["bar", "bar", "bar", "bar", "foo"])", value, 4);
+ this->AssertIndexIs(R"([null, null, null])", value, -1);
+ this->AssertIndexIs(R"([null, null, null])", null_value, -1);
+ this->AssertIndexIs(R"(["foo", null, null])", null_value, -1);
+}
+
+//
+// Mode
+//
+
+template <typename T>
+class TestPrimitiveModeKernel : public ::testing::Test {
+ public:
+ using ArrowType = T;
+ using Traits = TypeTraits<ArrowType>;
+ using CType = typename ArrowType::c_type;
+
+ void AssertModesAre(const Datum& array, const ModeOptions options,
+ const std::vector<CType>& expected_modes,
+ const std::vector<int64_t>& expected_counts) {
+ ASSERT_OK_AND_ASSIGN(Datum out, Mode(array, options));
+ ValidateOutput(out);
+ const StructArray out_array(out.array());
+ ASSERT_EQ(out_array.length(), expected_modes.size());
+ ASSERT_EQ(out_array.num_fields(), 2);
+
+ const CType* out_modes = out_array.field(0)->data()->GetValues<CType>(1);
+ const int64_t* out_counts = out_array.field(1)->data()->GetValues<int64_t>(1);
+ for (int i = 0; i < out_array.length(); ++i) {
+ // equal or nan equal
+ ASSERT_TRUE(
+ (expected_modes[i] == out_modes[i]) ||
+ (expected_modes[i] != expected_modes[i] && out_modes[i] != out_modes[i]));
+ ASSERT_EQ(expected_counts[i], out_counts[i]);
+ }
+ }
+
+ void AssertModesAre(const std::string& json, const int n,
+ const std::vector<CType>& expected_modes,
+ const std::vector<int64_t>& expected_counts) {
+ auto array = ArrayFromJSON(type_singleton(), json);
+ AssertModesAre(array, ModeOptions(n), expected_modes, expected_counts);
+ }
+
+ void AssertModesAre(const std::string& json, const ModeOptions options,
+ const std::vector<CType>& expected_modes,
+ const std::vector<int64_t>& expected_counts) {
+ auto array = ArrayFromJSON(type_singleton(), json);
+ AssertModesAre(array, options, expected_modes, expected_counts);
+ }
+
+ void AssertModeIs(const Datum& array, CType expected_mode, int64_t expected_count) {
+ AssertModesAre(array, ModeOptions(1), {expected_mode}, {expected_count});
+ }
+
+ void AssertModeIs(const std::string& json, CType expected_mode,
+ int64_t expected_count) {
+ auto array = ArrayFromJSON(type_singleton(), json);
+ AssertModeIs(array, expected_mode, expected_count);
+ }
+
+ void AssertModeIs(const std::vector<std::string>& json, CType expected_mode,
+ int64_t expected_count) {
+ auto chunked = ChunkedArrayFromJSON(type_singleton(), json);
+ AssertModeIs(chunked, expected_mode, expected_count);
+ }
+
+ void AssertModesEmpty(const Datum& array, ModeOptions options) {
+ ASSERT_OK_AND_ASSIGN(Datum out, Mode(array, options));
+ auto out_array = out.make_array();
+ ValidateOutput(*out_array);
+ ASSERT_EQ(out.array()->length, 0);
+ }
+
+ void AssertModesEmpty(const std::string& json, int n = 1) {
+ auto array = ArrayFromJSON(type_singleton(), json);
+ AssertModesEmpty(array, ModeOptions(n));
+ }
+
+ void AssertModesEmpty(const std::vector<std::string>& json, int n = 1) {
+ auto chunked = ChunkedArrayFromJSON(type_singleton(), json);
+ AssertModesEmpty(chunked, ModeOptions(n));
+ }
+
+ void AssertModesEmpty(const std::string& json, ModeOptions options) {
+ auto array = ArrayFromJSON(type_singleton(), json);
+ AssertModesEmpty(array, options);
+ }
+
+ std::shared_ptr<DataType> type_singleton() { return Traits::type_singleton(); }
+};
+
+template <typename ArrowType>
+class TestIntegerModeKernel : public TestPrimitiveModeKernel<ArrowType> {};
+
+template <typename ArrowType>
+class TestFloatingModeKernel : public TestPrimitiveModeKernel<ArrowType> {};
+
+class TestBooleanModeKernel : public TestPrimitiveModeKernel<BooleanType> {};
+
+class TestInt8ModeKernelValueRange : public TestPrimitiveModeKernel<Int8Type> {};
+
+class TestInt32ModeKernel : public TestPrimitiveModeKernel<Int32Type> {};
+
+TEST_F(TestBooleanModeKernel, Basics) {
+ this->AssertModeIs("[false, false]", false, 2);
+ this->AssertModeIs("[false, false, true, true, true]", true, 3);
+ this->AssertModeIs("[true, false, false, true, true]", true, 3);
+ this->AssertModeIs("[false, false, true, true, true, false]", false, 3);
+
+ this->AssertModeIs("[true, null, false, false, null, true, null, null, true]", true, 3);
+ this->AssertModesEmpty("[null, null, null]");
+ this->AssertModesEmpty("[]");
+
+ this->AssertModeIs({"[true, false]", "[true, true]", "[false, false]"}, false, 3);
+ this->AssertModeIs({"[true, null]", "[]", "[null, false]"}, false, 1);
+ this->AssertModesEmpty({"[null, null]", "[]", "[null]"});
+
+ this->AssertModesAre("[false, false, true, true, true, false]", 2, {false, true},
+ {3, 3});
+ this->AssertModesAre("[true, null, false, false, null, true, null, null, true]", 100,
+ {true, false}, {3, 2});
+ this->AssertModesEmpty({"[null, null]", "[]", "[null]"}, 4);
+
+ auto in_ty = boolean();
+ this->AssertModesAre("[true, false, false, null]", ModeOptions(/*n=*/1), {false}, {2});
+ this->AssertModesEmpty("[true, false, false, null]",
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false));
+ this->AssertModesAre("[true, false, false, null]",
+ ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/3),
+ {false}, {2});
+ this->AssertModesEmpty("[false, false, null]",
+ ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/3));
+ this->AssertModesAre("[true, false, false]",
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3),
+ {false}, {2});
+ this->AssertModesEmpty("[true, false, false, null]",
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3));
+ this->AssertModesEmpty("[true, false]",
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3));
+ this->AssertModesAre(ScalarFromJSON(in_ty, "true"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false), {true}, {1});
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false));
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "true"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/2));
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/2));
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "true"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/2));
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/2));
+
+ this->AssertModesAre(ScalarFromJSON(in_ty, "true"), ModeOptions(/*n=*/1), {true}, {1});
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), ModeOptions(/*n=*/1));
+}
+
+TYPED_TEST_SUITE(TestIntegerModeKernel, IntegralArrowTypes);
+TYPED_TEST(TestIntegerModeKernel, Basics) {
+ this->AssertModeIs("[5, 1, 1, 5, 5]", 5, 3);
+ this->AssertModeIs("[5, 1, 1, 5, 5, 1]", 1, 3);
+ this->AssertModeIs("[127, 0, 127, 127, 0, 1, 0, 127]", 127, 4);
+
+ this->AssertModeIs("[null, null, 2, null, 1]", 1, 1);
+ this->AssertModesEmpty("[null, null, null]");
+ this->AssertModesEmpty("[]");
+
+ this->AssertModeIs({"[5]", "[1, 1, 5]", "[5]"}, 5, 3);
+ this->AssertModeIs({"[5]", "[1, 1, 5]", "[5, 1]"}, 1, 3);
+ this->AssertModesEmpty({"[null, null]", "[]", "[null]"});
+
+ this->AssertModesAre("[127, 0, 127, 127, 0, 1, 0, 127]", 2, {127, 0}, {4, 3});
+ this->AssertModesAre("[null, null, 2, null, 1]", 3, {1, 2}, {1, 1});
+ this->AssertModesEmpty("[null, null, null]", 10);
+
+ auto in_ty = this->type_singleton();
+
+ this->AssertModesAre("[1, 2, 2, null]", ModeOptions(/*n=*/1), {2}, {2});
+ this->AssertModesEmpty("[1, 2, 2, null]", ModeOptions(/*n=*/1, /*skip_nulls=*/false));
+ this->AssertModesAre("[1, 2, 2, null]",
+ ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/3), {2},
+ {2});
+ this->AssertModesEmpty("[2, 2, null]",
+ ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/3));
+ this->AssertModesAre(
+ "[1, 2, 2]", ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3), {2}, {2});
+ this->AssertModesEmpty("[1, 2, 2, null]",
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3));
+ this->AssertModesEmpty("[1, 2]",
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3));
+ this->AssertModesAre(ScalarFromJSON(in_ty, "1"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false), {1}, {1});
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false));
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "1"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/2));
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/2));
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "1"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/2));
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/2));
+
+ this->AssertModesAre(ScalarFromJSON(in_ty, "5"), ModeOptions(/*n=*/1), {5}, {1});
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), ModeOptions(/*n=*/1));
+}
+
+TYPED_TEST_SUITE(TestFloatingModeKernel, RealArrowTypes);
+TYPED_TEST(TestFloatingModeKernel, Floats) {
+ this->AssertModeIs("[5, 1, 1, 5, 5]", 5, 3);
+ this->AssertModeIs("[5, 1, 1, 5, 5, 1]", 1, 3);
+ this->AssertModeIs("[Inf, 100, Inf, 100, Inf]", INFINITY, 3);
+ this->AssertModeIs("[Inf, -Inf, Inf, -Inf]", -INFINITY, 2);
+
+ this->AssertModeIs("[null, null, 2, null, 1]", 1, 1);
+ this->AssertModeIs("[NaN, NaN, 1, null, 1]", 1, 2);
+
+ this->AssertModesEmpty("[null, null, null]");
+ this->AssertModesEmpty("[]");
+
+ this->AssertModeIs("[NaN, NaN, 1]", NAN, 2);
+ this->AssertModeIs("[NaN, NaN, null]", NAN, 2);
+ this->AssertModeIs("[NaN, NaN, NaN]", NAN, 3);
+
+ this->AssertModeIs({"[Inf, 100]", "[Inf, 100]", "[Inf]"}, INFINITY, 3);
+ this->AssertModeIs({"[NaN, 1]", "[NaN, 1]", "[NaN]"}, NAN, 3);
+ this->AssertModesEmpty({"[null, null]", "[]", "[null]"});
+
+ this->AssertModesAre("[Inf, 100, Inf, 100, Inf]", 2, {INFINITY, 100}, {3, 2});
+ this->AssertModesAre("[NaN, NaN, 1, null, 1, 2, 2]", 3, {1, 2, NAN}, {2, 2, 2});
+
+ auto in_ty = this->type_singleton();
+
+ this->AssertModesAre("[1, 2, 2, null]", ModeOptions(/*n=*/1), {2}, {2});
+ this->AssertModesEmpty("[1, 2, 2, null]", ModeOptions(/*n=*/1, /*skip_nulls=*/false));
+ this->AssertModesAre("[1, 2, 2, null]",
+ ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/3), {2},
+ {2});
+ this->AssertModesEmpty("[2, 2, null]",
+ ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/3));
+ this->AssertModesAre(
+ "[1, 2, 2]", ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3), {2}, {2});
+ this->AssertModesEmpty("[1, 2, 2, null]",
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3));
+ this->AssertModesEmpty("[1, 2]",
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/3));
+ this->AssertModesAre(ScalarFromJSON(in_ty, "1"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false), {1}, {1});
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false));
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "1"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/2));
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/true, /*min_count=*/2));
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "1"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/2));
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"),
+ ModeOptions(/*n=*/1, /*skip_nulls=*/false, /*min_count=*/2));
+
+ this->AssertModesAre(ScalarFromJSON(in_ty, "5"), ModeOptions(/*n=*/1), {5}, {1});
+ this->AssertModesEmpty(ScalarFromJSON(in_ty, "null"), ModeOptions(/*n=*/1));
+}
+
+TEST_F(TestInt8ModeKernelValueRange, Basics) {
+ this->AssertModeIs("[0, 127, -128, -128]", -128, 2);
+ this->AssertModeIs("[127, 127, 127]", 127, 3);
+}
+
+template <typename ArrowType>
+struct ModeResult {
+ using T = typename ArrowType::c_type;
+
+ T mode = std::numeric_limits<T>::min();
+ int64_t count = 0;
+};
+
+template <typename ArrowType>
+ModeResult<ArrowType> NaiveMode(const Array& array) {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ using CTYPE = typename ArrowType::c_type;
+
+ std::unordered_map<CTYPE, int64_t> value_counts;
+
+ const auto& array_numeric = reinterpret_cast<const ArrayType&>(array);
+ const auto values = array_numeric.raw_values();
+ BitmapReader reader(array.null_bitmap_data(), array.offset(), array.length());
+ for (int64_t i = 0; i < array.length(); ++i) {
+ if (reader.IsSet()) {
+ ++value_counts[values[i]];
+ }
+ reader.Next();
+ }
+
+ ModeResult<ArrowType> result;
+ for (const auto& value_count : value_counts) {
+ auto value = value_count.first;
+ auto count = value_count.second;
+ if (count > result.count || (count == result.count && value < result.mode)) {
+ result.count = count;
+ result.mode = value;
+ }
+ }
+
+ return result;
+}
+
+template <typename ArrowType, typename CTYPE = typename ArrowType::c_type>
+void VerifyMode(const std::shared_ptr<Array>& array) {
+ auto expected = NaiveMode<ArrowType>(*array);
+ ASSERT_OK_AND_ASSIGN(Datum out, Mode(array));
+ const StructArray out_array(out.array());
+ ValidateOutput(out_array);
+ ASSERT_EQ(out_array.length(), 1);
+ ASSERT_EQ(out_array.num_fields(), 2);
+
+ const CTYPE* out_modes = out_array.field(0)->data()->GetValues<CTYPE>(1);
+ const int64_t* out_counts = out_array.field(1)->data()->GetValues<int64_t>(1);
+ ASSERT_EQ(out_modes[0], expected.mode);
+ ASSERT_EQ(out_counts[0], expected.count);
+}
+
+template <typename ArrowType, typename CTYPE = typename ArrowType::c_type>
+void CheckModeWithRange(CTYPE range_min, CTYPE range_max) {
+ auto rand = random::RandomArrayGenerator(0x5487655);
+ // 32K items (>= counting mode cutoff) within range, 10% null
+ auto array = rand.Numeric<ArrowType>(32 * 1024, range_min, range_max, 0.1);
+ VerifyMode<ArrowType>(array);
+}
+
+template <typename ArrowType, typename CTYPE = typename ArrowType::c_type>
+void CheckModeWithRangeSliced(CTYPE range_min, CTYPE range_max) {
+ auto rand = random::RandomArrayGenerator(0x5487655);
+ auto array = rand.Numeric<ArrowType>(32 * 1024, range_min, range_max, 0.1);
+
+ const int64_t array_size = array->length();
+ const std::vector<std::array<int64_t, 2>> offset_size{
+ {0, 40},
+ {array_size - 40, 40},
+ {array_size / 3, array_size / 6},
+ {array_size * 9 / 10, array_size / 10},
+ };
+ for (const auto& os : offset_size) {
+ VerifyMode<ArrowType>(array->Slice(os[0], os[1]));
+ }
+}
+
+TEST_F(TestInt32ModeKernel, SmallValueRange) {
+ // Small value range => should exercise counter-based Mode implementation
+ CheckModeWithRange<ArrowType>(-100, 100);
+}
+
+TEST_F(TestInt32ModeKernel, LargeValueRange) {
+ // Large value range => should exercise sorter-based Mode implementation
+ CheckModeWithRange<ArrowType>(-10000000, 10000000);
+}
+
+TEST_F(TestInt32ModeKernel, Sliced) {
+ CheckModeWithRangeSliced<ArrowType>(-100, 100);
+ CheckModeWithRangeSliced<ArrowType>(-10000000, 10000000);
+}
+
+//
+// Variance/Stddev
+//
+
+template <typename ArrowType>
+class TestPrimitiveVarStdKernel : public ::testing::Test {
+ public:
+ using Traits = TypeTraits<ArrowType>;
+ using ScalarType = typename TypeTraits<DoubleType>::ScalarType;
+
+ void AssertVarStdIs(const Array& array, const VarianceOptions& options,
+ double expected_var) {
+ AssertVarStdIsInternal(array, options, expected_var);
+ }
+
+ void AssertVarStdIs(const std::shared_ptr<ChunkedArray>& array,
+ const VarianceOptions& options, double expected_var) {
+ AssertVarStdIsInternal(array, options, expected_var);
+ }
+
+ void AssertVarStdIs(const std::string& json, const VarianceOptions& options,
+ double expected_var) {
+ auto array = ArrayFromJSON(type_singleton(), json);
+ AssertVarStdIs(*array, options, expected_var);
+ }
+
+ void AssertVarStdIs(const std::vector<std::string>& json,
+ const VarianceOptions& options, double expected_var) {
+ auto chunked = ChunkedArrayFromJSON(type_singleton(), json);
+ AssertVarStdIs(chunked, options, expected_var);
+ }
+
+ void AssertVarStdIsInvalid(const Array& array, const VarianceOptions& options) {
+ AssertVarStdIsInvalidInternal(array, options);
+ }
+
+ void AssertVarStdIsInvalid(const std::shared_ptr<ChunkedArray>& array,
+ const VarianceOptions& options) {
+ AssertVarStdIsInvalidInternal(array, options);
+ }
+
+ void AssertVarStdIsInvalid(const std::string& json, const VarianceOptions& options) {
+ auto array = ArrayFromJSON(type_singleton(), json);
+ AssertVarStdIsInvalid(*array, options);
+ }
+
+ void AssertVarStdIsInvalid(const std::vector<std::string>& json,
+ const VarianceOptions& options) {
+ auto array = ChunkedArrayFromJSON(type_singleton(), json);
+ AssertVarStdIsInvalid(array, options);
+ }
+
+ std::shared_ptr<DataType> type_singleton() { return Traits::type_singleton(); }
+
+ private:
+ void AssertVarStdIsInternal(const Datum& array, const VarianceOptions& options,
+ double expected_var) {
+ ASSERT_OK_AND_ASSIGN(Datum out_var, Variance(array, options));
+ ASSERT_OK_AND_ASSIGN(Datum out_std, Stddev(array, options));
+ auto var = checked_cast<const ScalarType*>(out_var.scalar().get());
+ auto std = checked_cast<const ScalarType*>(out_std.scalar().get());
+ ASSERT_TRUE(var->is_valid && std->is_valid);
+ ASSERT_DOUBLE_EQ(std->value * std->value, var->value);
+ ASSERT_DOUBLE_EQ(var->value, expected_var); // < 4ULP
+ }
+
+ void AssertVarStdIsInvalidInternal(const Datum& array, const VarianceOptions& options) {
+ ASSERT_OK_AND_ASSIGN(Datum out_var, Variance(array, options));
+ ASSERT_OK_AND_ASSIGN(Datum out_std, Stddev(array, options));
+ auto var = checked_cast<const ScalarType*>(out_var.scalar().get());
+ auto std = checked_cast<const ScalarType*>(out_std.scalar().get());
+ ASSERT_FALSE(var->is_valid || std->is_valid);
+ }
+};
+
+template <typename ArrowType>
+class TestNumericVarStdKernel : public TestPrimitiveVarStdKernel<ArrowType> {};
+
+// Reference value from numpy.var
+TYPED_TEST_SUITE(TestNumericVarStdKernel, NumericArrowTypes);
+TYPED_TEST(TestNumericVarStdKernel, Basics) {
+ VarianceOptions options; // ddof = 0, population variance/stddev
+
+ this->AssertVarStdIs("[100]", options, 0);
+ this->AssertVarStdIs("[1, 2, 3]", options, 0.6666666666666666);
+ this->AssertVarStdIs("[null, 1, 2, null, 3]", options, 0.6666666666666666);
+
+ std::vector<std::string> chunks;
+ chunks = {"[]", "[1]", "[2]", "[null]", "[3]"};
+ this->AssertVarStdIs(chunks, options, 0.6666666666666666);
+ chunks = {"[1, 2, 3]", "[4, 5, 6]", "[7, 8]"};
+ this->AssertVarStdIs(chunks, options, 5.25);
+ chunks = {"[1, 2, 3, 4, 5, 6, 7]", "[8]"};
+ this->AssertVarStdIs(chunks, options, 5.25);
+
+ this->AssertVarStdIsInvalid("[null, null, null]", options);
+ this->AssertVarStdIsInvalid("[]", options);
+ this->AssertVarStdIsInvalid("[]", options);
+
+ options.ddof = 1; // sample variance/stddev
+
+ this->AssertVarStdIs("[1, 2]", options, 0.5);
+
+ chunks = {"[1]", "[2]"};
+ this->AssertVarStdIs(chunks, options, 0.5);
+ chunks = {"[1, 2, 3]", "[4, 5, 6]", "[7, 8]"};
+ this->AssertVarStdIs(chunks, options, 6.0);
+ chunks = {"[1, 2, 3, 4, 5, 6, 7]", "[8]"};
+ this->AssertVarStdIs(chunks, options, 6.0);
+
+ this->AssertVarStdIsInvalid("[100]", options);
+ this->AssertVarStdIsInvalid("[100, null, null]", options);
+ chunks = {"[100]", "[null]", "[]"};
+ this->AssertVarStdIsInvalid(chunks, options);
+
+ auto ty = this->type_singleton();
+ EXPECT_THAT(Stddev(*MakeScalar(ty, 5)), ResultWith(Datum(0.0)));
+ EXPECT_THAT(Variance(*MakeScalar(ty, 5)), ResultWith(Datum(0.0)));
+ EXPECT_THAT(Stddev(*MakeScalar(ty, 5), options),
+ ResultWith(Datum(MakeNullScalar(float64()))));
+ EXPECT_THAT(Variance(*MakeScalar(ty, 5), options),
+ ResultWith(Datum(MakeNullScalar(float64()))));
+ EXPECT_THAT(Stddev(MakeNullScalar(ty)), ResultWith(Datum(MakeNullScalar(float64()))));
+ EXPECT_THAT(Variance(MakeNullScalar(ty)), ResultWith(Datum(MakeNullScalar(float64()))));
+
+ // skip_nulls and min_count
+ options.ddof = 0;
+ options.min_count = 3;
+ this->AssertVarStdIs("[1, 2, 3]", options, 0.6666666666666666);
+ this->AssertVarStdIsInvalid("[1, 2, null]", options);
+
+ options.min_count = 0;
+ options.skip_nulls = false;
+ this->AssertVarStdIs("[1, 2, 3]", options, 0.6666666666666666);
+ this->AssertVarStdIsInvalid("[1, 2, 3, null]", options);
+
+ options.min_count = 4;
+ options.skip_nulls = false;
+ this->AssertVarStdIsInvalid("[1, 2, 3]", options);
+ this->AssertVarStdIsInvalid("[1, 2, 3, null]", options);
+}
+
+// Test numerical stability
+template <typename ArrowType>
+class TestVarStdKernelStability : public TestPrimitiveVarStdKernel<ArrowType> {};
+
+typedef ::testing::Types<Int32Type, UInt32Type, Int64Type, UInt64Type, DoubleType>
+ VarStdStabilityTypes;
+
+TYPED_TEST_SUITE(TestVarStdKernelStability, VarStdStabilityTypes);
+TYPED_TEST(TestVarStdKernelStability, Basics) {
+ VarianceOptions options{1}; // ddof = 1
+ this->AssertVarStdIs("[100000004, 100000007, 100000013, 100000016]", options, 30.0);
+ this->AssertVarStdIs("[1000000004, 1000000007, 1000000013, 1000000016]", options, 30.0);
+ if (!is_unsigned_integer_type<TypeParam>::value) {
+ this->AssertVarStdIs("[-1000000016, -1000000013, -1000000007, -1000000004]", options,
+ 30.0);
+ }
+}
+
+// Test numerical stability of variance merging code
+class TestVarStdKernelMergeStability : public TestPrimitiveVarStdKernel<DoubleType> {};
+
+TEST_F(TestVarStdKernelMergeStability, Basics) {
+ VarianceOptions options{1}; // ddof = 1
+
+#ifndef __MINGW32__ // MinGW has precision issues
+ // XXX: The reference value from numpy is actually wrong due to floating
+ // point limits. The correct result should equals variance(90, 0) = 4050.
+ std::vector<std::string> chunks = {"[40000008000000490]", "[40000008000000400]"};
+ this->AssertVarStdIs(chunks, options, 3904.0);
+#endif
+}
+
+// Test round-off error
+template <typename ArrowType>
+class TestVarStdKernelRoundOff : public TestPrimitiveVarStdKernel<ArrowType> {};
+
+typedef ::testing::Types<Int32Type, Int64Type, FloatType, DoubleType> VarStdRoundOffTypes;
+
+TYPED_TEST_SUITE(TestVarStdKernelRoundOff, VarStdRoundOffTypes);
+TYPED_TEST(TestVarStdKernelRoundOff, Basics) {
+ // build array: np.arange(321000, dtype='xxx')
+ typename TypeParam::c_type value = 0;
+ ASSERT_OK_AND_ASSIGN(
+ auto array, ArrayFromBuilderVisitor(TypeTraits<TypeParam>::type_singleton(), 321000,
+ [&](NumericBuilder<TypeParam>* builder) {
+ builder->UnsafeAppend(value++);
+ }));
+
+ // reference value from numpy.var()
+ this->AssertVarStdIs(*array, VarianceOptions{0}, 8586749999.916667);
+}
+
+// Test integer arithmetic code
+class TestVarStdKernelInt32 : public TestPrimitiveVarStdKernel<Int32Type> {};
+
+TEST_F(TestVarStdKernelInt32, Basics) {
+ VarianceOptions options{1};
+ this->AssertVarStdIs("[-2147483648, -2147483647, -2147483646]", options, 1.0);
+ this->AssertVarStdIs("[2147483645, 2147483646, 2147483647]", options, 1.0);
+ this->AssertVarStdIs("[-2147483648, -2147483648, 2147483647]", options,
+ 6.148914688373205e+18);
+}
+
+class TestVarStdKernelUInt32 : public TestPrimitiveVarStdKernel<UInt32Type> {};
+
+TEST_F(TestVarStdKernelUInt32, Basics) {
+ VarianceOptions options{1};
+ this->AssertVarStdIs("[4294967293, 4294967294, 4294967295]", options, 1.0);
+ this->AssertVarStdIs("[0, 0, 4294967295]", options, 6.148914688373205e+18);
+}
+
+// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
+void KahanSum(double& sum, double& adjust, double addend) {
+ double y = addend - adjust;
+ double t = sum + y;
+ adjust = (t - sum) - y;
+ sum = t;
+}
+
+// Calculate reference variance with Welford's online algorithm + Kahan summation
+// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
+// XXX: not stable for long array with very small `stddev / average`
+template <typename ArrayType>
+std::pair<double, double> WelfordVar(const ArrayType& array) {
+ const auto values = array.raw_values();
+ BitmapReader reader(array.null_bitmap_data(), array.offset(), array.length());
+ double count = 0, mean = 0, m2 = 0;
+ double mean_adjust = 0, m2_adjust = 0;
+ for (int64_t i = 0; i < array.length(); ++i) {
+ if (reader.IsSet()) {
+ ++count;
+ double delta = static_cast<double>(values[i]) - mean;
+ KahanSum(mean, mean_adjust, delta / count);
+ double delta2 = static_cast<double>(values[i]) - mean;
+ KahanSum(m2, m2_adjust, delta * delta2);
+ }
+ reader.Next();
+ }
+ return std::make_pair(m2 / count, m2 / (count - 1));
+}
+
+// Test random chunked array
+template <typename ArrowType>
+class TestVarStdKernelRandom : public TestPrimitiveVarStdKernel<ArrowType> {};
+
+using VarStdRandomTypes =
+ ::testing::Types<Int32Type, UInt32Type, Int64Type, UInt64Type, FloatType, DoubleType>;
+
+TYPED_TEST_SUITE(TestVarStdKernelRandom, VarStdRandomTypes);
+
+TYPED_TEST(TestVarStdKernelRandom, Basics) {
+#if defined(__MINGW32__) && !defined(__MINGW64__)
+ if (TypeParam::type_id == Type::FLOAT) {
+ GTEST_SKIP() << "Precision issues on MinGW32 with float32";
+ }
+#endif
+ // Cut array into small chunks
+ constexpr int array_size = 5000;
+ constexpr int chunk_size_max = 50;
+ constexpr int chunk_count = array_size / chunk_size_max;
+
+ std::shared_ptr<Array> array;
+ auto rand = random::RandomArrayGenerator(0x5487656);
+ if (is_floating_type<TypeParam>::value) {
+ array = rand.Numeric<TypeParam>(array_size, -10000.0, 100000.0, 0.1);
+ } else {
+ using CType = typename TypeParam::c_type;
+ constexpr CType min = std::numeric_limits<CType>::min();
+ constexpr CType max = std::numeric_limits<CType>::max();
+ array = rand.Numeric<TypeParam>(array_size, min, max, 0.1);
+ }
+ auto chunk_size_array = rand.Numeric<Int32Type>(chunk_count, 0, chunk_size_max);
+ const int* chunk_size = chunk_size_array->data()->GetValues<int>(1);
+ int total_size = 0;
+
+ ArrayVector array_vector;
+ for (int i = 0; i < chunk_count; ++i) {
+ array_vector.emplace_back(array->Slice(total_size, chunk_size[i]));
+ total_size += chunk_size[i];
+ }
+ auto chunked = *ChunkedArray::Make(array_vector);
+
+ double var_population, var_sample;
+ using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
+ auto typed_array = checked_pointer_cast<ArrayType>(array->Slice(0, total_size));
+ std::tie(var_population, var_sample) = WelfordVar(*typed_array);
+
+ this->AssertVarStdIs(chunked, VarianceOptions{0}, var_population);
+ this->AssertVarStdIs(chunked, VarianceOptions{1}, var_sample);
+}
+
+// This test is too heavy to run in CI, should be checked manually
+#if 0
+class TestVarStdKernelIntegerLength : public TestPrimitiveVarStdKernel<Int32Type> {};
+
+TEST_F(TestVarStdKernelIntegerLength, Basics) {
+ constexpr int32_t min = std::numeric_limits<int32_t>::min();
+ constexpr int32_t max = std::numeric_limits<int32_t>::max();
+ auto rand = random::RandomArrayGenerator(0x5487657);
+ // large data volume
+ auto array = rand.Numeric<Int32Type>(4000000000, min, max, 0.1);
+ // biased distribution
+ // auto array = rand.Numeric<Int32Type>(4000000000, min, min + 100000, 0.1);
+
+ double var_population, var_sample;
+ auto int32_array = checked_pointer_cast<Int32Array>(array);
+ std::tie(var_population, var_sample) = WelfordVar(*int32_array);
+
+ this->AssertVarStdIs(*array, VarianceOptions{0}, var_population);
+ this->AssertVarStdIs(*array, VarianceOptions{1}, var_sample);
+}
+#endif
+
+//
+// Quantile
+//
+
+template <typename ArrowType>
+class TestPrimitiveQuantileKernel : public ::testing::Test {
+ public:
+ using Traits = TypeTraits<ArrowType>;
+ using CType = typename ArrowType::c_type;
+
+ void AssertQuantilesAre(const Datum& array, QuantileOptions options,
+ const std::vector<std::vector<Datum>>& expected) {
+ ASSERT_EQ(options.q.size(), expected.size());
+
+ for (size_t i = 0; i < this->interpolations_.size(); ++i) {
+ options.interpolation = this->interpolations_[i];
+
+ ASSERT_OK_AND_ASSIGN(Datum out, Quantile(array, options));
+ const auto& out_array = out.make_array();
+ ValidateOutput(*out_array);
+ ASSERT_EQ(out_array->length(), options.q.size());
+ ASSERT_EQ(out_array->null_count(), 0);
+ AssertTypeEqual(out_array->type(), expected[0][i].type());
+
+ if (out_array->type()->Equals(float64())) {
+ const double* quantiles = out_array->data()->GetValues<double>(1);
+ for (int64_t j = 0; j < out_array->length(); ++j) {
+ const auto& numeric_scalar =
+ checked_pointer_cast<DoubleScalar>(expected[j][i].scalar());
+ ASSERT_TRUE((quantiles[j] == numeric_scalar->value) ||
+ (std::isnan(quantiles[j]) && std::isnan(numeric_scalar->value)));
+ }
+ } else {
+ AssertTypeEqual(out_array->type(), type_singleton());
+ const CType* quantiles = out_array->data()->GetValues<CType>(1);
+ for (int64_t j = 0; j < out_array->length(); ++j) {
+ const auto& numeric_scalar =
+ checked_pointer_cast<NumericScalar<ArrowType>>(expected[j][i].scalar());
+ ASSERT_EQ(quantiles[j], numeric_scalar->value);
+ }
+ }
+ }
+ }
+
+ void AssertQuantilesAre(const std::string& json, const std::vector<double>& q,
+ const std::vector<std::vector<Datum>>& expected) {
+ auto array = ArrayFromJSON(type_singleton(), json);
+ AssertQuantilesAre(array, QuantileOptions{q}, expected);
+ }
+
+ void AssertQuantilesAre(const std::vector<std::string>& json,
+ const std::vector<double>& q,
+ const std::vector<std::vector<Datum>>& expected) {
+ auto chunked = ChunkedArrayFromJSON(type_singleton(), json);
+ AssertQuantilesAre(chunked, QuantileOptions{q}, expected);
+ }
+
+ void AssertQuantileIs(const Datum& array, double q,
+ const std::vector<Datum>& expected) {
+ AssertQuantilesAre(array, QuantileOptions{q}, {expected});
+ }
+
+ void AssertQuantileIs(const std::string& json, double q,
+ const std::vector<Datum>& expected) {
+ auto array = ArrayFromJSON(type_singleton(), json);
+ AssertQuantileIs(array, q, expected);
+ }
+
+ void AssertQuantileIs(const std::vector<std::string>& json, double q,
+ const std::vector<Datum>& expected) {
+ auto chunked = ChunkedArrayFromJSON(type_singleton(), json);
+ AssertQuantileIs(chunked, q, expected);
+ }
+
+ void AssertQuantilesEmpty(const Datum& array, const std::vector<double>& q) {
+ QuantileOptions options{q};
+ for (auto interpolation : this->interpolations_) {
+ options.interpolation = interpolation;
+ ASSERT_OK_AND_ASSIGN(Datum out, Quantile(array, options));
+ auto out_array = out.make_array();
+ ValidateOutput(*out_array);
+ ASSERT_EQ(out.array()->length, q.size());
+ ASSERT_EQ(out.array()->null_count, q.size());
+ }
+ }
+
+ void AssertQuantilesEmpty(const std::string& json, const std::vector<double>& q) {
+ auto array = ArrayFromJSON(type_singleton(), json);
+ AssertQuantilesEmpty(array, q);
+ }
+
+ void AssertQuantilesEmpty(const std::vector<std::string>& json,
+ const std::vector<double>& q) {
+ auto chunked = ChunkedArrayFromJSON(type_singleton(), json);
+ AssertQuantilesEmpty(chunked, q);
+ }
+
+ std::shared_ptr<DataType> type_singleton() { return Traits::type_singleton(); }
+
+ std::vector<enum QuantileOptions::Interpolation> interpolations_ = {
+ QuantileOptions::LINEAR, QuantileOptions::LOWER, QuantileOptions::HIGHER,
+ QuantileOptions::NEAREST, QuantileOptions::MIDPOINT};
+};
+
+#define INTYPE(x) Datum(static_cast<typename TypeParam::c_type>(x))
+#define DOUBLE(x) Datum(static_cast<double>(x))
+// output type per interplation: linear, lower, higher, nearest, midpoint
+#define O(a, b, c, d, e) \
+ { DOUBLE(a), INTYPE(b), INTYPE(c), INTYPE(d), DOUBLE(e) }
+
+template <typename ArrowType>
+class TestIntegerQuantileKernel : public TestPrimitiveQuantileKernel<ArrowType> {};
+
+TYPED_TEST_SUITE(TestIntegerQuantileKernel, IntegralArrowTypes);
+TYPED_TEST(TestIntegerQuantileKernel, Basics) {
+ // reference values from numpy
+ // ordered by interpolation method: {linear, lower, higher, nearest, midpoint}
+ this->AssertQuantileIs("[1]", 0.1, O(1, 1, 1, 1, 1));
+ this->AssertQuantileIs("[1, 2]", 0.5, O(1.5, 1, 2, 1, 1.5));
+ this->AssertQuantileIs("[3, 5, 2, 9, 0, 1, 8]", 0.5, O(3, 3, 3, 3, 3));
+ this->AssertQuantileIs("[3, 5, 2, 9, 0, 1, 8]", 0.33, O(1.98, 1, 2, 2, 1.5));
+ this->AssertQuantileIs("[3, 5, 2, 9, 0, 1, 8]", 0.9, O(8.4, 8, 9, 8, 8.5));
+ this->AssertQuantilesAre("[3, 5, 2, 9, 0, 1, 8]", {0.5, 0.9},
+ {O(3, 3, 3, 3, 3), O(8.4, 8, 9, 8, 8.5)});
+ this->AssertQuantilesAre("[3, 5, 2, 9, 0, 1, 8]", {1, 0.5},
+ {O(9, 9, 9, 9, 9), O(3, 3, 3, 3, 3)});
+ this->AssertQuantileIs("[3, 5, 2, 9, 0, 1, 8]", 0, O(0, 0, 0, 0, 0));
+ this->AssertQuantileIs("[3, 5, 2, 9, 0, 1, 8]", 1, O(9, 9, 9, 9, 9));
+
+ this->AssertQuantileIs("[5, null, null, 3, 9, null, 8, 1, 2, 0]", 0.21,
+ O(1.26, 1, 2, 1, 1.5));
+ this->AssertQuantilesAre("[5, null, null, 3, 9, null, 8, 1, 2, 0]", {0.5, 0.9},
+ {O(3, 3, 3, 3, 3), O(8.4, 8, 9, 8, 8.5)});
+ this->AssertQuantilesAre("[5, null, null, 3, 9, null, 8, 1, 2, 0]", {0.9, 0.5},
+ {O(8.4, 8, 9, 8, 8.5), O(3, 3, 3, 3, 3)});
+
+ this->AssertQuantileIs({"[5]", "[null, null]", "[3, 9, null]", "[8, 1, 2, 0]"}, 0.33,
+ O(1.98, 1, 2, 2, 1.5));
+ this->AssertQuantilesAre({"[5]", "[null, null]", "[3, 9, null]", "[8, 1, 2, 0]"},
+ {0.21, 1}, {O(1.26, 1, 2, 1, 1.5), O(9, 9, 9, 9, 9)});
+
+ this->AssertQuantilesEmpty("[]", {0.5});
+ this->AssertQuantilesEmpty("[null, null, null]", {0.1, 0.2});
+ this->AssertQuantilesEmpty({"[null, null]", "[]", "[null]"}, {0.3, 0.4});
+
+ auto ty = this->type_singleton();
+
+ QuantileOptions keep_nulls(/*q=*/0.5, QuantileOptions::LINEAR, /*skip_nulls=*/false,
+ /*min_count=*/0);
+ QuantileOptions min_count(/*q=*/0.5, QuantileOptions::LINEAR, /*skip_nulls=*/true,
+ /*min_count=*/3);
+ QuantileOptions keep_nulls_min_count(/*q=*/0.5, QuantileOptions::LINEAR,
+ /*skip_nulls=*/false, /*min_count=*/3);
+ auto not_empty = ResultWith(ArrayFromJSON(float64(), "[3.0]"));
+ auto empty = ResultWith(ArrayFromJSON(float64(), "[null]"));
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5]"), keep_nulls), not_empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5, null]"), keep_nulls), empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5]"), keep_nulls), not_empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5, null]"), keep_nulls), empty);
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "3"), keep_nulls), not_empty);
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), keep_nulls), empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5]"), min_count), not_empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5, null]"), min_count), not_empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5]"), min_count), empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5, null]"), min_count), empty);
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "3"), min_count), empty);
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), min_count), empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5]"), keep_nulls_min_count),
+ not_empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5, null]"), keep_nulls_min_count),
+ empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5]"), keep_nulls_min_count), empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5, null]"), keep_nulls_min_count), empty);
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "3"), keep_nulls_min_count), empty);
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), keep_nulls_min_count), empty);
+
+ for (const auto interpolation : this->interpolations_) {
+ QuantileOptions options({0.0, 0.5, 1.0}, interpolation);
+ auto expected_ty = (interpolation == QuantileOptions::LINEAR ||
+ interpolation == QuantileOptions::MIDPOINT)
+ ? float64()
+ : ty;
+ EXPECT_THAT(Quantile(*MakeScalar(ty, 1), options),
+ ResultWith(ArrayFromJSON(expected_ty, "[1, 1, 1]")));
+ EXPECT_THAT(Quantile(MakeNullScalar(ty), options),
+ ResultWith(ArrayFromJSON(expected_ty, "[null, null, null]")));
+ }
+}
+
+template <typename ArrowType>
+class TestFloatingQuantileKernel : public TestPrimitiveQuantileKernel<ArrowType> {};
+
+#ifndef __MINGW32__
+TYPED_TEST_SUITE(TestFloatingQuantileKernel, RealArrowTypes);
+TYPED_TEST(TestFloatingQuantileKernel, Floats) {
+ // ordered by interpolation method: {linear, lower, higher, nearest, midpoint}
+ this->AssertQuantileIs("[-9, 7, Inf, -Inf, 2, 11]", 0.5, O(4.5, 2, 7, 2, 4.5));
+ this->AssertQuantileIs("[-9, 7, Inf, -Inf, 2, 11]", 0.1,
+ O(-INFINITY, -INFINITY, -9, -INFINITY, -INFINITY));
+ this->AssertQuantileIs("[-9, 7, Inf, -Inf, 2, 11]", 0.9,
+ O(INFINITY, 11, INFINITY, 11, INFINITY));
+ this->AssertQuantilesAre("[-9, 7, Inf, -Inf, 2, 11]", {0.3, 0.6},
+ {O(-3.5, -9, 2, 2, -3.5), O(7, 7, 7, 7, 7)});
+ this->AssertQuantileIs("[-Inf, Inf]", 0.2, O(NAN, -INFINITY, INFINITY, -INFINITY, NAN));
+
+ this->AssertQuantileIs("[NaN, -9, 7, Inf, null, null, -Inf, NaN, 2, 11]", 0.5,
+ O(4.5, 2, 7, 2, 4.5));
+ this->AssertQuantilesAre("[null, -9, 7, Inf, NaN, NaN, -Inf, null, 2, 11]", {0.3, 0.6},
+ {O(-3.5, -9, 2, 2, -3.5), O(7, 7, 7, 7, 7)});
+ this->AssertQuantilesAre("[null, -9, 7, Inf, NaN, NaN, -Inf, null, 2, 11]", {0.6, 0.3},
+ {O(7, 7, 7, 7, 7), O(-3.5, -9, 2, 2, -3.5)});
+
+ this->AssertQuantileIs({"[NaN, -9, 7, Inf]", "[null, NaN]", "[-Inf, NaN, 2, 11]"}, 0.5,
+ O(4.5, 2, 7, 2, 4.5));
+ this->AssertQuantilesAre({"[null, -9, 7, Inf]", "[NaN, NaN]", "[-Inf, null, 2, 11]"},
+ {0.3, 0.6}, {O(-3.5, -9, 2, 2, -3.5), O(7, 7, 7, 7, 7)});
+
+ this->AssertQuantilesEmpty("[]", {0.5, 0.6});
+ this->AssertQuantilesEmpty("[null, NaN, null]", {0.1});
+ this->AssertQuantilesEmpty({"[NaN, NaN]", "[]", "[null]"}, {0.3, 0.4});
+
+ auto ty = this->type_singleton();
+
+ QuantileOptions keep_nulls(/*q=*/0.5, QuantileOptions::LINEAR, /*skip_nulls=*/false,
+ /*min_count=*/0);
+ QuantileOptions min_count(/*q=*/0.5, QuantileOptions::LINEAR, /*skip_nulls=*/true,
+ /*min_count=*/3);
+ QuantileOptions keep_nulls_min_count(/*q=*/0.5, QuantileOptions::LINEAR,
+ /*skip_nulls=*/false, /*min_count=*/3);
+ auto not_empty = ResultWith(ArrayFromJSON(float64(), "[3.0]"));
+ auto empty = ResultWith(ArrayFromJSON(float64(), "[null]"));
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5]"), keep_nulls), not_empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5, null]"), keep_nulls), empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5]"), keep_nulls), not_empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5, null]"), keep_nulls), empty);
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "3"), keep_nulls), not_empty);
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), keep_nulls), empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5]"), min_count), not_empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5, null]"), min_count), not_empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5]"), min_count), empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5, null]"), min_count), empty);
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "3"), min_count), empty);
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), min_count), empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5]"), keep_nulls_min_count),
+ not_empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 2, 4, 5, null]"), keep_nulls_min_count),
+ empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5]"), keep_nulls_min_count), empty);
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1, 5, null]"), keep_nulls_min_count), empty);
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "3"), keep_nulls_min_count), empty);
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), keep_nulls_min_count), empty);
+
+ for (const auto interpolation : this->interpolations_) {
+ QuantileOptions options({0.0, 0.5, 1.0}, interpolation);
+ auto expected_ty = (interpolation == QuantileOptions::LINEAR ||
+ interpolation == QuantileOptions::MIDPOINT)
+ ? float64()
+ : ty;
+ EXPECT_THAT(Quantile(*MakeScalar(ty, 1), options),
+ ResultWith(ArrayFromJSON(expected_ty, "[1, 1, 1]")));
+ EXPECT_THAT(Quantile(MakeNullScalar(ty), options),
+ ResultWith(ArrayFromJSON(expected_ty, "[null, null, null]")));
+ }
+}
+
+class TestInt8QuantileKernel : public TestPrimitiveQuantileKernel<Int8Type> {};
+
+// Test histogram approach
+TEST_F(TestInt8QuantileKernel, Int8) {
+ using TypeParam = Int8Type;
+ this->AssertQuantilesAre(
+ "[127, -128, null, -128, 66, -88, 127]", {0, 0.3, 0.7, 1},
+ {O(-128, -128, -128, -128, -128), O(-108, -128, -88, -88, -108),
+ O(96.5, 66, 127, 127, 96.5), O(127, 127, 127, 127, 127)});
+ this->AssertQuantilesAre(
+ {"[null]", "[-88, 127]", "[]", "[66, -128, null, -128]", "[127]"}, {0, 0.3, 0.7, 1},
+ {O(-128, -128, -128, -128, -128), O(-108, -128, -88, -88, -108),
+ O(96.5, 66, 127, 127, 96.5), O(127, 127, 127, 127, 127)});
+}
+#endif
+
+class TestInt64QuantileKernel : public TestPrimitiveQuantileKernel<Int64Type> {};
+
+// Test big int64 numbers cannot be precisely presented by double
+TEST_F(TestInt64QuantileKernel, Int64) {
+ using TypeParam = Int64Type;
+ this->AssertQuantileIs(
+ "[9223372036854775806, 9223372036854775807]", 0.5,
+ O(9.223372036854776e+18, 9223372036854775806, 9223372036854775807,
+ 9223372036854775806, 9.223372036854776e+18));
+}
+
+#undef INTYPE
+#undef DOUBLE
+#undef O
+
+#ifndef __MINGW32__
+template <typename ArrowType>
+class TestRandomQuantileKernel : public TestPrimitiveQuantileKernel<ArrowType> {
+ using CType = typename ArrowType::c_type;
+
+ public:
+ void CheckQuantiles(int64_t array_size, int64_t num_quantiles) {
+ std::shared_ptr<Array> array;
+ std::vector<double> quantiles;
+ // small value range to exercise input array with equal values and histogram approach
+ GenerateTestData(array_size, num_quantiles, -100, 200, &array, &quantiles);
+
+ this->AssertQuantilesAre(array, QuantileOptions{quantiles},
+ NaiveQuantile(array, quantiles, this->interpolations_));
+ }
+
+ void CheckQuantilesSliced(int64_t array_size, int64_t num_quantiles) {
+ std::shared_ptr<Array> array;
+ std::vector<double> quantiles;
+ GenerateTestData(array_size, num_quantiles, -100, 200, &array, &quantiles);
+
+ const std::vector<std::array<int64_t, 2>> offset_size{
+ {0, array_size - 1},
+ {1, array_size - 1},
+ {array_size / 3, array_size / 2},
+ {array_size * 9 / 10, array_size / 10},
+ };
+ for (const auto& os : offset_size) {
+ auto sliced = array->Slice(os[0], os[1]);
+ this->AssertQuantilesAre(sliced, QuantileOptions{quantiles},
+ NaiveQuantile(sliced, quantiles, this->interpolations_));
+ }
+ }
+
+ void CheckTDigests(const std::vector<int>& chunk_sizes, int64_t num_quantiles) {
+ std::shared_ptr<ChunkedArray> chunked;
+ std::vector<double> quantiles;
+ GenerateChunked(chunk_sizes, num_quantiles, &chunked, &quantiles);
+
+ VerifyTDigest(chunked, quantiles);
+ }
+
+ void CheckTDigestsSliced(const std::vector<int>& chunk_sizes, int64_t num_quantiles) {
+ std::shared_ptr<ChunkedArray> chunked;
+ std::vector<double> quantiles;
+ GenerateChunked(chunk_sizes, num_quantiles, &chunked, &quantiles);
+
+ const int64_t size = chunked->length();
+ const std::vector<std::array<int64_t, 2>> offset_size{
+ {0, size - 1},
+ {1, size - 1},
+ {size / 3, size / 2},
+ {size * 9 / 10, size / 10},
+ };
+ for (const auto& os : offset_size) {
+ VerifyTDigest(chunked->Slice(os[0], os[1]), quantiles);
+ }
+ }
+
+ private:
+ void GenerateTestData(int64_t array_size, int64_t num_quantiles, int min, int max,
+ std::shared_ptr<Array>* array, std::vector<double>* quantiles) {
+ auto rand = random::RandomArrayGenerator(0x5487658);
+ if (is_floating_type<ArrowType>::value) {
+ *array = rand.Float64(array_size, min, max, /*null_prob=*/0.1, /*nan_prob=*/0.2);
+ } else {
+ *array = rand.Int64(array_size, min, max, /*null_prob=*/0.1);
+ }
+
+ random_real(num_quantiles, 0x5487658, 0.0, 1.0, quantiles);
+ // make sure to exercise 0 and 1 quantiles
+ *std::min_element(quantiles->begin(), quantiles->end()) = 0;
+ *std::max_element(quantiles->begin(), quantiles->end()) = 1;
+ }
+
+ void GenerateChunked(const std::vector<int>& chunk_sizes, int64_t num_quantiles,
+ std::shared_ptr<ChunkedArray>* chunked,
+ std::vector<double>* quantiles) {
+ int total_size = 0;
+ for (int size : chunk_sizes) {
+ total_size += size;
+ }
+ std::shared_ptr<Array> array;
+ GenerateTestData(total_size, num_quantiles, 100, 123456789, &array, quantiles);
+
+ total_size = 0;
+ ArrayVector array_vector;
+ for (int size : chunk_sizes) {
+ array_vector.emplace_back(array->Slice(total_size, size));
+ total_size += size;
+ }
+ *chunked = ChunkedArray::Make(array_vector).ValueOrDie();
+ }
+
+ void VerifyTDigest(const std::shared_ptr<ChunkedArray>& chunked,
+ std::vector<double>& quantiles) {
+ TDigestOptions options(quantiles);
+ ASSERT_OK_AND_ASSIGN(Datum out, TDigest(chunked, options));
+ const auto& out_array = out.make_array();
+ ValidateOutput(*out_array);
+ ASSERT_EQ(out_array->length(), quantiles.size());
+ ASSERT_EQ(out_array->null_count(), 0);
+ AssertTypeEqual(out_array->type(), float64());
+
+ // linear interpolated exact quantile as reference
+ std::vector<std::vector<Datum>> exact =
+ NaiveQuantile(*chunked, quantiles, {QuantileOptions::LINEAR});
+ const double* approx = out_array->data()->GetValues<double>(1);
+ for (size_t i = 0; i < quantiles.size(); ++i) {
+ const auto& exact_scalar = checked_pointer_cast<DoubleScalar>(exact[i][0].scalar());
+ const double tolerance = std::fabs(exact_scalar->value) * 0.05;
+ EXPECT_NEAR(approx[i], exact_scalar->value, tolerance) << quantiles[i];
+ }
+ }
+
+ std::vector<std::vector<Datum>> NaiveQuantile(
+ const std::shared_ptr<Array>& array, const std::vector<double>& quantiles,
+ const std::vector<enum QuantileOptions::Interpolation>& interpolations) {
+ return NaiveQuantile(ChunkedArray(array), quantiles, interpolations);
+ }
+
+ std::vector<std::vector<Datum>> NaiveQuantile(
+ const ChunkedArray& chunked, const std::vector<double>& quantiles,
+ const std::vector<enum QuantileOptions::Interpolation>& interpolations) {
+ // copy and sort input chunked array
+ int64_t index = 0;
+ std::vector<CType> input(chunked.length() - chunked.null_count());
+ for (const auto& array : chunked.chunks()) {
+ const CType* values = array->data()->GetValues<CType>(1);
+ const auto bitmap = array->null_bitmap_data();
+ for (int64_t i = 0; i < array->length(); ++i) {
+ if ((!bitmap || BitUtil::GetBit(bitmap, array->data()->offset + i)) &&
+ !std::isnan(static_cast<double>(values[i]))) {
+ input[index++] = values[i];
+ }
+ }
+ }
+ input.resize(index);
+ std::sort(input.begin(), input.end());
+
+ std::vector<std::vector<Datum>> output(quantiles.size(),
+ std::vector<Datum>(interpolations.size()));
+ for (uint64_t i = 0; i < interpolations.size(); ++i) {
+ const auto interp = interpolations[i];
+ for (uint64_t j = 0; j < quantiles.size(); ++j) {
+ output[j][i] = GetQuantile(input, quantiles[j], interp);
+ }
+ }
+ return output;
+ }
+
+ Datum GetQuantile(const std::vector<CType>& input, double q,
+ enum QuantileOptions::Interpolation interp) {
+ const double index = (input.size() - 1) * q;
+ const uint64_t lower_index = static_cast<uint64_t>(index);
+ const double fraction = index - lower_index;
+
+ switch (interp) {
+ case QuantileOptions::LOWER:
+ return Datum(input[lower_index]);
+ case QuantileOptions::HIGHER:
+ return Datum(input[lower_index + (fraction != 0)]);
+ case QuantileOptions::NEAREST:
+ if (fraction < 0.5) {
+ return Datum(input[lower_index]);
+ } else if (fraction > 0.5) {
+ return Datum(input[lower_index + 1]);
+ } else {
+ return Datum(input[lower_index + (lower_index & 1)]);
+ }
+ case QuantileOptions::LINEAR:
+ if (fraction == 0) {
+ return Datum(input[lower_index] * 1.0);
+ } else {
+ return Datum(fraction * input[lower_index + 1] +
+ (1 - fraction) * input[lower_index]);
+ }
+ case QuantileOptions::MIDPOINT:
+ if (fraction == 0) {
+ return Datum(input[lower_index] * 1.0);
+ } else {
+ return Datum(input[lower_index] / 2.0 + input[lower_index + 1] / 2.0);
+ }
+ default:
+ return Datum(NAN);
+ }
+ }
+};
+
+class TestRandomInt64QuantileKernel : public TestRandomQuantileKernel<Int64Type> {};
+
+TEST_F(TestRandomInt64QuantileKernel, Normal) {
+ // exercise copy and sort approach: size < 65536
+ this->CheckQuantiles(/*array_size=*/10000, /*num_quantiles=*/100);
+}
+
+TEST_F(TestRandomInt64QuantileKernel, Overlapped) {
+ // much more quantiles than array size => many overlaps
+ this->CheckQuantiles(/*array_size=*/999, /*num_quantiles=*/9999);
+}
+
+TEST_F(TestRandomInt64QuantileKernel, Histogram) {
+ // exercise histogram approach: size >= 65536, range <= 65536
+ this->CheckQuantiles(/*array_size=*/80000, /*num_quantiles=*/100);
+}
+
+TEST_F(TestRandomInt64QuantileKernel, Sliced) {
+ this->CheckQuantilesSliced(1000, 10); // sort
+ this->CheckQuantilesSliced(66000, 10); // count
+}
+
+class TestRandomFloatQuantileKernel : public TestRandomQuantileKernel<DoubleType> {};
+
+TEST_F(TestRandomFloatQuantileKernel, Exact) {
+ this->CheckQuantiles(/*array_size=*/1000, /*num_quantiles=*/100);
+}
+
+TEST_F(TestRandomFloatQuantileKernel, TDigest) {
+ this->CheckTDigests(/*chunk_sizes=*/{12345, 6789, 8765, 4321}, /*num_quantiles=*/100);
+}
+
+TEST_F(TestRandomFloatQuantileKernel, Sliced) {
+ this->CheckQuantilesSliced(1000, 10);
+ this->CheckTDigestsSliced({200, 600}, 10);
+}
+#endif
+
+TEST(TestQuantileKernel, AllNullsOrNaNs) {
+ const std::vector<std::vector<std::string>> tests = {
+ {"[]"},
+ {"[null, null]", "[]", "[null]"},
+ {"[NaN]", "[NaN, NaN]", "[]"},
+ {"[null, NaN, null]"},
+ {"[NaN, NaN]", "[]", "[null]"},
+ };
+
+ for (const auto& json : tests) {
+ auto chunked = ChunkedArrayFromJSON(float64(), json);
+ ASSERT_OK_AND_ASSIGN(Datum out, Quantile(chunked, QuantileOptions()));
+ auto out_array = out.make_array();
+ ValidateOutput(*out_array);
+ AssertArraysEqual(*ArrayFromJSON(float64(), "[null]"), *out_array, /*verbose=*/true);
+ }
+}
+
+TEST(TestQuantileKernel, Scalar) {
+ for (const auto& ty : {float64(), int64(), uint64()}) {
+ QuantileOptions options(std::vector<double>{0.0, 0.5, 1.0});
+ EXPECT_THAT(Quantile(*MakeScalar(ty, 1), options),
+ ResultWith(ArrayFromJSON(float64(), "[1.0, 1.0, 1.0]")));
+ EXPECT_THAT(Quantile(MakeNullScalar(ty), options),
+ ResultWith(ArrayFromJSON(float64(), "[null, null, null]")));
+ }
+}
+
+TEST(TestQuantileKernel, Options) {
+ auto ty = float64();
+ QuantileOptions keep_nulls(/*q=*/0.5, QuantileOptions::LINEAR,
+ /*skip_nulls=*/false, /*min_count=*/0);
+ QuantileOptions min_count(/*q=*/0.5, QuantileOptions::LINEAR,
+ /*skip_nulls=*/true, /*min_count=*/3);
+ QuantileOptions keep_nulls_min_count(/*q=*/0.5, QuantileOptions::NEAREST,
+ /*skip_nulls=*/false, /*min_count=*/3);
+
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1.0, 2.0, 3.0]"), keep_nulls),
+ ResultWith(ArrayFromJSON(ty, "[2.0]")));
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1.0, 2.0, 3.0, null]"), keep_nulls),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "1.0"), keep_nulls),
+ ResultWith(ArrayFromJSON(ty, "[1.0]")));
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), keep_nulls),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1.0, 2.0, 3.0, null]"), min_count),
+ ResultWith(ArrayFromJSON(ty, "[2.0]")));
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1.0, 2.0, null]"), min_count),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "1.0"), min_count),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), min_count),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1.0, 2.0, 3.0]"), keep_nulls_min_count),
+ ResultWith(ArrayFromJSON(ty, "[2.0]")));
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1.0, 2.0]"), keep_nulls_min_count),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+ EXPECT_THAT(Quantile(ArrayFromJSON(ty, "[1.0, 2.0, 3.0, null]"), keep_nulls_min_count),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "1.0"), keep_nulls_min_count),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+ EXPECT_THAT(Quantile(ScalarFromJSON(ty, "null"), keep_nulls_min_count),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+}
+
+TEST(TestTDigestKernel, AllNullsOrNaNs) {
+ const std::vector<std::vector<std::string>> tests = {
+ {"[]"},
+ {"[null, null]", "[]", "[null]"},
+ {"[NaN]", "[NaN, NaN]", "[]"},
+ {"[null, NaN, null]"},
+ {"[NaN, NaN]", "[]", "[null]"},
+ };
+
+ for (const auto& json : tests) {
+ auto chunked = ChunkedArrayFromJSON(float64(), json);
+ ASSERT_OK_AND_ASSIGN(Datum out, TDigest(chunked, TDigestOptions()));
+ auto out_array = out.make_array();
+ ValidateOutput(*out_array);
+ AssertArraysEqual(*ArrayFromJSON(float64(), "[null]"), *out_array, /*verbose=*/true);
+ }
+}
+
+TEST(TestTDigestKernel, Scalar) {
+ for (const auto& ty : {float64(), int64(), uint64()}) {
+ TDigestOptions options(std::vector<double>{0.0, 0.5, 1.0});
+ EXPECT_THAT(TDigest(*MakeScalar(ty, 1), options),
+ ResultWith(ArrayFromJSON(float64(), "[1, 1, 1]")));
+ EXPECT_THAT(TDigest(MakeNullScalar(ty), options),
+ ResultWith(ArrayFromJSON(float64(), "[null, null, null]")));
+ }
+}
+
+TEST(TestTDigestKernel, Options) {
+ auto ty = float64();
+ TDigestOptions keep_nulls(/*q=*/0.5, /*delta=*/100, /*buffer_size=*/500,
+ /*skip_nulls=*/false, /*min_count=*/0);
+ TDigestOptions min_count(/*q=*/0.5, /*delta=*/100, /*buffer_size=*/500,
+ /*skip_nulls=*/true, /*min_count=*/3);
+ TDigestOptions keep_nulls_min_count(/*q=*/0.5, /*delta=*/100, /*buffer_size=*/500,
+ /*skip_nulls=*/false, /*min_count=*/3);
+
+ EXPECT_THAT(TDigest(ArrayFromJSON(ty, "[1.0, 2.0, 3.0]"), keep_nulls),
+ ResultWith(ArrayFromJSON(ty, "[2.0]")));
+ EXPECT_THAT(TDigest(ArrayFromJSON(ty, "[1.0, 2.0, 3.0, null]"), keep_nulls),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+ EXPECT_THAT(TDigest(ScalarFromJSON(ty, "1.0"), keep_nulls),
+ ResultWith(ArrayFromJSON(ty, "[1.0]")));
+ EXPECT_THAT(TDigest(ScalarFromJSON(ty, "null"), keep_nulls),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+
+ EXPECT_THAT(TDigest(ArrayFromJSON(ty, "[1.0, 2.0, 3.0, null]"), min_count),
+ ResultWith(ArrayFromJSON(ty, "[2.0]")));
+ EXPECT_THAT(TDigest(ArrayFromJSON(ty, "[1.0, 2.0, null]"), min_count),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+ EXPECT_THAT(TDigest(ScalarFromJSON(ty, "1.0"), min_count),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+ EXPECT_THAT(TDigest(ScalarFromJSON(ty, "null"), min_count),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+
+ EXPECT_THAT(TDigest(ArrayFromJSON(ty, "[1.0, 2.0, 3.0]"), keep_nulls_min_count),
+ ResultWith(ArrayFromJSON(ty, "[2.0]")));
+ EXPECT_THAT(TDigest(ArrayFromJSON(ty, "[1.0, 2.0]"), keep_nulls_min_count),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+ EXPECT_THAT(TDigest(ArrayFromJSON(ty, "[1.0, 2.0, 3.0, null]"), keep_nulls_min_count),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+ EXPECT_THAT(TDigest(ScalarFromJSON(ty, "1.0"), keep_nulls_min_count),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+ EXPECT_THAT(TDigest(ScalarFromJSON(ty, "null"), keep_nulls_min_count),
+ ResultWith(ArrayFromJSON(ty, "[null]")));
+}
+
+TEST(TestTDigestKernel, ApproximateMedian) {
+ // This is a wrapper for TDigest
+ for (const auto& ty : {float64(), int64(), uint16()}) {
+ ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false, /*min_count=*/0);
+ ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/3);
+ ScalarAggregateOptions keep_nulls_min_count(/*skip_nulls=*/false, /*min_count=*/3);
+
+ EXPECT_THAT(
+ CallFunction("approximate_median", {ArrayFromJSON(ty, "[1, 2, 3]")}, &keep_nulls),
+ ResultWith(ScalarFromJSON(float64(), "2.0")));
+ EXPECT_THAT(CallFunction("approximate_median", {ArrayFromJSON(ty, "[1, 2, 3, null]")},
+ &keep_nulls),
+ ResultWith(ScalarFromJSON(float64(), "null")));
+ EXPECT_THAT(
+ CallFunction("approximate_median",
+ {ChunkedArrayFromJSON(ty, {"[1, 2]", "[]", "[3]"})}, &keep_nulls),
+ ResultWith(ScalarFromJSON(float64(), "2.0")));
+ EXPECT_THAT(CallFunction("approximate_median",
+ {ChunkedArrayFromJSON(ty, {"[1, 2]", "[null]", "[3]"})},
+ &keep_nulls),
+ ResultWith(ScalarFromJSON(float64(), "null")));
+ EXPECT_THAT(
+ CallFunction("approximate_median", {ScalarFromJSON(ty, "1")}, &keep_nulls),
+ ResultWith(ScalarFromJSON(float64(), "1.0")));
+ EXPECT_THAT(
+ CallFunction("approximate_median", {ScalarFromJSON(ty, "null")}, &keep_nulls),
+ ResultWith(ScalarFromJSON(float64(), "null")));
+
+ EXPECT_THAT(CallFunction("approximate_median", {ArrayFromJSON(ty, "[1, 2, 3, null]")},
+ &min_count),
+ ResultWith(ScalarFromJSON(float64(), "2.0")));
+ EXPECT_THAT(CallFunction("approximate_median", {ArrayFromJSON(ty, "[1, 2, null]")},
+ &min_count),
+ ResultWith(ScalarFromJSON(float64(), "null")));
+ EXPECT_THAT(
+ CallFunction("approximate_median",
+ {ChunkedArrayFromJSON(ty, {"[1, 2]", "[]", "[3]"})}, &keep_nulls),
+ ResultWith(ScalarFromJSON(float64(), "2.0")));
+ EXPECT_THAT(CallFunction("approximate_median",
+ {ChunkedArrayFromJSON(ty, {"[1, 2]", "[null]", "[3]"})},
+ &keep_nulls),
+ ResultWith(ScalarFromJSON(float64(), "null")));
+ EXPECT_THAT(CallFunction("approximate_median", {ScalarFromJSON(ty, "1")}, &min_count),
+ ResultWith(ScalarFromJSON(float64(), "null")));
+ EXPECT_THAT(
+ CallFunction("approximate_median", {ScalarFromJSON(ty, "null")}, &min_count),
+ ResultWith(ScalarFromJSON(float64(), "null")));
+
+ EXPECT_THAT(CallFunction("approximate_median", {ArrayFromJSON(ty, "[1, 2, 3]")},
+ &keep_nulls_min_count),
+ ResultWith(ScalarFromJSON(float64(), "2.0")));
+ EXPECT_THAT(CallFunction("approximate_median", {ArrayFromJSON(ty, "[1, 2]")},
+ &keep_nulls_min_count),
+ ResultWith(ScalarFromJSON(float64(), "null")));
+ EXPECT_THAT(CallFunction("approximate_median",
+ {ChunkedArrayFromJSON(ty, {"[1, 2]", "[]", "[3]"})},
+ &keep_nulls_min_count),
+ ResultWith(ScalarFromJSON(float64(), "2.0")));
+ EXPECT_THAT(CallFunction("approximate_median",
+ {ChunkedArrayFromJSON(ty, {"[1, 2]", "[null]", "[3]"})},
+ &keep_nulls_min_count),
+ ResultWith(ScalarFromJSON(float64(), "null")));
+ EXPECT_THAT(CallFunction("approximate_median", {ArrayFromJSON(ty, "[1, 2, 3, null]")},
+ &keep_nulls_min_count),
+ ResultWith(ScalarFromJSON(float64(), "null")));
+ EXPECT_THAT(CallFunction("approximate_median", {ScalarFromJSON(ty, "1")},
+ &keep_nulls_min_count),
+ ResultWith(ScalarFromJSON(float64(), "null")));
+ EXPECT_THAT(CallFunction("approximate_median", {ScalarFromJSON(ty, "null")},
+ &keep_nulls_min_count),
+ ResultWith(ScalarFromJSON(float64(), "null")));
+ }
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/aggregate_var_std.cc b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_var_std.cc
new file mode 100644
index 000000000..d0d3c514f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_var_std.cc
@@ -0,0 +1,298 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cmath>
+
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/kernels/aggregate_internal.h"
+#include "arrow/compute/kernels/aggregate_var_std_internal.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/int128_internal.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+namespace {
+
+using arrow::internal::int128_t;
+using arrow::internal::VisitSetBitRunsVoid;
+
+template <typename ArrowType>
+struct VarStdState {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ using CType = typename ArrowType::c_type;
+ using ThisType = VarStdState<ArrowType>;
+
+ explicit VarStdState(VarianceOptions options) : options(options) {}
+
+ // float/double/int64: calculate `m2` (sum((X-mean)^2)) with `two pass algorithm`
+ // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm
+ template <typename T = ArrowType>
+ enable_if_t<is_floating_type<T>::value || (sizeof(CType) > 4)> Consume(
+ const ArrayType& array) {
+ this->all_valid = array.null_count() == 0;
+ int64_t count = array.length() - array.null_count();
+ if (count == 0 || (!this->all_valid && !options.skip_nulls)) {
+ return;
+ }
+
+ using SumType =
+ typename std::conditional<is_floating_type<T>::value, double, int128_t>::type;
+ SumType sum = SumArray<CType, SumType, SimdLevel::NONE>(*array.data());
+
+ const double mean = static_cast<double>(sum) / count;
+ const double m2 =
+ SumArray<CType, double, SimdLevel::NONE>(*array.data(), [mean](CType value) {
+ const double v = static_cast<double>(value);
+ return (v - mean) * (v - mean);
+ });
+
+ this->count = count;
+ this->mean = mean;
+ this->m2 = m2;
+ }
+
+ // int32/16/8: textbook one pass algorithm with integer arithmetic
+ template <typename T = ArrowType>
+ enable_if_t<is_integer_type<T>::value && (sizeof(CType) <= 4)> Consume(
+ const ArrayType& array) {
+ // max number of elements that sum will not overflow int64 (2Gi int32 elements)
+ // for uint32: 0 <= sum < 2^63 (int64 >= 0)
+ // for int32: -2^62 <= sum < 2^62
+ constexpr int64_t max_length = 1ULL << (63 - sizeof(CType) * 8);
+
+ this->all_valid = array.null_count() == 0;
+ if (!this->all_valid && !options.skip_nulls) return;
+ int64_t start_index = 0;
+ int64_t valid_count = array.length() - array.null_count();
+
+ while (valid_count > 0) {
+ // process in chunks that overflow will never happen
+ const auto slice = array.Slice(start_index, max_length);
+ const int64_t count = slice->length() - slice->null_count();
+ start_index += max_length;
+ valid_count -= count;
+
+ if (count > 0) {
+ IntegerVarStd<ArrowType> var_std;
+ const ArrayData& data = *slice->data();
+ const CType* values = data.GetValues<CType>(1);
+ VisitSetBitRunsVoid(data.buffers[0], data.offset, data.length,
+ [&](int64_t pos, int64_t len) {
+ for (int64_t i = 0; i < len; ++i) {
+ const auto value = values[pos + i];
+ var_std.ConsumeOne(value);
+ }
+ });
+
+ // merge variance
+ ThisType state(options);
+ state.count = var_std.count;
+ state.mean = var_std.mean();
+ state.m2 = var_std.m2();
+ this->MergeFrom(state);
+ }
+ }
+ }
+
+ // Scalar: textbook algorithm
+ void Consume(const Scalar& scalar, const int64_t count) {
+ this->m2 = 0;
+ if (scalar.is_valid) {
+ this->count = count;
+ this->mean = static_cast<double>(UnboxScalar<ArrowType>::Unbox(scalar));
+ } else {
+ this->count = 0;
+ this->mean = 0;
+ this->all_valid = false;
+ }
+ }
+
+ // Combine `m2` from two chunks (m2 = n*s2)
+ // https://www.emathzone.com/tutorials/basic-statistics/combined-variance.html
+ void MergeFrom(const ThisType& state) {
+ this->all_valid = this->all_valid && state.all_valid;
+ if (state.count == 0) {
+ return;
+ }
+ if (this->count == 0) {
+ this->count = state.count;
+ this->mean = state.mean;
+ this->m2 = state.m2;
+ return;
+ }
+ MergeVarStd(this->count, this->mean, state.count, state.mean, state.m2, &this->count,
+ &this->mean, &this->m2);
+ }
+
+ const VarianceOptions options;
+ int64_t count = 0;
+ double mean = 0;
+ double m2 = 0; // m2 = count*s2 = sum((X-mean)^2)
+ bool all_valid = true;
+};
+
+template <typename ArrowType>
+struct VarStdImpl : public ScalarAggregator {
+ using ThisType = VarStdImpl<ArrowType>;
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ explicit VarStdImpl(const std::shared_ptr<DataType>& out_type,
+ const VarianceOptions& options, VarOrStd return_type)
+ : out_type(out_type), state(options), return_type(return_type) {}
+
+ Status Consume(KernelContext*, const ExecBatch& batch) override {
+ if (batch[0].is_array()) {
+ ArrayType array(batch[0].array());
+ this->state.Consume(array);
+ } else {
+ this->state.Consume(*batch[0].scalar(), batch.length);
+ }
+ return Status::OK();
+ }
+
+ Status MergeFrom(KernelContext*, KernelState&& src) override {
+ const auto& other = checked_cast<const ThisType&>(src);
+ this->state.MergeFrom(other.state);
+ return Status::OK();
+ }
+
+ Status Finalize(KernelContext*, Datum* out) override {
+ if (state.count <= state.options.ddof || state.count < state.options.min_count ||
+ (!state.all_valid && !state.options.skip_nulls)) {
+ out->value = std::make_shared<DoubleScalar>();
+ } else {
+ double var = state.m2 / (state.count - state.options.ddof);
+ out->value =
+ std::make_shared<DoubleScalar>(return_type == VarOrStd::Var ? var : sqrt(var));
+ }
+ return Status::OK();
+ }
+
+ std::shared_ptr<DataType> out_type;
+ VarStdState<ArrowType> state;
+ VarOrStd return_type;
+};
+
+struct VarStdInitState {
+ std::unique_ptr<KernelState> state;
+ KernelContext* ctx;
+ const DataType& in_type;
+ const std::shared_ptr<DataType>& out_type;
+ const VarianceOptions& options;
+ VarOrStd return_type;
+
+ VarStdInitState(KernelContext* ctx, const DataType& in_type,
+ const std::shared_ptr<DataType>& out_type,
+ const VarianceOptions& options, VarOrStd return_type)
+ : ctx(ctx),
+ in_type(in_type),
+ out_type(out_type),
+ options(options),
+ return_type(return_type) {}
+
+ Status Visit(const DataType&) {
+ return Status::NotImplemented("No variance/stddev implemented");
+ }
+
+ Status Visit(const HalfFloatType&) {
+ return Status::NotImplemented("No variance/stddev implemented");
+ }
+
+ template <typename Type>
+ enable_if_t<is_number_type<Type>::value, Status> Visit(const Type&) {
+ state.reset(new VarStdImpl<Type>(out_type, options, return_type));
+ return Status::OK();
+ }
+
+ Result<std::unique_ptr<KernelState>> Create() {
+ RETURN_NOT_OK(VisitTypeInline(in_type, this));
+ return std::move(state);
+ }
+};
+
+Result<std::unique_ptr<KernelState>> StddevInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ VarStdInitState visitor(
+ ctx, *args.inputs[0].type, args.kernel->signature->out_type().type(),
+ static_cast<const VarianceOptions&>(*args.options), VarOrStd::Std);
+ return visitor.Create();
+}
+
+Result<std::unique_ptr<KernelState>> VarianceInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ VarStdInitState visitor(
+ ctx, *args.inputs[0].type, args.kernel->signature->out_type().type(),
+ static_cast<const VarianceOptions&>(*args.options), VarOrStd::Var);
+ return visitor.Create();
+}
+
+void AddVarStdKernels(KernelInit init,
+ const std::vector<std::shared_ptr<DataType>>& types,
+ ScalarAggregateFunction* func) {
+ for (const auto& ty : types) {
+ auto sig = KernelSignature::Make({InputType(ty)}, float64());
+ AddAggKernel(std::move(sig), init, func);
+ }
+}
+
+const FunctionDoc stddev_doc{
+ "Calculate the standard deviation of a numeric array",
+ ("The number of degrees of freedom can be controlled using VarianceOptions.\n"
+ "By default (`ddof` = 0), the population standard deviation is calculated.\n"
+ "Nulls are ignored. If there are not enough non-null values in the array\n"
+ "to satisfy `ddof`, null is returned."),
+ {"array"},
+ "VarianceOptions"};
+
+const FunctionDoc variance_doc{
+ "Calculate the variance of a numeric array",
+ ("The number of degrees of freedom can be controlled using VarianceOptions.\n"
+ "By default (`ddof` = 0), the population variance is calculated.\n"
+ "Nulls are ignored. If there are not enough non-null values in the array\n"
+ "to satisfy `ddof`, null is returned."),
+ {"array"},
+ "VarianceOptions"};
+
+std::shared_ptr<ScalarAggregateFunction> AddStddevAggKernels() {
+ static auto default_std_options = VarianceOptions::Defaults();
+ auto func = std::make_shared<ScalarAggregateFunction>(
+ "stddev", Arity::Unary(), &stddev_doc, &default_std_options);
+ AddVarStdKernels(StddevInit, NumericTypes(), func.get());
+ return func;
+}
+
+std::shared_ptr<ScalarAggregateFunction> AddVarianceAggKernels() {
+ static auto default_var_options = VarianceOptions::Defaults();
+ auto func = std::make_shared<ScalarAggregateFunction>(
+ "variance", Arity::Unary(), &variance_doc, &default_var_options);
+ AddVarStdKernels(VarianceInit, NumericTypes(), func.get());
+ return func;
+}
+
+} // namespace
+
+void RegisterScalarAggregateVariance(FunctionRegistry* registry) {
+ DCHECK_OK(registry->AddFunction(AddVarianceAggKernels()));
+ DCHECK_OK(registry->AddFunction(AddStddevAggKernels()));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/aggregate_var_std_internal.h b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_var_std_internal.h
new file mode 100644
index 000000000..675ebfd91
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_var_std_internal.h
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/util/int128_internal.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+using arrow::internal::int128_t;
+
+// Accumulate sum/squared sum (using naive summation)
+// Shared implementation between scalar/hash aggregate variance/stddev kernels
+template <typename ArrowType>
+struct IntegerVarStd {
+ using c_type = typename ArrowType::c_type;
+
+ int64_t count = 0;
+ int64_t sum = 0;
+ int128_t square_sum = 0;
+
+ void ConsumeOne(const c_type value) {
+ sum += value;
+ square_sum += static_cast<uint64_t>(value) * value;
+ count++;
+ }
+
+ double mean() const { return static_cast<double>(sum) / count; }
+
+ double m2() const {
+ // calculate m2 = square_sum - sum * sum / count
+ // decompose `sum * sum / count` into integers and fractions
+ const int128_t sum_square = static_cast<int128_t>(sum) * sum;
+ const int128_t integers = sum_square / count;
+ const double fractions = static_cast<double>(sum_square % count) / count;
+ return static_cast<double>(square_sum - integers) - fractions;
+ }
+};
+
+static inline void MergeVarStd(int64_t count1, double mean1, int64_t count2, double mean2,
+ double m22, int64_t* out_count, double* out_mean,
+ double* out_m2) {
+ double mean = (mean1 * count1 + mean2 * count2) / (count1 + count2);
+ *out_m2 += m22 + count1 * (mean1 - mean) * (mean1 - mean) +
+ count2 * (mean2 - mean) * (mean2 - mean);
+ *out_count += count2;
+ *out_mean = mean;
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/chunked_internal.h b/src/arrow/cpp/src/arrow/compute/kernels/chunked_internal.h
new file mode 100644
index 000000000..b007d6cbf
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/chunked_internal.h
@@ -0,0 +1,167 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cstdint>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/record_batch.h"
+#include "arrow/type.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+// The target chunk in a chunked array.
+template <typename ArrayType>
+struct ResolvedChunk {
+ using V = GetViewType<typename ArrayType::TypeClass>;
+ using LogicalValueType = typename V::T;
+
+ // The target array in chunked array.
+ const ArrayType* array;
+ // The index in the target array.
+ const int64_t index;
+
+ ResolvedChunk(const ArrayType* array, int64_t index) : array(array), index(index) {}
+
+ bool IsNull() const { return array->IsNull(index); }
+
+ LogicalValueType Value() const { return V::LogicalValue(array->GetView(index)); }
+};
+
+// ResolvedChunk specialization for untyped arrays when all is needed is null lookup
+template <>
+struct ResolvedChunk<Array> {
+ // The target array in chunked array.
+ const Array* array;
+ // The index in the target array.
+ const int64_t index;
+
+ ResolvedChunk(const Array* array, int64_t index) : array(array), index(index) {}
+
+ bool IsNull() const { return array->IsNull(index); }
+};
+
+struct ChunkLocation {
+ int64_t chunk_index, index_in_chunk;
+};
+
+// An object that resolves an array chunk depending on the index.
+struct ChunkResolver {
+ explicit ChunkResolver(std::vector<int64_t> lengths)
+ : num_chunks_(static_cast<int64_t>(lengths.size())),
+ offsets_(MakeEndOffsets(std::move(lengths))),
+ cached_chunk_(0) {}
+
+ ChunkLocation Resolve(int64_t index) const {
+ // It is common for the algorithms below to make consecutive accesses at
+ // a relatively small distance from each other, hence often falling in
+ // the same chunk.
+ // This is trivial when merging (assuming each side of the merge uses
+ // its own resolver), but also in the inner recursive invocations of
+ // partitioning.
+ const bool cache_hit =
+ (index >= offsets_[cached_chunk_] && index < offsets_[cached_chunk_ + 1]);
+ if (ARROW_PREDICT_TRUE(cache_hit)) {
+ return {cached_chunk_, index - offsets_[cached_chunk_]};
+ } else {
+ return ResolveMissBisect(index);
+ }
+ }
+
+ static ChunkResolver FromBatches(const RecordBatchVector& batches) {
+ std::vector<int64_t> lengths(batches.size());
+ std::transform(
+ batches.begin(), batches.end(), lengths.begin(),
+ [](const std::shared_ptr<RecordBatch>& batch) { return batch->num_rows(); });
+ return ChunkResolver(std::move(lengths));
+ }
+
+ protected:
+ ChunkLocation ResolveMissBisect(int64_t index) const {
+ // Like std::upper_bound(), but hand-written as it can help the compiler.
+ const int64_t* raw_offsets = offsets_.data();
+ // Search [lo, lo + n)
+ int64_t lo = 0, n = num_chunks_;
+ while (n > 1) {
+ int64_t m = n >> 1;
+ int64_t mid = lo + m;
+ if (index >= raw_offsets[mid]) {
+ lo = mid;
+ n -= m;
+ } else {
+ n = m;
+ }
+ }
+ cached_chunk_ = lo;
+ return {lo, index - offsets_[lo]};
+ }
+
+ static std::vector<int64_t> MakeEndOffsets(std::vector<int64_t> lengths) {
+ int64_t offset = 0;
+ for (auto& v : lengths) {
+ const auto this_length = v;
+ v = offset;
+ offset += this_length;
+ }
+ lengths.push_back(offset);
+ return lengths;
+ }
+
+ int64_t num_chunks_;
+ std::vector<int64_t> offsets_;
+
+ mutable int64_t cached_chunk_;
+};
+
+struct ChunkedArrayResolver : protected ChunkResolver {
+ explicit ChunkedArrayResolver(const std::vector<const Array*>& chunks)
+ : ChunkResolver(MakeLengths(chunks)), chunks_(chunks) {}
+
+ template <typename ArrayType>
+ ResolvedChunk<ArrayType> Resolve(int64_t index) const {
+ const auto loc = ChunkResolver::Resolve(index);
+ return ResolvedChunk<ArrayType>(
+ checked_cast<const ArrayType*>(chunks_[loc.chunk_index]), loc.index_in_chunk);
+ }
+
+ protected:
+ static std::vector<int64_t> MakeLengths(const std::vector<const Array*>& chunks) {
+ std::vector<int64_t> lengths(chunks.size());
+ std::transform(chunks.begin(), chunks.end(), lengths.begin(),
+ [](const Array* arr) { return arr->length(); });
+ return lengths;
+ }
+
+ const std::vector<const Array*> chunks_;
+};
+
+inline std::vector<const Array*> GetArrayPointers(const ArrayVector& arrays) {
+ std::vector<const Array*> pointers(arrays.size());
+ std::transform(arrays.begin(), arrays.end(), pointers.begin(),
+ [&](const std::shared_ptr<Array>& array) { return array.get(); });
+ return pointers;
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/codegen_internal.cc b/src/arrow/cpp/src/arrow/compute/kernels/codegen_internal.cc
new file mode 100644
index 000000000..209c433db
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/codegen_internal.cc
@@ -0,0 +1,420 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/kernels/codegen_internal.h"
+
+#include <cmath>
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <vector>
+
+#include "arrow/type_fwd.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+Status ExecFail(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return Status::NotImplemented("This kernel is malformed");
+}
+
+ArrayKernelExec MakeFlippedBinaryExec(ArrayKernelExec exec) {
+ return [exec](KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ ExecBatch flipped_batch = batch;
+ std::swap(flipped_batch.values[0], flipped_batch.values[1]);
+ return exec(ctx, flipped_batch, out);
+ };
+}
+
+const std::vector<std::shared_ptr<DataType>>& ExampleParametricTypes() {
+ static DataTypeVector example_parametric_types = {
+ decimal128(12, 2),
+ duration(TimeUnit::SECOND),
+ timestamp(TimeUnit::SECOND),
+ time32(TimeUnit::SECOND),
+ time64(TimeUnit::MICRO),
+ fixed_size_binary(0),
+ list(null()),
+ large_list(null()),
+ fixed_size_list(field("dummy", null()), 0),
+ struct_({}),
+ sparse_union(FieldVector{}),
+ dense_union(FieldVector{}),
+ dictionary(int32(), null()),
+ map(null(), null())};
+ return example_parametric_types;
+}
+
+Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& descrs) {
+ ValueDescr result = descrs.front();
+ result.shape = GetBroadcastShape(descrs);
+ return result;
+}
+
+Result<ValueDescr> LastType(KernelContext*, const std::vector<ValueDescr>& descrs) {
+ ValueDescr result = descrs.back();
+ result.shape = GetBroadcastShape(descrs);
+ return result;
+}
+
+Result<ValueDescr> ListValuesType(KernelContext*, const std::vector<ValueDescr>& args) {
+ const auto& list_type = checked_cast<const BaseListType&>(*args[0].type);
+ return ValueDescr(list_type.value_type(), GetBroadcastShape(args));
+}
+
+void EnsureDictionaryDecoded(std::vector<ValueDescr>* descrs) {
+ EnsureDictionaryDecoded(descrs->data(), descrs->size());
+}
+
+void EnsureDictionaryDecoded(ValueDescr* begin, size_t count) {
+ auto* end = begin + count;
+ for (auto it = begin; it != end; it++) {
+ if (it->type->id() == Type::DICTIONARY) {
+ it->type = checked_cast<const DictionaryType&>(*it->type).value_type();
+ }
+ }
+}
+
+void ReplaceNullWithOtherType(std::vector<ValueDescr>* descrs) {
+ ReplaceNullWithOtherType(descrs->data(), descrs->size());
+}
+
+void ReplaceNullWithOtherType(ValueDescr* first, size_t count) {
+ DCHECK_EQ(count, 2);
+
+ ValueDescr* second = first++;
+ if (first->type->id() == Type::NA) {
+ first->type = second->type;
+ return;
+ }
+
+ if (second->type->id() == Type::NA) {
+ second->type = first->type;
+ return;
+ }
+}
+
+void ReplaceTypes(const std::shared_ptr<DataType>& type,
+ std::vector<ValueDescr>* descrs) {
+ ReplaceTypes(type, descrs->data(), descrs->size());
+}
+
+void ReplaceTypes(const std::shared_ptr<DataType>& type, ValueDescr* begin,
+ size_t count) {
+ auto* end = begin + count;
+ for (auto* it = begin; it != end; it++) {
+ it->type = type;
+ }
+}
+
+std::shared_ptr<DataType> CommonNumeric(const std::vector<ValueDescr>& descrs) {
+ return CommonNumeric(descrs.data(), descrs.size());
+}
+
+std::shared_ptr<DataType> CommonNumeric(const ValueDescr* begin, size_t count) {
+ DCHECK_GT(count, 0) << "tried to find CommonNumeric type of an empty set";
+
+ for (size_t i = 0; i < count; i++) {
+ const auto& descr = *(begin + i);
+ auto id = descr.type->id();
+ if (!is_floating(id) && !is_integer(id)) {
+ // a common numeric type is only possible if all types are numeric
+ return nullptr;
+ }
+ if (id == Type::HALF_FLOAT) {
+ // float16 arithmetic is not currently supported
+ return nullptr;
+ }
+ }
+
+ for (size_t i = 0; i < count; i++) {
+ const auto& descr = *(begin + i);
+ if (descr.type->id() == Type::DOUBLE) return float64();
+ }
+
+ for (size_t i = 0; i < count; i++) {
+ const auto& descr = *(begin + i);
+ if (descr.type->id() == Type::FLOAT) return float32();
+ }
+
+ int max_width_signed = 0, max_width_unsigned = 0;
+
+ for (size_t i = 0; i < count; i++) {
+ const auto& descr = *(begin + i);
+ auto id = descr.type->id();
+ auto max_width = &(is_signed_integer(id) ? max_width_signed : max_width_unsigned);
+ *max_width = std::max(bit_width(id), *max_width);
+ }
+
+ if (max_width_signed == 0) {
+ if (max_width_unsigned >= 64) return uint64();
+ if (max_width_unsigned == 32) return uint32();
+ if (max_width_unsigned == 16) return uint16();
+ DCHECK_EQ(max_width_unsigned, 8);
+ return uint8();
+ }
+
+ if (max_width_signed <= max_width_unsigned) {
+ max_width_signed = static_cast<int>(BitUtil::NextPower2(max_width_unsigned + 1));
+ }
+
+ if (max_width_signed >= 64) return int64();
+ if (max_width_signed == 32) return int32();
+ if (max_width_signed == 16) return int16();
+ DCHECK_EQ(max_width_signed, 8);
+ return int8();
+}
+
+std::shared_ptr<DataType> CommonTemporal(const ValueDescr* begin, size_t count) {
+ TimeUnit::type finest_unit = TimeUnit::SECOND;
+ const std::string* timezone = nullptr;
+ bool saw_date32 = false;
+ bool saw_date64 = false;
+
+ const ValueDescr* end = begin + count;
+ for (auto it = begin; it != end; it++) {
+ auto id = it->type->id();
+ // a common timestamp is only possible if all types are timestamp like
+ switch (id) {
+ case Type::DATE32:
+ // Date32's unit is days, but the coarsest we have is seconds
+ saw_date32 = true;
+ continue;
+ case Type::DATE64:
+ finest_unit = std::max(finest_unit, TimeUnit::MILLI);
+ saw_date64 = true;
+ continue;
+ case Type::TIMESTAMP: {
+ const auto& ty = checked_cast<const TimestampType&>(*it->type);
+ if (timezone && *timezone != ty.timezone()) return nullptr;
+ timezone = &ty.timezone();
+ finest_unit = std::max(finest_unit, ty.unit());
+ continue;
+ }
+ default:
+ return nullptr;
+ }
+ }
+
+ if (timezone) {
+ // At least one timestamp seen
+ return timestamp(finest_unit, *timezone);
+ } else if (saw_date64) {
+ return date64();
+ } else if (saw_date32) {
+ return date32();
+ }
+ return nullptr;
+}
+
+std::shared_ptr<DataType> CommonBinary(const ValueDescr* begin, size_t count) {
+ bool all_utf8 = true, all_offset32 = true, all_fixed_width = true;
+
+ const ValueDescr* end = begin + count;
+ for (auto it = begin; it != end; ++it) {
+ auto id = it->type->id();
+ // a common varbinary type is only possible if all types are binary like
+ switch (id) {
+ case Type::STRING:
+ all_fixed_width = false;
+ continue;
+ case Type::BINARY:
+ all_fixed_width = false;
+ all_utf8 = false;
+ continue;
+ case Type::FIXED_SIZE_BINARY:
+ all_utf8 = false;
+ continue;
+ case Type::LARGE_STRING:
+ all_offset32 = false;
+ all_fixed_width = false;
+ continue;
+ case Type::LARGE_BINARY:
+ all_offset32 = false;
+ all_fixed_width = false;
+ all_utf8 = false;
+ continue;
+ default:
+ return nullptr;
+ }
+ }
+
+ if (all_fixed_width) {
+ // At least for the purposes of comparison, no need to cast.
+ return nullptr;
+ }
+
+ if (all_utf8) {
+ if (all_offset32) return utf8();
+ return large_utf8();
+ }
+
+ if (all_offset32) return binary();
+ return large_binary();
+}
+
+Status CastBinaryDecimalArgs(DecimalPromotion promotion,
+ std::vector<ValueDescr>* descrs) {
+ auto& left_type = (*descrs)[0].type;
+ auto& right_type = (*descrs)[1].type;
+ DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id()));
+
+ // decimal + float = float
+ if (is_floating(left_type->id())) {
+ right_type = left_type;
+ return Status::OK();
+ } else if (is_floating(right_type->id())) {
+ left_type = right_type;
+ return Status::OK();
+ }
+
+ // precision, scale of left and right args
+ int32_t p1, s1, p2, s2;
+
+ // decimal + integer = decimal
+ if (is_decimal(left_type->id())) {
+ auto decimal = checked_cast<const DecimalType*>(left_type.get());
+ p1 = decimal->precision();
+ s1 = decimal->scale();
+ } else {
+ DCHECK(is_integer(left_type->id()));
+ ARROW_ASSIGN_OR_RAISE(p1, MaxDecimalDigitsForInteger(left_type->id()));
+ s1 = 0;
+ }
+ if (is_decimal(right_type->id())) {
+ auto decimal = checked_cast<const DecimalType*>(right_type.get());
+ p2 = decimal->precision();
+ s2 = decimal->scale();
+ } else {
+ DCHECK(is_integer(right_type->id()));
+ ARROW_ASSIGN_OR_RAISE(p2, MaxDecimalDigitsForInteger(right_type->id()));
+ s2 = 0;
+ }
+ if (s1 < 0 || s2 < 0) {
+ return Status::NotImplemented("Decimals with negative scales not supported");
+ }
+
+ // decimal128 + decimal256 = decimal256
+ Type::type casted_type_id = Type::DECIMAL128;
+ if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) {
+ casted_type_id = Type::DECIMAL256;
+ }
+
+ // decimal promotion rules compatible with amazon redshift
+ // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html
+ int32_t left_scaleup = 0;
+ int32_t right_scaleup = 0;
+
+ switch (promotion) {
+ case DecimalPromotion::kAdd: {
+ left_scaleup = std::max(s1, s2) - s1;
+ right_scaleup = std::max(s1, s2) - s2;
+ break;
+ }
+ case DecimalPromotion::kMultiply: {
+ left_scaleup = right_scaleup = 0;
+ break;
+ }
+ case DecimalPromotion::kDivide: {
+ left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1;
+ right_scaleup = 0;
+ break;
+ }
+ default:
+ DCHECK(false) << "Invalid DecimalPromotion value " << static_cast<int>(promotion);
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup));
+ ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup,
+ s2 + right_scaleup));
+ return Status::OK();
+}
+
+Status CastDecimalArgs(ValueDescr* begin, size_t count) {
+ Type::type casted_type_id = Type::DECIMAL128;
+ auto* end = begin + count;
+
+ int32_t max_scale = 0;
+ bool any_floating = false;
+ for (auto* it = begin; it != end; ++it) {
+ const auto& ty = *it->type;
+ if (is_floating(ty.id())) {
+ // Decimal + float = float
+ any_floating = true;
+ } else if (is_integer(ty.id())) {
+ // Nothing to do here
+ } else if (is_decimal(ty.id())) {
+ max_scale = std::max(max_scale, checked_cast<const DecimalType&>(ty).scale());
+ if (ty.id() == Type::DECIMAL256) {
+ casted_type_id = Type::DECIMAL256;
+ }
+ } else {
+ // Non-numeric, can't cast
+ return Status::OK();
+ }
+ }
+ if (any_floating) {
+ ReplaceTypes(float64(), begin, count);
+ return Status::OK();
+ }
+
+ // All integer and decimal, rescale
+ int32_t common_precision = 0;
+ for (auto* it = begin; it != end; ++it) {
+ const auto& ty = *it->type;
+ if (is_integer(ty.id())) {
+ ARROW_ASSIGN_OR_RAISE(auto precision, MaxDecimalDigitsForInteger(ty.id()));
+ precision += max_scale;
+ common_precision = std::max(common_precision, precision);
+ } else if (is_decimal(ty.id())) {
+ const auto& decimal_ty = checked_cast<const DecimalType&>(ty);
+ auto precision = decimal_ty.precision();
+ const auto scale = decimal_ty.scale();
+ precision += max_scale - scale;
+ common_precision = std::max(common_precision, precision);
+ }
+ }
+
+ if (common_precision > BasicDecimal256::kMaxPrecision) {
+ return Status::Invalid("Result precision (", common_precision,
+ ") exceeds max precision of Decimal256 (",
+ BasicDecimal256::kMaxPrecision, ")");
+ } else if (common_precision > BasicDecimal128::kMaxPrecision) {
+ casted_type_id = Type::DECIMAL256;
+ }
+
+ for (auto* it = begin; it != end; ++it) {
+ ARROW_ASSIGN_OR_RAISE(it->type,
+ DecimalType::Make(casted_type_id, common_precision, max_scale));
+ }
+
+ return Status::OK();
+}
+
+bool HasDecimal(const std::vector<ValueDescr>& descrs) {
+ for (const auto& descr : descrs) {
+ if (is_decimal(descr.type->id())) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h b/src/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h
new file mode 100644
index 000000000..438362585
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -0,0 +1,1353 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/data.h"
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/kernel.h"
+#include "arrow/datum.h"
+#include "arrow/result.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_generate.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/bitmap_writer.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/string_view.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::BinaryBitBlockCounter;
+using internal::BitBlockCount;
+using internal::BitmapReader;
+using internal::checked_cast;
+using internal::FirstTimeBitmapWriter;
+using internal::GenerateBitsUnrolled;
+using internal::VisitBitBlocksVoid;
+using internal::VisitTwoBitBlocksVoid;
+
+namespace compute {
+namespace internal {
+
+/// KernelState adapter for the common case of kernels whose only
+/// state is an instance of a subclass of FunctionOptions.
+/// Default FunctionOptions are *not* handled here.
+template <typename OptionsType>
+struct OptionsWrapper : public KernelState {
+ explicit OptionsWrapper(OptionsType options) : options(std::move(options)) {}
+
+ static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ if (auto options = static_cast<const OptionsType*>(args.options)) {
+ return ::arrow::internal::make_unique<OptionsWrapper>(*options);
+ }
+
+ return Status::Invalid(
+ "Attempted to initialize KernelState from null FunctionOptions");
+ }
+
+ static const OptionsType& Get(const KernelState& state) {
+ return ::arrow::internal::checked_cast<const OptionsWrapper&>(state).options;
+ }
+
+ static const OptionsType& Get(KernelContext* ctx) { return Get(*ctx->state()); }
+
+ OptionsType options;
+};
+
+/// KernelState adapter for when the state is an instance constructed with the
+/// KernelContext and the FunctionOptions as argument
+template <typename StateType, typename OptionsType>
+struct KernelStateFromFunctionOptions : public KernelState {
+ explicit KernelStateFromFunctionOptions(KernelContext* ctx, OptionsType options)
+ : state(StateType(ctx, std::move(options))) {}
+
+ static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ if (auto options = static_cast<const OptionsType*>(args.options)) {
+ return ::arrow::internal::make_unique<KernelStateFromFunctionOptions>(ctx,
+ *options);
+ }
+
+ return Status::Invalid(
+ "Attempted to initialize KernelState from null FunctionOptions");
+ }
+
+ static const StateType& Get(const KernelState& state) {
+ return ::arrow::internal::checked_cast<const KernelStateFromFunctionOptions&>(state)
+ .state;
+ }
+
+ static const StateType& Get(KernelContext* ctx) { return Get(*ctx->state()); }
+
+ StateType state;
+};
+
+// ----------------------------------------------------------------------
+// Input and output value type definitions
+
+template <typename Type, typename Enable = void>
+struct GetViewType;
+
+template <typename Type>
+struct GetViewType<Type, enable_if_has_c_type<Type>> {
+ using T = typename Type::c_type;
+ using PhysicalType = T;
+
+ static T LogicalValue(PhysicalType value) { return value; }
+};
+
+template <typename Type>
+struct GetViewType<Type, enable_if_t<is_base_binary_type<Type>::value ||
+ is_fixed_size_binary_type<Type>::value>> {
+ using T = util::string_view;
+ using PhysicalType = T;
+
+ static T LogicalValue(PhysicalType value) { return value; }
+};
+
+template <>
+struct GetViewType<Decimal128Type> {
+ using T = Decimal128;
+ using PhysicalType = util::string_view;
+
+ static T LogicalValue(PhysicalType value) {
+ return Decimal128(reinterpret_cast<const uint8_t*>(value.data()));
+ }
+
+ static T LogicalValue(T value) { return value; }
+};
+
+template <>
+struct GetViewType<Decimal256Type> {
+ using T = Decimal256;
+ using PhysicalType = util::string_view;
+
+ static T LogicalValue(PhysicalType value) {
+ return Decimal256(reinterpret_cast<const uint8_t*>(value.data()));
+ }
+
+ static T LogicalValue(T value) { return value; }
+};
+
+template <typename Type, typename Enable = void>
+struct GetOutputType;
+
+template <typename Type>
+struct GetOutputType<Type, enable_if_has_c_type<Type>> {
+ using T = typename Type::c_type;
+};
+
+template <typename Type>
+struct GetOutputType<Type, enable_if_t<is_string_like_type<Type>::value>> {
+ using T = std::string;
+};
+
+template <>
+struct GetOutputType<Decimal128Type> {
+ using T = Decimal128;
+};
+
+template <>
+struct GetOutputType<Decimal256Type> {
+ using T = Decimal256;
+};
+
+// ----------------------------------------------------------------------
+// Iteration / value access utilities
+
+template <typename T, typename R = void>
+using enable_if_c_number_or_decimal = enable_if_t<
+ (has_c_type<T>::value && !is_boolean_type<T>::value) || is_decimal_type<T>::value, R>;
+
+// Iterator over various input array types, yielding a GetViewType<Type>
+
+template <typename Type, typename Enable = void>
+struct ArrayIterator;
+
+template <typename Type>
+struct ArrayIterator<Type, enable_if_c_number_or_decimal<Type>> {
+ using T = typename TypeTraits<Type>::ScalarType::ValueType;
+ const T* values;
+
+ explicit ArrayIterator(const ArrayData& data) : values(data.GetValues<T>(1)) {}
+ T operator()() { return *values++; }
+};
+
+template <typename Type>
+struct ArrayIterator<Type, enable_if_boolean<Type>> {
+ BitmapReader reader;
+
+ explicit ArrayIterator(const ArrayData& data)
+ : reader(data.buffers[1]->data(), data.offset, data.length) {}
+ bool operator()() {
+ bool out = reader.IsSet();
+ reader.Next();
+ return out;
+ }
+};
+
+template <typename Type>
+struct ArrayIterator<Type, enable_if_base_binary<Type>> {
+ using offset_type = typename Type::offset_type;
+ const ArrayData& arr;
+ const offset_type* offsets;
+ offset_type cur_offset;
+ const char* data;
+ int64_t position;
+
+ explicit ArrayIterator(const ArrayData& arr)
+ : arr(arr),
+ offsets(reinterpret_cast<const offset_type*>(arr.buffers[1]->data()) +
+ arr.offset),
+ cur_offset(offsets[0]),
+ data(reinterpret_cast<const char*>(arr.buffers[2]->data())),
+ position(0) {}
+
+ util::string_view operator()() {
+ offset_type next_offset = offsets[++position];
+ auto result = util::string_view(data + cur_offset, next_offset - cur_offset);
+ cur_offset = next_offset;
+ return result;
+ }
+};
+
+template <>
+struct ArrayIterator<FixedSizeBinaryType> {
+ const ArrayData& arr;
+ const char* data;
+ const int32_t width;
+ int64_t position;
+
+ explicit ArrayIterator(const ArrayData& arr)
+ : arr(arr),
+ data(reinterpret_cast<const char*>(arr.buffers[1]->data())),
+ width(checked_cast<const FixedSizeBinaryType&>(*arr.type).byte_width()),
+ position(arr.offset) {}
+
+ util::string_view operator()() {
+ auto result = util::string_view(data + position * width, width);
+ position++;
+ return result;
+ }
+};
+
+// Iterator over various output array types, taking a GetOutputType<Type>
+
+template <typename Type, typename Enable = void>
+struct OutputArrayWriter;
+
+template <typename Type>
+struct OutputArrayWriter<Type, enable_if_c_number_or_decimal<Type>> {
+ using T = typename TypeTraits<Type>::ScalarType::ValueType;
+ T* values;
+
+ explicit OutputArrayWriter(ArrayData* data) : values(data->GetMutableValues<T>(1)) {}
+
+ void Write(T value) { *values++ = value; }
+
+ // Note that this doesn't write the null bitmap, which should be consistent
+ // with Write / WriteNull calls
+ void WriteNull() { *values++ = T{}; }
+
+ void WriteAllNull(int64_t length) {
+ std::memset(static_cast<void*>(values), 0, sizeof(T) * length);
+ }
+};
+
+// (Un)box Scalar to / from C++ value
+
+template <typename Type, typename Enable = void>
+struct UnboxScalar;
+
+template <typename Type>
+struct UnboxScalar<Type, enable_if_has_c_type<Type>> {
+ using T = typename Type::c_type;
+ static T Unbox(const Scalar& val) {
+ util::string_view view =
+ checked_cast<const ::arrow::internal::PrimitiveScalarBase&>(val).view();
+ DCHECK_EQ(view.size(), sizeof(T));
+ return *reinterpret_cast<const T*>(view.data());
+ }
+};
+
+template <typename Type>
+struct UnboxScalar<Type, enable_if_has_string_view<Type>> {
+ static util::string_view Unbox(const Scalar& val) {
+ if (!val.is_valid) return util::string_view();
+ return util::string_view(*checked_cast<const BaseBinaryScalar&>(val).value);
+ }
+};
+
+template <>
+struct UnboxScalar<Decimal128Type> {
+ static const Decimal128& Unbox(const Scalar& val) {
+ return checked_cast<const Decimal128Scalar&>(val).value;
+ }
+};
+
+template <>
+struct UnboxScalar<Decimal256Type> {
+ static const Decimal256& Unbox(const Scalar& val) {
+ return checked_cast<const Decimal256Scalar&>(val).value;
+ }
+};
+
+template <typename Type, typename Enable = void>
+struct BoxScalar;
+
+template <typename Type>
+struct BoxScalar<Type, enable_if_has_c_type<Type>> {
+ using T = typename GetOutputType<Type>::T;
+ static void Box(T val, Scalar* out) {
+ // Enables BoxScalar<Int64Type> to work on a (for example) Time64Scalar
+ T* mutable_data = reinterpret_cast<T*>(
+ checked_cast<::arrow::internal::PrimitiveScalarBase*>(out)->mutable_data());
+ *mutable_data = val;
+ }
+};
+
+template <typename Type>
+struct BoxScalar<Type, enable_if_base_binary<Type>> {
+ using T = typename GetOutputType<Type>::T;
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+ static void Box(T val, Scalar* out) {
+ checked_cast<ScalarType*>(out)->value = std::make_shared<Buffer>(val);
+ }
+};
+
+template <>
+struct BoxScalar<Decimal128Type> {
+ using T = Decimal128;
+ using ScalarType = Decimal128Scalar;
+ static void Box(T val, Scalar* out) { checked_cast<ScalarType*>(out)->value = val; }
+};
+
+template <>
+struct BoxScalar<Decimal256Type> {
+ using T = Decimal256;
+ using ScalarType = Decimal256Scalar;
+ static void Box(T val, Scalar* out) { checked_cast<ScalarType*>(out)->value = val; }
+};
+
+// A VisitArrayDataInline variant that calls its visitor function with logical
+// values, such as Decimal128 rather than util::string_view.
+
+template <typename T, typename VisitFunc, typename NullFunc>
+static typename arrow::internal::call_traits::enable_if_return<VisitFunc, void>::type
+VisitArrayValuesInline(const ArrayData& arr, VisitFunc&& valid_func,
+ NullFunc&& null_func) {
+ VisitArrayDataInline<T>(
+ arr,
+ [&](typename GetViewType<T>::PhysicalType v) {
+ valid_func(GetViewType<T>::LogicalValue(std::move(v)));
+ },
+ std::forward<NullFunc>(null_func));
+}
+
+template <typename T, typename VisitFunc, typename NullFunc>
+static typename arrow::internal::call_traits::enable_if_return<VisitFunc, Status>::type
+VisitArrayValuesInline(const ArrayData& arr, VisitFunc&& valid_func,
+ NullFunc&& null_func) {
+ return VisitArrayDataInline<T>(
+ arr,
+ [&](typename GetViewType<T>::PhysicalType v) {
+ return valid_func(GetViewType<T>::LogicalValue(std::move(v)));
+ },
+ std::forward<NullFunc>(null_func));
+}
+
+// Like VisitArrayValuesInline, but for binary functions.
+
+template <typename Arg0Type, typename Arg1Type, typename VisitFunc, typename NullFunc>
+static void VisitTwoArrayValuesInline(const ArrayData& arr0, const ArrayData& arr1,
+ VisitFunc&& valid_func, NullFunc&& null_func) {
+ ArrayIterator<Arg0Type> arr0_it(arr0);
+ ArrayIterator<Arg1Type> arr1_it(arr1);
+
+ auto visit_valid = [&](int64_t i) {
+ valid_func(GetViewType<Arg0Type>::LogicalValue(arr0_it()),
+ GetViewType<Arg1Type>::LogicalValue(arr1_it()));
+ };
+ auto visit_null = [&]() {
+ arr0_it();
+ arr1_it();
+ null_func();
+ };
+ VisitTwoBitBlocksVoid(arr0.buffers[0], arr0.offset, arr1.buffers[0], arr1.offset,
+ arr0.length, std::move(visit_valid), std::move(visit_null));
+}
+
+// ----------------------------------------------------------------------
+// Reusable type resolvers
+
+Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& descrs);
+Result<ValueDescr> LastType(KernelContext*, const std::vector<ValueDescr>& descrs);
+Result<ValueDescr> ListValuesType(KernelContext*, const std::vector<ValueDescr>& args);
+
+// ----------------------------------------------------------------------
+// Generate an array kernel given template classes
+
+Status ExecFail(KernelContext* ctx, const ExecBatch& batch, Datum* out);
+
+ArrayKernelExec MakeFlippedBinaryExec(ArrayKernelExec exec);
+
+// ----------------------------------------------------------------------
+// Helpers for iterating over common DataType instances for adding kernels to
+// functions
+
+// Returns a vector of example instances of parametric types such as
+//
+// * Decimal
+// * Timestamp (requiring unit)
+// * Time32 (requiring unit)
+// * Time64 (requiring unit)
+// * Duration (requiring unit)
+// * List, LargeList, FixedSizeList
+// * Struct
+// * Union
+// * Dictionary
+// * Map
+//
+// Generally kernels will use the "FirstType" OutputType::Resolver above for
+// the OutputType of the kernel's signature and match::SameTypeId for the
+// corresponding InputType
+const std::vector<std::shared_ptr<DataType>>& ExampleParametricTypes();
+
+// ----------------------------------------------------------------------
+// "Applicators" take an operator definition (which may be scalar-valued or
+// array-valued) and creates an ArrayKernelExec which can be used to add an
+// ArrayKernel to a Function.
+
+namespace applicator {
+
+// Generate an ArrayKernelExec given a functor that handles all of its own
+// iteration, etc.
+//
+// Operator must implement
+//
+// static Status Call(KernelContext*, const ArrayData& in, ArrayData* out)
+// static Status Call(KernelContext*, const Scalar& in, Scalar* out)
+template <typename Operator>
+static Status SimpleUnary(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].kind() == Datum::SCALAR) {
+ return Operator::Call(ctx, *batch[0].scalar(), out->scalar().get());
+ } else if (batch.length > 0) {
+ return Operator::Call(ctx, *batch[0].array(), out->mutable_array());
+ }
+ return Status::OK();
+}
+
+// Generate an ArrayKernelExec given a functor that handles all of its own
+// iteration, etc.
+//
+// Operator must implement
+//
+// static Status Call(KernelContext*, const ArrayData& arg0, const ArrayData& arg1,
+// ArrayData* out)
+// static Status Call(KernelContext*, const ArrayData& arg0, const Scalar& arg1,
+// ArrayData* out)
+// static Status Call(KernelContext*, const Scalar& arg0, const ArrayData& arg1,
+// ArrayData* out)
+// static Status Call(KernelContext*, const Scalar& arg0, const Scalar& arg1,
+// Scalar* out)
+template <typename Operator>
+static Status SimpleBinary(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch.length == 0) return Status::OK();
+
+ if (batch[0].kind() == Datum::ARRAY) {
+ if (batch[1].kind() == Datum::ARRAY) {
+ return Operator::Call(ctx, *batch[0].array(), *batch[1].array(),
+ out->mutable_array());
+ } else {
+ return Operator::Call(ctx, *batch[0].array(), *batch[1].scalar(),
+ out->mutable_array());
+ }
+ } else {
+ if (batch[1].kind() == Datum::ARRAY) {
+ return Operator::Call(ctx, *batch[0].scalar(), *batch[1].array(),
+ out->mutable_array());
+ } else {
+ return Operator::Call(ctx, *batch[0].scalar(), *batch[1].scalar(),
+ out->scalar().get());
+ }
+ }
+}
+
+// OutputAdapter allows passing an inlineable lambda that provides a sequence
+// of output values to write into output memory. Boolean and primitive outputs
+// are currently implemented, and the validity bitmap is presumed to be handled
+// at a higher level, so this writes into every output slot, null or not.
+template <typename Type, typename Enable = void>
+struct OutputAdapter;
+
+template <typename Type>
+struct OutputAdapter<Type, enable_if_boolean<Type>> {
+ template <typename Generator>
+ static Status Write(KernelContext*, Datum* out, Generator&& generator) {
+ ArrayData* out_arr = out->mutable_array();
+ auto out_bitmap = out_arr->buffers[1]->mutable_data();
+ GenerateBitsUnrolled(out_bitmap, out_arr->offset, out_arr->length,
+ std::forward<Generator>(generator));
+ return Status::OK();
+ }
+};
+
+template <typename Type>
+struct OutputAdapter<Type, enable_if_c_number_or_decimal<Type>> {
+ using T = typename TypeTraits<Type>::ScalarType::ValueType;
+
+ template <typename Generator>
+ static Status Write(KernelContext*, Datum* out, Generator&& generator) {
+ ArrayData* out_arr = out->mutable_array();
+ auto out_data = out_arr->GetMutableValues<T>(1);
+ // TODO: Is this as fast as a more explicitly inlined function?
+ for (int64_t i = 0; i < out_arr->length; ++i) {
+ *out_data++ = generator();
+ }
+ return Status::OK();
+ }
+};
+
+template <typename Type>
+struct OutputAdapter<Type, enable_if_base_binary<Type>> {
+ template <typename Generator>
+ static Status Write(KernelContext* ctx, Datum* out, Generator&& generator) {
+ return Status::NotImplemented("NYI");
+ }
+};
+
+// A kernel exec generator for unary functions that addresses both array and
+// scalar inputs and dispatches input iteration and output writing to other
+// templates
+//
+// This template executes the operator even on the data behind null values,
+// therefore it is generally only suitable for operators that are safe to apply
+// even on the null slot values.
+//
+// The "Op" functor should have the form
+//
+// struct Op {
+// template <typename OutValue, typename Arg0Value>
+// static OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) {
+// // implementation
+// // NOTE: "status" should only populated with errors,
+// // leave it unmodified to indicate Status::OK()
+// }
+// };
+template <typename OutType, typename Arg0Type, typename Op>
+struct ScalarUnary {
+ using OutValue = typename GetOutputType<OutType>::T;
+ using Arg0Value = typename GetViewType<Arg0Type>::T;
+
+ static Status ExecArray(KernelContext* ctx, const ArrayData& arg0, Datum* out) {
+ Status st = Status::OK();
+ ArrayIterator<Arg0Type> arg0_it(arg0);
+ RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue {
+ return Op::template Call<OutValue, Arg0Value>(ctx, arg0_it(), &st);
+ }));
+ return st;
+ }
+
+ static Status ExecScalar(KernelContext* ctx, const Scalar& arg0, Datum* out) {
+ Status st = Status::OK();
+ Scalar* out_scalar = out->scalar().get();
+ if (arg0.is_valid) {
+ Arg0Value arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0);
+ out_scalar->is_valid = true;
+ BoxScalar<OutType>::Box(Op::template Call<OutValue, Arg0Value>(ctx, arg0_val, &st),
+ out_scalar);
+ } else {
+ out_scalar->is_valid = false;
+ }
+ return st;
+ }
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].kind() == Datum::ARRAY) {
+ return ExecArray(ctx, *batch[0].array(), out);
+ } else {
+ return ExecScalar(ctx, *batch[0].scalar(), out);
+ }
+ }
+};
+
+// An alternative to ScalarUnary that Applies a scalar operation with state on
+// only the not-null values of a single array
+template <typename OutType, typename Arg0Type, typename Op>
+struct ScalarUnaryNotNullStateful {
+ using ThisType = ScalarUnaryNotNullStateful<OutType, Arg0Type, Op>;
+ using OutValue = typename GetOutputType<OutType>::T;
+ using Arg0Value = typename GetViewType<Arg0Type>::T;
+
+ Op op;
+ explicit ScalarUnaryNotNullStateful(Op op) : op(std::move(op)) {}
+
+ // NOTE: In ArrayExec<Type>, Type is really OutputType
+
+ template <typename Type, typename Enable = void>
+ struct ArrayExec {
+ static Status Exec(const ThisType& functor, KernelContext* ctx,
+ const ExecBatch& batch, Datum* out) {
+ ARROW_LOG(FATAL) << "Missing ArrayExec specialization for output type "
+ << out->type();
+ return Status::NotImplemented("NYI");
+ }
+ };
+
+ template <typename Type>
+ struct ArrayExec<Type, enable_if_c_number_or_decimal<Type>> {
+ static Status Exec(const ThisType& functor, KernelContext* ctx, const ArrayData& arg0,
+ Datum* out) {
+ Status st = Status::OK();
+ ArrayData* out_arr = out->mutable_array();
+ auto out_data = out_arr->GetMutableValues<OutValue>(1);
+ VisitArrayValuesInline<Arg0Type>(
+ arg0,
+ [&](Arg0Value v) {
+ *out_data++ = functor.op.template Call<OutValue, Arg0Value>(ctx, v, &st);
+ },
+ [&]() {
+ // null
+ *out_data++ = OutValue{};
+ });
+ return st;
+ }
+ };
+
+ template <typename Type>
+ struct ArrayExec<Type, enable_if_base_binary<Type>> {
+ static Status Exec(const ThisType& functor, KernelContext* ctx, const ArrayData& arg0,
+ Datum* out) {
+ // NOTE: This code is not currently used by any kernels and has
+ // suboptimal performance because it's recomputing the validity bitmap
+ // that is already computed by the kernel execution layer. Consider
+ // writing a lower-level "output adapter" for base binary types.
+ typename TypeTraits<Type>::BuilderType builder;
+ Status st = Status::OK();
+ RETURN_NOT_OK(VisitArrayValuesInline<Arg0Type>(
+ arg0, [&](Arg0Value v) { return builder.Append(functor.op.Call(ctx, v, &st)); },
+ [&]() { return builder.AppendNull(); }));
+ if (st.ok()) {
+ std::shared_ptr<ArrayData> result;
+ RETURN_NOT_OK(builder.FinishInternal(&result));
+ out->value = std::move(result);
+ }
+ return st;
+ }
+ };
+
+ template <typename Type>
+ struct ArrayExec<Type, enable_if_t<is_boolean_type<Type>::value>> {
+ static Status Exec(const ThisType& functor, KernelContext* ctx, const ArrayData& arg0,
+ Datum* out) {
+ Status st = Status::OK();
+ ArrayData* out_arr = out->mutable_array();
+ FirstTimeBitmapWriter out_writer(out_arr->buffers[1]->mutable_data(),
+ out_arr->offset, out_arr->length);
+ VisitArrayValuesInline<Arg0Type>(
+ arg0,
+ [&](Arg0Value v) {
+ if (functor.op.template Call<OutValue, Arg0Value>(ctx, v, &st)) {
+ out_writer.Set();
+ }
+ out_writer.Next();
+ },
+ [&]() {
+ // null
+ out_writer.Clear();
+ out_writer.Next();
+ });
+ out_writer.Finish();
+ return st;
+ }
+ };
+
+ Status Scalar(KernelContext* ctx, const Scalar& arg0, Datum* out) {
+ Status st = Status::OK();
+ if (arg0.is_valid) {
+ Arg0Value arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0);
+ BoxScalar<OutType>::Box(
+ this->op.template Call<OutValue, Arg0Value>(ctx, arg0_val, &st),
+ out->scalar().get());
+ }
+ return st;
+ }
+
+ Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].kind() == Datum::ARRAY) {
+ return ArrayExec<OutType>::Exec(*this, ctx, *batch[0].array(), out);
+ } else {
+ return Scalar(ctx, *batch[0].scalar(), out);
+ }
+ }
+};
+
+// An alternative to ScalarUnary that Applies a scalar operation on only the
+// not-null values of a single array. The operator is not stateful; if the
+// operator requires some initialization use ScalarUnaryNotNullStateful
+template <typename OutType, typename Arg0Type, typename Op>
+struct ScalarUnaryNotNull {
+ using OutValue = typename GetOutputType<OutType>::T;
+ using Arg0Value = typename GetViewType<Arg0Type>::T;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // Seed kernel with dummy state
+ ScalarUnaryNotNullStateful<OutType, Arg0Type, Op> kernel({});
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+// A kernel exec generator for binary functions that addresses both array and
+// scalar inputs and dispatches input iteration and output writing to other
+// templates
+//
+// This template executes the operator even on the data behind null values,
+// therefore it is generally only suitable for operators that are safe to apply
+// even on the null slot values.
+//
+// The "Op" functor should have the form
+//
+// struct Op {
+// template <typename OutValue, typename Arg0Value, typename Arg1Value>
+// static OutValue Call(KernelContext* ctx, Arg0Value arg0, Arg1Value arg1, Status* st)
+// {
+// // implementation
+// // NOTE: "status" should only populated with errors,
+// // leave it unmodified to indicate Status::OK()
+// }
+// };
+template <typename OutType, typename Arg0Type, typename Arg1Type, typename Op>
+struct ScalarBinary {
+ using OutValue = typename GetOutputType<OutType>::T;
+ using Arg0Value = typename GetViewType<Arg0Type>::T;
+ using Arg1Value = typename GetViewType<Arg1Type>::T;
+
+ static Status ArrayArray(KernelContext* ctx, const ArrayData& arg0,
+ const ArrayData& arg1, Datum* out) {
+ Status st = Status::OK();
+ ArrayIterator<Arg0Type> arg0_it(arg0);
+ ArrayIterator<Arg1Type> arg1_it(arg1);
+ RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue {
+ return Op::template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_it(), arg1_it(),
+ &st);
+ }));
+ return st;
+ }
+
+ static Status ArrayScalar(KernelContext* ctx, const ArrayData& arg0, const Scalar& arg1,
+ Datum* out) {
+ Status st = Status::OK();
+ ArrayIterator<Arg0Type> arg0_it(arg0);
+ auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1);
+ RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue {
+ return Op::template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_it(), arg1_val,
+ &st);
+ }));
+ return st;
+ }
+
+ static Status ScalarArray(KernelContext* ctx, const Scalar& arg0, const ArrayData& arg1,
+ Datum* out) {
+ Status st = Status::OK();
+ auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0);
+ ArrayIterator<Arg1Type> arg1_it(arg1);
+ RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue {
+ return Op::template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_val, arg1_it(),
+ &st);
+ }));
+ return st;
+ }
+
+ static Status ScalarScalar(KernelContext* ctx, const Scalar& arg0, const Scalar& arg1,
+ Datum* out) {
+ Status st = Status::OK();
+ if (out->scalar()->is_valid) {
+ auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0);
+ auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1);
+ BoxScalar<OutType>::Box(
+ Op::template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_val, arg1_val, &st),
+ out->scalar().get());
+ }
+ return st;
+ }
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].kind() == Datum::ARRAY) {
+ if (batch[1].kind() == Datum::ARRAY) {
+ return ArrayArray(ctx, *batch[0].array(), *batch[1].array(), out);
+ } else {
+ return ArrayScalar(ctx, *batch[0].array(), *batch[1].scalar(), out);
+ }
+ } else {
+ if (batch[1].kind() == Datum::ARRAY) {
+ return ScalarArray(ctx, *batch[0].scalar(), *batch[1].array(), out);
+ } else {
+ return ScalarScalar(ctx, *batch[0].scalar(), *batch[1].scalar(), out);
+ }
+ }
+ }
+};
+
+// An alternative to ScalarBinary that Applies a scalar operation with state on
+// only the value pairs which are not-null in both arrays
+template <typename OutType, typename Arg0Type, typename Arg1Type, typename Op>
+struct ScalarBinaryNotNullStateful {
+ using ThisType = ScalarBinaryNotNullStateful<OutType, Arg0Type, Arg1Type, Op>;
+ using OutValue = typename GetOutputType<OutType>::T;
+ using Arg0Value = typename GetViewType<Arg0Type>::T;
+ using Arg1Value = typename GetViewType<Arg1Type>::T;
+
+ Op op;
+ explicit ScalarBinaryNotNullStateful(Op op) : op(std::move(op)) {}
+
+ // NOTE: In ArrayExec<Type>, Type is really OutputType
+
+ Status ArrayArray(KernelContext* ctx, const ArrayData& arg0, const ArrayData& arg1,
+ Datum* out) {
+ Status st = Status::OK();
+ OutputArrayWriter<OutType> writer(out->mutable_array());
+ VisitTwoArrayValuesInline<Arg0Type, Arg1Type>(
+ arg0, arg1,
+ [&](Arg0Value u, Arg1Value v) {
+ writer.Write(op.template Call<OutValue, Arg0Value, Arg1Value>(ctx, u, v, &st));
+ },
+ [&]() { writer.WriteNull(); });
+ return st;
+ }
+
+ Status ArrayScalar(KernelContext* ctx, const ArrayData& arg0, const Scalar& arg1,
+ Datum* out) {
+ Status st = Status::OK();
+ OutputArrayWriter<OutType> writer(out->mutable_array());
+ if (arg1.is_valid) {
+ const auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1);
+ VisitArrayValuesInline<Arg0Type>(
+ arg0,
+ [&](Arg0Value u) {
+ writer.Write(
+ op.template Call<OutValue, Arg0Value, Arg1Value>(ctx, u, arg1_val, &st));
+ },
+ [&]() { writer.WriteNull(); });
+ } else {
+ writer.WriteAllNull(out->mutable_array()->length);
+ }
+ return st;
+ }
+
+ Status ScalarArray(KernelContext* ctx, const Scalar& arg0, const ArrayData& arg1,
+ Datum* out) {
+ Status st = Status::OK();
+ OutputArrayWriter<OutType> writer(out->mutable_array());
+ if (arg0.is_valid) {
+ const auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0);
+ VisitArrayValuesInline<Arg1Type>(
+ arg1,
+ [&](Arg1Value v) {
+ writer.Write(
+ op.template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_val, v, &st));
+ },
+ [&]() { writer.WriteNull(); });
+ } else {
+ writer.WriteAllNull(out->mutable_array()->length);
+ }
+ return st;
+ }
+
+ Status ScalarScalar(KernelContext* ctx, const Scalar& arg0, const Scalar& arg1,
+ Datum* out) {
+ Status st = Status::OK();
+ if (arg0.is_valid && arg1.is_valid) {
+ const auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0);
+ const auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1);
+ BoxScalar<OutType>::Box(
+ op.template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_val, arg1_val, &st),
+ out->scalar().get());
+ }
+ return st;
+ }
+
+ Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].kind() == Datum::ARRAY) {
+ if (batch[1].kind() == Datum::ARRAY) {
+ return ArrayArray(ctx, *batch[0].array(), *batch[1].array(), out);
+ } else {
+ return ArrayScalar(ctx, *batch[0].array(), *batch[1].scalar(), out);
+ }
+ } else {
+ if (batch[1].kind() == Datum::ARRAY) {
+ return ScalarArray(ctx, *batch[0].scalar(), *batch[1].array(), out);
+ } else {
+ return ScalarScalar(ctx, *batch[0].scalar(), *batch[1].scalar(), out);
+ }
+ }
+ }
+};
+
+// An alternative to ScalarBinary that Applies a scalar operation on only
+// the value pairs which are not-null in both arrays.
+// The operator is not stateful; if the operator requires some initialization
+// use ScalarBinaryNotNullStateful.
+template <typename OutType, typename Arg0Type, typename Arg1Type, typename Op>
+struct ScalarBinaryNotNull {
+ using OutValue = typename GetOutputType<OutType>::T;
+ using Arg0Value = typename GetViewType<Arg0Type>::T;
+ using Arg1Value = typename GetViewType<Arg1Type>::T;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // Seed kernel with dummy state
+ ScalarBinaryNotNullStateful<OutType, Arg0Type, Arg1Type, Op> kernel({});
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+// A kernel exec generator for binary kernels where both input types are the
+// same
+template <typename OutType, typename ArgType, typename Op>
+using ScalarBinaryEqualTypes = ScalarBinary<OutType, ArgType, ArgType, Op>;
+
+// A kernel exec generator for non-null binary kernels where both input types are the
+// same
+template <typename OutType, typename ArgType, typename Op>
+using ScalarBinaryNotNullEqualTypes = ScalarBinaryNotNull<OutType, ArgType, ArgType, Op>;
+
+template <typename OutType, typename ArgType, typename Op>
+using ScalarBinaryNotNullStatefulEqualTypes =
+ ScalarBinaryNotNullStateful<OutType, ArgType, ArgType, Op>;
+
+} // namespace applicator
+
+// ----------------------------------------------------------------------
+// BEGIN of kernel generator-dispatchers ("GD")
+//
+// These GD functions instantiate kernel functor templates and select one of
+// the instantiated kernels dynamically based on the data type or Type::type id
+// that is passed. This enables functions to be populated with kernels by
+// looping over vectors of data types rather than using macros or other
+// approaches.
+//
+// The kernel functor must be of the form:
+//
+// template <typename Type0, typename Type1, Args...>
+// struct FUNCTOR {
+// static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+// // IMPLEMENTATION
+// }
+// };
+//
+// When you pass FUNCTOR to a GD function, you must pass at least one static
+// type along with the functor -- this is often the fixed return type of the
+// functor. This Type0 argument is passed as the first argument to the functor
+// during instantiation. The 2nd type passed to the functor is the DataType
+// subclass corresponding to the type passed as argument (not template type) to
+// the function.
+//
+// For example, GenerateNumeric<FUNCTOR, Type0>(int32()) will select a kernel
+// instantiated like FUNCTOR<Type0, Int32Type>. Any additional variadic
+// template arguments will be passed as additional template arguments to the
+// kernel template.
+
+namespace detail {
+
+// Convenience so we can pass DataType or Type::type for the GD's
+struct GetTypeId {
+ Type::type id;
+ GetTypeId(const std::shared_ptr<DataType>& type) // NOLINT implicit construction
+ : id(type->id()) {}
+ GetTypeId(const DataType& type) // NOLINT implicit construction
+ : id(type.id()) {}
+ GetTypeId(Type::type id) // NOLINT implicit construction
+ : id(id) {}
+};
+
+} // namespace detail
+
+// GD for numeric types (integer and floating point)
+template <template <typename...> class Generator, typename Type0, typename... Args>
+ArrayKernelExec GenerateNumeric(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::INT8:
+ return Generator<Type0, Int8Type, Args...>::Exec;
+ case Type::UINT8:
+ return Generator<Type0, UInt8Type, Args...>::Exec;
+ case Type::INT16:
+ return Generator<Type0, Int16Type, Args...>::Exec;
+ case Type::UINT16:
+ return Generator<Type0, UInt16Type, Args...>::Exec;
+ case Type::INT32:
+ return Generator<Type0, Int32Type, Args...>::Exec;
+ case Type::UINT32:
+ return Generator<Type0, UInt32Type, Args...>::Exec;
+ case Type::INT64:
+ return Generator<Type0, Int64Type, Args...>::Exec;
+ case Type::UINT64:
+ return Generator<Type0, UInt64Type, Args...>::Exec;
+ case Type::FLOAT:
+ return Generator<Type0, FloatType, Args...>::Exec;
+ case Type::DOUBLE:
+ return Generator<Type0, DoubleType, Args...>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+// Generate a kernel given a templated functor for floating point types
+//
+// See "Numeric" above for description of the generator functor
+template <template <typename...> class Generator, typename Type0, typename... Args>
+ArrayKernelExec GenerateFloatingPoint(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::FLOAT:
+ return Generator<Type0, FloatType, Args...>::Exec;
+ case Type::DOUBLE:
+ return Generator<Type0, DoubleType, Args...>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+// Generate a kernel given a templated functor for integer types
+//
+// See "Numeric" above for description of the generator functor
+template <template <typename...> class Generator, typename Type0, typename... Args>
+ArrayKernelExec GenerateInteger(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::INT8:
+ return Generator<Type0, Int8Type, Args...>::Exec;
+ case Type::INT16:
+ return Generator<Type0, Int16Type, Args...>::Exec;
+ case Type::INT32:
+ return Generator<Type0, Int32Type, Args...>::Exec;
+ case Type::INT64:
+ return Generator<Type0, Int64Type, Args...>::Exec;
+ case Type::UINT8:
+ return Generator<Type0, UInt8Type, Args...>::Exec;
+ case Type::UINT16:
+ return Generator<Type0, UInt16Type, Args...>::Exec;
+ case Type::UINT32:
+ return Generator<Type0, UInt32Type, Args...>::Exec;
+ case Type::UINT64:
+ return Generator<Type0, UInt64Type, Args...>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+template <template <typename...> class Generator, typename Type0, typename... Args>
+ArrayKernelExec GeneratePhysicalInteger(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::INT8:
+ return Generator<Type0, Int8Type, Args...>::Exec;
+ case Type::INT16:
+ return Generator<Type0, Int16Type, Args...>::Exec;
+ case Type::INT32:
+ case Type::DATE32:
+ case Type::TIME32:
+ return Generator<Type0, Int32Type, Args...>::Exec;
+ case Type::INT64:
+ case Type::DATE64:
+ case Type::TIMESTAMP:
+ case Type::TIME64:
+ case Type::DURATION:
+ return Generator<Type0, Int64Type, Args...>::Exec;
+ case Type::UINT8:
+ return Generator<Type0, UInt8Type, Args...>::Exec;
+ case Type::UINT16:
+ return Generator<Type0, UInt16Type, Args...>::Exec;
+ case Type::UINT32:
+ return Generator<Type0, UInt32Type, Args...>::Exec;
+ case Type::UINT64:
+ return Generator<Type0, UInt64Type, Args...>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+template <template <typename... Args> class Generator, typename... Args>
+ArrayKernelExec GeneratePhysicalNumeric(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::INT8:
+ return Generator<Int8Type, Args...>::Exec;
+ case Type::INT16:
+ return Generator<Int16Type, Args...>::Exec;
+ case Type::INT32:
+ case Type::DATE32:
+ case Type::TIME32:
+ return Generator<Int32Type, Args...>::Exec;
+ case Type::INT64:
+ case Type::DATE64:
+ case Type::TIMESTAMP:
+ case Type::TIME64:
+ case Type::DURATION:
+ return Generator<Int64Type, Args...>::Exec;
+ case Type::UINT8:
+ return Generator<UInt8Type, Args...>::Exec;
+ case Type::UINT16:
+ return Generator<UInt16Type, Args...>::Exec;
+ case Type::UINT32:
+ return Generator<UInt32Type, Args...>::Exec;
+ case Type::UINT64:
+ return Generator<UInt64Type, Args...>::Exec;
+ case Type::FLOAT:
+ return Generator<FloatType, Args...>::Exec;
+ case Type::DOUBLE:
+ return Generator<DoubleType, Args...>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+// Generate a kernel given a templated functor for integer types
+//
+// See "Numeric" above for description of the generator functor
+template <template <typename...> class Generator, typename Type0, typename... Args>
+ArrayKernelExec GenerateSignedInteger(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::INT8:
+ return Generator<Type0, Int8Type, Args...>::Exec;
+ case Type::INT16:
+ return Generator<Type0, Int16Type, Args...>::Exec;
+ case Type::INT32:
+ return Generator<Type0, Int32Type, Args...>::Exec;
+ case Type::INT64:
+ return Generator<Type0, Int64Type, Args...>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+// Generate a kernel given a templated functor. Only a single template is
+// instantiated for each bit width, and the functor is expected to treat types
+// of the same bit width the same without utilizing any type-specific behavior
+// (e.g. int64 should be handled equivalent to uint64 or double -- all 64
+// bits).
+//
+// See "Numeric" above for description of the generator functor
+template <template <typename...> class Generator, typename... Args>
+ArrayKernelExec GenerateTypeAgnosticPrimitive(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::NA:
+ return Generator<NullType, Args...>::Exec;
+ case Type::BOOL:
+ return Generator<BooleanType, Args...>::Exec;
+ case Type::UINT8:
+ case Type::INT8:
+ return Generator<UInt8Type, Args...>::Exec;
+ case Type::UINT16:
+ case Type::INT16:
+ return Generator<UInt16Type, Args...>::Exec;
+ case Type::UINT32:
+ case Type::INT32:
+ case Type::FLOAT:
+ case Type::DATE32:
+ case Type::TIME32:
+ case Type::INTERVAL_MONTHS:
+ return Generator<UInt32Type, Args...>::Exec;
+ case Type::UINT64:
+ case Type::INT64:
+ case Type::DOUBLE:
+ case Type::DATE64:
+ case Type::TIMESTAMP:
+ case Type::TIME64:
+ case Type::DURATION:
+ case Type::INTERVAL_DAY_TIME:
+ return Generator<UInt64Type, Args...>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+// similar to GenerateTypeAgnosticPrimitive, but for variable types
+template <template <typename...> class Generator, typename... Args>
+ArrayKernelExec GenerateTypeAgnosticVarBinaryBase(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::BINARY:
+ case Type::STRING:
+ return Generator<BinaryType, Args...>::Exec;
+ case Type::LARGE_BINARY:
+ case Type::LARGE_STRING:
+ return Generator<LargeBinaryType, Args...>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+// Generate a kernel given a templated functor for base binary types. Generates
+// a single kernel for binary/string and large binary / large string. If your
+// kernel implementation needs access to the specific type at compile time,
+// please use BaseBinarySpecific.
+//
+// See "Numeric" above for description of the generator functor
+template <template <typename...> class Generator, typename Type0, typename... Args>
+ArrayKernelExec GenerateVarBinaryBase(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::BINARY:
+ case Type::STRING:
+ return Generator<Type0, BinaryType, Args...>::Exec;
+ case Type::LARGE_BINARY:
+ case Type::LARGE_STRING:
+ return Generator<Type0, LargeBinaryType, Args...>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+// See BaseBinary documentation
+template <template <typename...> class Generator, typename Type0, typename... Args>
+ArrayKernelExec GenerateVarBinary(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::BINARY:
+ return Generator<Type0, BinaryType, Args...>::Exec;
+ case Type::STRING:
+ return Generator<Type0, StringType, Args...>::Exec;
+ case Type::LARGE_BINARY:
+ return Generator<Type0, LargeBinaryType, Args...>::Exec;
+ case Type::LARGE_STRING:
+ return Generator<Type0, LargeStringType, Args...>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+// Generate a kernel given a templated functor for temporal types
+//
+// See "Numeric" above for description of the generator functor
+template <template <typename...> class Generator, typename Type0, typename... Args>
+ArrayKernelExec GenerateTemporal(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::DATE32:
+ return Generator<Type0, Date32Type, Args...>::Exec;
+ case Type::DATE64:
+ return Generator<Type0, Date64Type, Args...>::Exec;
+ case Type::DURATION:
+ return Generator<Type0, DurationType, Args...>::Exec;
+ case Type::TIME32:
+ return Generator<Type0, Time32Type, Args...>::Exec;
+ case Type::TIME64:
+ return Generator<Type0, Time64Type, Args...>::Exec;
+ case Type::TIMESTAMP:
+ return Generator<Type0, TimestampType, Args...>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+// Generate a kernel given a templated functor for decimal types
+//
+// See "Numeric" above for description of the generator functor
+template <template <typename...> class Generator, typename Type0, typename... Args>
+ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::DECIMAL128:
+ return Generator<Type0, Decimal128Type, Args...>::Exec;
+ case Type::DECIMAL256:
+ return Generator<Type0, Decimal256Type, Args...>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+// END of kernel generator-dispatchers
+// ----------------------------------------------------------------------
+
+ARROW_EXPORT
+void EnsureDictionaryDecoded(std::vector<ValueDescr>* descrs);
+
+ARROW_EXPORT
+void EnsureDictionaryDecoded(ValueDescr* begin, size_t count);
+
+ARROW_EXPORT
+void ReplaceNullWithOtherType(std::vector<ValueDescr>* descrs);
+
+ARROW_EXPORT
+void ReplaceNullWithOtherType(ValueDescr* begin, size_t count);
+
+ARROW_EXPORT
+void ReplaceTypes(const std::shared_ptr<DataType>&, std::vector<ValueDescr>* descrs);
+
+ARROW_EXPORT
+void ReplaceTypes(const std::shared_ptr<DataType>&, ValueDescr* descrs, size_t count);
+
+ARROW_EXPORT
+std::shared_ptr<DataType> CommonNumeric(const std::vector<ValueDescr>& descrs);
+
+ARROW_EXPORT
+std::shared_ptr<DataType> CommonNumeric(const ValueDescr* begin, size_t count);
+
+ARROW_EXPORT
+std::shared_ptr<DataType> CommonTemporal(const ValueDescr* begin, size_t count);
+
+ARROW_EXPORT
+std::shared_ptr<DataType> CommonBinary(const ValueDescr* begin, size_t count);
+
+/// How to promote decimal precision/scale in CastBinaryDecimalArgs.
+enum class DecimalPromotion : uint8_t {
+ kAdd,
+ kMultiply,
+ kDivide,
+};
+
+/// Given two arguments, at least one of which is decimal, promote all
+/// to not necessarily identical types, but types which are compatible
+/// for the given operator (add/multiply/divide).
+ARROW_EXPORT
+Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector<ValueDescr>* descrs);
+
+/// Given one or more arguments, at least one of which is decimal,
+/// promote all to an identical type.
+ARROW_EXPORT
+Status CastDecimalArgs(ValueDescr* begin, size_t count);
+
+ARROW_EXPORT
+bool HasDecimal(const std::vector<ValueDescr>& descrs);
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
new file mode 100644
index 000000000..d64143dea
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
@@ -0,0 +1,163 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+TEST(TestDispatchBest, CastBinaryDecimalArgs) {
+ std::vector<ValueDescr> args;
+ std::vector<DecimalPromotion> modes = {
+ DecimalPromotion::kAdd, DecimalPromotion::kMultiply, DecimalPromotion::kDivide};
+
+ // Any float -> all float
+ for (auto mode : modes) {
+ args = {decimal128(3, 2), float64()};
+ ASSERT_OK(CastBinaryDecimalArgs(mode, &args));
+ AssertTypeEqual(args[0].type, float64());
+ AssertTypeEqual(args[1].type, float64());
+ }
+
+ // Integer -> decimal with common scale
+ args = {decimal128(1, 0), int64()};
+ ASSERT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args));
+ AssertTypeEqual(args[0].type, decimal128(1, 0));
+ AssertTypeEqual(args[1].type, decimal128(19, 0));
+
+ // Add: rescale so all have common scale
+ args = {decimal128(3, 2), decimal128(3, -2)};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ NotImplemented, ::testing::HasSubstr("Decimals with negative scales not supported"),
+ CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args));
+}
+
+TEST(TestDispatchBest, CastDecimalArgs) {
+ std::vector<ValueDescr> args;
+
+ // Any float -> all float
+ args = {decimal128(3, 2), float64()};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, float64());
+ AssertTypeEqual(args[1].type, float64());
+
+ args = {float32(), float64(), decimal128(3, 2)};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, float64());
+ AssertTypeEqual(args[1].type, float64());
+ AssertTypeEqual(args[2].type, float64());
+
+ // Promote to common decimal width
+ args = {decimal128(3, 2), decimal256(3, 2)};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, decimal256(3, 2));
+ AssertTypeEqual(args[1].type, decimal256(3, 2));
+
+ // Rescale so all have common scale/precision
+ args = {decimal128(3, 2), decimal128(3, 0)};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, decimal128(5, 2));
+ AssertTypeEqual(args[1].type, decimal128(5, 2));
+
+ args = {decimal128(3, 2), decimal128(3, -2)};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, decimal128(7, 2));
+ AssertTypeEqual(args[1].type, decimal128(7, 2));
+
+ args = {decimal128(3, 0), decimal128(3, 1), decimal128(3, 2)};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, decimal128(5, 2));
+ AssertTypeEqual(args[1].type, decimal128(5, 2));
+ AssertTypeEqual(args[2].type, decimal128(5, 2));
+
+ // Integer -> decimal with appropriate precision
+ args = {decimal128(3, 0), int64()};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, decimal128(19, 0));
+ AssertTypeEqual(args[1].type, decimal128(19, 0));
+
+ args = {decimal128(3, 1), int64()};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, decimal128(20, 1));
+ AssertTypeEqual(args[1].type, decimal128(20, 1));
+
+ args = {decimal128(3, -1), int64()};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, decimal128(19, 0));
+ AssertTypeEqual(args[1].type, decimal128(19, 0));
+
+ // Overflow decimal128 max precision -> promote to decimal256
+ args = {decimal128(38, 0), decimal128(37, 2)};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, decimal256(40, 2));
+ AssertTypeEqual(args[1].type, decimal256(40, 2));
+
+ // Overflow decimal256 max precision
+ args = {decimal256(76, 0), decimal256(75, 1)};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr(
+ "Result precision (77) exceeds max precision of Decimal256 (76)"),
+ CastDecimalArgs(args.data(), args.size()));
+
+ // Incompatible, no cast
+ args = {decimal256(3, 2), float64(), utf8()};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, decimal256(3, 2));
+ AssertTypeEqual(args[1].type, float64());
+ AssertTypeEqual(args[2].type, utf8());
+}
+
+TEST(TestDispatchBest, CommonTemporal) {
+ std::vector<ValueDescr> args;
+
+ args = {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::NANO)};
+ AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal(args.data(), args.size()));
+ args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::NANO, "UTC")};
+ AssertTypeEqual(timestamp(TimeUnit::NANO, "UTC"),
+ CommonTemporal(args.data(), args.size()));
+ args = {date32(), timestamp(TimeUnit::NANO)};
+ AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal(args.data(), args.size()));
+ args = {date64(), timestamp(TimeUnit::SECOND)};
+ AssertTypeEqual(timestamp(TimeUnit::MILLI), CommonTemporal(args.data(), args.size()));
+ args = {date32(), date32()};
+ AssertTypeEqual(date32(), CommonTemporal(args.data(), args.size()));
+ args = {date64(), date64()};
+ AssertTypeEqual(date64(), CommonTemporal(args.data(), args.size()));
+ args = {date32(), date64()};
+ AssertTypeEqual(date64(), CommonTemporal(args.data(), args.size()));
+ args = {};
+ ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size()));
+ args = {float64(), int32()};
+ ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size()));
+ args = {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND, "UTC")};
+ ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size()));
+ args = {timestamp(TimeUnit::SECOND, "America/Phoenix"),
+ timestamp(TimeUnit::SECOND, "UTC")};
+ ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size()));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/common.h b/src/arrow/cpp/src/arrow/compute/kernels/common.h
new file mode 100644
index 000000000..21244320f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/common.h
@@ -0,0 +1,54 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+// IWYU pragma: begin_exports
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/data.h"
+#include "arrow/buffer.h"
+#include "arrow/chunked_array.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/function.h"
+#include "arrow/compute/kernel.h"
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/compute/registry.h"
+#include "arrow/datum.h"
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_view.h"
+
+// IWYU pragma: end_exports
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/src/arrow/cpp/src/arrow/compute/kernels/hash_aggregate.cc
new file mode 100644
index 000000000..73c8f9d26
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -0,0 +1,2659 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cmath>
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <string>
+#include <thread>
+#include <unordered_map>
+#include <vector>
+
+#include "arrow/buffer_builder.h"
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/exec/key_compare.h"
+#include "arrow/compute/exec/key_encode.h"
+#include "arrow/compute/exec/key_hash.h"
+#include "arrow/compute/exec/key_map.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/compute/exec_internal.h"
+#include "arrow/compute/kernel.h"
+#include "arrow/compute/kernels/aggregate_internal.h"
+#include "arrow/compute/kernels/aggregate_var_std_internal.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/row_encoder.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/record_batch.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/bitmap_writer.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/cpu_info.h"
+#include "arrow/util/int128_internal.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/task_group.h"
+#include "arrow/util/tdigest.h"
+#include "arrow/util/thread_pool.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::FirstTimeBitmapWriter;
+
+namespace compute {
+namespace internal {
+namespace {
+
+struct GrouperImpl : Grouper {
+ static Result<std::unique_ptr<GrouperImpl>> Make(const std::vector<ValueDescr>& keys,
+ ExecContext* ctx) {
+ auto impl = ::arrow::internal::make_unique<GrouperImpl>();
+
+ impl->encoders_.resize(keys.size());
+ impl->ctx_ = ctx;
+
+ for (size_t i = 0; i < keys.size(); ++i) {
+ const auto& key = keys[i].type;
+
+ if (key->id() == Type::BOOL) {
+ impl->encoders_[i] = ::arrow::internal::make_unique<BooleanKeyEncoder>();
+ continue;
+ }
+
+ if (key->id() == Type::DICTIONARY) {
+ impl->encoders_[i] =
+ ::arrow::internal::make_unique<DictionaryKeyEncoder>(key, ctx->memory_pool());
+ continue;
+ }
+
+ if (is_fixed_width(key->id())) {
+ impl->encoders_[i] = ::arrow::internal::make_unique<FixedWidthKeyEncoder>(key);
+ continue;
+ }
+
+ if (is_binary_like(key->id())) {
+ impl->encoders_[i] =
+ ::arrow::internal::make_unique<VarLengthKeyEncoder<BinaryType>>(key);
+ continue;
+ }
+
+ if (is_large_binary_like(key->id())) {
+ impl->encoders_[i] =
+ ::arrow::internal::make_unique<VarLengthKeyEncoder<LargeBinaryType>>(key);
+ continue;
+ }
+
+ return Status::NotImplemented("Keys of type ", *key);
+ }
+
+ return std::move(impl);
+ }
+
+ Result<Datum> Consume(const ExecBatch& batch) override {
+ std::vector<int32_t> offsets_batch(batch.length + 1);
+ for (int i = 0; i < batch.num_values(); ++i) {
+ encoders_[i]->AddLength(batch[i], batch.length, offsets_batch.data());
+ }
+
+ int32_t total_length = 0;
+ for (int64_t i = 0; i < batch.length; ++i) {
+ auto total_length_before = total_length;
+ total_length += offsets_batch[i];
+ offsets_batch[i] = total_length_before;
+ }
+ offsets_batch[batch.length] = total_length;
+
+ std::vector<uint8_t> key_bytes_batch(total_length);
+ std::vector<uint8_t*> key_buf_ptrs(batch.length);
+ for (int64_t i = 0; i < batch.length; ++i) {
+ key_buf_ptrs[i] = key_bytes_batch.data() + offsets_batch[i];
+ }
+
+ for (int i = 0; i < batch.num_values(); ++i) {
+ RETURN_NOT_OK(encoders_[i]->Encode(batch[i], batch.length, key_buf_ptrs.data()));
+ }
+
+ TypedBufferBuilder<uint32_t> group_ids_batch(ctx_->memory_pool());
+ RETURN_NOT_OK(group_ids_batch.Resize(batch.length));
+
+ for (int64_t i = 0; i < batch.length; ++i) {
+ int32_t key_length = offsets_batch[i + 1] - offsets_batch[i];
+ std::string key(
+ reinterpret_cast<const char*>(key_bytes_batch.data() + offsets_batch[i]),
+ key_length);
+
+ auto it_success = map_.emplace(key, num_groups_);
+ auto group_id = it_success.first->second;
+
+ if (it_success.second) {
+ // new key; update offsets and key_bytes
+ ++num_groups_;
+ auto next_key_offset = static_cast<int32_t>(key_bytes_.size());
+ key_bytes_.resize(next_key_offset + key_length);
+ offsets_.push_back(next_key_offset + key_length);
+ memcpy(key_bytes_.data() + next_key_offset, key.c_str(), key_length);
+ }
+
+ group_ids_batch.UnsafeAppend(group_id);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto group_ids, group_ids_batch.Finish());
+ return Datum(UInt32Array(batch.length, std::move(group_ids)));
+ }
+
+ uint32_t num_groups() const override { return num_groups_; }
+
+ Result<ExecBatch> GetUniques() override {
+ ExecBatch out({}, num_groups_);
+
+ std::vector<uint8_t*> key_buf_ptrs(num_groups_);
+ for (int64_t i = 0; i < num_groups_; ++i) {
+ key_buf_ptrs[i] = key_bytes_.data() + offsets_[i];
+ }
+
+ out.values.resize(encoders_.size());
+ for (size_t i = 0; i < encoders_.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(
+ out.values[i],
+ encoders_[i]->Decode(key_buf_ptrs.data(), static_cast<int32_t>(num_groups_),
+ ctx_->memory_pool()));
+ }
+
+ return out;
+ }
+
+ ExecContext* ctx_;
+ std::unordered_map<std::string, uint32_t> map_;
+ std::vector<int32_t> offsets_ = {0};
+ std::vector<uint8_t> key_bytes_;
+ uint32_t num_groups_ = 0;
+ std::vector<std::unique_ptr<KeyEncoder>> encoders_;
+};
+
+struct GrouperFastImpl : Grouper {
+ static constexpr int kBitmapPaddingForSIMD = 64; // bits
+ static constexpr int kPaddingForSIMD = 32; // bytes
+
+ static bool CanUse(const std::vector<ValueDescr>& keys) {
+#if ARROW_LITTLE_ENDIAN
+ for (size_t i = 0; i < keys.size(); ++i) {
+ const auto& key = keys[i].type;
+ if (is_large_binary_like(key->id())) {
+ return false;
+ }
+ }
+ return true;
+#else
+ return false;
+#endif
+ }
+
+ static Result<std::unique_ptr<GrouperFastImpl>> Make(
+ const std::vector<ValueDescr>& keys, ExecContext* ctx) {
+ auto impl = ::arrow::internal::make_unique<GrouperFastImpl>();
+ impl->ctx_ = ctx;
+
+ RETURN_NOT_OK(impl->temp_stack_.Init(ctx->memory_pool(), 64 * minibatch_size_max_));
+ impl->encode_ctx_.hardware_flags =
+ arrow::internal::CpuInfo::GetInstance()->hardware_flags();
+ impl->encode_ctx_.stack = &impl->temp_stack_;
+
+ auto num_columns = keys.size();
+ impl->col_metadata_.resize(num_columns);
+ impl->key_types_.resize(num_columns);
+ impl->dictionaries_.resize(num_columns);
+ for (size_t icol = 0; icol < num_columns; ++icol) {
+ const auto& key = keys[icol].type;
+ if (key->id() == Type::DICTIONARY) {
+ auto bit_width = checked_cast<const FixedWidthType&>(*key).bit_width();
+ ARROW_DCHECK(bit_width % 8 == 0);
+ impl->col_metadata_[icol] =
+ arrow::compute::KeyEncoder::KeyColumnMetadata(true, bit_width / 8);
+ } else if (key->id() == Type::BOOL) {
+ impl->col_metadata_[icol] =
+ arrow::compute::KeyEncoder::KeyColumnMetadata(true, 0);
+ } else if (is_fixed_width(key->id())) {
+ impl->col_metadata_[icol] = arrow::compute::KeyEncoder::KeyColumnMetadata(
+ true, checked_cast<const FixedWidthType&>(*key).bit_width() / 8);
+ } else if (is_binary_like(key->id())) {
+ impl->col_metadata_[icol] =
+ arrow::compute::KeyEncoder::KeyColumnMetadata(false, sizeof(uint32_t));
+ } else {
+ return Status::NotImplemented("Keys of type ", *key);
+ }
+ impl->key_types_[icol] = key;
+ }
+
+ impl->encoder_.Init(impl->col_metadata_, &impl->encode_ctx_,
+ /* row_alignment = */ sizeof(uint64_t),
+ /* string_alignment = */ sizeof(uint64_t));
+ RETURN_NOT_OK(impl->rows_.Init(ctx->memory_pool(), impl->encoder_.row_metadata()));
+ RETURN_NOT_OK(
+ impl->rows_minibatch_.Init(ctx->memory_pool(), impl->encoder_.row_metadata()));
+ impl->minibatch_size_ = impl->minibatch_size_min_;
+ GrouperFastImpl* impl_ptr = impl.get();
+ auto equal_func = [impl_ptr](
+ int num_keys_to_compare, const uint16_t* selection_may_be_null,
+ const uint32_t* group_ids, uint32_t* out_num_keys_mismatch,
+ uint16_t* out_selection_mismatch) {
+ arrow::compute::KeyCompare::CompareColumnsToRows(
+ num_keys_to_compare, selection_may_be_null, group_ids, &impl_ptr->encode_ctx_,
+ out_num_keys_mismatch, out_selection_mismatch,
+ impl_ptr->encoder_.GetBatchColumns(), impl_ptr->rows_);
+ };
+ auto append_func = [impl_ptr](int num_keys, const uint16_t* selection) {
+ RETURN_NOT_OK(impl_ptr->encoder_.EncodeSelected(&impl_ptr->rows_minibatch_,
+ num_keys, selection));
+ return impl_ptr->rows_.AppendSelectionFrom(impl_ptr->rows_minibatch_, num_keys,
+ nullptr);
+ };
+ RETURN_NOT_OK(impl->map_.init(impl->encode_ctx_.hardware_flags, ctx->memory_pool(),
+ impl->encode_ctx_.stack, impl->log_minibatch_max_,
+ equal_func, append_func));
+ impl->cols_.resize(num_columns);
+ impl->minibatch_hashes_.resize(impl->minibatch_size_max_ +
+ kPaddingForSIMD / sizeof(uint32_t));
+
+ return std::move(impl);
+ }
+
+ ~GrouperFastImpl() { map_.cleanup(); }
+
+ Result<Datum> Consume(const ExecBatch& batch) override {
+ // ARROW-14027: broadcast scalar arguments for now
+ for (int i = 0; i < batch.num_values(); i++) {
+ if (batch.values[i].is_scalar()) {
+ ExecBatch expanded = batch;
+ for (int j = i; j < expanded.num_values(); j++) {
+ if (expanded.values[j].is_scalar()) {
+ ARROW_ASSIGN_OR_RAISE(
+ expanded.values[j],
+ MakeArrayFromScalar(*expanded.values[j].scalar(), expanded.length,
+ ctx_->memory_pool()));
+ }
+ }
+ return ConsumeImpl(expanded);
+ }
+ }
+ return ConsumeImpl(batch);
+ }
+
+ Result<Datum> ConsumeImpl(const ExecBatch& batch) {
+ int64_t num_rows = batch.length;
+ int num_columns = batch.num_values();
+ // Process dictionaries
+ for (int icol = 0; icol < num_columns; ++icol) {
+ if (key_types_[icol]->id() == Type::DICTIONARY) {
+ auto data = batch[icol].array();
+ auto dict = MakeArray(data->dictionary);
+ if (dictionaries_[icol]) {
+ if (!dictionaries_[icol]->Equals(dict)) {
+ // TODO(bkietz) unify if necessary. For now, just error if any batch's
+ // dictionary differs from the first we saw for this key
+ return Status::NotImplemented("Unifying differing dictionaries");
+ }
+ } else {
+ dictionaries_[icol] = std::move(dict);
+ }
+ }
+ }
+
+ std::shared_ptr<arrow::Buffer> group_ids;
+ ARROW_ASSIGN_OR_RAISE(
+ group_ids, AllocateBuffer(sizeof(uint32_t) * num_rows, ctx_->memory_pool()));
+
+ for (int icol = 0; icol < num_columns; ++icol) {
+ const uint8_t* non_nulls = nullptr;
+ if (batch[icol].array()->buffers[0] != NULLPTR) {
+ non_nulls = batch[icol].array()->buffers[0]->data();
+ }
+ const uint8_t* fixedlen = batch[icol].array()->buffers[1]->data();
+ const uint8_t* varlen = nullptr;
+ if (!col_metadata_[icol].is_fixed_length) {
+ varlen = batch[icol].array()->buffers[2]->data();
+ }
+
+ int64_t offset = batch[icol].array()->offset;
+
+ auto col_base = arrow::compute::KeyEncoder::KeyColumnArray(
+ col_metadata_[icol], offset + num_rows, non_nulls, fixedlen, varlen);
+
+ cols_[icol] =
+ arrow::compute::KeyEncoder::KeyColumnArray(col_base, offset, num_rows);
+ }
+
+ // Split into smaller mini-batches
+ //
+ for (uint32_t start_row = 0; start_row < num_rows;) {
+ uint32_t batch_size_next = std::min(static_cast<uint32_t>(minibatch_size_),
+ static_cast<uint32_t>(num_rows) - start_row);
+
+ // Encode
+ rows_minibatch_.Clean();
+ encoder_.PrepareEncodeSelected(start_row, batch_size_next, cols_);
+
+ // Compute hash
+ Hashing::HashMultiColumn(encoder_.GetBatchColumns(), &encode_ctx_,
+ minibatch_hashes_.data());
+
+ // Map
+ auto match_bitvector =
+ util::TempVectorHolder<uint8_t>(&temp_stack_, (batch_size_next + 7) / 8);
+ {
+ auto local_slots = util::TempVectorHolder<uint8_t>(&temp_stack_, batch_size_next);
+ map_.early_filter(batch_size_next, minibatch_hashes_.data(),
+ match_bitvector.mutable_data(), local_slots.mutable_data());
+ map_.find(batch_size_next, minibatch_hashes_.data(),
+ match_bitvector.mutable_data(), local_slots.mutable_data(),
+ reinterpret_cast<uint32_t*>(group_ids->mutable_data()) + start_row);
+ }
+ auto ids = util::TempVectorHolder<uint16_t>(&temp_stack_, batch_size_next);
+ int num_ids;
+ util::BitUtil::bits_to_indexes(0, encode_ctx_.hardware_flags, batch_size_next,
+ match_bitvector.mutable_data(), &num_ids,
+ ids.mutable_data());
+
+ RETURN_NOT_OK(map_.map_new_keys(
+ num_ids, ids.mutable_data(), minibatch_hashes_.data(),
+ reinterpret_cast<uint32_t*>(group_ids->mutable_data()) + start_row));
+
+ start_row += batch_size_next;
+
+ if (minibatch_size_ * 2 <= minibatch_size_max_) {
+ minibatch_size_ *= 2;
+ }
+ }
+
+ return Datum(UInt32Array(batch.length, std::move(group_ids)));
+ }
+
+ uint32_t num_groups() const override { return static_cast<uint32_t>(rows_.length()); }
+
+ // Make sure padded buffers end up with the right logical size
+
+ Result<std::shared_ptr<Buffer>> AllocatePaddedBitmap(int64_t length) {
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<Buffer> buf,
+ AllocateBitmap(length + kBitmapPaddingForSIMD, ctx_->memory_pool()));
+ return SliceMutableBuffer(buf, 0, BitUtil::BytesForBits(length));
+ }
+
+ Result<std::shared_ptr<Buffer>> AllocatePaddedBuffer(int64_t size) {
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<Buffer> buf,
+ AllocateBuffer(size + kBitmapPaddingForSIMD, ctx_->memory_pool()));
+ return SliceMutableBuffer(buf, 0, size);
+ }
+
+ Result<ExecBatch> GetUniques() override {
+ auto num_columns = static_cast<uint32_t>(col_metadata_.size());
+ int64_t num_groups = rows_.length();
+
+ std::vector<std::shared_ptr<Buffer>> non_null_bufs(num_columns);
+ std::vector<std::shared_ptr<Buffer>> fixedlen_bufs(num_columns);
+ std::vector<std::shared_ptr<Buffer>> varlen_bufs(num_columns);
+
+ for (size_t i = 0; i < num_columns; ++i) {
+ ARROW_ASSIGN_OR_RAISE(non_null_bufs[i], AllocatePaddedBitmap(num_groups));
+ if (col_metadata_[i].is_fixed_length) {
+ if (col_metadata_[i].fixed_length == 0) {
+ ARROW_ASSIGN_OR_RAISE(fixedlen_bufs[i], AllocatePaddedBitmap(num_groups));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ fixedlen_bufs[i],
+ AllocatePaddedBuffer(num_groups * col_metadata_[i].fixed_length));
+ }
+ } else {
+ ARROW_ASSIGN_OR_RAISE(fixedlen_bufs[i],
+ AllocatePaddedBuffer((num_groups + 1) * sizeof(uint32_t)));
+ }
+ cols_[i] = arrow::compute::KeyEncoder::KeyColumnArray(
+ col_metadata_[i], num_groups, non_null_bufs[i]->mutable_data(),
+ fixedlen_bufs[i]->mutable_data(), nullptr);
+ }
+
+ for (int64_t start_row = 0; start_row < num_groups;) {
+ int64_t batch_size_next =
+ std::min(num_groups - start_row, static_cast<int64_t>(minibatch_size_max_));
+ encoder_.DecodeFixedLengthBuffers(start_row, start_row, batch_size_next, rows_,
+ &cols_);
+ start_row += batch_size_next;
+ }
+
+ if (!rows_.metadata().is_fixed_length) {
+ for (size_t i = 0; i < num_columns; ++i) {
+ if (!col_metadata_[i].is_fixed_length) {
+ auto varlen_size =
+ reinterpret_cast<const uint32_t*>(fixedlen_bufs[i]->data())[num_groups];
+ ARROW_ASSIGN_OR_RAISE(varlen_bufs[i], AllocatePaddedBuffer(varlen_size));
+ cols_[i] = arrow::compute::KeyEncoder::KeyColumnArray(
+ col_metadata_[i], num_groups, non_null_bufs[i]->mutable_data(),
+ fixedlen_bufs[i]->mutable_data(), varlen_bufs[i]->mutable_data());
+ }
+ }
+
+ for (int64_t start_row = 0; start_row < num_groups;) {
+ int64_t batch_size_next =
+ std::min(num_groups - start_row, static_cast<int64_t>(minibatch_size_max_));
+ encoder_.DecodeVaryingLengthBuffers(start_row, start_row, batch_size_next, rows_,
+ &cols_);
+ start_row += batch_size_next;
+ }
+ }
+
+ ExecBatch out({}, num_groups);
+ out.values.resize(num_columns);
+ for (size_t i = 0; i < num_columns; ++i) {
+ auto valid_count = arrow::internal::CountSetBits(
+ non_null_bufs[i]->data(), /*offset=*/0, static_cast<int64_t>(num_groups));
+ int null_count = static_cast<int>(num_groups) - static_cast<int>(valid_count);
+
+ if (col_metadata_[i].is_fixed_length) {
+ out.values[i] = ArrayData::Make(
+ key_types_[i], num_groups,
+ {std::move(non_null_bufs[i]), std::move(fixedlen_bufs[i])}, null_count);
+ } else {
+ out.values[i] =
+ ArrayData::Make(key_types_[i], num_groups,
+ {std::move(non_null_bufs[i]), std::move(fixedlen_bufs[i]),
+ std::move(varlen_bufs[i])},
+ null_count);
+ }
+ }
+
+ // Process dictionaries
+ for (size_t icol = 0; icol < num_columns; ++icol) {
+ if (key_types_[icol]->id() == Type::DICTIONARY) {
+ if (dictionaries_[icol]) {
+ out.values[icol].array()->dictionary = dictionaries_[icol]->data();
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto dict, MakeArrayOfNull(key_types_[icol], 0));
+ out.values[icol].array()->dictionary = dict->data();
+ }
+ }
+ }
+
+ return out;
+ }
+
+ static constexpr int log_minibatch_max_ = 10;
+ static constexpr int minibatch_size_max_ = 1 << log_minibatch_max_;
+ static constexpr int minibatch_size_min_ = 128;
+ int minibatch_size_;
+
+ ExecContext* ctx_;
+ arrow::util::TempVectorStack temp_stack_;
+ arrow::compute::KeyEncoder::KeyEncoderContext encode_ctx_;
+
+ std::vector<std::shared_ptr<arrow::DataType>> key_types_;
+ std::vector<arrow::compute::KeyEncoder::KeyColumnMetadata> col_metadata_;
+ std::vector<arrow::compute::KeyEncoder::KeyColumnArray> cols_;
+ std::vector<uint32_t> minibatch_hashes_;
+
+ std::vector<std::shared_ptr<Array>> dictionaries_;
+
+ arrow::compute::KeyEncoder::KeyRowArray rows_;
+ arrow::compute::KeyEncoder::KeyRowArray rows_minibatch_;
+ arrow::compute::KeyEncoder encoder_;
+ arrow::compute::SwissTable map_;
+};
+
+/// C++ abstract base class for the HashAggregateKernel interface.
+/// Implementations should be default constructible and perform initialization in
+/// Init().
+struct GroupedAggregator : KernelState {
+ virtual Status Init(ExecContext*, const FunctionOptions*) = 0;
+
+ virtual Status Resize(int64_t new_num_groups) = 0;
+
+ virtual Status Consume(const ExecBatch& batch) = 0;
+
+ virtual Status Merge(GroupedAggregator&& other, const ArrayData& group_id_mapping) = 0;
+
+ virtual Result<Datum> Finalize() = 0;
+
+ virtual std::shared_ptr<DataType> out_type() const = 0;
+};
+
+template <typename Impl>
+Result<std::unique_ptr<KernelState>> HashAggregateInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ auto impl = ::arrow::internal::make_unique<Impl>();
+ RETURN_NOT_OK(impl->Init(ctx->exec_context(), args.options));
+ return std::move(impl);
+}
+
+Status HashAggregateResize(KernelContext* ctx, int64_t num_groups) {
+ return checked_cast<GroupedAggregator*>(ctx->state())->Resize(num_groups);
+}
+Status HashAggregateConsume(KernelContext* ctx, const ExecBatch& batch) {
+ return checked_cast<GroupedAggregator*>(ctx->state())->Consume(batch);
+}
+Status HashAggregateMerge(KernelContext* ctx, KernelState&& other,
+ const ArrayData& group_id_mapping) {
+ return checked_cast<GroupedAggregator*>(ctx->state())
+ ->Merge(checked_cast<GroupedAggregator&&>(other), group_id_mapping);
+}
+Status HashAggregateFinalize(KernelContext* ctx, Datum* out) {
+ return checked_cast<GroupedAggregator*>(ctx->state())->Finalize().Value(out);
+}
+
+HashAggregateKernel MakeKernel(InputType argument_type, KernelInit init) {
+ HashAggregateKernel kernel;
+ kernel.init = std::move(init);
+ kernel.signature = KernelSignature::Make(
+ {std::move(argument_type), InputType::Array(Type::UINT32)},
+ OutputType(
+ [](KernelContext* ctx, const std::vector<ValueDescr>&) -> Result<ValueDescr> {
+ return checked_cast<GroupedAggregator*>(ctx->state())->out_type();
+ }));
+ kernel.resize = HashAggregateResize;
+ kernel.consume = HashAggregateConsume;
+ kernel.merge = HashAggregateMerge;
+ kernel.finalize = HashAggregateFinalize;
+ return kernel;
+}
+
+Status AddHashAggKernels(
+ const std::vector<std::shared_ptr<DataType>>& types,
+ Result<HashAggregateKernel> make_kernel(const std::shared_ptr<DataType>&),
+ HashAggregateFunction* function) {
+ for (const auto& ty : types) {
+ ARROW_ASSIGN_OR_RAISE(auto kernel, make_kernel(ty));
+ RETURN_NOT_OK(function->AddKernel(std::move(kernel)));
+ }
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Helpers for more easily implementing hash aggregates
+
+template <typename T>
+struct GroupedValueTraits {
+ using CType = typename TypeTraits<T>::CType;
+
+ static CType Get(const CType* values, uint32_t g) { return values[g]; }
+ static void Set(CType* values, uint32_t g, CType v) { values[g] = v; }
+};
+template <>
+struct GroupedValueTraits<BooleanType> {
+ static bool Get(const uint8_t* values, uint32_t g) {
+ return BitUtil::GetBit(values, g);
+ }
+ static void Set(uint8_t* values, uint32_t g, bool v) {
+ BitUtil::SetBitTo(values, g, v);
+ }
+};
+
+template <typename Type, typename ConsumeValue, typename ConsumeNull>
+void VisitGroupedValues(const ExecBatch& batch, ConsumeValue&& valid_func,
+ ConsumeNull&& null_func) {
+ auto g = batch[1].array()->GetValues<uint32_t>(1);
+ if (batch[0].is_array()) {
+ VisitArrayValuesInline<Type>(
+ *batch[0].array(),
+ [&](typename TypeTraits<Type>::CType val) { valid_func(*g++, val); },
+ [&]() { null_func(*g++); });
+ return;
+ }
+ const auto& input = *batch[0].scalar();
+ if (input.is_valid) {
+ const auto val = UnboxScalar<Type>::Unbox(input);
+ for (int64_t i = 0; i < batch.length; i++) {
+ valid_func(*g++, val);
+ }
+ } else {
+ for (int64_t i = 0; i < batch.length; i++) {
+ null_func(*g++);
+ }
+ }
+}
+
+template <typename Type, typename ConsumeValue>
+void VisitGroupedValuesNonNull(const ExecBatch& batch, ConsumeValue&& valid_func) {
+ VisitGroupedValues<Type>(batch, std::forward<ConsumeValue>(valid_func),
+ [](uint32_t) {});
+}
+
+// ----------------------------------------------------------------------
+// Count implementation
+
+struct GroupedCountImpl : public GroupedAggregator {
+ Status Init(ExecContext* ctx, const FunctionOptions* options) override {
+ options_ = checked_cast<const CountOptions&>(*options);
+ counts_ = BufferBuilder(ctx->memory_pool());
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ num_groups_ = new_num_groups;
+ return counts_.Append(added_groups * sizeof(int64_t), 0);
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast<GroupedCountImpl*>(&raw_other);
+
+ auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
+ auto other_counts = reinterpret_cast<const int64_t*>(other->counts_.mutable_data());
+
+ auto g = group_id_mapping.GetValues<uint32_t>(1);
+ for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
+ counts[*g] += other_counts[other_g];
+ }
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
+ auto g_begin = batch[1].array()->GetValues<uint32_t>(1);
+
+ if (options_.mode == CountOptions::ALL) {
+ for (int64_t i = 0; i < batch.length; ++i, ++g_begin) {
+ counts[*g_begin] += 1;
+ }
+ } else if (batch[0].is_array()) {
+ const auto& input = batch[0].array();
+ if (options_.mode == CountOptions::ONLY_VALID) {
+ arrow::internal::VisitSetBitRunsVoid(input->buffers[0], input->offset,
+ input->length,
+ [&](int64_t offset, int64_t length) {
+ auto g = g_begin + offset;
+ for (int64_t i = 0; i < length; ++i, ++g) {
+ counts[*g] += 1;
+ }
+ });
+ } else { // ONLY_NULL
+ if (input->MayHaveNulls()) {
+ auto end = input->offset + input->length;
+ for (int64_t i = input->offset; i < end; ++i, ++g_begin) {
+ counts[*g_begin] += !BitUtil::GetBit(input->buffers[0]->data(), i);
+ }
+ }
+ }
+ } else {
+ const auto& input = *batch[0].scalar();
+ if (options_.mode == CountOptions::ONLY_VALID) {
+ for (int64_t i = 0; i < batch.length; ++i, ++g_begin) {
+ counts[*g_begin] += input.is_valid;
+ }
+ } else { // ONLY_NULL
+ for (int64_t i = 0; i < batch.length; ++i, ++g_begin) {
+ counts[*g_begin] += !input.is_valid;
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ Result<Datum> Finalize() override {
+ ARROW_ASSIGN_OR_RAISE(auto counts, counts_.Finish());
+ return std::make_shared<Int64Array>(num_groups_, std::move(counts));
+ }
+
+ std::shared_ptr<DataType> out_type() const override { return int64(); }
+
+ int64_t num_groups_ = 0;
+ CountOptions options_;
+ BufferBuilder counts_;
+};
+
+// ----------------------------------------------------------------------
+// Sum/Mean/Product implementation
+
+template <typename Type, typename Impl>
+struct GroupedReducingAggregator : public GroupedAggregator {
+ using AccType = typename FindAccumulatorType<Type>::Type;
+ using CType = typename TypeTraits<AccType>::CType;
+ using InputCType = typename TypeTraits<Type>::CType;
+
+ Status Init(ExecContext* ctx, const FunctionOptions* options) override {
+ pool_ = ctx->memory_pool();
+ options_ = checked_cast<const ScalarAggregateOptions&>(*options);
+ reduced_ = TypedBufferBuilder<CType>(pool_);
+ counts_ = TypedBufferBuilder<int64_t>(pool_);
+ no_nulls_ = TypedBufferBuilder<bool>(pool_);
+ // out_type_ initialized by SumInit
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ num_groups_ = new_num_groups;
+ RETURN_NOT_OK(reduced_.Append(added_groups, Impl::NullValue(*out_type_)));
+ RETURN_NOT_OK(counts_.Append(added_groups, 0));
+ RETURN_NOT_OK(no_nulls_.Append(added_groups, true));
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ CType* reduced = reduced_.mutable_data();
+ int64_t* counts = counts_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
+
+ VisitGroupedValues<Type>(
+ batch,
+ [&](uint32_t g, InputCType value) {
+ reduced[g] = Impl::Reduce(*out_type_, reduced[g], value);
+ counts[g]++;
+ },
+ [&](uint32_t g) { BitUtil::SetBitTo(no_nulls, g, false); });
+ return Status::OK();
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast<GroupedReducingAggregator<Type, Impl>*>(&raw_other);
+
+ CType* reduced = reduced_.mutable_data();
+ int64_t* counts = counts_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
+
+ const CType* other_reduced = other->reduced_.data();
+ const int64_t* other_counts = other->counts_.data();
+ const uint8_t* other_no_nulls = no_nulls_.mutable_data();
+
+ auto g = group_id_mapping.GetValues<uint32_t>(1);
+ for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
+ counts[*g] += other_counts[other_g];
+ reduced[*g] = Impl::Reduce(*out_type_, reduced[*g], other_reduced[other_g]);
+ BitUtil::SetBitTo(
+ no_nulls, *g,
+ BitUtil::GetBit(no_nulls, *g) && BitUtil::GetBit(other_no_nulls, other_g));
+ }
+ return Status::OK();
+ }
+
+ // Generate the values/nulls buffers
+ static Result<std::shared_ptr<Buffer>> Finish(MemoryPool* pool,
+ const ScalarAggregateOptions& options,
+ const int64_t* counts,
+ TypedBufferBuilder<CType>* reduced,
+ int64_t num_groups, int64_t* null_count,
+ std::shared_ptr<Buffer>* null_bitmap) {
+ for (int64_t i = 0; i < num_groups; ++i) {
+ if (counts[i] >= options.min_count) continue;
+
+ if ((*null_bitmap) == nullptr) {
+ ARROW_ASSIGN_OR_RAISE(*null_bitmap, AllocateBitmap(num_groups, pool));
+ BitUtil::SetBitsTo((*null_bitmap)->mutable_data(), 0, num_groups, true);
+ }
+
+ (*null_count)++;
+ BitUtil::SetBitTo((*null_bitmap)->mutable_data(), i, false);
+ }
+ return reduced->Finish();
+ }
+
+ Result<Datum> Finalize() override {
+ std::shared_ptr<Buffer> null_bitmap = nullptr;
+ const int64_t* counts = counts_.data();
+ int64_t null_count = 0;
+
+ ARROW_ASSIGN_OR_RAISE(auto values,
+ Impl::Finish(pool_, options_, counts, &reduced_, num_groups_,
+ &null_count, &null_bitmap));
+
+ if (!options_.skip_nulls) {
+ null_count = kUnknownNullCount;
+ if (null_bitmap) {
+ arrow::internal::BitmapAnd(null_bitmap->data(), /*left_offset=*/0,
+ no_nulls_.data(), /*right_offset=*/0, num_groups_,
+ /*out_offset=*/0, null_bitmap->mutable_data());
+ } else {
+ ARROW_ASSIGN_OR_RAISE(null_bitmap, no_nulls_.Finish());
+ }
+ }
+
+ return ArrayData::Make(out_type(), num_groups_,
+ {std::move(null_bitmap), std::move(values)}, null_count);
+ }
+
+ std::shared_ptr<DataType> out_type() const override { return out_type_; }
+
+ int64_t num_groups_ = 0;
+ ScalarAggregateOptions options_;
+ TypedBufferBuilder<CType> reduced_;
+ TypedBufferBuilder<int64_t> counts_;
+ TypedBufferBuilder<bool> no_nulls_;
+ std::shared_ptr<DataType> out_type_;
+ MemoryPool* pool_;
+};
+
+// ----------------------------------------------------------------------
+// Sum implementation
+
+template <typename Type>
+struct GroupedSumImpl : public GroupedReducingAggregator<Type, GroupedSumImpl<Type>> {
+ using Base = GroupedReducingAggregator<Type, GroupedSumImpl<Type>>;
+ using CType = typename Base::CType;
+ using InputCType = typename Base::InputCType;
+
+ // Default value for a group
+ static CType NullValue(const DataType&) { return CType(0); }
+
+ template <typename T = Type>
+ static enable_if_number<T, CType> Reduce(const DataType&, const CType u,
+ const InputCType v) {
+ return static_cast<CType>(to_unsigned(u) + to_unsigned(static_cast<CType>(v)));
+ }
+
+ static CType Reduce(const DataType&, const CType u, const CType v) {
+ return static_cast<CType>(to_unsigned(u) + to_unsigned(v));
+ }
+
+ using Base::Finish;
+};
+
+template <template <typename T> class Impl, typename T>
+Result<std::unique_ptr<KernelState>> SumInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ ARROW_ASSIGN_OR_RAISE(auto impl, HashAggregateInit<Impl<T>>(ctx, args));
+ static_cast<Impl<T>*>(impl.get())->out_type_ =
+ TypeTraits<typename Impl<T>::AccType>::type_singleton();
+ return std::move(impl);
+}
+
+template <typename Impl>
+Result<std::unique_ptr<KernelState>> DecimalSumInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ ARROW_ASSIGN_OR_RAISE(auto impl, HashAggregateInit<Impl>(ctx, args));
+ static_cast<Impl*>(impl.get())->out_type_ = args.inputs[0].type;
+ return std::move(impl);
+}
+
+struct GroupedSumFactory {
+ template <typename T, typename AccType = typename FindAccumulatorType<T>::Type>
+ Status Visit(const T&) {
+ kernel = MakeKernel(std::move(argument_type), SumInit<GroupedSumImpl, T>);
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal128Type&) {
+ kernel = MakeKernel(std::move(argument_type),
+ DecimalSumInit<GroupedSumImpl<Decimal128Type>>);
+ return Status::OK();
+ }
+ Status Visit(const Decimal256Type&) {
+ kernel = MakeKernel(std::move(argument_type),
+ DecimalSumInit<GroupedSumImpl<Decimal256Type>>);
+ return Status::OK();
+ }
+
+ Status Visit(const HalfFloatType& type) {
+ return Status::NotImplemented("Summing data of type ", type);
+ }
+
+ Status Visit(const DataType& type) {
+ return Status::NotImplemented("Summing data of type ", type);
+ }
+
+ static Result<HashAggregateKernel> Make(const std::shared_ptr<DataType>& type) {
+ GroupedSumFactory factory;
+ factory.argument_type = InputType::Array(type->id());
+ RETURN_NOT_OK(VisitTypeInline(*type, &factory));
+ return std::move(factory.kernel);
+ }
+
+ HashAggregateKernel kernel;
+ InputType argument_type;
+};
+
+// ----------------------------------------------------------------------
+// Product implementation
+
+template <typename Type>
+struct GroupedProductImpl final
+ : public GroupedReducingAggregator<Type, GroupedProductImpl<Type>> {
+ using Base = GroupedReducingAggregator<Type, GroupedProductImpl<Type>>;
+ using AccType = typename Base::AccType;
+ using CType = typename Base::CType;
+ using InputCType = typename Base::InputCType;
+
+ static CType NullValue(const DataType& out_type) {
+ return MultiplyTraits<AccType>::one(out_type);
+ }
+
+ template <typename T = Type>
+ static enable_if_number<T, CType> Reduce(const DataType& out_type, const CType u,
+ const InputCType v) {
+ return MultiplyTraits<AccType>::Multiply(out_type, u, static_cast<CType>(v));
+ }
+
+ static CType Reduce(const DataType& out_type, const CType u, const CType v) {
+ return MultiplyTraits<AccType>::Multiply(out_type, u, v);
+ }
+
+ using Base::Finish;
+};
+
+struct GroupedProductFactory {
+ template <typename T, typename AccType = typename FindAccumulatorType<T>::Type>
+ Status Visit(const T&) {
+ kernel = MakeKernel(std::move(argument_type), SumInit<GroupedProductImpl, T>);
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal128Type&) {
+ kernel = MakeKernel(std::move(argument_type),
+ DecimalSumInit<GroupedProductImpl<Decimal128Type>>);
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal256Type&) {
+ kernel = MakeKernel(std::move(argument_type),
+ DecimalSumInit<GroupedProductImpl<Decimal256Type>>);
+ return Status::OK();
+ }
+
+ Status Visit(const HalfFloatType& type) {
+ return Status::NotImplemented("Taking product of data of type ", type);
+ }
+
+ Status Visit(const DataType& type) {
+ return Status::NotImplemented("Taking product of data of type ", type);
+ }
+
+ static Result<HashAggregateKernel> Make(const std::shared_ptr<DataType>& type) {
+ GroupedProductFactory factory;
+ factory.argument_type = InputType::Array(type->id());
+ RETURN_NOT_OK(VisitTypeInline(*type, &factory));
+ return std::move(factory.kernel);
+ }
+
+ HashAggregateKernel kernel;
+ InputType argument_type;
+};
+
+// ----------------------------------------------------------------------
+// Mean implementation
+
+template <typename Type>
+struct GroupedMeanImpl : public GroupedReducingAggregator<Type, GroupedMeanImpl<Type>> {
+ using Base = GroupedReducingAggregator<Type, GroupedMeanImpl<Type>>;
+ using CType = typename Base::CType;
+ using InputCType = typename Base::InputCType;
+ using MeanType =
+ typename std::conditional<is_decimal_type<Type>::value, CType, double>::type;
+
+ static CType NullValue(const DataType&) { return CType(0); }
+
+ template <typename T = Type>
+ static enable_if_number<T, CType> Reduce(const DataType&, const CType u,
+ const InputCType v) {
+ return static_cast<CType>(to_unsigned(u) + to_unsigned(static_cast<CType>(v)));
+ }
+
+ static CType Reduce(const DataType&, const CType u, const CType v) {
+ return static_cast<CType>(to_unsigned(u) + to_unsigned(v));
+ }
+
+ static Result<std::shared_ptr<Buffer>> Finish(MemoryPool* pool,
+ const ScalarAggregateOptions& options,
+ const int64_t* counts,
+ TypedBufferBuilder<CType>* reduced_,
+ int64_t num_groups, int64_t* null_count,
+ std::shared_ptr<Buffer>* null_bitmap) {
+ const CType* reduced = reduced_->data();
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> values,
+ AllocateBuffer(num_groups * sizeof(MeanType), pool));
+ MeanType* means = reinterpret_cast<MeanType*>(values->mutable_data());
+ for (int64_t i = 0; i < num_groups; ++i) {
+ if (counts[i] >= options.min_count) {
+ means[i] = static_cast<MeanType>(reduced[i]) / counts[i];
+ continue;
+ }
+ means[i] = MeanType(0);
+
+ if ((*null_bitmap) == nullptr) {
+ ARROW_ASSIGN_OR_RAISE(*null_bitmap, AllocateBitmap(num_groups, pool));
+ BitUtil::SetBitsTo((*null_bitmap)->mutable_data(), 0, num_groups, true);
+ }
+
+ (*null_count)++;
+ BitUtil::SetBitTo((*null_bitmap)->mutable_data(), i, false);
+ }
+ return std::move(values);
+ }
+
+ std::shared_ptr<DataType> out_type() const override {
+ if (is_decimal_type<Type>::value) return this->out_type_;
+ return float64();
+ }
+};
+
+struct GroupedMeanFactory {
+ template <typename T, typename AccType = typename FindAccumulatorType<T>::Type>
+ Status Visit(const T&) {
+ kernel = MakeKernel(std::move(argument_type), SumInit<GroupedMeanImpl, T>);
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal128Type&) {
+ kernel = MakeKernel(std::move(argument_type),
+ DecimalSumInit<GroupedMeanImpl<Decimal128Type>>);
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal256Type&) {
+ kernel = MakeKernel(std::move(argument_type),
+ DecimalSumInit<GroupedMeanImpl<Decimal256Type>>);
+ return Status::OK();
+ }
+
+ Status Visit(const HalfFloatType& type) {
+ return Status::NotImplemented("Computing mean of type ", type);
+ }
+
+ Status Visit(const DataType& type) {
+ return Status::NotImplemented("Computing mean of type ", type);
+ }
+
+ static Result<HashAggregateKernel> Make(const std::shared_ptr<DataType>& type) {
+ GroupedMeanFactory factory;
+ factory.argument_type = InputType::Array(type->id());
+ RETURN_NOT_OK(VisitTypeInline(*type, &factory));
+ return std::move(factory.kernel);
+ }
+
+ HashAggregateKernel kernel;
+ InputType argument_type;
+};
+
+// Variance/Stdev implementation
+
+using arrow::internal::int128_t;
+
+template <typename Type>
+struct GroupedVarStdImpl : public GroupedAggregator {
+ using CType = typename Type::c_type;
+
+ Status Init(ExecContext* ctx, const FunctionOptions* options) override {
+ options_ = *checked_cast<const VarianceOptions*>(options);
+ ctx_ = ctx;
+ pool_ = ctx->memory_pool();
+ counts_ = TypedBufferBuilder<int64_t>(pool_);
+ means_ = TypedBufferBuilder<double>(pool_);
+ m2s_ = TypedBufferBuilder<double>(pool_);
+ no_nulls_ = TypedBufferBuilder<bool>(pool_);
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ num_groups_ = new_num_groups;
+ RETURN_NOT_OK(counts_.Append(added_groups, 0));
+ RETURN_NOT_OK(means_.Append(added_groups, 0));
+ RETURN_NOT_OK(m2s_.Append(added_groups, 0));
+ RETURN_NOT_OK(no_nulls_.Append(added_groups, true));
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override { return ConsumeImpl(batch); }
+
+ // float/double/int64: calculate `m2` (sum((X-mean)^2)) with `two pass algorithm`
+ // (see aggregate_var_std.cc)
+ template <typename T = Type>
+ enable_if_t<is_floating_type<T>::value || (sizeof(CType) > 4), Status> ConsumeImpl(
+ const ExecBatch& batch) {
+ using SumType =
+ typename std::conditional<is_floating_type<T>::value, double, int128_t>::type;
+
+ GroupedVarStdImpl<Type> state;
+ RETURN_NOT_OK(state.Init(ctx_, &options_));
+ RETURN_NOT_OK(state.Resize(num_groups_));
+ int64_t* counts = state.counts_.mutable_data();
+ double* means = state.means_.mutable_data();
+ double* m2s = state.m2s_.mutable_data();
+ uint8_t* no_nulls = state.no_nulls_.mutable_data();
+
+ // XXX this uses naive summation; we should switch to pairwise summation as was
+ // done for the scalar aggregate kernel in ARROW-11567
+ std::vector<SumType> sums(num_groups_);
+ VisitGroupedValues<Type>(
+ batch,
+ [&](uint32_t g, typename TypeTraits<Type>::CType value) {
+ sums[g] += value;
+ counts[g]++;
+ },
+ [&](uint32_t g) { BitUtil::ClearBit(no_nulls, g); });
+
+ for (int64_t i = 0; i < num_groups_; i++) {
+ means[i] = static_cast<double>(sums[i]) / counts[i];
+ }
+
+ VisitGroupedValuesNonNull<Type>(
+ batch, [&](uint32_t g, typename TypeTraits<Type>::CType value) {
+ const double v = static_cast<double>(value);
+ m2s[g] += (v - means[g]) * (v - means[g]);
+ });
+
+ ARROW_ASSIGN_OR_RAISE(auto mapping,
+ AllocateBuffer(num_groups_ * sizeof(uint32_t), pool_));
+ for (uint32_t i = 0; static_cast<int64_t>(i) < num_groups_; i++) {
+ reinterpret_cast<uint32_t*>(mapping->mutable_data())[i] = i;
+ }
+ ArrayData group_id_mapping(uint32(), num_groups_, {nullptr, std::move(mapping)},
+ /*null_count=*/0);
+ return this->Merge(std::move(state), group_id_mapping);
+ }
+
+ // int32/16/8: textbook one pass algorithm with integer arithmetic (see
+ // aggregate_var_std.cc)
+ template <typename T = Type>
+ enable_if_t<is_integer_type<T>::value && (sizeof(CType) <= 4), Status> ConsumeImpl(
+ const ExecBatch& batch) {
+ // max number of elements that sum will not overflow int64 (2Gi int32 elements)
+ // for uint32: 0 <= sum < 2^63 (int64 >= 0)
+ // for int32: -2^62 <= sum < 2^62
+ constexpr int64_t max_length = 1ULL << (63 - sizeof(CType) * 8);
+
+ const auto g = batch[1].array()->GetValues<uint32_t>(1);
+ if (batch[0].is_scalar() && !batch[0].scalar()->is_valid) {
+ uint8_t* no_nulls = no_nulls_.mutable_data();
+ for (int64_t i = 0; i < batch.length; i++) {
+ BitUtil::ClearBit(no_nulls, g[i]);
+ }
+ return Status::OK();
+ }
+
+ std::vector<IntegerVarStd<Type>> var_std(num_groups_);
+
+ ARROW_ASSIGN_OR_RAISE(auto mapping,
+ AllocateBuffer(num_groups_ * sizeof(uint32_t), pool_));
+ for (uint32_t i = 0; static_cast<int64_t>(i) < num_groups_; i++) {
+ reinterpret_cast<uint32_t*>(mapping->mutable_data())[i] = i;
+ }
+ ArrayData group_id_mapping(uint32(), num_groups_, {nullptr, std::move(mapping)},
+ /*null_count=*/0);
+
+ for (int64_t start_index = 0; start_index < batch.length; start_index += max_length) {
+ // process in chunks that overflow will never happen
+
+ // reset state
+ var_std.clear();
+ var_std.resize(num_groups_);
+ GroupedVarStdImpl<Type> state;
+ RETURN_NOT_OK(state.Init(ctx_, &options_));
+ RETURN_NOT_OK(state.Resize(num_groups_));
+ int64_t* other_counts = state.counts_.mutable_data();
+ double* other_means = state.means_.mutable_data();
+ double* other_m2s = state.m2s_.mutable_data();
+ uint8_t* other_no_nulls = state.no_nulls_.mutable_data();
+
+ if (batch[0].is_array()) {
+ const auto& array = *batch[0].array();
+ const CType* values = array.GetValues<CType>(1);
+ auto visit_values = [&](int64_t pos, int64_t len) {
+ for (int64_t i = 0; i < len; ++i) {
+ const int64_t index = start_index + pos + i;
+ const auto value = values[index];
+ var_std[g[index]].ConsumeOne(value);
+ }
+ };
+
+ if (array.MayHaveNulls()) {
+ arrow::internal::BitRunReader reader(
+ array.buffers[0]->data(), array.offset + start_index,
+ std::min(max_length, batch.length - start_index));
+ int64_t position = 0;
+ while (true) {
+ auto run = reader.NextRun();
+ if (run.length == 0) break;
+ if (run.set) {
+ visit_values(position, run.length);
+ } else {
+ for (int64_t i = 0; i < run.length; ++i) {
+ BitUtil::ClearBit(other_no_nulls, g[start_index + position + i]);
+ }
+ }
+ position += run.length;
+ }
+ } else {
+ visit_values(0, array.length);
+ }
+ } else {
+ const auto value = UnboxScalar<Type>::Unbox(*batch[0].scalar());
+ for (int64_t i = 0; i < std::min(max_length, batch.length - start_index); ++i) {
+ const int64_t index = start_index + i;
+ var_std[g[index]].ConsumeOne(value);
+ }
+ }
+
+ for (int64_t i = 0; i < num_groups_; i++) {
+ if (var_std[i].count == 0) continue;
+
+ other_counts[i] = var_std[i].count;
+ other_means[i] = var_std[i].mean();
+ other_m2s[i] = var_std[i].m2();
+ }
+ RETURN_NOT_OK(this->Merge(std::move(state), group_id_mapping));
+ }
+ return Status::OK();
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ // Combine m2 from two chunks (see aggregate_var_std.cc)
+ auto other = checked_cast<GroupedVarStdImpl*>(&raw_other);
+
+ int64_t* counts = counts_.mutable_data();
+ double* means = means_.mutable_data();
+ double* m2s = m2s_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
+
+ const int64_t* other_counts = other->counts_.data();
+ const double* other_means = other->means_.data();
+ const double* other_m2s = other->m2s_.data();
+ const uint8_t* other_no_nulls = other->no_nulls_.data();
+
+ auto g = group_id_mapping.GetValues<uint32_t>(1);
+ for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
+ if (!BitUtil::GetBit(other_no_nulls, other_g)) {
+ BitUtil::ClearBit(no_nulls, *g);
+ }
+ if (other_counts[other_g] == 0) continue;
+ MergeVarStd(counts[*g], means[*g], other_counts[other_g], other_means[other_g],
+ other_m2s[other_g], &counts[*g], &means[*g], &m2s[*g]);
+ }
+ return Status::OK();
+ }
+
+ Result<Datum> Finalize() override {
+ std::shared_ptr<Buffer> null_bitmap;
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> values,
+ AllocateBuffer(num_groups_ * sizeof(double), pool_));
+ int64_t null_count = 0;
+
+ double* results = reinterpret_cast<double*>(values->mutable_data());
+ const int64_t* counts = counts_.data();
+ const double* m2s = m2s_.data();
+ for (int64_t i = 0; i < num_groups_; ++i) {
+ if (counts[i] > options_.ddof && counts[i] >= options_.min_count) {
+ const double variance = m2s[i] / (counts[i] - options_.ddof);
+ results[i] = result_type_ == VarOrStd::Var ? variance : std::sqrt(variance);
+ continue;
+ }
+
+ results[i] = 0;
+ if (null_bitmap == nullptr) {
+ ARROW_ASSIGN_OR_RAISE(null_bitmap, AllocateBitmap(num_groups_, pool_));
+ BitUtil::SetBitsTo(null_bitmap->mutable_data(), 0, num_groups_, true);
+ }
+
+ null_count += 1;
+ BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false);
+ }
+ if (!options_.skip_nulls) {
+ if (null_bitmap) {
+ arrow::internal::BitmapAnd(null_bitmap->data(), 0, no_nulls_.data(), 0,
+ num_groups_, 0, null_bitmap->mutable_data());
+ } else {
+ ARROW_ASSIGN_OR_RAISE(null_bitmap, no_nulls_.Finish());
+ }
+ null_count = kUnknownNullCount;
+ }
+
+ return ArrayData::Make(float64(), num_groups_,
+ {std::move(null_bitmap), std::move(values)}, null_count);
+ }
+
+ std::shared_ptr<DataType> out_type() const override { return float64(); }
+
+ VarOrStd result_type_;
+ VarianceOptions options_;
+ int64_t num_groups_ = 0;
+ // m2 = count * s2 = sum((X-mean)^2)
+ TypedBufferBuilder<int64_t> counts_;
+ TypedBufferBuilder<double> means_, m2s_;
+ TypedBufferBuilder<bool> no_nulls_;
+ ExecContext* ctx_;
+ MemoryPool* pool_;
+};
+
+template <typename T, VarOrStd result_type>
+Result<std::unique_ptr<KernelState>> VarStdInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ auto impl = ::arrow::internal::make_unique<GroupedVarStdImpl<T>>();
+ impl->result_type_ = result_type;
+ RETURN_NOT_OK(impl->Init(ctx->exec_context(), args.options));
+ return std::move(impl);
+}
+
+template <VarOrStd result_type>
+struct GroupedVarStdFactory {
+ template <typename T, typename Enable = enable_if_t<is_integer_type<T>::value ||
+ is_floating_type<T>::value>>
+ Status Visit(const T&) {
+ kernel = MakeKernel(std::move(argument_type), VarStdInit<T, result_type>);
+ return Status::OK();
+ }
+
+ Status Visit(const HalfFloatType& type) {
+ return Status::NotImplemented("Computing variance/stddev of data of type ", type);
+ }
+
+ Status Visit(const DataType& type) {
+ return Status::NotImplemented("Computing variance/stddev of data of type ", type);
+ }
+
+ static Result<HashAggregateKernel> Make(const std::shared_ptr<DataType>& type) {
+ GroupedVarStdFactory factory;
+ factory.argument_type = InputType::Array(type);
+ RETURN_NOT_OK(VisitTypeInline(*type, &factory));
+ return std::move(factory.kernel);
+ }
+
+ HashAggregateKernel kernel;
+ InputType argument_type;
+};
+
+// ----------------------------------------------------------------------
+// TDigest implementation
+
+using arrow::internal::TDigest;
+
+template <typename Type>
+struct GroupedTDigestImpl : public GroupedAggregator {
+ using CType = typename Type::c_type;
+
+ Status Init(ExecContext* ctx, const FunctionOptions* options) override {
+ options_ = *checked_cast<const TDigestOptions*>(options);
+ ctx_ = ctx;
+ pool_ = ctx->memory_pool();
+ counts_ = TypedBufferBuilder<int64_t>(pool_);
+ no_nulls_ = TypedBufferBuilder<bool>(pool_);
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ const int64_t added_groups = new_num_groups - tdigests_.size();
+ tdigests_.reserve(new_num_groups);
+ for (int64_t i = 0; i < added_groups; i++) {
+ tdigests_.emplace_back(options_.delta, options_.buffer_size);
+ }
+ RETURN_NOT_OK(counts_.Append(new_num_groups, 0));
+ RETURN_NOT_OK(no_nulls_.Append(new_num_groups, true));
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ int64_t* counts = counts_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
+ VisitGroupedValues<Type>(
+ batch,
+ [&](uint32_t g, CType value) {
+ tdigests_[g].NanAdd(value);
+ counts[g]++;
+ },
+ [&](uint32_t g) { BitUtil::SetBitTo(no_nulls, g, false); });
+ return Status::OK();
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast<GroupedTDigestImpl*>(&raw_other);
+
+ int64_t* counts = counts_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
+
+ const int64_t* other_counts = other->counts_.data();
+ const uint8_t* other_no_nulls = no_nulls_.mutable_data();
+
+ auto g = group_id_mapping.GetValues<uint32_t>(1);
+ for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
+ tdigests_[*g].Merge(other->tdigests_[other_g]);
+ counts[*g] += other_counts[other_g];
+ BitUtil::SetBitTo(
+ no_nulls, *g,
+ BitUtil::GetBit(no_nulls, *g) && BitUtil::GetBit(other_no_nulls, other_g));
+ }
+
+ return Status::OK();
+ }
+
+ Result<Datum> Finalize() override {
+ const int64_t slot_length = options_.q.size();
+ const int64_t num_values = tdigests_.size() * slot_length;
+ const int64_t* counts = counts_.data();
+ std::shared_ptr<Buffer> null_bitmap;
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> values,
+ AllocateBuffer(num_values * sizeof(double), pool_));
+ int64_t null_count = 0;
+
+ double* results = reinterpret_cast<double*>(values->mutable_data());
+ for (int64_t i = 0; static_cast<size_t>(i) < tdigests_.size(); ++i) {
+ if (!tdigests_[i].is_empty() && counts[i] >= options_.min_count &&
+ (options_.skip_nulls || BitUtil::GetBit(no_nulls_.data(), i))) {
+ for (int64_t j = 0; j < slot_length; j++) {
+ results[i * slot_length + j] = tdigests_[i].Quantile(options_.q[j]);
+ }
+ continue;
+ }
+
+ if (!null_bitmap) {
+ ARROW_ASSIGN_OR_RAISE(null_bitmap, AllocateBitmap(num_values, pool_));
+ BitUtil::SetBitsTo(null_bitmap->mutable_data(), 0, num_values, true);
+ }
+ null_count += slot_length;
+ BitUtil::SetBitsTo(null_bitmap->mutable_data(), i * slot_length, slot_length,
+ false);
+ std::fill(&results[i * slot_length], &results[(i + 1) * slot_length], 0.0);
+ }
+
+ auto child = ArrayData::Make(float64(), num_values,
+ {std::move(null_bitmap), std::move(values)}, null_count);
+ return ArrayData::Make(out_type(), tdigests_.size(), {nullptr}, {std::move(child)},
+ /*null_count=*/0);
+ }
+
+ std::shared_ptr<DataType> out_type() const override {
+ return fixed_size_list(float64(), static_cast<int32_t>(options_.q.size()));
+ }
+
+ TDigestOptions options_;
+ std::vector<TDigest> tdigests_;
+ TypedBufferBuilder<int64_t> counts_;
+ TypedBufferBuilder<bool> no_nulls_;
+ ExecContext* ctx_;
+ MemoryPool* pool_;
+};
+
+struct GroupedTDigestFactory {
+ template <typename T>
+ enable_if_number<T, Status> Visit(const T&) {
+ kernel =
+ MakeKernel(std::move(argument_type), HashAggregateInit<GroupedTDigestImpl<T>>);
+ return Status::OK();
+ }
+
+ Status Visit(const HalfFloatType& type) {
+ return Status::NotImplemented("Computing t-digest of data of type ", type);
+ }
+
+ Status Visit(const DataType& type) {
+ return Status::NotImplemented("Computing t-digest of data of type ", type);
+ }
+
+ static Result<HashAggregateKernel> Make(const std::shared_ptr<DataType>& type) {
+ GroupedTDigestFactory factory;
+ factory.argument_type = InputType::Array(type);
+ RETURN_NOT_OK(VisitTypeInline(*type, &factory));
+ return std::move(factory.kernel);
+ }
+
+ HashAggregateKernel kernel;
+ InputType argument_type;
+};
+
+HashAggregateKernel MakeApproximateMedianKernel(HashAggregateFunction* tdigest_func) {
+ HashAggregateKernel kernel;
+ kernel.init = [tdigest_func](
+ KernelContext* ctx,
+ const KernelInitArgs& args) -> Result<std::unique_ptr<KernelState>> {
+ std::vector<ValueDescr> inputs = args.inputs;
+ ARROW_ASSIGN_OR_RAISE(auto kernel, tdigest_func->DispatchBest(&inputs));
+ const auto& scalar_options =
+ checked_cast<const ScalarAggregateOptions&>(*args.options);
+ TDigestOptions options;
+ // Default q = 0.5
+ options.min_count = scalar_options.min_count;
+ options.skip_nulls = scalar_options.skip_nulls;
+ KernelInitArgs new_args{kernel, inputs, &options};
+ return kernel->init(ctx, new_args);
+ };
+ kernel.signature =
+ KernelSignature::Make({InputType(ValueDescr::ANY), InputType::Array(Type::UINT32)},
+ ValueDescr::Array(float64()));
+ kernel.resize = HashAggregateResize;
+ kernel.consume = HashAggregateConsume;
+ kernel.merge = HashAggregateMerge;
+ kernel.finalize = [](KernelContext* ctx, Datum* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum temp,
+ checked_cast<GroupedAggregator*>(ctx->state())->Finalize());
+ *out = temp.array_as<FixedSizeListArray>()->values();
+ return Status::OK();
+ };
+ return kernel;
+}
+
+// ----------------------------------------------------------------------
+// MinMax implementation
+
+template <typename CType>
+struct AntiExtrema {
+ static constexpr CType anti_min() { return std::numeric_limits<CType>::max(); }
+ static constexpr CType anti_max() { return std::numeric_limits<CType>::min(); }
+};
+
+template <>
+struct AntiExtrema<bool> {
+ static constexpr bool anti_min() { return true; }
+ static constexpr bool anti_max() { return false; }
+};
+
+template <>
+struct AntiExtrema<float> {
+ static constexpr float anti_min() { return std::numeric_limits<float>::infinity(); }
+ static constexpr float anti_max() { return -std::numeric_limits<float>::infinity(); }
+};
+
+template <>
+struct AntiExtrema<double> {
+ static constexpr double anti_min() { return std::numeric_limits<double>::infinity(); }
+ static constexpr double anti_max() { return -std::numeric_limits<double>::infinity(); }
+};
+
+template <>
+struct AntiExtrema<Decimal128> {
+ static constexpr Decimal128 anti_min() { return BasicDecimal128::GetMaxSentinel(); }
+ static constexpr Decimal128 anti_max() { return BasicDecimal128::GetMinSentinel(); }
+};
+
+template <>
+struct AntiExtrema<Decimal256> {
+ static constexpr Decimal256 anti_min() { return BasicDecimal256::GetMaxSentinel(); }
+ static constexpr Decimal256 anti_max() { return BasicDecimal256::GetMinSentinel(); }
+};
+
+template <typename Type>
+struct GroupedMinMaxImpl final : public GroupedAggregator {
+ using CType = typename TypeTraits<Type>::CType;
+ using GetSet = GroupedValueTraits<Type>;
+ using ArrType =
+ typename std::conditional<is_boolean_type<Type>::value, uint8_t, CType>::type;
+
+ Status Init(ExecContext* ctx, const FunctionOptions* options) override {
+ options_ = *checked_cast<const ScalarAggregateOptions*>(options);
+ // type_ initialized by MinMaxInit
+ mins_ = TypedBufferBuilder<CType>(ctx->memory_pool());
+ maxes_ = TypedBufferBuilder<CType>(ctx->memory_pool());
+ has_values_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+ has_nulls_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ num_groups_ = new_num_groups;
+ RETURN_NOT_OK(mins_.Append(added_groups, AntiExtrema<CType>::anti_min()));
+ RETURN_NOT_OK(maxes_.Append(added_groups, AntiExtrema<CType>::anti_max()));
+ RETURN_NOT_OK(has_values_.Append(added_groups, false));
+ RETURN_NOT_OK(has_nulls_.Append(added_groups, false));
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ auto raw_mins = mins_.mutable_data();
+ auto raw_maxes = maxes_.mutable_data();
+
+ VisitGroupedValues<Type>(
+ batch,
+ [&](uint32_t g, CType val) {
+ GetSet::Set(raw_mins, g, std::min(GetSet::Get(raw_mins, g), val));
+ GetSet::Set(raw_maxes, g, std::max(GetSet::Get(raw_maxes, g), val));
+ BitUtil::SetBit(has_values_.mutable_data(), g);
+ },
+ [&](uint32_t g) { BitUtil::SetBit(has_nulls_.mutable_data(), g); });
+ return Status::OK();
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast<GroupedMinMaxImpl*>(&raw_other);
+
+ auto raw_mins = mins_.mutable_data();
+ auto raw_maxes = maxes_.mutable_data();
+
+ auto other_raw_mins = other->mins_.mutable_data();
+ auto other_raw_maxes = other->maxes_.mutable_data();
+
+ auto g = group_id_mapping.GetValues<uint32_t>(1);
+ for (uint32_t other_g = 0; static_cast<int64_t>(other_g) < group_id_mapping.length;
+ ++other_g, ++g) {
+ GetSet::Set(
+ raw_mins, *g,
+ std::min(GetSet::Get(raw_mins, *g), GetSet::Get(other_raw_mins, other_g)));
+ GetSet::Set(
+ raw_maxes, *g,
+ std::max(GetSet::Get(raw_maxes, *g), GetSet::Get(other_raw_maxes, other_g)));
+
+ if (BitUtil::GetBit(other->has_values_.data(), other_g)) {
+ BitUtil::SetBit(has_values_.mutable_data(), *g);
+ }
+ if (BitUtil::GetBit(other->has_nulls_.data(), other_g)) {
+ BitUtil::SetBit(has_nulls_.mutable_data(), *g);
+ }
+ }
+ return Status::OK();
+ }
+
+ Result<Datum> Finalize() override {
+ // aggregation for group is valid if there was at least one value in that group
+ ARROW_ASSIGN_OR_RAISE(auto null_bitmap, has_values_.Finish());
+
+ if (!options_.skip_nulls) {
+ // ... and there were no nulls in that group
+ ARROW_ASSIGN_OR_RAISE(auto has_nulls, has_nulls_.Finish());
+ arrow::internal::BitmapAndNot(null_bitmap->data(), 0, has_nulls->data(), 0,
+ num_groups_, 0, null_bitmap->mutable_data());
+ }
+
+ auto mins = ArrayData::Make(type_, num_groups_, {null_bitmap, nullptr});
+ auto maxes = ArrayData::Make(type_, num_groups_, {std::move(null_bitmap), nullptr});
+ ARROW_ASSIGN_OR_RAISE(mins->buffers[1], mins_.Finish());
+ ARROW_ASSIGN_OR_RAISE(maxes->buffers[1], maxes_.Finish());
+
+ return ArrayData::Make(out_type(), num_groups_, {nullptr},
+ {std::move(mins), std::move(maxes)});
+ }
+
+ std::shared_ptr<DataType> out_type() const override {
+ return struct_({field("min", type_), field("max", type_)});
+ }
+
+ int64_t num_groups_;
+ TypedBufferBuilder<CType> mins_, maxes_;
+ TypedBufferBuilder<bool> has_values_, has_nulls_;
+ std::shared_ptr<DataType> type_;
+ ScalarAggregateOptions options_;
+};
+
+struct GroupedNullMinMaxImpl final : public GroupedAggregator {
+ Status Init(ExecContext* ctx, const FunctionOptions*) override { return Status::OK(); }
+
+ Status Resize(int64_t new_num_groups) override {
+ num_groups_ = new_num_groups;
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override { return Status::OK(); }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ return Status::OK();
+ }
+
+ Result<Datum> Finalize() override {
+ return ArrayData::Make(
+ out_type(), num_groups_, {nullptr},
+ {
+ ArrayData::Make(null(), num_groups_, {nullptr}, num_groups_),
+ ArrayData::Make(null(), num_groups_, {nullptr}, num_groups_),
+ });
+ }
+
+ std::shared_ptr<DataType> out_type() const override {
+ return struct_({field("min", null()), field("max", null())});
+ }
+
+ int64_t num_groups_;
+};
+
+template <typename T>
+Result<std::unique_ptr<KernelState>> MinMaxInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ ARROW_ASSIGN_OR_RAISE(auto impl, HashAggregateInit<GroupedMinMaxImpl<T>>(ctx, args));
+ static_cast<GroupedMinMaxImpl<T>*>(impl.get())->type_ = args.inputs[0].type;
+ return std::move(impl);
+}
+
+template <MinOrMax min_or_max>
+HashAggregateKernel MakeMinOrMaxKernel(HashAggregateFunction* min_max_func) {
+ HashAggregateKernel kernel;
+ kernel.init = [min_max_func](
+ KernelContext* ctx,
+ const KernelInitArgs& args) -> Result<std::unique_ptr<KernelState>> {
+ std::vector<ValueDescr> inputs = args.inputs;
+ ARROW_ASSIGN_OR_RAISE(auto kernel, min_max_func->DispatchBest(&inputs));
+ KernelInitArgs new_args{kernel, inputs, args.options};
+ return kernel->init(ctx, new_args);
+ };
+ kernel.signature = KernelSignature::Make(
+ {InputType(ValueDescr::ANY), InputType::Array(Type::UINT32)},
+ OutputType([](KernelContext* ctx,
+ const std::vector<ValueDescr>& descrs) -> Result<ValueDescr> {
+ return ValueDescr::Array(descrs[0].type);
+ }));
+ kernel.resize = HashAggregateResize;
+ kernel.consume = HashAggregateConsume;
+ kernel.merge = HashAggregateMerge;
+ kernel.finalize = [](KernelContext* ctx, Datum* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum temp,
+ checked_cast<GroupedAggregator*>(ctx->state())->Finalize());
+ *out = temp.array_as<StructArray>()->field(static_cast<uint8_t>(min_or_max));
+ return Status::OK();
+ };
+ return kernel;
+}
+
+struct GroupedMinMaxFactory {
+ template <typename T>
+ enable_if_physical_integer<T, Status> Visit(const T&) {
+ using PhysicalType = typename T::PhysicalType;
+ kernel = MakeKernel(std::move(argument_type), MinMaxInit<PhysicalType>);
+ return Status::OK();
+ }
+
+ // MSVC2015 apparently doesn't compile this properly if we use
+ // enable_if_floating_point
+ Status Visit(const FloatType&) {
+ kernel = MakeKernel(std::move(argument_type), MinMaxInit<FloatType>);
+ return Status::OK();
+ }
+
+ Status Visit(const DoubleType&) {
+ kernel = MakeKernel(std::move(argument_type), MinMaxInit<DoubleType>);
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_decimal<T, Status> Visit(const T&) {
+ kernel = MakeKernel(std::move(argument_type), MinMaxInit<T>);
+ return Status::OK();
+ }
+
+ Status Visit(const BooleanType&) {
+ kernel = MakeKernel(std::move(argument_type), MinMaxInit<BooleanType>);
+ return Status::OK();
+ }
+
+ Status Visit(const NullType&) {
+ kernel =
+ MakeKernel(std::move(argument_type), HashAggregateInit<GroupedNullMinMaxImpl>);
+ return Status::OK();
+ }
+
+ Status Visit(const HalfFloatType& type) {
+ return Status::NotImplemented("Computing min/max of data of type ", type);
+ }
+
+ Status Visit(const DataType& type) {
+ return Status::NotImplemented("Computing min/max of data of type ", type);
+ }
+
+ static Result<HashAggregateKernel> Make(const std::shared_ptr<DataType>& type) {
+ GroupedMinMaxFactory factory;
+ factory.argument_type = InputType::Array(type->id());
+ RETURN_NOT_OK(VisitTypeInline(*type, &factory));
+ return std::move(factory.kernel);
+ }
+
+ HashAggregateKernel kernel;
+ InputType argument_type;
+};
+
+// ----------------------------------------------------------------------
+// Any/All implementation
+
+template <typename Impl>
+struct GroupedBooleanAggregator : public GroupedAggregator {
+ Status Init(ExecContext* ctx, const FunctionOptions* options) override {
+ options_ = checked_cast<const ScalarAggregateOptions&>(*options);
+ pool_ = ctx->memory_pool();
+ reduced_ = TypedBufferBuilder<bool>(pool_);
+ no_nulls_ = TypedBufferBuilder<bool>(pool_);
+ counts_ = TypedBufferBuilder<int64_t>(pool_);
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ num_groups_ = new_num_groups;
+ RETURN_NOT_OK(reduced_.Append(added_groups, Impl::NullValue()));
+ RETURN_NOT_OK(no_nulls_.Append(added_groups, true));
+ return counts_.Append(added_groups, 0);
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ uint8_t* reduced = reduced_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
+ int64_t* counts = counts_.mutable_data();
+ auto g = batch[1].array()->GetValues<uint32_t>(1);
+
+ if (batch[0].is_array()) {
+ const auto& input = *batch[0].array();
+ if (input.MayHaveNulls()) {
+ const uint8_t* bitmap = input.buffers[1]->data();
+ arrow::internal::VisitBitBlocksVoid(
+ input.buffers[0], input.offset, input.length,
+ [&](int64_t position) {
+ counts[*g]++;
+ Impl::UpdateGroupWith(reduced, *g, BitUtil::GetBit(bitmap, position));
+ g++;
+ },
+ [&] { BitUtil::SetBitTo(no_nulls, *g++, false); });
+ } else {
+ arrow::internal::VisitBitBlocksVoid(
+ input.buffers[1], input.offset, input.length,
+ [&](int64_t) {
+ Impl::UpdateGroupWith(reduced, *g, true);
+ counts[*g++]++;
+ },
+ [&]() {
+ Impl::UpdateGroupWith(reduced, *g, false);
+ counts[*g++]++;
+ });
+ }
+ } else {
+ const auto& input = *batch[0].scalar();
+ if (input.is_valid) {
+ const bool value = UnboxScalar<BooleanType>::Unbox(input);
+ for (int64_t i = 0; i < batch.length; i++) {
+ Impl::UpdateGroupWith(reduced, *g, value);
+ counts[*g++]++;
+ }
+ } else {
+ for (int64_t i = 0; i < batch.length; i++) {
+ BitUtil::SetBitTo(no_nulls, *g++, false);
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast<GroupedBooleanAggregator<Impl>*>(&raw_other);
+
+ uint8_t* reduced = reduced_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
+ int64_t* counts = counts_.mutable_data();
+
+ const uint8_t* other_reduced = other->reduced_.mutable_data();
+ const uint8_t* other_no_nulls = other->no_nulls_.mutable_data();
+ const int64_t* other_counts = other->counts_.mutable_data();
+
+ auto g = group_id_mapping.GetValues<uint32_t>(1);
+ for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
+ counts[*g] += other_counts[other_g];
+ Impl::UpdateGroupWith(reduced, *g, BitUtil::GetBit(other_reduced, other_g));
+ BitUtil::SetBitTo(
+ no_nulls, *g,
+ BitUtil::GetBit(no_nulls, *g) && BitUtil::GetBit(other_no_nulls, other_g));
+ }
+ return Status::OK();
+ }
+
+ Result<Datum> Finalize() override {
+ std::shared_ptr<Buffer> null_bitmap;
+ const int64_t* counts = counts_.data();
+ int64_t null_count = 0;
+
+ for (int64_t i = 0; i < num_groups_; ++i) {
+ if (counts[i] >= options_.min_count) continue;
+
+ if (null_bitmap == nullptr) {
+ ARROW_ASSIGN_OR_RAISE(null_bitmap, AllocateBitmap(num_groups_, pool_));
+ BitUtil::SetBitsTo(null_bitmap->mutable_data(), 0, num_groups_, true);
+ }
+
+ null_count += 1;
+ BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto reduced, reduced_.Finish());
+ if (!options_.skip_nulls) {
+ null_count = kUnknownNullCount;
+ ARROW_ASSIGN_OR_RAISE(auto no_nulls, no_nulls_.Finish());
+ Impl::AdjustForMinCount(no_nulls->mutable_data(), reduced->data(), num_groups_);
+ if (null_bitmap) {
+ arrow::internal::BitmapAnd(null_bitmap->data(), /*left_offset=*/0,
+ no_nulls->data(), /*right_offset=*/0, num_groups_,
+ /*out_offset=*/0, null_bitmap->mutable_data());
+ } else {
+ null_bitmap = std::move(no_nulls);
+ }
+ }
+
+ return ArrayData::Make(out_type(), num_groups_,
+ {std::move(null_bitmap), std::move(reduced)}, null_count);
+ }
+
+ std::shared_ptr<DataType> out_type() const override { return boolean(); }
+
+ int64_t num_groups_ = 0;
+ ScalarAggregateOptions options_;
+ TypedBufferBuilder<bool> reduced_, no_nulls_;
+ TypedBufferBuilder<int64_t> counts_;
+ MemoryPool* pool_;
+};
+
+struct GroupedAnyImpl : public GroupedBooleanAggregator<GroupedAnyImpl> {
+ // The default value for a group.
+ static bool NullValue() { return false; }
+
+ // Update the value for a group given an observation.
+ static void UpdateGroupWith(uint8_t* seen, uint32_t g, bool value) {
+ if (!BitUtil::GetBit(seen, g) && value) {
+ BitUtil::SetBit(seen, g);
+ }
+ }
+
+ // Combine the array of observed nulls with the array of group values.
+ static void AdjustForMinCount(uint8_t* no_nulls, const uint8_t* seen,
+ int64_t num_groups) {
+ arrow::internal::BitmapOr(no_nulls, /*left_offset=*/0, seen, /*right_offset=*/0,
+ num_groups, /*out_offset=*/0, no_nulls);
+ }
+};
+
+struct GroupedAllImpl : public GroupedBooleanAggregator<GroupedAllImpl> {
+ static bool NullValue() { return true; }
+
+ static void UpdateGroupWith(uint8_t* seen, uint32_t g, bool value) {
+ if (!value) {
+ BitUtil::ClearBit(seen, g);
+ }
+ }
+
+ static void AdjustForMinCount(uint8_t* no_nulls, const uint8_t* seen,
+ int64_t num_groups) {
+ arrow::internal::BitmapOrNot(no_nulls, /*left_offset=*/0, seen, /*right_offset=*/0,
+ num_groups, /*out_offset=*/0, no_nulls);
+ }
+};
+
+// ----------------------------------------------------------------------
+// CountDistinct/Distinct implementation
+
+struct GroupedCountDistinctImpl : public GroupedAggregator {
+ Status Init(ExecContext* ctx, const FunctionOptions* options) override {
+ ctx_ = ctx;
+ pool_ = ctx->memory_pool();
+ options_ = checked_cast<const CountOptions&>(*options);
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ num_groups_ = new_num_groups;
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ ARROW_ASSIGN_OR_RAISE(std::ignore, grouper_->Consume(batch));
+ return Status::OK();
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast<GroupedCountDistinctImpl*>(&raw_other);
+
+ // Get (value, group_id) pairs, then translate the group IDs and consume them
+ // ourselves
+ ARROW_ASSIGN_OR_RAISE(auto uniques, other->grouper_->GetUniques());
+ ARROW_ASSIGN_OR_RAISE(auto remapped_g,
+ AllocateBuffer(uniques.length * sizeof(uint32_t), pool_));
+
+ const auto* g_mapping = group_id_mapping.GetValues<uint32_t>(1);
+ const auto* other_g = uniques[1].array()->GetValues<uint32_t>(1);
+ auto* g = reinterpret_cast<uint32_t*>(remapped_g->mutable_data());
+
+ for (int64_t i = 0; i < uniques.length; i++) {
+ g[i] = g_mapping[other_g[i]];
+ }
+ uniques.values[1] =
+ ArrayData::Make(uint32(), uniques.length, {nullptr, std::move(remapped_g)});
+
+ return Consume(std::move(uniques));
+ }
+
+ Result<Datum> Finalize() override {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> values,
+ AllocateBuffer(num_groups_ * sizeof(int64_t), pool_));
+ int64_t* counts = reinterpret_cast<int64_t*>(values->mutable_data());
+ std::fill(counts, counts + num_groups_, 0);
+
+ ARROW_ASSIGN_OR_RAISE(auto uniques, grouper_->GetUniques());
+ auto* g = uniques[1].array()->GetValues<uint32_t>(1);
+ const auto& items = *uniques[0].array();
+ const auto* valid = items.GetValues<uint8_t>(0, 0);
+ if (options_.mode == CountOptions::ALL ||
+ (options_.mode == CountOptions::ONLY_VALID && !valid)) {
+ for (int64_t i = 0; i < uniques.length; i++) {
+ counts[g[i]]++;
+ }
+ } else if (options_.mode == CountOptions::ONLY_VALID) {
+ for (int64_t i = 0; i < uniques.length; i++) {
+ counts[g[i]] += BitUtil::GetBit(valid, items.offset + i);
+ }
+ } else if (valid) { // ONLY_NULL
+ for (int64_t i = 0; i < uniques.length; i++) {
+ counts[g[i]] += !BitUtil::GetBit(valid, items.offset + i);
+ }
+ }
+
+ return ArrayData::Make(int64(), num_groups_, {nullptr, std::move(values)},
+ /*null_count=*/0);
+ }
+
+ std::shared_ptr<DataType> out_type() const override { return int64(); }
+
+ ExecContext* ctx_;
+ MemoryPool* pool_;
+ int64_t num_groups_;
+ CountOptions options_;
+ std::unique_ptr<Grouper> grouper_;
+ std::shared_ptr<DataType> out_type_;
+};
+
+struct GroupedDistinctImpl : public GroupedCountDistinctImpl {
+ Result<Datum> Finalize() override {
+ ARROW_ASSIGN_OR_RAISE(auto uniques, grouper_->GetUniques());
+ ARROW_ASSIGN_OR_RAISE(auto groupings, grouper_->MakeGroupings(
+ *uniques[1].array_as<UInt32Array>(),
+ static_cast<uint32_t>(num_groups_), ctx_));
+ ARROW_ASSIGN_OR_RAISE(
+ auto list, grouper_->ApplyGroupings(*groupings, *uniques[0].make_array(), ctx_));
+ auto values = list->values();
+ DCHECK_EQ(values->offset(), 0);
+ int32_t* offsets = reinterpret_cast<int32_t*>(list->value_offsets()->mutable_data());
+ if (options_.mode == CountOptions::ALL ||
+ (options_.mode == CountOptions::ONLY_VALID && values->null_count() == 0)) {
+ return list;
+ } else if (options_.mode == CountOptions::ONLY_VALID) {
+ int32_t prev_offset = offsets[0];
+ for (int64_t i = 0; i < list->length(); i++) {
+ const int32_t slot_length = offsets[i + 1] - prev_offset;
+ const int64_t null_count =
+ slot_length - arrow::internal::CountSetBits(values->null_bitmap()->data(),
+ prev_offset, slot_length);
+ DCHECK_LE(null_count, 1);
+ const int32_t offset = null_count > 0 ? slot_length - 1 : slot_length;
+ prev_offset = offsets[i + 1];
+ offsets[i + 1] = offsets[i] + offset;
+ }
+ auto filter =
+ std::make_shared<BooleanArray>(values->length(), values->null_bitmap());
+ ARROW_ASSIGN_OR_RAISE(
+ auto new_values,
+ Filter(std::move(values), filter, FilterOptions(FilterOptions::DROP), ctx_));
+ return std::make_shared<ListArray>(list->type(), list->length(),
+ list->value_offsets(), new_values.make_array());
+ }
+ // ONLY_NULL
+ if (values->null_count() == 0) {
+ std::fill(offsets + 1, offsets + list->length() + 1, offsets[0]);
+ } else {
+ int32_t prev_offset = offsets[0];
+ for (int64_t i = 0; i < list->length(); i++) {
+ const int32_t slot_length = offsets[i + 1] - prev_offset;
+ const int64_t null_count =
+ slot_length - arrow::internal::CountSetBits(values->null_bitmap()->data(),
+ prev_offset, slot_length);
+ const int32_t offset = null_count > 0 ? 1 : 0;
+ prev_offset = offsets[i + 1];
+ offsets[i + 1] = offsets[i] + offset;
+ }
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ auto new_values,
+ MakeArrayOfNull(out_type_,
+ list->length() > 0 ? offsets[list->length()] - offsets[0] : 0,
+ pool_));
+ return std::make_shared<ListArray>(list->type(), list->length(),
+ list->value_offsets(), std::move(new_values));
+ }
+
+ std::shared_ptr<DataType> out_type() const override { return list(out_type_); }
+};
+
+template <typename Impl>
+Result<std::unique_ptr<KernelState>> GroupedDistinctInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ ARROW_ASSIGN_OR_RAISE(auto impl, HashAggregateInit<Impl>(ctx, args));
+ auto instance = static_cast<Impl*>(impl.get());
+ instance->out_type_ = args.inputs[0].type;
+ ARROW_ASSIGN_OR_RAISE(instance->grouper_,
+ Grouper::Make(args.inputs, ctx->exec_context()));
+ return std::move(impl);
+}
+
+} // namespace
+
+Result<std::vector<const HashAggregateKernel*>> GetKernels(
+ ExecContext* ctx, const std::vector<Aggregate>& aggregates,
+ const std::vector<ValueDescr>& in_descrs) {
+ if (aggregates.size() != in_descrs.size()) {
+ return Status::Invalid(aggregates.size(), " aggregate functions were specified but ",
+ in_descrs.size(), " arguments were provided.");
+ }
+
+ std::vector<const HashAggregateKernel*> kernels(in_descrs.size());
+
+ for (size_t i = 0; i < aggregates.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto function,
+ ctx->func_registry()->GetFunction(aggregates[i].function));
+ ARROW_ASSIGN_OR_RAISE(
+ const Kernel* kernel,
+ function->DispatchExact({in_descrs[i], ValueDescr::Array(uint32())}));
+ kernels[i] = static_cast<const HashAggregateKernel*>(kernel);
+ }
+ return kernels;
+}
+
+Result<std::vector<std::unique_ptr<KernelState>>> InitKernels(
+ const std::vector<const HashAggregateKernel*>& kernels, ExecContext* ctx,
+ const std::vector<Aggregate>& aggregates, const std::vector<ValueDescr>& in_descrs) {
+ std::vector<std::unique_ptr<KernelState>> states(kernels.size());
+
+ for (size_t i = 0; i < aggregates.size(); ++i) {
+ auto options = aggregates[i].options;
+
+ if (options == nullptr) {
+ // use known default options for the named function if possible
+ auto maybe_function = ctx->func_registry()->GetFunction(aggregates[i].function);
+ if (maybe_function.ok()) {
+ options = maybe_function.ValueOrDie()->default_options();
+ }
+ }
+
+ KernelContext kernel_ctx{ctx};
+ ARROW_ASSIGN_OR_RAISE(
+ states[i],
+ kernels[i]->init(&kernel_ctx, KernelInitArgs{kernels[i],
+ {
+ in_descrs[i],
+ ValueDescr::Array(uint32()),
+ },
+ options}));
+ }
+
+ return std::move(states);
+}
+
+Result<FieldVector> ResolveKernels(
+ const std::vector<Aggregate>& aggregates,
+ const std::vector<const HashAggregateKernel*>& kernels,
+ const std::vector<std::unique_ptr<KernelState>>& states, ExecContext* ctx,
+ const std::vector<ValueDescr>& descrs) {
+ FieldVector fields(descrs.size());
+
+ for (size_t i = 0; i < kernels.size(); ++i) {
+ KernelContext kernel_ctx{ctx};
+ kernel_ctx.SetState(states[i].get());
+
+ ARROW_ASSIGN_OR_RAISE(auto descr, kernels[i]->signature->out_type().Resolve(
+ &kernel_ctx, {
+ descrs[i],
+ ValueDescr::Array(uint32()),
+ }));
+ fields[i] = field(aggregates[i].function, std::move(descr.type));
+ }
+ return fields;
+}
+
+Result<std::unique_ptr<Grouper>> Grouper::Make(const std::vector<ValueDescr>& descrs,
+ ExecContext* ctx) {
+ if (GrouperFastImpl::CanUse(descrs)) {
+ return GrouperFastImpl::Make(descrs, ctx);
+ }
+ return GrouperImpl::Make(descrs, ctx);
+}
+
+Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Datum>& keys,
+ const std::vector<Aggregate>& aggregates, bool use_threads,
+ ExecContext* ctx) {
+ auto task_group =
+ use_threads
+ ? arrow::internal::TaskGroup::MakeThreaded(arrow::internal::GetCpuThreadPool())
+ : arrow::internal::TaskGroup::MakeSerial();
+
+ // Construct and initialize HashAggregateKernels
+ ARROW_ASSIGN_OR_RAISE(auto argument_descrs,
+ ExecBatch::Make(arguments).Map(
+ [](ExecBatch batch) { return batch.GetDescriptors(); }));
+
+ ARROW_ASSIGN_OR_RAISE(auto kernels, GetKernels(ctx, aggregates, argument_descrs));
+
+ std::vector<std::vector<std::unique_ptr<KernelState>>> states(
+ task_group->parallelism());
+ for (auto& state : states) {
+ ARROW_ASSIGN_OR_RAISE(state, InitKernels(kernels, ctx, aggregates, argument_descrs));
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ FieldVector out_fields,
+ ResolveKernels(aggregates, kernels, states[0], ctx, argument_descrs));
+
+ using arrow::compute::detail::ExecBatchIterator;
+
+ ARROW_ASSIGN_OR_RAISE(auto argument_batch_iterator,
+ ExecBatchIterator::Make(arguments, ctx->exec_chunksize()));
+
+ // Construct Groupers
+ ARROW_ASSIGN_OR_RAISE(auto key_descrs, ExecBatch::Make(keys).Map([](ExecBatch batch) {
+ return batch.GetDescriptors();
+ }));
+
+ std::vector<std::unique_ptr<Grouper>> groupers(task_group->parallelism());
+ for (auto& grouper : groupers) {
+ ARROW_ASSIGN_OR_RAISE(grouper, Grouper::Make(key_descrs, ctx));
+ }
+
+ std::mutex mutex;
+ std::unordered_map<std::thread::id, size_t> thread_ids;
+
+ int i = 0;
+ for (ValueDescr& key_descr : key_descrs) {
+ out_fields.push_back(field("key_" + std::to_string(i++), std::move(key_descr.type)));
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto key_batch_iterator,
+ ExecBatchIterator::Make(keys, ctx->exec_chunksize()));
+
+ // start "streaming" execution
+ ExecBatch key_batch, argument_batch;
+ while (argument_batch_iterator->Next(&argument_batch) &&
+ key_batch_iterator->Next(&key_batch)) {
+ if (key_batch.length == 0) continue;
+
+ task_group->Append([&, key_batch, argument_batch] {
+ size_t thread_index;
+ {
+ std::unique_lock<std::mutex> lock(mutex);
+ auto it = thread_ids.emplace(std::this_thread::get_id(), thread_ids.size()).first;
+ thread_index = it->second;
+ DCHECK_LT(static_cast<int>(thread_index), task_group->parallelism());
+ }
+
+ auto grouper = groupers[thread_index].get();
+
+ // compute a batch of group ids
+ ARROW_ASSIGN_OR_RAISE(Datum id_batch, grouper->Consume(key_batch));
+
+ // consume group ids with HashAggregateKernels
+ for (size_t i = 0; i < kernels.size(); ++i) {
+ KernelContext batch_ctx{ctx};
+ batch_ctx.SetState(states[thread_index][i].get());
+ ARROW_ASSIGN_OR_RAISE(auto batch, ExecBatch::Make({argument_batch[i], id_batch}));
+ RETURN_NOT_OK(kernels[i]->resize(&batch_ctx, grouper->num_groups()));
+ RETURN_NOT_OK(kernels[i]->consume(&batch_ctx, batch));
+ }
+
+ return Status::OK();
+ });
+ }
+
+ RETURN_NOT_OK(task_group->Finish());
+
+ // Merge if necessary
+ for (size_t thread_index = 1; thread_index < thread_ids.size(); ++thread_index) {
+ ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, groupers[thread_index]->GetUniques());
+ ARROW_ASSIGN_OR_RAISE(Datum transposition, groupers[0]->Consume(other_keys));
+ groupers[thread_index].reset();
+
+ for (size_t i = 0; i < kernels.size(); ++i) {
+ KernelContext batch_ctx{ctx};
+ batch_ctx.SetState(states[0][i].get());
+
+ RETURN_NOT_OK(kernels[i]->resize(&batch_ctx, groupers[0]->num_groups()));
+ RETURN_NOT_OK(kernels[i]->merge(&batch_ctx, std::move(*states[thread_index][i]),
+ *transposition.array()));
+ states[thread_index][i].reset();
+ }
+ }
+
+ // Finalize output
+ ArrayDataVector out_data(arguments.size() + keys.size());
+ auto it = out_data.begin();
+
+ for (size_t i = 0; i < kernels.size(); ++i) {
+ KernelContext batch_ctx{ctx};
+ batch_ctx.SetState(states[0][i].get());
+ Datum out;
+ RETURN_NOT_OK(kernels[i]->finalize(&batch_ctx, &out));
+ *it++ = out.array();
+ }
+
+ ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, groupers[0]->GetUniques());
+ for (const auto& key : out_keys.values) {
+ *it++ = key.array();
+ }
+
+ int64_t length = out_data[0]->length;
+ return ArrayData::Make(struct_(std::move(out_fields)), length,
+ {/*null_bitmap=*/nullptr}, std::move(out_data),
+ /*null_count=*/0);
+}
+
+Result<std::shared_ptr<ListArray>> Grouper::ApplyGroupings(const ListArray& groupings,
+ const Array& array,
+ ExecContext* ctx) {
+ ARROW_ASSIGN_OR_RAISE(Datum sorted,
+ compute::Take(array, groupings.data()->child_data[0],
+ TakeOptions::NoBoundsCheck(), ctx));
+
+ return std::make_shared<ListArray>(list(array.type()), groupings.length(),
+ groupings.value_offsets(), sorted.make_array());
+}
+
+Result<std::shared_ptr<ListArray>> Grouper::MakeGroupings(const UInt32Array& ids,
+ uint32_t num_groups,
+ ExecContext* ctx) {
+ if (ids.null_count() != 0) {
+ return Status::Invalid("MakeGroupings with null ids");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto offsets, AllocateBuffer(sizeof(int32_t) * (num_groups + 1),
+ ctx->memory_pool()));
+ auto raw_offsets = reinterpret_cast<int32_t*>(offsets->mutable_data());
+
+ std::memset(raw_offsets, 0, offsets->size());
+ for (int i = 0; i < ids.length(); ++i) {
+ DCHECK_LT(ids.Value(i), num_groups);
+ raw_offsets[ids.Value(i)] += 1;
+ }
+ int32_t length = 0;
+ for (uint32_t id = 0; id < num_groups; ++id) {
+ auto offset = raw_offsets[id];
+ raw_offsets[id] = length;
+ length += offset;
+ }
+ raw_offsets[num_groups] = length;
+ DCHECK_EQ(ids.length(), length);
+
+ ARROW_ASSIGN_OR_RAISE(auto offsets_copy,
+ offsets->CopySlice(0, offsets->size(), ctx->memory_pool()));
+ raw_offsets = reinterpret_cast<int32_t*>(offsets_copy->mutable_data());
+
+ ARROW_ASSIGN_OR_RAISE(auto sort_indices, AllocateBuffer(sizeof(int32_t) * ids.length(),
+ ctx->memory_pool()));
+ auto raw_sort_indices = reinterpret_cast<int32_t*>(sort_indices->mutable_data());
+ for (int i = 0; i < ids.length(); ++i) {
+ raw_sort_indices[raw_offsets[ids.Value(i)]++] = i;
+ }
+
+ return std::make_shared<ListArray>(
+ list(int32()), num_groups, std::move(offsets),
+ std::make_shared<Int32Array>(ids.length(), std::move(sort_indices)));
+}
+
+namespace {
+const FunctionDoc hash_count_doc{"Count the number of null / non-null values",
+ ("By default, non-null values are counted.\n"
+ "This can be changed through ScalarAggregateOptions."),
+ {"array", "group_id_array"},
+ "CountOptions"};
+
+const FunctionDoc hash_sum_doc{"Sum values of a numeric array",
+ ("Null values are ignored."),
+ {"array", "group_id_array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc hash_product_doc{
+ "Compute product of values of a numeric array",
+ ("Null values are ignored.\n"
+ "Overflow will wrap around as if the calculation was done with unsigned integers."),
+ {"array", "group_id_array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc hash_mean_doc{
+ "Average values of a numeric array",
+ ("Null values are ignored.\n"
+ "For integers and floats, NaN is returned if min_count = 0 and\n"
+ "there are no values. For decimals, null is returned instead."),
+ {"array", "group_id_array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc hash_stddev_doc{
+ "Calculate the standard deviation of a numeric array",
+ ("The number of degrees of freedom can be controlled using VarianceOptions.\n"
+ "By default (`ddof` = 0), the population standard deviation is calculated.\n"
+ "Nulls are ignored. If there are not enough non-null values in the array\n"
+ "to satisfy `ddof`, null is returned."),
+ {"array", "group_id_array"}};
+
+const FunctionDoc hash_variance_doc{
+ "Calculate the variance of a numeric array",
+ ("The number of degrees of freedom can be controlled using VarianceOptions.\n"
+ "By default (`ddof` = 0), the population variance is calculated.\n"
+ "Nulls are ignored. If there are not enough non-null values in the array\n"
+ "to satisfy `ddof`, null is returned."),
+ {"array", "group_id_array"}};
+
+const FunctionDoc hash_tdigest_doc{
+ "Calculate approximate quantiles of a numeric array with the T-Digest algorithm",
+ ("By default, the 0.5 quantile (median) is returned.\n"
+ "Nulls and NaNs are ignored.\n"
+ "A array of nulls is returned if there are no valid data points."),
+ {"array", "group_id_array"},
+ "TDigestOptions"};
+
+const FunctionDoc hash_approximate_median_doc{
+ "Calculate approximate medians of a numeric array with the T-Digest algorithm",
+ ("Nulls and NaNs are ignored.\n"
+ "Null is emitted for a group if there are no valid data points."),
+ {"array", "group_id_array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc hash_min_max_doc{
+ "Compute the minimum and maximum values of a numeric array",
+ ("Null values are ignored by default.\n"
+ "This can be changed through ScalarAggregateOptions."),
+ {"array", "group_id_array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc hash_min_or_max_doc{
+ "Compute the minimum or maximum values of a numeric array",
+ ("Null values are ignored by default.\n"
+ "This can be changed through ScalarAggregateOptions."),
+ {"array", "group_id_array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc hash_any_doc{"Test whether any element evaluates to true",
+ ("Null values are ignored."),
+ {"array", "group_id_array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc hash_all_doc{"Test whether all elements evaluate to true",
+ ("Null values are ignored."),
+ {"array", "group_id_array"},
+ "ScalarAggregateOptions"};
+
+const FunctionDoc hash_count_distinct_doc{
+ "Count the distinct values in each group",
+ ("Whether nulls/values are counted is controlled by CountOptions.\n"
+ "NaNs and signed zeroes are not normalized."),
+ {"array", "group_id_array"},
+ "CountOptions"};
+
+const FunctionDoc hash_distinct_doc{
+ "Keep the distinct values in each group",
+ ("Whether nulls/values are kept is controlled by CountOptions.\n"
+ "NaNs and signed zeroes are not normalized."),
+ {"array", "group_id_array"},
+ "CountOptions"};
+} // namespace
+
+void RegisterHashAggregateBasic(FunctionRegistry* registry) {
+ static auto default_count_options = CountOptions::Defaults();
+ static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults();
+ static auto default_tdigest_options = TDigestOptions::Defaults();
+ static auto default_variance_options = VarianceOptions::Defaults();
+
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_count", Arity::Binary(), &hash_count_doc, &default_count_options);
+
+ DCHECK_OK(func->AddKernel(
+ MakeKernel(ValueDescr::ARRAY, HashAggregateInit<GroupedCountImpl>)));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_sum", Arity::Binary(), &hash_sum_doc, &default_scalar_aggregate_options);
+ DCHECK_OK(AddHashAggKernels({boolean()}, GroupedSumFactory::Make, func.get()));
+ DCHECK_OK(AddHashAggKernels(SignedIntTypes(), GroupedSumFactory::Make, func.get()));
+ DCHECK_OK(AddHashAggKernels(UnsignedIntTypes(), GroupedSumFactory::Make, func.get()));
+ DCHECK_OK(
+ AddHashAggKernels(FloatingPointTypes(), GroupedSumFactory::Make, func.get()));
+ // Type parameters are ignored
+ DCHECK_OK(AddHashAggKernels({decimal128(1, 1), decimal256(1, 1)},
+ GroupedSumFactory::Make, func.get()));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_product", Arity::Binary(), &hash_product_doc,
+ &default_scalar_aggregate_options);
+ DCHECK_OK(AddHashAggKernels({boolean()}, GroupedProductFactory::Make, func.get()));
+ DCHECK_OK(
+ AddHashAggKernels(SignedIntTypes(), GroupedProductFactory::Make, func.get()));
+ DCHECK_OK(
+ AddHashAggKernels(UnsignedIntTypes(), GroupedProductFactory::Make, func.get()));
+ DCHECK_OK(
+ AddHashAggKernels(FloatingPointTypes(), GroupedProductFactory::Make, func.get()));
+ // Type parameters are ignored
+ DCHECK_OK(AddHashAggKernels({decimal128(1, 1), decimal256(1, 1)},
+ GroupedProductFactory::Make, func.get()));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_mean", Arity::Binary(), &hash_mean_doc, &default_scalar_aggregate_options);
+ DCHECK_OK(AddHashAggKernels({boolean()}, GroupedMeanFactory::Make, func.get()));
+ DCHECK_OK(AddHashAggKernels(SignedIntTypes(), GroupedMeanFactory::Make, func.get()));
+ DCHECK_OK(
+ AddHashAggKernels(UnsignedIntTypes(), GroupedMeanFactory::Make, func.get()));
+ DCHECK_OK(
+ AddHashAggKernels(FloatingPointTypes(), GroupedMeanFactory::Make, func.get()));
+ // Type parameters are ignored
+ DCHECK_OK(AddHashAggKernels({decimal128(1, 1), decimal256(1, 1)},
+ GroupedMeanFactory::Make, func.get()));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_stddev", Arity::Binary(), &hash_stddev_doc, &default_variance_options);
+ DCHECK_OK(AddHashAggKernels(SignedIntTypes(),
+ GroupedVarStdFactory<VarOrStd::Std>::Make, func.get()));
+ DCHECK_OK(AddHashAggKernels(UnsignedIntTypes(),
+ GroupedVarStdFactory<VarOrStd::Std>::Make, func.get()));
+ DCHECK_OK(AddHashAggKernels(FloatingPointTypes(),
+ GroupedVarStdFactory<VarOrStd::Std>::Make, func.get()));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_variance", Arity::Binary(), &hash_variance_doc, &default_variance_options);
+ DCHECK_OK(AddHashAggKernels(SignedIntTypes(),
+ GroupedVarStdFactory<VarOrStd::Var>::Make, func.get()));
+ DCHECK_OK(AddHashAggKernels(UnsignedIntTypes(),
+ GroupedVarStdFactory<VarOrStd::Var>::Make, func.get()));
+ DCHECK_OK(AddHashAggKernels(FloatingPointTypes(),
+ GroupedVarStdFactory<VarOrStd::Var>::Make, func.get()));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ HashAggregateFunction* tdigest_func = nullptr;
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_tdigest", Arity::Binary(), &hash_tdigest_doc, &default_tdigest_options);
+ DCHECK_OK(
+ AddHashAggKernels(SignedIntTypes(), GroupedTDigestFactory::Make, func.get()));
+ DCHECK_OK(
+ AddHashAggKernels(UnsignedIntTypes(), GroupedTDigestFactory::Make, func.get()));
+ DCHECK_OK(
+ AddHashAggKernels(FloatingPointTypes(), GroupedTDigestFactory::Make, func.get()));
+ tdigest_func = func.get();
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_approximate_median", Arity::Binary(), &hash_approximate_median_doc,
+ &default_scalar_aggregate_options);
+ DCHECK_OK(func->AddKernel(MakeApproximateMedianKernel(tdigest_func)));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ HashAggregateFunction* min_max_func = nullptr;
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_min_max", Arity::Binary(), &hash_min_max_doc,
+ &default_scalar_aggregate_options);
+ DCHECK_OK(AddHashAggKernels(NumericTypes(), GroupedMinMaxFactory::Make, func.get()));
+ DCHECK_OK(AddHashAggKernels(TemporalTypes(), GroupedMinMaxFactory::Make, func.get()));
+ // Type parameters are ignored
+ DCHECK_OK(AddHashAggKernels(
+ {null(), boolean(), decimal128(1, 1), decimal256(1, 1), month_interval()},
+ GroupedMinMaxFactory::Make, func.get()));
+ min_max_func = func.get();
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_min", Arity::Binary(), &hash_min_or_max_doc,
+ &default_scalar_aggregate_options);
+ DCHECK_OK(func->AddKernel(MakeMinOrMaxKernel<MinOrMax::Min>(min_max_func)));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_max", Arity::Binary(), &hash_min_or_max_doc,
+ &default_scalar_aggregate_options);
+ DCHECK_OK(func->AddKernel(MakeMinOrMaxKernel<MinOrMax::Max>(min_max_func)));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_any", Arity::Binary(), &hash_any_doc, &default_scalar_aggregate_options);
+ DCHECK_OK(func->AddKernel(MakeKernel(boolean(), HashAggregateInit<GroupedAnyImpl>)));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_all", Arity::Binary(), &hash_all_doc, &default_scalar_aggregate_options);
+ DCHECK_OK(func->AddKernel(MakeKernel(boolean(), HashAggregateInit<GroupedAllImpl>)));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_count_distinct", Arity::Binary(), &hash_count_distinct_doc,
+ &default_count_options);
+ DCHECK_OK(func->AddKernel(
+ MakeKernel(ValueDescr::ARRAY, GroupedDistinctInit<GroupedCountDistinctImpl>)));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared<HashAggregateFunction>(
+ "hash_distinct", Arity::Binary(), &hash_distinct_doc, &default_count_options);
+ DCHECK_OK(func->AddKernel(
+ MakeKernel(ValueDescr::ARRAY, GroupedDistinctInit<GroupedDistinctImpl>)));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
new file mode 100644
index 000000000..b98a36909
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
@@ -0,0 +1,2612 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <algorithm>
+#include <limits>
+#include <memory>
+#include <type_traits>
+#include <unordered_map>
+#include <utility>
+
+#include "arrow/array.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/chunked_array.h"
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/test_util.h"
+#include "arrow/compute/exec_internal.h"
+#include "arrow/compute/kernels/aggregate_internal.h"
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/compute/registry.h"
+#include "arrow/table.h"
+#include "arrow/testing/generator.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+#include "arrow/util/vector.h"
+
+using testing::Eq;
+using testing::HasSubstr;
+
+namespace arrow {
+
+using internal::BitmapReader;
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+namespace {
+
+Result<Datum> NaiveGroupBy(std::vector<Datum> arguments, std::vector<Datum> keys,
+ const std::vector<internal::Aggregate>& aggregates) {
+ ARROW_ASSIGN_OR_RAISE(auto key_batch, ExecBatch::Make(std::move(keys)));
+
+ ARROW_ASSIGN_OR_RAISE(auto grouper,
+ internal::Grouper::Make(key_batch.GetDescriptors()));
+
+ ARROW_ASSIGN_OR_RAISE(Datum id_batch, grouper->Consume(key_batch));
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto groupings, internal::Grouper::MakeGroupings(*id_batch.array_as<UInt32Array>(),
+ grouper->num_groups()));
+
+ ArrayVector out_columns;
+ std::vector<std::string> out_names;
+
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ out_names.push_back(aggregates[i].function);
+
+ // trim "hash_" prefix
+ auto scalar_agg_function = aggregates[i].function.substr(5);
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto grouped_argument,
+ internal::Grouper::ApplyGroupings(*groupings, *arguments[i].make_array()));
+
+ ScalarVector aggregated_scalars;
+
+ for (int64_t i_group = 0; i_group < grouper->num_groups(); ++i_group) {
+ auto slice = grouped_argument->value_slice(i_group);
+ if (slice->length() == 0) continue;
+ ARROW_ASSIGN_OR_RAISE(
+ Datum d, CallFunction(scalar_agg_function, {slice}, aggregates[i].options));
+ aggregated_scalars.push_back(d.scalar());
+ }
+
+ ARROW_ASSIGN_OR_RAISE(Datum aggregated_column,
+ ScalarVectorToArray(aggregated_scalars));
+ out_columns.push_back(aggregated_column.make_array());
+ }
+
+ int i = 0;
+ ARROW_ASSIGN_OR_RAISE(auto uniques, grouper->GetUniques());
+ for (const Datum& key : uniques.values) {
+ out_columns.push_back(key.make_array());
+ out_names.push_back("key_" + std::to_string(i++));
+ }
+
+ return StructArray::Make(std::move(out_columns), std::move(out_names));
+}
+
+Result<Datum> GroupByUsingExecPlan(const BatchesWithSchema& input,
+ const std::vector<std::string>& key_names,
+ const std::vector<std::string>& arg_names,
+ const std::vector<internal::Aggregate>& aggregates,
+ bool use_threads, ExecContext* ctx) {
+ std::vector<FieldRef> keys(key_names.size());
+ std::vector<FieldRef> targets(aggregates.size());
+ std::vector<std::string> names(aggregates.size());
+ for (size_t i = 0; i < aggregates.size(); ++i) {
+ names[i] = aggregates[i].function;
+ targets[i] = FieldRef(arg_names[i]);
+ }
+ for (size_t i = 0; i < key_names.size(); ++i) {
+ keys[i] = FieldRef(key_names[i]);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make(ctx));
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+ RETURN_NOT_OK(
+ Declaration::Sequence(
+ {
+ {"source",
+ SourceNodeOptions{input.schema, input.gen(use_threads, /*slow=*/false)}},
+ {"aggregate",
+ AggregateNodeOptions{std::move(aggregates), std::move(targets),
+ std::move(names), std::move(keys)}},
+ {"sink", SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ RETURN_NOT_OK(plan->Validate());
+ RETURN_NOT_OK(plan->StartProducing());
+
+ auto collected_fut = CollectAsyncGenerator(sink_gen);
+
+ auto start_and_collect =
+ AllComplete({plan->finished(), Future<>(collected_fut)})
+ .Then([collected_fut]() -> Result<std::vector<ExecBatch>> {
+ ARROW_ASSIGN_OR_RAISE(auto collected, collected_fut.result());
+ return ::arrow::internal::MapVector(
+ [](util::optional<ExecBatch> batch) { return std::move(*batch); },
+ std::move(collected));
+ });
+
+ ARROW_ASSIGN_OR_RAISE(std::vector<ExecBatch> output_batches,
+ start_and_collect.MoveResult());
+
+ ArrayVector out_arrays(aggregates.size() + key_names.size());
+ const auto& output_schema = plan->sources()[0]->outputs()[0]->output_schema();
+ for (size_t i = 0; i < out_arrays.size(); ++i) {
+ std::vector<std::shared_ptr<Array>> arrays(output_batches.size());
+ for (size_t j = 0; j < output_batches.size(); ++j) {
+ arrays[j] = output_batches[j].values[i].make_array();
+ }
+ if (arrays.empty()) {
+ ARROW_ASSIGN_OR_RAISE(
+ out_arrays[i],
+ MakeArrayOfNull(output_schema->field(static_cast<int>(i))->type(),
+ /*length=*/0));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(out_arrays[i], Concatenate(arrays));
+ }
+ }
+
+ return StructArray::Make(std::move(out_arrays), output_schema->fields());
+}
+
+/// Simpler overload where you can give the columns as datums
+Result<Datum> GroupByUsingExecPlan(const std::vector<Datum>& arguments,
+ const std::vector<Datum>& keys,
+ const std::vector<internal::Aggregate>& aggregates,
+ bool use_threads, ExecContext* ctx) {
+ using arrow::compute::detail::ExecBatchIterator;
+
+ FieldVector scan_fields(arguments.size() + keys.size());
+ std::vector<std::string> key_names(keys.size());
+ std::vector<std::string> arg_names(arguments.size());
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ auto name = std::string("agg_") + std::to_string(i);
+ scan_fields[i] = field(name, arguments[i].type());
+ arg_names[i] = std::move(name);
+ }
+ for (size_t i = 0; i < keys.size(); ++i) {
+ auto name = std::string("key_") + std::to_string(i);
+ scan_fields[arguments.size() + i] = field(name, keys[i].type());
+ key_names[i] = std::move(name);
+ }
+
+ std::vector<Datum> inputs = arguments;
+ inputs.reserve(inputs.size() + keys.size());
+ inputs.insert(inputs.end(), keys.begin(), keys.end());
+
+ ARROW_ASSIGN_OR_RAISE(auto batch_iterator,
+ ExecBatchIterator::Make(inputs, ctx->exec_chunksize()));
+ BatchesWithSchema input;
+ input.schema = schema(std::move(scan_fields));
+ ExecBatch batch;
+ while (batch_iterator->Next(&batch)) {
+ if (batch.length == 0) continue;
+ input.batches.push_back(std::move(batch));
+ }
+
+ return GroupByUsingExecPlan(input, key_names, arg_names, aggregates, use_threads, ctx);
+}
+
+void ValidateGroupBy(const std::vector<internal::Aggregate>& aggregates,
+ std::vector<Datum> arguments, std::vector<Datum> keys) {
+ ASSERT_OK_AND_ASSIGN(Datum expected, NaiveGroupBy(arguments, keys, aggregates));
+
+ ASSERT_OK_AND_ASSIGN(Datum actual, GroupBy(arguments, keys, aggregates));
+
+ ASSERT_OK(expected.make_array()->ValidateFull());
+ ValidateOutput(actual);
+
+ AssertDatumsEqual(expected, actual, /*verbose=*/true);
+}
+
+ExecContext* small_chunksize_context(bool use_threads = false) {
+ static ExecContext ctx,
+ ctx_with_threads{default_memory_pool(), arrow::internal::GetCpuThreadPool()};
+ ctx.set_exec_chunksize(2);
+ ctx_with_threads.set_exec_chunksize(2);
+ return use_threads ? &ctx_with_threads : &ctx;
+}
+
+Result<Datum> GroupByTest(
+ const std::vector<Datum>& arguments, const std::vector<Datum>& keys,
+ const std::vector<::arrow::compute::internal::Aggregate>& aggregates,
+ bool use_threads, bool use_exec_plan) {
+ if (use_exec_plan) {
+ return GroupByUsingExecPlan(arguments, keys, aggregates, use_threads,
+ small_chunksize_context(use_threads));
+ } else {
+ return internal::GroupBy(arguments, keys, aggregates, use_threads,
+ default_exec_context());
+ }
+}
+
+} // namespace
+
+TEST(Grouper, SupportedKeys) {
+ ASSERT_OK(internal::Grouper::Make({boolean()}));
+
+ ASSERT_OK(internal::Grouper::Make({int8(), uint16(), int32(), uint64()}));
+
+ ASSERT_OK(internal::Grouper::Make({dictionary(int64(), utf8())}));
+
+ ASSERT_OK(internal::Grouper::Make({float16(), float32(), float64()}));
+
+ ASSERT_OK(internal::Grouper::Make({utf8(), binary(), large_utf8(), large_binary()}));
+
+ ASSERT_OK(internal::Grouper::Make({fixed_size_binary(16), fixed_size_binary(32)}));
+
+ ASSERT_OK(internal::Grouper::Make({decimal128(32, 10), decimal256(76, 20)}));
+
+ ASSERT_OK(internal::Grouper::Make({date32(), date64()}));
+
+ for (auto unit : {
+ TimeUnit::SECOND,
+ TimeUnit::MILLI,
+ TimeUnit::MICRO,
+ TimeUnit::NANO,
+ }) {
+ ASSERT_OK(internal::Grouper::Make({timestamp(unit), duration(unit)}));
+ }
+
+ ASSERT_OK(internal::Grouper::Make({day_time_interval(), month_interval()}));
+
+ ASSERT_RAISES(NotImplemented, internal::Grouper::Make({struct_({field("", int64())})}));
+
+ ASSERT_RAISES(NotImplemented, internal::Grouper::Make({struct_({})}));
+
+ ASSERT_RAISES(NotImplemented, internal::Grouper::Make({list(int32())}));
+
+ ASSERT_RAISES(NotImplemented, internal::Grouper::Make({fixed_size_list(int32(), 5)}));
+
+ ASSERT_RAISES(NotImplemented,
+ internal::Grouper::Make({dense_union({field("", int32())})}));
+}
+
+struct TestGrouper {
+ explicit TestGrouper(std::vector<ValueDescr> descrs) : descrs_(std::move(descrs)) {
+ grouper_ = internal::Grouper::Make(descrs_).ValueOrDie();
+
+ FieldVector fields;
+ for (const auto& descr : descrs_) {
+ fields.push_back(field("", descr.type));
+ }
+ key_schema_ = schema(std::move(fields));
+ }
+
+ void ExpectConsume(const std::string& key_json, const std::string& expected) {
+ ExpectConsume(ExecBatchFromJSON(descrs_, key_json),
+ ArrayFromJSON(uint32(), expected));
+ }
+
+ void ExpectConsume(const std::vector<Datum>& key_values, Datum expected) {
+ ASSERT_OK_AND_ASSIGN(auto key_batch, ExecBatch::Make(key_values));
+ ExpectConsume(key_batch, expected);
+ }
+
+ void ExpectConsume(const ExecBatch& key_batch, Datum expected) {
+ Datum ids;
+ ConsumeAndValidate(key_batch, &ids);
+ AssertEquivalentIds(expected, ids);
+ }
+
+ void ExpectUniques(const ExecBatch& uniques) {
+ EXPECT_THAT(grouper_->GetUniques(), ResultWith(Eq(uniques)));
+ }
+
+ void ExpectUniques(const std::string& uniques_json) {
+ ExpectUniques(ExecBatchFromJSON(descrs_, uniques_json));
+ }
+
+ void AssertEquivalentIds(const Datum& expected, const Datum& actual) {
+ auto left = expected.make_array();
+ auto right = actual.make_array();
+ ASSERT_EQ(left->length(), right->length()) << "#ids unequal";
+ int64_t num_ids = left->length();
+ auto left_data = left->data();
+ auto right_data = right->data();
+ auto left_ids = reinterpret_cast<const uint32_t*>(left_data->buffers[1]->data());
+ auto right_ids = reinterpret_cast<const uint32_t*>(right_data->buffers[1]->data());
+ uint32_t max_left_id = 0;
+ uint32_t max_right_id = 0;
+ for (int64_t i = 0; i < num_ids; ++i) {
+ if (left_ids[i] > max_left_id) {
+ max_left_id = left_ids[i];
+ }
+ if (right_ids[i] > max_right_id) {
+ max_right_id = right_ids[i];
+ }
+ }
+ std::vector<bool> right_to_left_present(max_right_id + 1, false);
+ std::vector<bool> left_to_right_present(max_left_id + 1, false);
+ std::vector<uint32_t> right_to_left(max_right_id + 1);
+ std::vector<uint32_t> left_to_right(max_left_id + 1);
+ for (int64_t i = 0; i < num_ids; ++i) {
+ uint32_t left_id = left_ids[i];
+ uint32_t right_id = right_ids[i];
+ if (!left_to_right_present[left_id]) {
+ left_to_right[left_id] = right_id;
+ left_to_right_present[left_id] = true;
+ }
+ if (!right_to_left_present[right_id]) {
+ right_to_left[right_id] = left_id;
+ right_to_left_present[right_id] = true;
+ }
+ ASSERT_EQ(left_id, right_to_left[right_id]);
+ ASSERT_EQ(right_id, left_to_right[left_id]);
+ }
+ }
+
+ void ConsumeAndValidate(const ExecBatch& key_batch, Datum* ids = nullptr) {
+ ASSERT_OK_AND_ASSIGN(Datum id_batch, grouper_->Consume(key_batch));
+
+ ValidateConsume(key_batch, id_batch);
+
+ if (ids) {
+ *ids = std::move(id_batch);
+ }
+ }
+
+ void ValidateConsume(const ExecBatch& key_batch, const Datum& id_batch) {
+ if (uniques_.length == -1) {
+ ASSERT_OK_AND_ASSIGN(uniques_, grouper_->GetUniques());
+ } else if (static_cast<int64_t>(grouper_->num_groups()) > uniques_.length) {
+ ASSERT_OK_AND_ASSIGN(ExecBatch new_uniques, grouper_->GetUniques());
+
+ // check that uniques_ are prefixes of new_uniques
+ for (int i = 0; i < uniques_.num_values(); ++i) {
+ auto new_unique = new_uniques[i].make_array();
+ ValidateOutput(*new_unique);
+
+ AssertDatumsEqual(uniques_[i], new_unique->Slice(0, uniques_.length),
+ /*verbose=*/true);
+ }
+
+ uniques_ = std::move(new_uniques);
+ }
+
+ // check that the ids encode an equivalent key sequence
+ auto ids = id_batch.make_array();
+ ValidateOutput(*ids);
+
+ for (int i = 0; i < key_batch.num_values(); ++i) {
+ SCOPED_TRACE(std::to_string(i) + "th key array");
+ auto original =
+ key_batch[i].is_array()
+ ? key_batch[i].make_array()
+ : *MakeArrayFromScalar(*key_batch[i].scalar(), key_batch.length);
+ ASSERT_OK_AND_ASSIGN(auto encoded, Take(*uniques_[i].make_array(), *ids));
+ AssertArraysEqual(*original, *encoded, /*verbose=*/true,
+ EqualOptions().nans_equal(true));
+ }
+ }
+
+ std::vector<ValueDescr> descrs_;
+ std::shared_ptr<Schema> key_schema_;
+ std::unique_ptr<internal::Grouper> grouper_;
+ ExecBatch uniques_ = ExecBatch({}, -1);
+};
+
+TEST(Grouper, BooleanKey) {
+ TestGrouper g({boolean()});
+
+ g.ExpectConsume("[[true], [true]]", "[0, 0]");
+
+ g.ExpectConsume("[[true], [true]]", "[0, 0]");
+
+ g.ExpectConsume("[[false], [null]]", "[1, 2]");
+
+ g.ExpectConsume("[[true], [false], [true], [false], [null], [false], [null]]",
+ "[0, 1, 0, 1, 2, 1, 2]");
+}
+
+TEST(Grouper, NumericKey) {
+ for (auto ty : {
+ uint8(),
+ int8(),
+ uint16(),
+ int16(),
+ uint32(),
+ int32(),
+ uint64(),
+ int64(),
+ float16(),
+ float32(),
+ float64(),
+ }) {
+ SCOPED_TRACE("key type: " + ty->ToString());
+
+ TestGrouper g({ty});
+
+ g.ExpectConsume("[[3], [3]]", "[0, 0]");
+ g.ExpectUniques("[[3]]");
+
+ g.ExpectConsume("[[3], [3]]", "[0, 0]");
+ g.ExpectUniques("[[3]]");
+
+ g.ExpectConsume("[[27], [81], [81]]", "[1, 2, 2]");
+ g.ExpectUniques("[[3], [27], [81]]");
+
+ g.ExpectConsume("[[3], [27], [3], [27], [null], [81], [27], [81]]",
+ "[0, 1, 0, 1, 3, 2, 1, 2]");
+ g.ExpectUniques("[[3], [27], [81], [null]]");
+ }
+}
+
+TEST(Grouper, FloatingPointKey) {
+ TestGrouper g({float32()});
+
+ // -0.0 hashes differently from 0.0
+ g.ExpectConsume("[[0.0], [-0.0]]", "[0, 1]");
+
+ g.ExpectConsume("[[Inf], [-Inf]]", "[2, 3]");
+
+ // assert(!(NaN == NaN)) does not cause spurious new groups
+ g.ExpectConsume("[[NaN], [NaN]]", "[4, 4]");
+
+ // TODO(bkietz) test denormal numbers, more NaNs
+}
+
+TEST(Grouper, StringKey) {
+ for (auto ty : {utf8(), large_utf8(), fixed_size_binary(2)}) {
+ SCOPED_TRACE("key type: " + ty->ToString());
+
+ TestGrouper g({ty});
+
+ g.ExpectConsume(R"([["eh"], ["eh"]])", "[0, 0]");
+
+ g.ExpectConsume(R"([["eh"], ["eh"]])", "[0, 0]");
+
+ g.ExpectConsume(R"([["be"], [null]])", "[1, 2]");
+ }
+}
+
+TEST(Grouper, DictKey) {
+ TestGrouper g({dictionary(int32(), utf8())});
+
+ // For dictionary keys, all batches must share a single dictionary.
+ // Eventually, differing dictionaries will be unified and indices transposed
+ // during encoding to relieve this restriction.
+ const auto dict = ArrayFromJSON(utf8(), R"(["ex", "why", "zee", null])");
+
+ auto WithIndices = [&](const std::string& indices) {
+ return Datum(*DictionaryArray::FromArrays(ArrayFromJSON(int32(), indices), dict));
+ };
+
+ // NB: null index is not considered equivalent to index=3 (which encodes null in dict)
+ g.ExpectConsume({WithIndices(" [3, 1, null, 0, 2]")},
+ ArrayFromJSON(uint32(), "[0, 1, 2, 3, 4]"));
+
+ g = TestGrouper({dictionary(int32(), utf8())});
+
+ g.ExpectConsume({WithIndices(" [0, 1, 2, 3, null]")},
+ ArrayFromJSON(uint32(), "[0, 1, 2, 3, 4]"));
+
+ g.ExpectConsume({WithIndices(" [3, 1, null, 0, 2]")},
+ ArrayFromJSON(uint32(), "[3, 1, 4, 0, 2]"));
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ NotImplemented, HasSubstr("Unifying differing dictionaries"),
+ g.grouper_->Consume(*ExecBatch::Make({*DictionaryArray::FromArrays(
+ ArrayFromJSON(int32(), "[0, 1]"),
+ ArrayFromJSON(utf8(), R"(["different", "dictionary"])"))})));
+}
+
+TEST(Grouper, StringInt64Key) {
+ TestGrouper g({utf8(), int64()});
+
+ g.ExpectConsume(R"([["eh", 0], ["eh", 0]])", "[0, 0]");
+
+ g.ExpectConsume(R"([["eh", 0], ["eh", null]])", "[0, 1]");
+
+ g.ExpectConsume(R"([["eh", 1], ["bee", 1]])", "[2, 3]");
+
+ g.ExpectConsume(R"([["eh", null], ["bee", 1]])", "[1, 3]");
+
+ g = TestGrouper({utf8(), int64()});
+
+ g.ExpectConsume(R"([
+ ["ex", 0],
+ ["ex", 0],
+ ["why", 0],
+ ["ex", 1],
+ ["why", 0],
+ ["ex", 1],
+ ["ex", 0],
+ ["why", 1]
+ ])",
+ "[0, 0, 1, 2, 1, 2, 0, 3]");
+
+ g.ExpectConsume(R"([
+ ["ex", 0],
+ [null, 0],
+ [null, 0],
+ ["ex", 1],
+ [null, null],
+ ["ex", 1],
+ ["ex", 0],
+ ["why", null]
+ ])",
+ "[0, 4, 4, 2, 5, 2, 0, 6]");
+}
+
+TEST(Grouper, DoubleStringInt64Key) {
+ TestGrouper g({float64(), utf8(), int64()});
+
+ g.ExpectConsume(R"([[1.5, "eh", 0], [1.5, "eh", 0]])", "[0, 0]");
+
+ g.ExpectConsume(R"([[1.5, "eh", 0], [1.5, "eh", 0]])", "[0, 0]");
+
+ g.ExpectConsume(R"([[1.0, "eh", 0], [1.0, "be", null]])", "[1, 2]");
+
+ // note: -0 and +0 hash differently
+ g.ExpectConsume(R"([[-0.0, "be", 7], [0.0, "be", 7]])", "[3, 4]");
+}
+
+TEST(Grouper, RandomInt64Keys) {
+ TestGrouper g({int64()});
+ for (int i = 0; i < 4; ++i) {
+ SCOPED_TRACE(std::to_string(i) + "th key batch");
+
+ ExecBatch key_batch{
+ *random::GenerateBatch(g.key_schema_->fields(), 1 << 12, 0xDEADBEEF)};
+ g.ConsumeAndValidate(key_batch);
+ }
+}
+
+TEST(Grouper, RandomStringInt64Keys) {
+ TestGrouper g({utf8(), int64()});
+ for (int i = 0; i < 4; ++i) {
+ SCOPED_TRACE(std::to_string(i) + "th key batch");
+
+ ExecBatch key_batch{
+ *random::GenerateBatch(g.key_schema_->fields(), 1 << 12, 0xDEADBEEF)};
+ g.ConsumeAndValidate(key_batch);
+ }
+}
+
+TEST(Grouper, RandomStringInt64DoubleInt32Keys) {
+ TestGrouper g({utf8(), int64(), float64(), int32()});
+ for (int i = 0; i < 4; ++i) {
+ SCOPED_TRACE(std::to_string(i) + "th key batch");
+
+ ExecBatch key_batch{
+ *random::GenerateBatch(g.key_schema_->fields(), 1 << 12, 0xDEADBEEF)};
+ g.ConsumeAndValidate(key_batch);
+ }
+}
+
+TEST(Grouper, MakeGroupings) {
+ auto ExpectGroupings = [](std::string ids_json, std::string expected_json) {
+ auto ids = checked_pointer_cast<UInt32Array>(ArrayFromJSON(uint32(), ids_json));
+ auto expected = ArrayFromJSON(list(int32()), expected_json);
+
+ auto num_groups = static_cast<uint32_t>(expected->length());
+ ASSERT_OK_AND_ASSIGN(auto actual, internal::Grouper::MakeGroupings(*ids, num_groups));
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+
+ // validate ApplyGroupings
+ ASSERT_OK_AND_ASSIGN(auto grouped_ids,
+ internal::Grouper::ApplyGroupings(*actual, *ids));
+
+ for (uint32_t group = 0; group < num_groups; ++group) {
+ auto ids_slice = checked_pointer_cast<UInt32Array>(grouped_ids->value_slice(group));
+ for (auto slot : *ids_slice) {
+ EXPECT_EQ(slot, group);
+ }
+ }
+ };
+
+ ExpectGroupings("[]", "[[]]");
+
+ ExpectGroupings("[0, 0, 0]", "[[0, 1, 2]]");
+
+ ExpectGroupings("[0, 0, 0, 1, 1, 2]", "[[0, 1, 2], [3, 4], [5], []]");
+
+ ExpectGroupings("[2, 1, 2, 1, 1, 2]", "[[], [1, 3, 4], [0, 2, 5], [], []]");
+
+ ExpectGroupings("[2, 2, 5, 5, 2, 3]", "[[], [], [0, 1, 4], [5], [], [2, 3], [], []]");
+
+ auto ids = checked_pointer_cast<UInt32Array>(ArrayFromJSON(uint32(), "[0, null, 1]"));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("MakeGroupings with null ids"),
+ internal::Grouper::MakeGroupings(*ids, 5));
+}
+
+TEST(Grouper, ScalarValues) {
+ // large_utf8 forces GrouperImpl over GrouperFastImpl
+ for (const auto& str_type : {utf8(), large_utf8()}) {
+ {
+ TestGrouper g({ValueDescr::Scalar(boolean()), ValueDescr::Scalar(int32()),
+ ValueDescr::Scalar(decimal128(3, 2)),
+ ValueDescr::Scalar(decimal256(3, 2)),
+ ValueDescr::Scalar(fixed_size_binary(2)),
+ ValueDescr::Scalar(str_type), ValueDescr::Array(int32())});
+ g.ExpectConsume(
+ R"([
+[true, 1, "1.00", "2.00", "ab", "foo", 2],
+[true, 1, "1.00", "2.00", "ab", "foo", 2],
+[true, 1, "1.00", "2.00", "ab", "foo", 3]
+])",
+ "[0, 0, 1]");
+ }
+ {
+ auto dict_type = dictionary(int32(), utf8());
+ TestGrouper g({ValueDescr::Scalar(dict_type), ValueDescr::Scalar(str_type)});
+ const auto dict = R"(["foo", null])";
+ g.ExpectConsume(
+ {DictScalarFromJSON(dict_type, "0", dict), ScalarFromJSON(str_type, R"("")")},
+ ArrayFromJSON(uint32(), "[0]"));
+ g.ExpectConsume(
+ {DictScalarFromJSON(dict_type, "1", dict), ScalarFromJSON(str_type, R"("")")},
+ ArrayFromJSON(uint32(), "[1]"));
+ }
+ }
+}
+
+TEST(GroupBy, Errors) {
+ auto batch = RecordBatchFromJSON(
+ schema({field("argument", float64()), field("group_id", uint32())}), R"([
+ [1.0, 1],
+ [null, 1],
+ [0.0, 2],
+ [null, 3],
+ [4.0, 0],
+ [3.25, 1],
+ [0.125, 2],
+ [-0.25, 2],
+ [0.75, 0],
+ [null, 3]
+ ])");
+
+ EXPECT_THAT(CallFunction("hash_sum", {batch->GetColumnByName("argument"),
+ batch->GetColumnByName("group_id")}),
+ Raises(StatusCode::NotImplemented,
+ HasSubstr("Direct execution of HASH_AGGREGATE functions")));
+}
+
+TEST(GroupBy, NoBatches) {
+ // Regression test for ARROW-14583: handle when no batches are
+ // passed to the group by node before finalizing
+ auto table =
+ TableFromJSON(schema({field("argument", float64()), field("key", int64())}), {});
+ ASSERT_OK_AND_ASSIGN(
+ Datum aggregated_and_grouped,
+ GroupByTest({table->GetColumnByName("argument")}, {table->GetColumnByName("key")},
+ {
+ {"hash_count", nullptr},
+ },
+ /*use_threads=*/true, /*use_exec_plan=*/true));
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_count", int64()),
+ field("key_0", int64()),
+ }),
+ R"([])"),
+ aggregated_and_grouped, /*verbose=*/true);
+}
+
+namespace {
+void SortBy(std::vector<std::string> names, Datum* aggregated_and_grouped) {
+ SortOptions options;
+ for (auto&& name : names) {
+ options.sort_keys.emplace_back(std::move(name), SortOrder::Ascending);
+ }
+
+ ASSERT_OK_AND_ASSIGN(
+ auto batch, RecordBatch::FromStructArray(aggregated_and_grouped->make_array()));
+
+ // decode any dictionary columns:
+ ArrayVector cols = batch->columns();
+ for (auto& col : cols) {
+ if (col->type_id() != Type::DICTIONARY) continue;
+
+ auto dict_col = checked_cast<const DictionaryArray*>(col.get());
+ ASSERT_OK_AND_ASSIGN(col, Take(*dict_col->dictionary(), *dict_col->indices()));
+ }
+ batch = RecordBatch::Make(batch->schema(), batch->num_rows(), std::move(cols));
+
+ ASSERT_OK_AND_ASSIGN(Datum sort_indices, SortIndices(batch, options));
+
+ ASSERT_OK_AND_ASSIGN(*aggregated_and_grouped,
+ Take(*aggregated_and_grouped, sort_indices));
+}
+} // namespace
+
+TEST(GroupBy, CountOnly) {
+ for (bool use_exec_plan : {false, true}) {
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table = TableFromJSON(
+ schema({field("argument", float64()), field("key", int64())}), {R"([
+ [1.0, 1],
+ [null, 1]
+ ])",
+ R"([
+ [0.0, 2],
+ [null, 3],
+ [4.0, null],
+ [3.25, 1],
+ [0.125, 2]
+ ])",
+ R"([
+ [-0.25, 2],
+ [0.75, null],
+ [null, 3]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ GroupByTest({table->GetColumnByName("argument")},
+ {table->GetColumnByName("key")},
+ {
+ {"hash_count", nullptr},
+ },
+ use_threads, use_exec_plan));
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_count", int64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [2, 1],
+ [3, 2],
+ [0, 3],
+ [2, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+ }
+}
+
+TEST(GroupBy, CountScalar) {
+ BatchesWithSchema input;
+ input.batches = {
+ ExecBatchFromJSON({ValueDescr::Scalar(int32()), int64()},
+ "[[1, 1], [1, 1], [1, 2], [1, 3]]"),
+ ExecBatchFromJSON({ValueDescr::Scalar(int32()), int64()},
+ "[[null, 1], [null, 1], [null, 2], [null, 3]]"),
+ ExecBatchFromJSON({int32(), int64()}, "[[2, 1], [3, 2], [4, 3]]"),
+ };
+ input.schema = schema({field("argument", int32()), field("key", int64())});
+
+ CountOptions skip_nulls(CountOptions::ONLY_VALID);
+ CountOptions keep_nulls(CountOptions::ONLY_NULL);
+ CountOptions count_all(CountOptions::ALL);
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+ ASSERT_OK_AND_ASSIGN(
+ Datum actual,
+ GroupByUsingExecPlan(input, {"key"}, {"argument", "argument", "argument"},
+ {
+ {"hash_count", &skip_nulls},
+ {"hash_count", &keep_nulls},
+ {"hash_count", &count_all},
+ },
+ use_threads, default_exec_context()));
+ Datum expected = ArrayFromJSON(struct_({
+ field("hash_count", int64()),
+ field("hash_count", int64()),
+ field("hash_count", int64()),
+ field("key", int64()),
+ }),
+ R"([
+ [3, 2, 5, 1],
+ [2, 1, 3, 2],
+ [2, 1, 3, 3]
+ ])");
+ AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, SumOnly) {
+ for (bool use_exec_plan : {false, true}) {
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table = TableFromJSON(
+ schema({field("argument", float64()), field("key", int64())}), {R"([
+ [1.0, 1],
+ [null, 1]
+ ])",
+ R"([
+ [0.0, 2],
+ [null, 3],
+ [4.0, null],
+ [3.25, 1],
+ [0.125, 2]
+ ])",
+ R"([
+ [-0.25, 2],
+ [0.75, null],
+ [null, 3]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ GroupByTest({table->GetColumnByName("argument")},
+ {table->GetColumnByName("key")},
+ {
+ {"hash_sum", nullptr},
+ },
+ use_threads, use_exec_plan));
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_sum", float64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [4.25, 1],
+ [-0.125, 2],
+ [null, 3],
+ [4.75, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+ }
+}
+
+TEST(GroupBy, SumMeanProductDecimal) {
+ auto in_schema = schema({
+ field("argument0", decimal128(3, 2)),
+ field("argument1", decimal256(3, 2)),
+ field("key", int64()),
+ });
+
+ for (bool use_exec_plan : {false, true}) {
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table = TableFromJSON(in_schema, {R"([
+ ["1.00", "1.00", 1],
+ [null, null, 1]
+ ])",
+ R"([
+ ["0.00", "0.00", 2],
+ [null, null, 3],
+ ["4.00", "4.00", null],
+ ["3.25", "3.25", 1],
+ ["0.12", "0.12", 2]
+ ])",
+ R"([
+ ["-0.25", "-0.25", 2],
+ ["0.75", "0.75", null],
+ [null, null, 3]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ GroupByTest(
+ {
+ table->GetColumnByName("argument0"),
+ table->GetColumnByName("argument1"),
+ table->GetColumnByName("argument0"),
+ table->GetColumnByName("argument1"),
+ table->GetColumnByName("argument0"),
+ table->GetColumnByName("argument1"),
+ },
+ {table->GetColumnByName("key")},
+ {
+ {"hash_sum", nullptr},
+ {"hash_sum", nullptr},
+ {"hash_mean", nullptr},
+ {"hash_mean", nullptr},
+ {"hash_product", nullptr},
+ {"hash_product", nullptr},
+ },
+ use_threads, use_exec_plan));
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_sum", decimal128(3, 2)),
+ field("hash_sum", decimal256(3, 2)),
+ field("hash_mean", decimal128(3, 2)),
+ field("hash_mean", decimal256(3, 2)),
+ field("hash_product", decimal128(3, 2)),
+ field("hash_product", decimal256(3, 2)),
+ field("key_0", int64()),
+ }),
+ R"([
+ ["4.25", "4.25", "2.12", "2.12", "3.25", "3.25", 1],
+ ["-0.13", "-0.13", "-0.04", "-0.04", "0.00", "0.00", 2],
+ [null, null, null, null, null, null, 3],
+ ["4.75", "4.75", "2.37", "2.37", "3.00", "3.00", null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+ }
+}
+
+TEST(GroupBy, MeanOnly) {
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table =
+ TableFromJSON(schema({field("argument", float64()), field("key", int64())}), {R"([
+ [1.0, 1],
+ [null, 1]
+ ])",
+ R"([
+ [0.0, 2],
+ [null, 3],
+ [4.0, null],
+ [3.25, 1],
+ [0.125, 2]
+ ])",
+ R"([
+ [-0.25, 2],
+ [0.75, null],
+ [null, 3]
+ ])"});
+
+ ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/3);
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy({table->GetColumnByName("argument"),
+ table->GetColumnByName("argument")},
+ {table->GetColumnByName("key")},
+ {
+ {"hash_mean", nullptr},
+ {"hash_mean", &min_count},
+ },
+ use_threads));
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsApproxEqual(ArrayFromJSON(struct_({
+ field("hash_mean", float64()),
+ field("hash_mean", float64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [2.125, null, 1],
+ [-0.041666666666666664, -0.041666666666666664, 2],
+ [null, null, 3],
+ [2.375, null, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, SumMeanProductScalar) {
+ BatchesWithSchema input;
+ input.batches = {
+ ExecBatchFromJSON({ValueDescr::Scalar(int32()), int64()},
+ "[[1, 1], [1, 1], [1, 2], [1, 3]]"),
+ ExecBatchFromJSON({ValueDescr::Scalar(int32()), int64()},
+ "[[null, 1], [null, 1], [null, 2], [null, 3]]"),
+ ExecBatchFromJSON({int32(), int64()}, "[[2, 1], [3, 2], [4, 3]]"),
+ };
+ input.schema = schema({field("argument", int32()), field("key", int64())});
+
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+ ASSERT_OK_AND_ASSIGN(
+ Datum actual,
+ GroupByUsingExecPlan(input, {"key"}, {"argument", "argument", "argument"},
+ {
+ {"hash_sum", nullptr},
+ {"hash_mean", nullptr},
+ {"hash_product", nullptr},
+ },
+ use_threads, default_exec_context()));
+ Datum expected = ArrayFromJSON(struct_({
+ field("hash_sum", int64()),
+ field("hash_mean", float64()),
+ field("hash_product", int64()),
+ field("key", int64()),
+ }),
+ R"([
+ [4, 1.333333, 2, 1],
+ [4, 2, 3, 2],
+ [5, 2.5, 4, 3]
+ ])");
+ AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, VarianceAndStddev) {
+ auto batch = RecordBatchFromJSON(
+ schema({field("argument", int32()), field("key", int64())}), R"([
+ [1, 1],
+ [null, 1],
+ [0, 2],
+ [null, 3],
+ [4, null],
+ [3, 1],
+ [0, 2],
+ [-1, 2],
+ [1, null],
+ [null, 3]
+ ])");
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ },
+ {
+ batch->GetColumnByName("key"),
+ },
+ {
+ {"hash_variance", nullptr},
+ {"hash_stddev", nullptr},
+ }));
+
+ AssertDatumsApproxEqual(ArrayFromJSON(struct_({
+ field("hash_variance", float64()),
+ field("hash_stddev", float64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [1.0, 1.0, 1],
+ [0.22222222222222224, 0.4714045207910317, 2],
+ [null, null, 3],
+ [2.25, 1.5, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+
+ batch = RecordBatchFromJSON(
+ schema({field("argument", float64()), field("key", int64())}), R"([
+ [1.0, 1],
+ [null, 1],
+ [0.0, 2],
+ [null, 3],
+ [4.0, null],
+ [3.0, 1],
+ [0.0, 2],
+ [-1.0, 2],
+ [1.0, null],
+ [null, 3]
+ ])");
+
+ ASSERT_OK_AND_ASSIGN(aggregated_and_grouped, internal::GroupBy(
+ {
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ },
+ {
+ batch->GetColumnByName("key"),
+ },
+ {
+ {"hash_variance", nullptr},
+ {"hash_stddev", nullptr},
+ }));
+
+ AssertDatumsApproxEqual(ArrayFromJSON(struct_({
+ field("hash_variance", float64()),
+ field("hash_stddev", float64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [1.0, 1.0, 1],
+ [0.22222222222222224, 0.4714045207910317, 2],
+ [null, null, 3],
+ [2.25, 1.5, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+
+ // Test ddof
+ VarianceOptions variance_options(/*ddof=*/2);
+ ASSERT_OK_AND_ASSIGN(aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ },
+ {
+ batch->GetColumnByName("key"),
+ },
+ {
+ {"hash_variance", &variance_options},
+ {"hash_stddev", &variance_options},
+ }));
+
+ AssertDatumsApproxEqual(ArrayFromJSON(struct_({
+ field("hash_variance", float64()),
+ field("hash_stddev", float64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [null, null, 1],
+ [0.6666666666666667, 0.816496580927726, 2],
+ [null, null, 3],
+ [null, null, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+}
+
+TEST(GroupBy, TDigest) {
+ auto batch = RecordBatchFromJSON(
+ schema({field("argument", float64()), field("key", int64())}), R"([
+ [1, 1],
+ [null, 1],
+ [0, 2],
+ [null, 3],
+ [1, 4],
+ [4, null],
+ [3, 1],
+ [0, 2],
+ [-1, 2],
+ [1, null],
+ [NaN, 3],
+ [1, 4],
+ [1, 4],
+ [null, 4]
+ ])");
+
+ TDigestOptions options1(std::vector<double>{0.5, 0.9, 0.99});
+ TDigestOptions options2(std::vector<double>{0.5, 0.9, 0.99}, /*delta=*/50,
+ /*buffer_size=*/1024);
+ TDigestOptions keep_nulls(/*q=*/0.5, /*delta=*/100, /*buffer_size=*/500,
+ /*skip_nulls=*/false, /*min_count=*/0);
+ TDigestOptions min_count(/*q=*/0.5, /*delta=*/100, /*buffer_size=*/500,
+ /*skip_nulls=*/true, /*min_count=*/3);
+ TDigestOptions keep_nulls_min_count(/*q=*/0.5, /*delta=*/100, /*buffer_size=*/500,
+ /*skip_nulls=*/false, /*min_count=*/3);
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ },
+ {
+ batch->GetColumnByName("key"),
+ },
+ {
+ {"hash_tdigest", nullptr},
+ {"hash_tdigest", &options1},
+ {"hash_tdigest", &options2},
+ {"hash_tdigest", &keep_nulls},
+ {"hash_tdigest", &min_count},
+ {"hash_tdigest", &keep_nulls_min_count},
+ }));
+
+ AssertDatumsApproxEqual(
+ ArrayFromJSON(struct_({
+ field("hash_tdigest", fixed_size_list(float64(), 1)),
+ field("hash_tdigest", fixed_size_list(float64(), 3)),
+ field("hash_tdigest", fixed_size_list(float64(), 3)),
+ field("hash_tdigest", fixed_size_list(float64(), 1)),
+ field("hash_tdigest", fixed_size_list(float64(), 1)),
+ field("hash_tdigest", fixed_size_list(float64(), 1)),
+ field("key_0", int64()),
+ }),
+ R"([
+ [[1.0], [1.0, 3.0, 3.0], [1.0, 3.0, 3.0], [null], [null], [null], 1],
+ [[0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0], [0.0], [0.0], 2],
+ [[null], [null, null, null], [null, null, null], [null], [null], [null], 3],
+ [[1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [null], [1.0], [null], 4],
+ [[1.0], [1.0, 4.0, 4.0], [1.0, 4.0, 4.0], [1.0], [null], [null], null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+}
+
+TEST(GroupBy, ApproximateMedian) {
+ for (const auto& type : {float64(), int8()}) {
+ auto batch =
+ RecordBatchFromJSON(schema({field("argument", type), field("key", int64())}), R"([
+ [1, 1],
+ [null, 1],
+ [0, 2],
+ [null, 3],
+ [1, 4],
+ [4, null],
+ [3, 1],
+ [0, 2],
+ [-1, 2],
+ [1, null],
+ [null, 3],
+ [1, 4],
+ [1, 4],
+ [null, 4]
+ ])");
+
+ ScalarAggregateOptions options;
+ ScalarAggregateOptions keep_nulls(
+ /*skip_nulls=*/false, /*min_count=*/0);
+ ScalarAggregateOptions min_count(
+ /*skip_nulls=*/true, /*min_count=*/3);
+ ScalarAggregateOptions keep_nulls_min_count(
+ /*skip_nulls=*/false, /*min_count=*/3);
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ },
+ {
+ batch->GetColumnByName("key"),
+ },
+ {
+ {"hash_approximate_median", &options},
+ {"hash_approximate_median", &keep_nulls},
+ {"hash_approximate_median", &min_count},
+ {"hash_approximate_median", &keep_nulls_min_count},
+ }));
+
+ AssertDatumsApproxEqual(ArrayFromJSON(struct_({
+ field("hash_approximate_median", float64()),
+ field("hash_approximate_median", float64()),
+ field("hash_approximate_median", float64()),
+ field("hash_approximate_median", float64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [1.0, null, null, null, 1],
+ [0.0, 0.0, 0.0, 0.0, 2],
+ [null, null, null, null, 3],
+ [1.0, null, 1.0, null, 4],
+ [1.0, 1.0, null, null, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, StddevVarianceTDigestScalar) {
+ BatchesWithSchema input;
+ input.batches = {
+ ExecBatchFromJSON(
+ {ValueDescr::Scalar(int32()), ValueDescr::Scalar(float32()), int64()},
+ "[[1, 1.0, 1], [1, 1.0, 1], [1, 1.0, 2], [1, 1.0, 3]]"),
+ ExecBatchFromJSON(
+ {ValueDescr::Scalar(int32()), ValueDescr::Scalar(float32()), int64()},
+ "[[null, null, 1], [null, null, 1], [null, null, 2], [null, null, 3]]"),
+ ExecBatchFromJSON({int32(), float32(), int64()},
+ "[[2, 2.0, 1], [3, 3.0, 2], [4, 4.0, 3]]"),
+ };
+ input.schema = schema(
+ {field("argument", int32()), field("argument1", float32()), field("key", int64())});
+
+ for (bool use_threads : {false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+ ASSERT_OK_AND_ASSIGN(Datum actual,
+ GroupByUsingExecPlan(input, {"key"},
+ {"argument", "argument", "argument",
+ "argument1", "argument1", "argument1"},
+ {
+ {"hash_stddev", nullptr},
+ {"hash_variance", nullptr},
+ {"hash_tdigest", nullptr},
+ {"hash_stddev", nullptr},
+ {"hash_variance", nullptr},
+ {"hash_tdigest", nullptr},
+ },
+ use_threads, default_exec_context()));
+ Datum expected =
+ ArrayFromJSON(struct_({
+ field("hash_stddev", float64()),
+ field("hash_variance", float64()),
+ field("hash_tdigest", fixed_size_list(float64(), 1)),
+ field("hash_stddev", float64()),
+ field("hash_variance", float64()),
+ field("hash_tdigest", fixed_size_list(float64(), 1)),
+ field("key", int64()),
+ }),
+ R"([
+ [0.4714045, 0.222222, [1.0], 0.4714045, 0.222222, [1.0], 1],
+ [1.0, 1.0, [1.0], 1.0, 1.0, [1.0], 2],
+ [1.5, 2.25, [1.0], 1.5, 2.25, [1.0], 3]
+ ])");
+ AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, VarianceOptions) {
+ BatchesWithSchema input;
+ input.batches = {
+ ExecBatchFromJSON(
+ {ValueDescr::Scalar(int32()), ValueDescr::Scalar(float32()), int64()},
+ "[[1, 1.0, 1], [1, 1.0, 1], [1, 1.0, 2], [1, 1.0, 2], [1, 1.0, 3]]"),
+ ExecBatchFromJSON(
+ {ValueDescr::Scalar(int32()), ValueDescr::Scalar(float32()), int64()},
+ "[[1, 1.0, 4], [1, 1.0, 4]]"),
+ ExecBatchFromJSON(
+ {ValueDescr::Scalar(int32()), ValueDescr::Scalar(float32()), int64()},
+ "[[null, null, 1]]"),
+ ExecBatchFromJSON({int32(), float32(), int64()}, "[[2, 2.0, 1], [3, 3.0, 2]]"),
+ ExecBatchFromJSON({int32(), float32(), int64()}, "[[4, 4.0, 2], [2, 2.0, 4]]"),
+ ExecBatchFromJSON({int32(), float32(), int64()}, "[[null, null, 4]]"),
+ };
+ input.schema = schema(
+ {field("argument", int32()), field("argument1", float32()), field("key", int64())});
+
+ VarianceOptions keep_nulls(/*ddof=*/0, /*skip_nulls=*/false, /*min_count=*/0);
+ VarianceOptions min_count(/*ddof=*/0, /*skip_nulls=*/true, /*min_count=*/3);
+ VarianceOptions keep_nulls_min_count(/*ddof=*/0, /*skip_nulls=*/false, /*min_count=*/3);
+
+ for (bool use_threads : {false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+ ASSERT_OK_AND_ASSIGN(
+ Datum actual, GroupByUsingExecPlan(input, {"key"},
+ {
+ "argument",
+ "argument",
+ "argument",
+ "argument",
+ "argument",
+ "argument",
+ },
+ {
+ {"hash_stddev", &keep_nulls},
+ {"hash_stddev", &min_count},
+ {"hash_stddev", &keep_nulls_min_count},
+ {"hash_variance", &keep_nulls},
+ {"hash_variance", &min_count},
+ {"hash_variance", &keep_nulls_min_count},
+ },
+ use_threads, default_exec_context()));
+ Datum expected = ArrayFromJSON(struct_({
+ field("hash_stddev", float64()),
+ field("hash_stddev", float64()),
+ field("hash_stddev", float64()),
+ field("hash_variance", float64()),
+ field("hash_variance", float64()),
+ field("hash_variance", float64()),
+ field("key", int64()),
+ }),
+ R"([
+ [null, 0.471405, null, null, 0.222222, null, 1],
+ [1.29904, 1.29904, 1.29904, 1.6875, 1.6875, 1.6875, 2],
+ [0.0, null, null, 0.0, null, null, 3],
+ [null, 0.471405, null, null, 0.222222, null, 4]
+ ])");
+ ValidateOutput(expected);
+ AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
+
+ ASSERT_OK_AND_ASSIGN(
+ actual, GroupByUsingExecPlan(input, {"key"},
+ {
+ "argument1",
+ "argument1",
+ "argument1",
+ "argument1",
+ "argument1",
+ "argument1",
+ },
+ {
+ {"hash_stddev", &keep_nulls},
+ {"hash_stddev", &min_count},
+ {"hash_stddev", &keep_nulls_min_count},
+ {"hash_variance", &keep_nulls},
+ {"hash_variance", &min_count},
+ {"hash_variance", &keep_nulls_min_count},
+ },
+ use_threads, default_exec_context()));
+ expected = ArrayFromJSON(struct_({
+ field("hash_stddev", float64()),
+ field("hash_stddev", float64()),
+ field("hash_stddev", float64()),
+ field("hash_variance", float64()),
+ field("hash_variance", float64()),
+ field("hash_variance", float64()),
+ field("key", int64()),
+ }),
+ R"([
+ [null, 0.471405, null, null, 0.222222, null, 1],
+ [1.29904, 1.29904, 1.29904, 1.6875, 1.6875, 1.6875, 2],
+ [0.0, null, null, 0.0, null, null, 3],
+ [null, 0.471405, null, null, 0.222222, null, 4]
+ ])");
+ ValidateOutput(expected);
+ AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, MinMaxOnly) {
+ auto in_schema = schema({
+ field("argument", float64()),
+ field("argument1", null()),
+ field("argument2", boolean()),
+ field("key", int64()),
+ });
+ for (bool use_exec_plan : {false, true}) {
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table = TableFromJSON(in_schema, {R"([
+ [1.0, null, true, 1],
+ [null, null, true, 1]
+])",
+ R"([
+ [0.0, null, false, 2],
+ [null, null, false, 3],
+ [4.0, null, null, null],
+ [3.25, null, true, 1],
+ [0.125, null, false, 2]
+])",
+ R"([
+ [-0.25, null, false, 2],
+ [0.75, null, true, null],
+ [null, null, true, 3]
+])"});
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ GroupByTest(
+ {
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument1"),
+ table->GetColumnByName("argument2"),
+ },
+ {table->GetColumnByName("key")},
+ {
+ {"hash_min_max", nullptr},
+ {"hash_min_max", nullptr},
+ {"hash_min_max", nullptr},
+ },
+ use_threads, use_exec_plan));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(
+ ArrayFromJSON(struct_({
+ field("hash_min_max", struct_({
+ field("min", float64()),
+ field("max", float64()),
+ })),
+ field("hash_min_max", struct_({
+ field("min", null()),
+ field("max", null()),
+ })),
+ field("hash_min_max", struct_({
+ field("min", boolean()),
+ field("max", boolean()),
+ })),
+ field("key_0", int64()),
+ }),
+ R"([
+ [{"min": 1.0, "max": 3.25}, {"min": null, "max": null}, {"min": true, "max": true}, 1],
+ [{"min": -0.25, "max": 0.125}, {"min": null, "max": null}, {"min": false, "max": false}, 2],
+ [{"min": null, "max": null}, {"min": null, "max": null}, {"min": false, "max": true}, 3],
+ [{"min": 0.75, "max": 4.0}, {"min": null, "max": null}, {"min": true, "max": true}, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+ }
+}
+
+TEST(GroupBy, MinMaxTypes) {
+ std::vector<std::shared_ptr<DataType>> types;
+ types.insert(types.end(), NumericTypes().begin(), NumericTypes().end());
+ types.insert(types.end(), TemporalTypes().begin(), TemporalTypes().end());
+ types.push_back(month_interval());
+ for (const auto& ty : types) {
+ SCOPED_TRACE(ty->ToString());
+ auto in_schema = schema({field("argument0", ty), field("key", int64())});
+ auto table = TableFromJSON(in_schema, {R"([
+ [1, 1],
+ [null, 1]
+])",
+ R"([
+ [0, 2],
+ [null, 3],
+ [3, 4],
+ [5, 4],
+ [4, null],
+ [3, 1],
+ [0, 2]
+])",
+ R"([
+ [0, 2],
+ [1, null],
+ [null, 3]
+])"});
+
+ ASSERT_OK_AND_ASSIGN(
+ Datum aggregated_and_grouped,
+ GroupByTest({table->GetColumnByName("argument0")},
+ {table->GetColumnByName("key")}, {{"hash_min_max", nullptr}},
+ /*use_threads=*/true, /*use_exec_plan=*/true));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(
+ ArrayFromJSON(
+ struct_({
+ field("hash_min_max", struct_({field("min", ty), field("max", ty)})),
+ field("key_0", int64()),
+ }),
+ R"([
+ [{"min": 1, "max": 3}, 1],
+ [{"min": 0, "max": 0}, 2],
+ [{"min": null, "max": null}, 3],
+ [{"min": 3, "max": 5}, 4],
+ [{"min": 1, "max": 4}, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, MinMaxDecimal) {
+ auto in_schema = schema({
+ field("argument0", decimal128(3, 2)),
+ field("argument1", decimal256(3, 2)),
+ field("key", int64()),
+ });
+ for (bool use_exec_plan : {false, true}) {
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table = TableFromJSON(in_schema, {R"([
+ ["1.01", "1.01", 1],
+ [null, null, 1]
+ ])",
+ R"([
+ ["0.00", "0.00", 2],
+ [null, null, 3],
+ ["-3.25", "-3.25", 4],
+ ["-5.25", "-5.25", 4],
+ ["4.01", "4.01", null],
+ ["3.25", "3.25", 1],
+ ["0.12", "0.12", 2]
+ ])",
+ R"([
+ ["-0.25", "-0.25", 2],
+ ["0.75", "0.75", null],
+ [null, null, 3]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ GroupByTest(
+ {
+ table->GetColumnByName("argument0"),
+ table->GetColumnByName("argument1"),
+ },
+ {table->GetColumnByName("key")},
+ {
+ {"hash_min_max", nullptr},
+ {"hash_min_max", nullptr},
+ },
+ use_threads, use_exec_plan));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(
+ ArrayFromJSON(struct_({
+ field("hash_min_max", struct_({
+ field("min", decimal128(3, 2)),
+ field("max", decimal128(3, 2)),
+ })),
+ field("hash_min_max", struct_({
+ field("min", decimal256(3, 2)),
+ field("max", decimal256(3, 2)),
+ })),
+ field("key_0", int64()),
+ }),
+ R"([
+ [{"min": "1.01", "max": "3.25"}, {"min": "1.01", "max": "3.25"}, 1],
+ [{"min": "-0.25", "max": "0.12"}, {"min": "-0.25", "max": "0.12"}, 2],
+ [{"min": null, "max": null}, {"min": null, "max": null}, 3],
+ [{"min": "-5.25", "max": "-3.25"}, {"min": "-5.25", "max": "-3.25"}, 4],
+ [{"min": "0.75", "max": "4.01"}, {"min": "0.75", "max": "4.01"}, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+ }
+}
+
+TEST(GroupBy, MinOrMax) {
+ auto table =
+ TableFromJSON(schema({field("argument", float64()), field("key", int64())}), {R"([
+ [1.0, 1],
+ [null, 1]
+])",
+ R"([
+ [0.0, 2],
+ [null, 3],
+ [4.0, null],
+ [3.25, 1],
+ [0.125, 2]
+])",
+ R"([
+ [-0.25, 2],
+ [0.75, null],
+ [null, 3]
+])",
+ R"([
+ [NaN, 4],
+ [null, 4],
+ [Inf, 4],
+ [-Inf, 4],
+ [0.0, 4]
+])"});
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ GroupByTest({table->GetColumnByName("argument"),
+ table->GetColumnByName("argument")},
+ {table->GetColumnByName("key")},
+ {
+ {"hash_min", nullptr},
+ {"hash_max", nullptr},
+ },
+ /*use_threads=*/true, /*use_exec_plan=*/true));
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_min", float64()),
+ field("hash_max", float64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [1.0, 3.25, 1],
+ [-0.25, 0.125, 2],
+ [null, null, 3],
+ [-Inf, Inf, 4],
+ [0.75, 4.0, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+}
+
+TEST(GroupBy, MinMaxScalar) {
+ BatchesWithSchema input;
+ input.batches = {
+ ExecBatchFromJSON({ValueDescr::Scalar(int32()), int64()},
+ "[[-1, 1], [-1, 1], [-1, 2], [-1, 3]]"),
+ ExecBatchFromJSON({ValueDescr::Scalar(int32()), int64()},
+ "[[null, 1], [null, 1], [null, 2], [null, 3]]"),
+ ExecBatchFromJSON({int32(), int64()}, "[[2, 1], [3, 2], [4, 3]]"),
+ };
+ input.schema = schema({field("argument", int32()), field("key", int64())});
+
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+ ASSERT_OK_AND_ASSIGN(
+ Datum actual,
+ GroupByUsingExecPlan(input, {"key"}, {"argument", "argument", "argument"},
+ {{"hash_min_max", nullptr}}, use_threads,
+ default_exec_context()));
+ Datum expected =
+ ArrayFromJSON(struct_({
+ field("hash_min_max",
+ struct_({field("min", int32()), field("max", int32())})),
+ field("key", int64()),
+ }),
+ R"([
+ [{"min": -1, "max": 2}, 1],
+ [{"min": -1, "max": 3}, 2],
+ [{"min": -1, "max": 4}, 3]
+ ])");
+ AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, AnyAndAll) {
+ ScalarAggregateOptions options(/*skip_nulls=*/false);
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table =
+ TableFromJSON(schema({field("argument", boolean()), field("key", int64())}), {R"([
+ [true, 1],
+ [null, 1]
+ ])",
+ R"([
+ [false, 2],
+ [null, 3],
+ [null, 4],
+ [false, 4],
+ [true, 5],
+ [false, null],
+ [true, 1],
+ [true, 2]
+ ])",
+ R"([
+ [false, 2],
+ [false, null],
+ [null, 3]
+ ])"});
+
+ ScalarAggregateOptions no_min(/*skip_nulls=*/true, /*min_count=*/0);
+ ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/3);
+ ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false, /*min_count=*/0);
+ ScalarAggregateOptions keep_nulls_min_count(/*skip_nulls=*/false, /*min_count=*/3);
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ },
+ {table->GetColumnByName("key")},
+ {
+ {"hash_any", &no_min},
+ {"hash_any", &min_count},
+ {"hash_any", &keep_nulls},
+ {"hash_any", &keep_nulls_min_count},
+ {"hash_all", &no_min},
+ {"hash_all", &min_count},
+ {"hash_all", &keep_nulls},
+ {"hash_all", &keep_nulls_min_count},
+ },
+ use_threads));
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ // Group 1: trues and nulls
+ // Group 2: trues and falses
+ // Group 3: nulls
+ // Group 4: falses and nulls
+ // Group 5: trues
+ // Group null: falses
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_any", boolean()),
+ field("hash_any", boolean()),
+ field("hash_any", boolean()),
+ field("hash_any", boolean()),
+ field("hash_all", boolean()),
+ field("hash_all", boolean()),
+ field("hash_all", boolean()),
+ field("hash_all", boolean()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [true, null, true, null, true, null, null, null, 1],
+ [true, true, true, true, false, false, false, false, 2],
+ [false, null, null, null, true, null, null, null, 3],
+ [false, null, null, null, false, null, false, null, 4],
+ [true, null, true, null, true, null, true, null, 5],
+ [false, null, false, null, false, null, false, null, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, AnyAllScalar) {
+ BatchesWithSchema input;
+ input.batches = {
+ ExecBatchFromJSON({ValueDescr::Scalar(boolean()), int64()},
+ "[[true, 1], [true, 1], [true, 2], [true, 3]]"),
+ ExecBatchFromJSON({ValueDescr::Scalar(boolean()), int64()},
+ "[[null, 1], [null, 1], [null, 2], [null, 3]]"),
+ ExecBatchFromJSON({boolean(), int64()}, "[[true, 1], [false, 2], [null, 3]]"),
+ };
+ input.schema = schema({field("argument", boolean()), field("key", int64())});
+
+ ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false, /*min_count=*/0);
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+ ASSERT_OK_AND_ASSIGN(
+ Datum actual,
+ GroupByUsingExecPlan(input, {"key"},
+ {"argument", "argument", "argument", "argument"},
+ {
+ {"hash_any", nullptr},
+ {"hash_all", nullptr},
+ {"hash_any", &keep_nulls},
+ {"hash_all", &keep_nulls},
+ },
+ use_threads, default_exec_context()));
+ Datum expected = ArrayFromJSON(struct_({
+ field("hash_any", boolean()),
+ field("hash_all", boolean()),
+ field("hash_any", boolean()),
+ field("hash_all", boolean()),
+ field("key", int64()),
+ }),
+ R"([
+ [true, true, true, null, 1],
+ [true, false, true, false, 2],
+ [true, true, true, null, 3]
+ ])");
+ AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, CountDistinct) {
+ CountOptions all(CountOptions::ALL);
+ CountOptions only_valid(CountOptions::ONLY_VALID);
+ CountOptions only_null(CountOptions::ONLY_NULL);
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table =
+ TableFromJSON(schema({field("argument", float64()), field("key", int64())}), {R"([
+ [1, 1],
+ [1, 1]
+])",
+ R"([
+ [0, 2],
+ [null, 3],
+ [null, 3]
+])",
+ R"([
+ [null, 4],
+ [null, 4]
+])",
+ R"([
+ [4, null],
+ [1, 3]
+])",
+ R"([
+ [0, 2],
+ [-1, 2]
+])",
+ R"([
+ [1, null],
+ [NaN, 3]
+ ])",
+ R"([
+ [2, null],
+ [3, null]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_count_distinct", &all},
+ {"hash_count_distinct", &only_valid},
+ {"hash_count_distinct", &only_null},
+ },
+ use_threads));
+ SortBy({"key_0"}, &aggregated_and_grouped);
+ ValidateOutput(aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_count_distinct", int64()),
+ field("hash_count_distinct", int64()),
+ field("hash_count_distinct", int64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [1, 1, 0, 1],
+ [2, 2, 0, 2],
+ [3, 2, 1, 3],
+ [1, 0, 1, 4],
+ [4, 4, 0, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+
+ table =
+ TableFromJSON(schema({field("argument", utf8()), field("key", int64())}), {R"([
+ ["foo", 1],
+ ["foo", 1]
+])",
+ R"([
+ ["bar", 2],
+ [null, 3],
+ [null, 3]
+])",
+ R"([
+ [null, 4],
+ [null, 4]
+])",
+ R"([
+ ["baz", null],
+ ["foo", 3]
+])",
+ R"([
+ ["bar", 2],
+ ["spam", 2]
+])",
+ R"([
+ ["eggs", null],
+ ["ham", 3]
+ ])",
+ R"([
+ ["a", null],
+ ["b", null]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_count_distinct", &all},
+ {"hash_count_distinct", &only_valid},
+ {"hash_count_distinct", &only_null},
+ },
+ use_threads));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_count_distinct", int64()),
+ field("hash_count_distinct", int64()),
+ field("hash_count_distinct", int64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [1, 1, 0, 1],
+ [2, 2, 0, 2],
+ [3, 2, 1, 3],
+ [1, 0, 1, 4],
+ [4, 4, 0, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+
+ table =
+ TableFromJSON(schema({field("argument", utf8()), field("key", int64())}), {
+ R"([
+ ["foo", 1],
+ ["foo", 1],
+ ["bar", 2],
+ ["bar", 2],
+ ["spam", 2]
+])",
+ });
+
+ ASSERT_OK_AND_ASSIGN(aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_count_distinct", &all},
+ {"hash_count_distinct", &only_valid},
+ {"hash_count_distinct", &only_null},
+ },
+ use_threads));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_count_distinct", int64()),
+ field("hash_count_distinct", int64()),
+ field("hash_count_distinct", int64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [1, 1, 0, 1],
+ [2, 2, 0, 2]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, Distinct) {
+ CountOptions all(CountOptions::ALL);
+ CountOptions only_valid(CountOptions::ONLY_VALID);
+ CountOptions only_null(CountOptions::ONLY_NULL);
+ for (bool use_threads : {false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table =
+ TableFromJSON(schema({field("argument", utf8()), field("key", int64())}), {R"([
+ ["foo", 1],
+ ["foo", 1]
+])",
+ R"([
+ ["bar", 2],
+ [null, 3],
+ [null, 3]
+])",
+ R"([
+ [null, 4],
+ [null, 4]
+])",
+ R"([
+ ["baz", null],
+ ["foo", 3]
+])",
+ R"([
+ ["bar", 2],
+ ["spam", 2]
+])",
+ R"([
+ ["eggs", null],
+ ["ham", 3]
+ ])",
+ R"([
+ ["a", null],
+ ["b", null]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_distinct", &all},
+ {"hash_distinct", &only_valid},
+ {"hash_distinct", &only_null},
+ },
+ use_threads));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ // Order of sub-arrays is not stable
+ auto sort = [](const Array& arr) -> std::shared_ptr<Array> {
+ EXPECT_OK_AND_ASSIGN(auto indices, SortIndices(arr));
+ EXPECT_OK_AND_ASSIGN(auto sorted, Take(arr, indices));
+ return sorted.make_array();
+ };
+
+ auto struct_arr = aggregated_and_grouped.array_as<StructArray>();
+
+ auto all_arr = checked_pointer_cast<ListArray>(struct_arr->field(0));
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["foo"])"), sort(*all_arr->value_slice(0)),
+ /*verbose=*/true);
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["bar", "spam"])"),
+ sort(*all_arr->value_slice(1)), /*verbose=*/true);
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["foo", "ham", null])"),
+ sort(*all_arr->value_slice(2)), /*verbose=*/true);
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"([null])"), sort(*all_arr->value_slice(3)),
+ /*verbose=*/true);
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["a", "b", "baz", "eggs"])"),
+ sort(*all_arr->value_slice(4)), /*verbose=*/true);
+
+ auto valid_arr = checked_pointer_cast<ListArray>(struct_arr->field(1));
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["foo"])"),
+ sort(*valid_arr->value_slice(0)), /*verbose=*/true);
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["bar", "spam"])"),
+ sort(*valid_arr->value_slice(1)), /*verbose=*/true);
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["foo", "ham"])"),
+ sort(*valid_arr->value_slice(2)), /*verbose=*/true);
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"([])"), sort(*valid_arr->value_slice(3)),
+ /*verbose=*/true);
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["a", "b", "baz", "eggs"])"),
+ sort(*valid_arr->value_slice(4)), /*verbose=*/true);
+
+ auto null_arr = checked_pointer_cast<ListArray>(struct_arr->field(2));
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"([])"), sort(*null_arr->value_slice(0)),
+ /*verbose=*/true);
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"([])"), sort(*null_arr->value_slice(1)),
+ /*verbose=*/true);
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"([null])"), sort(*null_arr->value_slice(2)),
+ /*verbose=*/true);
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"([null])"), sort(*null_arr->value_slice(3)),
+ /*verbose=*/true);
+ AssertDatumsEqual(ArrayFromJSON(utf8(), R"([])"), sort(*null_arr->value_slice(4)),
+ /*verbose=*/true);
+
+ table =
+ TableFromJSON(schema({field("argument", utf8()), field("key", int64())}), {
+ R"([
+ ["foo", 1],
+ ["foo", 1],
+ ["bar", 2],
+ ["bar", 2]
+])",
+ });
+ ASSERT_OK_AND_ASSIGN(aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_distinct", &all},
+ {"hash_distinct", &only_valid},
+ {"hash_distinct", &only_null},
+ },
+ use_threads));
+ ValidateOutput(aggregated_and_grouped);
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(
+ ArrayFromJSON(struct_({
+ field("hash_distinct", list(utf8())),
+ field("hash_distinct", list(utf8())),
+ field("hash_distinct", list(utf8())),
+ field("key_0", int64()),
+ }),
+ R"([[["foo"], ["foo"], [], 1], [["bar"], ["bar"], [], 2]])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, CountAndSum) {
+ auto batch = RecordBatchFromJSON(
+ schema({field("argument", float64()), field("key", int64())}), R"([
+ [1.0, 1],
+ [null, 1],
+ [0.0, 2],
+ [null, 3],
+ [4.0, null],
+ [3.25, 1],
+ [0.125, 2],
+ [-0.25, 2],
+ [0.75, null],
+ [null, 3]
+ ])");
+
+ CountOptions count_options;
+ CountOptions count_nulls(CountOptions::ONLY_NULL);
+ CountOptions count_all(CountOptions::ALL);
+ ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/3);
+ ASSERT_OK_AND_ASSIGN(
+ Datum aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ // NB: passing an argument twice or also using it as a key is legal
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("key"),
+ },
+ {
+ batch->GetColumnByName("key"),
+ },
+ {
+ {"hash_count", &count_options},
+ {"hash_count", &count_nulls},
+ {"hash_count", &count_all},
+ {"hash_sum", nullptr},
+ {"hash_sum", &min_count},
+ {"hash_sum", nullptr},
+ }));
+
+ AssertDatumsEqual(
+ ArrayFromJSON(struct_({
+ field("hash_count", int64()),
+ field("hash_count", int64()),
+ field("hash_count", int64()),
+ // NB: summing a float32 array results in float64 sums
+ field("hash_sum", float64()),
+ field("hash_sum", float64()),
+ field("hash_sum", int64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [2, 1, 3, 4.25, null, 3, 1],
+ [3, 0, 3, -0.125, -0.125, 6, 2],
+ [0, 2, 2, null, null, 6, 3],
+ [2, 0, 2, 4.75, null, null, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+}
+
+TEST(GroupBy, Product) {
+ auto batch = RecordBatchFromJSON(
+ schema({field("argument", float64()), field("key", int64())}), R"([
+ [-1.0, 1],
+ [null, 1],
+ [0.0, 2],
+ [null, 3],
+ [4.0, null],
+ [3.25, 1],
+ [0.125, 2],
+ [-0.25, 2],
+ [0.75, null],
+ [null, 3]
+ ])");
+
+ ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/3);
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("key"),
+ batch->GetColumnByName("argument"),
+ },
+ {
+ batch->GetColumnByName("key"),
+ },
+ {
+ {"hash_product", nullptr},
+ {"hash_product", nullptr},
+ {"hash_product", &min_count},
+ }));
+
+ AssertDatumsApproxEqual(ArrayFromJSON(struct_({
+ field("hash_product", float64()),
+ field("hash_product", int64()),
+ field("hash_product", float64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [-3.25, 1, null, 1],
+ [0.0, 8, 0.0, 2],
+ [null, 9, null, 3],
+ [3.0, null, null, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+
+ // Overflow should wrap around
+ batch = RecordBatchFromJSON(schema({field("argument", int64()), field("key", int64())}),
+ R"([
+ [8589934592, 1],
+ [8589934593, 1]
+ ])");
+
+ ASSERT_OK_AND_ASSIGN(aggregated_and_grouped, internal::GroupBy(
+ {
+ batch->GetColumnByName("argument"),
+ },
+ {
+ batch->GetColumnByName("key"),
+ },
+ {
+ {"hash_product", nullptr},
+ }));
+
+ AssertDatumsApproxEqual(ArrayFromJSON(struct_({
+ field("hash_product", int64()),
+ field("key_0", int64()),
+ }),
+ R"([[8589934592, 1]])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+}
+
+TEST(GroupBy, SumMeanProductKeepNulls) {
+ auto batch = RecordBatchFromJSON(
+ schema({field("argument", float64()), field("key", int64())}), R"([
+ [-1.0, 1],
+ [null, 1],
+ [0.0, 2],
+ [null, 3],
+ [4.0, null],
+ [3.25, 1],
+ [0.125, 2],
+ [-0.25, 2],
+ [0.75, null],
+ [null, 3]
+ ])");
+
+ ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false);
+ ScalarAggregateOptions min_count(/*skip_nulls=*/false, /*min_count=*/3);
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ },
+ {
+ batch->GetColumnByName("key"),
+ },
+ {
+ {"hash_sum", &keep_nulls},
+ {"hash_sum", &min_count},
+ {"hash_mean", &keep_nulls},
+ {"hash_mean", &min_count},
+ {"hash_product", &keep_nulls},
+ {"hash_product", &min_count},
+ }));
+
+ AssertDatumsApproxEqual(ArrayFromJSON(struct_({
+ field("hash_sum", float64()),
+ field("hash_sum", float64()),
+ field("hash_mean", float64()),
+ field("hash_mean", float64()),
+ field("hash_product", float64()),
+ field("hash_product", float64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [null, null, null, null, null, null, 1],
+ [-0.125, -0.125, -0.0416667, -0.0416667, 0.0, 0.0, 2],
+ [null, null, null, null, null, null, 3],
+ [4.75, null, 2.375, null, 3.0, null, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+}
+
+TEST(GroupBy, SumOnlyStringAndDictKeys) {
+ for (auto key_type : {utf8(), dictionary(int32(), utf8())}) {
+ SCOPED_TRACE("key type: " + key_type->ToString());
+
+ auto batch = RecordBatchFromJSON(
+ schema({field("argument", float64()), field("key", key_type)}), R"([
+ [1.0, "alfa"],
+ [null, "alfa"],
+ [0.0, "beta"],
+ [null, "gama"],
+ [4.0, null ],
+ [3.25, "alfa"],
+ [0.125, "beta"],
+ [-0.25, "beta"],
+ [0.75, null ],
+ [null, "gama"]
+ ])");
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy({batch->GetColumnByName("argument")},
+ {batch->GetColumnByName("key")},
+ {
+ {"hash_sum", nullptr},
+ }));
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_sum", float64()),
+ field("key_0", key_type),
+ }),
+ R"([
+ [4.25, "alfa"],
+ [-0.125, "beta"],
+ [null, "gama"],
+ [4.75, null ]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+}
+
+TEST(GroupBy, ConcreteCaseWithValidateGroupBy) {
+ auto batch = RecordBatchFromJSON(
+ schema({field("argument", float64()), field("key", utf8())}), R"([
+ [1.0, "alfa"],
+ [null, "alfa"],
+ [0.0, "beta"],
+ [null, "gama"],
+ [4.0, null ],
+ [3.25, "alfa"],
+ [0.125, "beta"],
+ [-0.25, "beta"],
+ [0.75, null ],
+ [null, "gama"]
+ ])");
+
+ ScalarAggregateOptions keepna{false, 1};
+ CountOptions nulls(CountOptions::ONLY_NULL);
+ CountOptions non_null(CountOptions::ONLY_VALID);
+
+ using internal::Aggregate;
+ for (auto agg : {
+ Aggregate{"hash_sum", nullptr},
+ Aggregate{"hash_count", &non_null},
+ Aggregate{"hash_count", &nulls},
+ Aggregate{"hash_min_max", nullptr},
+ Aggregate{"hash_min_max", &keepna},
+ }) {
+ SCOPED_TRACE(agg.function);
+ ValidateGroupBy({agg}, {batch->GetColumnByName("argument")},
+ {batch->GetColumnByName("key")});
+ }
+}
+
+// Count nulls/non_nulls from record batch with no nulls
+TEST(GroupBy, CountNull) {
+ auto batch = RecordBatchFromJSON(
+ schema({field("argument", float64()), field("key", utf8())}), R"([
+ [1.0, "alfa"],
+ [2.0, "beta"],
+ [3.0, "gama"]
+ ])");
+
+ CountOptions keepna{CountOptions::ONLY_NULL}, skipna{CountOptions::ONLY_VALID};
+
+ using internal::Aggregate;
+ for (auto agg : {
+ Aggregate{"hash_count", &keepna},
+ Aggregate{"hash_count", &skipna},
+ }) {
+ SCOPED_TRACE(agg.function);
+ ValidateGroupBy({agg}, {batch->GetColumnByName("argument")},
+ {batch->GetColumnByName("key")});
+ }
+}
+
+TEST(GroupBy, RandomArraySum) {
+ ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
+ for (int64_t length : {1 << 10, 1 << 12, 1 << 15}) {
+ for (auto null_probability : {0.0, 0.01, 0.5, 1.0}) {
+ auto batch = random::GenerateBatch(
+ {
+ field("argument", float32(),
+ key_value_metadata(
+ {{"null_probability", std::to_string(null_probability)}})),
+ field("key", int64(), key_value_metadata({{"min", "0"}, {"max", "100"}})),
+ },
+ length, 0xDEADBEEF);
+
+ ValidateGroupBy(
+ {
+ {"hash_sum", &options},
+ },
+ {batch->GetColumnByName("argument")}, {batch->GetColumnByName("key")});
+ }
+ }
+}
+
+TEST(GroupBy, WithChunkedArray) {
+ auto table =
+ TableFromJSON(schema({field("argument", float64()), field("key", int64())}),
+ {R"([{"argument": 1.0, "key": 1},
+ {"argument": null, "key": 1}
+ ])",
+ R"([{"argument": 0.0, "key": 2},
+ {"argument": null, "key": 3},
+ {"argument": 4.0, "key": null},
+ {"argument": 3.25, "key": 1},
+ {"argument": 0.125, "key": 2},
+ {"argument": -0.25, "key": 2},
+ {"argument": 0.75, "key": null},
+ {"argument": null, "key": 3}
+ ])"});
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_count", nullptr},
+ {"hash_sum", nullptr},
+ {"hash_min_max", nullptr},
+ }));
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_count", int64()),
+ field("hash_sum", float64()),
+ field("hash_min_max", struct_({
+ field("min", float64()),
+ field("max", float64()),
+ })),
+ field("key_0", int64()),
+ }),
+ R"([
+ [2, 4.25, {"min": 1.0, "max": 3.25}, 1],
+ [3, -0.125, {"min": -0.25, "max": 0.125}, 2],
+ [0, null, {"min": null, "max": null}, 3],
+ [2, 4.75, {"min": 0.75, "max": 4.0}, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+}
+
+TEST(GroupBy, MinMaxWithNewGroupsInChunkedArray) {
+ auto table = TableFromJSON(
+ schema({field("argument", int64()), field("key", int64())}),
+ {R"([{"argument": 1, "key": 0}])", R"([{"argument": 0, "key": 1}])"});
+ ScalarAggregateOptions count_options;
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ table->GetColumnByName("argument"),
+ },
+ {
+ table->GetColumnByName("key"),
+ },
+ {
+ {"hash_min_max", nullptr},
+ }));
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_min_max", struct_({
+ field("min", int64()),
+ field("max", int64()),
+ })),
+ field("key_0", int64()),
+ }),
+ R"([
+ [{"min": 1, "max": 1}, 0],
+ [{"min": 0, "max": 0}, 1]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+}
+
+TEST(GroupBy, SmallChunkSizeSumOnly) {
+ auto batch = RecordBatchFromJSON(
+ schema({field("argument", float64()), field("key", int64())}), R"([
+ [1.0, 1],
+ [null, 1],
+ [0.0, 2],
+ [null, 3],
+ [4.0, null],
+ [3.25, 1],
+ [0.125, 2],
+ [-0.25, 2],
+ [0.75, null],
+ [null, 3]
+ ])");
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy({batch->GetColumnByName("argument")},
+ {batch->GetColumnByName("key")},
+ {
+ {"hash_sum", nullptr},
+ },
+ small_chunksize_context()));
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_sum", float64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [4.25, 1],
+ [-0.125, 2],
+ [null, 3],
+ [4.75, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/row_encoder.cc b/src/arrow/cpp/src/arrow/compute/kernels/row_encoder.cc
new file mode 100644
index 000000000..840e4634f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/row_encoder.cc
@@ -0,0 +1,360 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/kernels/row_encoder.h"
+
+#include "arrow/util/bitmap_writer.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+
+namespace arrow {
+
+using internal::FirstTimeBitmapWriter;
+
+namespace compute {
+namespace internal {
+
+// extract the null bitmap from the leading nullity bytes of encoded keys
+Status KeyEncoder::DecodeNulls(MemoryPool* pool, int32_t length, uint8_t** encoded_bytes,
+ std::shared_ptr<Buffer>* null_bitmap,
+ int32_t* null_count) {
+ // first count nulls to determine if a null bitmap is necessary
+ *null_count = 0;
+ for (int32_t i = 0; i < length; ++i) {
+ *null_count += (encoded_bytes[i][0] == kNullByte);
+ }
+
+ if (*null_count > 0) {
+ ARROW_ASSIGN_OR_RAISE(*null_bitmap, AllocateBitmap(length, pool));
+ uint8_t* validity = (*null_bitmap)->mutable_data();
+
+ FirstTimeBitmapWriter writer(validity, 0, length);
+ for (int32_t i = 0; i < length; ++i) {
+ if (encoded_bytes[i][0] == kValidByte) {
+ writer.Set();
+ } else {
+ writer.Clear();
+ }
+ writer.Next();
+ encoded_bytes[i] += 1;
+ }
+ writer.Finish();
+ } else {
+ for (int32_t i = 0; i < length; ++i) {
+ encoded_bytes[i] += 1;
+ }
+ }
+ return Status ::OK();
+}
+
+void BooleanKeyEncoder::AddLength(const Datum& data, int64_t batch_length,
+ int32_t* lengths) {
+ for (int64_t i = 0; i < batch_length; ++i) {
+ lengths[i] += kByteWidth + kExtraByteForNull;
+ }
+}
+
+void BooleanKeyEncoder::AddLengthNull(int32_t* length) {
+ *length += kByteWidth + kExtraByteForNull;
+}
+
+Status BooleanKeyEncoder::Encode(const Datum& data, int64_t batch_length,
+ uint8_t** encoded_bytes) {
+ if (data.is_array()) {
+ VisitArrayDataInline<BooleanType>(
+ *data.array(),
+ [&](bool value) {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kValidByte;
+ *encoded_ptr++ = value;
+ },
+ [&] {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kNullByte;
+ *encoded_ptr++ = 0;
+ });
+ } else {
+ const auto& scalar = data.scalar_as<BooleanScalar>();
+ bool value = scalar.is_valid && scalar.value;
+ for (int64_t i = 0; i < batch_length; i++) {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kValidByte;
+ *encoded_ptr++ = value;
+ }
+ }
+ return Status::OK();
+}
+
+void BooleanKeyEncoder::EncodeNull(uint8_t** encoded_bytes) {
+ auto& encoded_ptr = *encoded_bytes;
+ *encoded_ptr++ = kNullByte;
+ *encoded_ptr++ = 0;
+}
+
+Result<std::shared_ptr<ArrayData>> BooleanKeyEncoder::Decode(uint8_t** encoded_bytes,
+ int32_t length,
+ MemoryPool* pool) {
+ std::shared_ptr<Buffer> null_buf;
+ int32_t null_count;
+ RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count));
+
+ ARROW_ASSIGN_OR_RAISE(auto key_buf, AllocateBitmap(length, pool));
+
+ uint8_t* raw_output = key_buf->mutable_data();
+ memset(raw_output, 0, BitUtil::BytesForBits(length));
+ for (int32_t i = 0; i < length; ++i) {
+ auto& encoded_ptr = encoded_bytes[i];
+ BitUtil::SetBitTo(raw_output, i, encoded_ptr[0] != 0);
+ encoded_ptr += 1;
+ }
+
+ return ArrayData::Make(boolean(), length, {std::move(null_buf), std::move(key_buf)},
+ null_count);
+}
+
+void FixedWidthKeyEncoder::AddLength(const Datum& data, int64_t batch_length,
+ int32_t* lengths) {
+ for (int64_t i = 0; i < batch_length; ++i) {
+ lengths[i] += byte_width_ + kExtraByteForNull;
+ }
+}
+
+void FixedWidthKeyEncoder::AddLengthNull(int32_t* length) {
+ *length += byte_width_ + kExtraByteForNull;
+}
+
+Status FixedWidthKeyEncoder::Encode(const Datum& data, int64_t batch_length,
+ uint8_t** encoded_bytes) {
+ if (data.is_array()) {
+ const auto& arr = *data.array();
+ ArrayData viewed(fixed_size_binary(byte_width_), arr.length, arr.buffers,
+ arr.null_count, arr.offset);
+
+ VisitArrayDataInline<FixedSizeBinaryType>(
+ viewed,
+ [&](util::string_view bytes) {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kValidByte;
+ memcpy(encoded_ptr, bytes.data(), byte_width_);
+ encoded_ptr += byte_width_;
+ },
+ [&] {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kNullByte;
+ memset(encoded_ptr, 0, byte_width_);
+ encoded_ptr += byte_width_;
+ });
+ } else {
+ const auto& scalar = data.scalar_as<arrow::internal::PrimitiveScalarBase>();
+ if (scalar.is_valid) {
+ const util::string_view data = scalar.view();
+ DCHECK_EQ(data.size(), static_cast<size_t>(byte_width_));
+ for (int64_t i = 0; i < batch_length; i++) {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kValidByte;
+ memcpy(encoded_ptr, data.data(), data.size());
+ encoded_ptr += byte_width_;
+ }
+ } else {
+ for (int64_t i = 0; i < batch_length; i++) {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kNullByte;
+ memset(encoded_ptr, 0, byte_width_);
+ encoded_ptr += byte_width_;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+void FixedWidthKeyEncoder::EncodeNull(uint8_t** encoded_bytes) {
+ auto& encoded_ptr = *encoded_bytes;
+ *encoded_ptr++ = kNullByte;
+ memset(encoded_ptr, 0, byte_width_);
+ encoded_ptr += byte_width_;
+}
+
+Result<std::shared_ptr<ArrayData>> FixedWidthKeyEncoder::Decode(uint8_t** encoded_bytes,
+ int32_t length,
+ MemoryPool* pool) {
+ std::shared_ptr<Buffer> null_buf;
+ int32_t null_count;
+ RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count));
+
+ ARROW_ASSIGN_OR_RAISE(auto key_buf, AllocateBuffer(length * byte_width_, pool));
+
+ uint8_t* raw_output = key_buf->mutable_data();
+ for (int32_t i = 0; i < length; ++i) {
+ auto& encoded_ptr = encoded_bytes[i];
+ std::memcpy(raw_output, encoded_ptr, byte_width_);
+ encoded_ptr += byte_width_;
+ raw_output += byte_width_;
+ }
+
+ return ArrayData::Make(type_, length, {std::move(null_buf), std::move(key_buf)},
+ null_count);
+}
+
+Status DictionaryKeyEncoder::Encode(const Datum& data, int64_t batch_length,
+ uint8_t** encoded_bytes) {
+ auto dict = data.is_array() ? MakeArray(data.array()->dictionary)
+ : data.scalar_as<DictionaryScalar>().value.dictionary;
+ if (dictionary_) {
+ if (!dictionary_->Equals(dict)) {
+ // TODO(bkietz) unify if necessary. For now, just error if any batch's dictionary
+ // differs from the first we saw for this key
+ return Status::NotImplemented("Unifying differing dictionaries");
+ }
+ } else {
+ dictionary_ = std::move(dict);
+ }
+ if (data.is_array()) {
+ return FixedWidthKeyEncoder::Encode(data, batch_length, encoded_bytes);
+ }
+ return FixedWidthKeyEncoder::Encode(data.scalar_as<DictionaryScalar>().value.index,
+ batch_length, encoded_bytes);
+}
+
+Result<std::shared_ptr<ArrayData>> DictionaryKeyEncoder::Decode(uint8_t** encoded_bytes,
+ int32_t length,
+ MemoryPool* pool) {
+ ARROW_ASSIGN_OR_RAISE(auto data,
+ FixedWidthKeyEncoder::Decode(encoded_bytes, length, pool));
+
+ if (dictionary_) {
+ data->dictionary = dictionary_->data();
+ } else {
+ ARROW_DCHECK(type_->id() == Type::DICTIONARY);
+ const auto& dict_type = checked_cast<const DictionaryType&>(*type_);
+ ARROW_ASSIGN_OR_RAISE(auto dict, MakeArrayOfNull(dict_type.value_type(), 0));
+ data->dictionary = dict->data();
+ }
+
+ data->type = type_;
+ return data;
+}
+
+void RowEncoder::Init(const std::vector<ValueDescr>& column_types, ExecContext* ctx) {
+ ctx_ = ctx;
+ encoders_.resize(column_types.size());
+
+ for (size_t i = 0; i < column_types.size(); ++i) {
+ const auto& column_type = column_types[i].type;
+
+ if (column_type->id() == Type::BOOL) {
+ encoders_[i] = std::make_shared<BooleanKeyEncoder>();
+ continue;
+ }
+
+ if (column_type->id() == Type::DICTIONARY) {
+ encoders_[i] =
+ std::make_shared<DictionaryKeyEncoder>(column_type, ctx->memory_pool());
+ continue;
+ }
+
+ if (is_fixed_width(column_type->id())) {
+ encoders_[i] = std::make_shared<FixedWidthKeyEncoder>(column_type);
+ continue;
+ }
+
+ if (is_binary_like(column_type->id())) {
+ encoders_[i] = std::make_shared<VarLengthKeyEncoder<BinaryType>>(column_type);
+ continue;
+ }
+
+ if (is_large_binary_like(column_type->id())) {
+ encoders_[i] = std::make_shared<VarLengthKeyEncoder<LargeBinaryType>>(column_type);
+ continue;
+ }
+
+ // We should not get here
+ ARROW_DCHECK(false);
+ }
+
+ int32_t total_length = 0;
+ for (size_t i = 0; i < column_types.size(); ++i) {
+ encoders_[i]->AddLengthNull(&total_length);
+ }
+ encoded_nulls_.resize(total_length);
+ uint8_t* buf_ptr = encoded_nulls_.data();
+ for (size_t i = 0; i < column_types.size(); ++i) {
+ encoders_[i]->EncodeNull(&buf_ptr);
+ }
+}
+
+void RowEncoder::Clear() {
+ offsets_.clear();
+ bytes_.clear();
+}
+
+Status RowEncoder::EncodeAndAppend(const ExecBatch& batch) {
+ if (offsets_.empty()) {
+ offsets_.resize(1);
+ offsets_[0] = 0;
+ }
+ size_t length_before = offsets_.size() - 1;
+ offsets_.resize(length_before + batch.length + 1);
+ for (int64_t i = 0; i < batch.length; ++i) {
+ offsets_[length_before + 1 + i] = 0;
+ }
+
+ for (int i = 0; i < batch.num_values(); ++i) {
+ encoders_[i]->AddLength(batch[i], batch.length, offsets_.data() + length_before + 1);
+ }
+
+ int32_t total_length = offsets_[length_before];
+ for (int64_t i = 0; i < batch.length; ++i) {
+ total_length += offsets_[length_before + 1 + i];
+ offsets_[length_before + 1 + i] = total_length;
+ }
+
+ bytes_.resize(total_length);
+ std::vector<uint8_t*> buf_ptrs(batch.length);
+ for (int64_t i = 0; i < batch.length; ++i) {
+ buf_ptrs[i] = bytes_.data() + offsets_[length_before + i];
+ }
+
+ for (int i = 0; i < batch.num_values(); ++i) {
+ RETURN_NOT_OK(encoders_[i]->Encode(batch[i], batch.length, buf_ptrs.data()));
+ }
+
+ return Status::OK();
+}
+
+Result<ExecBatch> RowEncoder::Decode(int64_t num_rows, const int32_t* row_ids) {
+ ExecBatch out({}, num_rows);
+
+ std::vector<uint8_t*> buf_ptrs(num_rows);
+ for (int64_t i = 0; i < num_rows; ++i) {
+ buf_ptrs[i] = (row_ids[i] == kRowIdForNulls()) ? encoded_nulls_.data()
+ : bytes_.data() + offsets_[row_ids[i]];
+ }
+
+ out.values.resize(encoders_.size());
+ for (size_t i = 0; i < encoders_.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(
+ out.values[i],
+ encoders_[i]->Decode(buf_ptrs.data(), static_cast<int32_t>(num_rows),
+ ctx_->memory_pool()));
+ }
+
+ return out;
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/row_encoder.h b/src/arrow/cpp/src/arrow/compute/kernels/row_encoder.h
new file mode 100644
index 000000000..40509f2df
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/row_encoder.h
@@ -0,0 +1,267 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+namespace internal {
+
+struct KeyEncoder {
+ // the first byte of an encoded key is used to indicate nullity
+ static constexpr bool kExtraByteForNull = true;
+
+ static constexpr uint8_t kNullByte = 1;
+ static constexpr uint8_t kValidByte = 0;
+
+ virtual ~KeyEncoder() = default;
+
+ virtual void AddLength(const Datum&, int64_t batch_length, int32_t* lengths) = 0;
+
+ virtual void AddLengthNull(int32_t* length) = 0;
+
+ virtual Status Encode(const Datum&, int64_t batch_length, uint8_t** encoded_bytes) = 0;
+
+ virtual void EncodeNull(uint8_t** encoded_bytes) = 0;
+
+ virtual Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes,
+ int32_t length, MemoryPool*) = 0;
+
+ // extract the null bitmap from the leading nullity bytes of encoded keys
+ static Status DecodeNulls(MemoryPool* pool, int32_t length, uint8_t** encoded_bytes,
+ std::shared_ptr<Buffer>* null_bitmap, int32_t* null_count);
+
+ static bool IsNull(const uint8_t* encoded_bytes) {
+ return encoded_bytes[0] == kNullByte;
+ }
+};
+
+struct BooleanKeyEncoder : KeyEncoder {
+ static constexpr int kByteWidth = 1;
+
+ void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override;
+
+ void AddLengthNull(int32_t* length) override;
+
+ Status Encode(const Datum& data, int64_t batch_length,
+ uint8_t** encoded_bytes) override;
+
+ void EncodeNull(uint8_t** encoded_bytes) override;
+
+ Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
+ MemoryPool* pool) override;
+};
+
+struct FixedWidthKeyEncoder : KeyEncoder {
+ explicit FixedWidthKeyEncoder(std::shared_ptr<DataType> type)
+ : type_(std::move(type)),
+ byte_width_(checked_cast<const FixedWidthType&>(*type_).bit_width() / 8) {}
+
+ void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override;
+
+ void AddLengthNull(int32_t* length) override;
+
+ Status Encode(const Datum& data, int64_t batch_length,
+ uint8_t** encoded_bytes) override;
+
+ void EncodeNull(uint8_t** encoded_bytes) override;
+
+ Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
+ MemoryPool* pool) override;
+
+ std::shared_ptr<DataType> type_;
+ int byte_width_;
+};
+
+struct DictionaryKeyEncoder : FixedWidthKeyEncoder {
+ DictionaryKeyEncoder(std::shared_ptr<DataType> type, MemoryPool* pool)
+ : FixedWidthKeyEncoder(std::move(type)), pool_(pool) {}
+
+ Status Encode(const Datum& data, int64_t batch_length,
+ uint8_t** encoded_bytes) override;
+
+ Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
+ MemoryPool* pool) override;
+
+ MemoryPool* pool_;
+ std::shared_ptr<Array> dictionary_;
+};
+
+template <typename T>
+struct VarLengthKeyEncoder : KeyEncoder {
+ using Offset = typename T::offset_type;
+
+ void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override {
+ if (data.is_array()) {
+ int64_t i = 0;
+ VisitArrayDataInline<T>(
+ *data.array(),
+ [&](util::string_view bytes) {
+ lengths[i++] +=
+ kExtraByteForNull + sizeof(Offset) + static_cast<int32_t>(bytes.size());
+ },
+ [&] { lengths[i++] += kExtraByteForNull + sizeof(Offset); });
+ } else {
+ const Scalar& scalar = *data.scalar();
+ const int32_t buffer_size =
+ scalar.is_valid ? static_cast<int32_t>(UnboxScalar<T>::Unbox(scalar).size())
+ : 0;
+ for (int64_t i = 0; i < batch_length; i++) {
+ lengths[i] += kExtraByteForNull + sizeof(Offset) + buffer_size;
+ }
+ }
+ }
+
+ void AddLengthNull(int32_t* length) override {
+ *length += kExtraByteForNull + sizeof(Offset);
+ }
+
+ Status Encode(const Datum& data, int64_t batch_length,
+ uint8_t** encoded_bytes) override {
+ if (data.is_array()) {
+ VisitArrayDataInline<T>(
+ *data.array(),
+ [&](util::string_view bytes) {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kValidByte;
+ util::SafeStore(encoded_ptr, static_cast<Offset>(bytes.size()));
+ encoded_ptr += sizeof(Offset);
+ memcpy(encoded_ptr, bytes.data(), bytes.size());
+ encoded_ptr += bytes.size();
+ },
+ [&] {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kNullByte;
+ util::SafeStore(encoded_ptr, static_cast<Offset>(0));
+ encoded_ptr += sizeof(Offset);
+ });
+ } else {
+ const auto& scalar = data.scalar_as<BaseBinaryScalar>();
+ if (scalar.is_valid) {
+ const auto& bytes = *scalar.value;
+ for (int64_t i = 0; i < batch_length; i++) {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kValidByte;
+ util::SafeStore(encoded_ptr, static_cast<Offset>(bytes.size()));
+ encoded_ptr += sizeof(Offset);
+ memcpy(encoded_ptr, bytes.data(), bytes.size());
+ encoded_ptr += bytes.size();
+ }
+ } else {
+ for (int64_t i = 0; i < batch_length; i++) {
+ auto& encoded_ptr = *encoded_bytes++;
+ *encoded_ptr++ = kNullByte;
+ util::SafeStore(encoded_ptr, static_cast<Offset>(0));
+ encoded_ptr += sizeof(Offset);
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ void EncodeNull(uint8_t** encoded_bytes) override {
+ auto& encoded_ptr = *encoded_bytes;
+ *encoded_ptr++ = kNullByte;
+ util::SafeStore(encoded_ptr, static_cast<Offset>(0));
+ encoded_ptr += sizeof(Offset);
+ }
+
+ Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
+ MemoryPool* pool) override {
+ std::shared_ptr<Buffer> null_buf;
+ int32_t null_count;
+ ARROW_RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count));
+
+ Offset length_sum = 0;
+ for (int32_t i = 0; i < length; ++i) {
+ length_sum += util::SafeLoadAs<Offset>(encoded_bytes[i]);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto offset_buf,
+ AllocateBuffer(sizeof(Offset) * (1 + length), pool));
+ ARROW_ASSIGN_OR_RAISE(auto key_buf, AllocateBuffer(length_sum));
+
+ auto raw_offsets = reinterpret_cast<Offset*>(offset_buf->mutable_data());
+ auto raw_keys = key_buf->mutable_data();
+
+ Offset current_offset = 0;
+ for (int32_t i = 0; i < length; ++i) {
+ raw_offsets[i] = current_offset;
+
+ auto key_length = util::SafeLoadAs<Offset>(encoded_bytes[i]);
+ encoded_bytes[i] += sizeof(Offset);
+
+ memcpy(raw_keys + current_offset, encoded_bytes[i], key_length);
+ encoded_bytes[i] += key_length;
+
+ current_offset += key_length;
+ }
+ raw_offsets[length] = current_offset;
+
+ return ArrayData::Make(
+ type_, length, {std::move(null_buf), std::move(offset_buf), std::move(key_buf)},
+ null_count);
+ }
+
+ explicit VarLengthKeyEncoder(std::shared_ptr<DataType> type) : type_(std::move(type)) {}
+
+ std::shared_ptr<DataType> type_;
+};
+
+class ARROW_EXPORT RowEncoder {
+ public:
+ static constexpr int kRowIdForNulls() { return -1; }
+
+ void Init(const std::vector<ValueDescr>& column_types, ExecContext* ctx);
+ void Clear();
+ Status EncodeAndAppend(const ExecBatch& batch);
+ Result<ExecBatch> Decode(int64_t num_rows, const int32_t* row_ids);
+
+ inline std::string encoded_row(int32_t i) const {
+ if (i == kRowIdForNulls()) {
+ return std::string(reinterpret_cast<const char*>(encoded_nulls_.data()),
+ encoded_nulls_.size());
+ }
+ int32_t row_length = offsets_[i + 1] - offsets_[i];
+ return std::string(reinterpret_cast<const char*>(bytes_.data() + offsets_[i]),
+ row_length);
+ }
+
+ int32_t num_rows() const {
+ return offsets_.size() == 0 ? 0 : static_cast<int32_t>(offsets_.size() - 1);
+ }
+
+ private:
+ ExecContext* ctx_;
+ std::vector<std::shared_ptr<KeyEncoder>> encoders_;
+ std::vector<int32_t> offsets_;
+ std::vector<uint8_t> bytes_;
+ std::vector<uint8_t> encoded_nulls_;
+};
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
new file mode 100644
index 000000000..763e40e11
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
@@ -0,0 +1,2609 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <utility>
+#include <vector>
+
+#include "arrow/compare.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+
+using internal::AddWithOverflow;
+using internal::DivideWithOverflow;
+using internal::MultiplyWithOverflow;
+using internal::NegateWithOverflow;
+using internal::SubtractWithOverflow;
+
+namespace compute {
+namespace internal {
+
+using applicator::ScalarBinaryEqualTypes;
+using applicator::ScalarBinaryNotNullEqualTypes;
+using applicator::ScalarUnary;
+using applicator::ScalarUnaryNotNull;
+using applicator::ScalarUnaryNotNullStateful;
+
+namespace {
+
+// N.B. take care not to conflict with type_traits.h as that can cause surprises in a
+// unity build
+
+template <typename T>
+using is_unsigned_integer = std::integral_constant<bool, std::is_integral<T>::value &&
+ std::is_unsigned<T>::value>;
+
+template <typename T>
+using is_signed_integer =
+ std::integral_constant<bool, std::is_integral<T>::value && std::is_signed<T>::value>;
+
+template <typename T, typename R = T>
+using enable_if_signed_c_integer = enable_if_t<is_signed_integer<T>::value, R>;
+
+template <typename T, typename R = T>
+using enable_if_unsigned_c_integer = enable_if_t<is_unsigned_integer<T>::value, R>;
+
+template <typename T, typename R = T>
+using enable_if_c_integer =
+ enable_if_t<is_signed_integer<T>::value || is_unsigned_integer<T>::value, R>;
+
+template <typename T, typename R = T>
+using enable_if_floating_point = enable_if_t<std::is_floating_point<T>::value, R>;
+
+template <typename T, typename R = T>
+using enable_if_decimal_value =
+ enable_if_t<std::is_same<Decimal128, T>::value || std::is_same<Decimal256, T>::value,
+ R>;
+
+struct AbsoluteValue {
+ template <typename T, typename Arg>
+ static constexpr enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg,
+ Status*) {
+ return std::fabs(arg);
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_unsigned_c_integer<Arg, T> Call(KernelContext*, Arg arg,
+ Status*) {
+ return arg;
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_signed_c_integer<Arg, T> Call(KernelContext*, Arg arg,
+ Status* st) {
+ return (arg < 0) ? arrow::internal::SafeSignedNegate(arg) : arg;
+ }
+};
+
+struct AbsoluteValueChecked {
+ template <typename T, typename Arg>
+ static enable_if_signed_c_integer<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ if (arg == std::numeric_limits<Arg>::min()) {
+ *st = Status::Invalid("overflow");
+ return arg;
+ }
+ return std::abs(arg);
+ }
+
+ template <typename T, typename Arg>
+ static enable_if_unsigned_c_integer<Arg, T> Call(KernelContext* ctx, Arg arg,
+ Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ return arg;
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg,
+ Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ return std::fabs(arg);
+ }
+};
+
+struct Add {
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right,
+ Status*) {
+ return left + right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_unsigned_c_integer<T> Call(KernelContext*, Arg0 left,
+ Arg1 right, Status*) {
+ return left + right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_signed_c_integer<T> Call(KernelContext*, Arg0 left,
+ Arg1 right, Status*) {
+ return arrow::internal::SafeSignedAdd(left, right);
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
+ return left + right;
+ }
+};
+
+struct AddChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_c_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
+ T result = 0;
+ if (ARROW_PREDICT_FALSE(AddWithOverflow(left, right, &result))) {
+ *st = Status::Invalid("overflow");
+ }
+ return result;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right,
+ Status*) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
+ return left + right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
+ return left + right;
+ }
+};
+
+struct Subtract {
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right,
+ Status*) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
+ return left - right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_unsigned_c_integer<T> Call(KernelContext*, Arg0 left,
+ Arg1 right, Status*) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
+ return left - right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_signed_c_integer<T> Call(KernelContext*, Arg0 left,
+ Arg1 right, Status*) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
+ return arrow::internal::SafeSignedSubtract(left, right);
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
+ return left + (-right);
+ }
+};
+
+struct SubtractChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_c_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
+ T result = 0;
+ if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) {
+ *st = Status::Invalid("overflow");
+ }
+ return result;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right,
+ Status*) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
+ return left - right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
+ return left + (-right);
+ }
+};
+
+struct Multiply {
+ static_assert(std::is_same<decltype(int8_t() * int8_t()), int32_t>::value, "");
+ static_assert(std::is_same<decltype(uint8_t() * uint8_t()), int32_t>::value, "");
+ static_assert(std::is_same<decltype(int16_t() * int16_t()), int32_t>::value, "");
+ static_assert(std::is_same<decltype(uint16_t() * uint16_t()), int32_t>::value, "");
+ static_assert(std::is_same<decltype(int32_t() * int32_t()), int32_t>::value, "");
+ static_assert(std::is_same<decltype(uint32_t() * uint32_t()), uint32_t>::value, "");
+ static_assert(std::is_same<decltype(int64_t() * int64_t()), int64_t>::value, "");
+ static_assert(std::is_same<decltype(uint64_t() * uint64_t()), uint64_t>::value, "");
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_floating_point<T> Call(KernelContext*, T left, T right,
+ Status*) {
+ return left * right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_t<
+ is_unsigned_integer<T>::value && !std::is_same<T, uint16_t>::value, T>
+ Call(KernelContext*, T left, T right, Status*) {
+ return left * right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_t<
+ is_signed_integer<T>::value && !std::is_same<T, int16_t>::value, T>
+ Call(KernelContext*, T left, T right, Status*) {
+ return to_unsigned(left) * to_unsigned(right);
+ }
+
+ // Multiplication of 16 bit integer types implicitly promotes to signed 32 bit
+ // integer. However, some inputs may nevertheless overflow (which triggers undefined
+ // behaviour). Therefore we first cast to 32 bit unsigned integers where overflow is
+ // well defined.
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_same<T, int16_t, T> Call(KernelContext*, int16_t left,
+ int16_t right, Status*) {
+ return static_cast<uint32_t>(left) * static_cast<uint32_t>(right);
+ }
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr enable_if_same<T, uint16_t, T> Call(KernelContext*, uint16_t left,
+ uint16_t right, Status*) {
+ return static_cast<uint32_t>(left) * static_cast<uint32_t>(right);
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
+ return left * right;
+ }
+};
+
+struct MultiplyChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_c_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
+ T result = 0;
+ if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(left, right, &result))) {
+ *st = Status::Invalid("overflow");
+ }
+ return result;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right,
+ Status*) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
+ return left * right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
+ return left * right;
+ }
+};
+
+struct Divide {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right,
+ Status*) {
+ return left / right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_c_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
+ T result;
+ if (ARROW_PREDICT_FALSE(DivideWithOverflow(left, right, &result))) {
+ if (right == 0) {
+ *st = Status::Invalid("divide by zero");
+ } else {
+ result = 0;
+ }
+ }
+ return result;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
+ Status* st) {
+ if (right == Arg1()) {
+ *st = Status::Invalid("Divide by zero");
+ return T();
+ } else {
+ return left / right;
+ }
+ }
+};
+
+struct DivideChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_c_integer<T> Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
+ T result;
+ if (ARROW_PREDICT_FALSE(DivideWithOverflow(left, right, &result))) {
+ if (right == 0) {
+ *st = Status::Invalid("divide by zero");
+ } else {
+ *st = Status::Invalid("overflow");
+ }
+ }
+ return result;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_point<T> Call(KernelContext*, Arg0 left, Arg1 right,
+ Status* st) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
+ if (ARROW_PREDICT_FALSE(right == 0)) {
+ *st = Status::Invalid("divide by zero");
+ return 0;
+ }
+ return left / right;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_decimal_value<T> Call(KernelContext* ctx, Arg0 left, Arg1 right,
+ Status* st) {
+ return Divide::Call<T>(ctx, left, right, st);
+ }
+};
+
+struct Negate {
+ template <typename T, typename Arg>
+ static constexpr enable_if_floating_point<T> Call(KernelContext*, Arg arg, Status*) {
+ return -arg;
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_unsigned_c_integer<T> Call(KernelContext*, Arg arg,
+ Status*) {
+ return ~arg + 1;
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_signed_c_integer<T> Call(KernelContext*, Arg arg, Status*) {
+ return arrow::internal::SafeSignedNegate(arg);
+ }
+};
+
+struct NegateChecked {
+ template <typename T, typename Arg>
+ static enable_if_signed_c_integer<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ T result = 0;
+ if (ARROW_PREDICT_FALSE(NegateWithOverflow(arg, &result))) {
+ *st = Status::Invalid("overflow");
+ }
+ return result;
+ }
+
+ template <typename T, typename Arg>
+ static enable_if_unsigned_c_integer<Arg, T> Call(KernelContext* ctx, Arg arg,
+ Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ DCHECK(false) << "This is included only for the purposes of instantiability from the "
+ "arithmetic kernel generator";
+ return 0;
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg,
+ Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ return -arg;
+ }
+};
+
+struct Power {
+ ARROW_NOINLINE
+ static uint64_t IntegerPower(uint64_t base, uint64_t exp) {
+ // right to left O(logn) power
+ uint64_t pow = 1;
+ while (exp) {
+ pow *= (exp & 1) ? base : 1;
+ base *= base;
+ exp >>= 1;
+ }
+ return pow;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_c_integer<T> Call(KernelContext*, T base, T exp, Status* st) {
+ if (exp < 0) {
+ *st = Status::Invalid("integers to negative integer powers are not allowed");
+ return 0;
+ }
+ return static_cast<T>(IntegerPower(base, exp));
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_point<T> Call(KernelContext*, T base, T exp, Status*) {
+ return std::pow(base, exp);
+ }
+};
+
+struct PowerChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_c_integer<T> Call(KernelContext*, Arg0 base, Arg1 exp, Status* st) {
+ if (exp < 0) {
+ *st = Status::Invalid("integers to negative integer powers are not allowed");
+ return 0;
+ } else if (exp == 0) {
+ return 1;
+ }
+ // left to right O(logn) power with overflow checks
+ bool overflow = false;
+ uint64_t bitmask =
+ 1ULL << (63 - BitUtil::CountLeadingZeros(static_cast<uint64_t>(exp)));
+ T pow = 1;
+ while (bitmask) {
+ overflow |= MultiplyWithOverflow(pow, pow, &pow);
+ if (exp & bitmask) {
+ overflow |= MultiplyWithOverflow(pow, base, &pow);
+ }
+ bitmask >>= 1;
+ }
+ if (overflow) {
+ *st = Status::Invalid("overflow");
+ }
+ return pow;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_point<T> Call(KernelContext*, Arg0 base, Arg1 exp, Status*) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
+ return std::pow(base, exp);
+ }
+};
+
+struct Sign {
+ template <typename T, typename Arg>
+ static constexpr enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg,
+ Status*) {
+ return std::isnan(arg) ? arg : ((arg == 0) ? 0 : (std::signbit(arg) ? -1 : 1));
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_unsigned_c_integer<Arg, T> Call(KernelContext*, Arg arg,
+ Status*) {
+ return (arg > 0) ? 1 : 0;
+ }
+
+ template <typename T, typename Arg>
+ static constexpr enable_if_signed_c_integer<Arg, T> Call(KernelContext*, Arg arg,
+ Status*) {
+ return (arg > 0) ? 1 : ((arg == 0) ? 0 : -1);
+ }
+};
+
+// Bitwise operations
+
+struct BitWiseNot {
+ template <typename T, typename Arg>
+ static T Call(KernelContext*, Arg arg, Status*) {
+ return ~arg;
+ }
+};
+
+struct BitWiseAnd {
+ template <typename T, typename Arg0, typename Arg1>
+ static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
+ return lhs & rhs;
+ }
+};
+
+struct BitWiseOr {
+ template <typename T, typename Arg0, typename Arg1>
+ static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
+ return lhs | rhs;
+ }
+};
+
+struct BitWiseXor {
+ template <typename T, typename Arg0, typename Arg1>
+ static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
+ return lhs ^ rhs;
+ }
+};
+
+struct ShiftLeft {
+ template <typename T, typename Arg0, typename Arg1>
+ static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
+ using Unsigned = typename std::make_unsigned<Arg0>::type;
+ static_assert(std::is_same<T, Arg0>::value, "");
+ if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
+ return lhs;
+ }
+ return static_cast<T>(static_cast<Unsigned>(lhs) << static_cast<Unsigned>(rhs));
+ }
+};
+
+// See SEI CERT C Coding Standard rule INT34-C
+struct ShiftLeftChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_unsigned_c_integer<T> Call(KernelContext*, Arg0 lhs, Arg1 rhs,
+ Status* st) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
+ *st = Status::Invalid("shift amount must be >= 0 and less than precision of type");
+ return lhs;
+ }
+ return lhs << rhs;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_signed_c_integer<T> Call(KernelContext*, Arg0 lhs, Arg1 rhs,
+ Status* st) {
+ using Unsigned = typename std::make_unsigned<Arg0>::type;
+ static_assert(std::is_same<T, Arg0>::value, "");
+ if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
+ *st = Status::Invalid("shift amount must be >= 0 and less than precision of type");
+ return lhs;
+ }
+ // In C/C++ left shift of a negative number is undefined (C++11 standard 5.8.2)
+ // Mimic Java/etc. and treat left shift as based on two's complement representation
+ // Assumes two's complement machine
+ return static_cast<T>(static_cast<Unsigned>(lhs) << static_cast<Unsigned>(rhs));
+ }
+};
+
+struct ShiftRight {
+ template <typename T, typename Arg0, typename Arg1>
+ static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ // Logical right shift when Arg0 is unsigned
+ // Arithmetic otherwise (this is implementation-defined but GCC and MSVC document this
+ // as arithmetic right shift)
+ // https://gcc.gnu.org/onlinedocs/gcc/Integers-implementation.html#Integers-implementation
+ // https://docs.microsoft.com/en-us/cpp/cpp/left-shift-and-right-shift-operators-input-and-output?view=msvc-160
+ // Clang doesn't document their behavior.
+ if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
+ return lhs;
+ }
+ return lhs >> rhs;
+ }
+};
+
+struct ShiftRightChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status* st) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
+ *st = Status::Invalid("shift amount must be >= 0 and less than precision of type");
+ return lhs;
+ }
+ return lhs >> rhs;
+ }
+};
+
+struct Sin {
+ template <typename T, typename Arg0>
+ static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status*) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ return std::sin(val);
+ }
+};
+
+struct SinChecked {
+ template <typename T, typename Arg0>
+ static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status* st) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ if (ARROW_PREDICT_FALSE(std::isinf(val))) {
+ *st = Status::Invalid("domain error");
+ return val;
+ }
+ return std::sin(val);
+ }
+};
+
+struct Cos {
+ template <typename T, typename Arg0>
+ static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status*) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ return std::cos(val);
+ }
+};
+
+struct CosChecked {
+ template <typename T, typename Arg0>
+ static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status* st) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ if (ARROW_PREDICT_FALSE(std::isinf(val))) {
+ *st = Status::Invalid("domain error");
+ return val;
+ }
+ return std::cos(val);
+ }
+};
+
+struct Tan {
+ template <typename T, typename Arg0>
+ static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status*) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ return std::tan(val);
+ }
+};
+
+struct TanChecked {
+ template <typename T, typename Arg0>
+ static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status* st) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ if (ARROW_PREDICT_FALSE(std::isinf(val))) {
+ *st = Status::Invalid("domain error");
+ return val;
+ }
+ // Cannot raise range errors (overflow) since PI/2 is not exactly representable
+ return std::tan(val);
+ }
+};
+
+struct Asin {
+ template <typename T, typename Arg0>
+ static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status*) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ if (ARROW_PREDICT_FALSE(val < -1.0 || val > 1.0)) {
+ return std::numeric_limits<T>::quiet_NaN();
+ }
+ return std::asin(val);
+ }
+};
+
+struct AsinChecked {
+ template <typename T, typename Arg0>
+ static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status* st) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ if (ARROW_PREDICT_FALSE(val < -1.0 || val > 1.0)) {
+ *st = Status::Invalid("domain error");
+ return val;
+ }
+ return std::asin(val);
+ }
+};
+
+struct Acos {
+ template <typename T, typename Arg0>
+ static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status*) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ if (ARROW_PREDICT_FALSE((val < -1.0 || val > 1.0))) {
+ return std::numeric_limits<T>::quiet_NaN();
+ }
+ return std::acos(val);
+ }
+};
+
+struct AcosChecked {
+ template <typename T, typename Arg0>
+ static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status* st) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ if (ARROW_PREDICT_FALSE((val < -1.0 || val > 1.0))) {
+ *st = Status::Invalid("domain error");
+ return val;
+ }
+ return std::acos(val);
+ }
+};
+
+struct Atan {
+ template <typename T, typename Arg0>
+ static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 val, Status*) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ return std::atan(val);
+ }
+};
+
+struct Atan2 {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_point<Arg0, T> Call(KernelContext*, Arg0 y, Arg1 x, Status*) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ static_assert(std::is_same<Arg0, Arg1>::value, "");
+ return std::atan2(y, x);
+ }
+};
+
+struct LogNatural {
+ template <typename T, typename Arg>
+ static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status*) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ if (arg == 0.0) {
+ return -std::numeric_limits<T>::infinity();
+ } else if (arg < 0.0) {
+ return std::numeric_limits<T>::quiet_NaN();
+ }
+ return std::log(arg);
+ }
+};
+
+struct LogNaturalChecked {
+ template <typename T, typename Arg>
+ static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ if (arg == 0.0) {
+ *st = Status::Invalid("logarithm of zero");
+ return arg;
+ } else if (arg < 0.0) {
+ *st = Status::Invalid("logarithm of negative number");
+ return arg;
+ }
+ return std::log(arg);
+ }
+};
+
+struct Log10 {
+ template <typename T, typename Arg>
+ static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status*) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ if (arg == 0.0) {
+ return -std::numeric_limits<T>::infinity();
+ } else if (arg < 0.0) {
+ return std::numeric_limits<T>::quiet_NaN();
+ }
+ return std::log10(arg);
+ }
+};
+
+struct Log10Checked {
+ template <typename T, typename Arg>
+ static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ if (arg == 0) {
+ *st = Status::Invalid("logarithm of zero");
+ return arg;
+ } else if (arg < 0) {
+ *st = Status::Invalid("logarithm of negative number");
+ return arg;
+ }
+ return std::log10(arg);
+ }
+};
+
+struct Log2 {
+ template <typename T, typename Arg>
+ static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status*) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ if (arg == 0.0) {
+ return -std::numeric_limits<T>::infinity();
+ } else if (arg < 0.0) {
+ return std::numeric_limits<T>::quiet_NaN();
+ }
+ return std::log2(arg);
+ }
+};
+
+struct Log2Checked {
+ template <typename T, typename Arg>
+ static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ if (arg == 0.0) {
+ *st = Status::Invalid("logarithm of zero");
+ return arg;
+ } else if (arg < 0.0) {
+ *st = Status::Invalid("logarithm of negative number");
+ return arg;
+ }
+ return std::log2(arg);
+ }
+};
+
+struct Log1p {
+ template <typename T, typename Arg>
+ static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status*) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ if (arg == -1) {
+ return -std::numeric_limits<T>::infinity();
+ } else if (arg < -1) {
+ return std::numeric_limits<T>::quiet_NaN();
+ }
+ return std::log1p(arg);
+ }
+};
+
+struct Log1pChecked {
+ template <typename T, typename Arg>
+ static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ if (arg == -1) {
+ *st = Status::Invalid("logarithm of zero");
+ return arg;
+ } else if (arg < -1) {
+ *st = Status::Invalid("logarithm of negative number");
+ return arg;
+ }
+ return std::log1p(arg);
+ }
+};
+
+struct Logb {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_point<T> Call(KernelContext*, Arg0 x, Arg1 base, Status*) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ static_assert(std::is_same<Arg0, Arg1>::value, "");
+ if (x == 0.0) {
+ if (base == 0.0 || base < 0.0) {
+ return std::numeric_limits<T>::quiet_NaN();
+ } else {
+ return -std::numeric_limits<T>::infinity();
+ }
+ } else if (x < 0.0) {
+ return std::numeric_limits<T>::quiet_NaN();
+ }
+ return std::log(x) / std::log(base);
+ }
+};
+
+struct LogbChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_point<T> Call(KernelContext*, Arg0 x, Arg1 base, Status* st) {
+ static_assert(std::is_same<T, Arg0>::value, "");
+ static_assert(std::is_same<Arg0, Arg1>::value, "");
+ if (x == 0.0 || base == 0.0) {
+ *st = Status::Invalid("logarithm of zero");
+ return x;
+ } else if (x < 0.0 || base < 0.0) {
+ *st = Status::Invalid("logarithm of negative number");
+ return x;
+ }
+ return std::log(x) / std::log(base);
+ }
+};
+
+struct RoundUtil {
+ // Calculate powers of ten with arbitrary integer exponent
+ template <typename T = double>
+ static enable_if_floating_point<T> Pow10(int64_t power) {
+ static constexpr T lut[] = {1e0F, 1e1F, 1e2F, 1e3F, 1e4F, 1e5F, 1e6F, 1e7F,
+ 1e8F, 1e9F, 1e10F, 1e11F, 1e12F, 1e13F, 1e14F, 1e15F};
+ int64_t lut_size = (sizeof(lut) / sizeof(*lut));
+ int64_t abs_power = std::abs(power);
+ auto pow10 = lut[std::min(abs_power, lut_size - 1)];
+ while (abs_power-- >= lut_size) {
+ pow10 *= 1e1F;
+ }
+ return (power >= 0) ? pow10 : (1 / pow10);
+ }
+};
+
+// Specializations of rounding implementations for round kernels
+template <typename Type, RoundMode>
+struct RoundImpl;
+
+template <typename Type>
+struct RoundImpl<Type, RoundMode::DOWN> {
+ template <typename T = Type>
+ static constexpr enable_if_floating_point<T> Round(const T val) {
+ return std::floor(val);
+ }
+
+ template <typename T = Type>
+ static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ (*val) -= remainder;
+ if (remainder.Sign() < 0) {
+ (*val) -= pow10;
+ }
+ }
+};
+
+template <typename Type>
+struct RoundImpl<Type, RoundMode::UP> {
+ template <typename T = Type>
+ static constexpr enable_if_floating_point<T> Round(const T val) {
+ return std::ceil(val);
+ }
+
+ template <typename T = Type>
+ static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ (*val) -= remainder;
+ if (remainder.Sign() > 0 && remainder != 0) {
+ (*val) += pow10;
+ }
+ }
+};
+
+template <typename Type>
+struct RoundImpl<Type, RoundMode::TOWARDS_ZERO> {
+ template <typename T = Type>
+ static constexpr enable_if_floating_point<T> Round(const T val) {
+ return std::trunc(val);
+ }
+
+ template <typename T = Type>
+ static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ (*val) -= remainder;
+ }
+};
+
+template <typename Type>
+struct RoundImpl<Type, RoundMode::TOWARDS_INFINITY> {
+ template <typename T = Type>
+ static constexpr enable_if_floating_point<T> Round(const T val) {
+ return std::signbit(val) ? std::floor(val) : std::ceil(val);
+ }
+
+ template <typename T = Type>
+ static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ (*val) -= remainder;
+ if (remainder.Sign() < 0) {
+ (*val) -= pow10;
+ } else if (remainder.Sign() > 0 && remainder != 0) {
+ (*val) += pow10;
+ }
+ }
+};
+
+// NOTE: RoundImpl variants for the HALF_* rounding modes are only
+// invoked when the fractional part is equal to 0.5 (std::round is invoked
+// otherwise).
+
+template <typename Type>
+struct RoundImpl<Type, RoundMode::HALF_DOWN> {
+ template <typename T = Type>
+ static constexpr enable_if_floating_point<T> Round(const T val) {
+ return RoundImpl<T, RoundMode::DOWN>::Round(val);
+ }
+
+ template <typename T = Type>
+ static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ RoundImpl<T, RoundMode::DOWN>::Round(val, remainder, pow10, scale);
+ }
+};
+
+template <typename Type>
+struct RoundImpl<Type, RoundMode::HALF_UP> {
+ template <typename T = Type>
+ static constexpr enable_if_floating_point<T> Round(const T val) {
+ return RoundImpl<T, RoundMode::UP>::Round(val);
+ }
+
+ template <typename T = Type>
+ static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ RoundImpl<T, RoundMode::UP>::Round(val, remainder, pow10, scale);
+ }
+};
+
+template <typename Type>
+struct RoundImpl<Type, RoundMode::HALF_TOWARDS_ZERO> {
+ template <typename T = Type>
+ static constexpr enable_if_floating_point<T> Round(const T val) {
+ return RoundImpl<T, RoundMode::TOWARDS_ZERO>::Round(val);
+ }
+
+ template <typename T = Type>
+ static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ RoundImpl<T, RoundMode::TOWARDS_ZERO>::Round(val, remainder, pow10, scale);
+ }
+};
+
+template <typename Type>
+struct RoundImpl<Type, RoundMode::HALF_TOWARDS_INFINITY> {
+ template <typename T = Type>
+ static constexpr enable_if_floating_point<T> Round(const T val) {
+ return RoundImpl<T, RoundMode::TOWARDS_INFINITY>::Round(val);
+ }
+
+ template <typename T = Type>
+ static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ RoundImpl<T, RoundMode::TOWARDS_INFINITY>::Round(val, remainder, pow10, scale);
+ }
+};
+
+template <typename Type>
+struct RoundImpl<Type, RoundMode::HALF_TO_EVEN> {
+ template <typename T = Type>
+ static constexpr enable_if_floating_point<T> Round(const T val) {
+ return std::round(val * T(0.5)) * 2;
+ }
+
+ template <typename T = Type>
+ static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ auto scaled = val->ReduceScaleBy(scale, /*round=*/false);
+ if (scaled.low_bits() % 2 != 0) {
+ scaled += remainder.Sign() >= 0 ? 1 : -1;
+ }
+ *val = scaled.IncreaseScaleBy(scale);
+ }
+};
+
+template <typename Type>
+struct RoundImpl<Type, RoundMode::HALF_TO_ODD> {
+ template <typename T = Type>
+ static constexpr enable_if_floating_point<T> Round(const T val) {
+ return std::floor(val * T(0.5)) + std::ceil(val * T(0.5));
+ }
+
+ template <typename T = Type>
+ static enable_if_decimal_value<T, void> Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ auto scaled = val->ReduceScaleBy(scale, /*round=*/false);
+ if (scaled.low_bits() % 2 == 0) {
+ scaled += remainder.Sign() ? 1 : -1;
+ }
+ *val = scaled.IncreaseScaleBy(scale);
+ }
+};
+
+// Specializations of kernel state for round kernels
+template <typename OptionsType>
+struct RoundOptionsWrapper;
+
+template <>
+struct RoundOptionsWrapper<RoundOptions> : public OptionsWrapper<RoundOptions> {
+ using OptionsType = RoundOptions;
+ using State = RoundOptionsWrapper<OptionsType>;
+ double pow10;
+
+ explicit RoundOptionsWrapper(OptionsType options) : OptionsWrapper(std::move(options)) {
+ // Only positive exponents for powers of 10 are used because combining
+ // multiply and division operations produced more stable rounding than
+ // using multiply-only. Refer to NumPy's round implementation:
+ // https://github.com/numpy/numpy/blob/7b2f20b406d27364c812f7a81a9c901afbd3600c/numpy/core/src/multiarray/calculation.c#L589
+ pow10 = RoundUtil::Pow10(std::abs(options.ndigits));
+ }
+
+ static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ if (auto options = static_cast<const OptionsType*>(args.options)) {
+ return ::arrow::internal::make_unique<State>(*options);
+ }
+ return Status::Invalid(
+ "Attempted to initialize KernelState from null FunctionOptions");
+ }
+};
+
+template <>
+struct RoundOptionsWrapper<RoundToMultipleOptions>
+ : public OptionsWrapper<RoundToMultipleOptions> {
+ using OptionsType = RoundToMultipleOptions;
+ using State = RoundOptionsWrapper<OptionsType>;
+ using OptionsWrapper::OptionsWrapper;
+
+ static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ std::unique_ptr<State> state;
+ if (auto options = static_cast<const OptionsType*>(args.options)) {
+ state = ::arrow::internal::make_unique<State>(*options);
+ } else {
+ return Status::Invalid(
+ "Attempted to initialize KernelState from null FunctionOptions");
+ }
+
+ auto options = Get(*state);
+ const auto& type = *args.inputs[0].type;
+ if (!options.multiple || !options.multiple->is_valid) {
+ return Status::Invalid("Rounding multiple must be non-null and valid");
+ }
+ if (is_floating(type.id())) {
+ switch (options.multiple->type->id()) {
+ case Type::FLOAT: {
+ if (UnboxScalar<FloatType>::Unbox(*options.multiple) < 0) {
+ return Status::Invalid("Rounding multiple must be positive");
+ }
+ break;
+ }
+ case Type::DOUBLE: {
+ if (UnboxScalar<DoubleType>::Unbox(*options.multiple) < 0) {
+ return Status::Invalid("Rounding multiple must be positive");
+ }
+ break;
+ }
+ case Type::HALF_FLOAT:
+ return Status::NotImplemented("Half-float values are not supported");
+ default:
+ return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ",
+ *options.multiple->type);
+ }
+ } else {
+ DCHECK(is_decimal(type.id()));
+ if (!type.Equals(*options.multiple->type)) {
+ return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ",
+ *options.multiple->type);
+ }
+ switch (options.multiple->type->id()) {
+ case Type::DECIMAL128: {
+ if (UnboxScalar<Decimal128Type>::Unbox(*options.multiple) <= 0) {
+ return Status::Invalid("Rounding multiple must be positive");
+ }
+ break;
+ }
+ case Type::DECIMAL256: {
+ if (UnboxScalar<Decimal256Type>::Unbox(*options.multiple) <= 0) {
+ return Status::Invalid("Rounding multiple must be positive");
+ }
+ break;
+ }
+ default:
+ // This shouldn't happen
+ return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ",
+ *options.multiple->type);
+ }
+ }
+ return std::move(state);
+ }
+};
+
+template <typename ArrowType, RoundMode RndMode, typename Enable = void>
+struct Round {
+ using CType = typename TypeTraits<ArrowType>::CType;
+ using State = RoundOptionsWrapper<RoundOptions>;
+
+ CType pow10;
+ int64_t ndigits;
+
+ explicit Round(const State& state, const DataType& out_ty)
+ : pow10(static_cast<CType>(state.pow10)), ndigits(state.options.ndigits) {}
+
+ template <typename T = ArrowType, typename CType = typename TypeTraits<T>::CType>
+ enable_if_floating_point<CType> Call(KernelContext* ctx, CType arg, Status* st) const {
+ // Do not process Inf or NaN because they will trigger the overflow error at end of
+ // function.
+ if (!std::isfinite(arg)) {
+ return arg;
+ }
+ auto round_val = ndigits >= 0 ? (arg * pow10) : (arg / pow10);
+ auto frac = round_val - std::floor(round_val);
+ if (frac != T(0)) {
+ // Use std::round() if in tie-breaking mode and scaled value is not 0.5.
+ if ((RndMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) {
+ round_val = std::round(round_val);
+ } else {
+ round_val = RoundImpl<CType, RndMode>::Round(round_val);
+ }
+ // Equality check is ommitted so that the common case of 10^0 (integer rounding)
+ // uses multiply-only
+ round_val = ndigits > 0 ? (round_val / pow10) : (round_val * pow10);
+ if (!std::isfinite(round_val)) {
+ *st = Status::Invalid("overflow occurred during rounding");
+ return arg;
+ }
+ } else {
+ // If scaled value is an integer, then no rounding is needed.
+ round_val = arg;
+ }
+ return round_val;
+ }
+};
+
+template <typename ArrowType, RoundMode kRoundMode>
+struct Round<ArrowType, kRoundMode, enable_if_decimal<ArrowType>> {
+ using CType = typename TypeTraits<ArrowType>::CType;
+ using State = RoundOptionsWrapper<RoundOptions>;
+
+ const ArrowType& ty;
+ int64_t ndigits;
+ int32_t pow;
+ // pow10 is "1" for the given decimal scale. Similarly half_pow10 is "0.5".
+ CType pow10, half_pow10, neg_half_pow10;
+
+ explicit Round(const State& state, const DataType& out_ty)
+ : Round(state.options.ndigits, out_ty) {}
+
+ explicit Round(int64_t ndigits, const DataType& out_ty)
+ : ty(checked_cast<const ArrowType&>(out_ty)),
+ ndigits(ndigits),
+ pow(static_cast<int32_t>(ty.scale() - ndigits)) {
+ if (pow >= ty.precision() || pow < 0) {
+ pow10 = half_pow10 = neg_half_pow10 = 0;
+ } else {
+ pow10 = CType::GetScaleMultiplier(pow);
+ half_pow10 = CType::GetHalfScaleMultiplier(pow);
+ neg_half_pow10 = -half_pow10;
+ }
+ }
+
+ template <typename T = ArrowType, typename CType = typename TypeTraits<T>::CType>
+ enable_if_decimal_value<CType> Call(KernelContext* ctx, CType arg, Status* st) const {
+ if (pow >= ty.precision()) {
+ *st = Status::Invalid("Rounding to ", ndigits,
+ " digits will not fit in precision of ", ty);
+ return arg;
+ } else if (pow < 0) {
+ // no-op, copy output to input
+ return arg;
+ }
+
+ std::pair<CType, CType> pair;
+ *st = arg.Divide(pow10).Value(&pair);
+ if (!st->ok()) return arg;
+ // The remainder is effectively the scaled fractional part after division.
+ const auto& remainder = pair.second;
+ if (remainder == 0) return arg;
+ if (kRoundMode >= RoundMode::HALF_DOWN) {
+ if (remainder == half_pow10 || remainder == neg_half_pow10) {
+ // On the halfway point, use tiebreaker
+ RoundImpl<CType, kRoundMode>::Round(&arg, remainder, pow10, pow);
+ } else if (remainder.Sign() >= 0) {
+ // Positive, round up/down
+ arg -= remainder;
+ if (remainder > half_pow10) {
+ arg += pow10;
+ }
+ } else {
+ // Negative, round up/down
+ arg -= remainder;
+ if (remainder < neg_half_pow10) {
+ arg -= pow10;
+ }
+ }
+ } else {
+ RoundImpl<CType, kRoundMode>::Round(&arg, remainder, pow10, pow);
+ }
+ if (!arg.FitsInPrecision(ty.precision())) {
+ *st = Status::Invalid("Rounded value ", arg.ToString(ty.scale()),
+ " does not fit in precision of ", ty);
+ return 0;
+ }
+ return arg;
+ }
+};
+
+template <typename DecimalType, RoundMode kMode, int32_t kDigits>
+Status FixedRoundDecimalExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ using Op = Round<DecimalType, kMode>;
+ return ScalarUnaryNotNullStateful<DecimalType, DecimalType, Op>(
+ Op(kDigits, *out->type()))
+ .Exec(ctx, batch, out);
+}
+
+template <typename ArrowType, RoundMode kRoundMode, typename Enable = void>
+struct RoundToMultiple {
+ using CType = typename TypeTraits<ArrowType>::CType;
+ using State = RoundOptionsWrapper<RoundToMultipleOptions>;
+
+ CType multiple;
+
+ explicit RoundToMultiple(const State& state, const DataType& out_ty) {
+ const auto& options = state.options;
+ DCHECK(options.multiple);
+ DCHECK(options.multiple->is_valid);
+ DCHECK(is_floating(options.multiple->type->id()));
+ switch (options.multiple->type->id()) {
+ case Type::FLOAT:
+ multiple = static_cast<CType>(UnboxScalar<FloatType>::Unbox(*options.multiple));
+ break;
+ case Type::DOUBLE:
+ multiple = static_cast<CType>(UnboxScalar<DoubleType>::Unbox(*options.multiple));
+ break;
+ default:
+ DCHECK(false);
+ }
+ }
+
+ template <typename T = ArrowType, typename CType = typename TypeTraits<T>::CType>
+ enable_if_floating_point<CType> Call(KernelContext* ctx, CType arg, Status* st) const {
+ // Do not process Inf or NaN because they will trigger the overflow error at end of
+ // function.
+ if (!std::isfinite(arg)) {
+ return arg;
+ }
+ auto round_val = arg / multiple;
+ auto frac = round_val - std::floor(round_val);
+ if (frac != T(0)) {
+ // Use std::round() if in tie-breaking mode and scaled value is not 0.5.
+ if ((kRoundMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) {
+ round_val = std::round(round_val);
+ } else {
+ round_val = RoundImpl<CType, kRoundMode>::Round(round_val);
+ }
+ round_val *= multiple;
+ if (!std::isfinite(round_val)) {
+ *st = Status::Invalid("overflow occurred during rounding");
+ return arg;
+ }
+ } else {
+ // If scaled value is an integer, then no rounding is needed.
+ round_val = arg;
+ }
+ return round_val;
+ }
+};
+
+template <typename ArrowType, RoundMode kRoundMode>
+struct RoundToMultiple<ArrowType, kRoundMode, enable_if_decimal<ArrowType>> {
+ using CType = typename TypeTraits<ArrowType>::CType;
+ using State = RoundOptionsWrapper<RoundToMultipleOptions>;
+
+ const ArrowType& ty;
+ CType multiple, half_multiple, neg_half_multiple;
+ bool has_halfway_point;
+
+ explicit RoundToMultiple(const State& state, const DataType& out_ty)
+ : ty(checked_cast<const ArrowType&>(out_ty)) {
+ const auto& options = state.options;
+ DCHECK(options.multiple);
+ DCHECK(options.multiple->is_valid);
+ DCHECK(options.multiple->type->Equals(out_ty));
+ multiple = UnboxScalar<ArrowType>::Unbox(*options.multiple);
+ half_multiple = multiple;
+ half_multiple /= 2;
+ neg_half_multiple = -half_multiple;
+ has_halfway_point = multiple.low_bits() % 2 == 0;
+ }
+
+ template <typename T = ArrowType, typename CType = typename TypeTraits<T>::CType>
+ enable_if_decimal_value<CType> Call(KernelContext* ctx, CType arg, Status* st) const {
+ std::pair<CType, CType> pair;
+ *st = arg.Divide(multiple).Value(&pair);
+ if (!st->ok()) return arg;
+ const auto& remainder = pair.second;
+ if (remainder == 0) return arg;
+ if (kRoundMode >= RoundMode::HALF_DOWN) {
+ if (has_halfway_point &&
+ (remainder == half_multiple || remainder == neg_half_multiple)) {
+ // On the halfway point, use tiebreaker
+ // Manually implement rounding since we're not actually rounding a
+ // decimal value, but rather manipulating the multiple
+ switch (kRoundMode) {
+ case RoundMode::HALF_DOWN:
+ if (remainder.Sign() < 0) pair.first -= 1;
+ break;
+ case RoundMode::HALF_UP:
+ if (remainder.Sign() >= 0) pair.first += 1;
+ break;
+ case RoundMode::HALF_TOWARDS_ZERO:
+ // Do nothing
+ break;
+ case RoundMode::HALF_TOWARDS_INFINITY:
+ if (remainder.Sign() >= 0) {
+ pair.first += 1;
+ } else {
+ pair.first -= 1;
+ }
+ break;
+ case RoundMode::HALF_TO_EVEN:
+ if (pair.first.low_bits() % 2 != 0) {
+ pair.first += remainder.Sign() >= 0 ? 1 : -1;
+ }
+ break;
+ case RoundMode::HALF_TO_ODD:
+ if (pair.first.low_bits() % 2 == 0) {
+ pair.first += remainder.Sign() >= 0 ? 1 : -1;
+ }
+ break;
+ default:
+ DCHECK(false);
+ }
+ } else if (remainder.Sign() >= 0) {
+ // Positive, round up/down
+ if (remainder > half_multiple) {
+ pair.first += 1;
+ }
+ } else {
+ // Negative, round up/down
+ if (remainder < neg_half_multiple) {
+ pair.first -= 1;
+ }
+ }
+ } else {
+ // Manually implement rounding since we're not actually rounding a
+ // decimal value, but rather manipulating the multiple
+ switch (kRoundMode) {
+ case RoundMode::DOWN:
+ if (remainder.Sign() < 0) pair.first -= 1;
+ break;
+ case RoundMode::UP:
+ if (remainder.Sign() >= 0) pair.first += 1;
+ break;
+ case RoundMode::TOWARDS_ZERO:
+ // Do nothing
+ break;
+ case RoundMode::TOWARDS_INFINITY:
+ if (remainder.Sign() >= 0) {
+ pair.first += 1;
+ } else {
+ pair.first -= 1;
+ }
+ break;
+ default:
+ DCHECK(false);
+ }
+ }
+ CType round_val = pair.first * multiple;
+ if (!round_val.FitsInPrecision(ty.precision())) {
+ *st = Status::Invalid("Rounded value ", round_val.ToString(ty.scale()),
+ " does not fit in precision of ", ty);
+ return 0;
+ }
+ return round_val;
+ }
+};
+
+struct Floor {
+ template <typename T, typename Arg>
+ static constexpr enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg,
+ Status*) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ return RoundImpl<T, RoundMode::DOWN>::Round(arg);
+ }
+};
+
+struct Ceil {
+ template <typename T, typename Arg>
+ static constexpr enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg,
+ Status*) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ return RoundImpl<T, RoundMode::UP>::Round(arg);
+ }
+};
+
+struct Trunc {
+ template <typename T, typename Arg>
+ static constexpr enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg,
+ Status*) {
+ static_assert(std::is_same<T, Arg>::value, "");
+ return RoundImpl<T, RoundMode::TOWARDS_ZERO>::Round(arg);
+ }
+};
+
+// Generate a kernel given an arithmetic functor
+template <template <typename... Args> class KernelGenerator, typename Op>
+ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::INT8:
+ return KernelGenerator<Int8Type, Int8Type, Op>::Exec;
+ case Type::UINT8:
+ return KernelGenerator<UInt8Type, UInt8Type, Op>::Exec;
+ case Type::INT16:
+ return KernelGenerator<Int16Type, Int16Type, Op>::Exec;
+ case Type::UINT16:
+ return KernelGenerator<UInt16Type, UInt16Type, Op>::Exec;
+ case Type::INT32:
+ return KernelGenerator<Int32Type, Int32Type, Op>::Exec;
+ case Type::UINT32:
+ return KernelGenerator<UInt32Type, UInt32Type, Op>::Exec;
+ case Type::INT64:
+ case Type::TIMESTAMP:
+ return KernelGenerator<Int64Type, Int64Type, Op>::Exec;
+ case Type::UINT64:
+ return KernelGenerator<UInt64Type, UInt64Type, Op>::Exec;
+ case Type::FLOAT:
+ return KernelGenerator<FloatType, FloatType, Op>::Exec;
+ case Type::DOUBLE:
+ return KernelGenerator<DoubleType, DoubleType, Op>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+// Generate a kernel given a bitwise arithmetic functor. Assumes the
+// functor treats all integer types of equal width identically
+template <template <typename... Args> class KernelGenerator, typename Op>
+ArrayKernelExec TypeAgnosticBitWiseExecFromOp(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::INT8:
+ case Type::UINT8:
+ return KernelGenerator<UInt8Type, UInt8Type, Op>::Exec;
+ case Type::INT16:
+ case Type::UINT16:
+ return KernelGenerator<UInt16Type, UInt16Type, Op>::Exec;
+ case Type::INT32:
+ case Type::UINT32:
+ return KernelGenerator<UInt32Type, UInt32Type, Op>::Exec;
+ case Type::INT64:
+ case Type::UINT64:
+ return KernelGenerator<UInt64Type, UInt64Type, Op>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+template <template <typename... Args> class KernelGenerator, typename Op>
+ArrayKernelExec ShiftExecFromOp(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::INT8:
+ return KernelGenerator<Int8Type, Int8Type, Op>::Exec;
+ case Type::UINT8:
+ return KernelGenerator<UInt8Type, UInt8Type, Op>::Exec;
+ case Type::INT16:
+ return KernelGenerator<Int16Type, Int16Type, Op>::Exec;
+ case Type::UINT16:
+ return KernelGenerator<UInt16Type, UInt16Type, Op>::Exec;
+ case Type::INT32:
+ return KernelGenerator<Int32Type, Int32Type, Op>::Exec;
+ case Type::UINT32:
+ return KernelGenerator<UInt32Type, UInt32Type, Op>::Exec;
+ case Type::INT64:
+ return KernelGenerator<Int64Type, Int64Type, Op>::Exec;
+ case Type::UINT64:
+ return KernelGenerator<UInt64Type, UInt64Type, Op>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+template <template <typename... Args> class KernelGenerator, typename Op>
+ArrayKernelExec GenerateArithmeticFloatingPoint(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::FLOAT:
+ return KernelGenerator<FloatType, FloatType, Op>::Exec;
+ case Type::DOUBLE:
+ return KernelGenerator<DoubleType, DoubleType, Op>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+// resolve decimal binary operation output type per *casted* args
+template <typename OutputGetter>
+Result<ValueDescr> ResolveDecimalBinaryOperationOutput(
+ const std::vector<ValueDescr>& args, OutputGetter&& getter) {
+ // casted args should be same size decimals
+ auto left_type = checked_cast<const DecimalType*>(args[0].type.get());
+ auto right_type = checked_cast<const DecimalType*>(args[1].type.get());
+ DCHECK_EQ(left_type->id(), right_type->id());
+
+ int32_t precision, scale;
+ std::tie(precision, scale) = getter(left_type->precision(), left_type->scale(),
+ right_type->precision(), right_type->scale());
+ ARROW_ASSIGN_OR_RAISE(auto type, DecimalType::Make(left_type->id(), precision, scale));
+ return ValueDescr(std::move(type), GetBroadcastShape(args));
+}
+
+Result<ValueDescr> ResolveDecimalAdditionOrSubtractionOutput(
+ KernelContext*, const std::vector<ValueDescr>& args) {
+ return ResolveDecimalBinaryOperationOutput(
+ args, [](int32_t p1, int32_t s1, int32_t p2, int32_t s2) {
+ DCHECK_EQ(s1, s2);
+ const int32_t scale = s1;
+ const int32_t precision = std::max(p1 - s1, p2 - s2) + scale + 1;
+ return std::make_pair(precision, scale);
+ });
+}
+
+Result<ValueDescr> ResolveDecimalMultiplicationOutput(
+ KernelContext*, const std::vector<ValueDescr>& args) {
+ return ResolveDecimalBinaryOperationOutput(
+ args, [](int32_t p1, int32_t s1, int32_t p2, int32_t s2) {
+ const int32_t scale = s1 + s2;
+ const int32_t precision = p1 + p2 + 1;
+ return std::make_pair(precision, scale);
+ });
+}
+
+Result<ValueDescr> ResolveDecimalDivisionOutput(KernelContext*,
+ const std::vector<ValueDescr>& args) {
+ return ResolveDecimalBinaryOperationOutput(
+ args, [](int32_t p1, int32_t s1, int32_t p2, int32_t s2) {
+ DCHECK_GE(s1, s2);
+ const int32_t scale = s1 - s2;
+ const int32_t precision = p1;
+ return std::make_pair(precision, scale);
+ });
+}
+
+template <typename Op>
+void AddDecimalBinaryKernels(const std::string& name,
+ std::shared_ptr<ScalarFunction>* func) {
+ OutputType out_type(null());
+ const std::string op = name.substr(0, name.find("_"));
+ if (op == "add" || op == "subtract") {
+ out_type = OutputType(ResolveDecimalAdditionOrSubtractionOutput);
+ } else if (op == "multiply") {
+ out_type = OutputType(ResolveDecimalMultiplicationOutput);
+ } else if (op == "divide") {
+ out_type = OutputType(ResolveDecimalDivisionOutput);
+ } else {
+ DCHECK(false);
+ }
+
+ auto in_type128 = InputType(Type::DECIMAL128);
+ auto in_type256 = InputType(Type::DECIMAL256);
+ auto exec128 = ScalarBinaryNotNullEqualTypes<Decimal128Type, Decimal128Type, Op>::Exec;
+ auto exec256 = ScalarBinaryNotNullEqualTypes<Decimal256Type, Decimal256Type, Op>::Exec;
+ DCHECK_OK((*func)->AddKernel({in_type128, in_type128}, out_type, exec128));
+ DCHECK_OK((*func)->AddKernel({in_type256, in_type256}, out_type, exec256));
+}
+
+// Generate a kernel given an arithmetic functor
+template <template <typename...> class KernelGenerator, typename OutType, typename Op>
+ArrayKernelExec GenerateArithmeticWithFixedIntOutType(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::INT8:
+ return KernelGenerator<OutType, Int8Type, Op>::Exec;
+ case Type::UINT8:
+ return KernelGenerator<OutType, UInt8Type, Op>::Exec;
+ case Type::INT16:
+ return KernelGenerator<OutType, Int16Type, Op>::Exec;
+ case Type::UINT16:
+ return KernelGenerator<OutType, UInt16Type, Op>::Exec;
+ case Type::INT32:
+ return KernelGenerator<OutType, Int32Type, Op>::Exec;
+ case Type::UINT32:
+ return KernelGenerator<OutType, UInt32Type, Op>::Exec;
+ case Type::INT64:
+ case Type::TIMESTAMP:
+ return KernelGenerator<OutType, Int64Type, Op>::Exec;
+ case Type::UINT64:
+ return KernelGenerator<OutType, UInt64Type, Op>::Exec;
+ case Type::FLOAT:
+ return KernelGenerator<FloatType, FloatType, Op>::Exec;
+ case Type::DOUBLE:
+ return KernelGenerator<DoubleType, DoubleType, Op>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
+struct ArithmeticFunction : ScalarFunction {
+ using ScalarFunction::ScalarFunction;
+
+ Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
+ RETURN_NOT_OK(CheckArity(*values));
+
+ RETURN_NOT_OK(CheckDecimals(values));
+
+ using arrow::compute::detail::DispatchExactImpl;
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+
+ EnsureDictionaryDecoded(values);
+
+ // Only promote types for binary functions
+ if (values->size() == 2) {
+ ReplaceNullWithOtherType(values);
+
+ if (auto type = CommonNumeric(*values)) {
+ ReplaceTypes(type, values);
+ }
+ }
+
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+ return arrow::compute::detail::NoMatchingKernel(this, *values);
+ }
+
+ Status CheckDecimals(std::vector<ValueDescr>* values) const {
+ if (!HasDecimal(*values)) return Status::OK();
+
+ if (values->size() == 2) {
+ // "add_checked" -> "add"
+ const auto func_name = name();
+ const std::string op = func_name.substr(0, func_name.find("_"));
+ if (op == "add" || op == "subtract") {
+ return CastBinaryDecimalArgs(DecimalPromotion::kAdd, values);
+ } else if (op == "multiply") {
+ return CastBinaryDecimalArgs(DecimalPromotion::kMultiply, values);
+ } else if (op == "divide") {
+ return CastBinaryDecimalArgs(DecimalPromotion::kDivide, values);
+ } else {
+ return Status::Invalid("Invalid decimal function: ", func_name);
+ }
+ }
+ return Status::OK();
+ }
+};
+
+/// An ArithmeticFunction that promotes only integer arguments to double.
+struct ArithmeticIntegerToFloatingPointFunction : public ArithmeticFunction {
+ using ArithmeticFunction::ArithmeticFunction;
+
+ Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
+ RETURN_NOT_OK(CheckArity(*values));
+ RETURN_NOT_OK(CheckDecimals(values));
+
+ using arrow::compute::detail::DispatchExactImpl;
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+
+ EnsureDictionaryDecoded(values);
+
+ if (values->size() == 2) {
+ ReplaceNullWithOtherType(values);
+ }
+
+ for (auto& descr : *values) {
+ if (is_integer(descr.type->id())) {
+ descr.type = float64();
+ }
+ }
+ if (auto type = CommonNumeric(*values)) {
+ ReplaceTypes(type, values);
+ }
+
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+ return arrow::compute::detail::NoMatchingKernel(this, *values);
+ }
+};
+
+/// An ArithmeticFunction that promotes integer arguments to double.
+struct ArithmeticFloatingPointFunction : public ArithmeticFunction {
+ using ArithmeticFunction::ArithmeticFunction;
+
+ Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
+ RETURN_NOT_OK(CheckArity(*values));
+ RETURN_NOT_OK(CheckDecimals(values));
+
+ using arrow::compute::detail::DispatchExactImpl;
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+
+ EnsureDictionaryDecoded(values);
+
+ if (values->size() == 2) {
+ ReplaceNullWithOtherType(values);
+ }
+
+ for (auto& descr : *values) {
+ if (is_integer(descr.type->id())) {
+ descr.type = float64();
+ }
+ }
+ if (auto type = CommonNumeric(*values)) {
+ ReplaceTypes(type, values);
+ }
+
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+ return arrow::compute::detail::NoMatchingKernel(this, *values);
+ }
+};
+
+// A scalar kernel that ignores (assumed all-null) inputs and returns null.
+Status NullToNullExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return Status::OK();
+}
+
+void AddNullExec(ScalarFunction* func) {
+ std::vector<InputType> input_types(func->arity().num_args, InputType(Type::NA));
+ DCHECK_OK(func->AddKernel(std::move(input_types), OutputType(null()), NullToNullExec));
+}
+
+template <typename Op>
+std::shared_ptr<ScalarFunction> MakeArithmeticFunction(std::string name,
+ const FunctionDoc* doc) {
+ auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc);
+ for (const auto& ty : NumericTypes()) {
+ auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Op>(ty);
+ DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
+ }
+ AddNullExec(func.get());
+ return func;
+}
+
+// Like MakeArithmeticFunction, but for arithmetic ops that need to run
+// only on non-null output.
+template <typename Op>
+std::shared_ptr<ScalarFunction> MakeArithmeticFunctionNotNull(std::string name,
+ const FunctionDoc* doc) {
+ auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc);
+ for (const auto& ty : NumericTypes()) {
+ auto exec = ArithmeticExecFromOp<ScalarBinaryNotNullEqualTypes, Op>(ty);
+ DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
+ }
+ AddNullExec(func.get());
+ return func;
+}
+
+template <typename Op>
+std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunction(std::string name,
+ const FunctionDoc* doc) {
+ auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
+ for (const auto& ty : NumericTypes()) {
+ auto exec = ArithmeticExecFromOp<ScalarUnary, Op>(ty);
+ DCHECK_OK(func->AddKernel({ty}, ty, exec));
+ }
+ AddNullExec(func.get());
+ return func;
+}
+
+// Like MakeUnaryArithmeticFunction, but for unary arithmetic ops with a fixed
+// output type for integral inputs.
+template <typename Op, typename IntOutType>
+std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionWithFixedIntOutType(
+ std::string name, const FunctionDoc* doc) {
+ auto int_out_ty = TypeTraits<IntOutType>::type_singleton();
+ auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
+ for (const auto& ty : NumericTypes()) {
+ auto out_ty = arrow::is_floating(ty->id()) ? ty : int_out_ty;
+ auto exec = GenerateArithmeticWithFixedIntOutType<ScalarUnary, IntOutType, Op>(ty);
+ DCHECK_OK(func->AddKernel({ty}, out_ty, exec));
+ }
+ AddNullExec(func.get());
+ return func;
+}
+
+// Like MakeUnaryArithmeticFunction, but for arithmetic ops that need to run
+// only on non-null output.
+template <typename Op>
+std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionNotNull(
+ std::string name, const FunctionDoc* doc) {
+ auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
+ for (const auto& ty : NumericTypes()) {
+ auto exec = ArithmeticExecFromOp<ScalarUnaryNotNull, Op>(ty);
+ DCHECK_OK(func->AddKernel({ty}, ty, exec));
+ }
+ AddNullExec(func.get());
+ return func;
+}
+
+// Exec the round kernel for the given types
+template <typename Type, typename OptionsType,
+ template <typename, RoundMode, typename...> class OpImpl>
+Status ExecRound(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ using State = RoundOptionsWrapper<OptionsType>;
+ const auto& state = static_cast<const State&>(*ctx->state());
+ switch (state.options.round_mode) {
+ case RoundMode::DOWN: {
+ using Op = OpImpl<Type, RoundMode::DOWN>;
+ return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::UP: {
+ using Op = OpImpl<Type, RoundMode::UP>;
+ return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::TOWARDS_ZERO: {
+ using Op = OpImpl<Type, RoundMode::TOWARDS_ZERO>;
+ return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::TOWARDS_INFINITY: {
+ using Op = OpImpl<Type, RoundMode::TOWARDS_INFINITY>;
+ return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::HALF_DOWN: {
+ using Op = OpImpl<Type, RoundMode::HALF_DOWN>;
+ return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::HALF_UP: {
+ using Op = OpImpl<Type, RoundMode::HALF_UP>;
+ return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::HALF_TOWARDS_ZERO: {
+ using Op = OpImpl<Type, RoundMode::HALF_TOWARDS_ZERO>;
+ return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::HALF_TOWARDS_INFINITY: {
+ using Op = OpImpl<Type, RoundMode::HALF_TOWARDS_INFINITY>;
+ return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::HALF_TO_EVEN: {
+ using Op = OpImpl<Type, RoundMode::HALF_TO_EVEN>;
+ return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::HALF_TO_ODD: {
+ using Op = OpImpl<Type, RoundMode::HALF_TO_ODD>;
+ return ScalarUnaryNotNullStateful<Type, Type, Op>(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ }
+ DCHECK(false);
+ return Status::NotImplemented(
+ "Internal implementation error: round mode not implemented: ",
+ state.options.ToString());
+}
+
+// Like MakeUnaryArithmeticFunction, but for unary rounding functions that control
+// kernel dispatch based on RoundMode, only on non-null output.
+template <template <typename, RoundMode, typename...> class Op, typename OptionsType>
+std::shared_ptr<ScalarFunction> MakeUnaryRoundFunction(std::string name,
+ const FunctionDoc* doc) {
+ using State = RoundOptionsWrapper<OptionsType>;
+ static const OptionsType kDefaultOptions = OptionsType::Defaults();
+ auto func = std::make_shared<ArithmeticIntegerToFloatingPointFunction>(
+ name, Arity::Unary(), doc, &kDefaultOptions);
+ for (const auto& ty : {float32(), float64(), decimal128(1, 0), decimal256(1, 0)}) {
+ auto type_id = ty->id();
+ auto exec = [type_id](KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ switch (type_id) {
+ case Type::FLOAT:
+ return ExecRound<FloatType, OptionsType, Op>(ctx, batch, out);
+ case Type::DOUBLE:
+ return ExecRound<DoubleType, OptionsType, Op>(ctx, batch, out);
+ case Type::DECIMAL128:
+ return ExecRound<Decimal128Type, OptionsType, Op>(ctx, batch, out);
+ case Type::DECIMAL256:
+ return ExecRound<Decimal256Type, OptionsType, Op>(ctx, batch, out);
+ default: {
+ DCHECK(false);
+ return ExecFail(ctx, batch, out);
+ }
+ }
+ };
+ DCHECK_OK(func->AddKernel(
+ {InputType(type_id)},
+ is_decimal(type_id) ? OutputType(FirstType) : OutputType(ty), exec, State::Init));
+ }
+ AddNullExec(func.get());
+ return func;
+}
+
+// Like MakeUnaryArithmeticFunction, but for signed arithmetic ops that need to run
+// only on non-null output.
+template <typename Op>
+std::shared_ptr<ScalarFunction> MakeUnarySignedArithmeticFunctionNotNull(
+ std::string name, const FunctionDoc* doc) {
+ auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
+ for (const auto& ty : NumericTypes()) {
+ if (!arrow::is_unsigned_integer(ty->id())) {
+ auto exec = ArithmeticExecFromOp<ScalarUnaryNotNull, Op>(ty);
+ DCHECK_OK(func->AddKernel({ty}, ty, exec));
+ }
+ }
+ AddNullExec(func.get());
+ return func;
+}
+
+template <typename Op>
+std::shared_ptr<ScalarFunction> MakeBitWiseFunctionNotNull(std::string name,
+ const FunctionDoc* doc) {
+ auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc);
+ for (const auto& ty : IntTypes()) {
+ auto exec = TypeAgnosticBitWiseExecFromOp<ScalarBinaryNotNullEqualTypes, Op>(ty);
+ DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
+ }
+ AddNullExec(func.get());
+ return func;
+}
+
+template <typename Op>
+std::shared_ptr<ScalarFunction> MakeShiftFunctionNotNull(std::string name,
+ const FunctionDoc* doc) {
+ auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc);
+ for (const auto& ty : IntTypes()) {
+ auto exec = ShiftExecFromOp<ScalarBinaryNotNullEqualTypes, Op>(ty);
+ DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
+ }
+ AddNullExec(func.get());
+ return func;
+}
+
+template <typename Op, typename FunctionImpl = ArithmeticFloatingPointFunction>
+std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionFloatingPoint(
+ std::string name, const FunctionDoc* doc) {
+ auto func = std::make_shared<FunctionImpl>(name, Arity::Unary(), doc);
+ for (const auto& ty : FloatingPointTypes()) {
+ auto exec = GenerateArithmeticFloatingPoint<ScalarUnary, Op>(ty);
+ DCHECK_OK(func->AddKernel({ty}, ty, exec));
+ }
+ AddNullExec(func.get());
+ return func;
+}
+
+template <typename Op>
+std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionFloatingPointNotNull(
+ std::string name, const FunctionDoc* doc) {
+ auto func =
+ std::make_shared<ArithmeticFloatingPointFunction>(name, Arity::Unary(), doc);
+ for (const auto& ty : FloatingPointTypes()) {
+ auto exec = GenerateArithmeticFloatingPoint<ScalarUnaryNotNull, Op>(ty);
+ DCHECK_OK(func->AddKernel({ty}, ty, exec));
+ }
+ AddNullExec(func.get());
+ return func;
+}
+
+template <typename Op>
+std::shared_ptr<ScalarFunction> MakeArithmeticFunctionFloatingPoint(
+ std::string name, const FunctionDoc* doc) {
+ auto func =
+ std::make_shared<ArithmeticFloatingPointFunction>(name, Arity::Binary(), doc);
+ for (const auto& ty : FloatingPointTypes()) {
+ auto exec = GenerateArithmeticFloatingPoint<ScalarBinaryEqualTypes, Op>(ty);
+ DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
+ }
+ AddNullExec(func.get());
+ return func;
+}
+
+template <typename Op>
+std::shared_ptr<ScalarFunction> MakeArithmeticFunctionFloatingPointNotNull(
+ std::string name, const FunctionDoc* doc) {
+ auto func =
+ std::make_shared<ArithmeticFloatingPointFunction>(name, Arity::Binary(), doc);
+ for (const auto& ty : FloatingPointTypes()) {
+ auto output = is_integer(ty->id()) ? float64() : ty;
+ auto exec = GenerateArithmeticFloatingPoint<ScalarBinaryNotNullEqualTypes, Op>(ty);
+ DCHECK_OK(func->AddKernel({ty, ty}, output, exec));
+ }
+ AddNullExec(func.get());
+ return func;
+}
+
+const FunctionDoc absolute_value_doc{
+ "Calculate the absolute value of the argument element-wise",
+ ("Results will wrap around on integer overflow.\n"
+ "Use function \"abs_checked\" if you want overflow\n"
+ "to return an error."),
+ {"x"}};
+
+const FunctionDoc absolute_value_checked_doc{
+ "Calculate the absolute value of the argument element-wise",
+ ("This function returns an error on overflow. For a variant that\n"
+ "doesn't fail on overflow, use function \"abs\"."),
+ {"x"}};
+
+const FunctionDoc add_doc{"Add the arguments element-wise",
+ ("Results will wrap around on integer overflow.\n"
+ "Use function \"add_checked\" if you want overflow\n"
+ "to return an error."),
+ {"x", "y"}};
+
+const FunctionDoc add_checked_doc{
+ "Add the arguments element-wise",
+ ("This function returns an error on overflow. For a variant that\n"
+ "doesn't fail on overflow, use function \"add\"."),
+ {"x", "y"}};
+
+const FunctionDoc sub_doc{"Subtract the arguments element-wise",
+ ("Results will wrap around on integer overflow.\n"
+ "Use function \"subtract_checked\" if you want overflow\n"
+ "to return an error."),
+ {"x", "y"}};
+
+const FunctionDoc sub_checked_doc{
+ "Subtract the arguments element-wise",
+ ("This function returns an error on overflow. For a variant that\n"
+ "doesn't fail on overflow, use function \"subtract\"."),
+ {"x", "y"}};
+
+const FunctionDoc mul_doc{"Multiply the arguments element-wise",
+ ("Results will wrap around on integer overflow.\n"
+ "Use function \"multiply_checked\" if you want overflow\n"
+ "to return an error."),
+ {"x", "y"}};
+
+const FunctionDoc mul_checked_doc{
+ "Multiply the arguments element-wise",
+ ("This function returns an error on overflow. For a variant that\n"
+ "doesn't fail on overflow, use function \"multiply\"."),
+ {"x", "y"}};
+
+const FunctionDoc div_doc{
+ "Divide the arguments element-wise",
+ ("Integer division by zero returns an error. However, integer overflow\n"
+ "wraps around, and floating-point division by zero returns an infinite.\n"
+ "Use function \"divide_checked\" if you want to get an error\n"
+ "in all the aforementioned cases."),
+ {"dividend", "divisor"}};
+
+const FunctionDoc div_checked_doc{
+ "Divide the arguments element-wise",
+ ("An error is returned when trying to divide by zero, or when\n"
+ "integer overflow is encountered."),
+ {"dividend", "divisor"}};
+
+const FunctionDoc negate_doc{"Negate the argument element-wise",
+ ("Results will wrap around on integer overflow.\n"
+ "Use function \"negate_checked\" if you want overflow\n"
+ "to return an error."),
+ {"x"}};
+
+const FunctionDoc negate_checked_doc{
+ "Negate the arguments element-wise",
+ ("This function returns an error on overflow. For a variant that\n"
+ "doesn't fail on overflow, use function \"negate\"."),
+ {"x"}};
+
+const FunctionDoc pow_doc{
+ "Raise arguments to power element-wise",
+ ("Integer to negative integer power returns an error. However, integer overflow\n"
+ "wraps around. If either base or exponent is null the result will be null."),
+ {"base", "exponent"}};
+
+const FunctionDoc pow_checked_doc{
+ "Raise arguments to power element-wise",
+ ("An error is returned when integer to negative integer power is encountered,\n"
+ "or integer overflow is encountered."),
+ {"base", "exponent"}};
+
+const FunctionDoc sign_doc{
+ "Get the signedness of the arguments element-wise",
+ ("Output is any of (-1,1) for nonzero inputs and 0 for zero input.\n"
+ "NaN values return NaN. Integral values return signedness as Int8 and\n"
+ "floating-point values return it with the same type as the input values."),
+ {"x"}};
+
+const FunctionDoc bit_wise_not_doc{
+ "Bit-wise negate the arguments element-wise", "Null values return null.", {"x"}};
+
+const FunctionDoc bit_wise_and_doc{
+ "Bit-wise AND the arguments element-wise", "Null values return null.", {"x", "y"}};
+
+const FunctionDoc bit_wise_or_doc{
+ "Bit-wise OR the arguments element-wise", "Null values return null.", {"x", "y"}};
+
+const FunctionDoc bit_wise_xor_doc{
+ "Bit-wise XOR the arguments element-wise", "Null values return null.", {"x", "y"}};
+
+const FunctionDoc shift_left_doc{
+ "Left shift `x` by `y`",
+ ("This function will return `x` if `y` (the amount to shift by) is: "
+ "(1) negative or (2) greater than or equal to the precision of `x`.\n"
+ "The shift operates as if on the two's complement representation of the number. "
+ "In other words, this is equivalent to multiplying `x` by 2 to the power `y`, "
+ "even if overflow occurs.\n"
+ "Use function \"shift_left_checked\" if you want an invalid shift amount to "
+ "return an error."),
+ {"x", "y"}};
+
+const FunctionDoc shift_left_checked_doc{
+ "Left shift `x` by `y` with invalid shift check",
+ ("This function will raise an error if `y` (the amount to shift by) is: "
+ "(1) negative or (2) greater than or equal to the precision of `x`. "
+ "The shift operates as if on the two's complement representation of the number. "
+ "In other words, this is equivalent to multiplying `x` by 2 to the power `y`, "
+ "even if overflow occurs.\n"
+ "See \"shift_left\" for a variant that doesn't fail for an invalid shift amount."),
+ {"x", "y"}};
+
+const FunctionDoc shift_right_doc{
+ "Right shift `x` by `y`",
+ ("Perform a logical shift for unsigned `x` and an arithmetic shift for signed `x`.\n"
+ "This function will return `x` if `y` (the amount to shift by) is: "
+ "(1) negative or (2) greater than or equal to the precision of `x`.\n"
+ "Use function \"shift_right_checked\" if you want an invalid shift amount to return "
+ "an error."),
+ {"x", "y"}};
+
+const FunctionDoc shift_right_checked_doc{
+ "Right shift `x` by `y` with invalid shift check",
+ ("Perform a logical shift for unsigned `x` and an arithmetic shift for signed `x`.\n"
+ "This function will raise an error if `y` (the amount to shift by) is: "
+ "(1) negative or (2) greater than or equal to the precision of `x`.\n"
+ "See \"shift_right\" for a variant that doesn't fail for an invalid shift amount"),
+ {"x", "y"}};
+
+const FunctionDoc sin_doc{"Compute the sine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function returns NaN on values outside its domain. "
+ "To raise an error instead, see \"sin_checked\"."),
+ {"x"}};
+
+const FunctionDoc sin_checked_doc{
+ "Compute the sine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function raises an error on values outside its domain. "
+ "To return NaN instead, see \"sin\"."),
+ {"x"}};
+
+const FunctionDoc cos_doc{"Compute the cosine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function returns NaN on values outside its domain. "
+ "To raise an error instead, see \"cos_checked\"."),
+ {"x"}};
+
+const FunctionDoc cos_checked_doc{
+ "Compute the cosine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function raises an error on values outside its domain. "
+ "To return NaN instead, see \"cos\"."),
+ {"x"}};
+
+const FunctionDoc tan_doc{"Compute the tangent of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function returns NaN on values outside its domain. "
+ "To raise an error instead, see \"tan_checked\"."),
+ {"x"}};
+
+const FunctionDoc tan_checked_doc{
+ "Compute the tangent of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function raises an error on values outside its domain. "
+ "To return NaN instead, see \"tan\"."),
+ {"x"}};
+
+const FunctionDoc asin_doc{"Compute the inverse sine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function returns NaN on values outside its domain. "
+ "To raise an error instead, see \"asin_checked\"."),
+ {"x"}};
+
+const FunctionDoc asin_checked_doc{
+ "Compute the inverse sine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function raises an error on values outside its domain. "
+ "To return NaN instead, see \"asin\"."),
+ {"x"}};
+
+const FunctionDoc acos_doc{"Compute the inverse cosine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function returns NaN on values outside its domain. "
+ "To raise an error instead, see \"acos_checked\"."),
+ {"x"}};
+
+const FunctionDoc acos_checked_doc{
+ "Compute the inverse cosine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function raises an error on values outside its domain. "
+ "To return NaN instead, see \"acos\"."),
+ {"x"}};
+
+const FunctionDoc atan_doc{"Compute the principal value of the inverse tangent",
+ "Integer arguments return double values.",
+ {"x"}};
+
+const FunctionDoc atan2_doc{
+ "Compute the inverse tangent using argument signs to determine the quadrant",
+ "Integer arguments return double values.",
+ {"y", "x"}};
+
+const FunctionDoc ln_doc{
+ "Compute natural log of arguments element-wise",
+ ("Non-positive values return -inf or NaN. Null values return null.\n"
+ "Use function \"ln_checked\" if you want non-positive values to raise an error."),
+ {"x"}};
+
+const FunctionDoc ln_checked_doc{
+ "Compute natural log of arguments element-wise",
+ ("Non-positive values return -inf or NaN. Null values return null.\n"
+ "Use function \"ln\" if you want non-positive values to return "
+ "-inf or NaN."),
+ {"x"}};
+
+const FunctionDoc log10_doc{
+ "Compute log base 10 of arguments element-wise",
+ ("Non-positive values return -inf or NaN. Null values return null.\n"
+ "Use function \"log10_checked\" if you want non-positive values to raise an error."),
+ {"x"}};
+
+const FunctionDoc log10_checked_doc{
+ "Compute log base 10 of arguments element-wise",
+ ("Non-positive values return -inf or NaN. Null values return null.\n"
+ "Use function \"log10\" if you want non-positive values to return "
+ "-inf or NaN."),
+ {"x"}};
+
+const FunctionDoc log2_doc{
+ "Compute log base 2 of arguments element-wise",
+ ("Non-positive values return -inf or NaN. Null values return null.\n"
+ "Use function \"log2_checked\" if you want non-positive values to raise an error."),
+ {"x"}};
+
+const FunctionDoc log2_checked_doc{
+ "Compute log base 2 of arguments element-wise",
+ ("Non-positive values return -inf or NaN. Null values return null.\n"
+ "Use function \"log2\" if you want non-positive values to return "
+ "-inf or NaN."),
+ {"x"}};
+
+const FunctionDoc log1p_doc{
+ "Compute natural log of (1+x) element-wise",
+ ("Values <= -1 return -inf or NaN. Null values return null.\n"
+ "This function may be more precise than log(1 + x) for x close to zero."
+ "Use function \"log1p_checked\" if you want non-positive values to raise an error."),
+ {"x"}};
+
+const FunctionDoc log1p_checked_doc{
+ "Compute natural log of (1+x) element-wise",
+ ("Values <= -1 return -inf or NaN. Null values return null.\n"
+ "This function may be more precise than log(1 + x) for x close to zero."
+ "Use function \"log1p\" if you want non-positive values to return "
+ "-inf or NaN."),
+ {"x"}};
+
+const FunctionDoc logb_doc{
+ "Compute log of x to base b of arguments element-wise",
+ ("Values <= 0 return -inf or NaN. Null values return null.\n"
+ "Use function \"logb_checked\" if you want non-positive values to raise an error."),
+ {"x", "b"}};
+
+const FunctionDoc logb_checked_doc{
+ "Compute log of x to base b of arguments element-wise",
+ ("Values <= 0 return -inf or NaN. Null values return null.\n"
+ "Use function \"logb\" if you want non-positive values to return "
+ "-inf or NaN."),
+ {"x", "b"}};
+
+const FunctionDoc floor_doc{
+ "Round down to the nearest integer",
+ ("Calculate the nearest integer less than or equal in magnitude to the "
+ "argument element-wise"),
+ {"x"}};
+
+const FunctionDoc ceil_doc{
+ "Round up to the nearest integer",
+ ("Calculate the nearest integer greater than or equal in magnitude to the "
+ "argument element-wise"),
+ {"x"}};
+
+const FunctionDoc trunc_doc{
+ "Get the integral part without fractional digits",
+ ("Calculate the nearest integer not greater in magnitude than to the "
+ "argument element-wise."),
+ {"x"}};
+
+const FunctionDoc round_doc{
+ "Round to a given precision",
+ ("Options are used to control the number of digits and rounding mode.\n"
+ "Default behavior is to round to the nearest integer and use half-to-even "
+ "rule to break ties."),
+ {"x"},
+ "RoundOptions"};
+
+const FunctionDoc round_to_multiple_doc{
+ "Round to a given multiple",
+ ("Options are used to control the rounding multiple and rounding mode.\n"
+ "Default behavior is to round to the nearest integer and use half-to-even "
+ "rule to break ties."),
+ {"x"},
+ "RoundToMultipleOptions"};
+} // namespace
+
+void RegisterScalarArithmetic(FunctionRegistry* registry) {
+ // ----------------------------------------------------------------------
+ auto absolute_value =
+ MakeUnaryArithmeticFunction<AbsoluteValue>("abs", &absolute_value_doc);
+ DCHECK_OK(registry->AddFunction(std::move(absolute_value)));
+
+ // ----------------------------------------------------------------------
+ auto absolute_value_checked = MakeUnaryArithmeticFunctionNotNull<AbsoluteValueChecked>(
+ "abs_checked", &absolute_value_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(absolute_value_checked)));
+
+ // ----------------------------------------------------------------------
+ auto add = MakeArithmeticFunction<Add>("add", &add_doc);
+ AddDecimalBinaryKernels<Add>("add", &add);
+ DCHECK_OK(registry->AddFunction(std::move(add)));
+
+ // ----------------------------------------------------------------------
+ auto add_checked =
+ MakeArithmeticFunctionNotNull<AddChecked>("add_checked", &add_checked_doc);
+ AddDecimalBinaryKernels<AddChecked>("add_checked", &add_checked);
+ DCHECK_OK(registry->AddFunction(std::move(add_checked)));
+
+ // ----------------------------------------------------------------------
+ auto subtract = MakeArithmeticFunction<Subtract>("subtract", &sub_doc);
+ AddDecimalBinaryKernels<Subtract>("subtract", &subtract);
+
+ // Add subtract(timestamp, timestamp) -> duration
+ for (auto unit : TimeUnit::values()) {
+ InputType in_type(match::TimestampTypeUnit(unit));
+ auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Subtract>(Type::TIMESTAMP);
+ DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
+ }
+
+ DCHECK_OK(registry->AddFunction(std::move(subtract)));
+
+ // ----------------------------------------------------------------------
+ auto subtract_checked = MakeArithmeticFunctionNotNull<SubtractChecked>(
+ "subtract_checked", &sub_checked_doc);
+ AddDecimalBinaryKernels<SubtractChecked>("subtract_checked", &subtract_checked);
+ DCHECK_OK(registry->AddFunction(std::move(subtract_checked)));
+
+ // ----------------------------------------------------------------------
+ auto multiply = MakeArithmeticFunction<Multiply>("multiply", &mul_doc);
+ AddDecimalBinaryKernels<Multiply>("multiply", &multiply);
+ DCHECK_OK(registry->AddFunction(std::move(multiply)));
+
+ // ----------------------------------------------------------------------
+ auto multiply_checked = MakeArithmeticFunctionNotNull<MultiplyChecked>(
+ "multiply_checked", &mul_checked_doc);
+ AddDecimalBinaryKernels<MultiplyChecked>("multiply_checked", &multiply_checked);
+ DCHECK_OK(registry->AddFunction(std::move(multiply_checked)));
+
+ // ----------------------------------------------------------------------
+ auto divide = MakeArithmeticFunctionNotNull<Divide>("divide", &div_doc);
+ AddDecimalBinaryKernels<Divide>("divide", &divide);
+ DCHECK_OK(registry->AddFunction(std::move(divide)));
+
+ // ----------------------------------------------------------------------
+ auto divide_checked =
+ MakeArithmeticFunctionNotNull<DivideChecked>("divide_checked", &div_checked_doc);
+ AddDecimalBinaryKernels<DivideChecked>("divide_checked", &divide_checked);
+ DCHECK_OK(registry->AddFunction(std::move(divide_checked)));
+
+ // ----------------------------------------------------------------------
+ auto negate = MakeUnaryArithmeticFunction<Negate>("negate", &negate_doc);
+ DCHECK_OK(registry->AddFunction(std::move(negate)));
+
+ // ----------------------------------------------------------------------
+ auto negate_checked = MakeUnarySignedArithmeticFunctionNotNull<NegateChecked>(
+ "negate_checked", &negate_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(negate_checked)));
+
+ // ----------------------------------------------------------------------
+ auto power = MakeArithmeticFunction<Power>("power", &pow_doc);
+ DCHECK_OK(registry->AddFunction(std::move(power)));
+
+ // ----------------------------------------------------------------------
+ auto power_checked =
+ MakeArithmeticFunctionNotNull<PowerChecked>("power_checked", &pow_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(power_checked)));
+
+ // ----------------------------------------------------------------------
+ auto sign =
+ MakeUnaryArithmeticFunctionWithFixedIntOutType<Sign, Int8Type>("sign", &sign_doc);
+ DCHECK_OK(registry->AddFunction(std::move(sign)));
+
+ // ----------------------------------------------------------------------
+ // Bitwise functions
+ {
+ auto bit_wise_not = std::make_shared<ArithmeticFunction>(
+ "bit_wise_not", Arity::Unary(), &bit_wise_not_doc);
+ for (const auto& ty : IntTypes()) {
+ auto exec = TypeAgnosticBitWiseExecFromOp<ScalarUnaryNotNull, BitWiseNot>(ty);
+ DCHECK_OK(bit_wise_not->AddKernel({ty}, ty, exec));
+ }
+ AddNullExec(bit_wise_not.get());
+ DCHECK_OK(registry->AddFunction(std::move(bit_wise_not)));
+ }
+
+ auto bit_wise_and =
+ MakeBitWiseFunctionNotNull<BitWiseAnd>("bit_wise_and", &bit_wise_and_doc);
+ DCHECK_OK(registry->AddFunction(std::move(bit_wise_and)));
+
+ auto bit_wise_or =
+ MakeBitWiseFunctionNotNull<BitWiseOr>("bit_wise_or", &bit_wise_or_doc);
+ DCHECK_OK(registry->AddFunction(std::move(bit_wise_or)));
+
+ auto bit_wise_xor =
+ MakeBitWiseFunctionNotNull<BitWiseXor>("bit_wise_xor", &bit_wise_xor_doc);
+ DCHECK_OK(registry->AddFunction(std::move(bit_wise_xor)));
+
+ auto shift_left = MakeShiftFunctionNotNull<ShiftLeft>("shift_left", &shift_left_doc);
+ DCHECK_OK(registry->AddFunction(std::move(shift_left)));
+
+ auto shift_left_checked = MakeShiftFunctionNotNull<ShiftLeftChecked>(
+ "shift_left_checked", &shift_left_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(shift_left_checked)));
+
+ auto shift_right =
+ MakeShiftFunctionNotNull<ShiftRight>("shift_right", &shift_right_doc);
+ DCHECK_OK(registry->AddFunction(std::move(shift_right)));
+
+ auto shift_right_checked = MakeShiftFunctionNotNull<ShiftRightChecked>(
+ "shift_right_checked", &shift_right_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(shift_right_checked)));
+
+ // ----------------------------------------------------------------------
+ // Trig functions
+ auto sin = MakeUnaryArithmeticFunctionFloatingPoint<Sin>("sin", &sin_doc);
+ DCHECK_OK(registry->AddFunction(std::move(sin)));
+
+ auto sin_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<SinChecked>(
+ "sin_checked", &sin_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(sin_checked)));
+
+ auto cos = MakeUnaryArithmeticFunctionFloatingPoint<Cos>("cos", &cos_doc);
+ DCHECK_OK(registry->AddFunction(std::move(cos)));
+
+ auto cos_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<CosChecked>(
+ "cos_checked", &cos_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(cos_checked)));
+
+ auto tan = MakeUnaryArithmeticFunctionFloatingPoint<Tan>("tan", &tan_doc);
+ DCHECK_OK(registry->AddFunction(std::move(tan)));
+
+ auto tan_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<TanChecked>(
+ "tan_checked", &tan_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(tan_checked)));
+
+ auto asin = MakeUnaryArithmeticFunctionFloatingPoint<Asin>("asin", &asin_doc);
+ DCHECK_OK(registry->AddFunction(std::move(asin)));
+
+ auto asin_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<AsinChecked>(
+ "asin_checked", &asin_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(asin_checked)));
+
+ auto acos = MakeUnaryArithmeticFunctionFloatingPoint<Acos>("acos", &acos_doc);
+ DCHECK_OK(registry->AddFunction(std::move(acos)));
+
+ auto acos_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<AcosChecked>(
+ "acos_checked", &acos_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(acos_checked)));
+
+ auto atan = MakeUnaryArithmeticFunctionFloatingPoint<Atan>("atan", &atan_doc);
+ DCHECK_OK(registry->AddFunction(std::move(atan)));
+
+ auto atan2 = MakeArithmeticFunctionFloatingPoint<Atan2>("atan2", &atan2_doc);
+ DCHECK_OK(registry->AddFunction(std::move(atan2)));
+
+ // ----------------------------------------------------------------------
+ // Logarithms
+ auto ln = MakeUnaryArithmeticFunctionFloatingPoint<LogNatural>("ln", &ln_doc);
+ DCHECK_OK(registry->AddFunction(std::move(ln)));
+
+ auto ln_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<LogNaturalChecked>(
+ "ln_checked", &ln_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(ln_checked)));
+
+ auto log10 = MakeUnaryArithmeticFunctionFloatingPoint<Log10>("log10", &log10_doc);
+ DCHECK_OK(registry->AddFunction(std::move(log10)));
+
+ auto log10_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<Log10Checked>(
+ "log10_checked", &log10_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(log10_checked)));
+
+ auto log2 = MakeUnaryArithmeticFunctionFloatingPoint<Log2>("log2", &log2_doc);
+ DCHECK_OK(registry->AddFunction(std::move(log2)));
+
+ auto log2_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<Log2Checked>(
+ "log2_checked", &log2_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(log2_checked)));
+
+ auto log1p = MakeUnaryArithmeticFunctionFloatingPoint<Log1p>("log1p", &log1p_doc);
+ DCHECK_OK(registry->AddFunction(std::move(log1p)));
+
+ auto log1p_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull<Log1pChecked>(
+ "log1p_checked", &log1p_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(log1p_checked)));
+
+ auto logb = MakeArithmeticFunctionFloatingPoint<Logb>("logb", &logb_doc);
+ DCHECK_OK(registry->AddFunction(std::move(logb)));
+
+ auto logb_checked = MakeArithmeticFunctionFloatingPointNotNull<LogbChecked>(
+ "logb_checked", &logb_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(logb_checked)));
+
+ // ----------------------------------------------------------------------
+ // Rounding functions
+ auto floor =
+ MakeUnaryArithmeticFunctionFloatingPoint<Floor,
+ ArithmeticIntegerToFloatingPointFunction>(
+ "floor", &floor_doc);
+ DCHECK_OK(floor->AddKernel(
+ {InputType(Type::DECIMAL128)}, OutputType(FirstType),
+ FixedRoundDecimalExec<Decimal128Type, RoundMode::DOWN, /*ndigits=*/0>));
+ DCHECK_OK(floor->AddKernel(
+ {InputType(Type::DECIMAL256)}, OutputType(FirstType),
+ FixedRoundDecimalExec<Decimal256Type, RoundMode::DOWN, /*ndigits=*/0>));
+ DCHECK_OK(registry->AddFunction(std::move(floor)));
+
+ auto ceil =
+ MakeUnaryArithmeticFunctionFloatingPoint<Ceil,
+ ArithmeticIntegerToFloatingPointFunction>(
+ "ceil", &ceil_doc);
+ DCHECK_OK(ceil->AddKernel(
+ {InputType(Type::DECIMAL128)}, OutputType(FirstType),
+ FixedRoundDecimalExec<Decimal128Type, RoundMode::UP, /*ndigits=*/0>));
+ DCHECK_OK(ceil->AddKernel(
+ {InputType(Type::DECIMAL256)}, OutputType(FirstType),
+ FixedRoundDecimalExec<Decimal256Type, RoundMode::UP, /*ndigits=*/0>));
+ DCHECK_OK(registry->AddFunction(std::move(ceil)));
+
+ auto trunc =
+ MakeUnaryArithmeticFunctionFloatingPoint<Trunc,
+ ArithmeticIntegerToFloatingPointFunction>(
+ "trunc", &trunc_doc);
+ DCHECK_OK(trunc->AddKernel(
+ {InputType(Type::DECIMAL128)}, OutputType(FirstType),
+ FixedRoundDecimalExec<Decimal128Type, RoundMode::TOWARDS_ZERO, /*ndigits=*/0>));
+ DCHECK_OK(trunc->AddKernel(
+ {InputType(Type::DECIMAL256)}, OutputType(FirstType),
+ FixedRoundDecimalExec<Decimal256Type, RoundMode::TOWARDS_ZERO, /*ndigits=*/0>));
+ DCHECK_OK(registry->AddFunction(std::move(trunc)));
+
+ auto round = MakeUnaryRoundFunction<Round, RoundOptions>("round", &round_doc);
+ DCHECK_OK(registry->AddFunction(std::move(round)));
+
+ auto round_to_multiple =
+ MakeUnaryRoundFunction<RoundToMultiple, RoundToMultipleOptions>(
+ "round_to_multiple", &round_to_multiple_doc);
+ DCHECK_OK(registry->AddFunction(std::move(round_to_multiple)));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic_benchmark.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic_benchmark.cc
new file mode 100644
index 000000000..01d9ec944
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic_benchmark.cc
@@ -0,0 +1,159 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <vector>
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+
+namespace arrow {
+namespace compute {
+
+constexpr auto kSeed = 0x94378165;
+
+using BinaryOp = Result<Datum>(const Datum&, const Datum&, ArithmeticOptions,
+ ExecContext*);
+
+// Add explicit overflow-checked shortcuts, for easy benchmark parametering.
+static Result<Datum> AddChecked(const Datum& left, const Datum& right,
+ ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR) {
+ options.check_overflow = true;
+ return Add(left, right, std::move(options), ctx);
+}
+
+static Result<Datum> SubtractChecked(const Datum& left, const Datum& right,
+ ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR) {
+ options.check_overflow = true;
+ return Subtract(left, right, std::move(options), ctx);
+}
+
+static Result<Datum> MultiplyChecked(const Datum& left, const Datum& right,
+ ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR) {
+ options.check_overflow = true;
+ return Multiply(left, right, std::move(options), ctx);
+}
+
+static Result<Datum> DivideChecked(const Datum& left, const Datum& right,
+ ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR) {
+ options.check_overflow = true;
+ return Divide(left, right, std::move(options), ctx);
+}
+
+template <BinaryOp& Op, typename ArrowType, typename CType = typename ArrowType::c_type>
+static void ArrayScalarKernel(benchmark::State& state) {
+ RegressionArgs args(state);
+
+ const int64_t array_size = args.size / sizeof(CType);
+
+ // Choose values so as to avoid overflow on all ops and types
+ auto min = static_cast<CType>(6);
+ auto max = static_cast<CType>(min + 15);
+ Datum rhs(static_cast<CType>(6));
+
+ auto rand = random::RandomArrayGenerator(kSeed);
+ auto lhs = std::static_pointer_cast<NumericArray<ArrowType>>(
+ rand.Numeric<ArrowType>(array_size, min, max, args.null_proportion));
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(Op(lhs, rhs, ArithmeticOptions(), nullptr).status());
+ }
+ state.SetItemsProcessed(state.iterations() * array_size);
+}
+
+template <BinaryOp& Op, typename ArrowType, typename CType = typename ArrowType::c_type>
+static void ArrayArrayKernel(benchmark::State& state) {
+ RegressionArgs args(state);
+
+ // Choose values so as to avoid overflow on all ops and types
+ const int64_t array_size = args.size / sizeof(CType);
+ auto rmin = static_cast<CType>(1);
+ auto rmax = static_cast<CType>(rmin + 6); // 7
+ auto lmin = static_cast<CType>(rmax + 1); // 8
+ auto lmax = static_cast<CType>(lmin + 6); // 14
+
+ auto rand = random::RandomArrayGenerator(kSeed);
+ auto lhs = std::static_pointer_cast<NumericArray<ArrowType>>(
+ rand.Numeric<ArrowType>(array_size, lmin, lmax, args.null_proportion));
+ auto rhs = std::static_pointer_cast<NumericArray<ArrowType>>(
+ rand.Numeric<ArrowType>(array_size, rmin, rmax, args.null_proportion));
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(Op(lhs, rhs, ArithmeticOptions(), nullptr).status());
+ }
+ state.SetItemsProcessed(state.iterations() * array_size);
+}
+
+void SetArgs(benchmark::internal::Benchmark* bench) {
+ for (const auto inverse_null_proportion : std::vector<ArgsType>({100, 0})) {
+ bench->Args({static_cast<ArgsType>(kL2Size), inverse_null_proportion});
+ }
+}
+
+#define DECLARE_ARITHMETIC_BENCHMARKS(BENCHMARK, OP) \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, Int64Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, Int32Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, Int16Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, Int8Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, UInt64Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, UInt32Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, UInt16Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, UInt8Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, FloatType)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, DoubleType)->Apply(SetArgs)
+
+// Checked floating-point variants of arithmetic operations are identical to
+// non-checked variants, so do not bother measuring them.
+
+#define DECLARE_ARITHMETIC_CHECKED_BENCHMARKS(BENCHMARK, OP) \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, Int64Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, Int32Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, Int16Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, Int8Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, UInt64Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, UInt32Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, UInt16Type)->Apply(SetArgs); \
+ BENCHMARK_TEMPLATE(BENCHMARK, OP, UInt8Type)->Apply(SetArgs);
+
+DECLARE_ARITHMETIC_BENCHMARKS(ArrayArrayKernel, Add);
+DECLARE_ARITHMETIC_BENCHMARKS(ArrayScalarKernel, Add);
+DECLARE_ARITHMETIC_BENCHMARKS(ArrayArrayKernel, Subtract);
+DECLARE_ARITHMETIC_BENCHMARKS(ArrayScalarKernel, Subtract);
+DECLARE_ARITHMETIC_BENCHMARKS(ArrayArrayKernel, Multiply);
+DECLARE_ARITHMETIC_BENCHMARKS(ArrayScalarKernel, Multiply);
+DECLARE_ARITHMETIC_BENCHMARKS(ArrayArrayKernel, Divide);
+DECLARE_ARITHMETIC_BENCHMARKS(ArrayScalarKernel, Divide);
+
+DECLARE_ARITHMETIC_CHECKED_BENCHMARKS(ArrayArrayKernel, AddChecked);
+DECLARE_ARITHMETIC_CHECKED_BENCHMARKS(ArrayScalarKernel, AddChecked);
+DECLARE_ARITHMETIC_CHECKED_BENCHMARKS(ArrayArrayKernel, SubtractChecked);
+DECLARE_ARITHMETIC_CHECKED_BENCHMARKS(ArrayScalarKernel, SubtractChecked);
+DECLARE_ARITHMETIC_CHECKED_BENCHMARKS(ArrayArrayKernel, MultiplyChecked);
+DECLARE_ARITHMETIC_CHECKED_BENCHMARKS(ArrayScalarKernel, MultiplyChecked);
+DECLARE_ARITHMETIC_CHECKED_BENCHMARKS(ArrayArrayKernel, DivideChecked);
+DECLARE_ARITHMETIC_CHECKED_BENCHMARKS(ArrayScalarKernel, DivideChecked);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
new file mode 100644
index 000000000..09681b276
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
@@ -0,0 +1,3174 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cmath>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/math_constants.h"
+#include "arrow/util/string.h"
+
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+
+namespace arrow {
+namespace compute {
+
+using IntegralTypes = testing::Types<Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type,
+ UInt16Type, UInt32Type, UInt64Type>;
+
+using SignedIntegerTypes = testing::Types<Int8Type, Int16Type, Int32Type, Int64Type>;
+
+using UnsignedIntegerTypes =
+ testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type>;
+
+// TODO(kszucs): add half-float
+using FloatingTypes = testing::Types<FloatType, DoubleType>;
+
+// Assert that all-null-type inputs results in a null-type output.
+void AssertNullToNull(const std::string& func_name) {
+ SCOPED_TRACE(func_name);
+ ASSERT_OK_AND_ASSIGN(auto func, GetFunctionRegistry()->GetFunction(func_name));
+ ASSERT_OK_AND_ASSIGN(auto nulls, MakeArrayOfNull(null(), /*length=*/7));
+ const auto n = func->arity().num_args;
+
+ {
+ std::vector<Datum> args(n, nulls);
+ ASSERT_OK_AND_ASSIGN(auto result, CallFunction(func_name, args));
+ AssertArraysEqual(*nulls, *result.make_array(), /*verbose=*/true);
+ }
+
+ {
+ std::vector<Datum> args(n, Datum(std::make_shared<NullScalar>()));
+ ASSERT_OK_AND_ASSIGN(auto result, CallFunction(func_name, args));
+ AssertScalarsEqual(NullScalar(), *result.scalar(), /*verbose=*/true);
+ }
+}
+
+// Construct an array of decimals, where negative scale is allowed.
+//
+// Works around DecimalXXX::FromString intentionally not inferring
+// negative scales.
+std::shared_ptr<Array> DecimalArrayFromJSON(const std::shared_ptr<DataType>& type,
+ const std::string& json) {
+ const auto& ty = checked_cast<const DecimalType&>(*type);
+ if (ty.scale() >= 0) return ArrayFromJSON(type, json);
+ auto p = ty.precision() - ty.scale();
+ auto adjusted_ty = ty.id() == Type::DECIMAL128 ? decimal128(p, 0) : decimal256(p, 0);
+ return Cast(ArrayFromJSON(adjusted_ty, json), type).ValueOrDie().make_array();
+}
+
+template <typename T, typename OptionsType>
+class TestBaseUnaryArithmetic : public TestBase {
+ protected:
+ using ArrowType = T;
+ using CType = typename ArrowType::c_type;
+
+ static std::shared_ptr<DataType> type_singleton() {
+ return TypeTraits<ArrowType>::type_singleton();
+ }
+
+ using UnaryFunction =
+ std::function<Result<Datum>(const Datum&, OptionsType, ExecContext*)>;
+
+ std::shared_ptr<Scalar> MakeNullScalar() {
+ return arrow::MakeNullScalar(type_singleton());
+ }
+
+ std::shared_ptr<Scalar> MakeScalar(CType value) {
+ return *arrow::MakeScalar(type_singleton(), value);
+ }
+
+ void SetUp() override {}
+
+ // (CScalar, CScalar)
+ void AssertUnaryOp(UnaryFunction func, CType argument, CType expected) {
+ auto arg = MakeScalar(argument);
+ auto exp = MakeScalar(expected);
+ ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr));
+ AssertScalarsApproxEqual(*exp, *actual.scalar(), /*verbose=*/true);
+ }
+
+ // (Scalar, Scalar)
+ void AssertUnaryOp(UnaryFunction func, const std::shared_ptr<Scalar>& arg,
+ const std::shared_ptr<Scalar>& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr));
+ AssertScalarsApproxEqual(*expected, *actual.scalar(), /*verbose=*/true);
+ }
+
+ // (JSON, JSON)
+ void AssertUnaryOp(UnaryFunction func, const std::string& arg_json,
+ const std::string& expected_json) {
+ auto arg = ArrayFromJSON(type_singleton(), arg_json);
+ auto expected = ArrayFromJSON(type_singleton(), expected_json);
+ AssertUnaryOp(func, arg, expected);
+ }
+
+ // (Array, JSON)
+ void AssertUnaryOp(UnaryFunction func, const std::shared_ptr<Array>& arg,
+ const std::string& expected_json) {
+ const auto expected = ArrayFromJSON(type_singleton(), expected_json);
+ AssertUnaryOp(func, arg, expected);
+ }
+
+ // (JSON, Array)
+ void AssertUnaryOp(UnaryFunction func, const std::string& arg_json,
+ const std::shared_ptr<Array>& expected) {
+ auto arg = ArrayFromJSON(type_singleton(), arg_json);
+ AssertUnaryOp(func, arg, expected);
+ }
+
+ // (Array, Array)
+ void AssertUnaryOp(UnaryFunction func, const std::shared_ptr<Array>& arg,
+ const std::shared_ptr<Array>& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr));
+ ValidateAndAssertApproxEqual(actual.make_array(), expected);
+
+ // Also check (Scalar, Scalar) operations
+ const int64_t length = expected->length();
+ for (int64_t i = 0; i < length; ++i) {
+ const auto expected_scalar = *expected->GetScalar(i);
+ ASSERT_OK_AND_ASSIGN(actual, func(*arg->GetScalar(i), options_, nullptr));
+ AssertScalarsApproxEqual(*expected_scalar, *actual.scalar(), /*verbose=*/true,
+ equal_options_);
+ }
+ }
+
+ void AssertUnaryOpRaises(UnaryFunction func, const std::string& argument,
+ const std::string& expected_msg) {
+ auto arg = ArrayFromJSON(type_singleton(), argument);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(expected_msg),
+ func(arg, options_, nullptr));
+ for (int64_t i = 0; i < arg->length(); i++) {
+ ASSERT_OK_AND_ASSIGN(auto scalar, arg->GetScalar(i));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(expected_msg),
+ func(scalar, options_, nullptr));
+ }
+ }
+
+ void AssertUnaryOpNotImplemented(UnaryFunction func, const std::string& argument) {
+ auto arg = ArrayFromJSON(type_singleton(), argument);
+ const char* expected_msg = "has no kernel matching input types";
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, ::testing::HasSubstr(expected_msg),
+ func(arg, options_, nullptr));
+ }
+
+ void ValidateAndAssertApproxEqual(const std::shared_ptr<Array>& actual,
+ const std::string& expected) {
+ const auto exp = ArrayFromJSON(type_singleton(), expected);
+ ValidateAndAssertApproxEqual(actual, exp);
+ }
+
+ void ValidateAndAssertApproxEqual(const std::shared_ptr<Array>& actual,
+ const std::shared_ptr<Array>& expected) {
+ ValidateOutput(*actual);
+ AssertArraysApproxEqual(*expected, *actual, /*verbose=*/true, equal_options_);
+ }
+
+ void SetNansEqual(bool value = true) {
+ equal_options_ = equal_options_.nans_equal(value);
+ }
+
+ OptionsType options_ = OptionsType();
+ EqualOptions equal_options_ = EqualOptions::Defaults();
+};
+
+// Subclasses of TestBaseUnaryArithmetic for different FunctionOptions.
+template <typename T>
+class TestUnaryArithmetic : public TestBaseUnaryArithmetic<T, ArithmeticOptions> {
+ protected:
+ using Base = TestBaseUnaryArithmetic<T, ArithmeticOptions>;
+ using Base::options_;
+ void SetOverflowCheck(bool value) { options_.check_overflow = value; }
+};
+
+template <typename T>
+class TestUnaryArithmeticIntegral : public TestUnaryArithmetic<T> {};
+
+template <typename T>
+class TestUnaryArithmeticSigned : public TestUnaryArithmeticIntegral<T> {};
+
+template <typename T>
+class TestUnaryArithmeticUnsigned : public TestUnaryArithmeticIntegral<T> {};
+
+template <typename T>
+class TestUnaryArithmeticFloating : public TestUnaryArithmetic<T> {};
+
+template <typename T>
+class TestUnaryRound : public TestBaseUnaryArithmetic<T, RoundOptions> {
+ protected:
+ using Base = TestBaseUnaryArithmetic<T, RoundOptions>;
+ using Base::options_;
+ void SetRoundMode(RoundMode value) { options_.round_mode = value; }
+ void SetRoundNdigits(int64_t value) { options_.ndigits = value; }
+};
+
+template <typename T>
+class TestUnaryRoundIntegral : public TestUnaryRound<T> {};
+
+template <typename T>
+class TestUnaryRoundSigned : public TestUnaryRoundIntegral<T> {};
+
+template <typename T>
+class TestUnaryRoundUnsigned : public TestUnaryRoundIntegral<T> {};
+
+template <typename T>
+class TestUnaryRoundFloating : public TestUnaryRound<T> {};
+
+template <typename T>
+class TestUnaryRoundToMultiple
+ : public TestBaseUnaryArithmetic<T, RoundToMultipleOptions> {
+ protected:
+ using Base = TestBaseUnaryArithmetic<T, RoundToMultipleOptions>;
+ using Base::options_;
+ void SetRoundMode(RoundMode value) { options_.round_mode = value; }
+ void SetRoundMultiple(double value) {
+ options_.multiple = std::make_shared<DoubleScalar>(value);
+ }
+};
+
+template <typename T>
+class TestUnaryRoundToMultipleIntegral : public TestUnaryRoundToMultiple<T> {};
+
+template <typename T>
+class TestUnaryRoundToMultipleSigned : public TestUnaryRoundToMultipleIntegral<T> {};
+
+template <typename T>
+class TestUnaryRoundToMultipleUnsigned : public TestUnaryRoundToMultipleIntegral<T> {};
+
+template <typename T>
+class TestUnaryRoundToMultipleFloating : public TestUnaryRoundToMultiple<T> {};
+
+class TestArithmeticDecimal : public ::testing::Test {
+ protected:
+ std::vector<std::shared_ptr<DataType>> PositiveScaleTypes() {
+ return {decimal128(4, 2), decimal256(4, 2), decimal128(38, 2), decimal256(76, 2)};
+ }
+ std::vector<std::shared_ptr<DataType>> NegativeScaleTypes() {
+ return {decimal128(2, -2), decimal256(2, -2)};
+ }
+
+ // Validate that func(*decimals) is the same as
+ // func([cast(x, float64) x for x in decimals])
+ void CheckDecimalToFloat(const std::string& func, const DatumVector& args) {
+ DatumVector floating_args;
+ for (const auto& arg : args) {
+ if (is_decimal(arg.type()->id())) {
+ ASSERT_OK_AND_ASSIGN(auto casted, Cast(arg, float64()));
+ floating_args.push_back(casted);
+ } else {
+ floating_args.push_back(arg);
+ }
+ }
+ ASSERT_OK_AND_ASSIGN(auto expected, CallFunction(func, floating_args));
+ ASSERT_OK_AND_ASSIGN(auto actual, CallFunction(func, args));
+ auto equal_options = EqualOptions::Defaults().nans_equal(true);
+ AssertDatumsApproxEqual(actual, expected, /*verbose=*/true, equal_options);
+ }
+
+ void CheckRaises(const std::string& func, const DatumVector& args,
+ const std::string& substr, FunctionOptions* options = nullptr) {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(substr),
+ CallFunction(func, args, options));
+ }
+};
+
+template <typename T>
+class TestBinaryArithmetic : public TestBase {
+ protected:
+ using ArrowType = T;
+ using CType = typename ArrowType::c_type;
+
+ static std::shared_ptr<DataType> type_singleton() {
+ return TypeTraits<ArrowType>::type_singleton();
+ }
+
+ using BinaryFunction = std::function<Result<Datum>(const Datum&, const Datum&,
+ ArithmeticOptions, ExecContext*)>;
+
+ void SetUp() override { options_.check_overflow = false; }
+
+ std::shared_ptr<Scalar> MakeNullScalar() {
+ return arrow::MakeNullScalar(type_singleton());
+ }
+
+ std::shared_ptr<Scalar> MakeScalar(CType value) {
+ return *arrow::MakeScalar(type_singleton(), value);
+ }
+
+ // (Scalar, Scalar)
+ void AssertBinop(BinaryFunction func, CType lhs, CType rhs, CType expected) {
+ auto left = MakeScalar(lhs);
+ auto right = MakeScalar(rhs);
+ auto exp = MakeScalar(expected);
+
+ ASSERT_OK_AND_ASSIGN(auto actual, func(left, right, options_, nullptr));
+ AssertScalarsApproxEqual(*exp, *actual.scalar(), /*verbose=*/true);
+ }
+
+ // (Scalar, Array)
+ void AssertBinop(BinaryFunction func, CType lhs, const std::string& rhs,
+ const std::string& expected) {
+ auto left = MakeScalar(lhs);
+ AssertBinop(func, left, rhs, expected);
+ }
+
+ // (Scalar, Array)
+ void AssertBinop(BinaryFunction func, const std::shared_ptr<Scalar>& left,
+ const std::string& rhs, const std::string& expected) {
+ auto right = ArrayFromJSON(type_singleton(), rhs);
+ auto exp = ArrayFromJSON(type_singleton(), expected);
+
+ ASSERT_OK_AND_ASSIGN(auto actual, func(left, right, options_, nullptr));
+ ValidateAndAssertApproxEqual(actual.make_array(), expected);
+ }
+
+ // (Array, Scalar)
+ void AssertBinop(BinaryFunction func, const std::string& lhs, CType rhs,
+ const std::string& expected) {
+ auto right = MakeScalar(rhs);
+ AssertBinop(func, lhs, right, expected);
+ }
+
+ // (Array, Scalar) => Array
+ void AssertBinop(BinaryFunction func, const std::string& lhs,
+ const std::shared_ptr<Scalar>& right,
+ const std::shared_ptr<Array>& expected) {
+ auto left = ArrayFromJSON(type_singleton(), lhs);
+
+ ASSERT_OK_AND_ASSIGN(auto actual, func(left, right, options_, nullptr));
+ ValidateAndAssertApproxEqual(actual.make_array(), expected);
+ }
+
+ // (Array, Scalar)
+ void AssertBinop(BinaryFunction func, const std::string& lhs,
+ const std::shared_ptr<Scalar>& right, const std::string& expected) {
+ auto left = ArrayFromJSON(type_singleton(), lhs);
+ auto exp = ArrayFromJSON(type_singleton(), expected);
+
+ ASSERT_OK_AND_ASSIGN(auto actual, func(left, right, options_, nullptr));
+ ValidateAndAssertApproxEqual(actual.make_array(), expected);
+ }
+
+ // (Array, Array)
+ void AssertBinop(BinaryFunction func, const std::string& lhs, const std::string& rhs,
+ const std::string& expected) {
+ auto left = ArrayFromJSON(type_singleton(), lhs);
+ auto right = ArrayFromJSON(type_singleton(), rhs);
+
+ AssertBinop(func, left, right, expected);
+ }
+
+ // (Array, Array) => Array
+ void AssertBinop(BinaryFunction func, const std::string& lhs, const std::string& rhs,
+ const std::shared_ptr<Array>& expected) {
+ auto left = ArrayFromJSON(type_singleton(), lhs);
+ auto right = ArrayFromJSON(type_singleton(), rhs);
+
+ AssertBinop(func, left, right, expected);
+ }
+
+ // (Array, Array)
+ void AssertBinop(BinaryFunction func, const std::shared_ptr<Array>& left,
+ const std::shared_ptr<Array>& right,
+ const std::string& expected_json) {
+ const auto expected = ArrayFromJSON(type_singleton(), expected_json);
+ AssertBinop(func, left, right, expected);
+ }
+
+ void AssertBinop(BinaryFunction func, const std::shared_ptr<Array>& left,
+ const std::shared_ptr<Array>& right,
+ const std::shared_ptr<Array>& expected) {
+ ASSERT_OK_AND_ASSIGN(Datum actual, func(left, right, options_, nullptr));
+ ValidateAndAssertApproxEqual(actual.make_array(), expected);
+
+ // Also check (Scalar, Scalar) operations
+ const int64_t length = expected->length();
+ for (int64_t i = 0; i < length; ++i) {
+ const auto expected_scalar = *expected->GetScalar(i);
+ ASSERT_OK_AND_ASSIGN(
+ actual, func(*left->GetScalar(i), *right->GetScalar(i), options_, nullptr));
+ AssertScalarsApproxEqual(*expected_scalar, *actual.scalar(), /*verbose=*/true,
+ equal_options_);
+ }
+ }
+
+ void AssertBinopRaises(BinaryFunction func, const std::string& lhs,
+ const std::string& rhs, const std::string& expected_msg) {
+ auto left = ArrayFromJSON(type_singleton(), lhs);
+ auto right = ArrayFromJSON(type_singleton(), rhs);
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr(expected_msg),
+ func(left, right, options_, nullptr));
+ }
+
+ void ValidateAndAssertApproxEqual(const std::shared_ptr<Array>& actual,
+ const std::string& expected) {
+ ValidateAndAssertApproxEqual(actual, ArrayFromJSON(type_singleton(), expected));
+ }
+
+ void ValidateAndAssertApproxEqual(const std::shared_ptr<Array>& actual,
+ const std::shared_ptr<Array>& expected) {
+ ValidateOutput(*actual);
+ AssertArraysApproxEqual(*expected, *actual, /*verbose=*/true, equal_options_);
+ }
+
+ void SetOverflowCheck(bool value = true) { options_.check_overflow = value; }
+
+ void SetNansEqual(bool value = true) {
+ this->equal_options_ = equal_options_.nans_equal(value);
+ }
+
+ ArithmeticOptions options_ = ArithmeticOptions();
+ EqualOptions equal_options_ = EqualOptions::Defaults();
+};
+
+template <typename... Elements>
+std::string MakeArray(Elements... elements) {
+ std::vector<std::string> elements_as_strings = {std::to_string(elements)...};
+
+ std::vector<util::string_view> elements_as_views(sizeof...(Elements));
+ std::copy(elements_as_strings.begin(), elements_as_strings.end(),
+ elements_as_views.begin());
+
+ return "[" + ::arrow::internal::JoinStrings(elements_as_views, ",") + "]";
+}
+
+template <typename T>
+class TestBinaryArithmeticIntegral : public TestBinaryArithmetic<T> {};
+
+template <typename T>
+class TestBinaryArithmeticSigned : public TestBinaryArithmeticIntegral<T> {};
+
+template <typename T>
+class TestBinaryArithmeticUnsigned : public TestBinaryArithmeticIntegral<T> {};
+
+template <typename T>
+class TestBinaryArithmeticFloating : public TestBinaryArithmetic<T> {};
+
+template <typename T>
+class TestBitWiseArithmetic : public TestBase {
+ protected:
+ using ArrowType = T;
+ using CType = typename ArrowType::c_type;
+
+ static std::shared_ptr<DataType> type_singleton() {
+ return TypeTraits<ArrowType>::type_singleton();
+ }
+
+ void AssertUnaryOp(const std::string& func, const std::vector<uint8_t>& args,
+ const std::vector<uint8_t>& expected) {
+ auto input = ExpandByteArray(args);
+ auto output = ExpandByteArray(expected);
+ ASSERT_OK_AND_ASSIGN(Datum actual, CallFunction(func, {input}));
+ ValidateAndAssertEqual(actual.make_array(), output);
+ for (int64_t i = 0; i < output->length(); i++) {
+ ASSERT_OK_AND_ASSIGN(Datum actual, CallFunction(func, {*input->GetScalar(i)}));
+ const auto expected_scalar = *output->GetScalar(i);
+ AssertScalarsEqual(*expected_scalar, *actual.scalar(), /*verbose=*/true);
+ }
+ }
+
+ void AssertBinaryOp(const std::string& func, const std::vector<uint8_t>& arg0,
+ const std::vector<uint8_t>& arg1,
+ const std::vector<uint8_t>& expected) {
+ auto input0 = ExpandByteArray(arg0);
+ auto input1 = ExpandByteArray(arg1);
+ auto output = ExpandByteArray(expected);
+ ASSERT_OK_AND_ASSIGN(Datum actual, CallFunction(func, {input0, input1}));
+ ValidateAndAssertEqual(actual.make_array(), output);
+ for (int64_t i = 0; i < output->length(); i++) {
+ ASSERT_OK_AND_ASSIGN(Datum actual, CallFunction(func, {*input0->GetScalar(i),
+ *input1->GetScalar(i)}));
+ const auto expected_scalar = *output->GetScalar(i);
+ AssertScalarsEqual(*expected_scalar, *actual.scalar(), /*verbose=*/true);
+ }
+ }
+
+ // To make it easier to test different widths, tests give bytes which get repeated to
+ // make an array of the actual type
+ std::shared_ptr<Array> ExpandByteArray(const std::vector<uint8_t>& values) {
+ std::vector<CType> c_values(values.size() + 1);
+ for (size_t i = 0; i < values.size(); i++) {
+ std::memset(&c_values[i], values[i], sizeof(CType));
+ }
+ std::vector<bool> valid(values.size() + 1, true);
+ valid.back() = false;
+ std::shared_ptr<Array> arr;
+ ArrayFromVector<ArrowType>(valid, c_values, &arr);
+ return arr;
+ }
+
+ void ValidateAndAssertEqual(const std::shared_ptr<Array>& actual,
+ const std::shared_ptr<Array>& expected) {
+ ASSERT_OK(actual->ValidateFull());
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+ }
+};
+
+TYPED_TEST_SUITE(TestUnaryArithmeticIntegral, IntegralTypes);
+TYPED_TEST_SUITE(TestUnaryArithmeticSigned, SignedIntegerTypes);
+TYPED_TEST_SUITE(TestUnaryArithmeticUnsigned, UnsignedIntegerTypes);
+TYPED_TEST_SUITE(TestUnaryArithmeticFloating, FloatingTypes);
+
+TYPED_TEST_SUITE(TestBinaryArithmeticIntegral, IntegralTypes);
+TYPED_TEST_SUITE(TestBinaryArithmeticSigned, SignedIntegerTypes);
+TYPED_TEST_SUITE(TestBinaryArithmeticUnsigned, UnsignedIntegerTypes);
+TYPED_TEST_SUITE(TestBinaryArithmeticFloating, FloatingTypes);
+
+TYPED_TEST_SUITE(TestBitWiseArithmetic, IntegralTypes);
+
+TYPED_TEST(TestBitWiseArithmetic, BitWiseNot) {
+ this->AssertUnaryOp("bit_wise_not", std::vector<uint8_t>{0x00, 0x55, 0xAA, 0xFF},
+ std::vector<uint8_t>{0xFF, 0xAA, 0x55, 0x00});
+}
+
+TYPED_TEST(TestBitWiseArithmetic, BitWiseAnd) {
+ this->AssertBinaryOp("bit_wise_and", std::vector<uint8_t>{0x00, 0xFF, 0x00, 0xFF},
+ std::vector<uint8_t>{0x00, 0x00, 0xFF, 0xFF},
+ std::vector<uint8_t>{0x00, 0x00, 0x00, 0xFF});
+}
+
+TYPED_TEST(TestBitWiseArithmetic, BitWiseOr) {
+ this->AssertBinaryOp("bit_wise_or", std::vector<uint8_t>{0x00, 0xFF, 0x00, 0xFF},
+ std::vector<uint8_t>{0x00, 0x00, 0xFF, 0xFF},
+ std::vector<uint8_t>{0x00, 0xFF, 0xFF, 0xFF});
+}
+
+TYPED_TEST(TestBitWiseArithmetic, BitWiseXor) {
+ this->AssertBinaryOp("bit_wise_xor", std::vector<uint8_t>{0x00, 0xFF, 0x00, 0xFF},
+ std::vector<uint8_t>{0x00, 0x00, 0xFF, 0xFF},
+ std::vector<uint8_t>{0x00, 0xFF, 0xFF, 0x00});
+}
+
+TYPED_TEST(TestBinaryArithmeticIntegral, Add) {
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+
+ this->AssertBinop(Add, "[]", "[]", "[]");
+ this->AssertBinop(Add, "[3, 2, 6]", "[1, 0, 2]", "[4, 2, 8]");
+ // Nulls on left side
+ this->AssertBinop(Add, "[null, 1, null]", "[3, 4, 5]", "[null, 5, null]");
+ this->AssertBinop(Add, "[3, 4, 5]", "[null, 1, null]", "[null, 5, null]");
+ // Nulls on both sides
+ this->AssertBinop(Add, "[null, 1, 2]", "[3, 4, null]", "[null, 5, null]");
+ // All nulls
+ this->AssertBinop(Add, "[null]", "[null]", "[null]");
+
+ // Scalar on the left
+ this->AssertBinop(Add, 3, "[1, 2]", "[4, 5]");
+ this->AssertBinop(Add, 3, "[null, 2]", "[null, 5]");
+ this->AssertBinop(Add, this->MakeNullScalar(), "[1, 2]", "[null, null]");
+ this->AssertBinop(Add, this->MakeNullScalar(), "[null, 2]", "[null, null]");
+ // Scalar on the right
+ this->AssertBinop(Add, "[1, 2]", 3, "[4, 5]");
+ this->AssertBinop(Add, "[null, 2]", 3, "[null, 5]");
+ this->AssertBinop(Add, "[1, 2]", this->MakeNullScalar(), "[null, null]");
+ this->AssertBinop(Add, "[null, 2]", this->MakeNullScalar(), "[null, null]");
+ }
+}
+
+TYPED_TEST(TestBinaryArithmeticIntegral, Sub) {
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+
+ this->AssertBinop(Subtract, "[]", "[]", "[]");
+ this->AssertBinop(Subtract, "[3, 2, 6]", "[1, 0, 2]", "[2, 2, 4]");
+ // Nulls on left side
+ this->AssertBinop(Subtract, "[null, 4, null]", "[2, 1, 0]", "[null, 3, null]");
+ this->AssertBinop(Subtract, "[5, 4, 3]", "[null, 1, null]", "[null, 3, null]");
+ // Nulls on both sides
+ this->AssertBinop(Subtract, "[null, 4, 3]", "[2, 1, null]", "[null, 3, null]");
+ // All nulls
+ this->AssertBinop(Subtract, "[null]", "[null]", "[null]");
+
+ // Scalar on the left
+ this->AssertBinop(Subtract, 3, "[1, 2]", "[2, 1]");
+ this->AssertBinop(Subtract, 3, "[null, 2]", "[null, 1]");
+ this->AssertBinop(Subtract, this->MakeNullScalar(), "[1, 2]", "[null, null]");
+ this->AssertBinop(Subtract, this->MakeNullScalar(), "[null, 2]", "[null, null]");
+ // Scalar on the right
+ this->AssertBinop(Subtract, "[4, 5]", 3, "[1, 2]");
+ this->AssertBinop(Subtract, "[null, 5]", 3, "[null, 2]");
+ this->AssertBinop(Subtract, "[1, 2]", this->MakeNullScalar(), "[null, null]");
+ this->AssertBinop(Subtract, "[null, 2]", this->MakeNullScalar(), "[null, null]");
+ }
+}
+
+TEST(TestBinaryArithmetic, SubtractTimestamps) {
+ random::RandomArrayGenerator rand(kRandomSeed);
+
+ const int64_t length = 100;
+
+ auto lhs = rand.Int64(length, 0, 100000000);
+ auto rhs = rand.Int64(length, 0, 100000000);
+ auto expected_int64 = (*Subtract(lhs, rhs)).make_array();
+
+ for (auto unit : TimeUnit::values()) {
+ auto timestamp_ty = timestamp(unit);
+ auto duration_ty = duration(unit);
+
+ auto lhs_timestamp = *lhs->View(timestamp_ty);
+ auto rhs_timestamp = *rhs->View(timestamp_ty);
+
+ auto result = (*Subtract(lhs_timestamp, rhs_timestamp)).make_array();
+ ASSERT_TRUE(result->type()->Equals(*duration_ty));
+ AssertArraysEqual(**result->View(int64()), *expected_int64);
+ }
+}
+
+TYPED_TEST(TestBinaryArithmeticIntegral, Mul) {
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+
+ this->AssertBinop(Multiply, "[]", "[]", "[]");
+ this->AssertBinop(Multiply, "[3, 2, 6]", "[1, 0, 2]", "[3, 0, 12]");
+ // Nulls on left side
+ this->AssertBinop(Multiply, "[null, 2, null]", "[4, 5, 6]", "[null, 10, null]");
+ this->AssertBinop(Multiply, "[4, 5, 6]", "[null, 2, null]", "[null, 10, null]");
+ // Nulls on both sides
+ this->AssertBinop(Multiply, "[null, 2, 3]", "[4, 5, null]", "[null, 10, null]");
+ // All nulls
+ this->AssertBinop(Multiply, "[null]", "[null]", "[null]");
+
+ // Scalar on the left
+ this->AssertBinop(Multiply, 3, "[4, 5]", "[12, 15]");
+ this->AssertBinop(Multiply, 3, "[null, 5]", "[null, 15]");
+ this->AssertBinop(Multiply, this->MakeNullScalar(), "[1, 2]", "[null, null]");
+ this->AssertBinop(Multiply, this->MakeNullScalar(), "[null, 2]", "[null, null]");
+ // Scalar on the right
+ this->AssertBinop(Multiply, "[4, 5]", 3, "[12, 15]");
+ this->AssertBinop(Multiply, "[null, 5]", 3, "[null, 15]");
+ this->AssertBinop(Multiply, "[1, 2]", this->MakeNullScalar(), "[null, null]");
+ this->AssertBinop(Multiply, "[null, 2]", this->MakeNullScalar(), "[null, null]");
+ }
+}
+
+TYPED_TEST(TestBinaryArithmeticSigned, Add) {
+ this->AssertBinop(Add, "[-7, 6, 5, 4, 3, 2, 1]", "[-6, 5, -4, 3, -2, 1, 0]",
+ "[-13, 11, 1, 7, 1, 3, 1]");
+ this->AssertBinop(Add, -1, "[-6, 5, -4, 3, -2, 1, 0]", "[-7, 4, -5, 2, -3, 0, -1]");
+ this->AssertBinop(Add, -10, 5, -5);
+}
+
+TYPED_TEST(TestBinaryArithmeticSigned, OverflowWraps) {
+ using CType = typename TestFixture::CType;
+
+ auto min = std::numeric_limits<CType>::lowest();
+ auto max = std::numeric_limits<CType>::max();
+
+ this->AssertBinop(Subtract, MakeArray(min, max, min), MakeArray(1, max, max),
+ MakeArray(max, 0, 1));
+ this->AssertBinop(Multiply, MakeArray(min, max, max), MakeArray(max, 2, max),
+ MakeArray(min, CType(-2), 1));
+}
+
+TYPED_TEST(TestBinaryArithmeticIntegral, OverflowRaises) {
+ using CType = typename TestFixture::CType;
+
+ auto min = std::numeric_limits<CType>::lowest();
+ auto max = std::numeric_limits<CType>::max();
+
+ this->SetOverflowCheck(true);
+
+ this->AssertBinopRaises(Add, MakeArray(min, max, max), MakeArray(CType(-1), 1, max),
+ "overflow");
+ this->AssertBinopRaises(Subtract, MakeArray(min, max), MakeArray(1, max), "overflow");
+ this->AssertBinopRaises(Subtract, MakeArray(min), MakeArray(max), "overflow");
+
+ this->AssertBinopRaises(Multiply, MakeArray(min, max, max), MakeArray(max, 2, max),
+ "overflow");
+}
+
+TYPED_TEST(TestBinaryArithmeticSigned, AddOverflowRaises) {
+ using CType = typename TestFixture::CType;
+
+ auto min = std::numeric_limits<CType>::lowest();
+ auto max = std::numeric_limits<CType>::max();
+
+ this->SetOverflowCheck(true);
+
+ this->AssertBinop(Add, MakeArray(max), MakeArray(-1), MakeArray(max - 1));
+ this->AssertBinop(Add, MakeArray(min), MakeArray(1), MakeArray(min + 1));
+ this->AssertBinop(Add, MakeArray(-1), MakeArray(2), MakeArray(1));
+ this->AssertBinop(Add, MakeArray(1), MakeArray(-2), MakeArray(-1));
+
+ this->AssertBinopRaises(Add, MakeArray(max), MakeArray(1), "overflow");
+ this->AssertBinopRaises(Add, MakeArray(min), MakeArray(-1), "overflow");
+
+ // Overflow should not be checked on underlying value slots when output would be null
+ auto left = ArrayFromJSON(this->type_singleton(), MakeArray(1, max, min));
+ auto right = ArrayFromJSON(this->type_singleton(), MakeArray(1, 1, -1));
+ left = TweakValidityBit(left, 1, false);
+ right = TweakValidityBit(right, 2, false);
+ this->AssertBinop(Add, left, right, "[2, null, null]");
+}
+
+TYPED_TEST(TestBinaryArithmeticSigned, SubOverflowRaises) {
+ using CType = typename TestFixture::CType;
+
+ auto min = std::numeric_limits<CType>::lowest();
+ auto max = std::numeric_limits<CType>::max();
+
+ this->SetOverflowCheck(true);
+
+ this->AssertBinop(Subtract, MakeArray(max), MakeArray(1), MakeArray(max - 1));
+ this->AssertBinop(Subtract, MakeArray(min), MakeArray(-1), MakeArray(min + 1));
+ this->AssertBinop(Subtract, MakeArray(-1), MakeArray(-2), MakeArray(1));
+ this->AssertBinop(Subtract, MakeArray(1), MakeArray(2), MakeArray(-1));
+
+ this->AssertBinopRaises(Subtract, MakeArray(max), MakeArray(-1), "overflow");
+ this->AssertBinopRaises(Subtract, MakeArray(min), MakeArray(1), "overflow");
+
+ // Overflow should not be checked on underlying value slots when output would be null
+ auto left = ArrayFromJSON(this->type_singleton(), MakeArray(2, max, min));
+ auto right = ArrayFromJSON(this->type_singleton(), MakeArray(1, -1, 1));
+ left = TweakValidityBit(left, 1, false);
+ right = TweakValidityBit(right, 2, false);
+ this->AssertBinop(Subtract, left, right, "[1, null, null]");
+}
+
+TYPED_TEST(TestBinaryArithmeticSigned, MulOverflowRaises) {
+ using CType = typename TestFixture::CType;
+
+ auto min = std::numeric_limits<CType>::lowest();
+ auto max = std::numeric_limits<CType>::max();
+
+ this->SetOverflowCheck(true);
+
+ this->AssertBinop(Multiply, MakeArray(max), MakeArray(-1), MakeArray(min + 1));
+ this->AssertBinop(Multiply, MakeArray(max / 2), MakeArray(-2), MakeArray(min + 2));
+
+ this->AssertBinopRaises(Multiply, MakeArray(max), MakeArray(2), "overflow");
+ this->AssertBinopRaises(Multiply, MakeArray(max / 2), MakeArray(3), "overflow");
+ this->AssertBinopRaises(Multiply, MakeArray(max / 2), MakeArray(-3), "overflow");
+
+ this->AssertBinopRaises(Multiply, MakeArray(min), MakeArray(2), "overflow");
+ this->AssertBinopRaises(Multiply, MakeArray(min / 2), MakeArray(3), "overflow");
+ this->AssertBinopRaises(Multiply, MakeArray(min), MakeArray(-1), "overflow");
+ this->AssertBinopRaises(Multiply, MakeArray(min / 2), MakeArray(-2), "overflow");
+
+ // Overflow should not be checked on underlying value slots when output would be null
+ auto left = ArrayFromJSON(this->type_singleton(), MakeArray(2, max, min / 2));
+ auto right = ArrayFromJSON(this->type_singleton(), MakeArray(1, 2, 3));
+ left = TweakValidityBit(left, 1, false);
+ right = TweakValidityBit(right, 2, false);
+ this->AssertBinop(Multiply, left, right, "[2, null, null]");
+}
+
+TYPED_TEST(TestBinaryArithmeticUnsigned, OverflowWraps) {
+ using CType = typename TestFixture::CType;
+
+ auto min = std::numeric_limits<CType>::lowest();
+ auto max = std::numeric_limits<CType>::max();
+
+ this->SetOverflowCheck(false);
+ this->AssertBinop(Add, MakeArray(min, max, max), MakeArray(CType(-1), 1, max),
+ MakeArray(max, min, CType(-2)));
+
+ this->AssertBinop(Subtract, MakeArray(min, max, min), MakeArray(1, max, max),
+ MakeArray(max, 0, 1));
+
+ this->AssertBinop(Multiply, MakeArray(min, max, max), MakeArray(max, 2, max),
+ MakeArray(min, CType(-2), 1));
+}
+
+TYPED_TEST(TestBinaryArithmeticSigned, Sub) {
+ this->AssertBinop(Subtract, "[0, 1, 2, 3, 4, 5, 6]", "[1, 2, 3, 4, 5, 6, 7]",
+ "[-1, -1, -1, -1, -1, -1, -1]");
+
+ this->AssertBinop(Subtract, "[0, 0, 0, 0, 0, 0, 0]", "[6, 5, 4, 3, 2, 1, 0]",
+ "[-6, -5, -4, -3, -2, -1, 0]");
+
+ this->AssertBinop(Subtract, "[10, 12, 4, 50, 50, 32, 11]", "[2, 0, 6, 1, 5, 3, 4]",
+ "[8, 12, -2, 49, 45, 29, 7]");
+
+ this->AssertBinop(Subtract, "[null, 1, 3, null, 2, 5]", "[1, 4, 2, 5, 0, 3]",
+ "[null, -3, 1, null, 2, 2]");
+}
+
+TYPED_TEST(TestBinaryArithmeticSigned, Mul) {
+ this->AssertBinop(Multiply, "[-10, 12, 4, 50, -5, 32, 11]", "[-2, 0, -6, 1, 5, 3, 4]",
+ "[20, 0, -24, 50, -25, 96, 44]");
+ this->AssertBinop(Multiply, -2, "[-10, 12, 4, 50, -5, 32, 11]",
+ "[20, -24, -8, -100, 10, -64, -22]");
+ this->AssertBinop(Multiply, -5, -5, 25);
+}
+
+// NOTE: cannot test Inf / -Inf (ARROW-9495)
+
+TYPED_TEST(TestBinaryArithmeticFloating, Add) {
+ this->AssertBinop(Add, "[]", "[]", "[]");
+
+ this->AssertBinop(Add, "[1.5, 0.5]", "[2.0, -3]", "[3.5, -2.5]");
+ // Nulls on the left
+ this->AssertBinop(Add, "[null, 0.5]", "[2.0, -3]", "[null, -2.5]");
+ // Nulls on the right
+ this->AssertBinop(Add, "[1.5, 0.5]", "[null, -3]", "[null, -2.5]");
+ // Nulls on both sides
+ this->AssertBinop(Add, "[null, 1.5, 0.5]", "[2.0, -3, null]", "[null, -1.5, null]");
+
+ // Scalar on the left
+ this->AssertBinop(Add, -1.5f, "[0.0, 2.0]", "[-1.5, 0.5]");
+ this->AssertBinop(Add, -1.5f, "[null, 2.0]", "[null, 0.5]");
+ this->AssertBinop(Add, this->MakeNullScalar(), "[0.0, 2.0]", "[null, null]");
+ this->AssertBinop(Add, this->MakeNullScalar(), "[null, 2.0]", "[null, null]");
+ // Scalar on the right
+ this->AssertBinop(Add, "[0.0, 2.0]", -1.5f, "[-1.5, 0.5]");
+ this->AssertBinop(Add, "[null, 2.0]", -1.5f, "[null, 0.5]");
+ this->AssertBinop(Add, "[0.0, 2.0]", this->MakeNullScalar(), "[null, null]");
+ this->AssertBinop(Add, "[null, 2.0]", this->MakeNullScalar(), "[null, null]");
+}
+
+TYPED_TEST(TestBinaryArithmeticFloating, Div) {
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ // Empty arrays
+ this->AssertBinop(Divide, "[]", "[]", "[]");
+ // Ordinary arrays
+ this->AssertBinop(Divide, "[3.4, 0.64, 1.28]", "[1, 2, 4]", "[3.4, 0.32, 0.32]");
+ // Array with nulls
+ this->AssertBinop(Divide, "[null, 1, 3.3, null, 2]", "[1, 4, 2, 5, 0.1]",
+ "[null, 0.25, 1.65, null, 20]");
+ // Scalar divides by array
+ this->AssertBinop(Divide, 10.0F, "[null, 1, 2.5, null, 2, 5]",
+ "[null, 10, 4, null, 5, 2]");
+ // Array divides by scalar
+ this->AssertBinop(Divide, "[null, 1, 2.5, null, 2, 5]", 10.0F,
+ "[null, 0.1, 0.25, null, 0.2, 0.5]");
+ // Array with infinity
+ this->AssertBinop(Divide, "[3.4, Inf, -Inf]", "[1, 2, 3]", "[3.4, Inf, -Inf]");
+ // Array with NaN
+ this->SetNansEqual(true);
+ this->AssertBinop(Divide, "[3.4, NaN, 2.0]", "[1, 2, 2.0]", "[3.4, NaN, 1.0]");
+ // Scalar divides by scalar
+ this->AssertBinop(Divide, 21.0F, 3.0F, 7.0F);
+ }
+}
+
+TYPED_TEST(TestBinaryArithmeticIntegral, Div) {
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+
+ // Empty arrays
+ this->AssertBinop(Divide, "[]", "[]", "[]");
+ // Ordinary arrays
+ this->AssertBinop(Divide, "[3, 2, 6]", "[1, 1, 2]", "[3, 2, 3]");
+ // Array with nulls
+ this->AssertBinop(Divide, "[null, 10, 30, null, 20]", "[1, 4, 2, 5, 10]",
+ "[null, 2, 15, null, 2]");
+ // Scalar divides by array
+ this->AssertBinop(Divide, 33, "[null, 1, 3, null, 2]", "[null, 33, 11, null, 16]");
+ // Array divides by scalar
+ this->AssertBinop(Divide, "[null, 10, 30, null, 2]", 3, "[null, 3, 10, null, 0]");
+ // Scalar divides by scalar
+ this->AssertBinop(Divide, 16, 7, 2);
+ }
+}
+
+TYPED_TEST(TestBinaryArithmeticSigned, Div) {
+ // Ordinary arrays
+ this->AssertBinop(Divide, "[-3, 2, -6]", "[1, 1, 2]", "[-3, 2, -3]");
+ // Array with nulls
+ this->AssertBinop(Divide, "[null, 10, 30, null, -20]", "[1, 4, 2, 5, 10]",
+ "[null, 2, 15, null, -2]");
+ // Scalar divides by array
+ this->AssertBinop(Divide, 33, "[null, -1, -3, null, 2]", "[null, -33, -11, null, 16]");
+ // Array divides by scalar
+ this->AssertBinop(Divide, "[null, 10, 30, null, 2]", 3, "[null, 3, 10, null, 0]");
+ // Scalar divides by scalar
+ this->AssertBinop(Divide, -16, -8, 2);
+}
+
+TYPED_TEST(TestBinaryArithmeticIntegral, DivideByZero) {
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ this->AssertBinopRaises(Divide, "[3, 2, 6]", "[1, 1, 0]", "divide by zero");
+ }
+}
+
+TYPED_TEST(TestBinaryArithmeticFloating, DivideByZero) {
+ this->SetOverflowCheck(true);
+ this->AssertBinopRaises(Divide, "[3.0, 2.0, 6.0]", "[1.0, 1.0, 0.0]", "divide by zero");
+ this->AssertBinopRaises(Divide, "[3.0, 2.0, 0.0]", "[1.0, 1.0, 0.0]", "divide by zero");
+ this->AssertBinopRaises(Divide, "[3.0, 2.0, -6.0]", "[1.0, 1.0, 0.0]",
+ "divide by zero");
+
+ this->SetOverflowCheck(false);
+ this->SetNansEqual(true);
+ this->AssertBinop(Divide, "[3.0, 2.0, 6.0]", "[1.0, 1.0, 0.0]", "[3.0, 2.0, Inf]");
+ this->AssertBinop(Divide, "[3.0, 2.0, 0.0]", "[1.0, 1.0, 0.0]", "[3.0, 2.0, NaN]");
+ this->AssertBinop(Divide, "[3.0, 2.0, -6.0]", "[1.0, 1.0, 0.0]", "[3.0, 2.0, -Inf]");
+}
+
+TYPED_TEST(TestBinaryArithmeticSigned, DivideOverflowRaises) {
+ using CType = typename TestFixture::CType;
+ auto min = std::numeric_limits<CType>::lowest();
+
+ this->SetOverflowCheck(true);
+ this->AssertBinopRaises(Divide, MakeArray(min), MakeArray(-1), "overflow");
+
+ this->SetOverflowCheck(false);
+ this->AssertBinop(Divide, MakeArray(min), MakeArray(-1), "[0]");
+}
+
+TYPED_TEST(TestBinaryArithmeticFloating, Power) {
+ using CType = typename TestFixture::CType;
+ auto max = std::numeric_limits<CType>::max();
+ this->SetNansEqual(true);
+
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+
+ // Empty arrays
+ this->AssertBinop(Power, "[]", "[]", "[]");
+ // Ordinary arrays
+ this->AssertBinop(Power, "[3.4, 16, 0.64, 1.2, 0]", "[1, 0.5, 2, 4, 0]",
+ "[3.4, 4, 0.4096, 2.0736, 1]");
+ // Array with nulls
+ this->AssertBinop(Power, "[null, 1, 3.3, null, 2]", "[1, 4, 2, 5, 0.1]",
+ "[null, 1, 10.89, null, 1.07177346]");
+ // Scalar exponentiated by array
+ this->AssertBinop(Power, 10.0F, "[null, 1, 2.5, null, 2, 5]",
+ "[null, 10, 316.227766017, null, 100, 100000]");
+ // Array exponentiated by scalar
+ this->AssertBinop(Power, "[null, 1, 2.5, null, 2, 5]", 10.0F,
+ "[null, 1, 9536.74316406, null, 1024, 9765625]");
+ // Array with infinity
+ this->AssertBinop(Power, "[3.4, Inf, -Inf, 1.1, 100000]", "[1, 2, 3, Inf, 100000]",
+ "[3.4, Inf, -Inf, Inf, Inf]");
+ // Array with NaN
+ this->AssertBinop(Power, "[3.4, NaN, 2.0]", "[1, 2, 2.0]", "[3.4, NaN, 4.0]");
+ // Scalar exponentiated by scalar
+ this->AssertBinop(Power, 21.0F, 3.0F, 9261.0F);
+ // Divide by zero
+ this->AssertBinop(Power, "[0.0, 0.0]", "[-1.0, -3.0]", "[Inf, Inf]");
+ // Check overflow behaviour
+ this->AssertBinop(Power, max, 10, INFINITY);
+ }
+
+ // Edge cases - removing NaNs
+ this->AssertBinop(Power, "[1, NaN, 0, null, 1.2, -Inf, Inf, 1.1, 1, 0, 1, 0]",
+ "[NaN, 0, NaN, 1, null, 1, 2, -Inf, Inf, 0, 0, 42]",
+ "[1, 1, NaN, null, null, -Inf, Inf, 0, 1, 1, 1, 0]");
+}
+
+TYPED_TEST(TestBinaryArithmeticIntegral, Power) {
+ using CType = typename TestFixture::CType;
+ auto max = std::numeric_limits<CType>::max();
+
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+
+ // Empty arrays
+ this->AssertBinop(Power, "[]", "[]", "[]");
+ // Ordinary arrays
+ this->AssertBinop(Power, "[3, 2, 6, 2]", "[1, 1, 2, 0]", "[3, 2, 36, 1]");
+ // Array with nulls
+ this->AssertBinop(Power, "[null, 2, 3, null, 20]", "[1, 6, 2, 5, 1]",
+ "[null, 64, 9, null, 20]");
+ // Scalar exponentiated by array
+ this->AssertBinop(Power, 3, "[null, 3, 4, null, 2]", "[null, 27, 81, null, 9]");
+ // Array exponentiated by scalar
+ this->AssertBinop(Power, "[null, 10, 3, null, 2]", 2, "[null, 100, 9, null, 4]");
+ // Scalar exponentiated by scalar
+ this->AssertBinop(Power, 4, 3, 64);
+ // Edge cases
+ this->AssertBinop(Power, "[0, 1, 0]", "[0, 0, 42]", "[1, 1, 0]");
+ }
+
+ // Overflow raises
+ this->SetOverflowCheck(true);
+ this->AssertBinopRaises(Power, MakeArray(max), MakeArray(10), "overflow");
+ // Disable overflow check
+ this->SetOverflowCheck(false);
+ this->AssertBinop(Power, max, 10, 1);
+}
+
+TYPED_TEST(TestBinaryArithmeticSigned, Power) {
+ using CType = typename TestFixture::CType;
+ auto max = std::numeric_limits<CType>::max();
+
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+
+ // Empty arrays
+ this->AssertBinop(Power, "[]", "[]", "[]");
+ // Ordinary arrays
+ this->AssertBinop(Power, "[-3, 2, -6, 2]", "[3, 1, 2, 0]", "[-27, 2, 36, 1]");
+ // Array with nulls
+ this->AssertBinop(Power, "[null, 10, 127, null, -20]", "[1, 2, 1, 5, 1]",
+ "[null, 100, 127, null, -20]");
+ // Scalar exponentiated by array
+ this->AssertBinop(Power, 11, "[null, 1, null, 2]", "[null, 11, null, 121]");
+ // Array exponentiated by scalar
+ this->AssertBinop(Power, "[null, 1, 3, null, 2]", 3, "[null, 1, 27, null, 8]");
+ // Scalar exponentiated by scalar
+ this->AssertBinop(Power, 16, 1, 16);
+ // Edge cases
+ this->AssertBinop(Power, "[1, 0, -1, 2]", "[0, 42, 0, 1]", "[1, 0, 1, 2]");
+ // Divide by zero raises
+ this->AssertBinopRaises(Power, MakeArray(0), MakeArray(-1),
+ "integers to negative integer powers are not allowed");
+ }
+
+ // Overflow raises
+ this->SetOverflowCheck(true);
+ this->AssertBinopRaises(Power, MakeArray(max), MakeArray(10), "overflow");
+ // Disable overflow check
+ this->SetOverflowCheck(false);
+ this->AssertBinop(Power, max, 10, 1);
+}
+
+TYPED_TEST(TestBinaryArithmeticFloating, Sub) {
+ this->AssertBinop(Subtract, "[]", "[]", "[]");
+
+ this->AssertBinop(Subtract, "[1.5, 0.5]", "[2.0, -3]", "[-0.5, 3.5]");
+ // Nulls on the left
+ this->AssertBinop(Subtract, "[null, 0.5]", "[2.0, -3]", "[null, 3.5]");
+ // Nulls on the right
+ this->AssertBinop(Subtract, "[1.5, 0.5]", "[null, -3]", "[null, 3.5]");
+ // Nulls on both sides
+ this->AssertBinop(Subtract, "[null, 1.5, 0.5]", "[2.0, -3, null]", "[null, 4.5, null]");
+
+ // Scalar on the left
+ this->AssertBinop(Subtract, -1.5f, "[0.0, 2.0]", "[-1.5, -3.5]");
+ this->AssertBinop(Subtract, -1.5f, "[null, 2.0]", "[null, -3.5]");
+ this->AssertBinop(Subtract, this->MakeNullScalar(), "[0.0, 2.0]", "[null, null]");
+ this->AssertBinop(Subtract, this->MakeNullScalar(), "[null, 2.0]", "[null, null]");
+ // Scalar on the right
+ this->AssertBinop(Subtract, "[0.0, 2.0]", -1.5f, "[1.5, 3.5]");
+ this->AssertBinop(Subtract, "[null, 2.0]", -1.5f, "[null, 3.5]");
+ this->AssertBinop(Subtract, "[0.0, 2.0]", this->MakeNullScalar(), "[null, null]");
+ this->AssertBinop(Subtract, "[null, 2.0]", this->MakeNullScalar(), "[null, null]");
+}
+
+TYPED_TEST(TestBinaryArithmeticFloating, Mul) {
+ this->AssertBinop(Multiply, "[]", "[]", "[]");
+
+ this->AssertBinop(Multiply, "[1.5, 0.5]", "[2.0, -3]", "[3.0, -1.5]");
+ // Nulls on the left
+ this->AssertBinop(Multiply, "[null, 0.5]", "[2.0, -3]", "[null, -1.5]");
+ // Nulls on the right
+ this->AssertBinop(Multiply, "[1.5, 0.5]", "[null, -3]", "[null, -1.5]");
+ // Nulls on both sides
+ this->AssertBinop(Multiply, "[null, 1.5, 0.5]", "[2.0, -3, null]",
+ "[null, -4.5, null]");
+
+ // Scalar on the left
+ this->AssertBinop(Multiply, -1.5f, "[0.0, 2.0]", "[0.0, -3.0]");
+ this->AssertBinop(Multiply, -1.5f, "[null, 2.0]", "[null, -3.0]");
+ this->AssertBinop(Multiply, this->MakeNullScalar(), "[0.0, 2.0]", "[null, null]");
+ this->AssertBinop(Multiply, this->MakeNullScalar(), "[null, 2.0]", "[null, null]");
+ // Scalar on the right
+ this->AssertBinop(Multiply, "[0.0, 2.0]", -1.5f, "[0.0, -3.0]");
+ this->AssertBinop(Multiply, "[null, 2.0]", -1.5f, "[null, -3.0]");
+ this->AssertBinop(Multiply, "[0.0, 2.0]", this->MakeNullScalar(), "[null, null]");
+ this->AssertBinop(Multiply, "[null, 2.0]", this->MakeNullScalar(), "[null, null]");
+}
+
+TEST(TestBinaryArithmetic, DispatchBest) {
+ for (std::string name : {"add", "subtract", "multiply", "divide", "power"}) {
+ for (std::string suffix : {"", "_checked"}) {
+ name += suffix;
+
+ CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()});
+ CheckDispatchBest(name, {int32(), null()}, {int32(), int32()});
+ CheckDispatchBest(name, {null(), int32()}, {int32(), int32()});
+
+ CheckDispatchBest(name, {int32(), int8()}, {int32(), int32()});
+ CheckDispatchBest(name, {int32(), int16()}, {int32(), int32()});
+ CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()});
+ CheckDispatchBest(name, {int32(), int64()}, {int64(), int64()});
+
+ CheckDispatchBest(name, {int32(), uint8()}, {int32(), int32()});
+ CheckDispatchBest(name, {int32(), uint16()}, {int32(), int32()});
+ CheckDispatchBest(name, {int32(), uint32()}, {int64(), int64()});
+ CheckDispatchBest(name, {int32(), uint64()}, {int64(), int64()});
+
+ CheckDispatchBest(name, {uint8(), uint8()}, {uint8(), uint8()});
+ CheckDispatchBest(name, {uint8(), uint16()}, {uint16(), uint16()});
+
+ CheckDispatchBest(name, {int32(), float32()}, {float32(), float32()});
+ CheckDispatchBest(name, {float32(), int64()}, {float32(), float32()});
+ CheckDispatchBest(name, {float64(), int32()}, {float64(), float64()});
+
+ CheckDispatchBest(name, {dictionary(int8(), float64()), float64()},
+ {float64(), float64()});
+ CheckDispatchBest(name, {dictionary(int8(), float64()), int16()},
+ {float64(), float64()});
+ }
+ }
+
+ CheckDispatchBest("atan2", {int32(), float64()}, {float64(), float64()});
+ CheckDispatchBest("atan2", {int32(), uint8()}, {float64(), float64()});
+ CheckDispatchBest("atan2", {int32(), null()}, {float64(), float64()});
+ CheckDispatchBest("atan2", {float32(), float64()}, {float64(), float64()});
+ // Integer always promotes to double
+ CheckDispatchBest("atan2", {float32(), int8()}, {float64(), float64()});
+}
+
+TEST(TestBinaryArithmetic, Null) {
+ for (std::string name : {"add", "divide", "logb", "multiply", "power", "shift_left",
+ "shift_right", "subtract", "tan"}) {
+ for (std::string suffix : {"", "_checked"}) {
+ name += suffix;
+ AssertNullToNull(name);
+ }
+ }
+
+ for (std::string name : {"atan2", "bit_wise_and", "bit_wise_or", "bit_wise_xor"}) {
+ AssertNullToNull(name);
+ }
+}
+
+TEST(TestBinaryArithmetic, AddWithImplicitCasts) {
+ CheckScalarBinary("add", ArrayFromJSON(int32(), "[0, 1, 2, null]"),
+ ArrayFromJSON(float64(), "[0.25, 0.5, 0.75, 1.0]"),
+ ArrayFromJSON(float64(), "[0.25, 1.5, 2.75, null]"));
+
+ CheckScalarBinary("add", ArrayFromJSON(int8(), "[-16, 0, 16, null]"),
+ ArrayFromJSON(uint32(), "[3, 4, 5, 7]"),
+ ArrayFromJSON(int64(), "[-13, 4, 21, null]"));
+
+ CheckScalarBinary("add",
+ ArrayFromJSON(dictionary(int32(), int32()), "[8, 6, 3, null, 2]"),
+ ArrayFromJSON(uint32(), "[3, 4, 5, 7, 0]"),
+ ArrayFromJSON(int64(), "[11, 10, 8, null, 2]"));
+
+ CheckScalarBinary("add", ArrayFromJSON(int32(), "[0, 1, 2, null]"),
+ std::make_shared<NullArray>(4),
+ ArrayFromJSON(int32(), "[null, null, null, null]"));
+
+ CheckScalarBinary("add", ArrayFromJSON(dictionary(int32(), int8()), "[0, 1, 2, null]"),
+ ArrayFromJSON(uint32(), "[3, 4, 5, 7]"),
+ ArrayFromJSON(int64(), "[3, 5, 7, null]"));
+}
+
+TEST(TestBinaryArithmetic, AddWithImplicitCastsUint64EdgeCase) {
+ // int64 is as wide as we can promote
+ CheckDispatchBest("add", {int8(), uint64()}, {int64(), int64()});
+
+ // this works sometimes
+ CheckScalarBinary("add", ArrayFromJSON(int8(), "[-1]"), ArrayFromJSON(uint64(), "[0]"),
+ ArrayFromJSON(int64(), "[-1]"));
+
+ // ... but it can result in impossible implicit casts in the presence of uint64, since
+ // some uint64 values cannot be cast to int64:
+ ASSERT_RAISES(Invalid,
+ CallFunction("add", {ArrayFromJSON(int64(), "[-1]"),
+ ArrayFromJSON(uint64(), "[18446744073709551615]")}));
+}
+
+TEST(TestUnaryArithmetic, DispatchBest) {
+ // All types (with _checked variant)
+ for (std::string name : {"abs"}) {
+ for (std::string suffix : {"", "_checked"}) {
+ name += suffix;
+ for (const auto& ty : {int8(), int16(), int32(), int64(), uint8(), uint16(),
+ uint32(), uint64(), float32(), float64()}) {
+ CheckDispatchBest(name, {ty}, {ty});
+ CheckDispatchBest(name, {dictionary(int8(), ty)}, {ty});
+ }
+ }
+ }
+
+ // All types
+ for (std::string name : {"negate", "sign"}) {
+ for (const auto& ty : {int8(), int16(), int32(), int64(), uint8(), uint16(), uint32(),
+ uint64(), float32(), float64()}) {
+ CheckDispatchBest(name, {ty}, {ty});
+ CheckDispatchBest(name, {dictionary(int8(), ty)}, {ty});
+ }
+ }
+
+ // Signed types
+ for (std::string name : {"negate_checked"}) {
+ for (const auto& ty : {int8(), int16(), int32(), int64(), float32(), float64()}) {
+ CheckDispatchBest(name, {ty}, {ty});
+ CheckDispatchBest(name, {dictionary(int8(), ty)}, {ty});
+ }
+ }
+
+ // Float types (with _checked variant)
+ for (std::string name :
+ {"ln", "log2", "log10", "log1p", "sin", "cos", "tan", "asin", "acos"}) {
+ for (std::string suffix : {"", "_checked"}) {
+ name += suffix;
+ for (const auto& ty : {float32(), float64()}) {
+ CheckDispatchBest(name, {ty}, {ty});
+ CheckDispatchBest(name, {dictionary(int8(), ty)}, {ty});
+ }
+ }
+ }
+
+ // Float types
+ for (std::string name :
+ {"atan", "sign", "floor", "ceil", "trunc", "round", "round_to_multiple"}) {
+ for (const auto& ty : {float32(), float64()}) {
+ CheckDispatchBest(name, {ty}, {ty});
+ CheckDispatchBest(name, {dictionary(int8(), ty)}, {ty});
+ }
+ }
+
+ // Integer -> Float64 (with _checked variant)
+ for (std::string name :
+ {"ln", "log2", "log10", "log1p", "sin", "cos", "tan", "asin", "acos"}) {
+ for (std::string suffix : {"", "_checked"}) {
+ name += suffix;
+ for (const auto& ty :
+ {int8(), int16(), int32(), int64(), uint8(), uint16(), uint32(), uint64()}) {
+ CheckDispatchBest(name, {ty}, {float64()});
+ CheckDispatchBest(name, {dictionary(int8(), ty)}, {float64()});
+ }
+ }
+ }
+
+ // Integer -> Float64
+ for (std::string name :
+ {"atan", "floor", "ceil", "trunc", "round", "round_to_multiple"}) {
+ for (const auto& ty :
+ {int8(), int16(), int32(), int64(), uint8(), uint16(), uint32(), uint64()}) {
+ CheckDispatchBest(name, {ty}, {float64()});
+ CheckDispatchBest(name, {dictionary(int8(), ty)}, {float64()});
+ }
+ }
+}
+
+TEST(TestUnaryArithmetic, Null) {
+ for (std::string name : {"abs", "acos", "asin", "cos", "ln", "log10", "log1p", "log2",
+ "negate", "sin", "tan"}) {
+ for (std::string suffix : {"", "_checked"}) {
+ name += suffix;
+ AssertNullToNull(name);
+ }
+ }
+
+ for (std::string name : {"atan", "bit_wise_not", "ceil", "floor", "round",
+ "round_to_multiple", "sign", "trunc"}) {
+ AssertNullToNull(name);
+ }
+}
+
+TYPED_TEST(TestUnaryArithmeticSigned, Negate) {
+ using CType = typename TestFixture::CType;
+
+ auto min = std::numeric_limits<CType>::min();
+ auto max = std::numeric_limits<CType>::max();
+
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ // Empty arrays
+ this->AssertUnaryOp(Negate, "[]", "[]");
+ // Array with nulls
+ this->AssertUnaryOp(Negate, "[null]", "[null]");
+ this->AssertUnaryOp(Negate, this->MakeNullScalar(), this->MakeNullScalar());
+ this->AssertUnaryOp(Negate, "[1, null, -10]", "[-1, null, 10]");
+ // Arrays with zeros
+ this->AssertUnaryOp(Negate, "[0, 0, -0]", "[0, -0, 0]");
+ this->AssertUnaryOp(Negate, 0, -0);
+ this->AssertUnaryOp(Negate, -0, 0);
+ this->AssertUnaryOp(Negate, 0, 0);
+ // Ordinary arrays (positive inputs)
+ this->AssertUnaryOp(Negate, "[1, 10, 127]", "[-1, -10, -127]");
+ this->AssertUnaryOp(Negate, 1, -1);
+ this->AssertUnaryOp(Negate, this->MakeScalar(1), this->MakeScalar(-1));
+ // Ordinary arrays (negative inputs)
+ this->AssertUnaryOp(Negate, "[-1, -10, -127]", "[1, 10, 127]");
+ this->AssertUnaryOp(Negate, -1, 1);
+ this->AssertUnaryOp(Negate, MakeArray(-1), "[1]");
+ // Min/max (wrap arounds and overflow)
+ this->AssertUnaryOp(Negate, max, min + 1);
+ if (check_overflow) {
+ this->AssertUnaryOpRaises(Negate, MakeArray(min), "overflow");
+ } else {
+ this->AssertUnaryOp(Negate, min, min);
+ }
+ }
+
+ // Overflow should not be checked on underlying value slots when output would be null
+ this->SetOverflowCheck(true);
+ auto arg = ArrayFromJSON(this->type_singleton(), MakeArray(1, max, min));
+ arg = TweakValidityBit(arg, 1, false);
+ arg = TweakValidityBit(arg, 2, false);
+ this->AssertUnaryOp(Negate, arg, "[-1, null, null]");
+}
+
+TYPED_TEST(TestUnaryArithmeticUnsigned, Negate) {
+ using CType = typename TestFixture::CType;
+
+ auto min = std::numeric_limits<CType>::min();
+ auto max = std::numeric_limits<CType>::max();
+
+ // Empty arrays
+ this->AssertUnaryOp(Negate, "[]", "[]");
+ // Array with nulls
+ this->AssertUnaryOp(Negate, "[null]", "[null]");
+ this->AssertUnaryOp(Negate, this->MakeNullScalar(), this->MakeNullScalar());
+ // Min/max (wrap around)
+ this->AssertUnaryOp(Negate, min, min);
+ this->AssertUnaryOp(Negate, max, 1);
+ this->AssertUnaryOp(Negate, 1, max);
+ // Not implemented kernels
+ this->SetOverflowCheck(true);
+ this->AssertUnaryOpNotImplemented(Negate, "[0]");
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, Negate) {
+ using CType = typename TestFixture::CType;
+
+ auto min = std::numeric_limits<CType>::lowest();
+ auto max = std::numeric_limits<CType>::max();
+
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ // Empty arrays
+ this->AssertUnaryOp(Negate, "[]", "[]");
+ // Array with nulls
+ this->AssertUnaryOp(Negate, "[null]", "[null]");
+ this->AssertUnaryOp(Negate, this->MakeNullScalar(), this->MakeNullScalar());
+ this->AssertUnaryOp(Negate, "[1.3, null, -10.80]", "[-1.3, null, 10.80]");
+ // Arrays with zeros
+ this->AssertUnaryOp(Negate, "[0.0, 0.0, -0.0]", "[0.0, -0.0, 0.0]");
+ this->AssertUnaryOp(Negate, 0.0F, -0.0F);
+ this->AssertUnaryOp(Negate, -0.0F, 0.0F);
+ this->AssertUnaryOp(Negate, 0.0F, 0.0F);
+ // Ordinary arrays (positive inputs)
+ this->AssertUnaryOp(Negate, "[1.3, 10.80, 12748.001]", "[-1.3, -10.80, -12748.001]");
+ this->AssertUnaryOp(Negate, 1.3F, -1.3F);
+ this->AssertUnaryOp(Negate, this->MakeScalar(1.3F), this->MakeScalar(-1.3F));
+ // Ordinary arrays (negative inputs)
+ this->AssertUnaryOp(Negate, "[-1.3, -10.80, -12748.001]", "[1.3, 10.80, 12748.001]");
+ this->AssertUnaryOp(Negate, -1.3F, 1.3F);
+ this->AssertUnaryOp(Negate, MakeArray(-1.3F), "[1.3]");
+ // Arrays with infinites
+ this->AssertUnaryOp(Negate, "[Inf, -Inf]", "[-Inf, Inf]");
+ // Arrays with NaNs
+ this->SetNansEqual(true);
+ this->AssertUnaryOp(Negate, "[NaN]", "[NaN]");
+ this->AssertUnaryOp(Negate, "[NaN]", "[-NaN]");
+ this->AssertUnaryOp(Negate, "[-NaN]", "[NaN]");
+ // Min/max
+ this->AssertUnaryOp(Negate, min, max);
+ this->AssertUnaryOp(Negate, max, min);
+ }
+}
+
+TYPED_TEST(TestUnaryArithmeticSigned, AbsoluteValue) {
+ using CType = typename TestFixture::CType;
+
+ auto min = std::numeric_limits<CType>::min();
+ auto max = std::numeric_limits<CType>::max();
+
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ // Empty array
+ this->AssertUnaryOp(AbsoluteValue, "[]", "[]");
+ // Scalar/arrays with nulls
+ this->AssertUnaryOp(AbsoluteValue, "[null]", "[null]");
+ this->AssertUnaryOp(AbsoluteValue, "[1, null, -10]", "[1, null, 10]");
+ this->AssertUnaryOp(AbsoluteValue, this->MakeNullScalar(), this->MakeNullScalar());
+ // Scalar/arrays with zeros
+ this->AssertUnaryOp(AbsoluteValue, "[0, -0]", "[0, 0]");
+ this->AssertUnaryOp(AbsoluteValue, -0, 0);
+ this->AssertUnaryOp(AbsoluteValue, 0, 0);
+ // Ordinary scalar/arrays (positive inputs)
+ this->AssertUnaryOp(AbsoluteValue, "[1, 10, 127]", "[1, 10, 127]");
+ this->AssertUnaryOp(AbsoluteValue, 1, 1);
+ this->AssertUnaryOp(AbsoluteValue, this->MakeScalar(1), this->MakeScalar(1));
+ // Ordinary scalar/arrays (negative inputs)
+ this->AssertUnaryOp(AbsoluteValue, "[-1, -10, -127]", "[1, 10, 127]");
+ this->AssertUnaryOp(AbsoluteValue, -1, 1);
+ this->AssertUnaryOp(AbsoluteValue, MakeArray(-1), "[1]");
+ // Min/max
+ this->AssertUnaryOp(AbsoluteValue, max, max);
+ if (check_overflow) {
+ this->AssertUnaryOpRaises(AbsoluteValue, MakeArray(min), "overflow");
+ } else {
+ this->AssertUnaryOp(AbsoluteValue, min, min);
+ }
+ }
+
+ // Overflow should not be checked on underlying value slots when output would be null
+ this->SetOverflowCheck(true);
+ auto arg = ArrayFromJSON(this->type_singleton(), MakeArray(-1, max, min));
+ arg = TweakValidityBit(arg, 1, false);
+ arg = TweakValidityBit(arg, 2, false);
+ this->AssertUnaryOp(AbsoluteValue, arg, "[1, null, null]");
+}
+
+TYPED_TEST(TestUnaryArithmeticUnsigned, AbsoluteValue) {
+ using CType = typename TestFixture::CType;
+
+ auto min = std::numeric_limits<CType>::min();
+ auto max = std::numeric_limits<CType>::max();
+
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ // Empty arrays
+ this->AssertUnaryOp(AbsoluteValue, "[]", "[]");
+ // Array with nulls
+ this->AssertUnaryOp(AbsoluteValue, "[null]", "[null]");
+ this->AssertUnaryOp(AbsoluteValue, this->MakeNullScalar(), this->MakeNullScalar());
+ // Ordinary arrays
+ this->AssertUnaryOp(AbsoluteValue, "[0, 1, 10, 127]", "[0, 1, 10, 127]");
+ // Min/max
+ this->AssertUnaryOp(AbsoluteValue, min, min);
+ this->AssertUnaryOp(AbsoluteValue, max, max);
+ }
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, AbsoluteValue) {
+ using CType = typename TestFixture::CType;
+
+ auto min = std::numeric_limits<CType>::lowest();
+ auto max = std::numeric_limits<CType>::max();
+
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ // Empty array
+ this->AssertUnaryOp(AbsoluteValue, "[]", "[]");
+ // Scalar/arrays with nulls
+ this->AssertUnaryOp(AbsoluteValue, "[null]", "[null]");
+ this->AssertUnaryOp(AbsoluteValue, "[1.3, null, -10.80]", "[1.3, null, 10.80]");
+ this->AssertUnaryOp(AbsoluteValue, this->MakeNullScalar(), this->MakeNullScalar());
+ // Scalars/arrays with zeros
+ this->AssertUnaryOp(AbsoluteValue, "[0.0, -0.0]", "[0.0, 0.0]");
+ this->AssertUnaryOp(AbsoluteValue, -0.0F, 0.0F);
+ this->AssertUnaryOp(AbsoluteValue, 0.0F, 0.0F);
+ // Ordinary scalars/arrays (positive inputs)
+ this->AssertUnaryOp(AbsoluteValue, "[1.3, 10.80, 12748.001]",
+ "[1.3, 10.80, 12748.001]");
+ this->AssertUnaryOp(AbsoluteValue, 1.3F, 1.3F);
+ this->AssertUnaryOp(AbsoluteValue, this->MakeScalar(1.3F), this->MakeScalar(1.3F));
+ // Ordinary scalars/arrays (negative inputs)
+ this->AssertUnaryOp(AbsoluteValue, "[-1.3, -10.80, -12748.001]",
+ "[1.3, 10.80, 12748.001]");
+ this->AssertUnaryOp(AbsoluteValue, -1.3F, 1.3F);
+ this->AssertUnaryOp(AbsoluteValue, MakeArray(-1.3F), "[1.3]");
+ // Arrays with infinites
+ this->AssertUnaryOp(AbsoluteValue, "[Inf, -Inf]", "[Inf, Inf]");
+ // Arrays with NaNs
+ this->SetNansEqual(true);
+ this->AssertUnaryOp(AbsoluteValue, "[NaN]", "[NaN]");
+ this->AssertUnaryOp(AbsoluteValue, "[-NaN]", "[NaN]");
+ // Min/max
+ this->AssertUnaryOp(AbsoluteValue, min, max);
+ this->AssertUnaryOp(AbsoluteValue, max, max);
+ }
+}
+
+class TestUnaryArithmeticDecimal : public TestArithmeticDecimal {};
+
+// Check two modes exhaustively, give all modes a simple test
+TEST_F(TestUnaryArithmeticDecimal, Round) {
+ const auto func = "round";
+ RoundOptions options(2, RoundMode::DOWN);
+ for (const auto& ty : {decimal128(4, 3), decimal256(4, 3)}) {
+ auto values = ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.012", "1.015", "1.019", "-1.010", "-1.012", "-1.015", "-1.019", null])");
+ options.round_mode = RoundMode::DOWN;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.010", "1.010", "-1.010", "-1.020", "-1.020", "-1.020", null])"),
+ &options);
+ options.round_mode = RoundMode::UP;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.020", "1.020", "1.020", "-1.010", "-1.010", "-1.010", "-1.010", null])"),
+ &options);
+ options.round_mode = RoundMode::TOWARDS_ZERO;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.010", "1.010", "-1.010", "-1.010", "-1.010", "-1.010", null])"),
+ &options);
+ options.round_mode = RoundMode::TOWARDS_INFINITY;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.020", "1.020", "1.020", "-1.010", "-1.020", "-1.020", "-1.020", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_DOWN;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.010", "1.020", "-1.010", "-1.010", "-1.020", "-1.020", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_UP;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.020", "1.020", "-1.010", "-1.010", "-1.010", "-1.020", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TOWARDS_ZERO;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.010", "1.020", "-1.010", "-1.010", "-1.010", "-1.020", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TOWARDS_INFINITY;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.020", "1.020", "-1.010", "-1.010", "-1.020", "-1.020", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TO_EVEN;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.020", "1.020", "-1.010", "-1.010", "-1.020", "-1.020", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TO_ODD;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.010", "1.020", "-1.010", "-1.010", "-1.010", "-1.020", null])"),
+ &options);
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundTowardsInfinity) {
+ const auto func = "round";
+ RoundOptions options(0, RoundMode::TOWARDS_INFINITY);
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ auto values = ArrayFromJSON(
+ ty, R"(["1.00", "1.99", "1.01", "-42.00", "-42.99", "-42.15", null])");
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"), &options);
+ options.ndigits = 0;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(ty,
+ R"(["1.00", "2.00", "2.00", "-42.00", "-43.00", "-43.00", null])"),
+ &options);
+ options.ndigits = 1;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(ty,
+ R"(["1.00", "2.00", "1.10", "-42.00", "-43.00", "-42.20", null])"),
+ &options);
+ options.ndigits = 2;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 4;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 100;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -1;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty, R"(["10.00", "10.00", "10.00", "-50.00", "-50.00", "-50.00", null])"),
+ &options);
+ options.ndigits = -2;
+ CheckRaises(func, {values}, "Rounding to -2 digits will not fit in precision",
+ &options);
+ options.ndigits = -1;
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")},
+ "Rounded value 100.00 does not fit in precision", &options);
+ }
+ for (const auto& ty : {decimal128(2, -2), decimal256(2, -2)}) {
+ auto values = DecimalArrayFromJSON(
+ ty, R"(["10E2", "12E2", "18E2", "-10E2", "-12E2", "-18E2", null])");
+ options.ndigits = 0;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 2;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 100;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -1;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -2;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -3;
+ CheckScalar(func, {values},
+ DecimalArrayFromJSON(
+ ty, R"(["10E2", "20E2", "20E2", "-10E2", "-20E2", "-20E2", null])"),
+ &options);
+ options.ndigits = -4;
+ CheckRaises(func, {values}, "Rounding to -4 digits will not fit in precision",
+ &options);
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundHalfToEven) {
+ const auto func = "round";
+ RoundOptions options(0, RoundMode::HALF_TO_EVEN);
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ auto values = ArrayFromJSON(
+ ty,
+ R"(["1.00", "5.99", "1.01", "-42.00", "-42.99", "-42.15", "1.50", "2.50", "-5.50", "-2.55", null])");
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"), &options);
+ options.ndigits = 0;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.00", "6.00", "1.00", "-42.00", "-43.00", "-42.00", "2.00", "2.00", "-6.00", "-3.00", null])"),
+ &options);
+ options.ndigits = 1;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.00", "6.00", "1.00", "-42.00", "-43.00", "-42.20", "1.50", "2.50", "-5.50", "-2.60", null])"),
+ &options);
+ options.ndigits = 2;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 4;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 100;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -1;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["0.00", "10.00", "0.00", "-40.00", "-40.00", "-40.00", "0.00", "0.00", "-10.00", "0.00", null])"),
+ &options);
+ options.ndigits = -2;
+ CheckRaises(func, {values}, "Rounding to -2 digits will not fit in precision",
+ &options);
+ options.ndigits = -1;
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")},
+ "Rounded value 100.00 does not fit in precision", &options);
+ }
+ for (const auto& ty : {decimal128(2, -2), decimal256(2, -2)}) {
+ auto values = DecimalArrayFromJSON(
+ ty,
+ R"(["5E2", "10E2", "12E2", "15E2", "18E2", "-10E2", "-12E2", "-15E2", "-18E2", null])");
+ options.ndigits = 0;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 2;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 100;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -1;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -2;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -3;
+ CheckScalar(
+ func, {values},
+ DecimalArrayFromJSON(
+ ty,
+ R"(["0", "10E2", "10E2", "20E2", "20E2", "-10E2", "-10E2", "-20E2", "-20E2", null])"),
+ &options);
+ options.ndigits = -4;
+ CheckRaises(func, {values}, "Rounding to -4 digits will not fit in precision",
+ &options);
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundCeil) {
+ const auto func = "ceil";
+ for (const auto& ty : PositiveScaleTypes()) {
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"));
+ CheckScalar(
+ func,
+ {ArrayFromJSON(
+ ty, R"(["1.00", "1.99", "1.01", "-42.00", "-42.99", "-42.15", null])")},
+ ArrayFromJSON(ty,
+ R"(["1.00", "2.00", "2.00", "-42.00", "-42.00", "-42.00", null])"));
+ }
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ CheckRaises(func, {ScalarFromJSON(ty, R"("99.99")")},
+ "Rounded value 100.00 does not fit in precision of decimal");
+ CheckScalar(func, {ScalarFromJSON(ty, R"("-99.99")")},
+ ScalarFromJSON(ty, R"("-99.00")"));
+ }
+ for (const auto& ty : NegativeScaleTypes()) {
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"));
+ CheckScalar(func, {DecimalArrayFromJSON(ty, R"(["12E2", "-42E2", null])")},
+ DecimalArrayFromJSON(ty, R"(["12E2", "-42E2", null])"));
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundFloor) {
+ const auto func = "floor";
+ for (const auto& ty : PositiveScaleTypes()) {
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"));
+ CheckScalar(
+ func,
+ {ArrayFromJSON(
+ ty, R"(["1.00", "1.99", "1.01", "-42.00", "-42.99", "-42.15", null])")},
+ ArrayFromJSON(ty,
+ R"(["1.00", "1.00", "1.00", "-42.00", "-43.00", "-43.00", null])"));
+ }
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ CheckScalar(func, {ScalarFromJSON(ty, R"("99.99")")},
+ ScalarFromJSON(ty, R"("99.00")"));
+ CheckRaises(func, {ScalarFromJSON(ty, R"("-99.99")")},
+ "Rounded value -100.00 does not fit in precision of decimal");
+ }
+ for (const auto& ty : NegativeScaleTypes()) {
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"));
+ CheckScalar(func, {DecimalArrayFromJSON(ty, R"(["12E2", "-42E2", null])")},
+ DecimalArrayFromJSON(ty, R"(["12E2", "-42E2", null])"));
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundTrunc) {
+ const auto func = "trunc";
+ for (const auto& ty : PositiveScaleTypes()) {
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"));
+ CheckScalar(
+ func,
+ {ArrayFromJSON(
+ ty, R"(["1.00", "1.99", "1.01", "-42.00", "-42.99", "-42.15", null])")},
+ ArrayFromJSON(ty,
+ R"(["1.00", "1.00", "1.00", "-42.00", "-42.00", "-42.00", null])"));
+ }
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ CheckScalar(func, {ScalarFromJSON(ty, R"("99.99")")},
+ ScalarFromJSON(ty, R"("99.00")"));
+ CheckScalar(func, {ScalarFromJSON(ty, R"("-99.99")")},
+ ScalarFromJSON(ty, R"("-99.00")"));
+ }
+ for (const auto& ty : NegativeScaleTypes()) {
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"));
+ CheckScalar(func, {DecimalArrayFromJSON(ty, R"(["12E2", "-42E2", null])")},
+ DecimalArrayFromJSON(ty, R"(["12E2", "-42E2", null])"));
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundToMultiple) {
+ const auto func = "round_to_multiple";
+ RoundToMultipleOptions options(0, RoundMode::DOWN);
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ if (ty->id() == Type::DECIMAL128) {
+ options.multiple = std::make_shared<Decimal128Scalar>(Decimal128(200), ty);
+ } else {
+ options.multiple = std::make_shared<Decimal256Scalar>(Decimal256(200), ty);
+ }
+ auto values = ArrayFromJSON(
+ ty,
+ R"(["-3.50", "-3.00", "-2.50", "-2.00", "-1.50", "-1.00", "-0.50", "0.00",
+ "0.50", "1.00", "1.50", "2.00", "2.50", "3.00", "3.50", null])");
+ options.round_mode = RoundMode::DOWN;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-4.00", "-4.00", "-2.00", "-2.00", "-2.00", "-2.00", "0.00",
+ "0.00", "0.00", "0.00", "2.00", "2.00", "2.00", "2.00", null])"),
+ &options);
+ options.round_mode = RoundMode::UP;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-2.00", "-2.00", "-2.00", "-2.00", "-0.00", "-0.00", "-0.00", "0.00",
+ "2.00", "2.00", "2.00", "2.00", "4.00", "4.00", "4.00", null])"),
+ &options);
+ options.round_mode = RoundMode::TOWARDS_ZERO;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-2.00", "-2.00", "-2.00", "-2.00", "-0.00", "-0.00", "-0.00", "0.00",
+ "0.00", "0.00", "0.00", "2.00", "2.00", "2.00", "2.00", null])"),
+ &options);
+ options.round_mode = RoundMode::TOWARDS_INFINITY;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-4.00", "-4.00", "-2.00", "-2.00", "-2.00", "-2.00", "0.00",
+ "2.00", "2.00", "2.00", "2.00", "4.00", "4.00", "4.00", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_DOWN;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-4.00", "-2.00", "-2.00", "-2.00", "-2.00", "-0.00", "0.00",
+ "0.00", "0.00", "2.00", "2.00", "2.00", "2.00", "4.00", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_UP;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-2.00", "-2.00", "-2.00", "-2.00", "-0.00", "-0.00", "0.00",
+ "0.00", "2.00", "2.00", "2.00", "2.00", "4.00", "4.00", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TOWARDS_ZERO;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-2.00", "-2.00", "-2.00", "-2.00", "-0.00", "-0.00", "0.00",
+ "0.00", "0.00", "2.00", "2.00", "2.00", "2.00", "4.00", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TOWARDS_INFINITY;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-4.00", "-2.00", "-2.00", "-2.00", "-2.00", "-0.00", "0.00",
+ "0.00", "2.00", "2.00", "2.00", "2.00", "4.00", "4.00", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TO_EVEN;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-4.00", "-2.00", "-2.00", "-2.00", "-0.00", "-0.00", "0.00",
+ "0.00", "0.00", "2.00", "2.00", "2.00", "4.00", "4.00", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TO_ODD;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-2.00", "-2.00", "-2.00", "-2.00", "-2.00", "-0.00", "0.00",
+ "0.00", "2.00", "2.00", "2.00", "2.00", "2.00", "4.00", null])"),
+ &options);
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundToMultipleTowardsInfinity) {
+ const auto func = "round_to_multiple";
+ RoundToMultipleOptions options(0, RoundMode::TOWARDS_INFINITY);
+ auto set_multiple = [&](const std::shared_ptr<DataType>& ty, int64_t value) {
+ if (ty->id() == Type::DECIMAL128) {
+ options.multiple = std::make_shared<Decimal128Scalar>(Decimal128(value), ty);
+ } else {
+ options.multiple = std::make_shared<Decimal256Scalar>(Decimal256(value), ty);
+ }
+ };
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ auto values = ArrayFromJSON(
+ ty, R"(["1.00", "1.99", "1.01", "-42.00", "-42.99", "-42.15", null])");
+ set_multiple(ty, 25);
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"), &options);
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(ty,
+ R"(["1.00", "2.00", "1.25", "-42.00", "-43.00", "-42.25", null])"),
+ &options);
+ set_multiple(ty, 1);
+ CheckScalar(func, {values}, values, &options);
+ set_multiple(ty, 0);
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")},
+ "Rounding multiple must be positive", &options);
+ set_multiple(ty, -10);
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")},
+ "Rounding multiple must be positive", &options);
+ set_multiple(ty, 100);
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")},
+ "Rounded value 100.00 does not fit in precision", &options);
+ options.multiple = std::make_shared<DoubleScalar>(1.0);
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")}, "scalar, not double",
+ &options);
+ options.multiple =
+ std::make_shared<Decimal128Scalar>(Decimal128(0), decimal128(3, 0));
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")}, "scalar, not decimal128(3, 0)",
+ &options);
+ options.multiple = std::make_shared<Decimal128Scalar>(decimal128(3, 0));
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")},
+ "Rounding multiple must be non-null and valid", &options);
+ options.multiple = nullptr;
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")},
+ "Rounding multiple must be non-null and valid", &options);
+ }
+ for (const auto& ty : {decimal128(2, -2), decimal256(2, -2)}) {
+ auto values = DecimalArrayFromJSON(
+ ty, R"(["10E2", "12E2", "18E2", "-10E2", "-12E2", "-18E2", null])");
+ set_multiple(ty, 4);
+ CheckScalar(func, {values},
+ DecimalArrayFromJSON(
+ ty, R"(["12E2", "12E2", "20E2", "-12E2", "-12E2", "-20E2", null])"),
+ &options);
+ set_multiple(ty, 1);
+ CheckScalar(func, {values}, values, &options);
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundToMultipleHalfToOdd) {
+ const auto func = "round_to_multiple";
+ RoundToMultipleOptions options(0, RoundMode::HALF_TO_ODD);
+ auto set_multiple = [&](const std::shared_ptr<DataType>& ty, int64_t value) {
+ if (ty->id() == Type::DECIMAL128) {
+ options.multiple = std::make_shared<Decimal128Scalar>(Decimal128(value), ty);
+ } else {
+ options.multiple = std::make_shared<Decimal256Scalar>(Decimal256(value), ty);
+ }
+ };
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ auto values =
+ ArrayFromJSON(ty, R"(["-0.38", "-0.37", "-0.25", "-0.13", "-0.12", "0.00",
+ "0.12", "0.13", "0.25", "0.37", "0.38", null])");
+ // There is no exact halfway point, check what happens
+ set_multiple(ty, 25);
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"), &options);
+ CheckScalar(func, {values},
+ ArrayFromJSON(ty, R"(["-0.50", "-0.25", "-0.25", "-0.25", "-0.00", "0.00",
+ "0.00", "0.25", "0.25", "0.25", "0.50", null])"),
+ &options);
+ set_multiple(ty, 1);
+ CheckScalar(func, {values}, values, &options);
+ set_multiple(ty, 24);
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"), &options);
+ CheckScalar(func, {values},
+ ArrayFromJSON(ty, R"(["-0.48", "-0.48", "-0.24", "-0.24", "-0.24", "0.00",
+ "0.24", "0.24", "0.24", "0.48", "0.48", null])"),
+ &options);
+ }
+ for (const auto& ty : {decimal128(2, -2), decimal256(2, -2)}) {
+ auto values = DecimalArrayFromJSON(
+ ty, R"(["10E2", "12E2", "18E2", "-10E2", "-12E2", "-18E2", null])");
+ set_multiple(ty, 4);
+ CheckScalar(func, {values},
+ DecimalArrayFromJSON(
+ ty, R"(["12E2", "12E2", "20E2", "-12E2", "-12E2", "-20E2", null])"),
+ &options);
+ set_multiple(ty, 5);
+ CheckScalar(func, {values},
+ DecimalArrayFromJSON(
+ ty, R"(["10E2", "10E2", "20E2", "-10E2", "-10E2", "-20E2", null])"),
+ &options);
+ set_multiple(ty, 1);
+ CheckScalar(func, {values}, values, &options);
+ }
+}
+
+TYPED_TEST_SUITE(TestUnaryRoundIntegral, IntegralTypes);
+TYPED_TEST_SUITE(TestUnaryRoundSigned, SignedIntegerTypes);
+TYPED_TEST_SUITE(TestUnaryRoundUnsigned, UnsignedIntegerTypes);
+TYPED_TEST_SUITE(TestUnaryRoundFloating, FloatingTypes);
+
+const std::vector<RoundMode> kRoundModes{
+ RoundMode::DOWN,
+ RoundMode::UP,
+ RoundMode::TOWARDS_ZERO,
+ RoundMode::TOWARDS_INFINITY,
+ RoundMode::HALF_DOWN,
+ RoundMode::HALF_UP,
+ RoundMode::HALF_TOWARDS_ZERO,
+ RoundMode::HALF_TOWARDS_INFINITY,
+ RoundMode::HALF_TO_EVEN,
+ RoundMode::HALF_TO_ODD,
+};
+
+TYPED_TEST(TestUnaryRoundSigned, Round) {
+ // Test different rounding modes for integer rounding
+ std::string values("[0, 1, -13, -50, 115]");
+ this->SetRoundNdigits(0);
+ for (const auto& round_mode : kRoundModes) {
+ this->SetRoundMode(round_mode);
+ this->AssertUnaryOp(Round, values, ArrayFromJSON(float64(), values));
+ }
+
+ // Test different round N-digits for nearest rounding mode
+ std::vector<std::pair<int64_t, std::string>> ndigits_and_expected{{
+ {-2, "[0, 0, -0, -100, 100]"},
+ {-1, "[0, 0, -10, -50, 120]"},
+ {0, values},
+ {1, values},
+ {2, values},
+ }};
+ this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY);
+ for (const auto& pair : ndigits_and_expected) {
+ this->SetRoundNdigits(pair.first);
+ this->AssertUnaryOp(Round, values, ArrayFromJSON(float64(), pair.second));
+ }
+}
+
+TYPED_TEST(TestUnaryRoundUnsigned, Round) {
+ // Test different rounding modes for integer rounding
+ std::string values("[0, 1, 13, 50, 115]");
+ this->SetRoundNdigits(0);
+ for (const auto& round_mode : kRoundModes) {
+ this->SetRoundMode(round_mode);
+ this->AssertUnaryOp(Round, values, ArrayFromJSON(float64(), values));
+ }
+
+ // Test different round N-digits for nearest rounding mode
+ std::vector<std::pair<int64_t, std::string>> ndigits_and_expected{{
+ {-2, "[0, 0, 0, 100, 100]"},
+ {-1, "[0, 0, 10, 50, 120]"},
+ {0, values},
+ {1, values},
+ {2, values},
+ }};
+ this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY);
+ for (const auto& pair : ndigits_and_expected) {
+ this->SetRoundNdigits(pair.first);
+ this->AssertUnaryOp(Round, values, ArrayFromJSON(float64(), pair.second));
+ }
+}
+
+TYPED_TEST(TestUnaryRoundFloating, Round) {
+ this->SetNansEqual(true);
+
+ // Test different rounding modes
+ std::string values("[3.2, 3.5, 3.7, 4.5, -3.2, -3.5, -3.7]");
+ std::vector<std::pair<RoundMode, std::string>> rmode_and_expected{{
+ {RoundMode::DOWN, "[3, 3, 3, 4, -4, -4, -4]"},
+ {RoundMode::UP, "[4, 4, 4, 5, -3, -3, -3]"},
+ {RoundMode::TOWARDS_ZERO, "[3, 3, 3, 4, -3, -3, -3]"},
+ {RoundMode::TOWARDS_INFINITY, "[4, 4, 4, 5, -4, -4, -4]"},
+ {RoundMode::HALF_DOWN, "[3, 3, 4, 4, -3, -4, -4]"},
+ {RoundMode::HALF_UP, "[3, 4, 4, 5, -3, -3, -4]"},
+ {RoundMode::HALF_TOWARDS_ZERO, "[3, 3, 4, 4, -3, -3, -4]"},
+ {RoundMode::HALF_TOWARDS_INFINITY, "[3, 4, 4, 5, -3, -4, -4]"},
+ {RoundMode::HALF_TO_EVEN, "[3, 4, 4, 4, -3, -4, -4]"},
+ {RoundMode::HALF_TO_ODD, "[3, 3, 4, 5, -3, -3, -4]"},
+ }};
+ this->SetRoundNdigits(0);
+ for (const auto& pair : rmode_and_expected) {
+ this->SetRoundMode(pair.first);
+ this->AssertUnaryOp(Round, "[]", "[]");
+ this->AssertUnaryOp(Round, "[null, 0, Inf, -Inf, NaN, -NaN]",
+ "[null, 0, Inf, -Inf, NaN, -NaN]");
+ this->AssertUnaryOp(Round, values, pair.second);
+ }
+
+ // Test different round N-digits for nearest rounding mode
+ values = "[320, 3.5, 3.075, 4.5, -3.212, -35.1234, -3.045]";
+ std::vector<std::pair<int64_t, std::string>> ndigits_and_expected{{
+ {-2, "[300, 0, 0, 0, -0, -0, -0]"},
+ {-1, "[320, 0, 0, 0, -0, -40, -0]"},
+ {0, "[320, 4, 3, 5, -3, -35, -3]"},
+ {1, "[320, 3.5, 3.1, 4.5, -3.2, -35.1, -3]"},
+ {2, "[320, 3.5, 3.08, 4.5, -3.21, -35.12, -3.05]"},
+ }};
+ this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY);
+ for (const auto& pair : ndigits_and_expected) {
+ this->SetRoundNdigits(pair.first);
+ this->AssertUnaryOp(Round, values, pair.second);
+ }
+}
+
+TYPED_TEST_SUITE(TestUnaryRoundToMultipleIntegral, IntegralTypes);
+TYPED_TEST_SUITE(TestUnaryRoundToMultipleSigned, SignedIntegerTypes);
+TYPED_TEST_SUITE(TestUnaryRoundToMultipleUnsigned, UnsignedIntegerTypes);
+TYPED_TEST_SUITE(TestUnaryRoundToMultipleFloating, FloatingTypes);
+
+TYPED_TEST(TestUnaryRoundToMultipleSigned, RoundToMultiple) {
+ // Test different rounding modes for integer rounding
+ std::string values("[0, 1, -13, -50, 115]");
+ this->SetRoundMultiple(1);
+ for (const auto& round_mode : kRoundModes) {
+ this->SetRoundMode(round_mode);
+ this->AssertUnaryOp(RoundToMultiple, values, ArrayFromJSON(float64(), values));
+ }
+
+ // Test different round multiples for nearest rounding mode
+ std::vector<std::pair<double, std::string>> multiple_and_expected{{
+ {2, "[0, 2, -14, -50, 116]"},
+ {0.05, "[0, 1, -13, -50, 115]"},
+ {0.1, values},
+ {10, "[0, 0, -10, -50, 120]"},
+ {100, "[0, 0, -0, -100, 100]"},
+ }};
+ this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY);
+ for (const auto& pair : multiple_and_expected) {
+ this->SetRoundMultiple(pair.first);
+ this->AssertUnaryOp(RoundToMultiple, values, ArrayFromJSON(float64(), pair.second));
+ }
+}
+
+TYPED_TEST(TestUnaryRoundToMultipleUnsigned, RoundToMultiple) {
+ // Test different rounding modes for integer rounding
+ std::string values("[0, 1, 13, 50, 115]");
+ this->SetRoundMultiple(1);
+ for (const auto& round_mode : kRoundModes) {
+ this->SetRoundMode(round_mode);
+ this->AssertUnaryOp(RoundToMultiple, values, ArrayFromJSON(float64(), values));
+ }
+
+ // Test different round multiples for nearest rounding mode
+ std::vector<std::pair<double, std::string>> multiple_and_expected{{
+ {2, "[0, 2, 14, 50, 116]"},
+ {0.05, "[0, 1, 13, 50, 115]"},
+ {0.1, values},
+ {10, "[0, 0, 10, 50, 120]"},
+ {100, "[0, 0, 0, 100, 100]"},
+ }};
+ this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY);
+ for (const auto& pair : multiple_and_expected) {
+ this->SetRoundMultiple(pair.first);
+ this->AssertUnaryOp(RoundToMultiple, values, ArrayFromJSON(float64(), pair.second));
+ }
+}
+
+TYPED_TEST(TestUnaryRoundToMultipleFloating, RoundToMultiple) {
+ this->SetNansEqual(true);
+
+ // Test different rounding modes for integer rounding
+ std::string values("[3.2, 3.5, 3.7, 4.5, -3.2, -3.5, -3.7]");
+ std::vector<std::pair<RoundMode, std::string>> rmode_and_expected{{
+ {RoundMode::DOWN, "[3, 3, 3, 4, -4, -4, -4]"},
+ {RoundMode::UP, "[4, 4, 4, 5, -3, -3, -3]"},
+ {RoundMode::TOWARDS_ZERO, "[3, 3, 3, 4, -3, -3, -3]"},
+ {RoundMode::TOWARDS_INFINITY, "[4, 4, 4, 5, -4, -4, -4]"},
+ {RoundMode::HALF_DOWN, "[3, 3, 4, 4, -3, -4, -4]"},
+ {RoundMode::HALF_UP, "[3, 4, 4, 5, -3, -3, -4]"},
+ {RoundMode::HALF_TOWARDS_ZERO, "[3, 3, 4, 4, -3, -3, -4]"},
+ {RoundMode::HALF_TOWARDS_INFINITY, "[3, 4, 4, 5, -3, -4, -4]"},
+ {RoundMode::HALF_TO_EVEN, "[3, 4, 4, 4, -3, -4, -4]"},
+ {RoundMode::HALF_TO_ODD, "[3, 3, 4, 5, -3, -3, -4]"},
+ }};
+ this->SetRoundMultiple(1);
+ for (const auto& pair : rmode_and_expected) {
+ this->SetRoundMode(pair.first);
+ this->AssertUnaryOp(RoundToMultiple, "[]", "[]");
+ this->AssertUnaryOp(RoundToMultiple, "[null, 0, Inf, -Inf, NaN, -NaN]",
+ "[null, 0, Inf, -Inf, NaN, -NaN]");
+ this->AssertUnaryOp(RoundToMultiple, values, pair.second);
+ }
+
+ // Test different round multiples for nearest rounding mode
+ values = "[320, 3.5, 3.075, 4.5, -3.212, -35.1234, -3.045]";
+ std::vector<std::pair<double, std::string>> multiple_and_expected{{
+ {2, "[320, 4, 4, 4, -4, -36, -4]"},
+ {0.05, "[320, 3.5, 3.1, 4.5, -3.2, -35.1, -3.05]"},
+ {0.1, "[320, 3.5, 3.1, 4.5, -3.2, -35.1, -3]"},
+ {10, "[320, 0, 0, 0, -0, -40, -0]"},
+ {100, "[300, 0, 0, 0, -0, -0, -0]"},
+ }};
+ this->SetRoundMode(RoundMode::HALF_TOWARDS_INFINITY);
+ for (const auto& pair : multiple_and_expected) {
+ this->SetRoundMultiple(pair.first);
+ this->AssertUnaryOp(RoundToMultiple, values, pair.second);
+ }
+
+ this->SetRoundMultiple(-2);
+ this->AssertUnaryOpRaises(RoundToMultiple, values, "multiple must be positive");
+}
+
+TEST(TestBinaryDecimalArithmetic, DispatchBest) {
+ // decimal, floating point
+ for (std::string name : {"add", "subtract", "multiply", "divide"}) {
+ for (std::string suffix : {"", "_checked"}) {
+ name += suffix;
+
+ CheckDispatchBest(name, {decimal128(1, 0), float32()}, {float32(), float32()});
+ CheckDispatchBest(name, {decimal256(1, 0), float64()}, {float64(), float64()});
+ CheckDispatchBest(name, {float32(), decimal256(1, 0)}, {float32(), float32()});
+ CheckDispatchBest(name, {float64(), decimal128(1, 0)}, {float64(), float64()});
+ }
+ }
+
+ // decimal, decimal -> decimal
+ // decimal, integer -> decimal
+ for (std::string name : {"add", "subtract"}) {
+ for (std::string suffix : {"", "_checked"}) {
+ name += suffix;
+
+ CheckDispatchBest(name, {int64(), decimal128(1, 0)},
+ {decimal128(19, 0), decimal128(1, 0)});
+ CheckDispatchBest(name, {decimal128(1, 0), int64()},
+ {decimal128(1, 0), decimal128(19, 0)});
+
+ CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)},
+ {decimal128(2, 1), decimal128(2, 1)});
+ CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)},
+ {decimal256(2, 1), decimal256(2, 1)});
+ CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)},
+ {decimal256(2, 1), decimal256(2, 1)});
+ CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)},
+ {decimal256(2, 1), decimal256(2, 1)});
+
+ CheckDispatchBest(name, {decimal128(2, 0), decimal128(2, 1)},
+ {decimal128(3, 1), decimal128(2, 1)});
+ CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 0)},
+ {decimal128(2, 1), decimal128(3, 1)});
+ }
+ }
+ {
+ std::string name = "multiply";
+ for (std::string suffix : {"", "_checked"}) {
+ name += suffix;
+
+ CheckDispatchBest(name, {int64(), decimal128(1, 0)},
+ {decimal128(19, 0), decimal128(1, 0)});
+ CheckDispatchBest(name, {decimal128(1, 0), int64()},
+ {decimal128(1, 0), decimal128(19, 0)});
+
+ CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)},
+ {decimal128(2, 1), decimal128(2, 1)});
+ CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)},
+ {decimal256(2, 1), decimal256(2, 1)});
+ CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)},
+ {decimal256(2, 1), decimal256(2, 1)});
+ CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)},
+ {decimal256(2, 1), decimal256(2, 1)});
+
+ CheckDispatchBest(name, {decimal128(2, 0), decimal128(2, 1)},
+ {decimal128(2, 0), decimal128(2, 1)});
+ CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 0)},
+ {decimal128(2, 1), decimal128(2, 0)});
+ }
+ }
+ {
+ std::string name = "divide";
+ for (std::string suffix : {"", "_checked"}) {
+ name += suffix;
+ SCOPED_TRACE(name);
+
+ CheckDispatchBest(name, {int64(), decimal128(1, 0)},
+ {decimal128(23, 4), decimal128(1, 0)});
+ CheckDispatchBest(name, {decimal128(1, 0), int64()},
+ {decimal128(21, 20), decimal128(19, 0)});
+
+ CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)},
+ {decimal128(6, 5), decimal128(2, 1)});
+ CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)},
+ {decimal256(6, 5), decimal256(2, 1)});
+ CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)},
+ {decimal256(6, 5), decimal256(2, 1)});
+ CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)},
+ {decimal256(6, 5), decimal256(2, 1)});
+
+ CheckDispatchBest(name, {decimal128(2, 0), decimal128(2, 1)},
+ {decimal128(7, 5), decimal128(2, 1)});
+ CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 0)},
+ {decimal128(5, 4), decimal128(2, 0)});
+ }
+ }
+}
+
+// reference result from bc (precsion=100, scale=40)
+TEST(TestBinaryArithmeticDecimal, AddSubtract) {
+ // array array, decimal128
+ {
+ auto left = ArrayFromJSON(decimal128(30, 3),
+ R"([
+ "1.000",
+ "-123456789012345678901234567.890",
+ "98765432109876543210.987",
+ "-999999999999999999999999999.999"
+ ])");
+ auto right = ArrayFromJSON(decimal128(20, 9),
+ R"([
+ "-1.000000000",
+ "12345678901.234567890",
+ "98765.432101234",
+ "-99999999999.999999999"
+ ])");
+ auto added = ArrayFromJSON(decimal128(37, 9),
+ R"([
+ "0.000000000",
+ "-123456789012345666555555666.655432110",
+ "98765432109876641976.419101234",
+ "-1000000000000000099999999999.998999999"
+ ])");
+ auto subtracted = ArrayFromJSON(decimal128(37, 9),
+ R"([
+ "2.000000000",
+ "-123456789012345691246913469.124567890",
+ "98765432109876444445.554898766",
+ "-999999999999999899999999999.999000001"
+ ])");
+ CheckScalarBinary("add", left, right, added);
+ CheckScalarBinary("subtract", left, right, subtracted);
+ }
+
+ // array array, decimal256
+ {
+ auto left = ArrayFromJSON(decimal256(30, 20),
+ R"([
+ "-1.00000000000000000001",
+ "1234567890.12345678900000000000",
+ "-9876543210.09876543210987654321",
+ "9999999999.99999999999999999999"
+ ])");
+ auto right = ArrayFromJSON(decimal256(30, 10),
+ R"([
+ "1.0000000000",
+ "-1234567890.1234567890",
+ "6789.5432101234",
+ "99999999999999999999.9999999999"
+ ])");
+ auto added = ArrayFromJSON(decimal256(41, 20),
+ R"([
+ "-0.00000000000000000001",
+ "0.00000000000000000000",
+ "-9876536420.55555530870987654321",
+ "100000000009999999999.99999999989999999999"
+ ])");
+ auto subtracted = ArrayFromJSON(decimal256(41, 20),
+ R"([
+ "-2.00000000000000000001",
+ "2469135780.24691357800000000000",
+ "-9876549999.64197555550987654321",
+ "-99999999989999999999.99999999990000000001"
+ ])");
+ CheckScalarBinary("add", left, right, added);
+ CheckScalarBinary("subtract", left, right, subtracted);
+ }
+
+ // scalar array
+ {
+ auto left = ScalarFromJSON(decimal128(6, 1), R"("12345.6")");
+ auto right = ArrayFromJSON(decimal128(10, 3),
+ R"(["1.234", "1234.000", "-9876.543", "666.888"])");
+ auto added = ArrayFromJSON(decimal128(11, 3),
+ R"(["12346.834", "13579.600", "2469.057", "13012.488"])");
+ auto left_sub_right = ArrayFromJSON(
+ decimal128(11, 3), R"(["12344.366", "11111.600", "22222.143", "11678.712"])");
+ auto right_sub_left = ArrayFromJSON(
+ decimal128(11, 3), R"(["-12344.366", "-11111.600", "-22222.143", "-11678.712"])");
+ CheckScalarBinary("add", left, right, added);
+ CheckScalarBinary("add", right, left, added);
+ CheckScalarBinary("subtract", left, right, left_sub_right);
+ CheckScalarBinary("subtract", right, left, right_sub_left);
+ }
+
+ // scalar scalar
+ {
+ auto left = ScalarFromJSON(decimal256(3, 0), R"("666")");
+ auto right = ScalarFromJSON(decimal256(3, 0), R"("888")");
+ auto added = ScalarFromJSON(decimal256(4, 0), R"("1554")");
+ auto subtracted = ScalarFromJSON(decimal256(4, 0), R"("-222")");
+ CheckScalarBinary("add", left, right, added);
+ CheckScalarBinary("subtract", left, right, subtracted);
+ }
+
+ // decimal128 decimal256
+ {
+ auto left = ScalarFromJSON(decimal128(3, 0), R"("666")");
+ auto right = ScalarFromJSON(decimal256(3, 0), R"("888")");
+ auto added = ScalarFromJSON(decimal256(4, 0), R"("1554")");
+ CheckScalarBinary("add", left, right, added);
+ CheckScalarBinary("add", right, left, added);
+ }
+
+ // decimal float
+ {
+ auto left = ScalarFromJSON(decimal128(3, 0), R"("666")");
+ ASSIGN_OR_ABORT(auto right, arrow::MakeScalar(float64(), 888));
+ ASSIGN_OR_ABORT(auto added, arrow::MakeScalar(float64(), 1554));
+ CheckScalarBinary("add", left, right, added);
+ CheckScalarBinary("add", right, left, added);
+ }
+
+ // TODO: decimal integer
+
+ // failed case: result maybe overflow
+ {
+ std::shared_ptr<Scalar> left, right;
+
+ left = ScalarFromJSON(decimal128(21, 20), R"("0.12345678901234567890")");
+ right = ScalarFromJSON(decimal128(21, 1), R"("1.0")");
+ ASSERT_RAISES(Invalid, CallFunction("add", {left, right}));
+ ASSERT_RAISES(Invalid, CallFunction("subtract", {left, right}));
+
+ left = ScalarFromJSON(decimal256(75, 0), R"("0")");
+ right = ScalarFromJSON(decimal256(2, 1), R"("0.0")");
+ ASSERT_RAISES(Invalid, CallFunction("add", {left, right}));
+ ASSERT_RAISES(Invalid, CallFunction("subtract", {left, right}));
+ }
+}
+
+TEST(TestBinaryArithmeticDecimal, Multiply) {
+ // array array, decimal128
+ {
+ auto left = ArrayFromJSON(decimal128(20, 10),
+ R"([
+ "1234567890.1234567890",
+ "-0.0000000001",
+ "-9999999999.9999999999"
+ ])");
+ auto right = ArrayFromJSON(decimal128(13, 3),
+ R"([
+ "1234567890.123",
+ "0.001",
+ "-9999999999.999"
+ ])");
+ auto expected = ArrayFromJSON(decimal128(34, 13),
+ R"([
+ "1524157875323319737.9870903950470",
+ "-0.0000000000001",
+ "99999999999989999999.0000000000001"
+ ])");
+ CheckScalarBinary("multiply", left, right, expected);
+ }
+
+ // array array, decimal26
+ {
+ auto left = ArrayFromJSON(decimal256(30, 3),
+ R"([
+ "123456789012345678901234567.890",
+ "0.000"
+ ])");
+ auto right = ArrayFromJSON(decimal256(20, 9),
+ R"([
+ "-12345678901.234567890",
+ "99999999999.999999999"
+ ])");
+ auto expected = ArrayFromJSON(decimal256(51, 12),
+ R"([
+ "-1524157875323883675034293577501905199.875019052100",
+ "0.000000000000"
+ ])");
+ CheckScalarBinary("multiply", left, right, expected);
+ }
+
+ // scalar array
+ {
+ auto left = ScalarFromJSON(decimal128(3, 2), R"("3.14")");
+ auto right = ArrayFromJSON(decimal128(1, 0), R"(["1", "2", "3", "4", "5"])");
+ auto expected =
+ ArrayFromJSON(decimal128(5, 2), R"(["3.14", "6.28", "9.42", "12.56", "15.70"])");
+ CheckScalarBinary("multiply", left, right, expected);
+ CheckScalarBinary("multiply", right, left, expected);
+ }
+
+ // scalar scalar
+ {
+ auto left = ScalarFromJSON(decimal128(1, 0), R"("1")");
+ auto right = ScalarFromJSON(decimal128(1, 0), R"("1")");
+ auto expected = ScalarFromJSON(decimal128(3, 0), R"("1")");
+ CheckScalarBinary("multiply", left, right, expected);
+ }
+
+ // decimal128 decimal256
+ {
+ auto left = ScalarFromJSON(decimal128(3, 2), R"("6.66")");
+ auto right = ScalarFromJSON(decimal256(3, 1), R"("88.8")");
+ auto expected = ScalarFromJSON(decimal256(7, 3), R"("591.408")");
+ CheckScalarBinary("multiply", left, right, expected);
+ CheckScalarBinary("multiply", right, left, expected);
+ }
+
+ // decimal float
+ {
+ auto left = ScalarFromJSON(decimal128(3, 0), R"("666")");
+ ASSIGN_OR_ABORT(auto right, arrow::MakeScalar(float64(), 888));
+ ASSIGN_OR_ABORT(auto expected, arrow::MakeScalar(float64(), 591408));
+ CheckScalarBinary("multiply", left, right, expected);
+ CheckScalarBinary("multiply", right, left, expected);
+ }
+
+ // TODO: decimal integer
+
+ // failed case: result maybe overflow
+ {
+ auto left = ScalarFromJSON(decimal128(20, 0), R"("1")");
+ auto right = ScalarFromJSON(decimal128(18, 1), R"("1.0")");
+ ASSERT_RAISES(Invalid, CallFunction("multiply", {left, right}));
+ }
+}
+
+TEST(TestBinaryArithmeticDecimal, Divide) {
+ // array array, decimal128
+ {
+ auto left = ArrayFromJSON(decimal128(13, 3), R"(["1234567890.123", "0.001"])");
+ auto right = ArrayFromJSON(decimal128(3, 0), R"(["-987", "999"])");
+ auto expected =
+ ArrayFromJSON(decimal128(17, 7), R"(["-1250828.6627386", "0.0000010"])");
+ CheckScalarBinary("divide", left, right, expected);
+ }
+
+ // array array, decimal256
+ {
+ auto left = ArrayFromJSON(decimal256(20, 10),
+ R"(["1234567890.1234567890", "9999999999.9999999999"])");
+ auto right = ArrayFromJSON(decimal256(13, 3), R"(["1234567890.123", "0.001"])");
+ auto expected = ArrayFromJSON(
+ decimal256(34, 21),
+ R"(["1.000000000000369999093", "9999999999999.999999900000000000000"])");
+ CheckScalarBinary("divide", left, right, expected);
+ }
+
+ // scalar array
+ {
+ auto left = ScalarFromJSON(decimal128(1, 0), R"("1")");
+ auto right = ArrayFromJSON(decimal128(1, 0), R"(["1", "2", "3", "4"])");
+ auto left_div_right =
+ ArrayFromJSON(decimal128(5, 4), R"(["1.0000", "0.5000", "0.3333", "0.2500"])");
+ auto right_div_left =
+ ArrayFromJSON(decimal128(5, 4), R"(["1.0000", "2.0000", "3.0000", "4.0000"])");
+ CheckScalarBinary("divide", left, right, left_div_right);
+ CheckScalarBinary("divide", right, left, right_div_left);
+ }
+
+ // scalar scalar
+ {
+ auto left = ScalarFromJSON(decimal256(6, 5), R"("2.71828")");
+ auto right = ScalarFromJSON(decimal256(6, 5), R"("3.14159")");
+ auto expected = ScalarFromJSON(decimal256(13, 7), R"("0.8652561")");
+ CheckScalarBinary("divide", left, right, expected);
+ }
+
+ // decimal128 decimal256
+ {
+ auto left = ScalarFromJSON(decimal256(6, 5), R"("2.71828")");
+ auto right = ScalarFromJSON(decimal128(6, 5), R"("3.14159")");
+ auto left_div_right = ScalarFromJSON(decimal256(13, 7), R"("0.8652561")");
+ auto right_div_left = ScalarFromJSON(decimal256(13, 7), R"("1.1557271")");
+ CheckScalarBinary("divide", left, right, left_div_right);
+ CheckScalarBinary("divide", right, left, right_div_left);
+ }
+
+ // decimal float
+ {
+ auto left = ScalarFromJSON(decimal128(3, 0), R"("100")");
+ ASSIGN_OR_ABORT(auto right, arrow::MakeScalar(float64(), 50));
+ ASSIGN_OR_ABORT(auto left_div_right, arrow::MakeScalar(float64(), 2));
+ ASSIGN_OR_ABORT(auto right_div_left, arrow::MakeScalar(float64(), 0.5));
+ CheckScalarBinary("divide", left, right, left_div_right);
+ CheckScalarBinary("divide", right, left, right_div_left);
+ }
+
+ // TODO: decimal integer
+
+ // failed case: result maybe overflow
+ {
+ auto left = ScalarFromJSON(decimal128(20, 20), R"("0.12345678901234567890")");
+ auto right = ScalarFromJSON(decimal128(20, 0), R"("12345678901234567890")");
+ ASSERT_RAISES(Invalid, CallFunction("divide", {left, right}));
+ }
+
+ // failed case: divide by 0
+ {
+ auto left = ScalarFromJSON(decimal256(1, 0), R"("1")");
+ auto right = ScalarFromJSON(decimal256(1, 0), R"("0")");
+ ASSERT_RAISES(Invalid, CallFunction("divide", {left, right}));
+ }
+}
+
+TYPED_TEST(TestBinaryArithmeticIntegral, ShiftLeft) {
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+
+ this->AssertBinop(ShiftLeft, "[]", "[]", "[]");
+ this->AssertBinop(ShiftLeft, "[0, 1, 2, 3]", "[2, 3, 4, 5]", "[0, 8, 32, 96]");
+ // Nulls on one side
+ this->AssertBinop(ShiftLeft, "[0, null, 2, 3]", "[2, 3, 4, 5]", "[0, null, 32, 96]");
+ this->AssertBinop(ShiftLeft, "[0, 1, 2, 3]", "[2, 3, null, 5]", "[0, 8, null, 96]");
+ // Nulls on both sides
+ this->AssertBinop(ShiftLeft, "[0, null, 2, 3]", "[2, 3, null, 5]",
+ "[0, null, null, 96]");
+ // All nulls
+ this->AssertBinop(ShiftLeft, "[null]", "[null]", "[null]");
+
+ // Scalar on the left
+ this->AssertBinop(ShiftLeft, 2, "[null, 5]", "[null, 64]");
+ this->AssertBinop(ShiftLeft, this->MakeNullScalar(), "[null, 5]", "[null, null]");
+ // Scalar on the right
+ this->AssertBinop(ShiftLeft, "[null, 5]", 3, "[null, 40]");
+ this->AssertBinop(ShiftLeft, "[null, 5]", this->MakeNullScalar(), "[null, null]");
+ }
+}
+
+TYPED_TEST(TestBinaryArithmeticIntegral, ShiftRight) {
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+
+ this->AssertBinop(ShiftRight, "[]", "[]", "[]");
+ this->AssertBinop(ShiftRight, "[0, 1, 4, 8]", "[1, 1, 1, 4]", "[0, 0, 2, 0]");
+ // Nulls on one side
+ this->AssertBinop(ShiftRight, "[0, null, 4, 8]", "[1, 1, 1, 4]", "[0, null, 2, 0]");
+ this->AssertBinop(ShiftRight, "[0, 1, 4, 8]", "[1, 1, null, 4]", "[0, 0, null, 0]");
+ // Nulls on both sides
+ this->AssertBinop(ShiftRight, "[0, null, 4, 8]", "[1, 1, null, 4]",
+ "[0, null, null, 0]");
+ // All nulls
+ this->AssertBinop(ShiftRight, "[null]", "[null]", "[null]");
+
+ // Scalar on the left
+ this->AssertBinop(ShiftRight, 64, "[null, 2, 6]", "[null, 16, 1]");
+ this->AssertBinop(ShiftRight, this->MakeNullScalar(), "[null, 2, 6]",
+ "[null, null, null]");
+ // Scalar on the right
+ this->AssertBinop(ShiftRight, "[null, 3, 96]", 3, "[null, 0, 12]");
+ this->AssertBinop(ShiftRight, "[null, 3, 96]", this->MakeNullScalar(),
+ "[null, null, null]");
+ }
+}
+
+TYPED_TEST(TestBinaryArithmeticSigned, ShiftLeftOverflowRaises) {
+ using CType = typename TestFixture::CType;
+ const CType bit_width = static_cast<CType>(std::numeric_limits<CType>::digits);
+ const CType min = std::numeric_limits<CType>::min();
+ this->SetOverflowCheck(true);
+
+ this->AssertBinop(ShiftLeft, "[1]", MakeArray(bit_width - 1),
+ MakeArray(static_cast<CType>(1) << (bit_width - 1)));
+ this->AssertBinop(ShiftLeft, "[2]", MakeArray(bit_width - 2),
+ MakeArray(static_cast<CType>(1) << (bit_width - 1)));
+ // Shift a bit into the sign bit
+ this->AssertBinop(ShiftLeft, "[2]", MakeArray(bit_width - 1), MakeArray(min));
+ // Shift a bit past the sign bit
+ this->AssertBinop(ShiftLeft, "[4]", MakeArray(bit_width - 1), "[0]");
+ this->AssertBinop(ShiftLeft, MakeArray(min), "[1]", "[0]");
+ this->AssertBinopRaises(ShiftLeft, "[1, 2]", "[1, -1]",
+ "shift amount must be >= 0 and less than precision of type");
+ this->AssertBinopRaises(ShiftLeft, "[1]", MakeArray(bit_width),
+ "shift amount must be >= 0 and less than precision of type");
+
+ this->SetOverflowCheck(false);
+ this->AssertBinop(ShiftLeft, "[1, 1]", MakeArray(-1, bit_width), "[1, 1]");
+}
+
+TYPED_TEST(TestBinaryArithmeticSigned, ShiftRightOverflowRaises) {
+ using CType = typename TestFixture::CType;
+ const CType bit_width = static_cast<CType>(std::numeric_limits<CType>::digits);
+ const CType max = std::numeric_limits<CType>::max();
+ const CType min = std::numeric_limits<CType>::min();
+ this->SetOverflowCheck(true);
+
+ this->AssertBinop(ShiftRight, MakeArray(max), MakeArray(bit_width - 1), "[1]");
+ this->AssertBinop(ShiftRight, "[-1, -1]", "[1, 5]", "[-1, -1]");
+ this->AssertBinop(ShiftRight, MakeArray(min), "[1]", MakeArray(min / 2));
+ this->AssertBinopRaises(ShiftRight, "[1, 2]", "[1, -1]",
+ "shift amount must be >= 0 and less than precision of type");
+ this->AssertBinopRaises(ShiftRight, "[1]", MakeArray(bit_width),
+ "shift amount must be >= 0 and less than precision of type");
+
+ this->SetOverflowCheck(false);
+ this->AssertBinop(ShiftRight, "[1, 1]", MakeArray(-1, bit_width), "[1, 1]");
+}
+
+TYPED_TEST(TestBinaryArithmeticUnsigned, ShiftLeftOverflowRaises) {
+ using CType = typename TestFixture::CType;
+ const CType bit_width = static_cast<CType>(std::numeric_limits<CType>::digits);
+ this->SetOverflowCheck(true);
+
+ this->AssertBinop(ShiftLeft, "[1]", MakeArray(bit_width - 1),
+ MakeArray(static_cast<CType>(1) << (bit_width - 1)));
+ this->AssertBinop(ShiftLeft, "[2]", MakeArray(bit_width - 2),
+ MakeArray(static_cast<CType>(1) << (bit_width - 1)));
+ this->AssertBinop(ShiftLeft, "[2]", MakeArray(bit_width - 1), "[0]");
+ this->AssertBinop(ShiftLeft, "[4]", MakeArray(bit_width - 1), "[0]");
+ this->AssertBinopRaises(ShiftLeft, "[1]", MakeArray(bit_width),
+ "shift amount must be >= 0 and less than precision of type");
+}
+
+TYPED_TEST(TestBinaryArithmeticUnsigned, ShiftRightOverflowRaises) {
+ using CType = typename TestFixture::CType;
+ const CType bit_width = static_cast<CType>(std::numeric_limits<CType>::digits);
+ const CType max = std::numeric_limits<CType>::max();
+ this->SetOverflowCheck(true);
+
+ this->AssertBinop(ShiftRight, MakeArray(max), MakeArray(bit_width - 1), "[1]");
+ this->AssertBinopRaises(ShiftRight, "[1]", MakeArray(bit_width),
+ "shift amount must be >= 0 and less than precision of type");
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, TrigSin) {
+ this->SetNansEqual(true);
+ this->AssertUnaryOp(Sin, "[Inf, -Inf]", "[NaN, NaN]");
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ this->AssertUnaryOp(Sin, "[]", "[]");
+ this->AssertUnaryOp(Sin, "[null, NaN]", "[null, NaN]");
+ this->AssertUnaryOp(Sin, MakeArray(0, M_PI_2, M_PI), "[0, 1, 0]");
+ }
+ this->AssertUnaryOpRaises(Sin, "[Inf, -Inf]", "domain error");
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, TrigCos) {
+ this->SetNansEqual(true);
+ this->AssertUnaryOp(Cos, "[Inf, -Inf]", "[NaN, NaN]");
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ this->AssertUnaryOp(Cos, "[]", "[]");
+ this->AssertUnaryOp(Cos, "[null, NaN]", "[null, NaN]");
+ this->AssertUnaryOp(Cos, MakeArray(0, M_PI_2, M_PI), "[1, 0, -1]");
+ }
+ this->AssertUnaryOpRaises(Cos, "[Inf, -Inf]", "domain error");
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, TrigTan) {
+ this->SetNansEqual(true);
+ this->AssertUnaryOp(Tan, "[Inf, -Inf]", "[NaN, NaN]");
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ this->AssertUnaryOp(Tan, "[]", "[]");
+ this->AssertUnaryOp(Tan, "[null, NaN]", "[null, NaN]");
+ // N.B. pi/2 isn't representable exactly -> there are no poles
+ // (i.e. tan(pi/2) is merely a large value and not +Inf)
+ this->AssertUnaryOp(Tan, MakeArray(0, M_PI), "[0, 0]");
+ }
+ this->AssertUnaryOpRaises(Tan, "[Inf, -Inf]", "domain error");
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, TrigAsin) {
+ this->SetNansEqual(true);
+ this->AssertUnaryOp(Asin, "[Inf, -Inf, -2, 2]", "[NaN, NaN, NaN, NaN]");
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ this->AssertUnaryOp(Asin, "[]", "[]");
+ this->AssertUnaryOp(Asin, "[null, NaN]", "[null, NaN]");
+ this->AssertUnaryOp(Asin, "[0, 1, -1]", MakeArray(0, M_PI_2, -M_PI_2));
+ }
+ this->AssertUnaryOpRaises(Asin, "[Inf, -Inf, -2, 2]", "domain error");
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, TrigAcos) {
+ this->SetNansEqual(true);
+ this->AssertUnaryOp(Asin, "[Inf, -Inf, -2, 2]", "[NaN, NaN, NaN, NaN]");
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ this->AssertUnaryOp(Acos, "[]", "[]");
+ this->AssertUnaryOp(Acos, "[null, NaN]", "[null, NaN]");
+ this->AssertUnaryOp(Acos, "[0, 1, -1]", MakeArray(M_PI_2, 0, M_PI));
+ }
+ this->AssertUnaryOpRaises(Acos, "[Inf, -Inf, -2, 2]", "domain error");
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, TrigAtan) {
+ this->SetNansEqual(true);
+ auto atan = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Atan(arg, ctx);
+ };
+ this->AssertUnaryOp(atan, "[]", "[]");
+ this->AssertUnaryOp(atan, "[null, NaN]", "[null, NaN]");
+ this->AssertUnaryOp(atan, "[0, 1, -1, Inf, -Inf]",
+ MakeArray(0, M_PI_4, -M_PI_4, M_PI_2, -M_PI_2));
+}
+
+TYPED_TEST(TestBinaryArithmeticFloating, TrigAtan2) {
+ this->SetNansEqual(true);
+ auto atan2 = [](const Datum& y, const Datum& x, ArithmeticOptions, ExecContext* ctx) {
+ return Atan2(y, x, ctx);
+ };
+ this->AssertBinop(atan2, "[]", "[]", "[]");
+ this->AssertBinop(atan2, "[0, 0, null, NaN]", "[null, NaN, 0, 0]",
+ "[null, NaN, null, NaN]");
+ this->AssertBinop(atan2, "[0, 0, -0.0, 0, -0.0, 0, 1, 0, -1, Inf, -Inf, 0, 0]",
+ "[0, 0, 0, -0.0, -0.0, 1, 0, -1, 0, 0, 0, Inf, -Inf]",
+ MakeArray(0, 0, -0.0, M_PI, -M_PI, 0, M_PI_2, M_PI, -M_PI_2, M_PI_2,
+ -M_PI_2, 0, M_PI));
+}
+
+TYPED_TEST(TestUnaryArithmeticIntegral, Trig) {
+ // Integer arguments promoted to double, sanity check here
+ auto atan = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Atan(arg, ctx);
+ };
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ this->AssertUnaryOp(Sin, "[0, 1]",
+ ArrayFromJSON(float64(), "[0, 0.8414709848078965]"));
+ this->AssertUnaryOp(Cos, "[0, 1]",
+ ArrayFromJSON(float64(), "[1, 0.5403023058681398]"));
+ this->AssertUnaryOp(Tan, "[0, 1]",
+ ArrayFromJSON(float64(), "[0, 1.5574077246549023]"));
+ this->AssertUnaryOp(Asin, "[0, 1]", ArrayFromJSON(float64(), MakeArray(0, M_PI_2)));
+ this->AssertUnaryOp(Acos, "[0, 1]", ArrayFromJSON(float64(), MakeArray(M_PI_2, 0)));
+ this->AssertUnaryOp(atan, "[0, 1]", ArrayFromJSON(float64(), MakeArray(0, M_PI_4)));
+ }
+}
+
+TYPED_TEST(TestBinaryArithmeticIntegral, Trig) {
+ // Integer arguments promoted to double, sanity check here
+ auto ty = this->type_singleton();
+ auto atan2 = [](const Datum& y, const Datum& x, ArithmeticOptions, ExecContext* ctx) {
+ return Atan2(y, x, ctx);
+ };
+ this->AssertBinop(atan2, ArrayFromJSON(ty, "[0, 1]"), ArrayFromJSON(ty, "[1, 0]"),
+ ArrayFromJSON(float64(), MakeArray(0, M_PI_2)));
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, Log) {
+ using CType = typename TestFixture::CType;
+ this->SetNansEqual(true);
+ auto min_val = std::numeric_limits<CType>::min();
+ auto max_val = std::numeric_limits<CType>::max();
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ this->AssertUnaryOp(Ln, "[1, 2.718281828459045, null, NaN, Inf]",
+ "[0, 1, null, NaN, Inf]");
+ // N.B. min() for float types is smallest normal number > 0
+ this->AssertUnaryOp(Ln, min_val, std::log(min_val));
+ this->AssertUnaryOp(Ln, max_val, std::log(max_val));
+ this->AssertUnaryOp(Log10, "[1, 10, null, NaN, Inf]", "[0, 1, null, NaN, Inf]");
+ this->AssertUnaryOp(Log10, min_val, std::log10(min_val));
+ this->AssertUnaryOp(Log10, max_val, std::log10(max_val));
+ this->AssertUnaryOp(Log2, "[1, 2, null, NaN, Inf]", "[0, 1, null, NaN, Inf]");
+ this->AssertUnaryOp(Log2, min_val, std::log2(min_val));
+ this->AssertUnaryOp(Log2, max_val, std::log2(max_val));
+ this->AssertUnaryOp(Log1p, "[0, 1.718281828459045, null, NaN, Inf]",
+ "[0, 1, null, NaN, Inf]");
+ this->AssertUnaryOp(Log1p, min_val, std::log1p(min_val));
+ this->AssertUnaryOp(Log1p, max_val, std::log1p(max_val));
+ }
+ this->SetOverflowCheck(false);
+ this->AssertUnaryOp(Ln, "[-Inf, -1, 0, Inf]", "[NaN, NaN, -Inf, Inf]");
+ this->AssertUnaryOp(Log10, "[-Inf, -1, 0, Inf]", "[NaN, NaN, -Inf, Inf]");
+ this->AssertUnaryOp(Log2, "[-Inf, -1, 0, Inf]", "[NaN, NaN, -Inf, Inf]");
+ this->AssertUnaryOp(Log1p, "[-Inf, -2, -1, Inf]", "[NaN, NaN, -Inf, Inf]");
+ this->SetOverflowCheck(true);
+ this->AssertUnaryOpRaises(Ln, "[0]", "logarithm of zero");
+ this->AssertUnaryOpRaises(Ln, "[-1]", "logarithm of negative number");
+ this->AssertUnaryOpRaises(Ln, "[-Inf]", "logarithm of negative number");
+
+ auto lowest_val = MakeScalar(std::numeric_limits<CType>::lowest());
+ // N.B. RapidJSON on some platforms raises "Number too big to be stored in double" so
+ // don't bounce through JSON
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("logarithm of negative number"),
+ Ln(lowest_val, this->options_));
+ this->AssertUnaryOpRaises(Log10, "[0]", "logarithm of zero");
+ this->AssertUnaryOpRaises(Log10, "[-1]", "logarithm of negative number");
+ this->AssertUnaryOpRaises(Log10, "[-Inf]", "logarithm of negative number");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("logarithm of negative number"),
+ Log10(lowest_val, this->options_));
+ this->AssertUnaryOpRaises(Log2, "[0]", "logarithm of zero");
+ this->AssertUnaryOpRaises(Log2, "[-1]", "logarithm of negative number");
+ this->AssertUnaryOpRaises(Log2, "[-Inf]", "logarithm of negative number");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("logarithm of negative number"),
+ Log2(lowest_val, this->options_));
+ this->AssertUnaryOpRaises(Log1p, "[-1]", "logarithm of zero");
+ this->AssertUnaryOpRaises(Log1p, "[-2]", "logarithm of negative number");
+ this->AssertUnaryOpRaises(Log1p, "[-Inf]", "logarithm of negative number");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("logarithm of negative number"),
+ Log1p(lowest_val, this->options_));
+}
+
+TYPED_TEST(TestUnaryArithmeticIntegral, Log) {
+ // Integer arguments promoted to double, sanity check here
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ this->AssertUnaryOp(Ln, "[1, null]", ArrayFromJSON(float64(), "[0, null]"));
+ this->AssertUnaryOp(Log10, "[1, 10, null]", ArrayFromJSON(float64(), "[0, 1, null]"));
+ this->AssertUnaryOp(Log2, "[1, 2, null]", ArrayFromJSON(float64(), "[0, 1, null]"));
+ this->AssertUnaryOp(Log1p, "[0, null]", ArrayFromJSON(float64(), "[0, null]"));
+ }
+}
+
+TYPED_TEST(TestBinaryArithmeticIntegral, Log) {
+ // Integer arguments promoted to double, sanity check here
+ this->AssertBinop(Logb, "[1, 10, null]", "[10, 10, null]",
+ ArrayFromJSON(float64(), "[0, 1, null]"));
+ this->AssertBinop(Logb, "[1, 2, null]", "[2, 2, null]",
+ ArrayFromJSON(float64(), "[0, 1, null]"));
+ this->AssertBinop(Logb, "[10, 100, null]", this->MakeScalar(10),
+ ArrayFromJSON(float64(), "[1, 2, null]"));
+}
+
+TYPED_TEST(TestBinaryArithmeticFloating, Log) {
+ using CType = typename TestFixture::CType;
+ this->SetNansEqual(true);
+ auto min_val = std::numeric_limits<CType>::min();
+ auto max_val = std::numeric_limits<CType>::max();
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ // N.B. min() for float types is smallest normal number > 0
+ this->AssertBinop(Logb, "[1, 10, null, NaN, Inf]", "[100, 10, null, 2, 10]",
+ "[0, 1, null, NaN, Inf]");
+ this->AssertBinop(Logb, min_val, 10,
+ static_cast<CType>(std::log(min_val) / std::log(10)));
+ this->AssertBinop(Logb, max_val, 10,
+ static_cast<CType>(std::log(max_val) / std::log(10)));
+ }
+ this->AssertBinop(Logb, "[1.0, 10.0, null]", "[10.0, 10.0, null]", "[0.0, 1.0, null]");
+ this->AssertBinop(Logb, "[1.0, 2.0, null]", "[2.0, 2.0, null]", "[0.0, 1.0, null]");
+ this->AssertBinop(Logb, "[10.0, 100.0, 1000.0, null]", this->MakeScalar(10),
+ "[1.0, 2.0, 3.0, null]");
+ this->SetOverflowCheck(false);
+ this->AssertBinop(Logb, "[-Inf, -1, 0, Inf]", this->MakeScalar(10),
+ "[NaN, NaN, -Inf, Inf]");
+ this->AssertBinop(Logb, "[-Inf, -1, 0, Inf]", this->MakeScalar(2),
+ "[NaN, NaN, -Inf, Inf]");
+ this->AssertBinop(Logb, "[-Inf, -1, 0, Inf]", "[2, 10, 0, 0]", "[NaN, NaN, NaN, NaN]");
+ this->AssertBinop(Logb, "[-Inf, -1, 0, Inf]", this->MakeScalar(0),
+ "[NaN, NaN, NaN, NaN]");
+ this->AssertBinop(Logb, "[-Inf, -2, -1, Inf]", this->MakeScalar(2),
+ "[NaN, NaN, NaN, Inf]");
+ this->SetOverflowCheck(true);
+ this->AssertBinopRaises(Logb, "[0]", "[2]", "logarithm of zero");
+ this->AssertBinopRaises(Logb, "[-1]", "[2]", "logarithm of negative number");
+ this->AssertBinopRaises(Logb, "[-Inf]", "[2]", "logarithm of negative number");
+}
+
+TYPED_TEST(TestBinaryArithmeticSigned, Log) {
+ // Integer arguments promoted to double, sanity check here
+ this->SetNansEqual(true);
+ this->SetOverflowCheck(false);
+ this->AssertBinop(Logb, "[-1, 0]", this->MakeScalar(10),
+ ArrayFromJSON(float64(), "[NaN, -Inf]"));
+ this->AssertBinop(Logb, "[-1, 0]", this->MakeScalar(2),
+ ArrayFromJSON(float64(), "[NaN, -Inf]"));
+ this->AssertBinop(Logb, "[10, 100]", this->MakeScalar(-1),
+ ArrayFromJSON(float64(), "[NaN, NaN]"));
+ this->AssertBinop(Logb, "[-1, 0, null]", this->MakeScalar(-1),
+ ArrayFromJSON(float64(), "[NaN, NaN, null]"));
+ this->AssertBinop(Logb, "[10, 100]", this->MakeScalar(0),
+ ArrayFromJSON(float64(), "[0, 0]"));
+ this->SetOverflowCheck(true);
+ this->AssertBinopRaises(Logb, "[0]", "[10]", "logarithm of zero");
+ this->AssertBinopRaises(Logb, "[-1]", "[10]", "logarithm of negative number");
+ this->AssertBinopRaises(Logb, "[10]", "[0]", "logarithm of zero");
+ this->AssertBinopRaises(Logb, "[100]", "[-1]", "logarithm of negative number");
+}
+
+TYPED_TEST(TestUnaryArithmeticSigned, Log) {
+ // Integer arguments promoted to double, sanity check here
+ this->SetNansEqual(true);
+ this->SetOverflowCheck(false);
+ this->AssertUnaryOp(Ln, "[-1, 0]", ArrayFromJSON(float64(), "[NaN, -Inf]"));
+ this->AssertUnaryOp(Log10, "[-1, 0]", ArrayFromJSON(float64(), "[NaN, -Inf]"));
+ this->AssertUnaryOp(Log2, "[-1, 0]", ArrayFromJSON(float64(), "[NaN, -Inf]"));
+ this->AssertUnaryOp(Log1p, "[-2, -1]", ArrayFromJSON(float64(), "[NaN, -Inf]"));
+ this->SetOverflowCheck(true);
+ this->AssertUnaryOpRaises(Ln, "[0]", "logarithm of zero");
+ this->AssertUnaryOpRaises(Ln, "[-1]", "logarithm of negative number");
+ this->AssertUnaryOpRaises(Log10, "[0]", "logarithm of zero");
+ this->AssertUnaryOpRaises(Log10, "[-1]", "logarithm of negative number");
+ this->AssertUnaryOpRaises(Log2, "[0]", "logarithm of zero");
+ this->AssertUnaryOpRaises(Log2, "[-1]", "logarithm of negative number");
+ this->AssertUnaryOpRaises(Log1p, "[-1]", "logarithm of zero");
+ this->AssertUnaryOpRaises(Log1p, "[-2]", "logarithm of negative number");
+}
+
+TYPED_TEST(TestUnaryArithmeticSigned, Sign) {
+ using CType = typename TestFixture::CType;
+ auto min = std::numeric_limits<CType>::min();
+ auto max = std::numeric_limits<CType>::max();
+
+ auto sign = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Sign(arg, ctx);
+ };
+
+ this->AssertUnaryOp(sign, "[]", ArrayFromJSON(int8(), "[]"));
+ this->AssertUnaryOp(sign, "[null]", ArrayFromJSON(int8(), "[null]"));
+ this->AssertUnaryOp(sign, "[1, null, -10]", ArrayFromJSON(int8(), "[1, null, -1]"));
+ this->AssertUnaryOp(sign, "[0]", ArrayFromJSON(int8(), "[0]"));
+ this->AssertUnaryOp(sign, "[1, 10, 127]", ArrayFromJSON(int8(), "[1, 1, 1]"));
+ this->AssertUnaryOp(sign, "[-1, -10, -127]", ArrayFromJSON(int8(), "[-1, -1, -1]"));
+ this->AssertUnaryOp(sign, this->MakeScalar(min), *arrow::MakeScalar(int8(), -1));
+ this->AssertUnaryOp(sign, this->MakeScalar(max), *arrow::MakeScalar(int8(), 1));
+}
+
+TYPED_TEST(TestUnaryArithmeticUnsigned, Sign) {
+ using CType = typename TestFixture::CType;
+ auto min = std::numeric_limits<CType>::min();
+ auto max = std::numeric_limits<CType>::max();
+
+ auto sign = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Sign(arg, ctx);
+ };
+
+ this->AssertUnaryOp(sign, "[]", ArrayFromJSON(int8(), "[]"));
+ this->AssertUnaryOp(sign, "[null]", ArrayFromJSON(int8(), "[null]"));
+ this->AssertUnaryOp(sign, "[1, null, 10]", ArrayFromJSON(int8(), "[1, null, 1]"));
+ this->AssertUnaryOp(sign, "[0]", ArrayFromJSON(int8(), "[0]"));
+ this->AssertUnaryOp(sign, "[1, 10, 127]", ArrayFromJSON(int8(), "[1, 1, 1]"));
+ this->AssertUnaryOp(sign, this->MakeScalar(min), *arrow::MakeScalar(int8(), 0));
+ this->AssertUnaryOp(sign, this->MakeScalar(max), *arrow::MakeScalar(int8(), 1));
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, Sign) {
+ using CType = typename TestFixture::CType;
+ auto min = std::numeric_limits<CType>::lowest();
+ auto max = std::numeric_limits<CType>::max();
+
+ this->SetNansEqual(true);
+
+ auto sign = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Sign(arg, ctx);
+ };
+
+ this->AssertUnaryOp(sign, "[]", "[]");
+ this->AssertUnaryOp(sign, "[null]", "[null]");
+ this->AssertUnaryOp(sign, "[1.3, null, -10.80]", "[1, null, -1]");
+ this->AssertUnaryOp(sign, "[0.0, -0.0]", "[0, 0]");
+ this->AssertUnaryOp(sign, "[1.3, 10.80, 12748.001]", "[1, 1, 1]");
+ this->AssertUnaryOp(sign, "[-1.3, -10.80, -12748.001]", "[-1, -1, -1]");
+ this->AssertUnaryOp(sign, "[Inf, -Inf]", "[1, -1]");
+ this->AssertUnaryOp(sign, "[NaN]", "[NaN]");
+ this->AssertUnaryOp(sign, this->MakeScalar(min), this->MakeScalar(-1));
+ this->AssertUnaryOp(sign, this->MakeScalar(max), this->MakeScalar(1));
+}
+
+TYPED_TEST(TestUnaryArithmeticSigned, Floor) {
+ auto floor = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Floor(arg, ctx);
+ };
+
+ this->AssertUnaryOp(floor, "[]", ArrayFromJSON(float64(), "[]"));
+ this->AssertUnaryOp(floor, "[null]", ArrayFromJSON(float64(), "[null]"));
+ this->AssertUnaryOp(floor, "[1, null, -10]",
+ ArrayFromJSON(float64(), "[1, null, -10]"));
+ this->AssertUnaryOp(floor, "[0]", ArrayFromJSON(float64(), "[0]"));
+ this->AssertUnaryOp(floor, "[1, 10, 127]", ArrayFromJSON(float64(), "[1, 10, 127]"));
+ this->AssertUnaryOp(floor, "[-1, -10, -127]",
+ ArrayFromJSON(float64(), "[-1, -10, -127]"));
+}
+
+TYPED_TEST(TestUnaryArithmeticUnsigned, Floor) {
+ auto floor = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Floor(arg, ctx);
+ };
+
+ this->AssertUnaryOp(floor, "[]", ArrayFromJSON(float64(), "[]"));
+ this->AssertUnaryOp(floor, "[null]", ArrayFromJSON(float64(), "[null]"));
+ this->AssertUnaryOp(floor, "[1, null, 10]", ArrayFromJSON(float64(), "[1, null, 10]"));
+ this->AssertUnaryOp(floor, "[0]", ArrayFromJSON(float64(), "[0]"));
+ this->AssertUnaryOp(floor, "[1, 10, 127]", ArrayFromJSON(float64(), "[1, 10, 127]"));
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, Floor) {
+ using CType = typename TestFixture::CType;
+ auto min = std::numeric_limits<CType>::lowest();
+ auto max = std::numeric_limits<CType>::max();
+
+ this->SetNansEqual(true);
+
+ auto floor = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Floor(arg, ctx);
+ };
+
+ this->AssertUnaryOp(floor, "[]", "[]");
+ this->AssertUnaryOp(floor, "[null]", "[null]");
+ this->AssertUnaryOp(floor, "[1.3, null, -10.80]", "[1, null, -11]");
+ this->AssertUnaryOp(floor, "[0.0, -0.0]", "[0, 0]");
+ this->AssertUnaryOp(floor, "[1.3, 10.80, 12748.001]", "[1, 10, 12748]");
+ this->AssertUnaryOp(floor, "[-1.3, -10.80, -12748.001]", "[-2, -11, -12749]");
+ this->AssertUnaryOp(floor, "[Inf, -Inf]", "[Inf, -Inf]");
+ this->AssertUnaryOp(floor, "[NaN]", "[NaN]");
+ this->AssertUnaryOp(floor, this->MakeScalar(min), this->MakeScalar(min));
+ this->AssertUnaryOp(floor, this->MakeScalar(max), this->MakeScalar(max));
+}
+
+TYPED_TEST(TestUnaryArithmeticSigned, Ceil) {
+ auto ceil = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Ceil(arg, ctx);
+ };
+
+ this->AssertUnaryOp(ceil, "[]", ArrayFromJSON(float64(), "[]"));
+ this->AssertUnaryOp(ceil, "[null]", ArrayFromJSON(float64(), "[null]"));
+ this->AssertUnaryOp(ceil, "[1, null, -10]", ArrayFromJSON(float64(), "[1, null, -10]"));
+ this->AssertUnaryOp(ceil, "[0]", ArrayFromJSON(float64(), "[0]"));
+ this->AssertUnaryOp(ceil, "[1, 10, 127]", ArrayFromJSON(float64(), "[1, 10, 127]"));
+ this->AssertUnaryOp(ceil, "[-1, -10, -127]",
+ ArrayFromJSON(float64(), "[-1, -10, -127]"));
+}
+
+TYPED_TEST(TestUnaryArithmeticUnsigned, Ceil) {
+ auto ceil = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Ceil(arg, ctx);
+ };
+
+ this->AssertUnaryOp(ceil, "[]", ArrayFromJSON(float64(), "[]"));
+ this->AssertUnaryOp(ceil, "[null]", ArrayFromJSON(float64(), "[null]"));
+ this->AssertUnaryOp(ceil, "[1, null, 10]", ArrayFromJSON(float64(), "[1, null, 10]"));
+ this->AssertUnaryOp(ceil, "[0]", ArrayFromJSON(float64(), "[0]"));
+ this->AssertUnaryOp(ceil, "[1, 10, 127]", ArrayFromJSON(float64(), "[1, 10, 127]"));
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, Ceil) {
+ using CType = typename TestFixture::CType;
+ auto min = std::numeric_limits<CType>::lowest();
+ auto max = std::numeric_limits<CType>::max();
+
+ this->SetNansEqual(true);
+
+ auto ceil = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Ceil(arg, ctx);
+ };
+
+ this->AssertUnaryOp(ceil, "[]", "[]");
+ this->AssertUnaryOp(ceil, "[null]", "[null]");
+ this->AssertUnaryOp(ceil, "[1.3, null, -10.80]", "[2, null, -10]");
+ this->AssertUnaryOp(ceil, "[0.0, -0.0]", "[0, 0]");
+ this->AssertUnaryOp(ceil, "[1.3, 10.80, 12748.001]", "[2, 11, 12749]");
+ this->AssertUnaryOp(ceil, "[-1.3, -10.80, -12748.001]", "[-1, -10, -12748]");
+ this->AssertUnaryOp(ceil, "[Inf, -Inf]", "[Inf, -Inf]");
+ this->AssertUnaryOp(ceil, "[NaN]", "[NaN]");
+ this->AssertUnaryOp(ceil, this->MakeScalar(min), this->MakeScalar(min));
+ this->AssertUnaryOp(ceil, this->MakeScalar(max), this->MakeScalar(max));
+}
+
+TYPED_TEST(TestUnaryArithmeticSigned, Trunc) {
+ auto trunc = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Trunc(arg, ctx);
+ };
+
+ this->AssertUnaryOp(trunc, "[]", ArrayFromJSON(float64(), "[]"));
+ this->AssertUnaryOp(trunc, "[null]", ArrayFromJSON(float64(), "[null]"));
+ this->AssertUnaryOp(trunc, "[1, null, -10]",
+ ArrayFromJSON(float64(), "[1, null, -10]"));
+ this->AssertUnaryOp(trunc, "[0]", ArrayFromJSON(float64(), "[0]"));
+ this->AssertUnaryOp(trunc, "[1, 10, 127]", ArrayFromJSON(float64(), "[1, 10, 127]"));
+ this->AssertUnaryOp(trunc, "[-1, -10, -127]",
+ ArrayFromJSON(float64(), "[-1, -10, -127]"));
+}
+
+TYPED_TEST(TestUnaryArithmeticUnsigned, Trunc) {
+ auto trunc = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Trunc(arg, ctx);
+ };
+
+ this->AssertUnaryOp(trunc, "[]", ArrayFromJSON(float64(), "[]"));
+ this->AssertUnaryOp(trunc, "[null]", ArrayFromJSON(float64(), "[null]"));
+ this->AssertUnaryOp(trunc, "[1, null, 10]", ArrayFromJSON(float64(), "[1, null, 10]"));
+ this->AssertUnaryOp(trunc, "[0]", ArrayFromJSON(float64(), "[0]"));
+ this->AssertUnaryOp(trunc, "[1, 10, 127]", ArrayFromJSON(float64(), "[1, 10, 127]"));
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, Trunc) {
+ using CType = typename TestFixture::CType;
+ auto min = std::numeric_limits<CType>::lowest();
+ auto max = std::numeric_limits<CType>::max();
+
+ this->SetNansEqual(true);
+
+ auto trunc = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Trunc(arg, ctx);
+ };
+
+ this->AssertUnaryOp(trunc, "[]", "[]");
+ this->AssertUnaryOp(trunc, "[null]", "[null]");
+ this->AssertUnaryOp(trunc, "[1.3, null, -10.80]", "[1, null, -10]");
+ this->AssertUnaryOp(trunc, "[0.0, -0.0]", "[0, 0]");
+ this->AssertUnaryOp(trunc, "[1.3, 10.80, 12748.001]", "[1, 10, 12748]");
+ this->AssertUnaryOp(trunc, "[-1.3, -10.80, -12748.001]", "[-1, -10, -12748]");
+ this->AssertUnaryOp(trunc, "[Inf, -Inf]", "[Inf, -Inf]");
+ this->AssertUnaryOp(trunc, "[NaN]", "[NaN]");
+ this->AssertUnaryOp(trunc, this->MakeScalar(min), this->MakeScalar(min));
+ this->AssertUnaryOp(trunc, this->MakeScalar(max), this->MakeScalar(max));
+}
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_boolean.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_boolean.cc
new file mode 100644
index 000000000..7a0e3654e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_boolean.cc
@@ -0,0 +1,563 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <array>
+
+#include "arrow/compute/kernels/common.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap.h"
+#include "arrow/util/bitmap_ops.h"
+
+namespace arrow {
+
+using internal::Bitmap;
+
+namespace compute {
+
+namespace {
+
+template <typename ComputeWord>
+void ComputeKleene(ComputeWord&& compute_word, KernelContext* ctx, const ArrayData& left,
+ const ArrayData& right, ArrayData* out) {
+ DCHECK(left.null_count != 0 || right.null_count != 0)
+ << "ComputeKleene is unnecessarily expensive for the non-null case";
+
+ Bitmap left_valid_bm{left.buffers[0], left.offset, left.length};
+ Bitmap left_data_bm{left.buffers[1], left.offset, left.length};
+
+ Bitmap right_valid_bm{right.buffers[0], right.offset, right.length};
+ Bitmap right_data_bm{right.buffers[1], right.offset, right.length};
+
+ std::array<Bitmap, 2> out_bms{Bitmap(out->buffers[0], out->offset, out->length),
+ Bitmap(out->buffers[1], out->offset, out->length)};
+
+ auto apply = [&](uint64_t left_valid, uint64_t left_data, uint64_t right_valid,
+ uint64_t right_data, uint64_t* out_validity, uint64_t* out_data) {
+ auto left_true = left_valid & left_data;
+ auto left_false = left_valid & ~left_data;
+
+ auto right_true = right_valid & right_data;
+ auto right_false = right_valid & ~right_data;
+
+ compute_word(left_true, left_false, right_true, right_false, out_validity, out_data);
+ };
+
+ if (right.null_count == 0) {
+ std::array<Bitmap, 3> in_bms{left_valid_bm, left_data_bm, right_data_bm};
+ Bitmap::VisitWordsAndWrite(
+ in_bms, &out_bms,
+ [&](const std::array<uint64_t, 3>& in, std::array<uint64_t, 2>* out) {
+ apply(in[0], in[1], ~uint64_t(0), in[2], &(out->at(0)), &(out->at(1)));
+ });
+ return;
+ }
+
+ if (left.null_count == 0) {
+ std::array<Bitmap, 3> in_bms{left_data_bm, right_valid_bm, right_data_bm};
+ Bitmap::VisitWordsAndWrite(
+ in_bms, &out_bms,
+ [&](const std::array<uint64_t, 3>& in, std::array<uint64_t, 2>* out) {
+ apply(~uint64_t(0), in[0], in[1], in[2], &(out->at(0)), &(out->at(1)));
+ });
+ return;
+ }
+
+ DCHECK(left.null_count != 0 && right.null_count != 0);
+ std::array<Bitmap, 4> in_bms{left_valid_bm, left_data_bm, right_valid_bm,
+ right_data_bm};
+ Bitmap::VisitWordsAndWrite(
+ in_bms, &out_bms,
+ [&](const std::array<uint64_t, 4>& in, std::array<uint64_t, 2>* out) {
+ apply(in[0], in[1], in[2], in[3], &(out->at(0)), &(out->at(1)));
+ });
+}
+
+inline BooleanScalar InvertScalar(const Scalar& in) {
+ return in.is_valid ? BooleanScalar(!checked_cast<const BooleanScalar&>(in).value)
+ : BooleanScalar();
+}
+
+inline Bitmap GetBitmap(const ArrayData& arr, int index) {
+ return Bitmap{arr.buffers[index], arr.offset, arr.length};
+}
+
+struct InvertOp {
+ static Status Call(KernelContext* ctx, const Scalar& in, Scalar* out) {
+ *checked_cast<BooleanScalar*>(out) = InvertScalar(in);
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& in, ArrayData* out) {
+ GetBitmap(*out, 1).CopyFromInverted(GetBitmap(in, 1));
+ return Status::OK();
+ }
+};
+
+template <typename Op>
+struct Commutative {
+ static Status Call(KernelContext* ctx, const Scalar& left, const ArrayData& right,
+ ArrayData* out) {
+ return Op::Call(ctx, right, left, out);
+ }
+};
+
+struct AndOp : Commutative<AndOp> {
+ using Commutative<AndOp>::Call;
+
+ static Status Call(KernelContext* ctx, const Scalar& left, const Scalar& right,
+ Scalar* out) {
+ if (left.is_valid && right.is_valid) {
+ checked_cast<BooleanScalar*>(out)->value =
+ checked_cast<const BooleanScalar&>(left).value &&
+ checked_cast<const BooleanScalar&>(right).value;
+ }
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& left, const Scalar& right,
+ ArrayData* out) {
+ if (right.is_valid) {
+ checked_cast<const BooleanScalar&>(right).value
+ ? GetBitmap(*out, 1).CopyFrom(GetBitmap(left, 1))
+ : GetBitmap(*out, 1).SetBitsTo(false);
+ }
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& left, const ArrayData& right,
+ ArrayData* out) {
+ ::arrow::internal::BitmapAnd(left.buffers[1]->data(), left.offset,
+ right.buffers[1]->data(), right.offset, right.length,
+ out->offset, out->buffers[1]->mutable_data());
+ return Status::OK();
+ }
+};
+
+struct KleeneAndOp : Commutative<KleeneAndOp> {
+ using Commutative<KleeneAndOp>::Call;
+
+ static Status Call(KernelContext* ctx, const Scalar& left, const Scalar& right,
+ Scalar* out) {
+ bool left_true = left.is_valid && checked_cast<const BooleanScalar&>(left).value;
+ bool left_false = left.is_valid && !checked_cast<const BooleanScalar&>(left).value;
+
+ bool right_true = right.is_valid && checked_cast<const BooleanScalar&>(right).value;
+ bool right_false = right.is_valid && !checked_cast<const BooleanScalar&>(right).value;
+
+ checked_cast<BooleanScalar*>(out)->value = left_true && right_true;
+ out->is_valid = left_false || right_false || (left_true && right_true);
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& left, const Scalar& right,
+ ArrayData* out) {
+ bool right_true = right.is_valid && checked_cast<const BooleanScalar&>(right).value;
+ bool right_false = right.is_valid && !checked_cast<const BooleanScalar&>(right).value;
+
+ if (right_false) {
+ out->null_count = 0;
+ out->buffers[0] = nullptr;
+ GetBitmap(*out, 1).SetBitsTo(false); // all false case
+ return Status::OK();
+ }
+
+ if (right_true) {
+ if (left.GetNullCount() == 0) {
+ out->null_count = 0;
+ out->buffers[0] = nullptr;
+ } else {
+ GetBitmap(*out, 0).CopyFrom(GetBitmap(left, 0));
+ }
+ GetBitmap(*out, 1).CopyFrom(GetBitmap(left, 1));
+ return Status::OK();
+ }
+
+ // scalar was null: out[i] is valid iff left[i] was false
+ if (left.GetNullCount() == 0) {
+ ::arrow::internal::InvertBitmap(left.buffers[1]->data(), left.offset, left.length,
+ out->buffers[0]->mutable_data(), out->offset);
+ } else {
+ ::arrow::internal::BitmapAndNot(left.buffers[0]->data(), left.offset,
+ left.buffers[1]->data(), left.offset, left.length,
+ out->offset, out->buffers[0]->mutable_data());
+ }
+ ::arrow::internal::CopyBitmap(left.buffers[1]->data(), left.offset, left.length,
+ out->buffers[1]->mutable_data(), out->offset);
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& left, const ArrayData& right,
+ ArrayData* out) {
+ if (left.GetNullCount() == 0 && right.GetNullCount() == 0) {
+ out->null_count = 0;
+ // Kleene kernels have validity bitmap pre-allocated. Therefore, set it to 1
+ BitUtil::SetBitmap(out->buffers[0]->mutable_data(), out->offset, out->length);
+ return AndOp::Call(ctx, left, right, out);
+ }
+ auto compute_word = [](uint64_t left_true, uint64_t left_false, uint64_t right_true,
+ uint64_t right_false, uint64_t* out_valid,
+ uint64_t* out_data) {
+ *out_data = left_true & right_true;
+ *out_valid = left_false | right_false | (left_true & right_true);
+ };
+ ComputeKleene(compute_word, ctx, left, right, out);
+ return Status::OK();
+ }
+};
+
+struct OrOp : Commutative<OrOp> {
+ using Commutative<OrOp>::Call;
+
+ static Status Call(KernelContext* ctx, const Scalar& left, const Scalar& right,
+ Scalar* out) {
+ if (left.is_valid && right.is_valid) {
+ checked_cast<BooleanScalar*>(out)->value =
+ checked_cast<const BooleanScalar&>(left).value ||
+ checked_cast<const BooleanScalar&>(right).value;
+ }
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& left, const Scalar& right,
+ ArrayData* out) {
+ if (right.is_valid) {
+ checked_cast<const BooleanScalar&>(right).value
+ ? GetBitmap(*out, 1).SetBitsTo(true)
+ : GetBitmap(*out, 1).CopyFrom(GetBitmap(left, 1));
+ }
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& left, const ArrayData& right,
+ ArrayData* out) {
+ ::arrow::internal::BitmapOr(left.buffers[1]->data(), left.offset,
+ right.buffers[1]->data(), right.offset, right.length,
+ out->offset, out->buffers[1]->mutable_data());
+ return Status::OK();
+ }
+};
+
+struct KleeneOrOp : Commutative<KleeneOrOp> {
+ using Commutative<KleeneOrOp>::Call;
+
+ static Status Call(KernelContext* ctx, const Scalar& left, const Scalar& right,
+ Scalar* out) {
+ bool left_true = left.is_valid && checked_cast<const BooleanScalar&>(left).value;
+ bool left_false = left.is_valid && !checked_cast<const BooleanScalar&>(left).value;
+
+ bool right_true = right.is_valid && checked_cast<const BooleanScalar&>(right).value;
+ bool right_false = right.is_valid && !checked_cast<const BooleanScalar&>(right).value;
+
+ checked_cast<BooleanScalar*>(out)->value = left_true || right_true;
+ out->is_valid = left_true || right_true || (left_false && right_false);
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& left, const Scalar& right,
+ ArrayData* out) {
+ bool right_true = right.is_valid && checked_cast<const BooleanScalar&>(right).value;
+ bool right_false = right.is_valid && !checked_cast<const BooleanScalar&>(right).value;
+
+ if (right_true) {
+ out->null_count = 0;
+ out->buffers[0] = nullptr;
+ GetBitmap(*out, 1).SetBitsTo(true); // all true case
+ return Status::OK();
+ }
+
+ if (right_false) {
+ if (left.GetNullCount() == 0) {
+ out->null_count = 0;
+ out->buffers[0] = nullptr;
+ } else {
+ GetBitmap(*out, 0).CopyFrom(GetBitmap(left, 0));
+ }
+ GetBitmap(*out, 1).CopyFrom(GetBitmap(left, 1));
+ return Status::OK();
+ }
+
+ // scalar was null: out[i] is valid iff left[i] was true
+ if (left.GetNullCount() == 0) {
+ ::arrow::internal::CopyBitmap(left.buffers[1]->data(), left.offset, left.length,
+ out->buffers[0]->mutable_data(), out->offset);
+ } else {
+ ::arrow::internal::BitmapAnd(left.buffers[0]->data(), left.offset,
+ left.buffers[1]->data(), left.offset, left.length,
+ out->offset, out->buffers[0]->mutable_data());
+ }
+ ::arrow::internal::CopyBitmap(left.buffers[1]->data(), left.offset, left.length,
+ out->buffers[1]->mutable_data(), out->offset);
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& left, const ArrayData& right,
+ ArrayData* out) {
+ if (left.GetNullCount() == 0 && right.GetNullCount() == 0) {
+ out->null_count = 0;
+ // Kleene kernels have validity bitmap pre-allocated. Therefore, set it to 1
+ BitUtil::SetBitmap(out->buffers[0]->mutable_data(), out->offset, out->length);
+ return OrOp::Call(ctx, left, right, out);
+ }
+
+ static auto compute_word = [](uint64_t left_true, uint64_t left_false,
+ uint64_t right_true, uint64_t right_false,
+ uint64_t* out_valid, uint64_t* out_data) {
+ *out_data = left_true | right_true;
+ *out_valid = left_true | right_true | (left_false & right_false);
+ };
+
+ ComputeKleene(compute_word, ctx, left, right, out);
+ return Status::OK();
+ }
+};
+
+struct XorOp : Commutative<XorOp> {
+ using Commutative<XorOp>::Call;
+
+ static Status Call(KernelContext* ctx, const Scalar& left, const Scalar& right,
+ Scalar* out) {
+ if (left.is_valid && right.is_valid) {
+ checked_cast<BooleanScalar*>(out)->value =
+ checked_cast<const BooleanScalar&>(left).value ^
+ checked_cast<const BooleanScalar&>(right).value;
+ }
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& left, const Scalar& right,
+ ArrayData* out) {
+ if (right.is_valid) {
+ checked_cast<const BooleanScalar&>(right).value
+ ? GetBitmap(*out, 1).CopyFromInverted(GetBitmap(left, 1))
+ : GetBitmap(*out, 1).CopyFrom(GetBitmap(left, 1));
+ }
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& left, const ArrayData& right,
+ ArrayData* out) {
+ ::arrow::internal::BitmapXor(left.buffers[1]->data(), left.offset,
+ right.buffers[1]->data(), right.offset, right.length,
+ out->offset, out->buffers[1]->mutable_data());
+ return Status::OK();
+ }
+};
+
+struct AndNotOp {
+ static Status Call(KernelContext* ctx, const Scalar& left, const Scalar& right,
+ Scalar* out) {
+ return AndOp::Call(ctx, left, InvertScalar(right), out);
+ }
+
+ static Status Call(KernelContext* ctx, const Scalar& left, const ArrayData& right,
+ ArrayData* out) {
+ if (left.is_valid) {
+ checked_cast<const BooleanScalar&>(left).value
+ ? GetBitmap(*out, 1).CopyFromInverted(GetBitmap(right, 1))
+ : GetBitmap(*out, 1).SetBitsTo(false);
+ }
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& left, const Scalar& right,
+ ArrayData* out) {
+ return AndOp::Call(ctx, left, InvertScalar(right), out);
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& left, const ArrayData& right,
+ ArrayData* out) {
+ ::arrow::internal::BitmapAndNot(left.buffers[1]->data(), left.offset,
+ right.buffers[1]->data(), right.offset, right.length,
+ out->offset, out->buffers[1]->mutable_data());
+ return Status::OK();
+ }
+};
+
+struct KleeneAndNotOp {
+ static Status Call(KernelContext* ctx, const Scalar& left, const Scalar& right,
+ Scalar* out) {
+ return KleeneAndOp::Call(ctx, left, InvertScalar(right), out);
+ }
+
+ static Status Call(KernelContext* ctx, const Scalar& left, const ArrayData& right,
+ ArrayData* out) {
+ bool left_true = left.is_valid && checked_cast<const BooleanScalar&>(left).value;
+ bool left_false = left.is_valid && !checked_cast<const BooleanScalar&>(left).value;
+
+ if (left_false) {
+ out->null_count = 0;
+ out->buffers[0] = nullptr;
+ GetBitmap(*out, 1).SetBitsTo(false); // all false case
+ return Status::OK();
+ }
+
+ if (left_true) {
+ if (right.GetNullCount() == 0) {
+ out->null_count = 0;
+ out->buffers[0] = nullptr;
+ } else {
+ GetBitmap(*out, 0).CopyFrom(GetBitmap(right, 0));
+ }
+ GetBitmap(*out, 1).CopyFromInverted(GetBitmap(right, 1));
+ return Status::OK();
+ }
+
+ // scalar was null: out[i] is valid iff right[i] was true
+ if (right.GetNullCount() == 0) {
+ ::arrow::internal::CopyBitmap(right.buffers[1]->data(), right.offset, right.length,
+ out->buffers[0]->mutable_data(), out->offset);
+ } else {
+ ::arrow::internal::BitmapAnd(right.buffers[0]->data(), right.offset,
+ right.buffers[1]->data(), right.offset, right.length,
+ out->offset, out->buffers[0]->mutable_data());
+ }
+ ::arrow::internal::InvertBitmap(right.buffers[1]->data(), right.offset, right.length,
+ out->buffers[1]->mutable_data(), out->offset);
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& left, const Scalar& right,
+ ArrayData* out) {
+ return KleeneAndOp::Call(ctx, left, InvertScalar(right), out);
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& left, const ArrayData& right,
+ ArrayData* out) {
+ if (left.GetNullCount() == 0 && right.GetNullCount() == 0) {
+ out->null_count = 0;
+ // Kleene kernels have validity bitmap pre-allocated. Therefore, set it to 1
+ BitUtil::SetBitmap(out->buffers[0]->mutable_data(), out->offset, out->length);
+ return AndNotOp::Call(ctx, left, right, out);
+ }
+
+ static auto compute_word = [](uint64_t left_true, uint64_t left_false,
+ uint64_t right_true, uint64_t right_false,
+ uint64_t* out_valid, uint64_t* out_data) {
+ *out_data = left_true & right_false;
+ *out_valid = left_false | right_true | (left_true & right_false);
+ };
+
+ ComputeKleene(compute_word, ctx, left, right, out);
+ return Status::OK();
+ }
+};
+
+void MakeFunction(const std::string& name, int arity, ArrayKernelExec exec,
+ const FunctionDoc* doc, FunctionRegistry* registry,
+ NullHandling::type null_handling = NullHandling::INTERSECTION) {
+ auto func = std::make_shared<ScalarFunction>(name, Arity(arity), doc);
+
+ // Scalar arguments not yet supported
+ std::vector<InputType> in_types(arity, InputType(boolean()));
+ ScalarKernel kernel(std::move(in_types), boolean(), exec);
+ kernel.null_handling = null_handling;
+
+ DCHECK_OK(func->AddKernel(kernel));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+const FunctionDoc invert_doc{"Invert boolean values", "", {"values"}};
+
+const FunctionDoc and_doc{
+ "Logical 'and' boolean values",
+ ("When a null is encountered in either input, a null is output.\n"
+ "For a different null behavior, see function \"and_kleene\"."),
+ {"x", "y"}};
+
+const FunctionDoc and_not_doc{
+ "Logical 'and not' boolean values",
+ ("When a null is encountered in either input, a null is output.\n"
+ "For a different null behavior, see function \"and_not_kleene\"."),
+ {"x", "y"}};
+
+const FunctionDoc or_doc{
+ "Logical 'or' boolean values",
+ ("When a null is encountered in either input, a null is output.\n"
+ "For a different null behavior, see function \"or_kleene\"."),
+ {"x", "y"}};
+
+const FunctionDoc xor_doc{
+ "Logical 'xor' boolean values",
+ ("When a null is encountered in either input, a null is output."),
+ {"x", "y"}};
+
+const FunctionDoc and_kleene_doc{
+ "Logical 'and' boolean values (Kleene logic)",
+ ("This function behaves as follows with nulls:\n\n"
+ "- true and null = null\n"
+ "- null and true = null\n"
+ "- false and null = false\n"
+ "- null and false = false\n"
+ "- null and null = null\n"
+ "\n"
+ "In other words, in this context a null value really means \"unknown\",\n"
+ "and an unknown value 'and' false is always false.\n"
+ "For a different null behavior, see function \"and\"."),
+ {"x", "y"}};
+
+const FunctionDoc and_not_kleene_doc{
+ "Logical 'and not' boolean values (Kleene logic)",
+ ("This function behaves as follows with nulls:\n\n"
+ "- true and null = null\n"
+ "- null and false = null\n"
+ "- false and null = false\n"
+ "- null and true = false\n"
+ "- null and null = null\n"
+ "\n"
+ "In other words, in this context a null value really means \"unknown\",\n"
+ "and an unknown value 'and not' true is always false, as is false\n"
+ "'and not' an unknown value.\n"
+ "For a different null behavior, see function \"and_not\"."),
+ {"x", "y"}};
+
+const FunctionDoc or_kleene_doc{
+ "Logical 'or' boolean values (Kleene logic)",
+ ("This function behaves as follows with nulls:\n\n"
+ "- true or null = true\n"
+ "- null and true = true\n"
+ "- false and null = null\n"
+ "- null and false = null\n"
+ "- null and null = null\n"
+ "\n"
+ "In other words, in this context a null value really means \"unknown\",\n"
+ "and an unknown value 'or' true is always true.\n"
+ "For a different null behavior, see function \"and\"."),
+ {"x", "y"}};
+
+} // namespace
+
+namespace internal {
+
+void RegisterScalarBoolean(FunctionRegistry* registry) {
+ // These functions can write into sliced output bitmaps
+ MakeFunction("invert", 1, applicator::SimpleUnary<InvertOp>, &invert_doc, registry);
+ MakeFunction("and", 2, applicator::SimpleBinary<AndOp>, &and_doc, registry);
+ MakeFunction("and_not", 2, applicator::SimpleBinary<AndNotOp>, &and_not_doc, registry);
+ MakeFunction("or", 2, applicator::SimpleBinary<OrOp>, &or_doc, registry);
+ MakeFunction("xor", 2, applicator::SimpleBinary<XorOp>, &xor_doc, registry);
+
+ MakeFunction("and_kleene", 2, applicator::SimpleBinary<KleeneAndOp>, &and_kleene_doc,
+ registry, NullHandling::COMPUTED_PREALLOCATE);
+ MakeFunction("and_not_kleene", 2, applicator::SimpleBinary<KleeneAndNotOp>,
+ &and_not_kleene_doc, registry, NullHandling::COMPUTED_PREALLOCATE);
+ MakeFunction("or_kleene", 2, applicator::SimpleBinary<KleeneOrOp>, &or_kleene_doc,
+ registry, NullHandling::COMPUTED_PREALLOCATE);
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_boolean_benchmark.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_boolean_benchmark.cc
new file mode 100644
index 000000000..969b91a14
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_boolean_benchmark.cc
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <vector>
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+
+namespace arrow {
+namespace compute {
+
+constexpr auto kSeed = 0x94378165;
+
+using BooleanBinaryOp = Result<Datum>(const Datum&, const Datum&, ExecContext*);
+
+template <BooleanBinaryOp& Op>
+static void ArrayArrayKernel(benchmark::State& state) {
+ RegressionArgs args(state);
+
+ const int64_t array_size = args.size * 8;
+
+ auto rand = random::RandomArrayGenerator(kSeed);
+ auto lhs = rand.Boolean(array_size, /*true_probability=*/0.5, args.null_proportion);
+ auto rhs = rand.Boolean(array_size, /*true_probability=*/0.5, args.null_proportion);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(Op(lhs, rhs, nullptr).status());
+ }
+ state.SetItemsProcessed(state.iterations() * array_size);
+}
+
+void SetArgs(benchmark::internal::Benchmark* bench) {
+ BenchmarkSetArgsWithSizes(bench, {kL1Size, kL2Size});
+}
+
+BENCHMARK_TEMPLATE(ArrayArrayKernel, And)->Apply(SetArgs);
+BENCHMARK_TEMPLATE(ArrayArrayKernel, KleeneAnd)->Apply(SetArgs);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_boolean_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_boolean_test.cc
new file mode 100644
index 000000000..4c11eb6db
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_boolean_test.cc
@@ -0,0 +1,159 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/chunked_array.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+void CheckBooleanScalarArrayBinary(std::string func_name, Datum array) {
+ for (std::shared_ptr<Scalar> scalar :
+ {std::make_shared<BooleanScalar>(), std::make_shared<BooleanScalar>(true),
+ std::make_shared<BooleanScalar>(false)}) {
+ ASSERT_OK_AND_ASSIGN(Datum actual, CallFunction(func_name, {Datum(scalar), array}));
+
+ ASSERT_OK_AND_ASSIGN(auto constant_array,
+ MakeArrayFromScalar(*scalar, array.length()));
+
+ ASSERT_OK_AND_ASSIGN(Datum expected,
+ CallFunction(func_name, {Datum(constant_array), array}));
+ AssertDatumsEqual(expected, actual);
+
+ ASSERT_OK_AND_ASSIGN(actual, CallFunction(func_name, {array, Datum(scalar)}));
+ ASSERT_OK_AND_ASSIGN(expected,
+ CallFunction(func_name, {array, Datum(constant_array)}));
+ AssertDatumsEqual(expected, actual);
+ }
+}
+
+TEST(TestBooleanKernel, Invert) {
+ auto arr =
+ ArrayFromJSON(boolean(), "[true, false, true, null, false, true, false, null]");
+ auto expected =
+ ArrayFromJSON(boolean(), "[false, true, false, null, true, false, true, null]");
+ CheckScalarUnary("invert", arr, expected);
+}
+
+TEST(TestBooleanKernel, And) {
+ auto left = ArrayFromJSON(boolean(), " [true, true, true, false, false, null]");
+ auto right = ArrayFromJSON(boolean(), " [true, false, null, false, null, null]");
+ auto expected = ArrayFromJSON(boolean(), "[true, false, null, false, null, null]");
+ CheckScalarBinary("and", left, right, expected);
+ CheckBooleanScalarArrayBinary("and", left);
+}
+
+TEST(TestBooleanKernel, Or) {
+ auto left = ArrayFromJSON(boolean(), " [true, true, true, false, false, null]");
+ auto right = ArrayFromJSON(boolean(), " [true, false, null, false, null, null]");
+ auto expected = ArrayFromJSON(boolean(), "[true, true, null, false, null, null]");
+ CheckScalarBinary("or", left, right, expected);
+ CheckBooleanScalarArrayBinary("or", left);
+}
+
+TEST(TestBooleanKernel, Xor) {
+ auto left = ArrayFromJSON(boolean(), " [true, true, true, false, false, null]");
+ auto right = ArrayFromJSON(boolean(), " [true, false, null, false, null, null]");
+ auto expected = ArrayFromJSON(boolean(), "[false, true, null, false, null, null]");
+ CheckScalarBinary("xor", left, right, expected);
+ CheckBooleanScalarArrayBinary("xor", left);
+}
+
+TEST(TestBooleanKernel, AndNot) {
+ auto left = ArrayFromJSON(
+ boolean(), "[true, true, true, false, false, false, null, null, null]");
+ auto right = ArrayFromJSON(
+ boolean(), "[true, false, null, true, false, null, true, false, null]");
+ auto expected = ArrayFromJSON(
+ boolean(), "[false, true, null, false, false, null, null, null, null]");
+ CheckScalarBinary("and_not", left, right, expected);
+ CheckBooleanScalarArrayBinary("and_not", left);
+}
+
+TEST(TestBooleanKernel, KleeneAnd) {
+ auto left = ArrayFromJSON(boolean(), " [true, true, true, false, false, null]");
+ auto right = ArrayFromJSON(boolean(), " [true, false, null, false, null, null]");
+ auto expected = ArrayFromJSON(boolean(), "[true, false, null, false, false, null]");
+ CheckScalarBinary("and_kleene", left, right, expected);
+ CheckBooleanScalarArrayBinary("and_kleene", left);
+
+ left = ArrayFromJSON(boolean(), " [true, true, false, null, null]");
+ right = ArrayFromJSON(boolean(), " [true, false, false, true, false]");
+ expected = ArrayFromJSON(boolean(), "[true, false, false, null, false]");
+ CheckScalarBinary("and_kleene", left, right, expected);
+ CheckBooleanScalarArrayBinary("and_kleene", left);
+
+ left = ArrayFromJSON(boolean(), " [true, true, false, true]");
+ right = ArrayFromJSON(boolean(), " [true, false, false, false]");
+ expected = ArrayFromJSON(boolean(), "[true, false, false, false]");
+ CheckScalarBinary("and_kleene", left, right, expected);
+ CheckBooleanScalarArrayBinary("and_kleene", left);
+}
+
+TEST(TestBooleanKernel, KleeneAndNot) {
+ auto left = ArrayFromJSON(
+ boolean(), "[true, true, true, false, false, false, null, null, null]");
+ auto right = ArrayFromJSON(
+ boolean(), "[true, false, null, true, false, null, true, false, null]");
+ auto expected = ArrayFromJSON(
+ boolean(), "[false, true, null, false, false, false, false, null, null]");
+ CheckScalarBinary("and_not_kleene", left, right, expected);
+ CheckBooleanScalarArrayBinary("and_not_kleene", left);
+
+ left = ArrayFromJSON(boolean(), " [true, true, false, false]");
+ right = ArrayFromJSON(boolean(), " [true, false, true, false]");
+ expected = ArrayFromJSON(boolean(), "[false, true, false, false]");
+ CheckScalarBinary("and_not_kleene", left, right, expected);
+ CheckBooleanScalarArrayBinary("and_not_kleene", left);
+}
+
+TEST(TestBooleanKernel, KleeneOr) {
+ auto left = ArrayFromJSON(boolean(), " [true, true, true, false, false, null]");
+ auto right = ArrayFromJSON(boolean(), " [true, false, null, false, null, null]");
+ auto expected = ArrayFromJSON(boolean(), "[true, true, true, false, null, null]");
+ CheckScalarBinary("or_kleene", left, right, expected);
+ CheckBooleanScalarArrayBinary("or_kleene", left);
+
+ left = ArrayFromJSON(boolean(), " [true, true, false, null, null]");
+ right = ArrayFromJSON(boolean(), " [true, false, false, true, false]");
+ expected = ArrayFromJSON(boolean(), "[true, true, false, true, null]");
+ CheckScalarBinary("or_kleene", left, right, expected);
+ CheckBooleanScalarArrayBinary("or_kleene", left);
+
+ left = ArrayFromJSON(boolean(), " [true, true, false, false]");
+ right = ArrayFromJSON(boolean(), " [true, false, false, true]");
+ expected = ArrayFromJSON(boolean(), "[true, true, false, true]");
+ CheckScalarBinary("or_kleene", left, right, expected);
+ CheckBooleanScalarArrayBinary("or_kleene", left);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_benchmark.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_benchmark.cc
new file mode 100644
index 000000000..8eea8725d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_benchmark.cc
@@ -0,0 +1,117 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <vector>
+
+#include "arrow/compute/cast.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+
+namespace arrow {
+namespace compute {
+
+constexpr auto kSeed = 0x94378165;
+
+template <typename InputType, typename CType = typename InputType::c_type>
+static void BenchmarkNumericCast(benchmark::State& state,
+ std::shared_ptr<DataType> to_type,
+ const CastOptions& options, CType min, CType max) {
+ GenericItemsArgs args(state);
+ random::RandomArrayGenerator rand(kSeed);
+ auto array = rand.Numeric<InputType>(args.size, min, max, args.null_proportion);
+ for (auto _ : state) {
+ ABORT_NOT_OK(Cast(array, to_type, options).status());
+ }
+}
+
+template <typename InputType, typename CType = typename InputType::c_type>
+static void BenchmarkFloatingToIntegerCast(benchmark::State& state,
+ std::shared_ptr<DataType> from_type,
+ std::shared_ptr<DataType> to_type,
+ const CastOptions& options, CType min,
+ CType max) {
+ GenericItemsArgs args(state);
+ random::RandomArrayGenerator rand(kSeed);
+ auto array = rand.Numeric<InputType>(args.size, min, max, args.null_proportion);
+
+ std::shared_ptr<Array> values_as_float = *Cast(*array, from_type);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(Cast(values_as_float, to_type, options).status());
+ }
+}
+
+std::vector<int64_t> g_data_sizes = {kL2Size};
+
+void CastSetArgs(benchmark::internal::Benchmark* bench) {
+ for (int64_t size : g_data_sizes) {
+ for (auto nulls : std::vector<ArgsType>({1000, 10, 2, 1, 0})) {
+ bench->Args({static_cast<ArgsType>(size), nulls});
+ }
+ }
+}
+
+static constexpr int32_t kInt32Min = std::numeric_limits<int32_t>::min();
+static constexpr int32_t kInt32Max = std::numeric_limits<int32_t>::max();
+
+static void CastInt64ToInt32Safe(benchmark::State& state) {
+ BenchmarkNumericCast<Int64Type>(state, int32(), CastOptions::Safe(), kInt32Min,
+ kInt32Max);
+}
+
+static void CastInt64ToInt32Unsafe(benchmark::State& state) {
+ BenchmarkNumericCast<Int64Type>(state, int32(), CastOptions::Unsafe(), kInt32Min,
+ kInt32Max);
+}
+
+static void CastUInt32ToInt32Safe(benchmark::State& state) {
+ BenchmarkNumericCast<UInt32Type>(state, int32(), CastOptions::Safe(), 0, kInt32Max);
+}
+
+static void CastInt64ToDoubleSafe(benchmark::State& state) {
+ BenchmarkNumericCast<Int64Type>(state, float64(), CastOptions::Safe(), 0, 1000);
+}
+
+static void CastInt64ToDoubleUnsafe(benchmark::State& state) {
+ BenchmarkNumericCast<Int64Type>(state, float64(), CastOptions::Unsafe(), 0, 1000);
+}
+
+static void CastDoubleToInt32Safe(benchmark::State& state) {
+ BenchmarkFloatingToIntegerCast<Int32Type>(state, float64(), int32(),
+ CastOptions::Safe(), -1000, 1000);
+}
+
+static void CastDoubleToInt32Unsafe(benchmark::State& state) {
+ BenchmarkFloatingToIntegerCast<Int32Type>(state, float64(), int32(),
+ CastOptions::Unsafe(), -1000, 1000);
+}
+
+BENCHMARK(CastInt64ToInt32Safe)->Apply(CastSetArgs);
+BENCHMARK(CastInt64ToInt32Unsafe)->Apply(CastSetArgs);
+BENCHMARK(CastUInt32ToInt32Safe)->Apply(CastSetArgs);
+
+BENCHMARK(CastInt64ToDoubleSafe)->Apply(CastSetArgs);
+BENCHMARK(CastInt64ToDoubleUnsafe)->Apply(CastSetArgs);
+BENCHMARK(CastDoubleToInt32Safe)->Apply(CastSetArgs);
+BENCHMARK(CastDoubleToInt32Unsafe)->Apply(CastSetArgs);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc
new file mode 100644
index 000000000..dad94c1ac
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Cast types to boolean
+
+#include "arrow/array/builder_primitive.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/scalar_cast_internal.h"
+#include "arrow/util/value_parsing.h"
+
+namespace arrow {
+
+using internal::ParseValue;
+
+namespace compute {
+namespace internal {
+
+struct IsNonZero {
+ template <typename OutValue, typename Arg0Value>
+ static OutValue Call(KernelContext*, Arg0Value val, Status*) {
+ return val != 0;
+ }
+};
+
+struct ParseBooleanString {
+ template <typename OutValue, typename Arg0Value>
+ static OutValue Call(KernelContext*, Arg0Value val, Status* st) {
+ bool result = false;
+ if (ARROW_PREDICT_FALSE(!ParseValue<BooleanType>(val.data(), val.size(), &result))) {
+ *st = Status::Invalid("Failed to parse value: ", val);
+ }
+ return result;
+ }
+};
+
+std::vector<std::shared_ptr<CastFunction>> GetBooleanCasts() {
+ auto func = std::make_shared<CastFunction>("cast_boolean", Type::BOOL);
+ AddCommonCasts(Type::BOOL, boolean(), func.get());
+ AddZeroCopyCast(Type::BOOL, boolean(), boolean(), func.get());
+
+ for (const auto& ty : NumericTypes()) {
+ ArrayKernelExec exec =
+ GenerateNumeric<applicator::ScalarUnary, BooleanType, IsNonZero>(*ty);
+ DCHECK_OK(func->AddKernel(ty->id(), {ty}, boolean(), exec));
+ }
+ for (const auto& ty : BaseBinaryTypes()) {
+ ArrayKernelExec exec = GenerateVarBinaryBase<applicator::ScalarUnaryNotNull,
+ BooleanType, ParseBooleanString>(*ty);
+ DCHECK_OK(func->AddKernel(ty->id(), {ty}, boolean(), exec));
+ }
+ return {func};
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc
new file mode 100644
index 000000000..b1e1164fd
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc
@@ -0,0 +1,126 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Implementation of casting to dictionary type
+
+#include <arrow/util/bitmap_ops.h>
+#include <arrow/util/checked_cast.h>
+
+#include "arrow/array/builder_primitive.h"
+#include "arrow/compute/cast_internal.h"
+#include "arrow/compute/kernels/scalar_cast_internal.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/util/int_util.h"
+
+namespace arrow {
+using internal::CopyBitmap;
+
+namespace compute {
+namespace internal {
+
+Status CastDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const CastOptions& options = CastState::Get(ctx);
+ auto out_type = std::static_pointer_cast<DictionaryType>(out->type());
+
+ // if out type is same as in type, return input
+ if (out_type->Equals(batch[0].type())) {
+ *out = batch[0];
+ return Status::OK();
+ }
+
+ if (batch[0].is_scalar()) { // if input is scalar
+ auto in_scalar = checked_cast<const DictionaryScalar&>(*batch[0].scalar());
+
+ // if invalid scalar, return null scalar
+ if (!in_scalar.is_valid) {
+ *out = MakeNullScalar(out_type);
+ return Status::OK();
+ }
+
+ Datum casted_index, casted_dict;
+ if (in_scalar.value.index->type->Equals(out_type->index_type())) {
+ casted_index = in_scalar.value.index;
+ } else {
+ ARROW_ASSIGN_OR_RAISE(casted_index,
+ Cast(in_scalar.value.index, out_type->index_type(), options,
+ ctx->exec_context()));
+ }
+
+ if (in_scalar.value.dictionary->type()->Equals(out_type->value_type())) {
+ casted_dict = in_scalar.value.dictionary;
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ casted_dict, Cast(in_scalar.value.dictionary, out_type->value_type(), options,
+ ctx->exec_context()));
+ }
+
+ *out = std::static_pointer_cast<Scalar>(
+ DictionaryScalar::Make(casted_index.scalar(), casted_dict.make_array()));
+
+ return Status::OK();
+ }
+
+ // if input is array
+ const std::shared_ptr<ArrayData>& in_array = batch[0].array();
+ const auto& in_type = checked_cast<const DictionaryType&>(*in_array->type);
+
+ ArrayData* out_array = out->mutable_array();
+
+ if (in_type.index_type()->Equals(out_type->index_type())) {
+ out_array->buffers[0] = in_array->buffers[0];
+ out_array->buffers[1] = in_array->buffers[1];
+ out_array->null_count = in_array->GetNullCount();
+ out_array->offset = in_array->offset;
+ } else {
+ // for indices, create a dummy ArrayData with index_type()
+ const std::shared_ptr<ArrayData>& indices_arr =
+ ArrayData::Make(in_type.index_type(), in_array->length, in_array->buffers,
+ in_array->GetNullCount(), in_array->offset);
+ ARROW_ASSIGN_OR_RAISE(auto casted_indices, Cast(indices_arr, out_type->index_type(),
+ options, ctx->exec_context()));
+ out_array->buffers[0] = std::move(casted_indices.array()->buffers[0]);
+ out_array->buffers[1] = std::move(casted_indices.array()->buffers[1]);
+ }
+
+ // data (dict)
+ if (in_type.value_type()->Equals(out_type->value_type())) {
+ out_array->dictionary = in_array->dictionary;
+ } else {
+ const std::shared_ptr<Array>& dict_arr = MakeArray(in_array->dictionary);
+ ARROW_ASSIGN_OR_RAISE(auto casted_data, Cast(dict_arr, out_type->value_type(),
+ options, ctx->exec_context()));
+ out_array->dictionary = casted_data.array();
+ }
+ return Status::OK();
+}
+
+std::vector<std::shared_ptr<CastFunction>> GetDictionaryCasts() {
+ auto func = std::make_shared<CastFunction>("cast_dictionary", Type::DICTIONARY);
+
+ AddCommonCasts(Type::DICTIONARY, kOutputTargetType, func.get());
+ ScalarKernel kernel({InputType(Type::DICTIONARY)}, kOutputTargetType, CastDictionary);
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+
+ DCHECK_OK(func->AddKernel(Type::DICTIONARY, std::move(kernel)));
+
+ return {func};
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc
new file mode 100644
index 000000000..5254cb1cc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc
@@ -0,0 +1,299 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/kernels/scalar_cast_internal.h"
+#include "arrow/compute/cast_internal.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/extension_type.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::PrimitiveScalarBase;
+
+namespace compute {
+namespace internal {
+
+// ----------------------------------------------------------------------
+
+namespace {
+
+template <typename OutT, typename InT>
+ARROW_DISABLE_UBSAN("float-cast-overflow")
+void DoStaticCast(const void* in_data, int64_t in_offset, int64_t length,
+ int64_t out_offset, void* out_data) {
+ auto in = reinterpret_cast<const InT*>(in_data) + in_offset;
+ auto out = reinterpret_cast<OutT*>(out_data) + out_offset;
+ for (int64_t i = 0; i < length; ++i) {
+ *out++ = static_cast<OutT>(*in++);
+ }
+}
+
+using StaticCastFunc = std::function<void(const void*, int64_t, int64_t, int64_t, void*)>;
+
+template <typename OutType, typename InType, typename Enable = void>
+struct CastPrimitive {
+ static void Exec(const Datum& input, Datum* out) {
+ using OutT = typename OutType::c_type;
+ using InT = typename InType::c_type;
+
+ StaticCastFunc caster = DoStaticCast<OutT, InT>;
+ if (input.kind() == Datum::ARRAY) {
+ const ArrayData& arr = *input.array();
+ ArrayData* out_arr = out->mutable_array();
+ caster(arr.buffers[1]->data(), arr.offset, arr.length, out_arr->offset,
+ out_arr->buffers[1]->mutable_data());
+ } else {
+ // Scalar path. Use the caster with length 1 to place the casted value into
+ // the output
+ const auto& in_scalar = input.scalar_as<PrimitiveScalarBase>();
+ auto out_scalar = checked_cast<PrimitiveScalarBase*>(out->scalar().get());
+ caster(reinterpret_cast<const void*>(in_scalar.view().data()), /*in_offset=*/0,
+ /*length=*/1, /*out_offset=*/0, out_scalar->mutable_data());
+ }
+ }
+};
+
+template <typename OutType, typename InType>
+struct CastPrimitive<OutType, InType, enable_if_t<std::is_same<OutType, InType>::value>> {
+ // memcpy output
+ static void Exec(const Datum& input, Datum* out) {
+ using T = typename InType::c_type;
+
+ if (input.kind() == Datum::ARRAY) {
+ const ArrayData& arr = *input.array();
+ ArrayData* out_arr = out->mutable_array();
+ std::memcpy(
+ reinterpret_cast<T*>(out_arr->buffers[1]->mutable_data()) + out_arr->offset,
+ reinterpret_cast<const T*>(arr.buffers[1]->data()) + arr.offset,
+ arr.length * sizeof(T));
+ } else {
+ // Scalar path. Use the caster with length 1 to place the casted value into
+ // the output
+ const auto& in_scalar = input.scalar_as<PrimitiveScalarBase>();
+ auto out_scalar = checked_cast<PrimitiveScalarBase*>(out->scalar().get());
+ *reinterpret_cast<T*>(out_scalar->mutable_data()) =
+ *reinterpret_cast<const T*>(in_scalar.view().data());
+ }
+ }
+};
+
+template <typename InType>
+void CastNumberImpl(Type::type out_type, const Datum& input, Datum* out) {
+ switch (out_type) {
+ case Type::INT8:
+ return CastPrimitive<Int8Type, InType>::Exec(input, out);
+ case Type::INT16:
+ return CastPrimitive<Int16Type, InType>::Exec(input, out);
+ case Type::INT32:
+ return CastPrimitive<Int32Type, InType>::Exec(input, out);
+ case Type::INT64:
+ return CastPrimitive<Int64Type, InType>::Exec(input, out);
+ case Type::UINT8:
+ return CastPrimitive<UInt8Type, InType>::Exec(input, out);
+ case Type::UINT16:
+ return CastPrimitive<UInt16Type, InType>::Exec(input, out);
+ case Type::UINT32:
+ return CastPrimitive<UInt32Type, InType>::Exec(input, out);
+ case Type::UINT64:
+ return CastPrimitive<UInt64Type, InType>::Exec(input, out);
+ case Type::FLOAT:
+ return CastPrimitive<FloatType, InType>::Exec(input, out);
+ case Type::DOUBLE:
+ return CastPrimitive<DoubleType, InType>::Exec(input, out);
+ default:
+ break;
+ }
+}
+
+} // namespace
+
+void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, const Datum& input,
+ Datum* out) {
+ switch (in_type) {
+ case Type::INT8:
+ return CastNumberImpl<Int8Type>(out_type, input, out);
+ case Type::INT16:
+ return CastNumberImpl<Int16Type>(out_type, input, out);
+ case Type::INT32:
+ return CastNumberImpl<Int32Type>(out_type, input, out);
+ case Type::INT64:
+ return CastNumberImpl<Int64Type>(out_type, input, out);
+ case Type::UINT8:
+ return CastNumberImpl<UInt8Type>(out_type, input, out);
+ case Type::UINT16:
+ return CastNumberImpl<UInt16Type>(out_type, input, out);
+ case Type::UINT32:
+ return CastNumberImpl<UInt32Type>(out_type, input, out);
+ case Type::UINT64:
+ return CastNumberImpl<UInt64Type>(out_type, input, out);
+ case Type::FLOAT:
+ return CastNumberImpl<FloatType>(out_type, input, out);
+ case Type::DOUBLE:
+ return CastNumberImpl<DoubleType>(out_type, input, out);
+ default:
+ DCHECK(false);
+ break;
+ }
+}
+
+// ----------------------------------------------------------------------
+
+Status UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK(out->is_array());
+
+ DictionaryArray dict_arr(batch[0].array());
+ const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options;
+
+ const auto& dict_type = *dict_arr.dictionary()->type();
+ if (!dict_type.Equals(options.to_type) && !CanCast(dict_type, *options.to_type)) {
+ return Status::Invalid("Cast type ", options.to_type->ToString(),
+ " incompatible with dictionary type ", dict_type.ToString());
+ }
+
+ ARROW_ASSIGN_OR_RAISE(*out,
+ Take(Datum(dict_arr.dictionary()), Datum(dict_arr.indices()),
+ TakeOptions::Defaults(), ctx->exec_context()));
+
+ if (!dict_type.Equals(options.to_type)) {
+ ARROW_ASSIGN_OR_RAISE(*out, Cast(*out, options));
+ }
+ return Status::OK();
+}
+
+Status OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (out->is_scalar()) {
+ out->scalar()->is_valid = false;
+ } else {
+ ArrayData* output = out->mutable_array();
+ output->buffers = {nullptr};
+ output->null_count = batch.length;
+ }
+ return Status::OK();
+}
+
+Status CastFromExtension(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const CastOptions& options = checked_cast<const CastState*>(ctx->state())->options;
+
+ if (batch[0].kind() == Datum::SCALAR) {
+ const auto& ext_scalar = checked_cast<const ExtensionScalar&>(*batch[0].scalar());
+ Datum casted_storage;
+
+ if (ext_scalar.is_valid) {
+ return Cast(ext_scalar.value, out->type(), options, ctx->exec_context()).Value(out);
+ } else {
+ const auto& storage_type =
+ checked_cast<const ExtensionType&>(*ext_scalar.type).storage_type();
+ return Cast(MakeNullScalar(storage_type), out->type(), options, ctx->exec_context())
+ .Value(out);
+ }
+ } else {
+ DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
+ ExtensionArray extension(batch[0].array());
+ return Cast(*extension.storage(), out->type(), options, ctx->exec_context())
+ .Value(out);
+ }
+}
+
+Status CastFromNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (!batch[0].is_scalar()) {
+ ArrayData* output = out->mutable_array();
+ std::shared_ptr<Array> nulls;
+ RETURN_NOT_OK(MakeArrayOfNull(output->type, batch.length).Value(&nulls));
+ out->value = nulls->data();
+ }
+ return Status::OK();
+}
+
+Result<ValueDescr> ResolveOutputFromOptions(KernelContext* ctx,
+ const std::vector<ValueDescr>& args) {
+ const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options;
+ return ValueDescr(options.to_type, args[0].shape);
+}
+
+/// You will see some of kernels with
+///
+/// kOutputTargetType
+///
+/// for their output type resolution. This is somewhat of an eyesore but the
+/// easiest initial way to get the requested cast type including the TimeUnit
+/// to the kernel (which is needed to compute the output) was through
+/// CastOptions
+
+OutputType kOutputTargetType(ResolveOutputFromOptions);
+
+Status ZeroCopyCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
+ // Make a copy of the buffers into a destination array without carrying
+ // the type
+ const ArrayData& input = *batch[0].array();
+ ArrayData* output = out->mutable_array();
+ output->length = input.length;
+ output->SetNullCount(input.null_count);
+ output->buffers = input.buffers;
+ output->offset = input.offset;
+ output->child_data = input.child_data;
+ return Status::OK();
+}
+
+void AddZeroCopyCast(Type::type in_type_id, InputType in_type, OutputType out_type,
+ CastFunction* func) {
+ auto sig = KernelSignature::Make({in_type}, out_type);
+ ScalarKernel kernel;
+ kernel.exec = TrivialScalarUnaryAsArraysExec(ZeroCopyCastExec);
+ kernel.signature = sig;
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ DCHECK_OK(func->AddKernel(in_type_id, std::move(kernel)));
+}
+
+static bool CanCastFromDictionary(Type::type type_id) {
+ return (is_primitive(type_id) || is_base_binary_like(type_id) ||
+ is_fixed_size_binary(type_id));
+}
+
+void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* func) {
+ // From null to this type
+ ScalarKernel kernel;
+ kernel.exec = CastFromNull;
+ kernel.signature = KernelSignature::Make({null()}, out_ty);
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ DCHECK_OK(func->AddKernel(Type::NA, std::move(kernel)));
+
+ // From dictionary to this type
+ if (CanCastFromDictionary(out_type_id)) {
+ // Dictionary unpacking not implemented for boolean or nested types.
+ //
+ // XXX: Uses Take and does its own memory allocation for the moment. We can
+ // fix this later.
+ DCHECK_OK(func->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, out_ty,
+ TrivialScalarUnaryAsArraysExec(UnpackDictionary),
+ NullHandling::COMPUTED_NO_PREALLOCATE,
+ MemAllocation::NO_PREALLOCATE));
+ }
+
+ // From extension type to this type
+ DCHECK_OK(func->AddKernel(Type::EXTENSION, {InputType(Type::EXTENSION)}, out_ty,
+ CastFromExtension, NullHandling::COMPUTED_NO_PREALLOCATE,
+ MemAllocation::NO_PREALLOCATE));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.h b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.h
new file mode 100644
index 000000000..2419d898a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_internal.h
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/cast.h" // IWYU pragma: export
+#include "arrow/compute/cast_internal.h" // IWYU pragma: export
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/util_internal.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+namespace internal {
+
+template <typename OutType, typename InType, typename Enable = void>
+struct CastFunctor {};
+
+// No-op functor for identity casts
+template <typename O, typename I>
+struct CastFunctor<
+ O, I, enable_if_t<std::is_same<O, I>::value && is_parameter_free_type<I>::value>> {
+ static Status Exec(KernelContext*, const ExecBatch&, Datum*) { return Status::OK(); }
+};
+
+Status CastFromExtension(KernelContext* ctx, const ExecBatch& batch, Datum* out);
+
+// Utility for numeric casts
+void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, const Datum& input,
+ Datum* out);
+
+// ----------------------------------------------------------------------
+// Dictionary to other things
+
+Status UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out);
+
+Status OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out);
+
+Status CastFromNull(KernelContext* ctx, const ExecBatch& batch, Datum* out);
+
+// Adds a cast function where CastFunctor is specialized and the input and output
+// types are parameter free (have a type_singleton). Scalar inputs are handled by
+// wrapping with TrivialScalarUnaryAsArraysExec.
+template <typename InType, typename OutType>
+void AddSimpleCast(InputType in_ty, OutputType out_ty, CastFunction* func) {
+ DCHECK_OK(func->AddKernel(
+ InType::type_id, {in_ty}, out_ty,
+ TrivialScalarUnaryAsArraysExec(CastFunctor<OutType, InType>::Exec)));
+}
+
+Status ZeroCopyCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* out);
+
+void AddZeroCopyCast(Type::type in_type_id, InputType in_type, OutputType out_type,
+ CastFunction* func);
+
+// OutputType::Resolver that returns a descr with the shape of the input
+// argument and the type from CastOptions
+Result<ValueDescr> ResolveOutputFromOptions(KernelContext* ctx,
+ const std::vector<ValueDescr>& args);
+
+ARROW_EXPORT extern OutputType kOutputTargetType;
+
+// Add generic casts to out_ty from:
+// - the null type
+// - dictionary with out_ty as given value type
+// - extension types with a compatible storage type
+void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* func);
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
new file mode 100644
index 000000000..ab583bbbe
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
@@ -0,0 +1,188 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Implementation of casting to (or between) list types
+
+#include <limits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/builder_nested.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/scalar_cast_internal.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/int_util.h"
+
+namespace arrow {
+
+using internal::CopyBitmap;
+
+namespace compute {
+namespace internal {
+
+namespace {
+
+// (Large)List<T> -> (Large)List<U>
+
+template <typename SrcType, typename DestType>
+typename std::enable_if<SrcType::type_id == DestType::type_id, Status>::type
+CastListOffsets(KernelContext* ctx, const ArrayData& in_array, ArrayData* out_array) {
+ return Status::OK();
+}
+
+template <typename SrcType, typename DestType>
+typename std::enable_if<SrcType::type_id != DestType::type_id, Status>::type
+CastListOffsets(KernelContext* ctx, const ArrayData& in_array, ArrayData* out_array) {
+ using src_offset_type = typename SrcType::offset_type;
+ using dest_offset_type = typename DestType::offset_type;
+
+ ARROW_ASSIGN_OR_RAISE(out_array->buffers[1],
+ ctx->Allocate(sizeof(dest_offset_type) * (in_array.length + 1)));
+ ::arrow::internal::CastInts(in_array.GetValues<src_offset_type>(1),
+ out_array->GetMutableValues<dest_offset_type>(1),
+ in_array.length + 1);
+ return Status::OK();
+}
+
+template <typename SrcType, typename DestType>
+struct CastList {
+ using src_offset_type = typename SrcType::offset_type;
+ using dest_offset_type = typename DestType::offset_type;
+
+ static constexpr bool is_upcast = sizeof(src_offset_type) < sizeof(dest_offset_type);
+ static constexpr bool is_downcast = sizeof(src_offset_type) > sizeof(dest_offset_type);
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const CastOptions& options = CastState::Get(ctx);
+
+ auto child_type = checked_cast<const DestType&>(*out->type()).value_type();
+
+ if (out->kind() == Datum::SCALAR) {
+ // The scalar case is simple, as only the underlying values must be cast
+ const auto& in_scalar = checked_cast<const BaseListScalar&>(*batch[0].scalar());
+ auto out_scalar = checked_cast<BaseListScalar*>(out->scalar().get());
+
+ DCHECK(!out_scalar->is_valid);
+ if (in_scalar.is_valid) {
+ ARROW_ASSIGN_OR_RAISE(out_scalar->value, Cast(*in_scalar.value, child_type,
+ options, ctx->exec_context()));
+
+ out_scalar->is_valid = true;
+ }
+ return Status::OK();
+ }
+
+ const ArrayData& in_array = *batch[0].array();
+ auto offsets = in_array.GetValues<src_offset_type>(1);
+ Datum values = in_array.child_data[0];
+
+ ArrayData* out_array = out->mutable_array();
+ out_array->buffers = in_array.buffers;
+
+ // Shift bitmap in case the source offset is non-zero
+ if (in_array.offset != 0 && in_array.buffers[0]) {
+ ARROW_ASSIGN_OR_RAISE(out_array->buffers[0],
+ CopyBitmap(ctx->memory_pool(), in_array.buffers[0]->data(),
+ in_array.offset, in_array.length));
+ }
+
+ // Handle list offsets
+ // Several cases can arise:
+ // - the source offset is non-zero, in which case we slice the underlying values
+ // and shift the list offsets (regardless of their respective types)
+ // - the source offset is zero but source and destination types have
+ // different list offset types, in which case we cast the list offsets
+ // - otherwise, we simply keep the original list offsets
+ if (is_downcast) {
+ if (offsets[in_array.length] > std::numeric_limits<dest_offset_type>::max()) {
+ return Status::Invalid("Array of type ", in_array.type->ToString(),
+ " too large to convert to ", out_array->type->ToString());
+ }
+ }
+
+ if (in_array.offset != 0) {
+ ARROW_ASSIGN_OR_RAISE(
+ out_array->buffers[1],
+ ctx->Allocate(sizeof(dest_offset_type) * (in_array.length + 1)));
+
+ auto shifted_offsets = out_array->GetMutableValues<dest_offset_type>(1);
+ for (int64_t i = 0; i < in_array.length + 1; ++i) {
+ shifted_offsets[i] = static_cast<dest_offset_type>(offsets[i] - offsets[0]);
+ }
+ values = in_array.child_data[0]->Slice(offsets[0], offsets[in_array.length]);
+ } else {
+ RETURN_NOT_OK((CastListOffsets<SrcType, DestType>(ctx, in_array, out_array)));
+ }
+
+ // Handle values
+ ARROW_ASSIGN_OR_RAISE(Datum cast_values,
+ Cast(values, child_type, options, ctx->exec_context()));
+
+ DCHECK_EQ(Datum::ARRAY, cast_values.kind());
+ out_array->child_data.push_back(cast_values.array());
+ return Status::OK();
+ }
+};
+
+template <typename SrcType, typename DestType>
+void AddListCast(CastFunction* func) {
+ ScalarKernel kernel;
+ kernel.exec = CastList<SrcType, DestType>::Exec;
+ kernel.signature =
+ KernelSignature::Make({InputType(SrcType::type_id)}, kOutputTargetType);
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ DCHECK_OK(func->AddKernel(SrcType::type_id, std::move(kernel)));
+}
+
+} // namespace
+
+std::vector<std::shared_ptr<CastFunction>> GetNestedCasts() {
+ // We use the list<T> from the CastOptions when resolving the output type
+
+ auto cast_list = std::make_shared<CastFunction>("cast_list", Type::LIST);
+ AddCommonCasts(Type::LIST, kOutputTargetType, cast_list.get());
+ AddListCast<ListType, ListType>(cast_list.get());
+ AddListCast<LargeListType, ListType>(cast_list.get());
+
+ auto cast_large_list =
+ std::make_shared<CastFunction>("cast_large_list", Type::LARGE_LIST);
+ AddCommonCasts(Type::LARGE_LIST, kOutputTargetType, cast_large_list.get());
+ AddListCast<ListType, LargeListType>(cast_large_list.get());
+ AddListCast<LargeListType, LargeListType>(cast_large_list.get());
+
+ // FSL is a bit incomplete at the moment
+ auto cast_fsl =
+ std::make_shared<CastFunction>("cast_fixed_size_list", Type::FIXED_SIZE_LIST);
+ AddCommonCasts(Type::FIXED_SIZE_LIST, kOutputTargetType, cast_fsl.get());
+
+ // So is struct
+ auto cast_struct = std::make_shared<CastFunction>("cast_struct", Type::STRUCT);
+ AddCommonCasts(Type::STRUCT, kOutputTargetType, cast_struct.get());
+
+ // So is dictionary
+ auto cast_dictionary =
+ std::make_shared<CastFunction>("cast_dictionary", Type::DICTIONARY);
+ AddCommonCasts(Type::DICTIONARY, kOutputTargetType, cast_dictionary.get());
+
+ return {cast_list, cast_large_list, cast_fsl, cast_struct, cast_dictionary};
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
new file mode 100644
index 000000000..1ce6896de
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
@@ -0,0 +1,784 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Implementation of casting to integer, floating point, or decimal types
+
+#include "arrow/array/builder_primitive.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/scalar_cast_internal.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/int_util.h"
+#include "arrow/util/value_parsing.h"
+
+namespace arrow {
+
+using internal::BitBlockCount;
+using internal::CheckIntegersInRange;
+using internal::IntegersCanFit;
+using internal::OptionalBitBlockCounter;
+using internal::ParseValue;
+
+namespace compute {
+namespace internal {
+
+Status CastIntegerToInteger(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& options = checked_cast<const CastState*>(ctx->state())->options;
+ if (!options.allow_int_overflow) {
+ RETURN_NOT_OK(IntegersCanFit(batch[0], *out->type()));
+ }
+ CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out);
+ return Status::OK();
+}
+
+Status CastFloatingToFloating(KernelContext*, const ExecBatch& batch, Datum* out) {
+ CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out);
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Implement fast safe floating point to integer cast
+
+// InType is a floating point type we are planning to cast to integer
+template <typename InType, typename OutType, typename InT = typename InType::c_type,
+ typename OutT = typename OutType::c_type>
+ARROW_DISABLE_UBSAN("float-cast-overflow")
+Status CheckFloatTruncation(const Datum& input, const Datum& output) {
+ auto WasTruncated = [&](OutT out_val, InT in_val) -> bool {
+ return static_cast<InT>(out_val) != in_val;
+ };
+ auto WasTruncatedMaybeNull = [&](OutT out_val, InT in_val, bool is_valid) -> bool {
+ return is_valid && static_cast<InT>(out_val) != in_val;
+ };
+ auto GetErrorMessage = [&](InT val) {
+ return Status::Invalid("Float value ", val, " was truncated converting to ",
+ *output.type());
+ };
+
+ if (input.kind() == Datum::SCALAR) {
+ DCHECK_EQ(output.kind(), Datum::SCALAR);
+ const auto& in_scalar = input.scalar_as<typename TypeTraits<InType>::ScalarType>();
+ const auto& out_scalar = output.scalar_as<typename TypeTraits<OutType>::ScalarType>();
+ if (WasTruncatedMaybeNull(out_scalar.value, in_scalar.value, out_scalar.is_valid)) {
+ return GetErrorMessage(in_scalar.value);
+ }
+ return Status::OK();
+ }
+
+ const ArrayData& in_array = *input.array();
+ const ArrayData& out_array = *output.array();
+
+ const InT* in_data = in_array.GetValues<InT>(1);
+ const OutT* out_data = out_array.GetValues<OutT>(1);
+
+ const uint8_t* bitmap = nullptr;
+ if (in_array.buffers[0]) {
+ bitmap = in_array.buffers[0]->data();
+ }
+ OptionalBitBlockCounter bit_counter(bitmap, in_array.offset, in_array.length);
+ int64_t position = 0;
+ int64_t offset_position = in_array.offset;
+ while (position < in_array.length) {
+ BitBlockCount block = bit_counter.NextBlock();
+ bool block_out_of_bounds = false;
+ if (block.popcount == block.length) {
+ // Fast path: branchless
+ for (int64_t i = 0; i < block.length; ++i) {
+ block_out_of_bounds |= WasTruncated(out_data[i], in_data[i]);
+ }
+ } else if (block.popcount > 0) {
+ // Indices have nulls, must only boundscheck non-null values
+ for (int64_t i = 0; i < block.length; ++i) {
+ block_out_of_bounds |= WasTruncatedMaybeNull(
+ out_data[i], in_data[i], BitUtil::GetBit(bitmap, offset_position + i));
+ }
+ }
+ if (ARROW_PREDICT_FALSE(block_out_of_bounds)) {
+ if (in_array.GetNullCount() > 0) {
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (WasTruncatedMaybeNull(out_data[i], in_data[i],
+ BitUtil::GetBit(bitmap, offset_position + i))) {
+ return GetErrorMessage(in_data[i]);
+ }
+ }
+ } else {
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (WasTruncated(out_data[i], in_data[i])) {
+ return GetErrorMessage(in_data[i]);
+ }
+ }
+ }
+ }
+ in_data += block.length;
+ out_data += block.length;
+ position += block.length;
+ offset_position += block.length;
+ }
+ return Status::OK();
+}
+
+template <typename InType>
+Status CheckFloatToIntTruncationImpl(const Datum& input, const Datum& output) {
+ switch (output.type()->id()) {
+ case Type::INT8:
+ return CheckFloatTruncation<InType, Int8Type>(input, output);
+ case Type::INT16:
+ return CheckFloatTruncation<InType, Int16Type>(input, output);
+ case Type::INT32:
+ return CheckFloatTruncation<InType, Int32Type>(input, output);
+ case Type::INT64:
+ return CheckFloatTruncation<InType, Int64Type>(input, output);
+ case Type::UINT8:
+ return CheckFloatTruncation<InType, UInt8Type>(input, output);
+ case Type::UINT16:
+ return CheckFloatTruncation<InType, UInt16Type>(input, output);
+ case Type::UINT32:
+ return CheckFloatTruncation<InType, UInt32Type>(input, output);
+ case Type::UINT64:
+ return CheckFloatTruncation<InType, UInt64Type>(input, output);
+ default:
+ break;
+ }
+ DCHECK(false);
+ return Status::OK();
+}
+
+Status CheckFloatToIntTruncation(const Datum& input, const Datum& output) {
+ switch (input.type()->id()) {
+ case Type::FLOAT:
+ return CheckFloatToIntTruncationImpl<FloatType>(input, output);
+ case Type::DOUBLE:
+ return CheckFloatToIntTruncationImpl<DoubleType>(input, output);
+ default:
+ break;
+ }
+ DCHECK(false);
+ return Status::OK();
+}
+
+Status CastFloatingToInteger(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& options = checked_cast<const CastState*>(ctx->state())->options;
+ CastNumberToNumberUnsafe(batch[0].type()->id(), out->type()->id(), batch[0], out);
+ if (!options.allow_float_truncate) {
+ RETURN_NOT_OK(CheckFloatToIntTruncation(batch[0], *out));
+ }
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Implement fast integer to floating point cast
+
+// These are the limits for exact representation of whole numbers in floating
+// point numbers
+template <typename T>
+struct FloatingIntegerBound {};
+
+template <>
+struct FloatingIntegerBound<float> {
+ static const int64_t value = 1LL << 24;
+};
+
+template <>
+struct FloatingIntegerBound<double> {
+ static const int64_t value = 1LL << 53;
+};
+
+template <typename InType, typename OutType, typename InT = typename InType::c_type,
+ typename OutT = typename OutType::c_type,
+ bool IsSigned = is_signed_integer_type<InType>::value>
+Status CheckIntegerFloatTruncateImpl(const Datum& input) {
+ using InScalarType = typename TypeTraits<InType>::ScalarType;
+ const int64_t limit = FloatingIntegerBound<OutT>::value;
+ InScalarType bound_lower(IsSigned ? -limit : 0);
+ InScalarType bound_upper(limit);
+ return CheckIntegersInRange(input, bound_lower, bound_upper);
+}
+
+Status CheckForIntegerToFloatingTruncation(const Datum& input, Type::type out_type) {
+ switch (input.type()->id()) {
+ // Small integers are all exactly representable as whole numbers
+ case Type::INT8:
+ case Type::INT16:
+ case Type::UINT8:
+ case Type::UINT16:
+ return Status::OK();
+ case Type::INT32: {
+ if (out_type == Type::DOUBLE) {
+ return Status::OK();
+ }
+ return CheckIntegerFloatTruncateImpl<Int32Type, FloatType>(input);
+ }
+ case Type::UINT32: {
+ if (out_type == Type::DOUBLE) {
+ return Status::OK();
+ }
+ return CheckIntegerFloatTruncateImpl<UInt32Type, FloatType>(input);
+ }
+ case Type::INT64: {
+ if (out_type == Type::FLOAT) {
+ return CheckIntegerFloatTruncateImpl<Int64Type, FloatType>(input);
+ } else {
+ return CheckIntegerFloatTruncateImpl<Int64Type, DoubleType>(input);
+ }
+ }
+ case Type::UINT64: {
+ if (out_type == Type::FLOAT) {
+ return CheckIntegerFloatTruncateImpl<UInt64Type, FloatType>(input);
+ } else {
+ return CheckIntegerFloatTruncateImpl<UInt64Type, DoubleType>(input);
+ }
+ }
+ default:
+ break;
+ }
+ DCHECK(false);
+ return Status::OK();
+}
+
+Status CastIntegerToFloating(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& options = checked_cast<const CastState*>(ctx->state())->options;
+ Type::type out_type = out->type()->id();
+ if (!options.allow_float_truncate) {
+ RETURN_NOT_OK(CheckForIntegerToFloatingTruncation(batch[0], out_type));
+ }
+ CastNumberToNumberUnsafe(batch[0].type()->id(), out_type, batch[0], out);
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Boolean to number
+
+struct BooleanToNumber {
+ template <typename OutValue, typename Arg0Value>
+ static OutValue Call(KernelContext*, Arg0Value val, Status*) {
+ constexpr auto kOne = static_cast<OutValue>(1);
+ constexpr auto kZero = static_cast<OutValue>(0);
+ return val ? kOne : kZero;
+ }
+};
+
+template <typename O>
+struct CastFunctor<O, BooleanType, enable_if_number<O>> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return applicator::ScalarUnary<O, BooleanType, BooleanToNumber>::Exec(ctx, batch,
+ out);
+ }
+};
+
+// ----------------------------------------------------------------------
+// String to number
+
+template <typename OutType>
+struct ParseString {
+ template <typename OutValue, typename Arg0Value>
+ OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const {
+ OutValue result = OutValue(0);
+ if (ARROW_PREDICT_FALSE(!ParseValue<OutType>(val.data(), val.size(), &result))) {
+ *st = Status::Invalid("Failed to parse string: '", val, "' as a scalar of type ",
+ TypeTraits<OutType>::type_singleton()->ToString());
+ }
+ return result;
+ }
+};
+
+template <typename O, typename I>
+struct CastFunctor<O, I, enable_if_base_binary<I>> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return applicator::ScalarUnaryNotNull<O, I, ParseString<O>>::Exec(ctx, batch, out);
+ }
+};
+
+// ----------------------------------------------------------------------
+// Decimal to integer
+
+struct DecimalToIntegerMixin {
+ template <typename OutValue, typename Arg0Value>
+ OutValue ToInteger(KernelContext* ctx, const Arg0Value& val, Status* st) const {
+ constexpr auto min_value = std::numeric_limits<OutValue>::min();
+ constexpr auto max_value = std::numeric_limits<OutValue>::max();
+
+ if (!allow_int_overflow_ && ARROW_PREDICT_FALSE(val < min_value || val > max_value)) {
+ *st = Status::Invalid("Integer value out of bounds");
+ return OutValue{}; // Zero
+ } else {
+ return static_cast<OutValue>(val.low_bits());
+ }
+ }
+
+ DecimalToIntegerMixin(int32_t in_scale, bool allow_int_overflow)
+ : in_scale_(in_scale), allow_int_overflow_(allow_int_overflow) {}
+
+ int32_t in_scale_;
+ bool allow_int_overflow_;
+};
+
+struct UnsafeUpscaleDecimalToInteger : public DecimalToIntegerMixin {
+ using DecimalToIntegerMixin::DecimalToIntegerMixin;
+
+ template <typename OutValue, typename Arg0Value>
+ OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const {
+ return ToInteger<OutValue>(ctx, val.IncreaseScaleBy(-in_scale_), st);
+ }
+};
+
+struct UnsafeDownscaleDecimalToInteger : public DecimalToIntegerMixin {
+ using DecimalToIntegerMixin::DecimalToIntegerMixin;
+
+ template <typename OutValue, typename Arg0Value>
+ OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const {
+ return ToInteger<OutValue>(ctx, val.ReduceScaleBy(in_scale_, false), st);
+ }
+};
+
+struct SafeRescaleDecimalToInteger : public DecimalToIntegerMixin {
+ using DecimalToIntegerMixin::DecimalToIntegerMixin;
+
+ template <typename OutValue, typename Arg0Value>
+ OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const {
+ auto result = val.Rescale(in_scale_, 0);
+ if (ARROW_PREDICT_FALSE(!result.ok())) {
+ *st = result.status();
+ return OutValue{}; // Zero
+ } else {
+ return ToInteger<OutValue>(ctx, *result, st);
+ }
+ }
+};
+
+template <typename O, typename I>
+struct CastFunctor<O, I,
+ enable_if_t<is_integer_type<O>::value && is_decimal_type<I>::value>> {
+ using out_type = typename O::c_type;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& options = checked_cast<const CastState*>(ctx->state())->options;
+
+ const auto& in_type_inst = checked_cast<const I&>(*batch[0].type());
+ const auto in_scale = in_type_inst.scale();
+
+ if (options.allow_decimal_truncate) {
+ if (in_scale < 0) {
+ // Unsafe upscale
+ applicator::ScalarUnaryNotNullStateful<O, I, UnsafeUpscaleDecimalToInteger>
+ kernel(UnsafeUpscaleDecimalToInteger{in_scale, options.allow_int_overflow});
+ return kernel.Exec(ctx, batch, out);
+ } else {
+ // Unsafe downscale
+ applicator::ScalarUnaryNotNullStateful<O, I, UnsafeDownscaleDecimalToInteger>
+ kernel(UnsafeDownscaleDecimalToInteger{in_scale, options.allow_int_overflow});
+ return kernel.Exec(ctx, batch, out);
+ }
+ } else {
+ // Safe rescale
+ applicator::ScalarUnaryNotNullStateful<O, I, SafeRescaleDecimalToInteger> kernel(
+ SafeRescaleDecimalToInteger{in_scale, options.allow_int_overflow});
+ return kernel.Exec(ctx, batch, out);
+ }
+ }
+};
+
+// ----------------------------------------------------------------------
+// Integer to decimal
+
+struct IntegerToDecimal {
+ template <typename OutValue, typename IntegerType>
+ OutValue Call(KernelContext*, IntegerType val, Status* st) const {
+ auto maybe_decimal = OutValue(val).Rescale(0, out_scale_);
+ if (ARROW_PREDICT_TRUE(maybe_decimal.ok())) {
+ return maybe_decimal.MoveValueUnsafe();
+ }
+ *st = maybe_decimal.status();
+ return OutValue{};
+ }
+
+ int32_t out_scale_;
+};
+
+template <typename O, typename I>
+struct CastFunctor<O, I,
+ enable_if_t<is_decimal_type<O>::value && is_integer_type<I>::value>> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& out_type = checked_cast<const O&>(*out->type());
+ const auto out_scale = out_type.scale();
+ const auto out_precision = out_type.precision();
+
+ // verify precision and scale
+ if (out_scale < 0) {
+ return Status::Invalid("Scale must be non-negative");
+ }
+ ARROW_ASSIGN_OR_RAISE(int32_t precision, MaxDecimalDigitsForInteger(I::type_id));
+ precision += out_scale;
+ if (out_precision < precision) {
+ return Status::Invalid(
+ "Precision is not great enough for the result. "
+ "It should be at least ",
+ precision);
+ }
+
+ applicator::ScalarUnaryNotNullStateful<O, I, IntegerToDecimal> kernel(
+ IntegerToDecimal{out_scale});
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+// ----------------------------------------------------------------------
+// Decimal to decimal
+
+// Helper that converts the input and output decimals
+// For instance, Decimal128 -> Decimal256 requires converting, then scaling
+// Decimal256 -> Decimal128 requires scaling, then truncating
+template <typename OutDecimal, typename InDecimal>
+struct DecimalConversions {};
+
+template <typename InDecimal>
+struct DecimalConversions<Decimal256, InDecimal> {
+ // Convert then scale
+ static Decimal256 ConvertInput(InDecimal&& val) { return Decimal256(val); }
+ static Decimal256 ConvertOutput(Decimal256&& val) { return val; }
+};
+
+template <>
+struct DecimalConversions<Decimal128, Decimal256> {
+ // Scale then truncate
+ static Decimal256 ConvertInput(Decimal256&& val) { return val; }
+ static Decimal128 ConvertOutput(Decimal256&& val) {
+ const auto array_le = BitUtil::LittleEndianArray::Make(val.native_endian_array());
+ return Decimal128(array_le[1], array_le[0]);
+ }
+};
+
+template <>
+struct DecimalConversions<Decimal128, Decimal128> {
+ static Decimal128 ConvertInput(Decimal128&& val) { return val; }
+ static Decimal128 ConvertOutput(Decimal128&& val) { return val; }
+};
+
+struct UnsafeUpscaleDecimal {
+ template <typename OutValue, typename Arg0Value>
+ OutValue Call(KernelContext*, Arg0Value val, Status*) const {
+ using Conv = DecimalConversions<OutValue, Arg0Value>;
+ return Conv::ConvertOutput(Conv::ConvertInput(std::move(val)).IncreaseScaleBy(by_));
+ }
+ int32_t by_;
+};
+
+struct UnsafeDownscaleDecimal {
+ template <typename OutValue, typename Arg0Value>
+ OutValue Call(KernelContext*, Arg0Value val, Status*) const {
+ using Conv = DecimalConversions<OutValue, Arg0Value>;
+ return Conv::ConvertOutput(
+ Conv::ConvertInput(std::move(val)).ReduceScaleBy(by_, false));
+ }
+ int32_t by_;
+};
+
+struct SafeRescaleDecimal {
+ template <typename OutValue, typename Arg0Value>
+ OutValue Call(KernelContext*, Arg0Value val, Status* st) const {
+ using Conv = DecimalConversions<OutValue, Arg0Value>;
+ auto maybe_rescaled =
+ Conv::ConvertInput(std::move(val)).Rescale(in_scale_, out_scale_);
+ if (ARROW_PREDICT_FALSE(!maybe_rescaled.ok())) {
+ *st = maybe_rescaled.status();
+ return {}; // Zero
+ }
+
+ if (ARROW_PREDICT_TRUE(maybe_rescaled->FitsInPrecision(out_precision_))) {
+ return Conv::ConvertOutput(maybe_rescaled.MoveValueUnsafe());
+ }
+
+ *st = Status::Invalid("Decimal value does not fit in precision ", out_precision_);
+ return {}; // Zero
+ }
+
+ int32_t out_scale_, out_precision_, in_scale_;
+};
+
+template <typename O, typename I>
+struct CastFunctor<O, I,
+ enable_if_t<is_decimal_type<O>::value && is_decimal_type<I>::value>> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& options = checked_cast<const CastState*>(ctx->state())->options;
+
+ const auto& in_type = checked_cast<const I&>(*batch[0].type());
+ const auto& out_type = checked_cast<const O&>(*out->type());
+ const auto in_scale = in_type.scale();
+ const auto out_scale = out_type.scale();
+
+ if (options.allow_decimal_truncate) {
+ if (in_scale < out_scale) {
+ // Unsafe upscale
+ applicator::ScalarUnaryNotNullStateful<O, I, UnsafeUpscaleDecimal> kernel(
+ UnsafeUpscaleDecimal{out_scale - in_scale});
+ return kernel.Exec(ctx, batch, out);
+ } else {
+ // Unsafe downscale
+ applicator::ScalarUnaryNotNullStateful<O, I, UnsafeDownscaleDecimal> kernel(
+ UnsafeDownscaleDecimal{in_scale - out_scale});
+ return kernel.Exec(ctx, batch, out);
+ }
+ }
+
+ // Safe rescale
+ applicator::ScalarUnaryNotNullStateful<O, I, SafeRescaleDecimal> kernel(
+ SafeRescaleDecimal{out_scale, out_type.precision(), in_scale});
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+// ----------------------------------------------------------------------
+// Real to decimal
+
+struct RealToDecimal {
+ template <typename OutValue, typename RealType>
+ OutValue Call(KernelContext*, RealType val, Status* st) const {
+ auto maybe_decimal = OutValue::FromReal(val, out_precision_, out_scale_);
+
+ if (ARROW_PREDICT_TRUE(maybe_decimal.ok())) {
+ return maybe_decimal.MoveValueUnsafe();
+ }
+
+ if (!allow_truncate_) {
+ *st = maybe_decimal.status();
+ }
+ return {}; // Zero
+ }
+
+ int32_t out_scale_, out_precision_;
+ bool allow_truncate_;
+};
+
+template <typename O, typename I>
+struct CastFunctor<O, I,
+ enable_if_t<is_decimal_type<O>::value && is_floating_type<I>::value>> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& options = checked_cast<const CastState*>(ctx->state())->options;
+ const auto& out_type = checked_cast<const O&>(*out->type());
+ const auto out_scale = out_type.scale();
+ const auto out_precision = out_type.precision();
+
+ applicator::ScalarUnaryNotNullStateful<O, I, RealToDecimal> kernel(
+ RealToDecimal{out_scale, out_precision, options.allow_decimal_truncate});
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+// ----------------------------------------------------------------------
+// Decimal to real
+
+struct DecimalToReal {
+ template <typename RealType, typename Arg0Value>
+ RealType Call(KernelContext*, const Arg0Value& val, Status*) const {
+ return val.template ToReal<RealType>(in_scale_);
+ }
+
+ int32_t in_scale_;
+};
+
+template <typename O, typename I>
+struct CastFunctor<O, I,
+ enable_if_t<is_floating_type<O>::value && is_decimal_type<I>::value>> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& in_type = checked_cast<const I&>(*batch[0].type());
+ const auto in_scale = in_type.scale();
+
+ applicator::ScalarUnaryNotNullStateful<O, I, DecimalToReal> kernel(
+ DecimalToReal{in_scale});
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+// ----------------------------------------------------------------------
+// Top-level kernel instantiation
+
+namespace {
+
+template <typename OutType>
+void AddCommonNumberCasts(const std::shared_ptr<DataType>& out_ty, CastFunction* func) {
+ AddCommonCasts(out_ty->id(), out_ty, func);
+
+ // Cast from boolean to number
+ DCHECK_OK(func->AddKernel(Type::BOOL, {boolean()}, out_ty,
+ CastFunctor<OutType, BooleanType>::Exec));
+
+ // Cast from other strings
+ for (const std::shared_ptr<DataType>& in_ty : BaseBinaryTypes()) {
+ auto exec = GenerateVarBinaryBase<CastFunctor, OutType>(*in_ty);
+ DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, exec));
+ }
+}
+
+template <typename OutType>
+std::shared_ptr<CastFunction> GetCastToInteger(std::string name) {
+ auto func = std::make_shared<CastFunction>(std::move(name), OutType::type_id);
+ auto out_ty = TypeTraits<OutType>::type_singleton();
+
+ for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
+ DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastIntegerToInteger));
+ }
+
+ // Cast from floating point
+ for (const std::shared_ptr<DataType>& in_ty : FloatingPointTypes()) {
+ DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToInteger));
+ }
+
+ // From other numbers to integer
+ AddCommonNumberCasts<OutType>(out_ty, func.get());
+
+ // From decimal to integer
+ DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType(Type::DECIMAL)}, out_ty,
+ CastFunctor<OutType, Decimal128Type>::Exec));
+ DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty,
+ CastFunctor<OutType, Decimal256Type>::Exec));
+ return func;
+}
+
+template <typename OutType>
+std::shared_ptr<CastFunction> GetCastToFloating(std::string name) {
+ auto func = std::make_shared<CastFunction>(std::move(name), OutType::type_id);
+ auto out_ty = TypeTraits<OutType>::type_singleton();
+
+ // Casts from integer to floating point
+ for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
+ DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastIntegerToFloating));
+ }
+
+ // Cast from floating point
+ for (const std::shared_ptr<DataType>& in_ty : FloatingPointTypes()) {
+ DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToFloating));
+ }
+
+ // From other numbers to floating point
+ AddCommonNumberCasts<OutType>(out_ty, func.get());
+
+ // From decimal to floating point
+ DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType(Type::DECIMAL)}, out_ty,
+ CastFunctor<OutType, Decimal128Type>::Exec));
+ DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty,
+ CastFunctor<OutType, Decimal256Type>::Exec));
+ return func;
+}
+
+std::shared_ptr<CastFunction> GetCastToDecimal128() {
+ OutputType sig_out_ty(ResolveOutputFromOptions);
+
+ auto func = std::make_shared<CastFunction>("cast_decimal", Type::DECIMAL128);
+ AddCommonCasts(Type::DECIMAL128, sig_out_ty, func.get());
+
+ // Cast from floating point
+ DCHECK_OK(func->AddKernel(Type::FLOAT, {float32()}, sig_out_ty,
+ CastFunctor<Decimal128Type, FloatType>::Exec));
+ DCHECK_OK(func->AddKernel(Type::DOUBLE, {float64()}, sig_out_ty,
+ CastFunctor<Decimal128Type, DoubleType>::Exec));
+
+ // Cast from integer
+ for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
+ auto exec = GenerateInteger<CastFunctor, Decimal128Type>(in_ty->id());
+ DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
+ }
+
+ // Cast from other decimal
+ auto exec = CastFunctor<Decimal128Type, Decimal128Type>::Exec;
+ // We resolve the output type of this kernel from the CastOptions
+ DCHECK_OK(
+ func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
+ exec = CastFunctor<Decimal128Type, Decimal256Type>::Exec;
+ DCHECK_OK(
+ func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, sig_out_ty, exec));
+ return func;
+}
+
+std::shared_ptr<CastFunction> GetCastToDecimal256() {
+ OutputType sig_out_ty(ResolveOutputFromOptions);
+
+ auto func = std::make_shared<CastFunction>("cast_decimal256", Type::DECIMAL256);
+ AddCommonCasts(Type::DECIMAL256, sig_out_ty, func.get());
+
+ // Cast from floating point
+ DCHECK_OK(func->AddKernel(Type::FLOAT, {float32()}, sig_out_ty,
+ CastFunctor<Decimal256Type, FloatType>::Exec));
+ DCHECK_OK(func->AddKernel(Type::DOUBLE, {float64()}, sig_out_ty,
+ CastFunctor<Decimal256Type, DoubleType>::Exec));
+
+ // Cast from integer
+ for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
+ auto exec = GenerateInteger<CastFunctor, Decimal256Type>(in_ty->id());
+ DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
+ }
+
+ // Cast from other decimal
+ auto exec = CastFunctor<Decimal256Type, Decimal128Type>::Exec;
+ DCHECK_OK(
+ func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
+ exec = CastFunctor<Decimal256Type, Decimal256Type>::Exec;
+ DCHECK_OK(
+ func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, sig_out_ty, exec));
+ return func;
+}
+
+} // namespace
+
+std::vector<std::shared_ptr<CastFunction>> GetNumericCasts() {
+ std::vector<std::shared_ptr<CastFunction>> functions;
+
+ // Make a cast to null that does not do much. Not sure why we need to be able
+ // to cast from dict<null> -> null but there are unit tests for it
+ auto cast_null = std::make_shared<CastFunction>("cast_null", Type::NA);
+ DCHECK_OK(cast_null->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, null(),
+ OutputAllNull));
+ functions.push_back(cast_null);
+
+ functions.push_back(GetCastToInteger<Int8Type>("cast_int8"));
+ functions.push_back(GetCastToInteger<Int16Type>("cast_int16"));
+
+ auto cast_int32 = GetCastToInteger<Int32Type>("cast_int32");
+ // Convert DATE32 or TIME32 to INT32 zero copy
+ AddZeroCopyCast(Type::DATE32, date32(), int32(), cast_int32.get());
+ AddZeroCopyCast(Type::TIME32, InputType(Type::TIME32), int32(), cast_int32.get());
+ functions.push_back(cast_int32);
+
+ auto cast_int64 = GetCastToInteger<Int64Type>("cast_int64");
+ // Convert DATE64, DURATION, TIMESTAMP, TIME64 to INT64 zero copy
+ AddZeroCopyCast(Type::DATE64, InputType(Type::DATE64), int64(), cast_int64.get());
+ AddZeroCopyCast(Type::DURATION, InputType(Type::DURATION), int64(), cast_int64.get());
+ AddZeroCopyCast(Type::TIMESTAMP, InputType(Type::TIMESTAMP), int64(), cast_int64.get());
+ AddZeroCopyCast(Type::TIME64, InputType(Type::TIME64), int64(), cast_int64.get());
+ functions.push_back(cast_int64);
+
+ functions.push_back(GetCastToInteger<UInt8Type>("cast_uint8"));
+ functions.push_back(GetCastToInteger<UInt16Type>("cast_uint16"));
+ functions.push_back(GetCastToInteger<UInt32Type>("cast_uint32"));
+ functions.push_back(GetCastToInteger<UInt64Type>("cast_uint64"));
+
+ // HalfFloat is a bit brain-damaged for now
+ auto cast_half_float =
+ std::make_shared<CastFunction>("cast_half_float", Type::HALF_FLOAT);
+ AddCommonCasts(Type::HALF_FLOAT, float16(), cast_half_float.get());
+ functions.push_back(cast_half_float);
+
+ functions.push_back(GetCastToFloating<FloatType>("cast_float"));
+ functions.push_back(GetCastToFloating<DoubleType>("cast_double"));
+
+ functions.push_back(GetCastToDecimal128());
+ functions.push_back(GetCastToDecimal256());
+
+ return functions;
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_string.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_string.cc
new file mode 100644
index 000000000..eb2f90439
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_string.cc
@@ -0,0 +1,374 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <limits>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/scalar_cast_internal.h"
+#include "arrow/result.h"
+#include "arrow/util/formatting.h"
+#include "arrow/util/int_util.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/utf8.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::StringFormatter;
+using util::InitializeUTF8;
+using util::ValidateUTF8;
+
+namespace compute {
+namespace internal {
+
+namespace {
+
+// ----------------------------------------------------------------------
+// Number / Boolean to String
+
+template <typename O, typename I>
+struct NumericToStringCastFunctor {
+ using value_type = typename TypeTraits<I>::CType;
+ using BuilderType = typename TypeTraits<O>::BuilderType;
+ using FormatterType = StringFormatter<I>;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK(out->is_array());
+ const ArrayData& input = *batch[0].array();
+ ArrayData* output = out->mutable_array();
+ return Convert(ctx, input, output);
+ }
+
+ static Status Convert(KernelContext* ctx, const ArrayData& input, ArrayData* output) {
+ FormatterType formatter(input.type);
+ BuilderType builder(input.type, ctx->memory_pool());
+ RETURN_NOT_OK(VisitArrayDataInline<I>(
+ input,
+ [&](value_type v) {
+ return formatter(v, [&](util::string_view v) { return builder.Append(v); });
+ },
+ [&]() { return builder.AppendNull(); }));
+
+ std::shared_ptr<Array> output_array;
+ RETURN_NOT_OK(builder.Finish(&output_array));
+ *output = std::move(*output_array->data());
+ return Status::OK();
+ }
+};
+
+// ----------------------------------------------------------------------
+// Temporal to String
+
+template <typename O, typename I>
+struct TemporalToStringCastFunctor {
+ using value_type = typename TypeTraits<I>::CType;
+ using BuilderType = typename TypeTraits<O>::BuilderType;
+ using FormatterType = StringFormatter<I>;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK(out->is_array());
+ const ArrayData& input = *batch[0].array();
+ ArrayData* output = out->mutable_array();
+ return Convert(ctx, input, output);
+ }
+
+ static Status Convert(KernelContext* ctx, const ArrayData& input, ArrayData* output) {
+ FormatterType formatter(input.type);
+ BuilderType builder(input.type, ctx->memory_pool());
+ RETURN_NOT_OK(VisitArrayDataInline<I>(
+ input,
+ [&](value_type v) {
+ return formatter(v, [&](util::string_view v) { return builder.Append(v); });
+ },
+ [&]() { return builder.AppendNull(); }));
+
+ std::shared_ptr<Array> output_array;
+ RETURN_NOT_OK(builder.Finish(&output_array));
+ *output = std::move(*output_array->data());
+ return Status::OK();
+ }
+};
+
+// ----------------------------------------------------------------------
+// Binary-like to binary-like
+//
+
+#if defined(_MSC_VER)
+// Silence warning: """'visitor': unreferenced local variable"""
+#pragma warning(push)
+#pragma warning(disable : 4101)
+#endif
+
+struct Utf8Validator {
+ Status VisitNull() { return Status::OK(); }
+
+ Status VisitValue(util::string_view str) {
+ if (ARROW_PREDICT_FALSE(!ValidateUTF8(str))) {
+ return Status::Invalid("Invalid UTF8 payload");
+ }
+ return Status::OK();
+ }
+};
+
+template <typename I, typename O>
+Status CastBinaryToBinaryOffsets(KernelContext* ctx, const ArrayData& input,
+ ArrayData* output) {
+ static_assert(std::is_same<I, O>::value, "Cast same-width offsets (no-op)");
+ return Status::OK();
+}
+
+// Upcast offsets
+template <>
+Status CastBinaryToBinaryOffsets<int32_t, int64_t>(KernelContext* ctx,
+ const ArrayData& input,
+ ArrayData* output) {
+ using input_offset_type = int32_t;
+ using output_offset_type = int64_t;
+ ARROW_ASSIGN_OR_RAISE(
+ output->buffers[1],
+ ctx->Allocate((output->length + output->offset + 1) * sizeof(output_offset_type)));
+ memset(output->buffers[1]->mutable_data(), 0,
+ output->offset * sizeof(output_offset_type));
+ ::arrow::internal::CastInts(input.GetValues<input_offset_type>(1),
+ output->GetMutableValues<output_offset_type>(1),
+ output->length + 1);
+ return Status::OK();
+}
+
+// Downcast offsets
+template <>
+Status CastBinaryToBinaryOffsets<int64_t, int32_t>(KernelContext* ctx,
+ const ArrayData& input,
+ ArrayData* output) {
+ using input_offset_type = int64_t;
+ using output_offset_type = int32_t;
+
+ constexpr input_offset_type kMaxOffset = std::numeric_limits<output_offset_type>::max();
+
+ auto input_offsets = input.GetValues<input_offset_type>(1);
+
+ // Binary offsets are ascending, so it's enough to check the last one for overflow.
+ if (input_offsets[input.length] > kMaxOffset) {
+ return Status::Invalid("Failed casting from ", input.type->ToString(), " to ",
+ output->type->ToString(), ": input array too large");
+ } else {
+ ARROW_ASSIGN_OR_RAISE(output->buffers[1],
+ ctx->Allocate((output->length + output->offset + 1) *
+ sizeof(output_offset_type)));
+ memset(output->buffers[1]->mutable_data(), 0,
+ output->offset * sizeof(output_offset_type));
+ ::arrow::internal::CastInts(input.GetValues<input_offset_type>(1),
+ output->GetMutableValues<output_offset_type>(1),
+ output->length + 1);
+ return Status::OK();
+ }
+}
+
+template <typename O, typename I>
+enable_if_base_binary<I, Status> BinaryToBinaryCastExec(KernelContext* ctx,
+ const ExecBatch& batch,
+ Datum* out) {
+ DCHECK(out->is_array());
+ const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options;
+ const ArrayData& input = *batch[0].array();
+
+ if (!I::is_utf8 && O::is_utf8 && !options.allow_invalid_utf8) {
+ InitializeUTF8();
+
+ ArrayDataVisitor<I> visitor;
+ Utf8Validator validator;
+ RETURN_NOT_OK(visitor.Visit(input, &validator));
+ }
+
+ // Start with a zero-copy cast, but change indices to expected size
+ RETURN_NOT_OK(ZeroCopyCastExec(ctx, batch, out));
+ return CastBinaryToBinaryOffsets<typename I::offset_type, typename O::offset_type>(
+ ctx, input, out->mutable_array());
+}
+
+template <typename O, typename I>
+enable_if_t<std::is_same<I, FixedSizeBinaryType>::value &&
+ !std::is_same<O, FixedSizeBinaryType>::value,
+ Status>
+BinaryToBinaryCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK(out->is_array());
+ const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options;
+ const ArrayData& input = *batch[0].array();
+ ArrayData* output = out->mutable_array();
+
+ if (O::is_utf8 && !options.allow_invalid_utf8) {
+ InitializeUTF8();
+
+ ArrayDataVisitor<I> visitor;
+ Utf8Validator validator;
+ RETURN_NOT_OK(visitor.Visit(input, &validator));
+ }
+
+ // Check for overflow
+ using output_offset_type = typename O::offset_type;
+ constexpr output_offset_type kMaxOffset =
+ std::numeric_limits<output_offset_type>::max();
+ const int32_t width =
+ checked_cast<const FixedSizeBinaryType&>(*input.type).byte_width();
+ const int64_t max_offset = width * input.length;
+ if (max_offset > kMaxOffset) {
+ return Status::Invalid("Failed casting from ", input.type->ToString(), " to ",
+ output->type->ToString(), ": input array too large");
+ }
+
+ // Copy buffers over, then generate indices
+ output->length = input.length;
+ output->SetNullCount(input.null_count);
+ if (input.offset == output->offset) {
+ output->buffers[0] = input.buffers[0];
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ output->buffers[0],
+ arrow::internal::CopyBitmap(ctx->memory_pool(), input.GetValues<uint8_t>(0, 0),
+ input.offset, input.length));
+ }
+ output->buffers[2] = input.buffers[1];
+ output_offset_type* offsets = output->GetMutableValues<output_offset_type>(1);
+
+ offsets[0] = static_cast<output_offset_type>(input.offset * width);
+ for (int64_t i = 0; i < input.length; i++) {
+ offsets[i + 1] = offsets[i] + width;
+ }
+ return Status::OK();
+}
+
+template <typename O, typename I>
+enable_if_t<std::is_same<I, FixedSizeBinaryType>::value &&
+ std::is_same<O, FixedSizeBinaryType>::value,
+ Status>
+BinaryToBinaryCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK(out->is_array());
+ const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options;
+ const ArrayData& input = *batch[0].array();
+ const int32_t in_width =
+ checked_cast<const FixedSizeBinaryType&>(*input.type).byte_width();
+ const int32_t out_width =
+ checked_cast<const FixedSizeBinaryType&>(*options.to_type).byte_width();
+
+ if (in_width != out_width) {
+ return Status::Invalid("Failed casting from ", input.type->ToString(), " to ",
+ options.to_type->ToString(), ": widths must match");
+ }
+
+ return ZeroCopyCastExec(ctx, batch, out);
+}
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+// ----------------------------------------------------------------------
+// Cast functions registration
+
+template <typename OutType>
+void AddNumberToStringCasts(CastFunction* func) {
+ auto out_ty = TypeTraits<OutType>::type_singleton();
+
+ DCHECK_OK(func->AddKernel(Type::BOOL, {boolean()}, out_ty,
+ TrivialScalarUnaryAsArraysExec(
+ NumericToStringCastFunctor<OutType, BooleanType>::Exec),
+ NullHandling::COMPUTED_NO_PREALLOCATE));
+
+ for (const std::shared_ptr<DataType>& in_ty : NumericTypes()) {
+ DCHECK_OK(
+ func->AddKernel(in_ty->id(), {in_ty}, out_ty,
+ TrivialScalarUnaryAsArraysExec(
+ GenerateNumeric<NumericToStringCastFunctor, OutType>(*in_ty)),
+ NullHandling::COMPUTED_NO_PREALLOCATE));
+ }
+}
+
+template <typename OutType>
+void AddTemporalToStringCasts(CastFunction* func) {
+ auto out_ty = TypeTraits<OutType>::type_singleton();
+ for (const std::shared_ptr<DataType>& in_ty : TemporalTypes()) {
+ DCHECK_OK(func->AddKernel(
+ in_ty->id(), {in_ty}, out_ty,
+ TrivialScalarUnaryAsArraysExec(
+ GenerateTemporal<TemporalToStringCastFunctor, OutType>(*in_ty)),
+ NullHandling::COMPUTED_NO_PREALLOCATE));
+ }
+}
+
+template <typename OutType, typename InType>
+void AddBinaryToBinaryCast(CastFunction* func) {
+ auto out_ty = TypeTraits<OutType>::type_singleton();
+
+ DCHECK_OK(func->AddKernel(
+ InType::type_id, {InputType(InType::type_id)}, out_ty,
+ TrivialScalarUnaryAsArraysExec(BinaryToBinaryCastExec<OutType, InType>),
+ NullHandling::COMPUTED_NO_PREALLOCATE));
+}
+
+template <typename OutType>
+void AddBinaryToBinaryCast(CastFunction* func) {
+ AddBinaryToBinaryCast<OutType, StringType>(func);
+ AddBinaryToBinaryCast<OutType, BinaryType>(func);
+ AddBinaryToBinaryCast<OutType, LargeStringType>(func);
+ AddBinaryToBinaryCast<OutType, LargeBinaryType>(func);
+ AddBinaryToBinaryCast<OutType, FixedSizeBinaryType>(func);
+}
+
+} // namespace
+
+std::vector<std::shared_ptr<CastFunction>> GetBinaryLikeCasts() {
+ auto cast_binary = std::make_shared<CastFunction>("cast_binary", Type::BINARY);
+ AddCommonCasts(Type::BINARY, binary(), cast_binary.get());
+ AddBinaryToBinaryCast<BinaryType>(cast_binary.get());
+
+ auto cast_large_binary =
+ std::make_shared<CastFunction>("cast_large_binary", Type::LARGE_BINARY);
+ AddCommonCasts(Type::LARGE_BINARY, large_binary(), cast_large_binary.get());
+ AddBinaryToBinaryCast<LargeBinaryType>(cast_large_binary.get());
+
+ auto cast_string = std::make_shared<CastFunction>("cast_string", Type::STRING);
+ AddCommonCasts(Type::STRING, utf8(), cast_string.get());
+ AddNumberToStringCasts<StringType>(cast_string.get());
+ AddTemporalToStringCasts<StringType>(cast_string.get());
+ AddBinaryToBinaryCast<StringType>(cast_string.get());
+
+ auto cast_large_string =
+ std::make_shared<CastFunction>("cast_large_string", Type::LARGE_STRING);
+ AddCommonCasts(Type::LARGE_STRING, large_utf8(), cast_large_string.get());
+ AddNumberToStringCasts<LargeStringType>(cast_large_string.get());
+ AddTemporalToStringCasts<LargeStringType>(cast_large_string.get());
+ AddBinaryToBinaryCast<LargeStringType>(cast_large_string.get());
+
+ auto cast_fsb =
+ std::make_shared<CastFunction>("cast_fixed_size_binary", Type::FIXED_SIZE_BINARY);
+ AddCommonCasts(Type::FIXED_SIZE_BINARY, OutputType(ResolveOutputFromOptions),
+ cast_fsb.get());
+ DCHECK_OK(cast_fsb->AddKernel(
+ Type::FIXED_SIZE_BINARY, {InputType(Type::FIXED_SIZE_BINARY)},
+ OutputType(FirstType),
+ TrivialScalarUnaryAsArraysExec(
+ BinaryToBinaryCastExec<FixedSizeBinaryType, FixedSizeBinaryType>),
+ NullHandling::COMPUTED_NO_PREALLOCATE));
+
+ return {cast_binary, cast_large_binary, cast_string, cast_large_string, cast_fsb};
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc
new file mode 100644
index 000000000..5f16f1e9d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc
@@ -0,0 +1,598 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Implementation of casting to (or between) temporal types
+
+#include <limits>
+
+#include "arrow/array/builder_time.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/scalar_cast_internal.h"
+#include "arrow/compute/kernels/temporal_internal.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/time.h"
+#include "arrow/util/value_parsing.h"
+
+namespace arrow {
+
+using internal::ParseValue;
+
+namespace compute {
+namespace internal {
+
+constexpr int64_t kMillisecondsInDay = 86400000;
+
+// ----------------------------------------------------------------------
+// From one timestamp to another
+
+template <typename in_type, typename out_type>
+Status ShiftTime(KernelContext* ctx, const util::DivideOrMultiply factor_op,
+ const int64_t factor, const ArrayData& input, ArrayData* output) {
+ const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options;
+ auto in_data = input.GetValues<in_type>(1);
+ auto out_data = output->GetMutableValues<out_type>(1);
+
+ if (factor == 1) {
+ for (int64_t i = 0; i < input.length; i++) {
+ out_data[i] = static_cast<out_type>(in_data[i]);
+ }
+ } else if (factor_op == util::MULTIPLY) {
+ if (options.allow_time_overflow) {
+ for (int64_t i = 0; i < input.length; i++) {
+ out_data[i] = static_cast<out_type>(in_data[i] * factor);
+ }
+ } else {
+#define RAISE_OVERFLOW_CAST(VAL) \
+ return Status::Invalid("Casting from ", input.type->ToString(), " to ", \
+ output->type->ToString(), " would result in ", \
+ "out of bounds timestamp: ", VAL);
+
+ int64_t max_val = std::numeric_limits<int64_t>::max() / factor;
+ int64_t min_val = std::numeric_limits<int64_t>::min() / factor;
+ if (input.null_count != 0) {
+ BitmapReader bit_reader(input.buffers[0]->data(), input.offset, input.length);
+ for (int64_t i = 0; i < input.length; i++) {
+ if (bit_reader.IsSet() && (in_data[i] < min_val || in_data[i] > max_val)) {
+ RAISE_OVERFLOW_CAST(in_data[i]);
+ }
+ out_data[i] = static_cast<out_type>(in_data[i] * factor);
+ bit_reader.Next();
+ }
+ } else {
+ for (int64_t i = 0; i < input.length; i++) {
+ if (in_data[i] < min_val || in_data[i] > max_val) {
+ RAISE_OVERFLOW_CAST(in_data[i]);
+ }
+ out_data[i] = static_cast<out_type>(in_data[i] * factor);
+ }
+ }
+
+#undef RAISE_OVERFLOW_CAST
+ }
+ } else {
+ if (options.allow_time_truncate) {
+ for (int64_t i = 0; i < input.length; i++) {
+ out_data[i] = static_cast<out_type>(in_data[i] / factor);
+ }
+ } else {
+#define RAISE_INVALID_CAST(VAL) \
+ return Status::Invalid("Casting from ", input.type->ToString(), " to ", \
+ output->type->ToString(), " would lose data: ", VAL);
+
+ if (input.null_count != 0) {
+ BitmapReader bit_reader(input.buffers[0]->data(), input.offset, input.length);
+ for (int64_t i = 0; i < input.length; i++) {
+ out_data[i] = static_cast<out_type>(in_data[i] / factor);
+ if (bit_reader.IsSet() && (out_data[i] * factor != in_data[i])) {
+ RAISE_INVALID_CAST(in_data[i]);
+ }
+ bit_reader.Next();
+ }
+ } else {
+ for (int64_t i = 0; i < input.length; i++) {
+ out_data[i] = static_cast<out_type>(in_data[i] / factor);
+ if (out_data[i] * factor != in_data[i]) {
+ RAISE_INVALID_CAST(in_data[i]);
+ }
+ }
+ }
+
+#undef RAISE_INVALID_CAST
+ }
+ }
+
+ return Status::OK();
+}
+
+template <template <typename...> class Op, typename OutType, typename... Args>
+Status ExtractTemporal(KernelContext* ctx, const ExecBatch& batch, Datum* out,
+ Args... args) {
+ const auto& ty = checked_cast<const TimestampType&>(*batch[0].type());
+
+ switch (ty.unit()) {
+ case TimeUnit::SECOND:
+ return TemporalComponentExtract<Op, std::chrono::seconds, TimestampType, OutType,
+ Args...>::Exec(ctx, batch, out, args...);
+ case TimeUnit::MILLI:
+ return TemporalComponentExtract<Op, std::chrono::milliseconds, TimestampType,
+ OutType, Args...>::Exec(ctx, batch, out, args...);
+ case TimeUnit::MICRO:
+ return TemporalComponentExtract<Op, std::chrono::microseconds, TimestampType,
+ OutType, Args...>::Exec(ctx, batch, out, args...);
+ case TimeUnit::NANO:
+ return TemporalComponentExtract<Op, std::chrono::nanoseconds, TimestampType,
+ OutType, Args...>::Exec(ctx, batch, out, args...);
+ }
+ return Status::Invalid("Unknown timestamp unit: ", ty);
+}
+
+// <TimestampType, TimestampType> and <DurationType, DurationType>
+template <typename O, typename I>
+struct CastFunctor<
+ O, I,
+ enable_if_t<(is_timestamp_type<O>::value && is_timestamp_type<I>::value) ||
+ (is_duration_type<O>::value && is_duration_type<I>::value)>> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
+
+ const ArrayData& input = *batch[0].array();
+ ArrayData* output = out->mutable_array();
+
+ // If units are the same, zero copy, otherwise convert
+ const auto& in_type = checked_cast<const I&>(*batch[0].type());
+ const auto& out_type = checked_cast<const O&>(*output->type);
+
+ // The units may be equal if the time zones are different. We might go to
+ // lengths to make this zero copy in the future but we leave it for now
+
+ auto conversion = util::GetTimestampConversion(in_type.unit(), out_type.unit());
+ return ShiftTime<int64_t, int64_t>(ctx, conversion.first, conversion.second, input,
+ output);
+ }
+};
+
+// ----------------------------------------------------------------------
+// From timestamp to date32 or date64
+
+template <>
+struct CastFunctor<Date32Type, TimestampType> {
+ template <typename Duration, typename Localizer>
+ struct Date32 {
+ Date32(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ return static_cast<T>(static_cast<const int32_t>(
+ floor<days>(localizer_.template ConvertTimePoint<Duration>(arg))
+ .time_since_epoch()
+ .count()));
+ }
+
+ Localizer localizer_;
+ };
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
+ return ExtractTemporal<Date32, Date32Type>(ctx, batch, out);
+ }
+};
+
+template <>
+struct CastFunctor<Date64Type, TimestampType> {
+ template <typename Duration, typename Localizer>
+ struct Date64 {
+ constexpr static int64_t kMillisPerDay = 86400000;
+ Date64(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ return static_cast<T>(
+ kMillisPerDay *
+ static_cast<const int32_t>(
+ floor<days>(localizer_.template ConvertTimePoint<Duration>(arg))
+ .time_since_epoch()
+ .count()));
+ }
+
+ Localizer localizer_;
+ };
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
+ return ExtractTemporal<Date64, Date64Type>(ctx, batch, out);
+ }
+};
+
+// ----------------------------------------------------------------------
+// From timestamp to time32 or time64
+
+template <typename Duration, typename Localizer>
+struct ExtractTimeDownscaled {
+ ExtractTimeDownscaled(const FunctionOptions* options, Localizer&& localizer,
+ const int64_t factor)
+ : localizer_(std::move(localizer)), factor_(factor) {}
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status* st) const {
+ const auto t = localizer_.template ConvertTimePoint<Duration>(arg);
+ const int64_t orig_value = (t - floor<days>(t)).count();
+ const T scaled = static_cast<T>(orig_value / factor_);
+ const int64_t unscaled = static_cast<int64_t>(scaled) * factor_;
+ if (unscaled != orig_value) {
+ *st = Status::Invalid("Cast would lose data: ", orig_value);
+ return 0;
+ }
+ return scaled;
+ }
+
+ Localizer localizer_;
+ const int64_t factor_;
+};
+
+template <typename Duration, typename Localizer>
+struct ExtractTimeUpscaledUnchecked {
+ ExtractTimeUpscaledUnchecked(const FunctionOptions* options, Localizer&& localizer,
+ const int64_t factor)
+ : localizer_(std::move(localizer)), factor_(factor) {}
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ const auto t = localizer_.template ConvertTimePoint<Duration>(arg);
+ const int64_t orig_value = (t - floor<days>(t)).count();
+ return static_cast<T>(orig_value * factor_);
+ }
+
+ Localizer localizer_;
+ const int64_t factor_;
+};
+
+template <typename Duration, typename Localizer>
+struct ExtractTimeDownscaledUnchecked {
+ ExtractTimeDownscaledUnchecked(const FunctionOptions* options, Localizer&& localizer,
+ const int64_t factor)
+ : localizer_(std::move(localizer)), factor_(factor) {}
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ const auto t = localizer_.template ConvertTimePoint<Duration>(arg);
+ const int64_t orig_value = (t - floor<days>(t)).count();
+ return static_cast<T>(orig_value / factor_);
+ }
+
+ Localizer localizer_;
+ const int64_t factor_;
+};
+
+template <>
+struct CastFunctor<Time32Type, TimestampType> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
+ const auto& in_type = checked_cast<const TimestampType&>(*batch[0].type());
+ const auto& out_type = checked_cast<const Time32Type&>(*out->type());
+ const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options;
+
+ // Shifting before extraction won't work since the timestamp may not fit
+ // even if the time itself fits
+ if (in_type.unit() != out_type.unit()) {
+ auto conversion = util::GetTimestampConversion(in_type.unit(), out_type.unit());
+ if (conversion.first == util::MULTIPLY) {
+ return ExtractTemporal<ExtractTimeUpscaledUnchecked, Time32Type>(
+ ctx, batch, out, conversion.second);
+ } else {
+ if (options.allow_time_truncate) {
+ return ExtractTemporal<ExtractTimeDownscaledUnchecked, Time32Type>(
+ ctx, batch, out, conversion.second);
+ } else {
+ return ExtractTemporal<ExtractTimeDownscaled, Time32Type>(ctx, batch, out,
+ conversion.second);
+ }
+ }
+ }
+ return ExtractTemporal<ExtractTimeUpscaledUnchecked, Time32Type>(ctx, batch, out, 1);
+ }
+};
+
+template <>
+struct CastFunctor<Time64Type, TimestampType> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
+ const auto& in_type = checked_cast<const TimestampType&>(*batch[0].type());
+ const auto& out_type = checked_cast<const Time64Type&>(*out->type());
+ const CastOptions& options = checked_cast<const CastState&>(*ctx->state()).options;
+
+ // Shifting before extraction won't work since the timestamp may not fit
+ // even if the time itself fits
+ if (in_type.unit() != out_type.unit()) {
+ auto conversion = util::GetTimestampConversion(in_type.unit(), out_type.unit());
+ if (conversion.first == util::MULTIPLY) {
+ return ExtractTemporal<ExtractTimeUpscaledUnchecked, Time64Type>(
+ ctx, batch, out, conversion.second);
+ } else {
+ if (options.allow_time_truncate) {
+ return ExtractTemporal<ExtractTimeDownscaledUnchecked, Time64Type>(
+ ctx, batch, out, conversion.second);
+ } else {
+ return ExtractTemporal<ExtractTimeDownscaled, Time64Type>(ctx, batch, out,
+ conversion.second);
+ }
+ }
+ }
+ return ExtractTemporal<ExtractTimeUpscaledUnchecked, Time64Type>(ctx, batch, out, 1);
+ }
+};
+
+// ----------------------------------------------------------------------
+// From one time32 or time64 to another
+
+template <typename O, typename I>
+struct CastFunctor<O, I, enable_if_t<is_time_type<I>::value && is_time_type<O>::value>> {
+ using in_t = typename I::c_type;
+ using out_t = typename O::c_type;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
+
+ const ArrayData& input = *batch[0].array();
+ ArrayData* output = out->mutable_array();
+
+ // If units are the same, zero copy, otherwise convert
+ const auto& in_type = checked_cast<const I&>(*input.type);
+ const auto& out_type = checked_cast<const O&>(*output->type);
+ DCHECK_NE(in_type.unit(), out_type.unit()) << "Do not cast equal types";
+ auto conversion = util::GetTimestampConversion(in_type.unit(), out_type.unit());
+ return ShiftTime<in_t, out_t>(ctx, conversion.first, conversion.second, input,
+ output);
+ }
+};
+
+// ----------------------------------------------------------------------
+// Between date32 and date64
+
+template <>
+struct CastFunctor<Date64Type, Date32Type> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
+
+ return ShiftTime<int32_t, int64_t>(ctx, util::MULTIPLY, kMillisecondsInDay,
+ *batch[0].array(), out->mutable_array());
+ }
+};
+
+template <>
+struct CastFunctor<Date32Type, Date64Type> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
+
+ return ShiftTime<int64_t, int32_t>(ctx, util::DIVIDE, kMillisecondsInDay,
+ *batch[0].array(), out->mutable_array());
+ }
+};
+
+// ----------------------------------------------------------------------
+// date32, date64 to timestamp
+
+template <>
+struct CastFunctor<TimestampType, Date32Type> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
+
+ const auto& out_type = checked_cast<const TimestampType&>(*out->type());
+ // get conversion SECOND -> unit
+ auto conversion = util::GetTimestampConversion(TimeUnit::SECOND, out_type.unit());
+ DCHECK_EQ(conversion.first, util::MULTIPLY);
+
+ // multiply to achieve days -> unit
+ conversion.second *= kMillisecondsInDay / 1000;
+ return ShiftTime<int32_t, int64_t>(ctx, util::MULTIPLY, conversion.second,
+ *batch[0].array(), out->mutable_array());
+ }
+};
+
+template <>
+struct CastFunctor<TimestampType, Date64Type> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
+
+ const auto& out_type = checked_cast<const TimestampType&>(*out->type());
+
+ // date64 is ms since epoch
+ auto conversion = util::GetTimestampConversion(TimeUnit::MILLI, out_type.unit());
+ return ShiftTime<int64_t, int64_t>(ctx, conversion.first, conversion.second,
+ *batch[0].array(), out->mutable_array());
+ }
+};
+
+// ----------------------------------------------------------------------
+// String to Timestamp
+
+struct ParseTimestamp {
+ template <typename OutValue, typename Arg0Value>
+ OutValue Call(KernelContext*, Arg0Value val, Status* st) const {
+ OutValue result = 0;
+ if (ARROW_PREDICT_FALSE(!ParseValue(type, val.data(), val.size(), &result))) {
+ *st = Status::Invalid("Failed to parse string: '", val, "' as a scalar of type ",
+ type.ToString());
+ }
+ return result;
+ }
+
+ const TimestampType& type;
+};
+
+template <typename I>
+struct CastFunctor<TimestampType, I, enable_if_t<is_base_binary_type<I>::value>> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& out_type = checked_cast<const TimestampType&>(*out->type());
+ applicator::ScalarUnaryNotNullStateful<TimestampType, I, ParseTimestamp> kernel(
+ ParseTimestamp{out_type});
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+template <typename Type>
+void AddCrossUnitCast(CastFunction* func) {
+ ScalarKernel kernel;
+ kernel.exec = TrivialScalarUnaryAsArraysExec(CastFunctor<Type, Type>::Exec);
+ kernel.signature = KernelSignature::Make({InputType(Type::type_id)}, kOutputTargetType);
+ DCHECK_OK(func->AddKernel(Type::type_id, std::move(kernel)));
+}
+
+std::shared_ptr<CastFunction> GetDate32Cast() {
+ auto func = std::make_shared<CastFunction>("cast_date32", Type::DATE32);
+ auto out_ty = date32();
+ AddCommonCasts(Type::DATE32, out_ty, func.get());
+
+ // int32 -> date32
+ AddZeroCopyCast(Type::INT32, int32(), date32(), func.get());
+
+ // date64 -> date32
+ AddSimpleCast<Date64Type, Date32Type>(date64(), date32(), func.get());
+
+ // timestamp -> date32
+ AddSimpleCast<TimestampType, Date32Type>(InputType(Type::TIMESTAMP), date32(),
+ func.get());
+ return func;
+}
+
+std::shared_ptr<CastFunction> GetDate64Cast() {
+ auto func = std::make_shared<CastFunction>("cast_date64", Type::DATE64);
+ auto out_ty = date64();
+ AddCommonCasts(Type::DATE64, out_ty, func.get());
+
+ // int64 -> date64
+ AddZeroCopyCast(Type::INT64, int64(), date64(), func.get());
+
+ // date32 -> date64
+ AddSimpleCast<Date32Type, Date64Type>(date32(), date64(), func.get());
+
+ // timestamp -> date64
+ AddSimpleCast<TimestampType, Date64Type>(InputType(Type::TIMESTAMP), date64(),
+ func.get());
+ return func;
+}
+
+std::shared_ptr<CastFunction> GetDurationCast() {
+ auto func = std::make_shared<CastFunction>("cast_duration", Type::DURATION);
+ AddCommonCasts(Type::DURATION, kOutputTargetType, func.get());
+
+ auto seconds = duration(TimeUnit::SECOND);
+ auto millis = duration(TimeUnit::MILLI);
+ auto micros = duration(TimeUnit::MICRO);
+ auto nanos = duration(TimeUnit::NANO);
+
+ // Same integer representation
+ AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get());
+
+ // Between durations
+ AddCrossUnitCast<DurationType>(func.get());
+
+ return func;
+}
+
+std::shared_ptr<CastFunction> GetIntervalCast() {
+ auto func = std::make_shared<CastFunction>("cast_month_day_nano_interval",
+ Type::INTERVAL_MONTH_DAY_NANO);
+ AddCommonCasts(Type::INTERVAL_MONTH_DAY_NANO, kOutputTargetType, func.get());
+ return func;
+}
+
+std::shared_ptr<CastFunction> GetTime32Cast() {
+ auto func = std::make_shared<CastFunction>("cast_time32", Type::TIME32);
+ AddCommonCasts(Type::TIME32, kOutputTargetType, func.get());
+
+ // Zero copy when the unit is the same or same integer representation
+ AddZeroCopyCast(Type::INT32, /*in_type=*/int32(), kOutputTargetType, func.get());
+
+ // time64 -> time32
+ AddSimpleCast<Time64Type, Time32Type>(InputType(Type::TIME64), kOutputTargetType,
+ func.get());
+
+ // time32 -> time32
+ AddCrossUnitCast<Time32Type>(func.get());
+
+ // timestamp -> time32
+ AddSimpleCast<TimestampType, Time32Type>(InputType(Type::TIMESTAMP), kOutputTargetType,
+ func.get());
+
+ return func;
+}
+
+std::shared_ptr<CastFunction> GetTime64Cast() {
+ auto func = std::make_shared<CastFunction>("cast_time64", Type::TIME64);
+ AddCommonCasts(Type::TIME64, kOutputTargetType, func.get());
+
+ // Zero copy when the unit is the same or same integer representation
+ AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get());
+
+ // time32 -> time64
+ AddSimpleCast<Time32Type, Time64Type>(InputType(Type::TIME32), kOutputTargetType,
+ func.get());
+
+ // Between durations
+ AddCrossUnitCast<Time64Type>(func.get());
+
+ // timestamp -> time64
+ AddSimpleCast<TimestampType, Time64Type>(InputType(Type::TIMESTAMP), kOutputTargetType,
+ func.get());
+
+ return func;
+}
+
+std::shared_ptr<CastFunction> GetTimestampCast() {
+ auto func = std::make_shared<CastFunction>("cast_timestamp", Type::TIMESTAMP);
+ AddCommonCasts(Type::TIMESTAMP, kOutputTargetType, func.get());
+
+ // Same integer representation
+ AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get());
+
+ // From date types
+ // TODO: ARROW-8876, these casts are not directly tested
+ AddSimpleCast<Date32Type, TimestampType>(InputType(Type::DATE32), kOutputTargetType,
+ func.get());
+ AddSimpleCast<Date64Type, TimestampType>(InputType(Type::DATE64), kOutputTargetType,
+ func.get());
+
+ // string -> timestamp
+ AddSimpleCast<StringType, TimestampType>(utf8(), kOutputTargetType, func.get());
+ // large_string -> timestamp
+ AddSimpleCast<LargeStringType, TimestampType>(large_utf8(), kOutputTargetType,
+ func.get());
+
+ // From one timestamp to another
+ AddCrossUnitCast<TimestampType>(func.get());
+
+ return func;
+}
+
+std::vector<std::shared_ptr<CastFunction>> GetTemporalCasts() {
+ std::vector<std::shared_ptr<CastFunction>> functions;
+
+ functions.push_back(GetDate32Cast());
+ functions.push_back(GetDate64Cast());
+ functions.push_back(GetDurationCast());
+ functions.push_back(GetIntervalCast());
+ functions.push_back(GetTime32Cast());
+ functions.push_back(GetTime64Cast());
+ functions.push_back(GetTimestampCast());
+ return functions;
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
new file mode 100644
index 000000000..d21fa24c7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
@@ -0,0 +1,2334 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <cstdio>
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/buffer.h"
+#include "arrow/chunked_array.h"
+#include "arrow/extension_type.h"
+#include "arrow/status.h"
+#include "arrow/testing/extension_type.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bitmap.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/kernel.h"
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/compute/kernels/test_util.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+static std::shared_ptr<Array> InvalidUtf8(std::shared_ptr<DataType> type) {
+ return ArrayFromJSON(type,
+ "["
+ R"(
+ "Hi",
+ "olá mundo",
+ "你好世界",
+ "",
+ )"
+ "\"\xa0\xa1\""
+ "]");
+}
+
+static std::shared_ptr<Array> FixedSizeInvalidUtf8(std::shared_ptr<DataType> type) {
+ if (type->id() == Type::FIXED_SIZE_BINARY) {
+ // Assume a particular width for testing
+ EXPECT_EQ(3, checked_cast<const FixedSizeBinaryType&>(*type).byte_width());
+ }
+ return ArrayFromJSON(type,
+ "["
+ R"(
+ "Hi!",
+ "lá",
+ "你",
+ " ",
+ )"
+ "\"\xa0\xa1\xa2\""
+ "]");
+}
+
+static std::vector<std::shared_ptr<DataType>> kNumericTypes = {
+ uint8(), int8(), uint16(), int16(), uint32(),
+ int32(), uint64(), int64(), float32(), float64()};
+
+static std::vector<std::shared_ptr<DataType>> kIntegerTypes = {
+ int8(), uint8(), int16(), uint16(), int32(), uint32(), int64(), uint64()};
+
+static std::vector<std::shared_ptr<DataType>> kDictionaryIndexTypes = kIntegerTypes;
+
+static std::vector<std::shared_ptr<DataType>> kBaseBinaryTypes = {
+ binary(), utf8(), large_binary(), large_utf8()};
+
+static void AssertBufferSame(const Array& left, const Array& right, int buffer_index) {
+ ASSERT_EQ(left.data()->buffers[buffer_index].get(),
+ right.data()->buffers[buffer_index].get());
+}
+
+static void CheckCast(std::shared_ptr<Array> input, std::shared_ptr<Array> expected,
+ CastOptions options = CastOptions{}) {
+ options.to_type = expected->type();
+ CheckScalarUnary("cast", input, expected, &options);
+}
+
+static void CheckCastFails(std::shared_ptr<Array> input, CastOptions options) {
+ ASSERT_RAISES(Invalid, Cast(input, options))
+ << "\n to_type: " << options.to_type->ToString()
+ << "\n from_type: " << input->type()->ToString()
+ << "\n input: " << input->ToString();
+
+ // For the scalars, check that at least one of the input fails (since many
+ // of the tests contains a mix of passing and failing values). In some
+ // cases we will want to check more precisely
+ int64_t num_failing = 0;
+ for (int64_t i = 0; i < input->length(); ++i) {
+ ASSERT_OK_AND_ASSIGN(auto scalar, input->GetScalar(i));
+ num_failing += static_cast<int>(Cast(scalar, options).status().IsInvalid());
+ }
+ ASSERT_GT(num_failing, 0);
+}
+
+static void CheckCastZeroCopy(std::shared_ptr<Array> input,
+ std::shared_ptr<DataType> to_type,
+ CastOptions options = CastOptions::Safe()) {
+ ASSERT_OK_AND_ASSIGN(auto converted, Cast(*input, to_type, options));
+ ValidateOutput(*converted);
+
+ ASSERT_EQ(input->data()->buffers.size(), converted->data()->buffers.size());
+ for (size_t i = 0; i < input->data()->buffers.size(); ++i) {
+ AssertBufferSame(*input, *converted, static_cast<int>(i));
+ }
+}
+
+static std::shared_ptr<Array> MaskArrayWithNullsAt(std::shared_ptr<Array> input,
+ std::vector<int> indices_to_mask) {
+ auto masked = input->data()->Copy();
+ masked->buffers[0] = *AllocateEmptyBitmap(input->length());
+ masked->null_count = kUnknownNullCount;
+
+ using arrow::internal::Bitmap;
+ Bitmap is_valid(masked->buffers[0], 0, input->length());
+ if (auto original = input->null_bitmap()) {
+ is_valid.CopyFrom(Bitmap(original, input->offset(), input->length()));
+ } else {
+ is_valid.SetBitsTo(true);
+ }
+
+ for (int i : indices_to_mask) {
+ is_valid.SetBitTo(i, false);
+ }
+ return MakeArray(masked);
+}
+
+TEST(Cast, CanCast) {
+ auto ExpectCanCast = [](std::shared_ptr<DataType> from,
+ std::vector<std::shared_ptr<DataType>> to_set,
+ bool expected = true) {
+ for (auto to : to_set) {
+ EXPECT_EQ(CanCast(*from, *to), expected) << " from: " << from->ToString() << "\n"
+ << " to: " << to->ToString();
+ }
+ };
+
+ auto ExpectCannotCast = [ExpectCanCast](std::shared_ptr<DataType> from,
+ std::vector<std::shared_ptr<DataType>> to_set) {
+ ExpectCanCast(from, to_set, /*expected=*/false);
+ };
+
+ ExpectCanCast(null(), {boolean()});
+ ExpectCanCast(null(), kNumericTypes);
+ ExpectCanCast(null(), kBaseBinaryTypes);
+ ExpectCanCast(
+ null(), {date32(), date64(), time32(TimeUnit::MILLI), timestamp(TimeUnit::SECOND)});
+ ExpectCanCast(dictionary(uint16(), null()), {null()});
+
+ ExpectCanCast(boolean(), {boolean()});
+ ExpectCanCast(boolean(), kNumericTypes);
+ ExpectCanCast(boolean(), {utf8(), large_utf8()});
+ ExpectCanCast(dictionary(int32(), boolean()), {boolean()});
+
+ ExpectCannotCast(boolean(), {null()});
+ ExpectCannotCast(boolean(), {binary(), large_binary()});
+ ExpectCannotCast(boolean(), {date32(), date64(), time32(TimeUnit::MILLI),
+ timestamp(TimeUnit::SECOND)});
+
+ for (auto from_numeric : kNumericTypes) {
+ ExpectCanCast(from_numeric, {boolean()});
+ ExpectCanCast(from_numeric, kNumericTypes);
+ ExpectCanCast(from_numeric, {utf8(), large_utf8()});
+ ExpectCanCast(dictionary(int32(), from_numeric), {from_numeric});
+
+ ExpectCannotCast(from_numeric, {null()});
+ }
+
+ for (auto from_base_binary : kBaseBinaryTypes) {
+ ExpectCanCast(from_base_binary, {boolean()});
+ ExpectCanCast(from_base_binary, kNumericTypes);
+ ExpectCanCast(from_base_binary, kBaseBinaryTypes);
+ ExpectCanCast(dictionary(int64(), from_base_binary), {from_base_binary});
+
+ // any cast which is valid for the dictionary is valid for the DictionaryArray
+ ExpectCanCast(dictionary(uint32(), from_base_binary), kBaseBinaryTypes);
+ ExpectCanCast(dictionary(int16(), from_base_binary), kNumericTypes);
+
+ ExpectCannotCast(from_base_binary, {null()});
+ }
+
+ ExpectCanCast(utf8(), {timestamp(TimeUnit::MILLI)});
+ ExpectCanCast(large_utf8(), {timestamp(TimeUnit::NANO)});
+ ExpectCannotCast(timestamp(TimeUnit::MICRO),
+ {binary(), large_binary()}); // no formatting supported
+
+ ExpectCanCast(fixed_size_binary(3),
+ {binary(), utf8(), large_binary(), large_utf8(), fixed_size_binary(3)});
+ // Doesn't fail since a kernel exists (but it will return an error when executed)
+ // ExpectCannotCast(fixed_size_binary(3), {fixed_size_binary(5)});
+
+ ExtensionTypeGuard smallint_guard(smallint());
+ ExpectCanCast(smallint(), {int16()}); // cast storage
+ ExpectCanCast(smallint(),
+ kNumericTypes); // any cast which is valid for storage is supported
+ ExpectCannotCast(null(), {smallint()}); // FIXME missing common cast from null
+
+ ExpectCanCast(date32(), {utf8(), large_utf8()});
+ ExpectCanCast(date64(), {utf8(), large_utf8()});
+ ExpectCanCast(timestamp(TimeUnit::NANO), {utf8(), large_utf8()});
+ ExpectCanCast(timestamp(TimeUnit::MICRO), {utf8(), large_utf8()});
+ ExpectCanCast(time32(TimeUnit::MILLI), {utf8(), large_utf8()});
+ ExpectCanCast(time64(TimeUnit::NANO), {utf8(), large_utf8()});
+}
+
+TEST(Cast, SameTypeZeroCopy) {
+ std::shared_ptr<Array> arr = ArrayFromJSON(int32(), "[0, null, 2, 3, 4]");
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> result, Cast(*arr, int32()));
+
+ AssertBufferSame(*arr, *result, 0);
+ AssertBufferSame(*arr, *result, 1);
+}
+
+TEST(Cast, ZeroChunks) {
+ auto chunked_i32 = std::make_shared<ChunkedArray>(ArrayVector{}, int32());
+ ASSERT_OK_AND_ASSIGN(Datum result, Cast(chunked_i32, utf8()));
+
+ ASSERT_EQ(result.kind(), Datum::CHUNKED_ARRAY);
+ AssertChunkedEqual(*result.chunked_array(), ChunkedArray({}, utf8()));
+}
+
+TEST(Cast, CastDoesNotProvideDefaultOptions) {
+ std::shared_ptr<Array> arr = ArrayFromJSON(int32(), "[0, null, 2, 3, 4]");
+ ASSERT_RAISES(Invalid, CallFunction("cast", {arr}));
+}
+
+TEST(Cast, FromBoolean) {
+ std::string vals = "[1, 0, null, 1, 0, 1, 1, null, 0, 0, 1]";
+ CheckCast(ArrayFromJSON(boolean(), vals), ArrayFromJSON(int32(), vals));
+}
+
+TEST(Cast, ToBoolean) {
+ for (auto type : kNumericTypes) {
+ CheckCast(ArrayFromJSON(type, "[0, null, 127, 1, 0]"),
+ ArrayFromJSON(boolean(), "[false, null, true, true, false]"));
+ }
+
+ // Check negative numbers
+ for (auto type : {int8(), float64()}) {
+ CheckCast(ArrayFromJSON(type, "[0, null, 127, -1, 0]"),
+ ArrayFromJSON(boolean(), "[false, null, true, true, false]"));
+ }
+}
+
+TEST(Cast, ToIntUpcast) {
+ std::vector<bool> is_valid = {true, false, true, true, true};
+
+ // int8 to int32
+ CheckCast(ArrayFromJSON(int8(), "[0, null, 127, -1, 0]"),
+ ArrayFromJSON(int32(), "[0, null, 127, -1, 0]"));
+
+ // uint8 to int16, no overflow/underrun
+ CheckCast(ArrayFromJSON(uint8(), "[0, 100, 200, 255, 0]"),
+ ArrayFromJSON(int16(), "[0, 100, 200, 255, 0]"));
+}
+
+TEST(Cast, OverflowInNullSlot) {
+ CheckCast(
+ MaskArrayWithNullsAt(ArrayFromJSON(int32(), "[0, 87654321, 2000, 1000, 0]"), {1}),
+ ArrayFromJSON(int16(), "[0, null, 2000, 1000, 0]"));
+}
+
+TEST(Cast, ToIntDowncastSafe) {
+ // int16 to uint8, no overflow/underflow
+ CheckCast(ArrayFromJSON(int16(), "[0, null, 200, 1, 2]"),
+ ArrayFromJSON(uint8(), "[0, null, 200, 1, 2]"));
+
+ // int16 to uint8, overflow
+ CheckCastFails(ArrayFromJSON(int16(), "[0, null, 256, 0, 0]"),
+ CastOptions::Safe(uint8()));
+ // ... and underflow
+ CheckCastFails(ArrayFromJSON(int16(), "[0, null, -1, 0, 0]"),
+ CastOptions::Safe(uint8()));
+
+ // int32 to int16, no overflow/underflow
+ CheckCast(ArrayFromJSON(int32(), "[0, null, 2000, 1, 2]"),
+ ArrayFromJSON(int16(), "[0, null, 2000, 1, 2]"));
+
+ // int32 to int16, overflow
+ CheckCastFails(ArrayFromJSON(int32(), "[0, null, 2000, 70000, 2]"),
+ CastOptions::Safe(int16()));
+
+ // ... and underflow
+ CheckCastFails(ArrayFromJSON(int32(), "[0, null, 2000, -70000, 2]"),
+ CastOptions::Safe(int16()));
+
+ CheckCastFails(ArrayFromJSON(int32(), "[0, null, 2000, -70000, 2]"),
+ CastOptions::Safe(uint8()));
+}
+
+TEST(Cast, IntegerSignedToUnsigned) {
+ auto i32s = ArrayFromJSON(int32(), "[-2147483648, null, -1, 65535, 2147483647]");
+ // Same width
+ CheckCastFails(i32s, CastOptions::Safe(uint32()));
+ // Wider
+ CheckCastFails(i32s, CastOptions::Safe(uint64()));
+ // Narrower
+ CheckCastFails(i32s, CastOptions::Safe(uint16()));
+
+ CastOptions options;
+ options.allow_int_overflow = true;
+
+ CheckCast(i32s,
+ ArrayFromJSON(uint32(), "[2147483648, null, 4294967295, 65535, 2147483647]"),
+ options);
+ CheckCast(i32s,
+ ArrayFromJSON(
+ uint64(),
+ "[18446744071562067968, null, 18446744073709551615, 65535, 2147483647]"),
+ options);
+ CheckCast(i32s, ArrayFromJSON(uint16(), "[0, null, 65535, 65535, 65535]"), options);
+
+ // Fail because of overflow (instead of underflow).
+ i32s = ArrayFromJSON(int32(), "[0, null, 0, 65536, 2147483647]");
+ CheckCastFails(i32s, CastOptions::Safe(uint16()));
+
+ CheckCast(i32s, ArrayFromJSON(uint16(), "[0, null, 0, 0, 65535]"), options);
+}
+
+TEST(Cast, IntegerUnsignedToSigned) {
+ auto u32s = ArrayFromJSON(uint32(), "[4294967295, null, 0, 32768]");
+ // Same width
+ CheckCastFails(u32s, CastOptions::Safe(int32()));
+
+ // Narrower
+ CheckCastFails(u32s, CastOptions::Safe(int16()));
+ CheckCastFails(u32s->Slice(1), CastOptions::Safe(int16()));
+
+ CastOptions options;
+ options.allow_int_overflow = true;
+
+ CheckCast(u32s, ArrayFromJSON(int32(), "[-1, null, 0, 32768]"), options);
+ CheckCast(u32s, ArrayFromJSON(int64(), "[4294967295, null, 0, 32768]"), options);
+ CheckCast(u32s, ArrayFromJSON(int16(), "[-1, null, 0, -32768]"), options);
+}
+
+TEST(Cast, ToIntDowncastUnsafe) {
+ CastOptions options;
+ options.allow_int_overflow = true;
+
+ // int16 to uint8, no overflow/underflow
+ CheckCast(ArrayFromJSON(int16(), "[0, null, 200, 1, 2]"),
+ ArrayFromJSON(uint8(), "[0, null, 200, 1, 2]"), options);
+
+ // int16 to uint8, with overflow/underflow
+ CheckCast(ArrayFromJSON(int16(), "[0, null, 256, 1, 2, -1]"),
+ ArrayFromJSON(uint8(), "[0, null, 0, 1, 2, 255]"), options);
+
+ // int32 to int16, no overflow/underflow
+ CheckCast(ArrayFromJSON(int32(), "[0, null, 2000, 1, 2, -1]"),
+ ArrayFromJSON(int16(), "[0, null, 2000, 1, 2, -1]"), options);
+
+ // int32 to int16, with overflow/underflow
+ CheckCast(ArrayFromJSON(int32(), "[0, null, 2000, 70000, -70000]"),
+ ArrayFromJSON(int16(), "[0, null, 2000, 4464, -4464]"), options);
+}
+
+TEST(Cast, FloatingToInt) {
+ for (auto from : {float32(), float64()}) {
+ for (auto to : {int32(), int64()}) {
+ // float to int no truncation
+ CheckCast(ArrayFromJSON(from, "[1.0, null, 0.0, -1.0, 5.0]"),
+ ArrayFromJSON(to, "[1, null, 0, -1, 5]"));
+
+ // float to int truncate error
+ auto opts = CastOptions::Safe(to);
+ CheckCastFails(ArrayFromJSON(from, "[1.5, 0.0, null, 0.5, -1.5, 5.5]"), opts);
+
+ // float to int truncate allowed
+ opts.allow_float_truncate = true;
+ CheckCast(ArrayFromJSON(from, "[1.5, 0.0, null, 0.5, -1.5, 5.5]"),
+ ArrayFromJSON(to, "[1, 0, null, 0, -1, 5]"), opts);
+ }
+ }
+}
+
+TEST(Cast, IntToFloating) {
+ for (auto from : {uint32(), int32()}) {
+ std::string two_24 = "[16777216, 16777217]";
+
+ CheckCastFails(ArrayFromJSON(from, two_24), CastOptions::Safe(float32()));
+
+ CheckCast(ArrayFromJSON(from, two_24)->Slice(0, 1),
+ ArrayFromJSON(float32(), two_24)->Slice(0, 1));
+ }
+
+ auto i64s = ArrayFromJSON(int64(),
+ "[-9223372036854775808, -9223372036854775807, 0,"
+ " 9223372036854775806, 9223372036854775807]");
+ CheckCastFails(i64s, CastOptions::Safe(float64()));
+
+ // Masking those values with nulls makes this safe
+ CheckCast(MaskArrayWithNullsAt(i64s, {0, 1, 3, 4}),
+ ArrayFromJSON(float64(), "[null, null, 0, null, null]"));
+
+ CheckCastFails(ArrayFromJSON(uint64(), "[9007199254740992, 9007199254740993]"),
+ CastOptions::Safe(float64()));
+}
+
+TEST(Cast, Decimal128ToInt) {
+ auto options = CastOptions::Safe(int64());
+
+ for (bool allow_int_overflow : {false, true}) {
+ for (bool allow_decimal_truncate : {false, true}) {
+ options.allow_int_overflow = allow_int_overflow;
+ options.allow_decimal_truncate = allow_decimal_truncate;
+
+ auto no_overflow_no_truncation = ArrayFromJSON(decimal(38, 10), R"([
+ "02.0000000000",
+ "-11.0000000000",
+ "22.0000000000",
+ "-121.0000000000",
+ null])");
+ CheckCast(no_overflow_no_truncation,
+ ArrayFromJSON(int64(), "[2, -11, 22, -121, null]"), options);
+ }
+ }
+
+ for (bool allow_int_overflow : {false, true}) {
+ options.allow_int_overflow = allow_int_overflow;
+ auto truncation_but_no_overflow = ArrayFromJSON(decimal(38, 10), R"([
+ "02.1000000000",
+ "-11.0000004500",
+ "22.0000004500",
+ "-121.1210000000",
+ null])");
+
+ options.allow_decimal_truncate = true;
+ CheckCast(truncation_but_no_overflow,
+ ArrayFromJSON(int64(), "[2, -11, 22, -121, null]"), options);
+
+ options.allow_decimal_truncate = false;
+ CheckCastFails(truncation_but_no_overflow, options);
+ }
+
+ for (bool allow_decimal_truncate : {false, true}) {
+ options.allow_decimal_truncate = allow_decimal_truncate;
+
+ auto overflow_no_truncation = ArrayFromJSON(decimal(38, 10), R"([
+ "12345678901234567890000.0000000000",
+ "99999999999999999999999.0000000000",
+ null])");
+
+ options.allow_int_overflow = true;
+ CheckCast(
+ overflow_no_truncation,
+ ArrayFromJSON(int64(),
+ // 12345678901234567890000 % 2**64, 99999999999999999999999 % 2**64
+ "[4807115922877858896, 200376420520689663, null]"),
+ options);
+
+ options.allow_int_overflow = false;
+ CheckCastFails(overflow_no_truncation, options);
+ }
+
+ for (bool allow_int_overflow : {false, true}) {
+ for (bool allow_decimal_truncate : {false, true}) {
+ options.allow_int_overflow = allow_int_overflow;
+ options.allow_decimal_truncate = allow_decimal_truncate;
+
+ auto overflow_and_truncation = ArrayFromJSON(decimal(38, 10), R"([
+ "12345678901234567890000.0045345000",
+ "99999999999999999999999.0000344300",
+ null])");
+
+ if (options.allow_int_overflow && options.allow_decimal_truncate) {
+ CheckCast(overflow_and_truncation,
+ ArrayFromJSON(
+ int64(),
+ // 12345678901234567890000 % 2**64, 99999999999999999999999 % 2**64
+ "[4807115922877858896, 200376420520689663, null]"),
+ options);
+ } else {
+ CheckCastFails(overflow_and_truncation, options);
+ }
+ }
+ }
+
+ Decimal128Builder builder(decimal(38, -4));
+ for (auto d : {Decimal128("1234567890000."), Decimal128("-120000.")}) {
+ ASSERT_OK_AND_ASSIGN(d, d.Rescale(0, -4));
+ ASSERT_OK(builder.Append(d));
+ }
+ ASSERT_OK_AND_ASSIGN(auto negative_scale, builder.Finish());
+ options.allow_int_overflow = true;
+ options.allow_decimal_truncate = true;
+ CheckCast(negative_scale, ArrayFromJSON(int64(), "[1234567890000, -120000]"), options);
+}
+
+TEST(Cast, Decimal256ToInt) {
+ auto options = CastOptions::Safe(int64());
+
+ for (bool allow_int_overflow : {false, true}) {
+ for (bool allow_decimal_truncate : {false, true}) {
+ options.allow_int_overflow = allow_int_overflow;
+ options.allow_decimal_truncate = allow_decimal_truncate;
+
+ auto no_overflow_no_truncation = ArrayFromJSON(decimal256(40, 10), R"([
+ "02.0000000000",
+ "-11.0000000000",
+ "22.0000000000",
+ "-121.0000000000",
+ null])");
+ CheckCast(no_overflow_no_truncation,
+ ArrayFromJSON(int64(), "[2, -11, 22, -121, null]"), options);
+ }
+ }
+
+ for (bool allow_int_overflow : {false, true}) {
+ options.allow_int_overflow = allow_int_overflow;
+ auto truncation_but_no_overflow = ArrayFromJSON(decimal256(40, 10), R"([
+ "02.1000000000",
+ "-11.0000004500",
+ "22.0000004500",
+ "-121.1210000000",
+ null])");
+
+ options.allow_decimal_truncate = true;
+ CheckCast(truncation_but_no_overflow,
+ ArrayFromJSON(int64(), "[2, -11, 22, -121, null]"), options);
+
+ options.allow_decimal_truncate = false;
+ CheckCastFails(truncation_but_no_overflow, options);
+ }
+
+ for (bool allow_decimal_truncate : {false, true}) {
+ options.allow_decimal_truncate = allow_decimal_truncate;
+
+ auto overflow_no_truncation = ArrayFromJSON(decimal256(40, 10), R"([
+ "1234567890123456789000000.0000000000",
+ "9999999999999999999999999.0000000000",
+ null])");
+
+ options.allow_int_overflow = true;
+ CheckCast(overflow_no_truncation,
+ ArrayFromJSON(
+ int64(),
+ // 1234567890123456789000000 % 2**64, 9999999999999999999999999 % 2**64
+ "[1096246371337547584, 1590897978359414783, null]"),
+ options);
+
+ options.allow_int_overflow = false;
+ CheckCastFails(overflow_no_truncation, options);
+ }
+
+ for (bool allow_int_overflow : {false, true}) {
+ for (bool allow_decimal_truncate : {false, true}) {
+ options.allow_int_overflow = allow_int_overflow;
+ options.allow_decimal_truncate = allow_decimal_truncate;
+
+ auto overflow_and_truncation = ArrayFromJSON(decimal256(40, 10), R"([
+ "1234567890123456789000000.0045345000",
+ "9999999999999999999999999.0000344300",
+ null])");
+
+ if (options.allow_int_overflow && options.allow_decimal_truncate) {
+ CheckCast(
+ overflow_and_truncation,
+ ArrayFromJSON(
+ int64(),
+ // 1234567890123456789000000 % 2**64, 9999999999999999999999999 % 2**64
+ "[1096246371337547584, 1590897978359414783, null]"),
+ options);
+ } else {
+ CheckCastFails(overflow_and_truncation, options);
+ }
+ }
+ }
+
+ Decimal256Builder builder(decimal256(40, -4));
+ for (auto d : {Decimal256("1234567890000."), Decimal256("-120000.")}) {
+ ASSERT_OK_AND_ASSIGN(d, d.Rescale(0, -4));
+ ASSERT_OK(builder.Append(d));
+ }
+ ASSERT_OK_AND_ASSIGN(auto negative_scale, builder.Finish());
+ options.allow_int_overflow = true;
+ options.allow_decimal_truncate = true;
+ CheckCast(negative_scale, ArrayFromJSON(int64(), "[1234567890000, -120000]"), options);
+}
+
+TEST(Cast, IntegerToDecimal) {
+ for (auto decimal_type : {decimal128(21, 2), decimal256(21, 2)}) {
+ for (auto integer_type : kIntegerTypes) {
+ CheckCast(
+ ArrayFromJSON(integer_type, "[0, 7, null, 100, 99]"),
+ ArrayFromJSON(decimal_type, R"(["0.00", "7.00", null, "100.00", "99.00"])"));
+ }
+ }
+
+ // extreme value
+ for (auto decimal_type : {decimal128(19, 0), decimal256(19, 0)}) {
+ CheckCast(ArrayFromJSON(int64(), "[-9223372036854775808, 9223372036854775807]"),
+ ArrayFromJSON(decimal_type,
+ R"(["-9223372036854775808", "9223372036854775807"])"));
+ CheckCast(ArrayFromJSON(uint64(), "[0, 18446744073709551615]"),
+ ArrayFromJSON(decimal_type, R"(["0", "18446744073709551615"])"));
+ }
+
+ // insufficient output precision
+ {
+ CastOptions options;
+
+ options.to_type = decimal128(5, 3);
+ CheckCastFails(ArrayFromJSON(int8(), "[0]"), options);
+
+ options.to_type = decimal256(76, 67);
+ CheckCastFails(ArrayFromJSON(int32(), "[0]"), options);
+ }
+}
+
+TEST(Cast, Decimal128ToDecimal128) {
+ CastOptions options;
+
+ for (bool allow_decimal_truncate : {false, true}) {
+ options.allow_decimal_truncate = allow_decimal_truncate;
+
+ auto no_truncation = ArrayFromJSON(decimal(38, 10), R"([
+ "02.0000000000",
+ "30.0000000000",
+ "22.0000000000",
+ "-121.0000000000",
+ null])");
+ auto expected = ArrayFromJSON(decimal(28, 0), R"([
+ "02.",
+ "30.",
+ "22.",
+ "-121.",
+ null])");
+
+ CheckCast(no_truncation, expected, options);
+ CheckCast(expected, no_truncation, options);
+ }
+
+ for (bool allow_decimal_truncate : {false, true}) {
+ options.allow_decimal_truncate = allow_decimal_truncate;
+
+ // Same scale, different precision
+ auto d_5_2 = ArrayFromJSON(decimal(5, 2), R"([
+ "12.34",
+ "0.56"])");
+ auto d_4_2 = ArrayFromJSON(decimal(4, 2), R"([
+ "12.34",
+ "0.56"])");
+
+ CheckCast(d_5_2, d_4_2, options);
+ CheckCast(d_4_2, d_5_2, options);
+ }
+
+ auto d_38_10 = ArrayFromJSON(decimal(38, 10), R"([
+ "-02.1234567890",
+ "30.1234567890",
+ null])");
+
+ auto d_28_0 = ArrayFromJSON(decimal(28, 0), R"([
+ "-02.",
+ "30.",
+ null])");
+
+ auto d_38_10_roundtripped = ArrayFromJSON(decimal(38, 10), R"([
+ "-02.0000000000",
+ "30.0000000000",
+ null])");
+
+ // Rescale which leads to truncation
+ options.allow_decimal_truncate = true;
+ CheckCast(d_38_10, d_28_0, options);
+ CheckCast(d_28_0, d_38_10_roundtripped, options);
+
+ options.allow_decimal_truncate = false;
+ options.to_type = d_28_0->type();
+ CheckCastFails(d_38_10, options);
+ CheckCast(d_28_0, d_38_10_roundtripped, options);
+
+ // Precision loss without rescale leads to truncation
+ auto d_4_2 = ArrayFromJSON(decimal(4, 2), R"(["12.34"])");
+ for (auto expected : {
+ ArrayFromJSON(decimal(3, 2), R"(["12.34"])"),
+ ArrayFromJSON(decimal(4, 3), R"(["12.340"])"),
+ ArrayFromJSON(decimal(2, 1), R"(["12.3"])"),
+ }) {
+ options.allow_decimal_truncate = true;
+ CheckCast(d_4_2, expected, options);
+
+ options.allow_decimal_truncate = false;
+ options.to_type = expected->type();
+ CheckCastFails(d_4_2, options);
+ }
+}
+
+TEST(Cast, Decimal256ToDecimal256) {
+ CastOptions options;
+
+ for (bool allow_decimal_truncate : {false, true}) {
+ options.allow_decimal_truncate = allow_decimal_truncate;
+
+ auto no_truncation = ArrayFromJSON(decimal256(38, 10), R"([
+ "02.0000000000",
+ "30.0000000000",
+ "22.0000000000",
+ "-121.0000000000",
+ null])");
+ auto expected = ArrayFromJSON(decimal256(28, 0), R"([
+ "02.",
+ "30.",
+ "22.",
+ "-121.",
+ null])");
+
+ CheckCast(no_truncation, expected, options);
+ CheckCast(expected, no_truncation, options);
+ }
+
+ for (bool allow_decimal_truncate : {false, true}) {
+ options.allow_decimal_truncate = allow_decimal_truncate;
+
+ // Same scale, different precision
+ auto d_5_2 = ArrayFromJSON(decimal256(5, 2), R"([
+ "12.34",
+ "0.56"])");
+ auto d_4_2 = ArrayFromJSON(decimal256(4, 2), R"([
+ "12.34",
+ "0.56"])");
+
+ CheckCast(d_5_2, d_4_2, options);
+ CheckCast(d_4_2, d_5_2, options);
+ }
+
+ auto d_38_10 = ArrayFromJSON(decimal256(38, 10), R"([
+ "-02.1234567890",
+ "30.1234567890",
+ null])");
+
+ auto d_28_0 = ArrayFromJSON(decimal256(28, 0), R"([
+ "-02.",
+ "30.",
+ null])");
+
+ auto d_38_10_roundtripped = ArrayFromJSON(decimal256(38, 10), R"([
+ "-02.0000000000",
+ "30.0000000000",
+ null])");
+
+ // Rescale which leads to truncation
+ options.allow_decimal_truncate = true;
+ CheckCast(d_38_10, d_28_0, options);
+ CheckCast(d_28_0, d_38_10_roundtripped, options);
+
+ options.allow_decimal_truncate = false;
+ options.to_type = d_28_0->type();
+ CheckCastFails(d_38_10, options);
+ CheckCast(d_28_0, d_38_10_roundtripped, options);
+
+ // Precision loss without rescale leads to truncation
+ auto d_4_2 = ArrayFromJSON(decimal256(4, 2), R"(["12.34"])");
+ for (auto expected : {
+ ArrayFromJSON(decimal256(3, 2), R"(["12.34"])"),
+ ArrayFromJSON(decimal256(4, 3), R"(["12.340"])"),
+ ArrayFromJSON(decimal256(2, 1), R"(["12.3"])"),
+ }) {
+ options.allow_decimal_truncate = true;
+ CheckCast(d_4_2, expected, options);
+
+ options.allow_decimal_truncate = false;
+ options.to_type = expected->type();
+ CheckCastFails(d_4_2, options);
+ }
+}
+
+TEST(Cast, Decimal128ToDecimal256) {
+ CastOptions options;
+
+ for (bool allow_decimal_truncate : {false, true}) {
+ options.allow_decimal_truncate = allow_decimal_truncate;
+
+ auto no_truncation = ArrayFromJSON(decimal(38, 10), R"([
+ "02.0000000000",
+ "30.0000000000",
+ "22.0000000000",
+ "-121.0000000000",
+ null])");
+ auto expected = ArrayFromJSON(decimal256(48, 0), R"([
+ "02.",
+ "30.",
+ "22.",
+ "-121.",
+ null])");
+
+ CheckCast(no_truncation, expected, options);
+ }
+
+ for (bool allow_decimal_truncate : {false, true}) {
+ options.allow_decimal_truncate = allow_decimal_truncate;
+
+ // Same scale, different precision
+ auto d_5_2 = ArrayFromJSON(decimal(5, 2), R"([
+ "12.34",
+ "0.56"])");
+ auto d_4_2 = ArrayFromJSON(decimal256(4, 2), R"([
+ "12.34",
+ "0.56"])");
+ auto d_40_2 = ArrayFromJSON(decimal256(40, 2), R"([
+ "12.34",
+ "0.56"])");
+
+ CheckCast(d_5_2, d_4_2, options);
+ CheckCast(d_5_2, d_40_2, options);
+ }
+
+ auto d128_38_10 = ArrayFromJSON(decimal(38, 10), R"([
+ "-02.1234567890",
+ "30.1234567890",
+ null])");
+
+ auto d128_28_0 = ArrayFromJSON(decimal(28, 0), R"([
+ "-02.",
+ "30.",
+ null])");
+
+ auto d256_28_0 = ArrayFromJSON(decimal256(28, 0), R"([
+ "-02.",
+ "30.",
+ null])");
+
+ auto d256_38_10_roundtripped = ArrayFromJSON(decimal256(38, 10), R"([
+ "-02.0000000000",
+ "30.0000000000",
+ null])");
+
+ // Rescale which leads to truncation
+ options.allow_decimal_truncate = true;
+ CheckCast(d128_38_10, d256_28_0, options);
+ CheckCast(d128_28_0, d256_38_10_roundtripped, options);
+
+ options.allow_decimal_truncate = false;
+ options.to_type = d256_28_0->type();
+ CheckCastFails(d128_38_10, options);
+ CheckCast(d128_28_0, d256_38_10_roundtripped, options);
+
+ // Precision loss without rescale leads to truncation
+ auto d128_4_2 = ArrayFromJSON(decimal(4, 2), R"(["12.34"])");
+ for (auto expected : {
+ ArrayFromJSON(decimal256(3, 2), R"(["12.34"])"),
+ ArrayFromJSON(decimal256(4, 3), R"(["12.340"])"),
+ ArrayFromJSON(decimal256(2, 1), R"(["12.3"])"),
+ }) {
+ options.allow_decimal_truncate = true;
+ CheckCast(d128_4_2, expected, options);
+
+ options.allow_decimal_truncate = false;
+ options.to_type = expected->type();
+ CheckCastFails(d128_4_2, options);
+ }
+}
+
+TEST(Cast, Decimal256ToDecimal128) {
+ CastOptions options;
+
+ for (bool allow_decimal_truncate : {false, true}) {
+ options.allow_decimal_truncate = allow_decimal_truncate;
+
+ auto no_truncation = ArrayFromJSON(decimal256(42, 10), R"([
+ "02.0000000000",
+ "30.0000000000",
+ "22.0000000000",
+ "-121.0000000000",
+ null])");
+ auto expected = ArrayFromJSON(decimal(28, 0), R"([
+ "02.",
+ "30.",
+ "22.",
+ "-121.",
+ null])");
+
+ CheckCast(no_truncation, expected, options);
+ }
+
+ for (bool allow_decimal_truncate : {false, true}) {
+ options.allow_decimal_truncate = allow_decimal_truncate;
+
+ // Same scale, different precision
+ auto d_5_2 = ArrayFromJSON(decimal256(42, 2), R"([
+ "12.34",
+ "0.56"])");
+ auto d_4_2 = ArrayFromJSON(decimal(4, 2), R"([
+ "12.34",
+ "0.56"])");
+
+ CheckCast(d_5_2, d_4_2, options);
+ }
+
+ auto d256_52_10 = ArrayFromJSON(decimal256(52, 10), R"([
+ "-02.1234567890",
+ "30.1234567890",
+ null])");
+
+ auto d256_42_0 = ArrayFromJSON(decimal256(42, 0), R"([
+ "-02.",
+ "30.",
+ null])");
+
+ auto d128_28_0 = ArrayFromJSON(decimal(28, 0), R"([
+ "-02.",
+ "30.",
+ null])");
+
+ auto d128_38_10_roundtripped = ArrayFromJSON(decimal(38, 10), R"([
+ "-02.0000000000",
+ "30.0000000000",
+ null])");
+
+ // Rescale which leads to truncation
+ options.allow_decimal_truncate = true;
+ CheckCast(d256_52_10, d128_28_0, options);
+ CheckCast(d256_42_0, d128_38_10_roundtripped, options);
+
+ options.allow_decimal_truncate = false;
+ options.to_type = d128_28_0->type();
+ CheckCastFails(d256_52_10, options);
+ CheckCast(d256_42_0, d128_38_10_roundtripped, options);
+
+ // Precision loss without rescale leads to truncation
+ auto d256_4_2 = ArrayFromJSON(decimal256(4, 2), R"(["12.34"])");
+ for (auto expected : {
+ ArrayFromJSON(decimal(3, 2), R"(["12.34"])"),
+ ArrayFromJSON(decimal(4, 3), R"(["12.340"])"),
+ ArrayFromJSON(decimal(2, 1), R"(["12.3"])"),
+ }) {
+ options.allow_decimal_truncate = true;
+ CheckCast(d256_4_2, expected, options);
+
+ options.allow_decimal_truncate = false;
+ options.to_type = expected->type();
+ CheckCastFails(d256_4_2, options);
+ }
+}
+
+TEST(Cast, FloatingToDecimal) {
+ for (auto float_type : {float32(), float64()}) {
+ for (auto decimal_type : {decimal(5, 2), decimal256(5, 2)}) {
+ CheckCast(
+ ArrayFromJSON(float_type, "[0.0, null, 123.45, 123.456, 999.994]"),
+ ArrayFromJSON(decimal_type, R"(["0.00", null, "123.45", "123.46", "999.99"])"));
+
+ // Overflow
+ CastOptions options;
+ options.to_type = decimal_type;
+ CheckCastFails(ArrayFromJSON(float_type, "[999.996]"), options);
+
+ options.allow_decimal_truncate = true;
+ CheckCast(
+ ArrayFromJSON(float_type, "[0.0, null, 999.996, 123.45, 999.994]"),
+ ArrayFromJSON(decimal_type, R"(["0.00", null, "0.00", "123.45", "999.99"])"),
+ options);
+ }
+ }
+
+ for (auto decimal_type : {decimal128, decimal256}) {
+ // 2**64 + 2**41 (exactly representable as a float)
+ CheckCast(ArrayFromJSON(float32(), "[1.8446746e+19, -1.8446746e+19]"),
+ ArrayFromJSON(decimal_type(20, 0),
+ R"(["18446746272732807168", "-18446746272732807168"])"));
+
+ CheckCast(
+ ArrayFromJSON(float64(), "[1.8446744073709556e+19, -1.8446744073709556e+19]"),
+ ArrayFromJSON(decimal_type(20, 0),
+ R"(["18446744073709555712", "-18446744073709555712"])"));
+
+ CheckCast(ArrayFromJSON(float32(), "[1.8446746e+15, -1.8446746e+15]"),
+ ArrayFromJSON(decimal_type(20, 4),
+ R"(["1844674627273280.7168", "-1844674627273280.7168"])"));
+
+ CheckCast(
+ ArrayFromJSON(float64(), "[1.8446744073709556e+15, -1.8446744073709556e+15]"),
+ ArrayFromJSON(decimal_type(20, 4),
+ R"(["1844674407370955.5712", "-1844674407370955.5712"])"));
+
+ // Edge cases are tested for Decimal128::FromReal() and Decimal256::FromReal
+ }
+}
+
+TEST(Cast, DecimalToFloating) {
+ for (auto float_type : {float32(), float64()}) {
+ for (auto decimal_type : {decimal(5, 2), decimal256(5, 2)}) {
+ CheckCast(ArrayFromJSON(decimal_type, R"(["0.00", null, "123.45", "999.99"])"),
+ ArrayFromJSON(float_type, "[0.0, null, 123.45, 999.99]"));
+ }
+ }
+
+ // Edge cases are tested for Decimal128::ToReal() and Decimal256::ToReal()
+}
+
+TEST(Cast, TimestampToTimestamp) {
+ struct TimestampTypePair {
+ std::shared_ptr<DataType> coarse, fine;
+ };
+
+ CastOptions options;
+
+ for (auto types : {
+ TimestampTypePair{timestamp(TimeUnit::SECOND), timestamp(TimeUnit::MILLI)},
+ TimestampTypePair{timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MICRO)},
+ TimestampTypePair{timestamp(TimeUnit::MICRO), timestamp(TimeUnit::NANO)},
+ }) {
+ auto coarse = ArrayFromJSON(types.coarse, "[0, null, 200, 1, 2]");
+ auto promoted = ArrayFromJSON(types.fine, "[0, null, 200000, 1000, 2000]");
+
+ // multiply/promote
+ CheckCast(coarse, promoted);
+
+ auto will_be_truncated = ArrayFromJSON(types.fine, "[0, null, 200456, 1123, 2456]");
+
+ // with truncation disallowed, fails
+ options.allow_time_truncate = false;
+ options.to_type = types.coarse;
+ CheckCastFails(will_be_truncated, options);
+
+ // with truncation allowed, divide/truncate
+ options.allow_time_truncate = true;
+ CheckCast(will_be_truncated, coarse, options);
+ }
+
+ for (auto types : {
+ TimestampTypePair{timestamp(TimeUnit::SECOND), timestamp(TimeUnit::MICRO)},
+ TimestampTypePair{timestamp(TimeUnit::MILLI), timestamp(TimeUnit::NANO)},
+ }) {
+ auto coarse = ArrayFromJSON(types.coarse, "[0, null, 200, 1, 2]");
+ auto promoted = ArrayFromJSON(types.fine, "[0, null, 200000000, 1000000, 2000000]");
+
+ // multiply/promote
+ CheckCast(coarse, promoted);
+
+ auto will_be_truncated =
+ ArrayFromJSON(types.fine, "[0, null, 200456000, 1123000, 2456000]");
+
+ // with truncation disallowed, fails
+ options.allow_time_truncate = false;
+ options.to_type = types.coarse;
+ CheckCastFails(will_be_truncated, options);
+
+ // with truncation allowed, divide/truncate
+ options.allow_time_truncate = true;
+ CheckCast(will_be_truncated, coarse, options);
+ }
+
+ for (auto types : {
+ TimestampTypePair{timestamp(TimeUnit::SECOND), timestamp(TimeUnit::NANO)},
+ }) {
+ auto coarse = ArrayFromJSON(types.coarse, "[0, null, 200, 1, 2]");
+ auto promoted =
+ ArrayFromJSON(types.fine, "[0, null, 200000000000, 1000000000, 2000000000]");
+
+ // multiply/promote
+ CheckCast(coarse, promoted);
+
+ auto will_be_truncated =
+ ArrayFromJSON(types.fine, "[0, null, 200456000000, 1123000000, 2456000000]");
+
+ // with truncation disallowed, fails
+ options.allow_time_truncate = false;
+ options.to_type = types.coarse;
+ CheckCastFails(will_be_truncated, options);
+
+ // with truncation allowed, divide/truncate
+ options.allow_time_truncate = true;
+ CheckCast(will_be_truncated, coarse, options);
+ }
+}
+
+TEST(Cast, TimestampZeroCopy) {
+ for (auto zero_copy_to_type : {
+ timestamp(TimeUnit::SECOND),
+ int64(), // ARROW-1773, cast to integer
+ }) {
+ CheckCastZeroCopy(
+ ArrayFromJSON(timestamp(TimeUnit::SECOND), "[0, null, 2000, 1000, 0]"),
+ zero_copy_to_type);
+ }
+ CheckCastZeroCopy(ArrayFromJSON(int64(), "[0, null, 2000, 1000, 0]"),
+ timestamp(TimeUnit::SECOND));
+}
+
+TEST(Cast, TimestampToTimestampMultiplyOverflow) {
+ CastOptions options;
+ options.to_type = timestamp(TimeUnit::NANO);
+ // 1000-01-01, 1800-01-01 , 2000-01-01, 2300-01-01, 3000-01-01
+ CheckCastFails(
+ ArrayFromJSON(timestamp(TimeUnit::SECOND),
+ "[-30610224000, -5364662400, 946684800, 10413792000, 32503680000]"),
+ options);
+}
+
+constexpr char kTimestampJson[] =
+ R"(["1970-01-01T00:00:59.123456789","2000-02-29T23:23:23.999999999",
+ "1899-01-01T00:59:20.001001001","2033-05-18T03:33:20.000000000",
+ "2020-01-01T01:05:05.001", "2019-12-31T02:10:10.002",
+ "2019-12-30T03:15:15.003", "2009-12-31T04:20:20.004132",
+ "2010-01-01T05:25:25.005321", "2010-01-03T06:30:30.006163",
+ "2010-01-04T07:35:35", "2006-01-01T08:40:40", "2005-12-31T09:45:45",
+ "2008-12-28", "2008-12-29", "2012-01-01 01:02:03", null])";
+constexpr char kTimestampSecondsJson[] =
+ R"(["1970-01-01T00:00:59","2000-02-29T23:23:23",
+ "1899-01-01T00:59:20","2033-05-18T03:33:20",
+ "2020-01-01T01:05:05", "2019-12-31T02:10:10",
+ "2019-12-30T03:15:15", "2009-12-31T04:20:20",
+ "2010-01-01T05:25:25", "2010-01-03T06:30:30",
+ "2010-01-04T07:35:35", "2006-01-01T08:40:40",
+ "2005-12-31T09:45:45", "2008-12-28", "2008-12-29",
+ "2012-01-01 01:02:03", null])";
+constexpr char kTimestampExtremeJson[] =
+ R"(["1677-09-20T00:00:59.123456", "2262-04-13T23:23:23.999999"])";
+
+TEST(Cast, TimestampToDate) {
+ // See scalar_temporal_test.cc
+ auto timestamps = ArrayFromJSON(timestamp(TimeUnit::NANO), kTimestampJson);
+ auto date_32 = ArrayFromJSON(date32(),
+ R"([
+ 0, 11016, -25932, 23148,
+ 18262, 18261, 18260, 14609,
+ 14610, 14612, 14613, 13149,
+ 13148, 14241, 14242, 15340, null
+ ])");
+ auto date_64 = ArrayFromJSON(date64(),
+ R"([
+ 0, 951782400000, -2240524800000, 1999987200000,
+ 1577836800000, 1577750400000, 1577664000000, 1262217600000,
+ 1262304000000, 1262476800000, 1262563200000, 1136073600000,
+ 1135987200000, 1230422400000, 1230508800000, 1325376000000, null
+ ])");
+ // See TestOutsideNanosecondRange in scalar_temporal_test.cc
+ auto timestamps_extreme =
+ ArrayFromJSON(timestamp(TimeUnit::MICRO),
+ R"(["1677-09-20T00:00:59.123456", "2262-04-13T23:23:23.999999"])");
+ auto date_32_extreme = ArrayFromJSON(date32(), "[-106753, 106753]");
+ auto date_64_extreme = ArrayFromJSON(date64(), "[-9223459200000, 9223459200000]");
+
+ CheckCast(timestamps, date_32);
+ CheckCast(timestamps, date_64);
+ CheckCast(timestamps_extreme, date_32_extreme);
+ CheckCast(timestamps_extreme, date_64_extreme);
+ for (auto u : TimeUnit::values()) {
+ auto unit = timestamp(u);
+ CheckCast(ArrayFromJSON(unit, kTimestampSecondsJson), date_32);
+ CheckCast(ArrayFromJSON(unit, kTimestampSecondsJson), date_64);
+ }
+}
+
+TEST(Cast, ZonedTimestampToDate) {
+#ifdef _WIN32
+ // TODO(ARROW-13168): we lack tzdb on Windows
+ GTEST_SKIP() << "ARROW-13168: no access to timezone database on Windows";
+#endif
+
+ {
+ // See TestZoned in scalar_temporal_test.cc
+ auto timestamps =
+ ArrayFromJSON(timestamp(TimeUnit::NANO, "Pacific/Marquesas"), kTimestampJson);
+ auto date_32 = ArrayFromJSON(date32(),
+ R"([
+ -1, 11016, -25933, 23147,
+ 18261, 18260, 18259, 14608,
+ 14609, 14611, 14612, 13148,
+ 13148, 14240, 14241, 15339, null
+ ])");
+ auto date_64 = ArrayFromJSON(date64(), R"([
+ -86400000, 951782400000, -2240611200000, 1999900800000,
+ 1577750400000, 1577664000000, 1577577600000, 1262131200000,
+ 1262217600000, 1262390400000, 1262476800000, 1135987200000,
+ 1135987200000, 1230336000000, 1230422400000, 1325289600000, null
+ ])");
+ CheckCast(timestamps, date_32);
+ CheckCast(timestamps, date_64);
+ }
+
+ auto date_32 = ArrayFromJSON(date32(), R"([
+ 0, 11017, -25932, 23148,
+ 18262, 18261, 18260, 14609,
+ 14610, 14612, 14613, 13149,
+ 13148, 14241, 14242, 15340, null
+ ])");
+ auto date_64 = ArrayFromJSON(date64(), R"([
+ 0, 951868800000, -2240524800000, 1999987200000, 1577836800000,
+ 1577750400000, 1577664000000, 1262217600000, 1262304000000,
+ 1262476800000, 1262563200000, 1136073600000, 1135987200000,
+ 1230422400000, 1230508800000, 1325376000000, null
+ ])");
+
+ for (auto u : TimeUnit::values()) {
+ auto timestamps =
+ ArrayFromJSON(timestamp(u, "Australia/Broken_Hill"), kTimestampSecondsJson);
+ CheckCast(timestamps, date_32);
+ CheckCast(timestamps, date_64);
+ }
+
+ // Invalid timezone
+ for (auto u : TimeUnit::values()) {
+ auto timestamps =
+ ArrayFromJSON(timestamp(u, "Mars/Mariner_Valley"), kTimestampSecondsJson);
+ CheckCastFails(timestamps, CastOptions::Unsafe(date32()));
+ CheckCastFails(timestamps, CastOptions::Unsafe(date64()));
+ }
+}
+
+TEST(Cast, TimestampToTime) {
+ // See scalar_temporal_test.cc
+ auto timestamps = ArrayFromJSON(timestamp(TimeUnit::NANO), kTimestampJson);
+ // See TestOutsideNanosecondRange in scalar_temporal_test.cc
+ auto timestamps_extreme =
+ ArrayFromJSON(timestamp(TimeUnit::MICRO), kTimestampExtremeJson);
+ auto timestamps_us = ArrayFromJSON(timestamp(TimeUnit::MICRO), R"([
+ "1970-01-01T00:00:59.123456","2000-02-29T23:23:23.999999",
+ "1899-01-01T00:59:20.001001","2033-05-18T03:33:20.000000",
+ "2020-01-01T01:05:05.001", "2019-12-31T02:10:10.002",
+ "2019-12-30T03:15:15.003", "2009-12-31T04:20:20.004132",
+ "2010-01-01T05:25:25.005321", "2010-01-03T06:30:30.006163",
+ "2010-01-04T07:35:35", "2006-01-01T08:40:40", "2005-12-31T09:45:45",
+ "2008-12-28", "2008-12-29", "2012-01-01 01:02:03", null])");
+ auto timestamps_ms = ArrayFromJSON(timestamp(TimeUnit::MILLI), R"([
+ "1970-01-01T00:00:59.123","2000-02-29T23:23:23.999",
+ "1899-01-01T00:59:20.001","2033-05-18T03:33:20.000",
+ "2020-01-01T01:05:05.001", "2019-12-31T02:10:10.002",
+ "2019-12-30T03:15:15.003", "2009-12-31T04:20:20.004",
+ "2010-01-01T05:25:25.005", "2010-01-03T06:30:30.006",
+ "2010-01-04T07:35:35", "2006-01-01T08:40:40", "2005-12-31T09:45:45",
+ "2008-12-28", "2008-12-29", "2012-01-01 01:02:03", null])");
+ auto timestamps_s = ArrayFromJSON(timestamp(TimeUnit::SECOND), kTimestampSecondsJson);
+
+ auto times = ArrayFromJSON(time64(TimeUnit::NANO), R"([
+ 59123456789, 84203999999999, 3560001001001, 12800000000000,
+ 3905001000000, 7810002000000, 11715003000000, 15620004132000,
+ 19525005321000, 23430006163000, 27335000000000, 31240000000000,
+ 35145000000000, 0, 0, 3723000000000, null
+ ])");
+ auto times_ns_us = ArrayFromJSON(time64(TimeUnit::MICRO), R"([
+ 59123456, 84203999999, 3560001001, 12800000000,
+ 3905001000, 7810002000, 11715003000, 15620004132,
+ 19525005321, 23430006163, 27335000000, 31240000000,
+ 35145000000, 0, 0, 3723000000, null
+ ])");
+ auto times_ns_ms = ArrayFromJSON(time32(TimeUnit::MILLI), R"([
+ 59123, 84203999, 3560001, 12800000,
+ 3905001, 7810002, 11715003, 15620004,
+ 19525005, 23430006, 27335000, 31240000,
+ 35145000, 0, 0, 3723000, null
+ ])");
+ auto times_us_ns = ArrayFromJSON(time64(TimeUnit::NANO), R"([
+ 59123456000, 84203999999000, 3560001001000, 12800000000000,
+ 3905001000000, 7810002000000, 11715003000000, 15620004132000,
+ 19525005321000, 23430006163000, 27335000000000, 31240000000000,
+ 35145000000000, 0, 0, 3723000000000, null
+ ])");
+ auto times_ms_ns = ArrayFromJSON(time64(TimeUnit::NANO), R"([
+ 59123000000, 84203999000000, 3560001000000, 12800000000000,
+ 3905001000000, 7810002000000, 11715003000000, 15620004000000,
+ 19525005000000, 23430006000000, 27335000000000, 31240000000000,
+ 35145000000000, 0, 0, 3723000000000, null
+ ])");
+ auto times_ms_us = ArrayFromJSON(time64(TimeUnit::MICRO), R"([
+ 59123000, 84203999000, 3560001000, 12800000000,
+ 3905001000, 7810002000, 11715003000, 15620004000,
+ 19525005000, 23430006000, 27335000000, 31240000000,
+ 35145000000, 0, 0, 3723000000, null
+ ])");
+
+ auto times_extreme = ArrayFromJSON(time64(TimeUnit::MICRO), "[59123456, 84203999999]");
+ auto times_s = ArrayFromJSON(time32(TimeUnit::SECOND), R"([
+ 59, 84203, 3560, 12800,
+ 3905, 7810, 11715, 15620,
+ 19525, 23430, 27335, 31240,
+ 35145, 0, 0, 3723, null
+ ])");
+ auto times_ms = ArrayFromJSON(time32(TimeUnit::MILLI), R"([
+ 59000, 84203000, 3560000, 12800000,
+ 3905000, 7810000, 11715000, 15620000,
+ 19525000, 23430000, 27335000, 31240000,
+ 35145000, 0, 0, 3723000, null
+ ])");
+ auto times_us = ArrayFromJSON(time64(TimeUnit::MICRO), R"([
+ 59000000, 84203000000, 3560000000, 12800000000,
+ 3905000000, 7810000000, 11715000000, 15620000000,
+ 19525000000, 23430000000, 27335000000, 31240000000,
+ 35145000000, 0, 0, 3723000000, null
+ ])");
+ auto times_ns = ArrayFromJSON(time64(TimeUnit::NANO), R"([
+ 59000000000, 84203000000000, 3560000000000, 12800000000000,
+ 3905000000000, 7810000000000, 11715000000000, 15620000000000,
+ 19525000000000, 23430000000000, 27335000000000, 31240000000000,
+ 35145000000000, 0, 0, 3723000000000, null
+ ])");
+
+ CheckCast(timestamps, times);
+ CheckCastFails(timestamps, CastOptions::Safe(time64(TimeUnit::MICRO)));
+ CheckCast(timestamps_extreme, times_extreme);
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::SECOND), kTimestampSecondsJson), times_s);
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::SECOND), kTimestampSecondsJson), times_ms);
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::MILLI), kTimestampSecondsJson), times_s);
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::MILLI), kTimestampSecondsJson), times_ms);
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::MICRO), kTimestampSecondsJson), times_us);
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::MICRO), kTimestampSecondsJson), times_ns);
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::MICRO), kTimestampSecondsJson), times_ms);
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::MICRO), kTimestampSecondsJson), times_s);
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::NANO), kTimestampSecondsJson), times_ns);
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::NANO), kTimestampSecondsJson), times_us);
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::NANO), kTimestampSecondsJson), times_ms);
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::NANO), kTimestampSecondsJson), times_s);
+
+ CastOptions truncate = CastOptions::Safe();
+ truncate.allow_time_truncate = true;
+
+ // Truncation tests
+ CheckCastFails(timestamps, CastOptions::Safe(time64(TimeUnit::MICRO)));
+ CheckCastFails(timestamps, CastOptions::Safe(time32(TimeUnit::MILLI)));
+ CheckCastFails(timestamps, CastOptions::Safe(time32(TimeUnit::SECOND)));
+ CheckCastFails(timestamps_us, CastOptions::Safe(time32(TimeUnit::MILLI)));
+ CheckCastFails(timestamps_us, CastOptions::Safe(time32(TimeUnit::SECOND)));
+ CheckCastFails(timestamps_ms, CastOptions::Safe(time32(TimeUnit::SECOND)));
+ CheckCast(timestamps, times_ns_us, truncate);
+ CheckCast(timestamps, times_ns_ms, truncate);
+ CheckCast(timestamps, times_s, truncate);
+ CheckCast(timestamps_us, times_ns_ms, truncate);
+ CheckCast(timestamps_us, times_s, truncate);
+ CheckCast(timestamps_ms, times_s, truncate);
+
+ // Upscaling tests
+ CheckCast(timestamps_us, times_us_ns);
+ CheckCast(timestamps_ms, times_ms_ns);
+ CheckCast(timestamps_ms, times_ms_us);
+ CheckCast(timestamps_s, times_ns);
+ CheckCast(timestamps_s, times_us);
+ CheckCast(timestamps_s, times_ms);
+
+ // Invalid timezone
+ for (auto u : TimeUnit::values()) {
+ auto timestamps =
+ ArrayFromJSON(timestamp(u, "Mars/Mariner_Valley"), kTimestampSecondsJson);
+ if (u == TimeUnit::SECOND || u == TimeUnit::MILLI) {
+ CheckCastFails(timestamps, CastOptions::Unsafe(time32(u)));
+ } else {
+ CheckCastFails(timestamps, CastOptions::Unsafe(time64(u)));
+ }
+ }
+}
+
+TEST(Cast, ZonedTimestampToTime) {
+#ifdef _WIN32
+ // TODO(ARROW-13168): we lack tzdb on Windows
+ GTEST_SKIP() << "ARROW-13168: no access to timezone database on Windows";
+#endif
+
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::NANO, "Pacific/Marquesas"), kTimestampJson),
+ ArrayFromJSON(time64(TimeUnit::NANO), R"([
+ 52259123456789, 50003999999999, 56480001001001, 65000000000000,
+ 56105001000000, 60010002000000, 63915003000000, 67820004132000,
+ 71725005321000, 75630006163000, 79535000000000, 83440000000000,
+ 945000000000, 52200000000000, 52200000000000, 55923000000000, null
+ ])"));
+
+ auto time_s = R"([
+ 34259, 35603, 35960, 47000,
+ 41705, 45610, 49515, 53420,
+ 57325, 61230, 65135, 69040,
+ 72945, 37800, 37800, 41523, null
+ ])";
+ auto time_ms = R"([
+ 34259000, 35603000, 35960000, 47000000,
+ 41705000, 45610000, 49515000, 53420000,
+ 57325000, 61230000, 65135000, 69040000,
+ 72945000, 37800000, 37800000, 41523000, null
+ ])";
+ auto time_us = R"([
+ 34259000000, 35603000000, 35960000000, 47000000000,
+ 41705000000, 45610000000, 49515000000, 53420000000,
+ 57325000000, 61230000000, 65135000000, 69040000000,
+ 72945000000, 37800000000, 37800000000, 41523000000, null
+ ])";
+ auto time_ns = R"([
+ 34259000000000, 35603000000000, 35960000000000, 47000000000000,
+ 41705000000000, 45610000000000, 49515000000000, 53420000000000,
+ 57325000000000, 61230000000000, 65135000000000, 69040000000000,
+ 72945000000000, 37800000000000, 37800000000000, 41523000000000, null
+ ])";
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::SECOND, "Australia/Broken_Hill"),
+ kTimestampSecondsJson),
+ ArrayFromJSON(time32(TimeUnit::SECOND), time_s));
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::MILLI, "Australia/Broken_Hill"),
+ kTimestampSecondsJson),
+ ArrayFromJSON(time32(TimeUnit::MILLI), time_ms));
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::MICRO, "Australia/Broken_Hill"),
+ kTimestampSecondsJson),
+ ArrayFromJSON(time64(TimeUnit::MICRO), time_us));
+ CheckCast(ArrayFromJSON(timestamp(TimeUnit::NANO, "Australia/Broken_Hill"),
+ kTimestampSecondsJson),
+ ArrayFromJSON(time64(TimeUnit::NANO), time_ns));
+}
+
+TEST(Cast, TimeToTime) {
+ struct TimeTypePair {
+ std::shared_ptr<DataType> coarse, fine;
+ };
+
+ CastOptions options;
+
+ for (auto types : {
+ TimeTypePair{time32(TimeUnit::SECOND), time32(TimeUnit::MILLI)},
+ TimeTypePair{time32(TimeUnit::MILLI), time64(TimeUnit::MICRO)},
+ TimeTypePair{time64(TimeUnit::MICRO), time64(TimeUnit::NANO)},
+ }) {
+ auto coarse = ArrayFromJSON(types.coarse, "[0, null, 200, 1, 2]");
+ auto promoted = ArrayFromJSON(types.fine, "[0, null, 200000, 1000, 2000]");
+
+ // multiply/promote
+ CheckCast(coarse, promoted);
+
+ auto will_be_truncated = ArrayFromJSON(types.fine, "[0, null, 200456, 1123, 2456]");
+
+ // with truncation disallowed, fails
+ options.allow_time_truncate = false;
+ options.to_type = types.coarse;
+ CheckCastFails(will_be_truncated, options);
+
+ // with truncation allowed, divide/truncate
+ options.allow_time_truncate = true;
+ CheckCast(will_be_truncated, coarse, options);
+ }
+
+ for (auto types : {
+ TimeTypePair{time32(TimeUnit::SECOND), time64(TimeUnit::MICRO)},
+ TimeTypePair{time32(TimeUnit::MILLI), time64(TimeUnit::NANO)},
+ }) {
+ auto coarse = ArrayFromJSON(types.coarse, "[0, null, 200, 1, 2]");
+ auto promoted = ArrayFromJSON(types.fine, "[0, null, 200000000, 1000000, 2000000]");
+
+ // multiply/promote
+ CheckCast(coarse, promoted);
+
+ auto will_be_truncated =
+ ArrayFromJSON(types.fine, "[0, null, 200456000, 1123000, 2456000]");
+
+ // with truncation disallowed, fails
+ options.allow_time_truncate = false;
+ options.to_type = types.coarse;
+ CheckCastFails(will_be_truncated, options);
+
+ // with truncation allowed, divide/truncate
+ options.allow_time_truncate = true;
+ CheckCast(will_be_truncated, coarse, options);
+ }
+
+ for (auto types : {
+ TimeTypePair{time32(TimeUnit::SECOND), time64(TimeUnit::NANO)},
+ }) {
+ auto coarse = ArrayFromJSON(types.coarse, "[0, null, 200, 1, 2]");
+ auto promoted =
+ ArrayFromJSON(types.fine, "[0, null, 200000000000, 1000000000, 2000000000]");
+
+ // multiply/promote
+ CheckCast(coarse, promoted);
+
+ auto will_be_truncated =
+ ArrayFromJSON(types.fine, "[0, null, 200456000000, 1123000000, 2456000000]");
+
+ // with truncation disallowed, fails
+ options.allow_time_truncate = false;
+ options.to_type = types.coarse;
+ CheckCastFails(will_be_truncated, options);
+
+ // with truncation allowed, divide/truncate
+ options.allow_time_truncate = true;
+ CheckCast(will_be_truncated, coarse, options);
+ }
+}
+
+TEST(Cast, TimeZeroCopy) {
+ for (auto zero_copy_to_type : {
+ time32(TimeUnit::SECOND),
+ int32(), // ARROW-1773: cast to int32
+ }) {
+ CheckCastZeroCopy(ArrayFromJSON(time32(TimeUnit::SECOND), "[0, null, 2000, 1000, 0]"),
+ zero_copy_to_type);
+ }
+ CheckCastZeroCopy(ArrayFromJSON(int32(), "[0, null, 2000, 1000, 0]"),
+ time32(TimeUnit::SECOND));
+
+ for (auto zero_copy_to_type : {
+ time64(TimeUnit::MICRO),
+ int64(), // ARROW-1773: cast to int64
+ }) {
+ CheckCastZeroCopy(ArrayFromJSON(time64(TimeUnit::MICRO), "[0, null, 2000, 1000, 0]"),
+ zero_copy_to_type);
+ }
+ CheckCastZeroCopy(ArrayFromJSON(int64(), "[0, null, 2000, 1000, 0]"),
+ time64(TimeUnit::MICRO));
+}
+
+TEST(Cast, DateToString) {
+ for (auto string_type : {utf8(), large_utf8()}) {
+ CheckCast(ArrayFromJSON(date32(), "[0, null]"),
+ ArrayFromJSON(string_type, R"(["1970-01-01", null])"));
+ CheckCast(ArrayFromJSON(date64(), "[86400000, null]"),
+ ArrayFromJSON(string_type, R"(["1970-01-02", null])"));
+ }
+}
+
+TEST(Cast, TimeToString) {
+ for (auto string_type : {utf8(), large_utf8()}) {
+ CheckCast(ArrayFromJSON(time32(TimeUnit::SECOND), "[1, 62]"),
+ ArrayFromJSON(string_type, R"(["00:00:01", "00:01:02"])"));
+ CheckCast(
+ ArrayFromJSON(time64(TimeUnit::NANO), "[0, 1]"),
+ ArrayFromJSON(string_type, R"(["00:00:00.000000000", "00:00:00.000000001"])"));
+ }
+}
+
+TEST(Cast, TimestampToString) {
+ for (auto string_type : {utf8(), large_utf8()}) {
+ CheckCast(
+ ArrayFromJSON(timestamp(TimeUnit::SECOND), "[-30610224000, -5364662400]"),
+ ArrayFromJSON(string_type, R"(["1000-01-01 00:00:00", "1800-01-01 00:00:00"])"));
+ }
+}
+
+TEST(Cast, DateToDate) {
+ auto day_32 = ArrayFromJSON(date32(), "[0, null, 100, 1, 10]");
+ auto day_64 = ArrayFromJSON(date64(), R"([
+ 0,
+ null,
+ 8640000000,
+ 86400000,
+ 864000000])");
+
+ // Multiply promotion
+ CheckCast(day_32, day_64);
+
+ // No truncation
+ CheckCast(day_64, day_32);
+
+ auto day_64_will_be_truncated = ArrayFromJSON(date64(), R"([
+ 0,
+ null,
+ 8640000123,
+ 86400456,
+ 864000789])");
+
+ // Disallow truncate
+ CastOptions options;
+ options.to_type = date32();
+ CheckCastFails(day_64_will_be_truncated, options);
+
+ // Divide, truncate
+ options.allow_time_truncate = true;
+ CheckCast(day_64_will_be_truncated, day_32, options);
+}
+
+TEST(Cast, DateZeroCopy) {
+ for (auto zero_copy_to_type : {
+ date32(),
+ int32(), // ARROW-1773: cast to int32
+ }) {
+ CheckCastZeroCopy(ArrayFromJSON(date32(), "[0, null, 2000, 1000, 0]"),
+ zero_copy_to_type);
+ }
+ CheckCastZeroCopy(ArrayFromJSON(int32(), "[0, null, 2000, 1000, 0]"), date32());
+
+ for (auto zero_copy_to_type : {
+ date64(),
+ int64(), // ARROW-1773: cast to int64
+ }) {
+ CheckCastZeroCopy(ArrayFromJSON(date64(), "[0, null, 2000, 1000, 0]"),
+ zero_copy_to_type);
+ }
+ CheckCastZeroCopy(ArrayFromJSON(int64(), "[0, null, 2000, 1000, 0]"), date64());
+}
+
+TEST(Cast, DurationToDuration) {
+ struct DurationTypePair {
+ std::shared_ptr<DataType> coarse, fine;
+ };
+
+ CastOptions options;
+
+ for (auto types : {
+ DurationTypePair{duration(TimeUnit::SECOND), duration(TimeUnit::MILLI)},
+ DurationTypePair{duration(TimeUnit::MILLI), duration(TimeUnit::MICRO)},
+ DurationTypePair{duration(TimeUnit::MICRO), duration(TimeUnit::NANO)},
+ }) {
+ auto coarse = ArrayFromJSON(types.coarse, "[0, null, 200, 1, 2]");
+ auto promoted = ArrayFromJSON(types.fine, "[0, null, 200000, 1000, 2000]");
+
+ // multiply/promote
+ CheckCast(coarse, promoted);
+
+ auto will_be_truncated = ArrayFromJSON(types.fine, "[0, null, 200456, 1123, 2456]");
+
+ // with truncation disallowed, fails
+ options.allow_time_truncate = false;
+ options.to_type = types.coarse;
+ CheckCastFails(will_be_truncated, options);
+
+ // with truncation allowed, divide/truncate
+ options.allow_time_truncate = true;
+ CheckCast(will_be_truncated, coarse, options);
+ }
+
+ for (auto types : {
+ DurationTypePair{duration(TimeUnit::SECOND), duration(TimeUnit::MICRO)},
+ DurationTypePair{duration(TimeUnit::MILLI), duration(TimeUnit::NANO)},
+ }) {
+ auto coarse = ArrayFromJSON(types.coarse, "[0, null, 200, 1, 2]");
+ auto promoted = ArrayFromJSON(types.fine, "[0, null, 200000000, 1000000, 2000000]");
+
+ // multiply/promote
+ CheckCast(coarse, promoted);
+
+ auto will_be_truncated =
+ ArrayFromJSON(types.fine, "[0, null, 200000456, 1000123, 2000456]");
+
+ // with truncation disallowed, fails
+ options.allow_time_truncate = false;
+ options.to_type = types.coarse;
+ CheckCastFails(will_be_truncated, options);
+
+ // with truncation allowed, divide/truncate
+ options.allow_time_truncate = true;
+ CheckCast(will_be_truncated, coarse, options);
+ }
+
+ for (auto types : {
+ DurationTypePair{duration(TimeUnit::SECOND), duration(TimeUnit::NANO)},
+ }) {
+ auto coarse = ArrayFromJSON(types.coarse, "[0, null, 200, 1, 2]");
+ auto promoted =
+ ArrayFromJSON(types.fine, "[0, null, 200000000000, 1000000000, 2000000000]");
+
+ // multiply/promote
+ CheckCast(coarse, promoted);
+
+ auto will_be_truncated =
+ ArrayFromJSON(types.fine, "[0, null, 200000000456, 1000000123, 2000000456]");
+
+ // with truncation disallowed, fails
+ options.allow_time_truncate = false;
+ options.to_type = types.coarse;
+ CheckCastFails(will_be_truncated, options);
+
+ // with truncation allowed, divide/truncate
+ options.allow_time_truncate = true;
+ CheckCast(will_be_truncated, coarse, options);
+ }
+}
+
+TEST(Cast, DurationZeroCopy) {
+ for (auto zero_copy_to_type : {
+ duration(TimeUnit::SECOND),
+ int64(), // ARROW-1773: cast to int64
+ }) {
+ CheckCastZeroCopy(
+ ArrayFromJSON(duration(TimeUnit::SECOND), "[0, null, 2000, 1000, 0]"),
+ zero_copy_to_type);
+ }
+ CheckCastZeroCopy(ArrayFromJSON(int64(), "[0, null, 2000, 1000, 0]"),
+ duration(TimeUnit::SECOND));
+}
+
+TEST(Cast, DurationToDurationMultiplyOverflow) {
+ CastOptions options;
+ options.to_type = duration(TimeUnit::NANO);
+ CheckCastFails(
+ ArrayFromJSON(duration(TimeUnit::SECOND), "[10000000000, 1, 2, 3, 10000000000]"),
+ options);
+}
+
+TEST(Cast, MiscToFloating) {
+ for (auto to_type : {float32(), float64()}) {
+ CheckCast(ArrayFromJSON(int16(), "[0, null, 200, 1, 2]"),
+ ArrayFromJSON(to_type, "[0, null, 200, 1, 2]"));
+
+ CheckCast(ArrayFromJSON(float32(), "[0, null, 200, 1, 2]"),
+ ArrayFromJSON(to_type, "[0, null, 200, 1, 2]"));
+
+ CheckCast(ArrayFromJSON(boolean(), "[true, null, false, false, true]"),
+ ArrayFromJSON(to_type, "[1, null, 0, 0, 1]"));
+ }
+}
+
+TEST(Cast, UnsupportedInputType) {
+ // Casting to a supported target type, but with an unsupported input type
+ // for the target type.
+ const auto arr = ArrayFromJSON(int32(), "[1, 2, 3]");
+
+ const auto to_type = list(utf8());
+ const char* expected_message = "Unsupported cast from int32 to list";
+
+ // Try through concrete API
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, ::testing::HasSubstr(expected_message),
+ Cast(*arr, to_type));
+
+ // Try through general kernel API
+ CastOptions options;
+ options.to_type = to_type;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, ::testing::HasSubstr(expected_message),
+ CallFunction("cast", {arr}, &options));
+}
+
+TEST(Cast, UnsupportedTargetType) {
+ // Casting to an unsupported target type
+ const auto arr = ArrayFromJSON(int32(), "[1, 2, 3]");
+ const auto to_type = dense_union({field("a", int32())});
+
+ // Try through concrete API
+ const char* expected_message = "Unsupported cast from int32 to dense_union";
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, ::testing::HasSubstr(expected_message),
+ Cast(*arr, to_type));
+
+ // Try through general kernel API
+ CastOptions options;
+ options.to_type = to_type;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, ::testing::HasSubstr(expected_message),
+ CallFunction("cast", {arr}, &options));
+}
+
+TEST(Cast, StringToBoolean) {
+ for (auto string_type : {utf8(), large_utf8()}) {
+ CheckCast(ArrayFromJSON(string_type, R"(["False", null, "true", "True", "false"])"),
+ ArrayFromJSON(boolean(), "[false, null, true, true, false]"));
+
+ CheckCast(ArrayFromJSON(string_type, R"(["0", null, "1", "1", "0"])"),
+ ArrayFromJSON(boolean(), "[false, null, true, true, false]"));
+
+ auto options = CastOptions::Safe(boolean());
+ CheckCastFails(ArrayFromJSON(string_type, R"(["false "])"), options);
+ CheckCastFails(ArrayFromJSON(string_type, R"(["T"])"), options);
+ }
+}
+
+TEST(Cast, StringToInt) {
+ for (auto string_type : {utf8(), large_utf8()}) {
+ for (auto signed_type : {int8(), int16(), int32(), int64()}) {
+ CheckCast(
+ ArrayFromJSON(string_type, R"(["0", null, "127", "-1", "0", "0x0", "0x7F"])"),
+ ArrayFromJSON(signed_type, "[0, null, 127, -1, 0, 0, 127]"));
+ }
+
+ CheckCast(ArrayFromJSON(string_type, R"(["2147483647", null, "-2147483648", "0",
+ "0X0", "0x7FFFFFFF", "0XFFFFfFfF", "0Xf0000000"])"),
+ ArrayFromJSON(
+ int32(),
+ "[2147483647, null, -2147483648, 0, 0, 2147483647, -1, -268435456]"));
+
+ CheckCast(ArrayFromJSON(string_type,
+ R"(["9223372036854775807", null, "-9223372036854775808", "0",
+ "0x0", "0x7FFFFFFFFFFFFFFf", "0XF000000000000001"])"),
+ ArrayFromJSON(int64(),
+ "[9223372036854775807, null, -9223372036854775808, 0, 0, "
+ "9223372036854775807, -1152921504606846975]"));
+
+ for (auto unsigned_type : {uint8(), uint16(), uint32(), uint64()}) {
+ CheckCast(ArrayFromJSON(string_type,
+ R"(["0", null, "127", "255", "0", "0X0", "0xff", "0x7f"])"),
+ ArrayFromJSON(unsigned_type, "[0, null, 127, 255, 0, 0, 255, 127]"));
+ }
+
+ CheckCast(
+ ArrayFromJSON(string_type, R"(["2147483647", null, "4294967295", "0",
+ "0x0", "0x7FFFFFFf", "0xFFFFFFFF"])"),
+ ArrayFromJSON(uint32(),
+ "[2147483647, null, 4294967295, 0, 0, 2147483647, 4294967295]"));
+
+ CheckCast(ArrayFromJSON(string_type,
+ R"(["9223372036854775807", null, "18446744073709551615", "0",
+ "0x0", "0x7FFFFFFFFFFFFFFf", "0xfFFFFFFFFFFFFFFf"])"),
+ ArrayFromJSON(uint64(),
+ "[9223372036854775807, null, 18446744073709551615, 0, 0, "
+ "9223372036854775807, 18446744073709551615]"));
+
+ for (std::string not_int8 : {
+ "z",
+ "12 z",
+ "128",
+ "-129",
+ "0.5",
+ "0x",
+ "0xfff",
+ "-0xf0",
+ }) {
+ auto options = CastOptions::Safe(int8());
+ CheckCastFails(ArrayFromJSON(string_type, "[\"" + not_int8 + "\"]"), options);
+ }
+
+ for (std::string not_uint8 : {"256", "-1", "0.5", "0x", "0x3wa", "0x123"}) {
+ auto options = CastOptions::Safe(uint8());
+ CheckCastFails(ArrayFromJSON(string_type, "[\"" + not_uint8 + "\"]"), options);
+ }
+ }
+}
+
+TEST(Cast, StringToFloating) {
+ for (auto string_type : {utf8(), large_utf8()}) {
+ for (auto float_type : {float32(), float64()}) {
+ auto strings =
+ ArrayFromJSON(string_type, R"(["0.1", null, "127.3", "1e3", "200.4", "0.5"])");
+ auto floats = ArrayFromJSON(float_type, "[0.1, null, 127.3, 1000, 200.4, 0.5]");
+ CheckCast(strings, floats);
+
+ for (std::string not_float : {
+ "z",
+ }) {
+ auto options = CastOptions::Safe(float32());
+ CheckCastFails(ArrayFromJSON(string_type, "[\"" + not_float + "\"]"), options);
+ }
+
+#if !defined(_WIN32) || defined(NDEBUG)
+ // Test that casting is locale-independent
+ // French locale uses the comma as decimal point
+ LocaleGuard locale_guard("fr_FR.UTF-8");
+ CheckCast(strings, floats);
+#endif
+ }
+ }
+}
+
+TEST(Cast, StringToTimestamp) {
+ for (auto string_type : {utf8(), large_utf8()}) {
+ auto strings = ArrayFromJSON(string_type, R"(["1970-01-01", null, "2000-02-29"])");
+
+ CheckCast(strings,
+ ArrayFromJSON(timestamp(TimeUnit::SECOND), "[0, null, 951782400]"));
+
+ CheckCast(strings,
+ ArrayFromJSON(timestamp(TimeUnit::MICRO), "[0, null, 951782400000000]"));
+
+ for (auto unit :
+ {TimeUnit::SECOND, TimeUnit::MILLI, TimeUnit::MICRO, TimeUnit::NANO}) {
+ for (std::string not_ts : {
+ "",
+ "xxx",
+ }) {
+ auto options = CastOptions::Safe(timestamp(unit));
+ CheckCastFails(ArrayFromJSON(string_type, "[\"" + not_ts + "\"]"), options);
+ }
+ }
+
+ // NOTE: timestamp parsing is tested comprehensively in parsing-util-test.cc
+ }
+}
+
+static void AssertBinaryZeroCopy(std::shared_ptr<Array> lhs, std::shared_ptr<Array> rhs) {
+ // null bitmap and data buffers are always zero-copied
+ AssertBufferSame(*lhs, *rhs, 0);
+ AssertBufferSame(*lhs, *rhs, 2);
+
+ if (offset_bit_width(lhs->type_id()) == offset_bit_width(rhs->type_id())) {
+ // offset buffer is zero copied if possible
+ AssertBufferSame(*lhs, *rhs, 1);
+ return;
+ }
+
+ // offset buffers are equivalent
+ ArrayVector offsets;
+ for (auto array : {lhs, rhs}) {
+ auto length = array->length();
+ auto buffer = array->data()->buffers[1];
+ offsets.push_back(offset_bit_width(array->type_id()) == 32
+ ? *Cast(Int32Array(length, buffer), int64())
+ : std::make_shared<Int64Array>(length, buffer));
+ }
+ AssertArraysEqual(*offsets[0], *offsets[1]);
+}
+
+TEST(Cast, BinaryToString) {
+ for (auto bin_type : {binary(), large_binary()}) {
+ for (auto string_type : {utf8(), large_utf8()}) {
+ // empty -> empty always works
+ CheckCast(ArrayFromJSON(bin_type, "[]"), ArrayFromJSON(string_type, "[]"));
+
+ auto invalid_utf8 = InvalidUtf8(bin_type);
+
+ // invalid utf-8 masked by a null bit is not an error
+ CheckCast(MaskArrayWithNullsAt(InvalidUtf8(bin_type), {4}),
+ MaskArrayWithNullsAt(InvalidUtf8(string_type), {4}));
+
+ // error: invalid utf-8
+ auto options = CastOptions::Safe(string_type);
+ CheckCastFails(invalid_utf8, options);
+
+ // override utf-8 check
+ options.allow_invalid_utf8 = true;
+ ASSERT_OK_AND_ASSIGN(auto strings, Cast(*invalid_utf8, string_type, options));
+ ASSERT_RAISES(Invalid, strings->ValidateFull());
+ AssertBinaryZeroCopy(invalid_utf8, strings);
+ }
+ }
+
+ auto from_type = fixed_size_binary(3);
+ auto invalid_utf8 = FixedSizeInvalidUtf8(from_type);
+ for (auto string_type : {utf8(), large_utf8()}) {
+ CheckCast(ArrayFromJSON(from_type, "[]"), ArrayFromJSON(string_type, "[]"));
+
+ // invalid utf-8 masked by a null bit is not an error
+ CheckCast(MaskArrayWithNullsAt(invalid_utf8, {4}),
+ MaskArrayWithNullsAt(FixedSizeInvalidUtf8(string_type), {4}));
+
+ // error: invalid utf-8
+ auto options = CastOptions::Safe(string_type);
+ CheckCastFails(invalid_utf8, options);
+
+ // override utf-8 check
+ options.allow_invalid_utf8 = true;
+ ASSERT_OK_AND_ASSIGN(auto strings, Cast(*invalid_utf8, string_type, options));
+ ASSERT_RAISES(Invalid, strings->ValidateFull());
+
+ // N.B. null buffer is not always the same if input sliced
+ AssertBufferSame(*invalid_utf8, *strings, 0);
+ ASSERT_EQ(invalid_utf8->data()->buffers[1].get(), strings->data()->buffers[2].get());
+ }
+}
+
+TEST(Cast, BinaryOrStringToBinary) {
+ for (auto from_type : {utf8(), large_utf8(), binary(), large_binary()}) {
+ for (auto to_type : {binary(), large_binary()}) {
+ // empty -> empty always works
+ CheckCast(ArrayFromJSON(from_type, "[]"), ArrayFromJSON(to_type, "[]"));
+
+ auto invalid_utf8 = InvalidUtf8(from_type);
+
+ // invalid utf-8 is not an error for binary
+ ASSERT_OK_AND_ASSIGN(auto strings, Cast(*invalid_utf8, to_type));
+ ValidateOutput(*strings);
+ AssertBinaryZeroCopy(invalid_utf8, strings);
+
+ // invalid utf-8 masked by a null bit is not an error
+ CheckCast(MaskArrayWithNullsAt(InvalidUtf8(from_type), {4}),
+ MaskArrayWithNullsAt(InvalidUtf8(to_type), {4}));
+ }
+ }
+
+ auto from_type = fixed_size_binary(3);
+ auto invalid_utf8 = FixedSizeInvalidUtf8(from_type);
+ CheckCast(invalid_utf8, invalid_utf8);
+ CheckCastFails(invalid_utf8, CastOptions::Safe(fixed_size_binary(5)));
+ for (auto to_type : {binary(), large_binary()}) {
+ CheckCast(ArrayFromJSON(from_type, "[]"), ArrayFromJSON(to_type, "[]"));
+ ASSERT_OK_AND_ASSIGN(auto strings, Cast(*invalid_utf8, to_type));
+ ValidateOutput(*strings);
+
+ // N.B. null buffer is not always the same if input sliced
+ AssertBufferSame(*invalid_utf8, *strings, 0);
+ ASSERT_EQ(invalid_utf8->data()->buffers[1].get(), strings->data()->buffers[2].get());
+
+ // invalid utf-8 masked by a null bit is not an error
+ CheckCast(MaskArrayWithNullsAt(invalid_utf8, {4}),
+ MaskArrayWithNullsAt(FixedSizeInvalidUtf8(to_type), {4}));
+ }
+}
+
+TEST(Cast, StringToString) {
+ for (auto from_type : {utf8(), large_utf8()}) {
+ for (auto to_type : {utf8(), large_utf8()}) {
+ // empty -> empty always works
+ CheckCast(ArrayFromJSON(from_type, "[]"), ArrayFromJSON(to_type, "[]"));
+
+ auto invalid_utf8 = InvalidUtf8(from_type);
+
+ // invalid utf-8 masked by a null bit is not an error
+ CheckCast(MaskArrayWithNullsAt(invalid_utf8, {4}),
+ MaskArrayWithNullsAt(InvalidUtf8(to_type), {4}));
+
+ // override utf-8 check
+ auto options = CastOptions::Safe(to_type);
+ options.allow_invalid_utf8 = true;
+ // utf-8 is not checked by Cast when the origin guarantees utf-8
+ ASSERT_OK_AND_ASSIGN(auto strings, Cast(*invalid_utf8, to_type, options));
+ ASSERT_RAISES(Invalid, strings->ValidateFull());
+ AssertBinaryZeroCopy(invalid_utf8, strings);
+ }
+ }
+}
+
+TEST(Cast, IntToString) {
+ for (auto string_type : {utf8(), large_utf8()}) {
+ CheckCast(ArrayFromJSON(int8(), "[0, 1, 127, -128, null]"),
+ ArrayFromJSON(string_type, R"(["0", "1", "127", "-128", null])"));
+
+ CheckCast(ArrayFromJSON(uint8(), "[0, 1, 255, null]"),
+ ArrayFromJSON(string_type, R"(["0", "1", "255", null])"));
+
+ CheckCast(ArrayFromJSON(int16(), "[0, 1, 32767, -32768, null]"),
+ ArrayFromJSON(string_type, R"(["0", "1", "32767", "-32768", null])"));
+
+ CheckCast(ArrayFromJSON(uint16(), "[0, 1, 65535, null]"),
+ ArrayFromJSON(string_type, R"(["0", "1", "65535", null])"));
+
+ CheckCast(
+ ArrayFromJSON(int32(), "[0, 1, 2147483647, -2147483648, null]"),
+ ArrayFromJSON(string_type, R"(["0", "1", "2147483647", "-2147483648", null])"));
+
+ CheckCast(ArrayFromJSON(uint32(), "[0, 1, 4294967295, null]"),
+ ArrayFromJSON(string_type, R"(["0", "1", "4294967295", null])"));
+
+ CheckCast(
+ ArrayFromJSON(int64(), "[0, 1, 9223372036854775807, -9223372036854775808, null]"),
+ ArrayFromJSON(
+ string_type,
+ R"(["0", "1", "9223372036854775807", "-9223372036854775808", null])"));
+
+ CheckCast(ArrayFromJSON(uint64(), "[0, 1, 18446744073709551615, null]"),
+ ArrayFromJSON(string_type, R"(["0", "1", "18446744073709551615", null])"));
+ }
+}
+
+TEST(Cast, FloatingToString) {
+ for (auto string_type : {utf8(), large_utf8()}) {
+ CheckCast(
+ ArrayFromJSON(float32(), "[0.0, -0.0, 1.5, -Inf, Inf, NaN, null]"),
+ ArrayFromJSON(string_type, R"(["0", "-0", "1.5", "-inf", "inf", "nan", null])"));
+
+ CheckCast(
+ ArrayFromJSON(float64(), "[0.0, -0.0, 1.5, -Inf, Inf, NaN, null]"),
+ ArrayFromJSON(string_type, R"(["0", "-0", "1.5", "-inf", "inf", "nan", null])"));
+ }
+}
+
+TEST(Cast, BooleanToString) {
+ for (auto string_type : {utf8(), large_utf8()}) {
+ CheckCast(ArrayFromJSON(boolean(), "[true, true, false, null]"),
+ ArrayFromJSON(string_type, R"(["true", "true", "false", null])"));
+ }
+}
+
+TEST(Cast, ListToPrimitive) {
+ ASSERT_RAISES(NotImplemented,
+ Cast(*ArrayFromJSON(list(int8()), "[[1, 2], [3, 4]]"), uint8()));
+
+ ASSERT_RAISES(
+ NotImplemented,
+ Cast(*ArrayFromJSON(list(binary()), R"([["1", "2"], ["3", "4"]])"), utf8()));
+}
+
+using make_list_t = std::shared_ptr<DataType>(const std::shared_ptr<DataType>&);
+
+static const auto list_factories = std::vector<make_list_t*>{&list, &large_list};
+
+static void CheckListToList(const std::vector<std::shared_ptr<DataType>>& value_types,
+ const std::string& json_data) {
+ for (auto make_src_list : list_factories) {
+ for (auto make_dest_list : list_factories) {
+ for (const auto& src_value_type : value_types) {
+ for (const auto& dest_value_type : value_types) {
+ const auto src_type = make_src_list(src_value_type);
+ const auto dest_type = make_dest_list(dest_value_type);
+ ARROW_SCOPED_TRACE("src_type = ", src_type->ToString(),
+ ", dest_type = ", dest_type->ToString());
+ CheckCast(ArrayFromJSON(src_type, json_data),
+ ArrayFromJSON(dest_type, json_data));
+ }
+ }
+ }
+ }
+}
+
+TEST(Cast, ListToList) {
+ CheckListToList({int32(), float32(), int64()},
+ "[[0], [1], null, [2, 3, 4], [5, 6], null, [], [7], [8, 9]]");
+}
+
+TEST(Cast, ListToListNoNulls) {
+ // ARROW-12568
+ CheckListToList({int32(), float32(), int64()},
+ "[[0], [1], [2, 3, 4], [5, 6], [], [7], [8, 9]]");
+}
+
+TEST(Cast, ListToListOptionsPassthru) {
+ for (auto make_src_list : list_factories) {
+ for (auto make_dest_list : list_factories) {
+ auto list_int32 = ArrayFromJSON(make_src_list(int32()), "[[87654321]]");
+
+ auto options = CastOptions::Safe(make_dest_list(int16()));
+ CheckCastFails(list_int32, options);
+
+ options.allow_int_overflow = true;
+ CheckCast(list_int32, ArrayFromJSON(make_dest_list(int16()), "[[32689]]"), options);
+ }
+ }
+}
+
+TEST(Cast, IdentityCasts) {
+ // ARROW-4102
+ auto CheckIdentityCast = [](std::shared_ptr<DataType> type, const std::string& json) {
+ CheckCastZeroCopy(ArrayFromJSON(type, json), type);
+ };
+
+ CheckIdentityCast(null(), "[null, null, null]");
+ CheckIdentityCast(boolean(), "[false, true, null, false]");
+
+ for (auto type : kNumericTypes) {
+ CheckIdentityCast(type, "[1, 2, null, 4]");
+ }
+ CheckIdentityCast(binary(), R"(["foo", "bar"])");
+ CheckIdentityCast(utf8(), R"(["foo", "bar"])");
+ CheckIdentityCast(fixed_size_binary(3), R"(["foo", "bar"])");
+
+ CheckIdentityCast(list(int8()), "[[1, 2], [null], [], [3]]");
+
+ CheckIdentityCast(time32(TimeUnit::MILLI), "[1, 2, 3, 4]");
+ CheckIdentityCast(time64(TimeUnit::MICRO), "[1, 2, 3, 4]");
+ CheckIdentityCast(date32(), "[1, 2, 3, 4]");
+ CheckIdentityCast(date64(), "[86400000, 0]");
+ CheckIdentityCast(timestamp(TimeUnit::SECOND), "[1, 2, 3, 4]");
+
+ CheckIdentityCast(dictionary(int8(), int8()), "[1, 2, 3, 1, null, 3]");
+}
+
+TEST(Cast, EmptyCasts) {
+ // ARROW-4766: 0-length arrays should not segfault
+ auto CheckCastEmpty = [](std::shared_ptr<DataType> from, std::shared_ptr<DataType> to) {
+ // Python creates array with nullptr instead of 0-length (valid) buffers.
+ auto data = ArrayData::Make(from, /* length */ 0, /* buffers */ {nullptr, nullptr});
+ CheckCast(MakeArray(data), ArrayFromJSON(to, "[]"));
+ };
+
+ for (auto numeric : kNumericTypes) {
+ CheckCastEmpty(boolean(), numeric);
+ CheckCastEmpty(numeric, boolean());
+ }
+}
+
+TEST(Cast, CastWithNoValidityBitmapButUnknownNullCount) {
+ // ARROW-12672 segfault when casting slightly malformed array
+ // (no validity bitmap but atomic null count non-zero)
+ auto values = ArrayFromJSON(boolean(), "[true, true, false]");
+
+ ASSERT_OK_AND_ASSIGN(auto expected, Cast(*values, int8()));
+
+ ASSERT_EQ(values->data()->buffers[0], NULLPTR);
+ values->data()->null_count = kUnknownNullCount;
+ ASSERT_OK_AND_ASSIGN(auto result, Cast(*values, int8()));
+
+ AssertArraysEqual(*expected, *result);
+}
+
+// ----------------------------------------------------------------------
+// Test casting from NullType
+
+TEST(Cast, FromNull) {
+ for (auto to_type : {
+ null(),
+ uint8(),
+ int8(),
+ uint16(),
+ int16(),
+ uint32(),
+ int32(),
+ uint64(),
+ int64(),
+ float32(),
+ float64(),
+ date32(),
+ date64(),
+ fixed_size_binary(10),
+ binary(),
+ utf8(),
+ }) {
+ ASSERT_OK_AND_ASSIGN(auto expected, MakeArrayOfNull(to_type, 10));
+ CheckCast(std::make_shared<NullArray>(10), expected);
+ }
+}
+
+TEST(Cast, FromNullToDictionary) {
+ auto from = std::make_shared<NullArray>(10);
+ auto to_type = dictionary(int8(), boolean());
+
+ ASSERT_OK_AND_ASSIGN(auto expected, MakeArrayOfNull(to_type, 10));
+ CheckCast(from, expected);
+}
+
+// ----------------------------------------------------------------------
+// Test casting from DictionaryType
+
+TEST(Cast, FromDictionary) {
+ ArrayVector dictionaries;
+ dictionaries.push_back(std::make_shared<NullArray>(5));
+
+ for (auto num_type : kNumericTypes) {
+ dictionaries.push_back(ArrayFromJSON(num_type, "[23, 12, 45, 12, null]"));
+ }
+
+ for (auto string_type : kBaseBinaryTypes) {
+ dictionaries.push_back(
+ ArrayFromJSON(string_type, R"(["foo", "bar", "baz", "foo", null])"));
+ }
+
+ for (auto dict : dictionaries) {
+ for (auto index_type : kDictionaryIndexTypes) {
+ auto indices = ArrayFromJSON(index_type, "[4, 0, 1, 2, 0, 4, null, 2]");
+ ASSERT_OK_AND_ASSIGN(auto expected, Take(*dict, *indices));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto dict_arr, DictionaryArray::FromArrays(dictionary(index_type, dict->type()),
+ indices, dict));
+ CheckCast(dict_arr, expected);
+ }
+ }
+
+ for (auto dict : dictionaries) {
+ if (dict->type_id() == Type::NA) continue;
+
+ // Test with a nullptr bitmap buffer (ARROW-3208)
+ auto indices = ArrayFromJSON(int8(), "[0, 0, 1, 2, 0, 3, 3, 2]");
+ ASSERT_OK_AND_ASSIGN(auto no_nulls, Take(*dict, *indices));
+ ASSERT_EQ(no_nulls->null_count(), 0);
+
+ ASSERT_OK_AND_ASSIGN(Datum encoded, DictionaryEncode(no_nulls));
+
+ // Make a new dict array with nullptr bitmap buffer
+ auto data = encoded.array()->Copy();
+ data->buffers[0] = nullptr;
+ data->null_count = 0;
+ std::shared_ptr<Array> dict_array = std::make_shared<DictionaryArray>(data);
+ ValidateOutput(*dict_array);
+
+ CheckCast(dict_array, no_nulls);
+ }
+}
+
+std::shared_ptr<Array> SmallintArrayFromJSON(const std::string& json_data) {
+ auto arr = ArrayFromJSON(int16(), json_data);
+ auto ext_data = arr->data()->Copy();
+ ext_data->type = smallint();
+ return MakeArray(ext_data);
+}
+
+TEST(Cast, ExtensionTypeToIntDowncast) {
+ auto smallint = std::make_shared<SmallintType>();
+ ExtensionTypeGuard smallint_guard(smallint);
+
+ std::shared_ptr<Array> result;
+ std::vector<bool> is_valid = {true, false, true, true, true};
+
+ // Smallint(int16) to int16
+ CheckCastZeroCopy(SmallintArrayFromJSON("[0, 100, 200, 1, 2]"), int16());
+
+ // Smallint(int16) to uint8, no overflow/underrun
+ CheckCast(SmallintArrayFromJSON("[0, 100, 200, 1, 2]"),
+ ArrayFromJSON(uint8(), "[0, 100, 200, 1, 2]"));
+
+ // Smallint(int16) to uint8, with overflow
+ {
+ CastOptions options;
+ options.to_type = uint8();
+ CheckCastFails(SmallintArrayFromJSON("[0, null, 256, 1, 3]"), options);
+
+ options.allow_int_overflow = true;
+ CheckCast(SmallintArrayFromJSON("[0, null, 256, 1, 3]"),
+ ArrayFromJSON(uint8(), "[0, null, 0, 1, 3]"), options);
+ }
+
+ // Smallint(int16) to uint8, with underflow
+ {
+ CastOptions options;
+ options.to_type = uint8();
+ CheckCastFails(SmallintArrayFromJSON("[0, null, -1, 1, 3]"), options);
+
+ options.allow_int_overflow = true;
+ CheckCast(SmallintArrayFromJSON("[0, null, -1, 1, 3]"),
+ ArrayFromJSON(uint8(), "[0, null, 255, 1, 3]"), options);
+ }
+}
+
+TEST(Cast, DictTypeToAnotherDict) {
+ auto check_cast = [&](const std::shared_ptr<DataType>& in_type,
+ const std::shared_ptr<DataType>& out_type,
+ const std::string& json_str,
+ const CastOptions& options = CastOptions()) {
+ auto arr = ArrayFromJSON(in_type, json_str);
+ auto exp = in_type->Equals(out_type) ? arr : ArrayFromJSON(out_type, json_str);
+ // this checks for scalars as well
+ CheckCast(arr, exp, options);
+ };
+
+ // check same type passed on to casting
+ check_cast(dictionary(int8(), int16()), dictionary(int8(), int16()),
+ "[1, 2, 3, 1, null, 3]");
+ check_cast(dictionary(int8(), int16()), dictionary(int32(), int64()),
+ "[1, 2, 3, 1, null, 3]");
+ check_cast(dictionary(int8(), int16()), dictionary(int32(), float64()),
+ "[1, 2, 3, 1, null, 3]");
+ check_cast(dictionary(int32(), utf8()), dictionary(int8(), utf8()),
+ R"(["a", "b", "a", null])");
+
+ auto arr = ArrayFromJSON(dictionary(int32(), int32()), "[1, 1000]");
+ // check casting unsafe values (checking for unsafe indices is unnecessary, because it
+ // would create an invalid index array which results in a ValidateOutput failure)
+ ASSERT_OK_AND_ASSIGN(auto casted,
+ Cast(arr, dictionary(int8(), int8()), CastOptions::Unsafe()));
+ ValidateOutput(casted);
+
+ // check safe casting values
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, testing::HasSubstr("Integer value 1000 not in range"),
+ Cast(arr, dictionary(int8(), int8()), CastOptions::Safe()));
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_compare.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_compare.cc
new file mode 100644
index 000000000..316a94cbc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_compare.cc
@@ -0,0 +1,540 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cmath>
+#include <limits>
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/util/bitmap_ops.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+using util::string_view;
+
+namespace compute {
+namespace internal {
+
+namespace {
+
+struct Equal {
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr T Call(KernelContext*, const Arg0& left, const Arg1& right, Status*) {
+ static_assert(std::is_same<T, bool>::value && std::is_same<Arg0, Arg1>::value, "");
+ return left == right;
+ }
+};
+
+struct NotEqual {
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr T Call(KernelContext*, const Arg0& left, const Arg1& right, Status*) {
+ static_assert(std::is_same<T, bool>::value && std::is_same<Arg0, Arg1>::value, "");
+ return left != right;
+ }
+};
+
+struct Greater {
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr T Call(KernelContext*, const Arg0& left, const Arg1& right, Status*) {
+ static_assert(std::is_same<T, bool>::value && std::is_same<Arg0, Arg1>::value, "");
+ return left > right;
+ }
+};
+
+struct GreaterEqual {
+ template <typename T, typename Arg0, typename Arg1>
+ static constexpr T Call(KernelContext*, const Arg0& left, const Arg1& right, Status*) {
+ static_assert(std::is_same<T, bool>::value && std::is_same<Arg0, Arg1>::value, "");
+ return left >= right;
+ }
+};
+
+template <typename T>
+using is_unsigned_integer = std::integral_constant<bool, std::is_integral<T>::value &&
+ std::is_unsigned<T>::value>;
+
+template <typename T>
+using is_signed_integer =
+ std::integral_constant<bool, std::is_integral<T>::value && std::is_signed<T>::value>;
+
+template <typename T>
+using enable_if_integer =
+ enable_if_t<is_signed_integer<T>::value || is_unsigned_integer<T>::value, T>;
+
+template <typename T>
+using enable_if_floating_point = enable_if_t<std::is_floating_point<T>::value, T>;
+
+struct Minimum {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_point<T> Call(Arg0 left, Arg1 right) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, Arg1>::value, "");
+ return std::fmin(left, right);
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_integer<T> Call(Arg0 left, Arg1 right) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, Arg1>::value, "");
+ return std::min(left, right);
+ }
+
+ template <typename T>
+ static constexpr enable_if_t<std::is_same<float, T>::value, T> antiextreme() {
+ return std::nanf("");
+ }
+
+ template <typename T>
+ static constexpr enable_if_t<std::is_same<double, T>::value, T> antiextreme() {
+ return std::nan("");
+ }
+
+ template <typename T>
+ static constexpr enable_if_integer<T> antiextreme() {
+ return std::numeric_limits<T>::max();
+ }
+};
+
+struct Maximum {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_point<T> Call(Arg0 left, Arg1 right) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, Arg1>::value, "");
+ return std::fmax(left, right);
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_integer<T> Call(Arg0 left, Arg1 right) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, Arg1>::value, "");
+ return std::max(left, right);
+ }
+
+ template <typename T>
+ static constexpr enable_if_t<std::is_same<float, T>::value, T> antiextreme() {
+ return std::nanf("");
+ }
+
+ template <typename T>
+ static constexpr enable_if_t<std::is_same<double, T>::value, T> antiextreme() {
+ return std::nan("");
+ }
+
+ template <typename T>
+ static constexpr enable_if_integer<T> antiextreme() {
+ return std::numeric_limits<T>::min();
+ }
+};
+
+// Implement Less, LessEqual by flipping arguments to Greater, GreaterEqual
+
+template <typename Op>
+void AddIntegerCompare(const std::shared_ptr<DataType>& ty, ScalarFunction* func) {
+ auto exec =
+ GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(*ty);
+ DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
+}
+
+template <typename InType, typename Op>
+void AddGenericCompare(const std::shared_ptr<DataType>& ty, ScalarFunction* func) {
+ DCHECK_OK(
+ func->AddKernel({ty, ty}, boolean(),
+ applicator::ScalarBinaryEqualTypes<BooleanType, InType, Op>::Exec));
+}
+
+struct CompareFunction : ScalarFunction {
+ using ScalarFunction::ScalarFunction;
+
+ Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
+ RETURN_NOT_OK(CheckArity(*values));
+ if (HasDecimal(*values)) {
+ RETURN_NOT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, values));
+ }
+
+ using arrow::compute::detail::DispatchExactImpl;
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+
+ EnsureDictionaryDecoded(values);
+ ReplaceNullWithOtherType(values);
+
+ if (auto type = CommonNumeric(*values)) {
+ ReplaceTypes(type, values);
+ } else if (auto type = CommonTemporal(values->data(), values->size())) {
+ ReplaceTypes(type, values);
+ } else if (auto type = CommonBinary(values->data(), values->size())) {
+ ReplaceTypes(type, values);
+ }
+
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+ return arrow::compute::detail::NoMatchingKernel(this, *values);
+ }
+};
+
+struct VarArgsCompareFunction : ScalarFunction {
+ using ScalarFunction::ScalarFunction;
+
+ Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
+ RETURN_NOT_OK(CheckArity(*values));
+
+ using arrow::compute::detail::DispatchExactImpl;
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+
+ EnsureDictionaryDecoded(values);
+
+ if (auto type = CommonNumeric(*values)) {
+ ReplaceTypes(type, values);
+ } else if (auto type = CommonTemporal(values->data(), values->size())) {
+ ReplaceTypes(type, values);
+ }
+
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+ return arrow::compute::detail::NoMatchingKernel(this, *values);
+ }
+};
+
+template <typename Op>
+std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name,
+ const FunctionDoc* doc) {
+ auto func = std::make_shared<CompareFunction>(name, Arity::Binary(), doc);
+
+ DCHECK_OK(func->AddKernel(
+ {boolean(), boolean()}, boolean(),
+ applicator::ScalarBinary<BooleanType, BooleanType, BooleanType, Op>::Exec));
+
+ for (const std::shared_ptr<DataType>& ty : IntTypes()) {
+ AddIntegerCompare<Op>(ty, func.get());
+ }
+ AddIntegerCompare<Op>(date32(), func.get());
+ AddIntegerCompare<Op>(date64(), func.get());
+
+ AddGenericCompare<FloatType, Op>(float32(), func.get());
+ AddGenericCompare<DoubleType, Op>(float64(), func.get());
+
+ // Add timestamp kernels
+ for (auto unit : TimeUnit::values()) {
+ InputType in_type(match::TimestampTypeUnit(unit));
+ auto exec =
+ GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(
+ int64());
+ DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
+ }
+
+ // Duration
+ for (auto unit : TimeUnit::values()) {
+ InputType in_type(match::DurationTypeUnit(unit));
+ auto exec =
+ GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(
+ int64());
+ DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
+ }
+
+ // Time32 and Time64
+ for (auto unit : {TimeUnit::SECOND, TimeUnit::MILLI}) {
+ InputType in_type(match::Time32TypeUnit(unit));
+ auto exec =
+ GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(
+ int32());
+ DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
+ }
+ for (auto unit : {TimeUnit::MICRO, TimeUnit::NANO}) {
+ InputType in_type(match::Time64TypeUnit(unit));
+ auto exec =
+ GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(
+ int64());
+ DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
+ }
+
+ for (const std::shared_ptr<DataType>& ty : BaseBinaryTypes()) {
+ auto exec =
+ GenerateVarBinaryBase<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(*ty);
+ DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
+ }
+
+ for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) {
+ auto exec = GenerateDecimal<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
+ DCHECK_OK(
+ func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec)));
+ }
+
+ {
+ auto exec =
+ applicator::ScalarBinaryEqualTypes<BooleanType, FixedSizeBinaryType, Op>::Exec;
+ auto ty = InputType(Type::FIXED_SIZE_BINARY);
+ DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
+ }
+
+ return func;
+}
+
+std::shared_ptr<ScalarFunction> MakeFlippedFunction(std::string name,
+ const ScalarFunction& func,
+ const FunctionDoc* doc) {
+ auto flipped_func = std::make_shared<CompareFunction>(name, Arity::Binary(), doc);
+ for (const ScalarKernel* kernel : func.kernels()) {
+ ScalarKernel flipped_kernel = *kernel;
+ flipped_kernel.exec = MakeFlippedBinaryExec(kernel->exec);
+ DCHECK_OK(flipped_func->AddKernel(std::move(flipped_kernel)));
+ }
+ return flipped_func;
+}
+
+using MinMaxState = OptionsWrapper<ElementWiseAggregateOptions>;
+
+// Implement a variadic scalar min/max kernel.
+template <typename OutType, typename Op>
+struct ScalarMinMax {
+ using OutValue = typename GetOutputType<OutType>::T;
+
+ static void ExecScalar(const ExecBatch& batch,
+ const ElementWiseAggregateOptions& options, Scalar* out) {
+ // All arguments are scalar
+ OutValue value{};
+ bool valid = false;
+ for (const auto& arg : batch.values) {
+ // Ignore non-scalar arguments so we can use it in the mixed-scalar-and-array case
+ if (!arg.is_scalar()) continue;
+ const auto& scalar = *arg.scalar();
+ if (!scalar.is_valid) {
+ if (options.skip_nulls) continue;
+ out->is_valid = false;
+ return;
+ }
+ if (!valid) {
+ value = UnboxScalar<OutType>::Unbox(scalar);
+ valid = true;
+ } else {
+ value = Op::template Call<OutValue, OutValue, OutValue>(
+ value, UnboxScalar<OutType>::Unbox(scalar));
+ }
+ }
+ out->is_valid = valid;
+ if (valid) {
+ BoxScalar<OutType>::Box(value, out);
+ }
+ }
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const ElementWiseAggregateOptions& options = MinMaxState::Get(ctx);
+ const auto descrs = batch.GetDescriptors();
+ const size_t scalar_count =
+ static_cast<size_t>(std::count_if(batch.values.begin(), batch.values.end(),
+ [](const Datum& d) { return d.is_scalar(); }));
+ if (scalar_count == batch.values.size()) {
+ ExecScalar(batch, options, out->scalar().get());
+ return Status::OK();
+ }
+
+ ArrayData* output = out->mutable_array();
+
+ // At least one array, two or more arguments
+ ArrayDataVector arrays;
+ for (const auto& arg : batch.values) {
+ if (!arg.is_array()) continue;
+ arrays.push_back(arg.array());
+ }
+
+ bool initialize_output = true;
+ if (scalar_count > 0) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> temp_scalar,
+ MakeScalar(out->type(), 0));
+ ExecScalar(batch, options, temp_scalar.get());
+ if (temp_scalar->is_valid) {
+ const auto value = UnboxScalar<OutType>::Unbox(*temp_scalar);
+ initialize_output = false;
+ OutValue* out = output->GetMutableValues<OutValue>(1);
+ std::fill(out, out + batch.length, value);
+ } else if (!options.skip_nulls) {
+ // Abort early
+ ARROW_ASSIGN_OR_RAISE(auto array, MakeArrayFromScalar(*temp_scalar, batch.length,
+ ctx->memory_pool()));
+ *output = *array->data();
+ return Status::OK();
+ }
+ }
+
+ if (initialize_output) {
+ OutValue* out = output->GetMutableValues<OutValue>(1);
+ std::fill(out, out + batch.length, Op::template antiextreme<OutValue>());
+ }
+
+ // Precompute the validity buffer
+ if (options.skip_nulls && initialize_output) {
+ // OR together the validity buffers of all arrays
+ if (std::all_of(arrays.begin(), arrays.end(),
+ [](const std::shared_ptr<ArrayData>& arr) {
+ return arr->MayHaveNulls();
+ })) {
+ for (const auto& arr : arrays) {
+ if (!arr->MayHaveNulls()) continue;
+ if (!output->buffers[0]) {
+ ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(batch.length));
+ ::arrow::internal::CopyBitmap(arr->buffers[0]->data(), arr->offset,
+
+ batch.length,
+ output->buffers[0]->mutable_data(),
+ /*dest_offset=*/0);
+ } else {
+ ::arrow::internal::BitmapOr(
+ output->buffers[0]->data(), /*left_offset=*/0, arr->buffers[0]->data(),
+ arr->offset, batch.length,
+ /*out_offset=*/0, output->buffers[0]->mutable_data());
+ }
+ }
+ }
+ } else if (!options.skip_nulls) {
+ // AND together the validity buffers of all arrays
+ for (const auto& arr : arrays) {
+ if (!arr->MayHaveNulls()) continue;
+ if (!output->buffers[0]) {
+ ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(batch.length));
+ ::arrow::internal::CopyBitmap(arr->buffers[0]->data(), arr->offset,
+ batch.length, output->buffers[0]->mutable_data(),
+ /*dest_offset=*/0);
+ } else {
+ ::arrow::internal::BitmapAnd(output->buffers[0]->data(), /*left_offset=*/0,
+ arr->buffers[0]->data(), arr->offset, batch.length,
+ /*out_offset=*/0,
+ output->buffers[0]->mutable_data());
+ }
+ }
+ }
+
+ for (const auto& array : arrays) {
+ OutputArrayWriter<OutType> writer(out->mutable_array());
+ ArrayIterator<OutType> out_it(*output);
+ int64_t index = 0;
+ VisitArrayValuesInline<OutType>(
+ *array,
+ [&](OutValue value) {
+ auto u = out_it();
+ if (!output->buffers[0] ||
+ BitUtil::GetBit(output->buffers[0]->data(), index)) {
+ writer.Write(Op::template Call<OutValue, OutValue, OutValue>(u, value));
+ } else {
+ writer.Write(value);
+ }
+ index++;
+ },
+ [&]() {
+ // RHS is null, preserve the LHS
+ writer.values++;
+ index++;
+ out_it();
+ });
+ }
+ output->null_count = output->buffers[0] ? -1 : 0;
+ return Status::OK();
+ }
+};
+
+template <typename Op>
+std::shared_ptr<ScalarFunction> MakeScalarMinMax(std::string name,
+ const FunctionDoc* doc) {
+ static auto default_element_wise_aggregate_options =
+ ElementWiseAggregateOptions::Defaults();
+
+ auto func = std::make_shared<VarArgsCompareFunction>(
+ name, Arity::VarArgs(), doc, &default_element_wise_aggregate_options);
+ for (const auto& ty : NumericTypes()) {
+ auto exec = GeneratePhysicalNumeric<ScalarMinMax, Op>(ty);
+ ScalarKernel kernel{KernelSignature::Make({ty}, ty, /*is_varargs=*/true), exec,
+ MinMaxState::Init};
+ kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::type::PREALLOCATE;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ }
+ for (const auto& ty : TemporalTypes()) {
+ auto exec = GeneratePhysicalNumeric<ScalarMinMax, Op>(ty);
+ ScalarKernel kernel{KernelSignature::Make({ty}, ty, /*is_varargs=*/true), exec,
+ MinMaxState::Init};
+ kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::type::PREALLOCATE;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ }
+ return func;
+}
+
+const FunctionDoc equal_doc{"Compare values for equality (x == y)",
+ ("A null on either side emits a null comparison result."),
+ {"x", "y"}};
+
+const FunctionDoc not_equal_doc{"Compare values for inequality (x != y)",
+ ("A null on either side emits a null comparison result."),
+ {"x", "y"}};
+
+const FunctionDoc greater_doc{"Compare values for ordered inequality (x > y)",
+ ("A null on either side emits a null comparison result."),
+ {"x", "y"}};
+
+const FunctionDoc greater_equal_doc{
+ "Compare values for ordered inequality (x >= y)",
+ ("A null on either side emits a null comparison result."),
+ {"x", "y"}};
+
+const FunctionDoc less_doc{"Compare values for ordered inequality (x < y)",
+ ("A null on either side emits a null comparison result."),
+ {"x", "y"}};
+
+const FunctionDoc less_equal_doc{
+ "Compare values for ordered inequality (x <= y)",
+ ("A null on either side emits a null comparison result."),
+ {"x", "y"}};
+
+const FunctionDoc min_element_wise_doc{
+ "Find the element-wise minimum value",
+ ("Nulls will be ignored (default) or propagated. "
+ "NaN will be taken over null, but not over any valid float."),
+ {"*args"},
+ "ElementWiseAggregateOptions"};
+
+const FunctionDoc max_element_wise_doc{
+ "Find the element-wise maximum value",
+ ("Nulls will be ignored (default) or propagated. "
+ "NaN will be taken over null, but not over any valid float."),
+ {"*args"},
+ "ElementWiseAggregateOptions"};
+} // namespace
+
+void RegisterScalarComparison(FunctionRegistry* registry) {
+ DCHECK_OK(registry->AddFunction(MakeCompareFunction<Equal>("equal", &equal_doc)));
+ DCHECK_OK(
+ registry->AddFunction(MakeCompareFunction<NotEqual>("not_equal", &not_equal_doc)));
+
+ auto greater = MakeCompareFunction<Greater>("greater", &greater_doc);
+ auto greater_equal =
+ MakeCompareFunction<GreaterEqual>("greater_equal", &greater_equal_doc);
+
+ auto less = MakeFlippedFunction("less", *greater, &less_doc);
+ auto less_equal = MakeFlippedFunction("less_equal", *greater_equal, &less_equal_doc);
+ DCHECK_OK(registry->AddFunction(std::move(less)));
+ DCHECK_OK(registry->AddFunction(std::move(less_equal)));
+ DCHECK_OK(registry->AddFunction(std::move(greater)));
+ DCHECK_OK(registry->AddFunction(std::move(greater_equal)));
+
+ // ----------------------------------------------------------------------
+ // Variadic element-wise functions
+
+ auto min_element_wise =
+ MakeScalarMinMax<Minimum>("min_element_wise", &min_element_wise_doc);
+ DCHECK_OK(registry->AddFunction(std::move(min_element_wise)));
+
+ auto max_element_wise =
+ MakeScalarMinMax<Maximum>("max_element_wise", &max_element_wise_doc);
+ DCHECK_OK(registry->AddFunction(std::move(max_element_wise)));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_compare_benchmark.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_compare_benchmark.cc
new file mode 100644
index 000000000..86be319a3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_compare_benchmark.cc
@@ -0,0 +1,81 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <vector>
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+
+namespace arrow {
+namespace compute {
+
+constexpr auto kSeed = 0x94378165;
+
+template <CompareOperator op, typename Type>
+static void CompareArrayScalar(benchmark::State& state) {
+ RegressionArgs args(state, /*size_is_bytes=*/false);
+ auto ty = TypeTraits<Type>::type_singleton();
+ auto rand = random::RandomArrayGenerator(kSeed);
+ auto array = rand.ArrayOf(ty, args.size, args.null_proportion);
+ auto scalar = *rand.ArrayOf(ty, 1, 0)->GetScalar(0);
+ for (auto _ : state) {
+ ABORT_NOT_OK(
+ CallFunction(CompareOperatorToFunctionName(op), {array, Datum(scalar)}).status());
+ }
+}
+
+template <CompareOperator op, typename Type>
+static void CompareArrayArray(benchmark::State& state) {
+ RegressionArgs args(state, /*size_is_bytes=*/false);
+ auto ty = TypeTraits<Type>::type_singleton();
+ auto rand = random::RandomArrayGenerator(kSeed);
+ auto lhs = rand.ArrayOf(ty, args.size, args.null_proportion);
+ auto rhs = rand.ArrayOf(ty, args.size, args.null_proportion);
+ for (auto _ : state) {
+ ABORT_NOT_OK(CallFunction(CompareOperatorToFunctionName(op), {lhs, rhs}).status());
+ }
+}
+
+static void GreaterArrayArrayInt64(benchmark::State& state) {
+ CompareArrayArray<GREATER, Int64Type>(state);
+}
+
+static void GreaterArrayScalarInt64(benchmark::State& state) {
+ CompareArrayScalar<GREATER, Int64Type>(state);
+}
+
+static void GreaterArrayArrayString(benchmark::State& state) {
+ CompareArrayArray<GREATER, StringType>(state);
+}
+
+static void GreaterArrayScalarString(benchmark::State& state) {
+ CompareArrayScalar<GREATER, StringType>(state);
+}
+
+BENCHMARK(GreaterArrayArrayInt64)->Apply(RegressionSetArgs);
+BENCHMARK(GreaterArrayScalarInt64)->Apply(RegressionSetArgs);
+
+BENCHMARK(GreaterArrayArrayString)->Apply(RegressionSetArgs);
+BENCHMARK(GreaterArrayScalarString)->Apply(RegressionSetArgs);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
new file mode 100644
index 000000000..800ae8063
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
@@ -0,0 +1,1388 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+
+using internal::BitmapReader;
+
+namespace compute {
+
+using util::string_view;
+
+template <typename ArrowType>
+static void ValidateCompare(CompareOptions options, const Datum& lhs, const Datum& rhs,
+ const Datum& expected) {
+ ASSERT_OK_AND_ASSIGN(
+ Datum result, CallFunction(CompareOperatorToFunctionName(options.op), {lhs, rhs}));
+ AssertArraysEqual(*expected.make_array(), *result.make_array(),
+ /*verbose=*/true);
+}
+
+template <typename ArrowType>
+static void ValidateCompare(CompareOptions options, const char* lhs_str, const Datum& rhs,
+ const char* expected_str) {
+ auto lhs = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), lhs_str);
+ auto expected = ArrayFromJSON(TypeTraits<BooleanType>::type_singleton(), expected_str);
+ ValidateCompare<ArrowType>(options, lhs, rhs, expected);
+}
+
+template <typename ArrowType>
+static void ValidateCompare(CompareOptions options, const Datum& lhs, const char* rhs_str,
+ const char* expected_str) {
+ auto rhs = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), rhs_str);
+ auto expected = ArrayFromJSON(TypeTraits<BooleanType>::type_singleton(), expected_str);
+ ValidateCompare<ArrowType>(options, lhs, rhs, expected);
+}
+
+template <typename ArrowType>
+static void ValidateCompare(CompareOptions options, const char* lhs_str,
+ const char* rhs_str, const char* expected_str) {
+ auto lhs = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), lhs_str);
+ auto rhs = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), rhs_str);
+ auto expected = ArrayFromJSON(TypeTraits<BooleanType>::type_singleton(), expected_str);
+ ValidateCompare<ArrowType>(options, lhs, rhs, expected);
+}
+
+template <typename T>
+static inline bool SlowCompare(CompareOperator op, const T& lhs, const T& rhs) {
+ switch (op) {
+ case EQUAL:
+ return lhs == rhs;
+ case NOT_EQUAL:
+ return lhs != rhs;
+ case GREATER:
+ return lhs > rhs;
+ case GREATER_EQUAL:
+ return lhs >= rhs;
+ case LESS:
+ return lhs < rhs;
+ case LESS_EQUAL:
+ return lhs <= rhs;
+ default:
+ return false;
+ }
+}
+
+template <typename ArrowType>
+Datum SimpleScalarArrayCompare(CompareOptions options, const Datum& lhs,
+ const Datum& rhs) {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+
+ bool swap = lhs.is_array();
+ auto array = std::static_pointer_cast<ArrayType>((swap ? lhs : rhs).make_array());
+ auto value = std::static_pointer_cast<ScalarType>((swap ? rhs : lhs).scalar())->value;
+
+ std::vector<bool> bitmap(array->length());
+ for (int64_t i = 0; i < array->length(); i++) {
+ bitmap[i] = swap ? SlowCompare(options.op, array->Value(i), value)
+ : SlowCompare(options.op, value, array->Value(i));
+ }
+
+ std::shared_ptr<Array> result;
+
+ if (array->null_count() == 0) {
+ ArrayFromVector<BooleanType>(bitmap, &result);
+ } else {
+ std::vector<bool> null_bitmap(array->length());
+ auto reader =
+ BitmapReader(array->null_bitmap_data(), array->offset(), array->length());
+ for (int64_t i = 0; i < array->length(); i++, reader.Next()) {
+ null_bitmap[i] = reader.IsSet();
+ }
+ ArrayFromVector<BooleanType>(null_bitmap, bitmap, &result);
+ }
+
+ return Datum(result);
+}
+
+template <>
+Datum SimpleScalarArrayCompare<StringType>(CompareOptions options, const Datum& lhs,
+ const Datum& rhs) {
+ bool swap = lhs.is_array();
+ auto array = std::static_pointer_cast<StringArray>((swap ? lhs : rhs).make_array());
+ auto value = util::string_view(
+ *std::static_pointer_cast<StringScalar>((swap ? rhs : lhs).scalar())->value);
+
+ std::vector<bool> bitmap(array->length());
+ for (int64_t i = 0; i < array->length(); i++) {
+ bitmap[i] = swap ? SlowCompare(options.op, array->GetView(i), value)
+ : SlowCompare(options.op, value, array->GetView(i));
+ }
+
+ std::shared_ptr<Array> result;
+
+ if (array->null_count() == 0) {
+ ArrayFromVector<BooleanType>(bitmap, &result);
+ } else {
+ std::vector<bool> null_bitmap(array->length());
+ auto reader =
+ BitmapReader(array->null_bitmap_data(), array->offset(), array->length());
+ for (int64_t i = 0; i < array->length(); i++, reader.Next()) {
+ null_bitmap[i] = reader.IsSet();
+ }
+ ArrayFromVector<BooleanType>(null_bitmap, bitmap, &result);
+ }
+
+ return Datum(result);
+}
+
+template <typename ArrayType>
+std::vector<bool> NullBitmapFromArrays(const ArrayType& lhs, const ArrayType& rhs) {
+ auto left_lambda = [&lhs](int64_t i) {
+ return lhs.null_count() == 0 ? true : lhs.IsValid(i);
+ };
+
+ auto right_lambda = [&rhs](int64_t i) {
+ return rhs.null_count() == 0 ? true : rhs.IsValid(i);
+ };
+
+ const int64_t length = lhs.length();
+ std::vector<bool> null_bitmap(length);
+
+ for (int64_t i = 0; i < length; i++) {
+ null_bitmap[i] = left_lambda(i) && right_lambda(i);
+ }
+
+ return null_bitmap;
+}
+
+template <typename ArrowType>
+Datum SimpleArrayArrayCompare(CompareOptions options, const Datum& lhs,
+ const Datum& rhs) {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ auto l_array = std::static_pointer_cast<ArrayType>(lhs.make_array());
+ auto r_array = std::static_pointer_cast<ArrayType>(rhs.make_array());
+ const int64_t length = l_array->length();
+
+ std::vector<bool> bitmap(length);
+ for (int64_t i = 0; i < length; i++) {
+ bitmap[i] = SlowCompare(options.op, l_array->Value(i), r_array->Value(i));
+ }
+
+ std::shared_ptr<Array> result;
+
+ if (l_array->null_count() == 0 && r_array->null_count() == 0) {
+ ArrayFromVector<BooleanType>(bitmap, &result);
+ } else {
+ std::vector<bool> null_bitmap = NullBitmapFromArrays(*l_array, *r_array);
+ ArrayFromVector<BooleanType>(null_bitmap, bitmap, &result);
+ }
+
+ return Datum(result);
+}
+
+template <>
+Datum SimpleArrayArrayCompare<StringType>(CompareOptions options, const Datum& lhs,
+ const Datum& rhs) {
+ auto l_array = std::static_pointer_cast<StringArray>(lhs.make_array());
+ auto r_array = std::static_pointer_cast<StringArray>(rhs.make_array());
+ const int64_t length = l_array->length();
+
+ std::vector<bool> bitmap(length);
+ for (int64_t i = 0; i < length; i++) {
+ bitmap[i] = SlowCompare(options.op, l_array->GetView(i), r_array->GetView(i));
+ }
+
+ std::shared_ptr<Array> result;
+
+ if (l_array->null_count() == 0 && r_array->null_count() == 0) {
+ ArrayFromVector<BooleanType>(bitmap, &result);
+ } else {
+ std::vector<bool> null_bitmap = NullBitmapFromArrays(*l_array, *r_array);
+ ArrayFromVector<BooleanType>(null_bitmap, bitmap, &result);
+ }
+
+ return Datum(result);
+}
+
+template <typename ArrowType>
+void ValidateCompare(CompareOptions options, const Datum& lhs, const Datum& rhs) {
+ Datum result;
+
+ bool has_scalar = lhs.is_scalar() || rhs.is_scalar();
+ Datum expected = has_scalar ? SimpleScalarArrayCompare<ArrowType>(options, lhs, rhs)
+ : SimpleArrayArrayCompare<ArrowType>(options, lhs, rhs);
+
+ ValidateCompare<ArrowType>(options, lhs, rhs, expected);
+}
+
+template <typename ArrowType>
+class TestNumericCompareKernel : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestNumericCompareKernel, NumericArrowTypes);
+TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayScalar) {
+ using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
+ using CType = typename TypeTraits<TypeParam>::CType;
+
+ Datum one(std::make_shared<ScalarType>(CType(1)));
+
+ CompareOptions eq(CompareOperator::EQUAL);
+ ValidateCompare<TypeParam>(eq, "[]", one, "[]");
+ ValidateCompare<TypeParam>(eq, "[null]", one, "[null]");
+ ValidateCompare<TypeParam>(eq, "[0,0,1,1,2,2]", one, "[0,0,1,1,0,0]");
+ ValidateCompare<TypeParam>(eq, "[0,1,2,3,4,5]", one, "[0,1,0,0,0,0]");
+ ValidateCompare<TypeParam>(eq, "[5,4,3,2,1,0]", one, "[0,0,0,0,1,0]");
+ ValidateCompare<TypeParam>(eq, "[null,0,1,1]", one, "[null,0,1,1]");
+
+ CompareOptions neq(CompareOperator::NOT_EQUAL);
+ ValidateCompare<TypeParam>(neq, "[]", one, "[]");
+ ValidateCompare<TypeParam>(neq, "[null]", one, "[null]");
+ ValidateCompare<TypeParam>(neq, "[0,0,1,1,2,2]", one, "[1,1,0,0,1,1]");
+ ValidateCompare<TypeParam>(neq, "[0,1,2,3,4,5]", one, "[1,0,1,1,1,1]");
+ ValidateCompare<TypeParam>(neq, "[5,4,3,2,1,0]", one, "[1,1,1,1,0,1]");
+ ValidateCompare<TypeParam>(neq, "[null,0,1,1]", one, "[null,1,0,0]");
+
+ CompareOptions gt(CompareOperator::GREATER);
+ ValidateCompare<TypeParam>(gt, "[]", one, "[]");
+ ValidateCompare<TypeParam>(gt, "[null]", one, "[null]");
+ ValidateCompare<TypeParam>(gt, "[0,0,1,1,2,2]", one, "[0,0,0,0,1,1]");
+ ValidateCompare<TypeParam>(gt, "[0,1,2,3,4,5]", one, "[0,0,1,1,1,1]");
+ ValidateCompare<TypeParam>(gt, "[4,5,6,7,8,9]", one, "[1,1,1,1,1,1]");
+ ValidateCompare<TypeParam>(gt, "[null,0,1,1]", one, "[null,0,0,0]");
+
+ CompareOptions gte(CompareOperator::GREATER_EQUAL);
+ ValidateCompare<TypeParam>(gte, "[]", one, "[]");
+ ValidateCompare<TypeParam>(gte, "[null]", one, "[null]");
+ ValidateCompare<TypeParam>(gte, "[0,0,1,1,2,2]", one, "[0,0,1,1,1,1]");
+ ValidateCompare<TypeParam>(gte, "[0,1,2,3,4,5]", one, "[0,1,1,1,1,1]");
+ ValidateCompare<TypeParam>(gte, "[4,5,6,7,8,9]", one, "[1,1,1,1,1,1]");
+ ValidateCompare<TypeParam>(gte, "[null,0,1,1]", one, "[null,0,1,1]");
+
+ CompareOptions lt(CompareOperator::LESS);
+ ValidateCompare<TypeParam>(lt, "[]", one, "[]");
+ ValidateCompare<TypeParam>(lt, "[null]", one, "[null]");
+ ValidateCompare<TypeParam>(lt, "[0,0,1,1,2,2]", one, "[1,1,0,0,0,0]");
+ ValidateCompare<TypeParam>(lt, "[0,1,2,3,4,5]", one, "[1,0,0,0,0,0]");
+ ValidateCompare<TypeParam>(lt, "[4,5,6,7,8,9]", one, "[0,0,0,0,0,0]");
+ ValidateCompare<TypeParam>(lt, "[null,0,1,1]", one, "[null,1,0,0]");
+
+ CompareOptions lte(CompareOperator::LESS_EQUAL);
+ ValidateCompare<TypeParam>(lte, "[]", one, "[]");
+ ValidateCompare<TypeParam>(lte, "[null]", one, "[null]");
+ ValidateCompare<TypeParam>(lte, "[0,0,1,1,2,2]", one, "[1,1,1,1,0,0]");
+ ValidateCompare<TypeParam>(lte, "[0,1,2,3,4,5]", one, "[1,1,0,0,0,0]");
+ ValidateCompare<TypeParam>(lte, "[4,5,6,7,8,9]", one, "[0,0,0,0,0,0]");
+ ValidateCompare<TypeParam>(lte, "[null,0,1,1]", one, "[null,1,1,1]");
+}
+
+TYPED_TEST(TestNumericCompareKernel, SimpleCompareScalarArray) {
+ using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
+ using CType = typename TypeTraits<TypeParam>::CType;
+
+ Datum one(std::make_shared<ScalarType>(CType(1)));
+
+ CompareOptions eq(CompareOperator::EQUAL);
+ ValidateCompare<TypeParam>(eq, one, "[]", "[]");
+ ValidateCompare<TypeParam>(eq, one, "[null]", "[null]");
+ ValidateCompare<TypeParam>(eq, one, "[0,0,1,1,2,2]", "[0,0,1,1,0,0]");
+ ValidateCompare<TypeParam>(eq, one, "[0,1,2,3,4,5]", "[0,1,0,0,0,0]");
+ ValidateCompare<TypeParam>(eq, one, "[5,4,3,2,1,0]", "[0,0,0,0,1,0]");
+ ValidateCompare<TypeParam>(eq, one, "[null,0,1,1]", "[null,0,1,1]");
+
+ CompareOptions neq(CompareOperator::NOT_EQUAL);
+ ValidateCompare<TypeParam>(neq, one, "[]", "[]");
+ ValidateCompare<TypeParam>(neq, one, "[null]", "[null]");
+ ValidateCompare<TypeParam>(neq, one, "[0,0,1,1,2,2]", "[1,1,0,0,1,1]");
+ ValidateCompare<TypeParam>(neq, one, "[0,1,2,3,4,5]", "[1,0,1,1,1,1]");
+ ValidateCompare<TypeParam>(neq, one, "[5,4,3,2,1,0]", "[1,1,1,1,0,1]");
+ ValidateCompare<TypeParam>(neq, one, "[null,0,1,1]", "[null,1,0,0]");
+
+ CompareOptions gt(CompareOperator::GREATER);
+ ValidateCompare<TypeParam>(gt, one, "[]", "[]");
+ ValidateCompare<TypeParam>(gt, one, "[null]", "[null]");
+ ValidateCompare<TypeParam>(gt, one, "[0,0,1,1,2,2]", "[1,1,0,0,0,0]");
+ ValidateCompare<TypeParam>(gt, one, "[0,1,2,3,4,5]", "[1,0,0,0,0,0]");
+ ValidateCompare<TypeParam>(gt, one, "[4,5,6,7,8,9]", "[0,0,0,0,0,0]");
+ ValidateCompare<TypeParam>(gt, one, "[null,0,1,1]", "[null,1,0,0]");
+
+ CompareOptions gte(CompareOperator::GREATER_EQUAL);
+ ValidateCompare<TypeParam>(gte, one, "[]", "[]");
+ ValidateCompare<TypeParam>(gte, one, "[null]", "[null]");
+ ValidateCompare<TypeParam>(gte, one, "[0,0,1,1,2,2]", "[1,1,1,1,0,0]");
+ ValidateCompare<TypeParam>(gte, one, "[0,1,2,3,4,5]", "[1,1,0,0,0,0]");
+ ValidateCompare<TypeParam>(gte, one, "[4,5,6,7,8,9]", "[0,0,0,0,0,0]");
+ ValidateCompare<TypeParam>(gte, one, "[null,0,1,1]", "[null,1,1,1]");
+
+ CompareOptions lt(CompareOperator::LESS);
+ ValidateCompare<TypeParam>(lt, one, "[]", "[]");
+ ValidateCompare<TypeParam>(lt, one, "[null]", "[null]");
+ ValidateCompare<TypeParam>(lt, one, "[0,0,1,1,2,2]", "[0,0,0,0,1,1]");
+ ValidateCompare<TypeParam>(lt, one, "[0,1,2,3,4,5]", "[0,0,1,1,1,1]");
+ ValidateCompare<TypeParam>(lt, one, "[4,5,6,7,8,9]", "[1,1,1,1,1,1]");
+ ValidateCompare<TypeParam>(lt, one, "[null,0,1,1]", "[null,0,0,0]");
+
+ CompareOptions lte(CompareOperator::LESS_EQUAL);
+ ValidateCompare<TypeParam>(lte, one, "[]", "[]");
+ ValidateCompare<TypeParam>(lte, one, "[null]", "[null]");
+ ValidateCompare<TypeParam>(lte, one, "[0,0,1,1,2,2]", "[0,0,1,1,1,1]");
+ ValidateCompare<TypeParam>(lte, one, "[0,1,2,3,4,5]", "[0,1,1,1,1,1]");
+ ValidateCompare<TypeParam>(lte, one, "[4,5,6,7,8,9]", "[1,1,1,1,1,1]");
+ ValidateCompare<TypeParam>(lte, one, "[null,0,1,1]", "[null,0,1,1]");
+}
+
+TYPED_TEST(TestNumericCompareKernel, TestNullScalar) {
+ /* Ensure that null scalar broadcast to all null results. */
+ using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
+
+ Datum null(std::make_shared<ScalarType>());
+ EXPECT_FALSE(null.scalar()->is_valid);
+
+ CompareOptions eq(CompareOperator::EQUAL);
+ ValidateCompare<TypeParam>(eq, "[]", null, "[]");
+ ValidateCompare<TypeParam>(eq, null, "[]", "[]");
+ ValidateCompare<TypeParam>(eq, "[null]", null, "[null]");
+ ValidateCompare<TypeParam>(eq, null, "[null]", "[null]");
+ ValidateCompare<TypeParam>(eq, null, "[1,2,3]", "[null, null, null]");
+}
+
+TYPED_TEST_SUITE(TestNumericCompareKernel, NumericArrowTypes);
+
+template <typename Type>
+struct CompareRandomNumeric {
+ static void Test(const std::shared_ptr<DataType>& type) {
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+ using CType = typename TypeTraits<Type>::CType;
+ auto rand = random::RandomArrayGenerator(0x5416447);
+ const int64_t length = 1000;
+ for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) {
+ for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
+ auto data =
+ rand.Numeric<typename Type::PhysicalType>(length, 0, 100, null_probability);
+
+ auto data1 =
+ rand.Numeric<typename Type::PhysicalType>(length, 0, 100, null_probability);
+ auto data2 =
+ rand.Numeric<typename Type::PhysicalType>(length, 0, 100, null_probability);
+
+ // Create view of data as the type (e.g. timestamp)
+ auto array1 = Datum(*data1->View(type));
+ auto array2 = Datum(*data2->View(type));
+ auto fifty = Datum(std::make_shared<ScalarType>(CType(50), type));
+ auto options = CompareOptions(op);
+
+ ValidateCompare<Type>(options, array1, fifty);
+ ValidateCompare<Type>(options, fifty, array1);
+ ValidateCompare<Type>(options, array1, array2);
+ }
+ }
+ }
+};
+
+TEST(TestCompareKernel, PrimitiveRandomTests) {
+ TestRandomPrimitiveCTypes<CompareRandomNumeric>();
+}
+
+TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayArray) {
+ /* Ensure that null scalar broadcast to all null results. */
+ CompareOptions eq(CompareOperator::EQUAL);
+ ValidateCompare<TypeParam>(eq, "[]", "[]", "[]");
+ ValidateCompare<TypeParam>(eq, "[null]", "[null]", "[null]");
+ ValidateCompare<TypeParam>(eq, "[1]", "[1]", "[1]");
+ ValidateCompare<TypeParam>(eq, "[1]", "[2]", "[0]");
+ ValidateCompare<TypeParam>(eq, "[null]", "[1]", "[null]");
+ ValidateCompare<TypeParam>(eq, "[1]", "[null]", "[null]");
+
+ CompareOptions lte(CompareOperator::LESS_EQUAL);
+ ValidateCompare<TypeParam>(lte, "[1,2,3,4,5]", "[2,3,4,5,6]", "[1,1,1,1,1]");
+}
+
+TEST(TestCompareTimestamps, Basics) {
+ const char* example1_json = R"(["1970-01-01","2000-02-29","1900-02-28"])";
+ const char* example2_json = R"(["1970-01-02","2000-02-01","1900-02-28"])";
+
+ auto CheckArrayCase = [&](std::shared_ptr<DataType> type, CompareOperator op,
+ const char* expected_json) {
+ auto lhs = ArrayFromJSON(type, example1_json);
+ auto rhs = ArrayFromJSON(type, example2_json);
+ auto expected = ArrayFromJSON(boolean(), expected_json);
+ ASSERT_OK_AND_ASSIGN(Datum result,
+ CallFunction(CompareOperatorToFunctionName(op), {lhs, rhs}));
+ AssertArraysEqual(*expected, *result.make_array(), /*verbose=*/true);
+ };
+
+ auto seconds = timestamp(TimeUnit::SECOND);
+ auto millis = timestamp(TimeUnit::MILLI);
+ auto micros = timestamp(TimeUnit::MICRO);
+ auto nanos = timestamp(TimeUnit::NANO);
+
+ CheckArrayCase(seconds, CompareOperator::EQUAL, "[false, false, true]");
+ CheckArrayCase(seconds, CompareOperator::NOT_EQUAL, "[true, true, false]");
+ CheckArrayCase(seconds, CompareOperator::LESS, "[true, false, false]");
+ CheckArrayCase(seconds, CompareOperator::LESS_EQUAL, "[true, false, true]");
+ CheckArrayCase(seconds, CompareOperator::GREATER, "[false, true, false]");
+ CheckArrayCase(seconds, CompareOperator::GREATER_EQUAL, "[false, true, true]");
+
+ // Check that comparisons with tz-aware timestamps work fine
+ auto seconds_utc = timestamp(TimeUnit::SECOND, "utc");
+ CheckArrayCase(seconds_utc, CompareOperator::EQUAL, "[false, false, true]");
+}
+
+template <typename ArrowType>
+class TestCompareDecimal : public ::testing::Test {};
+TYPED_TEST_SUITE(TestCompareDecimal, DecimalArrowTypes);
+
+TYPED_TEST(TestCompareDecimal, ArrayScalar) {
+ auto ty = std::make_shared<TypeParam>(3, 2);
+
+ std::vector<std::pair<std::string, std::string>> cases = {
+ std::make_pair("equal", "[1, 0, 0, null]"),
+ std::make_pair("not_equal", "[0, 1, 1, null]"),
+ std::make_pair("less", "[0, 0, 1, null]"),
+ std::make_pair("less_equal", "[1, 0, 1, null]"),
+ std::make_pair("greater", "[0, 1, 0, null]"),
+ std::make_pair("greater_equal", "[1, 1, 0, null]"),
+ };
+
+ auto lhs = ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", null])");
+ auto lhs_float = ArrayFromJSON(float64(), "[1.23, 2.34, -1.23, null]");
+ auto lhs_intlike = ArrayFromJSON(ty, R"(["1.00", "2.00", "-1.00", null])");
+ auto rhs = ScalarFromJSON(ty, R"("1.23")");
+ auto rhs_float = ScalarFromJSON(float64(), "1.23");
+ auto rhs_int = ScalarFromJSON(int64(), "1");
+ for (const auto& op : cases) {
+ const auto& function = op.first;
+ const auto& expected = op.second;
+
+ SCOPED_TRACE(function);
+ CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs_intlike, rhs_int, ArrayFromJSON(boolean(), expected));
+ }
+}
+
+TYPED_TEST(TestCompareDecimal, ScalarArray) {
+ auto ty = std::make_shared<TypeParam>(3, 2);
+
+ std::vector<std::pair<std::string, std::string>> cases = {
+ std::make_pair("equal", "[1, 0, 0, null]"),
+ std::make_pair("not_equal", "[0, 1, 1, null]"),
+ std::make_pair("less", "[0, 1, 0, null]"),
+ std::make_pair("less_equal", "[1, 1, 0, null]"),
+ std::make_pair("greater", "[0, 0, 1, null]"),
+ std::make_pair("greater_equal", "[1, 0, 1, null]"),
+ };
+
+ auto lhs = ScalarFromJSON(ty, R"("1.23")");
+ auto lhs_float = ScalarFromJSON(float64(), "1.23");
+ auto lhs_int = ScalarFromJSON(int64(), "1");
+ auto rhs = ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", null])");
+ auto rhs_float = ArrayFromJSON(float64(), "[1.23, 2.34, -1.23, null]");
+ auto rhs_intlike = ArrayFromJSON(ty, R"(["1.00", "2.00", "-1.00", null])");
+ for (const auto& op : cases) {
+ const auto& function = op.first;
+ const auto& expected = op.second;
+
+ SCOPED_TRACE(function);
+ CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs_int, rhs_intlike, ArrayFromJSON(boolean(), expected));
+ }
+}
+
+TYPED_TEST(TestCompareDecimal, ArrayArray) {
+ auto ty = std::make_shared<TypeParam>(3, 2);
+
+ std::vector<std::pair<std::string, std::string>> cases = {
+ std::make_pair("equal", "[1, 0, 0, 1, 0, 0, null, null]"),
+ std::make_pair("not_equal", "[0, 1, 1, 0, 1, 1, null, null]"),
+ std::make_pair("less", "[0, 1, 0, 0, 1, 0, null, null]"),
+ std::make_pair("less_equal", "[1, 1, 0, 1, 1, 0, null, null]"),
+ std::make_pair("greater", "[0, 0, 1, 0, 0, 1, null, null]"),
+ std::make_pair("greater_equal", "[1, 0, 1, 1, 0, 1, null, null]"),
+ };
+
+ auto lhs = ArrayFromJSON(
+ ty, R"(["1.23", "1.23", "2.34", "-1.23", "-1.23", "1.23", "1.23", null])");
+ auto lhs_float =
+ ArrayFromJSON(float64(), "[1.23, 1.23, 2.34, -1.23, -1.23, 1.23, 1.23, null]");
+ auto lhs_intlike = ArrayFromJSON(
+ ty, R"(["1.00", "1.00", "2.00", "-1.00", "-1.00", "1.00", "1.00", null])");
+ auto rhs = ArrayFromJSON(
+ ty, R"(["1.23", "2.34", "1.23", "-1.23", "1.23", "-1.23", null, "1.23"])");
+ auto rhs_float =
+ ArrayFromJSON(float64(), "[1.23, 2.34, 1.23, -1.23, 1.23, -1.23, null, 1.23]");
+ auto rhs_int = ArrayFromJSON(int64(), "[1, 2, 1, -1, 1, -1, null, 1]");
+ for (const auto& op : cases) {
+ const auto& function = op.first;
+ const auto& expected = op.second;
+
+ SCOPED_TRACE(function);
+ CheckScalarBinary(function, ArrayFromJSON(ty, R"([])"), ArrayFromJSON(ty, R"([])"),
+ ArrayFromJSON(boolean(), "[]"));
+ CheckScalarBinary(function, ArrayFromJSON(ty, R"([null])"),
+ ArrayFromJSON(ty, R"([null])"), ArrayFromJSON(boolean(), "[null]"));
+ CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs_intlike, rhs_int, ArrayFromJSON(boolean(), expected));
+ }
+}
+
+// Helper to organize tests for fixed size binary comparisons
+struct CompareCase {
+ std::shared_ptr<DataType> lhs_type;
+ std::shared_ptr<DataType> rhs_type;
+ std::string lhs;
+ std::string rhs;
+ // An index into cases[...].second
+ int result_index;
+};
+
+TEST(TestCompareFixedSizeBinary, ArrayScalar) {
+ auto ty1 = fixed_size_binary(3);
+ auto ty2 = fixed_size_binary(1);
+
+ std::vector<std::pair<std::string, std::vector<std::string>>> cases = {
+ std::make_pair("equal",
+ std::vector<std::string>{
+ "[0, 1, 0, null]",
+ "[0, 0, 0, null]",
+ "[0, 0, 0, null]",
+ }),
+ std::make_pair("not_equal",
+ std::vector<std::string>{
+ "[1, 0, 1, null]",
+ "[1, 1, 1, null]",
+ "[1, 1, 1, null]",
+ }),
+ std::make_pair("less",
+ std::vector<std::string>{
+ "[1, 0, 0, null]",
+ "[1, 1, 1, null]",
+ "[1, 0, 0, null]",
+ }),
+ std::make_pair("less_equal",
+ std::vector<std::string>{
+ "[1, 1, 0, null]",
+ "[1, 1, 1, null]",
+ "[1, 0, 0, null]",
+ }),
+ std::make_pair("greater",
+ std::vector<std::string>{
+ "[0, 0, 1, null]",
+ "[0, 0, 0, null]",
+ "[0, 1, 1, null]",
+ }),
+ std::make_pair("greater_equal",
+ std::vector<std::string>{
+ "[0, 1, 1, null]",
+ "[0, 0, 0, null]",
+ "[0, 1, 1, null]",
+ }),
+ };
+
+ const std::string lhs1 = R"(["aba", "abc", "abd", null])";
+ const std::string rhs1 = R"("abc")";
+ const std::string lhs2 = R"(["a", "b", "c", null])";
+ const std::string rhs2 = R"("b")";
+
+ std::vector<CompareCase> types = {
+ {ty1, ty1, lhs1, rhs1, 0},
+ {ty2, ty2, lhs2, rhs2, 0},
+ {ty1, ty2, lhs1, rhs2, 1},
+ {ty2, ty1, lhs2, rhs1, 2},
+ {ty1, binary(), lhs1, rhs1, 0},
+ {binary(), ty1, lhs1, rhs1, 0},
+ {ty1, large_binary(), lhs1, rhs1, 0},
+ {large_binary(), ty1, lhs1, rhs1, 0},
+ {ty1, utf8(), lhs1, rhs1, 0},
+ {utf8(), ty1, lhs1, rhs1, 0},
+ {ty1, large_utf8(), lhs1, rhs1, 0},
+ {large_utf8(), ty1, lhs1, rhs1, 0},
+ };
+
+ for (const auto& op : cases) {
+ const auto& function = op.first;
+
+ SCOPED_TRACE(function);
+ for (const auto& test_case : types) {
+ const auto& lhs_type = test_case.lhs_type;
+ const auto& rhs_type = test_case.rhs_type;
+ auto lhs = ArrayFromJSON(lhs_type, test_case.lhs);
+ auto rhs = ScalarFromJSON(rhs_type, test_case.rhs);
+ auto expected = ArrayFromJSON(boolean(), op.second[test_case.result_index]);
+
+ CheckScalarBinary(function, ArrayFromJSON(lhs_type, R"([null])"),
+ ScalarFromJSON(rhs_type, "null"),
+ ArrayFromJSON(boolean(), "[null]"));
+ CheckScalarBinary(function, lhs, rhs, expected);
+ }
+ }
+}
+
+TEST(TestCompareFixedSizeBinary, ScalarArray) {
+ auto ty1 = fixed_size_binary(3);
+ auto ty2 = fixed_size_binary(1);
+
+ std::vector<std::pair<std::string, std::vector<std::string>>> cases = {
+ std::make_pair("equal",
+ std::vector<std::string>{
+ "[0, 1, 0, null]",
+ "[0, 0, 0, null]",
+ "[0, 0, 0, null]",
+ }),
+ std::make_pair("not_equal",
+ std::vector<std::string>{
+ "[1, 0, 1, null]",
+ "[1, 1, 1, null]",
+ "[1, 1, 1, null]",
+ }),
+ std::make_pair("less",
+ std::vector<std::string>{
+ "[0, 0, 1, null]",
+ "[0, 1, 1, null]",
+ "[0, 0, 0, null]",
+ }),
+ std::make_pair("less_equal",
+ std::vector<std::string>{
+ "[0, 1, 1, null]",
+ "[0, 1, 1, null]",
+ "[0, 0, 0, null]",
+ }),
+ std::make_pair("greater",
+ std::vector<std::string>{
+ "[1, 0, 0, null]",
+ "[1, 0, 0, null]",
+ "[1, 1, 1, null]",
+ }),
+ std::make_pair("greater_equal",
+ std::vector<std::string>{
+ "[1, 1, 0, null]",
+ "[1, 0, 0, null]",
+ "[1, 1, 1, null]",
+ }),
+ };
+
+ const std::string lhs1 = R"("abc")";
+ const std::string rhs1 = R"(["aba", "abc", "abd", null])";
+ const std::string lhs2 = R"("b")";
+ const std::string rhs2 = R"(["a", "b", "c", null])";
+
+ std::vector<CompareCase> types = {
+ {ty1, ty1, lhs1, rhs1, 0},
+ {ty2, ty2, lhs2, rhs2, 0},
+ {ty1, ty2, lhs1, rhs2, 1},
+ {ty2, ty1, lhs2, rhs1, 2},
+ {ty1, binary(), lhs1, rhs1, 0},
+ {binary(), ty1, lhs1, rhs1, 0},
+ {ty1, large_binary(), lhs1, rhs1, 0},
+ {large_binary(), ty1, lhs1, rhs1, 0},
+ {ty1, utf8(), lhs1, rhs1, 0},
+ {utf8(), ty1, lhs1, rhs1, 0},
+ {ty1, large_utf8(), lhs1, rhs1, 0},
+ {large_utf8(), ty1, lhs1, rhs1, 0},
+ };
+
+ for (const auto& op : cases) {
+ const auto& function = op.first;
+
+ SCOPED_TRACE(function);
+ for (const auto& test_case : types) {
+ const auto& lhs_type = test_case.lhs_type;
+ const auto& rhs_type = test_case.rhs_type;
+ auto lhs = ScalarFromJSON(lhs_type, test_case.lhs);
+ auto rhs = ArrayFromJSON(rhs_type, test_case.rhs);
+ auto expected = ArrayFromJSON(boolean(), op.second[test_case.result_index]);
+
+ CheckScalarBinary(function, ScalarFromJSON(rhs_type, "null"),
+ ArrayFromJSON(lhs_type, R"([null])"),
+ ArrayFromJSON(boolean(), "[null]"));
+ CheckScalarBinary(function, lhs, rhs, expected);
+ }
+ }
+}
+
+TEST(TestCompareFixedSizeBinary, ArrayArray) {
+ auto ty1 = fixed_size_binary(3);
+ auto ty2 = fixed_size_binary(1);
+
+ std::vector<std::pair<std::string, std::vector<std::string>>> cases = {
+ std::make_pair("equal",
+ std::vector<std::string>{
+ "[1, 0, 0, null, null]",
+ "[1, 0, 0, null, null]",
+ "[1, 0, 0, null, null]",
+ "[1, 0, 0, null, null]",
+ "[0, 0, 0, null, null]",
+ "[0, 0, 0, null, null]",
+ }),
+ std::make_pair("not_equal",
+ std::vector<std::string>{
+ "[0, 1, 1, null, null]",
+ "[0, 1, 1, null, null]",
+ "[0, 1, 1, null, null]",
+ "[0, 1, 1, null, null]",
+ "[1, 1, 1, null, null]",
+ "[1, 1, 1, null, null]",
+ }),
+ std::make_pair("less",
+ std::vector<std::string>{
+ "[0, 1, 0, null, null]",
+ "[0, 0, 1, null, null]",
+ "[0, 1, 0, null, null]",
+ "[0, 0, 1, null, null]",
+ "[0, 1, 1, null, null]",
+ "[1, 1, 0, null, null]",
+ }),
+ std::make_pair("less_equal",
+ std::vector<std::string>{
+ "[1, 1, 0, null, null]",
+ "[1, 0, 1, null, null]",
+ "[1, 1, 0, null, null]",
+ "[1, 0, 1, null, null]",
+ "[0, 1, 1, null, null]",
+ "[1, 1, 0, null, null]",
+ }),
+ std::make_pair("greater",
+ std::vector<std::string>{
+ "[0, 0, 1, null, null]",
+ "[0, 1, 0, null, null]",
+ "[0, 0, 1, null, null]",
+ "[0, 1, 0, null, null]",
+ "[1, 0, 0, null, null]",
+ "[0, 0, 1, null, null]",
+ }),
+ std::make_pair("greater_equal",
+ std::vector<std::string>{
+ "[1, 0, 1, null, null]",
+ "[1, 1, 0, null, null]",
+ "[1, 0, 1, null, null]",
+ "[1, 1, 0, null, null]",
+ "[1, 0, 0, null, null]",
+ "[0, 0, 1, null, null]",
+ }),
+ };
+
+ const std::string lhs1 = R"(["abc", "abc", "abd", null, "abc"])";
+ const std::string rhs1 = R"(["abc", "abd", "abc", "abc", null])";
+ const std::string lhs2 = R"(["a", "a", "d", null, "a"])";
+ const std::string rhs2 = R"(["a", "d", "c", "a", null])";
+
+ std::vector<CompareCase> types = {
+ {ty1, ty1, lhs1, rhs1, 0},
+ {ty1, ty1, rhs1, lhs1, 1},
+ {ty2, ty2, lhs2, rhs2, 2},
+ {ty2, ty2, rhs2, lhs2, 3},
+ {ty1, ty2, lhs1, rhs2, 4},
+ {ty2, ty1, lhs2, rhs1, 5},
+ {ty1, binary(), lhs1, rhs1, 0},
+ {binary(), ty1, lhs1, rhs1, 0},
+ {ty1, large_binary(), lhs1, rhs1, 0},
+ {large_binary(), ty1, lhs1, rhs1, 0},
+ {ty1, utf8(), lhs1, rhs1, 0},
+ {utf8(), ty1, lhs1, rhs1, 0},
+ {ty1, large_utf8(), lhs1, rhs1, 0},
+ {large_utf8(), ty1, lhs1, rhs1, 0},
+ };
+
+ for (const auto& op : cases) {
+ const auto& function = op.first;
+
+ SCOPED_TRACE(function);
+ for (const auto& test_case : types) {
+ const auto& lhs_type = test_case.lhs_type;
+ const auto& rhs_type = test_case.rhs_type;
+ auto lhs = ArrayFromJSON(lhs_type, test_case.lhs);
+ auto rhs = ArrayFromJSON(rhs_type, test_case.rhs);
+ auto expected = ArrayFromJSON(boolean(), op.second[test_case.result_index]);
+
+ CheckScalarBinary(function, ArrayFromJSON(lhs_type, R"([])"),
+ ArrayFromJSON(rhs_type, R"([])"), ArrayFromJSON(boolean(), "[]"));
+ CheckScalarBinary(function, ArrayFromJSON(lhs_type, R"([null])"),
+ ArrayFromJSON(rhs_type, R"([null])"),
+ ArrayFromJSON(boolean(), "[null]"));
+ CheckScalarBinary(function, lhs, rhs, expected);
+ }
+ }
+}
+
+TEST(TestCompareKernel, DispatchBest) {
+ for (std::string name :
+ {"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) {
+ CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()});
+ CheckDispatchBest(name, {int32(), null()}, {int32(), int32()});
+ CheckDispatchBest(name, {null(), int32()}, {int32(), int32()});
+
+ CheckDispatchBest(name, {int32(), int8()}, {int32(), int32()});
+ CheckDispatchBest(name, {int32(), int16()}, {int32(), int32()});
+ CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()});
+ CheckDispatchBest(name, {int32(), int64()}, {int64(), int64()});
+
+ CheckDispatchBest(name, {int32(), uint8()}, {int32(), int32()});
+ CheckDispatchBest(name, {int32(), uint16()}, {int32(), int32()});
+ CheckDispatchBest(name, {int32(), uint32()}, {int64(), int64()});
+ CheckDispatchBest(name, {int32(), uint64()}, {int64(), int64()});
+
+ CheckDispatchBest(name, {uint8(), uint8()}, {uint8(), uint8()});
+ CheckDispatchBest(name, {uint8(), uint16()}, {uint16(), uint16()});
+
+ CheckDispatchBest(name, {int32(), float32()}, {float32(), float32()});
+ CheckDispatchBest(name, {float32(), int64()}, {float32(), float32()});
+ CheckDispatchBest(name, {float64(), int32()}, {float64(), float64()});
+
+ CheckDispatchBest(name, {dictionary(int8(), float64()), float64()},
+ {float64(), float64()});
+ CheckDispatchBest(name, {dictionary(int8(), float64()), int16()},
+ {float64(), float64()});
+
+ CheckDispatchBest(name, {timestamp(TimeUnit::MICRO), date64()},
+ {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)});
+
+ CheckDispatchBest(name, {timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MICRO)},
+ {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)});
+
+ CheckDispatchBest(name, {utf8(), binary()}, {binary(), binary()});
+ CheckDispatchBest(name, {large_utf8(), binary()}, {large_binary(), large_binary()});
+ CheckDispatchBest(name, {large_utf8(), fixed_size_binary(2)},
+ {large_binary(), large_binary()});
+ CheckDispatchBest(name, {binary(), fixed_size_binary(2)}, {binary(), binary()});
+ CheckDispatchBest(name, {fixed_size_binary(4), fixed_size_binary(2)},
+ {fixed_size_binary(4), fixed_size_binary(2)});
+
+ CheckDispatchBest(name, {decimal128(3, 2), decimal128(6, 3)},
+ {decimal128(4, 3), decimal128(6, 3)});
+ CheckDispatchBest(name, {decimal128(3, 2), decimal256(3, 2)},
+ {decimal256(3, 2), decimal256(3, 2)});
+ CheckDispatchBest(name, {decimal128(3, 2), float64()}, {float64(), float64()});
+ CheckDispatchBest(name, {float64(), decimal128(3, 2)}, {float64(), float64()});
+ CheckDispatchBest(name, {decimal128(3, 2), int64()},
+ {decimal128(3, 2), decimal128(21, 2)});
+ CheckDispatchBest(name, {int64(), decimal128(3, 2)},
+ {decimal128(21, 2), decimal128(3, 2)});
+ }
+}
+
+TEST(TestCompareKernel, GreaterWithImplicitCasts) {
+ CheckScalarBinary("greater", ArrayFromJSON(int32(), "[0, 1, 2, null]"),
+ ArrayFromJSON(float64(), "[0.5, 1.0, 1.5, 2.0]"),
+ ArrayFromJSON(boolean(), "[false, false, true, null]"));
+
+ CheckScalarBinary("greater", ArrayFromJSON(int8(), "[-16, 0, 16, null]"),
+ ArrayFromJSON(uint32(), "[3, 4, 5, 7]"),
+ ArrayFromJSON(boolean(), "[false, false, true, null]"));
+
+ CheckScalarBinary("greater", ArrayFromJSON(int8(), "[-16, 0, 16, null]"),
+ ArrayFromJSON(uint8(), "[255, 254, 1, 0]"),
+ ArrayFromJSON(boolean(), "[false, false, true, null]"));
+
+ CheckScalarBinary("greater",
+ ArrayFromJSON(dictionary(int32(), int32()), "[0, 1, 2, null]"),
+ ArrayFromJSON(uint32(), "[3, 4, 5, 7]"),
+ ArrayFromJSON(boolean(), "[false, false, false, null]"));
+
+ CheckScalarBinary("greater", ArrayFromJSON(int32(), "[0, 1, 2, null]"),
+ std::make_shared<NullArray>(4),
+ ArrayFromJSON(boolean(), "[null, null, null, null]"));
+
+ CheckScalarBinary("greater",
+ ArrayFromJSON(timestamp(TimeUnit::SECOND),
+ R"(["1970-01-01","2000-02-29","1900-02-28"])"),
+ ArrayFromJSON(date64(), "[86400000, 0, 86400000]"),
+ ArrayFromJSON(boolean(), "[false, true, false]"));
+
+ CheckScalarBinary("greater",
+ ArrayFromJSON(dictionary(int32(), int8()), "[3, -3, -28, null]"),
+ ArrayFromJSON(uint32(), "[3, 4, 5, 7]"),
+ ArrayFromJSON(boolean(), "[false, false, false, null]"));
+}
+
+TEST(TestCompareKernel, GreaterWithImplicitCastsUint64EdgeCase) {
+ // int64 is as wide as we can promote
+ CheckDispatchBest("greater", {int8(), uint64()}, {int64(), int64()});
+
+ // this works sometimes
+ CheckScalarBinary("greater", ArrayFromJSON(int8(), "[-1]"),
+ ArrayFromJSON(uint64(), "[0]"), ArrayFromJSON(boolean(), "[false]"));
+
+ // ... but it can result in impossible implicit casts in the presence of uint64, since
+ // some uint64 values cannot be cast to int64:
+ ASSERT_RAISES(
+ Invalid,
+ CallFunction("greater", {ArrayFromJSON(int64(), "[-1]"),
+ ArrayFromJSON(uint64(), "[18446744073709551615]")}));
+}
+
+class TestStringCompareKernel : public ::testing::Test {};
+
+TEST_F(TestStringCompareKernel, SimpleCompareArrayScalar) {
+ Datum one(std::make_shared<StringScalar>("one"));
+
+ CompareOptions eq(CompareOperator::EQUAL);
+ ValidateCompare<StringType>(eq, "[]", one, "[]");
+ ValidateCompare<StringType>(eq, "[null]", one, "[null]");
+ ValidateCompare<StringType>(eq, R"(["zero","zero","one","one","two","two"])", one,
+ "[0,0,1,1,0,0]");
+ ValidateCompare<StringType>(eq, R"(["zero","one","two","three","four","five"])", one,
+ "[0,1,0,0,0,0]");
+ ValidateCompare<StringType>(eq, R"(["five","four","three","two","one","zero"])", one,
+ "[0,0,0,0,1,0]");
+ ValidateCompare<StringType>(eq, R"([null,"zero","one","one"])", one, "[null,0,1,1]");
+
+ Datum na(std::make_shared<StringScalar>());
+ ValidateCompare<StringType>(eq, R"([null,"zero","one","one"])", na,
+ "[null,null,null,null]");
+ ValidateCompare<StringType>(eq, na, R"([null,"zero","one","one"])",
+ "[null,null,null,null]");
+
+ CompareOptions neq(CompareOperator::NOT_EQUAL);
+ ValidateCompare<StringType>(neq, "[]", one, "[]");
+ ValidateCompare<StringType>(neq, "[null]", one, "[null]");
+ ValidateCompare<StringType>(neq, R"(["zero","zero","one","one","two","two"])", one,
+ "[1,1,0,0,1,1]");
+ ValidateCompare<StringType>(neq, R"(["zero","one","two","three","four","five"])", one,
+ "[1,0,1,1,1,1]");
+ ValidateCompare<StringType>(neq, R"(["five","four","three","two","one","zero"])", one,
+ "[1,1,1,1,0,1]");
+ ValidateCompare<StringType>(neq, R"([null,"zero","one","one"])", one, "[null,1,0,0]");
+
+ CompareOptions gt(CompareOperator::GREATER);
+ ValidateCompare<StringType>(gt, "[]", one, "[]");
+ ValidateCompare<StringType>(gt, "[null]", one, "[null]");
+ ValidateCompare<StringType>(gt, R"(["zero","zero","one","one","two","two"])", one,
+ "[1,1,0,0,1,1]");
+ ValidateCompare<StringType>(gt, R"(["zero","one","two","three","four","five"])", one,
+ "[1,0,1,1,0,0]");
+ ValidateCompare<StringType>(gt, R"(["four","five","six","seven","eight","nine"])", one,
+ "[0,0,1,1,0,0]");
+ ValidateCompare<StringType>(gt, R"([null,"zero","one","one"])", one, "[null,1,0,0]");
+
+ CompareOptions gte(CompareOperator::GREATER_EQUAL);
+ ValidateCompare<StringType>(gte, "[]", one, "[]");
+ ValidateCompare<StringType>(gte, "[null]", one, "[null]");
+ ValidateCompare<StringType>(gte, R"(["zero","zero","one","one","two","two"])", one,
+ "[1,1,1,1,1,1]");
+ ValidateCompare<StringType>(gte, R"(["zero","one","two","three","four","five"])", one,
+ "[1,1,1,1,0,0]");
+ ValidateCompare<StringType>(gte, R"(["four","five","six","seven","eight","nine"])", one,
+ "[0,0,1,1,0,0]");
+ ValidateCompare<StringType>(gte, R"([null,"zero","one","one"])", one, "[null,1,1,1]");
+
+ CompareOptions lt(CompareOperator::LESS);
+ ValidateCompare<StringType>(lt, "[]", one, "[]");
+ ValidateCompare<StringType>(lt, "[null]", one, "[null]");
+ ValidateCompare<StringType>(lt, R"(["zero","zero","one","one","two","two"])", one,
+ "[0,0,0,0,0,0]");
+ ValidateCompare<StringType>(lt, R"(["zero","one","two","three","four","five"])", one,
+ "[0,0,0,0,1,1]");
+ ValidateCompare<StringType>(lt, R"(["four","five","six","seven","eight","nine"])", one,
+ "[1,1,0,0,1,1]");
+ ValidateCompare<StringType>(lt, R"([null,"zero","one","one"])", one, "[null,0,0,0]");
+
+ CompareOptions lte(CompareOperator::LESS_EQUAL);
+ ValidateCompare<StringType>(lte, "[]", one, "[]");
+ ValidateCompare<StringType>(lte, "[null]", one, "[null]");
+ ValidateCompare<StringType>(lte, R"(["zero","zero","one","one","two","two"])", one,
+ "[0,0,1,1,0,0]");
+ ValidateCompare<StringType>(lte, R"(["zero","one","two","three","four","five"])", one,
+ "[0,1,0,0,1,1]");
+ ValidateCompare<StringType>(lte, R"(["four","five","six","seven","eight","nine"])", one,
+ "[1,1,0,0,1,1]");
+ ValidateCompare<StringType>(lte, R"([null,"zero","one","one"])", one, "[null,0,1,1]");
+}
+
+TEST_F(TestStringCompareKernel, RandomCompareArrayScalar) {
+ using ScalarType = typename TypeTraits<StringType>::ScalarType;
+
+ auto rand = random::RandomArrayGenerator(0x5416447);
+ for (size_t i = 3; i < 10; i++) {
+ for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) {
+ for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
+ const int64_t length = static_cast<int64_t>(1ULL << i);
+ auto array = Datum(rand.String(length, 0, 16, null_probability));
+ auto hello = Datum(std::make_shared<ScalarType>("hello"));
+ auto options = CompareOptions(op);
+ ValidateCompare<StringType>(options, array, hello);
+ ValidateCompare<StringType>(options, hello, array);
+ }
+ }
+ }
+}
+
+TEST_F(TestStringCompareKernel, RandomCompareArrayArray) {
+ auto rand = random::RandomArrayGenerator(0x5416447);
+ for (size_t i = 3; i < 5; i++) {
+ for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) {
+ for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
+ auto length = static_cast<int64_t>(1ULL << i);
+ auto lhs = Datum(rand.String(length << i, 0, 16, null_probability));
+ auto rhs = Datum(rand.String(length << i, 0, 16, null_probability));
+ auto options = CompareOptions(op);
+ ValidateCompare<StringType>(options, lhs, rhs);
+ }
+ }
+ }
+}
+
+template <typename T>
+class TestVarArgsCompare : public TestBase {
+ protected:
+ static std::shared_ptr<DataType> type_singleton() {
+ return TypeTraits<T>::type_singleton();
+ }
+
+ using VarArgsFunction = std::function<Result<Datum>(
+ const std::vector<Datum>&, ElementWiseAggregateOptions, ExecContext*)>;
+
+ void SetUp() override { equal_options_ = equal_options_.nans_equal(true); }
+
+ Datum scalar(const std::string& value) {
+ return ScalarFromJSON(type_singleton(), value);
+ }
+
+ Datum array(const std::string& value) { return ArrayFromJSON(type_singleton(), value); }
+
+ Datum Eval(VarArgsFunction func, const std::vector<Datum>& args) {
+ EXPECT_OK_AND_ASSIGN(auto actual,
+ func(args, element_wise_aggregate_options_, nullptr));
+ ValidateOutput(actual);
+ return actual;
+ }
+
+ void AssertNullScalar(VarArgsFunction func, const std::vector<Datum>& args) {
+ auto datum = this->Eval(func, args);
+ ASSERT_TRUE(datum.is_scalar());
+ ASSERT_FALSE(datum.scalar()->is_valid);
+ }
+
+ void Assert(VarArgsFunction func, Datum expected, const std::vector<Datum>& args) {
+ auto actual = Eval(func, args);
+ AssertDatumsApproxEqual(expected, actual, /*verbose=*/true, equal_options_);
+ }
+
+ EqualOptions equal_options_ = EqualOptions::Defaults();
+ ElementWiseAggregateOptions element_wise_aggregate_options_;
+};
+
+template <typename T>
+class TestVarArgsCompareNumeric : public TestVarArgsCompare<T> {};
+
+template <typename T>
+class TestVarArgsCompareFloating : public TestVarArgsCompare<T> {};
+
+template <typename T>
+class TestVarArgsCompareParametricTemporal : public TestVarArgsCompare<T> {
+ protected:
+ static std::shared_ptr<DataType> type_singleton() {
+ // Time32 requires second/milli, Time64 requires nano/micro
+ if (TypeTraits<T>::bytes_required(1) == 4) {
+ return std::make_shared<T>(TimeUnit::type::SECOND);
+ } else {
+ return std::make_shared<T>(TimeUnit::type::NANO);
+ }
+ }
+
+ Datum scalar(const std::string& value) {
+ return ScalarFromJSON(type_singleton(), value);
+ }
+
+ Datum array(const std::string& value) { return ArrayFromJSON(type_singleton(), value); }
+};
+
+using NumericBasedTypes =
+ ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
+ Int32Type, Int64Type, FloatType, DoubleType, Date32Type, Date64Type>;
+using ParametricTemporalTypes = ::testing::Types<TimestampType, Time32Type, Time64Type>;
+
+TYPED_TEST_SUITE(TestVarArgsCompareNumeric, NumericBasedTypes);
+TYPED_TEST_SUITE(TestVarArgsCompareFloating, RealArrowTypes);
+TYPED_TEST_SUITE(TestVarArgsCompareParametricTemporal, ParametricTemporalTypes);
+
+TYPED_TEST(TestVarArgsCompareNumeric, MinElementWise) {
+ this->AssertNullScalar(MinElementWise, {});
+ this->AssertNullScalar(MinElementWise, {this->scalar("null"), this->scalar("null")});
+
+ this->Assert(MinElementWise, this->scalar("0"), {this->scalar("0")});
+ this->Assert(MinElementWise, this->scalar("0"),
+ {this->scalar("2"), this->scalar("0"), this->scalar("1")});
+ this->Assert(
+ MinElementWise, this->scalar("0"),
+ {this->scalar("2"), this->scalar("0"), this->scalar("1"), this->scalar("null")});
+ this->Assert(MinElementWise, this->scalar("1"),
+ {this->scalar("null"), this->scalar("null"), this->scalar("1"),
+ this->scalar("null")});
+
+ this->Assert(MinElementWise, (this->array("[]")), {this->array("[]")});
+ this->Assert(MinElementWise, this->array("[1, 2, 3, null]"),
+ {this->array("[1, 2, 3, null]")});
+
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
+ {this->array("[1, 2, 3, 4]"), this->scalar("2")});
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
+ {this->array("[1, null, 3, 4]"), this->scalar("2")});
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
+ {this->array("[1, null, 3, 4]"), this->scalar("2"), this->scalar("4")});
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
+ {this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")});
+
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
+ {this->array("[1, 2, 3, 4]"), this->array("[2, 2, 2, 2]")});
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
+ {this->array("[1, 2, 3, 4]"), this->array("[2, null, 2, 2]")});
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
+ {this->array("[1, null, 3, 4]"), this->array("[2, 2, 2, 2]")});
+
+ this->Assert(MinElementWise, this->array("[1, 2, null, 6]"),
+ {this->array("[1, 2, null, null]"), this->array("[4, null, null, 6]")});
+ this->Assert(MinElementWise, this->array("[1, 2, null, 6]"),
+ {this->array("[4, null, null, 6]"), this->array("[1, 2, null, null]")});
+ this->Assert(MinElementWise, this->array("[1, 2, 3, 4]"),
+ {this->array("[1, 2, 3, 4]"), this->array("[null, null, null, null]")});
+ this->Assert(MinElementWise, this->array("[1, 2, 3, 4]"),
+ {this->array("[null, null, null, null]"), this->array("[1, 2, 3, 4]")});
+
+ this->Assert(MinElementWise, this->array("[1, 1, 1, 1]"),
+ {this->scalar("1"), this->array("[1, 2, 3, 4]")});
+ this->Assert(MinElementWise, this->array("[1, 1, 1, 1]"),
+ {this->scalar("1"), this->array("[null, null, null, null]")});
+ this->Assert(MinElementWise, this->array("[1, 1, 1, 1]"),
+ {this->scalar("null"), this->array("[1, 1, 1, 1]")});
+ this->Assert(MinElementWise, this->array("[null, null, null, null]"),
+ {this->scalar("null"), this->array("[null, null, null, null]")});
+
+ // Test null handling
+ this->element_wise_aggregate_options_.skip_nulls = false;
+ this->AssertNullScalar(MinElementWise, {this->scalar("null"), this->scalar("null")});
+ this->AssertNullScalar(MinElementWise, {this->scalar("0"), this->scalar("null")});
+
+ this->Assert(MinElementWise, this->array("[1, null, 2, 2]"),
+ {this->array("[1, null, 3, 4]"), this->scalar("2"), this->scalar("4")});
+ this->Assert(MinElementWise, this->array("[null, null, null, null]"),
+ {this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")});
+ this->Assert(MinElementWise, this->array("[1, null, 2, 2]"),
+ {this->array("[1, 2, 3, 4]"), this->array("[2, null, 2, 2]")});
+
+ this->Assert(MinElementWise, this->array("[null, null, null, null]"),
+ {this->scalar("1"), this->array("[null, null, null, null]")});
+ this->Assert(MinElementWise, this->array("[null, null, null, null]"),
+ {this->scalar("null"), this->array("[1, 1, 1, 1]")});
+}
+
+TYPED_TEST(TestVarArgsCompareFloating, MinElementWise) {
+ auto Check = [this](const std::string& expected,
+ const std::vector<std::string>& inputs) {
+ std::vector<Datum> args;
+ for (const auto& input : inputs) {
+ args.emplace_back(this->scalar(input));
+ }
+ this->Assert(MinElementWise, this->scalar(expected), args);
+
+ args.clear();
+ for (const auto& input : inputs) {
+ args.emplace_back(this->array("[" + input + "]"));
+ }
+ this->Assert(MinElementWise, this->array("[" + expected + "]"), args);
+ };
+ Check("-0.0", {"0.0", "-0.0"});
+ Check("-0.0", {"1.0", "-0.0", "0.0"});
+ Check("-1.0", {"-1.0", "-0.0"});
+ Check("0", {"0", "NaN"});
+ Check("0", {"NaN", "0"});
+ Check("Inf", {"Inf", "NaN"});
+ Check("Inf", {"NaN", "Inf"});
+ Check("-Inf", {"-Inf", "NaN"});
+ Check("-Inf", {"NaN", "-Inf"});
+ Check("NaN", {"NaN", "null"});
+ Check("0", {"0", "Inf"});
+ Check("-Inf", {"0", "-Inf"});
+}
+
+TYPED_TEST(TestVarArgsCompareParametricTemporal, MinElementWise) {
+ // Temporal kernel is implemented with numeric kernel underneath
+ this->AssertNullScalar(MinElementWise, {});
+ this->AssertNullScalar(MinElementWise, {this->scalar("null"), this->scalar("null")});
+
+ this->Assert(MinElementWise, this->scalar("0"), {this->scalar("0")});
+ this->Assert(MinElementWise, this->scalar("0"), {this->scalar("2"), this->scalar("0")});
+ this->Assert(MinElementWise, this->scalar("0"),
+ {this->scalar("0"), this->scalar("null")});
+
+ this->Assert(MinElementWise, (this->array("[]")), {this->array("[]")});
+ this->Assert(MinElementWise, this->array("[1, 2, 3, null]"),
+ {this->array("[1, 2, 3, null]")});
+
+ this->Assert(MinElementWise, this->array("[1, 2, 2, 2]"),
+ {this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")});
+
+ this->Assert(MinElementWise, this->array("[1, 2, 3, 2]"),
+ {this->array("[1, null, 3, 4]"), this->array("[2, 2, null, 2]")});
+}
+
+TYPED_TEST(TestVarArgsCompareNumeric, MaxElementWise) {
+ this->AssertNullScalar(MaxElementWise, {});
+ this->AssertNullScalar(MaxElementWise, {this->scalar("null"), this->scalar("null")});
+
+ this->Assert(MaxElementWise, this->scalar("0"), {this->scalar("0")});
+ this->Assert(MaxElementWise, this->scalar("2"),
+ {this->scalar("2"), this->scalar("0"), this->scalar("1")});
+ this->Assert(
+ MaxElementWise, this->scalar("2"),
+ {this->scalar("2"), this->scalar("0"), this->scalar("1"), this->scalar("null")});
+ this->Assert(MaxElementWise, this->scalar("1"),
+ {this->scalar("null"), this->scalar("null"), this->scalar("1"),
+ this->scalar("null")});
+
+ this->Assert(MaxElementWise, (this->array("[]")), {this->array("[]")});
+ this->Assert(MaxElementWise, this->array("[1, 2, 3, null]"),
+ {this->array("[1, 2, 3, null]")});
+
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
+ {this->array("[1, 2, 3, 4]"), this->scalar("2")});
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
+ {this->array("[1, null, 3, 4]"), this->scalar("2")});
+ this->Assert(MaxElementWise, this->array("[4, 4, 4, 4]"),
+ {this->array("[1, null, 3, 4]"), this->scalar("2"), this->scalar("4")});
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
+ {this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")});
+
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
+ {this->array("[1, 2, 3, 4]"), this->array("[2, 2, 2, 2]")});
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
+ {this->array("[1, 2, 3, 4]"), this->array("[2, null, 2, 2]")});
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
+ {this->array("[1, null, 3, 4]"), this->array("[2, 2, 2, 2]")});
+
+ this->Assert(MaxElementWise, this->array("[4, 2, null, 6]"),
+ {this->array("[1, 2, null, null]"), this->array("[4, null, null, 6]")});
+ this->Assert(MaxElementWise, this->array("[4, 2, null, 6]"),
+ {this->array("[4, null, null, 6]"), this->array("[1, 2, null, null]")});
+ this->Assert(MaxElementWise, this->array("[1, 2, 3, 4]"),
+ {this->array("[1, 2, 3, 4]"), this->array("[null, null, null, null]")});
+ this->Assert(MaxElementWise, this->array("[1, 2, 3, 4]"),
+ {this->array("[null, null, null, null]"), this->array("[1, 2, 3, 4]")});
+
+ this->Assert(MaxElementWise, this->array("[1, 2, 3, 4]"),
+ {this->scalar("1"), this->array("[1, 2, 3, 4]")});
+ this->Assert(MaxElementWise, this->array("[1, 1, 1, 1]"),
+ {this->scalar("1"), this->array("[null, null, null, null]")});
+ this->Assert(MaxElementWise, this->array("[1, 1, 1, 1]"),
+ {this->scalar("null"), this->array("[1, 1, 1, 1]")});
+ this->Assert(MaxElementWise, this->array("[null, null, null, null]"),
+ {this->scalar("null"), this->array("[null, null, null, null]")});
+
+ // Test null handling
+ this->element_wise_aggregate_options_.skip_nulls = false;
+ this->AssertNullScalar(MaxElementWise, {this->scalar("null"), this->scalar("null")});
+ this->AssertNullScalar(MaxElementWise, {this->scalar("0"), this->scalar("null")});
+
+ this->Assert(MaxElementWise, this->array("[4, null, 4, 4]"),
+ {this->array("[1, null, 3, 4]"), this->scalar("2"), this->scalar("4")});
+ this->Assert(MaxElementWise, this->array("[null, null, null, null]"),
+ {this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")});
+ this->Assert(MaxElementWise, this->array("[2, null, 3, 4]"),
+ {this->array("[1, 2, 3, 4]"), this->array("[2, null, 2, 2]")});
+
+ this->Assert(MaxElementWise, this->array("[null, null, null, null]"),
+ {this->scalar("1"), this->array("[null, null, null, null]")});
+ this->Assert(MaxElementWise, this->array("[null, null, null, null]"),
+ {this->scalar("null"), this->array("[1, 1, 1, 1]")});
+}
+
+TYPED_TEST(TestVarArgsCompareFloating, MaxElementWise) {
+ auto Check = [this](const std::string& expected,
+ const std::vector<std::string>& inputs) {
+ std::vector<Datum> args;
+ for (const auto& input : inputs) {
+ args.emplace_back(this->scalar(input));
+ }
+ this->Assert(MaxElementWise, this->scalar(expected), args);
+
+ args.clear();
+ for (const auto& input : inputs) {
+ args.emplace_back(this->array("[" + input + "]"));
+ }
+ this->Assert(MaxElementWise, this->array("[" + expected + "]"), args);
+ };
+ Check("0.0", {"0.0", "-0.0"});
+ Check("1.0", {"1.0", "-0.0", "0.0"});
+ Check("-0.0", {"-1.0", "-0.0"});
+ Check("0", {"0", "NaN"});
+ Check("0", {"NaN", "0"});
+ Check("Inf", {"Inf", "NaN"});
+ Check("Inf", {"NaN", "Inf"});
+ Check("-Inf", {"-Inf", "NaN"});
+ Check("-Inf", {"NaN", "-Inf"});
+ Check("NaN", {"NaN", "null"});
+ Check("Inf", {"0", "Inf"});
+ Check("0", {"0", "-Inf"});
+}
+
+TYPED_TEST(TestVarArgsCompareParametricTemporal, MaxElementWise) {
+ // Temporal kernel is implemented with numeric kernel underneath
+ this->AssertNullScalar(MaxElementWise, {});
+ this->AssertNullScalar(MaxElementWise, {this->scalar("null"), this->scalar("null")});
+
+ this->Assert(MaxElementWise, this->scalar("0"), {this->scalar("0")});
+ this->Assert(MaxElementWise, this->scalar("2"), {this->scalar("2"), this->scalar("0")});
+ this->Assert(MaxElementWise, this->scalar("0"),
+ {this->scalar("0"), this->scalar("null")});
+
+ this->Assert(MaxElementWise, (this->array("[]")), {this->array("[]")});
+ this->Assert(MaxElementWise, this->array("[1, 2, 3, null]"),
+ {this->array("[1, 2, 3, null]")});
+
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
+ {this->array("[1, null, 3, 4]"), this->scalar("null"), this->scalar("2")});
+
+ this->Assert(MaxElementWise, this->array("[2, 2, 3, 4]"),
+ {this->array("[1, null, 3, 4]"), this->array("[2, 2, null, 2]")});
+}
+
+TEST(TestMaxElementWiseMinElementWise, CommonTemporal) {
+ EXPECT_THAT(MinElementWise({
+ ScalarFromJSON(timestamp(TimeUnit::SECOND), "1"),
+ ScalarFromJSON(timestamp(TimeUnit::MILLI), "12000"),
+ }),
+ ResultWith(ScalarFromJSON(timestamp(TimeUnit::MILLI), "1000")));
+ EXPECT_THAT(MaxElementWise({
+ ScalarFromJSON(date32(), "1"),
+ ScalarFromJSON(timestamp(TimeUnit::SECOND), "86401"),
+ }),
+ ResultWith(ScalarFromJSON(timestamp(TimeUnit::SECOND), "86401")));
+ EXPECT_THAT(MinElementWise({
+ ScalarFromJSON(date32(), "1"),
+ ScalarFromJSON(date64(), "172800000"),
+ }),
+ ResultWith(ScalarFromJSON(date64(), "86400000")));
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_if_else.cc
new file mode 100644
index 000000000..b3ebba8ea
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -0,0 +1,2912 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/builder_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/builder_time.h"
+#include "arrow/array/builder_union.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bitmap.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/bitmap_reader.h"
+
+namespace arrow {
+
+using internal::BitBlockCount;
+using internal::BitBlockCounter;
+using internal::Bitmap;
+using internal::BitmapWordReader;
+using internal::BitRunReader;
+
+namespace compute {
+namespace internal {
+
+namespace {
+
+constexpr uint64_t kAllNull = 0;
+constexpr uint64_t kAllValid = ~kAllNull;
+
+util::optional<uint64_t> GetConstantValidityWord(const Datum& data) {
+ if (data.is_scalar()) {
+ return data.scalar()->is_valid ? kAllValid : kAllNull;
+ }
+
+ if (data.array()->null_count == data.array()->length) return kAllNull;
+
+ if (!data.array()->MayHaveNulls()) return kAllValid;
+
+ // no constant validity word available
+ return {};
+}
+
+inline Bitmap GetBitmap(const Datum& datum, int i) {
+ if (datum.is_scalar()) return {};
+ const ArrayData& a = *datum.array();
+ return Bitmap{a.buffers[i], a.offset, a.length};
+}
+
+// Ensure parameterized types are identical.
+Status CheckIdenticalTypes(const Datum* begin, size_t count) {
+ const auto& ty = begin->type();
+ const auto* end = begin + count;
+ for (auto it = begin + 1; it != end; ++it) {
+ const DataType& other_ty = *it->type();
+ if (!ty->Equals(other_ty)) {
+ return Status::TypeError("All types must be compatible, expected: ", *ty,
+ ", but got: ", other_ty);
+ }
+ }
+ return Status::OK();
+}
+
+// if the condition is null then output is null otherwise we take validity from the
+// selected argument
+// ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid)
+template <typename AllocateNullBitmap>
+Status PromoteNullsVisitor(KernelContext* ctx, const Datum& cond_d, const Datum& left_d,
+ const Datum& right_d, ArrayData* output) {
+ auto cond_const = GetConstantValidityWord(cond_d);
+ auto left_const = GetConstantValidityWord(left_d);
+ auto right_const = GetConstantValidityWord(right_d);
+
+ enum { COND_CONST = 1, LEFT_CONST = 2, RIGHT_CONST = 4 };
+ auto flag = COND_CONST * cond_const.has_value() | LEFT_CONST * left_const.has_value() |
+ RIGHT_CONST * right_const.has_value();
+
+ const ArrayData& cond = *cond_d.array();
+ // cond.data will always be available
+ Bitmap cond_data{cond.buffers[1], cond.offset, cond.length};
+ Bitmap cond_valid{cond.buffers[0], cond.offset, cond.length};
+ Bitmap left_valid = GetBitmap(left_d, 0);
+ Bitmap right_valid = GetBitmap(right_d, 0);
+
+ // cond.valid & (cond.data & left.valid | ~cond.data & right.valid)
+ // In the following cases, we dont need to allocate out_valid bitmap
+
+ // if cond & left & right all ones, then output is all valid.
+ // if output validity buffer is already allocated (NullHandling::
+ // COMPUTED_PREALLOCATE) -> set all bits
+ // else, return nullptr
+ if (cond_const == kAllValid && left_const == kAllValid && right_const == kAllValid) {
+ if (AllocateNullBitmap::value) { // NullHandling::COMPUTED_NO_PREALLOCATE
+ output->buffers[0] = nullptr;
+ } else { // NullHandling::COMPUTED_PREALLOCATE
+ BitUtil::SetBitmap(output->buffers[0]->mutable_data(), output->offset,
+ output->length);
+ }
+ return Status::OK();
+ }
+
+ if (left_const == kAllValid && right_const == kAllValid) {
+ // if both left and right are valid, no need to calculate out_valid bitmap. Copy
+ // cond validity buffer
+ if (AllocateNullBitmap::value) { // NullHandling::COMPUTED_NO_PREALLOCATE
+ // if there's an offset, copy bitmap (cannot slice a bitmap)
+ if (cond.offset) {
+ ARROW_ASSIGN_OR_RAISE(
+ output->buffers[0],
+ arrow::internal::CopyBitmap(ctx->memory_pool(), cond.buffers[0]->data(),
+ cond.offset, cond.length));
+ } else { // just copy assign cond validity buffer
+ output->buffers[0] = cond.buffers[0];
+ }
+ } else { // NullHandling::COMPUTED_PREALLOCATE
+ arrow::internal::CopyBitmap(cond.buffers[0]->data(), cond.offset, cond.length,
+ output->buffers[0]->mutable_data(), output->offset);
+ }
+ return Status::OK();
+ }
+
+ // lambda function that will be used inside the visitor
+ auto apply = [&](uint64_t c_valid, uint64_t c_data, uint64_t l_valid,
+ uint64_t r_valid) {
+ return c_valid & ((c_data & l_valid) | (~c_data & r_valid));
+ };
+
+ if (AllocateNullBitmap::value) {
+ // following cases requires a separate out_valid buffer. COMPUTED_NO_PREALLOCATE
+ // would not have allocated buffers for it.
+ ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length));
+ }
+
+ std::array<Bitmap, 1> out_bitmaps{
+ Bitmap{output->buffers[0], output->offset, output->length}};
+
+ switch (flag) {
+ case COND_CONST | LEFT_CONST | RIGHT_CONST: {
+ std::array<Bitmap, 1> bitmaps{cond_data};
+ Bitmap::VisitWordsAndWrite(bitmaps, &out_bitmaps,
+ [&](const std::array<uint64_t, 1>& words_in,
+ std::array<uint64_t, 1>* word_out) {
+ word_out->at(0) = apply(*cond_const, words_in[0],
+ *left_const, *right_const);
+ });
+ break;
+ }
+ case LEFT_CONST | RIGHT_CONST: {
+ std::array<Bitmap, 2> bitmaps{cond_valid, cond_data};
+ Bitmap::VisitWordsAndWrite(bitmaps, &out_bitmaps,
+ [&](const std::array<uint64_t, 2>& words_in,
+ std::array<uint64_t, 1>* word_out) {
+ word_out->at(0) = apply(words_in[0], words_in[1],
+ *left_const, *right_const);
+ });
+ break;
+ }
+ case COND_CONST | RIGHT_CONST: {
+ // bitmaps[C_VALID], bitmaps[R_VALID] might be null; override to make it safe for
+ // Visit()
+ std::array<Bitmap, 2> bitmaps{cond_data, left_valid};
+ Bitmap::VisitWordsAndWrite(bitmaps, &out_bitmaps,
+ [&](const std::array<uint64_t, 2>& words_in,
+ std::array<uint64_t, 1>* word_out) {
+ word_out->at(0) = apply(*cond_const, words_in[0],
+ words_in[1], *right_const);
+ });
+ break;
+ }
+ case RIGHT_CONST: {
+ // bitmaps[R_VALID] might be null; override to make it safe for Visit()
+ std::array<Bitmap, 3> bitmaps{cond_valid, cond_data, left_valid};
+ Bitmap::VisitWordsAndWrite(bitmaps, &out_bitmaps,
+ [&](const std::array<uint64_t, 3>& words_in,
+ std::array<uint64_t, 1>* word_out) {
+ word_out->at(0) = apply(words_in[0], words_in[1],
+ words_in[2], *right_const);
+ });
+ break;
+ }
+ case COND_CONST | LEFT_CONST: {
+ // bitmaps[C_VALID], bitmaps[L_VALID] might be null; override to make it safe for
+ // Visit()
+ std::array<Bitmap, 2> bitmaps{cond_data, right_valid};
+ Bitmap::VisitWordsAndWrite(bitmaps, &out_bitmaps,
+ [&](const std::array<uint64_t, 2>& words_in,
+ std::array<uint64_t, 1>* word_out) {
+ word_out->at(0) = apply(*cond_const, words_in[0],
+ *left_const, words_in[1]);
+ });
+ break;
+ }
+ case LEFT_CONST: {
+ // bitmaps[L_VALID] might be null; override to make it safe for Visit()
+ std::array<Bitmap, 3> bitmaps{cond_valid, cond_data, right_valid};
+ Bitmap::VisitWordsAndWrite(bitmaps, &out_bitmaps,
+ [&](const std::array<uint64_t, 3>& words_in,
+ std::array<uint64_t, 1>* word_out) {
+ word_out->at(0) = apply(words_in[0], words_in[1],
+ *left_const, words_in[2]);
+ });
+ break;
+ }
+ case COND_CONST: {
+ // bitmaps[C_VALID] might be null; override to make it safe for Visit()
+ std::array<Bitmap, 3> bitmaps{cond_data, left_valid, right_valid};
+ Bitmap::VisitWordsAndWrite(bitmaps, &out_bitmaps,
+ [&](const std::array<uint64_t, 3>& words_in,
+ std::array<uint64_t, 1>* word_out) {
+ word_out->at(0) = apply(*cond_const, words_in[0],
+ words_in[1], words_in[2]);
+ });
+ break;
+ }
+ case 0: {
+ std::array<Bitmap, 4> bitmaps{cond_valid, cond_data, left_valid, right_valid};
+ Bitmap::VisitWordsAndWrite(bitmaps, &out_bitmaps,
+ [&](const std::array<uint64_t, 4>& words_in,
+ std::array<uint64_t, 1>* word_out) {
+ word_out->at(0) = apply(words_in[0], words_in[1],
+ words_in[2], words_in[3]);
+ });
+ break;
+ }
+ }
+ return Status::OK();
+}
+
+using Word = uint64_t;
+static constexpr int64_t word_len = sizeof(Word) * 8;
+
+/// Runs the main if_else loop. Here, it is expected that the right data has already
+/// been copied to the output.
+/// If `invert` is meant to invert the cond.data. If is set to `true`, then the
+/// buffer will be inverted before calling the handle_block or handle_each functions.
+/// This is useful, when left is an array and right is scalar. Then rather than
+/// copying data from the right to output, we can copy left data to the output and
+/// invert the cond data to fill right values. Filling out with a scalar is presumed to
+/// be more efficient than filling with an array
+///
+/// `HandleBlock` has the signature:
+/// [](int64_t offset, int64_t length){...}
+/// It should copy `length` number of elements from source array to output array with
+/// `offset` offset in both arrays
+template <typename HandleBlock, bool invert = false>
+void RunIfElseLoop(const ArrayData& cond, const HandleBlock& handle_block) {
+ int64_t data_offset = 0;
+ int64_t bit_offset = cond.offset;
+ const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray
+
+ BitmapWordReader<Word> cond_reader(cond_data, cond.offset, cond.length);
+
+ constexpr Word pickAll = invert ? 0 : UINT64_MAX;
+ constexpr Word pickNone = ~pickAll;
+
+ int64_t cnt = cond_reader.words();
+ while (cnt--) {
+ Word word = cond_reader.NextWord();
+
+ if (word == pickAll) {
+ handle_block(data_offset, word_len);
+ } else if (word != pickNone) {
+ for (int64_t i = 0; i < word_len; ++i) {
+ if (BitUtil::GetBit(cond_data, bit_offset + i) != invert) {
+ handle_block(data_offset + i, 1);
+ }
+ }
+ }
+ data_offset += word_len;
+ bit_offset += word_len;
+ }
+
+ constexpr uint8_t pickAllByte = invert ? 0 : UINT8_MAX;
+ // byte bit-wise inversion is int-wide. Hence XOR with 0xff
+ constexpr uint8_t pickNoneByte = pickAllByte ^ 0xff;
+
+ cnt = cond_reader.trailing_bytes();
+ while (cnt--) {
+ int valid_bits;
+ uint8_t byte = cond_reader.NextTrailingByte(valid_bits);
+
+ if (byte == pickAllByte && valid_bits == 8) {
+ handle_block(data_offset, 8);
+ } else if (byte != pickNoneByte) {
+ for (int i = 0; i < valid_bits; ++i) {
+ if (BitUtil::GetBit(cond_data, bit_offset + i) != invert) {
+ handle_block(data_offset + i, 1);
+ }
+ }
+ }
+ data_offset += 8;
+ bit_offset += 8;
+ }
+}
+
+template <typename HandleBlock>
+void RunIfElseLoopInverted(const ArrayData& cond, const HandleBlock& handle_block) {
+ RunIfElseLoop<HandleBlock, true>(cond, handle_block);
+}
+
+/// Runs if-else when cond is a scalar. Two special functions are required,
+/// 1.CopyArrayData, 2. BroadcastScalar
+template <typename CopyArrayData, typename BroadcastScalar>
+Status RunIfElseScalar(const BooleanScalar& cond, const Datum& left, const Datum& right,
+ Datum* out, const CopyArrayData& copy_array_data,
+ const BroadcastScalar& broadcast_scalar) {
+ if (left.is_scalar() && right.is_scalar()) { // output will be a scalar
+ if (cond.is_valid) {
+ *out = cond.value ? left.scalar() : right.scalar();
+ } else {
+ *out = MakeNullScalar(left.type());
+ }
+ return Status::OK();
+ }
+
+ // either left or right is an array. Output is always an array`
+ const std::shared_ptr<ArrayData>& out_array = out->array();
+ if (!cond.is_valid) {
+ // cond is null; output is all null --> clear validity buffer
+ BitUtil::ClearBitmap(out_array->buffers[0]->mutable_data(), out_array->offset,
+ out_array->length);
+ return Status::OK();
+ }
+
+ // cond is a non-null scalar
+ const auto& valid_data = cond.value ? left : right;
+ if (valid_data.is_array()) {
+ // valid_data is an array. Hence copy data to the output buffers
+ const auto& valid_array = valid_data.array();
+ if (valid_array->MayHaveNulls()) {
+ arrow::internal::CopyBitmap(
+ valid_array->buffers[0]->data(), valid_array->offset, valid_array->length,
+ out_array->buffers[0]->mutable_data(), out_array->offset);
+ } else { // validity buffer is nullptr --> set all bits
+ BitUtil::SetBitmap(out_array->buffers[0]->mutable_data(), out_array->offset,
+ out_array->length);
+ }
+ copy_array_data(*valid_array, out_array.get());
+ return Status::OK();
+
+ } else { // valid data is scalar
+ // valid data is a scalar that needs to be broadcasted
+ const auto& valid_scalar = *valid_data.scalar();
+ if (valid_scalar.is_valid) { // if the scalar is non-null, broadcast
+ BitUtil::SetBitmap(out_array->buffers[0]->mutable_data(), out_array->offset,
+ out_array->length);
+ broadcast_scalar(*valid_data.scalar(), out_array.get());
+ } else { // scalar is null, clear the output validity buffer
+ BitUtil::ClearBitmap(out_array->buffers[0]->mutable_data(), out_array->offset,
+ out_array->length);
+ }
+ return Status::OK();
+ }
+}
+
+template <typename Type, typename Enable = void>
+struct IfElseFunctor {};
+
+// only number types needs to be handled for Fixed sized primitive data types because,
+// internal::GenerateTypeAgnosticPrimitive forwards types to the corresponding unsigned
+// int type
+template <typename Type>
+struct IfElseFunctor<Type, enable_if_number<Type>> {
+ using T = typename TypeTraits<Type>::CType;
+ // A - Array, S - Scalar, X = Array/Scalar
+
+ // SXX
+ static Status Call(KernelContext* ctx, const BooleanScalar& cond, const Datum& left,
+ const Datum& right, Datum* out) {
+ return RunIfElseScalar(
+ cond, left, right, out,
+ /*CopyArrayData*/
+ [&](const ArrayData& valid_array, ArrayData* out_array) {
+ std::memcpy(out_array->GetMutableValues<T>(1), valid_array.GetValues<T>(1),
+ valid_array.length * sizeof(T));
+ },
+ /*BroadcastScalar*/
+ [&](const Scalar& scalar, ArrayData* out_array) {
+ T scalar_data = internal::UnboxScalar<Type>::Unbox(scalar);
+ std::fill(out_array->GetMutableValues<T>(1),
+ out_array->GetMutableValues<T>(1) + out_array->length, scalar_data);
+ });
+ }
+
+ // AAA
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
+ const ArrayData& right, ArrayData* out) {
+ T* out_values = out->template GetMutableValues<T>(1);
+
+ // copy right data to out_buff
+ const T* right_data = right.GetValues<T>(1);
+ std::memcpy(out_values, right_data, right.length * sizeof(T));
+
+ // selectively copy values from left data
+ const T* left_data = left.GetValues<T>(1);
+
+ RunIfElseLoop(cond, [&](int64_t data_offset, int64_t num_elems) {
+ std::memcpy(out_values + data_offset, left_data + data_offset,
+ num_elems * sizeof(T));
+ });
+
+ return Status::OK();
+ }
+
+ // ASA
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left,
+ const ArrayData& right, ArrayData* out) {
+ T* out_values = out->template GetMutableValues<T>(1);
+
+ // copy right data to out_buff
+ const T* right_data = right.GetValues<T>(1);
+ std::memcpy(out_values, right_data, right.length * sizeof(T));
+
+ // selectively copy values from left data
+ T left_data = internal::UnboxScalar<Type>::Unbox(left);
+
+ RunIfElseLoop(cond, [&](int64_t data_offset, int64_t num_elems) {
+ std::fill(out_values + data_offset, out_values + data_offset + num_elems,
+ left_data);
+ });
+
+ return Status::OK();
+ }
+
+ // AAS
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
+ const Scalar& right, ArrayData* out) {
+ T* out_values = out->template GetMutableValues<T>(1);
+
+ // copy left data to out_buff
+ const T* left_data = left.GetValues<T>(1);
+ std::memcpy(out_values, left_data, left.length * sizeof(T));
+
+ T right_data = internal::UnboxScalar<Type>::Unbox(right);
+
+ RunIfElseLoopInverted(cond, [&](int64_t data_offset, int64_t num_elems) {
+ std::fill(out_values + data_offset, out_values + data_offset + num_elems,
+ right_data);
+ });
+
+ return Status::OK();
+ }
+
+ // ASS
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left,
+ const Scalar& right, ArrayData* out) {
+ T* out_values = out->template GetMutableValues<T>(1);
+
+ // copy right data to out_buff
+ T right_data = internal::UnboxScalar<Type>::Unbox(right);
+ std::fill(out_values, out_values + cond.length, right_data);
+
+ // selectively copy values from left data
+ T left_data = internal::UnboxScalar<Type>::Unbox(left);
+ RunIfElseLoop(cond, [&](int64_t data_offset, int64_t num_elems) {
+ std::fill(out_values + data_offset, out_values + data_offset + num_elems,
+ left_data);
+ });
+
+ return Status::OK();
+ }
+};
+
+template <typename Type>
+struct IfElseFunctor<Type, enable_if_boolean<Type>> {
+ // A - Array, S - Scalar, X = Array/Scalar
+
+ // SXX
+ static Status Call(KernelContext* ctx, const BooleanScalar& cond, const Datum& left,
+ const Datum& right, Datum* out) {
+ return RunIfElseScalar(
+ cond, left, right, out,
+ /*CopyArrayData*/
+ [&](const ArrayData& valid_array, ArrayData* out_array) {
+ arrow::internal::CopyBitmap(
+ valid_array.buffers[1]->data(), valid_array.offset, valid_array.length,
+ out_array->buffers[1]->mutable_data(), out_array->offset);
+ },
+ /*BroadcastScalar*/
+ [&](const Scalar& scalar, ArrayData* out_array) {
+ bool scalar_data = internal::UnboxScalar<Type>::Unbox(scalar);
+ BitUtil::SetBitsTo(out_array->buffers[1]->mutable_data(), out_array->offset,
+ out_array->length, scalar_data);
+ });
+ }
+
+ // AAA
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
+ const ArrayData& right, ArrayData* out) {
+ // out_buff = right & ~cond
+ const auto& out_buf = out->buffers[1];
+ arrow::internal::BitmapAndNot(right.buffers[1]->data(), right.offset,
+ cond.buffers[1]->data(), cond.offset, cond.length,
+ out->offset, out_buf->mutable_data());
+
+ // out_buff = left & cond
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> temp_buf,
+ arrow::internal::BitmapAnd(
+ ctx->memory_pool(), left.buffers[1]->data(), left.offset,
+ cond.buffers[1]->data(), cond.offset, cond.length, 0));
+
+ arrow::internal::BitmapOr(out_buf->data(), out->offset, temp_buf->data(), 0,
+ cond.length, out->offset, out_buf->mutable_data());
+
+ return Status::OK();
+ }
+
+ // ASA
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left,
+ const ArrayData& right, ArrayData* out) {
+ // out_buff = right & ~cond
+ const auto& out_buf = out->buffers[1];
+ arrow::internal::BitmapAndNot(right.buffers[1]->data(), right.offset,
+ cond.buffers[1]->data(), cond.offset, cond.length,
+ out->offset, out_buf->mutable_data());
+
+ // out_buff = left & cond
+ bool left_data = internal::UnboxScalar<BooleanType>::Unbox(left);
+ if (left_data) {
+ arrow::internal::BitmapOr(out_buf->data(), out->offset, cond.buffers[1]->data(),
+ cond.offset, cond.length, out->offset,
+ out_buf->mutable_data());
+ }
+
+ return Status::OK();
+ }
+
+ // AAS
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
+ const Scalar& right, ArrayData* out) {
+ // out_buff = left & cond
+ const auto& out_buf = out->buffers[1];
+ arrow::internal::BitmapAnd(left.buffers[1]->data(), left.offset,
+ cond.buffers[1]->data(), cond.offset, cond.length,
+ out->offset, out_buf->mutable_data());
+
+ bool right_data = internal::UnboxScalar<BooleanType>::Unbox(right);
+
+ // out_buff = left & cond | right & ~cond
+ if (right_data) {
+ arrow::internal::BitmapOrNot(out_buf->data(), out->offset, cond.buffers[1]->data(),
+ cond.offset, cond.length, out->offset,
+ out_buf->mutable_data());
+ }
+
+ return Status::OK();
+ }
+
+ // ASS
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left,
+ const Scalar& right, ArrayData* out) {
+ bool left_data = internal::UnboxScalar<BooleanType>::Unbox(left);
+ bool right_data = internal::UnboxScalar<BooleanType>::Unbox(right);
+
+ const auto& out_buf = out->buffers[1];
+
+ // out_buf = left & cond | right & ~cond
+ // std::shared_ptr<Buffer> out_buf = nullptr;
+ if (left_data) {
+ if (right_data) {
+ // out_buf = ones
+ BitUtil::SetBitmap(out_buf->mutable_data(), out->offset, cond.length);
+ } else {
+ // out_buf = cond
+ arrow::internal::CopyBitmap(cond.buffers[1]->data(), cond.offset, cond.length,
+ out_buf->mutable_data(), out->offset);
+ }
+ } else {
+ if (right_data) {
+ // out_buf = ~cond
+ arrow::internal::InvertBitmap(cond.buffers[1]->data(), cond.offset, cond.length,
+ out_buf->mutable_data(), out->offset);
+ } else {
+ // out_buf = zeros
+ BitUtil::ClearBitmap(out_buf->mutable_data(), out->offset, cond.length);
+ }
+ }
+
+ return Status::OK();
+ }
+};
+
+template <typename Type>
+struct IfElseFunctor<Type, enable_if_base_binary<Type>> {
+ using OffsetType = typename TypeTraits<Type>::OffsetType::c_type;
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using BuilderType = typename TypeTraits<Type>::BuilderType;
+
+ // A - Array, S - Scalar, X = Array/Scalar
+
+ // SXX
+ static Status Call(KernelContext* ctx, const BooleanScalar& cond, const Datum& left,
+ const Datum& right, Datum* out) {
+ if (left.is_scalar() && right.is_scalar()) {
+ if (cond.is_valid) {
+ *out = cond.value ? left.scalar() : right.scalar();
+ } else {
+ *out = MakeNullScalar(left.type());
+ }
+ return Status::OK();
+ }
+ // either left or right is an array. Output is always an array
+ int64_t out_arr_len = std::max(left.length(), right.length());
+ if (!cond.is_valid) {
+ // cond is null; just create a null array
+ ARROW_ASSIGN_OR_RAISE(*out,
+ MakeArrayOfNull(left.type(), out_arr_len, ctx->memory_pool()))
+ return Status::OK();
+ }
+
+ const auto& valid_data = cond.value ? left : right;
+ if (valid_data.is_array()) {
+ *out = valid_data;
+ } else {
+ // valid data is a scalar that needs to be broadcasted
+ ARROW_ASSIGN_OR_RAISE(*out, MakeArrayFromScalar(*valid_data.scalar(), out_arr_len,
+ ctx->memory_pool()));
+ }
+ return Status::OK();
+ }
+
+ // AAA
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
+ const ArrayData& right, ArrayData* out) {
+ const auto* left_offsets = left.GetValues<OffsetType>(1);
+ const uint8_t* left_data = left.buffers[2]->data();
+ const auto* right_offsets = right.GetValues<OffsetType>(1);
+ const uint8_t* right_data = right.buffers[2]->data();
+
+ // allocate data buffer conservatively
+ int64_t data_buff_alloc = left_offsets[left.length] - left_offsets[0] +
+ right_offsets[right.length] - right_offsets[0];
+
+ BuilderType builder(ctx->memory_pool());
+ ARROW_RETURN_NOT_OK(builder.Reserve(cond.length + 1));
+ ARROW_RETURN_NOT_OK(builder.ReserveData(data_buff_alloc));
+
+ RunLoop(
+ cond, *out,
+ [&](int64_t i) {
+ builder.UnsafeAppend(left_data + left_offsets[i],
+ left_offsets[i + 1] - left_offsets[i]);
+ },
+ [&](int64_t i) {
+ builder.UnsafeAppend(right_data + right_offsets[i],
+ right_offsets[i + 1] - right_offsets[i]);
+ },
+ [&]() { builder.UnsafeAppendNull(); });
+ ARROW_ASSIGN_OR_RAISE(auto out_arr, builder.Finish());
+
+ out->SetNullCount(out_arr->data()->null_count);
+ out->buffers[0] = std::move(out_arr->data()->buffers[0]);
+ out->buffers[1] = std::move(out_arr->data()->buffers[1]);
+ out->buffers[2] = std::move(out_arr->data()->buffers[2]);
+ return Status::OK();
+ }
+
+ // ASA
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left,
+ const ArrayData& right, ArrayData* out) {
+ util::string_view left_data = internal::UnboxScalar<Type>::Unbox(left);
+ auto left_size = static_cast<OffsetType>(left_data.size());
+
+ const auto* right_offsets = right.GetValues<OffsetType>(1);
+ const uint8_t* right_data = right.buffers[2]->data();
+
+ // allocate data buffer conservatively
+ int64_t data_buff_alloc =
+ left_size * cond.length + right_offsets[right.length] - right_offsets[0];
+
+ BuilderType builder(ctx->memory_pool());
+ ARROW_RETURN_NOT_OK(builder.Reserve(cond.length + 1));
+ ARROW_RETURN_NOT_OK(builder.ReserveData(data_buff_alloc));
+
+ RunLoop(
+ cond, *out, [&](int64_t i) { builder.UnsafeAppend(left_data.data(), left_size); },
+ [&](int64_t i) {
+ builder.UnsafeAppend(right_data + right_offsets[i],
+ right_offsets[i + 1] - right_offsets[i]);
+ },
+ [&]() { builder.UnsafeAppendNull(); });
+ ARROW_ASSIGN_OR_RAISE(auto out_arr, builder.Finish());
+
+ out->SetNullCount(out_arr->data()->null_count);
+ out->buffers[0] = std::move(out_arr->data()->buffers[0]);
+ out->buffers[1] = std::move(out_arr->data()->buffers[1]);
+ out->buffers[2] = std::move(out_arr->data()->buffers[2]);
+ return Status::OK();
+ }
+
+ // AAS
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
+ const Scalar& right, ArrayData* out) {
+ const auto* left_offsets = left.GetValues<OffsetType>(1);
+ const uint8_t* left_data = left.buffers[2]->data();
+
+ util::string_view right_data = internal::UnboxScalar<Type>::Unbox(right);
+ auto right_size = static_cast<OffsetType>(right_data.size());
+
+ // allocate data buffer conservatively
+ int64_t data_buff_alloc =
+ right_size * cond.length + left_offsets[left.length] - left_offsets[0];
+
+ BuilderType builder(ctx->memory_pool());
+ ARROW_RETURN_NOT_OK(builder.Reserve(cond.length + 1));
+ ARROW_RETURN_NOT_OK(builder.ReserveData(data_buff_alloc));
+
+ RunLoop(
+ cond, *out,
+ [&](int64_t i) {
+ builder.UnsafeAppend(left_data + left_offsets[i],
+ left_offsets[i + 1] - left_offsets[i]);
+ },
+ [&](int64_t i) { builder.UnsafeAppend(right_data.data(), right_size); },
+ [&]() { builder.UnsafeAppendNull(); });
+ ARROW_ASSIGN_OR_RAISE(auto out_arr, builder.Finish());
+
+ out->SetNullCount(out_arr->data()->null_count);
+ out->buffers[0] = std::move(out_arr->data()->buffers[0]);
+ out->buffers[1] = std::move(out_arr->data()->buffers[1]);
+ out->buffers[2] = std::move(out_arr->data()->buffers[2]);
+ return Status::OK();
+ }
+
+ // ASS
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left,
+ const Scalar& right, ArrayData* out) {
+ util::string_view left_data = internal::UnboxScalar<Type>::Unbox(left);
+ auto left_size = static_cast<OffsetType>(left_data.size());
+
+ util::string_view right_data = internal::UnboxScalar<Type>::Unbox(right);
+ auto right_size = static_cast<OffsetType>(right_data.size());
+
+ // allocate data buffer conservatively
+ int64_t data_buff_alloc = std::max(right_size, left_size) * cond.length;
+ BuilderType builder(ctx->memory_pool());
+ ARROW_RETURN_NOT_OK(builder.Reserve(cond.length + 1));
+ ARROW_RETURN_NOT_OK(builder.ReserveData(data_buff_alloc));
+
+ RunLoop(
+ cond, *out, [&](int64_t i) { builder.UnsafeAppend(left_data.data(), left_size); },
+ [&](int64_t i) { builder.UnsafeAppend(right_data.data(), right_size); },
+ [&]() { builder.UnsafeAppendNull(); });
+ ARROW_ASSIGN_OR_RAISE(auto out_arr, builder.Finish());
+
+ out->SetNullCount(out_arr->data()->null_count);
+ out->buffers[0] = std::move(out_arr->data()->buffers[0]);
+ out->buffers[1] = std::move(out_arr->data()->buffers[1]);
+ out->buffers[2] = std::move(out_arr->data()->buffers[2]);
+ return Status::OK();
+ }
+
+ template <typename HandleLeft, typename HandleRight, typename HandleNull>
+ static void RunLoop(const ArrayData& cond, const ArrayData& output,
+ HandleLeft&& handle_left, HandleRight&& handle_right,
+ HandleNull&& handle_null) {
+ const auto* cond_data = cond.buffers[1]->data();
+
+ if (output.buffers[0]) { // output may have nulls
+ // output validity buffer is allocated internally from the IfElseFunctor. Therefore
+ // it is cond.length'd with 0 offset.
+ const auto* out_valid = output.buffers[0]->data();
+
+ for (int64_t i = 0; i < cond.length; i++) {
+ if (BitUtil::GetBit(out_valid, i)) {
+ BitUtil::GetBit(cond_data, cond.offset + i) ? handle_left(i) : handle_right(i);
+ } else {
+ handle_null();
+ }
+ }
+ } else { // output is all valid (no nulls)
+ for (int64_t i = 0; i < cond.length; i++) {
+ BitUtil::GetBit(cond_data, cond.offset + i) ? handle_left(i) : handle_right(i);
+ }
+ }
+ }
+};
+
+template <typename Type>
+struct IfElseFunctor<Type, enable_if_fixed_size_binary<Type>> {
+ // A - Array, S - Scalar, X = Array/Scalar
+
+ // SXX
+ static Status Call(KernelContext* ctx, const BooleanScalar& cond, const Datum& left,
+ const Datum& right, Datum* out) {
+ ARROW_ASSIGN_OR_RAISE(auto byte_width, GetByteWidth(*left.type(), *right.type()));
+ return RunIfElseScalar(
+ cond, left, right, out,
+ /*CopyArrayData*/
+ [&](const ArrayData& valid_array, ArrayData* out_array) {
+ std::memcpy(
+ out_array->buffers[1]->mutable_data() + out_array->offset * byte_width,
+ valid_array.buffers[1]->data() + valid_array.offset * byte_width,
+ valid_array.length * byte_width);
+ },
+ /*BroadcastScalar*/
+ [&](const Scalar& scalar, ArrayData* out_array) {
+ const uint8_t* scalar_data = UnboxBinaryScalar(scalar);
+ uint8_t* start =
+ out_array->buffers[1]->mutable_data() + out_array->offset * byte_width;
+ for (int64_t i = 0; i < out_array->length; i++) {
+ std::memcpy(start + i * byte_width, scalar_data, byte_width);
+ }
+ });
+ }
+
+ // AAA
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
+ const ArrayData& right, ArrayData* out) {
+ ARROW_ASSIGN_OR_RAISE(auto byte_width, GetByteWidth(*left.type, *right.type));
+ auto* out_values = out->buffers[1]->mutable_data() + out->offset * byte_width;
+
+ // copy right data to out_buff
+ const uint8_t* right_data = right.buffers[1]->data() + right.offset * byte_width;
+ std::memcpy(out_values, right_data, right.length * byte_width);
+
+ // selectively copy values from left data
+ const uint8_t* left_data = left.buffers[1]->data() + left.offset * byte_width;
+
+ RunIfElseLoop(cond, [&](int64_t data_offset, int64_t num_elems) {
+ std::memcpy(out_values + data_offset * byte_width,
+ left_data + data_offset * byte_width, num_elems * byte_width);
+ });
+
+ return Status::OK();
+ }
+
+ // ASA
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left,
+ const ArrayData& right, ArrayData* out) {
+ ARROW_ASSIGN_OR_RAISE(auto byte_width, GetByteWidth(*left.type, *right.type));
+ auto* out_values = out->buffers[1]->mutable_data() + out->offset * byte_width;
+
+ // copy right data to out_buff
+ const uint8_t* right_data = right.buffers[1]->data() + right.offset * byte_width;
+ std::memcpy(out_values, right_data, right.length * byte_width);
+
+ // selectively copy values from left data
+ const uint8_t* left_data = UnboxBinaryScalar(left);
+
+ RunIfElseLoop(cond, [&](int64_t data_offset, int64_t num_elems) {
+ if (left_data) {
+ for (int64_t i = 0; i < num_elems; i++) {
+ std::memcpy(out_values + (data_offset + i) * byte_width, left_data, byte_width);
+ }
+ }
+ });
+
+ return Status::OK();
+ }
+
+ // AAS
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
+ const Scalar& right, ArrayData* out) {
+ ARROW_ASSIGN_OR_RAISE(auto byte_width, GetByteWidth(*left.type, *right.type));
+ auto* out_values = out->buffers[1]->mutable_data() + out->offset * byte_width;
+
+ // copy left data to out_buff
+ const uint8_t* left_data = left.buffers[1]->data() + left.offset * byte_width;
+ std::memcpy(out_values, left_data, left.length * byte_width);
+
+ const uint8_t* right_data = UnboxBinaryScalar(right);
+
+ RunIfElseLoopInverted(cond, [&](int64_t data_offset, int64_t num_elems) {
+ if (right_data) {
+ for (int64_t i = 0; i < num_elems; i++) {
+ std::memcpy(out_values + (data_offset + i) * byte_width, right_data,
+ byte_width);
+ }
+ }
+ });
+
+ return Status::OK();
+ }
+
+ // ASS
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left,
+ const Scalar& right, ArrayData* out) {
+ ARROW_ASSIGN_OR_RAISE(auto byte_width, GetByteWidth(*left.type, *right.type));
+ auto* out_values = out->buffers[1]->mutable_data() + out->offset * byte_width;
+
+ // copy right data to out_buff
+ const uint8_t* right_data = UnboxBinaryScalar(right);
+ if (right_data) {
+ for (int64_t i = 0; i < cond.length; i++) {
+ std::memcpy(out_values + i * byte_width, right_data, byte_width);
+ }
+ }
+
+ // selectively copy values from left data
+ const uint8_t* left_data = UnboxBinaryScalar(left);
+ RunIfElseLoop(cond, [&](int64_t data_offset, int64_t num_elems) {
+ if (left_data) {
+ for (int64_t i = 0; i < num_elems; i++) {
+ std::memcpy(out_values + (data_offset + i) * byte_width, left_data, byte_width);
+ }
+ }
+ });
+
+ return Status::OK();
+ }
+
+ template <typename T = Type>
+ static enable_if_t<!is_decimal_type<T>::value, const uint8_t*> UnboxBinaryScalar(
+ const Scalar& scalar) {
+ return reinterpret_cast<const uint8_t*>(
+ internal::UnboxScalar<FixedSizeBinaryType>::Unbox(scalar).data());
+ }
+
+ template <typename T = Type>
+ static enable_if_decimal<T, const uint8_t*> UnboxBinaryScalar(const Scalar& scalar) {
+ return internal::UnboxScalar<T>::Unbox(scalar).native_endian_bytes();
+ }
+
+ template <typename T = Type>
+ static enable_if_t<!is_decimal_type<T>::value, Result<int32_t>> GetByteWidth(
+ const DataType& left_type, const DataType& right_type) {
+ const int32_t width =
+ checked_cast<const FixedSizeBinaryType&>(left_type).byte_width();
+ DCHECK_EQ(width, checked_cast<const FixedSizeBinaryType&>(right_type).byte_width());
+ return width;
+ }
+
+ template <typename T = Type>
+ static enable_if_decimal<T, Result<int32_t>> GetByteWidth(const DataType& left_type,
+ const DataType& right_type) {
+ const auto& left = checked_cast<const T&>(left_type);
+ const auto& right = checked_cast<const T&>(right_type);
+ DCHECK_EQ(left.precision(), right.precision());
+ DCHECK_EQ(left.scale(), right.scale());
+ return left.byte_width();
+ }
+};
+
+// Use builders for dictionaries - slower, but allows us to unify dictionaries
+struct NestedIfElseExec {
+ // A - Array, S - Scalar, X = Array/Scalar
+
+ // SXX
+ static Status Call(KernelContext* ctx, const BooleanScalar& cond, const Datum& left,
+ const Datum& right, Datum* out) {
+ if (left.is_scalar() && right.is_scalar()) {
+ if (cond.is_valid) {
+ *out = cond.value ? left.scalar() : right.scalar();
+ } else {
+ *out = MakeNullScalar(left.type());
+ }
+ return Status::OK();
+ }
+ // either left or right is an array. Output is always an array
+ int64_t out_arr_len = std::max(left.length(), right.length());
+ if (!cond.is_valid) {
+ // cond is null; just create a null array
+ ARROW_ASSIGN_OR_RAISE(*out,
+ MakeArrayOfNull(left.type(), out_arr_len, ctx->memory_pool()))
+ return Status::OK();
+ }
+
+ const auto& valid_data = cond.value ? left : right;
+ if (valid_data.is_array()) {
+ *out = valid_data;
+ } else {
+ // valid data is a scalar that needs to be broadcasted
+ ARROW_ASSIGN_OR_RAISE(*out, MakeArrayFromScalar(*valid_data.scalar(), out_arr_len,
+ ctx->memory_pool()));
+ }
+ return Status::OK();
+ }
+
+ // AAA
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
+ const ArrayData& right, ArrayData* out) {
+ return RunLoop(
+ ctx, cond, out,
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendArraySlice(left, i, length);
+ },
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendArraySlice(right, i, length);
+ });
+ }
+
+ // ASA
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left,
+ const ArrayData& right, ArrayData* out) {
+ return RunLoop(
+ ctx, cond, out,
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendScalar(left, length);
+ },
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendArraySlice(right, i, length);
+ });
+ }
+
+ // AAS
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
+ const Scalar& right, ArrayData* out) {
+ return RunLoop(
+ ctx, cond, out,
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendArraySlice(left, i, length);
+ },
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendScalar(right, length);
+ });
+ }
+
+ // ASS
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left,
+ const Scalar& right, ArrayData* out) {
+ return RunLoop(
+ ctx, cond, out,
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendScalar(left, length);
+ },
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendScalar(right, length);
+ });
+ }
+
+ template <typename HandleLeft, typename HandleRight>
+ static Status RunLoop(KernelContext* ctx, const ArrayData& cond, ArrayData* out,
+ HandleLeft&& handle_left, HandleRight&& handle_right) {
+ std::unique_ptr<ArrayBuilder> raw_builder;
+ RETURN_NOT_OK(MakeBuilderExactIndex(ctx->memory_pool(), out->type, &raw_builder));
+ RETURN_NOT_OK(raw_builder->Reserve(out->length));
+
+ const auto* cond_data = cond.buffers[1]->data();
+ if (cond.buffers[0]) {
+ BitRunReader reader(cond.buffers[0]->data(), cond.offset, cond.length);
+ int64_t position = 0;
+ while (true) {
+ auto run = reader.NextRun();
+ if (run.length == 0) break;
+ if (run.set) {
+ for (int j = 0; j < run.length; j++) {
+ if (BitUtil::GetBit(cond_data, cond.offset + position + j)) {
+ RETURN_NOT_OK(handle_left(raw_builder.get(), position + j, 1));
+ } else {
+ RETURN_NOT_OK(handle_right(raw_builder.get(), position + j, 1));
+ }
+ }
+ } else {
+ RETURN_NOT_OK(raw_builder->AppendNulls(run.length));
+ }
+ position += run.length;
+ }
+ } else {
+ BitRunReader reader(cond_data, cond.offset, cond.length);
+ int64_t position = 0;
+ while (true) {
+ auto run = reader.NextRun();
+ if (run.length == 0) break;
+ if (run.set) {
+ RETURN_NOT_OK(handle_left(raw_builder.get(), position, run.length));
+ } else {
+ RETURN_NOT_OK(handle_right(raw_builder.get(), position, run.length));
+ }
+ position += run.length;
+ }
+ }
+ ARROW_ASSIGN_OR_RAISE(auto out_arr, raw_builder->Finish());
+ *out = std::move(*out_arr->data());
+ return Status::OK();
+ }
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[1], /*count=*/2));
+ if (batch[0].is_scalar()) {
+ const auto& cond = batch[0].scalar_as<BooleanScalar>();
+ return Call(ctx, cond, batch[1], batch[2], out);
+ }
+ if (batch[1].kind() == Datum::ARRAY) {
+ if (batch[2].kind() == Datum::ARRAY) { // AAA
+ return Call(ctx, *batch[0].array(), *batch[1].array(), *batch[2].array(),
+ out->mutable_array());
+ } else { // AAS
+ return Call(ctx, *batch[0].array(), *batch[1].array(), *batch[2].scalar(),
+ out->mutable_array());
+ }
+ } else {
+ if (batch[2].kind() == Datum::ARRAY) { // ASA
+ return Call(ctx, *batch[0].array(), *batch[1].scalar(), *batch[2].array(),
+ out->mutable_array());
+ } else { // ASS
+ return Call(ctx, *batch[0].array(), *batch[1].scalar(), *batch[2].scalar(),
+ out->mutable_array());
+ }
+ }
+ }
+};
+
+template <typename Type, typename AllocateMem>
+struct ResolveIfElseExec {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // Check is unconditional because parametric types like timestamp
+ // are templated as integer
+ RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[1], /*count=*/2));
+
+ // cond is scalar
+ if (batch[0].is_scalar()) {
+ const auto& cond = batch[0].scalar_as<BooleanScalar>();
+ return IfElseFunctor<Type>::Call(ctx, cond, batch[1], batch[2], out);
+ }
+
+ // cond is array. Use functors to sort things out
+ ARROW_RETURN_NOT_OK(PromoteNullsVisitor<AllocateMem>(ctx, batch[0], batch[1],
+ batch[2], out->mutable_array()));
+
+ if (batch[1].kind() == Datum::ARRAY) {
+ if (batch[2].kind() == Datum::ARRAY) { // AAA
+ return IfElseFunctor<Type>::Call(ctx, *batch[0].array(), *batch[1].array(),
+ *batch[2].array(), out->mutable_array());
+ } else { // AAS
+ return IfElseFunctor<Type>::Call(ctx, *batch[0].array(), *batch[1].array(),
+ *batch[2].scalar(), out->mutable_array());
+ }
+ } else {
+ if (batch[2].kind() == Datum::ARRAY) { // ASA
+ return IfElseFunctor<Type>::Call(ctx, *batch[0].array(), *batch[1].scalar(),
+ *batch[2].array(), out->mutable_array());
+ } else { // ASS
+ return IfElseFunctor<Type>::Call(ctx, *batch[0].array(), *batch[1].scalar(),
+ *batch[2].scalar(), out->mutable_array());
+ }
+ }
+ }
+};
+
+template <typename AllocateMem>
+struct ResolveIfElseExec<NullType, AllocateMem> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // if all are scalars, return a null scalar
+ if (batch[0].is_scalar() && batch[1].is_scalar() && batch[2].is_scalar()) {
+ *out = MakeNullScalar(null());
+ } else {
+ ARROW_ASSIGN_OR_RAISE(*out,
+ MakeArrayOfNull(null(), batch.length, ctx->memory_pool()));
+ }
+ return Status::OK();
+ }
+};
+
+struct IfElseFunction : ScalarFunction {
+ using ScalarFunction::ScalarFunction;
+
+ Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
+ RETURN_NOT_OK(CheckArity(*values));
+
+ using arrow::compute::detail::DispatchExactImpl;
+ // Do not DispatchExact here because it'll let through something like (bool,
+ // timestamp[s], timestamp[s, "UTC"])
+
+ // if 0th descriptor is null, replace with bool
+ if (values->at(0).type->id() == Type::NA) {
+ values->at(0).type = boolean();
+ }
+
+ // if-else 0'th descriptor is bool, so skip it
+ ValueDescr* left_arg = &(*values)[1];
+ constexpr size_t num_args = 2;
+
+ internal::ReplaceNullWithOtherType(left_arg, num_args);
+
+ // If both are identical dictionary types, dispatch to the dictionary kernel
+ // TODO(ARROW-14105): apply implicit casts to dictionary types too
+ ValueDescr* right_arg = &(*values)[2];
+ if (is_dictionary(left_arg->type->id()) && left_arg->type->Equals(right_arg->type)) {
+ auto kernel = DispatchExactImpl(this, *values);
+ DCHECK(kernel);
+ return kernel;
+ }
+
+ internal::EnsureDictionaryDecoded(left_arg, num_args);
+
+ if (auto type = internal::CommonNumeric(left_arg, num_args)) {
+ internal::ReplaceTypes(type, left_arg, num_args);
+ } else if (auto type = internal::CommonTemporal(left_arg, num_args)) {
+ internal::ReplaceTypes(type, left_arg, num_args);
+ } else if (auto type = internal::CommonBinary(left_arg, num_args)) {
+ internal::ReplaceTypes(type, left_arg, num_args);
+ } else if (HasDecimal(*values)) {
+ RETURN_NOT_OK(CastDecimalArgs(left_arg, num_args));
+ }
+
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+
+ return arrow::compute::detail::NoMatchingKernel(this, *values);
+ }
+};
+
+void AddNullIfElseKernel(const std::shared_ptr<IfElseFunction>& scalar_function) {
+ ScalarKernel kernel({boolean(), null(), null()}, null(),
+ ResolveIfElseExec<NullType,
+ /*AllocateMem=*/std::true_type>::Exec);
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ kernel.can_write_into_slices = false;
+
+ DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
+}
+
+void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_function,
+ const std::vector<std::shared_ptr<DataType>>& types) {
+ for (auto&& type : types) {
+ auto exec =
+ internal::GenerateTypeAgnosticPrimitive<ResolveIfElseExec,
+ /*AllocateMem=*/std::false_type>(*type);
+ // cond array needs to be boolean always
+ std::shared_ptr<KernelSignature> sig;
+ if (type->id() == Type::TIMESTAMP) {
+ auto unit = checked_cast<const TimestampType&>(*type).unit();
+ sig = KernelSignature::Make(
+ {boolean(), match::TimestampTypeUnit(unit), match::TimestampTypeUnit(unit)},
+ OutputType(LastType));
+ } else {
+ sig = KernelSignature::Make({boolean(), type, type}, type);
+ }
+ ScalarKernel kernel(std::move(sig), exec);
+ kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::PREALLOCATE;
+ kernel.can_write_into_slices = true;
+
+ DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
+ }
+}
+
+void AddBinaryIfElseKernels(const std::shared_ptr<IfElseFunction>& scalar_function,
+ const std::vector<std::shared_ptr<DataType>>& types) {
+ for (auto&& type : types) {
+ auto exec =
+ internal::GenerateTypeAgnosticVarBinaryBase<ResolveIfElseExec,
+ /*AllocateMem=*/std::true_type>(
+ *type);
+ // cond array needs to be boolean always
+ ScalarKernel kernel({boolean(), type, type}, type, exec);
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ kernel.can_write_into_slices = false;
+
+ DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
+ }
+}
+
+template <typename T>
+void AddFixedWidthIfElseKernel(const std::shared_ptr<IfElseFunction>& scalar_function) {
+ auto type_id = T::type_id;
+ ScalarKernel kernel({boolean(), InputType(type_id), InputType(type_id)},
+ OutputType(LastType),
+ ResolveIfElseExec<T, /*AllocateMem=*/std::false_type>::Exec);
+ kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::PREALLOCATE;
+ kernel.can_write_into_slices = true;
+
+ DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
+}
+
+void AddNestedIfElseKernels(const std::shared_ptr<IfElseFunction>& scalar_function) {
+ for (const auto type_id :
+ {Type::LIST, Type::LARGE_LIST, Type::FIXED_SIZE_LIST, Type::STRUCT,
+ Type::DENSE_UNION, Type::SPARSE_UNION, Type::DICTIONARY}) {
+ ScalarKernel kernel({boolean(), InputType(type_id), InputType(type_id)},
+ OutputType(LastType), NestedIfElseExec::Exec);
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ kernel.can_write_into_slices = false;
+ DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
+ }
+}
+
+// Helper to copy or broadcast fixed-width values between buffers.
+template <typename Type, typename Enable = void>
+struct CopyFixedWidth {};
+template <>
+struct CopyFixedWidth<BooleanType> {
+ static void CopyScalar(const Scalar& scalar, const int64_t length,
+ uint8_t* raw_out_values, const int64_t out_offset) {
+ const bool value = UnboxScalar<BooleanType>::Unbox(scalar);
+ BitUtil::SetBitsTo(raw_out_values, out_offset, length, value);
+ }
+ static void CopyArray(const DataType&, const uint8_t* in_values,
+ const int64_t in_offset, const int64_t length,
+ uint8_t* raw_out_values, const int64_t out_offset) {
+ arrow::internal::CopyBitmap(in_values, in_offset, length, raw_out_values, out_offset);
+ }
+};
+
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_number<Type>> {
+ using CType = typename TypeTraits<Type>::CType;
+ static void CopyScalar(const Scalar& scalar, const int64_t length,
+ uint8_t* raw_out_values, const int64_t out_offset) {
+ CType* out_values = reinterpret_cast<CType*>(raw_out_values);
+ const CType value = UnboxScalar<Type>::Unbox(scalar);
+ std::fill(out_values + out_offset, out_values + out_offset + length, value);
+ }
+ static void CopyArray(const DataType&, const uint8_t* in_values,
+ const int64_t in_offset, const int64_t length,
+ uint8_t* raw_out_values, const int64_t out_offset) {
+ std::memcpy(raw_out_values + out_offset * sizeof(CType),
+ in_values + in_offset * sizeof(CType), length * sizeof(CType));
+ }
+};
+
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> {
+ static void CopyScalar(const Scalar& values, const int64_t length,
+ uint8_t* raw_out_values, const int64_t out_offset) {
+ const int32_t width =
+ checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width();
+ uint8_t* next = raw_out_values + (width * out_offset);
+ const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(values);
+ // Scalar may have null value buffer
+ if (!scalar.value) {
+ std::memset(next, 0x00, width * length);
+ } else {
+ DCHECK_EQ(scalar.value->size(), width);
+ for (int i = 0; i < length; i++) {
+ std::memcpy(next, scalar.value->data(), width);
+ next += width;
+ }
+ }
+ }
+ static void CopyArray(const DataType& type, const uint8_t* in_values,
+ const int64_t in_offset, const int64_t length,
+ uint8_t* raw_out_values, const int64_t out_offset) {
+ const int32_t width = checked_cast<const FixedSizeBinaryType&>(type).byte_width();
+ uint8_t* next = raw_out_values + (width * out_offset);
+ std::memcpy(next, in_values + in_offset * width, length * width);
+ }
+};
+
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_decimal<Type>> {
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+ static void CopyScalar(const Scalar& values, const int64_t length,
+ uint8_t* raw_out_values, const int64_t out_offset) {
+ const int32_t width =
+ checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width();
+ uint8_t* next = raw_out_values + (width * out_offset);
+ const auto& scalar = checked_cast<const ScalarType&>(values);
+ const auto value = scalar.value.ToBytes();
+ for (int i = 0; i < length; i++) {
+ std::memcpy(next, value.data(), width);
+ next += width;
+ }
+ }
+ static void CopyArray(const DataType& type, const uint8_t* in_values,
+ const int64_t in_offset, const int64_t length,
+ uint8_t* raw_out_values, const int64_t out_offset) {
+ const int32_t width = checked_cast<const FixedSizeBinaryType&>(type).byte_width();
+ uint8_t* next = raw_out_values + (width * out_offset);
+ std::memcpy(next, in_values + in_offset * width, length * width);
+ }
+};
+
+// Copy fixed-width values from a scalar/array datum into an output values buffer
+template <typename Type>
+void CopyValues(const Datum& in_values, const int64_t in_offset, const int64_t length,
+ uint8_t* out_valid, uint8_t* out_values, const int64_t out_offset) {
+ if (in_values.is_scalar()) {
+ const auto& scalar = *in_values.scalar();
+ if (out_valid) {
+ BitUtil::SetBitsTo(out_valid, out_offset, length, scalar.is_valid);
+ }
+ CopyFixedWidth<Type>::CopyScalar(scalar, length, out_values, out_offset);
+ } else {
+ const ArrayData& array = *in_values.array();
+ if (out_valid) {
+ if (array.MayHaveNulls()) {
+ if (length == 1) {
+ // CopyBitmap is slow for short runs
+ BitUtil::SetBitTo(
+ out_valid, out_offset,
+ BitUtil::GetBit(array.buffers[0]->data(), array.offset + in_offset));
+ } else {
+ arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + in_offset,
+ length, out_valid, out_offset);
+ }
+ } else {
+ BitUtil::SetBitsTo(out_valid, out_offset, length, true);
+ }
+ }
+ CopyFixedWidth<Type>::CopyArray(*array.type, array.buffers[1]->data(),
+ array.offset + in_offset, length, out_values,
+ out_offset);
+ }
+}
+
+// Specialized helper to copy a single value from a source array. Allows avoiding
+// repeatedly calling MayHaveNulls and Buffer::data() which have internal checks that
+// add up when called in a loop.
+template <typename Type>
+void CopyOneArrayValue(const DataType& type, const uint8_t* in_valid,
+ const uint8_t* in_values, const int64_t in_offset,
+ uint8_t* out_valid, uint8_t* out_values,
+ const int64_t out_offset) {
+ if (out_valid) {
+ BitUtil::SetBitTo(out_valid, out_offset,
+ !in_valid || BitUtil::GetBit(in_valid, in_offset));
+ }
+ CopyFixedWidth<Type>::CopyArray(type, in_values, in_offset, /*length=*/1, out_values,
+ out_offset);
+}
+
+template <typename Type>
+void CopyOneScalarValue(const Scalar& scalar, uint8_t* out_valid, uint8_t* out_values,
+ const int64_t out_offset) {
+ if (out_valid) {
+ BitUtil::SetBitTo(out_valid, out_offset, scalar.is_valid);
+ }
+ CopyFixedWidth<Type>::CopyScalar(scalar, /*length=*/1, out_values, out_offset);
+}
+
+template <typename Type>
+void CopyOneValue(const Datum& in_values, const int64_t in_offset, uint8_t* out_valid,
+ uint8_t* out_values, const int64_t out_offset) {
+ if (in_values.is_array()) {
+ const ArrayData& array = *in_values.array();
+ CopyOneArrayValue<Type>(*array.type, array.GetValues<uint8_t>(0, 0),
+ array.GetValues<uint8_t>(1, 0), array.offset + in_offset,
+ out_valid, out_values, out_offset);
+ } else {
+ CopyOneScalarValue<Type>(*in_values.scalar(), out_valid, out_values, out_offset);
+ }
+}
+
+struct CaseWhenFunction : ScalarFunction {
+ using ScalarFunction::ScalarFunction;
+
+ Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
+ // The first function is a struct of booleans, where the number of fields in the
+ // struct is either equal to the number of other arguments or is one less.
+ RETURN_NOT_OK(CheckArity(*values));
+ auto first_type = (*values)[0].type;
+ if (first_type->id() != Type::STRUCT) {
+ return Status::TypeError("case_when: first argument must be STRUCT, not ",
+ *first_type);
+ }
+ auto num_fields = static_cast<size_t>(first_type->num_fields());
+ if (num_fields < values->size() - 2 || num_fields >= values->size()) {
+ return Status::Invalid(
+ "case_when: number of struct fields must be equal to or one less than count of "
+ "remaining arguments (",
+ values->size() - 1, "), got: ", first_type->num_fields());
+ }
+ for (const auto& field : first_type->fields()) {
+ if (field->type()->id() != Type::BOOL) {
+ return Status::TypeError(
+ "case_when: all fields of first argument must be BOOL, but ", field->name(),
+ " was of type: ", *field->type());
+ }
+ }
+
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+
+ EnsureDictionaryDecoded(values);
+ if (auto type = CommonNumeric(values->data() + 1, values->size() - 1)) {
+ for (auto it = values->begin() + 1; it != values->end(); it++) {
+ it->type = type;
+ }
+ }
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+ return arrow::compute::detail::NoMatchingKernel(this, *values);
+ }
+};
+
+// Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar conditions
+template <typename Type>
+Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& conds = checked_cast<const StructScalar&>(*batch.values[0].scalar());
+ if (!conds.is_valid) {
+ return Status::Invalid("cond struct must not be null");
+ }
+ Datum result;
+ for (size_t i = 0; i < batch.values.size() - 1; i++) {
+ if (i < conds.value.size()) {
+ const Scalar& cond = *conds.value[i];
+ if (cond.is_valid && internal::UnboxScalar<BooleanType>::Unbox(cond)) {
+ result = batch[i + 1];
+ break;
+ }
+ } else {
+ // ELSE clause
+ result = batch[i + 1];
+ break;
+ }
+ }
+ if (out->is_scalar()) {
+ *out = result.is_scalar() ? result.scalar() : MakeNullScalar(out->type());
+ return Status::OK();
+ }
+ ArrayData* output = out->mutable_array();
+ if (is_dictionary_type<Type>::value) {
+ const Datum& dict_from = result.is_value() ? result : batch[1];
+ if (dict_from.is_scalar()) {
+ output->dictionary = checked_cast<const DictionaryScalar&>(*dict_from.scalar())
+ .value.dictionary->data();
+ } else {
+ output->dictionary = dict_from.array()->dictionary;
+ }
+ }
+ if (!result.is_value()) {
+ // All conditions false, no 'else' argument
+ result = MakeNullScalar(out->type());
+ }
+ CopyValues<Type>(result, /*in_offset=*/0, batch.length,
+ output->GetMutableValues<uint8_t>(0, 0),
+ output->GetMutableValues<uint8_t>(1, 0), output->offset);
+ return Status::OK();
+}
+
+// Implement 'case when' for any mix of scalar/array arguments for any fixed-width type,
+// given helper functions to copy data from a source array to a target array
+template <typename Type>
+Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& conds_array = *batch.values[0].array();
+ if (conds_array.GetNullCount() > 0) {
+ return Status::Invalid("cond struct must not have top-level nulls");
+ }
+ ArrayData* output = out->mutable_array();
+ const int64_t out_offset = output->offset;
+ const auto num_value_args = batch.values.size() - 1;
+ const bool have_else_arg =
+ static_cast<size_t>(conds_array.type->num_fields()) < num_value_args;
+ uint8_t* out_valid = output->buffers[0]->mutable_data();
+ uint8_t* out_values = output->buffers[1]->mutable_data();
+
+ if (have_else_arg) {
+ // Copy 'else' value into output
+ CopyValues<Type>(batch.values.back(), /*in_offset=*/0, batch.length, out_valid,
+ out_values, out_offset);
+ } else {
+ // There's no 'else' argument, so we should have an all-null validity bitmap
+ BitUtil::SetBitsTo(out_valid, out_offset, batch.length, false);
+ }
+
+ // Allocate a temporary bitmap to determine which elements still need setting.
+ ARROW_ASSIGN_OR_RAISE(auto mask_buffer, ctx->AllocateBitmap(batch.length));
+ uint8_t* mask = mask_buffer->mutable_data();
+ std::memset(mask, 0xFF, mask_buffer->size());
+
+ // Then iterate through each argument in turn and set elements.
+ for (size_t i = 0; i < batch.values.size() - (have_else_arg ? 2 : 1); i++) {
+ const ArrayData& cond_array = *conds_array.child_data[i];
+ const int64_t cond_offset = conds_array.offset + cond_array.offset;
+ const uint8_t* cond_values = cond_array.buffers[1]->data();
+ const Datum& values_datum = batch[i + 1];
+ int64_t offset = 0;
+
+ if (cond_array.GetNullCount() == 0) {
+ // If no valid buffer, visit mask & cond bitmap simultaneously
+ BinaryBitBlockCounter counter(mask, /*start_offset=*/0, cond_values, cond_offset,
+ batch.length);
+ while (offset < batch.length) {
+ const auto block = counter.NextAndWord();
+ if (block.AllSet()) {
+ CopyValues<Type>(values_datum, offset, block.length, out_valid, out_values,
+ out_offset + offset);
+ BitUtil::SetBitsTo(mask, offset, block.length, false);
+ } else if (block.popcount) {
+ for (int64_t j = 0; j < block.length; ++j) {
+ if (BitUtil::GetBit(mask, offset + j) &&
+ BitUtil::GetBit(cond_values, cond_offset + offset + j)) {
+ CopyValues<Type>(values_datum, offset + j, /*length=*/1, out_valid,
+ out_values, out_offset + offset + j);
+ BitUtil::SetBitTo(mask, offset + j, false);
+ }
+ }
+ }
+ offset += block.length;
+ }
+ } else {
+ // Visit mask & cond bitmap & cond validity
+ const uint8_t* cond_valid = cond_array.buffers[0]->data();
+ Bitmap bitmaps[3] = {{mask, /*offset=*/0, batch.length},
+ {cond_values, cond_offset, batch.length},
+ {cond_valid, cond_offset, batch.length}};
+ Bitmap::VisitWords(bitmaps, [&](std::array<uint64_t, 3> words) {
+ const uint64_t word = words[0] & words[1] & words[2];
+ const int64_t block_length = std::min<int64_t>(64, batch.length - offset);
+ if (word == std::numeric_limits<uint64_t>::max()) {
+ CopyValues<Type>(values_datum, offset, block_length, out_valid, out_values,
+ out_offset + offset);
+ BitUtil::SetBitsTo(mask, offset, block_length, false);
+ } else if (word) {
+ for (int64_t j = 0; j < block_length; ++j) {
+ if (BitUtil::GetBit(mask, offset + j) &&
+ BitUtil::GetBit(cond_valid, cond_offset + offset + j) &&
+ BitUtil::GetBit(cond_values, cond_offset + offset + j)) {
+ CopyValues<Type>(values_datum, offset + j, /*length=*/1, out_valid,
+ out_values, out_offset + offset + j);
+ BitUtil::SetBitTo(mask, offset + j, false);
+ }
+ }
+ }
+ });
+ }
+ }
+ if (!have_else_arg) {
+ // Need to initialize any remaining null slots (uninitialized memory)
+ BitBlockCounter counter(mask, /*offset=*/0, batch.length);
+ int64_t offset = 0;
+ auto bit_width = checked_cast<const FixedWidthType&>(*out->type()).bit_width();
+ auto byte_width = BitUtil::BytesForBits(bit_width);
+ while (offset < batch.length) {
+ const auto block = counter.NextWord();
+ if (block.AllSet()) {
+ if (bit_width == 1) {
+ BitUtil::SetBitsTo(out_values, out_offset + offset, block.length, false);
+ } else {
+ std::memset(out_values + (out_offset + offset) * byte_width, 0x00,
+ byte_width * block.length);
+ }
+ } else if (!block.NoneSet()) {
+ for (int64_t j = 0; j < block.length; ++j) {
+ if (BitUtil::GetBit(out_valid, out_offset + offset + j)) continue;
+ if (bit_width == 1) {
+ BitUtil::ClearBit(out_values, out_offset + offset + j);
+ } else {
+ std::memset(out_values + (out_offset + offset + j) * byte_width, 0x00,
+ byte_width);
+ }
+ }
+ }
+ offset += block.length;
+ }
+ }
+ return Status::OK();
+}
+
+template <typename Type, typename Enable = void>
+struct CaseWhenFunctor {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch.values[0].is_array()) {
+ return ExecArrayCaseWhen<Type>(ctx, batch, out);
+ }
+ return ExecScalarCaseWhen<Type>(ctx, batch, out);
+ }
+};
+
+template <>
+struct CaseWhenFunctor<NullType> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return Status::OK();
+ }
+};
+
+Status ExecVarWidthScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch,
+ Datum* out) {
+ const auto& conds = checked_cast<const StructScalar&>(*batch.values[0].scalar());
+ Datum result;
+ for (size_t i = 0; i < batch.values.size() - 1; i++) {
+ if (i < conds.value.size()) {
+ const Scalar& cond = *conds.value[i];
+ if (cond.is_valid && internal::UnboxScalar<BooleanType>::Unbox(cond)) {
+ result = batch[i + 1];
+ break;
+ }
+ } else {
+ // ELSE clause
+ result = batch[i + 1];
+ break;
+ }
+ }
+ if (out->is_scalar()) {
+ DCHECK(result.is_scalar() || result.kind() == Datum::NONE);
+ *out = result.is_scalar() ? result.scalar() : MakeNullScalar(out->type());
+ return Status::OK();
+ }
+ ArrayData* output = out->mutable_array();
+ if (!result.is_value()) {
+ // All conditions false, no 'else' argument
+ ARROW_ASSIGN_OR_RAISE(
+ auto array, MakeArrayOfNull(output->type, batch.length, ctx->memory_pool()));
+ *output = *array->data();
+ } else if (result.is_scalar()) {
+ ARROW_ASSIGN_OR_RAISE(auto array, MakeArrayFromScalar(*result.scalar(), batch.length,
+ ctx->memory_pool()));
+ *output = *array->data();
+ } else {
+ *output = *result.array();
+ }
+ return Status::OK();
+}
+
+// Use std::function for reserve_data to avoid instantiating template so much
+template <typename AppendScalar>
+static Status ExecVarWidthArrayCaseWhenImpl(
+ KernelContext* ctx, const ExecBatch& batch, Datum* out,
+ std::function<Status(ArrayBuilder*)> reserve_data, AppendScalar append_scalar) {
+ const auto& conds_array = *batch.values[0].array();
+ ArrayData* output = out->mutable_array();
+ const bool have_else_arg =
+ static_cast<size_t>(conds_array.type->num_fields()) < (batch.values.size() - 1);
+ std::unique_ptr<ArrayBuilder> raw_builder;
+ RETURN_NOT_OK(MakeBuilderExactIndex(ctx->memory_pool(), out->type(), &raw_builder));
+ RETURN_NOT_OK(raw_builder->Reserve(batch.length));
+ RETURN_NOT_OK(reserve_data(raw_builder.get()));
+
+ for (int64_t row = 0; row < batch.length; row++) {
+ int64_t selected = have_else_arg ? static_cast<int64_t>(batch.values.size() - 1) : -1;
+ for (int64_t arg = 0; static_cast<size_t>(arg) < conds_array.child_data.size();
+ arg++) {
+ const ArrayData& cond_array = *conds_array.child_data[arg];
+ if ((!cond_array.buffers[0] ||
+ BitUtil::GetBit(cond_array.buffers[0]->data(),
+ conds_array.offset + cond_array.offset + row)) &&
+ BitUtil::GetBit(cond_array.buffers[1]->data(),
+ conds_array.offset + cond_array.offset + row)) {
+ selected = arg + 1;
+ break;
+ }
+ }
+ if (selected < 0) {
+ RETURN_NOT_OK(raw_builder->AppendNull());
+ continue;
+ }
+ const Datum& source = batch.values[selected];
+ if (source.is_scalar()) {
+ const auto& scalar = *source.scalar();
+ if (!scalar.is_valid) {
+ RETURN_NOT_OK(raw_builder->AppendNull());
+ } else {
+ RETURN_NOT_OK(append_scalar(raw_builder.get(), scalar));
+ }
+ } else {
+ const auto& array = source.array();
+ if (!array->buffers[0] ||
+ BitUtil::GetBit(array->buffers[0]->data(), array->offset + row)) {
+ RETURN_NOT_OK(raw_builder->AppendArraySlice(*array, row, /*length=*/1));
+ } else {
+ RETURN_NOT_OK(raw_builder->AppendNull());
+ }
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto temp_output, raw_builder->Finish());
+ *output = *temp_output->data();
+ return Status::OK();
+}
+
+// Single instantiation using ArrayBuilder::AppendScalar for append_scalar
+static Status ExecVarWidthArrayCaseWhen(
+ KernelContext* ctx, const ExecBatch& batch, Datum* out,
+ std::function<Status(ArrayBuilder*)> reserve_data) {
+ return ExecVarWidthArrayCaseWhenImpl(
+ ctx, batch, out, std::move(reserve_data),
+ [](ArrayBuilder* raw_builder, const Scalar& scalar) {
+ return raw_builder->AppendScalar(scalar);
+ });
+}
+
+template <typename Type>
+struct CaseWhenFunctor<Type, enable_if_base_binary<Type>> {
+ using offset_type = typename Type::offset_type;
+ using BuilderType = typename TypeTraits<Type>::BuilderType;
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].null_count() > 0) {
+ return Status::Invalid("cond struct must not have outer nulls");
+ }
+ if (batch[0].is_scalar()) {
+ return ExecVarWidthScalarCaseWhen(ctx, batch, out);
+ }
+ return ExecArray(ctx, batch, out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return ExecVarWidthArrayCaseWhenImpl(
+ ctx, batch, out,
+ // ReserveData
+ [&](ArrayBuilder* raw_builder) {
+ int64_t reservation = 0;
+ for (size_t arg = 1; arg < batch.values.size(); arg++) {
+ auto source = batch.values[arg];
+ if (source.is_scalar()) {
+ const auto& scalar =
+ checked_cast<const BaseBinaryScalar&>(*source.scalar());
+ if (!scalar.value) continue;
+ reservation =
+ std::max<int64_t>(reservation, batch.length * scalar.value->size());
+ } else {
+ const auto& array = *source.array();
+ const auto& offsets = array.GetValues<offset_type>(1);
+ reservation =
+ std::max<int64_t>(reservation, offsets[array.length] - offsets[0]);
+ }
+ }
+ // checked_cast works since (Large)StringBuilder <: (Large)BinaryBuilder
+ return checked_cast<BuilderType*>(raw_builder)->ReserveData(reservation);
+ },
+ // AppendScalar
+ [](ArrayBuilder* raw_builder, const Scalar& raw_scalar) {
+ const auto& scalar = checked_cast<const BaseBinaryScalar&>(raw_scalar);
+ return checked_cast<BuilderType*>(raw_builder)
+ ->Append(scalar.value->data(),
+ static_cast<offset_type>(scalar.value->size()));
+ });
+ }
+};
+
+template <typename Type>
+struct CaseWhenFunctor<Type, enable_if_var_size_list<Type>> {
+ using offset_type = typename Type::offset_type;
+ using BuilderType = typename TypeTraits<Type>::BuilderType;
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].null_count() > 0) {
+ return Status::Invalid("cond struct must not have outer nulls");
+ }
+ if (batch[0].is_scalar()) {
+ return ExecVarWidthScalarCaseWhen(ctx, batch, out);
+ }
+ return ExecArray(ctx, batch, out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return ExecVarWidthArrayCaseWhen(
+ ctx, batch, out,
+ // ReserveData
+ [&](ArrayBuilder* raw_builder) {
+ auto builder = checked_cast<BuilderType*>(raw_builder);
+ auto child_builder = builder->value_builder();
+
+ int64_t reservation = 0;
+ for (size_t arg = 1; arg < batch.values.size(); arg++) {
+ auto source = batch.values[arg];
+ if (!source.is_array()) {
+ const auto& scalar = checked_cast<const BaseListScalar&>(*source.scalar());
+ if (!scalar.value) continue;
+ reservation =
+ std::max<int64_t>(reservation, batch.length * scalar.value->length());
+ } else {
+ const auto& array = *source.array();
+ reservation = std::max<int64_t>(reservation, array.child_data[0]->length);
+ }
+ }
+ return child_builder->Reserve(reservation);
+ });
+ }
+};
+
+// No-op reserve function, pulled out to avoid apparent miscompilation on MinGW
+Status ReserveNoData(ArrayBuilder*) { return Status::OK(); }
+
+template <>
+struct CaseWhenFunctor<MapType> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].null_count() > 0) {
+ return Status::Invalid("cond struct must not have outer nulls");
+ }
+ if (batch[0].is_scalar()) {
+ return ExecVarWidthScalarCaseWhen(ctx, batch, out);
+ }
+ return ExecArray(ctx, batch, out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ std::function<Status(ArrayBuilder*)> reserve_data = ReserveNoData;
+ return ExecVarWidthArrayCaseWhen(ctx, batch, out, std::move(reserve_data));
+ }
+};
+
+template <>
+struct CaseWhenFunctor<StructType> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].null_count() > 0) {
+ return Status::Invalid("cond struct must not have outer nulls");
+ }
+ if (batch[0].is_scalar()) {
+ return ExecVarWidthScalarCaseWhen(ctx, batch, out);
+ }
+ return ExecArray(ctx, batch, out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ std::function<Status(ArrayBuilder*)> reserve_data = ReserveNoData;
+ return ExecVarWidthArrayCaseWhen(ctx, batch, out, std::move(reserve_data));
+ }
+};
+
+template <>
+struct CaseWhenFunctor<FixedSizeListType> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].null_count() > 0) {
+ return Status::Invalid("cond struct must not have outer nulls");
+ }
+ if (batch[0].is_scalar()) {
+ return ExecVarWidthScalarCaseWhen(ctx, batch, out);
+ }
+ return ExecArray(ctx, batch, out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& ty = checked_cast<const FixedSizeListType&>(*out->type());
+ const int64_t width = ty.list_size();
+ return ExecVarWidthArrayCaseWhen(
+ ctx, batch, out,
+ // ReserveData
+ [&](ArrayBuilder* raw_builder) {
+ int64_t children = batch.length * width;
+ return checked_cast<FixedSizeListBuilder*>(raw_builder)
+ ->value_builder()
+ ->Reserve(children);
+ });
+ }
+};
+
+template <typename Type>
+struct CaseWhenFunctor<Type, enable_if_union<Type>> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].null_count() > 0) {
+ return Status::Invalid("cond struct must not have outer nulls");
+ }
+ if (batch[0].is_scalar()) {
+ return ExecVarWidthScalarCaseWhen(ctx, batch, out);
+ }
+ return ExecArray(ctx, batch, out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ std::function<Status(ArrayBuilder*)> reserve_data = ReserveNoData;
+ return ExecVarWidthArrayCaseWhen(ctx, batch, out, std::move(reserve_data));
+ }
+};
+
+template <>
+struct CaseWhenFunctor<DictionaryType> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].null_count() > 0) {
+ return Status::Invalid("cond struct must not have outer nulls");
+ }
+ if (batch[0].is_scalar()) {
+ return ExecVarWidthScalarCaseWhen(ctx, batch, out);
+ }
+ return ExecArray(ctx, batch, out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ std::function<Status(ArrayBuilder*)> reserve_data = ReserveNoData;
+ return ExecVarWidthArrayCaseWhen(ctx, batch, out, std::move(reserve_data));
+ }
+};
+
+struct CoalesceFunction : ScalarFunction {
+ using ScalarFunction::ScalarFunction;
+
+ Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
+ RETURN_NOT_OK(CheckArity(*values));
+ using arrow::compute::detail::DispatchExactImpl;
+ // Do not DispatchExact here since we want to rescale decimals if necessary
+ EnsureDictionaryDecoded(values);
+ if (auto type = CommonNumeric(*values)) {
+ ReplaceTypes(type, values);
+ }
+ if (auto type = CommonBinary(values->data(), values->size())) {
+ ReplaceTypes(type, values);
+ }
+ if (auto type = CommonTemporal(values->data(), values->size())) {
+ ReplaceTypes(type, values);
+ }
+ if (HasDecimal(*values)) {
+ RETURN_NOT_OK(CastDecimalArgs(values->data(), values->size()));
+ }
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+ return arrow::compute::detail::NoMatchingKernel(this, *values);
+ }
+};
+
+// Implement a 'coalesce' (SQL) operator for any number of scalar inputs
+Status ExecScalarCoalesce(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ for (const auto& datum : batch.values) {
+ if (datum.scalar()->is_valid) {
+ *out = datum;
+ break;
+ }
+ }
+ return Status::OK();
+}
+
+// Helper: copy from a source datum into all null slots of the output
+template <typename Type>
+void CopyValuesAllValid(Datum source, uint8_t* out_valid, uint8_t* out_values,
+ const int64_t out_offset, const int64_t length) {
+ BitRunReader bit_reader(out_valid, out_offset, length);
+ int64_t offset = 0;
+ while (true) {
+ const auto run = bit_reader.NextRun();
+ if (run.length == 0) {
+ break;
+ }
+ if (!run.set) {
+ CopyValues<Type>(source, offset, run.length, out_valid, out_values,
+ out_offset + offset);
+ }
+ offset += run.length;
+ }
+ DCHECK_EQ(offset, length);
+}
+
+// Helper: zero the values buffer of the output wherever the slot is null
+void InitializeNullSlots(const DataType& type, uint8_t* out_valid, uint8_t* out_values,
+ const int64_t out_offset, const int64_t length) {
+ BitRunReader bit_reader(out_valid, out_offset, length);
+ int64_t offset = 0;
+ const auto bit_width = checked_cast<const FixedWidthType&>(type).bit_width();
+ const auto byte_width = BitUtil::BytesForBits(bit_width);
+ while (true) {
+ const auto run = bit_reader.NextRun();
+ if (run.length == 0) {
+ break;
+ }
+ if (!run.set) {
+ if (bit_width == 1) {
+ BitUtil::SetBitsTo(out_values, out_offset + offset, run.length, false);
+ } else {
+ std::memset(out_values + (out_offset + offset) * byte_width, 0,
+ byte_width * run.length);
+ }
+ }
+ offset += run.length;
+ }
+ DCHECK_EQ(offset, length);
+}
+
+// Implement 'coalesce' for any mix of scalar/array arguments for any fixed-width type
+template <typename Type>
+Status ExecArrayCoalesce(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ ArrayData* output = out->mutable_array();
+ const int64_t out_offset = output->offset;
+ // Use output validity buffer as mask to decide what values to copy
+ uint8_t* out_valid = output->buffers[0]->mutable_data();
+
+ // Clear output validity buffer - no values are set initially
+ BitUtil::SetBitsTo(out_valid, out_offset, batch.length, false);
+ uint8_t* out_values = output->buffers[1]->mutable_data();
+
+ for (const auto& datum : batch.values) {
+ if (datum.null_count() == 0) {
+ // Valid scalar, or all-valid array
+ CopyValuesAllValid<Type>(datum, out_valid, out_values, out_offset, batch.length);
+ break;
+ } else if (datum.is_array()) {
+ // Array with nulls
+ const ArrayData& arr = *datum.array();
+ const int64_t in_offset = arr.offset;
+ const int64_t in_null_count = arr.null_count;
+ DCHECK_GT(in_null_count, 0); // computed in datum.null_count()
+ const DataType& type = *arr.type;
+ const uint8_t* in_valid = arr.buffers[0]->data();
+ const uint8_t* in_values = arr.buffers[1]->data();
+
+ if (in_null_count < 0.8 * batch.length) {
+ // The input is not mostly null, we deem it more efficient to
+ // copy values even underlying null slots instead of the more
+ // expensive bitmasking using BinaryBitBlockCounter.
+ BitRunReader bit_reader(out_valid, out_offset, batch.length);
+ int64_t offset = 0;
+ while (true) {
+ const auto run = bit_reader.NextRun();
+ if (run.length == 0) {
+ break;
+ }
+ if (!run.set) {
+ // Copy from input
+ CopyFixedWidth<Type>::CopyArray(type, in_values, in_offset + offset,
+ run.length, out_values, out_offset + offset);
+ }
+ offset += run.length;
+ }
+ arrow::internal::BitmapOr(out_valid, out_offset, in_valid, in_offset,
+ batch.length, out_offset, out_valid);
+ } else {
+ BinaryBitBlockCounter counter(in_valid, in_offset, out_valid, out_offset,
+ batch.length);
+ int64_t offset = 0;
+ while (offset < batch.length) {
+ const auto block = counter.NextAndNotWord();
+ if (block.AllSet()) {
+ CopyValues<Type>(datum, offset, block.length, out_valid, out_values,
+ out_offset + offset);
+ } else if (block.popcount) {
+ for (int64_t j = 0; j < block.length; ++j) {
+ if (!BitUtil::GetBit(out_valid, out_offset + offset + j) &&
+ BitUtil::GetBit(in_valid, in_offset + offset + j)) {
+ // This version lets us avoid calling MayHaveNulls() on every iteration
+ // (which does an atomic load and can add up)
+ CopyOneArrayValue<Type>(type, in_valid, in_values, in_offset + offset + j,
+ out_valid, out_values, out_offset + offset + j);
+ }
+ }
+ }
+ offset += block.length;
+ }
+ }
+ }
+ }
+
+ // Initialize any remaining null slots (uninitialized memory)
+ InitializeNullSlots(*out->type(), out_valid, out_values, out_offset, batch.length);
+ return Status::OK();
+}
+
+// Special case: implement 'coalesce' for an array and a scalar for any
+// fixed-width type (a 'fill_null' operation)
+template <typename Type>
+Status ExecArrayScalarCoalesce(KernelContext* ctx, Datum left, Datum right,
+ int64_t length, Datum* out) {
+ ArrayData* output = out->mutable_array();
+ const int64_t out_offset = output->offset;
+ uint8_t* out_valid = output->buffers[0]->mutable_data();
+ uint8_t* out_values = output->buffers[1]->mutable_data();
+
+ const ArrayData& left_arr = *left.array();
+ const uint8_t* left_valid = left_arr.buffers[0]->data();
+ const uint8_t* left_values = left_arr.buffers[1]->data();
+ const Scalar& right_scalar = *right.scalar();
+
+ if (left.null_count() < length * 0.2) {
+ // There are less than 20% nulls in the left array, so first copy
+ // the left values, then fill any nulls with the right value
+ CopyFixedWidth<Type>::CopyArray(*left_arr.type, left_values, left_arr.offset, length,
+ out_values, out_offset);
+
+ BitRunReader reader(left_valid, left_arr.offset, left_arr.length);
+ int64_t offset = 0;
+ while (true) {
+ const auto run = reader.NextRun();
+ if (run.length == 0) break;
+ if (!run.set) {
+ // All from right
+ CopyFixedWidth<Type>::CopyScalar(right_scalar, run.length, out_values,
+ out_offset + offset);
+ }
+ offset += run.length;
+ }
+ DCHECK_EQ(offset, length);
+ } else {
+ BitRunReader reader(left_valid, left_arr.offset, left_arr.length);
+ int64_t offset = 0;
+ while (true) {
+ const auto run = reader.NextRun();
+ if (run.length == 0) break;
+ if (run.set) {
+ // All from left
+ CopyFixedWidth<Type>::CopyArray(*left_arr.type, left_values,
+ left_arr.offset + offset, run.length, out_values,
+ out_offset + offset);
+ } else {
+ // All from right
+ CopyFixedWidth<Type>::CopyScalar(right_scalar, run.length, out_values,
+ out_offset + offset);
+ }
+ offset += run.length;
+ }
+ DCHECK_EQ(offset, length);
+ }
+
+ if (right_scalar.is_valid || !left_valid) {
+ BitUtil::SetBitsTo(out_valid, out_offset, length, true);
+ } else {
+ arrow::internal::CopyBitmap(left_valid, left_arr.offset, length, out_valid,
+ out_offset);
+ }
+ return Status::OK();
+}
+
+// Special case: implement 'coalesce' for any 2 arguments for any fixed-width
+// type (a 'fill_null' operation)
+template <typename Type>
+Status ExecBinaryCoalesce(KernelContext* ctx, Datum left, Datum right, int64_t length,
+ Datum* out) {
+ if (left.is_scalar() && right.is_scalar()) {
+ // Both scalar
+ *out = left.scalar()->is_valid ? left : right;
+ return Status::OK();
+ }
+
+ ArrayData* output = out->mutable_array();
+ const int64_t out_offset = output->offset;
+ uint8_t* out_valid = output->buffers[0]->mutable_data();
+ uint8_t* out_values = output->buffers[1]->mutable_data();
+
+ const int64_t left_null_count = left.null_count();
+ const int64_t right_null_count = right.null_count();
+
+ if (left.is_scalar()) {
+ // (Scalar, Any)
+ CopyValues<Type>(left.scalar()->is_valid ? left : right, /*in_offset=*/0, length,
+ out_valid, out_values, out_offset);
+ return Status::OK();
+ } else if (left_null_count == 0) {
+ // LHS is array without nulls. Must copy (since we preallocate)
+ CopyValues<Type>(left, /*in_offset=*/0, length, out_valid, out_values, out_offset);
+ return Status::OK();
+ } else if (right.is_scalar()) {
+ // (Array, Scalar)
+ return ExecArrayScalarCoalesce<Type>(ctx, left, right, length, out);
+ }
+
+ // (Array, Array)
+ const ArrayData& left_arr = *left.array();
+ const ArrayData& right_arr = *right.array();
+ const uint8_t* left_valid = left_arr.buffers[0]->data();
+ const uint8_t* left_values = left_arr.buffers[1]->data();
+ const uint8_t* right_valid =
+ right_null_count > 0 ? right_arr.buffers[0]->data() : nullptr;
+ const uint8_t* right_values = right_arr.buffers[1]->data();
+
+ BitRunReader bit_reader(left_valid, left_arr.offset, left_arr.length);
+ int64_t offset = 0;
+ while (true) {
+ const auto run = bit_reader.NextRun();
+ if (run.length == 0) {
+ break;
+ }
+ if (run.set) {
+ // All from left
+ CopyFixedWidth<Type>::CopyArray(*left_arr.type, left_values,
+ left_arr.offset + offset, run.length, out_values,
+ out_offset + offset);
+ } else {
+ // All from right
+ CopyFixedWidth<Type>::CopyArray(*right_arr.type, right_values,
+ right_arr.offset + offset, run.length, out_values,
+ out_offset + offset);
+ }
+ offset += run.length;
+ }
+ DCHECK_EQ(offset, length);
+
+ if (right_null_count == 0) {
+ BitUtil::SetBitsTo(out_valid, out_offset, length, true);
+ } else {
+ arrow::internal::BitmapOr(left_valid, left_arr.offset, right_valid, right_arr.offset,
+ length, out_offset, out_valid);
+ }
+ return Status::OK();
+}
+
+template <typename AppendScalar>
+static Status ExecVarWidthCoalesceImpl(KernelContext* ctx, const ExecBatch& batch,
+ Datum* out,
+ std::function<Status(ArrayBuilder*)> reserve_data,
+ AppendScalar append_scalar) {
+ // Special case: grab any leading non-null scalar or array arguments
+ for (const auto& datum : batch.values) {
+ if (datum.is_scalar()) {
+ if (!datum.scalar()->is_valid) continue;
+ ARROW_ASSIGN_OR_RAISE(
+ *out, MakeArrayFromScalar(*datum.scalar(), batch.length, ctx->memory_pool()));
+ return Status::OK();
+ } else if (datum.is_array() && !datum.array()->MayHaveNulls()) {
+ *out = datum;
+ return Status::OK();
+ }
+ break;
+ }
+ ArrayData* output = out->mutable_array();
+ std::unique_ptr<ArrayBuilder> raw_builder;
+ RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder));
+ RETURN_NOT_OK(raw_builder->Reserve(batch.length));
+ RETURN_NOT_OK(reserve_data(raw_builder.get()));
+
+ for (int64_t i = 0; i < batch.length; i++) {
+ bool set = false;
+ for (const auto& datum : batch.values) {
+ if (datum.is_scalar()) {
+ if (datum.scalar()->is_valid) {
+ RETURN_NOT_OK(append_scalar(raw_builder.get(), *datum.scalar()));
+ set = true;
+ break;
+ }
+ } else {
+ const ArrayData& source = *datum.array();
+ if (!source.MayHaveNulls() ||
+ BitUtil::GetBit(source.buffers[0]->data(), source.offset + i)) {
+ RETURN_NOT_OK(raw_builder->AppendArraySlice(source, i, /*length=*/1));
+ set = true;
+ break;
+ }
+ }
+ }
+ if (!set) RETURN_NOT_OK(raw_builder->AppendNull());
+ }
+ ARROW_ASSIGN_OR_RAISE(auto temp_output, raw_builder->Finish());
+ *output = *temp_output->data();
+ output->type = batch[0].type();
+ return Status::OK();
+}
+
+static Status ExecVarWidthCoalesce(KernelContext* ctx, const ExecBatch& batch, Datum* out,
+ std::function<Status(ArrayBuilder*)> reserve_data) {
+ return ExecVarWidthCoalesceImpl(ctx, batch, out, std::move(reserve_data),
+ [](ArrayBuilder* builder, const Scalar& scalar) {
+ return builder->AppendScalar(scalar);
+ });
+}
+
+template <typename Type, typename Enable = void>
+struct CoalesceFunctor {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (!TypeTraits<Type>::is_parameter_free) {
+ RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[0], batch.values.size()));
+ }
+ // Special case for two arguments (since "fill_null" is a common operation)
+ if (batch.num_values() == 2) {
+ return ExecBinaryCoalesce<Type>(ctx, batch[0], batch[1], batch.length, out);
+ }
+ for (const auto& datum : batch.values) {
+ if (datum.is_array()) {
+ return ExecArrayCoalesce<Type>(ctx, batch, out);
+ }
+ }
+ return ExecScalarCoalesce(ctx, batch, out);
+ }
+};
+
+template <>
+struct CoalesceFunctor<NullType> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return Status::OK();
+ }
+};
+
+template <typename Type>
+struct CoalesceFunctor<Type, enable_if_base_binary<Type>> {
+ using offset_type = typename Type::offset_type;
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using BuilderType = typename TypeTraits<Type>::BuilderType;
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch.num_values() == 2 && batch.values[0].is_array() &&
+ batch.values[1].is_scalar()) {
+ // Specialized implementation for common case ('fill_null' operation)
+ return ExecArrayScalar(ctx, *batch.values[0].array(), *batch.values[1].scalar(),
+ out);
+ }
+ for (const auto& datum : batch.values) {
+ if (datum.is_array()) {
+ return ExecArray(ctx, batch, out);
+ }
+ }
+ return ExecScalarCoalesce(ctx, batch, out);
+ }
+
+ static Status ExecArrayScalar(KernelContext* ctx, const ArrayData& left,
+ const Scalar& right, Datum* out) {
+ const int64_t null_count = left.GetNullCount();
+ if (null_count == 0 || !right.is_valid) {
+ *out = left;
+ return Status::OK();
+ }
+ ArrayData* output = out->mutable_array();
+ BuilderType builder(left.type, ctx->memory_pool());
+ RETURN_NOT_OK(builder.Reserve(left.length));
+ const auto& scalar = checked_cast<const BaseBinaryScalar&>(right);
+ const offset_type* offsets = left.GetValues<offset_type>(1);
+ const int64_t data_reserve = static_cast<int64_t>(offsets[left.length] - offsets[0]) +
+ null_count * scalar.value->size();
+ if (data_reserve > std::numeric_limits<offset_type>::max()) {
+ return Status::CapacityError(
+ "Result will not fit in a 32-bit binary-like array, convert to large type");
+ }
+ RETURN_NOT_OK(builder.ReserveData(static_cast<offset_type>(data_reserve)));
+
+ util::string_view fill_value(*scalar.value);
+ VisitArrayDataInline<Type>(
+ left, [&](util::string_view s) { builder.UnsafeAppend(s); },
+ [&]() { builder.UnsafeAppend(fill_value); });
+
+ ARROW_ASSIGN_OR_RAISE(auto temp_output, builder.Finish());
+ *output = *temp_output->data();
+ output->type = left.type;
+ return Status::OK();
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return ExecVarWidthCoalesceImpl(
+ ctx, batch, out,
+ [&](ArrayBuilder* builder) {
+ int64_t reservation = 0;
+ for (const auto& datum : batch.values) {
+ if (datum.is_array()) {
+ const ArrayType array(datum.array());
+ reservation = std::max<int64_t>(reservation, array.total_values_length());
+ } else {
+ const auto& scalar = *datum.scalar();
+ if (scalar.is_valid) {
+ const int64_t size = UnboxScalar<Type>::Unbox(scalar).size();
+ reservation = std::max<int64_t>(reservation, batch.length * size);
+ }
+ }
+ }
+ return checked_cast<BuilderType*>(builder)->ReserveData(reservation);
+ },
+ [&](ArrayBuilder* builder, const Scalar& scalar) {
+ return checked_cast<BuilderType*>(builder)->Append(
+ UnboxScalar<Type>::Unbox(scalar));
+ });
+ }
+};
+
+template <typename Type>
+struct CoalesceFunctor<
+ Type, enable_if_t<is_nested_type<Type>::value && !is_union_type<Type>::value>> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[0], batch.values.size()));
+ for (const auto& datum : batch.values) {
+ if (datum.is_array()) {
+ return ExecArray(ctx, batch, out);
+ }
+ }
+ return ExecScalarCoalesce(ctx, batch, out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ std::function<Status(ArrayBuilder*)> reserve_data = ReserveNoData;
+ return ExecVarWidthCoalesce(ctx, batch, out, reserve_data);
+ }
+};
+
+template <typename Type>
+struct CoalesceFunctor<Type, enable_if_union<Type>> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // Unions don't have top-level nulls, so a specialized implementation is needed
+ RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[0], batch.values.size()));
+
+ for (const auto& datum : batch.values) {
+ if (datum.is_array()) {
+ return ExecArray(ctx, batch, out);
+ }
+ }
+ return ExecScalar(ctx, batch, out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ ArrayData* output = out->mutable_array();
+ std::unique_ptr<ArrayBuilder> raw_builder;
+ RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder));
+ RETURN_NOT_OK(raw_builder->Reserve(batch.length));
+
+ const UnionType& type = checked_cast<const UnionType&>(*out->type());
+ for (int64_t i = 0; i < batch.length; i++) {
+ bool set = false;
+ for (const auto& datum : batch.values) {
+ if (datum.is_scalar()) {
+ const auto& scalar = checked_cast<const UnionScalar&>(*datum.scalar());
+ if (scalar.is_valid && scalar.value->is_valid) {
+ RETURN_NOT_OK(raw_builder->AppendScalar(scalar));
+ set = true;
+ break;
+ }
+ } else {
+ const ArrayData& source = *datum.array();
+ // Peek at the relevant child array's validity bitmap
+ if (std::is_same<Type, SparseUnionType>::value) {
+ const int8_t type_id = source.GetValues<int8_t>(1)[i];
+ const int child_id = type.child_ids()[type_id];
+ const ArrayData& child = *source.child_data[child_id];
+ if (!child.MayHaveNulls() ||
+ BitUtil::GetBit(child.buffers[0]->data(),
+ source.offset + child.offset + i)) {
+ RETURN_NOT_OK(raw_builder->AppendArraySlice(source, i, /*length=*/1));
+ set = true;
+ break;
+ }
+ } else {
+ const int8_t type_id = source.GetValues<int8_t>(1)[i];
+ const int32_t offset = source.GetValues<int32_t>(2)[i];
+ const int child_id = type.child_ids()[type_id];
+ const ArrayData& child = *source.child_data[child_id];
+ if (!child.MayHaveNulls() ||
+ BitUtil::GetBit(child.buffers[0]->data(), child.offset + offset)) {
+ RETURN_NOT_OK(raw_builder->AppendArraySlice(source, i, /*length=*/1));
+ set = true;
+ break;
+ }
+ }
+ }
+ }
+ if (!set) RETURN_NOT_OK(raw_builder->AppendNull());
+ }
+ ARROW_ASSIGN_OR_RAISE(auto temp_output, raw_builder->Finish());
+ *output = *temp_output->data();
+ return Status::OK();
+ }
+
+ static Status ExecScalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ for (const auto& datum : batch.values) {
+ const auto& scalar = checked_cast<const UnionScalar&>(*datum.scalar());
+ // Union scalars can have top-level validity
+ if (scalar.is_valid && scalar.value->is_valid) {
+ *out = datum;
+ break;
+ }
+ }
+ return Status::OK();
+ }
+};
+
+template <typename Type>
+Status ExecScalarChoose(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& index_scalar = *batch[0].scalar();
+ if (!index_scalar.is_valid) {
+ if (out->is_array()) {
+ auto source = MakeNullScalar(out->type());
+ ArrayData* output = out->mutable_array();
+ CopyValues<Type>(source, /*row=*/0, batch.length,
+ output->GetMutableValues<uint8_t>(0, /*absolute_offset=*/0),
+ output->GetMutableValues<uint8_t>(1, /*absolute_offset=*/0),
+ output->offset);
+ }
+ return Status::OK();
+ }
+ auto index = UnboxScalar<Int64Type>::Unbox(index_scalar);
+ if (index < 0 || static_cast<size_t>(index + 1) >= batch.values.size()) {
+ return Status::IndexError("choose: index ", index, " out of range");
+ }
+ auto source = batch.values[index + 1];
+ if (out->is_scalar()) {
+ *out = source;
+ } else {
+ ArrayData* output = out->mutable_array();
+ CopyValues<Type>(source, /*row=*/0, batch.length,
+ output->GetMutableValues<uint8_t>(0, /*absolute_offset=*/0),
+ output->GetMutableValues<uint8_t>(1, /*absolute_offset=*/0),
+ output->offset);
+ }
+ return Status::OK();
+}
+
+template <typename Type>
+Status ExecArrayChoose(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ ArrayData* output = out->mutable_array();
+ const int64_t out_offset = output->offset;
+ // Need a null bitmap if any input has nulls
+ uint8_t* out_valid = nullptr;
+ if (std::any_of(batch.values.begin(), batch.values.end(),
+ [](const Datum& d) { return d.null_count() > 0; })) {
+ out_valid = output->buffers[0]->mutable_data();
+ } else {
+ BitUtil::SetBitsTo(output->buffers[0]->mutable_data(), out_offset, batch.length,
+ true);
+ }
+ uint8_t* out_values = output->buffers[1]->mutable_data();
+ int64_t row = 0;
+ return VisitArrayValuesInline<Int64Type>(
+ *batch[0].array(),
+ [&](int64_t index) {
+ if (index < 0 || static_cast<size_t>(index + 1) >= batch.values.size()) {
+ return Status::IndexError("choose: index ", index, " out of range");
+ }
+ const auto& source = batch.values[index + 1];
+ CopyOneValue<Type>(source, row, out_valid, out_values, out_offset + row);
+ row++;
+ return Status::OK();
+ },
+ [&]() {
+ // Index is null, but we should still initialize the output with some value
+ const auto& source = batch.values[1];
+ CopyOneValue<Type>(source, row, out_valid, out_values, out_offset + row);
+ BitUtil::ClearBit(out_valid, out_offset + row);
+ row++;
+ return Status::OK();
+ });
+}
+
+template <typename Type, typename Enable = void>
+struct ChooseFunctor {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch.values[0].is_scalar()) {
+ return ExecScalarChoose<Type>(ctx, batch, out);
+ }
+ return ExecArrayChoose<Type>(ctx, batch, out);
+ }
+};
+
+template <>
+struct ChooseFunctor<NullType> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return Status::OK();
+ }
+};
+
+template <typename Type>
+struct ChooseFunctor<Type, enable_if_base_binary<Type>> {
+ using offset_type = typename Type::offset_type;
+ using BuilderType = typename TypeTraits<Type>::BuilderType;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch.values[0].is_scalar()) {
+ const auto& index_scalar = *batch[0].scalar();
+ if (!index_scalar.is_valid) {
+ if (out->is_array()) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto temp_array,
+ MakeArrayOfNull(out->type(), batch.length, ctx->memory_pool()));
+ *out->mutable_array() = *temp_array->data();
+ }
+ return Status::OK();
+ }
+ auto index = UnboxScalar<Int64Type>::Unbox(index_scalar);
+ if (index < 0 || static_cast<size_t>(index + 1) >= batch.values.size()) {
+ return Status::IndexError("choose: index ", index, " out of range");
+ }
+ auto source = batch.values[index + 1];
+ if (source.is_scalar() && out->is_array()) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto temp_array,
+ MakeArrayFromScalar(*source.scalar(), batch.length, ctx->memory_pool()));
+ *out->mutable_array() = *temp_array->data();
+ } else {
+ *out = source;
+ }
+ return Status::OK();
+ }
+
+ // Row-wise implementation
+ BuilderType builder(out->type(), ctx->memory_pool());
+ RETURN_NOT_OK(builder.Reserve(batch.length));
+ int64_t reserve_data = 0;
+ for (const auto& value : batch.values) {
+ if (value.is_scalar()) {
+ if (!value.scalar()->is_valid) continue;
+ const auto row_length =
+ checked_cast<const BaseBinaryScalar&>(*value.scalar()).value->size();
+ reserve_data = std::max<int64_t>(reserve_data, batch.length * row_length);
+ continue;
+ }
+ const ArrayData& arr = *value.array();
+ const offset_type* offsets = arr.GetValues<offset_type>(1);
+ const offset_type values_length = offsets[arr.length] - offsets[0];
+ reserve_data = std::max<int64_t>(reserve_data, values_length);
+ }
+ RETURN_NOT_OK(builder.ReserveData(reserve_data));
+ int64_t row = 0;
+ RETURN_NOT_OK(VisitArrayValuesInline<Int64Type>(
+ *batch[0].array(),
+ [&](int64_t index) {
+ if (index < 0 || static_cast<size_t>(index + 1) >= batch.values.size()) {
+ return Status::IndexError("choose: index ", index, " out of range");
+ }
+ const auto& source = batch.values[index + 1];
+ return CopyValue(source, &builder, row++);
+ },
+ [&]() {
+ row++;
+ return builder.AppendNull();
+ }));
+ auto actual_type = out->type();
+ std::shared_ptr<Array> temp_output;
+ RETURN_NOT_OK(builder.Finish(&temp_output));
+ ArrayData* output = out->mutable_array();
+ *output = *temp_output->data();
+ // Builder type != logical type due to GenerateTypeAgnosticVarBinaryBase
+ output->type = std::move(actual_type);
+ return Status::OK();
+ }
+
+ static Status CopyValue(const Datum& datum, BuilderType* builder, int64_t row) {
+ if (datum.is_scalar()) {
+ const auto& scalar = checked_cast<const BaseBinaryScalar&>(*datum.scalar());
+ if (!scalar.value) return builder->AppendNull();
+ return builder->Append(scalar.value->data(),
+ static_cast<offset_type>(scalar.value->size()));
+ }
+ const ArrayData& source = *datum.array();
+ if (!source.MayHaveNulls() ||
+ BitUtil::GetBit(source.buffers[0]->data(), source.offset + row)) {
+ const uint8_t* data = source.buffers[2]->data();
+ const offset_type* offsets = source.GetValues<offset_type>(1);
+ const offset_type offset0 = offsets[row];
+ const offset_type offset1 = offsets[row + 1];
+ return builder->Append(data + offset0, offset1 - offset0);
+ }
+ return builder->AppendNull();
+ }
+};
+
+struct ChooseFunction : ScalarFunction {
+ using ScalarFunction::ScalarFunction;
+
+ Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
+ // The first argument is always int64 or promoted to it. The kernel is dispatched
+ // based on the type of the rest of the arguments.
+ RETURN_NOT_OK(CheckArity(*values));
+ EnsureDictionaryDecoded(values);
+ if (values->front().type->id() != Type::INT64) {
+ values->front().type = int64();
+ }
+ if (auto type = CommonNumeric(values->data() + 1, values->size() - 1)) {
+ for (auto it = values->begin() + 1; it != values->end(); it++) {
+ it->type = type;
+ }
+ }
+ if (auto kernel = DispatchExactImpl(this, {values->back()})) return kernel;
+ return arrow::compute::detail::NoMatchingKernel(this, *values);
+ }
+};
+
+void AddCaseWhenKernel(const std::shared_ptr<CaseWhenFunction>& scalar_function,
+ detail::GetTypeId get_id, ArrayKernelExec exec) {
+ ScalarKernel kernel(
+ KernelSignature::Make({InputType(Type::STRUCT), InputType(get_id.id)},
+ OutputType(LastType),
+ /*is_varargs=*/true),
+ exec);
+ if (is_fixed_width(get_id.id)) {
+ kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::PREALLOCATE;
+ kernel.can_write_into_slices = true;
+ } else {
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ kernel.can_write_into_slices = false;
+ }
+ DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
+}
+
+void AddPrimitiveCaseWhenKernels(const std::shared_ptr<CaseWhenFunction>& scalar_function,
+ const std::vector<std::shared_ptr<DataType>>& types) {
+ for (auto&& type : types) {
+ auto exec = GenerateTypeAgnosticPrimitive<CaseWhenFunctor>(*type);
+ AddCaseWhenKernel(scalar_function, type, std::move(exec));
+ }
+}
+
+void AddBinaryCaseWhenKernels(const std::shared_ptr<CaseWhenFunction>& scalar_function,
+ const std::vector<std::shared_ptr<DataType>>& types) {
+ for (auto&& type : types) {
+ auto exec = GenerateTypeAgnosticVarBinaryBase<CaseWhenFunctor>(*type);
+ AddCaseWhenKernel(scalar_function, type, std::move(exec));
+ }
+}
+
+void AddCoalesceKernel(const std::shared_ptr<ScalarFunction>& scalar_function,
+ detail::GetTypeId get_id, ArrayKernelExec exec) {
+ ScalarKernel kernel(KernelSignature::Make({InputType(get_id.id)}, OutputType(FirstType),
+ /*is_varargs=*/true),
+ exec);
+ kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::PREALLOCATE;
+ kernel.can_write_into_slices = is_fixed_width(get_id.id);
+ DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
+}
+
+void AddPrimitiveCoalesceKernels(const std::shared_ptr<ScalarFunction>& scalar_function,
+ const std::vector<std::shared_ptr<DataType>>& types) {
+ for (auto&& type : types) {
+ auto exec = GenerateTypeAgnosticPrimitive<CoalesceFunctor>(*type);
+ AddCoalesceKernel(scalar_function, type, std::move(exec));
+ }
+}
+
+void AddChooseKernel(const std::shared_ptr<ScalarFunction>& scalar_function,
+ detail::GetTypeId get_id, ArrayKernelExec exec) {
+ ScalarKernel kernel(
+ KernelSignature::Make({Type::INT64, InputType(get_id.id)}, OutputType(LastType),
+ /*is_varargs=*/true),
+ exec);
+ kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::PREALLOCATE;
+ kernel.can_write_into_slices = is_fixed_width(get_id.id);
+ DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
+}
+
+void AddPrimitiveChooseKernels(const std::shared_ptr<ScalarFunction>& scalar_function,
+ const std::vector<std::shared_ptr<DataType>>& types) {
+ for (auto&& type : types) {
+ auto exec = GenerateTypeAgnosticPrimitive<ChooseFunctor>(*type);
+ AddChooseKernel(scalar_function, type, std::move(exec));
+ }
+}
+
+const FunctionDoc if_else_doc{"Choose values based on a condition",
+ ("`cond` must be a Boolean scalar/ array. \n`left` or "
+ "`right` must be of the same type scalar/ array.\n"
+ "`null` values in `cond` will be promoted to the"
+ " output."),
+ {"cond", "left", "right"}};
+
+const FunctionDoc case_when_doc{
+ "Choose values based on multiple conditions",
+ ("`cond` must be a struct of Boolean values. `cases` can be a mix "
+ "of scalar and array arguments (of any type, but all must be the "
+ "same type or castable to a common type), with either exactly one "
+ "datum per child of `cond`, or one more `cases` than children of "
+ "`cond` (in which case we have an \"else\" value).\n"
+ "Each row of the output will be the corresponding value of the "
+ "first datum in `cases` for which the corresponding child of `cond` "
+ "is true, or otherwise the \"else\" value (if given), or null. "
+ "Essentially, this implements a switch-case or if-else, if-else... "
+ "statement."),
+ {"cond", "*cases"}};
+
+const FunctionDoc coalesce_doc{
+ "Select the first non-null value in each slot",
+ ("Each row of the output will be the value from the first corresponding input "
+ "for which the value is not null. If all inputs are null in a row, the output "
+ "will be null."),
+ {"*values"}};
+
+const FunctionDoc choose_doc{
+ "Given indices and arrays, choose the value from the corresponding array for each "
+ "index",
+ ("For each row, the value of the first argument is used as a 0-based index into the "
+ "rest of the arguments (i.e. index 0 selects the second argument). The output value "
+ "is the corresponding value of the selected argument.\n"
+ "If an index is null, the output will be null."),
+ {"indices", "*values"}};
+} // namespace
+
+void RegisterScalarIfElse(FunctionRegistry* registry) {
+ {
+ auto func =
+ std::make_shared<IfElseFunction>("if_else", Arity::Ternary(), &if_else_doc);
+
+ AddPrimitiveIfElseKernels(func, NumericTypes());
+ AddPrimitiveIfElseKernels(func, TemporalTypes());
+ AddPrimitiveIfElseKernels(func, IntervalTypes());
+ AddPrimitiveIfElseKernels(func, {boolean()});
+ AddNullIfElseKernel(func);
+ AddBinaryIfElseKernels(func, BaseBinaryTypes());
+ AddFixedWidthIfElseKernel<FixedSizeBinaryType>(func);
+ AddFixedWidthIfElseKernel<Decimal128Type>(func);
+ AddFixedWidthIfElseKernel<Decimal256Type>(func);
+ AddNestedIfElseKernels(func);
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+ {
+ auto func = std::make_shared<CaseWhenFunction>(
+ "case_when", Arity::VarArgs(/*min_args=*/2), &case_when_doc);
+ AddPrimitiveCaseWhenKernels(func, NumericTypes());
+ AddPrimitiveCaseWhenKernels(func, TemporalTypes());
+ AddPrimitiveCaseWhenKernels(func, IntervalTypes());
+ AddPrimitiveCaseWhenKernels(func, {boolean(), null()});
+ AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY,
+ CaseWhenFunctor<FixedSizeBinaryType>::Exec);
+ AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor<Decimal128Type>::Exec);
+ AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor<Decimal256Type>::Exec);
+ AddBinaryCaseWhenKernels(func, BaseBinaryTypes());
+ AddCaseWhenKernel(func, Type::FIXED_SIZE_LIST,
+ CaseWhenFunctor<FixedSizeListType>::Exec);
+ AddCaseWhenKernel(func, Type::LIST, CaseWhenFunctor<ListType>::Exec);
+ AddCaseWhenKernel(func, Type::LARGE_LIST, CaseWhenFunctor<LargeListType>::Exec);
+ AddCaseWhenKernel(func, Type::MAP, CaseWhenFunctor<MapType>::Exec);
+ AddCaseWhenKernel(func, Type::STRUCT, CaseWhenFunctor<StructType>::Exec);
+ AddCaseWhenKernel(func, Type::DENSE_UNION, CaseWhenFunctor<DenseUnionType>::Exec);
+ AddCaseWhenKernel(func, Type::SPARSE_UNION, CaseWhenFunctor<SparseUnionType>::Exec);
+ AddCaseWhenKernel(func, Type::DICTIONARY, CaseWhenFunctor<DictionaryType>::Exec);
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+ {
+ auto func = std::make_shared<CoalesceFunction>(
+ "coalesce", Arity::VarArgs(/*min_args=*/1), &coalesce_doc);
+ AddPrimitiveCoalesceKernels(func, NumericTypes());
+ AddPrimitiveCoalesceKernels(func, TemporalTypes());
+ AddPrimitiveCoalesceKernels(func, IntervalTypes());
+ AddPrimitiveCoalesceKernels(func, {boolean(), null()});
+ AddCoalesceKernel(func, Type::FIXED_SIZE_BINARY,
+ CoalesceFunctor<FixedSizeBinaryType>::Exec);
+ AddCoalesceKernel(func, Type::DECIMAL128, CoalesceFunctor<Decimal128Type>::Exec);
+ AddCoalesceKernel(func, Type::DECIMAL256, CoalesceFunctor<Decimal256Type>::Exec);
+ for (const auto& ty : BaseBinaryTypes()) {
+ AddCoalesceKernel(func, ty, GenerateTypeAgnosticVarBinaryBase<CoalesceFunctor>(ty));
+ }
+ AddCoalesceKernel(func, Type::FIXED_SIZE_LIST,
+ CoalesceFunctor<FixedSizeListType>::Exec);
+ AddCoalesceKernel(func, Type::LIST, CoalesceFunctor<ListType>::Exec);
+ AddCoalesceKernel(func, Type::LARGE_LIST, CoalesceFunctor<LargeListType>::Exec);
+ AddCoalesceKernel(func, Type::MAP, CoalesceFunctor<MapType>::Exec);
+ AddCoalesceKernel(func, Type::STRUCT, CoalesceFunctor<StructType>::Exec);
+ AddCoalesceKernel(func, Type::DENSE_UNION, CoalesceFunctor<DenseUnionType>::Exec);
+ AddCoalesceKernel(func, Type::SPARSE_UNION, CoalesceFunctor<SparseUnionType>::Exec);
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+ {
+ auto func = std::make_shared<ChooseFunction>("choose", Arity::VarArgs(/*min_args=*/2),
+ &choose_doc);
+ AddPrimitiveChooseKernels(func, NumericTypes());
+ AddPrimitiveChooseKernels(func, TemporalTypes());
+ AddPrimitiveChooseKernels(func, IntervalTypes());
+ AddPrimitiveChooseKernels(func, {boolean(), null()});
+ AddChooseKernel(func, Type::FIXED_SIZE_BINARY,
+ ChooseFunctor<FixedSizeBinaryType>::Exec);
+ AddChooseKernel(func, Type::DECIMAL128, ChooseFunctor<Decimal128Type>::Exec);
+ AddChooseKernel(func, Type::DECIMAL256, ChooseFunctor<Decimal256Type>::Exec);
+ for (const auto& ty : BaseBinaryTypes()) {
+ AddChooseKernel(func, ty, GenerateTypeAgnosticVarBinaryBase<ChooseFunctor>(ty));
+ }
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc
new file mode 100644
index 000000000..b6d6bf6e4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc
@@ -0,0 +1,457 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <benchmark/benchmark.h>
+
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+namespace compute {
+
+const int64_t kNumItems = 1024 * 1024;
+const int64_t kFewItems = 64 * 1024;
+
+template <typename Type, typename Enable = void>
+struct GetBytesProcessed {};
+
+template <>
+struct GetBytesProcessed<BooleanType> {
+ static int64_t Get(const std::shared_ptr<Array>& arr) { return arr->length() / 8; }
+};
+
+template <typename Type>
+struct GetBytesProcessed<Type, enable_if_number<Type>> {
+ static int64_t Get(const std::shared_ptr<Array>& arr) {
+ using CType = typename Type::c_type;
+ return arr->length() * sizeof(CType);
+ }
+};
+
+template <typename Type>
+struct GetBytesProcessed<Type, enable_if_base_binary<Type>> {
+ static int64_t Get(const std::shared_ptr<Array>& arr) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using OffsetType = typename TypeTraits<Type>::OffsetType::c_type;
+ return arr->length() * sizeof(OffsetType) +
+ std::static_pointer_cast<ArrayType>(arr)->total_values_length();
+ }
+};
+
+template <typename Type>
+static void IfElseBench(benchmark::State& state) {
+ auto type = TypeTraits<Type>::type_singleton();
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+
+ int64_t len = state.range(0);
+ int64_t offset = state.range(1);
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+
+ auto cond = std::static_pointer_cast<BooleanArray>(
+ rand.ArrayOf(boolean(), len, /*null_probability=*/0.01))
+ ->Slice(offset);
+ auto left = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01))
+ ->Slice(offset);
+ auto right = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01))
+ ->Slice(offset);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(IfElse(cond, left, right));
+ }
+
+ state.SetBytesProcessed(state.iterations() *
+ (GetBytesProcessed<BooleanType>::Get(cond) +
+ GetBytesProcessed<Type>::Get(left) +
+ GetBytesProcessed<Type>::Get(right)));
+}
+
+template <typename Type>
+static void IfElseBenchContiguous(benchmark::State& state) {
+ auto type = TypeTraits<Type>::type_singleton();
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+
+ int64_t len = state.range(0);
+ int64_t offset = state.range(1);
+
+ ASSERT_OK_AND_ASSIGN(auto temp1, MakeArrayFromScalar(BooleanScalar(true), len / 2));
+ ASSERT_OK_AND_ASSIGN(auto temp2,
+ MakeArrayFromScalar(BooleanScalar(false), len - len / 2));
+ ASSERT_OK_AND_ASSIGN(auto concat, Concatenate({temp1, temp2}));
+ auto cond = std::static_pointer_cast<BooleanArray>(concat)->Slice(offset);
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+ auto left = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01))
+ ->Slice(offset);
+ auto right = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01))
+ ->Slice(offset);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(IfElse(cond, left, right));
+ }
+
+ state.SetBytesProcessed(state.iterations() *
+ (GetBytesProcessed<BooleanType>::Get(cond) +
+ GetBytesProcessed<Type>::Get(left) +
+ GetBytesProcessed<Type>::Get(right)));
+}
+
+static void IfElseBench64(benchmark::State& state) {
+ return IfElseBench<UInt64Type>(state);
+}
+
+static void IfElseBench32(benchmark::State& state) {
+ return IfElseBench<UInt32Type>(state);
+}
+
+static void IfElseBenchString32(benchmark::State& state) {
+ return IfElseBench<StringType>(state);
+}
+
+static void IfElseBenchString64(benchmark::State& state) {
+ return IfElseBench<LargeStringType>(state);
+}
+
+static void IfElseBench64Contiguous(benchmark::State& state) {
+ return IfElseBenchContiguous<UInt64Type>(state);
+}
+
+static void IfElseBench32Contiguous(benchmark::State& state) {
+ return IfElseBenchContiguous<UInt32Type>(state);
+}
+
+static void IfElseBenchString64Contiguous(benchmark::State& state) {
+ return IfElseBenchContiguous<UInt64Type>(state);
+}
+
+static void IfElseBenchString32Contiguous(benchmark::State& state) {
+ return IfElseBenchContiguous<UInt32Type>(state);
+}
+
+template <typename Type>
+static void CaseWhenBench(benchmark::State& state) {
+ auto type = TypeTraits<Type>::type_singleton();
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+
+ int64_t len = state.range(0);
+ int64_t offset = state.range(1);
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+
+ auto cond_field =
+ field("cond", boolean(), key_value_metadata({{"null_probability", "0.01"}}));
+ auto cond = rand.ArrayOf(*field("", struct_({cond_field, cond_field, cond_field}),
+ key_value_metadata({{"null_probability", "0.0"}})),
+ len)
+ ->Slice(offset);
+ auto val1 = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01))
+ ->Slice(offset);
+ auto val2 = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01))
+ ->Slice(offset);
+ auto val3 = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01))
+ ->Slice(offset);
+ auto val4 = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01))
+ ->Slice(offset);
+ for (auto _ : state) {
+ ABORT_NOT_OK(CaseWhen(cond, {val1, val2, val3, val4}));
+ }
+
+ // Set bytes processed to ~length of output
+ state.SetBytesProcessed(state.iterations() * GetBytesProcessed<Type>::Get(val1));
+ state.SetItemsProcessed(state.iterations() * (len - offset));
+}
+
+static void CaseWhenBenchList(benchmark::State& state) {
+ auto type = list(int64());
+ auto fld = field("", type);
+
+ int64_t len = state.range(0);
+ int64_t offset = state.range(1);
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+
+ auto cond_field =
+ field("cond", boolean(), key_value_metadata({{"null_probability", "0.01"}}));
+ auto cond = rand.ArrayOf(*field("", struct_({cond_field, cond_field, cond_field}),
+ key_value_metadata({{"null_probability", "0.0"}})),
+ len);
+ auto val1 = rand.ArrayOf(*fld, len);
+ auto val2 = rand.ArrayOf(*fld, len);
+ auto val3 = rand.ArrayOf(*fld, len);
+ auto val4 = rand.ArrayOf(*fld, len);
+ for (auto _ : state) {
+ ABORT_NOT_OK(
+ CaseWhen(cond->Slice(offset), {val1->Slice(offset), val2->Slice(offset),
+ val3->Slice(offset), val4->Slice(offset)}));
+ }
+
+ // Set bytes processed to ~length of output
+ state.SetBytesProcessed(state.iterations() *
+ GetBytesProcessed<Int64Type>::Get(
+ std::static_pointer_cast<ListArray>(val1)->values()));
+ state.SetItemsProcessed(state.iterations() * (len - offset));
+}
+
+template <typename Type>
+static void CaseWhenBenchContiguous(benchmark::State& state) {
+ auto type = TypeTraits<Type>::type_singleton();
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+
+ int64_t len = state.range(0);
+ int64_t offset = state.range(1);
+
+ ASSERT_OK_AND_ASSIGN(auto trues, MakeArrayFromScalar(BooleanScalar(true), len / 3));
+ ASSERT_OK_AND_ASSIGN(auto falses, MakeArrayFromScalar(BooleanScalar(false), len / 3));
+ ASSERT_OK_AND_ASSIGN(auto nulls, MakeArrayOfNull(boolean(), len - 2 * (len / 3)));
+ ASSERT_OK_AND_ASSIGN(auto concat, Concatenate({trues, falses, nulls}));
+ auto cond1 = std::static_pointer_cast<BooleanArray>(concat);
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+ auto cond2 = std::static_pointer_cast<BooleanArray>(
+ rand.ArrayOf(boolean(), len, /*null_probability=*/0.01));
+ auto val1 = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01))
+ ->Slice(offset);
+ auto val2 = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01))
+ ->Slice(offset);
+ auto val3 = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01))
+ ->Slice(offset);
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<Array> cond,
+ StructArray::Make({cond1, cond2}, std::vector<std::string>{"a", "b"}, nullptr,
+ /*null_count=*/0));
+ cond = cond->Slice(offset);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(CaseWhen(cond, {val1, val2, val3}));
+ }
+
+ // Set bytes processed to ~length of output
+ state.SetBytesProcessed(state.iterations() * GetBytesProcessed<Type>::Get(val1));
+ state.SetItemsProcessed(state.iterations() * (len - offset));
+}
+
+static void CaseWhenBench64(benchmark::State& state) {
+ return CaseWhenBench<UInt64Type>(state);
+}
+
+static void CaseWhenBench64Contiguous(benchmark::State& state) {
+ return CaseWhenBenchContiguous<UInt64Type>(state);
+}
+
+static void CaseWhenBenchString(benchmark::State& state) {
+ return CaseWhenBench<StringType>(state);
+}
+
+static void CaseWhenBenchStringContiguous(benchmark::State& state) {
+ return CaseWhenBenchContiguous<StringType>(state);
+}
+
+struct CoalesceParams {
+ int64_t length;
+ int64_t num_arguments;
+ double null_probability;
+};
+
+std::vector<CoalesceParams> g_coalesce_params = {
+ {kNumItems, 2, 0.01}, {kNumItems, 4, 0.01}, {kNumItems, 2, 0.25},
+ {kNumItems, 4, 0.25}, {kNumItems, 2, 0.50}, {kNumItems, 4, 0.50},
+ {kNumItems, 2, 0.99}, {kNumItems, 4, 0.99},
+};
+
+struct CoalesceArgs : public CoalesceParams {
+ explicit CoalesceArgs(benchmark::State& state) : state_(state) {
+ const auto& params = g_coalesce_params[state.range(0)];
+ length = params.length;
+ num_arguments = params.num_arguments;
+ null_probability = params.null_probability;
+ }
+
+ ~CoalesceArgs() {
+ state_.counters["length"] = static_cast<double>(length);
+ state_.counters["null%"] = null_probability * 100;
+ state_.counters["num_args"] = static_cast<double>(num_arguments);
+ }
+
+ private:
+ benchmark::State& state_;
+};
+
+template <typename Type>
+static void CoalesceBench(benchmark::State& state) {
+ auto type = TypeTraits<Type>::type_singleton();
+ CoalesceArgs params(state);
+ random::RandomArrayGenerator rand(/*seed=*/0);
+
+ std::vector<Datum> arguments;
+ for (int i = 0; i < params.num_arguments; i++) {
+ arguments.emplace_back(rand.ArrayOf(type, params.length, params.null_probability));
+ }
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(CallFunction("coalesce", arguments));
+ }
+
+ state.SetBytesProcessed(state.iterations() *
+ GetBytesProcessed<Type>::Get(arguments.front().make_array()));
+ state.SetItemsProcessed(state.iterations() * params.length);
+}
+
+template <typename Type>
+static void CoalesceScalarBench(benchmark::State& state) {
+ using CType = typename Type::c_type;
+ auto type = TypeTraits<Type>::type_singleton();
+ CoalesceArgs params(state);
+ random::RandomArrayGenerator rand(/*seed=*/0);
+
+ std::vector<Datum> arguments = {
+ rand.ArrayOf(type, params.length, params.null_probability),
+ Datum(CType(42)),
+ };
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(CallFunction("coalesce", arguments));
+ }
+
+ state.SetBytesProcessed(state.iterations() *
+ GetBytesProcessed<Type>::Get(arguments.front().make_array()));
+ state.SetItemsProcessed(state.iterations() * params.length);
+}
+
+static void CoalesceScalarStringBench(benchmark::State& state) {
+ CoalesceArgs params(state);
+ random::RandomArrayGenerator rand(/*seed=*/0);
+
+ auto arr = rand.ArrayOf(utf8(), params.length, params.null_probability);
+ std::vector<Datum> arguments = {arr, Datum("foobar")};
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(CallFunction("coalesce", arguments));
+ }
+
+ state.SetBytesProcessed(state.iterations() * GetBytesProcessed<StringType>::Get(
+ arguments.front().make_array()));
+ state.SetItemsProcessed(state.iterations() * params.length);
+}
+
+static void CoalesceBench64(benchmark::State& state) {
+ return CoalesceBench<Int64Type>(state);
+}
+
+static void CoalesceScalarBench64(benchmark::State& state) {
+ return CoalesceScalarBench<Int64Type>(state);
+}
+
+template <typename Type>
+static void ChooseBench(benchmark::State& state) {
+ constexpr int kNumChoices = 5;
+ auto type = TypeTraits<Type>::type_singleton();
+
+ int64_t len = state.range(0);
+ int64_t offset = state.range(1);
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+
+ std::vector<Datum> arguments;
+ arguments.emplace_back(
+ rand.Int64(len, /*min=*/0, /*max=*/kNumChoices - 1, /*null_probability=*/0.1)
+ ->Slice(offset));
+ for (int i = 0; i < kNumChoices; i++) {
+ arguments.emplace_back(
+ rand.ArrayOf(type, len, /*null_probability=*/0.25)->Slice(offset));
+ }
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(CallFunction("choose", arguments));
+ }
+
+ state.SetBytesProcessed(state.iterations() *
+ GetBytesProcessed<Type>::Get(arguments[1].make_array()));
+ state.SetItemsProcessed(state.iterations() * (len - offset));
+}
+
+static void ChooseBench64(benchmark::State& state) {
+ return ChooseBench<Int64Type>(state);
+}
+
+BENCHMARK(IfElseBench32)->Args({kNumItems, 0});
+BENCHMARK(IfElseBench64)->Args({kNumItems, 0});
+
+BENCHMARK(IfElseBench32)->Args({kNumItems, 99});
+BENCHMARK(IfElseBench64)->Args({kNumItems, 99});
+
+BENCHMARK(IfElseBench32Contiguous)->Args({kNumItems, 0});
+BENCHMARK(IfElseBench64Contiguous)->Args({kNumItems, 0});
+
+BENCHMARK(IfElseBench32Contiguous)->Args({kNumItems, 99});
+BENCHMARK(IfElseBench64Contiguous)->Args({kNumItems, 99});
+
+BENCHMARK(IfElseBenchString32)->Args({kNumItems, 0});
+BENCHMARK(IfElseBenchString64)->Args({kNumItems, 0});
+
+BENCHMARK(IfElseBenchString32Contiguous)->Args({kNumItems, 99});
+BENCHMARK(IfElseBenchString64Contiguous)->Args({kNumItems, 99});
+
+BENCHMARK(CaseWhenBench64)->Args({kNumItems, 0});
+BENCHMARK(CaseWhenBench64)->Args({kNumItems, 99});
+
+BENCHMARK(CaseWhenBench64Contiguous)->Args({kNumItems, 0});
+BENCHMARK(CaseWhenBench64Contiguous)->Args({kNumItems, 99});
+
+BENCHMARK(CaseWhenBenchList)->Args({kFewItems, 0});
+BENCHMARK(CaseWhenBenchList)->Args({kFewItems, 99});
+
+BENCHMARK(CaseWhenBenchString)->Args({kFewItems, 0});
+BENCHMARK(CaseWhenBenchString)->Args({kFewItems, 99});
+
+BENCHMARK(CaseWhenBenchStringContiguous)->Args({kFewItems, 0});
+BENCHMARK(CaseWhenBenchStringContiguous)->Args({kFewItems, 99});
+
+void CoalesceSetArgs(benchmark::internal::Benchmark* bench) {
+ for (size_t i = 0; i < g_coalesce_params.size(); i++) {
+ bench->Args({static_cast<int64_t>(i)});
+ }
+}
+void CoalesceSetBinaryArgs(benchmark::internal::Benchmark* bench) {
+ for (size_t i = 0; i < g_coalesce_params.size(); i++) {
+ if (g_coalesce_params[i].num_arguments == 2) {
+ bench->Args({static_cast<int64_t>(i)});
+ }
+ }
+}
+BENCHMARK(CoalesceBench64)->Apply(CoalesceSetArgs);
+BENCHMARK(CoalesceScalarBench64)->Apply(CoalesceSetBinaryArgs);
+BENCHMARK(CoalesceScalarStringBench)->Apply(CoalesceSetBinaryArgs);
+
+BENCHMARK(ChooseBench64)->Args({kNumItems, 0});
+BENCHMARK(ChooseBench64)->Args({kNumItems, 99});
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
new file mode 100644
index 000000000..92e0582c6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -0,0 +1,2922 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include "arrow/array.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/compute/registry.h"
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+namespace compute {
+
+// Helper that combines a dictionary and the value type so it can
+// later be used with DictArrayFromJSON
+struct JsonDict {
+ std::shared_ptr<DataType> type;
+ std::string value;
+};
+
+// Helper that makes a list of dictionary indices
+std::shared_ptr<Array> MakeListOfDict(const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& backing_array) {
+ EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices, *backing_array));
+ return result;
+}
+
+void CheckIfElseOutput(const Datum& cond, const Datum& left, const Datum& right,
+ const Datum& expected) {
+ ASSERT_OK_AND_ASSIGN(Datum datum_out, IfElse(cond, left, right));
+ if (datum_out.is_array()) {
+ std::shared_ptr<Array> result = datum_out.make_array();
+ ValidateOutput(*result);
+ std::shared_ptr<Array> expected_ = expected.make_array();
+ AssertArraysEqual(*expected_, *result, /*verbose=*/true);
+ } else { // expecting scalar
+ const std::shared_ptr<Scalar>& result = datum_out.scalar();
+ const std::shared_ptr<Scalar>& expected_ = expected.scalar();
+ AssertScalarsEqual(*expected_, *result, /*verbose=*/true);
+ }
+}
+
+class TestIfElseKernel : public ::testing::Test {};
+
+template <typename Type>
+class TestIfElsePrimitive : public ::testing::Test {};
+
+using NumericBasedTypes =
+ ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
+ Int32Type, Int64Type, FloatType, DoubleType, Date32Type, Date64Type,
+ Time32Type, Time64Type, TimestampType, MonthIntervalType>;
+
+TYPED_TEST_SUITE(TestIfElsePrimitive, NumericBasedTypes);
+
+TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeRand) {
+ using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
+ auto type = default_type_instance<TypeParam>();
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+ int64_t len = 1000;
+
+ // adding 64 consecutive 1's and 0's in the cond array to test all-true/ all-false
+ // word code paths
+ ASSERT_OK_AND_ASSIGN(auto temp1, MakeArrayFromScalar(BooleanScalar(true), 64));
+ ASSERT_OK_AND_ASSIGN(auto temp2, MakeArrayFromScalar(BooleanScalar(false), 64));
+ auto temp3 = rand.ArrayOf(boolean(), len - 64 * 2, /*null_probability=*/0.01);
+
+ ASSERT_OK_AND_ASSIGN(auto concat, Concatenate({temp1, temp2, temp3}));
+ auto cond = std::static_pointer_cast<BooleanArray>(concat);
+ auto left = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01));
+ auto right = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01));
+
+ typename TypeTraits<TypeParam>::BuilderType builder(type, default_memory_pool());
+
+ for (int64_t i = 0; i < len; ++i) {
+ if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) ||
+ (!cond->Value(i) && !right->IsValid(i))) {
+ ASSERT_OK(builder.AppendNull());
+ continue;
+ }
+
+ if (cond->Value(i)) {
+ ASSERT_OK(builder.Append(left->Value(i)));
+ } else {
+ ASSERT_OK(builder.Append(right->Value(i)));
+ }
+ }
+ ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish());
+
+ CheckIfElseOutput(cond, left, right, expected_data);
+}
+
+void CheckWithDifferentShapes(const std::shared_ptr<Array>& cond,
+ const std::shared_ptr<Array>& left,
+ const std::shared_ptr<Array>& right,
+ const std::shared_ptr<Array>& expected) {
+ // this will check for whole arrays, every scalar at i'th index and slicing (offset)
+ CheckScalar("if_else", {cond, left, right}, expected);
+
+ auto len = left->length();
+ std::vector<int64_t> array_indices = {-1}; // sentinel for make_input
+ std::vector<int64_t> scalar_indices(len);
+ std::iota(scalar_indices.begin(), scalar_indices.end(), 0);
+ auto make_input = [&](const std::shared_ptr<Array>& array, int64_t index, Datum* input,
+ Datum* input_broadcast, std::string* trace) {
+ if (index >= 0) {
+ // Use scalar from array[index] as input; broadcast scalar for computing expected
+ // result
+ ASSERT_OK_AND_ASSIGN(auto scalar, array->GetScalar(index));
+ *trace += "@" + std::to_string(index) + "=" + scalar->ToString();
+ *input = std::move(scalar);
+ ASSERT_OK_AND_ASSIGN(*input_broadcast, MakeArrayFromScalar(*input->scalar(), len));
+ } else {
+ // Use array as input
+ *trace += "=Array";
+ *input = *input_broadcast = array;
+ }
+ };
+
+ enum { COND_SCALAR = 1, LEFT_SCALAR = 2, RIGHT_SCALAR = 4 };
+ for (int mask = 1; mask <= (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR); ++mask) {
+ for (int64_t cond_idx : (mask & COND_SCALAR) ? scalar_indices : array_indices) {
+ Datum cond_in, cond_bcast;
+ std::string trace_cond = "Cond";
+ make_input(cond, cond_idx, &cond_in, &cond_bcast, &trace_cond);
+
+ for (int64_t left_idx : (mask & LEFT_SCALAR) ? scalar_indices : array_indices) {
+ Datum left_in, left_bcast;
+ std::string trace_left = "Left";
+ make_input(left, left_idx, &left_in, &left_bcast, &trace_left);
+
+ for (int64_t right_idx : (mask & RIGHT_SCALAR) ? scalar_indices : array_indices) {
+ Datum right_in, right_bcast;
+ std::string trace_right = "Right";
+ make_input(right, right_idx, &right_in, &right_bcast, &trace_right);
+
+ SCOPED_TRACE(trace_right);
+ SCOPED_TRACE(trace_left);
+ SCOPED_TRACE(trace_cond);
+
+ Datum expected;
+ ASSERT_OK_AND_ASSIGN(auto actual, IfElse(cond_in, left_in, right_in));
+ if (mask == (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR)) {
+ const auto& scalar = cond_in.scalar_as<BooleanScalar>();
+ if (scalar.is_valid) {
+ expected = scalar.value ? left_in : right_in;
+ } else {
+ expected = MakeNullScalar(left_in.type());
+ }
+ if (!left_in.type()->Equals(*right_in.type())) {
+ ASSERT_OK_AND_ASSIGN(expected,
+ Cast(expected, CastOptions::Safe(actual.type())));
+ }
+ } else {
+ ASSERT_OK_AND_ASSIGN(expected, IfElse(cond_bcast, left_bcast, right_bcast));
+ }
+ AssertDatumsEqual(expected, actual, /*verbose=*/true);
+ }
+ }
+ }
+ } // for (mask)
+}
+
+TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) {
+ auto type = default_type_instance<TypeParam>();
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(type, "[1, 2, 3, 4]"),
+ ArrayFromJSON(type, "[5, 6, 7, 8]"),
+ ArrayFromJSON(type, "[1, 2, 3, 8]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(type, "[1, 2, 3, 4]"),
+ ArrayFromJSON(type, "[5, 6, 7, null]"),
+ ArrayFromJSON(type, "[1, 2, 3, null]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(type, "[1, 2, null, 4]"),
+ ArrayFromJSON(type, "[5, 6, 7, null]"),
+ ArrayFromJSON(type, "[1, 2, null, null]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(type, "[1, 2, null, 4]"),
+ ArrayFromJSON(type, "[5, 6, 7, 8]"),
+ ArrayFromJSON(type, "[1, 2, null, 8]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"),
+ ArrayFromJSON(type, "[1, 2, null, 4]"),
+ ArrayFromJSON(type, "[5, 6, 7, 8]"),
+ ArrayFromJSON(type, "[null, 2, null, 8]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"),
+ ArrayFromJSON(type, "[1, 2, null, 4]"),
+ ArrayFromJSON(type, "[5, 6, 7, null]"),
+ ArrayFromJSON(type, "[null, 2, null, null]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"),
+ ArrayFromJSON(type, "[1, 2, 3, 4]"),
+ ArrayFromJSON(type, "[5, 6, 7, null]"),
+ ArrayFromJSON(type, "[null, 2, 3, null]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"),
+ ArrayFromJSON(type, "[1, 2, 3, 4]"),
+ ArrayFromJSON(type, "[5, 6, 7, 8]"),
+ ArrayFromJSON(type, "[null, 2, 3, 8]"));
+}
+
+TEST_F(TestIfElseKernel, IfElseBoolean) {
+ auto type = boolean();
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(type, "[false, false, false, false]"),
+ ArrayFromJSON(type, "[true, true, true, true]"),
+ ArrayFromJSON(type, "[false, false, false, true]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(type, "[false, false, false, false]"),
+ ArrayFromJSON(type, "[true, true, true, null]"),
+ ArrayFromJSON(type, "[false, false, false, null]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(type, "[false, false, null, false]"),
+ ArrayFromJSON(type, "[true, true, true, null]"),
+ ArrayFromJSON(type, "[false, false, null, null]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(type, "[false, false, null, false]"),
+ ArrayFromJSON(type, "[true, true, true, true]"),
+ ArrayFromJSON(type, "[false, false, null, true]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"),
+ ArrayFromJSON(type, "[false, false, null, false]"),
+ ArrayFromJSON(type, "[true, true, true, true]"),
+ ArrayFromJSON(type, "[null, false, null, true]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"),
+ ArrayFromJSON(type, "[false, false, null, false]"),
+ ArrayFromJSON(type, "[true, true, true, null]"),
+ ArrayFromJSON(type, "[null, false, null, null]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"),
+ ArrayFromJSON(type, "[false, false, false, false]"),
+ ArrayFromJSON(type, "[true, true, true, null]"),
+ ArrayFromJSON(type, "[null, false, false, null]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"),
+ ArrayFromJSON(type, "[false, false, false, false]"),
+ ArrayFromJSON(type, "[true, true, true, true]"),
+ ArrayFromJSON(type, "[null, false, false, true]"));
+}
+
+TEST_F(TestIfElseKernel, IfElseBooleanRand) {
+ auto type = boolean();
+ random::RandomArrayGenerator rand(/*seed=*/0);
+ int64_t len = 1000;
+ auto cond = std::static_pointer_cast<BooleanArray>(
+ rand.ArrayOf(boolean(), len, /*null_probability=*/0.01));
+ auto left = std::static_pointer_cast<BooleanArray>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01));
+ auto right = std::static_pointer_cast<BooleanArray>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01));
+
+ BooleanBuilder builder;
+ for (int64_t i = 0; i < len; ++i) {
+ if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) ||
+ (!cond->Value(i) && !right->IsValid(i))) {
+ ASSERT_OK(builder.AppendNull());
+ continue;
+ }
+
+ if (cond->Value(i)) {
+ ASSERT_OK(builder.Append(left->Value(i)));
+ } else {
+ ASSERT_OK(builder.Append(right->Value(i)));
+ }
+ }
+ ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish());
+
+ CheckIfElseOutput(cond, left, right, expected_data);
+}
+
+TEST_F(TestIfElseKernel, IfElseNull) {
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(null(), "[null, null, null, null]"),
+ ArrayFromJSON(null(), "[null, null, null, null]"),
+ ArrayFromJSON(null(), "[null, null, null, null]"));
+}
+
+TEST_F(TestIfElseKernel, IfElseMultiType) {
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(int32(), "[1, 2, 3, 4]"),
+ ArrayFromJSON(float32(), "[5, 6, 7, 8]"),
+ ArrayFromJSON(float32(), "[1, 2, 3, 8]"));
+}
+
+TEST_F(TestIfElseKernel, TimestampTypes) {
+ for (const auto unit : TimeUnit::values()) {
+ auto ty = timestamp(unit);
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(ty, "[1, 2, 3, 4]"),
+ ArrayFromJSON(ty, "[5, 6, 7, 8]"),
+ ArrayFromJSON(ty, "[1, 2, 3, 8]"));
+
+ ty = timestamp(unit, "America/Phoenix");
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(ty, "[1, 2, 3, 4]"),
+ ArrayFromJSON(ty, "[5, 6, 7, 8]"),
+ ArrayFromJSON(ty, "[1, 2, 3, 8]"));
+ }
+}
+
+TEST_F(TestIfElseKernel, TemporalTypes) {
+ for (const auto& ty : TemporalTypes()) {
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(ty, "[1, 2, 3, 4]"),
+ ArrayFromJSON(ty, "[5, 6, 7, 8]"),
+ ArrayFromJSON(ty, "[1, 2, 3, 8]"));
+ }
+}
+
+TEST_F(TestIfElseKernel, DayTimeInterval) {
+ auto ty = day_time_interval();
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(ty, "[[1, 2], [3, -4], [-5, 6], [-7, -8]]"),
+ ArrayFromJSON(ty, "[[-9, -10], [11, -12], [-13, 14], [15, 16]]"),
+ ArrayFromJSON(ty, "[[1, 2], [3, -4], [-5, 6], [15, 16]]"));
+}
+
+TEST_F(TestIfElseKernel, IfElseDispatchBest) {
+ std::string name = "if_else";
+ ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(name));
+ CheckDispatchBest(name, {boolean(), int32(), int32()}, {boolean(), int32(), int32()});
+ CheckDispatchBest(name, {boolean(), int32(), null()}, {boolean(), int32(), int32()});
+ CheckDispatchBest(name, {boolean(), null(), int32()}, {boolean(), int32(), int32()});
+
+ CheckDispatchBest(name, {boolean(), int32(), int8()}, {boolean(), int32(), int32()});
+ CheckDispatchBest(name, {boolean(), int32(), int16()}, {boolean(), int32(), int32()});
+ CheckDispatchBest(name, {boolean(), int32(), int32()}, {boolean(), int32(), int32()});
+ CheckDispatchBest(name, {boolean(), int32(), int64()}, {boolean(), int64(), int64()});
+
+ CheckDispatchBest(name, {boolean(), int32(), uint8()}, {boolean(), int32(), int32()});
+ CheckDispatchBest(name, {boolean(), int32(), uint16()}, {boolean(), int32(), int32()});
+ CheckDispatchBest(name, {boolean(), int32(), uint32()}, {boolean(), int64(), int64()});
+ CheckDispatchBest(name, {boolean(), int32(), uint64()}, {boolean(), int64(), int64()});
+
+ CheckDispatchBest(name, {boolean(), uint8(), uint8()}, {boolean(), uint8(), uint8()});
+ CheckDispatchBest(name, {boolean(), uint8(), uint16()},
+ {boolean(), uint16(), uint16()});
+
+ CheckDispatchBest(name, {boolean(), int32(), float32()},
+ {boolean(), float32(), float32()});
+ CheckDispatchBest(name, {boolean(), float32(), int64()},
+ {boolean(), float32(), float32()});
+ CheckDispatchBest(name, {boolean(), float64(), int32()},
+ {boolean(), float64(), float64()});
+
+ CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()});
+
+ CheckDispatchBest(name,
+ {boolean(), timestamp(TimeUnit::SECOND), timestamp(TimeUnit::MILLI)},
+ {boolean(), timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)});
+ CheckDispatchBest(name, {boolean(), date32(), timestamp(TimeUnit::MILLI)},
+ {boolean(), timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)});
+ CheckDispatchBest(name, {boolean(), date32(), date64()},
+ {boolean(), date64(), date64()});
+ CheckDispatchBest(name, {boolean(), date32(), date32()},
+ {boolean(), date32(), date32()});
+}
+
+template <typename Type>
+class TestIfElseBaseBinary : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestIfElseBaseBinary, BinaryArrowTypes);
+
+TYPED_TEST(TestIfElseBaseBinary, IfElseBaseBinary) {
+ auto type = TypeTraits<TypeParam>::type_singleton();
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(type, R"(["a", "ab", "abc", "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmn", "lm", "l"])"),
+ ArrayFromJSON(type, R"(["a", "ab", "abc", "l"])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"),
+ ArrayFromJSON(type, R"(["a", "ab", "abc", "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmn", "lm", null])"),
+ ArrayFromJSON(type, R"(["a", "ab", "abc", null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"),
+ ArrayFromJSON(type, R"(["a", "ab", null, "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmn", "lm", null])"),
+ ArrayFromJSON(type, R"(["a", "ab", null, null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"),
+ ArrayFromJSON(type, R"(["a", "ab", null, "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmn", "lm", "l"])"),
+ ArrayFromJSON(type, R"(["a", "ab", null, "l"])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(type, R"(["a", "ab", null, "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmn", "lm", "l"])"),
+ ArrayFromJSON(type, R"([null, "ab", null, "l"])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(type, R"(["a", "ab", null, "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmn", "lm", null])"),
+ ArrayFromJSON(type, R"([null, "ab", null, null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(type, R"(["a", "ab", "abc", "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmn", "lm", null])"),
+ ArrayFromJSON(type, R"([null, "ab", "abc", null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(type, R"(["a", "ab", "abc", "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmn", "lm", "l"])"),
+ ArrayFromJSON(type, R"([null, "ab", "abc", "l"])"));
+}
+
+TYPED_TEST(TestIfElseBaseBinary, IfElseBaseBinaryRand) {
+ using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
+ using OffsetType = typename TypeTraits<TypeParam>::OffsetType::c_type;
+ auto type = TypeTraits<TypeParam>::type_singleton();
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+ int64_t len = 1000;
+
+ // this is to check the BitBlockCount::AllSet/ NoneSet code paths
+ ASSERT_OK_AND_ASSIGN(auto temp1, MakeArrayFromScalar(BooleanScalar(true), 64));
+ ASSERT_OK_AND_ASSIGN(auto temp2, MakeArrayFromScalar(BooleanScalar(false), 64));
+ auto temp3 = rand.ArrayOf(boolean(), len - 64 * 2, /*null_probability=*/0.01);
+
+ ASSERT_OK_AND_ASSIGN(auto concat, Concatenate({temp1, temp2, temp3}));
+ auto cond = std::static_pointer_cast<BooleanArray>(concat);
+
+ auto left = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01));
+ auto right = std::static_pointer_cast<ArrayType>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01));
+
+ typename TypeTraits<TypeParam>::BuilderType builder;
+
+ for (int64_t i = 0; i < len; ++i) {
+ if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) ||
+ (!cond->Value(i) && !right->IsValid(i))) {
+ ASSERT_OK(builder.AppendNull());
+ continue;
+ }
+
+ OffsetType offset;
+ const uint8_t* val;
+ if (cond->Value(i)) {
+ val = left->GetValue(i, &offset);
+ } else {
+ val = right->GetValue(i, &offset);
+ }
+ ASSERT_OK(builder.Append(val, offset));
+ }
+ ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish());
+
+ CheckIfElseOutput(cond, left, right, expected_data);
+}
+
+TEST_F(TestIfElseKernel, IfElseFSBinary) {
+ auto type = fixed_size_binary(4);
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(type, R"(["aaaa", "abab", "abca", "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmnl", "lmlm", "llll"])"),
+ ArrayFromJSON(type, R"(["aaaa", "abab", "abca", "llll"])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"),
+ ArrayFromJSON(type, R"(["aaaa", "abab", "abca", "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmnl", "lmlm", null])"),
+ ArrayFromJSON(type, R"(["aaaa", "abab", "abca", null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"),
+ ArrayFromJSON(type, R"(["aaaa", "abab", null, "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmnl", "lmlm", null])"),
+ ArrayFromJSON(type, R"(["aaaa", "abab", null, null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"),
+ ArrayFromJSON(type, R"(["aaaa", "abab", null, "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmnl", "lmlm", "llll"])"),
+ ArrayFromJSON(type, R"(["aaaa", "abab", null, "llll"])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(type, R"(["aaaa", "abab", null, "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmnl", "lmlm", "llll"])"),
+ ArrayFromJSON(type, R"([null, "abab", null, "llll"])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(type, R"(["aaaa", "abab", null, "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmnl", "lmlm", null])"),
+ ArrayFromJSON(type, R"([null, "abab", null, null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(type, R"(["aaaa", "abab", "abca", "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmnl", "lmlm", null])"),
+ ArrayFromJSON(type, R"([null, "abab", "abca", null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(type, R"(["aaaa", "abab", "abca", "abcd"])"),
+ ArrayFromJSON(type, R"(["lmno", "lmnl", "lmlm", "llll"])"),
+ ArrayFromJSON(type, R"([null, "abab", "abca", "llll"])"));
+}
+
+TEST_F(TestIfElseKernel, IfElseFSBinaryRand) {
+ auto type = fixed_size_binary(5);
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+ int64_t len = 1000;
+
+ // this is to check the BitBlockCount::AllSet/ NoneSet code paths
+ ASSERT_OK_AND_ASSIGN(auto temp1, MakeArrayFromScalar(BooleanScalar(true), 64));
+ ASSERT_OK_AND_ASSIGN(auto temp2, MakeArrayFromScalar(BooleanScalar(false), 64));
+ auto temp3 = rand.ArrayOf(boolean(), len - 64 * 2, /*null_probability=*/0.01);
+
+ ASSERT_OK_AND_ASSIGN(auto concat, Concatenate({temp1, temp2, temp3}));
+ auto cond = std::static_pointer_cast<BooleanArray>(concat);
+
+ auto left = std::static_pointer_cast<FixedSizeBinaryArray>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01));
+ auto right = std::static_pointer_cast<FixedSizeBinaryArray>(
+ rand.ArrayOf(type, len, /*null_probability=*/0.01));
+
+ FixedSizeBinaryBuilder builder(type);
+
+ for (int64_t i = 0; i < len; ++i) {
+ if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) ||
+ (!cond->Value(i) && !right->IsValid(i))) {
+ ASSERT_OK(builder.AppendNull());
+ continue;
+ }
+
+ const uint8_t* val;
+ if (cond->Value(i)) {
+ val = left->GetValue(i);
+ } else {
+ val = right->GetValue(i);
+ }
+ ASSERT_OK(builder.Append(val));
+ }
+ ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish());
+
+ CheckIfElseOutput(cond, left, right, expected_data);
+}
+
+TEST_F(TestIfElseKernel, Decimal) {
+ for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", "-4.56"])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "-4.56"])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", null])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", null, "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", null])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", null, null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", null, "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", "-4.56"])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", null, "-4.56"])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", null, "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", "-4.56"])"),
+ ArrayFromJSON(ty, R"([null, "2.34", null, "-4.56"])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", null, "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", null])"),
+ ArrayFromJSON(ty, R"([null, "2.34", null, null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", null])"),
+ ArrayFromJSON(ty, R"([null, "2.34", "-1.23", null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", "-4.56"])"),
+ ArrayFromJSON(ty, R"([null, "2.34", "-1.23", "-4.56"])"));
+ }
+}
+
+template <typename Type>
+class TestIfElseList : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestIfElseList, ListArrowTypes);
+
+TYPED_TEST(TestIfElseList, ListOfInt) {
+ auto type = std::make_shared<TypeParam>(int32());
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(type, "[[], null, [1, null], [2, 3]]"),
+ ArrayFromJSON(type, "[[4, 5, 6], [7], [null], null]"),
+ ArrayFromJSON(type, "[[], null, [null], null]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(type, "[[], [2, 3, 4, 5], null, null]"),
+ ArrayFromJSON(type, "[[4, 5, 6], null, [null], null]"),
+ ArrayFromJSON(type, "[null, null, null, null]"));
+}
+
+TYPED_TEST(TestIfElseList, ListOfString) {
+ auto type = std::make_shared<TypeParam>(utf8());
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(type, R"([[], null, ["xyz", null], ["ab", "c"]])"),
+ ArrayFromJSON(type, R"([["hi", "jk", "l"], ["defg"], [null], null])"),
+ ArrayFromJSON(type, R"([[], null, [null], null])"));
+
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(type, R"([[], ["b", "cd", "efg", "h"], null, null])"),
+ ArrayFromJSON(type, R"([["hi", "jk", "l"], null, [null], null])"),
+ ArrayFromJSON(type, R"([null, null, null, null])"));
+}
+
+TEST_F(TestIfElseKernel, FixedSizeList) {
+ auto type = fixed_size_list(int32(), 2);
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(type, "[[1, 2], null, [1, null], [2, 3]]"),
+ ArrayFromJSON(type, "[[4, 5], [6, 7], [null, 8], null]"),
+ ArrayFromJSON(type, "[[1, 2], null, [null, 8], null]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(type, "[[2, 3], [4, 5], null, null]"),
+ ArrayFromJSON(type, "[[4, 5], null, [6, null], null]"),
+ ArrayFromJSON(type, "[null, null, null, null]"));
+}
+
+TEST_F(TestIfElseKernel, StructPrimitive) {
+ auto type = struct_({field("int", uint16()), field("str", utf8())});
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(type, R"([[null, "foo"], null, [1, null], [2, "spam"]])"),
+ ArrayFromJSON(type, R"([[1, "a"], [42, ""], [24, null], null])"),
+ ArrayFromJSON(type, R"([[null, "foo"], null, [24, null], null])"));
+
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(type, R"([[null, "foo"], [4, "abcd"], null, null])"),
+ ArrayFromJSON(type, R"([[1, "a"], null, [24, null], null])"),
+ ArrayFromJSON(type, R"([null, null, null, null])"));
+}
+
+TEST_F(TestIfElseKernel, StructNested) {
+ auto type = struct_({field("date", date32()), field("list", list(int32()))});
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(type, R"([[-1, [null]], null, [1, null], [2, [3, 4]]])"),
+ ArrayFromJSON(type, R"([[4, [5]], [6, [7, 8]], [null, [1, null, 42]], null])"),
+ ArrayFromJSON(type, R"([[-1, [null]], null, [null, [1, null, 42]], null])"));
+
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(type, R"([[-1, [null]], [4, [5, 6]], null, null])"),
+ ArrayFromJSON(type, R"([[4, [5]], null, [null, [1, null, 42]], null])"),
+ ArrayFromJSON(type, R"([null, null, null, null])"));
+}
+
+TEST_F(TestIfElseKernel, ParameterizedTypes) {
+ auto cond = ArrayFromJSON(boolean(), "[true]");
+
+ auto type0 = fixed_size_binary(4);
+ auto type1 = fixed_size_binary(5);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: "
+ "fixed_size_binary[4], but got: fixed_size_binary[5]"),
+ CallFunction("if_else", {cond, ArrayFromJSON(type0, R"(["aaaa"])"),
+ ArrayFromJSON(type1, R"(["aaaaa"])")}));
+
+ // TODO(ARROW-14105): in principle many of these could be implicitly castable too
+
+ type0 = struct_({field("a", int32())});
+ type1 = struct_({field("a", int64())});
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: struct<a: int32>, "
+ "but got: struct<a: int64>"),
+ CallFunction("if_else",
+ {cond, ArrayFromJSON(type0, "[[0]]"), ArrayFromJSON(type1, "[[0]]")}));
+
+ type0 = dense_union({field("a", int32())});
+ type1 = dense_union({field("a", int64())});
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: dense_union<a: "
+ "int32=0>, but got: dense_union<a: int64=0>"),
+ CallFunction("if_else", {cond, ArrayFromJSON(type0, "[[0, -1]]"),
+ ArrayFromJSON(type1, "[[0, -1]]")}));
+
+ type0 = list(int16());
+ type1 = list(int32());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: list<item: int16>, "
+ "but got: list<item: int32>"),
+ CallFunction("if_else",
+ {cond, ArrayFromJSON(type0, "[[0]]"), ArrayFromJSON(type1, "[[0]]")}));
+
+ type0 = timestamp(TimeUnit::SECOND);
+ type1 = timestamp(TimeUnit::MILLI);
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(type0, "[1, 2, 3, 4]"),
+ ArrayFromJSON(type1, "[5, 6, 7, 8]"),
+ ArrayFromJSON(type1, "[1000, 2000, 3000, 8]"));
+
+ type0 = timestamp(TimeUnit::SECOND);
+ type1 = timestamp(TimeUnit::SECOND, "America/Phoenix");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: timestamp[s], "
+ "but got: timestamp[s, tz=America/Phoenix]"),
+ CallFunction("if_else",
+ {cond, ArrayFromJSON(type0, "[0]"), ArrayFromJSON(type1, "[1]")}));
+
+ type0 = timestamp(TimeUnit::SECOND, "America/New_York");
+ type1 = timestamp(TimeUnit::SECOND, "America/Phoenix");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr(
+ "All types must be compatible, expected: timestamp[s, tz=America/New_York], "
+ "but got: timestamp[s, tz=America/Phoenix]"),
+ CallFunction("if_else",
+ {cond, ArrayFromJSON(type0, "[0]"), ArrayFromJSON(type1, "[1]")}));
+
+ type0 = timestamp(TimeUnit::MILLI, "America/New_York");
+ type1 = timestamp(TimeUnit::SECOND, "America/Phoenix");
+ // Casting fails so we never get to the kernel in the first place (since the units don't
+ // match)
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ NotImplemented,
+ ::testing::HasSubstr("Function if_else has no kernel matching input types "
+ "(array[bool], array[timestamp[ms, tz=America/New_York]], "
+ "array[timestamp[s, tz=America/Phoenix]]"),
+ CallFunction("if_else",
+ {cond, ArrayFromJSON(type0, "[0]"), ArrayFromJSON(type1, "[1]")}));
+}
+
+template <typename Type>
+class TestIfElseUnion : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestIfElseUnion, UnionArrowTypes);
+
+TYPED_TEST(TestIfElseUnion, UnionPrimitive) {
+ std::vector<std::shared_ptr<Field>> fields = {field("int", uint16()),
+ field("str", utf8())};
+ std::vector<int8_t> codes = {2, 7};
+ auto type = std::make_shared<TypeParam>(fields, codes);
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(type, R"([[7, "foo"], [7, null], [7, null], [7, "spam"]])"),
+ ArrayFromJSON(type, R"([[2, 15], [2, null], [2, 42], [2, null]])"),
+ ArrayFromJSON(type, R"([[7, "foo"], [7, null], [2, 42], [2, null]])"));
+
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(type, R"([[7, "foo"], [7, null], [7, null], [7, "spam"]])"),
+ ArrayFromJSON(type, R"([[2, 15], [2, null], [2, 42], [2, null]])"),
+ ArrayFromJSON(type, R"([null, null, null, null])"));
+}
+
+TYPED_TEST(TestIfElseUnion, UnionNested) {
+ std::vector<std::shared_ptr<Field>> fields = {field("int", uint16()),
+ field("list", list(int16()))};
+ std::vector<int8_t> codes = {2, 7};
+ auto type = std::make_shared<TypeParam>(fields, codes);
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(type, R"([[7, [1, 2]], [7, null], [7, []], [7, [3]]])"),
+ ArrayFromJSON(type, R"([[2, 15], [2, null], [2, 42], [2, null]])"),
+ ArrayFromJSON(type, R"([[7, [1, 2]], [7, null], [2, 42], [2, null]])"));
+
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(type, R"([[7, [1, 2]], [7, null], [7, []], [7, [3]]])"),
+ ArrayFromJSON(type, R"([[2, 15], [2, null], [2, 42], [2, null]])"),
+ ArrayFromJSON(type, R"([null, null, null, null])"));
+}
+
+template <typename Type>
+class TestIfElseDict : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestIfElseDict, IntegralArrowTypes);
+
+TYPED_TEST(TestIfElseDict, Simple) {
+ auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ for (const auto& dict :
+ {JsonDict{utf8(), R"(["a", null, "bc", "def"])"},
+ JsonDict{int64(), "[1, null, 2, 3]"},
+ JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) {
+ auto type = dictionary(default_type_instance<TypeParam>(), dict.type);
+ auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict.value);
+ auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict.value);
+ auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict.value);
+ auto scalar = DictScalarFromJSON(type, "3", dict.value);
+
+ // Easy case: all arguments have the same dictionary
+ CheckDictionary("if_else", {cond, values1, values2});
+ CheckDictionary("if_else", {cond, values1, scalar});
+ CheckDictionary("if_else", {cond, scalar, values2});
+ CheckDictionary("if_else", {cond, values_null, values2});
+ CheckDictionary("if_else", {cond, values1, values_null});
+ CheckDictionary("if_else", {Datum(true), values1, values2});
+ CheckDictionary("if_else", {Datum(false), values1, values2});
+ CheckDictionary("if_else", {Datum(true), scalar, values2});
+ CheckDictionary("if_else", {Datum(true), values1, scalar});
+ CheckDictionary("if_else", {Datum(false), values1, scalar});
+ CheckDictionary("if_else", {Datum(false), scalar, values2});
+ CheckDictionary("if_else", {MakeNullScalar(boolean()), values1, values2});
+ }
+}
+
+TYPED_TEST(TestIfElseDict, Mixed) {
+ auto index_type = default_type_instance<TypeParam>();
+ auto type = dictionary(index_type, utf8());
+ auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto dict = R"(["a", null, "bc", "def"])";
+ auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict);
+ auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict);
+ auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])");
+ auto values2_dict = DictArrayFromJSON(type, "[2, 1, null, 0]", dict);
+ auto values2_decoded = ArrayFromJSON(utf8(), R"(["bc", null, null, "a"])");
+ auto scalar = ScalarFromJSON(utf8(), R"("bc")");
+
+ // If we have mixed dictionary/non-dictionary arguments, we decode dictionaries
+ CheckDictionary("if_else", {cond, values1_dict, values2_decoded},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, values1_dict, scalar}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, scalar, values2_dict}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, values_null, values2_decoded},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, values1_decoded, values_null},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(true), values1_decoded, values2_dict},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(false), values1_decoded, values2_dict},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(true), scalar, values2_dict},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(true), values1_dict, scalar},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(false), values1_dict, scalar},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(false), scalar, values2_dict},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {MakeNullScalar(boolean()), values1_decoded, values2_dict},
+ /*result_is_encoded=*/false);
+
+ // If we have mismatched dictionary types, we decode (for now)
+ auto values3_dict =
+ DictArrayFromJSON(dictionary(index_type, binary()), "[2, 1, null, 0]", dict);
+ auto values4_dict = DictArrayFromJSON(
+ dictionary(index_type->id() == Type::UINT8 ? int8() : uint8(), utf8()),
+ "[2, 1, null, 0]", dict);
+ CheckDictionary("if_else", {cond, values1_dict, values3_dict},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, values1_dict, values4_dict},
+ /*result_is_encoded=*/false);
+}
+
+TYPED_TEST(TestIfElseDict, NestedSimple) {
+ auto index_type = default_type_instance<TypeParam>();
+ auto inner_type = dictionary(index_type, utf8());
+ auto type = list(inner_type);
+ auto dict = R"(["a", null, "bc", "def"])";
+ auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto values_null = MakeListOfDict(ArrayFromJSON(int32(), "[null, null, null, null, 0]"),
+ DictArrayFromJSON(inner_type, "[]", dict));
+ auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict);
+ auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", dict);
+ auto values1 =
+ MakeListOfDict(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing);
+ auto values2 =
+ MakeListOfDict(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing);
+ auto scalar =
+ Datum(std::make_shared<ListScalar>(DictArrayFromJSON(inner_type, "[0, 1]", dict)));
+
+ CheckDictionary("if_else", {cond, values1, values2}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, values1, scalar}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, scalar, values2}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, values_null, values2}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, values1, values_null}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(true), values1, values2},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(false), values1, values2},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(true), scalar, values2}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(true), values1, scalar}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(false), values1, scalar},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(false), scalar, values2},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {MakeNullScalar(boolean()), values1, values2},
+ /*result_is_encoded=*/false);
+}
+
+TYPED_TEST(TestIfElseDict, DifferentDictionaries) {
+ auto type = dictionary(default_type_instance<TypeParam>(), utf8());
+ auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto dict1 = R"(["a", null, "bc", "def"])";
+ auto dict2 = R"(["bc", "foo", null, "a"])";
+ auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1);
+ auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2);
+ auto values1 = DictArrayFromJSON(type, "[null, 0, 3, 1]", dict1);
+ auto values2 = DictArrayFromJSON(type, "[2, 1, 0, null]", dict2);
+ auto scalar1 = DictScalarFromJSON(type, "0", dict1);
+ auto scalar2 = DictScalarFromJSON(type, "0", dict2);
+
+ CheckDictionary("if_else", {cond, values1, values2});
+ CheckDictionary("if_else", {cond, values1, scalar2});
+ CheckDictionary("if_else", {cond, scalar1, values2});
+ CheckDictionary("if_else", {cond, values1_null, values2});
+ CheckDictionary("if_else", {cond, values1, values2_null});
+ CheckDictionary("if_else", {Datum(true), values1, values2});
+ CheckDictionary("if_else", {Datum(false), values1, values2});
+ CheckDictionary("if_else", {Datum(true), scalar1, values2});
+ CheckDictionary("if_else", {Datum(true), values1, scalar2});
+ CheckDictionary("if_else", {Datum(false), values1, scalar2});
+ CheckDictionary("if_else", {Datum(false), scalar1, values2});
+ CheckDictionary("if_else", {MakeNullScalar(boolean()), values1, values2});
+}
+
+template <typename Type>
+class TestCaseWhenNumeric : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestCaseWhenNumeric, NumericBasedTypes);
+
+Datum MakeStruct(const std::vector<Datum>& conds) {
+ EXPECT_OK_AND_ASSIGN(auto result, CallFunction("make_struct", conds));
+ return result;
+}
+
+TYPED_TEST(TestCaseWhenNumeric, FixedSize) {
+ auto type = default_type_instance<TypeParam>();
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "1");
+ auto scalar2 = ScalarFromJSON(type, "2");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, "[3, null, 5, 6]");
+ auto values2 = ArrayFromJSON(type, "[7, 8, null, 10]");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, "[1, 1, 2, null]"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, "[null, null, 1, 1]"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(type, "[1, 1, 2, 1]"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, "[3, null, null, null]"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, "[3, null, null, 6]"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, "[null, null, null, 6]"));
+
+ CheckScalar(
+ "case_when",
+ {MakeStruct(
+ {ArrayFromJSON(boolean(),
+ "[true, true, true, false, false, false, null, null, null]"),
+ ArrayFromJSON(boolean(),
+ "[true, false, null, true, false, null, true, false, null]")}),
+ ArrayFromJSON(type, "[10, 11, 12, 13, 14, 15, 16, 17, 18]"),
+ ArrayFromJSON(type, "[20, 21, 22, 23, 24, 25, 26, 27, 28]")},
+ ArrayFromJSON(type, "[10, 11, 12, 23, null, null, 26, null, null]"));
+ CheckScalar(
+ "case_when",
+ {MakeStruct(
+ {ArrayFromJSON(boolean(),
+ "[true, true, true, false, false, false, null, null, null]"),
+ ArrayFromJSON(boolean(),
+ "[true, false, null, true, false, null, true, false, null]")}),
+ ArrayFromJSON(type, "[10, 11, 12, 13, 14, 15, 16, 17, 18]"),
+
+ ArrayFromJSON(type, "[20, 21, 22, 23, 24, 25, 26, 27, 28]"),
+ ArrayFromJSON(type, "[30, 31, 32, 33, 34, null, 36, 37, null]")},
+ ArrayFromJSON(type, "[10, 11, 12, 23, 34, null, 26, 37, null]"));
+
+ // Error cases
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("cond struct must not be null"),
+ CallFunction(
+ "case_when",
+ {Datum(std::make_shared<StructScalar>(struct_({field("", boolean())}))),
+ Datum(scalar1)}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("cond struct must not have top-level nulls"),
+ CallFunction(
+ "case_when",
+ {Datum(*MakeArrayOfNull(struct_({field("", boolean())}), 4)), Datum(values1)}));
+}
+
+TYPED_TEST(TestCaseWhenNumeric, ListOfType) {
+ // More minimal test to check type coverage
+ auto type = list(default_type_instance<TypeParam>());
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"([[1, 2], null, [3, 4, 5], [6, null]])");
+ auto values2 = ArrayFromJSON(type, R"([[8, 9, 10], [11], null, [12]])");
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([[1, 2], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([[1, 2], null, null, [6, null]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [6, null]])"));
+}
+
+template <typename Type>
+class TestCaseWhenDict : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestCaseWhenDict, IntegralArrowTypes);
+
+TYPED_TEST(TestCaseWhenDict, Simple) {
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ for (const auto& dict :
+ {JsonDict{utf8(), R"(["a", null, "bc", "def"])"},
+ JsonDict{int64(), "[1, null, 2, 3]"},
+ JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) {
+ auto type = dictionary(default_type_instance<TypeParam>(), dict.type);
+ auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict.value);
+ auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict.value);
+ auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict.value);
+
+ // Easy case: all arguments have the same dictionary
+ CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2});
+ CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1});
+ CheckDictionary("case_when",
+ {MakeStruct({cond1, cond2}), values_null, values2, values1});
+ }
+}
+
+TYPED_TEST(TestCaseWhenDict, Mixed) {
+ auto type = dictionary(default_type_instance<TypeParam>(), utf8());
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto dict = R"(["a", null, "bc", "def"])";
+ auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict);
+ auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict);
+ auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])");
+ auto values2_dict = DictArrayFromJSON(type, "[2, 1, null, 0]", dict);
+ auto values2_decoded = ArrayFromJSON(utf8(), R"(["bc", null, null, "a"])");
+
+ // If we have mixed dictionary/non-dictionary arguments, we decode dictionaries
+ CheckDictionary("case_when",
+ {MakeStruct({cond1, cond2}), values1_dict, values2_decoded},
+ /*result_is_encoded=*/false);
+ CheckDictionary("case_when",
+ {MakeStruct({cond1, cond2}), values1_decoded, values2_dict},
+ /*result_is_encoded=*/false);
+ CheckDictionary(
+ "case_when",
+ {MakeStruct({cond1, cond2}), values1_dict, values2_dict, values1_decoded},
+ /*result_is_encoded=*/false);
+ CheckDictionary(
+ "case_when",
+ {MakeStruct({cond1, cond2}), values_null, values2_dict, values1_decoded},
+ /*result_is_encoded=*/false);
+}
+
+TYPED_TEST(TestCaseWhenDict, NestedSimple) {
+ auto index_type = default_type_instance<TypeParam>();
+ auto inner_type = dictionary(index_type, utf8());
+ auto type = list(inner_type);
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto dict = R"(["a", null, "bc", "def"])";
+ auto values_null = MakeListOfDict(ArrayFromJSON(int32(), "[null, null, null, null, 0]"),
+ DictArrayFromJSON(inner_type, "[]", dict));
+ auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict);
+ auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", dict);
+ auto values1 =
+ MakeListOfDict(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing);
+ auto values2 =
+ MakeListOfDict(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing);
+
+ CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ /*result_is_encoded=*/false);
+ CheckDictionary(
+ "case_when",
+ {MakeStruct({cond1, cond2}), values1,
+ MakeListOfDict(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing)},
+ /*result_is_encoded=*/false);
+ CheckDictionary(
+ "case_when",
+ {MakeStruct({cond1, cond2}), values1,
+ MakeListOfDict(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing),
+ values1},
+ /*result_is_encoded=*/false);
+
+ CheckDictionary("case_when",
+ {
+ Datum(MakeStruct({cond1, cond2})),
+ Datum(std::make_shared<ListScalar>(
+ DictArrayFromJSON(inner_type, "[0, 1]", dict))),
+ Datum(std::make_shared<ListScalar>(
+ DictArrayFromJSON(inner_type, "[2, 3]", dict))),
+ },
+ /*result_is_encoded=*/false);
+
+ CheckDictionary("case_when",
+ {MakeStruct({Datum(true), Datum(false)}), values1, values2},
+ /*result_is_encoded=*/false);
+ CheckDictionary("case_when",
+ {MakeStruct({Datum(false), Datum(true)}), values1, values2},
+ /*result_is_encoded=*/false);
+ CheckDictionary("case_when", {MakeStruct({Datum(false)}), values1, values2},
+ /*result_is_encoded=*/false);
+ CheckDictionary("case_when",
+ {MakeStruct({Datum(false), Datum(false)}), values1, values2},
+ /*result_is_encoded=*/false);
+}
+
+TYPED_TEST(TestCaseWhenDict, DifferentDictionaries) {
+ auto type = dictionary(default_type_instance<TypeParam>(), utf8());
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, null, true]");
+ auto dict1 = R"(["a", null, "bc", "def"])";
+ auto dict2 = R"(["bc", "foo", null, "a"])";
+ auto dict3 = R"(["def", null, "a", "bc"])";
+ auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1);
+ auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2);
+ auto values1 = DictArrayFromJSON(type, "[null, 0, 3, 1]", dict1);
+ auto values2 = DictArrayFromJSON(type, "[2, 1, 0, null]", dict2);
+ auto values3 = DictArrayFromJSON(type, "[0, 1, 2, 3]", dict3);
+
+ CheckDictionary("case_when",
+ {MakeStruct({Datum(true), Datum(false)}), values1, values2});
+ CheckDictionary("case_when",
+ {MakeStruct({Datum(false), Datum(true)}), values1, values2});
+ CheckDictionary("case_when",
+ {MakeStruct({Datum(false), Datum(false)}), values1, values2});
+ CheckDictionary("case_when",
+ {MakeStruct({Datum(false), Datum(false)}), values2, values1});
+
+ CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2});
+ CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1});
+
+ CheckDictionary("case_when",
+ {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}),
+ values1, values2});
+ CheckDictionary("case_when",
+ {MakeStruct({ArrayFromJSON(boolean(), "[true, false, false, true]")}),
+ values1, values2});
+ CheckDictionary("case_when",
+ {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(boolean(), "[true, false, true, false]")}),
+ values1, values2});
+ CheckDictionary("case_when",
+ {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]"),
+ ArrayFromJSON(boolean(), "[true, true, true, true]")}),
+ values1, values3});
+ CheckDictionary("case_when",
+ {MakeStruct({ArrayFromJSON(boolean(), "[null, null, null, true]"),
+ ArrayFromJSON(boolean(), "[true, true, true, true]")}),
+ values1, values3});
+ CheckDictionary(
+ "case_when",
+ {
+ MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}),
+ DictScalarFromJSON(type, "0", dict1),
+ DictScalarFromJSON(type, "0", dict2),
+ });
+ CheckDictionary(
+ "case_when",
+ {
+ MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(boolean(), "[false, false, true, true]")}),
+ DictScalarFromJSON(type, "0", dict1),
+ DictScalarFromJSON(type, "0", dict2),
+ });
+ CheckDictionary(
+ "case_when",
+ {
+ MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(boolean(), "[false, false, true, true]")}),
+ DictScalarFromJSON(type, "null", dict1),
+ DictScalarFromJSON(type, "0", dict2),
+ });
+}
+
+TEST(TestCaseWhen, Null) {
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_arr = ArrayFromJSON(boolean(), "[true, true, false, null]");
+ auto scalar = ScalarFromJSON(null(), "null");
+ auto array = ArrayFromJSON(null(), "[null, null, null, null]");
+ CheckScalar("case_when", {MakeStruct({}), array}, array);
+ CheckScalar("case_when", {MakeStruct({cond_false}), array}, array);
+ CheckScalar("case_when", {MakeStruct({cond_true}), array, array}, array);
+ CheckScalar("case_when", {MakeStruct({cond_arr, cond_true}), array, array}, array);
+}
+
+TEST(TestCaseWhen, Boolean) {
+ auto type = boolean();
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "true");
+ auto scalar2 = ScalarFromJSON(type, "false");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, "[true, null, true, true]");
+ auto values2 = ArrayFromJSON(type, "[false, false, null, false]");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, "[true, true, false, null]"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, "[null, null, true, true]"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(type, "[true, true, false, true]"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, "[true, null, null, null]"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, "[true, null, null, true]"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, "[null, null, null, true]"));
+}
+
+TEST(TestCaseWhen, DayTimeInterval) {
+ auto type = day_time_interval();
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "[1, 1]");
+ auto scalar2 = ScalarFromJSON(type, "[2, 2]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, "[[3, 3], null, [5, 5], [6, 6]]");
+ auto values2 = ArrayFromJSON(type, "[[7, 7], [8, 8], null, [10, 10]]");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, "[[1, 1], [1, 1], [2, 2], null]"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, "[null, null, [1, 1], [1, 1]]"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(type, "[[1, 1], [1, 1], [2, 2], [1, 1]]"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, "[[3, 3], null, null, null]"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, "[[3, 3], null, null, [6, 6]]"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, "[null, null, null, [6, 6]]"));
+}
+
+TEST(TestCaseWhen, Decimal) {
+ for (const auto& type :
+ std::vector<std::shared_ptr<DataType>>{decimal128(3, 2), decimal256(3, 2)}) {
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"("1.23")");
+ auto scalar2 = ScalarFromJSON(type, R"("2.34")");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"(["3.45", null, "5.67", "6.78"])");
+ auto values2 = ArrayFromJSON(type, R"(["7.89", "8.90", null, "1.01"])");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2},
+ values2);
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, R"(["1.23", "1.23", "2.34", null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, R"([null, null, "1.23", "1.23"])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(type, R"(["1.23", "1.23", "2.34", "1.23"])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"(["3.45", null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"(["3.45", null, null, "6.78"])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, "6.78"])"));
+ }
+}
+
+TEST(TestCaseWhen, FixedSizeBinary) {
+ auto type = fixed_size_binary(3);
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"("abc")");
+ auto scalar2 = ScalarFromJSON(type, R"("bcd")");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"(["cde", null, "def", "efg"])");
+ auto values2 = ArrayFromJSON(type, R"(["fgh", "ghi", null, "hij"])");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, R"(["abc", "abc", "bcd", null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, R"([null, null, "abc", "abc"])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(type, R"(["abc", "abc", "bcd", "abc"])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"(["cde", null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"(["cde", null, null, "efg"])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, "efg"])"));
+}
+
+template <typename Type>
+class TestCaseWhenBinary : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestCaseWhenBinary, BinaryArrowTypes);
+
+TYPED_TEST(TestCaseWhenBinary, Basics) {
+ auto type = default_type_instance<TypeParam>();
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"("aBxYz")");
+ auto scalar2 = ScalarFromJSON(type, R"("b")");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"(["cDE", null, "degfhi", "efg"])");
+ auto values2 = ArrayFromJSON(type, R"(["fghijk", "ghi", null, "hi"])");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, R"(["aBxYz", "aBxYz", "b", null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, R"([null, null, "aBxYz", "aBxYz"])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(type, R"(["aBxYz", "aBxYz", "b", "aBxYz"])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"(["cDE", null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"(["cDE", null, null, "efg"])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, "efg"])"));
+}
+
+template <typename Type>
+class TestCaseWhenList : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestCaseWhenList, ListArrowTypes);
+
+TYPED_TEST(TestCaseWhenList, ListOfString) {
+ auto type = std::make_shared<TypeParam>(utf8());
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"(["aB", "xYz"])");
+ auto scalar2 = ScalarFromJSON(type, R"(["b", null])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 =
+ ArrayFromJSON(type, R"([["cD", "E"], null, ["de", "gf", "hi"], ["ef", "g"]])");
+ auto values2 = ArrayFromJSON(type, R"([["f", "ghi", "jk"], ["ghi"], null, ["hi"]])");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar(
+ "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, R"([["aB", "xYz"], ["aB", "xYz"], ["b", null], null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, R"([null, null, ["aB", "xYz"], ["aB", "xYz"]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(
+ type, R"([["aB", "xYz"], ["aB", "xYz"], ["b", null], ["aB", "xYz"]])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([["cD", "E"], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([["cD", "E"], null, null, ["ef", "g"]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, ["ef", "g"]])"));
+}
+
+// More minimal tests to check type coverage
+TYPED_TEST(TestCaseWhenList, ListOfBool) {
+ auto type = std::make_shared<TypeParam>(boolean());
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"([[true], null, [false], [false, null]])");
+ auto values2 = ArrayFromJSON(type, R"([[false], [false], null, [true]])");
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([[true], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([[true], null, null, [false, null]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [false, null]])"));
+}
+
+TYPED_TEST(TestCaseWhenList, ListOfInt) {
+ auto type = std::make_shared<TypeParam>(int64());
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"([[1, 2], null, [3, 4, 5], [6, null]])");
+ auto values2 = ArrayFromJSON(type, R"([[8, 9, 10], [11], null, [12]])");
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([[1, 2], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([[1, 2], null, null, [6, null]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [6, null]])"));
+}
+
+TYPED_TEST(TestCaseWhenList, ListOfDayTimeInterval) {
+ auto type = std::make_shared<TypeParam>(day_time_interval());
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 =
+ ArrayFromJSON(type, R"([[[1, 2]], null, [[3, 4], [5, 0]], [[6, 7], null]])");
+ auto values2 = ArrayFromJSON(type, R"([[[8, 9], null], [[11, 12]], null, [[12, 1]]])");
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([[[1, 2]], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([[[1, 2]], null, null, [[6, 7], null]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [[6, 7], null]])"));
+}
+
+TYPED_TEST(TestCaseWhenList, ListOfDecimal) {
+ for (const auto& decimal_ty :
+ std::vector<std::shared_ptr<DataType>>{decimal128(3, 2), decimal256(3, 2)}) {
+ auto type = std::make_shared<TypeParam>(decimal_ty);
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(
+ type, R"([["1.23", "2.34"], null, ["3.45", "4.56", "5.67"], ["6.78", null]])");
+ auto values2 =
+ ArrayFromJSON(type, R"([["8.90", "9.01", "1.02"], ["1.12"], null, ["1.23"]])");
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([["1.23", "2.34"], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([["1.23", "2.34"], null, null, ["6.78", null]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, ["6.78", null]])"));
+ }
+}
+
+TYPED_TEST(TestCaseWhenList, ListOfFixedSizeBinary) {
+ auto type = std::make_shared<TypeParam>(fixed_size_binary(4));
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(
+ type, R"([["1.23", "2.34"], null, ["3.45", "4.56", "5.67"], ["6.78", null]])");
+ auto values2 =
+ ArrayFromJSON(type, R"([["8.90", "9.01", "1.02"], ["1.12"], null, ["1.23"]])");
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([["1.23", "2.34"], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([["1.23", "2.34"], null, null, ["6.78", null]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, ["6.78", null]])"));
+}
+
+TYPED_TEST(TestCaseWhenList, ListOfListOfInt) {
+ auto type = std::make_shared<TypeParam>(list(int64()));
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 =
+ ArrayFromJSON(type, R"([[[1, 2], []], null, [[3, 4, 5]], [[6, null], null]])");
+ auto values2 = ArrayFromJSON(type, R"([[[8, 9, 10]], [[11]], null, [[12]]])");
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([[[1, 2], []], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([[[1, 2], []], null, null, [[6, null], null]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [[6, null], null]])"));
+}
+
+TEST(TestCaseWhen, Map) {
+ auto type = map(int64(), utf8());
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([[1, "abc"], [2, "de"]])");
+ auto scalar2 = ScalarFromJSON(type, R"([[3, "fghi"]])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 =
+ ArrayFromJSON(type, R"([[[4, "kl"]], null, [[5, "mn"]], [[6, "o"], [7, "pq"]]])");
+ auto values2 = ArrayFromJSON(type, R"([[[8, "r"], [9, "st"]], [[10, "u"]], null, []])");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar(
+ "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(
+ type,
+ R"([[[1, "abc"], [2, "de"]], [[1, "abc"], [2, "de"]], [[3, "fghi"]], null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar(
+ "case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type,
+ R"([null, null, [[1, "abc"], [2, "de"]], [[1, "abc"], [2, "de"]]])"));
+ CheckScalar(
+ "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(
+ type,
+ R"([[[1, "abc"], [2, "de"]], [[1, "abc"], [2, "de"]], [[3, "fghi"]], [[1, "abc"], [2, "de"]]])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([[[4, "kl"]], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([[[4, "kl"]], null, null, [[6, "o"], [7, "pq"]]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [[6, "o"], [7, "pq"]]])"));
+}
+
+TEST(TestCaseWhen, FixedSizeListOfInt) {
+ auto type = fixed_size_list(int64(), 2);
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([1, 2])");
+ auto scalar2 = ScalarFromJSON(type, R"([3, null])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"([[4, 5], null, [6, 7], [8, 9]])");
+ auto values2 = ArrayFromJSON(type, R"([[10, 11], [12, null], null, [null, 13]])");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, R"([[1, 2], [1, 2], [3, null], null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, R"([null, null, [1, 2], [1, 2]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(type, R"([[1, 2], [1, 2], [3, null], [1, 2]])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([[4, 5], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([[4, 5], null, null, [8, 9]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [8, 9]])"));
+}
+
+TEST(TestCaseWhen, FixedSizeListOfString) {
+ auto type = fixed_size_list(utf8(), 2);
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"(["aB", "xYz"])");
+ auto scalar2 = ScalarFromJSON(type, R"(["b", null])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 =
+ ArrayFromJSON(type, R"([["cD", "E"], null, ["de", "gfhi"], ["ef", "g"]])");
+ auto values2 =
+ ArrayFromJSON(type, R"([["fghi", "jk"], ["ghi", null], null, [null, "hi"]])");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar(
+ "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, R"([["aB", "xYz"], ["aB", "xYz"], ["b", null], null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, R"([null, null, ["aB", "xYz"], ["aB", "xYz"]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(
+ type, R"([["aB", "xYz"], ["aB", "xYz"], ["b", null], ["aB", "xYz"]])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([["cD", "E"], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([["cD", "E"], null, null, ["ef", "g"]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, ["ef", "g"]])"));
+}
+
+TEST(TestCaseWhen, StructOfInt) {
+ auto type = struct_({field("a", uint32()), field("b", int64())});
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([1, -2])");
+ auto scalar2 = ScalarFromJSON(type, R"([null, 3])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"([[4, null], null, [5, -6], [7, -8]])");
+ auto values2 = ArrayFromJSON(type, R"([[9, 10], [11, -12], null, [null, null]])");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, R"([[1, -2], [1, -2], [null, 3], null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, R"([null, null, [1, -2], [1, -2]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(type, R"([[1, -2], [1, -2], [null, 3], [1, -2]])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([[4, null], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([[4, null], null, null, [7, -8]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [7, -8]])"));
+}
+
+TEST(TestCaseWhen, StructOfString) {
+ // More minimal test to check type coverage
+ auto type = struct_({field("a", utf8()), field("b", large_utf8())});
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"(["a", "bc"])");
+ auto scalar2 = ScalarFromJSON(type, R"([null, "d"])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 =
+ ArrayFromJSON(type, R"([["efg", null], null, [null, null], [null, "hi"]])");
+ auto values2 =
+ ArrayFromJSON(type, R"([["j", "k"], [null, "lmnop"], null, ["qr", "stu"]])");
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, R"([["a", "bc"], ["a", "bc"], [null, "d"], null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, R"([null, null, ["a", "bc"], ["a", "bc"]])"));
+ CheckScalar(
+ "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(type, R"([["a", "bc"], ["a", "bc"], [null, "d"], ["a", "bc"]])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([["efg", null], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([["efg", null], null, null, [null, "hi"]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [null, "hi"]])"));
+}
+
+TEST(TestCaseWhen, StructOfListOfInt) {
+ // More minimal test to check type coverage
+ auto type = struct_({field("a", utf8()), field("b", list(int64()))});
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([null, [1, null]])");
+ auto scalar2 = ScalarFromJSON(type, R"(["b", null])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 =
+ ArrayFromJSON(type, R"([["efg", null], null, [null, null], [null, [null, 1]]])");
+ auto values2 =
+ ArrayFromJSON(type, R"([["j", [2, 3]], [null, [4, 5, 6]], null, ["qr", [7]]])");
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(
+ type, R"([[null, [1, null]], [null, [1, null]], ["b", null], null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar(
+ "case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, R"([null, null, [null, [1, null]], [null, [1, null]]])"));
+ CheckScalar(
+ "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(
+ type,
+ R"([[null, [1, null]], [null, [1, null]], ["b", null], [null, [1, null]]])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([["efg", null], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([["efg", null], null, null, [null, [null, 1]]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [null, [null, 1]]])"));
+}
+
+TEST(TestCaseWhen, UnionBoolString) {
+ for (const auto& type : std::vector<std::shared_ptr<DataType>>{
+ sparse_union({field("a", boolean()), field("b", utf8())}, {2, 7}),
+ dense_union({field("a", boolean()), field("b", utf8())}, {2, 7})}) {
+ ARROW_SCOPED_TRACE(type->ToString());
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([2, null])");
+ auto scalar2 = ScalarFromJSON(type, R"([7, "foo"])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"([[2, true], null, [7, "bar"], [7, "baz"]])");
+ auto values2 = ArrayFromJSON(type, R"([[7, "spam"], [2, null], null, [7, null]])");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2},
+ values2);
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, R"([[2, null], [2, null], [7, "foo"], null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, R"([null, null, [2, null], [2, null]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(type, R"([[2, null], [2, null], [7, "foo"], [2, null]])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([[2, true], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([[2, true], null, null, [7, "baz"]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [7, "baz"]])"));
+ }
+}
+
+TEST(TestCaseWhen, DispatchBest) {
+ CheckDispatchBest("case_when", {struct_({field("", boolean())}), int64(), int32()},
+ {struct_({field("", boolean())}), int64(), int64()});
+
+ ASSERT_RAISES(Invalid, CallFunction("case_when", {}));
+ // Too many/too few conditions
+ ASSERT_RAISES(
+ Invalid, CallFunction("case_when", {MakeStruct({ArrayFromJSON(boolean(), "[]")})}));
+ ASSERT_RAISES(Invalid,
+ CallFunction("case_when", {MakeStruct({}), ArrayFromJSON(int64(), "[]"),
+ ArrayFromJSON(int64(), "[]")}));
+ // Conditions must be struct of boolean
+ ASSERT_RAISES(TypeError,
+ CallFunction("case_when", {MakeStruct({ArrayFromJSON(int64(), "[]")}),
+ ArrayFromJSON(int64(), "[]")}));
+ ASSERT_RAISES(TypeError, CallFunction("case_when", {ArrayFromJSON(boolean(), "[true]"),
+ ArrayFromJSON(int32(), "[0]")}));
+ // Values must have compatible types
+ ASSERT_RAISES(NotImplemented,
+ CallFunction("case_when", {MakeStruct({ArrayFromJSON(boolean(), "[]")}),
+ ArrayFromJSON(int64(), "[]"),
+ ArrayFromJSON(utf8(), "[]")}));
+
+ // Do not dictionary-decode when we have only dictionary values
+ CheckDispatchBest("case_when",
+ {struct_({field("", boolean())}), dictionary(int64(), utf8()),
+ dictionary(int64(), utf8())},
+ {struct_({field("", boolean())}), dictionary(int64(), utf8()),
+ dictionary(int64(), utf8())});
+
+ // Dictionary-decode if we have a mix
+ CheckDispatchBest(
+ "case_when", {struct_({field("", boolean())}), dictionary(int64(), utf8()), utf8()},
+ {struct_({field("", boolean())}), utf8(), utf8()});
+}
+
+template <typename Type>
+class TestCoalesceNumeric : public ::testing::Test {};
+template <typename Type>
+class TestCoalesceBinary : public ::testing::Test {};
+template <typename Type>
+class TestCoalesceList : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestCoalesceNumeric, NumericBasedTypes);
+TYPED_TEST_SUITE(TestCoalesceBinary, BinaryArrowTypes);
+TYPED_TEST_SUITE(TestCoalesceList, ListArrowTypes);
+
+TYPED_TEST(TestCoalesceNumeric, Basics) {
+ auto type = default_type_instance<TypeParam>();
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "20");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, "[null, 10, 11, 12]");
+ auto values2 = ArrayFromJSON(type, "[13, 14, 15, 16]");
+ auto values3 = ArrayFromJSON(type, "[17, 18, 19, null]");
+ // N.B. all-scalar cases are checked in CheckScalar
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, "[20, 20, 20, 20]"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar("coalesce", {values1, scalar1}, ArrayFromJSON(type, "[20, 10, 11, 12]"));
+ CheckScalar("coalesce", {values1, values2}, ArrayFromJSON(type, "[13, 10, 11, 12]"));
+ CheckScalar("coalesce", {values1, values2, values3},
+ ArrayFromJSON(type, "[13, 10, 11, 12]"));
+ CheckScalar("coalesce", {scalar1, values1}, ArrayFromJSON(type, "[20, 20, 20, 20]"));
+}
+
+TYPED_TEST(TestCoalesceNumeric, ListOfType) {
+ auto type = list(default_type_instance<TypeParam>());
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "[20, 24]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, "[null, [10, null, 20], [], [null, null]]");
+ auto values2 = ArrayFromJSON(type, "[[23], [14, 24], [null, 15], [16]]");
+ auto values3 = ArrayFromJSON(type, "[[17, 18], [19], [], null]");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, "[[20, 24], [20, 24], [20, 24], [20, 24]]"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar("coalesce", {values1, scalar1},
+ ArrayFromJSON(type, "[[20, 24], [10, null, 20], [], [null, null]]"));
+ CheckScalar("coalesce", {values1, values2},
+ ArrayFromJSON(type, "[[23], [10, null, 20], [], [null, null]]"));
+ CheckScalar("coalesce", {values1, values2, values3},
+ ArrayFromJSON(type, "[[23], [10, null, 20], [], [null, null]]"));
+ CheckScalar("coalesce", {scalar1, values1},
+ ArrayFromJSON(type, "[[20, 24], [20, 24], [20, 24], [20, 24]]"));
+}
+
+TYPED_TEST(TestCoalesceBinary, Basics) {
+ auto type = default_type_instance<TypeParam>();
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"("a")");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 = ArrayFromJSON(type, R"([null, "bc", "def", "ghij"])");
+ auto values2 = ArrayFromJSON(type, R"(["klmno", "p", "qr", "stu"])");
+ auto values3 = ArrayFromJSON(type, R"(["vwxy", "zabc", "d", null])");
+ // N.B. all-scalar cases are checked in CheckScalar
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, R"(["a", "a", "a", "a"])"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar("coalesce", {values1, scalar1},
+ ArrayFromJSON(type, R"(["a", "bc", "def", "ghij"])"));
+ CheckScalar("coalesce", {values1, values2},
+ ArrayFromJSON(type, R"(["klmno", "bc", "def", "ghij"])"));
+ CheckScalar("coalesce", {values1, values2, values3},
+ ArrayFromJSON(type, R"(["klmno", "bc", "def", "ghij"])"));
+ CheckScalar("coalesce", {scalar1, values1},
+ ArrayFromJSON(type, R"(["a", "a", "a", "a"])"));
+}
+
+TYPED_TEST(TestCoalesceList, ListOfString) {
+ auto type = std::make_shared<TypeParam>(utf8());
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([null, "a"])");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 = ArrayFromJSON(type, R"([null, ["bc", null], ["def"], []])");
+ auto values2 = ArrayFromJSON(type, R"([["klmno"], ["p"], ["qr", null], ["stu"]])");
+ auto values3 = ArrayFromJSON(type, R"([["vwxy"], [], ["d"], null])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar(
+ "coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, R"([[null, "a"], [null, "a"], [null, "a"], [null, "a"]])"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar("coalesce", {values1, scalar1},
+ ArrayFromJSON(type, R"([[null, "a"], ["bc", null], ["def"], []])"));
+ CheckScalar("coalesce", {values1, values2},
+ ArrayFromJSON(type, R"([["klmno"], ["bc", null], ["def"], []])"));
+ CheckScalar("coalesce", {values1, values2, values3},
+ ArrayFromJSON(type, R"([["klmno"], ["bc", null], ["def"], []])"));
+ CheckScalar(
+ "coalesce", {scalar1, values1},
+ ArrayFromJSON(type, R"([[null, "a"], [null, "a"], [null, "a"], [null, "a"]])"));
+}
+
+// More minimal tests to check type coverage
+TYPED_TEST(TestCoalesceList, ListOfBool) {
+ auto type = std::make_shared<TypeParam>(boolean());
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "[true, false, null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, "[null, [true, null, true], [], [null, null]]");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1},
+ ArrayFromJSON(type,
+ "[[true, false, null], [true, false, null], [true, false, "
+ "null], [true, false, null]]"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+}
+
+TYPED_TEST(TestCoalesceList, ListOfInt) {
+ auto type = std::make_shared<TypeParam>(int64());
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "[20, 24]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, "[null, [10, null, 20], [], [null, null]]");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, "[[20, 24], [20, 24], [20, 24], [20, 24]]"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+}
+
+TYPED_TEST(TestCoalesceList, ListOfDayTimeInterval) {
+ auto type = std::make_shared<TypeParam>(day_time_interval());
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "[[20, 24], null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 =
+ ArrayFromJSON(type, "[null, [[10, 12], null, [20, 22]], [], [null, null]]");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar(
+ "coalesce", {values_null, scalar1},
+ ArrayFromJSON(
+ type,
+ "[[[20, 24], null], [[20, 24], null], [[20, 24], null], [[20, 24], null]]"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+}
+
+TYPED_TEST(TestCoalesceList, ListOfDecimal) {
+ for (auto ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ auto type = std::make_shared<TypeParam>(ty);
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"(["0.42", null])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"([null, ["1.23"], [], [null, null]])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar(
+ "coalesce", {values_null, scalar1},
+ ArrayFromJSON(
+ type, R"([["0.42", null], ["0.42", null], ["0.42", null], ["0.42", null]])"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ }
+}
+
+TYPED_TEST(TestCoalesceList, ListOfFixedSizeBinary) {
+ auto type = std::make_shared<TypeParam>(fixed_size_binary(3));
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"(["ab!", null])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"([null, ["def"], [], [null, null]])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar(
+ "coalesce", {values_null, scalar1},
+ ArrayFromJSON(type,
+ R"([["ab!", null], ["ab!", null], ["ab!", null], ["ab!", null]])"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+}
+
+TYPED_TEST(TestCoalesceList, ListOfListOfInt) {
+ auto type = std::make_shared<TypeParam>(std::make_shared<TypeParam>(int64()));
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "[[20], null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, "[null, [[10, 12], null, []], [], [null, null]]");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar(
+ "coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, "[[[20], null], [[20], null], [[20], null], [[20], null]]"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+}
+
+TYPED_TEST(TestCoalesceList, Errors) {
+ auto type1 = std::make_shared<TypeParam>(int64());
+ auto type2 = std::make_shared<TypeParam>(utf8());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError, ::testing::HasSubstr("All types must be compatible"),
+ CallFunction("coalesce", {
+ ArrayFromJSON(type1, "[null]"),
+ ArrayFromJSON(type2, "[null]"),
+ }));
+}
+
+TEST(TestCoalesce, Null) {
+ auto type = null();
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar_null}, values_null);
+}
+
+TEST(TestCoalesce, Boolean) {
+ auto type = boolean();
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "false");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, "[null, true, false, true]");
+ auto values2 = ArrayFromJSON(type, "[true, false, true, false]");
+ auto values3 = ArrayFromJSON(type, "[false, true, false, null]");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, "[false, false, false, false]"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar("coalesce", {values1, scalar1},
+ ArrayFromJSON(type, "[false, true, false, true]"));
+ CheckScalar("coalesce", {values1, values2},
+ ArrayFromJSON(type, "[true, true, false, true]"));
+ CheckScalar("coalesce", {values1, values2, values3},
+ ArrayFromJSON(type, "[true, true, false, true]"));
+ CheckScalar("coalesce", {scalar1, values1},
+ ArrayFromJSON(type, "[false, false, false, false]"));
+}
+
+TEST(TestCoalesce, DayTimeInterval) {
+ auto type = day_time_interval();
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "[1, 2]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, "[null, [3, 4], [5, 6], [7, 8]]");
+ auto values2 = ArrayFromJSON(type, "[[9, 10], [11, 12], [13, 14], [15, 16]]");
+ auto values3 = ArrayFromJSON(type, "[[17, 18], [19, 20], [21, 22], null]");
+ // N.B. all-scalar cases are checked in CheckScalar
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, "[[1, 2], [1, 2], [1, 2], [1, 2]]"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar("coalesce", {values1, scalar1},
+ ArrayFromJSON(type, "[[1, 2], [3, 4], [5, 6], [7, 8]]"));
+ CheckScalar("coalesce", {values1, values2},
+ ArrayFromJSON(type, "[[9, 10], [3, 4], [5, 6], [7, 8]]"));
+ CheckScalar("coalesce", {values1, values2, values3},
+ ArrayFromJSON(type, "[[9, 10], [3, 4], [5, 6], [7, 8]]"));
+ CheckScalar("coalesce", {scalar1, values1},
+ ArrayFromJSON(type, "[[1, 2], [1, 2], [1, 2], [1, 2]]"));
+}
+
+TEST(TestCoalesce, Decimal) {
+ for (const auto& type :
+ std::vector<std::shared_ptr<DataType>>{decimal128(3, 2), decimal256(3, 2)}) {
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"("1.23")");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 = ArrayFromJSON(type, R"([null, "4.56", "7.89", "1.34"])");
+ auto values2 = ArrayFromJSON(type, R"(["1.45", "2.34", "3.45", "4.56"])");
+ auto values3 = ArrayFromJSON(type, R"(["5.67", "6.78", "7.91", null])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, R"(["1.23", "1.23", "1.23", "1.23"])"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar("coalesce", {values1, scalar1},
+ ArrayFromJSON(type, R"(["1.23", "4.56", "7.89", "1.34"])"));
+ CheckScalar("coalesce", {values1, values2},
+ ArrayFromJSON(type, R"(["1.45", "4.56", "7.89", "1.34"])"));
+ CheckScalar("coalesce", {values1, values2, values3},
+ ArrayFromJSON(type, R"(["1.45", "4.56", "7.89", "1.34"])"));
+ CheckScalar("coalesce", {scalar1, values1},
+ ArrayFromJSON(type, R"(["1.23", "1.23", "1.23", "1.23"])"));
+ }
+ // Ensure promotion
+ CheckScalar("coalesce",
+ {
+ ArrayFromJSON(decimal128(3, 2), R"(["1.23", null])"),
+ ArrayFromJSON(decimal128(4, 1), R"([null, "1.0"])"),
+ },
+ ArrayFromJSON(decimal128(5, 2), R"(["1.23", "1.00"])"));
+}
+
+TEST(TestCoalesce, FixedSizeBinary) {
+ auto type = fixed_size_binary(3);
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"("abc")");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 = ArrayFromJSON(type, R"([null, "def", "ghi", "jkl"])");
+ auto values2 = ArrayFromJSON(type, R"(["mno", "pqr", "stu", "vwx"])");
+ auto values3 = ArrayFromJSON(type, R"(["yza", "bcd", "efg", null])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, R"(["abc", "abc", "abc", "abc"])"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar("coalesce", {values1, scalar1},
+ ArrayFromJSON(type, R"(["abc", "def", "ghi", "jkl"])"));
+ CheckScalar("coalesce", {values1, values2},
+ ArrayFromJSON(type, R"(["mno", "def", "ghi", "jkl"])"));
+ CheckScalar("coalesce", {values1, values2, values3},
+ ArrayFromJSON(type, R"(["mno", "def", "ghi", "jkl"])"));
+ CheckScalar("coalesce", {scalar1, values1},
+ ArrayFromJSON(type, R"(["abc", "abc", "abc", "abc"])"));
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: "
+ "fixed_size_binary[3], but got: fixed_size_binary[2]"),
+ CallFunction("coalesce", {
+ ArrayFromJSON(type, "[null]"),
+ ArrayFromJSON(fixed_size_binary(2), "[null]"),
+ }));
+}
+
+TEST(TestCoalesce, FixedSizeListOfInt) {
+ auto type = fixed_size_list(uint8(), 2);
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([42, null])");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 = ArrayFromJSON(type, R"([null, [2, null], [4, 8], [null, null]])");
+ auto values2 = ArrayFromJSON(type, R"([[1, 5], [16, 32], [64, null], [null, 128]])");
+ auto values3 = ArrayFromJSON(type, R"([[null, null], [1, 3], [9, 27], null])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, R"([[42, null], [42, null], [42, null], [42, null]])"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar("coalesce", {values1, scalar1},
+ ArrayFromJSON(type, R"([[42, null], [2, null], [4, 8], [null, null]])"));
+ CheckScalar("coalesce", {values1, values2},
+ ArrayFromJSON(type, R"([[1, 5], [2, null], [4, 8], [null, null]])"));
+ CheckScalar("coalesce", {values1, values2, values3},
+ ArrayFromJSON(type, R"([[1, 5], [2, null], [4, 8], [null, null]])"));
+ CheckScalar("coalesce", {scalar1, values1},
+ ArrayFromJSON(type, R"([[42, null], [42, null], [42, null], [42, null]])"));
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr(
+ "All types must be compatible, expected: fixed_size_list<item: "
+ "uint8>[2], but got: fixed_size_list<item: uint8>[3]"),
+ CallFunction("coalesce", {
+ ArrayFromJSON(type, "[null]"),
+ ArrayFromJSON(fixed_size_list(uint8(), 3), "[null]"),
+ }));
+}
+
+TEST(TestCoalesce, FixedSizeListOfString) {
+ auto type = fixed_size_list(utf8(), 2);
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"(["abc", null])");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 =
+ ArrayFromJSON(type, R"([null, ["d", null], ["ghi", "jkl"], [null, null]])");
+ auto values2 = ArrayFromJSON(
+ type, R"([["mno", "pq"], ["pqr", "ab"], ["stu", null], [null, "vwx"]])");
+ auto values3 =
+ ArrayFromJSON(type, R"([[null, null], ["a", "bcd"], ["d", "efg"], null])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar(
+ "coalesce", {values_null, scalar1},
+ ArrayFromJSON(type,
+ R"([["abc", null], ["abc", null], ["abc", null], ["abc", null]])"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar("coalesce", {values1, scalar1},
+ ArrayFromJSON(
+ type, R"([["abc", null], ["d", null], ["ghi", "jkl"], [null, null]])"));
+ CheckScalar("coalesce", {values1, values2},
+ ArrayFromJSON(
+ type, R"([["mno", "pq"], ["d", null], ["ghi", "jkl"], [null, null]])"));
+ CheckScalar("coalesce", {values1, values2, values3},
+ ArrayFromJSON(
+ type, R"([["mno", "pq"], ["d", null], ["ghi", "jkl"], [null, null]])"));
+ CheckScalar(
+ "coalesce", {scalar1, values1},
+ ArrayFromJSON(type,
+ R"([["abc", null], ["abc", null], ["abc", null], ["abc", null]])"));
+}
+
+TEST(TestCoalesce, Map) {
+ auto type = map(int64(), utf8());
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([[1, "a"], [5, "bc"]])");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 =
+ ArrayFromJSON(type, R"([null, [[2, "foo"], [4, null]], [[3, "test"]], []])");
+ auto values2 = ArrayFromJSON(
+ type, R"([[[1, "b"]], [[2, "c"]], [[5, "c"], [6, "d"]], [[7, "abc"]]])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1}, *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar(
+ "coalesce", {values1, scalar1},
+ ArrayFromJSON(
+ type,
+ R"([[[1, "a"], [5, "bc"]], [[2, "foo"], [4, null]], [[3, "test"]], []])"));
+ CheckScalar(
+ "coalesce", {values1, values2},
+ ArrayFromJSON(type, R"([[[1, "b"]], [[2, "foo"], [4, null]], [[3, "test"]], []])"));
+ CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4));
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: map<int64, "
+ "string>, but got: map<int64, int32>"),
+ CallFunction("coalesce", {
+ ArrayFromJSON(type, "[null]"),
+ ArrayFromJSON(map(int64(), int32()), "[null]"),
+ }));
+}
+
+TEST(TestCoalesce, Struct) {
+ auto type = struct_(
+ {field("int", uint32()), field("str", utf8()), field("list", list(int8()))});
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([42, "spam", [null, -1]])");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 = ArrayFromJSON(
+ type, R"([null, [null, "eggs", []], [0, "", [null]], [32, "abc", [1, 2, 3]]])");
+ auto values2 = ArrayFromJSON(
+ type,
+ R"([[21, "foobar", [1, null, 2]], [5, "bar", []], [20, null, null], [1, "", [null]]])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1}, *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar(
+ "coalesce", {values1, scalar1},
+ ArrayFromJSON(
+ type,
+ R"([[42, "spam", [null, -1]], [null, "eggs", []], [0, "", [null]], [32, "abc", [1, 2, 3]]])"));
+ CheckScalar(
+ "coalesce", {values1, values2},
+ ArrayFromJSON(
+ type,
+ R"([[21, "foobar", [1, null, 2]], [null, "eggs", []], [0, "", [null]], [32, "abc", [1, 2, 3]]])"));
+ CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4));
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: struct<str: "
+ "string>, but got: struct<int: uint16>"),
+ CallFunction("coalesce",
+ {
+ ArrayFromJSON(struct_({field("str", utf8())}), "[null]"),
+ ArrayFromJSON(struct_({field("int", uint16())}), "[null]"),
+ }));
+}
+
+TEST(TestCoalesce, UnionBoolString) {
+ for (const auto& type : {
+ sparse_union({field("a", boolean()), field("b", utf8())}, {2, 7}),
+ dense_union({field("a", boolean()), field("b", utf8())}, {2, 7}),
+ }) {
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([7, "foo"])");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 = ArrayFromJSON(type, R"([null, [2, false], [7, "bar"], [7, "baz"]])");
+ auto values2 =
+ ArrayFromJSON(type, R"([[2, true], [2, false], [7, "foo"], [7, "bar"]])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1}, *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar(
+ "coalesce", {values1, scalar1},
+ ArrayFromJSON(type, R"([[7, "foo"], [2, false], [7, "bar"], [7, "baz"]])"));
+ CheckScalar(
+ "coalesce", {values1, values2},
+ ArrayFromJSON(type, R"([[2, true], [2, false], [7, "bar"], [7, "baz"]])"));
+ CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4));
+ }
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: "
+ "sparse_union<a: bool=0>, but got: sparse_union<a: int64=0>"),
+ CallFunction(
+ "coalesce",
+ {
+ ArrayFromJSON(sparse_union({field("a", boolean())}), "[[0, true]]"),
+ ArrayFromJSON(sparse_union({field("a", int64())}), "[[0, 1]]"),
+ }));
+}
+
+TEST(TestCoalesce, DispatchBest) {
+ CheckDispatchBest("coalesce", {int8(), float64()}, {float64(), float64()});
+ CheckDispatchBest("coalesce", {int8(), uint32()}, {int64(), int64()});
+ CheckDispatchBest("coalesce", {binary(), utf8()}, {binary(), binary()});
+ CheckDispatchBest("coalesce", {binary(), large_binary()},
+ {large_binary(), large_binary()});
+ CheckDispatchBest("coalesce", {int32(), decimal128(3, 2)},
+ {decimal128(12, 2), decimal128(12, 2)});
+ CheckDispatchBest("coalesce", {float32(), decimal128(3, 2)}, {float64(), float64()});
+ CheckDispatchBest("coalesce", {decimal128(3, 2), decimal256(3, 2)},
+ {decimal256(3, 2), decimal256(3, 2)});
+ CheckDispatchBest("coalesce", {timestamp(TimeUnit::SECOND), date32()},
+ {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND)});
+ CheckDispatchBest("coalesce", {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::MILLI)},
+ {timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)});
+ CheckDispatchFails("coalesce", {
+ sparse_union({field("a", boolean())}),
+ dense_union({field("a", boolean())}),
+ });
+}
+
+template <typename Type>
+class TestChooseNumeric : public ::testing::Test {};
+template <typename Type>
+class TestChooseBinary : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestChooseNumeric, NumericBasedTypes);
+TYPED_TEST_SUITE(TestChooseBinary, BinaryArrowTypes);
+
+TYPED_TEST(TestChooseNumeric, FixedSize) {
+ auto type = default_type_instance<TypeParam>();
+ auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]");
+ auto values1 = ArrayFromJSON(type, "[10, 11, null, null, 14]");
+ auto values2 = ArrayFromJSON(type, "[20, 21, null, null, 24]");
+ auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]");
+ CheckScalar("choose", {indices1, values1, values2},
+ ArrayFromJSON(type, "[10, 21, null, null, null]"));
+ CheckScalar("choose", {indices1, ScalarFromJSON(type, "1"), values1},
+ ArrayFromJSON(type, "[1, 11, 1, null, null]"));
+ // Mixed scalar and array (note CheckScalar checks all-scalar cases for us)
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls);
+ auto scalar1 = ScalarFromJSON(type, "42");
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar1, values2},
+ *MakeArrayFromScalar(*scalar1, 5));
+ CheckScalar("choose", {ScalarFromJSON(int64(), "1"), scalar1, values2}, values2);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls);
+ auto scalar_null = ScalarFromJSON(type, "null");
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar_null, values2},
+ *MakeArrayOfNull(type, 5));
+}
+
+TYPED_TEST(TestChooseBinary, Basics) {
+ auto type = default_type_instance<TypeParam>();
+ auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]");
+ auto values1 = ArrayFromJSON(type, R"(["a", "bc", null, null, "def"])");
+ auto values2 = ArrayFromJSON(type, R"(["ghij", "klmno", null, null, "pqrstu"])");
+ auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]");
+ CheckScalar("choose", {indices1, values1, values2},
+ ArrayFromJSON(type, R"(["a", "klmno", null, null, null])"));
+ CheckScalar("choose", {indices1, ScalarFromJSON(type, R"("foo")"), values1},
+ ArrayFromJSON(type, R"(["foo", "bc", "foo", null, null])"));
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls);
+ auto scalar1 = ScalarFromJSON(type, R"("abcd")");
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar1, values2},
+ *MakeArrayFromScalar(*scalar1, 5));
+ CheckScalar("choose", {ScalarFromJSON(int64(), "1"), scalar1, values2}, values2);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls);
+ auto scalar_null = ScalarFromJSON(type, "null");
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar_null, values2},
+ *MakeArrayOfNull(type, 5));
+}
+
+TEST(TestChoose, Null) {
+ auto type = null();
+ auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]");
+ auto nulls = *MakeArrayOfNull(type, 5);
+ CheckScalar("choose", {indices1, nulls, nulls}, nulls);
+ CheckScalar("choose", {indices1, MakeNullScalar(type), nulls}, nulls);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), nulls, nulls}, nulls);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "1"), nulls, nulls}, nulls);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "null"), nulls, nulls}, nulls);
+ auto scalar_null = ScalarFromJSON(type, "null");
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar_null, nulls}, nulls);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "1"), scalar_null, nulls}, nulls);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "null"), nulls, nulls}, nulls);
+}
+
+TEST(TestChoose, Boolean) {
+ auto type = boolean();
+ auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]");
+ auto values1 = ArrayFromJSON(type, "[true, true, null, null, true]");
+ auto values2 = ArrayFromJSON(type, "[false, false, null, null, false]");
+ auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]");
+ CheckScalar("choose", {indices1, values1, values2},
+ ArrayFromJSON(type, "[true, false, null, null, null]"));
+ CheckScalar("choose", {indices1, ScalarFromJSON(type, "false"), values1},
+ ArrayFromJSON(type, "[false, true, false, null, null]"));
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls);
+ auto scalar1 = ScalarFromJSON(type, "true");
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar1, values2},
+ *MakeArrayFromScalar(*scalar1, 5));
+ CheckScalar("choose", {ScalarFromJSON(int64(), "1"), scalar1, values2}, values2);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls);
+ auto scalar_null = ScalarFromJSON(type, "null");
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar_null, values2},
+ *MakeArrayOfNull(type, 5));
+}
+
+TEST(TestChoose, DayTimeInterval) {
+ auto type = day_time_interval();
+ auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]");
+ auto values1 = ArrayFromJSON(type, "[[10, 1], [10, 1], null, null, [10, 1]]");
+ auto values2 = ArrayFromJSON(type, "[[2, 20], [2, 20], null, null, [2, 20]]");
+ auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]");
+ CheckScalar("choose", {indices1, values1, values2},
+ ArrayFromJSON(type, "[[10, 1], [2, 20], null, null, null]"));
+ CheckScalar("choose", {indices1, ScalarFromJSON(type, "[1, 2]"), values1},
+ ArrayFromJSON(type, "[[1, 2], [10, 1], [1, 2], null, null]"));
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls);
+ auto scalar1 = ScalarFromJSON(type, "[10, 1]");
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar1, values2},
+ *MakeArrayFromScalar(*scalar1, 5));
+ CheckScalar("choose", {ScalarFromJSON(int64(), "1"), scalar1, values2}, values2);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls);
+ auto scalar_null = ScalarFromJSON(type, "null");
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar_null, values2},
+ *MakeArrayOfNull(type, 5));
+}
+
+TEST(TestChoose, Decimal) {
+ for (const auto& type : {decimal128(3, 2), decimal256(3, 2)}) {
+ auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]");
+ auto values1 = ArrayFromJSON(type, R"(["1.23", "1.24", null, null, "1.25"])");
+ auto values2 = ArrayFromJSON(type, R"(["4.56", "4.57", null, null, "4.58"])");
+ auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]");
+ CheckScalar("choose", {indices1, values1, values2},
+ ArrayFromJSON(type, R"(["1.23", "4.57", null, null, null])"));
+ CheckScalar("choose", {indices1, ScalarFromJSON(type, R"("2.34")"), values1},
+ ArrayFromJSON(type, R"(["2.34", "1.24", "2.34", null, null])"));
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls);
+ auto scalar1 = ScalarFromJSON(type, R"("1.23")");
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar1, values2},
+ *MakeArrayFromScalar(*scalar1, 5));
+ CheckScalar("choose", {ScalarFromJSON(int64(), "1"), scalar1, values2}, values2);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls);
+ auto scalar_null = ScalarFromJSON(type, "null");
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar_null, values2},
+ *MakeArrayOfNull(type, 5));
+ }
+}
+
+TEST(TestChoose, FixedSizeBinary) {
+ auto type = fixed_size_binary(3);
+ auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]");
+ auto values1 = ArrayFromJSON(type, R"(["abc", "abd", null, null, "abe"])");
+ auto values2 = ArrayFromJSON(type, R"(["def", "deg", null, null, "deh"])");
+ auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]");
+ CheckScalar("choose", {indices1, values1, values2},
+ ArrayFromJSON(type, R"(["abc", "deg", null, null, null])"));
+ CheckScalar("choose", {indices1, ScalarFromJSON(type, R"("xyz")"), values1},
+ ArrayFromJSON(type, R"(["xyz", "abd", "xyz", null, null])"));
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls);
+ auto scalar1 = ScalarFromJSON(type, R"("abc")");
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar1, values2},
+ *MakeArrayFromScalar(*scalar1, 5));
+ CheckScalar("choose", {ScalarFromJSON(int64(), "1"), scalar1, values2}, values2);
+ CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls);
+ auto scalar_null = ScalarFromJSON(type, "null");
+ CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar_null, values2},
+ *MakeArrayOfNull(type, 5));
+}
+
+TEST(TestChooseKernel, DispatchBest) {
+ ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction("choose"));
+ auto Check = [&](std::vector<ValueDescr> original_values) {
+ auto values = original_values;
+ ARROW_EXPECT_OK(function->DispatchBest(&values));
+ return values;
+ };
+
+ // Since DispatchBest for this kernel pulls tricks, we can't compare it to DispatchExact
+ // as CheckDispatchBest does
+ for (auto ty :
+ {int8(), int16(), int32(), int64(), uint8(), uint16(), uint32(), uint64()}) {
+ // Index always promoted to int64
+ EXPECT_EQ((std::vector<ValueDescr>{int64(), ty}), Check({ty, ty}));
+ EXPECT_EQ((std::vector<ValueDescr>{int64(), int64(), int64()}),
+ Check({ty, ty, int64()}));
+ }
+ // Other arguments promoted separately from index
+ EXPECT_EQ((std::vector<ValueDescr>{int64(), int32(), int32()}),
+ Check({int8(), int32(), uint8()}));
+}
+
+TEST(TestChooseKernel, Errors) {
+ ASSERT_RAISES(Invalid, CallFunction("choose", {}));
+ ASSERT_RAISES(Invalid, CallFunction("choose", {ArrayFromJSON(int64(), "[]")}));
+ ASSERT_RAISES(Invalid, CallFunction("choose", {ArrayFromJSON(utf8(), "[\"a\"]"),
+ ArrayFromJSON(int64(), "[0]")}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ IndexError, ::testing::HasSubstr("choose: index 1 out of range"),
+ CallFunction("choose",
+ {ArrayFromJSON(int64(), "[1]"), ArrayFromJSON(int32(), "[0]")}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ IndexError, ::testing::HasSubstr("choose: index -1 out of range"),
+ CallFunction("choose",
+ {ArrayFromJSON(int64(), "[-1]"), ArrayFromJSON(int32(), "[0]")}));
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_nested.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_nested.cc
new file mode 100644
index 000000000..aafaeb341
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_nested.cc
@@ -0,0 +1,317 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Vector kernels involving nested types
+
+#include "arrow/array/array_base.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/result.h"
+#include "arrow/util/bit_block_counter.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+namespace {
+
+template <typename Type, typename offset_type = typename Type::offset_type>
+Status ListValueLength(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+ using OffsetScalarType = typename TypeTraits<Type>::OffsetScalarType;
+
+ if (batch[0].kind() == Datum::ARRAY) {
+ typename TypeTraits<Type>::ArrayType list(batch[0].array());
+ ArrayData* out_arr = out->mutable_array();
+ auto out_values = out_arr->GetMutableValues<offset_type>(1);
+ const offset_type* offsets = list.raw_value_offsets();
+ ::arrow::internal::VisitBitBlocksVoid(
+ list.data()->buffers[0], list.offset(), list.length(),
+ [&](int64_t position) {
+ *out_values++ = offsets[position + 1] - offsets[position];
+ },
+ [&]() { *out_values++ = 0; });
+ } else {
+ const auto& arg0 = batch[0].scalar_as<ScalarType>();
+ if (arg0.is_valid) {
+ checked_cast<OffsetScalarType*>(out->scalar().get())->value =
+ static_cast<offset_type>(arg0.value->length());
+ }
+ }
+
+ return Status::OK();
+}
+
+Status FixedSizeListValueLength(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ using offset_type = typename FixedSizeListType::offset_type;
+ auto width = checked_cast<const FixedSizeListType&>(*batch[0].type()).list_size();
+ if (batch[0].kind() == Datum::ARRAY) {
+ const auto& arr = *batch[0].array();
+ ArrayData* out_arr = out->mutable_array();
+ auto* out_values = out_arr->GetMutableValues<offset_type>(1);
+ std::fill(out_values, out_values + arr.length, width);
+ } else {
+ const auto& arg0 = batch[0].scalar_as<FixedSizeListScalar>();
+ if (arg0.is_valid) {
+ checked_cast<Int32Scalar*>(out->scalar().get())->value = width;
+ }
+ }
+
+ return Status::OK();
+}
+
+const FunctionDoc list_value_length_doc{
+ "Compute list lengths",
+ ("`lists` must have a list-like type.\n"
+ "For each non-null value in `lists`, its length is emitted.\n"
+ "Null values emit a null in the output."),
+ {"lists"}};
+
+template <typename Type, typename IndexType>
+struct ListElementArray {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ using ListArrayType = typename TypeTraits<Type>::ArrayType;
+ using IndexScalarType = typename TypeTraits<IndexType>::ScalarType;
+ const auto& index_scalar = batch[1].scalar_as<IndexScalarType>();
+ if (ARROW_PREDICT_FALSE(!index_scalar.is_valid)) {
+ return Status::Invalid("Index must not be null");
+ }
+ ListArrayType list_array(batch[0].array());
+ auto index = index_scalar.value;
+ if (ARROW_PREDICT_FALSE(index < 0)) {
+ return Status::Invalid("Index ", index,
+ " is out of bounds: should be greater than or equal to 0");
+ }
+ std::unique_ptr<ArrayBuilder> builder;
+ RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), list_array.value_type(), &builder));
+ RETURN_NOT_OK(builder->Reserve(list_array.length()));
+ for (int i = 0; i < list_array.length(); ++i) {
+ if (list_array.IsNull(i)) {
+ RETURN_NOT_OK(builder->AppendNull());
+ continue;
+ }
+ std::shared_ptr<arrow::Array> value_array = list_array.value_slice(i);
+ auto len = value_array->length();
+ if (ARROW_PREDICT_FALSE(index >= static_cast<typename IndexType::c_type>(len))) {
+ return Status::Invalid("Index ", index, " is out of bounds: should be in [0, ",
+ len, ")");
+ }
+ RETURN_NOT_OK(builder->AppendArraySlice(*value_array->data(), index, 1));
+ }
+ ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish());
+ out->value = result->data();
+ return Status::OK();
+ }
+};
+
+template <typename, typename IndexType>
+struct ListElementScalar {
+ static Status Exec(KernelContext* /*ctx*/, const ExecBatch& batch, Datum* out) {
+ using IndexScalarType = typename TypeTraits<IndexType>::ScalarType;
+ const auto& index_scalar = batch[1].scalar_as<IndexScalarType>();
+ if (ARROW_PREDICT_FALSE(!index_scalar.is_valid)) {
+ return Status::Invalid("Index must not be null");
+ }
+ const auto& list_scalar = batch[0].scalar_as<BaseListScalar>();
+ if (ARROW_PREDICT_FALSE(!list_scalar.is_valid)) {
+ out->value = MakeNullScalar(
+ checked_cast<const BaseListType&>(*batch[0].type()).value_type());
+ return Status::OK();
+ }
+ auto list = list_scalar.value;
+ auto index = index_scalar.value;
+ auto len = list->length();
+ if (ARROW_PREDICT_FALSE(index < 0 ||
+ index >= static_cast<typename IndexType::c_type>(len))) {
+ return Status::Invalid("Index ", index, " is out of bounds: should be in [0, ", len,
+ ")");
+ }
+ ARROW_ASSIGN_OR_RAISE(out->value, list->GetScalar(index));
+ return Status::OK();
+ }
+};
+
+template <typename InListType>
+void AddListElementArrayKernels(ScalarFunction* func) {
+ for (const auto& index_type : IntTypes()) {
+ auto inputs = {InputType::Array(InListType::type_id), InputType::Scalar(index_type)};
+ auto output = OutputType{ListValuesType};
+ auto sig = KernelSignature::Make(std::move(inputs), std::move(output),
+ /*is_varargs=*/false);
+ auto scalar_exec = GenerateInteger<ListElementArray, InListType>({index_type->id()});
+ ScalarKernel kernel{std::move(sig), std::move(scalar_exec)};
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ }
+}
+
+void AddListElementArrayKernels(ScalarFunction* func) {
+ AddListElementArrayKernels<ListType>(func);
+ AddListElementArrayKernels<LargeListType>(func);
+ AddListElementArrayKernels<FixedSizeListType>(func);
+}
+
+void AddListElementScalarKernels(ScalarFunction* func) {
+ for (const auto list_type_id : {Type::LIST, Type::LARGE_LIST, Type::FIXED_SIZE_LIST}) {
+ for (const auto& index_type : IntTypes()) {
+ auto inputs = {InputType::Scalar(list_type_id), InputType::Scalar(index_type)};
+ auto output = OutputType{ListValuesType};
+ auto sig = KernelSignature::Make(std::move(inputs), std::move(output),
+ /*is_varargs=*/false);
+ auto scalar_exec = GenerateInteger<ListElementScalar, void>({index_type->id()});
+ ScalarKernel kernel{std::move(sig), std::move(scalar_exec)};
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ }
+ }
+}
+
+const FunctionDoc list_element_doc(
+ "Compute elements using of nested list values using an index",
+ ("`lists` must have a list-like type.\n"
+ "For each value in each list of `lists`, the element at `index`\n"
+ "is emitted. Null values emit a null in the output."),
+ {"lists", "index"});
+
+Result<ValueDescr> MakeStructResolve(KernelContext* ctx,
+ const std::vector<ValueDescr>& descrs) {
+ auto names = OptionsWrapper<MakeStructOptions>::Get(ctx).field_names;
+ auto nullable = OptionsWrapper<MakeStructOptions>::Get(ctx).field_nullability;
+ auto metadata = OptionsWrapper<MakeStructOptions>::Get(ctx).field_metadata;
+
+ if (names.size() == 0) {
+ names.resize(descrs.size());
+ nullable.resize(descrs.size(), true);
+ metadata.resize(descrs.size(), nullptr);
+ int i = 0;
+ for (auto& name : names) {
+ name = std::to_string(i++);
+ }
+ } else if (names.size() != descrs.size() || nullable.size() != descrs.size() ||
+ metadata.size() != descrs.size()) {
+ return Status::Invalid("make_struct() was passed ", descrs.size(), " arguments but ",
+ names.size(), " field names, ", nullable.size(),
+ " nullability bits, and ", metadata.size(),
+ " metadata dictionaries.");
+ }
+
+ size_t i = 0;
+ FieldVector fields(descrs.size());
+
+ ValueDescr::Shape shape = ValueDescr::SCALAR;
+ for (const ValueDescr& descr : descrs) {
+ if (descr.shape != ValueDescr::SCALAR) {
+ shape = ValueDescr::ARRAY;
+ } else {
+ switch (descr.type->id()) {
+ case Type::EXTENSION:
+ case Type::DENSE_UNION:
+ case Type::SPARSE_UNION:
+ return Status::NotImplemented("Broadcasting scalars of type ", *descr.type);
+ default:
+ break;
+ }
+ }
+
+ fields[i] =
+ field(std::move(names[i]), descr.type, nullable[i], std::move(metadata[i]));
+ ++i;
+ }
+
+ return ValueDescr{struct_(std::move(fields)), shape};
+}
+
+Status MakeStructExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ ARROW_ASSIGN_OR_RAISE(auto descr, MakeStructResolve(ctx, batch.GetDescriptors()));
+
+ for (int i = 0; i < batch.num_values(); ++i) {
+ const auto& field = checked_cast<const StructType&>(*descr.type).field(i);
+ if (batch[i].null_count() > 0 && !field->nullable()) {
+ return Status::Invalid("Output field ", field, " (#", i,
+ ") does not allow nulls but the corresponding "
+ "argument was not entirely valid.");
+ }
+ }
+
+ if (descr.shape == ValueDescr::SCALAR) {
+ ScalarVector scalars(batch.num_values());
+ for (int i = 0; i < batch.num_values(); ++i) {
+ scalars[i] = batch[i].scalar();
+ }
+
+ *out =
+ Datum(std::make_shared<StructScalar>(std::move(scalars), std::move(descr.type)));
+ return Status::OK();
+ }
+
+ ArrayVector arrays(batch.num_values());
+ for (int i = 0; i < batch.num_values(); ++i) {
+ if (batch[i].is_array()) {
+ arrays[i] = batch[i].make_array();
+ continue;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(arrays[i], MakeArrayFromScalar(*batch[i].scalar(), batch.length,
+ ctx->memory_pool()));
+ }
+
+ *out = std::make_shared<StructArray>(descr.type, batch.length, std::move(arrays));
+ return Status::OK();
+}
+
+const FunctionDoc make_struct_doc{"Wrap Arrays into a StructArray",
+ ("Names of the StructArray's fields are\n"
+ "specified through MakeStructOptions."),
+ {"*args"},
+ "MakeStructOptions"};
+
+} // namespace
+
+void RegisterScalarNested(FunctionRegistry* registry) {
+ auto list_value_length = std::make_shared<ScalarFunction>(
+ "list_value_length", Arity::Unary(), &list_value_length_doc);
+ DCHECK_OK(list_value_length->AddKernel({InputType(Type::LIST)}, int32(),
+ ListValueLength<ListType>));
+ DCHECK_OK(list_value_length->AddKernel({InputType(Type::FIXED_SIZE_LIST)}, int32(),
+ FixedSizeListValueLength));
+ DCHECK_OK(list_value_length->AddKernel({InputType(Type::LARGE_LIST)}, int64(),
+ ListValueLength<LargeListType>));
+ DCHECK_OK(registry->AddFunction(std::move(list_value_length)));
+
+ auto list_element = std::make_shared<ScalarFunction>("list_element", Arity::Binary(),
+ &list_element_doc);
+ AddListElementArrayKernels(list_element.get());
+ AddListElementScalarKernels(list_element.get());
+ DCHECK_OK(registry->AddFunction(std::move(list_element)));
+
+ static MakeStructOptions kDefaultMakeStructOptions;
+ auto make_struct_function = std::make_shared<ScalarFunction>(
+ "make_struct", Arity::VarArgs(), &make_struct_doc, &kDefaultMakeStructOptions);
+
+ ScalarKernel kernel{KernelSignature::Make({InputType{}}, OutputType{MakeStructResolve},
+ /*is_varargs=*/true),
+ MakeStructExec, OptionsWrapper<MakeStructOptions>::Init};
+ kernel.null_handling = NullHandling::OUTPUT_NOT_NULL;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ DCHECK_OK(make_struct_function->AddKernel(std::move(kernel)));
+ DCHECK_OK(registry->AddFunction(std::move(make_struct_function)));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_nested_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_nested_test.cc
new file mode 100644
index 000000000..cb1625739
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_nested_test.cc
@@ -0,0 +1,249 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include "arrow/chunked_array.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/result.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+namespace compute {
+
+static std::shared_ptr<DataType> GetOffsetType(const DataType& type) {
+ return type.id() == Type::LIST ? int32() : int64();
+}
+
+TEST(TestScalarNested, ListValueLength) {
+ for (auto ty : {list(int32()), large_list(int32())}) {
+ CheckScalarUnary("list_value_length", ty, "[[0, null, 1], null, [2, 3], []]",
+ GetOffsetType(*ty), "[3, null, 2, 0]");
+ }
+
+ CheckScalarUnary("list_value_length", fixed_size_list(int32(), 3),
+ "[[0, null, 1], null, [2, 3, 4], [1, 2, null]]", int32(),
+ "[3, null, 3, 3]");
+}
+
+TEST(TestScalarNested, ListElementNonFixedListWithNulls) {
+ auto sample = "[[7, 5, 81], [6, null, 4, 7, 8], [3, 12, 2, 0], [1, 9], null]";
+ for (auto ty : NumericTypes()) {
+ for (auto list_type : {list(ty), large_list(ty)}) {
+ auto input = ArrayFromJSON(list_type, sample);
+ auto null_input = ArrayFromJSON(list_type, "[null]");
+ for (auto index_type : IntTypes()) {
+ auto index = ScalarFromJSON(index_type, "1");
+ auto expected = ArrayFromJSON(ty, "[5, null, 12, 9, null]");
+ auto expected_null = ArrayFromJSON(ty, "[null]");
+ CheckScalar("list_element", {input, index}, expected);
+ CheckScalar("list_element", {null_input, index}, expected_null);
+ }
+ }
+ }
+}
+
+TEST(TestScalarNested, ListElementFixedList) {
+ auto sample = "[[7, 5, 81], [6, 4, 8], [3, 12, 2], [1, 43, 87]]";
+ for (auto ty : NumericTypes()) {
+ auto input = ArrayFromJSON(fixed_size_list(ty, 3), sample);
+ for (auto index_type : IntTypes()) {
+ auto index = ScalarFromJSON(index_type, "0");
+ auto expected = ArrayFromJSON(ty, "[7, 6, 3, 1]");
+ CheckScalar("list_element", {input, index}, expected);
+ }
+ }
+}
+
+TEST(TestScalarNested, ListElementInvalid) {
+ auto input_array = ArrayFromJSON(list(float32()), "[[0.1, 1.1], [0.2, 1.2]]");
+ auto input_scalar = ScalarFromJSON(list(float32()), "[0.1, 0.2]");
+
+ // invalid index: null
+ auto index = ScalarFromJSON(int32(), "null");
+ EXPECT_THAT(CallFunction("list_element", {input_array, index}),
+ Raises(StatusCode::Invalid));
+ EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
+ Raises(StatusCode::Invalid));
+
+ // invalid index: < 0
+ index = ScalarFromJSON(int32(), "-1");
+ EXPECT_THAT(CallFunction("list_element", {input_array, index}),
+ Raises(StatusCode::Invalid));
+ EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
+ Raises(StatusCode::Invalid));
+
+ // invalid index: >= list.length
+ index = ScalarFromJSON(int32(), "2");
+ EXPECT_THAT(CallFunction("list_element", {input_array, index}),
+ Raises(StatusCode::Invalid));
+ EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
+ Raises(StatusCode::Invalid));
+
+ // invalid input
+ input_array = ArrayFromJSON(list(float32()), "[[41, 6, 93], [], [2]]");
+ input_scalar = ScalarFromJSON(list(float32()), "[]");
+ index = ScalarFromJSON(int32(), "0");
+ EXPECT_THAT(CallFunction("list_element", {input_array, index}),
+ Raises(StatusCode::Invalid));
+ EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
+ Raises(StatusCode::Invalid));
+}
+
+struct {
+ Result<Datum> operator()(std::vector<Datum> args) {
+ return CallFunction("make_struct", args);
+ }
+
+ template <typename... Options>
+ Result<Datum> operator()(std::vector<Datum> args, std::vector<std::string> field_names,
+ Options... options) {
+ MakeStructOptions opts{field_names, options...};
+ return CallFunction("make_struct", args, &opts);
+ }
+} MakeStruct;
+
+TEST(MakeStruct, Scalar) {
+ auto i32 = MakeScalar(1);
+ auto f64 = MakeScalar(2.5);
+ auto str = MakeScalar("yo");
+
+ EXPECT_THAT(MakeStruct({i32, f64, str}, {"i", "f", "s"}),
+ ResultWith(Datum(*StructScalar::Make({i32, f64, str}, {"i", "f", "s"}))));
+
+ // Names default to field_index
+ EXPECT_THAT(MakeStruct({i32, f64, str}),
+ ResultWith(Datum(*StructScalar::Make({i32, f64, str}, {"0", "1", "2"}))));
+
+ // No field names or input values is fine
+ EXPECT_THAT(MakeStruct({}), ResultWith(Datum(*StructScalar::Make({}, {}))));
+
+ // Three field names but one input value
+ EXPECT_THAT(MakeStruct({str}, {"i", "f", "s"}), Raises(StatusCode::Invalid));
+}
+
+TEST(MakeStruct, Array) {
+ std::vector<std::string> field_names{"i", "s"};
+
+ auto i32 = ArrayFromJSON(int32(), "[42, 13, 7]");
+ auto str = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])");
+
+ EXPECT_THAT(MakeStruct({i32, str}, {"i", "s"}),
+ ResultWith(Datum(*StructArray::Make({i32, str}, field_names))));
+
+ // Scalars are broadcast to the length of the arrays
+ EXPECT_THAT(MakeStruct({i32, MakeScalar("aa")}, {"i", "s"}),
+ ResultWith(Datum(*StructArray::Make({i32, str}, field_names))));
+
+ // Array length mismatch
+ EXPECT_THAT(MakeStruct({i32->Slice(1), str}, field_names), Raises(StatusCode::Invalid));
+}
+
+TEST(MakeStruct, NullableMetadataPassedThru) {
+ auto i32 = ArrayFromJSON(int32(), "[42, 13, 7]");
+ auto str = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])");
+
+ std::vector<std::string> field_names{"i", "s"};
+ std::vector<bool> nullability{true, false};
+ std::vector<std::shared_ptr<const KeyValueMetadata>> metadata = {
+ key_value_metadata({"a", "b"}, {"ALPHA", "BRAVO"}), nullptr};
+
+ ASSERT_OK_AND_ASSIGN(auto proj,
+ MakeStruct({i32, str}, field_names, nullability, metadata));
+
+ AssertTypeEqual(*proj.type(), StructType({
+ field("i", int32(), /*nullable=*/true, metadata[0]),
+ field("s", utf8(), /*nullable=*/false, nullptr),
+ }));
+
+ // error: projecting an array containing nulls with nullable=false
+ EXPECT_THAT(MakeStruct({i32, ArrayFromJSON(utf8(), R"(["aa", null, "aa"])")},
+ field_names, nullability, metadata),
+ Raises(StatusCode::Invalid));
+}
+
+TEST(MakeStruct, ChunkedArray) {
+ std::vector<std::string> field_names{"i", "s"};
+
+ auto i32_0 = ArrayFromJSON(int32(), "[42, 13, 7]");
+ auto i32_1 = ArrayFromJSON(int32(), "[]");
+ auto i32_2 = ArrayFromJSON(int32(), "[32, 0]");
+
+ auto str_0 = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])");
+ auto str_1 = ArrayFromJSON(utf8(), "[]");
+ auto str_2 = ArrayFromJSON(utf8(), R"(["aa", "aa"])");
+
+ ASSERT_OK_AND_ASSIGN(auto i32, ChunkedArray::Make({i32_0, i32_1, i32_2}));
+ ASSERT_OK_AND_ASSIGN(auto str, ChunkedArray::Make({str_0, str_1, str_2}));
+
+ ASSERT_OK_AND_ASSIGN(auto expected_0, StructArray::Make({i32_0, str_0}, field_names));
+ ASSERT_OK_AND_ASSIGN(auto expected_1, StructArray::Make({i32_1, str_1}, field_names));
+ ASSERT_OK_AND_ASSIGN(auto expected_2, StructArray::Make({i32_2, str_2}, field_names));
+ ASSERT_OK_AND_ASSIGN(Datum expected,
+ ChunkedArray::Make({expected_0, expected_1, expected_2}));
+
+ ASSERT_OK_AND_EQ(expected, MakeStruct({i32, str}, field_names));
+
+ // Scalars are broadcast to the length of the arrays
+ ASSERT_OK_AND_EQ(expected, MakeStruct({i32, MakeScalar("aa")}, field_names));
+
+ // Array length mismatch
+ ASSERT_RAISES(Invalid, MakeStruct({i32->Slice(1), str}, field_names));
+}
+
+TEST(MakeStruct, ChunkedArrayDifferentChunking) {
+ std::vector<std::string> field_names{"i", "s"};
+
+ auto i32_0 = ArrayFromJSON(int32(), "[42, 13, 7]");
+ auto i32_1 = ArrayFromJSON(int32(), "[]");
+ auto i32_2 = ArrayFromJSON(int32(), "[32, 0]");
+
+ auto str_0 = ArrayFromJSON(utf8(), R"(["aa", "aa"])");
+ auto str_1 = ArrayFromJSON(utf8(), R"(["aa"])");
+ auto str_2 = ArrayFromJSON(utf8(), R"([])");
+ auto str_3 = ArrayFromJSON(utf8(), R"(["aa", "aa"])");
+
+ ASSERT_OK_AND_ASSIGN(auto i32, ChunkedArray::Make({i32_0, i32_1, i32_2}));
+ ASSERT_OK_AND_ASSIGN(auto str, ChunkedArray::Make({str_0, str_1, str_2, str_3}));
+
+ std::vector<ArrayVector> expected_rechunked =
+ ::arrow::internal::RechunkArraysConsistently({i32->chunks(), str->chunks()});
+ ASSERT_EQ(expected_rechunked[0].size(), expected_rechunked[1].size());
+
+ ArrayVector expected_chunks(expected_rechunked[0].size());
+ for (size_t i = 0; i < expected_chunks.size(); ++i) {
+ ASSERT_OK_AND_ASSIGN(expected_chunks[i], StructArray::Make({expected_rechunked[0][i],
+ expected_rechunked[1][i]},
+ field_names));
+ }
+
+ ASSERT_OK_AND_ASSIGN(Datum expected, ChunkedArray::Make(expected_chunks));
+
+ ASSERT_OK_AND_EQ(expected, MakeStruct({i32, str}, field_names));
+
+ // Scalars are broadcast to the length of the arrays
+ ASSERT_OK_AND_EQ(expected, MakeStruct({i32, MakeScalar("aa")}, field_names));
+
+ // Array length mismatch
+ ASSERT_RAISES(Invalid, MakeStruct({i32->Slice(1), str}, field_names));
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
new file mode 100644
index 000000000..96d8ba23c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
@@ -0,0 +1,532 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_writer.h"
+#include "arrow/util/hashing.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::HashTraits;
+
+namespace compute {
+namespace internal {
+namespace {
+
+template <typename Type>
+struct SetLookupState : public KernelState {
+ explicit SetLookupState(MemoryPool* pool) : lookup_table(pool, 0) {}
+
+ Status Init(const SetLookupOptions& options) {
+ if (options.value_set.kind() == Datum::ARRAY) {
+ const ArrayData& value_set = *options.value_set.array();
+ memo_index_to_value_index.reserve(value_set.length);
+ RETURN_NOT_OK(AddArrayValueSet(options, *options.value_set.array()));
+ } else if (options.value_set.kind() == Datum::CHUNKED_ARRAY) {
+ const ChunkedArray& value_set = *options.value_set.chunked_array();
+ memo_index_to_value_index.reserve(value_set.length());
+ int64_t offset = 0;
+ for (const std::shared_ptr<Array>& chunk : value_set.chunks()) {
+ RETURN_NOT_OK(AddArrayValueSet(options, *chunk->data(), offset));
+ offset += chunk->length();
+ }
+ } else {
+ return Status::Invalid("value_set should be an array or chunked array");
+ }
+ if (!options.skip_nulls && lookup_table.GetNull() >= 0) {
+ null_index = memo_index_to_value_index[lookup_table.GetNull()];
+ }
+ return Status::OK();
+ }
+
+ Status AddArrayValueSet(const SetLookupOptions& options, const ArrayData& data,
+ int64_t start_index = 0) {
+ using T = typename GetViewType<Type>::T;
+ int32_t index = static_cast<int32_t>(start_index);
+ auto visit_valid = [&](T v) {
+ const auto memo_size = static_cast<int32_t>(memo_index_to_value_index.size());
+ int32_t unused_memo_index;
+ auto on_found = [&](int32_t memo_index) { DCHECK_LT(memo_index, memo_size); };
+ auto on_not_found = [&](int32_t memo_index) {
+ DCHECK_EQ(memo_index, memo_size);
+ memo_index_to_value_index.push_back(index);
+ };
+ RETURN_NOT_OK(lookup_table.GetOrInsert(
+ v, std::move(on_found), std::move(on_not_found), &unused_memo_index));
+ ++index;
+ return Status::OK();
+ };
+ auto visit_null = [&]() {
+ const auto memo_size = static_cast<int32_t>(memo_index_to_value_index.size());
+ auto on_found = [&](int32_t memo_index) { DCHECK_LT(memo_index, memo_size); };
+ auto on_not_found = [&](int32_t memo_index) {
+ DCHECK_EQ(memo_index, memo_size);
+ memo_index_to_value_index.push_back(index);
+ };
+ lookup_table.GetOrInsertNull(std::move(on_found), std::move(on_not_found));
+ ++index;
+ return Status::OK();
+ };
+
+ return VisitArrayDataInline<Type>(data, visit_valid, visit_null);
+ }
+
+ using MemoTable = typename HashTraits<Type>::MemoTableType;
+ MemoTable lookup_table;
+ // When there are duplicates in value_set, the MemoTable indices must
+ // be mapped back to indices in the value_set.
+ std::vector<int32_t> memo_index_to_value_index;
+ int32_t null_index = -1;
+};
+
+template <>
+struct SetLookupState<NullType> : public KernelState {
+ explicit SetLookupState(MemoryPool*) {}
+
+ Status Init(const SetLookupOptions& options) {
+ value_set_has_null = (options.value_set.length() > 0) && !options.skip_nulls;
+ return Status::OK();
+ }
+
+ bool value_set_has_null;
+};
+
+// TODO: Put this concept somewhere reusable
+template <int width>
+struct UnsignedIntType;
+
+template <>
+struct UnsignedIntType<1> {
+ using Type = UInt8Type;
+};
+
+template <>
+struct UnsignedIntType<2> {
+ using Type = UInt16Type;
+};
+
+template <>
+struct UnsignedIntType<4> {
+ using Type = UInt32Type;
+};
+
+template <>
+struct UnsignedIntType<8> {
+ using Type = UInt64Type;
+};
+
+// Constructing the type requires a type parameter
+struct InitStateVisitor {
+ KernelContext* ctx;
+ SetLookupOptions options;
+ const std::shared_ptr<DataType>& arg_type;
+ std::unique_ptr<KernelState> result;
+
+ InitStateVisitor(KernelContext* ctx, const KernelInitArgs& args)
+ : ctx(ctx),
+ options(*checked_cast<const SetLookupOptions*>(args.options)),
+ arg_type(args.inputs[0].type) {}
+
+ template <typename Type>
+ Status Init() {
+ using StateType = SetLookupState<Type>;
+ result.reset(new StateType(ctx->exec_context()->memory_pool()));
+ return static_cast<StateType*>(result.get())->Init(options);
+ }
+
+ Status Visit(const DataType&) { return Init<NullType>(); }
+
+ template <typename Type>
+ enable_if_boolean<Type, Status> Visit(const Type&) {
+ return Init<BooleanType>();
+ }
+
+ template <typename Type>
+ enable_if_t<has_c_type<Type>::value && !is_boolean_type<Type>::value &&
+ !std::is_same<Type, MonthDayNanoIntervalType>::value,
+ Status>
+ Visit(const Type&) {
+ return Init<typename UnsignedIntType<sizeof(typename Type::c_type)>::Type>();
+ }
+
+ template <typename Type>
+ enable_if_base_binary<Type, Status> Visit(const Type&) {
+ return Init<typename Type::PhysicalType>();
+ }
+
+ // Handle Decimal128Type, FixedSizeBinaryType
+ Status Visit(const FixedSizeBinaryType& type) { return Init<FixedSizeBinaryType>(); }
+
+ Status Visit(const MonthDayNanoIntervalType& type) {
+ return Init<MonthDayNanoIntervalType>();
+ }
+
+ Result<std::unique_ptr<KernelState>> GetResult() {
+ if (!options.value_set.type()->Equals(arg_type)) {
+ ARROW_ASSIGN_OR_RAISE(
+ options.value_set,
+ Cast(options.value_set, CastOptions::Safe(arg_type), ctx->exec_context()));
+ }
+
+ RETURN_NOT_OK(VisitTypeInline(*arg_type, this));
+ return std::move(result);
+ }
+};
+
+Result<std::unique_ptr<KernelState>> InitSetLookup(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ if (args.options == nullptr) {
+ return Status::Invalid(
+ "Attempted to call a set lookup function without SetLookupOptions");
+ }
+
+ return InitStateVisitor{ctx, args}.GetResult();
+}
+
+struct IndexInVisitor {
+ KernelContext* ctx;
+ const ArrayData& data;
+ Datum* out;
+ Int32Builder builder;
+
+ IndexInVisitor(KernelContext* ctx, const ArrayData& data, Datum* out)
+ : ctx(ctx), data(data), out(out), builder(ctx->exec_context()->memory_pool()) {}
+
+ Status Visit(const DataType& type) {
+ DCHECK_EQ(type.id(), Type::NA);
+ const auto& state = checked_cast<const SetLookupState<NullType>&>(*ctx->state());
+ if (data.length != 0) {
+ // skip_nulls is honored for consistency with other types
+ if (state.value_set_has_null) {
+ RETURN_NOT_OK(this->builder.Reserve(data.length));
+ for (int64_t i = 0; i < data.length; ++i) {
+ this->builder.UnsafeAppend(0);
+ }
+ } else {
+ RETURN_NOT_OK(this->builder.AppendNulls(data.length));
+ }
+ }
+ return Status::OK();
+ }
+
+ template <typename Type>
+ Status ProcessIndexIn() {
+ using T = typename GetViewType<Type>::T;
+
+ const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());
+
+ RETURN_NOT_OK(this->builder.Reserve(data.length));
+ VisitArrayDataInline<Type>(
+ data,
+ [&](T v) {
+ int32_t index = state.lookup_table.Get(v);
+ if (index != -1) {
+ // matching needle; output index from value_set
+ this->builder.UnsafeAppend(state.memo_index_to_value_index[index]);
+ } else {
+ // no matching needle; output null
+ this->builder.UnsafeAppendNull();
+ }
+ },
+ [&]() {
+ if (state.null_index != -1) {
+ // value_set included null
+ this->builder.UnsafeAppend(state.null_index);
+ } else {
+ // value_set does not include null; output null
+ this->builder.UnsafeAppendNull();
+ }
+ });
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_boolean<Type, Status> Visit(const Type&) {
+ return ProcessIndexIn<BooleanType>();
+ }
+
+ template <typename Type>
+ enable_if_t<has_c_type<Type>::value && !is_boolean_type<Type>::value &&
+ !std::is_same<Type, MonthDayNanoIntervalType>::value,
+ Status>
+ Visit(const Type&) {
+ return ProcessIndexIn<
+ typename UnsignedIntType<sizeof(typename Type::c_type)>::Type>();
+ }
+
+ template <typename Type>
+ enable_if_base_binary<Type, Status> Visit(const Type&) {
+ return ProcessIndexIn<typename Type::PhysicalType>();
+ }
+
+ // Handle Decimal128Type, FixedSizeBinaryType
+ Status Visit(const FixedSizeBinaryType& type) {
+ return ProcessIndexIn<FixedSizeBinaryType>();
+ }
+
+ Status Visit(const MonthDayNanoIntervalType& type) {
+ return ProcessIndexIn<MonthDayNanoIntervalType>();
+ }
+
+ Status Execute() {
+ Status s = VisitTypeInline(*data.type, this);
+ if (!s.ok()) {
+ return s;
+ }
+ std::shared_ptr<ArrayData> out_data;
+ RETURN_NOT_OK(this->builder.FinishInternal(&out_data));
+ out->value = std::move(out_data);
+ return Status::OK();
+ }
+};
+
+Status ExecIndexIn(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return IndexInVisitor(ctx, *batch[0].array(), out).Execute();
+}
+
+// ----------------------------------------------------------------------
+
+// IsIn writes the results into a preallocated boolean data bitmap
+struct IsInVisitor {
+ KernelContext* ctx;
+ const ArrayData& data;
+ Datum* out;
+
+ IsInVisitor(KernelContext* ctx, const ArrayData& data, Datum* out)
+ : ctx(ctx), data(data), out(out) {}
+
+ Status Visit(const DataType& type) {
+ DCHECK_EQ(type.id(), Type::NA);
+ const auto& state = checked_cast<const SetLookupState<NullType>&>(*ctx->state());
+ ArrayData* output = out->mutable_array();
+ // skip_nulls is honored for consistency with other types
+ BitUtil::SetBitsTo(output->buffers[1]->mutable_data(), output->offset, output->length,
+ state.value_set_has_null);
+ return Status::OK();
+ }
+
+ template <typename Type>
+ Status ProcessIsIn() {
+ using T = typename GetViewType<Type>::T;
+ const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());
+ ArrayData* output = out->mutable_array();
+
+ FirstTimeBitmapWriter writer(output->buffers[1]->mutable_data(), output->offset,
+ output->length);
+
+ VisitArrayDataInline<Type>(
+ this->data,
+ [&](T v) {
+ if (state.lookup_table.Get(v) != -1) {
+ writer.Set();
+ } else {
+ writer.Clear();
+ }
+ writer.Next();
+ },
+ [&]() {
+ if (state.null_index != -1) {
+ writer.Set();
+ } else {
+ writer.Clear();
+ }
+ writer.Next();
+ });
+ writer.Finish();
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_boolean<Type, Status> Visit(const Type&) {
+ return ProcessIsIn<BooleanType>();
+ }
+
+ template <typename Type>
+ enable_if_t<has_c_type<Type>::value && !is_boolean_type<Type>::value &&
+ !std::is_same<Type, MonthDayNanoIntervalType>::value,
+ Status>
+ Visit(const Type&) {
+ return ProcessIsIn<typename UnsignedIntType<sizeof(typename Type::c_type)>::Type>();
+ }
+
+ template <typename Type>
+ enable_if_base_binary<Type, Status> Visit(const Type&) {
+ return ProcessIsIn<typename Type::PhysicalType>();
+ }
+
+ // Handle Decimal128Type, FixedSizeBinaryType
+ Status Visit(const FixedSizeBinaryType& type) {
+ return ProcessIsIn<FixedSizeBinaryType>();
+ }
+
+ Status Visit(const MonthDayNanoIntervalType& type) {
+ return ProcessIsIn<MonthDayNanoIntervalType>();
+ }
+
+ Status Execute() { return VisitTypeInline(*data.type, this); }
+};
+
+Status ExecIsIn(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return IsInVisitor(ctx, *batch[0].array(), out).Execute();
+}
+
+// Unary set lookup kernels available for the following input types
+//
+// * Null type
+// * Boolean
+// * Numeric
+// * Simple temporal types (date, time, timestamp)
+// * Base binary types
+// * Decimal
+
+void AddBasicSetLookupKernels(ScalarKernel kernel,
+ const std::shared_ptr<DataType>& out_ty,
+ ScalarFunction* func) {
+ auto AddKernels = [&](const std::vector<std::shared_ptr<DataType>>& types) {
+ for (const std::shared_ptr<DataType>& ty : types) {
+ kernel.signature = KernelSignature::Make({ty}, out_ty);
+ DCHECK_OK(func->AddKernel(kernel));
+ }
+ };
+
+ AddKernels(BaseBinaryTypes());
+ AddKernels(NumericTypes());
+ AddKernels(TemporalTypes());
+ AddKernels({month_day_nano_interval()});
+
+ std::vector<Type::type> other_types = {Type::BOOL, Type::DECIMAL,
+ Type::FIXED_SIZE_BINARY};
+ for (auto ty : other_types) {
+ kernel.signature = KernelSignature::Make({InputType::Array(ty)}, out_ty);
+ DCHECK_OK(func->AddKernel(kernel));
+ }
+}
+
+// Enables calling is_in with CallFunction as though it were binary.
+class IsInMetaBinary : public MetaFunction {
+ public:
+ IsInMetaBinary()
+ : MetaFunction("is_in_meta_binary", Arity::Binary(), /*doc=*/nullptr) {}
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ if (options != nullptr) {
+ return Status::Invalid("Unexpected options for 'is_in_meta_binary' function");
+ }
+ return IsIn(args[0], args[1], ctx);
+ }
+};
+
+// Enables calling index_in with CallFunction as though it were binary.
+class IndexInMetaBinary : public MetaFunction {
+ public:
+ IndexInMetaBinary()
+ : MetaFunction("index_in_meta_binary", Arity::Binary(), /*doc=*/nullptr) {}
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ if (options != nullptr) {
+ return Status::Invalid("Unexpected options for 'index_in_meta_binary' function");
+ }
+ return IndexIn(args[0], args[1], ctx);
+ }
+};
+
+struct SetLookupFunction : ScalarFunction {
+ using ScalarFunction::ScalarFunction;
+
+ Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
+ EnsureDictionaryDecoded(values);
+ return DispatchExact(*values);
+ }
+};
+
+const FunctionDoc is_in_doc{
+ "Find each element in a set of values",
+ ("For each element in `values`, return true if it is found in a given\n"
+ "set of values, false otherwise.\n"
+ "The set of values to look for must be given in SetLookupOptions.\n"
+ "By default, nulls are matched against the value set, this can be\n"
+ "changed in SetLookupOptions."),
+ {"values"},
+ "SetLookupOptions"};
+
+const FunctionDoc index_in_doc{
+ "Return index of each element in a set of values",
+ ("For each element in `values`, return its index in a given set of\n"
+ "values, or null if it is not found there.\n"
+ "The set of values to look for must be given in SetLookupOptions.\n"
+ "By default, nulls are matched against the value set, this can be\n"
+ "changed in SetLookupOptions."),
+ {"values"},
+ "SetLookupOptions"};
+
+} // namespace
+
+void RegisterScalarSetLookup(FunctionRegistry* registry) {
+ // IsIn writes its boolean output into preallocated memory
+ {
+ ScalarKernel isin_base;
+ isin_base.init = InitSetLookup;
+ isin_base.exec =
+ TrivialScalarUnaryAsArraysExec(ExecIsIn, NullHandling::OUTPUT_NOT_NULL);
+ isin_base.null_handling = NullHandling::OUTPUT_NOT_NULL;
+ auto is_in = std::make_shared<SetLookupFunction>("is_in", Arity::Unary(), &is_in_doc);
+
+ AddBasicSetLookupKernels(isin_base, /*output_type=*/boolean(), is_in.get());
+
+ isin_base.signature = KernelSignature::Make({null()}, boolean());
+ DCHECK_OK(is_in->AddKernel(isin_base));
+ DCHECK_OK(registry->AddFunction(is_in));
+
+ DCHECK_OK(registry->AddFunction(std::make_shared<IsInMetaBinary>()));
+ }
+
+ // IndexIn uses Int32Builder and so is responsible for all its own allocation
+ {
+ ScalarKernel index_in_base;
+ index_in_base.init = InitSetLookup;
+ index_in_base.exec = TrivialScalarUnaryAsArraysExec(
+ ExecIndexIn, NullHandling::COMPUTED_NO_PREALLOCATE);
+ index_in_base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ index_in_base.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ auto index_in =
+ std::make_shared<SetLookupFunction>("index_in", Arity::Unary(), &index_in_doc);
+
+ AddBasicSetLookupKernels(index_in_base, /*output_type=*/int32(), index_in.get());
+
+ index_in_base.signature = KernelSignature::Make({null()}, int32());
+ DCHECK_OK(index_in->AddKernel(index_in_base));
+ DCHECK_OK(registry->AddFunction(index_in));
+
+ DCHECK_OK(registry->AddFunction(std::make_shared<IndexInMetaBinary>()));
+ }
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup_benchmark.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup_benchmark.cc
new file mode 100644
index 000000000..02f6af4be
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup_benchmark.cc
@@ -0,0 +1,143 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+
+namespace arrow {
+namespace compute {
+
+constexpr auto kSeed = 0x94378165;
+
+static void SetLookupBenchmarkString(benchmark::State& state,
+ const std::string& func_name,
+ const int64_t value_set_length) {
+ // As the set lookup functions don't support duplicate values in the value_set,
+ // we need to choose random generation parameters that minimize the risk of
+ // duplicates (including nulls).
+ const int64_t array_length = 1 << 18;
+ const int32_t value_min_size = (value_set_length < 64) ? 2 : 10;
+ const int32_t value_max_size = 32;
+ const double null_probability = 0.2 / value_set_length;
+ random::RandomArrayGenerator rng(kSeed);
+
+ auto values =
+ rng.String(array_length, value_min_size, value_max_size, null_probability);
+ auto value_set =
+ rng.String(value_set_length, value_min_size, value_max_size, null_probability);
+ ABORT_NOT_OK(CallFunction(func_name, {values, value_set}));
+ for (auto _ : state) {
+ ABORT_NOT_OK(CallFunction(func_name, {values, value_set}));
+ }
+ state.SetItemsProcessed(state.iterations() * array_length);
+ state.SetBytesProcessed(state.iterations() * values->data()->buffers[2]->size());
+}
+
+template <typename Type>
+static void SetLookupBenchmarkNumeric(benchmark::State& state,
+ const std::string& func_name,
+ const int64_t value_set_length) {
+ const int64_t array_length = 1 << 18;
+ const int64_t value_min = 0;
+ const int64_t value_max = std::numeric_limits<typename Type::c_type>::max();
+ const double null_probability = 0.1 / value_set_length;
+ random::RandomArrayGenerator rng(kSeed);
+
+ auto values = rng.Numeric<Type>(array_length, value_min, value_max, null_probability);
+ auto value_set =
+ rng.Numeric<Type>(value_set_length, value_min, value_max, null_probability);
+ ABORT_NOT_OK(CallFunction(func_name, {values, value_set}));
+ for (auto _ : state) {
+ ABORT_NOT_OK(CallFunction(func_name, {values, value_set}));
+ }
+ state.SetItemsProcessed(state.iterations() * array_length);
+ state.SetBytesProcessed(state.iterations() * values->data()->buffers[1]->size());
+}
+
+static void IndexInStringSmallSet(benchmark::State& state) {
+ SetLookupBenchmarkString(state, "index_in_meta_binary", state.range(0));
+}
+
+static void IsInStringSmallSet(benchmark::State& state) {
+ SetLookupBenchmarkString(state, "is_in_meta_binary", state.range(0));
+}
+
+static void IndexInStringLargeSet(benchmark::State& state) {
+ SetLookupBenchmarkString(state, "index_in_meta_binary", 1 << 10);
+}
+
+static void IsInStringLargeSet(benchmark::State& state) {
+ SetLookupBenchmarkString(state, "is_in_meta_binary", 1 << 10);
+}
+
+static void IndexInInt8SmallSet(benchmark::State& state) {
+ SetLookupBenchmarkNumeric<Int8Type>(state, "index_in_meta_binary", state.range(0));
+}
+
+static void IndexInInt16SmallSet(benchmark::State& state) {
+ SetLookupBenchmarkNumeric<Int16Type>(state, "index_in_meta_binary", state.range(0));
+}
+
+static void IndexInInt32SmallSet(benchmark::State& state) {
+ SetLookupBenchmarkNumeric<Int32Type>(state, "index_in_meta_binary", state.range(0));
+}
+
+static void IndexInInt64SmallSet(benchmark::State& state) {
+ SetLookupBenchmarkNumeric<Int64Type>(state, "index_in_meta_binary", state.range(0));
+}
+
+static void IsInInt8SmallSet(benchmark::State& state) {
+ SetLookupBenchmarkNumeric<Int8Type>(state, "is_in_meta_binary", state.range(0));
+}
+
+static void IsInInt16SmallSet(benchmark::State& state) {
+ SetLookupBenchmarkNumeric<Int16Type>(state, "is_in_meta_binary", state.range(0));
+}
+
+static void IsInInt32SmallSet(benchmark::State& state) {
+ SetLookupBenchmarkNumeric<Int32Type>(state, "is_in_meta_binary", state.range(0));
+}
+
+static void IsInInt64SmallSet(benchmark::State& state) {
+ SetLookupBenchmarkNumeric<Int64Type>(state, "is_in_meta_binary", state.range(0));
+}
+
+BENCHMARK(IndexInStringSmallSet)->RangeMultiplier(4)->Range(2, 64);
+BENCHMARK(IsInStringSmallSet)->RangeMultiplier(4)->Range(2, 64);
+
+BENCHMARK(IndexInStringLargeSet);
+BENCHMARK(IsInStringLargeSet);
+
+// XXX For Int8, the value_set length has to be capped at a lower value
+// in order to avoid duplicates.
+BENCHMARK(IndexInInt8SmallSet)->RangeMultiplier(4)->Range(2, 8);
+BENCHMARK(IndexInInt16SmallSet)->RangeMultiplier(4)->Range(2, 64);
+BENCHMARK(IndexInInt32SmallSet)->RangeMultiplier(4)->Range(2, 64);
+BENCHMARK(IndexInInt64SmallSet)->RangeMultiplier(4)->Range(2, 64);
+BENCHMARK(IsInInt8SmallSet)->RangeMultiplier(4)->Range(2, 8);
+BENCHMARK(IsInInt16SmallSet)->RangeMultiplier(4)->Range(2, 64);
+BENCHMARK(IsInInt32SmallSet)->RangeMultiplier(4)->Range(2, 64);
+BENCHMARK(IsInInt64SmallSet)->RangeMultiplier(4)->Range(2, 64);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
new file mode 100644
index 000000000..284c8ccde
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
@@ -0,0 +1,992 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdio>
+#include <functional>
+#include <iosfwd>
+#include <locale>
+#include <memory>
+#include <stdexcept>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/chunked_array.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_compat.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+namespace compute {
+
+// ----------------------------------------------------------------------
+// IsIn tests
+
+void CheckIsIn(const std::shared_ptr<DataType>& type, const std::string& input_json,
+ const std::string& value_set_json, const std::string& expected_json,
+ bool skip_nulls = false) {
+ auto input = ArrayFromJSON(type, input_json);
+ auto value_set = ArrayFromJSON(type, value_set_json);
+ auto expected = ArrayFromJSON(boolean(), expected_json);
+
+ ASSERT_OK_AND_ASSIGN(Datum actual_datum,
+ IsIn(input, SetLookupOptions(value_set, skip_nulls)));
+ std::shared_ptr<Array> actual = actual_datum.make_array();
+ ValidateOutput(actual_datum);
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+}
+
+void CheckIsInChunked(const std::shared_ptr<ChunkedArray>& input,
+ const std::shared_ptr<ChunkedArray>& value_set,
+ const std::shared_ptr<ChunkedArray>& expected,
+ bool skip_nulls = false) {
+ ASSERT_OK_AND_ASSIGN(Datum actual_datum,
+ IsIn(input, SetLookupOptions(value_set, skip_nulls)));
+ auto actual = actual_datum.chunked_array();
+ ValidateOutput(actual_datum);
+ AssertChunkedEqual(*expected, *actual);
+}
+
+void CheckIsInDictionary(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<DataType>& index_type,
+ const std::string& input_dictionary_json,
+ const std::string& input_index_json,
+ const std::string& value_set_json,
+ const std::string& expected_json, bool skip_nulls = false) {
+ auto dict_type = dictionary(index_type, type);
+ auto indices = ArrayFromJSON(index_type, input_index_json);
+ auto dict = ArrayFromJSON(type, input_dictionary_json);
+
+ ASSERT_OK_AND_ASSIGN(auto input, DictionaryArray::FromArrays(dict_type, indices, dict));
+ auto value_set = ArrayFromJSON(type, value_set_json);
+ auto expected = ArrayFromJSON(boolean(), expected_json);
+
+ ASSERT_OK_AND_ASSIGN(Datum actual_datum,
+ IsIn(input, SetLookupOptions(value_set, skip_nulls)));
+ std::shared_ptr<Array> actual = actual_datum.make_array();
+ ValidateOutput(actual_datum);
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+}
+
+class TestIsInKernel : public ::testing::Test {};
+
+TEST_F(TestIsInKernel, CallBinary) {
+ auto input = ArrayFromJSON(int8(), "[0, 1, 2, 3, 4, 5, 6, 7, 8]");
+ auto value_set = ArrayFromJSON(int8(), "[2, 3, 5, 7]");
+ ASSERT_RAISES(Invalid, CallFunction("is_in", {input, value_set}));
+
+ ASSERT_OK_AND_ASSIGN(Datum out, CallFunction("is_in_meta_binary", {input, value_set}));
+ auto expected = ArrayFromJSON(boolean(), ("[false, false, true, true, false,"
+ "true, false, true, false]"));
+ AssertArraysEqual(*expected, *out.make_array());
+}
+
+TEST_F(TestIsInKernel, ImplicitlyCastValueSet) {
+ auto input = ArrayFromJSON(int8(), "[0, 1, 2, 3, 4, 5, 6, 7, 8]");
+
+ SetLookupOptions opts{ArrayFromJSON(int32(), "[2, 3, 5, 7]")};
+ ASSERT_OK_AND_ASSIGN(Datum out, CallFunction("is_in", {input}, &opts));
+
+ auto expected = ArrayFromJSON(boolean(), ("[false, false, true, true, false,"
+ "true, false, true, false]"));
+ AssertArraysEqual(*expected, *out.make_array());
+
+ // fails; value_set cannot be cast to int8
+ opts = SetLookupOptions{ArrayFromJSON(float32(), "[2.5, 3.1, 5.0]")};
+ ASSERT_RAISES(Invalid, CallFunction("is_in", {input}, &opts));
+}
+
+template <typename Type>
+class TestIsInKernelPrimitive : public ::testing::Test {};
+
+template <typename Type>
+class TestIsInKernelBinary : public ::testing::Test {};
+
+using PrimitiveTypes = ::testing::Types<Int8Type, UInt8Type, Int16Type, UInt16Type,
+ Int32Type, UInt32Type, Int64Type, UInt64Type,
+ FloatType, DoubleType, Date32Type, Date64Type>;
+
+TYPED_TEST_SUITE(TestIsInKernelPrimitive, PrimitiveTypes);
+
+TYPED_TEST(TestIsInKernelPrimitive, IsIn) {
+ auto type = TypeTraits<TypeParam>::type_singleton();
+
+ // No Nulls
+ CheckIsIn(type, "[0, 1, 2, 3, 2]", "[2, 1]", "[false, true, true, false, true]");
+
+ // Nulls in left array
+ CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, 1]", "[false, true, true, false, true]",
+ /*skip_nulls=*/false);
+ CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, 1]", "[false, true, true, false, true]",
+ /*skip_nulls=*/true);
+
+ // Nulls in right array
+ CheckIsIn(type, "[0, 1, 2, 3, 2]", "[2, null, 1]", "[false, true, true, false, true]",
+ /*skip_nulls=*/false);
+ CheckIsIn(type, "[0, 1, 2, 3, 2]", "[2, null, 1]", "[false, true, true, false, true]",
+ /*skip_nulls=*/true);
+
+ // Nulls in both the arrays
+ CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, null, 1]", "[true, true, true, false, true]",
+ /*skip_nulls=*/false);
+ CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, null, 1]",
+ "[false, true, true, false, true]", /*skip_nulls=*/true);
+
+ // Duplicates in right array
+ CheckIsIn(type, "[null, 1, 2, 3, 2]", "[null, 2, 2, null, 1, 1]",
+ "[true, true, true, false, true]",
+ /*skip_nulls=*/false);
+ CheckIsIn(type, "[null, 1, 2, 3, 2]", "[null, 2, 2, null, 1, 1]",
+ "[false, true, true, false, true]", /*skip_nulls=*/true);
+
+ // Empty Arrays
+ CheckIsIn(type, "[]", "[]", "[]");
+}
+
+TEST_F(TestIsInKernel, NullType) {
+ auto type = null();
+
+ CheckIsIn(type, "[null, null, null]", "[null]", "[true, true, true]");
+ CheckIsIn(type, "[null, null, null]", "[]", "[false, false, false]");
+ CheckIsIn(type, "[]", "[]", "[]");
+
+ CheckIsIn(type, "[null, null]", "[null]", "[false, false]", /*skip_nulls=*/true);
+ CheckIsIn(type, "[null, null]", "[]", "[false, false]", /*skip_nulls=*/true);
+
+ // Duplicates in right array
+ CheckIsIn(type, "[null, null, null]", "[null, null]", "[true, true, true]");
+ CheckIsIn(type, "[null, null]", "[null, null]", "[false, false]", /*skip_nulls=*/true);
+}
+
+TEST_F(TestIsInKernel, TimeTimestamp) {
+ for (const auto& type :
+ {time32(TimeUnit::SECOND), time64(TimeUnit::NANO), timestamp(TimeUnit::MICRO)}) {
+ CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]",
+ "[true, true, false, true, true]", /*skip_nulls=*/false);
+ CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]",
+ "[true, false, false, true, true]", /*skip_nulls=*/true);
+
+ // Duplicates in right array
+ CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]",
+ "[true, true, false, true, true]", /*skip_nulls=*/false);
+ CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]",
+ "[true, false, false, true, true]", /*skip_nulls=*/true);
+ }
+}
+
+TEST_F(TestIsInKernel, Boolean) {
+ auto type = boolean();
+
+ CheckIsIn(type, "[true, false, null, true, false]", "[false]",
+ "[false, true, false, false, true]", /*skip_nulls=*/false);
+ CheckIsIn(type, "[true, false, null, true, false]", "[false]",
+ "[false, true, false, false, true]", /*skip_nulls=*/true);
+
+ CheckIsIn(type, "[true, false, null, true, false]", "[false, null]",
+ "[false, true, true, false, true]", /*skip_nulls=*/false);
+ CheckIsIn(type, "[true, false, null, true, false]", "[false, null]",
+ "[false, true, false, false, true]", /*skip_nulls=*/true);
+
+ // Duplicates in right array
+ CheckIsIn(type, "[true, false, null, true, false]", "[null, false, false, null]",
+ "[false, true, true, false, true]", /*skip_nulls=*/false);
+ CheckIsIn(type, "[true, false, null, true, false]", "[null, false, false, null]",
+ "[false, true, false, false, true]", /*skip_nulls=*/true);
+}
+
+TYPED_TEST_SUITE(TestIsInKernelBinary, BinaryArrowTypes);
+
+TYPED_TEST(TestIsInKernelBinary, Binary) {
+ auto type = TypeTraits<TypeParam>::type_singleton();
+
+ CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", ""])",
+ "[true, true, false, false, true]",
+ /*skip_nulls=*/false);
+ CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", ""])",
+ "[true, true, false, false, true]",
+ /*skip_nulls=*/true);
+
+ CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", "", null])",
+ "[true, true, false, true, true]",
+ /*skip_nulls=*/false);
+ CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", "", null])",
+ "[true, true, false, false, true]",
+ /*skip_nulls=*/true);
+
+ // Duplicates in right array
+ CheckIsIn(type, R"(["aaa", "", "cc", null, ""])",
+ R"([null, "aaa", "aaa", "", "", null])", "[true, true, false, true, true]",
+ /*skip_nulls=*/false);
+ CheckIsIn(type, R"(["aaa", "", "cc", null, ""])",
+ R"([null, "aaa", "aaa", "", "", null])", "[true, true, false, false, true]",
+ /*skip_nulls=*/true);
+}
+
+TEST_F(TestIsInKernel, FixedSizeBinary) {
+ auto type = fixed_size_binary(3);
+
+ CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb"])",
+ "[true, true, false, false, true]",
+ /*skip_nulls=*/false);
+ CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb"])",
+ "[true, true, false, false, true]",
+ /*skip_nulls=*/true);
+
+ CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb", null])",
+ "[true, true, false, true, true]",
+ /*skip_nulls=*/false);
+ CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb", null])",
+ "[true, true, false, false, true]",
+ /*skip_nulls=*/true);
+
+ // Duplicates in right array
+ CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])",
+ R"(["aaa", null, "aaa", "bbb", "bbb", null])",
+ "[true, true, false, true, true]",
+ /*skip_nulls=*/false);
+ CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])",
+ R"(["aaa", null, "aaa", "bbb", "bbb", null])",
+ "[true, true, false, false, true]",
+ /*skip_nulls=*/true);
+}
+
+TEST_F(TestIsInKernel, Decimal) {
+ auto type = decimal(3, 1);
+
+ CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", R"(["12.3", "78.9"])",
+ "[true, false, true, false, true]",
+ /*skip_nulls=*/false);
+ CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", R"(["12.3", "78.9"])",
+ "[true, false, true, false, true]",
+ /*skip_nulls=*/true);
+
+ CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])",
+ R"(["12.3", "78.9", null])", "[true, false, true, true, true]",
+ /*skip_nulls=*/false);
+ CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])",
+ R"(["12.3", "78.9", null])", "[true, false, true, false, true]",
+ /*skip_nulls=*/true);
+
+ // Duplicates in right array
+ CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])",
+ R"([null, "12.3", "12.3", "78.9", "78.9", null])",
+ "[true, false, true, true, true]",
+ /*skip_nulls=*/false);
+ CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])",
+ R"([null, "12.3", "12.3", "78.9", "78.9", null])",
+ "[true, false, true, false, true]",
+ /*skip_nulls=*/true);
+}
+
+TEST_F(TestIsInKernel, DictionaryArray) {
+ for (auto index_ty : all_dictionary_index_types()) {
+ CheckIsInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", "B", "C", "D"])",
+ /*input_index_json=*/"[1, 2, null, 0]",
+ /*value_set_json=*/R"(["A", "B", "C"])",
+ /*expected_json=*/"[true, true, false, true]",
+ /*skip_nulls=*/false);
+ CheckIsInDictionary(/*type=*/float32(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/"[4.1, -1.0, 42, 9.8]",
+ /*input_index_json=*/"[1, 2, null, 0]",
+ /*value_set_json=*/"[4.1, 42, -1.0]",
+ /*expected_json=*/"[true, true, false, true]",
+ /*skip_nulls=*/false);
+
+ // With nulls and skip_nulls=false
+ CheckIsInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", "B", "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "B", "A", null])",
+ /*expected_json=*/"[true, false, true, true, true]",
+ /*skip_nulls=*/false);
+ CheckIsInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", null, "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "B", "A", null])",
+ /*expected_json=*/"[true, false, true, true, true]",
+ /*skip_nulls=*/false);
+ CheckIsInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", null, "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "B", "A"])",
+ /*expected_json=*/"[false, false, false, true, false]",
+ /*skip_nulls=*/false);
+
+ // With nulls and skip_nulls=true
+ CheckIsInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", "B", "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "B", "A", null])",
+ /*expected_json=*/"[true, false, false, true, true]",
+ /*skip_nulls=*/true);
+ CheckIsInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", null, "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "B", "A", null])",
+ /*expected_json=*/"[false, false, false, true, false]",
+ /*skip_nulls=*/true);
+ CheckIsInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", null, "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "B", "A"])",
+ /*expected_json=*/"[false, false, false, true, false]",
+ /*skip_nulls=*/true);
+
+ // With duplicates in value_set
+ CheckIsInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", "B", "C", "D"])",
+ /*input_index_json=*/"[1, 2, null, 0]",
+ /*value_set_json=*/R"(["A", "A", "B", "A", "B", "C"])",
+ /*expected_json=*/"[true, true, false, true]",
+ /*skip_nulls=*/false);
+ CheckIsInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", "B", "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "C", "B", "A", null, null, "B"])",
+ /*expected_json=*/"[true, false, true, true, true]",
+ /*skip_nulls=*/false);
+ CheckIsInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", "B", "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "C", "B", "A", null, null, "B"])",
+ /*expected_json=*/"[true, false, false, true, true]",
+ /*skip_nulls=*/true);
+ }
+}
+
+TEST_F(TestIsInKernel, ChunkedArrayInvoke) {
+ auto input = ChunkedArrayFromJSON(
+ utf8(), {R"(["abc", "def", "", "abc", "jkl"])", R"(["def", null, "abc", "zzz"])"});
+ // No null in value_set
+ auto value_set = ChunkedArrayFromJSON(utf8(), {R"(["", "def"])", R"(["abc"])"});
+ auto expected = ChunkedArrayFromJSON(
+ boolean(), {"[true, true, true, true, false]", "[true, false, true, false]"});
+
+ CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/false);
+ CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/true);
+
+ value_set = ChunkedArrayFromJSON(utf8(), {R"(["", "def"])", R"([null])"});
+ expected = ChunkedArrayFromJSON(
+ boolean(), {"[false, true, true, false, false]", "[true, true, false, false]"});
+ CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/false);
+ expected = ChunkedArrayFromJSON(
+ boolean(), {"[false, true, true, false, false]", "[true, false, false, false]"});
+ CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/true);
+
+ // Duplicates in value_set
+ value_set =
+ ChunkedArrayFromJSON(utf8(), {R"(["", null, "", "def"])", R"(["def", null])"});
+ expected = ChunkedArrayFromJSON(
+ boolean(), {"[false, true, true, false, false]", "[true, true, false, false]"});
+ CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/false);
+ expected = ChunkedArrayFromJSON(
+ boolean(), {"[false, true, true, false, false]", "[true, false, false, false]"});
+ CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/true);
+}
+
+// ----------------------------------------------------------------------
+// IndexIn tests
+
+class TestIndexInKernel : public ::testing::Test {
+ public:
+ void CheckIndexIn(const std::shared_ptr<DataType>& type, const std::string& input_json,
+ const std::string& value_set_json, const std::string& expected_json,
+ bool skip_nulls = false) {
+ std::shared_ptr<Array> input = ArrayFromJSON(type, input_json);
+ std::shared_ptr<Array> value_set = ArrayFromJSON(type, value_set_json);
+ std::shared_ptr<Array> expected = ArrayFromJSON(int32(), expected_json);
+
+ SetLookupOptions options(value_set, skip_nulls);
+ ASSERT_OK_AND_ASSIGN(Datum actual_datum, IndexIn(input, options));
+ std::shared_ptr<Array> actual = actual_datum.make_array();
+ ValidateOutput(actual_datum);
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+ }
+
+ void CheckIndexInChunked(const std::shared_ptr<ChunkedArray>& input,
+ const std::shared_ptr<ChunkedArray>& value_set,
+ const std::shared_ptr<ChunkedArray>& expected,
+ bool skip_nulls) {
+ ASSERT_OK_AND_ASSIGN(Datum actual,
+ IndexIn(input, SetLookupOptions(value_set, skip_nulls)));
+ ASSERT_EQ(Datum::CHUNKED_ARRAY, actual.kind());
+ ValidateOutput(actual);
+ AssertChunkedEqual(*expected, *actual.chunked_array());
+ }
+
+ void CheckIndexInDictionary(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<DataType>& index_type,
+ const std::string& input_dictionary_json,
+ const std::string& input_index_json,
+ const std::string& value_set_json,
+ const std::string& expected_json, bool skip_nulls = false) {
+ auto dict_type = dictionary(index_type, type);
+ auto indices = ArrayFromJSON(index_type, input_index_json);
+ auto dict = ArrayFromJSON(type, input_dictionary_json);
+
+ ASSERT_OK_AND_ASSIGN(auto input,
+ DictionaryArray::FromArrays(dict_type, indices, dict));
+ auto value_set = ArrayFromJSON(type, value_set_json);
+ auto expected = ArrayFromJSON(int32(), expected_json);
+
+ SetLookupOptions options(value_set, skip_nulls);
+ ASSERT_OK_AND_ASSIGN(Datum actual_datum, IndexIn(input, options));
+ std::shared_ptr<Array> actual = actual_datum.make_array();
+ ValidateOutput(actual_datum);
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+ }
+};
+
+TEST_F(TestIndexInKernel, CallBinary) {
+ auto input = ArrayFromJSON(int8(), "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]");
+ auto value_set = ArrayFromJSON(int8(), "[2, 3, 5, 7]");
+ ASSERT_RAISES(Invalid, CallFunction("index_in", {input, value_set}));
+
+ ASSERT_OK_AND_ASSIGN(Datum out,
+ CallFunction("index_in_meta_binary", {input, value_set}));
+ auto expected = ArrayFromJSON(int32(), ("[null, null, 0, 1, null, 2, null, 3, null,"
+ " null, null]"));
+ AssertArraysEqual(*expected, *out.make_array());
+}
+
+template <typename Type>
+class TestIndexInKernelPrimitive : public TestIndexInKernel {};
+
+using PrimitiveDictionaries =
+ ::testing::Types<Int8Type, UInt8Type, Int16Type, UInt16Type, Int32Type, UInt32Type,
+ Int64Type, UInt64Type, FloatType, DoubleType, Date32Type,
+ Date64Type>;
+
+TYPED_TEST_SUITE(TestIndexInKernelPrimitive, PrimitiveDictionaries);
+
+TYPED_TEST(TestIndexInKernelPrimitive, IndexIn) {
+ auto type = TypeTraits<TypeParam>::type_singleton();
+
+ // No Nulls
+ this->CheckIndexIn(type,
+ /* input= */ "[2, 1, 2, 1, 2, 3]",
+ /* value_set= */ "[2, 1, 3]",
+ /* expected= */ "[0, 1, 0, 1, 0, 2]");
+
+ // Haystack array all null
+ this->CheckIndexIn(type,
+ /* input= */ "[null, null, null, null, null, null]",
+ /* value_set= */ "[2, 1, 3]",
+ /* expected= */ "[null, null, null, null, null, null]");
+
+ // Needles array all null
+ this->CheckIndexIn(type,
+ /* input= */ "[2, 1, 2, 1, 2, 3]",
+ /* value_set= */ "[null]",
+ /* expected= */ "[null, null, null, null, null, null]");
+
+ // Both arrays all null
+ this->CheckIndexIn(type,
+ /* input= */ "[null, null, null, null]",
+ /* value_set= */ "[null]",
+ /* expected= */ "[0, 0, 0, 0]");
+
+ // Duplicates in value_set
+ this->CheckIndexIn(type,
+ /* input= */ "[2, 1, 2, 1, 2, 3]",
+ /* value_set= */ "[2, 2, 1, 1, 1, 3, 3]",
+ /* expected= */ "[0, 2, 0, 2, 0, 5]");
+
+ // Duplicates and nulls in value_set
+ this->CheckIndexIn(type,
+ /* input= */ "[2, 1, 2, 1, 2, 3]",
+ /* value_set= */ "[2, 2, null, null, 1, 1, 1, 3, 3]",
+ /* expected= */ "[0, 4, 0, 4, 0, 7]");
+
+ // No Match
+ this->CheckIndexIn(type,
+ /* input= */ "[2, null, 7, 3, 8]",
+ /* value_set= */ "[2, null, 6, 3]",
+ /* expected= */ "[0, 1, null, 3, null]");
+
+ // Empty Arrays
+ this->CheckIndexIn(type, "[]", "[]", "[]");
+}
+
+TYPED_TEST(TestIndexInKernelPrimitive, SkipNulls) {
+ auto type = TypeTraits<TypeParam>::type_singleton();
+
+ // No nulls in value_set
+ this->CheckIndexIn(type,
+ /*input=*/"[0, 1, 2, 3, null]",
+ /*value_set=*/"[1, 3]",
+ /*expected=*/"[null, 0, null, 1, null]",
+ /*skip_nulls=*/false);
+ this->CheckIndexIn(type,
+ /*input=*/"[0, 1, 2, 3, null]",
+ /*value_set=*/"[1, 3]",
+ /*expected=*/"[null, 0, null, 1, null]",
+ /*skip_nulls=*/true);
+ // Same with duplicates in value_set
+ this->CheckIndexIn(type,
+ /*input=*/"[0, 1, 2, 3, null]",
+ /*value_set=*/"[1, 1, 3, 3]",
+ /*expected=*/"[null, 0, null, 2, null]",
+ /*skip_nulls=*/false);
+ this->CheckIndexIn(type,
+ /*input=*/"[0, 1, 2, 3, null]",
+ /*value_set=*/"[1, 1, 3, 3]",
+ /*expected=*/"[null, 0, null, 2, null]",
+ /*skip_nulls=*/true);
+
+ // Nulls in value_set
+ this->CheckIndexIn(type,
+ /*input=*/"[0, 1, 2, 3, null]",
+ /*value_set=*/"[1, null, 3]",
+ /*expected=*/"[null, 0, null, 2, 1]",
+ /*skip_nulls=*/false);
+ this->CheckIndexIn(type,
+ /*input=*/"[0, 1, 2, 3, null]",
+ /*value_set=*/"[1, 1, null, null, 3, 3]",
+ /*expected=*/"[null, 0, null, 4, null]",
+ /*skip_nulls=*/true);
+ // Same with duplicates in value_set
+ this->CheckIndexIn(type,
+ /*input=*/"[0, 1, 2, 3, null]",
+ /*value_set=*/"[1, 1, null, null, 3, 3]",
+ /*expected=*/"[null, 0, null, 4, 2]",
+ /*skip_nulls=*/false);
+}
+
+TEST_F(TestIndexInKernel, NullType) {
+ CheckIndexIn(null(), "[null, null, null]", "[null]", "[0, 0, 0]");
+ CheckIndexIn(null(), "[null, null, null]", "[]", "[null, null, null]");
+ CheckIndexIn(null(), "[]", "[null, null]", "[]");
+ CheckIndexIn(null(), "[]", "[]", "[]");
+
+ CheckIndexIn(null(), "[null, null]", "[null]", "[null, null]", /*skip_nulls=*/true);
+ CheckIndexIn(null(), "[null, null]", "[]", "[null, null]", /*skip_nulls=*/true);
+}
+
+TEST_F(TestIndexInKernel, TimeTimestamp) {
+ CheckIndexIn(time32(TimeUnit::SECOND),
+ /* input= */ "[1, null, 5, 1, 2]",
+ /* value_set= */ "[2, 1, null]",
+ /* expected= */ "[1, 2, null, 1, 0]");
+
+ // Duplicates in value_set
+ CheckIndexIn(time32(TimeUnit::SECOND),
+ /* input= */ "[1, null, 5, 1, 2]",
+ /* value_set= */ "[2, 2, 1, 1, null, null]",
+ /* expected= */ "[2, 4, null, 2, 0]");
+
+ // Needles array has no nulls
+ CheckIndexIn(time32(TimeUnit::SECOND),
+ /* input= */ "[2, null, 5, 1]",
+ /* value_set= */ "[2, 1]",
+ /* expected= */ "[0, null, null, 1]");
+
+ // No match
+ CheckIndexIn(time32(TimeUnit::SECOND), "[3, null, 5, 3]", "[2, 1]",
+ "[null, null, null, null]");
+
+ // Empty arrays
+ CheckIndexIn(time32(TimeUnit::SECOND), "[]", "[]", "[]");
+
+ CheckIndexIn(time64(TimeUnit::NANO), "[2, null, 2, 1]", "[2, null, 1]", "[0, 1, 0, 2]");
+
+ CheckIndexIn(timestamp(TimeUnit::NANO), "[2, null, 2, 1]", "[2, null, 1]",
+ "[0, 1, 0, 2]");
+
+ // Empty input array
+ CheckIndexIn(timestamp(TimeUnit::NANO), "[]", "[2, null, 1]", "[]");
+
+ // Empty value_set array
+ CheckIndexIn(timestamp(TimeUnit::NANO), "[2, null, 1]", "[]", "[null, null, null]");
+
+ // Both array are all null
+ CheckIndexIn(time32(TimeUnit::SECOND), "[null, null, null, null]", "[null]",
+ "[0, 0, 0, 0]");
+}
+
+TEST_F(TestIndexInKernel, Boolean) {
+ CheckIndexIn(boolean(),
+ /* input= */ "[false, null, false, true]",
+ /* value_set= */ "[null, false, true]",
+ /* expected= */ "[1, 0, 1, 2]");
+
+ CheckIndexIn(boolean(), "[false, null, false, true]", "[false, true, null]",
+ "[0, 2, 0, 1]");
+
+ // Duplicates in value_set
+ CheckIndexIn(boolean(), "[false, null, false, true]",
+ "[false, false, true, true, null, null]", "[0, 4, 0, 2]");
+
+ // No Nulls
+ CheckIndexIn(boolean(), "[true, true, false, true]", "[false, true]", "[1, 1, 0, 1]");
+
+ CheckIndexIn(boolean(), "[false, true, false, true]", "[true]", "[null, 0, null, 0]");
+
+ // No match
+ CheckIndexIn(boolean(), "[true, true, true, true]", "[false]",
+ "[null, null, null, null]");
+
+ // Nulls in input array
+ CheckIndexIn(boolean(), "[null, null, null, null]", "[true]",
+ "[null, null, null, null]");
+
+ // Nulls in value_set array
+ CheckIndexIn(boolean(), "[true, true, false, true]", "[null]",
+ "[null, null, null, null]");
+
+ // Both array have Nulls
+ CheckIndexIn(boolean(), "[null, null, null, null]", "[null]", "[0, 0, 0, 0]");
+}
+
+template <typename Type>
+class TestIndexInKernelBinary : public TestIndexInKernel {};
+
+TYPED_TEST_SUITE(TestIndexInKernelBinary, BinaryArrowTypes);
+
+TYPED_TEST(TestIndexInKernelBinary, Binary) {
+ auto type = TypeTraits<TypeParam>::type_singleton();
+ this->CheckIndexIn(type, R"(["foo", null, "bar", "foo"])", R"(["foo", null, "bar"])",
+ R"([0, 1, 2, 0])");
+
+ // Duplicates in value_set
+ this->CheckIndexIn(type, R"(["foo", null, "bar", "foo"])",
+ R"(["foo", "foo", null, null, "bar", "bar"])", R"([0, 2, 4, 0])");
+
+ // No match
+ this->CheckIndexIn(type,
+ /* input= */ R"(["foo", null, "bar", "foo"])",
+ /* value_set= */ R"(["baz", "bazzz"])",
+ /* expected= */ R"([null, null, null, null])");
+
+ // Nulls in input array
+ this->CheckIndexIn(type,
+ /* input= */ R"([null, null, null, null])",
+ /* value_set= */ R"(["foo", "bar"])",
+ /* expected= */ R"([null, null, null, null])");
+
+ // Nulls in value_set array
+ this->CheckIndexIn(type, R"(["foo", "bar", "foo"])", R"([null])",
+ R"([null, null, null])");
+
+ // Both array have Nulls
+ this->CheckIndexIn(type,
+ /* input= */ R"([null, null, null, null])",
+ /* value_set= */ R"([null])",
+ /* expected= */ R"([0, 0, 0, 0])");
+
+ // Empty arrays
+ this->CheckIndexIn(type, R"([])", R"([])", R"([])");
+
+ // Empty input array
+ this->CheckIndexIn(type, R"([])", R"(["foo", null, "bar"])", "[]");
+
+ // Empty value_set array
+ this->CheckIndexIn(type, R"(["foo", null, "bar", "foo"])", "[]",
+ R"([null, null, null, null])");
+}
+
+TEST_F(TestIndexInKernel, BinaryResizeTable) {
+ const int32_t kTotalValues = 10000;
+#if !defined(ARROW_VALGRIND)
+ const int32_t kRepeats = 10;
+#else
+ // Mitigate Valgrind's slowness
+ const int32_t kRepeats = 3;
+#endif
+
+ const int32_t kBufSize = 20;
+
+ Int32Builder expected_builder;
+ StringBuilder input_builder;
+ ASSERT_OK(expected_builder.Resize(kTotalValues * kRepeats));
+ ASSERT_OK(input_builder.Resize(kTotalValues * kRepeats));
+ ASSERT_OK(input_builder.ReserveData(kBufSize * kTotalValues * kRepeats));
+
+ for (int32_t i = 0; i < kTotalValues * kRepeats; i++) {
+ int32_t index = i % kTotalValues;
+
+ char buf[kBufSize] = "test";
+ ASSERT_GE(snprintf(buf + 4, sizeof(buf) - 4, "%d", index), 0);
+
+ input_builder.UnsafeAppend(util::string_view(buf));
+ expected_builder.UnsafeAppend(index);
+ }
+
+ std::shared_ptr<Array> input, value_set, expected;
+ ASSERT_OK(input_builder.Finish(&input));
+ value_set = input->Slice(0, kTotalValues);
+ ASSERT_OK(expected_builder.Finish(&expected));
+
+ ASSERT_OK_AND_ASSIGN(Datum actual_datum, IndexIn(input, value_set));
+ std::shared_ptr<Array> actual = actual_datum.make_array();
+ ASSERT_ARRAYS_EQUAL(*expected, *actual);
+}
+
+TEST_F(TestIndexInKernel, FixedSizeBinary) {
+ CheckIndexIn(fixed_size_binary(3),
+ /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])",
+ /*value_set=*/R"(["aaa", null, "bbb", "ccc"])",
+ /*expected=*/R"([2, 1, null, 0, 3, 0])");
+ CheckIndexIn(fixed_size_binary(3),
+ /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])",
+ /*value_set=*/R"(["aaa", null, "bbb", "ccc"])",
+ /*expected=*/R"([2, null, null, 0, 3, 0])",
+ /*skip_nulls=*/true);
+
+ CheckIndexIn(fixed_size_binary(3),
+ /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])",
+ /*value_set=*/R"(["aaa", "bbb", "ccc"])",
+ /*expected=*/R"([1, null, null, 0, 2, 0])");
+ CheckIndexIn(fixed_size_binary(3),
+ /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])",
+ /*value_set=*/R"(["aaa", "bbb", "ccc"])",
+ /*expected=*/R"([1, null, null, 0, 2, 0])",
+ /*skip_nulls=*/true);
+
+ // Duplicates in value_set
+ CheckIndexIn(fixed_size_binary(3),
+ /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])",
+ /*value_set=*/R"(["aaa", "aaa", null, null, "bbb", "bbb", "ccc"])",
+ /*expected=*/R"([4, 2, null, 0, 6, 0])");
+ CheckIndexIn(fixed_size_binary(3),
+ /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])",
+ /*value_set=*/R"(["aaa", "aaa", null, null, "bbb", "bbb", "ccc"])",
+ /*expected=*/R"([4, null, null, 0, 6, 0])",
+ /*skip_nulls=*/true);
+
+ // Empty input array
+ CheckIndexIn(fixed_size_binary(5), R"([])", R"(["bbbbb", null, "aaaaa", "ccccc"])",
+ R"([])");
+
+ // Empty value_set array
+ CheckIndexIn(fixed_size_binary(5), R"(["bbbbb", null, "bbbbb"])", R"([])",
+ R"([null, null, null])");
+
+ // Empty arrays
+ CheckIndexIn(fixed_size_binary(0), R"([])", R"([])", R"([])");
+}
+
+TEST_F(TestIndexInKernel, MonthDayNanoInterval) {
+ auto type = month_day_nano_interval();
+
+ CheckIndexIn(type,
+ /*input=*/R"([[5, -1, 5], null, [4, 5, 6], [5, -1, 5], [1, 2, 3]])",
+ /*value_set=*/R"([null, [4, 5, 6], [5, -1, 5]])",
+ /*expected=*/R"([2, 0, 1, 2, null])",
+ /*skip_nulls=*/false);
+
+ // Duplicates in value_set
+ CheckIndexIn(
+ type,
+ /*input=*/R"([[7, 8, 0], null, [0, 0, 0], [7, 8, 0], [0, 0, 1]])",
+ /*value_set=*/R"([null, null, [0, 0, 0], [0, 0, 0], [7, 8, 0], [7, 8, 0]])",
+ /*expected=*/R"([4, 0, 2, 4, null])",
+ /*skip_nulls=*/false);
+}
+
+TEST_F(TestIndexInKernel, Decimal) {
+ auto type = decimal(2, 0);
+
+ CheckIndexIn(type,
+ /*input=*/R"(["12", null, "11", "12", "13"])",
+ /*value_set=*/R"([null, "11", "12"])",
+ /*expected=*/R"([2, 0, 1, 2, null])",
+ /*skip_nulls=*/false);
+ CheckIndexIn(type,
+ /*input=*/R"(["12", null, "11", "12", "13"])",
+ /*value_set=*/R"([null, "11", "12"])",
+ /*expected=*/R"([2, null, 1, 2, null])",
+ /*skip_nulls=*/true);
+
+ CheckIndexIn(type,
+ /*input=*/R"(["12", null, "11", "12", "13"])",
+ /*value_set=*/R"(["11", "12"])",
+ /*expected=*/R"([1, null, 0, 1, null])",
+ /*skip_nulls=*/false);
+ CheckIndexIn(type,
+ /*input=*/R"(["12", null, "11", "12", "13"])",
+ /*value_set=*/R"(["11", "12"])",
+ /*expected=*/R"([1, null, 0, 1, null])",
+ /*skip_nulls=*/true);
+
+ // Duplicates in value_set
+ CheckIndexIn(type,
+ /*input=*/R"(["12", null, "11", "12", "13"])",
+ /*value_set=*/R"([null, null, "11", "11", "12", "12"])",
+ /*expected=*/R"([4, 0, 2, 4, null])",
+ /*skip_nulls=*/false);
+ CheckIndexIn(type,
+ /*input=*/R"(["12", null, "11", "12", "13"])",
+ /*value_set=*/R"([null, null, "11", "11", "12", "12"])",
+ /*expected=*/R"([4, null, 2, 4, null])",
+ /*skip_nulls=*/true);
+}
+
+TEST_F(TestIndexInKernel, DictionaryArray) {
+ for (auto index_ty : all_dictionary_index_types()) {
+ CheckIndexInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", "B", "C", "D"])",
+ /*input_index_json=*/"[1, 2, null, 0]",
+ /*value_set_json=*/R"(["A", "B", "C"])",
+ /*expected_json=*/"[1, 2, null, 0]",
+ /*skip_nulls=*/false);
+ CheckIndexInDictionary(/*type=*/float32(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/"[4.1, -1.0, 42, 9.8]",
+ /*input_index_json=*/"[1, 2, null, 0]",
+ /*value_set_json=*/"[4.1, 42, -1.0]",
+ /*expected_json=*/"[2, 1, null, 0]",
+ /*skip_nulls=*/false);
+
+ // With nulls and skip_nulls=false
+ CheckIndexInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", "B", "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "B", "A", null])",
+ /*expected_json=*/"[1, null, 3, 2, 1]",
+ /*skip_nulls=*/false);
+ CheckIndexInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", null, "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "B", "A", null])",
+ /*expected_json=*/"[3, null, 3, 2, 3]",
+ /*skip_nulls=*/false);
+ CheckIndexInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", null, "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "B", "A"])",
+ /*expected_json=*/"[null, null, null, 2, null]",
+ /*skip_nulls=*/false);
+
+ // With nulls and skip_nulls=true
+ CheckIndexInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", "B", "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "B", "A", null])",
+ /*expected_json=*/"[1, null, null, 2, 1]",
+ /*skip_nulls=*/true);
+ CheckIndexInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", null, "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "B", "A", null])",
+ /*expected_json=*/"[null, null, null, 2, null]",
+ /*skip_nulls=*/true);
+ CheckIndexInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", null, "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "B", "A"])",
+ /*expected_json=*/"[null, null, null, 2, null]",
+ /*skip_nulls=*/true);
+
+ // With duplicates in value_set
+ CheckIndexInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", "B", "C", "D"])",
+ /*input_index_json=*/"[1, 2, null, 0]",
+ /*value_set_json=*/R"(["A", "A", "B", "B", "C", "C"])",
+ /*expected_json=*/"[2, 4, null, 0]",
+ /*skip_nulls=*/false);
+ CheckIndexInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", null, "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "C", "B", "B", "A", "A", null])",
+ /*expected_json=*/"[6, null, 6, 4, 6]",
+ /*skip_nulls=*/false);
+ CheckIndexInDictionary(/*type=*/utf8(),
+ /*index_type=*/index_ty,
+ /*input_dictionary_json=*/R"(["A", null, "C", "D"])",
+ /*input_index_json=*/"[1, 3, null, 0, 1]",
+ /*value_set_json=*/R"(["C", "C", "B", "B", "A", "A", null])",
+ /*expected_json=*/"[null, null, null, 4, null]",
+ /*skip_nulls=*/true);
+ }
+}
+
+TEST_F(TestIndexInKernel, ChunkedArrayInvoke) {
+ auto input = ChunkedArrayFromJSON(utf8(), {R"(["abc", "def", "ghi", "abc", "jkl"])",
+ R"(["def", null, "abc", "zzz"])"});
+ // No null in value_set
+ auto value_set = ChunkedArrayFromJSON(utf8(), {R"(["ghi", "def"])", R"(["abc"])"});
+ auto expected =
+ ChunkedArrayFromJSON(int32(), {"[2, 1, 0, 2, null]", "[1, null, 2, null]"});
+
+ CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/false);
+ CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/true);
+
+ // Null in value_set
+ value_set = ChunkedArrayFromJSON(utf8(), {R"(["ghi", "def"])", R"([null, "abc"])"});
+ expected = ChunkedArrayFromJSON(int32(), {"[3, 1, 0, 3, null]", "[1, 2, 3, null]"});
+ CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/false);
+ expected = ChunkedArrayFromJSON(int32(), {"[3, 1, 0, 3, null]", "[1, null, 3, null]"});
+ CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/true);
+
+ // Duplicates in value_set
+ value_set = ChunkedArrayFromJSON(
+ utf8(), {R"(["ghi", "ghi", "def"])", R"(["def", null, null, "abc"])"});
+ expected = ChunkedArrayFromJSON(int32(), {"[6, 2, 0, 6, null]", "[2, 4, 6, null]"});
+ CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/false);
+ expected = ChunkedArrayFromJSON(int32(), {"[6, 2, 0, 6, null]", "[2, null, 6, null]"});
+ CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/true);
+}
+
+TEST(TestSetLookup, DispatchBest) {
+ for (std::string name : {"is_in", "index_in"}) {
+ CheckDispatchBest(name, {int32()}, {int32()});
+ CheckDispatchBest(name, {dictionary(int32(), utf8())}, {utf8()});
+ }
+}
+
+TEST(TestSetLookup, IsInWithImplicitCasts) {
+ SetLookupOptions opts{ArrayFromJSON(utf8(), R"(["b", null])")};
+ CheckScalarUnary("is_in",
+ ArrayFromJSON(dictionary(int32(), utf8()), R"(["a", "b", "c", null])"),
+ ArrayFromJSON(boolean(), "[0, 1, 0, 1]"), &opts);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_string.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_string.cc
new file mode 100644
index 000000000..11562b06d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_string.cc
@@ -0,0 +1,4490 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cctype>
+#include <iterator>
+#include <string>
+
+#ifdef ARROW_WITH_UTF8PROC
+#include <utf8proc.h>
+#endif
+
+#ifdef ARROW_WITH_RE2
+#include <re2/re2.h>
+#endif
+
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/buffer_builder.h"
+
+#include "arrow/builder.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/utf8.h"
+#include "arrow/util/value_parsing.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+namespace internal {
+
+namespace {
+
+#ifdef ARROW_WITH_RE2
+util::string_view ToStringView(re2::StringPiece piece) {
+ return {piece.data(), piece.length()};
+}
+
+re2::StringPiece ToStringPiece(util::string_view view) {
+ return {view.data(), view.length()};
+}
+
+Status RegexStatus(const RE2& regex) {
+ if (!regex.ok()) {
+ return Status::Invalid("Invalid regular expression: ", regex.error());
+ }
+ return Status::OK();
+}
+#endif
+
+// Code units in the range [a-z] can only be an encoding of an ASCII
+// character/codepoint, not the 2nd, 3rd or 4th code unit (byte) of a different
+// codepoint. This is guaranteed by the non-overlap design of the Unicode
+// standard. (see section 2.5 of Unicode Standard Core Specification v13.0)
+
+// IsAlpha/Digit etc
+
+static inline bool IsAsciiCharacter(uint8_t character) { return character < 128; }
+
+static inline bool IsLowerCaseCharacterAscii(uint8_t ascii_character) {
+ return (ascii_character >= 'a') && (ascii_character <= 'z');
+}
+
+static inline bool IsUpperCaseCharacterAscii(uint8_t ascii_character) {
+ return (ascii_character >= 'A') && (ascii_character <= 'Z');
+}
+
+static inline bool IsCasedCharacterAscii(uint8_t ascii_character) {
+ // Note: Non-ASCII characters are seen as uncased.
+ return IsLowerCaseCharacterAscii(ascii_character) ||
+ IsUpperCaseCharacterAscii(ascii_character);
+}
+
+static inline bool IsAlphaCharacterAscii(uint8_t ascii_character) {
+ return IsCasedCharacterAscii(ascii_character);
+}
+
+static inline bool IsAlphaNumericCharacterAscii(uint8_t ascii_character) {
+ return ((ascii_character >= '0') && (ascii_character <= '9')) ||
+ ((ascii_character >= 'a') && (ascii_character <= 'z')) ||
+ ((ascii_character >= 'A') && (ascii_character <= 'Z'));
+}
+
+static inline bool IsDecimalCharacterAscii(uint8_t ascii_character) {
+ return ((ascii_character >= '0') && (ascii_character <= '9'));
+}
+
+static inline bool IsSpaceCharacterAscii(uint8_t ascii_character) {
+ return ((ascii_character >= 9) && (ascii_character <= 13)) || (ascii_character == ' ');
+}
+
+static inline bool IsPrintableCharacterAscii(uint8_t ascii_character) {
+ return ((ascii_character >= ' ') && (ascii_character <= '~'));
+}
+
+struct BinaryLength {
+ template <typename OutValue, typename Arg0Value = util::string_view>
+ static OutValue Call(KernelContext*, Arg0Value val, Status*) {
+ return static_cast<OutValue>(val.size());
+ }
+
+ static Status FixedSizeExec(KernelContext*, const ExecBatch& batch, Datum* out) {
+ // Output is preallocated and validity buffer is precomputed
+ const int32_t width =
+ checked_cast<const FixedSizeBinaryType&>(*batch[0].type()).byte_width();
+ if (batch.values[0].is_array()) {
+ int32_t* buffer = out->mutable_array()->GetMutableValues<int32_t>(1);
+ std::fill(buffer, buffer + batch.length, width);
+ } else {
+ checked_cast<Int32Scalar*>(out->scalar().get())->value = width;
+ }
+ return Status::OK();
+ }
+};
+
+struct Utf8Length {
+ template <typename OutValue, typename Arg0Value = util::string_view>
+ static OutValue Call(KernelContext*, Arg0Value val, Status*) {
+ auto str = reinterpret_cast<const uint8_t*>(val.data());
+ auto strlen = val.size();
+ return static_cast<OutValue>(util::UTF8Length(str, str + strlen));
+ }
+};
+
+static inline uint8_t ascii_tolower(uint8_t utf8_code_unit) {
+ return ((utf8_code_unit >= 'A') && (utf8_code_unit <= 'Z')) ? (utf8_code_unit + 32)
+ : utf8_code_unit;
+}
+
+static inline uint8_t ascii_toupper(uint8_t utf8_code_unit) {
+ return ((utf8_code_unit >= 'a') && (utf8_code_unit <= 'z')) ? (utf8_code_unit - 32)
+ : utf8_code_unit;
+}
+
+static inline uint8_t ascii_swapcase(uint8_t utf8_code_unit) {
+ if (IsLowerCaseCharacterAscii(utf8_code_unit)) {
+ utf8_code_unit -= 32;
+ } else if (IsUpperCaseCharacterAscii(utf8_code_unit)) {
+ utf8_code_unit += 32;
+ }
+ return utf8_code_unit;
+}
+
+#ifdef ARROW_WITH_UTF8PROC
+
+// Direct lookup tables for unicode properties
+constexpr uint32_t kMaxCodepointLookup =
+ 0xffff; // up to this codepoint is in a lookup table
+std::vector<uint32_t> lut_upper_codepoint;
+std::vector<uint32_t> lut_lower_codepoint;
+std::vector<uint32_t> lut_swapcase_codepoint;
+std::vector<utf8proc_category_t> lut_category;
+std::once_flag flag_case_luts;
+
+// IsAlpha/Digit etc
+
+static inline bool HasAnyUnicodeGeneralCategory(uint32_t codepoint, uint32_t mask) {
+ utf8proc_category_t general_category = codepoint <= kMaxCodepointLookup
+ ? lut_category[codepoint]
+ : utf8proc_category(codepoint);
+ uint32_t general_category_bit = 1 << general_category;
+ // for e.g. undefined (but valid) codepoints, general_category == 0 ==
+ // UTF8PROC_CATEGORY_CN
+ return (general_category != UTF8PROC_CATEGORY_CN) &&
+ ((general_category_bit & mask) != 0);
+}
+
+template <typename... Categories>
+static inline bool HasAnyUnicodeGeneralCategory(uint32_t codepoint, uint32_t mask,
+ utf8proc_category_t category,
+ Categories... categories) {
+ return HasAnyUnicodeGeneralCategory(codepoint, mask | (1 << category), categories...);
+}
+
+template <typename... Categories>
+static inline bool HasAnyUnicodeGeneralCategory(uint32_t codepoint,
+ utf8proc_category_t category,
+ Categories... categories) {
+ return HasAnyUnicodeGeneralCategory(codepoint, static_cast<uint32_t>(1u << category),
+ categories...);
+}
+
+static inline bool IsCasedCharacterUnicode(uint32_t codepoint) {
+ return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LU,
+ UTF8PROC_CATEGORY_LL, UTF8PROC_CATEGORY_LT) ||
+ ((static_cast<uint32_t>(utf8proc_toupper(codepoint)) != codepoint) ||
+ (static_cast<uint32_t>(utf8proc_tolower(codepoint)) != codepoint));
+}
+
+static inline bool IsLowerCaseCharacterUnicode(uint32_t codepoint) {
+ // although this trick seems to work for upper case, this is not enough for lower case
+ // testing, see https://github.com/JuliaStrings/utf8proc/issues/195 . But currently the
+ // best we can do
+ return (HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LL) ||
+ ((static_cast<uint32_t>(utf8proc_toupper(codepoint)) != codepoint) &&
+ (static_cast<uint32_t>(utf8proc_tolower(codepoint)) == codepoint))) &&
+ !HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LT);
+}
+
+static inline bool IsUpperCaseCharacterUnicode(uint32_t codepoint) {
+ // this seems to be a good workaround for utf8proc not having case information
+ // https://github.com/JuliaStrings/utf8proc/issues/195
+ return (HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LU) ||
+ ((static_cast<uint32_t>(utf8proc_toupper(codepoint)) == codepoint) &&
+ (static_cast<uint32_t>(utf8proc_tolower(codepoint)) != codepoint))) &&
+ !HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LT);
+}
+
+static inline bool IsAlphaNumericCharacterUnicode(uint32_t codepoint) {
+ return HasAnyUnicodeGeneralCategory(
+ codepoint, UTF8PROC_CATEGORY_LU, UTF8PROC_CATEGORY_LL, UTF8PROC_CATEGORY_LT,
+ UTF8PROC_CATEGORY_LM, UTF8PROC_CATEGORY_LO, UTF8PROC_CATEGORY_ND,
+ UTF8PROC_CATEGORY_NL, UTF8PROC_CATEGORY_NO);
+}
+
+static inline bool IsAlphaCharacterUnicode(uint32_t codepoint) {
+ return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_LU,
+ UTF8PROC_CATEGORY_LL, UTF8PROC_CATEGORY_LT,
+ UTF8PROC_CATEGORY_LM, UTF8PROC_CATEGORY_LO);
+}
+
+static inline bool IsDecimalCharacterUnicode(uint32_t codepoint) {
+ return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_ND);
+}
+
+static inline bool IsDigitCharacterUnicode(uint32_t codepoint) {
+ // Python defines this as Numeric_Type=Digit or Numeric_Type=Decimal.
+ // utf8proc has no support for this, this is the best we can do:
+ return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_ND);
+}
+
+static inline bool IsNumericCharacterUnicode(uint32_t codepoint) {
+ // Formally this is not correct, but utf8proc does not allow us to query for Numerical
+ // properties, e.g. Numeric_Value and Numeric_Type
+ // Python defines Numeric as Numeric_Type=Digit, Numeric_Type=Decimal or
+ // Numeric_Type=Numeric.
+ return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_ND,
+ UTF8PROC_CATEGORY_NL, UTF8PROC_CATEGORY_NO);
+}
+
+static inline bool IsSpaceCharacterUnicode(uint32_t codepoint) {
+ auto property = utf8proc_get_property(codepoint);
+ return HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_ZS) ||
+ property->bidi_class == UTF8PROC_BIDI_CLASS_WS ||
+ property->bidi_class == UTF8PROC_BIDI_CLASS_B ||
+ property->bidi_class == UTF8PROC_BIDI_CLASS_S;
+}
+
+static inline bool IsPrintableCharacterUnicode(uint32_t codepoint) {
+ uint32_t general_category = utf8proc_category(codepoint);
+ return (general_category != UTF8PROC_CATEGORY_CN) &&
+ !HasAnyUnicodeGeneralCategory(codepoint, UTF8PROC_CATEGORY_CC,
+ UTF8PROC_CATEGORY_CF, UTF8PROC_CATEGORY_CS,
+ UTF8PROC_CATEGORY_CO, UTF8PROC_CATEGORY_ZS,
+ UTF8PROC_CATEGORY_ZL, UTF8PROC_CATEGORY_ZP);
+}
+
+void EnsureLookupTablesFilled() {
+ std::call_once(flag_case_luts, []() {
+ lut_upper_codepoint.reserve(kMaxCodepointLookup + 1);
+ lut_lower_codepoint.reserve(kMaxCodepointLookup + 1);
+ lut_swapcase_codepoint.reserve(kMaxCodepointLookup + 1);
+ for (uint32_t i = 0; i <= kMaxCodepointLookup; i++) {
+ lut_upper_codepoint.push_back(utf8proc_toupper(i));
+ lut_lower_codepoint.push_back(utf8proc_tolower(i));
+ lut_category.push_back(utf8proc_category(i));
+
+ if (IsLowerCaseCharacterUnicode(i)) {
+ lut_swapcase_codepoint.push_back(utf8proc_toupper(i));
+ } else if (IsUpperCaseCharacterUnicode(i)) {
+ lut_swapcase_codepoint.push_back(utf8proc_tolower(i));
+ } else {
+ lut_swapcase_codepoint.push_back(i);
+ }
+ }
+ });
+}
+
+#else
+
+void EnsureLookupTablesFilled() {}
+
+#endif // ARROW_WITH_UTF8PROC
+
+constexpr int64_t kTransformError = -1;
+
+struct StringTransformBase {
+ virtual ~StringTransformBase() = default;
+ virtual Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return Status::OK();
+ }
+
+ // Return the maximum total size of the output in codeunits (i.e. bytes)
+ // given input characteristics.
+ virtual int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) {
+ return input_ncodeunits;
+ }
+
+ virtual Status InvalidStatus() {
+ return Status::Invalid("Invalid UTF8 sequence in input");
+ }
+
+ // Derived classes should also define this method:
+ // int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ // uint8_t* output);
+};
+
+template <typename Type, typename StringTransform>
+struct StringTransformExecBase {
+ using offset_type = typename Type::offset_type;
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+
+ static Status Execute(KernelContext* ctx, StringTransform* transform,
+ const ExecBatch& batch, Datum* out) {
+ if (batch[0].kind() == Datum::ARRAY) {
+ return ExecArray(ctx, transform, batch[0].array(), out);
+ }
+ DCHECK_EQ(batch[0].kind(), Datum::SCALAR);
+ return ExecScalar(ctx, transform, batch[0].scalar(), out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, StringTransform* transform,
+ const std::shared_ptr<ArrayData>& data, Datum* out) {
+ ArrayType input(data);
+ ArrayData* output = out->mutable_array();
+
+ const int64_t input_ncodeunits = input.total_values_length();
+ const int64_t input_nstrings = input.length();
+
+ const int64_t output_ncodeunits_max =
+ transform->MaxCodeunits(input_nstrings, input_ncodeunits);
+ if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) {
+ return Status::CapacityError(
+ "Result might not fit in a 32bit utf8 array, convert to large_utf8");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(output_ncodeunits_max));
+ output->buffers[2] = values_buffer;
+
+ // String offsets are preallocated
+ offset_type* output_string_offsets = output->GetMutableValues<offset_type>(1);
+ uint8_t* output_str = output->buffers[2]->mutable_data();
+ offset_type output_ncodeunits = 0;
+
+ output_string_offsets[0] = 0;
+ for (int64_t i = 0; i < input_nstrings; i++) {
+ if (!input.IsNull(i)) {
+ offset_type input_string_ncodeunits;
+ const uint8_t* input_string = input.GetValue(i, &input_string_ncodeunits);
+ auto encoded_nbytes = static_cast<offset_type>(transform->Transform(
+ input_string, input_string_ncodeunits, output_str + output_ncodeunits));
+ if (encoded_nbytes < 0) {
+ return transform->InvalidStatus();
+ }
+ output_ncodeunits += encoded_nbytes;
+ }
+ output_string_offsets[i + 1] = output_ncodeunits;
+ }
+ DCHECK_LE(output_ncodeunits, output_ncodeunits_max);
+
+ // Trim the codepoint buffer, since we allocated too much
+ return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true);
+ }
+
+ static Status ExecScalar(KernelContext* ctx, StringTransform* transform,
+ const std::shared_ptr<Scalar>& scalar, Datum* out) {
+ const auto& input = checked_cast<const BaseBinaryScalar&>(*scalar);
+ if (!input.is_valid) {
+ return Status::OK();
+ }
+ auto* result = checked_cast<BaseBinaryScalar*>(out->scalar().get());
+ result->is_valid = true;
+ const int64_t data_nbytes = static_cast<int64_t>(input.value->size());
+
+ const int64_t output_ncodeunits_max = transform->MaxCodeunits(1, data_nbytes);
+ if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) {
+ return Status::CapacityError(
+ "Result might not fit in a 32bit utf8 array, convert to large_utf8");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(output_ncodeunits_max));
+ result->value = value_buffer;
+ auto encoded_nbytes = static_cast<offset_type>(transform->Transform(
+ input.value->data(), data_nbytes, value_buffer->mutable_data()));
+ if (encoded_nbytes < 0) {
+ return transform->InvalidStatus();
+ }
+ DCHECK_LE(encoded_nbytes, output_ncodeunits_max);
+ return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true);
+ }
+};
+
+template <typename Type, typename StringTransform>
+struct StringTransformExec : public StringTransformExecBase<Type, StringTransform> {
+ using StringTransformExecBase<Type, StringTransform>::Execute;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ StringTransform transform;
+ RETURN_NOT_OK(transform.PreExec(ctx, batch, out));
+ return Execute(ctx, &transform, batch, out);
+ }
+};
+
+template <typename Type, typename StringTransform>
+struct StringTransformExecWithState
+ : public StringTransformExecBase<Type, StringTransform> {
+ using State = typename StringTransform::State;
+ using StringTransformExecBase<Type, StringTransform>::Execute;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ StringTransform transform(State::Get(ctx));
+ RETURN_NOT_OK(transform.PreExec(ctx, batch, out));
+ return Execute(ctx, &transform, batch, out);
+ }
+};
+
+template <typename StringTransform>
+struct FixedSizeBinaryTransformExecBase {
+ static Status Execute(KernelContext* ctx, StringTransform* transform,
+ const ExecBatch& batch, Datum* out) {
+ if (batch[0].kind() == Datum::ARRAY) {
+ return ExecArray(ctx, transform, batch[0].array(), out);
+ }
+ DCHECK_EQ(batch[0].kind(), Datum::SCALAR);
+ return ExecScalar(ctx, transform, batch[0].scalar(), out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, StringTransform* transform,
+ const std::shared_ptr<ArrayData>& data, Datum* out) {
+ FixedSizeBinaryArray input(data);
+ ArrayData* output = out->mutable_array();
+
+ const int32_t input_width =
+ checked_cast<const FixedSizeBinaryType&>(*data->type).byte_width();
+ const int32_t output_width =
+ checked_cast<const FixedSizeBinaryType&>(*out->type()).byte_width();
+ const int64_t input_nstrings = input.length();
+ ARROW_ASSIGN_OR_RAISE(auto values_buffer,
+ ctx->Allocate(output_width * input_nstrings));
+ uint8_t* output_str = values_buffer->mutable_data();
+
+ for (int64_t i = 0; i < input_nstrings; i++) {
+ if (!input.IsNull(i)) {
+ const uint8_t* input_string = input.GetValue(i);
+ auto encoded_nbytes = static_cast<int32_t>(
+ transform->Transform(input_string, input_width, output_str));
+ if (encoded_nbytes != output_width) {
+ return transform->InvalidStatus();
+ }
+ } else {
+ std::memset(output_str, 0x00, output_width);
+ }
+ output_str += output_width;
+ }
+
+ output->buffers[1] = std::move(values_buffer);
+ return Status::OK();
+ }
+
+ static Status ExecScalar(KernelContext* ctx, StringTransform* transform,
+ const std::shared_ptr<Scalar>& scalar, Datum* out) {
+ const auto& input = checked_cast<const BaseBinaryScalar&>(*scalar);
+ if (!input.is_valid) {
+ return Status::OK();
+ }
+ const int32_t out_width =
+ checked_cast<const FixedSizeBinaryType&>(*out->type()).byte_width();
+ auto* result = checked_cast<BaseBinaryScalar*>(out->scalar().get());
+
+ const int32_t data_nbytes = static_cast<int32_t>(input.value->size());
+ ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(out_width));
+ auto encoded_nbytes = static_cast<int32_t>(transform->Transform(
+ input.value->data(), data_nbytes, value_buffer->mutable_data()));
+ if (encoded_nbytes != out_width) {
+ return transform->InvalidStatus();
+ }
+
+ result->is_valid = true;
+ result->value = std::move(value_buffer);
+ return Status::OK();
+ }
+};
+
+template <typename StringTransform>
+struct FixedSizeBinaryTransformExecWithState
+ : public FixedSizeBinaryTransformExecBase<StringTransform> {
+ using State = typename StringTransform::State;
+ using FixedSizeBinaryTransformExecBase<StringTransform>::Execute;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ StringTransform transform(State::Get(ctx));
+ RETURN_NOT_OK(transform.PreExec(ctx, batch, out));
+ return Execute(ctx, &transform, batch, out);
+ }
+
+ static Result<ValueDescr> OutputType(KernelContext* ctx,
+ const std::vector<ValueDescr>& descrs) {
+ DCHECK_EQ(1, descrs.size());
+ const auto& options = State::Get(ctx);
+ const int32_t input_width =
+ checked_cast<const FixedSizeBinaryType&>(*descrs[0].type).byte_width();
+ const int32_t output_width = StringTransform::FixedOutputSize(options, input_width);
+ return ValueDescr(fixed_size_binary(output_width), descrs[0].shape);
+ }
+};
+
+#ifdef ARROW_WITH_UTF8PROC
+
+struct FunctionalCaseMappingTransform : public StringTransformBase {
+ Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) override {
+ EnsureLookupTablesFilled();
+ return Status::OK();
+ }
+
+ int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override {
+ // Section 5.18 of the Unicode spec claims that the number of codepoints for case
+ // mapping can grow by a factor of 3. This means grow by a factor of 3 in bytes
+ // However, since we don't support all casings (SpecialCasing.txt) the growth
+ // in bytes is actually only at max 3/2 (as covered by the unittest).
+ // Note that rounding down the 3/2 is ok, since only codepoints encoded by
+ // two code units (even) can grow to 3 code units.
+ return static_cast<int64_t>(input_ncodeunits) * 3 / 2;
+ }
+};
+
+template <typename CodepointTransform>
+struct StringTransformCodepoint : public FunctionalCaseMappingTransform {
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ uint8_t* output_start = output;
+ if (ARROW_PREDICT_FALSE(
+ !arrow::util::UTF8Transform(input, input + input_string_ncodeunits, &output,
+ CodepointTransform::TransformCodepoint))) {
+ return kTransformError;
+ }
+ return output - output_start;
+ }
+};
+
+struct UTF8UpperTransform : public FunctionalCaseMappingTransform {
+ static uint32_t TransformCodepoint(uint32_t codepoint) {
+ return codepoint <= kMaxCodepointLookup ? lut_upper_codepoint[codepoint]
+ : utf8proc_toupper(codepoint);
+ }
+};
+
+template <typename Type>
+using UTF8Upper = StringTransformExec<Type, StringTransformCodepoint<UTF8UpperTransform>>;
+
+struct UTF8LowerTransform : public FunctionalCaseMappingTransform {
+ static uint32_t TransformCodepoint(uint32_t codepoint) {
+ return codepoint <= kMaxCodepointLookup ? lut_lower_codepoint[codepoint]
+ : utf8proc_tolower(codepoint);
+ }
+};
+
+template <typename Type>
+using UTF8Lower = StringTransformExec<Type, StringTransformCodepoint<UTF8LowerTransform>>;
+
+struct UTF8SwapCaseTransform : public FunctionalCaseMappingTransform {
+ static uint32_t TransformCodepoint(uint32_t codepoint) {
+ if (codepoint <= kMaxCodepointLookup) {
+ return lut_swapcase_codepoint[codepoint];
+ } else {
+ if (IsLowerCaseCharacterUnicode(codepoint)) {
+ return utf8proc_toupper(codepoint);
+ } else if (IsUpperCaseCharacterUnicode(codepoint)) {
+ return utf8proc_tolower(codepoint);
+ }
+ }
+
+ return codepoint;
+ }
+};
+
+template <typename Type>
+using UTF8SwapCase =
+ StringTransformExec<Type, StringTransformCodepoint<UTF8SwapCaseTransform>>;
+
+struct Utf8CapitalizeTransform : public FunctionalCaseMappingTransform {
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ uint8_t* output_start = output;
+ const uint8_t* end = input + input_string_ncodeunits;
+ const uint8_t* next = input;
+ if (input_string_ncodeunits > 0) {
+ if (ARROW_PREDICT_FALSE(!util::UTF8AdvanceCodepoints(input, end, &next, 1))) {
+ return kTransformError;
+ }
+ if (ARROW_PREDICT_FALSE(!util::UTF8Transform(
+ input, next, &output, UTF8UpperTransform::TransformCodepoint))) {
+ return kTransformError;
+ }
+ if (ARROW_PREDICT_FALSE(!util::UTF8Transform(
+ next, end, &output, UTF8LowerTransform::TransformCodepoint))) {
+ return kTransformError;
+ }
+ }
+ return output - output_start;
+ }
+};
+
+template <typename Type>
+using Utf8Capitalize = StringTransformExec<Type, Utf8CapitalizeTransform>;
+
+struct Utf8TitleTransform : public FunctionalCaseMappingTransform {
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ uint8_t* output_start = output;
+ const uint8_t* end = input + input_string_ncodeunits;
+ const uint8_t* next = input;
+ bool is_next_upper = true;
+ while ((input = next) < end) {
+ uint32_t codepoint;
+ if (ARROW_PREDICT_FALSE(!util::UTF8Decode(&next, &codepoint))) {
+ return kTransformError;
+ }
+ if (IsCasedCharacterUnicode(codepoint)) {
+ // Lower/uppercase current codepoint and
+ // prepare to lowercase next consecutive cased codepoints
+ output = is_next_upper
+ ? util::UTF8Encode(output,
+ UTF8UpperTransform::TransformCodepoint(codepoint))
+ : util::UTF8Encode(
+ output, UTF8LowerTransform::TransformCodepoint(codepoint));
+ is_next_upper = false;
+ } else {
+ // Copy current uncased codepoint and
+ // prepare to uppercase next cased codepoint
+ std::memcpy(output, input, next - input);
+ output += next - input;
+ is_next_upper = true;
+ }
+ }
+ return output - output_start;
+ }
+};
+
+template <typename Type>
+using Utf8Title = StringTransformExec<Type, Utf8TitleTransform>;
+
+#endif // ARROW_WITH_UTF8PROC
+
+struct AsciiReverseTransform : public StringTransformBase {
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ uint8_t utf8_char_found = 0;
+ for (int64_t i = 0; i < input_string_ncodeunits; i++) {
+ // if a utf8 char is found, report to utf8_char_found
+ utf8_char_found |= input[i] & 0x80;
+ output[input_string_ncodeunits - i - 1] = input[i];
+ }
+ return utf8_char_found ? kTransformError : input_string_ncodeunits;
+ }
+
+ Status InvalidStatus() override {
+ return Status::Invalid("Non-ASCII sequence in input");
+ }
+};
+
+template <typename Type>
+using AsciiReverse = StringTransformExec<Type, AsciiReverseTransform>;
+
+struct Utf8ReverseTransform : public StringTransformBase {
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ int64_t i = 0;
+ while (i < input_string_ncodeunits) {
+ int64_t char_end = std::min(i + util::ValidUtf8CodepointByteSize(input + i),
+ input_string_ncodeunits);
+ std::copy(input + i, input + char_end, output + input_string_ncodeunits - char_end);
+ i = char_end;
+ }
+ return input_string_ncodeunits;
+ }
+};
+
+template <typename Type>
+using Utf8Reverse = StringTransformExec<Type, Utf8ReverseTransform>;
+
+using TransformFunc = std::function<void(const uint8_t*, int64_t, uint8_t*)>;
+
+// Transform a buffer of offsets to one which begins with 0 and has same
+// value lengths.
+template <typename T>
+Status GetShiftedOffsets(KernelContext* ctx, const Buffer& input_buffer, int64_t offset,
+ int64_t length, std::shared_ptr<Buffer>* out) {
+ ARROW_ASSIGN_OR_RAISE(*out, ctx->Allocate((length + 1) * sizeof(T)));
+ const T* input_offsets = reinterpret_cast<const T*>(input_buffer.data()) + offset;
+ T* out_offsets = reinterpret_cast<T*>((*out)->mutable_data());
+ T first_offset = *input_offsets;
+ for (int64_t i = 0; i < length; ++i) {
+ *out_offsets++ = input_offsets[i] - first_offset;
+ }
+ *out_offsets = input_offsets[length] - first_offset;
+ return Status::OK();
+}
+
+// Apply `transform` to input character data- this function cannot change the
+// length
+template <typename Type>
+Status StringDataTransform(KernelContext* ctx, const ExecBatch& batch,
+ TransformFunc transform, Datum* out) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using offset_type = typename Type::offset_type;
+
+ if (batch[0].kind() == Datum::ARRAY) {
+ const ArrayData& input = *batch[0].array();
+ ArrayType input_boxed(batch[0].array());
+
+ ArrayData* out_arr = out->mutable_array();
+
+ if (input.offset == 0) {
+ // We can reuse offsets from input
+ out_arr->buffers[1] = input.buffers[1];
+ } else {
+ DCHECK(input.buffers[1]);
+ // We must allocate new space for the offsets and shift the existing offsets
+ RETURN_NOT_OK(GetShiftedOffsets<offset_type>(ctx, *input.buffers[1], input.offset,
+ input.length, &out_arr->buffers[1]));
+ }
+
+ // Allocate space for output data
+ int64_t data_nbytes = input_boxed.total_values_length();
+ RETURN_NOT_OK(ctx->Allocate(data_nbytes).Value(&out_arr->buffers[2]));
+ if (input.length > 0) {
+ transform(input.buffers[2]->data() + input_boxed.value_offset(0), data_nbytes,
+ out_arr->buffers[2]->mutable_data());
+ }
+ } else {
+ const auto& input = checked_cast<const BaseBinaryScalar&>(*batch[0].scalar());
+ auto result = checked_pointer_cast<BaseBinaryScalar>(MakeNullScalar(out->type()));
+ if (input.is_valid) {
+ result->is_valid = true;
+ int64_t data_nbytes = input.value->size();
+ RETURN_NOT_OK(ctx->Allocate(data_nbytes).Value(&result->value));
+ transform(input.value->data(), data_nbytes, result->value->mutable_data());
+ }
+ out->value = result;
+ }
+
+ return Status::OK();
+}
+
+void TransformAsciiUpper(const uint8_t* input, int64_t length, uint8_t* output) {
+ std::transform(input, input + length, output, ascii_toupper);
+}
+
+template <typename Type>
+struct AsciiUpper {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return StringDataTransform<Type>(ctx, batch, TransformAsciiUpper, out);
+ }
+};
+
+void TransformAsciiLower(const uint8_t* input, int64_t length, uint8_t* output) {
+ std::transform(input, input + length, output, ascii_tolower);
+}
+
+template <typename Type>
+struct AsciiLower {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return StringDataTransform<Type>(ctx, batch, TransformAsciiLower, out);
+ }
+};
+
+void TransformAsciiSwapCase(const uint8_t* input, int64_t length, uint8_t* output) {
+ std::transform(input, input + length, output, ascii_swapcase);
+}
+
+template <typename Type>
+struct AsciiSwapCase {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return StringDataTransform<Type>(ctx, batch, TransformAsciiSwapCase, out);
+ }
+};
+
+struct AsciiCapitalizeTransform : public StringTransformBase {
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ if (input_string_ncodeunits > 0) {
+ *output++ = ascii_toupper(*input++);
+ TransformAsciiLower(input, input_string_ncodeunits - 1, output);
+ }
+ return input_string_ncodeunits;
+ }
+};
+
+template <typename Type>
+using AsciiCapitalize = StringTransformExec<Type, AsciiCapitalizeTransform>;
+
+struct AsciiTitleTransform : public StringTransformBase {
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ const uint8_t* end = input + input_string_ncodeunits;
+ const uint8_t* next = input;
+ bool is_next_upper = true;
+ while ((input = next++) < end) {
+ if (IsCasedCharacterAscii(*input)) {
+ // Lower/uppercase current character and
+ // prepare to lowercase next consecutive cased characters
+ *output++ = is_next_upper ? ascii_toupper(*input) : ascii_tolower(*input);
+ is_next_upper = false;
+ } else {
+ // Copy current uncased character and
+ // prepare to uppercase next cased character
+ *output++ = *input;
+ is_next_upper = true;
+ }
+ }
+ return input_string_ncodeunits;
+ }
+};
+
+template <typename Type>
+using AsciiTitle = StringTransformExec<Type, AsciiTitleTransform>;
+
+// ----------------------------------------------------------------------
+// exact pattern detection
+
+using StrToBoolTransformFunc =
+ std::function<void(const void*, const uint8_t*, int64_t, int64_t, uint8_t*)>;
+
+// Apply `transform` to input character data- this function cannot change the
+// length
+template <typename Type>
+void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch,
+ StrToBoolTransformFunc transform, Datum* out) {
+ using offset_type = typename Type::offset_type;
+
+ if (batch[0].kind() == Datum::ARRAY) {
+ const ArrayData& input = *batch[0].array();
+ ArrayData* out_arr = out->mutable_array();
+ if (input.length > 0) {
+ transform(
+ reinterpret_cast<const offset_type*>(input.buffers[1]->data()) + input.offset,
+ input.buffers[2]->data(), input.length, out_arr->offset,
+ out_arr->buffers[1]->mutable_data());
+ }
+ } else {
+ const auto& input = checked_cast<const BaseBinaryScalar&>(*batch[0].scalar());
+ if (input.is_valid) {
+ uint8_t result_value = 0;
+ std::array<offset_type, 2> offsets{0,
+ static_cast<offset_type>(input.value->size())};
+ transform(offsets.data(), input.value->data(), 1, /*output_offset=*/0,
+ &result_value);
+ out->value = std::make_shared<BooleanScalar>(result_value > 0);
+ }
+ }
+}
+
+using MatchSubstringState = OptionsWrapper<MatchSubstringOptions>;
+
+// This is an implementation of the Knuth-Morris-Pratt algorithm
+struct PlainSubstringMatcher {
+ const MatchSubstringOptions& options_;
+ std::vector<int64_t> prefix_table;
+
+ static Result<std::unique_ptr<PlainSubstringMatcher>> Make(
+ const MatchSubstringOptions& options) {
+ // Should be handled by partial template specialization below
+ DCHECK(!options.ignore_case);
+ return ::arrow::internal::make_unique<PlainSubstringMatcher>(options);
+ }
+
+ explicit PlainSubstringMatcher(const MatchSubstringOptions& options)
+ : options_(options) {
+ // Phase 1: Build the prefix table
+ const auto pattern_length = options_.pattern.size();
+ prefix_table.resize(pattern_length + 1, /*value=*/0);
+ int64_t prefix_length = -1;
+ prefix_table[0] = -1;
+ for (size_t pos = 0; pos < pattern_length; ++pos) {
+ // The prefix cannot be expanded, reset.
+ while (prefix_length >= 0 &&
+ options_.pattern[pos] != options_.pattern[prefix_length]) {
+ prefix_length = prefix_table[prefix_length];
+ }
+ prefix_length++;
+ prefix_table[pos + 1] = prefix_length;
+ }
+ }
+
+ int64_t Find(util::string_view current) const {
+ // Phase 2: Find the prefix in the data
+ const auto pattern_length = options_.pattern.size();
+ int64_t pattern_pos = 0;
+ int64_t pos = 0;
+ if (pattern_length == 0) return 0;
+ for (const auto c : current) {
+ while ((pattern_pos >= 0) && (options_.pattern[pattern_pos] != c)) {
+ pattern_pos = prefix_table[pattern_pos];
+ }
+ pattern_pos++;
+ if (static_cast<size_t>(pattern_pos) == pattern_length) {
+ return pos + 1 - pattern_length;
+ }
+ pos++;
+ }
+ return -1;
+ }
+
+ bool Match(util::string_view current) const { return Find(current) >= 0; }
+};
+
+struct PlainStartsWithMatcher {
+ const MatchSubstringOptions& options_;
+
+ explicit PlainStartsWithMatcher(const MatchSubstringOptions& options)
+ : options_(options) {}
+
+ static Result<std::unique_ptr<PlainStartsWithMatcher>> Make(
+ const MatchSubstringOptions& options) {
+ // Should be handled by partial template specialization below
+ DCHECK(!options.ignore_case);
+ return ::arrow::internal::make_unique<PlainStartsWithMatcher>(options);
+ }
+
+ bool Match(util::string_view current) const {
+ // string_view::starts_with is C++20
+ return current.substr(0, options_.pattern.size()) == options_.pattern;
+ }
+};
+
+struct PlainEndsWithMatcher {
+ const MatchSubstringOptions& options_;
+
+ explicit PlainEndsWithMatcher(const MatchSubstringOptions& options)
+ : options_(options) {}
+
+ static Result<std::unique_ptr<PlainEndsWithMatcher>> Make(
+ const MatchSubstringOptions& options) {
+ // Should be handled by partial template specialization below
+ DCHECK(!options.ignore_case);
+ return ::arrow::internal::make_unique<PlainEndsWithMatcher>(options);
+ }
+
+ bool Match(util::string_view current) const {
+ // string_view::ends_with is C++20
+ return current.size() >= options_.pattern.size() &&
+ current.substr(current.size() - options_.pattern.size(),
+ options_.pattern.size()) == options_.pattern;
+ }
+};
+
+#ifdef ARROW_WITH_RE2
+struct RegexSubstringMatcher {
+ const MatchSubstringOptions& options_;
+ const RE2 regex_match_;
+
+ static Result<std::unique_ptr<RegexSubstringMatcher>> Make(
+ const MatchSubstringOptions& options, bool literal = false) {
+ auto matcher =
+ ::arrow::internal::make_unique<RegexSubstringMatcher>(options, literal);
+ RETURN_NOT_OK(RegexStatus(matcher->regex_match_));
+ return std::move(matcher);
+ }
+
+ explicit RegexSubstringMatcher(const MatchSubstringOptions& options,
+ bool literal = false)
+ : options_(options),
+ regex_match_(options_.pattern, MakeRE2Options(options, literal)) {}
+
+ bool Match(util::string_view current) const {
+ auto piece = re2::StringPiece(current.data(), current.length());
+ return re2::RE2::PartialMatch(piece, regex_match_);
+ }
+
+ static RE2::RE2::Options MakeRE2Options(const MatchSubstringOptions& options,
+ bool literal) {
+ RE2::RE2::Options re2_options(RE2::Quiet);
+ re2_options.set_case_sensitive(!options.ignore_case);
+ re2_options.set_literal(literal);
+ return re2_options;
+ }
+};
+#endif
+
+template <typename Type, typename Matcher>
+struct MatchSubstringImpl {
+ using offset_type = typename Type::offset_type;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out,
+ const Matcher* matcher) {
+ StringBoolTransform<Type>(
+ ctx, batch,
+ [&matcher](const void* raw_offsets, const uint8_t* data, int64_t length,
+ int64_t output_offset, uint8_t* output) {
+ const offset_type* offsets = reinterpret_cast<const offset_type*>(raw_offsets);
+ FirstTimeBitmapWriter bitmap_writer(output, output_offset, length);
+ for (int64_t i = 0; i < length; ++i) {
+ const char* current_data = reinterpret_cast<const char*>(data + offsets[i]);
+ int64_t current_length = offsets[i + 1] - offsets[i];
+ if (matcher->Match(util::string_view(current_data, current_length))) {
+ bitmap_writer.Set();
+ }
+ bitmap_writer.Next();
+ }
+ bitmap_writer.Finish();
+ },
+ out);
+ return Status::OK();
+ }
+};
+
+template <typename Type, typename Matcher>
+struct MatchSubstring {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // TODO Cache matcher across invocations (for regex compilation)
+ ARROW_ASSIGN_OR_RAISE(auto matcher, Matcher::Make(MatchSubstringState::Get(ctx)));
+ return MatchSubstringImpl<Type, Matcher>::Exec(ctx, batch, out, matcher.get());
+ }
+};
+
+template <typename Type>
+struct MatchSubstring<Type, PlainSubstringMatcher> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ auto options = MatchSubstringState::Get(ctx);
+ if (options.ignore_case) {
+#ifdef ARROW_WITH_RE2
+ ARROW_ASSIGN_OR_RAISE(auto matcher,
+ RegexSubstringMatcher::Make(options, /*literal=*/true));
+ return MatchSubstringImpl<Type, RegexSubstringMatcher>::Exec(ctx, batch, out,
+ matcher.get());
+#else
+ return Status::NotImplemented("ignore_case requires RE2");
+#endif
+ }
+ ARROW_ASSIGN_OR_RAISE(auto matcher, PlainSubstringMatcher::Make(options));
+ return MatchSubstringImpl<Type, PlainSubstringMatcher>::Exec(ctx, batch, out,
+ matcher.get());
+ }
+};
+
+template <typename Type>
+struct MatchSubstring<Type, PlainStartsWithMatcher> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ auto options = MatchSubstringState::Get(ctx);
+ if (options.ignore_case) {
+#ifdef ARROW_WITH_RE2
+ MatchSubstringOptions converted_options = options;
+ converted_options.pattern = "^" + RE2::QuoteMeta(options.pattern);
+ ARROW_ASSIGN_OR_RAISE(auto matcher, RegexSubstringMatcher::Make(converted_options));
+ return MatchSubstringImpl<Type, RegexSubstringMatcher>::Exec(ctx, batch, out,
+ matcher.get());
+#else
+ return Status::NotImplemented("ignore_case requires RE2");
+#endif
+ }
+ ARROW_ASSIGN_OR_RAISE(auto matcher, PlainStartsWithMatcher::Make(options));
+ return MatchSubstringImpl<Type, PlainStartsWithMatcher>::Exec(ctx, batch, out,
+ matcher.get());
+ }
+};
+
+template <typename Type>
+struct MatchSubstring<Type, PlainEndsWithMatcher> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ auto options = MatchSubstringState::Get(ctx);
+ if (options.ignore_case) {
+#ifdef ARROW_WITH_RE2
+ MatchSubstringOptions converted_options = options;
+ converted_options.pattern = RE2::QuoteMeta(options.pattern) + "$";
+ ARROW_ASSIGN_OR_RAISE(auto matcher, RegexSubstringMatcher::Make(converted_options));
+ return MatchSubstringImpl<Type, RegexSubstringMatcher>::Exec(ctx, batch, out,
+ matcher.get());
+#else
+ return Status::NotImplemented("ignore_case requires RE2");
+#endif
+ }
+ ARROW_ASSIGN_OR_RAISE(auto matcher, PlainEndsWithMatcher::Make(options));
+ return MatchSubstringImpl<Type, PlainEndsWithMatcher>::Exec(ctx, batch, out,
+ matcher.get());
+ }
+};
+
+const FunctionDoc match_substring_doc(
+ "Match strings against literal pattern",
+ ("For each string in `strings`, emit true iff it contains a given pattern.\n"
+ "Null inputs emit null. The pattern must be given in MatchSubstringOptions. "
+ "If ignore_case is set, only simple case folding is performed."),
+ {"strings"}, "MatchSubstringOptions");
+
+const FunctionDoc starts_with_doc(
+ "Check if strings start with a literal pattern",
+ ("For each string in `strings`, emit true iff it starts with a given pattern.\n"
+ "Null inputs emit null. The pattern must be given in MatchSubstringOptions. "
+ "If ignore_case is set, only simple case folding is performed."),
+ {"strings"}, "MatchSubstringOptions");
+
+const FunctionDoc ends_with_doc(
+ "Check if strings end with a literal pattern",
+ ("For each string in `strings`, emit true iff it ends with a given pattern.\n"
+ "Null inputs emit null. The pattern must be given in MatchSubstringOptions. "
+ "If ignore_case is set, only simple case folding is performed."),
+ {"strings"}, "MatchSubstringOptions");
+
+#ifdef ARROW_WITH_RE2
+const FunctionDoc match_substring_regex_doc(
+ "Match strings against regex pattern",
+ ("For each string in `strings`, emit true iff it matches a given pattern at any "
+ "position.\n"
+ "Null inputs emit null. The pattern must be given in MatchSubstringOptions. "
+ "If ignore_case is set, only simple case folding is performed."),
+ {"strings"}, "MatchSubstringOptions");
+
+// SQL LIKE match
+
+/// Convert a SQL-style LIKE pattern (using '%' and '_') into a regex pattern
+std::string MakeLikeRegex(const MatchSubstringOptions& options) {
+ // Allow . to match \n
+ std::string like_pattern = "(?s:^";
+ like_pattern.reserve(options.pattern.size() + 7);
+ bool escaped = false;
+ for (const char c : options.pattern) {
+ if (!escaped && c == '%') {
+ like_pattern.append(".*");
+ } else if (!escaped && c == '_') {
+ like_pattern.append(".");
+ } else if (!escaped && c == '\\') {
+ escaped = true;
+ } else {
+ switch (c) {
+ case '.':
+ case '?':
+ case '+':
+ case '*':
+ case '^':
+ case '$':
+ case '\\':
+ case '[':
+ case '{':
+ case '(':
+ case ')':
+ case '|': {
+ like_pattern.push_back('\\');
+ like_pattern.push_back(c);
+ escaped = false;
+ break;
+ }
+ default: {
+ like_pattern.push_back(c);
+ escaped = false;
+ break;
+ }
+ }
+ }
+ }
+ like_pattern.append("$)");
+ return like_pattern;
+}
+
+// Evaluate a SQL-like LIKE pattern by translating it to a regexp or
+// substring search as appropriate. See what Apache Impala does:
+// https://github.com/apache/impala/blob/9c38568657d62b6f6d7b10aa1c721ba843374dd8/be/src/exprs/like-predicate.cc
+template <typename StringType>
+struct MatchLike {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // NOTE: avoid making those constants global to avoid compiling regexes at startup
+ // A LIKE pattern matching this regex can be translated into a substring search.
+ static const RE2 kLikePatternIsSubstringMatch(R"(%+([^%_]*[^\\%_])?%+)");
+ // A LIKE pattern matching this regex can be translated into a prefix search.
+ static const RE2 kLikePatternIsStartsWith(R"(([^%_]*[^\\%_])?%+)");
+ // A LIKE pattern matching this regex can be translated into a suffix search.
+ static const RE2 kLikePatternIsEndsWith(R"(%+([^%_]*))");
+
+ auto original_options = MatchSubstringState::Get(ctx);
+ auto original_state = ctx->state();
+
+ Status status;
+ std::string pattern;
+ if (!original_options.ignore_case &&
+ re2::RE2::FullMatch(original_options.pattern, kLikePatternIsSubstringMatch,
+ &pattern)) {
+ MatchSubstringOptions converted_options{pattern, original_options.ignore_case};
+ MatchSubstringState converted_state(converted_options);
+ ctx->SetState(&converted_state);
+ status = MatchSubstring<StringType, PlainSubstringMatcher>::Exec(ctx, batch, out);
+ } else if (!original_options.ignore_case &&
+ re2::RE2::FullMatch(original_options.pattern, kLikePatternIsStartsWith,
+ &pattern)) {
+ MatchSubstringOptions converted_options{pattern, original_options.ignore_case};
+ MatchSubstringState converted_state(converted_options);
+ ctx->SetState(&converted_state);
+ status = MatchSubstring<StringType, PlainStartsWithMatcher>::Exec(ctx, batch, out);
+ } else if (!original_options.ignore_case &&
+ re2::RE2::FullMatch(original_options.pattern, kLikePatternIsEndsWith,
+ &pattern)) {
+ MatchSubstringOptions converted_options{pattern, original_options.ignore_case};
+ MatchSubstringState converted_state(converted_options);
+ ctx->SetState(&converted_state);
+ status = MatchSubstring<StringType, PlainEndsWithMatcher>::Exec(ctx, batch, out);
+ } else {
+ MatchSubstringOptions converted_options{MakeLikeRegex(original_options),
+ original_options.ignore_case};
+ MatchSubstringState converted_state(converted_options);
+ ctx->SetState(&converted_state);
+ status = MatchSubstring<StringType, RegexSubstringMatcher>::Exec(ctx, batch, out);
+ }
+ ctx->SetState(original_state);
+ return status;
+ }
+};
+
+const FunctionDoc match_like_doc(
+ "Match strings against SQL-style LIKE pattern",
+ ("For each string in `strings`, emit true iff it fully matches a given pattern "
+ "at any position. That is, '%' will match any number of characters, '_' will "
+ "match exactly one character, and any other character matches itself. To "
+ "match a literal '%', '_', or '\\', precede the character with a backslash.\n"
+ "Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
+ {"strings"}, "MatchSubstringOptions");
+
+#endif
+
+void AddMatchSubstring(FunctionRegistry* registry) {
+ {
+ auto func = std::make_shared<ScalarFunction>("match_substring", Arity::Unary(),
+ &match_substring_doc);
+ auto exec_32 = MatchSubstring<StringType, PlainSubstringMatcher>::Exec;
+ auto exec_64 = MatchSubstring<LargeStringType, PlainSubstringMatcher>::Exec;
+ DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, MatchSubstringState::Init));
+ DCHECK_OK(
+ func->AddKernel({large_utf8()}, boolean(), exec_64, MatchSubstringState::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+ {
+ auto func = std::make_shared<ScalarFunction>("starts_with", Arity::Unary(),
+ &match_substring_doc);
+ auto exec_32 = MatchSubstring<StringType, PlainStartsWithMatcher>::Exec;
+ auto exec_64 = MatchSubstring<LargeStringType, PlainStartsWithMatcher>::Exec;
+ DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, MatchSubstringState::Init));
+ DCHECK_OK(
+ func->AddKernel({large_utf8()}, boolean(), exec_64, MatchSubstringState::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+ {
+ auto func = std::make_shared<ScalarFunction>("ends_with", Arity::Unary(),
+ &match_substring_doc);
+ auto exec_32 = MatchSubstring<StringType, PlainEndsWithMatcher>::Exec;
+ auto exec_64 = MatchSubstring<LargeStringType, PlainEndsWithMatcher>::Exec;
+ DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, MatchSubstringState::Init));
+ DCHECK_OK(
+ func->AddKernel({large_utf8()}, boolean(), exec_64, MatchSubstringState::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+#ifdef ARROW_WITH_RE2
+ {
+ auto func = std::make_shared<ScalarFunction>("match_substring_regex", Arity::Unary(),
+ &match_substring_regex_doc);
+ auto exec_32 = MatchSubstring<StringType, RegexSubstringMatcher>::Exec;
+ auto exec_64 = MatchSubstring<LargeStringType, RegexSubstringMatcher>::Exec;
+ DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, MatchSubstringState::Init));
+ DCHECK_OK(
+ func->AddKernel({large_utf8()}, boolean(), exec_64, MatchSubstringState::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+ {
+ auto func =
+ std::make_shared<ScalarFunction>("match_like", Arity::Unary(), &match_like_doc);
+ auto exec_32 = MatchLike<StringType>::Exec;
+ auto exec_64 = MatchLike<LargeStringType>::Exec;
+ DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, MatchSubstringState::Init));
+ DCHECK_OK(
+ func->AddKernel({large_utf8()}, boolean(), exec_64, MatchSubstringState::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+#endif
+}
+
+// Substring find - lfind/index/etc.
+
+struct FindSubstring {
+ const PlainSubstringMatcher matcher_;
+
+ explicit FindSubstring(PlainSubstringMatcher matcher) : matcher_(std::move(matcher)) {}
+
+ template <typename OutValue, typename... Ignored>
+ OutValue Call(KernelContext*, util::string_view val, Status*) const {
+ return static_cast<OutValue>(matcher_.Find(val));
+ }
+};
+
+#ifdef ARROW_WITH_RE2
+struct FindSubstringRegex {
+ std::unique_ptr<RE2> regex_match_;
+
+ explicit FindSubstringRegex(const MatchSubstringOptions& options,
+ bool literal = false) {
+ std::string regex = "(";
+ regex.reserve(options.pattern.length() + 2);
+ regex += literal ? RE2::QuoteMeta(options.pattern) : options.pattern;
+ regex += ")";
+ regex_match_.reset(new RE2(std::move(regex), RegexSubstringMatcher::MakeRE2Options(
+ options, /*literal=*/false)));
+ }
+
+ template <typename OutValue, typename... Ignored>
+ OutValue Call(KernelContext*, util::string_view val, Status*) const {
+ re2::StringPiece piece(val.data(), val.length());
+ re2::StringPiece match;
+ if (re2::RE2::PartialMatch(piece, *regex_match_, &match)) {
+ return static_cast<OutValue>(match.data() - piece.data());
+ }
+ return -1;
+ }
+};
+#endif
+
+template <typename InputType>
+struct FindSubstringExec {
+ using OffsetType = typename TypeTraits<InputType>::OffsetType;
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const MatchSubstringOptions& options = MatchSubstringState::Get(ctx);
+ if (options.ignore_case) {
+#ifdef ARROW_WITH_RE2
+ applicator::ScalarUnaryNotNullStateful<OffsetType, InputType, FindSubstringRegex>
+ kernel{FindSubstringRegex(options, /*literal=*/true)};
+ return kernel.Exec(ctx, batch, out);
+#endif
+ return Status::NotImplemented("ignore_case requires RE2");
+ }
+ applicator::ScalarUnaryNotNullStateful<OffsetType, InputType, FindSubstring> kernel{
+ FindSubstring(PlainSubstringMatcher(options))};
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+const FunctionDoc find_substring_doc(
+ "Find first occurrence of substring",
+ ("For each string in `strings`, emit the index of the first occurrence of the given "
+ "pattern, or -1 if not found.\n"
+ "Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
+ {"strings"}, "MatchSubstringOptions");
+
+#ifdef ARROW_WITH_RE2
+template <typename InputType>
+struct FindSubstringRegexExec {
+ using OffsetType = typename TypeTraits<InputType>::OffsetType;
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const MatchSubstringOptions& options = MatchSubstringState::Get(ctx);
+ applicator::ScalarUnaryNotNullStateful<OffsetType, InputType, FindSubstringRegex>
+ kernel{FindSubstringRegex(options, /*literal=*/false)};
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+const FunctionDoc find_substring_regex_doc(
+ "Find location of first match of regex pattern",
+ ("For each string in `strings`, emit the index of the first match of the given "
+ "pattern, or -1 if not found.\n"
+ "Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
+ {"strings"}, "MatchSubstringOptions");
+#endif
+
+void AddFindSubstring(FunctionRegistry* registry) {
+ {
+ auto func = std::make_shared<ScalarFunction>("find_substring", Arity::Unary(),
+ &find_substring_doc);
+ for (const auto& ty : BaseBinaryTypes()) {
+ auto offset_type = offset_bit_width(ty->id()) == 64 ? int64() : int32();
+ DCHECK_OK(func->AddKernel({ty}, offset_type,
+ GenerateTypeAgnosticVarBinaryBase<FindSubstringExec>(ty),
+ MatchSubstringState::Init));
+ }
+ DCHECK_OK(func->AddKernel({InputType(Type::FIXED_SIZE_BINARY)}, int32(),
+ FindSubstringExec<FixedSizeBinaryType>::Exec,
+ MatchSubstringState::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+#ifdef ARROW_WITH_RE2
+ {
+ auto func = std::make_shared<ScalarFunction>("find_substring_regex", Arity::Unary(),
+ &find_substring_regex_doc);
+ for (const auto& ty : BaseBinaryTypes()) {
+ auto offset_type = offset_bit_width(ty->id()) == 64 ? int64() : int32();
+ DCHECK_OK(
+ func->AddKernel({ty}, offset_type,
+ GenerateTypeAgnosticVarBinaryBase<FindSubstringRegexExec>(ty),
+ MatchSubstringState::Init));
+ }
+ DCHECK_OK(func->AddKernel({InputType(Type::FIXED_SIZE_BINARY)}, int32(),
+ FindSubstringRegexExec<FixedSizeBinaryType>::Exec,
+ MatchSubstringState::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+#endif
+}
+
+// Substring count
+
+struct CountSubstring {
+ const PlainSubstringMatcher matcher_;
+
+ explicit CountSubstring(PlainSubstringMatcher matcher) : matcher_(std::move(matcher)) {}
+
+ template <typename OutValue, typename... Ignored>
+ OutValue Call(KernelContext*, util::string_view val, Status*) const {
+ OutValue count = 0;
+ uint64_t start = 0;
+ const auto pattern_size = std::max<uint64_t>(1, matcher_.options_.pattern.size());
+ while (start <= val.size()) {
+ const int64_t index = matcher_.Find(val.substr(start));
+ if (index >= 0) {
+ count++;
+ start += index + pattern_size;
+ } else {
+ break;
+ }
+ }
+ return count;
+ }
+};
+
+#ifdef ARROW_WITH_RE2
+struct CountSubstringRegex {
+ std::unique_ptr<RE2> regex_match_;
+
+ explicit CountSubstringRegex(const MatchSubstringOptions& options, bool literal = false)
+ : regex_match_(new RE2(options.pattern,
+ RegexSubstringMatcher::MakeRE2Options(options, literal))) {}
+
+ static Result<CountSubstringRegex> Make(const MatchSubstringOptions& options,
+ bool literal = false) {
+ CountSubstringRegex counter(options, literal);
+ RETURN_NOT_OK(RegexStatus(*counter.regex_match_));
+ return std::move(counter);
+ }
+
+ template <typename OutValue, typename... Ignored>
+ OutValue Call(KernelContext*, util::string_view val, Status*) const {
+ OutValue count = 0;
+ re2::StringPiece input(val.data(), val.size());
+ auto last_size = input.size();
+ while (re2::RE2::FindAndConsume(&input, *regex_match_)) {
+ count++;
+ if (last_size == input.size()) {
+ // 0-length match
+ if (input.size() > 0) {
+ input.remove_prefix(1);
+ } else {
+ break;
+ }
+ }
+ last_size = input.size();
+ }
+ return count;
+ }
+};
+
+template <typename InputType>
+struct CountSubstringRegexExec {
+ using OffsetType = typename TypeTraits<InputType>::OffsetType;
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const MatchSubstringOptions& options = MatchSubstringState::Get(ctx);
+ ARROW_ASSIGN_OR_RAISE(auto counter, CountSubstringRegex::Make(options));
+ applicator::ScalarUnaryNotNullStateful<OffsetType, InputType, CountSubstringRegex>
+ kernel{std::move(counter)};
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+#endif
+
+template <typename InputType>
+struct CountSubstringExec {
+ using OffsetType = typename TypeTraits<InputType>::OffsetType;
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const MatchSubstringOptions& options = MatchSubstringState::Get(ctx);
+ if (options.ignore_case) {
+#ifdef ARROW_WITH_RE2
+ ARROW_ASSIGN_OR_RAISE(auto counter,
+ CountSubstringRegex::Make(options, /*literal=*/true));
+ applicator::ScalarUnaryNotNullStateful<OffsetType, InputType, CountSubstringRegex>
+ kernel{std::move(counter)};
+ return kernel.Exec(ctx, batch, out);
+#else
+ return Status::NotImplemented("ignore_case requires RE2");
+#endif
+ }
+ applicator::ScalarUnaryNotNullStateful<OffsetType, InputType, CountSubstring> kernel{
+ CountSubstring(PlainSubstringMatcher(options))};
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+const FunctionDoc count_substring_doc(
+ "Count occurrences of substring",
+ ("For each string in `strings`, emit the number of occurrences of the given "
+ "pattern.\n"
+ "Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
+ {"strings"}, "MatchSubstringOptions");
+
+#ifdef ARROW_WITH_RE2
+const FunctionDoc count_substring_regex_doc(
+ "Count occurrences of substring",
+ ("For each string in `strings`, emit the number of occurrences of the given "
+ "regex pattern.\n"
+ "Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
+ {"strings"}, "MatchSubstringOptions");
+#endif
+
+void AddCountSubstring(FunctionRegistry* registry) {
+ {
+ auto func = std::make_shared<ScalarFunction>("count_substring", Arity::Unary(),
+ &count_substring_doc);
+ for (const auto& ty : BaseBinaryTypes()) {
+ auto offset_type = offset_bit_width(ty->id()) == 64 ? int64() : int32();
+ DCHECK_OK(func->AddKernel({ty}, offset_type,
+ GenerateTypeAgnosticVarBinaryBase<CountSubstringExec>(ty),
+ MatchSubstringState::Init));
+ }
+ DCHECK_OK(func->AddKernel({InputType(Type::FIXED_SIZE_BINARY)}, int32(),
+ CountSubstringExec<FixedSizeBinaryType>::Exec,
+ MatchSubstringState::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+#ifdef ARROW_WITH_RE2
+ {
+ auto func = std::make_shared<ScalarFunction>("count_substring_regex", Arity::Unary(),
+ &count_substring_regex_doc);
+ for (const auto& ty : BaseBinaryTypes()) {
+ auto offset_type = offset_bit_width(ty->id()) == 64 ? int64() : int32();
+ DCHECK_OK(
+ func->AddKernel({ty}, offset_type,
+ GenerateTypeAgnosticVarBinaryBase<CountSubstringRegexExec>(ty),
+ MatchSubstringState::Init));
+ }
+ DCHECK_OK(func->AddKernel({InputType(Type::FIXED_SIZE_BINARY)}, int32(),
+ CountSubstringRegexExec<FixedSizeBinaryType>::Exec,
+ MatchSubstringState::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+#endif
+}
+
+// Slicing
+
+struct SliceTransformBase : public StringTransformBase {
+ using State = OptionsWrapper<SliceOptions>;
+
+ const SliceOptions* options;
+
+ Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) override {
+ options = &State::Get(ctx);
+ if (options->step == 0) {
+ return Status::Invalid("Slice step cannot be zero");
+ }
+ return Status::OK();
+ }
+};
+
+struct SliceCodeunitsTransform : SliceTransformBase {
+ int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override {
+ const SliceOptions& opt = *this->options;
+ if ((opt.start >= 0) != (opt.stop >= 0)) {
+ // If start and stop don't have the same sign, we can't guess an upper bound
+ // on the resulting slice lengths, so return a worst case estimate.
+ return input_ncodeunits;
+ }
+ int64_t max_slice_codepoints = (opt.stop - opt.start + opt.step - 1) / opt.step;
+ // The maximum UTF8 byte size of a codepoint is 4
+ return std::min(input_ncodeunits,
+ 4 * ninputs * std::max<int64_t>(0, max_slice_codepoints));
+ }
+
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ if (options->step >= 1) {
+ return SliceForward(input, input_string_ncodeunits, output);
+ }
+ return SliceBackward(input, input_string_ncodeunits, output);
+ }
+
+#define RETURN_IF_UTF8_ERROR(expr) \
+ do { \
+ if (ARROW_PREDICT_FALSE(!expr)) { \
+ return kTransformError; \
+ } \
+ } while (0)
+
+ int64_t SliceForward(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ // Slice in forward order (step > 0)
+ const SliceOptions& opt = *this->options;
+ const uint8_t* begin = input;
+ const uint8_t* end = input + input_string_ncodeunits;
+ const uint8_t* begin_sliced = begin;
+ const uint8_t* end_sliced = end;
+
+ // First, compute begin_sliced and end_sliced
+ if (opt.start >= 0) {
+ // start counting from the left
+ RETURN_IF_UTF8_ERROR(
+ arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, opt.start));
+ if (opt.stop > opt.start) {
+ // continue counting from begin_sliced
+ const int64_t length = opt.stop - opt.start;
+ RETURN_IF_UTF8_ERROR(
+ arrow::util::UTF8AdvanceCodepoints(begin_sliced, end, &end_sliced, length));
+ } else if (opt.stop < 0) {
+ // or from the end (but we will never need to < begin_sliced)
+ RETURN_IF_UTF8_ERROR(arrow::util::UTF8AdvanceCodepointsReverse(
+ begin_sliced, end, &end_sliced, -opt.stop));
+ } else {
+ // zero length slice
+ return 0;
+ }
+ } else {
+ // start counting from the right
+ RETURN_IF_UTF8_ERROR(arrow::util::UTF8AdvanceCodepointsReverse(
+ begin, end, &begin_sliced, -opt.start));
+ if (opt.stop > 0) {
+ // continue counting from the left, we cannot start from begin_sliced because we
+ // don't know how many codepoints are between begin and begin_sliced
+ RETURN_IF_UTF8_ERROR(
+ arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, opt.stop));
+ // and therefore we also needs this
+ if (end_sliced <= begin_sliced) {
+ // zero length slice
+ return 0;
+ }
+ } else if ((opt.stop < 0) && (opt.stop > opt.start)) {
+ // stop is negative, but larger than start, so we count again from the right
+ // in some cases we can optimize this, depending on the shortest path (from end
+ // or begin_sliced), but begin_sliced and opt.start can be 'out of sync',
+ // for instance when start=-100, when the string length is only 10.
+ RETURN_IF_UTF8_ERROR(arrow::util::UTF8AdvanceCodepointsReverse(
+ begin_sliced, end, &end_sliced, -opt.stop));
+ } else {
+ // zero length slice
+ return 0;
+ }
+ }
+
+ // Second, copy computed slice to output
+ DCHECK(begin_sliced <= end_sliced);
+ if (opt.step == 1) {
+ // fast case, where we simply can finish with a memcpy
+ std::copy(begin_sliced, end_sliced, output);
+ return end_sliced - begin_sliced;
+ }
+ uint8_t* dest = output;
+ const uint8_t* i = begin_sliced;
+
+ while (i < end_sliced) {
+ uint32_t codepoint = 0;
+ // write a single codepoint
+ RETURN_IF_UTF8_ERROR(arrow::util::UTF8Decode(&i, &codepoint));
+ dest = arrow::util::UTF8Encode(dest, codepoint);
+ // and skip the remainder
+ int64_t skips = opt.step - 1;
+ while ((skips--) && (i < end_sliced)) {
+ RETURN_IF_UTF8_ERROR(arrow::util::UTF8Decode(&i, &codepoint));
+ }
+ }
+ return dest - output;
+ }
+
+ int64_t SliceBackward(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ // Slice in reverse order (step < 0)
+ const SliceOptions& opt = *this->options;
+ const uint8_t* begin = input;
+ const uint8_t* end = input + input_string_ncodeunits;
+ const uint8_t* begin_sliced = begin;
+ const uint8_t* end_sliced = end;
+
+ // Serious +1 -1 kung fu because begin_sliced and end_sliced act like
+ // reverse iterators.
+ if (opt.start >= 0) {
+ // +1 because begin_sliced acts as as the end of a reverse iterator
+ RETURN_IF_UTF8_ERROR(
+ arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, opt.start + 1));
+ } else {
+ // -1 because start=-1 means the last codeunit, which is 0 advances
+ RETURN_IF_UTF8_ERROR(arrow::util::UTF8AdvanceCodepointsReverse(
+ begin, end, &begin_sliced, -opt.start - 1));
+ }
+ // make it point at the last codeunit of the previous codeunit
+ begin_sliced--;
+
+ // similar to opt.start
+ if (opt.stop >= 0) {
+ RETURN_IF_UTF8_ERROR(
+ arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, opt.stop + 1));
+ } else {
+ RETURN_IF_UTF8_ERROR(arrow::util::UTF8AdvanceCodepointsReverse(
+ begin, end, &end_sliced, -opt.stop - 1));
+ }
+ end_sliced--;
+
+ // Copy computed slice to output
+ uint8_t* dest = output;
+ const uint8_t* i = begin_sliced;
+ while (i > end_sliced) {
+ uint32_t codepoint = 0;
+ // write a single codepoint
+ RETURN_IF_UTF8_ERROR(arrow::util::UTF8DecodeReverse(&i, &codepoint));
+ dest = arrow::util::UTF8Encode(dest, codepoint);
+ // and skip the remainder
+ int64_t skips = -opt.step - 1;
+ while ((skips--) && (i > end_sliced)) {
+ RETURN_IF_UTF8_ERROR(arrow::util::UTF8DecodeReverse(&i, &codepoint));
+ }
+ }
+ return dest - output;
+ }
+
+#undef RETURN_IF_UTF8_ERROR
+};
+
+template <typename Type>
+using SliceCodeunits = StringTransformExec<Type, SliceCodeunitsTransform>;
+
+const FunctionDoc utf8_slice_codeunits_doc(
+ "Slice string ",
+ ("For each string in `strings`, slice into a substring defined by\n"
+ "`start`, `stop`, `step`) as given by `SliceOptions` where `start` is inclusive\n"
+ "and `stop` is exclusive and are measured in codeunits. If step is negative, the\n"
+ "string will be advanced in reversed order. A `step` of zero is considered an\n"
+ "error.\n"
+ "Null inputs emit null."),
+ {"strings"}, "SliceOptions");
+
+void AddSlice(FunctionRegistry* registry) {
+ auto func = std::make_shared<ScalarFunction>("utf8_slice_codeunits", Arity::Unary(),
+ &utf8_slice_codeunits_doc);
+ using t32 = SliceCodeunits<StringType>;
+ using t64 = SliceCodeunits<LargeStringType>;
+ DCHECK_OK(
+ func->AddKernel({utf8()}, utf8(), t32::Exec, SliceCodeunitsTransform::State::Init));
+ DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), t64::Exec,
+ SliceCodeunitsTransform::State::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+template <typename Derived, bool allow_empty = false>
+struct CharacterPredicateUnicode {
+ static bool Call(KernelContext*, const uint8_t* input, size_t input_string_ncodeunits,
+ Status* st) {
+ if (allow_empty && input_string_ncodeunits == 0) {
+ return true;
+ }
+ bool all;
+ bool any = false;
+ if (!ARROW_PREDICT_TRUE(arrow::util::UTF8AllOf(
+ input, input + input_string_ncodeunits, &all, [&any](uint32_t codepoint) {
+ any |= Derived::PredicateCharacterAny(codepoint);
+ return Derived::PredicateCharacterAll(codepoint);
+ }))) {
+ *st = Status::Invalid("Invalid UTF8 sequence in input");
+ return false;
+ }
+ return all & any;
+ }
+
+ static inline bool PredicateCharacterAny(uint32_t) {
+ return true; // default condition make sure there is at least 1 charachter
+ }
+};
+
+template <typename Derived, bool allow_empty = false>
+struct CharacterPredicateAscii {
+ static bool Call(KernelContext*, const uint8_t* input, size_t input_string_ncodeunits,
+ Status*) {
+ if (allow_empty && input_string_ncodeunits == 0) {
+ return true;
+ }
+ bool any = false;
+ // MB: A simple for loops seems 8% faster on gcc 9.3, running the IsAlphaNumericAscii
+ // benchmark. I don't consider that worth it.
+ bool all = std::all_of(input, input + input_string_ncodeunits,
+ [&any](uint8_t ascii_character) {
+ any |= Derived::PredicateCharacterAny(ascii_character);
+ return Derived::PredicateCharacterAll(ascii_character);
+ });
+ return all & any;
+ }
+
+ static inline bool PredicateCharacterAny(uint8_t) {
+ return true; // default condition make sure there is at least 1 charachter
+ }
+};
+
+#ifdef ARROW_WITH_UTF8PROC
+struct IsAlphaNumericUnicode : CharacterPredicateUnicode<IsAlphaNumericUnicode> {
+ static inline bool PredicateCharacterAll(uint32_t codepoint) {
+ return IsAlphaNumericCharacterUnicode(codepoint);
+ }
+};
+#endif
+
+struct IsAlphaNumericAscii : CharacterPredicateAscii<IsAlphaNumericAscii> {
+ static inline bool PredicateCharacterAll(uint8_t ascii_character) {
+ return IsAlphaNumericCharacterAscii(ascii_character);
+ }
+};
+
+#ifdef ARROW_WITH_UTF8PROC
+struct IsAlphaUnicode : CharacterPredicateUnicode<IsAlphaUnicode> {
+ static inline bool PredicateCharacterAll(uint32_t codepoint) {
+ return IsAlphaCharacterUnicode(codepoint);
+ }
+};
+#endif
+
+struct IsAlphaAscii : CharacterPredicateAscii<IsAlphaAscii> {
+ static inline bool PredicateCharacterAll(uint8_t ascii_character) {
+ return IsAlphaCharacterAscii(ascii_character);
+ }
+};
+
+#ifdef ARROW_WITH_UTF8PROC
+struct IsDecimalUnicode : CharacterPredicateUnicode<IsDecimalUnicode> {
+ static inline bool PredicateCharacterAll(uint32_t codepoint) {
+ return IsDecimalCharacterUnicode(codepoint);
+ }
+};
+#endif
+
+struct IsDecimalAscii : CharacterPredicateAscii<IsDecimalAscii> {
+ static inline bool PredicateCharacterAll(uint8_t ascii_character) {
+ return IsDecimalCharacterAscii(ascii_character);
+ }
+};
+
+#ifdef ARROW_WITH_UTF8PROC
+struct IsDigitUnicode : CharacterPredicateUnicode<IsDigitUnicode> {
+ static inline bool PredicateCharacterAll(uint32_t codepoint) {
+ return IsDigitCharacterUnicode(codepoint);
+ }
+};
+
+struct IsNumericUnicode : CharacterPredicateUnicode<IsNumericUnicode> {
+ static inline bool PredicateCharacterAll(uint32_t codepoint) {
+ return IsNumericCharacterUnicode(codepoint);
+ }
+};
+#endif
+
+struct IsAscii {
+ static bool Call(KernelContext*, const uint8_t* input,
+ size_t input_string_nascii_characters, Status*) {
+ return std::all_of(input, input + input_string_nascii_characters, IsAsciiCharacter);
+ }
+};
+
+#ifdef ARROW_WITH_UTF8PROC
+struct IsLowerUnicode : CharacterPredicateUnicode<IsLowerUnicode> {
+ static inline bool PredicateCharacterAll(uint32_t codepoint) {
+ // Only for cased character it needs to be lower case
+ return !IsCasedCharacterUnicode(codepoint) || IsLowerCaseCharacterUnicode(codepoint);
+ }
+ static inline bool PredicateCharacterAny(uint32_t codepoint) {
+ return IsCasedCharacterUnicode(codepoint); // at least 1 cased character
+ }
+};
+#endif
+
+struct IsLowerAscii : CharacterPredicateAscii<IsLowerAscii> {
+ static inline bool PredicateCharacterAll(uint8_t ascii_character) {
+ // Only for cased character it needs to be lower case
+ return !IsCasedCharacterAscii(ascii_character) ||
+ IsLowerCaseCharacterAscii(ascii_character);
+ }
+ static inline bool PredicateCharacterAny(uint8_t ascii_character) {
+ return IsCasedCharacterAscii(ascii_character); // at least 1 cased character
+ }
+};
+
+#ifdef ARROW_WITH_UTF8PROC
+struct IsPrintableUnicode
+ : CharacterPredicateUnicode<IsPrintableUnicode, /*allow_empty=*/true> {
+ static inline bool PredicateCharacterAll(uint32_t codepoint) {
+ return codepoint == ' ' || IsPrintableCharacterUnicode(codepoint);
+ }
+};
+#endif
+
+struct IsPrintableAscii
+ : CharacterPredicateAscii<IsPrintableAscii, /*allow_empty=*/true> {
+ static inline bool PredicateCharacterAll(uint8_t ascii_character) {
+ return IsPrintableCharacterAscii(ascii_character);
+ }
+};
+
+#ifdef ARROW_WITH_UTF8PROC
+struct IsSpaceUnicode : CharacterPredicateUnicode<IsSpaceUnicode> {
+ static inline bool PredicateCharacterAll(uint32_t codepoint) {
+ return IsSpaceCharacterUnicode(codepoint);
+ }
+};
+#endif
+
+struct IsSpaceAscii : CharacterPredicateAscii<IsSpaceAscii> {
+ static inline bool PredicateCharacterAll(uint8_t ascii_character) {
+ return IsSpaceCharacterAscii(ascii_character);
+ }
+};
+
+#ifdef ARROW_WITH_UTF8PROC
+struct IsTitleUnicode {
+ static bool Call(KernelContext*, const uint8_t* input, size_t input_string_ncodeunits,
+ Status* st) {
+ // rules:
+ // 1. lower case follows cased
+ // 2. upper case follows uncased
+ // 3. at least 1 cased character (which logically should be upper/title)
+ bool rules_1_and_2;
+ bool previous_cased = false; // in LL, LU or LT
+ bool rule_3 = false;
+ bool status =
+ arrow::util::UTF8AllOf(input, input + input_string_ncodeunits, &rules_1_and_2,
+ [&previous_cased, &rule_3](uint32_t codepoint) {
+ if (IsLowerCaseCharacterUnicode(codepoint)) {
+ if (!previous_cased) return false; // rule 1 broken
+ // next should be more lower case or uncased
+ previous_cased = true;
+ } else if (IsCasedCharacterUnicode(codepoint)) {
+ if (previous_cased) return false; // rule 2 broken
+ // next should be a lower case or uncased
+ previous_cased = true;
+ rule_3 = true; // rule 3 obeyed
+ } else {
+ // an uncased char, like _ or 1
+ // next should be upper case or more uncased
+ previous_cased = false;
+ }
+ return true;
+ });
+ if (!ARROW_PREDICT_TRUE(status)) {
+ *st = Status::Invalid("Invalid UTF8 sequence in input");
+ return false;
+ }
+ return rules_1_and_2 & rule_3;
+ }
+};
+#endif
+
+struct IsTitleAscii {
+ static bool Call(KernelContext*, const uint8_t* input, size_t input_string_ncodeunits,
+ Status*) {
+ // Rules:
+ // 1. lower case follows cased
+ // 2. upper case follows uncased
+ // 3. at least 1 cased character (which logically should be upper/title)
+ bool rules_1_and_2 = true;
+ bool previous_cased = false; // in LL, LU or LT
+ bool rule_3 = false;
+ for (const uint8_t* c = input; c < input + input_string_ncodeunits; ++c) {
+ if (IsLowerCaseCharacterAscii(*c)) {
+ if (!previous_cased) {
+ // rule 1 broken
+ rules_1_and_2 = false;
+ break;
+ }
+ // next should be more lower case or uncased
+ previous_cased = true;
+ } else if (IsCasedCharacterAscii(*c)) {
+ if (previous_cased) {
+ // rule 2 broken
+ rules_1_and_2 = false;
+ break;
+ }
+ // next should be a lower case or uncased
+ previous_cased = true;
+ rule_3 = true; // rule 3 obeyed
+ } else {
+ // an uncased character, like _ or 1
+ // next should be upper case or more uncased
+ previous_cased = false;
+ }
+ }
+ return rules_1_and_2 & rule_3;
+ }
+};
+
+#ifdef ARROW_WITH_UTF8PROC
+struct IsUpperUnicode : CharacterPredicateUnicode<IsUpperUnicode> {
+ static inline bool PredicateCharacterAll(uint32_t codepoint) {
+ // Only for cased character it needs to be lower case
+ return !IsCasedCharacterUnicode(codepoint) || IsUpperCaseCharacterUnicode(codepoint);
+ }
+ static inline bool PredicateCharacterAny(uint32_t codepoint) {
+ return IsCasedCharacterUnicode(codepoint); // at least 1 cased character
+ }
+};
+#endif
+
+struct IsUpperAscii : CharacterPredicateAscii<IsUpperAscii> {
+ static inline bool PredicateCharacterAll(uint8_t ascii_character) {
+ // Only for cased character it needs to be lower case
+ return !IsCasedCharacterAscii(ascii_character) ||
+ IsUpperCaseCharacterAscii(ascii_character);
+ }
+ static inline bool PredicateCharacterAny(uint8_t ascii_character) {
+ return IsCasedCharacterAscii(ascii_character); // at least 1 cased character
+ }
+};
+
+// splitting
+
+template <typename Options>
+struct SplitFinderBase {
+ virtual ~SplitFinderBase() = default;
+ virtual Status PreExec(const Options& options) { return Status::OK(); }
+
+ // Derived classes should also define these methods:
+ // static bool Find(const uint8_t* begin, const uint8_t* end,
+ // const uint8_t** separator_begin,
+ // const uint8_t** separator_end,
+ // const SplitPatternOptions& options);
+ //
+ // static bool FindReverse(const uint8_t* begin, const uint8_t* end,
+ // const uint8_t** separator_begin,
+ // const uint8_t** separator_end,
+ // const SplitPatternOptions& options);
+};
+
+template <typename Type, typename ListType, typename SplitFinder,
+ typename Options = typename SplitFinder::Options>
+struct SplitExec {
+ using string_offset_type = typename Type::offset_type;
+ using list_offset_type = typename ListType::offset_type;
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using ArrayListType = typename TypeTraits<ListType>::ArrayType;
+ using ListScalarType = typename TypeTraits<ListType>::ScalarType;
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+ using BuilderType = typename TypeTraits<Type>::BuilderType;
+ using ListOffsetsBuilderType = TypedBufferBuilder<list_offset_type>;
+ using State = OptionsWrapper<Options>;
+
+ // Keep the temporary storage accross individual values, to minimize reallocations
+ std::vector<util::string_view> parts;
+ Options options;
+
+ explicit SplitExec(const Options& options) : options(options) {}
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return SplitExec{State::Get(ctx)}.Execute(ctx, batch, out);
+ }
+
+ Status Execute(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ SplitFinder finder;
+ RETURN_NOT_OK(finder.PreExec(options));
+ if (batch[0].kind() == Datum::ARRAY) {
+ return Execute(ctx, &finder, batch[0].array(), out);
+ }
+ DCHECK_EQ(batch[0].kind(), Datum::SCALAR);
+ return Execute(ctx, &finder, batch[0].scalar(), out);
+ }
+
+ Status Execute(KernelContext* ctx, SplitFinder* finder,
+ const std::shared_ptr<ArrayData>& data, Datum* out) {
+ const ArrayType input(data);
+
+ BuilderType builder(input.type(), ctx->memory_pool());
+ // A slight overestimate of the data needed
+ RETURN_NOT_OK(builder.ReserveData(input.total_values_length()));
+ // The minimum amount of strings needed
+ RETURN_NOT_OK(builder.Resize(input.length() - input.null_count()));
+
+ ArrayData* output_list = out->mutable_array();
+ // List offsets were preallocated
+ auto* list_offsets = output_list->GetMutableValues<list_offset_type>(1);
+ DCHECK_NE(list_offsets, nullptr);
+ // Initial value
+ *list_offsets++ = 0;
+ for (int64_t i = 0; i < input.length(); ++i) {
+ if (!input.IsNull(i)) {
+ RETURN_NOT_OK(SplitString(input.GetView(i), finder, &builder));
+ if (ARROW_PREDICT_FALSE(builder.length() >
+ std::numeric_limits<list_offset_type>::max())) {
+ return Status::CapacityError("List offset does not fit into 32 bit");
+ }
+ }
+ *list_offsets++ = static_cast<list_offset_type>(builder.length());
+ }
+ // Assign string array to list child data
+ std::shared_ptr<Array> string_array;
+ RETURN_NOT_OK(builder.Finish(&string_array));
+ output_list->child_data.push_back(string_array->data());
+ return Status::OK();
+ }
+
+ Status Execute(KernelContext* ctx, SplitFinder* finder,
+ const std::shared_ptr<Scalar>& scalar, Datum* out) {
+ const auto& input = checked_cast<const ScalarType&>(*scalar);
+ auto result = checked_cast<ListScalarType*>(out->scalar().get());
+ if (input.is_valid) {
+ result->is_valid = true;
+ BuilderType builder(input.type, ctx->memory_pool());
+ util::string_view s(*input.value);
+ RETURN_NOT_OK(SplitString(s, finder, &builder));
+ RETURN_NOT_OK(builder.Finish(&result->value));
+ }
+ return Status::OK();
+ }
+
+ Status SplitString(const util::string_view& s, SplitFinder* finder,
+ BuilderType* builder) {
+ const uint8_t* begin = reinterpret_cast<const uint8_t*>(s.data());
+ const uint8_t* end = begin + s.length();
+
+ int64_t max_splits = options.max_splits;
+ // if there is no max splits, reversing does not make sense (and is probably less
+ // efficient), but is useful for testing
+ if (options.reverse) {
+ // note that i points 1 further than the 'current'
+ const uint8_t* i = end;
+ // we will record the parts in reverse order
+ parts.clear();
+ if (max_splits > -1) {
+ parts.reserve(max_splits + 1);
+ }
+ while (max_splits != 0) {
+ const uint8_t *separator_begin, *separator_end;
+ // find with whatever algo the part we will 'cut out'
+ if (finder->FindReverse(begin, i, &separator_begin, &separator_end, options)) {
+ parts.emplace_back(reinterpret_cast<const char*>(separator_end),
+ i - separator_end);
+ i = separator_begin;
+ max_splits--;
+ } else {
+ // if we cannot find a separator, we're done
+ break;
+ }
+ }
+ parts.emplace_back(reinterpret_cast<const char*>(begin), i - begin);
+ // now we do the copying
+ for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
+ RETURN_NOT_OK(builder->Append(*it));
+ }
+ } else {
+ const uint8_t* i = begin;
+ while (max_splits != 0) {
+ const uint8_t *separator_begin, *separator_end;
+ // find with whatever algo the part we will 'cut out'
+ if (finder->Find(i, end, &separator_begin, &separator_end, options)) {
+ // the part till the beginning of the 'cut'
+ RETURN_NOT_OK(
+ builder->Append(i, static_cast<string_offset_type>(separator_begin - i)));
+ i = separator_end;
+ max_splits--;
+ } else {
+ // if we cannot find a separator, we're done
+ break;
+ }
+ }
+ // trailing part
+ RETURN_NOT_OK(builder->Append(i, static_cast<string_offset_type>(end - i)));
+ }
+ return Status::OK();
+ }
+};
+
+struct SplitPatternFinder : public SplitFinderBase<SplitPatternOptions> {
+ using Options = SplitPatternOptions;
+
+ Status PreExec(const SplitPatternOptions& options) override {
+ if (options.pattern.length() == 0) {
+ return Status::Invalid("Empty separator");
+ }
+ return Status::OK();
+ }
+
+ static bool Find(const uint8_t* begin, const uint8_t* end,
+ const uint8_t** separator_begin, const uint8_t** separator_end,
+ const SplitPatternOptions& options) {
+ const uint8_t* pattern = reinterpret_cast<const uint8_t*>(options.pattern.c_str());
+ const int64_t pattern_length = options.pattern.length();
+ const uint8_t* i = begin;
+ // this is O(n*m) complexity, we could use the Knuth-Morris-Pratt algorithm used in
+ // the match kernel
+ while ((i + pattern_length <= end)) {
+ i = std::search(i, end, pattern, pattern + pattern_length);
+ if (i != end) {
+ *separator_begin = i;
+ *separator_end = i + pattern_length;
+ return true;
+ }
+ }
+ return false;
+ }
+
+ static bool FindReverse(const uint8_t* begin, const uint8_t* end,
+ const uint8_t** separator_begin, const uint8_t** separator_end,
+ const SplitPatternOptions& options) {
+ const uint8_t* pattern = reinterpret_cast<const uint8_t*>(options.pattern.c_str());
+ const int64_t pattern_length = options.pattern.length();
+ // this is O(n*m) complexity, we could use the Knuth-Morris-Pratt algorithm used in
+ // the match kernel
+ std::reverse_iterator<const uint8_t*> ri(end);
+ std::reverse_iterator<const uint8_t*> rend(begin);
+ std::reverse_iterator<const uint8_t*> pattern_rbegin(pattern + pattern_length);
+ std::reverse_iterator<const uint8_t*> pattern_rend(pattern);
+ while (begin <= ri.base() - pattern_length) {
+ ri = std::search(ri, rend, pattern_rbegin, pattern_rend);
+ if (ri != rend) {
+ *separator_begin = ri.base() - pattern_length;
+ *separator_end = ri.base();
+ return true;
+ }
+ }
+ return false;
+ }
+};
+
+template <typename Type, typename ListType>
+using SplitPatternExec = SplitExec<Type, ListType, SplitPatternFinder>;
+
+const FunctionDoc split_pattern_doc(
+ "Split string according to separator",
+ ("Split each string according to the exact `pattern` defined in\n"
+ "SplitPatternOptions. The output for each string input is a list\n"
+ "of strings.\n"
+ "\n"
+ "The maximum number of splits and direction of splitting\n"
+ "(forward, reverse) can optionally be defined in SplitPatternOptions."),
+ {"strings"}, "SplitPatternOptions");
+
+const FunctionDoc ascii_split_whitespace_doc(
+ "Split string according to any ASCII whitespace",
+ ("Split each string according any non-zero length sequence of ASCII\n"
+ "whitespace characters. The output for each string input is a list\n"
+ "of strings.\n"
+ "\n"
+ "The maximum number of splits and direction of splitting\n"
+ "(forward, reverse) can optionally be defined in SplitOptions."),
+ {"strings"}, "SplitOptions");
+
+const FunctionDoc utf8_split_whitespace_doc(
+ "Split string according to any Unicode whitespace",
+ ("Split each string according any non-zero length sequence of Unicode\n"
+ "whitespace characters. The output for each string input is a list\n"
+ "of strings.\n"
+ "\n"
+ "The maximum number of splits and direction of splitting\n"
+ "(forward, reverse) can optionally be defined in SplitOptions."),
+ {"strings"}, "SplitOptions");
+
+void AddSplitPattern(FunctionRegistry* registry) {
+ auto func = std::make_shared<ScalarFunction>("split_pattern", Arity::Unary(),
+ &split_pattern_doc);
+ using t32 = SplitPatternExec<StringType, ListType>;
+ using t64 = SplitPatternExec<LargeStringType, ListType>;
+ DCHECK_OK(func->AddKernel({utf8()}, {list(utf8())}, t32::Exec, t32::State::Init));
+ DCHECK_OK(
+ func->AddKernel({large_utf8()}, {list(large_utf8())}, t64::Exec, t64::State::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+struct SplitWhitespaceAsciiFinder : public SplitFinderBase<SplitOptions> {
+ using Options = SplitOptions;
+
+ static bool Find(const uint8_t* begin, const uint8_t* end,
+ const uint8_t** separator_begin, const uint8_t** separator_end,
+ const SplitOptions& options) {
+ const uint8_t* i = begin;
+ while (i < end) {
+ if (IsSpaceCharacterAscii(*i)) {
+ *separator_begin = i;
+ do {
+ i++;
+ } while (IsSpaceCharacterAscii(*i) && i < end);
+ *separator_end = i;
+ return true;
+ }
+ i++;
+ }
+ return false;
+ }
+
+ static bool FindReverse(const uint8_t* begin, const uint8_t* end,
+ const uint8_t** separator_begin, const uint8_t** separator_end,
+ const SplitOptions& options) {
+ const uint8_t* i = end - 1;
+ while ((i >= begin)) {
+ if (IsSpaceCharacterAscii(*i)) {
+ *separator_end = i + 1;
+ do {
+ i--;
+ } while (IsSpaceCharacterAscii(*i) && i >= begin);
+ *separator_begin = i + 1;
+ return true;
+ }
+ i--;
+ }
+ return false;
+ }
+};
+
+template <typename Type, typename ListType>
+using SplitWhitespaceAsciiExec = SplitExec<Type, ListType, SplitWhitespaceAsciiFinder>;
+
+void AddSplitWhitespaceAscii(FunctionRegistry* registry) {
+ static const SplitOptions default_options{};
+ auto func =
+ std::make_shared<ScalarFunction>("ascii_split_whitespace", Arity::Unary(),
+ &ascii_split_whitespace_doc, &default_options);
+ using t32 = SplitWhitespaceAsciiExec<StringType, ListType>;
+ using t64 = SplitWhitespaceAsciiExec<LargeStringType, ListType>;
+ DCHECK_OK(func->AddKernel({utf8()}, {list(utf8())}, t32::Exec, t32::State::Init));
+ DCHECK_OK(
+ func->AddKernel({large_utf8()}, {list(large_utf8())}, t64::Exec, t64::State::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+#ifdef ARROW_WITH_UTF8PROC
+struct SplitWhitespaceUtf8Finder : public SplitFinderBase<SplitOptions> {
+ using Options = SplitOptions;
+
+ Status PreExec(const SplitOptions& options) override {
+ EnsureLookupTablesFilled();
+ return Status::OK();
+ }
+
+ bool Find(const uint8_t* begin, const uint8_t* end, const uint8_t** separator_begin,
+ const uint8_t** separator_end, const SplitOptions& options) {
+ const uint8_t* i = begin;
+ while ((i < end)) {
+ uint32_t codepoint = 0;
+ *separator_begin = i;
+ if (ARROW_PREDICT_FALSE(!arrow::util::UTF8Decode(&i, &codepoint))) {
+ return false;
+ }
+ if (IsSpaceCharacterUnicode(codepoint)) {
+ do {
+ *separator_end = i;
+ if (ARROW_PREDICT_FALSE(!arrow::util::UTF8Decode(&i, &codepoint))) {
+ return false;
+ }
+ } while (IsSpaceCharacterUnicode(codepoint) && i < end);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool FindReverse(const uint8_t* begin, const uint8_t* end,
+ const uint8_t** separator_begin, const uint8_t** separator_end,
+ const SplitOptions& options) {
+ const uint8_t* i = end - 1;
+ while ((i >= begin)) {
+ uint32_t codepoint = 0;
+ *separator_end = i + 1;
+ if (ARROW_PREDICT_FALSE(!arrow::util::UTF8DecodeReverse(&i, &codepoint))) {
+ return false;
+ }
+ if (IsSpaceCharacterUnicode(codepoint)) {
+ do {
+ *separator_begin = i + 1;
+ if (ARROW_PREDICT_FALSE(!arrow::util::UTF8DecodeReverse(&i, &codepoint))) {
+ return false;
+ }
+ } while (IsSpaceCharacterUnicode(codepoint) && i >= begin);
+ return true;
+ }
+ }
+ return false;
+ }
+};
+
+template <typename Type, typename ListType>
+using SplitWhitespaceUtf8Exec = SplitExec<Type, ListType, SplitWhitespaceUtf8Finder>;
+
+void AddSplitWhitespaceUTF8(FunctionRegistry* registry) {
+ static const SplitOptions default_options{};
+ auto func =
+ std::make_shared<ScalarFunction>("utf8_split_whitespace", Arity::Unary(),
+ &utf8_split_whitespace_doc, &default_options);
+ using t32 = SplitWhitespaceUtf8Exec<StringType, ListType>;
+ using t64 = SplitWhitespaceUtf8Exec<LargeStringType, ListType>;
+ DCHECK_OK(func->AddKernel({utf8()}, {list(utf8())}, t32::Exec, t32::State::Init));
+ DCHECK_OK(
+ func->AddKernel({large_utf8()}, {list(large_utf8())}, t64::Exec, t64::State::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+#endif // ARROW_WITH_UTF8PROC
+
+#ifdef ARROW_WITH_RE2
+struct SplitRegexFinder : public SplitFinderBase<SplitPatternOptions> {
+ using Options = SplitPatternOptions;
+
+ util::optional<RE2> regex_split;
+
+ Status PreExec(const SplitPatternOptions& options) override {
+ if (options.reverse) {
+ return Status::NotImplemented("Cannot split in reverse with regex");
+ }
+ // RE2 does *not* give you the full match! Must wrap the regex in a capture group
+ // There is FindAndConsume, but it would give only the end of the separator
+ std::string pattern = "(";
+ pattern.reserve(options.pattern.size() + 2);
+ pattern += options.pattern;
+ pattern += ')';
+ regex_split.emplace(std::move(pattern));
+ return RegexStatus(*regex_split);
+ }
+
+ bool Find(const uint8_t* begin, const uint8_t* end, const uint8_t** separator_begin,
+ const uint8_t** separator_end, const SplitPatternOptions& options) {
+ re2::StringPiece piece(reinterpret_cast<const char*>(begin),
+ std::distance(begin, end));
+ // "StringPiece is mutated to point to matched piece"
+ re2::StringPiece result;
+ if (!re2::RE2::PartialMatch(piece, *regex_split, &result)) {
+ return false;
+ }
+ *separator_begin = reinterpret_cast<const uint8_t*>(result.data());
+ *separator_end = reinterpret_cast<const uint8_t*>(result.data() + result.size());
+ return true;
+ }
+
+ bool FindReverse(const uint8_t* begin, const uint8_t* end,
+ const uint8_t** separator_begin, const uint8_t** separator_end,
+ const SplitPatternOptions& options) {
+ // Unsupported (see PreExec)
+ return false;
+ }
+};
+
+template <typename Type, typename ListType>
+using SplitRegexExec = SplitExec<Type, ListType, SplitRegexFinder>;
+
+const FunctionDoc split_pattern_regex_doc(
+ "Split string according to regex pattern",
+ ("Split each string according to the regex `pattern` defined in\n"
+ "SplitPatternOptions. The output for each string input is a list\n"
+ "of strings.\n"
+ "\n"
+ "The maximum number of splits and direction of splitting\n"
+ "(forward, reverse) can optionally be defined in SplitPatternOptions."),
+ {"strings"}, "SplitPatternOptions");
+
+void AddSplitRegex(FunctionRegistry* registry) {
+ auto func = std::make_shared<ScalarFunction>("split_pattern_regex", Arity::Unary(),
+ &split_pattern_regex_doc);
+ using t32 = SplitRegexExec<StringType, ListType>;
+ using t64 = SplitRegexExec<LargeStringType, ListType>;
+ DCHECK_OK(func->AddKernel({utf8()}, {list(utf8())}, t32::Exec, t32::State::Init));
+ DCHECK_OK(
+ func->AddKernel({large_utf8()}, {list(large_utf8())}, t64::Exec, t64::State::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+#endif // ARROW_WITH_RE2
+
+void AddSplit(FunctionRegistry* registry) {
+ AddSplitPattern(registry);
+ AddSplitWhitespaceAscii(registry);
+#ifdef ARROW_WITH_UTF8PROC
+ AddSplitWhitespaceUTF8(registry);
+#endif
+#ifdef ARROW_WITH_RE2
+ AddSplitRegex(registry);
+#endif
+}
+
+// ----------------------------------------------------------------------
+// Replace substring (plain, regex)
+
+template <typename Type, typename Replacer>
+struct ReplaceSubString {
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+ using offset_type = typename Type::offset_type;
+ using ValueDataBuilder = TypedBufferBuilder<uint8_t>;
+ using OffsetBuilder = TypedBufferBuilder<offset_type>;
+ using State = OptionsWrapper<ReplaceSubstringOptions>;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // TODO Cache replacer across invocations (for regex compilation)
+ ARROW_ASSIGN_OR_RAISE(auto replacer, Replacer::Make(State::Get(ctx)));
+ return Replace(ctx, batch, *replacer, out);
+ }
+
+ static Status Replace(KernelContext* ctx, const ExecBatch& batch,
+ const Replacer& replacer, Datum* out) {
+ ValueDataBuilder value_data_builder(ctx->memory_pool());
+ OffsetBuilder offset_builder(ctx->memory_pool());
+
+ if (batch[0].kind() == Datum::ARRAY) {
+ // We already know how many strings we have, so we can use Reserve/UnsafeAppend
+ RETURN_NOT_OK(offset_builder.Reserve(batch[0].array()->length + 1));
+ offset_builder.UnsafeAppend(0); // offsets start at 0
+
+ const ArrayData& input = *batch[0].array();
+ RETURN_NOT_OK(VisitArrayDataInline<Type>(
+ input,
+ [&](util::string_view s) {
+ RETURN_NOT_OK(replacer.ReplaceString(s, &value_data_builder));
+ offset_builder.UnsafeAppend(
+ static_cast<offset_type>(value_data_builder.length()));
+ return Status::OK();
+ },
+ [&]() {
+ // offset for null value
+ offset_builder.UnsafeAppend(
+ static_cast<offset_type>(value_data_builder.length()));
+ return Status::OK();
+ }));
+ ArrayData* output = out->mutable_array();
+ RETURN_NOT_OK(value_data_builder.Finish(&output->buffers[2]));
+ RETURN_NOT_OK(offset_builder.Finish(&output->buffers[1]));
+ } else {
+ const auto& input = checked_cast<const ScalarType&>(*batch[0].scalar());
+ auto result = std::make_shared<ScalarType>();
+ if (input.is_valid) {
+ util::string_view s = static_cast<util::string_view>(*input.value);
+ RETURN_NOT_OK(replacer.ReplaceString(s, &value_data_builder));
+ RETURN_NOT_OK(value_data_builder.Finish(&result->value));
+ result->is_valid = true;
+ }
+ out->value = result;
+ }
+
+ return Status::OK();
+ }
+};
+
+struct PlainSubStringReplacer {
+ const ReplaceSubstringOptions& options_;
+
+ static Result<std::unique_ptr<PlainSubStringReplacer>> Make(
+ const ReplaceSubstringOptions& options) {
+ return arrow::internal::make_unique<PlainSubStringReplacer>(options);
+ }
+
+ explicit PlainSubStringReplacer(const ReplaceSubstringOptions& options)
+ : options_(options) {}
+
+ Status ReplaceString(util::string_view s, TypedBufferBuilder<uint8_t>* builder) const {
+ const char* i = s.begin();
+ const char* end = s.end();
+ int64_t max_replacements = options_.max_replacements;
+ while ((i < end) && (max_replacements != 0)) {
+ const char* pos =
+ std::search(i, end, options_.pattern.begin(), options_.pattern.end());
+ if (pos == end) {
+ RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+ static_cast<int64_t>(end - i)));
+ i = end;
+ } else {
+ // the string before the pattern
+ RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+ static_cast<int64_t>(pos - i)));
+ // the replacement
+ RETURN_NOT_OK(
+ builder->Append(reinterpret_cast<const uint8_t*>(options_.replacement.data()),
+ options_.replacement.length()));
+ // skip pattern
+ i = pos + options_.pattern.length();
+ max_replacements--;
+ }
+ }
+ // if we exited early due to max_replacements, add the trailing part
+ return builder->Append(reinterpret_cast<const uint8_t*>(i),
+ static_cast<int64_t>(end - i));
+ }
+};
+
+#ifdef ARROW_WITH_RE2
+struct RegexSubStringReplacer {
+ const ReplaceSubstringOptions& options_;
+ const RE2 regex_find_;
+ const RE2 regex_replacement_;
+
+ static Result<std::unique_ptr<RegexSubStringReplacer>> Make(
+ const ReplaceSubstringOptions& options) {
+ auto replacer = arrow::internal::make_unique<RegexSubStringReplacer>(options);
+
+ RETURN_NOT_OK(RegexStatus(replacer->regex_find_));
+ RETURN_NOT_OK(RegexStatus(replacer->regex_replacement_));
+
+ std::string replacement_error;
+ if (!replacer->regex_replacement_.CheckRewriteString(replacer->options_.replacement,
+ &replacement_error)) {
+ return Status::Invalid("Invalid replacement string: ",
+ std::move(replacement_error));
+ }
+
+ return std::move(replacer);
+ }
+
+ // Using RE2::FindAndConsume we can only find the pattern if it is a group, therefore
+ // we have 2 regexes, one with () around it, one without.
+ explicit RegexSubStringReplacer(const ReplaceSubstringOptions& options)
+ : options_(options),
+ regex_find_("(" + options_.pattern + ")", RE2::Quiet),
+ regex_replacement_(options_.pattern, RE2::Quiet) {}
+
+ Status ReplaceString(util::string_view s, TypedBufferBuilder<uint8_t>* builder) const {
+ re2::StringPiece replacement(options_.replacement);
+
+ if (options_.max_replacements == -1) {
+ std::string s_copy(s.to_string());
+ re2::RE2::GlobalReplace(&s_copy, regex_replacement_, replacement);
+ return builder->Append(reinterpret_cast<const uint8_t*>(s_copy.data()),
+ s_copy.length());
+ }
+
+ // Since RE2 does not have the concept of max_replacements, we have to do some work
+ // ourselves.
+ // We might do this faster similar to RE2::GlobalReplace using Match and Rewrite
+ const char* i = s.begin();
+ const char* end = s.end();
+ re2::StringPiece piece(s.data(), s.length());
+
+ int64_t max_replacements = options_.max_replacements;
+ while ((i < end) && (max_replacements != 0)) {
+ std::string found;
+ if (!re2::RE2::FindAndConsume(&piece, regex_find_, &found)) {
+ RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+ static_cast<int64_t>(end - i)));
+ i = end;
+ } else {
+ // wind back to the beginning of the match
+ const char* pos = piece.begin() - found.length();
+ // the string before the pattern
+ RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+ static_cast<int64_t>(pos - i)));
+ // replace the pattern in what we found
+ if (!re2::RE2::Replace(&found, regex_replacement_, replacement)) {
+ return Status::Invalid("Regex found, but replacement failed");
+ }
+ RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(found.data()),
+ static_cast<int64_t>(found.length())));
+ // skip pattern
+ i = piece.begin();
+ max_replacements--;
+ }
+ }
+ // If we exited early due to max_replacements, add the trailing part
+ return builder->Append(reinterpret_cast<const uint8_t*>(i),
+ static_cast<int64_t>(end - i));
+ }
+};
+#endif
+
+template <typename Type>
+using ReplaceSubStringPlain = ReplaceSubString<Type, PlainSubStringReplacer>;
+
+const FunctionDoc replace_substring_doc(
+ "Replace non-overlapping substrings that match pattern by replacement",
+ ("For each string in `strings`, replace non-overlapping substrings that match\n"
+ "`pattern` by `replacement`. If `max_replacements != -1`, it determines the\n"
+ "maximum amount of replacements made, counting from the left. Null values emit\n"
+ "null."),
+ {"strings"}, "ReplaceSubstringOptions");
+
+#ifdef ARROW_WITH_RE2
+template <typename Type>
+using ReplaceSubStringRegex = ReplaceSubString<Type, RegexSubStringReplacer>;
+
+const FunctionDoc replace_substring_regex_doc(
+ "Replace non-overlapping substrings that match regex `pattern` by `replacement`",
+ ("For each string in `strings`, replace non-overlapping substrings that match the\n"
+ "regular expression `pattern` by `replacement` using the Google RE2 library.\n"
+ "If `max_replacements != -1`, it determines the maximum amount of replacements\n"
+ "made, counting from the left. Note that if the pattern contains groups,\n"
+ "backreferencing macan be used. Null values emit null."),
+ {"strings"}, "ReplaceSubstringOptions");
+#endif
+
+// ----------------------------------------------------------------------
+// Replace slice
+
+struct ReplaceSliceTransformBase : public StringTransformBase {
+ using State = OptionsWrapper<ReplaceSliceOptions>;
+
+ const ReplaceSliceOptions* options;
+
+ explicit ReplaceSliceTransformBase(const ReplaceSliceOptions& options)
+ : options{&options} {}
+
+ int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override {
+ return ninputs * options->replacement.size() + input_ncodeunits;
+ }
+};
+
+struct BinaryReplaceSliceTransform : ReplaceSliceTransformBase {
+ using ReplaceSliceTransformBase::ReplaceSliceTransformBase;
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ const auto& opts = *options;
+ int64_t before_slice = 0;
+ int64_t after_slice = 0;
+ uint8_t* output_start = output;
+
+ if (opts.start >= 0) {
+ // Count from left
+ before_slice = std::min<int64_t>(input_string_ncodeunits, opts.start);
+ } else {
+ // Count from right
+ before_slice = std::max<int64_t>(0, input_string_ncodeunits + opts.start);
+ }
+ // Mimic Pandas: if stop would be before start, treat as 0-length slice
+ if (opts.stop >= 0) {
+ // Count from left
+ after_slice =
+ std::min<int64_t>(input_string_ncodeunits, std::max(before_slice, opts.stop));
+ } else {
+ // Count from right
+ after_slice = std::max<int64_t>(before_slice, input_string_ncodeunits + opts.stop);
+ }
+ output = std::copy(input, input + before_slice, output);
+ output = std::copy(opts.replacement.begin(), opts.replacement.end(), output);
+ output = std::copy(input + after_slice, input + input_string_ncodeunits, output);
+ return output - output_start;
+ }
+
+ static int32_t FixedOutputSize(const ReplaceSliceOptions& opts, int32_t input_width) {
+ int32_t before_slice = 0;
+ int32_t after_slice = 0;
+ const int32_t start = static_cast<int32_t>(opts.start);
+ const int32_t stop = static_cast<int32_t>(opts.stop);
+ if (opts.start >= 0) {
+ // Count from left
+ before_slice = std::min<int32_t>(input_width, start);
+ } else {
+ // Count from right
+ before_slice = std::max<int32_t>(0, input_width + start);
+ }
+ if (opts.stop >= 0) {
+ // Count from left
+ after_slice = std::min<int32_t>(input_width, std::max<int32_t>(before_slice, stop));
+ } else {
+ // Count from right
+ after_slice = std::max<int32_t>(before_slice, input_width + stop);
+ }
+ return static_cast<int32_t>(before_slice + opts.replacement.size() +
+ (input_width - after_slice));
+ }
+};
+
+struct Utf8ReplaceSliceTransform : ReplaceSliceTransformBase {
+ using ReplaceSliceTransformBase::ReplaceSliceTransformBase;
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ const auto& opts = *options;
+ const uint8_t* begin = input;
+ const uint8_t* end = input + input_string_ncodeunits;
+ const uint8_t *begin_sliced, *end_sliced;
+ uint8_t* output_start = output;
+
+ // Mimic Pandas: if stop would be before start, treat as 0-length slice
+ if (opts.start >= 0) {
+ // Count from left
+ if (!arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, opts.start)) {
+ return kTransformError;
+ }
+ if (opts.stop > options->start) {
+ // Continue counting from left
+ const int64_t length = opts.stop - options->start;
+ if (!arrow::util::UTF8AdvanceCodepoints(begin_sliced, end, &end_sliced, length)) {
+ return kTransformError;
+ }
+ } else if (opts.stop < 0) {
+ // Count from right
+ if (!arrow::util::UTF8AdvanceCodepointsReverse(begin_sliced, end, &end_sliced,
+ -opts.stop)) {
+ return kTransformError;
+ }
+ } else {
+ // Zero-length slice
+ end_sliced = begin_sliced;
+ }
+ } else {
+ // Count from right
+ if (!arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &begin_sliced,
+ -opts.start)) {
+ return kTransformError;
+ }
+ if (opts.stop >= 0) {
+ // Restart counting from left
+ if (!arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, opts.stop)) {
+ return kTransformError;
+ }
+ if (end_sliced <= begin_sliced) {
+ // Zero-length slice
+ end_sliced = begin_sliced;
+ }
+ } else if ((opts.stop < 0) && (options->stop > options->start)) {
+ // Count from right
+ if (!arrow::util::UTF8AdvanceCodepointsReverse(begin_sliced, end, &end_sliced,
+ -opts.stop)) {
+ return kTransformError;
+ }
+ } else {
+ // zero-length slice
+ end_sliced = begin_sliced;
+ }
+ }
+ output = std::copy(begin, begin_sliced, output);
+ output = std::copy(opts.replacement.begin(), options->replacement.end(), output);
+ output = std::copy(end_sliced, end, output);
+ return output - output_start;
+ }
+};
+
+template <typename Type>
+using BinaryReplaceSlice =
+ StringTransformExecWithState<Type, BinaryReplaceSliceTransform>;
+template <typename Type>
+using Utf8ReplaceSlice = StringTransformExecWithState<Type, Utf8ReplaceSliceTransform>;
+
+const FunctionDoc binary_replace_slice_doc(
+ "Replace a slice of a binary string with `replacement`",
+ ("For each string in `strings`, replace a slice of the string defined by `start`"
+ "and `stop` with `replacement`. `start` is inclusive and `stop` is exclusive, "
+ "and both are measured in bytes.\n"
+ "Null values emit null."),
+ {"strings"}, "ReplaceSliceOptions");
+
+const FunctionDoc utf8_replace_slice_doc(
+ "Replace a slice of a string with `replacement`",
+ ("For each string in `strings`, replace a slice of the string defined by `start`"
+ "and `stop` with `replacement`. `start` is inclusive and `stop` is exclusive, "
+ "and both are measured in codeunits.\n"
+ "Null values emit null."),
+ {"strings"}, "ReplaceSliceOptions");
+
+void AddReplaceSlice(FunctionRegistry* registry) {
+ {
+ auto func = std::make_shared<ScalarFunction>("binary_replace_slice", Arity::Unary(),
+ &binary_replace_slice_doc);
+ for (const auto& ty : BaseBinaryTypes()) {
+ DCHECK_OK(func->AddKernel({ty}, ty,
+ GenerateTypeAgnosticVarBinaryBase<BinaryReplaceSlice>(ty),
+ ReplaceSliceTransformBase::State::Init));
+ }
+ using TransformExec =
+ FixedSizeBinaryTransformExecWithState<BinaryReplaceSliceTransform>;
+ DCHECK_OK(func->AddKernel({InputType(Type::FIXED_SIZE_BINARY)},
+ OutputType(TransformExec::OutputType), TransformExec::Exec,
+ ReplaceSliceTransformBase::State::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared<ScalarFunction>("utf8_replace_slice", Arity::Unary(),
+ &utf8_replace_slice_doc);
+ DCHECK_OK(func->AddKernel({utf8()}, utf8(), Utf8ReplaceSlice<StringType>::Exec,
+ ReplaceSliceTransformBase::State::Init));
+ DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(),
+ Utf8ReplaceSlice<LargeStringType>::Exec,
+ ReplaceSliceTransformBase::State::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+}
+
+// ----------------------------------------------------------------------
+// Extract with regex
+
+#ifdef ARROW_WITH_RE2
+
+// TODO cache this once per ExtractRegexOptions
+struct ExtractRegexData {
+ // Use unique_ptr<> because RE2 is non-movable
+ std::unique_ptr<RE2> regex;
+ std::vector<std::string> group_names;
+
+ static Result<ExtractRegexData> Make(const ExtractRegexOptions& options) {
+ ExtractRegexData data(options.pattern);
+ RETURN_NOT_OK(RegexStatus(*data.regex));
+
+ const int group_count = data.regex->NumberOfCapturingGroups();
+ const auto& name_map = data.regex->CapturingGroupNames();
+ data.group_names.reserve(group_count);
+
+ for (int i = 0; i < group_count; i++) {
+ auto item = name_map.find(i + 1); // re2 starts counting from 1
+ if (item == name_map.end()) {
+ // XXX should we instead just create fields with an empty name?
+ return Status::Invalid("Regular expression contains unnamed groups");
+ }
+ data.group_names.emplace_back(item->second);
+ }
+ return std::move(data);
+ }
+
+ Result<ValueDescr> ResolveOutputType(const std::vector<ValueDescr>& args) const {
+ const auto& input_type = args[0].type;
+ if (input_type == nullptr) {
+ // No input type specified => propagate shape
+ return args[0];
+ }
+ // Input type is either String or LargeString and is also the type of each
+ // field in the output struct type.
+ DCHECK(input_type->id() == Type::STRING || input_type->id() == Type::LARGE_STRING);
+ FieldVector fields;
+ fields.reserve(group_names.size());
+ std::transform(group_names.begin(), group_names.end(), std::back_inserter(fields),
+ [&](const std::string& name) { return field(name, input_type); });
+ return struct_(std::move(fields));
+ }
+
+ private:
+ explicit ExtractRegexData(const std::string& pattern)
+ : regex(new RE2(pattern, RE2::Quiet)) {}
+};
+
+Result<ValueDescr> ResolveExtractRegexOutput(KernelContext* ctx,
+ const std::vector<ValueDescr>& args) {
+ using State = OptionsWrapper<ExtractRegexOptions>;
+ ExtractRegexOptions options = State::Get(ctx);
+ ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options));
+ return data.ResolveOutputType(args);
+}
+
+struct ExtractRegexBase {
+ const ExtractRegexData& data;
+ const int group_count;
+ std::vector<re2::StringPiece> found_values;
+ std::vector<re2::RE2::Arg> args;
+ std::vector<const re2::RE2::Arg*> args_pointers;
+ const re2::RE2::Arg** args_pointers_start;
+ const re2::RE2::Arg* null_arg = nullptr;
+
+ explicit ExtractRegexBase(const ExtractRegexData& data)
+ : data(data),
+ group_count(static_cast<int>(data.group_names.size())),
+ found_values(group_count) {
+ args.reserve(group_count);
+ args_pointers.reserve(group_count);
+
+ for (int i = 0; i < group_count; i++) {
+ args.emplace_back(&found_values[i]);
+ // Since we reserved capacity, we're guaranteed the pointer remains valid
+ args_pointers.push_back(&args[i]);
+ }
+ // Avoid null pointer if there is no capture group
+ args_pointers_start = (group_count > 0) ? args_pointers.data() : &null_arg;
+ }
+
+ bool Match(util::string_view s) {
+ return re2::RE2::PartialMatchN(ToStringPiece(s), *data.regex, args_pointers_start,
+ group_count);
+ }
+};
+
+template <typename Type>
+struct ExtractRegex : public ExtractRegexBase {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+ using BuilderType = typename TypeTraits<Type>::BuilderType;
+ using State = OptionsWrapper<ExtractRegexOptions>;
+
+ using ExtractRegexBase::ExtractRegexBase;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ ExtractRegexOptions options = State::Get(ctx);
+ ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options));
+ return ExtractRegex{data}.Extract(ctx, batch, out);
+ }
+
+ Status Extract(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ ARROW_ASSIGN_OR_RAISE(auto descr, data.ResolveOutputType(batch.GetDescriptors()));
+ DCHECK_NE(descr.type, nullptr);
+ const auto& type = descr.type;
+
+ if (batch[0].kind() == Datum::ARRAY) {
+ std::unique_ptr<ArrayBuilder> array_builder;
+ RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), type, &array_builder));
+ StructBuilder* struct_builder = checked_cast<StructBuilder*>(array_builder.get());
+
+ std::vector<BuilderType*> field_builders;
+ field_builders.reserve(group_count);
+ for (int i = 0; i < group_count; i++) {
+ field_builders.push_back(
+ checked_cast<BuilderType*>(struct_builder->field_builder(i)));
+ }
+
+ auto visit_null = [&]() { return struct_builder->AppendNull(); };
+ auto visit_value = [&](util::string_view s) {
+ if (Match(s)) {
+ for (int i = 0; i < group_count; i++) {
+ RETURN_NOT_OK(field_builders[i]->Append(ToStringView(found_values[i])));
+ }
+ return struct_builder->Append();
+ } else {
+ return struct_builder->AppendNull();
+ }
+ };
+ const ArrayData& input = *batch[0].array();
+ RETURN_NOT_OK(VisitArrayDataInline<Type>(input, visit_value, visit_null));
+
+ std::shared_ptr<Array> out_array;
+ RETURN_NOT_OK(struct_builder->Finish(&out_array));
+ *out = std::move(out_array);
+ } else {
+ const auto& input = checked_cast<const ScalarType&>(*batch[0].scalar());
+ auto result = std::make_shared<StructScalar>(type);
+ if (input.is_valid && Match(util::string_view(*input.value))) {
+ result->value.reserve(group_count);
+ for (int i = 0; i < group_count; i++) {
+ result->value.push_back(
+ std::make_shared<ScalarType>(found_values[i].as_string()));
+ }
+ result->is_valid = true;
+ } else {
+ result->is_valid = false;
+ }
+ out->value = std::move(result);
+ }
+
+ return Status::OK();
+ }
+};
+
+const FunctionDoc extract_regex_doc(
+ "Extract substrings captured by a regex pattern",
+ ("For each string in `strings`, match the regular expression and, if\n"
+ "successful, emit a struct with field names and values coming from the\n"
+ "regular expression's named capture groups. If the input is null or the\n"
+ "regular expression fails matching, a null output value is emitted.\n"
+ "\n"
+ "Regular expression matching is done using the Google RE2 library."),
+ {"strings"}, "ExtractRegexOptions");
+
+void AddExtractRegex(FunctionRegistry* registry) {
+ auto func = std::make_shared<ScalarFunction>("extract_regex", Arity::Unary(),
+ &extract_regex_doc);
+ using t32 = ExtractRegex<StringType>;
+ using t64 = ExtractRegex<LargeStringType>;
+ OutputType out_ty(ResolveExtractRegexOutput);
+ ScalarKernel kernel;
+
+ // Null values will be computed based on regex match or not
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ kernel.signature.reset(new KernelSignature({utf8()}, out_ty));
+ kernel.exec = t32::Exec;
+ kernel.init = t32::State::Init;
+ DCHECK_OK(func->AddKernel(kernel));
+ kernel.signature.reset(new KernelSignature({large_utf8()}, out_ty));
+ kernel.exec = t64::Exec;
+ kernel.init = t64::State::Init;
+ DCHECK_OK(func->AddKernel(kernel));
+
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+#endif // ARROW_WITH_RE2
+
+// ----------------------------------------------------------------------
+// strptime string parsing
+
+using StrptimeState = OptionsWrapper<StrptimeOptions>;
+
+struct ParseStrptime {
+ explicit ParseStrptime(const StrptimeOptions& options)
+ : parser(TimestampParser::MakeStrptime(options.format)), unit(options.unit) {}
+
+ template <typename... Ignored>
+ int64_t Call(KernelContext*, util::string_view val, Status* st) const {
+ int64_t result = 0;
+ if (!(*parser)(val.data(), val.size(), unit, &result)) {
+ *st = Status::Invalid("Failed to parse string: '", val, "' as a scalar of type ",
+ TimestampType(unit).ToString());
+ }
+ return result;
+ }
+
+ std::shared_ptr<TimestampParser> parser;
+ TimeUnit::type unit;
+};
+
+template <typename InputType>
+Status StrptimeExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ applicator::ScalarUnaryNotNullStateful<TimestampType, InputType, ParseStrptime> kernel{
+ ParseStrptime(StrptimeState::Get(ctx))};
+ return kernel.Exec(ctx, batch, out);
+}
+
+Result<ValueDescr> StrptimeResolve(KernelContext* ctx, const std::vector<ValueDescr>&) {
+ if (ctx->state()) {
+ return ::arrow::timestamp(StrptimeState::Get(ctx).unit);
+ }
+
+ return Status::Invalid("strptime does not provide default StrptimeOptions");
+}
+
+// ----------------------------------------------------------------------
+// string padding
+
+template <bool PadLeft, bool PadRight>
+struct AsciiPadTransform : public StringTransformBase {
+ using State = OptionsWrapper<PadOptions>;
+
+ const PadOptions& options_;
+
+ explicit AsciiPadTransform(const PadOptions& options) : options_(options) {}
+
+ Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) override {
+ if (options_.padding.size() != 1) {
+ return Status::Invalid("Padding must be one byte, got '", options_.padding, "'");
+ }
+ return Status::OK();
+ }
+
+ int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override {
+ // This is likely very overallocated but hard to do better without
+ // actually looking at each string (because of strings that may be
+ // longer than the given width)
+ return input_ncodeunits + ninputs * options_.width;
+ }
+
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ if (input_string_ncodeunits >= options_.width) {
+ std::copy(input, input + input_string_ncodeunits, output);
+ return input_string_ncodeunits;
+ }
+ const int64_t spaces = options_.width - input_string_ncodeunits;
+ int64_t left = 0;
+ int64_t right = 0;
+ if (PadLeft && PadRight) {
+ // If odd number of spaces, put the extra space on the right
+ left = spaces / 2;
+ right = spaces - left;
+ } else if (PadLeft) {
+ left = spaces;
+ } else if (PadRight) {
+ right = spaces;
+ } else {
+ DCHECK(false) << "unreachable";
+ return 0;
+ }
+ std::fill(output, output + left, options_.padding[0]);
+ output += left;
+ output = std::copy(input, input + input_string_ncodeunits, output);
+ std::fill(output, output + right, options_.padding[0]);
+ return options_.width;
+ }
+};
+
+template <bool PadLeft, bool PadRight>
+struct Utf8PadTransform : public StringTransformBase {
+ using State = OptionsWrapper<PadOptions>;
+
+ const PadOptions& options_;
+
+ explicit Utf8PadTransform(const PadOptions& options) : options_(options) {}
+
+ Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) override {
+ auto str = reinterpret_cast<const uint8_t*>(options_.padding.data());
+ auto strlen = options_.padding.size();
+ if (util::UTF8Length(str, str + strlen) != 1) {
+ return Status::Invalid("Padding must be one codepoint, got '", options_.padding,
+ "'");
+ }
+ return Status::OK();
+ }
+
+ int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override {
+ // This is likely very overallocated but hard to do better without
+ // actually looking at each string (because of strings that may be
+ // longer than the given width)
+ // One codepoint may be up to 4 bytes
+ return input_ncodeunits + 4 * ninputs * options_.width;
+ }
+
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ const int64_t input_width = util::UTF8Length(input, input + input_string_ncodeunits);
+ if (input_width >= options_.width) {
+ std::copy(input, input + input_string_ncodeunits, output);
+ return input_string_ncodeunits;
+ }
+ const int64_t spaces = options_.width - input_width;
+ int64_t left = 0;
+ int64_t right = 0;
+ if (PadLeft && PadRight) {
+ // If odd number of spaces, put the extra space on the right
+ left = spaces / 2;
+ right = spaces - left;
+ } else if (PadLeft) {
+ left = spaces;
+ } else if (PadRight) {
+ right = spaces;
+ } else {
+ DCHECK(false) << "unreachable";
+ return 0;
+ }
+ uint8_t* start = output;
+ while (left) {
+ output = std::copy(options_.padding.begin(), options_.padding.end(), output);
+ left--;
+ }
+ output = std::copy(input, input + input_string_ncodeunits, output);
+ while (right) {
+ output = std::copy(options_.padding.begin(), options_.padding.end(), output);
+ right--;
+ }
+ return output - start;
+ }
+};
+
+template <typename Type>
+using AsciiLPad = StringTransformExecWithState<Type, AsciiPadTransform<true, false>>;
+template <typename Type>
+using AsciiRPad = StringTransformExecWithState<Type, AsciiPadTransform<false, true>>;
+template <typename Type>
+using AsciiCenter = StringTransformExecWithState<Type, AsciiPadTransform<true, true>>;
+template <typename Type>
+using Utf8LPad = StringTransformExecWithState<Type, Utf8PadTransform<true, false>>;
+template <typename Type>
+using Utf8RPad = StringTransformExecWithState<Type, Utf8PadTransform<false, true>>;
+template <typename Type>
+using Utf8Center = StringTransformExecWithState<Type, Utf8PadTransform<true, true>>;
+
+// ----------------------------------------------------------------------
+// string trimming
+
+#ifdef ARROW_WITH_UTF8PROC
+
+template <bool TrimLeft, bool TrimRight>
+struct UTF8TrimWhitespaceTransform : public StringTransformBase {
+ Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) override {
+ EnsureLookupTablesFilled();
+ return Status::OK();
+ }
+
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ const uint8_t* begin = input;
+ const uint8_t* end = input + input_string_ncodeunits;
+ const uint8_t* end_trimmed = end;
+ const uint8_t* begin_trimmed = begin;
+
+ auto predicate = [](uint32_t c) { return !IsSpaceCharacterUnicode(c); };
+ if (TrimLeft && !ARROW_PREDICT_TRUE(
+ arrow::util::UTF8FindIf(begin, end, predicate, &begin_trimmed))) {
+ return kTransformError;
+ }
+ if (TrimRight && begin_trimmed < end) {
+ if (!ARROW_PREDICT_TRUE(arrow::util::UTF8FindIfReverse(begin_trimmed, end,
+ predicate, &end_trimmed))) {
+ return kTransformError;
+ }
+ }
+ std::copy(begin_trimmed, end_trimmed, output);
+ return end_trimmed - begin_trimmed;
+ }
+};
+
+template <typename Type>
+using UTF8TrimWhitespace =
+ StringTransformExec<Type, UTF8TrimWhitespaceTransform<true, true>>;
+
+template <typename Type>
+using UTF8LTrimWhitespace =
+ StringTransformExec<Type, UTF8TrimWhitespaceTransform<true, false>>;
+
+template <typename Type>
+using UTF8RTrimWhitespace =
+ StringTransformExec<Type, UTF8TrimWhitespaceTransform<false, true>>;
+
+struct UTF8TrimState {
+ TrimOptions options_;
+ std::vector<bool> codepoints_;
+ Status status_ = Status::OK();
+
+ explicit UTF8TrimState(KernelContext* ctx, TrimOptions options)
+ : options_(std::move(options)) {
+ if (!ARROW_PREDICT_TRUE(
+ arrow::util::UTF8ForEach(options_.characters, [&](uint32_t c) {
+ codepoints_.resize(
+ std::max(c + 1, static_cast<uint32_t>(codepoints_.size())));
+ codepoints_.at(c) = true;
+ }))) {
+ status_ = Status::Invalid("Invalid UTF8 sequence in input");
+ }
+ }
+};
+
+template <bool TrimLeft, bool TrimRight>
+struct UTF8TrimTransform : public StringTransformBase {
+ using State = KernelStateFromFunctionOptions<UTF8TrimState, TrimOptions>;
+
+ const UTF8TrimState& state_;
+
+ explicit UTF8TrimTransform(const UTF8TrimState& state) : state_(state) {}
+
+ Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) override {
+ return state_.status_;
+ }
+
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ const uint8_t* begin = input;
+ const uint8_t* end = input + input_string_ncodeunits;
+ const uint8_t* end_trimmed = end;
+ const uint8_t* begin_trimmed = begin;
+ const auto& codepoints = state_.codepoints_;
+
+ auto predicate = [&](uint32_t c) { return c >= codepoints.size() || !codepoints[c]; };
+ if (TrimLeft && !ARROW_PREDICT_TRUE(
+ arrow::util::UTF8FindIf(begin, end, predicate, &begin_trimmed))) {
+ return kTransformError;
+ }
+ if (TrimRight && begin_trimmed < end) {
+ if (!ARROW_PREDICT_TRUE(arrow::util::UTF8FindIfReverse(begin_trimmed, end,
+ predicate, &end_trimmed))) {
+ return kTransformError;
+ }
+ }
+ std::copy(begin_trimmed, end_trimmed, output);
+ return end_trimmed - begin_trimmed;
+ }
+};
+
+template <typename Type>
+using UTF8Trim = StringTransformExecWithState<Type, UTF8TrimTransform<true, true>>;
+
+template <typename Type>
+using UTF8LTrim = StringTransformExecWithState<Type, UTF8TrimTransform<true, false>>;
+
+template <typename Type>
+using UTF8RTrim = StringTransformExecWithState<Type, UTF8TrimTransform<false, true>>;
+
+#endif
+
+template <bool TrimLeft, bool TrimRight>
+struct AsciiTrimWhitespaceTransform : public StringTransformBase {
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ const uint8_t* begin = input;
+ const uint8_t* end = input + input_string_ncodeunits;
+ const uint8_t* end_trimmed = end;
+ const uint8_t* begin_trimmed = begin;
+
+ auto predicate = [](unsigned char c) { return !IsSpaceCharacterAscii(c); };
+ if (TrimLeft) {
+ begin_trimmed = std::find_if(begin, end, predicate);
+ }
+ if (TrimRight && begin_trimmed < end) {
+ std::reverse_iterator<const uint8_t*> rbegin(end);
+ std::reverse_iterator<const uint8_t*> rend(begin_trimmed);
+ end_trimmed = std::find_if(rbegin, rend, predicate).base();
+ }
+ std::copy(begin_trimmed, end_trimmed, output);
+ return end_trimmed - begin_trimmed;
+ }
+};
+
+template <typename Type>
+using AsciiTrimWhitespace =
+ StringTransformExec<Type, AsciiTrimWhitespaceTransform<true, true>>;
+
+template <typename Type>
+using AsciiLTrimWhitespace =
+ StringTransformExec<Type, AsciiTrimWhitespaceTransform<true, false>>;
+
+template <typename Type>
+using AsciiRTrimWhitespace =
+ StringTransformExec<Type, AsciiTrimWhitespaceTransform<false, true>>;
+
+struct AsciiTrimState {
+ TrimOptions options_;
+ std::vector<bool> characters_;
+
+ explicit AsciiTrimState(KernelContext* ctx, TrimOptions options)
+ : options_(std::move(options)), characters_(256) {
+ for (const auto c : options_.characters) {
+ characters_[static_cast<unsigned char>(c)] = true;
+ }
+ }
+};
+
+template <bool TrimLeft, bool TrimRight>
+struct AsciiTrimTransform : public StringTransformBase {
+ using State = KernelStateFromFunctionOptions<AsciiTrimState, TrimOptions>;
+
+ const AsciiTrimState& state_;
+
+ explicit AsciiTrimTransform(const AsciiTrimState& state) : state_(state) {}
+
+ int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
+ uint8_t* output) {
+ const uint8_t* begin = input;
+ const uint8_t* end = input + input_string_ncodeunits;
+ const uint8_t* end_trimmed = end;
+ const uint8_t* begin_trimmed = begin;
+ const auto& characters = state_.characters_;
+
+ auto predicate = [&](uint8_t c) { return !characters[c]; };
+ if (TrimLeft) {
+ begin_trimmed = std::find_if(begin, end, predicate);
+ }
+ if (TrimRight && begin_trimmed < end) {
+ std::reverse_iterator<const uint8_t*> rbegin(end);
+ std::reverse_iterator<const uint8_t*> rend(begin_trimmed);
+ end_trimmed = std::find_if(rbegin, rend, predicate).base();
+ }
+ std::copy(begin_trimmed, end_trimmed, output);
+ return end_trimmed - begin_trimmed;
+ }
+};
+
+template <typename Type>
+using AsciiTrim = StringTransformExecWithState<Type, AsciiTrimTransform<true, true>>;
+
+template <typename Type>
+using AsciiLTrim = StringTransformExecWithState<Type, AsciiTrimTransform<true, false>>;
+
+template <typename Type>
+using AsciiRTrim = StringTransformExecWithState<Type, AsciiTrimTransform<false, true>>;
+
+const FunctionDoc utf8_center_doc(
+ "Center strings by padding with a given character",
+ ("For each string in `strings`, emit a centered string by padding both sides \n"
+ "with the given UTF8 codeunit.\nNull values emit null."),
+ {"strings"}, "PadOptions");
+
+const FunctionDoc utf8_lpad_doc(
+ "Right-align strings by padding with a given character",
+ ("For each string in `strings`, emit a right-aligned string by prepending \n"
+ "the given UTF8 codeunit.\nNull values emit null."),
+ {"strings"}, "PadOptions");
+
+const FunctionDoc utf8_rpad_doc(
+ "Left-align strings by padding with a given character",
+ ("For each string in `strings`, emit a left-aligned string by appending \n"
+ "the given UTF8 codeunit.\nNull values emit null."),
+ {"strings"}, "PadOptions");
+
+const FunctionDoc ascii_center_doc(
+ utf8_center_doc.description + "",
+ ("For each string in `strings`, emit a centered string by padding both sides \n"
+ "with the given ASCII character.\nNull values emit null."),
+ {"strings"}, "PadOptions");
+
+const FunctionDoc ascii_lpad_doc(
+ utf8_lpad_doc.description + "",
+ ("For each string in `strings`, emit a right-aligned string by prepending \n"
+ "the given ASCII character.\nNull values emit null."),
+ {"strings"}, "PadOptions");
+
+const FunctionDoc ascii_rpad_doc(
+ utf8_rpad_doc.description + "",
+ ("For each string in `strings`, emit a left-aligned string by appending \n"
+ "the given ASCII character.\nNull values emit null."),
+ {"strings"}, "PadOptions");
+
+const FunctionDoc utf8_trim_whitespace_doc(
+ "Trim leading and trailing whitespace characters",
+ ("For each string in `strings`, emit a string with leading and trailing whitespace\n"
+ "characters removed, where whitespace characters are defined by the Unicode\n"
+ "standard. Null values emit null."),
+ {"strings"});
+
+const FunctionDoc utf8_ltrim_whitespace_doc(
+ "Trim leading whitespace characters",
+ ("For each string in `strings`, emit a string with leading whitespace\n"
+ "characters removed, where whitespace characters are defined by the Unicode\n"
+ "standard. Null values emit null."),
+ {"strings"});
+
+const FunctionDoc utf8_rtrim_whitespace_doc(
+ "Trim trailing whitespace characters",
+ ("For each string in `strings`, emit a string with trailing whitespace\n"
+ "characters removed, where whitespace characters are defined by the Unicode\n"
+ "standard. Null values emit null."),
+ {"strings"});
+
+const FunctionDoc ascii_trim_whitespace_doc(
+ "Trim leading and trailing ASCII whitespace characters",
+ ("For each string in `strings`, emit a string with leading and trailing ASCII\n"
+ "whitespace characters removed. Use `utf8_trim_whitespace` to trim Unicode\n"
+ "whitespace characters. Null values emit null."),
+ {"strings"});
+
+const FunctionDoc ascii_ltrim_whitespace_doc(
+ "Trim leading ASCII whitespace characters",
+ ("For each string in `strings`, emit a string with leading ASCII whitespace\n"
+ "characters removed. Use `utf8_ltrim_whitespace` to trim leading Unicode\n"
+ "whitespace characters. Null values emit null."),
+ {"strings"});
+
+const FunctionDoc ascii_rtrim_whitespace_doc(
+ "Trim trailing ASCII whitespace characters",
+ ("For each string in `strings`, emit a string with trailing ASCII whitespace\n"
+ "characters removed. Use `utf8_rtrim_whitespace` to trim trailing Unicode\n"
+ "whitespace characters. Null values emit null."),
+ {"strings"});
+
+const FunctionDoc utf8_trim_doc(
+ "Trim leading and trailing characters present in the `characters` arguments",
+ ("For each string in `strings`, emit a string with leading and trailing\n"
+ "characters removed that are present in the `characters` argument. Null values\n"
+ "emit null."),
+ {"strings"}, "TrimOptions");
+
+const FunctionDoc utf8_ltrim_doc(
+ "Trim leading characters present in the `characters` arguments",
+ ("For each string in `strings`, emit a string with leading\n"
+ "characters removed that are present in the `characters` argument. Null values\n"
+ "emit null."),
+ {"strings"}, "TrimOptions");
+
+const FunctionDoc utf8_rtrim_doc(
+ "Trim trailing characters present in the `characters` arguments",
+ ("For each string in `strings`, emit a string with leading "
+ "characters removed that are present in the `characters` argument. Null values\n"
+ "emit null."),
+ {"strings"}, "TrimOptions");
+
+const FunctionDoc ascii_trim_doc(
+ utf8_trim_doc.summary + "",
+ utf8_trim_doc.description +
+ ("\nBoth the input string as the `characters` argument are interepreted as\n"
+ "ASCII characters, to trim non-ASCII characters, use `utf8_trim`."),
+ {"strings"}, "TrimOptions");
+
+const FunctionDoc ascii_ltrim_doc(
+ utf8_ltrim_doc.summary + "",
+ utf8_ltrim_doc.description +
+ ("\nBoth the input string as the `characters` argument are interepreted as\n"
+ "ASCII characters, to trim non-ASCII characters, use `utf8_trim`."),
+ {"strings"}, "TrimOptions");
+
+const FunctionDoc ascii_rtrim_doc(
+ utf8_rtrim_doc.summary + "",
+ utf8_rtrim_doc.description +
+ ("\nBoth the input string as the `characters` argument are interepreted as\n"
+ "ASCII characters, to trim non-ASCII characters, use `utf8_trim`."),
+ {"strings"}, "TrimOptions");
+
+const FunctionDoc strptime_doc(
+ "Parse timestamps",
+ ("For each string in `strings`, parse it as a timestamp.\n"
+ "The timestamp unit and the expected string pattern must be given\n"
+ "in StrptimeOptions. Null inputs emit null. If a non-null string\n"
+ "fails parsing, an error is returned."),
+ {"strings"}, "StrptimeOptions");
+
+const FunctionDoc binary_length_doc(
+ "Compute string lengths",
+ ("For each string in `strings`, emit the number of bytes. Null values emit null."),
+ {"strings"});
+
+const FunctionDoc utf8_length_doc("Compute UTF8 string lengths",
+ ("For each string in `strings`, emit the number of "
+ "UTF8 characters. Null values emit null."),
+ {"strings"});
+
+void AddStrptime(FunctionRegistry* registry) {
+ auto func = std::make_shared<ScalarFunction>("strptime", Arity::Unary(), &strptime_doc);
+ DCHECK_OK(func->AddKernel({utf8()}, OutputType(StrptimeResolve),
+ StrptimeExec<StringType>, StrptimeState::Init));
+ DCHECK_OK(func->AddKernel({large_utf8()}, OutputType(StrptimeResolve),
+ StrptimeExec<LargeStringType>, StrptimeState::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+void AddBinaryLength(FunctionRegistry* registry) {
+ auto func = std::make_shared<ScalarFunction>("binary_length", Arity::Unary(),
+ &binary_length_doc);
+ ArrayKernelExec exec_offset_32 =
+ applicator::ScalarUnaryNotNull<Int32Type, StringType, BinaryLength>::Exec;
+ ArrayKernelExec exec_offset_64 =
+ applicator::ScalarUnaryNotNull<Int64Type, LargeStringType, BinaryLength>::Exec;
+ for (const auto& input_type : {binary(), utf8()}) {
+ DCHECK_OK(func->AddKernel({input_type}, int32(), exec_offset_32));
+ }
+ for (const auto& input_type : {large_binary(), large_utf8()}) {
+ DCHECK_OK(func->AddKernel({input_type}, int64(), exec_offset_64));
+ }
+ DCHECK_OK(func->AddKernel({InputType(Type::FIXED_SIZE_BINARY)}, int32(),
+ BinaryLength::FixedSizeExec));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+void AddUtf8Length(FunctionRegistry* registry) {
+ auto func =
+ std::make_shared<ScalarFunction>("utf8_length", Arity::Unary(), &utf8_length_doc);
+
+ ArrayKernelExec exec_offset_32 =
+ applicator::ScalarUnaryNotNull<Int32Type, StringType, Utf8Length>::Exec;
+ DCHECK_OK(func->AddKernel({utf8()}, int32(), std::move(exec_offset_32)));
+
+ ArrayKernelExec exec_offset_64 =
+ applicator::ScalarUnaryNotNull<Int64Type, LargeStringType, Utf8Length>::Exec;
+ DCHECK_OK(func->AddKernel({large_utf8()}, int64(), std::move(exec_offset_64)));
+
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+template <typename BinaryType, typename ListType>
+struct BinaryJoin {
+ using ArrayType = typename TypeTraits<BinaryType>::ArrayType;
+ using ListArrayType = typename TypeTraits<ListType>::ArrayType;
+ using ListScalarType = typename TypeTraits<ListType>::ScalarType;
+ using ListOffsetType = typename ListArrayType::offset_type;
+ using BuilderType = typename TypeTraits<BinaryType>::BuilderType;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].kind() == Datum::SCALAR) {
+ if (batch[1].kind() == Datum::SCALAR) {
+ return ExecScalarScalar(ctx, *batch[0].scalar(), *batch[1].scalar(), out);
+ }
+ DCHECK_EQ(batch[1].kind(), Datum::ARRAY);
+ return ExecScalarArray(ctx, *batch[0].scalar(), batch[1].array(), out);
+ }
+ DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
+ if (batch[1].kind() == Datum::SCALAR) {
+ return ExecArrayScalar(ctx, batch[0].array(), *batch[1].scalar(), out);
+ }
+ DCHECK_EQ(batch[1].kind(), Datum::ARRAY);
+ return ExecArrayArray(ctx, batch[0].array(), batch[1].array(), out);
+ }
+
+ struct ListScalarOffsetLookup {
+ const ArrayType& values;
+
+ int64_t GetStart(int64_t i) { return 0; }
+ int64_t GetStop(int64_t i) { return values.length(); }
+ bool IsNull(int64_t i) { return false; }
+ };
+
+ struct ListArrayOffsetLookup {
+ explicit ListArrayOffsetLookup(const ListArrayType& lists)
+ : lists_(lists), offsets_(lists.raw_value_offsets()) {}
+
+ int64_t GetStart(int64_t i) { return offsets_[i]; }
+ int64_t GetStop(int64_t i) { return offsets_[i + 1]; }
+ bool IsNull(int64_t i) { return lists_.IsNull(i); }
+
+ private:
+ const ListArrayType& lists_;
+ const ListOffsetType* offsets_;
+ };
+
+ struct SeparatorScalarLookup {
+ const util::string_view separator;
+
+ bool IsNull(int64_t i) { return false; }
+ util::string_view GetView(int64_t i) { return separator; }
+ };
+
+ struct SeparatorArrayLookup {
+ const ArrayType& separators;
+
+ bool IsNull(int64_t i) { return separators.IsNull(i); }
+ util::string_view GetView(int64_t i) { return separators.GetView(i); }
+ };
+
+ // Scalar, scalar -> scalar
+ static Status ExecScalarScalar(KernelContext* ctx, const Scalar& left,
+ const Scalar& right, Datum* out) {
+ const auto& list = checked_cast<const ListScalarType&>(left);
+ const auto& separator_scalar = checked_cast<const BaseBinaryScalar&>(right);
+ if (!list.is_valid || !separator_scalar.is_valid) {
+ return Status::OK();
+ }
+ util::string_view separator(*separator_scalar.value);
+
+ const auto& strings = checked_cast<const ArrayType&>(*list.value);
+ if (strings.null_count() > 0) {
+ out->scalar()->is_valid = false;
+ return Status::OK();
+ }
+
+ TypedBufferBuilder<uint8_t> builder(ctx->memory_pool());
+ auto Append = [&](util::string_view value) {
+ return builder.Append(reinterpret_cast<const uint8_t*>(value.data()),
+ static_cast<int64_t>(value.size()));
+ };
+ if (strings.length() > 0) {
+ auto data_length =
+ strings.total_values_length() + (strings.length() - 1) * separator.length();
+ RETURN_NOT_OK(builder.Reserve(data_length));
+ RETURN_NOT_OK(Append(strings.GetView(0)));
+ for (int64_t j = 1; j < strings.length(); j++) {
+ RETURN_NOT_OK(Append(separator));
+ RETURN_NOT_OK(Append(strings.GetView(j)));
+ }
+ }
+ auto out_scalar = checked_cast<BaseBinaryScalar*>(out->scalar().get());
+ return builder.Finish(&out_scalar->value);
+ }
+
+ // Scalar, array -> array
+ static Status ExecScalarArray(KernelContext* ctx, const Scalar& left,
+ const std::shared_ptr<ArrayData>& right, Datum* out) {
+ const auto& list_scalar = checked_cast<const BaseListScalar&>(left);
+ if (!list_scalar.is_valid) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto nulls, MakeArrayOfNull(right->type, right->length, ctx->memory_pool()));
+ *out = *nulls->data();
+ return Status::OK();
+ }
+ const auto& strings = checked_cast<const ArrayType&>(*list_scalar.value);
+ if (strings.null_count() != 0) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto nulls, MakeArrayOfNull(right->type, right->length, ctx->memory_pool()));
+ *out = *nulls->data();
+ return Status::OK();
+ }
+ const ArrayType separators(right);
+
+ BuilderType builder(ctx->memory_pool());
+ RETURN_NOT_OK(builder.Reserve(separators.length()));
+
+ // Presize data to avoid multiple reallocations when joining strings
+ int64_t total_data_length = 0;
+ const int64_t list_length = strings.length();
+ if (list_length) {
+ const int64_t string_length = strings.total_values_length();
+ total_data_length +=
+ string_length * (separators.length() - separators.null_count());
+ for (int64_t i = 0; i < separators.length(); ++i) {
+ if (separators.IsNull(i)) {
+ continue;
+ }
+ total_data_length += (list_length - 1) * separators.value_length(i);
+ }
+ }
+ RETURN_NOT_OK(builder.ReserveData(total_data_length));
+
+ return JoinStrings(separators.length(), strings, ListScalarOffsetLookup{strings},
+ SeparatorArrayLookup{separators}, &builder, out);
+ }
+
+ // Array, scalar -> array
+ static Status ExecArrayScalar(KernelContext* ctx,
+ const std::shared_ptr<ArrayData>& left,
+ const Scalar& right, Datum* out) {
+ const ListArrayType lists(left);
+ const auto& separator_scalar = checked_cast<const BaseBinaryScalar&>(right);
+
+ if (!separator_scalar.is_valid) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto nulls,
+ MakeArrayOfNull(lists.value_type(), lists.length(), ctx->memory_pool()));
+ *out = *nulls->data();
+ return Status::OK();
+ }
+
+ util::string_view separator(*separator_scalar.value);
+ const auto& strings = checked_cast<const ArrayType&>(*lists.values());
+ const auto list_offsets = lists.raw_value_offsets();
+
+ BuilderType builder(ctx->memory_pool());
+ RETURN_NOT_OK(builder.Reserve(lists.length()));
+
+ // Presize data to avoid multiple reallocations when joining strings
+ int64_t total_data_length = strings.total_values_length();
+ for (int64_t i = 0; i < lists.length(); ++i) {
+ const auto start = list_offsets[i], end = list_offsets[i + 1];
+ if (end > start && !ValuesContainNull(strings, start, end)) {
+ total_data_length += (end - start - 1) * separator.length();
+ }
+ }
+ RETURN_NOT_OK(builder.ReserveData(total_data_length));
+
+ return JoinStrings(lists.length(), strings, ListArrayOffsetLookup{lists},
+ SeparatorScalarLookup{separator}, &builder, out);
+ }
+
+ // Array, array -> array
+ static Status ExecArrayArray(KernelContext* ctx, const std::shared_ptr<ArrayData>& left,
+ const std::shared_ptr<ArrayData>& right, Datum* out) {
+ const ListArrayType lists(left);
+ const auto& strings = checked_cast<const ArrayType&>(*lists.values());
+ const auto list_offsets = lists.raw_value_offsets();
+ const auto string_offsets = strings.raw_value_offsets();
+ const ArrayType separators(right);
+
+ BuilderType builder(ctx->memory_pool());
+ RETURN_NOT_OK(builder.Reserve(lists.length()));
+
+ // Presize data to avoid multiple reallocations when joining strings
+ int64_t total_data_length = 0;
+ for (int64_t i = 0; i < lists.length(); ++i) {
+ if (separators.IsNull(i)) {
+ continue;
+ }
+ const auto start = list_offsets[i], end = list_offsets[i + 1];
+ if (end > start && !ValuesContainNull(strings, start, end)) {
+ total_data_length += string_offsets[end] - string_offsets[start];
+ total_data_length += (end - start - 1) * separators.value_length(i);
+ }
+ }
+ RETURN_NOT_OK(builder.ReserveData(total_data_length));
+
+ struct SeparatorLookup {
+ const ArrayType& separators;
+
+ bool IsNull(int64_t i) { return separators.IsNull(i); }
+ util::string_view GetView(int64_t i) { return separators.GetView(i); }
+ };
+ return JoinStrings(lists.length(), strings, ListArrayOffsetLookup{lists},
+ SeparatorArrayLookup{separators}, &builder, out);
+ }
+
+ template <typename ListOffsetLookup, typename SeparatorLookup>
+ static Status JoinStrings(int64_t length, const ArrayType& strings,
+ ListOffsetLookup&& list_offsets, SeparatorLookup&& separators,
+ BuilderType* builder, Datum* out) {
+ for (int64_t i = 0; i < length; ++i) {
+ if (list_offsets.IsNull(i) || separators.IsNull(i)) {
+ builder->UnsafeAppendNull();
+ continue;
+ }
+ const auto j_start = list_offsets.GetStart(i), j_end = list_offsets.GetStop(i);
+ if (j_start == j_end) {
+ builder->UnsafeAppendEmptyValue();
+ continue;
+ }
+ if (ValuesContainNull(strings, j_start, j_end)) {
+ builder->UnsafeAppendNull();
+ continue;
+ }
+ builder->UnsafeAppend(strings.GetView(j_start));
+ for (int64_t j = j_start + 1; j < j_end; ++j) {
+ builder->UnsafeExtendCurrent(separators.GetView(i));
+ builder->UnsafeExtendCurrent(strings.GetView(j));
+ }
+ }
+
+ std::shared_ptr<Array> string_array;
+ RETURN_NOT_OK(builder->Finish(&string_array));
+ *out = *string_array->data();
+ // Correct the output type based on the input
+ out->mutable_array()->type = strings.type();
+ return Status::OK();
+ }
+
+ static bool ValuesContainNull(const ArrayType& values, int64_t start, int64_t end) {
+ if (values.null_count() == 0) {
+ return false;
+ }
+ for (int64_t i = start; i < end; ++i) {
+ if (values.IsNull(i)) {
+ return true;
+ }
+ }
+ return false;
+ }
+};
+
+using BinaryJoinElementWiseState = OptionsWrapper<JoinOptions>;
+
+template <typename Type>
+struct BinaryJoinElementWise {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using BuilderType = typename TypeTraits<Type>::BuilderType;
+ using offset_type = typename Type::offset_type;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ JoinOptions options = BinaryJoinElementWiseState::Get(ctx);
+ // Last argument is the separator (for consistency with binary_join)
+ if (std::all_of(batch.values.begin(), batch.values.end(),
+ [](const Datum& d) { return d.is_scalar(); })) {
+ return ExecOnlyScalar(ctx, options, batch, out);
+ }
+ return ExecContainingArrays(ctx, options, batch, out);
+ }
+
+ static Status ExecOnlyScalar(KernelContext* ctx, const JoinOptions& options,
+ const ExecBatch& batch, Datum* out) {
+ BaseBinaryScalar* output = checked_cast<BaseBinaryScalar*>(out->scalar().get());
+ const size_t num_args = batch.values.size();
+ if (num_args == 1) {
+ // Only separator, no values
+ output->is_valid = batch.values[0].scalar()->is_valid;
+ if (output->is_valid) {
+ ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(0));
+ }
+ return Status::OK();
+ }
+
+ int64_t final_size = CalculateRowSize(options, batch, 0);
+ if (final_size < 0) {
+ output->is_valid = false;
+ return Status::OK();
+ }
+ ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(final_size));
+ const auto separator = UnboxScalar<Type>::Unbox(*batch.values.back().scalar());
+ uint8_t* buf = output->value->mutable_data();
+ bool first = true;
+ for (size_t i = 0; i < num_args - 1; i++) {
+ const Scalar& scalar = *batch[i].scalar();
+ util::string_view s;
+ if (scalar.is_valid) {
+ s = UnboxScalar<Type>::Unbox(scalar);
+ } else {
+ switch (options.null_handling) {
+ case JoinOptions::EMIT_NULL:
+ // Handled by CalculateRowSize
+ DCHECK(false) << "unreachable";
+ break;
+ case JoinOptions::SKIP:
+ continue;
+ case JoinOptions::REPLACE:
+ s = options.null_replacement;
+ break;
+ }
+ }
+ if (!first) {
+ buf = std::copy(separator.begin(), separator.end(), buf);
+ }
+ first = false;
+ buf = std::copy(s.begin(), s.end(), buf);
+ }
+ output->is_valid = true;
+ DCHECK_EQ(final_size, buf - output->value->mutable_data());
+ return Status::OK();
+ }
+
+ static Status ExecContainingArrays(KernelContext* ctx, const JoinOptions& options,
+ const ExecBatch& batch, Datum* out) {
+ // Presize data to avoid reallocations
+ int64_t final_size = 0;
+ for (int64_t i = 0; i < batch.length; i++) {
+ auto size = CalculateRowSize(options, batch, i);
+ if (size > 0) final_size += size;
+ }
+ BuilderType builder(ctx->memory_pool());
+ RETURN_NOT_OK(builder.Reserve(batch.length));
+ RETURN_NOT_OK(builder.ReserveData(final_size));
+
+ std::vector<util::string_view> valid_cols(batch.values.size());
+ for (size_t row = 0; row < static_cast<size_t>(batch.length); row++) {
+ size_t num_valid = 0; // Not counting separator
+ for (size_t col = 0; col < batch.values.size(); col++) {
+ if (batch[col].is_scalar()) {
+ const auto& scalar = *batch[col].scalar();
+ if (scalar.is_valid) {
+ valid_cols[col] = UnboxScalar<Type>::Unbox(scalar);
+ if (col < batch.values.size() - 1) num_valid++;
+ } else {
+ valid_cols[col] = util::string_view();
+ }
+ } else {
+ const ArrayData& array = *batch[col].array();
+ if (!array.MayHaveNulls() ||
+ BitUtil::GetBit(array.buffers[0]->data(), array.offset + row)) {
+ const offset_type* offsets = array.GetValues<offset_type>(1);
+ const uint8_t* data = array.GetValues<uint8_t>(2, /*absolute_offset=*/0);
+ const int64_t length = offsets[row + 1] - offsets[row];
+ valid_cols[col] = util::string_view(
+ reinterpret_cast<const char*>(data + offsets[row]), length);
+ if (col < batch.values.size() - 1) num_valid++;
+ } else {
+ valid_cols[col] = util::string_view();
+ }
+ }
+ }
+
+ if (!valid_cols.back().data()) {
+ // Separator is null
+ builder.UnsafeAppendNull();
+ continue;
+ } else if (batch.values.size() == 1) {
+ // Only given separator
+ builder.UnsafeAppendEmptyValue();
+ continue;
+ } else if (num_valid < batch.values.size() - 1) {
+ // We had some nulls
+ if (options.null_handling == JoinOptions::EMIT_NULL) {
+ builder.UnsafeAppendNull();
+ continue;
+ }
+ }
+ const auto separator = valid_cols.back();
+ bool first = true;
+ for (size_t col = 0; col < batch.values.size() - 1; col++) {
+ util::string_view value = valid_cols[col];
+ if (!value.data()) {
+ switch (options.null_handling) {
+ case JoinOptions::EMIT_NULL:
+ DCHECK(false) << "unreachable";
+ break;
+ case JoinOptions::SKIP:
+ continue;
+ case JoinOptions::REPLACE:
+ value = options.null_replacement;
+ break;
+ }
+ }
+ if (first) {
+ builder.UnsafeAppend(value);
+ first = false;
+ continue;
+ }
+ builder.UnsafeExtendCurrent(separator);
+ builder.UnsafeExtendCurrent(value);
+ }
+ }
+
+ std::shared_ptr<Array> string_array;
+ RETURN_NOT_OK(builder.Finish(&string_array));
+ *out = *string_array->data();
+ out->mutable_array()->type = batch[0].type();
+ DCHECK_EQ(batch.length, out->array()->length);
+ DCHECK_EQ(final_size,
+ checked_cast<const ArrayType&>(*string_array).total_values_length());
+ return Status::OK();
+ }
+
+ // Compute the length of the output for the given position, or -1 if it would be null.
+ static int64_t CalculateRowSize(const JoinOptions& options, const ExecBatch& batch,
+ const int64_t index) {
+ const auto num_args = batch.values.size();
+ int64_t final_size = 0;
+ int64_t num_non_null_args = 0;
+ for (size_t i = 0; i < num_args; i++) {
+ int64_t element_size = 0;
+ bool valid = true;
+ if (batch[i].is_scalar()) {
+ const Scalar& scalar = *batch[i].scalar();
+ valid = scalar.is_valid;
+ element_size = UnboxScalar<Type>::Unbox(scalar).size();
+ } else {
+ const ArrayData& array = *batch[i].array();
+ valid = !array.MayHaveNulls() ||
+ BitUtil::GetBit(array.buffers[0]->data(), array.offset + index);
+ const offset_type* offsets = array.GetValues<offset_type>(1);
+ element_size = offsets[index + 1] - offsets[index];
+ }
+ if (i == num_args - 1) {
+ if (!valid) return -1;
+ if (num_non_null_args > 1) {
+ // Add separator size (only if there were values to join)
+ final_size += (num_non_null_args - 1) * element_size;
+ }
+ break;
+ }
+ if (!valid) {
+ switch (options.null_handling) {
+ case JoinOptions::EMIT_NULL:
+ return -1;
+ case JoinOptions::SKIP:
+ continue;
+ case JoinOptions::REPLACE:
+ element_size = options.null_replacement.size();
+ break;
+ }
+ }
+ num_non_null_args++;
+ final_size += element_size;
+ }
+ return final_size;
+ }
+};
+
+const FunctionDoc binary_join_doc(
+ "Join a list of strings together with a `separator` to form a single string",
+ ("Insert `separator` between `list` elements, and concatenate them.\n"
+ "Any null input and any null `list` element emits a null output.\n"),
+ {"list", "separator"});
+
+const FunctionDoc binary_join_element_wise_doc(
+ "Join string arguments into one, using the last argument as the separator",
+ ("Insert the last argument of `strings` between the rest of the elements, "
+ "and concatenate them.\n"
+ "Any null separator element emits a null output. Null elements either "
+ "emit a null (the default), are skipped, or replaced with a given string.\n"),
+ {"*strings"}, "JoinOptions");
+
+const auto kDefaultJoinOptions = JoinOptions::Defaults();
+
+template <typename ListType>
+void AddBinaryJoinForListType(ScalarFunction* func) {
+ for (const std::shared_ptr<DataType>& ty : BaseBinaryTypes()) {
+ auto exec = GenerateTypeAgnosticVarBinaryBase<BinaryJoin, ListType>(*ty);
+ auto list_ty = std::make_shared<ListType>(ty);
+ DCHECK_OK(func->AddKernel({InputType(list_ty), InputType(ty)}, ty, exec));
+ }
+}
+
+void AddBinaryJoin(FunctionRegistry* registry) {
+ {
+ auto func = std::make_shared<ScalarFunction>("binary_join", Arity::Binary(),
+ &binary_join_doc);
+ AddBinaryJoinForListType<ListType>(func.get());
+ AddBinaryJoinForListType<LargeListType>(func.get());
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+ {
+ auto func = std::make_shared<ScalarFunction>(
+ "binary_join_element_wise", Arity::VarArgs(/*min_args=*/1),
+ &binary_join_element_wise_doc, &kDefaultJoinOptions);
+ for (const auto& ty : BaseBinaryTypes()) {
+ ScalarKernel kernel{KernelSignature::Make({InputType(ty)}, ty, /*is_varargs=*/true),
+ GenerateTypeAgnosticVarBinaryBase<BinaryJoinElementWise>(ty),
+ BinaryJoinElementWiseState::Init};
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ }
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+}
+
+template <template <typename> class ExecFunctor>
+void MakeUnaryStringBatchKernel(
+ std::string name, FunctionRegistry* registry, const FunctionDoc* doc,
+ MemAllocation::type mem_allocation = MemAllocation::PREALLOCATE) {
+ auto func = std::make_shared<ScalarFunction>(name, Arity::Unary(), doc);
+ {
+ auto exec_32 = ExecFunctor<StringType>::Exec;
+ ScalarKernel kernel{{utf8()}, utf8(), exec_32};
+ kernel.mem_allocation = mem_allocation;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ }
+ {
+ auto exec_64 = ExecFunctor<LargeStringType>::Exec;
+ ScalarKernel kernel{{large_utf8()}, large_utf8(), exec_64};
+ kernel.mem_allocation = mem_allocation;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ }
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+template <template <typename> class ExecFunctor>
+void MakeUnaryStringBatchKernelWithState(
+ std::string name, FunctionRegistry* registry, const FunctionDoc* doc,
+ MemAllocation::type mem_allocation = MemAllocation::PREALLOCATE) {
+ auto func = std::make_shared<ScalarFunction>(name, Arity::Unary(), doc);
+ {
+ using t32 = ExecFunctor<StringType>;
+ ScalarKernel kernel{{utf8()}, utf8(), t32::Exec, t32::State::Init};
+ kernel.mem_allocation = mem_allocation;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ }
+ {
+ using t64 = ExecFunctor<LargeStringType>;
+ ScalarKernel kernel{{large_utf8()}, large_utf8(), t64::Exec, t64::State::Init};
+ kernel.mem_allocation = mem_allocation;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ }
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+#ifdef ARROW_WITH_UTF8PROC
+
+template <template <typename> class Transformer>
+void MakeUnaryStringUTF8TransformKernel(std::string name, FunctionRegistry* registry,
+ const FunctionDoc* doc) {
+ auto func = std::make_shared<ScalarFunction>(name, Arity::Unary(), doc);
+ ArrayKernelExec exec_32 = Transformer<StringType>::Exec;
+ ArrayKernelExec exec_64 = Transformer<LargeStringType>::Exec;
+ DCHECK_OK(func->AddKernel({utf8()}, utf8(), exec_32));
+ DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), exec_64));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+#endif
+
+// NOTE: Predicate should only populate 'status' with errors,
+// leave it unmodified to indicate Status::OK()
+using StringPredicate =
+ std::function<bool(KernelContext*, const uint8_t*, size_t, Status*)>;
+
+template <typename Type>
+Status ApplyPredicate(KernelContext* ctx, const ExecBatch& batch,
+ StringPredicate predicate, Datum* out) {
+ Status st = Status::OK();
+ EnsureLookupTablesFilled();
+ if (batch[0].kind() == Datum::ARRAY) {
+ const ArrayData& input = *batch[0].array();
+ ArrayIterator<Type> input_it(input);
+ ArrayData* out_arr = out->mutable_array();
+ ::arrow::internal::GenerateBitsUnrolled(
+ out_arr->buffers[1]->mutable_data(), out_arr->offset, input.length,
+ [&]() -> bool {
+ util::string_view val = input_it();
+ return predicate(ctx, reinterpret_cast<const uint8_t*>(val.data()), val.size(),
+ &st);
+ });
+ } else {
+ const auto& input = checked_cast<const BaseBinaryScalar&>(*batch[0].scalar());
+ if (input.is_valid) {
+ bool boolean_result = predicate(ctx, input.value->data(),
+ static_cast<size_t>(input.value->size()), &st);
+ // UTF decoding can lead to issues
+ if (st.ok()) {
+ out->value = std::make_shared<BooleanScalar>(boolean_result);
+ }
+ }
+ }
+ return st;
+}
+
+template <typename Predicate>
+void AddUnaryStringPredicate(std::string name, FunctionRegistry* registry,
+ const FunctionDoc* doc) {
+ auto func = std::make_shared<ScalarFunction>(name, Arity::Unary(), doc);
+ auto exec_32 = [](KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return ApplyPredicate<StringType>(ctx, batch, Predicate::Call, out);
+ };
+ auto exec_64 = [](KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return ApplyPredicate<LargeStringType>(ctx, batch, Predicate::Call, out);
+ };
+ DCHECK_OK(func->AddKernel({utf8()}, boolean(), std::move(exec_32)));
+ DCHECK_OK(func->AddKernel({large_utf8()}, boolean(), std::move(exec_64)));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+FunctionDoc StringPredicateDoc(std::string summary, std::string description) {
+ return FunctionDoc{std::move(summary), std::move(description), {"strings"}};
+}
+
+FunctionDoc StringClassifyDoc(std::string class_summary, std::string class_desc,
+ bool non_empty) {
+ std::string summary, description;
+ {
+ std::stringstream ss;
+ ss << "Classify strings as " << class_summary;
+ summary = ss.str();
+ }
+ {
+ std::stringstream ss;
+ if (non_empty) {
+ ss
+ << ("For each string in `strings`, emit true iff the string is non-empty\n"
+ "and consists only of ");
+ } else {
+ ss
+ << ("For each string in `strings`, emit true iff the string consists only\n"
+ "of ");
+ }
+ ss << class_desc << ". Null strings emit null.";
+ description = ss.str();
+ }
+ return StringPredicateDoc(std::move(summary), std::move(description));
+}
+
+const auto string_is_ascii_doc = StringClassifyDoc("ASCII", "ASCII characters", false);
+
+const auto ascii_is_alnum_doc =
+ StringClassifyDoc("ASCII alphanumeric", "alphanumeric ASCII characters", true);
+const auto ascii_is_alpha_doc =
+ StringClassifyDoc("ASCII alphabetic", "alphabetic ASCII characters", true);
+const auto ascii_is_decimal_doc =
+ StringClassifyDoc("ASCII decimal", "decimal ASCII characters", true);
+const auto ascii_is_lower_doc =
+ StringClassifyDoc("ASCII lowercase", "lowercase ASCII characters", true);
+const auto ascii_is_printable_doc =
+ StringClassifyDoc("ASCII printable", "printable ASCII characters", true);
+const auto ascii_is_space_doc =
+ StringClassifyDoc("ASCII whitespace", "whitespace ASCII characters", true);
+const auto ascii_is_upper_doc =
+ StringClassifyDoc("ASCII uppercase", "uppercase ASCII characters", true);
+
+const auto ascii_is_title_doc = StringPredicateDoc(
+ "Classify strings as ASCII titlecase",
+ ("For each string in `strings`, emit true iff the string is title-cased,\n"
+ "i.e. it has at least one cased character, each uppercase character\n"
+ "follows an uncased character, and each lowercase character follows\n"
+ "an uppercase character.\n"));
+
+const auto utf8_is_alnum_doc =
+ StringClassifyDoc("alphanumeric", "alphanumeric Unicode characters", true);
+const auto utf8_is_alpha_doc =
+ StringClassifyDoc("alphabetic", "alphabetic Unicode characters", true);
+const auto utf8_is_decimal_doc =
+ StringClassifyDoc("decimal", "decimal Unicode characters", true);
+const auto utf8_is_digit_doc = StringClassifyDoc("digits", "Unicode digits", true);
+const auto utf8_is_lower_doc =
+ StringClassifyDoc("lowercase", "lowercase Unicode characters", true);
+const auto utf8_is_numeric_doc =
+ StringClassifyDoc("numeric", "numeric Unicode characters", true);
+const auto utf8_is_printable_doc =
+ StringClassifyDoc("printable", "printable Unicode characters", true);
+const auto utf8_is_space_doc =
+ StringClassifyDoc("whitespace", "whitespace Unicode characters", true);
+const auto utf8_is_upper_doc =
+ StringClassifyDoc("uppercase", "uppercase Unicode characters", true);
+
+const auto utf8_is_title_doc = StringPredicateDoc(
+ "Classify strings as titlecase",
+ ("For each string in `strings`, emit true iff the string is title-cased,\n"
+ "i.e. it has at least one cased character, each uppercase character\n"
+ "follows an uncased character, and each lowercase character follows\n"
+ "an uppercase character.\n"));
+
+const FunctionDoc ascii_upper_doc(
+ "Transform ASCII input to uppercase",
+ ("For each string in `strings`, return an uppercase version.\n\n"
+ "This function assumes the input is fully ASCII. It it may contain\n"
+ "non-ASCII characters, use \"utf8_upper\" instead."),
+ {"strings"});
+
+const FunctionDoc ascii_lower_doc(
+ "Transform ASCII input to lowercase",
+ ("For each string in `strings`, return a lowercase version.\n\n"
+ "This function assumes the input is fully ASCII. If it may contain\n"
+ "non-ASCII characters, use \"utf8_lower\" instead."),
+ {"strings"});
+
+const FunctionDoc ascii_swapcase_doc(
+ "Transform ASCII input lowercase characters to uppercase and uppercase characters to "
+ "lowercase",
+ ("For each string in `strings`, return a string with opposite casing.\n\n"
+ "This function assumes the input is fully ASCII. If it may contain\n"
+ "non-ASCII characters, use \"utf8_swapcase\" instead."),
+ {"strings"});
+
+const FunctionDoc ascii_capitalize_doc(
+ "Capitalize the first character of ASCII input",
+ ("For each string in `strings`, return a capitalized version.\n\n"
+ "This function assumes the input is fully ASCII. If it may contain\n"
+ "non-ASCII characters, use \"utf8_capitalize\" instead."),
+ {"strings"});
+
+const FunctionDoc ascii_title_doc(
+ "Titlecase each word of ASCII input",
+ ("For each string in `strings`, return a titlecased version.\n"
+ "Each word in the output will start with an uppercase character and its\n"
+ "remaining characters will be lowercase.\n\n"
+ "This function assumes the input is fully ASCII. If it may contain\n"
+ "non-ASCII characters, use \"utf8_title\" instead."),
+ {"strings"});
+
+const FunctionDoc ascii_reverse_doc(
+ "Reverse ASCII input",
+ ("For each ASCII string in `strings`, return a reversed version.\n\n"
+ "This function assumes the input is fully ASCII. If it may contain\n"
+ "non-ASCII characters, use \"utf8_reverse\" instead."),
+ {"strings"});
+
+const FunctionDoc utf8_upper_doc(
+ "Transform input to uppercase",
+ ("For each string in `strings`, return an uppercase version."), {"strings"});
+
+const FunctionDoc utf8_lower_doc(
+ "Transform input to lowercase",
+ ("For each string in `strings`, return a lowercase version."), {"strings"});
+
+const FunctionDoc utf8_swapcase_doc(
+ "Transform input lowercase characters to uppercase and uppercase characters to "
+ "lowercase",
+ ("For each string in `strings`, return an opposite case version."), {"strings"});
+
+const FunctionDoc utf8_capitalize_doc(
+ "Capitalize the first character of input",
+ ("For each string in `strings`, return a capitalized version,\n"
+ "with the first character uppercased and the others lowercased."),
+ {"strings"});
+
+const FunctionDoc utf8_title_doc(
+ "Titlecase each word of input",
+ ("For each string in `strings`, return a titlecased version.\n"
+ "Each word in the output will start with an uppercase character and its\n"
+ "remaining characters will be lowercase."),
+ {"strings"});
+
+const FunctionDoc utf8_reverse_doc(
+ "Reverse input",
+ ("For each string in `strings`, return a reversed version.\n\n"
+ "This function operates on Unicode codepoints, not grapheme\n"
+ "clusters. Hence, it will not correctly reverse grapheme clusters\n"
+ "composed of multiple codepoints."),
+ {"strings"});
+
+} // namespace
+
+void RegisterScalarStringAscii(FunctionRegistry* registry) {
+ // Some kernels are able to reuse the original offsets buffer, so don't
+ // preallocate them in the output. Only kernels that invoke
+ // "StringDataTransform" support no preallocation.
+ MakeUnaryStringBatchKernel<AsciiUpper>("ascii_upper", registry, &ascii_upper_doc,
+ MemAllocation::NO_PREALLOCATE);
+ MakeUnaryStringBatchKernel<AsciiLower>("ascii_lower", registry, &ascii_lower_doc,
+ MemAllocation::NO_PREALLOCATE);
+ MakeUnaryStringBatchKernel<AsciiSwapCase>(
+ "ascii_swapcase", registry, &ascii_swapcase_doc, MemAllocation::NO_PREALLOCATE);
+ MakeUnaryStringBatchKernel<AsciiCapitalize>("ascii_capitalize", registry,
+ &ascii_capitalize_doc);
+ MakeUnaryStringBatchKernel<AsciiTitle>("ascii_title", registry, &ascii_title_doc);
+ MakeUnaryStringBatchKernel<AsciiTrimWhitespace>("ascii_trim_whitespace", registry,
+ &ascii_trim_whitespace_doc);
+ MakeUnaryStringBatchKernel<AsciiLTrimWhitespace>("ascii_ltrim_whitespace", registry,
+ &ascii_ltrim_whitespace_doc);
+ MakeUnaryStringBatchKernel<AsciiRTrimWhitespace>("ascii_rtrim_whitespace", registry,
+ &ascii_rtrim_whitespace_doc);
+ MakeUnaryStringBatchKernel<AsciiReverse>("ascii_reverse", registry, &ascii_reverse_doc);
+ MakeUnaryStringBatchKernel<Utf8Reverse>("utf8_reverse", registry, &utf8_reverse_doc);
+
+ MakeUnaryStringBatchKernelWithState<AsciiCenter>("ascii_center", registry,
+ &ascii_center_doc);
+ MakeUnaryStringBatchKernelWithState<AsciiLPad>("ascii_lpad", registry, &ascii_lpad_doc);
+ MakeUnaryStringBatchKernelWithState<AsciiRPad>("ascii_rpad", registry, &ascii_rpad_doc);
+ MakeUnaryStringBatchKernelWithState<Utf8Center>("utf8_center", registry,
+ &utf8_center_doc);
+ MakeUnaryStringBatchKernelWithState<Utf8LPad>("utf8_lpad", registry, &utf8_lpad_doc);
+ MakeUnaryStringBatchKernelWithState<Utf8RPad>("utf8_rpad", registry, &utf8_rpad_doc);
+
+ MakeUnaryStringBatchKernelWithState<AsciiTrim>("ascii_trim", registry, &ascii_trim_doc);
+ MakeUnaryStringBatchKernelWithState<AsciiLTrim>("ascii_ltrim", registry,
+ &ascii_ltrim_doc);
+ MakeUnaryStringBatchKernelWithState<AsciiRTrim>("ascii_rtrim", registry,
+ &ascii_rtrim_doc);
+
+ AddUnaryStringPredicate<IsAscii>("string_is_ascii", registry, &string_is_ascii_doc);
+
+ AddUnaryStringPredicate<IsAlphaNumericAscii>("ascii_is_alnum", registry,
+ &ascii_is_alnum_doc);
+ AddUnaryStringPredicate<IsAlphaAscii>("ascii_is_alpha", registry, &ascii_is_alpha_doc);
+ AddUnaryStringPredicate<IsDecimalAscii>("ascii_is_decimal", registry,
+ &ascii_is_decimal_doc);
+ // no is_digit for ascii, since it is the same as is_decimal
+ AddUnaryStringPredicate<IsLowerAscii>("ascii_is_lower", registry, &ascii_is_lower_doc);
+ // no is_numeric for ascii, since it is the same as is_decimal
+ AddUnaryStringPredicate<IsPrintableAscii>("ascii_is_printable", registry,
+ &ascii_is_printable_doc);
+ AddUnaryStringPredicate<IsSpaceAscii>("ascii_is_space", registry, &ascii_is_space_doc);
+ AddUnaryStringPredicate<IsTitleAscii>("ascii_is_title", registry, &ascii_is_title_doc);
+ AddUnaryStringPredicate<IsUpperAscii>("ascii_is_upper", registry, &ascii_is_upper_doc);
+
+#ifdef ARROW_WITH_UTF8PROC
+ MakeUnaryStringUTF8TransformKernel<UTF8Upper>("utf8_upper", registry, &utf8_upper_doc);
+ MakeUnaryStringUTF8TransformKernel<UTF8Lower>("utf8_lower", registry, &utf8_lower_doc);
+ MakeUnaryStringUTF8TransformKernel<UTF8SwapCase>("utf8_swapcase", registry,
+ &utf8_swapcase_doc);
+ MakeUnaryStringBatchKernel<Utf8Capitalize>("utf8_capitalize", registry,
+ &utf8_capitalize_doc);
+ MakeUnaryStringBatchKernel<Utf8Title>("utf8_title", registry, &utf8_title_doc);
+ MakeUnaryStringBatchKernel<UTF8TrimWhitespace>("utf8_trim_whitespace", registry,
+ &utf8_trim_whitespace_doc);
+ MakeUnaryStringBatchKernel<UTF8LTrimWhitespace>("utf8_ltrim_whitespace", registry,
+ &utf8_ltrim_whitespace_doc);
+ MakeUnaryStringBatchKernel<UTF8RTrimWhitespace>("utf8_rtrim_whitespace", registry,
+ &utf8_rtrim_whitespace_doc);
+ MakeUnaryStringBatchKernelWithState<UTF8Trim>("utf8_trim", registry, &utf8_trim_doc);
+ MakeUnaryStringBatchKernelWithState<UTF8LTrim>("utf8_ltrim", registry, &utf8_ltrim_doc);
+ MakeUnaryStringBatchKernelWithState<UTF8RTrim>("utf8_rtrim", registry, &utf8_rtrim_doc);
+
+ AddUnaryStringPredicate<IsAlphaNumericUnicode>("utf8_is_alnum", registry,
+ &utf8_is_alnum_doc);
+ AddUnaryStringPredicate<IsAlphaUnicode>("utf8_is_alpha", registry, &utf8_is_alpha_doc);
+ AddUnaryStringPredicate<IsDecimalUnicode>("utf8_is_decimal", registry,
+ &utf8_is_decimal_doc);
+ AddUnaryStringPredicate<IsDigitUnicode>("utf8_is_digit", registry, &utf8_is_digit_doc);
+ AddUnaryStringPredicate<IsLowerUnicode>("utf8_is_lower", registry, &utf8_is_lower_doc);
+ AddUnaryStringPredicate<IsNumericUnicode>("utf8_is_numeric", registry,
+ &utf8_is_numeric_doc);
+ AddUnaryStringPredicate<IsPrintableUnicode>("utf8_is_printable", registry,
+ &utf8_is_printable_doc);
+ AddUnaryStringPredicate<IsSpaceUnicode>("utf8_is_space", registry, &utf8_is_space_doc);
+ AddUnaryStringPredicate<IsTitleUnicode>("utf8_is_title", registry, &utf8_is_title_doc);
+ AddUnaryStringPredicate<IsUpperUnicode>("utf8_is_upper", registry, &utf8_is_upper_doc);
+#endif
+
+ AddBinaryLength(registry);
+ AddUtf8Length(registry);
+ AddMatchSubstring(registry);
+ AddFindSubstring(registry);
+ AddCountSubstring(registry);
+ MakeUnaryStringBatchKernelWithState<ReplaceSubStringPlain>(
+ "replace_substring", registry, &replace_substring_doc,
+ MemAllocation::NO_PREALLOCATE);
+#ifdef ARROW_WITH_RE2
+ MakeUnaryStringBatchKernelWithState<ReplaceSubStringRegex>(
+ "replace_substring_regex", registry, &replace_substring_regex_doc,
+ MemAllocation::NO_PREALLOCATE);
+ AddExtractRegex(registry);
+#endif
+ AddReplaceSlice(registry);
+ AddSlice(registry);
+ AddSplit(registry);
+ AddStrptime(registry);
+ AddBinaryJoin(registry);
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
new file mode 100644
index 000000000..ddc3a56f0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
@@ -0,0 +1,240 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+constexpr auto kSeed = 0x94378165;
+
+static void UnaryStringBenchmark(benchmark::State& state, const std::string& func_name,
+ const FunctionOptions* options = nullptr) {
+ const int64_t array_length = 1 << 20;
+ const int64_t value_min_size = 0;
+ const int64_t value_max_size = 32;
+ const double null_probability = 0.01;
+ random::RandomArrayGenerator rng(kSeed);
+
+ // NOTE: this produces only-Ascii data
+ auto values =
+ rng.String(array_length, value_min_size, value_max_size, null_probability);
+ // Make sure lookup tables are initialized before measuring
+ ABORT_NOT_OK(CallFunction(func_name, {values}, options));
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(CallFunction(func_name, {values}, options));
+ }
+ state.SetItemsProcessed(state.iterations() * array_length);
+ state.SetBytesProcessed(state.iterations() * values->data()->buffers[2]->size());
+}
+
+static void AsciiLower(benchmark::State& state) {
+ UnaryStringBenchmark(state, "ascii_lower");
+}
+
+static void AsciiUpper(benchmark::State& state) {
+ UnaryStringBenchmark(state, "ascii_upper");
+}
+
+static void IsAlphaNumericAscii(benchmark::State& state) {
+ UnaryStringBenchmark(state, "ascii_is_alnum");
+}
+
+static void MatchSubstring(benchmark::State& state) {
+ MatchSubstringOptions options("abac");
+ UnaryStringBenchmark(state, "match_substring", &options);
+}
+
+static void SplitPattern(benchmark::State& state) {
+ SplitPatternOptions options("a");
+ UnaryStringBenchmark(state, "split_pattern", &options);
+}
+
+static void TrimSingleAscii(benchmark::State& state) {
+ TrimOptions options("a");
+ UnaryStringBenchmark(state, "ascii_trim", &options);
+}
+
+static void TrimManyAscii(benchmark::State& state) {
+ TrimOptions options("abcdefgABCDEFG");
+ UnaryStringBenchmark(state, "ascii_trim", &options);
+}
+
+#ifdef ARROW_WITH_RE2
+static void MatchLike(benchmark::State& state) {
+ MatchSubstringOptions options("ab%ac");
+ UnaryStringBenchmark(state, "match_like", &options);
+}
+
+// MatchLike optimizes the following three into a substring/prefix/suffix search instead
+// of using RE2
+static void MatchLikeSubstring(benchmark::State& state) {
+ MatchSubstringOptions options("%abac%");
+ UnaryStringBenchmark(state, "match_like", &options);
+}
+
+static void MatchLikePrefix(benchmark::State& state) {
+ MatchSubstringOptions options("%abac");
+ UnaryStringBenchmark(state, "match_like", &options);
+}
+
+static void MatchLikeSuffix(benchmark::State& state) {
+ MatchSubstringOptions options("%abac");
+ UnaryStringBenchmark(state, "match_like", &options);
+}
+#endif
+
+#ifdef ARROW_WITH_UTF8PROC
+static void Utf8Upper(benchmark::State& state) {
+ UnaryStringBenchmark(state, "utf8_upper");
+}
+
+static void Utf8Lower(benchmark::State& state) {
+ UnaryStringBenchmark(state, "utf8_lower");
+}
+
+static void IsAlphaNumericUnicode(benchmark::State& state) {
+ UnaryStringBenchmark(state, "utf8_is_alnum");
+}
+static void TrimSingleUtf8(benchmark::State& state) {
+ TrimOptions options("a");
+ UnaryStringBenchmark(state, "utf8_trim", &options);
+}
+
+static void TrimManyUtf8(benchmark::State& state) {
+ TrimOptions options("abcdefgABCDEFG");
+ UnaryStringBenchmark(state, "utf8_trim", &options);
+}
+#endif
+
+using SeparatorFactory = std::function<Datum(int64_t n, double null_probability)>;
+
+static void BinaryJoin(benchmark::State& state, SeparatorFactory make_separator) {
+ const int64_t n_strings = 10000;
+ const int64_t n_lists = 1000;
+ const double null_probability = 0.02;
+
+ random::RandomArrayGenerator rng(kSeed);
+
+ auto strings =
+ rng.String(n_strings, /*min_length=*/5, /*max_length=*/20, null_probability);
+ auto lists = rng.List(*strings, n_lists, null_probability, /*force_empty_nulls=*/true);
+ auto separator = make_separator(n_lists, null_probability);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(CallFunction("binary_join", {lists, separator}));
+ }
+ state.SetBytesProcessed(
+ state.iterations() *
+ checked_cast<const StringArray&>(*strings).total_values_length());
+}
+
+static void BinaryJoinArrayScalar(benchmark::State& state) {
+ BinaryJoin(state, [](int64_t n, double null_probability) -> Datum {
+ return ScalarFromJSON(utf8(), R"("--")");
+ });
+}
+
+static void BinaryJoinArrayArray(benchmark::State& state) {
+ BinaryJoin(state, [](int64_t n, double null_probability) -> Datum {
+ random::RandomArrayGenerator rng(kSeed + 1);
+ return rng.String(n, /*min_length=*/0, /*max_length=*/4, null_probability);
+ });
+}
+
+static void BinaryJoinElementWise(benchmark::State& state,
+ SeparatorFactory make_separator) {
+ // Unfortunately benchmark is not 1:1 with BinaryJoin since BinaryJoin can join a
+ // varying number of inputs per output
+ const int64_t n_rows = 10000;
+ const int64_t n_cols = state.range(0);
+ const double null_probability = 0.02;
+
+ random::RandomArrayGenerator rng(kSeed);
+
+ DatumVector args;
+ ArrayVector strings;
+ int64_t total_values_length = 0;
+ for (int i = 0; i < n_cols; i++) {
+ auto arr = rng.String(n_rows, /*min_length=*/5, /*max_length=*/20, null_probability);
+ strings.push_back(arr);
+ args.emplace_back(arr);
+ total_values_length += checked_cast<const StringArray&>(*arr).total_values_length();
+ }
+ auto separator = make_separator(n_rows, null_probability);
+ args.emplace_back(separator);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(CallFunction("binary_join_element_wise", args));
+ }
+ state.SetBytesProcessed(state.iterations() * total_values_length);
+}
+
+static void BinaryJoinElementWiseArrayScalar(benchmark::State& state) {
+ BinaryJoinElementWise(state, [](int64_t n, double null_probability) -> Datum {
+ return ScalarFromJSON(utf8(), R"("--")");
+ });
+}
+
+static void BinaryJoinElementWiseArrayArray(benchmark::State& state) {
+ BinaryJoinElementWise(state, [](int64_t n, double null_probability) -> Datum {
+ random::RandomArrayGenerator rng(kSeed + 1);
+ return rng.String(n, /*min_length=*/0, /*max_length=*/4, null_probability);
+ });
+}
+
+BENCHMARK(AsciiLower);
+BENCHMARK(AsciiUpper);
+BENCHMARK(IsAlphaNumericAscii);
+BENCHMARK(MatchSubstring);
+BENCHMARK(SplitPattern);
+BENCHMARK(TrimSingleAscii);
+BENCHMARK(TrimManyAscii);
+#ifdef ARROW_WITH_RE2
+BENCHMARK(MatchLike);
+BENCHMARK(MatchLikeSubstring);
+BENCHMARK(MatchLikePrefix);
+BENCHMARK(MatchLikeSuffix);
+#endif
+#ifdef ARROW_WITH_UTF8PROC
+BENCHMARK(Utf8Lower);
+BENCHMARK(Utf8Upper);
+BENCHMARK(IsAlphaNumericUnicode);
+BENCHMARK(TrimSingleUtf8);
+BENCHMARK(TrimManyUtf8);
+#endif
+
+BENCHMARK(BinaryJoinArrayScalar);
+BENCHMARK(BinaryJoinArrayArray);
+BENCHMARK(BinaryJoinElementWiseArrayScalar)->RangeMultiplier(8)->Range(2, 128);
+BENCHMARK(BinaryJoinElementWiseArrayArray)->RangeMultiplier(8)->Range(2, 128);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_string_test.cc
new file mode 100644
index 000000000..e16d9b2dc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_string_test.cc
@@ -0,0 +1,1739 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#ifdef ARROW_WITH_UTF8PROC
+#include <utf8proc.h>
+#endif
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+namespace compute {
+
+// interesting utf8 characters for testing (lower case / upper case):
+// * ῦ / Υ͂ (3 to 4 code units) (Note, we don't support this yet, utf8proc does not use
+// SpecialCasing.txt)
+// * ɑ / Ɑ (2 to 3 code units)
+// * ı / I (2 to 1 code units)
+// * Ⱥ / ⱥ (2 to 3 code units)
+
+template <typename TestType>
+class BaseTestStringKernels : public ::testing::Test {
+ protected:
+ using OffsetType = typename TypeTraits<TestType>::OffsetType;
+ using ScalarType = typename TypeTraits<TestType>::ScalarType;
+
+ void CheckUnary(std::string func_name, std::string json_input,
+ std::shared_ptr<DataType> out_ty, std::string json_expected,
+ const FunctionOptions* options = nullptr) {
+ CheckScalarUnary(func_name, type(), json_input, out_ty, json_expected, options);
+ }
+
+ void CheckBinaryScalar(std::string func_name, std::string json_left_input,
+ std::string json_right_scalar, std::shared_ptr<DataType> out_ty,
+ std::string json_expected,
+ const FunctionOptions* options = nullptr) {
+ CheckScalarBinaryScalar(func_name, type(), json_left_input, json_right_scalar, out_ty,
+ json_expected, options);
+ }
+
+ void CheckVarArgsScalar(std::string func_name, std::string json_input,
+ std::shared_ptr<DataType> out_ty, std::string json_expected,
+ const FunctionOptions* options = nullptr) {
+ // CheckScalar (on arrays) checks scalar arguments individually,
+ // but this lets us test the all-scalar case explicitly
+ ScalarVector inputs;
+ std::shared_ptr<Array> args = ArrayFromJSON(type(), json_input);
+ for (int64_t i = 0; i < args->length(); i++) {
+ ASSERT_OK_AND_ASSIGN(auto scalar, args->GetScalar(i));
+ inputs.push_back(std::move(scalar));
+ }
+ CheckScalar(func_name, inputs, ScalarFromJSON(out_ty, json_expected), options);
+ }
+
+ void CheckVarArgs(std::string func_name, const std::vector<Datum>& inputs,
+ std::shared_ptr<DataType> out_ty, std::string json_expected,
+ const FunctionOptions* options = nullptr) {
+ CheckScalar(func_name, inputs, ArrayFromJSON(out_ty, json_expected), options);
+ }
+
+ std::shared_ptr<DataType> type() { return TypeTraits<TestType>::type_singleton(); }
+
+ template <typename CType>
+ std::shared_ptr<ScalarType> scalar(CType value) {
+ return std::make_shared<ScalarType>(value);
+ }
+
+ std::shared_ptr<DataType> offset_type() {
+ return TypeTraits<OffsetType>::type_singleton();
+ }
+};
+
+template <typename TestType>
+class TestBinaryKernels : public BaseTestStringKernels<TestType> {};
+
+TYPED_TEST_SUITE(TestBinaryKernels, BinaryArrowTypes);
+
+TYPED_TEST(TestBinaryKernels, BinaryLength) {
+ this->CheckUnary("binary_length", R"(["aaa", null, "áéíóú", "", "b"])",
+ this->offset_type(), "[3, null, 10, 0, 1]");
+}
+
+TYPED_TEST(TestBinaryKernels, BinaryReplaceSlice) {
+ ReplaceSliceOptions options{0, 1, "XX"};
+ this->CheckUnary("binary_replace_slice", "[]", this->type(), "[]", &options);
+ this->CheckUnary("binary_replace_slice", R"([null, "", "a", "ab", "abc"])",
+ this->type(), R"([null, "XX", "XX", "XXb", "XXbc"])", &options);
+
+ ReplaceSliceOptions options_whole{0, 5, "XX"};
+ this->CheckUnary("binary_replace_slice",
+ R"([null, "", "a", "ab", "abc", "abcde", "abcdef"])", this->type(),
+ R"([null, "XX", "XX", "XX", "XX", "XX", "XXf"])", &options_whole);
+
+ ReplaceSliceOptions options_middle{2, 4, "XX"};
+ this->CheckUnary("binary_replace_slice",
+ R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(),
+ R"([null, "XX", "aXX", "abXX", "abXX", "abXX", "abXXe"])",
+ &options_middle);
+
+ ReplaceSliceOptions options_neg_start{-3, -2, "XX"};
+ this->CheckUnary("binary_replace_slice",
+ R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(),
+ R"([null, "XX", "XXa", "XXab", "XXbc", "aXXcd", "abXXde"])",
+ &options_neg_start);
+
+ ReplaceSliceOptions options_neg_end{2, -2, "XX"};
+ this->CheckUnary("binary_replace_slice",
+ R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(),
+ R"([null, "XX", "aXX", "abXX", "abXXc", "abXXcd", "abXXde"])",
+ &options_neg_end);
+
+ ReplaceSliceOptions options_neg_pos{-1, 2, "XX"};
+ this->CheckUnary("binary_replace_slice",
+ R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(),
+ R"([null, "XX", "XX", "aXX", "abXXc", "abcXXd", "abcdXXe"])",
+ &options_neg_pos);
+
+ // Effectively the same as [2, 2)
+ ReplaceSliceOptions options_flip{2, 0, "XX"};
+ this->CheckUnary("binary_replace_slice",
+ R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(),
+ R"([null, "XX", "aXX", "abXX", "abXXc", "abXXcd", "abXXcde"])",
+ &options_flip);
+
+ // Effectively the same as [-3, -3)
+ ReplaceSliceOptions options_neg_flip{-3, -5, "XX"};
+ this->CheckUnary("binary_replace_slice",
+ R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(),
+ R"([null, "XX", "XXa", "XXab", "XXabc", "aXXbcd", "abXXcde"])",
+ &options_neg_flip);
+}
+
+TYPED_TEST(TestBinaryKernels, FindSubstring) {
+ MatchSubstringOptions options{"ab"};
+ this->CheckUnary("find_substring", "[]", this->offset_type(), "[]", &options);
+ this->CheckUnary("find_substring", R"(["abc", "acb", "cab", null, "bac"])",
+ this->offset_type(), "[0, -1, 1, null, -1]", &options);
+
+ MatchSubstringOptions options_repeated{"abab"};
+ this->CheckUnary("find_substring", R"(["abab", "ab", "cababc", null, "bac"])",
+ this->offset_type(), "[0, -1, 1, null, -1]", &options_repeated);
+
+ MatchSubstringOptions options_double_char{"aab"};
+ this->CheckUnary("find_substring", R"(["aacb", "aab", "ab", "aaab"])",
+ this->offset_type(), "[-1, 0, -1, 1]", &options_double_char);
+
+ MatchSubstringOptions options_double_char_2{"bbcaa"};
+ this->CheckUnary("find_substring", R"(["abcbaabbbcaabccabaab"])", this->offset_type(),
+ "[7]", &options_double_char_2);
+
+ MatchSubstringOptions options_empty{""};
+ this->CheckUnary("find_substring", R"(["", "a", null])", this->offset_type(),
+ "[0, 0, null]", &options_empty);
+}
+
+#ifdef ARROW_WITH_RE2
+TYPED_TEST(TestBinaryKernels, FindSubstringIgnoreCase) {
+ MatchSubstringOptions options{"?AB)", /*ignore_case=*/true};
+ this->CheckUnary("find_substring", "[]", this->offset_type(), "[]", &options);
+ this->CheckUnary("find_substring",
+ R"-(["?aB)c", "acb", "c?Ab)", null, "?aBc", "AB)"])-",
+ this->offset_type(), "[0, -1, 1, null, -1, -1]", &options);
+}
+
+TYPED_TEST(TestBinaryKernels, FindSubstringRegex) {
+ MatchSubstringOptions options{"a+", /*ignore_case=*/false};
+ this->CheckUnary("find_substring_regex", "[]", this->offset_type(), "[]", &options);
+ this->CheckUnary("find_substring_regex", R"(["a", "A", "baaa", null, "", "AaaA"])",
+ this->offset_type(), "[0, -1, 1, null, -1, 1]", &options);
+
+ options.ignore_case = true;
+ this->CheckUnary("find_substring_regex", "[]", this->offset_type(), "[]", &options);
+ this->CheckUnary("find_substring_regex", R"(["a", "A", "baaa", null, "", "AaaA"])",
+ this->offset_type(), "[0, 0, 1, null, -1, 0]", &options);
+}
+#else
+TYPED_TEST(TestBinaryKernels, FindSubstringIgnoreCase) {
+ MatchSubstringOptions options{"a+", /*ignore_case=*/true};
+ Datum input = ArrayFromJSON(this->type(), R"(["a"])");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
+ ::testing::HasSubstr("ignore_case requires RE2"),
+ CallFunction("find_substring", {input}, &options));
+}
+#endif
+
+TYPED_TEST(TestBinaryKernels, CountSubstring) {
+ MatchSubstringOptions options{"aba"};
+ this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options);
+ this->CheckUnary(
+ "count_substring",
+ R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])",
+ this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options);
+
+ MatchSubstringOptions options_empty{""};
+ this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(),
+ "[1, null, 4]", &options_empty);
+
+ MatchSubstringOptions options_repeated{"aaa"};
+ this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaá"])",
+ this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated);
+}
+
+#ifdef ARROW_WITH_RE2
+TYPED_TEST(TestBinaryKernels, CountSubstringRegex) {
+ MatchSubstringOptions options{"aba"};
+ this->CheckUnary("count_substring_regex", "[]", this->offset_type(), "[]", &options);
+ this->CheckUnary(
+ "count_substring",
+ R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])",
+ this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options);
+
+ MatchSubstringOptions options_empty{""};
+ this->CheckUnary("count_substring_regex", R"(["", null, "abc"])", this->offset_type(),
+ "[1, null, 4]", &options_empty);
+
+ MatchSubstringOptions options_as{"a+"};
+ this->CheckUnary("count_substring_regex", R"(["", "bacaaadaaaa", "c", "AAA"])",
+ this->offset_type(), "[0, 3, 0, 0]", &options_as);
+
+ MatchSubstringOptions options_empty_match{"a*"};
+ this->CheckUnary("count_substring_regex", R"(["", "bacaaadaaaa", "c", "AAA"])",
+ // 7 is because it matches at |b|a|c|aaa|d|aaaa|
+ this->offset_type(), "[1, 7, 2, 4]", &options_empty_match);
+
+ MatchSubstringOptions options_repeated{"aaa"};
+ this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaá"])",
+ this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated);
+}
+
+TYPED_TEST(TestBinaryKernels, CountSubstringIgnoreCase) {
+ MatchSubstringOptions options{"aba", /*ignore_case=*/true};
+ this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options);
+ this->CheckUnary(
+ "count_substring",
+ R"(["", null, "ab", "aBa", "bAbA", "aBaBa", "abaAbA", "babacaba", "ABA"])",
+ this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 1]", &options);
+
+ MatchSubstringOptions options_empty{"", /*ignore_case=*/true};
+ this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(),
+ "[1, null, 4]", &options_empty);
+}
+
+TYPED_TEST(TestBinaryKernels, CountSubstringRegexIgnoreCase) {
+ MatchSubstringOptions options_as{"a+", /*ignore_case=*/true};
+ this->CheckUnary("count_substring_regex", R"(["", "bacAaAdaAaA", "c", "AAA"])",
+ this->offset_type(), "[0, 3, 0, 1]", &options_as);
+
+ MatchSubstringOptions options_empty_match{"a*", /*ignore_case=*/true};
+ this->CheckUnary("count_substring_regex", R"(["", "bacAaAdaAaA", "c", "AAA"])",
+ this->offset_type(), "[1, 7, 2, 2]", &options_empty_match);
+}
+#else
+TYPED_TEST(TestBinaryKernels, CountSubstringIgnoreCase) {
+ Datum input = ArrayFromJSON(this->type(), R"(["a"])");
+ MatchSubstringOptions options{"a", /*ignore_case=*/true};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
+ ::testing::HasSubstr("ignore_case requires RE2"),
+ CallFunction("count_substring", {input}, &options));
+}
+#endif
+
+TYPED_TEST(TestBinaryKernels, BinaryJoinElementWise) {
+ const auto ty = this->type();
+ JoinOptions options;
+ JoinOptions options_skip(JoinOptions::SKIP);
+ JoinOptions options_replace(JoinOptions::REPLACE, "X");
+ // Scalar args, Scalar separator
+ this->CheckVarArgsScalar("binary_join_element_wise", R"([null])", ty, R"(null)",
+ &options);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["-"])", ty, R"("")", &options);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", "-"])", ty, R"("a")",
+ &options);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", "b", "-"])", ty,
+ R"("a-b")", &options);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", "b", null])", ty,
+ R"(null)", &options);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", null, "-"])", ty,
+ R"(null)", &options);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["foo", "bar", "baz", "++"])",
+ ty, R"("foo++bar++baz")", &options);
+
+ // Scalar args, Array separator
+ const auto sep = ArrayFromJSON(ty, R"([null, "-", "--"])");
+ const auto scalar1 = ScalarFromJSON(ty, R"("foo")");
+ const auto scalar2 = ScalarFromJSON(ty, R"("bar")");
+ const auto scalar3 = ScalarFromJSON(ty, R"("")");
+ const auto scalar_null = ScalarFromJSON(ty, R"(null)");
+ this->CheckVarArgs("binary_join_element_wise", {sep}, ty, R"([null, "", ""])",
+ &options);
+ this->CheckVarArgs("binary_join_element_wise", {scalar1, sep}, ty,
+ R"([null, "foo", "foo"])", &options);
+ this->CheckVarArgs("binary_join_element_wise", {scalar1, scalar2, sep}, ty,
+ R"([null, "foo-bar", "foo--bar"])", &options);
+ this->CheckVarArgs("binary_join_element_wise", {scalar1, scalar_null, sep}, ty,
+ R"([null, null, null])", &options);
+ this->CheckVarArgs("binary_join_element_wise", {scalar1, scalar2, scalar3, sep}, ty,
+ R"([null, "foo-bar-", "foo--bar--"])", &options);
+
+ // Array args, Scalar separator
+ const auto sep1 = ScalarFromJSON(ty, R"("-")");
+ const auto sep2 = ScalarFromJSON(ty, R"("--")");
+ const auto arr1 = ArrayFromJSON(ty, R"([null, "a", "bb", "ccc"])");
+ const auto arr2 = ArrayFromJSON(ty, R"(["d", null, "e", ""])");
+ const auto arr3 = ArrayFromJSON(ty, R"(["gg", null, "h", "iii"])");
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, arr3, scalar_null}, ty,
+ R"([null, null, null, null])", &options);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, arr3, sep1}, ty,
+ R"([null, null, "bb-e-h", "ccc--iii"])", &options);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, arr3, sep2}, ty,
+ R"([null, null, "bb--e--h", "ccc----iii"])", &options);
+
+ // Array args, Array separator
+ const auto sep3 = ArrayFromJSON(ty, R"(["-", "--", null, "---"])");
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, arr3, sep3}, ty,
+ R"([null, null, null, "ccc------iii"])", &options);
+
+ // Mixed
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar2, sep3}, ty,
+ R"([null, null, null, "ccc------bar"])", &options);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar_null, sep3}, ty,
+ R"([null, null, null, null])", &options);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar2, sep1}, ty,
+ R"([null, null, "bb-e-bar", "ccc--bar"])", &options);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar_null, scalar_null},
+ ty, R"([null, null, null, null])", &options);
+
+ // Skip
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", null, "b", "-"])", ty,
+ R"("a-b")", &options_skip);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", null, "b", null])", ty,
+ R"(null)", &options_skip);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar2, sep3}, ty,
+ R"(["d-bar", "a--bar", null, "ccc------bar"])", &options_skip);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar_null, sep3}, ty,
+ R"(["d", "a", null, "ccc---"])", &options_skip);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar2, sep1}, ty,
+ R"(["d-bar", "a-bar", "bb-e-bar", "ccc--bar"])", &options_skip);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar_null, scalar_null},
+ ty, R"([null, null, null, null])", &options_skip);
+
+ // Replace
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", null, "b", "-"])", ty,
+ R"("a-X-b")", &options_replace);
+ this->CheckVarArgsScalar("binary_join_element_wise", R"(["a", null, "b", null])", ty,
+ R"(null)", &options_replace);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar2, sep3}, ty,
+ R"(["X-d-bar", "a--X--bar", null, "ccc------bar"])",
+ &options_replace);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar_null, sep3}, ty,
+ R"(["X-d-X", "a--X--X", null, "ccc------X"])", &options_replace);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar2, sep1}, ty,
+ R"(["X-d-bar", "a-X-bar", "bb-e-bar", "ccc--bar"])",
+ &options_replace);
+ this->CheckVarArgs("binary_join_element_wise", {arr1, arr2, scalar_null, scalar_null},
+ ty, R"([null, null, null, null])", &options_replace);
+
+ // Error cases
+ ASSERT_RAISES(Invalid, CallFunction("binary_join_element_wise", {}, &options));
+}
+
+class TestFixedSizeBinaryKernels : public ::testing::Test {
+ protected:
+ void CheckUnary(std::string func_name, std::string json_input,
+ std::shared_ptr<DataType> out_ty, std::string json_expected,
+ const FunctionOptions* options = nullptr) {
+ CheckScalarUnary(func_name, type(), json_input, out_ty, json_expected, options);
+ // Ensure the equivalent binary kernel does the same thing
+ CheckScalarUnary(func_name, binary(), json_input,
+ out_ty->id() == Type::FIXED_SIZE_BINARY ? binary() : out_ty,
+ json_expected, options);
+ }
+
+ std::shared_ptr<DataType> type() const { return fixed_size_binary(6); }
+ std::shared_ptr<DataType> offset_type() const { return int32(); }
+};
+
+TEST_F(TestFixedSizeBinaryKernels, BinaryLength) {
+ CheckUnary("binary_length", R"(["aaaaaa", null, "áéí"])", offset_type(),
+ "[6, null, 6]");
+}
+
+TEST_F(TestFixedSizeBinaryKernels, BinaryReplaceSlice) {
+ ReplaceSliceOptions options{0, 1, "XX"};
+ CheckUnary("binary_replace_slice", "[]", fixed_size_binary(7), "[]", &options);
+ CheckUnary("binary_replace_slice", R"([null, "abcdef"])", fixed_size_binary(7),
+ R"([null, "XXbcdef"])", &options);
+
+ ReplaceSliceOptions options_shrink{0, 2, ""};
+ CheckUnary("binary_replace_slice", R"([null, "abcdef"])", fixed_size_binary(4),
+ R"([null, "cdef"])", &options_shrink);
+
+ ReplaceSliceOptions options_whole{0, 6, "XX"};
+ CheckUnary("binary_replace_slice", R"([null, "abcdef"])", fixed_size_binary(2),
+ R"([null, "XX"])", &options_whole);
+
+ ReplaceSliceOptions options_middle{2, 4, "XX"};
+ CheckUnary("binary_replace_slice", R"([null, "abcdef"])", fixed_size_binary(6),
+ R"([null, "abXXef"])", &options_middle);
+
+ ReplaceSliceOptions options_neg_start{-3, -2, "XX"};
+ CheckUnary("binary_replace_slice", R"([null, "abcdef"])", fixed_size_binary(7),
+ R"([null, "abcXXef"])", &options_neg_start);
+
+ ReplaceSliceOptions options_neg_end{2, -2, "XX"};
+ CheckUnary("binary_replace_slice", R"([null, "abcdef"])", fixed_size_binary(6),
+ R"([null, "abXXef"])", &options_neg_end);
+
+ ReplaceSliceOptions options_neg_pos{-1, 2, "XX"};
+ CheckUnary("binary_replace_slice", R"([null, "abcdef"])", fixed_size_binary(8),
+ R"([null, "abcdeXXf"])", &options_neg_pos);
+
+ // Effectively the same as [2, 2)
+ ReplaceSliceOptions options_flip{2, 0, "XX"};
+ CheckUnary("binary_replace_slice", R"([null, "abcdef"])", fixed_size_binary(8),
+ R"([null, "abXXcdef"])", &options_flip);
+
+ // Effectively the same as [-3, -3)
+ ReplaceSliceOptions options_neg_flip{-3, -5, "XX"};
+ CheckUnary("binary_replace_slice", R"([null, "abcdef"])", fixed_size_binary(8),
+ R"([null, "abcXXdef"])", &options_neg_flip);
+}
+
+TEST_F(TestFixedSizeBinaryKernels, CountSubstring) {
+ MatchSubstringOptions options{"aba"};
+ CheckUnary("count_substring", "[]", offset_type(), "[]", &options);
+ CheckUnary(
+ "count_substring",
+ R"([" ", null, " ab ", " aba ", "baba ", "ababa ", "abaaba", "ABAABA"])",
+ offset_type(), "[0, null, 0, 1, 1, 1, 2, 0]", &options);
+
+ MatchSubstringOptions options_empty{""};
+ CheckUnary("count_substring", R"([" ", null, "abc "])", offset_type(),
+ "[7, null, 7]", &options_empty);
+
+ MatchSubstringOptions options_repeated{"aaa"};
+ CheckUnary("count_substring", R"([" ", "aaaa ", "aaaaa ", "aaaaaa", "aaáaa"])",
+ offset_type(), "[0, 1, 1, 2, 0]", &options_repeated);
+}
+
+#ifdef ARROW_WITH_RE2
+TEST_F(TestFixedSizeBinaryKernels, CountSubstringRegex) {
+ MatchSubstringOptions options{"aba"};
+ CheckUnary("count_substring_regex", "[]", offset_type(), "[]", &options);
+ CheckUnary(
+ "count_substring_regex",
+ R"([" ", null, " ab ", " aba ", "baba ", "ababa ", "abaaba", "ABAABA"])",
+ offset_type(), "[0, null, 0, 1, 1, 1, 2, 0]", &options);
+
+ MatchSubstringOptions options_empty{""};
+ CheckUnary("count_substring_regex", R"([" ", null, "abc "])", offset_type(),
+ "[7, null, 7]", &options_empty);
+
+ MatchSubstringOptions options_repeated{"aaa"};
+ CheckUnary("count_substring_regex",
+ R"([" ", "aaaa ", "aaaaa ", "aaaaaa", "aaáaa"])", offset_type(),
+ "[0, 1, 1, 2, 0]", &options_repeated);
+
+ MatchSubstringOptions options_as{"a+"};
+ CheckUnary("count_substring_regex", R"([" ", "bacaaa", "c ", "AAAAAA"])",
+ offset_type(), "[0, 2, 0, 0]", &options_as);
+
+ MatchSubstringOptions options_empty_match{"a*"};
+ CheckUnary("count_substring_regex", R"([" ", "bacaaa", "c ", "AAAAAA"])",
+ // 5 is because it matches at |b|a|c|aaa|
+ offset_type(), "[7, 5, 7, 7]", &options_empty_match);
+}
+
+TEST_F(TestFixedSizeBinaryKernels, CountSubstringIgnoreCase) {
+ MatchSubstringOptions options{"aba", /*ignore_case=*/true};
+ CheckUnary("count_substring", "[]", offset_type(), "[]", &options);
+ CheckUnary(
+ "count_substring",
+ R"([" ", null, "ab ", "aBa ", " bAbA ", " aBaBa", "abaAbA", "abaaba", "ABAabc"])",
+ offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 1]", &options);
+
+ MatchSubstringOptions options_empty{"", /*ignore_case=*/true};
+ CheckUnary("count_substring", R"([" ", null, "abcABc"])", offset_type(),
+ "[7, null, 7]", &options_empty);
+}
+
+TEST_F(TestFixedSizeBinaryKernels, CountSubstringRegexIgnoreCase) {
+ MatchSubstringOptions options_as{"a+", /*ignore_case=*/true};
+ CheckUnary("count_substring_regex", R"([" ", "aAadaA", "c ", "AAAbbb"])",
+ offset_type(), "[0, 2, 0, 1]", &options_as);
+
+ MatchSubstringOptions options_empty_match{"a*", /*ignore_case=*/true};
+ CheckUnary("count_substring_regex", R"([" ", "aAadaA", "c ", "AAAbbb"])",
+ offset_type(), "[7, 4, 7, 5]", &options_empty_match);
+}
+#else
+TEST_F(TestFixedSizeBinaryKernels, CountSubstringIgnoreCase) {
+ Datum input = ArrayFromJSON(type(), R"([" a "])");
+ MatchSubstringOptions options{"a", /*ignore_case=*/true};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
+ ::testing::HasSubstr("ignore_case requires RE2"),
+ CallFunction("count_substring", {input}, &options));
+}
+#endif
+
+TEST_F(TestFixedSizeBinaryKernels, FindSubstring) {
+ MatchSubstringOptions options{"ab"};
+ CheckUnary("find_substring", "[]", offset_type(), "[]", &options);
+ CheckUnary("find_substring", R"(["abc ", " acb", " cab ", null, " bac "])",
+ offset_type(), "[0, -1, 2, null, -1]", &options);
+
+ MatchSubstringOptions options_repeated{"abab"};
+ CheckUnary("find_substring", R"([" abab ", " ab ", "cababc", null, " bac "])",
+ offset_type(), "[1, -1, 1, null, -1]", &options_repeated);
+
+ MatchSubstringOptions options_double_char{"aab"};
+ CheckUnary("find_substring", R"([" aacb", "aab ", " ab ", " aaab"])",
+ offset_type(), "[-1, 0, -1, 3]", &options_double_char);
+
+ MatchSubstringOptions options_double_char_2{"bbcaa"};
+ CheckUnary("find_substring", R"(["bbbcaa"])", offset_type(), "[1]",
+ &options_double_char_2);
+
+ MatchSubstringOptions options_empty{""};
+ CheckUnary("find_substring", R"([" ", "aaaaaa", null])", offset_type(),
+ "[0, 0, null]", &options_empty);
+}
+
+#ifdef ARROW_WITH_RE2
+TEST_F(TestFixedSizeBinaryKernels, FindSubstringIgnoreCase) {
+ MatchSubstringOptions options{"?AB)", /*ignore_case=*/true};
+ CheckUnary("find_substring", "[]", offset_type(), "[]", &options);
+ CheckUnary("find_substring",
+ R"-(["?aB)c ", " acb ", " c?Ab)", null, " ?aBc ", " AB) "])-",
+ offset_type(), "[0, -1, 2, null, -1, -1]", &options);
+}
+
+TEST_F(TestFixedSizeBinaryKernels, FindSubstringRegex) {
+ MatchSubstringOptions options{"a+", /*ignore_case=*/false};
+ CheckUnary("find_substring_regex", "[]", offset_type(), "[]", &options);
+ CheckUnary("find_substring_regex",
+ R"(["a ", " A ", " baaa", null, " ", " AaaA "])", offset_type(),
+ "[0, -1, 3, null, -1, 2]", &options);
+
+ options.ignore_case = true;
+ CheckUnary("find_substring_regex", "[]", offset_type(), "[]", &options);
+ CheckUnary("find_substring_regex",
+ R"(["a ", " A ", " baaa", null, " ", " AaaA "])", offset_type(),
+ "[0, 2, 3, null, -1, 1]", &options);
+}
+#else
+TEST_F(TestFixedSizeBinaryKernels, FindSubstringIgnoreCase) {
+ MatchSubstringOptions options{"a+", /*ignore_case=*/true};
+ Datum input = ArrayFromJSON(type(), R"(["aaaaaa"])");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
+ ::testing::HasSubstr("ignore_case requires RE2"),
+ CallFunction("find_substring", {input}, &options));
+}
+#endif
+
+template <typename TestType>
+class TestStringKernels : public BaseTestStringKernels<TestType> {};
+
+TYPED_TEST_SUITE(TestStringKernels, StringArrowTypes);
+
+TYPED_TEST(TestStringKernels, AsciiUpper) {
+ this->CheckUnary("ascii_upper", "[]", this->type(), "[]");
+ this->CheckUnary("ascii_upper", "[\"aAazZæÆ&\", null, \"\", \"bbb\"]", this->type(),
+ "[\"AAAZZæÆ&\", null, \"\", \"BBB\"]");
+}
+
+TYPED_TEST(TestStringKernels, AsciiLower) {
+ this->CheckUnary("ascii_lower", "[]", this->type(), "[]");
+ this->CheckUnary("ascii_lower", "[\"aAazZæÆ&\", null, \"\", \"BBB\"]", this->type(),
+ "[\"aaazzæÆ&\", null, \"\", \"bbb\"]");
+}
+
+TYPED_TEST(TestStringKernels, AsciiSwapCase) {
+ this->CheckUnary("ascii_swapcase", "[]", this->type(), "[]");
+ this->CheckUnary("ascii_swapcase", "[\"aAazZæÆ&\", null, \"\", \"BbB\"]", this->type(),
+ "[\"AaAZzæÆ&\", null, \"\", \"bBb\"]");
+ this->CheckUnary("ascii_swapcase", "[\"hEllO, WoRld!\", \"$. A35?\"]", this->type(),
+ "[\"HeLLo, wOrLD!\", \"$. a35?\"]");
+}
+
+TYPED_TEST(TestStringKernels, AsciiCapitalize) {
+ this->CheckUnary("ascii_capitalize", "[]", this->type(), "[]");
+ this->CheckUnary("ascii_capitalize",
+ "[\"aAazZæÆ&\", null, \"\", \"bBB\", \"hEllO, WoRld!\", \"$. A3\", "
+ "\"!hELlo, wORLd!\"]",
+ this->type(),
+ "[\"AaazzæÆ&\", null, \"\", \"Bbb\", \"Hello, world!\", \"$. a3\", "
+ "\"!hello, world!\"]");
+}
+
+TYPED_TEST(TestStringKernels, AsciiTitle) {
+ this->CheckUnary(
+ "ascii_title",
+ R"([null, "", "b", "aAaz;ZeA&", "arRoW", "iI", "a.a.a..A", "hEllO, WoRld!", "foo baR;heHe0zOP", "!%$^.,;"])",
+ this->type(),
+ R"([null, "", "B", "Aaaz;Zea&", "Arrow", "Ii", "A.A.A..A", "Hello, World!", "Foo Bar;Hehe0Zop", "!%$^.,;"])");
+}
+
+TYPED_TEST(TestStringKernels, AsciiReverse) {
+ this->CheckUnary("ascii_reverse", "[]", this->type(), "[]");
+ this->CheckUnary("ascii_reverse", R"(["abcd", null, "", "bbb"])", this->type(),
+ R"(["dcba", null, "", "bbb"])");
+
+ auto invalid_input = ArrayFromJSON(this->type(), R"(["aAazZæÆ&", null, "", "bcd"])");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ testing::HasSubstr("Non-ASCII sequence in input"),
+ CallFunction("ascii_reverse", {invalid_input}));
+ auto masked_input = TweakValidityBit(invalid_input, 0, false);
+ CheckScalarUnary("ascii_reverse", masked_input,
+ ArrayFromJSON(this->type(), R"([null, null, "", "dcb"])"));
+}
+
+TYPED_TEST(TestStringKernels, Utf8Reverse) {
+ this->CheckUnary("utf8_reverse", "[]", this->type(), "[]");
+ this->CheckUnary("utf8_reverse", R"(["abcd", null, "", "bbb"])", this->type(),
+ R"(["dcba", null, "", "bbb"])");
+ this->CheckUnary("utf8_reverse", R"(["aAazZæÆ&", null, "", "bbb", "ɑɽⱤæÆ"])",
+ this->type(), R"(["&ÆæZzaAa", null, "", "bbb", "ÆæⱤɽɑ"])");
+
+ // inputs with malformed utf8 chars would produce garbage output, but the end result
+ // would produce arrays with same lengths. Hence checking offset buffer equality
+ auto malformed_input = ArrayFromJSON(this->type(), "[\"ɑ\xFFɑa\", \"ɽ\xe1\xbdɽa\"]");
+ const Result<Datum>& res = CallFunction("utf8_reverse", {malformed_input});
+ ASSERT_TRUE(res->array()->buffers[1]->Equals(*malformed_input->data()->buffers[1]));
+}
+
+TEST(TestStringKernels, LARGE_MEMORY_TEST(Utf8Upper32bitGrowth)) {
+ // 0x7fff * 0xffff is the max a 32 bit string array can hold
+ // since the utf8_upper kernel can grow it by 3/2, the max we should accept is is
+ // 0x7fff * 0xffff * 2/3 = 0x5555 * 0xffff, so this should give us a CapacityError
+ std::string str(0x5556 * 0xffff, 'a');
+ arrow::StringBuilder builder;
+ ASSERT_OK(builder.Append(str));
+ std::shared_ptr<arrow::Array> array;
+ arrow::Status st = builder.Finish(&array);
+ const FunctionOptions* options = nullptr;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(CapacityError,
+ testing::HasSubstr("Result might not fit"),
+ CallFunction("utf8_upper", {array}, options));
+ ASSERT_OK_AND_ASSIGN(auto scalar, array->GetScalar(0));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(CapacityError,
+ testing::HasSubstr("Result might not fit"),
+ CallFunction("utf8_upper", {scalar}, options));
+}
+
+TYPED_TEST(TestStringKernels, Utf8Length) {
+ this->CheckUnary("utf8_length",
+ R"(["aaa", null, "áéíóú", "ɑɽⱤoW😀", "áéí 0😀", "", "b"])",
+ this->offset_type(), "[3, null, 5, 6, 6, 0, 1]");
+}
+
+#ifdef ARROW_WITH_UTF8PROC
+
+TYPED_TEST(TestStringKernels, Utf8Upper) {
+ this->CheckUnary("utf8_upper", "[\"aAazZæÆ&\", null, \"\", \"b\"]", this->type(),
+ "[\"AAAZZÆÆ&\", null, \"\", \"B\"]");
+
+ // test varying encoding lengths and thus changing indices/offsets
+ this->CheckUnary("utf8_upper", "[\"ɑɽⱤoW\", null, \"ıI\", \"b\"]", this->type(),
+ "[\"ⱭⱤⱤOW\", null, \"II\", \"B\"]");
+
+ // ῦ to Υ͂ not supported
+ // this->CheckUnary("utf8_upper", "[\"ῦɐɜʞȿ\"]", this->type(),
+ // "[\"Υ͂ⱯꞫꞰⱾ\"]");
+
+ // test maximum buffer growth
+ this->CheckUnary("utf8_upper", "[\"ɑɑɑɑ\"]", this->type(), "[\"ⱭⱭⱭⱭ\"]");
+
+ // Test invalid data
+ auto invalid_input = ArrayFromJSON(this->type(), "[\"ɑa\xFFɑ\", \"ɽ\xe1\xbdɽaa\"]");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("Invalid UTF8 sequence"),
+ CallFunction("utf8_upper", {invalid_input}));
+}
+
+TYPED_TEST(TestStringKernels, Utf8Lower) {
+ this->CheckUnary("utf8_lower", "[\"aAazZæÆ&\", null, \"\", \"b\"]", this->type(),
+ "[\"aaazzææ&\", null, \"\", \"b\"]");
+
+ // test varying encoding lengths and thus changing indices/offsets
+ this->CheckUnary("utf8_lower", "[\"ⱭɽⱤoW\", null, \"ıI\", \"B\"]", this->type(),
+ "[\"ɑɽɽow\", null, \"ıi\", \"b\"]");
+
+ // ῦ to Υ͂ is not supported, but in principle the reverse is, but it would need
+ // normalization
+ // this->CheckUnary("utf8_lower", "[\"Υ͂ⱯꞫꞰⱾ\"]", this->type(),
+ // "[\"ῦɐɜʞȿ\"]");
+
+ // test maximum buffer growth
+ this->CheckUnary("utf8_lower", "[\"ȺȺȺȺ\"]", this->type(), "[\"ⱥⱥⱥⱥ\"]");
+
+ // Test invalid data
+ auto invalid_input = ArrayFromJSON(this->type(), "[\"Ⱥa\xFFⱭ\", \"Ɽ\xe1\xbdⱤaA\"]");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("Invalid UTF8 sequence"),
+ CallFunction("utf8_lower", {invalid_input}));
+}
+
+TYPED_TEST(TestStringKernels, Utf8SwapCase) {
+ this->CheckUnary("utf8_swapcase", "[\"aAazZæÆ&\", null, \"\", \"b\"]", this->type(),
+ "[\"AaAZzÆæ&\", null, \"\", \"B\"]");
+
+ // test varying encoding lengths and thus changing indices/offsets
+ this->CheckUnary("utf8_swapcase", "[\"ⱭɽⱤoW\", null, \"ıI\", \"B\"]", this->type(),
+ "[\"ɑⱤɽOw\", null, \"Ii\", \"b\"]");
+
+ // test maximum buffer growth
+ this->CheckUnary("utf8_swapcase", "[\"ȺȺȺȺ\"]", this->type(), "[\"ⱥⱥⱥⱥ\"]");
+
+ this->CheckUnary("utf8_swapcase", "[\"hEllO, WoRld!\", \"$. A35?\"]", this->type(),
+ "[\"HeLLo, wOrLD!\", \"$. a35?\"]");
+
+ // Test invalid data
+ auto invalid_input = ArrayFromJSON(this->type(), "[\"Ⱥa\xFFⱭ\", \"Ɽ\xe1\xbdⱤaA\"]");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("Invalid UTF8 sequence"),
+ CallFunction("utf8_swapcase", {invalid_input}));
+}
+
+TYPED_TEST(TestStringKernels, Utf8Capitalize) {
+ this->CheckUnary("utf8_capitalize", "[]", this->type(), "[]");
+ this->CheckUnary("utf8_capitalize",
+ "[\"aAazZæÆ&\", null, \"\", \"b\", \"ɑɽⱤoW\", \"ıI\", \"ⱥⱥⱥȺ\", "
+ "\"hEllO, WoRld!\", \"$. A3\", \"!ɑⱤⱤow\"]",
+ this->type(),
+ "[\"Aaazzææ&\", null, \"\", \"B\", \"Ɑɽɽow\", \"Ii\", \"Ⱥⱥⱥⱥ\", "
+ "\"Hello, world!\", \"$. a3\", \"!ɑɽɽow\"]");
+}
+
+TYPED_TEST(TestStringKernels, Utf8Title) {
+ this->CheckUnary(
+ "utf8_title",
+ R"([null, "", "b", "aAaz;ZæÆ&", "ɑɽⱤoW", "ıI", "ⱥ.ⱥ.ⱥ..Ⱥ", "hEllO, WoRld!", "foo baR;héHé0zOP", "!%$^.,;"])",
+ this->type(),
+ R"([null, "", "B", "Aaaz;Zææ&", "Ɑɽɽow", "Ii", "Ⱥ.Ⱥ.Ⱥ..Ⱥ", "Hello, World!", "Foo Bar;Héhé0Zop", "!%$^.,;"])");
+}
+
+TYPED_TEST(TestStringKernels, IsAlphaNumericUnicode) {
+ // U+08BE (utf8: \xE0\xA2\xBE) is undefined, but utf8proc things it is
+ // UTF8PROC_CATEGORY_LO
+ this->CheckUnary("utf8_is_alnum", "[\"ⱭɽⱤoW123\", null, \"Ɑ2\", \"!\", \"\"]",
+ boolean(), "[true, null, true, false, false]");
+}
+
+TYPED_TEST(TestStringKernels, IsAlphaUnicode) {
+ // U+08BE (utf8: \xE0\xA2\xBE) is undefined, but utf8proc things it is
+ // UTF8PROC_CATEGORY_LO
+ this->CheckUnary("utf8_is_alpha", "[\"ⱭɽⱤoW\", null, \"Ɑ2\", \"!\", \"\"]", boolean(),
+ "[true, null, false, false, false]");
+}
+
+TYPED_TEST(TestStringKernels, IsAscii) {
+ this->CheckUnary("string_is_ascii", "[\"azAZ~\", null, \"Ɑ\", \"\"]", boolean(),
+ "[true, null, false, true]");
+}
+
+TYPED_TEST(TestStringKernels, IsDecimalUnicode) {
+ // ٣ is arabic 3 (decimal), Ⅳ roman (non-decimal)
+ this->CheckUnary("utf8_is_decimal", "[\"12\", null, \"٣\", \"Ⅳ\", \"1a\", \"\"]",
+ boolean(), "[true, null, true, false, false, false]");
+}
+
+TYPED_TEST(TestStringKernels, IsDigitUnicode) {
+ // These are digits according to Python, but we don't have the information in
+ // utf8proc for this
+ // this->CheckUnary("utf8_is_digit", "[\"²\", \"①\"]", boolean(), "[true,
+ // true]");
+}
+
+TYPED_TEST(TestStringKernels, IsNumericUnicode) {
+ // ٣ is arabic 3 (decimal), Ⅳ roman (non-decimal)
+ this->CheckUnary("utf8_is_numeric", "[\"12\", null, \"٣\", \"Ⅳ\", \"1a\", \"\"]",
+ boolean(), "[true, null, true, true, false, false]");
+ // These are numerical according to Python, but we don't have the information in
+ // utf8proc for this
+ // this->CheckUnary("utf8_is_numeric", "[\"㐅\", \"卌\"]", boolean(),
+ // "[true, null, true, true, false, false]");
+}
+
+TYPED_TEST(TestStringKernels, IsLowerUnicode) {
+ // ٣ is arabic 3 (decimal), Φ capital
+ this->CheckUnary("utf8_is_lower",
+ "[\"12\", null, \"٣a\", \"٣A\", \"1a\", \"Φ\", \"\", \"with space\", "
+ "\"With space\"]",
+ boolean(),
+ "[false, null, true, false, true, false, false, true, false]");
+ // lower case character utf8proc does not know about
+ // this->CheckUnary("utf8_is_lower", "[\"ª\", \"ₕ\"]", boolean(), "[true,
+ // true]");
+}
+
+TYPED_TEST(TestStringKernels, IsPrintableUnicode) {
+ // U+2008 (utf8: \xe2\x80\x88) is punctuation space, it is NOT printable
+ // U+0378 (utf8: \xCD\xB8) is an undefined char, it has no category
+ this->CheckUnary(
+ "utf8_is_printable",
+ "[\" 123azAZ!~\", null, \"\xe2\x80\x88\", \"\", \"\\r\", \"\xCD\xB8\"]", boolean(),
+ "[true, null, false, true, false, false]");
+}
+
+TYPED_TEST(TestStringKernels, IsSpaceUnicode) {
+ // U+2008 (utf8: \xe2\x80\x88) is punctuation space
+ this->CheckUnary("utf8_is_space", "[\" \", null, \" \", \"\\t\\r\"]", boolean(),
+ "[true, null, true, true]");
+ this->CheckUnary("utf8_is_space", "[\" a\", null, \"a \", \"~\", \"\xe2\x80\x88\"]",
+ boolean(), "[false, null, false, false, true]");
+}
+
+TYPED_TEST(TestStringKernels, IsTitleUnicode) {
+ // ٣ is arabic 3 (decimal), Φ capital
+ this->CheckUnary("utf8_is_title",
+ "[\"Is\", null, \"Is Title\", \"Is٣Title\", \"Is_DŽ\", \"Φ\", \"DŽ\"]",
+ boolean(), "[true, null, true, true, true, true, true]");
+ this->CheckUnary(
+ "utf8_is_title",
+ "[\"IsN\", null, \"IsNoTitle\", \"Is No T٣tle\", \"IsDŽ\", \"ΦΦ\", \"dž\", \"_\"]",
+ boolean(), "[false, null, false, false, false, false, false, false]");
+}
+
+// Older versions of utf8proc fail
+#if !(UTF8PROC_VERSION_MAJOR <= 2 && UTF8PROC_VERSION_MINOR < 5)
+
+TYPED_TEST(TestStringKernels, IsUpperUnicode) {
+ // ٣ is arabic 3 (decimal), Φ capital
+ this->CheckUnary("utf8_is_upper",
+ "[\"12\", null, \"٣a\", \"٣A\", \"1A\", \"Φ\", \"\", \"Ⅰ\", \"Ⅿ\"]",
+ boolean(),
+ "[false, null, false, true, true, true, false, true, true]");
+ // * Ⅰ to Ⅿ is a special case (roman capital), as well as Ⓐ to Ⓩ
+ // * ϒ - \xCF\x92 - Greek Upsilon with Hook Symbol - upper case, but has no direct lower
+ // case
+ // * U+1F88 - ᾈ - \E1\xBE\x88 - Greek Capital Letter Alpha with Psili and Prosgegrammeni
+ // - title case
+ // U+10400 - 𐐀 - \xF0x90x90x80 - Deseret Capital Letter Long - upper case
+ // * U+A7BA - Ꞻ - \xEA\x9E\xBA - Latin Capital Letter Glottal A - new in unicode 13
+ // (not tested since it depends on the version of libutf8proc)
+ // * U+A7BB - ꞻ - \xEA\x9E\xBB - Latin Small Letter Glottal A - new in unicode 13
+ this->CheckUnary("utf8_is_upper",
+ "[\"Ⓐ\", \"Ⓩ\", \"ϒ\", \"ᾈ\", \"\xEA\x9E\xBA\", \"xF0x90x90x80\"]",
+ boolean(), "[true, true, true, false, true, false]");
+}
+
+#endif // UTF8PROC_VERSION_MINOR >= 5
+
+#endif // ARROW_WITH_UTF8PROC
+
+TYPED_TEST(TestStringKernels, IsAlphaNumericAscii) {
+ this->CheckUnary("ascii_is_alnum",
+ "[\"ⱭɽⱤoW123\", null, \"Ɑ2\", \"!\", \"\", \"a space\", \"1 space\"]",
+ boolean(), "[false, null, false, false, false, false, false]");
+ this->CheckUnary("ascii_is_alnum", "[\"aRoW123\", null, \"a2\", \"a\", \"2\", \"\"]",
+ boolean(), "[true, null, true, true, true, false]");
+}
+
+TYPED_TEST(TestStringKernels, IsAlphaAscii) {
+ this->CheckUnary("ascii_is_alpha", "[\"ⱭɽⱤoW\", \"arrow\", null, \"a2\", \"!\", \"\"]",
+ boolean(), "[false, true, null, false, false, false]");
+}
+
+TYPED_TEST(TestStringKernels, IsDecimalAscii) {
+ // ٣ is arabic 3
+ this->CheckUnary("ascii_is_decimal", "[\"12\", null, \"٣\", \"Ⅳ\", \"1a\", \"\"]",
+ boolean(), "[true, null, false, false, false, false]");
+}
+
+TYPED_TEST(TestStringKernels, IsLowerAscii) {
+ // ٣ is arabic 3 (decimal), φ lower greek
+ this->CheckUnary("ascii_is_lower",
+ "[\"12\", null, \"٣a\", \"٣A\", \"1a\", \"φ\", \"\"]", boolean(),
+ "[false, null, true, false, true, false, false]");
+}
+TYPED_TEST(TestStringKernels, IsPrintableAscii) {
+ // \xe2\x80\x88 is punctuation space
+ this->CheckUnary("ascii_is_printable",
+ "[\" 123azAZ!~\", null, \"\xe2\x80\x88\", \"\", \"\\r\"]", boolean(),
+ "[true, null, false, true, false]");
+}
+
+TYPED_TEST(TestStringKernels, IsSpaceAscii) {
+ // \xe2\x80\x88 is punctuation space
+ this->CheckUnary("ascii_is_space", "[\" \", null, \" \", \"\\t\\r\"]", boolean(),
+ "[true, null, true, true]");
+ this->CheckUnary("ascii_is_space", "[\" a\", null, \"a \", \"~\", \"\xe2\x80\x88\"]",
+ boolean(), "[false, null, false, false, false]");
+}
+
+TYPED_TEST(TestStringKernels, IsTitleAscii) {
+ // ٣ is Arabic 3 (decimal), Φ capital
+ this->CheckUnary("ascii_is_title",
+ "[\"Is\", null, \"Is Title\", \"Is٣Title\", \"Is_DŽ\", \"Φ\", \"DŽ\"]",
+ boolean(), "[true, null, true, true, true, false, false]");
+ this->CheckUnary(
+ "ascii_is_title",
+ "[\"IsN\", null, \"IsNoTitle\", \"Is No T٣tle\", \"IsDŽ\", \"ΦΦ\", \"dž\", \"_\"]",
+ boolean(), "[false, null, false, false, true, false, false, false]");
+}
+
+TYPED_TEST(TestStringKernels, IsUpperAscii) {
+ // ٣ is arabic 3 (decimal), Φ capital greek
+ this->CheckUnary("ascii_is_upper",
+ "[\"12\", null, \"٣a\", \"٣A\", \"1A\", \"Φ\", \"\"]", boolean(),
+ "[false, null, false, true, true, false, false]");
+}
+
+TYPED_TEST(TestStringKernels, MatchSubstring) {
+ MatchSubstringOptions options{"ab"};
+ this->CheckUnary("match_substring", "[]", boolean(), "[]", &options);
+ this->CheckUnary("match_substring", R"(["abc", "acb", "cab", null, "bac", "AB"])",
+ boolean(), "[true, false, true, null, false, false]", &options);
+
+ MatchSubstringOptions options_repeated{"abab"};
+ this->CheckUnary("match_substring", R"(["abab", "ab", "cababc", null, "bac"])",
+ boolean(), "[true, false, true, null, false]", &options_repeated);
+
+ // ARROW-9460
+ MatchSubstringOptions options_double_char{"aab"};
+ this->CheckUnary("match_substring", R"(["aacb", "aab", "ab", "aaab"])", boolean(),
+ "[false, true, false, true]", &options_double_char);
+ MatchSubstringOptions options_double_char_2{"bbcaa"};
+ this->CheckUnary("match_substring", R"(["abcbaabbbcaabccabaab"])", boolean(), "[true]",
+ &options_double_char_2);
+
+ MatchSubstringOptions options_empty{""};
+ this->CheckUnary("match_substring", "[]", boolean(), "[]", &options);
+ this->CheckUnary("match_substring", R"(["abc", "acb", "cab", null, "bac", "AB", ""])",
+ boolean(), "[true, true, true, null, true, true, true]",
+ &options_empty);
+}
+
+#ifdef ARROW_WITH_RE2
+TYPED_TEST(TestStringKernels, MatchSubstringIgnoreCase) {
+ MatchSubstringOptions options_insensitive{"aé(", /*ignore_case=*/true};
+ this->CheckUnary("match_substring", R"(["abc", "aEb", "baÉ(", "aé(", "ae(", "Aé("])",
+ boolean(), "[false, false, true, true, false, true]",
+ &options_insensitive);
+}
+#else
+TYPED_TEST(TestStringKernels, MatchSubstringIgnoreCase) {
+ Datum input = ArrayFromJSON(this->type(), R"(["a"])");
+ MatchSubstringOptions options{"a", /*ignore_case=*/true};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
+ ::testing::HasSubstr("ignore_case requires RE2"),
+ CallFunction("match_substring", {input}, &options));
+}
+#endif
+
+TYPED_TEST(TestStringKernels, MatchStartsWith) {
+ MatchSubstringOptions options{"abab"};
+ this->CheckUnary("starts_with", "[]", boolean(), "[]", &options);
+ this->CheckUnary("starts_with", R"([null, "", "ab", "abab", "$abab", "abab$"])",
+ boolean(), "[null, false, false, true, false, true]", &options);
+ this->CheckUnary("starts_with", R"(["ABAB", "BABAB", "ABABC", "bAbAb", "aBaBc"])",
+ boolean(), "[false, false, false, false, false]", &options);
+}
+
+TYPED_TEST(TestStringKernels, MatchEndsWith) {
+ MatchSubstringOptions options{"abab"};
+ this->CheckUnary("ends_with", "[]", boolean(), "[]", &options);
+ this->CheckUnary("ends_with", R"([null, "", "ab", "abab", "$abab", "abab$"])",
+ boolean(), "[null, false, false, true, true, false]", &options);
+ this->CheckUnary("ends_with", R"(["ABAB", "BABAB", "ABABC", "bAbAb", "aBaBc"])",
+ boolean(), "[false, false, false, false, false]", &options);
+}
+
+#ifdef ARROW_WITH_RE2
+TYPED_TEST(TestStringKernels, MatchStartsWithIgnoreCase) {
+ MatchSubstringOptions options{"aBAb", /*ignore_case=*/true};
+ this->CheckUnary("starts_with", "[]", boolean(), "[]", &options);
+ this->CheckUnary("starts_with", R"([null, "", "ab", "abab", "$abab", "abab$"])",
+ boolean(), "[null, false, false, true, false, true]", &options);
+ this->CheckUnary("starts_with", R"(["ABAB", "$ABAB", "ABAB$", "$AbAb", "aBaB$"])",
+ boolean(), "[true, false, true, false, true]", &options);
+}
+
+TYPED_TEST(TestStringKernels, MatchEndsWithIgnoreCase) {
+ MatchSubstringOptions options{"aBAb", /*ignore_case=*/true};
+ this->CheckUnary("ends_with", "[]", boolean(), "[]", &options);
+ this->CheckUnary("ends_with", R"([null, "", "ab", "abab", "$abab", "abab$"])",
+ boolean(), "[null, false, false, true, true, false]", &options);
+ this->CheckUnary("ends_with", R"(["ABAB", "$ABAB", "ABAB$", "$AbAb", "aBaB$"])",
+ boolean(), "[true, true, false, true, false]", &options);
+}
+#else
+TYPED_TEST(TestStringKernels, MatchStartsWithIgnoreCase) {
+ Datum input = ArrayFromJSON(this->type(), R"(["a"])");
+ MatchSubstringOptions options{"a", /*ignore_case=*/true};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
+ ::testing::HasSubstr("ignore_case requires RE2"),
+ CallFunction("starts_with", {input}, &options));
+}
+
+TYPED_TEST(TestStringKernels, MatchEndsWithIgnoreCase) {
+ Datum input = ArrayFromJSON(this->type(), R"(["a"])");
+ MatchSubstringOptions options{"a", /*ignore_case=*/true};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
+ ::testing::HasSubstr("ignore_case requires RE2"),
+ CallFunction("ends_with", {input}, &options));
+}
+#endif
+
+#ifdef ARROW_WITH_RE2
+TYPED_TEST(TestStringKernels, MatchSubstringRegex) {
+ MatchSubstringOptions options{"ab"};
+ this->CheckUnary("match_substring_regex", "[]", boolean(), "[]", &options);
+ this->CheckUnary("match_substring_regex", R"(["abc", "acb", "cab", null, "bac", "AB"])",
+ boolean(), "[true, false, true, null, false, false]", &options);
+ MatchSubstringOptions options_repeated{"(ab){2}"};
+ this->CheckUnary("match_substring_regex", R"(["abab", "ab", "cababc", null, "bac"])",
+ boolean(), "[true, false, true, null, false]", &options_repeated);
+ MatchSubstringOptions options_digit{"\\d"};
+ this->CheckUnary("match_substring_regex", R"(["aacb", "a2ab", "", "24"])", boolean(),
+ "[false, true, false, true]", &options_digit);
+ MatchSubstringOptions options_star{"a*b"};
+ this->CheckUnary("match_substring_regex", R"(["aacb", "aab", "dab", "caaab", "b", ""])",
+ boolean(), "[true, true, true, true, true, false]", &options_star);
+ MatchSubstringOptions options_plus{"a+b"};
+ this->CheckUnary("match_substring_regex", R"(["aacb", "aab", "dab", "caaab", "b", ""])",
+ boolean(), "[false, true, true, true, false, false]", &options_plus);
+ MatchSubstringOptions options_insensitive{"ab|é", /*ignore_case=*/true};
+ this->CheckUnary("match_substring_regex", R"(["abc", "acb", "É", null, "bac", "AB"])",
+ boolean(), "[true, false, true, null, false, true]",
+ &options_insensitive);
+
+ // Unicode character semantics
+ // "\pL" means: unicode category "letter"
+ // (re2 interprets "\w" as ASCII-only: https://github.com/google/re2/wiki/Syntax)
+ MatchSubstringOptions options_unicode{"^\\pL+$"};
+ this->CheckUnary("match_substring_regex", R"(["été", "ß", "€", ""])", boolean(),
+ "[true, true, false, false]", &options_unicode);
+}
+
+TYPED_TEST(TestStringKernels, MatchSubstringRegexNoOptions) {
+ Datum input = ArrayFromJSON(this->type(), "[]");
+ ASSERT_RAISES(Invalid, CallFunction("match_substring_regex", {input}));
+}
+
+TYPED_TEST(TestStringKernels, MatchSubstringRegexInvalid) {
+ Datum input = ArrayFromJSON(this->type(), "[null]");
+ MatchSubstringOptions options{"invalid["};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Invalid regular expression: missing ]"),
+ CallFunction("match_substring_regex", {input}, &options));
+}
+
+TYPED_TEST(TestStringKernels, MatchLike) {
+ auto inputs = R"(["foo", "bar", "foobar", "barfoo", "o", "\nfoo", "foo\n", null])";
+
+ MatchSubstringOptions prefix_match{"foo%"};
+ this->CheckUnary("match_like", "[]", boolean(), "[]", &prefix_match);
+ this->CheckUnary("match_like", inputs, boolean(),
+ "[true, false, true, false, false, false, true, null]", &prefix_match);
+
+ MatchSubstringOptions suffix_match{"%foo"};
+ this->CheckUnary("match_like", inputs, boolean(),
+ "[true, false, false, true, false, true, false, null]", &suffix_match);
+
+ MatchSubstringOptions substring_match{"%foo%"};
+ this->CheckUnary("match_like", inputs, boolean(),
+ "[true, false, true, true, false, true, true, null]",
+ &substring_match);
+
+ MatchSubstringOptions trivial_match{"%%"};
+ this->CheckUnary("match_like", inputs, boolean(),
+ "[true, true, true, true, true, true, true, null]", &trivial_match);
+
+ MatchSubstringOptions regex_match{"foo%bar"};
+ this->CheckUnary("match_like", inputs, boolean(),
+ "[false, false, true, false, false, false, false, null]",
+ &regex_match);
+
+ // ignore_case means this still gets mapped to a regex search
+ MatchSubstringOptions insensitive_substring{"%é%", /*ignore_case=*/true};
+ this->CheckUnary("match_like", R"(["é", "fooÉbar", "e"])", boolean(),
+ "[true, true, false]", &insensitive_substring);
+
+ MatchSubstringOptions insensitive_regex{"_é%", /*ignore_case=*/true};
+ this->CheckUnary("match_like", R"(["éfoo", "aÉfoo", "e"])", boolean(),
+ "[false, true, false]", &insensitive_regex);
+}
+
+TYPED_TEST(TestStringKernels, MatchLikeEscaping) {
+ auto inputs = R"(["%%foo", "_bar", "({", "\\baz"])";
+
+ // N.B. I believe Impala mistakenly optimizes these into substring searches
+ MatchSubstringOptions escape_percent{"\\%%"};
+ this->CheckUnary("match_like", inputs, boolean(), "[true, false, false, false]",
+ &escape_percent);
+
+ MatchSubstringOptions not_substring{"%\\%%"};
+ this->CheckUnary("match_like", inputs, boolean(), "[true, false, false, false]",
+ &not_substring);
+
+ MatchSubstringOptions escape_underscore{"\\____"};
+ this->CheckUnary("match_like", inputs, boolean(), "[false, true, false, false]",
+ &escape_underscore);
+
+ MatchSubstringOptions escape_regex{"(%"};
+ this->CheckUnary("match_like", inputs, boolean(), "[false, false, true, false]",
+ &escape_regex);
+
+ MatchSubstringOptions escape_escape{"\\\\%"};
+ this->CheckUnary("match_like", inputs, boolean(), "[false, false, false, true]",
+ &escape_escape);
+
+ MatchSubstringOptions special_chars{"!@#$^&*()[]{}.?"};
+ this->CheckUnary("match_like", R"(["!@#$^&*()[]{}.?"])", boolean(), "[true]",
+ &special_chars);
+
+ MatchSubstringOptions escape_sequences{"\n\t%"};
+ this->CheckUnary("match_like", R"(["\n\tfoo\t", "\n\t", "\n"])", boolean(),
+ "[true, true, false]", &escape_sequences);
+}
+#endif
+
+TYPED_TEST(TestStringKernels, FindSubstring) {
+ MatchSubstringOptions options{"ab"};
+ this->CheckUnary("find_substring", "[]", this->offset_type(), "[]", &options);
+ this->CheckUnary("find_substring", R"(["abc", "acb", "cab", null, "bac"])",
+ this->offset_type(), "[0, -1, 1, null, -1]", &options);
+
+ MatchSubstringOptions options_repeated{"abab"};
+ this->CheckUnary("find_substring", R"(["abab", "ab", "cababc", null, "bac"])",
+ this->offset_type(), "[0, -1, 1, null, -1]", &options_repeated);
+
+ MatchSubstringOptions options_double_char{"aab"};
+ this->CheckUnary("find_substring", R"(["aacb", "aab", "ab", "aaab"])",
+ this->offset_type(), "[-1, 0, -1, 1]", &options_double_char);
+
+ MatchSubstringOptions options_double_char_2{"bbcaa"};
+ this->CheckUnary("find_substring", R"(["abcbaabbbcaabccabaab"])", this->offset_type(),
+ "[7]", &options_double_char_2);
+}
+
+TYPED_TEST(TestStringKernels, SplitBasics) {
+ SplitPatternOptions options{" "};
+ // basics
+ this->CheckUnary("split_pattern", R"(["foo bar", "foo"])", list(this->type()),
+ R"([["foo", "bar"], ["foo"]])", &options);
+ this->CheckUnary("split_pattern", R"(["foo bar", "foo", null])", list(this->type()),
+ R"([["foo", "bar"], ["foo"], null])", &options);
+ // edgy cases
+ this->CheckUnary("split_pattern", R"(["f o o "])", list(this->type()),
+ R"([["f", "", "o", "o", ""]])", &options);
+ this->CheckUnary("split_pattern", "[]", list(this->type()), "[]", &options);
+ // longer patterns
+ SplitPatternOptions options_long{"---"};
+ this->CheckUnary("split_pattern", R"(["-foo---bar--", "---foo---b"])",
+ list(this->type()), R"([["-foo", "bar--"], ["", "foo", "b"]])",
+ &options_long);
+ SplitPatternOptions options_long_reverse{"---", -1, /*reverse=*/true};
+ this->CheckUnary("split_pattern", R"(["-foo---bar--", "---foo---b"])",
+ list(this->type()), R"([["-foo", "bar--"], ["", "foo", "b"]])",
+ &options_long_reverse);
+}
+
+TYPED_TEST(TestStringKernels, SplitMax) {
+ SplitPatternOptions options{"---", 2};
+ SplitPatternOptions options_reverse{"---", 2, /*reverse=*/true};
+ this->CheckUnary("split_pattern", R"(["foo---bar", "foo", "foo---bar------ar"])",
+ list(this->type()),
+ R"([["foo", "bar"], ["foo"], ["foo", "bar", "---ar"]])", &options);
+ this->CheckUnary(
+ "split_pattern", R"(["foo---bar", "foo", "foo---bar------ar"])", list(this->type()),
+ R"([["foo", "bar"], ["foo"], ["foo---bar", "", "ar"]])", &options_reverse);
+}
+
+TYPED_TEST(TestStringKernels, SplitWhitespaceAscii) {
+ SplitOptions options;
+ SplitOptions options_max{1};
+ // basics
+ this->CheckUnary("ascii_split_whitespace", R"(["foo bar", "foo bar \tba"])",
+ list(this->type()), R"([["foo", "bar"], ["foo", "bar", "ba"]])",
+ &options);
+ this->CheckUnary("ascii_split_whitespace", R"(["foo bar", "foo bar \tba"])",
+ list(this->type()), R"([["foo", "bar"], ["foo", "bar \tba"]])",
+ &options_max);
+}
+
+TYPED_TEST(TestStringKernels, SplitWhitespaceAsciiReverse) {
+ SplitOptions options{-1, /*reverse=*/true};
+ SplitOptions options_max{1, /*reverse=*/true};
+ // basics
+ this->CheckUnary("ascii_split_whitespace", R"(["foo bar", "foo bar \tba"])",
+ list(this->type()), R"([["foo", "bar"], ["foo", "bar", "ba"]])",
+ &options);
+ this->CheckUnary("ascii_split_whitespace", R"(["foo bar", "foo bar \tba"])",
+ list(this->type()), R"([["foo", "bar"], ["foo bar", "ba"]])",
+ &options_max);
+}
+
+TYPED_TEST(TestStringKernels, SplitWhitespaceUTF8) {
+ SplitOptions options;
+ SplitOptions options_max{1};
+ // \xe2\x80\x88 is punctuation space
+ this->CheckUnary("utf8_split_whitespace",
+ "[\"foo bar\", \"foo\xe2\x80\x88 bar \\tba\"]", list(this->type()),
+ R"([["foo", "bar"], ["foo", "bar", "ba"]])", &options);
+ this->CheckUnary("utf8_split_whitespace",
+ "[\"foo bar\", \"foo\xe2\x80\x88 bar \\tba\"]", list(this->type()),
+ R"([["foo", "bar"], ["foo", "bar \tba"]])", &options_max);
+}
+
+TYPED_TEST(TestStringKernels, SplitWhitespaceUTF8Reverse) {
+ SplitOptions options{-1, /*reverse=*/true};
+ SplitOptions options_max{1, /*reverse=*/true};
+ // \xe2\x80\x88 is punctuation space
+ this->CheckUnary("utf8_split_whitespace",
+ "[\"foo bar\", \"foo\xe2\x80\x88 bar \\tba\"]", list(this->type()),
+ R"([["foo", "bar"], ["foo", "bar", "ba"]])", &options);
+ this->CheckUnary("utf8_split_whitespace",
+ "[\"foo bar\", \"foo\xe2\x80\x88 bar \\tba\"]", list(this->type()),
+ "[[\"foo\", \"bar\"], [\"foo\xe2\x80\x88 bar\", \"ba\"]]",
+ &options_max);
+}
+
+#ifdef ARROW_WITH_RE2
+TYPED_TEST(TestStringKernels, SplitRegex) {
+ SplitPatternOptions options{"a+|b"};
+
+ this->CheckUnary(
+ "split_pattern_regex", R"(["aaaab", "foob", "foo bar", "foo", "AaaaBaaaC", null])",
+ list(this->type()),
+ R"([["", "", ""], ["foo", ""], ["foo ", "", "r"], ["foo"], ["A", "B", "C"], null])",
+ &options);
+
+ options.max_splits = 1;
+ this->CheckUnary(
+ "split_pattern_regex", R"(["aaaab", "foob", "foo bar", "foo", "AaaaBaaaC", null])",
+ list(this->type()),
+ R"([["", "b"], ["foo", ""], ["foo ", "ar"], ["foo"], ["A", "BaaaC"], null])",
+ &options);
+}
+
+TYPED_TEST(TestStringKernels, SplitRegexReverse) {
+ SplitPatternOptions options{"a+|b", /*max_splits=*/1, /*reverse=*/true};
+ Datum input = ArrayFromJSON(this->type(), R"(["a"])");
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ NotImplemented, ::testing::HasSubstr("Cannot split in reverse with regex"),
+ CallFunction("split_pattern_regex", {input}, &options));
+}
+#endif
+
+TYPED_TEST(TestStringKernels, Utf8ReplaceSlice) {
+ ReplaceSliceOptions options{0, 1, "χχ"};
+ this->CheckUnary("utf8_replace_slice", "[]", this->type(), "[]", &options);
+ this->CheckUnary("utf8_replace_slice", R"([null, "", "π", "πb", "πbθ"])", this->type(),
+ R"([null, "χχ", "χχ", "χχb", "χχbθ"])", &options);
+
+ ReplaceSliceOptions options_whole{0, 5, "χχ"};
+ this->CheckUnary("utf8_replace_slice",
+ R"([null, "", "π", "πb", "πbθ", "πbθde", "πbθdef"])", this->type(),
+ R"([null, "χχ", "χχ", "χχ", "χχ", "χχ", "χχf"])", &options_whole);
+
+ ReplaceSliceOptions options_middle{2, 4, "χχ"};
+ this->CheckUnary("utf8_replace_slice",
+ R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(),
+ R"([null, "χχ", "πχχ", "πbχχ", "πbχχ", "πbχχ", "πbχχe"])",
+ &options_middle);
+
+ ReplaceSliceOptions options_neg_start{-3, -2, "χχ"};
+ this->CheckUnary("utf8_replace_slice",
+ R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(),
+ R"([null, "χχ", "χχπ", "χχπb", "χχbθ", "πχχθd", "πbχχde"])",
+ &options_neg_start);
+
+ ReplaceSliceOptions options_neg_end{2, -2, "χχ"};
+ this->CheckUnary("utf8_replace_slice",
+ R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(),
+ R"([null, "χχ", "πχχ", "πbχχ", "πbχχθ", "πbχχθd", "πbχχde"])",
+ &options_neg_end);
+
+ ReplaceSliceOptions options_neg_pos{-1, 2, "χχ"};
+ this->CheckUnary("utf8_replace_slice",
+ R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(),
+ R"([null, "χχ", "χχ", "πχχ", "πbχχθ", "πbθχχd", "πbθdχχe"])",
+ &options_neg_pos);
+
+ // Effectively the same as [2, 2)
+ ReplaceSliceOptions options_flip{2, 0, "χχ"};
+ this->CheckUnary("utf8_replace_slice",
+ R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(),
+ R"([null, "χχ", "πχχ", "πbχχ", "πbχχθ", "πbχχθd", "πbχχθde"])",
+ &options_flip);
+
+ // Effectively the same as [-3, -3)
+ ReplaceSliceOptions options_neg_flip{-3, -5, "χχ"};
+ this->CheckUnary("utf8_replace_slice",
+ R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(),
+ R"([null, "χχ", "χχπ", "χχπb", "χχπbθ", "πχχbθd", "πbχχθde"])",
+ &options_neg_flip);
+}
+
+TYPED_TEST(TestStringKernels, ReplaceSubstring) {
+ ReplaceSubstringOptions options{"foo", "bazz"};
+ this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])",
+ this->type(), R"(["bazz", "this bazz that bazz", null])", &options);
+}
+
+TYPED_TEST(TestStringKernels, ReplaceSubstringLimited) {
+ ReplaceSubstringOptions options{"foo", "bazz", 1};
+ this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])",
+ this->type(), R"(["bazz", "this bazz that foo", null])", &options);
+}
+
+TYPED_TEST(TestStringKernels, ReplaceSubstringNoOptions) {
+ Datum input = ArrayFromJSON(this->type(), "[]");
+ ASSERT_RAISES(Invalid, CallFunction("replace_substring", {input}));
+}
+
+#ifdef ARROW_WITH_RE2
+TYPED_TEST(TestStringKernels, ReplaceSubstringRegex) {
+ ReplaceSubstringOptions options_regex{"(fo+)\\s*", "\\1-bazz"};
+ this->CheckUnary("replace_substring_regex", R"(["foo ", "this foo that foo", null])",
+ this->type(), R"(["foo-bazz", "this foo-bazzthat foo-bazz", null])",
+ &options_regex);
+ // make sure we match non-overlapping
+ ReplaceSubstringOptions options_regex2{"(a.a)", "aba\\1"};
+ this->CheckUnary("replace_substring_regex", R"(["aaaaaa"])", this->type(),
+ R"(["abaaaaabaaaa"])", &options_regex2);
+
+ // ARROW-12774
+ ReplaceSubstringOptions options_regex3{"X", "Y"};
+ this->CheckUnary("replace_substring_regex",
+ R"(["A","A","A","A","A","A","A","A","A","A","A","A","A","A","A","A"])",
+ this->type(),
+ R"(["A","A","A","A","A","A","A","A","A","A","A","A","A","A","A","A"])",
+ &options_regex3);
+}
+
+TYPED_TEST(TestStringKernels, ReplaceSubstringRegexLimited) {
+ // With a finite number of replacements
+ ReplaceSubstringOptions options1{"foo", "bazz", 1};
+ this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])",
+ this->type(), R"(["bazz", "this bazz that foo", null])", &options1);
+ ReplaceSubstringOptions options_regex1{"(fo+)\\s*", "\\1-bazz", 1};
+ this->CheckUnary("replace_substring_regex", R"(["foo ", "this foo that foo", null])",
+ this->type(), R"(["foo-bazz", "this foo-bazzthat foo", null])",
+ &options_regex1);
+}
+
+TYPED_TEST(TestStringKernels, ReplaceSubstringRegexNoOptions) {
+ Datum input = ArrayFromJSON(this->type(), "[]");
+ ASSERT_RAISES(Invalid, CallFunction("replace_substring_regex", {input}));
+}
+
+TYPED_TEST(TestStringKernels, ReplaceSubstringRegexInvalid) {
+ Datum input = ArrayFromJSON(this->type(), R"(["foo"])");
+ ReplaceSubstringOptions options{"invalid[", ""};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Invalid regular expression: missing ]"),
+ CallFunction("replace_substring_regex", {input}, &options));
+
+ // Capture group number out of range
+ options = ReplaceSubstringOptions{"(.)", "\\9"};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Invalid replacement string"),
+ CallFunction("replace_substring_regex", {input}, &options));
+}
+
+TYPED_TEST(TestStringKernels, ExtractRegex) {
+ ExtractRegexOptions options{"(?P<letter>[ab])(?P<digit>\\d)"};
+ auto type = struct_({field("letter", this->type()), field("digit", this->type())});
+ this->CheckUnary("extract_regex", R"([])", type, R"([])", &options);
+ this->CheckUnary(
+ "extract_regex", R"(["a1", "b2", "c3", null])", type,
+ R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "2"}, null, null])",
+ &options);
+ this->CheckUnary(
+ "extract_regex", R"(["a1", "c3", null, "b2"])", type,
+ R"([{"letter": "a", "digit": "1"}, null, null, {"letter": "b", "digit": "2"}])",
+ &options);
+ this->CheckUnary("extract_regex", R"(["a1", "b2"])", type,
+ R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "2"}])",
+ &options);
+ this->CheckUnary("extract_regex", R"(["a1", "zb3z"])", type,
+ R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "3"}])",
+ &options);
+}
+
+TYPED_TEST(TestStringKernels, ExtractRegexNoCapture) {
+ // XXX Should we accept this or is it a user error?
+ ExtractRegexOptions options{"foo"};
+ auto type = struct_({});
+ this->CheckUnary("extract_regex", R"(["oofoo", "bar", null])", type,
+ R"([{}, null, null])", &options);
+}
+
+TYPED_TEST(TestStringKernels, ExtractRegexNoOptions) {
+ Datum input = ArrayFromJSON(this->type(), "[]");
+ ASSERT_RAISES(Invalid, CallFunction("extract_regex", {input}));
+}
+
+TYPED_TEST(TestStringKernels, ExtractRegexInvalid) {
+ Datum input = ArrayFromJSON(this->type(), "[]");
+ ExtractRegexOptions options{"invalid["};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Invalid regular expression: missing ]"),
+ CallFunction("extract_regex", {input}, &options));
+
+ options = ExtractRegexOptions{"(.)"};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Regular expression contains unnamed groups"),
+ CallFunction("extract_regex", {input}, &options));
+}
+
+#endif
+
+TYPED_TEST(TestStringKernels, Strptime) {
+ std::string input1 = R"(["5/1/2020", null, "12/11/1900"])";
+ std::string output1 = R"(["2020-05-01", null, "1900-12-11"])";
+ StrptimeOptions options("%m/%d/%Y", TimeUnit::MICRO);
+ this->CheckUnary("strptime", input1, timestamp(TimeUnit::MICRO), output1, &options);
+}
+
+TYPED_TEST(TestStringKernels, StrptimeDoesNotProvideDefaultOptions) {
+ auto input = ArrayFromJSON(this->type(), R"(["2020-05-01", null, "1900-12-11"])");
+ ASSERT_RAISES(Invalid, CallFunction("strptime", {input}));
+}
+
+TYPED_TEST(TestStringKernels, BinaryJoin) {
+ // Scalar separator
+ auto separator = this->scalar("--");
+ std::string list_json =
+ R"([["a", "bb", "ccc"], [], null, ["dd"], ["eee", null], ["ff", ""]])";
+ auto expected =
+ ArrayFromJSON(this->type(), R"(["a--bb--ccc", "", null, "dd", null, "ff--"])");
+ CheckScalarBinary("binary_join", ArrayFromJSON(list(this->type()), list_json),
+ Datum(separator), expected);
+ CheckScalarBinary("binary_join", ArrayFromJSON(large_list(this->type()), list_json),
+ Datum(separator), expected);
+
+ auto separator_null = MakeNullScalar(this->type());
+ expected = ArrayFromJSON(this->type(), R"([null, null, null, null, null, null])");
+ CheckScalarBinary("binary_join", ArrayFromJSON(list(this->type()), list_json),
+ separator_null, expected);
+ CheckScalarBinary("binary_join", ArrayFromJSON(large_list(this->type()), list_json),
+ separator_null, expected);
+
+ // Array list, Array separator
+ auto separators =
+ ArrayFromJSON(this->type(), R"(["1", "2", "3", "4", "5", "6", null])");
+ list_json =
+ R"([["a", "bb", "ccc"], [], null, ["dd"], ["eee", null], ["ff", ""], ["hh", "ii"]])";
+ expected =
+ ArrayFromJSON(this->type(), R"(["a1bb1ccc", "", null, "dd", null, "ff6", null])");
+ CheckScalarBinary("binary_join", ArrayFromJSON(list(this->type()), list_json),
+ separators, expected);
+ CheckScalarBinary("binary_join", ArrayFromJSON(large_list(this->type()), list_json),
+ separators, expected);
+
+ // Scalar list, Array separator
+ separators = ArrayFromJSON(this->type(), R"(["1", "", null])");
+ list_json = R"(["a", "bb", "ccc"])";
+ expected = ArrayFromJSON(this->type(), R"(["a1bb1ccc", "abbccc", null])");
+ CheckScalarBinary("binary_join", ScalarFromJSON(list(this->type()), list_json),
+ separators, expected);
+ CheckScalarBinary("binary_join", ScalarFromJSON(large_list(this->type()), list_json),
+ separators, expected);
+ list_json = R"(["a", "bb", null])";
+ expected = ArrayFromJSON(this->type(), R"([null, null, null])");
+ CheckScalarBinary("binary_join", ScalarFromJSON(list(this->type()), list_json),
+ separators, expected);
+ CheckScalarBinary("binary_join", ScalarFromJSON(large_list(this->type()), list_json),
+ separators, expected);
+}
+
+TYPED_TEST(TestStringKernels, PadUTF8) {
+ // \xe2\x80\x88 = \u2008 is punctuation space, \xc3\xa1 = \u00E1 = á
+ PadOptions options{/*width=*/5, "\xe2\x80\x88"};
+ this->CheckUnary(
+ "utf8_center", R"([null, "a", "bb", "b\u00E1r", "foobar"])", this->type(),
+ R"([null, "\u2008\u2008a\u2008\u2008", "\u2008bb\u2008\u2008", "\u2008b\u00E1r\u2008", "foobar"])",
+ &options);
+ this->CheckUnary(
+ "utf8_lpad", R"([null, "a", "bb", "b\u00E1r", "foobar"])", this->type(),
+ R"([null, "\u2008\u2008\u2008\u2008a", "\u2008\u2008\u2008bb", "\u2008\u2008b\u00E1r", "foobar"])",
+ &options);
+ this->CheckUnary(
+ "utf8_rpad", R"([null, "a", "bb", "b\u00E1r", "foobar"])", this->type(),
+ R"([null, "a\u2008\u2008\u2008\u2008", "bb\u2008\u2008\u2008", "b\u00E1r\u2008\u2008", "foobar"])",
+ &options);
+
+ PadOptions options_bad{/*width=*/3, /*padding=*/"spam"};
+ auto input = ArrayFromJSON(this->type(), R"(["foo"])");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("Padding must be one codepoint"),
+ CallFunction("utf8_lpad", {input}, &options_bad));
+ options_bad.padding = "";
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("Padding must be one codepoint"),
+ CallFunction("utf8_lpad", {input}, &options_bad));
+}
+
+#ifdef ARROW_WITH_UTF8PROC
+
+TYPED_TEST(TestStringKernels, TrimWhitespaceUTF8) {
+ // \xe2\x80\x88 is punctuation space
+ this->CheckUnary("utf8_trim_whitespace",
+ "[\" \\tfoo\", null, \"bar \", \" \xe2\x80\x88 foo bar \"]",
+ this->type(), "[\"foo\", null, \"bar\", \"foo bar\"]");
+ this->CheckUnary("utf8_rtrim_whitespace",
+ "[\" \\tfoo\", null, \"bar \", \" \xe2\x80\x88 foo bar \"]",
+ this->type(),
+ "[\" \\tfoo\", null, \"bar\", \" \xe2\x80\x88 foo bar\"]");
+ this->CheckUnary("utf8_ltrim_whitespace",
+ "[\" \\tfoo\", null, \"bar \", \" \xe2\x80\x88 foo bar \"]",
+ this->type(), "[\"foo\", null, \"bar \", \"foo bar \"]");
+}
+
+TYPED_TEST(TestStringKernels, TrimUTF8) {
+ auto options = TrimOptions{"ab"};
+ this->CheckUnary("utf8_trim", "[\"azȺz矢ba\", null, \"bab\", \"zȺz\"]", this->type(),
+ "[\"zȺz矢\", null, \"\", \"zȺz\"]", &options);
+ this->CheckUnary("utf8_ltrim", "[\"azȺz矢ba\", null, \"bab\", \"zȺz\"]", this->type(),
+ "[\"zȺz矢ba\", null, \"\", \"zȺz\"]", &options);
+ this->CheckUnary("utf8_rtrim", "[\"azȺz矢ba\", null, \"bab\", \"zȺz\"]", this->type(),
+ "[\"azȺz矢\", null, \"\", \"zȺz\"]", &options);
+
+ options = TrimOptions{"ȺA"};
+ this->CheckUnary("utf8_trim", "[\"ȺȺfoo矢ȺAȺ\", null, \"barȺAȺ\", \"ȺAȺfooȺAȺ矢barA\"]",
+ this->type(), "[\"foo矢\", null, \"bar\", \"fooȺAȺ矢bar\"]", &options);
+ this->CheckUnary(
+ "utf8_ltrim", "[\"ȺȺfoo矢ȺAȺ\", null, \"barȺAȺ\", \"ȺAȺfooȺAȺ矢barA\"]",
+ this->type(), "[\"foo矢ȺAȺ\", null, \"barȺAȺ\", \"fooȺAȺ矢barA\"]", &options);
+ this->CheckUnary(
+ "utf8_rtrim", "[\"ȺȺfoo矢ȺAȺ\", null, \"barȺAȺ\", \"ȺAȺfooȺAȺ矢barA\"]",
+ this->type(), "[\"ȺȺfoo矢\", null, \"bar\", \"ȺAȺfooȺAȺ矢bar\"]", &options);
+
+ TrimOptions options_invalid{"ɑa\xFFɑ"};
+ auto input = ArrayFromJSON(this->type(), "[\"foo\"]");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("Invalid UTF8"),
+ CallFunction("utf8_trim", {input}, &options_invalid));
+}
+#endif
+
+// produce test data with e.g.:
+// repr([k[-3:1] for k in ["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"]]).replace("'", '"')
+
+#ifdef ARROW_WITH_UTF8PROC
+TYPED_TEST(TestStringKernels, SliceCodeunitsBasic) {
+ SliceOptions options{2, 4};
+ this->CheckUnary("utf8_slice_codeunits", R"(["foo", "fo", null, "foo bar"])",
+ this->type(), R"(["o", "", null, "o "])", &options);
+ SliceOptions options_2{2, 3};
+ // ensure we slice in codeunits, not graphemes
+ // a\u0308 is ä, which is 1 grapheme (character), but two codepoints
+ // \u0308 in utf8 encoding is \xcc\x88
+ this->CheckUnary("utf8_slice_codeunits", R"(["ää", "bä"])", this->type(),
+ "[\"a\", \"\xcc\x88\"]", &options_2);
+ SliceOptions options_empty_pos{6, 6};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓öõ"])", this->type(), R"(["",
+ ""])",
+ &options_empty_pos);
+ SliceOptions options_empty_neg{-6, -6};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓öõ"])", this->type(), R"(["",
+ ""])",
+ &options_empty_neg);
+ SliceOptions options_empty_neg_to_zero{-6, 0};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓öõ"])", this->type(), R"(["", ""])",
+ &options_empty_neg_to_zero);
+
+ // end is beyond 0, but before start (hence empty)
+ SliceOptions options_edgecase_1{-3, 1};
+ this->CheckUnary("utf8_slice_codeunits", R"(["𝑓öõḍš"])", this->type(), R"([""])",
+ &options_edgecase_1);
+
+ // this is a safeguard agains an optimization path possible, but actually a tricky case
+ SliceOptions options_edgecase_2{-6, -2};
+ this->CheckUnary("utf8_slice_codeunits", R"(["𝑓öõḍš"])", this->type(), R"(["𝑓öõ"])",
+ &options_edgecase_2);
+
+ auto input = ArrayFromJSON(this->type(), R"(["𝑓öõḍš"])");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr("Attempted to initialize KernelState from null FunctionOptions"),
+ CallFunction("utf8_slice_codeunits", {input}));
+
+ SliceOptions options_invalid{2, 4, 0};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, testing::HasSubstr("Slice step cannot be zero"),
+ CallFunction("utf8_slice_codeunits", {input}, &options_invalid));
+}
+
+TYPED_TEST(TestStringKernels, SliceCodeunitsPosPos) {
+ SliceOptions options{2, 4};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])",
+ this->type(), R"(["", "", "", "õ", "õḍ", "õḍ"])", &options);
+ SliceOptions options_step{1, 5, 2};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])",
+ this->type(), R"(["", "", "ö", "ö", "öḍ", "öḍ"])", &options_step);
+ SliceOptions options_step_neg{5, 1, -2};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])",
+ this->type(), R"(["", "", "", "õ", "ḍ", "šõ"])", &options_step_neg);
+ options_step_neg.stop = 0;
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ","𝑓öõḍš"])",
+ this->type(), R"(["", "", "ö", "õ", "ḍö", "šõ"])", &options_step_neg);
+}
+
+TYPED_TEST(TestStringKernels, SliceCodeunitsPosNeg) {
+ SliceOptions options{2, -1};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])",
+ this->type(), R"(["", "", "", "", "õ", "õḍ"])", &options);
+ SliceOptions options_step{1, -1, 2};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "f", "fö", "föo", "föod","foodš"])",
+ this->type(), R"(["", "", "", "ö", "ö", "od"])", &options_step);
+ SliceOptions options_step_neg{3, -4, -2};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ","𝑓öõḍš"])",
+ this->type(), R"(["", "𝑓", "ö", "õ𝑓", "ḍö", "ḍ"])", &options_step_neg);
+ options_step_neg.stop = -5;
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ","𝑓öõḍš"])",
+ this->type(), R"(["", "𝑓", "ö", "õ𝑓", "ḍö", "ḍö"])",
+ &options_step_neg);
+}
+
+TYPED_TEST(TestStringKernels, SliceCodeunitsNegNeg) {
+ SliceOptions options{-2, -1};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])",
+ this->type(), R"(["", "", "𝑓", "ö", "õ", "ḍ"])", &options);
+ SliceOptions options_step{-4, -1, 2};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])",
+ this->type(), R"(["", "", "𝑓", "𝑓", "𝑓õ", "öḍ"])", &options_step);
+ SliceOptions options_step_neg{-1, -3, -2};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])",
+ this->type(), R"(["", "𝑓", "ö", "õ", "ḍ", "š"])", &options_step_neg);
+ options_step_neg.stop = -4;
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])",
+ this->type(), R"(["", "𝑓", "ö", "õ𝑓", "ḍö", "šõ"])",
+ &options_step_neg);
+}
+
+TYPED_TEST(TestStringKernels, SliceCodeunitsNegPos) {
+ SliceOptions options{-2, 4};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])",
+ this->type(), R"(["", "𝑓", "𝑓ö", "öõ", "õḍ", "ḍ"])", &options);
+ SliceOptions options_step{-4, 4, 2};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])",
+ this->type(), R"(["", "𝑓", "𝑓", "𝑓õ", "𝑓õ", "öḍ"])", &options_step);
+ SliceOptions options_step_neg{-1, 1, -2};
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])",
+ this->type(), R"(["", "", "", "õ", "ḍ", "šõ"])", &options_step_neg);
+ options_step_neg.stop = 0;
+ this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])",
+ this->type(), R"(["", "", "ö", "õ", "ḍö", "šõ"])", &options_step_neg);
+}
+
+#endif // ARROW_WITH_UTF8PROC
+
+TYPED_TEST(TestStringKernels, PadAscii) {
+ PadOptions options{/*width=*/5, " "};
+ this->CheckUnary("ascii_center", R"([null, "a", "bb", "bar", "foobar"])", this->type(),
+ R"([null, " a ", " bb ", " bar ", "foobar"])", &options);
+ this->CheckUnary("ascii_lpad", R"([null, "a", "bb", "bar", "foobar"])", this->type(),
+ R"([null, " a", " bb", " bar", "foobar"])", &options);
+ this->CheckUnary("ascii_rpad", R"([null, "a", "bb", "bar", "foobar"])", this->type(),
+ R"([null, "a ", "bb ", "bar ", "foobar"])", &options);
+
+ PadOptions options_bad{/*width=*/3, /*padding=*/"spam"};
+ auto input = ArrayFromJSON(this->type(), R"(["foo"])");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("Padding must be one byte"),
+ CallFunction("ascii_lpad", {input}, &options_bad));
+ options_bad.padding = "";
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("Padding must be one byte"),
+ CallFunction("ascii_lpad", {input}, &options_bad));
+}
+
+TYPED_TEST(TestStringKernels, TrimWhitespaceAscii) {
+ // \xe2\x80\x88 is punctuation space
+ this->CheckUnary("ascii_trim_whitespace",
+ "[\" \\tfoo\", null, \"bar \", \" \xe2\x80\x88 foo bar \"]",
+ this->type(), "[\"foo\", null, \"bar\", \"\xe2\x80\x88 foo bar\"]");
+ this->CheckUnary("ascii_rtrim_whitespace",
+ "[\" \\tfoo\", null, \"bar \", \" \xe2\x80\x88 foo bar \"]",
+ this->type(),
+ "[\" \\tfoo\", null, \"bar\", \" \xe2\x80\x88 foo bar\"]");
+ this->CheckUnary("ascii_ltrim_whitespace",
+ "[\" \\tfoo\", null, \"bar \", \" \xe2\x80\x88 foo bar \"]",
+ this->type(), "[\"foo\", null, \"bar \", \"\xe2\x80\x88 foo bar \"]");
+}
+
+TYPED_TEST(TestStringKernels, TrimAscii) {
+ TrimOptions options{"BA"};
+ this->CheckUnary("ascii_trim", "[\"BBfooBAB\", null, \"barBAB\", \"BABfooBABbarA\"]",
+ this->type(), "[\"foo\", null, \"bar\", \"fooBABbar\"]", &options);
+ this->CheckUnary("ascii_ltrim", "[\"BBfooBAB\", null, \"barBAB\", \"BABfooBABbarA\"]",
+ this->type(), "[\"fooBAB\", null, \"barBAB\", \"fooBABbarA\"]",
+ &options);
+ this->CheckUnary("ascii_rtrim", "[\"BBfooBAB\", null, \"barBAB\", \"BABfooBABbarA\"]",
+ this->type(), "[\"BBfoo\", null, \"bar\", \"BABfooBABbar\"]",
+ &options);
+}
+
+#ifdef ARROW_WITH_UTF8PROC
+TEST(TestStringKernels, UnicodeLibraryAssumptions) {
+ uint8_t output[4];
+ for (utf8proc_int32_t codepoint = 0x100; codepoint < 0x110000; codepoint++) {
+ utf8proc_ssize_t encoded_nbytes = utf8proc_encode_char(codepoint, output);
+ utf8proc_int32_t codepoint_upper = utf8proc_toupper(codepoint);
+ utf8proc_ssize_t encoded_nbytes_upper = utf8proc_encode_char(codepoint_upper, output);
+ // validate that upper casing will only lead to a byte length growth of max 3/2
+ if (encoded_nbytes == 2) {
+ EXPECT_LE(encoded_nbytes_upper, 3)
+ << "Expected the upper case codepoint for a 2 byte encoded codepoint to be "
+ "encoded in maximum 3 bytes, not "
+ << encoded_nbytes_upper;
+ }
+ utf8proc_int32_t codepoint_lower = utf8proc_tolower(codepoint);
+ utf8proc_ssize_t encoded_nbytes_lower = utf8proc_encode_char(codepoint_lower, output);
+ // validate that lower casing will only lead to a byte length growth of max 3/2
+ if (encoded_nbytes == 2) {
+ EXPECT_LE(encoded_nbytes_lower, 3)
+ << "Expected the lower case codepoint for a 2 byte encoded codepoint to be "
+ "encoded in maximum 3 bytes, not "
+ << encoded_nbytes_lower;
+ }
+ }
+}
+#endif
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_temporal_binary.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_temporal_binary.cc
new file mode 100644
index 000000000..e73c89857
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_temporal_binary.cc
@@ -0,0 +1,542 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cmath>
+#include <initializer_list>
+#include <sstream>
+
+#include "arrow/builder.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/temporal_internal.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/time.h"
+#include "arrow/vendored/datetime.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+namespace internal {
+
+namespace {
+
+using arrow_vendored::date::days;
+using arrow_vendored::date::floor;
+using arrow_vendored::date::hh_mm_ss;
+using arrow_vendored::date::local_days;
+using arrow_vendored::date::local_time;
+using arrow_vendored::date::locate_zone;
+using arrow_vendored::date::sys_days;
+using arrow_vendored::date::sys_time;
+using arrow_vendored::date::time_zone;
+using arrow_vendored::date::trunc;
+using arrow_vendored::date::weekday;
+using arrow_vendored::date::weeks;
+using arrow_vendored::date::year_month_day;
+using arrow_vendored::date::year_month_weekday;
+using arrow_vendored::date::years;
+using arrow_vendored::date::zoned_time;
+using arrow_vendored::date::literals::dec;
+using arrow_vendored::date::literals::jan;
+using arrow_vendored::date::literals::last;
+using arrow_vendored::date::literals::mon;
+using arrow_vendored::date::literals::sun;
+using arrow_vendored::date::literals::thu;
+using arrow_vendored::date::literals::wed;
+using internal::applicator::ScalarBinaryNotNullStatefulEqualTypes;
+
+using DayOfWeekState = OptionsWrapper<DayOfWeekOptions>;
+using WeekState = OptionsWrapper<WeekOptions>;
+
+Status CheckTimezones(const ExecBatch& batch) {
+ const auto& timezone = GetInputTimezone(batch.values[0]);
+ for (int i = 1; i < batch.num_values(); i++) {
+ const auto& other_timezone = GetInputTimezone(batch.values[i]);
+ if (other_timezone != timezone) {
+ return Status::TypeError("Got differing time zone '", other_timezone,
+ "' for argument ", i + 1, "; expected '", timezone, "'");
+ }
+ }
+ return Status::OK();
+}
+
+template <template <typename...> class Op, typename Duration, typename InType,
+ typename OutType>
+struct TemporalBinary {
+ template <typename OptionsType, typename T = InType>
+ static enable_if_timestamp<T, Status> ExecWithOptions(KernelContext* ctx,
+ const OptionsType* options,
+ const ExecBatch& batch,
+ Datum* out) {
+ RETURN_NOT_OK(CheckTimezones(batch));
+
+ const auto& timezone = GetInputTimezone(batch.values[0]);
+ if (timezone.empty()) {
+ using ExecTemplate = Op<Duration, NonZonedLocalizer>;
+ auto op = ExecTemplate(options, NonZonedLocalizer());
+ applicator::ScalarBinaryNotNullStatefulEqualTypes<OutType, T, ExecTemplate> kernel{
+ op};
+ return kernel.Exec(ctx, batch, out);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto tz, LocateZone(timezone));
+ using ExecTemplate = Op<Duration, ZonedLocalizer>;
+ auto op = ExecTemplate(options, ZonedLocalizer{tz});
+ applicator::ScalarBinaryNotNullStatefulEqualTypes<OutType, T, ExecTemplate> kernel{
+ op};
+ return kernel.Exec(ctx, batch, out);
+ }
+ }
+
+ template <typename OptionsType, typename T = InType>
+ static enable_if_t<!is_timestamp_type<T>::value, Status> ExecWithOptions(
+ KernelContext* ctx, const OptionsType* options, const ExecBatch& batch,
+ Datum* out) {
+ using ExecTemplate = Op<Duration, NonZonedLocalizer>;
+ auto op = ExecTemplate(options, NonZonedLocalizer());
+ applicator::ScalarBinaryNotNullStatefulEqualTypes<OutType, T, ExecTemplate> kernel{
+ op};
+ return kernel.Exec(ctx, batch, out);
+ }
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const FunctionOptions* options = nullptr;
+ return ExecWithOptions(ctx, options, batch, out);
+ }
+};
+
+template <template <typename...> class Op, typename Duration, typename InType,
+ typename OutType>
+struct TemporalDayOfWeekBinary : public TemporalBinary<Op, Duration, InType, OutType> {
+ using Base = TemporalBinary<Op, Duration, InType, OutType>;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const DayOfWeekOptions& options = DayOfWeekState::Get(ctx);
+ RETURN_NOT_OK(ValidateDayOfWeekOptions(options));
+ return Base::ExecWithOptions(ctx, &options, batch, out);
+ }
+};
+
+// ----------------------------------------------------------------------
+// Compute boundary crossings between two timestamps
+
+template <typename Duration, typename Localizer>
+struct YearsBetween {
+ YearsBetween(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0, typename Arg1>
+ T Call(KernelContext*, Arg0 arg0, Arg1 arg1, Status*) const {
+ year_month_day from(
+ floor<days>(localizer_.template ConvertTimePoint<Duration>(arg0)));
+ year_month_day to(floor<days>(localizer_.template ConvertTimePoint<Duration>(arg1)));
+ return static_cast<T>((to.year() - from.year()).count());
+ }
+
+ Localizer localizer_;
+};
+
+template <typename Duration, typename Localizer>
+struct QuartersBetween {
+ QuartersBetween(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ static int64_t GetQuarters(const year_month_day& ymd) {
+ return static_cast<int64_t>(static_cast<int32_t>(ymd.year())) * 4 + GetQuarter(ymd);
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ T Call(KernelContext*, Arg0 arg0, Arg1 arg1, Status*) const {
+ year_month_day from_ymd(
+ floor<days>(localizer_.template ConvertTimePoint<Duration>(arg0)));
+ year_month_day to_ymd(
+ floor<days>(localizer_.template ConvertTimePoint<Duration>(arg1)));
+ int64_t from_quarters = GetQuarters(from_ymd);
+ int64_t to_quarters = GetQuarters(to_ymd);
+ return static_cast<T>(to_quarters - from_quarters);
+ }
+
+ Localizer localizer_;
+};
+
+template <typename Duration, typename Localizer>
+struct MonthsBetween {
+ MonthsBetween(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0, typename Arg1>
+ T Call(KernelContext*, Arg0 arg0, Arg1 arg1, Status*) const {
+ year_month_day from(
+ floor<days>(localizer_.template ConvertTimePoint<Duration>(arg0)));
+ year_month_day to(floor<days>(localizer_.template ConvertTimePoint<Duration>(arg1)));
+ return static_cast<T>((to.year() / to.month() - from.year() / from.month()).count());
+ }
+
+ Localizer localizer_;
+};
+
+template <typename Duration, typename Localizer>
+struct WeeksBetween {
+ using days_t = typename Localizer::days_t;
+
+ WeeksBetween(const DayOfWeekOptions* options, Localizer&& localizer)
+ : week_start_(options->week_start), localizer_(std::move(localizer)) {}
+
+ /// Adjust the day backwards to land on the start of the week.
+ days_t ToWeekStart(days_t point) const {
+ const weekday dow(point);
+ const weekday start_of_week(week_start_);
+ if (dow == start_of_week) return point;
+ const days delta = start_of_week - dow;
+ // delta is always positive and in [0, 6]
+ return point - days(7 - delta.count());
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ T Call(KernelContext*, Arg0 arg0, Arg1 arg1, Status*) const {
+ auto from =
+ ToWeekStart(floor<days>(localizer_.template ConvertTimePoint<Duration>(arg0)));
+ auto to =
+ ToWeekStart(floor<days>(localizer_.template ConvertTimePoint<Duration>(arg1)));
+ return (to - from).count() / 7;
+ }
+
+ uint32_t week_start_;
+ Localizer localizer_;
+};
+
+template <typename Duration, typename Localizer>
+struct MonthDayNanoBetween {
+ MonthDayNanoBetween(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0, typename Arg1>
+ T Call(KernelContext*, Arg0 arg0, Arg1 arg1, Status*) const {
+ static_assert(std::is_same<T, MonthDayNanoIntervalType::MonthDayNanos>::value, "");
+ auto from = localizer_.template ConvertTimePoint<Duration>(arg0);
+ auto to = localizer_.template ConvertTimePoint<Duration>(arg1);
+ year_month_day from_ymd(floor<days>(from));
+ year_month_day to_ymd(floor<days>(to));
+ const int32_t num_months = static_cast<int32_t>(
+ (to_ymd.year() / to_ymd.month() - from_ymd.year() / from_ymd.month()).count());
+ const int32_t num_days = static_cast<int32_t>(static_cast<uint32_t>(to_ymd.day())) -
+ static_cast<int32_t>(static_cast<uint32_t>(from_ymd.day()));
+ auto from_time = static_cast<int64_t>(
+ std::chrono::duration_cast<std::chrono::nanoseconds>(from - floor<days>(from))
+ .count());
+ auto to_time = static_cast<int64_t>(
+ std::chrono::duration_cast<std::chrono::nanoseconds>(to - floor<days>(to))
+ .count());
+ const int64_t num_nanos = to_time - from_time;
+ return T{num_months, num_days, num_nanos};
+ }
+
+ Localizer localizer_;
+};
+
+template <typename Duration, typename Localizer>
+struct DayTimeBetween {
+ DayTimeBetween(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0, typename Arg1>
+ T Call(KernelContext*, Arg0 arg0, Arg1 arg1, Status*) const {
+ static_assert(std::is_same<T, DayTimeIntervalType::DayMilliseconds>::value, "");
+ auto from = localizer_.template ConvertTimePoint<Duration>(arg0);
+ auto to = localizer_.template ConvertTimePoint<Duration>(arg1);
+ const int32_t num_days =
+ static_cast<int32_t>((floor<days>(to) - floor<days>(from)).count());
+ auto from_time = static_cast<int32_t>(
+ std::chrono::duration_cast<std::chrono::milliseconds>(from - floor<days>(from))
+ .count());
+ auto to_time = static_cast<int32_t>(
+ std::chrono::duration_cast<std::chrono::milliseconds>(to - floor<days>(to))
+ .count());
+ const int32_t num_millis = to_time - from_time;
+ return DayTimeIntervalType::DayMilliseconds{num_days, num_millis};
+ }
+
+ Localizer localizer_;
+};
+
+template <typename Unit, typename Duration, typename Localizer>
+struct UnitsBetween {
+ UnitsBetween(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0, typename Arg1>
+ T Call(KernelContext*, Arg0 arg0, Arg1 arg1, Status*) const {
+ auto from = floor<Unit>(localizer_.template ConvertTimePoint<Duration>(arg0));
+ auto to = floor<Unit>(localizer_.template ConvertTimePoint<Duration>(arg1));
+ return static_cast<T>((to - from).count());
+ }
+
+ Localizer localizer_;
+};
+
+template <typename Duration, typename Localizer>
+using DaysBetween = UnitsBetween<days, Duration, Localizer>;
+
+template <typename Duration, typename Localizer>
+using HoursBetween = UnitsBetween<std::chrono::hours, Duration, Localizer>;
+
+template <typename Duration, typename Localizer>
+using MinutesBetween = UnitsBetween<std::chrono::minutes, Duration, Localizer>;
+
+template <typename Duration, typename Localizer>
+using SecondsBetween = UnitsBetween<std::chrono::seconds, Duration, Localizer>;
+
+template <typename Duration, typename Localizer>
+using MillisecondsBetween = UnitsBetween<std::chrono::milliseconds, Duration, Localizer>;
+
+template <typename Duration, typename Localizer>
+using MicrosecondsBetween = UnitsBetween<std::chrono::microseconds, Duration, Localizer>;
+
+template <typename Duration, typename Localizer>
+using NanosecondsBetween = UnitsBetween<std::chrono::nanoseconds, Duration, Localizer>;
+
+// ----------------------------------------------------------------------
+// Registration helpers
+
+template <template <typename...> class Op,
+ template <template <typename...> class OpExec, typename Duration,
+ typename InType, typename OutType, typename... Args>
+ class ExecTemplate,
+ typename OutType>
+struct BinaryTemporalFactory {
+ OutputType out_type;
+ KernelInit init;
+ std::shared_ptr<ScalarFunction> func;
+
+ template <typename... WithTypes>
+ static std::shared_ptr<ScalarFunction> Make(
+ std::string name, OutputType out_type, const FunctionDoc* doc,
+ const FunctionOptions* default_options = NULLPTR, KernelInit init = NULLPTR) {
+ DCHECK_NE(sizeof...(WithTypes), 0);
+ BinaryTemporalFactory self{
+ out_type, init,
+ std::make_shared<ScalarFunction>(name, Arity::Binary(), doc, default_options)};
+ AddTemporalKernels(&self, WithTypes{}...);
+ return self.func;
+ }
+
+ template <typename Duration, typename InType>
+ void AddKernel(InputType in_type) {
+ auto exec = ExecTemplate<Op, Duration, InType, OutType>::Exec;
+ DCHECK_OK(func->AddKernel({in_type, in_type}, out_type, std::move(exec), init));
+ }
+};
+
+const FunctionDoc years_between_doc{
+ "Compute the number of years between two timestamps",
+ ("Returns the number of year boundaries crossed from `start` to `end`.\n"
+ "That is, the difference is calculated as if the timestamps were\n"
+ "truncated to the year.\n"
+ "Null values emit null."),
+ {"start", "end"}};
+
+const FunctionDoc quarters_between_doc{
+ "Compute the number of quarters between two timestamps",
+ ("Returns the number of quarter start boundaries crossed from `start` to `end`.\n"
+ "That is, the difference is calculated as if the timestamps were\n"
+ "truncated to the quarter.\n"
+ "Null values emit null."),
+ {"start", "end"}};
+
+const FunctionDoc months_between_doc{
+ "Compute the number of months between two timestamps",
+ ("Returns the number of month boundaries crossed from `start` to `end`.\n"
+ "That is, the difference is calculated as if the timestamps were\n"
+ "truncated to the month.\n"
+ "Null values emit null."),
+ {"start", "end"}};
+
+const FunctionDoc month_day_nano_interval_between_doc{
+ "Compute the number of months, days and nanoseconds between two timestamps",
+ ("Returns the number of months, days, and nanoseconds from `start` to `end`.\n"
+ "That is, first the difference in months is computed as if both timestamps\n"
+ "were truncated to the months, then the difference between the days\n"
+ "is computed, and finally the difference between the times of the two\n"
+ "timestamps is computed as if both times were truncated to the nanosecond.\n"
+ "Null values return null."),
+ {"start", "end"}};
+
+const FunctionDoc weeks_between_doc{
+ "Compute the number of weeks between two timestamps",
+ ("Returns the number of week boundaries crossed from `start` to `end`.\n"
+ "That is, the difference is calculated as if the timestamps were\n"
+ "truncated to the week.\n"
+ "Null values emit null."),
+ {"start", "end"},
+ "DayOfWeekOptions"};
+
+const FunctionDoc day_time_interval_between_doc{
+ "Compute the number of days and milliseconds between two timestamps",
+ ("Returns the number of days and milliseconds from `start` to `end`.\n"
+ "That is, first the difference in days is computed as if both\n"
+ "timestamps were truncated to the day, then the difference between time times\n"
+ "of the two timestamps is computed as if both times were truncated to the\n"
+ "millisecond.\n"
+ "Null values return null."),
+ {"start", "end"}};
+
+const FunctionDoc days_between_doc{
+ "Compute the number of days between two timestamps",
+ ("Returns the number of day boundaries crossed from `start` to `end`.\n"
+ "That is, the difference is calculated as if the timestamps were\n"
+ "truncated to the day.\n"
+ "Null values emit null."),
+ {"start", "end"}};
+
+const FunctionDoc hours_between_doc{
+ "Compute the number of hours between two timestamps",
+ ("Returns the number of hour boundaries crossed from `start` to `end`.\n"
+ "That is, the difference is calculated as if the timestamps were\n"
+ "truncated to the hour.\n"
+ "Null values emit null."),
+ {"start", "end"}};
+
+const FunctionDoc minutes_between_doc{
+ "Compute the number of minute boundaries between two timestamps",
+ ("Returns the number of minute boundaries crossed from `start` to `end`.\n"
+ "That is, the difference is calculated as if the timestamps were\n"
+ "truncated to the minute.\n"
+ "Null values emit null."),
+ {"start", "end"}};
+
+const FunctionDoc seconds_between_doc{
+ "Compute the number of seconds between two timestamps",
+ ("Returns the number of second boundaries crossed from `start` to `end`.\n"
+ "That is, the difference is calculated as if the timestamps were\n"
+ "truncated to the second.\n"
+ "Null values emit null."),
+ {"start", "end"}};
+
+const FunctionDoc milliseconds_between_doc{
+ "Compute the number of millisecond boundaries between two timestamps",
+ ("Returns the number of millisecond boundaries crossed from `start` to `end`.\n"
+ "That is, the difference is calculated as if the timestamps were\n"
+ "truncated to the millisecond.\n"
+ "Null values emit null."),
+ {"start", "end"}};
+
+const FunctionDoc microseconds_between_doc{
+ "Compute the number of microseconds between two timestamps",
+ ("Returns the number of microsecond boundaries crossed from `start` to `end`.\n"
+ "That is, the difference is calculated as if the timestamps were\n"
+ "truncated to the microsecond.\n"
+ "Null values emit null."),
+ {"start", "end"}};
+
+const FunctionDoc nanoseconds_between_doc{
+ "Compute the number of nanoseconds between two timestamps",
+ ("Returns the number of nanosecond boundaries crossed from `start` to `end`.\n"
+ "That is, the difference is calculated as if the timestamps were\n"
+ "truncated to the nanosecond.\n"
+ "Null values emit null."),
+ {"start", "end"}};
+
+} // namespace
+
+void RegisterScalarTemporalBinary(FunctionRegistry* registry) {
+ // Temporal difference functions
+ auto years_between =
+ BinaryTemporalFactory<YearsBetween, TemporalBinary, Int64Type>::Make<
+ WithDates, WithTimestamps>("years_between", int64(), &years_between_doc);
+ DCHECK_OK(registry->AddFunction(std::move(years_between)));
+
+ auto quarters_between =
+ BinaryTemporalFactory<QuartersBetween, TemporalBinary, Int64Type>::Make<
+ WithDates, WithTimestamps>("quarters_between", int64(), &quarters_between_doc);
+ DCHECK_OK(registry->AddFunction(std::move(quarters_between)));
+
+ auto month_interval_between =
+ BinaryTemporalFactory<MonthsBetween, TemporalBinary, MonthIntervalType>::Make<
+ WithDates, WithTimestamps>("month_interval_between", month_interval(),
+ &months_between_doc);
+ DCHECK_OK(registry->AddFunction(std::move(month_interval_between)));
+
+ auto month_day_nano_interval_between =
+ BinaryTemporalFactory<MonthDayNanoBetween, TemporalBinary,
+ MonthDayNanoIntervalType>::Make<WithDates, WithTimes,
+ WithTimestamps>(
+ "month_day_nano_interval_between", month_day_nano_interval(),
+ &month_day_nano_interval_between_doc);
+ DCHECK_OK(registry->AddFunction(std::move(month_day_nano_interval_between)));
+
+ static const auto default_day_of_week_options = DayOfWeekOptions::Defaults();
+ auto weeks_between =
+ BinaryTemporalFactory<WeeksBetween, TemporalDayOfWeekBinary, Int64Type>::Make<
+ WithDates, WithTimestamps>("weeks_between", int64(), &weeks_between_doc,
+ &default_day_of_week_options, DayOfWeekState::Init);
+ DCHECK_OK(registry->AddFunction(std::move(weeks_between)));
+
+ auto day_time_interval_between =
+ BinaryTemporalFactory<DayTimeBetween, TemporalBinary, DayTimeIntervalType>::Make<
+ WithDates, WithTimes, WithTimestamps>("day_time_interval_between",
+ day_time_interval(),
+ &day_time_interval_between_doc);
+ DCHECK_OK(registry->AddFunction(std::move(day_time_interval_between)));
+
+ auto days_between =
+ BinaryTemporalFactory<DaysBetween, TemporalBinary, Int64Type>::Make<WithDates,
+ WithTimestamps>(
+ "days_between", int64(), &days_between_doc);
+ DCHECK_OK(registry->AddFunction(std::move(days_between)));
+
+ auto hours_between =
+ BinaryTemporalFactory<HoursBetween, TemporalBinary, Int64Type>::Make<
+ WithDates, WithTimes, WithTimestamps>("hours_between", int64(),
+ &hours_between_doc);
+ DCHECK_OK(registry->AddFunction(std::move(hours_between)));
+
+ auto minutes_between =
+ BinaryTemporalFactory<MinutesBetween, TemporalBinary, Int64Type>::Make<
+ WithDates, WithTimes, WithTimestamps>("minutes_between", int64(),
+ &minutes_between_doc);
+ DCHECK_OK(registry->AddFunction(std::move(minutes_between)));
+
+ auto seconds_between =
+ BinaryTemporalFactory<SecondsBetween, TemporalBinary, Int64Type>::Make<
+ WithDates, WithTimes, WithTimestamps>("seconds_between", int64(),
+ &seconds_between_doc);
+ DCHECK_OK(registry->AddFunction(std::move(seconds_between)));
+
+ auto milliseconds_between =
+ BinaryTemporalFactory<MillisecondsBetween, TemporalBinary, Int64Type>::Make<
+ WithDates, WithTimes, WithTimestamps>("milliseconds_between", int64(),
+ &milliseconds_between_doc);
+ DCHECK_OK(registry->AddFunction(std::move(milliseconds_between)));
+
+ auto microseconds_between =
+ BinaryTemporalFactory<MicrosecondsBetween, TemporalBinary, Int64Type>::Make<
+ WithDates, WithTimes, WithTimestamps>("microseconds_between", int64(),
+ &microseconds_between_doc);
+ DCHECK_OK(registry->AddFunction(std::move(microseconds_between)));
+
+ auto nanoseconds_between =
+ BinaryTemporalFactory<NanosecondsBetween, TemporalBinary, Int64Type>::Make<
+ WithDates, WithTimes, WithTimestamps>("nanoseconds_between", int64(),
+ &nanoseconds_between_doc);
+ DCHECK_OK(registry->AddFunction(std::move(nanoseconds_between)));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
new file mode 100644
index 000000000..92133136b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
@@ -0,0 +1,1330 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <tuple>
+
+#include <gtest/gtest.h>
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/formatting.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::StringFormatter;
+
+namespace compute {
+
+class ScalarTemporalTest : public ::testing::Test {
+ public:
+ const char* date32s =
+ R"([0, 11016, -25932, 23148, 18262, 18261, 18260, 14609, 14610, 14612,
+ 14613, 13149, 13148, 14241, 14242, 15340, null])";
+ const char* date32s2 =
+ R"([365, 10650, -25901, 23118, 18263, 18259, 18260, 14609, 14610, 14612,
+ 14613, 13149, 13148, 14240, 13937, 15400, null])";
+ const char* date64s =
+ R"([0, 951782400000, -2240524800000, 1999987200000, 1577836800000,
+ 1577750400000, 1577664000000, 1262217600000, 1262304000000, 1262476800000,
+ 1262563200000, 1136073600000, 1135987200000, 1230422400000, 1230508800000,
+ 1325376000000, null])";
+ const char* date64s2 =
+ R"([31536000000, 920160000000, -2237846400000, 1997395200000,
+ 1577923200000, 1577577600000, 1577664000000, 1262217600000, 1262304000000,
+ 1262476800000, 1262563200000, 1136073600000, 1135987200000, 1230336000000,
+ 1204156800000, 1330560000000, null])";
+ const char* times_s =
+ R"([59, 84203, 3560, 12800, 3905, 7810, 11715, 15620, 19525, 23430, 27335,
+ 31240, 35145, 0, 0, 3723, null])";
+ const char* times_s2 =
+ R"([59, 84203, 12642, 7182, 68705, 7390, 915, 16820, 19525, 5430, 84959,
+ 31207, 35145, 0, 0, 3723, null])";
+ const char* times_ms =
+ R"([59123, 84203999, 3560001, 12800000, 3905001, 7810002, 11715003, 15620004,
+ 19525005, 23430006, 27335000, 31240000, 35145000, 0, 0, 3723000, null])";
+ const char* times_ms2 =
+ R"([59103, 84203999, 12642001, 7182000, 68705005, 7390000, 915003, 16820004,
+ 19525005, 5430006, 84959000, 31207000, 35145000, 0, 0, 3723000, null])";
+ const char* times_us =
+ R"([59123456, 84203999999, 3560001001, 12800000000, 3905001000, 7810002000,
+ 11715003000, 15620004132, 19525005321, 23430006163, 27335000000,
+ 31240000000, 35145000000, 0, 0, 3723000000, null])";
+ const char* times_us2 =
+ R"([59103476, 84203999999, 12642001001, 7182000000, 68705005000, 7390000000,
+ 915003000, 16820004432, 19525005021, 5430006163, 84959000000,
+ 31207000000, 35145000000, 0, 0, 3723000000, null])";
+ const char* times_ns =
+ R"([59123456789, 84203999999999, 3560001001001, 12800000000000, 3905001000000,
+ 7810002000000, 11715003000000, 15620004132000, 19525005321000,
+ 23430006163000, 27335000000000, 31240000000000, 35145000000000, 0, 0,
+ 3723000000000, null])";
+ const char* times_ns2 =
+ R"([59103476799, 84203999999909, 12642001001001, 7182000000000, 68705005000000,
+ 7390000000000, 915003000000, 16820004432000, 19525005021000, 5430006163000,
+ 84959000000000, 31207000000000, 35145000000000, 0, 0, 3723000000000, null])";
+ const char* times =
+ R"(["1970-01-01T00:00:59.123456789","2000-02-29T23:23:23.999999999",
+ "1899-01-01T00:59:20.001001001","2033-05-18T03:33:20.000000000",
+ "2020-01-01T01:05:05.001", "2019-12-31T02:10:10.002",
+ "2019-12-30T03:15:15.003", "2009-12-31T04:20:20.004132",
+ "2010-01-01T05:25:25.005321", "2010-01-03T06:30:30.006163",
+ "2010-01-04T07:35:35", "2006-01-01T08:40:40", "2005-12-31T09:45:45",
+ "2008-12-28", "2008-12-29", "2012-01-01 01:02:03", null])";
+ const char* times2 =
+ R"(["1970-01-01T00:00:59.103476799","2000-02-29T23:23:23.999999909",
+ "1899-01-01T03:30:42.001001001","2033-05-18T01:59:42.000000000",
+ "2020-01-01T19:05:05.005", "2019-12-31T02:03:10.000",
+ "2019-12-30T00:15:15.003", "2009-12-31T04:40:20.004432",
+ "2010-01-01T05:25:25.005021", "2010-01-03T01:30:30.006163",
+ "2010-01-04T23:35:59", "2006-01-01T08:40:07",
+ "2005-12-31T09:45:45", "2008-12-28", "2008-12-29",
+ "2012-01-01 01:02:03", null])";
+ const char* times_seconds_precision =
+ R"(["1970-01-01T00:00:59","2000-02-29T23:23:23",
+ "1899-01-01T00:59:20","2033-05-18T03:33:20",
+ "2020-01-01T01:05:05", "2019-12-31T02:10:10",
+ "2019-12-30T03:15:15", "2009-12-31T04:20:20",
+ "2010-01-01T05:25:25", "2010-01-03T06:30:30",
+ "2010-01-04T07:35:35", "2006-01-01T08:40:40",
+ "2005-12-31T09:45:45", "2008-12-28", "2008-12-29",
+ "2012-01-01 01:02:03", null])";
+ const char* times_seconds_precision2 =
+ R"(["1971-01-01T00:00:59","1999-02-28T23:23:23",
+ "1899-02-01T00:59:20","2033-04-18T03:33:20",
+ "2020-01-02T01:05:05", "2019-12-29T02:10:10",
+ "2019-12-30T04:15:15", "2009-12-31T03:20:20",
+ "2010-01-01T05:26:25", "2010-01-03T06:29:30",
+ "2010-01-04T07:35:36", "2006-01-01T08:40:39",
+ "2005-12-31T09:45:45", "2008-12-27T23:59:59",
+ "2008-02-28", "2012-03-01", null])";
+ std::shared_ptr<arrow::DataType> iso_calendar_type =
+ struct_({field("iso_year", int64()), field("iso_week", int64()),
+ field("iso_day_of_week", int64())});
+ std::shared_ptr<arrow::Array> iso_calendar =
+ ArrayFromJSON(iso_calendar_type,
+ R"([{"iso_year": 1970, "iso_week": 1, "iso_day_of_week": 4},
+ {"iso_year": 2000, "iso_week": 9, "iso_day_of_week": 2},
+ {"iso_year": 1898, "iso_week": 52, "iso_day_of_week": 7},
+ {"iso_year": 2033, "iso_week": 20, "iso_day_of_week": 3},
+ {"iso_year": 2020, "iso_week": 1, "iso_day_of_week": 3},
+ {"iso_year": 2020, "iso_week": 1, "iso_day_of_week": 2},
+ {"iso_year": 2020, "iso_week": 1, "iso_day_of_week": 1},
+ {"iso_year": 2009, "iso_week": 53, "iso_day_of_week": 4},
+ {"iso_year": 2009, "iso_week": 53, "iso_day_of_week": 5},
+ {"iso_year": 2009, "iso_week": 53, "iso_day_of_week": 7},
+ {"iso_year": 2010, "iso_week": 1, "iso_day_of_week": 1},
+ {"iso_year": 2005, "iso_week": 52, "iso_day_of_week": 7},
+ {"iso_year": 2005, "iso_week": 52, "iso_day_of_week": 6},
+ {"iso_year": 2008, "iso_week": 52, "iso_day_of_week": 7},
+ {"iso_year": 2009, "iso_week": 1, "iso_day_of_week": 1},
+ {"iso_year": 2011, "iso_week": 52, "iso_day_of_week": 7}, null])");
+ std::string year =
+ "[1970, 2000, 1899, 2033, 2020, 2019, 2019, 2009, 2010, 2010, 2010, 2006, "
+ "2005, 2008, 2008, 2012, null]";
+ std::string month = "[1, 2, 1, 5, 1, 12, 12, 12, 1, 1, 1, 1, 12, 12, 12, 1, null]";
+ std::string day = "[1, 29, 1, 18, 1, 31, 30, 31, 1, 3, 4, 1, 31, 28, 29, 1, null]";
+ std::string day_of_week = "[3, 1, 6, 2, 2, 1, 0, 3, 4, 6, 0, 6, 5, 6, 0, 6, null]";
+ std::string day_of_year =
+ "[1, 60, 1, 138, 1, 365, 364, 365, 1, 3, 4, 1, 365, 363, 364, 1, null]";
+ std::string iso_year =
+ "[1970, 2000, 1898, 2033, 2020, 2020, 2020, 2009, 2009, 2009, 2010, 2005, "
+ "2005, 2008, 2009, 2011, null]";
+ std::string iso_week =
+ "[1, 9, 52, 20, 1, 1, 1, 53, 53, 53, 1, 52, 52, 52, 1, 52, null]";
+ std::string us_week = "[53, 9, 1, 20, 1, 1, 1, 52, 52, 1, 1, 1, 52, 53, 53, 1, null]";
+ std::string week = "[1, 9, 52, 20, 1, 1, 1, 53, 53, 53, 1, 52, 52, 52, 1, 52, null]";
+
+ std::string quarter = "[1, 1, 1, 2, 1, 4, 4, 4, 1, 1, 1, 1, 4, 4, 4, 1, null]";
+ std::string hour = "[0, 23, 0, 3, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 1, null]";
+ std::string minute =
+ "[0, 23, 59, 33, 5, 10, 15, 20, 25, 30, 35, 40, 45, 0, 0, 2, null]";
+ std::string second =
+ "[59, 23, 20, 20, 5, 10, 15, 20, 25, 30, 35, 40, 45, 0, 0, 3, null]";
+ std::string millisecond = "[123, 999, 1, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, null]";
+ std::string microsecond =
+ "[456, 999, 1, 0, 0, 0, 0, 132, 321, 163, 0, 0, 0, 0, 0, 0, null]";
+ std::string nanosecond = "[789, 999, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, null]";
+ std::string subsecond =
+ "[0.123456789, 0.999999999, 0.001001001, 0, 0.001, 0.002, 0.003, 0.004132, "
+ "0.005321, 0.006163, 0, 0, 0, 0, 0, 0, null]";
+ std::string subsecond_ms =
+ "[0.123, 0.999, 0.001, 0, 0.001, 0.002, 0.003, 0.004, 0.005, 0.006, 0, 0, 0, "
+ "0, 0, 0, null]";
+ std::string subsecond_us =
+ "[0.123456, 0.999999, 0.001001, 0, 0.001, 0.002, 0.003, 0.004132, 0.005321, "
+ "0.006163, 0, 0, 0, 0, 0, 0, null]";
+ std::string zeros = "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, null]";
+ std::string years_between = "[1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, null]";
+ std::string years_between_tz =
+ "[1, -1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, null]";
+ std::string quarters_between =
+ "[4, -4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, 0, null]";
+ std::string quarters_between_tz =
+ "[4, -4, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, 1, null]";
+ std::string months_between =
+ "[12, -12, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10, 2, null]";
+ std::string months_between_tz =
+ "[12, -12, 1, -1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10, 2, null]";
+ std::string month_day_nano_interval_between_zeros =
+ "[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], "
+ "[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], "
+ "[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], null]";
+ std::string month_day_nano_interval_between =
+ "[[12, 0, 0], [-12, -1, 0], [1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -2, 0], "
+ "[0, 0, 3600000000000], [0, 0, -3600000000000], "
+ "[0, 0, 60000000000], [0, 0, -60000000000], "
+ "[0, 0, 1000000000], [0, 0, -1000000000], "
+ "[0, 0, 0], [0, -1, 86399000000000], [-10, -1, 0], [2, 0, -3723000000000], null]";
+ std::string month_day_nano_interval_between_tz =
+ "[[12, 0, 0], [-12, -1, 0], [1, 0, 0], [-1, 0, 0], [1, -30, 0], [0, -2, 0], "
+ "[0, 0, 3600000000000], [0, 0, -3600000000000], "
+ "[0, 0, 60000000000], [0, 0, -60000000000], "
+ "[0, 0, 1000000000], [0, 0, -1000000000], "
+ "[0, 0, 0], [0, 0, -1000000000], [-10, -1, 0], [2, -2, -3723000000000], null]";
+ std::string month_day_nano_interval_between_date =
+ "[[12, 0, 0], [-12, -1, 0], [1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -2, 0], "
+ "[0, 0, 0], [0, 0, 0], "
+ "[0, 0, 0], [0, 0, 0], "
+ "[0, 0, 0], [0, 0, 0], "
+ "[0, 0, 0], [0, -1, 0], [-10, -1, 0], [2, 0, 0], null]";
+ std::string month_day_nano_interval_between_time =
+ "[[0, 0, -19979990], [0, 0, -90], [0, 0, 9082000000000], [0, 0, -5618000000000], "
+ "[0, 0, 64800004000000], [0, 0, -420002000000], [0, 0, -10800000000000], "
+ "[0, 0, 1200000300000], [0, 0, -300000], [0, 0, -18000000000000], [0, 0, "
+ "57624000000000], "
+ "[0, 0, -33000000000], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], null]";
+ std::string month_day_nano_interval_between_time_s =
+ "[[0, 0, 0], [0, 0, 0], [0, 0, 9082000000000], [0, 0, -5618000000000], "
+ "[0, 0, 64800000000000], [0, 0, -420000000000], [0, 0, -10800000000000], "
+ "[0, 0, 1200000000000], [0, 0, 0], [0, 0, -18000000000000], [0, 0, "
+ "57624000000000], "
+ "[0, 0, -33000000000], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], null]";
+ std::string month_day_nano_interval_between_time_ms =
+ "[[0, 0, -20000000], [0, 0, 0], [0, 0, 9082000000000], [0, 0, -5618000000000], "
+ "[0, 0, 64800004000000], [0, 0, -420002000000], [0, 0, -10800000000000], "
+ "[0, 0, 1200000000000], [0, 0, 0], [0, 0, -18000000000000], [0, 0, "
+ "57624000000000], "
+ "[0, 0, -33000000000], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], null]";
+ std::string month_day_nano_interval_between_time_us =
+ "[[0, 0, -19980000], [0, 0, 0], [0, 0, 9082000000000], [0, 0, -5618000000000], "
+ "[0, 0, 64800004000000], [0, 0, -420002000000], [0, 0, -10800000000000], "
+ "[0, 0, 1200000300000], [0, 0, -300000], [0, 0, -18000000000000], [0, 0, "
+ "57624000000000], "
+ "[0, 0, -33000000000], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], null]";
+ std::string day_time_interval_between_zeros =
+ "[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], "
+ "[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], null]";
+ std::string day_time_interval_between =
+ "[[365, 0], [-366, 0], [31, 0], [-30, 0], [1, 0], [-2, 0], [0, 3600000], "
+ "[0, -3600000], [0, 60000], [0, -60000], [0, 1000], [0, -1000], [0, 0], "
+ "[-1, 86399000], [-305, 0], [60, -3723000], null]";
+ std::string day_time_interval_between_date =
+ "[[365, 0], [-366, 0], [31, 0], [-30, 0], [1, 0], [-2, 0], [0, 0], "
+ "[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], "
+ "[-1, 0], [-305, 0], [60, 0], null]";
+ std::string day_time_interval_between_tz =
+ "[[365, 0], [-366, 0], [31, 0], [-30, 0], [1, 0], [-2, 0], [0, 3600000], "
+ "[0, -3600000], [0, 60000], [0, -60000], [0, 1000], [0, -1000], [0, 0], "
+ "[0, -1000], [-305, 0], [60, -3723000], null]";
+ std::string day_time_interval_between_time =
+ "[[0, -20], [0, 0], [0, 9082000], [0, -5618000], [0, 64800004], [0, -420002], "
+ "[0, -10800000], [0, 1200000], [0, 0], [0, -18000000], [0, 57624000], "
+ "[0, -33000], [0, 0], [0, 0], [0, 0], [0, 0], null]";
+ std::string day_time_interval_between_time_s =
+ "[[0, 0], [0, 0], [0, 9082000], [0, -5618000], [0, 64800000], [0, -420000], "
+ "[0, -10800000], [0, 1200000], [0, 0], [0, -18000000], [0, 57624000], "
+ "[0, -33000], [0, 0], [0, 0], [0, 0], [0, 0], null]";
+ std::string weeks_between =
+ "[52, -53, 5, -4, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, -44, 9, null]";
+ std::string weeks_between_tz =
+ "[52, -53, 5, -5, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, -43, 9, null]";
+ std::string days_between =
+ "[365, -366, 31, -30, 1, -2, 0, 0, 0, 0, 0, 0, 0, -1, -305, 60, null]";
+ std::string days_between_tz =
+ "[365, -366, 31, -30, 1, -2, 0, 0, 0, 0, 0, 0, 0, 0, -305, 60, null]";
+ std::string hours_between =
+ "[8760, -8784, 744, -720, 24, -48, 1, -1, 0, 0, 0, 0, 0, -1, -7320, 1439, null]";
+ std::string hours_between_date =
+ "[8760, -8784, 744, -720, 24, -48, 0, 0, 0, 0, 0, 0, 0, -24, -7320, 1440, null]";
+ std::string hours_between_tz =
+ "[8760, -8784, 744, -720, 24, -48, 1, -1, 0, -1, 0, 0, 0, 0, -7320, 1439, null]";
+ std::string hours_between_time =
+ "[0, 0, 3, -2, 18, 0, -3, 0, 0, -5, 16, 0, 0, 0, 0, 0, null]";
+ std::string minutes_between =
+ "[525600, -527040, 44640, -43200, 1440, -2880, 60, -60, 1, -1, 0, 0, 0, -1, "
+ "-439200, 86338, null]";
+ std::string minutes_between_date =
+ "[525600, -527040, 44640, -43200, 1440, -2880, 0, 0, 0, 0, 0, 0, 0, -1440, "
+ "-439200, 86400, null]";
+ std::string minutes_between_time =
+ "[0, 0, 151, -94, 1080, -7, -180, 20, 0, -300, 960, 0, 0, 0, 0, 0, null]";
+ std::string seconds_between =
+ "[31536000, -31622400, 2678400, -2592000, 86400, -172800, 3600, -3600, 60, -60, 1, "
+ "-1, 0, -1, -26352000, 5180277, null]";
+ std::string seconds_between_date =
+ "[31536000, -31622400, 2678400, -2592000, 86400, -172800, 0, 0, 0, 0, 0, "
+ "0, 0, -86400, -26352000, 5184000, null]";
+ std::string seconds_between_time =
+ "[0, 0, 9082, -5618, 64800, -420, -10800, 1200, 0, -18000, 57624, -33, 0, 0, 0, 0, "
+ "null]";
+ std::string milliseconds_between =
+ "[31536000000, -31622400000, 2678400000, -2592000000, 86400000, -172800000, "
+ "3600000, -3600000, 60000, -60000, 1000, -1000, 0, -1000, -26352000000, "
+ "5180277000, null]";
+ std::string milliseconds_between_date =
+ "[31536000000, -31622400000, 2678400000, -2592000000, 86400000, -172800000, "
+ "0, 0, 0, 0, 0, 0, 0, -86400000, -26352000000, 5184000000, null]";
+ std::string milliseconds_between_time =
+ "[-20, 0, 9082000, -5618000, 64800004, -420002, -10800000, 1200000, 0, "
+ "-18000000, 57624000, -33000, 0, 0, 0, 0, null]";
+ std::string milliseconds_between_time_s =
+ "[0, 0, 9082000, -5618000, 64800000, -420000, -10800000, 1200000, 0, "
+ "-18000000, 57624000, -33000, 0, 0, 0, 0, null]";
+ std::string microseconds_between =
+ "[31536000000000, -31622400000000, 2678400000000, -2592000000000, 86400000000, "
+ "-172800000000, 3600000000, -3600000000, 60000000, -60000000, 1000000, -1000000, "
+ "0, -1000000, -26352000000000, 5180277000000, null]";
+ std::string microseconds_between_date =
+ "[31536000000000, -31622400000000, 2678400000000, -2592000000000, 86400000000, "
+ "-172800000000, 0, 0, 0, 0, 0, 0, 0, -86400000000, -26352000000000, 5184000000000, "
+ "null]";
+ std::string microseconds_between_time =
+ "[-19980, 0, 9082000000, -5618000000, 64800004000, -420002000, -10800000000, "
+ "1200000300, -300, -18000000000, 57624000000, -33000000, 0, 0, 0, 0, null]";
+ std::string microseconds_between_time_s =
+ "[0, 0, 9082000000, -5618000000, 64800000000, -420000000, -10800000000, "
+ "1200000000, 0, -18000000000, 57624000000, -33000000, 0, 0, 0, 0, null]";
+ std::string microseconds_between_time_ms =
+ "[-20000, 0, 9082000000, -5618000000, 64800004000, -420002000, -10800000000, "
+ "1200000000, 0, -18000000000, 57624000000, -33000000, 0, 0, 0, 0, null]";
+ std::string nanoseconds_between =
+ "[31536000000000000, -31622400000000000, 2678400000000000, -2592000000000000, "
+ "86400000000000, -172800000000000, 3600000000000, -3600000000000, 60000000000, "
+ "-60000000000, 1000000000, -1000000000, 0, -1000000000, -26352000000000000, "
+ "5180277000000000, null]";
+ std::string nanoseconds_between_date =
+ "[31536000000000000, -31622400000000000, 2678400000000000, -2592000000000000, "
+ "86400000000000, -172800000000000, 0, 0, 0, 0, 0, 0, 0, -86400000000000, "
+ "-26352000000000000, 5184000000000000, null]";
+ std::string nanoseconds_between_time =
+ "[-19979990, -90, 9082000000000, -5618000000000, 64800004000000, -420002000000, "
+ "-10800000000000, 1200000300000, -300000, -18000000000000, 57624000000000, "
+ "-33000000000, 0, 0, 0, 0, null]";
+ std::string nanoseconds_between_time_s =
+ "[0, 0, 9082000000000, -5618000000000, 64800000000000, -420000000000, "
+ "-10800000000000, 1200000000000, 0, -18000000000000, 57624000000000, "
+ "-33000000000, 0, 0, 0, 0, null]";
+ std::string nanoseconds_between_time_ms =
+ "[-20000000, 0, 9082000000000, -5618000000000, 64800004000000, -420002000000, "
+ "-10800000000000, 1200000000000, 0, -18000000000000, 57624000000000, "
+ "-33000000000, 0, 0, 0, 0, null]";
+ std::string nanoseconds_between_time_us =
+ "[-19980000, 0, 9082000000000, -5618000000000, 64800004000000, -420002000000, "
+ "-10800000000000, 1200000300000, -300000, -18000000000000, 57624000000000, "
+ "-33000000000, 0, 0, 0, 0, null]";
+};
+
+TEST_F(ScalarTemporalTest, TestTemporalComponentExtractionAllTemporalTypes) {
+ std::vector<std::shared_ptr<DataType>> units = {date32(), date64(),
+ timestamp(TimeUnit::NANO)};
+ std::vector<const char*> samples = {date32s, date64s, times};
+ DCHECK_EQ(units.size(), samples.size());
+ for (size_t i = 0; i < samples.size(); ++i) {
+ auto unit = units[i];
+ auto sample = samples[i];
+ CheckScalarUnary("year", unit, sample, int64(), year);
+ CheckScalarUnary("month", unit, sample, int64(), month);
+ CheckScalarUnary("day", unit, sample, int64(), day);
+ CheckScalarUnary("day_of_week", unit, sample, int64(), day_of_week);
+ CheckScalarUnary("day_of_year", unit, sample, int64(), day_of_year);
+ CheckScalarUnary("iso_year", unit, sample, int64(), iso_year);
+ CheckScalarUnary("iso_week", unit, sample, int64(), iso_week);
+ CheckScalarUnary("us_week", unit, sample, int64(), us_week);
+ CheckScalarUnary("iso_calendar", ArrayFromJSON(unit, sample), iso_calendar);
+ CheckScalarUnary("quarter", unit, sample, int64(), quarter);
+ if (unit->id() == Type::TIMESTAMP) {
+ CheckScalarUnary("hour", unit, sample, int64(), hour);
+ CheckScalarUnary("minute", unit, sample, int64(), minute);
+ CheckScalarUnary("second", unit, sample, int64(), second);
+ CheckScalarUnary("millisecond", unit, sample, int64(), millisecond);
+ CheckScalarUnary("microsecond", unit, sample, int64(), microsecond);
+ CheckScalarUnary("nanosecond", unit, sample, int64(), nanosecond);
+ CheckScalarUnary("subsecond", unit, sample, float64(), subsecond);
+ }
+ }
+
+ CheckScalarUnary("hour", time32(TimeUnit::SECOND), times_s, int64(), hour);
+ CheckScalarUnary("minute", time32(TimeUnit::SECOND), times_s, int64(), minute);
+ CheckScalarUnary("second", time32(TimeUnit::SECOND), times_s, int64(), second);
+ CheckScalarUnary("millisecond", time32(TimeUnit::SECOND), times_s, int64(), zeros);
+ CheckScalarUnary("microsecond", time32(TimeUnit::SECOND), times_s, int64(), zeros);
+ CheckScalarUnary("nanosecond", time32(TimeUnit::SECOND), times_s, int64(), zeros);
+ CheckScalarUnary("subsecond", time32(TimeUnit::SECOND), times_s, float64(), zeros);
+
+ CheckScalarUnary("hour", time32(TimeUnit::MILLI), times_ms, int64(), hour);
+ CheckScalarUnary("minute", time32(TimeUnit::MILLI), times_ms, int64(), minute);
+ CheckScalarUnary("second", time32(TimeUnit::MILLI), times_ms, int64(), second);
+ CheckScalarUnary("millisecond", time32(TimeUnit::MILLI), times_ms, int64(),
+ millisecond);
+ CheckScalarUnary("microsecond", time32(TimeUnit::MILLI), times_ms, int64(), zeros);
+ CheckScalarUnary("nanosecond", time32(TimeUnit::MILLI), times_ms, int64(), zeros);
+ CheckScalarUnary("subsecond", time32(TimeUnit::MILLI), times_ms, float64(),
+ subsecond_ms);
+
+ CheckScalarUnary("hour", time64(TimeUnit::MICRO), times_us, int64(), hour);
+ CheckScalarUnary("minute", time64(TimeUnit::MICRO), times_us, int64(), minute);
+ CheckScalarUnary("second", time64(TimeUnit::MICRO), times_us, int64(), second);
+ CheckScalarUnary("millisecond", time64(TimeUnit::MICRO), times_us, int64(),
+ millisecond);
+ CheckScalarUnary("microsecond", time64(TimeUnit::MICRO), times_us, int64(),
+ microsecond);
+ CheckScalarUnary("nanosecond", time64(TimeUnit::MICRO), times_us, int64(), zeros);
+ CheckScalarUnary("subsecond", time64(TimeUnit::MICRO), times_us, float64(),
+ subsecond_us);
+
+ CheckScalarUnary("hour", time64(TimeUnit::NANO), times_ns, int64(), hour);
+ CheckScalarUnary("minute", time64(TimeUnit::NANO), times_ns, int64(), minute);
+ CheckScalarUnary("second", time64(TimeUnit::NANO), times_ns, int64(), second);
+ CheckScalarUnary("millisecond", time64(TimeUnit::NANO), times_ns, int64(), millisecond);
+ CheckScalarUnary("microsecond", time64(TimeUnit::NANO), times_ns, int64(), microsecond);
+ CheckScalarUnary("nanosecond", time64(TimeUnit::NANO), times_ns, int64(), nanosecond);
+ CheckScalarUnary("subsecond", time64(TimeUnit::NANO), times_ns, float64(), subsecond);
+}
+
+TEST_F(ScalarTemporalTest, TestTemporalComponentExtractionWithDifferentUnits) {
+ for (auto u : TimeUnit::values()) {
+ auto unit = timestamp(u);
+ CheckScalarUnary("year", unit, times_seconds_precision, int64(), year);
+ CheckScalarUnary("month", unit, times_seconds_precision, int64(), month);
+ CheckScalarUnary("day", unit, times_seconds_precision, int64(), day);
+ CheckScalarUnary("day_of_week", unit, times_seconds_precision, int64(), day_of_week);
+ CheckScalarUnary("day_of_year", unit, times_seconds_precision, int64(), day_of_year);
+ CheckScalarUnary("iso_year", unit, times_seconds_precision, int64(), iso_year);
+ CheckScalarUnary("iso_week", unit, times_seconds_precision, int64(), iso_week);
+ CheckScalarUnary("us_week", unit, times_seconds_precision, int64(), us_week);
+ CheckScalarUnary("week", unit, times_seconds_precision, int64(), week);
+ CheckScalarUnary("iso_calendar", ArrayFromJSON(unit, times_seconds_precision),
+ iso_calendar);
+ CheckScalarUnary("quarter", unit, times_seconds_precision, int64(), quarter);
+ CheckScalarUnary("hour", unit, times_seconds_precision, int64(), hour);
+ CheckScalarUnary("minute", unit, times_seconds_precision, int64(), minute);
+ CheckScalarUnary("second", unit, times_seconds_precision, int64(), second);
+ CheckScalarUnary("millisecond", unit, times_seconds_precision, int64(), zeros);
+ CheckScalarUnary("microsecond", unit, times_seconds_precision, int64(), zeros);
+ CheckScalarUnary("nanosecond", unit, times_seconds_precision, int64(), zeros);
+ CheckScalarUnary("subsecond", unit, times_seconds_precision, float64(), zeros);
+ }
+}
+
+TEST_F(ScalarTemporalTest, TestOutsideNanosecondRange) {
+ const char* times = R"(["1677-09-20T00:00:59.123456", "2262-04-13T23:23:23.999999"])";
+ auto unit = timestamp(TimeUnit::MICRO);
+ auto year = "[1677, 2262]";
+ auto month = "[9, 4]";
+ auto day = "[20, 13]";
+ auto day_of_week = "[0, 6]";
+ auto day_of_year = "[263, 103]";
+ auto iso_year = "[1677, 2262]";
+ auto iso_week = "[38, 15]";
+ auto us_week = "[38, 16]";
+ auto week = "[38, 15]";
+ auto iso_calendar =
+ ArrayFromJSON(iso_calendar_type,
+ R"([{"iso_year": 1677, "iso_week": 38, "iso_day_of_week": 1},
+ {"iso_year": 2262, "iso_week": 15, "iso_day_of_week": 7}])");
+ auto quarter = "[3, 2]";
+ auto hour = "[0, 23]";
+ auto minute = "[0, 23]";
+ auto second = "[59, 23]";
+ auto millisecond = "[123, 999]";
+ auto microsecond = "[456, 999]";
+ auto nanosecond = "[0, 0]";
+ auto subsecond = "[0.123456, 0.999999]";
+
+ CheckScalarUnary("year", unit, times, int64(), year);
+ CheckScalarUnary("month", unit, times, int64(), month);
+ CheckScalarUnary("day", unit, times, int64(), day);
+ CheckScalarUnary("day_of_week", unit, times, int64(), day_of_week);
+ CheckScalarUnary("day_of_year", unit, times, int64(), day_of_year);
+ CheckScalarUnary("iso_year", unit, times, int64(), iso_year);
+ CheckScalarUnary("iso_week", unit, times, int64(), iso_week);
+ CheckScalarUnary("us_week", unit, times, int64(), us_week);
+ CheckScalarUnary("week", unit, times, int64(), week);
+ CheckScalarUnary("iso_calendar", ArrayFromJSON(unit, times), iso_calendar);
+ CheckScalarUnary("quarter", unit, times, int64(), quarter);
+ CheckScalarUnary("hour", unit, times, int64(), hour);
+ CheckScalarUnary("minute", unit, times, int64(), minute);
+ CheckScalarUnary("second", unit, times, int64(), second);
+ CheckScalarUnary("millisecond", unit, times, int64(), millisecond);
+ CheckScalarUnary("microsecond", unit, times, int64(), microsecond);
+ CheckScalarUnary("nanosecond", unit, times, int64(), nanosecond);
+ CheckScalarUnary("subsecond", unit, times, float64(), subsecond);
+}
+
+#ifndef _WIN32
+// TODO: We should test on windows once ARROW-13168 is resolved.
+TEST_F(ScalarTemporalTest, TestZoned1) {
+ auto unit = timestamp(TimeUnit::NANO, "Pacific/Marquesas");
+ auto year =
+ "[1969, 2000, 1898, 2033, 2019, 2019, 2019, 2009, 2009, 2010, 2010, 2005, 2005, "
+ "2008, 2008, 2011, null]";
+ auto month = "[12, 2, 12, 5, 12, 12, 12, 12, 12, 1, 1, 12, 12, 12, 12, 12, null]";
+ auto day = "[31, 29, 31, 17, 31, 30, 29, 30, 31, 2, 3, 31, 31, 27, 28, 31, null]";
+ auto day_of_week = "[2, 1, 5, 1, 1, 0, 6, 2, 3, 5, 6, 5, 5, 5, 6, 5, null]";
+ auto day_of_year =
+ "[365, 60, 365, 137, 365, 364, 363, 364, 365, 2, 3, 365, 365, 362, 363, 365, null]";
+ auto iso_year =
+ "[1970, 2000, 1898, 2033, 2020, 2020, 2019, 2009, 2009, 2009, 2009, 2005, 2005, "
+ "2008, 2008, 2011, null]";
+ auto iso_week = "[1, 9, 52, 20, 1, 1, 52, 53, 53, 53, 53, 52, 52, 52, 52, 52, null]";
+ auto us_week = "[53, 9, 52, 20, 1, 1, 1, 52, 52, 52, 1, 52, 52, 52, 53, 52, null]";
+ auto week = "[1, 9, 52, 20, 1, 1, 52, 53, 53, 53, 53, 52, 52, 52, 52, 52, null]";
+ auto iso_calendar =
+ ArrayFromJSON(iso_calendar_type,
+ R"([{"iso_year": 1970, "iso_week": 1, "iso_day_of_week": 3},
+ {"iso_year": 2000, "iso_week": 9, "iso_day_of_week": 2},
+ {"iso_year": 1898, "iso_week": 52, "iso_day_of_week": 6},
+ {"iso_year": 2033, "iso_week": 20, "iso_day_of_week": 2},
+ {"iso_year": 2020, "iso_week": 1, "iso_day_of_week": 2},
+ {"iso_year": 2020, "iso_week": 1, "iso_day_of_week": 1},
+ {"iso_year": 2019, "iso_week": 52, "iso_day_of_week": 7},
+ {"iso_year": 2009, "iso_week": 53, "iso_day_of_week": 3},
+ {"iso_year": 2009, "iso_week": 53, "iso_day_of_week": 4},
+ {"iso_year": 2009, "iso_week": 53, "iso_day_of_week": 6},
+ {"iso_year": 2009, "iso_week": 53, "iso_day_of_week": 7},
+ {"iso_year": 2005, "iso_week": 52, "iso_day_of_week": 6},
+ {"iso_year": 2005, "iso_week": 52, "iso_day_of_week": 6},
+ {"iso_year": 2008, "iso_week": 52, "iso_day_of_week": 6},
+ {"iso_year": 2008, "iso_week": 52, "iso_day_of_week": 7},
+ {"iso_year": 2011, "iso_week": 52, "iso_day_of_week": 6}, null])");
+ auto quarter = "[4, 1, 4, 2, 4, 4, 4, 4, 4, 1, 1, 4, 4, 4, 4, 4, null]";
+ auto hour = "[14, 13, 15, 18, 15, 16, 17, 18, 19, 21, 22, 23, 0, 14, 14, 15, null]";
+ auto minute = "[30, 53, 41, 3, 35, 40, 45, 50, 55, 0, 5, 10, 15, 30, 30, 32, null]";
+
+ CheckScalarUnary("year", unit, times, int64(), year);
+ CheckScalarUnary("month", unit, times, int64(), month);
+ CheckScalarUnary("day", unit, times, int64(), day);
+ CheckScalarUnary("day_of_week", unit, times, int64(), day_of_week);
+ CheckScalarUnary("day_of_year", unit, times, int64(), day_of_year);
+ CheckScalarUnary("iso_year", unit, times, int64(), iso_year);
+ CheckScalarUnary("iso_week", unit, times, int64(), iso_week);
+ CheckScalarUnary("us_week", unit, times, int64(), us_week);
+ CheckScalarUnary("week", unit, times, int64(), week);
+ CheckScalarUnary("iso_calendar", ArrayFromJSON(unit, times), iso_calendar);
+ CheckScalarUnary("quarter", unit, times, int64(), quarter);
+ CheckScalarUnary("hour", unit, times, int64(), hour);
+ CheckScalarUnary("minute", unit, times, int64(), minute);
+ CheckScalarUnary("second", unit, times, int64(), second);
+ CheckScalarUnary("millisecond", unit, times, int64(), millisecond);
+ CheckScalarUnary("microsecond", unit, times, int64(), microsecond);
+ CheckScalarUnary("nanosecond", unit, times, int64(), nanosecond);
+ CheckScalarUnary("subsecond", unit, times, float64(), subsecond);
+}
+
+TEST_F(ScalarTemporalTest, TestZoned2) {
+ for (auto u : TimeUnit::values()) {
+ auto unit = timestamp(u, "Australia/Broken_Hill");
+ auto month = "[1, 3, 1, 5, 1, 12, 12, 12, 1, 1, 1, 1, 12, 12, 12, 1, null]";
+ auto day = "[1, 1, 1, 18, 1, 31, 30, 31, 1, 3, 4, 1, 31, 28, 29, 1, null]";
+ auto day_of_week = "[3, 2, 6, 2, 2, 1, 0, 3, 4, 6, 0, 6, 5, 6, 0, 6, null]";
+ auto day_of_year =
+ "[1, 61, 1, 138, 1, 365, 364, 365, 1, 3, 4, 1, 365, 363, 364, 1, null]";
+ auto iso_year =
+ "[1970, 2000, 1898, 2033, 2020, 2020, 2020, 2009, 2009, 2009, 2010, 2005, 2005, "
+ "2008, 2009, 2011, null]";
+ auto iso_week = "[1, 9, 52, 20, 1, 1, 1, 53, 53, 53, 1, 52, 52, 52, 1, 52, null]";
+ auto us_week = "[53, 9, 1, 20, 1, 1, 1, 52, 52, 1, 1, 1, 52, 53, 53, 1, null]";
+ auto week = "[1, 9, 52, 20, 1, 1, 1, 53, 53, 53, 1, 52, 52, 52, 1, 52, null]";
+ auto iso_calendar =
+ ArrayFromJSON(iso_calendar_type,
+ R"([{"iso_year": 1970, "iso_week": 1, "iso_day_of_week": 4},
+ {"iso_year": 2000, "iso_week": 9, "iso_day_of_week": 3},
+ {"iso_year": 1898, "iso_week": 52, "iso_day_of_week": 7},
+ {"iso_year": 2033, "iso_week": 20, "iso_day_of_week": 3},
+ {"iso_year": 2020, "iso_week": 1, "iso_day_of_week": 3},
+ {"iso_year": 2020, "iso_week": 1, "iso_day_of_week": 2},
+ {"iso_year": 2020, "iso_week": 1, "iso_day_of_week": 1},
+ {"iso_year": 2009, "iso_week": 53, "iso_day_of_week": 4},
+ {"iso_year": 2009, "iso_week": 53, "iso_day_of_week": 5},
+ {"iso_year": 2009, "iso_week": 53, "iso_day_of_week": 7},
+ {"iso_year": 2010, "iso_week": 1, "iso_day_of_week": 1},
+ {"iso_year": 2005, "iso_week": 52, "iso_day_of_week": 7},
+ {"iso_year": 2005, "iso_week": 52, "iso_day_of_week": 6},
+ {"iso_year": 2008, "iso_week": 52, "iso_day_of_week": 7},
+ {"iso_year": 2009, "iso_week": 1, "iso_day_of_week": 1},
+ {"iso_year": 2011, "iso_week": 52, "iso_day_of_week": 7}, null])");
+ auto quarter = "[1, 1, 1, 2, 1, 4, 4, 4, 1, 1, 1, 1, 4, 4, 4, 1, null]";
+ auto hour = "[9, 9, 9, 13, 11, 12, 13, 14, 15, 17, 18, 19, 20, 10, 10, 11, null]";
+ auto minute = "[30, 53, 59, 3, 35, 40, 45, 50, 55, 0, 5, 10, 15, 30, 30, 32, null]";
+
+ CheckScalarUnary("year", unit, times_seconds_precision, int64(), year);
+ CheckScalarUnary("month", unit, times_seconds_precision, int64(), month);
+ CheckScalarUnary("day", unit, times_seconds_precision, int64(), day);
+ CheckScalarUnary("day_of_week", unit, times_seconds_precision, int64(), day_of_week);
+ CheckScalarUnary("day_of_year", unit, times_seconds_precision, int64(), day_of_year);
+ CheckScalarUnary("iso_year", unit, times_seconds_precision, int64(), iso_year);
+ CheckScalarUnary("iso_week", unit, times_seconds_precision, int64(), iso_week);
+ CheckScalarUnary("us_week", unit, times_seconds_precision, int64(), us_week);
+ CheckScalarUnary("week", unit, times_seconds_precision, int64(), week);
+ CheckScalarUnary("iso_calendar", ArrayFromJSON(unit, times_seconds_precision),
+ iso_calendar);
+ CheckScalarUnary("quarter", unit, times_seconds_precision, int64(), quarter);
+ CheckScalarUnary("hour", unit, times_seconds_precision, int64(), hour);
+ CheckScalarUnary("minute", unit, times_seconds_precision, int64(), minute);
+ CheckScalarUnary("second", unit, times_seconds_precision, int64(), second);
+ CheckScalarUnary("millisecond", unit, times_seconds_precision, int64(), zeros);
+ CheckScalarUnary("microsecond", unit, times_seconds_precision, int64(), zeros);
+ CheckScalarUnary("nanosecond", unit, times_seconds_precision, int64(), zeros);
+ CheckScalarUnary("subsecond", unit, times_seconds_precision, float64(), zeros);
+ }
+}
+
+TEST_F(ScalarTemporalTest, TestNonexistentTimezone) {
+ auto data_buffer = Buffer::Wrap(std::vector<int32_t>{1, 2, 3});
+ auto null_buffer = Buffer::FromString("\xff");
+
+ for (auto u : TimeUnit::values()) {
+ auto ts_type = timestamp(u, "Mars/Mariner_Valley");
+ auto timestamp_array = std::make_shared<NumericArray<TimestampType>>(
+ ts_type, 2, data_buffer, null_buffer, 0);
+ ASSERT_RAISES(Invalid, Year(timestamp_array));
+ ASSERT_RAISES(Invalid, Month(timestamp_array));
+ ASSERT_RAISES(Invalid, Day(timestamp_array));
+ ASSERT_RAISES(Invalid, DayOfWeek(timestamp_array));
+ ASSERT_RAISES(Invalid, DayOfYear(timestamp_array));
+ ASSERT_RAISES(Invalid, ISOYear(timestamp_array));
+ ASSERT_RAISES(Invalid, Week(timestamp_array));
+ ASSERT_RAISES(Invalid, ISOCalendar(timestamp_array));
+ ASSERT_RAISES(Invalid, Quarter(timestamp_array));
+ ASSERT_RAISES(Invalid, Hour(timestamp_array));
+ ASSERT_RAISES(Invalid, Minute(timestamp_array));
+ ASSERT_RAISES(Invalid, Second(timestamp_array));
+ ASSERT_RAISES(Invalid, Millisecond(timestamp_array));
+ ASSERT_RAISES(Invalid, Microsecond(timestamp_array));
+ ASSERT_RAISES(Invalid, Nanosecond(timestamp_array));
+ ASSERT_RAISES(Invalid, Subsecond(timestamp_array));
+ }
+}
+#endif
+
+TEST_F(ScalarTemporalTest, Week) {
+ auto unit = timestamp(TimeUnit::NANO);
+ std::string week_100 =
+ "[1, 9, 52, 20, 1, 1, 1, 53, 53, 53, 1, 52, 52, 52, 1, 52, null]";
+ std::string week_110 = "[1, 9, 0, 20, 1, 53, 53, 53, 0, 0, 1, 0, 52, 52, 53, 0, null]";
+ std::string week_010 = "[0, 9, 1, 20, 1, 53, 53, 52, 0, 1, 1, 1, 52, 53, 53, 1, null]";
+ std::string week_000 = "[53, 9, 1, 20, 1, 1, 1, 52, 52, 1, 1, 1, 52, 53, 53, 1, null]";
+ std::string week_111 = "[0, 9, 0, 20, 0, 52, 52, 52, 0, 0, 1, 0, 52, 51, 52, 0, null]";
+ std::string week_011 = "[0, 9, 1, 20, 0, 52, 52, 52, 0, 1, 1, 1, 52, 52, 52, 1, null]";
+ std::string week_101 =
+ "[52, 9, 52, 20, 52, 52, 52, 52, 52, 52, 1, 52, 52, 51, 52, 52, null]";
+ std::string week_001 =
+ "[52, 9, 1, 20, 52, 52, 52, 52, 52, 1, 1, 1, 52, 52, 52, 1, null]";
+
+ auto options_100 = WeekOptions(/*week_starts_monday*/ true, /*count_from_zero=*/false,
+ /*first_week_is_fully_in_year=*/false);
+ auto options_110 = WeekOptions(/*week_starts_monday*/ true, /*count_from_zero=*/true,
+ /*first_week_is_fully_in_year=*/false);
+ auto options_010 = WeekOptions(/*week_starts_monday*/ false, /*count_from_zero=*/true,
+ /*first_week_is_fully_in_year=*/false);
+ auto options_000 = WeekOptions(/*week_starts_monday*/ false, /*count_from_zero=*/false,
+ /*first_week_is_fully_in_year=*/false);
+ auto options_111 = WeekOptions(/*week_starts_monday*/ true, /*count_from_zero=*/true,
+ /*first_week_is_fully_in_year=*/true);
+ auto options_011 = WeekOptions(/*week_starts_monday*/ false, /*count_from_zero=*/true,
+ /*first_week_is_fully_in_year=*/true);
+ auto options_101 = WeekOptions(/*week_starts_monday*/ true, /*count_from_zero=*/false,
+ /*first_week_is_fully_in_year=*/true);
+ auto options_001 = WeekOptions(/*week_starts_monday*/ false, /*count_from_zero=*/false,
+ /*first_week_is_fully_in_year=*/true);
+
+ CheckScalarUnary("iso_week", unit, times, int64(), week_100);
+ CheckScalarUnary("us_week", unit, times, int64(), week_000);
+ CheckScalarUnary("week", unit, times, int64(), week_100, &options_100);
+ CheckScalarUnary("week", unit, times, int64(), week_110, &options_110);
+ CheckScalarUnary("week", unit, times, int64(), week_010, &options_010);
+ CheckScalarUnary("week", unit, times, int64(), week_000, &options_000);
+ CheckScalarUnary("week", unit, times, int64(), week_111, &options_111);
+ CheckScalarUnary("week", unit, times, int64(), week_011, &options_011);
+ CheckScalarUnary("week", unit, times, int64(), week_101, &options_101);
+ CheckScalarUnary("week", unit, times, int64(), week_001, &options_001);
+}
+
+TEST_F(ScalarTemporalTest, DayOfWeek) {
+ auto unit = timestamp(TimeUnit::NANO);
+
+ auto timestamps = ArrayFromJSON(unit, times);
+ auto day_of_week_week_start_7_zero_based =
+ "[4, 2, 0, 3, 3, 2, 1, 4, 5, 0, 1, 0, 6, 0, 1, 0, null]";
+ auto day_of_week_week_start_2_zero_based =
+ "[2, 0, 5, 1, 1, 0, 6, 2, 3, 5, 6, 5, 4, 5, 6, 5, null]";
+ auto day_of_week_week_start_7_one_based =
+ "[5, 3, 1, 4, 4, 3, 2, 5, 6, 1, 2, 1, 7, 1, 2, 1, null]";
+ auto day_of_week_week_start_2_one_based =
+ "[3, 1, 6, 2, 2, 1, 7, 3, 4, 6, 7, 6, 5, 6, 7, 6, null]";
+
+ auto expected_70 = ArrayFromJSON(int64(), day_of_week_week_start_7_zero_based);
+ ASSERT_OK_AND_ASSIGN(
+ Datum result_70,
+ DayOfWeek(timestamps, DayOfWeekOptions(
+ /*count_from_zero=*/true, /*week_start=*/7)));
+ ASSERT_TRUE(result_70.Equals(expected_70));
+
+ auto expected_20 = ArrayFromJSON(int64(), day_of_week_week_start_2_zero_based);
+ ASSERT_OK_AND_ASSIGN(
+ Datum result_20,
+ DayOfWeek(timestamps, DayOfWeekOptions(
+ /*count_from_zero=*/true, /*week_start=*/2)));
+ ASSERT_TRUE(result_20.Equals(expected_20));
+
+ auto expected_71 = ArrayFromJSON(int64(), day_of_week_week_start_7_one_based);
+ ASSERT_OK_AND_ASSIGN(
+ Datum result_71,
+ DayOfWeek(timestamps, DayOfWeekOptions(
+ /*count_from_zero=*/false, /*week_start=*/7)));
+ ASSERT_TRUE(result_71.Equals(expected_71));
+
+ auto expected_21 = ArrayFromJSON(int64(), day_of_week_week_start_2_one_based);
+ ASSERT_OK_AND_ASSIGN(
+ Datum result_21,
+ DayOfWeek(timestamps, DayOfWeekOptions(
+ /*count_from_zero=*/false, /*week_start=*/2)));
+ ASSERT_TRUE(result_21.Equals(expected_21));
+
+ ASSERT_RAISES(Invalid, DayOfWeek(timestamps, DayOfWeekOptions(/*count_from_zero=*/false,
+ /*week_start=*/0)));
+ ASSERT_RAISES(Invalid, DayOfWeek(timestamps, DayOfWeekOptions(/*count_from_zero=*/true,
+ /*week_start=*/8)));
+}
+
+TEST_F(ScalarTemporalTest, TestTemporalDifference) {
+ for (auto u : TimeUnit::values()) {
+ auto unit = timestamp(u);
+ auto arr1 = ArrayFromJSON(unit, times_seconds_precision);
+ auto arr2 = ArrayFromJSON(unit, times_seconds_precision2);
+ CheckScalarBinary("years_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("years_between", arr1, arr2, ArrayFromJSON(int64(), years_between));
+ CheckScalarBinary("quarters_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("quarters_between", arr1, arr2,
+ ArrayFromJSON(int64(), quarters_between));
+ CheckScalarBinary("month_interval_between", arr1, arr1,
+ ArrayFromJSON(month_interval(), zeros));
+ CheckScalarBinary("month_interval_between", arr1, arr1,
+ ArrayFromJSON(month_interval(), zeros));
+ CheckScalarBinary("month_interval_between", arr1, arr2,
+ ArrayFromJSON(month_interval(), months_between));
+ CheckScalarBinary(
+ "month_day_nano_interval_between", arr1, arr1,
+ ArrayFromJSON(month_day_nano_interval(), month_day_nano_interval_between_zeros));
+ CheckScalarBinary(
+ "month_day_nano_interval_between", arr1, arr2,
+ ArrayFromJSON(month_day_nano_interval(), month_day_nano_interval_between));
+ CheckScalarBinary(
+ "day_time_interval_between", arr1, arr1,
+ ArrayFromJSON(day_time_interval(), day_time_interval_between_zeros));
+ CheckScalarBinary("day_time_interval_between", arr1, arr2,
+ ArrayFromJSON(day_time_interval(), day_time_interval_between));
+ CheckScalarBinary("weeks_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("weeks_between", arr1, arr2, ArrayFromJSON(int64(), weeks_between));
+ CheckScalarBinary("days_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("days_between", arr1, arr2, ArrayFromJSON(int64(), days_between));
+ CheckScalarBinary("hours_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("hours_between", arr1, arr2, ArrayFromJSON(int64(), hours_between));
+ CheckScalarBinary("minutes_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("minutes_between", arr1, arr2,
+ ArrayFromJSON(int64(), minutes_between));
+ CheckScalarBinary("seconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("seconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), seconds_between));
+ CheckScalarBinary("milliseconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("milliseconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), milliseconds_between));
+ CheckScalarBinary("microseconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("microseconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), microseconds_between));
+ CheckScalarBinary("nanoseconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("nanoseconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), nanoseconds_between));
+ }
+
+ for (auto date_case : {std::make_tuple(date32(), date32s, date32s2),
+ std::make_tuple(date64(), date64s, date64s2)}) {
+ auto ty = std::get<0>(date_case);
+ auto arr1 = ArrayFromJSON(ty, std::get<1>(date_case));
+ auto arr2 = ArrayFromJSON(ty, std::get<2>(date_case));
+ CheckScalarBinary("years_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("years_between", arr1, arr2, ArrayFromJSON(int64(), years_between));
+ CheckScalarBinary("quarters_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("quarters_between", arr1, arr2,
+ ArrayFromJSON(int64(), quarters_between));
+ CheckScalarBinary("month_interval_between", arr1, arr1,
+ ArrayFromJSON(month_interval(), zeros));
+ CheckScalarBinary("month_interval_between", arr1, arr2,
+ ArrayFromJSON(month_interval(), months_between));
+ CheckScalarBinary(
+ "month_day_nano_interval_between", arr1, arr1,
+ ArrayFromJSON(month_day_nano_interval(), month_day_nano_interval_between_zeros));
+ CheckScalarBinary(
+ "month_day_nano_interval_between", arr1, arr2,
+ ArrayFromJSON(month_day_nano_interval(), month_day_nano_interval_between_date));
+ CheckScalarBinary(
+ "day_time_interval_between", arr1, arr1,
+ ArrayFromJSON(day_time_interval(), day_time_interval_between_zeros));
+ CheckScalarBinary("day_time_interval_between", arr1, arr2,
+ ArrayFromJSON(day_time_interval(), day_time_interval_between_date));
+ CheckScalarBinary("weeks_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("weeks_between", arr1, arr2, ArrayFromJSON(int64(), weeks_between));
+ CheckScalarBinary("days_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("days_between", arr1, arr2, ArrayFromJSON(int64(), days_between));
+ CheckScalarBinary("hours_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("hours_between", arr1, arr2,
+ ArrayFromJSON(int64(), hours_between_date));
+ CheckScalarBinary("minutes_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("minutes_between", arr1, arr2,
+ ArrayFromJSON(int64(), minutes_between_date));
+ CheckScalarBinary("seconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("seconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), seconds_between_date));
+ CheckScalarBinary("milliseconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("milliseconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), milliseconds_between_date));
+ CheckScalarBinary("microseconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("microseconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), microseconds_between_date));
+ CheckScalarBinary("nanoseconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("nanoseconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), nanoseconds_between_date));
+ }
+
+ struct TimeCase {
+ std::shared_ptr<DataType> ty;
+ std::string times1;
+ std::string times2;
+ std::string month_day_nano_interval_between;
+ std::string day_time_interval_between;
+ std::string milliseconds_between;
+ std::string microseconds_between;
+ std::string nanoseconds_between;
+ };
+ std::vector<TimeCase> cases = {
+ {time32(TimeUnit::SECOND), times_s, times_s2,
+ month_day_nano_interval_between_time_s, day_time_interval_between_time_s,
+ milliseconds_between_time_s, microseconds_between_time_s,
+ nanoseconds_between_time_s},
+ {time32(TimeUnit::MILLI), times_ms, times_ms2,
+ month_day_nano_interval_between_time_ms, day_time_interval_between_time,
+ milliseconds_between_time, microseconds_between_time_ms,
+ nanoseconds_between_time_ms},
+ {time64(TimeUnit::MICRO), times_us, times_us2,
+ month_day_nano_interval_between_time_us, day_time_interval_between_time,
+ milliseconds_between_time, microseconds_between_time, nanoseconds_between_time_us},
+ {time64(TimeUnit::NANO), times_ns, times_ns2, month_day_nano_interval_between_time,
+ day_time_interval_between_time, milliseconds_between_time,
+ microseconds_between_time, nanoseconds_between_time},
+ };
+ for (auto time_case : cases) {
+ auto arr1 = ArrayFromJSON(time_case.ty, time_case.times1);
+ auto arr2 = ArrayFromJSON(time_case.ty, time_case.times2);
+ CheckScalarBinary(
+ "month_day_nano_interval_between", arr1, arr1,
+ ArrayFromJSON(month_day_nano_interval(), month_day_nano_interval_between_zeros));
+ CheckScalarBinary("month_day_nano_interval_between", arr1, arr2,
+ ArrayFromJSON(month_day_nano_interval(),
+ time_case.month_day_nano_interval_between));
+ CheckScalarBinary("hours_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("hours_between", arr1, arr2,
+ ArrayFromJSON(int64(), hours_between_time));
+ CheckScalarBinary("minutes_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("minutes_between", arr1, arr2,
+ ArrayFromJSON(int64(), minutes_between_time));
+ CheckScalarBinary("seconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("seconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), seconds_between_time));
+ CheckScalarBinary(
+ "day_time_interval_between", arr1, arr1,
+ ArrayFromJSON(day_time_interval(), day_time_interval_between_zeros));
+ CheckScalarBinary(
+ "day_time_interval_between", arr1, arr2,
+ ArrayFromJSON(day_time_interval(), time_case.day_time_interval_between));
+ CheckScalarBinary("milliseconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("milliseconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), time_case.milliseconds_between));
+ CheckScalarBinary("microseconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("microseconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), time_case.microseconds_between));
+ CheckScalarBinary("nanoseconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("nanoseconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), time_case.nanoseconds_between));
+ }
+}
+
+TEST_F(ScalarTemporalTest, TestTemporalDifferenceWeeks) {
+ auto raw_days = ArrayFromJSON(timestamp(TimeUnit::SECOND), R"([
+ "2021-08-09", "2021-08-10", "2021-08-11", "2021-08-12", "2021-08-13", "2021-08-14", "2021-08-15",
+ "2021-08-16", "2021-08-17", "2021-08-18", "2021-08-19", "2021-08-20", "2021-08-21", "2021-08-22",
+ "2021-08-23", "2021-08-24", "2021-08-25", "2021-08-26", "2021-08-27", "2021-08-28", "2021-08-29"
+ ])");
+ std::vector<std::string> ts_scalars = {R"("2021-08-16")", R"("2021-08-17")",
+ R"("2021-08-18")"};
+ std::vector<std::string> date32_scalars = {"18855", "18856", "18857"};
+ std::vector<std::string> date64_scalars = {"1629072000000", "1629158400000",
+ "1629244800000"};
+
+ for (const auto& test_case : {std::make_pair(timestamp(TimeUnit::SECOND), ts_scalars),
+ std::make_pair(date32(), date32_scalars),
+ std::make_pair(date64(), date64_scalars)}) {
+ auto ty = test_case.first;
+ std::shared_ptr<Array> days;
+ if (ty->id() == Type::TIMESTAMP) {
+ days = raw_days;
+ } else {
+ ASSERT_OK_AND_ASSIGN(auto temp, Cast(raw_days, ty));
+ days = temp.make_array();
+ }
+ auto aug16 = ScalarFromJSON(ty, test_case.second[0]);
+ auto aug17 = ScalarFromJSON(ty, test_case.second[1]);
+ auto aug18 = ScalarFromJSON(ty, test_case.second[2]);
+
+ DayOfWeekOptions options(/*one_based_numbering=*/false, /*week_start=Monday*/ 1);
+ EXPECT_THAT(CallFunction("weeks_between", {aug16, days}, &options),
+ ResultWith(Datum(ArrayFromJSON(int64(), R"([
+-1, -1, -1, -1, -1, -1, -1,
+0, 0, 0, 0, 0, 0, 0,
+1, 1, 1, 1, 1, 1, 1
+])"))));
+ EXPECT_THAT(CallFunction("weeks_between", {aug17, days}, &options),
+ ResultWith(Datum(ArrayFromJSON(int64(), R"([
+-1, -1, -1, -1, -1, -1, -1,
+0, 0, 0, 0, 0, 0, 0,
+1, 1, 1, 1, 1, 1, 1
+])"))));
+
+ options.week_start = 3; // Wednesday
+ EXPECT_THAT(CallFunction("weeks_between", {aug16, days}, &options),
+ ResultWith(Datum(ArrayFromJSON(int64(), R"([
+-1, -1, 0, 0, 0, 0, 0,
+0, 0, 1, 1, 1, 1, 1,
+1, 1, 2, 2, 2, 2, 2
+])"))));
+ EXPECT_THAT(CallFunction("weeks_between", {aug17, days}, &options),
+ ResultWith(Datum(ArrayFromJSON(int64(), R"([
+-1, -1, 0, 0, 0, 0, 0,
+0, 0, 1, 1, 1, 1, 1,
+1, 1, 2, 2, 2, 2, 2
+])"))));
+ EXPECT_THAT(CallFunction("weeks_between", {aug18, days}, &options),
+ ResultWith(Datum(ArrayFromJSON(int64(), R"([
+-2, -2, -1, -1, -1, -1, -1,
+-1, -1, 0, 0, 0, 0, 0,
+0, 0, 1, 1, 1, 1, 1
+])"))));
+ }
+}
+
+TEST_F(ScalarTemporalTest, TestTemporalDifferenceErrors) {
+ Datum arr1 = ArrayFromJSON(timestamp(TimeUnit::SECOND, "America/New_York"),
+ R"(["1970-01-01T00:00:59"])");
+ Datum arr2 = ArrayFromJSON(timestamp(TimeUnit::SECOND, "America/Phoenix"),
+ R"(["1970-01-01T00:00:59"])");
+ Datum arr3 = ArrayFromJSON(timestamp(TimeUnit::SECOND), R"(["1970-01-01T00:00:59"])");
+ Datum arr4 =
+ ArrayFromJSON(timestamp(TimeUnit::SECOND, "UTC"), R"(["1970-01-01T00:00:59"])");
+ for (auto fn :
+ {"years_between", "month_interval_between", "month_day_nano_interval_between",
+ "day_time_interval_between", "weeks_between", "days_between", "hours_between",
+ "minutes_between", "seconds_between", "milliseconds_between",
+ "microseconds_between", "nanoseconds_between"}) {
+ SCOPED_TRACE(fn);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("Got differing time zone 'America/Phoenix' for argument 2; "
+ "expected 'America/New_York'"),
+ CallFunction(fn, {arr1, arr2}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr(
+ "Got differing time zone 'America/Phoenix' for argument 2; expected ''"),
+ CallFunction(fn, {arr3, arr2}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("Got differing time zone 'UTC' for argument 2; expected ''"),
+ CallFunction(fn, {arr3, arr4}));
+ }
+
+ DayOfWeekOptions options;
+ options.week_start = 20;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("week_start must follow ISO convention (Monday=1, Sunday=7). "
+ "Got week_start=20"),
+ CallFunction("weeks_between", {arr1, arr1}, &options));
+}
+
+// TODO: We should test on windows once ARROW-13168 is resolved.
+#ifndef _WIN32
+TEST_F(ScalarTemporalTest, TestAssumeTimezone) {
+ std::string timezone_utc = "UTC";
+ std::string timezone_kolkata = "Asia/Kolkata";
+ std::string timezone_us_central = "US/Central";
+ const char* times_utc = R"(["1970-01-01T00:00:00", null])";
+ const char* times_kolkata = R"(["1970-01-01T05:30:00", null])";
+ const char* times_us_central = R"(["1969-12-31T18:00:00", null])";
+ auto options_utc = AssumeTimezoneOptions(timezone_utc);
+ auto options_kolkata = AssumeTimezoneOptions(timezone_kolkata);
+ auto options_us_central = AssumeTimezoneOptions(timezone_us_central);
+ auto options_invalid = AssumeTimezoneOptions("Europe/Brusselsss");
+
+ for (auto u : TimeUnit::values()) {
+ auto unit = timestamp(u);
+ auto unit_utc = timestamp(u, timezone_utc);
+ auto unit_kolkata = timestamp(u, timezone_kolkata);
+ auto unit_us_central = timestamp(u, timezone_us_central);
+
+ CheckScalarUnary("assume_timezone", unit, times_utc, unit_utc, times_utc,
+ &options_utc);
+ CheckScalarUnary("assume_timezone", unit, times_kolkata, unit_kolkata, times_utc,
+ &options_kolkata);
+ CheckScalarUnary("assume_timezone", unit, times_us_central, unit_us_central,
+ times_utc, &options_us_central);
+ ASSERT_RAISES(Invalid,
+ AssumeTimezone(ArrayFromJSON(unit_kolkata, times_utc), options_utc));
+ ASSERT_RAISES(Invalid,
+ AssumeTimezone(ArrayFromJSON(unit, times_utc), options_invalid));
+ }
+}
+
+TEST_F(ScalarTemporalTest, TestAssumeTimezoneAmbiguous) {
+ std::string timezone = "CET";
+ const char* times = R"(["2018-10-28 01:20:00",
+ "2018-10-28 02:36:00",
+ "2018-10-28 03:46:00"])";
+ const char* times_earliest = R"(["2018-10-27 23:20:00",
+ "2018-10-28 00:36:00",
+ "2018-10-28 02:46:00"])";
+ const char* times_latest = R"(["2018-10-27 23:20:00",
+ "2018-10-28 01:36:00",
+ "2018-10-28 02:46:00"])";
+
+ auto options_earliest =
+ AssumeTimezoneOptions(timezone, AssumeTimezoneOptions::AMBIGUOUS_EARLIEST);
+ auto options_latest =
+ AssumeTimezoneOptions(timezone, AssumeTimezoneOptions::AMBIGUOUS_LATEST);
+ auto options_raise =
+ AssumeTimezoneOptions(timezone, AssumeTimezoneOptions::AMBIGUOUS_RAISE);
+
+ for (auto u : TimeUnit::values()) {
+ auto unit = timestamp(u);
+ auto unit_local = timestamp(u, timezone);
+ ASSERT_RAISES(Invalid, AssumeTimezone(ArrayFromJSON(unit, times), options_raise));
+ CheckScalarUnary("assume_timezone", unit, times, unit_local, times_earliest,
+ &options_earliest);
+ CheckScalarUnary("assume_timezone", unit, times, unit_local, times_latest,
+ &options_latest);
+ }
+}
+
+TEST_F(ScalarTemporalTest, TestAssumeTimezoneNonexistent) {
+ std::string timezone = "Europe/Warsaw";
+ const char* times = R"(["2015-03-29 02:30:00", "2015-03-29 03:30:00"])";
+ const char* times_latest = R"(["2015-03-29 01:00:00", "2015-03-29 01:30:00"])";
+ const char* times_earliest = R"(["2015-03-29 00:59:59", "2015-03-29 01:30:00"])";
+ const char* times_earliest_milli =
+ R"(["2015-03-29 00:59:59.999", "2015-03-29 01:30:00"])";
+ const char* times_earliest_micro =
+ R"(["2015-03-29 00:59:59.999999", "2015-03-29 01:30:00"])";
+ const char* times_earliest_nano =
+ R"(["2015-03-29 00:59:59.999999999", "2015-03-29 01:30:00"])";
+
+ auto options_raise =
+ AssumeTimezoneOptions(timezone, AssumeTimezoneOptions::AMBIGUOUS_RAISE,
+ AssumeTimezoneOptions::NONEXISTENT_RAISE);
+ auto options_latest =
+ AssumeTimezoneOptions(timezone, AssumeTimezoneOptions::AMBIGUOUS_RAISE,
+ AssumeTimezoneOptions::NONEXISTENT_LATEST);
+ auto options_earliest =
+ AssumeTimezoneOptions(timezone, AssumeTimezoneOptions::AMBIGUOUS_RAISE,
+ AssumeTimezoneOptions::NONEXISTENT_EARLIEST);
+
+ for (auto u : TimeUnit::values()) {
+ auto unit = timestamp(u);
+ auto unit_local = timestamp(u, timezone);
+ ASSERT_RAISES(Invalid, AssumeTimezone(ArrayFromJSON(unit, times), options_raise));
+ CheckScalarUnary("assume_timezone", unit, times, unit_local, times_latest,
+ &options_latest);
+ }
+ CheckScalarUnary("assume_timezone", timestamp(TimeUnit::SECOND), times,
+ timestamp(TimeUnit::SECOND, timezone), times_earliest,
+ &options_earliest);
+ CheckScalarUnary("assume_timezone", timestamp(TimeUnit::MILLI), times,
+ timestamp(TimeUnit::MILLI, timezone), times_earliest_milli,
+ &options_earliest);
+ CheckScalarUnary("assume_timezone", timestamp(TimeUnit::MICRO), times,
+ timestamp(TimeUnit::MICRO, timezone), times_earliest_micro,
+ &options_earliest);
+ CheckScalarUnary("assume_timezone", timestamp(TimeUnit::NANO), times,
+ timestamp(TimeUnit::NANO, timezone), times_earliest_nano,
+ &options_earliest);
+}
+
+TEST_F(ScalarTemporalTest, Strftime) {
+ auto options_default = StrftimeOptions();
+ auto options = StrftimeOptions("%Y-%m-%dT%H:%M:%S%z");
+
+ const char* seconds = R"(["1970-01-01T00:00:59", "2021-08-18T15:11:50", null])";
+ const char* milliseconds = R"(["1970-01-01T00:00:59.123", null])";
+ const char* microseconds = R"(["1970-01-01T00:00:59.123456", null])";
+ const char* nanoseconds = R"(["1970-01-01T00:00:59.123456789", null])";
+
+ const char* default_seconds = R"(
+ ["1970-01-01T00:00:59", "2021-08-18T15:11:50", null])";
+ const char* string_seconds = R"(
+ ["1970-01-01T00:00:59+0000", "2021-08-18T15:11:50+0000", null])";
+ const char* string_milliseconds = R"(["1970-01-01T00:00:59.123+0000", null])";
+ const char* string_microseconds = R"(["1970-01-01T05:30:59.123456+0530", null])";
+ const char* string_nanoseconds = R"(["1969-12-31T14:00:59.123456789-1000", null])";
+
+ CheckScalarUnary("strftime", timestamp(TimeUnit::SECOND, "UTC"), seconds, utf8(),
+ default_seconds, &options_default);
+ CheckScalarUnary("strftime", timestamp(TimeUnit::SECOND, "UTC"), seconds, utf8(),
+ string_seconds, &options);
+ CheckScalarUnary("strftime", timestamp(TimeUnit::MILLI, "GMT"), milliseconds, utf8(),
+ string_milliseconds, &options);
+ CheckScalarUnary("strftime", timestamp(TimeUnit::MICRO, "Asia/Kolkata"), microseconds,
+ utf8(), string_microseconds, &options);
+ CheckScalarUnary("strftime", timestamp(TimeUnit::NANO, "US/Hawaii"), nanoseconds,
+ utf8(), string_nanoseconds, &options);
+
+ auto options_hms = StrftimeOptions("%H:%M:%S");
+ auto options_ymdhms = StrftimeOptions("%Y-%m-%dT%H:%M:%S");
+
+ const char* times_s = R"([59, null])";
+ const char* times_ms = R"([59123, null])";
+ const char* times_us = R"([59123456, null])";
+ const char* times_ns = R"([59123456789, null])";
+ const char* hms_s = R"(["00:00:59", null])";
+ const char* hms_ms = R"(["00:00:59.123", null])";
+ const char* hms_us = R"(["00:00:59.123456", null])";
+ const char* hms_ns = R"(["00:00:59.123456789", null])";
+ const char* ymdhms_s = R"(["1970-01-01T00:00:59", null])";
+ const char* ymdhms_ms = R"(["1970-01-01T00:00:59.123", null])";
+ const char* ymdhms_us = R"(["1970-01-01T00:00:59.123456", null])";
+ const char* ymdhms_ns = R"(["1970-01-01T00:00:59.123456789", null])";
+
+ CheckScalarUnary("strftime", time32(TimeUnit::SECOND), times_s, utf8(), hms_s,
+ &options_hms);
+ CheckScalarUnary("strftime", time32(TimeUnit::MILLI), times_ms, utf8(), hms_ms,
+ &options_hms);
+ CheckScalarUnary("strftime", time64(TimeUnit::MICRO), times_us, utf8(), hms_us,
+ &options_hms);
+ CheckScalarUnary("strftime", time64(TimeUnit::NANO), times_ns, utf8(), hms_ns,
+ &options_hms);
+
+ CheckScalarUnary("strftime", time32(TimeUnit::SECOND), times_s, utf8(), ymdhms_s,
+ &options_ymdhms);
+ CheckScalarUnary("strftime", time32(TimeUnit::MILLI), times_ms, utf8(), ymdhms_ms,
+ &options_ymdhms);
+ CheckScalarUnary("strftime", time64(TimeUnit::MICRO), times_us, utf8(), ymdhms_us,
+ &options_ymdhms);
+ CheckScalarUnary("strftime", time64(TimeUnit::NANO), times_ns, utf8(), ymdhms_ns,
+ &options_ymdhms);
+
+ auto arr_s = ArrayFromJSON(time32(TimeUnit::SECOND), times_s);
+ auto arr_ms = ArrayFromJSON(time32(TimeUnit::MILLI), times_ms);
+ auto arr_us = ArrayFromJSON(time64(TimeUnit::MICRO), times_us);
+ auto arr_ns = ArrayFromJSON(time64(TimeUnit::NANO), times_ns);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr("Invalid: Timezone not present, cannot convert to string"),
+ Strftime(arr_s, StrftimeOptions("%Y-%m-%dT%H:%M:%S%z")));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr("Invalid: Timezone not present, cannot convert to string"),
+ Strftime(arr_ms, StrftimeOptions("%Y-%m-%dT%H:%M:%S%Z")));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr("Invalid: Timezone not present, cannot convert to string"),
+ Strftime(arr_us, StrftimeOptions("%Y-%m-%dT%H:%M:%S%z")));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr("Invalid: Timezone not present, cannot convert to string"),
+ Strftime(arr_ns, StrftimeOptions("%Y-%m-%dT%H:%M:%S%Z")));
+
+ auto options_ymd = StrftimeOptions("%Y-%m-%d");
+
+ const char* date32s = R"([0, 10957, 10958, null])";
+ const char* date64s = R"([0, 946684800000, 946771200000, null])";
+ const char* dates32_ymd = R"(["1970-01-01", "2000-01-01", "2000-01-02", null])";
+ const char* dates64_ymd = R"(["1970-01-01", "2000-01-01", "2000-01-02", null])";
+ const char* dates32_ymdhms =
+ R"(["1970-01-01T00:00:00", "2000-01-01T00:00:00", "2000-01-02T00:00:00", null])";
+ const char* dates64_ymdhms =
+ R"(["1970-01-01T00:00:00.000", "2000-01-01T00:00:00.000",
+ "2000-01-02T00:00:00.000", null])";
+
+ CheckScalarUnary("strftime", date32(), date32s, utf8(), dates32_ymd, &options_ymd);
+ CheckScalarUnary("strftime", date64(), date64s, utf8(), dates64_ymd, &options_ymd);
+ CheckScalarUnary("strftime", date32(), date32s, utf8(), dates32_ymdhms,
+ &options_ymdhms);
+ CheckScalarUnary("strftime", date64(), date64s, utf8(), dates64_ymdhms,
+ &options_ymdhms);
+}
+
+TEST_F(ScalarTemporalTest, StrftimeNoTimezone) {
+ auto options_default = StrftimeOptions();
+ const char* seconds = R"(["1970-01-01T00:00:59", null])";
+ auto arr = ArrayFromJSON(timestamp(TimeUnit::SECOND), seconds);
+
+ CheckScalarUnary("strftime", timestamp(TimeUnit::SECOND), seconds, utf8(), seconds,
+ &options_default);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr("Invalid: Timezone not present, cannot convert to string"),
+ Strftime(arr, StrftimeOptions("%Y-%m-%dT%H:%M:%S%z")));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr("Invalid: Timezone not present, cannot convert to string"),
+ Strftime(arr, StrftimeOptions("%Y-%m-%dT%H:%M:%S%Z")));
+}
+
+TEST_F(ScalarTemporalTest, StrftimeInvalidTimezone) {
+ const char* seconds = R"(["1970-01-01T00:00:59", null])";
+ auto arr = ArrayFromJSON(timestamp(TimeUnit::SECOND, "non-existent"), seconds);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, testing::HasSubstr("Cannot locate timezone 'non-existent'"),
+ Strftime(arr, StrftimeOptions()));
+}
+
+TEST_F(ScalarTemporalTest, StrftimeCLocale) {
+ auto options_default = StrftimeOptions();
+ auto options = StrftimeOptions("%Y-%m-%dT%H:%M:%S%z", "C");
+ auto options_locale_specific = StrftimeOptions("%a", "C");
+
+ const char* seconds = R"(["1970-01-01T00:00:59", null])";
+ const char* milliseconds = R"(["1970-01-01T00:00:59.123", null])";
+ const char* microseconds = R"(["1970-01-01T00:00:59.123456", null])";
+ const char* nanoseconds = R"(["1970-01-01T00:00:59.123456789", null])";
+
+ const char* default_seconds = R"(["1970-01-01T00:00:59", null])";
+ const char* string_seconds = R"(["1970-01-01T00:00:59+0000", null])";
+ const char* string_milliseconds = R"(["1970-01-01T00:00:59.123+0000", null])";
+ const char* string_microseconds = R"(["1970-01-01T05:30:59.123456+0530", null])";
+ const char* string_nanoseconds = R"(["1969-12-31T14:00:59.123456789-1000", null])";
+
+ const char* string_locale_specific = R"(["Wed", null])";
+
+ CheckScalarUnary("strftime", timestamp(TimeUnit::SECOND, "UTC"), seconds, utf8(),
+ default_seconds, &options_default);
+ CheckScalarUnary("strftime", timestamp(TimeUnit::SECOND, "UTC"), seconds, utf8(),
+ string_seconds, &options);
+ CheckScalarUnary("strftime", timestamp(TimeUnit::MILLI, "GMT"), milliseconds, utf8(),
+ string_milliseconds, &options);
+ CheckScalarUnary("strftime", timestamp(TimeUnit::MICRO, "Asia/Kolkata"), microseconds,
+ utf8(), string_microseconds, &options);
+ CheckScalarUnary("strftime", timestamp(TimeUnit::NANO, "US/Hawaii"), nanoseconds,
+ utf8(), string_nanoseconds, &options);
+
+ CheckScalarUnary("strftime", timestamp(TimeUnit::NANO, "US/Hawaii"), nanoseconds,
+ utf8(), string_locale_specific, &options_locale_specific);
+}
+
+TEST_F(ScalarTemporalTest, StrftimeOtherLocale) {
+ if (!LocaleExists("fr_FR.UTF-8")) {
+ GTEST_SKIP() << "locale 'fr_FR.UTF-8' doesn't exist on this system";
+ }
+
+ auto options = StrftimeOptions("%d %B %Y %H:%M:%S", "fr_FR.UTF-8");
+ const char* milliseconds = R"(
+ ["1970-01-01T00:00:59.123", "2021-08-18T15:11:50.456", null])";
+ const char* expected = R"(
+ ["01 janvier 1970 00:00:59,123", "18 août 2021 15:11:50,456", null])";
+ CheckScalarUnary("strftime", timestamp(TimeUnit::MILLI, "UTC"), milliseconds, utf8(),
+ expected, &options);
+}
+
+TEST_F(ScalarTemporalTest, StrftimeInvalidLocale) {
+ auto options = StrftimeOptions("%d %B %Y %H:%M:%S", "non-existent");
+ const char* seconds = R"(["1970-01-01T00:00:59", null])";
+ auto arr = ArrayFromJSON(timestamp(TimeUnit::SECOND, "UTC"), seconds);
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ testing::HasSubstr("Cannot find locale 'non-existent'"),
+ Strftime(arr, options));
+}
+
+TEST_F(ScalarTemporalTest, TestTemporalDifferenceZoned) {
+ for (auto u : TimeUnit::values()) {
+ auto unit = timestamp(u, "Pacific/Marquesas");
+ auto arr1 = ArrayFromJSON(unit, times_seconds_precision);
+ auto arr2 = ArrayFromJSON(unit, times_seconds_precision2);
+ CheckScalarBinary("years_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("years_between", arr1, arr2,
+ ArrayFromJSON(int64(), years_between_tz));
+ CheckScalarBinary("quarters_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("quarters_between", arr1, arr2,
+ ArrayFromJSON(int64(), quarters_between_tz));
+ CheckScalarBinary(
+ "month_day_nano_interval_between", arr1, arr1,
+ ArrayFromJSON(month_day_nano_interval(), month_day_nano_interval_between_zeros));
+ CheckScalarBinary(
+ "month_day_nano_interval_between", arr1, arr2,
+ ArrayFromJSON(month_day_nano_interval(), month_day_nano_interval_between_tz));
+ CheckScalarBinary("month_interval_between", arr1, arr1,
+ ArrayFromJSON(month_interval(), zeros));
+ CheckScalarBinary("month_interval_between", arr1, arr2,
+ ArrayFromJSON(month_interval(), months_between_tz));
+ CheckScalarBinary(
+ "day_time_interval_between", arr1, arr1,
+ ArrayFromJSON(day_time_interval(), day_time_interval_between_zeros));
+ CheckScalarBinary("day_time_interval_between", arr1, arr2,
+ ArrayFromJSON(day_time_interval(), day_time_interval_between_tz));
+ CheckScalarBinary("weeks_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("weeks_between", arr1, arr2,
+ ArrayFromJSON(int64(), weeks_between_tz));
+ CheckScalarBinary("days_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("days_between", arr1, arr2,
+ ArrayFromJSON(int64(), days_between_tz));
+ CheckScalarBinary("hours_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("hours_between", arr1, arr2,
+ ArrayFromJSON(int64(), hours_between_tz));
+ CheckScalarBinary("minutes_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("minutes_between", arr1, arr2,
+ ArrayFromJSON(int64(), minutes_between));
+ CheckScalarBinary("seconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("seconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), seconds_between));
+ CheckScalarBinary("milliseconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("milliseconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), milliseconds_between));
+ CheckScalarBinary("microseconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("microseconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), microseconds_between));
+ CheckScalarBinary("nanoseconds_between", arr1, arr1, ArrayFromJSON(int64(), zeros));
+ CheckScalarBinary("nanoseconds_between", arr1, arr2,
+ ArrayFromJSON(int64(), nanoseconds_between));
+ }
+}
+
+#endif // !_WIN32
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc
new file mode 100644
index 000000000..a62fbb2cc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc
@@ -0,0 +1,1158 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cmath>
+#include <initializer_list>
+#include <sstream>
+
+#include "arrow/builder.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/temporal_internal.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/time.h"
+#include "arrow/vendored/datetime.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+namespace internal {
+
+namespace {
+
+using arrow_vendored::date::days;
+using arrow_vendored::date::floor;
+using arrow_vendored::date::hh_mm_ss;
+using arrow_vendored::date::local_days;
+using arrow_vendored::date::local_time;
+using arrow_vendored::date::locate_zone;
+using arrow_vendored::date::sys_days;
+using arrow_vendored::date::sys_time;
+using arrow_vendored::date::time_zone;
+using arrow_vendored::date::trunc;
+using arrow_vendored::date::weekday;
+using arrow_vendored::date::weeks;
+using arrow_vendored::date::year_month_day;
+using arrow_vendored::date::year_month_weekday;
+using arrow_vendored::date::years;
+using arrow_vendored::date::zoned_time;
+using arrow_vendored::date::literals::dec;
+using arrow_vendored::date::literals::jan;
+using arrow_vendored::date::literals::last;
+using arrow_vendored::date::literals::mon;
+using arrow_vendored::date::literals::sun;
+using arrow_vendored::date::literals::thu;
+using arrow_vendored::date::literals::wed;
+using internal::applicator::SimpleUnary;
+
+using DayOfWeekState = OptionsWrapper<DayOfWeekOptions>;
+using WeekState = OptionsWrapper<WeekOptions>;
+using StrftimeState = OptionsWrapper<StrftimeOptions>;
+using AssumeTimezoneState = OptionsWrapper<AssumeTimezoneOptions>;
+
+const std::shared_ptr<DataType>& IsoCalendarType() {
+ static auto type = struct_({field("iso_year", int64()), field("iso_week", int64()),
+ field("iso_day_of_week", int64())});
+ return type;
+}
+
+Result<std::locale> GetLocale(const std::string& locale) {
+ try {
+ return std::locale(locale.c_str());
+ } catch (const std::runtime_error& ex) {
+ return Status::Invalid("Cannot find locale '", locale, "': ", ex.what());
+ }
+}
+
+Status ValidateDayOfWeekOptions(const DayOfWeekOptions& options) {
+ if (options.week_start < 1 || 7 < options.week_start) {
+ return Status::Invalid(
+ "week_start must follow ISO convention (Monday=1, Sunday=7). Got week_start=",
+ options.week_start);
+ }
+ return Status::OK();
+}
+
+template <template <typename...> class Op, typename Duration, typename InType,
+ typename OutType>
+struct TemporalComponentExtractDayOfWeek
+ : public TemporalComponentExtractBase<Op, Duration, InType, OutType> {
+ using Base = TemporalComponentExtractBase<Op, Duration, InType, OutType>;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const DayOfWeekOptions& options = DayOfWeekState::Get(ctx);
+ RETURN_NOT_OK(ValidateDayOfWeekOptions(options));
+ return Base::ExecWithOptions(ctx, &options, batch, out);
+ }
+};
+
+template <template <typename...> class Op, typename Duration, typename InType,
+ typename OutType>
+struct AssumeTimezoneExtractor
+ : public TemporalComponentExtractBase<Op, Duration, InType, OutType> {
+ using Base = TemporalComponentExtractBase<Op, Duration, InType, OutType>;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const AssumeTimezoneOptions& options = AssumeTimezoneState::Get(ctx);
+ const auto& timezone = GetInputTimezone(batch.values[0]);
+ if (!timezone.empty()) {
+ return Status::Invalid("Timestamps already have a timezone: '", timezone,
+ "'. Cannot localize to '", options.timezone, "'.");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto tz, LocateZone(options.timezone));
+ using ExecTemplate = Op<Duration>;
+ auto op = ExecTemplate(&options, tz);
+ applicator::ScalarUnaryNotNullStateful<OutType, TimestampType, ExecTemplate> kernel{
+ op};
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+template <template <typename...> class Op, typename Duration, typename InType,
+ typename OutType>
+struct TemporalComponentExtractWeek
+ : public TemporalComponentExtractBase<Op, Duration, InType, OutType> {
+ using Base = TemporalComponentExtractBase<Op, Duration, InType, OutType>;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const WeekOptions& options = WeekState::Get(ctx);
+ return Base::ExecWithOptions(ctx, &options, batch, out);
+ }
+};
+
+// ----------------------------------------------------------------------
+// Extract year from temporal types
+//
+// This class and the following (`Month`, etc.) are to be used as the `Op`
+// parameter to `TemporalComponentExtract`.
+
+template <typename Duration, typename Localizer>
+struct Year {
+ Year(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ return static_cast<T>(static_cast<const int32_t>(
+ year_month_day(floor<days>(localizer_.template ConvertTimePoint<Duration>(arg)))
+ .year()));
+ }
+
+ Localizer localizer_;
+};
+
+// ----------------------------------------------------------------------
+// Extract month from temporal types
+
+template <typename Duration, typename Localizer>
+struct Month {
+ explicit Month(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ return static_cast<T>(static_cast<const uint32_t>(
+ year_month_day(floor<days>(localizer_.template ConvertTimePoint<Duration>(arg)))
+ .month()));
+ }
+
+ Localizer localizer_;
+};
+
+// ----------------------------------------------------------------------
+// Extract day from temporal types
+
+template <typename Duration, typename Localizer>
+struct Day {
+ explicit Day(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ return static_cast<T>(static_cast<const uint32_t>(
+ year_month_day(floor<days>(localizer_.template ConvertTimePoint<Duration>(arg)))
+ .day()));
+ }
+
+ Localizer localizer_;
+};
+
+// ----------------------------------------------------------------------
+// Extract day of week from temporal types
+//
+// By default week starts on Monday represented by 0 and ends on Sunday represented
+// by 6. Start day of the week (Monday=1, Sunday=7) and numbering start (0 or 1) can be
+// set using DayOfWeekOptions
+
+template <typename Duration, typename Localizer>
+struct DayOfWeek {
+ explicit DayOfWeek(const DayOfWeekOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {
+ for (int i = 0; i < 7; i++) {
+ lookup_table[i] = i + 8 - options->week_start;
+ lookup_table[i] = (lookup_table[i] > 6) ? lookup_table[i] - 7 : lookup_table[i];
+ lookup_table[i] += !options->count_from_zero;
+ }
+ }
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ const auto wd = year_month_weekday(
+ floor<days>(localizer_.template ConvertTimePoint<Duration>(arg)))
+ .weekday()
+ .iso_encoding();
+ return lookup_table[wd - 1];
+ }
+
+ std::array<int64_t, 7> lookup_table;
+ Localizer localizer_;
+};
+
+// ----------------------------------------------------------------------
+// Extract day of year from temporal types
+
+template <typename Duration, typename Localizer>
+struct DayOfYear {
+ explicit DayOfYear(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ const auto t = floor<days>(localizer_.template ConvertTimePoint<Duration>(arg));
+ return static_cast<T>(
+ (t - localizer_.ConvertDays(year_month_day(t).year() / jan / 0)).count());
+ }
+
+ Localizer localizer_;
+};
+
+// ----------------------------------------------------------------------
+// Extract ISO Year values from temporal types
+//
+// First week of an ISO year has the majority (4 or more) of it's days in January.
+// Last week of an ISO year has the year's last Thursday in it.
+
+template <typename Duration, typename Localizer>
+struct ISOYear {
+ explicit ISOYear(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ const auto t = floor<days>(localizer_.template ConvertTimePoint<Duration>(arg));
+ auto y = year_month_day{t + days{3}}.year();
+ auto start = localizer_.ConvertDays((y - years{1}) / dec / thu[last]) + (mon - thu);
+ if (t < start) {
+ --y;
+ }
+ return static_cast<T>(static_cast<int32_t>(y));
+ }
+
+ Localizer localizer_;
+};
+
+// ----------------------------------------------------------------------
+// Extract week from temporal types
+//
+// First week of an ISO year has the majority (4 or more) of its days in January.
+// Last week of an ISO year has the year's last Thursday in it.
+// Based on
+// https://github.com/HowardHinnant/date/blob/6e921e1b1d21e84a5c82416ba7ecd98e33a436d0/include/date/iso_week.h#L1503
+
+template <typename Duration, typename Localizer>
+struct Week {
+ explicit Week(const WeekOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)),
+ count_from_zero_(options->count_from_zero),
+ first_week_is_fully_in_year_(options->first_week_is_fully_in_year) {
+ if (options->week_starts_monday) {
+ if (first_week_is_fully_in_year_) {
+ wd_ = mon;
+ } else {
+ wd_ = thu;
+ }
+ } else {
+ if (first_week_is_fully_in_year_) {
+ wd_ = sun;
+ } else {
+ wd_ = wed;
+ }
+ }
+ if (count_from_zero_) {
+ days_offset_ = days{0};
+ } else {
+ days_offset_ = days{3};
+ }
+ }
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ const auto t = floor<days>(localizer_.template ConvertTimePoint<Duration>(arg));
+ auto y = year_month_day{t + days_offset_}.year();
+
+ if (first_week_is_fully_in_year_) {
+ auto start = localizer_.ConvertDays(y / jan / wd_[1]);
+ if (!count_from_zero_) {
+ if (t < start) {
+ --y;
+ start = localizer_.ConvertDays(y / jan / wd_[1]);
+ }
+ }
+ return static_cast<T>(floor<weeks>(t - start).count() + 1);
+ }
+
+ auto start = localizer_.ConvertDays((y - years{1}) / dec / wd_[last]) + (mon - thu);
+ if (!count_from_zero_) {
+ if (t < start) {
+ --y;
+ start = localizer_.ConvertDays((y - years{1}) / dec / wd_[last]) + (mon - thu);
+ }
+ }
+ return static_cast<T>(floor<weeks>(t - start).count() + 1);
+ }
+
+ Localizer localizer_;
+ arrow_vendored::date::weekday wd_;
+ arrow_vendored::date::days days_offset_;
+ const bool count_from_zero_;
+ const bool first_week_is_fully_in_year_;
+};
+
+// ----------------------------------------------------------------------
+// Extract quarter from temporal types
+
+template <typename Duration, typename Localizer>
+struct Quarter {
+ explicit Quarter(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ const auto ymd =
+ year_month_day(floor<days>(localizer_.template ConvertTimePoint<Duration>(arg)));
+ return static_cast<T>(GetQuarter(ymd) + 1);
+ }
+
+ Localizer localizer_;
+};
+
+// ----------------------------------------------------------------------
+// Extract hour from timestamp
+
+template <typename Duration, typename Localizer>
+struct Hour {
+ explicit Hour(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ const auto t = localizer_.template ConvertTimePoint<Duration>(arg);
+ return static_cast<T>((t - floor<days>(t)) / std::chrono::hours(1));
+ }
+
+ Localizer localizer_;
+};
+
+// ----------------------------------------------------------------------
+// Extract minute from timestamp
+
+template <typename Duration, typename Localizer>
+struct Minute {
+ explicit Minute(const FunctionOptions* options, Localizer&& localizer)
+ : localizer_(std::move(localizer)) {}
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ const auto t = localizer_.template ConvertTimePoint<Duration>(arg);
+ return static_cast<T>((t - floor<std::chrono::hours>(t)) / std::chrono::minutes(1));
+ }
+
+ Localizer localizer_;
+};
+
+// ----------------------------------------------------------------------
+// Extract second from timestamp
+
+template <typename Duration, typename Localizer>
+struct Second {
+ explicit Second(const FunctionOptions* options, Localizer&& localizer) {}
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status*) const {
+ Duration t = Duration{arg};
+ return static_cast<T>((t - floor<std::chrono::minutes>(t)) / std::chrono::seconds(1));
+ }
+};
+
+// ----------------------------------------------------------------------
+// Extract subsecond from timestamp
+
+template <typename Duration, typename Localizer>
+struct Subsecond {
+ explicit Subsecond(const FunctionOptions* options, Localizer&& localizer) {}
+
+ template <typename T, typename Arg0>
+ static T Call(KernelContext*, Arg0 arg, Status*) {
+ Duration t = Duration{arg};
+ return static_cast<T>(
+ (std::chrono::duration<double>(t - floor<std::chrono::seconds>(t)).count()));
+ }
+};
+
+// ----------------------------------------------------------------------
+// Extract milliseconds from timestamp
+
+template <typename Duration, typename Localizer>
+struct Millisecond {
+ explicit Millisecond(const FunctionOptions* options, Localizer&& localizer) {}
+
+ template <typename T, typename Arg0>
+ static T Call(KernelContext*, Arg0 arg, Status*) {
+ Duration t = Duration{arg};
+ return static_cast<T>(
+ ((t - floor<std::chrono::seconds>(t)) / std::chrono::milliseconds(1)) % 1000);
+ }
+};
+
+// ----------------------------------------------------------------------
+// Extract microseconds from timestamp
+
+template <typename Duration, typename Localizer>
+struct Microsecond {
+ explicit Microsecond(const FunctionOptions* options, Localizer&& localizer) {}
+
+ template <typename T, typename Arg0>
+ static T Call(KernelContext*, Arg0 arg, Status*) {
+ Duration t = Duration{arg};
+ return static_cast<T>(
+ ((t - floor<std::chrono::seconds>(t)) / std::chrono::microseconds(1)) % 1000);
+ }
+};
+
+// ----------------------------------------------------------------------
+// Extract nanoseconds from timestamp
+
+template <typename Duration, typename Localizer>
+struct Nanosecond {
+ explicit Nanosecond(const FunctionOptions* options, Localizer&& localizer) {}
+
+ template <typename T, typename Arg0>
+ static T Call(KernelContext*, Arg0 arg, Status*) {
+ Duration t = Duration{arg};
+ return static_cast<T>(
+ ((t - floor<std::chrono::seconds>(t)) / std::chrono::nanoseconds(1)) % 1000);
+ }
+};
+
+// ----------------------------------------------------------------------
+// Convert timestamps to a string representation with an arbitrary format
+
+#ifndef _WIN32
+template <typename Duration, typename InType>
+struct Strftime {
+ const StrftimeOptions& options;
+ const time_zone* tz;
+ const std::locale locale;
+
+ static Result<Strftime> Make(KernelContext* ctx, const DataType& type) {
+ const StrftimeOptions& options = StrftimeState::Get(ctx);
+
+ // This check is due to surprising %c behavior.
+ // See https://github.com/HowardHinnant/date/issues/704
+ if ((options.format.find("%c") != std::string::npos) && (options.locale != "C")) {
+ return Status::Invalid("%c flag is not supported in non-C locales.");
+ }
+ auto timezone = GetInputTimezone(type);
+
+ if (timezone.empty()) {
+ if ((options.format.find("%z") != std::string::npos) ||
+ (options.format.find("%Z") != std::string::npos)) {
+ return Status::Invalid(
+ "Timezone not present, cannot convert to string with timezone: ",
+ options.format);
+ }
+ timezone = "UTC";
+ }
+
+ ARROW_ASSIGN_OR_RAISE(const time_zone* tz, LocateZone(timezone));
+
+ ARROW_ASSIGN_OR_RAISE(std::locale locale, GetLocale(options.locale));
+
+ return Strftime{options, tz, std::move(locale)};
+ }
+
+ static Status Call(KernelContext* ctx, const Scalar& in, Scalar* out) {
+ ARROW_ASSIGN_OR_RAISE(auto self, Make(ctx, *in.type));
+ TimestampFormatter formatter{self.options.format, self.tz, self.locale};
+
+ if (in.is_valid) {
+ const int64_t in_val = internal::UnboxScalar<const InType>::Unbox(in);
+ ARROW_ASSIGN_OR_RAISE(auto formatted, formatter(in_val));
+ checked_cast<StringScalar*>(out)->value = Buffer::FromString(std::move(formatted));
+ } else {
+ out->is_valid = false;
+ }
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& in, ArrayData* out) {
+ ARROW_ASSIGN_OR_RAISE(auto self, Make(ctx, *in.type));
+ TimestampFormatter formatter{self.options.format, self.tz, self.locale};
+
+ StringBuilder string_builder;
+ // Presize string data using a heuristic
+ {
+ ARROW_ASSIGN_OR_RAISE(auto formatted, formatter(42));
+ const auto string_size = static_cast<int64_t>(std::ceil(formatted.size() * 1.1));
+ RETURN_NOT_OK(string_builder.Reserve(in.length));
+ RETURN_NOT_OK(
+ string_builder.ReserveData((in.length - in.GetNullCount()) * string_size));
+ }
+
+ auto visit_null = [&]() { return string_builder.AppendNull(); };
+ auto visit_value = [&](int64_t arg) {
+ ARROW_ASSIGN_OR_RAISE(auto formatted, formatter(arg));
+ return string_builder.Append(std::move(formatted));
+ };
+ RETURN_NOT_OK(VisitArrayDataInline<InType>(in, visit_value, visit_null));
+
+ std::shared_ptr<Array> out_array;
+ RETURN_NOT_OK(string_builder.Finish(&out_array));
+ *out = *std::move(out_array->data());
+
+ return Status::OK();
+ }
+
+ struct TimestampFormatter {
+ const char* format;
+ const time_zone* tz;
+ std::ostringstream bufstream;
+
+ explicit TimestampFormatter(const std::string& format, const time_zone* tz,
+ const std::locale& locale)
+ : format(format.c_str()), tz(tz) {
+ bufstream.imbue(locale);
+ // Propagate errors as C++ exceptions (to get an actual error message)
+ bufstream.exceptions(std::ios::failbit | std::ios::badbit);
+ }
+
+ Result<std::string> operator()(int64_t arg) {
+ bufstream.str("");
+ const auto zt = zoned_time<Duration>{tz, sys_time<Duration>(Duration{arg})};
+ try {
+ arrow_vendored::date::to_stream(bufstream, format, zt);
+ } catch (const std::runtime_error& ex) {
+ bufstream.clear();
+ return Status::Invalid("Failed formatting timestamp: ", ex.what());
+ }
+ // XXX could return a view with std::ostringstream::view() (C++20)
+ return std::move(bufstream).str();
+ }
+ };
+};
+#else
+template <typename Duration, typename InType>
+struct Strftime {
+ static Status Call(KernelContext* ctx, const Scalar& in, Scalar* out) {
+ return Status::NotImplemented("Strftime not yet implemented on windows.");
+ }
+ static Status Call(KernelContext* ctx, const ArrayData& in, ArrayData* out) {
+ return Status::NotImplemented("Strftime not yet implemented on windows.");
+ }
+};
+#endif
+
+// ----------------------------------------------------------------------
+// Convert timestamps from local timestamp without a timezone to timestamps with a
+// timezone, interpreting the local timestamp as being in the specified timezone
+
+Result<ValueDescr> ResolveAssumeTimezoneOutput(KernelContext* ctx,
+ const std::vector<ValueDescr>& args) {
+ auto in_type = checked_cast<const TimestampType*>(args[0].type.get());
+ auto type = timestamp(in_type->unit(), AssumeTimezoneState::Get(ctx).timezone);
+ return ValueDescr(std::move(type));
+}
+
+template <typename Duration>
+struct AssumeTimezone {
+ explicit AssumeTimezone(const AssumeTimezoneOptions* options, const time_zone* tz)
+ : options(*options), tz_(tz) {}
+
+ template <typename T, typename Arg0>
+ T get_local_time(Arg0 arg, const time_zone* tz) const {
+ return static_cast<T>(zoned_time<Duration>(tz, local_time<Duration>(Duration{arg}))
+ .get_sys_time()
+ .time_since_epoch()
+ .count());
+ }
+
+ template <typename T, typename Arg0>
+ T get_local_time(Arg0 arg, const arrow_vendored::date::choose choose,
+ const time_zone* tz) const {
+ return static_cast<T>(
+ zoned_time<Duration>(tz, local_time<Duration>(Duration{arg}), choose)
+ .get_sys_time()
+ .time_since_epoch()
+ .count());
+ }
+
+ template <typename T, typename Arg0>
+ T Call(KernelContext*, Arg0 arg, Status* st) const {
+ try {
+ return get_local_time<T, Arg0>(arg, tz_);
+ } catch (const arrow_vendored::date::nonexistent_local_time& e) {
+ switch (options.nonexistent) {
+ case AssumeTimezoneOptions::Nonexistent::NONEXISTENT_RAISE: {
+ *st = Status::Invalid("Timestamp doesn't exist in timezone '", options.timezone,
+ "': ", e.what());
+ return arg;
+ }
+ case AssumeTimezoneOptions::Nonexistent::NONEXISTENT_EARLIEST: {
+ return get_local_time<T, Arg0>(arg, arrow_vendored::date::choose::latest, tz_) -
+ 1;
+ }
+ case AssumeTimezoneOptions::Nonexistent::NONEXISTENT_LATEST: {
+ return get_local_time<T, Arg0>(arg, arrow_vendored::date::choose::latest, tz_);
+ }
+ }
+ } catch (const arrow_vendored::date::ambiguous_local_time& e) {
+ switch (options.ambiguous) {
+ case AssumeTimezoneOptions::Ambiguous::AMBIGUOUS_RAISE: {
+ *st = Status::Invalid("Timestamp is ambiguous in timezone '", options.timezone,
+ "': ", e.what());
+ return arg;
+ }
+ case AssumeTimezoneOptions::Ambiguous::AMBIGUOUS_EARLIEST: {
+ return get_local_time<T, Arg0>(arg, arrow_vendored::date::choose::earliest,
+ tz_);
+ }
+ case AssumeTimezoneOptions::Ambiguous::AMBIGUOUS_LATEST: {
+ return get_local_time<T, Arg0>(arg, arrow_vendored::date::choose::latest, tz_);
+ }
+ }
+ }
+ return 0;
+ }
+ AssumeTimezoneOptions options;
+ const time_zone* tz_;
+};
+
+// ----------------------------------------------------------------------
+// Extract ISO calendar values from timestamp
+
+template <typename Duration, typename Localizer>
+std::array<int64_t, 3> GetIsoCalendar(int64_t arg, Localizer&& localizer) {
+ const auto t = floor<days>(localizer.template ConvertTimePoint<Duration>(arg));
+ const auto ymd = year_month_day(t);
+ auto y = year_month_day{t + days{3}}.year();
+ auto start = localizer.ConvertDays((y - years{1}) / dec / thu[last]) + (mon - thu);
+ if (t < start) {
+ --y;
+ start = localizer.ConvertDays((y - years{1}) / dec / thu[last]) + (mon - thu);
+ }
+ return {static_cast<int64_t>(static_cast<int32_t>(y)),
+ static_cast<int64_t>(trunc<weeks>(t - start).count() + 1),
+ static_cast<int64_t>(weekday(ymd).iso_encoding())};
+}
+
+template <typename Duration, typename InType>
+struct ISOCalendarWrapper {
+ static Result<std::array<int64_t, 3>> Get(const Scalar& in) {
+ const auto& in_val = internal::UnboxScalar<const InType>::Unbox(in);
+ return GetIsoCalendar<Duration>(in_val, NonZonedLocalizer{});
+ }
+};
+
+template <typename Duration>
+struct ISOCalendarWrapper<Duration, TimestampType> {
+ static Result<std::array<int64_t, 3>> Get(const Scalar& in) {
+ const auto& in_val = internal::UnboxScalar<const TimestampType>::Unbox(in);
+ const auto& timezone = GetInputTimezone(in);
+ if (timezone.empty()) {
+ return GetIsoCalendar<Duration>(in_val, NonZonedLocalizer{});
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto tz, LocateZone(timezone));
+ return GetIsoCalendar<Duration>(in_val, ZonedLocalizer{tz});
+ }
+ }
+};
+
+template <typename Duration, typename InType, typename BuilderType>
+struct ISOCalendarVisitValueFunction {
+ static Result<std::function<Status(typename InType::c_type arg)>> Get(
+ const std::vector<BuilderType*>& field_builders, const ArrayData&,
+ StructBuilder* struct_builder) {
+ return [=](typename InType::c_type arg) {
+ const auto iso_calendar = GetIsoCalendar<Duration>(arg, NonZonedLocalizer{});
+ field_builders[0]->UnsafeAppend(iso_calendar[0]);
+ field_builders[1]->UnsafeAppend(iso_calendar[1]);
+ field_builders[2]->UnsafeAppend(iso_calendar[2]);
+ return struct_builder->Append();
+ };
+ }
+};
+
+template <typename Duration, typename BuilderType>
+struct ISOCalendarVisitValueFunction<Duration, TimestampType, BuilderType> {
+ static Result<std::function<Status(typename TimestampType::c_type arg)>> Get(
+ const std::vector<BuilderType*>& field_builders, const ArrayData& in,
+ StructBuilder* struct_builder) {
+ const auto& timezone = GetInputTimezone(in);
+ if (timezone.empty()) {
+ return [=](TimestampType::c_type arg) {
+ const auto iso_calendar = GetIsoCalendar<Duration>(arg, NonZonedLocalizer{});
+ field_builders[0]->UnsafeAppend(iso_calendar[0]);
+ field_builders[1]->UnsafeAppend(iso_calendar[1]);
+ field_builders[2]->UnsafeAppend(iso_calendar[2]);
+ return struct_builder->Append();
+ };
+ }
+ ARROW_ASSIGN_OR_RAISE(auto tz, LocateZone(timezone));
+ return [=](TimestampType::c_type arg) {
+ const auto iso_calendar = GetIsoCalendar<Duration>(arg, ZonedLocalizer{tz});
+ field_builders[0]->UnsafeAppend(iso_calendar[0]);
+ field_builders[1]->UnsafeAppend(iso_calendar[1]);
+ field_builders[2]->UnsafeAppend(iso_calendar[2]);
+ return struct_builder->Append();
+ };
+ }
+};
+
+template <typename Duration, typename InType>
+struct ISOCalendar {
+ static Status Call(KernelContext* ctx, const Scalar& in, Scalar* out) {
+ if (in.is_valid) {
+ ARROW_ASSIGN_OR_RAISE(auto iso_calendar,
+ (ISOCalendarWrapper<Duration, InType>::Get(in)));
+ ScalarVector values = {std::make_shared<Int64Scalar>(iso_calendar[0]),
+ std::make_shared<Int64Scalar>(iso_calendar[1]),
+ std::make_shared<Int64Scalar>(iso_calendar[2])};
+ *checked_cast<StructScalar*>(out) =
+ StructScalar(std::move(values), IsoCalendarType());
+ } else {
+ out->is_valid = false;
+ }
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& in, ArrayData* out) {
+ using BuilderType = typename TypeTraits<Int64Type>::BuilderType;
+
+ std::unique_ptr<ArrayBuilder> array_builder;
+ RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), IsoCalendarType(), &array_builder));
+ StructBuilder* struct_builder = checked_cast<StructBuilder*>(array_builder.get());
+ RETURN_NOT_OK(struct_builder->Reserve(in.length));
+
+ std::vector<BuilderType*> field_builders;
+ field_builders.reserve(3);
+ for (int i = 0; i < 3; i++) {
+ field_builders.push_back(
+ checked_cast<BuilderType*>(struct_builder->field_builder(i)));
+ RETURN_NOT_OK(field_builders[i]->Reserve(1));
+ }
+ auto visit_null = [&]() { return struct_builder->AppendNull(); };
+ std::function<Status(typename InType::c_type arg)> visit_value;
+ ARROW_ASSIGN_OR_RAISE(
+ visit_value, (ISOCalendarVisitValueFunction<Duration, InType, BuilderType>::Get(
+ field_builders, in, struct_builder)));
+ RETURN_NOT_OK(
+ VisitArrayDataInline<typename InType::PhysicalType>(in, visit_value, visit_null));
+ std::shared_ptr<Array> out_array;
+ RETURN_NOT_OK(struct_builder->Finish(&out_array));
+ *out = *std::move(out_array->data());
+ return Status::OK();
+ }
+};
+
+// ----------------------------------------------------------------------
+// Registration helpers
+
+template <template <typename...> class Op,
+ template <template <typename...> class OpExec, typename Duration,
+ typename InType, typename OutType, typename... Args>
+ class ExecTemplate,
+ typename OutType>
+struct UnaryTemporalFactory {
+ OutputType out_type;
+ KernelInit init;
+ std::shared_ptr<ScalarFunction> func;
+
+ template <typename... WithTypes>
+ static std::shared_ptr<ScalarFunction> Make(
+ std::string name, OutputType out_type, const FunctionDoc* doc,
+ const FunctionOptions* default_options = NULLPTR, KernelInit init = NULLPTR) {
+ DCHECK_NE(sizeof...(WithTypes), 0);
+ UnaryTemporalFactory self{
+ out_type, init,
+ std::make_shared<ScalarFunction>(name, Arity::Unary(), doc, default_options)};
+ AddTemporalKernels(&self, WithTypes{}...);
+ return self.func;
+ }
+
+ template <typename Duration, typename InType>
+ void AddKernel(InputType in_type) {
+ auto exec = ExecTemplate<Op, Duration, InType, OutType>::Exec;
+ DCHECK_OK(func->AddKernel({std::move(in_type)}, out_type, std::move(exec), init));
+ }
+};
+
+template <template <typename...> class Op>
+struct SimpleUnaryTemporalFactory {
+ OutputType out_type;
+ KernelInit init;
+ std::shared_ptr<ScalarFunction> func;
+
+ template <typename... WithTypes>
+ static std::shared_ptr<ScalarFunction> Make(
+ std::string name, OutputType out_type, const FunctionDoc* doc,
+ const FunctionOptions* default_options = NULLPTR, KernelInit init = NULLPTR) {
+ DCHECK_NE(sizeof...(WithTypes), 0);
+ SimpleUnaryTemporalFactory self{
+ out_type, init,
+ std::make_shared<ScalarFunction>(name, Arity::Unary(), doc, default_options)};
+ AddTemporalKernels(&self, WithTypes{}...);
+ return self.func;
+ }
+
+ template <typename Duration, typename InType>
+ void AddKernel(InputType in_type) {
+ auto exec = SimpleUnary<Op<Duration, InType>>;
+ DCHECK_OK(func->AddKernel({std::move(in_type)}, out_type, std::move(exec), init));
+ }
+};
+
+const FunctionDoc year_doc{
+ "Extract year number",
+ ("Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc month_doc{
+ "Extract month number",
+ ("Month is encoded as January=1, December=12.\n"
+ "Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc day_doc{
+ "Extract day number",
+ ("Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc day_of_week_doc{
+ "Extract day of the week number",
+ ("By default, the week starts on Monday represented by 0 and ends on Sunday\n"
+ "represented by 6.\n"
+ "`DayOfWeekOptions.week_start` can be used to set another starting day using\n"
+ "the ISO numbering convention (1=start week on Monday, 7=start week on Sunday).\n"
+ "Day numbers can start at 0 or 1 based on `DayOfWeekOptions.count_from_zero`.\n"
+ "Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"},
+ "DayOfWeekOptions"};
+
+const FunctionDoc day_of_year_doc{
+ "Extract day of year number",
+ ("January 1st maps to day number 1, February 1st to 32, etc.\n"
+ "Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc iso_year_doc{
+ "Extract ISO year number",
+ ("First week of an ISO year has the majority (4 or more) of its days in January."
+ "Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc iso_week_doc{
+ "Extract ISO week of year number",
+ ("First ISO week has the majority (4 or more) of its days in January."
+ "ISO week starts on Monday.\n"
+ "Week of the year starts with 1 and can run up to 53.\n"
+ "Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc us_week_doc{
+ "Extract US week of year number",
+ ("First US week has the majority (4 or more) of its days in January."
+ "US week starts on Sunday.\n"
+ "Week of the year starts with 1 and can run up to 53.\n"
+ "Null values emit null.\n"
+ "An error is returned if the timestamps have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc week_doc{
+ "Extract week of year number",
+ ("First week has the majority (4 or more) of its days in January.\n"
+ "Year can have 52 or 53 weeks. Week numbering can start with 0 or 1 using "
+ "DayOfWeekOptions.count_from_zero.\n"
+ "An error is returned if the timestamps have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"},
+ "WeekOptions"};
+
+const FunctionDoc iso_calendar_doc{
+ "Extract (ISO year, ISO week, ISO day of week) struct",
+ ("ISO week starts on Monday denoted by 1 and ends on Sunday denoted by 7.\n"
+ "Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc quarter_doc{
+ "Extract quarter of year number",
+ ("First quarter maps to 1 and forth quarter maps to 4.\n"
+ "Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc hour_doc{
+ "Extract hour value",
+ ("Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc minute_doc{
+ "Extract minute values",
+ ("Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc second_doc{
+ "Extract second values",
+ ("Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc millisecond_doc{
+ "Extract millisecond values",
+ ("Millisecond returns number of milliseconds since the last full second.\n"
+ "Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc microsecond_doc{
+ "Extract microsecond values",
+ ("Millisecond returns number of microseconds since the last full millisecond.\n"
+ "Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc nanosecond_doc{
+ "Extract nanosecond values",
+ ("Nanosecond returns number of nanoseconds since the last full microsecond.\n"
+ "Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc subsecond_doc{
+ "Extract subsecond values",
+ ("Subsecond returns the fraction of a second since the last full second.\n"
+ "Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database."),
+ {"values"}};
+
+const FunctionDoc strftime_doc{
+ "Format temporal values according to a format string",
+ ("For each input value, emit a formatted string.\n"
+ "The time format string and locale can be set using StrftimeOptions.\n"
+ "The output precision of the \"%S\" (seconds) format code depends on\n"
+ "the input time precision: it is an integer for timestamps with\n"
+ "second precision, a real number with the required number of fractional\n"
+ "digits for higher precisions.\n"
+ "Null values emit null.\n"
+ "An error is returned if the values have a defined timezone but it\n"
+ "cannot be found in the timezone database, or if the specified locale\n"
+ "does not exist on this system."),
+ {"timestamps"},
+ "StrftimeOptions"};
+
+const FunctionDoc assume_timezone_doc{
+ "Convert naive timestamp to timezone-aware timestamp",
+ ("Input timestamps are assumed to be relative to the timezone given in the\n"
+ "`timezone` option. They are converted to UTC-relative timestamps and\n"
+ "the output type has its timezone set to the value of the `timezone`\n"
+ "option. Null values emit null.\n"
+ "This function is meant to be used when an external system produces\n"
+ "\"timezone-naive\" timestamps which need to be converted to\n"
+ "\"timezone-aware\" timestamps. An error is returned if the timestamps\n"
+ "already have a defined timezone."),
+ {"timestamps"},
+ "AssumeTimezoneOptions"};
+
+} // namespace
+
+void RegisterScalarTemporalUnary(FunctionRegistry* registry) {
+ // Date extractors
+ auto year =
+ UnaryTemporalFactory<Year, TemporalComponentExtract,
+ Int64Type>::Make<WithDates, WithTimestamps>("year", int64(),
+ &year_doc);
+ DCHECK_OK(registry->AddFunction(std::move(year)));
+
+ auto month =
+ UnaryTemporalFactory<Month, TemporalComponentExtract,
+ Int64Type>::Make<WithDates, WithTimestamps>("month", int64(),
+ &month_doc);
+ DCHECK_OK(registry->AddFunction(std::move(month)));
+
+ auto day =
+ UnaryTemporalFactory<Day, TemporalComponentExtract,
+ Int64Type>::Make<WithDates, WithTimestamps>("day", int64(),
+ &day_doc);
+ DCHECK_OK(registry->AddFunction(std::move(day)));
+
+ static const auto default_day_of_week_options = DayOfWeekOptions::Defaults();
+ auto day_of_week =
+ UnaryTemporalFactory<DayOfWeek, TemporalComponentExtractDayOfWeek, Int64Type>::Make<
+ WithDates, WithTimestamps>("day_of_week", int64(), &day_of_week_doc,
+ &default_day_of_week_options, DayOfWeekState::Init);
+ DCHECK_OK(registry->AddFunction(std::move(day_of_week)));
+
+ auto day_of_year =
+ UnaryTemporalFactory<DayOfYear, TemporalComponentExtract,
+ Int64Type>::Make<WithDates, WithTimestamps>("day_of_year",
+ int64(),
+ &day_of_year_doc);
+ DCHECK_OK(registry->AddFunction(std::move(day_of_year)));
+
+ auto iso_year =
+ UnaryTemporalFactory<ISOYear, TemporalComponentExtract,
+ Int64Type>::Make<WithDates, WithTimestamps>("iso_year",
+ int64(),
+ &iso_year_doc);
+ DCHECK_OK(registry->AddFunction(std::move(iso_year)));
+
+ static const auto default_iso_week_options = WeekOptions::ISODefaults();
+ auto iso_week =
+ UnaryTemporalFactory<Week, TemporalComponentExtractWeek, Int64Type>::Make<
+ WithDates, WithTimestamps>("iso_week", int64(), &iso_week_doc,
+ &default_iso_week_options, WeekState::Init);
+ DCHECK_OK(registry->AddFunction(std::move(iso_week)));
+
+ static const auto default_us_week_options = WeekOptions::USDefaults();
+ auto us_week =
+ UnaryTemporalFactory<Week, TemporalComponentExtractWeek, Int64Type>::Make<
+ WithDates, WithTimestamps>("us_week", int64(), &us_week_doc,
+ &default_us_week_options, WeekState::Init);
+ DCHECK_OK(registry->AddFunction(std::move(us_week)));
+
+ static const auto default_week_options = WeekOptions();
+ auto week = UnaryTemporalFactory<Week, TemporalComponentExtractWeek, Int64Type>::Make<
+ WithDates, WithTimestamps>("week", int64(), &week_doc, &default_week_options,
+ WeekState::Init);
+ DCHECK_OK(registry->AddFunction(std::move(week)));
+
+ auto iso_calendar =
+ SimpleUnaryTemporalFactory<ISOCalendar>::Make<WithDates, WithTimestamps>(
+ "iso_calendar", IsoCalendarType(), &iso_calendar_doc);
+ DCHECK_OK(registry->AddFunction(std::move(iso_calendar)));
+
+ auto quarter =
+ UnaryTemporalFactory<Quarter, TemporalComponentExtract,
+ Int64Type>::Make<WithDates, WithTimestamps>("quarter", int64(),
+ &quarter_doc);
+ DCHECK_OK(registry->AddFunction(std::move(quarter)));
+
+ // Date / time extractors
+ auto hour =
+ UnaryTemporalFactory<Hour, TemporalComponentExtract,
+ Int64Type>::Make<WithTimes, WithTimestamps>("hour", int64(),
+ &hour_doc);
+ DCHECK_OK(registry->AddFunction(std::move(hour)));
+
+ auto minute =
+ UnaryTemporalFactory<Minute, TemporalComponentExtract,
+ Int64Type>::Make<WithTimes, WithTimestamps>("minute", int64(),
+ &minute_doc);
+ DCHECK_OK(registry->AddFunction(std::move(minute)));
+
+ auto second =
+ UnaryTemporalFactory<Second, TemporalComponentExtract,
+ Int64Type>::Make<WithTimes, WithTimestamps>("second", int64(),
+ &second_doc);
+ DCHECK_OK(registry->AddFunction(std::move(second)));
+
+ auto millisecond =
+ UnaryTemporalFactory<Millisecond, TemporalComponentExtract,
+ Int64Type>::Make<WithTimes, WithTimestamps>("millisecond",
+ int64(),
+ &millisecond_doc);
+ DCHECK_OK(registry->AddFunction(std::move(millisecond)));
+
+ auto microsecond =
+ UnaryTemporalFactory<Microsecond, TemporalComponentExtract,
+ Int64Type>::Make<WithTimes, WithTimestamps>("microsecond",
+ int64(),
+ &microsecond_doc);
+ DCHECK_OK(registry->AddFunction(std::move(microsecond)));
+
+ auto nanosecond =
+ UnaryTemporalFactory<Nanosecond, TemporalComponentExtract,
+ Int64Type>::Make<WithTimes, WithTimestamps>("nanosecond",
+ int64(),
+ &nanosecond_doc);
+ DCHECK_OK(registry->AddFunction(std::move(nanosecond)));
+
+ auto subsecond =
+ UnaryTemporalFactory<Subsecond, TemporalComponentExtract,
+ DoubleType>::Make<WithTimes, WithTimestamps>("subsecond",
+ float64(),
+ &subsecond_doc);
+ DCHECK_OK(registry->AddFunction(std::move(subsecond)));
+
+ // Timezone-related functions
+ static const auto default_strftime_options = StrftimeOptions();
+ auto strftime =
+ SimpleUnaryTemporalFactory<Strftime>::Make<WithTimes, WithDates, WithTimestamps>(
+ "strftime", utf8(), &strftime_doc, &default_strftime_options,
+ StrftimeState::Init);
+ DCHECK_OK(registry->AddFunction(std::move(strftime)));
+
+ auto assume_timezone =
+ UnaryTemporalFactory<AssumeTimezone, AssumeTimezoneExtractor, TimestampType>::Make<
+ WithTimestamps>("assume_timezone",
+ OutputType::Resolver(ResolveAssumeTimezoneOutput),
+ &assume_timezone_doc, nullptr, AssumeTimezoneState::Init);
+ DCHECK_OK(registry->AddFunction(std::move(assume_timezone)));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_validity.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_validity.cc
new file mode 100644
index 000000000..d23a909c6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_validity.cc
@@ -0,0 +1,286 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cmath>
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/common.h"
+
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+
+namespace arrow {
+
+using internal::CopyBitmap;
+using internal::InvertBitmap;
+
+namespace compute {
+namespace internal {
+namespace {
+
+struct IsValidOperator {
+ static Status Call(KernelContext* ctx, const Scalar& in, Scalar* out) {
+ checked_cast<BooleanScalar*>(out)->value = in.is_valid;
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& arr, ArrayData* out) {
+ DCHECK_EQ(out->offset, 0);
+ DCHECK_LE(out->length, arr.length);
+ if (arr.MayHaveNulls()) {
+ // Input has nulls => output is the null (validity) bitmap.
+ // To avoid copying the null bitmap, slice from the starting byte offset
+ // and set the offset to the remaining bit offset.
+ out->offset = arr.offset % 8;
+ out->buffers[1] =
+ arr.offset == 0 ? arr.buffers[0]
+ : SliceBuffer(arr.buffers[0], arr.offset / 8,
+ BitUtil::BytesForBits(out->length + out->offset));
+ return Status::OK();
+ }
+
+ // Input has no nulls => output is entirely true.
+ ARROW_ASSIGN_OR_RAISE(out->buffers[1],
+ ctx->AllocateBitmap(out->length + out->offset));
+ BitUtil::SetBitsTo(out->buffers[1]->mutable_data(), out->offset, out->length, true);
+ return Status::OK();
+ }
+};
+
+struct IsFiniteOperator {
+ template <typename OutType, typename InType>
+ static constexpr OutType Call(KernelContext*, const InType& value, Status*) {
+ return std::isfinite(value);
+ }
+};
+
+struct IsInfOperator {
+ template <typename OutType, typename InType>
+ static constexpr OutType Call(KernelContext*, const InType& value, Status*) {
+ return std::isinf(value);
+ }
+};
+
+using NanOptionsState = OptionsWrapper<NullOptions>;
+
+struct IsNullOperator {
+ static Status Call(KernelContext* ctx, const Scalar& in, Scalar* out) {
+ const auto& options = NanOptionsState::Get(ctx);
+ bool* out_value = &checked_cast<BooleanScalar*>(out)->value;
+
+ if (in.is_valid) {
+ if (options.nan_is_null && is_floating(in.type->id())) {
+ switch (in.type->id()) {
+ case Type::FLOAT:
+ *out_value = std::isnan(internal::UnboxScalar<FloatType>::Unbox(in));
+ break;
+ case Type::DOUBLE:
+ *out_value = std::isnan(internal::UnboxScalar<DoubleType>::Unbox(in));
+ break;
+ default:
+ return Status::NotImplemented("NaN detection not implemented for type ",
+ in.type->ToString());
+ }
+ } else {
+ *out_value = false;
+ }
+ } else {
+ *out_value = true;
+ }
+
+ return Status::OK();
+ }
+
+ template <typename T>
+ static void SetNanBits(const ArrayData& arr, uint8_t* out_bitmap, int64_t out_offset) {
+ const T* data = arr.GetValues<T>(1);
+ for (int64_t i = 0; i < arr.length; ++i) {
+ if (std::isnan(data[i])) {
+ BitUtil::SetBit(out_bitmap, i + out_offset);
+ }
+ }
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& arr, ArrayData* out) {
+ const auto& options = NanOptionsState::Get(ctx);
+
+ uint8_t* out_bitmap = out->buffers[1]->mutable_data();
+ if (arr.GetNullCount() > 0) {
+ // Input has nulls => output is the inverted null (validity) bitmap.
+ InvertBitmap(arr.buffers[0]->data(), arr.offset, arr.length, out_bitmap,
+ out->offset);
+ } else {
+ // Input has no nulls => output is entirely false.
+ BitUtil::SetBitsTo(out_bitmap, out->offset, out->length, false);
+ }
+
+ if (is_floating(arr.type->id()) && options.nan_is_null) {
+ switch (arr.type->id()) {
+ case Type::FLOAT:
+ SetNanBits<float>(arr, out_bitmap, out->offset);
+ break;
+ case Type::DOUBLE:
+ SetNanBits<double>(arr, out_bitmap, out->offset);
+ break;
+ default:
+ return Status::NotImplemented("NaN detection not implemented for type ",
+ arr.type->ToString());
+ }
+ }
+ return Status::OK();
+ }
+};
+
+struct IsNanOperator {
+ template <typename OutType, typename InType>
+ static constexpr OutType Call(KernelContext*, const InType& value, Status*) {
+ return std::isnan(value);
+ }
+};
+
+void MakeFunction(std::string name, const FunctionDoc* doc,
+ std::vector<InputType> in_types, OutputType out_type,
+ ArrayKernelExec exec, FunctionRegistry* registry,
+ MemAllocation::type mem_allocation, bool can_write_into_slices,
+ const FunctionOptions* default_options = NULLPTR,
+ KernelInit init = NULLPTR) {
+ Arity arity{static_cast<int>(in_types.size())};
+ auto func = std::make_shared<ScalarFunction>(name, arity, doc, default_options);
+
+ ScalarKernel kernel(std::move(in_types), out_type, exec, init);
+ kernel.null_handling = NullHandling::OUTPUT_NOT_NULL;
+ kernel.can_write_into_slices = can_write_into_slices;
+ kernel.mem_allocation = mem_allocation;
+
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+template <typename InType, typename Op>
+void AddFloatValidityKernel(const std::shared_ptr<DataType>& ty, ScalarFunction* func) {
+ DCHECK_OK(func->AddKernel({ty}, boolean(),
+ applicator::ScalarUnary<BooleanType, InType, Op>::Exec));
+}
+
+std::shared_ptr<ScalarFunction> MakeIsFiniteFunction(std::string name,
+ const FunctionDoc* doc) {
+ auto func = std::make_shared<ScalarFunction>(name, Arity::Unary(), doc);
+
+ AddFloatValidityKernel<FloatType, IsFiniteOperator>(float32(), func.get());
+ AddFloatValidityKernel<DoubleType, IsFiniteOperator>(float64(), func.get());
+
+ return func;
+}
+
+std::shared_ptr<ScalarFunction> MakeIsInfFunction(std::string name,
+ const FunctionDoc* doc) {
+ auto func = std::make_shared<ScalarFunction>(name, Arity::Unary(), doc);
+
+ AddFloatValidityKernel<FloatType, IsInfOperator>(float32(), func.get());
+ AddFloatValidityKernel<DoubleType, IsInfOperator>(float64(), func.get());
+
+ return func;
+}
+
+std::shared_ptr<ScalarFunction> MakeIsNanFunction(std::string name,
+ const FunctionDoc* doc) {
+ auto func = std::make_shared<ScalarFunction>(name, Arity::Unary(), doc);
+
+ AddFloatValidityKernel<FloatType, IsNanOperator>(float32(), func.get());
+ AddFloatValidityKernel<DoubleType, IsNanOperator>(float64(), func.get());
+
+ return func;
+}
+
+Status IsValidExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const Datum& arg0 = batch[0];
+ if (arg0.type()->id() == Type::NA) {
+ auto false_value = std::make_shared<BooleanScalar>(false);
+ if (arg0.kind() == Datum::SCALAR) {
+ out->value = false_value;
+ } else {
+ std::shared_ptr<Array> false_values;
+ RETURN_NOT_OK(MakeArrayFromScalar(*false_value, out->length(), ctx->memory_pool())
+ .Value(&false_values));
+ out->value = false_values->data();
+ }
+ return Status::OK();
+ } else {
+ return applicator::SimpleUnary<IsValidOperator>(ctx, batch, out);
+ }
+}
+
+Status IsNullExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const Datum& arg0 = batch[0];
+ if (arg0.type()->id() == Type::NA) {
+ if (arg0.kind() == Datum::SCALAR) {
+ out->value = std::make_shared<BooleanScalar>(true);
+ } else {
+ // Data is preallocated
+ ArrayData* out_arr = out->mutable_array();
+ BitUtil::SetBitsTo(out_arr->buffers[1]->mutable_data(), out_arr->offset,
+ out_arr->length, true);
+ }
+ return Status::OK();
+ } else {
+ return applicator::SimpleUnary<IsNullOperator>(ctx, batch, out);
+ }
+}
+
+const FunctionDoc is_valid_doc(
+ "Return true if non-null",
+ ("For each input value, emit true iff the value is valid (non-null)."), {"values"});
+
+const FunctionDoc is_finite_doc(
+ "Return true if value is finite",
+ ("For each input value, emit true iff the value is finite (not NaN, inf, or -inf)."),
+ {"values"});
+
+const FunctionDoc is_inf_doc(
+ "Return true if infinity",
+ ("For each input value, emit true iff the value is infinite (inf or -inf)."),
+ {"values"});
+
+const FunctionDoc is_null_doc(
+ "Return true if null (and optionally NaN)",
+ ("For each input value, emit true iff the value is null.\n"
+ "True may also be emitted for NaN values by setting the `nan_is_null` flag."),
+ {"values"}, "NullOptions");
+
+const FunctionDoc is_nan_doc("Return true if NaN",
+ ("For each input value, emit true iff the value is NaN."),
+ {"values"});
+
+} // namespace
+
+void RegisterScalarValidity(FunctionRegistry* registry) {
+ static auto kNullOptions = NullOptions::Defaults();
+ MakeFunction("is_valid", &is_valid_doc, {ValueDescr::ANY}, boolean(), IsValidExec,
+ registry, MemAllocation::NO_PREALLOCATE, /*can_write_into_slices=*/false);
+
+ MakeFunction("is_null", &is_null_doc, {ValueDescr::ANY}, boolean(), IsNullExec,
+ registry, MemAllocation::PREALLOCATE,
+ /*can_write_into_slices=*/true, &kNullOptions, NanOptionsState::Init);
+
+ DCHECK_OK(registry->AddFunction(MakeIsFiniteFunction("is_finite", &is_finite_doc)));
+ DCHECK_OK(registry->AddFunction(MakeIsInfFunction("is_inf", &is_inf_doc)));
+ DCHECK_OK(registry->AddFunction(MakeIsNanFunction("is_nan", &is_nan_doc)));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/scalar_validity_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/scalar_validity_test.cc
new file mode 100644
index 000000000..35a6b831e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/scalar_validity_test.cc
@@ -0,0 +1,206 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+namespace compute {
+
+template <typename ArrowType>
+class TestValidityKernels : public ::testing::Test {
+ protected:
+ static std::shared_ptr<DataType> type_singleton() {
+ return TypeTraits<ArrowType>::type_singleton();
+ }
+};
+
+using TestBooleanValidityKernels = TestValidityKernels<BooleanType>;
+
+TEST_F(TestBooleanValidityKernels, ArrayIsValid) {
+ CheckScalarUnary("is_valid", type_singleton(), "[]", type_singleton(), "[]");
+ CheckScalarUnary("is_valid", type_singleton(), "[null]", type_singleton(), "[false]");
+ CheckScalarUnary("is_valid", type_singleton(), "[1]", type_singleton(), "[true]");
+ CheckScalarUnary("is_valid", type_singleton(), "[null, 1, 0, null]", type_singleton(),
+ "[false, true, true, false]");
+}
+
+TEST_F(TestBooleanValidityKernels, ArrayIsValidBufferPassthruOptimization) {
+ Datum arg = ArrayFromJSON(boolean(), "[null, 1, 0, null]");
+ ASSERT_OK_AND_ASSIGN(auto validity, arrow::compute::IsValid(arg));
+ ASSERT_EQ(validity.array()->buffers[1], arg.array()->buffers[0]);
+}
+
+TEST_F(TestBooleanValidityKernels, IsNull) {
+ auto ty = type_singleton();
+ NullOptions default_options;
+ NullOptions nan_is_null_options(/*nan_is_null=*/true);
+
+ CheckScalarUnary("is_null", ty, "[]", boolean(), "[]");
+ CheckScalarUnary("is_null", ty, "[]", boolean(), "[]", &default_options);
+ CheckScalarUnary("is_null", ty, "[]", boolean(), "[]", &nan_is_null_options);
+
+ CheckScalarUnary("is_null", ty, "[null]", boolean(), "[true]");
+ CheckScalarUnary("is_null", ty, "[null]", boolean(), "[true]", &default_options);
+ CheckScalarUnary("is_null", ty, "[null]", boolean(), "[true]", &nan_is_null_options);
+
+ CheckScalarUnary("is_null", ty, "[1]", boolean(), "[false]");
+ CheckScalarUnary("is_null", ty, "[1]", boolean(), "[false]", &default_options);
+ CheckScalarUnary("is_null", ty, "[1]", boolean(), "[false]", &nan_is_null_options);
+
+ CheckScalarUnary("is_null", ty, "[null, 1, 0, null]", boolean(),
+ "[true, false, false, true]");
+ CheckScalarUnary("is_null", ty, "[null, 1, 0, null]", boolean(),
+ "[true, false, false, true]", &default_options);
+ CheckScalarUnary("is_null", ty, "[null, 1, 0, null]", boolean(),
+ "[true, false, false, true]", &nan_is_null_options);
+}
+
+TEST(TestValidityKernels, IsValidIsNullNullType) {
+ CheckScalarUnary("is_null", std::make_shared<NullArray>(5),
+ ArrayFromJSON(boolean(), "[true, true, true, true, true]"));
+ CheckScalarUnary("is_valid", std::make_shared<NullArray>(5),
+ ArrayFromJSON(boolean(), "[false, false, false, false, false]"));
+}
+
+TEST(TestValidityKernels, IsNullSetsZeroNullCount) {
+ auto arr = ArrayFromJSON(int32(), "[1, 2, 3, 4, null]");
+ ASSERT_OK_AND_ASSIGN(Datum out, IsNull(arr));
+ ASSERT_EQ(out.array()->null_count, 0);
+}
+
+template <typename ArrowType>
+class TestFloatingPointValidityKernels : public TestValidityKernels<ArrowType> {
+ public:
+ void TestIsNull() {
+ NullOptions default_options;
+ NullOptions nan_is_null_options(/*nan_is_null=*/true);
+
+ auto ty = this->type_singleton();
+ auto arr = ArrayFromJSON(ty, "[]");
+ CheckScalarUnary("is_null", arr, ArrayFromJSON(boolean(), "[]"));
+ CheckScalarUnary("is_null", arr, ArrayFromJSON(boolean(), "[]"), &default_options);
+ CheckScalarUnary("is_null", arr, ArrayFromJSON(boolean(), "[]"),
+ &nan_is_null_options);
+
+ // Without nulls
+ arr = ArrayFromJSON(ty, "[1.5, 0.0, -0.0, Inf, -Inf, NaN]");
+ CheckScalarUnary(
+ "is_null", arr,
+ ArrayFromJSON(boolean(), "[false, false, false, false, false, false]"));
+ CheckScalarUnary(
+ "is_null", arr,
+ ArrayFromJSON(boolean(), "[false, false, false, false, false, false]"),
+ &default_options);
+ CheckScalarUnary(
+ "is_null", arr,
+ ArrayFromJSON(boolean(), "[false, false, false, false, false, true]"),
+ &nan_is_null_options);
+
+ // With nulls
+ arr = ArrayFromJSON(ty, "[1.5, -0.0, null, Inf, -Inf, NaN]");
+ CheckScalarUnary(
+ "is_null", arr,
+ ArrayFromJSON(boolean(), "[false, false, true, false, false, false]"));
+ CheckScalarUnary(
+ "is_null", arr,
+ ArrayFromJSON(boolean(), "[false, false, true, false, false, false]"),
+ &default_options);
+ CheckScalarUnary("is_null", arr,
+ ArrayFromJSON(boolean(), "[false, false, true, false, false, true]"),
+ &nan_is_null_options);
+
+ // Only nulls
+ arr = ArrayFromJSON(ty, "[null, null, null]");
+ CheckScalarUnary("is_null", arr, ArrayFromJSON(boolean(), "[true, true, true]"));
+ CheckScalarUnary("is_null", arr, ArrayFromJSON(boolean(), "[true, true, true]"),
+ &default_options);
+ CheckScalarUnary("is_null", arr, ArrayFromJSON(boolean(), "[true, true, true]"),
+ &nan_is_null_options);
+ }
+
+ void TestIsFinite() {
+ auto ty = this->type_singleton();
+ CheckScalarUnary("is_finite", ArrayFromJSON(ty, "[]"),
+ ArrayFromJSON(boolean(), "[]"));
+
+ // All Inf
+ CheckScalarUnary("is_finite", ArrayFromJSON(ty, "[Inf, -Inf, Inf, -Inf, Inf]"),
+ ArrayFromJSON(boolean(), "[false, false, false, false, false]"));
+ // No Inf
+ CheckScalarUnary("is_finite", ArrayFromJSON(ty, "[0.0, 1.0, 2.0, 3.0, NaN, null]"),
+ ArrayFromJSON(boolean(), "[true, true, true, true, false, null]"));
+ // Some Inf
+ CheckScalarUnary("is_finite", ArrayFromJSON(ty, "[0.0, Inf, 2.0, -Inf, NaN, null]"),
+ ArrayFromJSON(boolean(), "[true, false, true, false, false, null]"));
+ }
+
+ void TestIsInf() {
+ auto ty = this->type_singleton();
+ CheckScalarUnary("is_inf", ArrayFromJSON(ty, "[]"), ArrayFromJSON(boolean(), "[]"));
+
+ // All Inf
+ CheckScalarUnary("is_inf", ArrayFromJSON(ty, "[Inf, -Inf, Inf, -Inf, Inf]"),
+ ArrayFromJSON(boolean(), "[true, true, true, true, true]"));
+ // No Inf
+ CheckScalarUnary(
+ "is_inf", ArrayFromJSON(ty, "[0.0, 1.0, 2.0, 3.0, NaN, null]"),
+ ArrayFromJSON(boolean(), "[false, false, false, false, false, null]"));
+ // Some Inf
+ CheckScalarUnary("is_inf", ArrayFromJSON(ty, "[0.0, Inf, 2.0, -Inf, NaN, null]"),
+ ArrayFromJSON(boolean(), "[false, true, false, true, false, null]"));
+ }
+
+ void TestIsNan() {
+ auto ty = this->type_singleton();
+ CheckScalarUnary("is_nan", ArrayFromJSON(ty, "[]"), ArrayFromJSON(boolean(), "[]"));
+
+ // All NaN
+ CheckScalarUnary("is_nan", ArrayFromJSON(ty, "[NaN, NaN, NaN, NaN, NaN]"),
+ ArrayFromJSON(boolean(), "[true, true, true, true, true]"));
+ // No NaN
+ CheckScalarUnary(
+ "is_nan", ArrayFromJSON(ty, "[0.0, 1.0, 2.0, 3.0, Inf, null]"),
+ ArrayFromJSON(boolean(), "[false, false, false, false, false, null]"));
+ // Some NaNs
+ CheckScalarUnary("is_nan", ArrayFromJSON(ty, "[0.0, NaN, 2.0, NaN, Inf, null]"),
+ ArrayFromJSON(boolean(), "[false, true, false, true, false, null]"));
+ }
+};
+
+TYPED_TEST_SUITE(TestFloatingPointValidityKernels, RealArrowTypes);
+
+TYPED_TEST(TestFloatingPointValidityKernels, IsNull) { this->TestIsNull(); }
+
+TYPED_TEST(TestFloatingPointValidityKernels, IsFinite) { this->TestIsFinite(); }
+
+TYPED_TEST(TestFloatingPointValidityKernels, IsInf) { this->TestIsInf(); }
+
+TYPED_TEST(TestFloatingPointValidityKernels, IsNan) { this->TestIsNan(); }
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/select_k_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/select_k_test.cc
new file mode 100644
index 000000000..2d1d5cffe
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/select_k_test.cc
@@ -0,0 +1,716 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+template <typename ArrayType, SortOrder order>
+class SelectKCompareForResult {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ if (order == SortOrder::Ascending) {
+ return lval <= rval;
+ } else {
+ return rval <= lval;
+ }
+ }
+};
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const Datum& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return SelectKUnstable(values, SelectKOptions::TopKDefault(k));
+ } else {
+ return SelectKUnstable(values, SelectKOptions::BottomKDefault(k));
+ }
+}
+
+void ValidateSelectK(const Datum& datum, Array& select_k_indices, SortOrder order,
+ bool stable_sort = false) {
+ ASSERT_TRUE(datum.is_arraylike());
+ ASSERT_OK_AND_ASSIGN(auto sorted_indices,
+ SortIndices(datum, SortOptions({SortKey("unused", order)})));
+
+ int64_t k = select_k_indices.length();
+ // head(k)
+ auto head_k_indices = sorted_indices->Slice(0, k);
+ if (stable_sort) {
+ AssertDatumsEqual(*head_k_indices, select_k_indices);
+ } else {
+ ASSERT_OK_AND_ASSIGN(auto expected,
+ Take(datum, *head_k_indices, TakeOptions::NoBoundsCheck()));
+ ASSERT_OK_AND_ASSIGN(auto actual,
+ Take(datum, select_k_indices, TakeOptions::NoBoundsCheck()));
+ AssertDatumsEqual(Datum(expected), Datum(actual));
+ }
+}
+
+template <typename ArrowType>
+class TestSelectKBase : public TestBase {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ protected:
+ template <SortOrder order>
+ void AssertSelectKArray(const std::shared_ptr<Array> values, int k) {
+ std::shared_ptr<Array> select_k;
+ ASSERT_OK_AND_ASSIGN(select_k, SelectK<order>(Datum(*values), k));
+ ASSERT_EQ(select_k->data()->null_count, 0);
+ ValidateOutput(*select_k);
+ ValidateSelectK(Datum(*values), *select_k, order);
+ }
+
+ void AssertTopKArray(const std::shared_ptr<Array> values, int n) {
+ AssertSelectKArray<SortOrder::Descending>(values, n);
+ }
+ void AssertBottomKArray(const std::shared_ptr<Array> values, int n) {
+ AssertSelectKArray<SortOrder::Ascending>(values, n);
+ }
+
+ void AssertSelectKJson(const std::string& values, int n) {
+ AssertTopKArray(ArrayFromJSON(type_singleton(), values), n);
+ AssertBottomKArray(ArrayFromJSON(type_singleton(), values), n);
+ }
+
+ virtual std::shared_ptr<DataType> type_singleton() = 0;
+};
+
+template <typename ArrowType>
+class TestSelectK : public TestSelectKBase<ArrowType> {
+ protected:
+ std::shared_ptr<DataType> type_singleton() override {
+ return default_type_instance<ArrowType>();
+ }
+};
+
+template <typename ArrowType>
+class TestSelectKForReal : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForReal, RealArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForIntegral : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForIntegral, IntegralArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForBool : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForBool, ::testing::Types<BooleanType>);
+
+template <typename ArrowType>
+class TestSelectKForTemporal : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForTemporal, TemporalArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForDecimal : public TestSelectKBase<ArrowType> {
+ std::shared_ptr<DataType> type_singleton() override {
+ return std::make_shared<ArrowType>(5, 2);
+ }
+};
+TYPED_TEST_SUITE(TestSelectKForDecimal, DecimalArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForStrings : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForStrings, testing::Types<StringType>);
+
+TYPED_TEST(TestSelectKForReal, SelectKDoesNotProvideDefaultOptions) {
+ auto input = ArrayFromJSON(this->type_singleton(), "[null, 1, 3.3, null, 2, 5.3]");
+ ASSERT_RAISES(Invalid, CallFunction("select_k_unstable", {input}));
+}
+
+TYPED_TEST(TestSelectKForReal, Real) {
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 0);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 2);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 5);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 6);
+
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 0);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 1);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 2);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 3);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4);
+ this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 3);
+ this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4);
+ this->AssertSelectKJson("[100, 4, 2, 7, 8, 3, NaN, 3, 1]", 4);
+}
+
+TYPED_TEST(TestSelectKForIntegral, Integral) {
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6);
+
+ this->AssertSelectKJson("[2, 4, 5, 7, 8, 0, 9, 1, 3]", 5);
+}
+
+TYPED_TEST(TestSelectKForBool, Bool) {
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 0);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 2);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 5);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 6);
+}
+
+TYPED_TEST(TestSelectKForTemporal, Temporal) {
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6);
+}
+
+TYPED_TEST(TestSelectKForDecimal, Decimal) {
+ const std::string values = R"(["123.45", null, "-123.45", "456.78", "-456.78"])";
+ this->AssertSelectKJson(values, 0);
+ this->AssertSelectKJson(values, 2);
+ this->AssertSelectKJson(values, 4);
+ this->AssertSelectKJson(values, 5);
+}
+
+TYPED_TEST(TestSelectKForStrings, Strings) {
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 0);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 2);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 5);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 6);
+}
+
+template <typename ArrowType>
+class TestSelectKRandom : public TestSelectKBase<ArrowType> {
+ public:
+ std::shared_ptr<DataType> type_singleton() override {
+ EXPECT_TRUE(0) << "shouldn't be used";
+ return nullptr;
+ }
+};
+
+using SelectKableTypes =
+ ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
+ Int32Type, Int64Type, FloatType, DoubleType, StringType>;
+
+TYPED_TEST_SUITE(TestSelectKRandom, SelectKableTypes);
+
+TYPED_TEST(TestSelectKRandom, RandomValues) {
+ Random<TypeParam> rand(0x61549225);
+ int length = 100;
+ for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
+ auto array = rand.Generate(length, null_probability);
+ // Try n from 0 to out of bound
+ for (int n = 0; n <= length; ++n) {
+ this->AssertTopKArray(array, n);
+ this->AssertBottomKArray(array, n);
+ }
+ }
+}
+
+// Test basic cases for chunked array
+
+template <typename ArrowType>
+struct TestSelectKWithChunkedArray : public ::testing::Test {
+ TestSelectKWithChunkedArray() {}
+
+ // Slice `array` into multiple chunks along `offsets`
+ ArrayVector Slices(const std::shared_ptr<Array>& array,
+ const std::shared_ptr<Int32Array>& offsets) {
+ ArrayVector slices(offsets->length() - 1);
+ for (int64_t i = 0; i != static_cast<int64_t>(slices.size()); ++i) {
+ slices[i] =
+ array->Slice(offsets->Value(i), offsets->Value(i + 1) - offsets->Value(i));
+ }
+ return slices;
+ }
+
+ template <SortOrder order = SortOrder::Descending>
+ void AssertSelectK(const std::shared_ptr<ChunkedArray>& chunked_array, int64_t k) {
+ ASSERT_OK_AND_ASSIGN(auto select_k_array, SelectK<order>(Datum(*chunked_array), k));
+ ValidateSelectK(Datum(*chunked_array), *select_k_array, order);
+ }
+
+ void AssertTopK(const std::shared_ptr<ChunkedArray>& chunked_array, int64_t k) {
+ AssertSelectK<SortOrder::Descending>(chunked_array, k);
+ }
+ void AssertBottomK(const std::shared_ptr<ChunkedArray>& chunked_array, int64_t k) {
+ AssertSelectK<SortOrder::Ascending>(chunked_array, k);
+ }
+};
+
+TYPED_TEST_SUITE(TestSelectKWithChunkedArray, SelectKableTypes);
+
+TYPED_TEST(TestSelectKWithChunkedArray, RandomValuesWithSlices) {
+ Random<TypeParam> rand(0x61549225);
+ int length = 100;
+ for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
+ // Try n from 0 to out of bound
+ auto array = rand.Generate(length, null_probability);
+ auto offsets = rand.Offsets(length, 3);
+ auto slices = this->Slices(array, offsets);
+ ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make(slices));
+ for (int k = 0; k <= length; k += 10) {
+ this->AssertTopK(chunked_array, k);
+ this->AssertBottomK(chunked_array, k);
+ }
+ }
+}
+
+template <typename ArrayType, SortOrder order>
+void ValidateSelectKIndices(const ArrayType& array) {
+ ValidateOutput(array);
+
+ SelectKCompareForResult<ArrayType, order> compare;
+ for (uint64_t i = 1; i < static_cast<uint64_t>(array.length()); i++) {
+ using ArrowType = typename ArrayType::TypeClass;
+ using GetView = internal::GetViewType<ArrowType>;
+
+ const auto lval = GetView::LogicalValue(array.GetView(i - 1));
+ const auto rval = GetView::LogicalValue(array.GetView(i));
+ ASSERT_TRUE(compare(lval, rval));
+ }
+}
+// Base class for testing against random chunked array.
+template <typename Type, SortOrder order>
+struct TestSelectKWithChunkedArrayRandomBase : public ::testing::Test {
+ void TestSelectK(int length) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ // We can use INSTANTIATE_TEST_SUITE_P() instead of using fors in a test.
+ for (auto null_probability : {0.0, 0.1, 0.5, 0.9, 1.0}) {
+ for (auto num_chunks : {1, 2, 5, 10, 40}) {
+ std::vector<std::shared_ptr<Array>> arrays;
+ for (int i = 0; i < num_chunks; ++i) {
+ auto array = this->GenerateArray(length / num_chunks, null_probability);
+ arrays.push_back(array);
+ }
+ ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make(arrays));
+ ASSERT_OK_AND_ASSIGN(auto indices, SelectK<order>(Datum(*chunked_array), 5));
+ ASSERT_OK_AND_ASSIGN(auto actual, Take(Datum(chunked_array), Datum(indices),
+ TakeOptions::NoBoundsCheck()));
+ ASSERT_OK_AND_ASSIGN(auto sorted_k,
+ Concatenate(actual.chunked_array()->chunks()));
+
+ ValidateSelectKIndices<ArrayType, order>(
+ *checked_pointer_cast<ArrayType>(sorted_k));
+ }
+ }
+ }
+
+ void SetUp() override { rand_ = new Random<Type>(0x5487655); }
+
+ void TearDown() override { delete rand_; }
+
+ protected:
+ std::shared_ptr<Array> GenerateArray(int length, double null_probability) {
+ return rand_->Generate(length, null_probability);
+ }
+
+ private:
+ Random<Type>* rand_;
+};
+
+// Long array with big value range
+template <typename Type>
+class TestTopKChunkedArrayRandom
+ : public TestSelectKWithChunkedArrayRandomBase<Type, SortOrder::Descending> {};
+
+TYPED_TEST_SUITE(TestTopKChunkedArrayRandom, SelectKableTypes);
+
+TYPED_TEST(TestTopKChunkedArrayRandom, TopK) { this->TestSelectK(1000); }
+
+template <typename Type>
+class TestBottomKChunkedArrayRandom
+ : public TestSelectKWithChunkedArrayRandomBase<Type, SortOrder::Ascending> {};
+
+TYPED_TEST_SUITE(TestBottomKChunkedArrayRandom, SelectKableTypes);
+
+TYPED_TEST(TestBottomKChunkedArrayRandom, BottomK) { this->TestSelectK(1000); }
+
+// // Test basic cases for record batch.
+class TestSelectKWithRecordBatch : public ::testing::Test {
+ public:
+ void Check(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
+ const SelectKOptions& options, const std::string& expected_batch) {
+ std::shared_ptr<RecordBatch> actual;
+ ASSERT_OK(this->DoSelectK(schm, batch_json, options, &actual));
+ ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual);
+ }
+
+ Status DoSelectK(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
+ const SelectKOptions& options, std::shared_ptr<RecordBatch>* out) {
+ auto batch = RecordBatchFromJSON(schm, batch_json);
+ ARROW_ASSIGN_OR_RAISE(auto indices, SelectKUnstable(Datum(*batch), options));
+
+ ValidateOutput(*indices);
+ ARROW_ASSIGN_OR_RAISE(
+ auto select_k, Take(Datum(batch), Datum(indices), TakeOptions::NoBoundsCheck()));
+ *out = select_k.record_batch();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestSelectKWithRecordBatch, TopKNoNull) {
+ auto schema = ::arrow::schema({
+ {field("a", uint8())},
+ {field("b", uint32())},
+ });
+
+ auto batch_input = R"([
+ {"a": 3, "b": 5},
+ {"a": 30, "b": 3},
+ {"a": 3, "b": 4},
+ {"a": 0, "b": 6},
+ {"a": 20, "b": 5},
+ {"a": 10, "b": 5},
+ {"a": 10, "b": 3}
+ ])";
+
+ auto options = SelectKOptions::TopKDefault(3, {"a"});
+
+ auto expected_batch = R"([
+ {"a": 30, "b": 3},
+ {"a": 20, "b": 5},
+ {"a": 10, "b": 5}
+ ])";
+
+ Check(schema, batch_input, options, expected_batch);
+}
+
+TEST_F(TestSelectKWithRecordBatch, TopKNull) {
+ auto schema = ::arrow::schema({
+ {field("a", uint8())},
+ {field("b", uint32())},
+ });
+
+ auto batch_input = R"([
+ {"a": null, "b": 5},
+ {"a": 30, "b": 3},
+ {"a": null, "b": 4},
+ {"a": null, "b": 6},
+ {"a": 20, "b": 5},
+ {"a": null, "b": 5},
+ {"a": 10, "b": 3}
+ ])";
+
+ auto options = SelectKOptions::TopKDefault(3, {"a"});
+
+ auto expected_batch = R"([
+ {"a": 30, "b": 3},
+ {"a": 20, "b": 5},
+ {"a": 10, "b": 3}
+ ])";
+
+ Check(schema, batch_input, options, expected_batch);
+}
+
+TEST_F(TestSelectKWithRecordBatch, TopKOneColumnKey) {
+ auto schema = ::arrow::schema({
+ {field("country", utf8())},
+ {field("population", uint64())},
+ });
+
+ auto batch_input =
+ R"([{"country": "Italy", "population": 59000000},
+ {"country": "France", "population": 65000000},
+ {"country": "Malta", "population": 434000},
+ {"country": "Maldives", "population": 434000},
+ {"country": "Brunei", "population": 434000},
+ {"country": "Iceland", "population": 337000},
+ {"country": "Nauru", "population": 11300},
+ {"country": "Tuvalu", "population": 11300},
+ {"country": "Anguilla", "population": 11300},
+ {"country": "Montserrat", "population": 5200}
+ ])";
+
+ auto options = SelectKOptions::TopKDefault(3, {"population"});
+
+ auto expected_batch =
+ R"([{"country": "France", "population": 65000000},
+ {"country": "Italy", "population": 59000000},
+ {"country": "Malta", "population": 434000}
+ ])";
+ this->Check(schema, batch_input, options, expected_batch);
+}
+
+TEST_F(TestSelectKWithRecordBatch, TopKMultipleColumnKeys) {
+ auto schema = ::arrow::schema({{field("country", utf8())},
+ {field("population", uint64())},
+ {field("GDP", uint64())}});
+
+ auto batch_input =
+ R"([{"country": "Italy", "population": 59000000, "GDP": 1937894},
+ {"country": "France", "population": 65000000, "GDP": 2583560},
+ {"country": "Malta", "population": 434000, "GDP": 12011},
+ {"country": "Maldives", "population": 434000, "GDP": 4520},
+ {"country": "Brunei", "population": 434000, "GDP": 12128},
+ {"country": "Iceland", "population": 337000, "GDP": 17036},
+ {"country": "Nauru", "population": 337000, "GDP": 182},
+ {"country": "Tuvalu", "population": 11300, "GDP": 38},
+ {"country": "Anguilla", "population": 11300, "GDP": 311}
+ ])";
+ auto options = SelectKOptions::TopKDefault(3, {"population", "GDP"});
+
+ auto expected_batch =
+ R"([{"country": "France", "population": 65000000, "GDP": 2583560},
+ {"country": "Italy", "population": 59000000, "GDP": 1937894},
+ {"country": "Brunei", "population": 434000, "GDP": 12128}
+ ])";
+ this->Check(schema, batch_input, options, expected_batch);
+}
+
+TEST_F(TestSelectKWithRecordBatch, BottomKNoNull) {
+ auto schema = ::arrow::schema({
+ {field("a", uint8())},
+ {field("b", uint32())},
+ });
+
+ auto batch_input = R"([
+ {"a": 3, "b": 5},
+ {"a": 30, "b": 3},
+ {"a": 3, "b": 4},
+ {"a": 0, "b": 6},
+ {"a": 20, "b": 5},
+ {"a": 10, "b": 5},
+ {"a": 10, "b": 3}
+ ])";
+
+ auto options = SelectKOptions::BottomKDefault(3, {"a"});
+
+ auto expected_batch = R"([
+ {"a": 0, "b": 6},
+ {"a": 3, "b": 4},
+ {"a": 3, "b": 5}
+ ])";
+
+ Check(schema, batch_input, options, expected_batch);
+}
+
+TEST_F(TestSelectKWithRecordBatch, BottomKNull) {
+ auto schema = ::arrow::schema({
+ {field("a", uint8())},
+ {field("b", uint32())},
+ });
+
+ auto batch_input = R"([
+ {"a": null, "b": 5},
+ {"a": 30, "b": 3},
+ {"a": null, "b": 4},
+ {"a": null, "b": 6},
+ {"a": 20, "b": 5},
+ {"a": null, "b": 5},
+ {"a": 10, "b": 3}
+ ])";
+
+ auto options = SelectKOptions::BottomKDefault(3, {"a"});
+
+ auto expected_batch = R"([
+ {"a": 10, "b": 3},
+ {"a": 20, "b": 5},
+ {"a": 30, "b": 3}
+ ])";
+
+ Check(schema, batch_input, options, expected_batch);
+}
+
+TEST_F(TestSelectKWithRecordBatch, BottomKOneColumnKey) {
+ auto schema = ::arrow::schema({
+ {field("country", utf8())},
+ {field("population", uint64())},
+ });
+
+ auto batch_input =
+ R"([{"country": "Italy", "population": 59000000},
+ {"country": "France", "population": 65000000},
+ {"country": "Malta", "population": 434000},
+ {"country": "Maldives", "population": 434000},
+ {"country": "Brunei", "population": 434000},
+ {"country": "Iceland", "population": 337000},
+ {"country": "Nauru", "population": 11300},
+ {"country": "Tuvalu", "population": 11300},
+ {"country": "Anguilla", "population": 11300},
+ {"country": "Montserrat", "population": 5200}
+ ])";
+
+ auto options = SelectKOptions::BottomKDefault(3, {"population"});
+
+ auto expected_batch =
+ R"([{"country": "Montserrat", "population": 5200},
+ {"country": "Anguilla", "population": 11300},
+ {"country": "Tuvalu", "population": 11300}
+ ])";
+ this->Check(schema, batch_input, options, expected_batch);
+}
+
+TEST_F(TestSelectKWithRecordBatch, BottomKMultipleColumnKeys) {
+ auto schema = ::arrow::schema({{field("country", utf8())},
+ {field("population", uint64())},
+ {field("GDP", uint64())}});
+
+ auto batch_input =
+ R"([{"country": "Italy", "population": 59000000, "GDP": 1937894},
+ {"country": "France", "population": 65000000, "GDP": 2583560},
+ {"country": "Malta", "population": 434000, "GDP": 12011},
+ {"country": "Maldives", "population": 434000, "GDP": 4520},
+ {"country": "Brunei", "population": 434000, "GDP": 12128},
+ {"country": "Iceland", "population": 337000, "GDP": 17036},
+ {"country": "Nauru", "population": 337000, "GDP": 182},
+ {"country": "Tuvalu", "population": 11300, "GDP": 38},
+ {"country": "Anguilla", "population": 11300, "GDP": 311}
+ ])";
+
+ auto options = SelectKOptions::BottomKDefault(3, {"population", "GDP"});
+
+ auto expected_batch =
+ R"([{"country": "Tuvalu", "population": 11300, "GDP": 38},
+ {"country": "Anguilla", "population": 11300, "GDP": 311},
+ {"country": "Nauru", "population": 337000, "GDP": 182}
+ ])";
+ this->Check(schema, batch_input, options, expected_batch);
+}
+
+// Test basic cases for table.
+struct TestSelectKWithTable : public ::testing::Test {
+ void Check(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& input_json, const SelectKOptions& options,
+ const std::vector<std::string>& expected) {
+ std::shared_ptr<Table> actual;
+ ASSERT_OK(this->DoSelectK(schm, input_json, options, &actual));
+ ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected), *actual);
+ }
+
+ Status DoSelectK(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& input_json,
+ const SelectKOptions& options, std::shared_ptr<Table>* out) {
+ auto table = TableFromJSON(schm, input_json);
+ ARROW_ASSIGN_OR_RAISE(auto indices, SelectKUnstable(Datum(*table), options));
+ ValidateOutput(*indices);
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto select_k, Take(Datum(table), Datum(indices), TakeOptions::NoBoundsCheck()));
+ *out = select_k.table();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestSelectKWithTable, TopKOneColumnKey) {
+ auto schema = ::arrow::schema({
+ {field("a", uint8())},
+ {field("b", uint32())},
+ });
+
+ std::vector<std::string> input = {R"([{"a": null, "b": 5},
+ {"a": 1, "b": 3},
+ {"a": 3, "b": null},
+ {"a": null, "b": null},
+ {"a": 2, "b": 5},
+ {"a": 1, "b": 5}
+ ])"};
+
+ auto options = SelectKOptions::TopKDefault(3, {"a"});
+
+ std::vector<std::string> expected = {R"([{"a": 3, "b": null},
+ {"a": 2, "b": 5},
+ {"a": 1, "b": 3}
+ ])"};
+ Check(schema, input, options, expected);
+}
+
+TEST_F(TestSelectKWithTable, TopKMultipleColumnKeys) {
+ auto schema = ::arrow::schema({
+ {field("a", uint8())},
+ {field("b", uint32())},
+ });
+ std::vector<std::string> input = {R"([{"a": null, "b": 5},
+ {"a": 1, "b": 3},
+ {"a": 3, "b": null}
+ ])",
+ R"([{"a": null, "b": null},
+ {"a": 2, "b": 5},
+ {"a": 1, "b": 5}
+ ])"};
+
+ auto options = SelectKOptions::TopKDefault(3, {"a", "b"});
+
+ std::vector<std::string> expected = {R"([{"a": 3, "b": null},
+ {"a": 2, "b": 5},
+ {"a": 1, "b": 5}
+ ])"};
+ Check(schema, input, options, expected);
+}
+
+TEST_F(TestSelectKWithTable, BottomKOneColumnKey) {
+ auto schema = ::arrow::schema({
+ {field("a", uint8())},
+ {field("b", uint32())},
+ });
+
+ std::vector<std::string> input = {R"([{"a": null, "b": 5},
+ {"a": 0, "b": 3},
+ {"a": 3, "b": null},
+ {"a": null, "b": null},
+ {"a": 2, "b": 5},
+ {"a": 1, "b": 5}
+ ])"};
+
+ auto options = SelectKOptions::BottomKDefault(3, {"a"});
+
+ std::vector<std::string> expected = {R"([{"a": 0, "b": 3},
+ {"a": 1, "b": 5},
+ {"a": 2, "b": 5}
+ ])"};
+ Check(schema, input, options, expected);
+}
+
+TEST_F(TestSelectKWithTable, BottomKMultipleColumnKeys) {
+ auto schema = ::arrow::schema({
+ {field("a", uint8())},
+ {field("b", uint32())},
+ });
+ std::vector<std::string> input = {R"([{"a": null, "b": 5},
+ {"a": 1, "b": 3},
+ {"a": 3, "b": null}
+ ])",
+ R"([{"a": null, "b": null},
+ {"a": 2, "b": 5},
+ {"a": 1, "b": 5}
+ ])"};
+
+ auto options = SelectKOptions::BottomKDefault(3, {"a", "b"});
+
+ std::vector<std::string> expected = {R"([{"a": 1, "b": 3},
+ {"a": 1, "b": 5},
+ {"a": 2, "b": 5}
+ ])"};
+ Check(schema, input, options, expected);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/temporal_internal.h b/src/arrow/cpp/src/arrow/compute/kernels/temporal_internal.h
new file mode 100644
index 000000000..45fa67a9b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/temporal_internal.h
@@ -0,0 +1,269 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <chrono>
+#include <cstdint>
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/vendored/datetime.h"
+
+namespace arrow {
+
+namespace compute {
+namespace internal {
+
+using arrow_vendored::date::days;
+using arrow_vendored::date::floor;
+using arrow_vendored::date::local_days;
+using arrow_vendored::date::local_time;
+using arrow_vendored::date::locate_zone;
+using arrow_vendored::date::sys_days;
+using arrow_vendored::date::sys_time;
+using arrow_vendored::date::time_zone;
+using arrow_vendored::date::year_month_day;
+
+inline int64_t GetQuarter(const year_month_day& ymd) {
+ return static_cast<int64_t>((static_cast<uint32_t>(ymd.month()) - 1) / 3);
+}
+
+static inline Result<const time_zone*> LocateZone(const std::string& timezone) {
+ try {
+ return locate_zone(timezone);
+ } catch (const std::runtime_error& ex) {
+ return Status::Invalid("Cannot locate timezone '", timezone, "': ", ex.what());
+ }
+}
+
+static inline const std::string& GetInputTimezone(const DataType& type) {
+ static const std::string no_timezone = "";
+ switch (type.id()) {
+ case Type::TIMESTAMP:
+ return checked_cast<const TimestampType&>(type).timezone();
+ default:
+ return no_timezone;
+ }
+}
+
+static inline const std::string& GetInputTimezone(const Datum& datum) {
+ return checked_cast<const TimestampType&>(*datum.type()).timezone();
+}
+
+static inline const std::string& GetInputTimezone(const Scalar& scalar) {
+ return checked_cast<const TimestampType&>(*scalar.type).timezone();
+}
+
+static inline const std::string& GetInputTimezone(const ArrayData& array) {
+ return checked_cast<const TimestampType&>(*array.type).timezone();
+}
+
+inline Status ValidateDayOfWeekOptions(const DayOfWeekOptions& options) {
+ if (options.week_start < 1 || 7 < options.week_start) {
+ return Status::Invalid(
+ "week_start must follow ISO convention (Monday=1, Sunday=7). Got week_start=",
+ options.week_start);
+ }
+ return Status::OK();
+}
+
+struct NonZonedLocalizer {
+ using days_t = sys_days;
+
+ // No-op conversions: UTC -> UTC
+ template <typename Duration>
+ sys_time<Duration> ConvertTimePoint(int64_t t) const {
+ return sys_time<Duration>(Duration{t});
+ }
+
+ sys_days ConvertDays(sys_days d) const { return d; }
+};
+
+struct ZonedLocalizer {
+ using days_t = local_days;
+
+ // Timezone-localizing conversions: UTC -> local time
+ const time_zone* tz;
+
+ template <typename Duration>
+ local_time<Duration> ConvertTimePoint(int64_t t) const {
+ return tz->to_local(sys_time<Duration>(Duration{t}));
+ }
+
+ local_days ConvertDays(sys_days d) const { return local_days(year_month_day(d)); }
+};
+
+//
+// Which types to generate a kernel for
+//
+struct WithDates {};
+struct WithTimes {};
+struct WithTimestamps {};
+
+// This helper allows generating temporal kernels for selected type categories
+// without any spurious code generation for other categories (e.g. avoid
+// generating code for date kernels for a times-only function).
+template <typename Factory>
+void AddTemporalKernels(Factory* fac) {}
+
+template <typename Factory, typename... WithOthers>
+void AddTemporalKernels(Factory* fac, WithDates, WithOthers... others) {
+ fac->template AddKernel<days, Date32Type>(date32());
+ fac->template AddKernel<std::chrono::milliseconds, Date64Type>(date64());
+ AddTemporalKernels(fac, std::forward<WithOthers>(others)...);
+}
+
+template <typename Factory, typename... WithOthers>
+void AddTemporalKernels(Factory* fac, WithTimes, WithOthers... others) {
+ fac->template AddKernel<std::chrono::seconds, Time32Type>(time32(TimeUnit::SECOND));
+ fac->template AddKernel<std::chrono::milliseconds, Time32Type>(time32(TimeUnit::MILLI));
+ fac->template AddKernel<std::chrono::microseconds, Time64Type>(time64(TimeUnit::MICRO));
+ fac->template AddKernel<std::chrono::nanoseconds, Time64Type>(time64(TimeUnit::NANO));
+ AddTemporalKernels(fac, std::forward<WithOthers>(others)...);
+}
+
+template <typename Factory, typename... WithOthers>
+void AddTemporalKernels(Factory* fac, WithTimestamps, WithOthers... others) {
+ fac->template AddKernel<std::chrono::seconds, TimestampType>(
+ match::TimestampTypeUnit(TimeUnit::SECOND));
+ fac->template AddKernel<std::chrono::milliseconds, TimestampType>(
+ match::TimestampTypeUnit(TimeUnit::MILLI));
+ fac->template AddKernel<std::chrono::microseconds, TimestampType>(
+ match::TimestampTypeUnit(TimeUnit::MICRO));
+ fac->template AddKernel<std::chrono::nanoseconds, TimestampType>(
+ match::TimestampTypeUnit(TimeUnit::NANO));
+ AddTemporalKernels(fac, std::forward<WithOthers>(others)...);
+}
+
+//
+// Executor class for temporal component extractors, i.e. scalar kernels
+// with the signature Timestamp -> <non-temporal scalar type `OutType`>
+//
+// The `Op` parameter is templated on the Duration (which depends on the timestamp
+// unit) and a Localizer class (depending on whether the timestamp has a
+// timezone defined).
+//
+template <template <typename...> class Op, typename Duration, typename InType,
+ typename OutType, typename... Args>
+struct TemporalComponentExtractBase {
+ template <typename OptionsType>
+ static Status ExecWithOptions(KernelContext* ctx, const OptionsType* options,
+ const ExecBatch& batch, Datum* out, Args... args) {
+ const auto& timezone = GetInputTimezone(batch.values[0]);
+ if (timezone.empty()) {
+ using ExecTemplate = Op<Duration, NonZonedLocalizer>;
+ auto op = ExecTemplate(options, NonZonedLocalizer(), args...);
+ applicator::ScalarUnaryNotNullStateful<OutType, InType, ExecTemplate> kernel{op};
+ return kernel.Exec(ctx, batch, out);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto tz, LocateZone(timezone));
+ using ExecTemplate = Op<Duration, ZonedLocalizer>;
+ auto op = ExecTemplate(options, ZonedLocalizer{tz}, args...);
+ applicator::ScalarUnaryNotNullStateful<OutType, InType, ExecTemplate> kernel{op};
+ return kernel.Exec(ctx, batch, out);
+ }
+ }
+};
+
+template <template <typename...> class Op, typename OutType>
+struct TemporalComponentExtractBase<Op, days, Date32Type, OutType> {
+ template <typename OptionsType>
+ static Status ExecWithOptions(KernelContext* ctx, const OptionsType* options,
+ const ExecBatch& batch, Datum* out) {
+ using ExecTemplate = Op<days, NonZonedLocalizer>;
+ auto op = ExecTemplate(options, NonZonedLocalizer());
+ applicator::ScalarUnaryNotNullStateful<OutType, Date32Type, ExecTemplate> kernel{op};
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+template <template <typename...> class Op, typename OutType>
+struct TemporalComponentExtractBase<Op, std::chrono::milliseconds, Date64Type, OutType> {
+ template <typename OptionsType>
+ static Status ExecWithOptions(KernelContext* ctx, const OptionsType* options,
+ const ExecBatch& batch, Datum* out) {
+ using ExecTemplate = Op<std::chrono::milliseconds, NonZonedLocalizer>;
+ auto op = ExecTemplate(options, NonZonedLocalizer());
+ applicator::ScalarUnaryNotNullStateful<OutType, Date64Type, ExecTemplate> kernel{op};
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+template <template <typename...> class Op, typename OutType>
+struct TemporalComponentExtractBase<Op, std::chrono::seconds, Time32Type, OutType> {
+ template <typename OptionsType>
+ static Status ExecWithOptions(KernelContext* ctx, const OptionsType* options,
+ const ExecBatch& batch, Datum* out) {
+ using ExecTemplate = Op<std::chrono::seconds, NonZonedLocalizer>;
+ auto op = ExecTemplate(options, NonZonedLocalizer());
+ applicator::ScalarUnaryNotNullStateful<OutType, Time32Type, ExecTemplate> kernel{op};
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+template <template <typename...> class Op, typename OutType>
+struct TemporalComponentExtractBase<Op, std::chrono::milliseconds, Time32Type, OutType> {
+ template <typename OptionsType>
+ static Status ExecWithOptions(KernelContext* ctx, const OptionsType* options,
+ const ExecBatch& batch, Datum* out) {
+ using ExecTemplate = Op<std::chrono::milliseconds, NonZonedLocalizer>;
+ auto op = ExecTemplate(options, NonZonedLocalizer());
+ applicator::ScalarUnaryNotNullStateful<OutType, Time32Type, ExecTemplate> kernel{op};
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+template <template <typename...> class Op, typename OutType>
+struct TemporalComponentExtractBase<Op, std::chrono::microseconds, Time64Type, OutType> {
+ template <typename OptionsType>
+ static Status ExecWithOptions(KernelContext* ctx, const OptionsType* options,
+ const ExecBatch& batch, Datum* out) {
+ using ExecTemplate = Op<std::chrono::microseconds, NonZonedLocalizer>;
+ auto op = ExecTemplate(options, NonZonedLocalizer());
+ applicator::ScalarUnaryNotNullStateful<OutType, Date64Type, ExecTemplate> kernel{op};
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+template <template <typename...> class Op, typename OutType>
+struct TemporalComponentExtractBase<Op, std::chrono::nanoseconds, Time64Type, OutType> {
+ template <typename OptionsType>
+ static Status ExecWithOptions(KernelContext* ctx, const OptionsType* options,
+ const ExecBatch& batch, Datum* out) {
+ using ExecTemplate = Op<std::chrono::nanoseconds, NonZonedLocalizer>;
+ auto op = ExecTemplate(options, NonZonedLocalizer());
+ applicator::ScalarUnaryNotNullStateful<OutType, Date64Type, ExecTemplate> kernel{op};
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+template <template <typename...> class Op, typename Duration, typename InType,
+ typename OutType, typename... Args>
+struct TemporalComponentExtract
+ : public TemporalComponentExtractBase<Op, Duration, InType, OutType, Args...> {
+ using Base = TemporalComponentExtractBase<Op, Duration, InType, OutType, Args...>;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out,
+ Args... args) {
+ const FunctionOptions* options = nullptr;
+ return Base::ExecWithOptions(ctx, options, batch, out, args...);
+ }
+};
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/test_util.cc b/src/arrow/cpp/src/arrow/compute/kernels/test_util.cc
new file mode 100644
index 000000000..e72c3dce2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/test_util.cc
@@ -0,0 +1,362 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/kernels/test_util.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+
+#include "arrow/array.h"
+#include "arrow/array/validate.h"
+#include "arrow/chunked_array.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/exec.h"
+#include "arrow/compute/function.h"
+#include "arrow/compute/registry.h"
+#include "arrow/datum.h"
+#include "arrow/result.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+namespace compute {
+
+namespace {
+
+template <typename T>
+DatumVector GetDatums(const std::vector<T>& inputs) {
+ std::vector<Datum> datums;
+ for (const auto& input : inputs) {
+ datums.emplace_back(input);
+ }
+ return datums;
+}
+
+template <typename... SliceArgs>
+DatumVector SliceArrays(const DatumVector& inputs, SliceArgs... slice_args) {
+ DatumVector sliced;
+ for (const auto& input : inputs) {
+ if (input.is_array()) {
+ sliced.push_back(*input.make_array()->Slice(slice_args...));
+ } else {
+ sliced.push_back(input);
+ }
+ }
+ return sliced;
+}
+
+ScalarVector GetScalars(const DatumVector& inputs, int64_t index) {
+ ScalarVector scalars;
+ for (const auto& input : inputs) {
+ if (input.is_array()) {
+ scalars.push_back(*input.make_array()->GetScalar(index));
+ } else {
+ scalars.push_back(input.scalar());
+ }
+ }
+ return scalars;
+}
+
+} // namespace
+
+void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs,
+ const Datum& expected, const FunctionOptions* options) {
+ ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options));
+ ValidateOutput(out);
+ AssertDatumsEqual(expected, out, /*verbose=*/true);
+}
+
+void CheckScalar(std::string func_name, const ScalarVector& inputs,
+ std::shared_ptr<Scalar> expected, const FunctionOptions* options) {
+ ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, GetDatums(inputs), options));
+ ValidateOutput(out);
+ if (!out.scalar()->Equals(expected)) {
+ std::string summary = func_name + "(";
+ for (const auto& input : inputs) {
+ summary += input->ToString() + ",";
+ }
+ summary.back() = ')';
+
+ summary += " = " + out.scalar()->ToString() + " != " + expected->ToString();
+
+ if (!out.type()->Equals(expected->type)) {
+ summary += " (types differed: " + out.type()->ToString() + " vs " +
+ expected->type->ToString() + ")";
+ }
+
+ FAIL() << summary;
+ }
+}
+
+void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expected_datum,
+ const FunctionOptions* options) {
+ CheckScalarNonRecursive(func_name, inputs, expected_datum, options);
+
+ if (expected_datum.is_scalar()) return;
+ ASSERT_TRUE(expected_datum.is_array())
+ << "CheckScalar is only implemented for scalar/array expected values";
+ auto expected = expected_datum.make_array();
+
+ // check for at least 1 array, and make sure the others are of equal length
+ bool has_array = false;
+ for (const auto& input : inputs) {
+ if (input.is_array()) {
+ ASSERT_EQ(input.array()->length, expected->length());
+ has_array = true;
+ }
+ }
+ ASSERT_TRUE(has_array) << "Must have at least 1 array input to have an array output";
+
+ // Check all the input scalars
+ for (int64_t i = 0; i < expected->length(); ++i) {
+ CheckScalar(func_name, GetScalars(inputs, i), *expected->GetScalar(i), options);
+ }
+
+ // Since it's a scalar function, calling it on sliced inputs should
+ // result in the sliced expected output.
+ const auto slice_length = expected->length() / 3;
+ if (slice_length > 0) {
+ CheckScalarNonRecursive(func_name, SliceArrays(inputs, 0, slice_length),
+ expected->Slice(0, slice_length), options);
+
+ CheckScalarNonRecursive(func_name, SliceArrays(inputs, slice_length, slice_length),
+ expected->Slice(slice_length, slice_length), options);
+
+ CheckScalarNonRecursive(func_name, SliceArrays(inputs, 2 * slice_length),
+ expected->Slice(2 * slice_length), options);
+ }
+
+ // Should also work with an empty slice
+ CheckScalarNonRecursive(func_name, SliceArrays(inputs, 0, 0), expected->Slice(0, 0),
+ options);
+
+ // Ditto with ChunkedArray inputs
+ if (slice_length > 0) {
+ DatumVector chunked_inputs;
+ chunked_inputs.reserve(inputs.size());
+ for (const auto& input : inputs) {
+ if (input.is_array()) {
+ auto ar = input.make_array();
+ auto ar_chunked = std::make_shared<ChunkedArray>(
+ ArrayVector{ar->Slice(0, slice_length), ar->Slice(slice_length)});
+ chunked_inputs.push_back(ar_chunked);
+ } else {
+ chunked_inputs.push_back(input.scalar());
+ }
+ }
+ ArrayVector expected_chunks{expected->Slice(0, slice_length),
+ expected->Slice(slice_length)};
+
+ ASSERT_OK_AND_ASSIGN(Datum out,
+ CallFunction(func_name, GetDatums(chunked_inputs), options));
+ ValidateOutput(out);
+ auto chunked = out.chunked_array();
+ (void)chunked;
+ AssertDatumsEqual(std::make_shared<ChunkedArray>(expected_chunks), out);
+ }
+}
+
+Datum CheckDictionaryNonRecursive(const std::string& func_name, const DatumVector& args,
+ bool result_is_encoded) {
+ EXPECT_OK_AND_ASSIGN(Datum actual, CallFunction(func_name, args));
+ ValidateOutput(actual);
+
+ DatumVector decoded_args;
+ decoded_args.reserve(args.size());
+ for (const auto& arg : args) {
+ if (arg.type()->id() == Type::DICTIONARY) {
+ const auto& to_type = checked_cast<const DictionaryType&>(*arg.type()).value_type();
+ EXPECT_OK_AND_ASSIGN(auto decoded, Cast(arg, to_type));
+ decoded_args.push_back(decoded);
+ } else {
+ decoded_args.push_back(arg);
+ }
+ }
+ EXPECT_OK_AND_ASSIGN(Datum expected, CallFunction(func_name, decoded_args));
+
+ if (result_is_encoded) {
+ EXPECT_EQ(Type::DICTIONARY, actual.type()->id())
+ << "Result should have been dictionary-encoded";
+ // Decode before comparison - we care about equivalent not identical results
+ const auto& to_type =
+ checked_cast<const DictionaryType&>(*actual.type()).value_type();
+ EXPECT_OK_AND_ASSIGN(auto decoded, Cast(actual, to_type));
+ AssertDatumsApproxEqual(expected, decoded, /*verbose=*/true);
+ } else {
+ AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
+ }
+ return actual;
+}
+
+void CheckDictionary(const std::string& func_name, const DatumVector& args,
+ bool result_is_encoded) {
+ auto actual = CheckDictionaryNonRecursive(func_name, args, result_is_encoded);
+
+ if (actual.is_scalar()) return;
+ ASSERT_TRUE(actual.is_array());
+ ASSERT_GE(actual.length(), 0);
+
+ // Check all scalars
+ for (int64_t i = 0; i < actual.length(); i++) {
+ CheckDictionaryNonRecursive(func_name, GetDatums(GetScalars(args, i)),
+ result_is_encoded);
+ }
+
+ // Check slices of the input
+ const auto slice_length = actual.length() / 3;
+ if (slice_length > 0) {
+ CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, slice_length),
+ result_is_encoded);
+ CheckDictionaryNonRecursive(func_name, SliceArrays(args, slice_length, slice_length),
+ result_is_encoded);
+ CheckDictionaryNonRecursive(func_name, SliceArrays(args, 2 * slice_length),
+ result_is_encoded);
+ }
+
+ // Check empty slice
+ CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, 0), result_is_encoded);
+
+ // Check chunked arrays
+ if (slice_length > 0) {
+ DatumVector chunked_args;
+ chunked_args.reserve(args.size());
+ for (const auto& arg : args) {
+ if (arg.is_array()) {
+ auto arr = arg.make_array();
+ ArrayVector chunks{arr->Slice(0, slice_length), arr->Slice(slice_length)};
+ chunked_args.push_back(std::make_shared<ChunkedArray>(std::move(chunks)));
+ } else {
+ chunked_args.push_back(arg);
+ }
+ }
+ CheckDictionaryNonRecursive(func_name, chunked_args, result_is_encoded);
+ }
+}
+
+void CheckScalarUnary(std::string func_name, Datum input, Datum expected,
+ const FunctionOptions* options) {
+ std::vector<Datum> input_vector = {std::move(input)};
+ CheckScalar(std::move(func_name), input_vector, expected, options);
+}
+
+void CheckScalarUnary(std::string func_name, std::shared_ptr<DataType> in_ty,
+ std::string json_input, std::shared_ptr<DataType> out_ty,
+ std::string json_expected, const FunctionOptions* options) {
+ CheckScalarUnary(std::move(func_name), ArrayFromJSON(in_ty, json_input),
+ ArrayFromJSON(out_ty, json_expected), options);
+}
+
+void CheckVectorUnary(std::string func_name, Datum input, Datum expected,
+ const FunctionOptions* options) {
+ ASSERT_OK_AND_ASSIGN(Datum actual, CallFunction(func_name, {input}, options));
+ ValidateOutput(actual);
+ AssertDatumsEqual(expected, actual, /*verbose=*/true);
+}
+
+void CheckScalarBinary(std::string func_name, Datum left_input, Datum right_input,
+ Datum expected, const FunctionOptions* options) {
+ CheckScalar(std::move(func_name), {left_input, right_input}, expected, options);
+}
+
+namespace {
+
+void ValidateOutput(const ArrayData& output) {
+ ASSERT_OK(::arrow::internal::ValidateArrayFull(output));
+ TestInitialized(output);
+}
+
+void ValidateOutput(const ChunkedArray& output) {
+ ASSERT_OK(output.ValidateFull());
+ for (const auto& chunk : output.chunks()) {
+ TestInitialized(*chunk);
+ }
+}
+
+void ValidateOutput(const RecordBatch& output) {
+ ASSERT_OK(output.ValidateFull());
+ for (const auto& column : output.column_data()) {
+ TestInitialized(*column);
+ }
+}
+
+void ValidateOutput(const Table& output) {
+ ASSERT_OK(output.ValidateFull());
+ for (const auto& column : output.columns()) {
+ for (const auto& chunk : column->chunks()) {
+ TestInitialized(*chunk);
+ }
+ }
+}
+
+void ValidateOutput(const Scalar& output) { ASSERT_OK(output.ValidateFull()); }
+
+} // namespace
+
+void ValidateOutput(const Datum& output) {
+ switch (output.kind()) {
+ case Datum::ARRAY:
+ ValidateOutput(*output.array());
+ break;
+ case Datum::CHUNKED_ARRAY:
+ ValidateOutput(*output.chunked_array());
+ break;
+ case Datum::RECORD_BATCH:
+ ValidateOutput(*output.record_batch());
+ break;
+ case Datum::TABLE:
+ ValidateOutput(*output.table());
+ break;
+ case Datum::SCALAR:
+ ValidateOutput(*output.scalar());
+ break;
+ default:
+ break;
+ }
+}
+
+void CheckDispatchBest(std::string func_name, std::vector<ValueDescr> original_values,
+ std::vector<ValueDescr> expected_equivalent_values) {
+ ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name));
+
+ auto values = original_values;
+ ASSERT_OK_AND_ASSIGN(auto actual_kernel, function->DispatchBest(&values));
+
+ ASSERT_OK_AND_ASSIGN(auto expected_kernel,
+ function->DispatchExact(expected_equivalent_values));
+
+ EXPECT_EQ(actual_kernel, expected_kernel)
+ << " DispatchBest" << ValueDescr::ToString(original_values) << " => "
+ << actual_kernel->signature->ToString() << "\n"
+ << " DispatchExact" << ValueDescr::ToString(expected_equivalent_values) << " => "
+ << expected_kernel->signature->ToString();
+ EXPECT_EQ(values.size(), expected_equivalent_values.size());
+ for (size_t i = 0; i < values.size(); i++) {
+ EXPECT_EQ(values[i].shape, expected_equivalent_values[i].shape)
+ << "Argument " << i << " should have the same shape";
+ AssertTypeEqual(values[i].type, expected_equivalent_values[i].type);
+ }
+}
+
+void CheckDispatchFails(std::string func_name, std::vector<ValueDescr> values) {
+ ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name));
+ ASSERT_NOT_OK(function->DispatchBest(&values));
+ ASSERT_NOT_OK(function->DispatchExact(values));
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/test_util.h b/src/arrow/cpp/src/arrow/compute/kernels/test_util.h
new file mode 100644
index 000000000..25ea577a4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/test_util.h
@@ -0,0 +1,241 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+// IWYU pragma: begin_exports
+
+#include <gmock/gmock.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernel.h"
+#include "arrow/datum.h"
+#include "arrow/memory_pool.h"
+#include "arrow/pretty_print.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+
+// IWYU pragma: end_exports
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+using DatumVector = std::vector<Datum>;
+
+template <typename Type, typename T>
+std::shared_ptr<Array> _MakeArray(const std::shared_ptr<DataType>& type,
+ const std::vector<T>& values,
+ const std::vector<bool>& is_valid) {
+ std::shared_ptr<Array> result;
+ if (is_valid.size() > 0) {
+ ArrayFromVector<Type, T>(type, is_valid, values, &result);
+ } else {
+ ArrayFromVector<Type, T>(type, values, &result);
+ }
+ return result;
+}
+
+inline std::string CompareOperatorToFunctionName(CompareOperator op) {
+ static std::string function_names[] = {
+ "equal", "not_equal", "greater", "greater_equal", "less", "less_equal",
+ };
+ return function_names[op];
+}
+
+// Call the function with the given arguments, as well as slices of
+// the arguments and scalars extracted from the arguments.
+void CheckScalar(std::string func_name, const ScalarVector& inputs,
+ std::shared_ptr<Scalar> expected,
+ const FunctionOptions* options = nullptr);
+
+void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expected,
+ const FunctionOptions* options = nullptr);
+
+// Like CheckScalar, but gets the expected result by
+// dictionary-decoding arguments and calling the function again.
+//
+// result_is_encoded controls whether the result is expected to be a
+// dictionary or not.
+void CheckDictionary(const std::string& func_name, const DatumVector& args,
+ bool result_is_encoded = true);
+
+// Just call the function with the given arguments.
+void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs,
+ const Datum& expected,
+ const FunctionOptions* options = nullptr);
+
+void CheckScalarUnary(std::string func_name, std::shared_ptr<DataType> in_ty,
+ std::string json_input, std::shared_ptr<DataType> out_ty,
+ std::string json_expected,
+ const FunctionOptions* options = nullptr);
+
+void CheckScalarUnary(std::string func_name, Datum input, Datum expected,
+ const FunctionOptions* options = nullptr);
+
+void CheckScalarBinary(std::string func_name, Datum left_input, Datum right_input,
+ Datum expected, const FunctionOptions* options = nullptr);
+
+void CheckVectorUnary(std::string func_name, Datum input, Datum expected,
+ const FunctionOptions* options = nullptr);
+
+void ValidateOutput(const Datum& output);
+
+static constexpr random::SeedType kRandomSeed = 0x0ff1ce;
+
+template <template <typename> class DoTestFunctor>
+void TestRandomPrimitiveCTypes() {
+ DoTestFunctor<Int8Type>::Test(int8());
+ DoTestFunctor<Int16Type>::Test(int16());
+ DoTestFunctor<Int32Type>::Test(int32());
+ DoTestFunctor<Int64Type>::Test(int64());
+ DoTestFunctor<UInt8Type>::Test(uint8());
+ DoTestFunctor<UInt16Type>::Test(uint16());
+ DoTestFunctor<UInt32Type>::Test(uint32());
+ DoTestFunctor<UInt64Type>::Test(uint64());
+ DoTestFunctor<FloatType>::Test(float32());
+ DoTestFunctor<DoubleType>::Test(float64());
+ DoTestFunctor<Date32Type>::Test(date32());
+ DoTestFunctor<Date64Type>::Test(date64());
+ DoTestFunctor<Time32Type>::Test(time32(TimeUnit::SECOND));
+ DoTestFunctor<Time64Type>::Test(time64(TimeUnit::MICRO));
+ DoTestFunctor<TimestampType>::Test(timestamp(TimeUnit::SECOND));
+ DoTestFunctor<TimestampType>::Test(timestamp(TimeUnit::MICRO));
+ DoTestFunctor<DurationType>::Test(duration(TimeUnit::MILLI));
+}
+
+// Check that DispatchBest on a given function yields the same Kernel as
+// produced by DispatchExact on another set of ValueDescrs.
+void CheckDispatchBest(std::string func_name, std::vector<ValueDescr> descrs,
+ std::vector<ValueDescr> exact_descrs);
+
+// Check that function fails to produce a Kernel for the set of ValueDescrs.
+void CheckDispatchFails(std::string func_name, std::vector<ValueDescr> descrs);
+
+// Helper to get a default instance of a type, including parameterized types
+template <typename T>
+enable_if_parameter_free<T, std::shared_ptr<DataType>> default_type_instance() {
+ return TypeTraits<T>::type_singleton();
+}
+template <typename T>
+enable_if_time<T, std::shared_ptr<DataType>> default_type_instance() {
+ // Time32 requires second/milli, Time64 requires nano/micro
+ if (bit_width(T::type_id) == 32) {
+ return std::make_shared<T>(TimeUnit::type::SECOND);
+ }
+ return std::make_shared<T>(TimeUnit::type::NANO);
+}
+template <typename T>
+enable_if_timestamp<T, std::shared_ptr<DataType>> default_type_instance() {
+ return std::make_shared<T>(TimeUnit::type::SECOND);
+}
+template <typename T>
+enable_if_decimal<T, std::shared_ptr<DataType>> default_type_instance() {
+ return std::make_shared<T>(5, 2);
+}
+
+// Random Generator Helpers
+class RandomImpl {
+ protected:
+ random::RandomArrayGenerator generator_;
+ std::shared_ptr<DataType> type_;
+
+ explicit RandomImpl(random::SeedType seed, std::shared_ptr<DataType> type)
+ : generator_(seed), type_(std::move(type)) {}
+
+ public:
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob) {
+ return generator_.ArrayOf(type_, count, null_prob);
+ }
+
+ std::shared_ptr<Int32Array> Offsets(int32_t length, int32_t slice_count) {
+ return arrow::internal::checked_pointer_cast<Int32Array>(
+ generator_.Offsets(slice_count, 0, length));
+ }
+};
+
+template <typename ArrowType>
+class Random : public RandomImpl {
+ public:
+ explicit Random(random::SeedType seed)
+ : RandomImpl(seed, TypeTraits<ArrowType>::type_singleton()) {}
+};
+
+template <>
+class Random<FloatType> : public RandomImpl {
+ using CType = float;
+
+ public:
+ explicit Random(random::SeedType seed) : RandomImpl(seed, float32()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob, double nan_prob = 0) {
+ return generator_.Float32(count, std::numeric_limits<CType>::min(),
+ std::numeric_limits<CType>::max(), null_prob, nan_prob);
+ }
+};
+
+template <>
+class Random<DoubleType> : public RandomImpl {
+ using CType = double;
+
+ public:
+ explicit Random(random::SeedType seed) : RandomImpl(seed, float64()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob, double nan_prob = 0) {
+ return generator_.Float64(count, std::numeric_limits<CType>::min(),
+ std::numeric_limits<CType>::max(), null_prob, nan_prob);
+ }
+};
+
+template <>
+class Random<Decimal128Type> : public RandomImpl {
+ public:
+ explicit Random(random::SeedType seed,
+ std::shared_ptr<DataType> type = decimal128(18, 5))
+ : RandomImpl(seed, std::move(type)) {}
+};
+
+template <typename ArrowType>
+class RandomRange : public RandomImpl {
+ using CType = typename TypeTraits<ArrowType>::CType;
+
+ public:
+ explicit RandomRange(random::SeedType seed)
+ : RandomImpl(seed, TypeTraits<ArrowType>::type_singleton()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, int range, double null_prob) {
+ CType min = std::numeric_limits<CType>::min();
+ CType max = min + range;
+ if (sizeof(CType) < 4 && (range + min) > std::numeric_limits<CType>::max()) {
+ max = std::numeric_limits<CType>::max();
+ }
+ return generator_.Numeric<ArrowType>(count, min, max, null_prob);
+ }
+};
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/util_internal.cc b/src/arrow/cpp/src/arrow/compute/kernels/util_internal.cc
new file mode 100644
index 000000000..846fa26ba
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/util_internal.cc
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/kernels/util_internal.h"
+
+#include <cstdint>
+
+#include "arrow/array/data.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+namespace internal {
+
+const uint8_t* GetValidityBitmap(const ArrayData& data) {
+ const uint8_t* bitmap = nullptr;
+ if (data.buffers[0]) {
+ bitmap = data.buffers[0]->data();
+ }
+ return bitmap;
+}
+
+int GetBitWidth(const DataType& type) {
+ return checked_cast<const FixedWidthType&>(type).bit_width();
+}
+
+PrimitiveArg GetPrimitiveArg(const ArrayData& arr) {
+ PrimitiveArg arg;
+ arg.is_valid = GetValidityBitmap(arr);
+ arg.data = arr.buffers[1]->data();
+ arg.bit_width = GetBitWidth(*arr.type);
+ arg.offset = arr.offset;
+ arg.length = arr.length;
+ if (arg.bit_width > 1) {
+ arg.data += arr.offset * arg.bit_width / 8;
+ }
+ // This may be kUnknownNullCount
+ arg.null_count = (arg.is_valid != nullptr) ? arr.null_count.load() : 0;
+ return arg;
+}
+
+ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec,
+ NullHandling::type null_handling) {
+ return [=](KernelContext* ctx, const ExecBatch& batch, Datum* out) -> Status {
+ if (out->is_array()) {
+ return exec(ctx, batch, out);
+ }
+
+ if (null_handling == NullHandling::INTERSECTION && !batch[0].scalar()->is_valid) {
+ out->scalar()->is_valid = false;
+ return Status::OK();
+ }
+
+ ARROW_ASSIGN_OR_RAISE(Datum array_in, MakeArrayFromScalar(*batch[0].scalar(), 1));
+ ARROW_ASSIGN_OR_RAISE(Datum array_out, MakeArrayFromScalar(*out->scalar(), 1));
+ RETURN_NOT_OK(exec(ctx, ExecBatch{{std::move(array_in)}, 1}, &array_out));
+ ARROW_ASSIGN_OR_RAISE(*out, array_out.make_array()->GetScalar(0));
+ return Status::OK();
+ };
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/util_internal.h b/src/arrow/cpp/src/arrow/compute/kernels/util_internal.h
new file mode 100644
index 000000000..eaaf96ef4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/util_internal.h
@@ -0,0 +1,166 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <type_traits>
+#include <utility>
+
+#include "arrow/array/util.h"
+#include "arrow/buffer.h"
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/compute/type_fwd.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/math_constants.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+template <typename T>
+using maybe_make_unsigned =
+ typename std::conditional<std::is_integral<T>::value && !std::is_same<T, bool>::value,
+ std::make_unsigned<T>, std::common_type<T> >::type;
+
+template <typename T, typename Unsigned = typename maybe_make_unsigned<T>::type>
+constexpr Unsigned to_unsigned(T signed_) {
+ return static_cast<Unsigned>(signed_);
+}
+
+// An internal data structure for unpacking a primitive argument to pass to a
+// kernel implementation
+struct PrimitiveArg {
+ const uint8_t* is_valid;
+ // If the bit_width is a multiple of 8 (i.e. not boolean), then "data" should
+ // be shifted by offset * (bit_width / 8). For bit-packed data, the offset
+ // must be used when indexing.
+ const uint8_t* data;
+ int bit_width;
+ int64_t length;
+ int64_t offset;
+ // This may be kUnknownNullCount if the null_count has not yet been computed,
+ // so use null_count != 0 to determine "may have nulls".
+ int64_t null_count;
+};
+
+// Get validity bitmap data or return nullptr if there is no validity buffer
+const uint8_t* GetValidityBitmap(const ArrayData& data);
+
+int GetBitWidth(const DataType& type);
+
+// Reduce code size by dealing with the unboxing of the kernel inputs once
+// rather than duplicating compiled code to do all these in each kernel.
+PrimitiveArg GetPrimitiveArg(const ArrayData& arr);
+
+// Augment a unary ArrayKernelExec which supports only array-like inputs with support for
+// scalar inputs. Scalars will be transformed to 1-long arrays with the scalar's value (or
+// null if the scalar is null) as its only element. This 1-long array will be passed to
+// the original exec, then the only element of the resulting array will be extracted as
+// the output scalar. This could be far more efficient, but instead of optimizing this
+// it'd be better to support scalar inputs "upstream" in original exec.
+ArrayKernelExec TrivialScalarUnaryAsArraysExec(
+ ArrayKernelExec exec, NullHandling::type null_handling = NullHandling::INTERSECTION);
+
+// Return (min, max) of a numerical array, ignore nulls.
+// For empty array, return the maximal number limit as 'min', and minimal limit as 'max'.
+template <typename T>
+ARROW_NOINLINE std::pair<T, T> GetMinMax(const ArrayData& data) {
+ T min = std::numeric_limits<T>::max();
+ T max = std::numeric_limits<T>::lowest();
+
+ const T* values = data.GetValues<T>(1);
+ arrow::internal::VisitSetBitRunsVoid(data.buffers[0], data.offset, data.length,
+ [&](int64_t pos, int64_t len) {
+ for (int64_t i = 0; i < len; ++i) {
+ min = std::min(min, values[pos + i]);
+ max = std::max(max, values[pos + i]);
+ }
+ });
+
+ return std::make_pair(min, max);
+}
+
+template <typename T>
+std::pair<T, T> GetMinMax(const Datum& datum) {
+ T min = std::numeric_limits<T>::max();
+ T max = std::numeric_limits<T>::lowest();
+
+ for (const auto& array : datum.chunks()) {
+ T local_min, local_max;
+ std::tie(local_min, local_max) = GetMinMax<T>(*array->data());
+ min = std::min(min, local_min);
+ max = std::max(max, local_max);
+ }
+
+ return std::make_pair(min, max);
+}
+
+// Count value occurrences of an array, ignore nulls.
+// 'counts' must be zeroed and with enough size.
+template <typename T>
+ARROW_NOINLINE int64_t CountValues(uint64_t* counts, const ArrayData& data, T min) {
+ const int64_t n = data.length - data.GetNullCount();
+ if (n > 0) {
+ const T* values = data.GetValues<T>(1);
+ arrow::internal::VisitSetBitRunsVoid(data.buffers[0], data.offset, data.length,
+ [&](int64_t pos, int64_t len) {
+ for (int64_t i = 0; i < len; ++i) {
+ ++counts[values[pos + i] - min];
+ }
+ });
+ }
+ return n;
+}
+
+template <typename T>
+int64_t CountValues(uint64_t* counts, const Datum& datum, T min) {
+ int64_t n = 0;
+ for (const auto& array : datum.chunks()) {
+ n += CountValues<T>(counts, *array->data(), min);
+ }
+ return n;
+}
+
+// Copy numerical array values to a buffer, ignore nulls.
+template <typename T>
+ARROW_NOINLINE int64_t CopyNonNullValues(const ArrayData& data, T* out) {
+ const int64_t n = data.length - data.GetNullCount();
+ if (n > 0) {
+ int64_t index = 0;
+ const T* values = data.GetValues<T>(1);
+ arrow::internal::VisitSetBitRunsVoid(
+ data.buffers[0], data.offset, data.length, [&](int64_t pos, int64_t len) {
+ memcpy(out + index, values + pos, len * sizeof(T));
+ index += len;
+ });
+ }
+ return n;
+}
+
+template <typename T>
+int64_t CopyNonNullValues(const Datum& datum, T* out) {
+ int64_t n = 0;
+ for (const auto& array : datum.chunks()) {
+ n += CopyNonNullValues(*array->data(), out + n);
+ }
+ return n;
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_array_sort.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_array_sort.cc
new file mode 100644
index 000000000..1de122663
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_array_sort.cc
@@ -0,0 +1,561 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cmath>
+#include <iterator>
+#include <limits>
+#include <numeric>
+#include <type_traits>
+#include <utility>
+
+#include "arrow/array/data.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/compute/kernels/vector_sort_internal.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bitmap.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+namespace internal {
+
+namespace {
+
+// ----------------------------------------------------------------------
+// partition_nth_indices implementation
+
+// We need to preserve the options
+using PartitionNthToIndicesState = internal::OptionsWrapper<PartitionNthOptions>;
+
+template <typename OutType, typename InType>
+struct PartitionNthToIndices {
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ using GetView = GetViewType<InType>;
+
+ if (ctx->state() == nullptr) {
+ return Status::Invalid("NthToIndices requires PartitionNthOptions");
+ }
+ const auto& options = PartitionNthToIndicesState::Get(ctx);
+
+ ArrayType arr(batch[0].array());
+
+ const int64_t pivot = options.pivot;
+ if (pivot > arr.length()) {
+ return Status::IndexError("NthToIndices index out of bound");
+ }
+ ArrayData* out_arr = out->mutable_array();
+ uint64_t* out_begin = out_arr->GetMutableValues<uint64_t>(1);
+ uint64_t* out_end = out_begin + arr.length();
+ std::iota(out_begin, out_end, 0);
+ if (pivot == arr.length()) {
+ return Status::OK();
+ }
+ const auto p = PartitionNulls<ArrayType, NonStablePartitioner>(
+ out_begin, out_end, arr, 0, options.null_placement);
+ auto nth_begin = out_begin + pivot;
+ if (nth_begin >= p.non_nulls_begin && nth_begin < p.non_nulls_end) {
+ std::nth_element(p.non_nulls_begin, nth_begin, p.non_nulls_end,
+ [&arr](uint64_t left, uint64_t right) {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ return lval < rval;
+ });
+ }
+ return Status::OK();
+ }
+};
+
+template <typename OutType>
+struct PartitionNthToIndices<OutType, NullType> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (ctx->state() == nullptr) {
+ return Status::Invalid("NthToIndices requires PartitionNthOptions");
+ }
+ ArrayData* out_arr = out->mutable_array();
+ uint64_t* out_begin = out_arr->GetMutableValues<uint64_t>(1);
+ uint64_t* out_end = out_begin + batch.length;
+ std::iota(out_begin, out_end, 0);
+ return Status::OK();
+ }
+};
+
+// ----------------------------------------------------------------------
+// Array sorting implementations
+
+template <typename ArrayType, typename VisitorNotNull, typename VisitorNull>
+inline void VisitRawValuesInline(const ArrayType& values,
+ VisitorNotNull&& visitor_not_null,
+ VisitorNull&& visitor_null) {
+ const auto data = values.raw_values();
+ VisitBitBlocksVoid(
+ values.null_bitmap(), values.offset(), values.length(),
+ [&](int64_t i) { visitor_not_null(data[i]); }, [&]() { visitor_null(); });
+}
+
+template <typename VisitorNotNull, typename VisitorNull>
+inline void VisitRawValuesInline(const BooleanArray& values,
+ VisitorNotNull&& visitor_not_null,
+ VisitorNull&& visitor_null) {
+ if (values.null_count() != 0) {
+ const uint8_t* data = values.data()->GetValues<uint8_t>(1, 0);
+ VisitBitBlocksVoid(
+ values.null_bitmap(), values.offset(), values.length(),
+ [&](int64_t i) { visitor_not_null(BitUtil::GetBit(data, values.offset() + i)); },
+ [&]() { visitor_null(); });
+ } else {
+ // Can avoid GetBit() overhead in the no-nulls case
+ VisitBitBlocksVoid(
+ values.data()->buffers[1], values.offset(), values.length(),
+ [&](int64_t i) { visitor_not_null(true); }, [&]() { visitor_not_null(false); });
+ }
+}
+
+template <typename ArrowType>
+class ArrayCompareSorter {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ using GetView = GetViewType<ArrowType>;
+
+ public:
+ // `offset` is used when this is called on a chunk of a chunked array
+ NullPartitionResult operator()(uint64_t* indices_begin, uint64_t* indices_end,
+ const Array& array, int64_t offset,
+ const ArraySortOptions& options) {
+ const auto& values = checked_cast<const ArrayType&>(array);
+
+ const auto p = PartitionNulls<ArrayType, StablePartitioner>(
+ indices_begin, indices_end, values, offset, options.null_placement);
+ if (options.order == SortOrder::Ascending) {
+ std::stable_sort(
+ p.non_nulls_begin, p.non_nulls_end,
+ [&values, &offset](uint64_t left, uint64_t right) {
+ const auto lhs = GetView::LogicalValue(values.GetView(left - offset));
+ const auto rhs = GetView::LogicalValue(values.GetView(right - offset));
+ return lhs < rhs;
+ });
+ } else {
+ std::stable_sort(
+ p.non_nulls_begin, p.non_nulls_end,
+ [&values, &offset](uint64_t left, uint64_t right) {
+ const auto lhs = GetView::LogicalValue(values.GetView(left - offset));
+ const auto rhs = GetView::LogicalValue(values.GetView(right - offset));
+ // We don't use 'left > right' here to reduce required operator.
+ // If we use 'right < left' here, '<' is only required.
+ return rhs < lhs;
+ });
+ }
+ return p;
+ }
+};
+
+template <typename ArrowType>
+class ArrayCountSorter {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ using c_type = typename ArrowType::c_type;
+
+ public:
+ ArrayCountSorter() = default;
+
+ explicit ArrayCountSorter(c_type min, c_type max) { SetMinMax(min, max); }
+
+ // Assume: max >= min && (max - min) < 4Gi
+ void SetMinMax(c_type min, c_type max) {
+ min_ = min;
+ value_range_ = static_cast<uint32_t>(max - min) + 1;
+ }
+
+ NullPartitionResult operator()(uint64_t* indices_begin, uint64_t* indices_end,
+ const Array& array, int64_t offset,
+ const ArraySortOptions& options) const {
+ const auto& values = checked_cast<const ArrayType&>(array);
+
+ // 32bit counter performs much better than 64bit one
+ if (values.length() < (1LL << 32)) {
+ return SortInternal<uint32_t>(indices_begin, indices_end, values, offset, options);
+ } else {
+ return SortInternal<uint64_t>(indices_begin, indices_end, values, offset, options);
+ }
+ }
+
+ private:
+ c_type min_{0};
+ uint32_t value_range_{0};
+
+ // `offset` is used when this is called on a chunk of a chunked array
+ template <typename CounterType>
+ NullPartitionResult SortInternal(uint64_t* indices_begin, uint64_t* indices_end,
+ const ArrayType& values, int64_t offset,
+ const ArraySortOptions& options) const {
+ const uint32_t value_range = value_range_;
+
+ // first and last slot reserved for prefix sum (depending on sort order)
+ std::vector<CounterType> counts(2 + value_range);
+ NullPartitionResult p;
+
+ if (options.order == SortOrder::Ascending) {
+ // counts will be increasing, starting with 0 and ending with (length - null_count)
+ CountValues(values, &counts[1]);
+ for (uint32_t i = 1; i <= value_range; ++i) {
+ counts[i] += counts[i - 1];
+ }
+
+ if (options.null_placement == NullPlacement::AtStart) {
+ p = NullPartitionResult::NullsAtStart(indices_begin, indices_end,
+ indices_end - counts[value_range]);
+ } else {
+ p = NullPartitionResult::NullsAtEnd(indices_begin, indices_end,
+ indices_begin + counts[value_range]);
+ }
+ EmitIndices(p, values, offset, &counts[0]);
+ } else {
+ // counts will be decreasing, starting with (length - null_count) and ending with 0
+ CountValues(values, &counts[0]);
+ for (uint32_t i = value_range; i >= 1; --i) {
+ counts[i - 1] += counts[i];
+ }
+
+ if (options.null_placement == NullPlacement::AtStart) {
+ p = NullPartitionResult::NullsAtStart(indices_begin, indices_end,
+ indices_end - counts[0]);
+ } else {
+ p = NullPartitionResult::NullsAtEnd(indices_begin, indices_end,
+ indices_begin + counts[0]);
+ }
+ EmitIndices(p, values, offset, &counts[1]);
+ }
+ return p;
+ }
+
+ template <typename CounterType>
+ void CountValues(const ArrayType& values, CounterType* counts) const {
+ VisitRawValuesInline(
+ values, [&](c_type v) { ++counts[v - min_]; }, []() {});
+ }
+
+ template <typename CounterType>
+ void EmitIndices(const NullPartitionResult& p, const ArrayType& values, int64_t offset,
+ CounterType* counts) const {
+ int64_t index = offset;
+ CounterType count_nulls = 0;
+ VisitRawValuesInline(
+ values, [&](c_type v) { p.non_nulls_begin[counts[v - min_]++] = index++; },
+ [&]() { p.nulls_begin[count_nulls++] = index++; });
+ }
+};
+
+template <>
+class ArrayCountSorter<BooleanType> {
+ public:
+ ArrayCountSorter() = default;
+
+ // `offset` is used when this is called on a chunk of a chunked array
+ NullPartitionResult operator()(uint64_t* indices_begin, uint64_t* indices_end,
+ const Array& array, int64_t offset,
+ const ArraySortOptions& options) {
+ const auto& values = checked_cast<const BooleanArray&>(array);
+
+ std::array<int64_t, 3> counts{0, 0, 0}; // false, true, null
+
+ const int64_t nulls = values.null_count();
+ const int64_t ones = values.true_count();
+ const int64_t zeros = values.length() - ones - nulls;
+
+ NullPartitionResult p;
+ if (options.null_placement == NullPlacement::AtStart) {
+ p = NullPartitionResult::NullsAtStart(indices_begin, indices_end,
+ indices_begin + nulls);
+ } else {
+ p = NullPartitionResult::NullsAtEnd(indices_begin, indices_end,
+ indices_end - nulls);
+ }
+
+ if (options.order == SortOrder::Ascending) {
+ // ones start after zeros
+ counts[1] = zeros;
+ } else {
+ // zeros start after ones
+ counts[0] = ones;
+ }
+
+ int64_t index = offset;
+ VisitRawValuesInline(
+ values, [&](bool v) { p.non_nulls_begin[counts[v]++] = index++; },
+ [&]() { p.nulls_begin[counts[2]++] = index++; });
+ return p;
+ }
+};
+
+// Sort integers with counting sort or comparison based sorting algorithm
+// - Use O(n) counting sort if values are in a small range
+// - Use O(nlogn) std::stable_sort otherwise
+template <typename ArrowType>
+class ArrayCountOrCompareSorter {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ using c_type = typename ArrowType::c_type;
+
+ public:
+ // `offset` is used when this is called on a chunk of a chunked array
+ NullPartitionResult operator()(uint64_t* indices_begin, uint64_t* indices_end,
+ const Array& array, int64_t offset,
+ const ArraySortOptions& options) {
+ const auto& values = checked_cast<const ArrayType&>(array);
+
+ if (values.length() >= countsort_min_len_ && values.length() > values.null_count()) {
+ c_type min, max;
+ std::tie(min, max) = GetMinMax<c_type>(*values.data());
+
+ // For signed int32/64, (max - min) may overflow and trigger UBSAN.
+ // Cast to largest unsigned type(uint64_t) before subtraction.
+ if (static_cast<uint64_t>(max) - static_cast<uint64_t>(min) <=
+ countsort_max_range_) {
+ count_sorter_.SetMinMax(min, max);
+ return count_sorter_(indices_begin, indices_end, values, offset, options);
+ }
+ }
+
+ return compare_sorter_(indices_begin, indices_end, values, offset, options);
+ }
+
+ private:
+ ArrayCompareSorter<ArrowType> compare_sorter_;
+ ArrayCountSorter<ArrowType> count_sorter_;
+
+ // Cross point to prefer counting sort than stl::stable_sort(merge sort)
+ // - array to be sorted is longer than "count_min_len_"
+ // - value range (max-min) is within "count_max_range_"
+ //
+ // The optimal setting depends heavily on running CPU. Below setting is
+ // conservative to adapt to various hardware and keep code simple.
+ // It's possible to decrease array-len and/or increase value-range to cover
+ // more cases, or setup a table for best array-len/value-range combinations.
+ // See https://issues.apache.org/jira/browse/ARROW-1571 for detailed analysis.
+ static const uint32_t countsort_min_len_ = 1024;
+ static const uint32_t countsort_max_range_ = 4096;
+};
+
+class ArrayNullSorter {
+ public:
+ NullPartitionResult operator()(uint64_t* indices_begin, uint64_t* indices_end,
+ const Array& values, int64_t offset,
+ const ArraySortOptions& options) {
+ return NullPartitionResult::NullsOnly(indices_begin, indices_end,
+ options.null_placement);
+ }
+};
+
+//
+// Generic Array sort dispatcher for physical types
+//
+
+template <typename Type, typename Enable = void>
+struct ArraySorter {};
+
+template <>
+struct ArraySorter<NullType> {
+ ArrayNullSorter impl;
+};
+
+template <>
+struct ArraySorter<BooleanType> {
+ ArrayCountSorter<BooleanType> impl;
+};
+
+template <>
+struct ArraySorter<UInt8Type> {
+ ArrayCountSorter<UInt8Type> impl;
+ ArraySorter() : impl(0, 255) {}
+};
+
+template <>
+struct ArraySorter<Int8Type> {
+ ArrayCountSorter<Int8Type> impl;
+ ArraySorter() : impl(-128, 127) {}
+};
+
+template <typename Type>
+struct ArraySorter<Type, enable_if_t<is_integer_type<Type>::value &&
+ (sizeof(typename Type::c_type) > 1)>> {
+ static constexpr bool is_supported = true;
+ ArrayCountOrCompareSorter<Type> impl;
+};
+
+template <typename Type>
+struct ArraySorter<
+ Type, enable_if_t<is_floating_type<Type>::value || is_base_binary_type<Type>::value ||
+ is_fixed_size_binary_type<Type>::value>> {
+ ArrayCompareSorter<Type> impl;
+};
+
+struct ArraySorterFactory {
+ ArraySortFunc sorter;
+
+ Status Visit(const DataType& type) {
+ return Status::TypeError("Sorting not supported for type ", type.ToString());
+ }
+
+ template <typename T, typename U = decltype(ArraySorter<T>::impl)>
+ Status Visit(const T& type, U* = nullptr) {
+ sorter = ArraySortFunc(std::move(ArraySorter<T>{}.impl));
+ return Status::OK();
+ }
+
+ Result<ArraySortFunc> MakeSorter(const DataType& type) {
+ RETURN_NOT_OK(VisitTypeInline(type, this));
+ DCHECK(sorter);
+ return std::move(sorter);
+ }
+};
+
+using ArraySortIndicesState = internal::OptionsWrapper<ArraySortOptions>;
+
+template <typename OutType, typename InType>
+struct ArraySortIndices {
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& options = ArraySortIndicesState::Get(ctx);
+ ArrayType arr(batch[0].array());
+ ARROW_ASSIGN_OR_RAISE(auto sorter, GetArraySorter(*GetPhysicalType(arr.type())));
+
+ ArrayData* out_arr = out->mutable_array();
+ uint64_t* out_begin = out_arr->GetMutableValues<uint64_t>(1);
+ uint64_t* out_end = out_begin + arr.length();
+ std::iota(out_begin, out_end, 0);
+
+ sorter(out_begin, out_end, arr, 0, options);
+ return Status::OK();
+ }
+};
+
+template <template <typename...> class ExecTemplate>
+void AddArraySortingKernels(VectorKernel base, VectorFunction* func) {
+ // null type
+ base.signature = KernelSignature::Make({InputType::Array(null())}, uint64());
+ base.exec = ExecTemplate<UInt64Type, NullType>::Exec;
+ DCHECK_OK(func->AddKernel(base));
+
+ // bool type
+ base.signature = KernelSignature::Make({InputType::Array(boolean())}, uint64());
+ base.exec = ExecTemplate<UInt64Type, BooleanType>::Exec;
+ DCHECK_OK(func->AddKernel(base));
+
+ // duration type
+ base.signature = KernelSignature::Make({InputType::Array(Type::DURATION)}, uint64());
+ base.exec = GenerateNumeric<ExecTemplate, UInt64Type>(*int64());
+ DCHECK_OK(func->AddKernel(base));
+
+ for (const auto& ty : NumericTypes()) {
+ auto physical_type = GetPhysicalType(ty);
+ base.signature = KernelSignature::Make({InputType::Array(ty)}, uint64());
+ base.exec = GenerateNumeric<ExecTemplate, UInt64Type>(*physical_type);
+ DCHECK_OK(func->AddKernel(base));
+ }
+ for (const auto& ty : TemporalTypes()) {
+ auto physical_type = GetPhysicalType(ty);
+ base.signature = KernelSignature::Make({InputType::Array(ty->id())}, uint64());
+ base.exec = GenerateNumeric<ExecTemplate, UInt64Type>(*physical_type);
+ DCHECK_OK(func->AddKernel(base));
+ }
+ for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) {
+ base.signature = KernelSignature::Make({InputType::Array(id)}, uint64());
+ base.exec = GenerateDecimal<ExecTemplate, UInt64Type>(id);
+ DCHECK_OK(func->AddKernel(base));
+ }
+ for (const auto& ty : BaseBinaryTypes()) {
+ auto physical_type = GetPhysicalType(ty);
+ base.signature = KernelSignature::Make({InputType::Array(ty)}, uint64());
+ base.exec = GenerateVarBinaryBase<ExecTemplate, UInt64Type>(*physical_type);
+ DCHECK_OK(func->AddKernel(base));
+ }
+ base.signature =
+ KernelSignature::Make({InputType::Array(Type::FIXED_SIZE_BINARY)}, uint64());
+ base.exec = ExecTemplate<UInt64Type, FixedSizeBinaryType>::Exec;
+ DCHECK_OK(func->AddKernel(base));
+}
+
+const auto kDefaultArraySortOptions = ArraySortOptions::Defaults();
+
+const FunctionDoc array_sort_indices_doc(
+ "Return the indices that would sort an array",
+ ("This function computes an array of indices that define a stable sort\n"
+ "of the input array. By default, Null values are considered greater\n"
+ "than any other value and are therefore sorted at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values.\n"
+ "\n"
+ "The handling of nulls and NaNs can be changed in ArraySortOptions."),
+ {"array"}, "ArraySortOptions");
+
+const FunctionDoc partition_nth_indices_doc(
+ "Return the indices that would partition an array around a pivot",
+ ("This functions computes an array of indices that define a non-stable\n"
+ "partial sort of the input array.\n"
+ "\n"
+ "The output is such that the `N`'th index points to the `N`'th element\n"
+ "of the input in sorted order, and all indices before the `N`'th point\n"
+ "to elements in the input less or equal to elements at or after the `N`'th.\n"
+ "\n"
+ "By default, null values are considered greater than any other value\n"
+ "and are therefore partitioned towards the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values.\n"
+ "\n"
+ "The pivot index `N` must be given in PartitionNthOptions.\n"
+ "The handling of nulls and NaNs can also be changed in PartitionNthOptions."),
+ {"array"}, "PartitionNthOptions");
+
+} // namespace
+
+Result<ArraySortFunc> GetArraySorter(const DataType& type) {
+ ArraySorterFactory factory;
+ return factory.MakeSorter(type);
+}
+
+void RegisterVectorArraySort(FunctionRegistry* registry) {
+ // The kernel outputs into preallocated memory and is never null
+ VectorKernel base;
+ base.mem_allocation = MemAllocation::PREALLOCATE;
+ base.null_handling = NullHandling::OUTPUT_NOT_NULL;
+
+ auto array_sort_indices = std::make_shared<VectorFunction>(
+ "array_sort_indices", Arity::Unary(), &array_sort_indices_doc,
+ &kDefaultArraySortOptions);
+ base.init = ArraySortIndicesState::Init;
+ AddArraySortingKernels<ArraySortIndices>(base, array_sort_indices.get());
+ DCHECK_OK(registry->AddFunction(std::move(array_sort_indices)));
+
+ // partition_nth_indices has a parameter so needs its init function
+ auto part_indices = std::make_shared<VectorFunction>(
+ "partition_nth_indices", Arity::Unary(), &partition_nth_indices_doc);
+ base.init = PartitionNthToIndicesState::Init;
+ AddArraySortingKernels<PartitionNthToIndices>(base, part_indices.get());
+ DCHECK_OK(registry->AddFunction(std::move(part_indices)));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_hash.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_hash.cc
new file mode 100644
index 000000000..02c443baf
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_hash.cc
@@ -0,0 +1,807 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstring>
+#include <mutex>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_dict.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/array/dict_internal.h"
+#include "arrow/array/util.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/result.h"
+#include "arrow/util/hashing.h"
+#include "arrow/util/make_unique.h"
+
+namespace arrow {
+
+using internal::DictionaryTraits;
+using internal::HashTraits;
+
+namespace compute {
+namespace internal {
+
+namespace {
+
+class ActionBase {
+ public:
+ ActionBase(const std::shared_ptr<DataType>& type, MemoryPool* pool)
+ : type_(type), pool_(pool) {}
+
+ protected:
+ std::shared_ptr<DataType> type_;
+ MemoryPool* pool_;
+};
+
+// ----------------------------------------------------------------------
+// Unique
+
+class UniqueAction final : public ActionBase {
+ public:
+ using ActionBase::ActionBase;
+
+ static constexpr bool with_error_status = false;
+
+ UniqueAction(const std::shared_ptr<DataType>& type, const FunctionOptions* options,
+ MemoryPool* pool)
+ : ActionBase(type, pool) {}
+
+ Status Reset() { return Status::OK(); }
+
+ Status Reserve(const int64_t length) { return Status::OK(); }
+
+ template <class Index>
+ void ObserveNullFound(Index index) {}
+
+ template <class Index>
+ void ObserveNullNotFound(Index index) {}
+
+ template <class Index>
+ void ObserveFound(Index index) {}
+
+ template <class Index>
+ void ObserveNotFound(Index index) {}
+
+ bool ShouldEncodeNulls() { return true; }
+
+ Status Flush(Datum* out) { return Status::OK(); }
+
+ Status FlushFinal(Datum* out) { return Status::OK(); }
+};
+
+// ----------------------------------------------------------------------
+// Count values
+
+class ValueCountsAction final : ActionBase {
+ public:
+ using ActionBase::ActionBase;
+
+ static constexpr bool with_error_status = true;
+
+ ValueCountsAction(const std::shared_ptr<DataType>& type, const FunctionOptions* options,
+ MemoryPool* pool)
+ : ActionBase(type, pool), count_builder_(pool) {}
+
+ Status Reserve(const int64_t length) {
+ // builder size is independent of input array size.
+ return Status::OK();
+ }
+
+ Status Reset() {
+ count_builder_.Reset();
+ return Status::OK();
+ }
+
+ // Don't do anything on flush because we don't want to finalize the builder
+ // or incur the cost of memory copies.
+ Status Flush(Datum* out) { return Status::OK(); }
+
+ // Return the counts corresponding the MemoTable keys.
+ Status FlushFinal(Datum* out) {
+ std::shared_ptr<ArrayData> result;
+ RETURN_NOT_OK(count_builder_.FinishInternal(&result));
+ out->value = std::move(result);
+ return Status::OK();
+ }
+
+ template <class Index>
+ void ObserveNullFound(Index index) {
+ count_builder_[index]++;
+ }
+
+ template <class Index>
+ void ObserveNullNotFound(Index index) {
+ ARROW_LOG(FATAL) << "ObserveNullNotFound without err_status should not be called";
+ }
+
+ template <class Index>
+ void ObserveNullNotFound(Index index, Status* status) {
+ Status s = count_builder_.Append(1);
+ if (ARROW_PREDICT_FALSE(!s.ok())) {
+ *status = s;
+ }
+ }
+
+ template <class Index>
+ void ObserveFound(Index slot) {
+ count_builder_[slot]++;
+ }
+
+ template <class Index>
+ void ObserveNotFound(Index slot, Status* status) {
+ Status s = count_builder_.Append(1);
+ if (ARROW_PREDICT_FALSE(!s.ok())) {
+ *status = s;
+ }
+ }
+
+ bool ShouldEncodeNulls() const { return true; }
+
+ private:
+ Int64Builder count_builder_;
+};
+
+// ----------------------------------------------------------------------
+// Dictionary encode implementation
+
+class DictEncodeAction final : public ActionBase {
+ public:
+ using ActionBase::ActionBase;
+
+ static constexpr bool with_error_status = false;
+
+ DictEncodeAction(const std::shared_ptr<DataType>& type, const FunctionOptions* options,
+ MemoryPool* pool)
+ : ActionBase(type, pool), indices_builder_(pool) {
+ if (auto options_ptr = static_cast<const DictionaryEncodeOptions*>(options)) {
+ encode_options_ = *options_ptr;
+ }
+ }
+
+ Status Reset() {
+ indices_builder_.Reset();
+ return Status::OK();
+ }
+
+ Status Reserve(const int64_t length) { return indices_builder_.Reserve(length); }
+
+ template <class Index>
+ void ObserveNullFound(Index index) {
+ if (encode_options_.null_encoding_behavior == DictionaryEncodeOptions::MASK) {
+ indices_builder_.UnsafeAppendNull();
+ } else {
+ indices_builder_.UnsafeAppend(index);
+ }
+ }
+
+ template <class Index>
+ void ObserveNullNotFound(Index index) {
+ ObserveNullFound(index);
+ }
+
+ template <class Index>
+ void ObserveFound(Index index) {
+ indices_builder_.UnsafeAppend(index);
+ }
+
+ template <class Index>
+ void ObserveNotFound(Index index) {
+ ObserveFound(index);
+ }
+
+ bool ShouldEncodeNulls() {
+ return encode_options_.null_encoding_behavior == DictionaryEncodeOptions::ENCODE;
+ }
+
+ Status Flush(Datum* out) {
+ std::shared_ptr<ArrayData> result;
+ RETURN_NOT_OK(indices_builder_.FinishInternal(&result));
+ out->value = std::move(result);
+ return Status::OK();
+ }
+
+ Status FlushFinal(Datum* out) { return Status::OK(); }
+
+ private:
+ Int32Builder indices_builder_;
+ DictionaryEncodeOptions encode_options_;
+};
+
+class HashKernel : public KernelState {
+ public:
+ HashKernel() : options_(nullptr) {}
+ explicit HashKernel(const FunctionOptions* options) : options_(options) {}
+
+ // Reset for another run.
+ virtual Status Reset() = 0;
+
+ // Flush out accumulated results from the last invocation of Call.
+ virtual Status Flush(Datum* out) = 0;
+ // Flush out accumulated results across all invocations of Call. The kernel
+ // should not be used until after Reset() is called.
+ virtual Status FlushFinal(Datum* out) = 0;
+ // Get the values (keys) accumulated in the dictionary so far.
+ virtual Status GetDictionary(std::shared_ptr<ArrayData>* out) = 0;
+
+ virtual std::shared_ptr<DataType> value_type() const = 0;
+
+ Status Append(KernelContext* ctx, const ArrayData& input) {
+ std::lock_guard<std::mutex> guard(lock_);
+ return Append(input);
+ }
+
+ // Prepare the Action for the given input (e.g. reserve appropriately sized
+ // data structures) and visit the given input with Action.
+ virtual Status Append(const ArrayData& arr) = 0;
+
+ protected:
+ const FunctionOptions* options_;
+ std::mutex lock_;
+};
+
+// ----------------------------------------------------------------------
+// Base class for all "regular" hash kernel implementations
+// (NullType has a separate implementation)
+
+template <typename Type, typename Scalar, typename Action,
+ bool with_error_status = Action::with_error_status>
+class RegularHashKernel : public HashKernel {
+ public:
+ RegularHashKernel(const std::shared_ptr<DataType>& type, const FunctionOptions* options,
+ MemoryPool* pool)
+ : HashKernel(options), pool_(pool), type_(type), action_(type, options, pool) {}
+
+ Status Reset() override {
+ memo_table_.reset(new MemoTable(pool_, 0));
+ return action_.Reset();
+ }
+
+ Status Append(const ArrayData& arr) override {
+ RETURN_NOT_OK(action_.Reserve(arr.length));
+ return DoAppend(arr);
+ }
+
+ Status Flush(Datum* out) override { return action_.Flush(out); }
+
+ Status FlushFinal(Datum* out) override { return action_.FlushFinal(out); }
+
+ Status GetDictionary(std::shared_ptr<ArrayData>* out) override {
+ return DictionaryTraits<Type>::GetDictionaryArrayData(pool_, type_, *memo_table_,
+ 0 /* start_offset */, out);
+ }
+
+ std::shared_ptr<DataType> value_type() const override { return type_; }
+
+ template <bool HasError = with_error_status>
+ enable_if_t<!HasError, Status> DoAppend(const ArrayData& arr) {
+ return VisitArrayDataInline<Type>(
+ arr,
+ [this](Scalar v) {
+ auto on_found = [this](int32_t memo_index) {
+ action_.ObserveFound(memo_index);
+ };
+ auto on_not_found = [this](int32_t memo_index) {
+ action_.ObserveNotFound(memo_index);
+ };
+
+ int32_t unused_memo_index;
+ return memo_table_->GetOrInsert(v, std::move(on_found), std::move(on_not_found),
+ &unused_memo_index);
+ },
+ [this]() {
+ if (action_.ShouldEncodeNulls()) {
+ auto on_found = [this](int32_t memo_index) {
+ action_.ObserveNullFound(memo_index);
+ };
+ auto on_not_found = [this](int32_t memo_index) {
+ action_.ObserveNullNotFound(memo_index);
+ };
+ memo_table_->GetOrInsertNull(std::move(on_found), std::move(on_not_found));
+ } else {
+ action_.ObserveNullNotFound(-1);
+ }
+ return Status::OK();
+ });
+ }
+
+ template <bool HasError = with_error_status>
+ enable_if_t<HasError, Status> DoAppend(const ArrayData& arr) {
+ return VisitArrayDataInline<Type>(
+ arr,
+ [this](Scalar v) {
+ Status s = Status::OK();
+ auto on_found = [this](int32_t memo_index) {
+ action_.ObserveFound(memo_index);
+ };
+ auto on_not_found = [this, &s](int32_t memo_index) {
+ action_.ObserveNotFound(memo_index, &s);
+ };
+
+ int32_t unused_memo_index;
+ RETURN_NOT_OK(memo_table_->GetOrInsert(
+ v, std::move(on_found), std::move(on_not_found), &unused_memo_index));
+ return s;
+ },
+ [this]() {
+ // Null
+ Status s = Status::OK();
+ auto on_found = [this](int32_t memo_index) {
+ action_.ObserveNullFound(memo_index);
+ };
+ auto on_not_found = [this, &s](int32_t memo_index) {
+ action_.ObserveNullNotFound(memo_index, &s);
+ };
+ if (action_.ShouldEncodeNulls()) {
+ memo_table_->GetOrInsertNull(std::move(on_found), std::move(on_not_found));
+ }
+ return s;
+ });
+ }
+
+ protected:
+ using MemoTable = typename HashTraits<Type>::MemoTableType;
+
+ MemoryPool* pool_;
+ std::shared_ptr<DataType> type_;
+ Action action_;
+ std::unique_ptr<MemoTable> memo_table_;
+};
+
+// ----------------------------------------------------------------------
+// Hash kernel implementation for nulls
+
+template <typename Action, bool with_error_status = Action::with_error_status>
+class NullHashKernel : public HashKernel {
+ public:
+ NullHashKernel(const std::shared_ptr<DataType>& type, const FunctionOptions* options,
+ MemoryPool* pool)
+ : pool_(pool), type_(type), action_(type, options, pool) {}
+
+ Status Reset() override { return action_.Reset(); }
+
+ Status Append(const ArrayData& arr) override { return DoAppend(arr); }
+
+ template <bool HasError = with_error_status>
+ enable_if_t<!HasError, Status> DoAppend(const ArrayData& arr) {
+ RETURN_NOT_OK(action_.Reserve(arr.length));
+ for (int64_t i = 0; i < arr.length; ++i) {
+ if (i == 0) {
+ seen_null_ = true;
+ action_.ObserveNullNotFound(0);
+ } else {
+ action_.ObserveNullFound(0);
+ }
+ }
+ return Status::OK();
+ }
+
+ template <bool HasError = with_error_status>
+ enable_if_t<HasError, Status> DoAppend(const ArrayData& arr) {
+ Status s = Status::OK();
+ RETURN_NOT_OK(action_.Reserve(arr.length));
+ for (int64_t i = 0; i < arr.length; ++i) {
+ if (seen_null_ == false && i == 0) {
+ seen_null_ = true;
+ action_.ObserveNullNotFound(0, &s);
+ } else {
+ action_.ObserveNullFound(0);
+ }
+ }
+ return s;
+ }
+
+ Status Flush(Datum* out) override { return action_.Flush(out); }
+ Status FlushFinal(Datum* out) override { return action_.FlushFinal(out); }
+
+ Status GetDictionary(std::shared_ptr<ArrayData>* out) override {
+ std::shared_ptr<NullArray> null_array;
+ if (seen_null_) {
+ null_array = std::make_shared<NullArray>(1);
+ } else {
+ null_array = std::make_shared<NullArray>(0);
+ }
+ *out = null_array->data();
+ return Status::OK();
+ }
+
+ std::shared_ptr<DataType> value_type() const override { return type_; }
+
+ protected:
+ MemoryPool* pool_;
+ std::shared_ptr<DataType> type_;
+ bool seen_null_ = false;
+ Action action_;
+};
+
+// ----------------------------------------------------------------------
+// Hashing for dictionary type
+
+class DictionaryHashKernel : public HashKernel {
+ public:
+ explicit DictionaryHashKernel(std::unique_ptr<HashKernel> indices_kernel,
+ std::shared_ptr<DataType> dictionary_value_type)
+ : indices_kernel_(std::move(indices_kernel)),
+ dictionary_value_type_(std::move(dictionary_value_type)) {}
+
+ Status Reset() override { return indices_kernel_->Reset(); }
+
+ Status Append(const ArrayData& arr) override {
+ if (!dictionary_) {
+ dictionary_ = arr.dictionary;
+ } else if (!MakeArray(dictionary_)->Equals(*MakeArray(arr.dictionary))) {
+ // NOTE: This approach computes a new dictionary unification per chunk.
+ // This is in effect O(n*k) where n is the total chunked array length and
+ // k is the number of chunks (therefore O(n**2) if chunks have a fixed size).
+ //
+ // A better approach may be to run the kernel over each individual chunk,
+ // and then hash-aggregate all results (for example sum-group-by for
+ // the "value_counts" kernel).
+ auto out_dict_type = dictionary_->type;
+ std::shared_ptr<Buffer> transpose_map;
+ std::shared_ptr<Array> out_dict;
+ ARROW_ASSIGN_OR_RAISE(auto unifier, DictionaryUnifier::Make(out_dict_type));
+
+ ARROW_CHECK_OK(unifier->Unify(*MakeArray(dictionary_)));
+ ARROW_CHECK_OK(unifier->Unify(*MakeArray(arr.dictionary), &transpose_map));
+ ARROW_CHECK_OK(unifier->GetResult(&out_dict_type, &out_dict));
+
+ this->dictionary_ = out_dict->data();
+ auto transpose = reinterpret_cast<const int32_t*>(transpose_map->data());
+ auto in_dict_array = MakeArray(std::make_shared<ArrayData>(arr));
+ ARROW_ASSIGN_OR_RAISE(
+ auto tmp, arrow::internal::checked_cast<const DictionaryArray&>(*in_dict_array)
+ .Transpose(arr.type, out_dict, transpose));
+ return indices_kernel_->Append(*tmp->data());
+ }
+
+ return indices_kernel_->Append(arr);
+ }
+
+ Status Flush(Datum* out) override { return indices_kernel_->Flush(out); }
+
+ Status FlushFinal(Datum* out) override { return indices_kernel_->FlushFinal(out); }
+
+ Status GetDictionary(std::shared_ptr<ArrayData>* out) override {
+ return indices_kernel_->GetDictionary(out);
+ }
+
+ std::shared_ptr<DataType> value_type() const override {
+ return indices_kernel_->value_type();
+ }
+
+ std::shared_ptr<DataType> dictionary_value_type() const {
+ return dictionary_value_type_;
+ }
+
+ std::shared_ptr<ArrayData> dictionary() const { return dictionary_; }
+
+ private:
+ std::unique_ptr<HashKernel> indices_kernel_;
+ std::shared_ptr<ArrayData> dictionary_;
+ std::shared_ptr<DataType> dictionary_value_type_;
+};
+
+// ----------------------------------------------------------------------
+
+template <typename Type, typename Action, typename Enable = void>
+struct HashKernelTraits {};
+
+template <typename Type, typename Action>
+struct HashKernelTraits<Type, Action, enable_if_null<Type>> {
+ using HashKernel = NullHashKernel<Action>;
+};
+
+template <typename Type, typename Action>
+struct HashKernelTraits<Type, Action, enable_if_has_c_type<Type>> {
+ using HashKernel = RegularHashKernel<Type, typename Type::c_type, Action>;
+};
+
+template <typename Type, typename Action>
+struct HashKernelTraits<Type, Action, enable_if_has_string_view<Type>> {
+ using HashKernel = RegularHashKernel<Type, util::string_view, Action>;
+};
+
+template <typename Type, typename Action>
+Result<std::unique_ptr<HashKernel>> HashInitImpl(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ using HashKernelType = typename HashKernelTraits<Type, Action>::HashKernel;
+ auto result = ::arrow::internal::make_unique<HashKernelType>(
+ args.inputs[0].type, args.options, ctx->memory_pool());
+ RETURN_NOT_OK(result->Reset());
+ return std::move(result);
+}
+
+template <typename Type, typename Action>
+Result<std::unique_ptr<KernelState>> HashInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ return HashInitImpl<Type, Action>(ctx, args);
+}
+
+template <typename Action>
+KernelInit GetHashInit(Type::type type_id) {
+ // ARROW-8933: Generate only a single hash kernel per physical data
+ // representation
+ switch (type_id) {
+ case Type::NA:
+ return HashInit<NullType, Action>;
+ case Type::BOOL:
+ return HashInit<BooleanType, Action>;
+ case Type::INT8:
+ case Type::UINT8:
+ return HashInit<UInt8Type, Action>;
+ case Type::INT16:
+ case Type::UINT16:
+ return HashInit<UInt16Type, Action>;
+ case Type::INT32:
+ case Type::UINT32:
+ case Type::FLOAT:
+ case Type::DATE32:
+ case Type::TIME32:
+ return HashInit<UInt32Type, Action>;
+ case Type::INT64:
+ case Type::UINT64:
+ case Type::DOUBLE:
+ case Type::DATE64:
+ case Type::TIME64:
+ case Type::TIMESTAMP:
+ case Type::DURATION:
+ return HashInit<UInt64Type, Action>;
+ case Type::BINARY:
+ case Type::STRING:
+ return HashInit<BinaryType, Action>;
+ case Type::LARGE_BINARY:
+ case Type::LARGE_STRING:
+ return HashInit<LargeBinaryType, Action>;
+ case Type::FIXED_SIZE_BINARY:
+ case Type::DECIMAL128:
+ case Type::DECIMAL256:
+ return HashInit<FixedSizeBinaryType, Action>;
+ default:
+ DCHECK(false);
+ return nullptr;
+ }
+}
+
+using DictionaryEncodeState = OptionsWrapper<DictionaryEncodeOptions>;
+
+template <typename Action>
+Result<std::unique_ptr<KernelState>> DictionaryHashInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ const auto& dict_type = checked_cast<const DictionaryType&>(*args.inputs[0].type);
+ Result<std::unique_ptr<HashKernel>> indices_hasher;
+ switch (dict_type.index_type()->id()) {
+ case Type::INT8:
+ case Type::UINT8:
+ indices_hasher = HashInitImpl<UInt8Type, Action>(ctx, args);
+ break;
+ case Type::INT16:
+ case Type::UINT16:
+ indices_hasher = HashInitImpl<UInt16Type, Action>(ctx, args);
+ break;
+ case Type::INT32:
+ case Type::UINT32:
+ indices_hasher = HashInitImpl<UInt32Type, Action>(ctx, args);
+ break;
+ case Type::INT64:
+ case Type::UINT64:
+ indices_hasher = HashInitImpl<UInt64Type, Action>(ctx, args);
+ break;
+ default:
+ DCHECK(false) << "Unsupported dictionary index type";
+ break;
+ }
+ RETURN_NOT_OK(indices_hasher);
+ return ::arrow::internal::make_unique<DictionaryHashKernel>(
+ std::move(indices_hasher.ValueOrDie()), dict_type.value_type());
+}
+
+Status HashExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ auto hash_impl = checked_cast<HashKernel*>(ctx->state());
+ RETURN_NOT_OK(hash_impl->Append(ctx, *batch[0].array()));
+ RETURN_NOT_OK(hash_impl->Flush(out));
+ return Status::OK();
+}
+
+Status UniqueFinalize(KernelContext* ctx, std::vector<Datum>* out) {
+ auto hash_impl = checked_cast<HashKernel*>(ctx->state());
+ std::shared_ptr<ArrayData> uniques;
+ RETURN_NOT_OK(hash_impl->GetDictionary(&uniques));
+ *out = {Datum(uniques)};
+ return Status::OK();
+}
+
+Status DictEncodeFinalize(KernelContext* ctx, std::vector<Datum>* out) {
+ auto hash_impl = checked_cast<HashKernel*>(ctx->state());
+ std::shared_ptr<ArrayData> uniques;
+ RETURN_NOT_OK(hash_impl->GetDictionary(&uniques));
+ auto dict_type = dictionary(int32(), uniques->type);
+ auto dict = MakeArray(uniques);
+ for (size_t i = 0; i < out->size(); ++i) {
+ (*out)[i] =
+ std::make_shared<DictionaryArray>(dict_type, (*out)[i].make_array(), dict);
+ }
+ return Status::OK();
+}
+
+std::shared_ptr<ArrayData> BoxValueCounts(const std::shared_ptr<ArrayData>& uniques,
+ const std::shared_ptr<ArrayData>& counts) {
+ auto data_type =
+ struct_({field(kValuesFieldName, uniques->type), field(kCountsFieldName, int64())});
+ ArrayVector children = {MakeArray(uniques), MakeArray(counts)};
+ return std::make_shared<StructArray>(data_type, uniques->length, children)->data();
+}
+
+Status ValueCountsFinalize(KernelContext* ctx, std::vector<Datum>* out) {
+ auto hash_impl = checked_cast<HashKernel*>(ctx->state());
+ std::shared_ptr<ArrayData> uniques;
+ Datum value_counts;
+
+ RETURN_NOT_OK(hash_impl->GetDictionary(&uniques));
+ RETURN_NOT_OK(hash_impl->FlushFinal(&value_counts));
+ *out = {Datum(BoxValueCounts(uniques, value_counts.array()))};
+ return Status::OK();
+}
+
+// Return the dictionary from the hash kernel or allocate an empty one.
+// Required because on empty inputs, we don't ever see the input and
+// hence have no dictionary.
+Result<std::shared_ptr<ArrayData>> EnsureHashDictionary(KernelContext* ctx,
+ DictionaryHashKernel* hash) {
+ if (hash->dictionary()) {
+ return hash->dictionary();
+ }
+ ARROW_ASSIGN_OR_RAISE(auto null, MakeArrayOfNull(hash->dictionary_value_type(),
+ /*length=*/0, ctx->memory_pool()));
+ return null->data();
+}
+
+Status UniqueFinalizeDictionary(KernelContext* ctx, std::vector<Datum>* out) {
+ RETURN_NOT_OK(UniqueFinalize(ctx, out));
+ auto hash = checked_cast<DictionaryHashKernel*>(ctx->state());
+ ARROW_ASSIGN_OR_RAISE((*out)[0].mutable_array()->dictionary,
+ EnsureHashDictionary(ctx, hash));
+ return Status::OK();
+}
+
+Status ValueCountsFinalizeDictionary(KernelContext* ctx, std::vector<Datum>* out) {
+ auto hash = checked_cast<DictionaryHashKernel*>(ctx->state());
+ std::shared_ptr<ArrayData> uniques;
+ Datum value_counts;
+ RETURN_NOT_OK(hash->GetDictionary(&uniques));
+ RETURN_NOT_OK(hash->FlushFinal(&value_counts));
+ ARROW_ASSIGN_OR_RAISE(uniques->dictionary, EnsureHashDictionary(ctx, hash));
+ *out = {Datum(BoxValueCounts(uniques, value_counts.array()))};
+ return Status::OK();
+}
+
+ValueDescr DictEncodeOutput(KernelContext*, const std::vector<ValueDescr>& descrs) {
+ return ValueDescr::Array(dictionary(int32(), descrs[0].type));
+}
+
+ValueDescr ValueCountsOutput(KernelContext*, const std::vector<ValueDescr>& descrs) {
+ return ValueDescr::Array(struct_(
+ {field(kValuesFieldName, descrs[0].type), field(kCountsFieldName, int64())}));
+}
+
+template <typename Action>
+void AddHashKernels(VectorFunction* func, VectorKernel base, OutputType out_ty) {
+ for (const auto& ty : PrimitiveTypes()) {
+ base.init = GetHashInit<Action>(ty->id());
+ base.signature = KernelSignature::Make({InputType::Array(ty)}, out_ty);
+ DCHECK_OK(func->AddKernel(base));
+ }
+
+ // Example parametric types that we want to match only on Type::type
+ auto parametric_types = {time32(TimeUnit::SECOND), time64(TimeUnit::MICRO),
+ timestamp(TimeUnit::SECOND), fixed_size_binary(0)};
+ for (const auto& ty : parametric_types) {
+ base.init = GetHashInit<Action>(ty->id());
+ base.signature = KernelSignature::Make({InputType::Array(ty->id())}, out_ty);
+ DCHECK_OK(func->AddKernel(base));
+ }
+
+ for (auto t : {Type::DECIMAL128, Type::DECIMAL256}) {
+ base.init = GetHashInit<Action>(t);
+ base.signature = KernelSignature::Make({InputType::Array(t)}, out_ty);
+ DCHECK_OK(func->AddKernel(base));
+ }
+}
+
+const FunctionDoc unique_doc(
+ "Compute unique elements",
+ ("Return an array with distinct values. Nulls in the input are ignored."),
+ {"array"});
+
+const FunctionDoc value_counts_doc(
+ "Compute counts of unique elements",
+ ("For each distinct value, compute the number of times it occurs in the array.\n"
+ "The result is returned as an array of `struct<input type, int64>`.\n"
+ "Nulls in the input are ignored."),
+ {"array"});
+
+const auto kDefaultDictionaryEncodeOptions = DictionaryEncodeOptions::Defaults();
+const FunctionDoc dictionary_encode_doc(
+ "Dictionary-encode array",
+ ("Return a dictionary-encoded version of the input array."), {"array"},
+ "DictionaryEncodeOptions");
+
+} // namespace
+
+void RegisterVectorHash(FunctionRegistry* registry) {
+ VectorKernel base;
+ base.exec = HashExec;
+
+ // ----------------------------------------------------------------------
+ // unique
+
+ base.finalize = UniqueFinalize;
+ base.output_chunked = false;
+ auto unique = std::make_shared<VectorFunction>("unique", Arity::Unary(), &unique_doc);
+ AddHashKernels<UniqueAction>(unique.get(), base, OutputType(FirstType));
+
+ // Dictionary unique
+ base.init = DictionaryHashInit<UniqueAction>;
+ base.finalize = UniqueFinalizeDictionary;
+ base.signature =
+ KernelSignature::Make({InputType::Array(Type::DICTIONARY)}, OutputType(FirstType));
+ DCHECK_OK(unique->AddKernel(base));
+
+ DCHECK_OK(registry->AddFunction(std::move(unique)));
+
+ // ----------------------------------------------------------------------
+ // value_counts
+
+ base.finalize = ValueCountsFinalize;
+ auto value_counts =
+ std::make_shared<VectorFunction>("value_counts", Arity::Unary(), &value_counts_doc);
+ AddHashKernels<ValueCountsAction>(value_counts.get(), base,
+ OutputType(ValueCountsOutput));
+
+ // Dictionary value counts
+ base.init = DictionaryHashInit<ValueCountsAction>;
+ base.finalize = ValueCountsFinalizeDictionary;
+ base.signature = KernelSignature::Make({InputType::Array(Type::DICTIONARY)},
+ OutputType(ValueCountsOutput));
+ DCHECK_OK(value_counts->AddKernel(base));
+
+ DCHECK_OK(registry->AddFunction(std::move(value_counts)));
+
+ // ----------------------------------------------------------------------
+ // dictionary_encode
+
+ base.finalize = DictEncodeFinalize;
+ // Unique and ValueCounts output unchunked arrays
+ base.output_chunked = true;
+ auto dict_encode = std::make_shared<VectorFunction>("dictionary_encode", Arity::Unary(),
+ &dictionary_encode_doc,
+ &kDefaultDictionaryEncodeOptions);
+ AddHashKernels<DictEncodeAction>(dict_encode.get(), base, OutputType(DictEncodeOutput));
+
+ // Calling dictionary_encode on dictionary input not supported, but if it
+ // ends up being needed (or convenience), a kernel could be added to make it
+ // a no-op
+
+ DCHECK_OK(registry->AddFunction(std::move(dict_encode)));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_hash_benchmark.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_hash_benchmark.cc
new file mode 100644
index 000000000..3be549d05
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_hash_benchmark.cc
@@ -0,0 +1,250 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <vector>
+
+#include "arrow/array/builder_binary.h"
+#include "arrow/memory_pool.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+
+#include "arrow/compute/api.h"
+
+namespace arrow {
+namespace compute {
+
+static void BuildDictionary(benchmark::State& state) { // NOLINT non-const reference
+ const int64_t iterations = 1024;
+
+ std::vector<int64_t> values;
+ std::vector<bool> is_valid;
+ for (int64_t i = 0; i < iterations; i++) {
+ for (int64_t j = 0; j < i; j++) {
+ is_valid.push_back((i + j) % 9 != 0);
+ values.push_back(j);
+ }
+ }
+
+ std::shared_ptr<Array> arr;
+ ArrayFromVector<Int64Type, int64_t>(is_valid, values, &arr);
+
+ while (state.KeepRunning()) {
+ ABORT_NOT_OK(DictionaryEncode(arr).status());
+ }
+ state.counters["null_percent"] =
+ static_cast<double>(arr->null_count()) / arr->length() * 100;
+ state.SetBytesProcessed(state.iterations() * values.size() * sizeof(int64_t));
+ state.SetItemsProcessed(state.iterations() * values.size());
+}
+
+static void BuildStringDictionary(
+ benchmark::State& state) { // NOLINT non-const reference
+ const int64_t iterations = 1024 * 64;
+ // Pre-render strings
+ std::vector<std::string> data;
+
+ int64_t total_bytes = 0;
+ for (int64_t i = 0; i < iterations; i++) {
+ std::stringstream ss;
+ ss << i;
+ auto val = ss.str();
+ data.push_back(val);
+ total_bytes += static_cast<int64_t>(val.size());
+ }
+
+ std::shared_ptr<Array> arr;
+ ArrayFromVector<StringType, std::string>(data, &arr);
+
+ while (state.KeepRunning()) {
+ ABORT_NOT_OK(DictionaryEncode(arr).status());
+ }
+ state.SetBytesProcessed(state.iterations() * total_bytes);
+ state.SetItemsProcessed(state.iterations() * data.size());
+}
+
+struct HashBenchCase {
+ int64_t length;
+ int64_t num_unique;
+ double null_probability;
+};
+
+template <typename Type>
+struct HashParams {
+ using T = typename Type::c_type;
+
+ HashBenchCase params;
+
+ void GenerateTestData(std::shared_ptr<Array>* arr) const {
+ std::vector<int64_t> draws;
+ std::vector<T> values;
+ std::vector<bool> is_valid;
+ randint<int64_t>(params.length, 0, params.num_unique, &draws);
+ for (int64_t draw : draws) {
+ values.push_back(static_cast<T>(draw));
+ }
+ if (params.null_probability > 0) {
+ random_is_valid(params.length, params.null_probability, &is_valid);
+ ArrayFromVector<Type, T>(is_valid, values, arr);
+ } else {
+ ArrayFromVector<Type, T>(values, arr);
+ }
+ }
+
+ void SetMetadata(benchmark::State& state) const {
+ state.counters["null_percent"] = params.null_probability * 100;
+ state.counters["num_unique"] = static_cast<double>(params.num_unique);
+ state.SetBytesProcessed(state.iterations() * params.length * sizeof(T));
+ state.SetItemsProcessed(state.iterations() * params.length);
+ }
+};
+
+template <>
+struct HashParams<StringType> {
+ HashBenchCase params;
+ int32_t byte_width;
+ void GenerateTestData(std::shared_ptr<Array>* arr) const {
+ std::vector<int64_t> draws;
+ randint<int64_t>(params.length, 0, params.num_unique, &draws);
+
+ const int64_t total_bytes = this->byte_width * params.num_unique;
+ std::vector<uint8_t> uniques(total_bytes);
+ const uint32_t seed = 0;
+ random_bytes(total_bytes, seed, uniques.data());
+
+ std::vector<bool> is_valid;
+ if (params.null_probability > 0) {
+ random_is_valid(params.length, params.null_probability, &is_valid);
+ }
+
+ StringBuilder builder;
+ for (int64_t i = 0; i < params.length; ++i) {
+ if (params.null_probability == 0 || is_valid[i]) {
+ ABORT_NOT_OK(builder.Append(uniques.data() + this->byte_width * draws[i],
+ this->byte_width));
+ } else {
+ ABORT_NOT_OK(builder.AppendNull());
+ }
+ }
+ ABORT_NOT_OK(builder.Finish(arr));
+ }
+
+ void SetMetadata(benchmark::State& state) const {
+ state.counters["null_percent"] = params.null_probability * 100;
+ state.counters["num_unique"] = static_cast<double>(params.num_unique);
+ state.SetBytesProcessed(state.iterations() * params.length * byte_width);
+ state.SetItemsProcessed(state.iterations() * params.length);
+ }
+};
+
+template <typename ParamType>
+void BenchUnique(benchmark::State& state, const ParamType& params) {
+ std::shared_ptr<Array> arr;
+ params.GenerateTestData(&arr);
+
+ while (state.KeepRunning()) {
+ ABORT_NOT_OK(Unique(arr).status());
+ }
+ params.SetMetadata(state);
+}
+
+template <typename ParamType>
+void BenchDictionaryEncode(benchmark::State& state, const ParamType& params) {
+ std::shared_ptr<Array> arr;
+ params.GenerateTestData(&arr);
+ while (state.KeepRunning()) {
+ ABORT_NOT_OK(DictionaryEncode(arr).status());
+ }
+ params.SetMetadata(state);
+}
+
+constexpr int kHashBenchmarkLength = 1 << 22;
+
+// clang-format off
+std::vector<HashBenchCase> uint8_bench_cases = {
+ {kHashBenchmarkLength, 200, 0},
+ {kHashBenchmarkLength, 200, 0.001},
+ {kHashBenchmarkLength, 200, 0.01},
+ {kHashBenchmarkLength, 200, 0.1},
+ {kHashBenchmarkLength, 200, 0.5},
+ {kHashBenchmarkLength, 200, 0.99},
+ {kHashBenchmarkLength, 200, 1}
+};
+// clang-format on
+
+static void UniqueUInt8(benchmark::State& state) {
+ BenchUnique(state, HashParams<UInt8Type>{uint8_bench_cases[state.range(0)]});
+}
+
+// clang-format off
+std::vector<HashBenchCase> general_bench_cases = {
+ {kHashBenchmarkLength, 100, 0},
+ {kHashBenchmarkLength, 100, 0.001},
+ {kHashBenchmarkLength, 100, 0.01},
+ {kHashBenchmarkLength, 100, 0.1},
+ {kHashBenchmarkLength, 100, 0.5},
+ {kHashBenchmarkLength, 100, 0.99},
+ {kHashBenchmarkLength, 100, 1},
+ {kHashBenchmarkLength, 100000, 0},
+ {kHashBenchmarkLength, 100000, 0.001},
+ {kHashBenchmarkLength, 100000, 0.01},
+ {kHashBenchmarkLength, 100000, 0.1},
+ {kHashBenchmarkLength, 100000, 0.5},
+ {kHashBenchmarkLength, 100000, 0.99},
+ {kHashBenchmarkLength, 100000, 1},
+};
+// clang-format on
+
+static void UniqueInt64(benchmark::State& state) {
+ BenchUnique(state, HashParams<Int64Type>{general_bench_cases[state.range(0)]});
+}
+
+static void UniqueString10bytes(benchmark::State& state) {
+ // Byte strings with 10 bytes each
+ BenchUnique(state, HashParams<StringType>{general_bench_cases[state.range(0)], 10});
+}
+
+static void UniqueString100bytes(benchmark::State& state) {
+ // Byte strings with 100 bytes each
+ BenchUnique(state, HashParams<StringType>{general_bench_cases[state.range(0)], 100});
+}
+
+void HashSetArgs(benchmark::internal::Benchmark* bench) {
+ for (int i = 0; i < static_cast<int>(general_bench_cases.size()); ++i) {
+ bench->Arg(i);
+ }
+}
+
+BENCHMARK(BuildDictionary);
+BENCHMARK(BuildStringDictionary);
+
+BENCHMARK(UniqueInt64)->Apply(HashSetArgs);
+BENCHMARK(UniqueString10bytes)->Apply(HashSetArgs);
+BENCHMARK(UniqueString100bytes)->Apply(HashSetArgs);
+
+void UInt8SetArgs(benchmark::internal::Benchmark* bench) {
+ for (int i = 0; i < static_cast<int>(uint8_bench_cases.size()); ++i) {
+ bench->Arg(i);
+ }
+}
+
+BENCHMARK(UniqueUInt8)->Apply(UInt8SetArgs);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_hash_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_hash_test.cc
new file mode 100644
index 000000000..a10667e49
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_hash_test.cc
@@ -0,0 +1,756 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdio>
+#include <functional>
+#include <locale>
+#include <memory>
+#include <stdexcept>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/buffer.h"
+#include "arrow/chunked_array.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+
+#include "arrow/compute/api.h"
+#include "arrow/compute/kernels/test_util.h"
+
+#include "arrow/ipc/json_simple.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+// ----------------------------------------------------------------------
+// Dictionary tests
+
+template <typename T>
+void CheckUnique(const std::shared_ptr<T>& input,
+ const std::shared_ptr<Array>& expected) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> result, Unique(input));
+ ValidateOutput(*result);
+ // TODO: We probably shouldn't rely on array ordering.
+ ASSERT_ARRAYS_EQUAL(*expected, *result);
+}
+
+template <typename Type, typename T>
+void CheckUnique(const std::shared_ptr<DataType>& type, const std::vector<T>& in_values,
+ const std::vector<bool>& in_is_valid, const std::vector<T>& out_values,
+ const std::vector<bool>& out_is_valid) {
+ std::shared_ptr<Array> input = _MakeArray<Type, T>(type, in_values, in_is_valid);
+ std::shared_ptr<Array> expected = _MakeArray<Type, T>(type, out_values, out_is_valid);
+ CheckUnique(input, expected);
+}
+
+// Check that ValueCounts() accepts a 0-length array with null buffers
+void CheckValueCountsNull(const std::shared_ptr<DataType>& type) {
+ std::vector<std::shared_ptr<Buffer>> data_buffers(2);
+ Datum input;
+ input.value =
+ ArrayData::Make(type, 0 /* length */, std::move(data_buffers), 0 /* null_count */);
+
+ std::shared_ptr<Array> ex_values = ArrayFromJSON(type, "[]");
+ std::shared_ptr<Array> ex_counts = ArrayFromJSON(int64(), "[]");
+
+ ASSERT_OK_AND_ASSIGN(auto result_struct, ValueCounts(input));
+ ValidateOutput(*result_struct);
+ ASSERT_NE(result_struct->GetFieldByName(kValuesFieldName), nullptr);
+ // TODO: We probably shouldn't rely on value ordering.
+ ASSERT_ARRAYS_EQUAL(*ex_values, *result_struct->GetFieldByName(kValuesFieldName));
+ ASSERT_ARRAYS_EQUAL(*ex_counts, *result_struct->GetFieldByName(kCountsFieldName));
+}
+
+template <typename T>
+void CheckValueCounts(const std::shared_ptr<T>& input,
+ const std::shared_ptr<Array>& expected_values,
+ const std::shared_ptr<Array>& expected_counts) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> result, ValueCounts(input));
+ ValidateOutput(*result);
+ auto result_struct = std::dynamic_pointer_cast<StructArray>(result);
+ ASSERT_EQ(result_struct->num_fields(), 2);
+ // TODO: We probably shouldn't rely on value ordering.
+ ASSERT_ARRAYS_EQUAL(*expected_values, *result_struct->field(kValuesFieldIndex));
+ ASSERT_ARRAYS_EQUAL(*expected_counts, *result_struct->field(kCountsFieldIndex));
+}
+
+template <typename Type, typename T>
+void CheckValueCounts(const std::shared_ptr<DataType>& type,
+ const std::vector<T>& in_values,
+ const std::vector<bool>& in_is_valid,
+ const std::vector<T>& out_values,
+ const std::vector<bool>& out_is_valid,
+ const std::vector<int64_t>& out_counts) {
+ std::vector<bool> all_valids(out_is_valid.size(), true);
+ std::shared_ptr<Array> input = _MakeArray<Type, T>(type, in_values, in_is_valid);
+ std::shared_ptr<Array> ex_values = _MakeArray<Type, T>(type, out_values, out_is_valid);
+ std::shared_ptr<Array> ex_counts =
+ _MakeArray<Int64Type, int64_t>(int64(), out_counts, all_valids);
+
+ CheckValueCounts(input, ex_values, ex_counts);
+}
+
+void CheckDictEncode(const std::shared_ptr<Array>& input,
+ const std::shared_ptr<Array>& expected_values,
+ const std::shared_ptr<Array>& expected_indices) {
+ auto type = dictionary(expected_indices->type(), expected_values->type());
+ DictionaryArray expected(type, expected_indices, expected_values);
+
+ ASSERT_OK_AND_ASSIGN(Datum datum_out, DictionaryEncode(input));
+ std::shared_ptr<Array> result = MakeArray(datum_out.array());
+ ValidateOutput(*result);
+
+ ASSERT_ARRAYS_EQUAL(expected, *result);
+}
+
+template <typename Type, typename T>
+void CheckDictEncode(const std::shared_ptr<DataType>& type,
+ const std::vector<T>& in_values,
+ const std::vector<bool>& in_is_valid,
+ const std::vector<T>& out_values,
+ const std::vector<bool>& out_is_valid,
+ const std::vector<int32_t>& out_indices) {
+ std::shared_ptr<Array> input = _MakeArray<Type, T>(type, in_values, in_is_valid);
+ std::shared_ptr<Array> ex_dict = _MakeArray<Type, T>(type, out_values, out_is_valid);
+ std::shared_ptr<Array> ex_indices =
+ _MakeArray<Int32Type, int32_t>(int32(), out_indices, in_is_valid);
+ return CheckDictEncode(input, ex_dict, ex_indices);
+}
+
+class TestHashKernel : public ::testing::Test {};
+
+template <typename Type>
+class TestHashKernelPrimitive : public ::testing::Test {};
+
+typedef ::testing::Types<Int8Type, UInt8Type, Int16Type, UInt16Type, Int32Type,
+ UInt32Type, Int64Type, UInt64Type, FloatType, DoubleType,
+ Date32Type, Date64Type>
+ PrimitiveDictionaries;
+
+TYPED_TEST_SUITE(TestHashKernelPrimitive, PrimitiveDictionaries);
+
+TYPED_TEST(TestHashKernelPrimitive, Unique) {
+ using T = typename TypeParam::c_type;
+ auto type = TypeTraits<TypeParam>::type_singleton();
+ CheckUnique<TypeParam, T>(type, {2, 1, 2, 1}, {true, false, true, true}, {2, 0, 1},
+ {1, 0, 1});
+ CheckUnique<TypeParam, T>(type, {2, 1, 3, 1}, {false, false, true, true}, {0, 3, 1},
+ {0, 1, 1});
+
+ // Sliced
+ CheckUnique(ArrayFromJSON(type, "[1, 2, null, 3, 2, null]")->Slice(1, 4),
+ ArrayFromJSON(type, "[2, null, 3]"));
+}
+
+TYPED_TEST(TestHashKernelPrimitive, ValueCounts) {
+ using T = typename TypeParam::c_type;
+ auto type = TypeTraits<TypeParam>::type_singleton();
+ CheckValueCounts<TypeParam, T>(type, {2, 1, 2, 1, 2, 3, 4},
+ {true, false, true, true, true, true, false},
+ {2, 0, 1, 3}, {1, 0, 1, 1}, {3, 2, 1, 1});
+ CheckValueCounts<TypeParam, T>(type, {}, {}, {}, {}, {});
+ CheckValueCountsNull(type);
+
+ // Sliced
+ CheckValueCounts(ArrayFromJSON(type, "[1, 2, null, 3, 2, null]")->Slice(1, 4),
+ ArrayFromJSON(type, "[2, null, 3]"),
+ ArrayFromJSON(int64(), "[2, 1, 1]"));
+}
+
+TYPED_TEST(TestHashKernelPrimitive, DictEncode) {
+ using T = typename TypeParam::c_type;
+ auto type = TypeTraits<TypeParam>::type_singleton();
+ CheckDictEncode<TypeParam, T>(type, {2, 1, 2, 1, 2, 3},
+ {true, false, true, true, true, true}, {2, 1, 3},
+ {1, 1, 1}, {0, 0, 0, 1, 0, 2});
+
+ // Sliced
+ CheckDictEncode(ArrayFromJSON(type, "[2, 1, null, 4, 3, 1, 42]")->Slice(1, 5),
+ ArrayFromJSON(type, "[1, 4, 3]"),
+ ArrayFromJSON(int32(), "[0, null, 1, 2, 0]"));
+}
+
+TYPED_TEST(TestHashKernelPrimitive, ZeroChunks) {
+ auto type = TypeTraits<TypeParam>::type_singleton();
+
+ auto zero_chunks = std::make_shared<ChunkedArray>(ArrayVector{}, type);
+ ASSERT_OK_AND_ASSIGN(Datum result, DictionaryEncode(zero_chunks));
+
+ ASSERT_EQ(result.kind(), Datum::CHUNKED_ARRAY);
+ AssertChunkedEqual(*result.chunked_array(),
+ ChunkedArray({}, dictionary(int32(), type)));
+}
+
+TYPED_TEST(TestHashKernelPrimitive, PrimitiveResizeTable) {
+ using T = typename TypeParam::c_type;
+
+ const int64_t kTotalValues = std::min<int64_t>(INT16_MAX, 1UL << sizeof(T) / 2);
+ const int64_t kRepeats = 5;
+
+ std::vector<T> values;
+ std::vector<T> uniques;
+ std::vector<int32_t> indices;
+ std::vector<int64_t> counts;
+ for (int64_t i = 0; i < kTotalValues * kRepeats; i++) {
+ const auto val = static_cast<T>(i % kTotalValues);
+ values.push_back(val);
+
+ if (i < kTotalValues) {
+ uniques.push_back(val);
+ counts.push_back(kRepeats);
+ }
+ indices.push_back(static_cast<int32_t>(i % kTotalValues));
+ }
+
+ auto type = TypeTraits<TypeParam>::type_singleton();
+ CheckUnique<TypeParam, T>(type, values, {}, uniques, {});
+ CheckValueCounts<TypeParam, T>(type, values, {}, uniques, {}, counts);
+ CheckDictEncode<TypeParam, T>(type, values, {}, uniques, {}, indices);
+}
+
+TEST_F(TestHashKernel, UniqueTimeTimestamp) {
+ CheckUnique<Time32Type, int32_t>(time32(TimeUnit::SECOND), {2, 1, 2, 1},
+ {true, false, true, true}, {2, 0, 1}, {1, 0, 1});
+
+ CheckUnique<Time64Type, int64_t>(time64(TimeUnit::NANO), {2, 1, 2, 1},
+ {true, false, true, true}, {2, 0, 1}, {1, 0, 1});
+
+ CheckUnique<TimestampType, int64_t>(timestamp(TimeUnit::NANO), {2, 1, 2, 1},
+ {true, false, true, true}, {2, 0, 1}, {1, 0, 1});
+}
+
+TEST_F(TestHashKernel, ValueCountsTimeTimestamp) {
+ CheckValueCounts<Time32Type, int32_t>(time32(TimeUnit::SECOND), {2, 1, 2, 1},
+ {true, false, true, true}, {2, 0, 1}, {1, 0, 1},
+ {2, 1, 1});
+
+ CheckValueCounts<Time64Type, int64_t>(time64(TimeUnit::NANO), {2, 1, 2, 1},
+ {true, false, true, true}, {2, 0, 1}, {1, 0, 1},
+ {2, 1, 1});
+
+ CheckValueCounts<TimestampType, int64_t>(timestamp(TimeUnit::NANO), {2, 1, 2, 1},
+ {true, false, true, true}, {2, 0, 1},
+ {1, 0, 1}, {2, 1, 1});
+}
+
+TEST_F(TestHashKernel, UniqueBoolean) {
+ CheckUnique<BooleanType, bool>(boolean(), {true, true, false, true},
+ {true, false, true, true}, {true, false, false},
+ {1, 0, 1});
+
+ CheckUnique<BooleanType, bool>(boolean(), {false, true, false, true},
+ {true, false, true, true}, {false, false, true},
+ {1, 0, 1});
+
+ // No nulls
+ CheckUnique<BooleanType, bool>(boolean(), {true, true, false, true}, {}, {true, false},
+ {});
+
+ CheckUnique<BooleanType, bool>(boolean(), {false, true, false, true}, {}, {false, true},
+ {});
+
+ // Sliced
+ CheckUnique(ArrayFromJSON(boolean(), "[null, true, true, false]")->Slice(1, 2),
+ ArrayFromJSON(boolean(), "[true]"));
+}
+
+TEST_F(TestHashKernel, ValueCountsBoolean) {
+ CheckValueCounts<BooleanType, bool>(boolean(), {true, true, false, true},
+ {true, false, true, true}, {true, false, false},
+ {1, 0, 1}, {2, 1, 1});
+
+ CheckValueCounts<BooleanType, bool>(boolean(), {false, true, false, true},
+ {true, false, true, true}, {false, false, true},
+ {1, 0, 1}, {2, 1, 1});
+
+ // No nulls
+ CheckValueCounts<BooleanType, bool>(boolean(), {true, true, false, true}, {},
+ {true, false}, {}, {3, 1});
+
+ CheckValueCounts<BooleanType, bool>(boolean(), {false, true, false, true}, {},
+ {false, true}, {}, {2, 2});
+
+ // Sliced
+ CheckValueCounts(ArrayFromJSON(boolean(), "[true, false, false, null]")->Slice(1, 2),
+ ArrayFromJSON(boolean(), "[false]"), ArrayFromJSON(int64(), "[2]"));
+}
+
+TEST_F(TestHashKernel, ValueCountsNull) {
+ CheckValueCounts(ArrayFromJSON(null(), "[null, null, null]"),
+ ArrayFromJSON(null(), "[null]"), ArrayFromJSON(int64(), "[3]"));
+}
+
+TEST_F(TestHashKernel, DictEncodeBoolean) {
+ CheckDictEncode<BooleanType, bool>(boolean(), {true, true, false, true, false},
+ {true, false, true, true, true}, {true, false}, {},
+ {0, 0, 1, 0, 1});
+
+ CheckDictEncode<BooleanType, bool>(boolean(), {false, true, false, true, false},
+ {true, false, true, true, true}, {false, true}, {},
+ {0, 0, 0, 1, 0});
+
+ // No nulls
+ CheckDictEncode<BooleanType, bool>(boolean(), {true, true, false, true, false}, {},
+ {true, false}, {}, {0, 0, 1, 0, 1});
+
+ CheckDictEncode<BooleanType, bool>(boolean(), {false, true, false, true, false}, {},
+ {false, true}, {}, {0, 1, 0, 1, 0});
+
+ // Sliced
+ CheckDictEncode(
+ ArrayFromJSON(boolean(), "[false, true, null, true, false]")->Slice(1, 3),
+ ArrayFromJSON(boolean(), "[true]"), ArrayFromJSON(int32(), "[0, null, 0]"));
+}
+
+template <typename ArrowType>
+class TestHashKernelBinaryTypes : public TestHashKernel {
+ protected:
+ std::shared_ptr<DataType> type() { return TypeTraits<ArrowType>::type_singleton(); }
+
+ void CheckDictEncodeP(const std::vector<std::string>& in_values,
+ const std::vector<bool>& in_is_valid,
+ const std::vector<std::string>& out_values,
+ const std::vector<bool>& out_is_valid,
+ const std::vector<int32_t>& out_indices) {
+ CheckDictEncode<ArrowType, std::string>(type(), in_values, in_is_valid, out_values,
+ out_is_valid, out_indices);
+ }
+
+ void CheckValueCountsP(const std::vector<std::string>& in_values,
+ const std::vector<bool>& in_is_valid,
+ const std::vector<std::string>& out_values,
+ const std::vector<bool>& out_is_valid,
+ const std::vector<int64_t>& out_counts) {
+ CheckValueCounts<ArrowType, std::string>(type(), in_values, in_is_valid, out_values,
+ out_is_valid, out_counts);
+ }
+
+ void CheckUniqueP(const std::vector<std::string>& in_values,
+ const std::vector<bool>& in_is_valid,
+ const std::vector<std::string>& out_values,
+ const std::vector<bool>& out_is_valid) {
+ CheckUnique<ArrowType, std::string>(type(), in_values, in_is_valid, out_values,
+ out_is_valid);
+ }
+};
+
+TYPED_TEST_SUITE(TestHashKernelBinaryTypes, BinaryArrowTypes);
+
+TYPED_TEST(TestHashKernelBinaryTypes, ZeroChunks) {
+ auto type = this->type();
+
+ auto zero_chunks = std::make_shared<ChunkedArray>(ArrayVector{}, type);
+ ASSERT_OK_AND_ASSIGN(Datum result, DictionaryEncode(zero_chunks));
+
+ ASSERT_EQ(result.kind(), Datum::CHUNKED_ARRAY);
+ AssertChunkedEqual(*result.chunked_array(),
+ ChunkedArray({}, dictionary(int32(), type)));
+}
+
+TYPED_TEST(TestHashKernelBinaryTypes, TwoChunks) {
+ auto type = this->type();
+
+ auto two_chunks = std::make_shared<ChunkedArray>(
+ ArrayVector{
+ ArrayFromJSON(type, "[\"a\"]"),
+ ArrayFromJSON(type, "[\"b\"]"),
+ },
+ type);
+ ASSERT_OK_AND_ASSIGN(Datum result, DictionaryEncode(two_chunks));
+
+ auto dict_type = dictionary(int32(), type);
+ auto dictionary = ArrayFromJSON(type, R"(["a", "b"])");
+
+ auto chunk_0 = std::make_shared<DictionaryArray>(
+ dict_type, ArrayFromJSON(int32(), "[0]"), dictionary);
+ auto chunk_1 = std::make_shared<DictionaryArray>(
+ dict_type, ArrayFromJSON(int32(), "[1]"), dictionary);
+
+ ASSERT_EQ(result.kind(), Datum::CHUNKED_ARRAY);
+ AssertChunkedEqual(*result.chunked_array(),
+ ChunkedArray({chunk_0, chunk_1}, dict_type));
+}
+
+TYPED_TEST(TestHashKernelBinaryTypes, Unique) {
+ this->CheckUniqueP({"test", "", "test2", "test"}, {true, false, true, true},
+ {"test", "", "test2"}, {1, 0, 1});
+
+ // Sliced
+ CheckUnique(
+ ArrayFromJSON(this->type(), R"(["ab", null, "cd", "ef", "cd", "gh"])")->Slice(1, 4),
+ ArrayFromJSON(this->type(), R"([null, "cd", "ef"])"));
+}
+
+TYPED_TEST(TestHashKernelBinaryTypes, ValueCounts) {
+ this->CheckValueCountsP({"test", "", "test2", "test"}, {true, false, true, true},
+ {"test", "", "test2"}, {1, 0, 1}, {2, 1, 1});
+
+ // Sliced
+ CheckValueCounts(
+ ArrayFromJSON(this->type(), R"(["ab", null, "cd", "ab", "cd", "ef"])")->Slice(1, 4),
+ ArrayFromJSON(this->type(), R"([null, "cd", "ab"])"),
+ ArrayFromJSON(int64(), "[1, 2, 1]"));
+}
+
+TYPED_TEST(TestHashKernelBinaryTypes, DictEncode) {
+ this->CheckDictEncodeP({"test", "", "test2", "test", "baz"},
+ {true, false, true, true, true}, {"test", "test2", "baz"}, {},
+ {0, 0, 1, 0, 2});
+
+ // Sliced
+ CheckDictEncode(
+ ArrayFromJSON(this->type(), R"(["ab", null, "cd", "ab", "cd", "ef"])")->Slice(1, 4),
+ ArrayFromJSON(this->type(), R"(["cd", "ab"])"),
+ ArrayFromJSON(int32(), "[null, 0, 1, 0]"));
+}
+
+TYPED_TEST(TestHashKernelBinaryTypes, BinaryResizeTable) {
+ const int32_t kTotalValues = 10000;
+#if !defined(ARROW_VALGRIND)
+ const int32_t kRepeats = 10;
+#else
+ // Mitigate Valgrind's slowness
+ const int32_t kRepeats = 3;
+#endif
+
+ std::vector<std::string> values;
+ std::vector<std::string> uniques;
+ std::vector<int32_t> indices;
+ std::vector<int64_t> counts;
+ char buf[20] = "test";
+
+ for (int32_t i = 0; i < kTotalValues * kRepeats; i++) {
+ int32_t index = i % kTotalValues;
+
+ ASSERT_GE(snprintf(buf + 4, sizeof(buf) - 4, "%d", index), 0);
+ values.emplace_back(buf);
+
+ if (i < kTotalValues) {
+ uniques.push_back(values.back());
+ counts.push_back(kRepeats);
+ }
+ indices.push_back(index);
+ }
+
+ this->CheckUniqueP(values, {}, uniques, {});
+ this->CheckValueCountsP(values, {}, uniques, {}, counts);
+ this->CheckDictEncodeP(values, {}, uniques, {}, indices);
+}
+
+TEST_F(TestHashKernel, UniqueFixedSizeBinary) {
+ auto type = fixed_size_binary(3);
+
+ CheckUnique<FixedSizeBinaryType, std::string>(type, {"aaa", "", "bbb", "aaa"},
+ {true, false, true, true},
+ {"aaa", "", "bbb"}, {1, 0, 1});
+
+ // Sliced
+ CheckUnique(
+ ArrayFromJSON(type, R"(["aaa", null, "bbb", "bbb", "ccc", "ddd"])")->Slice(1, 4),
+ ArrayFromJSON(type, R"([null, "bbb", "ccc"])"));
+}
+
+TEST_F(TestHashKernel, ValueCountsFixedSizeBinary) {
+ auto type = fixed_size_binary(3);
+ auto input = ArrayFromJSON(type, R"(["aaa", null, "bbb", "bbb", "ccc", null])");
+
+ CheckValueCounts(input, ArrayFromJSON(type, R"(["aaa", null, "bbb", "ccc"])"),
+ ArrayFromJSON(int64(), "[1, 2, 2, 1]"));
+
+ // Sliced
+ CheckValueCounts(input->Slice(1, 4), ArrayFromJSON(type, R"([null, "bbb", "ccc"])"),
+ ArrayFromJSON(int64(), "[1, 2, 1]"));
+}
+
+TEST_F(TestHashKernel, DictEncodeFixedSizeBinary) {
+ auto type = fixed_size_binary(3);
+
+ CheckDictEncode<FixedSizeBinaryType, std::string>(
+ type, {"bbb", "", "bbb", "aaa", "ccc"}, {true, false, true, true, true},
+ {"bbb", "aaa", "ccc"}, {}, {0, 0, 0, 1, 2});
+
+ // Sliced
+ CheckDictEncode(
+ ArrayFromJSON(type, R"(["aaa", null, "bbb", "bbb", "ccc", "ddd"])")->Slice(1, 4),
+ ArrayFromJSON(type, R"(["bbb", "ccc"])"),
+ ArrayFromJSON(int32(), "[null, 0, 0, 1]"));
+}
+
+TEST_F(TestHashKernel, FixedSizeBinaryResizeTable) {
+ const int32_t kTotalValues = 10000;
+#if !defined(ARROW_VALGRIND)
+ const int32_t kRepeats = 10;
+#else
+ // Mitigate Valgrind's slowness
+ const int32_t kRepeats = 3;
+#endif
+
+ std::vector<std::string> values;
+ std::vector<std::string> uniques;
+ std::vector<int32_t> indices;
+ char buf[7] = "test..";
+
+ for (int32_t i = 0; i < kTotalValues * kRepeats; i++) {
+ int32_t index = i % kTotalValues;
+
+ buf[4] = static_cast<char>(index / 128);
+ buf[5] = static_cast<char>(index % 128);
+ values.emplace_back(buf, 6);
+
+ if (i < kTotalValues) {
+ uniques.push_back(values.back());
+ }
+ indices.push_back(index);
+ }
+
+ auto type = fixed_size_binary(6);
+ CheckUnique<FixedSizeBinaryType, std::string>(type, values, {}, uniques, {});
+ CheckDictEncode<FixedSizeBinaryType, std::string>(type, values, {}, uniques, {},
+ indices);
+}
+
+TEST_F(TestHashKernel, UniqueDecimal) {
+ std::vector<Decimal128> values{12, 12, 11, 12};
+ std::vector<Decimal128> expected{12, 0, 11};
+
+ CheckUnique<Decimal128Type, Decimal128>(decimal(2, 0), values,
+ {true, false, true, true}, expected, {1, 0, 1});
+}
+
+TEST_F(TestHashKernel, UniqueNull) {
+ CheckUnique<NullType, std::nullptr_t>(null(), {nullptr, nullptr}, {false, true},
+ {nullptr}, {false});
+ CheckUnique<NullType, std::nullptr_t>(null(), {}, {}, {}, {});
+}
+
+TEST_F(TestHashKernel, ValueCountsDecimal) {
+ std::vector<Decimal128> values{12, 12, 11, 12};
+ std::vector<Decimal128> expected{12, 0, 11};
+
+ CheckValueCounts<Decimal128Type, Decimal128>(
+ decimal(2, 0), values, {true, false, true, true}, expected, {1, 0, 1}, {2, 1, 1});
+}
+
+TEST_F(TestHashKernel, DictEncodeDecimal) {
+ std::vector<Decimal128> values{12, 12, 11, 12, 13};
+ std::vector<Decimal128> expected{12, 11, 13};
+
+ CheckDictEncode<Decimal128Type, Decimal128>(decimal(2, 0), values,
+ {true, false, true, true, true}, expected,
+ {}, {0, 0, 1, 0, 2});
+}
+
+TEST_F(TestHashKernel, DictionaryUniqueAndValueCounts) {
+ auto dict_json = "[10, 20, 30, 40]";
+ auto dict = ArrayFromJSON(int64(), dict_json);
+ for (auto index_ty : IntTypes()) {
+ auto indices = ArrayFromJSON(index_ty, "[3, 0, 0, 0, 1, 1, 3, 0, 1, 3, 0, 1]");
+
+ auto dict_ty = dictionary(index_ty, int64());
+
+ auto ex_indices = ArrayFromJSON(index_ty, "[3, 0, 1]");
+
+ auto input = std::make_shared<DictionaryArray>(dict_ty, indices, dict);
+ auto ex_uniques = std::make_shared<DictionaryArray>(dict_ty, ex_indices, dict);
+ CheckUnique(input, ex_uniques);
+
+ auto ex_counts = ArrayFromJSON(int64(), "[3, 5, 4]");
+ CheckValueCounts(input, ex_uniques, ex_counts);
+
+ // Empty array - executor never gives the kernel any batches,
+ // so result dictionary is empty
+ CheckUnique(DictArrayFromJSON(dict_ty, "[]", dict_json),
+ DictArrayFromJSON(dict_ty, "[]", "[]"));
+ CheckValueCounts(DictArrayFromJSON(dict_ty, "[]", dict_json),
+ DictArrayFromJSON(dict_ty, "[]", "[]"),
+ ArrayFromJSON(int64(), "[]"));
+
+ // Check chunked array
+ auto chunked = *ChunkedArray::Make({input->Slice(0, 2), input->Slice(2)});
+ CheckUnique(chunked, ex_uniques);
+ CheckValueCounts(chunked, ex_uniques, ex_counts);
+
+ // Different chunk dictionaries
+ auto input_2 = DictArrayFromJSON(dict_ty, "[1, null, 2, 3]", "[30, 40, 50, 60]");
+ auto ex_uniques_2 =
+ DictArrayFromJSON(dict_ty, "[3, 0, 1, null, 4, 5]", "[10, 20, 30, 40, 50, 60]");
+ auto ex_counts_2 = ArrayFromJSON(int64(), "[4, 5, 4, 1, 1, 1]");
+ auto different_dictionaries = *ChunkedArray::Make({input, input_2}, dict_ty);
+
+ CheckUnique(different_dictionaries, ex_uniques_2);
+ CheckValueCounts(different_dictionaries, ex_uniques_2, ex_counts_2);
+
+ // Dictionary with encoded nulls
+ auto dict_with_null = ArrayFromJSON(int64(), "[10, null, 30, 40]");
+ input = std::make_shared<DictionaryArray>(dict_ty, indices, dict_with_null);
+ ex_uniques = std::make_shared<DictionaryArray>(dict_ty, ex_indices, dict_with_null);
+ CheckUnique(input, ex_uniques);
+
+ CheckValueCounts(input, ex_uniques, ex_counts);
+
+ // Dictionary with masked nulls
+ auto indices_with_null =
+ ArrayFromJSON(index_ty, "[3, 0, 0, 0, null, null, 3, 0, null, 3, 0, null]");
+ auto ex_indices_with_null = ArrayFromJSON(index_ty, "[3, 0, null]");
+ ex_uniques = std::make_shared<DictionaryArray>(dict_ty, ex_indices_with_null, dict);
+ input = std::make_shared<DictionaryArray>(dict_ty, indices_with_null, dict);
+ CheckUnique(input, ex_uniques);
+
+ CheckValueCounts(input, ex_uniques, ex_counts);
+
+ // Dictionary with encoded AND masked nulls
+ auto some_indices_with_null =
+ ArrayFromJSON(index_ty, "[3, 0, 0, 0, 1, 1, 3, 0, null, 3, 0, null]");
+ ex_uniques =
+ std::make_shared<DictionaryArray>(dict_ty, ex_indices_with_null, dict_with_null);
+ input = std::make_shared<DictionaryArray>(dict_ty, indices_with_null, dict_with_null);
+ CheckUnique(input, ex_uniques);
+ CheckValueCounts(input, ex_uniques, ex_counts);
+ }
+}
+
+/* TODO(ARROW-4124): Determine if we want to do something that is reproducible with
+ * floats.
+TEST_F(TestHashKernel, ValueCountsFloat) {
+
+ // No nulls
+ CheckValueCounts<FloatType, float>(float32(), {1.0f, 0.0f, -0.0f,
+std::nan("1"), std::nan("2") },
+ {}, {0.0f, 1.0f, std::nan("1")}, {}, {});
+
+ CheckValueCounts<DoubleType, double>(float64(), {1.0f, 0.0f, -0.0f,
+std::nan("1"), std::nan("2") },
+ {}, {0.0f, 1.0f, std::nan("1")}, {}, {});
+}
+*/
+
+TEST_F(TestHashKernel, ChunkedArrayInvoke) {
+ std::vector<std::string> values1 = {"foo", "bar", "foo"};
+ std::vector<std::string> values2 = {"bar", "baz", "quuux", "foo"};
+
+ auto type = utf8();
+ auto a1 = _MakeArray<StringType, std::string>(type, values1, {});
+ auto a2 = _MakeArray<StringType, std::string>(type, values2, {});
+
+ std::vector<std::string> dict_values = {"foo", "bar", "baz", "quuux"};
+ auto ex_dict = _MakeArray<StringType, std::string>(type, dict_values, {});
+
+ auto ex_counts = _MakeArray<Int64Type, int64_t>(int64(), {3, 2, 1, 1}, {});
+
+ ArrayVector arrays = {a1, a2};
+ auto carr = std::make_shared<ChunkedArray>(arrays);
+
+ // Unique
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> result, Unique(carr));
+ ASSERT_ARRAYS_EQUAL(*ex_dict, *result);
+
+ // Dictionary encode
+ auto dict_type = dictionary(int32(), type);
+
+ auto i1 = _MakeArray<Int32Type, int32_t>(int32(), {0, 1, 0}, {});
+ auto i2 = _MakeArray<Int32Type, int32_t>(int32(), {1, 2, 3, 0}, {});
+
+ ArrayVector dict_arrays = {std::make_shared<DictionaryArray>(dict_type, i1, ex_dict),
+ std::make_shared<DictionaryArray>(dict_type, i2, ex_dict)};
+ auto dict_carr = std::make_shared<ChunkedArray>(dict_arrays);
+
+ // Unique counts
+ ASSERT_OK_AND_ASSIGN(auto counts, ValueCounts(carr));
+ ASSERT_ARRAYS_EQUAL(*ex_dict, *counts->field(0));
+ ASSERT_ARRAYS_EQUAL(*ex_counts, *counts->field(1));
+
+ // Dictionary encode
+ ASSERT_OK_AND_ASSIGN(Datum encoded_out, DictionaryEncode(carr));
+ ASSERT_EQ(Datum::CHUNKED_ARRAY, encoded_out.kind());
+
+ AssertChunkedEqual(*dict_carr, *encoded_out.chunked_array());
+}
+
+TEST_F(TestHashKernel, ZeroLengthDictionaryEncode) {
+ // ARROW-7008
+ auto values = ArrayFromJSON(utf8(), "[]");
+ ASSERT_OK_AND_ASSIGN(Datum datum_result, DictionaryEncode(values));
+ ValidateOutput(datum_result);
+}
+
+TEST_F(TestHashKernel, NullEncodingSchemes) {
+ auto values = ArrayFromJSON(uint8(), "[1, 1, null, 2, null]");
+
+ // Masking should put null in the indices array
+ auto expected_mask_indices = ArrayFromJSON(int32(), "[0, 0, null, 1, null]");
+ auto expected_mask_dictionary = ArrayFromJSON(uint8(), "[1, 2]");
+ auto dictionary_type = dictionary(int32(), uint8());
+ std::shared_ptr<Array> expected = std::make_shared<DictionaryArray>(
+ dictionary_type, expected_mask_indices, expected_mask_dictionary);
+
+ ASSERT_OK_AND_ASSIGN(Datum datum_result, DictionaryEncode(values));
+ std::shared_ptr<Array> result = datum_result.make_array();
+ AssertArraysEqual(*expected, *result);
+
+ // Encoding should put null in the dictionary
+ auto expected_encoded_indices = ArrayFromJSON(int32(), "[0, 0, 1, 2, 1]");
+ auto expected_encoded_dict = ArrayFromJSON(uint8(), "[1, null, 2]");
+ expected = std::make_shared<DictionaryArray>(dictionary_type, expected_encoded_indices,
+ expected_encoded_dict);
+
+ auto options = DictionaryEncodeOptions::Defaults();
+ options.null_encoding_behavior = DictionaryEncodeOptions::ENCODE;
+ ASSERT_OK_AND_ASSIGN(datum_result, DictionaryEncode(values, options));
+ result = datum_result.make_array();
+ AssertArraysEqual(*expected, *result);
+}
+
+TEST_F(TestHashKernel, ChunkedArrayZeroChunk) {
+ // ARROW-6857
+ auto chunked_array = std::make_shared<ChunkedArray>(ArrayVector{}, utf8());
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> result_array, Unique(chunked_array));
+ auto expected = ArrayFromJSON(chunked_array->type(), "[]");
+ AssertArraysEqual(*expected, *result_array);
+
+ ASSERT_OK_AND_ASSIGN(result_array, ValueCounts(chunked_array));
+ expected = ArrayFromJSON(struct_({field(kValuesFieldName, chunked_array->type()),
+ field(kCountsFieldName, int64())}),
+ "[]");
+ AssertArraysEqual(*expected, *result_array);
+
+ ASSERT_OK_AND_ASSIGN(Datum result_datum, DictionaryEncode(chunked_array));
+ auto dict_type = dictionary(int32(), chunked_array->type());
+ ASSERT_EQ(result_datum.kind(), Datum::CHUNKED_ARRAY);
+
+ AssertChunkedEqual(*std::make_shared<ChunkedArray>(ArrayVector{}, dict_type),
+ *result_datum.chunked_array());
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_nested.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_nested.cc
new file mode 100644
index 000000000..f4c61ba74
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_nested.cc
@@ -0,0 +1,194 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Vector kernels involving nested types
+
+#include "arrow/array/array_base.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/result.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+namespace {
+
+template <typename Type>
+Status ListFlatten(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ typename TypeTraits<Type>::ArrayType list_array(batch[0].array());
+ ARROW_ASSIGN_OR_RAISE(auto result, list_array.Flatten(ctx->memory_pool()));
+ out->value = result->data();
+ return Status::OK();
+}
+
+struct ListParentIndicesArray {
+ KernelContext* ctx;
+ const std::shared_ptr<ArrayData>& input;
+ int64_t base_output_offset;
+ std::shared_ptr<ArrayData> out;
+
+ template <typename Type, typename offset_type = typename Type::offset_type>
+ Status VisitList(const Type&) {
+ typename TypeTraits<Type>::ArrayType list(input);
+
+ const offset_type* offsets = list.raw_value_offsets();
+ offset_type values_length = offsets[list.length()] - offsets[0];
+
+ ARROW_ASSIGN_OR_RAISE(auto indices,
+ ctx->Allocate(values_length * sizeof(offset_type)));
+ auto out_indices = reinterpret_cast<offset_type*>(indices->mutable_data());
+ for (int64_t i = 0; i < list.length(); ++i) {
+ // Note: In most cases, null slots are empty, but when they are non-empty
+ // we write out the indices so make sure they are accounted for. This
+ // behavior could be changed if needed in the future.
+ for (offset_type j = offsets[i]; j < offsets[i + 1]; ++j) {
+ *out_indices++ = static_cast<offset_type>(i + base_output_offset);
+ }
+ }
+
+ BufferVector buffers{nullptr, std::move(indices)};
+ int64_t null_count = 0;
+ if (sizeof(offset_type) == 4) {
+ out = std::make_shared<ArrayData>(int32(), values_length, std::move(buffers),
+ null_count);
+ } else {
+ out = std::make_shared<ArrayData>(int64(), values_length, std::move(buffers),
+ null_count);
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const ListType& type) { return VisitList(type); }
+
+ Status Visit(const LargeListType& type) { return VisitList(type); }
+
+ Status Visit(const FixedSizeListType& type) {
+ using offset_type = typename FixedSizeListType::offset_type;
+ const offset_type slot_length = type.list_size();
+ const int64_t values_length = slot_length * (input->length - input->GetNullCount());
+ ARROW_ASSIGN_OR_RAISE(auto indices, ctx->Allocate(values_length * sizeof(int32_t)));
+ auto* out_indices = reinterpret_cast<offset_type*>(indices->mutable_data());
+ const auto* bitmap = input->GetValues<uint8_t>(0, 0);
+ for (int32_t i = 0; i < input->length; i++) {
+ if (!bitmap || BitUtil::GetBit(bitmap, input->offset + i)) {
+ std::fill(out_indices, out_indices + slot_length,
+ static_cast<int32_t>(base_output_offset + i));
+ out_indices += slot_length;
+ }
+ }
+ out = ArrayData::Make(int32(), values_length, {nullptr, std::move(indices)},
+ /*null_count=*/0);
+ return Status::OK();
+ }
+
+ Status Visit(const DataType& type) {
+ return Status::TypeError("Function 'list_parent_indices' expects list input, got ",
+ type.ToString());
+ }
+
+ static Result<std::shared_ptr<ArrayData>> Exec(KernelContext* ctx,
+ const std::shared_ptr<ArrayData>& input,
+ int64_t base_output_offset = 0) {
+ ListParentIndicesArray self{ctx, input, base_output_offset, /*out=*/nullptr};
+ RETURN_NOT_OK(VisitTypeInline(*input->type, &self));
+ DCHECK_NE(self.out, nullptr);
+ return self.out;
+ }
+};
+
+Result<std::shared_ptr<DataType>> ListParentIndicesType(const DataType& input_type) {
+ switch (input_type.id()) {
+ case Type::LIST:
+ case Type::FIXED_SIZE_LIST:
+ return int32();
+ case Type::LARGE_LIST:
+ return int64();
+ default:
+ return Status::TypeError("Function 'list_parent_indices' expects list input, got ",
+ input_type.ToString());
+ }
+}
+
+const FunctionDoc list_flatten_doc(
+ "Flatten list values",
+ ("`lists` must have a list-like type.\n"
+ "Return an array with the top list level flattened.\n"
+ "Top-level null values in `lists` do not emit anything in the input."),
+ {"lists"});
+
+const FunctionDoc list_parent_indices_doc(
+ "Compute parent indices of nested list values",
+ ("`lists` must have a list-like type.\n"
+ "For each value in each list of `lists`, the top-level list index\n"
+ "is emitted."),
+ {"lists"});
+
+class ListParentIndicesFunction : public MetaFunction {
+ public:
+ ListParentIndicesFunction()
+ : MetaFunction("list_parent_indices", Arity::Unary(), &list_parent_indices_doc) {}
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ KernelContext kernel_ctx(ctx);
+ switch (args[0].kind()) {
+ case Datum::ARRAY:
+ return ListParentIndicesArray::Exec(&kernel_ctx, args[0].array());
+ case Datum::CHUNKED_ARRAY: {
+ const auto& input = args[0].chunked_array();
+ ARROW_ASSIGN_OR_RAISE(auto out_ty, ListParentIndicesType(*input->type()));
+
+ int64_t base_output_offset = 0;
+ ArrayVector out_chunks;
+ for (const auto& chunk : input->chunks()) {
+ ARROW_ASSIGN_OR_RAISE(auto out_chunk,
+ ListParentIndicesArray::Exec(&kernel_ctx, chunk->data(),
+ base_output_offset));
+ out_chunks.push_back(MakeArray(std::move(out_chunk)));
+ base_output_offset += chunk->length();
+ }
+ return std::make_shared<ChunkedArray>(std::move(out_chunks), std::move(out_ty));
+ }
+ default:
+ return Status::NotImplemented(
+ "Unsupported input type for function 'list_parent_indices': ",
+ args[0].ToString());
+ }
+ }
+};
+
+} // namespace
+
+void RegisterVectorNested(FunctionRegistry* registry) {
+ auto flatten =
+ std::make_shared<VectorFunction>("list_flatten", Arity::Unary(), &list_flatten_doc);
+ DCHECK_OK(flatten->AddKernel({InputType::Array(Type::LIST)}, OutputType(ListValuesType),
+ ListFlatten<ListType>));
+ DCHECK_OK(flatten->AddKernel({InputType::Array(Type::FIXED_SIZE_LIST)},
+ OutputType(ListValuesType),
+ ListFlatten<FixedSizeListType>));
+ DCHECK_OK(flatten->AddKernel({InputType::Array(Type::LARGE_LIST)},
+ OutputType(ListValuesType), ListFlatten<LargeListType>));
+ DCHECK_OK(registry->AddFunction(std::move(flatten)));
+
+ DCHECK_OK(registry->AddFunction(std::make_shared<ListParentIndicesFunction>()));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_nested_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_nested_test.cc
new file mode 100644
index 000000000..28bb4bdfd
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_nested_test.cc
@@ -0,0 +1,132 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include "arrow/chunked_array.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/result.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+namespace compute {
+
+using arrow::internal::checked_cast;
+
+TEST(TestVectorNested, ListFlatten) {
+ for (auto ty : {list(int16()), large_list(int16())}) {
+ auto input = ArrayFromJSON(ty, "[[0, null, 1], null, [2, 3], []]");
+ auto expected = ArrayFromJSON(int16(), "[0, null, 1, 2, 3]");
+ CheckVectorUnary("list_flatten", input, expected);
+
+ // Construct a list with a non-empty null slot
+ TweakValidityBit(input, 0, false);
+ expected = ArrayFromJSON(int16(), "[2, 3]");
+ CheckVectorUnary("list_flatten", input, expected);
+ }
+}
+
+TEST(TestVectorNested, ListFlattenChunkedArray) {
+ for (auto ty : {list(int16()), large_list(int16())}) {
+ auto input = ChunkedArrayFromJSON(ty, {"[[0, null, 1], null]", "[[2, 3], []]"});
+ auto expected = ChunkedArrayFromJSON(int16(), {"[0, null, 1]", "[2, 3]"});
+ CheckVectorUnary("list_flatten", input, expected);
+
+ input = ChunkedArrayFromJSON(ty, {});
+ expected = ChunkedArrayFromJSON(int16(), {});
+ CheckVectorUnary("list_flatten", input, expected);
+ }
+}
+
+TEST(TestVectorNested, ListFlattenFixedSizeList) {
+ for (auto ty : {fixed_size_list(int16(), 2), fixed_size_list(uint32(), 2)}) {
+ const auto& out_ty = checked_cast<const FixedSizeListType&>(*ty).value_type();
+ {
+ auto input = ArrayFromJSON(ty, "[[0, null], null, [2, 3], [0, 42]]");
+ auto expected = ArrayFromJSON(out_ty, "[0, null, 2, 3, 0, 42]");
+ CheckVectorUnary("list_flatten", input, expected);
+ }
+
+ {
+ // Test a chunked array
+ auto input = ChunkedArrayFromJSON(ty, {"[[0, null], null]", "[[2, 3], [0, 42]]"});
+ auto expected = ChunkedArrayFromJSON(out_ty, {"[0, null]", "[2, 3, 0, 42]"});
+ CheckVectorUnary("list_flatten", input, expected);
+
+ input = ChunkedArrayFromJSON(ty, {});
+ expected = ChunkedArrayFromJSON(out_ty, {});
+ CheckVectorUnary("list_flatten", input, expected);
+ }
+ }
+}
+
+TEST(TestVectorNested, ListParentIndices) {
+ for (auto ty : {list(int16()), large_list(int16())}) {
+ auto input = ArrayFromJSON(ty, "[[0, null, 1], null, [2, 3], [], [4, 5]]");
+
+ auto out_ty = ty->id() == Type::LIST ? int32() : int64();
+ auto expected = ArrayFromJSON(out_ty, "[0, 0, 0, 2, 2, 4, 4]");
+ CheckVectorUnary("list_parent_indices", input, expected);
+ }
+
+ // Construct a list with a non-empty null slot
+ auto input = ArrayFromJSON(list(int16()), "[[0, null, 1], [0, 0], [2, 3], [], [4, 5]]");
+ TweakValidityBit(input, 1, false);
+ auto expected = ArrayFromJSON(int32(), "[0, 0, 0, 1, 1, 2, 2, 4, 4]");
+ CheckVectorUnary("list_parent_indices", input, expected);
+}
+
+TEST(TestVectorNested, ListParentIndicesChunkedArray) {
+ for (auto ty : {list(int16()), large_list(int16())}) {
+ auto input =
+ ChunkedArrayFromJSON(ty, {"[[0, null, 1], null]", "[[2, 3], [], [4, 5]]"});
+
+ auto out_ty = ty->id() == Type::LIST ? int32() : int64();
+ auto expected = ChunkedArrayFromJSON(out_ty, {"[0, 0, 0]", "[2, 2, 4, 4]"});
+ CheckVectorUnary("list_parent_indices", input, expected);
+
+ input = ChunkedArrayFromJSON(ty, {});
+ expected = ChunkedArrayFromJSON(out_ty, {});
+ CheckVectorUnary("list_parent_indices", input, expected);
+ }
+}
+
+TEST(TestVectorNested, ListParentIndicesFixedSizeList) {
+ for (auto ty : {fixed_size_list(int16(), 2), fixed_size_list(uint32(), 2)}) {
+ {
+ auto input = ArrayFromJSON(ty, "[[0, null], null, [1, 2], [3, 4], [null, 5]]");
+ auto expected = ArrayFromJSON(int32(), "[0, 0, 2, 2, 3, 3, 4, 4]");
+ CheckVectorUnary("list_parent_indices", input, expected);
+ }
+ {
+ // Test a chunked array
+ auto input =
+ ChunkedArrayFromJSON(ty, {"[[0, null], null, [1, 2]]", "[[3, 4], [null, 5]]"});
+ auto expected = ChunkedArrayFromJSON(int32(), {"[0, 0, 2, 2]", "[3, 3, 4, 4]"});
+ CheckVectorUnary("list_parent_indices", input, expected);
+
+ input = ChunkedArrayFromJSON(ty, {});
+ expected = ChunkedArrayFromJSON(int32(), {});
+ CheckVectorUnary("list_parent_indices", input, expected);
+ }
+ }
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_partition_benchmark.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_partition_benchmark.cc
new file mode 100644
index 000000000..ff009c655
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_partition_benchmark.cc
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+
+namespace arrow {
+namespace compute {
+constexpr auto kSeed = 0x0ff1ce;
+
+static void NthToIndicesBenchmark(benchmark::State& state,
+ const std::shared_ptr<Array>& values, int64_t n) {
+ for (auto _ : state) {
+ ABORT_NOT_OK(NthToIndices(*values, n).status());
+ }
+ state.SetItemsProcessed(state.iterations() * values->length());
+}
+
+static void NthToIndicesInt64(benchmark::State& state) {
+ RegressionArgs args(state);
+
+ const int64_t array_size = args.size / sizeof(int64_t);
+ auto rand = random::RandomArrayGenerator(kSeed);
+
+ auto min = std::numeric_limits<int64_t>::min();
+ auto max = std::numeric_limits<int64_t>::max();
+ auto values = rand.Int64(array_size, min, max, args.null_proportion);
+
+ NthToIndicesBenchmark(state, values, array_size / 2);
+}
+
+BENCHMARK(NthToIndicesInt64)
+ ->Apply(RegressionSetArgs)
+ ->Args({1 << 20, 100})
+ ->Args({1 << 23, 100})
+ ->MinTime(1.0)
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_replace.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_replace.cc
new file mode 100644
index 000000000..450f99d78
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_replace.cc
@@ -0,0 +1,541 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/util/bitmap_ops.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+namespace {
+
+Status ReplacementArrayTooShort(int64_t expected, int64_t actual) {
+ return Status::Invalid("Replacement array must be of appropriate length (expected ",
+ expected, " items but got ", actual, " items)");
+}
+
+// Helper to implement replace_with kernel with scalar mask for fixed-width types,
+// using callbacks to handle both bool and byte-sized types
+template <typename Functor>
+Status ReplaceWithScalarMask(KernelContext* ctx, const ArrayData& array,
+ const BooleanScalar& mask, const Datum& replacements,
+ ArrayData* output) {
+ Datum source = array;
+ if (!mask.is_valid) {
+ // Output = null
+ source = MakeNullScalar(output->type);
+ } else if (mask.value) {
+ // Output = replacement
+ source = replacements;
+ }
+ uint8_t* out_bitmap = output->buffers[0]->mutable_data();
+ uint8_t* out_values = output->buffers[1]->mutable_data();
+ const int64_t out_offset = output->offset;
+ if (source.is_array()) {
+ const ArrayData& in_data = *source.array();
+ if (in_data.length < array.length) {
+ return ReplacementArrayTooShort(array.length, in_data.length);
+ }
+ Functor::CopyData(*array.type, out_values, out_offset, in_data, /*in_offset=*/0,
+ array.length);
+ if (in_data.MayHaveNulls()) {
+ arrow::internal::CopyBitmap(in_data.buffers[0]->data(), in_data.offset,
+ array.length, out_bitmap, out_offset);
+ } else {
+ BitUtil::SetBitsTo(out_bitmap, out_offset, array.length, true);
+ }
+ } else {
+ const Scalar& in_data = *source.scalar();
+ Functor::CopyData(*array.type, out_values, out_offset, in_data, /*in_offset=*/0,
+ array.length);
+ BitUtil::SetBitsTo(out_bitmap, out_offset, array.length, in_data.is_valid);
+ }
+ return Status::OK();
+}
+
+struct CopyArrayBitmap {
+ const uint8_t* in_bitmap;
+ int64_t in_offset;
+
+ void CopyBitmap(uint8_t* out_bitmap, int64_t out_offset, int64_t offset,
+ int64_t length) const {
+ arrow::internal::CopyBitmap(in_bitmap, in_offset + offset, length, out_bitmap,
+ out_offset);
+ }
+
+ void SetBit(uint8_t* out_bitmap, int64_t out_offset, int64_t offset) const {
+ BitUtil::SetBitTo(out_bitmap, out_offset,
+ BitUtil::GetBit(in_bitmap, in_offset + offset));
+ }
+};
+
+struct CopyScalarBitmap {
+ const bool is_valid;
+
+ void CopyBitmap(uint8_t* out_bitmap, int64_t out_offset, int64_t offset,
+ int64_t length) const {
+ BitUtil::SetBitsTo(out_bitmap, out_offset, length, is_valid);
+ }
+
+ void SetBit(uint8_t* out_bitmap, int64_t out_offset, int64_t offset) const {
+ BitUtil::SetBitTo(out_bitmap, out_offset, is_valid);
+ }
+};
+
+// Helper to implement replace_with kernel with array mask for fixed-width types,
+// using callbacks to handle both bool and byte-sized types and to handle
+// scalar and array replacements
+template <typename Functor, typename Data, typename CopyBitmap>
+void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask,
+ const Data& replacements, bool replacements_bitmap,
+ const CopyBitmap& copy_bitmap, const uint8_t* mask_bitmap,
+ const uint8_t* mask_values, uint8_t* out_bitmap,
+ uint8_t* out_values, const int64_t out_offset) {
+ Functor::CopyData(*array.type, out_values, /*out_offset=*/0, array, /*in_offset=*/0,
+ array.length);
+ arrow::internal::OptionalBinaryBitBlockCounter counter(
+ mask_values, mask.offset, mask_bitmap, mask.offset, mask.length);
+ int64_t write_offset = 0;
+ int64_t replacements_offset = 0;
+ while (write_offset < array.length) {
+ BitBlockCount block = counter.NextAndBlock();
+ if (block.AllSet()) {
+ // Copy from replacement array
+ Functor::CopyData(*array.type, out_values, out_offset + write_offset, replacements,
+ replacements_offset, block.length);
+ if (replacements_bitmap) {
+ copy_bitmap.CopyBitmap(out_bitmap, out_offset + write_offset, replacements_offset,
+ block.length);
+ } else if (!replacements_bitmap && out_bitmap) {
+ BitUtil::SetBitsTo(out_bitmap, out_offset + write_offset, block.length, true);
+ }
+ replacements_offset += block.length;
+ } else if (block.popcount) {
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(mask_values, write_offset + mask.offset + i) &&
+ (!mask_bitmap ||
+ BitUtil::GetBit(mask_bitmap, write_offset + mask.offset + i))) {
+ Functor::CopyData(*array.type, out_values, out_offset + write_offset + i,
+ replacements, replacements_offset, /*length=*/1);
+ if (replacements_bitmap) {
+ copy_bitmap.SetBit(out_bitmap, out_offset + write_offset + i,
+ replacements_offset);
+ }
+ replacements_offset++;
+ }
+ }
+ }
+ write_offset += block.length;
+ }
+}
+
+template <typename Functor>
+Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array,
+ const ArrayData& mask, const Datum& replacements,
+ ArrayData* output) {
+ const int64_t out_offset = output->offset;
+ uint8_t* out_bitmap = nullptr;
+ uint8_t* out_values = output->buffers[1]->mutable_data();
+ const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr;
+ const uint8_t* mask_values = mask.buffers[1]->data();
+ const bool replacements_bitmap = replacements.is_array()
+ ? replacements.array()->MayHaveNulls()
+ : !replacements.scalar()->is_valid;
+ if (replacements.is_array()) {
+ // Check that we have enough replacement values
+ const int64_t replacements_length = replacements.array()->length;
+
+ BooleanArray mask_arr(mask.length, mask.buffers[1], mask.buffers[0], mask.null_count,
+ mask.offset);
+ const int64_t count = mask_arr.true_count();
+ if (count > replacements_length) {
+ return ReplacementArrayTooShort(count, replacements_length);
+ }
+ }
+ if (array.MayHaveNulls() || mask.MayHaveNulls() || replacements_bitmap) {
+ out_bitmap = output->buffers[0]->mutable_data();
+ output->null_count = -1;
+ if (array.MayHaveNulls()) {
+ // Copy array's bitmap
+ arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset, array.length,
+ out_bitmap, out_offset);
+ } else {
+ // Array has no bitmap but mask/replacements do, generate an all-valid bitmap
+ BitUtil::SetBitsTo(out_bitmap, out_offset, array.length, true);
+ }
+ } else {
+ BitUtil::SetBitsTo(output->buffers[0]->mutable_data(), out_offset, array.length,
+ true);
+ output->null_count = 0;
+ }
+
+ if (replacements.is_array()) {
+ const ArrayData& array_repl = *replacements.array();
+ ReplaceWithArrayMaskImpl<Functor>(
+ array, mask, array_repl, replacements_bitmap,
+ CopyArrayBitmap{replacements_bitmap ? array_repl.buffers[0]->data() : nullptr,
+ array_repl.offset},
+ mask_bitmap, mask_values, out_bitmap, out_values, out_offset);
+ } else {
+ const Scalar& scalar_repl = *replacements.scalar();
+ ReplaceWithArrayMaskImpl<Functor>(array, mask, scalar_repl, replacements_bitmap,
+ CopyScalarBitmap{scalar_repl.is_valid}, mask_bitmap,
+ mask_values, out_bitmap, out_values, out_offset);
+ }
+
+ if (mask.MayHaveNulls()) {
+ arrow::internal::BitmapAnd(out_bitmap, out_offset, mask.buffers[0]->data(),
+ mask.offset, array.length, out_offset, out_bitmap);
+ }
+ return Status::OK();
+}
+
+template <typename Type, typename Enable = void>
+struct ReplaceWithMask {};
+
+template <typename Type>
+struct ReplaceWithMask<Type, enable_if_number<Type>> {
+ using T = typename TypeTraits<Type>::CType;
+
+ static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset,
+ const ArrayData& in, const int64_t in_offset,
+ const int64_t length) {
+ const auto in_arr = in.GetValues<uint8_t>(1, (in_offset + in.offset) * sizeof(T));
+ std::memcpy(out + (out_offset * sizeof(T)), in_arr, length * sizeof(T));
+ }
+
+ static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset,
+ const Scalar& in, const int64_t in_offset, const int64_t length) {
+ T* begin = reinterpret_cast<T*>(out + (out_offset * sizeof(T)));
+ T* end = begin + length;
+ std::fill(begin, end, UnboxScalar<Type>::Unbox(in));
+ }
+
+ static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
+ const BooleanScalar& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithScalarMask<ReplaceWithMask<Type>>(ctx, array, mask, replacements,
+ output);
+ }
+
+ static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array,
+ const ArrayData& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithArrayMask<ReplaceWithMask<Type>>(ctx, array, mask, replacements,
+ output);
+ }
+};
+
+template <typename Type>
+struct ReplaceWithMask<Type, enable_if_boolean<Type>> {
+ static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset,
+ const ArrayData& in, const int64_t in_offset,
+ const int64_t length) {
+ const auto in_arr = in.GetValues<uint8_t>(1, /*absolute_offset=*/0);
+ arrow::internal::CopyBitmap(in_arr, in_offset + in.offset, length, out, out_offset);
+ }
+ static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset,
+ const Scalar& in, const int64_t in_offset, const int64_t length) {
+ BitUtil::SetBitsTo(out, out_offset, length, in.is_valid);
+ }
+
+ static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
+ const BooleanScalar& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithScalarMask<ReplaceWithMask<Type>>(ctx, array, mask, replacements,
+ output);
+ }
+ static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array,
+ const ArrayData& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithArrayMask<ReplaceWithMask<Type>>(ctx, array, mask, replacements,
+ output);
+ }
+};
+
+template <typename Type>
+struct ReplaceWithMask<Type, enable_if_same<Type, FixedSizeBinaryType>> {
+ static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset,
+ const ArrayData& in, const int64_t in_offset,
+ const int64_t length) {
+ const int32_t width = checked_cast<const FixedSizeBinaryType&>(ty).byte_width();
+ uint8_t* begin = out + (out_offset * width);
+ const auto in_arr = in.GetValues<uint8_t>(1, (in_offset + in.offset) * width);
+ std::memcpy(begin, in_arr, length * width);
+ }
+ static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset,
+ const Scalar& in, const int64_t in_offset, const int64_t length) {
+ const int32_t width = checked_cast<const FixedSizeBinaryType&>(ty).byte_width();
+ uint8_t* begin = out + (out_offset * width);
+ const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(in);
+ // Null scalar may have null value buffer
+ if (!scalar.value) return;
+ const Buffer& buffer = *scalar.value;
+ const uint8_t* value = buffer.data();
+ DCHECK_GE(buffer.size(), width);
+ for (int i = 0; i < length; i++) {
+ std::memcpy(begin, value, width);
+ begin += width;
+ }
+ }
+
+ static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
+ const BooleanScalar& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithScalarMask<ReplaceWithMask<Type>>(ctx, array, mask, replacements,
+ output);
+ }
+
+ static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array,
+ const ArrayData& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithArrayMask<ReplaceWithMask<Type>>(ctx, array, mask, replacements,
+ output);
+ }
+};
+
+template <typename Type>
+struct ReplaceWithMask<Type, enable_if_decimal<Type>> {
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+ static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset,
+ const ArrayData& in, const int64_t in_offset,
+ const int64_t length) {
+ const int32_t width = checked_cast<const FixedSizeBinaryType&>(ty).byte_width();
+ uint8_t* begin = out + (out_offset * width);
+ const auto in_arr = in.GetValues<uint8_t>(1, (in_offset + in.offset) * width);
+ std::memcpy(begin, in_arr, length * width);
+ }
+ static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset,
+ const Scalar& in, const int64_t in_offset, const int64_t length) {
+ const int32_t width = checked_cast<const FixedSizeBinaryType&>(ty).byte_width();
+ uint8_t* begin = out + (out_offset * width);
+ const auto& scalar = checked_cast<const ScalarType&>(in);
+ const auto value = scalar.value.ToBytes();
+ for (int i = 0; i < length; i++) {
+ std::memcpy(begin, value.data(), width);
+ begin += width;
+ }
+ }
+
+ static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
+ const BooleanScalar& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithScalarMask<ReplaceWithMask<Type>>(ctx, array, mask, replacements,
+ output);
+ }
+
+ static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array,
+ const ArrayData& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithArrayMask<ReplaceWithMask<Type>>(ctx, array, mask, replacements,
+ output);
+ }
+};
+
+template <typename Type>
+struct ReplaceWithMask<Type, enable_if_null<Type>> {
+ static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
+ const BooleanScalar& mask, const Datum& replacements,
+ ArrayData* output) {
+ *output = array;
+ return Status::OK();
+ }
+ static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array,
+ const ArrayData& mask, const Datum& replacements,
+ ArrayData* output) {
+ *output = array;
+ return Status::OK();
+ }
+};
+
+template <typename Type>
+struct ReplaceWithMask<Type, enable_if_base_binary<Type>> {
+ using offset_type = typename Type::offset_type;
+ using BuilderType = typename TypeTraits<Type>::BuilderType;
+
+ static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
+ const BooleanScalar& mask, const Datum& replacements,
+ ArrayData* output) {
+ if (!mask.is_valid) {
+ // Output = null
+ ARROW_ASSIGN_OR_RAISE(
+ auto replacement_array,
+ MakeArrayOfNull(array.type, array.length, ctx->memory_pool()));
+ *output = *replacement_array->data();
+ } else if (mask.value) {
+ // Output = replacement
+ if (replacements.is_scalar()) {
+ ARROW_ASSIGN_OR_RAISE(auto replacement_array,
+ MakeArrayFromScalar(*replacements.scalar(), array.length,
+ ctx->memory_pool()));
+ *output = *replacement_array->data();
+ } else {
+ const ArrayData& replacement_array = *replacements.array();
+ if (replacement_array.length < array.length) {
+ return ReplacementArrayTooShort(array.length, replacement_array.length);
+ }
+ *output = replacement_array;
+ output->length = array.length;
+ }
+ } else {
+ // Output = input
+ *output = array;
+ }
+ return Status::OK();
+ }
+ static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array,
+ const ArrayData& mask, const Datum& replacements,
+ ArrayData* output) {
+ BuilderType builder(array.type, ctx->memory_pool());
+ RETURN_NOT_OK(builder.Reserve(array.length));
+ RETURN_NOT_OK(builder.ReserveData(array.buffers[2]->size()));
+ int64_t source_offset = 0;
+ int64_t replacements_offset = 0;
+ RETURN_NOT_OK(VisitArrayDataInline<BooleanType>(
+ mask,
+ [&](bool replace) {
+ if (replace && replacements.is_scalar()) {
+ const Scalar& scalar = *replacements.scalar();
+ if (scalar.is_valid) {
+ RETURN_NOT_OK(builder.Append(UnboxScalar<Type>::Unbox(scalar)));
+ } else {
+ RETURN_NOT_OK(builder.AppendNull());
+ }
+ } else {
+ const ArrayData& source = replace ? *replacements.array() : array;
+ const int64_t offset = replace ? replacements_offset++ : source_offset;
+ if (!source.MayHaveNulls() ||
+ BitUtil::GetBit(source.buffers[0]->data(), source.offset + offset)) {
+ const uint8_t* data = source.buffers[2]->data();
+ const offset_type* offsets = source.GetValues<offset_type>(1);
+ const offset_type offset0 = offsets[offset];
+ const offset_type offset1 = offsets[offset + 1];
+ RETURN_NOT_OK(builder.Append(data + offset0, offset1 - offset0));
+ } else {
+ RETURN_NOT_OK(builder.AppendNull());
+ }
+ }
+ source_offset++;
+ return Status::OK();
+ },
+ [&]() {
+ RETURN_NOT_OK(builder.AppendNull());
+ source_offset++;
+ return Status::OK();
+ }));
+ std::shared_ptr<Array> temp_output;
+ RETURN_NOT_OK(builder.Finish(&temp_output));
+ *output = *temp_output->data();
+ // Builder type != logical type due to GenerateTypeAgnosticVarBinaryBase
+ output->type = array.type;
+ return Status::OK();
+ }
+};
+
+template <typename Type>
+struct ReplaceWithMaskFunctor {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const ArrayData& array = *batch[0].array();
+ const Datum& replacements = batch[2];
+ ArrayData* output = out->array().get();
+ output->length = array.length;
+
+ // Needed for FixedSizeBinary/parameterized types
+ if (!array.type->Equals(*replacements.type(), /*check_metadata=*/false)) {
+ return Status::Invalid("Replacements must be of same type (expected ",
+ array.type->ToString(), " but got ",
+ replacements.type()->ToString(), ")");
+ }
+
+ if (!replacements.is_array() && !replacements.is_scalar()) {
+ return Status::Invalid("Replacements must be array or scalar");
+ }
+
+ if (batch[1].is_scalar()) {
+ return ReplaceWithMask<Type>::ExecScalarMask(
+ ctx, array, batch[1].scalar_as<BooleanScalar>(), replacements, output);
+ }
+ const ArrayData& mask = *batch[1].array();
+ if (array.length != mask.length) {
+ return Status::Invalid("Mask must be of same length as array (expected ",
+ array.length, " items but got ", mask.length, " items)");
+ }
+ return ReplaceWithMask<Type>::ExecArrayMask(ctx, array, mask, replacements, output);
+ }
+};
+
+} // namespace
+
+const FunctionDoc replace_with_mask_doc(
+ "Replace items using a mask and replacement values",
+ ("Given an array and a Boolean mask (either scalar or of equal length), "
+ "along with replacement values (either scalar or array), "
+ "each element of the array for which the corresponding mask element is "
+ "true will be replaced by the next value from the replacements, "
+ "or with null if the mask is null. "
+ "Hence, for replacement arrays, len(replacements) == sum(mask == true)."),
+ {"values", "mask", "replacements"});
+
+void RegisterVectorReplace(FunctionRegistry* registry) {
+ auto func = std::make_shared<VectorFunction>("replace_with_mask", Arity::Ternary(),
+ &replace_with_mask_doc);
+ auto add_kernel = [&](detail::GetTypeId get_id, ArrayKernelExec exec) {
+ VectorKernel kernel;
+ kernel.can_execute_chunkwise = false;
+ if (is_fixed_width(get_id.id)) {
+ kernel.null_handling = NullHandling::type::COMPUTED_PREALLOCATE;
+ } else {
+ kernel.can_write_into_slices = false;
+ kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE;
+ }
+ kernel.mem_allocation = MemAllocation::type::PREALLOCATE;
+ kernel.signature = KernelSignature::Make(
+ {InputType::Array(get_id.id), InputType(boolean()), InputType(get_id.id)},
+ OutputType(FirstType));
+ kernel.exec = std::move(exec);
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ };
+ auto add_primitive_kernel = [&](detail::GetTypeId get_id) {
+ add_kernel(get_id, GenerateTypeAgnosticPrimitive<ReplaceWithMaskFunctor>(get_id));
+ };
+ for (const auto& ty : NumericTypes()) {
+ add_primitive_kernel(ty);
+ }
+ for (const auto& ty : TemporalTypes()) {
+ add_primitive_kernel(ty);
+ }
+ for (const auto& ty : IntervalTypes()) {
+ add_primitive_kernel(ty);
+ }
+ add_primitive_kernel(null());
+ add_primitive_kernel(boolean());
+ add_kernel(Type::FIXED_SIZE_BINARY, ReplaceWithMaskFunctor<FixedSizeBinaryType>::Exec);
+ add_kernel(Type::DECIMAL128, ReplaceWithMaskFunctor<Decimal128Type>::Exec);
+ add_kernel(Type::DECIMAL256, ReplaceWithMaskFunctor<Decimal256Type>::Exec);
+ for (const auto& ty : BaseBinaryTypes()) {
+ add_kernel(ty->id(), GenerateTypeAgnosticVarBinaryBase<ReplaceWithMaskFunctor>(*ty));
+ }
+ // TODO: list types
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ // TODO(ARROW-9431): "replace_with_indices"
+}
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc
new file mode 100644
index 000000000..719969d46
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc
@@ -0,0 +1,89 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <benchmark/benchmark.h>
+
+#include "arrow/array.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+
+#include "arrow/compute/api_vector.h"
+
+namespace arrow {
+namespace compute {
+
+using ::arrow::internal::checked_pointer_cast;
+
+static constexpr random::SeedType kRandomSeed = 0xabcdef;
+static constexpr random::SeedType kLongLength = 16384;
+
+static std::shared_ptr<Array> MakeReplacements(random::RandomArrayGenerator* generator,
+ const BooleanArray& mask) {
+ int64_t count = 0;
+ for (int64_t i = 0; i < mask.length(); i++) {
+ count += mask.Value(i) && mask.IsValid(i);
+ }
+ return generator->Int64(count, /*min=*/-65536, /*max=*/65536, /*null_probability=*/0.1);
+}
+
+static void ReplaceWithMaskLowSelectivityBench(
+ benchmark::State& state) { // NOLINT non-const reference
+ random::RandomArrayGenerator generator(kRandomSeed);
+ const int64_t len = state.range(0);
+ const int64_t offset = state.range(1);
+
+ auto values =
+ generator.Int64(len, /*min=*/-65536, /*max=*/65536, /*null_probability=*/0.1)
+ ->Slice(offset);
+ auto mask = checked_pointer_cast<BooleanArray>(
+ generator.Boolean(len, /*true_probability=*/0.1, /*null_probability=*/0.1)
+ ->Slice(offset));
+ auto replacements = MakeReplacements(&generator, *mask);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(ReplaceWithMask(values, mask, replacements));
+ }
+ state.SetBytesProcessed(state.iterations() * (len - offset) * 8);
+}
+
+static void ReplaceWithMaskHighSelectivityBench(
+ benchmark::State& state) { // NOLINT non-const reference
+ random::RandomArrayGenerator generator(kRandomSeed);
+ const int64_t len = state.range(0);
+ const int64_t offset = state.range(1);
+
+ auto values =
+ generator.Int64(len, /*min=*/-65536, /*max=*/65536, /*null_probability=*/0.1)
+ ->Slice(offset);
+ auto mask = checked_pointer_cast<BooleanArray>(
+ generator.Boolean(len, /*true_probability=*/0.9, /*null_probability=*/0.1)
+ ->Slice(offset));
+ auto replacements = MakeReplacements(&generator, *mask);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(ReplaceWithMask(values, mask, replacements));
+ }
+ state.SetBytesProcessed(state.iterations() * (len - offset) * 8);
+}
+
+BENCHMARK(ReplaceWithMaskLowSelectivityBench)->Args({kLongLength, 0});
+BENCHMARK(ReplaceWithMaskLowSelectivityBench)->Args({kLongLength, 99});
+BENCHMARK(ReplaceWithMaskHighSelectivityBench)->Args({kLongLength, 0});
+BENCHMARK(ReplaceWithMaskHighSelectivityBench)->Args({kLongLength, 99});
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_replace_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_replace_test.cc
new file mode 100644
index 000000000..9eecc5309
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_replace_test.cc
@@ -0,0 +1,677 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/make_unique.h"
+
+namespace arrow {
+namespace compute {
+
+using arrow::internal::checked_pointer_cast;
+
+template <typename T>
+class TestReplaceKernel : public ::testing::Test {
+ protected:
+ virtual std::shared_ptr<DataType> type() = 0;
+
+ using ReplaceFunction = std::function<Result<Datum>(const Datum&, const Datum&,
+ const Datum&, ExecContext*)>;
+
+ void SetUp() override { equal_options_ = equal_options_.nans_equal(true); }
+
+ Datum mask_scalar(bool value) { return Datum(std::make_shared<BooleanScalar>(value)); }
+
+ Datum null_mask_scalar() {
+ auto scalar = std::make_shared<BooleanScalar>(true);
+ scalar->is_valid = false;
+ return Datum(std::move(scalar));
+ }
+
+ Datum scalar(const std::string& json) { return ScalarFromJSON(type(), json); }
+
+ std::shared_ptr<Array> array(const std::string& value) {
+ return ArrayFromJSON(type(), value);
+ }
+
+ std::shared_ptr<Array> mask(const std::string& value) {
+ return ArrayFromJSON(boolean(), value);
+ }
+
+ Status AssertRaises(ReplaceFunction func, const std::shared_ptr<Array>& array,
+ const Datum& mask, const std::shared_ptr<Array>& replacements) {
+ auto result = func(array, mask, replacements, nullptr);
+ EXPECT_FALSE(result.ok());
+ return result.status();
+ }
+
+ void Assert(ReplaceFunction func, const std::shared_ptr<Array>& array,
+ const Datum& mask, Datum replacements,
+ const std::shared_ptr<Array>& expected) {
+ SCOPED_TRACE("Replacements: " + (replacements.is_array()
+ ? replacements.make_array()->ToString()
+ : replacements.scalar()->ToString()));
+ SCOPED_TRACE("Mask: " + (mask.is_array() ? mask.make_array()->ToString()
+ : mask.scalar()->ToString()));
+ SCOPED_TRACE("Array: " + array->ToString());
+
+ ASSERT_OK_AND_ASSIGN(auto actual, func(array, mask, replacements, nullptr));
+ ASSERT_TRUE(actual.is_array());
+ ASSERT_OK(actual.make_array()->ValidateFull());
+
+ AssertArraysApproxEqual(*expected, *actual.make_array(), /*verbose=*/true,
+ equal_options_);
+ }
+
+ std::shared_ptr<Array> NaiveImpl(
+ const typename TypeTraits<T>::ArrayType& array, const BooleanArray& mask,
+ const typename TypeTraits<T>::ArrayType& replacements) {
+ auto length = array.length();
+ auto builder = arrow::internal::make_unique<typename TypeTraits<T>::BuilderType>(
+ default_type_instance<T>(), default_memory_pool());
+ int64_t replacement_offset = 0;
+ for (int64_t i = 0; i < length; ++i) {
+ if (mask.IsValid(i)) {
+ if (mask.Value(i)) {
+ if (replacements.IsValid(replacement_offset)) {
+ ARROW_EXPECT_OK(builder->Append(replacements.Value(replacement_offset++)));
+ } else {
+ ARROW_EXPECT_OK(builder->AppendNull());
+ replacement_offset++;
+ }
+ } else {
+ if (array.IsValid(i)) {
+ ARROW_EXPECT_OK(builder->Append(array.Value(i)));
+ } else {
+ ARROW_EXPECT_OK(builder->AppendNull());
+ }
+ }
+ } else {
+ ARROW_EXPECT_OK(builder->AppendNull());
+ }
+ }
+ EXPECT_OK_AND_ASSIGN(auto expected, builder->Finish());
+ return expected;
+ }
+
+ EqualOptions equal_options_ = EqualOptions::Defaults();
+};
+
+template <typename T>
+class TestReplaceNumeric : public TestReplaceKernel<T> {
+ protected:
+ std::shared_ptr<DataType> type() override { return default_type_instance<T>(); }
+};
+
+class TestReplaceBoolean : public TestReplaceKernel<BooleanType> {
+ protected:
+ std::shared_ptr<DataType> type() override {
+ return TypeTraits<BooleanType>::type_singleton();
+ }
+};
+
+class TestReplaceFixedSizeBinary : public TestReplaceKernel<FixedSizeBinaryType> {
+ protected:
+ std::shared_ptr<DataType> type() override { return fixed_size_binary(3); }
+};
+
+template <typename T>
+class TestReplaceDecimal : public TestReplaceKernel<T> {
+ protected:
+ std::shared_ptr<DataType> type() override { return default_type_instance<T>(); }
+};
+
+class TestReplaceDayTimeInterval : public TestReplaceKernel<DayTimeIntervalType> {
+ protected:
+ std::shared_ptr<DataType> type() override {
+ return TypeTraits<DayTimeIntervalType>::type_singleton();
+ }
+};
+
+template <typename T>
+class TestReplaceBinary : public TestReplaceKernel<T> {
+ protected:
+ std::shared_ptr<DataType> type() override { return default_type_instance<T>(); }
+};
+
+using NumericBasedTypes =
+ ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
+ Int32Type, Int64Type, FloatType, DoubleType, Date32Type, Date64Type,
+ Time32Type, Time64Type, TimestampType, MonthIntervalType>;
+
+TYPED_TEST_SUITE(TestReplaceNumeric, NumericBasedTypes);
+TYPED_TEST_SUITE(TestReplaceDecimal, DecimalArrowTypes);
+TYPED_TEST_SUITE(TestReplaceBinary, BinaryArrowTypes);
+
+TYPED_TEST(TestReplaceNumeric, ReplaceWithMask) {
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[]"));
+
+ this->Assert(ReplaceWithMask, this->array("[1]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[1]"));
+ this->Assert(ReplaceWithMask, this->array("[1]"), this->mask_scalar(true),
+ this->array("[0]"), this->array("[0]"));
+ this->Assert(ReplaceWithMask, this->array("[1]"), this->mask_scalar(true),
+ this->array("[2, 0]"), this->array("[2]"));
+ this->Assert(ReplaceWithMask, this->array("[1]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[0, 0]"), this->mask_scalar(false),
+ this->scalar("1"), this->array("[0, 0]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 0]"), this->mask_scalar(true),
+ this->scalar("1"), this->array("[1, 1]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 0]"), this->mask_scalar(true),
+ this->scalar("null"), this->array("[null, null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3]"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array("[0, 1, 2, 3]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3]"),
+ this->mask("[true, true, true, true]"), this->array("[10, 11, 12, 13]"),
+ this->array("[10, 11, 12, 13]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3]"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2, null]"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array("[0, 1, 2, null]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2, null]"),
+ this->mask("[true, true, true, true]"), this->array("[10, 11, 12, 13]"),
+ this->array("[10, 11, 12, 13]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2, null]"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3, 4, 5]"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array("[10, null]"), this->array("[10, null, 2, 3, null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array("[10, null]"),
+ this->array("[10, null, null, null, null, null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->scalar("1"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1]"), this->mask("[true, true]"),
+ this->scalar("10"), this->array("[10, 10]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1]"), this->mask("[true, true]"),
+ this->scalar("null"), this->array("[null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2]"),
+ this->mask("[true, false, null]"), this->scalar("10"),
+ this->array("[10, 1, null]"));
+}
+
+TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskRandom) {
+ using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
+ using CType = typename TypeTraits<TypeParam>::CType;
+ auto ty = this->type();
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+ const int64_t length = 1023;
+ std::vector<std::string> values = {"0.01", "0"};
+ // Clamp the range because date/time types don't print well with extreme values
+ values.push_back(std::to_string(static_cast<CType>(std::min<double>(
+ 16384.0, static_cast<double>(std::numeric_limits<CType>::max())))));
+ auto options = key_value_metadata({"null_probability", "min", "max"}, values);
+ auto array =
+ checked_pointer_cast<ArrayType>(rand.ArrayOf(*field("a", ty, options), length));
+ auto mask = checked_pointer_cast<BooleanArray>(
+ rand.ArrayOf(boolean(), length, /*null_probability=*/0.01));
+ const int64_t num_replacements = std::count_if(
+ mask->begin(), mask->end(),
+ [](util::optional<bool> value) { return value.has_value() && *value; });
+ auto replacements = checked_pointer_cast<ArrayType>(
+ rand.ArrayOf(*field("a", ty, options), num_replacements));
+ auto expected = this->NaiveImpl(*array, *mask, *replacements);
+
+ this->Assert(ReplaceWithMask, array, mask, replacements, expected);
+ for (int64_t slice = 1; slice <= 16; slice++) {
+ auto sliced_array = checked_pointer_cast<ArrayType>(array->Slice(slice, 15));
+ auto sliced_mask = checked_pointer_cast<BooleanArray>(mask->Slice(slice, 15));
+ auto new_expected = this->NaiveImpl(*sliced_array, *sliced_mask, *replacements);
+ this->Assert(ReplaceWithMask, sliced_array, sliced_mask, replacements, new_expected);
+ }
+}
+
+TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskErrors) {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Replacement array must be of appropriate length (expected 2 "
+ "items but got 1 items)"),
+ this->AssertRaises(ReplaceWithMask, this->array("[1, 2]"),
+ this->mask("[true, true]"), this->array("[0]")));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Replacement array must be of appropriate length (expected 1 "
+ "items but got 0 items)"),
+ this->AssertRaises(ReplaceWithMask, this->array("[1, 2]"),
+ this->mask("[true, null]"), this->array("[]")));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Mask must be of same length as array (expected 2 "
+ "items but got 0 items)"),
+ this->AssertRaises(ReplaceWithMask, this->array("[1, 2]"), this->mask("[]"),
+ this->array("[]")));
+}
+
+TEST_F(TestReplaceBoolean, ReplaceWithMask) {
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[]"));
+
+ this->Assert(ReplaceWithMask, this->array("[true]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[true]"));
+ this->Assert(ReplaceWithMask, this->array("[true]"), this->mask_scalar(true),
+ this->array("[false]"), this->array("[false]"));
+ this->Assert(ReplaceWithMask, this->array("[true]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask_scalar(false),
+ this->scalar("true"), this->array("[false, false]"));
+ this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask_scalar(true),
+ this->scalar("true"), this->array("[true, true]"));
+ this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask_scalar(true),
+ this->scalar("null"), this->array("[null, null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[true, true, true, true]"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array("[true, true, true, true]"));
+ this->Assert(ReplaceWithMask, this->array("[true, true, true, true]"),
+ this->mask("[true, true, true, true]"),
+ this->array("[false, false, false, false]"),
+ this->array("[false, false, false, false]"));
+ this->Assert(ReplaceWithMask, this->array("[true, true, true, true]"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[true, true, true, null]"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array("[true, true, true, null]"));
+ this->Assert(ReplaceWithMask, this->array("[true, true, true, null]"),
+ this->mask("[true, true, true, true]"),
+ this->array("[false, false, false, false]"),
+ this->array("[false, false, false, false]"));
+ this->Assert(ReplaceWithMask, this->array("[true, true, true, null]"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[true, true, true, true, true, true]"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array("[false, null]"),
+ this->array("[false, null, true, true, null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array("[false, null]"),
+ this->array("[false, null, null, null, null, null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->scalar("true"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"),
+ this->scalar("true"), this->array("[true, true]"));
+ this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"),
+ this->scalar("null"), this->array("[null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[false, false, false]"),
+ this->mask("[true, false, null]"), this->scalar("true"),
+ this->array("[true, false, null]"));
+}
+
+TEST_F(TestReplaceBoolean, ReplaceWithMaskErrors) {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Replacement array must be of appropriate length (expected 2 "
+ "items but got 1 items)"),
+ this->AssertRaises(ReplaceWithMask, this->array("[true, true]"),
+ this->mask("[true, true]"), this->array("[false]")));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Replacement array must be of appropriate length (expected 1 "
+ "items but got 0 items)"),
+ this->AssertRaises(ReplaceWithMask, this->array("[true, true]"),
+ this->mask("[true, null]"), this->array("[]")));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Mask must be of same length as array (expected 2 "
+ "items but got 0 items)"),
+ this->AssertRaises(ReplaceWithMask, this->array("[true, true]"), this->mask("[]"),
+ this->array("[]")));
+}
+
+TEST_F(TestReplaceFixedSizeBinary, ReplaceWithMask) {
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[]"));
+
+ this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(false),
+ this->array("[]"), this->array(R"(["foo"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(true),
+ this->array(R"(["bar"])"), this->array(R"(["bar"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[null]"));
+
+ this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"),
+ this->mask_scalar(false), this->scalar(R"("baz")"),
+ this->array(R"(["foo", "bar"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true),
+ this->scalar(R"("baz")"), this->array(R"(["baz", "baz"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true),
+ this->scalar("null"), this->array(R"([null, null])"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", "ddd"])"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array(R"(["aaa", "bbb", "ccc", "ddd"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", "ddd"])"),
+ this->mask("[true, true, true, true]"),
+ this->array(R"(["eee", "fff", "ggg", "hhh"])"),
+ this->array(R"(["eee", "fff", "ggg", "hhh"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", "ddd"])"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array(R"([null, null, null, null])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", null])"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array(R"(["aaa", "bbb", "ccc", null])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", null])"),
+ this->mask("[true, true, true, true]"),
+ this->array(R"(["eee", "fff", "ggg", "hhh"])"),
+ this->array(R"(["eee", "fff", "ggg", "hhh"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", null])"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array(R"([null, null, null, null])"));
+ this->Assert(ReplaceWithMask,
+ this->array(R"(["aaa", "bbb", "ccc", "ddd", "eee", "fff"])"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array(R"(["ggg", null])"),
+ this->array(R"(["ggg", null, "ccc", "ddd", null, null])"));
+ this->Assert(ReplaceWithMask, this->array(R"([null, null, null, null, null, null])"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array(R"(["aaa", null])"),
+ this->array(R"(["aaa", null, null, null, null, null])"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"),
+ this->scalar(R"("zzz")"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb"])"),
+ this->mask("[true, true]"), this->scalar(R"("zzz")"),
+ this->array(R"(["zzz", "zzz"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb"])"),
+ this->mask("[true, true]"), this->scalar("null"),
+ this->array("[null, null]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc"])"),
+ this->mask("[true, false, null]"), this->scalar(R"("zzz")"),
+ this->array(R"(["zzz", "bbb", null])"));
+}
+
+TEST_F(TestReplaceFixedSizeBinary, ReplaceWithMaskErrors) {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::AllOf(
+ ::testing::HasSubstr("Replacements must be of same type (expected "),
+ ::testing::HasSubstr(this->type()->ToString()),
+ ::testing::HasSubstr("but got fixed_size_binary[2]")),
+ this->AssertRaises(ReplaceWithMask, this->array("[]"), this->mask_scalar(true),
+ ArrayFromJSON(fixed_size_binary(2), "[]")));
+}
+
+TYPED_TEST(TestReplaceDecimal, ReplaceWithMask) {
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[]"));
+
+ this->Assert(ReplaceWithMask, this->array(R"(["1.00"])"), this->mask_scalar(false),
+ this->array("[]"), this->array(R"(["1.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["1.00"])"), this->mask_scalar(true),
+ this->array(R"(["0.00"])"), this->array(R"(["0.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["1.00"])"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[null]"));
+
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "0.00"])"),
+ this->mask_scalar(false), this->scalar(R"("1.00")"),
+ this->array(R"(["0.00", "0.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "0.00"])"),
+ this->mask_scalar(true), this->scalar(R"("1.00")"),
+ this->array(R"(["1.00", "1.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "0.00"])"),
+ this->mask_scalar(true), this->scalar("null"),
+ this->array("[null, null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", "3.00"])"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array(R"(["0.00", "1.00", "2.00", "3.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", "3.00"])"),
+ this->mask("[true, true, true, true]"),
+ this->array(R"(["10.00", "11.00", "12.00", "13.00"])"),
+ this->array(R"(["10.00", "11.00", "12.00", "13.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", "3.00"])"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", null])"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array(R"(["0.00", "1.00", "2.00", null])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", null])"),
+ this->mask("[true, true, true, true]"),
+ this->array(R"(["10.00", "11.00", "12.00", "13.00"])"),
+ this->array(R"(["10.00", "11.00", "12.00", "13.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", null])"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(ReplaceWithMask,
+ this->array(R"(["0.00", "1.00", "2.00", "3.00", "4.00", "5.00"])"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array(R"(["10.00", null])"),
+ this->array(R"(["10.00", null, "2.00", "3.00", null, null])"));
+ this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array(R"(["10.00", null])"),
+ this->array(R"(["10.00", null, null, null, null, null])"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"),
+ this->scalar(R"("1.00")"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00"])"),
+ this->mask("[true, true]"), this->scalar(R"("10.00")"),
+ this->array(R"(["10.00", "10.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00"])"),
+ this->mask("[true, true]"), this->scalar("null"),
+ this->array("[null, null]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00"])"),
+ this->mask("[true, false, null]"), this->scalar(R"("10.00")"),
+ this->array(R"(["10.00", "1.00", null])"));
+}
+
+TEST_F(TestReplaceDayTimeInterval, ReplaceWithMask) {
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[]"));
+
+ this->Assert(ReplaceWithMask, this->array("[[1, 2]]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[[1, 2]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2]]"), this->mask_scalar(true),
+ this->array("[[3, 4]]"), this->array("[[3, 4]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2]]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), this->mask_scalar(false),
+ this->scalar("[7, 8]"), this->array("[[1, 2], [3, 4]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), this->mask_scalar(true),
+ this->scalar("[7, 8]"), this->array("[[7, 8], [7, 8]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), this->mask_scalar(true),
+ this->scalar("null"), this->array("[null, null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]"),
+ this->mask("[true, true, true, true]"),
+ this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]"),
+ this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], null]"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array("[[1, 2], [1, 2], [1, 2], null]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], null]"),
+ this->mask("[true, true, true, true]"),
+ this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]"),
+ this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], null]"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(
+ ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2], [1, 2], [1, 2]]"),
+ this->mask("[true, true, false, false, null, null]"), this->array("[[3, 4], null]"),
+ this->array("[[3, 4], null, [1, 2], [1, 2], null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array("[[3, 4], null]"),
+ this->array("[[3, 4], null, null, null, null, null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"),
+ this->scalar("[7, 8]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"),
+ this->mask("[true, true]"), this->scalar("[7, 8]"),
+ this->array("[[7, 8], [7, 8]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"),
+ this->mask("[true, true]"), this->scalar("null"),
+ this->array("[null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4], [5, 6]]"),
+ this->mask("[true, false, null]"), this->scalar("[7, 8]"),
+ this->array("[[7, 8], [3, 4], null]"));
+}
+
+TYPED_TEST(TestReplaceBinary, ReplaceWithMask) {
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[]"));
+
+ this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(false),
+ this->array("[]"), this->array(R"(["foo"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(true),
+ this->array(R"(["bar"])"), this->array(R"(["bar"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[null]"));
+
+ this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"),
+ this->mask_scalar(false), this->scalar(R"("baz")"),
+ this->array(R"(["foo", "bar"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true),
+ this->scalar(R"("baz")"), this->array(R"(["baz", "baz"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true),
+ this->scalar("null"), this->array(R"([null, null])"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", "dddd"])"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array(R"(["a", "bb", "ccc", "dddd"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", "dddd"])"),
+ this->mask("[true, true, true, true]"),
+ this->array(R"(["eeeee", "f", "ggg", "hhh"])"),
+ this->array(R"(["eeeee", "f", "ggg", "hhh"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", "dddd"])"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array(R"([null, null, null, null])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", null])"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array(R"(["a", "bb", "ccc", null])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", null])"),
+ this->mask("[true, true, true, true]"),
+ this->array(R"(["eeeee", "f", "ggg", "hhh"])"),
+ this->array(R"(["eeeee", "f", "ggg", "hhh"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", null])"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array(R"([null, null, null, null])"));
+ this->Assert(ReplaceWithMask,
+ this->array(R"(["a", "bb", "ccc", "dddd", "eeeee", "f"])"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array(R"(["ggg", null])"),
+ this->array(R"(["ggg", null, "ccc", "dddd", null, null])"));
+ this->Assert(ReplaceWithMask, this->array(R"([null, null, null, null, null, null])"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array(R"(["a", null])"),
+ this->array(R"(["a", null, null, null, null, null])"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"),
+ this->scalar(R"("zzz")"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb"])"), this->mask("[true, true]"),
+ this->scalar(R"("zzz")"), this->array(R"(["zzz", "zzz"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb"])"), this->mask("[true, true]"),
+ this->scalar("null"), this->array("[null, null]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc"])"),
+ this->mask("[true, false, null]"), this->scalar(R"("zzz")"),
+ this->array(R"(["zzz", "bb", null])"));
+}
+
+TYPED_TEST(TestReplaceBinary, ReplaceWithMaskRandom) {
+ using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
+ auto ty = this->type();
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+ const int64_t length = 1023;
+ auto options = key_value_metadata({{"null_probability", "0.01"}, {"max_length", "5"}});
+ auto array =
+ checked_pointer_cast<ArrayType>(rand.ArrayOf(*field("a", ty, options), length));
+ auto mask = checked_pointer_cast<BooleanArray>(
+ rand.ArrayOf(boolean(), length, /*null_probability=*/0.01));
+ const int64_t num_replacements = std::count_if(
+ mask->begin(), mask->end(),
+ [](util::optional<bool> value) { return value.has_value() && *value; });
+ auto replacements = checked_pointer_cast<ArrayType>(
+ rand.ArrayOf(*field("a", ty, options), num_replacements));
+ auto expected = this->NaiveImpl(*array, *mask, *replacements);
+
+ this->Assert(ReplaceWithMask, array, mask, replacements, expected);
+ for (int64_t slice = 1; slice <= 16; slice++) {
+ auto sliced_array = checked_pointer_cast<ArrayType>(array->Slice(slice, 15));
+ auto sliced_mask = checked_pointer_cast<BooleanArray>(mask->Slice(slice, 15));
+ auto new_expected = this->NaiveImpl(*sliced_array, *sliced_mask, *replacements);
+ this->Assert(ReplaceWithMask, sliced_array, sliced_mask, replacements, new_expected);
+ }
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_selection.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_selection.cc
new file mode 100644
index 000000000..2f859bea8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_selection.cc
@@ -0,0 +1,2442 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstring>
+#include <limits>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_binary.h"
+#include "arrow/array/array_dict.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/chunked_array.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/extension_type.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/int_util.h"
+
+namespace arrow {
+
+using internal::BinaryBitBlockCounter;
+using internal::BitBlockCount;
+using internal::BitBlockCounter;
+using internal::CheckIndexBounds;
+using internal::CopyBitmap;
+using internal::CountSetBits;
+using internal::GetArrayView;
+using internal::GetByteWidth;
+using internal::OptionalBitBlockCounter;
+using internal::OptionalBitIndexer;
+
+namespace compute {
+namespace internal {
+
+int64_t GetFilterOutputSize(const ArrayData& filter,
+ FilterOptions::NullSelectionBehavior null_selection) {
+ int64_t output_size = 0;
+
+ if (filter.MayHaveNulls()) {
+ const uint8_t* filter_is_valid = filter.buffers[0]->data();
+ BinaryBitBlockCounter bit_counter(filter.buffers[1]->data(), filter.offset,
+ filter_is_valid, filter.offset, filter.length);
+ int64_t position = 0;
+ if (null_selection == FilterOptions::EMIT_NULL) {
+ while (position < filter.length) {
+ BitBlockCount block = bit_counter.NextOrNotWord();
+ output_size += block.popcount;
+ position += block.length;
+ }
+ } else {
+ while (position < filter.length) {
+ BitBlockCount block = bit_counter.NextAndWord();
+ output_size += block.popcount;
+ position += block.length;
+ }
+ }
+ } else {
+ // The filter has no nulls, so we can use CountSetBits
+ output_size = CountSetBits(filter.buffers[1]->data(), filter.offset, filter.length);
+ }
+ return output_size;
+}
+
+namespace {
+
+template <typename IndexType>
+Result<std::shared_ptr<ArrayData>> GetTakeIndicesImpl(
+ const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection,
+ MemoryPool* memory_pool) {
+ using T = typename IndexType::c_type;
+
+ const uint8_t* filter_data = filter.buffers[1]->data();
+ const bool have_filter_nulls = filter.MayHaveNulls();
+ const uint8_t* filter_is_valid =
+ have_filter_nulls ? filter.buffers[0]->data() : nullptr;
+
+ if (have_filter_nulls && null_selection == FilterOptions::EMIT_NULL) {
+ // Most complex case: the filter may have nulls and we don't drop them.
+ // The logic is ternary:
+ // - filter is null: emit null
+ // - filter is valid and true: emit index
+ // - filter is valid and false: don't emit anything
+
+ typename TypeTraits<IndexType>::BuilderType builder(memory_pool);
+
+ // The position relative to the start of the filter
+ T position = 0;
+ // The current position taking the filter offset into account
+ int64_t position_with_offset = filter.offset;
+
+ // To count blocks where filter_data[i] || !filter_is_valid[i]
+ BinaryBitBlockCounter filter_counter(filter_data, filter.offset, filter_is_valid,
+ filter.offset, filter.length);
+ BitBlockCounter is_valid_counter(filter_is_valid, filter.offset, filter.length);
+ while (position < filter.length) {
+ // true OR NOT valid
+ BitBlockCount selected_or_null_block = filter_counter.NextOrNotWord();
+ if (selected_or_null_block.NoneSet()) {
+ position += selected_or_null_block.length;
+ position_with_offset += selected_or_null_block.length;
+ continue;
+ }
+ RETURN_NOT_OK(builder.Reserve(selected_or_null_block.popcount));
+
+ // If the values are all valid and the selected_or_null_block is full,
+ // then we can infer that all the values are true and skip the bit checking
+ BitBlockCount is_valid_block = is_valid_counter.NextWord();
+
+ if (selected_or_null_block.AllSet() && is_valid_block.AllSet()) {
+ // All the values are selected and non-null
+ for (int64_t i = 0; i < selected_or_null_block.length; ++i) {
+ builder.UnsafeAppend(position++);
+ }
+ position_with_offset += selected_or_null_block.length;
+ } else {
+ // Some of the values are false or null
+ for (int64_t i = 0; i < selected_or_null_block.length; ++i) {
+ if (BitUtil::GetBit(filter_is_valid, position_with_offset)) {
+ if (BitUtil::GetBit(filter_data, position_with_offset)) {
+ builder.UnsafeAppend(position);
+ }
+ } else {
+ // Null slot, so append a null
+ builder.UnsafeAppendNull();
+ }
+ ++position;
+ ++position_with_offset;
+ }
+ }
+ }
+ std::shared_ptr<ArrayData> result;
+ RETURN_NOT_OK(builder.FinishInternal(&result));
+ return result;
+ }
+
+ // Other cases don't emit nulls and are therefore simpler.
+ TypedBufferBuilder<T> builder(memory_pool);
+
+ if (have_filter_nulls) {
+ // The filter may have nulls, so we scan the validity bitmap and the filter
+ // data bitmap together.
+ DCHECK_EQ(null_selection, FilterOptions::DROP);
+
+ // The position relative to the start of the filter
+ T position = 0;
+ // The current position taking the filter offset into account
+ int64_t position_with_offset = filter.offset;
+
+ BinaryBitBlockCounter filter_counter(filter_data, filter.offset, filter_is_valid,
+ filter.offset, filter.length);
+ while (position < filter.length) {
+ BitBlockCount and_block = filter_counter.NextAndWord();
+ RETURN_NOT_OK(builder.Reserve(and_block.popcount));
+ if (and_block.AllSet()) {
+ // All the values are selected and non-null
+ for (int64_t i = 0; i < and_block.length; ++i) {
+ builder.UnsafeAppend(position++);
+ }
+ position_with_offset += and_block.length;
+ } else if (!and_block.NoneSet()) {
+ // Some of the values are false or null
+ for (int64_t i = 0; i < and_block.length; ++i) {
+ if (BitUtil::GetBit(filter_is_valid, position_with_offset) &&
+ BitUtil::GetBit(filter_data, position_with_offset)) {
+ builder.UnsafeAppend(position);
+ }
+ ++position;
+ ++position_with_offset;
+ }
+ } else {
+ position += and_block.length;
+ position_with_offset += and_block.length;
+ }
+ }
+ } else {
+ // The filter has no nulls, so we need only look for true values
+ RETURN_NOT_OK(::arrow::internal::VisitSetBitRuns(
+ filter_data, filter.offset, filter.length, [&](int64_t offset, int64_t length) {
+ // Append the consecutive run of indices
+ RETURN_NOT_OK(builder.Reserve(length));
+ for (int64_t i = 0; i < length; ++i) {
+ builder.UnsafeAppend(static_cast<T>(offset + i));
+ }
+ return Status::OK();
+ }));
+ }
+
+ const int64_t length = builder.length();
+ std::shared_ptr<Buffer> out_buffer;
+ RETURN_NOT_OK(builder.Finish(&out_buffer));
+ return std::make_shared<ArrayData>(TypeTraits<IndexType>::type_singleton(), length,
+ BufferVector{nullptr, out_buffer}, /*null_count=*/0);
+}
+
+} // namespace
+
+Result<std::shared_ptr<ArrayData>> GetTakeIndices(
+ const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection,
+ MemoryPool* memory_pool) {
+ DCHECK_EQ(filter.type->id(), Type::BOOL);
+ if (filter.length <= std::numeric_limits<uint16_t>::max()) {
+ return GetTakeIndicesImpl<UInt16Type>(filter, null_selection, memory_pool);
+ } else if (filter.length <= std::numeric_limits<uint32_t>::max()) {
+ return GetTakeIndicesImpl<UInt32Type>(filter, null_selection, memory_pool);
+ } else {
+ // Arrays over 4 billion elements, not especially likely.
+ return Status::NotImplemented(
+ "Filter length exceeds UINT32_MAX, "
+ "consider a different strategy for selecting elements");
+ }
+}
+
+namespace {
+
+using FilterState = OptionsWrapper<FilterOptions>;
+using TakeState = OptionsWrapper<TakeOptions>;
+
+Status PreallocateData(KernelContext* ctx, int64_t length, int bit_width,
+ bool allocate_validity, ArrayData* out) {
+ // Preallocate memory
+ out->length = length;
+ out->buffers.resize(2);
+
+ if (allocate_validity) {
+ ARROW_ASSIGN_OR_RAISE(out->buffers[0], ctx->AllocateBitmap(length));
+ }
+ if (bit_width == 1) {
+ ARROW_ASSIGN_OR_RAISE(out->buffers[1], ctx->AllocateBitmap(length));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(out->buffers[1], ctx->Allocate(length * bit_width / 8));
+ }
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Implement optimized take for primitive types from boolean to 1/2/4/8-byte
+// C-type based types. Use common implementation for every byte width and only
+// generate code for unsigned integer indices, since after boundschecking to
+// check for negative numbers in the indices we can safely reinterpret_cast
+// signed integers as unsigned.
+
+/// \brief The Take implementation for primitive (fixed-width) types does not
+/// use the logical Arrow type but rather the physical C type. This way we
+/// only generate one take function for each byte width.
+///
+/// This function assumes that the indices have been boundschecked.
+template <typename IndexCType, typename ValueCType>
+struct PrimitiveTakeImpl {
+ static void Exec(const PrimitiveArg& values, const PrimitiveArg& indices,
+ ArrayData* out_arr) {
+ auto values_data = reinterpret_cast<const ValueCType*>(values.data);
+ auto values_is_valid = values.is_valid;
+ auto values_offset = values.offset;
+
+ auto indices_data = reinterpret_cast<const IndexCType*>(indices.data);
+ auto indices_is_valid = indices.is_valid;
+ auto indices_offset = indices.offset;
+
+ auto out = out_arr->GetMutableValues<ValueCType>(1);
+ auto out_is_valid = out_arr->buffers[0]->mutable_data();
+ auto out_offset = out_arr->offset;
+
+ // If either the values or indices have nulls, we preemptively zero out the
+ // out validity bitmap so that we don't have to use ClearBit in each
+ // iteration for nulls.
+ if (values.null_count != 0 || indices.null_count != 0) {
+ BitUtil::SetBitsTo(out_is_valid, out_offset, indices.length, false);
+ }
+
+ OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset,
+ indices.length);
+ int64_t position = 0;
+ int64_t valid_count = 0;
+ while (position < indices.length) {
+ BitBlockCount block = indices_bit_counter.NextBlock();
+ if (values.null_count == 0) {
+ // Values are never null, so things are easier
+ valid_count += block.popcount;
+ if (block.popcount == block.length) {
+ // Fastest path: neither values nor index nulls
+ BitUtil::SetBitsTo(out_is_valid, out_offset + position, block.length, true);
+ for (int64_t i = 0; i < block.length; ++i) {
+ out[position] = values_data[indices_data[position]];
+ ++position;
+ }
+ } else if (block.popcount > 0) {
+ // Slow path: some indices but not all are null
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) {
+ // index is not null
+ BitUtil::SetBit(out_is_valid, out_offset + position);
+ out[position] = values_data[indices_data[position]];
+ } else {
+ out[position] = ValueCType{};
+ }
+ ++position;
+ }
+ } else {
+ memset(out + position, 0, sizeof(ValueCType) * block.length);
+ position += block.length;
+ }
+ } else {
+ // Values have nulls, so we must do random access into the values bitmap
+ if (block.popcount == block.length) {
+ // Faster path: indices are not null but values may be
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(values_is_valid,
+ values_offset + indices_data[position])) {
+ // value is not null
+ out[position] = values_data[indices_data[position]];
+ BitUtil::SetBit(out_is_valid, out_offset + position);
+ ++valid_count;
+ } else {
+ out[position] = ValueCType{};
+ }
+ ++position;
+ }
+ } else if (block.popcount > 0) {
+ // Slow path: some but not all indices are null. Since we are doing
+ // random access in general we have to check the value nullness one by
+ // one.
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(indices_is_valid, indices_offset + position) &&
+ BitUtil::GetBit(values_is_valid,
+ values_offset + indices_data[position])) {
+ // index is not null && value is not null
+ out[position] = values_data[indices_data[position]];
+ BitUtil::SetBit(out_is_valid, out_offset + position);
+ ++valid_count;
+ } else {
+ out[position] = ValueCType{};
+ }
+ ++position;
+ }
+ } else {
+ memset(out + position, 0, sizeof(ValueCType) * block.length);
+ position += block.length;
+ }
+ }
+ }
+ out_arr->null_count = out_arr->length - valid_count;
+ }
+};
+
+template <typename IndexCType>
+struct BooleanTakeImpl {
+ static void Exec(const PrimitiveArg& values, const PrimitiveArg& indices,
+ ArrayData* out_arr) {
+ const uint8_t* values_data = values.data;
+ auto values_is_valid = values.is_valid;
+ auto values_offset = values.offset;
+
+ auto indices_data = reinterpret_cast<const IndexCType*>(indices.data);
+ auto indices_is_valid = indices.is_valid;
+ auto indices_offset = indices.offset;
+
+ auto out = out_arr->buffers[1]->mutable_data();
+ auto out_is_valid = out_arr->buffers[0]->mutable_data();
+ auto out_offset = out_arr->offset;
+
+ // If either the values or indices have nulls, we preemptively zero out the
+ // out validity bitmap so that we don't have to use ClearBit in each
+ // iteration for nulls.
+ if (values.null_count != 0 || indices.null_count != 0) {
+ BitUtil::SetBitsTo(out_is_valid, out_offset, indices.length, false);
+ }
+ // Avoid uninitialized data in values array
+ BitUtil::SetBitsTo(out, out_offset, indices.length, false);
+
+ auto PlaceDataBit = [&](int64_t loc, IndexCType index) {
+ BitUtil::SetBitTo(out, out_offset + loc,
+ BitUtil::GetBit(values_data, values_offset + index));
+ };
+
+ OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset,
+ indices.length);
+ int64_t position = 0;
+ int64_t valid_count = 0;
+ while (position < indices.length) {
+ BitBlockCount block = indices_bit_counter.NextBlock();
+ if (values.null_count == 0) {
+ // Values are never null, so things are easier
+ valid_count += block.popcount;
+ if (block.popcount == block.length) {
+ // Fastest path: neither values nor index nulls
+ BitUtil::SetBitsTo(out_is_valid, out_offset + position, block.length, true);
+ for (int64_t i = 0; i < block.length; ++i) {
+ PlaceDataBit(position, indices_data[position]);
+ ++position;
+ }
+ } else if (block.popcount > 0) {
+ // Slow path: some but not all indices are null
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) {
+ // index is not null
+ BitUtil::SetBit(out_is_valid, out_offset + position);
+ PlaceDataBit(position, indices_data[position]);
+ }
+ ++position;
+ }
+ } else {
+ position += block.length;
+ }
+ } else {
+ // Values have nulls, so we must do random access into the values bitmap
+ if (block.popcount == block.length) {
+ // Faster path: indices are not null but values may be
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(values_is_valid,
+ values_offset + indices_data[position])) {
+ // value is not null
+ BitUtil::SetBit(out_is_valid, out_offset + position);
+ PlaceDataBit(position, indices_data[position]);
+ ++valid_count;
+ }
+ ++position;
+ }
+ } else if (block.popcount > 0) {
+ // Slow path: some but not all indices are null. Since we are doing
+ // random access in general we have to check the value nullness one by
+ // one.
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) {
+ // index is not null
+ if (BitUtil::GetBit(values_is_valid,
+ values_offset + indices_data[position])) {
+ // value is not null
+ PlaceDataBit(position, indices_data[position]);
+ BitUtil::SetBit(out_is_valid, out_offset + position);
+ ++valid_count;
+ }
+ }
+ ++position;
+ }
+ } else {
+ position += block.length;
+ }
+ }
+ }
+ out_arr->null_count = out_arr->length - valid_count;
+ }
+};
+
+template <template <typename...> class TakeImpl, typename... Args>
+void TakeIndexDispatch(const PrimitiveArg& values, const PrimitiveArg& indices,
+ ArrayData* out) {
+ // With the simplifying assumption that boundschecking has taken place
+ // already at a higher level, we can now assume that the index values are all
+ // non-negative. Thus, we can interpret signed integers as unsigned and avoid
+ // having to generate double the amount of binary code to handle each integer
+ // width.
+ switch (indices.bit_width) {
+ case 8:
+ return TakeImpl<uint8_t, Args...>::Exec(values, indices, out);
+ case 16:
+ return TakeImpl<uint16_t, Args...>::Exec(values, indices, out);
+ case 32:
+ return TakeImpl<uint32_t, Args...>::Exec(values, indices, out);
+ case 64:
+ return TakeImpl<uint64_t, Args...>::Exec(values, indices, out);
+ default:
+ DCHECK(false) << "Invalid indices byte width";
+ break;
+ }
+}
+
+Status PrimitiveTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (TakeState::Get(ctx).boundscheck) {
+ RETURN_NOT_OK(CheckIndexBounds(*batch[1].array(), batch[0].length()));
+ }
+
+ PrimitiveArg values = GetPrimitiveArg(*batch[0].array());
+ PrimitiveArg indices = GetPrimitiveArg(*batch[1].array());
+
+ ArrayData* out_arr = out->mutable_array();
+
+ // TODO: When neither values nor indices contain nulls, we can skip
+ // allocating the validity bitmap altogether and save time and space. A
+ // streamlined PrimitiveTakeImpl would need to be written that skips all
+ // interactions with the output validity bitmap, though.
+ RETURN_NOT_OK(PreallocateData(ctx, indices.length, values.bit_width,
+ /*allocate_validity=*/true, out_arr));
+ switch (values.bit_width) {
+ case 1:
+ TakeIndexDispatch<BooleanTakeImpl>(values, indices, out_arr);
+ break;
+ case 8:
+ TakeIndexDispatch<PrimitiveTakeImpl, int8_t>(values, indices, out_arr);
+ break;
+ case 16:
+ TakeIndexDispatch<PrimitiveTakeImpl, int16_t>(values, indices, out_arr);
+ break;
+ case 32:
+ TakeIndexDispatch<PrimitiveTakeImpl, int32_t>(values, indices, out_arr);
+ break;
+ case 64:
+ TakeIndexDispatch<PrimitiveTakeImpl, int64_t>(values, indices, out_arr);
+ break;
+ default:
+ DCHECK(false) << "Invalid values byte width";
+ break;
+ }
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Optimized and streamlined filter for primitive types
+
+// Use either BitBlockCounter or BinaryBitBlockCounter to quickly scan filter a
+// word at a time for the DROP selection type.
+class DropNullCounter {
+ public:
+ // validity bitmap may be null
+ DropNullCounter(const uint8_t* validity, const uint8_t* data, int64_t offset,
+ int64_t length)
+ : data_counter_(data, offset, length),
+ data_and_validity_counter_(data, offset, validity, offset, length),
+ has_validity_(validity != nullptr) {}
+
+ BitBlockCount NextBlock() {
+ if (has_validity_) {
+ // filter is true AND not null
+ return data_and_validity_counter_.NextAndWord();
+ } else {
+ return data_counter_.NextWord();
+ }
+ }
+
+ private:
+ // For when just data is present, but no validity bitmap
+ BitBlockCounter data_counter_;
+
+ // For when both validity bitmap and data are present
+ BinaryBitBlockCounter data_and_validity_counter_;
+ const bool has_validity_;
+};
+
+/// \brief The Filter implementation for primitive (fixed-width) types does not
+/// use the logical Arrow type but rather the physical C type. This way we only
+/// generate one take function for each byte width. We use the same
+/// implementation here for boolean and fixed-byte-size inputs with some
+/// template specialization.
+template <typename ArrowType>
+class PrimitiveFilterImpl {
+ public:
+ using T = typename std::conditional<std::is_same<ArrowType, BooleanType>::value,
+ uint8_t, typename ArrowType::c_type>::type;
+
+ PrimitiveFilterImpl(const PrimitiveArg& values, const PrimitiveArg& filter,
+ FilterOptions::NullSelectionBehavior null_selection,
+ ArrayData* out_arr)
+ : values_is_valid_(values.is_valid),
+ values_data_(reinterpret_cast<const T*>(values.data)),
+ values_null_count_(values.null_count),
+ values_offset_(values.offset),
+ values_length_(values.length),
+ filter_is_valid_(filter.is_valid),
+ filter_data_(filter.data),
+ filter_null_count_(filter.null_count),
+ filter_offset_(filter.offset),
+ null_selection_(null_selection) {
+ if (out_arr->buffers[0] != nullptr) {
+ // May not be allocated if neither filter nor values contains nulls
+ out_is_valid_ = out_arr->buffers[0]->mutable_data();
+ }
+ out_data_ = reinterpret_cast<T*>(out_arr->buffers[1]->mutable_data());
+ out_offset_ = out_arr->offset;
+ out_length_ = out_arr->length;
+ out_position_ = 0;
+ }
+
+ void ExecNonNull() {
+ // Fast filter when values and filter are not null
+ ::arrow::internal::VisitSetBitRunsVoid(
+ filter_data_, filter_offset_, values_length_,
+ [&](int64_t position, int64_t length) { WriteValueSegment(position, length); });
+ }
+
+ void Exec() {
+ if (filter_null_count_ == 0 && values_null_count_ == 0) {
+ return ExecNonNull();
+ }
+
+ // Bit counters used for both null_selection behaviors
+ DropNullCounter drop_null_counter(filter_is_valid_, filter_data_, filter_offset_,
+ values_length_);
+ OptionalBitBlockCounter data_counter(values_is_valid_, values_offset_,
+ values_length_);
+ OptionalBitBlockCounter filter_valid_counter(filter_is_valid_, filter_offset_,
+ values_length_);
+
+ auto WriteNotNull = [&](int64_t index) {
+ BitUtil::SetBit(out_is_valid_, out_offset_ + out_position_);
+ // Increments out_position_
+ WriteValue(index);
+ };
+
+ auto WriteMaybeNull = [&](int64_t index) {
+ BitUtil::SetBitTo(out_is_valid_, out_offset_ + out_position_,
+ BitUtil::GetBit(values_is_valid_, values_offset_ + index));
+ // Increments out_position_
+ WriteValue(index);
+ };
+
+ int64_t in_position = 0;
+ while (in_position < values_length_) {
+ BitBlockCount filter_block = drop_null_counter.NextBlock();
+ BitBlockCount filter_valid_block = filter_valid_counter.NextWord();
+ BitBlockCount data_block = data_counter.NextWord();
+ if (filter_block.AllSet() && data_block.AllSet()) {
+ // Fastest path: all values in block are included and not null
+ BitUtil::SetBitsTo(out_is_valid_, out_offset_ + out_position_,
+ filter_block.length, true);
+ WriteValueSegment(in_position, filter_block.length);
+ in_position += filter_block.length;
+ } else if (filter_block.AllSet()) {
+ // Faster: all values are selected, but some values are null
+ // Batch copy bits from values validity bitmap to output validity bitmap
+ CopyBitmap(values_is_valid_, values_offset_ + in_position, filter_block.length,
+ out_is_valid_, out_offset_ + out_position_);
+ WriteValueSegment(in_position, filter_block.length);
+ in_position += filter_block.length;
+ } else if (filter_block.NoneSet() && null_selection_ == FilterOptions::DROP) {
+ // For this exceedingly common case in low-selectivity filters we can
+ // skip further analysis of the data and move on to the next block.
+ in_position += filter_block.length;
+ } else {
+ // Some filter values are false or null
+ if (data_block.AllSet()) {
+ // No values are null
+ if (filter_valid_block.AllSet()) {
+ // Filter is non-null but some values are false
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ if (BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) {
+ WriteNotNull(in_position);
+ }
+ ++in_position;
+ }
+ } else if (null_selection_ == FilterOptions::DROP) {
+ // If any values are selected, they ARE NOT null
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ if (BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position) &&
+ BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) {
+ WriteNotNull(in_position);
+ }
+ ++in_position;
+ }
+ } else { // null_selection == FilterOptions::EMIT_NULL
+ // Data values in this block are not null
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ const bool is_valid =
+ BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position);
+ if (is_valid &&
+ BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) {
+ // Filter slot is non-null and set
+ WriteNotNull(in_position);
+ } else if (!is_valid) {
+ // Filter slot is null, so we have a null in the output
+ BitUtil::ClearBit(out_is_valid_, out_offset_ + out_position_);
+ WriteNull();
+ }
+ ++in_position;
+ }
+ }
+ } else { // !data_block.AllSet()
+ // Some values are null
+ if (filter_valid_block.AllSet()) {
+ // Filter is non-null but some values are false
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ if (BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) {
+ WriteMaybeNull(in_position);
+ }
+ ++in_position;
+ }
+ } else if (null_selection_ == FilterOptions::DROP) {
+ // If any values are selected, they ARE NOT null
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ if (BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position) &&
+ BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) {
+ WriteMaybeNull(in_position);
+ }
+ ++in_position;
+ }
+ } else { // null_selection == FilterOptions::EMIT_NULL
+ // Data values in this block are not null
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ const bool is_valid =
+ BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position);
+ if (is_valid &&
+ BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) {
+ // Filter slot is non-null and set
+ WriteMaybeNull(in_position);
+ } else if (!is_valid) {
+ // Filter slot is null, so we have a null in the output
+ BitUtil::ClearBit(out_is_valid_, out_offset_ + out_position_);
+ WriteNull();
+ }
+ ++in_position;
+ }
+ }
+ }
+ } // !filter_block.AllSet()
+ } // while(in_position < values_length_)
+ }
+
+ // Write the next out_position given the selected in_position for the input
+ // data and advance out_position
+ void WriteValue(int64_t in_position) {
+ out_data_[out_position_++] = values_data_[in_position];
+ }
+
+ void WriteValueSegment(int64_t in_start, int64_t length) {
+ std::memcpy(out_data_ + out_position_, values_data_ + in_start, length * sizeof(T));
+ out_position_ += length;
+ }
+
+ void WriteNull() {
+ // Zero the memory
+ out_data_[out_position_++] = T{};
+ }
+
+ private:
+ const uint8_t* values_is_valid_;
+ const T* values_data_;
+ int64_t values_null_count_;
+ int64_t values_offset_;
+ int64_t values_length_;
+ const uint8_t* filter_is_valid_;
+ const uint8_t* filter_data_;
+ int64_t filter_null_count_;
+ int64_t filter_offset_;
+ FilterOptions::NullSelectionBehavior null_selection_;
+ uint8_t* out_is_valid_;
+ T* out_data_;
+ int64_t out_offset_;
+ int64_t out_length_;
+ int64_t out_position_;
+};
+
+template <>
+inline void PrimitiveFilterImpl<BooleanType>::WriteValue(int64_t in_position) {
+ BitUtil::SetBitTo(out_data_, out_offset_ + out_position_++,
+ BitUtil::GetBit(values_data_, values_offset_ + in_position));
+}
+
+template <>
+inline void PrimitiveFilterImpl<BooleanType>::WriteValueSegment(int64_t in_start,
+ int64_t length) {
+ CopyBitmap(values_data_, values_offset_ + in_start, length, out_data_,
+ out_offset_ + out_position_);
+ out_position_ += length;
+}
+
+template <>
+inline void PrimitiveFilterImpl<BooleanType>::WriteNull() {
+ // Zero the bit
+ BitUtil::ClearBit(out_data_, out_offset_ + out_position_++);
+}
+
+Status PrimitiveFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ PrimitiveArg values = GetPrimitiveArg(*batch[0].array());
+ PrimitiveArg filter = GetPrimitiveArg(*batch[1].array());
+ FilterOptions::NullSelectionBehavior null_selection =
+ FilterState::Get(ctx).null_selection_behavior;
+
+ int64_t output_length = GetFilterOutputSize(*batch[1].array(), null_selection);
+
+ ArrayData* out_arr = out->mutable_array();
+
+ // The output precomputed null count is unknown except in the narrow
+ // condition that all the values are non-null and the filter will not cause
+ // any new nulls to be created.
+ if (values.null_count == 0 &&
+ (null_selection == FilterOptions::DROP || filter.null_count == 0)) {
+ out_arr->null_count = 0;
+ } else {
+ out_arr->null_count = kUnknownNullCount;
+ }
+
+ // When neither the values nor filter is known to have any nulls, we will
+ // elect the optimized ExecNonNull path where there is no need to populate a
+ // validity bitmap.
+ bool allocate_validity = values.null_count != 0 || filter.null_count != 0;
+
+ RETURN_NOT_OK(
+ PreallocateData(ctx, output_length, values.bit_width, allocate_validity, out_arr));
+
+ switch (values.bit_width) {
+ case 1:
+ PrimitiveFilterImpl<BooleanType>(values, filter, null_selection, out_arr).Exec();
+ break;
+ case 8:
+ PrimitiveFilterImpl<UInt8Type>(values, filter, null_selection, out_arr).Exec();
+ break;
+ case 16:
+ PrimitiveFilterImpl<UInt16Type>(values, filter, null_selection, out_arr).Exec();
+ break;
+ case 32:
+ PrimitiveFilterImpl<UInt32Type>(values, filter, null_selection, out_arr).Exec();
+ break;
+ case 64:
+ PrimitiveFilterImpl<UInt64Type>(values, filter, null_selection, out_arr).Exec();
+ break;
+ default:
+ DCHECK(false) << "Invalid values bit width";
+ break;
+ }
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Optimized filter for base binary types (32-bit and 64-bit)
+
+#define BINARY_FILTER_SETUP_COMMON() \
+ auto raw_offsets = \
+ reinterpret_cast<const offset_type*>(values.buffers[1]->data()) + values.offset; \
+ const uint8_t* raw_data = values.buffers[2]->data(); \
+ \
+ TypedBufferBuilder<offset_type> offset_builder(ctx->memory_pool()); \
+ TypedBufferBuilder<uint8_t> data_builder(ctx->memory_pool()); \
+ RETURN_NOT_OK(offset_builder.Reserve(output_length + 1)); \
+ \
+ /* Presize the data builder with a rough estimate */ \
+ if (values.length > 0) { \
+ const double mean_value_length = (raw_offsets[values.length] - raw_offsets[0]) / \
+ static_cast<double>(values.length); \
+ RETURN_NOT_OK( \
+ data_builder.Reserve(static_cast<int64_t>(mean_value_length * output_length))); \
+ } \
+ int64_t space_available = data_builder.capacity(); \
+ offset_type offset = 0;
+
+#define APPEND_RAW_DATA(DATA, NBYTES) \
+ if (ARROW_PREDICT_FALSE(NBYTES > space_available)) { \
+ RETURN_NOT_OK(data_builder.Reserve(NBYTES)); \
+ space_available = data_builder.capacity() - data_builder.length(); \
+ } \
+ data_builder.UnsafeAppend(DATA, NBYTES); \
+ space_available -= NBYTES
+
+#define APPEND_SINGLE_VALUE() \
+ do { \
+ offset_type val_size = raw_offsets[in_position + 1] - raw_offsets[in_position]; \
+ APPEND_RAW_DATA(raw_data + raw_offsets[in_position], val_size); \
+ offset += val_size; \
+ } while (0)
+
+// Optimized binary filter for the case where neither values nor filter have
+// nulls
+template <typename Type>
+Status BinaryFilterNonNullImpl(KernelContext* ctx, const ArrayData& values,
+ const ArrayData& filter, int64_t output_length,
+ FilterOptions::NullSelectionBehavior null_selection,
+ ArrayData* out) {
+ using offset_type = typename Type::offset_type;
+ const auto filter_data = filter.buffers[1]->data();
+
+ BINARY_FILTER_SETUP_COMMON();
+
+ RETURN_NOT_OK(arrow::internal::VisitSetBitRuns(
+ filter_data, filter.offset, filter.length, [&](int64_t position, int64_t length) {
+ // Bulk-append raw data
+ const offset_type run_data_bytes =
+ (raw_offsets[position + length] - raw_offsets[position]);
+ APPEND_RAW_DATA(raw_data + raw_offsets[position], run_data_bytes);
+ // Append offsets
+ offset_type cur_offset = raw_offsets[position];
+ for (int64_t i = 0; i < length; ++i) {
+ offset_builder.UnsafeAppend(offset);
+ offset += raw_offsets[i + position + 1] - cur_offset;
+ cur_offset = raw_offsets[i + position + 1];
+ }
+ return Status::OK();
+ }));
+
+ offset_builder.UnsafeAppend(offset);
+ out->length = output_length;
+ RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1]));
+ return data_builder.Finish(&out->buffers[2]);
+}
+
+template <typename Type>
+Status BinaryFilterImpl(KernelContext* ctx, const ArrayData& values,
+ const ArrayData& filter, int64_t output_length,
+ FilterOptions::NullSelectionBehavior null_selection,
+ ArrayData* out) {
+ using offset_type = typename Type::offset_type;
+
+ const auto filter_data = filter.buffers[1]->data();
+ const uint8_t* filter_is_valid = GetValidityBitmap(filter);
+ const int64_t filter_offset = filter.offset;
+
+ const uint8_t* values_is_valid = GetValidityBitmap(values);
+ const int64_t values_offset = values.offset;
+
+ uint8_t* out_is_valid = out->buffers[0]->mutable_data();
+ // Zero bits and then only have to set valid values to true
+ BitUtil::SetBitsTo(out_is_valid, 0, output_length, false);
+
+ // We use 3 block counters for fast scanning of the filter
+ //
+ // * values_valid_counter: for values null/not-null
+ // * filter_valid_counter: for filter null/not-null
+ // * filter_counter: for filter true/false
+ OptionalBitBlockCounter values_valid_counter(values_is_valid, values_offset,
+ values.length);
+ OptionalBitBlockCounter filter_valid_counter(filter_is_valid, filter_offset,
+ filter.length);
+ BitBlockCounter filter_counter(filter_data, filter_offset, filter.length);
+
+ BINARY_FILTER_SETUP_COMMON();
+
+ int64_t in_position = 0;
+ int64_t out_position = 0;
+ while (in_position < filter.length) {
+ BitBlockCount filter_valid_block = filter_valid_counter.NextWord();
+ BitBlockCount values_valid_block = values_valid_counter.NextWord();
+ BitBlockCount filter_block = filter_counter.NextWord();
+ if (filter_block.NoneSet() && null_selection == FilterOptions::DROP) {
+ // For this exceedingly common case in low-selectivity filters we can
+ // skip further analysis of the data and move on to the next block.
+ in_position += filter_block.length;
+ } else if (filter_valid_block.AllSet()) {
+ // Simpler path: no filter values are null
+ if (filter_block.AllSet()) {
+ // Fastest path: filter values are all true and not null
+ if (values_valid_block.AllSet()) {
+ // The values aren't null either
+ BitUtil::SetBitsTo(out_is_valid, out_position, filter_block.length, true);
+
+ // Bulk-append raw data
+ offset_type block_data_bytes =
+ (raw_offsets[in_position + filter_block.length] - raw_offsets[in_position]);
+ APPEND_RAW_DATA(raw_data + raw_offsets[in_position], block_data_bytes);
+ // Append offsets
+ for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) {
+ offset_builder.UnsafeAppend(offset);
+ offset += raw_offsets[in_position + 1] - raw_offsets[in_position];
+ }
+ out_position += filter_block.length;
+ } else {
+ // Some of the values in this block are null
+ for (int64_t i = 0; i < filter_block.length;
+ ++i, ++in_position, ++out_position) {
+ offset_builder.UnsafeAppend(offset);
+ if (BitUtil::GetBit(values_is_valid, values_offset + in_position)) {
+ BitUtil::SetBit(out_is_valid, out_position);
+ APPEND_SINGLE_VALUE();
+ }
+ }
+ }
+ } else { // !filter_block.AllSet()
+ // Some of the filter values are false, but all not null
+ if (values_valid_block.AllSet()) {
+ // All the values are not-null, so we can skip null checking for
+ // them
+ for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) {
+ if (BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+ offset_builder.UnsafeAppend(offset);
+ BitUtil::SetBit(out_is_valid, out_position++);
+ APPEND_SINGLE_VALUE();
+ }
+ }
+ } else {
+ // Some of the values in the block are null, so we have to check
+ // each one
+ for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) {
+ if (BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+ offset_builder.UnsafeAppend(offset);
+ if (BitUtil::GetBit(values_is_valid, values_offset + in_position)) {
+ BitUtil::SetBit(out_is_valid, out_position);
+ APPEND_SINGLE_VALUE();
+ }
+ ++out_position;
+ }
+ }
+ }
+ }
+ } else { // !filter_valid_block.AllSet()
+ // Some of the filter values are null, so we have to handle the DROP
+ // versus EMIT_NULL null selection behavior.
+ if (null_selection == FilterOptions::DROP) {
+ // Filter null values are treated as false.
+ if (values_valid_block.AllSet()) {
+ for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) {
+ if (BitUtil::GetBit(filter_is_valid, filter_offset + in_position) &&
+ BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+ offset_builder.UnsafeAppend(offset);
+ BitUtil::SetBit(out_is_valid, out_position++);
+ APPEND_SINGLE_VALUE();
+ }
+ }
+ } else {
+ for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) {
+ if (BitUtil::GetBit(filter_is_valid, filter_offset + in_position) &&
+ BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+ offset_builder.UnsafeAppend(offset);
+ if (BitUtil::GetBit(values_is_valid, values_offset + in_position)) {
+ BitUtil::SetBit(out_is_valid, out_position);
+ APPEND_SINGLE_VALUE();
+ }
+ ++out_position;
+ }
+ }
+ }
+ } else {
+ // EMIT_NULL
+
+ // Filter null values are appended to output as null whether the
+ // value in the corresponding slot is valid or not
+ if (values_valid_block.AllSet()) {
+ for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) {
+ const bool filter_not_null =
+ BitUtil::GetBit(filter_is_valid, filter_offset + in_position);
+ if (filter_not_null &&
+ BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+ offset_builder.UnsafeAppend(offset);
+ BitUtil::SetBit(out_is_valid, out_position++);
+ APPEND_SINGLE_VALUE();
+ } else if (!filter_not_null) {
+ offset_builder.UnsafeAppend(offset);
+ ++out_position;
+ }
+ }
+ } else {
+ for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) {
+ const bool filter_not_null =
+ BitUtil::GetBit(filter_is_valid, filter_offset + in_position);
+ if (filter_not_null &&
+ BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+ offset_builder.UnsafeAppend(offset);
+ if (BitUtil::GetBit(values_is_valid, values_offset + in_position)) {
+ BitUtil::SetBit(out_is_valid, out_position);
+ APPEND_SINGLE_VALUE();
+ }
+ ++out_position;
+ } else if (!filter_not_null) {
+ offset_builder.UnsafeAppend(offset);
+ ++out_position;
+ }
+ }
+ }
+ }
+ }
+ }
+ offset_builder.UnsafeAppend(offset);
+ out->length = output_length;
+ RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1]));
+ return data_builder.Finish(&out->buffers[2]);
+}
+
+#undef BINARY_FILTER_SETUP_COMMON
+#undef APPEND_RAW_DATA
+#undef APPEND_SINGLE_VALUE
+
+Status BinaryFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ FilterOptions::NullSelectionBehavior null_selection =
+ FilterState::Get(ctx).null_selection_behavior;
+
+ const ArrayData& values = *batch[0].array();
+ const ArrayData& filter = *batch[1].array();
+ int64_t output_length = GetFilterOutputSize(filter, null_selection);
+ ArrayData* out_arr = out->mutable_array();
+
+ // The output precomputed null count is unknown except in the narrow
+ // condition that all the values are non-null and the filter will not cause
+ // any new nulls to be created.
+ if (values.null_count == 0 &&
+ (null_selection == FilterOptions::DROP || filter.null_count == 0)) {
+ out_arr->null_count = 0;
+ } else {
+ out_arr->null_count = kUnknownNullCount;
+ }
+ Type::type type_id = values.type->id();
+ if (values.null_count == 0 && filter.null_count == 0) {
+ // Faster no-nulls case
+ if (is_binary_like(type_id)) {
+ RETURN_NOT_OK(BinaryFilterNonNullImpl<BinaryType>(
+ ctx, values, filter, output_length, null_selection, out_arr));
+ } else if (is_large_binary_like(type_id)) {
+ RETURN_NOT_OK(BinaryFilterNonNullImpl<LargeBinaryType>(
+ ctx, values, filter, output_length, null_selection, out_arr));
+ } else {
+ DCHECK(false);
+ }
+ } else {
+ // Output may have nulls
+ RETURN_NOT_OK(ctx->AllocateBitmap(output_length).Value(&out_arr->buffers[0]));
+ if (is_binary_like(type_id)) {
+ RETURN_NOT_OK(BinaryFilterImpl<BinaryType>(ctx, values, filter, output_length,
+ null_selection, out_arr));
+ } else if (is_large_binary_like(type_id)) {
+ RETURN_NOT_OK(BinaryFilterImpl<LargeBinaryType>(ctx, values, filter, output_length,
+ null_selection, out_arr));
+ } else {
+ DCHECK(false);
+ }
+ }
+
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Null take and filter
+
+Status NullTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (TakeState::Get(ctx).boundscheck) {
+ RETURN_NOT_OK(CheckIndexBounds(*batch[1].array(), batch[0].length()));
+ }
+ // batch.length doesn't take into account the take indices
+ auto new_length = batch[1].array()->length;
+ out->value = std::make_shared<NullArray>(new_length)->data();
+ return Status::OK();
+}
+
+Status NullFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ int64_t output_length = GetFilterOutputSize(
+ *batch[1].array(), FilterState::Get(ctx).null_selection_behavior);
+ out->value = std::make_shared<NullArray>(output_length)->data();
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Dictionary take and filter
+
+Status DictionaryTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DictionaryArray values(batch[0].array());
+ Datum result;
+ RETURN_NOT_OK(
+ Take(Datum(values.indices()), batch[1], TakeState::Get(ctx), ctx->exec_context())
+ .Value(&result));
+ DictionaryArray taken_values(values.type(), result.make_array(), values.dictionary());
+ out->value = taken_values.data();
+ return Status::OK();
+}
+
+Status DictionaryFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ DictionaryArray dict_values(batch[0].array());
+ Datum result;
+ RETURN_NOT_OK(Filter(Datum(dict_values.indices()), batch[1].array(),
+ FilterState::Get(ctx), ctx->exec_context())
+ .Value(&result));
+ DictionaryArray filtered_values(dict_values.type(), result.make_array(),
+ dict_values.dictionary());
+ out->value = filtered_values.data();
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Extension take and filter
+
+Status ExtensionTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ ExtensionArray values(batch[0].array());
+ Datum result;
+ RETURN_NOT_OK(
+ Take(Datum(values.storage()), batch[1], TakeState::Get(ctx), ctx->exec_context())
+ .Value(&result));
+ ExtensionArray taken_values(values.type(), result.make_array());
+ out->value = taken_values.data();
+ return Status::OK();
+}
+
+Status ExtensionFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ ExtensionArray ext_values(batch[0].array());
+ Datum result;
+ RETURN_NOT_OK(Filter(Datum(ext_values.storage()), batch[1].array(),
+ FilterState::Get(ctx), ctx->exec_context())
+ .Value(&result));
+ ExtensionArray filtered_values(ext_values.type(), result.make_array());
+ out->value = filtered_values.data();
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Implement take for other data types where there is less performance
+// sensitivity by visiting the selected indices.
+
+// Use CRTP to dispatch to type-specific processing of take indices for each
+// unsigned integer type.
+template <typename Impl, typename Type>
+struct Selection {
+ using ValuesArrayType = typename TypeTraits<Type>::ArrayType;
+
+ // Forwards the generic value visitors to the take index visitor template
+ template <typename IndexCType>
+ struct TakeAdapter {
+ static constexpr bool is_take = true;
+
+ Impl* impl;
+ explicit TakeAdapter(Impl* impl) : impl(impl) {}
+ template <typename ValidVisitor, typename NullVisitor>
+ Status Generate(ValidVisitor&& visit_valid, NullVisitor&& visit_null) {
+ return impl->template VisitTake<IndexCType>(std::forward<ValidVisitor>(visit_valid),
+ std::forward<NullVisitor>(visit_null));
+ }
+ };
+
+ // Forwards the generic value visitors to the VisitFilter template
+ struct FilterAdapter {
+ static constexpr bool is_take = false;
+
+ Impl* impl;
+ explicit FilterAdapter(Impl* impl) : impl(impl) {}
+ template <typename ValidVisitor, typename NullVisitor>
+ Status Generate(ValidVisitor&& visit_valid, NullVisitor&& visit_null) {
+ return impl->VisitFilter(std::forward<ValidVisitor>(visit_valid),
+ std::forward<NullVisitor>(visit_null));
+ }
+ };
+
+ KernelContext* ctx;
+ std::shared_ptr<ArrayData> values;
+ std::shared_ptr<ArrayData> selection;
+ int64_t output_length;
+ ArrayData* out;
+ TypedBufferBuilder<bool> validity_builder;
+
+ Selection(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out)
+ : ctx(ctx),
+ values(batch[0].array()),
+ selection(batch[1].array()),
+ output_length(output_length),
+ out(out->mutable_array()),
+ validity_builder(ctx->memory_pool()) {}
+
+ virtual ~Selection() = default;
+
+ Status FinishCommon() {
+ out->buffers.resize(values->buffers.size());
+ out->length = validity_builder.length();
+ out->null_count = validity_builder.false_count();
+ return validity_builder.Finish(&out->buffers[0]);
+ }
+
+ template <typename IndexCType, typename ValidVisitor, typename NullVisitor>
+ Status VisitTake(ValidVisitor&& visit_valid, NullVisitor&& visit_null) {
+ const auto indices_values = selection->GetValues<IndexCType>(1);
+ const uint8_t* is_valid = GetValidityBitmap(*selection);
+ OptionalBitIndexer indices_is_valid(selection->buffers[0], selection->offset);
+ OptionalBitIndexer values_is_valid(values->buffers[0], values->offset);
+
+ const bool values_have_nulls = values->MayHaveNulls();
+ OptionalBitBlockCounter bit_counter(is_valid, selection->offset, selection->length);
+ int64_t position = 0;
+ while (position < selection->length) {
+ BitBlockCount block = bit_counter.NextBlock();
+ const bool indices_have_nulls = block.popcount < block.length;
+ if (!indices_have_nulls && !values_have_nulls) {
+ // Fastest path, neither indices nor values have nulls
+ validity_builder.UnsafeAppend(block.length, true);
+ for (int64_t i = 0; i < block.length; ++i) {
+ RETURN_NOT_OK(visit_valid(indices_values[position++]));
+ }
+ } else if (block.popcount > 0) {
+ // Since we have to branch on whether the indices are null or not, we
+ // combine the "non-null indices block but some values null" and
+ // "some-null indices block but values non-null" into a single loop.
+ for (int64_t i = 0; i < block.length; ++i) {
+ if ((!indices_have_nulls || indices_is_valid[position]) &&
+ values_is_valid[indices_values[position]]) {
+ validity_builder.UnsafeAppend(true);
+ RETURN_NOT_OK(visit_valid(indices_values[position]));
+ } else {
+ validity_builder.UnsafeAppend(false);
+ RETURN_NOT_OK(visit_null());
+ }
+ ++position;
+ }
+ } else {
+ // The whole block is null
+ validity_builder.UnsafeAppend(block.length, false);
+ for (int64_t i = 0; i < block.length; ++i) {
+ RETURN_NOT_OK(visit_null());
+ }
+ position += block.length;
+ }
+ }
+ return Status::OK();
+ }
+
+ // We use the NullVisitor both for "selected" nulls as well as "emitted"
+ // nulls coming from the filter when using FilterOptions::EMIT_NULL
+ template <typename ValidVisitor, typename NullVisitor>
+ Status VisitFilter(ValidVisitor&& visit_valid, NullVisitor&& visit_null) {
+ auto null_selection = FilterState::Get(ctx).null_selection_behavior;
+
+ const auto filter_data = selection->buffers[1]->data();
+
+ const uint8_t* filter_is_valid = GetValidityBitmap(*selection);
+ const int64_t filter_offset = selection->offset;
+ OptionalBitIndexer values_is_valid(values->buffers[0], values->offset);
+
+ // We use 3 block counters for fast scanning of the filter
+ //
+ // * values_valid_counter: for values null/not-null
+ // * filter_valid_counter: for filter null/not-null
+ // * filter_counter: for filter true/false
+ OptionalBitBlockCounter values_valid_counter(GetValidityBitmap(*values),
+ values->offset, values->length);
+ OptionalBitBlockCounter filter_valid_counter(filter_is_valid, filter_offset,
+ selection->length);
+ BitBlockCounter filter_counter(filter_data, filter_offset, selection->length);
+ int64_t in_position = 0;
+
+ auto AppendNotNull = [&](int64_t index) -> Status {
+ validity_builder.UnsafeAppend(true);
+ return visit_valid(index);
+ };
+
+ auto AppendNull = [&]() -> Status {
+ validity_builder.UnsafeAppend(false);
+ return visit_null();
+ };
+
+ auto AppendMaybeNull = [&](int64_t index) -> Status {
+ if (values_is_valid[index]) {
+ return AppendNotNull(index);
+ } else {
+ return AppendNull();
+ }
+ };
+
+ while (in_position < selection->length) {
+ BitBlockCount filter_valid_block = filter_valid_counter.NextWord();
+ BitBlockCount values_valid_block = values_valid_counter.NextWord();
+ BitBlockCount filter_block = filter_counter.NextWord();
+ if (filter_block.NoneSet() && null_selection == FilterOptions::DROP) {
+ // For this exceedingly common case in low-selectivity filters we can
+ // skip further analysis of the data and move on to the next block.
+ in_position += filter_block.length;
+ } else if (filter_valid_block.AllSet()) {
+ // Simpler path: no filter values are null
+ if (filter_block.AllSet()) {
+ // Fastest path: filter values are all true and not null
+ if (values_valid_block.AllSet()) {
+ // The values aren't null either
+ validity_builder.UnsafeAppend(filter_block.length, true);
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ RETURN_NOT_OK(visit_valid(in_position++));
+ }
+ } else {
+ // Some of the values in this block are null
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ RETURN_NOT_OK(AppendMaybeNull(in_position++));
+ }
+ }
+ } else { // !filter_block.AllSet()
+ // Some of the filter values are false, but all not null
+ if (values_valid_block.AllSet()) {
+ // All the values are not-null, so we can skip null checking for
+ // them
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ if (BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+ RETURN_NOT_OK(AppendNotNull(in_position));
+ }
+ ++in_position;
+ }
+ } else {
+ // Some of the values in the block are null, so we have to check
+ // each one
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ if (BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+ RETURN_NOT_OK(AppendMaybeNull(in_position));
+ }
+ ++in_position;
+ }
+ }
+ }
+ } else { // !filter_valid_block.AllSet()
+ // Some of the filter values are null, so we have to handle the DROP
+ // versus EMIT_NULL null selection behavior.
+ if (null_selection == FilterOptions::DROP) {
+ // Filter null values are treated as false.
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ if (BitUtil::GetBit(filter_is_valid, filter_offset + in_position) &&
+ BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+ RETURN_NOT_OK(AppendMaybeNull(in_position));
+ }
+ ++in_position;
+ }
+ } else {
+ // Filter null values are appended to output as null whether the
+ // value in the corresponding slot is valid or not
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ const bool filter_not_null =
+ BitUtil::GetBit(filter_is_valid, filter_offset + in_position);
+ if (filter_not_null &&
+ BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+ RETURN_NOT_OK(AppendMaybeNull(in_position));
+ } else if (!filter_not_null) {
+ // EMIT_NULL case
+ RETURN_NOT_OK(AppendNull());
+ }
+ ++in_position;
+ }
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ virtual Status Init() { return Status::OK(); }
+
+ // Implementation specific finish logic
+ virtual Status Finish() = 0;
+
+ Status ExecTake() {
+ RETURN_NOT_OK(this->validity_builder.Reserve(output_length));
+ RETURN_NOT_OK(Init());
+ int index_width = GetByteWidth(*this->selection->type);
+
+ // CTRP dispatch here
+ switch (index_width) {
+ case 1: {
+ Status s =
+ static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint8_t>>();
+ RETURN_NOT_OK(s);
+ } break;
+ case 2: {
+ Status s =
+ static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint16_t>>();
+ RETURN_NOT_OK(s);
+ } break;
+ case 4: {
+ Status s =
+ static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint32_t>>();
+ RETURN_NOT_OK(s);
+ } break;
+ case 8: {
+ Status s =
+ static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint64_t>>();
+ RETURN_NOT_OK(s);
+ } break;
+ default:
+ DCHECK(false) << "Invalid index width";
+ break;
+ }
+ RETURN_NOT_OK(this->FinishCommon());
+ return Finish();
+ }
+
+ Status ExecFilter() {
+ RETURN_NOT_OK(this->validity_builder.Reserve(output_length));
+ RETURN_NOT_OK(Init());
+ // CRTP dispatch
+ Status s = static_cast<Impl*>(this)->template GenerateOutput<FilterAdapter>();
+ RETURN_NOT_OK(s);
+ RETURN_NOT_OK(this->FinishCommon());
+ return Finish();
+ }
+};
+
+#define LIFT_BASE_MEMBERS() \
+ using ValuesArrayType = typename Base::ValuesArrayType; \
+ using Base::ctx; \
+ using Base::values; \
+ using Base::selection; \
+ using Base::output_length; \
+ using Base::out; \
+ using Base::validity_builder
+
+static inline Status VisitNoop() { return Status::OK(); }
+
+// A selection implementation for 32-bit and 64-bit variable binary
+// types. Common generated kernels are shared between Binary/String and
+// LargeBinary/LargeString
+template <typename Type>
+struct VarBinaryImpl : public Selection<VarBinaryImpl<Type>, Type> {
+ using offset_type = typename Type::offset_type;
+
+ using Base = Selection<VarBinaryImpl<Type>, Type>;
+ LIFT_BASE_MEMBERS();
+
+ std::shared_ptr<ArrayData> values_as_binary;
+ TypedBufferBuilder<offset_type> offset_builder;
+ TypedBufferBuilder<uint8_t> data_builder;
+
+ static constexpr int64_t kOffsetLimit = std::numeric_limits<offset_type>::max() - 1;
+
+ VarBinaryImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length,
+ Datum* out)
+ : Base(ctx, batch, output_length, out),
+ offset_builder(ctx->memory_pool()),
+ data_builder(ctx->memory_pool()) {}
+
+ template <typename Adapter>
+ Status GenerateOutput() {
+ ValuesArrayType typed_values(this->values_as_binary);
+
+ // Presize the data builder with a rough estimate of the required data size
+ if (values->length > 0) {
+ const double mean_value_length =
+ (typed_values.total_values_length() / static_cast<double>(values->length));
+
+ // TODO: See if possible to reduce output_length for take/filter cases
+ // where there are nulls in the selection array
+ RETURN_NOT_OK(
+ data_builder.Reserve(static_cast<int64_t>(mean_value_length * output_length)));
+ }
+ int64_t space_available = data_builder.capacity();
+
+ const offset_type* raw_offsets = typed_values.raw_value_offsets();
+ const uint8_t* raw_data = typed_values.raw_data();
+
+ offset_type offset = 0;
+ Adapter adapter(this);
+ RETURN_NOT_OK(adapter.Generate(
+ [&](int64_t index) {
+ offset_builder.UnsafeAppend(offset);
+ offset_type val_offset = raw_offsets[index];
+ offset_type val_size = raw_offsets[index + 1] - val_offset;
+
+ // Use static property to prune this code from the filter path in
+ // optimized builds
+ if (Adapter::is_take &&
+ ARROW_PREDICT_FALSE(static_cast<int64_t>(offset) +
+ static_cast<int64_t>(val_size)) > kOffsetLimit) {
+ return Status::Invalid("Take operation overflowed binary array capacity");
+ }
+ offset += val_size;
+ if (ARROW_PREDICT_FALSE(val_size > space_available)) {
+ RETURN_NOT_OK(data_builder.Reserve(val_size));
+ space_available = data_builder.capacity() - data_builder.length();
+ }
+ data_builder.UnsafeAppend(raw_data + val_offset, val_size);
+ space_available -= val_size;
+ return Status::OK();
+ },
+ [&]() {
+ offset_builder.UnsafeAppend(offset);
+ return Status::OK();
+ }));
+ offset_builder.UnsafeAppend(offset);
+ return Status::OK();
+ }
+
+ Status Init() override {
+ ARROW_ASSIGN_OR_RAISE(this->values_as_binary,
+ GetArrayView(this->values, TypeTraits<Type>::type_singleton()));
+ return offset_builder.Reserve(output_length + 1);
+ }
+
+ Status Finish() override {
+ RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1]));
+ return data_builder.Finish(&out->buffers[2]);
+ }
+};
+
+struct FSBImpl : public Selection<FSBImpl, FixedSizeBinaryType> {
+ using Base = Selection<FSBImpl, FixedSizeBinaryType>;
+ LIFT_BASE_MEMBERS();
+
+ TypedBufferBuilder<uint8_t> data_builder;
+
+ FSBImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out)
+ : Base(ctx, batch, output_length, out), data_builder(ctx->memory_pool()) {}
+
+ template <typename Adapter>
+ Status GenerateOutput() {
+ FixedSizeBinaryArray typed_values(this->values);
+ int32_t value_size = typed_values.byte_width();
+
+ RETURN_NOT_OK(data_builder.Reserve(value_size * output_length));
+ Adapter adapter(this);
+ return adapter.Generate(
+ [&](int64_t index) {
+ auto val = typed_values.GetView(index);
+ data_builder.UnsafeAppend(reinterpret_cast<const uint8_t*>(val.data()),
+ value_size);
+ return Status::OK();
+ },
+ [&]() {
+ data_builder.UnsafeAppend(value_size, static_cast<uint8_t>(0x00));
+ return Status::OK();
+ });
+ }
+
+ Status Finish() override { return data_builder.Finish(&out->buffers[1]); }
+};
+
+template <typename Type>
+struct ListImpl : public Selection<ListImpl<Type>, Type> {
+ using offset_type = typename Type::offset_type;
+
+ using Base = Selection<ListImpl<Type>, Type>;
+ LIFT_BASE_MEMBERS();
+
+ TypedBufferBuilder<offset_type> offset_builder;
+ typename TypeTraits<Type>::OffsetBuilderType child_index_builder;
+
+ ListImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out)
+ : Base(ctx, batch, output_length, out),
+ offset_builder(ctx->memory_pool()),
+ child_index_builder(ctx->memory_pool()) {}
+
+ template <typename Adapter>
+ Status GenerateOutput() {
+ ValuesArrayType typed_values(this->values);
+
+ // TODO presize child_index_builder with a similar heuristic as VarBinaryImpl
+
+ offset_type offset = 0;
+ Adapter adapter(this);
+ RETURN_NOT_OK(adapter.Generate(
+ [&](int64_t index) {
+ offset_builder.UnsafeAppend(offset);
+ offset_type value_offset = typed_values.value_offset(index);
+ offset_type value_length = typed_values.value_length(index);
+ offset += value_length;
+ RETURN_NOT_OK(child_index_builder.Reserve(value_length));
+ for (offset_type j = value_offset; j < value_offset + value_length; ++j) {
+ child_index_builder.UnsafeAppend(j);
+ }
+ return Status::OK();
+ },
+ [&]() {
+ offset_builder.UnsafeAppend(offset);
+ return Status::OK();
+ }));
+ offset_builder.UnsafeAppend(offset);
+ return Status::OK();
+ }
+
+ Status Init() override {
+ RETURN_NOT_OK(offset_builder.Reserve(output_length + 1));
+ return Status::OK();
+ }
+
+ Status Finish() override {
+ std::shared_ptr<Array> child_indices;
+ RETURN_NOT_OK(child_index_builder.Finish(&child_indices));
+
+ ValuesArrayType typed_values(this->values);
+
+ // No need to boundscheck the child values indices
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> taken_child,
+ Take(*typed_values.values(), *child_indices,
+ TakeOptions::NoBoundsCheck(), ctx->exec_context()));
+ RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1]));
+ out->child_data = {taken_child->data()};
+ return Status::OK();
+ }
+};
+
+struct DenseUnionImpl : public Selection<DenseUnionImpl, DenseUnionType> {
+ using Base = Selection<DenseUnionImpl, DenseUnionType>;
+ LIFT_BASE_MEMBERS();
+
+ TypedBufferBuilder<int32_t> value_offset_buffer_builder_;
+ TypedBufferBuilder<int8_t> child_id_buffer_builder_;
+ std::vector<int8_t> type_codes_;
+ std::vector<Int32Builder> child_indices_builders_;
+
+ DenseUnionImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length,
+ Datum* out)
+ : Base(ctx, batch, output_length, out),
+ value_offset_buffer_builder_(ctx->memory_pool()),
+ child_id_buffer_builder_(ctx->memory_pool()),
+ type_codes_(checked_cast<const UnionType&>(*this->values->type).type_codes()),
+ child_indices_builders_(type_codes_.size()) {
+ for (auto& child_indices_builder : child_indices_builders_) {
+ child_indices_builder = Int32Builder(ctx->memory_pool());
+ }
+ }
+
+ template <typename Adapter>
+ Status GenerateOutput() {
+ DenseUnionArray typed_values(this->values);
+ Adapter adapter(this);
+ RETURN_NOT_OK(adapter.Generate(
+ [&](int64_t index) {
+ int8_t child_id = typed_values.child_id(index);
+ child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]);
+ int32_t value_offset = typed_values.value_offset(index);
+ value_offset_buffer_builder_.UnsafeAppend(
+ static_cast<int32_t>(child_indices_builders_[child_id].length()));
+ RETURN_NOT_OK(child_indices_builders_[child_id].Reserve(1));
+ child_indices_builders_[child_id].UnsafeAppend(value_offset);
+ return Status::OK();
+ },
+ [&]() {
+ int8_t child_id = 0;
+ child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]);
+ value_offset_buffer_builder_.UnsafeAppend(
+ static_cast<int32_t>(child_indices_builders_[child_id].length()));
+ RETURN_NOT_OK(child_indices_builders_[child_id].Reserve(1));
+ child_indices_builders_[child_id].UnsafeAppendNull();
+ return Status::OK();
+ }));
+ return Status::OK();
+ }
+
+ Status Init() override {
+ RETURN_NOT_OK(child_id_buffer_builder_.Reserve(output_length));
+ RETURN_NOT_OK(value_offset_buffer_builder_.Reserve(output_length));
+ return Status::OK();
+ }
+
+ Status Finish() override {
+ ARROW_ASSIGN_OR_RAISE(auto child_ids_buffer, child_id_buffer_builder_.Finish());
+ ARROW_ASSIGN_OR_RAISE(auto value_offsets_buffer,
+ value_offset_buffer_builder_.Finish());
+ DenseUnionArray typed_values(this->values);
+ auto num_fields = typed_values.num_fields();
+ auto num_rows = child_ids_buffer->size();
+ BufferVector buffers{nullptr, std::move(child_ids_buffer),
+ std::move(value_offsets_buffer)};
+ *out = ArrayData(typed_values.type(), num_rows, std::move(buffers), 0);
+ for (auto i = 0; i < num_fields; i++) {
+ ARROW_ASSIGN_OR_RAISE(auto child_indices_array,
+ child_indices_builders_[i].Finish());
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> child_array,
+ Take(*typed_values.field(i), *child_indices_array));
+ out->child_data.push_back(child_array->data());
+ }
+ return Status::OK();
+ }
+};
+
+struct FSLImpl : public Selection<FSLImpl, FixedSizeListType> {
+ Int64Builder child_index_builder;
+
+ using Base = Selection<FSLImpl, FixedSizeListType>;
+ LIFT_BASE_MEMBERS();
+
+ FSLImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out)
+ : Base(ctx, batch, output_length, out), child_index_builder(ctx->memory_pool()) {}
+
+ template <typename Adapter>
+ Status GenerateOutput() {
+ ValuesArrayType typed_values(this->values);
+ const int32_t list_size = typed_values.list_type()->list_size();
+ const int64_t base_offset = typed_values.offset();
+
+ // We must take list_size elements even for null elements of
+ // indices.
+ RETURN_NOT_OK(child_index_builder.Reserve(output_length * list_size));
+
+ Adapter adapter(this);
+ return adapter.Generate(
+ [&](int64_t index) {
+ int64_t offset = (base_offset + index) * list_size;
+ for (int64_t j = offset; j < offset + list_size; ++j) {
+ child_index_builder.UnsafeAppend(j);
+ }
+ return Status::OK();
+ },
+ [&]() { return child_index_builder.AppendNulls(list_size); });
+ }
+
+ Status Finish() override {
+ std::shared_ptr<Array> child_indices;
+ RETURN_NOT_OK(child_index_builder.Finish(&child_indices));
+
+ ValuesArrayType typed_values(this->values);
+
+ // No need to boundscheck the child values indices
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> taken_child,
+ Take(*typed_values.values(), *child_indices,
+ TakeOptions::NoBoundsCheck(), ctx->exec_context()));
+ out->child_data = {taken_child->data()};
+ return Status::OK();
+ }
+};
+
+// ----------------------------------------------------------------------
+// Struct selection implementations
+
+// We need a slightly different approach for StructType. For Take, we can
+// invoke Take on each struct field's data with boundschecking disabled. For
+// Filter on the other hand, if we naively call Filter on each field, then the
+// filter output length will have to be redundantly computed. Thus, for Filter
+// we instead convert the filter to selection indices and then invoke take.
+
+// Struct selection implementation. ONLY used for Take
+struct StructImpl : public Selection<StructImpl, StructType> {
+ using Base = Selection<StructImpl, StructType>;
+ LIFT_BASE_MEMBERS();
+ using Base::Base;
+
+ template <typename Adapter>
+ Status GenerateOutput() {
+ StructArray typed_values(values);
+ Adapter adapter(this);
+ // There's nothing to do for Struct except to generate the validity bitmap
+ return adapter.Generate([&](int64_t index) { return Status::OK(); },
+ /*visit_null=*/VisitNoop);
+ }
+
+ Status Finish() override {
+ StructArray typed_values(values);
+
+ // Select from children without boundschecking
+ out->child_data.resize(values->type->num_fields());
+ for (int field_index = 0; field_index < values->type->num_fields(); ++field_index) {
+ ARROW_ASSIGN_OR_RAISE(Datum taken_field,
+ Take(Datum(typed_values.field(field_index)), Datum(selection),
+ TakeOptions::NoBoundsCheck(), ctx->exec_context()));
+ out->child_data[field_index] = taken_field.array();
+ }
+ return Status::OK();
+ }
+};
+
+Status StructFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // Transform filter to selection indices and then use Take.
+ std::shared_ptr<ArrayData> indices;
+ RETURN_NOT_OK(GetTakeIndices(*batch[1].array(),
+ FilterState::Get(ctx).null_selection_behavior,
+ ctx->memory_pool())
+ .Value(&indices));
+
+ Datum result;
+ RETURN_NOT_OK(
+ Take(batch[0], Datum(indices), TakeOptions::NoBoundsCheck(), ctx->exec_context())
+ .Value(&result));
+ out->value = result.array();
+ return Status::OK();
+}
+
+#undef LIFT_BASE_MEMBERS
+
+// ----------------------------------------------------------------------
+// Implement Filter metafunction
+
+Result<std::shared_ptr<RecordBatch>> FilterRecordBatch(const RecordBatch& batch,
+ const Datum& filter,
+ const FunctionOptions* options,
+ ExecContext* ctx) {
+ if (batch.num_rows() != filter.length()) {
+ return Status::Invalid("Filter inputs must all be the same length");
+ }
+
+ // Convert filter to selection vector/indices and use Take
+ const auto& filter_opts = *static_cast<const FilterOptions*>(options);
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<ArrayData> indices,
+ GetTakeIndices(*filter.array(), filter_opts.null_selection_behavior,
+ ctx->memory_pool()));
+ std::vector<std::shared_ptr<Array>> columns(batch.num_columns());
+ for (int i = 0; i < batch.num_columns(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(Datum out, Take(batch.column(i)->data(), Datum(indices),
+ TakeOptions::NoBoundsCheck(), ctx));
+ columns[i] = out.make_array();
+ }
+ return RecordBatch::Make(batch.schema(), indices->length, std::move(columns));
+}
+
+Result<std::shared_ptr<Table>> FilterTable(const Table& table, const Datum& filter,
+ const FunctionOptions* options,
+ ExecContext* ctx) {
+ if (table.num_rows() != filter.length()) {
+ return Status::Invalid("Filter inputs must all be the same length");
+ }
+ if (table.num_rows() == 0) {
+ return Table::Make(table.schema(), table.columns(), 0);
+ }
+
+ // Last input element will be the filter array
+ const int num_columns = table.num_columns();
+ std::vector<ArrayVector> inputs(num_columns + 1);
+
+ // Fetch table columns
+ for (int i = 0; i < num_columns; ++i) {
+ inputs[i] = table.column(i)->chunks();
+ }
+ // Fetch filter
+ const auto& filter_opts = *static_cast<const FilterOptions*>(options);
+ switch (filter.kind()) {
+ case Datum::ARRAY:
+ inputs.back().push_back(filter.make_array());
+ break;
+ case Datum::CHUNKED_ARRAY:
+ inputs.back() = filter.chunked_array()->chunks();
+ break;
+ default:
+ return Status::NotImplemented("Filter should be array-like");
+ }
+
+ // Rechunk inputs to allow consistent iteration over their respective chunks
+ inputs = arrow::internal::RechunkArraysConsistently(inputs);
+
+ // Instead of filtering each column with the boolean filter
+ // (which would be slow if the table has a large number of columns: ARROW-10569),
+ // convert each filter chunk to indices, and take() the column.
+ const int64_t num_chunks = static_cast<int64_t>(inputs.back().size());
+ std::vector<ArrayVector> out_columns(num_columns);
+ int64_t out_num_rows = 0;
+
+ for (int64_t i = 0; i < num_chunks; ++i) {
+ const ArrayData& filter_chunk = *inputs.back()[i]->data();
+ ARROW_ASSIGN_OR_RAISE(
+ const auto indices,
+ GetTakeIndices(filter_chunk, filter_opts.null_selection_behavior,
+ ctx->memory_pool()));
+
+ if (indices->length > 0) {
+ // Take from all input columns
+ Datum indices_datum{std::move(indices)};
+ for (int col = 0; col < num_columns; ++col) {
+ const auto& column_chunk = inputs[col][i];
+ ARROW_ASSIGN_OR_RAISE(Datum out, Take(column_chunk, indices_datum,
+ TakeOptions::NoBoundsCheck(), ctx));
+ out_columns[col].push_back(std::move(out).make_array());
+ }
+ out_num_rows += indices->length;
+ }
+ }
+
+ ChunkedArrayVector out_chunks(num_columns);
+ for (int i = 0; i < num_columns; ++i) {
+ out_chunks[i] = std::make_shared<ChunkedArray>(std::move(out_columns[i]),
+ table.column(i)->type());
+ }
+ return Table::Make(table.schema(), std::move(out_chunks), out_num_rows);
+}
+
+static auto kDefaultFilterOptions = FilterOptions::Defaults();
+
+const FunctionDoc filter_doc(
+ "Filter with a boolean selection filter",
+ ("The output is populated with values from the input at positions\n"
+ "where the selection filter is non-zero. Nulls in the selection filter\n"
+ "are handled based on FilterOptions."),
+ {"input", "selection_filter"}, "FilterOptions");
+
+class FilterMetaFunction : public MetaFunction {
+ public:
+ FilterMetaFunction()
+ : MetaFunction("filter", Arity::Binary(), &filter_doc, &kDefaultFilterOptions) {}
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ if (args[1].type()->id() != Type::BOOL) {
+ return Status::NotImplemented("Filter argument must be boolean type");
+ }
+
+ if (args[0].kind() == Datum::RECORD_BATCH) {
+ auto values_batch = args[0].record_batch();
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<RecordBatch> out_batch,
+ FilterRecordBatch(*args[0].record_batch(), args[1], options, ctx));
+ return Datum(out_batch);
+ } else if (args[0].kind() == Datum::TABLE) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Table> out_table,
+ FilterTable(*args[0].table(), args[1], options, ctx));
+ return Datum(out_table);
+ } else {
+ return CallFunction("array_filter", args, options, ctx);
+ }
+ }
+};
+
+// ----------------------------------------------------------------------
+// Implement Take metafunction
+
+// Shorthand naming of these functions
+// A -> Array
+// C -> ChunkedArray
+// R -> RecordBatch
+// T -> Table
+
+Result<std::shared_ptr<Array>> TakeAA(const Array& values, const Array& indices,
+ const TakeOptions& options, ExecContext* ctx) {
+ ARROW_ASSIGN_OR_RAISE(Datum result,
+ CallFunction("array_take", {values, indices}, &options, ctx));
+ return result.make_array();
+}
+
+Result<std::shared_ptr<ChunkedArray>> TakeCA(const ChunkedArray& values,
+ const Array& indices,
+ const TakeOptions& options,
+ ExecContext* ctx) {
+ auto num_chunks = values.num_chunks();
+ std::vector<std::shared_ptr<Array>> new_chunks(1); // Hard-coded 1 for now
+ std::shared_ptr<Array> current_chunk;
+
+ // Case 1: `values` has a single chunk, so just use it
+ if (num_chunks == 1) {
+ current_chunk = values.chunk(0);
+ } else {
+ // TODO Case 2: See if all `indices` fall in the same chunk and call Array Take on it
+ // See
+ // https://github.com/apache/arrow/blob/6f2c9041137001f7a9212f244b51bc004efc29af/r/src/compute.cpp#L123-L151
+ // TODO Case 3: If indices are sorted, can slice them and call Array Take
+
+ // Case 4: Else, concatenate chunks and call Array Take
+ if (values.chunks().empty()) {
+ ARROW_ASSIGN_OR_RAISE(current_chunk, MakeArrayOfNull(values.type(), /*length=*/0,
+ ctx->memory_pool()));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(current_chunk,
+ Concatenate(values.chunks(), ctx->memory_pool()));
+ }
+ }
+ // Call Array Take on our single chunk
+ ARROW_ASSIGN_OR_RAISE(new_chunks[0], TakeAA(*current_chunk, indices, options, ctx));
+ return std::make_shared<ChunkedArray>(std::move(new_chunks));
+}
+
+Result<std::shared_ptr<ChunkedArray>> TakeCC(const ChunkedArray& values,
+ const ChunkedArray& indices,
+ const TakeOptions& options,
+ ExecContext* ctx) {
+ auto num_chunks = indices.num_chunks();
+ std::vector<std::shared_ptr<Array>> new_chunks(num_chunks);
+ for (int i = 0; i < num_chunks; i++) {
+ // Take with that indices chunk
+ // Note that as currently implemented, this is inefficient because `values`
+ // will get concatenated on every iteration of this loop
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ChunkedArray> current_chunk,
+ TakeCA(values, *indices.chunk(i), options, ctx));
+ // Concatenate the result to make a single array for this chunk
+ ARROW_ASSIGN_OR_RAISE(new_chunks[i],
+ Concatenate(current_chunk->chunks(), ctx->memory_pool()));
+ }
+ return std::make_shared<ChunkedArray>(std::move(new_chunks), values.type());
+}
+
+Result<std::shared_ptr<ChunkedArray>> TakeAC(const Array& values,
+ const ChunkedArray& indices,
+ const TakeOptions& options,
+ ExecContext* ctx) {
+ auto num_chunks = indices.num_chunks();
+ std::vector<std::shared_ptr<Array>> new_chunks(num_chunks);
+ for (int i = 0; i < num_chunks; i++) {
+ // Take with that indices chunk
+ ARROW_ASSIGN_OR_RAISE(new_chunks[i], TakeAA(values, *indices.chunk(i), options, ctx));
+ }
+ return std::make_shared<ChunkedArray>(std::move(new_chunks), values.type());
+}
+
+Result<std::shared_ptr<RecordBatch>> TakeRA(const RecordBatch& batch,
+ const Array& indices,
+ const TakeOptions& options,
+ ExecContext* ctx) {
+ auto ncols = batch.num_columns();
+ auto nrows = indices.length();
+ std::vector<std::shared_ptr<Array>> columns(ncols);
+ for (int j = 0; j < ncols; j++) {
+ ARROW_ASSIGN_OR_RAISE(columns[j], TakeAA(*batch.column(j), indices, options, ctx));
+ }
+ return RecordBatch::Make(batch.schema(), nrows, std::move(columns));
+}
+
+Result<std::shared_ptr<Table>> TakeTA(const Table& table, const Array& indices,
+ const TakeOptions& options, ExecContext* ctx) {
+ auto ncols = table.num_columns();
+ std::vector<std::shared_ptr<ChunkedArray>> columns(ncols);
+
+ for (int j = 0; j < ncols; j++) {
+ ARROW_ASSIGN_OR_RAISE(columns[j], TakeCA(*table.column(j), indices, options, ctx));
+ }
+ return Table::Make(table.schema(), std::move(columns));
+}
+
+Result<std::shared_ptr<Table>> TakeTC(const Table& table, const ChunkedArray& indices,
+ const TakeOptions& options, ExecContext* ctx) {
+ auto ncols = table.num_columns();
+ std::vector<std::shared_ptr<ChunkedArray>> columns(ncols);
+ for (int j = 0; j < ncols; j++) {
+ ARROW_ASSIGN_OR_RAISE(columns[j], TakeCC(*table.column(j), indices, options, ctx));
+ }
+ return Table::Make(table.schema(), std::move(columns));
+}
+
+static auto kDefaultTakeOptions = TakeOptions::Defaults();
+
+const FunctionDoc take_doc(
+ "Select values from an input based on indices from another array",
+ ("The output is populated with values from the input at positions\n"
+ "given by `indices`. Nulls in `indices` emit null in the output."),
+ {"input", "indices"}, "TakeOptions");
+
+// Metafunction for dispatching to different Take implementations other than
+// Array-Array.
+//
+// TODO: Revamp approach to executing Take operations. In addition to being
+// overly complex dispatching, there is no parallelization.
+class TakeMetaFunction : public MetaFunction {
+ public:
+ TakeMetaFunction()
+ : MetaFunction("take", Arity::Binary(), &take_doc, &kDefaultTakeOptions) {}
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ Datum::Kind index_kind = args[1].kind();
+ const TakeOptions& take_opts = static_cast<const TakeOptions&>(*options);
+ switch (args[0].kind()) {
+ case Datum::ARRAY:
+ if (index_kind == Datum::ARRAY) {
+ return TakeAA(*args[0].make_array(), *args[1].make_array(), take_opts, ctx);
+ } else if (index_kind == Datum::CHUNKED_ARRAY) {
+ return TakeAC(*args[0].make_array(), *args[1].chunked_array(), take_opts, ctx);
+ }
+ break;
+ case Datum::CHUNKED_ARRAY:
+ if (index_kind == Datum::ARRAY) {
+ return TakeCA(*args[0].chunked_array(), *args[1].make_array(), take_opts, ctx);
+ } else if (index_kind == Datum::CHUNKED_ARRAY) {
+ return TakeCC(*args[0].chunked_array(), *args[1].chunked_array(), take_opts,
+ ctx);
+ }
+ break;
+ case Datum::RECORD_BATCH:
+ if (index_kind == Datum::ARRAY) {
+ return TakeRA(*args[0].record_batch(), *args[1].make_array(), take_opts, ctx);
+ }
+ break;
+ case Datum::TABLE:
+ if (index_kind == Datum::ARRAY) {
+ return TakeTA(*args[0].table(), *args[1].make_array(), take_opts, ctx);
+ } else if (index_kind == Datum::CHUNKED_ARRAY) {
+ return TakeTC(*args[0].table(), *args[1].chunked_array(), take_opts, ctx);
+ }
+ break;
+ default:
+ break;
+ }
+ return Status::NotImplemented(
+ "Unsupported types for take operation: "
+ "values=",
+ args[0].ToString(), "indices=", args[1].ToString());
+ }
+};
+
+// ----------------------------------------------------------------------
+// DropNull Implementation
+
+Result<std::shared_ptr<arrow::BooleanArray>> GetDropNullFilter(const Array& values,
+ MemoryPool* memory_pool) {
+ auto bitmap_buffer = values.null_bitmap();
+ std::shared_ptr<arrow::BooleanArray> out_array = std::make_shared<BooleanArray>(
+ values.length(), bitmap_buffer, nullptr, 0, values.offset());
+ return out_array;
+}
+
+Result<std::shared_ptr<Array>> CreateEmptyArray(std::shared_ptr<DataType> type,
+ MemoryPool* memory_pool) {
+ std::unique_ptr<ArrayBuilder> builder;
+ RETURN_NOT_OK(MakeBuilder(memory_pool, type, &builder));
+ RETURN_NOT_OK(builder->Resize(0));
+ return builder->Finish();
+}
+
+Result<std::shared_ptr<ChunkedArray>> CreateEmptyChunkedArray(
+ std::shared_ptr<DataType> type, MemoryPool* memory_pool) {
+ std::vector<std::shared_ptr<Array>> new_chunks(1); // Hard-coded 1 for now
+ ARROW_ASSIGN_OR_RAISE(new_chunks[0], CreateEmptyArray(type, memory_pool));
+ return std::make_shared<ChunkedArray>(std::move(new_chunks));
+}
+
+Result<Datum> DropNullArray(const std::shared_ptr<Array>& values, ExecContext* ctx) {
+ if (values->null_count() == 0) {
+ return values;
+ }
+ if (values->null_count() == values->length()) {
+ return CreateEmptyArray(values->type(), ctx->memory_pool());
+ }
+ if (values->type()->id() == Type::type::NA) {
+ return std::make_shared<NullArray>(0);
+ }
+ ARROW_ASSIGN_OR_RAISE(auto drop_null_filter,
+ GetDropNullFilter(*values, ctx->memory_pool()));
+ return Filter(values, drop_null_filter, FilterOptions::Defaults(), ctx);
+}
+
+Result<Datum> DropNullChunkedArray(const std::shared_ptr<ChunkedArray>& values,
+ ExecContext* ctx) {
+ if (values->null_count() == 0) {
+ return values;
+ }
+ if (values->null_count() == values->length()) {
+ return CreateEmptyChunkedArray(values->type(), ctx->memory_pool());
+ }
+ std::vector<std::shared_ptr<Array>> new_chunks;
+ for (const auto& chunk : values->chunks()) {
+ ARROW_ASSIGN_OR_RAISE(auto new_chunk, DropNullArray(chunk, ctx));
+ if (new_chunk.length() > 0) {
+ new_chunks.push_back(new_chunk.make_array());
+ }
+ }
+ return std::make_shared<ChunkedArray>(std::move(new_chunks));
+}
+
+Result<Datum> DropNullRecordBatch(const std::shared_ptr<RecordBatch>& batch,
+ ExecContext* ctx) {
+ // Compute an upper bound of the final null count
+ int64_t null_count = 0;
+ for (const auto& column : batch->columns()) {
+ null_count += column->null_count();
+ }
+ if (null_count == 0) {
+ return batch;
+ }
+ ARROW_ASSIGN_OR_RAISE(auto dst,
+ AllocateEmptyBitmap(batch->num_rows(), ctx->memory_pool()));
+ BitUtil::SetBitsTo(dst->mutable_data(), 0, batch->num_rows(), true);
+ for (const auto& column : batch->columns()) {
+ if (column->type()->id() == Type::type::NA) {
+ BitUtil::SetBitsTo(dst->mutable_data(), 0, batch->num_rows(), false);
+ break;
+ }
+ if (column->null_bitmap_data()) {
+ ::arrow::internal::BitmapAnd(column->null_bitmap_data(), column->offset(),
+ dst->data(), 0, column->length(), 0,
+ dst->mutable_data());
+ }
+ }
+ auto drop_null_filter = std::make_shared<BooleanArray>(batch->num_rows(), dst);
+ if (drop_null_filter->true_count() == 0) {
+ // Shortcut: construct empty result
+ ArrayVector empty_batch(batch->num_columns());
+ for (int i = 0; i < batch->num_columns(); i++) {
+ ARROW_ASSIGN_OR_RAISE(
+ empty_batch[i], CreateEmptyArray(batch->column(i)->type(), ctx->memory_pool()));
+ }
+ return RecordBatch::Make(batch->schema(), 0, std::move(empty_batch));
+ }
+ return Filter(Datum(batch), Datum(drop_null_filter), FilterOptions::Defaults(), ctx);
+}
+
+Result<Datum> DropNullTable(const std::shared_ptr<Table>& table, ExecContext* ctx) {
+ if (table->num_rows() == 0) {
+ return table;
+ }
+ // Compute an upper bound of the final null count
+ int64_t null_count = 0;
+ for (const auto& col : table->columns()) {
+ for (const auto& column_chunk : col->chunks()) {
+ null_count += column_chunk->null_count();
+ }
+ }
+ if (null_count == 0) {
+ return table;
+ }
+
+ arrow::RecordBatchVector filtered_batches;
+ TableBatchReader batch_iter(*table);
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, batch_iter.Next());
+ if (batch == nullptr) {
+ break;
+ }
+ ARROW_ASSIGN_OR_RAISE(auto filtered_datum, DropNullRecordBatch(batch, ctx))
+ if (filtered_datum.length() > 0) {
+ filtered_batches.push_back(filtered_datum.record_batch());
+ }
+ }
+ return Table::FromRecordBatches(table->schema(), filtered_batches);
+}
+
+const FunctionDoc drop_null_doc(
+ "Drop nulls from the input",
+ ("The output is populated with values from the input (Array, ChunkedArray,\n"
+ "RecordBatch, or Table) without the null values.\n"
+ "For the RecordBatch and Table cases, `drop_null` drops the full row if\n"
+ "there is any null."),
+ {"input"});
+
+class DropNullMetaFunction : public MetaFunction {
+ public:
+ DropNullMetaFunction() : MetaFunction("drop_null", Arity::Unary(), &drop_null_doc) {}
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ switch (args[0].kind()) {
+ case Datum::ARRAY: {
+ return DropNullArray(args[0].make_array(), ctx);
+ } break;
+ case Datum::CHUNKED_ARRAY: {
+ return DropNullChunkedArray(args[0].chunked_array(), ctx);
+ } break;
+ case Datum::RECORD_BATCH: {
+ return DropNullRecordBatch(args[0].record_batch(), ctx);
+ } break;
+ case Datum::TABLE: {
+ return DropNullTable(args[0].table(), ctx);
+ } break;
+ default:
+ break;
+ }
+ return Status::NotImplemented(
+ "Unsupported types for drop_null operation: "
+ "values=",
+ args[0].ToString());
+ }
+};
+
+// ----------------------------------------------------------------------
+
+template <typename Impl>
+Status FilterExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // TODO: where are the values and filter length equality checked?
+ int64_t output_length = GetFilterOutputSize(
+ *batch[1].array(), FilterState::Get(ctx).null_selection_behavior);
+ Impl kernel(ctx, batch, output_length, out);
+ return kernel.ExecFilter();
+}
+
+template <typename Impl>
+Status TakeExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (TakeState::Get(ctx).boundscheck) {
+ RETURN_NOT_OK(CheckIndexBounds(*batch[1].array(), batch[0].length()));
+ }
+ Impl kernel(ctx, batch, /*output_length=*/batch[1].length(), out);
+ return kernel.ExecTake();
+}
+
+struct SelectionKernelDescr {
+ InputType input;
+ ArrayKernelExec exec;
+};
+
+void RegisterSelectionFunction(const std::string& name, const FunctionDoc* doc,
+ VectorKernel base_kernel, InputType selection_type,
+ const std::vector<SelectionKernelDescr>& descrs,
+ const FunctionOptions* default_options,
+ FunctionRegistry* registry) {
+ auto func =
+ std::make_shared<VectorFunction>(name, Arity::Binary(), doc, default_options);
+ for (auto& descr : descrs) {
+ base_kernel.signature = KernelSignature::Make(
+ {std::move(descr.input), selection_type}, OutputType(FirstType));
+ base_kernel.exec = descr.exec;
+ DCHECK_OK(func->AddKernel(base_kernel));
+ }
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+const FunctionDoc array_filter_doc(
+ "Filter with a boolean selection filter",
+ ("The output is populated with values from the input `array` at positions\n"
+ "where the selection filter is non-zero. Nulls in the selection filter\n"
+ "are handled based on FilterOptions."),
+ {"array", "selection_filter"}, "FilterOptions");
+
+const FunctionDoc array_take_doc(
+ "Select values from an array based on indices from another array",
+ ("The output is populated with values from the input array at positions\n"
+ "given by `indices`. Nulls in `indices` emit null in the output."),
+ {"array", "indices"}, "TakeOptions");
+
+} // namespace
+
+void RegisterVectorSelection(FunctionRegistry* registry) {
+ // Filter kernels
+ std::vector<SelectionKernelDescr> filter_kernel_descrs = {
+ {InputType(match::Primitive(), ValueDescr::ARRAY), PrimitiveFilter},
+ {InputType(match::BinaryLike(), ValueDescr::ARRAY), BinaryFilter},
+ {InputType(match::LargeBinaryLike(), ValueDescr::ARRAY), BinaryFilter},
+ {InputType::Array(Type::FIXED_SIZE_BINARY), FilterExec<FSBImpl>},
+ {InputType::Array(null()), NullFilter},
+ {InputType::Array(Type::DECIMAL), FilterExec<FSBImpl>},
+ {InputType::Array(Type::DICTIONARY), DictionaryFilter},
+ {InputType::Array(Type::EXTENSION), ExtensionFilter},
+ {InputType::Array(Type::LIST), FilterExec<ListImpl<ListType>>},
+ {InputType::Array(Type::LARGE_LIST), FilterExec<ListImpl<LargeListType>>},
+ {InputType::Array(Type::FIXED_SIZE_LIST), FilterExec<FSLImpl>},
+ {InputType::Array(Type::DENSE_UNION), FilterExec<DenseUnionImpl>},
+ {InputType::Array(Type::STRUCT), StructFilter},
+ // TODO: Reuse ListType kernel for MAP
+ {InputType::Array(Type::MAP), FilterExec<ListImpl<MapType>>},
+ };
+
+ VectorKernel filter_base;
+ filter_base.init = FilterState::Init;
+ RegisterSelectionFunction("array_filter", &array_filter_doc, filter_base,
+ /*selection_type=*/InputType::Array(boolean()),
+ filter_kernel_descrs, &kDefaultFilterOptions, registry);
+
+ DCHECK_OK(registry->AddFunction(std::make_shared<FilterMetaFunction>()));
+
+ // Take kernels
+ std::vector<SelectionKernelDescr> take_kernel_descrs = {
+ {InputType(match::Primitive(), ValueDescr::ARRAY), PrimitiveTake},
+ {InputType(match::BinaryLike(), ValueDescr::ARRAY),
+ TakeExec<VarBinaryImpl<BinaryType>>},
+ {InputType(match::LargeBinaryLike(), ValueDescr::ARRAY),
+ TakeExec<VarBinaryImpl<LargeBinaryType>>},
+ {InputType::Array(Type::FIXED_SIZE_BINARY), TakeExec<FSBImpl>},
+ {InputType::Array(null()), NullTake},
+ {InputType::Array(Type::DECIMAL128), TakeExec<FSBImpl>},
+ {InputType::Array(Type::DECIMAL256), TakeExec<FSBImpl>},
+ {InputType::Array(Type::DICTIONARY), DictionaryTake},
+ {InputType::Array(Type::EXTENSION), ExtensionTake},
+ {InputType::Array(Type::LIST), TakeExec<ListImpl<ListType>>},
+ {InputType::Array(Type::LARGE_LIST), TakeExec<ListImpl<LargeListType>>},
+ {InputType::Array(Type::FIXED_SIZE_LIST), TakeExec<FSLImpl>},
+ {InputType::Array(Type::DENSE_UNION), TakeExec<DenseUnionImpl>},
+ {InputType::Array(Type::STRUCT), TakeExec<StructImpl>},
+ // TODO: Reuse ListType kernel for MAP
+ {InputType::Array(Type::MAP), TakeExec<ListImpl<MapType>>},
+ };
+
+ VectorKernel take_base;
+ take_base.init = TakeState::Init;
+ take_base.can_execute_chunkwise = false;
+ RegisterSelectionFunction(
+ "array_take", &array_take_doc, take_base,
+ /*selection_type=*/InputType(match::Integer(), ValueDescr::ARRAY),
+ take_kernel_descrs, &kDefaultTakeOptions, registry);
+
+ DCHECK_OK(registry->AddFunction(std::make_shared<TakeMetaFunction>()));
+
+ // DropNull kernel
+ DCHECK_OK(registry->AddFunction(std::make_shared<DropNullMetaFunction>()));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc
new file mode 100644
index 000000000..25e30e65a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <cstdint>
+#include <sstream>
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+
+namespace arrow {
+namespace compute {
+
+constexpr auto kSeed = 0x0ff1ce;
+
+struct FilterParams {
+ // proportion of nulls in the values array
+ const double values_null_proportion;
+
+ // proportion of true in filter
+ const double selected_proportion;
+
+ // proportion of nulls in the filter
+ const double filter_null_proportion;
+};
+
+std::vector<int64_t> g_data_sizes = {kL2Size};
+
+// The benchmark state parameter references this vector of cases. Test high and
+// low selectivity filters.
+
+// clang-format off
+std::vector<FilterParams> g_filter_params = {
+ {0., 0.999, 0.05},
+ {0., 0.50, 0.05},
+ {0., 0.01, 0.05},
+ {0.001, 0.999, 0.05},
+ {0.001, 0.50, 0.05},
+ {0.001, 0.01, 0.05},
+ {0.01, 0.999, 0.05},
+ {0.01, 0.50, 0.05},
+ {0.01, 0.01, 0.05},
+ {0.1, 0.999, 0.05},
+ {0.1, 0.50, 0.05},
+ {0.1, 0.01, 0.05},
+ {0.9, 0.999, 0.05},
+ {0.9, 0.50, 0.05},
+ {0.9, 0.01, 0.05}
+};
+// clang-format on
+
+// RAII struct to handle some of the boilerplate in filter
+struct FilterArgs {
+ // size of memory tested (per iteration) in bytes
+ int64_t size;
+
+ // What to call the "size" that's reported in the console output, for result
+ // interpretability.
+ std::string size_name = "size";
+
+ double values_null_proportion = 0.;
+ double selected_proportion = 0.;
+ double filter_null_proportion = 0.;
+
+ FilterArgs(benchmark::State& state, bool filter_has_nulls)
+ : size(state.range(0)), state_(state) {
+ auto params = g_filter_params[state.range(1)];
+ values_null_proportion = params.values_null_proportion;
+ selected_proportion = params.selected_proportion;
+ filter_null_proportion = filter_has_nulls ? params.filter_null_proportion : 0;
+ }
+
+ ~FilterArgs() {
+ state_.counters[size_name] = static_cast<double>(size);
+ state_.counters["select%"] = selected_proportion * 100;
+ state_.counters["data null%"] = values_null_proportion * 100;
+ state_.counters["mask null%"] = filter_null_proportion * 100;
+ state_.SetBytesProcessed(state_.iterations() * size);
+ }
+
+ private:
+ benchmark::State& state_;
+};
+
+struct TakeBenchmark {
+ benchmark::State& state;
+ RegressionArgs args;
+ random::RandomArrayGenerator rand;
+ bool indices_have_nulls;
+ bool monotonic_indices = false;
+
+ TakeBenchmark(benchmark::State& state, bool indices_have_nulls,
+ bool monotonic_indices = false)
+ : state(state),
+ args(state, /*size_is_bytes=*/false),
+ rand(kSeed),
+ indices_have_nulls(indices_have_nulls),
+ monotonic_indices(monotonic_indices) {}
+
+ void Int64() {
+ auto values = rand.Int64(args.size, -100, 100, args.null_proportion);
+ Bench(values);
+ }
+
+ void FSLInt64() {
+ auto int_array = rand.Int64(args.size, -100, 100, args.null_proportion);
+ auto values = std::make_shared<FixedSizeListArray>(
+ fixed_size_list(int64(), 1), args.size, int_array, int_array->null_bitmap(),
+ int_array->null_count());
+ Bench(values);
+ }
+
+ void String() {
+ int32_t string_min_length = 0, string_max_length = 32;
+ auto values = std::static_pointer_cast<StringArray>(rand.String(
+ args.size, string_min_length, string_max_length, args.null_proportion));
+ Bench(values);
+ }
+
+ void Bench(const std::shared_ptr<Array>& values) {
+ double indices_null_proportion = indices_have_nulls ? args.null_proportion : 0;
+ auto indices =
+ rand.Int32(values->length(), 0, static_cast<int32_t>(values->length() - 1),
+ indices_null_proportion);
+
+ if (monotonic_indices) {
+ auto arg_sorter = *SortIndices(*indices);
+ indices = *Take(*indices, *arg_sorter);
+ }
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(Take(values, indices).status());
+ }
+ }
+};
+
+struct FilterBenchmark {
+ benchmark::State& state;
+ FilterArgs args;
+ random::RandomArrayGenerator rand;
+ bool filter_has_nulls;
+
+ FilterBenchmark(benchmark::State& state, bool filter_has_nulls)
+ : state(state),
+ args(state, filter_has_nulls),
+ rand(kSeed),
+ filter_has_nulls(filter_has_nulls) {}
+
+ void Int64() {
+ const int64_t array_size = args.size / sizeof(int64_t);
+ auto values = std::static_pointer_cast<NumericArray<Int64Type>>(
+ rand.Int64(array_size, -100, 100, args.values_null_proportion));
+ Bench(values);
+ }
+
+ void FSLInt64() {
+ const int64_t array_size = args.size / sizeof(int64_t);
+ auto int_array = std::static_pointer_cast<NumericArray<Int64Type>>(
+ rand.Int64(array_size, -100, 100, args.values_null_proportion));
+ auto values = std::make_shared<FixedSizeListArray>(
+ fixed_size_list(int64(), 1), array_size, int_array, int_array->null_bitmap(),
+ int_array->null_count());
+ Bench(values);
+ }
+
+ void String() {
+ int32_t string_min_length = 0, string_max_length = 32;
+ int32_t string_mean_length = (string_max_length + string_min_length) / 2;
+ // for an array of 50% null strings, we need to generate twice as many strings
+ // to ensure that they have an average of args.size total characters
+ int64_t array_size = args.size;
+ if (args.values_null_proportion < 1) {
+ array_size = static_cast<int64_t>(args.size / string_mean_length /
+ (1 - args.values_null_proportion));
+ }
+ auto values = std::static_pointer_cast<StringArray>(rand.String(
+ array_size, string_min_length, string_max_length, args.values_null_proportion));
+ Bench(values);
+ }
+
+ void Bench(const std::shared_ptr<Array>& values) {
+ auto filter = rand.Boolean(values->length(), args.selected_proportion,
+ args.filter_null_proportion);
+ for (auto _ : state) {
+ ABORT_NOT_OK(Filter(values, filter).status());
+ }
+ }
+
+ void BenchRecordBatch() {
+ const int64_t total_data_cells = 10000000;
+ const int64_t num_columns = state.range(0);
+ const int64_t num_rows = total_data_cells / num_columns;
+
+ auto col_data = rand.Float64(num_rows, 0, 1);
+
+ auto filter =
+ rand.Boolean(num_rows, args.selected_proportion, args.filter_null_proportion);
+
+ int64_t output_length =
+ internal::GetFilterOutputSize(*filter->data(), FilterOptions::DROP);
+
+ // HACK: set FilterArgs.size to the number of selected data cells *
+ // sizeof(double) for accurate memory processing performance
+ args.size = output_length * num_columns * sizeof(double);
+ args.size_name = "extracted_size";
+ state.counters["num_cols"] = static_cast<double>(num_columns);
+
+ std::vector<std::shared_ptr<Array>> columns;
+ std::vector<std::shared_ptr<Field>> fields;
+ for (int64_t i = 0; i < num_columns; ++i) {
+ std::stringstream ss;
+ ss << "f" << i;
+ fields.push_back(::arrow::field(ss.str(), float64()));
+ columns.push_back(col_data);
+ }
+
+ auto batch = RecordBatch::Make(schema(fields), num_rows, columns);
+ for (auto _ : state) {
+ ABORT_NOT_OK(Filter(batch, filter).status());
+ }
+ }
+};
+
+static void FilterInt64FilterNoNulls(benchmark::State& state) {
+ FilterBenchmark(state, false).Int64();
+}
+
+static void FilterInt64FilterWithNulls(benchmark::State& state) {
+ FilterBenchmark(state, true).Int64();
+}
+
+static void FilterFSLInt64FilterNoNulls(benchmark::State& state) {
+ FilterBenchmark(state, false).FSLInt64();
+}
+
+static void FilterFSLInt64FilterWithNulls(benchmark::State& state) {
+ FilterBenchmark(state, true).FSLInt64();
+}
+
+static void FilterStringFilterNoNulls(benchmark::State& state) {
+ FilterBenchmark(state, false).String();
+}
+
+static void FilterStringFilterWithNulls(benchmark::State& state) {
+ FilterBenchmark(state, true).String();
+}
+
+static void FilterRecordBatchNoNulls(benchmark::State& state) {
+ FilterBenchmark(state, false).BenchRecordBatch();
+}
+
+static void FilterRecordBatchWithNulls(benchmark::State& state) {
+ FilterBenchmark(state, true).BenchRecordBatch();
+}
+
+static void TakeInt64RandomIndicesNoNulls(benchmark::State& state) {
+ TakeBenchmark(state, false).Int64();
+}
+
+static void TakeInt64RandomIndicesWithNulls(benchmark::State& state) {
+ TakeBenchmark(state, true).Int64();
+}
+
+static void TakeInt64MonotonicIndices(benchmark::State& state) {
+ TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true).Int64();
+}
+
+static void TakeFSLInt64RandomIndicesNoNulls(benchmark::State& state) {
+ TakeBenchmark(state, false).FSLInt64();
+}
+
+static void TakeFSLInt64RandomIndicesWithNulls(benchmark::State& state) {
+ TakeBenchmark(state, true).FSLInt64();
+}
+
+static void TakeFSLInt64MonotonicIndices(benchmark::State& state) {
+ TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true).FSLInt64();
+}
+
+static void TakeStringRandomIndicesNoNulls(benchmark::State& state) {
+ TakeBenchmark(state, false).String();
+}
+
+static void TakeStringRandomIndicesWithNulls(benchmark::State& state) {
+ TakeBenchmark(state, true).String();
+}
+
+static void TakeStringMonotonicIndices(benchmark::State& state) {
+ TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true).FSLInt64();
+}
+
+void FilterSetArgs(benchmark::internal::Benchmark* bench) {
+ for (int64_t size : g_data_sizes) {
+ for (int i = 0; i < static_cast<int>(g_filter_params.size()); ++i) {
+ bench->Args({static_cast<ArgsType>(size), i});
+ }
+ }
+}
+
+BENCHMARK(FilterInt64FilterNoNulls)->Apply(FilterSetArgs);
+BENCHMARK(FilterInt64FilterWithNulls)->Apply(FilterSetArgs);
+BENCHMARK(FilterFSLInt64FilterNoNulls)->Apply(FilterSetArgs);
+BENCHMARK(FilterFSLInt64FilterWithNulls)->Apply(FilterSetArgs);
+BENCHMARK(FilterStringFilterNoNulls)->Apply(FilterSetArgs);
+BENCHMARK(FilterStringFilterWithNulls)->Apply(FilterSetArgs);
+
+void FilterRecordBatchSetArgs(benchmark::internal::Benchmark* bench) {
+ for (auto num_cols : std::vector<int>({10, 50, 100})) {
+ for (int i = 0; i < static_cast<int>(g_filter_params.size()); ++i) {
+ bench->Args({num_cols, i});
+ }
+ }
+}
+BENCHMARK(FilterRecordBatchNoNulls)->Apply(FilterRecordBatchSetArgs);
+BENCHMARK(FilterRecordBatchWithNulls)->Apply(FilterRecordBatchSetArgs);
+
+void TakeSetArgs(benchmark::internal::Benchmark* bench) {
+ for (int64_t size : g_data_sizes) {
+ for (auto nulls : std::vector<ArgsType>({1000, 10, 2, 1, 0})) {
+ bench->Args({static_cast<ArgsType>(size), nulls});
+ }
+ }
+}
+
+BENCHMARK(TakeInt64RandomIndicesNoNulls)->Apply(TakeSetArgs);
+BENCHMARK(TakeInt64RandomIndicesWithNulls)->Apply(TakeSetArgs);
+BENCHMARK(TakeInt64MonotonicIndices)->Apply(TakeSetArgs);
+BENCHMARK(TakeFSLInt64RandomIndicesNoNulls)->Apply(TakeSetArgs);
+BENCHMARK(TakeFSLInt64RandomIndicesWithNulls)->Apply(TakeSetArgs);
+BENCHMARK(TakeFSLInt64MonotonicIndices)->Apply(TakeSetArgs);
+BENCHMARK(TakeStringRandomIndicesNoNulls)->Apply(TakeSetArgs);
+BENCHMARK(TakeStringRandomIndicesWithNulls)->Apply(TakeSetArgs);
+BENCHMARK(TakeStringMonotonicIndices)->Apply(TakeSetArgs);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_selection_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_selection_test.cc
new file mode 100644
index 000000000..959de6035
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_selection_test.cc
@@ -0,0 +1,2332 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/concatenate.h"
+#include "arrow/chunked_array.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+using util::string_view;
+
+namespace compute {
+
+// ----------------------------------------------------------------------
+
+TEST(GetTakeIndices, Basics) {
+ auto CheckCase = [&](const std::string& filter_json, const std::string& indices_json,
+ FilterOptions::NullSelectionBehavior null_selection,
+ const std::shared_ptr<DataType>& indices_type = uint16()) {
+ auto filter = ArrayFromJSON(boolean(), filter_json);
+ auto expected_indices = ArrayFromJSON(indices_type, indices_json);
+ ASSERT_OK_AND_ASSIGN(auto indices,
+ internal::GetTakeIndices(*filter->data(), null_selection));
+ auto indices_array = MakeArray(indices);
+ ValidateOutput(indices);
+ AssertArraysEqual(*expected_indices, *indices_array, /*verbose=*/true);
+ };
+
+ // Drop null cases
+ CheckCase("[]", "[]", FilterOptions::DROP);
+ CheckCase("[null]", "[]", FilterOptions::DROP);
+ CheckCase("[null, false, true, true, false, true]", "[2, 3, 5]", FilterOptions::DROP);
+
+ // Emit null cases
+ CheckCase("[]", "[]", FilterOptions::EMIT_NULL);
+ CheckCase("[null]", "[null]", FilterOptions::EMIT_NULL);
+ CheckCase("[null, false, true, true]", "[null, 2, 3]", FilterOptions::EMIT_NULL);
+}
+
+TEST(GetTakeIndices, NullValidityBuffer) {
+ BooleanArray filter(1, *AllocateEmptyBitmap(1), /*null_bitmap=*/nullptr);
+ auto expected_indices = ArrayFromJSON(uint16(), "[]");
+
+ ASSERT_OK_AND_ASSIGN(auto indices,
+ internal::GetTakeIndices(*filter.data(), FilterOptions::DROP));
+ auto indices_array = MakeArray(indices);
+ ValidateOutput(indices);
+ AssertArraysEqual(*expected_indices, *indices_array, /*verbose=*/true);
+
+ ASSERT_OK_AND_ASSIGN(
+ indices, internal::GetTakeIndices(*filter.data(), FilterOptions::EMIT_NULL));
+ indices_array = MakeArray(indices);
+ ValidateOutput(indices);
+ AssertArraysEqual(*expected_indices, *indices_array, /*verbose=*/true);
+}
+
+template <typename IndexArrayType>
+void CheckGetTakeIndicesCase(const Array& untyped_filter) {
+ const auto& filter = checked_cast<const BooleanArray&>(untyped_filter);
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<ArrayData> drop_indices,
+ internal::GetTakeIndices(*filter.data(), FilterOptions::DROP));
+ // Verify DROP indices
+ {
+ IndexArrayType indices(drop_indices);
+ ValidateOutput(indices);
+
+ int64_t out_position = 0;
+ for (int64_t i = 0; i < filter.length(); ++i) {
+ if (filter.IsValid(i)) {
+ if (filter.Value(i)) {
+ ASSERT_EQ(indices.Value(out_position), i);
+ ++out_position;
+ }
+ }
+ }
+ ASSERT_EQ(out_position, indices.length());
+ // Check that the end length agrees with the output of GetFilterOutputSize
+ ASSERT_EQ(out_position,
+ internal::GetFilterOutputSize(*filter.data(), FilterOptions::DROP));
+ }
+
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<ArrayData> emit_indices,
+ internal::GetTakeIndices(*filter.data(), FilterOptions::EMIT_NULL));
+ // Verify EMIT_NULL indices
+ {
+ IndexArrayType indices(emit_indices);
+ ValidateOutput(indices);
+
+ int64_t out_position = 0;
+ for (int64_t i = 0; i < filter.length(); ++i) {
+ if (filter.IsValid(i)) {
+ if (filter.Value(i)) {
+ ASSERT_EQ(indices.Value(out_position), i);
+ ++out_position;
+ }
+ } else {
+ ASSERT_TRUE(indices.IsNull(out_position));
+ ++out_position;
+ }
+ }
+
+ ASSERT_EQ(out_position, indices.length());
+ // Check that the end length agrees with the output of GetFilterOutputSize
+ ASSERT_EQ(out_position,
+ internal::GetFilterOutputSize(*filter.data(), FilterOptions::EMIT_NULL));
+ }
+}
+
+TEST(GetTakeIndices, RandomlyGenerated) {
+ random::RandomArrayGenerator rng(kRandomSeed);
+
+ // Multiple of word size + 1
+ const int64_t length = 6401;
+ for (auto null_prob : {0.0, 0.01, 0.999, 1.0}) {
+ for (auto true_prob : {0.0, 0.01, 0.999, 1.0}) {
+ auto filter = rng.Boolean(length, true_prob, null_prob);
+ CheckGetTakeIndicesCase<UInt16Array>(*filter);
+ CheckGetTakeIndicesCase<UInt16Array>(*filter->Slice(7));
+ }
+ }
+
+ // Check that the uint32 path is traveled successfully
+ const int64_t uint16_max = std::numeric_limits<uint16_t>::max();
+ auto filter =
+ std::static_pointer_cast<BooleanArray>(rng.Boolean(uint16_max + 1, 0.99, 0.01));
+ CheckGetTakeIndicesCase<UInt16Array>(*filter->Slice(1));
+ CheckGetTakeIndicesCase<UInt32Array>(*filter);
+}
+
+// ----------------------------------------------------------------------
+// Filter tests
+
+std::shared_ptr<Array> CoalesceNullToFalse(std::shared_ptr<Array> filter) {
+ if (filter->null_count() == 0) {
+ return filter;
+ }
+ const auto& data = *filter->data();
+ auto is_true = std::make_shared<BooleanArray>(data.length, data.buffers[1], nullptr, 0,
+ data.offset);
+ auto is_valid = std::make_shared<BooleanArray>(data.length, data.buffers[0], nullptr, 0,
+ data.offset);
+ EXPECT_OK_AND_ASSIGN(Datum out_datum, And(is_true, is_valid));
+ return out_datum.make_array();
+}
+
+class TestFilterKernel : public ::testing::Test {
+ protected:
+ TestFilterKernel() : emit_null_(FilterOptions::EMIT_NULL), drop_(FilterOptions::DROP) {}
+
+ void DoAssertFilter(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& filter,
+ const std::shared_ptr<Array>& expected) {
+ // test with EMIT_NULL
+ {
+ ARROW_SCOPED_TRACE("with EMIT_NULL");
+ ASSERT_OK_AND_ASSIGN(Datum out_datum, Filter(values, filter, emit_null_));
+ auto actual = out_datum.make_array();
+ ValidateOutput(*actual);
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+ }
+
+ // test with DROP using EMIT_NULL and a coalesced filter
+ {
+ ARROW_SCOPED_TRACE("with DROP");
+ auto coalesced_filter = CoalesceNullToFalse(filter);
+ ASSERT_OK_AND_ASSIGN(Datum out_datum, Filter(values, coalesced_filter, emit_null_));
+ auto expected_for_drop = out_datum.make_array();
+ ASSERT_OK_AND_ASSIGN(out_datum, Filter(values, filter, drop_));
+ auto actual = out_datum.make_array();
+ ValidateOutput(*actual);
+ AssertArraysEqual(*expected_for_drop, *actual, /*verbose=*/true);
+ }
+ }
+
+ void AssertFilter(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& filter,
+ const std::shared_ptr<Array>& expected) {
+ DoAssertFilter(values, filter, expected);
+
+ if (values->type_id() == Type::DENSE_UNION) {
+ // Concatenation of dense union not supported
+ return;
+ }
+
+ // Check slicing: add M(=3) dummy values at the start and end of `values`,
+ // add N(=2) dummy values at the start and end of `filter`.
+ ARROW_SCOPED_TRACE("for sliced values and filter");
+ ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(values->type(), 3));
+ auto filter_filler = ArrayFromJSON(boolean(), "[true, false]");
+ ASSERT_OK_AND_ASSIGN(auto values_sliced,
+ Concatenate({values_filler, values, values_filler}));
+ ASSERT_OK_AND_ASSIGN(auto filter_sliced,
+ Concatenate({filter_filler, filter, filter_filler}));
+ values_sliced = values_sliced->Slice(3, values->length());
+ filter_sliced = filter_sliced->Slice(2, filter->length());
+ DoAssertFilter(values_sliced, filter_sliced, expected);
+ }
+
+ void AssertFilter(const std::shared_ptr<DataType>& type, const std::string& values,
+ const std::string& filter, const std::string& expected) {
+ AssertFilter(ArrayFromJSON(type, values), ArrayFromJSON(boolean(), filter),
+ ArrayFromJSON(type, expected));
+ }
+
+ FilterOptions emit_null_, drop_;
+};
+
+void ValidateFilter(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& filter_boxed) {
+ FilterOptions emit_null(FilterOptions::EMIT_NULL);
+ FilterOptions drop(FilterOptions::DROP);
+
+ ASSERT_OK_AND_ASSIGN(Datum out_datum, Filter(values, filter_boxed, emit_null));
+ auto filtered_emit_null = out_datum.make_array();
+ ValidateOutput(*filtered_emit_null);
+
+ ASSERT_OK_AND_ASSIGN(out_datum, Filter(values, filter_boxed, drop));
+ auto filtered_drop = out_datum.make_array();
+ ValidateOutput(*filtered_drop);
+
+ // Create the expected arrays using Take
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<ArrayData> drop_indices,
+ internal::GetTakeIndices(*filter_boxed->data(), FilterOptions::DROP));
+ ASSERT_OK_AND_ASSIGN(Datum expected_drop, Take(values, Datum(drop_indices)));
+
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<ArrayData> emit_null_indices,
+ internal::GetTakeIndices(*filter_boxed->data(), FilterOptions::EMIT_NULL));
+ ASSERT_OK_AND_ASSIGN(Datum expected_emit_null, Take(values, Datum(emit_null_indices)));
+
+ AssertArraysEqual(*expected_drop.make_array(), *filtered_drop,
+ /*verbose=*/true);
+ AssertArraysEqual(*expected_emit_null.make_array(), *filtered_emit_null,
+ /*verbose=*/true);
+}
+
+class TestFilterKernelWithNull : public TestFilterKernel {
+ protected:
+ void AssertFilter(const std::string& values, const std::string& filter,
+ const std::string& expected) {
+ TestFilterKernel::AssertFilter(ArrayFromJSON(null(), values),
+ ArrayFromJSON(boolean(), filter),
+ ArrayFromJSON(null(), expected));
+ }
+};
+
+TEST_F(TestFilterKernelWithNull, FilterNull) {
+ this->AssertFilter("[]", "[]", "[]");
+
+ this->AssertFilter("[null, null, null]", "[0, 1, 0]", "[null]");
+ this->AssertFilter("[null, null, null]", "[1, 1, 0]", "[null, null]");
+}
+
+class TestFilterKernelWithBoolean : public TestFilterKernel {
+ protected:
+ void AssertFilter(const std::string& values, const std::string& filter,
+ const std::string& expected) {
+ TestFilterKernel::AssertFilter(ArrayFromJSON(boolean(), values),
+ ArrayFromJSON(boolean(), filter),
+ ArrayFromJSON(boolean(), expected));
+ }
+};
+
+TEST_F(TestFilterKernelWithBoolean, FilterBoolean) {
+ this->AssertFilter("[]", "[]", "[]");
+
+ this->AssertFilter("[true, false, true]", "[0, 1, 0]", "[false]");
+ this->AssertFilter("[null, false, true]", "[0, 1, 0]", "[false]");
+ this->AssertFilter("[true, false, true]", "[null, 1, 0]", "[null, false]");
+}
+
+TEST_F(TestFilterKernelWithBoolean, DefaultOptions) {
+ auto values = ArrayFromJSON(int8(), "[7, 8, null, 9]");
+ auto filter = ArrayFromJSON(boolean(), "[1, 1, 0, null]");
+
+ ASSERT_OK_AND_ASSIGN(auto no_options_provided,
+ CallFunction("filter", {values, filter}));
+
+ auto default_options = FilterOptions::Defaults();
+ ASSERT_OK_AND_ASSIGN(auto explicit_defaults,
+ CallFunction("filter", {values, filter}, &default_options));
+
+ AssertDatumsEqual(explicit_defaults, no_options_provided);
+}
+
+template <typename ArrowType>
+class TestFilterKernelWithNumeric : public TestFilterKernel {
+ protected:
+ std::shared_ptr<DataType> type_singleton() {
+ return TypeTraits<ArrowType>::type_singleton();
+ }
+};
+
+TYPED_TEST_SUITE(TestFilterKernelWithNumeric, NumericArrowTypes);
+TYPED_TEST(TestFilterKernelWithNumeric, FilterNumeric) {
+ auto type = this->type_singleton();
+ this->AssertFilter(type, "[]", "[]", "[]");
+
+ this->AssertFilter(type, "[9]", "[0]", "[]");
+ this->AssertFilter(type, "[9]", "[1]", "[9]");
+ this->AssertFilter(type, "[9]", "[null]", "[null]");
+ this->AssertFilter(type, "[null]", "[0]", "[]");
+ this->AssertFilter(type, "[null]", "[1]", "[null]");
+ this->AssertFilter(type, "[null]", "[null]", "[null]");
+
+ this->AssertFilter(type, "[7, 8, 9]", "[0, 1, 0]", "[8]");
+ this->AssertFilter(type, "[7, 8, 9]", "[1, 0, 1]", "[7, 9]");
+ this->AssertFilter(type, "[null, 8, 9]", "[0, 1, 0]", "[8]");
+ this->AssertFilter(type, "[7, 8, 9]", "[null, 1, 0]", "[null, 8]");
+ this->AssertFilter(type, "[7, 8, 9]", "[1, null, 1]", "[7, null, 9]");
+
+ this->AssertFilter(ArrayFromJSON(type, "[7, 8, 9]"),
+ ArrayFromJSON(boolean(), "[0, 1, 1, 1, 0, 1]")->Slice(3, 3),
+ ArrayFromJSON(type, "[7, 9]"));
+
+ ASSERT_RAISES(Invalid, Filter(ArrayFromJSON(type, "[7, 8, 9]"),
+ ArrayFromJSON(boolean(), "[]"), this->emit_null_));
+ ASSERT_RAISES(Invalid, Filter(ArrayFromJSON(type, "[7, 8, 9]"),
+ ArrayFromJSON(boolean(), "[]"), this->drop_));
+}
+
+template <typename CType>
+using Comparator = bool(CType, CType);
+
+template <typename CType>
+Comparator<CType>* GetComparator(CompareOperator op) {
+ static Comparator<CType>* cmp[] = {
+ // EQUAL
+ [](CType l, CType r) { return l == r; },
+ // NOT_EQUAL
+ [](CType l, CType r) { return l != r; },
+ // GREATER
+ [](CType l, CType r) { return l > r; },
+ // GREATER_EQUAL
+ [](CType l, CType r) { return l >= r; },
+ // LESS
+ [](CType l, CType r) { return l < r; },
+ // LESS_EQUAL
+ [](CType l, CType r) { return l <= r; },
+ };
+ return cmp[op];
+}
+
+template <typename T, typename Fn, typename CType = typename TypeTraits<T>::CType>
+std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length, Fn&& fn) {
+ std::vector<CType> filtered;
+ filtered.reserve(length);
+ std::copy_if(data, data + length, std::back_inserter(filtered), std::forward<Fn>(fn));
+ std::shared_ptr<Array> filtered_array;
+ ArrayFromVector<T, CType>(filtered, &filtered_array);
+ return filtered_array;
+}
+
+template <typename T, typename CType = typename TypeTraits<T>::CType>
+std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length, CType val,
+ CompareOperator op) {
+ auto cmp = GetComparator<CType>(op);
+ return CompareAndFilter<T>(data, length, [&](CType e) { return cmp(e, val); });
+}
+
+template <typename T, typename CType = typename TypeTraits<T>::CType>
+std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length,
+ const CType* other, CompareOperator op) {
+ auto cmp = GetComparator<CType>(op);
+ return CompareAndFilter<T>(data, length, [&](CType e) { return cmp(e, *other++); });
+}
+
+TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
+ using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
+ using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
+ using CType = typename TypeTraits<TypeParam>::CType;
+
+ auto rand = random::RandomArrayGenerator(kRandomSeed);
+ for (size_t i = 3; i < 10; i++) {
+ const int64_t length = static_cast<int64_t>(1ULL << i);
+ // TODO(bkietz) rewrite with some nulls
+ auto array =
+ checked_pointer_cast<ArrayType>(rand.Numeric<TypeParam>(length, 0, 100, 0));
+ CType c_fifty = 50;
+ auto fifty = std::make_shared<ScalarType>(c_fifty);
+ for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
+ ASSERT_OK_AND_ASSIGN(
+ Datum selection,
+ CallFunction(CompareOperatorToFunctionName(op), {array, Datum(fifty)}));
+ ASSERT_OK_AND_ASSIGN(Datum filtered, Filter(array, selection));
+ auto filtered_array = filtered.make_array();
+ ValidateOutput(*filtered_array);
+ auto expected =
+ CompareAndFilter<TypeParam>(array->raw_values(), array->length(), c_fifty, op);
+ ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
+ }
+ }
+}
+
+TYPED_TEST(TestFilterKernelWithNumeric, CompareArrayAndFilterRandomNumeric) {
+ using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
+
+ auto rand = random::RandomArrayGenerator(kRandomSeed);
+ for (size_t i = 3; i < 10; i++) {
+ const int64_t length = static_cast<int64_t>(1ULL << i);
+ auto lhs = checked_pointer_cast<ArrayType>(
+ rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
+ auto rhs = checked_pointer_cast<ArrayType>(
+ rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
+ for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
+ ASSERT_OK_AND_ASSIGN(Datum selection,
+ CallFunction(CompareOperatorToFunctionName(op), {lhs, rhs}));
+ ASSERT_OK_AND_ASSIGN(Datum filtered, Filter(lhs, selection));
+ auto filtered_array = filtered.make_array();
+ ValidateOutput(*filtered_array);
+ auto expected = CompareAndFilter<TypeParam>(lhs->raw_values(), lhs->length(),
+ rhs->raw_values(), op);
+ ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
+ }
+ }
+}
+
+TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) {
+ using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
+ using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
+ using CType = typename TypeTraits<TypeParam>::CType;
+
+ auto rand = random::RandomArrayGenerator(kRandomSeed);
+ for (size_t i = 3; i < 10; i++) {
+ const int64_t length = static_cast<int64_t>(1ULL << i);
+ auto array = checked_pointer_cast<ArrayType>(
+ rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
+ CType c_fifty = 50, c_hundred = 100;
+ auto fifty = std::make_shared<ScalarType>(c_fifty);
+ auto hundred = std::make_shared<ScalarType>(c_hundred);
+ ASSERT_OK_AND_ASSIGN(Datum greater_than_fifty,
+ CallFunction("greater", {array, Datum(fifty)}));
+ ASSERT_OK_AND_ASSIGN(Datum less_than_hundred,
+ CallFunction("less", {array, Datum(hundred)}));
+ ASSERT_OK_AND_ASSIGN(Datum selection, And(greater_than_fifty, less_than_hundred));
+ ASSERT_OK_AND_ASSIGN(Datum filtered, Filter(array, selection));
+ auto filtered_array = filtered.make_array();
+ ValidateOutput(*filtered_array);
+ auto expected = CompareAndFilter<TypeParam>(
+ array->raw_values(), array->length(),
+ [&](CType e) { return (e > c_fifty) && (e < c_hundred); });
+ ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
+ }
+}
+
+TEST(TestFilterKernel, NoValidityBitmapButUnknownNullCount) {
+ auto values = ArrayFromJSON(int32(), "[1, 2, 3, 4]");
+ auto filter = ArrayFromJSON(boolean(), "[true, true, false, true]");
+
+ auto expected = (*Filter(values, filter)).make_array();
+
+ filter->data()->null_count = kUnknownNullCount;
+ auto result = (*Filter(values, filter)).make_array();
+
+ AssertArraysEqual(*expected, *result);
+}
+
+template <typename TypeClass>
+class TestFilterKernelWithString : public TestFilterKernel {
+ protected:
+ std::shared_ptr<DataType> value_type() {
+ return TypeTraits<TypeClass>::type_singleton();
+ }
+
+ void AssertFilter(const std::string& values, const std::string& filter,
+ const std::string& expected) {
+ TestFilterKernel::AssertFilter(ArrayFromJSON(value_type(), values),
+ ArrayFromJSON(boolean(), filter),
+ ArrayFromJSON(value_type(), expected));
+ }
+
+ void AssertFilterDictionary(const std::string& dictionary_values,
+ const std::string& dictionary_filter,
+ const std::string& filter,
+ const std::string& expected_filter) {
+ auto dict = ArrayFromJSON(value_type(), dictionary_values);
+ auto type = dictionary(int8(), value_type());
+ ASSERT_OK_AND_ASSIGN(auto values,
+ DictionaryArray::FromArrays(
+ type, ArrayFromJSON(int8(), dictionary_filter), dict));
+ ASSERT_OK_AND_ASSIGN(
+ auto expected,
+ DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_filter), dict));
+ auto take_filter = ArrayFromJSON(boolean(), filter);
+ TestFilterKernel::AssertFilter(values, take_filter, expected);
+ }
+};
+
+TYPED_TEST_SUITE(TestFilterKernelWithString, BinaryArrowTypes);
+
+TYPED_TEST(TestFilterKernelWithString, FilterString) {
+ this->AssertFilter(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["b"])");
+ this->AssertFilter(R"([null, "b", "c"])", "[0, 1, 0]", R"(["b"])");
+ this->AssertFilter(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b"])");
+}
+
+TYPED_TEST(TestFilterKernelWithString, FilterDictionary) {
+ auto dict = R"(["a", "b", "c", "d", "e"])";
+ this->AssertFilterDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[4]");
+ this->AssertFilterDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[4]");
+ this->AssertFilterDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4]");
+}
+
+class TestFilterKernelWithList : public TestFilterKernel {
+ public:
+};
+
+TEST_F(TestFilterKernelWithList, FilterListInt32) {
+ std::string list_json = "[[], [1,2], null, [3]]";
+ this->AssertFilter(list(int32()), list_json, "[0, 0, 0, 0]", "[]");
+ this->AssertFilter(list(int32()), list_json, "[0, 1, 1, null]", "[[1,2], null, null]");
+ this->AssertFilter(list(int32()), list_json, "[0, 0, 1, null]", "[null, null]");
+ this->AssertFilter(list(int32()), list_json, "[1, 0, 0, 1]", "[[], [3]]");
+ this->AssertFilter(list(int32()), list_json, "[1, 1, 1, 1]", list_json);
+ this->AssertFilter(list(int32()), list_json, "[0, 1, 0, 1]", "[[1,2], [3]]");
+}
+
+TEST_F(TestFilterKernelWithList, FilterListListInt32) {
+ std::string list_json = R"([
+ [],
+ [[1], [2, null, 2], []],
+ null,
+ [[3, null], null]
+ ])";
+ auto type = list(list(int32()));
+ this->AssertFilter(type, list_json, "[0, 0, 0, 0]", "[]");
+ this->AssertFilter(type, list_json, "[0, 1, 1, null]", R"([
+ [[1], [2, null, 2], []],
+ null,
+ null
+ ])");
+ this->AssertFilter(type, list_json, "[0, 0, 1, null]", "[null, null]");
+ this->AssertFilter(type, list_json, "[1, 0, 0, 1]", R"([
+ [],
+ [[3, null], null]
+ ])");
+ this->AssertFilter(type, list_json, "[1, 1, 1, 1]", list_json);
+ this->AssertFilter(type, list_json, "[0, 1, 0, 1]", R"([
+ [[1], [2, null, 2], []],
+ [[3, null], null]
+ ])");
+}
+
+class TestFilterKernelWithLargeList : public TestFilterKernel {};
+
+TEST_F(TestFilterKernelWithLargeList, FilterListInt32) {
+ std::string list_json = "[[], [1,2], null, [3]]";
+ this->AssertFilter(large_list(int32()), list_json, "[0, 0, 0, 0]", "[]");
+ this->AssertFilter(large_list(int32()), list_json, "[0, 1, 1, null]",
+ "[[1,2], null, null]");
+}
+
+class TestFilterKernelWithFixedSizeList : public TestFilterKernel {};
+
+TEST_F(TestFilterKernelWithFixedSizeList, FilterFixedSizeListInt32) {
+ std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]";
+ this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 0, 0, 0]", "[]");
+ this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 1, 1, null]",
+ "[[1, null, 3], [4, 5, 6], null]");
+ this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 0, 1, null]",
+ "[[4, 5, 6], null]");
+ this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[1, 1, 1, 1]", list_json);
+ this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 1, 0, 1]",
+ "[[1, null, 3], [7, 8, null]]");
+}
+
+class TestFilterKernelWithMap : public TestFilterKernel {};
+
+TEST_F(TestFilterKernelWithMap, FilterMapStringToInt32) {
+ std::string map_json = R"([
+ [["joe", 0], ["mark", null]],
+ null,
+ [["cap", 8]],
+ []
+ ])";
+ this->AssertFilter(map(utf8(), int32()), map_json, "[0, 0, 0, 0]", "[]");
+ this->AssertFilter(map(utf8(), int32()), map_json, "[0, 1, 1, null]", R"([
+ null,
+ [["cap", 8]],
+ null
+ ])");
+ this->AssertFilter(map(utf8(), int32()), map_json, "[1, 1, 1, 1]", map_json);
+ this->AssertFilter(map(utf8(), int32()), map_json, "[0, 1, 0, 1]", "[null, []]");
+}
+
+class TestFilterKernelWithStruct : public TestFilterKernel {};
+
+TEST_F(TestFilterKernelWithStruct, FilterStruct) {
+ auto struct_type = struct_({field("a", int32()), field("b", utf8())});
+ auto struct_json = R"([
+ null,
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])";
+ this->AssertFilter(struct_type, struct_json, "[0, 0, 0, 0]", "[]");
+ this->AssertFilter(struct_type, struct_json, "[0, 1, 1, null]", R"([
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ null
+ ])");
+ this->AssertFilter(struct_type, struct_json, "[1, 1, 1, 1]", struct_json);
+ this->AssertFilter(struct_type, struct_json, "[1, 0, 1, 0]", R"([
+ null,
+ {"a": 2, "b": "hello"}
+ ])");
+}
+
+class TestFilterKernelWithUnion : public TestFilterKernel {};
+
+TEST_F(TestFilterKernelWithUnion, FilterUnion) {
+ auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5});
+ auto union_json = R"([
+ [2, null],
+ [2, 222],
+ [5, "hello"],
+ [5, "eh"],
+ [2, null],
+ [2, 111],
+ [5, null]
+ ])";
+ this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]");
+ this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([
+ [2, 222],
+ [5, "hello"],
+ [2, null],
+ [2, 111],
+ [5, null]
+ ])");
+ this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([
+ [2, null],
+ [5, "hello"],
+ [2, null]
+ ])");
+ this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json);
+
+ // Sliced
+ // (check this manually as concatenation of dense unions isn't supported: ARROW-4975)
+ auto values = ArrayFromJSON(union_type, union_json)->Slice(2, 4);
+ auto filter = ArrayFromJSON(boolean(), "[0, 1, 1, null, 0, 1, 1]")->Slice(2, 4);
+ auto expected = ArrayFromJSON(union_type, R"([
+ [5, "hello"],
+ [2, null],
+ [2, 111]
+ ])");
+ this->AssertFilter(values, filter, expected);
+}
+
+class TestFilterKernelWithRecordBatch : public TestFilterKernel {
+ public:
+ void AssertFilter(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
+ const std::string& selection, FilterOptions options,
+ const std::string& expected_batch) {
+ std::shared_ptr<RecordBatch> actual;
+
+ ASSERT_OK(this->DoFilter(schm, batch_json, selection, options, &actual));
+ ValidateOutput(actual);
+ ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual);
+ }
+
+ Status DoFilter(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
+ const std::string& selection, FilterOptions options,
+ std::shared_ptr<RecordBatch>* out) {
+ auto batch = RecordBatchFromJSON(schm, batch_json);
+ ARROW_ASSIGN_OR_RAISE(Datum out_datum,
+ Filter(batch, ArrayFromJSON(boolean(), selection), options));
+ *out = out_datum.record_batch();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestFilterKernelWithRecordBatch, FilterRecordBatch) {
+ std::vector<std::shared_ptr<Field>> fields = {field("a", int32()), field("b", utf8())};
+ auto schm = schema(fields);
+
+ auto batch_json = R"([
+ {"a": null, "b": "yo"},
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])";
+ for (auto options : {this->emit_null_, this->drop_}) {
+ this->AssertFilter(schm, batch_json, "[0, 0, 0, 0]", options, "[]");
+ this->AssertFilter(schm, batch_json, "[1, 1, 1, 1]", options, batch_json);
+ this->AssertFilter(schm, batch_json, "[1, 0, 1, 0]", options, R"([
+ {"a": null, "b": "yo"},
+ {"a": 2, "b": "hello"}
+ ])");
+ }
+
+ this->AssertFilter(schm, batch_json, "[0, 1, 1, null]", this->drop_, R"([
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"}
+ ])");
+
+ this->AssertFilter(schm, batch_json, "[0, 1, 1, null]", this->emit_null_, R"([
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": null, "b": null}
+ ])");
+}
+
+class TestFilterKernelWithChunkedArray : public TestFilterKernel {
+ public:
+ void AssertFilter(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, const std::string& filter,
+ const std::vector<std::string>& expected) {
+ std::shared_ptr<ChunkedArray> actual;
+ ASSERT_OK(this->FilterWithArray(type, values, filter, &actual));
+ ValidateOutput(actual);
+ AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
+ }
+
+ void AssertChunkedFilter(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& filter,
+ const std::vector<std::string>& expected) {
+ std::shared_ptr<ChunkedArray> actual;
+ ASSERT_OK(this->FilterWithChunkedArray(type, values, filter, &actual));
+ ValidateOutput(actual);
+ AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
+ }
+
+ Status FilterWithArray(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values,
+ const std::string& filter, std::shared_ptr<ChunkedArray>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum out_datum, Filter(ChunkedArrayFromJSON(type, values),
+ ArrayFromJSON(boolean(), filter)));
+ *out = out_datum.chunked_array();
+ return Status::OK();
+ }
+
+ Status FilterWithChunkedArray(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& filter,
+ std::shared_ptr<ChunkedArray>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum out_datum,
+ Filter(ChunkedArrayFromJSON(type, values),
+ ChunkedArrayFromJSON(boolean(), filter)));
+ *out = out_datum.chunked_array();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestFilterKernelWithChunkedArray, FilterChunkedArray) {
+ this->AssertFilter(int8(), {"[]"}, "[]", {});
+ this->AssertChunkedFilter(int8(), {"[]"}, {"[]"}, {});
+
+ this->AssertFilter(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0]", {"[8]"});
+ this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0]", "[1, 0]"}, {"[8]"});
+ this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0, 1]", "[0]"}, {"[8]"});
+
+ std::shared_ptr<ChunkedArray> arr;
+ ASSERT_RAISES(
+ Invalid, this->FilterWithArray(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0, 1, 1]", &arr));
+ ASSERT_RAISES(Invalid, this->FilterWithChunkedArray(int8(), {"[7]", "[8, 9]"},
+ {"[0, 1, 0]", "[1, 1]"}, &arr));
+}
+
+class TestFilterKernelWithTable : public TestFilterKernel {
+ public:
+ void AssertFilter(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& table_json, const std::string& filter,
+ FilterOptions options,
+ const std::vector<std::string>& expected_table) {
+ std::shared_ptr<Table> actual;
+
+ ASSERT_OK(this->FilterWithArray(schm, table_json, filter, options, &actual));
+ ValidateOutput(actual);
+ ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual);
+ }
+
+ void AssertChunkedFilter(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& table_json,
+ const std::vector<std::string>& filter, FilterOptions options,
+ const std::vector<std::string>& expected_table) {
+ std::shared_ptr<Table> actual;
+
+ ASSERT_OK(this->FilterWithChunkedArray(schm, table_json, filter, options, &actual));
+ ValidateOutput(actual);
+ AssertTablesEqual(*TableFromJSON(schm, expected_table), *actual,
+ /*same_chunk_layout=*/false);
+ }
+
+ Status FilterWithArray(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& values,
+ const std::string& filter, FilterOptions options,
+ std::shared_ptr<Table>* out) {
+ ARROW_ASSIGN_OR_RAISE(
+ Datum out_datum,
+ Filter(TableFromJSON(schm, values), ArrayFromJSON(boolean(), filter), options));
+ *out = out_datum.table();
+ return Status::OK();
+ }
+
+ Status FilterWithChunkedArray(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& filter,
+ FilterOptions options, std::shared_ptr<Table>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum out_datum,
+ Filter(TableFromJSON(schm, values),
+ ChunkedArrayFromJSON(boolean(), filter), options));
+ *out = out_datum.table();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestFilterKernelWithTable, FilterTable) {
+ std::vector<std::shared_ptr<Field>> fields = {field("a", int32()), field("b", utf8())};
+ auto schm = schema(fields);
+
+ std::vector<std::string> table_json = {R"([
+ {"a": null, "b": "yo"},
+ {"a": 1, "b": ""}
+ ])",
+ R"([
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])"};
+ for (auto options : {this->emit_null_, this->drop_}) {
+ this->AssertFilter(schm, table_json, "[0, 0, 0, 0]", options, {});
+ this->AssertChunkedFilter(schm, table_json, {"[0]", "[0, 0, 0]"}, options, {});
+ this->AssertFilter(schm, table_json, "[1, 1, 1, 1]", options, table_json);
+ this->AssertChunkedFilter(schm, table_json, {"[1]", "[1, 1, 1]"}, options,
+ table_json);
+ }
+
+ std::vector<std::string> expected_emit_null = {R"([
+ {"a": 1, "b": ""}
+ ])",
+ R"([
+ {"a": 2, "b": "hello"},
+ {"a": null, "b": null}
+ ])"};
+ this->AssertFilter(schm, table_json, "[0, 1, 1, null]", this->emit_null_,
+ expected_emit_null);
+ this->AssertChunkedFilter(schm, table_json, {"[0, 1, 1]", "[null]"}, this->emit_null_,
+ expected_emit_null);
+
+ std::vector<std::string> expected_drop = {R"([{"a": 1, "b": ""}])",
+ R"([{"a": 2, "b": "hello"}])"};
+ this->AssertFilter(schm, table_json, "[0, 1, 1, null]", this->drop_, expected_drop);
+ this->AssertChunkedFilter(schm, table_json, {"[0, 1, 1]", "[null]"}, this->drop_,
+ expected_drop);
+}
+
+TEST(TestFilterMetaFunction, ArityChecking) {
+ ASSERT_RAISES(Invalid, CallFunction("filter", {}));
+}
+
+// ----------------------------------------------------------------------
+// Take tests
+
+void AssertTakeArrays(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& expected) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> actual, Take(*values, *indices));
+ ValidateOutput(actual);
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+}
+
+Status TakeJSON(const std::shared_ptr<DataType>& type, const std::string& values,
+ const std::shared_ptr<DataType>& index_type, const std::string& indices,
+ std::shared_ptr<Array>* out) {
+ return Take(*ArrayFromJSON(type, values), *ArrayFromJSON(index_type, indices))
+ .Value(out);
+}
+
+void CheckTake(const std::shared_ptr<DataType>& type, const std::string& values_json,
+ const std::string& indices_json, const std::string& expected_json) {
+ auto values = ArrayFromJSON(type, values_json);
+ auto expected = ArrayFromJSON(type, expected_json);
+
+ for (auto index_type : {int8(), uint32()}) {
+ auto indices = ArrayFromJSON(index_type, indices_json);
+ AssertTakeArrays(values, indices, expected);
+
+ // Check sliced values
+ if (type->id() != Type::DENSE_UNION) {
+ ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(type, 2));
+ ASSERT_OK_AND_ASSIGN(auto values_sliced,
+ Concatenate({values_filler, values, values_filler}));
+ values_sliced = values_sliced->Slice(2, values->length());
+ AssertTakeArrays(values_sliced, indices, expected);
+ }
+
+ // Check sliced indices
+ ASSERT_OK_AND_ASSIGN(auto zero, MakeScalar(index_type, int8_t{0}));
+ ASSERT_OK_AND_ASSIGN(auto indices_filler, MakeArrayFromScalar(*zero, 3));
+ ASSERT_OK_AND_ASSIGN(auto indices_sliced,
+ Concatenate({indices_filler, indices, indices_filler}));
+ indices_sliced = indices_sliced->Slice(3, indices->length());
+ AssertTakeArrays(values, indices_sliced, expected);
+ }
+}
+
+void AssertTakeNull(const std::string& values, const std::string& indices,
+ const std::string& expected) {
+ CheckTake(null(), values, indices, expected);
+}
+
+void AssertTakeBoolean(const std::string& values, const std::string& indices,
+ const std::string& expected) {
+ CheckTake(boolean(), values, indices, expected);
+}
+
+template <typename ValuesType, typename IndexType>
+void ValidateTakeImpl(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& result) {
+ using ValuesArrayType = typename TypeTraits<ValuesType>::ArrayType;
+ using IndexArrayType = typename TypeTraits<IndexType>::ArrayType;
+ auto typed_values = checked_pointer_cast<ValuesArrayType>(values);
+ auto typed_result = checked_pointer_cast<ValuesArrayType>(result);
+ auto typed_indices = checked_pointer_cast<IndexArrayType>(indices);
+ for (int64_t i = 0; i < indices->length(); ++i) {
+ if (typed_indices->IsNull(i) || typed_values->IsNull(typed_indices->Value(i))) {
+ ASSERT_TRUE(result->IsNull(i)) << i;
+ } else {
+ ASSERT_FALSE(result->IsNull(i)) << i;
+ ASSERT_EQ(typed_result->GetView(i), typed_values->GetView(typed_indices->Value(i)))
+ << i;
+ }
+ }
+}
+
+template <typename ValuesType>
+void ValidateTake(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices) {
+ ASSERT_OK_AND_ASSIGN(Datum out, Take(values, indices));
+ auto taken = out.make_array();
+ ValidateOutput(taken);
+ ASSERT_EQ(indices->length(), taken->length());
+ switch (indices->type_id()) {
+ case Type::INT8:
+ ValidateTakeImpl<ValuesType, Int8Type>(values, indices, taken);
+ break;
+ case Type::INT16:
+ ValidateTakeImpl<ValuesType, Int16Type>(values, indices, taken);
+ break;
+ case Type::INT32:
+ ValidateTakeImpl<ValuesType, Int32Type>(values, indices, taken);
+ break;
+ case Type::INT64:
+ ValidateTakeImpl<ValuesType, Int64Type>(values, indices, taken);
+ break;
+ case Type::UINT8:
+ ValidateTakeImpl<ValuesType, UInt8Type>(values, indices, taken);
+ break;
+ case Type::UINT16:
+ ValidateTakeImpl<ValuesType, UInt16Type>(values, indices, taken);
+ break;
+ case Type::UINT32:
+ ValidateTakeImpl<ValuesType, UInt32Type>(values, indices, taken);
+ break;
+ case Type::UINT64:
+ ValidateTakeImpl<ValuesType, UInt64Type>(values, indices, taken);
+ break;
+ default:
+ FAIL() << "Invalid index type";
+ break;
+ }
+}
+
+template <typename T>
+T GetMaxIndex(int64_t values_length) {
+ int64_t max_index = values_length - 1;
+ if (max_index > static_cast<int64_t>(std::numeric_limits<T>::max())) {
+ max_index = std::numeric_limits<T>::max();
+ }
+ return static_cast<T>(max_index);
+}
+
+template <>
+uint64_t GetMaxIndex(int64_t values_length) {
+ return static_cast<uint64_t>(values_length - 1);
+}
+
+class TestTakeKernel : public ::testing::Test {
+ public:
+ void TestNoValidityBitmapButUnknownNullCount(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices) {
+ ASSERT_EQ(values->null_count(), 0);
+ ASSERT_EQ(indices->null_count(), 0);
+ auto expected = (*Take(values, indices)).make_array();
+
+ auto new_values = MakeArray(values->data()->Copy());
+ new_values->data()->buffers[0].reset();
+ new_values->data()->null_count = kUnknownNullCount;
+ auto new_indices = MakeArray(indices->data()->Copy());
+ new_indices->data()->buffers[0].reset();
+ new_indices->data()->null_count = kUnknownNullCount;
+ auto result = (*Take(new_values, new_indices)).make_array();
+
+ AssertArraysEqual(*expected, *result);
+ }
+
+ void TestNoValidityBitmapButUnknownNullCount(const std::shared_ptr<DataType>& type,
+ const std::string& values,
+ const std::string& indices) {
+ TestNoValidityBitmapButUnknownNullCount(ArrayFromJSON(type, values),
+ ArrayFromJSON(int16(), indices));
+ }
+};
+
+template <typename ArrowType>
+class TestTakeKernelTyped : public TestTakeKernel {};
+
+TEST_F(TestTakeKernel, TakeNull) {
+ AssertTakeNull("[null, null, null]", "[0, 1, 0]", "[null, null, null]");
+ AssertTakeNull("[null, null, null]", "[0, 2]", "[null, null]");
+
+ std::shared_ptr<Array> arr;
+ ASSERT_RAISES(IndexError,
+ TakeJSON(null(), "[null, null, null]", int8(), "[0, 9, 0]", &arr));
+ ASSERT_RAISES(IndexError,
+ TakeJSON(boolean(), "[null, null, null]", int8(), "[0, -1, 0]", &arr));
+}
+
+TEST_F(TestTakeKernel, InvalidIndexType) {
+ std::shared_ptr<Array> arr;
+ ASSERT_RAISES(NotImplemented, TakeJSON(null(), "[null, null, null]", float32(),
+ "[0.0, 1.0, 0.1]", &arr));
+}
+
+TEST_F(TestTakeKernel, TakeCCEmptyIndices) {
+ Datum dat = ChunkedArrayFromJSON(int8(), {"[]"});
+ Datum idx = ChunkedArrayFromJSON(int32(), {});
+ ASSERT_OK_AND_ASSIGN(auto out, Take(dat, idx));
+ ValidateOutput(out);
+ AssertDatumsEqual(ChunkedArrayFromJSON(int8(), {"[]"}), out, true);
+}
+
+TEST_F(TestTakeKernel, TakeACEmptyIndices) {
+ Datum dat = ArrayFromJSON(int8(), {"[]"});
+ Datum idx = ChunkedArrayFromJSON(int32(), {});
+ ASSERT_OK_AND_ASSIGN(auto out, Take(dat, idx));
+ ValidateOutput(out);
+ AssertDatumsEqual(ChunkedArrayFromJSON(int8(), {"[]"}), out, true);
+}
+
+TEST_F(TestTakeKernel, DefaultOptions) {
+ auto indices = ArrayFromJSON(int8(), "[null, 2, 0, 3]");
+ auto values = ArrayFromJSON(int8(), "[7, 8, 9, null]");
+ ASSERT_OK_AND_ASSIGN(auto no_options_provided, CallFunction("take", {values, indices}));
+
+ auto default_options = TakeOptions::Defaults();
+ ASSERT_OK_AND_ASSIGN(auto explicit_defaults,
+ CallFunction("take", {values, indices}, &default_options));
+
+ AssertDatumsEqual(explicit_defaults, no_options_provided);
+}
+
+TEST_F(TestTakeKernel, TakeBoolean) {
+ AssertTakeBoolean("[7, 8, 9]", "[]", "[]");
+ AssertTakeBoolean("[true, false, true]", "[0, 1, 0]", "[true, false, true]");
+ AssertTakeBoolean("[null, false, true]", "[0, 1, 0]", "[null, false, null]");
+ AssertTakeBoolean("[true, false, true]", "[null, 1, 0]", "[null, false, true]");
+
+ TestNoValidityBitmapButUnknownNullCount(boolean(), "[true, false, true]", "[1, 0, 0]");
+
+ std::shared_ptr<Array> arr;
+ ASSERT_RAISES(IndexError,
+ TakeJSON(boolean(), "[true, false, true]", int8(), "[0, 9, 0]", &arr));
+ ASSERT_RAISES(IndexError,
+ TakeJSON(boolean(), "[true, false, true]", int8(), "[0, -1, 0]", &arr));
+}
+
+template <typename ArrowType>
+class TestTakeKernelWithNumeric : public TestTakeKernelTyped<ArrowType> {
+ protected:
+ void AssertTake(const std::string& values, const std::string& indices,
+ const std::string& expected) {
+ CheckTake(type_singleton(), values, indices, expected);
+ }
+
+ std::shared_ptr<DataType> type_singleton() {
+ return TypeTraits<ArrowType>::type_singleton();
+ }
+};
+
+TYPED_TEST_SUITE(TestTakeKernelWithNumeric, NumericArrowTypes);
+TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) {
+ this->AssertTake("[7, 8, 9]", "[]", "[]");
+ this->AssertTake("[7, 8, 9]", "[0, 1, 0]", "[7, 8, 7]");
+ this->AssertTake("[null, 8, 9]", "[0, 1, 0]", "[null, 8, null]");
+ this->AssertTake("[7, 8, 9]", "[null, 1, 0]", "[null, 8, 7]");
+ this->AssertTake("[null, 8, 9]", "[]", "[]");
+ this->AssertTake("[7, 8, 9]", "[0, 0, 0, 0, 0, 0, 2]", "[7, 7, 7, 7, 7, 7, 9]");
+
+ std::shared_ptr<Array> arr;
+ ASSERT_RAISES(IndexError,
+ TakeJSON(this->type_singleton(), "[7, 8, 9]", int8(), "[0, 9, 0]", &arr));
+ ASSERT_RAISES(IndexError, TakeJSON(this->type_singleton(), "[7, 8, 9]", int8(),
+ "[0, -1, 0]", &arr));
+}
+
+template <typename TypeClass>
+class TestTakeKernelWithString : public TestTakeKernelTyped<TypeClass> {
+ public:
+ std::shared_ptr<DataType> value_type() {
+ return TypeTraits<TypeClass>::type_singleton();
+ }
+
+ void AssertTake(const std::string& values, const std::string& indices,
+ const std::string& expected) {
+ CheckTake(value_type(), values, indices, expected);
+ }
+
+ void AssertTakeDictionary(const std::string& dictionary_values,
+ const std::string& dictionary_indices,
+ const std::string& indices,
+ const std::string& expected_indices) {
+ auto dict = ArrayFromJSON(value_type(), dictionary_values);
+ auto type = dictionary(int8(), value_type());
+ ASSERT_OK_AND_ASSIGN(auto values,
+ DictionaryArray::FromArrays(
+ type, ArrayFromJSON(int8(), dictionary_indices), dict));
+ ASSERT_OK_AND_ASSIGN(
+ auto expected,
+ DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices), dict));
+ auto take_indices = ArrayFromJSON(int8(), indices);
+ AssertTakeArrays(values, take_indices, expected);
+ }
+};
+
+TYPED_TEST_SUITE(TestTakeKernelWithString, BinaryArrowTypes);
+
+TYPED_TEST(TestTakeKernelWithString, TakeString) {
+ this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["a", "b", "a"])");
+ this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", "[null, \"b\", null]");
+ this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b", "a"])");
+
+ this->TestNoValidityBitmapButUnknownNullCount(this->value_type(), R"(["a", "b", "c"])",
+ "[0, 1, 0]");
+
+ std::shared_ptr<DataType> type = this->value_type();
+ std::shared_ptr<Array> arr;
+ ASSERT_RAISES(IndexError,
+ TakeJSON(type, R"(["a", "b", "c"])", int8(), "[0, 9, 0]", &arr));
+ ASSERT_RAISES(IndexError, TakeJSON(type, R"(["a", "b", null, "ddd", "ee"])", int64(),
+ "[2, 5]", &arr));
+}
+
+TYPED_TEST(TestTakeKernelWithString, TakeDictionary) {
+ auto dict = R"(["a", "b", "c", "d", "e"])";
+ this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[3, 4, 3]");
+ this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[null, 4, null]");
+ this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4, 3]");
+}
+
+class TestTakeKernelFSB : public TestTakeKernelTyped<FixedSizeBinaryType> {
+ public:
+ std::shared_ptr<DataType> value_type() { return fixed_size_binary(3); }
+
+ void AssertTake(const std::string& values, const std::string& indices,
+ const std::string& expected) {
+ CheckTake(value_type(), values, indices, expected);
+ }
+};
+
+TEST_F(TestTakeKernelFSB, TakeFixedSizeBinary) {
+ this->AssertTake(R"(["aaa", "bbb", "ccc"])", "[0, 1, 0]", R"(["aaa", "bbb", "aaa"])");
+ this->AssertTake(R"([null, "bbb", "ccc"])", "[0, 1, 0]", "[null, \"bbb\", null]");
+ this->AssertTake(R"(["aaa", "bbb", "ccc"])", "[null, 1, 0]", R"([null, "bbb", "aaa"])");
+
+ this->TestNoValidityBitmapButUnknownNullCount(this->value_type(),
+ R"(["aaa", "bbb", "ccc"])", "[0, 1, 0]");
+
+ std::shared_ptr<DataType> type = this->value_type();
+ std::shared_ptr<Array> arr;
+ ASSERT_RAISES(IndexError,
+ TakeJSON(type, R"(["aaa", "bbb", "ccc"])", int8(), "[0, 9, 0]", &arr));
+ ASSERT_RAISES(IndexError, TakeJSON(type, R"(["aaa", "bbb", null, "ddd", "eee"])",
+ int64(), "[2, 5]", &arr));
+}
+
+class TestTakeKernelWithList : public TestTakeKernelTyped<ListType> {};
+
+TEST_F(TestTakeKernelWithList, TakeListInt32) {
+ std::string list_json = "[[], [1,2], null, [3]]";
+ CheckTake(list(int32()), list_json, "[]", "[]");
+ CheckTake(list(int32()), list_json, "[3, 2, 1]", "[[3], null, [1,2]]");
+ CheckTake(list(int32()), list_json, "[null, 3, 0]", "[null, [3], []]");
+ CheckTake(list(int32()), list_json, "[null, null]", "[null, null]");
+ CheckTake(list(int32()), list_json, "[3, 0, 0, 3]", "[[3], [], [], [3]]");
+ CheckTake(list(int32()), list_json, "[0, 1, 2, 3]", list_json);
+ CheckTake(list(int32()), list_json, "[0, 0, 0, 0, 0, 0, 1]",
+ "[[], [], [], [], [], [], [1, 2]]");
+
+ this->TestNoValidityBitmapButUnknownNullCount(list(int32()), "[[], [1,2], [3]]",
+ "[0, 1, 0]");
+}
+
+TEST_F(TestTakeKernelWithList, TakeListListInt32) {
+ std::string list_json = R"([
+ [],
+ [[1], [2, null, 2], []],
+ null,
+ [[3, null], null]
+ ])";
+ auto type = list(list(int32()));
+ CheckTake(type, list_json, "[]", "[]");
+ CheckTake(type, list_json, "[3, 2, 1]", R"([
+ [[3, null], null],
+ null,
+ [[1], [2, null, 2], []]
+ ])");
+ CheckTake(type, list_json, "[null, 3, 0]", R"([
+ null,
+ [[3, null], null],
+ []
+ ])");
+ CheckTake(type, list_json, "[null, null]", "[null, null]");
+ CheckTake(type, list_json, "[3, 0, 0, 3]",
+ "[[[3, null], null], [], [], [[3, null], null]]");
+ CheckTake(type, list_json, "[0, 1, 2, 3]", list_json);
+ CheckTake(type, list_json, "[0, 0, 0, 0, 0, 0, 1]",
+ "[[], [], [], [], [], [], [[1], [2, null, 2], []]]");
+
+ this->TestNoValidityBitmapButUnknownNullCount(
+ type, "[[[1], [2, null, 2], []], [[3, null]]]", "[0, 1, 0]");
+}
+
+class TestTakeKernelWithLargeList : public TestTakeKernelTyped<LargeListType> {};
+
+TEST_F(TestTakeKernelWithLargeList, TakeLargeListInt32) {
+ std::string list_json = "[[], [1,2], null, [3]]";
+ CheckTake(large_list(int32()), list_json, "[]", "[]");
+ CheckTake(large_list(int32()), list_json, "[null, 1, 2, 0]", "[null, [1,2], null, []]");
+}
+
+class TestTakeKernelWithFixedSizeList : public TestTakeKernelTyped<FixedSizeListType> {};
+
+TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) {
+ std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]";
+ CheckTake(fixed_size_list(int32(), 3), list_json, "[]", "[]");
+ CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 2, 1]",
+ "[[7, 8, null], [4, 5, 6], [1, null, 3]]");
+ CheckTake(fixed_size_list(int32(), 3), list_json, "[null, 2, 0]",
+ "[null, [4, 5, 6], null]");
+ CheckTake(fixed_size_list(int32(), 3), list_json, "[null, null]", "[null, null]");
+ CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 0, 0, 3]",
+ "[[7, 8, null], null, null, [7, 8, null]]");
+ CheckTake(fixed_size_list(int32(), 3), list_json, "[0, 1, 2, 3]", list_json);
+ CheckTake(
+ fixed_size_list(int32(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]",
+ "[[4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [1, null, 3]]");
+
+ this->TestNoValidityBitmapButUnknownNullCount(fixed_size_list(int32(), 3),
+ "[[1, null, 3], [4, 5, 6], [7, 8, null]]",
+ "[0, 1, 0]");
+}
+
+class TestTakeKernelWithMap : public TestTakeKernelTyped<MapType> {};
+
+TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) {
+ std::string map_json = R"([
+ [["joe", 0], ["mark", null]],
+ null,
+ [["cap", 8]],
+ []
+ ])";
+ CheckTake(map(utf8(), int32()), map_json, "[]", "[]");
+ CheckTake(map(utf8(), int32()), map_json, "[3, 1, 3, 1, 3]",
+ "[[], null, [], null, []]");
+ CheckTake(map(utf8(), int32()), map_json, "[2, 1, null]", R"([
+ [["cap", 8]],
+ null,
+ null
+ ])");
+ CheckTake(map(utf8(), int32()), map_json, "[2, 1, 0]", R"([
+ [["cap", 8]],
+ null,
+ [["joe", 0], ["mark", null]]
+ ])");
+ CheckTake(map(utf8(), int32()), map_json, "[0, 1, 2, 3]", map_json);
+ CheckTake(map(utf8(), int32()), map_json, "[0, 0, 0, 0, 0, 0, 3]", R"([
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ []
+ ])");
+}
+
+class TestTakeKernelWithStruct : public TestTakeKernelTyped<StructType> {};
+
+TEST_F(TestTakeKernelWithStruct, TakeStruct) {
+ auto struct_type = struct_({field("a", int32()), field("b", utf8())});
+ auto struct_json = R"([
+ null,
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])";
+ CheckTake(struct_type, struct_json, "[]", "[]");
+ CheckTake(struct_type, struct_json, "[3, 1, 3, 1, 3]", R"([
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ {"a": 4, "b": "eh"}
+ ])");
+ CheckTake(struct_type, struct_json, "[3, 1, 0]", R"([
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ null
+ ])");
+ CheckTake(struct_type, struct_json, "[0, 1, 2, 3]", struct_json);
+ CheckTake(struct_type, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
+ null,
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"}
+ ])");
+
+ this->TestNoValidityBitmapButUnknownNullCount(
+ struct_type, R"([{"a": 1}, {"a": 2, "b": "hello"}])", "[0, 1, 0]");
+}
+
+class TestTakeKernelWithUnion : public TestTakeKernelTyped<UnionType> {};
+
+TEST_F(TestTakeKernelWithUnion, TakeUnion) {
+ auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5});
+ auto union_json = R"([
+ [2, null],
+ [2, 222],
+ [5, "hello"],
+ [5, "eh"],
+ [2, null],
+ [2, 111],
+ [5, null]
+ ])";
+ CheckTake(union_type, union_json, "[]", "[]");
+ CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([
+ [5, "eh"],
+ [2, 222],
+ [5, "eh"],
+ [2, 222],
+ [5, "eh"]
+ ])");
+ CheckTake(union_type, union_json, "[4, 2, 1, 6]", R"([
+ [2, null],
+ [5, "hello"],
+ [2, 222],
+ [5, null]
+ ])");
+ CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json);
+ CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
+ [2, null],
+ [5, "hello"],
+ [5, "hello"],
+ [5, "hello"],
+ [5, "hello"],
+ [5, "hello"],
+ [5, "hello"]
+ ])");
+}
+
+class TestPermutationsWithTake : public TestBase {
+ protected:
+ void DoTake(const Int16Array& values, const Int16Array& indices,
+ std::shared_ptr<Int16Array>* out) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> boxed_out, Take(values, indices));
+ ValidateOutput(boxed_out);
+ *out = checked_pointer_cast<Int16Array>(std::move(boxed_out));
+ }
+
+ std::shared_ptr<Int16Array> DoTake(const Int16Array& values,
+ const Int16Array& indices) {
+ std::shared_ptr<Int16Array> out;
+ DoTake(values, indices, &out);
+ return out;
+ }
+
+ std::shared_ptr<Int16Array> DoTakeN(uint64_t n, std::shared_ptr<Int16Array> array) {
+ auto power_of_2 = array;
+ array = Identity(array->length());
+ while (n != 0) {
+ if (n & 1) {
+ array = DoTake(*array, *power_of_2);
+ }
+ power_of_2 = DoTake(*power_of_2, *power_of_2);
+ n >>= 1;
+ }
+ return array;
+ }
+
+ template <typename Rng>
+ void Shuffle(const Int16Array& array, Rng& gen, std::shared_ptr<Int16Array>* shuffled) {
+ auto byte_length = array.length() * sizeof(int16_t);
+ ASSERT_OK_AND_ASSIGN(auto data, array.values()->CopySlice(0, byte_length));
+ auto mutable_data = reinterpret_cast<int16_t*>(data->mutable_data());
+ std::shuffle(mutable_data, mutable_data + array.length(), gen);
+ shuffled->reset(new Int16Array(array.length(), data));
+ }
+
+ template <typename Rng>
+ std::shared_ptr<Int16Array> Shuffle(const Int16Array& array, Rng& gen) {
+ std::shared_ptr<Int16Array> out;
+ Shuffle(array, gen, &out);
+ return out;
+ }
+
+ void Identity(int64_t length, std::shared_ptr<Int16Array>* identity) {
+ Int16Builder identity_builder;
+ ASSERT_OK(identity_builder.Resize(length));
+ for (int16_t i = 0; i < length; ++i) {
+ identity_builder.UnsafeAppend(i);
+ }
+ ASSERT_OK(identity_builder.Finish(identity));
+ }
+
+ std::shared_ptr<Int16Array> Identity(int64_t length) {
+ std::shared_ptr<Int16Array> out;
+ Identity(length, &out);
+ return out;
+ }
+
+ std::shared_ptr<Int16Array> Inverse(const std::shared_ptr<Int16Array>& permutation) {
+ auto length = static_cast<int16_t>(permutation->length());
+
+ std::vector<bool> cycle_lengths(length + 1, false);
+ auto permutation_to_the_i = permutation;
+ for (int16_t cycle_length = 1; cycle_length <= length; ++cycle_length) {
+ cycle_lengths[cycle_length] = HasTrivialCycle(*permutation_to_the_i);
+ permutation_to_the_i = DoTake(*permutation, *permutation_to_the_i);
+ }
+
+ uint64_t cycle_to_identity_length = 1;
+ for (int16_t cycle_length = length; cycle_length > 1; --cycle_length) {
+ if (!cycle_lengths[cycle_length]) {
+ continue;
+ }
+ if (cycle_to_identity_length % cycle_length == 0) {
+ continue;
+ }
+ if (cycle_to_identity_length >
+ std::numeric_limits<uint64_t>::max() / cycle_length) {
+ // overflow, can't compute Inverse
+ return nullptr;
+ }
+ cycle_to_identity_length *= cycle_length;
+ }
+
+ return DoTakeN(cycle_to_identity_length - 1, permutation);
+ }
+
+ bool HasTrivialCycle(const Int16Array& permutation) {
+ for (int64_t i = 0; i < permutation.length(); ++i) {
+ if (permutation.Value(i) == static_cast<int16_t>(i)) {
+ return true;
+ }
+ }
+ return false;
+ }
+};
+
+TEST_F(TestPermutationsWithTake, InvertPermutation) {
+ for (auto seed : std::vector<random::SeedType>({0, kRandomSeed, kRandomSeed * 2 - 1})) {
+ std::default_random_engine gen(seed);
+ for (int16_t length = 0; length < 1 << 10; ++length) {
+ auto identity = Identity(length);
+ auto permutation = Shuffle(*identity, gen);
+ auto inverse = Inverse(permutation);
+ if (inverse == nullptr) {
+ break;
+ }
+ ASSERT_TRUE(DoTake(*inverse, *permutation)->Equals(identity));
+ }
+ }
+}
+
+class TestTakeKernelWithRecordBatch : public TestTakeKernelTyped<RecordBatch> {
+ public:
+ void AssertTake(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
+ const std::string& indices, const std::string& expected_batch) {
+ std::shared_ptr<RecordBatch> actual;
+
+ for (auto index_type : {int8(), uint32()}) {
+ ASSERT_OK(TakeJSON(schm, batch_json, index_type, indices, &actual));
+ ValidateOutput(actual);
+ ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual);
+ }
+ }
+
+ Status TakeJSON(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
+ const std::shared_ptr<DataType>& index_type, const std::string& indices,
+ std::shared_ptr<RecordBatch>* out) {
+ auto batch = RecordBatchFromJSON(schm, batch_json);
+ ARROW_ASSIGN_OR_RAISE(Datum result,
+ Take(Datum(batch), Datum(ArrayFromJSON(index_type, indices))));
+ *out = result.record_batch();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestTakeKernelWithRecordBatch, TakeRecordBatch) {
+ std::vector<std::shared_ptr<Field>> fields = {field("a", int32()), field("b", utf8())};
+ auto schm = schema(fields);
+
+ auto struct_json = R"([
+ {"a": null, "b": "yo"},
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])";
+ this->AssertTake(schm, struct_json, "[]", "[]");
+ this->AssertTake(schm, struct_json, "[3, 1, 3, 1, 3]", R"([
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ {"a": 4, "b": "eh"}
+ ])");
+ this->AssertTake(schm, struct_json, "[3, 1, 0]", R"([
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ {"a": null, "b": "yo"}
+ ])");
+ this->AssertTake(schm, struct_json, "[0, 1, 2, 3]", struct_json);
+ this->AssertTake(schm, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
+ {"a": null, "b": "yo"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"}
+ ])");
+}
+
+class TestTakeKernelWithChunkedArray : public TestTakeKernelTyped<ChunkedArray> {
+ public:
+ void AssertTake(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, const std::string& indices,
+ const std::vector<std::string>& expected) {
+ std::shared_ptr<ChunkedArray> actual;
+ ASSERT_OK(this->TakeWithArray(type, values, indices, &actual));
+ ValidateOutput(actual);
+ AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
+ }
+
+ void AssertChunkedTake(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& indices,
+ const std::vector<std::string>& expected) {
+ std::shared_ptr<ChunkedArray> actual;
+ ASSERT_OK(this->TakeWithChunkedArray(type, values, indices, &actual));
+ ValidateOutput(actual);
+ AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
+ }
+
+ Status TakeWithArray(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, const std::string& indices,
+ std::shared_ptr<ChunkedArray>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum result, Take(ChunkedArrayFromJSON(type, values),
+ ArrayFromJSON(int8(), indices)));
+ *out = result.chunked_array();
+ return Status::OK();
+ }
+
+ Status TakeWithChunkedArray(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& indices,
+ std::shared_ptr<ChunkedArray>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum result, Take(ChunkedArrayFromJSON(type, values),
+ ChunkedArrayFromJSON(int8(), indices)));
+ *out = result.chunked_array();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestTakeKernelWithChunkedArray, TakeChunkedArray) {
+ this->AssertTake(int8(), {"[]"}, "[]", {"[]"});
+ this->AssertChunkedTake(int8(), {}, {}, {});
+ this->AssertChunkedTake(int8(), {}, {"[]"}, {"[]"});
+ this->AssertChunkedTake(int8(), {}, {"[null]"}, {"[null]"});
+ this->AssertChunkedTake(int8(), {"[]"}, {}, {});
+ this->AssertChunkedTake(int8(), {"[]"}, {"[]"}, {"[]"});
+ this->AssertChunkedTake(int8(), {"[]"}, {"[null]"}, {"[null]"});
+
+ this->AssertTake(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0, 2]", {"[7, 8, 7, 9]"});
+ this->AssertChunkedTake(int8(), {"[7]", "[8, 9]"}, {"[0, 1, 0]", "[]", "[2]"},
+ {"[7, 8, 7]", "[]", "[9]"});
+ this->AssertTake(int8(), {"[7]", "[8, 9]"}, "[2, 1]", {"[9, 8]"});
+
+ std::shared_ptr<ChunkedArray> arr;
+ ASSERT_RAISES(IndexError,
+ this->TakeWithArray(int8(), {"[7]", "[8, 9]"}, "[0, 5]", &arr));
+ ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {"[7]", "[8, 9]"},
+ {"[0, 1, 0]", "[5, 1]"}, &arr));
+ ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {}, {"[0]"}, &arr));
+ ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {"[]"}, {"[0]"}, &arr));
+}
+
+class TestTakeKernelWithTable : public TestTakeKernelTyped<Table> {
+ public:
+ void AssertTake(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& table_json, const std::string& filter,
+ const std::vector<std::string>& expected_table) {
+ std::shared_ptr<Table> actual;
+
+ ASSERT_OK(this->TakeWithArray(schm, table_json, filter, &actual));
+ ValidateOutput(actual);
+ ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual);
+ }
+
+ void AssertChunkedTake(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& table_json,
+ const std::vector<std::string>& filter,
+ const std::vector<std::string>& expected_table) {
+ std::shared_ptr<Table> actual;
+
+ ASSERT_OK(this->TakeWithChunkedArray(schm, table_json, filter, &actual));
+ ValidateOutput(actual);
+ ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual);
+ }
+
+ Status TakeWithArray(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& values, const std::string& indices,
+ std::shared_ptr<Table>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(TableFromJSON(schm, values)),
+ Datum(ArrayFromJSON(int8(), indices))));
+ *out = result.table();
+ return Status::OK();
+ }
+
+ Status TakeWithChunkedArray(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& indices,
+ std::shared_ptr<Table>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum result,
+ Take(Datum(TableFromJSON(schm, values)),
+ Datum(ChunkedArrayFromJSON(int8(), indices))));
+ *out = result.table();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestTakeKernelWithTable, TakeTable) {
+ std::vector<std::shared_ptr<Field>> fields = {field("a", int32()), field("b", utf8())};
+ auto schm = schema(fields);
+
+ std::vector<std::string> table_json = {
+ "[{\"a\": null, \"b\": \"yo\"},{\"a\": 1, \"b\": \"\"}]",
+ "[{\"a\": 2, \"b\": \"hello\"},{\"a\": 4, \"b\": \"eh\"}]"};
+
+ this->AssertTake(schm, table_json, "[]", {"[]"});
+ std::vector<std::string> expected_310 = {
+ "[{\"a\": 4, \"b\": \"eh\"},{\"a\": 1, \"b\": \"\"},{\"a\": null, \"b\": \"yo\"}]"};
+ this->AssertTake(schm, table_json, "[3, 1, 0]", expected_310);
+ this->AssertChunkedTake(schm, table_json, {"[0, 1]", "[2, 3]"}, table_json);
+}
+
+TEST(TestTakeMetaFunction, ArityChecking) {
+ ASSERT_RAISES(Invalid, CallFunction("take", {}));
+}
+
+// ----------------------------------------------------------------------
+// Random data tests
+
+template <typename Unused = void>
+struct FilterRandomTest {
+ static void Test(const std::shared_ptr<DataType>& type) {
+ auto rand = random::RandomArrayGenerator(kRandomSeed);
+ const int64_t length = static_cast<int64_t>(1ULL << 10);
+ for (auto null_probability : {0.0, 0.01, 0.1, 0.999, 1.0}) {
+ for (auto true_probability : {0.0, 0.1, 0.999, 1.0}) {
+ auto values = rand.ArrayOf(type, length, null_probability);
+ auto filter = rand.Boolean(length + 1, true_probability, null_probability);
+ auto filter_no_nulls = rand.Boolean(length + 1, true_probability, 0.0);
+ ValidateFilter(values, filter->Slice(0, values->length()));
+ ValidateFilter(values, filter_no_nulls->Slice(0, values->length()));
+ // Test values and filter have different offsets
+ ValidateFilter(values->Slice(3), filter->Slice(4));
+ ValidateFilter(values->Slice(3), filter_no_nulls->Slice(4));
+ }
+ }
+ }
+};
+
+template <typename ValuesType, typename IndexType>
+void CheckTakeRandom(const std::shared_ptr<Array>& values, int64_t indices_length,
+ double null_probability, random::RandomArrayGenerator* rand) {
+ using IndexCType = typename IndexType::c_type;
+ IndexCType max_index = GetMaxIndex<IndexCType>(values->length());
+ auto indices = rand->Numeric<IndexType>(indices_length, static_cast<IndexCType>(0),
+ max_index, null_probability);
+ auto indices_no_nulls = rand->Numeric<IndexType>(
+ indices_length, static_cast<IndexCType>(0), max_index, /*null_probability=*/0.0);
+ ValidateTake<ValuesType>(values, indices);
+ ValidateTake<ValuesType>(values, indices_no_nulls);
+ // Sliced indices array
+ if (indices_length >= 2) {
+ indices = indices->Slice(1, indices_length - 2);
+ indices_no_nulls = indices_no_nulls->Slice(1, indices_length - 2);
+ ValidateTake<ValuesType>(values, indices);
+ ValidateTake<ValuesType>(values, indices_no_nulls);
+ }
+}
+
+template <typename ValuesType>
+struct TakeRandomTest {
+ static void Test(const std::shared_ptr<DataType>& type) {
+ auto rand = random::RandomArrayGenerator(kRandomSeed);
+ const int64_t values_length = 64 * 16 + 1;
+ const int64_t indices_length = 64 * 4 + 1;
+ for (const auto null_probability : {0.0, 0.001, 0.05, 0.25, 0.95, 0.999, 1.0}) {
+ auto values = rand.ArrayOf(type, values_length, null_probability);
+ CheckTakeRandom<ValuesType, Int8Type>(values, indices_length, null_probability,
+ &rand);
+ CheckTakeRandom<ValuesType, Int16Type>(values, indices_length, null_probability,
+ &rand);
+ CheckTakeRandom<ValuesType, Int32Type>(values, indices_length, null_probability,
+ &rand);
+ CheckTakeRandom<ValuesType, Int64Type>(values, indices_length, null_probability,
+ &rand);
+ CheckTakeRandom<ValuesType, UInt8Type>(values, indices_length, null_probability,
+ &rand);
+ CheckTakeRandom<ValuesType, UInt16Type>(values, indices_length, null_probability,
+ &rand);
+ CheckTakeRandom<ValuesType, UInt32Type>(values, indices_length, null_probability,
+ &rand);
+ CheckTakeRandom<ValuesType, UInt64Type>(values, indices_length, null_probability,
+ &rand);
+ // Sliced values array
+ if (values_length > 2) {
+ values = values->Slice(1, values_length - 2);
+ CheckTakeRandom<ValuesType, UInt64Type>(values, indices_length, null_probability,
+ &rand);
+ }
+ }
+ }
+};
+
+TEST(TestFilter, PrimitiveRandom) { TestRandomPrimitiveCTypes<FilterRandomTest>(); }
+
+TEST(TestFilter, RandomBoolean) { FilterRandomTest<>::Test(boolean()); }
+
+TEST(TestFilter, RandomString) {
+ FilterRandomTest<>::Test(utf8());
+ FilterRandomTest<>::Test(large_utf8());
+}
+
+TEST(TestFilter, RandomFixedSizeBinary) {
+ FilterRandomTest<>::Test(fixed_size_binary(0));
+ FilterRandomTest<>::Test(fixed_size_binary(16));
+}
+
+TEST(TestTake, PrimitiveRandom) { TestRandomPrimitiveCTypes<TakeRandomTest>(); }
+
+TEST(TestTake, RandomBoolean) { TakeRandomTest<BooleanType>::Test(boolean()); }
+
+TEST(TestTake, RandomString) {
+ TakeRandomTest<StringType>::Test(utf8());
+ TakeRandomTest<LargeStringType>::Test(large_utf8());
+}
+
+TEST(TestTake, RandomFixedSizeBinary) {
+ TakeRandomTest<FixedSizeBinaryType>::Test(fixed_size_binary(0));
+ TakeRandomTest<FixedSizeBinaryType>::Test(fixed_size_binary(16));
+}
+
+// ----------------------------------------------------------------------
+// DropNull tests
+
+void AssertDropNullArrays(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& expected) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> actual, DropNull(*values));
+ ValidateOutput(actual);
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+}
+
+Status DropNullJSON(const std::shared_ptr<DataType>& type, const std::string& values,
+ std::shared_ptr<Array>* out) {
+ return DropNull(*ArrayFromJSON(type, values)).Value(out);
+}
+
+void CheckDropNull(const std::shared_ptr<DataType>& type, const std::string& values,
+ const std::string& expected) {
+ std::shared_ptr<Array> actual;
+
+ ASSERT_OK(DropNullJSON(type, values, &actual));
+ ValidateOutput(actual);
+ AssertArraysEqual(*ArrayFromJSON(type, expected), *actual, /*verbose=*/true);
+}
+
+struct TestDropNullKernel : public ::testing::Test {
+ void TestNoValidityBitmapButUnknownNullCount(const std::shared_ptr<Array>& values) {
+ ASSERT_EQ(values->null_count(), 0);
+ auto expected = (*DropNull(values)).make_array();
+
+ auto new_values = MakeArray(values->data()->Copy());
+ new_values->data()->buffers[0].reset();
+ new_values->data()->null_count = kUnknownNullCount;
+ auto result = (*DropNull(new_values)).make_array();
+ AssertArraysEqual(*expected, *result);
+ }
+
+ void TestNoValidityBitmapButUnknownNullCount(const std::shared_ptr<DataType>& type,
+ const std::string& values) {
+ TestNoValidityBitmapButUnknownNullCount(ArrayFromJSON(type, values));
+ }
+};
+
+TEST_F(TestDropNullKernel, DropNull) {
+ CheckDropNull(null(), "[null, null, null]", "[]");
+ CheckDropNull(null(), "[null]", "[]");
+}
+
+TEST_F(TestDropNullKernel, DropNullBoolean) {
+ CheckDropNull(boolean(), "[true, false, true]", "[true, false, true]");
+ CheckDropNull(boolean(), "[null, false, true]", "[false, true]");
+ CheckDropNull(boolean(), "[]", "[]");
+ CheckDropNull(boolean(), "[null, null]", "[]");
+
+ TestNoValidityBitmapButUnknownNullCount(boolean(), "[true, false, true]");
+}
+
+template <typename ArrowType>
+struct TestDropNullKernelTyped : public TestDropNullKernel {
+ TestDropNullKernelTyped() : rng_(seed_) {}
+
+ std::shared_ptr<Int32Array> Offsets(int32_t length, int32_t slice_count) {
+ return checked_pointer_cast<Int32Array>(rng_.Offsets(slice_count, 0, length));
+ }
+
+ // Slice `array` into multiple chunks along `offsets`
+ ArrayVector Slices(const std::shared_ptr<Array>& array,
+ const std::shared_ptr<Int32Array>& offsets) {
+ ArrayVector slices(offsets->length() - 1);
+ for (int64_t i = 0; i != static_cast<int64_t>(slices.size()); ++i) {
+ slices[i] =
+ array->Slice(offsets->Value(i), offsets->Value(i + 1) - offsets->Value(i));
+ }
+ return slices;
+ }
+
+ random::SeedType seed_ = 0xdeadbeef;
+ random::RandomArrayGenerator rng_;
+};
+
+template <typename ArrowType>
+class TestDropNullKernelWithNumeric : public TestDropNullKernelTyped<ArrowType> {
+ protected:
+ void AssertDropNull(const std::string& values, const std::string& expected) {
+ CheckDropNull(type_singleton(), values, expected);
+ }
+
+ std::shared_ptr<DataType> type_singleton() {
+ return TypeTraits<ArrowType>::type_singleton();
+ }
+};
+
+TYPED_TEST_SUITE(TestDropNullKernelWithNumeric, NumericArrowTypes);
+TYPED_TEST(TestDropNullKernelWithNumeric, DropNullNumeric) {
+ this->AssertDropNull("[7, 8, 9]", "[7, 8, 9]");
+ this->AssertDropNull("[null, 8, 9]", "[8, 9]");
+ this->AssertDropNull("[null, null, null]", "[]");
+}
+
+template <typename TypeClass>
+class TestDropNullKernelWithString : public TestDropNullKernelTyped<TypeClass> {
+ public:
+ std::shared_ptr<DataType> value_type() {
+ return TypeTraits<TypeClass>::type_singleton();
+ }
+
+ void AssertDropNull(const std::string& values, const std::string& expected) {
+ CheckDropNull(value_type(), values, expected);
+ }
+
+ void AssertDropNullDictionary(const std::string& dictionary_values,
+ const std::string& dictionary_indices,
+ const std::string& expected_indices) {
+ auto dict = ArrayFromJSON(value_type(), dictionary_values);
+ auto type = dictionary(int8(), value_type());
+ ASSERT_OK_AND_ASSIGN(auto values,
+ DictionaryArray::FromArrays(
+ type, ArrayFromJSON(int8(), dictionary_indices), dict));
+ ASSERT_OK_AND_ASSIGN(
+ auto expected,
+ DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices), dict));
+ AssertDropNullArrays(values, expected);
+ }
+};
+
+TYPED_TEST_SUITE(TestDropNullKernelWithString, BinaryArrowTypes);
+
+TYPED_TEST(TestDropNullKernelWithString, DropNullString) {
+ this->AssertDropNull(R"(["a", "b", "c"])", R"(["a", "b", "c"])");
+ this->AssertDropNull(R"([null, "b", "c"])", "[\"b\", \"c\"]");
+ this->AssertDropNull(R"(["a", "b", null])", R"(["a", "b"])");
+
+ this->TestNoValidityBitmapButUnknownNullCount(this->value_type(), R"(["a", "b", "c"])");
+}
+
+TYPED_TEST(TestDropNullKernelWithString, DropNullDictionary) {
+ auto dict = R"(["a", "b", "c", "d", "e"])";
+ this->AssertDropNullDictionary(dict, "[3, 4, 2]", "[3, 4, 2]");
+ this->AssertDropNullDictionary(dict, "[null, 4, 2]", "[4, 2]");
+}
+
+class TestDropNullKernelFSB : public TestDropNullKernelTyped<FixedSizeBinaryType> {
+ public:
+ std::shared_ptr<DataType> value_type() { return fixed_size_binary(3); }
+
+ void AssertDropNull(const std::string& values, const std::string& expected) {
+ CheckDropNull(value_type(), values, expected);
+ }
+};
+
+TEST_F(TestDropNullKernelFSB, DropNullFixedSizeBinary) {
+ this->AssertDropNull(R"(["aaa", "bbb", "ccc"])", R"(["aaa", "bbb", "ccc"])");
+ this->AssertDropNull(R"([null, "bbb", "ccc"])", "[\"bbb\", \"ccc\"]");
+
+ this->TestNoValidityBitmapButUnknownNullCount(this->value_type(),
+ R"(["aaa", "bbb", "ccc"])");
+}
+
+class TestDropNullKernelWithList : public TestDropNullKernelTyped<ListType> {};
+
+TEST_F(TestDropNullKernelWithList, DropNullListInt32) {
+ std::string list_json = "[[], [1,2], null, [3]]";
+ CheckDropNull(list(int32()), list_json, "[[], [1,2], [3]]");
+ this->TestNoValidityBitmapButUnknownNullCount(list(int32()), "[[], [1,2], [3]]");
+}
+
+TEST_F(TestDropNullKernelWithList, DropNullListListInt32) {
+ std::string list_json = R"([
+ [],
+ [[1], [2, null, 2], []],
+ null,
+ [[3, null], null]
+ ])";
+ auto type = list(list(int32()));
+ CheckDropNull(type, list_json, R"([
+ [],
+ [[1], [2, null, 2], []],
+ [[3, null], null]
+ ])");
+
+ this->TestNoValidityBitmapButUnknownNullCount(type,
+ "[[[1], [2, null, 2], []], [[3, null]]]");
+}
+
+class TestDropNullKernelWithLargeList : public TestDropNullKernelTyped<LargeListType> {};
+
+TEST_F(TestDropNullKernelWithLargeList, DropNullLargeListInt32) {
+ std::string list_json = "[[], [1,2], null, [3]]";
+ CheckDropNull(large_list(int32()), list_json, "[[], [1,2], [3]]");
+
+ this->TestNoValidityBitmapButUnknownNullCount(
+ fixed_size_list(int32(), 3), "[[1, null, 3], [4, 5, 6], [7, 8, null]]");
+}
+
+class TestDropNullKernelWithFixedSizeList
+ : public TestDropNullKernelTyped<FixedSizeListType> {};
+
+TEST_F(TestDropNullKernelWithFixedSizeList, DropNullFixedSizeListInt32) {
+ std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]";
+ CheckDropNull(fixed_size_list(int32(), 3), list_json,
+ "[[1, null, 3], [4, 5, 6], [7, 8, null]]");
+
+ this->TestNoValidityBitmapButUnknownNullCount(
+ fixed_size_list(int32(), 3), "[[1, null, 3], [4, 5, 6], [7, 8, null]]");
+}
+
+class TestDropNullKernelWithMap : public TestDropNullKernelTyped<MapType> {};
+
+TEST_F(TestDropNullKernelWithMap, DropNullMapStringToInt32) {
+ std::string map_json = R"([
+ [["joe", 0], ["mark", null]],
+ null,
+ [["cap", 8]],
+ []
+ ])";
+ std::string expected_json = R"([
+ [["joe", 0], ["mark", null]],
+ [["cap", 8]],
+ []
+ ])";
+ CheckDropNull(map(utf8(), int32()), map_json, expected_json);
+}
+
+class TestDropNullKernelWithStruct : public TestDropNullKernelTyped<StructType> {};
+
+TEST_F(TestDropNullKernelWithStruct, DropNullStruct) {
+ auto struct_type = struct_({field("a", int32()), field("b", utf8())});
+ auto struct_json = R"([
+ null,
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])";
+ auto expected_struct_json = R"([
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])";
+ CheckDropNull(struct_type, struct_json, expected_struct_json);
+ this->TestNoValidityBitmapButUnknownNullCount(struct_type, expected_struct_json);
+}
+
+class TestDropNullKernelWithUnion : public TestDropNullKernelTyped<UnionType> {};
+
+TEST_F(TestDropNullKernelWithUnion, DropNullUnion) {
+ auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5});
+ auto union_json = R"([
+ [2, null],
+ [2, 222],
+ [5, "hello"],
+ [5, "eh"],
+ [2, null],
+ [2, 111],
+ [5, null]
+ ])";
+ CheckDropNull(union_type, union_json, union_json);
+}
+
+class TestDropNullKernelWithRecordBatch : public TestDropNullKernelTyped<RecordBatch> {
+ public:
+ void AssertDropNull(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
+ const std::string& expected_batch) {
+ std::shared_ptr<RecordBatch> actual;
+
+ ASSERT_OK(this->DoDropNull(schm, batch_json, &actual));
+ ValidateOutput(actual);
+ ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual);
+ }
+
+ Status DoDropNull(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
+ std::shared_ptr<RecordBatch>* out) {
+ auto batch = RecordBatchFromJSON(schm, batch_json);
+ ARROW_ASSIGN_OR_RAISE(Datum out_datum, DropNull(batch));
+ *out = out_datum.record_batch();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestDropNullKernelWithRecordBatch, DropNullRecordBatch) {
+ std::vector<std::shared_ptr<Field>> fields = {field("a", int32()), field("b", utf8())};
+ auto schm = schema(fields);
+
+ auto batch_json = R"([
+ {"a": null, "b": "yo"},
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])";
+ this->AssertDropNull(schm, batch_json, R"([
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])");
+
+ batch_json = R"([
+ {"a": null, "b": "yo"},
+ {"a": 1, "b": null},
+ {"a": null, "b": "hello"},
+ {"a": 4, "b": null}
+ ])";
+ this->AssertDropNull(schm, batch_json, R"([])");
+ this->AssertDropNull(schm, R"([])", R"([])");
+}
+
+class TestDropNullKernelWithChunkedArray : public TestDropNullKernelTyped<ChunkedArray> {
+ public:
+ TestDropNullKernelWithChunkedArray()
+ : sizes_({0, 1, 2, 4, 16, 31, 1234}),
+ null_probabilities_({0.0, 0.1, 0.5, 0.9, 1.0}) {}
+
+ void AssertDropNull(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& expected) {
+ std::shared_ptr<ChunkedArray> actual;
+ ASSERT_OK(this->DoDropNull(type, values, &actual));
+ ValidateOutput(actual);
+
+ AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
+ }
+
+ Status DoDropNull(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values,
+ std::shared_ptr<ChunkedArray>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum out_datum, DropNull(ChunkedArrayFromJSON(type, values)));
+ *out = out_datum.chunked_array();
+ return Status::OK();
+ }
+
+ template <typename ArrayFactory>
+ void CheckDropNullWithSlices(ArrayFactory&& factory) {
+ for (auto size : this->sizes_) {
+ for (auto null_probability : this->null_probabilities_) {
+ std::shared_ptr<Array> concatenated_array;
+ std::shared_ptr<ChunkedArray> chunked_array;
+ factory(size, null_probability, &chunked_array, &concatenated_array);
+
+ ASSERT_OK_AND_ASSIGN(auto out_datum, DropNull(chunked_array));
+ auto actual_chunked_array = out_datum.chunked_array();
+ ASSERT_OK_AND_ASSIGN(auto actual, Concatenate(actual_chunked_array->chunks()));
+
+ ASSERT_OK_AND_ASSIGN(out_datum, DropNull(*concatenated_array));
+ auto expected = out_datum.make_array();
+
+ AssertArraysEqual(*expected, *actual);
+ }
+ }
+ }
+
+ std::vector<int32_t> sizes_;
+ std::vector<double> null_probabilities_;
+};
+
+TEST_F(TestDropNullKernelWithChunkedArray, DropNullChunkedArray) {
+ this->AssertDropNull(int8(), {"[]"}, {"[]"});
+ this->AssertDropNull(int8(), {"[null]", "[8, null]"}, {"[8]"});
+
+ this->AssertDropNull(int8(), {"[null]", "[null, null]"}, {"[]"});
+ this->AssertDropNull(int8(), {"[7]", "[8, 9]"}, {"[7]", "[8, 9]"});
+ this->AssertDropNull(int8(), {"[]", "[]"}, {"[]", "[]"});
+}
+
+TEST_F(TestDropNullKernelWithChunkedArray, DropNullChunkedArrayWithSlices) {
+ // With Null Arrays
+ this->CheckDropNullWithSlices([this](int32_t size, double null_probability,
+ std::shared_ptr<ChunkedArray>* out_chunked_array,
+ std::shared_ptr<Array>* out_concatenated_array) {
+ auto array = std::make_shared<NullArray>(size);
+ auto offsets = this->Offsets(size, 3);
+ auto slices = this->Slices(array, offsets);
+ *out_chunked_array = std::make_shared<ChunkedArray>(std::move(slices));
+
+ ASSERT_OK_AND_ASSIGN(*out_concatenated_array,
+ Concatenate((*out_chunked_array)->chunks()));
+ });
+ // Without Null Arrays
+ this->CheckDropNullWithSlices([this](int32_t size, double null_probability,
+ std::shared_ptr<ChunkedArray>* out_chunked_array,
+ std::shared_ptr<Array>* out_concatenated_array) {
+ auto array = this->rng_.ArrayOf(int16(), size, null_probability);
+ auto offsets = this->Offsets(size, 3);
+ auto slices = this->Slices(array, offsets);
+ *out_chunked_array = std::make_shared<ChunkedArray>(std::move(slices));
+
+ ASSERT_OK_AND_ASSIGN(*out_concatenated_array,
+ Concatenate((*out_chunked_array)->chunks()));
+ });
+}
+
+class TestDropNullKernelWithTable : public TestDropNullKernelTyped<Table> {
+ public:
+ TestDropNullKernelWithTable()
+ : sizes_({0, 1, 4, 31, 1234}), null_probabilities_({0.0, 0.1, 0.5, 0.9, 1.0}) {}
+
+ void AssertDropNull(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& table_json,
+ const std::vector<std::string>& expected_table) {
+ std::shared_ptr<Table> actual;
+ ASSERT_OK(this->DoDropNull(schm, table_json, &actual));
+ ValidateOutput(actual);
+ ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual);
+ }
+
+ Status DoDropNull(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& values, std::shared_ptr<Table>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum out_datum, DropNull(TableFromJSON(schm, values)));
+ *out = out_datum.table();
+ return Status::OK();
+ }
+
+ template <typename ArrayFactory>
+ void CheckDropNullWithSlices(ArrayFactory&& factory) {
+ for (auto size : this->sizes_) {
+ for (auto null_probability : this->null_probabilities_) {
+ std::shared_ptr<Table> table_w_slices;
+ std::shared_ptr<Table> table_wo_slices;
+
+ factory(size, null_probability, &table_w_slices, &table_wo_slices);
+
+ ASSERT_OK_AND_ASSIGN(auto out_datum, DropNull(table_w_slices));
+ ValidateOutput(out_datum);
+ auto actual = out_datum.table();
+
+ ASSERT_OK_AND_ASSIGN(out_datum, DropNull(table_wo_slices));
+ ValidateOutput(out_datum);
+ auto expected = out_datum.table();
+ if (actual->num_rows() > 0) {
+ ASSERT_TRUE(actual->num_rows() == expected->num_rows());
+ for (int index = 0; index < actual->num_columns(); index++) {
+ ASSERT_OK_AND_ASSIGN(auto actual_col,
+ Concatenate(actual->column(index)->chunks()));
+ ASSERT_OK_AND_ASSIGN(auto expected_col,
+ Concatenate(expected->column(index)->chunks()));
+ AssertArraysEqual(*actual_col, *expected_col);
+ }
+ }
+ }
+ }
+ }
+
+ std::vector<int32_t> sizes_;
+ std::vector<double> null_probabilities_;
+};
+
+TEST_F(TestDropNullKernelWithTable, DropNullTable) {
+ std::vector<std::shared_ptr<Field>> fields = {field("a", int32()), field("b", utf8())};
+ auto schm = schema(fields);
+
+ {
+ std::vector<std::string> table_json = {R"([
+ {"a": null, "b": "yo"},
+ {"a": 1, "b": ""}
+ ])",
+ R"([
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])"};
+ std::vector<std::string> expected_table_json = {R"([
+ {"a": 1, "b": ""}
+ ])",
+ R"([
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])"};
+ this->AssertDropNull(schm, table_json, expected_table_json);
+ }
+ {
+ std::vector<std::string> table_json = {R"([
+ {"a": null, "b": "yo"},
+ {"a": 1, "b": null}
+ ])",
+ R"([
+ {"a": 2, "b": null},
+ {"a": null, "b": "eh"}
+ ])"};
+ std::shared_ptr<Table> actual;
+ ASSERT_OK(this->DoDropNull(schm, table_json, &actual));
+ AssertSchemaEqual(schm, actual->schema());
+ ASSERT_EQ(actual->num_rows(), 0);
+ }
+}
+
+TEST_F(TestDropNullKernelWithTable, DropNullTableWithSlices) {
+ // With Null Arrays
+ this->CheckDropNullWithSlices([this](int32_t size, double null_probability,
+ std::shared_ptr<Table>* out_table_w_slices,
+ std::shared_ptr<Table>* out_table_wo_slices) {
+ FieldVector fields = {field("a", int32()), field("b", utf8())};
+ auto schm = schema(fields);
+ ASSERT_OK_AND_ASSIGN(auto col_a, MakeArrayOfNull(int32(), size));
+ ASSERT_OK_AND_ASSIGN(auto col_b, MakeArrayOfNull(utf8(), size));
+
+ // Compute random chunkings of columns `a` and `b`
+ auto slices_a = this->Slices(col_a, this->Offsets(size, 3));
+ auto slices_b = this->Slices(col_b, this->Offsets(size, 3));
+
+ ChunkedArrayVector table_content_w_slices{
+ std::make_shared<ChunkedArray>(std::move(slices_a)),
+ std::make_shared<ChunkedArray>(std::move(slices_b))};
+ *out_table_w_slices = Table::Make(schm, std::move(table_content_w_slices), size);
+
+ ChunkedArrayVector table_content_wo_slices{std::make_shared<ChunkedArray>(col_a),
+ std::make_shared<ChunkedArray>(col_b)};
+ *out_table_wo_slices = Table::Make(schm, std::move(table_content_wo_slices), size);
+ });
+
+ // Without Null Arrays
+ this->CheckDropNullWithSlices([this](int32_t size, double null_probability,
+ std::shared_ptr<Table>* out_table_w_slices,
+ std::shared_ptr<Table>* out_table_wo_slices) {
+ FieldVector fields = {field("a", int32()), field("b", utf8())};
+ auto schm = schema(fields);
+ auto col_a = this->rng_.ArrayOf(int32(), size, null_probability);
+ auto col_b = this->rng_.ArrayOf(utf8(), size, null_probability);
+
+ // Compute random chunkings of columns `a` and `b`
+ auto slices_a = this->Slices(col_a, this->Offsets(size, 3));
+ auto slices_b = this->Slices(col_b, this->Offsets(size, 3));
+
+ ChunkedArrayVector table_content_w_slices{
+ std::make_shared<ChunkedArray>(std::move(slices_a)),
+ std::make_shared<ChunkedArray>(std::move(slices_b))};
+ *out_table_w_slices = Table::Make(schm, std::move(table_content_w_slices), size);
+
+ ChunkedArrayVector table_content_wo_slices{std::make_shared<ChunkedArray>(col_a),
+ std::make_shared<ChunkedArray>(col_b)};
+ *out_table_wo_slices = Table::Make(schm, std::move(table_content_wo_slices), size);
+ });
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_sort.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_sort.cc
new file mode 100644
index 000000000..e7178f71d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_sort.cc
@@ -0,0 +1,1902 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cmath>
+#include <iterator>
+#include <limits>
+#include <numeric>
+#include <queue>
+#include <type_traits>
+#include <unordered_set>
+#include <utility>
+
+#include "arrow/array/concatenate.h"
+#include "arrow/array/data.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/compute/kernels/vector_sort_internal.h"
+#include "arrow/table.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/optional.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+namespace internal {
+
+namespace {
+
+struct SortField {
+ int field_index;
+ SortOrder order;
+};
+
+// Return the field indices of the sort keys, deduplicating them along the way
+Result<std::vector<SortField>> FindSortKeys(const Schema& schema,
+ const std::vector<SortKey>& sort_keys) {
+ std::vector<SortField> fields;
+ std::unordered_set<int> seen;
+ fields.reserve(sort_keys.size());
+ seen.reserve(sort_keys.size());
+
+ for (const auto& sort_key : sort_keys) {
+ const auto r = schema.GetFieldIndex(sort_key.name);
+ if (r < 0) {
+ return Status::KeyError("Nonexistent sort key column: ", sort_key.name);
+ }
+ if (seen.insert(r).second) {
+ fields.push_back({r, sort_key.order});
+ }
+ }
+ return fields;
+}
+
+template <typename ResolvedSortKey, typename ResolvedSortKeyFactory>
+Result<std::vector<ResolvedSortKey>> ResolveSortKeys(
+ const Schema& schema, const std::vector<SortKey>& sort_keys,
+ ResolvedSortKeyFactory&& factory) {
+ ARROW_ASSIGN_OR_RAISE(const auto fields, FindSortKeys(schema, sort_keys));
+ std::vector<ResolvedSortKey> resolved;
+ resolved.reserve(fields.size());
+ std::transform(fields.begin(), fields.end(), std::back_inserter(resolved), factory);
+ return resolved;
+}
+
+template <typename ResolvedSortKey, typename TableOrBatch>
+Result<std::vector<ResolvedSortKey>> ResolveSortKeys(
+ const TableOrBatch& table_or_batch, const std::vector<SortKey>& sort_keys) {
+ return ResolveSortKeys<ResolvedSortKey>(
+ *table_or_batch.schema(), sort_keys, [&](const SortField& f) {
+ return ResolvedSortKey{table_or_batch.column(f.field_index), f.order};
+ });
+}
+
+// We could try to reproduce the concrete Array classes' facilities
+// (such as cached raw values pointer) in a separate hierarchy of
+// physical accessors, but doing so ends up too cumbersome.
+// Instead, we simply create the desired concrete Array objects.
+std::shared_ptr<Array> GetPhysicalArray(const Array& array,
+ const std::shared_ptr<DataType>& physical_type) {
+ auto new_data = array.data()->Copy();
+ new_data->type = physical_type;
+ return MakeArray(std::move(new_data));
+}
+
+ArrayVector GetPhysicalChunks(const ArrayVector& chunks,
+ const std::shared_ptr<DataType>& physical_type) {
+ ArrayVector physical(chunks.size());
+ std::transform(chunks.begin(), chunks.end(), physical.begin(),
+ [&](const std::shared_ptr<Array>& array) {
+ return GetPhysicalArray(*array, physical_type);
+ });
+ return physical;
+}
+
+ArrayVector GetPhysicalChunks(const ChunkedArray& chunked_array,
+ const std::shared_ptr<DataType>& physical_type) {
+ return GetPhysicalChunks(chunked_array.chunks(), physical_type);
+}
+
+Result<RecordBatchVector> BatchesFromTable(const Table& table) {
+ RecordBatchVector batches;
+ TableBatchReader reader(table);
+ RETURN_NOT_OK(reader.ReadAll(&batches));
+ return batches;
+}
+
+// ----------------------------------------------------------------------
+// ChunkedArray sorting implementation
+
+// Sort a chunked array by sorting each array in the chunked array,
+// then merging the sorted chunks recursively.
+class ChunkedArraySorter : public TypeVisitor {
+ public:
+ ChunkedArraySorter(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end,
+ const ChunkedArray& chunked_array, const SortOrder order,
+ const NullPlacement null_placement)
+ : TypeVisitor(),
+ indices_begin_(indices_begin),
+ indices_end_(indices_end),
+ chunked_array_(chunked_array),
+ physical_type_(GetPhysicalType(chunked_array.type())),
+ physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)),
+ order_(order),
+ null_placement_(null_placement),
+ ctx_(ctx) {}
+
+ Status Sort() {
+ ARROW_ASSIGN_OR_RAISE(array_sorter_, GetArraySorter(*physical_type_));
+ return physical_type_->Accept(this);
+ }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) override { return SortInternal<TYPE>(); }
+
+ VISIT_SORTABLE_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ Status Visit(const NullType&) override {
+ std::iota(indices_begin_, indices_end_, 0);
+ return Status::OK();
+ }
+
+ private:
+ template <typename Type>
+ Status SortInternal() {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ ArraySortOptions options(order_, null_placement_);
+ const auto num_chunks = chunked_array_.num_chunks();
+ if (num_chunks == 0) {
+ return Status::OK();
+ }
+ const auto arrays = GetArrayPointers(physical_chunks_);
+
+ // Sort each chunk independently and merge to sorted indices.
+ // This is a serial implementation.
+ std::vector<NullPartitionResult> sorted(num_chunks);
+
+ // First sort all individual chunks
+ int64_t begin_offset = 0;
+ int64_t end_offset = 0;
+ int64_t null_count = 0;
+ for (int i = 0; i < num_chunks; ++i) {
+ const auto array = checked_cast<const ArrayType*>(arrays[i]);
+ end_offset += array->length();
+ null_count += array->null_count();
+ sorted[i] =
+ array_sorter_(indices_begin_ + begin_offset, indices_begin_ + end_offset,
+ *array, begin_offset, options);
+ begin_offset = end_offset;
+ }
+ DCHECK_EQ(end_offset, indices_end_ - indices_begin_);
+
+ // Then merge them by pairs, recursively
+ if (sorted.size() > 1) {
+ auto merge_nulls = [&](uint64_t* nulls_begin, uint64_t* nulls_middle,
+ uint64_t* nulls_end, uint64_t* temp_indices,
+ int64_t null_count) {
+ if (has_null_like_values<typename ArrayType::TypeClass>::value) {
+ PartitionNullsOnly<StablePartitioner>(nulls_begin, nulls_end,
+ ChunkedArrayResolver(arrays), null_count,
+ null_placement_);
+ }
+ };
+ auto merge_non_nulls = [&](uint64_t* range_begin, uint64_t* range_middle,
+ uint64_t* range_end, uint64_t* temp_indices) {
+ MergeNonNulls<ArrayType>(range_begin, range_middle, range_end, arrays,
+ temp_indices);
+ };
+
+ MergeImpl merge_impl{null_placement_, std::move(merge_nulls),
+ std::move(merge_non_nulls)};
+ // std::merge is only called on non-null values, so size temp indices accordingly
+ RETURN_NOT_OK(merge_impl.Init(ctx_, indices_end_ - indices_begin_ - null_count));
+
+ while (sorted.size() > 1) {
+ auto out_it = sorted.begin();
+ auto it = sorted.begin();
+ while (it < sorted.end() - 1) {
+ const auto& left = *it++;
+ const auto& right = *it++;
+ DCHECK_EQ(left.overall_end(), right.overall_begin());
+ const auto merged = merge_impl.Merge(left, right, null_count);
+ *out_it++ = merged;
+ }
+ if (it < sorted.end()) {
+ *out_it++ = *it++;
+ }
+ sorted.erase(out_it, sorted.end());
+ }
+ }
+
+ DCHECK_EQ(sorted.size(), 1);
+ DCHECK_EQ(sorted[0].overall_begin(), indices_begin_);
+ DCHECK_EQ(sorted[0].overall_end(), indices_end_);
+ // Note that "nulls" can also include NaNs, hence the >= check
+ DCHECK_GE(sorted[0].null_count(), null_count);
+
+ return Status::OK();
+ }
+
+ template <typename ArrayType>
+ void MergeNonNulls(uint64_t* range_begin, uint64_t* range_middle, uint64_t* range_end,
+ const std::vector<const Array*>& arrays, uint64_t* temp_indices) {
+ const ChunkedArrayResolver left_resolver(arrays);
+ const ChunkedArrayResolver right_resolver(arrays);
+
+ if (order_ == SortOrder::Ascending) {
+ std::merge(range_begin, range_middle, range_middle, range_end, temp_indices,
+ [&](uint64_t left, uint64_t right) {
+ const auto chunk_left = left_resolver.Resolve<ArrayType>(left);
+ const auto chunk_right = right_resolver.Resolve<ArrayType>(right);
+ return chunk_left.Value() < chunk_right.Value();
+ });
+ } else {
+ std::merge(range_begin, range_middle, range_middle, range_end, temp_indices,
+ [&](uint64_t left, uint64_t right) {
+ const auto chunk_left = left_resolver.Resolve<ArrayType>(left);
+ const auto chunk_right = right_resolver.Resolve<ArrayType>(right);
+ // We don't use 'left > right' here to reduce required
+ // operator. If we use 'right < left' here, '<' is only
+ // required.
+ return chunk_right.Value() < chunk_left.Value();
+ });
+ }
+ // Copy back temp area into main buffer
+ std::copy(temp_indices, temp_indices + (range_end - range_begin), range_begin);
+ }
+
+ uint64_t* indices_begin_;
+ uint64_t* indices_end_;
+ const ChunkedArray& chunked_array_;
+ const std::shared_ptr<DataType> physical_type_;
+ const ArrayVector physical_chunks_;
+ const SortOrder order_;
+ const NullPlacement null_placement_;
+ ArraySortFunc array_sorter_;
+ ExecContext* ctx_;
+};
+
+// ----------------------------------------------------------------------
+// Record batch sorting implementation(s)
+
+// Visit contiguous ranges of equal values. All entries are assumed
+// to be non-null.
+template <typename ArrayType, typename Visitor>
+void VisitConstantRanges(const ArrayType& array, uint64_t* indices_begin,
+ uint64_t* indices_end, int64_t offset, Visitor&& visit) {
+ using GetView = GetViewType<typename ArrayType::TypeClass>;
+
+ if (indices_begin == indices_end) {
+ return;
+ }
+ auto range_start = indices_begin;
+ auto range_cur = range_start;
+ auto last_value = GetView::LogicalValue(array.GetView(*range_cur - offset));
+ while (++range_cur != indices_end) {
+ auto v = GetView::LogicalValue(array.GetView(*range_cur - offset));
+ if (v != last_value) {
+ visit(range_start, range_cur);
+ range_start = range_cur;
+ last_value = v;
+ }
+ }
+ if (range_start != range_cur) {
+ visit(range_start, range_cur);
+ }
+}
+
+// A sorter for a single column of a RecordBatch, deferring to the next column
+// for ranges of equal values.
+class RecordBatchColumnSorter {
+ public:
+ explicit RecordBatchColumnSorter(RecordBatchColumnSorter* next_column = nullptr)
+ : next_column_(next_column) {}
+ virtual ~RecordBatchColumnSorter() {}
+
+ virtual NullPartitionResult SortRange(uint64_t* indices_begin, uint64_t* indices_end,
+ int64_t offset) = 0;
+
+ protected:
+ RecordBatchColumnSorter* next_column_;
+};
+
+template <typename Type>
+class ConcreteRecordBatchColumnSorter : public RecordBatchColumnSorter {
+ public:
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+
+ ConcreteRecordBatchColumnSorter(std::shared_ptr<Array> array, SortOrder order,
+ NullPlacement null_placement,
+ RecordBatchColumnSorter* next_column = nullptr)
+ : RecordBatchColumnSorter(next_column),
+ owned_array_(std::move(array)),
+ array_(checked_cast<const ArrayType&>(*owned_array_)),
+ order_(order),
+ null_placement_(null_placement),
+ null_count_(array_.null_count()) {}
+
+ NullPartitionResult SortRange(uint64_t* indices_begin, uint64_t* indices_end,
+ int64_t offset) override {
+ using GetView = GetViewType<Type>;
+
+ NullPartitionResult p;
+ if (null_count_ == 0) {
+ p = NullPartitionResult::NoNulls(indices_begin, indices_end, null_placement_);
+ } else {
+ // NOTE that null_count_ is merely an upper bound on the number of nulls
+ // in this particular range.
+ p = PartitionNullsOnly<StablePartitioner>(indices_begin, indices_end, array_,
+ offset, null_placement_);
+ DCHECK_LE(p.nulls_end - p.nulls_begin, null_count_);
+ }
+ const NullPartitionResult q = PartitionNullLikes<ArrayType, StablePartitioner>(
+ p.non_nulls_begin, p.non_nulls_end, array_, offset, null_placement_);
+
+ // TODO This is roughly the same as ArrayCompareSorter.
+ // Also, we would like to use a counting sort if possible. This requires
+ // a counting sort compatible with indirect indexing.
+ if (order_ == SortOrder::Ascending) {
+ std::stable_sort(
+ q.non_nulls_begin, q.non_nulls_end, [&](uint64_t left, uint64_t right) {
+ const auto lhs = GetView::LogicalValue(array_.GetView(left - offset));
+ const auto rhs = GetView::LogicalValue(array_.GetView(right - offset));
+ return lhs < rhs;
+ });
+ } else {
+ std::stable_sort(
+ q.non_nulls_begin, q.non_nulls_end, [&](uint64_t left, uint64_t right) {
+ // We don't use 'left > right' here to reduce required operator.
+ // If we use 'right < left' here, '<' is only required.
+ const auto lhs = GetView::LogicalValue(array_.GetView(left - offset));
+ const auto rhs = GetView::LogicalValue(array_.GetView(right - offset));
+ return lhs > rhs;
+ });
+ }
+
+ if (next_column_ != nullptr) {
+ // Visit all ranges of equal values in this column and sort them on
+ // the next column.
+ SortNextColumn(q.nulls_begin, q.nulls_end, offset);
+ SortNextColumn(p.nulls_begin, p.nulls_end, offset);
+ VisitConstantRanges(array_, q.non_nulls_begin, q.non_nulls_end, offset,
+ [&](uint64_t* range_start, uint64_t* range_end) {
+ SortNextColumn(range_start, range_end, offset);
+ });
+ }
+ return NullPartitionResult{q.non_nulls_begin, q.non_nulls_end,
+ std::min(q.nulls_begin, p.nulls_begin),
+ std::max(q.nulls_end, p.nulls_end)};
+ }
+
+ void SortNextColumn(uint64_t* indices_begin, uint64_t* indices_end, int64_t offset) {
+ // Avoid the cost of a virtual method call in trivial cases
+ if (indices_end - indices_begin > 1) {
+ next_column_->SortRange(indices_begin, indices_end, offset);
+ }
+ }
+
+ protected:
+ const std::shared_ptr<Array> owned_array_;
+ const ArrayType& array_;
+ const SortOrder order_;
+ const NullPlacement null_placement_;
+ const int64_t null_count_;
+};
+
+template <>
+class ConcreteRecordBatchColumnSorter<NullType> : public RecordBatchColumnSorter {
+ public:
+ ConcreteRecordBatchColumnSorter(std::shared_ptr<Array> array, SortOrder order,
+ NullPlacement null_placement,
+ RecordBatchColumnSorter* next_column = nullptr)
+ : RecordBatchColumnSorter(next_column), null_placement_(null_placement) {}
+
+ NullPartitionResult SortRange(uint64_t* indices_begin, uint64_t* indices_end,
+ int64_t offset) {
+ if (next_column_ != nullptr) {
+ next_column_->SortRange(indices_begin, indices_end, offset);
+ }
+ return NullPartitionResult::NullsOnly(indices_begin, indices_end, null_placement_);
+ }
+
+ protected:
+ const NullPlacement null_placement_;
+};
+
+// Sort a batch using a single-pass left-to-right radix sort.
+class RadixRecordBatchSorter {
+ public:
+ RadixRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end,
+ const RecordBatch& batch, const SortOptions& options)
+ : batch_(batch),
+ options_(options),
+ indices_begin_(indices_begin),
+ indices_end_(indices_end) {}
+
+ // Offset is for table sorting
+ Result<NullPartitionResult> Sort(int64_t offset = 0) {
+ ARROW_ASSIGN_OR_RAISE(const auto sort_keys,
+ ResolveSortKeys(batch_, options_.sort_keys));
+
+ // Create column sorters from right to left
+ std::vector<std::unique_ptr<RecordBatchColumnSorter>> column_sorts(sort_keys.size());
+ RecordBatchColumnSorter* next_column = nullptr;
+ for (int64_t i = static_cast<int64_t>(sort_keys.size() - 1); i >= 0; --i) {
+ ColumnSortFactory factory(sort_keys[i], options_, next_column);
+ ARROW_ASSIGN_OR_RAISE(column_sorts[i], factory.MakeColumnSort());
+ next_column = column_sorts[i].get();
+ }
+
+ // Sort from left to right
+ return column_sorts.front()->SortRange(indices_begin_, indices_end_, offset);
+ }
+
+ protected:
+ struct ResolvedSortKey {
+ std::shared_ptr<Array> array;
+ SortOrder order;
+ };
+
+ struct ColumnSortFactory {
+ ColumnSortFactory(const ResolvedSortKey& sort_key, const SortOptions& options,
+ RecordBatchColumnSorter* next_column)
+ : physical_type(GetPhysicalType(sort_key.array->type())),
+ array(GetPhysicalArray(*sort_key.array, physical_type)),
+ order(sort_key.order),
+ null_placement(options.null_placement),
+ next_column(next_column) {}
+
+ Result<std::unique_ptr<RecordBatchColumnSorter>> MakeColumnSort() {
+ RETURN_NOT_OK(VisitTypeInline(*physical_type, this));
+ DCHECK_NE(result, nullptr);
+ return std::move(result);
+ }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return VisitGeneric(type); }
+
+ VISIT_SORTABLE_PHYSICAL_TYPES(VISIT)
+ VISIT(NullType)
+
+#undef VISIT
+
+ Status Visit(const DataType& type) {
+ return Status::TypeError("Unsupported type for RecordBatch sorting: ",
+ type.ToString());
+ }
+
+ template <typename Type>
+ Status VisitGeneric(const Type&) {
+ result.reset(new ConcreteRecordBatchColumnSorter<Type>(array, order, null_placement,
+ next_column));
+ return Status::OK();
+ }
+
+ std::shared_ptr<DataType> physical_type;
+ std::shared_ptr<Array> array;
+ SortOrder order;
+ NullPlacement null_placement;
+ RecordBatchColumnSorter* next_column;
+ std::unique_ptr<RecordBatchColumnSorter> result;
+ };
+
+ static Result<std::vector<ResolvedSortKey>> ResolveSortKeys(
+ const RecordBatch& batch, const std::vector<SortKey>& sort_keys) {
+ return ::arrow::compute::internal::ResolveSortKeys<ResolvedSortKey>(batch, sort_keys);
+ }
+
+ const RecordBatch& batch_;
+ const SortOptions& options_;
+ uint64_t* indices_begin_;
+ uint64_t* indices_end_;
+};
+
+// Compare two records in a single column (either from a batch or table)
+template <typename ResolvedSortKey>
+struct ColumnComparator {
+ using Location = typename ResolvedSortKey::LocationType;
+
+ ColumnComparator(const ResolvedSortKey& sort_key, NullPlacement null_placement)
+ : sort_key_(sort_key), null_placement_(null_placement) {}
+
+ virtual ~ColumnComparator() = default;
+
+ virtual int Compare(const Location& left, const Location& right) const = 0;
+
+ ResolvedSortKey sort_key_;
+ NullPlacement null_placement_;
+};
+
+template <typename ResolvedSortKey, typename Type>
+struct ConcreteColumnComparator : public ColumnComparator<ResolvedSortKey> {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using Location = typename ResolvedSortKey::LocationType;
+
+ using ColumnComparator<ResolvedSortKey>::ColumnComparator;
+
+ int Compare(const Location& left, const Location& right) const override {
+ const auto& sort_key = this->sort_key_;
+
+ const auto chunk_left = sort_key.template GetChunk<ArrayType>(left);
+ const auto chunk_right = sort_key.template GetChunk<ArrayType>(right);
+ if (sort_key.null_count > 0) {
+ const bool is_null_left = chunk_left.IsNull();
+ const bool is_null_right = chunk_right.IsNull();
+ if (is_null_left && is_null_right) {
+ return 0;
+ } else if (is_null_left) {
+ return this->null_placement_ == NullPlacement::AtStart ? -1 : 1;
+ } else if (is_null_right) {
+ return this->null_placement_ == NullPlacement::AtStart ? 1 : -1;
+ }
+ }
+ return CompareTypeValues<Type>(chunk_left.Value(), chunk_right.Value(),
+ sort_key.order, this->null_placement_);
+ }
+};
+
+template <typename ResolvedSortKey>
+struct ConcreteColumnComparator<ResolvedSortKey, NullType>
+ : public ColumnComparator<ResolvedSortKey> {
+ using Location = typename ResolvedSortKey::LocationType;
+
+ using ColumnComparator<ResolvedSortKey>::ColumnComparator;
+
+ int Compare(const Location& left, const Location& right) const override { return 0; }
+};
+
+// Compare two records in the same RecordBatch or Table
+// (indexing is handled through ResolvedSortKey)
+template <typename ResolvedSortKey>
+class MultipleKeyComparator {
+ public:
+ using Location = typename ResolvedSortKey::LocationType;
+
+ MultipleKeyComparator(const std::vector<ResolvedSortKey>& sort_keys,
+ NullPlacement null_placement)
+ : sort_keys_(sort_keys), null_placement_(null_placement) {
+ status_ &= MakeComparators();
+ }
+
+ Status status() const { return status_; }
+
+ // Returns true if the left-th value should be ordered before the
+ // right-th value, false otherwise. The start_sort_key_index-th
+ // sort key and subsequent sort keys are used for comparison.
+ bool Compare(const Location& left, const Location& right, size_t start_sort_key_index) {
+ return CompareInternal(left, right, start_sort_key_index) < 0;
+ }
+
+ bool Equals(const Location& left, const Location& right, size_t start_sort_key_index) {
+ return CompareInternal(left, right, start_sort_key_index) == 0;
+ }
+
+ private:
+ struct ColumnComparatorFactory {
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return VisitGeneric(type); }
+
+ VISIT_SORTABLE_PHYSICAL_TYPES(VISIT)
+ VISIT(NullType)
+
+#undef VISIT
+
+ Status Visit(const DataType& type) {
+ return Status::TypeError("Unsupported type for batch or table sorting: ",
+ type.ToString());
+ }
+
+ template <typename Type>
+ Status VisitGeneric(const Type& type) {
+ res.reset(
+ new ConcreteColumnComparator<ResolvedSortKey, Type>{sort_key, null_placement});
+ return Status::OK();
+ }
+
+ const ResolvedSortKey& sort_key;
+ NullPlacement null_placement;
+ std::unique_ptr<ColumnComparator<ResolvedSortKey>> res;
+ };
+
+ Status MakeComparators() {
+ column_comparators_.reserve(sort_keys_.size());
+
+ for (const auto& sort_key : sort_keys_) {
+ ColumnComparatorFactory factory{sort_key, null_placement_, nullptr};
+ RETURN_NOT_OK(VisitTypeInline(*sort_key.type, &factory));
+ column_comparators_.push_back(std::move(factory.res));
+ }
+ return Status::OK();
+ }
+
+ // Compare two records in the same table and return -1, 0 or 1.
+ //
+ // -1: The left is less than the right.
+ // 0: The left equals to the right.
+ // 1: The left is greater than the right.
+ //
+ // This supports null and NaN. Null is processed in this and NaN
+ // is processed in CompareTypeValue().
+ int CompareInternal(const Location& left, const Location& right,
+ size_t start_sort_key_index) {
+ const auto num_sort_keys = sort_keys_.size();
+ for (size_t i = start_sort_key_index; i < num_sort_keys; ++i) {
+ const int r = column_comparators_[i]->Compare(left, right);
+ if (r != 0) {
+ return r;
+ }
+ }
+ return 0;
+ }
+
+ const std::vector<ResolvedSortKey>& sort_keys_;
+ const NullPlacement null_placement_;
+ std::vector<std::unique_ptr<ColumnComparator<ResolvedSortKey>>> column_comparators_;
+ Status status_;
+};
+
+// Sort a batch using a single sort and multiple-key comparisons.
+class MultipleKeyRecordBatchSorter : public TypeVisitor {
+ public:
+ // Preprocessed sort key.
+ struct ResolvedSortKey {
+ ResolvedSortKey(const std::shared_ptr<Array>& array, SortOrder order)
+ : type(GetPhysicalType(array->type())),
+ owned_array(GetPhysicalArray(*array, type)),
+ array(*owned_array),
+ order(order),
+ null_count(array->null_count()) {}
+
+ using LocationType = int64_t;
+
+ template <typename ArrayType>
+ ResolvedChunk<ArrayType> GetChunk(int64_t index) const {
+ return {&checked_cast<const ArrayType&>(array), index};
+ }
+
+ const std::shared_ptr<DataType> type;
+ std::shared_ptr<Array> owned_array;
+ const Array& array;
+ SortOrder order;
+ int64_t null_count;
+ };
+
+ private:
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ MultipleKeyRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end,
+ const RecordBatch& batch, const SortOptions& options)
+ : indices_begin_(indices_begin),
+ indices_end_(indices_end),
+ sort_keys_(ResolveSortKeys(batch, options.sort_keys, &status_)),
+ null_placement_(options.null_placement),
+ comparator_(sort_keys_, null_placement_) {}
+
+ // This is optimized for the first sort key. The first sort key sort
+ // is processed in this class. The second and following sort keys
+ // are processed in Comparator.
+ Status Sort() {
+ RETURN_NOT_OK(status_);
+ return sort_keys_[0].type->Accept(this);
+ }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) override { return SortInternal<TYPE>(); }
+
+ VISIT_SORTABLE_PHYSICAL_TYPES(VISIT)
+ VISIT(NullType)
+
+#undef VISIT
+
+ private:
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const RecordBatch& batch, const std::vector<SortKey>& sort_keys, Status* status) {
+ const auto maybe_resolved =
+ ::arrow::compute::internal::ResolveSortKeys<ResolvedSortKey>(batch, sort_keys);
+ if (!maybe_resolved.ok()) {
+ *status = maybe_resolved.status();
+ return {};
+ }
+ return *std::move(maybe_resolved);
+ }
+
+ template <typename Type>
+ enable_if_t<!is_null_type<Type>::value, Status> SortInternal() {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using GetView = GetViewType<Type>;
+
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+ const ArrayType& array = checked_cast<const ArrayType&>(first_sort_key.array);
+ const auto p = PartitionNullsInternal<Type>(first_sort_key);
+
+ // Sort first-key non-nulls
+ std::stable_sort(
+ p.non_nulls_begin, p.non_nulls_end, [&](uint64_t left, uint64_t right) {
+ // Both values are never null nor NaN
+ // (otherwise they've been partitioned away above).
+ const auto value_left = GetView::LogicalValue(array.GetView(left));
+ const auto value_right = GetView::LogicalValue(array.GetView(right));
+ if (value_left != value_right) {
+ bool compared = value_left < value_right;
+ if (first_sort_key.order == SortOrder::Ascending) {
+ return compared;
+ } else {
+ return !compared;
+ }
+ }
+ // If the left value equals to the right value,
+ // we need to compare the second and following
+ // sort keys.
+ return comparator.Compare(left, right, 1);
+ });
+ return comparator_.status();
+ }
+
+ template <typename Type>
+ enable_if_null<Type, Status> SortInternal() {
+ std::stable_sort(indices_begin_, indices_end_, [&](uint64_t left, uint64_t right) {
+ return comparator_.Compare(left, right, 1);
+ });
+ return comparator_.status();
+ }
+
+ // Behaves like PartitionNulls() but this supports multiple sort keys.
+ template <typename Type>
+ NullPartitionResult PartitionNullsInternal(const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ const ArrayType& array = checked_cast<const ArrayType&>(first_sort_key.array);
+
+ const auto p = PartitionNullsOnly<StablePartitioner>(indices_begin_, indices_end_,
+ array, 0, null_placement_);
+ const auto q = PartitionNullLikes<ArrayType, StablePartitioner>(
+ p.non_nulls_begin, p.non_nulls_end, array, 0, null_placement_);
+
+ auto& comparator = comparator_;
+ if (q.nulls_begin != q.nulls_end) {
+ // Sort all NaNs by the second and following sort keys.
+ // TODO: could we instead run an independent sort from the second key on
+ // this slice?
+ std::stable_sort(q.nulls_begin, q.nulls_end,
+ [&comparator](uint64_t left, uint64_t right) {
+ return comparator.Compare(left, right, 1);
+ });
+ }
+ if (p.nulls_begin != p.nulls_end) {
+ // Sort all nulls by the second and following sort keys.
+ // TODO: could we instead run an independent sort from the second key on
+ // this slice?
+ std::stable_sort(p.nulls_begin, p.nulls_end,
+ [&comparator](uint64_t left, uint64_t right) {
+ return comparator.Compare(left, right, 1);
+ });
+ }
+ return q;
+ }
+
+ uint64_t* indices_begin_;
+ uint64_t* indices_end_;
+ Status status_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ NullPlacement null_placement_;
+ Comparator comparator_;
+};
+
+// ----------------------------------------------------------------------
+// Table sorting implementation(s)
+
+// Sort a table using an explicit merge sort.
+// Each batch is first sorted individually (taking advantage of the fact
+// that batch columns are contiguous and therefore have less indexing
+// overhead), then sorted batches are merged recursively.
+class TableSorter {
+ public:
+ // Preprocessed sort key.
+ struct ResolvedSortKey {
+ ResolvedSortKey(const std::shared_ptr<DataType>& type, ArrayVector chunks,
+ SortOrder order, int64_t null_count)
+ : type(GetPhysicalType(type)),
+ owned_chunks(std::move(chunks)),
+ chunks(GetArrayPointers(owned_chunks)),
+ order(order),
+ null_count(null_count) {}
+
+ using LocationType = ChunkLocation;
+
+ template <typename ArrayType>
+ ResolvedChunk<ArrayType> GetChunk(ChunkLocation loc) const {
+ return {checked_cast<const ArrayType*>(chunks[loc.chunk_index]),
+ loc.index_in_chunk};
+ }
+
+ // Make a vector of ResolvedSortKeys for the sort keys and the given table.
+ // `batches` must be a chunking of `table`.
+ static Result<std::vector<ResolvedSortKey>> Make(
+ const Table& table, const RecordBatchVector& batches,
+ const std::vector<SortKey>& sort_keys) {
+ auto factory = [&](const SortField& f) {
+ const auto& type = table.schema()->field(f.field_index)->type();
+ // We must expose a homogenous chunking for all ResolvedSortKey,
+ // so we can't simply pass `table.column(f.field_index)`
+ ArrayVector chunks(batches.size());
+ std::transform(batches.begin(), batches.end(), chunks.begin(),
+ [&](const std::shared_ptr<RecordBatch>& batch) {
+ return batch->column(f.field_index);
+ });
+ return ResolvedSortKey(type, std::move(chunks), f.order,
+ table.column(f.field_index)->null_count());
+ };
+
+ return ::arrow::compute::internal::ResolveSortKeys<ResolvedSortKey>(
+ *table.schema(), sort_keys, factory);
+ }
+
+ std::shared_ptr<DataType> type;
+ ArrayVector owned_chunks;
+ std::vector<const Array*> chunks;
+ SortOrder order;
+ int64_t null_count;
+ };
+
+ // TODO make all methods const and defer initialization into a Init() method?
+
+ TableSorter(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end,
+ const Table& table, const SortOptions& options)
+ : ctx_(ctx),
+ table_(table),
+ batches_(MakeBatches(table, &status_)),
+ options_(options),
+ null_placement_(options.null_placement),
+ left_resolver_(ChunkResolver::FromBatches(batches_)),
+ right_resolver_(ChunkResolver::FromBatches(batches_)),
+ sort_keys_(ResolveSortKeys(table, batches_, options.sort_keys, &status_)),
+ indices_begin_(indices_begin),
+ indices_end_(indices_end),
+ comparator_(sort_keys_, null_placement_) {}
+
+ // This is optimized for null partitioning and merging along the first sort key.
+ // Other sort keys are delegated to the Comparator class.
+ Status Sort() {
+ ARROW_RETURN_NOT_OK(status_);
+ return SortInternal();
+ }
+
+ private:
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ static RecordBatchVector MakeBatches(const Table& table, Status* status) {
+ const auto maybe_batches = BatchesFromTable(table);
+ if (!maybe_batches.ok()) {
+ *status = maybe_batches.status();
+ return {};
+ }
+ return *std::move(maybe_batches);
+ }
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const Table& table, const RecordBatchVector& batches,
+ const std::vector<SortKey>& sort_keys, Status* status) {
+ const auto maybe_resolved = ResolvedSortKey::Make(table, batches, sort_keys);
+ if (!maybe_resolved.ok()) {
+ *status = maybe_resolved.status();
+ return {};
+ }
+ return *std::move(maybe_resolved);
+ }
+
+ Status SortInternal() {
+ // Sort each batch independently and merge to sorted indices.
+ RecordBatchVector batches;
+ {
+ TableBatchReader reader(table_);
+ RETURN_NOT_OK(reader.ReadAll(&batches));
+ }
+ const int64_t num_batches = static_cast<int64_t>(batches.size());
+ if (num_batches == 0) {
+ return Status::OK();
+ }
+ std::vector<NullPartitionResult> sorted(num_batches);
+
+ // First sort all individual batches
+ int64_t begin_offset = 0;
+ int64_t end_offset = 0;
+ int64_t null_count = 0;
+ for (int64_t i = 0; i < num_batches; ++i) {
+ const auto& batch = *batches[i];
+ end_offset += batch.num_rows();
+ RadixRecordBatchSorter sorter(indices_begin_ + begin_offset,
+ indices_begin_ + end_offset, batch, options_);
+ ARROW_ASSIGN_OR_RAISE(sorted[i], sorter.Sort(begin_offset));
+ DCHECK_EQ(sorted[i].overall_begin(), indices_begin_ + begin_offset);
+ DCHECK_EQ(sorted[i].overall_end(), indices_begin_ + end_offset);
+ DCHECK_EQ(sorted[i].non_null_count() + sorted[i].null_count(), batch.num_rows());
+ begin_offset = end_offset;
+ // XXX this is an upper bound on the true null count
+ null_count += sorted[i].null_count();
+ }
+ DCHECK_EQ(end_offset, indices_end_ - indices_begin_);
+
+ // Then merge them by pairs, recursively
+ if (sorted.size() > 1) {
+ struct Visitor {
+ TableSorter* sorter;
+ std::vector<NullPartitionResult>* sorted;
+ int64_t null_count;
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ return sorter->MergeInternal<TYPE>(std::move(*sorted), null_count); \
+ }
+
+ VISIT_SORTABLE_PHYSICAL_TYPES(VISIT)
+ VISIT(NullType)
+#undef VISIT
+
+ Status Visit(const DataType& type) {
+ return Status::NotImplemented("Unsupported type for sorting: ",
+ type.ToString());
+ }
+ };
+ Visitor visitor{this, &sorted, null_count};
+ RETURN_NOT_OK(VisitTypeInline(*sort_keys_[0].type, &visitor));
+ }
+ return Status::OK();
+ }
+
+ // Recursive merge routine, typed on the first sort key
+ template <typename Type>
+ Status MergeInternal(std::vector<NullPartitionResult> sorted, int64_t null_count) {
+ auto merge_nulls = [&](uint64_t* nulls_begin, uint64_t* nulls_middle,
+ uint64_t* nulls_end, uint64_t* temp_indices,
+ int64_t null_count) {
+ MergeNulls<Type>(nulls_begin, nulls_middle, nulls_end, temp_indices, null_count);
+ };
+ auto merge_non_nulls = [&](uint64_t* range_begin, uint64_t* range_middle,
+ uint64_t* range_end, uint64_t* temp_indices) {
+ MergeNonNulls<Type>(range_begin, range_middle, range_end, temp_indices);
+ };
+
+ MergeImpl merge_impl(options_.null_placement, std::move(merge_nulls),
+ std::move(merge_non_nulls));
+ RETURN_NOT_OK(merge_impl.Init(ctx_, table_.num_rows()));
+
+ while (sorted.size() > 1) {
+ auto out_it = sorted.begin();
+ auto it = sorted.begin();
+ while (it < sorted.end() - 1) {
+ const auto& left = *it++;
+ const auto& right = *it++;
+ DCHECK_EQ(left.overall_end(), right.overall_begin());
+ *out_it++ = merge_impl.Merge(left, right, null_count);
+ }
+ if (it < sorted.end()) {
+ *out_it++ = *it++;
+ }
+ sorted.erase(out_it, sorted.end());
+ }
+ DCHECK_EQ(sorted.size(), 1);
+ DCHECK_EQ(sorted[0].overall_begin(), indices_begin_);
+ DCHECK_EQ(sorted[0].overall_end(), indices_end_);
+ return comparator_.status();
+ }
+
+ // Merge rows with a null or a null-like in the first sort key
+ template <typename Type>
+ enable_if_t<has_null_like_values<Type>::value> MergeNulls(uint64_t* nulls_begin,
+ uint64_t* nulls_middle,
+ uint64_t* nulls_end,
+ uint64_t* temp_indices,
+ int64_t null_count) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+
+ std::merge(nulls_begin, nulls_middle, nulls_middle, nulls_end, temp_indices,
+ [&](uint64_t left, uint64_t right) {
+ // First column is either null or nan
+ const auto left_loc = left_resolver_.Resolve(left);
+ const auto right_loc = right_resolver_.Resolve(right);
+ auto chunk_left = first_sort_key.GetChunk<ArrayType>(left_loc);
+ auto chunk_right = first_sort_key.GetChunk<ArrayType>(right_loc);
+ const auto left_is_null = chunk_left.IsNull();
+ const auto right_is_null = chunk_right.IsNull();
+ if (left_is_null == right_is_null) {
+ return comparator.Compare(left_loc, right_loc, 1);
+ } else if (options_.null_placement == NullPlacement::AtEnd) {
+ return right_is_null;
+ } else {
+ return left_is_null;
+ }
+ });
+ // Copy back temp area into main buffer
+ std::copy(temp_indices, temp_indices + (nulls_end - nulls_begin), nulls_begin);
+ }
+
+ template <typename Type>
+ enable_if_t<!has_null_like_values<Type>::value> MergeNulls(uint64_t* nulls_begin,
+ uint64_t* nulls_middle,
+ uint64_t* nulls_end,
+ uint64_t* temp_indices,
+ int64_t null_count) {
+ MergeNullsOnly(nulls_begin, nulls_middle, nulls_end, temp_indices, null_count);
+ }
+
+ void MergeNullsOnly(uint64_t* nulls_begin, uint64_t* nulls_middle, uint64_t* nulls_end,
+ uint64_t* temp_indices, int64_t null_count) {
+ // Untyped implementation
+ auto& comparator = comparator_;
+
+ std::merge(nulls_begin, nulls_middle, nulls_middle, nulls_end, temp_indices,
+ [&](uint64_t left, uint64_t right) {
+ // First column is always null
+ const auto left_loc = left_resolver_.Resolve(left);
+ const auto right_loc = right_resolver_.Resolve(right);
+ return comparator.Compare(left_loc, right_loc, 1);
+ });
+ // Copy back temp area into main buffer
+ std::copy(temp_indices, temp_indices + (nulls_end - nulls_begin), nulls_begin);
+ }
+
+ //
+ // Merge rows with a non-null in the first sort key
+ //
+ template <typename Type>
+ enable_if_t<!is_null_type<Type>::value> MergeNonNulls(uint64_t* range_begin,
+ uint64_t* range_middle,
+ uint64_t* range_end,
+ uint64_t* temp_indices) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+
+ std::merge(range_begin, range_middle, range_middle, range_end, temp_indices,
+ [&](uint64_t left, uint64_t right) {
+ // Both values are never null nor NaN.
+ const auto left_loc = left_resolver_.Resolve(left);
+ const auto right_loc = right_resolver_.Resolve(right);
+ auto chunk_left = first_sort_key.GetChunk<ArrayType>(left_loc);
+ auto chunk_right = first_sort_key.GetChunk<ArrayType>(right_loc);
+ DCHECK(!chunk_left.IsNull());
+ DCHECK(!chunk_right.IsNull());
+ auto value_left = chunk_left.Value();
+ auto value_right = chunk_right.Value();
+ if (value_left == value_right) {
+ // If the left value equals to the right value,
+ // we need to compare the second and following
+ // sort keys.
+ return comparator.Compare(left_loc, right_loc, 1);
+ } else {
+ auto compared = value_left < value_right;
+ if (first_sort_key.order == SortOrder::Ascending) {
+ return compared;
+ } else {
+ return !compared;
+ }
+ }
+ });
+ // Copy back temp area into main buffer
+ std::copy(temp_indices, temp_indices + (range_end - range_begin), range_begin);
+ }
+
+ template <typename Type>
+ enable_if_null<Type> MergeNonNulls(uint64_t* range_begin, uint64_t* range_middle,
+ uint64_t* range_end, uint64_t* temp_indices) {
+ const int64_t null_count = range_end - range_begin;
+ MergeNullsOnly(range_begin, range_middle, range_end, temp_indices, null_count);
+ }
+
+ ExecContext* ctx_;
+ const Table& table_;
+ const RecordBatchVector batches_;
+ const SortOptions& options_;
+ const NullPlacement null_placement_;
+ const ChunkResolver left_resolver_, right_resolver_;
+ const std::vector<ResolvedSortKey> sort_keys_;
+ uint64_t* indices_begin_;
+ uint64_t* indices_end_;
+ Comparator comparator_;
+ Status status_;
+};
+
+// ----------------------------------------------------------------------
+// Top-level sort functions
+
+const auto kDefaultSortOptions = SortOptions::Defaults();
+
+const FunctionDoc sort_indices_doc(
+ "Return the indices that would sort an array, record batch or table",
+ ("This function computes an array of indices that define a stable sort\n"
+ "of the input array, record batch or table. By default, nNull values are\n"
+ "considered greater than any other value and are therefore sorted at the\n"
+ "end of the input. For floating-point types, NaNs are considered greater\n"
+ "than any other non-null value, but smaller than null values.\n"
+ "\n"
+ "The handling of nulls and NaNs can be changed in SortOptions."),
+ {"input"}, "SortOptions");
+
+class SortIndicesMetaFunction : public MetaFunction {
+ public:
+ SortIndicesMetaFunction()
+ : MetaFunction("sort_indices", Arity::Unary(), &sort_indices_doc,
+ &kDefaultSortOptions) {}
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ const SortOptions& sort_options = static_cast<const SortOptions&>(*options);
+ switch (args[0].kind()) {
+ case Datum::ARRAY:
+ return SortIndices(*args[0].make_array(), sort_options, ctx);
+ break;
+ case Datum::CHUNKED_ARRAY:
+ return SortIndices(*args[0].chunked_array(), sort_options, ctx);
+ break;
+ case Datum::RECORD_BATCH: {
+ return SortIndices(*args[0].record_batch(), sort_options, ctx);
+ } break;
+ case Datum::TABLE:
+ return SortIndices(*args[0].table(), sort_options, ctx);
+ break;
+ default:
+ break;
+ }
+ return Status::NotImplemented(
+ "Unsupported types for sort_indices operation: "
+ "values=",
+ args[0].ToString());
+ }
+
+ private:
+ Result<Datum> SortIndices(const Array& values, const SortOptions& options,
+ ExecContext* ctx) const {
+ SortOrder order = SortOrder::Ascending;
+ if (!options.sort_keys.empty()) {
+ order = options.sort_keys[0].order;
+ }
+ ArraySortOptions array_options(order, options.null_placement);
+ return CallFunction("array_sort_indices", {values}, &array_options, ctx);
+ }
+
+ Result<Datum> SortIndices(const ChunkedArray& chunked_array, const SortOptions& options,
+ ExecContext* ctx) const {
+ SortOrder order = SortOrder::Ascending;
+ if (!options.sort_keys.empty()) {
+ order = options.sort_keys[0].order;
+ }
+
+ auto out_type = uint64();
+ auto length = chunked_array.length();
+ auto buffer_size = BitUtil::BytesForBits(
+ length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+ std::vector<std::shared_ptr<Buffer>> buffers(2);
+ ARROW_ASSIGN_OR_RAISE(buffers[1],
+ AllocateResizableBuffer(buffer_size, ctx->memory_pool()));
+ auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+ auto out_begin = out->GetMutableValues<uint64_t>(1);
+ auto out_end = out_begin + length;
+ std::iota(out_begin, out_end, 0);
+
+ ChunkedArraySorter sorter(ctx, out_begin, out_end, chunked_array, order,
+ options.null_placement);
+ ARROW_RETURN_NOT_OK(sorter.Sort());
+ return Datum(out);
+ }
+
+ Result<Datum> SortIndices(const RecordBatch& batch, const SortOptions& options,
+ ExecContext* ctx) const {
+ auto n_sort_keys = options.sort_keys.size();
+ if (n_sort_keys == 0) {
+ return Status::Invalid("Must specify one or more sort keys");
+ }
+ if (n_sort_keys == 1) {
+ auto array = batch.GetColumnByName(options.sort_keys[0].name);
+ if (!array) {
+ return Status::Invalid("Nonexistent sort key column: ",
+ options.sort_keys[0].name);
+ }
+ return SortIndices(*array, options, ctx);
+ }
+
+ auto out_type = uint64();
+ auto length = batch.num_rows();
+ auto buffer_size = BitUtil::BytesForBits(
+ length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+ BufferVector buffers(2);
+ ARROW_ASSIGN_OR_RAISE(buffers[1],
+ AllocateResizableBuffer(buffer_size, ctx->memory_pool()));
+ auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+ auto out_begin = out->GetMutableValues<uint64_t>(1);
+ auto out_end = out_begin + length;
+ std::iota(out_begin, out_end, 0);
+
+ // Radix sorting is consistently faster except when there is a large number
+ // of sort keys, in which case it can end up degrading catastrophically.
+ // Cut off above 8 sort keys.
+ if (n_sort_keys <= 8) {
+ RadixRecordBatchSorter sorter(out_begin, out_end, batch, options);
+ ARROW_RETURN_NOT_OK(sorter.Sort());
+ } else {
+ MultipleKeyRecordBatchSorter sorter(out_begin, out_end, batch, options);
+ ARROW_RETURN_NOT_OK(sorter.Sort());
+ }
+ return Datum(out);
+ }
+
+ Result<Datum> SortIndices(const Table& table, const SortOptions& options,
+ ExecContext* ctx) const {
+ auto n_sort_keys = options.sort_keys.size();
+ if (n_sort_keys == 0) {
+ return Status::Invalid("Must specify one or more sort keys");
+ }
+ if (n_sort_keys == 1) {
+ auto chunked_array = table.GetColumnByName(options.sort_keys[0].name);
+ if (!chunked_array) {
+ return Status::Invalid("Nonexistent sort key column: ",
+ options.sort_keys[0].name);
+ }
+ return SortIndices(*chunked_array, options, ctx);
+ }
+
+ auto out_type = uint64();
+ auto length = table.num_rows();
+ auto buffer_size = BitUtil::BytesForBits(
+ length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+ std::vector<std::shared_ptr<Buffer>> buffers(2);
+ ARROW_ASSIGN_OR_RAISE(buffers[1],
+ AllocateResizableBuffer(buffer_size, ctx->memory_pool()));
+ auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+ auto out_begin = out->GetMutableValues<uint64_t>(1);
+ auto out_end = out_begin + length;
+ std::iota(out_begin, out_end, 0);
+
+ TableSorter sorter(ctx, out_begin, out_end, table, options);
+ RETURN_NOT_OK(sorter.Sort());
+
+ return Datum(out);
+ }
+};
+
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+const auto kDefaultSelectKOptions = SelectKOptions::Defaults();
+
+const FunctionDoc select_k_unstable_doc(
+ "Selects the indices of the first `k` ordered elements from the input",
+ ("This function selects an array of indices of the first `k` ordered elements from\n"
+ "the input array, record batch or table specified in the column keys\n"
+ "(`options.sort_keys`). Output is not guaranteed to be stable.\n"
+ "The columns that are not specified are returned as well, but not used for\n"
+ "ordering. Null values are considered greater than any other value and are\n"
+ "therefore sorted at the end of the array. For floating-point types, ordering of\n"
+ "values is such that: Null > NaN > Inf > number."),
+ {"input"}, "SelectKOptions");
+
+Result<std::shared_ptr<ArrayData>> MakeMutableUInt64Array(
+ std::shared_ptr<DataType> out_type, int64_t length, MemoryPool* memory_pool) {
+ auto buffer_size = length * sizeof(uint64_t);
+ ARROW_ASSIGN_OR_RAISE(auto data, AllocateBuffer(buffer_size, memory_pool));
+ return ArrayData::Make(uint64(), length, {nullptr, std::move(data)}, /*null_count=*/0);
+}
+
+template <SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval);
+};
+
+template <>
+class SelectKComparator<SortOrder::Ascending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return lval < rval;
+ }
+};
+
+template <>
+class SelectKComparator<SortOrder::Descending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return rval < lval;
+ }
+};
+
+class ArraySelecter : public TypeVisitor {
+ public:
+ ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions& options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ array_(array),
+ k_(options.k),
+ order_(options.sort_keys[0].order),
+ physical_type_(GetPhysicalType(array.type())),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (order_ == SortOrder::Ascending) { \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ } \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ }
+
+ VISIT_SORTABLE_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+
+ ArrayType arr(array_.data());
+ std::vector<uint64_t> indices(arr.length());
+
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+ if (k_ > arr.length()) {
+ k_ = arr.length();
+ }
+
+ const auto p = PartitionNulls<ArrayType, NonStablePartitioner>(
+ indices_begin, indices_end, arr, 0, NullPlacement::AtEnd);
+ const auto end_iter = p.non_nulls_end;
+
+ auto kth_begin = std::min(indices_begin + k_, end_iter);
+
+ SelectKComparator<sort_order> comparator;
+ auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ return comparator(lval, rval);
+ };
+ using HeapContainer =
+ std::priority_queue<uint64_t, std::vector<uint64_t>, decltype(cmp)>;
+ HeapContainer heap(indices_begin, kth_begin, cmp);
+ for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ if (cmp(x_index, heap.top())) {
+ heap.pop();
+ heap.push(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(uint64(), out_size,
+ ctx_->memory_pool()));
+
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size - 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Array& array_;
+ int64_t k_;
+ SortOrder order_;
+ const std::shared_ptr<DataType> physical_type_;
+ Datum* output_;
+};
+
+template <typename ArrayType>
+struct TypedHeapItem {
+ uint64_t index;
+ uint64_t offset;
+ ArrayType* array;
+};
+
+class ChunkedArraySelecter : public TypeVisitor {
+ public:
+ ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ chunked_array_(chunked_array),
+ physical_type_(GetPhysicalType(chunked_array.type())),
+ physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)),
+ k_(options.k),
+ order_(options.sort_keys[0].order),
+ ctx_(ctx),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (order_ == SortOrder::Ascending) { \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ } \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ }
+
+ VISIT_SORTABLE_PHYSICAL_TYPES(VISIT)
+#undef VISIT
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ using HeapItem = TypedHeapItem<ArrayType>;
+
+ const auto num_chunks = chunked_array_.num_chunks();
+ if (num_chunks == 0) {
+ return Status::OK();
+ }
+ if (k_ > chunked_array_.length()) {
+ k_ = chunked_array_.length();
+ }
+ std::function<bool(const HeapItem&, const HeapItem&)> cmp;
+ SelectKComparator<sort_order> comparator;
+
+ cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool {
+ const auto lval = GetView::LogicalValue(left.array->GetView(left.index));
+ const auto rval = GetView::LogicalValue(right.array->GetView(right.index));
+ return comparator(lval, rval);
+ };
+ using HeapContainer =
+ std::priority_queue<HeapItem, std::vector<HeapItem>, decltype(cmp)>;
+
+ HeapContainer heap(cmp);
+ std::vector<std::shared_ptr<ArrayType>> chunks_holder;
+ uint64_t offset = 0;
+ for (const auto& chunk : physical_chunks_) {
+ if (chunk->length() == 0) continue;
+ chunks_holder.emplace_back(std::make_shared<ArrayType>(chunk->data()));
+ ArrayType& arr = *chunks_holder[chunks_holder.size() - 1];
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ const auto p = PartitionNulls<ArrayType, NonStablePartitioner>(
+ indices_begin, indices_end, arr, 0, NullPlacement::AtEnd);
+ const auto end_iter = p.non_nulls_end;
+
+ auto kth_begin = std::min(indices_begin + k_, end_iter);
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin && heap.size() < static_cast<size_t>(k_); ++iter) {
+ heap.push(HeapItem{*iter, offset, &arr});
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto& xval = GetView::LogicalValue(arr.GetView(x_index));
+ auto top_item = heap.top();
+ const auto& top_value =
+ GetView::LogicalValue(top_item.array->GetView(top_item.index));
+ if (comparator(xval, top_value)) {
+ heap.pop();
+ heap.push(HeapItem{x_index, offset, &arr});
+ }
+ }
+ offset += chunk->length();
+ }
+
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(uint64(), out_size,
+ ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size - 1;
+ while (heap.size() > 0) {
+ auto top_item = heap.top();
+ *out_cbegin = top_item.index + top_item.offset;
+ heap.pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ const ChunkedArray& chunked_array_;
+ const std::shared_ptr<DataType> physical_type_;
+ const ArrayVector physical_chunks_;
+ int64_t k_;
+ SortOrder order_;
+ ExecContext* ctx_;
+ Datum* output_;
+};
+
+class RecordBatchSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyRecordBatchSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ RecordBatchSelecter(ExecContext* ctx, const RecordBatch& record_batch,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ record_batch_(record_batch),
+ k_(options.k),
+ output_(output),
+ sort_keys_(ResolveSortKeys(record_batch, options.sort_keys)),
+ comparator_(sort_keys_, NullPlacement::AtEnd) {}
+
+ Status Run() { return sort_keys_[0].type->Accept(this); }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (sort_keys_[0].order == SortOrder::Descending) \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ }
+ VISIT_SORTABLE_PHYSICAL_TYPES(VISIT)
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const RecordBatch& batch, const std::vector<SortKey>& sort_keys) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key : sort_keys) {
+ auto array = batch.GetColumnByName(key.name);
+ resolved.emplace_back(array, key.order);
+ }
+ return resolved;
+ }
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+ const ArrayType& arr = checked_cast<const ArrayType&>(first_sort_key.array);
+
+ const auto num_rows = record_batch_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (k_ > record_batch_.num_rows()) {
+ k_ = record_batch_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ if (lval == rval) {
+ // If the left value equals to the right value,
+ // we need to compare the second and following
+ // sort keys.
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(lval, rval);
+ };
+ using HeapContainer =
+ std::priority_queue<uint64_t, std::vector<uint64_t>, decltype(cmp)>;
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ const auto p = PartitionNulls<ArrayType, NonStablePartitioner>(
+ indices_begin, indices_end, arr, 0, NullPlacement::AtEnd);
+ const auto end_iter = p.non_nulls_end;
+
+ auto kth_begin = std::min(indices_begin + k_, end_iter);
+
+ HeapContainer heap(indices_begin, kth_begin, cmp);
+ for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ auto top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.pop();
+ heap.push(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(uint64(), out_size,
+ ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size - 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const RecordBatch& record_batch_;
+ int64_t k_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+};
+
+class TableSelecter : public TypeVisitor {
+ private:
+ struct ResolvedSortKey {
+ ResolvedSortKey(const std::shared_ptr<ChunkedArray>& chunked_array,
+ const SortOrder order)
+ : order(order),
+ type(GetPhysicalType(chunked_array->type())),
+ chunks(GetPhysicalChunks(*chunked_array, type)),
+ chunk_pointers(GetArrayPointers(chunks)),
+ null_count(chunked_array->null_count()),
+ resolver(chunk_pointers) {}
+
+ using LocationType = int64_t;
+
+ // Find the target chunk and index in the target chunk from an
+ // index in chunked array.
+ template <typename ArrayType>
+ ResolvedChunk<ArrayType> GetChunk(int64_t index) const {
+ return resolver.Resolve<ArrayType>(index);
+ }
+
+ const SortOrder order;
+ const std::shared_ptr<DataType> type;
+ const ArrayVector chunks;
+ const std::vector<const Array*> chunk_pointers;
+ const int64_t null_count;
+ const ChunkedArrayResolver resolver;
+ };
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ TableSelecter(ExecContext* ctx, const Table& table, const SelectKOptions& options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ table_(table),
+ k_(options.k),
+ output_(output),
+ sort_keys_(ResolveSortKeys(table, options.sort_keys)),
+ comparator_(sort_keys_, NullPlacement::AtEnd) {}
+
+ Status Run() { return sort_keys_[0].type->Accept(this); }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (sort_keys_[0].order == SortOrder::Descending) \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ }
+ VISIT_SORTABLE_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const Table& table, const std::vector<SortKey>& sort_keys) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key : sort_keys) {
+ auto chunked_array = table.GetColumnByName(key.name);
+ resolved.emplace_back(chunked_array, key.order);
+ }
+ return resolved;
+ }
+
+ // Behaves like PartitionNulls() but this supports multiple sort keys.
+ template <typename Type>
+ NullPartitionResult PartitionNullsInternal(uint64_t* indices_begin,
+ uint64_t* indices_end,
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+
+ const auto p = PartitionNullsOnly<StablePartitioner>(
+ indices_begin, indices_end, first_sort_key.resolver, first_sort_key.null_count,
+ NullPlacement::AtEnd);
+ DCHECK_EQ(p.nulls_end - p.nulls_begin, first_sort_key.null_count);
+
+ const auto q = PartitionNullLikes<ArrayType, StablePartitioner>(
+ p.non_nulls_begin, p.non_nulls_end, first_sort_key.resolver,
+ NullPlacement::AtEnd);
+
+ auto& comparator = comparator_;
+ // Sort all NaNs by the second and following sort keys.
+ std::stable_sort(q.nulls_begin, q.nulls_end, [&](uint64_t left, uint64_t right) {
+ return comparator.Compare(left, right, 1);
+ });
+ // Sort all nulls by the second and following sort keys.
+ std::stable_sort(p.nulls_begin, p.nulls_end, [&](uint64_t left, uint64_t right) {
+ return comparator.Compare(left, right, 1);
+ });
+
+ return q;
+ }
+
+ // XXX this implementation is rather inefficient as it computes chunk indices
+ // at every comparison. Instead we should iterate over individual batches
+ // and remember ChunkLocation entries in the max-heap.
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+
+ const auto num_rows = table_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (k_ > table_.num_rows()) {
+ k_ = table_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ auto chunk_left = first_sort_key.template GetChunk<ArrayType>(left);
+ auto chunk_right = first_sort_key.template GetChunk<ArrayType>(right);
+ auto value_left = chunk_left.Value();
+ auto value_right = chunk_right.Value();
+ if (value_left == value_right) {
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(value_left, value_right);
+ };
+ using HeapContainer =
+ std::priority_queue<uint64_t, std::vector<uint64_t>, decltype(cmp)>;
+
+ std::vector<uint64_t> indices(num_rows);
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ const auto p =
+ this->PartitionNullsInternal<InType>(indices_begin, indices_end, first_sort_key);
+ const auto end_iter = p.non_nulls_end;
+ auto kth_begin = std::min(indices_begin + k_, end_iter);
+
+ HeapContainer heap(indices_begin, kth_begin, cmp);
+ for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ uint64_t top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.pop();
+ heap.push(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(uint64(), out_size,
+ ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size - 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Table& table_;
+ int64_t k_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+};
+
+static Status CheckConsistency(const Schema& schema,
+ const std::vector<SortKey>& sort_keys) {
+ for (const auto& key : sort_keys) {
+ auto field = schema.GetFieldByName(key.name);
+ if (!field) {
+ return Status::Invalid("Nonexistent sort key column: ", key.name);
+ }
+ }
+ return Status::OK();
+}
+
+class SelectKUnstableMetaFunction : public MetaFunction {
+ public:
+ SelectKUnstableMetaFunction()
+ : MetaFunction("select_k_unstable", Arity::Unary(), &select_k_unstable_doc,
+ &kDefaultSelectKOptions) {}
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options, ExecContext* ctx) const {
+ const SelectKOptions& select_k_options = static_cast<const SelectKOptions&>(*options);
+ if (select_k_options.k < 0) {
+ return Status::Invalid("select_k_unstable requires a nonnegative `k`, got ",
+ select_k_options.k);
+ }
+ if (select_k_options.sort_keys.size() == 0) {
+ return Status::Invalid("select_k_unstable requires a non-empty `sort_keys`");
+ }
+ switch (args[0].kind()) {
+ case Datum::ARRAY: {
+ return SelectKth(*args[0].make_array(), select_k_options, ctx);
+ } break;
+ case Datum::CHUNKED_ARRAY: {
+ return SelectKth(*args[0].chunked_array(), select_k_options, ctx);
+ } break;
+ case Datum::RECORD_BATCH:
+ return SelectKth(*args[0].record_batch(), select_k_options, ctx);
+ break;
+ case Datum::TABLE:
+ return SelectKth(*args[0].table(), select_k_options, ctx);
+ break;
+ default:
+ break;
+ }
+ return Status::NotImplemented(
+ "Unsupported types for select_k operation: "
+ "values=",
+ args[0].ToString());
+ }
+
+ private:
+ Result<Datum> SelectKth(const Array& array, const SelectKOptions& options,
+ ExecContext* ctx) const {
+ Datum output;
+ ArraySelecter selecter(ctx, array, options, &output);
+ ARROW_RETURN_NOT_OK(selecter.Run());
+ return output;
+ }
+
+ Result<Datum> SelectKth(const ChunkedArray& chunked_array,
+ const SelectKOptions& options, ExecContext* ctx) const {
+ Datum output;
+ ChunkedArraySelecter selecter(ctx, chunked_array, options, &output);
+ ARROW_RETURN_NOT_OK(selecter.Run());
+ return output;
+ }
+ Result<Datum> SelectKth(const RecordBatch& record_batch, const SelectKOptions& options,
+ ExecContext* ctx) const {
+ ARROW_RETURN_NOT_OK(CheckConsistency(*record_batch.schema(), options.sort_keys));
+ Datum output;
+ RecordBatchSelecter selecter(ctx, record_batch, options, &output);
+ ARROW_RETURN_NOT_OK(selecter.Run());
+ return output;
+ }
+ Result<Datum> SelectKth(const Table& table, const SelectKOptions& options,
+ ExecContext* ctx) const {
+ ARROW_RETURN_NOT_OK(CheckConsistency(*table.schema(), options.sort_keys));
+ Datum output;
+ TableSelecter selecter(ctx, table, options, &output);
+ ARROW_RETURN_NOT_OK(selecter.Run());
+ return output;
+ }
+};
+
+} // namespace
+
+void RegisterVectorSort(FunctionRegistry* registry) {
+ DCHECK_OK(registry->AddFunction(std::make_shared<SortIndicesMetaFunction>()));
+ DCHECK_OK(registry->AddFunction(std::make_shared<SelectKUnstableMetaFunction>()));
+}
+
+#undef VISIT_SORTABLE_PHYSICAL_TYPES
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc
new file mode 100644
index 000000000..6ab0bcfde
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc
@@ -0,0 +1,305 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace compute {
+constexpr auto kSeed = 0x0ff1ce;
+
+static void ArraySortIndicesBenchmark(benchmark::State& state,
+ const std::shared_ptr<Array>& values) {
+ for (auto _ : state) {
+ ABORT_NOT_OK(SortIndices(*values).status());
+ }
+ state.SetItemsProcessed(state.iterations() * values->length());
+}
+
+static void ChunkedArraySortIndicesBenchmark(
+ benchmark::State& state, const std::shared_ptr<ChunkedArray>& values) {
+ for (auto _ : state) {
+ ABORT_NOT_OK(SortIndices(*values).status());
+ }
+ state.SetItemsProcessed(state.iterations() * values->length());
+}
+
+static void ArraySortIndicesInt64Benchmark(benchmark::State& state, int64_t min,
+ int64_t max) {
+ RegressionArgs args(state);
+
+ const int64_t array_size = args.size / sizeof(int64_t);
+ auto rand = random::RandomArrayGenerator(kSeed);
+ auto values = rand.Int64(array_size, min, max, args.null_proportion);
+
+ ArraySortIndicesBenchmark(state, values);
+}
+
+static void ChunkedArraySortIndicesInt64Benchmark(benchmark::State& state, int64_t min,
+ int64_t max) {
+ RegressionArgs args(state);
+
+ const int64_t n_chunks = 10;
+ const int64_t array_size = args.size / n_chunks / sizeof(int64_t);
+ auto rand = random::RandomArrayGenerator(kSeed);
+ ArrayVector chunks;
+ for (int64_t i = 0; i < n_chunks; ++i) {
+ chunks.push_back(rand.Int64(array_size, min, max, args.null_proportion));
+ }
+
+ ChunkedArraySortIndicesBenchmark(state, std::make_shared<ChunkedArray>(chunks));
+}
+
+static void ArraySortIndicesInt64Narrow(benchmark::State& state) {
+ ArraySortIndicesInt64Benchmark(state, -100, 100);
+}
+
+static void ArraySortIndicesInt64Wide(benchmark::State& state) {
+ const auto min = std::numeric_limits<int64_t>::min();
+ const auto max = std::numeric_limits<int64_t>::max();
+ ArraySortIndicesInt64Benchmark(state, min, max);
+}
+
+static void ArraySortIndicesBool(benchmark::State& state) {
+ RegressionArgs args(state);
+
+ const int64_t array_size = args.size * 8;
+ auto rand = random::RandomArrayGenerator(kSeed);
+ auto values = rand.Boolean(array_size, 0.5, args.null_proportion);
+
+ ArraySortIndicesBenchmark(state, values);
+}
+
+static void ChunkedArraySortIndicesInt64Narrow(benchmark::State& state) {
+ ChunkedArraySortIndicesInt64Benchmark(state, -100, 100);
+}
+
+static void ChunkedArraySortIndicesInt64Wide(benchmark::State& state) {
+ const auto min = std::numeric_limits<int64_t>::min();
+ const auto max = std::numeric_limits<int64_t>::max();
+ ChunkedArraySortIndicesInt64Benchmark(state, min, max);
+}
+
+static void DatumSortIndicesBenchmark(benchmark::State& state, const Datum& datum,
+ const SortOptions& options) {
+ for (auto _ : state) {
+ ABORT_NOT_OK(SortIndices(datum, options).status());
+ }
+}
+
+// Extract benchmark args from benchmark::State
+struct RecordBatchSortIndicesArgs {
+ // the number of records
+ const int64_t num_records;
+
+ // proportion of nulls in generated arrays
+ const double null_proportion;
+
+ // the number of columns
+ const int64_t num_columns;
+
+ // Extract args
+ explicit RecordBatchSortIndicesArgs(benchmark::State& state)
+ : num_records(state.range(0)),
+ null_proportion(ComputeNullProportion(state.range(1))),
+ num_columns(state.range(2)),
+ state_(state) {}
+
+ ~RecordBatchSortIndicesArgs() {
+ state_.counters["columns"] = static_cast<double>(num_columns);
+ state_.counters["null_percent"] = null_proportion * 100;
+ state_.SetItemsProcessed(state_.iterations() * num_records);
+ }
+
+ protected:
+ double ComputeNullProportion(int64_t inverse_null_proportion) {
+ if (inverse_null_proportion == 0) {
+ return 0.0;
+ } else {
+ return std::min(1., 1. / static_cast<double>(inverse_null_proportion));
+ }
+ }
+
+ benchmark::State& state_;
+};
+
+struct TableSortIndicesArgs : public RecordBatchSortIndicesArgs {
+ // the number of chunks in each generated column
+ const int64_t num_chunks;
+
+ // Extract args
+ explicit TableSortIndicesArgs(benchmark::State& state)
+ : RecordBatchSortIndicesArgs(state), num_chunks(state.range(3)) {}
+
+ ~TableSortIndicesArgs() { state_.counters["chunks"] = static_cast<double>(num_chunks); }
+};
+
+struct BatchOrTableBenchmarkData {
+ std::shared_ptr<Schema> schema;
+ std::vector<SortKey> sort_keys;
+ ChunkedArrayVector columns;
+};
+
+BatchOrTableBenchmarkData MakeBatchOrTableBenchmarkDataInt64(
+ const RecordBatchSortIndicesArgs& args, int64_t num_chunks, int64_t min_value,
+ int64_t max_value) {
+ auto rand = random::RandomArrayGenerator(kSeed);
+ FieldVector fields;
+ BatchOrTableBenchmarkData data;
+
+ for (int64_t i = 0; i < args.num_columns; ++i) {
+ auto name = std::to_string(i);
+ fields.push_back(field(name, int64()));
+ auto order = (i % 2) == 0 ? SortOrder::Ascending : SortOrder::Descending;
+ data.sort_keys.emplace_back(name, order);
+ ArrayVector chunks;
+ if ((args.num_records % num_chunks) != 0) {
+ Status::Invalid("The number of chunks (", num_chunks,
+ ") must be "
+ "a multiple of the number of records (",
+ args.num_records, ")")
+ .Abort();
+ }
+ auto num_records_in_array = args.num_records / num_chunks;
+ for (int64_t j = 0; j < num_chunks; ++j) {
+ chunks.push_back(
+ rand.Int64(num_records_in_array, min_value, max_value, args.null_proportion));
+ }
+ ASSIGN_OR_ABORT(auto chunked_array, ChunkedArray::Make(chunks, int64()));
+ data.columns.push_back(chunked_array);
+ }
+
+ data.schema = schema(fields);
+ return data;
+}
+
+static void RecordBatchSortIndicesInt64(benchmark::State& state, int64_t min,
+ int64_t max) {
+ RecordBatchSortIndicesArgs args(state);
+
+ auto data = MakeBatchOrTableBenchmarkDataInt64(args, /*num_chunks=*/1, min, max);
+ ArrayVector columns;
+ for (const auto& chunked : data.columns) {
+ ARROW_CHECK_EQ(chunked->num_chunks(), 1);
+ columns.push_back(chunked->chunk(0));
+ }
+
+ auto batch = RecordBatch::Make(data.schema, args.num_records, columns);
+ SortOptions options(data.sort_keys);
+ DatumSortIndicesBenchmark(state, Datum(*batch), options);
+}
+
+static void TableSortIndicesInt64(benchmark::State& state, int64_t min, int64_t max) {
+ TableSortIndicesArgs args(state);
+
+ auto data = MakeBatchOrTableBenchmarkDataInt64(args, args.num_chunks, min, max);
+ auto table = Table::Make(data.schema, data.columns, args.num_records);
+ SortOptions options(data.sort_keys);
+ DatumSortIndicesBenchmark(state, Datum(*table), options);
+}
+
+static void RecordBatchSortIndicesInt64Narrow(benchmark::State& state) {
+ RecordBatchSortIndicesInt64(state, -100, 100);
+}
+
+static void RecordBatchSortIndicesInt64Wide(benchmark::State& state) {
+ RecordBatchSortIndicesInt64(state, std::numeric_limits<int64_t>::min(),
+ std::numeric_limits<int64_t>::max());
+}
+
+static void TableSortIndicesInt64Narrow(benchmark::State& state) {
+ TableSortIndicesInt64(state, -100, 100);
+}
+
+static void TableSortIndicesInt64Wide(benchmark::State& state) {
+ TableSortIndicesInt64(state, std::numeric_limits<int64_t>::min(),
+ std::numeric_limits<int64_t>::max());
+}
+
+BENCHMARK(ArraySortIndicesInt64Narrow)
+ ->Apply(RegressionSetArgs)
+ ->Args({1 << 20, 100})
+ ->Args({1 << 23, 100})
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+BENCHMARK(ArraySortIndicesInt64Wide)
+ ->Apply(RegressionSetArgs)
+ ->Args({1 << 20, 100})
+ ->Args({1 << 23, 100})
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+BENCHMARK(ArraySortIndicesBool)
+ ->Apply(RegressionSetArgs)
+ ->Args({1 << 20, 100})
+ ->Args({1 << 23, 100})
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+BENCHMARK(ChunkedArraySortIndicesInt64Narrow)
+ ->Apply(RegressionSetArgs)
+ ->Args({1 << 20, 100})
+ ->Args({1 << 23, 100})
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+BENCHMARK(ChunkedArraySortIndicesInt64Wide)
+ ->Apply(RegressionSetArgs)
+ ->Args({1 << 20, 100})
+ ->Args({1 << 23, 100})
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+BENCHMARK(RecordBatchSortIndicesInt64Narrow)
+ ->ArgsProduct({
+ {1 << 20}, // the number of records
+ {100, 4, 0}, // inverse null proportion
+ {16, 8, 2, 1}, // the number of columns
+ })
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+BENCHMARK(RecordBatchSortIndicesInt64Wide)
+ ->ArgsProduct({
+ {1 << 20}, // the number of records
+ {100, 4, 0}, // inverse null proportion
+ {16, 8, 2, 1}, // the number of columns
+ })
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+BENCHMARK(TableSortIndicesInt64Narrow)
+ ->ArgsProduct({
+ {1 << 20}, // the number of records
+ {100, 4, 0}, // inverse null proportion
+ {16, 8, 2, 1}, // the number of columns
+ {32, 4, 1}, // the number of chunks
+ })
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+BENCHMARK(TableSortIndicesInt64Wide)
+ ->ArgsProduct({
+ {1 << 20}, // the number of records
+ {100, 4, 0}, // inverse null proportion
+ {16, 8, 2, 1}, // the number of columns
+ {32, 4, 1}, // the number of chunks
+ })
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_sort_internal.h b/src/arrow/cpp/src/arrow/compute/kernels/vector_sort_internal.h
new file mode 100644
index 000000000..9b0295ded
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_sort_internal.h
@@ -0,0 +1,457 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <functional>
+
+#include "arrow/array.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/chunked_internal.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+// Visit all physical types for which sorting is implemented.
+#define VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) \
+ VISIT(BooleanType) \
+ VISIT(Int8Type) \
+ VISIT(Int16Type) \
+ VISIT(Int32Type) \
+ VISIT(Int64Type) \
+ VISIT(UInt8Type) \
+ VISIT(UInt16Type) \
+ VISIT(UInt32Type) \
+ VISIT(UInt64Type) \
+ VISIT(FloatType) \
+ VISIT(DoubleType) \
+ VISIT(BinaryType) \
+ VISIT(LargeBinaryType) \
+ VISIT(FixedSizeBinaryType) \
+ VISIT(Decimal128Type) \
+ VISIT(Decimal256Type)
+
+// NOTE: std::partition is usually faster than std::stable_partition.
+
+struct NonStablePartitioner {
+ template <typename Predicate>
+ uint64_t* operator()(uint64_t* indices_begin, uint64_t* indices_end, Predicate&& pred) {
+ return std::partition(indices_begin, indices_end, std::forward<Predicate>(pred));
+ }
+};
+
+struct StablePartitioner {
+ template <typename Predicate>
+ uint64_t* operator()(uint64_t* indices_begin, uint64_t* indices_end, Predicate&& pred) {
+ return std::stable_partition(indices_begin, indices_end,
+ std::forward<Predicate>(pred));
+ }
+};
+
+template <typename TypeClass, typename Enable = void>
+struct NullTraits {
+ using has_null_like_values = std::false_type;
+};
+
+template <typename TypeClass>
+struct NullTraits<TypeClass, enable_if_physical_floating_point<TypeClass>> {
+ using has_null_like_values = std::true_type;
+};
+
+template <typename TypeClass>
+using has_null_like_values = typename NullTraits<TypeClass>::has_null_like_values;
+
+// Compare two values, taking NaNs into account
+
+template <typename Type, typename Enable = void>
+struct ValueComparator;
+
+template <typename Type>
+struct ValueComparator<Type, enable_if_t<!has_null_like_values<Type>::value>> {
+ template <typename Value>
+ static int Compare(const Value& left, const Value& right, SortOrder order,
+ NullPlacement null_placement) {
+ int compared;
+ if (left == right) {
+ compared = 0;
+ } else if (left > right) {
+ compared = 1;
+ } else {
+ compared = -1;
+ }
+ if (order == SortOrder::Descending) {
+ compared = -compared;
+ }
+ return compared;
+ }
+};
+
+template <typename Type>
+struct ValueComparator<Type, enable_if_t<has_null_like_values<Type>::value>> {
+ template <typename Value>
+ static int Compare(const Value& left, const Value& right, SortOrder order,
+ NullPlacement null_placement) {
+ const bool is_nan_left = std::isnan(left);
+ const bool is_nan_right = std::isnan(right);
+ if (is_nan_left && is_nan_right) {
+ return 0;
+ } else if (is_nan_left) {
+ return null_placement == NullPlacement::AtStart ? -1 : 1;
+ } else if (is_nan_right) {
+ return null_placement == NullPlacement::AtStart ? 1 : -1;
+ }
+ int compared;
+ if (left == right) {
+ compared = 0;
+ } else if (left > right) {
+ compared = 1;
+ } else {
+ compared = -1;
+ }
+ if (order == SortOrder::Descending) {
+ compared = -compared;
+ }
+ return compared;
+ }
+};
+
+template <typename Type, typename Value>
+int CompareTypeValues(const Value& left, const Value& right, SortOrder order,
+ NullPlacement null_placement) {
+ return ValueComparator<Type>::Compare(left, right, order, null_placement);
+}
+
+struct NullPartitionResult {
+ uint64_t* non_nulls_begin;
+ uint64_t* non_nulls_end;
+ uint64_t* nulls_begin;
+ uint64_t* nulls_end;
+
+ uint64_t* overall_begin() const { return std::min(nulls_begin, non_nulls_begin); }
+
+ uint64_t* overall_end() const { return std::max(nulls_end, non_nulls_end); }
+
+ int64_t non_null_count() const { return non_nulls_end - non_nulls_begin; }
+
+ int64_t null_count() const { return nulls_end - nulls_begin; }
+
+ static NullPartitionResult NoNulls(uint64_t* indices_begin, uint64_t* indices_end,
+ NullPlacement null_placement) {
+ if (null_placement == NullPlacement::AtStart) {
+ return {indices_begin, indices_end, indices_begin, indices_begin};
+ } else {
+ return {indices_begin, indices_end, indices_end, indices_end};
+ }
+ }
+
+ static NullPartitionResult NullsOnly(uint64_t* indices_begin, uint64_t* indices_end,
+ NullPlacement null_placement) {
+ if (null_placement == NullPlacement::AtStart) {
+ return {indices_end, indices_end, indices_begin, indices_end};
+ } else {
+ return {indices_begin, indices_begin, indices_begin, indices_end};
+ }
+ }
+
+ static NullPartitionResult NullsAtEnd(uint64_t* indices_begin, uint64_t* indices_end,
+ uint64_t* midpoint) {
+ DCHECK_GE(midpoint, indices_begin);
+ DCHECK_LE(midpoint, indices_end);
+ return {indices_begin, midpoint, midpoint, indices_end};
+ }
+
+ static NullPartitionResult NullsAtStart(uint64_t* indices_begin, uint64_t* indices_end,
+ uint64_t* midpoint) {
+ DCHECK_GE(midpoint, indices_begin);
+ DCHECK_LE(midpoint, indices_end);
+ return {midpoint, indices_end, indices_begin, midpoint};
+ }
+};
+
+// Move nulls (not null-like values) to end of array.
+//
+// `offset` is used when this is called on a chunk of a chunked array
+template <typename Partitioner>
+NullPartitionResult PartitionNullsOnly(uint64_t* indices_begin, uint64_t* indices_end,
+ const Array& values, int64_t offset,
+ NullPlacement null_placement) {
+ if (values.null_count() == 0) {
+ return NullPartitionResult::NoNulls(indices_begin, indices_end, null_placement);
+ }
+ Partitioner partitioner;
+ if (null_placement == NullPlacement::AtStart) {
+ auto nulls_end = partitioner(
+ indices_begin, indices_end,
+ [&values, &offset](uint64_t ind) { return values.IsNull(ind - offset); });
+ return NullPartitionResult::NullsAtStart(indices_begin, indices_end, nulls_end);
+ } else {
+ auto nulls_begin = partitioner(
+ indices_begin, indices_end,
+ [&values, &offset](uint64_t ind) { return !values.IsNull(ind - offset); });
+ return NullPartitionResult::NullsAtEnd(indices_begin, indices_end, nulls_begin);
+ }
+}
+
+// Move non-null null-like values to end of array.
+//
+// `offset` is used when this is called on a chunk of a chunked array
+template <typename ArrayType, typename Partitioner>
+enable_if_t<!has_null_like_values<typename ArrayType::TypeClass>::value,
+ NullPartitionResult>
+PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end,
+ const ArrayType& values, int64_t offset,
+ NullPlacement null_placement) {
+ return NullPartitionResult::NoNulls(indices_begin, indices_end, null_placement);
+}
+
+template <typename ArrayType, typename Partitioner>
+enable_if_t<has_null_like_values<typename ArrayType::TypeClass>::value,
+ NullPartitionResult>
+PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end,
+ const ArrayType& values, int64_t offset,
+ NullPlacement null_placement) {
+ Partitioner partitioner;
+ if (null_placement == NullPlacement::AtStart) {
+ auto null_likes_end =
+ partitioner(indices_begin, indices_end, [&values, &offset](uint64_t ind) {
+ return std::isnan(values.GetView(ind - offset));
+ });
+ return NullPartitionResult::NullsAtStart(indices_begin, indices_end, null_likes_end);
+ } else {
+ auto null_likes_begin =
+ partitioner(indices_begin, indices_end, [&values, &offset](uint64_t ind) {
+ return !std::isnan(values.GetView(ind - offset));
+ });
+ return NullPartitionResult::NullsAtEnd(indices_begin, indices_end, null_likes_begin);
+ }
+}
+
+// Move nulls to end of array.
+//
+// `offset` is used when this is called on a chunk of a chunked array
+template <typename ArrayType, typename Partitioner>
+NullPartitionResult PartitionNulls(uint64_t* indices_begin, uint64_t* indices_end,
+ const ArrayType& values, int64_t offset,
+ NullPlacement null_placement) {
+ // Partition nulls at start (resp. end), and null-like values just before (resp. after)
+ NullPartitionResult p = PartitionNullsOnly<Partitioner>(indices_begin, indices_end,
+ values, offset, null_placement);
+ NullPartitionResult q = PartitionNullLikes<ArrayType, Partitioner>(
+ p.non_nulls_begin, p.non_nulls_end, values, offset, null_placement);
+ return NullPartitionResult{q.non_nulls_begin, q.non_nulls_end,
+ std::min(q.nulls_begin, p.nulls_begin),
+ std::max(q.nulls_end, p.nulls_end)};
+}
+
+//
+// Null partitioning on chunked arrays
+//
+
+template <typename Partitioner>
+NullPartitionResult PartitionNullsOnly(uint64_t* indices_begin, uint64_t* indices_end,
+ const ChunkedArrayResolver& resolver,
+ int64_t null_count, NullPlacement null_placement) {
+ if (null_count == 0) {
+ return NullPartitionResult::NoNulls(indices_begin, indices_end, null_placement);
+ }
+ Partitioner partitioner;
+ if (null_placement == NullPlacement::AtStart) {
+ auto nulls_end = partitioner(indices_begin, indices_end, [&](uint64_t ind) {
+ const auto chunk = resolver.Resolve<Array>(ind);
+ return chunk.IsNull();
+ });
+ return NullPartitionResult::NullsAtStart(indices_begin, indices_end, nulls_end);
+ } else {
+ auto nulls_begin = partitioner(indices_begin, indices_end, [&](uint64_t ind) {
+ const auto chunk = resolver.Resolve<Array>(ind);
+ return !chunk.IsNull();
+ });
+ return NullPartitionResult::NullsAtEnd(indices_begin, indices_end, nulls_begin);
+ }
+}
+
+template <typename ArrayType, typename Partitioner>
+enable_if_t<!has_null_like_values<typename ArrayType::TypeClass>::value,
+ NullPartitionResult>
+PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end,
+ const ChunkedArrayResolver& resolver, NullPlacement null_placement) {
+ return NullPartitionResult::NoNulls(indices_begin, indices_end, null_placement);
+}
+
+template <typename ArrayType, typename Partitioner>
+enable_if_t<has_null_like_values<typename ArrayType::TypeClass>::value,
+ NullPartitionResult>
+PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end,
+ const ChunkedArrayResolver& resolver, NullPlacement null_placement) {
+ Partitioner partitioner;
+ if (null_placement == NullPlacement::AtStart) {
+ auto null_likes_end = partitioner(indices_begin, indices_end, [&](uint64_t ind) {
+ const auto chunk = resolver.Resolve<ArrayType>(ind);
+ return std::isnan(chunk.Value());
+ });
+ return NullPartitionResult::NullsAtStart(indices_begin, indices_end, null_likes_end);
+ } else {
+ auto null_likes_begin = partitioner(indices_begin, indices_end, [&](uint64_t ind) {
+ const auto chunk = resolver.Resolve<ArrayType>(ind);
+ return !std::isnan(chunk.Value());
+ });
+ return NullPartitionResult::NullsAtEnd(indices_begin, indices_end, null_likes_begin);
+ }
+}
+
+template <typename ArrayType, typename Partitioner>
+NullPartitionResult PartitionNulls(uint64_t* indices_begin, uint64_t* indices_end,
+ const ChunkedArrayResolver& resolver,
+ int64_t null_count, NullPlacement null_placement) {
+ // Partition nulls at start (resp. end), and null-like values just before (resp. after)
+ NullPartitionResult p = PartitionNullsOnly<Partitioner>(
+ indices_begin, indices_end, resolver, null_count, null_placement);
+ NullPartitionResult q = PartitionNullLikes<ArrayType, Partitioner>(
+ p.non_nulls_begin, p.non_nulls_end, resolver, null_placement);
+ return NullPartitionResult{q.non_nulls_begin, q.non_nulls_end,
+ std::min(q.nulls_begin, p.nulls_begin),
+ std::max(q.nulls_end, p.nulls_end)};
+}
+
+struct MergeImpl {
+ using MergeNullsFunc = std::function<void(uint64_t* nulls_begin, uint64_t* nulls_middle,
+ uint64_t* nulls_end, uint64_t* temp_indices,
+ int64_t null_count)>;
+
+ using MergeNonNullsFunc =
+ std::function<void(uint64_t* range_begin, uint64_t* range_middle,
+ uint64_t* range_end, uint64_t* temp_indices)>;
+
+ MergeImpl(NullPlacement null_placement, MergeNullsFunc&& merge_nulls,
+ MergeNonNullsFunc&& merge_non_nulls)
+ : null_placement_(null_placement),
+ merge_nulls_(std::move(merge_nulls)),
+ merge_non_nulls_(std::move(merge_non_nulls)) {}
+
+ Status Init(ExecContext* ctx, int64_t temp_indices_length) {
+ ARROW_ASSIGN_OR_RAISE(
+ temp_buffer_,
+ AllocateBuffer(sizeof(int64_t) * temp_indices_length, ctx->memory_pool()));
+ temp_indices_ = reinterpret_cast<uint64_t*>(temp_buffer_->mutable_data());
+ return Status::OK();
+ }
+
+ NullPartitionResult Merge(const NullPartitionResult& left,
+ const NullPartitionResult& right, int64_t null_count) const {
+ if (null_placement_ == NullPlacement::AtStart) {
+ return MergeNullsAtStart(left, right, null_count);
+ } else {
+ return MergeNullsAtEnd(left, right, null_count);
+ }
+ }
+
+ NullPartitionResult MergeNullsAtStart(const NullPartitionResult& left,
+ const NullPartitionResult& right,
+ int64_t null_count) const {
+ // Input layout:
+ // [left nulls .... left non-nulls .... right nulls .... right non-nulls]
+ DCHECK_EQ(left.nulls_end, left.non_nulls_begin);
+ DCHECK_EQ(left.non_nulls_end, right.nulls_begin);
+ DCHECK_EQ(right.nulls_end, right.non_nulls_begin);
+
+ // Mutate the input, stably, to obtain the following layout:
+ // [left nulls .... right nulls .... left non-nulls .... right non-nulls]
+ std::rotate(left.non_nulls_begin, right.nulls_begin, right.nulls_end);
+
+ const auto p = NullPartitionResult::NullsAtStart(
+ left.nulls_begin, right.non_nulls_end,
+ left.nulls_begin + left.null_count() + right.null_count());
+
+ // If the type has null-like values (such as NaN), ensure those plus regular
+ // nulls are partitioned in the right order. Note this assumes that all
+ // null-like values (e.g. NaN) are ordered equally.
+ if (p.null_count()) {
+ merge_nulls_(p.nulls_begin, p.nulls_begin + left.null_count(), p.nulls_end,
+ temp_indices_, null_count);
+ }
+
+ // Merge the non-null values into temp area
+ DCHECK_EQ(right.non_nulls_begin - p.non_nulls_begin, left.non_null_count());
+ DCHECK_EQ(p.non_nulls_end - right.non_nulls_begin, right.non_null_count());
+ if (p.non_null_count()) {
+ merge_non_nulls_(p.non_nulls_begin, right.non_nulls_begin, p.non_nulls_end,
+ temp_indices_);
+ }
+ return p;
+ }
+
+ NullPartitionResult MergeNullsAtEnd(const NullPartitionResult& left,
+ const NullPartitionResult& right,
+ int64_t null_count) const {
+ // Input layout:
+ // [left non-nulls .... left nulls .... right non-nulls .... right nulls]
+ DCHECK_EQ(left.non_nulls_end, left.nulls_begin);
+ DCHECK_EQ(left.nulls_end, right.non_nulls_begin);
+ DCHECK_EQ(right.non_nulls_end, right.nulls_begin);
+
+ // Mutate the input, stably, to obtain the following layout:
+ // [left non-nulls .... right non-nulls .... left nulls .... right nulls]
+ std::rotate(left.nulls_begin, right.non_nulls_begin, right.non_nulls_end);
+
+ const auto p = NullPartitionResult::NullsAtEnd(
+ left.non_nulls_begin, right.nulls_end,
+ left.non_nulls_begin + left.non_null_count() + right.non_null_count());
+
+ // If the type has null-like values (such as NaN), ensure those plus regular
+ // nulls are partitioned in the right order. Note this assumes that all
+ // null-like values (e.g. NaN) are ordered equally.
+ if (p.null_count()) {
+ merge_nulls_(p.nulls_begin, p.nulls_begin + left.null_count(), p.nulls_end,
+ temp_indices_, null_count);
+ }
+
+ // Merge the non-null values into temp area
+ DCHECK_EQ(left.non_nulls_end - p.non_nulls_begin, left.non_null_count());
+ DCHECK_EQ(p.non_nulls_end - left.non_nulls_end, right.non_null_count());
+ if (p.non_null_count()) {
+ merge_non_nulls_(p.non_nulls_begin, left.non_nulls_end, p.non_nulls_end,
+ temp_indices_);
+ }
+ return p;
+ }
+
+ private:
+ NullPlacement null_placement_;
+ MergeNullsFunc merge_nulls_;
+ MergeNonNullsFunc merge_non_nulls_;
+ std::unique_ptr<Buffer> temp_buffer_;
+ uint64_t* temp_indices_ = nullptr;
+};
+
+// TODO make this usable if indices are non trivial on input
+// (see ConcreteRecordBatchColumnSorter)
+// `offset` is used when this is called on a chunk of a chunked array
+using ArraySortFunc = std::function<NullPartitionResult(
+ uint64_t* indices_begin, uint64_t* indices_end, const Array& values, int64_t offset,
+ const ArraySortOptions& options)>;
+
+Result<ArraySortFunc> GetArraySorter(const DataType& type);
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_sort_test.cc
new file mode 100644
index 000000000..d39f6722c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_sort_test.cc
@@ -0,0 +1,1925 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <functional>
+#include <limits>
+#include <memory>
+#include <ostream>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include <gmock/gmock-matchers.h>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/result.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+std::vector<SortOrder> AllOrders() {
+ return {SortOrder::Ascending, SortOrder::Descending};
+}
+
+std::vector<NullPlacement> AllNullPlacements() {
+ return {NullPlacement::AtEnd, NullPlacement::AtStart};
+}
+
+std::ostream& operator<<(std::ostream& os, NullPlacement null_placement) {
+ os << (null_placement == NullPlacement::AtEnd ? "AtEnd" : "AtStart");
+ return os;
+}
+
+// ----------------------------------------------------------------------
+// Tests for NthToIndices
+
+template <typename ArrayType>
+auto GetLogicalValue(const ArrayType& array, uint64_t index)
+ -> decltype(array.GetView(index)) {
+ return array.GetView(index);
+}
+
+Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) {
+ return Decimal128(array.Value(index));
+}
+
+Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) {
+ return Decimal256(array.Value(index));
+}
+
+template <typename ArrayType>
+struct ThreeWayComparator {
+ SortOrder order;
+ NullPlacement null_placement;
+
+ int operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) const {
+ return (*this)(array, array, lhs, rhs);
+ }
+
+ // Return -1 if L < R, 0 if L == R, 1 if L > R
+ int operator()(const ArrayType& left, const ArrayType& right, uint64_t lhs,
+ uint64_t rhs) const {
+ const bool lhs_is_null = left.IsNull(lhs);
+ const bool rhs_is_null = right.IsNull(rhs);
+ if (lhs_is_null && rhs_is_null) return 0;
+ if (lhs_is_null) {
+ return null_placement == NullPlacement::AtStart ? -1 : 1;
+ }
+ if (rhs_is_null) {
+ return null_placement == NullPlacement::AtStart ? 1 : -1;
+ }
+ const auto lval = GetLogicalValue(left, lhs);
+ const auto rval = GetLogicalValue(right, rhs);
+ if (is_floating_type<typename ArrayType::TypeClass>::value) {
+ const bool lhs_isnan = lval != lval;
+ const bool rhs_isnan = rval != rval;
+ if (lhs_isnan && rhs_isnan) return 0;
+ if (lhs_isnan) {
+ return null_placement == NullPlacement::AtStart ? -1 : 1;
+ }
+ if (rhs_isnan) {
+ return null_placement == NullPlacement::AtStart ? 1 : -1;
+ }
+ }
+ if (lval == rval) return 0;
+ if (lval < rval) {
+ return order == SortOrder::Ascending ? -1 : 1;
+ } else {
+ return order == SortOrder::Ascending ? 1 : -1;
+ }
+ }
+};
+
+template <typename ArrayType>
+struct NthComparator {
+ ThreeWayComparator<ArrayType> three_way;
+
+ explicit NthComparator(NullPlacement null_placement)
+ : three_way({SortOrder::Ascending, null_placement}) {}
+
+ // Return true iff L <= R
+ bool operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) const {
+ // lhs <= rhs
+ return three_way(array, lhs, rhs) <= 0;
+ }
+};
+
+template <typename ArrayType>
+struct SortComparator {
+ ThreeWayComparator<ArrayType> three_way;
+
+ explicit SortComparator(SortOrder order, NullPlacement null_placement)
+ : three_way({order, null_placement}) {}
+
+ bool operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) const {
+ const int r = three_way(array, lhs, rhs);
+ if (r != 0) return r < 0;
+ return lhs < rhs;
+ }
+};
+
+template <typename ArrowType>
+class TestNthToIndicesBase : public TestBase {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ protected:
+ void Validate(const ArrayType& array, int n, NullPlacement null_placement,
+ UInt64Array& offsets) {
+ if (n >= array.length()) {
+ for (int i = 0; i < array.length(); ++i) {
+ ASSERT_TRUE(offsets.Value(i) == static_cast<uint64_t>(i));
+ }
+ } else {
+ NthComparator<ArrayType> compare{null_placement};
+ uint64_t nth = offsets.Value(n);
+
+ for (int i = 0; i < n; ++i) {
+ uint64_t lhs = offsets.Value(i);
+ ASSERT_TRUE(compare(array, lhs, nth));
+ }
+ for (int i = n + 1; i < array.length(); ++i) {
+ uint64_t rhs = offsets.Value(i);
+ ASSERT_TRUE(compare(array, nth, rhs));
+ }
+ }
+ }
+
+ void AssertNthToIndicesArray(const std::shared_ptr<Array>& values, int n,
+ NullPlacement null_placement) {
+ ARROW_SCOPED_TRACE("n = ", n, ", null_placement = ", null_placement);
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> offsets,
+ NthToIndices(*values, PartitionNthOptions(n, null_placement)));
+ // null_count field should have been initialized to 0, for convenience
+ ASSERT_EQ(offsets->data()->null_count, 0);
+ ValidateOutput(*offsets);
+ Validate(*checked_pointer_cast<ArrayType>(values), n, null_placement,
+ *checked_pointer_cast<UInt64Array>(offsets));
+ }
+
+ void AssertNthToIndicesArray(const std::shared_ptr<Array>& values, int n) {
+ for (auto null_placement : AllNullPlacements()) {
+ AssertNthToIndicesArray(values, n, null_placement);
+ }
+ }
+
+ void AssertNthToIndicesJson(const std::string& values, int n) {
+ AssertNthToIndicesArray(ArrayFromJSON(GetType(), values), n);
+ }
+
+ virtual std::shared_ptr<DataType> GetType() = 0;
+};
+
+template <typename ArrowType>
+class TestNthToIndices : public TestNthToIndicesBase<ArrowType> {
+ protected:
+ std::shared_ptr<DataType> GetType() override {
+ return default_type_instance<ArrowType>();
+ }
+};
+
+template <typename ArrowType>
+class TestNthToIndicesForReal : public TestNthToIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestNthToIndicesForReal, RealArrowTypes);
+
+template <typename ArrowType>
+class TestNthToIndicesForIntegral : public TestNthToIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestNthToIndicesForIntegral, IntegralArrowTypes);
+
+template <typename ArrowType>
+class TestNthToIndicesForBool : public TestNthToIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestNthToIndicesForBool, ::testing::Types<BooleanType>);
+
+template <typename ArrowType>
+class TestNthToIndicesForTemporal : public TestNthToIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestNthToIndicesForTemporal, TemporalArrowTypes);
+
+template <typename ArrowType>
+class TestNthToIndicesForDecimal : public TestNthToIndicesBase<ArrowType> {
+ std::shared_ptr<DataType> GetType() override {
+ return std::make_shared<ArrowType>(5, 2);
+ }
+};
+TYPED_TEST_SUITE(TestNthToIndicesForDecimal, DecimalArrowTypes);
+
+template <typename ArrowType>
+class TestNthToIndicesForStrings : public TestNthToIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestNthToIndicesForStrings, testing::Types<StringType>);
+
+TYPED_TEST(TestNthToIndicesForReal, NthToIndicesDoesNotProvideDefaultOptions) {
+ auto input = ArrayFromJSON(this->GetType(), "[null, 1, 3.3, null, 2, 5.3]");
+ ASSERT_RAISES(Invalid, CallFunction("partition_nth_indices", {input}));
+}
+
+TYPED_TEST(TestNthToIndicesForReal, Real) {
+ this->AssertNthToIndicesJson("[null, 1, 3.3, null, 2, 5.3]", 0);
+ this->AssertNthToIndicesJson("[null, 1, 3.3, null, 2, 5.3]", 2);
+ this->AssertNthToIndicesJson("[null, 1, 3.3, null, 2, 5.3]", 5);
+ this->AssertNthToIndicesJson("[null, 1, 3.3, null, 2, 5.3]", 6);
+
+ this->AssertNthToIndicesJson("[null, 2, NaN, 3, 1]", 0);
+ this->AssertNthToIndicesJson("[null, 2, NaN, 3, 1]", 1);
+ this->AssertNthToIndicesJson("[null, 2, NaN, 3, 1]", 2);
+ this->AssertNthToIndicesJson("[null, 2, NaN, 3, 1]", 3);
+ this->AssertNthToIndicesJson("[null, 2, NaN, 3, 1]", 4);
+ this->AssertNthToIndicesJson("[NaN, 2, null, 3, 1]", 3);
+ this->AssertNthToIndicesJson("[NaN, 2, null, 3, 1]", 4);
+
+ this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 0);
+ this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 1);
+ this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 2);
+ this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 3);
+ this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 4);
+}
+
+TYPED_TEST(TestNthToIndicesForIntegral, Integral) {
+ this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 0);
+ this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 2);
+ this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 5);
+ this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 6);
+}
+
+TYPED_TEST(TestNthToIndicesForBool, Bool) {
+ this->AssertNthToIndicesJson("[null, false, true, null, false, true]", 0);
+ this->AssertNthToIndicesJson("[null, false, true, null, false, true]", 2);
+ this->AssertNthToIndicesJson("[null, false, true, null, false, true]", 5);
+ this->AssertNthToIndicesJson("[null, false, true, null, false, true]", 6);
+}
+
+TYPED_TEST(TestNthToIndicesForTemporal, Temporal) {
+ this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 0);
+ this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 2);
+ this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 5);
+ this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 6);
+}
+
+TYPED_TEST(TestNthToIndicesForDecimal, Decimal) {
+ const std::string values = R"(["123.45", null, "-123.45", "456.78", "-456.78"])";
+ this->AssertNthToIndicesJson(values, 0);
+ this->AssertNthToIndicesJson(values, 2);
+ this->AssertNthToIndicesJson(values, 4);
+ this->AssertNthToIndicesJson(values, 5);
+}
+
+TYPED_TEST(TestNthToIndicesForStrings, Strings) {
+ this->AssertNthToIndicesJson(R"(["testing", null, "nth", "for", null, "strings"])", 0);
+ this->AssertNthToIndicesJson(R"(["testing", null, "nth", "for", null, "strings"])", 2);
+ this->AssertNthToIndicesJson(R"(["testing", null, "nth", "for", null, "strings"])", 5);
+ this->AssertNthToIndicesJson(R"(["testing", null, "nth", "for", null, "strings"])", 6);
+}
+
+TEST(TestNthToIndices, Null) {
+ ASSERT_OK_AND_ASSIGN(auto arr, MakeArrayOfNull(null(), 6));
+ auto expected = ArrayFromJSON(uint64(), "[0, 1, 2, 3, 4, 5]");
+ for (const auto null_placement : AllNullPlacements()) {
+ for (const auto n : {0, 1, 2, 3, 4, 5, 6}) {
+ ASSERT_OK_AND_ASSIGN(auto actual,
+ NthToIndices(*arr, PartitionNthOptions(n, null_placement)));
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+ }
+ }
+}
+
+template <typename ArrowType>
+class TestNthToIndicesRandom : public TestNthToIndicesBase<ArrowType> {
+ public:
+ std::shared_ptr<DataType> GetType() override {
+ EXPECT_TRUE(0) << "shouldn't be used";
+ return nullptr;
+ }
+};
+
+using NthToIndicesableTypes =
+ ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
+ Int32Type, Int64Type, FloatType, DoubleType, Decimal128Type,
+ StringType>;
+
+TYPED_TEST_SUITE(TestNthToIndicesRandom, NthToIndicesableTypes);
+
+TYPED_TEST(TestNthToIndicesRandom, RandomValues) {
+ Random<TypeParam> rand(0x61549225);
+ int length = 100;
+ for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
+ // Try n from 0 to out of bound
+ for (int n = 0; n <= length; ++n) {
+ auto array = rand.Generate(length, null_probability);
+ this->AssertNthToIndicesArray(array, n);
+ }
+ }
+}
+
+// ----------------------------------------------------------------------
+// Tests for SortToIndices
+
+template <typename T>
+void AssertSortIndices(const std::shared_ptr<T>& input, SortOrder order,
+ NullPlacement null_placement,
+ const std::shared_ptr<Array>& expected) {
+ ArraySortOptions options(order, null_placement);
+ ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*input, options));
+ ValidateOutput(*actual);
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+}
+
+template <typename T>
+void AssertSortIndices(const std::shared_ptr<T>& input, const SortOptions& options,
+ const std::shared_ptr<Array>& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(Datum(*input), options));
+ ValidateOutput(*actual);
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+}
+
+template <typename T>
+void AssertSortIndices(const std::shared_ptr<T>& input, const SortOptions& options,
+ const std::string& expected) {
+ AssertSortIndices(input, options, ArrayFromJSON(uint64(), expected));
+}
+
+template <typename T>
+void AssertSortIndices(const std::shared_ptr<T>& input, SortOrder order,
+ NullPlacement null_placement, const std::string& expected) {
+ AssertSortIndices(input, order, null_placement, ArrayFromJSON(uint64(), expected));
+}
+
+void AssertSortIndices(const std::shared_ptr<DataType>& type, const std::string& values,
+ SortOrder order, NullPlacement null_placement,
+ const std::string& expected) {
+ AssertSortIndices(ArrayFromJSON(type, values), order, null_placement,
+ ArrayFromJSON(uint64(), expected));
+}
+
+class TestArraySortIndicesBase : public TestBase {
+ public:
+ virtual std::shared_ptr<DataType> type() = 0;
+
+ virtual void AssertSortIndices(const std::string& values, SortOrder order,
+ NullPlacement null_placement,
+ const std::string& expected) {
+ arrow::compute::AssertSortIndices(this->type(), values, order, null_placement,
+ expected);
+ }
+
+ virtual void AssertSortIndices(const std::string& values, const std::string& expected) {
+ AssertSortIndices(values, SortOrder::Ascending, NullPlacement::AtEnd, expected);
+ }
+};
+
+template <typename ArrowType>
+class TestArraySortIndices : public TestArraySortIndicesBase {
+ public:
+ std::shared_ptr<DataType> type() override {
+ // Will choose default parameters for temporal types
+ return std::make_shared<ArrowType>();
+ }
+};
+
+template <typename ArrowType>
+class TestArraySortIndicesForReal : public TestArraySortIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestArraySortIndicesForReal, RealArrowTypes);
+
+template <typename ArrowType>
+class TestArraySortIndicesForBool : public TestArraySortIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestArraySortIndicesForBool, ::testing::Types<BooleanType>);
+
+template <typename ArrowType>
+class TestArraySortIndicesForIntegral : public TestArraySortIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestArraySortIndicesForIntegral, IntegralArrowTypes);
+
+template <typename ArrowType>
+class TestArraySortIndicesForTemporal : public TestArraySortIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestArraySortIndicesForTemporal, TemporalArrowTypes);
+
+using StringSortTestTypes = testing::Types<StringType, LargeStringType>;
+
+template <typename ArrowType>
+class TestArraySortIndicesForStrings : public TestArraySortIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestArraySortIndicesForStrings, StringSortTestTypes);
+
+class TestArraySortIndicesForFixedSizeBinary : public TestArraySortIndicesBase {
+ public:
+ std::shared_ptr<DataType> type() override { return fixed_size_binary(3); }
+};
+
+TYPED_TEST(TestArraySortIndicesForReal, SortReal) {
+ for (auto null_placement : AllNullPlacements()) {
+ for (auto order : AllOrders()) {
+ this->AssertSortIndices("[]", order, null_placement, "[]");
+ this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]");
+ }
+ this->AssertSortIndices("[3.4, 2.6, 6.3]", SortOrder::Ascending, null_placement,
+ "[1, 0, 2]");
+ this->AssertSortIndices("[1.1, 2.4, 3.5, 4.3, 5.1, 6.8, 7.3]", SortOrder::Ascending,
+ null_placement, "[0, 1, 2, 3, 4, 5, 6]");
+ this->AssertSortIndices("[7, 6, 5, 4, 3, 2, 1]", SortOrder::Ascending, null_placement,
+ "[6, 5, 4, 3, 2, 1, 0]");
+ this->AssertSortIndices("[10.4, 12, 4.2, 50, 50.3, 32, 11]", SortOrder::Ascending,
+ null_placement, "[2, 0, 6, 1, 5, 3, 4]");
+ }
+
+ this->AssertSortIndices("[null, 1, 3.3, null, 2, 5.3]", SortOrder::Ascending,
+ NullPlacement::AtEnd, "[1, 4, 2, 5, 0, 3]");
+ this->AssertSortIndices("[null, 1, 3.3, null, 2, 5.3]", SortOrder::Ascending,
+ NullPlacement::AtStart, "[0, 3, 1, 4, 2, 5]");
+ this->AssertSortIndices("[null, 1, 3.3, null, 2, 5.3]", SortOrder::Descending,
+ NullPlacement::AtEnd, "[5, 2, 4, 1, 0, 3]");
+ this->AssertSortIndices("[null, 1, 3.3, null, 2, 5.3]", SortOrder::Descending,
+ NullPlacement::AtStart, "[0, 3, 5, 2, 4, 1]");
+
+ this->AssertSortIndices("[3, 4, NaN, 1, 2, null]", SortOrder::Ascending,
+ NullPlacement::AtEnd, "[3, 4, 0, 1, 2, 5]");
+ this->AssertSortIndices("[3, 4, NaN, 1, 2, null]", SortOrder::Ascending,
+ NullPlacement::AtStart, "[5, 2, 3, 4, 0, 1]");
+ this->AssertSortIndices("[3, 4, NaN, 1, 2, null]", SortOrder::Descending,
+ NullPlacement::AtEnd, "[1, 0, 4, 3, 2, 5]");
+ this->AssertSortIndices("[3, 4, NaN, 1, 2, null]", SortOrder::Descending,
+ NullPlacement::AtStart, "[5, 2, 1, 0, 4, 3]");
+
+ this->AssertSortIndices("[NaN, 2, NaN, 3, 1]", SortOrder::Ascending,
+ NullPlacement::AtEnd, "[4, 1, 3, 0, 2]");
+ this->AssertSortIndices("[NaN, 2, NaN, 3, 1]", SortOrder::Ascending,
+ NullPlacement::AtStart, "[0, 2, 4, 1, 3]");
+ this->AssertSortIndices("[NaN, 2, NaN, 3, 1]", SortOrder::Descending,
+ NullPlacement::AtEnd, "[3, 1, 4, 0, 2]");
+ this->AssertSortIndices("[NaN, 2, NaN, 3, 1]", SortOrder::Descending,
+ NullPlacement::AtStart, "[0, 2, 3, 1, 4]");
+
+ this->AssertSortIndices("[null, NaN, NaN, null]", SortOrder::Ascending,
+ NullPlacement::AtEnd, "[1, 2, 0, 3]");
+ this->AssertSortIndices("[null, NaN, NaN, null]", SortOrder::Ascending,
+ NullPlacement::AtStart, "[0, 3, 1, 2]");
+ this->AssertSortIndices("[null, NaN, NaN, null]", SortOrder::Descending,
+ NullPlacement::AtEnd, "[1, 2, 0, 3]");
+ this->AssertSortIndices("[null, NaN, NaN, null]", SortOrder::Descending,
+ NullPlacement::AtStart, "[0, 3, 1, 2]");
+}
+
+TYPED_TEST(TestArraySortIndicesForIntegral, SortIntegral) {
+ for (auto null_placement : AllNullPlacements()) {
+ for (auto order : AllOrders()) {
+ this->AssertSortIndices("[]", order, null_placement, "[]");
+ this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]");
+ }
+ this->AssertSortIndices("[1, 2, 3, 4, 5, 6, 7]", SortOrder::Ascending, null_placement,
+ "[0, 1, 2, 3, 4, 5, 6]");
+ this->AssertSortIndices("[7, 6, 5, 4, 3, 2, 1]", SortOrder::Ascending, null_placement,
+ "[6, 5, 4, 3, 2, 1, 0]");
+
+ this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Ascending,
+ null_placement, "[2, 0, 6, 1, 5, 3, 4]");
+ this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Descending,
+ null_placement, "[3, 4, 5, 1, 6, 0, 2]");
+ }
+
+ // Values with a small range (use a counting sort)
+ this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Ascending,
+ NullPlacement::AtEnd, "[1, 4, 2, 5, 0, 3]");
+ this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Ascending,
+ NullPlacement::AtStart, "[0, 3, 1, 4, 2, 5]");
+ this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Descending,
+ NullPlacement::AtEnd, "[5, 2, 4, 1, 0, 3]");
+ this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Descending,
+ NullPlacement::AtStart, "[0, 3, 5, 2, 4, 1]");
+}
+
+TYPED_TEST(TestArraySortIndicesForBool, SortBool) {
+ for (auto null_placement : AllNullPlacements()) {
+ for (auto order : AllOrders()) {
+ this->AssertSortIndices("[]", order, null_placement, "[]");
+ this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]");
+ }
+ this->AssertSortIndices("[true, true, false]", SortOrder::Ascending, null_placement,
+ "[2, 0, 1]");
+ this->AssertSortIndices("[false, false, false, true, true, true, true]",
+ SortOrder::Ascending, null_placement,
+ "[0, 1, 2, 3, 4, 5, 6]");
+ this->AssertSortIndices("[true, true, true, true, false, false, false]",
+ SortOrder::Ascending, null_placement,
+ "[4, 5, 6, 0, 1, 2, 3]");
+
+ this->AssertSortIndices("[false, true, false, true, true, false, false]",
+ SortOrder::Ascending, null_placement,
+ "[0, 2, 5, 6, 1, 3, 4]");
+ this->AssertSortIndices("[false, true, false, true, true, false, false]",
+ SortOrder::Descending, null_placement,
+ "[1, 3, 4, 0, 2, 5, 6]");
+ }
+
+ this->AssertSortIndices("[null, true, false, null, false, true]", SortOrder::Ascending,
+ NullPlacement::AtEnd, "[2, 4, 1, 5, 0, 3]");
+ this->AssertSortIndices("[null, true, false, null, false, true]", SortOrder::Ascending,
+ NullPlacement::AtStart, "[0, 3, 2, 4, 1, 5]");
+ this->AssertSortIndices("[null, true, false, null, false, true]", SortOrder::Descending,
+ NullPlacement::AtEnd, "[1, 5, 2, 4, 0, 3]");
+ this->AssertSortIndices("[null, true, false, null, false, true]", SortOrder::Descending,
+ NullPlacement::AtStart, "[0, 3, 1, 5, 2, 4]");
+}
+
+TYPED_TEST(TestArraySortIndicesForTemporal, SortTemporal) {
+ for (auto null_placement : AllNullPlacements()) {
+ for (auto order : AllOrders()) {
+ this->AssertSortIndices("[]", order, null_placement, "[]");
+ this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]");
+ }
+ this->AssertSortIndices("[3, 2, 6]", SortOrder::Ascending, null_placement,
+ "[1, 0, 2]");
+ this->AssertSortIndices("[1, 2, 3, 4, 5, 6, 7]", SortOrder::Ascending, null_placement,
+ "[0, 1, 2, 3, 4, 5, 6]");
+ this->AssertSortIndices("[7, 6, 5, 4, 3, 2, 1]", SortOrder::Ascending, null_placement,
+ "[6, 5, 4, 3, 2, 1, 0]");
+
+ this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Ascending,
+ null_placement, "[2, 0, 6, 1, 5, 3, 4]");
+ this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Descending,
+ null_placement, "[3, 4, 5, 1, 6, 0, 2]");
+ }
+
+ this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Ascending,
+ NullPlacement::AtEnd, "[1, 4, 2, 5, 0, 3]");
+ this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Ascending,
+ NullPlacement::AtStart, "[0, 3, 1, 4, 2, 5]");
+ this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Descending,
+ NullPlacement::AtEnd, "[5, 2, 4, 1, 0, 3]");
+ this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Descending,
+ NullPlacement::AtStart, "[0, 3, 5, 2, 4, 1]");
+}
+
+TYPED_TEST(TestArraySortIndicesForStrings, SortStrings) {
+ for (auto null_placement : AllNullPlacements()) {
+ for (auto order : AllOrders()) {
+ this->AssertSortIndices("[]", order, null_placement, "[]");
+ this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]");
+ }
+ this->AssertSortIndices(R"(["a", "b", "c"])", SortOrder::Ascending, null_placement,
+ "[0, 1, 2]");
+ this->AssertSortIndices(R"(["foo", "bar", "baz"])", SortOrder::Ascending,
+ null_placement, "[1, 2, 0]");
+ this->AssertSortIndices(R"(["testing", "sort", "for", "strings"])",
+ SortOrder::Ascending, null_placement, "[2, 1, 3, 0]");
+ }
+
+ const char* input = R"([null, "c", "b", null, "a", "b"])";
+ this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd,
+ "[4, 2, 5, 1, 0, 3]");
+ this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart,
+ "[0, 3, 4, 2, 5, 1]");
+ this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd,
+ "[1, 2, 5, 4, 0, 3]");
+ this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart,
+ "[0, 3, 1, 2, 5, 4]");
+}
+
+TEST_F(TestArraySortIndicesForFixedSizeBinary, SortFixedSizeBinary) {
+ for (auto null_placement : AllNullPlacements()) {
+ for (auto order : AllOrders()) {
+ this->AssertSortIndices("[]", order, null_placement, "[]");
+ this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]");
+ }
+ this->AssertSortIndices(R"(["def", "abc", "ghi"])", SortOrder::Ascending,
+ null_placement, "[1, 0, 2]");
+ this->AssertSortIndices(R"(["def", "abc", "ghi"])", SortOrder::Descending,
+ null_placement, "[2, 0, 1]");
+ }
+
+ const char* input = R"([null, "ccc", "bbb", null, "aaa", "bbb"])";
+ this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd,
+ "[4, 2, 5, 1, 0, 3]");
+ this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart,
+ "[0, 3, 4, 2, 5, 1]");
+ this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd,
+ "[1, 2, 5, 4, 0, 3]");
+ this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart,
+ "[0, 3, 1, 2, 5, 4]");
+}
+
+template <typename ArrowType>
+class TestArraySortIndicesForUInt8 : public TestArraySortIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestArraySortIndicesForUInt8, UInt8Type);
+
+template <typename ArrowType>
+class TestArraySortIndicesForInt8 : public TestArraySortIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestArraySortIndicesForInt8, Int8Type);
+
+TYPED_TEST(TestArraySortIndicesForUInt8, SortUInt8) {
+ const char* input = "[255, null, 0, 255, 10, null, 128, 0]";
+ this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd,
+ "[2, 7, 4, 6, 0, 3, 1, 5]");
+ this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart,
+ "[1, 5, 2, 7, 4, 6, 0, 3]");
+ this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd,
+ "[0, 3, 6, 4, 2, 7, 1, 5]");
+ this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart,
+ "[1, 5, 0, 3, 6, 4, 2, 7]");
+}
+
+TYPED_TEST(TestArraySortIndicesForInt8, SortInt8) {
+ const char* input = "[127, null, -128, 127, 0, null, 10, -128]";
+ this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd,
+ "[2, 7, 4, 6, 0, 3, 1, 5]");
+ this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart,
+ "[1, 5, 2, 7, 4, 6, 0, 3]");
+ this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd,
+ "[0, 3, 6, 4, 2, 7, 1, 5]");
+ this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart,
+ "[1, 5, 0, 3, 6, 4, 2, 7]");
+}
+
+template <typename ArrowType>
+class TestArraySortIndicesForInt64 : public TestArraySortIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestArraySortIndicesForInt64, Int64Type);
+
+TYPED_TEST(TestArraySortIndicesForInt64, SortInt64) {
+ // Values with a large range (use a comparison-based sort)
+ const char* input =
+ "[null, -2000000000000000, 3000000000000000,"
+ " null, -1000000000000000, 5000000000000000]";
+ this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd,
+ "[1, 4, 2, 5, 0, 3]");
+ this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart,
+ "[0, 3, 1, 4, 2, 5]");
+ this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd,
+ "[5, 2, 4, 1, 0, 3]");
+ this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart,
+ "[0, 3, 5, 2, 4, 1]");
+}
+
+template <typename ArrowType>
+class TestArraySortIndicesForDecimal : public TestArraySortIndicesBase {
+ public:
+ std::shared_ptr<DataType> type() override { return std::make_shared<ArrowType>(5, 2); }
+};
+TYPED_TEST_SUITE(TestArraySortIndicesForDecimal, DecimalArrowTypes);
+
+TYPED_TEST(TestArraySortIndicesForDecimal, DecimalSortTestTypes) {
+ const char* input = R"(["123.45", null, "-123.45", "456.78", "-456.78", null])";
+ this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd,
+ "[4, 2, 0, 3, 1, 5]");
+ this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart,
+ "[1, 5, 4, 2, 0, 3]");
+ this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd,
+ "[3, 0, 2, 4, 1, 5]");
+ this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart,
+ "[1, 5, 3, 0, 2, 4]");
+}
+
+TEST(TestArraySortIndices, NullType) {
+ auto chunked = ChunkedArrayFromJSON(null(), {"[null, null]", "[]", "[null]", "[null]"});
+ for (const auto null_placement : AllNullPlacements()) {
+ for (const auto order : AllOrders()) {
+ AssertSortIndices(null(), "[null, null, null, null]", order, null_placement,
+ "[0, 1, 2, 3]");
+ AssertSortIndices(chunked, order, null_placement, "[0, 1, 2, 3]");
+ }
+ }
+}
+
+TEST(TestArraySortIndices, TemporalTypeParameters) {
+ std::vector<std::shared_ptr<DataType>> types;
+ for (auto unit : {TimeUnit::NANO, TimeUnit::MICRO, TimeUnit::MILLI, TimeUnit::SECOND}) {
+ types.push_back(duration(unit));
+ types.push_back(timestamp(unit));
+ types.push_back(timestamp(unit, "America/Phoenix"));
+ }
+ types.push_back(time64(TimeUnit::NANO));
+ types.push_back(time64(TimeUnit::MICRO));
+ types.push_back(time32(TimeUnit::MILLI));
+ types.push_back(time32(TimeUnit::SECOND));
+ for (const auto& ty : types) {
+ for (auto null_placement : AllNullPlacements()) {
+ for (auto order : AllOrders()) {
+ AssertSortIndices(ty, "[]", order, null_placement, "[]");
+ AssertSortIndices(ty, "[null, null]", order, null_placement, "[0, 1]");
+ }
+ AssertSortIndices(ty, "[3, 2, 6]", SortOrder::Ascending, null_placement,
+ "[1, 0, 2]");
+ AssertSortIndices(ty, "[1, 2, 3, 4, 5, 6, 7]", SortOrder::Ascending, null_placement,
+ "[0, 1, 2, 3, 4, 5, 6]");
+ AssertSortIndices(ty, "[7, 6, 5, 4, 3, 2, 1]", SortOrder::Ascending, null_placement,
+ "[6, 5, 4, 3, 2, 1, 0]");
+
+ AssertSortIndices(ty, "[10, 12, 4, 50, 50, 32, 11]", SortOrder::Ascending,
+ null_placement, "[2, 0, 6, 1, 5, 3, 4]");
+ AssertSortIndices(ty, "[10, 12, 4, 50, 50, 32, 11]", SortOrder::Descending,
+ null_placement, "[3, 4, 5, 1, 6, 0, 2]");
+ }
+ AssertSortIndices(ty, "[null, 1, 3, null, 2, 5]", SortOrder::Ascending,
+ NullPlacement::AtEnd, "[1, 4, 2, 5, 0, 3]");
+ AssertSortIndices(ty, "[null, 1, 3, null, 2, 5]", SortOrder::Ascending,
+ NullPlacement::AtStart, "[0, 3, 1, 4, 2, 5]");
+ AssertSortIndices(ty, "[null, 1, 3, null, 2, 5]", SortOrder::Descending,
+ NullPlacement::AtEnd, "[5, 2, 4, 1, 0, 3]");
+ AssertSortIndices(ty, "[null, 1, 3, null, 2, 5]", SortOrder::Descending,
+ NullPlacement::AtStart, "[0, 3, 5, 2, 4, 1]");
+ }
+}
+
+template <typename ArrowType>
+class TestArraySortIndicesRandom : public TestBase {};
+
+template <typename ArrowType>
+class TestArraySortIndicesRandomCount : public TestBase {};
+
+template <typename ArrowType>
+class TestArraySortIndicesRandomCompare : public TestBase {};
+
+using SortIndicesableTypes =
+ ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
+ Int32Type, Int64Type, FloatType, DoubleType, StringType,
+ Decimal128Type, BooleanType>;
+
+template <typename ArrayType>
+void ValidateSorted(const ArrayType& array, UInt64Array& offsets, SortOrder order,
+ NullPlacement null_placement) {
+ ValidateOutput(array);
+ SortComparator<ArrayType> compare{order, null_placement};
+ for (int i = 1; i < array.length(); i++) {
+ uint64_t lhs = offsets.Value(i - 1);
+ uint64_t rhs = offsets.Value(i);
+ ASSERT_TRUE(compare(array, lhs, rhs));
+ }
+}
+
+TYPED_TEST_SUITE(TestArraySortIndicesRandom, SortIndicesableTypes);
+
+TYPED_TEST(TestArraySortIndicesRandom, SortRandomValues) {
+ using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
+
+ Random<TypeParam> rand(0x5487655);
+ int times = 5;
+ int length = 100;
+ for (int test = 0; test < times; test++) {
+ for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
+ auto array = rand.Generate(length, null_probability);
+ for (auto order : AllOrders()) {
+ for (auto null_placement : AllNullPlacements()) {
+ ArraySortOptions options(order, null_placement);
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> offsets,
+ SortIndices(*array, options));
+ ValidateSorted<ArrayType>(*checked_pointer_cast<ArrayType>(array),
+ *checked_pointer_cast<UInt64Array>(offsets), order,
+ null_placement);
+ }
+ }
+ }
+ }
+}
+
+// Long array with small value range: counting sort
+// - length >= 1024(CountCompareSorter::countsort_min_len_)
+// - range <= 4096(CountCompareSorter::countsort_max_range_)
+TYPED_TEST_SUITE(TestArraySortIndicesRandomCount, IntegralArrowTypes);
+
+TYPED_TEST(TestArraySortIndicesRandomCount, SortRandomValuesCount) {
+ using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
+
+ RandomRange<TypeParam> rand(0x5487656);
+ int times = 5;
+ int length = 100;
+ int range = 2000;
+ for (int test = 0; test < times; test++) {
+ for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
+ auto array = rand.Generate(length, range, null_probability);
+ for (auto order : AllOrders()) {
+ for (auto null_placement : AllNullPlacements()) {
+ ArraySortOptions options(order, null_placement);
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> offsets,
+ SortIndices(*array, options));
+ ValidateSorted<ArrayType>(*checked_pointer_cast<ArrayType>(array),
+ *checked_pointer_cast<UInt64Array>(offsets), order,
+ null_placement);
+ }
+ }
+ }
+ }
+}
+
+// Long array with big value range: std::stable_sort
+TYPED_TEST_SUITE(TestArraySortIndicesRandomCompare, IntegralArrowTypes);
+
+TYPED_TEST(TestArraySortIndicesRandomCompare, SortRandomValuesCompare) {
+ using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
+
+ Random<TypeParam> rand(0x5487657);
+ int times = 5;
+ int length = 100;
+ for (int test = 0; test < times; test++) {
+ for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
+ auto array = rand.Generate(length, null_probability);
+ for (auto order : AllOrders()) {
+ for (auto null_placement : AllNullPlacements()) {
+ ArraySortOptions options(order, null_placement);
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> offsets,
+ SortIndices(*array, options));
+ ValidateSorted<ArrayType>(*checked_pointer_cast<ArrayType>(array),
+ *checked_pointer_cast<UInt64Array>(offsets), order,
+ null_placement);
+ }
+ }
+ }
+ }
+}
+
+// Test basic cases for chunked array.
+class TestChunkedArraySortIndices : public ::testing::Test {};
+
+TEST_F(TestChunkedArraySortIndices, Null) {
+ auto chunked_array = ChunkedArrayFromJSON(uint8(), {
+ "[null, 1]",
+ "[3, null, 2]",
+ "[1]",
+ });
+ AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtEnd,
+ "[1, 5, 4, 2, 0, 3]");
+ AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtStart,
+ "[0, 3, 1, 5, 4, 2]");
+ AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtEnd,
+ "[2, 4, 1, 5, 0, 3]");
+ AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtStart,
+ "[0, 3, 2, 4, 1, 5]");
+}
+
+TEST_F(TestChunkedArraySortIndices, NaN) {
+ auto chunked_array = ChunkedArrayFromJSON(float32(), {
+ "[null, 1]",
+ "[3, null, NaN]",
+ "[NaN, 1]",
+ });
+ AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtEnd,
+ "[1, 6, 2, 4, 5, 0, 3]");
+ AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtStart,
+ "[0, 3, 4, 5, 1, 6, 2]");
+ AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtEnd,
+ "[2, 1, 6, 4, 5, 0, 3]");
+ AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtStart,
+ "[0, 3, 4, 5, 2, 1, 6]");
+}
+
+// Tests for temporal types
+template <typename ArrowType>
+class TestChunkedArraySortIndicesForTemporal : public TestChunkedArraySortIndices {
+ protected:
+ std::shared_ptr<DataType> GetType() { return default_type_instance<ArrowType>(); }
+};
+TYPED_TEST_SUITE(TestChunkedArraySortIndicesForTemporal, TemporalArrowTypes);
+
+TYPED_TEST(TestChunkedArraySortIndicesForTemporal, NoNull) {
+ auto type = this->GetType();
+ auto chunked_array = ChunkedArrayFromJSON(type, {
+ "[0, 1]",
+ "[3, 2, 1]",
+ "[5, 0]",
+ });
+ for (auto null_placement : AllNullPlacements()) {
+ AssertSortIndices(chunked_array, SortOrder::Ascending, null_placement,
+ "[0, 6, 1, 4, 3, 2, 5]");
+ AssertSortIndices(chunked_array, SortOrder::Descending, null_placement,
+ "[5, 2, 3, 1, 4, 0, 6]");
+ }
+}
+
+// Tests for decimal types
+template <typename ArrowType>
+class TestChunkedArraySortIndicesForDecimal : public TestChunkedArraySortIndices {
+ protected:
+ std::shared_ptr<DataType> GetType() { return std::make_shared<ArrowType>(5, 2); }
+};
+TYPED_TEST_SUITE(TestChunkedArraySortIndicesForDecimal, DecimalArrowTypes);
+
+TYPED_TEST(TestChunkedArraySortIndicesForDecimal, Basics) {
+ auto type = this->GetType();
+ auto chunked_array = ChunkedArrayFromJSON(
+ type, {R"(["123.45", "-123.45"])", R"([null, "456.78"])", R"(["-456.78", null])"});
+ AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtEnd,
+ "[4, 1, 0, 3, 2, 5]");
+ AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtStart,
+ "[2, 5, 4, 1, 0, 3]");
+ AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtEnd,
+ "[3, 0, 1, 4, 2, 5]");
+ AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtStart,
+ "[2, 5, 3, 0, 1, 4]");
+}
+
+// Base class for testing against random chunked array.
+template <typename Type>
+class TestChunkedArrayRandomBase : public TestBase {
+ protected:
+ // Generates a chunk. This should be implemented in subclasses.
+ virtual std::shared_ptr<Array> GenerateArray(int length, double null_probability) = 0;
+
+ // All tests uses this.
+ void TestSortIndices(int length) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+
+ for (auto null_probability : {0.0, 0.1, 0.5, 0.9, 1.0}) {
+ for (auto num_chunks : {1, 2, 5, 10, 40}) {
+ std::vector<std::shared_ptr<Array>> arrays;
+ for (int i = 0; i < num_chunks; ++i) {
+ auto array = this->GenerateArray(length / num_chunks, null_probability);
+ arrays.push_back(array);
+ }
+ ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make(arrays));
+ // Concatenate chunks to use existing ValidateSorted() for array.
+ ASSERT_OK_AND_ASSIGN(auto concatenated_array, Concatenate(arrays));
+
+ for (auto order : AllOrders()) {
+ for (auto null_placement : AllNullPlacements()) {
+ ArraySortOptions options(order, null_placement);
+ ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(*chunked_array, options));
+ ValidateSorted<ArrayType>(
+ *checked_pointer_cast<ArrayType>(concatenated_array),
+ *checked_pointer_cast<UInt64Array>(offsets), order, null_placement);
+ }
+ }
+ }
+ }
+ }
+};
+
+// Long array with big value range: std::stable_sort
+template <typename Type>
+class TestChunkedArrayRandom : public TestChunkedArrayRandomBase<Type> {
+ public:
+ void SetUp() override { rand_ = new Random<Type>(0x5487655); }
+
+ void TearDown() override { delete rand_; }
+
+ protected:
+ std::shared_ptr<Array> GenerateArray(int length, double null_probability) override {
+ return rand_->Generate(length, null_probability);
+ }
+
+ private:
+ Random<Type>* rand_;
+};
+TYPED_TEST_SUITE(TestChunkedArrayRandom, SortIndicesableTypes);
+
+TYPED_TEST(TestChunkedArrayRandom, SortIndices) { this->TestSortIndices(1000); }
+
+// Long array with small value range: counting sort
+// - length >= 1024(CountCompareSorter::countsort_min_len_)
+// - range <= 4096(CountCompareSorter::countsort_max_range_)
+template <typename Type>
+class TestChunkedArrayRandomNarrow : public TestChunkedArrayRandomBase<Type> {
+ public:
+ void SetUp() override {
+ range_ = 2000;
+ rand_ = new RandomRange<Type>(0x5487655);
+ }
+
+ void TearDown() override { delete rand_; }
+
+ protected:
+ std::shared_ptr<Array> GenerateArray(int length, double null_probability) override {
+ return rand_->Generate(length, range_, null_probability);
+ }
+
+ private:
+ int range_;
+ RandomRange<Type>* rand_;
+};
+TYPED_TEST_SUITE(TestChunkedArrayRandomNarrow, IntegralArrowTypes);
+TYPED_TEST(TestChunkedArrayRandomNarrow, SortIndices) { this->TestSortIndices(1000); }
+
+// Test basic cases for record batch.
+class TestRecordBatchSortIndices : public ::testing::Test {};
+
+TEST_F(TestRecordBatchSortIndices, NoNull) {
+ auto schema = ::arrow::schema({
+ {field("a", uint8())},
+ {field("b", uint32())},
+ });
+ auto batch = RecordBatchFromJSON(schema,
+ R"([{"a": 3, "b": 5},
+ {"a": 1, "b": 3},
+ {"a": 3, "b": 4},
+ {"a": 0, "b": 6},
+ {"a": 2, "b": 5},
+ {"a": 1, "b": 5},
+ {"a": 1, "b": 3}
+ ])");
+
+ for (auto null_placement : AllNullPlacements()) {
+ SortOptions options(
+ {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)},
+ null_placement);
+
+ AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]");
+ }
+}
+
+TEST_F(TestRecordBatchSortIndices, Null) {
+ auto schema = ::arrow::schema({
+ {field("a", uint8())},
+ {field("b", uint32())},
+ });
+ auto batch = RecordBatchFromJSON(schema,
+ R"([{"a": null, "b": 5},
+ {"a": 1, "b": 3},
+ {"a": 3, "b": null},
+ {"a": null, "b": null},
+ {"a": 2, "b": 5},
+ {"a": 1, "b": 5},
+ {"a": 3, "b": 5}
+ ])");
+ const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
+ SortKey("b", SortOrder::Descending)};
+
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(batch, options, "[5, 1, 4, 6, 2, 0, 3]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(batch, options, "[3, 0, 5, 1, 4, 2, 6]");
+}
+
+TEST_F(TestRecordBatchSortIndices, NaN) {
+ auto schema = ::arrow::schema({
+ {field("a", float32())},
+ {field("b", float64())},
+ });
+ auto batch = RecordBatchFromJSON(schema,
+ R"([{"a": 3, "b": 5},
+ {"a": 1, "b": NaN},
+ {"a": 3, "b": 4},
+ {"a": 0, "b": 6},
+ {"a": NaN, "b": 5},
+ {"a": NaN, "b": NaN},
+ {"a": NaN, "b": 5},
+ {"a": 1, "b": 5}
+ ])");
+ const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
+ SortKey("b", SortOrder::Descending)};
+
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(batch, options, "[3, 7, 1, 0, 2, 4, 6, 5]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(batch, options, "[5, 4, 6, 3, 1, 7, 0, 2]");
+}
+
+TEST_F(TestRecordBatchSortIndices, NaNAndNull) {
+ auto schema = ::arrow::schema({
+ {field("a", float32())},
+ {field("b", float64())},
+ });
+ auto batch = RecordBatchFromJSON(schema,
+ R"([{"a": null, "b": 5},
+ {"a": 1, "b": 3},
+ {"a": 3, "b": null},
+ {"a": null, "b": null},
+ {"a": NaN, "b": null},
+ {"a": NaN, "b": NaN},
+ {"a": NaN, "b": 5},
+ {"a": 1, "b": 5}
+ ])");
+ const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
+ SortKey("b", SortOrder::Descending)};
+
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(batch, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(batch, options, "[3, 0, 4, 5, 6, 7, 1, 2]");
+}
+
+TEST_F(TestRecordBatchSortIndices, Boolean) {
+ auto schema = ::arrow::schema({
+ {field("a", boolean())},
+ {field("b", boolean())},
+ });
+ auto batch = RecordBatchFromJSON(schema,
+ R"([{"a": true, "b": null},
+ {"a": false, "b": null},
+ {"a": true, "b": true},
+ {"a": false, "b": true},
+ {"a": true, "b": false},
+ {"a": null, "b": false},
+ {"a": false, "b": null},
+ {"a": null, "b": true}
+ ])");
+ const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
+ SortKey("b", SortOrder::Descending)};
+
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(batch, options, "[3, 1, 6, 2, 4, 0, 7, 5]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(batch, options, "[7, 5, 1, 6, 3, 0, 2, 4]");
+}
+
+TEST_F(TestRecordBatchSortIndices, MoreTypes) {
+ auto schema = ::arrow::schema({
+ {field("a", timestamp(TimeUnit::MICRO))},
+ {field("b", large_utf8())},
+ {field("c", fixed_size_binary(3))},
+ });
+ auto batch = RecordBatchFromJSON(schema,
+ R"([{"a": 3, "b": "05", "c": "aaa"},
+ {"a": 1, "b": "031", "c": "bbb"},
+ {"a": 3, "b": "05", "c": "bbb"},
+ {"a": 0, "b": "0666", "c": "aaa"},
+ {"a": 2, "b": "05", "c": "aaa"},
+ {"a": 1, "b": "05", "c": "bbb"}
+ ])");
+ const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
+ SortKey("b", SortOrder::Descending),
+ SortKey("c", SortOrder::Ascending)};
+
+ for (auto null_placement : AllNullPlacements()) {
+ SortOptions options(sort_keys, null_placement);
+ AssertSortIndices(batch, options, "[3, 5, 1, 4, 0, 2]");
+ }
+}
+
+TEST_F(TestRecordBatchSortIndices, Decimal) {
+ auto schema = ::arrow::schema({
+ {field("a", decimal128(3, 1))},
+ {field("b", decimal256(4, 2))},
+ });
+ auto batch = RecordBatchFromJSON(schema,
+ R"([{"a": "12.3", "b": "12.34"},
+ {"a": "45.6", "b": "12.34"},
+ {"a": "12.3", "b": "-12.34"},
+ {"a": "-12.3", "b": null},
+ {"a": "-12.3", "b": "-45.67"}
+ ])");
+ const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
+ SortKey("b", SortOrder::Descending)};
+
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(batch, options, "[4, 3, 0, 2, 1]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(batch, options, "[3, 4, 0, 2, 1]");
+}
+
+TEST_F(TestRecordBatchSortIndices, NullType) {
+ auto schema = arrow::schema({
+ field("a", null()),
+ field("b", int32()),
+ field("c", int32()),
+ field("d", int32()),
+ field("e", int32()),
+ field("f", int32()),
+ field("g", int32()),
+ field("h", int32()),
+ field("i", null()),
+ });
+ auto batch = RecordBatchFromJSON(schema, R"([
+ {"a": null, "b": 5, "c": 0, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": null},
+ {"a": null, "b": 5, "c": 1, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": null},
+ {"a": null, "b": 2, "c": 2, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": null},
+ {"a": null, "b": 4, "c": 3, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": null}
+])");
+ for (const auto null_placement : AllNullPlacements()) {
+ for (const auto order : AllOrders()) {
+ // Uses radix sorter
+ AssertSortIndices(batch,
+ SortOptions(
+ {
+ SortKey("a", order),
+ SortKey("i", order),
+ },
+ null_placement),
+ "[0, 1, 2, 3]");
+ AssertSortIndices(batch,
+ SortOptions(
+ {
+ SortKey("a", order),
+ SortKey("b", SortOrder::Ascending),
+ SortKey("i", order),
+ },
+ null_placement),
+ "[2, 3, 0, 1]");
+ // Uses multiple-key sorter
+ AssertSortIndices(batch,
+ SortOptions(
+ {
+ SortKey("a", order),
+ SortKey("b", SortOrder::Ascending),
+ SortKey("c", SortOrder::Ascending),
+ SortKey("d", SortOrder::Ascending),
+ SortKey("e", SortOrder::Ascending),
+ SortKey("f", SortOrder::Ascending),
+ SortKey("g", SortOrder::Ascending),
+ SortKey("h", SortOrder::Ascending),
+ SortKey("i", order),
+ },
+ null_placement),
+ "[2, 3, 0, 1]");
+ }
+ }
+}
+
+TEST_F(TestRecordBatchSortIndices, DuplicateSortKeys) {
+ // ARROW-14073: only the first occurrence of a given sort column is taken
+ // into account.
+ auto schema = ::arrow::schema({
+ {field("a", float32())},
+ {field("b", float64())},
+ });
+ auto batch = RecordBatchFromJSON(schema,
+ R"([{"a": null, "b": 5},
+ {"a": 1, "b": 3},
+ {"a": 3, "b": null},
+ {"a": null, "b": null},
+ {"a": NaN, "b": null},
+ {"a": NaN, "b": NaN},
+ {"a": NaN, "b": 5},
+ {"a": 1, "b": 5}
+ ])");
+ const std::vector<SortKey> sort_keys{
+ SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending),
+ SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Ascending),
+ SortKey("a", SortOrder::Descending)};
+
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(batch, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(batch, options, "[3, 0, 4, 5, 6, 7, 1, 2]");
+}
+
+// Test basic cases for table.
+class TestTableSortIndices : public ::testing::Test {};
+
+TEST_F(TestTableSortIndices, EmptyTable) {
+ auto schema = ::arrow::schema({
+ {field("a", uint8())},
+ {field("b", uint32())},
+ });
+ const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
+ SortKey("b", SortOrder::Descending)};
+
+ auto table = TableFromJSON(schema, {"[]"});
+ auto chunked_table = TableFromJSON(schema, {"[]", "[]"});
+
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(table, options, "[]");
+ AssertSortIndices(chunked_table, options, "[]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(table, options, "[]");
+ AssertSortIndices(chunked_table, options, "[]");
+}
+
+TEST_F(TestTableSortIndices, EmptySortKeys) {
+ auto schema = ::arrow::schema({
+ {field("a", uint8())},
+ {field("b", uint32())},
+ });
+ const std::vector<SortKey> sort_keys{};
+ const SortOptions options(sort_keys, NullPlacement::AtEnd);
+
+ auto table = TableFromJSON(schema, {R"([{"a": null, "b": 5}])"});
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, testing::HasSubstr("Must specify one or more sort keys"),
+ CallFunction("sort_indices", {table}, &options));
+
+ // Several chunks
+ table = TableFromJSON(schema, {R"([{"a": null, "b": 5}])", R"([{"a": 0, "b": 6}])"});
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, testing::HasSubstr("Must specify one or more sort keys"),
+ CallFunction("sort_indices", {table}, &options));
+}
+
+TEST_F(TestTableSortIndices, Null) {
+ auto schema = ::arrow::schema({
+ {field("a", uint8())},
+ {field("b", uint32())},
+ });
+ const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
+ SortKey("b", SortOrder::Descending)};
+ std::shared_ptr<Table> table;
+
+ table = TableFromJSON(schema, {R"([{"a": null, "b": 5},
+ {"a": 1, "b": 3},
+ {"a": 3, "b": null},
+ {"a": null, "b": null},
+ {"a": 2, "b": 5},
+ {"a": 1, "b": 5},
+ {"a": 3, "b": 5}
+ ])"});
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(table, options, "[5, 1, 4, 6, 2, 0, 3]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(table, options, "[3, 0, 5, 1, 4, 2, 6]");
+
+ // Same data, several chunks
+ table = TableFromJSON(schema, {R"([{"a": null, "b": 5},
+ {"a": 1, "b": 3},
+ {"a": 3, "b": null}
+ ])",
+ R"([{"a": null, "b": null},
+ {"a": 2, "b": 5},
+ {"a": 1, "b": 5},
+ {"a": 3, "b": 5}
+ ])"});
+ options.null_placement = NullPlacement::AtEnd;
+ AssertSortIndices(table, options, "[5, 1, 4, 6, 2, 0, 3]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(table, options, "[3, 0, 5, 1, 4, 2, 6]");
+}
+
+TEST_F(TestTableSortIndices, NaN) {
+ auto schema = ::arrow::schema({
+ {field("a", float32())},
+ {field("b", float64())},
+ });
+ const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
+ SortKey("b", SortOrder::Descending)};
+ std::shared_ptr<Table> table;
+
+ table = TableFromJSON(schema, {R"([{"a": 3, "b": 5},
+ {"a": 1, "b": NaN},
+ {"a": 3, "b": 4},
+ {"a": 0, "b": 6},
+ {"a": NaN, "b": 5},
+ {"a": NaN, "b": NaN},
+ {"a": NaN, "b": 5},
+ {"a": 1, "b": 5}
+ ])"});
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(table, options, "[5, 4, 6, 3, 1, 7, 0, 2]");
+
+ // Same data, several chunks
+ table = TableFromJSON(schema, {R"([{"a": 3, "b": 5},
+ {"a": 1, "b": NaN},
+ {"a": 3, "b": 4},
+ {"a": 0, "b": 6}
+ ])",
+ R"([{"a": NaN, "b": 5},
+ {"a": NaN, "b": NaN},
+ {"a": NaN, "b": 5},
+ {"a": 1, "b": 5}
+ ])"});
+ options.null_placement = NullPlacement::AtEnd;
+ AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(table, options, "[5, 4, 6, 3, 1, 7, 0, 2]");
+}
+
+TEST_F(TestTableSortIndices, NaNAndNull) {
+ auto schema = ::arrow::schema({
+ {field("a", float32())},
+ {field("b", float64())},
+ });
+ const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
+ SortKey("b", SortOrder::Descending)};
+ std::shared_ptr<Table> table;
+
+ table = TableFromJSON(schema, {R"([{"a": null, "b": 5},
+ {"a": 1, "b": 3},
+ {"a": 3, "b": null},
+ {"a": null, "b": null},
+ {"a": NaN, "b": null},
+ {"a": NaN, "b": NaN},
+ {"a": NaN, "b": 5},
+ {"a": 1, "b": 5}
+ ])"});
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]");
+
+ // Same data, several chunks
+ table = TableFromJSON(schema, {R"([{"a": null, "b": 5},
+ {"a": 1, "b": 3},
+ {"a": 3, "b": null},
+ {"a": null, "b": null}
+ ])",
+ R"([{"a": NaN, "b": null},
+ {"a": NaN, "b": NaN},
+ {"a": NaN, "b": 5},
+ {"a": 1, "b": 5}
+ ])"});
+ options.null_placement = NullPlacement::AtEnd;
+ AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]");
+}
+
+TEST_F(TestTableSortIndices, Boolean) {
+ auto schema = ::arrow::schema({
+ {field("a", boolean())},
+ {field("b", boolean())},
+ });
+ const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
+ SortKey("b", SortOrder::Descending)};
+
+ auto table = TableFromJSON(schema, {R"([{"a": true, "b": null},
+ {"a": false, "b": null},
+ {"a": true, "b": true},
+ {"a": false, "b": true}
+ ])",
+ R"([{"a": true, "b": false},
+ {"a": null, "b": false},
+ {"a": false, "b": null},
+ {"a": null, "b": true}
+ ])"});
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(table, options, "[3, 1, 6, 2, 4, 0, 7, 5]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(table, options, "[7, 5, 1, 6, 3, 0, 2, 4]");
+}
+
+TEST_F(TestTableSortIndices, BinaryLike) {
+ auto schema = ::arrow::schema({
+ {field("a", large_utf8())},
+ {field("b", fixed_size_binary(3))},
+ });
+ const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Descending),
+ SortKey("b", SortOrder::Ascending)};
+
+ auto table = TableFromJSON(schema, {R"([{"a": "one", "b": null},
+ {"a": "two", "b": "aaa"},
+ {"a": "three", "b": "bbb"},
+ {"a": "four", "b": "ccc"}
+ ])",
+ R"([{"a": "one", "b": "ddd"},
+ {"a": "two", "b": "ccc"},
+ {"a": "three", "b": "bbb"},
+ {"a": "four", "b": "aaa"}
+ ])"});
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(table, options, "[1, 5, 2, 6, 4, 0, 7, 3]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(table, options, "[1, 5, 2, 6, 0, 4, 7, 3]");
+}
+
+TEST_F(TestTableSortIndices, Decimal) {
+ auto schema = ::arrow::schema({
+ {field("a", decimal128(3, 1))},
+ {field("b", decimal256(4, 2))},
+ });
+ const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
+ SortKey("b", SortOrder::Descending)};
+
+ auto table = TableFromJSON(schema, {R"([{"a": "12.3", "b": "12.34"},
+ {"a": "45.6", "b": "12.34"},
+ {"a": "12.3", "b": "-12.34"}
+ ])",
+ R"([{"a": "-12.3", "b": null},
+ {"a": "-12.3", "b": "-45.67"}
+ ])"});
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(table, options, "[4, 3, 0, 2, 1]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(table, options, "[3, 4, 0, 2, 1]");
+}
+
+TEST_F(TestTableSortIndices, NullType) {
+ auto schema = arrow::schema({
+ field("a", null()),
+ field("b", int32()),
+ field("c", int32()),
+ field("d", null()),
+ });
+ auto table = TableFromJSON(schema, {
+ R"([
+ {"a": null, "b": 5, "c": 0, "d": null},
+ {"a": null, "b": 5, "c": 1, "d": null},
+ {"a": null, "b": 2, "c": 2, "d": null}
+ ])",
+ R"([])",
+ R"([{"a": null, "b": 4, "c": 3, "d": null}])",
+ });
+ for (const auto null_placement : AllNullPlacements()) {
+ for (const auto order : AllOrders()) {
+ AssertSortIndices(table,
+ SortOptions(
+ {
+ SortKey("a", order),
+ SortKey("d", order),
+ },
+ null_placement),
+ "[0, 1, 2, 3]");
+ AssertSortIndices(table,
+ SortOptions(
+ {
+ SortKey("a", order),
+ SortKey("b", SortOrder::Ascending),
+ SortKey("d", order),
+ },
+ null_placement),
+ "[2, 3, 0, 1]");
+ }
+ }
+}
+
+TEST_F(TestTableSortIndices, DuplicateSortKeys) {
+ // ARROW-14073: only the first occurrence of a given sort column is taken
+ // into account.
+ auto schema = ::arrow::schema({
+ {field("a", float32())},
+ {field("b", float64())},
+ });
+ const std::vector<SortKey> sort_keys{
+ SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending),
+ SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Ascending),
+ SortKey("a", SortOrder::Descending)};
+ std::shared_ptr<Table> table;
+
+ table = TableFromJSON(schema, {R"([{"a": null, "b": 5},
+ {"a": 1, "b": 3},
+ {"a": 3, "b": null},
+ {"a": null, "b": null}
+ ])",
+ R"([{"a": NaN, "b": null},
+ {"a": NaN, "b": NaN},
+ {"a": NaN, "b": 5},
+ {"a": 1, "b": 5}
+ ])"});
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]");
+}
+
+TEST_F(TestTableSortIndices, HeterogenousChunking) {
+ auto schema = ::arrow::schema({
+ {field("a", float32())},
+ {field("b", float64())},
+ });
+
+ // Same logical data as in "NaNAndNull" test above
+ auto col_a =
+ ChunkedArrayFromJSON(float32(), {"[null, 1]", "[]", "[3, null, NaN, NaN, NaN, 1]"});
+ auto col_b = ChunkedArrayFromJSON(float64(),
+ {"[5]", "[3, null, null]", "[null, NaN, 5]", "[5]"});
+ auto table = Table::Make(schema, {col_a, col_b});
+
+ SortOptions options(
+ {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)});
+ AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]");
+
+ options = SortOptions(
+ {SortKey("b", SortOrder::Ascending), SortKey("a", SortOrder::Descending)});
+ AssertSortIndices(table, options, "[1, 7, 6, 0, 5, 2, 4, 3]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(table, options, "[3, 4, 2, 5, 1, 0, 6, 7]");
+}
+
+// Tests for temporal types
+template <typename ArrowType>
+class TestTableSortIndicesForTemporal : public TestTableSortIndices {
+ protected:
+ std::shared_ptr<DataType> GetType() { return default_type_instance<ArrowType>(); }
+};
+TYPED_TEST_SUITE(TestTableSortIndicesForTemporal, TemporalArrowTypes);
+
+TYPED_TEST(TestTableSortIndicesForTemporal, NoNull) {
+ auto type = this->GetType();
+ const std::vector<SortKey> sort_keys{SortKey("a", SortOrder::Ascending),
+ SortKey("b", SortOrder::Descending)};
+ auto table = TableFromJSON(schema({
+ {field("a", type)},
+ {field("b", type)},
+ }),
+ {R"([{"a": 0, "b": 5},
+ {"a": 1, "b": 3},
+ {"a": 3, "b": 0},
+ {"a": 2, "b": 1},
+ {"a": 1, "b": 3},
+ {"a": 5, "b": 0},
+ {"a": 0, "b": 4},
+ {"a": 1, "b": 2}
+ ])"});
+ for (auto null_placement : AllNullPlacements()) {
+ SortOptions options(sort_keys, null_placement);
+ AssertSortIndices(table, options, "[0, 6, 1, 4, 7, 3, 2, 5]");
+ }
+}
+
+// For random table tests.
+using RandomParam = std::tuple<std::string, int, double>;
+
+class TestTableSortIndicesRandom : public testing::TestWithParam<RandomParam> {
+ // Compares two records in a column
+ class ColumnComparator : public TypeVisitor {
+ public:
+ ColumnComparator(SortOrder order, NullPlacement null_placement)
+ : order_(order), null_placement_(null_placement) {}
+
+ int operator()(const Array& left, const Array& right, uint64_t lhs, uint64_t rhs) {
+ left_ = &left;
+ right_ = &right;
+ lhs_ = lhs;
+ rhs_ = rhs;
+ ARROW_CHECK_OK(left.type()->Accept(this));
+ return compared_;
+ }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE##Type& type) override { \
+ compared_ = CompareType<TYPE##Type>(); \
+ return Status::OK(); \
+ }
+
+ VISIT(Boolean)
+ VISIT(Int8)
+ VISIT(Int16)
+ VISIT(Int32)
+ VISIT(Int64)
+ VISIT(UInt8)
+ VISIT(UInt16)
+ VISIT(UInt32)
+ VISIT(UInt64)
+ VISIT(Float)
+ VISIT(Double)
+ VISIT(String)
+ VISIT(LargeString)
+ VISIT(Decimal128)
+ VISIT(Decimal256)
+
+#undef VISIT
+
+ template <typename Type>
+ int CompareType() {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ ThreeWayComparator<ArrayType> three_way{order_, null_placement_};
+ return three_way(checked_cast<const ArrayType&>(*left_),
+ checked_cast<const ArrayType&>(*right_), lhs_, rhs_);
+ }
+
+ const SortOrder order_;
+ const NullPlacement null_placement_;
+ const Array* left_;
+ const Array* right_;
+ uint64_t lhs_;
+ uint64_t rhs_;
+ int compared_;
+ };
+
+ // Compares two records in the same table.
+ class Comparator {
+ public:
+ Comparator(const Table& table, const SortOptions& options) : options_(options) {
+ for (const auto& sort_key : options_.sort_keys) {
+ sort_columns_.emplace_back(table.GetColumnByName(sort_key.name).get(),
+ sort_key.order);
+ }
+ }
+
+ // Return true if the left record is less or equals to the right record,
+ // false otherwise.
+ bool operator()(uint64_t lhs, uint64_t rhs) {
+ for (const auto& pair : sort_columns_) {
+ ColumnComparator comparator(pair.second, options_.null_placement);
+ const auto& chunked_array = *pair.first;
+ int64_t lhs_index = 0, rhs_index = 0;
+ const Array* lhs_array = FindTargetArray(chunked_array, lhs, &lhs_index);
+ const Array* rhs_array = FindTargetArray(chunked_array, rhs, &rhs_index);
+ int compared = comparator(*lhs_array, *rhs_array, lhs_index, rhs_index);
+ if (compared != 0) {
+ return compared < 0;
+ }
+ }
+ return lhs < rhs;
+ }
+
+ // Find the target chunk and index in the target chunk from an
+ // index in chunked array.
+ const Array* FindTargetArray(const ChunkedArray& chunked_array, int64_t i,
+ int64_t* chunk_index) {
+ int64_t offset = 0;
+ for (const auto& chunk : chunked_array.chunks()) {
+ if (i < offset + chunk->length()) {
+ *chunk_index = i - offset;
+ return chunk.get();
+ }
+ offset += chunk->length();
+ }
+ return nullptr;
+ }
+
+ const SortOptions& options_;
+ std::vector<std::pair<const ChunkedArray*, SortOrder>> sort_columns_;
+ };
+
+ public:
+ // Validates the sorted indices are really sorted.
+ void Validate(const Table& table, const SortOptions& options, UInt64Array& offsets) {
+ ValidateOutput(offsets);
+ Comparator comparator{table, options};
+ for (int i = 1; i < table.num_rows(); i++) {
+ uint64_t lhs = offsets.Value(i - 1);
+ uint64_t rhs = offsets.Value(i);
+ if (!comparator(lhs, rhs)) {
+ std::stringstream ss;
+ ss << "Rows not ordered at consecutive sort indices:";
+ ss << "\nFirst row (index = " << lhs << "): ";
+ PrintRow(table, lhs, &ss);
+ ss << "\nSecond row (index = " << rhs << "): ";
+ PrintRow(table, rhs, &ss);
+ FAIL() << ss.str();
+ }
+ }
+ }
+
+ void PrintRow(const Table& table, uint64_t index, std::ostream* os) {
+ *os << "{";
+ const auto& columns = table.columns();
+ for (size_t i = 0; i < columns.size(); ++i) {
+ if (i != 0) {
+ *os << ", ";
+ }
+ ASSERT_OK_AND_ASSIGN(auto scal, columns[i]->GetScalar(index));
+ *os << scal->ToString();
+ }
+ *os << "}";
+ }
+};
+
+TEST_P(TestTableSortIndicesRandom, Sort) {
+ const auto first_sort_key_name = std::get<0>(GetParam());
+ const auto n_sort_keys = std::get<1>(GetParam());
+ const auto null_probability = std::get<2>(GetParam());
+ const auto nan_probability = (1.0 - null_probability) / 4;
+ const auto seed = 0x61549225;
+
+ ARROW_SCOPED_TRACE("n_sort_keys = ", n_sort_keys);
+ ARROW_SCOPED_TRACE("null_probability = ", null_probability);
+
+ ::arrow::random::RandomArrayGenerator rng(seed);
+
+ // Of these, "uint8", "boolean" and "string" should have many duplicates
+ const FieldVector fields = {
+ {field("uint8", uint8())},
+ {field("int16", int16())},
+ {field("int32", int32())},
+ {field("uint64", uint64())},
+ {field("float", float32())},
+ {field("boolean", boolean())},
+ {field("string", utf8())},
+ {field("large_string", large_utf8())},
+ {field("decimal128", decimal128(25, 3))},
+ {field("decimal256", decimal256(42, 6))},
+ };
+ const auto schema = ::arrow::schema(fields);
+ const int64_t length = 80;
+
+ using ArrayFactory = std::function<std::shared_ptr<Array>(int64_t length)>;
+
+ std::vector<ArrayFactory> column_factories{
+ [&](int64_t length) { return rng.UInt8(length, 0, 10, null_probability); },
+ [&](int64_t length) {
+ return rng.Int16(length, -1000, 12000, /*null_probability=*/0.0);
+ },
+ [&](int64_t length) {
+ return rng.Int32(length, -123456789, 987654321, null_probability);
+ },
+ [&](int64_t length) {
+ return rng.UInt64(length, 1, 1234567890123456789ULL, /*null_probability=*/0.0);
+ },
+ [&](int64_t length) {
+ return rng.Float32(length, -1.0f, 1.0f, null_probability, nan_probability);
+ },
+ [&](int64_t length) {
+ return rng.Boolean(length, /*true_probability=*/0.3, null_probability);
+ },
+ [&](int64_t length) {
+ if (length > 0) {
+ return rng.StringWithRepeats(length, /*unique=*/1 + length / 10,
+ /*min_length=*/5,
+ /*max_length=*/15, null_probability);
+ } else {
+ return *MakeArrayOfNull(utf8(), 0);
+ }
+ },
+ [&](int64_t length) {
+ return rng.LargeString(length, /*min_length=*/5, /*max_length=*/15,
+ /*null_probability=*/0.0);
+ },
+ [&](int64_t length) {
+ return rng.Decimal128(fields[8]->type(), length, null_probability);
+ },
+ [&](int64_t length) {
+ return rng.Decimal256(fields[9]->type(), length, /*null_probability=*/0.0);
+ },
+ };
+
+ // Generate random sort keys, making sure no column is included twice
+ std::default_random_engine engine(seed);
+ std::uniform_int_distribution<> distribution(0);
+
+ auto generate_order = [&]() {
+ return (distribution(engine) & 1) ? SortOrder::Ascending : SortOrder::Descending;
+ };
+
+ std::vector<SortKey> sort_keys;
+ sort_keys.reserve(fields.size());
+ for (const auto& field : fields) {
+ if (field->name() != first_sort_key_name) {
+ sort_keys.emplace_back(field->name(), generate_order());
+ }
+ }
+ std::shuffle(sort_keys.begin(), sort_keys.end(), engine);
+ sort_keys.emplace(sort_keys.begin(), first_sort_key_name, generate_order());
+ sort_keys.erase(sort_keys.begin() + n_sort_keys, sort_keys.end());
+ ASSERT_EQ(sort_keys.size(), n_sort_keys);
+
+ std::stringstream ss;
+ for (const auto& sort_key : sort_keys) {
+ ss << sort_key.name << (sort_key.order == SortOrder::Ascending ? " ASC" : " DESC");
+ ss << ", ";
+ }
+ ARROW_SCOPED_TRACE("sort_keys = ", ss.str());
+
+ SortOptions options(sort_keys);
+
+ // Test with different, heterogenous table chunkings
+ for (const int64_t max_num_chunks : {1, 3, 15}) {
+ ARROW_SCOPED_TRACE("Table sorting: max chunks per column = ", max_num_chunks);
+ std::uniform_int_distribution<int64_t> num_chunk_dist(1 + max_num_chunks / 2,
+ max_num_chunks);
+ ChunkedArrayVector columns;
+ columns.reserve(fields.size());
+
+ // Chunk each column independently, and make sure they consist of
+ // physically non-contiguous chunks.
+ for (const auto& factory : column_factories) {
+ const int64_t num_chunks = num_chunk_dist(engine);
+ ArrayVector chunks(num_chunks);
+ const auto offsets =
+ checked_pointer_cast<Int32Array>(rng.Offsets(num_chunks + 1, 0, length));
+ for (int64_t i = 0; i < num_chunks; ++i) {
+ const auto chunk_len = offsets->Value(i + 1) - offsets->Value(i);
+ chunks[i] = factory(chunk_len);
+ }
+ columns.push_back(std::make_shared<ChunkedArray>(std::move(chunks)));
+ ASSERT_EQ(columns.back()->length(), length);
+ }
+
+ auto table = Table::Make(schema, std::move(columns));
+ for (auto null_placement : AllNullPlacements()) {
+ ARROW_SCOPED_TRACE("null_placement = ", null_placement);
+ options.null_placement = null_placement;
+ ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*table), options));
+ Validate(*table, options, *checked_pointer_cast<UInt64Array>(offsets));
+ }
+ }
+
+ // Also validate RecordBatch sorting
+ ARROW_SCOPED_TRACE("Record batch sorting");
+ ArrayVector columns;
+ columns.reserve(fields.size());
+ for (const auto& factory : column_factories) {
+ columns.push_back(factory(length));
+ }
+ auto batch = RecordBatch::Make(schema, length, std::move(columns));
+ ASSERT_OK(batch->ValidateFull());
+ ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches(schema, {batch}));
+
+ for (auto null_placement : AllNullPlacements()) {
+ ARROW_SCOPED_TRACE("null_placement = ", null_placement);
+ options.null_placement = null_placement;
+ ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(batch), options));
+ Validate(*table, options, *checked_pointer_cast<UInt64Array>(offsets));
+ }
+}
+
+// Some first keys will have duplicates, others not
+static const auto first_sort_keys = testing::Values("uint8", "int16", "uint64", "float",
+ "boolean", "string", "decimal128");
+
+// Different numbers of sort keys may trigger different algorithms
+static const auto num_sort_keys = testing::Values(1, 3, 7, 9);
+
+INSTANTIATE_TEST_SUITE_P(NoNull, TestTableSortIndicesRandom,
+ testing::Combine(first_sort_keys, num_sort_keys,
+ testing::Values(0.0)));
+
+INSTANTIATE_TEST_SUITE_P(SomeNulls, TestTableSortIndicesRandom,
+ testing::Combine(first_sort_keys, num_sort_keys,
+ testing::Values(0.1, 0.5)));
+
+INSTANTIATE_TEST_SUITE_P(AllNull, TestTableSortIndicesRandom,
+ testing::Combine(first_sort_keys, num_sort_keys,
+ testing::Values(1.0)));
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/vector_topk_benchmark.cc b/src/arrow/cpp/src/arrow/compute/kernels/vector_topk_benchmark.cc
new file mode 100644
index 000000000..3f89eb6be
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/kernels/vector_topk_benchmark.cc
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+
+namespace arrow {
+namespace compute {
+constexpr auto kSeed = 0x0ff1ce;
+
+static void SelectKBenchmark(benchmark::State& state,
+ const std::shared_ptr<Array>& values, int64_t k) {
+ for (auto _ : state) {
+ ABORT_NOT_OK(SelectKUnstable(*values, SelectKOptions::TopKDefault(k)).status());
+ }
+ state.SetItemsProcessed(state.iterations() * values->length());
+}
+
+static void SelectKInt64(benchmark::State& state) {
+ RegressionArgs args(state);
+
+ const int64_t array_size = args.size / sizeof(int64_t);
+ auto rand = random::RandomArrayGenerator(kSeed);
+
+ auto min = std::numeric_limits<int64_t>::min();
+ auto max = std::numeric_limits<int64_t>::max();
+ auto values = rand.Int64(array_size, min, max, args.null_proportion);
+
+ SelectKBenchmark(state, values, array_size / 8);
+}
+
+BENCHMARK(SelectKInt64)
+ ->Apply(RegressionSetArgs)
+ ->Args({1 << 20, 100})
+ ->Args({1 << 23, 100})
+ ->MinTime(1.0)
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/registry.cc b/src/arrow/cpp/src/arrow/compute/registry.cc
new file mode 100644
index 000000000..bd303ea42
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/registry.cc
@@ -0,0 +1,200 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/registry.h"
+
+#include <algorithm>
+#include <memory>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+
+#include "arrow/compute/function.h"
+#include "arrow/compute/function_internal.h"
+#include "arrow/compute/registry_internal.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace compute {
+
+class FunctionRegistry::FunctionRegistryImpl {
+ public:
+ Status AddFunction(std::shared_ptr<Function> function, bool allow_overwrite) {
+ RETURN_NOT_OK(function->Validate());
+
+ std::lock_guard<std::mutex> mutation_guard(lock_);
+
+ const std::string& name = function->name();
+ auto it = name_to_function_.find(name);
+ if (it != name_to_function_.end() && !allow_overwrite) {
+ return Status::KeyError("Already have a function registered with name: ", name);
+ }
+ name_to_function_[name] = std::move(function);
+ return Status::OK();
+ }
+
+ Status AddAlias(const std::string& target_name, const std::string& source_name) {
+ std::lock_guard<std::mutex> mutation_guard(lock_);
+
+ auto it = name_to_function_.find(source_name);
+ if (it == name_to_function_.end()) {
+ return Status::KeyError("No function registered with name: ", source_name);
+ }
+ name_to_function_[target_name] = it->second;
+ return Status::OK();
+ }
+
+ Status AddFunctionOptionsType(const FunctionOptionsType* options_type,
+ bool allow_overwrite = false) {
+ std::lock_guard<std::mutex> mutation_guard(lock_);
+
+ const std::string name = options_type->type_name();
+ auto it = name_to_options_type_.find(name);
+ if (it != name_to_options_type_.end() && !allow_overwrite) {
+ return Status::KeyError(
+ "Already have a function options type registered with name: ", name);
+ }
+ name_to_options_type_[name] = options_type;
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<Function>> GetFunction(const std::string& name) const {
+ auto it = name_to_function_.find(name);
+ if (it == name_to_function_.end()) {
+ return Status::KeyError("No function registered with name: ", name);
+ }
+ return it->second;
+ }
+
+ std::vector<std::string> GetFunctionNames() const {
+ std::vector<std::string> results;
+ for (auto it : name_to_function_) {
+ results.push_back(it.first);
+ }
+ std::sort(results.begin(), results.end());
+ return results;
+ }
+
+ Result<const FunctionOptionsType*> GetFunctionOptionsType(
+ const std::string& name) const {
+ auto it = name_to_options_type_.find(name);
+ if (it == name_to_options_type_.end()) {
+ return Status::KeyError("No function options type registered with name: ", name);
+ }
+ return it->second;
+ }
+
+ int num_functions() const { return static_cast<int>(name_to_function_.size()); }
+
+ private:
+ std::mutex lock_;
+ std::unordered_map<std::string, std::shared_ptr<Function>> name_to_function_;
+ std::unordered_map<std::string, const FunctionOptionsType*> name_to_options_type_;
+};
+
+std::unique_ptr<FunctionRegistry> FunctionRegistry::Make() {
+ return std::unique_ptr<FunctionRegistry>(new FunctionRegistry());
+}
+
+FunctionRegistry::FunctionRegistry() { impl_.reset(new FunctionRegistryImpl()); }
+
+FunctionRegistry::~FunctionRegistry() {}
+
+Status FunctionRegistry::AddFunction(std::shared_ptr<Function> function,
+ bool allow_overwrite) {
+ return impl_->AddFunction(std::move(function), allow_overwrite);
+}
+
+Status FunctionRegistry::AddAlias(const std::string& target_name,
+ const std::string& source_name) {
+ return impl_->AddAlias(target_name, source_name);
+}
+
+Status FunctionRegistry::AddFunctionOptionsType(const FunctionOptionsType* options_type,
+ bool allow_overwrite) {
+ return impl_->AddFunctionOptionsType(options_type, allow_overwrite);
+}
+
+Result<std::shared_ptr<Function>> FunctionRegistry::GetFunction(
+ const std::string& name) const {
+ return impl_->GetFunction(name);
+}
+
+std::vector<std::string> FunctionRegistry::GetFunctionNames() const {
+ return impl_->GetFunctionNames();
+}
+
+Result<const FunctionOptionsType*> FunctionRegistry::GetFunctionOptionsType(
+ const std::string& name) const {
+ return impl_->GetFunctionOptionsType(name);
+}
+
+int FunctionRegistry::num_functions() const { return impl_->num_functions(); }
+
+namespace internal {
+
+static std::unique_ptr<FunctionRegistry> CreateBuiltInRegistry() {
+ auto registry = FunctionRegistry::Make();
+
+ // Scalar functions
+ RegisterScalarArithmetic(registry.get());
+ RegisterScalarBoolean(registry.get());
+ RegisterScalarCast(registry.get());
+ RegisterScalarComparison(registry.get());
+ RegisterScalarIfElse(registry.get());
+ RegisterScalarNested(registry.get());
+ RegisterScalarSetLookup(registry.get());
+ RegisterScalarStringAscii(registry.get());
+ RegisterScalarTemporalBinary(registry.get());
+ RegisterScalarTemporalUnary(registry.get());
+ RegisterScalarValidity(registry.get());
+
+ RegisterScalarOptions(registry.get());
+
+ // Vector functions
+ RegisterVectorArraySort(registry.get());
+ RegisterVectorHash(registry.get());
+ RegisterVectorNested(registry.get());
+ RegisterVectorReplace(registry.get());
+ RegisterVectorSelection(registry.get());
+ RegisterVectorSort(registry.get());
+
+ RegisterVectorOptions(registry.get());
+
+ // Aggregate functions
+ RegisterHashAggregateBasic(registry.get());
+ RegisterScalarAggregateBasic(registry.get());
+ RegisterScalarAggregateMode(registry.get());
+ RegisterScalarAggregateQuantile(registry.get());
+ RegisterScalarAggregateTDigest(registry.get());
+ RegisterScalarAggregateVariance(registry.get());
+
+ RegisterAggregateOptions(registry.get());
+
+ return registry;
+}
+
+} // namespace internal
+
+FunctionRegistry* GetFunctionRegistry() {
+ static auto g_registry = internal::CreateBuiltInRegistry();
+ return g_registry.get();
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/registry.h b/src/arrow/cpp/src/arrow/compute/registry.h
new file mode 100644
index 000000000..e83036db6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/registry.h
@@ -0,0 +1,93 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// NOTE: API is EXPERIMENTAL and will change without going through a
+// deprecation cycle
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace compute {
+
+class Function;
+class FunctionOptionsType;
+
+/// \brief A mutable central function registry for built-in functions as well
+/// as user-defined functions. Functions are implementations of
+/// arrow::compute::Function.
+///
+/// Generally, each function contains kernels which are implementations of a
+/// function for a specific argument signature. After looking up a function in
+/// the registry, one can either execute it eagerly with Function::Execute or
+/// use one of the function's dispatch methods to pick a suitable kernel for
+/// lower-level function execution.
+class ARROW_EXPORT FunctionRegistry {
+ public:
+ ~FunctionRegistry();
+
+ /// \brief Construct a new registry. Most users only need to use the global
+ /// registry
+ static std::unique_ptr<FunctionRegistry> Make();
+
+ /// \brief Add a new function to the registry. Returns Status::KeyError if a
+ /// function with the same name is already registered
+ Status AddFunction(std::shared_ptr<Function> function, bool allow_overwrite = false);
+
+ /// \brief Add aliases for the given function name. Returns Status::KeyError if the
+ /// function with the given name is not registered
+ Status AddAlias(const std::string& target_name, const std::string& source_name);
+
+ /// \brief Add a new function options type to the registry. Returns Status::KeyError if
+ /// a function options type with the same name is already registered
+ Status AddFunctionOptionsType(const FunctionOptionsType* options_type,
+ bool allow_overwrite = false);
+
+ /// \brief Retrieve a function by name from the registry
+ Result<std::shared_ptr<Function>> GetFunction(const std::string& name) const;
+
+ /// \brief Return vector of all entry names in the registry. Helpful for
+ /// displaying a manifest of available functions
+ std::vector<std::string> GetFunctionNames() const;
+
+ /// \brief Retrieve a function options type by name from the registry
+ Result<const FunctionOptionsType*> GetFunctionOptionsType(
+ const std::string& name) const;
+
+ /// \brief The number of currently registered functions
+ int num_functions() const;
+
+ private:
+ FunctionRegistry();
+
+ // Use PIMPL pattern to not have std::unordered_map here
+ class FunctionRegistryImpl;
+ std::unique_ptr<FunctionRegistryImpl> impl_;
+};
+
+/// \brief Return the process-global function registry
+ARROW_EXPORT FunctionRegistry* GetFunctionRegistry();
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/registry_internal.h b/src/arrow/cpp/src/arrow/compute/registry_internal.h
new file mode 100644
index 000000000..98f61185f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/registry_internal.h
@@ -0,0 +1,64 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+namespace arrow {
+namespace compute {
+
+class FunctionRegistry;
+
+namespace internal {
+
+// Built-in scalar / elementwise functions
+void RegisterScalarArithmetic(FunctionRegistry* registry);
+void RegisterScalarBoolean(FunctionRegistry* registry);
+void RegisterScalarCast(FunctionRegistry* registry);
+void RegisterScalarComparison(FunctionRegistry* registry);
+void RegisterScalarIfElse(FunctionRegistry* registry);
+void RegisterScalarNested(FunctionRegistry* registry);
+void RegisterScalarSetLookup(FunctionRegistry* registry);
+void RegisterScalarStringAscii(FunctionRegistry* registry);
+void RegisterScalarTemporalBinary(FunctionRegistry* registry);
+void RegisterScalarTemporalUnary(FunctionRegistry* registry);
+void RegisterScalarValidity(FunctionRegistry* registry);
+
+void RegisterScalarOptions(FunctionRegistry* registry);
+
+// Vector functions
+void RegisterVectorArraySort(FunctionRegistry* registry);
+void RegisterVectorHash(FunctionRegistry* registry);
+void RegisterVectorNested(FunctionRegistry* registry);
+void RegisterVectorReplace(FunctionRegistry* registry);
+void RegisterVectorSelection(FunctionRegistry* registry);
+void RegisterVectorSort(FunctionRegistry* registry);
+
+void RegisterVectorOptions(FunctionRegistry* registry);
+
+// Aggregate functions
+void RegisterHashAggregateBasic(FunctionRegistry* registry);
+void RegisterScalarAggregateBasic(FunctionRegistry* registry);
+void RegisterScalarAggregateMode(FunctionRegistry* registry);
+void RegisterScalarAggregateQuantile(FunctionRegistry* registry);
+void RegisterScalarAggregateTDigest(FunctionRegistry* registry);
+void RegisterScalarAggregateVariance(FunctionRegistry* registry);
+
+void RegisterAggregateOptions(FunctionRegistry* registry);
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/registry_test.cc b/src/arrow/cpp/src/arrow/compute/registry_test.cc
new file mode 100644
index 000000000..e1e0d5231
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/registry_test.cc
@@ -0,0 +1,87 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/compute/function.h"
+#include "arrow/compute/registry.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace compute {
+
+class TestRegistry : public ::testing::Test {
+ public:
+ void SetUp() { registry_ = FunctionRegistry::Make(); }
+
+ protected:
+ std::unique_ptr<FunctionRegistry> registry_;
+};
+
+TEST_F(TestRegistry, CreateBuiltInRegistry) {
+ // This does DCHECK_OK internally for now so this will fail in debug builds
+ // if there is a problem initializing the global function registry
+ FunctionRegistry* registry = GetFunctionRegistry();
+ ARROW_UNUSED(registry);
+}
+
+TEST_F(TestRegistry, Basics) {
+ ASSERT_EQ(0, registry_->num_functions());
+
+ std::shared_ptr<Function> func =
+ std::make_shared<ScalarFunction>("f1", Arity::Unary(), /*doc=*/nullptr);
+ ASSERT_OK(registry_->AddFunction(func));
+ ASSERT_EQ(1, registry_->num_functions());
+
+ func = std::make_shared<VectorFunction>("f0", Arity::Binary(), /*doc=*/nullptr);
+ ASSERT_OK(registry_->AddFunction(func));
+ ASSERT_EQ(2, registry_->num_functions());
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<const Function> f1, registry_->GetFunction("f1"));
+ ASSERT_EQ("f1", f1->name());
+
+ // Non-existent function
+ ASSERT_RAISES(KeyError, registry_->GetFunction("f2"));
+
+ // Try adding a function with name collision
+ func = std::make_shared<ScalarAggregateFunction>("f1", Arity::Unary(), /*doc=*/nullptr);
+ ASSERT_RAISES(KeyError, registry_->AddFunction(func));
+
+ // Allow overwriting by flag
+ ASSERT_OK(registry_->AddFunction(func, /*allow_overwrite=*/true));
+ ASSERT_OK_AND_ASSIGN(f1, registry_->GetFunction("f1"));
+ ASSERT_EQ(Function::SCALAR_AGGREGATE, f1->kind());
+
+ std::vector<std::string> expected_names = {"f0", "f1"};
+ ASSERT_EQ(expected_names, registry_->GetFunctionNames());
+
+ // Aliases
+ ASSERT_RAISES(KeyError, registry_->AddAlias("f33", "f3"));
+ ASSERT_OK(registry_->AddAlias("f11", "f1"));
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<const Function> f2, registry_->GetFunction("f11"));
+ ASSERT_EQ(func, f2);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/compute/type_fwd.h b/src/arrow/cpp/src/arrow/compute/type_fwd.h
new file mode 100644
index 000000000..127929ced
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/compute/type_fwd.h
@@ -0,0 +1,50 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+namespace arrow {
+
+struct Datum;
+struct ValueDescr;
+
+namespace compute {
+
+class Function;
+class FunctionOptions;
+
+class CastOptions;
+
+struct ExecBatch;
+class ExecContext;
+class KernelContext;
+
+struct Kernel;
+struct ScalarKernel;
+struct ScalarAggregateKernel;
+struct VectorKernel;
+
+struct KernelState;
+
+class Expression;
+class ExecNode;
+class ExecPlan;
+class ExecNodeOptions;
+class ExecFactoryRegistry;
+
+} // namespace compute
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/config.cc b/src/arrow/cpp/src/arrow/config.cc
new file mode 100644
index 000000000..b93f20716
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/config.cc
@@ -0,0 +1,78 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/config.h"
+
+#include <cstdint>
+
+#include "arrow/util/config.h"
+#include "arrow/util/cpu_info.h"
+
+namespace arrow {
+
+using internal::CpuInfo;
+
+namespace {
+
+const BuildInfo kBuildInfo = {
+ // clang-format off
+ ARROW_VERSION,
+ ARROW_VERSION_MAJOR,
+ ARROW_VERSION_MINOR,
+ ARROW_VERSION_PATCH,
+ ARROW_VERSION_STRING,
+ ARROW_SO_VERSION,
+ ARROW_FULL_SO_VERSION,
+ ARROW_CXX_COMPILER_ID,
+ ARROW_CXX_COMPILER_VERSION,
+ ARROW_CXX_COMPILER_FLAGS,
+ ARROW_GIT_ID,
+ ARROW_GIT_DESCRIPTION,
+ ARROW_PACKAGE_KIND,
+ // clang-format on
+};
+
+template <typename QueryFlagFunction>
+std::string MakeSimdLevelString(QueryFlagFunction&& query_flag) {
+ if (query_flag(CpuInfo::AVX512)) {
+ return "avx512";
+ } else if (query_flag(CpuInfo::AVX2)) {
+ return "avx2";
+ } else if (query_flag(CpuInfo::AVX)) {
+ return "avx";
+ } else if (query_flag(CpuInfo::SSE4_2)) {
+ return "sse4_2";
+ } else {
+ return "none";
+ }
+}
+
+}; // namespace
+
+const BuildInfo& GetBuildInfo() { return kBuildInfo; }
+
+RuntimeInfo GetRuntimeInfo() {
+ RuntimeInfo info;
+ auto cpu_info = CpuInfo::GetInstance();
+ info.simd_level =
+ MakeSimdLevelString([&](int64_t flags) { return cpu_info->IsSupported(flags); });
+ info.detected_simd_level =
+ MakeSimdLevelString([&](int64_t flags) { return cpu_info->IsDetected(flags); });
+ return info;
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/config.h b/src/arrow/cpp/src/arrow/config.h
new file mode 100644
index 000000000..5ae7e2231
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/config.h
@@ -0,0 +1,72 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+
+#include "arrow/util/config.h" // IWYU pragma: export
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+struct BuildInfo {
+ /// The packed version number, e.g. 1002003 (decimal) for Arrow 1.2.3
+ int version;
+ /// The "major" version number, e.g. 1 for Arrow 1.2.3
+ int version_major;
+ /// The "minor" version number, e.g. 2 for Arrow 1.2.3
+ int version_minor;
+ /// The "patch" version number, e.g. 3 for Arrow 1.2.3
+ int version_patch;
+ /// The version string, e.g. "1.2.3"
+ std::string version_string;
+ std::string so_version;
+ std::string full_so_version;
+ std::string compiler_id;
+ std::string compiler_version;
+ std::string compiler_flags;
+ std::string git_id;
+ std::string git_description;
+ std::string package_kind;
+};
+
+struct RuntimeInfo {
+ /// The enabled SIMD level
+ ///
+ /// This can be less than `detected_simd_level` if the ARROW_USER_SIMD_LEVEL
+ /// environment variable is set to another value.
+ std::string simd_level;
+
+ /// The SIMD level available on the OS and CPU
+ std::string detected_simd_level;
+};
+
+/// \brief Get runtime build info.
+///
+/// The returned values correspond to exact loaded version of the Arrow library,
+/// rather than the values frozen at application compile-time through the `ARROW_*`
+/// preprocessor definitions.
+ARROW_EXPORT
+const BuildInfo& GetBuildInfo();
+
+/// \brief Get runtime info.
+///
+ARROW_EXPORT
+RuntimeInfo GetRuntimeInfo();
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/CMakeLists.txt b/src/arrow/cpp/src/arrow/csv/CMakeLists.txt
new file mode 100644
index 000000000..561faf1b5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/CMakeLists.txt
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set(CSV_TEST_SRCS
+ chunker_test.cc
+ column_builder_test.cc
+ column_decoder_test.cc
+ converter_test.cc
+ parser_test.cc
+ reader_test.cc)
+
+# Writer depends on compute's cast functionality
+if(ARROW_COMPUTE)
+ list(APPEND CSV_TEST_SRCS writer_test.cc)
+endif()
+
+add_arrow_test(csv-test SOURCES ${CSV_TEST_SRCS})
+
+add_arrow_benchmark(converter_benchmark PREFIX "arrow-csv")
+add_arrow_benchmark(parser_benchmark PREFIX "arrow-csv")
+
+arrow_install_all_headers("arrow/csv")
+
+# pkg-config support
+arrow_add_pkg_config("arrow-csv")
diff --git a/src/arrow/cpp/src/arrow/csv/api.h b/src/arrow/cpp/src/arrow/csv/api.h
new file mode 100644
index 000000000..7bf393157
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/api.h
@@ -0,0 +1,26 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/csv/options.h"
+#include "arrow/csv/reader.h"
+
+// The writer depends on compute module for casting.
+#ifdef ARROW_COMPUTE
+#include "arrow/csv/writer.h"
+#endif
diff --git a/src/arrow/cpp/src/arrow/csv/arrow-csv.pc.in b/src/arrow/cpp/src/arrow/csv/arrow-csv.pc.in
new file mode 100644
index 000000000..9c69c6923
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/arrow-csv.pc.in
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+
+Name: Apache Arrow CSV
+Description: CSV reader module for Apache Arrow
+Version: @ARROW_VERSION@
+Requires: arrow
diff --git a/src/arrow/cpp/src/arrow/csv/chunker.cc b/src/arrow/cpp/src/arrow/csv/chunker.cc
new file mode 100644
index 000000000..12bb03a88
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/chunker.cc
@@ -0,0 +1,311 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/csv/chunker.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace csv {
+
+namespace {
+
+// NOTE: csvmonkey (https://github.com/dw/csvmonkey) has optimization ideas
+
+template <bool quoting, bool escaping>
+class Lexer {
+ public:
+ enum State {
+ FIELD_START,
+ IN_FIELD,
+ AT_ESCAPE,
+ IN_QUOTED_FIELD,
+ AT_QUOTED_QUOTE,
+ AT_QUOTED_ESCAPE
+ };
+
+ explicit Lexer(const ParseOptions& options) : options_(options) {
+ DCHECK_EQ(quoting, options_.quoting);
+ DCHECK_EQ(escaping, options_.escaping);
+ }
+
+ const char* ReadLine(const char* data, const char* data_end) {
+ // The parsing state machine
+ char c;
+ DCHECK_GT(data_end - data, 0);
+ if (ARROW_PREDICT_TRUE(state_ == FIELD_START)) {
+ goto FieldStart;
+ }
+ switch (state_) {
+ case FIELD_START:
+ goto FieldStart;
+ case IN_FIELD:
+ goto InField;
+ case AT_ESCAPE:
+ // will never reach here if escaping = false
+ // just to hint the compiler to remove dead code
+ if (!escaping) return nullptr;
+ goto AtEscape;
+ case IN_QUOTED_FIELD:
+ if (!quoting) return nullptr;
+ goto InQuotedField;
+ case AT_QUOTED_QUOTE:
+ if (!quoting) return nullptr;
+ goto AtQuotedQuote;
+ case AT_QUOTED_ESCAPE:
+ if (!quoting) return nullptr;
+ goto AtQuotedEscape;
+ }
+
+ FieldStart:
+ if (!quoting) {
+ goto InField;
+ } else {
+ // At the start of a field
+ if (ARROW_PREDICT_FALSE(data == data_end)) {
+ state_ = FIELD_START;
+ goto AbortLine;
+ }
+ // Quoting is only recognized at start of field
+ if (*data == options_.quote_char) {
+ data++;
+ goto InQuotedField;
+ } else {
+ goto InField;
+ }
+ }
+
+ InField:
+ // Inside a non-quoted part of a field
+ if (ARROW_PREDICT_FALSE(data == data_end)) {
+ state_ = IN_FIELD;
+ goto AbortLine;
+ }
+ c = *data++;
+ if (escaping && ARROW_PREDICT_FALSE(c == options_.escape_char)) {
+ if (ARROW_PREDICT_FALSE(data == data_end)) {
+ state_ = AT_ESCAPE;
+ goto AbortLine;
+ }
+ data++;
+ goto InField;
+ }
+ if (ARROW_PREDICT_FALSE(c == '\r')) {
+ if (ARROW_PREDICT_TRUE(data != data_end) && *data == '\n') {
+ data++;
+ }
+ goto LineEnd;
+ }
+ if (ARROW_PREDICT_FALSE(c == '\n')) {
+ goto LineEnd;
+ }
+ // treat delimiter as a normal token if quoting is disabled
+ if (ARROW_PREDICT_FALSE(quoting && c == options_.delimiter)) {
+ goto FieldEnd;
+ }
+ goto InField;
+
+ AtEscape:
+ // Coming here if last block ended on a non-quoted escape
+ data++;
+ goto InField;
+
+ InQuotedField:
+ // Inside a quoted part of a field
+ if (ARROW_PREDICT_FALSE(data == data_end)) {
+ state_ = IN_QUOTED_FIELD;
+ goto AbortLine;
+ }
+ c = *data++;
+ if (escaping && ARROW_PREDICT_FALSE(c == options_.escape_char)) {
+ if (ARROW_PREDICT_FALSE(data == data_end)) {
+ state_ = AT_QUOTED_ESCAPE;
+ goto AbortLine;
+ }
+ data++;
+ goto InQuotedField;
+ }
+ if (ARROW_PREDICT_FALSE(c == options_.quote_char)) {
+ if (ARROW_PREDICT_FALSE(data == data_end)) {
+ state_ = AT_QUOTED_QUOTE;
+ goto AbortLine;
+ }
+ if (options_.double_quote && *data == options_.quote_char) {
+ // Double-quoting
+ data++;
+ } else {
+ // End of single-quoting
+ goto InField;
+ }
+ }
+ goto InQuotedField;
+
+ AtQuotedEscape:
+ // Coming here if last block ended on a quoted escape
+ data++;
+ goto InQuotedField;
+
+ AtQuotedQuote:
+ // Coming here if last block ended on a quoted quote
+ if (options_.double_quote && *data == options_.quote_char) {
+ // Double-quoting
+ data++;
+ goto InQuotedField;
+ } else {
+ // End of single-quoting
+ goto InField;
+ }
+
+ FieldEnd:
+ // At the end of a field
+ goto FieldStart;
+
+ LineEnd:
+ state_ = FIELD_START;
+ return data;
+
+ AbortLine:
+ // Truncated line
+ return nullptr;
+ }
+
+ protected:
+ const ParseOptions& options_;
+ State state_ = FIELD_START;
+};
+
+// A BoundaryFinder implementation that assumes CSV cells can contain raw newlines,
+// and uses actual CSV lexing to delimit them.
+template <bool quoting, bool escaping>
+class LexingBoundaryFinder : public BoundaryFinder {
+ public:
+ explicit LexingBoundaryFinder(ParseOptions options) : options_(std::move(options)) {}
+
+ Status FindFirst(util::string_view partial, util::string_view block,
+ int64_t* out_pos) override {
+ Lexer<quoting, escaping> lexer(options_);
+
+ const char* line_end =
+ lexer.ReadLine(partial.data(), partial.data() + partial.size());
+ DCHECK_EQ(line_end, nullptr); // Otherwise `partial` is a whole CSV line
+ line_end = lexer.ReadLine(block.data(), block.data() + block.size());
+
+ if (line_end == nullptr) {
+ // No complete CSV line
+ *out_pos = -1;
+ } else {
+ *out_pos = static_cast<int64_t>(line_end - block.data());
+ DCHECK_GT(*out_pos, 0);
+ }
+ return Status::OK();
+ }
+
+ Status FindLast(util::string_view block, int64_t* out_pos) override {
+ Lexer<quoting, escaping> lexer(options_);
+
+ const char* data = block.data();
+ const char* const data_end = block.data() + block.size();
+
+ while (data < data_end) {
+ const char* line_end = lexer.ReadLine(data, data_end);
+ if (line_end == nullptr) {
+ // Cannot read any further
+ break;
+ }
+ DCHECK_GT(line_end, data);
+ data = line_end;
+ }
+ if (data == block.data()) {
+ // No complete CSV line
+ *out_pos = -1;
+ } else {
+ *out_pos = static_cast<int64_t>(data - block.data());
+ DCHECK_GT(*out_pos, 0);
+ }
+ return Status::OK();
+ }
+
+ Status FindNth(util::string_view partial, util::string_view block, int64_t count,
+ int64_t* out_pos, int64_t* num_found) override {
+ Lexer<quoting, escaping> lexer(options_);
+ int64_t found = 0;
+ const char* data = block.data();
+ const char* const data_end = block.data() + block.size();
+
+ const char* line_end;
+ if (partial.size()) {
+ line_end = lexer.ReadLine(partial.data(), partial.data() + partial.size());
+ DCHECK_EQ(line_end, nullptr); // Otherwise `partial` is a whole CSV line
+ }
+
+ for (; data < data_end && found < count; ++found) {
+ line_end = lexer.ReadLine(data, data_end);
+ if (line_end == nullptr) {
+ // Cannot read any further
+ break;
+ }
+ DCHECK_GT(line_end, data);
+ data = line_end;
+ }
+
+ if (data == block.data()) {
+ // No complete CSV line
+ *out_pos = kNoDelimiterFound;
+ } else {
+ *out_pos = static_cast<int64_t>(data - block.data());
+ }
+ *num_found = found;
+ return Status::OK();
+ }
+
+ protected:
+ ParseOptions options_;
+};
+
+} // namespace
+
+std::unique_ptr<Chunker> MakeChunker(const ParseOptions& options) {
+ std::shared_ptr<BoundaryFinder> delimiter;
+ if (!options.newlines_in_values) {
+ delimiter = MakeNewlineBoundaryFinder();
+ } else {
+ if (options.quoting) {
+ if (options.escaping) {
+ delimiter = std::make_shared<LexingBoundaryFinder<true, true>>(options);
+ } else {
+ delimiter = std::make_shared<LexingBoundaryFinder<true, false>>(options);
+ }
+ } else {
+ if (options.escaping) {
+ delimiter = std::make_shared<LexingBoundaryFinder<false, true>>(options);
+ } else {
+ delimiter = std::make_shared<LexingBoundaryFinder<false, false>>(options);
+ }
+ }
+ }
+ return internal::make_unique<Chunker>(std::move(delimiter));
+}
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/chunker.h b/src/arrow/cpp/src/arrow/csv/chunker.h
new file mode 100644
index 000000000..662b16ec4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/chunker.h
@@ -0,0 +1,36 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/csv/options.h"
+#include "arrow/status.h"
+#include "arrow/util/delimiting.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace csv {
+
+ARROW_EXPORT
+std::unique_ptr<Chunker> MakeChunker(const ParseOptions& options);
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/chunker_test.cc b/src/arrow/cpp/src/arrow/csv/chunker_test.cc
new file mode 100644
index 000000000..07ce5a413
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/chunker_test.cc
@@ -0,0 +1,372 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <memory>
+#include <numeric>
+#include <string>
+
+#include <gtest/gtest.h>
+
+#include "arrow/buffer.h"
+#include "arrow/csv/chunker.h"
+#include "arrow/csv/options.h"
+#include "arrow/csv/test_common.h"
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+namespace csv {
+
+void AssertChunkSize(Chunker& chunker, const std::string& str, uint32_t chunk_size) {
+ std::shared_ptr<Buffer> block, whole, partial;
+ block = std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>(str.data()),
+ static_cast<int64_t>(str.size()));
+ ASSERT_OK(chunker.Process(block, &whole, &partial));
+ ASSERT_EQ(block->size(), whole->size() + partial->size());
+ auto actual_chunk_size = static_cast<uint32_t>(whole->size());
+ ASSERT_EQ(actual_chunk_size, chunk_size);
+}
+
+template <typename IntContainer>
+void AssertChunking(Chunker& chunker, const std::string& str,
+ const IntContainer& expected_lengths) {
+ uint32_t expected_chunk_size;
+
+ // First chunkize whole CSV block
+ expected_chunk_size = static_cast<uint32_t>(
+ std::accumulate(expected_lengths.begin(), expected_lengths.end(), 0ULL));
+ AssertChunkSize(chunker, str, expected_chunk_size);
+
+ // Then chunkize incomplete substrings of the block
+ expected_chunk_size = 0;
+ for (const auto length : expected_lengths) {
+ AssertChunkSize(chunker, str.substr(0, expected_chunk_size + length - 1),
+ expected_chunk_size);
+
+ expected_chunk_size += static_cast<uint32_t>(length);
+ AssertChunkSize(chunker, str.substr(0, expected_chunk_size), expected_chunk_size);
+ }
+}
+
+class BaseChunkerTest : public ::testing::TestWithParam<bool> {
+ protected:
+ void SetUp() override {
+ options_ = ParseOptions::Defaults();
+ options_.newlines_in_values = GetParam();
+ }
+
+ void MakeChunker() { chunker_ = ::arrow::csv::MakeChunker(options_); }
+
+ void AssertSkip(const std::string& str, int64_t count, int64_t rem_count,
+ int64_t rest_size) {
+ MakeChunker();
+ {
+ auto test_count = count;
+ auto partial = std::make_shared<Buffer>("");
+ auto block = std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>(str.data()),
+ static_cast<int64_t>(str.size()));
+ std::shared_ptr<Buffer> rest;
+ ASSERT_OK(chunker_->ProcessSkip(partial, block, true, &test_count, &rest));
+ ASSERT_EQ(rem_count, test_count);
+ ASSERT_EQ(rest_size, rest->size());
+ AssertBufferEqual(*SliceBuffer(block, block->size() - rest_size), *rest);
+ }
+ {
+ auto test_count = count;
+ auto split = static_cast<int64_t>(str.find_first_of('\n'));
+ auto partial =
+ std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>(str.data()), split);
+ auto block =
+ std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>(str.data() + split),
+ static_cast<int64_t>(str.size()) - split);
+ std::shared_ptr<Buffer> rest;
+ ASSERT_OK(chunker_->ProcessSkip(partial, block, true, &test_count, &rest));
+ ASSERT_EQ(rem_count, test_count);
+ ASSERT_EQ(rest_size, rest->size());
+ AssertBufferEqual(*SliceBuffer(block, block->size() - rest_size), *rest);
+ }
+ }
+
+ ParseOptions options_;
+ std::unique_ptr<Chunker> chunker_;
+};
+
+INSTANTIATE_TEST_SUITE_P(ChunkerTest, BaseChunkerTest, ::testing::Values(true));
+
+INSTANTIATE_TEST_SUITE_P(NoNewlineChunkerTest, BaseChunkerTest, ::testing::Values(false));
+
+TEST_P(BaseChunkerTest, Basics) {
+ auto csv = MakeCSVData({"ab,c,\n", "def,,gh\n", ",ij,kl\n"});
+ auto lengths = {6, 8, 7};
+
+ MakeChunker();
+ AssertChunking(*chunker_, csv, lengths);
+}
+
+TEST_P(BaseChunkerTest, Empty) {
+ MakeChunker();
+ {
+ auto csv = MakeCSVData({"\n"});
+ auto lengths = {1};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ {
+ auto csv = MakeCSVData({"\n\n"});
+ auto lengths = {1, 1};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ {
+ auto csv = MakeCSVData({",\n"});
+ auto lengths = {2};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ {
+ auto csv = MakeCSVData({",\n,\n"});
+ auto lengths = {2, 2};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+}
+
+TEST_P(BaseChunkerTest, Newlines) {
+ MakeChunker();
+ {
+ auto csv = MakeCSVData({"a\n", "b\r", "c,d\r\n"});
+ AssertChunkSize(*chunker_, csv, static_cast<uint32_t>(csv.size()));
+ // Trailing \n after \r is optional
+ AssertChunkSize(*chunker_, csv.substr(0, csv.size() - 1),
+ static_cast<uint32_t>(csv.size() - 1));
+ }
+}
+
+TEST_P(BaseChunkerTest, QuotingSimple) {
+ auto csv = MakeCSVData({"1,\",3,\",5\n"});
+ {
+ MakeChunker();
+ auto lengths = {csv.size()};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ {
+ options_.quoting = false;
+ MakeChunker();
+ auto lengths = {csv.size()};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+}
+
+TEST_P(BaseChunkerTest, QuotingNewline) {
+ auto csv = MakeCSVData({"a,\"c \n d\",e\n"});
+ if (options_.newlines_in_values) {
+ MakeChunker();
+ auto lengths = {12};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ {
+ options_.quoting = false;
+ MakeChunker();
+ auto lengths = {6, 6};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+}
+
+TEST_P(BaseChunkerTest, QuotingUnbalanced) {
+ // Quote introduces a quoted field that doesn't end
+ auto csv = MakeCSVData({"a,b\n", "1,\",3,,5\n", "c,d\n"});
+ if (options_.newlines_in_values) {
+ MakeChunker();
+ auto lengths = {4};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ {
+ options_.quoting = false;
+ MakeChunker();
+ auto lengths = {4, 9, 4};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+}
+
+TEST_P(BaseChunkerTest, QuotingEmpty) {
+ MakeChunker();
+ {
+ auto csv = MakeCSVData({"\"\"\n", "a\n"});
+ auto lengths = {3, 2};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ {
+ auto csv = MakeCSVData({",\"\"\n", "a\n"});
+ auto lengths = {4, 2};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ {
+ auto csv = MakeCSVData({"\"\",\n", "a\n"});
+ auto lengths = {4, 2};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+}
+
+TEST_P(BaseChunkerTest, QuotingDouble) {
+ {
+ MakeChunker();
+ // 4 quotes is a quoted quote
+ auto csv = MakeCSVData({"\"\"\"\"\n", "a\n"});
+ auto lengths = {5, 2};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+}
+
+TEST_P(BaseChunkerTest, QuotesSpecial) {
+ // Some non-trivial cases
+ {
+ MakeChunker();
+ auto csv = MakeCSVData({"a,b\"c,d\n", "e\n"});
+ auto lengths = {8, 2};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ {
+ MakeChunker();
+ auto csv = MakeCSVData({"a,\"b\" \"c\",d\n", "e\n"});
+ auto lengths = {12, 2};
+ AssertChunking(*chunker_, csv, lengths);
+ }
+}
+
+TEST_P(BaseChunkerTest, Escaping) {
+ {
+ auto csv = MakeCSVData({"a\\b,c\n", "d\n"});
+ auto lengths = {6, 2};
+ {
+ options_.escaping = false;
+ MakeChunker();
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ {
+ options_.escaping = true;
+ MakeChunker();
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ }
+ {
+ auto csv = MakeCSVData({"a\\,b,c\n", "d\n"});
+ auto lengths = {7, 2};
+ {
+ options_.escaping = false;
+ MakeChunker();
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ {
+ options_.escaping = true;
+ MakeChunker();
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ }
+}
+
+TEST_P(BaseChunkerTest, EscapingNewline) {
+ if (options_.newlines_in_values) {
+ auto csv = MakeCSVData({"a\\\nb\n", "c\n"});
+ {
+ auto lengths = {3, 2, 2};
+ MakeChunker();
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ options_.escaping = true;
+ {
+ auto lengths = {5, 2};
+ MakeChunker();
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ }
+}
+
+TEST_P(BaseChunkerTest, EscapingAndQuoting) {
+ if (options_.newlines_in_values) {
+ {
+ auto csv = MakeCSVData({"\"a\\\"\n", "\"b\\\"\n"});
+ {
+ options_.quoting = true;
+ options_.escaping = true;
+ auto lengths = {10};
+ MakeChunker();
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ {
+ options_.quoting = true;
+ options_.escaping = false;
+ auto lengths = {5, 5};
+ MakeChunker();
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ }
+ {
+ auto csv = MakeCSVData({"\"a\\\n\"\n"});
+ {
+ options_.quoting = false;
+ options_.escaping = true;
+ auto lengths = {6};
+ MakeChunker();
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ {
+ options_.quoting = false;
+ options_.escaping = false;
+ auto lengths = {4, 2};
+ MakeChunker();
+ AssertChunking(*chunker_, csv, lengths);
+ }
+ }
+ }
+}
+
+TEST_P(BaseChunkerTest, ParseSkip) {
+ {
+ auto csv = MakeCSVData({"ab,c,\n", "def,,gh\n", ",ij,kl\n"});
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 1, 0, 15));
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 2, 0, 7));
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 3, 0, 0));
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 4, 1, 0));
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 6, 3, 0));
+ }
+
+ // Test with no trailing new line
+ {
+ auto csv = MakeCSVData({"ab,c,\n", "def,,gh\n", ",ij,kl"});
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 2, 0, 6));
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 3, 0, 0));
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 4, 1, 0));
+ }
+
+ // Test skip with new lines in values
+ {
+ auto csv = MakeCSVData({"ab,\"c\n\",\n", "\"d\nef\",,gh\n", ",ij,\"nkl\"\n"});
+ options_.newlines_in_values = true;
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 1, 0, 21));
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 2, 0, 10));
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 3, 0, 0));
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 4, 1, 0));
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 6, 3, 0));
+ }
+
+ // Test with no trailing new line and new lines in values
+ {
+ auto csv = MakeCSVData({"ab,\"c\n\",\n", "\"d\nef\",,gh\n", ",ij,\"nkl\""});
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 2, 0, 9));
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 3, 0, 0));
+ ASSERT_NO_FATAL_FAILURE(AssertSkip(csv, 4, 1, 0));
+ }
+}
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/column_builder.cc b/src/arrow/cpp/src/arrow/csv/column_builder.cc
new file mode 100644
index 000000000..bc9744287
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/column_builder.cc
@@ -0,0 +1,367 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <mutex>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_base.h"
+#include "arrow/chunked_array.h"
+#include "arrow/csv/column_builder.h"
+#include "arrow/csv/converter.h"
+#include "arrow/csv/inference_internal.h"
+#include "arrow/csv/options.h"
+#include "arrow/csv/parser.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/task_group.h"
+
+namespace arrow {
+namespace csv {
+
+class BlockParser;
+
+using internal::TaskGroup;
+
+class ConcreteColumnBuilder : public ColumnBuilder {
+ public:
+ explicit ConcreteColumnBuilder(MemoryPool* pool,
+ std::shared_ptr<internal::TaskGroup> task_group,
+ int32_t col_index = -1)
+ : ColumnBuilder(std::move(task_group)), pool_(pool), col_index_(col_index) {}
+
+ void Append(const std::shared_ptr<BlockParser>& parser) override {
+ Insert(static_cast<int64_t>(chunks_.size()), parser);
+ }
+
+ Result<std::shared_ptr<ChunkedArray>> Finish() override {
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ return FinishUnlocked();
+ }
+
+ protected:
+ virtual std::shared_ptr<DataType> type() const = 0;
+
+ Result<std::shared_ptr<ChunkedArray>> FinishUnlocked() {
+ auto type = this->type();
+ for (const auto& chunk : chunks_) {
+ if (chunk == nullptr) {
+ return Status::UnknownError("a chunk failed converting for an unknown reason");
+ }
+ DCHECK_EQ(chunk->type()->id(), type->id()) << "Chunk types not equal!";
+ }
+ return std::make_shared<ChunkedArray>(chunks_, std::move(type));
+ }
+
+ void ReserveChunks(int64_t block_index) {
+ // Create a null Array pointer at the back at the list.
+ std::lock_guard<std::mutex> lock(mutex_);
+ ReserveChunksUnlocked(block_index);
+ }
+
+ void ReserveChunksUnlocked(int64_t block_index) {
+ // Create a null Array pointer at the back at the list.
+ size_t chunk_index = static_cast<size_t>(block_index);
+ if (chunks_.size() <= chunk_index) {
+ chunks_.resize(chunk_index + 1);
+ }
+ }
+
+ Status SetChunk(int64_t chunk_index, Result<std::shared_ptr<Array>> maybe_array) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ return SetChunkUnlocked(chunk_index, std::move(maybe_array));
+ }
+
+ Status SetChunkUnlocked(int64_t chunk_index,
+ Result<std::shared_ptr<Array>> maybe_array) {
+ // Should not insert an already built chunk
+ DCHECK_EQ(chunks_[chunk_index], nullptr);
+
+ if (maybe_array.ok()) {
+ chunks_[chunk_index] = *std::move(maybe_array);
+ return Status::OK();
+ } else {
+ return WrapConversionError(maybe_array.status());
+ }
+ }
+
+ Status WrapConversionError(const Status& st) {
+ if (ARROW_PREDICT_TRUE(st.ok())) {
+ return st;
+ } else {
+ std::stringstream ss;
+ ss << "In CSV column #" << col_index_ << ": " << st.message();
+ return st.WithMessage(ss.str());
+ }
+ }
+
+ MemoryPool* pool_;
+ int32_t col_index_;
+
+ ArrayVector chunks_;
+
+ std::mutex mutex_;
+};
+
+//////////////////////////////////////////////////////////////////////////
+// Null column builder implementation (for a column not in the CSV file)
+
+class NullColumnBuilder : public ConcreteColumnBuilder {
+ public:
+ explicit NullColumnBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool,
+ const std::shared_ptr<internal::TaskGroup>& task_group)
+ : ConcreteColumnBuilder(pool, task_group), type_(type) {}
+
+ void Insert(int64_t block_index, const std::shared_ptr<BlockParser>& parser) override;
+
+ protected:
+ std::shared_ptr<DataType> type() const override { return type_; }
+
+ std::shared_ptr<DataType> type_;
+};
+
+void NullColumnBuilder::Insert(int64_t block_index,
+ const std::shared_ptr<BlockParser>& parser) {
+ ReserveChunks(block_index);
+
+ // Spawn a task that will build an array of nulls with the right DataType
+ const int32_t num_rows = parser->num_rows();
+ DCHECK_GE(num_rows, 0);
+
+ task_group_->Append([=]() -> Status {
+ std::unique_ptr<ArrayBuilder> builder;
+ RETURN_NOT_OK(MakeBuilder(pool_, type_, &builder));
+ std::shared_ptr<Array> res;
+ RETURN_NOT_OK(builder->AppendNulls(num_rows));
+ RETURN_NOT_OK(builder->Finish(&res));
+
+ return SetChunk(block_index, res);
+ });
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Pre-typed column builder implementation
+
+class TypedColumnBuilder : public ConcreteColumnBuilder {
+ public:
+ TypedColumnBuilder(const std::shared_ptr<DataType>& type, int32_t col_index,
+ const ConvertOptions& options, MemoryPool* pool,
+ const std::shared_ptr<internal::TaskGroup>& task_group)
+ : ConcreteColumnBuilder(pool, task_group, col_index),
+ type_(type),
+ options_(options) {}
+
+ Status Init();
+
+ void Insert(int64_t block_index, const std::shared_ptr<BlockParser>& parser) override;
+
+ protected:
+ std::shared_ptr<DataType> type() const override { return type_; }
+
+ std::shared_ptr<DataType> type_;
+ // CAUTION: ConvertOptions can grow large (if it customizes hundreds or
+ // thousands of columns), so avoid copying it in each TypedColumnBuilder.
+ const ConvertOptions& options_;
+
+ std::shared_ptr<Converter> converter_;
+};
+
+Status TypedColumnBuilder::Init() {
+ ARROW_ASSIGN_OR_RAISE(converter_, Converter::Make(type_, options_, pool_));
+ return Status::OK();
+}
+
+void TypedColumnBuilder::Insert(int64_t block_index,
+ const std::shared_ptr<BlockParser>& parser) {
+ DCHECK_NE(converter_, nullptr);
+
+ ReserveChunks(block_index);
+
+ // We're careful that all references in the closure outlive the Append() call
+ task_group_->Append([=]() -> Status {
+ return SetChunk(block_index, converter_->Convert(*parser, col_index_));
+ });
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Type-inferring column builder implementation
+
+class InferringColumnBuilder : public ConcreteColumnBuilder {
+ public:
+ InferringColumnBuilder(int32_t col_index, const ConvertOptions& options,
+ MemoryPool* pool,
+ const std::shared_ptr<internal::TaskGroup>& task_group)
+ : ConcreteColumnBuilder(pool, task_group, col_index),
+ options_(options),
+ infer_status_(options) {}
+
+ Status Init();
+
+ void Insert(int64_t block_index, const std::shared_ptr<BlockParser>& parser) override;
+ Result<std::shared_ptr<ChunkedArray>> Finish() override;
+
+ protected:
+ std::shared_ptr<DataType> type() const override {
+ DCHECK_NE(converter_, nullptr);
+ return converter_->type();
+ }
+
+ Status UpdateType();
+ Status TryConvertChunk(int64_t chunk_index);
+ // This must be called unlocked!
+ void ScheduleConvertChunk(int64_t chunk_index);
+
+ // CAUTION: ConvertOptions can grow large (if it customizes hundreds or
+ // thousands of columns), so avoid copying it in each InferringColumnBuilder.
+ const ConvertOptions& options_;
+
+ // Current inference status
+ InferStatus infer_status_;
+ std::shared_ptr<Converter> converter_;
+
+ // The parsers corresponding to each chunk (for reconverting)
+ std::vector<std::shared_ptr<BlockParser>> parsers_;
+};
+
+Status InferringColumnBuilder::Init() { return UpdateType(); }
+
+Status InferringColumnBuilder::UpdateType() {
+ return infer_status_.MakeConverter(pool_).Value(&converter_);
+}
+
+void InferringColumnBuilder::ScheduleConvertChunk(int64_t chunk_index) {
+ task_group_->Append([=]() { return TryConvertChunk(chunk_index); });
+}
+
+Status InferringColumnBuilder::TryConvertChunk(int64_t chunk_index) {
+ std::unique_lock<std::mutex> lock(mutex_);
+ std::shared_ptr<Converter> converter = converter_;
+ std::shared_ptr<BlockParser> parser = parsers_[chunk_index];
+ InferKind kind = infer_status_.kind();
+
+ DCHECK_NE(parser, nullptr);
+
+ lock.unlock();
+ auto maybe_array = converter->Convert(*parser, col_index_);
+ lock.lock();
+
+ if (kind != infer_status_.kind()) {
+ // infer_kind_ was changed by another task, reconvert
+ lock.unlock();
+ ScheduleConvertChunk(chunk_index);
+ return Status::OK();
+ }
+
+ if (maybe_array.ok() || !infer_status_.can_loosen_type()) {
+ // Conversion succeeded, or failed definitively
+ if (!infer_status_.can_loosen_type()) {
+ // We won't try to reconvert anymore
+ parsers_[chunk_index].reset();
+ }
+ return SetChunkUnlocked(chunk_index, maybe_array);
+ }
+
+ // Conversion failed, try another type
+ infer_status_.LoosenType(maybe_array.status());
+ RETURN_NOT_OK(UpdateType());
+
+ // Reconvert past finished chunks
+ // (unfinished chunks will notice by themselves if they need reconverting)
+ const auto nchunks = static_cast<int64_t>(chunks_.size());
+ for (int64_t i = 0; i < nchunks; ++i) {
+ if (i != chunk_index && chunks_[i]) {
+ // We're assuming the chunk was converted using the wrong type
+ // (which should be true unless the executor reorders tasks)
+ chunks_[i].reset();
+ lock.unlock();
+ ScheduleConvertChunk(i);
+ lock.lock();
+ }
+ }
+
+ // Reconvert this chunk
+ lock.unlock();
+ ScheduleConvertChunk(chunk_index);
+
+ return Status::OK();
+}
+
+void InferringColumnBuilder::Insert(int64_t block_index,
+ const std::shared_ptr<BlockParser>& parser) {
+ // Create a slot for the new chunk and spawn a task to convert it
+ size_t chunk_index = static_cast<size_t>(block_index);
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ DCHECK_NE(converter_, nullptr);
+ if (parsers_.size() <= chunk_index) {
+ parsers_.resize(chunk_index + 1);
+ }
+ // Should not insert an already converting chunk
+ DCHECK_EQ(parsers_[chunk_index], nullptr);
+ parsers_[chunk_index] = parser;
+ ReserveChunksUnlocked(block_index);
+ }
+
+ ScheduleConvertChunk(chunk_index);
+}
+
+Result<std::shared_ptr<ChunkedArray>> InferringColumnBuilder::Finish() {
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ parsers_.clear();
+ return FinishUnlocked();
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Factory functions
+
+Result<std::shared_ptr<ColumnBuilder>> ColumnBuilder::Make(
+ MemoryPool* pool, const std::shared_ptr<DataType>& type, int32_t col_index,
+ const ConvertOptions& options, const std::shared_ptr<TaskGroup>& task_group) {
+ auto ptr =
+ std::make_shared<TypedColumnBuilder>(type, col_index, options, pool, task_group);
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+}
+
+Result<std::shared_ptr<ColumnBuilder>> ColumnBuilder::Make(
+ MemoryPool* pool, int32_t col_index, const ConvertOptions& options,
+ const std::shared_ptr<TaskGroup>& task_group) {
+ auto ptr =
+ std::make_shared<InferringColumnBuilder>(col_index, options, pool, task_group);
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+}
+
+Result<std::shared_ptr<ColumnBuilder>> ColumnBuilder::MakeNull(
+ MemoryPool* pool, const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<internal::TaskGroup>& task_group) {
+ return std::make_shared<NullColumnBuilder>(type, pool, task_group);
+}
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/column_builder.h b/src/arrow/cpp/src/arrow/csv/column_builder.h
new file mode 100644
index 000000000..170a8ad06
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/column_builder.h
@@ -0,0 +1,78 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "arrow/result.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace csv {
+
+class BlockParser;
+struct ConvertOptions;
+
+class ARROW_EXPORT ColumnBuilder {
+ public:
+ virtual ~ColumnBuilder() = default;
+
+ /// Spawn a task that will try to convert and append the given CSV block.
+ /// All calls to Append() should happen on the same thread, otherwise
+ /// call Insert() instead.
+ virtual void Append(const std::shared_ptr<BlockParser>& parser) = 0;
+
+ /// Spawn a task that will try to convert and insert the given CSV block
+ virtual void Insert(int64_t block_index,
+ const std::shared_ptr<BlockParser>& parser) = 0;
+
+ /// Return the final chunked array. The TaskGroup _must_ have finished!
+ virtual Result<std::shared_ptr<ChunkedArray>> Finish() = 0;
+
+ std::shared_ptr<internal::TaskGroup> task_group() { return task_group_; }
+
+ /// Construct a strictly-typed ColumnBuilder.
+ static Result<std::shared_ptr<ColumnBuilder>> Make(
+ MemoryPool* pool, const std::shared_ptr<DataType>& type, int32_t col_index,
+ const ConvertOptions& options,
+ const std::shared_ptr<internal::TaskGroup>& task_group);
+
+ /// Construct a type-inferring ColumnBuilder.
+ static Result<std::shared_ptr<ColumnBuilder>> Make(
+ MemoryPool* pool, int32_t col_index, const ConvertOptions& options,
+ const std::shared_ptr<internal::TaskGroup>& task_group);
+
+ /// Construct a ColumnBuilder for a column of nulls
+ /// (i.e. not present in the CSV file).
+ static Result<std::shared_ptr<ColumnBuilder>> MakeNull(
+ MemoryPool* pool, const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<internal::TaskGroup>& task_group);
+
+ protected:
+ explicit ColumnBuilder(std::shared_ptr<internal::TaskGroup> task_group)
+ : task_group_(std::move(task_group)) {}
+
+ std::shared_ptr<internal::TaskGroup> task_group_;
+};
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/column_builder_test.cc b/src/arrow/cpp/src/arrow/csv/column_builder_test.cc
new file mode 100644
index 000000000..7577c883e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/column_builder_test.cc
@@ -0,0 +1,608 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/csv/column_builder.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/csv/options.h"
+#include "arrow/csv/test_common.h"
+#include "arrow/memory_pool.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/task_group.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+namespace csv {
+
+class BlockParser;
+
+using internal::checked_cast;
+using internal::GetCpuThreadPool;
+using internal::TaskGroup;
+
+using ChunkData = std::vector<std::vector<std::string>>;
+
+class ColumnBuilderTest : public ::testing::Test {
+ public:
+ void AssertBuilding(const std::shared_ptr<ColumnBuilder>& builder,
+ const ChunkData& chunks, bool validate_full,
+ std::shared_ptr<ChunkedArray>* out) {
+ for (const auto& chunk : chunks) {
+ std::shared_ptr<BlockParser> parser;
+ MakeColumnParser(chunk, &parser);
+ builder->Append(parser);
+ }
+ ASSERT_OK(builder->task_group()->Finish());
+ ASSERT_OK_AND_ASSIGN(*out, builder->Finish());
+ if (validate_full) {
+ ASSERT_OK((*out)->ValidateFull());
+ } else {
+ ASSERT_OK((*out)->Validate());
+ }
+ }
+
+ void AssertBuilding(const std::shared_ptr<ColumnBuilder>& builder,
+ const ChunkData& chunks, std::shared_ptr<ChunkedArray>* out) {
+ AssertBuilding(builder, chunks, /*validate_full=*/true, out);
+ }
+
+ void CheckInferred(const std::shared_ptr<TaskGroup>& tg, const ChunkData& csv_data,
+ const ConvertOptions& options,
+ std::shared_ptr<ChunkedArray> expected, bool validate_full = true) {
+ std::shared_ptr<ColumnBuilder> builder;
+ std::shared_ptr<ChunkedArray> actual;
+ ASSERT_OK_AND_ASSIGN(builder,
+ ColumnBuilder::Make(default_memory_pool(), 0, options, tg));
+ AssertBuilding(builder, csv_data, validate_full, &actual);
+ AssertChunkedEqual(*actual, *expected);
+ }
+
+ void CheckInferred(const std::shared_ptr<TaskGroup>& tg, const ChunkData& csv_data,
+ const ConvertOptions& options,
+ std::vector<std::shared_ptr<Array>> expected_chunks,
+ bool validate_full = true) {
+ CheckInferred(tg, csv_data, options, std::make_shared<ChunkedArray>(expected_chunks),
+ validate_full);
+ }
+
+ void CheckFixedType(const std::shared_ptr<TaskGroup>& tg,
+ const std::shared_ptr<DataType>& type, const ChunkData& csv_data,
+ const ConvertOptions& options,
+ std::shared_ptr<ChunkedArray> expected) {
+ std::shared_ptr<ColumnBuilder> builder;
+ std::shared_ptr<ChunkedArray> actual;
+ ASSERT_OK_AND_ASSIGN(
+ builder, ColumnBuilder::Make(default_memory_pool(), type, 0, options, tg));
+ AssertBuilding(builder, csv_data, &actual);
+ AssertChunkedEqual(*actual, *expected);
+ }
+
+ void CheckFixedType(const std::shared_ptr<TaskGroup>& tg,
+ const std::shared_ptr<DataType>& type, const ChunkData& csv_data,
+ const ConvertOptions& options,
+ std::vector<std::shared_ptr<Array>> expected_chunks) {
+ CheckFixedType(tg, type, csv_data, options,
+ std::make_shared<ChunkedArray>(expected_chunks));
+ }
+
+ protected:
+ ConvertOptions default_options = ConvertOptions::Defaults();
+};
+
+//////////////////////////////////////////////////////////////////////////
+// Tests for null column builder
+
+class NullColumnBuilderTest : public ColumnBuilderTest {};
+
+TEST_F(NullColumnBuilderTest, Empty) {
+ std::shared_ptr<DataType> type = null();
+ auto tg = TaskGroup::MakeSerial();
+
+ std::shared_ptr<ColumnBuilder> builder;
+ ASSERT_OK_AND_ASSIGN(builder, ColumnBuilder::MakeNull(default_memory_pool(), type, tg));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder, {}, &actual);
+
+ ChunkedArray expected({}, type);
+ AssertChunkedEqual(*actual, expected);
+}
+
+TEST_F(NullColumnBuilderTest, InsertNull) {
+ // Building a column of nulls with type null()
+ std::shared_ptr<DataType> type = null();
+ auto tg = TaskGroup::MakeSerial();
+
+ std::shared_ptr<ColumnBuilder> builder;
+ ASSERT_OK_AND_ASSIGN(builder, ColumnBuilder::MakeNull(default_memory_pool(), type, tg));
+
+ std::shared_ptr<BlockParser> parser;
+ std::shared_ptr<ChunkedArray> actual, expected;
+ // Those values are indifferent, only the number of rows is used
+ MakeColumnParser({"456", "789"}, &parser);
+ builder->Insert(1, parser);
+ MakeColumnParser({"123"}, &parser);
+ builder->Insert(0, parser);
+ ASSERT_OK(builder->task_group()->Finish());
+ ASSERT_OK_AND_ASSIGN(actual, builder->Finish());
+ ASSERT_OK(actual->ValidateFull());
+
+ auto chunks =
+ ArrayVector{std::make_shared<NullArray>(1), std::make_shared<NullArray>(2)};
+ expected = std::make_shared<ChunkedArray>(chunks);
+ AssertChunkedEqual(*actual, *expected);
+}
+
+TEST_F(NullColumnBuilderTest, InsertTyped) {
+ // Building a column of nulls with another type
+ std::shared_ptr<DataType> type = int16();
+ auto tg = TaskGroup::MakeSerial();
+
+ std::shared_ptr<ColumnBuilder> builder;
+ ASSERT_OK_AND_ASSIGN(builder, ColumnBuilder::MakeNull(default_memory_pool(), type, tg));
+
+ std::shared_ptr<BlockParser> parser;
+ std::shared_ptr<ChunkedArray> actual, expected;
+ // Those values are indifferent, only the number of rows is used
+ MakeColumnParser({"abc", "def", "ghi"}, &parser);
+ builder->Insert(1, parser);
+ MakeColumnParser({"jkl"}, &parser);
+ builder->Insert(0, parser);
+ ASSERT_OK(builder->task_group()->Finish());
+ ASSERT_OK_AND_ASSIGN(actual, builder->Finish());
+ ASSERT_OK(actual->ValidateFull());
+
+ auto chunks = ArrayVector{ArrayFromJSON(type, "[null]"),
+ ArrayFromJSON(type, "[null, null, null]")};
+ expected = std::make_shared<ChunkedArray>(chunks);
+ AssertChunkedEqual(*actual, *expected);
+}
+
+TEST_F(NullColumnBuilderTest, EmptyChunks) {
+ std::shared_ptr<DataType> type = int16();
+ auto tg = TaskGroup::MakeSerial();
+
+ std::shared_ptr<ColumnBuilder> builder;
+ ASSERT_OK_AND_ASSIGN(builder, ColumnBuilder::MakeNull(default_memory_pool(), type, tg));
+
+ std::shared_ptr<BlockParser> parser;
+ std::shared_ptr<ChunkedArray> actual, expected;
+ // Those values are indifferent, only the number of rows is used
+ MakeColumnParser({}, &parser);
+ builder->Insert(0, parser);
+ MakeColumnParser({"abc", "def"}, &parser);
+ builder->Insert(1, parser);
+ MakeColumnParser({}, &parser);
+ builder->Insert(2, parser);
+ ASSERT_OK(builder->task_group()->Finish());
+ ASSERT_OK_AND_ASSIGN(actual, builder->Finish());
+ ASSERT_OK(actual->ValidateFull());
+
+ auto chunks =
+ ArrayVector{ArrayFromJSON(type, "[]"), ArrayFromJSON(type, "[null, null]"),
+ ArrayFromJSON(type, "[]")};
+ expected = std::make_shared<ChunkedArray>(chunks);
+ AssertChunkedEqual(*actual, *expected);
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Tests for fixed-type column builder
+
+class TypedColumnBuilderTest : public ColumnBuilderTest {};
+
+TEST_F(TypedColumnBuilderTest, Empty) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ColumnBuilder> builder;
+ ASSERT_OK_AND_ASSIGN(
+ builder, ColumnBuilder::Make(default_memory_pool(), int32(), 0, options, tg));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder, {}, &actual);
+
+ ChunkedArray expected({}, int32());
+ AssertChunkedEqual(*actual, expected);
+}
+
+TEST_F(TypedColumnBuilderTest, Basics) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckFixedType(tg, int32(), {{"123", "-456"}}, options,
+ {ArrayFromJSON(int32(), "[123, -456]")});
+}
+
+TEST_F(TypedColumnBuilderTest, Insert) {
+ // Test ColumnBuilder::Insert()
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ColumnBuilder> builder;
+ ASSERT_OK_AND_ASSIGN(
+ builder, ColumnBuilder::Make(default_memory_pool(), int32(), 0, options, tg));
+
+ std::shared_ptr<BlockParser> parser;
+ std::shared_ptr<ChunkedArray> actual, expected;
+ MakeColumnParser({"456"}, &parser);
+ builder->Insert(1, parser);
+ MakeColumnParser({"123"}, &parser);
+ builder->Insert(0, parser);
+ ASSERT_OK(builder->task_group()->Finish());
+ ASSERT_OK_AND_ASSIGN(actual, builder->Finish());
+ ASSERT_OK(actual->ValidateFull());
+
+ ChunkedArrayFromVector<Int32Type>({{123}, {456}}, &expected);
+ AssertChunkedEqual(*actual, *expected);
+}
+
+TEST_F(TypedColumnBuilderTest, MultipleChunks) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckFixedType(tg, int16(), {{"1", "2", "3"}, {"4", "5"}}, options,
+ {ArrayFromJSON(int16(), "[1, 2, 3]"), ArrayFromJSON(int16(), "[4, 5]")});
+}
+
+TEST_F(TypedColumnBuilderTest, MultipleChunksParallel) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeThreaded(GetCpuThreadPool());
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<Int32Type>({{1, 2}, {3}, {4, 5}, {6, 7}}, &expected);
+ CheckFixedType(tg, int32(), {{"1", "2"}, {"3"}, {"4", "5"}, {"6", "7"}}, options,
+ expected);
+}
+
+TEST_F(TypedColumnBuilderTest, EmptyChunks) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckFixedType(tg, int16(), {{}, {"1", "2"}, {}}, options,
+ {ArrayFromJSON(int16(), "[]"), ArrayFromJSON(int16(), "[1, 2]"),
+ ArrayFromJSON(int16(), "[]")});
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Tests for type-inferring column builder
+
+class InferringColumnBuilderTest : public ColumnBuilderTest {
+ public:
+ void CheckAutoDictEncoded(const std::shared_ptr<TaskGroup>& tg,
+ const ChunkData& csv_data, const ConvertOptions& options,
+ std::vector<std::shared_ptr<Array>> expected_indices,
+ std::vector<std::shared_ptr<Array>> expected_dictionaries,
+ bool validate_full = true) {
+ std::shared_ptr<ColumnBuilder> builder;
+ std::shared_ptr<ChunkedArray> actual;
+ ASSERT_OK_AND_ASSIGN(builder,
+ ColumnBuilder::Make(default_memory_pool(), 0, options, tg));
+ AssertBuilding(builder, csv_data, validate_full, &actual);
+ ASSERT_EQ(actual->num_chunks(), static_cast<int>(csv_data.size()));
+ for (int i = 0; i < actual->num_chunks(); ++i) {
+ ASSERT_EQ(actual->chunk(i)->type_id(), Type::DICTIONARY);
+ const auto& dict_array = checked_cast<const DictionaryArray&>(*actual->chunk(i));
+ AssertArraysEqual(*dict_array.dictionary(), *expected_dictionaries[i]);
+ AssertArraysEqual(*dict_array.indices(), *expected_indices[i]);
+ }
+ }
+};
+
+TEST_F(InferringColumnBuilderTest, Empty) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckInferred(tg, {}, options, std::make_shared<ChunkedArray>(ArrayVector(), null()));
+}
+
+TEST_F(InferringColumnBuilderTest, SingleChunkNull) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckInferred(tg, {{"", "NA"}}, options, {std::make_shared<NullArray>(2)});
+}
+
+TEST_F(InferringColumnBuilderTest, MultipleChunkNull) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckInferred(tg, {{"", "NA"}, {""}, {"NaN"}}, options,
+ {std::make_shared<NullArray>(2), std::make_shared<NullArray>(1),
+ std::make_shared<NullArray>(1)});
+}
+
+TEST_F(InferringColumnBuilderTest, SingleChunkInteger) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckInferred(tg, {{"", "123", "456"}}, options,
+ {ArrayFromJSON(int64(), "[null, 123, 456]")});
+}
+
+TEST_F(InferringColumnBuilderTest, MultipleChunkInteger) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckInferred(
+ tg, {{""}, {"NA", "123", "456"}}, options,
+ {ArrayFromJSON(int64(), "[null]"), ArrayFromJSON(int64(), "[null, 123, 456]")});
+}
+
+TEST_F(InferringColumnBuilderTest, SingleChunkBoolean) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckInferred(tg, {{"", "0", "FALSE", "TRUE"}}, options,
+ {ArrayFromJSON(boolean(), "[null, false, false, true]")});
+}
+
+TEST_F(InferringColumnBuilderTest, MultipleChunkBoolean) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckInferred(tg, {{""}, {"1", "True", "0"}}, options,
+ {ArrayFromJSON(boolean(), "[null]"),
+ ArrayFromJSON(boolean(), "[true, true, false]")});
+}
+
+TEST_F(InferringColumnBuilderTest, SingleChunkReal) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckInferred(tg, {{"", "0.0", "12.5"}}, options,
+ {ArrayFromJSON(float64(), "[null, 0.0, 12.5]")});
+}
+
+TEST_F(InferringColumnBuilderTest, MultipleChunkReal) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckInferred(tg, {{""}, {"008"}, {"NaN", "12.5"}}, options,
+ {ArrayFromJSON(float64(), "[null]"), ArrayFromJSON(float64(), "[8.0]"),
+ ArrayFromJSON(float64(), "[null, 12.5]")});
+}
+
+TEST_F(InferringColumnBuilderTest, SingleChunkDate) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckInferred(tg, {{"", "1970-01-04", "NA"}}, options,
+ {ArrayFromJSON(date32(), "[null, 3, null]")});
+}
+
+TEST_F(InferringColumnBuilderTest, MultipleChunkDate) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckInferred(tg, {{""}, {"1970-01-04"}, {"NA"}}, options,
+ {ArrayFromJSON(date32(), "[null]"), ArrayFromJSON(date32(), "[3]"),
+ ArrayFromJSON(date32(), "[null]")});
+}
+
+TEST_F(InferringColumnBuilderTest, SingleChunkTime) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckInferred(tg, {{"", "01:23:45", "NA"}}, options,
+ {ArrayFromJSON(time32(TimeUnit::SECOND), "[null, 5025, null]")});
+}
+
+TEST_F(InferringColumnBuilderTest, MultipleChunkTime) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+ auto type = time32(TimeUnit::SECOND);
+
+ CheckInferred(tg, {{""}, {"01:23:45"}, {"NA"}}, options,
+ {ArrayFromJSON(type, "[null]"), ArrayFromJSON(type, "[5025]"),
+ ArrayFromJSON(type, "[null]")});
+}
+
+TEST_F(InferringColumnBuilderTest, SingleChunkTimestamp) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<TimestampType>(timestamp(TimeUnit::SECOND),
+ {{false, true, true}}, {{0, 0, 1542129070}},
+ &expected);
+ CheckInferred(tg, {{"", "1970-01-01", "2018-11-13 17:11:10"}}, options, expected);
+}
+
+TEST_F(InferringColumnBuilderTest, MultipleChunkTimestamp) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<TimestampType>(timestamp(TimeUnit::SECOND),
+ {{false}, {true}, {true}},
+ {{0}, {0}, {1542129070}}, &expected);
+ CheckInferred(tg, {{""}, {"1970-01-01"}, {"2018-11-13 17:11:10"}}, options, expected);
+}
+
+TEST_F(InferringColumnBuilderTest, SingleChunkTimestampNS) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<TimestampType>(
+ timestamp(TimeUnit::NANO), {{false, true, true, true, true}},
+ {{0, 0, 1542129070123000000, 1542129070123456000, 1542129070123456789}}, &expected);
+ CheckInferred(tg,
+ {{"", "1970-01-01", "2018-11-13 17:11:10.123",
+ "2018-11-13 17:11:10.123456", "2018-11-13 17:11:10.123456789"}},
+ options, expected);
+}
+
+TEST_F(InferringColumnBuilderTest, MultipleChunkTimestampNS) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<TimestampType>(
+ timestamp(TimeUnit::NANO), {{false}, {true}, {true, true, true}},
+ {{0}, {0}, {1542129070123000000, 1542129070123456000, 1542129070123456789}},
+ &expected);
+ CheckInferred(tg,
+ {{""},
+ {"1970-01-01"},
+ {"2018-11-13 17:11:10.123", "2018-11-13 17:11:10.123456",
+ "2018-11-13 17:11:10.123456789"}},
+ options, expected);
+}
+
+TEST_F(InferringColumnBuilderTest, SingleChunkIntegerAndTime) {
+ // Fallback to utf-8
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckInferred(tg, {{"", "99", "01:23:45", "NA"}}, options,
+ {ArrayFromJSON(utf8(), R"(["", "99", "01:23:45", "NA"])")});
+}
+
+TEST_F(InferringColumnBuilderTest, MultipleChunkIntegerAndTime) {
+ // Fallback to utf-8
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+ auto type = utf8();
+
+ CheckInferred(tg, {{""}, {"99"}, {"01:23:45", "NA"}}, options,
+ {ArrayFromJSON(type, R"([""])"), ArrayFromJSON(type, R"(["99"])"),
+ ArrayFromJSON(type, R"(["01:23:45", "NA"])")});
+}
+
+TEST_F(InferringColumnBuilderTest, SingleChunkDateAndTime) {
+ // Fallback to utf-8
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ CheckInferred(tg, {{"", "01:23:45", "1998-04-05"}}, options,
+ {ArrayFromJSON(utf8(), R"(["", "01:23:45", "1998-04-05"])")});
+}
+
+TEST_F(InferringColumnBuilderTest, MultipleChunkDateAndTime) {
+ // Fallback to utf-8
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+ auto type = utf8();
+
+ CheckInferred(tg, {{""}, {"01:23:45"}, {"1998-04-05"}}, options,
+ {ArrayFromJSON(type, R"([""])"), ArrayFromJSON(type, R"(["01:23:45"])"),
+ ArrayFromJSON(type, R"(["1998-04-05"])")});
+}
+
+TEST_F(InferringColumnBuilderTest, SingleChunkString) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArray> expected;
+
+ // With valid UTF8
+ CheckInferred(tg, {{"", "foo", "baré"}}, options,
+ {ArrayFromJSON(utf8(), R"(["", "foo", "baré"])")});
+
+ // With invalid UTF8, non-checking
+ options.check_utf8 = false;
+ tg = TaskGroup::MakeSerial();
+ ChunkedArrayFromVector<StringType, std::string>({{true, true, true}},
+ {{"", "foo\xff", "baré"}}, &expected);
+ CheckInferred(tg, {{"", "foo\xff", "baré"}}, options, expected,
+ /*validate_full=*/false);
+}
+
+TEST_F(InferringColumnBuilderTest, SingleChunkBinary) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArray> expected;
+
+ // With invalid UTF8, checking
+ tg = TaskGroup::MakeSerial();
+ ChunkedArrayFromVector<BinaryType, std::string>({{true, true, true}},
+ {{"", "foo\xff", "baré"}}, &expected);
+ CheckInferred(tg, {{"", "foo\xff", "baré"}}, options, expected);
+}
+
+TEST_F(InferringColumnBuilderTest, MultipleChunkString) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<StringType, std::string>(
+ {{true}, {true}, {true, true}}, {{""}, {"008"}, {"NaN", "baré"}}, &expected);
+
+ CheckInferred(tg, {{""}, {"008"}, {"NaN", "baré"}}, options, expected);
+}
+
+TEST_F(InferringColumnBuilderTest, MultipleChunkBinary) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeSerial();
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<BinaryType, std::string>(
+ {{true}, {true}, {true, true}}, {{""}, {"008"}, {"NaN", "baré\xff"}}, &expected);
+
+ CheckInferred(tg, {{""}, {"008"}, {"NaN", "baré\xff"}}, options, expected);
+}
+
+// Parallel parsing is tested more comprehensively on the Python side
+// (see python/pyarrow/tests/test_csv.py)
+
+TEST_F(InferringColumnBuilderTest, MultipleChunkIntegerParallel) {
+ auto options = ConvertOptions::Defaults();
+ auto tg = TaskGroup::MakeThreaded(GetCpuThreadPool());
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<Int64Type>({{1, 2}, {3}, {4, 5}, {6, 7}}, &expected);
+ CheckInferred(tg, {{"1", "2"}, {"3"}, {"4", "5"}, {"6", "7"}}, options, expected);
+}
+
+TEST_F(InferringColumnBuilderTest, SingleChunkBinaryAutoDict) {
+ auto options = ConvertOptions::Defaults();
+ options.auto_dict_encode = true;
+ options.auto_dict_max_cardinality = 3;
+
+ // With valid UTF8
+ auto expected_indices = ArrayFromJSON(int32(), "[0, 1, 0]");
+ auto expected_dictionary = ArrayFromJSON(utf8(), R"(["abé", "cd"])");
+ ChunkData csv_data = {{"abé", "cd", "abé"}};
+
+ CheckAutoDictEncoded(TaskGroup::MakeSerial(), csv_data, options, {expected_indices},
+ {expected_dictionary});
+
+ // With invalid UTF8, non-checking
+ csv_data = {{"ab", "cd\xff", "ab"}};
+ options.check_utf8 = false;
+ ArrayFromVector<StringType, std::string>({"ab", "cd\xff"}, &expected_dictionary);
+
+ CheckAutoDictEncoded(TaskGroup::MakeSerial(), csv_data, options, {expected_indices},
+ {expected_dictionary}, /*validate_full=*/false);
+
+ // With invalid UTF8, checking
+ options.check_utf8 = true;
+ ArrayFromVector<BinaryType, std::string>({"ab", "cd\xff"}, &expected_dictionary);
+
+ CheckAutoDictEncoded(TaskGroup::MakeSerial(), csv_data, options, {expected_indices},
+ {expected_dictionary});
+}
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/column_decoder.cc b/src/arrow/cpp/src/arrow/csv/column_decoder.cc
new file mode 100644
index 000000000..5f746c727
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/column_decoder.cc
@@ -0,0 +1,250 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/csv/column_decoder.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <utility>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_base.h"
+#include "arrow/csv/converter.h"
+#include "arrow/csv/inference_internal.h"
+#include "arrow/csv/options.h"
+#include "arrow/csv/parser.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/task_group.h"
+
+namespace arrow {
+namespace csv {
+
+using internal::TaskGroup;
+
+class ConcreteColumnDecoder : public ColumnDecoder {
+ public:
+ explicit ConcreteColumnDecoder(MemoryPool* pool, int32_t col_index = -1)
+ : ColumnDecoder(), pool_(pool), col_index_(col_index) {}
+
+ protected:
+ // XXX useful?
+ virtual std::shared_ptr<DataType> type() const = 0;
+
+ Result<std::shared_ptr<Array>> WrapConversionError(
+ const Result<std::shared_ptr<Array>>& result) {
+ if (ARROW_PREDICT_TRUE(result.ok())) {
+ return result;
+ } else {
+ const auto& st = result.status();
+ std::stringstream ss;
+ ss << "In CSV column #" << col_index_ << ": " << st.message();
+ return st.WithMessage(ss.str());
+ }
+ }
+
+ MemoryPool* pool_;
+ int32_t col_index_;
+ internal::Executor* executor_;
+};
+
+//////////////////////////////////////////////////////////////////////////
+// Null column decoder implementation (for a column not in the CSV file)
+
+class NullColumnDecoder : public ConcreteColumnDecoder {
+ public:
+ explicit NullColumnDecoder(const std::shared_ptr<DataType>& type, MemoryPool* pool)
+ : ConcreteColumnDecoder(pool), type_(type) {}
+
+ Future<std::shared_ptr<Array>> Decode(
+ const std::shared_ptr<BlockParser>& parser) override;
+
+ protected:
+ std::shared_ptr<DataType> type() const override { return type_; }
+
+ std::shared_ptr<DataType> type_;
+};
+
+Future<std::shared_ptr<Array>> NullColumnDecoder::Decode(
+ const std::shared_ptr<BlockParser>& parser) {
+ DCHECK_GE(parser->num_rows(), 0);
+ return WrapConversionError(MakeArrayOfNull(type_, parser->num_rows(), pool_));
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Pre-typed column decoder implementation
+
+class TypedColumnDecoder : public ConcreteColumnDecoder {
+ public:
+ TypedColumnDecoder(const std::shared_ptr<DataType>& type, int32_t col_index,
+ const ConvertOptions& options, MemoryPool* pool)
+ : ConcreteColumnDecoder(pool, col_index), type_(type), options_(options) {}
+
+ Status Init();
+
+ Future<std::shared_ptr<Array>> Decode(
+ const std::shared_ptr<BlockParser>& parser) override;
+
+ protected:
+ std::shared_ptr<DataType> type() const override { return type_; }
+
+ std::shared_ptr<DataType> type_;
+ // CAUTION: ConvertOptions can grow large (if it customizes hundreds or
+ // thousands of columns), so avoid copying it in each TypedColumnDecoder.
+ const ConvertOptions& options_;
+
+ std::shared_ptr<Converter> converter_;
+};
+
+Status TypedColumnDecoder::Init() {
+ ARROW_ASSIGN_OR_RAISE(converter_, Converter::Make(type_, options_, pool_));
+ return Status::OK();
+}
+
+Future<std::shared_ptr<Array>> TypedColumnDecoder::Decode(
+ const std::shared_ptr<BlockParser>& parser) {
+ DCHECK_NE(converter_, nullptr);
+ return Future<std::shared_ptr<Array>>::MakeFinished(
+ WrapConversionError(converter_->Convert(*parser, col_index_)));
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Type-inferring column builder implementation
+
+class InferringColumnDecoder : public ConcreteColumnDecoder {
+ public:
+ InferringColumnDecoder(int32_t col_index, const ConvertOptions& options,
+ MemoryPool* pool)
+ : ConcreteColumnDecoder(pool, col_index),
+ options_(options),
+ infer_status_(options),
+ type_frozen_(false) {
+ first_inference_run_ = Future<>::Make();
+ first_inferrer_ = 0;
+ }
+
+ Status Init();
+
+ Future<std::shared_ptr<Array>> Decode(
+ const std::shared_ptr<BlockParser>& parser) override;
+
+ protected:
+ std::shared_ptr<DataType> type() const override {
+ DCHECK_NE(converter_, nullptr);
+ return converter_->type();
+ }
+
+ Status UpdateType();
+ Result<std::shared_ptr<Array>> RunInference(const std::shared_ptr<BlockParser>& parser);
+
+ // CAUTION: ConvertOptions can grow large (if it customizes hundreds or
+ // thousands of columns), so avoid copying it in each InferringColumnDecoder.
+ const ConvertOptions& options_;
+
+ // Current inference status
+ InferStatus infer_status_;
+ bool type_frozen_;
+ std::atomic<int> first_inferrer_;
+ Future<> first_inference_run_;
+ std::shared_ptr<Converter> converter_;
+};
+
+Status InferringColumnDecoder::Init() { return UpdateType(); }
+
+Status InferringColumnDecoder::UpdateType() {
+ return infer_status_.MakeConverter(pool_).Value(&converter_);
+}
+
+Result<std::shared_ptr<Array>> InferringColumnDecoder::RunInference(
+ const std::shared_ptr<BlockParser>& parser) {
+ while (true) {
+ // (no one else should be updating converter_ concurrently)
+ auto maybe_array = converter_->Convert(*parser, col_index_);
+
+ if (maybe_array.ok() || !infer_status_.can_loosen_type()) {
+ // Conversion succeeded, or failed definitively
+ DCHECK(!type_frozen_);
+ type_frozen_ = true;
+ return maybe_array;
+ }
+ // Conversion failed temporarily, try another type
+ infer_status_.LoosenType(maybe_array.status());
+ auto update_status = UpdateType();
+ if (!update_status.ok()) {
+ return update_status;
+ }
+ }
+}
+
+Future<std::shared_ptr<Array>> InferringColumnDecoder::Decode(
+ const std::shared_ptr<BlockParser>& parser) {
+ // Empty arrays before the first inference run must be discarded since the type of the
+ // array will be NA and not match arrays decoded later
+ if (parser->num_rows() == 0) {
+ return Future<std::shared_ptr<Array>>::MakeFinished(
+ MakeArrayOfNull(converter_->type(), 0));
+ }
+
+ bool already_taken = first_inferrer_.fetch_or(1);
+ // First block: run inference
+ if (!already_taken) {
+ auto maybe_array = RunInference(parser);
+ first_inference_run_.MarkFinished();
+ return Future<std::shared_ptr<Array>>::MakeFinished(std::move(maybe_array));
+ }
+
+ // Non-first block: wait for inference to finish on first block now,
+ // without blocking a TaskGroup thread.
+ return first_inference_run_.Then([this, parser] {
+ DCHECK(type_frozen_);
+ auto maybe_array = converter_->Convert(*parser, col_index_);
+ return WrapConversionError(converter_->Convert(*parser, col_index_));
+ });
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Factory functions
+
+Result<std::shared_ptr<ColumnDecoder>> ColumnDecoder::Make(
+ MemoryPool* pool, int32_t col_index, const ConvertOptions& options) {
+ auto ptr = std::make_shared<InferringColumnDecoder>(col_index, options, pool);
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+}
+
+Result<std::shared_ptr<ColumnDecoder>> ColumnDecoder::Make(
+ MemoryPool* pool, std::shared_ptr<DataType> type, int32_t col_index,
+ const ConvertOptions& options) {
+ auto ptr =
+ std::make_shared<TypedColumnDecoder>(std::move(type), col_index, options, pool);
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+}
+
+Result<std::shared_ptr<ColumnDecoder>> ColumnDecoder::MakeNull(
+ MemoryPool* pool, std::shared_ptr<DataType> type) {
+ return std::make_shared<NullColumnDecoder>(std::move(type), pool);
+}
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/column_decoder.h b/src/arrow/cpp/src/arrow/csv/column_decoder.h
new file mode 100644
index 000000000..5fbbd5df5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/column_decoder.h
@@ -0,0 +1,64 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "arrow/result.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace csv {
+
+class BlockParser;
+struct ConvertOptions;
+
+class ARROW_EXPORT ColumnDecoder {
+ public:
+ virtual ~ColumnDecoder() = default;
+
+ /// Spawn a task that will try to convert and insert the given CSV block
+ virtual Future<std::shared_ptr<Array>> Decode(
+ const std::shared_ptr<BlockParser>& parser) = 0;
+
+ /// Construct a strictly-typed ColumnDecoder.
+ static Result<std::shared_ptr<ColumnDecoder>> Make(MemoryPool* pool,
+ std::shared_ptr<DataType> type,
+ int32_t col_index,
+ const ConvertOptions& options);
+
+ /// Construct a type-inferring ColumnDecoder.
+ /// Inference will run only on the first block, the type will be frozen afterwards.
+ static Result<std::shared_ptr<ColumnDecoder>> Make(MemoryPool* pool, int32_t col_index,
+ const ConvertOptions& options);
+
+ /// Construct a ColumnDecoder for a column of nulls
+ /// (i.e. not present in the CSV file).
+ static Result<std::shared_ptr<ColumnDecoder>> MakeNull(MemoryPool* pool,
+ std::shared_ptr<DataType> type);
+
+ protected:
+ ColumnDecoder() = default;
+};
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/column_decoder_test.cc b/src/arrow/cpp/src/arrow/csv/column_decoder_test.cc
new file mode 100644
index 000000000..c8b96e046
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/column_decoder_test.cc
@@ -0,0 +1,385 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/csv/column_decoder.h"
+#include "arrow/csv/options.h"
+#include "arrow/csv/test_common.h"
+#include "arrow/memory_pool.h"
+#include "arrow/table.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+namespace csv {
+
+class BlockParser;
+
+using internal::checked_cast;
+using internal::GetCpuThreadPool;
+
+using ChunkData = std::vector<std::vector<std::string>>;
+
+class ThreadJoiner {
+ public:
+ explicit ThreadJoiner(std::shared_ptr<std::thread> thread)
+ : thread_(std::move(thread)) {}
+
+ ~ThreadJoiner() {
+ if (thread_->joinable()) {
+ thread_->join();
+ }
+ }
+
+ protected:
+ std::shared_ptr<std::thread> thread_;
+};
+
+template <typename Func>
+ThreadJoiner RunThread(Func&& func) {
+ return ThreadJoiner(std::make_shared<std::thread>(std::forward<Func>(func)));
+}
+
+template <typename Func>
+void RunThreadsAndJoin(Func&& func, int iters) {
+ std::vector<ThreadJoiner> threads;
+ for (int i = 0; i < iters; i++) {
+ threads.emplace_back(std::make_shared<std::thread>([i, func] { func(i); }));
+ }
+}
+
+class ColumnDecoderTest : public ::testing::Test {
+ public:
+ ColumnDecoderTest() : num_chunks_(0), read_ptr_(0) {}
+
+ void SetDecoder(std::shared_ptr<ColumnDecoder> decoder) {
+ decoder_ = std::move(decoder);
+ decoded_chunks_.clear();
+ num_chunks_ = 0;
+ read_ptr_ = 0;
+ }
+
+ void InsertChunk(std::vector<std::string> chunk) {
+ std::shared_ptr<BlockParser> parser;
+ MakeColumnParser(chunk, &parser);
+ auto decoded = decoder_->Decode(parser);
+ decoded_chunks_.push_back(decoded);
+ ++num_chunks_;
+ }
+
+ void AppendChunks(const ChunkData& chunks) {
+ for (const auto& chunk : chunks) {
+ InsertChunk(chunk);
+ }
+ }
+
+ Result<std::shared_ptr<Array>> NextChunk() {
+ EXPECT_LT(read_ptr_, static_cast<int64_t>(decoded_chunks_.size()));
+ return decoded_chunks_[read_ptr_++].result();
+ }
+
+ void AssertChunk(std::vector<std::string> chunk, std::shared_ptr<Array> expected) {
+ std::shared_ptr<BlockParser> parser;
+ MakeColumnParser(chunk, &parser);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto decoded, decoder_->Decode(parser));
+ AssertArraysEqual(*expected, *decoded);
+ }
+
+ void AssertChunkInvalid(std::vector<std::string> chunk) {
+ std::shared_ptr<BlockParser> parser;
+ MakeColumnParser(chunk, &parser);
+ ASSERT_FINISHES_AND_RAISES(Invalid, decoder_->Decode(parser));
+ }
+
+ void AssertFetch(std::shared_ptr<Array> expected_chunk) {
+ ASSERT_OK_AND_ASSIGN(auto chunk, NextChunk());
+ ASSERT_NE(chunk, nullptr);
+ AssertArraysEqual(*expected_chunk, *chunk);
+ }
+
+ void AssertFetchInvalid() { ASSERT_RAISES(Invalid, NextChunk()); }
+
+ protected:
+ std::shared_ptr<ColumnDecoder> decoder_;
+ std::vector<Future<std::shared_ptr<Array>>> decoded_chunks_;
+ int64_t num_chunks_ = 0;
+ int64_t read_ptr_ = 0;
+
+ ConvertOptions default_options = ConvertOptions::Defaults();
+};
+
+//////////////////////////////////////////////////////////////////////////
+// Tests for null column decoder
+
+class NullColumnDecoderTest : public ColumnDecoderTest {
+ public:
+ NullColumnDecoderTest() {}
+
+ void MakeDecoder(std::shared_ptr<DataType> type) {
+ ASSERT_OK_AND_ASSIGN(auto decoder,
+ ColumnDecoder::MakeNull(default_memory_pool(), type));
+ SetDecoder(decoder);
+ }
+
+ void TestNullType() {
+ auto type = null();
+
+ MakeDecoder(type);
+
+ AppendChunks({{"1", "2", "3"}, {"4", "5"}});
+ AssertFetch(ArrayFromJSON(type, "[null, null, null]"));
+ AssertFetch(ArrayFromJSON(type, "[null, null]"));
+
+ MakeDecoder(type);
+
+ AppendChunks({{}, {"6"}});
+ AssertFetch(ArrayFromJSON(type, "[]"));
+ AppendChunks({{"7", "8"}});
+ AssertFetch(ArrayFromJSON(type, "[null]"));
+ AssertFetch(ArrayFromJSON(type, "[null, null]"));
+ }
+
+ void TestOtherType() {
+ auto type = int32();
+
+ MakeDecoder(type);
+
+ AppendChunks({{"1", "2", "3"}, {"4", "5"}});
+ AssertFetch(ArrayFromJSON(type, "[null, null, null]"));
+ AssertFetch(ArrayFromJSON(type, "[null, null]"));
+ }
+
+ void TestThreaded() {
+ constexpr int NITERS = 10;
+ auto type = int32();
+ MakeDecoder(type);
+
+ RunThreadsAndJoin(
+ [&](int thread_id) {
+ AssertChunk({"4", "5", std::to_string(thread_id)},
+ ArrayFromJSON(type, "[null, null, null]"));
+ },
+ NITERS);
+ }
+};
+
+TEST_F(NullColumnDecoderTest, NullType) { this->TestNullType(); }
+
+TEST_F(NullColumnDecoderTest, OtherType) { this->TestOtherType(); }
+
+TEST_F(NullColumnDecoderTest, Threaded) { this->TestThreaded(); }
+
+//////////////////////////////////////////////////////////////////////////
+// Tests for fixed-type column decoder
+
+class TypedColumnDecoderTest : public ColumnDecoderTest {
+ public:
+ TypedColumnDecoderTest() {}
+
+ void MakeDecoder(const std::shared_ptr<DataType>& type, const ConvertOptions& options) {
+ ASSERT_OK_AND_ASSIGN(auto decoder,
+ ColumnDecoder::Make(default_memory_pool(), type, 0, options));
+ SetDecoder(decoder);
+ }
+
+ void TestIntegers() {
+ auto type = int16();
+
+ MakeDecoder(type, default_options);
+
+ AppendChunks({{"123", "456", "-78"}, {"901", "N/A"}});
+ AssertFetch(ArrayFromJSON(type, "[123, 456, -78]"));
+ AssertFetch(ArrayFromJSON(type, "[901, null]"));
+
+ MakeDecoder(type, default_options);
+
+ AppendChunks({{}, {"-987"}});
+ AssertFetch(ArrayFromJSON(type, "[]"));
+ AppendChunks({{"N/A", "N/A"}});
+ AssertFetch(ArrayFromJSON(type, "[-987]"));
+ AssertFetch(ArrayFromJSON(type, "[null, null]"));
+ }
+
+ void TestOptions() {
+ auto type = boolean();
+
+ MakeDecoder(type, default_options);
+
+ AppendChunks({{"true", "false", "N/A"}});
+ AssertFetch(ArrayFromJSON(type, "[true, false, null]"));
+
+ // With non-default options
+ auto options = default_options;
+ options.null_values = {"true"};
+ options.true_values = {"false"};
+ options.false_values = {"N/A"};
+ MakeDecoder(type, options);
+
+ AppendChunks({{"true", "false", "N/A"}});
+ AssertFetch(ArrayFromJSON(type, "[null, true, false]"));
+ }
+
+ void TestErrors() {
+ auto type = uint64();
+
+ MakeDecoder(type, default_options);
+
+ AppendChunks({{"123", "456", "N/A"}, {"-901"}});
+ AppendChunks({{"N/A", "1000"}});
+ AssertFetch(ArrayFromJSON(type, "[123, 456, null]"));
+ AssertFetchInvalid();
+ AssertFetch(ArrayFromJSON(type, "[null, 1000]"));
+ }
+
+ void TestThreaded() {
+ constexpr int NITERS = 10;
+ auto type = uint32();
+ MakeDecoder(type, default_options);
+
+ RunThreadsAndJoin(
+ [&](int thread_id) {
+ if (thread_id % 2 == 0) {
+ AssertChunkInvalid({"4", "-5"});
+ } else {
+ AssertChunk({"1", "2", "3"}, ArrayFromJSON(type, "[1, 2, 3]"));
+ }
+ },
+ NITERS);
+ }
+};
+
+TEST_F(TypedColumnDecoderTest, Integers) { this->TestIntegers(); }
+
+TEST_F(TypedColumnDecoderTest, Options) { this->TestOptions(); }
+
+TEST_F(TypedColumnDecoderTest, Errors) { this->TestErrors(); }
+
+TEST_F(TypedColumnDecoderTest, Threaded) { this->TestThreaded(); }
+
+//////////////////////////////////////////////////////////////////////////
+// Tests for type-inferring column decoder
+
+class InferringColumnDecoderTest : public ColumnDecoderTest {
+ public:
+ InferringColumnDecoderTest() {}
+
+ void MakeDecoder(const ConvertOptions& options) {
+ ASSERT_OK_AND_ASSIGN(auto decoder,
+ ColumnDecoder::Make(default_memory_pool(), 0, options));
+ SetDecoder(decoder);
+ }
+
+ void TestIntegers() {
+ auto type = int64();
+
+ MakeDecoder(default_options);
+
+ AppendChunks({{"123", "456", "-78"}, {"901", "N/A"}});
+ AssertFetch(ArrayFromJSON(type, "[123, 456, -78]"));
+ AssertFetch(ArrayFromJSON(type, "[901, null]"));
+ }
+
+ void TestThreaded() {
+ constexpr int NITERS = 10;
+ auto type = float64();
+ MakeDecoder(default_options);
+
+ // One of these will do the inference so we need to make sure they all have floating
+ // point
+ RunThreadsAndJoin(
+ [&](int thread_id) {
+ if (thread_id % 2 == 0) {
+ AssertChunk({"6.3", "7.2"}, ArrayFromJSON(type, "[6.3, 7.2]"));
+ } else {
+ AssertChunk({"1.1", "2", "3"}, ArrayFromJSON(type, "[1.1, 2, 3]"));
+ }
+ },
+ NITERS);
+
+ // These will run after the inference
+ RunThreadsAndJoin(
+ [&](int thread_id) {
+ if (thread_id % 2 == 0) {
+ AssertChunk({"1", "2"}, ArrayFromJSON(type, "[1, 2]"));
+ } else {
+ AssertChunkInvalid({"xyz"});
+ }
+ },
+ NITERS);
+ }
+
+ void TestOptions() {
+ auto type = boolean();
+
+ auto options = default_options;
+ options.null_values = {"true"};
+ options.true_values = {"false"};
+ options.false_values = {"N/A"};
+ MakeDecoder(options);
+
+ AppendChunks({{"true", "false", "N/A"}, {"true"}});
+ AssertFetch(ArrayFromJSON(type, "[null, true, false]"));
+ AssertFetch(ArrayFromJSON(type, "[null]"));
+ }
+
+ void TestErrors() {
+ auto type = int64();
+
+ MakeDecoder(default_options);
+
+ AppendChunks({{"123", "456", "-78"}, {"9.5", "N/A"}});
+ AppendChunks({{"1000", "N/A"}});
+ AssertFetch(ArrayFromJSON(type, "[123, 456, -78]"));
+ AssertFetchInvalid();
+ AssertFetch(ArrayFromJSON(type, "[1000, null]"));
+ }
+
+ void TestEmpty() {
+ auto type = null();
+
+ MakeDecoder(default_options);
+
+ AppendChunks({{}, {}});
+ AssertFetch(ArrayFromJSON(type, "[]"));
+ AssertFetch(ArrayFromJSON(type, "[]"));
+ }
+};
+
+TEST_F(InferringColumnDecoderTest, Integers) { this->TestIntegers(); }
+
+TEST_F(InferringColumnDecoderTest, Threaded) { this->TestThreaded(); }
+
+TEST_F(InferringColumnDecoderTest, Options) { this->TestOptions(); }
+
+TEST_F(InferringColumnDecoderTest, Errors) { this->TestErrors(); }
+
+TEST_F(InferringColumnDecoderTest, Empty) { this->TestEmpty(); }
+
+// More inference tests are in InferringColumnBuilderTest
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/converter.cc b/src/arrow/cpp/src/arrow/csv/converter.cc
new file mode 100644
index 000000000..66d054580
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/converter.cc
@@ -0,0 +1,780 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/csv/converter.h"
+
+#include <array>
+#include <cstring>
+#include <limits>
+#include <sstream>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/array/builder_dict.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/csv/parser.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/trie.h"
+#include "arrow/util/utf8.h"
+#include "arrow/util/value_parsing.h" // IWYU pragma: keep
+
+namespace arrow {
+namespace csv {
+
+using internal::checked_cast;
+using internal::Trie;
+using internal::TrieBuilder;
+
+namespace {
+
+Status GenericConversionError(const std::shared_ptr<DataType>& type, const uint8_t* data,
+ uint32_t size) {
+ return Status::Invalid("CSV conversion error to ", type->ToString(),
+ ": invalid value '",
+ std::string(reinterpret_cast<const char*>(data), size), "'");
+}
+
+inline bool IsWhitespace(uint8_t c) {
+ if (ARROW_PREDICT_TRUE(c > ' ')) {
+ return false;
+ }
+ return c == ' ' || c == '\t';
+}
+
+// Updates data_inout and size_inout to not include leading/trailing whitespace
+// characters.
+inline void TrimWhiteSpace(const uint8_t** data_inout, uint32_t* size_inout) {
+ const uint8_t*& data = *data_inout;
+ uint32_t& size = *size_inout;
+ // Skip trailing whitespace
+ if (ARROW_PREDICT_TRUE(size > 0) && ARROW_PREDICT_FALSE(IsWhitespace(data[size - 1]))) {
+ const uint8_t* p = data + size - 1;
+ while (size > 0 && IsWhitespace(*p)) {
+ --size;
+ --p;
+ }
+ }
+ // Skip leading whitespace
+ if (ARROW_PREDICT_TRUE(size > 0) && ARROW_PREDICT_FALSE(IsWhitespace(data[0]))) {
+ while (size > 0 && IsWhitespace(*data)) {
+ --size;
+ ++data;
+ }
+ }
+}
+
+Status InitializeTrie(const std::vector<std::string>& inputs, Trie* trie) {
+ TrieBuilder builder;
+ for (const auto& s : inputs) {
+ RETURN_NOT_OK(builder.Append(s, true /* allow_duplicates */));
+ }
+ *trie = builder.Finish();
+ return Status::OK();
+}
+
+// Presize a builder based on parser contents
+template <typename BuilderType>
+enable_if_t<!is_base_binary_type<typename BuilderType::TypeClass>::value, Status>
+PresizeBuilder(const BlockParser& parser, BuilderType* builder) {
+ return builder->Resize(parser.num_rows());
+}
+
+// Same, for variable-sized binary builders
+template <typename T>
+Status PresizeBuilder(const BlockParser& parser, BaseBinaryBuilder<T>* builder) {
+ RETURN_NOT_OK(builder->Resize(parser.num_rows()));
+ return builder->ReserveData(parser.num_bytes());
+}
+
+/////////////////////////////////////////////////////////////////////////
+// Per-type value decoders
+
+struct ValueDecoder {
+ explicit ValueDecoder(const std::shared_ptr<DataType>& type,
+ const ConvertOptions& options)
+ : type_(type), options_(options) {}
+
+ Status Initialize() {
+ // TODO no need to build a separate Trie for each instance
+ return InitializeTrie(options_.null_values, &null_trie_);
+ }
+
+ bool IsNull(const uint8_t* data, uint32_t size, bool quoted) {
+ if (quoted && !options_.quoted_strings_can_be_null) {
+ return false;
+ }
+ return null_trie_.Find(
+ util::string_view(reinterpret_cast<const char*>(data), size)) >= 0;
+ }
+
+ protected:
+ Trie null_trie_;
+ const std::shared_ptr<DataType> type_;
+ const ConvertOptions& options_;
+};
+
+//
+// Value decoder for fixed-size binary
+//
+
+struct FixedSizeBinaryValueDecoder : public ValueDecoder {
+ using value_type = const uint8_t*;
+
+ explicit FixedSizeBinaryValueDecoder(const std::shared_ptr<DataType>& type,
+ const ConvertOptions& options)
+ : ValueDecoder(type, options),
+ byte_width_(checked_cast<const FixedSizeBinaryType&>(*type).byte_width()) {}
+
+ Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) {
+ if (ARROW_PREDICT_FALSE(size != byte_width_)) {
+ return Status::Invalid("CSV conversion error to ", type_->ToString(), ": got a ",
+ size, "-byte long string");
+ }
+ *out = data;
+ return Status::OK();
+ }
+
+ protected:
+ const uint32_t byte_width_;
+};
+
+//
+// Value decoder for variable-size binary
+//
+
+template <bool CheckUTF8>
+struct BinaryValueDecoder : public ValueDecoder {
+ using value_type = util::string_view;
+
+ using ValueDecoder::ValueDecoder;
+
+ Status Initialize() {
+ util::InitializeUTF8();
+ return ValueDecoder::Initialize();
+ }
+
+ Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) {
+ if (CheckUTF8 && ARROW_PREDICT_FALSE(!util::ValidateUTF8(data, size))) {
+ return Status::Invalid("CSV conversion error to ", type_->ToString(),
+ ": invalid UTF8 data");
+ }
+ *out = {reinterpret_cast<const char*>(data), size};
+ return Status::OK();
+ }
+
+ bool IsNull(const uint8_t* data, uint32_t size, bool quoted) {
+ return options_.strings_can_be_null &&
+ (!quoted || options_.quoted_strings_can_be_null) &&
+ ValueDecoder::IsNull(data, size, false /* quoted */);
+ }
+};
+
+//
+// Value decoder for integers, floats and temporals
+//
+
+template <typename T>
+struct NumericValueDecoder : public ValueDecoder {
+ using value_type = typename T::c_type;
+
+ explicit NumericValueDecoder(const std::shared_ptr<DataType>& type,
+ const ConvertOptions& options)
+ : ValueDecoder(type, options), concrete_type_(checked_cast<const T&>(*type)) {}
+
+ Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) {
+ // XXX should quoted values be allowed at all?
+ TrimWhiteSpace(&data, &size);
+ if (ARROW_PREDICT_FALSE(!internal::ParseValue<T>(
+ concrete_type_, reinterpret_cast<const char*>(data), size, out))) {
+ return GenericConversionError(type_, data, size);
+ }
+ return Status::OK();
+ }
+
+ protected:
+ const T& concrete_type_;
+};
+
+//
+// Value decoder for booleans
+//
+
+struct BooleanValueDecoder : public ValueDecoder {
+ using value_type = bool;
+
+ using ValueDecoder::ValueDecoder;
+
+ Status Initialize() {
+ // TODO no need to build separate Tries for each instance
+ RETURN_NOT_OK(InitializeTrie(options_.true_values, &true_trie_));
+ RETURN_NOT_OK(InitializeTrie(options_.false_values, &false_trie_));
+ return ValueDecoder::Initialize();
+ }
+
+ Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) {
+ // XXX should quoted values be allowed at all?
+ if (false_trie_.Find(util::string_view(reinterpret_cast<const char*>(data), size)) >=
+ 0) {
+ *out = false;
+ return Status::OK();
+ }
+ if (ARROW_PREDICT_TRUE(true_trie_.Find(util::string_view(
+ reinterpret_cast<const char*>(data), size)) >= 0)) {
+ *out = true;
+ return Status::OK();
+ }
+ return GenericConversionError(type_, data, size);
+ }
+
+ protected:
+ Trie true_trie_;
+ Trie false_trie_;
+};
+
+//
+// Value decoder for decimals
+//
+
+struct DecimalValueDecoder : public ValueDecoder {
+ using value_type = Decimal128;
+
+ explicit DecimalValueDecoder(const std::shared_ptr<DataType>& type,
+ const ConvertOptions& options)
+ : ValueDecoder(type, options),
+ decimal_type_(internal::checked_cast<const DecimalType&>(*type_)),
+ type_precision_(decimal_type_.precision()),
+ type_scale_(decimal_type_.scale()) {}
+
+ Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) {
+ TrimWhiteSpace(&data, &size);
+ Decimal128 decimal;
+ int32_t precision, scale;
+ util::string_view view(reinterpret_cast<const char*>(data), size);
+ RETURN_NOT_OK(Decimal128::FromString(view, &decimal, &precision, &scale));
+ if (precision > type_precision_) {
+ return Status::Invalid("Error converting '", view, "' to ", type_->ToString(),
+ ": precision not supported by type.");
+ }
+ if (scale != type_scale_) {
+ ARROW_ASSIGN_OR_RAISE(*out, decimal.Rescale(scale, type_scale_));
+ } else {
+ *out = std::move(decimal);
+ }
+ return Status::OK();
+ }
+
+ protected:
+ const DecimalType& decimal_type_;
+ const int32_t type_precision_;
+ const int32_t type_scale_;
+};
+
+//
+// Value decoder wrapper for floating-point and decimals
+// with a non-default decimal point
+//
+
+template <typename WrappedDecoder>
+struct CustomDecimalPointValueDecoder : public ValueDecoder {
+ using value_type = typename WrappedDecoder::value_type;
+
+ explicit CustomDecimalPointValueDecoder(const std::shared_ptr<DataType>& type,
+ const ConvertOptions& options)
+ : ValueDecoder(type, options), wrapped_decoder_(type, options) {}
+
+ Status Initialize() {
+ RETURN_NOT_OK(wrapped_decoder_.Initialize());
+ for (int i = 0; i < 256; ++i) {
+ mapping_[i] = i;
+ }
+ mapping_[options_.decimal_point] = '.';
+ mapping_['.'] = options_.decimal_point; // error out on standard decimal point
+ temp_.resize(30);
+ return Status::OK();
+ }
+
+ Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) {
+ if (ARROW_PREDICT_FALSE(size > temp_.size())) {
+ temp_.resize(size);
+ }
+ uint8_t* temp_data = temp_.data();
+ for (uint32_t i = 0; i < size; ++i) {
+ temp_data[i] = mapping_[data[i]];
+ }
+ if (ARROW_PREDICT_FALSE(
+ !wrapped_decoder_.Decode(temp_data, size, quoted, out).ok())) {
+ return GenericConversionError(type_, data, size);
+ }
+ return Status::OK();
+ }
+
+ bool IsNull(const uint8_t* data, uint32_t size, bool quoted) {
+ return wrapped_decoder_.IsNull(data, size, quoted);
+ }
+
+ protected:
+ WrappedDecoder wrapped_decoder_;
+ std::array<uint8_t, 256> mapping_;
+ std::vector<uint8_t> temp_;
+};
+
+//
+// Value decoders for timestamps
+//
+
+struct InlineISO8601ValueDecoder : public ValueDecoder {
+ using value_type = int64_t;
+
+ explicit InlineISO8601ValueDecoder(const std::shared_ptr<DataType>& type,
+ const ConvertOptions& options)
+ : ValueDecoder(type, options),
+ unit_(checked_cast<const TimestampType&>(*type_).unit()) {}
+
+ Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) {
+ if (ARROW_PREDICT_FALSE(!internal::ParseTimestampISO8601(
+ reinterpret_cast<const char*>(data), size, unit_, out))) {
+ return GenericConversionError(type_, data, size);
+ }
+ return Status::OK();
+ }
+
+ protected:
+ TimeUnit::type unit_;
+};
+
+struct SingleParserTimestampValueDecoder : public ValueDecoder {
+ using value_type = int64_t;
+
+ explicit SingleParserTimestampValueDecoder(const std::shared_ptr<DataType>& type,
+ const ConvertOptions& options)
+ : ValueDecoder(type, options),
+ unit_(checked_cast<const TimestampType&>(*type_).unit()),
+ parser_(*options_.timestamp_parsers[0]) {}
+
+ Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) {
+ if (ARROW_PREDICT_FALSE(
+ !parser_(reinterpret_cast<const char*>(data), size, unit_, out))) {
+ return GenericConversionError(type_, data, size);
+ }
+ return Status::OK();
+ }
+
+ protected:
+ TimeUnit::type unit_;
+ const TimestampParser& parser_;
+};
+
+struct MultipleParsersTimestampValueDecoder : public ValueDecoder {
+ using value_type = int64_t;
+
+ explicit MultipleParsersTimestampValueDecoder(const std::shared_ptr<DataType>& type,
+ const ConvertOptions& options)
+ : ValueDecoder(type, options),
+ unit_(checked_cast<const TimestampType&>(*type_).unit()),
+ parsers_(GetParsers(options_)) {}
+
+ Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) {
+ for (const auto& parser : parsers_) {
+ if (parser->operator()(reinterpret_cast<const char*>(data), size, unit_, out)) {
+ return Status::OK();
+ }
+ }
+ return GenericConversionError(type_, data, size);
+ }
+
+ protected:
+ using ParserVector = std::vector<const TimestampParser*>;
+
+ static ParserVector GetParsers(const ConvertOptions& options) {
+ ParserVector parsers(options.timestamp_parsers.size());
+ for (size_t i = 0; i < options.timestamp_parsers.size(); ++i) {
+ parsers[i] = options.timestamp_parsers[i].get();
+ }
+ return parsers;
+ }
+
+ TimeUnit::type unit_;
+ std::vector<const TimestampParser*> parsers_;
+};
+
+/////////////////////////////////////////////////////////////////////////
+// Concrete Converter hierarchy
+
+class ConcreteConverter : public Converter {
+ public:
+ using Converter::Converter;
+};
+
+class ConcreteDictionaryConverter : public DictionaryConverter {
+ public:
+ using DictionaryConverter::DictionaryConverter;
+};
+
+//
+// Concrete Converter for nulls
+//
+
+class NullConverter : public ConcreteConverter {
+ public:
+ NullConverter(const std::shared_ptr<DataType>& type, const ConvertOptions& options,
+ MemoryPool* pool)
+ : ConcreteConverter(type, options, pool), decoder_(type_, options_) {}
+
+ Result<std::shared_ptr<Array>> Convert(const BlockParser& parser,
+ int32_t col_index) override {
+ NullBuilder builder(pool_);
+
+ auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status {
+ if (ARROW_PREDICT_TRUE(decoder_.IsNull(data, size, quoted))) {
+ return builder.AppendNull();
+ } else {
+ return GenericConversionError(type_, data, size);
+ }
+ };
+ RETURN_NOT_OK(parser.VisitColumn(col_index, visit));
+ std::shared_ptr<Array> res;
+ RETURN_NOT_OK(builder.Finish(&res));
+ return res;
+ }
+
+ protected:
+ Status Initialize() override { return decoder_.Initialize(); }
+
+ ValueDecoder decoder_;
+};
+
+//
+// Concrete Converter for primitives
+//
+
+template <typename T, typename ValueDecoderType>
+class PrimitiveConverter : public ConcreteConverter {
+ public:
+ PrimitiveConverter(const std::shared_ptr<DataType>& type, const ConvertOptions& options,
+ MemoryPool* pool)
+ : ConcreteConverter(type, options, pool), decoder_(type_, options_) {}
+
+ Result<std::shared_ptr<Array>> Convert(const BlockParser& parser,
+ int32_t col_index) override {
+ using BuilderType = typename TypeTraits<T>::BuilderType;
+ using value_type = typename ValueDecoderType::value_type;
+
+ BuilderType builder(type_, pool_);
+ RETURN_NOT_OK(PresizeBuilder(parser, &builder));
+
+ auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status {
+ if (decoder_.IsNull(data, size, quoted /* quoted */)) {
+ return builder.AppendNull();
+ }
+ value_type value{};
+ RETURN_NOT_OK(decoder_.Decode(data, size, quoted, &value));
+ builder.UnsafeAppend(value);
+ return Status::OK();
+ };
+ RETURN_NOT_OK(parser.VisitColumn(col_index, visit));
+
+ std::shared_ptr<Array> res;
+ RETURN_NOT_OK(builder.Finish(&res));
+ return res;
+ }
+
+ protected:
+ Status Initialize() override { return decoder_.Initialize(); }
+
+ ValueDecoderType decoder_;
+};
+
+//
+// Concrete Converter for dictionaries
+//
+
+template <typename T, typename ValueDecoderType>
+class TypedDictionaryConverter : public ConcreteDictionaryConverter {
+ public:
+ TypedDictionaryConverter(const std::shared_ptr<DataType>& value_type,
+ const ConvertOptions& options, MemoryPool* pool)
+ : ConcreteDictionaryConverter(value_type, options, pool),
+ decoder_(value_type, options_) {}
+
+ Result<std::shared_ptr<Array>> Convert(const BlockParser& parser,
+ int32_t col_index) override {
+ // We use a fixed index width so that all column chunks get the same index type
+ using BuilderType = Dictionary32Builder<T>;
+ using value_type = typename ValueDecoderType::value_type;
+
+ BuilderType builder(value_type_, pool_);
+ RETURN_NOT_OK(PresizeBuilder(parser, &builder));
+
+ auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status {
+ if (decoder_.IsNull(data, size, quoted /* quoted */)) {
+ return builder.AppendNull();
+ }
+ if (ARROW_PREDICT_FALSE(builder.dictionary_length() > max_cardinality_)) {
+ return Status::IndexError("Dictionary length exceeded max cardinality");
+ }
+ value_type value{};
+ RETURN_NOT_OK(decoder_.Decode(data, size, quoted, &value));
+ return builder.Append(value);
+ };
+ RETURN_NOT_OK(parser.VisitColumn(col_index, visit));
+
+ std::shared_ptr<Array> res;
+ RETURN_NOT_OK(builder.Finish(&res));
+ return res;
+ }
+
+ void SetMaxCardinality(int32_t max_length) override { max_cardinality_ = max_length; }
+
+ protected:
+ Status Initialize() override {
+ util::InitializeUTF8();
+ return decoder_.Initialize();
+ }
+
+ ValueDecoderType decoder_;
+ int32_t max_cardinality_ = std::numeric_limits<int32_t>::max();
+};
+
+//
+// Concrete Converter factory for timestamps
+//
+
+template <template <typename, typename> class ConverterType>
+std::shared_ptr<Converter> MakeTimestampConverter(const std::shared_ptr<DataType>& type,
+ const ConvertOptions& options,
+ MemoryPool* pool) {
+ if (options.timestamp_parsers.size() == 0) {
+ // Default to ISO-8601
+ return std::make_shared<ConverterType<TimestampType, InlineISO8601ValueDecoder>>(
+ type, options, pool);
+ } else if (options.timestamp_parsers.size() == 1) {
+ // Single user-supplied converter
+ return std::make_shared<
+ ConverterType<TimestampType, SingleParserTimestampValueDecoder>>(type, options,
+ pool);
+ } else {
+ // Multiple converters, must iterate for each value
+ return std::make_shared<
+ ConverterType<TimestampType, MultipleParsersTimestampValueDecoder>>(type, options,
+ pool);
+ }
+}
+
+//
+// Concrete Converter factory for reals
+//
+
+template <typename ConverterType, template <typename...> class ConcreteConverterType,
+ typename Type, typename DecoderType>
+std::shared_ptr<ConverterType> MakeRealConverter(const std::shared_ptr<DataType>& type,
+ const ConvertOptions& options,
+ MemoryPool* pool) {
+ if (options.decimal_point == '.') {
+ return std::make_shared<ConcreteConverterType<Type, DecoderType>>(type, options,
+ pool);
+ }
+ return std::make_shared<
+ ConcreteConverterType<Type, CustomDecimalPointValueDecoder<DecoderType>>>(
+ type, options, pool);
+}
+
+} // namespace
+
+/////////////////////////////////////////////////////////////////////////
+// Base Converter class implementation
+
+Converter::Converter(const std::shared_ptr<DataType>& type, const ConvertOptions& options,
+ MemoryPool* pool)
+ : options_(options), pool_(pool), type_(type) {}
+
+DictionaryConverter::DictionaryConverter(const std::shared_ptr<DataType>& value_type,
+ const ConvertOptions& options, MemoryPool* pool)
+ : Converter(dictionary(int32(), value_type), options, pool),
+ value_type_(value_type) {}
+
+Result<std::shared_ptr<Converter>> Converter::Make(const std::shared_ptr<DataType>& type,
+ const ConvertOptions& options,
+ MemoryPool* pool) {
+ std::shared_ptr<Converter> ptr;
+
+ switch (type->id()) {
+#define CONVERTER_CASE(TYPE_ID, CONVERTER_TYPE) \
+ case TYPE_ID: \
+ ptr.reset(new CONVERTER_TYPE(type, options, pool)); \
+ break;
+
+#define NUMERIC_CONVERTER_CASE(TYPE_ID, TYPE_CLASS) \
+ CONVERTER_CASE(TYPE_ID, \
+ (PrimitiveConverter<TYPE_CLASS, NumericValueDecoder<TYPE_CLASS>>))
+
+#define REAL_CONVERTER_CASE(TYPE_ID, TYPE_CLASS, DECODER) \
+ case TYPE_ID: \
+ ptr = MakeRealConverter<Converter, PrimitiveConverter, TYPE_CLASS, DECODER>( \
+ type, options, pool); \
+ break;
+
+ CONVERTER_CASE(Type::NA, NullConverter)
+ NUMERIC_CONVERTER_CASE(Type::INT8, Int8Type)
+ NUMERIC_CONVERTER_CASE(Type::INT16, Int16Type)
+ NUMERIC_CONVERTER_CASE(Type::INT32, Int32Type)
+ NUMERIC_CONVERTER_CASE(Type::INT64, Int64Type)
+ NUMERIC_CONVERTER_CASE(Type::UINT8, UInt8Type)
+ NUMERIC_CONVERTER_CASE(Type::UINT16, UInt16Type)
+ NUMERIC_CONVERTER_CASE(Type::UINT32, UInt32Type)
+ NUMERIC_CONVERTER_CASE(Type::UINT64, UInt64Type)
+ REAL_CONVERTER_CASE(Type::FLOAT, FloatType, NumericValueDecoder<FloatType>)
+ REAL_CONVERTER_CASE(Type::DOUBLE, DoubleType, NumericValueDecoder<DoubleType>)
+ REAL_CONVERTER_CASE(Type::DECIMAL, Decimal128Type, DecimalValueDecoder)
+ NUMERIC_CONVERTER_CASE(Type::DATE32, Date32Type)
+ NUMERIC_CONVERTER_CASE(Type::DATE64, Date64Type)
+ NUMERIC_CONVERTER_CASE(Type::TIME32, Time32Type)
+ NUMERIC_CONVERTER_CASE(Type::TIME64, Time64Type)
+ CONVERTER_CASE(Type::BOOL, (PrimitiveConverter<BooleanType, BooleanValueDecoder>))
+ CONVERTER_CASE(Type::BINARY,
+ (PrimitiveConverter<BinaryType, BinaryValueDecoder<false>>))
+ CONVERTER_CASE(Type::LARGE_BINARY,
+ (PrimitiveConverter<LargeBinaryType, BinaryValueDecoder<false>>))
+ CONVERTER_CASE(Type::FIXED_SIZE_BINARY,
+ (PrimitiveConverter<FixedSizeBinaryType, FixedSizeBinaryValueDecoder>))
+
+ case Type::TIMESTAMP:
+ ptr = MakeTimestampConverter<PrimitiveConverter>(type, options, pool);
+ break;
+
+ case Type::STRING:
+ if (options.check_utf8) {
+ ptr = std::make_shared<PrimitiveConverter<StringType, BinaryValueDecoder<true>>>(
+ type, options, pool);
+ } else {
+ ptr = std::make_shared<PrimitiveConverter<StringType, BinaryValueDecoder<false>>>(
+ type, options, pool);
+ }
+ break;
+
+ case Type::LARGE_STRING:
+ if (options.check_utf8) {
+ ptr = std::make_shared<
+ PrimitiveConverter<LargeStringType, BinaryValueDecoder<true>>>(type, options,
+ pool);
+ } else {
+ ptr = std::make_shared<
+ PrimitiveConverter<LargeStringType, BinaryValueDecoder<false>>>(type, options,
+ pool);
+ }
+ break;
+
+ case Type::DICTIONARY: {
+ const auto& dict_type = checked_cast<const DictionaryType&>(*type);
+ if (dict_type.index_type()->id() != Type::INT32) {
+ return Status::NotImplemented(
+ "CSV conversion to dictionary only supported for int32 indices, "
+ "got ",
+ type->ToString());
+ }
+ return DictionaryConverter::Make(dict_type.value_type(), options, pool);
+ }
+
+ default: {
+ return Status::NotImplemented("CSV conversion to ", type->ToString(),
+ " is not supported");
+ }
+
+#undef CONVERTER_CASE
+#undef NUMERIC_CONVERTER_CASE
+#undef REAL_CONVERTER_CASE
+ }
+ RETURN_NOT_OK(ptr->Initialize());
+ return ptr;
+}
+
+Result<std::shared_ptr<DictionaryConverter>> DictionaryConverter::Make(
+ const std::shared_ptr<DataType>& type, const ConvertOptions& options,
+ MemoryPool* pool) {
+ std::shared_ptr<DictionaryConverter> ptr;
+
+ switch (type->id()) {
+#define CONVERTER_CASE(TYPE_ID, TYPE, VALUE_DECODER_TYPE) \
+ case TYPE_ID: \
+ ptr.reset( \
+ new TypedDictionaryConverter<TYPE, VALUE_DECODER_TYPE>(type, options, pool)); \
+ break;
+
+#define REAL_CONVERTER_CASE(TYPE_ID, TYPE_CLASS, DECODER) \
+ case TYPE_ID: \
+ ptr = MakeRealConverter<DictionaryConverter, TypedDictionaryConverter, TYPE_CLASS, \
+ DECODER>(type, options, pool); \
+ break;
+
+ // XXX Are 32-bit types useful?
+ CONVERTER_CASE(Type::INT32, Int32Type, NumericValueDecoder<Int32Type>)
+ CONVERTER_CASE(Type::INT64, Int64Type, NumericValueDecoder<Int64Type>)
+ CONVERTER_CASE(Type::UINT32, UInt32Type, NumericValueDecoder<UInt32Type>)
+ CONVERTER_CASE(Type::UINT64, UInt64Type, NumericValueDecoder<UInt64Type>)
+ REAL_CONVERTER_CASE(Type::FLOAT, FloatType, NumericValueDecoder<FloatType>)
+ REAL_CONVERTER_CASE(Type::DOUBLE, DoubleType, NumericValueDecoder<DoubleType>)
+ REAL_CONVERTER_CASE(Type::DECIMAL, Decimal128Type, DecimalValueDecoder)
+ CONVERTER_CASE(Type::FIXED_SIZE_BINARY, FixedSizeBinaryType,
+ FixedSizeBinaryValueDecoder)
+ CONVERTER_CASE(Type::BINARY, BinaryType, BinaryValueDecoder<false>)
+ CONVERTER_CASE(Type::LARGE_BINARY, LargeBinaryType, BinaryValueDecoder<false>)
+
+ case Type::STRING:
+ if (options.check_utf8) {
+ ptr = std::make_shared<
+ TypedDictionaryConverter<StringType, BinaryValueDecoder<true>>>(type, options,
+ pool);
+ } else {
+ ptr = std::make_shared<
+ TypedDictionaryConverter<StringType, BinaryValueDecoder<false>>>(
+ type, options, pool);
+ }
+ break;
+
+ case Type::LARGE_STRING:
+ if (options.check_utf8) {
+ ptr = std::make_shared<
+ TypedDictionaryConverter<LargeStringType, BinaryValueDecoder<true>>>(
+ type, options, pool);
+ } else {
+ ptr = std::make_shared<
+ TypedDictionaryConverter<LargeStringType, BinaryValueDecoder<false>>>(
+ type, options, pool);
+ }
+ break;
+
+ default: {
+ return Status::NotImplemented("CSV dictionary conversion to ", type->ToString(),
+ " is not supported");
+ }
+
+#undef CONVERTER_CASE
+#undef REAL_CONVERTER_CASE
+ }
+ RETURN_NOT_OK(ptr->Initialize());
+ return ptr;
+}
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/converter.h b/src/arrow/cpp/src/arrow/csv/converter.h
new file mode 100644
index 000000000..639f692f2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/converter.h
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/csv/options.h"
+#include "arrow/result.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace csv {
+
+class BlockParser;
+
+class ARROW_EXPORT Converter {
+ public:
+ Converter(const std::shared_ptr<DataType>& type, const ConvertOptions& options,
+ MemoryPool* pool);
+ virtual ~Converter() = default;
+
+ virtual Result<std::shared_ptr<Array>> Convert(const BlockParser& parser,
+ int32_t col_index) = 0;
+
+ std::shared_ptr<DataType> type() const { return type_; }
+
+ // Create a Converter for the given data type
+ static Result<std::shared_ptr<Converter>> Make(
+ const std::shared_ptr<DataType>& type, const ConvertOptions& options,
+ MemoryPool* pool = default_memory_pool());
+
+ protected:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Converter);
+
+ virtual Status Initialize() = 0;
+
+ // CAUTION: ConvertOptions can grow large (if it customizes hundreds or
+ // thousands of columns), so avoid copying it in each Converter.
+ const ConvertOptions& options_;
+ MemoryPool* pool_;
+ std::shared_ptr<DataType> type_;
+};
+
+class ARROW_EXPORT DictionaryConverter : public Converter {
+ public:
+ DictionaryConverter(const std::shared_ptr<DataType>& value_type,
+ const ConvertOptions& options, MemoryPool* pool);
+
+ // If the dictionary length goes above this value, conversion will fail
+ // with Status::IndexError.
+ virtual void SetMaxCardinality(int32_t max_length) = 0;
+
+ // Create a Converter for the given dictionary value type.
+ // The dictionary index type will always be Int32.
+ static Result<std::shared_ptr<DictionaryConverter>> Make(
+ const std::shared_ptr<DataType>& value_type, const ConvertOptions& options,
+ MemoryPool* pool = default_memory_pool());
+
+ protected:
+ std::shared_ptr<DataType> value_type_;
+};
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/converter_benchmark.cc b/src/arrow/cpp/src/arrow/csv/converter_benchmark.cc
new file mode 100644
index 000000000..b7311880d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/converter_benchmark.cc
@@ -0,0 +1,152 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <sstream>
+#include <string>
+
+#include "arrow/buffer.h"
+#include "arrow/csv/converter.h"
+#include "arrow/csv/options.h"
+#include "arrow/csv/parser.h"
+#include "arrow/csv/reader.h"
+#include "arrow/csv/test_common.h"
+#include "arrow/io/memory.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/value_parsing.h"
+
+namespace arrow {
+namespace csv {
+
+static std::shared_ptr<BlockParser> BuildFromExamples(
+ const std::vector<std::string>& base_rows, int32_t num_rows) {
+ std::vector<std::string> rows;
+ for (int32_t i = 0; i < num_rows; ++i) {
+ rows.push_back(base_rows[i % base_rows.size()]);
+ }
+
+ std::shared_ptr<BlockParser> result;
+ MakeCSVParser(rows, &result);
+ return result;
+}
+
+static std::shared_ptr<BlockParser> BuildInt64Data(int32_t num_rows) {
+ const std::vector<std::string> base_rows = {"123\n", "4\n", "-317005557\n",
+ "\n", "N/A\n", "0\n"};
+ return BuildFromExamples(base_rows, num_rows);
+}
+
+static std::shared_ptr<BlockParser> BuildFloatData(int32_t num_rows) {
+ const std::vector<std::string> base_rows = {"0\n", "123.456\n", "-3170.55766\n", "\n",
+ "N/A\n"};
+ return BuildFromExamples(base_rows, num_rows);
+}
+
+static std::shared_ptr<BlockParser> BuildDecimal128Data(int32_t num_rows) {
+ const std::vector<std::string> base_rows = {"0\n", "123.456\n", "-3170.55766\n",
+ "\n", "N/A\n", "1233456789.123456789"};
+ return BuildFromExamples(base_rows, num_rows);
+}
+
+static std::shared_ptr<BlockParser> BuildStringData(int32_t num_rows) {
+ return BuildDecimal128Data(num_rows);
+}
+
+static std::shared_ptr<BlockParser> BuildISO8601Data(int32_t num_rows) {
+ const std::vector<std::string> base_rows = {
+ "1917-10-17\n", "2018-09-13\n", "1941-06-22 04:00\n", "1945-05-09 09:45:38\n"};
+ return BuildFromExamples(base_rows, num_rows);
+}
+
+static std::shared_ptr<BlockParser> BuildStrptimeData(int32_t num_rows) {
+ const std::vector<std::string> base_rows = {"10/17/1917\n", "9/13/2018\n",
+ "9/5/1945\n"};
+ return BuildFromExamples(base_rows, num_rows);
+}
+
+static void BenchmarkConversion(benchmark::State& state, // NOLINT non-const reference
+ BlockParser& parser,
+ const std::shared_ptr<DataType>& type,
+ ConvertOptions options) {
+ std::shared_ptr<Converter> converter = *Converter::Make(type, options);
+
+ while (state.KeepRunning()) {
+ auto converted = *converter->Convert(parser, 0 /* col_index */);
+ if (converted->length() != parser.num_rows()) {
+ std::cerr << "Conversion incomplete\n";
+ std::abort();
+ }
+ }
+
+ state.SetItemsProcessed(state.iterations() * parser.num_rows());
+}
+
+constexpr size_t num_rows = 10000;
+
+static void Int64Conversion(benchmark::State& state) { // NOLINT non-const reference
+ auto parser = BuildInt64Data(num_rows);
+ auto options = ConvertOptions::Defaults();
+
+ BenchmarkConversion(state, *parser, int64(), options);
+}
+
+static void FloatConversion(benchmark::State& state) { // NOLINT non-const reference
+ auto parser = BuildFloatData(num_rows);
+ auto options = ConvertOptions::Defaults();
+
+ BenchmarkConversion(state, *parser, float64(), options);
+}
+
+static void Decimal128Conversion(benchmark::State& state) { // NOLINT non-const reference
+ auto parser = BuildDecimal128Data(num_rows);
+ auto options = ConvertOptions::Defaults();
+
+ BenchmarkConversion(state, *parser, decimal(24, 9), options);
+}
+
+static void StringConversion(benchmark::State& state) { // NOLINT non-const reference
+ auto parser = BuildStringData(num_rows);
+ auto options = ConvertOptions::Defaults();
+
+ BenchmarkConversion(state, *parser, utf8(), options);
+}
+
+static void TimestampConversionDefault(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto parser = BuildISO8601Data(num_rows);
+ auto options = ConvertOptions::Defaults();
+ BenchmarkConversion(state, *parser, timestamp(TimeUnit::MILLI), options);
+}
+
+static void TimestampConversionStrptime(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto parser = BuildStrptimeData(num_rows);
+ auto options = ConvertOptions::Defaults();
+ options.timestamp_parsers.push_back(TimestampParser::MakeStrptime("%m/%d/%Y"));
+ BenchmarkConversion(state, *parser, timestamp(TimeUnit::MILLI), options);
+}
+
+BENCHMARK(Int64Conversion);
+BENCHMARK(FloatConversion);
+BENCHMARK(Decimal128Conversion);
+BENCHMARK(StringConversion);
+BENCHMARK(TimestampConversionDefault);
+BENCHMARK(TimestampConversionStrptime);
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/converter_test.cc b/src/arrow/cpp/src/arrow/csv/converter_test.cc
new file mode 100644
index 000000000..9a68b9568
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/converter_test.cc
@@ -0,0 +1,818 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/csv/converter.h"
+
+#include <gtest/gtest.h>
+
+#include <cstdint>
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/csv/options.h"
+#include "arrow/csv/test_common.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/value_parsing.h"
+
+namespace arrow {
+namespace csv {
+
+class BlockParser;
+
+// All recognized (non-empty) null values
+std::vector<std::string> AllNulls() {
+ return {"#N/A\n", "#N/A N/A\n", "#NA\n", "-1.#IND\n", "-1.#QNAN\n", "-NaN\n",
+ "-nan\n", "1.#IND\n", "1.#QNAN\n", "N/A\n", "NA\n", "NULL\n",
+ "NaN\n", "n/a\n", "nan\n", "null\n"};
+}
+
+template <typename DATA_TYPE, typename C_TYPE>
+void AssertConversion(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& csv_string,
+ const std::vector<std::vector<C_TYPE>>& expected,
+ ConvertOptions options = ConvertOptions::Defaults(),
+ bool validate_full = true) {
+ std::shared_ptr<BlockParser> parser;
+ std::shared_ptr<Converter> converter;
+ std::shared_ptr<Array> array, expected_array;
+
+ ASSERT_OK_AND_ASSIGN(converter, Converter::Make(type, options));
+
+ MakeCSVParser(csv_string, &parser);
+ for (int32_t col_index = 0; col_index < static_cast<int32_t>(expected.size());
+ ++col_index) {
+ ASSERT_OK_AND_ASSIGN(array, converter->Convert(*parser, col_index));
+ if (validate_full) {
+ ASSERT_OK(array->ValidateFull());
+ } else {
+ ASSERT_OK(array->Validate());
+ }
+ ArrayFromVector<DATA_TYPE, C_TYPE>(type, expected[col_index], &expected_array);
+ AssertArraysEqual(*expected_array, *array);
+ }
+}
+
+template <typename DATA_TYPE, typename C_TYPE>
+void AssertConversion(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& csv_string,
+ const std::vector<std::vector<C_TYPE>>& expected,
+ const std::vector<std::vector<bool>>& is_valid,
+ ConvertOptions options = ConvertOptions::Defaults()) {
+ std::shared_ptr<BlockParser> parser;
+ std::shared_ptr<Converter> converter;
+ std::shared_ptr<Array> array, expected_array;
+
+ ASSERT_OK_AND_ASSIGN(converter, Converter::Make(type, options));
+
+ MakeCSVParser(csv_string, &parser);
+ for (int32_t col_index = 0; col_index < static_cast<int32_t>(expected.size());
+ ++col_index) {
+ ASSERT_OK_AND_ASSIGN(array, converter->Convert(*parser, col_index));
+ ASSERT_OK(array->ValidateFull());
+ ArrayFromVector<DATA_TYPE, C_TYPE>(type, is_valid[col_index], expected[col_index],
+ &expected_array);
+ AssertArraysEqual(*expected_array, *array);
+ }
+}
+
+Result<std::shared_ptr<Array>> DictConversion(
+ const std::shared_ptr<DataType>& value_type, const std::string& csv_string,
+ int32_t max_cardinality = -1, ConvertOptions options = ConvertOptions::Defaults()) {
+ std::shared_ptr<BlockParser> parser;
+ std::shared_ptr<DictionaryConverter> converter;
+
+ ARROW_ASSIGN_OR_RAISE(converter, DictionaryConverter::Make(value_type, options));
+ if (max_cardinality >= 0) {
+ converter->SetMaxCardinality(max_cardinality);
+ }
+
+ ParseOptions parse_options;
+ parse_options.ignore_empty_lines = false;
+ MakeCSVParser({csv_string}, parse_options, &parser);
+
+ const int32_t col_index = 0;
+ return converter->Convert(*parser, col_index);
+}
+
+void AssertDictConversion(const std::string& csv_string,
+ const std::shared_ptr<Array>& expected_indices,
+ const std::shared_ptr<Array>& expected_dict,
+ int32_t max_cardinality = -1,
+ ConvertOptions options = ConvertOptions::Defaults(),
+ bool validate_full = true) {
+ std::shared_ptr<BlockParser> parser;
+ std::shared_ptr<DictionaryConverter> converter;
+ std::shared_ptr<Array> array, expected_array;
+ std::shared_ptr<DataType> expected_type;
+
+ ASSERT_OK_AND_ASSIGN(
+ array, DictConversion(expected_dict->type(), csv_string, max_cardinality, options));
+ if (validate_full) {
+ ASSERT_OK(array->ValidateFull());
+ } else {
+ ASSERT_OK(array->Validate());
+ }
+ expected_type = dictionary(expected_indices->type(), expected_dict->type());
+ ASSERT_TRUE(array->type()->Equals(*expected_type));
+ const auto& dict_array = internal::checked_cast<const DictionaryArray&>(*array);
+ AssertArraysEqual(*dict_array.dictionary(), *expected_dict);
+ AssertArraysEqual(*dict_array.indices(), *expected_indices);
+}
+
+template <typename DATA_TYPE, typename C_TYPE>
+void AssertConversionAllNulls(const std::shared_ptr<DataType>& type) {
+ std::vector<std::string> nulls = AllNulls();
+ std::vector<bool> is_valid(nulls.size(), false);
+ std::vector<C_TYPE> values(nulls.size());
+ AssertConversion<DATA_TYPE, C_TYPE>(type, nulls, {values}, {is_valid});
+}
+
+void AssertConversionError(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& csv_string,
+ const std::set<int32_t>& invalid_columns,
+ ConvertOptions options = ConvertOptions::Defaults()) {
+ std::shared_ptr<BlockParser> parser;
+ std::shared_ptr<Converter> converter;
+
+ ASSERT_OK_AND_ASSIGN(converter, Converter::Make(type, options));
+
+ MakeCSVParser(csv_string, &parser);
+ for (int32_t i = 0; i < parser->num_cols(); ++i) {
+ if (invalid_columns.find(i) == invalid_columns.end()) {
+ ASSERT_OK(converter->Convert(*parser, i));
+ } else {
+ ASSERT_RAISES(Invalid, converter->Convert(*parser, i));
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Converter tests
+
+template <typename T>
+class BinaryConversionTestBase : public testing::Test {
+ public:
+ std::shared_ptr<DataType> type() { return TypeTraits<T>::type_singleton(); }
+
+ void TestNulls() {
+ auto type = this->type();
+ AssertConversion<T, std::string>(type, {"ab,N/A\n", "NULL,\n"},
+ {{"ab", "NULL"}, {"N/A", ""}},
+ {{true, true}, {true, true}});
+
+ auto options = ConvertOptions::Defaults();
+ options.strings_can_be_null = true;
+ AssertConversion<T, std::string>(type, {"ab,N/A\n", "NULL,\n"},
+ {{"ab", ""}, {"", ""}},
+ {{true, false}, {false, false}}, options);
+ AssertConversion<T, std::string>(type, {"ab,\"N/A\"\n", "\"NULL\",\"\"\n"},
+ {{"ab", ""}, {"", ""}},
+ {{true, false}, {false, false}}, options);
+ options.quoted_strings_can_be_null = false;
+ AssertConversion<T, std::string>(type, {"ab,N/A\n", "NULL,\n"},
+ {{"ab", ""}, {"", ""}},
+ {{true, false}, {false, false}}, options);
+ AssertConversion<T, std::string>(type, {"ab,\"N/A\"\n", "\"NULL\",\"\"\n"},
+ {{"ab", "NULL"}, {"N/A", ""}},
+ {{true, true}, {true, true}}, options);
+ }
+
+ void TestCustomNulls() {
+ auto type = this->type();
+ auto options = ConvertOptions::Defaults();
+ options.null_values = {"xxx", "zzz"};
+ AssertConversion<T, std::string>(type, {"ab,N/A\n", "xxx,\"zzz\"\n"},
+ {{"ab", "xxx"}, {"N/A", "zzz"}},
+ {{true, true}, {true, true}}, options);
+
+ options.strings_can_be_null = true;
+ AssertConversion<T, std::string>(type, {"ab,N/A\n", "xxx,\"zzz\"\n"},
+ {{"ab", ""}, {"N/A", ""}},
+ {{true, false}, {true, false}}, options);
+ options.quoted_strings_can_be_null = false;
+ AssertConversion<T, std::string>(type, {"ab,N/A\n", "xxx,\"zzz\"\n"},
+ {{"ab", ""}, {"N/A", "zzz"}},
+ {{true, false}, {true, true}}, options);
+ }
+};
+
+template <typename T>
+class BinaryConversionTest : public BinaryConversionTestBase<T> {
+ public:
+ void TestBasics() {
+ auto type = this->type();
+ AssertConversion<T, std::string>(type, {"ab,cdé\n", ",\xffgh\n"},
+ {{"ab", ""}, {"cdé", "\xffgh"}});
+ }
+};
+
+using BinaryTestTypes = ::testing::Types<BinaryType, LargeBinaryType>;
+
+TYPED_TEST_SUITE(BinaryConversionTest, BinaryTestTypes);
+
+TYPED_TEST(BinaryConversionTest, Basics) { this->TestBasics(); }
+
+TYPED_TEST(BinaryConversionTest, Nulls) { this->TestNulls(); }
+
+TYPED_TEST(BinaryConversionTest, CustomNulls) { this->TestNulls(); }
+
+template <typename T>
+class StringConversionTest : public BinaryConversionTestBase<T> {
+ public:
+ void TestBasics() {
+ auto type = TypeTraits<T>::type_singleton();
+ AssertConversion<T, std::string>(type, {"ab,cdé\n", ",gh\n"},
+ {{"ab", ""}, {"cdé", "gh"}});
+ }
+
+ void TestInvalidUtf8() {
+ auto type = TypeTraits<T>::type_singleton();
+ // Invalid UTF8 in column 0
+ AssertConversionError(type, {"ab,cdé\n", "\xff,gh\n"}, {0});
+
+ auto options = ConvertOptions::Defaults();
+ options.check_utf8 = false;
+ AssertConversion<T, std::string>(type, {"ab,cdé\n", ",\xffgh\n"},
+ {{"ab", ""}, {"cdé", "\xffgh"}}, options,
+ /*validate_full=*/false);
+ }
+};
+
+using StringTestTypes = ::testing::Types<StringType, LargeStringType>;
+
+TYPED_TEST_SUITE(StringConversionTest, StringTestTypes);
+
+TYPED_TEST(StringConversionTest, Basics) { this->TestBasics(); }
+
+TYPED_TEST(StringConversionTest, Nulls) { this->TestNulls(); }
+
+TYPED_TEST(StringConversionTest, CustomNulls) { this->TestCustomNulls(); }
+
+TYPED_TEST(StringConversionTest, InvalidUtf8) { this->TestInvalidUtf8(); }
+
+TEST(FixedSizeBinaryConversion, Basics) {
+ AssertConversion<FixedSizeBinaryType, std::string>(
+ fixed_size_binary(2), {"ab,cd\n", "gh,ij\n"}, {{"ab", "gh"}, {"cd", "ij"}});
+}
+
+TEST(FixedSizeBinaryConversion, Errors) {
+ // Wrong-sized string in column 0
+ AssertConversionError(fixed_size_binary(2), {"ab,cd\n", "g,ij\n"}, {0});
+}
+
+TEST(FixedSizeBinaryConversion, Nulls) {
+ AssertConversion<FixedSizeBinaryType, std::string>(
+ fixed_size_binary(2), {"ab,N/A\n", ",ij\n"}, {{"ab", "\0\0"}, {"\0\0", "ij"}},
+ {{true, false}, {false, true}});
+
+ AssertConversionAllNulls<FixedSizeBinaryType, std::string>(fixed_size_binary(2));
+}
+
+TEST(FixedSizeBinaryConversion, CustomNulls) {
+ auto options = ConvertOptions::Defaults();
+ options.null_values = {"xxx", "zzz"};
+
+ AssertConversion<FixedSizeBinaryType, std::string>(fixed_size_binary(2),
+ {"\"ab\",\"xxx\"\n"}, {{"ab"}, {""}},
+ {{true}, {false}}, options);
+
+ options.quoted_strings_can_be_null = false;
+ AssertConversionError(fixed_size_binary(2), {"\"ab\",\"xxx\"\n"}, {1}, options);
+
+ AssertConversion<FixedSizeBinaryType, std::string>(
+ fixed_size_binary(2), {"ab,xxx\n", "zzz,ij\n"}, {{"ab", "\0\0"}, {"\0\0", "ij"}},
+ {{true, false}, {false, true}}, options);
+
+ AssertConversionError(fixed_size_binary(2), {",xxx,N/A\n"}, {0, 2}, options);
+
+ // Duplicate nulls allowed
+ options.null_values = {"xxx", "zzz", "xxx"};
+ AssertConversion<FixedSizeBinaryType, std::string>(
+ fixed_size_binary(2), {"ab,xxx\n", "zzz,ij\n"}, {{"ab", "\0,\0"}, {"\0\0", "ij"}},
+ {{true, false}, {false, true}}, options);
+}
+
+TEST(NullConversion, Basics) {
+ std::shared_ptr<BlockParser> parser;
+ std::shared_ptr<Converter> converter;
+ std::shared_ptr<Array> array;
+ std::shared_ptr<DataType> type = null();
+
+ auto options = ConvertOptions::Defaults();
+ ASSERT_OK_AND_ASSIGN(converter, Converter::Make(type, options));
+
+ MakeCSVParser({"NA,z\n", ",0\n"}, &parser);
+ ASSERT_OK_AND_ASSIGN(array, converter->Convert(*parser, 0));
+ ASSERT_EQ(array->type()->id(), Type::NA);
+ ASSERT_EQ(array->length(), 2);
+ ASSERT_RAISES(Invalid, converter->Convert(*parser, 1));
+}
+
+TEST(IntegerConversion, Basics) {
+ AssertConversion<Int8Type, int8_t>(int8(), {"12,34\n", "0,-128\n"},
+ {{12, 0}, {34, -128}});
+ AssertConversion<Int64Type, int64_t>(
+ int64(), {"12,34\n", "9223372036854775807,-9223372036854775808\n"},
+ {{12, 9223372036854775807LL}, {34, -9223372036854775807LL - 1}});
+
+ AssertConversion<UInt16Type, uint16_t>(uint16(), {"12,34\n", "0,65535\n"},
+ {{12, 0}, {34, 65535}});
+ AssertConversion<UInt64Type, uint64_t>(uint64(),
+ {"12,34\n", "0,18446744073709551615\n"},
+ {{12, 0}, {34, 18446744073709551615ULL}});
+}
+
+TEST(IntegerConversion, Nulls) {
+ AssertConversion<Int8Type, int8_t>(int8(), {"12,N/A\n", ",-128\n"},
+ {{12, 0}, {0, -128}},
+ {{true, false}, {false, true}});
+
+ AssertConversionAllNulls<Int8Type, int8_t>(int8());
+}
+
+TEST(IntegerConversion, CustomNulls) {
+ auto options = ConvertOptions::Defaults();
+ options.null_values = {"xxx", "zzz"};
+
+ AssertConversion<Int8Type, int8_t>(int8(), {"\"12\",\"xxx\"\n"}, {{12}, {0}},
+ {{true}, {false}}, options);
+
+ options.quoted_strings_can_be_null = false;
+ AssertConversionError(int8(), {"\"12\",\"xxx\"\n"}, {1}, options);
+
+ AssertConversion<Int8Type, int8_t>(int8(), {"12,xxx\n", "zzz,-128\n"},
+ {{12, 0}, {0, -128}}, {{true, false}, {false, true}},
+ options);
+
+ AssertConversionError(int8(), {",xxx,N/A\n"}, {0, 2}, options);
+
+ // Duplicate nulls allowed
+ options.null_values = {"xxx", "zzz", "xxx"};
+ AssertConversion<Int8Type, int8_t>(int8(), {"12,xxx\n", "zzz,-128\n"},
+ {{12, 0}, {0, -128}}, {{true, false}, {false, true}},
+ options);
+}
+
+TEST(IntegerConversion, Whitespace) {
+ AssertConversion<Int32Type, int32_t>(int32(), {" 12,34 \n", " 56 ,78\n"},
+ {{12, 56}, {34, 78}});
+}
+
+TEST(FloatingPointConversion, Basics) {
+ AssertConversion<FloatType, float>(float32(), {"12,34.5\n", "0,-1e30\n"},
+ {{12., 0.}, {34.5, -1e30f}});
+ AssertConversion<DoubleType, double>(float64(), {"12,34.5\n", "0,-1e100\n"},
+ {{12., 0.}, {34.5, -1e100}});
+}
+
+TEST(FloatingPointConversion, Nulls) {
+ AssertConversion<FloatType, float>(float32(), {"1.5,0.\n", ",-1e10\n"},
+ {{1.5, 0.}, {0., -1e10f}},
+ {{true, false}, {true, true}});
+
+ AssertConversionAllNulls<DoubleType, double>(float64());
+}
+
+TEST(FloatingPointConversion, CustomNulls) {
+ auto options = ConvertOptions::Defaults();
+ options.null_values = {"xxx", "zzz"};
+
+ AssertConversion<FloatType, float>(float32(), {"\"1.5\",\"xxx\"\n"}, {{1.5}, {0}},
+ {{true}, {false}}, options);
+
+ options.quoted_strings_can_be_null = false;
+ AssertConversionError(float32(), {"\"1.5\",\"xxx\"\n"}, {1}, options);
+
+ AssertConversion<FloatType, float>(float32(), {"1.5,xxx\n", "zzz,-1e10\n"},
+ {{1.5, 0.}, {0., -1e10f}},
+ {{true, false}, {false, true}}, options);
+}
+
+TEST(FloatingPointConversion, Whitespace) {
+ AssertConversion<DoubleType, double>(float64(), {" 12,34.5\n", " 0 ,-1e100 \n"},
+ {{12., 0.}, {34.5, -1e100}});
+}
+
+TEST(FloatingPointConversion, CustomDecimalPoint) {
+ auto options = ConvertOptions::Defaults();
+ options.decimal_point = '/';
+
+ AssertConversion<FloatType, float>(float32(), {"1/5\n", "-1e10\n", "N/A\n"},
+ {{1.5, -1e10f, 0.}}, {{true, true, false}}, options);
+ AssertConversion<DoubleType, double>(float64(), {"1/5\n", "-1e10\n", "N/A\n"},
+ {{1.5, -1e10, 0.}}, {{true, true, false}},
+ options);
+ AssertConversionError(float32(), {"1.5\n"}, {0}, options);
+}
+
+TEST(BooleanConversion, Basics) {
+ // XXX we may want to accept more bool-like values
+ AssertConversion<BooleanType, bool>(boolean(), {"true,false\n", "1,0\n"},
+ {{true, true}, {false, false}});
+}
+
+TEST(BooleanConversion, Nulls) {
+ AssertConversion<BooleanType, bool>(boolean(), {"true,\n", "1,0\n"},
+ {{true, true}, {false, false}},
+ {{true, true}, {false, true}});
+}
+
+TEST(BooleanConversion, CustomNulls) {
+ auto options = ConvertOptions::Defaults();
+ options.null_values = {"xxx", "zzz"};
+
+ AssertConversion<BooleanType, bool>(boolean(), {"\"true\",\"xxx\"\n"}, {{1}, {0}},
+ {{true}, {false}}, options);
+
+ options.quoted_strings_can_be_null = false;
+ AssertConversionError(boolean(), {"\"true\",\"xxx\"\n"}, {1}, options);
+
+ AssertConversion<BooleanType, bool>(boolean(), {"true,xxx\n", "zzz,0\n"},
+ {{true, false}, {false, false}},
+ {{true, false}, {false, true}}, options);
+}
+
+TEST(Date32Conversion, Basics) {
+ AssertConversion<Date32Type, int32_t>(date32(), {"1945-05-08\n", "2020-03-15\n"},
+ {{-9004, 18336}});
+}
+
+TEST(Date32Conversion, Nulls) {
+ AssertConversion<Date32Type, int32_t>(date32(), {"N/A\n", "2020-03-15\n"}, {{0, 18336}},
+ {{false, true}});
+}
+
+TEST(Date32Conversion, Errors) {
+ AssertConversionError(date32(), {"1945-06-31\n"}, {0});
+ AssertConversionError(date32(), {"2020-13-01\n"}, {0});
+}
+
+TEST(Date64Conversion, Basics) {
+ AssertConversion<Date64Type, int64_t>(date64(), {"1945-05-08\n", "2020-03-15\n"},
+ {{-777945600000LL, 1584230400000LL}});
+}
+
+TEST(Date64Conversion, Nulls) {
+ AssertConversion<Date64Type, int64_t>(date64(), {"N/A\n", "2020-03-15\n"},
+ {{0, 1584230400000LL}}, {{false, true}});
+}
+
+TEST(Date64Conversion, Errors) {
+ AssertConversionError(date64(), {"1945-06-31\n"}, {0});
+ AssertConversionError(date64(), {"2020-13-01\n"}, {0});
+}
+
+TEST(Time32Conversion, Seconds) {
+ const auto type = time32(TimeUnit::SECOND);
+
+ AssertConversion<Time32Type, int32_t>(type, {"00:00\n", "00:00:00\n"}, {{0, 0}});
+ AssertConversion<Time32Type, int32_t>(type, {"01:23:45\n", "23:45:43\n"},
+ {{5025, 85543}});
+ AssertConversion<Time32Type, int32_t>(type, {"N/A\n", "23:59:59\n"}, {{0, 86399}},
+ {{false, true}});
+
+ AssertConversionError(type, {"24:00\n"}, {0});
+ AssertConversionError(type, {"23:59:60\n"}, {0});
+}
+
+TEST(Time32Conversion, Millis) {
+ const auto type = time32(TimeUnit::MILLI);
+
+ AssertConversion<Time32Type, int32_t>(type, {"00:00\n", "00:00:00\n"}, {{0, 0}});
+ AssertConversion<Time32Type, int32_t>(type, {"01:23:45.1\n", "23:45:43.789\n"},
+ {{5025100, 85543789}});
+ AssertConversion<Time32Type, int32_t>(type, {"N/A\n", "23:59:59.999\n"},
+ {{0, 86399999}}, {{false, true}});
+
+ AssertConversionError(type, {"24:00\n"}, {0});
+ AssertConversionError(type, {"23:59:60\n"}, {0});
+}
+
+TEST(Time64Conversion, Micros) {
+ const auto type = time64(TimeUnit::MICRO);
+
+ AssertConversion<Time64Type, int64_t>(type, {"00:00\n", "00:00:00\n"}, {{0LL, 0LL}});
+ AssertConversion<Time64Type, int64_t>(type, {"01:23:45.1\n", "23:45:43.456789\n"},
+ {{5025100000LL, 85543456789LL}});
+ AssertConversion<Time64Type, int64_t>(type, {"N/A\n", "23:59:59.999999\n"},
+ {{0, 86399999999LL}}, {{false, true}});
+
+ AssertConversionError(type, {"24:00\n"}, {0});
+ AssertConversionError(type, {"23:59:60\n"}, {0});
+}
+
+TEST(Time64Conversion, Nanos) {
+ const auto type = time64(TimeUnit::NANO);
+
+ AssertConversion<Time64Type, int64_t>(type, {"00:00\n", "00:00:00\n"}, {{0LL, 0LL}});
+ AssertConversion<Time64Type, int64_t>(type, {"01:23:45.1\n", "23:45:43.123456789\n"},
+ {{5025100000000LL, 85543123456789LL}});
+ AssertConversion<Time64Type, int64_t>(type, {"N/A\n", "23:59:59.999999999\n"},
+ {{0, 86399999999999LL}}, {{false, true}});
+
+ AssertConversionError(type, {"24:00\n"}, {0});
+ AssertConversionError(type, {"23:59:60\n"}, {0});
+}
+
+TEST(TimestampConversion, Basics) {
+ auto type = timestamp(TimeUnit::SECOND);
+
+ AssertConversion<TimestampType, int64_t>(
+ type, {"1970-01-01\n2000-02-29\n3989-07-14\n1900-02-28\n"},
+ {{0, 951782400, 63730281600LL, -2203977600LL}});
+ AssertConversion<TimestampType, int64_t>(type,
+ {"2018-11-13 17:11:10\n1900-02-28 12:34:56\n"},
+ {{1542129070, -2203932304LL}});
+
+ type = timestamp(TimeUnit::NANO);
+ AssertConversion<TimestampType, int64_t>(
+ type, {"1970-01-01\n2000-02-29\n1900-02-28\n"},
+ {{0, 951782400000000000LL, -2203977600000000000LL}});
+}
+
+TEST(TimestampConversion, Nulls) {
+ auto type = timestamp(TimeUnit::MILLI);
+ AssertConversion<TimestampType, int64_t>(
+ type, {"1970-01-01 00:01:00,,N/A\n"}, {{60000}, {0}, {0}},
+ {{true}, {false}, {false}}, ConvertOptions::Defaults());
+}
+
+TEST(TimestampConversion, CustomNulls) {
+ auto options = ConvertOptions::Defaults();
+ options.null_values = {"xxx", "zzz"};
+ auto type = timestamp(TimeUnit::MILLI);
+
+ AssertConversion<TimestampType, int64_t>(type, {"\"1970-01-01 00:01:00\",\"xxx\"\n"},
+ {{60000}, {0}}, {{true}, {false}}, options);
+
+ options.quoted_strings_can_be_null = false;
+ AssertConversionError(type, {"\"1970-01-01 00:01:00\",\"xxx\"\n"}, {1}, options);
+
+ AssertConversion<TimestampType, int64_t>(type, {"1970-01-01 00:01:00,xxx,zzz\n"},
+ {{60000}, {0}, {0}},
+ {{true}, {false}, {false}}, options);
+}
+
+TEST(TimestampConversion, UserDefinedParsers) {
+ auto options = ConvertOptions::Defaults();
+ auto type = timestamp(TimeUnit::MILLI);
+
+ // Test a single parser
+ options.timestamp_parsers = {TimestampParser::MakeStrptime("%m/%d/%Y")};
+ AssertConversion<TimestampType, int64_t>(type, {"01/02/1970,01/03/1970\n"},
+ {{86400000}, {172800000}}, options);
+
+ // Test multiple parsers
+ options.timestamp_parsers.push_back(TimestampParser::MakeISO8601());
+ AssertConversion<TimestampType, int64_t>(type, {"01/02/1970,1970-01-03\n"},
+ {{86400000}, {172800000}}, options);
+}
+
+Decimal128 Dec128(util::string_view value) {
+ Decimal128 dec;
+ int32_t scale = 0;
+ int32_t precision = 0;
+ DCHECK_OK(Decimal128::FromString(value, &dec, &precision, &scale));
+ return dec;
+}
+
+TEST(DecimalConversion, Basics) {
+ AssertConversion<Decimal128Type, Decimal128>(
+ decimal(23, 2), {"12,34.5\n", "36.37,-1e5\n"},
+ {{Dec128("12.00"), Dec128("36.37")}, {Dec128("34.50"), Dec128("-100000.00")}});
+}
+
+TEST(DecimalConversion, Nulls) {
+ AssertConversion<Decimal128Type, Decimal128>(
+ decimal(14, 3), {"1.5,0.\n", ",-1e3\n"},
+ {{Dec128("1.500"), Decimal128()}, {Decimal128(), Dec128("-1000.000")}},
+ {{true, false}, {true, true}});
+
+ AssertConversionAllNulls<Decimal128Type, Decimal128>(decimal(14, 2));
+}
+
+TEST(DecimalConversion, CustomNulls) {
+ auto options = ConvertOptions::Defaults();
+ options.null_values = {"xxx", "zzz"};
+
+ AssertConversion<Decimal128Type, Decimal128>(decimal(14, 3), {"\"1.5\",\"xxx\"\n"},
+ {{Dec128("1.500")}, {0}},
+ {{true}, {false}}, options);
+
+ options.quoted_strings_can_be_null = false;
+ AssertConversionError(decimal(14, 3), {"\"1.5\",\"xxx\"\n"}, {1}, options);
+
+ AssertConversion<Decimal128Type, Decimal128>(
+ decimal(14, 3), {"1.5,xxx\n", "zzz,-1e3\n"},
+ {{Dec128("1.500"), Decimal128()}, {Decimal128(), Dec128("-1000.000")}},
+ {{true, false}, {false, true}}, options);
+}
+
+TEST(DecimalConversion, CustomDecimalPoint) {
+ auto options = ConvertOptions::Defaults();
+ options.decimal_point = '/';
+
+ AssertConversion<Decimal128Type, Decimal128>(
+ decimal(14, 3), {"1/5,0/\n", ",-1e3\n"},
+ {{Dec128("1.500"), Decimal128()}, {Decimal128(), Dec128("-1000.000")}},
+ {{true, false}, {true, true}}, options);
+ AssertConversionError(decimal128(14, 3), {"1.5\n"}, {0}, options);
+}
+
+TEST(DecimalConversion, Whitespace) {
+ AssertConversion<Decimal128Type, Decimal128>(
+ decimal(5, 1), {" 12.00,34.5\n", " 0 ,-1e2 \n"},
+ {{Dec128("12.0"), Decimal128()}, {Dec128("34.5"), Dec128("-100.0")}});
+}
+
+TEST(DecimalConversion, OverflowFails) {
+ AssertConversionError(decimal(5, 0), {"1e6,0\n"}, {0});
+
+ AssertConversionError(decimal(5, 1), {"123.22\n"}, {0});
+ AssertConversionError(decimal(5, 1), {"12345.6\n"}, {0});
+ AssertConversionError(decimal(5, 1), {"1.61\n"}, {0});
+}
+
+//////////////////////////////////////////////////////////////////////////
+// DictionaryConverter tests
+
+template <typename T>
+class TestNumericDictConverter : public ::testing::Test {
+ public:
+ std::shared_ptr<DataType> type() const { return TypeTraits<T>::type_singleton(); }
+};
+
+using NumericDictConversionTypes =
+ ::testing::Types<Int32Type, UInt32Type, Int64Type, UInt64Type, FloatType, DoubleType>;
+
+TYPED_TEST_SUITE(TestNumericDictConverter, NumericDictConversionTypes);
+
+TYPED_TEST(TestNumericDictConverter, Basics) {
+ auto expected_dict = ArrayFromJSON(this->type(), "[4, 5]");
+ auto expected_indices = ArrayFromJSON(int32(), "[0, 1, 0, 0]");
+
+ AssertDictConversion("4\n5\n4\n4\n", expected_indices, expected_dict);
+}
+
+TYPED_TEST(TestNumericDictConverter, Nulls) {
+ auto expected_dict = ArrayFromJSON(this->type(), "[4, 5]");
+ auto expected_indices = ArrayFromJSON(int32(), "[0, 1, null, 0]");
+
+ AssertDictConversion("4\n5\nN/A\n4\n", expected_indices, expected_dict);
+ AssertDictConversion("\"4\"\n\"5\"\n\"N/A\"\n\"4\"\n", expected_indices, expected_dict);
+}
+
+TYPED_TEST(TestNumericDictConverter, Errors) {
+ auto value_type = this->type();
+ ASSERT_RAISES(Invalid, DictConversion(value_type, "xxx\n"));
+
+ ConvertOptions options = ConvertOptions::Defaults();
+
+ options.quoted_strings_can_be_null = false;
+ ASSERT_RAISES(Invalid, DictConversion(value_type, "\"N/A\"\n", -1, options));
+
+ // Overflow
+ if (is_integer(value_type->id())) {
+ ASSERT_RAISES(Invalid, DictConversion(value_type, "99999999999999999999999\n"));
+ }
+ if (is_unsigned_integer(value_type->id())) {
+ ASSERT_RAISES(Invalid, DictConversion(value_type, "-1\n"));
+ }
+}
+
+template <typename T>
+class TestStringDictConverter : public ::testing::Test {
+ public:
+ std::shared_ptr<DataType> type() const { return TypeTraits<T>::type_singleton(); }
+
+ bool is_utf8_type() const {
+ return T::type_id == Type::STRING || T::type_id == Type::LARGE_STRING;
+ }
+};
+
+using StringDictConversionTypes =
+ ::testing::Types<BinaryType, LargeBinaryType, StringType, LargeStringType>;
+
+TYPED_TEST_SUITE(TestStringDictConverter, StringDictConversionTypes);
+
+TYPED_TEST(TestStringDictConverter, Basics) {
+ auto expected_dict = ArrayFromJSON(this->type(), R"(["ab", "cdé", ""])");
+ auto expected_indices = ArrayFromJSON(int32(), "[0, 1, 2, 0]");
+
+ AssertDictConversion("ab\ncdé\n\nab\n", expected_indices, expected_dict);
+}
+
+TYPED_TEST(TestStringDictConverter, Nulls) {
+ auto expected_dict = ArrayFromJSON(this->type(), R"(["ab", "N/A", ""])");
+ auto expected_indices = ArrayFromJSON(int32(), "[0, 1, 2, 0]");
+
+ AssertDictConversion("ab\nN/A\n\nab\n", expected_indices, expected_dict);
+
+ auto options = ConvertOptions::Defaults();
+ options.strings_can_be_null = true;
+ expected_dict = ArrayFromJSON(this->type(), R"(["ab"])");
+ expected_indices = ArrayFromJSON(int32(), "[0, null, null, 0]");
+ AssertDictConversion("ab\nN/A\n\nab\n", expected_indices, expected_dict, -1, options);
+}
+
+TYPED_TEST(TestStringDictConverter, NonUTF8) {
+ auto expected_indices = ArrayFromJSON(int32(), "[0, 1, 2, 0]");
+ std::shared_ptr<Array> expected_dict;
+ ArrayFromVector<TypeParam, std::string>({"ab", "cd\xff", ""}, &expected_dict);
+ std::string csv_string = "ab\ncd\xff\n\nab\n";
+
+ if (this->is_utf8_type()) {
+ ASSERT_RAISES(Invalid, DictConversion(this->type(), "ab\ncd\xff\n\nab\n"));
+
+ auto options = ConvertOptions::Defaults();
+ options.check_utf8 = false;
+ AssertDictConversion(csv_string, expected_indices, expected_dict, -1, options,
+ /*validate_full=*/false);
+ } else {
+ AssertDictConversion(csv_string, expected_indices, expected_dict);
+ }
+}
+
+TYPED_TEST(TestStringDictConverter, MaxCardinality) {
+ auto expected_dict = ArrayFromJSON(this->type(), R"(["ab", "cd", "ef"])");
+ auto expected_indices = ArrayFromJSON(int32(), "[0, 1, 2, 1]");
+ std::string csv_string = "ab\ncd\nef\ncd\n";
+
+ AssertDictConversion(csv_string, expected_indices, expected_dict, 3);
+ ASSERT_RAISES(IndexError, DictConversion(this->type(), csv_string, 2));
+}
+
+TEST(TestFixedSizeBinaryDictConverter, Basics) {
+ auto value_type = fixed_size_binary(3);
+
+ auto expected_dict = ArrayFromJSON(value_type, R"(["abc", "def"])");
+ auto expected_indices = ArrayFromJSON(int32(), "[0, 1, 0, 1]");
+
+ AssertDictConversion("abc\ndef\nabc\ndef\n", expected_indices, expected_dict);
+}
+
+TEST(TestFixedSizeBinaryDictConverter, Errors) {
+ auto value_type = fixed_size_binary(3);
+
+ // Invalid string size
+ ASSERT_RAISES(Invalid, DictConversion(value_type, "abc\nde\n"));
+}
+
+TEST(TestDecimalDictConverter, Basics) {
+ auto value_type = decimal(9, 3);
+
+ auto expected_dict = ArrayFromJSON(value_type, R"(["1.234", "456.789"])");
+ auto expected_indices = ArrayFromJSON(int32(), "[0, 1, null, 1]");
+
+ AssertDictConversion("1.234\n456.789\nN/A\n4.56789e2\n", expected_indices,
+ expected_dict);
+}
+
+TEST(TestDecimalDictConverter, CustomDecimalPoint) {
+ auto value_type = decimal(9, 3);
+
+ auto options = ConvertOptions::Defaults();
+ options.decimal_point = '\'';
+
+ auto expected_dict = ArrayFromJSON(value_type, R"(["1.234", "456.789"])");
+ auto expected_indices = ArrayFromJSON(int32(), "[0, 1, null, 1]");
+
+ AssertDictConversion("1'234\n456'789\nN/A\n4'56789e2\n", expected_indices,
+ expected_dict, -1, options);
+
+ ASSERT_RAISES(Invalid, DictConversion(value_type, "1.234\n", -1, options));
+}
+
+TEST(TestDecimalDictConverter, Errors) {
+ auto value_type = decimal(9, 3);
+
+ // Overflow
+ ASSERT_RAISES(Invalid, DictConversion(value_type, "1e10\n"));
+}
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/inference_internal.h b/src/arrow/cpp/src/arrow/csv/inference_internal.h
new file mode 100644
index 000000000..1fd6d41b5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/inference_internal.h
@@ -0,0 +1,155 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/csv/converter.h"
+#include "arrow/csv/options.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace csv {
+
+enum class InferKind {
+ Null,
+ Integer,
+ Boolean,
+ Real,
+ Date,
+ Time,
+ Timestamp,
+ TimestampNS,
+ TextDict,
+ BinaryDict,
+ Text,
+ Binary
+};
+
+class InferStatus {
+ public:
+ explicit InferStatus(const ConvertOptions& options)
+ : kind_(InferKind::Null), can_loosen_type_(true), options_(options) {}
+
+ InferKind kind() const { return kind_; }
+
+ bool can_loosen_type() const { return can_loosen_type_; }
+
+ void LoosenType(const Status& conversion_error) {
+ DCHECK(can_loosen_type_);
+
+ switch (kind_) {
+ case InferKind::Null:
+ return SetKind(InferKind::Integer);
+ case InferKind::Integer:
+ return SetKind(InferKind::Boolean);
+ case InferKind::Boolean:
+ return SetKind(InferKind::Date);
+ case InferKind::Date:
+ return SetKind(InferKind::Time);
+ case InferKind::Time:
+ return SetKind(InferKind::Timestamp);
+ case InferKind::Timestamp:
+ return SetKind(InferKind::TimestampNS);
+ case InferKind::TimestampNS:
+ return SetKind(InferKind::Real);
+ case InferKind::Real:
+ if (options_.auto_dict_encode) {
+ return SetKind(InferKind::TextDict);
+ } else {
+ return SetKind(InferKind::Text);
+ }
+ case InferKind::TextDict:
+ if (conversion_error.IsIndexError()) {
+ // Cardinality too large, fall back to non-dict encoding
+ return SetKind(InferKind::Text);
+ } else {
+ // Assuming UTF8 validation failure
+ return SetKind(InferKind::BinaryDict);
+ }
+ break;
+ case InferKind::BinaryDict:
+ // Assuming cardinality too large
+ return SetKind(InferKind::Binary);
+ case InferKind::Text:
+ // Assuming UTF8 validation failure
+ return SetKind(InferKind::Binary);
+ default:
+ ARROW_LOG(FATAL) << "Shouldn't come here";
+ }
+ }
+
+ Result<std::shared_ptr<Converter>> MakeConverter(MemoryPool* pool) {
+ auto make_converter =
+ [&](std::shared_ptr<DataType> type) -> Result<std::shared_ptr<Converter>> {
+ return Converter::Make(type, options_, pool);
+ };
+
+ auto make_dict_converter =
+ [&](std::shared_ptr<DataType> type) -> Result<std::shared_ptr<Converter>> {
+ ARROW_ASSIGN_OR_RAISE(auto dict_converter,
+ DictionaryConverter::Make(type, options_, pool));
+ dict_converter->SetMaxCardinality(options_.auto_dict_max_cardinality);
+ return dict_converter;
+ };
+
+ switch (kind_) {
+ case InferKind::Null:
+ return make_converter(null());
+ case InferKind::Integer:
+ return make_converter(int64());
+ case InferKind::Boolean:
+ return make_converter(boolean());
+ case InferKind::Date:
+ return make_converter(date32());
+ case InferKind::Time:
+ return make_converter(time32(TimeUnit::SECOND));
+ case InferKind::Timestamp:
+ return make_converter(timestamp(TimeUnit::SECOND));
+ case InferKind::TimestampNS:
+ return make_converter(timestamp(TimeUnit::NANO));
+ case InferKind::Real:
+ return make_converter(float64());
+ case InferKind::Text:
+ return make_converter(utf8());
+ case InferKind::Binary:
+ return make_converter(binary());
+ case InferKind::TextDict:
+ return make_dict_converter(utf8());
+ case InferKind::BinaryDict:
+ return make_dict_converter(binary());
+ }
+ return Status::UnknownError("Shouldn't come here");
+ }
+
+ protected:
+ void SetKind(InferKind kind) {
+ kind_ = kind;
+ if (kind == InferKind::Binary) {
+ // Binary is the catch-all type
+ can_loosen_type_ = false;
+ }
+ }
+
+ InferKind kind_;
+ bool can_loosen_type_;
+ const ConvertOptions& options_;
+};
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/invalid_row.h b/src/arrow/cpp/src/arrow/csv/invalid_row.h
new file mode 100644
index 000000000..8a07b568a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/invalid_row.h
@@ -0,0 +1,56 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <functional>
+
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace csv {
+
+/// \brief Description of an invalid row
+struct InvalidRow {
+ /// \brief Number of columns expected in the row
+ int32_t expected_columns;
+ /// \brief Actual number of columns found in the row
+ int32_t actual_columns;
+ /// \brief The physical row number if known or -1
+ ///
+ /// This number is one-based and also accounts for non-data rows (such as
+ /// CSV header rows).
+ int64_t number;
+ /// \brief View of the entire row. Memory will be freed after callback returns
+ const util::string_view text;
+};
+
+/// \brief Result returned by an InvalidRowHandler
+enum class InvalidRowResult {
+ // Generate an error describing this row
+ Error,
+ // Skip over this row
+ Skip
+};
+
+/// \brief callback for handling a row with an invalid number of columns while parsing
+/// \return result indicating if an error should be returned from the parser or the row is
+/// skipped
+using InvalidRowHandler = std::function<InvalidRowResult(const InvalidRow&)>;
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/options.cc b/src/arrow/cpp/src/arrow/csv/options.cc
new file mode 100644
index 000000000..c71cfdaf2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/options.cc
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/csv/options.h"
+
+namespace arrow {
+namespace csv {
+
+ParseOptions ParseOptions::Defaults() { return ParseOptions(); }
+
+Status ParseOptions::Validate() const {
+ if (ARROW_PREDICT_FALSE(delimiter == '\n' || delimiter == '\r')) {
+ return Status::Invalid("ParseOptions: delimiter cannot be \\r or \\n");
+ }
+ if (ARROW_PREDICT_FALSE(quoting && (quote_char == '\n' || quote_char == '\r'))) {
+ return Status::Invalid("ParseOptions: quote_char cannot be \\r or \\n");
+ }
+ if (ARROW_PREDICT_FALSE(escaping && (escape_char == '\n' || escape_char == '\r'))) {
+ return Status::Invalid("ParseOptions: escape_char cannot be \\r or \\n");
+ }
+ return Status::OK();
+}
+
+ConvertOptions ConvertOptions::Defaults() {
+ auto options = ConvertOptions();
+ // Same default null / true / false spellings as in Pandas.
+ options.null_values = {"", "#N/A", "#N/A N/A", "#NA", "-1.#IND", "-1.#QNAN",
+ "-NaN", "-nan", "1.#IND", "1.#QNAN", "N/A", "NA",
+ "NULL", "NaN", "n/a", "nan", "null"};
+ options.true_values = {"1", "True", "TRUE", "true"};
+ options.false_values = {"0", "False", "FALSE", "false"};
+ return options;
+}
+
+Status ConvertOptions::Validate() const { return Status::OK(); }
+
+ReadOptions ReadOptions::Defaults() { return ReadOptions(); }
+
+Status ReadOptions::Validate() const {
+ if (ARROW_PREDICT_FALSE(block_size < 1)) {
+ // Min is 1 because some tests use really small block sizes
+ return Status::Invalid("ReadOptions: block_size must be at least 1: ", block_size);
+ }
+ if (ARROW_PREDICT_FALSE(skip_rows < 0)) {
+ return Status::Invalid("ReadOptions: skip_rows cannot be negative: ", skip_rows);
+ }
+ if (ARROW_PREDICT_FALSE(skip_rows_after_names < 0)) {
+ return Status::Invalid("ReadOptions: skip_rows_after_names cannot be negative: ",
+ skip_rows_after_names);
+ }
+ if (ARROW_PREDICT_FALSE(autogenerate_column_names && !column_names.empty())) {
+ return Status::Invalid(
+ "ReadOptions: autogenerate_column_names cannot be true when column_names are "
+ "provided");
+ }
+ return Status::OK();
+}
+
+WriteOptions WriteOptions::Defaults() { return WriteOptions(); }
+
+Status WriteOptions::Validate() const {
+ if (ARROW_PREDICT_FALSE(batch_size < 1)) {
+ return Status::Invalid("WriteOptions: batch_size must be at least 1: ", batch_size);
+ }
+ return Status::OK();
+}
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/options.h b/src/arrow/cpp/src/arrow/csv/options.h
new file mode 100644
index 000000000..38514bcb9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/options.h
@@ -0,0 +1,194 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "arrow/csv/invalid_row.h"
+#include "arrow/csv/type_fwd.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/status.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class DataType;
+class TimestampParser;
+
+namespace csv {
+
+// Silly workaround for https://github.com/michaeljones/breathe/issues/453
+constexpr char kDefaultEscapeChar = '\\';
+
+struct ARROW_EXPORT ParseOptions {
+ // Parsing options
+
+ /// Field delimiter
+ char delimiter = ',';
+ /// Whether quoting is used
+ bool quoting = true;
+ /// Quoting character (if `quoting` is true)
+ char quote_char = '"';
+ /// Whether a quote inside a value is double-quoted
+ bool double_quote = true;
+ /// Whether escaping is used
+ bool escaping = false;
+ /// Escaping character (if `escaping` is true)
+ char escape_char = kDefaultEscapeChar;
+ /// Whether values are allowed to contain CR (0x0d) and LF (0x0a) characters
+ bool newlines_in_values = false;
+ /// Whether empty lines are ignored. If false, an empty line represents
+ /// a single empty value (assuming a one-column CSV file).
+ bool ignore_empty_lines = true;
+ /// A handler function for rows which do not have the correct number of columns
+ InvalidRowHandler invalid_row_handler;
+
+ /// Create parsing options with default values
+ static ParseOptions Defaults();
+
+ /// \brief Test that all set options are valid
+ Status Validate() const;
+};
+
+struct ARROW_EXPORT ConvertOptions {
+ // Conversion options
+
+ /// Whether to check UTF8 validity of string columns
+ bool check_utf8 = true;
+ /// Optional per-column types (disabling type inference on those columns)
+ std::unordered_map<std::string, std::shared_ptr<DataType>> column_types;
+ /// Recognized spellings for null values
+ std::vector<std::string> null_values;
+ /// Recognized spellings for boolean true values
+ std::vector<std::string> true_values;
+ /// Recognized spellings for boolean false values
+ std::vector<std::string> false_values;
+
+ /// Whether string / binary columns can have null values.
+ ///
+ /// If true, then strings in "null_values" are considered null for string columns.
+ /// If false, then all strings are valid string values.
+ bool strings_can_be_null = false;
+
+ /// Whether quoted values can be null.
+ ///
+ /// If true, then strings in "null_values" are also considered null when they
+ /// appear quoted in the CSV file. Otherwise, quoted values are never considered null.
+ bool quoted_strings_can_be_null = true;
+
+ /// Whether to try to automatically dict-encode string / binary data.
+ /// If true, then when type inference detects a string or binary column,
+ /// it is dict-encoded up to `auto_dict_max_cardinality` distinct values
+ /// (per chunk), after which it switches to regular encoding.
+ ///
+ /// This setting is ignored for non-inferred columns (those in `column_types`).
+ bool auto_dict_encode = false;
+ int32_t auto_dict_max_cardinality = 50;
+
+ /// Decimal point character for floating-point and decimal data
+ char decimal_point = '.';
+
+ // XXX Should we have a separate FilterOptions?
+
+ /// If non-empty, indicates the names of columns from the CSV file that should
+ /// be actually read and converted (in the vector's order).
+ /// Columns not in this vector will be ignored.
+ std::vector<std::string> include_columns;
+ /// If false, columns in `include_columns` but not in the CSV file will error out.
+ /// If true, columns in `include_columns` but not in the CSV file will produce
+ /// a column of nulls (whose type is selected using `column_types`,
+ /// or null by default)
+ /// This option is ignored if `include_columns` is empty.
+ bool include_missing_columns = false;
+
+ /// User-defined timestamp parsers, using the virtual parser interface in
+ /// arrow/util/value_parsing.h. More than one parser can be specified, and
+ /// the CSV conversion logic will try parsing values starting from the
+ /// beginning of this vector. If no parsers are specified, we use the default
+ /// built-in ISO-8601 parser.
+ std::vector<std::shared_ptr<TimestampParser>> timestamp_parsers;
+
+ /// Create conversion options with default values, including conventional
+ /// values for `null_values`, `true_values` and `false_values`
+ static ConvertOptions Defaults();
+
+ /// \brief Test that all set options are valid
+ Status Validate() const;
+};
+
+struct ARROW_EXPORT ReadOptions {
+ // Reader options
+
+ /// Whether to use the global CPU thread pool
+ bool use_threads = true;
+
+ /// \brief Block size we request from the IO layer.
+ ///
+ /// This will determine multi-threading granularity as well as
+ /// the size of individual record batches.
+ /// Minimum valid value for block size is 1
+ int32_t block_size = 1 << 20; // 1 MB
+
+ /// Number of header rows to skip (not including the row of column names, if any)
+ int32_t skip_rows = 0;
+
+ /// Number of rows to skip after the column names are read, if any
+ int32_t skip_rows_after_names = 0;
+
+ /// Column names for the target table.
+ /// If empty, fall back on autogenerate_column_names.
+ std::vector<std::string> column_names;
+
+ /// Whether to autogenerate column names if `column_names` is empty.
+ /// If true, column names will be of the form "f0", "f1"...
+ /// If false, column names will be read from the first CSV row after `skip_rows`.
+ bool autogenerate_column_names = false;
+
+ /// Create read options with default values
+ static ReadOptions Defaults();
+
+ /// \brief Test that all set options are valid
+ Status Validate() const;
+};
+
+struct ARROW_EXPORT WriteOptions {
+ /// Whether to write an initial header line with column names
+ bool include_header = true;
+
+ /// \brief Maximum number of rows processed at a time
+ ///
+ /// The CSV writer converts and writes data in batches of N rows.
+ /// This number can impact performance.
+ int32_t batch_size = 1024;
+
+ /// \brief IO context for writing.
+ io::IOContext io_context;
+
+ /// Create write options with default values
+ static WriteOptions Defaults();
+
+ /// \brief Test that all set options are valid
+ Status Validate() const;
+};
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/parser.cc b/src/arrow/cpp/src/arrow/csv/parser.cc
new file mode 100644
index 000000000..6400c94bb
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/parser.cc
@@ -0,0 +1,608 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/csv/parser.h"
+
+#include <algorithm>
+#include <cstdio>
+#include <limits>
+#include <utility>
+
+#include "arrow/memory_pool.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace csv {
+
+using detail::DataBatch;
+using detail::ParsedValueDesc;
+
+namespace {
+
+template <typename... Args>
+Status ParseError(Args&&... args) {
+ return Status::Invalid("CSV parse error: ", std::forward<Args>(args)...);
+}
+
+Status MismatchingColumns(const InvalidRow& row) {
+ std::string ellipse;
+ auto row_string = row.text;
+ if (row_string.length() > 100) {
+ row_string = row_string.substr(0, 96);
+ ellipse = " ...";
+ }
+ if (row.number < 0) {
+ return ParseError("Expected ", row.expected_columns, " columns, got ",
+ row.actual_columns, ": ", row_string, ellipse);
+ }
+ return ParseError("Row #", row.number, ": Expected ", row.expected_columns,
+ " columns, got ", row.actual_columns, ": ", row_string, ellipse);
+}
+
+inline bool IsControlChar(uint8_t c) { return c < ' '; }
+
+template <bool Quoting, bool Escaping>
+class SpecializedOptions {
+ public:
+ static constexpr bool quoting = Quoting;
+ static constexpr bool escaping = Escaping;
+};
+
+// A helper class allocating the buffer for parsed values and writing into it
+// without any further resizes, except at the end.
+class PresizedDataWriter {
+ public:
+ PresizedDataWriter(MemoryPool* pool, uint32_t size)
+ : parsed_size_(0), parsed_capacity_(size) {
+ parsed_buffer_ = *AllocateResizableBuffer(parsed_capacity_, pool);
+ parsed_ = parsed_buffer_->mutable_data();
+ }
+
+ void Finish(std::shared_ptr<Buffer>* out_parsed) {
+ ARROW_CHECK_OK(parsed_buffer_->Resize(parsed_size_));
+ *out_parsed = parsed_buffer_;
+ }
+
+ void BeginLine() { saved_parsed_size_ = parsed_size_; }
+
+ void PushFieldChar(char c) {
+ DCHECK_LT(parsed_size_, parsed_capacity_);
+ parsed_[parsed_size_++] = static_cast<uint8_t>(c);
+ }
+
+ // Rollback the state that was saved in BeginLine()
+ void RollbackLine() { parsed_size_ = saved_parsed_size_; }
+
+ int64_t size() { return parsed_size_; }
+
+ protected:
+ std::shared_ptr<ResizableBuffer> parsed_buffer_;
+ uint8_t* parsed_;
+ int64_t parsed_size_;
+ int64_t parsed_capacity_;
+ // Checkpointing, for when an incomplete line is encountered at end of block
+ int64_t saved_parsed_size_;
+};
+
+template <typename Derived>
+class ValueDescWriter {
+ public:
+ Derived* derived() { return static_cast<Derived*>(this); }
+
+ template <typename DataWriter>
+ void Start(DataWriter& parsed_writer) {
+ derived()->PushValue(
+ {static_cast<uint32_t>(parsed_writer.size()) & 0x7fffffffU, false});
+ }
+
+ void BeginLine() { saved_values_size_ = values_size_; }
+
+ // Rollback the state that was saved in BeginLine()
+ void RollbackLine() { values_size_ = saved_values_size_; }
+
+ void StartField(bool quoted) { quoted_ = quoted; }
+
+ template <typename DataWriter>
+ void FinishField(DataWriter* parsed_writer) {
+ derived()->PushValue(
+ {static_cast<uint32_t>(parsed_writer->size()) & 0x7fffffffU, quoted_});
+ }
+
+ void Finish(std::shared_ptr<Buffer>* out_values) {
+ ARROW_CHECK_OK(values_buffer_->Resize(values_size_ * sizeof(*values_)));
+ *out_values = values_buffer_;
+ }
+
+ protected:
+ ValueDescWriter(MemoryPool* pool, int64_t values_capacity)
+ : values_size_(0), values_capacity_(values_capacity) {
+ values_buffer_ = *AllocateResizableBuffer(values_capacity_ * sizeof(*values_), pool);
+ values_ = reinterpret_cast<ParsedValueDesc*>(values_buffer_->mutable_data());
+ }
+
+ std::shared_ptr<ResizableBuffer> values_buffer_;
+ ParsedValueDesc* values_;
+ int64_t values_size_;
+ int64_t values_capacity_;
+ bool quoted_;
+ // Checkpointing, for when an incomplete line is encountered at end of block
+ int64_t saved_values_size_;
+};
+
+// A helper class handling a growable buffer for values offsets. This class is
+// used when the number of columns is not yet known and we therefore cannot
+// efficiently presize the target area for a given number of rows.
+class ResizableValueDescWriter : public ValueDescWriter<ResizableValueDescWriter> {
+ public:
+ explicit ResizableValueDescWriter(MemoryPool* pool)
+ : ValueDescWriter(pool, /*values_capacity=*/256) {}
+
+ void PushValue(ParsedValueDesc v) {
+ if (ARROW_PREDICT_FALSE(values_size_ == values_capacity_)) {
+ values_capacity_ = values_capacity_ * 2;
+ ARROW_CHECK_OK(values_buffer_->Resize(values_capacity_ * sizeof(*values_)));
+ values_ = reinterpret_cast<ParsedValueDesc*>(values_buffer_->mutable_data());
+ }
+ values_[values_size_++] = v;
+ }
+};
+
+// A helper class allocating the buffer for values offsets and writing into it
+// without any further resizes, except at the end. This class is used once the
+// number of columns is known, as it eliminates resizes and generates simpler,
+// faster CSV parsing code.
+class PresizedValueDescWriter : public ValueDescWriter<PresizedValueDescWriter> {
+ public:
+ PresizedValueDescWriter(MemoryPool* pool, int32_t num_rows, int32_t num_cols)
+ : ValueDescWriter(pool, /*values_capacity=*/1 + num_rows * num_cols) {}
+
+ void PushValue(ParsedValueDesc v) {
+ DCHECK_LT(values_size_, values_capacity_);
+ values_[values_size_++] = v;
+ }
+};
+
+} // namespace
+
+class BlockParserImpl {
+ public:
+ BlockParserImpl(MemoryPool* pool, ParseOptions options, int32_t num_cols,
+ int64_t first_row, int32_t max_num_rows)
+ : pool_(pool),
+ options_(options),
+ first_row_(first_row),
+ max_num_rows_(max_num_rows),
+ batch_(num_cols) {}
+
+ const DataBatch& parsed_batch() const { return batch_; }
+
+ int64_t first_row_num() const { return first_row_; }
+
+ template <typename ValueDescWriter, typename DataWriter>
+ Status HandleInvalidRow(ValueDescWriter* values_writer, DataWriter* parsed_writer,
+ const char* start, const char* data, int32_t num_cols,
+ const char** out_data) {
+ // Find the end of the line without newline or carriage return
+ auto end = data;
+ if (*(end - 1) == '\n') {
+ --end;
+ }
+ if (*(end - 1) == '\r') {
+ --end;
+ }
+ const int32_t batch_row_including_skipped =
+ batch_.num_rows_ + batch_.num_skipped_rows();
+ InvalidRow row{batch_.num_cols_, num_cols,
+ first_row_ < 0 ? -1 : first_row_ + batch_row_including_skipped,
+ util::string_view(start, end - start)};
+
+ if (options_.invalid_row_handler &&
+ options_.invalid_row_handler(row) == InvalidRowResult::Skip) {
+ values_writer->RollbackLine();
+ parsed_writer->RollbackLine();
+ if (!batch_.skipped_rows_.empty()) {
+ // Should be increasing (non-strictly)
+ DCHECK_GE(batch_.num_rows_, batch_.skipped_rows_.back());
+ }
+ // Record the logical row number (not including skipped) since that
+ // is what we are going to look for later.
+ batch_.skipped_rows_.push_back(batch_.num_rows_);
+ *out_data = data;
+ return Status::OK();
+ }
+
+ return MismatchingColumns(row);
+ }
+
+ template <typename SpecializedOptions, typename ValueDescWriter, typename DataWriter>
+ Status ParseLine(ValueDescWriter* values_writer, DataWriter* parsed_writer,
+ const char* data, const char* data_end, bool is_final,
+ const char** out_data) {
+ int32_t num_cols = 0;
+ char c;
+ const auto start = data;
+
+ DCHECK_GT(data_end, data);
+
+ auto FinishField = [&]() { values_writer->FinishField(parsed_writer); };
+
+ values_writer->BeginLine();
+ parsed_writer->BeginLine();
+
+ // The parsing state machine
+
+ // Special case empty lines: do we start with a newline separator?
+ c = *data;
+ if (ARROW_PREDICT_FALSE(IsControlChar(c))) {
+ if (c == '\r') {
+ data++;
+ if (data < data_end && *data == '\n') {
+ data++;
+ }
+ goto EmptyLine;
+ }
+ if (c == '\n') {
+ data++;
+ goto EmptyLine;
+ }
+ }
+
+ FieldStart:
+ // At the start of a field
+ // Quoting is only recognized at start of field
+ if (SpecializedOptions::quoting &&
+ ARROW_PREDICT_FALSE(*data == options_.quote_char)) {
+ ++data;
+ values_writer->StartField(true /* quoted */);
+ goto InQuotedField;
+ } else {
+ values_writer->StartField(false /* quoted */);
+ goto InField;
+ }
+
+ InField:
+ // Inside a non-quoted part of a field
+ if (ARROW_PREDICT_FALSE(data == data_end)) {
+ goto AbortLine;
+ }
+ c = *data++;
+ if (SpecializedOptions::escaping && ARROW_PREDICT_FALSE(c == options_.escape_char)) {
+ if (ARROW_PREDICT_FALSE(data == data_end)) {
+ goto AbortLine;
+ }
+ c = *data++;
+ parsed_writer->PushFieldChar(c);
+ goto InField;
+ }
+ if (ARROW_PREDICT_FALSE(c == options_.delimiter)) {
+ goto FieldEnd;
+ }
+ if (ARROW_PREDICT_FALSE(IsControlChar(c))) {
+ if (c == '\r') {
+ // In the middle of a newline separator?
+ if (ARROW_PREDICT_TRUE(data < data_end) && *data == '\n') {
+ data++;
+ }
+ goto LineEnd;
+ }
+ if (c == '\n') {
+ goto LineEnd;
+ }
+ }
+ parsed_writer->PushFieldChar(c);
+ goto InField;
+
+ InQuotedField:
+ // Inside a quoted part of a field
+ if (ARROW_PREDICT_FALSE(data == data_end)) {
+ goto AbortLine;
+ }
+ c = *data++;
+ if (SpecializedOptions::escaping && ARROW_PREDICT_FALSE(c == options_.escape_char)) {
+ if (ARROW_PREDICT_FALSE(data == data_end)) {
+ goto AbortLine;
+ }
+ c = *data++;
+ parsed_writer->PushFieldChar(c);
+ goto InQuotedField;
+ }
+ if (ARROW_PREDICT_FALSE(c == options_.quote_char)) {
+ if (options_.double_quote && ARROW_PREDICT_TRUE(data < data_end) &&
+ ARROW_PREDICT_FALSE(*data == options_.quote_char)) {
+ // Double-quoting
+ ++data;
+ } else {
+ // End of single-quoting
+ goto InField;
+ }
+ }
+ parsed_writer->PushFieldChar(c);
+ goto InQuotedField;
+
+ FieldEnd:
+ // At the end of a field
+ FinishField();
+ ++num_cols;
+ if (ARROW_PREDICT_FALSE(data == data_end)) {
+ goto AbortLine;
+ }
+ goto FieldStart;
+
+ LineEnd:
+ // At the end of line
+ FinishField();
+ ++num_cols;
+ if (ARROW_PREDICT_FALSE(num_cols != batch_.num_cols_)) {
+ if (batch_.num_cols_ == -1) {
+ batch_.num_cols_ = num_cols;
+ } else {
+ return HandleInvalidRow(values_writer, parsed_writer, start, data, num_cols,
+ out_data);
+ }
+ }
+ ++batch_.num_rows_;
+ *out_data = data;
+ return Status::OK();
+
+ AbortLine:
+ // Not a full line except perhaps if in final block
+ if (is_final) {
+ goto LineEnd;
+ }
+ // Truncated line at end of block, rewind parsed state
+ values_writer->RollbackLine();
+ parsed_writer->RollbackLine();
+ return Status::OK();
+
+ EmptyLine:
+ if (!options_.ignore_empty_lines) {
+ if (batch_.num_cols_ == -1) {
+ // Consider as single value
+ batch_.num_cols_ = 1;
+ }
+ // Record as row of empty (null?) values
+ while (num_cols++ < batch_.num_cols_) {
+ values_writer->StartField(false /* quoted */);
+ FinishField();
+ }
+ ++batch_.num_rows_;
+ }
+ *out_data = data;
+ return Status::OK();
+ }
+
+ template <typename SpecializedOptions, typename ValueDescWriter, typename DataWriter>
+ Status ParseChunk(ValueDescWriter* values_writer, DataWriter* parsed_writer,
+ const char* data, const char* data_end, bool is_final,
+ int32_t rows_in_chunk, const char** out_data,
+ bool* finished_parsing) {
+ int32_t num_rows_deadline = batch_.num_rows_ + rows_in_chunk;
+
+ while (data < data_end && batch_.num_rows_ < num_rows_deadline) {
+ const char* line_end = data;
+ RETURN_NOT_OK(ParseLine<SpecializedOptions>(values_writer, parsed_writer, data,
+ data_end, is_final, &line_end));
+ if (line_end == data) {
+ // Cannot parse any further
+ *finished_parsing = true;
+ break;
+ }
+ data = line_end;
+ }
+ // Append new buffers and update size
+ std::shared_ptr<Buffer> values_buffer;
+ values_writer->Finish(&values_buffer);
+ if (values_buffer->size() > 0) {
+ values_size_ +=
+ static_cast<int32_t>(values_buffer->size() / sizeof(ParsedValueDesc) - 1);
+ batch_.values_buffers_.push_back(std::move(values_buffer));
+ }
+ *out_data = data;
+ return Status::OK();
+ }
+
+ template <typename SpecializedOptions>
+ Status ParseSpecialized(const std::vector<util::string_view>& views, bool is_final,
+ uint32_t* out_size) {
+ batch_ = DataBatch{batch_.num_cols_};
+ values_size_ = 0;
+
+ size_t total_view_length = 0;
+ for (const auto& view : views) {
+ total_view_length += view.length();
+ }
+ if (total_view_length > std::numeric_limits<uint32_t>::max()) {
+ return Status::Invalid("CSV block too large");
+ }
+
+ PresizedDataWriter parsed_writer(pool_, static_cast<uint32_t>(total_view_length));
+ uint32_t total_parsed_length = 0;
+
+ for (const auto& view : views) {
+ const char* data = view.data();
+ const char* data_end = view.data() + view.length();
+ bool finished_parsing = false;
+
+ if (batch_.num_cols_ == -1) {
+ // Can't presize values when the number of columns is not known, first parse
+ // a single line
+ const int32_t rows_in_chunk = 1;
+ ResizableValueDescWriter values_writer(pool_);
+ values_writer.Start(parsed_writer);
+
+ RETURN_NOT_OK(ParseChunk<SpecializedOptions>(&values_writer, &parsed_writer, data,
+ data_end, is_final, rows_in_chunk,
+ &data, &finished_parsing));
+ if (batch_.num_cols_ == -1) {
+ return ParseError("Empty CSV file or block: cannot infer number of columns");
+ }
+ }
+
+ while (!finished_parsing && data < data_end && batch_.num_rows_ < max_num_rows_) {
+ // We know the number of columns, so can presize a values array for
+ // a given number of rows
+ DCHECK_GE(batch_.num_cols_, 0);
+
+ int32_t rows_in_chunk;
+ constexpr int32_t kTargetChunkSize = 32768; // in number of values
+ if (batch_.num_cols_ > 0) {
+ rows_in_chunk = std::min(std::max(kTargetChunkSize / batch_.num_cols_, 512),
+ max_num_rows_ - batch_.num_rows_);
+ } else {
+ rows_in_chunk = std::min(kTargetChunkSize, max_num_rows_ - batch_.num_rows_);
+ }
+
+ PresizedValueDescWriter values_writer(pool_, rows_in_chunk, batch_.num_cols_);
+ values_writer.Start(parsed_writer);
+
+ RETURN_NOT_OK(ParseChunk<SpecializedOptions>(&values_writer, &parsed_writer, data,
+ data_end, is_final, rows_in_chunk,
+ &data, &finished_parsing));
+ }
+ DCHECK_GE(data, view.data());
+ DCHECK_LE(data, data_end);
+ total_parsed_length += static_cast<uint32_t>(data - view.data());
+
+ if (data < data_end) {
+ // Stopped early, for some reason
+ break;
+ }
+ }
+
+ parsed_writer.Finish(&batch_.parsed_buffer_);
+ batch_.parsed_size_ = static_cast<int32_t>(batch_.parsed_buffer_->size());
+ batch_.parsed_ = batch_.parsed_buffer_->data();
+
+ if (batch_.num_cols_ == -1) {
+ DCHECK_EQ(batch_.num_rows_, 0);
+ }
+ DCHECK_EQ(values_size_, batch_.num_rows_ * batch_.num_cols_);
+#ifndef NDEBUG
+ if (batch_.num_rows_ > 0) {
+ // Ending parsed offset should be equal to number of parsed bytes
+ DCHECK_GT(batch_.values_buffers_.size(), 0);
+ const auto& last_values_buffer = batch_.values_buffers_.back();
+ const auto last_values =
+ reinterpret_cast<const ParsedValueDesc*>(last_values_buffer->data());
+ const auto last_values_size = last_values_buffer->size() / sizeof(ParsedValueDesc);
+ const auto check_parsed_size =
+ static_cast<int32_t>(last_values[last_values_size - 1].offset);
+ DCHECK_EQ(batch_.parsed_size_, check_parsed_size);
+ } else {
+ DCHECK_EQ(batch_.parsed_size_, 0);
+ }
+#endif
+ *out_size = static_cast<uint32_t>(total_parsed_length);
+ return Status::OK();
+ }
+
+ Status Parse(const std::vector<util::string_view>& data, bool is_final,
+ uint32_t* out_size) {
+ if (options_.quoting) {
+ if (options_.escaping) {
+ return ParseSpecialized<SpecializedOptions<true, true>>(data, is_final, out_size);
+ } else {
+ return ParseSpecialized<SpecializedOptions<true, false>>(data, is_final,
+ out_size);
+ }
+ } else {
+ if (options_.escaping) {
+ return ParseSpecialized<SpecializedOptions<false, true>>(data, is_final,
+ out_size);
+ } else {
+ return ParseSpecialized<SpecializedOptions<false, false>>(data, is_final,
+ out_size);
+ }
+ }
+ }
+
+ protected:
+ MemoryPool* pool_;
+ const ParseOptions options_;
+ const int64_t first_row_;
+ // The maximum number of rows to parse from a block
+ int32_t max_num_rows_;
+
+ // Unparsed data size
+ int32_t values_size_;
+ // Parsed data batch
+ DataBatch batch_;
+};
+
+BlockParser::BlockParser(ParseOptions options, int32_t num_cols, int64_t first_row,
+ int32_t max_num_rows)
+ : BlockParser(default_memory_pool(), options, num_cols, first_row, max_num_rows) {}
+
+BlockParser::BlockParser(MemoryPool* pool, ParseOptions options, int32_t num_cols,
+ int64_t first_row, int32_t max_num_rows)
+ : impl_(new BlockParserImpl(pool, std::move(options), num_cols, first_row,
+ max_num_rows)) {}
+
+BlockParser::~BlockParser() {}
+
+Status BlockParser::Parse(const std::vector<util::string_view>& data,
+ uint32_t* out_size) {
+ return impl_->Parse(data, false /* is_final */, out_size);
+}
+
+Status BlockParser::ParseFinal(const std::vector<util::string_view>& data,
+ uint32_t* out_size) {
+ return impl_->Parse(data, true /* is_final */, out_size);
+}
+
+Status BlockParser::Parse(util::string_view data, uint32_t* out_size) {
+ return impl_->Parse({data}, false /* is_final */, out_size);
+}
+
+Status BlockParser::ParseFinal(util::string_view data, uint32_t* out_size) {
+ return impl_->Parse({data}, true /* is_final */, out_size);
+}
+
+const DataBatch& BlockParser::parsed_batch() const { return impl_->parsed_batch(); }
+
+int64_t BlockParser::first_row_num() const { return impl_->first_row_num(); }
+
+int32_t SkipRows(const uint8_t* data, uint32_t size, int32_t num_rows,
+ const uint8_t** out_data) {
+ const auto end = data + size;
+ int32_t skipped_rows = 0;
+ *out_data = data;
+
+ for (; skipped_rows < num_rows; ++skipped_rows) {
+ uint8_t c;
+ do {
+ while (ARROW_PREDICT_FALSE(data < end && !IsControlChar(*data))) {
+ ++data;
+ }
+ if (ARROW_PREDICT_FALSE(data == end)) {
+ return skipped_rows;
+ }
+ c = *data++;
+ } while (c != '\r' && c != '\n');
+ if (c == '\r' && data < end && *data == '\n') {
+ ++data;
+ }
+ *out_data = data;
+ }
+
+ return skipped_rows;
+}
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/parser.h b/src/arrow/cpp/src/arrow/csv/parser.h
new file mode 100644
index 000000000..fb003faaf
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/parser.h
@@ -0,0 +1,227 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/csv/options.h"
+#include "arrow/csv/type_fwd.h"
+#include "arrow/status.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class MemoryPool;
+
+namespace csv {
+
+/// Skip at most num_rows from the given input. The input pointer is updated
+/// and the number of actually skipped rows is returns (may be less than
+/// requested if the input is too short).
+ARROW_EXPORT
+int32_t SkipRows(const uint8_t* data, uint32_t size, int32_t num_rows,
+ const uint8_t** out_data);
+
+class BlockParserImpl;
+
+namespace detail {
+
+struct ParsedValueDesc {
+ uint32_t offset : 31;
+ bool quoted : 1;
+};
+
+class ARROW_EXPORT DataBatch {
+ public:
+ explicit DataBatch(int32_t num_cols) : num_cols_(num_cols) {}
+
+ /// \brief Return the number of parsed rows (not skipped)
+ int32_t num_rows() const { return num_rows_; }
+ /// \brief Return the number of parsed columns
+ int32_t num_cols() const { return num_cols_; }
+ /// \brief Return the total size in bytes of parsed data
+ uint32_t num_bytes() const { return parsed_size_; }
+ /// \brief Return the number of skipped rows
+ int32_t num_skipped_rows() const { return static_cast<int32_t>(skipped_rows_.size()); }
+
+ template <typename Visitor>
+ Status VisitColumn(int32_t col_index, int64_t first_row, Visitor&& visit) const {
+ using detail::ParsedValueDesc;
+
+ int32_t batch_row = 0;
+ for (size_t buf_index = 0; buf_index < values_buffers_.size(); ++buf_index) {
+ const auto& values_buffer = values_buffers_[buf_index];
+ const auto values = reinterpret_cast<const ParsedValueDesc*>(values_buffer->data());
+ const auto max_pos =
+ static_cast<int32_t>(values_buffer->size() / sizeof(ParsedValueDesc)) - 1;
+ for (int32_t pos = col_index; pos < max_pos; pos += num_cols_, ++batch_row) {
+ auto start = values[pos].offset;
+ auto stop = values[pos + 1].offset;
+ auto quoted = values[pos + 1].quoted;
+ Status status = visit(parsed_ + start, stop - start, quoted);
+ if (ARROW_PREDICT_FALSE(!status.ok())) {
+ return DecorateWithRowNumber(std::move(status), first_row, batch_row);
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ template <typename Visitor>
+ Status VisitLastRow(Visitor&& visit) const {
+ using detail::ParsedValueDesc;
+
+ const auto& values_buffer = values_buffers_.back();
+ const auto values = reinterpret_cast<const ParsedValueDesc*>(values_buffer->data());
+ const auto start_pos =
+ static_cast<int32_t>(values_buffer->size() / sizeof(ParsedValueDesc)) -
+ num_cols_ - 1;
+ for (int32_t col_index = 0; col_index < num_cols_; ++col_index) {
+ auto start = values[start_pos + col_index].offset;
+ auto stop = values[start_pos + col_index + 1].offset;
+ auto quoted = values[start_pos + col_index + 1].quoted;
+ ARROW_RETURN_NOT_OK(visit(parsed_ + start, stop - start, quoted));
+ }
+ return Status::OK();
+ }
+
+ protected:
+ Status DecorateWithRowNumber(Status&& status, int64_t first_row,
+ int32_t batch_row) const {
+ if (first_row >= 0) {
+ // `skipped_rows_` is in ascending order by construction, so use bisection
+ // to find out how many rows were skipped before `batch_row`.
+ const auto skips_before =
+ std::upper_bound(skipped_rows_.begin(), skipped_rows_.end(), batch_row) -
+ skipped_rows_.begin();
+ status = status.WithMessage("Row #", batch_row + skips_before + first_row, ": ",
+ status.message());
+ }
+ // Use return_if so that when extra context is enabled it will be added
+ ARROW_RETURN_IF_(true, std::move(status), ARROW_STRINGIFY(status));
+ }
+
+ // The number of rows in this batch (not including any skipped ones)
+ int32_t num_rows_ = 0;
+ // The number of columns
+ int32_t num_cols_ = 0;
+
+ // XXX should we ensure the parsed buffer is padded with 8 or 16 excess zero bytes?
+ // It may help with null parsing...
+ std::vector<std::shared_ptr<Buffer>> values_buffers_;
+ std::shared_ptr<Buffer> parsed_buffer_;
+ const uint8_t* parsed_ = NULLPTR;
+ int32_t parsed_size_ = 0;
+
+ // Record the current num_rows_ each time a row is skipped
+ std::vector<int32_t> skipped_rows_;
+
+ friend class ::arrow::csv::BlockParserImpl;
+};
+
+} // namespace detail
+
+constexpr int32_t kMaxParserNumRows = 100000;
+
+/// \class BlockParser
+/// \brief A reusable block-based parser for CSV data
+///
+/// The parser takes a block of CSV data and delimits rows and fields,
+/// unquoting and unescaping them on the fly. Parsed data is own by the
+/// parser, so the original buffer can be discarded after Parse() returns.
+///
+/// If the block is truncated (i.e. not all data can be parsed), it is up
+/// to the caller to arrange the next block to start with the trailing data.
+/// Also, if the previous block ends with CR (0x0d) and a new block starts
+/// with LF (0x0a), the parser will consider the leading newline as an empty
+/// line; the caller should therefore strip it.
+class ARROW_EXPORT BlockParser {
+ public:
+ explicit BlockParser(ParseOptions options, int32_t num_cols = -1,
+ int64_t first_row = -1, int32_t max_num_rows = kMaxParserNumRows);
+ explicit BlockParser(MemoryPool* pool, ParseOptions options, int32_t num_cols = -1,
+ int64_t first_row = -1, int32_t max_num_rows = kMaxParserNumRows);
+ ~BlockParser();
+
+ /// \brief Parse a block of data
+ ///
+ /// Parse a block of CSV data, ingesting up to max_num_rows rows.
+ /// The number of bytes actually parsed is returned in out_size.
+ Status Parse(util::string_view data, uint32_t* out_size);
+
+ /// \brief Parse sequential blocks of data
+ ///
+ /// Only the last block is allowed to be truncated.
+ Status Parse(const std::vector<util::string_view>& data, uint32_t* out_size);
+
+ /// \brief Parse the final block of data
+ ///
+ /// Like Parse(), but called with the final block in a file.
+ /// The last row may lack a trailing line separator.
+ Status ParseFinal(util::string_view data, uint32_t* out_size);
+
+ /// \brief Parse the final sequential blocks of data
+ ///
+ /// Only the last block is allowed to be truncated.
+ Status ParseFinal(const std::vector<util::string_view>& data, uint32_t* out_size);
+
+ /// \brief Return the number of parsed rows
+ int32_t num_rows() const { return parsed_batch().num_rows(); }
+ /// \brief Return the number of parsed columns
+ int32_t num_cols() const { return parsed_batch().num_cols(); }
+ /// \brief Return the total size in bytes of parsed data
+ uint32_t num_bytes() const { return parsed_batch().num_bytes(); }
+
+ /// \brief Return the total number of rows including rows which were skipped
+ int32_t total_num_rows() const {
+ return parsed_batch().num_rows() + parsed_batch().num_skipped_rows();
+ }
+
+ /// \brief Return the row number of the first row in the block or -1 if unsupported
+ int64_t first_row_num() const;
+
+ /// \brief Visit parsed values in a column
+ ///
+ /// The signature of the visitor is
+ /// Status(const uint8_t* data, uint32_t size, bool quoted)
+ template <typename Visitor>
+ Status VisitColumn(int32_t col_index, Visitor&& visit) const {
+ return parsed_batch().VisitColumn(col_index, first_row_num(),
+ std::forward<Visitor>(visit));
+ }
+
+ template <typename Visitor>
+ Status VisitLastRow(Visitor&& visit) const {
+ return parsed_batch().VisitLastRow(std::forward<Visitor>(visit));
+ }
+
+ protected:
+ std::unique_ptr<BlockParserImpl> impl_;
+
+ const detail::DataBatch& parsed_batch() const;
+};
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/parser_benchmark.cc b/src/arrow/cpp/src/arrow/csv/parser_benchmark.cc
new file mode 100644
index 000000000..b279a3c0c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/parser_benchmark.cc
@@ -0,0 +1,205 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include "arrow/csv/chunker.h"
+#include "arrow/csv/options.h"
+#include "arrow/csv/parser.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace csv {
+
+struct Example {
+ int32_t num_rows;
+ const char* csv_rows;
+};
+
+const Example quoted_example{1, "abc,\"d,f\",12.34,\n"};
+const Example escaped_example{1, "abc,d\\,f,12.34,\n"};
+
+const Example flights_example{
+ 8,
+ R"(2015,1,1,4,AA,2336,N3KUAA,LAX,PBI,0010,0002,-8,12,0014,280,279,263,2330,0737,4,0750,0741,-9,0,0,,,,,,
+2015,1,1,4,US,840,N171US,SFO,CLT,0020,0018,-2,16,0034,286,293,266,2296,0800,11,0806,0811,5,0,0,,,,,,
+2015,1,1,4,AA,258,N3HYAA,LAX,MIA,0020,0015,-5,15,0030,285,281,258,2342,0748,8,0805,0756,-9,0,0,,,,,,
+2015,1,1,4,AS,135,N527AS,SEA,ANC,0025,0024,-1,11,0035,235,215,199,1448,0254,5,0320,0259,-21,0,0,,,,,,
+2015,1,1,4,DL,806,N3730B,SFO,MSP,0025,0020,-5,18,0038,217,230,206,1589,0604,6,0602,0610,8,0,0,,,,,,
+2015,1,1,4,NK,612,N635NK,LAS,MSP,0025,0019,-6,11,0030,181,170,154,1299,0504,5,0526,0509,-17,0,0,,,,,,
+2015,1,1,4,US,2013,N584UW,LAX,CLT,0030,0044,14,13,0057,273,249,228,2125,0745,8,0803,0753,-10,0,0,,,,,,
+2015,1,1,4,AA,1112,N3LAAA,SFO,DFW,0030,0019,-11,17,0036,195,193,173,1464,0529,3,0545,0532,-13,0,0,,,,,,
+)"};
+
+// NOTE: quoted
+const Example vehicles_example{
+ 2,
+ R"(7088743681,https://greensboro.craigslist.org/ctd/d/cary-2004-honda-element-lx-4dr-suv/7088743681.html,greensboro,https://greensboro.craigslist.org,3995,2004,honda,element,,,gas,212526,clean,automatic,5J6YH18314L006498,fwd,,SUV,orange,https://images.craigslist.org/00E0E_eAUnhFF86M4_600x450.jpg,"2004 Honda Element LX 4dr SUV Offered by: Best Import Auto Sales Inc — (919) 800-0650 — $3,995 EXCELLENT SHAPE INSIDE AND OUT FULLY SERVICED AND READY TO GO ,RUNS AND DRIVES PERFECT ,PLEASE CALL OR TEXT 919 454 4848 OR CALL 919 380 0380 IF INTERESTED. Best Import Auto Sales Inc Year: 2004 Make: Honda Model: Element Series: LX 4dr SUV VIN: 5J6YH18314L006498 Stock #: 4L006498 Condition: Used Mileage: 212,526 Exterior: Orange Interior: Black Body: SUV Transmission: Automatic 4-Speed Engine: 2.4L I4 **** Best Import Auto Sales Inc. 🚘 Raleigh Auto Dealer ***** ⚡️⚡️⚡️ Call Or Text (919) 800-0650 ⚡️⚡️⚡️ ✅ - We can arrange Financing Options with most banks and credit unions!!!! ✅ Extended Warranties Available on most vehicles!! ""Call To Inquire"" ✅ Full Service ASE-Certified Shop Onsite! More vehicle details: best-import-auto-sales-inc.hammerwebsites.net/v/3kE08kSD Address: 1501 Buck Jones Rd Raleigh, NC 27606 Phone: (919) 800-0650 Website: www.bestimportsonline.com 📲 ☎️ Call or text (919) 800-0650 for quick answers to your questions about this Honda Element Your message will always be answered by a real human — never an automated system. Disclaimer: Best Import Auto Sales Inc will never sell, share, or spam your mobile number. Standard text messaging rates may apply. 2004 Honda Element LX 4dr SUV 6fbc204ebd7e4a32a30dcf2c8c3bcdea",,nc,35.7636,-78.7443
+ 7088744126,https://greensboro.craigslist.org/cto/d/greensboro-2011-jaguar-xf-premier/7088744126.html,greensboro,https://greensboro.craigslist.org,9500,2011,jaguar,xf,excellent,,gas,85000,clean,automatic,,,,,blue,https://images.craigslist.org/00505_f22HGItCRpc_600x450.jpg,"2011 jaguar XF premium - estate sale. Retired lady executive. Like new, garaged and maintained. Very nice leather, heated seats, electric sunroof, metallic blue paint. 85K miles bumper-to-bumper warranty. Premium radio sound system. Built-in phone connection. Please call show contact info cell or show contact info . Asking Price $9500",,nc,36.1032,-79.8794
+)"};
+
+const Example stocks_example{
+ 3,
+ R"(2,2010-01-27 00:00:00,002204,华锐铸钢,536498.0,135378.0,2652784.2001924426,14160629.45,5.382023337513902,5.288274712474071,5.382023337513902,5.341540976701248,,5.338025403262254,1.01364599,0.21306505690870553
+3,2010-02-05 00:00:00,600266,北京城建,1122615.0,1122615.0,8102476.086666377,57695471.0,7.236029036381633,7.025270909108382,7.170459841229955,7.095523618199466,,7.120720923193468,2.3025570905818964,0.4683513939405588
+4,2010-01-04 00:00:00,600289,亿阳信通,602926.359,602926.359,16393247.138998777,167754890.0,10.381817699665978,9.960037526145015,10.092597009251604,10.321563389162982,,10.233170315655089,4.436963485334562,0.6025431050299465
+)"};
+
+static constexpr int32_t kNumRows = 10000;
+
+static std::string BuildCSVData(const Example& example) {
+ std::stringstream ss;
+ for (int32_t i = 0; i < kNumRows; i += example.num_rows) {
+ ss << example.csv_rows;
+ }
+ return ss.str();
+}
+
+static void BenchmarkCSVChunking(benchmark::State& state, // NOLINT non-const reference
+ const std::string& csv, ParseOptions options) {
+ auto chunker = MakeChunker(options);
+ auto block = std::make_shared<Buffer>(util::string_view(csv));
+
+ while (state.KeepRunning()) {
+ std::shared_ptr<Buffer> whole, partial;
+ ABORT_NOT_OK(chunker->Process(block, &whole, &partial));
+ benchmark::DoNotOptimize(whole->size());
+ }
+
+ state.SetBytesProcessed(state.iterations() * csv.length());
+}
+
+static void ChunkCSVQuotedBlock(benchmark::State& state) { // NOLINT non-const reference
+ auto csv = BuildCSVData(quoted_example);
+ auto options = ParseOptions::Defaults();
+ options.quoting = true;
+ options.escaping = false;
+ options.newlines_in_values = true;
+
+ BenchmarkCSVChunking(state, csv, options);
+}
+
+static void ChunkCSVEscapedBlock(benchmark::State& state) { // NOLINT non-const reference
+ auto csv = BuildCSVData(escaped_example);
+ auto options = ParseOptions::Defaults();
+ options.quoting = false;
+ options.escaping = true;
+ options.newlines_in_values = true;
+
+ BenchmarkCSVChunking(state, csv, options);
+}
+
+static void ChunkCSVNoNewlinesBlock(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto csv = BuildCSVData(escaped_example);
+ auto options = ParseOptions::Defaults();
+ options.quoting = true;
+ options.escaping = false;
+ options.newlines_in_values = false;
+
+ BenchmarkCSVChunking(state, csv, options);
+ // Provides better regression stability with timings rather than bogus
+ // bandwidth.
+ state.SetBytesProcessed(0);
+}
+
+static void BenchmarkCSVParsing(benchmark::State& state, // NOLINT non-const reference
+ const std::string& csv, int32_t num_rows,
+ ParseOptions options) {
+ BlockParser parser(options, -1, num_rows + 1);
+
+ while (state.KeepRunning()) {
+ uint32_t parsed_size = 0;
+ ABORT_NOT_OK(parser.Parse(util::string_view(csv), &parsed_size));
+
+ // Include performance of visiting the parsed values, as that might
+ // vary depending on the parser's internal data structures.
+ bool dummy_quoted = false;
+ uint32_t dummy_size = 0;
+ auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) {
+ dummy_size += size;
+ dummy_quoted ^= quoted;
+ return Status::OK();
+ };
+ for (int32_t col = 0; col < parser.num_cols(); ++col) {
+ ABORT_NOT_OK(parser.VisitColumn(col, visit));
+ benchmark::DoNotOptimize(dummy_size);
+ benchmark::DoNotOptimize(dummy_quoted);
+ }
+ }
+
+ state.SetBytesProcessed(state.iterations() * csv.size());
+}
+
+static void BenchmarkCSVParsing(benchmark::State& state, // NOLINT non-const reference
+ const Example& example, ParseOptions options) {
+ auto csv = BuildCSVData(example);
+ BenchmarkCSVParsing(state, csv, kNumRows, options);
+}
+
+static void ParseCSVQuotedBlock(benchmark::State& state) { // NOLINT non-const reference
+ auto options = ParseOptions::Defaults();
+ options.quoting = true;
+ options.escaping = false;
+
+ BenchmarkCSVParsing(state, quoted_example, options);
+}
+
+static void ParseCSVEscapedBlock(benchmark::State& state) { // NOLINT non-const reference
+ auto options = ParseOptions::Defaults();
+ options.quoting = false;
+ options.escaping = true;
+
+ BenchmarkCSVParsing(state, escaped_example, options);
+}
+
+static void ParseCSVFlightsExample(
+ benchmark::State& state) { // NOLINT non-const reference
+ BenchmarkCSVParsing(state, flights_example, ParseOptions::Defaults());
+}
+
+static void ParseCSVVehiclesExample(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto options = ParseOptions::Defaults();
+ options.quoting = true;
+ options.escaping = false;
+
+ BenchmarkCSVParsing(state, vehicles_example, options);
+}
+
+static void ParseCSVStocksExample(
+ benchmark::State& state) { // NOLINT non-const reference
+ BenchmarkCSVParsing(state, stocks_example, ParseOptions::Defaults());
+}
+
+BENCHMARK(ChunkCSVQuotedBlock);
+BENCHMARK(ChunkCSVEscapedBlock);
+BENCHMARK(ChunkCSVNoNewlinesBlock);
+
+BENCHMARK(ParseCSVQuotedBlock);
+BENCHMARK(ParseCSVEscapedBlock);
+BENCHMARK(ParseCSVFlightsExample);
+BENCHMARK(ParseCSVVehiclesExample);
+BENCHMARK(ParseCSVStocksExample);
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/parser_test.cc b/src/arrow/cpp/src/arrow/csv/parser_test.cc
new file mode 100644
index 000000000..3eeb746af
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/parser_test.cc
@@ -0,0 +1,805 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/csv/options.h"
+#include "arrow/csv/parser.h"
+#include "arrow/csv/test_common.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+namespace csv {
+
+void CheckSkipRows(const std::string& rows, int32_t num_rows,
+ int32_t expected_skipped_rows, int32_t expected_skipped_bytes) {
+ const uint8_t* start = reinterpret_cast<const uint8_t*>(rows.data());
+ const uint8_t* data;
+ int32_t skipped_rows =
+ SkipRows(start, static_cast<int32_t>(rows.size()), num_rows, &data);
+ ASSERT_EQ(skipped_rows, expected_skipped_rows);
+ ASSERT_EQ(data - start, expected_skipped_bytes);
+}
+
+TEST(SkipRows, Basics) {
+ CheckSkipRows("", 0, 0, 0);
+ CheckSkipRows("", 15, 0, 0);
+
+ CheckSkipRows("a\nb\nc\nd", 1, 1, 2);
+ CheckSkipRows("a\nb\nc\nd", 2, 2, 4);
+ CheckSkipRows("a\nb\nc\nd", 3, 3, 6);
+ CheckSkipRows("a\nb\nc\nd", 4, 3, 6);
+
+ CheckSkipRows("a\nb\nc\nd\n", 3, 3, 6);
+ CheckSkipRows("a\nb\nc\nd\n", 4, 4, 8);
+ CheckSkipRows("a\nb\nc\nd\n", 5, 4, 8);
+
+ CheckSkipRows("\t\n\t\n\t\n\t", 1, 1, 2);
+ CheckSkipRows("\t\n\t\n\t\n\t", 3, 3, 6);
+ CheckSkipRows("\t\n\t\n\t\n\t", 4, 3, 6);
+
+ CheckSkipRows("a\r\nb\nc\rd\r\n", 1, 1, 3);
+ CheckSkipRows("a\r\nb\nc\rd\r\n", 2, 2, 5);
+ CheckSkipRows("a\r\nb\nc\rd\r\n", 3, 3, 7);
+ CheckSkipRows("a\r\nb\nc\rd\r\n", 4, 4, 10);
+ CheckSkipRows("a\r\nb\nc\rd\r\n", 5, 4, 10);
+
+ CheckSkipRows("a\r\nb\nc\rd\r", 4, 4, 9);
+ CheckSkipRows("a\r\nb\nc\rd\r", 5, 4, 9);
+ CheckSkipRows("a\r\nb\nc\rd\re", 4, 4, 9);
+ CheckSkipRows("a\r\nb\nc\rd\re", 5, 4, 9);
+
+ CheckSkipRows("\n\r\n\r\r\n\n\r\n\r", 1, 1, 1);
+ CheckSkipRows("\n\r\n\r\r\n\n\r\n\r", 2, 2, 3);
+ CheckSkipRows("\n\r\n\r\r\n\n\r\n\r", 3, 3, 4);
+ CheckSkipRows("\n\r\n\r\r\n\n\r\n\r", 4, 4, 6);
+ CheckSkipRows("\n\r\n\r\r\n\n\r\n\r", 5, 5, 7);
+ CheckSkipRows("\n\r\n\r\r\n\n\r\n\r", 6, 6, 9);
+ CheckSkipRows("\n\r\n\r\r\n\n\r\n\r", 7, 7, 10);
+ CheckSkipRows("\n\r\n\r\r\n\n\r\n\r", 8, 7, 10);
+}
+
+////////////////////////////////////////////////////////////////////////////
+// BlockParser tests
+
+// Read the column with the given index out of the BlockParser.
+void GetColumn(const BlockParser& parser, int32_t col_index,
+ std::vector<std::string>* out, std::vector<bool>* out_quoted = nullptr) {
+ std::vector<std::string> values;
+ std::vector<bool> quoted_values;
+ auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status {
+ values.push_back(std::string(reinterpret_cast<const char*>(data), size));
+ if (out_quoted) {
+ quoted_values.push_back(quoted);
+ }
+ return Status::OK();
+ };
+ ASSERT_OK(parser.VisitColumn(col_index, visit));
+ *out = std::move(values);
+ if (out_quoted) {
+ *out_quoted = std::move(quoted_values);
+ }
+}
+
+void GetLastRow(const BlockParser& parser, std::vector<std::string>* out,
+ std::vector<bool>* out_quoted = nullptr) {
+ std::vector<std::string> values;
+ std::vector<bool> quoted_values;
+ auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status {
+ values.push_back(std::string(reinterpret_cast<const char*>(data), size));
+ if (out_quoted) {
+ quoted_values.push_back(quoted);
+ }
+ return Status::OK();
+ };
+ ASSERT_OK(parser.VisitLastRow(visit));
+ *out = std::move(values);
+ if (out_quoted) {
+ *out_quoted = std::move(quoted_values);
+ }
+}
+
+size_t TotalViewLength(const std::vector<util::string_view>& views) {
+ size_t total_view_length = 0;
+ for (const auto& view : views) {
+ total_view_length += view.length();
+ }
+ return total_view_length;
+}
+
+Status Parse(BlockParser& parser, const std::string& str, uint32_t* out_size) {
+ return parser.Parse(util::string_view(str), out_size);
+}
+
+Status ParseFinal(BlockParser& parser, const std::string& str, uint32_t* out_size) {
+ return parser.ParseFinal(util::string_view(str), out_size);
+}
+
+void AssertParseOk(BlockParser& parser, const std::string& str) {
+ uint32_t parsed_size = static_cast<uint32_t>(-1);
+ ASSERT_OK(Parse(parser, str, &parsed_size));
+ ASSERT_EQ(parsed_size, str.size());
+}
+
+void AssertParseOk(BlockParser& parser, const std::vector<util::string_view>& data) {
+ uint32_t parsed_size = static_cast<uint32_t>(-1);
+ ASSERT_OK(parser.Parse(data, &parsed_size));
+ ASSERT_EQ(parsed_size, TotalViewLength(data));
+}
+
+void AssertParseFinal(BlockParser& parser, const std::string& str) {
+ uint32_t parsed_size = static_cast<uint32_t>(-1);
+ ASSERT_OK(ParseFinal(parser, str, &parsed_size));
+ ASSERT_EQ(parsed_size, str.size());
+}
+
+void AssertParseFinal(BlockParser& parser, const std::vector<util::string_view>& data) {
+ uint32_t parsed_size = static_cast<uint32_t>(-1);
+ ASSERT_OK(parser.ParseFinal(data, &parsed_size));
+ ASSERT_EQ(parsed_size, TotalViewLength(data));
+}
+
+void AssertParsePartial(BlockParser& parser, const std::string& str,
+ uint32_t expected_size) {
+ uint32_t parsed_size = static_cast<uint32_t>(-1);
+ ASSERT_OK(Parse(parser, str, &parsed_size));
+ ASSERT_EQ(parsed_size, expected_size);
+}
+
+void AssertLastRowEq(const BlockParser& parser, const std::vector<std::string> expected) {
+ std::vector<std::string> values;
+ GetLastRow(parser, &values);
+ ASSERT_EQ(parser.num_rows(), expected.size());
+ ASSERT_EQ(values, expected);
+}
+
+void AssertLastRowEq(const BlockParser& parser, const std::vector<std::string> expected,
+ const std::vector<bool> expected_quoted) {
+ std::vector<std::string> values;
+ std::vector<bool> quoted;
+ GetLastRow(parser, &values, &quoted);
+ ASSERT_EQ(parser.num_cols(), expected.size());
+ ASSERT_EQ(values, expected);
+ ASSERT_EQ(quoted, expected_quoted);
+}
+
+void AssertColumnEq(const BlockParser& parser, int32_t col_index,
+ const std::vector<std::string> expected) {
+ std::vector<std::string> values;
+ GetColumn(parser, col_index, &values);
+ ASSERT_EQ(parser.num_rows(), expected.size());
+ ASSERT_EQ(values, expected);
+}
+
+void AssertColumnEq(const BlockParser& parser, int32_t col_index,
+ const std::vector<std::string> expected,
+ const std::vector<bool> expected_quoted) {
+ std::vector<std::string> values;
+ std::vector<bool> quoted;
+ GetColumn(parser, col_index, &values, &quoted);
+ ASSERT_EQ(parser.num_rows(), expected.size());
+ ASSERT_EQ(values, expected);
+ ASSERT_EQ(quoted, expected_quoted);
+}
+
+void AssertColumnsEq(const BlockParser& parser,
+ const std::vector<std::vector<std::string>> expected) {
+ ASSERT_EQ(parser.num_cols(), expected.size());
+ for (int32_t col_index = 0; col_index < parser.num_cols(); ++col_index) {
+ AssertColumnEq(parser, col_index, expected[col_index]);
+ }
+}
+
+void AssertColumnsEq(const BlockParser& parser,
+ const std::vector<std::vector<std::string>> expected,
+ const std::vector<std::vector<bool>> quoted) {
+ ASSERT_EQ(parser.num_cols(), expected.size());
+ for (int32_t col_index = 0; col_index < parser.num_cols(); ++col_index) {
+ AssertColumnEq(parser, col_index, expected[col_index], quoted[col_index]);
+ }
+ uint32_t total_bytes = 0;
+ for (const auto& col : expected) {
+ for (const auto& field : col) {
+ total_bytes += static_cast<uint32_t>(field.size());
+ }
+ }
+ ASSERT_EQ(total_bytes, parser.num_bytes());
+}
+
+TEST(BlockParser, Basics) {
+ {
+ auto csv = MakeCSVData({"ab,cd,\n", "ef,,gh\n", ",ij,kl\n"});
+ BlockParser parser(ParseOptions::Defaults());
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"ab", "ef", ""}, {"cd", "", "ij"}, {"", "gh", "kl"}});
+ AssertLastRowEq(parser, {"", "ij", "kl"}, {false, false, false});
+ }
+ {
+ auto csv1 = MakeCSVData({"ab,cd,\n", "ef,,gh\n"});
+ auto csv2 = MakeCSVData({",ij,kl\n"});
+ std::vector<util::string_view> csvs = {csv1, csv2};
+ BlockParser parser(ParseOptions::Defaults());
+ AssertParseOk(parser, {{csv1}, {csv2}});
+ AssertColumnsEq(parser, {{"ab", "ef", ""}, {"cd", "", "ij"}, {"", "gh", "kl"}});
+ AssertLastRowEq(parser, {"", "ij", "kl"}, {false, false, false});
+ }
+}
+
+TEST(BlockParser, EmptyHeader) {
+ // Cannot infer number of columns
+ uint32_t out_size;
+ {
+ auto csv = MakeCSVData({""});
+ BlockParser parser(ParseOptions::Defaults());
+ ASSERT_RAISES(Invalid, ParseFinal(parser, csv, &out_size));
+ }
+ {
+ auto csv = MakeCSVData({"\n"});
+ BlockParser parser(ParseOptions::Defaults());
+ ASSERT_RAISES(Invalid, ParseFinal(parser, csv, &out_size));
+ }
+}
+
+TEST(BlockParser, Empty) {
+ {
+ auto csv = MakeCSVData({",\n"});
+ BlockParser parser(ParseOptions::Defaults());
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{""}, {""}});
+ AssertLastRowEq(parser, {"", ""}, {false, false});
+ }
+ {
+ auto csv = MakeCSVData({",\n,\n"});
+ BlockParser parser(ParseOptions::Defaults());
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"", ""}, {"", ""}});
+ AssertLastRowEq(parser, {"", ""}, {false, false});
+ }
+}
+
+TEST(BlockParser, Whitespace) {
+ // Non-newline whitespace is preserved
+ auto csv = MakeCSVData({"a b, cd, \n", " ef, \t,gh\n"});
+ BlockParser parser(ParseOptions::Defaults());
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a b", " ef"}, {" cd", " \t"}, {" ", "gh"}});
+}
+
+TEST(BlockParser, Newlines) {
+ auto csv = MakeCSVData({"a,b\n", "c,d\r\n", "e,f\r", "g,h\r"});
+ BlockParser parser(ParseOptions::Defaults());
+
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a", "c", "e", "g"}, {"b", "d", "f", "h"}});
+}
+
+TEST(BlockParser, MaxNumRows) {
+ auto csv = MakeCSVData({"a\n", "b\n", "c\n", "d\n"});
+ BlockParser parser(ParseOptions::Defaults(), -1, 0, 3 /* max_num_rows */);
+
+ AssertParsePartial(parser, csv, 6);
+ AssertColumnsEq(parser, {{"a", "b", "c"}});
+
+ AssertParseOk(parser, csv.substr(6));
+ AssertColumnsEq(parser, {{"d"}});
+
+ AssertParseOk(parser, csv.substr(8));
+ AssertColumnsEq(parser, {{}});
+}
+
+TEST(BlockParser, EmptyLinesWithOneColumn) {
+ auto csv = MakeCSVData({"a\n", "\n", "b\r", "\r", "c\r\n", "\r\n", "d\n"});
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a", "b", "c", "d"}});
+ }
+ {
+ auto options = ParseOptions::Defaults();
+ options.ignore_empty_lines = false;
+ BlockParser parser(options);
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a", "", "b", "", "c", "", "d"}});
+ }
+}
+
+TEST(BlockParser, EmptyLinesWithSeveralColumns) {
+ auto csv = MakeCSVData({"a,b\n", "\n", "c,d\r", "\r", "e,f\r\n", "\r\n", "g,h\n"});
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a", "c", "e", "g"}, {"b", "d", "f", "h"}});
+ }
+ {
+ // Non-ignored empty lines get turned into empty values
+ auto options = ParseOptions::Defaults();
+ options.ignore_empty_lines = false;
+ BlockParser parser(options);
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser,
+ {{"a", "", "c", "", "e", "", "g"}, {"b", "", "d", "", "f", "", "h"}});
+ }
+}
+
+TEST(BlockParser, EmptyLineFirst) {
+ auto csv = MakeCSVData({"\n", "\n", "a\n", "b\n"});
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a", "b"}});
+ }
+ {
+ auto options = ParseOptions::Defaults();
+ options.ignore_empty_lines = false;
+ BlockParser parser(options);
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"", "", "a", "b"}});
+ }
+}
+
+TEST(BlockParser, TruncatedData) {
+ BlockParser parser(ParseOptions::Defaults());
+ auto csv = MakeCSVData({"a,b\n", "c,d\n"});
+ for (auto trim : {1, 2, 3}) {
+ AssertParsePartial(parser, csv.substr(0, csv.length() - trim), 4);
+ AssertColumnsEq(parser, {{"a"}, {"b"}});
+ }
+}
+
+TEST(BlockParser, Final) {
+ // Tests for ParseFinal()
+ BlockParser parser(ParseOptions::Defaults());
+ auto csv = MakeCSVData({"ab,cd\n", "ef,gh\n"});
+ AssertParseFinal(parser, csv);
+ AssertColumnsEq(parser, {{"ab", "ef"}, {"cd", "gh"}});
+
+ // Same without newline
+ csv = MakeCSVData({"ab,cd\n", "ef,gh"});
+ AssertParseFinal(parser, csv);
+ AssertColumnsEq(parser, {{"ab", "ef"}, {"cd", "gh"}});
+
+ // Same with empty last item
+ csv = MakeCSVData({"ab,cd\n", "ef,"});
+ AssertParseFinal(parser, csv);
+ AssertColumnsEq(parser, {{"ab", "ef"}, {"cd", ""}});
+
+ // Same with single line
+ csv = MakeCSVData({"ab,cd"});
+ AssertParseFinal(parser, csv);
+ AssertColumnsEq(parser, {{"ab"}, {"cd"}});
+
+ // Two blocks
+ auto csv1 = MakeCSVData({"ab,cd\n"});
+ auto csv2 = MakeCSVData({"ef,"});
+ AssertParseFinal(parser, {{csv1}, {csv2}});
+ AssertColumnsEq(parser, {{"ab", "ef"}, {"cd", ""}});
+}
+
+TEST(BlockParser, FinalTruncatedData) {
+ // Test ParseFinal() with truncated data
+ uint32_t out_size;
+ BlockParser parser(ParseOptions::Defaults());
+ auto csv = MakeCSVData({"ab,cd\n", "ef"});
+ Status st = ParseFinal(parser, csv, &out_size);
+ ASSERT_RAISES(Invalid, st);
+}
+
+TEST(BlockParser, QuotingSimple) {
+ auto csv = MakeCSVData({"1,\",3,\",5\n"});
+
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"1"}, {",3,"}, {"5"}},
+ {{false}, {true}, {false}} /* quoted */);
+ }
+ {
+ auto options = ParseOptions::Defaults();
+ options.quoting = false;
+ BlockParser parser(options);
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"1"}, {"\""}, {"3"}, {"\""}, {"5"}},
+ {{false}, {false}, {false}, {false}, {false}} /* quoted */);
+ }
+ {
+ auto options = ParseOptions::Defaults();
+ options.quote_char = 'Z';
+ BlockParser parser(options);
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"1"}, {"\""}, {"3"}, {"\""}, {"5"}},
+ {{false}, {false}, {false}, {false}, {false}} /* quoted */);
+ }
+}
+
+TEST(BlockParser, QuotingNewline) {
+ auto csv = MakeCSVData({"a,\"c \n d\",e\n"});
+ BlockParser parser(ParseOptions::Defaults());
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a"}, {"c \n d"}, {"e"}},
+ {{false}, {true}, {false}} /* quoted */);
+}
+
+TEST(BlockParser, QuotingUnbalanced) {
+ // Quote introduces a quoted field that doesn't end
+ auto csv = MakeCSVData({"a,b\n", "1,\",3,,5\n"});
+ BlockParser parser(ParseOptions::Defaults());
+ AssertParsePartial(parser, csv, 4);
+ AssertColumnsEq(parser, {{"a"}, {"b"}}, {{false}, {false}} /* quoted */);
+}
+
+TEST(BlockParser, QuotingEmpty) {
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ auto csv = MakeCSVData({"\"\"\n"});
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{""}}, {{true}} /* quoted */);
+ AssertLastRowEq(parser, {""}, {true});
+ }
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ auto csv = MakeCSVData({",\"\"\n"});
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{""}, {""}}, {{false}, {true}} /* quoted */);
+ AssertLastRowEq(parser, {"", ""}, {false, true});
+ }
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ auto csv = MakeCSVData({"\"\",\n"});
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{""}, {""}}, {{true}, {false}} /* quoted */);
+ AssertLastRowEq(parser, {"", ""}, {true, false});
+ }
+}
+
+TEST(BlockParser, QuotingDouble) {
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ // 4 quotes is a quoted quote
+ auto csv = MakeCSVData({"\"\"\"\"\n"});
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"\""}}, {{true}} /* quoted */);
+ }
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ // 4 quotes is a quoted quote
+ auto csv = MakeCSVData({"a,\"\"\"\",b\n"});
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a"}, {"\""}, {"b"}},
+ {{false}, {true}, {false}} /* quoted */);
+ }
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ // 6 quotes is two quoted quotes
+ auto csv = MakeCSVData({"\"\"\"\"\"\"\n"});
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"\"\""}}, {{true}} /* quoted */);
+ }
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ // 6 quotes is two quoted quotes
+ auto csv = MakeCSVData({"a,\"\"\"\"\"\",b\n"});
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a"}, {"\"\""}, {"b"}},
+ {{false}, {true}, {false}} /* quoted */);
+ }
+}
+
+TEST(BlockParser, QuotesAndMore) {
+ // There may be trailing data after the quoted part of a field
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ auto csv = MakeCSVData({"a,\"b\"c,d\n"});
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a"}, {"bc"}, {"d"}},
+ {{false}, {true}, {false}} /* quoted */);
+ }
+}
+
+TEST(BlockParser, QuotesSpecial) {
+ // Some non-trivial cases
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ auto csv = MakeCSVData({"a,b\"c,d\n"});
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a"}, {"b\"c"}, {"d"}},
+ {{false}, {false}, {false}} /* quoted */);
+ }
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ auto csv = MakeCSVData({"a,\"b\" \"c\",d\n"});
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a"}, {"b \"c\""}, {"d"}},
+ {{false}, {true}, {false}} /* quoted */);
+ }
+}
+
+TEST(BlockParser, MismatchingNumColumns) {
+ uint32_t out_size;
+ {
+ BlockParser parser(ParseOptions::Defaults(), -1, 0 /* first_row */);
+ auto csv = MakeCSVData({"a,b\nc\n"});
+ Status st = Parse(parser, csv, &out_size);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr("CSV parse error: Row #1: Expected 2 columns, got 1: c"), st);
+ }
+ {
+ BlockParser parser(ParseOptions::Defaults(), 2 /* num_cols */, 0 /* first_row */);
+ auto csv = MakeCSVData({"a\n"});
+ Status st = Parse(parser, csv, &out_size);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr("CSV parse error: Row #0: Expected 2 columns, got 1: a"), st);
+ }
+ {
+ BlockParser parser(ParseOptions::Defaults(), 2 /* num_cols */, 50 /* first_row */);
+ auto csv = MakeCSVData({"a,b,c\n"});
+ Status st = Parse(parser, csv, &out_size);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr("CSV parse error: Row #50: Expected 2 columns, got 3: a,b,c"),
+ st);
+ }
+ // No row number
+ {
+ BlockParser parser(ParseOptions::Defaults(), 2 /* num_cols */, -1);
+ auto csv = MakeCSVData({"a\n"});
+ Status st = Parse(parser, csv, &out_size);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, testing::HasSubstr("CSV parse error: Expected 2 columns, got 1: a"), st);
+ }
+}
+
+TEST(BlockParser, MismatchingNumColumnsHandler) {
+ struct CustomHandler {
+ operator InvalidRowHandler() {
+ return [this](const InvalidRow& row) {
+ // Copy the row to a string since the array behind the string_view can go away
+ rows.emplace_back(row, row.text.to_string());
+ return InvalidRowResult::Skip;
+ };
+ }
+
+ std::vector<std::pair<InvalidRow, std::string>> rows;
+ };
+
+ {
+ ParseOptions opts = ParseOptions::Defaults();
+ CustomHandler handler;
+ opts.invalid_row_handler = handler;
+ BlockParser parser(opts);
+ ASSERT_NO_FATAL_FAILURE(AssertParseOk(parser, "a,b\nc\nd,e\n"));
+ ASSERT_EQ(2, parser.num_rows());
+ ASSERT_EQ(3, parser.total_num_rows());
+ ASSERT_EQ(1, handler.rows.size());
+ ASSERT_EQ(2, handler.rows[0].first.expected_columns);
+ ASSERT_EQ(1, handler.rows[0].first.actual_columns);
+ ASSERT_EQ("c", handler.rows[0].second);
+ ASSERT_NO_FATAL_FAILURE(AssertLastRowEq(parser, {"d", "e"}, {false, false}));
+ }
+ {
+ ParseOptions opts = ParseOptions::Defaults();
+ CustomHandler handler;
+ opts.invalid_row_handler = handler;
+ BlockParser parser(opts, 2 /* num_cols */);
+ ASSERT_NO_FATAL_FAILURE(AssertParseOk(parser, "a\nb,c\n"));
+ ASSERT_EQ(1, parser.num_rows());
+ ASSERT_EQ(2, parser.total_num_rows());
+ ASSERT_EQ(1, handler.rows.size());
+ ASSERT_EQ(2, handler.rows[0].first.expected_columns);
+ ASSERT_EQ(1, handler.rows[0].first.actual_columns);
+ ASSERT_EQ("a", handler.rows[0].second);
+ ASSERT_NO_FATAL_FAILURE(AssertLastRowEq(parser, {"b", "c"}, {false, false}));
+ }
+ {
+ ParseOptions opts = ParseOptions::Defaults();
+ CustomHandler handler;
+ opts.invalid_row_handler = handler;
+ BlockParser parser(opts, 2 /* num_cols */);
+ ASSERT_NO_FATAL_FAILURE(AssertParseOk(parser, "a,b,c\nd,e\n"));
+ ASSERT_EQ(1, parser.num_rows());
+ ASSERT_EQ(2, parser.total_num_rows());
+ ASSERT_EQ(1, handler.rows.size());
+ ASSERT_EQ(2, handler.rows[0].first.expected_columns);
+ ASSERT_EQ(3, handler.rows[0].first.actual_columns);
+ ASSERT_EQ("a,b,c", handler.rows[0].second);
+ ASSERT_NO_FATAL_FAILURE(AssertLastRowEq(parser, {"d", "e"}, {false, false}));
+ }
+
+ // Skip multiple bad lines are skipped
+ {
+ ParseOptions opts = ParseOptions::Defaults();
+ CustomHandler handler;
+ opts.invalid_row_handler = handler;
+ BlockParser parser(opts, /*num_col=*/2, /*first_row=*/1);
+ ASSERT_NO_FATAL_FAILURE(AssertParseOk(parser, "a,b,c\nd,e\nf,g\nh\ni\nj,k\nl\n"));
+ ASSERT_EQ(3, parser.num_rows());
+ ASSERT_EQ(7, parser.total_num_rows());
+ ASSERT_EQ(4, handler.rows.size());
+
+ {
+ auto row = handler.rows[0];
+ ASSERT_EQ(2, row.first.expected_columns);
+ ASSERT_EQ(3, row.first.actual_columns);
+ ASSERT_EQ(1, row.first.number);
+ ASSERT_EQ("a,b,c", row.second);
+ }
+
+ {
+ auto row = handler.rows[1];
+ ASSERT_EQ(2, row.first.expected_columns);
+ ASSERT_EQ(1, row.first.actual_columns);
+ ASSERT_EQ(4, row.first.number);
+ ASSERT_EQ("h", row.second);
+ }
+
+ {
+ auto row = handler.rows[2];
+ ASSERT_EQ(2, row.first.expected_columns);
+ ASSERT_EQ(1, row.first.actual_columns);
+ ASSERT_EQ(5, row.first.number);
+ ASSERT_EQ("i", row.second);
+ }
+
+ {
+ auto row = handler.rows[3];
+ ASSERT_EQ(2, row.first.expected_columns);
+ ASSERT_EQ(1, row.first.actual_columns);
+ ASSERT_EQ(7, row.first.number);
+ ASSERT_EQ("l", row.second);
+ }
+
+ ASSERT_NO_FATAL_FAILURE(AssertLastRowEq(parser, {"j", "k"}, {false, false}));
+ }
+}
+
+TEST(BlockParser, Escaping) {
+ auto options = ParseOptions::Defaults();
+ options.escaping = true;
+
+ {
+ auto csv = MakeCSVData({"a\\b,c\n"});
+ {
+ BlockParser parser(ParseOptions::Defaults());
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a\\b"}, {"c"}});
+ }
+ {
+ BlockParser parser(options);
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"ab"}, {"c"}});
+ }
+ }
+ {
+ auto csv = MakeCSVData({"a\\,b,c\n"});
+ BlockParser parser(options);
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a,b"}, {"c"}});
+ }
+}
+
+// Generate test data with the given number of columns.
+std::string MakeLotsOfCsvColumns(int32_t num_columns) {
+ std::string values, header;
+ header.reserve(num_columns * 10);
+ values.reserve(num_columns * 10);
+ for (int x = 0; x < num_columns; x++) {
+ if (x != 0) {
+ header += ",";
+ values += ",";
+ }
+ header += "c" + std::to_string(x);
+ values += std::to_string(x);
+ }
+
+ header += "\n";
+ values += "\n";
+ return MakeCSVData({header, values});
+}
+
+TEST(BlockParser, LotsOfColumns) {
+ auto options = ParseOptions::Defaults();
+ BlockParser parser(options);
+ AssertParseOk(parser, MakeLotsOfCsvColumns(1024 * 100));
+}
+
+TEST(BlockParser, QuotedEscape) {
+ auto options = ParseOptions::Defaults();
+ options.escaping = true;
+
+ {
+ auto csv = MakeCSVData({"\"a\\,b\",c\n"});
+ BlockParser parser(options);
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a,b"}, {"c"}}, {{true}, {false}} /* quoted */);
+ }
+ {
+ auto csv = MakeCSVData({"\"a\\\"b\",c\n"});
+ BlockParser parser(options);
+ AssertParseOk(parser, csv);
+ AssertColumnsEq(parser, {{"a\"b"}, {"c"}}, {{true}, {false}} /* quoted */);
+ }
+}
+
+TEST(BlockParser, RowNumberAppendedToError) {
+ auto options = ParseOptions::Defaults();
+ auto csv = "a,b,c\nd,e,f\ng,h,i\n";
+ {
+ BlockParser parser(options, -1, 0);
+ ASSERT_NO_FATAL_FAILURE(AssertParseOk(parser, csv));
+ int row = 0;
+ auto status = parser.VisitColumn(
+ 0, [row](const uint8_t* data, uint32_t size, bool quoted) mutable -> Status {
+ return ++row == 2 ? Status::Invalid("Bad value") : Status::OK();
+ });
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("Row #1: Bad value"),
+ status);
+ }
+
+ {
+ BlockParser parser(options, -1, 100);
+ ASSERT_NO_FATAL_FAILURE(AssertParseOk(parser, csv));
+ int row = 0;
+ auto status = parser.VisitColumn(
+ 0, [row](const uint8_t* data, uint32_t size, bool quoted) mutable -> Status {
+ return ++row == 3 ? Status::Invalid("Bad value") : Status::OK();
+ });
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("Row #102: Bad value"),
+ status);
+ }
+
+ // No first row specified should not append row information
+ {
+ BlockParser parser(options, -1, -1);
+ ASSERT_NO_FATAL_FAILURE(AssertParseOk(parser, csv));
+ int row = 0;
+ auto status = parser.VisitColumn(
+ 0, [row](const uint8_t* data, uint32_t size, bool quoted) mutable -> Status {
+ return ++row == 3 ? Status::Invalid("Bad value") : Status::OK();
+ });
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::Not(testing::HasSubstr("Row")),
+ status);
+ }
+
+ // Error message is correct even with skipped parsed rows
+ {
+ ParseOptions opts = ParseOptions::Defaults();
+ opts.invalid_row_handler = [](const InvalidRow& row) {
+ return InvalidRowResult::Skip;
+ };
+ BlockParser parser(opts, /*num_cols=*/2, /*first_row=*/1);
+ ASSERT_NO_FATAL_FAILURE(AssertParseOk(parser, "a,b,c\nd,e\nf,g\nh\ni\nj,k\nl\n"));
+ int row = 0;
+ auto status = parser.VisitColumn(
+ 0, [row](const uint8_t* data, uint32_t size, bool quoted) mutable -> Status {
+ return ++row == 3 ? Status::Invalid("Bad value") : Status::OK();
+ });
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("Row #6: Bad value"),
+ status);
+ }
+}
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/reader.cc b/src/arrow/cpp/src/arrow/csv/reader.cc
new file mode 100644
index 000000000..546de7787
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/reader.cc
@@ -0,0 +1,1303 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/csv/reader.h"
+
+#include <cstdint>
+#include <cstring>
+#include <functional>
+#include <limits>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/csv/chunker.h"
+#include "arrow/csv/column_builder.h"
+#include "arrow/csv/column_decoder.h"
+#include "arrow/csv/options.h"
+#include "arrow/csv/parser.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/future.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/task_group.h"
+#include "arrow/util/thread_pool.h"
+#include "arrow/util/utf8.h"
+#include "arrow/util/vector.h"
+
+namespace arrow {
+namespace csv {
+
+using internal::Executor;
+
+namespace {
+
+struct ConversionSchema {
+ struct Column {
+ std::string name;
+ // Physical column index in CSV file
+ int32_t index;
+ // If true, make a column of nulls
+ bool is_missing;
+ // If set, convert the CSV column to this type
+ // If unset (and is_missing is false), infer the type from the CSV column
+ std::shared_ptr<DataType> type;
+ };
+
+ static Column NullColumn(std::string col_name, std::shared_ptr<DataType> type) {
+ return Column{std::move(col_name), -1, true, std::move(type)};
+ }
+
+ static Column TypedColumn(std::string col_name, int32_t col_index,
+ std::shared_ptr<DataType> type) {
+ return Column{std::move(col_name), col_index, false, std::move(type)};
+ }
+
+ static Column InferredColumn(std::string col_name, int32_t col_index) {
+ return Column{std::move(col_name), col_index, false, nullptr};
+ }
+
+ std::vector<Column> columns;
+};
+
+// An iterator of Buffers that makes sure there is no straddling CRLF sequence.
+class CSVBufferIterator {
+ public:
+ static Iterator<std::shared_ptr<Buffer>> Make(
+ Iterator<std::shared_ptr<Buffer>> buffer_iterator) {
+ Transformer<std::shared_ptr<Buffer>, std::shared_ptr<Buffer>> fn =
+ CSVBufferIterator();
+ return MakeTransformedIterator(std::move(buffer_iterator), fn);
+ }
+
+ static AsyncGenerator<std::shared_ptr<Buffer>> MakeAsync(
+ AsyncGenerator<std::shared_ptr<Buffer>> buffer_iterator) {
+ Transformer<std::shared_ptr<Buffer>, std::shared_ptr<Buffer>> fn =
+ CSVBufferIterator();
+ return MakeTransformedGenerator(std::move(buffer_iterator), fn);
+ }
+
+ Result<TransformFlow<std::shared_ptr<Buffer>>> operator()(std::shared_ptr<Buffer> buf) {
+ if (buf == nullptr) {
+ // EOF
+ return TransformFinish();
+ }
+
+ int64_t offset = 0;
+ if (first_buffer_) {
+ ARROW_ASSIGN_OR_RAISE(auto data, util::SkipUTF8BOM(buf->data(), buf->size()));
+ offset += data - buf->data();
+ DCHECK_GE(offset, 0);
+ first_buffer_ = false;
+ }
+
+ if (trailing_cr_ && buf->data()[offset] == '\n') {
+ // Skip '\r\n' line separator that started at the end of previous buffer
+ ++offset;
+ }
+
+ trailing_cr_ = (buf->data()[buf->size() - 1] == '\r');
+ buf = SliceBuffer(buf, offset);
+ if (buf->size() == 0) {
+ // EOF
+ return TransformFinish();
+ } else {
+ return TransformYield(buf);
+ }
+ }
+
+ protected:
+ bool first_buffer_ = true;
+ // Whether there was a trailing CR at the end of last received buffer
+ bool trailing_cr_ = false;
+};
+
+struct CSVBlock {
+ // (partial + completion + buffer) is an entire delimited CSV buffer.
+ std::shared_ptr<Buffer> partial;
+ std::shared_ptr<Buffer> completion;
+ std::shared_ptr<Buffer> buffer;
+ int64_t block_index;
+ bool is_final;
+ int64_t bytes_skipped;
+ std::function<Status(int64_t)> consume_bytes;
+};
+
+} // namespace
+} // namespace csv
+
+template <>
+struct IterationTraits<csv::CSVBlock> {
+ static csv::CSVBlock End() { return csv::CSVBlock{{}, {}, {}, -1, true, 0, {}}; }
+ static bool IsEnd(const csv::CSVBlock& val) { return val.block_index < 0; }
+};
+
+namespace csv {
+namespace {
+
+// This is a callable that can be used to transform an iterator. The source iterator
+// will contain buffers of data and the output iterator will contain delimited CSV
+// blocks. util::optional is used so that there is an end token (required by the
+// iterator APIs (e.g. Visit)) even though an empty optional is never used in this code.
+class BlockReader {
+ public:
+ BlockReader(std::unique_ptr<Chunker> chunker, std::shared_ptr<Buffer> first_buffer,
+ int64_t skip_rows)
+ : chunker_(std::move(chunker)),
+ partial_(std::make_shared<Buffer>("")),
+ buffer_(std::move(first_buffer)),
+ skip_rows_(skip_rows) {}
+
+ protected:
+ std::unique_ptr<Chunker> chunker_;
+ std::shared_ptr<Buffer> partial_, buffer_;
+ int64_t skip_rows_;
+ int64_t block_index_ = 0;
+ // Whether there was a trailing CR at the end of last received buffer
+ bool trailing_cr_ = false;
+};
+
+// An object that reads delimited CSV blocks for serial use.
+// The number of bytes consumed should be notified after each read,
+// using CSVBlock::consume_bytes.
+class SerialBlockReader : public BlockReader {
+ public:
+ using BlockReader::BlockReader;
+
+ static Iterator<CSVBlock> MakeIterator(
+ Iterator<std::shared_ptr<Buffer>> buffer_iterator, std::unique_ptr<Chunker> chunker,
+ std::shared_ptr<Buffer> first_buffer, int64_t skip_rows) {
+ auto block_reader =
+ std::make_shared<SerialBlockReader>(std::move(chunker), first_buffer, skip_rows);
+ // Wrap shared pointer in callable
+ Transformer<std::shared_ptr<Buffer>, CSVBlock> block_reader_fn =
+ [block_reader](std::shared_ptr<Buffer> buf) {
+ return (*block_reader)(std::move(buf));
+ };
+ return MakeTransformedIterator(std::move(buffer_iterator), block_reader_fn);
+ }
+
+ static AsyncGenerator<CSVBlock> MakeAsyncIterator(
+ AsyncGenerator<std::shared_ptr<Buffer>> buffer_generator,
+ std::unique_ptr<Chunker> chunker, std::shared_ptr<Buffer> first_buffer,
+ int64_t skip_rows) {
+ auto block_reader =
+ std::make_shared<SerialBlockReader>(std::move(chunker), first_buffer, skip_rows);
+ // Wrap shared pointer in callable
+ Transformer<std::shared_ptr<Buffer>, CSVBlock> block_reader_fn =
+ [block_reader](std::shared_ptr<Buffer> next) {
+ return (*block_reader)(std::move(next));
+ };
+ return MakeTransformedGenerator(std::move(buffer_generator), block_reader_fn);
+ }
+
+ Result<TransformFlow<CSVBlock>> operator()(std::shared_ptr<Buffer> next_buffer) {
+ if (buffer_ == nullptr) {
+ return TransformFinish();
+ }
+
+ bool is_final = (next_buffer == nullptr);
+ int64_t bytes_skipped = 0;
+
+ if (skip_rows_) {
+ bytes_skipped += partial_->size();
+ auto orig_size = buffer_->size();
+ RETURN_NOT_OK(
+ chunker_->ProcessSkip(partial_, buffer_, is_final, &skip_rows_, &buffer_));
+ bytes_skipped += orig_size - buffer_->size();
+ auto empty = std::make_shared<Buffer>(nullptr, 0);
+ if (skip_rows_) {
+ // Still have rows beyond this buffer to skip return empty block
+ partial_ = std::move(buffer_);
+ buffer_ = next_buffer;
+ return TransformYield<CSVBlock>(CSVBlock{empty, empty, empty, block_index_++,
+ is_final, bytes_skipped,
+ [](int64_t) { return Status::OK(); }});
+ }
+ partial_ = std::move(empty);
+ }
+
+ std::shared_ptr<Buffer> completion;
+
+ if (is_final) {
+ // End of file reached => compute completion from penultimate block
+ RETURN_NOT_OK(chunker_->ProcessFinal(partial_, buffer_, &completion, &buffer_));
+ } else {
+ // Get completion of partial from previous block.
+ RETURN_NOT_OK(
+ chunker_->ProcessWithPartial(partial_, buffer_, &completion, &buffer_));
+ }
+ int64_t bytes_before_buffer = partial_->size() + completion->size();
+
+ auto consume_bytes = [this, bytes_before_buffer,
+ next_buffer](int64_t nbytes) -> Status {
+ DCHECK_GE(nbytes, 0);
+ auto offset = nbytes - bytes_before_buffer;
+ if (offset < 0) {
+ // Should not happen
+ return Status::Invalid("CSV parser got out of sync with chunker");
+ }
+ partial_ = SliceBuffer(buffer_, offset);
+ buffer_ = next_buffer;
+ return Status::OK();
+ };
+
+ return TransformYield<CSVBlock>(CSVBlock{partial_, completion, buffer_,
+ block_index_++, is_final, bytes_skipped,
+ std::move(consume_bytes)});
+ }
+};
+
+// An object that reads delimited CSV blocks for threaded use.
+class ThreadedBlockReader : public BlockReader {
+ public:
+ using BlockReader::BlockReader;
+
+ static AsyncGenerator<CSVBlock> MakeAsyncIterator(
+ AsyncGenerator<std::shared_ptr<Buffer>> buffer_generator,
+ std::unique_ptr<Chunker> chunker, std::shared_ptr<Buffer> first_buffer,
+ int64_t skip_rows) {
+ auto block_reader = std::make_shared<ThreadedBlockReader>(std::move(chunker),
+ first_buffer, skip_rows);
+ // Wrap shared pointer in callable
+ Transformer<std::shared_ptr<Buffer>, CSVBlock> block_reader_fn =
+ [block_reader](std::shared_ptr<Buffer> next) { return (*block_reader)(next); };
+ return MakeTransformedGenerator(std::move(buffer_generator), block_reader_fn);
+ }
+
+ Result<TransformFlow<CSVBlock>> operator()(std::shared_ptr<Buffer> next_buffer) {
+ if (buffer_ == nullptr) {
+ // EOF
+ return TransformFinish();
+ }
+
+ bool is_final = (next_buffer == nullptr);
+
+ auto current_partial = std::move(partial_);
+ auto current_buffer = std::move(buffer_);
+ int64_t bytes_skipped = 0;
+
+ if (skip_rows_) {
+ auto orig_size = current_buffer->size();
+ bytes_skipped = current_partial->size();
+ RETURN_NOT_OK(chunker_->ProcessSkip(current_partial, current_buffer, is_final,
+ &skip_rows_, &current_buffer));
+ bytes_skipped += orig_size - current_buffer->size();
+ current_partial = std::make_shared<Buffer>(nullptr, 0);
+ if (skip_rows_) {
+ partial_ = std::move(current_buffer);
+ buffer_ = std::move(next_buffer);
+ return TransformYield<CSVBlock>(CSVBlock{current_partial,
+ current_partial,
+ current_partial,
+ block_index_++,
+ is_final,
+ bytes_skipped,
+ {}});
+ }
+ }
+
+ std::shared_ptr<Buffer> whole, completion, next_partial;
+
+ if (is_final) {
+ // End of file reached => compute completion from penultimate block
+ RETURN_NOT_OK(
+ chunker_->ProcessFinal(current_partial, current_buffer, &completion, &whole));
+ } else {
+ // Get completion of partial from previous block.
+ std::shared_ptr<Buffer> starts_with_whole;
+ // Get completion of partial from previous block.
+ RETURN_NOT_OK(chunker_->ProcessWithPartial(current_partial, current_buffer,
+ &completion, &starts_with_whole));
+
+ // Get a complete CSV block inside `partial + block`, and keep
+ // the rest for the next iteration.
+ RETURN_NOT_OK(chunker_->Process(starts_with_whole, &whole, &next_partial));
+ }
+
+ partial_ = std::move(next_partial);
+ buffer_ = std::move(next_buffer);
+
+ return TransformYield<CSVBlock>(CSVBlock{
+ current_partial, completion, whole, block_index_++, is_final, bytes_skipped, {}});
+ }
+};
+
+struct ParsedBlock {
+ std::shared_ptr<BlockParser> parser;
+ int64_t block_index;
+ int64_t bytes_parsed_or_skipped;
+};
+
+struct DecodedBlock {
+ std::shared_ptr<RecordBatch> record_batch;
+ // Represents the number of input bytes represented by this batch
+ // This will include bytes skipped when skipping rows after the header
+ int64_t bytes_processed;
+};
+
+} // namespace
+
+} // namespace csv
+
+template <>
+struct IterationTraits<csv::ParsedBlock> {
+ static csv::ParsedBlock End() { return csv::ParsedBlock{nullptr, -1, -1}; }
+ static bool IsEnd(const csv::ParsedBlock& val) { return val.block_index < 0; }
+};
+
+template <>
+struct IterationTraits<csv::DecodedBlock> {
+ static csv::DecodedBlock End() { return csv::DecodedBlock{nullptr, -1}; }
+ static bool IsEnd(const csv::DecodedBlock& val) { return val.bytes_processed < 0; }
+};
+
+namespace csv {
+namespace {
+
+// A function object that takes in a buffer of CSV data and returns a parsed batch of CSV
+// data (CSVBlock -> ParsedBlock) for use with MakeMappedGenerator.
+// The parsed batch contains a list of offsets for each of the columns so that columns
+// can be individually scanned
+//
+// This operator is not re-entrant
+class BlockParsingOperator {
+ public:
+ BlockParsingOperator(io::IOContext io_context, ParseOptions parse_options,
+ int num_csv_cols, int64_t first_row)
+ : io_context_(io_context),
+ parse_options_(parse_options),
+ num_csv_cols_(num_csv_cols),
+ count_rows_(first_row >= 0),
+ num_rows_seen_(first_row) {}
+
+ Result<ParsedBlock> operator()(const CSVBlock& block) {
+ constexpr int32_t max_num_rows = std::numeric_limits<int32_t>::max();
+ auto parser = std::make_shared<BlockParser>(
+ io_context_.pool(), parse_options_, num_csv_cols_, num_rows_seen_, max_num_rows);
+
+ std::shared_ptr<Buffer> straddling;
+ std::vector<util::string_view> views;
+ if (block.partial->size() != 0 || block.completion->size() != 0) {
+ if (block.partial->size() == 0) {
+ straddling = block.completion;
+ } else if (block.completion->size() == 0) {
+ straddling = block.partial;
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ straddling,
+ ConcatenateBuffers({block.partial, block.completion}, io_context_.pool()));
+ }
+ views = {util::string_view(*straddling), util::string_view(*block.buffer)};
+ } else {
+ views = {util::string_view(*block.buffer)};
+ }
+ uint32_t parsed_size;
+ if (block.is_final) {
+ RETURN_NOT_OK(parser->ParseFinal(views, &parsed_size));
+ } else {
+ RETURN_NOT_OK(parser->Parse(views, &parsed_size));
+ }
+ if (count_rows_) {
+ num_rows_seen_ += parser->total_num_rows();
+ }
+ RETURN_NOT_OK(block.consume_bytes(parsed_size));
+ return ParsedBlock{std::move(parser), block.block_index,
+ static_cast<int64_t>(parsed_size) + block.bytes_skipped};
+ }
+
+ private:
+ io::IOContext io_context_;
+ ParseOptions parse_options_;
+ int num_csv_cols_;
+ bool count_rows_;
+ int64_t num_rows_seen_;
+};
+
+// A function object that takes in parsed batch of CSV data and decodes it to an arrow
+// record batch (ParsedBlock -> DecodedBlock) for use with MakeMappedGenerator.
+class BlockDecodingOperator {
+ public:
+ Future<DecodedBlock> operator()(const ParsedBlock& block) {
+ DCHECK(!state_->column_decoders.empty());
+ std::vector<Future<std::shared_ptr<Array>>> decoded_array_futs;
+ for (auto& decoder : state_->column_decoders) {
+ decoded_array_futs.push_back(decoder->Decode(block.parser));
+ }
+ auto bytes_parsed_or_skipped = block.bytes_parsed_or_skipped;
+ auto decoded_arrays_fut = All(std::move(decoded_array_futs));
+ auto state = state_;
+ return decoded_arrays_fut.Then(
+ [state, bytes_parsed_or_skipped](
+ const std::vector<Result<std::shared_ptr<Array>>>& maybe_decoded_arrays)
+ -> Result<DecodedBlock> {
+ ARROW_ASSIGN_OR_RAISE(auto decoded_arrays,
+ internal::UnwrapOrRaise(maybe_decoded_arrays));
+
+ ARROW_ASSIGN_OR_RAISE(auto batch,
+ state->DecodedArraysToBatch(std::move(decoded_arrays)));
+ return DecodedBlock{std::move(batch), bytes_parsed_or_skipped};
+ });
+ }
+
+ static Result<BlockDecodingOperator> Make(io::IOContext io_context,
+ ConvertOptions convert_options,
+ ConversionSchema conversion_schema) {
+ BlockDecodingOperator op(std::move(io_context), std::move(convert_options),
+ std::move(conversion_schema));
+ RETURN_NOT_OK(op.state_->MakeColumnDecoders(io_context));
+ return op;
+ }
+
+ private:
+ BlockDecodingOperator(io::IOContext io_context, ConvertOptions convert_options,
+ ConversionSchema conversion_schema)
+ : state_(std::make_shared<State>(std::move(io_context), std::move(convert_options),
+ std::move(conversion_schema))) {}
+
+ struct State {
+ State(io::IOContext io_context, ConvertOptions convert_options,
+ ConversionSchema conversion_schema)
+ : convert_options(std::move(convert_options)),
+ conversion_schema(std::move(conversion_schema)) {}
+
+ Result<std::shared_ptr<RecordBatch>> DecodedArraysToBatch(
+ std::vector<std::shared_ptr<Array>> arrays) {
+ const auto n_rows = arrays[0]->length();
+
+ if (schema == nullptr) {
+ FieldVector fields(arrays.size());
+ for (size_t i = 0; i < arrays.size(); ++i) {
+ fields[i] = field(conversion_schema.columns[i].name, arrays[i]->type());
+ }
+
+ if (n_rows == 0) {
+ // No rows so schema is not reliable. return RecordBatch but do not set schema
+ return RecordBatch::Make(arrow::schema(std::move(fields)), n_rows,
+ std::move(arrays));
+ }
+
+ schema = arrow::schema(std::move(fields));
+ }
+
+ return RecordBatch::Make(schema, n_rows, std::move(arrays));
+ }
+
+ // Make column decoders from conversion schema
+ Status MakeColumnDecoders(io::IOContext io_context) {
+ for (const auto& column : conversion_schema.columns) {
+ std::shared_ptr<ColumnDecoder> decoder;
+ if (column.is_missing) {
+ ARROW_ASSIGN_OR_RAISE(decoder,
+ ColumnDecoder::MakeNull(io_context.pool(), column.type));
+ } else if (column.type != nullptr) {
+ ARROW_ASSIGN_OR_RAISE(
+ decoder, ColumnDecoder::Make(io_context.pool(), column.type, column.index,
+ convert_options));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ decoder,
+ ColumnDecoder::Make(io_context.pool(), column.index, convert_options));
+ }
+ column_decoders.push_back(std::move(decoder));
+ }
+ return Status::OK();
+ }
+
+ ConvertOptions convert_options;
+ ConversionSchema conversion_schema;
+ std::vector<std::shared_ptr<ColumnDecoder>> column_decoders;
+ std::shared_ptr<Schema> schema;
+ };
+
+ std::shared_ptr<State> state_;
+};
+
+/////////////////////////////////////////////////////////////////////////
+// Base class for common functionality
+
+class ReaderMixin {
+ public:
+ ReaderMixin(io::IOContext io_context, std::shared_ptr<io::InputStream> input,
+ const ReadOptions& read_options, const ParseOptions& parse_options,
+ const ConvertOptions& convert_options, bool count_rows)
+ : io_context_(std::move(io_context)),
+ read_options_(read_options),
+ parse_options_(parse_options),
+ convert_options_(convert_options),
+ count_rows_(count_rows),
+ num_rows_seen_(count_rows_ ? 1 : -1),
+ input_(std::move(input)) {}
+
+ protected:
+ // Read header and column names from buffer, create column builders
+ // Returns the # of bytes consumed
+ Result<int64_t> ProcessHeader(const std::shared_ptr<Buffer>& buf,
+ std::shared_ptr<Buffer>* rest) {
+ const uint8_t* data = buf->data();
+ const auto data_end = data + buf->size();
+ DCHECK_GT(data_end - data, 0);
+
+ if (read_options_.skip_rows) {
+ // Skip initial rows (potentially invalid CSV data)
+ auto num_skipped_rows = SkipRows(data, static_cast<uint32_t>(data_end - data),
+ read_options_.skip_rows, &data);
+ if (num_skipped_rows < read_options_.skip_rows) {
+ return Status::Invalid(
+ "Could not skip initial ", read_options_.skip_rows,
+ " rows from CSV file, "
+ "either file is too short or header is larger than block size");
+ }
+ if (count_rows_) {
+ num_rows_seen_ += num_skipped_rows;
+ }
+ }
+
+ if (read_options_.column_names.empty()) {
+ // Parse one row (either to read column names or to know the number of columns)
+ BlockParser parser(io_context_.pool(), parse_options_, num_csv_cols_,
+ num_rows_seen_, 1);
+ uint32_t parsed_size = 0;
+ RETURN_NOT_OK(parser.Parse(
+ util::string_view(reinterpret_cast<const char*>(data), data_end - data),
+ &parsed_size));
+ if (parser.num_rows() != 1) {
+ return Status::Invalid(
+ "Could not read first row from CSV file, either "
+ "file is too short or header is larger than block size");
+ }
+ if (parser.num_cols() == 0) {
+ return Status::Invalid("No columns in CSV file");
+ }
+
+ if (read_options_.autogenerate_column_names) {
+ column_names_ = GenerateColumnNames(parser.num_cols());
+ } else {
+ // Read column names from header row
+ auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status {
+ column_names_.emplace_back(reinterpret_cast<const char*>(data), size);
+ return Status::OK();
+ };
+ RETURN_NOT_OK(parser.VisitLastRow(visit));
+ DCHECK_EQ(static_cast<size_t>(parser.num_cols()), column_names_.size());
+ // Skip parsed header row
+ data += parsed_size;
+ if (count_rows_) {
+ ++num_rows_seen_;
+ }
+ }
+ } else {
+ column_names_ = read_options_.column_names;
+ }
+
+ if (count_rows_) {
+ // increase rows seen to skip past rows which will be skipped
+ num_rows_seen_ += read_options_.skip_rows_after_names;
+ }
+
+ auto bytes_consumed = data - buf->data();
+ *rest = SliceBuffer(buf, bytes_consumed);
+
+ num_csv_cols_ = static_cast<int32_t>(column_names_.size());
+ DCHECK_GT(num_csv_cols_, 0);
+
+ RETURN_NOT_OK(MakeConversionSchema());
+ return bytes_consumed;
+ }
+
+ std::vector<std::string> GenerateColumnNames(int32_t num_cols) {
+ std::vector<std::string> res;
+ res.reserve(num_cols);
+ for (int32_t i = 0; i < num_cols; ++i) {
+ std::stringstream ss;
+ ss << "f" << i;
+ res.push_back(ss.str());
+ }
+ return res;
+ }
+
+ // Make conversion schema from options and parsed CSV header
+ Status MakeConversionSchema() {
+ // Append a column converted from CSV data
+ auto append_csv_column = [&](std::string col_name, int32_t col_index) {
+ // Does the named column have a fixed type?
+ auto it = convert_options_.column_types.find(col_name);
+ if (it == convert_options_.column_types.end()) {
+ conversion_schema_.columns.push_back(
+ ConversionSchema::InferredColumn(std::move(col_name), col_index));
+ } else {
+ conversion_schema_.columns.push_back(
+ ConversionSchema::TypedColumn(std::move(col_name), col_index, it->second));
+ }
+ };
+
+ // Append a column of nulls
+ auto append_null_column = [&](std::string col_name) {
+ // If the named column has a fixed type, use it, otherwise use null()
+ std::shared_ptr<DataType> type;
+ auto it = convert_options_.column_types.find(col_name);
+ if (it == convert_options_.column_types.end()) {
+ type = null();
+ } else {
+ type = it->second;
+ }
+ conversion_schema_.columns.push_back(
+ ConversionSchema::NullColumn(std::move(col_name), std::move(type)));
+ };
+
+ if (convert_options_.include_columns.empty()) {
+ // Include all columns in CSV file order
+ for (int32_t col_index = 0; col_index < num_csv_cols_; ++col_index) {
+ append_csv_column(column_names_[col_index], col_index);
+ }
+ } else {
+ // Include columns from `include_columns` (in that order)
+ // Compute indices of columns in the CSV file
+ std::unordered_map<std::string, int32_t> col_indices;
+ col_indices.reserve(column_names_.size());
+ for (int32_t i = 0; i < static_cast<int32_t>(column_names_.size()); ++i) {
+ col_indices.emplace(column_names_[i], i);
+ }
+
+ for (const auto& col_name : convert_options_.include_columns) {
+ auto it = col_indices.find(col_name);
+ if (it != col_indices.end()) {
+ append_csv_column(col_name, it->second);
+ } else if (convert_options_.include_missing_columns) {
+ append_null_column(col_name);
+ } else {
+ return Status::KeyError("Column '", col_name,
+ "' in include_columns "
+ "does not exist in CSV file");
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ struct ParseResult {
+ std::shared_ptr<BlockParser> parser;
+ int64_t parsed_bytes;
+ };
+
+ Result<ParseResult> Parse(const std::shared_ptr<Buffer>& partial,
+ const std::shared_ptr<Buffer>& completion,
+ const std::shared_ptr<Buffer>& block, int64_t block_index,
+ bool is_final) {
+ static constexpr int32_t max_num_rows = std::numeric_limits<int32_t>::max();
+ auto parser = std::make_shared<BlockParser>(
+ io_context_.pool(), parse_options_, num_csv_cols_, num_rows_seen_, max_num_rows);
+
+ std::shared_ptr<Buffer> straddling;
+ std::vector<util::string_view> views;
+ if (partial->size() != 0 || completion->size() != 0) {
+ if (partial->size() == 0) {
+ straddling = completion;
+ } else if (completion->size() == 0) {
+ straddling = partial;
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ straddling, ConcatenateBuffers({partial, completion}, io_context_.pool()));
+ }
+ views = {util::string_view(*straddling), util::string_view(*block)};
+ } else {
+ views = {util::string_view(*block)};
+ }
+ uint32_t parsed_size;
+ if (is_final) {
+ RETURN_NOT_OK(parser->ParseFinal(views, &parsed_size));
+ } else {
+ RETURN_NOT_OK(parser->Parse(views, &parsed_size));
+ }
+ if (count_rows_) {
+ num_rows_seen_ += parser->total_num_rows();
+ }
+ return ParseResult{std::move(parser), static_cast<int64_t>(parsed_size)};
+ }
+
+ io::IOContext io_context_;
+ ReadOptions read_options_;
+ ParseOptions parse_options_;
+ ConvertOptions convert_options_;
+
+ // Number of columns in the CSV file
+ int32_t num_csv_cols_ = -1;
+ // Whether num_rows_seen_ tracks the number of rows seen in the CSV being parsed
+ bool count_rows_;
+ // Number of rows seen in the csv. Not used if count_rows is false
+ int64_t num_rows_seen_;
+ // Column names in the CSV file
+ std::vector<std::string> column_names_;
+ ConversionSchema conversion_schema_;
+
+ std::shared_ptr<io::InputStream> input_;
+ std::shared_ptr<internal::TaskGroup> task_group_;
+};
+
+/////////////////////////////////////////////////////////////////////////
+// Base class for one-shot table readers
+
+class BaseTableReader : public ReaderMixin, public csv::TableReader {
+ public:
+ using ReaderMixin::ReaderMixin;
+
+ virtual Status Init() = 0;
+
+ Future<std::shared_ptr<Table>> ReadAsync() override {
+ return Future<std::shared_ptr<Table>>::MakeFinished(Read());
+ }
+
+ protected:
+ // Make column builders from conversion schema
+ Status MakeColumnBuilders() {
+ for (const auto& column : conversion_schema_.columns) {
+ std::shared_ptr<ColumnBuilder> builder;
+ if (column.is_missing) {
+ ARROW_ASSIGN_OR_RAISE(builder, ColumnBuilder::MakeNull(io_context_.pool(),
+ column.type, task_group_));
+ } else if (column.type != nullptr) {
+ ARROW_ASSIGN_OR_RAISE(
+ builder, ColumnBuilder::Make(io_context_.pool(), column.type, column.index,
+ convert_options_, task_group_));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(builder,
+ ColumnBuilder::Make(io_context_.pool(), column.index,
+ convert_options_, task_group_));
+ }
+ column_builders_.push_back(std::move(builder));
+ }
+ return Status::OK();
+ }
+
+ Result<int64_t> ParseAndInsert(const std::shared_ptr<Buffer>& partial,
+ const std::shared_ptr<Buffer>& completion,
+ const std::shared_ptr<Buffer>& block,
+ int64_t block_index, bool is_final) {
+ ARROW_ASSIGN_OR_RAISE(auto result,
+ Parse(partial, completion, block, block_index, is_final));
+ RETURN_NOT_OK(ProcessData(result.parser, block_index));
+ return result.parsed_bytes;
+ }
+
+ // Trigger conversion of parsed block data
+ Status ProcessData(const std::shared_ptr<BlockParser>& parser, int64_t block_index) {
+ for (auto& builder : column_builders_) {
+ builder->Insert(block_index, parser);
+ }
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<Table>> MakeTable() {
+ DCHECK_EQ(column_builders_.size(), conversion_schema_.columns.size());
+
+ std::vector<std::shared_ptr<Field>> fields;
+ std::vector<std::shared_ptr<ChunkedArray>> columns;
+
+ for (int32_t i = 0; i < static_cast<int32_t>(column_builders_.size()); ++i) {
+ const auto& column = conversion_schema_.columns[i];
+ ARROW_ASSIGN_OR_RAISE(auto array, column_builders_[i]->Finish());
+ fields.push_back(::arrow::field(column.name, array->type()));
+ columns.emplace_back(std::move(array));
+ }
+ return Table::Make(schema(std::move(fields)), std::move(columns));
+ }
+
+ // Column builders for target Table (in ConversionSchema order)
+ std::vector<std::shared_ptr<ColumnBuilder>> column_builders_;
+};
+
+/////////////////////////////////////////////////////////////////////////
+// Base class for streaming readers
+
+class StreamingReaderImpl : public ReaderMixin,
+ public csv::StreamingReader,
+ public std::enable_shared_from_this<StreamingReaderImpl> {
+ public:
+ StreamingReaderImpl(io::IOContext io_context, std::shared_ptr<io::InputStream> input,
+ const ReadOptions& read_options, const ParseOptions& parse_options,
+ const ConvertOptions& convert_options, bool count_rows)
+ : ReaderMixin(io_context, std::move(input), read_options, parse_options,
+ convert_options, count_rows),
+ bytes_decoded_(std::make_shared<std::atomic<int64_t>>(0)) {}
+
+ Future<> Init(Executor* cpu_executor) {
+ ARROW_ASSIGN_OR_RAISE(auto istream_it,
+ io::MakeInputStreamIterator(input_, read_options_.block_size));
+
+ // TODO Consider exposing readahead as a read option (ARROW-12090)
+ ARROW_ASSIGN_OR_RAISE(auto bg_it, MakeBackgroundGenerator(std::move(istream_it),
+ io_context_.executor()));
+
+ auto transferred_it = MakeTransferredGenerator(bg_it, cpu_executor);
+
+ auto buffer_generator = CSVBufferIterator::MakeAsync(std::move(transferred_it));
+
+ int max_readahead = cpu_executor->GetCapacity();
+ auto self = shared_from_this();
+
+ return buffer_generator().Then([self, buffer_generator, max_readahead](
+ const std::shared_ptr<Buffer>& first_buffer) {
+ return self->InitAfterFirstBuffer(first_buffer, buffer_generator, max_readahead);
+ });
+ }
+
+ std::shared_ptr<Schema> schema() const override { return schema_; }
+
+ int64_t bytes_read() const override { return bytes_decoded_->load(); }
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
+ auto next_fut = ReadNextAsync();
+ auto next_result = next_fut.result();
+ return std::move(next_result).Value(batch);
+ }
+
+ Future<std::shared_ptr<RecordBatch>> ReadNextAsync() override {
+ return record_batch_gen_();
+ }
+
+ protected:
+ Future<> InitAfterFirstBuffer(const std::shared_ptr<Buffer>& first_buffer,
+ AsyncGenerator<std::shared_ptr<Buffer>> buffer_generator,
+ int max_readahead) {
+ if (first_buffer == nullptr) {
+ return Status::Invalid("Empty CSV file");
+ }
+
+ std::shared_ptr<Buffer> after_header;
+ ARROW_ASSIGN_OR_RAISE(auto header_bytes_consumed,
+ ProcessHeader(first_buffer, &after_header));
+ bytes_decoded_->fetch_add(header_bytes_consumed);
+
+ auto parser_op =
+ BlockParsingOperator(io_context_, parse_options_, num_csv_cols_, num_rows_seen_);
+ ARROW_ASSIGN_OR_RAISE(
+ auto decoder_op,
+ BlockDecodingOperator::Make(io_context_, convert_options_, conversion_schema_));
+
+ auto block_gen = SerialBlockReader::MakeAsyncIterator(
+ std::move(buffer_generator), MakeChunker(parse_options_), std::move(after_header),
+ read_options_.skip_rows_after_names);
+ auto parsed_block_gen =
+ MakeMappedGenerator(std::move(block_gen), std::move(parser_op));
+ auto rb_gen = MakeMappedGenerator(std::move(parsed_block_gen), std::move(decoder_op));
+
+ auto self = shared_from_this();
+ return rb_gen().Then([self, rb_gen, max_readahead](const DecodedBlock& first_block) {
+ return self->InitFromBlock(first_block, std::move(rb_gen), max_readahead, 0);
+ });
+ }
+
+ Future<> InitFromBlock(const DecodedBlock& block,
+ AsyncGenerator<DecodedBlock> batch_gen, int max_readahead,
+ int64_t prev_bytes_processed) {
+ if (!block.record_batch) {
+ // End of file just return null batches
+ record_batch_gen_ = MakeEmptyGenerator<std::shared_ptr<RecordBatch>>();
+ return Status::OK();
+ }
+
+ schema_ = block.record_batch->schema();
+
+ if (block.record_batch->num_rows() == 0) {
+ // Keep consuming blocks until the first non empty block is found
+ auto self = shared_from_this();
+ prev_bytes_processed += block.bytes_processed;
+ return batch_gen().Then([self, batch_gen, max_readahead,
+ prev_bytes_processed](const DecodedBlock& next_block) {
+ return self->InitFromBlock(next_block, std::move(batch_gen), max_readahead,
+ prev_bytes_processed);
+ });
+ }
+
+ AsyncGenerator<DecodedBlock> readahead_gen;
+ if (read_options_.use_threads) {
+ readahead_gen = MakeReadaheadGenerator(std::move(batch_gen), max_readahead);
+ } else {
+ readahead_gen = std::move(batch_gen);
+ }
+
+ AsyncGenerator<DecodedBlock> restarted_gen =
+ MakeGeneratorStartsWith({block}, std::move(readahead_gen));
+
+ auto bytes_decoded = bytes_decoded_;
+ auto unwrap_and_record_bytes =
+ [bytes_decoded, prev_bytes_processed](
+ const DecodedBlock& block) mutable -> Result<std::shared_ptr<RecordBatch>> {
+ bytes_decoded->fetch_add(block.bytes_processed + prev_bytes_processed);
+ prev_bytes_processed = 0;
+ return block.record_batch;
+ };
+
+ auto unwrapped =
+ MakeMappedGenerator(std::move(restarted_gen), std::move(unwrap_and_record_bytes));
+
+ record_batch_gen_ = MakeCancellable(std::move(unwrapped), io_context_.stop_token());
+ return Status::OK();
+ }
+
+ std::shared_ptr<Schema> schema_;
+ AsyncGenerator<std::shared_ptr<RecordBatch>> record_batch_gen_;
+ // bytes which have been decoded and asked for by the caller
+ std::shared_ptr<std::atomic<int64_t>> bytes_decoded_;
+};
+
+/////////////////////////////////////////////////////////////////////////
+// Serial TableReader implementation
+
+class SerialTableReader : public BaseTableReader {
+ public:
+ using BaseTableReader::BaseTableReader;
+
+ Status Init() override {
+ ARROW_ASSIGN_OR_RAISE(auto istream_it,
+ io::MakeInputStreamIterator(input_, read_options_.block_size));
+
+ // Since we're converting serially, no need to readahead more than one block
+ int32_t block_queue_size = 1;
+ ARROW_ASSIGN_OR_RAISE(auto rh_it,
+ MakeReadaheadIterator(std::move(istream_it), block_queue_size));
+ buffer_iterator_ = CSVBufferIterator::Make(std::move(rh_it));
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<Table>> Read() override {
+ task_group_ = internal::TaskGroup::MakeSerial(io_context_.stop_token());
+
+ // First block
+ ARROW_ASSIGN_OR_RAISE(auto first_buffer, buffer_iterator_.Next());
+ if (first_buffer == nullptr) {
+ return Status::Invalid("Empty CSV file");
+ }
+ RETURN_NOT_OK(ProcessHeader(first_buffer, &first_buffer));
+ RETURN_NOT_OK(MakeColumnBuilders());
+
+ auto block_iterator = SerialBlockReader::MakeIterator(
+ std::move(buffer_iterator_), MakeChunker(parse_options_), std::move(first_buffer),
+ read_options_.skip_rows_after_names);
+ while (true) {
+ RETURN_NOT_OK(io_context_.stop_token().Poll());
+
+ ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_iterator.Next());
+ if (IsIterationEnd(maybe_block)) {
+ // EOF
+ break;
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ int64_t parsed_bytes,
+ ParseAndInsert(maybe_block.partial, maybe_block.completion, maybe_block.buffer,
+ maybe_block.block_index, maybe_block.is_final));
+ RETURN_NOT_OK(maybe_block.consume_bytes(parsed_bytes));
+ }
+ // Finish conversion, create schema and table
+ RETURN_NOT_OK(task_group_->Finish());
+ return MakeTable();
+ }
+
+ protected:
+ Iterator<std::shared_ptr<Buffer>> buffer_iterator_;
+};
+
+class AsyncThreadedTableReader
+ : public BaseTableReader,
+ public std::enable_shared_from_this<AsyncThreadedTableReader> {
+ public:
+ using BaseTableReader::BaseTableReader;
+
+ AsyncThreadedTableReader(io::IOContext io_context,
+ std::shared_ptr<io::InputStream> input,
+ const ReadOptions& read_options,
+ const ParseOptions& parse_options,
+ const ConvertOptions& convert_options, Executor* cpu_executor)
+ // Count rows is currently not supported during parallel read
+ : BaseTableReader(std::move(io_context), input, read_options, parse_options,
+ convert_options, /*count_rows=*/false),
+ cpu_executor_(cpu_executor) {}
+
+ ~AsyncThreadedTableReader() override {
+ if (task_group_) {
+ // In case of error, make sure all pending tasks are finished before
+ // we start destroying BaseTableReader members
+ ARROW_UNUSED(task_group_->Finish());
+ }
+ }
+
+ Status Init() override {
+ ARROW_ASSIGN_OR_RAISE(auto istream_it,
+ io::MakeInputStreamIterator(input_, read_options_.block_size));
+
+ int max_readahead = cpu_executor_->GetCapacity();
+ int readahead_restart = std::max(1, max_readahead / 2);
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto bg_it, MakeBackgroundGenerator(std::move(istream_it), io_context_.executor(),
+ max_readahead, readahead_restart));
+
+ auto transferred_it = MakeTransferredGenerator(bg_it, cpu_executor_);
+ buffer_generator_ = CSVBufferIterator::MakeAsync(std::move(transferred_it));
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<Table>> Read() override { return ReadAsync().result(); }
+
+ Future<std::shared_ptr<Table>> ReadAsync() override {
+ task_group_ =
+ internal::TaskGroup::MakeThreaded(cpu_executor_, io_context_.stop_token());
+
+ auto self = shared_from_this();
+ return ProcessFirstBuffer().Then([self](const std::shared_ptr<Buffer>& first_buffer) {
+ auto block_generator = ThreadedBlockReader::MakeAsyncIterator(
+ self->buffer_generator_, MakeChunker(self->parse_options_),
+ std::move(first_buffer), self->read_options_.skip_rows_after_names);
+
+ std::function<Status(CSVBlock)> block_visitor =
+ [self](CSVBlock maybe_block) -> Status {
+ // The logic in VisitAsyncGenerator ensures that we will never be
+ // passed an empty block (visit does not call with the end token) so
+ // we can be assured maybe_block has a value.
+ DCHECK_GE(maybe_block.block_index, 0);
+ DCHECK(!maybe_block.consume_bytes);
+
+ // Launch parse task
+ self->task_group_->Append([self, maybe_block] {
+ return self
+ ->ParseAndInsert(maybe_block.partial, maybe_block.completion,
+ maybe_block.buffer, maybe_block.block_index,
+ maybe_block.is_final)
+ .status();
+ });
+ return Status::OK();
+ };
+
+ return VisitAsyncGenerator(std::move(block_generator), block_visitor)
+ .Then([self]() -> Future<> {
+ // By this point we've added all top level tasks so it is safe to call
+ // FinishAsync
+ return self->task_group_->FinishAsync();
+ })
+ .Then([self]() -> Result<std::shared_ptr<Table>> {
+ // Finish conversion, create schema and table
+ return self->MakeTable();
+ });
+ });
+ }
+
+ protected:
+ Future<std::shared_ptr<Buffer>> ProcessFirstBuffer() {
+ // First block
+ auto first_buffer_future = buffer_generator_();
+ return first_buffer_future.Then([this](const std::shared_ptr<Buffer>& first_buffer)
+ -> Result<std::shared_ptr<Buffer>> {
+ if (first_buffer == nullptr) {
+ return Status::Invalid("Empty CSV file");
+ }
+ std::shared_ptr<Buffer> first_buffer_processed;
+ RETURN_NOT_OK(ProcessHeader(first_buffer, &first_buffer_processed));
+ RETURN_NOT_OK(MakeColumnBuilders());
+ return first_buffer_processed;
+ });
+ }
+
+ Executor* cpu_executor_;
+ AsyncGenerator<std::shared_ptr<Buffer>> buffer_generator_;
+};
+
+Result<std::shared_ptr<TableReader>> MakeTableReader(
+ MemoryPool* pool, io::IOContext io_context, std::shared_ptr<io::InputStream> input,
+ const ReadOptions& read_options, const ParseOptions& parse_options,
+ const ConvertOptions& convert_options) {
+ RETURN_NOT_OK(parse_options.Validate());
+ RETURN_NOT_OK(read_options.Validate());
+ RETURN_NOT_OK(convert_options.Validate());
+ std::shared_ptr<BaseTableReader> reader;
+ if (read_options.use_threads) {
+ auto cpu_executor = internal::GetCpuThreadPool();
+ reader = std::make_shared<AsyncThreadedTableReader>(
+ io_context, input, read_options, parse_options, convert_options, cpu_executor);
+ } else {
+ reader = std::make_shared<SerialTableReader>(io_context, input, read_options,
+ parse_options, convert_options,
+ /*count_rows=*/true);
+ }
+ RETURN_NOT_OK(reader->Init());
+ return reader;
+}
+
+Future<std::shared_ptr<StreamingReader>> MakeStreamingReader(
+ io::IOContext io_context, std::shared_ptr<io::InputStream> input,
+ internal::Executor* cpu_executor, const ReadOptions& read_options,
+ const ParseOptions& parse_options, const ConvertOptions& convert_options) {
+ RETURN_NOT_OK(parse_options.Validate());
+ RETURN_NOT_OK(read_options.Validate());
+ RETURN_NOT_OK(convert_options.Validate());
+ std::shared_ptr<StreamingReaderImpl> reader;
+ reader = std::make_shared<StreamingReaderImpl>(
+ io_context, input, read_options, parse_options, convert_options,
+ /*count_rows=*/!read_options.use_threads || cpu_executor->GetCapacity() == 1);
+ return reader->Init(cpu_executor).Then([reader] {
+ return std::dynamic_pointer_cast<StreamingReader>(reader);
+ });
+}
+
+/////////////////////////////////////////////////////////////////////////
+// Row count implementation
+
+class CSVRowCounter : public ReaderMixin,
+ public std::enable_shared_from_this<CSVRowCounter> {
+ public:
+ CSVRowCounter(io::IOContext io_context, Executor* cpu_executor,
+ std::shared_ptr<io::InputStream> input, const ReadOptions& read_options,
+ const ParseOptions& parse_options)
+ : ReaderMixin(io_context, std::move(input), read_options, parse_options,
+ ConvertOptions::Defaults(), /*count_rows=*/true),
+ cpu_executor_(cpu_executor),
+ row_count_(0) {}
+
+ Future<int64_t> Count() {
+ auto self = shared_from_this();
+ return Init(self).Then([self]() { return self->DoCount(self); });
+ }
+
+ private:
+ Future<> Init(const std::shared_ptr<CSVRowCounter>& self) {
+ ARROW_ASSIGN_OR_RAISE(auto istream_it,
+ io::MakeInputStreamIterator(input_, read_options_.block_size));
+ // TODO Consider exposing readahead as a read option (ARROW-12090)
+ ARROW_ASSIGN_OR_RAISE(auto bg_it, MakeBackgroundGenerator(std::move(istream_it),
+ io_context_.executor()));
+ auto transferred_it = MakeTransferredGenerator(bg_it, cpu_executor_);
+ auto buffer_generator = CSVBufferIterator::MakeAsync(std::move(transferred_it));
+
+ return buffer_generator().Then(
+ [self, buffer_generator](std::shared_ptr<Buffer> first_buffer) {
+ if (!first_buffer) {
+ return Status::Invalid("Empty CSV file");
+ }
+ RETURN_NOT_OK(self->ProcessHeader(first_buffer, &first_buffer));
+ self->block_generator_ = SerialBlockReader::MakeAsyncIterator(
+ buffer_generator, MakeChunker(self->parse_options_),
+ std::move(first_buffer), 0);
+ return Status::OK();
+ });
+ }
+
+ Future<int64_t> DoCount(const std::shared_ptr<CSVRowCounter>& self) {
+ // count_cb must return a value instead of Status/Future<> to work with
+ // MakeMappedGenerator, and it must use a type with a valid end value to work with
+ // IterationEnd.
+ std::function<Result<util::optional<int64_t>>(const CSVBlock&)> count_cb =
+ [self](const CSVBlock& maybe_block) -> Result<util::optional<int64_t>> {
+ ARROW_ASSIGN_OR_RAISE(
+ auto parser,
+ self->Parse(maybe_block.partial, maybe_block.completion, maybe_block.buffer,
+ maybe_block.block_index, maybe_block.is_final));
+ RETURN_NOT_OK(maybe_block.consume_bytes(parser.parsed_bytes));
+ int32_t total_row_count = parser.parser->total_num_rows();
+ self->row_count_ += total_row_count;
+ return total_row_count;
+ };
+ auto count_gen = MakeMappedGenerator(block_generator_, std::move(count_cb));
+ return DiscardAllFromAsyncGenerator(count_gen).Then(
+ [self]() { return self->row_count_; });
+ }
+
+ Executor* cpu_executor_;
+ AsyncGenerator<CSVBlock> block_generator_;
+ int64_t row_count_;
+};
+
+} // namespace
+
+/////////////////////////////////////////////////////////////////////////
+// Factory functions
+
+Result<std::shared_ptr<TableReader>> TableReader::Make(
+ io::IOContext io_context, std::shared_ptr<io::InputStream> input,
+ const ReadOptions& read_options, const ParseOptions& parse_options,
+ const ConvertOptions& convert_options) {
+ return MakeTableReader(io_context.pool(), io_context, std::move(input), read_options,
+ parse_options, convert_options);
+}
+
+Result<std::shared_ptr<TableReader>> TableReader::Make(
+ MemoryPool* pool, io::IOContext io_context, std::shared_ptr<io::InputStream> input,
+ const ReadOptions& read_options, const ParseOptions& parse_options,
+ const ConvertOptions& convert_options) {
+ return MakeTableReader(pool, io_context, std::move(input), read_options, parse_options,
+ convert_options);
+}
+
+Result<std::shared_ptr<StreamingReader>> StreamingReader::Make(
+ MemoryPool* pool, std::shared_ptr<io::InputStream> input,
+ const ReadOptions& read_options, const ParseOptions& parse_options,
+ const ConvertOptions& convert_options) {
+ auto io_context = io::IOContext(pool);
+ auto cpu_executor = internal::GetCpuThreadPool();
+ auto reader_fut = MakeStreamingReader(io_context, std::move(input), cpu_executor,
+ read_options, parse_options, convert_options);
+ auto reader_result = reader_fut.result();
+ ARROW_ASSIGN_OR_RAISE(auto reader, reader_result);
+ return reader;
+}
+
+Result<std::shared_ptr<StreamingReader>> StreamingReader::Make(
+ io::IOContext io_context, std::shared_ptr<io::InputStream> input,
+ const ReadOptions& read_options, const ParseOptions& parse_options,
+ const ConvertOptions& convert_options) {
+ auto cpu_executor = internal::GetCpuThreadPool();
+ auto reader_fut = MakeStreamingReader(io_context, std::move(input), cpu_executor,
+ read_options, parse_options, convert_options);
+ auto reader_result = reader_fut.result();
+ ARROW_ASSIGN_OR_RAISE(auto reader, reader_result);
+ return reader;
+}
+
+Future<std::shared_ptr<StreamingReader>> StreamingReader::MakeAsync(
+ io::IOContext io_context, std::shared_ptr<io::InputStream> input,
+ internal::Executor* cpu_executor, const ReadOptions& read_options,
+ const ParseOptions& parse_options, const ConvertOptions& convert_options) {
+ return MakeStreamingReader(io_context, std::move(input), cpu_executor, read_options,
+ parse_options, convert_options);
+}
+
+Future<int64_t> CountRowsAsync(io::IOContext io_context,
+ std::shared_ptr<io::InputStream> input,
+ internal::Executor* cpu_executor,
+ const ReadOptions& read_options,
+ const ParseOptions& parse_options) {
+ RETURN_NOT_OK(parse_options.Validate());
+ RETURN_NOT_OK(read_options.Validate());
+ auto counter = std::make_shared<CSVRowCounter>(
+ io_context, cpu_executor, std::move(input), read_options, parse_options);
+ return counter->Count();
+}
+
+} // namespace csv
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/reader.h b/src/arrow/cpp/src/arrow/csv/reader.h
new file mode 100644
index 000000000..253db6892
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/reader.h
@@ -0,0 +1,125 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/csv/options.h" // IWYU pragma: keep
+#include "arrow/io/interfaces.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/future.h"
+#include "arrow/util/thread_pool.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace io {
+class InputStream;
+} // namespace io
+
+namespace csv {
+
+/// A class that reads an entire CSV file into a Arrow Table
+class ARROW_EXPORT TableReader {
+ public:
+ virtual ~TableReader() = default;
+
+ /// Read the entire CSV file and convert it to a Arrow Table
+ virtual Result<std::shared_ptr<Table>> Read() = 0;
+ /// Read the entire CSV file and convert it to a Arrow Table
+ virtual Future<std::shared_ptr<Table>> ReadAsync() = 0;
+
+ /// Create a TableReader instance
+ static Result<std::shared_ptr<TableReader>> Make(io::IOContext io_context,
+ std::shared_ptr<io::InputStream> input,
+ const ReadOptions&,
+ const ParseOptions&,
+ const ConvertOptions&);
+
+ ARROW_DEPRECATED(
+ "Deprecated in 4.0.0. "
+ "Use MemoryPool-less variant (the IOContext holds a pool already)")
+ static Result<std::shared_ptr<TableReader>> Make(
+ MemoryPool* pool, io::IOContext io_context, std::shared_ptr<io::InputStream> input,
+ const ReadOptions&, const ParseOptions&, const ConvertOptions&);
+};
+
+/// \brief A class that reads a CSV file incrementally
+///
+/// Caveats:
+/// - For now, this is always single-threaded (regardless of `ReadOptions::use_threads`.
+/// - Type inference is done on the first block and types are frozen afterwards;
+/// to make sure the right data types are inferred, either set
+/// `ReadOptions::block_size` to a large enough value, or use
+/// `ConvertOptions::column_types` to set the desired data types explicitly.
+class ARROW_EXPORT StreamingReader : public RecordBatchReader {
+ public:
+ virtual ~StreamingReader() = default;
+
+ virtual Future<std::shared_ptr<RecordBatch>> ReadNextAsync() = 0;
+
+ /// \brief Return the number of bytes which have been read and processed
+ ///
+ /// The returned number includes CSV bytes which the StreamingReader has
+ /// finished processing, but not bytes for which some processing (e.g.
+ /// CSV parsing or conversion to Arrow layout) is still ongoing.
+ ///
+ /// Furthermore, the following rules apply:
+ /// - bytes skipped by `ReadOptions.skip_rows` are counted as being read before
+ /// any records are returned.
+ /// - bytes read while parsing the header are counted as being read before any
+ /// records are returned.
+ /// - bytes skipped by `ReadOptions.skip_rows_after_names` are counted after the
+ /// first batch is returned.
+ virtual int64_t bytes_read() const = 0;
+
+ /// Create a StreamingReader instance
+ ///
+ /// This involves some I/O as the first batch must be loaded during the creation process
+ /// so it is returned as a future
+ ///
+ /// Currently, the StreamingReader is not async-reentrant and does not do any fan-out
+ /// parsing (see ARROW-11889)
+ static Future<std::shared_ptr<StreamingReader>> MakeAsync(
+ io::IOContext io_context, std::shared_ptr<io::InputStream> input,
+ internal::Executor* cpu_executor, const ReadOptions&, const ParseOptions&,
+ const ConvertOptions&);
+
+ static Result<std::shared_ptr<StreamingReader>> Make(
+ io::IOContext io_context, std::shared_ptr<io::InputStream> input,
+ const ReadOptions&, const ParseOptions&, const ConvertOptions&);
+
+ ARROW_DEPRECATED("Deprecated in 4.0.0. Use IOContext-based overload")
+ static Result<std::shared_ptr<StreamingReader>> Make(
+ MemoryPool* pool, std::shared_ptr<io::InputStream> input,
+ const ReadOptions& read_options, const ParseOptions& parse_options,
+ const ConvertOptions& convert_options);
+};
+
+/// \brief Count the logical rows of data in a CSV file (i.e. the
+/// number of rows you would get if you read the file into a table).
+ARROW_EXPORT
+Future<int64_t> CountRowsAsync(io::IOContext io_context,
+ std::shared_ptr<io::InputStream> input,
+ internal::Executor* cpu_executor, const ReadOptions&,
+ const ParseOptions&);
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/reader_test.cc b/src/arrow/cpp/src/arrow/csv/reader_test.cc
new file mode 100644
index 000000000..328ad0c4d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/reader_test.cc
@@ -0,0 +1,490 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/csv/reader.h"
+
+#include <gtest/gtest.h>
+
+#include <atomic>
+#include <cstdint>
+#include <string>
+#include <thread>
+#include <utility>
+#include <vector>
+
+#include "arrow/csv/options.h"
+#include "arrow/csv/test_common.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/memory.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/future.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+namespace csv {
+
+// Allows the streaming reader to be used in tests that expect a table reader
+class StreamingReaderAsTableReader : public TableReader {
+ public:
+ explicit StreamingReaderAsTableReader(std::shared_ptr<StreamingReader> reader)
+ : reader_(std::move(reader)) {}
+ virtual ~StreamingReaderAsTableReader() = default;
+ virtual Result<std::shared_ptr<Table>> Read() {
+ std::shared_ptr<Table> table;
+ RETURN_NOT_OK(reader_->ReadAll(&table));
+ return table;
+ }
+ virtual Future<std::shared_ptr<Table>> ReadAsync() {
+ auto reader = reader_;
+ AsyncGenerator<std::shared_ptr<RecordBatch>> gen = [reader] {
+ return reader->ReadNextAsync();
+ };
+ return CollectAsyncGenerator(std::move(gen))
+ .Then([](const RecordBatchVector& batches) {
+ return Table::FromRecordBatches(batches);
+ });
+ }
+
+ private:
+ std::shared_ptr<StreamingReader> reader_;
+};
+
+using TableReaderFactory = std::function<Result<std::shared_ptr<TableReader>>(
+ std::shared_ptr<io::InputStream>, ParseOptions)>;
+using StreamingReaderFactory = std::function<Result<std::shared_ptr<StreamingReader>>(
+ std::shared_ptr<io::InputStream>)>;
+
+void TestEmptyTable(TableReaderFactory reader_factory) {
+ auto empty_buffer = std::make_shared<Buffer>("");
+ auto empty_input = std::make_shared<io::BufferReader>(empty_buffer);
+ auto maybe_reader = reader_factory(empty_input, ParseOptions::Defaults());
+ // Streaming reader fails on open, table readers fail on first read
+ if (maybe_reader.ok()) {
+ ASSERT_FINISHES_AND_RAISES(Invalid, (*maybe_reader)->ReadAsync());
+ } else {
+ ASSERT_TRUE(maybe_reader.status().IsInvalid());
+ }
+}
+
+void TestHeaderOnly(TableReaderFactory reader_factory) {
+ auto header_only_buffer = std::make_shared<Buffer>("a,b,c\n");
+ auto input = std::make_shared<io::BufferReader>(header_only_buffer);
+ ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(input, ParseOptions::Defaults()));
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto table, reader->ReadAsync());
+ ASSERT_EQ(table->schema()->num_fields(), 3);
+ ASSERT_EQ(table->num_rows(), 0);
+}
+
+void TestHeaderOnlyStreaming(StreamingReaderFactory reader_factory) {
+ auto header_only_buffer = std::make_shared<Buffer>("a,b,c\n");
+ auto input = std::make_shared<io::BufferReader>(header_only_buffer);
+ ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(input));
+ std::shared_ptr<RecordBatch> next_batch;
+ ASSERT_OK(reader->ReadNext(&next_batch));
+ ASSERT_EQ(next_batch, nullptr);
+}
+
+void StressTableReader(TableReaderFactory reader_factory) {
+#ifdef ARROW_VALGRIND
+ const int NTASKS = 10;
+ const int NROWS = 100;
+#else
+ const int NTASKS = 100;
+ const int NROWS = 1000;
+#endif
+ ASSERT_OK_AND_ASSIGN(auto table_buffer, MakeSampleCsvBuffer(NROWS));
+
+ std::vector<Future<std::shared_ptr<Table>>> task_futures(NTASKS);
+ for (int i = 0; i < NTASKS; i++) {
+ auto input = std::make_shared<io::BufferReader>(table_buffer);
+ ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(input, ParseOptions::Defaults()));
+ task_futures[i] = reader->ReadAsync();
+ }
+ auto combined_future = All(task_futures);
+ combined_future.Wait();
+
+ ASSERT_OK_AND_ASSIGN(std::vector<Result<std::shared_ptr<Table>>> results,
+ combined_future.result());
+ for (auto&& result : results) {
+ ASSERT_OK_AND_ASSIGN(auto table, result);
+ ASSERT_EQ(NROWS, table->num_rows());
+ }
+}
+
+void StressInvalidTableReader(TableReaderFactory reader_factory) {
+#ifdef ARROW_VALGRIND
+ const int NTASKS = 10;
+ const int NROWS = 100;
+#else
+ const int NTASKS = 100;
+ const int NROWS = 1000;
+#endif
+ ASSERT_OK_AND_ASSIGN(auto table_buffer, MakeSampleCsvBuffer(NROWS, [=](size_t row_num) {
+ return row_num != NROWS / 2;
+ }));
+
+ std::vector<Future<std::shared_ptr<Table>>> task_futures(NTASKS);
+ for (int i = 0; i < NTASKS; i++) {
+ auto input = std::make_shared<io::BufferReader>(table_buffer);
+ ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(input, ParseOptions::Defaults()));
+ task_futures[i] = reader->ReadAsync();
+ }
+ auto combined_future = All(task_futures);
+ combined_future.Wait();
+
+ ASSERT_OK_AND_ASSIGN(std::vector<Result<std::shared_ptr<Table>>> results,
+ combined_future.result());
+ for (auto&& result : results) {
+ ASSERT_RAISES(Invalid, result);
+ }
+}
+
+void TestNestedParallelism(std::shared_ptr<internal::ThreadPool> thread_pool,
+ TableReaderFactory reader_factory) {
+ const int NROWS = 1000;
+ ASSERT_OK_AND_ASSIGN(auto table_buffer, MakeSampleCsvBuffer(NROWS));
+ auto input = std::make_shared<io::BufferReader>(table_buffer);
+ ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(input, ParseOptions::Defaults()));
+
+ Future<std::shared_ptr<Table>> table_future;
+
+ auto read_task = [&reader, &table_future]() mutable {
+ table_future = reader->ReadAsync();
+ return Status::OK();
+ };
+ ASSERT_OK_AND_ASSIGN(auto future, thread_pool->Submit(read_task));
+
+ ASSERT_FINISHES_OK(future);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto table, table_future);
+ ASSERT_EQ(table->num_rows(), NROWS);
+}
+
+void TestInvalidRowsSkipped(TableReaderFactory reader_factory, bool async) {
+ const int NROWS = 1000;
+ const int INVALID_EVERY = 20;
+ const int NINVALID = 50;
+
+ auto opts = ParseOptions::Defaults();
+ std::atomic<int> num_invalid_rows(0);
+ opts.invalid_row_handler = [&](const InvalidRow& row) {
+ auto cur_invalid_rows = ++num_invalid_rows;
+ if (async) {
+ // Row numbers are not counted in batches during async processing
+ EXPECT_EQ(-1, row.number);
+ } else {
+ // actual data starts at row #2 in the CSV "file"
+ EXPECT_EQ(cur_invalid_rows * INVALID_EVERY + 1, row.number);
+ }
+ return InvalidRowResult::Skip;
+ };
+
+ ASSERT_OK_AND_ASSIGN(auto table_buffer, MakeSampleCsvBuffer(NROWS, [=](size_t row_num) {
+ // row_num is 0-based
+ return (row_num + 1) % static_cast<size_t>(INVALID_EVERY) != 0;
+ }));
+ auto input = std::make_shared<io::BufferReader>(table_buffer);
+ ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(input, std::move(opts)));
+ ASSERT_OK_AND_ASSIGN(auto table, reader->Read());
+ ASSERT_EQ(NROWS - NINVALID, table->num_rows());
+ ASSERT_EQ(NINVALID, num_invalid_rows);
+}
+
+TableReaderFactory MakeSerialFactory() {
+ return [](std::shared_ptr<io::InputStream> input_stream, ParseOptions parse_options) {
+ auto read_options = ReadOptions::Defaults();
+ read_options.block_size = 1 << 10;
+ read_options.use_threads = false;
+ return TableReader::Make(io::default_io_context(), input_stream, read_options,
+ std::move(parse_options), ConvertOptions::Defaults());
+ };
+}
+
+TEST(SerialReaderTests, Empty) { TestEmptyTable(MakeSerialFactory()); }
+TEST(SerialReaderTests, HeaderOnly) { TestHeaderOnly(MakeSerialFactory()); }
+TEST(SerialReaderTests, Stress) { StressTableReader(MakeSerialFactory()); }
+TEST(SerialReaderTests, StressInvalid) { StressInvalidTableReader(MakeSerialFactory()); }
+TEST(SerialReaderTests, NestedParallelism) {
+ ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1));
+ TestNestedParallelism(thread_pool, MakeSerialFactory());
+}
+TEST(SerialReaderTests, InvalidRowsSkipped) {
+ TestInvalidRowsSkipped(MakeSerialFactory(), /*async=*/false);
+}
+
+Result<TableReaderFactory> MakeAsyncFactory(
+ std::shared_ptr<internal::ThreadPool> thread_pool = nullptr) {
+ if (!thread_pool) {
+ ARROW_ASSIGN_OR_RAISE(thread_pool, internal::ThreadPool::Make(1));
+ }
+ return [thread_pool](
+ std::shared_ptr<io::InputStream> input_stream,
+ ParseOptions parse_options) -> Result<std::shared_ptr<TableReader>> {
+ ReadOptions read_options = ReadOptions::Defaults();
+ read_options.use_threads = true;
+ read_options.block_size = 1 << 10;
+ auto table_reader =
+ TableReader::Make(io::IOContext(thread_pool.get()), input_stream, read_options,
+ std::move(parse_options), ConvertOptions::Defaults());
+ return table_reader;
+ };
+}
+
+TEST(AsyncReaderTests, Empty) {
+ ASSERT_OK_AND_ASSIGN(auto table_factory, MakeAsyncFactory());
+ TestEmptyTable(table_factory);
+}
+TEST(AsyncReaderTests, HeaderOnly) {
+ ASSERT_OK_AND_ASSIGN(auto table_factory, MakeAsyncFactory());
+ TestHeaderOnly(table_factory);
+}
+TEST(AsyncReaderTests, Stress) {
+ ASSERT_OK_AND_ASSIGN(auto table_factory, MakeAsyncFactory());
+ StressTableReader(table_factory);
+}
+TEST(AsyncReaderTests, StressInvalid) {
+ ASSERT_OK_AND_ASSIGN(auto table_factory, MakeAsyncFactory());
+ StressInvalidTableReader(table_factory);
+}
+TEST(AsyncReaderTests, NestedParallelism) {
+ ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1));
+ ASSERT_OK_AND_ASSIGN(auto table_factory, MakeAsyncFactory(thread_pool));
+ TestNestedParallelism(thread_pool, table_factory);
+}
+TEST(AsyncReaderTests, InvalidRowsSkipped) {
+ ASSERT_OK_AND_ASSIGN(auto table_factory, MakeAsyncFactory());
+ TestInvalidRowsSkipped(table_factory, /*async=*/true);
+}
+
+TableReaderFactory MakeStreamingFactory(bool use_threads = true) {
+ return [use_threads](
+ std::shared_ptr<io::InputStream> input_stream,
+ ParseOptions parse_options) -> Result<std::shared_ptr<TableReader>> {
+ auto read_options = ReadOptions::Defaults();
+ read_options.block_size = 1 << 10;
+ read_options.use_threads = use_threads;
+ ARROW_ASSIGN_OR_RAISE(
+ auto streaming_reader,
+ StreamingReader::Make(io::default_io_context(), input_stream, read_options,
+ std::move(parse_options), ConvertOptions::Defaults()));
+ return std::make_shared<StreamingReaderAsTableReader>(std::move(streaming_reader));
+ };
+}
+
+Result<StreamingReaderFactory> MakeStreamingReaderFactory() {
+ return [](std::shared_ptr<io::InputStream> input_stream)
+ -> Result<std::shared_ptr<StreamingReader>> {
+ auto read_options = ReadOptions::Defaults();
+ read_options.block_size = 1 << 10;
+ read_options.use_threads = true;
+ return StreamingReader::Make(io::default_io_context(), input_stream, read_options,
+ ParseOptions::Defaults(), ConvertOptions::Defaults());
+ };
+}
+
+TEST(StreamingReaderTests, Empty) { TestEmptyTable(MakeStreamingFactory()); }
+TEST(StreamingReaderTests, HeaderOnly) {
+ ASSERT_OK_AND_ASSIGN(auto table_factory, MakeStreamingReaderFactory());
+ TestHeaderOnlyStreaming(table_factory);
+}
+TEST(StreamingReaderTests, Stress) { StressTableReader(MakeStreamingFactory()); }
+TEST(StreamingReaderTests, StressInvalid) {
+ StressInvalidTableReader(MakeStreamingFactory());
+}
+TEST(StreamingReaderTests, NestedParallelism) {
+ ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1));
+ TestNestedParallelism(thread_pool, MakeStreamingFactory());
+}
+TEST(StreamingReaderTests, InvalidRowsSkipped) {
+ TestInvalidRowsSkipped(MakeStreamingFactory(/*use_threads=*/false), /*async=*/false);
+}
+TEST(StreamingReaderTests, InvalidRowsSkippedAsync) {
+ TestInvalidRowsSkipped(MakeStreamingFactory(), /*async=*/true);
+}
+
+TEST(StreamingReaderTests, BytesRead) {
+ ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1));
+ auto table_buffer =
+ std::make_shared<Buffer>("a,b,c\n123,456,789\n101,112,131\n415,161,718\n");
+
+ // Basic read without any skips and small block size
+ {
+ auto input = std::make_shared<io::BufferReader>(table_buffer);
+
+ auto read_options = ReadOptions::Defaults();
+ read_options.block_size = 20;
+ read_options.use_threads = false;
+ ASSERT_OK_AND_ASSIGN(
+ auto streaming_reader,
+ StreamingReader::Make(io::default_io_context(), input, read_options,
+ ParseOptions::Defaults(), ConvertOptions::Defaults()));
+ std::shared_ptr<RecordBatch> batch;
+ int64_t bytes = 6; // Size of header (counted during StreamingReader::Make)
+ do {
+ ASSERT_EQ(bytes, streaming_reader->bytes_read());
+ ASSERT_OK(streaming_reader->ReadNext(&batch));
+ bytes += 12; // Add size of each row
+ } while (bytes <= 42);
+ ASSERT_EQ(42, streaming_reader->bytes_read());
+ ASSERT_EQ(batch.get(), nullptr);
+ }
+
+ // Interaction of skip_rows and bytes_read()
+ {
+ auto input = std::make_shared<io::BufferReader>(table_buffer);
+
+ auto read_options = ReadOptions::Defaults();
+ read_options.skip_rows = 1;
+ read_options.block_size = 32;
+ ASSERT_OK_AND_ASSIGN(
+ auto streaming_reader,
+ StreamingReader::Make(io::default_io_context(), input, read_options,
+ ParseOptions::Defaults(), ConvertOptions::Defaults()));
+ std::shared_ptr<RecordBatch> batch;
+ // The header (6 bytes) and first skipped row (12 bytes) are counted during
+ // StreamingReader::Make
+ ASSERT_EQ(18, streaming_reader->bytes_read());
+ ASSERT_OK(streaming_reader->ReadNext(&batch));
+ ASSERT_NE(batch.get(), nullptr);
+ ASSERT_EQ(30, streaming_reader->bytes_read());
+ ASSERT_OK(streaming_reader->ReadNext(&batch));
+ ASSERT_NE(batch.get(), nullptr);
+ ASSERT_EQ(42, streaming_reader->bytes_read());
+ ASSERT_OK(streaming_reader->ReadNext(&batch));
+ ASSERT_EQ(batch.get(), nullptr);
+ }
+
+ // Interaction of skip_rows_after_names and bytes_read()
+ {
+ auto input = std::make_shared<io::BufferReader>(table_buffer);
+
+ auto read_options = ReadOptions::Defaults();
+ read_options.block_size = 32;
+ read_options.skip_rows_after_names = 1;
+
+ ASSERT_OK_AND_ASSIGN(
+ auto streaming_reader,
+ StreamingReader::Make(io::default_io_context(), input, read_options,
+ ParseOptions::Defaults(), ConvertOptions::Defaults()));
+ std::shared_ptr<RecordBatch> batch;
+
+ // The header is read as part of StreamingReader::Make
+ ASSERT_EQ(6, streaming_reader->bytes_read());
+ ASSERT_OK(streaming_reader->ReadNext(&batch));
+ ASSERT_NE(batch.get(), nullptr);
+ // Next the skipped batch (12 bytes) and 1 row (12 bytes)
+ ASSERT_EQ(30, streaming_reader->bytes_read());
+ ASSERT_OK(streaming_reader->ReadNext(&batch));
+ ASSERT_NE(batch.get(), nullptr);
+ ASSERT_EQ(42, streaming_reader->bytes_read());
+ ASSERT_OK(streaming_reader->ReadNext(&batch));
+ ASSERT_EQ(batch.get(), nullptr);
+ }
+}
+
+TEST(StreamingReaderTests, SkipMultipleEmptyBlocksAtStart) {
+ ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1));
+ auto table_buffer = std::make_shared<Buffer>(
+ "aaa,bbb,ccc\n123,456,789\n101,112,131\n415,161,718\n192,021,222\n324,252,627\n"
+ "282,930,313\n233,343,536\n");
+
+ auto input = std::make_shared<io::BufferReader>(table_buffer);
+
+ auto read_options = ReadOptions::Defaults();
+ read_options.block_size = 34;
+ read_options.skip_rows_after_names = 6;
+
+ ASSERT_OK_AND_ASSIGN(
+ auto streaming_reader,
+ StreamingReader::Make(io::default_io_context(), input, read_options,
+ ParseOptions::Defaults(), ConvertOptions::Defaults()));
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_EQ(12, streaming_reader->bytes_read());
+
+ // The first batch should have the one and only row in it
+ ASSERT_OK(streaming_reader->ReadNext(&batch));
+ ASSERT_NE(nullptr, batch.get());
+ ASSERT_EQ(1, batch->num_rows());
+ ASSERT_EQ(96, streaming_reader->bytes_read());
+
+ auto expected_schema =
+ schema({field("aaa", int64()), field("bbb", int64()), field("ccc", int64())});
+ AssertSchemaEqual(expected_schema, streaming_reader->schema());
+ auto expected_batch = RecordBatchFromJSON(expected_schema, "[[233,343,536]]");
+ ASSERT_TRUE(expected_batch->Equals(*batch));
+
+ ASSERT_OK(streaming_reader->ReadNext(&batch));
+ ASSERT_EQ(nullptr, batch.get());
+}
+
+TEST(CountRowsAsync, Basics) {
+ constexpr int NROWS = 4096;
+ ASSERT_OK_AND_ASSIGN(auto table_buffer, MakeSampleCsvBuffer(NROWS));
+ {
+ auto reader = std::make_shared<io::BufferReader>(table_buffer);
+ auto read_options = ReadOptions::Defaults();
+ auto parse_options = ParseOptions::Defaults();
+ ASSERT_FINISHES_OK_AND_EQ(
+ NROWS, CountRowsAsync(io::default_io_context(), reader,
+ internal::GetCpuThreadPool(), read_options, parse_options));
+ }
+ {
+ auto reader = std::make_shared<io::BufferReader>(table_buffer);
+ auto read_options = ReadOptions::Defaults();
+ read_options.skip_rows = 20;
+ auto parse_options = ParseOptions::Defaults();
+ ASSERT_FINISHES_OK_AND_EQ(NROWS - 20, CountRowsAsync(io::default_io_context(), reader,
+ internal::GetCpuThreadPool(),
+ read_options, parse_options));
+ }
+ {
+ auto reader = std::make_shared<io::BufferReader>(table_buffer);
+ auto read_options = ReadOptions::Defaults();
+ read_options.autogenerate_column_names = true;
+ auto parse_options = ParseOptions::Defaults();
+ ASSERT_FINISHES_OK_AND_EQ(NROWS + 1, CountRowsAsync(io::default_io_context(), reader,
+ internal::GetCpuThreadPool(),
+ read_options, parse_options));
+ }
+ {
+ auto reader = std::make_shared<io::BufferReader>(table_buffer);
+ auto read_options = ReadOptions::Defaults();
+ read_options.block_size = 1024;
+ auto parse_options = ParseOptions::Defaults();
+ ASSERT_FINISHES_OK_AND_EQ(
+ NROWS, CountRowsAsync(io::default_io_context(), reader,
+ internal::GetCpuThreadPool(), read_options, parse_options));
+ }
+}
+
+TEST(CountRowsAsync, Errors) {
+ ASSERT_OK_AND_ASSIGN(auto table_buffer, MakeSampleCsvBuffer(4096, [](size_t row_num) {
+ return row_num != 2048;
+ }));
+ auto reader = std::make_shared<io::BufferReader>(table_buffer);
+ auto read_options = ReadOptions::Defaults();
+ auto parse_options = ParseOptions::Defaults();
+ ASSERT_FINISHES_AND_RAISES(
+ Invalid, CountRowsAsync(io::default_io_context(), reader,
+ internal::GetCpuThreadPool(), read_options, parse_options));
+}
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/test_common.cc b/src/arrow/cpp/src/arrow/csv/test_common.cc
new file mode 100644
index 000000000..6ba4ff2e3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/test_common.cc
@@ -0,0 +1,121 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/csv/test_common.h"
+
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+namespace csv {
+
+std::string MakeCSVData(std::vector<std::string> lines) {
+ std::string s;
+ for (const auto& line : lines) {
+ s += line;
+ }
+ return s;
+}
+
+void MakeCSVParser(std::vector<std::string> lines, ParseOptions options, int32_t num_cols,
+ std::shared_ptr<BlockParser>* out) {
+ auto csv = MakeCSVData(lines);
+ auto parser = std::make_shared<BlockParser>(options, num_cols);
+ uint32_t out_size;
+ ASSERT_OK(parser->Parse(util::string_view(csv), &out_size));
+ ASSERT_EQ(out_size, csv.size()) << "trailing CSV data not parsed";
+ *out = parser;
+}
+
+void MakeCSVParser(std::vector<std::string> lines, ParseOptions options,
+ std::shared_ptr<BlockParser>* out) {
+ return MakeCSVParser(lines, options, -1, out);
+}
+
+void MakeCSVParser(std::vector<std::string> lines, std::shared_ptr<BlockParser>* out) {
+ MakeCSVParser(lines, ParseOptions::Defaults(), out);
+}
+
+void MakeColumnParser(std::vector<std::string> items, std::shared_ptr<BlockParser>* out) {
+ auto options = ParseOptions::Defaults();
+ // Need this to test for null (empty) values
+ options.ignore_empty_lines = false;
+ std::vector<std::string> lines;
+ for (const auto& item : items) {
+ lines.push_back(item + '\n');
+ }
+ MakeCSVParser(lines, options, 1, out);
+ ASSERT_EQ((*out)->num_cols(), 1) << "Should have seen only 1 CSV column";
+ ASSERT_EQ((*out)->num_rows(), items.size());
+}
+
+namespace {
+
+const std::vector<std::string> int64_rows = {"123", "4", "-317005557", "", "N/A", "0"};
+const std::vector<std::string> float_rows = {"0", "123.456", "-3170.55766", "", "N/A"};
+const std::vector<std::string> decimal128_rows = {"0", "123.456", "-3170.55766",
+ "", "N/A", "1233456789.123456789"};
+const std::vector<std::string> iso8601_rows = {"1917-10-17", "2018-09-13",
+ "1941-06-22 04:00", "1945-05-09 09:45:38"};
+const std::vector<std::string> strptime_rows = {"10/17/1917", "9/13/2018", "9/5/1945"};
+
+static void WriteHeader(std::ostream& writer) {
+ writer << "Int64,Float,Decimal128,ISO8601,Strptime" << std::endl;
+}
+
+static std::string GetCell(const std::vector<std::string>& base_rows, size_t row_index) {
+ return base_rows[row_index % base_rows.size()];
+}
+
+static void WriteRow(std::ostream& writer, size_t row_index) {
+ writer << GetCell(int64_rows, row_index);
+ writer << ',';
+ writer << GetCell(float_rows, row_index);
+ writer << ',';
+ writer << GetCell(decimal128_rows, row_index);
+ writer << ',';
+ writer << GetCell(iso8601_rows, row_index);
+ writer << ',';
+ writer << GetCell(strptime_rows, row_index);
+ writer << std::endl;
+}
+
+static void WriteInvalidRow(std::ostream& writer, size_t row_index) {
+ writer << "\"" << std::endl << "\"";
+ writer << std::endl;
+}
+} // namespace
+
+Result<std::shared_ptr<Buffer>> MakeSampleCsvBuffer(
+ size_t num_rows, std::function<bool(size_t)> is_valid) {
+ std::stringstream writer;
+
+ WriteHeader(writer);
+ for (size_t i = 0; i < num_rows; ++i) {
+ if (!is_valid || is_valid(i)) {
+ WriteRow(writer, i);
+ } else {
+ WriteInvalidRow(writer, i);
+ }
+ }
+
+ auto table_str = writer.str();
+ auto table_buffer = std::make_shared<Buffer>(table_str);
+ return MemoryManager::CopyBuffer(table_buffer, default_cpu_memory_manager());
+}
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/test_common.h b/src/arrow/cpp/src/arrow/csv/test_common.h
new file mode 100644
index 000000000..810a0b472
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/test_common.h
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/csv/parser.h"
+#include "arrow/testing/visibility.h"
+
+namespace arrow {
+namespace csv {
+
+ARROW_TESTING_EXPORT
+std::string MakeCSVData(std::vector<std::string> lines);
+
+// Make a BlockParser from a vector of lines representing a CSV file
+ARROW_TESTING_EXPORT
+void MakeCSVParser(std::vector<std::string> lines, ParseOptions options, int32_t num_cols,
+ std::shared_ptr<BlockParser>* out);
+
+ARROW_TESTING_EXPORT
+void MakeCSVParser(std::vector<std::string> lines, ParseOptions options,
+ std::shared_ptr<BlockParser>* out);
+
+ARROW_TESTING_EXPORT
+void MakeCSVParser(std::vector<std::string> lines, std::shared_ptr<BlockParser>* out);
+
+// Make a BlockParser from a vector of strings representing a single CSV column
+ARROW_TESTING_EXPORT
+void MakeColumnParser(std::vector<std::string> items, std::shared_ptr<BlockParser>* out);
+
+ARROW_TESTING_EXPORT
+Result<std::shared_ptr<Buffer>> MakeSampleCsvBuffer(
+ size_t num_rows, std::function<bool(size_t row_num)> is_valid = {});
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/type_fwd.h b/src/arrow/cpp/src/arrow/csv/type_fwd.h
new file mode 100644
index 000000000..c0a53847a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/type_fwd.h
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+namespace arrow {
+namespace csv {
+
+class TableReader;
+struct ConvertOptions;
+struct ReadOptions;
+struct ParseOptions;
+struct WriteOptions;
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/writer.cc b/src/arrow/cpp/src/arrow/csv/writer.cc
new file mode 100644
index 000000000..1b782cae7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/writer.cc
@@ -0,0 +1,460 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/csv/writer.h"
+#include "arrow/array.h"
+#include "arrow/compute/cast.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/result_internal.h"
+#include "arrow/stl_allocator.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+namespace csv {
+// This implementation is intentionally light on configurability to minimize the size of
+// the initial PR. Aditional features can be added as there is demand and interest to
+// implement them.
+//
+// The algorithm used here at a high level is to break RecordBatches/Tables into slices
+// and convert each slice independently. A slice is then converted to CSV by first
+// scanning each column to determine the size of its contents when rendered as a string in
+// CSV. For non-string types this requires casting the value to string (which is cached).
+// This data is used to understand the precise length of each row and a single allocation
+// for the final CSV data buffer. Once the final size is known each column is then
+// iterated over again to place its contents into the CSV data buffer. The rationale for
+// choosing this approach is it allows for reuse of the cast functionality in the compute
+// module and inline data visiting functionality in the core library. A performance
+// comparison has not been done using a naive single-pass approach. This approach might
+// still be competitive due to reduction in the number of per row branches necessary with
+// a single pass approach. Profiling would likely yield further opportunities for
+// optimization with this approach.
+
+namespace {
+
+struct SliceIteratorFunctor {
+ Result<std::shared_ptr<RecordBatch>> Next() {
+ if (current_offset < batch->num_rows()) {
+ std::shared_ptr<RecordBatch> next = batch->Slice(current_offset, slice_size);
+ current_offset += slice_size;
+ return next;
+ }
+ return IterationTraits<std::shared_ptr<RecordBatch>>::End();
+ }
+ const RecordBatch* const batch;
+ const int64_t slice_size;
+ int64_t current_offset;
+};
+
+RecordBatchIterator RecordBatchSliceIterator(const RecordBatch& batch,
+ int64_t slice_size) {
+ SliceIteratorFunctor functor = {&batch, slice_size, /*offset=*/static_cast<int64_t>(0)};
+ return RecordBatchIterator(std::move(functor));
+}
+
+// Counts the number of characters that need escaping in s.
+int64_t CountEscapes(util::string_view s) {
+ return static_cast<int64_t>(std::count(s.begin(), s.end(), '"'));
+}
+
+// Matching quote pair character length.
+constexpr int64_t kQuoteCount = 2;
+constexpr int64_t kQuoteDelimiterCount = kQuoteCount + /*end_char*/ 1;
+
+// Interface for generating CSV data per column.
+// The intended usage is to iteratively call UpdateRowLengths for a column and
+// then PopulateColumns. PopulateColumns must be called in the reverse order of the
+// populators (it populates data backwards).
+class ColumnPopulator {
+ public:
+ ColumnPopulator(MemoryPool* pool, char end_char) : end_char_(end_char), pool_(pool) {}
+
+ virtual ~ColumnPopulator() = default;
+
+ // Adds the number of characters each entry in data will add to to elements
+ // in row_lengths.
+ Status UpdateRowLengths(const Array& data, int32_t* row_lengths) {
+ compute::ExecContext ctx(pool_);
+ // Populators are intented to be applied to reasonably small data. In most cases
+ // threading overhead would not be justified.
+ ctx.set_use_threads(false);
+ ASSIGN_OR_RAISE(
+ std::shared_ptr<Array> casted,
+ compute::Cast(data, /*to_type=*/utf8(), compute::CastOptions(), &ctx));
+ casted_array_ = internal::checked_pointer_cast<StringArray>(casted);
+ return UpdateRowLengths(row_lengths);
+ }
+
+ // Places string data onto each row in output and updates the corresponding row
+ // row pointers in preparation for calls to other (preceding) ColumnPopulators.
+ // Args:
+ // output: character buffer to write to.
+ // offsets: an array of end of row column within the the output buffer (values are
+ // one past the end of the position to write to).
+ virtual void PopulateColumns(char* output, int32_t* offsets) const = 0;
+
+ protected:
+ virtual Status UpdateRowLengths(int32_t* row_lengths) = 0;
+ std::shared_ptr<StringArray> casted_array_;
+ const char end_char_;
+
+ private:
+ MemoryPool* const pool_;
+};
+
+// Copies the contents of to out properly escaping any necessary characters.
+// Returns the position prior to last copied character (out_end is decremented).
+char* EscapeReverse(arrow::util::string_view s, char* out_end) {
+ for (const char* val = s.data() + s.length() - 1; val >= s.data(); val--, out_end--) {
+ if (*val == '"') {
+ *out_end = *val;
+ out_end--;
+ }
+ *out_end = *val;
+ }
+ return out_end;
+}
+
+// Populator for non-string types. This populator relies on compute Cast functionality to
+// String if it doesn't exist it will be an error. it also assumes the resulting string
+// from a cast does not require quoting or escaping.
+class UnquotedColumnPopulator : public ColumnPopulator {
+ public:
+ explicit UnquotedColumnPopulator(MemoryPool* memory_pool, char end_char)
+ : ColumnPopulator(memory_pool, end_char) {}
+
+ Status UpdateRowLengths(int32_t* row_lengths) override {
+ for (int x = 0; x < casted_array_->length(); x++) {
+ row_lengths[x] += casted_array_->value_length(x);
+ }
+ return Status::OK();
+ }
+
+ void PopulateColumns(char* output, int32_t* offsets) const override {
+ VisitArrayDataInline<StringType>(
+ *casted_array_->data(),
+ [&](arrow::util::string_view s) {
+ int64_t next_column_offset = s.length() + /*end_char*/ 1;
+ memcpy((output + *offsets - next_column_offset), s.data(), s.length());
+ *(output + *offsets - 1) = end_char_;
+ *offsets -= static_cast<int32_t>(next_column_offset);
+ offsets++;
+ },
+ [&]() {
+ // Nulls are empty (unquoted) to distinguish with empty string.
+ *(output + *offsets - 1) = end_char_;
+ *offsets -= 1;
+ offsets++;
+ });
+ }
+};
+
+// Strings need special handling to ensure they are escaped properly.
+// This class handles escaping assuming that all strings will be quoted
+// and that the only character within the string that needs to escaped is
+// a quote character (") and escaping is done my adding another quote.
+class QuotedColumnPopulator : public ColumnPopulator {
+ public:
+ QuotedColumnPopulator(MemoryPool* pool, char end_char)
+ : ColumnPopulator(pool, end_char) {}
+
+ Status UpdateRowLengths(int32_t* row_lengths) override {
+ const StringArray& input = *casted_array_;
+ int row_number = 0;
+ row_needs_escaping_.resize(casted_array_->length());
+ VisitArrayDataInline<StringType>(
+ *input.data(),
+ [&](arrow::util::string_view s) {
+ int64_t escaped_count = CountEscapes(s);
+ // TODO: Maybe use 64 bit row lengths or safe cast?
+ row_needs_escaping_[row_number] = escaped_count > 0;
+ row_lengths[row_number] += static_cast<int32_t>(s.length()) +
+ static_cast<int32_t>(escaped_count + kQuoteCount);
+ row_number++;
+ },
+ [&]() {
+ row_needs_escaping_[row_number] = false;
+ row_number++;
+ });
+ return Status::OK();
+ }
+
+ void PopulateColumns(char* output, int32_t* offsets) const override {
+ auto needs_escaping = row_needs_escaping_.begin();
+ VisitArrayDataInline<StringType>(
+ *(casted_array_->data()),
+ [&](arrow::util::string_view s) {
+ // still needs string content length to be added
+ char* row_end = output + *offsets;
+ int32_t next_column_offset = 0;
+ if (!*needs_escaping) {
+ next_column_offset = static_cast<int32_t>(s.length() + kQuoteDelimiterCount);
+ memcpy(row_end - next_column_offset + /*quote_offset=*/1, s.data(),
+ s.length());
+ } else {
+ // Adjust row_end by 3: 1 quote char, 1 end char and 1 to position at the
+ // first position to write to.
+ next_column_offset =
+ static_cast<int32_t>(row_end - EscapeReverse(s, row_end - 3));
+ }
+ *(row_end - next_column_offset) = '"';
+ *(row_end - 2) = '"';
+ *(row_end - 1) = end_char_;
+ *offsets -= next_column_offset;
+ offsets++;
+ needs_escaping++;
+ },
+ [&]() {
+ // Nulls are empty (unquoted) to distinguish with empty string.
+ *(output + *offsets - 1) = end_char_;
+ *offsets -= 1;
+ offsets++;
+ needs_escaping++;
+ });
+ }
+
+ private:
+ // Older version of GCC don't support custom allocators
+ // at some point we should change this to use memory_pool
+ // backed allocator.
+ std::vector<bool> row_needs_escaping_;
+};
+
+struct PopulatorFactory {
+ template <typename TypeClass>
+ enable_if_t<is_base_binary_type<TypeClass>::value ||
+ std::is_same<FixedSizeBinaryType, TypeClass>::value,
+ Status>
+ Visit(const TypeClass& type) {
+ populator = new QuotedColumnPopulator(pool, end_char);
+ return Status::OK();
+ }
+
+ template <typename TypeClass>
+ enable_if_dictionary<TypeClass, Status> Visit(const TypeClass& type) {
+ return VisitTypeInline(*type.value_type(), this);
+ }
+
+ template <typename TypeClass>
+ enable_if_t<is_nested_type<TypeClass>::value || is_extension_type<TypeClass>::value,
+ Status>
+ Visit(const TypeClass& type) {
+ return Status::Invalid("Unsupported Type:", type.ToString());
+ }
+
+ template <typename TypeClass>
+ enable_if_t<is_primitive_ctype<TypeClass>::value || is_decimal_type<TypeClass>::value ||
+ is_null_type<TypeClass>::value || is_temporal_type<TypeClass>::value,
+ Status>
+ Visit(const TypeClass& type) {
+ populator = new UnquotedColumnPopulator(pool, end_char);
+ return Status::OK();
+ }
+
+ char end_char;
+ MemoryPool* pool;
+ ColumnPopulator* populator;
+};
+
+Result<std::unique_ptr<ColumnPopulator>> MakePopulator(const Field& field, char end_char,
+ MemoryPool* pool) {
+ PopulatorFactory factory{end_char, pool, nullptr};
+ RETURN_NOT_OK(VisitTypeInline(*field.type(), &factory));
+ return std::unique_ptr<ColumnPopulator>(factory.populator);
+}
+
+class CSVWriterImpl : public ipc::RecordBatchWriter {
+ public:
+ static Result<std::shared_ptr<CSVWriterImpl>> Make(
+ io::OutputStream* sink, std::shared_ptr<io::OutputStream> owned_sink,
+ std::shared_ptr<Schema> schema, const WriteOptions& options) {
+ RETURN_NOT_OK(options.Validate());
+ std::vector<std::unique_ptr<ColumnPopulator>> populators(schema->num_fields());
+ for (int col = 0; col < schema->num_fields(); col++) {
+ char end_char = col < schema->num_fields() - 1 ? ',' : '\n';
+ ASSIGN_OR_RAISE(populators[col], MakePopulator(*schema->field(col), end_char,
+ options.io_context.pool()));
+ }
+ auto writer = std::make_shared<CSVWriterImpl>(
+ sink, std::move(owned_sink), std::move(schema), std::move(populators), options);
+ RETURN_NOT_OK(writer->PrepareForContentsWrite());
+ if (options.include_header) {
+ RETURN_NOT_OK(writer->WriteHeader());
+ }
+ return writer;
+ }
+
+ Status WriteRecordBatch(const RecordBatch& batch) override {
+ RecordBatchIterator iterator = RecordBatchSliceIterator(batch, options_.batch_size);
+ for (auto maybe_slice : iterator) {
+ ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> slice, maybe_slice);
+ RETURN_NOT_OK(TranslateMinimalBatch(*slice));
+ RETURN_NOT_OK(sink_->Write(data_buffer_));
+ stats_.num_record_batches++;
+ }
+ return Status::OK();
+ }
+
+ Status WriteTable(const Table& table, int64_t max_chunksize) override {
+ TableBatchReader reader(table);
+ reader.set_chunksize(max_chunksize > 0 ? max_chunksize : options_.batch_size);
+ std::shared_ptr<RecordBatch> batch;
+ RETURN_NOT_OK(reader.ReadNext(&batch));
+ while (batch != nullptr) {
+ RETURN_NOT_OK(TranslateMinimalBatch(*batch));
+ RETURN_NOT_OK(sink_->Write(data_buffer_));
+ RETURN_NOT_OK(reader.ReadNext(&batch));
+ stats_.num_record_batches++;
+ }
+
+ return Status::OK();
+ }
+
+ Status Close() override { return Status::OK(); }
+
+ ipc::WriteStats stats() const override { return stats_; }
+
+ CSVWriterImpl(io::OutputStream* sink, std::shared_ptr<io::OutputStream> owned_sink,
+ std::shared_ptr<Schema> schema,
+ std::vector<std::unique_ptr<ColumnPopulator>> populators,
+ const WriteOptions& options)
+ : sink_(sink),
+ owned_sink_(std::move(owned_sink)),
+ column_populators_(std::move(populators)),
+ offsets_(0, 0, ::arrow::stl::allocator<char*>(options.io_context.pool())),
+ schema_(std::move(schema)),
+ options_(options) {}
+
+ private:
+ Status PrepareForContentsWrite() {
+ // Only called once, as part of initialization
+ if (data_buffer_ == nullptr) {
+ ASSIGN_OR_RAISE(data_buffer_,
+ AllocateResizableBuffer(
+ options_.batch_size * schema_->num_fields() * kColumnSizeGuess,
+ options_.io_context.pool()));
+ }
+ return Status::OK();
+ }
+
+ int64_t CalculateHeaderSize() const {
+ int64_t header_length = 0;
+ for (int col = 0; col < schema_->num_fields(); col++) {
+ const std::string& col_name = schema_->field(col)->name();
+ header_length += col_name.size();
+ header_length += CountEscapes(col_name);
+ }
+ return header_length + (kQuoteDelimiterCount * schema_->num_fields());
+ }
+
+ Status WriteHeader() {
+ // Only called once, as part of initialization
+ RETURN_NOT_OK(data_buffer_->Resize(CalculateHeaderSize(), /*shrink_to_fit=*/false));
+ char* next =
+ reinterpret_cast<char*>(data_buffer_->mutable_data() + data_buffer_->size() - 1);
+ for (int col = schema_->num_fields() - 1; col >= 0; col--) {
+ *next-- = ',';
+ *next-- = '"';
+ next = EscapeReverse(schema_->field(col)->name(), next);
+ *next-- = '"';
+ }
+ *(data_buffer_->mutable_data() + data_buffer_->size() - 1) = '\n';
+ DCHECK_EQ(reinterpret_cast<uint8_t*>(next + 1), data_buffer_->data());
+ return sink_->Write(data_buffer_);
+ }
+
+ Status TranslateMinimalBatch(const RecordBatch& batch) {
+ if (batch.num_rows() == 0) {
+ return Status::OK();
+ }
+ offsets_.resize(batch.num_rows());
+ std::fill(offsets_.begin(), offsets_.end(), 0);
+
+ // Calculate relative offsets for each row (excluding delimiters)
+ for (int32_t col = 0; col < static_cast<int32_t>(column_populators_.size()); col++) {
+ RETURN_NOT_OK(
+ column_populators_[col]->UpdateRowLengths(*batch.column(col), offsets_.data()));
+ }
+ // Calculate cumulalative offsets for each row (including delimiters).
+ offsets_[0] += batch.num_columns();
+ for (int64_t row = 1; row < batch.num_rows(); row++) {
+ offsets_[row] += offsets_[row - 1] + /*delimiter lengths*/ batch.num_columns();
+ }
+ // Resize the target buffer to required size. We assume batch to batch sizes
+ // should be pretty close so don't shrink the buffer to avoid allocation churn.
+ RETURN_NOT_OK(data_buffer_->Resize(offsets_.back(), /*shrink_to_fit=*/false));
+
+ // Use the offsets to populate contents.
+ for (auto populator = column_populators_.rbegin();
+ populator != column_populators_.rend(); populator++) {
+ (*populator)
+ ->PopulateColumns(reinterpret_cast<char*>(data_buffer_->mutable_data()),
+ offsets_.data());
+ }
+ DCHECK_EQ(0, offsets_[0]);
+ return Status::OK();
+ }
+
+ static constexpr int64_t kColumnSizeGuess = 8;
+ io::OutputStream* sink_;
+ std::shared_ptr<io::OutputStream> owned_sink_;
+ std::vector<std::unique_ptr<ColumnPopulator>> column_populators_;
+ std::vector<int32_t, arrow::stl::allocator<int32_t>> offsets_;
+ std::shared_ptr<ResizableBuffer> data_buffer_;
+ const std::shared_ptr<Schema> schema_;
+ const WriteOptions options_;
+ ipc::WriteStats stats_;
+};
+
+} // namespace
+
+Status WriteCSV(const Table& table, const WriteOptions& options,
+ arrow::io::OutputStream* output) {
+ ASSIGN_OR_RAISE(auto writer, MakeCSVWriter(output, table.schema(), options));
+ RETURN_NOT_OK(writer->WriteTable(table));
+ return writer->Close();
+}
+
+Status WriteCSV(const RecordBatch& batch, const WriteOptions& options,
+ arrow::io::OutputStream* output) {
+ ASSIGN_OR_RAISE(auto writer, MakeCSVWriter(output, batch.schema(), options));
+ RETURN_NOT_OK(writer->WriteRecordBatch(batch));
+ return writer->Close();
+}
+
+ARROW_EXPORT
+Result<std::shared_ptr<ipc::RecordBatchWriter>> MakeCSVWriter(
+ std::shared_ptr<io::OutputStream> sink, const std::shared_ptr<Schema>& schema,
+ const WriteOptions& options) {
+ return CSVWriterImpl::Make(sink.get(), sink, schema, options);
+}
+
+ARROW_EXPORT
+Result<std::shared_ptr<ipc::RecordBatchWriter>> MakeCSVWriter(
+ io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
+ const WriteOptions& options) {
+ return CSVWriterImpl::Make(sink, nullptr, schema, options);
+}
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/writer.h b/src/arrow/cpp/src/arrow/csv/writer.h
new file mode 100644
index 000000000..2f1442ae0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/writer.h
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/csv/options.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/ipc/type_fwd.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+
+namespace arrow {
+namespace csv {
+// Functionality for converting Arrow data to Comma separated value text.
+// This library supports all primitive types that can be cast to a StringArrays.
+// It applies to following formatting rules:
+// - For non-binary types no quotes surround values. Nulls are represented as the empty
+// string.
+// - For binary types all non-null data is quoted (and quotes within data are escaped
+// with an additional quote).
+// Null values are empty and unquoted.
+// - LF (\n) is always used as a line ending.
+
+/// \brief Converts table to a CSV and writes the results to output.
+/// Experimental
+ARROW_EXPORT Status WriteCSV(const Table& table, const WriteOptions& options,
+ arrow::io::OutputStream* output);
+/// \brief Converts batch to CSV and writes the results to output.
+/// Experimental
+ARROW_EXPORT Status WriteCSV(const RecordBatch& batch, const WriteOptions& options,
+ arrow::io::OutputStream* output);
+
+/// \brief Create a new CSV writer. User is responsible for closing the
+/// actual OutputStream.
+///
+/// \param[in] sink output stream to write to
+/// \param[in] schema the schema of the record batches to be written
+/// \param[in] options options for serialization
+/// \return Result<std::shared_ptr<RecordBatchWriter>>
+ARROW_EXPORT
+Result<std::shared_ptr<ipc::RecordBatchWriter>> MakeCSVWriter(
+ std::shared_ptr<io::OutputStream> sink, const std::shared_ptr<Schema>& schema,
+ const WriteOptions& options = WriteOptions::Defaults());
+
+/// \brief Create a new CSV writer.
+///
+/// \param[in] sink output stream to write to (does not take ownership)
+/// \param[in] schema the schema of the record batches to be written
+/// \param[in] options options for serialization
+/// \return Result<std::shared_ptr<RecordBatchWriter>>
+ARROW_EXPORT
+Result<std::shared_ptr<ipc::RecordBatchWriter>> MakeCSVWriter(
+ io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
+ const WriteOptions& options = WriteOptions::Defaults());
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/csv/writer_test.cc b/src/arrow/cpp/src/arrow/csv/writer_test.cc
new file mode 100644
index 000000000..57b42c7f5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/csv/writer_test.cc
@@ -0,0 +1,159 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gtest/gtest.h"
+
+#include <memory>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/csv/writer.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/result_internal.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+
+namespace arrow {
+namespace csv {
+
+struct WriterTestParams {
+ std::shared_ptr<Schema> schema;
+ std::string batch_data;
+ WriteOptions options;
+ std::string expected_output;
+};
+
+// Avoid Valgrind failures with GTest trying to represent a WriterTestParams
+void PrintTo(const WriterTestParams& p, std::ostream* os) {
+ *os << "WriterTestParams(" << reinterpret_cast<const void*>(&p) << ")";
+}
+
+WriteOptions DefaultTestOptions(bool include_header) {
+ WriteOptions options;
+ options.batch_size = 5;
+ options.include_header = include_header;
+ return options;
+}
+
+std::vector<WriterTestParams> GenerateTestCases() {
+ auto abc_schema = schema({
+ {field("a", uint64())},
+ {field("b\"", utf8())},
+ {field("c ", int32())},
+ {field("d", date32())},
+ {field("e", date64())},
+ });
+ auto populated_batch = R"([{"a": 1, "c ": -1},
+ { "a": 1, "b\"": "abc\"efg", "c ": 2324},
+ { "b\"": "abcd", "c ": 5467},
+ { },
+ { "a": 546, "b\"": "", "c ": 517 },
+ { "a": 124, "b\"": "a\"\"b\"" },
+ { "d": 0 },
+ { "e": 86400000 }])";
+ std::string expected_without_header = std::string("1,,-1,,") + "\n" + // line 1
+ R"(1,"abc""efg",2324,,)" + "\n" + // line 2
+ R"(,"abcd",5467,,)" + "\n" + // line 3
+ R"(,,,,)" + "\n" + // line 4
+ R"(546,"",517,,)" + "\n" + // line 5
+ R"(124,"a""""b""",,,)" + "\n" + // line 6
+ R"(,,,1970-01-01,)" + "\n" + // line 7
+ R"(,,,,1970-01-02)" + "\n"; // line 8
+ std::string expected_header = std::string(R"("a","b""","c ","d","e")") + "\n";
+
+ return std::vector<WriterTestParams>{
+ {abc_schema, "[]", DefaultTestOptions(/*header=*/false), ""},
+ {abc_schema, "[]", DefaultTestOptions(/*header=*/true), expected_header},
+ {abc_schema, populated_batch, DefaultTestOptions(/*header=*/false),
+ expected_without_header},
+ {abc_schema, populated_batch, DefaultTestOptions(/*header=*/true),
+ expected_header + expected_without_header}};
+}
+
+class TestWriteCSV : public ::testing::TestWithParam<WriterTestParams> {
+ protected:
+ template <typename Data>
+ Result<std::string> ToCsvString(const Data& data, const WriteOptions& options) {
+ std::shared_ptr<io::BufferOutputStream> out;
+ ASSIGN_OR_RAISE(out, io::BufferOutputStream::Create());
+
+ RETURN_NOT_OK(WriteCSV(data, options, out.get()));
+ ASSIGN_OR_RAISE(std::shared_ptr<Buffer> buffer, out->Finish());
+ return std::string(reinterpret_cast<const char*>(buffer->data()), buffer->size());
+ }
+
+ Result<std::string> ToCsvStringUsingWriter(const Table& data,
+ const WriteOptions& options) {
+ std::shared_ptr<io::BufferOutputStream> out;
+ ASSIGN_OR_RAISE(out, io::BufferOutputStream::Create());
+ // Write row-by-row
+ ASSIGN_OR_RAISE(auto writer, MakeCSVWriter(out, data.schema(), options));
+ TableBatchReader reader(data);
+ reader.set_chunksize(1);
+ std::shared_ptr<RecordBatch> batch;
+ RETURN_NOT_OK(reader.ReadNext(&batch));
+ while (batch != nullptr) {
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+ RETURN_NOT_OK(reader.ReadNext(&batch));
+ }
+ RETURN_NOT_OK(writer->Close());
+ EXPECT_EQ(data.num_rows(), writer->stats().num_record_batches);
+ ASSIGN_OR_RAISE(std::shared_ptr<Buffer> buffer, out->Finish());
+ return std::string(reinterpret_cast<const char*>(buffer->data()), buffer->size());
+ }
+};
+
+TEST_P(TestWriteCSV, TestWrite) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<io::BufferOutputStream> out,
+ io::BufferOutputStream::Create());
+ WriteOptions options = GetParam().options;
+ std::string csv;
+ auto record_batch = RecordBatchFromJSON(GetParam().schema, GetParam().batch_data);
+ ASSERT_OK_AND_ASSIGN(csv, ToCsvString(*record_batch, options));
+ EXPECT_EQ(csv, GetParam().expected_output);
+
+ // Batch size shouldn't matter.
+ options.batch_size /= 2;
+ ASSERT_OK_AND_ASSIGN(csv, ToCsvString(*record_batch, options));
+ EXPECT_EQ(csv, GetParam().expected_output);
+
+ // Table and Record batch should work identically.
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Table> table,
+ Table::FromRecordBatches({record_batch}));
+ ASSERT_OK_AND_ASSIGN(csv, ToCsvString(*table, options));
+ EXPECT_EQ(csv, GetParam().expected_output);
+
+ // The writer should work identically.
+ ASSERT_OK_AND_ASSIGN(csv, ToCsvStringUsingWriter(*table, options));
+ EXPECT_EQ(csv, GetParam().expected_output);
+}
+
+INSTANTIATE_TEST_SUITE_P(MultiColumnWriteCSVTest, TestWriteCSV,
+ ::testing::ValuesIn(GenerateTestCases()));
+
+INSTANTIATE_TEST_SUITE_P(SingleColumnWriteCSVTest, TestWriteCSV,
+ ::testing::Values(WriterTestParams{
+ schema({field("int64", int64())}),
+ R"([{ "int64": 9999}, {}, { "int64": -15}])", WriteOptions(),
+ R"("int64")"
+ "\n9999\n\n-15\n"}));
+
+} // namespace csv
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/ArrowDatasetConfig.cmake.in b/src/arrow/cpp/src/arrow/dataset/ArrowDatasetConfig.cmake.in
new file mode 100644
index 000000000..ee732cfd5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/ArrowDatasetConfig.cmake.in
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This config sets the following variables in your project::
+#
+# ArrowDataset_FOUND - true if Arrow Dataset found on the system
+#
+# This config sets the following targets in your project::
+#
+# arrow_dataset_shared - for linked as shared library if shared library is built
+# arrow_dataset_static - for linked as static library if static library is built
+
+@PACKAGE_INIT@
+
+include(CMakeFindDependencyMacro)
+find_dependency(Arrow)
+find_dependency(Parquet)
+
+# Load targets only once. If we load targets multiple times, CMake reports
+# already existent target error.
+if(NOT (TARGET arrow_dataset_shared OR TARGET arrow_dataset_static))
+ include("${CMAKE_CURRENT_LIST_DIR}/ArrowDatasetTargets.cmake")
+endif()
diff --git a/src/arrow/cpp/src/arrow/dataset/CMakeLists.txt b/src/arrow/cpp/src/arrow/dataset/CMakeLists.txt
new file mode 100644
index 000000000..6aa4794a3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/CMakeLists.txt
@@ -0,0 +1,145 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_custom_target(arrow_dataset)
+
+arrow_install_all_headers("arrow/dataset")
+
+set(ARROW_DATASET_SRCS
+ dataset.cc
+ dataset_writer.cc
+ discovery.cc
+ file_base.cc
+ file_ipc.cc
+ partition.cc
+ plan.cc
+ projector.cc
+ scanner.cc)
+
+set(ARROW_DATASET_LINK_STATIC arrow_static)
+set(ARROW_DATASET_LINK_SHARED arrow_shared)
+
+if(ARROW_CSV)
+ set(ARROW_DATASET_SRCS ${ARROW_DATASET_SRCS} file_csv.cc)
+endif()
+
+if(ARROW_ORC)
+ set(ARROW_DATASET_SRCS ${ARROW_DATASET_SRCS} file_orc.cc)
+endif()
+
+if(ARROW_PARQUET)
+ set(ARROW_DATASET_LINK_STATIC ${ARROW_DATASET_LINK_STATIC} parquet_static)
+ set(ARROW_DATASET_LINK_SHARED ${ARROW_DATASET_LINK_SHARED} parquet_shared)
+ set(ARROW_DATASET_SRCS ${ARROW_DATASET_SRCS} file_parquet.cc)
+ set(ARROW_DATASET_PRIVATE_INCLUDES ${PROJECT_SOURCE_DIR}/src/parquet)
+endif()
+
+add_arrow_lib(arrow_dataset
+ CMAKE_PACKAGE_NAME
+ ArrowDataset
+ PKG_CONFIG_NAME
+ arrow-dataset
+ OUTPUTS
+ ARROW_DATASET_LIBRARIES
+ SOURCES
+ ${ARROW_DATASET_SRCS}
+ PRECOMPILED_HEADERS
+ "$<$<COMPILE_LANGUAGE:CXX>:arrow/dataset/pch.h>"
+ DEPENDENCIES
+ toolchain
+ PRIVATE_INCLUDES
+ ${ARROW_DATASET_PRIVATE_INCLUDES}
+ SHARED_LINK_LIBS
+ ${ARROW_DATASET_LINK_SHARED}
+ STATIC_LINK_LIBS
+ ${ARROW_DATASET_LINK_STATIC})
+
+if(ARROW_TEST_LINKAGE STREQUAL "static")
+ set(ARROW_DATASET_TEST_LINK_LIBS arrow_dataset_static ${ARROW_TEST_STATIC_LINK_LIBS})
+else()
+ set(ARROW_DATASET_TEST_LINK_LIBS arrow_dataset_shared ${ARROW_TEST_SHARED_LINK_LIBS})
+endif()
+
+foreach(LIB_TARGET ${ARROW_DATASET_LIBRARIES})
+ target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_DS_EXPORTING)
+endforeach()
+
+# Adding unit tests part of the "dataset" portion of the test suite
+function(ADD_ARROW_DATASET_TEST REL_TEST_NAME)
+ set(options)
+ set(one_value_args PREFIX)
+ set(multi_value_args LABELS)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+
+ if(ARG_PREFIX)
+ set(PREFIX ${ARG_PREFIX})
+ else()
+ set(PREFIX "arrow-dataset")
+ endif()
+
+ if(ARG_LABELS)
+ set(LABELS ${ARG_LABELS})
+ else()
+ set(LABELS "arrow_dataset")
+ endif()
+
+ add_arrow_test(${REL_TEST_NAME}
+ EXTRA_LINK_LIBS
+ ${ARROW_DATASET_TEST_LINK_LIBS}
+ PREFIX
+ ${PREFIX}
+ LABELS
+ ${LABELS}
+ ${ARG_UNPARSED_ARGUMENTS})
+endfunction()
+
+add_arrow_dataset_test(dataset_test)
+add_arrow_dataset_test(dataset_writer_test)
+add_arrow_dataset_test(discovery_test)
+add_arrow_dataset_test(file_ipc_test)
+add_arrow_dataset_test(file_test)
+add_arrow_dataset_test(partition_test)
+add_arrow_dataset_test(scanner_test)
+
+if(ARROW_CSV)
+ add_arrow_dataset_test(file_csv_test)
+endif()
+
+if(ARROW_ORC)
+ add_arrow_dataset_test(file_orc_test)
+endif()
+
+if(ARROW_PARQUET)
+ add_arrow_dataset_test(file_parquet_test)
+endif()
+
+if(ARROW_BUILD_BENCHMARKS)
+ add_arrow_benchmark(file_benchmark PREFIX "arrow-dataset")
+ add_arrow_benchmark(scanner_benchmark PREFIX "arrow-dataset")
+
+ if(ARROW_BUILD_STATIC)
+ target_link_libraries(arrow-dataset-file-benchmark PUBLIC arrow_dataset_static)
+ target_link_libraries(arrow-dataset-scanner-benchmark PUBLIC arrow_dataset_static)
+ else()
+ target_link_libraries(arrow-dataset-file-benchmark PUBLIC arrow_dataset_shared)
+ target_link_libraries(arrow-dataset-scanner-benchmark PUBLIC arrow_dataset_shared)
+ endif()
+endif()
diff --git a/src/arrow/cpp/src/arrow/dataset/README.md b/src/arrow/cpp/src/arrow/dataset/README.md
new file mode 100644
index 000000000..225f38a5c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/README.md
@@ -0,0 +1,32 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Arrow C++ Dataset
+
+The `arrow::dataset` subcomponent provides an API to read and write
+semantic datasets stored in different locations and formats. It
+facilitates parallel processing of datasets spread across different
+physical files and serialization formats. Other concerns such as
+partitioning, filtering (partition- and column-level), and schema
+normalization are also addressed.
+
+## Development Status
+
+Alpha/beta stage as of April 2020. API subject to change, possibly
+without deprecation notices.
diff --git a/src/arrow/cpp/src/arrow/dataset/api.h b/src/arrow/cpp/src/arrow/dataset/api.h
new file mode 100644
index 000000000..8b81f4c15
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/api.h
@@ -0,0 +1,30 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include "arrow/compute/exec/expression.h"
+#include "arrow/dataset/dataset.h"
+#include "arrow/dataset/discovery.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/file_csv.h"
+#include "arrow/dataset/file_ipc.h"
+#include "arrow/dataset/file_orc.h"
+#include "arrow/dataset/file_parquet.h"
+#include "arrow/dataset/scanner.h"
diff --git a/src/arrow/cpp/src/arrow/dataset/arrow-dataset.pc.in b/src/arrow/cpp/src/arrow/dataset/arrow-dataset.pc.in
new file mode 100644
index 000000000..c03aad378
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/arrow-dataset.pc.in
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+
+Name: Apache Arrow Dataset
+Description: Apache Arrow Dataset provides an API to read and write semantic datasets stored in different locations and formats.
+Version: @ARROW_VERSION@
+Requires: arrow parquet
+Libs: -L${libdir} -larrow_dataset
diff --git a/src/arrow/cpp/src/arrow/dataset/dataset.cc b/src/arrow/cpp/src/arrow/dataset/dataset.cc
new file mode 100644
index 000000000..680713a7b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/dataset.cc
@@ -0,0 +1,269 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/dataset.h"
+
+#include <memory>
+#include <utility>
+
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/scanner.h"
+#include "arrow/table.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+
+namespace arrow {
+
+using internal::checked_pointer_cast;
+
+namespace dataset {
+
+Fragment::Fragment(compute::Expression partition_expression,
+ std::shared_ptr<Schema> physical_schema)
+ : partition_expression_(std::move(partition_expression)),
+ physical_schema_(std::move(physical_schema)) {}
+
+Result<std::shared_ptr<Schema>> Fragment::ReadPhysicalSchema() {
+ {
+ auto lock = physical_schema_mutex_.Lock();
+ if (physical_schema_ != nullptr) return physical_schema_;
+ }
+
+ // allow ReadPhysicalSchemaImpl to lock mutex_, if necessary
+ ARROW_ASSIGN_OR_RAISE(auto physical_schema, ReadPhysicalSchemaImpl());
+
+ auto lock = physical_schema_mutex_.Lock();
+ if (physical_schema_ == nullptr) {
+ physical_schema_ = std::move(physical_schema);
+ }
+ return physical_schema_;
+}
+
+Future<util::optional<int64_t>> Fragment::CountRows(compute::Expression,
+ const std::shared_ptr<ScanOptions>&) {
+ return Future<util::optional<int64_t>>::MakeFinished(util::nullopt);
+}
+
+Result<std::shared_ptr<Schema>> InMemoryFragment::ReadPhysicalSchemaImpl() {
+ return physical_schema_;
+}
+
+InMemoryFragment::InMemoryFragment(std::shared_ptr<Schema> schema,
+ RecordBatchVector record_batches,
+ compute::Expression partition_expression)
+ : Fragment(std::move(partition_expression), std::move(schema)),
+ record_batches_(std::move(record_batches)) {
+ DCHECK_NE(physical_schema_, nullptr);
+}
+
+InMemoryFragment::InMemoryFragment(RecordBatchVector record_batches,
+ compute::Expression partition_expression)
+ : Fragment(std::move(partition_expression), /*schema=*/nullptr),
+ record_batches_(std::move(record_batches)) {
+ // Order of argument evaluation is undefined, so compute physical_schema here
+ physical_schema_ = record_batches_.empty() ? schema({}) : record_batches_[0]->schema();
+}
+
+Result<ScanTaskIterator> InMemoryFragment::Scan(std::shared_ptr<ScanOptions> options) {
+ // Make an explicit copy of record_batches_ to ensure Scan can be called
+ // multiple times.
+ auto batches_it = MakeVectorIterator(record_batches_);
+
+ auto batch_size = options->batch_size;
+ // RecordBatch -> ScanTask
+ auto self = shared_from_this();
+ auto fn = [=](std::shared_ptr<RecordBatch> batch) -> std::shared_ptr<ScanTask> {
+ RecordBatchVector batches;
+
+ auto n_batches = BitUtil::CeilDiv(batch->num_rows(), batch_size);
+ for (int i = 0; i < n_batches; i++) {
+ batches.push_back(batch->Slice(batch_size * i, batch_size));
+ }
+
+ return ::arrow::internal::make_unique<InMemoryScanTask>(std::move(batches),
+ std::move(options), self);
+ };
+
+ return MakeMapIterator(fn, std::move(batches_it));
+}
+
+Result<RecordBatchGenerator> InMemoryFragment::ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& options) {
+ struct State {
+ State(std::shared_ptr<InMemoryFragment> fragment, int64_t batch_size)
+ : fragment(std::move(fragment)),
+ batch_index(0),
+ offset(0),
+ batch_size(batch_size) {}
+
+ std::shared_ptr<RecordBatch> Next() {
+ const auto& next_parent = fragment->record_batches_[batch_index];
+ if (offset < next_parent->num_rows()) {
+ auto next = next_parent->Slice(offset, batch_size);
+ offset += batch_size;
+ return next;
+ }
+ batch_index++;
+ offset = 0;
+ return nullptr;
+ }
+
+ bool Finished() { return batch_index >= fragment->record_batches_.size(); }
+
+ std::shared_ptr<InMemoryFragment> fragment;
+ std::size_t batch_index;
+ int64_t offset;
+ int64_t batch_size;
+ };
+
+ struct Generator {
+ Generator(std::shared_ptr<InMemoryFragment> fragment, int64_t batch_size)
+ : state(std::make_shared<State>(std::move(fragment), batch_size)) {}
+
+ Future<std::shared_ptr<RecordBatch>> operator()() {
+ while (!state->Finished()) {
+ auto next = state->Next();
+ if (next) {
+ return Future<std::shared_ptr<RecordBatch>>::MakeFinished(std::move(next));
+ }
+ }
+ return AsyncGeneratorEnd<std::shared_ptr<RecordBatch>>();
+ }
+
+ std::shared_ptr<State> state;
+ };
+ return Generator(checked_pointer_cast<InMemoryFragment>(shared_from_this()),
+ options->batch_size);
+}
+
+Future<util::optional<int64_t>> InMemoryFragment::CountRows(
+ compute::Expression predicate, const std::shared_ptr<ScanOptions>& options) {
+ if (ExpressionHasFieldRefs(predicate)) {
+ return Future<util::optional<int64_t>>::MakeFinished(util::nullopt);
+ }
+ int64_t total = 0;
+ for (const auto& batch : record_batches_) {
+ total += batch->num_rows();
+ }
+ return Future<util::optional<int64_t>>::MakeFinished(total);
+}
+
+Dataset::Dataset(std::shared_ptr<Schema> schema, compute::Expression partition_expression)
+ : schema_(std::move(schema)),
+ partition_expression_(std::move(partition_expression)) {}
+
+Result<std::shared_ptr<ScannerBuilder>> Dataset::NewScan() {
+ return std::make_shared<ScannerBuilder>(this->shared_from_this());
+}
+
+Result<FragmentIterator> Dataset::GetFragments() {
+ return GetFragments(compute::literal(true));
+}
+
+Result<FragmentIterator> Dataset::GetFragments(compute::Expression predicate) {
+ ARROW_ASSIGN_OR_RAISE(
+ predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_));
+ return predicate.IsSatisfiable() ? GetFragmentsImpl(std::move(predicate))
+ : MakeEmptyIterator<std::shared_ptr<Fragment>>();
+}
+
+struct VectorRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator {
+ explicit VectorRecordBatchGenerator(RecordBatchVector batches)
+ : batches_(std::move(batches)) {}
+
+ RecordBatchIterator Get() const final { return MakeVectorIterator(batches_); }
+
+ RecordBatchVector batches_;
+};
+
+InMemoryDataset::InMemoryDataset(std::shared_ptr<Schema> schema,
+ RecordBatchVector batches)
+ : Dataset(std::move(schema)),
+ get_batches_(new VectorRecordBatchGenerator(std::move(batches))) {}
+
+struct TableRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator {
+ explicit TableRecordBatchGenerator(std::shared_ptr<Table> table)
+ : table_(std::move(table)) {}
+
+ RecordBatchIterator Get() const final {
+ auto reader = std::make_shared<TableBatchReader>(*table_);
+ auto table = table_;
+ return MakeFunctionIterator([reader, table] { return reader->Next(); });
+ }
+
+ std::shared_ptr<Table> table_;
+};
+
+InMemoryDataset::InMemoryDataset(std::shared_ptr<Table> table)
+ : Dataset(table->schema()),
+ get_batches_(new TableRecordBatchGenerator(std::move(table))) {}
+
+Result<std::shared_ptr<Dataset>> InMemoryDataset::ReplaceSchema(
+ std::shared_ptr<Schema> schema) const {
+ RETURN_NOT_OK(CheckProjectable(*schema_, *schema));
+ return std::make_shared<InMemoryDataset>(std::move(schema), get_batches_);
+}
+
+Result<FragmentIterator> InMemoryDataset::GetFragmentsImpl(compute::Expression) {
+ auto schema = this->schema();
+
+ auto create_fragment =
+ [schema](std::shared_ptr<RecordBatch> batch) -> Result<std::shared_ptr<Fragment>> {
+ if (!batch->schema()->Equals(schema)) {
+ return Status::TypeError("yielded batch had schema ", *batch->schema(),
+ " which did not match InMemorySource's: ", *schema);
+ }
+
+ return std::make_shared<InMemoryFragment>(RecordBatchVector{std::move(batch)});
+ };
+
+ auto batches_it = get_batches_->Get();
+ return MakeMaybeMapIterator(std::move(create_fragment), std::move(batches_it));
+}
+
+Result<std::shared_ptr<UnionDataset>> UnionDataset::Make(std::shared_ptr<Schema> schema,
+ DatasetVector children) {
+ for (const auto& child : children) {
+ if (!child->schema()->Equals(*schema)) {
+ return Status::TypeError("child Dataset had schema ", *child->schema(),
+ " but the union schema was ", *schema);
+ }
+ }
+
+ return std::shared_ptr<UnionDataset>(
+ new UnionDataset(std::move(schema), std::move(children)));
+}
+
+Result<std::shared_ptr<Dataset>> UnionDataset::ReplaceSchema(
+ std::shared_ptr<Schema> schema) const {
+ auto children = children_;
+ for (auto& child : children) {
+ ARROW_ASSIGN_OR_RAISE(child, child->ReplaceSchema(schema));
+ }
+
+ return std::shared_ptr<Dataset>(
+ new UnionDataset(std::move(schema), std::move(children)));
+}
+
+Result<FragmentIterator> UnionDataset::GetFragmentsImpl(compute::Expression predicate) {
+ return GetFragmentsFromDatasets(children_, predicate);
+}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/dataset.h b/src/arrow/cpp/src/arrow/dataset/dataset.h
new file mode 100644
index 000000000..11210fdc2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/dataset.h
@@ -0,0 +1,264 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/compute/exec/expression.h"
+#include "arrow/dataset/type_fwd.h"
+#include "arrow/dataset/visibility.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/mutex.h"
+#include "arrow/util/optional.h"
+
+namespace arrow {
+namespace dataset {
+
+using RecordBatchGenerator = std::function<Future<std::shared_ptr<RecordBatch>>()>;
+
+/// \brief A granular piece of a Dataset, such as an individual file.
+///
+/// A Fragment can be read/scanned separately from other fragments. It yields a
+/// collection of RecordBatches when scanned, encapsulated in one or more
+/// ScanTasks.
+///
+/// Note that Fragments have well defined physical schemas which are reconciled by
+/// the Datasets which contain them; these physical schemas may differ from a parent
+/// Dataset's schema and the physical schemas of sibling Fragments.
+class ARROW_DS_EXPORT Fragment : public std::enable_shared_from_this<Fragment> {
+ public:
+ /// \brief Return the physical schema of the Fragment.
+ ///
+ /// The physical schema is also called the writer schema.
+ /// This method is blocking and may suffer from high latency filesystem.
+ /// The schema is cached after being read once, or may be specified at construction.
+ Result<std::shared_ptr<Schema>> ReadPhysicalSchema();
+
+ /// \brief Scan returns an iterator of ScanTasks, each of which yields
+ /// RecordBatches from this Fragment.
+ ///
+ /// Note that batches yielded using this method will not be filtered and may not align
+ /// with the Fragment's schema. In particular, note that columns referenced by the
+ /// filter may be present in yielded batches even if they are not projected (so that
+ /// they are available when a filter is applied). Additionally, explicitly projected
+ /// columns may be absent if they were not present in this fragment.
+ ///
+ /// To receive a record batch stream which is fully filtered and projected, use Scanner.
+ virtual Result<ScanTaskIterator> Scan(std::shared_ptr<ScanOptions> options) = 0;
+
+ /// An asynchronous version of Scan
+ virtual Result<RecordBatchGenerator> ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& options) = 0;
+
+ /// \brief Count the number of rows in this fragment matching the filter using metadata
+ /// only. That is, this method may perform I/O, but will not load data.
+ ///
+ /// If this is not possible, resolve with an empty optional. The fragment can perform
+ /// I/O (e.g. to read metadata) before it deciding whether it can satisfy the request.
+ virtual Future<util::optional<int64_t>> CountRows(
+ compute::Expression predicate, const std::shared_ptr<ScanOptions>& options);
+
+ virtual std::string type_name() const = 0;
+ virtual std::string ToString() const { return type_name(); }
+
+ /// \brief An expression which evaluates to true for all data viewed by this
+ /// Fragment.
+ const compute::Expression& partition_expression() const {
+ return partition_expression_;
+ }
+
+ virtual ~Fragment() = default;
+
+ protected:
+ Fragment() = default;
+ explicit Fragment(compute::Expression partition_expression,
+ std::shared_ptr<Schema> physical_schema);
+
+ virtual Result<std::shared_ptr<Schema>> ReadPhysicalSchemaImpl() = 0;
+
+ util::Mutex physical_schema_mutex_;
+ compute::Expression partition_expression_ = compute::literal(true);
+ std::shared_ptr<Schema> physical_schema_;
+};
+
+/// \brief Per-scan options for fragment(s) in a dataset.
+///
+/// These options are not intrinsic to the format or fragment itself, but do affect
+/// the results of a scan. These are options which make sense to change between
+/// repeated reads of the same dataset, such as format-specific conversion options
+/// (that do not affect the schema).
+///
+/// \ingroup dataset-scanning
+class ARROW_DS_EXPORT FragmentScanOptions {
+ public:
+ virtual std::string type_name() const = 0;
+ virtual std::string ToString() const { return type_name(); }
+ virtual ~FragmentScanOptions() = default;
+};
+
+/// \defgroup dataset-implementations Concrete implementations
+///
+/// @{
+
+/// \brief A trivial Fragment that yields ScanTask out of a fixed set of
+/// RecordBatch.
+class ARROW_DS_EXPORT InMemoryFragment : public Fragment {
+ public:
+ InMemoryFragment(std::shared_ptr<Schema> schema, RecordBatchVector record_batches,
+ compute::Expression = compute::literal(true));
+ explicit InMemoryFragment(RecordBatchVector record_batches,
+ compute::Expression = compute::literal(true));
+
+ Result<ScanTaskIterator> Scan(std::shared_ptr<ScanOptions> options) override;
+ Result<RecordBatchGenerator> ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& options) override;
+ Future<util::optional<int64_t>> CountRows(
+ compute::Expression predicate,
+ const std::shared_ptr<ScanOptions>& options) override;
+
+ std::string type_name() const override { return "in-memory"; }
+
+ protected:
+ Result<std::shared_ptr<Schema>> ReadPhysicalSchemaImpl() override;
+
+ RecordBatchVector record_batches_;
+};
+
+/// @}
+
+/// \brief A container of zero or more Fragments.
+///
+/// A Dataset acts as a union of Fragments, e.g. files deeply nested in a
+/// directory. A Dataset has a schema to which Fragments must align during a
+/// scan operation. This is analogous to Avro's reader and writer schema.
+class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this<Dataset> {
+ public:
+ /// \brief Begin to build a new Scan operation against this Dataset
+ Result<std::shared_ptr<ScannerBuilder>> NewScan();
+
+ /// \brief GetFragments returns an iterator of Fragments given a predicate.
+ Result<FragmentIterator> GetFragments(compute::Expression predicate);
+ Result<FragmentIterator> GetFragments();
+
+ const std::shared_ptr<Schema>& schema() const { return schema_; }
+
+ /// \brief An expression which evaluates to true for all data viewed by this Dataset.
+ /// May be null, which indicates no information is available.
+ const compute::Expression& partition_expression() const {
+ return partition_expression_;
+ }
+
+ /// \brief The name identifying the kind of Dataset
+ virtual std::string type_name() const = 0;
+
+ /// \brief Return a copy of this Dataset with a different schema.
+ ///
+ /// The copy will view the same Fragments. If the new schema is not compatible with the
+ /// original dataset's schema then an error will be raised.
+ virtual Result<std::shared_ptr<Dataset>> ReplaceSchema(
+ std::shared_ptr<Schema> schema) const = 0;
+
+ virtual ~Dataset() = default;
+
+ protected:
+ explicit Dataset(std::shared_ptr<Schema> schema) : schema_(std::move(schema)) {}
+
+ Dataset(std::shared_ptr<Schema> schema, compute::Expression partition_expression);
+
+ virtual Result<FragmentIterator> GetFragmentsImpl(compute::Expression predicate) = 0;
+
+ std::shared_ptr<Schema> schema_;
+ compute::Expression partition_expression_ = compute::literal(true);
+};
+
+/// \addtogroup dataset-implementations
+///
+/// @{
+
+/// \brief A Source which yields fragments wrapping a stream of record batches.
+///
+/// The record batches must match the schema provided to the source at construction.
+class ARROW_DS_EXPORT InMemoryDataset : public Dataset {
+ public:
+ class RecordBatchGenerator {
+ public:
+ virtual ~RecordBatchGenerator() = default;
+ virtual RecordBatchIterator Get() const = 0;
+ };
+
+ /// Construct a dataset from a schema and a factory of record batch iterators.
+ InMemoryDataset(std::shared_ptr<Schema> schema,
+ std::shared_ptr<RecordBatchGenerator> get_batches)
+ : Dataset(std::move(schema)), get_batches_(std::move(get_batches)) {}
+
+ /// Convenience constructor taking a fixed list of batches
+ InMemoryDataset(std::shared_ptr<Schema> schema, RecordBatchVector batches);
+
+ /// Convenience constructor taking a Table
+ explicit InMemoryDataset(std::shared_ptr<Table> table);
+
+ std::string type_name() const override { return "in-memory"; }
+
+ Result<std::shared_ptr<Dataset>> ReplaceSchema(
+ std::shared_ptr<Schema> schema) const override;
+
+ protected:
+ Result<FragmentIterator> GetFragmentsImpl(compute::Expression predicate) override;
+
+ std::shared_ptr<RecordBatchGenerator> get_batches_;
+};
+
+/// \brief A Dataset wrapping child Datasets.
+class ARROW_DS_EXPORT UnionDataset : public Dataset {
+ public:
+ /// \brief Construct a UnionDataset wrapping child Datasets.
+ ///
+ /// \param[in] schema the schema of the resulting dataset.
+ /// \param[in] children one or more child Datasets. Their schemas must be identical to
+ /// schema.
+ static Result<std::shared_ptr<UnionDataset>> Make(std::shared_ptr<Schema> schema,
+ DatasetVector children);
+
+ const DatasetVector& children() const { return children_; }
+
+ std::string type_name() const override { return "union"; }
+
+ Result<std::shared_ptr<Dataset>> ReplaceSchema(
+ std::shared_ptr<Schema> schema) const override;
+
+ protected:
+ Result<FragmentIterator> GetFragmentsImpl(compute::Expression predicate) override;
+
+ explicit UnionDataset(std::shared_ptr<Schema> schema, DatasetVector children)
+ : Dataset(std::move(schema)), children_(std::move(children)) {}
+
+ DatasetVector children_;
+
+ friend class UnionDatasetFactory;
+};
+
+/// @}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/dataset_internal.h b/src/arrow/cpp/src/arrow/dataset/dataset_internal.h
new file mode 100644
index 000000000..aac437df0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/dataset_internal.h
@@ -0,0 +1,160 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/dataset/dataset.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/type_fwd.h"
+#include "arrow/record_batch.h"
+#include "arrow/scalar.h"
+#include "arrow/type.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/iterator.h"
+
+namespace arrow {
+namespace dataset {
+
+/// \brief GetFragmentsFromDatasets transforms a vector<Dataset> into a
+/// flattened FragmentIterator.
+inline Result<FragmentIterator> GetFragmentsFromDatasets(const DatasetVector& datasets,
+ compute::Expression predicate) {
+ // Iterator<Dataset>
+ auto datasets_it = MakeVectorIterator(datasets);
+
+ // Dataset -> Iterator<Fragment>
+ auto fn = [predicate](std::shared_ptr<Dataset> dataset) -> Result<FragmentIterator> {
+ return dataset->GetFragments(predicate);
+ };
+
+ // Iterator<Iterator<Fragment>>
+ auto fragments_it = MakeMaybeMapIterator(fn, std::move(datasets_it));
+
+ // Iterator<Fragment>
+ return MakeFlattenIterator(std::move(fragments_it));
+}
+
+inline std::shared_ptr<Schema> SchemaFromColumnNames(
+ const std::shared_ptr<Schema>& input, const std::vector<std::string>& column_names) {
+ std::vector<std::shared_ptr<Field>> columns;
+ for (FieldRef ref : column_names) {
+ auto maybe_field = ref.GetOne(*input);
+ if (maybe_field.ok()) {
+ columns.push_back(std::move(maybe_field).ValueOrDie());
+ }
+ }
+
+ return schema(std::move(columns))->WithMetadata(input->metadata());
+}
+
+/// Get fragment scan options of the expected type.
+/// \return Fragment scan options if provided on the scan options, else the default
+/// options if set, else a default-constructed value. If options are provided
+/// but of the wrong type, an error is returned.
+template <typename T>
+arrow::Result<std::shared_ptr<T>> GetFragmentScanOptions(
+ const std::string& type_name, const ScanOptions* scan_options,
+ const std::shared_ptr<FragmentScanOptions>& default_options) {
+ auto source = default_options;
+ if (scan_options && scan_options->fragment_scan_options) {
+ source = scan_options->fragment_scan_options;
+ }
+ if (!source) {
+ return std::make_shared<T>();
+ }
+ if (source->type_name() != type_name) {
+ return Status::Invalid("FragmentScanOptions of type ", source->type_name(),
+ " were provided for scanning a fragment of type ", type_name);
+ }
+ return ::arrow::internal::checked_pointer_cast<T>(source);
+}
+
+class FragmentDataset : public Dataset {
+ public:
+ FragmentDataset(std::shared_ptr<Schema> schema, FragmentVector fragments)
+ : Dataset(std::move(schema)), fragments_(std::move(fragments)) {}
+
+ FragmentDataset(std::shared_ptr<Schema> schema,
+ AsyncGenerator<std::shared_ptr<Fragment>> fragments)
+ : Dataset(std::move(schema)), fragment_gen_(std::move(fragments)) {}
+
+ std::string type_name() const override { return "fragment"; }
+
+ Result<std::shared_ptr<Dataset>> ReplaceSchema(
+ std::shared_ptr<Schema> schema) const override {
+ return std::make_shared<FragmentDataset>(std::move(schema), fragments_);
+ }
+
+ protected:
+ Result<FragmentIterator> GetFragmentsImpl(compute::Expression predicate) override {
+ if (fragment_gen_) {
+ // TODO(ARROW-8163): Async fragment scanning can be forwarded rather than waiting
+ // for the whole generator here. For now, all Dataset impls have a vector of
+ // Fragments anyway
+ auto fragments_fut = CollectAsyncGenerator(std::move(fragment_gen_));
+ ARROW_ASSIGN_OR_RAISE(fragments_, fragments_fut.result());
+ }
+
+ // TODO(ARROW-12891) Provide subtree pruning for any vector of fragments
+ FragmentVector fragments;
+ for (const auto& fragment : fragments_) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto simplified_filter,
+ compute::SimplifyWithGuarantee(predicate, fragment->partition_expression()));
+
+ if (simplified_filter.IsSatisfiable()) {
+ fragments.push_back(fragment);
+ }
+ }
+ return MakeVectorIterator(std::move(fragments));
+ }
+
+ FragmentVector fragments_;
+ AsyncGenerator<std::shared_ptr<Fragment>> fragment_gen_;
+};
+
+// Given a record batch generator, creates a new generator that slices
+// batches so individual batches have at most batch_size rows. The
+// resulting generator is async-reentrant, but does not forward
+// reentrant pulls, so apply readahead before using this helper.
+inline RecordBatchGenerator MakeChunkedBatchGenerator(RecordBatchGenerator gen,
+ int64_t batch_size) {
+ return MakeFlatMappedGenerator(
+ std::move(gen),
+ [batch_size](const std::shared_ptr<RecordBatch>& batch)
+ -> ::arrow::AsyncGenerator<std::shared_ptr<::arrow::RecordBatch>> {
+ const int64_t rows = batch->num_rows();
+ if (rows <= batch_size) {
+ return ::arrow::MakeVectorGenerator<std::shared_ptr<RecordBatch>>({batch});
+ }
+ std::vector<std::shared_ptr<RecordBatch>> slices;
+ slices.reserve(rows / batch_size + (rows % batch_size != 0));
+ for (int64_t i = 0; i < rows; i += batch_size) {
+ slices.push_back(batch->Slice(i, batch_size));
+ }
+ return ::arrow::MakeVectorGenerator(std::move(slices));
+ });
+}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/dataset_test.cc b/src/arrow/cpp/src/arrow/dataset/dataset_test.cc
new file mode 100644
index 000000000..66d69c30c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/dataset_test.cc
@@ -0,0 +1,734 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/dataset.h"
+
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/discovery.h"
+#include "arrow/dataset/partition.h"
+#include "arrow/dataset/test_util.h"
+#include "arrow/filesystem/mockfs.h"
+#include "arrow/stl.h"
+#include "arrow/testing/generator.h"
+#include "arrow/util/optional.h"
+
+namespace arrow {
+namespace dataset {
+
+class TestInMemoryFragment : public DatasetFixtureMixin {};
+
+using RecordBatchVector = std::vector<std::shared_ptr<RecordBatch>>;
+
+TEST_F(TestInMemoryFragment, Scan) {
+ constexpr int64_t kBatchSize = 1024;
+ constexpr int64_t kNumberBatches = 16;
+
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
+ auto reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);
+
+ // Creates a InMemoryFragment of the same repeated batch.
+ RecordBatchVector batches = {static_cast<size_t>(kNumberBatches), batch};
+ auto fragment = std::make_shared<InMemoryFragment>(batches);
+
+ AssertFragmentEquals(reader.get(), fragment.get());
+}
+
+class TestInMemoryDataset : public DatasetFixtureMixin {};
+
+TEST_F(TestInMemoryDataset, ReplaceSchema) {
+ constexpr int64_t kBatchSize = 1;
+ constexpr int64_t kNumberBatches = 1;
+
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
+ auto reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);
+
+ auto dataset = std::make_shared<InMemoryDataset>(
+ schema_, RecordBatchVector{static_cast<size_t>(kNumberBatches), batch});
+
+ // drop field
+ ASSERT_OK(dataset->ReplaceSchema(schema({field("i32", int32())})).status());
+ // add field (will be materialized as null during projection)
+ ASSERT_OK(dataset->ReplaceSchema(schema({field("str", utf8())})).status());
+ // incompatible type
+ ASSERT_RAISES(TypeError,
+ dataset->ReplaceSchema(schema({field("i32", utf8())})).status());
+ // incompatible nullability
+ ASSERT_RAISES(
+ TypeError,
+ dataset->ReplaceSchema(schema({field("f64", float64(), /*nullable=*/false)}))
+ .status());
+ // add non-nullable field
+ ASSERT_RAISES(TypeError,
+ dataset->ReplaceSchema(schema({field("str", utf8(), /*nullable=*/false)}))
+ .status());
+}
+
+TEST_F(TestInMemoryDataset, GetFragments) {
+ constexpr int64_t kBatchSize = 1024;
+ constexpr int64_t kNumberBatches = 16;
+
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
+ auto reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);
+
+ auto dataset = std::make_shared<InMemoryDataset>(
+ schema_, RecordBatchVector{static_cast<size_t>(kNumberBatches), batch});
+
+ AssertDatasetEquals(reader.get(), dataset.get());
+}
+
+TEST_F(TestInMemoryDataset, InMemoryFragment) {
+ constexpr int64_t kBatchSize = 1024;
+
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
+ RecordBatchVector batches{batch};
+
+ // Regression test: previously this constructor relied on undefined behavior (order of
+ // evaluation of arguments) leading to fragments being constructed with empty schemas
+ auto fragment = std::make_shared<InMemoryFragment>(batches);
+ ASSERT_OK_AND_ASSIGN(auto schema, fragment->ReadPhysicalSchema());
+ AssertSchemaEqual(batch->schema(), schema);
+}
+
+class TestUnionDataset : public DatasetFixtureMixin {};
+
+TEST_F(TestUnionDataset, ReplaceSchema) {
+ constexpr int64_t kBatchSize = 1;
+ constexpr int64_t kNumberBatches = 1;
+
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
+
+ std::vector<std::shared_ptr<RecordBatch>> batches{static_cast<size_t>(kNumberBatches),
+ batch};
+
+ DatasetVector children = {
+ std::make_shared<InMemoryDataset>(schema_, batches),
+ std::make_shared<InMemoryDataset>(schema_, batches),
+ };
+
+ const int64_t total_batches = children.size() * kNumberBatches;
+ auto reader = ConstantArrayGenerator::Repeat(total_batches, batch);
+
+ ASSERT_OK_AND_ASSIGN(auto dataset, UnionDataset::Make(schema_, children));
+ AssertDatasetEquals(reader.get(), dataset.get());
+
+ // drop field
+ ASSERT_OK(dataset->ReplaceSchema(schema({field("i32", int32())})).status());
+ // add nullable field (will be materialized as null during projection)
+ ASSERT_OK(dataset->ReplaceSchema(schema({field("str", utf8())})).status());
+ // incompatible type
+ ASSERT_RAISES(TypeError,
+ dataset->ReplaceSchema(schema({field("i32", utf8())})).status());
+ // incompatible nullability
+ ASSERT_RAISES(
+ TypeError,
+ dataset->ReplaceSchema(schema({field("f64", float64(), /*nullable=*/false)}))
+ .status());
+ // add non-nullable field
+ ASSERT_RAISES(TypeError,
+ dataset->ReplaceSchema(schema({field("str", utf8(), /*nullable=*/false)}))
+ .status());
+}
+
+TEST_F(TestUnionDataset, GetFragments) {
+ constexpr int64_t kBatchSize = 1024;
+ constexpr int64_t kChildPerNode = 2;
+ constexpr int64_t kCompleteBinaryTreeDepth = 4;
+
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
+
+ auto n_leaves = 1U << kCompleteBinaryTreeDepth;
+ auto reader = ConstantArrayGenerator::Repeat(n_leaves, batch);
+
+ // Creates a complete binary tree of depth kCompleteBinaryTreeDepth where the
+ // leaves are InMemoryDataset containing kChildPerNode fragments.
+
+ auto l1_leaf_dataset = std::make_shared<InMemoryDataset>(
+ schema_, RecordBatchVector{static_cast<size_t>(kChildPerNode), batch});
+
+ ASSERT_OK_AND_ASSIGN(
+ auto l2_leaf_tree_dataset,
+ UnionDataset::Make(
+ schema_, DatasetVector{static_cast<size_t>(kChildPerNode), l1_leaf_dataset}));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto l3_middle_tree_dataset,
+ UnionDataset::Make(schema_, DatasetVector{static_cast<size_t>(kChildPerNode),
+ l2_leaf_tree_dataset}));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto root_dataset,
+ UnionDataset::Make(schema_, DatasetVector{static_cast<size_t>(kChildPerNode),
+ l3_middle_tree_dataset}));
+
+ AssertDatasetEquals(reader.get(), root_dataset.get());
+}
+
+TEST_F(TestUnionDataset, TrivialScan) {
+ constexpr int64_t kNumberBatches = 16;
+ constexpr int64_t kBatchSize = 1024;
+
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
+
+ std::vector<std::shared_ptr<RecordBatch>> batches{static_cast<size_t>(kNumberBatches),
+ batch};
+
+ DatasetVector children = {
+ std::make_shared<InMemoryDataset>(schema_, batches),
+ std::make_shared<InMemoryDataset>(schema_, batches),
+ };
+
+ const int64_t total_batches = children.size() * kNumberBatches;
+ auto reader = ConstantArrayGenerator::Repeat(total_batches, batch);
+
+ ASSERT_OK_AND_ASSIGN(auto dataset, UnionDataset::Make(schema_, children));
+ AssertDatasetEquals(reader.get(), dataset.get());
+}
+
+TEST(TestProjector, CheckProjectable) {
+ struct Assert {
+ explicit Assert(FieldVector from) : from_(from) {}
+ Schema from_;
+
+ void ProjectableTo(FieldVector to) {
+ ARROW_EXPECT_OK(CheckProjectable(from_, Schema(to)));
+ }
+
+ void NotProjectableTo(FieldVector to, std::string substr = "") {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, testing::HasSubstr(substr),
+ CheckProjectable(from_, Schema(to)));
+ }
+ };
+
+ auto i8 = field("i8", int8());
+ auto u16 = field("u16", uint16());
+ auto str = field("str", utf8());
+ auto i8_req = field("i8", int8(), false);
+ auto u16_req = field("u16", uint16(), false);
+ auto str_req = field("str", utf8(), false);
+ auto str_nil = field("str", null());
+
+ // trivial
+ Assert({}).ProjectableTo({});
+ Assert({i8}).ProjectableTo({i8});
+ Assert({i8, u16_req}).ProjectableTo({i8, u16_req});
+
+ // reorder
+ Assert({i8, u16}).ProjectableTo({u16, i8});
+ Assert({i8, str, u16}).ProjectableTo({u16, i8, str});
+
+ // drop field(s)
+ Assert({i8}).ProjectableTo({});
+
+ // add field(s)
+ Assert({}).ProjectableTo({i8});
+ Assert({}).ProjectableTo({i8, u16});
+ Assert({}).NotProjectableTo({u16_req},
+ "is not nullable and does not exist in origin schema");
+ Assert({i8}).NotProjectableTo({u16_req, i8});
+
+ // change nullability
+ Assert({i8}).NotProjectableTo({i8_req},
+ "not nullable but is not required in origin schema");
+ Assert({i8_req}).ProjectableTo({i8});
+ Assert({str_nil}).ProjectableTo({str});
+ Assert({str_nil}).NotProjectableTo({str_req});
+
+ // change field type
+ Assert({i8}).NotProjectableTo({field("i8", utf8())},
+ "fields had matching names but differing types");
+}
+
+class TestEndToEnd : public TestUnionDataset {
+ void SetUp() override {
+ bool nullable = false;
+ SetSchema({
+ field("region", utf8(), nullable),
+ field("model", utf8(), nullable),
+ field("sales", float64(), nullable),
+ // partition columns
+ field("year", int32()),
+ field("month", int32()),
+ field("country", utf8()),
+ });
+
+ using PathAndContent = std::vector<std::pair<std::string, std::string>>;
+ auto files = PathAndContent{
+ {"/dataset/2018/01/US/dat.json", R"([
+ {"region": "NY", "model": "3", "sales": 742.0},
+ {"region": "NY", "model": "S", "sales": 304.125},
+ {"region": "NY", "model": "X", "sales": 136.25},
+ {"region": "NY", "model": "Y", "sales": 27.5}
+ ])"},
+ {"/dataset/2018/01/CA/dat.json", R"([
+ {"region": "CA", "model": "3", "sales": 512},
+ {"region": "CA", "model": "S", "sales": 978},
+ {"region": "CA", "model": "X", "sales": 1.0},
+ {"region": "CA", "model": "Y", "sales": 69}
+ ])"},
+ {"/dataset/2019/01/US/dat.json", R"([
+ {"region": "QC", "model": "3", "sales": 273.5},
+ {"region": "QC", "model": "S", "sales": 13},
+ {"region": "QC", "model": "X", "sales": 54},
+ {"region": "QC", "model": "Y", "sales": 21}
+ ])"},
+ {"/dataset/2019/01/CA/dat.json", R"([
+ {"region": "QC", "model": "3", "sales": 152.25},
+ {"region": "QC", "model": "S", "sales": 10},
+ {"region": "QC", "model": "X", "sales": 42},
+ {"region": "QC", "model": "Y", "sales": 37}
+ ])"},
+ {"/dataset/.pesky", "garbage content"},
+ };
+
+ auto mock_fs = std::make_shared<fs::internal::MockFileSystem>(fs::kNoTime);
+ for (const auto& f : files) {
+ ARROW_EXPECT_OK(mock_fs->CreateFile(f.first, f.second, /*recursive=*/true));
+ }
+
+ fs_ = mock_fs;
+ }
+
+ protected:
+ std::shared_ptr<fs::FileSystem> fs_;
+};
+
+TEST_F(TestEndToEnd, EndToEndSingleDataset) {
+ // The dataset API is divided in 3 parts:
+ // - Creation
+ // - Querying
+ // - Consuming
+
+ // Creation.
+ //
+ // A Dataset is the union of one or more Datasets with the same schema.
+ // Example of Dataset, FileSystemDataset, OdbcDataset,
+ // FlightDataset.
+
+ // A Dataset is composed of Fragments. Each Fragment can yield
+ // multiple RecordBatches. Datasets can be created manually or "discovered"
+ // via the DatasetFactory interface.
+ std::shared_ptr<DatasetFactory> factory;
+
+ // The user must specify which FileFormat is used to create FileFragments.
+ // This option is specific to FileSystemDataset (and the builder).
+ auto format_schema = SchemaFromColumnNames(schema_, {"region", "model", "sales"});
+ auto format = std::make_shared<JSONRecordBatchFileFormat>(format_schema);
+
+ // A selector is used to crawl files and directories of a
+ // filesystem. If the options in FileSelector are not enough, the
+ // FileSystemDatasetFactory class also supports an explicit list of
+ // fs::FileInfo instead of the selector.
+ fs::FileSelector s;
+ s.base_dir = "/dataset";
+ s.recursive = true;
+
+ // Further options can be given to the factory mechanism via the
+ // FileSystemFactoryOptions configuration class. See the docstring for more
+ // information.
+ FileSystemFactoryOptions options;
+ options.selector_ignore_prefixes = {"."};
+
+ // Partitions expressions can be discovered for Dataset and Fragments.
+ // This metadata is then used in conjunction with the query filter to apply
+ // the pushdown predicate optimization.
+ //
+ // The DirectoryPartitioning is a partitioning where the path is split with
+ // the directory separator character and the components are parsed as values
+ // of the corresponding fields in its schema.
+ //
+ // Since a PartitioningFactory is specified instead of an explicit
+ // Partitioning, the types of partition fields will be inferred.
+ //
+ // - "/2019" -> {"year": 2019}
+ // - "/2019/01 -> {"year": 2019, "month": 1}
+ // - "/2019/01/CA -> {"year": 2019, "month": 1, "country": "CA"}
+ // - "/2019/01/CA/a_file.json -> {"year": 2019, "month": 1, "country": "CA"}
+ options.partitioning = DirectoryPartitioning::MakeFactory({"year", "month", "country"});
+
+ ASSERT_OK_AND_ASSIGN(factory, FileSystemDatasetFactory::Make(fs_, s, format, options));
+
+ // Fragments might have compatible but slightly different schemas, e.g.
+ // schema evolved by adding/renaming columns. In this case, the schema is
+ // passed to the dataset constructor.
+ // The inspected_schema may optionally be modified before being finalized.
+ InspectOptions inspect_options;
+ inspect_options.fragments = InspectOptions::kInspectAllFragments;
+ ASSERT_OK_AND_ASSIGN(auto inspected_schema, factory->Inspect(inspect_options));
+ EXPECT_EQ(*schema_, *inspected_schema);
+
+ // Build the Dataset where partitions are attached to fragments (files).
+ ASSERT_OK_AND_ASSIGN(auto source, factory->Finish(inspected_schema));
+
+ // Create the Dataset from our single Dataset.
+ ASSERT_OK_AND_ASSIGN(auto dataset, UnionDataset::Make(inspected_schema, {source}));
+
+ // Querying.
+ //
+ // The Scan operator materializes data from io into memory. Avoiding data
+ // transfer is a critical optimization done by analytical engine. Thus, a
+ // Scan can take multiple options, notably a subset of columns and a filter
+ // expression.
+ ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan());
+
+ // An optional subset of columns can be provided. This will trickle to
+ // Fragment drivers. The net effect is that only columns of interest will
+ // be materialized if the Fragment supports it. This is the major benefit
+ // of using a column-major format versus a row-major format.
+ //
+ // This API decouples the Dataset/Fragment implementation and column
+ // projection from the query part.
+ //
+ // For example, a ParquetFileFragment may read the necessary byte ranges
+ // exclusively, ranges, or an OdbcFragment could convert the projection to a SELECT
+ // statement. The CsvFileFragment wouldn't benefit from this as much, but
+ // can still benefit from skipping conversion of unneeded columns.
+ std::vector<std::string> columns{"sales", "model", "country"};
+ ASSERT_OK(scanner_builder->Project(columns));
+
+ // An optional filter expression may also be specified. The filter expression
+ // is evaluated against input rows. Only rows for which the filter evaluates to true
+ // are yielded. Predicate pushdown optimizations are applied using partition
+ // information if available.
+ //
+ // This API decouples predicate pushdown from the Dataset implementation
+ // and partition discovery.
+ //
+ // The following filter tests both predicate pushdown and post filtering
+ // without partition information because `year` is a partition and `sales` is
+ // not.
+ auto filter = and_(equal(field_ref("year"), literal(2019)),
+ greater(field_ref("sales"), literal(100.0)));
+ ASSERT_OK(scanner_builder->Filter(filter));
+
+ ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
+ // In the simplest case, consumption is simply conversion to a Table.
+ ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable());
+
+ auto expected = TableFromJSON(scanner_builder->projected_schema(), {R"([
+ {"sales": 152.25, "model": "3", "country": "CA"},
+ {"sales": 273.5, "model": "3", "country": "US"}
+ ])"});
+ AssertTablesEqual(*expected, *table, false, true);
+}
+
+inline std::shared_ptr<Schema> SchemaFromNames(const std::vector<std::string> names) {
+ std::vector<std::shared_ptr<Field>> fields;
+ for (const auto& name : names) {
+ fields.push_back(field(name, int32()));
+ }
+
+ return schema(fields);
+}
+
+class TestSchemaUnification : public TestUnionDataset {
+ public:
+ using i32 = util::optional<int32_t>;
+ using PathAndContent = std::vector<std::pair<std::string, std::string>>;
+
+ void SetUp() override {
+ // The following test creates 2 sources with divergent but compatible
+ // schemas. Each source have a common partitioning where the
+ // fields are not materialized in the data fragments.
+ //
+ // Each data is composed of 2 data fragments with divergent but
+ // compatible schemas. The data fragment within a source share at
+ // least one column.
+ //
+ // Thus, the fixture helps verifying various scenarios where the Scanner
+ // must fix the RecordBatches to align with the final unified schema exposed
+ // to the consumer.
+ static constexpr auto ds1_df1 = "/dataset/alpha/part_ds=1/part_df=1/data.json";
+ static constexpr auto ds1_df2 = "/dataset/alpha/part_ds=1/part_df=2/data.json";
+ static constexpr auto ds2_df1 = "/dataset/beta/part_ds=2/part_df=1/data.json";
+ static constexpr auto ds2_df2 = "/dataset/beta/part_ds=2/part_df=2/data.json";
+ auto files = PathAndContent{
+ // First Dataset
+ {ds1_df1, R"([{"phy_1": 111, "phy_2": 211}])"},
+ {ds1_df2, R"([{"phy_2": 212, "phy_3": 312}])"},
+ // Second Dataset
+ {ds2_df1, R"([{"phy_3": 321, "phy_4": 421}])"},
+ {ds2_df2, R"([{"phy_4": 422, "phy_2": 222}])"},
+ };
+
+ auto mock_fs = std::make_shared<fs::internal::MockFileSystem>(fs::kNoTime);
+ for (const auto& f : files) {
+ ARROW_EXPECT_OK(mock_fs->CreateFile(f.first, f.second, /* recursive */ true));
+ }
+ fs_ = mock_fs;
+
+ auto get_source =
+ [this](std::string base,
+ std::vector<std::string> paths) -> Result<std::shared_ptr<Dataset>> {
+ auto resolver = [](const FileSource& source) -> std::shared_ptr<Schema> {
+ auto path = source.path();
+ // A different schema for each data fragment.
+ if (path == ds1_df1) {
+ return SchemaFromNames({"phy_1", "phy_2"});
+ } else if (path == ds1_df2) {
+ return SchemaFromNames({"phy_2", "phy_3"});
+ } else if (path == ds2_df1) {
+ return SchemaFromNames({"phy_3", "phy_4"});
+ } else if (path == ds2_df2) {
+ return SchemaFromNames({"phy_4", "phy_2"});
+ }
+
+ return nullptr;
+ };
+
+ auto format = std::make_shared<JSONRecordBatchFileFormat>(resolver);
+
+ FileSystemFactoryOptions options;
+ options.partition_base_dir = base;
+ options.partitioning =
+ std::make_shared<HivePartitioning>(SchemaFromNames({"part_ds", "part_df"}));
+
+ ARROW_ASSIGN_OR_RAISE(auto factory,
+ FileSystemDatasetFactory::Make(fs_, paths, format, options));
+
+ ARROW_ASSIGN_OR_RAISE(auto schema, factory->Inspect());
+
+ return factory->Finish(schema);
+ };
+
+ schema_ = SchemaFromNames({"phy_1", "phy_2", "phy_3", "phy_4", "part_ds", "part_df"});
+ ASSERT_OK_AND_ASSIGN(auto ds1, get_source("/dataset/alpha", {ds1_df1, ds1_df2}));
+ ASSERT_OK_AND_ASSIGN(auto ds2, get_source("/dataset/beta", {ds2_df1, ds2_df2}));
+
+ // FIXME(bkietz) this is a hack: allow differing schemas for the purposes of this
+ // test
+ class DisparateSchemasUnionDataset : public UnionDataset {
+ public:
+ DisparateSchemasUnionDataset(std::shared_ptr<Schema> schema, DatasetVector children)
+ : UnionDataset(std::move(schema), std::move(children)) {}
+ };
+ dataset_ =
+ std::make_shared<DisparateSchemasUnionDataset>(schema_, DatasetVector{ds1, ds2});
+ }
+
+ template <typename TupleType>
+ void AssertScanEquals(std::shared_ptr<Scanner> scanner,
+ const std::vector<TupleType>& expected_rows) {
+ std::vector<std::string> columns;
+ for (const auto& field : scanner->options()->projected_schema->fields()) {
+ columns.push_back(field->name());
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto actual, scanner->ToTable());
+ std::shared_ptr<Table> expected;
+ ASSERT_OK(stl::TableFromTupleRange(default_memory_pool(), expected_rows, columns,
+ &expected));
+ AssertTablesEqual(*expected, *actual, false, true);
+ }
+
+ template <typename TupleType>
+ void AssertBuilderEquals(std::shared_ptr<ScannerBuilder> builder,
+ const std::vector<TupleType>& expected_rows) {
+ ASSERT_OK_AND_ASSIGN(auto scanner, builder->Finish());
+ AssertScanEquals(scanner, expected_rows);
+ }
+
+ protected:
+ std::shared_ptr<fs::FileSystem> fs_;
+ std::shared_ptr<Dataset> dataset_;
+};
+
+using util::nullopt;
+
+TEST_F(TestSchemaUnification, SelectStar) {
+ // This is a `SELECT * FROM dataset` where it ensures:
+ //
+ // - proper re-ordering of columns
+ // - materializing missing physical columns in Fragments
+ // - materializing missing partition columns extracted from Partitioning
+ ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan());
+
+ using TupleType = std::tuple<i32, i32, i32, i32, i32, i32>;
+ std::vector<TupleType> rows = {
+ TupleType(111, 211, nullopt, nullopt, 1, 1),
+ TupleType(nullopt, 212, 312, nullopt, 1, 2),
+ TupleType(nullopt, nullopt, 321, 421, 2, 1),
+ TupleType(nullopt, 222, nullopt, 422, 2, 2),
+ };
+
+ AssertBuilderEquals(scan_builder, rows);
+}
+
+TEST_F(TestSchemaUnification, SelectPhysicalColumns) {
+ // Same as above, but scoped to physical columns.
+ ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan());
+ ASSERT_OK(scan_builder->Project({"phy_1", "phy_2", "phy_3", "phy_4"}));
+
+ using TupleType = std::tuple<i32, i32, i32, i32>;
+ std::vector<TupleType> rows = {
+ TupleType(111, 211, nullopt, nullopt),
+ TupleType(nullopt, 212, 312, nullopt),
+ TupleType(nullopt, nullopt, 321, 421),
+ TupleType(nullopt, 222, nullopt, 422),
+ };
+
+ AssertBuilderEquals(scan_builder, rows);
+}
+
+TEST_F(TestSchemaUnification, SelectSomeReorderedPhysicalColumns) {
+ // Select physical columns in a different order than physical Fragments
+ ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan());
+ ASSERT_OK(scan_builder->Project({"phy_2", "phy_1", "phy_4"}));
+
+ using TupleType = std::tuple<i32, i32, i32>;
+ std::vector<TupleType> rows = {
+ TupleType(211, 111, nullopt),
+ TupleType(212, nullopt, nullopt),
+ TupleType(nullopt, nullopt, 421),
+ TupleType(222, nullopt, 422),
+ };
+
+ AssertBuilderEquals(scan_builder, rows);
+}
+
+TEST_F(TestSchemaUnification, SelectPhysicalColumnsFilterPartitionColumn) {
+ // Select a subset of physical column with a filter on a missing physical
+ // column and a partition column, it ensures:
+ //
+ // - Can filter on virtual and physical columns with a non-trivial filter
+ // when some of the columns may not be materialized
+ ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan());
+ ASSERT_OK(scan_builder->Project({"phy_2", "phy_3", "phy_4"}));
+ ASSERT_OK(scan_builder->Filter(or_(and_(equal(field_ref("part_df"), literal(1)),
+ equal(field_ref("phy_2"), literal(211))),
+ and_(equal(field_ref("part_ds"), literal(2)),
+ not_equal(field_ref("phy_4"), literal(422))))));
+
+ using TupleType = std::tuple<i32, i32, i32>;
+ std::vector<TupleType> rows = {
+ TupleType(211, nullopt, nullopt),
+ TupleType(nullopt, 321, 421),
+ };
+
+ AssertBuilderEquals(scan_builder, rows);
+}
+
+TEST_F(TestSchemaUnification, SelectSyntheticColumn) {
+ // Select only a synthetic column
+ ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan());
+ ASSERT_OK(scan_builder->Project(
+ {call("add", {field_ref("phy_1"), field_ref("part_df")})}, {"phy_1 + part_df"}));
+
+ ASSERT_OK_AND_ASSIGN(auto scanner, scan_builder->Finish());
+ AssertSchemaEqual(Schema({field("phy_1 + part_df", int32())}),
+ *scanner->options()->projected_schema);
+
+ using TupleType = std::tuple<i32>;
+ std::vector<TupleType> rows = {
+ TupleType(111 + 1),
+ TupleType(nullopt),
+ TupleType(nullopt),
+ TupleType(nullopt),
+ };
+
+ AssertBuilderEquals(scan_builder, rows);
+}
+
+TEST_F(TestSchemaUnification, SelectPartitionColumns) {
+ // Selects partition (virtual) columns, it ensures:
+ //
+ // - virtual column are materialized
+ // - Fragment yield the right number of rows even if no column is selected
+ ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan());
+ ASSERT_OK(scan_builder->Project({"part_ds", "part_df"}));
+ using TupleType = std::tuple<i32, i32>;
+ std::vector<TupleType> rows = {
+ TupleType(1, 1),
+ TupleType(1, 2),
+ TupleType(2, 1),
+ TupleType(2, 2),
+ };
+ AssertBuilderEquals(scan_builder, rows);
+}
+
+TEST_F(TestSchemaUnification, SelectPartitionColumnsFilterPhysicalColumn) {
+ // Selects re-ordered virtual columns with a filter on a physical columns
+ ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan());
+ ASSERT_OK(scan_builder->Filter(equal(field_ref("phy_1"), literal(111))));
+
+ ASSERT_OK(scan_builder->Project({"part_df", "part_ds"}));
+ using TupleType = std::tuple<i32, i32>;
+ std::vector<TupleType> rows = {
+ TupleType(1, 1),
+ };
+ AssertBuilderEquals(scan_builder, rows);
+}
+
+TEST_F(TestSchemaUnification, SelectMixedColumnsAndFilter) {
+ // Selects mix of physical/virtual with a different order and uses a filter on
+ // a physical column not selected.
+ ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan());
+ ASSERT_OK(scan_builder->Filter(greater_equal(field_ref("phy_2"), literal(212))));
+ ASSERT_OK(scan_builder->Project({"part_df", "phy_3", "part_ds", "phy_1"}));
+
+ using TupleType = std::tuple<i32, i32, i32, i32>;
+ std::vector<TupleType> rows = {
+ TupleType(2, 312, 1, nullopt),
+ TupleType(2, nullopt, 2, nullopt),
+ };
+ AssertBuilderEquals(scan_builder, rows);
+}
+
+TEST(TestDictPartitionColumn, SelectPartitionColumnFilterPhysicalColumn) {
+ auto partition_field = field("part", dictionary(int32(), utf8()));
+ auto path = "/dataset/part=one/data.json";
+ auto dictionary = ArrayFromJSON(utf8(), R"(["one"])");
+
+ auto mock_fs = std::make_shared<fs::internal::MockFileSystem>(fs::kNoTime);
+ ARROW_EXPECT_OK(mock_fs->CreateFile(path, R"([ {"phy_1": 111, "phy_2": 211} ])",
+ /*recursive=*/true));
+
+ auto physical_schema = SchemaFromNames({"phy_1", "phy_2"});
+ auto format = std::make_shared<JSONRecordBatchFileFormat>(
+ [=](const FileSource&) { return physical_schema; });
+
+ FileSystemFactoryOptions options;
+ options.partition_base_dir = "/dataset";
+ options.partitioning = std::make_shared<HivePartitioning>(schema({partition_field}),
+ ArrayVector{dictionary});
+
+ ASSERT_OK_AND_ASSIGN(auto factory,
+ FileSystemDatasetFactory::Make(mock_fs, {path}, format, options));
+
+ ASSERT_OK_AND_ASSIGN(auto schema, factory->Inspect());
+
+ ASSERT_OK_AND_ASSIGN(auto dataset, factory->Finish(schema));
+
+ // Selects re-ordered virtual column with a filter on a physical column
+ ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset->NewScan());
+ ASSERT_OK(scan_builder->Filter(equal(field_ref("phy_1"), literal(111))));
+ ASSERT_OK(scan_builder->Project({"part"}));
+
+ ASSERT_OK_AND_ASSIGN(auto scanner, scan_builder->Finish());
+ ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable());
+ AssertArraysEqual(*table->column(0)->chunk(0),
+ *ArrayFromJSON(partition_field->type(), R"(["one"])"));
+}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/dataset_writer.cc b/src/arrow/cpp/src/arrow/dataset/dataset_writer.cc
new file mode 100644
index 000000000..a61f32cbc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/dataset_writer.cc
@@ -0,0 +1,529 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/dataset_writer.h"
+
+#include <list>
+#include <mutex>
+#include <unordered_map>
+
+#include "arrow/filesystem/path_util.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/map.h"
+#include "arrow/util/string.h"
+
+namespace arrow {
+namespace dataset {
+namespace internal {
+
+namespace {
+
+constexpr util::string_view kIntegerToken = "{i}";
+
+class Throttle {
+ public:
+ explicit Throttle(uint64_t max_value) : max_value_(max_value) {}
+
+ bool Unthrottled() const { return max_value_ <= 0; }
+
+ Future<> Acquire(uint64_t values) {
+ if (Unthrottled()) {
+ return Future<>::MakeFinished();
+ }
+ std::lock_guard<std::mutex> lg(mutex_);
+ if (values + current_value_ > max_value_) {
+ in_waiting_ = values;
+ backpressure_ = Future<>::Make();
+ } else {
+ current_value_ += values;
+ }
+ return backpressure_;
+ }
+
+ void Release(uint64_t values) {
+ if (Unthrottled()) {
+ return;
+ }
+ Future<> to_complete;
+ {
+ std::lock_guard<std::mutex> lg(mutex_);
+ current_value_ -= values;
+ if (in_waiting_ > 0 && in_waiting_ + current_value_ <= max_value_) {
+ in_waiting_ = 0;
+ to_complete = backpressure_;
+ }
+ }
+ if (to_complete.is_valid()) {
+ to_complete.MarkFinished();
+ }
+ }
+
+ private:
+ Future<> backpressure_ = Future<>::MakeFinished();
+ uint64_t max_value_;
+ uint64_t in_waiting_ = 0;
+ uint64_t current_value_ = 0;
+ std::mutex mutex_;
+};
+
+class DatasetWriterFileQueue : public util::AsyncDestroyable {
+ public:
+ explicit DatasetWriterFileQueue(const Future<std::shared_ptr<FileWriter>>& writer_fut,
+ const FileSystemDatasetWriteOptions& options,
+ std::mutex* visitors_mutex)
+ : options_(options), visitors_mutex_(visitors_mutex) {
+ running_task_ = Future<>::Make();
+ writer_fut.AddCallback(
+ [this](const Result<std::shared_ptr<FileWriter>>& maybe_writer) {
+ if (maybe_writer.ok()) {
+ writer_ = *maybe_writer;
+ Flush();
+ } else {
+ Abort(maybe_writer.status());
+ }
+ });
+ }
+
+ Future<uint64_t> Push(std::shared_ptr<RecordBatch> batch) {
+ std::unique_lock<std::mutex> lk(mutex);
+ write_queue_.push_back(std::move(batch));
+ Future<uint64_t> write_future = Future<uint64_t>::Make();
+ write_futures_.push_back(write_future);
+ if (!running_task_.is_valid()) {
+ running_task_ = Future<>::Make();
+ FlushUnlocked(std::move(lk));
+ }
+ return write_future;
+ }
+
+ Future<> DoDestroy() override {
+ std::lock_guard<std::mutex> lg(mutex);
+ if (!running_task_.is_valid()) {
+ RETURN_NOT_OK(DoFinish());
+ return Future<>::MakeFinished();
+ }
+ return running_task_.Then([this] { return DoFinish(); });
+ }
+
+ private:
+ Future<uint64_t> WriteNext() {
+ // May want to prototype / measure someday pushing the async write down further
+ return DeferNotOk(
+ io::default_io_context().executor()->Submit([this]() -> Result<uint64_t> {
+ DCHECK(running_task_.is_valid());
+ std::unique_lock<std::mutex> lk(mutex);
+ const std::shared_ptr<RecordBatch>& to_write = write_queue_.front();
+ Future<uint64_t> on_complete = write_futures_.front();
+ uint64_t rows_to_write = to_write->num_rows();
+ lk.unlock();
+ Status status = writer_->Write(to_write);
+ lk.lock();
+ write_queue_.pop_front();
+ write_futures_.pop_front();
+ lk.unlock();
+ if (!status.ok()) {
+ on_complete.MarkFinished(status);
+ } else {
+ on_complete.MarkFinished(rows_to_write);
+ }
+ return rows_to_write;
+ }));
+ }
+
+ Status DoFinish() {
+ {
+ std::lock_guard<std::mutex> lg(*visitors_mutex_);
+ RETURN_NOT_OK(options_.writer_pre_finish(writer_.get()));
+ }
+ RETURN_NOT_OK(writer_->Finish());
+ {
+ std::lock_guard<std::mutex> lg(*visitors_mutex_);
+ return options_.writer_post_finish(writer_.get());
+ }
+ }
+
+ void Abort(Status err) {
+ std::vector<Future<uint64_t>> futures_to_abort;
+ Future<> old_running_task = running_task_;
+ {
+ std::lock_guard<std::mutex> lg(mutex);
+ write_queue_.clear();
+ futures_to_abort =
+ std::vector<Future<uint64_t>>(write_futures_.begin(), write_futures_.end());
+ write_futures_.clear();
+ running_task_ = Future<>();
+ }
+ for (auto& fut : futures_to_abort) {
+ fut.MarkFinished(err);
+ }
+ old_running_task.MarkFinished(std::move(err));
+ }
+
+ void Flush() {
+ std::unique_lock<std::mutex> lk(mutex);
+ FlushUnlocked(std::move(lk));
+ }
+
+ void FlushUnlocked(std::unique_lock<std::mutex> lk) {
+ if (write_queue_.empty()) {
+ Future<> old_running_task = running_task_;
+ running_task_ = Future<>();
+ lk.unlock();
+ old_running_task.MarkFinished();
+ return;
+ }
+ WriteNext().AddCallback([this](const Result<uint64_t>& res) {
+ if (res.ok()) {
+ Flush();
+ } else {
+ Abort(res.status());
+ }
+ });
+ }
+
+ const FileSystemDatasetWriteOptions& options_;
+ std::mutex* visitors_mutex_;
+ std::shared_ptr<FileWriter> writer_;
+ std::mutex mutex;
+ std::list<std::shared_ptr<RecordBatch>> write_queue_;
+ std::list<Future<uint64_t>> write_futures_;
+ Future<> running_task_;
+};
+
+struct WriteTask {
+ std::string filename;
+ uint64_t num_rows;
+};
+
+class DatasetWriterDirectoryQueue : public util::AsyncDestroyable {
+ public:
+ DatasetWriterDirectoryQueue(std::string directory, std::shared_ptr<Schema> schema,
+ const FileSystemDatasetWriteOptions& write_options,
+ Throttle* open_files_throttle, std::mutex* visitors_mutex)
+ : directory_(std::move(directory)),
+ schema_(std::move(schema)),
+ write_options_(write_options),
+ open_files_throttle_(open_files_throttle),
+ visitors_mutex_(visitors_mutex) {}
+
+ Result<std::shared_ptr<RecordBatch>> NextWritableChunk(
+ std::shared_ptr<RecordBatch> batch, std::shared_ptr<RecordBatch>* remainder,
+ bool* will_open_file) const {
+ DCHECK_GT(batch->num_rows(), 0);
+ uint64_t rows_available = std::numeric_limits<uint64_t>::max();
+ *will_open_file = rows_written_ == 0;
+ if (write_options_.max_rows_per_file > 0) {
+ rows_available = write_options_.max_rows_per_file - rows_written_;
+ }
+
+ std::shared_ptr<RecordBatch> to_queue;
+ if (rows_available < static_cast<uint64_t>(batch->num_rows())) {
+ to_queue = batch->Slice(0, static_cast<int64_t>(rows_available));
+ *remainder = batch->Slice(static_cast<int64_t>(rows_available));
+ } else {
+ to_queue = std::move(batch);
+ }
+ return to_queue;
+ }
+
+ Future<WriteTask> StartWrite(const std::shared_ptr<RecordBatch>& batch) {
+ rows_written_ += batch->num_rows();
+ WriteTask task{current_filename_, static_cast<uint64_t>(batch->num_rows())};
+ if (!latest_open_file_) {
+ ARROW_ASSIGN_OR_RAISE(latest_open_file_, OpenFileQueue(current_filename_));
+ }
+ return latest_open_file_->Push(batch).Then([task] { return task; });
+ }
+
+ Result<std::string> GetNextFilename() {
+ auto basename = ::arrow::internal::Replace(
+ write_options_.basename_template, kIntegerToken, std::to_string(file_counter_++));
+ if (!basename) {
+ return Status::Invalid("string interpolation of basename template failed");
+ }
+
+ return fs::internal::ConcatAbstractPath(directory_, *basename);
+ }
+
+ Status FinishCurrentFile() {
+ if (latest_open_file_) {
+ latest_open_file_ = nullptr;
+ }
+ rows_written_ = 0;
+ return GetNextFilename().Value(&current_filename_);
+ }
+
+ Result<std::shared_ptr<FileWriter>> OpenWriter(const std::string& filename) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<io::OutputStream> out_stream,
+ write_options_.filesystem->OpenOutputStream(filename));
+ return write_options_.format()->MakeWriter(std::move(out_stream), schema_,
+ write_options_.file_write_options,
+ {write_options_.filesystem, filename});
+ }
+
+ Result<std::shared_ptr<DatasetWriterFileQueue>> OpenFileQueue(
+ const std::string& filename) {
+ Future<std::shared_ptr<FileWriter>> file_writer_fut =
+ init_future_.Then([this, filename] {
+ ::arrow::internal::Executor* io_executor =
+ write_options_.filesystem->io_context().executor();
+ return DeferNotOk(
+ io_executor->Submit([this, filename]() { return OpenWriter(filename); }));
+ });
+ auto file_queue = util::MakeSharedAsync<DatasetWriterFileQueue>(
+ file_writer_fut, write_options_, visitors_mutex_);
+ RETURN_NOT_OK(task_group_.AddTask(
+ file_queue->on_closed().Then([this] { open_files_throttle_->Release(1); })));
+ return file_queue;
+ }
+
+ uint64_t rows_written() const { return rows_written_; }
+
+ void PrepareDirectory() {
+ init_future_ =
+ DeferNotOk(write_options_.filesystem->io_context().executor()->Submit([this] {
+ RETURN_NOT_OK(write_options_.filesystem->CreateDir(directory_));
+ if (write_options_.existing_data_behavior ==
+ ExistingDataBehavior::kDeleteMatchingPartitions) {
+ return write_options_.filesystem->DeleteDirContents(directory_);
+ }
+ return Status::OK();
+ }));
+ }
+
+ static Result<std::unique_ptr<DatasetWriterDirectoryQueue,
+ util::DestroyingDeleter<DatasetWriterDirectoryQueue>>>
+ Make(util::AsyncTaskGroup* task_group,
+ const FileSystemDatasetWriteOptions& write_options, Throttle* open_files_throttle,
+ std::shared_ptr<Schema> schema, std::string dir, std::mutex* visitors_mutex) {
+ auto dir_queue = util::MakeUniqueAsync<DatasetWriterDirectoryQueue>(
+ std::move(dir), std::move(schema), write_options, open_files_throttle,
+ visitors_mutex);
+ RETURN_NOT_OK(task_group->AddTask(dir_queue->on_closed()));
+ dir_queue->PrepareDirectory();
+ ARROW_ASSIGN_OR_RAISE(dir_queue->current_filename_, dir_queue->GetNextFilename());
+ // std::move required to make RTools 3.5 mingw compiler happy
+ return std::move(dir_queue);
+ }
+
+ Future<> DoDestroy() override {
+ latest_open_file_.reset();
+ return task_group_.End();
+ }
+
+ private:
+ util::AsyncTaskGroup task_group_;
+ std::string directory_;
+ std::shared_ptr<Schema> schema_;
+ const FileSystemDatasetWriteOptions& write_options_;
+ Throttle* open_files_throttle_;
+ std::mutex* visitors_mutex_;
+ Future<> init_future_;
+ std::string current_filename_;
+ std::shared_ptr<DatasetWriterFileQueue> latest_open_file_;
+ uint64_t rows_written_ = 0;
+ uint32_t file_counter_ = 0;
+};
+
+Status ValidateBasenameTemplate(util::string_view basename_template) {
+ if (basename_template.find(fs::internal::kSep) != util::string_view::npos) {
+ return Status::Invalid("basename_template contained '/'");
+ }
+ size_t token_start = basename_template.find(kIntegerToken);
+ if (token_start == util::string_view::npos) {
+ return Status::Invalid("basename_template did not contain '", kIntegerToken, "'");
+ }
+ size_t next_token_start = basename_template.find(kIntegerToken, token_start + 1);
+ if (next_token_start != util::string_view::npos) {
+ return Status::Invalid("basename_template contained '", kIntegerToken,
+ "' more than once");
+ }
+ return Status::OK();
+}
+
+Status EnsureDestinationValid(const FileSystemDatasetWriteOptions& options) {
+ if (options.existing_data_behavior == ExistingDataBehavior::kError) {
+ fs::FileSelector selector;
+ selector.base_dir = options.base_dir;
+ selector.recursive = true;
+ Result<std::vector<fs::FileInfo>> maybe_files =
+ options.filesystem->GetFileInfo(selector);
+ if (!maybe_files.ok()) {
+ // If the path doesn't exist then continue
+ return Status::OK();
+ }
+ if (maybe_files->size() > 1) {
+ return Status::Invalid(
+ "Could not write to ", options.base_dir,
+ " as the directory is not empty and existing_data_behavior is to error");
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+class DatasetWriter::DatasetWriterImpl : public util::AsyncDestroyable {
+ public:
+ DatasetWriterImpl(FileSystemDatasetWriteOptions write_options, uint64_t max_rows_queued)
+ : write_options_(std::move(write_options)),
+ rows_in_flight_throttle_(max_rows_queued),
+ open_files_throttle_(write_options_.max_open_files) {}
+
+ Future<> WriteRecordBatch(std::shared_ptr<RecordBatch> batch,
+ const std::string& directory) {
+ RETURN_NOT_OK(CheckError());
+ if (batch->num_rows() == 0) {
+ return Future<>::MakeFinished();
+ }
+ if (!directory.empty()) {
+ auto full_path =
+ fs::internal::ConcatAbstractPath(write_options_.base_dir, directory);
+ return DoWriteRecordBatch(std::move(batch), full_path);
+ } else {
+ return DoWriteRecordBatch(std::move(batch), write_options_.base_dir);
+ }
+ }
+
+ protected:
+ Status CloseLargestFile() {
+ std::shared_ptr<DatasetWriterDirectoryQueue> largest = nullptr;
+ uint64_t largest_num_rows = 0;
+ for (auto& dir_queue : directory_queues_) {
+ if (dir_queue.second->rows_written() > largest_num_rows) {
+ largest_num_rows = dir_queue.second->rows_written();
+ largest = dir_queue.second;
+ }
+ }
+ DCHECK_NE(largest, nullptr);
+ return largest->FinishCurrentFile();
+ }
+
+ Future<> DoWriteRecordBatch(std::shared_ptr<RecordBatch> batch,
+ const std::string& directory) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto dir_queue_itr,
+ ::arrow::internal::GetOrInsertGenerated(
+ &directory_queues_, directory, [this, &batch](const std::string& dir) {
+ return DatasetWriterDirectoryQueue::Make(
+ &task_group_, write_options_, &open_files_throttle_, batch->schema(),
+ dir, &visitors_mutex_);
+ }));
+ std::shared_ptr<DatasetWriterDirectoryQueue> dir_queue = dir_queue_itr->second;
+ std::vector<Future<WriteTask>> scheduled_writes;
+ Future<> backpressure;
+ while (batch) {
+ // Keep opening new files until batch is done.
+ std::shared_ptr<RecordBatch> remainder;
+ bool will_open_file = false;
+ ARROW_ASSIGN_OR_RAISE(auto next_chunk, dir_queue->NextWritableChunk(
+ batch, &remainder, &will_open_file));
+
+ backpressure = rows_in_flight_throttle_.Acquire(next_chunk->num_rows());
+ if (!backpressure.is_finished()) {
+ break;
+ }
+ if (will_open_file) {
+ backpressure = open_files_throttle_.Acquire(1);
+ if (!backpressure.is_finished()) {
+ RETURN_NOT_OK(CloseLargestFile());
+ break;
+ }
+ }
+ scheduled_writes.push_back(dir_queue->StartWrite(next_chunk));
+ batch = std::move(remainder);
+ if (batch) {
+ RETURN_NOT_OK(dir_queue->FinishCurrentFile());
+ }
+ }
+
+ for (auto& scheduled_write : scheduled_writes) {
+ RETURN_NOT_OK(task_group_.AddTask(scheduled_write.Then(
+ [this](const WriteTask& write) {
+ rows_in_flight_throttle_.Release(write.num_rows);
+ },
+ [this](const Status& err) { SetError(err); })));
+ // The previously added callback could run immediately and set err_ so we check
+ // it each time through the loop
+ RETURN_NOT_OK(CheckError());
+ }
+ if (batch) {
+ return backpressure.Then(
+ [this, batch, directory] { return DoWriteRecordBatch(batch, directory); });
+ }
+ return Future<>::MakeFinished();
+ }
+
+ void SetError(Status st) {
+ std::lock_guard<std::mutex> lg(mutex_);
+ err_ = std::move(st);
+ }
+
+ Status CheckError() {
+ std::lock_guard<std::mutex> lg(mutex_);
+ return err_;
+ }
+
+ Future<> DoDestroy() override {
+ directory_queues_.clear();
+ return task_group_.End().Then([this] { return err_; });
+ }
+
+ util::AsyncTaskGroup task_group_;
+ FileSystemDatasetWriteOptions write_options_;
+ Throttle rows_in_flight_throttle_;
+ Throttle open_files_throttle_;
+ std::unordered_map<std::string, std::shared_ptr<DatasetWriterDirectoryQueue>>
+ directory_queues_;
+ std::mutex mutex_;
+ // A mutex to guard access to the visitor callbacks
+ std::mutex visitors_mutex_;
+ Status err_;
+};
+
+DatasetWriter::DatasetWriter(FileSystemDatasetWriteOptions write_options,
+ uint64_t max_rows_queued)
+ : impl_(util::MakeUniqueAsync<DatasetWriterImpl>(std::move(write_options),
+ max_rows_queued)) {}
+
+Result<std::unique_ptr<DatasetWriter>> DatasetWriter::Make(
+ FileSystemDatasetWriteOptions write_options, uint64_t max_rows_queued) {
+ RETURN_NOT_OK(ValidateBasenameTemplate(write_options.basename_template));
+ RETURN_NOT_OK(EnsureDestinationValid(write_options));
+ return std::unique_ptr<DatasetWriter>(
+ new DatasetWriter(std::move(write_options), max_rows_queued));
+}
+
+DatasetWriter::~DatasetWriter() = default;
+
+Future<> DatasetWriter::WriteRecordBatch(std::shared_ptr<RecordBatch> batch,
+ const std::string& directory) {
+ return impl_->WriteRecordBatch(std::move(batch), directory);
+}
+
+Future<> DatasetWriter::Finish() {
+ Future<> finished = impl_->on_closed();
+ impl_.reset();
+ return finished;
+}
+
+} // namespace internal
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/dataset_writer.h b/src/arrow/cpp/src/arrow/dataset/dataset_writer.h
new file mode 100644
index 000000000..b014f9635
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/dataset_writer.h
@@ -0,0 +1,97 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+
+#include "arrow/dataset/file_base.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/util/async_util.h"
+#include "arrow/util/future.h"
+
+namespace arrow {
+namespace dataset {
+namespace internal {
+
+constexpr uint64_t kDefaultDatasetWriterMaxRowsQueued = 64 * 1024 * 1024;
+
+/// \brief Utility class that manages a set of writers to different paths
+///
+/// Writers may be closed and reopened (and a new file created) based on the dataset
+/// write options (for example, max_rows_per_file or max_open_files)
+///
+/// The dataset writer enforces its own back pressure based on the # of rows (as opposed
+/// to # of batches which is how it is typically enforced elsewhere) and # of files.
+class ARROW_DS_EXPORT DatasetWriter {
+ public:
+ /// \brief Create a dataset writer
+ ///
+ /// Will fail if basename_template is invalid or if there is existing data and
+ /// existing_data_behavior is kError
+ ///
+ /// \param write_options options to control how the data should be written
+ /// \param max_rows_queued max # of rows allowed to be queued before the dataset_writer
+ /// will ask for backpressure
+ static Result<std::unique_ptr<DatasetWriter>> Make(
+ FileSystemDatasetWriteOptions write_options,
+ uint64_t max_rows_queued = kDefaultDatasetWriterMaxRowsQueued);
+
+ ~DatasetWriter();
+
+ /// \brief Write a batch to the dataset
+ /// \param[in] batch The batch to write
+ /// \param[in] directory The directory to write to
+ ///
+ /// Note: The written filename will be {directory}/{filename_factory(i)} where i is a
+ /// counter controlled by `max_open_files` and `max_rows_per_file`
+ ///
+ /// If multiple WriteRecordBatch calls arrive with the same `directory` then the batches
+ /// may be written to the same file.
+ ///
+ /// The returned future will be marked finished when the record batch has been queued
+ /// to be written. If the returned future is unfinished then this indicates the dataset
+ /// writer's queue is full and the data provider should pause.
+ ///
+ /// This method is NOT async reentrant. The returned future will only be unfinished
+ /// if back pressure needs to be applied. Async reentrancy is not necessary for
+ /// concurrent writes to happen. Calling this method again before the previous future
+ /// completes will not just violate max_rows_queued but likely lead to race conditions.
+ ///
+ /// One thing to note is that the ordering of your data can affect your maximum
+ /// potential parallelism. If this seems odd then consider a dataset where the first
+ /// 1000 batches go to the same directory and then the 1001st batch goes to a different
+ /// directory. The only way to get two parallel writes immediately would be to queue
+ /// all 1000 pending writes to the first directory.
+ Future<> WriteRecordBatch(std::shared_ptr<RecordBatch> batch,
+ const std::string& directory);
+
+ /// Finish all pending writes and close any open files
+ Future<> Finish();
+
+ protected:
+ DatasetWriter(FileSystemDatasetWriteOptions write_options,
+ uint64_t max_rows_queued = kDefaultDatasetWriterMaxRowsQueued);
+
+ class DatasetWriterImpl;
+ std::unique_ptr<DatasetWriterImpl, util::DestroyingDeleter<DatasetWriterImpl>> impl_;
+};
+
+} // namespace internal
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/dataset_writer_test.cc b/src/arrow/cpp/src/arrow/dataset/dataset_writer_test.cc
new file mode 100644
index 000000000..bf38c2f60
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/dataset_writer_test.cc
@@ -0,0 +1,349 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/dataset_writer.h"
+
+#include <chrono>
+#include <mutex>
+#include <vector>
+
+#include "arrow/dataset/file_ipc.h"
+#include "arrow/filesystem/mockfs.h"
+#include "arrow/filesystem/test_util.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/table.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/optional.h"
+#include "gtest/gtest.h"
+
+namespace arrow {
+namespace dataset {
+namespace internal {
+
+using arrow::fs::internal::MockFileInfo;
+using arrow::fs::internal::MockFileSystem;
+
+class DatasetWriterTestFixture : public testing::Test {
+ protected:
+ struct ExpectedFile {
+ std::string filename;
+ uint64_t start;
+ uint64_t num_rows;
+ };
+
+ void SetUp() override {
+ fs::TimePoint mock_now = std::chrono::system_clock::now();
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<fs::FileSystem> fs,
+ MockFileSystem::Make(mock_now, {::arrow::fs::Dir("testdir")}));
+ filesystem_ = std::dynamic_pointer_cast<MockFileSystem>(fs);
+ schema_ = schema({field("int64", int64())});
+ write_options_.filesystem = filesystem_;
+ write_options_.basename_template = "chunk-{i}.arrow";
+ write_options_.base_dir = "testdir";
+ write_options_.writer_pre_finish = [this](FileWriter* writer) {
+ pre_finish_visited_.push_back(writer->destination().path);
+ return Status::OK();
+ };
+ write_options_.writer_post_finish = [this](FileWriter* writer) {
+ post_finish_visited_.push_back(writer->destination().path);
+ return Status::OK();
+ };
+ std::shared_ptr<FileFormat> format = std::make_shared<IpcFileFormat>();
+ write_options_.file_write_options = format->DefaultWriteOptions();
+ }
+
+ std::shared_ptr<fs::GatedMockFilesystem> UseGatedFs() {
+ fs::TimePoint mock_now = std::chrono::system_clock::now();
+ auto fs = std::make_shared<fs::GatedMockFilesystem>(mock_now);
+ ARROW_EXPECT_OK(fs->CreateDir("testdir"));
+ write_options_.filesystem = fs;
+ filesystem_ = fs;
+ return fs;
+ }
+
+ std::shared_ptr<RecordBatch> MakeBatch(uint64_t start, uint64_t num_rows) {
+ Int64Builder builder;
+ for (uint64_t i = 0; i < num_rows; i++) {
+ ARROW_EXPECT_OK(builder.Append(i + start));
+ }
+ EXPECT_OK_AND_ASSIGN(std::shared_ptr<Array> arr, builder.Finish());
+ return RecordBatch::Make(schema_, static_cast<int64_t>(num_rows), {std::move(arr)});
+ }
+
+ std::shared_ptr<RecordBatch> MakeBatch(uint64_t num_rows) {
+ std::shared_ptr<RecordBatch> batch = MakeBatch(counter_, num_rows);
+ counter_ += num_rows;
+ return batch;
+ }
+
+ util::optional<MockFileInfo> FindFile(const std::string& filename) {
+ for (const auto& mock_file : filesystem_->AllFiles()) {
+ if (mock_file.full_path == filename) {
+ return mock_file;
+ }
+ }
+ return util::nullopt;
+ }
+
+ void AssertVisited(const std::vector<std::string>& actual_paths,
+ const std::string& expected_path) {
+ const auto found = std::find(actual_paths.begin(), actual_paths.end(), expected_path);
+ ASSERT_NE(found, actual_paths.end())
+ << "The file " << expected_path << " was not in the list of files visited";
+ }
+
+ std::shared_ptr<RecordBatch> ReadAsBatch(util::string_view data) {
+ std::shared_ptr<io::RandomAccessFile> in_stream =
+ std::make_shared<io::BufferReader>(data);
+ EXPECT_OK_AND_ASSIGN(std::shared_ptr<ipc::RecordBatchFileReader> reader,
+ ipc::RecordBatchFileReader::Open(std::move(in_stream)));
+ RecordBatchVector batches;
+ for (int i = 0; i < reader->num_record_batches(); i++) {
+ EXPECT_OK_AND_ASSIGN(std::shared_ptr<RecordBatch> next_batch,
+ reader->ReadRecordBatch(i));
+ batches.push_back(next_batch);
+ }
+ EXPECT_OK_AND_ASSIGN(std::shared_ptr<Table> table, Table::FromRecordBatches(batches));
+ EXPECT_OK_AND_ASSIGN(std::shared_ptr<Table> combined_table, table->CombineChunks());
+ EXPECT_OK_AND_ASSIGN(std::shared_ptr<RecordBatch> batch,
+ TableBatchReader(*combined_table).Next());
+ return batch;
+ }
+
+ void AssertFileCreated(const util::optional<MockFileInfo>& maybe_file,
+ const std::string& expected_filename) {
+ ASSERT_TRUE(maybe_file.has_value())
+ << "The file " << expected_filename << " was not created";
+ {
+ SCOPED_TRACE("pre_finish");
+ AssertVisited(pre_finish_visited_, expected_filename);
+ }
+ {
+ SCOPED_TRACE("post_finish");
+ AssertVisited(post_finish_visited_, expected_filename);
+ }
+ }
+
+ void AssertCreatedData(const std::vector<ExpectedFile>& expected_files) {
+ counter_ = 0;
+ for (const auto& expected_file : expected_files) {
+ util::optional<MockFileInfo> written_file = FindFile(expected_file.filename);
+ AssertFileCreated(written_file, expected_file.filename);
+ AssertBatchesEqual(*MakeBatch(expected_file.start, expected_file.num_rows),
+ *ReadAsBatch(written_file->data));
+ }
+ }
+
+ void AssertFilesCreated(const std::vector<std::string>& expected_files) {
+ for (const std::string& expected_file : expected_files) {
+ util::optional<MockFileInfo> written_file = FindFile(expected_file);
+ AssertFileCreated(written_file, expected_file);
+ }
+ }
+
+ void AssertNotFiles(const std::vector<std::string>& expected_non_files) {
+ for (const auto& expected_non_file : expected_non_files) {
+ util::optional<MockFileInfo> file = FindFile(expected_non_file);
+ ASSERT_FALSE(file.has_value());
+ }
+ }
+
+ void AssertEmptyFiles(const std::vector<std::string>& expected_empty_files) {
+ for (const auto& expected_empty_file : expected_empty_files) {
+ util::optional<MockFileInfo> file = FindFile(expected_empty_file);
+ ASSERT_TRUE(file.has_value());
+ ASSERT_EQ("", file->data);
+ }
+ }
+
+ std::shared_ptr<MockFileSystem> filesystem_;
+ std::shared_ptr<Schema> schema_;
+ std::vector<std::string> pre_finish_visited_;
+ std::vector<std::string> post_finish_visited_;
+ FileSystemDatasetWriteOptions write_options_;
+ uint64_t counter_ = 0;
+};
+
+TEST_F(DatasetWriterTestFixture, Basic) {
+ EXPECT_OK_AND_ASSIGN(auto dataset_writer, DatasetWriter::Make(write_options_));
+ Future<> queue_fut = dataset_writer->WriteRecordBatch(MakeBatch(100), "");
+ AssertFinished(queue_fut);
+ ASSERT_FINISHES_OK(dataset_writer->Finish());
+ AssertCreatedData({{"testdir/chunk-0.arrow", 0, 100}});
+}
+
+TEST_F(DatasetWriterTestFixture, MaxRowsOneWrite) {
+ write_options_.max_rows_per_file = 10;
+ EXPECT_OK_AND_ASSIGN(auto dataset_writer, DatasetWriter::Make(write_options_));
+ Future<> queue_fut = dataset_writer->WriteRecordBatch(MakeBatch(35), "");
+ AssertFinished(queue_fut);
+ ASSERT_FINISHES_OK(dataset_writer->Finish());
+ AssertCreatedData({{"testdir/chunk-0.arrow", 0, 10},
+ {"testdir/chunk-1.arrow", 10, 10},
+ {"testdir/chunk-2.arrow", 20, 10},
+ {"testdir/chunk-3.arrow", 30, 5}});
+}
+
+TEST_F(DatasetWriterTestFixture, MaxRowsManyWrites) {
+ write_options_.max_rows_per_file = 10;
+ EXPECT_OK_AND_ASSIGN(auto dataset_writer, DatasetWriter::Make(write_options_));
+ ASSERT_FINISHES_OK(dataset_writer->WriteRecordBatch(MakeBatch(3), ""));
+ ASSERT_FINISHES_OK(dataset_writer->WriteRecordBatch(MakeBatch(3), ""));
+ ASSERT_FINISHES_OK(dataset_writer->WriteRecordBatch(MakeBatch(3), ""));
+ ASSERT_FINISHES_OK(dataset_writer->WriteRecordBatch(MakeBatch(3), ""));
+ ASSERT_FINISHES_OK(dataset_writer->WriteRecordBatch(MakeBatch(3), ""));
+ ASSERT_FINISHES_OK(dataset_writer->WriteRecordBatch(MakeBatch(3), ""));
+ ASSERT_FINISHES_OK(dataset_writer->Finish());
+ AssertCreatedData({{"testdir/chunk-0.arrow", 0, 10}, {"testdir/chunk-1.arrow", 10, 8}});
+}
+
+TEST_F(DatasetWriterTestFixture, ConcurrentWritesSameFile) {
+ // Use a gated filesystem to queue up many writes behind a file open to make sure the
+ // file isn't opened multiple times.
+ auto gated_fs = UseGatedFs();
+ EXPECT_OK_AND_ASSIGN(auto dataset_writer, DatasetWriter::Make(write_options_));
+ for (int i = 0; i < 10; i++) {
+ Future<> queue_fut = dataset_writer->WriteRecordBatch(MakeBatch(10), "");
+ AssertFinished(queue_fut);
+ ASSERT_FINISHES_OK(queue_fut);
+ }
+ ASSERT_OK(gated_fs->WaitForOpenOutputStream(1));
+ ASSERT_OK(gated_fs->UnlockOpenOutputStream(1));
+ ASSERT_FINISHES_OK(dataset_writer->Finish());
+ AssertCreatedData({{"testdir/chunk-0.arrow", 0, 100}});
+}
+
+TEST_F(DatasetWriterTestFixture, ConcurrentWritesDifferentFiles) {
+ // NBATCHES must be less than I/O executor concurrency to avoid deadlock / test failure
+ constexpr int NBATCHES = 6;
+ auto gated_fs = UseGatedFs();
+ std::vector<ExpectedFile> expected_files;
+ EXPECT_OK_AND_ASSIGN(auto dataset_writer, DatasetWriter::Make(write_options_));
+ for (int i = 0; i < NBATCHES; i++) {
+ std::string i_str = std::to_string(i);
+ expected_files.push_back(ExpectedFile{"testdir/part" + i_str + "/chunk-0.arrow",
+ static_cast<uint64_t>(i) * 10, 10});
+ Future<> queue_fut = dataset_writer->WriteRecordBatch(MakeBatch(10), "part" + i_str);
+ AssertFinished(queue_fut);
+ ASSERT_FINISHES_OK(queue_fut);
+ }
+ ASSERT_OK(gated_fs->WaitForOpenOutputStream(NBATCHES));
+ ASSERT_OK(gated_fs->UnlockOpenOutputStream(NBATCHES));
+ ASSERT_FINISHES_OK(dataset_writer->Finish());
+ AssertCreatedData(expected_files);
+}
+
+TEST_F(DatasetWriterTestFixture, MaxOpenFiles) {
+ auto gated_fs = UseGatedFs();
+ write_options_.max_open_files = 2;
+ EXPECT_OK_AND_ASSIGN(auto dataset_writer, DatasetWriter::Make(write_options_));
+
+ ASSERT_FINISHES_OK(dataset_writer->WriteRecordBatch(MakeBatch(10), "part0"));
+ ASSERT_FINISHES_OK(dataset_writer->WriteRecordBatch(MakeBatch(10), "part1"));
+ ASSERT_FINISHES_OK(dataset_writer->WriteRecordBatch(MakeBatch(10), "part0"));
+ Future<> fut = dataset_writer->WriteRecordBatch(MakeBatch(10), "part2");
+ // Backpressure will be applied until an existing file can be evicted
+ AssertNotFinished(fut);
+
+ // Ungate the writes to relieve the pressure, testdir/part0 should be closed
+ ASSERT_OK(gated_fs->WaitForOpenOutputStream(2));
+ ASSERT_OK(gated_fs->UnlockOpenOutputStream(5));
+ ASSERT_FINISHES_OK(fut);
+
+ ASSERT_FINISHES_OK(dataset_writer->WriteRecordBatch(MakeBatch(10), "part0"));
+ // Following call should resume existing write but, on slow test systems, the old
+ // write may have already been finished
+ ASSERT_FINISHES_OK(dataset_writer->WriteRecordBatch(MakeBatch(10), "part1"));
+ ASSERT_FINISHES_OK(dataset_writer->Finish());
+ AssertFilesCreated({"testdir/part0/chunk-0.arrow", "testdir/part0/chunk-1.arrow",
+ "testdir/part1/chunk-0.arrow", "testdir/part2/chunk-0.arrow"});
+}
+
+TEST_F(DatasetWriterTestFixture, DeleteExistingData) {
+ fs::TimePoint mock_now = std::chrono::system_clock::now();
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<fs::FileSystem> fs,
+ MockFileSystem::Make(
+ mock_now, {::arrow::fs::Dir("testdir"), fs::File("testdir/subdir/foo.txt"),
+ fs::File("testdir/chunk-5.arrow"), fs::File("testdir/blah.txt")}));
+ filesystem_ = std::dynamic_pointer_cast<MockFileSystem>(fs);
+ write_options_.filesystem = filesystem_;
+ write_options_.existing_data_behavior = ExistingDataBehavior::kDeleteMatchingPartitions;
+ EXPECT_OK_AND_ASSIGN(auto dataset_writer, DatasetWriter::Make(write_options_));
+ Future<> queue_fut = dataset_writer->WriteRecordBatch(MakeBatch(100), "");
+ AssertFinished(queue_fut);
+ ASSERT_FINISHES_OK(dataset_writer->Finish());
+ AssertCreatedData({{"testdir/chunk-0.arrow", 0, 100}});
+ AssertNotFiles({"testdir/chunk-5.arrow", "testdir/blah.txt", "testdir/subdir/foo.txt"});
+}
+
+TEST_F(DatasetWriterTestFixture, PartitionedDeleteExistingData) {
+ fs::TimePoint mock_now = std::chrono::system_clock::now();
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<fs::FileSystem> fs,
+ MockFileSystem::Make(
+ mock_now, {::arrow::fs::Dir("testdir"), fs::File("testdir/part0/foo.arrow"),
+ fs::File("testdir/part1/bar.arrow")}));
+ filesystem_ = std::dynamic_pointer_cast<MockFileSystem>(fs);
+ write_options_.filesystem = filesystem_;
+ write_options_.existing_data_behavior = ExistingDataBehavior::kDeleteMatchingPartitions;
+ EXPECT_OK_AND_ASSIGN(auto dataset_writer, DatasetWriter::Make(write_options_));
+ Future<> queue_fut = dataset_writer->WriteRecordBatch(MakeBatch(100), "part0");
+ AssertFinished(queue_fut);
+ ASSERT_FINISHES_OK(dataset_writer->Finish());
+ AssertCreatedData({{"testdir/part0/chunk-0.arrow", 0, 100}});
+ AssertNotFiles({"testdir/part0/foo.arrow"});
+ AssertEmptyFiles({"testdir/part1/bar.arrow"});
+}
+
+TEST_F(DatasetWriterTestFixture, LeaveExistingData) {
+ fs::TimePoint mock_now = std::chrono::system_clock::now();
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<fs::FileSystem> fs,
+ MockFileSystem::Make(
+ mock_now, {::arrow::fs::Dir("testdir"), fs::File("testdir/chunk-0.arrow"),
+ fs::File("testdir/chunk-5.arrow"), fs::File("testdir/blah.txt")}));
+ filesystem_ = std::dynamic_pointer_cast<MockFileSystem>(fs);
+ write_options_.filesystem = filesystem_;
+ write_options_.existing_data_behavior = ExistingDataBehavior::kOverwriteOrIgnore;
+ EXPECT_OK_AND_ASSIGN(auto dataset_writer, DatasetWriter::Make(write_options_));
+ Future<> queue_fut = dataset_writer->WriteRecordBatch(MakeBatch(100), "");
+ AssertFinished(queue_fut);
+ ASSERT_FINISHES_OK(dataset_writer->Finish());
+ AssertCreatedData({{"testdir/chunk-0.arrow", 0, 100}});
+ AssertEmptyFiles({"testdir/chunk-5.arrow", "testdir/blah.txt"});
+}
+
+TEST_F(DatasetWriterTestFixture, ErrOnExistingData) {
+ fs::TimePoint mock_now = std::chrono::system_clock::now();
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<fs::FileSystem> fs,
+ MockFileSystem::Make(
+ mock_now, {::arrow::fs::Dir("testdir"), fs::File("testdir/chunk-0.arrow"),
+ fs::File("testdir/chunk-5.arrow"), fs::File("testdir/blah.txt")}));
+ filesystem_ = std::dynamic_pointer_cast<MockFileSystem>(fs);
+ write_options_.filesystem = filesystem_;
+ ASSERT_RAISES(Invalid, DatasetWriter::Make(write_options_));
+ AssertEmptyFiles(
+ {"testdir/chunk-0.arrow", "testdir/chunk-5.arrow", "testdir/blah.txt"});
+}
+
+} // namespace internal
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/discovery.cc b/src/arrow/cpp/src/arrow/dataset/discovery.cc
new file mode 100644
index 000000000..0f9d479b9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/discovery.cc
@@ -0,0 +1,282 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/discovery.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "arrow/dataset/dataset.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/partition.h"
+#include "arrow/dataset/type_fwd.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace dataset {
+
+DatasetFactory::DatasetFactory() : root_partition_(compute::literal(true)) {}
+
+Result<std::shared_ptr<Schema>> DatasetFactory::Inspect(InspectOptions options) {
+ ARROW_ASSIGN_OR_RAISE(auto schemas, InspectSchemas(std::move(options)));
+
+ if (schemas.empty()) {
+ return arrow::schema({});
+ }
+
+ return UnifySchemas(schemas);
+}
+
+Result<std::shared_ptr<Dataset>> DatasetFactory::Finish() {
+ FinishOptions options;
+ return Finish(options);
+}
+
+Result<std::shared_ptr<Dataset>> DatasetFactory::Finish(std::shared_ptr<Schema> schema) {
+ FinishOptions options;
+ options.schema = schema;
+ return Finish(std::move(options));
+}
+
+UnionDatasetFactory::UnionDatasetFactory(
+ std::vector<std::shared_ptr<DatasetFactory>> factories)
+ : factories_(std::move(factories)) {}
+
+Result<std::shared_ptr<DatasetFactory>> UnionDatasetFactory::Make(
+ std::vector<std::shared_ptr<DatasetFactory>> factories) {
+ for (const auto& factory : factories) {
+ if (factory == nullptr) {
+ return Status::Invalid("Can't accept nullptr DatasetFactory");
+ }
+ }
+
+ return std::shared_ptr<UnionDatasetFactory>{
+ new UnionDatasetFactory(std::move(factories))};
+}
+
+Result<std::vector<std::shared_ptr<Schema>>> UnionDatasetFactory::InspectSchemas(
+ InspectOptions options) {
+ std::vector<std::shared_ptr<Schema>> schemas;
+
+ for (const auto& child_factory : factories_) {
+ ARROW_ASSIGN_OR_RAISE(auto child_schemas, child_factory->InspectSchemas(options));
+ ARROW_ASSIGN_OR_RAISE(auto child_schema, UnifySchemas(child_schemas));
+ schemas.emplace_back(child_schema);
+ }
+
+ return schemas;
+}
+
+Result<std::shared_ptr<Dataset>> UnionDatasetFactory::Finish(FinishOptions options) {
+ std::vector<std::shared_ptr<Dataset>> children;
+
+ if (options.schema == nullptr) {
+ // Set the schema in the option directly for use in `child_factory->Finish()`
+ ARROW_ASSIGN_OR_RAISE(options.schema, Inspect(options.inspect_options));
+ }
+
+ for (const auto& child_factory : factories_) {
+ ARROW_ASSIGN_OR_RAISE(auto child, child_factory->Finish(options));
+ children.emplace_back(child);
+ }
+
+ return std::shared_ptr<Dataset>(new UnionDataset(options.schema, std::move(children)));
+}
+
+FileSystemDatasetFactory::FileSystemDatasetFactory(
+ std::vector<fs::FileInfo> files, std::shared_ptr<fs::FileSystem> filesystem,
+ std::shared_ptr<FileFormat> format, FileSystemFactoryOptions options)
+ : files_(std::move(files)),
+ fs_(std::move(filesystem)),
+ format_(std::move(format)),
+ options_(std::move(options)) {}
+
+Result<std::shared_ptr<DatasetFactory>> FileSystemDatasetFactory::Make(
+ std::shared_ptr<fs::FileSystem> filesystem, const std::vector<std::string>& paths,
+ std::shared_ptr<FileFormat> format, FileSystemFactoryOptions options) {
+ std::vector<fs::FileInfo> filtered_files;
+ for (const auto& path : paths) {
+ if (options.exclude_invalid_files) {
+ ARROW_ASSIGN_OR_RAISE(auto supported,
+ format->IsSupported(FileSource(path, filesystem)));
+ if (!supported) {
+ continue;
+ }
+ }
+
+ filtered_files.emplace_back(path);
+ }
+
+ return std::shared_ptr<DatasetFactory>(
+ new FileSystemDatasetFactory(std::move(filtered_files), std::move(filesystem),
+ std::move(format), std::move(options)));
+}
+
+Result<std::shared_ptr<DatasetFactory>> FileSystemDatasetFactory::Make(
+ std::shared_ptr<fs::FileSystem> filesystem, const std::vector<fs::FileInfo>& files,
+ std::shared_ptr<FileFormat> format, FileSystemFactoryOptions options) {
+ std::vector<fs::FileInfo> filtered_files;
+ for (const auto& info : files) {
+ if (options.exclude_invalid_files) {
+ ARROW_ASSIGN_OR_RAISE(auto supported,
+ format->IsSupported(FileSource(info, filesystem)));
+ if (!supported) {
+ continue;
+ }
+ }
+
+ filtered_files.emplace_back(info);
+ }
+
+ return std::shared_ptr<DatasetFactory>(
+ new FileSystemDatasetFactory(std::move(filtered_files), std::move(filesystem),
+ std::move(format), std::move(options)));
+}
+
+bool StartsWithAnyOf(const std::string& path, const std::vector<std::string>& prefixes) {
+ if (prefixes.empty()) {
+ return false;
+ }
+
+ auto parts = fs::internal::SplitAbstractPath(path);
+ return std::any_of(parts.cbegin(), parts.cend(), [&](util::string_view part) {
+ return std::any_of(prefixes.cbegin(), prefixes.cend(), [&](util::string_view prefix) {
+ return util::string_view(part).starts_with(prefix);
+ });
+ });
+}
+
+Result<std::shared_ptr<DatasetFactory>> FileSystemDatasetFactory::Make(
+ std::shared_ptr<fs::FileSystem> filesystem, fs::FileSelector selector,
+ std::shared_ptr<FileFormat> format, FileSystemFactoryOptions options) {
+ // By automatically setting the options base_dir to the selector's base_dir,
+ // we provide a better experience for user providing Partitioning that are
+ // relative to the base_dir instead of the full path.
+ if (options.partition_base_dir.empty() && !selector.base_dir.empty()) {
+ options.partition_base_dir = selector.base_dir;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(selector.base_dir, filesystem->NormalizePath(selector.base_dir));
+ ARROW_ASSIGN_OR_RAISE(auto files, filesystem->GetFileInfo(selector));
+
+ // Filter out anything that's not a file or that's explicitly ignored
+ Status st;
+ auto files_end =
+ std::remove_if(files.begin(), files.end(), [&](const fs::FileInfo& info) {
+ if (!info.IsFile()) return true;
+
+ auto relative = fs::internal::RemoveAncestor(selector.base_dir, info.path());
+ if (!relative.has_value()) {
+ st = Status::Invalid("GetFileInfo() yielded path '", info.path(),
+ "', which is outside base dir '", selector.base_dir, "'");
+ return false;
+ }
+
+ if (StartsWithAnyOf(std::string(*relative), options.selector_ignore_prefixes)) {
+ return true;
+ }
+
+ return false;
+ });
+ RETURN_NOT_OK(st);
+ files.erase(files_end, files.end());
+
+ // Sorting by path guarantees a stability sometimes needed by unit tests.
+ std::sort(files.begin(), files.end(), fs::FileInfo::ByPath());
+
+ return Make(std::move(filesystem), std::move(files), std::move(format),
+ std::move(options));
+}
+
+Result<std::shared_ptr<DatasetFactory>> FileSystemDatasetFactory::Make(
+ std::string uri, std::shared_ptr<FileFormat> format,
+ FileSystemFactoryOptions options) {
+ std::string internal_path;
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<fs::FileSystem> filesystem,
+ arrow::fs::FileSystemFromUri(uri, &internal_path))
+ ARROW_ASSIGN_OR_RAISE(fs::FileInfo file_info, filesystem->GetFileInfo(internal_path))
+ return std::shared_ptr<DatasetFactory>(new FileSystemDatasetFactory(
+ {file_info}, std::move(filesystem), std::move(format), std::move(options)));
+}
+
+Result<std::vector<std::shared_ptr<Schema>>> FileSystemDatasetFactory::InspectSchemas(
+ InspectOptions options) {
+ std::vector<std::shared_ptr<Schema>> schemas;
+
+ const bool has_fragments_limit = options.fragments >= 0;
+ int fragments = options.fragments;
+ for (const auto& info : files_) {
+ if (has_fragments_limit && fragments-- == 0) break;
+ auto result = format_->Inspect({info, fs_});
+ if (ARROW_PREDICT_FALSE(!result.ok())) {
+ return result.status().WithMessage(
+ "Error creating dataset. Could not read schema from '", info.path(),
+ "': ", result.status().message(), ". Is this a '", format_->type_name(),
+ "' file?");
+ }
+ schemas.push_back(result.MoveValueUnsafe());
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto partition_schema,
+ options_.partitioning.GetOrInferSchema(
+ StripPrefixAndFilename(files_, options_.partition_base_dir)));
+ schemas.push_back(partition_schema);
+
+ return schemas;
+}
+
+Result<std::shared_ptr<Dataset>> FileSystemDatasetFactory::Finish(FinishOptions options) {
+ std::shared_ptr<Schema> schema = options.schema;
+ bool schema_missing = schema == nullptr;
+ if (schema_missing) {
+ ARROW_ASSIGN_OR_RAISE(schema, Inspect(options.inspect_options));
+ }
+
+ if (options.validate_fragments && !schema_missing) {
+ // If the schema was not explicitly provided we don't need to validate
+ // since Inspect has already succeeded in producing a valid unified schema.
+ ARROW_ASSIGN_OR_RAISE(auto schemas, InspectSchemas(options.inspect_options));
+ for (const auto& s : schemas) {
+ RETURN_NOT_OK(SchemaBuilder::AreCompatible({schema, s}));
+ }
+ }
+
+ std::shared_ptr<Partitioning> partitioning = options_.partitioning.partitioning();
+ if (partitioning == nullptr) {
+ auto factory = options_.partitioning.factory();
+ ARROW_ASSIGN_OR_RAISE(partitioning, factory->Finish(schema));
+ }
+
+ std::vector<std::shared_ptr<FileFragment>> fragments;
+ for (const auto& info : files_) {
+ auto fixed_path = StripPrefixAndFilename(info.path(), options_.partition_base_dir);
+ ARROW_ASSIGN_OR_RAISE(auto partition, partitioning->Parse(fixed_path));
+ ARROW_ASSIGN_OR_RAISE(auto fragment, format_->MakeFragment({info, fs_}, partition));
+ fragments.push_back(fragment);
+ }
+
+ return FileSystemDataset::Make(std::move(schema), root_partition_, format_, fs_,
+ std::move(fragments), std::move(partitioning));
+}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/discovery.h b/src/arrow/cpp/src/arrow/dataset/discovery.h
new file mode 100644
index 000000000..40c020519
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/discovery.h
@@ -0,0 +1,271 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Logic for automatically determining the structure of multi-file
+/// dataset with possible partitioning according to available
+/// partitioning
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/dataset/partition.h"
+#include "arrow/dataset/type_fwd.h"
+#include "arrow/dataset/visibility.h"
+#include "arrow/filesystem/type_fwd.h"
+#include "arrow/result.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/variant.h"
+
+namespace arrow {
+namespace dataset {
+
+/// \defgroup dataset-discovery Discovery API
+///
+/// @{
+
+struct InspectOptions {
+ /// See `fragments` property.
+ static constexpr int kInspectAllFragments = -1;
+
+ /// Indicate how many fragments should be inspected to infer the unified dataset
+ /// schema. Limiting the number of fragments accessed improves the latency of
+ /// the discovery process when dealing with a high number of fragments and/or
+ /// high latency file systems.
+ ///
+ /// The default value of `1` inspects the schema of the first (in no particular
+ /// order) fragment only. If the dataset has a uniform schema for all fragments,
+ /// this default is the optimal value. In order to inspect all fragments and
+ /// robustly unify their potentially varying schemas, set this option to
+ /// `kInspectAllFragments`. A value of `0` disables inspection of fragments
+ /// altogether so only the partitioning schema will be inspected.
+ int fragments = 1;
+};
+
+struct FinishOptions {
+ /// Finalize the dataset with this given schema. If the schema is not
+ /// provided, infer the schema via the Inspect, see the `inspect_options`
+ /// property.
+ std::shared_ptr<Schema> schema = NULLPTR;
+
+ /// If the schema is not provided, it will be discovered by passing the
+ /// following options to `DatasetDiscovery::Inspect`.
+ InspectOptions inspect_options{};
+
+ /// Indicate if the given Schema (when specified), should be validated against
+ /// the fragments' schemas. `inspect_options` will control how many fragments
+ /// are checked.
+ bool validate_fragments = false;
+};
+
+/// \brief DatasetFactory provides a way to inspect/discover a Dataset's expected
+/// schema before materializing said Dataset.
+class ARROW_DS_EXPORT DatasetFactory {
+ public:
+ /// \brief Get the schemas of the Fragments and Partitioning.
+ virtual Result<std::vector<std::shared_ptr<Schema>>> InspectSchemas(
+ InspectOptions options) = 0;
+
+ /// \brief Get unified schema for the resulting Dataset.
+ Result<std::shared_ptr<Schema>> Inspect(InspectOptions options = {});
+
+ /// \brief Create a Dataset
+ Result<std::shared_ptr<Dataset>> Finish();
+ /// \brief Create a Dataset with the given schema (see \a InspectOptions::schema)
+ Result<std::shared_ptr<Dataset>> Finish(std::shared_ptr<Schema> schema);
+ /// \brief Create a Dataset with the given options
+ virtual Result<std::shared_ptr<Dataset>> Finish(FinishOptions options) = 0;
+
+ /// \brief Optional root partition for the resulting Dataset.
+ const compute::Expression& root_partition() const { return root_partition_; }
+ /// \brief Set the root partition for the resulting Dataset.
+ Status SetRootPartition(compute::Expression partition) {
+ root_partition_ = std::move(partition);
+ return Status::OK();
+ }
+
+ virtual ~DatasetFactory() = default;
+
+ protected:
+ DatasetFactory();
+
+ compute::Expression root_partition_;
+};
+
+/// @}
+
+/// \brief DatasetFactory provides a way to inspect/discover a Dataset's
+/// expected schema before materialization.
+/// \ingroup dataset-implementations
+class ARROW_DS_EXPORT UnionDatasetFactory : public DatasetFactory {
+ public:
+ static Result<std::shared_ptr<DatasetFactory>> Make(
+ std::vector<std::shared_ptr<DatasetFactory>> factories);
+
+ /// \brief Return the list of child DatasetFactory
+ const std::vector<std::shared_ptr<DatasetFactory>>& factories() const {
+ return factories_;
+ }
+
+ /// \brief Get the schemas of the Datasets.
+ ///
+ /// Instead of applying options globally, it applies at each child factory.
+ /// This will not respect `options.fragments` exactly, but will respect the
+ /// spirit of peeking the first fragments or all of them.
+ Result<std::vector<std::shared_ptr<Schema>>> InspectSchemas(
+ InspectOptions options) override;
+
+ /// \brief Create a Dataset.
+ Result<std::shared_ptr<Dataset>> Finish(FinishOptions options) override;
+
+ protected:
+ explicit UnionDatasetFactory(std::vector<std::shared_ptr<DatasetFactory>> factories);
+
+ std::vector<std::shared_ptr<DatasetFactory>> factories_;
+};
+
+/// \ingroup dataset-filesystem
+struct FileSystemFactoryOptions {
+ /// Either an explicit Partitioning or a PartitioningFactory to discover one.
+ ///
+ /// If a factory is provided, it will be used to infer a schema for partition fields
+ /// based on file and directory paths then construct a Partitioning. The default
+ /// is a Partitioning which will yield no partition information.
+ ///
+ /// The (explicit or discovered) partitioning will be applied to discovered files
+ /// and the resulting partition information embedded in the Dataset.
+ PartitioningOrFactory partitioning{Partitioning::Default()};
+
+ /// For the purposes of applying the partitioning, paths will be stripped
+ /// of the partition_base_dir. Files not matching the partition_base_dir
+ /// prefix will be skipped for partition discovery. The ignored files will still
+ /// be part of the Dataset, but will not have partition information.
+ ///
+ /// Example:
+ /// partition_base_dir = "/dataset";
+ ///
+ /// - "/dataset/US/sales.csv" -> "US/sales.csv" will be given to the partitioning
+ ///
+ /// - "/home/john/late_sales.csv" -> Will be ignored for partition discovery.
+ ///
+ /// This is useful for partitioning which parses directory when ordering
+ /// is important, e.g. DirectoryPartitioning.
+ std::string partition_base_dir;
+
+ /// Invalid files (via selector or explicitly) will be excluded by checking
+ /// with the FileFormat::IsSupported method. This will incur IO for each files
+ /// in a serial and single threaded fashion. Disabling this feature will skip the
+ /// IO, but unsupported files may be present in the Dataset
+ /// (resulting in an error at scan time).
+ bool exclude_invalid_files = false;
+
+ /// When discovering from a Selector (and not from an explicit file list), ignore
+ /// files and directories matching any of these prefixes.
+ ///
+ /// Example (with selector = "/dataset/**"):
+ /// selector_ignore_prefixes = {"_", ".DS_STORE" };
+ ///
+ /// - "/dataset/data.csv" -> not ignored
+ /// - "/dataset/_metadata" -> ignored
+ /// - "/dataset/.DS_STORE" -> ignored
+ /// - "/dataset/_hidden/dat" -> ignored
+ /// - "/dataset/nested/.DS_STORE" -> ignored
+ std::vector<std::string> selector_ignore_prefixes = {
+ ".",
+ "_",
+ };
+};
+
+/// \brief FileSystemDatasetFactory creates a Dataset from a vector of
+/// fs::FileInfo or a fs::FileSelector.
+/// \ingroup dataset-filesystem
+class ARROW_DS_EXPORT FileSystemDatasetFactory : public DatasetFactory {
+ public:
+ /// \brief Build a FileSystemDatasetFactory from an explicit list of
+ /// paths.
+ ///
+ /// \param[in] filesystem passed to FileSystemDataset
+ /// \param[in] paths passed to FileSystemDataset
+ /// \param[in] format passed to FileSystemDataset
+ /// \param[in] options see FileSystemFactoryOptions for more information.
+ static Result<std::shared_ptr<DatasetFactory>> Make(
+ std::shared_ptr<fs::FileSystem> filesystem, const std::vector<std::string>& paths,
+ std::shared_ptr<FileFormat> format, FileSystemFactoryOptions options);
+
+ /// \brief Build a FileSystemDatasetFactory from a fs::FileSelector.
+ ///
+ /// The selector will expand to a vector of FileInfo. The expansion/crawling
+ /// is performed in this function call. Thus, the finalized Dataset is
+ /// working with a snapshot of the filesystem.
+ //
+ /// If options.partition_base_dir is not provided, it will be overwritten
+ /// with selector.base_dir.
+ ///
+ /// \param[in] filesystem passed to FileSystemDataset
+ /// \param[in] selector used to crawl and search files
+ /// \param[in] format passed to FileSystemDataset
+ /// \param[in] options see FileSystemFactoryOptions for more information.
+ static Result<std::shared_ptr<DatasetFactory>> Make(
+ std::shared_ptr<fs::FileSystem> filesystem, fs::FileSelector selector,
+ std::shared_ptr<FileFormat> format, FileSystemFactoryOptions options);
+
+ /// \brief Build a FileSystemDatasetFactory from an uri including filesystem
+ /// information.
+ ///
+ /// \param[in] uri passed to FileSystemDataset
+ /// \param[in] format passed to FileSystemDataset
+ /// \param[in] options see FileSystemFactoryOptions for more information.
+ static Result<std::shared_ptr<DatasetFactory>> Make(std::string uri,
+ std::shared_ptr<FileFormat> format,
+ FileSystemFactoryOptions options);
+
+ /// \brief Build a FileSystemDatasetFactory from an explicit list of
+ /// file information.
+ ///
+ /// \param[in] filesystem passed to FileSystemDataset
+ /// \param[in] files passed to FileSystemDataset
+ /// \param[in] format passed to FileSystemDataset
+ /// \param[in] options see FileSystemFactoryOptions for more information.
+ static Result<std::shared_ptr<DatasetFactory>> Make(
+ std::shared_ptr<fs::FileSystem> filesystem, const std::vector<fs::FileInfo>& files,
+ std::shared_ptr<FileFormat> format, FileSystemFactoryOptions options);
+
+ Result<std::vector<std::shared_ptr<Schema>>> InspectSchemas(
+ InspectOptions options) override;
+
+ Result<std::shared_ptr<Dataset>> Finish(FinishOptions options) override;
+
+ protected:
+ FileSystemDatasetFactory(std::vector<fs::FileInfo> files,
+ std::shared_ptr<fs::FileSystem> filesystem,
+ std::shared_ptr<FileFormat> format,
+ FileSystemFactoryOptions options);
+
+ Result<std::shared_ptr<Schema>> PartitionSchema();
+
+ std::vector<fs::FileInfo> files_;
+ std::shared_ptr<fs::FileSystem> fs_;
+ std::shared_ptr<FileFormat> format_;
+ FileSystemFactoryOptions options_;
+};
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/discovery_test.cc b/src/arrow/cpp/src/arrow/dataset/discovery_test.cc
new file mode 100644
index 000000000..a51b3c099
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/discovery_test.cc
@@ -0,0 +1,479 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/discovery.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <utility>
+
+#include "arrow/dataset/partition.h"
+#include "arrow/dataset/test_util.h"
+#include "arrow/filesystem/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/checked_cast.h"
+
+using testing::SizeIs;
+
+namespace arrow {
+namespace dataset {
+
+void AssertSchemasAre(std::vector<std::shared_ptr<Schema>> actual,
+ std::vector<std::shared_ptr<Schema>> expected) {
+ EXPECT_EQ(actual.size(), expected.size());
+ for (size_t i = 0; i < actual.size(); i++) {
+ EXPECT_EQ(*actual[i], *expected[i]);
+ }
+}
+
+class DatasetFactoryTest : public TestFileSystemDataset {
+ public:
+ void AssertInspect(std::shared_ptr<Schema> expected, InspectOptions options = {}) {
+ ASSERT_OK_AND_ASSIGN(auto actual, factory_->Inspect(options));
+ EXPECT_EQ(*actual, *expected);
+ }
+
+ void AssertInspectSchemas(std::vector<std::shared_ptr<Schema>> expected,
+ InspectOptions options = {}) {
+ ASSERT_OK_AND_ASSIGN(auto actual, factory_->InspectSchemas(options));
+ AssertSchemasAre(actual, expected);
+ }
+
+ protected:
+ std::shared_ptr<DatasetFactory> factory_;
+};
+
+class MockDatasetFactory : public DatasetFactory {
+ public:
+ explicit MockDatasetFactory(std::vector<std::shared_ptr<Schema>> schemas)
+ : schemas_(std::move(schemas)) {}
+
+ Result<std::vector<std::shared_ptr<Schema>>> InspectSchemas(
+ InspectOptions options) override {
+ return schemas_;
+ }
+
+ Result<std::shared_ptr<Dataset>> Finish(FinishOptions options) override {
+ return std::make_shared<InMemoryDataset>(options.schema,
+ std::vector<std::shared_ptr<RecordBatch>>{});
+ }
+
+ protected:
+ std::vector<std::shared_ptr<Schema>> schemas_;
+};
+
+class MockDatasetFactoryTest : public DatasetFactoryTest {
+ public:
+ void MakeFactory(std::vector<std::shared_ptr<Schema>> schemas) {
+ factory_ = std::make_shared<MockDatasetFactory>(schemas);
+ }
+
+ protected:
+ std::shared_ptr<Field> i32 = field("i32", int32());
+ std::shared_ptr<Field> i64 = field("i64", int64());
+ std::shared_ptr<Field> f32 = field("f32", float64());
+ std::shared_ptr<Field> f64 = field("f64", float64());
+ // Non-nullable
+ std::shared_ptr<Field> i32_req = field("i32", int32(), false);
+ // bad type with name `i32`
+ std::shared_ptr<Field> i32_fake = field("i32", boolean());
+};
+
+TEST_F(MockDatasetFactoryTest, UnifySchemas) {
+ MakeFactory({});
+ AssertInspect(schema({}));
+
+ MakeFactory({schema({i32}), schema({i32})});
+ AssertInspect(schema({i32}));
+
+ MakeFactory({schema({i32}), schema({i64})});
+ AssertInspect(schema({i32, i64}));
+
+ MakeFactory({schema({i32}), schema({i64})});
+ AssertInspect(schema({i32, i64}));
+
+ MakeFactory({schema({i32}), schema({i32_req})});
+ AssertInspect(schema({i32}));
+
+ MakeFactory({schema({i32, f64}), schema({i32_req, i64})});
+ AssertInspect(schema({i32, f64, i64}));
+
+ MakeFactory({schema({i32, f64}), schema({f64, i32_fake})});
+ // Unification fails when fields with the same name have clashing types.
+ ASSERT_RAISES(Invalid, factory_->Inspect());
+ // Return the individual schema for closer inspection should not fail.
+ AssertInspectSchemas({schema({i32, f64}), schema({f64, i32_fake})});
+}
+
+class FileSystemDatasetFactoryTest : public DatasetFactoryTest {
+ public:
+ void MakeFactory(const std::vector<fs::FileInfo>& files) {
+ MakeFileSystem(files);
+ ASSERT_OK_AND_ASSIGN(factory_, FileSystemDatasetFactory::Make(fs_, selector_, format_,
+ factory_options_));
+ }
+
+ void AssertFinishWithPaths(std::vector<std::string> paths,
+ std::shared_ptr<Schema> schema = nullptr,
+ InspectOptions options = {}) {
+ if (schema == nullptr) {
+ ASSERT_OK_AND_ASSIGN(schema, factory_->Inspect(options));
+ }
+ options_ = std::make_shared<ScanOptions>();
+ options_->dataset_schema = schema;
+ ASSERT_OK(SetProjection(options_.get(), schema->field_names()));
+ ASSERT_OK_AND_ASSIGN(dataset_, factory_->Finish(schema));
+ ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset_->GetFragments());
+ AssertFragmentsAreFromPath(std::move(fragment_it), paths);
+ }
+
+ protected:
+ fs::FileSelector selector_;
+ FileSystemFactoryOptions factory_options_;
+ std::shared_ptr<FileFormat> format_ = std::make_shared<DummyFileFormat>(schema({}));
+};
+
+TEST_F(FileSystemDatasetFactoryTest, Basic) {
+ MakeFactory({fs::File("a"), fs::File("b")});
+ AssertFinishWithPaths({"a", "b"});
+ MakeFactory({fs::Dir("a"), fs::Dir("a/b"), fs::File("a/b/c")});
+}
+
+TEST_F(FileSystemDatasetFactoryTest, Selector) {
+ selector_.base_dir = "A";
+ selector_.recursive = true;
+
+ MakeFactory({fs::File("0"), fs::File("A/a"), fs::File("A/A/a")});
+ // "0" doesn't match selector, so it has been dropped:
+ AssertFinishWithPaths({"A/a", "A/A/a"});
+
+ factory_options_.partition_base_dir = "A/A";
+ MakeFactory({fs::File("0"), fs::File("A/a"), fs::File("A/A/a")});
+ // partition_base_dir should not affect filtered files, only the applied partition
+ AssertInspect(schema({}));
+ AssertFinishWithPaths({"A/a", "A/A/a"});
+}
+
+TEST_F(FileSystemDatasetFactoryTest, ExplicitPartition) {
+ selector_.base_dir = "a=ignored/base";
+ auto part_field = field("a", int32());
+ factory_options_.partitioning =
+ std::make_shared<HivePartitioning>(schema({part_field}));
+
+ auto a_1 = "a=ignored/base/a=1";
+ MakeFactory({fs::File(a_1)});
+
+ InspectOptions options;
+ // Should inspect the partition's Schema even if no files are inspected.
+ options.fragments = 0;
+ AssertInspect(schema({part_field}), options);
+ AssertFinishWithPaths({a_1}, nullptr, options);
+}
+
+TEST_F(FileSystemDatasetFactoryTest, DiscoveredPartition) {
+ selector_.base_dir = "a=ignored/base";
+ selector_.recursive = true;
+ factory_options_.partitioning = HivePartitioning::MakeFactory();
+
+ auto a_1 = "a=ignored/base/a=1/file.data";
+ MakeFactory({fs::File(a_1)});
+
+ InspectOptions options;
+
+ auto schema_with = schema({field("a", int32())});
+ AssertInspect(schema_with, options);
+ AssertFinishWithPaths({a_1}, schema_with);
+}
+
+TEST_F(FileSystemDatasetFactoryTest, MissingDirectories) {
+ auto partition_path = "base_dir/a=3/b=3/dat";
+ auto unpartition_path = "unpartitioned/ignored=3";
+ MakeFileSystem({fs::File(partition_path), fs::File(unpartition_path)});
+
+ factory_options_.partition_base_dir = "base_dir";
+ factory_options_.partitioning = std::make_shared<HivePartitioning>(
+ schema({field("a", int32()), field("b", int32())}));
+
+ auto paths = std::vector<std::string>{partition_path, unpartition_path};
+
+ ASSERT_OK_AND_ASSIGN(
+ factory_, FileSystemDatasetFactory::Make(fs_, paths, format_, factory_options_));
+
+ InspectOptions options;
+ AssertInspect(schema({field("a", int32()), field("b", int32())}), options);
+ AssertFinishWithPaths({partition_path, unpartition_path});
+}
+
+TEST_F(FileSystemDatasetFactoryTest, OptionsIgnoredDefaultPrefixes) {
+ // When constructing a factory from a FileSelector,
+ // `selector_ignore_prefixes` governs which files are filtered out.
+ selector_.recursive = true;
+ MakeFactory({
+ fs::File("."),
+ fs::File("_"),
+ fs::File("_$folder$/dat"),
+ fs::File("_SUCCESS"),
+ fs::File("not_ignored_by_default"),
+ fs::File("not_ignored_by_default_either/dat"),
+ });
+
+ AssertFinishWithPaths({"not_ignored_by_default", "not_ignored_by_default_either/dat"});
+}
+
+TEST_F(FileSystemDatasetFactoryTest, OptionsIgnoredDefaultExplicitFiles) {
+ // When constructing a factory from an explicit list of paths,
+ // `selector_ignore_prefixes` is ignored.
+ selector_.recursive = true;
+ std::vector<fs::FileInfo> ignored_by_default = {
+ fs::File(".ignored_by_default.parquet"),
+ fs::File("_ignored_by_default.csv"),
+ fs::File("_$folder$/ignored_by_default.arrow"),
+ };
+ MakeFileSystem(ignored_by_default);
+
+ std::vector<std::string> paths;
+ for (const auto& info : ignored_by_default) paths.push_back(info.path());
+ ASSERT_OK_AND_ASSIGN(
+ factory_, FileSystemDatasetFactory::Make(fs_, paths, format_, factory_options_));
+
+ AssertFinishWithPaths(paths);
+}
+
+TEST_F(FileSystemDatasetFactoryTest, OptionsIgnoredCustomPrefixes) {
+ selector_.recursive = true;
+ factory_options_.selector_ignore_prefixes = {"not_ignored"};
+ MakeFactory({
+ fs::File("."),
+ fs::File("_"),
+ fs::File("_$folder$/dat"),
+ fs::File("_SUCCESS"),
+ fs::File("not_ignored_by_default"),
+ fs::File("not_ignored_by_default_either/dat"),
+ });
+
+ AssertFinishWithPaths({".", "_", "_$folder$/dat", "_SUCCESS"});
+}
+
+TEST_F(FileSystemDatasetFactoryTest, OptionsIgnoredNoPrefixes) {
+ // Ignore nothing
+ selector_.recursive = true;
+ factory_options_.selector_ignore_prefixes = {};
+ MakeFactory({
+ fs::File("."),
+ fs::File("_"),
+ fs::File("_$folder$/dat"),
+ fs::File("_SUCCESS"),
+ fs::File("not_ignored_by_default"),
+ fs::File("not_ignored_by_default_either/dat"),
+ });
+
+ AssertFinishWithPaths({".", "_", "_$folder$/dat", "_SUCCESS", "not_ignored_by_default",
+ "not_ignored_by_default_either/dat"});
+}
+
+TEST_F(FileSystemDatasetFactoryTest, OptionsIgnoredPrefixesWithBaseDirectory) {
+ // ARROW-9644: the selector base_dir shouldn't be filtered out even if matches
+ // `selector_ignore_prefixes`.
+ std::string dir = "_shouldnt_be_ignored/.dataset/";
+ selector_.base_dir = dir;
+ selector_.recursive = true;
+ MakeFactory({
+ fs::File(dir + "."),
+ fs::File(dir + "_"),
+ fs::File(dir + "_$folder$/dat"),
+ fs::File(dir + "_SUCCESS"),
+ fs::File(dir + "not_ignored_by_default"),
+ fs::File(dir + "not_ignored_by_default_either/dat"),
+ });
+
+ AssertFinishWithPaths(
+ {dir + "not_ignored_by_default", dir + "not_ignored_by_default_either/dat"});
+}
+
+TEST_F(FileSystemDatasetFactoryTest, Inspect) {
+ auto s = schema({field("f64", float64())});
+ format_ = std::make_shared<DummyFileFormat>(s);
+
+ // No files
+ MakeFactory({});
+ AssertInspect(schema({}));
+
+ MakeFactory({fs::File("test")});
+ AssertInspect(s);
+}
+
+TEST_F(FileSystemDatasetFactoryTest, FinishWithIncompatibleSchemaShouldFail) {
+ auto s = schema({field("f64", float64())});
+ format_ = std::make_shared<DummyFileFormat>(s);
+
+ auto broken_s = schema({field("f64", utf8())});
+
+ FinishOptions options;
+ options.schema = broken_s;
+ options.validate_fragments = true;
+
+ // No files and validation
+ MakeFactory({});
+ ASSERT_OK_AND_ASSIGN(auto dataset, factory_->Finish(options));
+
+ MakeFactory({fs::File("test")});
+ ASSERT_RAISES(Invalid, factory_->Finish(options));
+
+ // Disable validation
+ options.validate_fragments = false;
+ ASSERT_OK_AND_ASSIGN(dataset, factory_->Finish(options));
+}
+
+TEST_F(FileSystemDatasetFactoryTest, InspectFragmentsLimit) {
+ MakeFactory({fs::File("a"), fs::File("b"), fs::File("c")});
+
+ InspectOptions options;
+ // By default, inspect one fragment and the partitioning.
+ ASSERT_OK_AND_ASSIGN(auto schemas, factory_->InspectSchemas(options));
+ EXPECT_THAT(schemas, SizeIs(2));
+
+ for (int fragments = 0; fragments < 3; fragments++) {
+ options.fragments = fragments;
+ ASSERT_OK_AND_ASSIGN(auto schemas, factory_->InspectSchemas(options));
+ EXPECT_THAT(schemas, SizeIs(fragments + 1));
+ }
+}
+
+TEST_F(FileSystemDatasetFactoryTest, FilenameNotPartOfPartitions) {
+ // ARROW-8726: Ensure filename is not a partition.
+
+ // Creates a partition with 2 explicit fields. The type `int32` is
+ // specifically chosen such that parsing would fail given a non-integer
+ // string.
+ auto s = schema({field("first", utf8()), field("second", int32())});
+ factory_options_.partitioning = std::make_shared<DirectoryPartitioning>(s);
+
+ selector_.recursive = true;
+ // The file doesn't have a directory component for the second partition
+ // column. In such case, the filename should not be used.
+ MakeFactory({fs::File("one/file.parquet")});
+
+ auto expected = equal(field_ref("first"), literal("one"));
+
+ ASSERT_OK_AND_ASSIGN(auto dataset, factory_->Finish());
+ ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments());
+ for (const auto& maybe_fragment : fragment_it) {
+ ASSERT_OK_AND_ASSIGN(auto fragment, maybe_fragment);
+ EXPECT_EQ(fragment->partition_expression(), expected);
+ }
+}
+
+TEST_F(FileSystemDatasetFactoryTest, UnparseablePartitionExpression) {
+ auto s = schema({field("first", int32()), field("second", int32())});
+ factory_options_.partitioning = std::make_shared<HivePartitioning>(s);
+ selector_.recursive = true;
+
+ for (auto pathlist : {"first=one/file.parquet", "second=one/file.parquet",
+ R"(first=1/second=0/file.parquet
+ first=1/second=zero/file.parquet)"}) {
+ MakeFactory(ParsePathList(pathlist));
+ ASSERT_RAISES(Invalid, factory_->Finish().status());
+ }
+
+ for (auto pathlist : {
+ R"(first=1/file.parquet
+ second=0/file.parquet)",
+ R"(first=1/second=2/file.parquet
+ second=0/file.parquet)",
+ R"(first=1/file.parquet
+ second=0/first=1/file.parquet)",
+ }) {
+ MakeFactory(ParsePathList(pathlist));
+ ASSERT_OK(factory_->Finish().status());
+ }
+}
+
+std::shared_ptr<DatasetFactory> DatasetFactoryFromSchemas(
+ std::vector<std::shared_ptr<Schema>> schemas) {
+ return std::make_shared<MockDatasetFactory>(schemas);
+}
+
+TEST(UnionDatasetFactoryTest, Basic) {
+ auto f64 = field("f64", float64());
+ auto i32 = field("i32", int32());
+ auto i32_req = field("i32", int32(), /*nullable*/ false);
+ auto str = field("str", utf8());
+
+ auto schema_1 = schema({f64, i32_req});
+ auto schema_2 = schema({f64, i32});
+ auto schema_3 = schema({str, i32});
+
+ auto dataset_1 = DatasetFactoryFromSchemas({schema_1, schema_2});
+ auto dataset_2 = DatasetFactoryFromSchemas({schema_2});
+ auto dataset_3 = DatasetFactoryFromSchemas({schema_3});
+
+ ASSERT_OK_AND_ASSIGN(auto factory,
+ UnionDatasetFactory::Make({dataset_1, dataset_2, dataset_3}));
+
+ ASSERT_OK_AND_ASSIGN(auto schemas, factory->InspectSchemas({}));
+ AssertSchemasAre(schemas, {schema_2, schema_2, schema_3});
+
+ auto expected_schema = schema({f64, i32, str});
+ ASSERT_OK_AND_ASSIGN(auto inspected, factory->Inspect());
+ EXPECT_EQ(*inspected, *expected_schema);
+
+ ASSERT_OK_AND_ASSIGN(auto dataset, factory->Finish());
+ EXPECT_EQ(*dataset->schema(), *expected_schema);
+
+ auto f64_schema = schema({f64});
+ ASSERT_OK_AND_ASSIGN(dataset, factory->Finish(f64_schema));
+ EXPECT_EQ(*dataset->schema(), *f64_schema);
+}
+
+TEST(UnionDatasetFactoryTest, ConflictingSchemas) {
+ auto f64 = field("f64", float64());
+ auto i32 = field("i32", int32());
+ auto i32_req = field("i32", int32(), /*nullable*/ false);
+ auto bad_f64 = field("f64", float32());
+
+ auto schema_1 = schema({f64, i32_req});
+ auto schema_2 = schema({f64, i32});
+ // Incompatible with schema_1
+ auto schema_3 = schema({bad_f64, i32});
+
+ auto dataset_factory_1 = DatasetFactoryFromSchemas({schema_1, schema_2});
+ auto dataset_factory_2 = DatasetFactoryFromSchemas({schema_2});
+ auto dataset_factory_3 = DatasetFactoryFromSchemas({schema_3});
+
+ ASSERT_OK_AND_ASSIGN(auto factory,
+ UnionDatasetFactory::Make(
+ {dataset_factory_1, dataset_factory_2, dataset_factory_3}));
+
+ // schema_3 conflicts with other, Inspect/Finish should not work
+ ASSERT_RAISES(Invalid, factory->Inspect());
+ ASSERT_RAISES(Invalid, factory->Finish());
+
+ // The user can inspect without error
+ ASSERT_OK_AND_ASSIGN(auto schemas, factory->InspectSchemas({}));
+ AssertSchemasAre(schemas, {schema_2, schema_2, schema_3});
+
+ // The user decided to ignore the conflicting `f64` field.
+ auto i32_schema = schema({i32});
+ ASSERT_OK_AND_ASSIGN(auto dataset, factory->Finish(i32_schema));
+ EXPECT_EQ(*dataset->schema(), *i32_schema);
+}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_base.cc b/src/arrow/cpp/src/arrow/dataset/file_base.cc
new file mode 100644
index 000000000..4ff3c6d2b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_base.cc
@@ -0,0 +1,466 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/file_base.h"
+
+#include <arrow/compute/exec/exec_plan.h>
+
+#include <algorithm>
+#include <unordered_map>
+#include <vector>
+
+#include "arrow/compute/exec/forest_internal.h"
+#include "arrow/compute/exec/subtree_internal.h"
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/dataset_writer.h"
+#include "arrow/dataset/scanner.h"
+#include "arrow/dataset/scanner_internal.h"
+#include "arrow/filesystem/filesystem.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/io/compressed.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/memory.h"
+#include "arrow/util/compression.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/map.h"
+#include "arrow/util/string.h"
+#include "arrow/util/task_group.h"
+#include "arrow/util/variant.h"
+
+namespace arrow {
+
+using internal::checked_pointer_cast;
+
+namespace dataset {
+
+Result<std::shared_ptr<io::RandomAccessFile>> FileSource::Open() const {
+ if (filesystem_) {
+ return filesystem_->OpenInputFile(file_info_);
+ }
+
+ if (buffer_) {
+ return std::make_shared<io::BufferReader>(buffer_);
+ }
+
+ return custom_open_();
+}
+
+Result<std::shared_ptr<io::InputStream>> FileSource::OpenCompressed(
+ util::optional<Compression::type> compression) const {
+ ARROW_ASSIGN_OR_RAISE(auto file, Open());
+ auto actual_compression = Compression::type::UNCOMPRESSED;
+ if (!compression.has_value()) {
+ // Guess compression from file extension
+ auto extension = fs::internal::GetAbstractPathExtension(path());
+ if (extension == "gz") {
+ actual_compression = Compression::type::GZIP;
+ } else {
+ auto maybe_compression = util::Codec::GetCompressionType(extension);
+ if (maybe_compression.ok()) {
+ ARROW_ASSIGN_OR_RAISE(actual_compression, maybe_compression);
+ }
+ }
+ } else {
+ actual_compression = compression.value();
+ }
+ if (actual_compression == Compression::type::UNCOMPRESSED) {
+ return file;
+ }
+ ARROW_ASSIGN_OR_RAISE(auto codec, util::Codec::Create(actual_compression));
+ return io::CompressedInputStream::Make(codec.get(), std::move(file));
+}
+
+Future<util::optional<int64_t>> FileFormat::CountRows(
+ const std::shared_ptr<FileFragment>&, compute::Expression,
+ const std::shared_ptr<ScanOptions>&) {
+ return Future<util::optional<int64_t>>::MakeFinished(util::nullopt);
+}
+
+Result<std::shared_ptr<FileFragment>> FileFormat::MakeFragment(
+ FileSource source, std::shared_ptr<Schema> physical_schema) {
+ return MakeFragment(std::move(source), compute::literal(true),
+ std::move(physical_schema));
+}
+
+Result<std::shared_ptr<FileFragment>> FileFormat::MakeFragment(
+ FileSource source, compute::Expression partition_expression) {
+ return MakeFragment(std::move(source), std::move(partition_expression), nullptr);
+}
+
+Result<std::shared_ptr<FileFragment>> FileFormat::MakeFragment(
+ FileSource source, compute::Expression partition_expression,
+ std::shared_ptr<Schema> physical_schema) {
+ return std::shared_ptr<FileFragment>(
+ new FileFragment(std::move(source), shared_from_this(),
+ std::move(partition_expression), std::move(physical_schema)));
+}
+
+// The following implementation of ScanBatchesAsync is both ugly and terribly inefficient.
+// Each of the formats should provide their own efficient implementation. However, this
+// is a reasonable starting point or implementation for a dummy/mock format.
+Result<RecordBatchGenerator> FileFormat::ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& scan_options,
+ const std::shared_ptr<FileFragment>& file) const {
+ ARROW_ASSIGN_OR_RAISE(auto scan_task_it, ScanFile(scan_options, file));
+ struct State {
+ State(std::shared_ptr<ScanOptions> scan_options, ScanTaskIterator scan_task_it)
+ : scan_options(std::move(scan_options)),
+ scan_task_it(std::move(scan_task_it)),
+ current_rb_it(),
+ finished(false) {}
+
+ std::shared_ptr<ScanOptions> scan_options;
+ ScanTaskIterator scan_task_it;
+ RecordBatchIterator current_rb_it;
+ bool finished;
+ };
+ struct Generator {
+ Future<std::shared_ptr<RecordBatch>> operator()() {
+ while (!state->finished) {
+ if (!state->current_rb_it) {
+ RETURN_NOT_OK(PumpScanTask());
+ if (state->finished) {
+ return AsyncGeneratorEnd<std::shared_ptr<RecordBatch>>();
+ }
+ }
+ ARROW_ASSIGN_OR_RAISE(auto next_batch, state->current_rb_it.Next());
+ if (IsIterationEnd(next_batch)) {
+ state->current_rb_it = RecordBatchIterator();
+ } else {
+ return Future<std::shared_ptr<RecordBatch>>::MakeFinished(next_batch);
+ }
+ }
+ return AsyncGeneratorEnd<std::shared_ptr<RecordBatch>>();
+ }
+ Status PumpScanTask() {
+ ARROW_ASSIGN_OR_RAISE(auto next_task, state->scan_task_it.Next());
+ if (IsIterationEnd(next_task)) {
+ state->finished = true;
+ } else {
+ ARROW_ASSIGN_OR_RAISE(state->current_rb_it, next_task->Execute());
+ }
+ return Status::OK();
+ }
+ std::shared_ptr<State> state;
+ };
+ return Generator{std::make_shared<State>(scan_options, std::move(scan_task_it))};
+}
+
+Result<std::shared_ptr<Schema>> FileFragment::ReadPhysicalSchemaImpl() {
+ return format_->Inspect(source_);
+}
+
+Result<ScanTaskIterator> FileFragment::Scan(std::shared_ptr<ScanOptions> options) {
+ auto self = std::dynamic_pointer_cast<FileFragment>(shared_from_this());
+ return format_->ScanFile(options, self);
+}
+
+Result<RecordBatchGenerator> FileFragment::ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& options) {
+ auto self = std::dynamic_pointer_cast<FileFragment>(shared_from_this());
+ return format_->ScanBatchesAsync(options, self);
+}
+
+Future<util::optional<int64_t>> FileFragment::CountRows(
+ compute::Expression predicate, const std::shared_ptr<ScanOptions>& options) {
+ ARROW_ASSIGN_OR_RAISE(predicate, compute::SimplifyWithGuarantee(std::move(predicate),
+ partition_expression_));
+ if (!predicate.IsSatisfiable()) {
+ return Future<util::optional<int64_t>>::MakeFinished(0);
+ }
+ auto self = checked_pointer_cast<FileFragment>(shared_from_this());
+ return format()->CountRows(self, std::move(predicate), options);
+}
+
+struct FileSystemDataset::FragmentSubtrees {
+ // Forest for skipping fragments based on extracted subtree expressions
+ compute::Forest forest;
+ // fragment indices and subtree expressions in forest order
+ std::vector<util::Variant<int, compute::Expression>> fragments_and_subtrees;
+};
+
+Result<std::shared_ptr<FileSystemDataset>> FileSystemDataset::Make(
+ std::shared_ptr<Schema> schema, compute::Expression root_partition,
+ std::shared_ptr<FileFormat> format, std::shared_ptr<fs::FileSystem> filesystem,
+ std::vector<std::shared_ptr<FileFragment>> fragments,
+ std::shared_ptr<Partitioning> partitioning) {
+ std::shared_ptr<FileSystemDataset> out(
+ new FileSystemDataset(std::move(schema), std::move(root_partition)));
+ out->format_ = std::move(format);
+ out->filesystem_ = std::move(filesystem);
+ out->fragments_ = std::move(fragments);
+ out->partitioning_ = std::move(partitioning);
+ out->SetupSubtreePruning();
+ return out;
+}
+
+Result<std::shared_ptr<Dataset>> FileSystemDataset::ReplaceSchema(
+ std::shared_ptr<Schema> schema) const {
+ RETURN_NOT_OK(CheckProjectable(*schema_, *schema));
+ return Make(std::move(schema), partition_expression_, format_, filesystem_, fragments_);
+}
+
+std::vector<std::string> FileSystemDataset::files() const {
+ std::vector<std::string> files;
+
+ for (const auto& fragment : fragments_) {
+ files.push_back(fragment->source().path());
+ }
+
+ return files;
+}
+
+std::string FileSystemDataset::ToString() const {
+ std::string repr = "FileSystemDataset:";
+
+ if (fragments_.empty()) {
+ return repr + " []";
+ }
+
+ for (const auto& fragment : fragments_) {
+ repr += "\n" + fragment->source().path();
+
+ const auto& partition = fragment->partition_expression();
+ if (partition != compute::literal(true)) {
+ repr += ": " + partition.ToString();
+ }
+ }
+
+ return repr;
+}
+
+void FileSystemDataset::SetupSubtreePruning() {
+ subtrees_ = std::make_shared<FragmentSubtrees>();
+ compute::SubtreeImpl impl;
+
+ auto encoded = impl.EncodeGuarantees(
+ [&](int index) { return fragments_[index]->partition_expression(); },
+ static_cast<int>(fragments_.size()));
+
+ std::sort(encoded.begin(), encoded.end(), compute::SubtreeImpl::ByGuarantee());
+
+ for (const auto& e : encoded) {
+ if (e.index) {
+ subtrees_->fragments_and_subtrees.emplace_back(*e.index);
+ } else {
+ subtrees_->fragments_and_subtrees.emplace_back(impl.GetSubtreeExpression(e));
+ }
+ }
+
+ subtrees_->forest = compute::Forest(static_cast<int>(encoded.size()),
+ compute::SubtreeImpl::IsAncestor{encoded});
+}
+
+Result<FragmentIterator> FileSystemDataset::GetFragmentsImpl(
+ compute::Expression predicate) {
+ if (predicate == compute::literal(true)) {
+ // trivial predicate; skip subtree pruning
+ return MakeVectorIterator(FragmentVector(fragments_.begin(), fragments_.end()));
+ }
+
+ std::vector<int> fragment_indices;
+
+ std::vector<compute::Expression> predicates{predicate};
+ RETURN_NOT_OK(subtrees_->forest.Visit(
+ [&](compute::Forest::Ref ref) -> Result<bool> {
+ if (auto fragment_index =
+ util::get_if<int>(&subtrees_->fragments_and_subtrees[ref.i])) {
+ fragment_indices.push_back(*fragment_index);
+ return false;
+ }
+
+ const auto& subtree_expr =
+ util::get<compute::Expression>(subtrees_->fragments_and_subtrees[ref.i]);
+ ARROW_ASSIGN_OR_RAISE(auto simplified,
+ SimplifyWithGuarantee(predicates.back(), subtree_expr));
+
+ if (!simplified.IsSatisfiable()) {
+ return false;
+ }
+
+ predicates.push_back(std::move(simplified));
+ return true;
+ },
+ [&](compute::Forest::Ref ref) { predicates.pop_back(); }));
+
+ std::sort(fragment_indices.begin(), fragment_indices.end());
+
+ FragmentVector fragments(fragment_indices.size());
+ std::transform(fragment_indices.begin(), fragment_indices.end(), fragments.begin(),
+ [this](int i) { return fragments_[i]; });
+
+ return MakeVectorIterator(std::move(fragments));
+}
+
+Status FileWriter::Write(RecordBatchReader* batches) {
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, batches->Next());
+ if (batch == nullptr) break;
+ RETURN_NOT_OK(Write(batch));
+ }
+ return Status::OK();
+}
+
+Status FileWriter::Finish() {
+ RETURN_NOT_OK(FinishInternal());
+ return destination_->Close();
+}
+
+namespace {
+
+class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer {
+ public:
+ DatasetWritingSinkNodeConsumer(std::shared_ptr<Schema> schema,
+ std::unique_ptr<internal::DatasetWriter> dataset_writer,
+ FileSystemDatasetWriteOptions write_options,
+ std::shared_ptr<util::AsyncToggle> backpressure_toggle)
+ : schema_(std::move(schema)),
+ dataset_writer_(std::move(dataset_writer)),
+ write_options_(std::move(write_options)),
+ backpressure_toggle_(std::move(backpressure_toggle)) {}
+
+ Status Consume(compute::ExecBatch batch) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> record_batch,
+ batch.ToRecordBatch(schema_));
+ return WriteNextBatch(std::move(record_batch), batch.guarantee);
+ }
+
+ Future<> Finish() {
+ RETURN_NOT_OK(task_group_.AddTask([this] { return dataset_writer_->Finish(); }));
+ return task_group_.End();
+ }
+
+ private:
+ Status WriteNextBatch(std::shared_ptr<RecordBatch> batch,
+ compute::Expression guarantee) {
+ ARROW_ASSIGN_OR_RAISE(auto groups, write_options_.partitioning->Partition(batch));
+ batch.reset(); // drop to hopefully conserve memory
+
+ if (groups.batches.size() > static_cast<size_t>(write_options_.max_partitions)) {
+ return Status::Invalid("Fragment would be written into ", groups.batches.size(),
+ " partitions. This exceeds the maximum of ",
+ write_options_.max_partitions);
+ }
+
+ for (std::size_t index = 0; index < groups.batches.size(); index++) {
+ auto partition_expression = and_(groups.expressions[index], guarantee);
+ auto next_batch = groups.batches[index];
+ ARROW_ASSIGN_OR_RAISE(std::string destination,
+ write_options_.partitioning->Format(partition_expression));
+ RETURN_NOT_OK(task_group_.AddTask([this, next_batch, destination] {
+ Future<> has_room = dataset_writer_->WriteRecordBatch(next_batch, destination);
+ if (!has_room.is_finished() && backpressure_toggle_) {
+ backpressure_toggle_->Close();
+ return has_room.Then([this] { backpressure_toggle_->Open(); });
+ }
+ return has_room;
+ }));
+ }
+ return Status::OK();
+ }
+
+ std::shared_ptr<Schema> schema_;
+ std::unique_ptr<internal::DatasetWriter> dataset_writer_;
+ FileSystemDatasetWriteOptions write_options_;
+ std::shared_ptr<util::AsyncToggle> backpressure_toggle_;
+ util::SerializedAsyncTaskGroup task_group_;
+};
+
+} // namespace
+
+Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_options,
+ std::shared_ptr<Scanner> scanner) {
+ if (!scanner->options()->use_async) {
+ return Status::Invalid(
+ "A dataset write operation was invoked on a scanner that was configured for "
+ "synchronous scanning. Dataset writing requires a scanner configured for "
+ "asynchronous scanning. Please recreate the scanner with the use_async or "
+ "UseAsync option set to true");
+ }
+ const io::IOContext& io_context = scanner->options()->io_context;
+ std::shared_ptr<compute::ExecContext> exec_context =
+ std::make_shared<compute::ExecContext>(io_context.pool(),
+ ::arrow::internal::GetCpuThreadPool());
+
+ ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(exec_context.get()));
+
+ auto exprs = scanner->options()->projection.call()->arguments;
+ auto names = checked_cast<const compute::MakeStructOptions*>(
+ scanner->options()->projection.call()->options.get())
+ ->field_names;
+ std::shared_ptr<Dataset> dataset = scanner->dataset();
+ std::shared_ptr<util::AsyncToggle> backpressure_toggle =
+ std::make_shared<util::AsyncToggle>();
+
+ RETURN_NOT_OK(
+ compute::Declaration::Sequence(
+ {
+ {"scan", ScanNodeOptions{dataset, scanner->options(), backpressure_toggle}},
+ {"filter", compute::FilterNodeOptions{scanner->options()->filter}},
+ {"project",
+ compute::ProjectNodeOptions{std::move(exprs), std::move(names)}},
+ {"write",
+ WriteNodeOptions{write_options, scanner->options()->projected_schema,
+ backpressure_toggle}},
+ })
+ .AddToPlan(plan.get()));
+
+ RETURN_NOT_OK(plan->StartProducing());
+ return plan->finished().status();
+}
+
+Result<compute::ExecNode*> MakeWriteNode(compute::ExecPlan* plan,
+ std::vector<compute::ExecNode*> inputs,
+ const compute::ExecNodeOptions& options) {
+ if (inputs.size() != 1) {
+ return Status::Invalid("Write SinkNode requires exactly 1 input, got ",
+ inputs.size());
+ }
+
+ const WriteNodeOptions write_node_options =
+ checked_cast<const WriteNodeOptions&>(options);
+ const FileSystemDatasetWriteOptions& write_options = write_node_options.write_options;
+ const std::shared_ptr<Schema>& schema = write_node_options.schema;
+ const std::shared_ptr<util::AsyncToggle>& backpressure_toggle =
+ write_node_options.backpressure_toggle;
+
+ ARROW_ASSIGN_OR_RAISE(auto dataset_writer,
+ internal::DatasetWriter::Make(write_options));
+
+ std::shared_ptr<DatasetWritingSinkNodeConsumer> consumer =
+ std::make_shared<DatasetWritingSinkNodeConsumer>(
+ schema, std::move(dataset_writer), write_options, backpressure_toggle);
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto node,
+ compute::MakeExecNode("consuming_sink", plan, std::move(inputs),
+ compute::ConsumingSinkNodeOptions{std::move(consumer)}));
+
+ return node;
+}
+
+namespace internal {
+void InitializeDatasetWriter(arrow::compute::ExecFactoryRegistry* registry) {
+ DCHECK_OK(registry->AddFactory("write", MakeWriteNode));
+}
+} // namespace internal
+
+} // namespace dataset
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_base.h b/src/arrow/cpp/src/arrow/dataset/file_base.h
new file mode 100644
index 000000000..911369181
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_base.h
@@ -0,0 +1,421 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/dataset/dataset.h"
+#include "arrow/dataset/partition.h"
+#include "arrow/dataset/scanner.h"
+#include "arrow/dataset/type_fwd.h"
+#include "arrow/dataset/visibility.h"
+#include "arrow/filesystem/filesystem.h"
+#include "arrow/io/file.h"
+#include "arrow/util/compression.h"
+
+namespace arrow {
+
+namespace dataset {
+
+/// \defgroup dataset-file-formats File formats for reading and writing datasets
+/// \defgroup dataset-filesystem File system datasets
+///
+/// @{
+
+/// \brief The path and filesystem where an actual file is located or a buffer which can
+/// be read like a file
+class ARROW_DS_EXPORT FileSource {
+ public:
+ FileSource(std::string path, std::shared_ptr<fs::FileSystem> filesystem,
+ Compression::type compression = Compression::UNCOMPRESSED)
+ : file_info_(std::move(path)),
+ filesystem_(std::move(filesystem)),
+ compression_(compression) {}
+
+ FileSource(fs::FileInfo info, std::shared_ptr<fs::FileSystem> filesystem,
+ Compression::type compression = Compression::UNCOMPRESSED)
+ : file_info_(std::move(info)),
+ filesystem_(std::move(filesystem)),
+ compression_(compression) {}
+
+ explicit FileSource(std::shared_ptr<Buffer> buffer,
+ Compression::type compression = Compression::UNCOMPRESSED)
+ : buffer_(std::move(buffer)), compression_(compression) {}
+
+ using CustomOpen = std::function<Result<std::shared_ptr<io::RandomAccessFile>>()>;
+ explicit FileSource(CustomOpen open) : custom_open_(std::move(open)) {}
+
+ using CustomOpenWithCompression =
+ std::function<Result<std::shared_ptr<io::RandomAccessFile>>(Compression::type)>;
+ explicit FileSource(CustomOpenWithCompression open_with_compression,
+ Compression::type compression = Compression::UNCOMPRESSED)
+ : custom_open_(std::bind(std::move(open_with_compression), compression)),
+ compression_(compression) {}
+
+ explicit FileSource(std::shared_ptr<io::RandomAccessFile> file,
+ Compression::type compression = Compression::UNCOMPRESSED)
+ : custom_open_([=] { return ToResult(file); }), compression_(compression) {}
+
+ FileSource() : custom_open_(CustomOpen{&InvalidOpen}) {}
+
+ static std::vector<FileSource> FromPaths(const std::shared_ptr<fs::FileSystem>& fs,
+ std::vector<std::string> paths) {
+ std::vector<FileSource> sources;
+ for (auto&& path : paths) {
+ sources.emplace_back(std::move(path), fs);
+ }
+ return sources;
+ }
+
+ /// \brief Return the type of raw compression on the file, if any.
+ Compression::type compression() const { return compression_; }
+
+ /// \brief Return the file path, if any. Only valid when file source wraps a path.
+ const std::string& path() const {
+ static std::string buffer_path = "<Buffer>";
+ static std::string custom_open_path = "<Buffer>";
+ return filesystem_ ? file_info_.path() : buffer_ ? buffer_path : custom_open_path;
+ }
+
+ /// \brief Return the filesystem, if any. Otherwise returns nullptr
+ const std::shared_ptr<fs::FileSystem>& filesystem() const { return filesystem_; }
+
+ /// \brief Return the buffer containing the file, if any. Otherwise returns nullptr
+ const std::shared_ptr<Buffer>& buffer() const { return buffer_; }
+
+ /// \brief Get a RandomAccessFile which views this file source
+ Result<std::shared_ptr<io::RandomAccessFile>> Open() const;
+
+ /// \brief Get an InputStream which views this file source (and decompresses if needed)
+ /// \param[in] compression If nullopt, guess the compression scheme from the
+ /// filename, else decompress with the given codec
+ Result<std::shared_ptr<io::InputStream>> OpenCompressed(
+ util::optional<Compression::type> compression = util::nullopt) const;
+
+ private:
+ static Result<std::shared_ptr<io::RandomAccessFile>> InvalidOpen() {
+ return Status::Invalid("Called Open() on an uninitialized FileSource");
+ }
+
+ fs::FileInfo file_info_;
+ std::shared_ptr<fs::FileSystem> filesystem_;
+ std::shared_ptr<Buffer> buffer_;
+ CustomOpen custom_open_;
+ Compression::type compression_ = Compression::UNCOMPRESSED;
+};
+
+/// \brief Base class for file format implementation
+class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this<FileFormat> {
+ public:
+ /// Options affecting how this format is scanned.
+ ///
+ /// The options here can be overridden at scan time.
+ std::shared_ptr<FragmentScanOptions> default_fragment_scan_options;
+
+ virtual ~FileFormat() = default;
+
+ /// \brief The name identifying the kind of file format
+ virtual std::string type_name() const = 0;
+
+ virtual bool Equals(const FileFormat& other) const = 0;
+
+ /// \brief Indicate if the FileSource is supported/readable by this format.
+ virtual Result<bool> IsSupported(const FileSource& source) const = 0;
+
+ /// \brief Return the schema of the file if possible.
+ virtual Result<std::shared_ptr<Schema>> Inspect(const FileSource& source) const = 0;
+
+ /// \brief Open a FileFragment for scanning.
+ /// May populate lazy properties of the FileFragment.
+ virtual Result<ScanTaskIterator> ScanFile(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& file) const = 0;
+
+ virtual Result<RecordBatchGenerator> ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& file) const;
+ virtual Future<util::optional<int64_t>> CountRows(
+ const std::shared_ptr<FileFragment>& file, compute::Expression predicate,
+ const std::shared_ptr<ScanOptions>& options);
+
+ /// \brief Open a fragment
+ virtual Result<std::shared_ptr<FileFragment>> MakeFragment(
+ FileSource source, compute::Expression partition_expression,
+ std::shared_ptr<Schema> physical_schema);
+
+ /// \brief Create a FileFragment for a FileSource.
+ Result<std::shared_ptr<FileFragment>> MakeFragment(
+ FileSource source, compute::Expression partition_expression);
+
+ /// \brief Create a FileFragment for a FileSource.
+ Result<std::shared_ptr<FileFragment>> MakeFragment(
+ FileSource source, std::shared_ptr<Schema> physical_schema = NULLPTR);
+
+ /// \brief Create a writer for this format.
+ virtual Result<std::shared_ptr<FileWriter>> MakeWriter(
+ std::shared_ptr<io::OutputStream> destination, std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options,
+ fs::FileLocator destination_locator) const = 0;
+
+ /// \brief Get default write options for this format.
+ virtual std::shared_ptr<FileWriteOptions> DefaultWriteOptions() = 0;
+};
+
+/// \brief A Fragment that is stored in a file with a known format
+class ARROW_DS_EXPORT FileFragment : public Fragment {
+ public:
+ Result<ScanTaskIterator> Scan(std::shared_ptr<ScanOptions> options) override;
+ Result<RecordBatchGenerator> ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& options) override;
+ Future<util::optional<int64_t>> CountRows(
+ compute::Expression predicate,
+ const std::shared_ptr<ScanOptions>& options) override;
+
+ std::string type_name() const override { return format_->type_name(); }
+ std::string ToString() const override { return source_.path(); };
+
+ const FileSource& source() const { return source_; }
+ const std::shared_ptr<FileFormat>& format() const { return format_; }
+
+ protected:
+ FileFragment(FileSource source, std::shared_ptr<FileFormat> format,
+ compute::Expression partition_expression,
+ std::shared_ptr<Schema> physical_schema)
+ : Fragment(std::move(partition_expression), std::move(physical_schema)),
+ source_(std::move(source)),
+ format_(std::move(format)) {}
+
+ Result<std::shared_ptr<Schema>> ReadPhysicalSchemaImpl() override;
+
+ FileSource source_;
+ std::shared_ptr<FileFormat> format_;
+
+ friend class FileFormat;
+};
+
+/// \brief A Dataset of FileFragments.
+///
+/// A FileSystemDataset is composed of one or more FileFragment. The fragments
+/// are independent and don't need to share the same format and/or filesystem.
+class ARROW_DS_EXPORT FileSystemDataset : public Dataset {
+ public:
+ /// \brief Create a FileSystemDataset.
+ ///
+ /// \param[in] schema the schema of the dataset
+ /// \param[in] root_partition the partition expression of the dataset
+ /// \param[in] format the format of each FileFragment.
+ /// \param[in] filesystem the filesystem of each FileFragment, or nullptr if the
+ /// fragments wrap buffers.
+ /// \param[in] fragments list of fragments to create the dataset from.
+ /// \param[in] partitioning the Partitioning object in case the dataset is created
+ /// with a known partitioning (e.g. from a discovered partitioning
+ /// through a DatasetFactory), or nullptr if not known.
+ ///
+ /// Note that fragments wrapping files resident in differing filesystems are not
+ /// permitted; to work with multiple filesystems use a UnionDataset.
+ ///
+ /// \return A constructed dataset.
+ static Result<std::shared_ptr<FileSystemDataset>> Make(
+ std::shared_ptr<Schema> schema, compute::Expression root_partition,
+ std::shared_ptr<FileFormat> format, std::shared_ptr<fs::FileSystem> filesystem,
+ std::vector<std::shared_ptr<FileFragment>> fragments,
+ std::shared_ptr<Partitioning> partitioning = NULLPTR);
+
+ /// \brief Write a dataset.
+ static Status Write(const FileSystemDatasetWriteOptions& write_options,
+ std::shared_ptr<Scanner> scanner);
+
+ /// \brief Return the type name of the dataset.
+ std::string type_name() const override { return "filesystem"; }
+
+ /// \brief Replace the schema of the dataset.
+ Result<std::shared_ptr<Dataset>> ReplaceSchema(
+ std::shared_ptr<Schema> schema) const override;
+
+ /// \brief Return the path of files.
+ std::vector<std::string> files() const;
+
+ /// \brief Return the format.
+ const std::shared_ptr<FileFormat>& format() const { return format_; }
+
+ /// \brief Return the filesystem. May be nullptr if the fragments wrap buffers.
+ const std::shared_ptr<fs::FileSystem>& filesystem() const { return filesystem_; }
+
+ /// \brief Return the partitioning. May be nullptr if the dataset was not constructed
+ /// with a partitioning.
+ const std::shared_ptr<Partitioning>& partitioning() const { return partitioning_; }
+
+ std::string ToString() const;
+
+ protected:
+ struct FragmentSubtrees;
+
+ explicit FileSystemDataset(std::shared_ptr<Schema> schema)
+ : Dataset(std::move(schema)) {}
+
+ FileSystemDataset(std::shared_ptr<Schema> schema,
+ compute::Expression partition_expression)
+ : Dataset(std::move(schema), partition_expression) {}
+
+ Result<FragmentIterator> GetFragmentsImpl(compute::Expression predicate) override;
+
+ void SetupSubtreePruning();
+
+ std::shared_ptr<FileFormat> format_;
+ std::shared_ptr<fs::FileSystem> filesystem_;
+ std::vector<std::shared_ptr<FileFragment>> fragments_;
+ std::shared_ptr<Partitioning> partitioning_;
+
+ std::shared_ptr<FragmentSubtrees> subtrees_;
+};
+
+/// \brief Options for writing a file of this format.
+class ARROW_DS_EXPORT FileWriteOptions {
+ public:
+ virtual ~FileWriteOptions() = default;
+
+ const std::shared_ptr<FileFormat>& format() const { return format_; }
+
+ std::string type_name() const { return format_->type_name(); }
+
+ protected:
+ explicit FileWriteOptions(std::shared_ptr<FileFormat> format)
+ : format_(std::move(format)) {}
+
+ std::shared_ptr<FileFormat> format_;
+};
+
+/// \brief A writer for this format.
+class ARROW_DS_EXPORT FileWriter {
+ public:
+ virtual ~FileWriter() = default;
+
+ /// \brief Write the given batch.
+ virtual Status Write(const std::shared_ptr<RecordBatch>& batch) = 0;
+
+ /// \brief Write all batches from the reader.
+ Status Write(RecordBatchReader* batches);
+
+ /// \brief Indicate that writing is done.
+ virtual Status Finish();
+
+ const std::shared_ptr<FileFormat>& format() const { return options_->format(); }
+ const std::shared_ptr<Schema>& schema() const { return schema_; }
+ const std::shared_ptr<FileWriteOptions>& options() const { return options_; }
+ const fs::FileLocator& destination() const { return destination_locator_; }
+
+ protected:
+ FileWriter(std::shared_ptr<Schema> schema, std::shared_ptr<FileWriteOptions> options,
+ std::shared_ptr<io::OutputStream> destination,
+ fs::FileLocator destination_locator)
+ : schema_(std::move(schema)),
+ options_(std::move(options)),
+ destination_(std::move(destination)),
+ destination_locator_(std::move(destination_locator)) {}
+
+ virtual Status FinishInternal() = 0;
+
+ std::shared_ptr<Schema> schema_;
+ std::shared_ptr<FileWriteOptions> options_;
+ std::shared_ptr<io::OutputStream> destination_;
+ fs::FileLocator destination_locator_;
+};
+
+/// \brief Options for writing a dataset.
+struct ARROW_DS_EXPORT FileSystemDatasetWriteOptions {
+ /// Options for individual fragment writing.
+ std::shared_ptr<FileWriteOptions> file_write_options;
+
+ /// FileSystem into which a dataset will be written.
+ std::shared_ptr<fs::FileSystem> filesystem;
+
+ /// Root directory into which the dataset will be written.
+ std::string base_dir;
+
+ /// Partitioning used to generate fragment paths.
+ std::shared_ptr<Partitioning> partitioning;
+
+ /// Maximum number of partitions any batch may be written into, default is 1K.
+ int max_partitions = 1024;
+
+ /// Template string used to generate fragment basenames.
+ /// {i} will be replaced by an auto incremented integer.
+ std::string basename_template;
+
+ /// If greater than 0 then this will limit the maximum number of files that can be left
+ /// open. If an attempt is made to open too many files then the least recently used file
+ /// will be closed. If this setting is set too low you may end up fragmenting your data
+ /// into many small files.
+ uint32_t max_open_files = 1024;
+
+ /// If greater than 0 then this will limit how many rows are placed in any single file.
+ /// Otherwise there will be no limit and one file will be created in each output
+ /// directory unless files need to be closed to respect max_open_files
+ uint64_t max_rows_per_file = 0;
+
+ /// Controls what happens if an output directory already exists.
+ ExistingDataBehavior existing_data_behavior = ExistingDataBehavior::kError;
+
+ /// Callback to be invoked against all FileWriters before
+ /// they are finalized with FileWriter::Finish().
+ std::function<Status(FileWriter*)> writer_pre_finish = [](FileWriter*) {
+ return Status::OK();
+ };
+
+ /// Callback to be invoked against all FileWriters after they have
+ /// called FileWriter::Finish().
+ std::function<Status(FileWriter*)> writer_post_finish = [](FileWriter*) {
+ return Status::OK();
+ };
+
+ const std::shared_ptr<FileFormat>& format() const {
+ return file_write_options->format();
+ }
+};
+
+/// \brief Wraps FileSystemDatasetWriteOptions for consumption as compute::ExecNodeOptions
+class ARROW_DS_EXPORT WriteNodeOptions : public compute::ExecNodeOptions {
+ public:
+ explicit WriteNodeOptions(
+ FileSystemDatasetWriteOptions options, std::shared_ptr<Schema> schema,
+ std::shared_ptr<util::AsyncToggle> backpressure_toggle = NULLPTR)
+ : write_options(std::move(options)),
+ schema(std::move(schema)),
+ backpressure_toggle(std::move(backpressure_toggle)) {}
+
+ FileSystemDatasetWriteOptions write_options;
+ std::shared_ptr<Schema> schema;
+ std::shared_ptr<util::AsyncToggle> backpressure_toggle;
+};
+
+/// @}
+
+namespace internal {
+ARROW_DS_EXPORT void InitializeDatasetWriter(
+ arrow::compute::ExecFactoryRegistry* registry);
+}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_benchmark.cc b/src/arrow/cpp/src/arrow/dataset/file_benchmark.cc
new file mode 100644
index 000000000..5caea1851
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_benchmark.cc
@@ -0,0 +1,90 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/compute/exec/expression.h"
+#include "arrow/dataset/discovery.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/file_ipc.h"
+#include "arrow/dataset/partition.h"
+#include "arrow/filesystem/mockfs.h"
+#include "arrow/filesystem/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/iterator.h"
+
+namespace arrow {
+namespace dataset {
+
+static std::shared_ptr<Dataset> GetDataset() {
+ std::vector<fs::FileInfo> files;
+ std::vector<std::string> paths;
+ for (int a = 0; a < 100; a++) {
+ for (int b = 0; b < 100; b++) {
+ auto path = "a=" + std::to_string(a) + "/b=" + std::to_string(b) + "/data.feather";
+ files.push_back(fs::File(path));
+ paths.push_back(path);
+ }
+ }
+ EXPECT_OK_AND_ASSIGN(auto fs,
+ arrow::fs::internal::MockFileSystem::Make(fs::kNoTime, files));
+ auto format = std::make_shared<IpcFileFormat>();
+ FileSystemFactoryOptions options;
+ options.partitioning = HivePartitioning::MakeFactory();
+ EXPECT_OK_AND_ASSIGN(auto factory,
+ FileSystemDatasetFactory::Make(fs, paths, format, options));
+ FinishOptions finish_options;
+ finish_options.inspect_options.fragments = 0;
+ EXPECT_OK_AND_ASSIGN(auto dataset, factory->Finish(finish_options));
+ return dataset;
+}
+
+// A benchmark of filtering fragments in a dataset.
+static void GetAllFragments(benchmark::State& state) {
+ auto dataset = GetDataset();
+ for (auto _ : state) {
+ ASSERT_OK_AND_ASSIGN(auto fragments, dataset->GetFragments());
+ ABORT_NOT_OK(fragments.Visit([](std::shared_ptr<Fragment>) { return Status::OK(); }));
+ }
+}
+
+static void GetFilteredFragments(benchmark::State& state, compute::Expression filter) {
+ auto dataset = GetDataset();
+ ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*dataset->schema()));
+ for (auto _ : state) {
+ ASSERT_OK_AND_ASSIGN(auto fragments, dataset->GetFragments(filter));
+ ABORT_NOT_OK(fragments.Visit([](std::shared_ptr<Fragment>) { return Status::OK(); }));
+ }
+}
+
+using compute::field_ref;
+using compute::literal;
+
+BENCHMARK(GetAllFragments);
+// Drill down to a subtree.
+BENCHMARK_CAPTURE(GetFilteredFragments, single_dir, equal(field_ref("a"), literal(90)));
+// Drill down, but not to a subtree.
+BENCHMARK_CAPTURE(GetFilteredFragments, multi_dir, equal(field_ref("b"), literal(90)));
+// Drill down to a single file.
+BENCHMARK_CAPTURE(GetFilteredFragments, single_file,
+ and_(equal(field_ref("a"), literal(90)),
+ equal(field_ref("b"), literal(90))));
+// Apply a filter, but keep most of the files.
+BENCHMARK_CAPTURE(GetFilteredFragments, range, greater(field_ref("a"), literal(1)));
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_csv.cc b/src/arrow/cpp/src/arrow/dataset/file_csv.cc
new file mode 100644
index 000000000..3b6632907
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_csv.cc
@@ -0,0 +1,335 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/file_csv.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <utility>
+
+#include "arrow/csv/options.h"
+#include "arrow/csv/parser.h"
+#include "arrow/csv/reader.h"
+#include "arrow/csv/writer.h"
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/type_fwd.h"
+#include "arrow/dataset/visibility.h"
+#include "arrow/io/buffered.h"
+#include "arrow/io/compressed.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/result.h"
+#include "arrow/type.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+using internal::Executor;
+using internal::SerialExecutor;
+
+namespace dataset {
+
+using RecordBatchGenerator = std::function<Future<std::shared_ptr<RecordBatch>>()>;
+
+Result<std::unordered_set<std::string>> GetColumnNames(
+ const csv::ReadOptions& read_options, const csv::ParseOptions& parse_options,
+ util::string_view first_block, MemoryPool* pool) {
+ if (!read_options.column_names.empty()) {
+ std::unordered_set<std::string> column_names;
+ for (const auto& s : read_options.column_names) {
+ if (!column_names.emplace(s).second) {
+ return Status::Invalid("CSV file contained multiple columns named ", s);
+ }
+ }
+ return column_names;
+ }
+
+ uint32_t parsed_size = 0;
+ int32_t max_num_rows = read_options.skip_rows + 1;
+ csv::BlockParser parser(pool, parse_options, /*num_cols=*/-1, /*first_row=*/1,
+ max_num_rows);
+
+ RETURN_NOT_OK(parser.Parse(util::string_view{first_block}, &parsed_size));
+
+ if (parser.num_rows() != max_num_rows) {
+ return Status::Invalid("Could not read first ", max_num_rows,
+ " rows from CSV file, either file is truncated or"
+ " header is larger than block size");
+ }
+
+ if (parser.num_cols() == 0) {
+ return Status::Invalid("No columns in CSV file");
+ }
+
+ std::unordered_set<std::string> column_names;
+
+ RETURN_NOT_OK(
+ parser.VisitLastRow([&](const uint8_t* data, uint32_t size, bool quoted) -> Status {
+ util::string_view view{reinterpret_cast<const char*>(data), size};
+ if (column_names.emplace(std::string(view)).second) {
+ return Status::OK();
+ }
+ return Status::Invalid("CSV file contained multiple columns named ", view);
+ }));
+
+ return column_names;
+}
+
+static inline Result<csv::ConvertOptions> GetConvertOptions(
+ const CsvFileFormat& format, const ScanOptions* scan_options,
+ const util::string_view first_block) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto csv_scan_options,
+ GetFragmentScanOptions<CsvFragmentScanOptions>(
+ kCsvTypeName, scan_options, format.default_fragment_scan_options));
+ ARROW_ASSIGN_OR_RAISE(
+ auto column_names,
+ GetColumnNames(csv_scan_options->read_options, format.parse_options, first_block,
+ scan_options ? scan_options->pool : default_memory_pool()));
+
+ auto convert_options = csv_scan_options->convert_options;
+
+ if (!scan_options) return convert_options;
+
+ auto materialized = scan_options->MaterializedFields();
+ std::unordered_set<std::string> materialized_fields(materialized.begin(),
+ materialized.end());
+ for (auto field : scan_options->dataset_schema->fields()) {
+ if (materialized_fields.find(field->name()) == materialized_fields.end()) continue;
+ // Ignore virtual columns.
+ if (column_names.find(field->name()) == column_names.end()) continue;
+ // Only read the requested columns
+ convert_options.include_columns.push_back(field->name());
+ // Properly set conversion types
+ convert_options.column_types[field->name()] = field->type();
+ }
+ return convert_options;
+}
+
+static inline Result<csv::ReadOptions> GetReadOptions(
+ const CsvFileFormat& format, const std::shared_ptr<ScanOptions>& scan_options) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto csv_scan_options,
+ GetFragmentScanOptions<CsvFragmentScanOptions>(
+ kCsvTypeName, scan_options.get(), format.default_fragment_scan_options));
+ auto read_options = csv_scan_options->read_options;
+ // Multithreaded conversion of individual files would lead to excessive thread
+ // contention when ScanTasks are also executed in multiple threads, so we disable it
+ // here. Also, this is a no-op since the streaming CSV reader is currently serial
+ read_options.use_threads = false;
+ return read_options;
+}
+
+static inline Future<std::shared_ptr<csv::StreamingReader>> OpenReaderAsync(
+ const FileSource& source, const CsvFileFormat& format,
+ const std::shared_ptr<ScanOptions>& scan_options, Executor* cpu_executor) {
+ ARROW_ASSIGN_OR_RAISE(auto reader_options, GetReadOptions(format, scan_options));
+
+ ARROW_ASSIGN_OR_RAISE(auto input, source.OpenCompressed());
+ ARROW_ASSIGN_OR_RAISE(
+ input, io::BufferedInputStream::Create(reader_options.block_size,
+ default_memory_pool(), std::move(input)));
+
+ // Grab the first block and use it to determine the schema and create a reader. The
+ // input->Peek call blocks so we run the whole thing on the I/O thread pool.
+ auto reader_fut = DeferNotOk(input->io_context().executor()->Submit(
+ [=]() -> Future<std::shared_ptr<csv::StreamingReader>> {
+ ARROW_ASSIGN_OR_RAISE(auto first_block, input->Peek(reader_options.block_size));
+ const auto& parse_options = format.parse_options;
+ ARROW_ASSIGN_OR_RAISE(
+ auto convert_options,
+ GetConvertOptions(format, scan_options ? scan_options.get() : nullptr,
+ first_block));
+ return csv::StreamingReader::MakeAsync(io::default_io_context(), std::move(input),
+ cpu_executor, reader_options,
+ parse_options, convert_options);
+ }));
+ return reader_fut.Then(
+ // Adds the filename to the error
+ [](const std::shared_ptr<csv::StreamingReader>& reader)
+ -> Result<std::shared_ptr<csv::StreamingReader>> { return reader; },
+ [source](const Status& err) -> Result<std::shared_ptr<csv::StreamingReader>> {
+ return err.WithMessage("Could not open CSV input source '", source.path(),
+ "': ", err);
+ });
+}
+
+static inline Result<std::shared_ptr<csv::StreamingReader>> OpenReader(
+ const FileSource& source, const CsvFileFormat& format,
+ const std::shared_ptr<ScanOptions>& scan_options = nullptr) {
+ auto open_reader_fut = OpenReaderAsync(source, format, scan_options,
+ ::arrow::internal::GetCpuThreadPool());
+ return open_reader_fut.result();
+}
+
+static RecordBatchGenerator GeneratorFromReader(
+ const Future<std::shared_ptr<csv::StreamingReader>>& reader, int64_t batch_size) {
+ auto gen_fut = reader.Then(
+ [batch_size](
+ const std::shared_ptr<csv::StreamingReader>& reader) -> RecordBatchGenerator {
+ auto batch_gen = [reader]() { return reader->ReadNextAsync(); };
+ return MakeChunkedBatchGenerator(std::move(batch_gen), batch_size);
+ });
+ return MakeFromFuture(std::move(gen_fut));
+}
+
+/// \brief A ScanTask backed by an Csv file.
+class CsvScanTask : public ScanTask {
+ public:
+ CsvScanTask(std::shared_ptr<const CsvFileFormat> format,
+ std::shared_ptr<ScanOptions> options,
+ std::shared_ptr<FileFragment> fragment)
+ : ScanTask(std::move(options), fragment),
+ format_(std::move(format)),
+ source_(fragment->source()) {}
+
+ Result<RecordBatchIterator> Execute() override {
+ auto reader_fut = OpenReaderAsync(source_, *format_, options(),
+ ::arrow::internal::GetCpuThreadPool());
+ auto reader_gen = GeneratorFromReader(std::move(reader_fut), options()->batch_size);
+ return MakeGeneratorIterator(std::move(reader_gen));
+ }
+
+ Future<RecordBatchVector> SafeExecute(Executor* executor) override {
+ auto reader_fut = OpenReaderAsync(source_, *format_, options(), executor);
+ auto reader_gen = GeneratorFromReader(std::move(reader_fut), options()->batch_size);
+ return CollectAsyncGenerator(reader_gen);
+ }
+
+ Future<> SafeVisit(
+ Executor* executor,
+ std::function<Status(std::shared_ptr<RecordBatch>)> visitor) override {
+ auto reader_fut = OpenReaderAsync(source_, *format_, options(), executor);
+ auto reader_gen = GeneratorFromReader(std::move(reader_fut), options()->batch_size);
+ return VisitAsyncGenerator(reader_gen, visitor);
+ }
+
+ private:
+ std::shared_ptr<const CsvFileFormat> format_;
+ FileSource source_;
+};
+
+bool CsvFileFormat::Equals(const FileFormat& format) const {
+ if (type_name() != format.type_name()) return false;
+
+ const auto& other_parse_options =
+ checked_cast<const CsvFileFormat&>(format).parse_options;
+
+ return parse_options.delimiter == other_parse_options.delimiter &&
+ parse_options.quoting == other_parse_options.quoting &&
+ parse_options.quote_char == other_parse_options.quote_char &&
+ parse_options.double_quote == other_parse_options.double_quote &&
+ parse_options.escaping == other_parse_options.escaping &&
+ parse_options.escape_char == other_parse_options.escape_char &&
+ parse_options.newlines_in_values == other_parse_options.newlines_in_values &&
+ parse_options.ignore_empty_lines == other_parse_options.ignore_empty_lines;
+}
+
+Result<bool> CsvFileFormat::IsSupported(const FileSource& source) const {
+ RETURN_NOT_OK(source.Open().status());
+ return OpenReader(source, *this).ok();
+}
+
+Result<std::shared_ptr<Schema>> CsvFileFormat::Inspect(const FileSource& source) const {
+ ARROW_ASSIGN_OR_RAISE(auto reader, OpenReader(source, *this));
+ return reader->schema();
+}
+
+Result<ScanTaskIterator> CsvFileFormat::ScanFile(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& fragment) const {
+ auto this_ = checked_pointer_cast<const CsvFileFormat>(shared_from_this());
+ auto task = std::make_shared<CsvScanTask>(std::move(this_), options, fragment);
+
+ return MakeVectorIterator<std::shared_ptr<ScanTask>>({std::move(task)});
+}
+
+Result<RecordBatchGenerator> CsvFileFormat::ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& scan_options,
+ const std::shared_ptr<FileFragment>& file) const {
+ auto this_ = checked_pointer_cast<const CsvFileFormat>(shared_from_this());
+ auto source = file->source();
+ auto reader_fut =
+ OpenReaderAsync(source, *this, scan_options, ::arrow::internal::GetCpuThreadPool());
+ return GeneratorFromReader(std::move(reader_fut), scan_options->batch_size);
+}
+
+Future<util::optional<int64_t>> CsvFileFormat::CountRows(
+ const std::shared_ptr<FileFragment>& file, compute::Expression predicate,
+ const std::shared_ptr<ScanOptions>& options) {
+ if (ExpressionHasFieldRefs(predicate)) {
+ return Future<util::optional<int64_t>>::MakeFinished(util::nullopt);
+ }
+ auto self = checked_pointer_cast<CsvFileFormat>(shared_from_this());
+ ARROW_ASSIGN_OR_RAISE(auto input, file->source().OpenCompressed());
+ ARROW_ASSIGN_OR_RAISE(auto read_options, GetReadOptions(*self, options));
+ return csv::CountRowsAsync(options->io_context, std::move(input),
+ ::arrow::internal::GetCpuThreadPool(), read_options,
+ self->parse_options)
+ .Then([](int64_t count) { return util::make_optional<int64_t>(count); });
+}
+
+//
+// CsvFileWriter, CsvFileWriteOptions
+//
+
+std::shared_ptr<FileWriteOptions> CsvFileFormat::DefaultWriteOptions() {
+ std::shared_ptr<CsvFileWriteOptions> csv_options(
+ new CsvFileWriteOptions(shared_from_this()));
+ csv_options->write_options =
+ std::make_shared<csv::WriteOptions>(csv::WriteOptions::Defaults());
+ return csv_options;
+}
+
+Result<std::shared_ptr<FileWriter>> CsvFileFormat::MakeWriter(
+ std::shared_ptr<io::OutputStream> destination, std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options,
+ fs::FileLocator destination_locator) const {
+ if (!Equals(*options->format())) {
+ return Status::TypeError("Mismatching format/write options.");
+ }
+ auto csv_options = checked_pointer_cast<CsvFileWriteOptions>(options);
+ ARROW_ASSIGN_OR_RAISE(
+ auto writer, csv::MakeCSVWriter(destination, schema, *csv_options->write_options));
+ return std::shared_ptr<FileWriter>(
+ new CsvFileWriter(std::move(destination), std::move(writer), std::move(schema),
+ std::move(csv_options), std::move(destination_locator)));
+}
+
+CsvFileWriter::CsvFileWriter(std::shared_ptr<io::OutputStream> destination,
+ std::shared_ptr<ipc::RecordBatchWriter> writer,
+ std::shared_ptr<Schema> schema,
+ std::shared_ptr<CsvFileWriteOptions> options,
+ fs::FileLocator destination_locator)
+ : FileWriter(std::move(schema), std::move(options), std::move(destination),
+ std::move(destination_locator)),
+ batch_writer_(std::move(writer)) {}
+
+Status CsvFileWriter::Write(const std::shared_ptr<RecordBatch>& batch) {
+ return batch_writer_->WriteRecordBatch(*batch);
+}
+
+Status CsvFileWriter::FinishInternal() { return batch_writer_->Close(); }
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_csv.h b/src/arrow/cpp/src/arrow/dataset/file_csv.h
new file mode 100644
index 000000000..8d7391727
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_csv.h
@@ -0,0 +1,123 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/csv/options.h"
+#include "arrow/dataset/dataset.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/type_fwd.h"
+#include "arrow/dataset/visibility.h"
+#include "arrow/ipc/type_fwd.h"
+#include "arrow/status.h"
+#include "arrow/util/compression.h"
+
+namespace arrow {
+namespace dataset {
+
+constexpr char kCsvTypeName[] = "csv";
+
+/// \addtogroup dataset-file-formats
+///
+/// @{
+
+/// \brief A FileFormat implementation that reads from and writes to Csv files
+class ARROW_DS_EXPORT CsvFileFormat : public FileFormat {
+ public:
+ /// Options affecting the parsing of CSV files
+ csv::ParseOptions parse_options = csv::ParseOptions::Defaults();
+
+ std::string type_name() const override { return kCsvTypeName; }
+
+ bool Equals(const FileFormat& other) const override;
+
+ Result<bool> IsSupported(const FileSource& source) const override;
+
+ /// \brief Return the schema of the file if possible.
+ Result<std::shared_ptr<Schema>> Inspect(const FileSource& source) const override;
+
+ /// \brief Open a file for scanning
+ Result<ScanTaskIterator> ScanFile(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& fragment) const override;
+
+ Result<RecordBatchGenerator> ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& scan_options,
+ const std::shared_ptr<FileFragment>& file) const override;
+
+ Future<util::optional<int64_t>> CountRows(
+ const std::shared_ptr<FileFragment>& file, compute::Expression predicate,
+ const std::shared_ptr<ScanOptions>& options) override;
+
+ Result<std::shared_ptr<FileWriter>> MakeWriter(
+ std::shared_ptr<io::OutputStream> destination, std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options,
+ fs::FileLocator destination_locator) const override;
+
+ std::shared_ptr<FileWriteOptions> DefaultWriteOptions() override;
+};
+
+/// \brief Per-scan options for CSV fragments
+struct ARROW_DS_EXPORT CsvFragmentScanOptions : public FragmentScanOptions {
+ std::string type_name() const override { return kCsvTypeName; }
+
+ /// CSV conversion options
+ csv::ConvertOptions convert_options = csv::ConvertOptions::Defaults();
+
+ /// CSV reading options
+ ///
+ /// Note that use_threads is always ignored.
+ csv::ReadOptions read_options = csv::ReadOptions::Defaults();
+};
+
+class ARROW_DS_EXPORT CsvFileWriteOptions : public FileWriteOptions {
+ public:
+ /// Options passed to csv::MakeCSVWriter.
+ std::shared_ptr<csv::WriteOptions> write_options;
+
+ protected:
+ using FileWriteOptions::FileWriteOptions;
+
+ friend class CsvFileFormat;
+};
+
+class ARROW_DS_EXPORT CsvFileWriter : public FileWriter {
+ public:
+ Status Write(const std::shared_ptr<RecordBatch>& batch) override;
+
+ private:
+ CsvFileWriter(std::shared_ptr<io::OutputStream> destination,
+ std::shared_ptr<ipc::RecordBatchWriter> writer,
+ std::shared_ptr<Schema> schema,
+ std::shared_ptr<CsvFileWriteOptions> options,
+ fs::FileLocator destination_locator);
+
+ Status FinishInternal() override;
+
+ std::shared_ptr<io::OutputStream> destination_;
+ std::shared_ptr<ipc::RecordBatchWriter> batch_writer_;
+
+ friend class CsvFileFormat;
+};
+
+/// @}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_csv_test.cc b/src/arrow/cpp/src/arrow/dataset/file_csv_test.cc
new file mode 100644
index 000000000..e4303d022
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_csv_test.cc
@@ -0,0 +1,404 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/file_csv.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/csv/writer.h"
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/partition.h"
+#include "arrow/dataset/test_util.h"
+#include "arrow/filesystem/mockfs.h"
+#include "arrow/io/compressed.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/testing/generator.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+
+namespace arrow {
+namespace dataset {
+
+class CsvFormatHelper {
+ public:
+ using FormatType = CsvFileFormat;
+ static Result<std::shared_ptr<Buffer>> Write(RecordBatchReader* reader) {
+ ARROW_ASSIGN_OR_RAISE(auto sink, io::BufferOutputStream::Create());
+ std::shared_ptr<Table> table;
+ RETURN_NOT_OK(reader->ReadAll(&table));
+ auto options = csv::WriteOptions::Defaults();
+ RETURN_NOT_OK(csv::WriteCSV(*table, options, sink.get()));
+ return sink->Finish();
+ }
+
+ static std::shared_ptr<CsvFileFormat> MakeFormat() {
+ auto format = std::make_shared<CsvFileFormat>();
+ // Required for CountRows (since the test generates data with nulls that get written
+ // as empty lines)
+ format->parse_options.ignore_empty_lines = false;
+ return format;
+ }
+};
+
+class TestCsvFileFormat : public FileFormatFixtureMixin<CsvFormatHelper>,
+ public ::testing::WithParamInterface<Compression::type> {
+ public:
+ Compression::type GetCompression() { return GetParam(); }
+
+ std::unique_ptr<FileSource> GetFileSource(std::string csv) {
+ if (GetCompression() == Compression::UNCOMPRESSED) {
+ return ::arrow::internal::make_unique<FileSource>(
+ Buffer::FromString(std::move(csv)));
+ }
+ std::string path = "test.csv";
+ switch (GetCompression()) {
+ case Compression::type::GZIP:
+ path += ".gz";
+ break;
+ case Compression::type::ZSTD:
+ path += ".zstd";
+ break;
+ case Compression::type::LZ4_FRAME:
+ path += ".lz4";
+ break;
+ case Compression::type::BZ2:
+ path += ".bz2";
+ break;
+ default:
+ // No known extension
+ break;
+ }
+ EXPECT_OK_AND_ASSIGN(auto fs, fs::internal::MockFileSystem::Make(fs::kNoTime, {}));
+ EXPECT_OK_AND_ASSIGN(auto codec, util::Codec::Create(GetCompression()));
+ EXPECT_OK_AND_ASSIGN(auto buffer_writer, fs->OpenOutputStream(path));
+ EXPECT_OK_AND_ASSIGN(auto stream,
+ io::CompressedOutputStream::Make(codec.get(), buffer_writer));
+ ARROW_EXPECT_OK(stream->Write(csv));
+ ARROW_EXPECT_OK(stream->Close());
+ EXPECT_OK_AND_ASSIGN(auto info, fs->GetFileInfo(path));
+ return ::arrow::internal::make_unique<FileSource>(info, fs, GetCompression());
+ }
+
+ RecordBatchIterator Batches(ScanTaskIterator scan_task_it) {
+ return MakeFlattenIterator(MakeMaybeMapIterator(
+ [](std::shared_ptr<ScanTask> scan_task) { return scan_task->Execute(); },
+ std::move(scan_task_it)));
+ }
+
+ RecordBatchIterator Batches(Fragment* fragment) {
+ EXPECT_OK_AND_ASSIGN(auto scan_task_it, fragment->Scan(opts_));
+ return Batches(std::move(scan_task_it));
+ }
+};
+
+// Basic scanning tests (to exercise compression support); see the parameterized test
+// below for more comprehensive testing of scan behaviors
+TEST_P(TestCsvFileFormat, ScanRecordBatchReader) {
+ auto source = GetFileSource(R"(f64
+1.0
+
+N/A
+2)");
+ SetSchema({field("f64", float64())});
+ auto fragment = MakeFragment(*source);
+
+ int64_t row_count = 0;
+
+ for (auto maybe_batch : Batches(fragment.get())) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ row_count += batch->num_rows();
+ }
+
+ ASSERT_EQ(row_count, 4);
+}
+
+TEST_P(TestCsvFileFormat, CustomConvertOptions) {
+ auto source = GetFileSource(R"(str
+foo
+MYNULL
+N/A
+bar)");
+ SetSchema({field("str", utf8())});
+ auto fragment = MakeFragment(*source);
+ auto fragment_scan_options = std::make_shared<CsvFragmentScanOptions>();
+ fragment_scan_options->convert_options.null_values = {"MYNULL"};
+ fragment_scan_options->convert_options.strings_can_be_null = true;
+ opts_->fragment_scan_options = fragment_scan_options;
+
+ int64_t null_count = 0;
+ for (auto maybe_batch : Batches(fragment.get())) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ null_count += batch->GetColumnByName("str")->null_count();
+ }
+
+ ASSERT_EQ(null_count, 1);
+}
+
+TEST_P(TestCsvFileFormat, CustomReadOptions) {
+ auto source = GetFileSource(R"(header_skipped
+str
+foo
+MYNULL
+N/A
+bar)");
+ {
+ SetSchema({field("str", utf8())});
+ auto defaults = std::make_shared<CsvFragmentScanOptions>();
+ defaults->read_options.skip_rows = 1;
+ format_->default_fragment_scan_options = defaults;
+ auto fragment = MakeFragment(*source);
+ ASSERT_OK_AND_ASSIGN(auto physical_schema, fragment->ReadPhysicalSchema());
+ AssertSchemaEqual(opts_->dataset_schema, physical_schema);
+
+ int64_t rows = 0;
+ for (auto maybe_batch : Batches(fragment.get())) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ rows += batch->GetColumnByName("str")->length();
+ }
+ ASSERT_EQ(rows, 4);
+ }
+ {
+ SetSchema({field("header_skipped", utf8())});
+ // These options completely override the default ones
+ auto fragment_scan_options = std::make_shared<CsvFragmentScanOptions>();
+ fragment_scan_options->read_options.block_size = 1 << 22;
+ opts_->fragment_scan_options = fragment_scan_options;
+ int64_t rows = 0;
+ auto fragment = MakeFragment(*source);
+ for (auto maybe_batch : Batches(fragment.get())) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ rows += batch->GetColumnByName("header_skipped")->length();
+ }
+ ASSERT_EQ(rows, 5);
+ }
+ {
+ SetSchema({field("custom_header", utf8())});
+ auto defaults = std::make_shared<CsvFragmentScanOptions>();
+ defaults->read_options.column_names = {"custom_header"};
+ format_->default_fragment_scan_options = defaults;
+ opts_->fragment_scan_options = nullptr;
+ int64_t rows = 0;
+ auto fragment = MakeFragment(*source);
+ for (auto maybe_batch : Batches(fragment.get())) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ rows += batch->GetColumnByName("custom_header")->length();
+ }
+ ASSERT_EQ(rows, 6);
+ }
+}
+
+TEST_P(TestCsvFileFormat, CustomReadOptionsColumnNames) {
+ auto source = GetFileSource("1,1\n2,3");
+ SetSchema({field("ints_1", int64()), field("ints_2", int64())});
+ auto defaults = std::make_shared<CsvFragmentScanOptions>();
+ defaults->read_options.column_names = {"ints_1", "ints_2"};
+ format_->default_fragment_scan_options = defaults;
+ auto fragment = MakeFragment(*source);
+ ASSERT_OK_AND_ASSIGN(auto physical_schema, fragment->ReadPhysicalSchema());
+ AssertSchemaEqual(opts_->dataset_schema, physical_schema);
+ int64_t rows = 0;
+ for (auto maybe_batch : Batches(fragment.get())) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ rows += batch->num_rows();
+ }
+ ASSERT_EQ(rows, 2);
+
+ defaults->read_options.column_names = {"same", "same"};
+ format_->default_fragment_scan_options = defaults;
+ fragment = MakeFragment(*source);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("CSV file contained multiple columns named same"),
+ Batches(fragment.get()).Next());
+}
+
+TEST_P(TestCsvFileFormat, ScanRecordBatchReaderWithVirtualColumn) {
+ auto source = GetFileSource(R"(f64
+1.0
+
+N/A
+2)");
+ // NB: dataset_schema includes a column not present in the file
+ SetSchema({field("f64", float64()), field("virtual", int32())});
+ auto fragment = MakeFragment(*source);
+
+ ASSERT_OK_AND_ASSIGN(auto physical_schema, fragment->ReadPhysicalSchema());
+ AssertSchemaEqual(Schema({field("f64", float64())}), *physical_schema);
+
+ int64_t row_count = 0;
+
+ for (auto maybe_batch : Batches(fragment.get())) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ AssertSchemaEqual(*batch->schema(), *physical_schema);
+ row_count += batch->num_rows();
+ }
+
+ ASSERT_EQ(row_count, 4);
+}
+
+TEST_P(TestCsvFileFormat, InspectFailureWithRelevantError) {
+ TestInspectFailureWithRelevantError(StatusCode::Invalid, "CSV");
+}
+
+TEST_P(TestCsvFileFormat, Inspect) {
+ TestInspect();
+ auto source = GetFileSource(R"(f64
+1.0
+
+N/A
+2)");
+ ASSERT_OK_AND_ASSIGN(auto actual, format_->Inspect(*source.get()));
+ EXPECT_EQ(*actual, Schema({field("f64", float64())}));
+}
+
+TEST_P(TestCsvFileFormat, InspectWithCustomConvertOptions) {
+ // Regression test for ARROW-12083
+ auto source = GetFileSource(R"(actually_string
+1.0
+
+N/A
+2)");
+ auto defaults = std::make_shared<CsvFragmentScanOptions>();
+ format_->default_fragment_scan_options = defaults;
+
+ ASSERT_OK_AND_ASSIGN(auto actual, format_->Inspect(*source.get()));
+ // Default type inferred
+ EXPECT_EQ(*actual, Schema({field("actually_string", float64())}));
+
+ // Override the inferred type
+ defaults->convert_options.column_types["actually_string"] = utf8();
+ ASSERT_OK_AND_ASSIGN(actual, format_->Inspect(*source.get()));
+ EXPECT_EQ(*actual, Schema({field("actually_string", utf8())}));
+}
+
+TEST_P(TestCsvFileFormat, IsSupported) {
+ TestIsSupported();
+ bool supported;
+
+ auto source = GetFileSource("");
+ ASSERT_OK_AND_ASSIGN(supported, format_->IsSupported(*source));
+ ASSERT_EQ(supported, false);
+
+ source = GetFileSource(R"(declare,two
+ 1,2,3)");
+ ASSERT_OK_AND_ASSIGN(supported, format_->IsSupported(*source));
+ ASSERT_EQ(supported, false);
+
+ source = GetFileSource(R"(f64
+1.0
+
+N/A
+2)");
+ ASSERT_OK_AND_ASSIGN(supported, format_->IsSupported(*source));
+ EXPECT_EQ(supported, true);
+}
+
+TEST_P(TestCsvFileFormat, NonProjectedFieldWithDifferingTypeFromInferred) {
+ auto source = GetFileSource(R"(betrayal_not_really_f64,str
+1.0,foo
+,
+N/A,bar
+2,baz)");
+ auto fragment = MakeFragment(*source);
+ ASSERT_OK_AND_ASSIGN(auto physical_schema, fragment->ReadPhysicalSchema());
+ AssertSchemaEqual(
+ Schema({field("betrayal_not_really_f64", float64()), field("str", utf8())}),
+ *physical_schema);
+
+ // CSV is a text format, so it is valid to read column betrayal_not_really_f64 as string
+ // rather than double
+ auto not_float64 = utf8();
+ auto dataset_schema =
+ schema({field("betrayal_not_really_f64", not_float64), field("str", utf8())});
+
+ ScannerBuilder builder(dataset_schema, fragment, opts_);
+
+ // This filter is valid with declared schema, but would *not* be valid
+ // if betrayal_not_really_f64 were read as double rather than string.
+ ASSERT_OK(
+ builder.Filter(equal(field_ref("betrayal_not_really_f64"), field_ref("str"))));
+
+ // project only "str"
+ ASSERT_OK(builder.Project({"str"}));
+ ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+
+ ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
+ // Run through the scan checking for errors to ensure that "f64" is read with the
+ // specified type and does not revert to the inferred type (if it reverts to
+ // inferring float64 then evaluation of the comparison expression should break)
+ ASSERT_OK(batch_it.Visit([](TaggedRecordBatch) { return Status::OK(); }));
+}
+
+TEST_P(TestCsvFileFormat, WriteRecordBatchReader) { TestWrite(); }
+
+TEST_P(TestCsvFileFormat, WriteRecordBatchReaderCustomOptions) {
+ auto options =
+ checked_pointer_cast<CsvFileWriteOptions>(format_->DefaultWriteOptions());
+ options->write_options->include_header = false;
+ auto data_schema = schema({field("f64", float64())});
+ ASSERT_OK_AND_ASSIGN(auto sink, GetFileSink());
+ ASSERT_OK_AND_ASSIGN(auto writer, format_->MakeWriter(sink, data_schema, options, {}));
+ ASSERT_OK(writer->Write(ConstantArrayGenerator::Zeroes(5, data_schema)));
+ ASSERT_OK(writer->Finish());
+ ASSERT_OK_AND_ASSIGN(auto written, sink->Finish());
+ ASSERT_EQ("0\n0\n0\n0\n0\n", written->ToString());
+}
+
+TEST_P(TestCsvFileFormat, CountRows) { TestCountRows(); }
+
+INSTANTIATE_TEST_SUITE_P(TestUncompressedCsv, TestCsvFileFormat,
+ ::testing::Values(Compression::UNCOMPRESSED));
+#ifdef ARROW_WITH_BZ2
+INSTANTIATE_TEST_SUITE_P(TestBZ2Csv, TestCsvFileFormat,
+ ::testing::Values(Compression::BZ2));
+#endif
+#ifdef ARROW_WITH_LZ4
+INSTANTIATE_TEST_SUITE_P(TestLZ4Csv, TestCsvFileFormat,
+ ::testing::Values(Compression::LZ4_FRAME));
+#endif
+// Snappy does not support streaming compression
+#ifdef ARROW_WITH_ZLIB
+INSTANTIATE_TEST_SUITE_P(TestGZipCsv, TestCsvFileFormat,
+ ::testing::Values(Compression::GZIP));
+#endif
+#ifdef ARROW_WITH_ZSTD
+INSTANTIATE_TEST_SUITE_P(TestZSTDCsv, TestCsvFileFormat,
+ ::testing::Values(Compression::ZSTD));
+#endif
+
+class TestCsvFileFormatScan : public FileFormatScanMixin<CsvFormatHelper> {};
+
+TEST_P(TestCsvFileFormatScan, ScanRecordBatchReader) { TestScan(); }
+TEST_P(TestCsvFileFormatScan, ScanBatchSize) { TestScanBatchSize(); }
+TEST_P(TestCsvFileFormatScan, ScanRecordBatchReaderWithVirtualColumn) {
+ TestScanWithVirtualColumn();
+}
+TEST_P(TestCsvFileFormatScan, ScanRecordBatchReaderProjected) { TestScanProjected(); }
+TEST_P(TestCsvFileFormatScan, ScanRecordBatchReaderProjectedMissingCols) {
+ TestScanProjectedMissingCols();
+}
+
+INSTANTIATE_TEST_SUITE_P(TestScan, TestCsvFileFormatScan,
+ ::testing::ValuesIn(TestFormatParams::Values()),
+ TestFormatParams::ToTestNameString);
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_ipc.cc b/src/arrow/cpp/src/arrow/dataset/file_ipc.cc
new file mode 100644
index 000000000..e01373e79
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_ipc.cc
@@ -0,0 +1,310 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/file_ipc.h"
+
+#include <algorithm>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/scanner.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_pointer_cast;
+
+namespace dataset {
+
+static inline ipc::IpcReadOptions default_read_options() {
+ auto options = ipc::IpcReadOptions::Defaults();
+ options.use_threads = false;
+ return options;
+}
+
+static inline Result<std::shared_ptr<ipc::RecordBatchFileReader>> OpenReader(
+ const FileSource& source,
+ const ipc::IpcReadOptions& options = default_read_options()) {
+ ARROW_ASSIGN_OR_RAISE(auto input, source.Open());
+
+ std::shared_ptr<ipc::RecordBatchFileReader> reader;
+
+ auto status =
+ ipc::RecordBatchFileReader::Open(std::move(input), options).Value(&reader);
+ if (!status.ok()) {
+ return status.WithMessage("Could not open IPC input source '", source.path(),
+ "': ", status.message());
+ }
+ return reader;
+}
+
+static inline Future<std::shared_ptr<ipc::RecordBatchFileReader>> OpenReaderAsync(
+ const FileSource& source,
+ const ipc::IpcReadOptions& options = default_read_options()) {
+ ARROW_ASSIGN_OR_RAISE(auto input, source.Open());
+ auto path = source.path();
+ return ipc::RecordBatchFileReader::OpenAsync(std::move(input), options)
+ .Then([](const std::shared_ptr<ipc::RecordBatchFileReader>& reader)
+ -> Result<std::shared_ptr<ipc::RecordBatchFileReader>> { return reader; },
+ [path](const Status& status)
+ -> Result<std::shared_ptr<ipc::RecordBatchFileReader>> {
+ return status.WithMessage("Could not open IPC input source '", path,
+ "': ", status.message());
+ });
+}
+
+static inline Result<std::vector<int>> GetIncludedFields(
+ const Schema& schema, const std::vector<std::string>& materialized_fields) {
+ std::vector<int> included_fields;
+
+ for (FieldRef ref : materialized_fields) {
+ ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOneOrNone(schema));
+ if (match.indices().empty()) continue;
+
+ included_fields.push_back(match.indices()[0]);
+ }
+
+ return included_fields;
+}
+
+static inline Result<ipc::IpcReadOptions> GetReadOptions(
+ const Schema& schema, const FileFormat& format, const ScanOptions& scan_options) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto ipc_scan_options,
+ GetFragmentScanOptions<IpcFragmentScanOptions>(
+ kIpcTypeName, &scan_options, format.default_fragment_scan_options));
+ auto options =
+ ipc_scan_options->options ? *ipc_scan_options->options : default_read_options();
+ options.memory_pool = scan_options.pool;
+ if (!options.included_fields.empty()) {
+ // Cannot set them here
+ ARROW_LOG(WARNING) << "IpcFragmentScanOptions.options->included_fields was set "
+ "but will be ignored; included_fields are derived from "
+ "fields referenced by the scan";
+ }
+ ARROW_ASSIGN_OR_RAISE(options.included_fields,
+ GetIncludedFields(schema, scan_options.MaterializedFields()));
+ return options;
+}
+
+/// \brief A ScanTask backed by an Ipc file.
+class IpcScanTask : public ScanTask {
+ public:
+ IpcScanTask(std::shared_ptr<FileFragment> fragment,
+ std::shared_ptr<ScanOptions> options)
+ : ScanTask(std::move(options), fragment), source_(fragment->source()) {}
+
+ Result<RecordBatchIterator> Execute() override {
+ struct Impl {
+ static Result<RecordBatchIterator> Make(const FileSource& source,
+ const FileFormat& format,
+ const ScanOptions& scan_options) {
+ ARROW_ASSIGN_OR_RAISE(auto reader, OpenReader(source));
+ ARROW_ASSIGN_OR_RAISE(auto options,
+ GetReadOptions(*reader->schema(), format, scan_options));
+ ARROW_ASSIGN_OR_RAISE(reader, OpenReader(source, options));
+ return RecordBatchIterator(
+ Impl{std::move(reader), scan_options.batch_size, nullptr, 0});
+ }
+
+ Result<std::shared_ptr<RecordBatch>> Next() {
+ if (leftover_) {
+ if (leftover_->num_rows() > batch_size) {
+ auto chunk = leftover_->Slice(0, batch_size);
+ leftover_ = leftover_->Slice(batch_size);
+ return chunk;
+ }
+ return std::move(leftover_);
+ }
+ if (i_ == reader_->num_record_batches()) {
+ return nullptr;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto batch, reader_->ReadRecordBatch(i_++));
+ if (batch->num_rows() > batch_size) {
+ leftover_ = batch->Slice(batch_size);
+ return batch->Slice(0, batch_size);
+ }
+ return batch;
+ }
+
+ std::shared_ptr<ipc::RecordBatchFileReader> reader_;
+ const int64_t batch_size;
+ std::shared_ptr<RecordBatch> leftover_;
+ int i_;
+ };
+
+ return Impl::Make(source_, *checked_pointer_cast<FileFragment>(fragment_)->format(),
+ *options_);
+ }
+
+ private:
+ FileSource source_;
+};
+
+class IpcScanTaskIterator {
+ public:
+ static Result<ScanTaskIterator> Make(std::shared_ptr<ScanOptions> options,
+ std::shared_ptr<FileFragment> fragment) {
+ return ScanTaskIterator(IpcScanTaskIterator(std::move(options), std::move(fragment)));
+ }
+
+ Result<std::shared_ptr<ScanTask>> Next() {
+ if (once_) {
+ // Iteration is done.
+ return nullptr;
+ }
+
+ once_ = true;
+ return std::shared_ptr<ScanTask>(new IpcScanTask(fragment_, options_));
+ }
+
+ private:
+ IpcScanTaskIterator(std::shared_ptr<ScanOptions> options,
+ std::shared_ptr<FileFragment> fragment)
+ : options_(std::move(options)), fragment_(std::move(fragment)) {}
+
+ bool once_ = false;
+ std::shared_ptr<ScanOptions> options_;
+ std::shared_ptr<FileFragment> fragment_;
+};
+
+Result<bool> IpcFileFormat::IsSupported(const FileSource& source) const {
+ RETURN_NOT_OK(source.Open().status());
+ return OpenReader(source).ok();
+}
+
+Result<std::shared_ptr<Schema>> IpcFileFormat::Inspect(const FileSource& source) const {
+ ARROW_ASSIGN_OR_RAISE(auto reader, OpenReader(source));
+ return reader->schema();
+}
+
+Result<ScanTaskIterator> IpcFileFormat::ScanFile(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& fragment) const {
+ return IpcScanTaskIterator::Make(options, fragment);
+}
+
+Result<RecordBatchGenerator> IpcFileFormat::ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& file) const {
+ auto self = shared_from_this();
+ auto source = file->source();
+ auto open_reader = OpenReaderAsync(source);
+ auto reopen_reader = [self, options,
+ source](std::shared_ptr<ipc::RecordBatchFileReader> reader)
+ -> Future<std::shared_ptr<ipc::RecordBatchFileReader>> {
+ ARROW_ASSIGN_OR_RAISE(auto options,
+ GetReadOptions(*reader->schema(), *self, *options));
+ return OpenReader(source, options);
+ };
+ auto readahead_level = options->batch_readahead;
+ auto default_fragment_scan_options = this->default_fragment_scan_options;
+ auto open_generator = [=](const std::shared_ptr<ipc::RecordBatchFileReader>& reader)
+ -> Result<RecordBatchGenerator> {
+ ARROW_ASSIGN_OR_RAISE(
+ auto ipc_scan_options,
+ GetFragmentScanOptions<IpcFragmentScanOptions>(kIpcTypeName, options.get(),
+ default_fragment_scan_options));
+
+ RecordBatchGenerator generator;
+ if (ipc_scan_options->cache_options) {
+ // Transferring helps performance when coalescing
+ ARROW_ASSIGN_OR_RAISE(generator, reader->GetRecordBatchGenerator(
+ /*coalesce=*/true, options->io_context,
+ *ipc_scan_options->cache_options,
+ ::arrow::internal::GetCpuThreadPool()));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(generator, reader->GetRecordBatchGenerator(
+ /*coalesce=*/false, options->io_context));
+ }
+ auto batch_generator = MakeReadaheadGenerator(std::move(generator), readahead_level);
+ return MakeChunkedBatchGenerator(std::move(batch_generator), options->batch_size);
+ };
+ return MakeFromFuture(open_reader.Then(reopen_reader).Then(open_generator));
+}
+
+Future<util::optional<int64_t>> IpcFileFormat::CountRows(
+ const std::shared_ptr<FileFragment>& file, compute::Expression predicate,
+ const std::shared_ptr<ScanOptions>& options) {
+ if (ExpressionHasFieldRefs(predicate)) {
+ return Future<util::optional<int64_t>>::MakeFinished(util::nullopt);
+ }
+ auto self = checked_pointer_cast<IpcFileFormat>(shared_from_this());
+ return DeferNotOk(options->io_context.executor()->Submit(
+ [self, file]() -> Result<util::optional<int64_t>> {
+ ARROW_ASSIGN_OR_RAISE(auto reader, OpenReader(file->source()));
+ return reader->CountRows();
+ }));
+}
+
+//
+// IpcFileWriter, IpcFileWriteOptions
+//
+
+std::shared_ptr<FileWriteOptions> IpcFileFormat::DefaultWriteOptions() {
+ std::shared_ptr<IpcFileWriteOptions> ipc_options(
+ new IpcFileWriteOptions(shared_from_this()));
+
+ ipc_options->options =
+ std::make_shared<ipc::IpcWriteOptions>(ipc::IpcWriteOptions::Defaults());
+ return ipc_options;
+}
+
+Result<std::shared_ptr<FileWriter>> IpcFileFormat::MakeWriter(
+ std::shared_ptr<io::OutputStream> destination, std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options,
+ fs::FileLocator destination_locator) const {
+ if (!Equals(*options->format())) {
+ return Status::TypeError("Mismatching format/write options.");
+ }
+
+ auto ipc_options = checked_pointer_cast<IpcFileWriteOptions>(options);
+
+ ARROW_ASSIGN_OR_RAISE(auto writer,
+ ipc::MakeFileWriter(destination, schema, *ipc_options->options,
+ ipc_options->metadata));
+
+ return std::shared_ptr<FileWriter>(
+ new IpcFileWriter(std::move(destination), std::move(writer), std::move(schema),
+ std::move(ipc_options), std::move(destination_locator)));
+}
+
+IpcFileWriter::IpcFileWriter(std::shared_ptr<io::OutputStream> destination,
+ std::shared_ptr<ipc::RecordBatchWriter> writer,
+ std::shared_ptr<Schema> schema,
+ std::shared_ptr<IpcFileWriteOptions> options,
+ fs::FileLocator destination_locator)
+ : FileWriter(std::move(schema), std::move(options), std::move(destination),
+ std::move(destination_locator)),
+ batch_writer_(std::move(writer)) {}
+
+Status IpcFileWriter::Write(const std::shared_ptr<RecordBatch>& batch) {
+ return batch_writer_->WriteRecordBatch(*batch);
+}
+
+Status IpcFileWriter::FinishInternal() { return batch_writer_->Close(); }
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_ipc.h b/src/arrow/cpp/src/arrow/dataset/file_ipc.h
new file mode 100644
index 000000000..ef7851522
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_ipc.h
@@ -0,0 +1,125 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/type_fwd.h"
+#include "arrow/dataset/visibility.h"
+#include "arrow/io/type_fwd.h"
+#include "arrow/ipc/type_fwd.h"
+#include "arrow/result.h"
+
+namespace arrow {
+namespace dataset {
+
+/// \addtogroup dataset-file-formats
+///
+/// @{
+
+constexpr char kIpcTypeName[] = "ipc";
+
+/// \brief A FileFormat implementation that reads from and writes to Ipc files
+class ARROW_DS_EXPORT IpcFileFormat : public FileFormat {
+ public:
+ std::string type_name() const override { return kIpcTypeName; }
+
+ bool Equals(const FileFormat& other) const override {
+ return type_name() == other.type_name();
+ }
+
+ Result<bool> IsSupported(const FileSource& source) const override;
+
+ /// \brief Return the schema of the file if possible.
+ Result<std::shared_ptr<Schema>> Inspect(const FileSource& source) const override;
+
+ /// \brief Open a file for scanning
+ Result<ScanTaskIterator> ScanFile(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& fragment) const override;
+
+ Result<RecordBatchGenerator> ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& file) const override;
+
+ Future<util::optional<int64_t>> CountRows(
+ const std::shared_ptr<FileFragment>& file, compute::Expression predicate,
+ const std::shared_ptr<ScanOptions>& options) override;
+
+ Result<std::shared_ptr<FileWriter>> MakeWriter(
+ std::shared_ptr<io::OutputStream> destination, std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options,
+ fs::FileLocator destination_locator) const override;
+
+ std::shared_ptr<FileWriteOptions> DefaultWriteOptions() override;
+};
+
+/// \brief Per-scan options for IPC fragments
+class ARROW_DS_EXPORT IpcFragmentScanOptions : public FragmentScanOptions {
+ public:
+ std::string type_name() const override { return kIpcTypeName; }
+
+ /// Options passed to the IPC file reader.
+ /// included_fields, memory_pool, and use_threads are ignored.
+ std::shared_ptr<ipc::IpcReadOptions> options;
+ /// If present, the async scanner will enable I/O coalescing.
+ /// This is ignored by the sync scanner.
+ std::shared_ptr<io::CacheOptions> cache_options;
+};
+
+class ARROW_DS_EXPORT IpcFileWriteOptions : public FileWriteOptions {
+ public:
+ /// Options passed to ipc::MakeFileWriter. use_threads is ignored
+ std::shared_ptr<ipc::IpcWriteOptions> options;
+
+ /// custom_metadata written to the file's footer
+ std::shared_ptr<const KeyValueMetadata> metadata;
+
+ protected:
+ using FileWriteOptions::FileWriteOptions;
+
+ friend class IpcFileFormat;
+};
+
+class ARROW_DS_EXPORT IpcFileWriter : public FileWriter {
+ public:
+ Status Write(const std::shared_ptr<RecordBatch>& batch) override;
+
+ private:
+ IpcFileWriter(std::shared_ptr<io::OutputStream> destination,
+ std::shared_ptr<ipc::RecordBatchWriter> writer,
+ std::shared_ptr<Schema> schema,
+ std::shared_ptr<IpcFileWriteOptions> options,
+ fs::FileLocator destination_locator);
+
+ Status FinishInternal() override;
+
+ std::shared_ptr<io::OutputStream> destination_;
+ std::shared_ptr<ipc::RecordBatchWriter> batch_writer_;
+
+ friend class IpcFileFormat;
+};
+
+/// @}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_ipc_test.cc b/src/arrow/cpp/src/arrow/dataset/file_ipc_test.cc
new file mode 100644
index 000000000..cb625a9e1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_ipc_test.cc
@@ -0,0 +1,173 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/file_ipc.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/discovery.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/partition.h"
+#include "arrow/dataset/scanner_internal.h"
+#include "arrow/dataset/test_util.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+
+using internal::checked_pointer_cast;
+
+namespace dataset {
+
+class IpcFormatHelper {
+ public:
+ using FormatType = IpcFileFormat;
+ static Result<std::shared_ptr<Buffer>> Write(RecordBatchReader* reader) {
+ ARROW_ASSIGN_OR_RAISE(auto sink, io::BufferOutputStream::Create());
+ ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(sink, reader->schema()));
+ std::vector<std::shared_ptr<RecordBatch>> batches;
+ RETURN_NOT_OK(reader->ReadAll(&batches));
+ for (auto batch : batches) {
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+ }
+ RETURN_NOT_OK(writer->Close());
+ return sink->Finish();
+ }
+
+ static std::shared_ptr<IpcFileFormat> MakeFormat() {
+ return std::make_shared<IpcFileFormat>();
+ }
+};
+
+class TestIpcFileFormat : public FileFormatFixtureMixin<IpcFormatHelper> {};
+
+TEST_F(TestIpcFileFormat, WriteRecordBatchReader) { TestWrite(); }
+
+TEST_F(TestIpcFileFormat, WriteRecordBatchReaderCustomOptions) {
+ auto reader = GetRecordBatchReader(schema({field("f64", float64())}));
+ auto source = GetFileSource(reader.get());
+ auto ipc_options =
+ checked_pointer_cast<IpcFileWriteOptions>(format_->DefaultWriteOptions());
+ if (util::Codec::IsAvailable(Compression::ZSTD)) {
+ EXPECT_OK_AND_ASSIGN(ipc_options->options->codec,
+ util::Codec::Create(Compression::ZSTD));
+ }
+ ipc_options->metadata = key_value_metadata({{"hello", "world"}});
+
+ auto written = WriteToBuffer(reader->schema(), ipc_options);
+
+ EXPECT_OK_AND_ASSIGN(auto ipc_reader, ipc::RecordBatchFileReader::Open(
+ std::make_shared<io::BufferReader>(written)));
+ EXPECT_EQ(ipc_reader->metadata()->sorted_pairs(),
+ ipc_options->metadata->sorted_pairs());
+}
+
+TEST_F(TestIpcFileFormat, InspectFailureWithRelevantError) {
+ TestInspectFailureWithRelevantError(StatusCode::Invalid, "IPC");
+}
+TEST_F(TestIpcFileFormat, Inspect) { TestInspect(); }
+TEST_F(TestIpcFileFormat, IsSupported) { TestIsSupported(); }
+TEST_F(TestIpcFileFormat, CountRows) { TestCountRows(); }
+
+class TestIpcFileSystemDataset : public testing::Test,
+ public WriteFileSystemDatasetMixin {
+ public:
+ void SetUp() override {
+ MakeSourceDataset();
+ auto ipc_format = std::make_shared<IpcFileFormat>();
+ format_ = ipc_format;
+ SetWriteOptions(ipc_format->DefaultWriteOptions());
+ }
+};
+
+TEST_F(TestIpcFileSystemDataset, WriteWithIdenticalPartitioningSchema) {
+ TestWriteWithIdenticalPartitioningSchema();
+}
+
+TEST_F(TestIpcFileSystemDataset, WriteWithUnrelatedPartitioningSchema) {
+ TestWriteWithUnrelatedPartitioningSchema();
+}
+
+TEST_F(TestIpcFileSystemDataset, WriteWithSupersetPartitioningSchema) {
+ TestWriteWithSupersetPartitioningSchema();
+}
+
+TEST_F(TestIpcFileSystemDataset, WriteWithEmptyPartitioningSchema) {
+ TestWriteWithEmptyPartitioningSchema();
+}
+
+TEST_F(TestIpcFileSystemDataset, WriteExceedsMaxPartitions) {
+ write_options_.partitioning = std::make_shared<DirectoryPartitioning>(
+ SchemaFromColumnNames(source_schema_, {"model"}));
+
+ // require that no batch be grouped into more than 2 written batches:
+ write_options_.max_partitions = 2;
+
+ auto scanner_builder = ScannerBuilder(dataset_, scan_options_);
+ ASSERT_OK(scanner_builder.UseAsync(true));
+ EXPECT_OK_AND_ASSIGN(auto scanner, scanner_builder.Finish());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("This exceeds the maximum"),
+ FileSystemDataset::Write(write_options_, scanner));
+}
+
+class TestIpcFileFormatScan : public FileFormatScanMixin<IpcFormatHelper> {};
+
+TEST_P(TestIpcFileFormatScan, ScanRecordBatchReader) { TestScan(); }
+TEST_P(TestIpcFileFormatScan, ScanBatchSize) { TestScanBatchSize(); }
+TEST_P(TestIpcFileFormatScan, ScanRecordBatchReaderWithVirtualColumn) {
+ TestScanWithVirtualColumn();
+}
+TEST_P(TestIpcFileFormatScan, ScanRecordBatchReaderProjected) { TestScanProjected(); }
+TEST_P(TestIpcFileFormatScan, ScanRecordBatchReaderProjectedMissingCols) {
+ TestScanProjectedMissingCols();
+}
+TEST_P(TestIpcFileFormatScan, FragmentScanOptions) {
+ auto reader = GetRecordBatchReader(
+ // ARROW-12077: on Windows/mimalloc/release, nullable list column leads to crash
+ schema({field("list", list(float64()), false,
+ key_value_metadata({{"max_length", "1"}})),
+ field("f64", float64())}));
+ auto source = GetFileSource(reader.get());
+
+ SetSchema(reader->schema()->fields());
+ auto fragment = MakeFragment(*source);
+
+ // Set scan options that ensure reading fails
+ auto fragment_scan_options = std::make_shared<IpcFragmentScanOptions>();
+ fragment_scan_options->options = std::make_shared<ipc::IpcReadOptions>();
+ fragment_scan_options->options->max_recursion_depth = 0;
+ opts_->fragment_scan_options = fragment_scan_options;
+ ASSERT_OK_AND_ASSIGN(auto scan_tasks, fragment->Scan(opts_));
+ ASSERT_OK_AND_ASSIGN(auto scan_task, scan_tasks.Next());
+ ASSERT_OK_AND_ASSIGN(auto batches, scan_task->Execute());
+ ASSERT_RAISES(Invalid, batches.Next());
+}
+INSTANTIATE_TEST_SUITE_P(TestScan, TestIpcFileFormatScan,
+ ::testing::ValuesIn(TestFormatParams::Values()),
+ TestFormatParams::ToTestNameString);
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_orc.cc b/src/arrow/cpp/src/arrow/dataset/file_orc.cc
new file mode 100644
index 000000000..44ae3a770
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_orc.cc
@@ -0,0 +1,193 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/file_orc.h"
+
+#include <memory>
+
+#include "arrow/adapters/orc/adapter.h"
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/scanner.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_pointer_cast;
+
+namespace dataset {
+
+namespace {
+
+Result<std::unique_ptr<arrow::adapters::orc::ORCFileReader>> OpenORCReader(
+ const FileSource& source,
+ const std::shared_ptr<ScanOptions>& scan_options = nullptr) {
+ ARROW_ASSIGN_OR_RAISE(auto input, source.Open());
+
+ arrow::MemoryPool* pool;
+ if (scan_options) {
+ pool = scan_options->pool;
+ } else {
+ pool = default_memory_pool();
+ }
+
+ auto reader = arrow::adapters::orc::ORCFileReader::Open(std::move(input), pool);
+ auto status = reader.status();
+ if (!status.ok()) {
+ return status.WithMessage("Could not open ORC input source '", source.path(),
+ "': ", status.message());
+ }
+ return reader;
+}
+
+/// \brief A ScanTask backed by an ORC file.
+class OrcScanTask : public ScanTask {
+ public:
+ OrcScanTask(std::shared_ptr<FileFragment> fragment,
+ std::shared_ptr<ScanOptions> options)
+ : ScanTask(std::move(options), fragment), source_(fragment->source()) {}
+
+ Result<RecordBatchIterator> Execute() override {
+ struct Impl {
+ static Result<RecordBatchIterator> Make(const FileSource& source,
+ const FileFormat& format,
+ const ScanOptions& scan_options) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto reader,
+ OpenORCReader(source, std::make_shared<ScanOptions>(scan_options)));
+ int num_stripes = reader->NumberOfStripes();
+
+ auto materialized_fields = scan_options.MaterializedFields();
+ // filter out virtual columns
+ std::vector<std::string> included_fields;
+ ARROW_ASSIGN_OR_RAISE(auto schema, reader->ReadSchema());
+ for (auto name : materialized_fields) {
+ FieldRef ref(name);
+ ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOneOrNone(*schema));
+ if (match.indices().empty()) continue;
+
+ included_fields.push_back(name);
+ }
+
+ return RecordBatchIterator(
+ Impl{std::move(reader), 0, num_stripes, included_fields});
+ }
+
+ Result<std::shared_ptr<RecordBatch>> Next() {
+ if (i_ == num_stripes_) {
+ return nullptr;
+ }
+ std::shared_ptr<RecordBatch> batch;
+ // TODO (https://issues.apache.org/jira/browse/ARROW-14153)
+ // pass scan_options_->batch_size
+ return reader_->ReadStripe(i_++, included_fields_);
+ }
+
+ std::unique_ptr<arrow::adapters::orc::ORCFileReader> reader_;
+ int i_;
+ int num_stripes_;
+ std::vector<std::string> included_fields_;
+ };
+
+ return Impl::Make(source_, *checked_pointer_cast<FileFragment>(fragment_)->format(),
+ *options_);
+ }
+
+ private:
+ FileSource source_;
+};
+
+class OrcScanTaskIterator {
+ public:
+ static Result<ScanTaskIterator> Make(std::shared_ptr<ScanOptions> options,
+ std::shared_ptr<FileFragment> fragment) {
+ return ScanTaskIterator(OrcScanTaskIterator(std::move(options), std::move(fragment)));
+ }
+
+ Result<std::shared_ptr<ScanTask>> Next() {
+ if (once_) {
+ // Iteration is done.
+ return nullptr;
+ }
+
+ once_ = true;
+ return std::shared_ptr<ScanTask>(new OrcScanTask(fragment_, options_));
+ }
+
+ private:
+ OrcScanTaskIterator(std::shared_ptr<ScanOptions> options,
+ std::shared_ptr<FileFragment> fragment)
+ : options_(std::move(options)), fragment_(std::move(fragment)) {}
+
+ bool once_ = false;
+ std::shared_ptr<ScanOptions> options_;
+ std::shared_ptr<FileFragment> fragment_;
+};
+
+} // namespace
+
+Result<bool> OrcFileFormat::IsSupported(const FileSource& source) const {
+ RETURN_NOT_OK(source.Open().status());
+ return OpenORCReader(source).ok();
+}
+
+Result<std::shared_ptr<Schema>> OrcFileFormat::Inspect(const FileSource& source) const {
+ ARROW_ASSIGN_OR_RAISE(auto reader, OpenORCReader(source));
+ return reader->ReadSchema();
+}
+
+Result<ScanTaskIterator> OrcFileFormat::ScanFile(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& fragment) const {
+ return OrcScanTaskIterator::Make(options, fragment);
+}
+
+Future<util::optional<int64_t>> OrcFileFormat::CountRows(
+ const std::shared_ptr<FileFragment>& file, compute::Expression predicate,
+ const std::shared_ptr<ScanOptions>& options) {
+ if (ExpressionHasFieldRefs(predicate)) {
+ return Future<util::optional<int64_t>>::MakeFinished(util::nullopt);
+ }
+ auto self = checked_pointer_cast<OrcFileFormat>(shared_from_this());
+ return DeferNotOk(options->io_context.executor()->Submit(
+ [self, file]() -> Result<util::optional<int64_t>> {
+ ARROW_ASSIGN_OR_RAISE(auto reader, OpenORCReader(file->source()));
+ return reader->NumberOfRows();
+ }));
+}
+
+// //
+// // OrcFileWriter, OrcFileWriteOptions
+// //
+
+std::shared_ptr<FileWriteOptions> OrcFileFormat::DefaultWriteOptions() {
+ // TODO (https://issues.apache.org/jira/browse/ARROW-13796)
+ return nullptr;
+}
+
+Result<std::shared_ptr<FileWriter>> OrcFileFormat::MakeWriter(
+ std::shared_ptr<io::OutputStream> destination, std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options,
+ fs::FileLocator destination_locator) const {
+ // TODO (https://issues.apache.org/jira/browse/ARROW-13796)
+ return Status::NotImplemented("ORC writer not yet implemented.");
+}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_orc.h b/src/arrow/cpp/src/arrow/dataset/file_orc.h
new file mode 100644
index 000000000..ca682935b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_orc.h
@@ -0,0 +1,79 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/type_fwd.h"
+#include "arrow/dataset/visibility.h"
+#include "arrow/io/type_fwd.h"
+#include "arrow/result.h"
+
+namespace arrow {
+namespace dataset {
+
+/// \addtogroup dataset-file-formats
+///
+/// @{
+
+constexpr char kOrcTypeName[] = "orc";
+
+/// \brief A FileFormat implementation that reads from and writes to ORC files
+class ARROW_DS_EXPORT OrcFileFormat : public FileFormat {
+ public:
+ std::string type_name() const override { return kOrcTypeName; }
+
+ bool Equals(const FileFormat& other) const override {
+ return type_name() == other.type_name();
+ }
+
+ Result<bool> IsSupported(const FileSource& source) const override;
+
+ /// \brief Return the schema of the file if possible.
+ Result<std::shared_ptr<Schema>> Inspect(const FileSource& source) const override;
+
+ /// \brief Open a file for scanning
+ Result<ScanTaskIterator> ScanFile(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& fragment) const override;
+
+ // TODO add async version (https://issues.apache.org/jira/browse/ARROW-13795)
+ // Result<RecordBatchGenerator> ScanBatchesAsync(
+ // const std::shared_ptr<ScanOptions>& options,
+ // const std::shared_ptr<FileFragment>& file) const override;
+
+ Future<util::optional<int64_t>> CountRows(
+ const std::shared_ptr<FileFragment>& file, compute::Expression predicate,
+ const std::shared_ptr<ScanOptions>& options) override;
+
+ Result<std::shared_ptr<FileWriter>> MakeWriter(
+ std::shared_ptr<io::OutputStream> destination, std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options,
+ fs::FileLocator destination_locator) const override;
+
+ std::shared_ptr<FileWriteOptions> DefaultWriteOptions() override;
+};
+
+/// @}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_orc_test.cc b/src/arrow/cpp/src/arrow/dataset/file_orc_test.cc
new file mode 100644
index 000000000..197d7afeb
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_orc_test.cc
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/file_orc.h"
+
+#include <memory>
+#include <utility>
+
+#include "arrow/adapters/orc/adapter.h"
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/discovery.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/partition.h"
+#include "arrow/dataset/scanner_internal.h"
+#include "arrow/dataset/test_util.h"
+#include "arrow/io/memory.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+
+namespace arrow {
+namespace dataset {
+
+class OrcFormatHelper {
+ public:
+ using FormatType = OrcFileFormat;
+ static Result<std::shared_ptr<Buffer>> Write(RecordBatchReader* reader) {
+ ARROW_ASSIGN_OR_RAISE(auto sink, io::BufferOutputStream::Create());
+ ARROW_ASSIGN_OR_RAISE(auto writer, adapters::orc::ORCFileWriter::Open(sink.get()));
+ std::shared_ptr<Table> table;
+ RETURN_NOT_OK(reader->ReadAll(&table));
+ writer->Write(*table);
+ RETURN_NOT_OK(writer->Close());
+ return sink->Finish();
+ }
+
+ static std::shared_ptr<OrcFileFormat> MakeFormat() {
+ return std::make_shared<OrcFileFormat>();
+ }
+};
+
+class TestOrcFileFormat : public FileFormatFixtureMixin<OrcFormatHelper> {};
+
+// TEST_F(TestOrcFileFormat, WriteRecordBatchReader) { TestWrite(); }
+
+TEST_F(TestOrcFileFormat, InspectFailureWithRelevantError) {
+ TestInspectFailureWithRelevantError(StatusCode::IOError, "ORC");
+}
+TEST_F(TestOrcFileFormat, Inspect) { TestInspect(); }
+TEST_F(TestOrcFileFormat, IsSupported) { TestIsSupported(); }
+TEST_F(TestOrcFileFormat, CountRows) { TestCountRows(); }
+
+// TODO add TestOrcFileSystemDataset if write support is added
+
+class TestOrcFileFormatScan : public FileFormatScanMixin<OrcFormatHelper> {};
+
+TEST_P(TestOrcFileFormatScan, ScanRecordBatchReader) { TestScan(); }
+TEST_P(TestOrcFileFormatScan, ScanRecordBatchReaderWithVirtualColumn) {
+ TestScanWithVirtualColumn();
+}
+TEST_P(TestOrcFileFormatScan, ScanRecordBatchReaderProjected) { TestScanProjected(); }
+TEST_P(TestOrcFileFormatScan, ScanRecordBatchReaderProjectedMissingCols) {
+ TestScanProjectedMissingCols();
+}
+INSTANTIATE_TEST_SUITE_P(TestScan, TestOrcFileFormatScan,
+ ::testing::ValuesIn(TestFormatParams::Values()),
+ TestFormatParams::ToTestNameString);
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_parquet.cc b/src/arrow/cpp/src/arrow/dataset/file_parquet.cc
new file mode 100644
index 000000000..ba9be0c8b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_parquet.cc
@@ -0,0 +1,974 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/file_parquet.h"
+
+#include <memory>
+#include <mutex>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "arrow/compute/exec.h"
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/scanner.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/table.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/range.h"
+#include "parquet/arrow/reader.h"
+#include "parquet/arrow/schema.h"
+#include "parquet/arrow/writer.h"
+#include "parquet/file_reader.h"
+#include "parquet/properties.h"
+#include "parquet/statistics.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+using internal::Iota;
+
+namespace dataset {
+
+using parquet::arrow::SchemaField;
+using parquet::arrow::SchemaManifest;
+using parquet::arrow::StatisticsAsScalars;
+
+namespace {
+
+/// \brief A ScanTask backed by a parquet file and a RowGroup within a parquet file.
+class ParquetScanTask : public ScanTask {
+ public:
+ ParquetScanTask(int row_group, std::vector<int> column_projection,
+ std::shared_ptr<parquet::arrow::FileReader> reader,
+ std::shared_ptr<std::once_flag> pre_buffer_once,
+ std::vector<int> pre_buffer_row_groups, arrow::io::IOContext io_context,
+ arrow::io::CacheOptions cache_options,
+ std::shared_ptr<ScanOptions> options,
+ std::shared_ptr<Fragment> fragment)
+ : ScanTask(std::move(options), std::move(fragment)),
+ row_group_(row_group),
+ column_projection_(std::move(column_projection)),
+ reader_(std::move(reader)),
+ pre_buffer_once_(std::move(pre_buffer_once)),
+ pre_buffer_row_groups_(std::move(pre_buffer_row_groups)),
+ io_context_(std::move(io_context)),
+ cache_options_(cache_options) {}
+
+ Result<RecordBatchIterator> Execute() override {
+ // The construction of parquet's RecordBatchReader is deferred here to
+ // control the memory usage of consumers who materialize all ScanTasks
+ // before dispatching them, e.g. for scheduling purposes.
+ //
+ // The memory and IO incurred by the RecordBatchReader is allocated only
+ // when Execute is called.
+ struct {
+ Result<std::shared_ptr<RecordBatch>> operator()() const {
+ return record_batch_reader->Next();
+ }
+
+ // The RecordBatchIterator must hold a reference to the FileReader;
+ // since it must outlive the wrapped RecordBatchReader
+ std::shared_ptr<parquet::arrow::FileReader> file_reader;
+ std::unique_ptr<RecordBatchReader> record_batch_reader;
+ } NextBatch;
+
+ RETURN_NOT_OK(EnsurePreBuffered());
+ NextBatch.file_reader = reader_;
+ RETURN_NOT_OK(reader_->GetRecordBatchReader({row_group_}, column_projection_,
+ &NextBatch.record_batch_reader));
+ return MakeFunctionIterator(std::move(NextBatch));
+ }
+
+ // Ensure that pre-buffering has been applied to the underlying Parquet reader
+ // exactly once (if needed). If we instead set pre_buffer on in the Arrow
+ // reader properties, each scan task will try to separately pre-buffer, which
+ // will lead to crashes as they trample the Parquet file reader's internal
+ // state. Instead, pre-buffer once at the file level. This also has the
+ // advantage that we can coalesce reads across row groups.
+ Status EnsurePreBuffered() {
+ if (pre_buffer_once_) {
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ std::call_once(*pre_buffer_once_, [this]() {
+ // Ignore the future here - don't wait for pre-buffering (the reader itself will
+ // block as necessary)
+ ARROW_UNUSED(reader_->parquet_reader()->PreBuffer(
+ pre_buffer_row_groups_, column_projection_, io_context_, cache_options_));
+ });
+ END_PARQUET_CATCH_EXCEPTIONS
+ }
+ return Status::OK();
+ }
+
+ private:
+ int row_group_;
+ std::vector<int> column_projection_;
+ std::shared_ptr<parquet::arrow::FileReader> reader_;
+ // Pre-buffering state. pre_buffer_once will be nullptr if no pre-buffering is
+ // to be done. We assume all scan tasks have the same column projection.
+ std::shared_ptr<std::once_flag> pre_buffer_once_;
+ std::vector<int> pre_buffer_row_groups_;
+ arrow::io::IOContext io_context_;
+ arrow::io::CacheOptions cache_options_;
+};
+
+parquet::ReaderProperties MakeReaderProperties(
+ const ParquetFileFormat& format, ParquetFragmentScanOptions* parquet_scan_options,
+ MemoryPool* pool = default_memory_pool()) {
+ // Can't mutate pool after construction
+ parquet::ReaderProperties properties(pool);
+ if (parquet_scan_options->reader_properties->is_buffered_stream_enabled()) {
+ properties.enable_buffered_stream();
+ } else {
+ properties.disable_buffered_stream();
+ }
+ properties.set_buffer_size(parquet_scan_options->reader_properties->buffer_size());
+ properties.file_decryption_properties(
+ parquet_scan_options->reader_properties->file_decryption_properties());
+ return properties;
+}
+
+parquet::ArrowReaderProperties MakeArrowReaderProperties(
+ const ParquetFileFormat& format, const parquet::FileMetaData& metadata) {
+ parquet::ArrowReaderProperties properties(/* use_threads = */ false);
+ for (const std::string& name : format.reader_options.dict_columns) {
+ auto column_index = metadata.schema()->ColumnIndex(name);
+ properties.set_read_dictionary(column_index, true);
+ }
+ properties.set_coerce_int96_timestamp_unit(
+ format.reader_options.coerce_int96_timestamp_unit);
+ return properties;
+}
+
+template <typename M>
+Result<std::shared_ptr<SchemaManifest>> GetSchemaManifest(
+ const M& metadata, const parquet::ArrowReaderProperties& properties) {
+ auto manifest = std::make_shared<SchemaManifest>();
+ const std::shared_ptr<const ::arrow::KeyValueMetadata>& key_value_metadata = nullptr;
+ RETURN_NOT_OK(SchemaManifest::Make(metadata.schema(), key_value_metadata, properties,
+ manifest.get()));
+ return manifest;
+}
+
+util::optional<compute::Expression> ColumnChunkStatisticsAsExpression(
+ const SchemaField& schema_field, const parquet::RowGroupMetaData& metadata) {
+ // For the remaining of this function, failure to extract/parse statistics
+ // are ignored by returning nullptr. The goal is two fold. First
+ // avoid an optimization which breaks the computation. Second, allow the
+ // following columns to maybe succeed in extracting column statistics.
+
+ // For now, only leaf (primitive) types are supported.
+ if (!schema_field.is_leaf()) {
+ return util::nullopt;
+ }
+
+ auto column_metadata = metadata.ColumnChunk(schema_field.column_index);
+ auto statistics = column_metadata->statistics();
+ if (statistics == nullptr) {
+ return util::nullopt;
+ }
+
+ const auto& field = schema_field.field;
+ auto field_expr = compute::field_ref(field->name());
+
+ // Optimize for corner case where all values are nulls
+ if (statistics->num_values() == 0 && statistics->null_count() > 0) {
+ return is_null(std::move(field_expr));
+ }
+
+ std::shared_ptr<Scalar> min, max;
+ if (!StatisticsAsScalars(*statistics, &min, &max).ok()) {
+ return util::nullopt;
+ }
+
+ auto maybe_min = min->CastTo(field->type());
+ auto maybe_max = max->CastTo(field->type());
+ if (maybe_min.ok() && maybe_max.ok()) {
+ auto col_min = maybe_min.MoveValueUnsafe();
+ auto col_max = maybe_max.MoveValueUnsafe();
+ if (col_min->Equals(col_max)) {
+ return compute::equal(std::move(field_expr), compute::literal(std::move(col_min)));
+ }
+
+ auto lower_bound =
+ compute::greater_equal(field_expr, compute::literal(std::move(col_min)));
+ auto upper_bound =
+ compute::less_equal(std::move(field_expr), compute::literal(std::move(col_max)));
+ return compute::and_(std::move(lower_bound), std::move(upper_bound));
+ }
+
+ return util::nullopt;
+}
+
+void AddColumnIndices(const SchemaField& schema_field,
+ std::vector<int>* column_projection) {
+ if (schema_field.is_leaf()) {
+ column_projection->push_back(schema_field.column_index);
+ } else {
+ // The following ensure that complex types, e.g. struct, are materialized.
+ for (const auto& child : schema_field.children) {
+ AddColumnIndices(child, column_projection);
+ }
+ }
+}
+
+// Compute the column projection out of an optional arrow::Schema
+std::vector<int> InferColumnProjection(const parquet::arrow::FileReader& reader,
+ const ScanOptions& options) {
+ auto manifest = reader.manifest();
+ // Checks if the field is needed in either the projection or the filter.
+ auto field_names = options.MaterializedFields();
+ std::unordered_set<std::string> materialized_fields{field_names.cbegin(),
+ field_names.cend()};
+ auto should_materialize_column = [&materialized_fields](const std::string& f) {
+ return materialized_fields.find(f) != materialized_fields.end();
+ };
+
+ std::vector<int> columns_selection;
+ // Note that the loop is using the file's schema to iterate instead of the
+ // materialized fields of the ScanOptions. This ensures that missing
+ // fields in the file (but present in the ScanOptions) will be ignored. The
+ // scanner's projector will take care of padding the column with the proper
+ // values.
+ for (const auto& schema_field : manifest.schema_fields) {
+ if (should_materialize_column(schema_field.field->name())) {
+ AddColumnIndices(schema_field, &columns_selection);
+ }
+ }
+
+ return columns_selection;
+}
+
+Status WrapSourceError(const Status& status, const std::string& path) {
+ return status.WithMessage("Could not open Parquet input source '", path,
+ "': ", status.message());
+}
+
+Result<bool> IsSupportedParquetFile(const ParquetFileFormat& format,
+ const FileSource& source) {
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ try {
+ ARROW_ASSIGN_OR_RAISE(auto input, source.Open());
+ ARROW_ASSIGN_OR_RAISE(
+ auto parquet_scan_options,
+ GetFragmentScanOptions<ParquetFragmentScanOptions>(
+ kParquetTypeName, nullptr, format.default_fragment_scan_options));
+ auto reader = parquet::ParquetFileReader::Open(
+ std::move(input), MakeReaderProperties(format, parquet_scan_options.get()));
+ std::shared_ptr<parquet::FileMetaData> metadata = reader->metadata();
+ return metadata != nullptr && metadata->can_decompress();
+ } catch (const ::parquet::ParquetInvalidOrCorruptedFileException& e) {
+ ARROW_UNUSED(e);
+ return false;
+ }
+ END_PARQUET_CATCH_EXCEPTIONS
+}
+
+} // namespace
+
+bool ParquetFileFormat::Equals(const FileFormat& other) const {
+ if (other.type_name() != type_name()) return false;
+
+ const auto& other_reader_options =
+ checked_cast<const ParquetFileFormat&>(other).reader_options;
+
+ // FIXME implement comparison for decryption options
+ return (reader_options.dict_columns == other_reader_options.dict_columns &&
+ reader_options.coerce_int96_timestamp_unit ==
+ other_reader_options.coerce_int96_timestamp_unit);
+}
+
+ParquetFileFormat::ParquetFileFormat(const parquet::ReaderProperties& reader_properties) {
+ auto parquet_scan_options = std::make_shared<ParquetFragmentScanOptions>();
+ *parquet_scan_options->reader_properties = reader_properties;
+ default_fragment_scan_options = std::move(parquet_scan_options);
+}
+
+Result<bool> ParquetFileFormat::IsSupported(const FileSource& source) const {
+ auto maybe_is_supported = IsSupportedParquetFile(*this, source);
+ if (!maybe_is_supported.ok()) {
+ return WrapSourceError(maybe_is_supported.status(), source.path());
+ }
+ return maybe_is_supported;
+}
+
+Result<std::shared_ptr<Schema>> ParquetFileFormat::Inspect(
+ const FileSource& source) const {
+ ARROW_ASSIGN_OR_RAISE(auto reader, GetReader(source));
+ std::shared_ptr<Schema> schema;
+ RETURN_NOT_OK(reader->GetSchema(&schema));
+ return schema;
+}
+
+Result<std::unique_ptr<parquet::arrow::FileReader>> ParquetFileFormat::GetReader(
+ const FileSource& source, ScanOptions* options) const {
+ ARROW_ASSIGN_OR_RAISE(auto parquet_scan_options,
+ GetFragmentScanOptions<ParquetFragmentScanOptions>(
+ kParquetTypeName, options, default_fragment_scan_options));
+ MemoryPool* pool = options ? options->pool : default_memory_pool();
+ auto properties = MakeReaderProperties(*this, parquet_scan_options.get(), pool);
+
+ ARROW_ASSIGN_OR_RAISE(auto input, source.Open());
+
+ auto make_reader = [&]() -> Result<std::unique_ptr<parquet::ParquetFileReader>> {
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ return parquet::ParquetFileReader::Open(std::move(input), std::move(properties));
+ END_PARQUET_CATCH_EXCEPTIONS
+ };
+
+ auto maybe_reader = std::move(make_reader)();
+ if (!maybe_reader.ok()) {
+ return WrapSourceError(maybe_reader.status(), source.path());
+ }
+ std::unique_ptr<parquet::ParquetFileReader> reader = *std::move(maybe_reader);
+ std::shared_ptr<parquet::FileMetaData> metadata = reader->metadata();
+ auto arrow_properties = MakeArrowReaderProperties(*this, *metadata);
+
+ if (options) {
+ arrow_properties.set_batch_size(options->batch_size);
+ }
+
+ if (options && !options->use_threads) {
+ arrow_properties.set_use_threads(
+ parquet_scan_options->enable_parallel_column_conversion);
+ }
+
+ std::unique_ptr<parquet::arrow::FileReader> arrow_reader;
+ RETURN_NOT_OK(parquet::arrow::FileReader::Make(
+ pool, std::move(reader), std::move(arrow_properties), &arrow_reader));
+ return std::move(arrow_reader);
+}
+
+Future<std::shared_ptr<parquet::arrow::FileReader>> ParquetFileFormat::GetReaderAsync(
+ const FileSource& source, const std::shared_ptr<ScanOptions>& options) const {
+ ARROW_ASSIGN_OR_RAISE(
+ auto parquet_scan_options,
+ GetFragmentScanOptions<ParquetFragmentScanOptions>(kParquetTypeName, options.get(),
+ default_fragment_scan_options));
+ auto properties =
+ MakeReaderProperties(*this, parquet_scan_options.get(), options->pool);
+ ARROW_ASSIGN_OR_RAISE(auto input, source.Open());
+ // TODO(ARROW-12259): workaround since we have Future<(move-only type)>
+ auto reader_fut =
+ parquet::ParquetFileReader::OpenAsync(std::move(input), std::move(properties));
+ auto path = source.path();
+ auto self = checked_pointer_cast<const ParquetFileFormat>(shared_from_this());
+ return reader_fut.Then(
+ [=](const std::unique_ptr<parquet::ParquetFileReader>&) mutable
+ -> Result<std::shared_ptr<parquet::arrow::FileReader>> {
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<parquet::ParquetFileReader> reader,
+ reader_fut.MoveResult());
+ std::shared_ptr<parquet::FileMetaData> metadata = reader->metadata();
+ auto arrow_properties = MakeArrowReaderProperties(*self, *metadata);
+ arrow_properties.set_batch_size(options->batch_size);
+ // Must be set here since the sync ScanTask handles pre-buffering itself
+ arrow_properties.set_pre_buffer(
+ parquet_scan_options->arrow_reader_properties->pre_buffer());
+ arrow_properties.set_cache_options(
+ parquet_scan_options->arrow_reader_properties->cache_options());
+ arrow_properties.set_io_context(
+ parquet_scan_options->arrow_reader_properties->io_context());
+ arrow_properties.set_use_threads(options->use_threads);
+ std::unique_ptr<parquet::arrow::FileReader> arrow_reader;
+ RETURN_NOT_OK(parquet::arrow::FileReader::Make(options->pool, std::move(reader),
+ std::move(arrow_properties),
+ &arrow_reader));
+ return std::move(arrow_reader);
+ },
+ [path](
+ const Status& status) -> Result<std::shared_ptr<parquet::arrow::FileReader>> {
+ return WrapSourceError(status, path);
+ });
+}
+
+Result<ScanTaskIterator> ParquetFileFormat::ScanFile(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& fragment) const {
+ auto* parquet_fragment = checked_cast<ParquetFileFragment*>(fragment.get());
+ std::vector<int> row_groups;
+
+ bool pre_filtered = false;
+ auto MakeEmpty = [] { return MakeEmptyIterator<std::shared_ptr<ScanTask>>(); };
+
+ // If RowGroup metadata is cached completely we can pre-filter RowGroups before opening
+ // a FileReader, potentially avoiding IO altogether if all RowGroups are excluded due to
+ // prior statistics knowledge. In the case where a RowGroup doesn't have statistics
+ // metdata, it will not be excluded.
+ if (parquet_fragment->metadata() != nullptr) {
+ ARROW_ASSIGN_OR_RAISE(row_groups, parquet_fragment->FilterRowGroups(options->filter));
+
+ pre_filtered = true;
+ if (row_groups.empty()) MakeEmpty();
+ }
+
+ // Open the reader and pay the real IO cost.
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<parquet::arrow::FileReader> reader,
+ GetReader(fragment->source(), options.get()));
+
+ // Ensure that parquet_fragment has FileMetaData
+ RETURN_NOT_OK(parquet_fragment->EnsureCompleteMetadata(reader.get()));
+
+ if (!pre_filtered) {
+ // row groups were not already filtered; do this now
+ ARROW_ASSIGN_OR_RAISE(row_groups, parquet_fragment->FilterRowGroups(options->filter));
+
+ if (row_groups.empty()) MakeEmpty();
+ }
+
+ auto column_projection = InferColumnProjection(*reader, *options);
+ ScanTaskVector tasks(row_groups.size());
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto parquet_scan_options,
+ GetFragmentScanOptions<ParquetFragmentScanOptions>(kParquetTypeName, options.get(),
+ default_fragment_scan_options));
+ std::shared_ptr<std::once_flag> pre_buffer_once = nullptr;
+ if (parquet_scan_options->arrow_reader_properties->pre_buffer()) {
+ pre_buffer_once = std::make_shared<std::once_flag>();
+ }
+
+ for (size_t i = 0; i < row_groups.size(); ++i) {
+ tasks[i] = std::make_shared<ParquetScanTask>(
+ row_groups[i], column_projection, reader, pre_buffer_once, row_groups,
+ parquet_scan_options->arrow_reader_properties->io_context(),
+ parquet_scan_options->arrow_reader_properties->cache_options(), options,
+ fragment);
+ }
+
+ return MakeVectorIterator(std::move(tasks));
+}
+
+Result<RecordBatchGenerator> ParquetFileFormat::ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& file) const {
+ auto parquet_fragment = checked_pointer_cast<ParquetFileFragment>(file);
+ std::vector<int> row_groups;
+ bool pre_filtered = false;
+ // If RowGroup metadata is cached completely we can pre-filter RowGroups before opening
+ // a FileReader, potentially avoiding IO altogether if all RowGroups are excluded due to
+ // prior statistics knowledge. In the case where a RowGroup doesn't have statistics
+ // metdata, it will not be excluded.
+ if (parquet_fragment->metadata() != nullptr) {
+ ARROW_ASSIGN_OR_RAISE(row_groups, parquet_fragment->FilterRowGroups(options->filter));
+ pre_filtered = true;
+ if (row_groups.empty()) return MakeEmptyGenerator<std::shared_ptr<RecordBatch>>();
+ }
+ // Open the reader and pay the real IO cost.
+ auto make_generator =
+ [=](const std::shared_ptr<parquet::arrow::FileReader>& reader) mutable
+ -> Result<RecordBatchGenerator> {
+ // Ensure that parquet_fragment has FileMetaData
+ RETURN_NOT_OK(parquet_fragment->EnsureCompleteMetadata(reader.get()));
+ if (!pre_filtered) {
+ // row groups were not already filtered; do this now
+ ARROW_ASSIGN_OR_RAISE(row_groups,
+ parquet_fragment->FilterRowGroups(options->filter));
+ if (row_groups.empty()) return MakeEmptyGenerator<std::shared_ptr<RecordBatch>>();
+ }
+ auto column_projection = InferColumnProjection(*reader, *options);
+ ARROW_ASSIGN_OR_RAISE(
+ auto parquet_scan_options,
+ GetFragmentScanOptions<ParquetFragmentScanOptions>(
+ kParquetTypeName, options.get(), default_fragment_scan_options));
+ // Assume 1 row group corresponds to 1 batch (this factor could be
+ // improved by looking at metadata)
+ int row_group_readahead = options->batch_readahead;
+ ARROW_ASSIGN_OR_RAISE(
+ auto generator, reader->GetRecordBatchGenerator(
+ reader, row_groups, column_projection,
+ ::arrow::internal::GetCpuThreadPool(), row_group_readahead));
+ return generator;
+ };
+ return MakeFromFuture(GetReaderAsync(parquet_fragment->source(), options)
+ .Then(std::move(make_generator)));
+}
+
+Future<util::optional<int64_t>> ParquetFileFormat::CountRows(
+ const std::shared_ptr<FileFragment>& file, compute::Expression predicate,
+ const std::shared_ptr<ScanOptions>& options) {
+ auto parquet_file = checked_pointer_cast<ParquetFileFragment>(file);
+ if (parquet_file->metadata()) {
+ ARROW_ASSIGN_OR_RAISE(auto maybe_count,
+ parquet_file->TryCountRows(std::move(predicate)));
+ return Future<util::optional<int64_t>>::MakeFinished(maybe_count);
+ } else {
+ return DeferNotOk(options->io_context.executor()->Submit(
+ [parquet_file, predicate]() -> Result<util::optional<int64_t>> {
+ RETURN_NOT_OK(parquet_file->EnsureCompleteMetadata());
+ return parquet_file->TryCountRows(predicate);
+ }));
+ }
+}
+
+Result<std::shared_ptr<ParquetFileFragment>> ParquetFileFormat::MakeFragment(
+ FileSource source, compute::Expression partition_expression,
+ std::shared_ptr<Schema> physical_schema, std::vector<int> row_groups) {
+ return std::shared_ptr<ParquetFileFragment>(new ParquetFileFragment(
+ std::move(source), shared_from_this(), std::move(partition_expression),
+ std::move(physical_schema), std::move(row_groups)));
+}
+
+Result<std::shared_ptr<FileFragment>> ParquetFileFormat::MakeFragment(
+ FileSource source, compute::Expression partition_expression,
+ std::shared_ptr<Schema> physical_schema) {
+ return std::shared_ptr<FileFragment>(new ParquetFileFragment(
+ std::move(source), shared_from_this(), std::move(partition_expression),
+ std::move(physical_schema), util::nullopt));
+}
+
+//
+// ParquetFileWriter, ParquetFileWriteOptions
+//
+
+std::shared_ptr<FileWriteOptions> ParquetFileFormat::DefaultWriteOptions() {
+ std::shared_ptr<ParquetFileWriteOptions> options(
+ new ParquetFileWriteOptions(shared_from_this()));
+ options->writer_properties = parquet::default_writer_properties();
+ options->arrow_writer_properties = parquet::default_arrow_writer_properties();
+ return options;
+}
+
+Result<std::shared_ptr<FileWriter>> ParquetFileFormat::MakeWriter(
+ std::shared_ptr<io::OutputStream> destination, std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options,
+ fs::FileLocator destination_locator) const {
+ if (!Equals(*options->format())) {
+ return Status::TypeError("Mismatching format/write options");
+ }
+
+ auto parquet_options = checked_pointer_cast<ParquetFileWriteOptions>(options);
+
+ std::unique_ptr<parquet::arrow::FileWriter> parquet_writer;
+ RETURN_NOT_OK(parquet::arrow::FileWriter::Open(
+ *schema, default_memory_pool(), destination, parquet_options->writer_properties,
+ parquet_options->arrow_writer_properties, &parquet_writer));
+
+ return std::shared_ptr<FileWriter>(
+ new ParquetFileWriter(std::move(destination), std::move(parquet_writer),
+ std::move(parquet_options), std::move(destination_locator)));
+}
+
+ParquetFileWriter::ParquetFileWriter(std::shared_ptr<io::OutputStream> destination,
+ std::shared_ptr<parquet::arrow::FileWriter> writer,
+ std::shared_ptr<ParquetFileWriteOptions> options,
+ fs::FileLocator destination_locator)
+ : FileWriter(writer->schema(), std::move(options), std::move(destination),
+ std::move(destination_locator)),
+ parquet_writer_(std::move(writer)) {}
+
+Status ParquetFileWriter::Write(const std::shared_ptr<RecordBatch>& batch) {
+ ARROW_ASSIGN_OR_RAISE(auto table, Table::FromRecordBatches(batch->schema(), {batch}));
+ return parquet_writer_->WriteTable(*table, batch->num_rows());
+}
+
+Status ParquetFileWriter::FinishInternal() { return parquet_writer_->Close(); }
+
+//
+// ParquetFileFragment
+//
+
+ParquetFileFragment::ParquetFileFragment(FileSource source,
+ std::shared_ptr<FileFormat> format,
+ compute::Expression partition_expression,
+ std::shared_ptr<Schema> physical_schema,
+ util::optional<std::vector<int>> row_groups)
+ : FileFragment(std::move(source), std::move(format), std::move(partition_expression),
+ std::move(physical_schema)),
+ parquet_format_(checked_cast<ParquetFileFormat&>(*format_)),
+ row_groups_(std::move(row_groups)) {}
+
+Status ParquetFileFragment::EnsureCompleteMetadata(parquet::arrow::FileReader* reader) {
+ auto lock = physical_schema_mutex_.Lock();
+ if (metadata_ != nullptr) {
+ return Status::OK();
+ }
+
+ if (reader == nullptr) {
+ lock.Unlock();
+ ARROW_ASSIGN_OR_RAISE(auto reader, parquet_format_.GetReader(source_));
+ return EnsureCompleteMetadata(reader.get());
+ }
+
+ std::shared_ptr<Schema> schema;
+ RETURN_NOT_OK(reader->GetSchema(&schema));
+ if (physical_schema_ && !physical_schema_->Equals(*schema)) {
+ return Status::Invalid("Fragment initialized with physical schema ",
+ *physical_schema_, " but ", source_.path(), " has schema ",
+ *schema);
+ }
+ physical_schema_ = std::move(schema);
+
+ if (!row_groups_) {
+ row_groups_ = Iota(reader->num_row_groups());
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto manifest,
+ GetSchemaManifest(*reader->parquet_reader()->metadata(), reader->properties()));
+ return SetMetadata(reader->parquet_reader()->metadata(), std::move(manifest));
+}
+
+Status ParquetFileFragment::SetMetadata(
+ std::shared_ptr<parquet::FileMetaData> metadata,
+ std::shared_ptr<parquet::arrow::SchemaManifest> manifest) {
+ DCHECK(row_groups_.has_value());
+
+ metadata_ = std::move(metadata);
+ manifest_ = std::move(manifest);
+
+ statistics_expressions_.resize(row_groups_->size(), compute::literal(true));
+ statistics_expressions_complete_.resize(physical_schema_->num_fields(), false);
+
+ for (int row_group : *row_groups_) {
+ // Ensure RowGroups are indexing valid RowGroups before augmenting.
+ if (row_group < metadata_->num_row_groups()) continue;
+
+ return Status::IndexError("ParquetFileFragment references row group ", row_group,
+ " but ", source_.path(), " only has ",
+ metadata_->num_row_groups(), " row groups");
+ }
+
+ return Status::OK();
+}
+
+Result<FragmentVector> ParquetFileFragment::SplitByRowGroup(
+ compute::Expression predicate) {
+ RETURN_NOT_OK(EnsureCompleteMetadata());
+ ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate));
+
+ FragmentVector fragments(row_groups.size());
+ int i = 0;
+ for (int row_group : row_groups) {
+ ARROW_ASSIGN_OR_RAISE(auto fragment,
+ parquet_format_.MakeFragment(source_, partition_expression(),
+ physical_schema_, {row_group}));
+
+ RETURN_NOT_OK(fragment->SetMetadata(metadata_, manifest_));
+ fragments[i++] = std::move(fragment);
+ }
+
+ return fragments;
+}
+
+Result<std::shared_ptr<Fragment>> ParquetFileFragment::Subset(
+ compute::Expression predicate) {
+ RETURN_NOT_OK(EnsureCompleteMetadata());
+ ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate));
+ return Subset(std::move(row_groups));
+}
+
+Result<std::shared_ptr<Fragment>> ParquetFileFragment::Subset(
+ std::vector<int> row_groups) {
+ RETURN_NOT_OK(EnsureCompleteMetadata());
+ ARROW_ASSIGN_OR_RAISE(auto new_fragment, parquet_format_.MakeFragment(
+ source_, partition_expression(),
+ physical_schema_, std::move(row_groups)));
+
+ RETURN_NOT_OK(new_fragment->SetMetadata(metadata_, manifest_));
+ return new_fragment;
+}
+
+inline void FoldingAnd(compute::Expression* l, compute::Expression r) {
+ if (*l == compute::literal(true)) {
+ *l = std::move(r);
+ } else {
+ *l = and_(std::move(*l), std::move(r));
+ }
+}
+
+Result<std::vector<int>> ParquetFileFragment::FilterRowGroups(
+ compute::Expression predicate) {
+ std::vector<int> row_groups;
+ ARROW_ASSIGN_OR_RAISE(auto expressions, TestRowGroups(std::move(predicate)));
+
+ auto lock = physical_schema_mutex_.Lock();
+ DCHECK(expressions.empty() || (expressions.size() == row_groups_->size()));
+ for (size_t i = 0; i < expressions.size(); i++) {
+ if (expressions[i].IsSatisfiable()) {
+ row_groups.push_back(row_groups_->at(i));
+ }
+ }
+ return row_groups;
+}
+
+Result<std::vector<compute::Expression>> ParquetFileFragment::TestRowGroups(
+ compute::Expression predicate) {
+ auto lock = physical_schema_mutex_.Lock();
+
+ DCHECK_NE(metadata_, nullptr);
+ ARROW_ASSIGN_OR_RAISE(
+ predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_));
+
+ if (!predicate.IsSatisfiable()) {
+ return std::vector<compute::Expression>{};
+ }
+
+ for (const FieldRef& ref : FieldsInExpression(predicate)) {
+ ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOneOrNone(*physical_schema_));
+
+ if (match.empty()) continue;
+ if (statistics_expressions_complete_[match[0]]) continue;
+ statistics_expressions_complete_[match[0]] = true;
+
+ const SchemaField& schema_field = manifest_->schema_fields[match[0]];
+ int i = 0;
+ for (int row_group : *row_groups_) {
+ auto row_group_metadata = metadata_->RowGroup(row_group);
+
+ if (auto minmax =
+ ColumnChunkStatisticsAsExpression(schema_field, *row_group_metadata)) {
+ FoldingAnd(&statistics_expressions_[i], std::move(*minmax));
+ ARROW_ASSIGN_OR_RAISE(statistics_expressions_[i],
+ statistics_expressions_[i].Bind(*physical_schema_));
+ }
+
+ ++i;
+ }
+ }
+
+ std::vector<compute::Expression> row_groups(row_groups_->size());
+ for (size_t i = 0; i < row_groups_->size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto row_group_predicate,
+ SimplifyWithGuarantee(predicate, statistics_expressions_[i]));
+ row_groups[i] = std::move(row_group_predicate);
+ }
+ return row_groups;
+}
+
+Result<util::optional<int64_t>> ParquetFileFragment::TryCountRows(
+ compute::Expression predicate) {
+ DCHECK_NE(metadata_, nullptr);
+ if (ExpressionHasFieldRefs(predicate)) {
+#if defined(__GNUC__) && (__GNUC__ < 5)
+ // ARROW-12694: with GCC 4.9 (RTools 35) we sometimes segfault here if we move(result)
+ auto result = TestRowGroups(std::move(predicate));
+ if (!result.ok()) {
+ return result.status();
+ }
+ auto expressions = result.ValueUnsafe();
+#else
+ ARROW_ASSIGN_OR_RAISE(auto expressions, TestRowGroups(std::move(predicate)));
+#endif
+ int64_t rows = 0;
+ for (size_t i = 0; i < row_groups_->size(); i++) {
+ // If the row group is entirely excluded, exclude it from the row count
+ if (!expressions[i].IsSatisfiable()) continue;
+ // Unless the row group is entirely included, bail out of fast path
+ if (expressions[i] != compute::literal(true)) return util::nullopt;
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ rows += metadata()->RowGroup((*row_groups_)[i])->num_rows();
+ END_PARQUET_CATCH_EXCEPTIONS
+ }
+ return rows;
+ }
+ return metadata()->num_rows();
+}
+
+//
+// ParquetFragmentScanOptions
+//
+
+ParquetFragmentScanOptions::ParquetFragmentScanOptions() {
+ reader_properties = std::make_shared<parquet::ReaderProperties>();
+ arrow_reader_properties =
+ std::make_shared<parquet::ArrowReaderProperties>(/*use_threads=*/false);
+}
+
+//
+// ParquetDatasetFactory
+//
+
+static inline Result<std::string> FileFromRowGroup(
+ fs::FileSystem* filesystem, const std::string& base_path,
+ const parquet::RowGroupMetaData& row_group, bool validate_column_chunk_paths) {
+ constexpr auto prefix = "Extracting file path from RowGroup failed. ";
+
+ if (row_group.num_columns() == 0) {
+ return Status::Invalid(prefix,
+ "RowGroup must have a least one column to extract path.");
+ }
+
+ auto path = row_group.ColumnChunk(0)->file_path();
+ if (path == "") {
+ return Status::Invalid(
+ prefix,
+ "The column chunks' file paths should be set, but got an empty file path.");
+ }
+
+ if (validate_column_chunk_paths) {
+ for (int i = 1; i < row_group.num_columns(); ++i) {
+ const auto& column_path = row_group.ColumnChunk(i)->file_path();
+ if (column_path != path) {
+ return Status::Invalid(prefix, "Path '", column_path, "' not equal to path '",
+ path, ", for ColumnChunk at index ", i,
+ "; ColumnChunks in a RowGroup must have the same path.");
+ }
+ }
+ }
+
+ path = fs::internal::JoinAbstractPath(
+ std::vector<std::string>{base_path, std::move(path)});
+ // Normalizing path is required for Windows.
+ return filesystem->NormalizePath(std::move(path));
+}
+
+Result<std::shared_ptr<Schema>> GetSchema(
+ const parquet::FileMetaData& metadata,
+ const parquet::ArrowReaderProperties& properties) {
+ std::shared_ptr<Schema> schema;
+ RETURN_NOT_OK(parquet::arrow::FromParquetSchema(
+ metadata.schema(), properties, metadata.key_value_metadata(), &schema));
+ return schema;
+}
+
+Result<std::shared_ptr<DatasetFactory>> ParquetDatasetFactory::Make(
+ const std::string& metadata_path, std::shared_ptr<fs::FileSystem> filesystem,
+ std::shared_ptr<ParquetFileFormat> format, ParquetFactoryOptions options) {
+ // Paths in ColumnChunk are relative to the `_metadata` file. Thus, the base
+ // directory of all parquet files is `dirname(metadata_path)`.
+ auto dirname = arrow::fs::internal::GetAbstractPathParent(metadata_path).first;
+ return Make({metadata_path, filesystem}, dirname, filesystem, std::move(format),
+ std::move(options));
+}
+
+Result<std::shared_ptr<DatasetFactory>> ParquetDatasetFactory::Make(
+ const FileSource& metadata_source, const std::string& base_path,
+ std::shared_ptr<fs::FileSystem> filesystem, std::shared_ptr<ParquetFileFormat> format,
+ ParquetFactoryOptions options) {
+ DCHECK_NE(filesystem, nullptr);
+ DCHECK_NE(format, nullptr);
+
+ // By automatically setting the options base_dir to the metadata's base_path,
+ // we provide a better experience for user providing Partitioning that are
+ // relative to the base_dir instead of the full path.
+ if (options.partition_base_dir.empty()) {
+ options.partition_base_dir = base_path;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto reader, format->GetReader(metadata_source));
+ std::shared_ptr<parquet::FileMetaData> metadata = reader->parquet_reader()->metadata();
+
+ if (metadata->num_columns() == 0) {
+ return Status::Invalid(
+ "ParquetDatasetFactory must contain a schema with at least one column");
+ }
+
+ auto properties = MakeArrowReaderProperties(*format, *metadata);
+ ARROW_ASSIGN_OR_RAISE(auto physical_schema, GetSchema(*metadata, properties));
+ ARROW_ASSIGN_OR_RAISE(auto manifest, GetSchemaManifest(*metadata, properties));
+
+ std::vector<std::pair<std::string, std::vector<int>>> paths_with_row_group_ids;
+ std::unordered_map<std::string, int> paths_to_index;
+
+ for (int i = 0; i < metadata->num_row_groups(); i++) {
+ auto row_group = metadata->RowGroup(i);
+ ARROW_ASSIGN_OR_RAISE(auto path,
+ FileFromRowGroup(filesystem.get(), base_path, *row_group,
+ options.validate_column_chunk_paths));
+
+ // Insert the path, or increase the count of row groups. It will be assumed that the
+ // RowGroup of a file are ordered exactly as in the metadata file.
+ auto inserted_index = paths_to_index.emplace(
+ std::move(path), static_cast<int>(paths_with_row_group_ids.size()));
+ if (inserted_index.second) {
+ paths_with_row_group_ids.push_back({inserted_index.first->first, {}});
+ }
+ paths_with_row_group_ids[inserted_index.first->second].second.push_back(i);
+ }
+
+ return std::shared_ptr<DatasetFactory>(new ParquetDatasetFactory(
+ std::move(filesystem), std::move(format), std::move(metadata), std::move(manifest),
+ std::move(physical_schema), base_path, std::move(options),
+ std::move(paths_with_row_group_ids)));
+}
+
+Result<std::vector<std::shared_ptr<FileFragment>>>
+ParquetDatasetFactory::CollectParquetFragments(const Partitioning& partitioning) {
+ std::vector<std::shared_ptr<FileFragment>> fragments(paths_with_row_group_ids_.size());
+
+ size_t i = 0;
+ for (const auto& e : paths_with_row_group_ids_) {
+ const auto& path = e.first;
+ auto metadata_subset = metadata_->Subset(e.second);
+
+ auto row_groups = Iota(metadata_subset->num_row_groups());
+
+ auto partition_expression =
+ partitioning.Parse(StripPrefixAndFilename(path, options_.partition_base_dir))
+ .ValueOr(compute::literal(true));
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto fragment,
+ format_->MakeFragment({path, filesystem_}, std::move(partition_expression),
+ physical_schema_, std::move(row_groups)));
+
+ RETURN_NOT_OK(fragment->SetMetadata(metadata_subset, manifest_));
+ fragments[i++] = std::move(fragment);
+ }
+
+ return fragments;
+}
+
+Result<std::vector<std::shared_ptr<Schema>>> ParquetDatasetFactory::InspectSchemas(
+ InspectOptions options) {
+ // The physical_schema from the _metadata file is always yielded
+ std::vector<std::shared_ptr<Schema>> schemas = {physical_schema_};
+
+ if (auto factory = options_.partitioning.factory()) {
+ // Gather paths found in RowGroups' ColumnChunks.
+ std::vector<std::string> stripped(paths_with_row_group_ids_.size());
+
+ size_t i = 0;
+ for (const auto& e : paths_with_row_group_ids_) {
+ stripped[i++] = StripPrefixAndFilename(e.first, options_.partition_base_dir);
+ }
+ ARROW_ASSIGN_OR_RAISE(auto partition_schema, factory->Inspect(stripped));
+
+ schemas.push_back(std::move(partition_schema));
+ } else {
+ schemas.push_back(options_.partitioning.partitioning()->schema());
+ }
+
+ return schemas;
+}
+
+Result<std::shared_ptr<Dataset>> ParquetDatasetFactory::Finish(FinishOptions options) {
+ std::shared_ptr<Schema> schema = options.schema;
+ bool schema_missing = schema == nullptr;
+ if (schema_missing) {
+ ARROW_ASSIGN_OR_RAISE(schema, Inspect(options.inspect_options));
+ }
+
+ std::shared_ptr<Partitioning> partitioning = options_.partitioning.partitioning();
+ if (partitioning == nullptr) {
+ auto factory = options_.partitioning.factory();
+ ARROW_ASSIGN_OR_RAISE(partitioning, factory->Finish(schema));
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto fragments, CollectParquetFragments(*partitioning));
+ return FileSystemDataset::Make(std::move(schema), compute::literal(true), format_,
+ filesystem_, std::move(fragments),
+ std::move(partitioning));
+}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_parquet.h b/src/arrow/cpp/src/arrow/dataset/file_parquet.h
new file mode 100644
index 000000000..daf4bd92d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_parquet.h
@@ -0,0 +1,385 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "arrow/dataset/discovery.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/type_fwd.h"
+#include "arrow/dataset/visibility.h"
+#include "arrow/io/caching.h"
+#include "arrow/util/optional.h"
+
+namespace parquet {
+class ParquetFileReader;
+class Statistics;
+class ColumnChunkMetaData;
+class RowGroupMetaData;
+class FileMetaData;
+class FileDecryptionProperties;
+class FileEncryptionProperties;
+
+class ReaderProperties;
+class ArrowReaderProperties;
+
+class WriterProperties;
+class ArrowWriterProperties;
+
+namespace arrow {
+class FileReader;
+class FileWriter;
+struct SchemaManifest;
+} // namespace arrow
+} // namespace parquet
+
+namespace arrow {
+namespace dataset {
+
+/// \addtogroup dataset-file-formats
+///
+/// @{
+
+constexpr char kParquetTypeName[] = "parquet";
+
+/// \brief A FileFormat implementation that reads from Parquet files
+class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat {
+ public:
+ ParquetFileFormat() = default;
+
+ /// Convenience constructor which copies properties from a parquet::ReaderProperties.
+ /// memory_pool will be ignored.
+ explicit ParquetFileFormat(const parquet::ReaderProperties& reader_properties);
+
+ std::string type_name() const override { return kParquetTypeName; }
+
+ bool Equals(const FileFormat& other) const override;
+
+ struct ReaderOptions {
+ /// \defgroup parquet-file-format-arrow-reader-properties properties which correspond
+ /// to members of parquet::ArrowReaderProperties.
+ ///
+ /// We don't embed parquet::ReaderProperties directly because column names (rather
+ /// than indices) are used to indicate dictionary columns, and other options are
+ /// deferred to scan time.
+ ///
+ /// @{
+ std::unordered_set<std::string> dict_columns;
+ arrow::TimeUnit::type coerce_int96_timestamp_unit = arrow::TimeUnit::NANO;
+ /// @}
+ } reader_options;
+
+ Result<bool> IsSupported(const FileSource& source) const override;
+
+ /// \brief Return the schema of the file if possible.
+ Result<std::shared_ptr<Schema>> Inspect(const FileSource& source) const override;
+
+ /// \brief Open a file for scanning
+ Result<ScanTaskIterator> ScanFile(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& file) const override;
+
+ Result<RecordBatchGenerator> ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& file) const override;
+
+ Future<util::optional<int64_t>> CountRows(
+ const std::shared_ptr<FileFragment>& file, compute::Expression predicate,
+ const std::shared_ptr<ScanOptions>& options) override;
+
+ using FileFormat::MakeFragment;
+
+ /// \brief Create a Fragment targeting all RowGroups.
+ Result<std::shared_ptr<FileFragment>> MakeFragment(
+ FileSource source, compute::Expression partition_expression,
+ std::shared_ptr<Schema> physical_schema) override;
+
+ /// \brief Create a Fragment, restricted to the specified row groups.
+ Result<std::shared_ptr<ParquetFileFragment>> MakeFragment(
+ FileSource source, compute::Expression partition_expression,
+ std::shared_ptr<Schema> physical_schema, std::vector<int> row_groups);
+
+ /// \brief Return a FileReader on the given source.
+ Result<std::unique_ptr<parquet::arrow::FileReader>> GetReader(
+ const FileSource& source, ScanOptions* = NULLPTR) const;
+
+ Future<std::shared_ptr<parquet::arrow::FileReader>> GetReaderAsync(
+ const FileSource& source, const std::shared_ptr<ScanOptions>& options) const;
+
+ Result<std::shared_ptr<FileWriter>> MakeWriter(
+ std::shared_ptr<io::OutputStream> destination, std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options,
+ fs::FileLocator destination_locator) const override;
+
+ std::shared_ptr<FileWriteOptions> DefaultWriteOptions() override;
+};
+
+/// \brief A FileFragment with parquet logic.
+///
+/// ParquetFileFragment provides a lazy (with respect to IO) interface to
+/// scan parquet files. Any heavy IO calls are deferred to the Scan() method.
+///
+/// The caller can provide an optional list of selected RowGroups to limit the
+/// number of scanned RowGroups, or to partition the scans across multiple
+/// threads.
+///
+/// Metadata can be explicitly provided, enabling pushdown predicate benefits without
+/// the potentially heavy IO of loading Metadata from the file system. This can induce
+/// significant performance boost when scanning high latency file systems.
+class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment {
+ public:
+ Result<FragmentVector> SplitByRowGroup(compute::Expression predicate);
+
+ /// \brief Return the RowGroups selected by this fragment.
+ const std::vector<int>& row_groups() const {
+ if (row_groups_) return *row_groups_;
+ static std::vector<int> empty;
+ return empty;
+ }
+
+ /// \brief Return the FileMetaData associated with this fragment.
+ const std::shared_ptr<parquet::FileMetaData>& metadata() const { return metadata_; }
+
+ /// \brief Ensure this fragment's FileMetaData is in memory.
+ Status EnsureCompleteMetadata(parquet::arrow::FileReader* reader = NULLPTR);
+
+ /// \brief Return fragment which selects a filtered subset of this fragment's RowGroups.
+ Result<std::shared_ptr<Fragment>> Subset(compute::Expression predicate);
+ Result<std::shared_ptr<Fragment>> Subset(std::vector<int> row_group_ids);
+
+ private:
+ ParquetFileFragment(FileSource source, std::shared_ptr<FileFormat> format,
+ compute::Expression partition_expression,
+ std::shared_ptr<Schema> physical_schema,
+ util::optional<std::vector<int>> row_groups);
+
+ Status SetMetadata(std::shared_ptr<parquet::FileMetaData> metadata,
+ std::shared_ptr<parquet::arrow::SchemaManifest> manifest);
+
+ // Overridden to opportunistically set metadata since a reader must be opened anyway.
+ Result<std::shared_ptr<Schema>> ReadPhysicalSchemaImpl() override {
+ ARROW_RETURN_NOT_OK(EnsureCompleteMetadata());
+ return physical_schema_;
+ }
+
+ /// Return a filtered subset of row group indices.
+ Result<std::vector<int>> FilterRowGroups(compute::Expression predicate);
+ /// Simplify the predicate against the statistics of each row group.
+ Result<std::vector<compute::Expression>> TestRowGroups(compute::Expression predicate);
+ /// Try to count rows matching the predicate using metadata. Expects
+ /// metadata to be present, and expects the predicate to have been
+ /// simplified against the partition expression already.
+ Result<util::optional<int64_t>> TryCountRows(compute::Expression predicate);
+
+ ParquetFileFormat& parquet_format_;
+
+ /// Indices of row groups selected by this fragment,
+ /// or util::nullopt if all row groups are selected.
+ util::optional<std::vector<int>> row_groups_;
+
+ std::vector<compute::Expression> statistics_expressions_;
+ std::vector<bool> statistics_expressions_complete_;
+ std::shared_ptr<parquet::FileMetaData> metadata_;
+ std::shared_ptr<parquet::arrow::SchemaManifest> manifest_;
+
+ friend class ParquetFileFormat;
+ friend class ParquetDatasetFactory;
+};
+
+/// \brief Per-scan options for Parquet fragments
+class ARROW_DS_EXPORT ParquetFragmentScanOptions : public FragmentScanOptions {
+ public:
+ ParquetFragmentScanOptions();
+ std::string type_name() const override { return kParquetTypeName; }
+
+ /// Reader properties. Not all properties are respected: memory_pool comes from
+ /// ScanOptions.
+ std::shared_ptr<parquet::ReaderProperties> reader_properties;
+ /// Arrow reader properties. Not all properties are respected: batch_size comes from
+ /// ScanOptions, and use_threads will be overridden based on
+ /// enable_parallel_column_conversion. Additionally, dictionary columns come from
+ /// ParquetFileFormat::ReaderOptions::dict_columns.
+ std::shared_ptr<parquet::ArrowReaderProperties> arrow_reader_properties;
+ /// EXPERIMENTAL: Parallelize conversion across columns. This option is ignored if a
+ /// scan is already parallelized across input files to avoid thread contention. This
+ /// option will be removed after support is added for simultaneous parallelization
+ /// across files and columns. Only affects the threaded reader; the async reader
+ /// will parallelize across columns if use_threads is enabled.
+ bool enable_parallel_column_conversion = false;
+};
+
+class ARROW_DS_EXPORT ParquetFileWriteOptions : public FileWriteOptions {
+ public:
+ /// \brief Parquet writer properties.
+ std::shared_ptr<parquet::WriterProperties> writer_properties;
+
+ /// \brief Parquet Arrow writer properties.
+ std::shared_ptr<parquet::ArrowWriterProperties> arrow_writer_properties;
+
+ protected:
+ using FileWriteOptions::FileWriteOptions;
+
+ friend class ParquetFileFormat;
+};
+
+class ARROW_DS_EXPORT ParquetFileWriter : public FileWriter {
+ public:
+ const std::shared_ptr<parquet::arrow::FileWriter>& parquet_writer() const {
+ return parquet_writer_;
+ }
+
+ Status Write(const std::shared_ptr<RecordBatch>& batch) override;
+
+ private:
+ ParquetFileWriter(std::shared_ptr<io::OutputStream> destination,
+ std::shared_ptr<parquet::arrow::FileWriter> writer,
+ std::shared_ptr<ParquetFileWriteOptions> options,
+ fs::FileLocator destination_locator);
+
+ Status FinishInternal() override;
+
+ std::shared_ptr<parquet::arrow::FileWriter> parquet_writer_;
+
+ friend class ParquetFileFormat;
+};
+
+/// \brief Options for making a FileSystemDataset from a Parquet _metadata file.
+struct ParquetFactoryOptions {
+ /// Either an explicit Partitioning or a PartitioningFactory to discover one.
+ ///
+ /// If a factory is provided, it will be used to infer a schema for partition fields
+ /// based on file and directory paths then construct a Partitioning. The default
+ /// is a Partitioning which will yield no partition information.
+ ///
+ /// The (explicit or discovered) partitioning will be applied to discovered files
+ /// and the resulting partition information embedded in the Dataset.
+ PartitioningOrFactory partitioning{Partitioning::Default()};
+
+ /// For the purposes of applying the partitioning, paths will be stripped
+ /// of the partition_base_dir. Files not matching the partition_base_dir
+ /// prefix will be skipped for partition discovery. The ignored files will still
+ /// be part of the Dataset, but will not have partition information.
+ ///
+ /// Example:
+ /// partition_base_dir = "/dataset";
+ ///
+ /// - "/dataset/US/sales.csv" -> "US/sales.csv" will be given to the partitioning
+ ///
+ /// - "/home/john/late_sales.csv" -> Will be ignored for partition discovery.
+ ///
+ /// This is useful for partitioning which parses directory when ordering
+ /// is important, e.g. DirectoryPartitioning.
+ std::string partition_base_dir;
+
+ /// Assert that all ColumnChunk paths are consistent. The parquet spec allows for
+ /// ColumnChunk data to be stored in multiple files, but ParquetDatasetFactory
+ /// supports only a single file with all ColumnChunk data. If this flag is set
+ /// construction of a ParquetDatasetFactory will raise an error if ColumnChunk
+ /// data is not resident in a single file.
+ bool validate_column_chunk_paths = false;
+};
+
+/// \brief Create FileSystemDataset from custom `_metadata` cache file.
+///
+/// Dask and other systems will generate a cache metadata file by concatenating
+/// the RowGroupMetaData of multiple parquet files into a single parquet file
+/// that only contains metadata and no ColumnChunk data.
+///
+/// ParquetDatasetFactory creates a FileSystemDataset composed of
+/// ParquetFileFragment where each fragment is pre-populated with the exact
+/// number of row groups and statistics for each columns.
+class ARROW_DS_EXPORT ParquetDatasetFactory : public DatasetFactory {
+ public:
+ /// \brief Create a ParquetDatasetFactory from a metadata path.
+ ///
+ /// The `metadata_path` will be read from `filesystem`. Each RowGroup
+ /// contained in the metadata file will be relative to `dirname(metadata_path)`.
+ ///
+ /// \param[in] metadata_path path of the metadata parquet file
+ /// \param[in] filesystem from which to open/read the path
+ /// \param[in] format to read the file with.
+ /// \param[in] options see ParquetFactoryOptions
+ static Result<std::shared_ptr<DatasetFactory>> Make(
+ const std::string& metadata_path, std::shared_ptr<fs::FileSystem> filesystem,
+ std::shared_ptr<ParquetFileFormat> format, ParquetFactoryOptions options);
+
+ /// \brief Create a ParquetDatasetFactory from a metadata source.
+ ///
+ /// Similar to the previous Make definition, but the metadata can be a Buffer
+ /// and the base_path is explicited instead of inferred from the metadata
+ /// path.
+ ///
+ /// \param[in] metadata source to open the metadata parquet file from
+ /// \param[in] base_path used as the prefix of every parquet files referenced
+ /// \param[in] filesystem from which to read the files referenced.
+ /// \param[in] format to read the file with.
+ /// \param[in] options see ParquetFactoryOptions
+ static Result<std::shared_ptr<DatasetFactory>> Make(
+ const FileSource& metadata, const std::string& base_path,
+ std::shared_ptr<fs::FileSystem> filesystem,
+ std::shared_ptr<ParquetFileFormat> format, ParquetFactoryOptions options);
+
+ Result<std::vector<std::shared_ptr<Schema>>> InspectSchemas(
+ InspectOptions options) override;
+
+ Result<std::shared_ptr<Dataset>> Finish(FinishOptions options) override;
+
+ protected:
+ ParquetDatasetFactory(
+ std::shared_ptr<fs::FileSystem> filesystem,
+ std::shared_ptr<ParquetFileFormat> format,
+ std::shared_ptr<parquet::FileMetaData> metadata,
+ std::shared_ptr<parquet::arrow::SchemaManifest> manifest,
+ std::shared_ptr<Schema> physical_schema, std::string base_path,
+ ParquetFactoryOptions options,
+ std::vector<std::pair<std::string, std::vector<int>>> paths_with_row_group_ids)
+ : filesystem_(std::move(filesystem)),
+ format_(std::move(format)),
+ metadata_(std::move(metadata)),
+ manifest_(std::move(manifest)),
+ physical_schema_(std::move(physical_schema)),
+ base_path_(std::move(base_path)),
+ options_(std::move(options)),
+ paths_with_row_group_ids_(std::move(paths_with_row_group_ids)) {}
+
+ std::shared_ptr<fs::FileSystem> filesystem_;
+ std::shared_ptr<ParquetFileFormat> format_;
+ std::shared_ptr<parquet::FileMetaData> metadata_;
+ std::shared_ptr<parquet::arrow::SchemaManifest> manifest_;
+ std::shared_ptr<Schema> physical_schema_;
+ std::string base_path_;
+ ParquetFactoryOptions options_;
+ std::vector<std::pair<std::string, std::vector<int>>> paths_with_row_group_ids_;
+
+ private:
+ Result<std::vector<std::shared_ptr<FileFragment>>> CollectParquetFragments(
+ const Partitioning& partitioning);
+
+ Result<std::shared_ptr<Schema>> PartitionSchema();
+};
+
+/// @}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_parquet_test.cc b/src/arrow/cpp/src/arrow/dataset/file_parquet_test.cc
new file mode 100644
index 000000000..1612f1d25
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_parquet_test.cc
@@ -0,0 +1,612 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/file_parquet.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/scanner_internal.h"
+#include "arrow/dataset/test_util.h"
+#include "arrow/io/memory.h"
+#include "arrow/io/util_internal.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/range.h"
+
+#include "parquet/arrow/writer.h"
+#include "parquet/metadata.h"
+
+namespace arrow {
+
+using internal::checked_pointer_cast;
+
+namespace dataset {
+
+using parquet::ArrowWriterProperties;
+using parquet::default_arrow_writer_properties;
+
+using parquet::default_writer_properties;
+using parquet::WriterProperties;
+
+using parquet::CreateOutputStream;
+using parquet::arrow::WriteTable;
+
+using testing::Pointee;
+
+class ParquetFormatHelper {
+ public:
+ using FormatType = ParquetFileFormat;
+
+ static Result<std::shared_ptr<Buffer>> Write(RecordBatchReader* reader) {
+ auto pool = ::arrow::default_memory_pool();
+ std::shared_ptr<Buffer> out;
+ auto sink = CreateOutputStream(pool);
+ RETURN_NOT_OK(WriteRecordBatchReader(reader, pool, sink));
+ return sink->Finish();
+ }
+ static std::shared_ptr<ParquetFileFormat> MakeFormat() {
+ return std::make_shared<ParquetFileFormat>();
+ }
+
+ private:
+ static Status WriteRecordBatch(const RecordBatch& batch,
+ parquet::arrow::FileWriter* writer) {
+ auto schema = batch.schema();
+ auto size = batch.num_rows();
+
+ if (!schema->Equals(*writer->schema(), false)) {
+ return Status::Invalid("RecordBatch schema does not match this writer's. batch:'",
+ schema->ToString(), "' this:'", writer->schema()->ToString(),
+ "'");
+ }
+
+ RETURN_NOT_OK(writer->NewRowGroup(size));
+ for (int i = 0; i < batch.num_columns(); i++) {
+ RETURN_NOT_OK(writer->WriteColumnChunk(*batch.column(i)));
+ }
+
+ return Status::OK();
+ }
+
+ static Status WriteRecordBatchReader(RecordBatchReader* reader,
+ parquet::arrow::FileWriter* writer) {
+ auto schema = reader->schema();
+
+ if (!schema->Equals(*writer->schema(), false)) {
+ return Status::Invalid("RecordBatch schema does not match this writer's. batch:'",
+ schema->ToString(), "' this:'", writer->schema()->ToString(),
+ "'");
+ }
+
+ return MakeFunctionIterator([reader] { return reader->Next(); })
+ .Visit([&](std::shared_ptr<RecordBatch> batch) {
+ return WriteRecordBatch(*batch, writer);
+ });
+ }
+
+ static Status WriteRecordBatchReader(
+ RecordBatchReader* reader, MemoryPool* pool,
+ const std::shared_ptr<io::OutputStream>& sink,
+ const std::shared_ptr<WriterProperties>& properties = default_writer_properties(),
+ const std::shared_ptr<ArrowWriterProperties>& arrow_properties =
+ default_arrow_writer_properties()) {
+ std::unique_ptr<parquet::arrow::FileWriter> writer;
+ RETURN_NOT_OK(parquet::arrow::FileWriter::Open(
+ *reader->schema(), pool, sink, properties, arrow_properties, &writer));
+ RETURN_NOT_OK(WriteRecordBatchReader(reader, writer.get()));
+ return writer->Close();
+ }
+};
+
+class TestParquetFileFormat : public FileFormatFixtureMixin<ParquetFormatHelper> {
+ public:
+ RecordBatchIterator Batches(ScanTaskIterator scan_task_it) {
+ return MakeFlattenIterator(MakeMaybeMapIterator(
+ [](std::shared_ptr<ScanTask> scan_task) { return scan_task->Execute(); },
+ std::move(scan_task_it)));
+ }
+
+ RecordBatchIterator Batches(Fragment* fragment) {
+ EXPECT_OK_AND_ASSIGN(auto scan_task_it, fragment->Scan(opts_));
+ return Batches(std::move(scan_task_it));
+ }
+
+ std::shared_ptr<RecordBatch> SingleBatch(Fragment* fragment) {
+ auto batches = IteratorToVector(Batches(fragment));
+ EXPECT_EQ(batches.size(), 1);
+ return batches.front();
+ }
+
+ void CountRowsAndBatchesInScan(Fragment* fragment, int64_t expected_rows,
+ int64_t expected_batches) {
+ int64_t actual_rows = 0;
+ int64_t actual_batches = 0;
+
+ for (auto maybe_batch : Batches(fragment)) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ actual_rows += batch->num_rows();
+ ++actual_batches;
+ }
+
+ EXPECT_EQ(actual_rows, expected_rows);
+ EXPECT_EQ(actual_batches, expected_batches);
+ }
+
+ void CountRowsAndBatchesInScan(const std::shared_ptr<Fragment>& fragment,
+ int64_t expected_rows, int64_t expected_batches) {
+ return CountRowsAndBatchesInScan(fragment.get(), expected_rows, expected_batches);
+ }
+
+ void CountRowGroupsInFragment(const std::shared_ptr<Fragment>& fragment,
+ std::vector<int> expected_row_groups,
+ compute::Expression filter) {
+ SetFilter(filter);
+
+ auto parquet_fragment = checked_pointer_cast<ParquetFileFragment>(fragment);
+ ASSERT_OK_AND_ASSIGN(auto fragments, parquet_fragment->SplitByRowGroup(opts_->filter))
+
+ EXPECT_EQ(fragments.size(), expected_row_groups.size());
+ for (size_t i = 0; i < fragments.size(); i++) {
+ auto expected = expected_row_groups[i];
+ auto parquet_fragment = checked_pointer_cast<ParquetFileFragment>(fragments[i]);
+
+ EXPECT_EQ(parquet_fragment->row_groups(), std::vector<int>{expected});
+ EXPECT_EQ(SingleBatch(parquet_fragment.get())->num_rows(), expected + 1);
+ }
+ }
+};
+
+TEST_F(TestParquetFileFormat, InspectFailureWithRelevantError) {
+ TestInspectFailureWithRelevantError(StatusCode::Invalid, "Parquet");
+}
+TEST_F(TestParquetFileFormat, Inspect) { TestInspect(); }
+
+TEST_F(TestParquetFileFormat, InspectDictEncoded) {
+ auto reader = GetRecordBatchReader(schema({field("utf8", utf8())}));
+ auto source = GetFileSource(reader.get());
+
+ format_->reader_options.dict_columns = {"utf8"};
+ ASSERT_OK_AND_ASSIGN(auto actual, format_->Inspect(*source.get()));
+
+ Schema expected_schema({field("utf8", dictionary(int32(), utf8()))});
+ AssertSchemaEqual(*actual, expected_schema, /* check_metadata = */ false);
+}
+
+TEST_F(TestParquetFileFormat, IsSupported) { TestIsSupported(); }
+
+TEST_F(TestParquetFileFormat, WriteRecordBatchReader) { TestWrite(); }
+
+TEST_F(TestParquetFileFormat, WriteRecordBatchReaderCustomOptions) {
+ TimeUnit::type coerce_timestamps_to = TimeUnit::MICRO,
+ coerce_timestamps_from = TimeUnit::NANO;
+
+ auto reader =
+ GetRecordBatchReader(schema({field("ts", timestamp(coerce_timestamps_from))}));
+ auto options =
+ checked_pointer_cast<ParquetFileWriteOptions>(format_->DefaultWriteOptions());
+ options->writer_properties = parquet::WriterProperties::Builder()
+ .created_by("TestParquetFileFormat")
+ ->disable_statistics()
+ ->build();
+ options->arrow_writer_properties = parquet::ArrowWriterProperties::Builder()
+ .coerce_timestamps(coerce_timestamps_to)
+ ->allow_truncated_timestamps()
+ ->build();
+
+ auto written = WriteToBuffer(reader->schema(), options);
+
+ EXPECT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(FileSource{written}));
+ EXPECT_OK_AND_ASSIGN(auto actual_schema, fragment->ReadPhysicalSchema());
+ AssertSchemaEqual(Schema({field("ts", timestamp(coerce_timestamps_to))}),
+ *actual_schema);
+}
+
+TEST_F(TestParquetFileFormat, CountRows) { TestCountRows(); }
+
+TEST_F(TestParquetFileFormat, CountRowsPredicatePushdown) {
+ constexpr int64_t kNumRowGroups = 16;
+ constexpr int64_t kTotalNumRows = kNumRowGroups * (kNumRowGroups + 1) / 2;
+
+ // See PredicatePushdown test below for a description of the generated data
+ auto reader = ArithmeticDatasetFixture::GetRecordBatchReader(kNumRowGroups);
+ auto source = GetFileSource(reader.get());
+ auto options = std::make_shared<ScanOptions>();
+
+ auto fragment = MakeFragment(*source);
+
+ ASSERT_FINISHES_OK_AND_EQ(util::make_optional<int64_t>(kTotalNumRows),
+ fragment->CountRows(literal(true), options));
+
+ for (int i = 1; i <= kNumRowGroups; i++) {
+ SCOPED_TRACE(i);
+ // The row group for which all values in column i64 == i has i rows
+ auto predicate = less_equal(field_ref("i64"), literal(i));
+ ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*reader->schema()));
+ auto expected = i * (i + 1) / 2;
+ ASSERT_FINISHES_OK_AND_EQ(util::make_optional<int64_t>(expected),
+ fragment->CountRows(predicate, options));
+
+ predicate = and_(less_equal(field_ref("i64"), literal(i)),
+ greater_equal(field_ref("i64"), literal(i)));
+ ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*reader->schema()));
+ ASSERT_FINISHES_OK_AND_EQ(util::make_optional<int64_t>(i),
+ fragment->CountRows(predicate, options));
+
+ predicate = equal(field_ref("i64"), literal(i));
+ ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*reader->schema()));
+ ASSERT_FINISHES_OK_AND_EQ(util::make_optional<int64_t>(i),
+ fragment->CountRows(predicate, options));
+ }
+
+ // Ensure nulls are properly handled
+ {
+ auto dataset_schema = schema({field("i64", int64())});
+ auto null_batch = RecordBatchFromJSON(dataset_schema, R"([
+[null],
+[null],
+[null]
+])");
+ auto batch = RecordBatchFromJSON(dataset_schema, R"([
+[1],
+[2]
+])");
+ ASSERT_OK_AND_ASSIGN(auto reader,
+ RecordBatchReader::Make({null_batch, batch}, dataset_schema));
+ auto source = GetFileSource(reader.get());
+ auto fragment = MakeFragment(*source);
+ ASSERT_OK_AND_ASSIGN(
+ auto predicate,
+ greater_equal(field_ref("i64"), literal(1)).Bind(*dataset_schema));
+ ASSERT_FINISHES_OK_AND_EQ(util::make_optional<int64_t>(2),
+ fragment->CountRows(predicate, options));
+ // TODO(ARROW-12659): SimplifyWithGuarantee can't handle
+ // not(is_null) so trying to count with is_null doesn't work
+ }
+}
+
+TEST_F(TestParquetFileFormat, MultithreadedScan) {
+ constexpr int64_t kNumRowGroups = 16;
+
+ // See PredicatePushdown test below for a description of the generated data
+ auto reader = ArithmeticDatasetFixture::GetRecordBatchReader(kNumRowGroups);
+ auto source = GetFileSource(reader.get());
+ auto options = std::make_shared<ScanOptions>();
+
+ auto fragment = MakeFragment(*source);
+
+ FragmentDataset dataset(ArithmeticDatasetFixture::schema(), {fragment});
+ ScannerBuilder builder({&dataset, [](...) {}});
+
+ ASSERT_OK(builder.UseAsync(true));
+ ASSERT_OK(builder.UseThreads(true));
+ ASSERT_OK(builder.Project({call("add", {field_ref("i64"), literal(3)})}, {""}));
+ ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+
+ ASSERT_OK_AND_ASSIGN(auto gen, scanner->ScanBatchesUnorderedAsync());
+
+ auto collect_fut = CollectAsyncGenerator(gen);
+ ASSERT_OK_AND_ASSIGN(auto batches, collect_fut.result());
+
+ ASSERT_EQ(batches.size(), kNumRowGroups);
+}
+
+class TestParquetFileSystemDataset : public WriteFileSystemDatasetMixin,
+ public testing::Test {
+ public:
+ void SetUp() override {
+ MakeSourceDataset();
+ check_metadata_ = false;
+ auto parquet_format = std::make_shared<ParquetFileFormat>();
+ format_ = parquet_format;
+ SetWriteOptions(parquet_format->DefaultWriteOptions());
+ }
+};
+
+TEST_F(TestParquetFileSystemDataset, WriteWithIdenticalPartitioningSchema) {
+ TestWriteWithIdenticalPartitioningSchema();
+}
+
+TEST_F(TestParquetFileSystemDataset, WriteWithUnrelatedPartitioningSchema) {
+ TestWriteWithUnrelatedPartitioningSchema();
+}
+
+TEST_F(TestParquetFileSystemDataset, WriteWithSupersetPartitioningSchema) {
+ TestWriteWithSupersetPartitioningSchema();
+}
+
+TEST_F(TestParquetFileSystemDataset, WriteWithEmptyPartitioningSchema) {
+ TestWriteWithEmptyPartitioningSchema();
+}
+
+class TestParquetFileFormatScan : public FileFormatScanMixin<ParquetFormatHelper> {
+ public:
+ std::shared_ptr<RecordBatch> SingleBatch(std::shared_ptr<Fragment> fragment) {
+ auto batches = IteratorToVector(PhysicalBatches(fragment));
+ EXPECT_EQ(batches.size(), 1);
+ return batches.front();
+ }
+
+ void CountRowsAndBatchesInScan(std::shared_ptr<Fragment> fragment,
+ int64_t expected_rows, int64_t expected_batches) {
+ int64_t actual_rows = 0;
+ int64_t actual_batches = 0;
+
+ for (auto maybe_batch : PhysicalBatches(fragment)) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ actual_rows += batch->num_rows();
+ ++actual_batches;
+ }
+
+ EXPECT_EQ(actual_rows, expected_rows);
+ EXPECT_EQ(actual_batches, expected_batches);
+ }
+
+ void CountRowGroupsInFragment(const std::shared_ptr<Fragment>& fragment,
+ std::vector<int> expected_row_groups,
+ compute::Expression filter) {
+ SetFilter(filter);
+
+ auto parquet_fragment = checked_pointer_cast<ParquetFileFragment>(fragment);
+ ASSERT_OK_AND_ASSIGN(auto fragments, parquet_fragment->SplitByRowGroup(opts_->filter))
+
+ EXPECT_EQ(fragments.size(), expected_row_groups.size());
+ for (size_t i = 0; i < fragments.size(); i++) {
+ auto expected = expected_row_groups[i];
+ auto parquet_fragment = checked_pointer_cast<ParquetFileFragment>(fragments[i]);
+
+ EXPECT_EQ(parquet_fragment->row_groups(), std::vector<int>{expected});
+ EXPECT_EQ(SingleBatch(parquet_fragment)->num_rows(), expected + 1);
+ }
+ }
+};
+
+TEST_P(TestParquetFileFormatScan, ScanRecordBatchReader) { TestScan(); }
+TEST_P(TestParquetFileFormatScan, ScanBatchSize) { TestScanBatchSize(); }
+TEST_P(TestParquetFileFormatScan, ScanRecordBatchReaderProjected) { TestScanProjected(); }
+TEST_P(TestParquetFileFormatScan, ScanRecordBatchReaderProjectedMissingCols) {
+ TestScanProjectedMissingCols();
+}
+TEST_P(TestParquetFileFormatScan, ScanRecordBatchReaderDictEncoded) {
+ auto reader = GetRecordBatchReader(schema({field("utf8", utf8())}));
+ auto source = GetFileSource(reader.get());
+
+ SetSchema(reader->schema()->fields());
+ SetFilter(literal(true));
+ format_->reader_options.dict_columns = {"utf8"};
+ ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source));
+
+ int64_t row_count = 0;
+ Schema expected_schema({field("utf8", dictionary(int32(), utf8()))});
+
+ for (auto maybe_batch : PhysicalBatches(fragment)) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ row_count += batch->num_rows();
+ AssertSchemaEqual(*batch->schema(), expected_schema, /* check_metadata = */ false);
+ }
+ ASSERT_EQ(row_count, expected_rows());
+}
+TEST_P(TestParquetFileFormatScan, ScanRecordBatchReaderPreBuffer) {
+ auto reader = GetRecordBatchReader(schema({field("f64", float64())}));
+ auto source = GetFileSource(reader.get());
+
+ SetSchema(reader->schema()->fields());
+ SetFilter(literal(true));
+
+ ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source));
+ auto fragment_scan_options = std::make_shared<ParquetFragmentScanOptions>();
+ fragment_scan_options->arrow_reader_properties->set_pre_buffer(true);
+ opts_->fragment_scan_options = fragment_scan_options;
+ ASSERT_OK_AND_ASSIGN(auto scan_task_it, fragment->Scan(opts_));
+
+ int64_t row_count = 0;
+ for (auto maybe_batch : PhysicalBatches(fragment)) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ row_count += batch->num_rows();
+ }
+ ASSERT_EQ(row_count, expected_rows());
+}
+TEST_P(TestParquetFileFormatScan, PredicatePushdown) {
+ // Given a number `n`, the arithmetic dataset creates n RecordBatches where
+ // each RecordBatch is keyed by a unique integer in [1, n]. Let `rb_i` denote
+ // the record batch keyed by `i`. `rb_i` is composed of `i` rows where all
+ // values are a variant of `i`, e.g. {"i64": i, "u8": i, ... }.
+ //
+ // Thus the ArithmeticDataset(n) has n RecordBatches and the total number of
+ // rows is n(n+1)/2.
+ //
+ // This test uses the Fragment directly, and so no post-filtering is
+ // applied via ScanOptions' evaluator. Thus, counting the number of returned
+ // rows and returned row groups is a good enough proxy to check if pushdown
+ // predicate is working.
+
+ constexpr int64_t kNumRowGroups = 16;
+ constexpr int64_t kTotalNumRows = kNumRowGroups * (kNumRowGroups + 1) / 2;
+
+ auto reader = ArithmeticDatasetFixture::GetRecordBatchReader(kNumRowGroups);
+ auto source = GetFileSource(reader.get());
+
+ SetSchema(reader->schema()->fields());
+ ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source));
+
+ SetFilter(literal(true));
+ CountRowsAndBatchesInScan(fragment, kTotalNumRows, kNumRowGroups);
+
+ for (int64_t i = 1; i <= kNumRowGroups; i++) {
+ SetFilter(equal(field_ref("i64"), literal(i)));
+ CountRowsAndBatchesInScan(fragment, i, 1);
+ }
+
+ // Out of bound filters should skip all RowGroups.
+ SetFilter(literal(false));
+ CountRowsAndBatchesInScan(fragment, 0, 0);
+ SetFilter(equal(field_ref("i64"), literal<int64_t>(kNumRowGroups + 1)));
+ CountRowsAndBatchesInScan(fragment, 0, 0);
+ SetFilter(equal(field_ref("i64"), literal<int64_t>(-1)));
+ CountRowsAndBatchesInScan(fragment, 0, 0);
+ // No rows match 1 and 2.
+ SetFilter(and_(equal(field_ref("i64"), literal<int64_t>(1)),
+ equal(field_ref("u8"), literal<uint8_t>(2))));
+ CountRowsAndBatchesInScan(fragment, 0, 0);
+
+ SetFilter(or_(equal(field_ref("i64"), literal<int64_t>(2)),
+ equal(field_ref("i64"), literal<int64_t>(4))));
+ CountRowsAndBatchesInScan(fragment, 2 + 4, 2);
+
+ SetFilter(less(field_ref("i64"), literal<int64_t>(6)));
+ CountRowsAndBatchesInScan(fragment, 5 * (5 + 1) / 2, 5);
+
+ SetFilter(greater_equal(field_ref("i64"), literal<int64_t>(6)));
+ CountRowsAndBatchesInScan(fragment, kTotalNumRows - (5 * (5 + 1) / 2),
+ kNumRowGroups - 5);
+}
+
+TEST_P(TestParquetFileFormatScan, PredicatePushdownRowGroupFragments) {
+ constexpr int64_t kNumRowGroups = 16;
+
+ auto reader = ArithmeticDatasetFixture::GetRecordBatchReader(kNumRowGroups);
+ auto source = GetFileSource(reader.get());
+
+ SetSchema(reader->schema()->fields());
+ ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source));
+
+ auto all_row_groups = ::arrow::internal::Iota(static_cast<int>(kNumRowGroups));
+ CountRowGroupsInFragment(fragment, all_row_groups, literal(true));
+
+ for (int i = 0; i < kNumRowGroups; ++i) {
+ CountRowGroupsInFragment(fragment, {i}, equal(field_ref("i64"), literal(i + 1)));
+ }
+
+ // Out of bound filters should skip all RowGroups.
+ CountRowGroupsInFragment(fragment, {}, literal(false));
+ CountRowGroupsInFragment(fragment, {},
+ equal(field_ref("i64"), literal(kNumRowGroups + 1)));
+ CountRowGroupsInFragment(fragment, {}, equal(field_ref("i64"), literal(-1)));
+
+ // No rows match 1 and 2.
+ CountRowGroupsInFragment(
+ fragment, {},
+ and_(equal(field_ref("i64"), literal(1)), equal(field_ref("u8"), literal(2))));
+ CountRowGroupsInFragment(
+ fragment, {},
+ and_(equal(field_ref("i64"), literal(2)), equal(field_ref("i64"), literal(4))));
+
+ CountRowGroupsInFragment(
+ fragment, {1, 3},
+ or_(equal(field_ref("i64"), literal(2)), equal(field_ref("i64"), literal(4))));
+
+ auto set = ArrayFromJSON(int64(), "[2, 4]");
+ CountRowGroupsInFragment(
+ fragment, {1, 3},
+ call("is_in", {field_ref("i64")}, compute::SetLookupOptions{set}));
+
+ CountRowGroupsInFragment(fragment, {0, 1, 2, 3, 4}, less(field_ref("i64"), literal(6)));
+
+ CountRowGroupsInFragment(fragment,
+ ::arrow::internal::Iota(5, static_cast<int>(kNumRowGroups)),
+ greater_equal(field_ref("i64"), literal(6)));
+
+ CountRowGroupsInFragment(fragment, {5, 6},
+ and_(greater_equal(field_ref("i64"), literal(6)),
+ less(field_ref("i64"), literal(8))));
+}
+
+TEST_P(TestParquetFileFormatScan, ExplicitRowGroupSelection) {
+ constexpr int64_t kNumRowGroups = 16;
+ constexpr int64_t kTotalNumRows = kNumRowGroups * (kNumRowGroups + 1) / 2;
+
+ auto reader = ArithmeticDatasetFixture::GetRecordBatchReader(kNumRowGroups);
+ auto source = GetFileSource(reader.get());
+
+ SetSchema(reader->schema()->fields());
+ SetFilter(literal(true));
+
+ auto row_groups_fragment = [&](std::vector<int> row_groups) {
+ EXPECT_OK_AND_ASSIGN(auto fragment,
+ format_->MakeFragment(*source, literal(true),
+ /*physical_schema=*/nullptr, row_groups));
+ return fragment;
+ };
+
+ // select all row groups
+ EXPECT_OK_AND_ASSIGN(auto all_row_groups_fragment,
+ format_->MakeFragment(*source, literal(true))
+ .Map([](std::shared_ptr<FileFragment> f) {
+ return checked_pointer_cast<ParquetFileFragment>(f);
+ }));
+
+ EXPECT_EQ(all_row_groups_fragment->row_groups(), std::vector<int>{});
+
+ ARROW_EXPECT_OK(all_row_groups_fragment->EnsureCompleteMetadata());
+ CountRowsAndBatchesInScan(all_row_groups_fragment, kTotalNumRows, kNumRowGroups);
+
+ // individual selection selects a single row group
+ for (int i = 0; i < kNumRowGroups; ++i) {
+ CountRowsAndBatchesInScan(row_groups_fragment({i}), i + 1, 1);
+ EXPECT_EQ(row_groups_fragment({i})->row_groups(), std::vector<int>{i});
+ }
+
+ for (int i = 0; i < kNumRowGroups; ++i) {
+ // conflicting selection/filter
+ SetFilter(equal(field_ref("i64"), literal(i)));
+ CountRowsAndBatchesInScan(row_groups_fragment({i}), 0, 0);
+ }
+
+ for (int i = 0; i < kNumRowGroups; ++i) {
+ // identical selection/filter
+ SetFilter(equal(field_ref("i64"), literal(i + 1)));
+ CountRowsAndBatchesInScan(row_groups_fragment({i}), i + 1, 1);
+ }
+
+ SetFilter(greater(field_ref("i64"), literal(3)));
+ CountRowsAndBatchesInScan(row_groups_fragment({2, 3, 4, 5}), 4 + 5 + 6, 3);
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ IndexError,
+ testing::HasSubstr("only has " + std::to_string(kNumRowGroups) + " row groups"),
+ row_groups_fragment({kNumRowGroups + 1})->Scan(opts_));
+}
+
+TEST_P(TestParquetFileFormatScan, PredicatePushdownRowGroupFragmentsUsingStringColumn) {
+ auto table = TableFromJSON(schema({field("x", utf8())}),
+ {
+ R"([{"x": "a"}])",
+ R"([{"x": "b"}, {"x": "b"}])",
+ R"([{"x": "c"}, {"x": "c"}, {"x": "c"}])",
+ R"([{"x": "a"}, {"x": "b"}, {"x": "c"}, {"x": "d"}])",
+ });
+ TableBatchReader reader(*table);
+ auto source = GetFileSource(&reader);
+
+ SetSchema(reader.schema()->fields());
+ ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source));
+
+ CountRowGroupsInFragment(fragment, {0, 3}, equal(field_ref("x"), literal("a")));
+}
+
+INSTANTIATE_TEST_SUITE_P(TestScan, TestParquetFileFormatScan,
+ ::testing::ValuesIn(TestFormatParams::Values()),
+ TestFormatParams::ToTestNameString);
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/file_test.cc b/src/arrow/cpp/src/arrow/dataset/file_test.cc
new file mode 100644
index 000000000..db2eedbcf
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/file_test.cc
@@ -0,0 +1,346 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/dataset/api.h"
+#include "arrow/dataset/partition.h"
+#include "arrow/dataset/test_util.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/filesystem/test_util.h"
+#include "arrow/status.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/io_util.h"
+
+namespace arrow {
+
+using internal::TemporaryDir;
+
+namespace dataset {
+
+using fs::internal::GetAbstractPathExtension;
+using testing::ContainerEq;
+
+TEST(FileSource, PathBased) {
+ auto localfs = std::make_shared<fs::LocalFileSystem>();
+
+ std::string p1 = "/path/to/file.ext";
+ std::string p2 = "/path/to/file.ext.gz";
+
+ FileSource source1(p1, localfs);
+ FileSource source2(p2, localfs, Compression::GZIP);
+
+ ASSERT_EQ(p1, source1.path());
+ ASSERT_TRUE(localfs->Equals(*source1.filesystem()));
+ ASSERT_EQ(Compression::UNCOMPRESSED, source1.compression());
+
+ ASSERT_EQ(p2, source2.path());
+ ASSERT_TRUE(localfs->Equals(*source2.filesystem()));
+ ASSERT_EQ(Compression::GZIP, source2.compression());
+
+ // Test copy constructor and comparison
+ FileSource source3;
+ source3 = source1;
+ ASSERT_EQ(source1.path(), source3.path());
+ ASSERT_EQ(source1.filesystem(), source3.filesystem());
+}
+
+TEST(FileSource, BufferBased) {
+ std::string the_data = "this is the file contents";
+ auto buf = std::make_shared<Buffer>(the_data);
+
+ FileSource source1(buf);
+ FileSource source2(buf, Compression::LZ4);
+
+ ASSERT_TRUE(source1.buffer()->Equals(*buf));
+ ASSERT_EQ(Compression::UNCOMPRESSED, source1.compression());
+
+ ASSERT_TRUE(source2.buffer()->Equals(*buf));
+ ASSERT_EQ(Compression::LZ4, source2.compression());
+
+ FileSource source3;
+ source3 = source1;
+ ASSERT_EQ(source1.buffer(), source3.buffer());
+}
+
+constexpr int kNumScanTasks = 2;
+constexpr int kBatchesPerScanTask = 2;
+constexpr int kRowsPerBatch = 1024;
+class MockFileFormat : public FileFormat {
+ std::string type_name() const override { return "mock"; }
+ bool Equals(const FileFormat& other) const override { return false; }
+ Result<bool> IsSupported(const FileSource& source) const override { return true; }
+ Result<std::shared_ptr<Schema>> Inspect(const FileSource& source) const override {
+ return Status::NotImplemented("Not needed for test");
+ }
+ Result<std::shared_ptr<FileWriter>> MakeWriter(
+ std::shared_ptr<io::OutputStream> destination, std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options,
+ fs::FileLocator destination_locator) const override {
+ return Status::NotImplemented("Not needed for test");
+ }
+ std::shared_ptr<FileWriteOptions> DefaultWriteOptions() override { return nullptr; }
+
+ Result<ScanTaskIterator> ScanFile(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& file) const override {
+ auto sch = schema({field("i32", int32())});
+ ScanTaskVector scan_tasks;
+ for (int i = 0; i < kNumScanTasks; i++) {
+ RecordBatchVector batches;
+ for (int j = 0; j < kBatchesPerScanTask; j++) {
+ batches.push_back(ConstantArrayGenerator::Zeroes(kRowsPerBatch, sch));
+ }
+ scan_tasks.push_back(std::make_shared<InMemoryScanTask>(
+ std::move(batches), std::make_shared<ScanOptions>(), nullptr));
+ }
+ return MakeVectorIterator(std::move(scan_tasks));
+ }
+};
+
+TEST(FileFormat, ScanAsync) {
+ MockFileFormat format;
+ auto scan_options = std::make_shared<ScanOptions>();
+ ASSERT_OK_AND_ASSIGN(auto batch_gen, format.ScanBatchesAsync(scan_options, nullptr));
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto batches, CollectAsyncGenerator(batch_gen));
+ ASSERT_EQ(kNumScanTasks * kBatchesPerScanTask, static_cast<int>(batches.size()));
+ for (int i = 0; i < kNumScanTasks * kBatchesPerScanTask; i++) {
+ ASSERT_EQ(kRowsPerBatch, batches[i]->num_rows());
+ }
+}
+
+TEST_F(TestFileSystemDataset, Basic) {
+ MakeDataset({});
+ AssertFragmentsAreFromPath(*dataset_->GetFragments(), {});
+
+ MakeDataset({fs::File("a"), fs::File("b"), fs::File("c")});
+ AssertFragmentsAreFromPath(*dataset_->GetFragments(), {"a", "b", "c"});
+ AssertFilesAre(dataset_, {"a", "b", "c"});
+
+ // Should not create fragment from directories.
+ MakeDataset({fs::Dir("A"), fs::Dir("A/B"), fs::File("A/a"), fs::File("A/B/b")});
+ AssertFragmentsAreFromPath(*dataset_->GetFragments(), {"A/a", "A/B/b"});
+ AssertFilesAre(dataset_, {"A/a", "A/B/b"});
+}
+
+TEST_F(TestFileSystemDataset, ReplaceSchema) {
+ auto schm = schema({field("i32", int32()), field("f64", float64())});
+ auto format = std::make_shared<DummyFileFormat>(schm);
+ ASSERT_OK_AND_ASSIGN(auto dataset,
+ FileSystemDataset::Make(schm, literal(true), format, nullptr, {}));
+
+ // drop field
+ ASSERT_OK(dataset->ReplaceSchema(schema({field("i32", int32())})).status());
+ // add nullable field (will be materialized as null during projection)
+ ASSERT_OK(dataset->ReplaceSchema(schema({field("str", utf8())})).status());
+ // incompatible type
+ ASSERT_RAISES(TypeError,
+ dataset->ReplaceSchema(schema({field("i32", utf8())})).status());
+ // incompatible nullability
+ ASSERT_RAISES(
+ TypeError,
+ dataset->ReplaceSchema(schema({field("f64", float64(), /*nullable=*/false)}))
+ .status());
+ // add non-nullable field
+ ASSERT_RAISES(TypeError,
+ dataset->ReplaceSchema(schema({field("str", utf8(), /*nullable=*/false)}))
+ .status());
+}
+
+TEST_F(TestFileSystemDataset, RootPartitionPruning) {
+ auto root_partition = equal(field_ref("i32"), literal(5));
+ MakeDataset({fs::File("a"), fs::File("b")}, root_partition, {},
+ schema({field("i32", int32()), field("f32", float32())}));
+
+ auto GetFragments = [&](compute::Expression filter) {
+ return *dataset_->GetFragments(*filter.Bind(*dataset_->schema()));
+ };
+
+ // Default filter should always return all data.
+ AssertFragmentsAreFromPath(*dataset_->GetFragments(), {"a", "b"});
+
+ // filter == partition
+ AssertFragmentsAreFromPath(GetFragments(root_partition), {"a", "b"});
+
+ // Same partition key, but non matching filter
+ AssertFragmentsAreFromPath(GetFragments(equal(field_ref("i32"), literal(6))), {});
+
+ AssertFragmentsAreFromPath(GetFragments(greater(field_ref("i32"), literal(1))),
+ {"a", "b"});
+
+ // different key shouldn't prune
+ AssertFragmentsAreFromPath(GetFragments(equal(field_ref("f32"), literal(3.F))),
+ {"a", "b"});
+
+ // No root partition: don't prune any fragments
+ MakeDataset({fs::File("a"), fs::File("b")}, literal(true), {},
+ schema({field("i32", int32()), field("f32", float32())}));
+ AssertFragmentsAreFromPath(GetFragments(equal(field_ref("f32"), literal(3.F))),
+ {"a", "b"});
+}
+
+TEST_F(TestFileSystemDataset, TreePartitionPruning) {
+ auto root_partition = equal(field_ref("country"), literal("US"));
+
+ std::vector<fs::FileInfo> regions = {
+ fs::Dir("NY"), fs::File("NY/New York"), fs::File("NY/Franklin"),
+ fs::Dir("CA"), fs::File("CA/San Francisco"), fs::File("CA/Franklin"),
+ };
+
+ std::vector<compute::Expression> partitions = {
+ equal(field_ref("state"), literal("NY")),
+
+ and_(equal(field_ref("state"), literal("NY")),
+ equal(field_ref("city"), literal("New York"))),
+
+ and_(equal(field_ref("state"), literal("NY")),
+ equal(field_ref("city"), literal("Franklin"))),
+
+ equal(field_ref("state"), literal("CA")),
+
+ and_(equal(field_ref("state"), literal("CA")),
+ equal(field_ref("city"), literal("San Francisco"))),
+
+ and_(equal(field_ref("state"), literal("CA")),
+ equal(field_ref("city"), literal("Franklin"))),
+ };
+
+ MakeDataset(
+ regions, root_partition, partitions,
+ schema({field("country", utf8()), field("state", utf8()), field("city", utf8())}));
+
+ std::vector<std::string> all_cities = {"CA/San Francisco", "CA/Franklin", "NY/New York",
+ "NY/Franklin"};
+ std::vector<std::string> ca_cities = {"CA/San Francisco", "CA/Franklin"};
+ std::vector<std::string> franklins = {"CA/Franklin", "NY/Franklin"};
+
+ // Default filter should always return all data.
+ AssertFragmentsAreFromPath(*dataset_->GetFragments(), all_cities);
+
+ auto GetFragments = [&](compute::Expression filter) {
+ return *dataset_->GetFragments(*filter.Bind(*dataset_->schema()));
+ };
+
+ // Dataset's partitions are respected
+ AssertFragmentsAreFromPath(GetFragments(equal(field_ref("country"), literal("US"))),
+ all_cities);
+ AssertFragmentsAreFromPath(GetFragments(equal(field_ref("country"), literal("FR"))),
+ {});
+
+ AssertFragmentsAreFromPath(GetFragments(equal(field_ref("state"), literal("CA"))),
+ ca_cities);
+
+ // Filter where no decisions can be made on inner nodes when filter don't
+ // apply to inner partitions.
+ AssertFragmentsAreFromPath(GetFragments(equal(field_ref("city"), literal("Franklin"))),
+ franklins);
+}
+
+TEST_F(TestFileSystemDataset, FragmentPartitions) {
+ auto root_partition = equal(field_ref("country"), literal("US"));
+ std::vector<fs::FileInfo> regions = {
+ fs::Dir("NY"), fs::File("NY/New York"), fs::File("NY/Franklin"),
+ fs::Dir("CA"), fs::File("CA/San Francisco"), fs::File("CA/Franklin"),
+ };
+
+ std::vector<compute::Expression> partitions = {
+ equal(field_ref("state"), literal("NY")),
+
+ and_(equal(field_ref("state"), literal("NY")),
+ equal(field_ref("city"), literal("New York"))),
+
+ and_(equal(field_ref("state"), literal("NY")),
+ equal(field_ref("city"), literal("Franklin"))),
+
+ equal(field_ref("state"), literal("CA")),
+
+ and_(equal(field_ref("state"), literal("CA")),
+ equal(field_ref("city"), literal("San Francisco"))),
+
+ and_(equal(field_ref("state"), literal("CA")),
+ equal(field_ref("city"), literal("Franklin"))),
+ };
+
+ MakeDataset(
+ regions, root_partition, partitions,
+ schema({field("country", utf8()), field("state", utf8()), field("city", utf8())}));
+
+ AssertFragmentsHavePartitionExpressions(
+ dataset_, {
+ and_(equal(field_ref("state"), literal("CA")),
+ equal(field_ref("city"), literal("San Francisco"))),
+ and_(equal(field_ref("state"), literal("CA")),
+ equal(field_ref("city"), literal("Franklin"))),
+ and_(equal(field_ref("state"), literal("NY")),
+ equal(field_ref("city"), literal("New York"))),
+ and_(equal(field_ref("state"), literal("NY")),
+ equal(field_ref("city"), literal("Franklin"))),
+ });
+}
+
+TEST_F(TestFileSystemDataset, WriteProjected) {
+ // Regression test for ARROW-12620
+ auto format = std::make_shared<IpcFileFormat>();
+ auto fs = std::make_shared<fs::internal::MockFileSystem>(fs::kNoTime);
+ FileSystemDatasetWriteOptions write_options;
+ write_options.file_write_options = format->DefaultWriteOptions();
+ write_options.filesystem = fs;
+ write_options.base_dir = "root";
+ write_options.partitioning = std::make_shared<HivePartitioning>(schema({}));
+ write_options.basename_template = "{i}.feather";
+
+ auto dataset_schema = schema({field("a", int64())});
+ RecordBatchVector batches{
+ ConstantArrayGenerator::Zeroes(kRowsPerBatch, dataset_schema)};
+ ASSERT_EQ(0, batches[0]->column(0)->null_count());
+ auto dataset = std::make_shared<InMemoryDataset>(dataset_schema, batches);
+ ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan());
+ ASSERT_OK(scanner_builder->Project(
+ {compute::call("add", {compute::field_ref("a"), compute::literal(1)})},
+ {"a_plus_one"}));
+ ASSERT_OK(scanner_builder->UseAsync(true));
+ ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
+
+ ASSERT_OK(FileSystemDataset::Write(write_options, scanner));
+
+ ASSERT_OK_AND_ASSIGN(auto dataset_factory, FileSystemDatasetFactory::Make(
+ fs, {"root/0.feather"}, format, {}));
+ ASSERT_OK_AND_ASSIGN(auto written_dataset, dataset_factory->Finish(FinishOptions{}));
+ auto expected_schema = schema({field("a_plus_one", int64())});
+ AssertSchemaEqual(*expected_schema, *written_dataset->schema());
+ ASSERT_OK_AND_ASSIGN(scanner_builder, written_dataset->NewScan());
+ ASSERT_OK_AND_ASSIGN(scanner, scanner_builder->Finish());
+ ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable());
+ auto col = table->column(0);
+ ASSERT_EQ(0, col->null_count());
+ for (auto chunk : col->chunks()) {
+ auto arr = std::dynamic_pointer_cast<Int64Array>(chunk);
+ for (auto val : *arr) {
+ ASSERT_TRUE(val.has_value());
+ ASSERT_EQ(1, *val);
+ }
+ }
+}
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/partition.cc b/src/arrow/cpp/src/arrow/dataset/partition.cc
new file mode 100644
index 000000000..8db9d7c84
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/partition.cc
@@ -0,0 +1,732 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/partition.h"
+
+#include <algorithm>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_dict.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/array/builder_dict.h"
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/exec/expression_internal.h"
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/scalar.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/uri.h"
+#include "arrow/util/utf8.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+using util::string_view;
+
+using internal::DictionaryMemoTable;
+
+namespace dataset {
+
+namespace {
+/// Apply UriUnescape, then ensure the results are valid UTF-8.
+Result<std::string> SafeUriUnescape(util::string_view encoded) {
+ auto decoded = ::arrow::internal::UriUnescape(encoded);
+ if (!util::ValidateUTF8(decoded)) {
+ return Status::Invalid("Partition segment was not valid UTF-8 after URL decoding: ",
+ encoded);
+ }
+ return decoded;
+}
+} // namespace
+
+std::shared_ptr<Partitioning> Partitioning::Default() {
+ class DefaultPartitioning : public Partitioning {
+ public:
+ DefaultPartitioning() : Partitioning(::arrow::schema({})) {}
+
+ std::string type_name() const override { return "default"; }
+
+ Result<compute::Expression> Parse(const std::string& path) const override {
+ return compute::literal(true);
+ }
+
+ Result<std::string> Format(const compute::Expression& expr) const override {
+ return Status::NotImplemented("formatting paths from ", type_name(),
+ " Partitioning");
+ }
+
+ Result<PartitionedBatches> Partition(
+ const std::shared_ptr<RecordBatch>& batch) const override {
+ return PartitionedBatches{{batch}, {compute::literal(true)}};
+ }
+ };
+
+ return std::make_shared<DefaultPartitioning>();
+}
+
+static Result<RecordBatchVector> ApplyGroupings(
+ const ListArray& groupings, const std::shared_ptr<RecordBatch>& batch) {
+ ARROW_ASSIGN_OR_RAISE(Datum sorted,
+ compute::Take(batch, groupings.data()->child_data[0]));
+
+ const auto& sorted_batch = *sorted.record_batch();
+
+ RecordBatchVector out(static_cast<size_t>(groupings.length()));
+ for (size_t i = 0; i < out.size(); ++i) {
+ out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i));
+ }
+
+ return out;
+}
+
+Result<Partitioning::PartitionedBatches> KeyValuePartitioning::Partition(
+ const std::shared_ptr<RecordBatch>& batch) const {
+ std::vector<int> key_indices;
+ int num_keys = 0;
+
+ // assemble vector of indices of fields in batch on which we'll partition
+ for (const auto& partition_field : schema_->fields()) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto match, FieldRef(partition_field->name()).FindOneOrNone(*batch->schema()))
+
+ if (match.empty()) continue;
+ key_indices.push_back(match[0]);
+ ++num_keys;
+ }
+
+ if (key_indices.empty()) {
+ // no fields to group by; return the whole batch
+ return PartitionedBatches{{batch}, {compute::literal(true)}};
+ }
+
+ // assemble an ExecBatch of the key columns
+ compute::ExecBatch key_batch({}, batch->num_rows());
+ for (int i : key_indices) {
+ key_batch.values.emplace_back(batch->column_data(i));
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto grouper,
+ compute::internal::Grouper::Make(key_batch.GetDescriptors()));
+
+ ARROW_ASSIGN_OR_RAISE(Datum id_batch, grouper->Consume(key_batch));
+
+ auto ids = id_batch.array_as<UInt32Array>();
+ ARROW_ASSIGN_OR_RAISE(auto groupings, compute::internal::Grouper::MakeGroupings(
+ *ids, grouper->num_groups()));
+
+ ARROW_ASSIGN_OR_RAISE(auto uniques, grouper->GetUniques());
+ ArrayVector unique_arrays(num_keys);
+ for (int i = 0; i < num_keys; ++i) {
+ unique_arrays[i] = uniques.values[i].make_array();
+ }
+
+ PartitionedBatches out;
+
+ // assemble partition expressions from the unique keys
+ out.expressions.resize(grouper->num_groups());
+ for (uint32_t group = 0; group < grouper->num_groups(); ++group) {
+ std::vector<compute::Expression> exprs(num_keys);
+
+ for (int i = 0; i < num_keys; ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto val, unique_arrays[i]->GetScalar(group));
+ const auto& name = batch->schema()->field(key_indices[i])->name();
+
+ exprs[i] = val->is_valid ? compute::equal(compute::field_ref(name),
+ compute::literal(std::move(val)))
+ : compute::is_null(compute::field_ref(name));
+ }
+ out.expressions[group] = and_(std::move(exprs));
+ }
+
+ // remove key columns from batch to which we'll be applying the groupings
+ auto rest = batch;
+ std::sort(key_indices.begin(), key_indices.end(), std::greater<int>());
+ for (int i : key_indices) {
+ // indices are in descending order; indices larger than i (which would be invalidated
+ // here) have already been handled
+ ARROW_ASSIGN_OR_RAISE(rest, rest->RemoveColumn(i));
+ }
+ ARROW_ASSIGN_OR_RAISE(out.batches, ApplyGroupings(*groupings, rest));
+
+ return out;
+}
+
+std::ostream& operator<<(std::ostream& os, SegmentEncoding segment_encoding) {
+ switch (segment_encoding) {
+ case SegmentEncoding::None:
+ os << "SegmentEncoding::None";
+ break;
+ case SegmentEncoding::Uri:
+ os << "SegmentEncoding::Uri";
+ break;
+ default:
+ os << "(invalid SegmentEncoding " << static_cast<int8_t>(segment_encoding) << ")";
+ break;
+ }
+ return os;
+}
+
+Result<compute::Expression> KeyValuePartitioning::ConvertKey(const Key& key) const {
+ ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(key.name).FindOneOrNone(*schema_));
+ if (match.empty()) {
+ return compute::literal(true);
+ }
+
+ auto field_index = match[0];
+ auto field = schema_->field(field_index);
+
+ std::shared_ptr<Scalar> converted;
+
+ if (!key.value.has_value()) {
+ return compute::is_null(compute::field_ref(field->name()));
+ } else if (field->type()->id() == Type::DICTIONARY) {
+ if (dictionaries_.empty() || dictionaries_[field_index] == nullptr) {
+ return Status::Invalid("No dictionary provided for dictionary field ",
+ field->ToString());
+ }
+
+ DictionaryScalar::ValueType value;
+ value.dictionary = dictionaries_[field_index];
+
+ const auto& dictionary_type = checked_cast<const DictionaryType&>(*field->type());
+ if (!value.dictionary->type()->Equals(dictionary_type.value_type())) {
+ return Status::TypeError("Dictionary supplied for field ", field->ToString(),
+ " had incorrect type ",
+ value.dictionary->type()->ToString());
+ }
+
+ // look up the partition value in the dictionary
+ ARROW_ASSIGN_OR_RAISE(converted, Scalar::Parse(value.dictionary->type(), *key.value));
+ ARROW_ASSIGN_OR_RAISE(auto index, compute::IndexIn(converted, value.dictionary));
+ auto to_index_type = compute::CastOptions::Safe(dictionary_type.index_type());
+ ARROW_ASSIGN_OR_RAISE(index, compute::Cast(index, to_index_type));
+ value.index = index.scalar();
+ if (!value.index->is_valid) {
+ return Status::Invalid("Dictionary supplied for field ", field->ToString(),
+ " does not contain '", *key.value, "'");
+ }
+ converted = std::make_shared<DictionaryScalar>(std::move(value), field->type());
+ } else {
+ ARROW_ASSIGN_OR_RAISE(converted, Scalar::Parse(field->type(), *key.value));
+ }
+
+ return compute::equal(compute::field_ref(field->name()),
+ compute::literal(std::move(converted)));
+}
+
+Result<compute::Expression> KeyValuePartitioning::Parse(const std::string& path) const {
+ std::vector<compute::Expression> expressions;
+
+ ARROW_ASSIGN_OR_RAISE(auto parsed, ParseKeys(path));
+ for (const Key& key : parsed) {
+ ARROW_ASSIGN_OR_RAISE(auto expr, ConvertKey(key));
+ if (expr == compute::literal(true)) continue;
+ expressions.push_back(std::move(expr));
+ }
+
+ return and_(std::move(expressions));
+}
+
+Result<std::string> KeyValuePartitioning::Format(const compute::Expression& expr) const {
+ ScalarVector values{static_cast<size_t>(schema_->num_fields()), nullptr};
+
+ ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr));
+ for (const auto& ref_value : known_values.map) {
+ if (!ref_value.second.is_scalar()) {
+ return Status::Invalid("non-scalar partition key ", ref_value.second.ToString());
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto match, ref_value.first.FindOneOrNone(*schema_));
+ if (match.empty()) continue;
+
+ auto value = ref_value.second.scalar();
+
+ const auto& field = schema_->field(match[0]);
+ if (!value->type->Equals(field->type())) {
+ if (value->is_valid) {
+ auto maybe_converted = compute::Cast(value, field->type());
+ if (!maybe_converted.ok()) {
+ return Status::TypeError("Error converting scalar ", value->ToString(),
+ " (of type ", *value->type,
+ ") to a partition key for ", field->ToString(), ": ",
+ maybe_converted.status().message());
+ }
+ value = maybe_converted->scalar();
+ } else {
+ value = MakeNullScalar(field->type());
+ }
+ }
+
+ if (value->type->id() == Type::DICTIONARY) {
+ ARROW_ASSIGN_OR_RAISE(
+ value, checked_cast<const DictionaryScalar&>(*value).GetEncodedValue());
+ }
+
+ values[match[0]] = std::move(value);
+ }
+
+ return FormatValues(values);
+}
+
+DirectoryPartitioning::DirectoryPartitioning(std::shared_ptr<Schema> schema,
+ ArrayVector dictionaries,
+ KeyValuePartitioningOptions options)
+ : KeyValuePartitioning(std::move(schema), std::move(dictionaries), options) {
+ util::InitializeUTF8();
+}
+
+Result<std::vector<KeyValuePartitioning::Key>> DirectoryPartitioning::ParseKeys(
+ const std::string& path) const {
+ std::vector<Key> keys;
+
+ int i = 0;
+ for (auto&& segment : fs::internal::SplitAbstractPath(path)) {
+ if (i >= schema_->num_fields()) break;
+
+ switch (options_.segment_encoding) {
+ case SegmentEncoding::None: {
+ if (ARROW_PREDICT_FALSE(!util::ValidateUTF8(segment))) {
+ return Status::Invalid("Partition segment was not valid UTF-8: ", segment);
+ }
+ keys.push_back({schema_->field(i++)->name(), std::move(segment)});
+ break;
+ }
+ case SegmentEncoding::Uri: {
+ ARROW_ASSIGN_OR_RAISE(auto decoded, SafeUriUnescape(segment));
+ keys.push_back({schema_->field(i++)->name(), std::move(decoded)});
+ break;
+ }
+ default:
+ return Status::NotImplemented("Unknown segment encoding: ",
+ options_.segment_encoding);
+ }
+ }
+
+ return keys;
+}
+
+inline util::optional<int> NextValid(const ScalarVector& values, int first_null) {
+ auto it = std::find_if(values.begin() + first_null + 1, values.end(),
+ [](const std::shared_ptr<Scalar>& v) { return v != nullptr; });
+
+ if (it == values.end()) {
+ return util::nullopt;
+ }
+
+ return static_cast<int>(it - values.begin());
+}
+
+Result<std::string> DirectoryPartitioning::FormatValues(
+ const ScalarVector& values) const {
+ std::vector<std::string> segments(static_cast<size_t>(schema_->num_fields()));
+
+ for (int i = 0; i < schema_->num_fields(); ++i) {
+ if (values[i] != nullptr && values[i]->is_valid) {
+ segments[i] = values[i]->ToString();
+ continue;
+ }
+
+ if (auto illegal_index = NextValid(values, i)) {
+ // XXX maybe we should just ignore keys provided after the first absent one?
+ return Status::Invalid("No partition key for ", schema_->field(i)->name(),
+ " but a key was provided subsequently for ",
+ schema_->field(*illegal_index)->name(), ".");
+ }
+
+ // if all subsequent keys are absent we'll just print the available keys
+ break;
+ }
+
+ return fs::internal::JoinAbstractPath(std::move(segments));
+}
+
+KeyValuePartitioningOptions PartitioningFactoryOptions::AsPartitioningOptions() const {
+ KeyValuePartitioningOptions options;
+ options.segment_encoding = segment_encoding;
+ return options;
+}
+
+HivePartitioningOptions HivePartitioningFactoryOptions::AsHivePartitioningOptions()
+ const {
+ HivePartitioningOptions options;
+ options.segment_encoding = segment_encoding;
+ options.null_fallback = null_fallback;
+ return options;
+}
+
+namespace {
+class KeyValuePartitioningFactory : public PartitioningFactory {
+ protected:
+ explicit KeyValuePartitioningFactory(PartitioningFactoryOptions options)
+ : options_(std::move(options)) {}
+
+ int GetOrInsertField(const std::string& name) {
+ auto it_inserted =
+ name_to_index_.emplace(name, static_cast<int>(name_to_index_.size()));
+
+ if (it_inserted.second) {
+ repr_memos_.push_back(MakeMemo());
+ }
+
+ return it_inserted.first->second;
+ }
+
+ Status InsertRepr(const std::string& name, util::optional<string_view> repr) {
+ auto field_index = GetOrInsertField(name);
+ if (repr.has_value()) {
+ return InsertRepr(field_index, *repr);
+ } else {
+ return Status::OK();
+ }
+ }
+
+ Status InsertRepr(int index, util::string_view repr) {
+ int dummy;
+ return repr_memos_[index]->GetOrInsert<StringType>(repr, &dummy);
+ }
+
+ Result<std::shared_ptr<Schema>> DoInspect() {
+ dictionaries_.assign(name_to_index_.size(), nullptr);
+
+ std::vector<std::shared_ptr<Field>> fields(name_to_index_.size());
+ if (options_.schema) {
+ const auto requested_size = options_.schema->fields().size();
+ const auto inferred_size = fields.size();
+ if (inferred_size != requested_size) {
+ return Status::Invalid("Requested schema has ", requested_size,
+ " fields, but only ", inferred_size, " were detected");
+ }
+ }
+
+ for (const auto& name_index : name_to_index_) {
+ const auto& name = name_index.first;
+ auto index = name_index.second;
+
+ std::shared_ptr<ArrayData> reprs;
+ RETURN_NOT_OK(repr_memos_[index]->GetArrayData(0, &reprs));
+
+ if (reprs->length == 0) {
+ return Status::Invalid("No non-null segments were available for field '", name,
+ "'; couldn't infer type");
+ }
+
+ std::shared_ptr<Field> current_field;
+ std::shared_ptr<Array> dict;
+ if (options_.schema) {
+ // if we have a schema, use the schema type.
+ current_field = options_.schema->field(index);
+ auto cast_target = current_field->type();
+ if (is_dictionary(cast_target->id())) {
+ cast_target = checked_pointer_cast<DictionaryType>(cast_target)->value_type();
+ }
+ auto maybe_dict = compute::Cast(reprs, cast_target);
+ if (!maybe_dict.ok()) {
+ return Status::Invalid("Could not cast segments for partition field ",
+ current_field->name(), " to requested type ",
+ current_field->type()->ToString(),
+ " because: ", maybe_dict.status());
+ }
+ dict = maybe_dict.ValueOrDie().make_array();
+ } else {
+ // try casting to int32, otherwise bail and just use the string reprs
+ dict = compute::Cast(reprs, int32()).ValueOr(reprs).make_array();
+ auto type = dict->type();
+ if (options_.infer_dictionary) {
+ // wrap the inferred type in dictionary()
+ type = dictionary(int32(), std::move(type));
+ }
+ current_field = field(name, std::move(type));
+ }
+ fields[index] = std::move(current_field);
+ dictionaries_[index] = std::move(dict);
+ }
+
+ Reset();
+ return ::arrow::schema(std::move(fields));
+ }
+
+ std::vector<std::string> FieldNames() {
+ std::vector<std::string> names(name_to_index_.size());
+
+ for (auto kv : name_to_index_) {
+ names[kv.second] = kv.first;
+ }
+ return names;
+ }
+
+ virtual void Reset() {
+ name_to_index_.clear();
+ repr_memos_.clear();
+ }
+
+ std::unique_ptr<DictionaryMemoTable> MakeMemo() {
+ return ::arrow::internal::make_unique<DictionaryMemoTable>(default_memory_pool(),
+ utf8());
+ }
+
+ PartitioningFactoryOptions options_;
+ ArrayVector dictionaries_;
+ std::unordered_map<std::string, int> name_to_index_;
+ std::vector<std::unique_ptr<DictionaryMemoTable>> repr_memos_;
+};
+
+class DirectoryPartitioningFactory : public KeyValuePartitioningFactory {
+ public:
+ DirectoryPartitioningFactory(std::vector<std::string> field_names,
+ PartitioningFactoryOptions options)
+ : KeyValuePartitioningFactory(options), field_names_(std::move(field_names)) {
+ Reset();
+ util::InitializeUTF8();
+ }
+
+ std::string type_name() const override { return "directory"; }
+
+ Result<std::shared_ptr<Schema>> Inspect(
+ const std::vector<std::string>& paths) override {
+ for (auto path : paths) {
+ size_t field_index = 0;
+ for (auto&& segment : fs::internal::SplitAbstractPath(path)) {
+ if (field_index == field_names_.size()) break;
+
+ switch (options_.segment_encoding) {
+ case SegmentEncoding::None: {
+ if (ARROW_PREDICT_FALSE(!util::ValidateUTF8(segment))) {
+ return Status::Invalid("Partition segment was not valid UTF-8: ", segment);
+ }
+ RETURN_NOT_OK(InsertRepr(static_cast<int>(field_index++), segment));
+ break;
+ }
+ case SegmentEncoding::Uri: {
+ ARROW_ASSIGN_OR_RAISE(auto decoded, SafeUriUnescape(segment));
+ RETURN_NOT_OK(InsertRepr(static_cast<int>(field_index++), decoded));
+ break;
+ }
+ default:
+ return Status::NotImplemented("Unknown segment encoding: ",
+ options_.segment_encoding);
+ }
+ }
+ }
+
+ return DoInspect();
+ }
+
+ Result<std::shared_ptr<Partitioning>> Finish(
+ const std::shared_ptr<Schema>& schema) const override {
+ for (FieldRef ref : field_names_) {
+ // ensure all of field_names_ are present in schema
+ RETURN_NOT_OK(ref.FindOne(*schema).status());
+ }
+
+ // drop fields which aren't in field_names_
+ auto out_schema = SchemaFromColumnNames(schema, field_names_);
+
+ return std::make_shared<DirectoryPartitioning>(std::move(out_schema), dictionaries_,
+ options_.AsPartitioningOptions());
+ }
+
+ private:
+ void Reset() override {
+ KeyValuePartitioningFactory::Reset();
+
+ for (const auto& name : field_names_) {
+ GetOrInsertField(name);
+ }
+ }
+
+ std::vector<std::string> field_names_;
+};
+
+} // namespace
+
+std::shared_ptr<PartitioningFactory> DirectoryPartitioning::MakeFactory(
+ std::vector<std::string> field_names, PartitioningFactoryOptions options) {
+ return std::shared_ptr<PartitioningFactory>(
+ new DirectoryPartitioningFactory(std::move(field_names), options));
+}
+
+Result<util::optional<KeyValuePartitioning::Key>> HivePartitioning::ParseKey(
+ const std::string& segment, const HivePartitioningOptions& options) {
+ auto name_end = string_view(segment).find_first_of('=');
+ // Not round-trippable
+ if (name_end == string_view::npos) {
+ return util::nullopt;
+ }
+
+ // Static method, so we have no better place for it
+ util::InitializeUTF8();
+
+ auto name = segment.substr(0, name_end);
+ std::string value;
+ switch (options.segment_encoding) {
+ case SegmentEncoding::None: {
+ value = segment.substr(name_end + 1);
+ if (ARROW_PREDICT_FALSE(!util::ValidateUTF8(value))) {
+ return Status::Invalid("Partition segment was not valid UTF-8: ", value);
+ }
+ break;
+ }
+ case SegmentEncoding::Uri: {
+ auto raw_value = util::string_view(segment).substr(name_end + 1);
+ ARROW_ASSIGN_OR_RAISE(value, SafeUriUnescape(raw_value));
+ break;
+ }
+ default:
+ return Status::NotImplemented("Unknown segment encoding: ",
+ options.segment_encoding);
+ }
+
+ if (value == options.null_fallback) {
+ return Key{std::move(name), util::nullopt};
+ }
+ return Key{std::move(name), std::move(value)};
+}
+
+Result<std::vector<KeyValuePartitioning::Key>> HivePartitioning::ParseKeys(
+ const std::string& path) const {
+ std::vector<Key> keys;
+
+ for (const auto& segment : fs::internal::SplitAbstractPath(path)) {
+ ARROW_ASSIGN_OR_RAISE(auto maybe_key, ParseKey(segment, hive_options_));
+ if (auto key = maybe_key) {
+ keys.push_back(std::move(*key));
+ }
+ }
+
+ return keys;
+}
+
+Result<std::string> HivePartitioning::FormatValues(const ScalarVector& values) const {
+ std::vector<std::string> segments(static_cast<size_t>(schema_->num_fields()));
+
+ for (int i = 0; i < schema_->num_fields(); ++i) {
+ const std::string& name = schema_->field(i)->name();
+
+ if (values[i] == nullptr) {
+ segments[i] = "";
+ } else if (!values[i]->is_valid) {
+ // If no key is available just provide a placeholder segment to maintain the
+ // field_index <-> path nesting relation
+ segments[i] = name + "=" + hive_options_.null_fallback;
+ } else {
+ segments[i] = name + "=" + values[i]->ToString();
+ }
+ }
+
+ return fs::internal::JoinAbstractPath(std::move(segments));
+}
+
+class HivePartitioningFactory : public KeyValuePartitioningFactory {
+ public:
+ explicit HivePartitioningFactory(HivePartitioningFactoryOptions options)
+ : KeyValuePartitioningFactory(options), options_(std::move(options)) {}
+
+ std::string type_name() const override { return "hive"; }
+
+ Result<std::shared_ptr<Schema>> Inspect(
+ const std::vector<std::string>& paths) override {
+ auto options = options_.AsHivePartitioningOptions();
+ for (auto path : paths) {
+ for (auto&& segment : fs::internal::SplitAbstractPath(path)) {
+ ARROW_ASSIGN_OR_RAISE(auto maybe_key,
+ HivePartitioning::ParseKey(segment, options));
+ if (auto key = maybe_key) {
+ RETURN_NOT_OK(InsertRepr(key->name, key->value));
+ }
+ }
+ }
+
+ field_names_ = FieldNames();
+ return DoInspect();
+ }
+
+ Result<std::shared_ptr<Partitioning>> Finish(
+ const std::shared_ptr<Schema>& schema) const override {
+ if (dictionaries_.empty()) {
+ return std::make_shared<HivePartitioning>(schema, dictionaries_);
+ } else {
+ for (FieldRef ref : field_names_) {
+ // ensure all of field_names_ are present in schema
+ RETURN_NOT_OK(ref.FindOne(*schema));
+ }
+
+ // drop fields which aren't in field_names_
+ auto out_schema = SchemaFromColumnNames(schema, field_names_);
+
+ return std::make_shared<HivePartitioning>(std::move(out_schema), dictionaries_,
+ options_.AsHivePartitioningOptions());
+ }
+ }
+
+ private:
+ const HivePartitioningFactoryOptions options_;
+ std::vector<std::string> field_names_;
+};
+
+std::shared_ptr<PartitioningFactory> HivePartitioning::MakeFactory(
+ HivePartitioningFactoryOptions options) {
+ return std::shared_ptr<PartitioningFactory>(new HivePartitioningFactory(options));
+}
+
+std::string StripPrefixAndFilename(const std::string& path, const std::string& prefix) {
+ auto maybe_base_less = fs::internal::RemoveAncestor(prefix, path);
+ auto base_less = maybe_base_less ? std::string(*maybe_base_less) : path;
+ auto basename_filename = fs::internal::GetAbstractPathParent(base_less);
+ return basename_filename.first;
+}
+
+std::vector<std::string> StripPrefixAndFilename(const std::vector<std::string>& paths,
+ const std::string& prefix) {
+ std::vector<std::string> result;
+ result.reserve(paths.size());
+ for (const auto& path : paths) {
+ result.emplace_back(StripPrefixAndFilename(path, prefix));
+ }
+ return result;
+}
+
+std::vector<std::string> StripPrefixAndFilename(const std::vector<fs::FileInfo>& files,
+ const std::string& prefix) {
+ std::vector<std::string> result;
+ result.reserve(files.size());
+ for (const auto& info : files) {
+ result.emplace_back(StripPrefixAndFilename(info.path(), prefix));
+ }
+ return result;
+}
+
+Result<std::shared_ptr<Schema>> PartitioningOrFactory::GetOrInferSchema(
+ const std::vector<std::string>& paths) {
+ if (auto part = partitioning()) {
+ return part->schema();
+ }
+
+ return factory()->Inspect(paths);
+}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/partition.h b/src/arrow/cpp/src/arrow/dataset/partition.h
new file mode 100644
index 000000000..aa6958ed1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/partition.h
@@ -0,0 +1,372 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <functional>
+#include <iosfwd>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "arrow/compute/exec/expression.h"
+#include "arrow/dataset/type_fwd.h"
+#include "arrow/dataset/visibility.h"
+#include "arrow/util/optional.h"
+
+namespace arrow {
+
+namespace dataset {
+
+// ----------------------------------------------------------------------
+// Partitioning
+
+/// \defgroup dataset-partitioning Partitioning API
+///
+/// @{
+
+/// \brief Interface for parsing partition expressions from string partition
+/// identifiers.
+///
+/// For example, the identifier "foo=5" might be parsed to an equality expression
+/// between the "foo" field and the value 5.
+///
+/// Some partitionings may store the field names in a metadata
+/// store instead of in file paths, for example
+/// dataset_root/2009/11/... could be used when the partition fields
+/// are "year" and "month"
+///
+/// Paths are consumed from left to right. Paths must be relative to
+/// the root of a partition; path prefixes must be removed before passing
+/// the path to a partitioning for parsing.
+class ARROW_DS_EXPORT Partitioning {
+ public:
+ virtual ~Partitioning() = default;
+
+ /// \brief The name identifying the kind of partitioning
+ virtual std::string type_name() const = 0;
+
+ /// \brief If the input batch shares any fields with this partitioning,
+ /// produce sub-batches which satisfy mutually exclusive Expressions.
+ struct PartitionedBatches {
+ RecordBatchVector batches;
+ std::vector<compute::Expression> expressions;
+ };
+ virtual Result<PartitionedBatches> Partition(
+ const std::shared_ptr<RecordBatch>& batch) const = 0;
+
+ /// \brief Parse a path into a partition expression
+ virtual Result<compute::Expression> Parse(const std::string& path) const = 0;
+
+ virtual Result<std::string> Format(const compute::Expression& expr) const = 0;
+
+ /// \brief A default Partitioning which always yields scalar(true)
+ static std::shared_ptr<Partitioning> Default();
+
+ /// \brief The partition schema.
+ const std::shared_ptr<Schema>& schema() { return schema_; }
+
+ protected:
+ explicit Partitioning(std::shared_ptr<Schema> schema) : schema_(std::move(schema)) {}
+
+ std::shared_ptr<Schema> schema_;
+};
+
+/// \brief The encoding of partition segments.
+enum class SegmentEncoding : int8_t {
+ /// No encoding.
+ None = 0,
+ /// Segment values are URL-encoded.
+ Uri = 1,
+};
+
+ARROW_DS_EXPORT
+std::ostream& operator<<(std::ostream& os, SegmentEncoding segment_encoding);
+
+/// \brief Options for key-value based partitioning (hive/directory).
+struct ARROW_DS_EXPORT KeyValuePartitioningOptions {
+ /// After splitting a path into components, decode the path components
+ /// before parsing according to this scheme.
+ SegmentEncoding segment_encoding = SegmentEncoding::Uri;
+};
+
+/// \brief Options for inferring a partitioning.
+struct ARROW_DS_EXPORT PartitioningFactoryOptions {
+ /// When inferring a schema for partition fields, yield dictionary encoded types
+ /// instead of plain. This can be more efficient when materializing virtual
+ /// columns, and Expressions parsed by the finished Partitioning will include
+ /// dictionaries of all unique inspected values for each field.
+ bool infer_dictionary = false;
+ /// Optionally, an expected schema can be provided, in which case inference
+ /// will only check discovered fields against the schema and update internal
+ /// state (such as dictionaries).
+ std::shared_ptr<Schema> schema;
+ /// After splitting a path into components, decode the path components
+ /// before parsing according to this scheme.
+ SegmentEncoding segment_encoding = SegmentEncoding::Uri;
+
+ KeyValuePartitioningOptions AsPartitioningOptions() const;
+};
+
+/// \brief Options for inferring a hive-style partitioning.
+struct ARROW_DS_EXPORT HivePartitioningFactoryOptions : PartitioningFactoryOptions {
+ /// The hive partitioning scheme maps null to a hard coded fallback string.
+ std::string null_fallback;
+
+ HivePartitioningOptions AsHivePartitioningOptions() const;
+};
+
+/// \brief PartitioningFactory provides creation of a partitioning when the
+/// specific schema must be inferred from available paths (no explicit schema is known).
+class ARROW_DS_EXPORT PartitioningFactory {
+ public:
+ virtual ~PartitioningFactory() = default;
+
+ /// \brief The name identifying the kind of partitioning
+ virtual std::string type_name() const = 0;
+
+ /// Get the schema for the resulting Partitioning.
+ /// This may reset internal state, for example dictionaries of unique representations.
+ virtual Result<std::shared_ptr<Schema>> Inspect(
+ const std::vector<std::string>& paths) = 0;
+
+ /// Create a partitioning using the provided schema
+ /// (fields may be dropped).
+ virtual Result<std::shared_ptr<Partitioning>> Finish(
+ const std::shared_ptr<Schema>& schema) const = 0;
+};
+
+/// \brief Subclass for the common case of a partitioning which yields an equality
+/// expression for each segment
+class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning {
+ public:
+ /// An unconverted equality expression consisting of a field name and the representation
+ /// of a scalar value
+ struct Key {
+ std::string name;
+ util::optional<std::string> value;
+ };
+
+ Result<PartitionedBatches> Partition(
+ const std::shared_ptr<RecordBatch>& batch) const override;
+
+ Result<compute::Expression> Parse(const std::string& path) const override;
+
+ Result<std::string> Format(const compute::Expression& expr) const override;
+
+ const ArrayVector& dictionaries() const { return dictionaries_; }
+
+ protected:
+ KeyValuePartitioning(std::shared_ptr<Schema> schema, ArrayVector dictionaries,
+ KeyValuePartitioningOptions options)
+ : Partitioning(std::move(schema)),
+ dictionaries_(std::move(dictionaries)),
+ options_(options) {
+ if (dictionaries_.empty()) {
+ dictionaries_.resize(schema_->num_fields());
+ }
+ }
+
+ virtual Result<std::vector<Key>> ParseKeys(const std::string& path) const = 0;
+
+ virtual Result<std::string> FormatValues(const ScalarVector& values) const = 0;
+
+ /// Convert a Key to a full expression.
+ Result<compute::Expression> ConvertKey(const Key& key) const;
+
+ ArrayVector dictionaries_;
+ KeyValuePartitioningOptions options_;
+};
+
+/// \brief DirectoryPartitioning parses one segment of a path for each field in its
+/// schema. All fields are required, so paths passed to DirectoryPartitioning::Parse
+/// must contain segments for each field.
+///
+/// For example given schema<year:int16, month:int8> the path "/2009/11" would be
+/// parsed to ("year"_ == 2009 and "month"_ == 11)
+class ARROW_DS_EXPORT DirectoryPartitioning : public KeyValuePartitioning {
+ public:
+ /// If a field in schema is of dictionary type, the corresponding element of
+ /// dictionaries must be contain the dictionary of values for that field.
+ explicit DirectoryPartitioning(std::shared_ptr<Schema> schema,
+ ArrayVector dictionaries = {},
+ KeyValuePartitioningOptions options = {});
+
+ std::string type_name() const override { return "directory"; }
+
+ /// \brief Create a factory for a directory partitioning.
+ ///
+ /// \param[in] field_names The names for the partition fields. Types will be
+ /// inferred.
+ static std::shared_ptr<PartitioningFactory> MakeFactory(
+ std::vector<std::string> field_names, PartitioningFactoryOptions = {});
+
+ private:
+ Result<std::vector<Key>> ParseKeys(const std::string& path) const override;
+
+ Result<std::string> FormatValues(const ScalarVector& values) const override;
+};
+
+/// \brief The default fallback used for null values in a Hive-style partitioning.
+static constexpr char kDefaultHiveNullFallback[] = "__HIVE_DEFAULT_PARTITION__";
+
+struct ARROW_DS_EXPORT HivePartitioningOptions : public KeyValuePartitioningOptions {
+ std::string null_fallback = kDefaultHiveNullFallback;
+
+ static HivePartitioningOptions DefaultsWithNullFallback(std::string fallback) {
+ HivePartitioningOptions options;
+ options.null_fallback = std::move(fallback);
+ return options;
+ }
+};
+
+/// \brief Multi-level, directory based partitioning
+/// originating from Apache Hive with all data files stored in the
+/// leaf directories. Data is partitioned by static values of a
+/// particular column in the schema. Partition keys are represented in
+/// the form $key=$value in directory names.
+/// Field order is ignored, as are missing or unrecognized field names.
+///
+/// For example given schema<year:int16, month:int8, day:int8> the path
+/// "/day=321/ignored=3.4/year=2009" parses to ("year"_ == 2009 and "day"_ == 321)
+class ARROW_DS_EXPORT HivePartitioning : public KeyValuePartitioning {
+ public:
+ /// If a field in schema is of dictionary type, the corresponding element of
+ /// dictionaries must be contain the dictionary of values for that field.
+ explicit HivePartitioning(std::shared_ptr<Schema> schema, ArrayVector dictionaries = {},
+ std::string null_fallback = kDefaultHiveNullFallback)
+ : KeyValuePartitioning(std::move(schema), std::move(dictionaries),
+ KeyValuePartitioningOptions()),
+ hive_options_(
+ HivePartitioningOptions::DefaultsWithNullFallback(std::move(null_fallback))) {
+ }
+
+ explicit HivePartitioning(std::shared_ptr<Schema> schema, ArrayVector dictionaries,
+ HivePartitioningOptions options)
+ : KeyValuePartitioning(std::move(schema), std::move(dictionaries), options),
+ hive_options_(options) {}
+
+ std::string type_name() const override { return "hive"; }
+ std::string null_fallback() const { return hive_options_.null_fallback; }
+ const HivePartitioningOptions& options() const { return hive_options_; }
+
+ static Result<util::optional<Key>> ParseKey(const std::string& segment,
+ const HivePartitioningOptions& options);
+
+ /// \brief Create a factory for a hive partitioning.
+ static std::shared_ptr<PartitioningFactory> MakeFactory(
+ HivePartitioningFactoryOptions = {});
+
+ private:
+ const HivePartitioningOptions hive_options_;
+ Result<std::vector<Key>> ParseKeys(const std::string& path) const override;
+
+ Result<std::string> FormatValues(const ScalarVector& values) const override;
+};
+
+/// \brief Implementation provided by lambda or other callable
+class ARROW_DS_EXPORT FunctionPartitioning : public Partitioning {
+ public:
+ using ParseImpl = std::function<Result<compute::Expression>(const std::string&)>;
+
+ using FormatImpl = std::function<Result<std::string>(const compute::Expression&)>;
+
+ FunctionPartitioning(std::shared_ptr<Schema> schema, ParseImpl parse_impl,
+ FormatImpl format_impl = NULLPTR, std::string name = "function")
+ : Partitioning(std::move(schema)),
+ parse_impl_(std::move(parse_impl)),
+ format_impl_(std::move(format_impl)),
+ name_(std::move(name)) {}
+
+ std::string type_name() const override { return name_; }
+
+ Result<compute::Expression> Parse(const std::string& path) const override {
+ return parse_impl_(path);
+ }
+
+ Result<std::string> Format(const compute::Expression& expr) const override {
+ if (format_impl_) {
+ return format_impl_(expr);
+ }
+ return Status::NotImplemented("formatting paths from ", type_name(), " Partitioning");
+ }
+
+ Result<PartitionedBatches> Partition(
+ const std::shared_ptr<RecordBatch>& batch) const override {
+ return Status::NotImplemented("partitioning batches from ", type_name(),
+ " Partitioning");
+ }
+
+ private:
+ ParseImpl parse_impl_;
+ FormatImpl format_impl_;
+ std::string name_;
+};
+
+/// \brief Remove a prefix and the filename of a path.
+///
+/// e.g., `StripPrefixAndFilename("/data/year=2019/c.txt", "/data") -> "year=2019"`
+ARROW_DS_EXPORT std::string StripPrefixAndFilename(const std::string& path,
+ const std::string& prefix);
+
+/// \brief Vector version of StripPrefixAndFilename.
+ARROW_DS_EXPORT std::vector<std::string> StripPrefixAndFilename(
+ const std::vector<std::string>& paths, const std::string& prefix);
+
+/// \brief Vector version of StripPrefixAndFilename.
+ARROW_DS_EXPORT std::vector<std::string> StripPrefixAndFilename(
+ const std::vector<fs::FileInfo>& files, const std::string& prefix);
+
+/// \brief Either a Partitioning or a PartitioningFactory
+class ARROW_DS_EXPORT PartitioningOrFactory {
+ public:
+ explicit PartitioningOrFactory(std::shared_ptr<Partitioning> partitioning)
+ : partitioning_(std::move(partitioning)) {}
+
+ explicit PartitioningOrFactory(std::shared_ptr<PartitioningFactory> factory)
+ : factory_(std::move(factory)) {}
+
+ PartitioningOrFactory& operator=(std::shared_ptr<Partitioning> partitioning) {
+ return *this = PartitioningOrFactory(std::move(partitioning));
+ }
+
+ PartitioningOrFactory& operator=(std::shared_ptr<PartitioningFactory> factory) {
+ return *this = PartitioningOrFactory(std::move(factory));
+ }
+
+ /// \brief The partitioning (if given).
+ const std::shared_ptr<Partitioning>& partitioning() const { return partitioning_; }
+
+ /// \brief The partition factory (if given).
+ const std::shared_ptr<PartitioningFactory>& factory() const { return factory_; }
+
+ /// \brief Get the partition schema, inferring it with the given factory if needed.
+ Result<std::shared_ptr<Schema>> GetOrInferSchema(const std::vector<std::string>& paths);
+
+ private:
+ std::shared_ptr<PartitioningFactory> factory_;
+ std::shared_ptr<Partitioning> partitioning_;
+};
+
+/// @}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/partition_test.cc b/src/arrow/cpp/src/arrow/dataset/partition_test.cc
new file mode 100644
index 000000000..c4cd528c2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/partition_test.cc
@@ -0,0 +1,836 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/partition.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <cstdint>
+#include <memory>
+#include <regex>
+#include <string>
+#include <vector>
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/dataset/scanner_internal.h"
+#include "arrow/dataset/test_util.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/range.h"
+
+namespace arrow {
+
+using internal::checked_pointer_cast;
+
+namespace dataset {
+
+class TestPartitioning : public ::testing::Test {
+ public:
+ void AssertParseError(const std::string& path) {
+ ASSERT_RAISES(Invalid, partitioning_->Parse(path));
+ }
+
+ void AssertParse(const std::string& path, compute::Expression expected) {
+ ASSERT_OK_AND_ASSIGN(auto parsed, partitioning_->Parse(path));
+ ASSERT_EQ(parsed, expected);
+ }
+
+ template <StatusCode code = StatusCode::Invalid>
+ void AssertFormatError(compute::Expression expr) {
+ ASSERT_EQ(partitioning_->Format(expr).status().code(), code);
+ }
+
+ void AssertFormat(compute::Expression expr, const std::string& expected) {
+ // formatted partition expressions are bound to the schema of the dataset being
+ // written
+ ASSERT_OK_AND_ASSIGN(auto formatted, partitioning_->Format(expr));
+ ASSERT_EQ(formatted, expected);
+
+ // ensure the formatted path round trips the relevant components of the partition
+ // expression: roundtripped should be a subset of expr
+ ASSERT_OK_AND_ASSIGN(compute::Expression roundtripped,
+ partitioning_->Parse(formatted));
+
+ ASSERT_OK_AND_ASSIGN(roundtripped, roundtripped.Bind(*written_schema_));
+ ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(roundtripped, expr));
+ ASSERT_EQ(simplified, literal(true));
+ }
+
+ void AssertInspect(const std::vector<std::string>& paths,
+ const std::vector<std::shared_ptr<Field>>& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, factory_->Inspect(paths));
+ ASSERT_EQ(*actual, Schema(expected));
+ ASSERT_OK_AND_ASSIGN(partitioning_, factory_->Finish(actual));
+ }
+
+ void AssertPartition(const std::shared_ptr<Partitioning> partitioning,
+ const std::shared_ptr<RecordBatch> full_batch,
+ const RecordBatchVector& expected_batches,
+ const std::vector<compute::Expression>& expected_expressions) {
+ ASSERT_OK_AND_ASSIGN(auto partition_results, partitioning->Partition(full_batch));
+ std::shared_ptr<RecordBatch> rest = full_batch;
+
+ ASSERT_EQ(partition_results.batches.size(), expected_batches.size());
+
+ for (size_t i = 0; i < partition_results.batches.size(); i++) {
+ std::shared_ptr<RecordBatch> actual_batch = partition_results.batches[i];
+ compute::Expression actual_expression = partition_results.expressions[i];
+
+ auto expected_expression = std::find(expected_expressions.begin(),
+ expected_expressions.end(), actual_expression);
+ ASSERT_NE(expected_expression, expected_expressions.end())
+ << "Unexpected partition expr " << actual_expression.ToString();
+
+ auto expected_batch =
+ expected_batches[expected_expression - expected_expressions.begin()];
+
+ SCOPED_TRACE("Batch for " + expected_expression->ToString());
+ AssertBatchesEqual(*expected_batch, *actual_batch);
+ }
+ }
+
+ void AssertPartition(const std::shared_ptr<Partitioning> partitioning,
+ const std::shared_ptr<Schema> schema,
+ const std::string& record_batch_json,
+ const std::shared_ptr<Schema> partitioned_schema,
+ const std::vector<std::string>& expected_record_batch_strs,
+ const std::vector<compute::Expression>& expected_expressions) {
+ auto record_batch = RecordBatchFromJSON(schema, record_batch_json);
+ RecordBatchVector expected_batches;
+ for (const auto& expected_record_batch_str : expected_record_batch_strs) {
+ expected_batches.push_back(
+ RecordBatchFromJSON(partitioned_schema, expected_record_batch_str));
+ }
+ AssertPartition(partitioning, record_batch, expected_batches, expected_expressions);
+ }
+
+ void AssertInspectError(const std::vector<std::string>& paths) {
+ ASSERT_RAISES(Invalid, factory_->Inspect(paths));
+ }
+
+ protected:
+ static std::shared_ptr<Field> Int(std::string name) {
+ return field(std::move(name), int32());
+ }
+
+ static std::shared_ptr<Field> Str(std::string name) {
+ return field(std::move(name), utf8());
+ }
+
+ static std::shared_ptr<Field> DictStr(std::string name) {
+ return field(std::move(name), dictionary(int32(), utf8()));
+ }
+
+ static std::shared_ptr<Field> DictInt(std::string name) {
+ return field(std::move(name), dictionary(int32(), int32()));
+ }
+
+ std::shared_ptr<Partitioning> partitioning_;
+ std::shared_ptr<PartitioningFactory> factory_;
+ std::shared_ptr<Schema> written_schema_;
+};
+
+TEST_F(TestPartitioning, Partition) {
+ auto dataset_schema =
+ schema({field("a", int32()), field("b", utf8()), field("c", uint32())});
+
+ auto partition_schema = schema({field("a", int32()), field("b", utf8())});
+
+ auto physical_schema = schema({field("c", uint32())});
+
+ auto partitioning = std::make_shared<DirectoryPartitioning>(partition_schema);
+ std::string json = R"([{"a": 3, "b": "x", "c": 0},
+ {"a": 3, "b": "x", "c": 1},
+ {"a": 1, "b": null, "c": 2},
+ {"a": null, "b": null, "c": 3},
+ {"a": null, "b": "z", "c": 4},
+ {"a": null, "b": null, "c": 5}
+ ])";
+
+ std::vector<std::string> expected_batches = {
+ R"([{"c": 0}, {"c": 1}])",
+ R"([{"c": 2}])",
+ R"([{"c": 3}, {"c": 5}])",
+ R"([{"c": 4}])",
+ };
+
+ std::vector<compute::Expression> expected_expressions = {
+ and_(equal(field_ref("a"), literal(3)), equal(field_ref("b"), literal("x"))),
+ and_(equal(field_ref("a"), literal(1)), is_null(field_ref("b"))),
+ and_(is_null(field_ref("a")), is_null(field_ref("b"))),
+ and_(is_null(field_ref("a")), equal(field_ref("b"), literal("z"))),
+ };
+
+ AssertPartition(partitioning, dataset_schema, json, physical_schema, expected_batches,
+ expected_expressions);
+}
+
+TEST_F(TestPartitioning, DirectoryPartitioning) {
+ partitioning_ = std::make_shared<DirectoryPartitioning>(
+ schema({field("alpha", int32()), field("beta", utf8())}));
+
+ AssertParse("/0/hello", and_(equal(field_ref("alpha"), literal(0)),
+ equal(field_ref("beta"), literal("hello"))));
+ AssertParse("/3", equal(field_ref("alpha"), literal(3)));
+ AssertParseError("/world/0"); // reversed order
+ AssertParseError("/0.0/foo"); // invalid alpha
+ AssertParseError("/3.25"); // invalid alpha with missing beta
+ AssertParse("", literal(true)); // no segments to parse
+
+ // gotcha someday:
+ AssertParse("/0/dat.parquet", and_(equal(field_ref("alpha"), literal(0)),
+ equal(field_ref("beta"), literal("dat.parquet"))));
+
+ AssertParse("/0/foo/ignored=2341", and_(equal(field_ref("alpha"), literal(0)),
+ equal(field_ref("beta"), literal("foo"))));
+}
+
+TEST_F(TestPartitioning, DirectoryPartitioningFormat) {
+ partitioning_ = std::make_shared<DirectoryPartitioning>(
+ schema({field("alpha", int32()), field("beta", utf8())}));
+
+ written_schema_ = partitioning_->schema();
+
+ AssertFormat(and_(equal(field_ref("alpha"), literal(0)),
+ equal(field_ref("beta"), literal("hello"))),
+ "0/hello");
+ AssertFormat(and_(equal(field_ref("beta"), literal("hello")),
+ equal(field_ref("alpha"), literal(0))),
+ "0/hello");
+ AssertFormat(equal(field_ref("alpha"), literal(0)), "0");
+ AssertFormat(and_(equal(field_ref("alpha"), literal(0)), is_null(field_ref("beta"))),
+ "0");
+ AssertFormatError(
+ and_(is_null(field_ref("alpha")), equal(field_ref("beta"), literal("hello"))));
+ AssertFormatError(equal(field_ref("beta"), literal("hello")));
+ AssertFormat(literal(true), "");
+
+ ASSERT_OK_AND_ASSIGN(written_schema_,
+ written_schema_->AddField(0, field("gamma", utf8())));
+ AssertFormat(and_({equal(field_ref("gamma"), literal("yo")),
+ equal(field_ref("alpha"), literal(0)),
+ equal(field_ref("beta"), literal("hello"))}),
+ "0/hello");
+
+ // written_schema_ is incompatible with partitioning_'s schema
+ written_schema_ = schema({field("alpha", utf8()), field("beta", utf8())});
+ AssertFormatError<StatusCode::TypeError>(
+ and_(equal(field_ref("alpha"), literal("0.0")),
+ equal(field_ref("beta"), literal("hello"))));
+}
+
+TEST_F(TestPartitioning, DirectoryPartitioningFormatDictionary) {
+ auto dictionary = ArrayFromJSON(utf8(), R"(["hello", "world"])");
+ partitioning_ = std::make_shared<DirectoryPartitioning>(schema({DictStr("alpha")}),
+ ArrayVector{dictionary});
+ written_schema_ = partitioning_->schema();
+
+ ASSERT_OK_AND_ASSIGN(auto dict_hello, MakeScalar("hello")->CastTo(DictStr("")->type()));
+ AssertFormat(equal(field_ref("alpha"), literal(dict_hello)), "hello");
+}
+
+TEST_F(TestPartitioning, DirectoryPartitioningFormatDictionaryCustomIndex) {
+ // Make sure a non-int32 index type is properly cast to, else we fail a CHECK when
+ // we construct a dictionary array with the wrong index type
+ auto dict_type = dictionary(int8(), utf8());
+ auto dictionary = ArrayFromJSON(utf8(), R"(["hello", "world"])");
+ partitioning_ = std::make_shared<DirectoryPartitioning>(
+ schema({field("alpha", dict_type)}), ArrayVector{dictionary});
+ written_schema_ = partitioning_->schema();
+
+ ASSERT_OK_AND_ASSIGN(auto dict_hello, MakeScalar("hello")->CastTo(dict_type));
+ AssertFormat(equal(field_ref("alpha"), literal(dict_hello)), "hello");
+}
+
+TEST_F(TestPartitioning, DirectoryPartitioningWithTemporal) {
+ for (auto temporal : {timestamp(TimeUnit::SECOND), date32()}) {
+ partitioning_ = std::make_shared<DirectoryPartitioning>(
+ schema({field("year", int32()), field("month", int8()), field("day", temporal)}));
+
+ ASSERT_OK_AND_ASSIGN(auto day, StringScalar("2020-06-08").CastTo(temporal));
+ AssertParse("/2020/06/2020-06-08",
+ and_({equal(field_ref("year"), literal(2020)),
+ equal(field_ref("month"), literal<int8_t>(6)),
+ equal(field_ref("day"), literal(day))}));
+ }
+}
+
+TEST_F(TestPartitioning, DiscoverSchema) {
+ factory_ = DirectoryPartitioning::MakeFactory({"alpha", "beta"});
+
+ // type is int32 if possible
+ AssertInspect({"/0/1"}, {Int("alpha"), Int("beta")});
+
+ // extra segments are ignored
+ AssertInspect({"/0/1/what"}, {Int("alpha"), Int("beta")});
+
+ // fall back to string if any segment for field alpha is not parseable as int
+ AssertInspect({"/0/1", "/hello/1"}, {Str("alpha"), Int("beta")});
+
+ // If there are too many digits fall back to string
+ AssertInspect({"/3760212050/1"}, {Str("alpha"), Int("beta")});
+
+ // missing segment for beta doesn't cause an error or fallback
+ AssertInspect({"/0/1", "/hello"}, {Str("alpha"), Int("beta")});
+}
+
+TEST_F(TestPartitioning, DictionaryInference) {
+ PartitioningFactoryOptions options;
+ options.infer_dictionary = true;
+ factory_ = DirectoryPartitioning::MakeFactory({"alpha", "beta"}, options);
+
+ // type is still int32 if possible
+ AssertInspect({"/0/1"}, {DictInt("alpha"), DictInt("beta")});
+
+ // If there are too many digits fall back to string
+ AssertInspect({"/3760212050/1"}, {DictStr("alpha"), DictInt("beta")});
+
+ // successful dictionary inference
+ AssertInspect({"/a/0"}, {DictStr("alpha"), DictInt("beta")});
+ AssertInspect({"/a/0", "/a/1"}, {DictStr("alpha"), DictInt("beta")});
+ AssertInspect({"/a/0", "/a"}, {DictStr("alpha"), DictInt("beta")});
+ AssertInspect({"/0/a", "/1"}, {DictInt("alpha"), DictStr("beta")});
+ AssertInspect({"/a/0", "/b/0", "/a/1", "/b/1"}, {DictStr("alpha"), DictInt("beta")});
+ AssertInspect({"/a/-", "/b/-", "/a/_", "/b/_"}, {DictStr("alpha"), DictStr("beta")});
+}
+
+TEST_F(TestPartitioning, DictionaryHasUniqueValues) {
+ PartitioningFactoryOptions options;
+ options.infer_dictionary = true;
+ factory_ = DirectoryPartitioning::MakeFactory({"alpha"}, options);
+
+ auto alpha = DictStr("alpha");
+ AssertInspect({"/a", "/b", "/a", "/b", "/c", "/a"}, {alpha});
+ ASSERT_OK_AND_ASSIGN(auto partitioning, factory_->Finish(schema({alpha})));
+
+ auto expected_dictionary =
+ checked_pointer_cast<StringArray>(ArrayFromJSON(utf8(), R"(["a", "b", "c"])"));
+
+ for (int32_t i = 0; i < expected_dictionary->length(); ++i) {
+ DictionaryScalar::ValueType index_and_dictionary{std::make_shared<Int32Scalar>(i),
+ expected_dictionary};
+ auto dictionary_scalar =
+ std::make_shared<DictionaryScalar>(index_and_dictionary, alpha->type());
+
+ auto path = "/" + expected_dictionary->GetString(i);
+ AssertParse(path, equal(field_ref("alpha"), literal(dictionary_scalar)));
+ }
+
+ AssertParseError("/yosemite"); // not in inspected dictionary
+}
+
+TEST_F(TestPartitioning, DiscoverSchemaSegfault) {
+ // ARROW-7638
+ factory_ = DirectoryPartitioning::MakeFactory({"alpha", "beta"});
+ AssertInspectError({"oops.txt"});
+}
+
+TEST_F(TestPartitioning, HivePartitioning) {
+ partitioning_ = std::make_shared<HivePartitioning>(
+ schema({field("alpha", int32()), field("beta", float32())}), ArrayVector(), "xyz");
+
+ AssertParse("/alpha=0/beta=3.25", and_(equal(field_ref("alpha"), literal(0)),
+ equal(field_ref("beta"), literal(3.25f))));
+ AssertParse("/beta=3.25/alpha=0", and_(equal(field_ref("beta"), literal(3.25f)),
+ equal(field_ref("alpha"), literal(0))));
+ AssertParse("/alpha=0", equal(field_ref("alpha"), literal(0)));
+ AssertParse("/alpha=xyz/beta=3.25", and_(is_null(field_ref("alpha")),
+ equal(field_ref("beta"), literal(3.25f))));
+ AssertParse("/beta=3.25", equal(field_ref("beta"), literal(3.25f)));
+ AssertParse("", literal(true));
+
+ AssertParse("/alpha=0/unexpected/beta=3.25",
+ and_(equal(field_ref("alpha"), literal(0)),
+ equal(field_ref("beta"), literal(3.25f))));
+
+ AssertParse("/alpha=0/beta=3.25/ignored=2341",
+ and_(equal(field_ref("alpha"), literal(0)),
+ equal(field_ref("beta"), literal(3.25f))));
+
+ AssertParse("/ignored=2341", literal(true));
+
+ AssertParseError("/alpha=0.0/beta=3.25"); // conversion of "0.0" to int32 fails
+}
+
+TEST_F(TestPartitioning, HivePartitioningFormat) {
+ partitioning_ = std::make_shared<HivePartitioning>(
+ schema({field("alpha", int32()), field("beta", float32())}), ArrayVector(), "xyz");
+
+ written_schema_ = partitioning_->schema();
+
+ AssertFormat(and_(equal(field_ref("alpha"), literal(0)),
+ equal(field_ref("beta"), literal(3.25f))),
+ "alpha=0/beta=3.25");
+ AssertFormat(and_(equal(field_ref("beta"), literal(3.25f)),
+ equal(field_ref("alpha"), literal(0))),
+ "alpha=0/beta=3.25");
+ AssertFormat(equal(field_ref("alpha"), literal(0)), "alpha=0");
+ AssertFormat(and_(equal(field_ref("alpha"), literal(0)), is_null(field_ref("beta"))),
+ "alpha=0/beta=xyz");
+ AssertFormat(
+ and_(is_null(field_ref("alpha")), equal(field_ref("beta"), literal(3.25f))),
+ "alpha=xyz/beta=3.25");
+ AssertFormat(literal(true), "");
+
+ AssertFormat(and_(is_null(field_ref("alpha")), is_null(field_ref("beta"))),
+ "alpha=xyz/beta=xyz");
+
+ ASSERT_OK_AND_ASSIGN(written_schema_,
+ written_schema_->AddField(0, field("gamma", utf8())));
+ AssertFormat(and_({equal(field_ref("gamma"), literal("yo")),
+ equal(field_ref("alpha"), literal(0)),
+ equal(field_ref("beta"), literal(3.25f))}),
+ "alpha=0/beta=3.25");
+
+ // written_schema_ is incompatible with partitioning_'s schema
+ written_schema_ = schema({field("alpha", utf8()), field("beta", utf8())});
+ AssertFormatError<StatusCode::TypeError>(
+ and_(equal(field_ref("alpha"), literal("0.0")),
+ equal(field_ref("beta"), literal("hello"))));
+}
+
+TEST_F(TestPartitioning, DiscoverHiveSchema) {
+ auto options = HivePartitioningFactoryOptions();
+ options.null_fallback = "xyz";
+ factory_ = HivePartitioning::MakeFactory(options);
+
+ // type is int32 if possible
+ AssertInspect({"/alpha=0/beta=1"}, {Int("alpha"), Int("beta")});
+
+ // extra segments are ignored
+ AssertInspect({"/gamma=0/unexpected/delta=1/dat.parquet"},
+ {Int("gamma"), Int("delta")});
+
+ // schema field names are in order of first occurrence
+ // (...so ensure your partitions are ordered the same for all paths)
+ AssertInspect({"/alpha=0/beta=1", "/beta=2/alpha=3"}, {Int("alpha"), Int("beta")});
+
+ // Null fallback strings shouldn't interfere with type inference
+ AssertInspect({"/alpha=xyz/beta=x", "/alpha=7/beta=xyz"}, {Int("alpha"), Str("beta")});
+
+ // Cannot infer if the only values are null
+ AssertInspectError({"/alpha=xyz"});
+
+ // If there are too many digits fall back to string
+ AssertInspect({"/alpha=3760212050"}, {Str("alpha")});
+
+ // missing path segments will not cause an error
+ AssertInspect({"/alpha=0/beta=1", "/beta=2/alpha=3", "/gamma=what"},
+ {Int("alpha"), Int("beta"), Str("gamma")});
+}
+
+TEST_F(TestPartitioning, HiveDictionaryInference) {
+ HivePartitioningFactoryOptions options;
+ options.infer_dictionary = true;
+ options.null_fallback = "xyz";
+ factory_ = HivePartitioning::MakeFactory(options);
+
+ // type is still int32 if possible
+ AssertInspect({"/alpha=0/beta=1"}, {DictInt("alpha"), DictInt("beta")});
+
+ // If there are too many digits fall back to string
+ AssertInspect({"/alpha=3760212050"}, {DictStr("alpha")});
+
+ // successful dictionary inference
+ AssertInspect({"/alpha=a/beta=0"}, {DictStr("alpha"), DictInt("beta")});
+ AssertInspect({"/alpha=a/beta=0", "/alpha=a/1"}, {DictStr("alpha"), DictInt("beta")});
+ AssertInspect({"/alpha=a/beta=0", "/alpha=xyz/beta=xyz"},
+ {DictStr("alpha"), DictInt("beta")});
+ AssertInspect(
+ {"/alpha=a/beta=0", "/alpha=b/beta=0", "/alpha=a/beta=1", "/alpha=b/beta=1"},
+ {DictStr("alpha"), DictInt("beta")});
+ AssertInspect(
+ {"/alpha=a/beta=-", "/alpha=b/beta=-", "/alpha=a/beta=_", "/alpha=b/beta=_"},
+ {DictStr("alpha"), DictStr("beta")});
+}
+
+TEST_F(TestPartitioning, HiveNullFallbackPassedOn) {
+ HivePartitioningFactoryOptions options;
+ options.null_fallback = "xyz";
+ factory_ = HivePartitioning::MakeFactory(options);
+
+ EXPECT_OK_AND_ASSIGN(auto schema, factory_->Inspect({"/alpha=a/beta=0"}));
+ EXPECT_OK_AND_ASSIGN(auto partitioning, factory_->Finish(schema));
+ ASSERT_EQ("xyz",
+ std::static_pointer_cast<HivePartitioning>(partitioning)->null_fallback());
+}
+
+TEST_F(TestPartitioning, HiveDictionaryHasUniqueValues) {
+ HivePartitioningFactoryOptions options;
+ options.infer_dictionary = true;
+ factory_ = HivePartitioning::MakeFactory(options);
+
+ auto alpha = DictStr("alpha");
+ AssertInspect({"/alpha=a", "/alpha=b", "/alpha=a", "/alpha=b", "/alpha=c", "/alpha=a"},
+ {alpha});
+ ASSERT_OK_AND_ASSIGN(auto partitioning, factory_->Finish(schema({alpha})));
+
+ auto expected_dictionary =
+ checked_pointer_cast<StringArray>(ArrayFromJSON(utf8(), R"(["a", "b", "c"])"));
+
+ for (int32_t i = 0; i < expected_dictionary->length(); ++i) {
+ DictionaryScalar::ValueType index_and_dictionary{std::make_shared<Int32Scalar>(i),
+ expected_dictionary};
+ auto dictionary_scalar =
+ std::make_shared<DictionaryScalar>(index_and_dictionary, alpha->type());
+
+ auto path = "/alpha=" + expected_dictionary->GetString(i);
+ AssertParse(path, equal(field_ref("alpha"), literal(dictionary_scalar)));
+ }
+
+ AssertParseError("/alpha=yosemite"); // not in inspected dictionary
+}
+
+TEST_F(TestPartitioning, ExistingSchemaDirectory) {
+ // Infer dictionary values but with a given schema
+ auto dict_type = dictionary(int8(), utf8());
+ PartitioningFactoryOptions options;
+ options.schema = schema({field("alpha", int64()), field("beta", dict_type)});
+ factory_ = DirectoryPartitioning::MakeFactory({"alpha", "beta"}, options);
+
+ AssertInspect({"/0/1"}, options.schema->fields());
+ AssertInspect({"/0/1/what"}, options.schema->fields());
+
+ // fail if any segment is not parseable as schema type
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Failed to parse string"),
+ factory_->Inspect({"/0/1", "/hello/1"}));
+ factory_ = DirectoryPartitioning::MakeFactory({"alpha", "beta"}, options);
+
+ // Now we don't fail since our type is large enough
+ AssertInspect({"/3760212050/1"}, options.schema->fields());
+ // If there are still too many digits, fail
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Failed to parse string"),
+ factory_->Inspect({"/1038581385102940193760212050/1"}));
+ factory_ = DirectoryPartitioning::MakeFactory({"alpha", "beta"}, options);
+
+ AssertInspect({"/0/1", "/2"}, options.schema->fields());
+}
+
+TEST_F(TestPartitioning, ExistingSchemaHive) {
+ // Infer dictionary values but with a given schema
+ auto dict_type = dictionary(int8(), utf8());
+ HivePartitioningFactoryOptions options;
+ options.schema = schema({field("a", int64()), field("b", dict_type)});
+ factory_ = HivePartitioning::MakeFactory(options);
+
+ AssertInspect({"/a=0/b=1"}, options.schema->fields());
+ AssertInspect({"/a=0/b=1/what"}, options.schema->fields());
+ AssertInspect({"/a=0", "/b=1"}, options.schema->fields());
+
+ // fail if any segment for field alpha is not parseable as schema type
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr(
+ "Could not cast segments for partition field a to requested type int64"),
+ factory_->Inspect({"/a=0/b=1", "/a=hello/b=1"}));
+ factory_ = HivePartitioning::MakeFactory(options);
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Requested schema has 2 fields, but only 1 were detected"),
+ factory_->Inspect({"/a=0", "/a=hello"}));
+ factory_ = HivePartitioning::MakeFactory(options);
+
+ // Now we don't fail since our type is large enough
+ AssertInspect({"/a=3760212050/b=1"}, options.schema->fields());
+ // If there are still too many digits, fail
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Failed to parse string"),
+ factory_->Inspect({"/a=1038581385102940193760212050/b=1"}));
+ factory_ = HivePartitioning::MakeFactory(options);
+
+ AssertInspect({"/a=0/b=1", "/b=2"}, options.schema->fields());
+}
+
+TEST_F(TestPartitioning, UrlEncodedDirectory) {
+ PartitioningFactoryOptions options;
+ auto ts = timestamp(TimeUnit::type::SECOND);
+ options.schema = schema({field("date", ts), field("time", ts), field("str", utf8())});
+ factory_ = DirectoryPartitioning::MakeFactory(options.schema->field_names(), options);
+
+ AssertInspect({"/2021-05-04 00:00:00/2021-05-04 07:27:00/%24",
+ "/2021-05-04 00%3A00%3A00/2021-05-04 07%3A27%3A00/foo"},
+ options.schema->fields());
+ auto date = std::make_shared<TimestampScalar>(1620086400, ts);
+ auto time = std::make_shared<TimestampScalar>(1620113220, ts);
+ partitioning_ = std::make_shared<DirectoryPartitioning>(options.schema, ArrayVector());
+ AssertParse("/2021-05-04 00%3A00%3A00/2021-05-04 07%3A27%3A00/%24",
+ and_({equal(field_ref("date"), literal(date)),
+ equal(field_ref("time"), literal(time)),
+ equal(field_ref("str"), literal("$"))}));
+
+ // Invalid UTF-8
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("was not valid UTF-8"),
+ factory_->Inspect({"/%AF/%BF/%CF"}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("was not valid UTF-8"),
+ partitioning_->Parse({"/%AF/%BF/%CF"}));
+
+ options.segment_encoding = SegmentEncoding::None;
+ options.schema =
+ schema({field("date", utf8()), field("time", utf8()), field("str", utf8())});
+ factory_ = DirectoryPartitioning::MakeFactory(options.schema->field_names(), options);
+ AssertInspect({"/2021-05-04 00:00:00/2021-05-04 07:27:00/%E3%81%8F%E3%81%BE",
+ "/2021-05-04 00%3A00%3A00/2021-05-04 07%3A27%3A00/foo"},
+ options.schema->fields());
+ partitioning_ = std::make_shared<DirectoryPartitioning>(
+ options.schema, ArrayVector(), options.AsPartitioningOptions());
+ AssertParse("/2021-05-04 00%3A00%3A00/2021-05-04 07%3A27%3A00/%24",
+ and_({equal(field_ref("date"), literal("2021-05-04 00%3A00%3A00")),
+ equal(field_ref("time"), literal("2021-05-04 07%3A27%3A00")),
+ equal(field_ref("str"), literal("%24"))}));
+}
+
+TEST_F(TestPartitioning, UrlEncodedHive) {
+ HivePartitioningFactoryOptions options;
+ auto ts = timestamp(TimeUnit::type::SECOND);
+ options.schema = schema({field("date", ts), field("time", ts), field("str", utf8())});
+ options.null_fallback = "$";
+ factory_ = HivePartitioning::MakeFactory(options);
+
+ AssertInspect(
+ {"/date=2021-05-04 00:00:00/time=2021-05-04 07:27:00/str=$",
+ "/date=2021-05-04 00:00:00/time=2021-05-04 07:27:00/str=%E3%81%8F%E3%81%BE",
+ "/date=2021-05-04 00%3A00%3A00/time=2021-05-04 07%3A27%3A00/str=%24"},
+ options.schema->fields());
+
+ auto date = std::make_shared<TimestampScalar>(1620086400, ts);
+ auto time = std::make_shared<TimestampScalar>(1620113220, ts);
+ partitioning_ = std::make_shared<HivePartitioning>(options.schema, ArrayVector(),
+ options.AsHivePartitioningOptions());
+ AssertParse("/date=2021-05-04 00:00:00/time=2021-05-04 07:27:00/str=$",
+ and_({equal(field_ref("date"), literal(date)),
+ equal(field_ref("time"), literal(time)), is_null(field_ref("str"))}));
+ AssertParse("/date=2021-05-04 00:00:00/time=2021-05-04 07:27:00/str=%E3%81%8F%E3%81%BE",
+ and_({equal(field_ref("date"), literal(date)),
+ equal(field_ref("time"), literal(time)),
+ equal(field_ref("str"), literal("\xE3\x81\x8F\xE3\x81\xBE"))}));
+ // URL-encoded null fallback value
+ AssertParse("/date=2021-05-04 00%3A00%3A00/time=2021-05-04 07%3A27%3A00/str=%24",
+ and_({equal(field_ref("date"), literal(date)),
+ equal(field_ref("time"), literal(time)), is_null(field_ref("str"))}));
+
+ // Invalid UTF-8
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("was not valid UTF-8"),
+ factory_->Inspect({"/date=%AF/time=%BF/str=%CF"}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("was not valid UTF-8"),
+ partitioning_->Parse({"/date=%AF/time=%BF/str=%CF"}));
+
+ options.segment_encoding = SegmentEncoding::None;
+ options.schema =
+ schema({field("date", utf8()), field("time", utf8()), field("str", utf8())});
+ factory_ = HivePartitioning::MakeFactory(options);
+ AssertInspect(
+ {"/date=2021-05-04 00:00:00/time=2021-05-04 07:27:00/str=$",
+ "/date=2021-05-04 00:00:00/time=2021-05-04 07:27:00/str=%E3%81%8F%E3%81%BE",
+ "/date=2021-05-04 00%3A00%3A00/time=2021-05-04 07%3A27%3A00/str=%24"},
+ options.schema->fields());
+ partitioning_ = std::make_shared<HivePartitioning>(options.schema, ArrayVector(),
+ options.AsHivePartitioningOptions());
+ AssertParse("/date=2021-05-04 00%3A00%3A00/time=2021-05-04 07%3A27%3A00/str=%24",
+ and_({equal(field_ref("date"), literal("2021-05-04 00%3A00%3A00")),
+ equal(field_ref("time"), literal("2021-05-04 07%3A27%3A00")),
+ equal(field_ref("str"), literal("%24"))}));
+
+ // Invalid UTF-8
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("was not valid UTF-8"),
+ factory_->Inspect({"/date=\xAF/time=\xBF/str=\xCF"}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("was not valid UTF-8"),
+ partitioning_->Parse({"/date=\xAF/time=\xBF/str=\xCF"}));
+}
+
+TEST_F(TestPartitioning, EtlThenHive) {
+ FieldVector etl_fields{field("year", int16()), field("month", int8()),
+ field("day", int8()), field("hour", int8())};
+ DirectoryPartitioning etl_part(schema(etl_fields));
+
+ FieldVector alphabeta_fields{field("alpha", int32()), field("beta", float32())};
+ HivePartitioning alphabeta_part(schema(alphabeta_fields));
+
+ auto schm =
+ schema({field("year", int16()), field("month", int8()), field("day", int8()),
+ field("hour", int8()), field("alpha", int32()), field("beta", float32())});
+
+ partitioning_ = std::make_shared<FunctionPartitioning>(
+ schm, [&](const std::string& path) -> Result<compute::Expression> {
+ auto segments = fs::internal::SplitAbstractPath(path);
+ if (segments.size() < etl_fields.size() + alphabeta_fields.size()) {
+ return Status::Invalid("path ", path, " can't be parsed");
+ }
+
+ auto etl_segments_end = segments.begin() + etl_fields.size();
+ auto etl_path =
+ fs::internal::JoinAbstractPath(segments.begin(), etl_segments_end);
+ ARROW_ASSIGN_OR_RAISE(auto etl_expr, etl_part.Parse(etl_path));
+
+ auto alphabeta_segments_end = etl_segments_end + alphabeta_fields.size();
+ auto alphabeta_path =
+ fs::internal::JoinAbstractPath(etl_segments_end, alphabeta_segments_end);
+ ARROW_ASSIGN_OR_RAISE(auto alphabeta_expr, alphabeta_part.Parse(alphabeta_path));
+
+ return and_(etl_expr, alphabeta_expr);
+ });
+
+ AssertParse("/1999/12/31/00/alpha=0/beta=3.25",
+ and_({equal(field_ref("year"), literal<int16_t>(1999)),
+ equal(field_ref("month"), literal<int8_t>(12)),
+ equal(field_ref("day"), literal<int8_t>(31)),
+ equal(field_ref("hour"), literal<int8_t>(0)),
+ and_(equal(field_ref("alpha"), literal<int32_t>(0)),
+ equal(field_ref("beta"), literal<float>(3.25f)))}));
+
+ AssertParseError("/20X6/03/21/05/alpha=0/beta=3.25");
+}
+
+TEST_F(TestPartitioning, Set) {
+ auto ints = [](std::vector<int32_t> ints) {
+ std::shared_ptr<Array> out;
+ ArrayFromVector<Int32Type>(ints, &out);
+ return out;
+ };
+
+ auto schm = schema({field("x", int32())});
+
+ // An adhoc partitioning which parses segments like "/x in [1 4 5]"
+ // into (field_ref("x") == 1 or field_ref("x") == 4 or field_ref("x") == 5)
+ partitioning_ = std::make_shared<FunctionPartitioning>(
+ schm, [&](const std::string& path) -> Result<compute::Expression> {
+ std::vector<compute::Expression> subexpressions;
+ for (auto segment : fs::internal::SplitAbstractPath(path)) {
+ std::smatch matches;
+
+ static std::regex re(R"(^(\S+) in \[(.*)\]$)");
+ if (!std::regex_match(segment, matches, re) || matches.size() != 3) {
+ return Status::Invalid("regex failed to parse");
+ }
+
+ std::vector<int32_t> set;
+ std::istringstream elements(matches[2]);
+ for (std::string element; elements >> element;) {
+ ARROW_ASSIGN_OR_RAISE(auto s, Scalar::Parse(int32(), element));
+ set.push_back(checked_cast<const Int32Scalar&>(*s).value);
+ }
+
+ subexpressions.push_back(call("is_in", {field_ref(std::string(matches[1]))},
+ compute::SetLookupOptions{ints(set)}));
+ }
+ return and_(std::move(subexpressions));
+ });
+
+ auto x_in = [&](std::vector<int32_t> set) {
+ return call("is_in", {field_ref("x")}, compute::SetLookupOptions{ints(set)});
+ };
+ AssertParse("/x in [1]", x_in({1}));
+ AssertParse("/x in [1 4 5]", x_in({1, 4, 5}));
+ AssertParse("/x in []", x_in({}));
+}
+
+// An adhoc partitioning which parses segments like "/x=[-3.25, 0.0)"
+// into (field_ref("x") >= -3.25 and "x" < 0.0)
+class RangePartitioning : public Partitioning {
+ public:
+ explicit RangePartitioning(std::shared_ptr<Schema> s) : Partitioning(std::move(s)) {}
+
+ std::string type_name() const override { return "range"; }
+
+ Result<compute::Expression> Parse(const std::string& path) const override {
+ std::vector<compute::Expression> ranges;
+
+ HivePartitioningOptions options;
+ for (auto segment : fs::internal::SplitAbstractPath(path)) {
+ ARROW_ASSIGN_OR_RAISE(auto key, HivePartitioning::ParseKey(segment, options));
+ if (!key) {
+ return Status::Invalid("can't parse '", segment, "' as a range");
+ }
+
+ std::smatch matches;
+ RETURN_NOT_OK(DoRegex(*key->value, &matches));
+
+ auto& min_cmp = matches[1] == "[" ? greater_equal : greater;
+ std::string min_repr = matches[2];
+ std::string max_repr = matches[3];
+ auto& max_cmp = matches[4] == "]" ? less_equal : less;
+
+ const auto& type = schema_->GetFieldByName(key->name)->type();
+ ARROW_ASSIGN_OR_RAISE(auto min, Scalar::Parse(type, min_repr));
+ ARROW_ASSIGN_OR_RAISE(auto max, Scalar::Parse(type, max_repr));
+
+ ranges.push_back(and_(min_cmp(field_ref(key->name), literal(min)),
+ max_cmp(field_ref(key->name), literal(max))));
+ }
+
+ return and_(ranges);
+ }
+
+ static Status DoRegex(const std::string& segment, std::smatch* matches) {
+ static std::regex re(
+ "^"
+ "([\\[\\(])" // open bracket or paren
+ "([^ ]+)" // representation of range minimum
+ " "
+ "([^ ]+)" // representation of range maximum
+ "([\\]\\)])" // close bracket or paren
+ "$");
+
+ if (!std::regex_match(segment, *matches, re) || matches->size() != 5) {
+ return Status::Invalid("regex failed to parse");
+ }
+
+ return Status::OK();
+ }
+
+ Result<std::string> Format(const compute::Expression&) const override { return ""; }
+ Result<PartitionedBatches> Partition(
+ const std::shared_ptr<RecordBatch>&) const override {
+ return Status::OK();
+ }
+};
+
+TEST_F(TestPartitioning, Range) {
+ partitioning_ = std::make_shared<RangePartitioning>(
+ schema({field("x", float64()), field("y", float64()), field("z", float64())}));
+
+ AssertParse("/x=[-1.5 0.0)/y=[0.0 1.5)/z=(1.5 3.0]",
+ and_({and_(greater_equal(field_ref("x"), literal(-1.5)),
+ less(field_ref("x"), literal(0.0))),
+ and_(greater_equal(field_ref("y"), literal(0.0)),
+ less(field_ref("y"), literal(1.5))),
+ and_(greater(field_ref("z"), literal(1.5)),
+ less_equal(field_ref("z"), literal(3.0)))}));
+}
+
+TEST(TestStripPrefixAndFilename, Basic) {
+ ASSERT_EQ(StripPrefixAndFilename("", ""), "");
+ ASSERT_EQ(StripPrefixAndFilename("a.csv", ""), "");
+ ASSERT_EQ(StripPrefixAndFilename("a/b.csv", ""), "a");
+ ASSERT_EQ(StripPrefixAndFilename("/a/b/c.csv", "/a"), "b");
+ ASSERT_EQ(StripPrefixAndFilename("/a/b/c/d.csv", "/a"), "b/c");
+ ASSERT_EQ(StripPrefixAndFilename("/a/b/c.csv", "/a/b"), "");
+
+ std::vector<std::string> input{"/data/year=2019/file.parquet",
+ "/data/year=2019/month=12/file.parquet",
+ "/data/year=2019/month=12/day=01/file.parquet"};
+ EXPECT_THAT(StripPrefixAndFilename(input, "/data"),
+ testing::ElementsAre("year=2019", "year=2019/month=12",
+ "year=2019/month=12/day=01"));
+}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/pch.h b/src/arrow/cpp/src/arrow/dataset/pch.h
new file mode 100644
index 000000000..a74fd96e3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/pch.h
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Often-used headers, for precompiling.
+// If updating this header, please make sure you check compilation speed
+// before checking in. Adding headers which are not used extremely often
+// may incur a slowdown, since it makes the precompiled header heavier to load.
+
+// This API is EXPERIMENTAL.
+
+#include "arrow/dataset/dataset.h"
+#include "arrow/dataset/scanner.h"
+#include "arrow/pch.h"
diff --git a/src/arrow/cpp/src/arrow/dataset/plan.cc b/src/arrow/cpp/src/arrow/dataset/plan.cc
new file mode 100644
index 000000000..9b222ff57
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/plan.cc
@@ -0,0 +1,39 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/plan.h"
+
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/scanner.h"
+
+namespace arrow {
+namespace dataset {
+namespace internal {
+
+void Initialize() {
+ static auto registry = compute::default_exec_factory_registry();
+ if (registry) {
+ InitializeScanner(registry);
+ InitializeDatasetWriter(registry);
+ registry = nullptr;
+ }
+}
+
+} // namespace internal
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/plan.h b/src/arrow/cpp/src/arrow/dataset/plan.h
new file mode 100644
index 000000000..10260ccec
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/plan.h
@@ -0,0 +1,33 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#include "arrow/dataset/visibility.h"
+
+namespace arrow {
+namespace dataset {
+namespace internal {
+
+/// Register dataset-based exec nodes with the exec node registry
+///
+/// This function must be called before using dataset ExecNode factories
+ARROW_DS_EXPORT void Initialize();
+
+} // namespace internal
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/projector.cc b/src/arrow/cpp/src/arrow/dataset/projector.cc
new file mode 100644
index 000000000..b2196a874
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/projector.cc
@@ -0,0 +1,63 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/projector.h"
+
+#include "arrow/status.h"
+#include "arrow/type.h"
+
+namespace arrow {
+namespace dataset {
+
+Status CheckProjectable(const Schema& from, const Schema& to) {
+ for (const auto& to_field : to.fields()) {
+ ARROW_ASSIGN_OR_RAISE(auto from_field, FieldRef(to_field->name()).GetOneOrNone(from));
+
+ if (from_field == nullptr) {
+ if (to_field->nullable()) continue;
+
+ return Status::TypeError("field ", to_field->ToString(),
+ " is not nullable and does not exist in origin schema ",
+ from);
+ }
+
+ if (from_field->type()->id() == Type::NA) {
+ // promotion from null to any type is supported
+ if (to_field->nullable()) continue;
+
+ return Status::TypeError("field ", to_field->ToString(),
+ " is not nullable but has type ", NullType(),
+ " in origin schema ", from);
+ }
+
+ if (!from_field->type()->Equals(to_field->type())) {
+ return Status::TypeError("fields had matching names but differing types. From: ",
+ from_field->ToString(), " To: ", to_field->ToString());
+ }
+
+ if (from_field->nullable() && !to_field->nullable()) {
+ return Status::TypeError("field ", to_field->ToString(),
+ " is not nullable but is not required in origin schema ",
+ from);
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/projector.h b/src/arrow/cpp/src/arrow/dataset/projector.h
new file mode 100644
index 000000000..86d38f0af
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/projector.h
@@ -0,0 +1,32 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include "arrow/dataset/visibility.h"
+#include "arrow/type_fwd.h"
+
+namespace arrow {
+namespace dataset {
+
+// FIXME this is superceded by compute::Expression::Bind
+ARROW_DS_EXPORT Status CheckProjectable(const Schema& from, const Schema& to);
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/scanner.cc b/src/arrow/cpp/src/arrow/dataset/scanner.cc
new file mode 100644
index 000000000..23942ec37
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/scanner.cc
@@ -0,0 +1,1347 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/scanner.h"
+
+#include <algorithm>
+#include <condition_variable>
+#include <memory>
+#include <mutex>
+#include <sstream>
+
+#include "arrow/array/array_primitive.h"
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/dataset/dataset.h"
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/plan.h"
+#include "arrow/dataset/scanner_internal.h"
+#include "arrow/table.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/task_group.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::Executor;
+using internal::SerialExecutor;
+using internal::TaskGroup;
+
+namespace dataset {
+
+using FragmentGenerator = std::function<Future<std::shared_ptr<Fragment>>()>;
+
+std::vector<std::string> ScanOptions::MaterializedFields() const {
+ std::vector<std::string> fields;
+
+ for (const compute::Expression* expr : {&filter, &projection}) {
+ for (const FieldRef& ref : FieldsInExpression(*expr)) {
+ DCHECK(ref.name());
+ fields.push_back(*ref.name());
+ }
+ }
+
+ return fields;
+}
+
+std::shared_ptr<TaskGroup> ScanOptions::TaskGroup() const {
+ if (use_threads) {
+ auto* thread_pool = arrow::internal::GetCpuThreadPool();
+ return TaskGroup::MakeThreaded(thread_pool);
+ }
+ return TaskGroup::MakeSerial();
+}
+
+Result<RecordBatchIterator> InMemoryScanTask::Execute() {
+ return MakeVectorIterator(record_batches_);
+}
+
+Future<RecordBatchVector> ScanTask::SafeExecute(Executor* executor) {
+ // If the ScanTask can't possibly be async then just execute it
+ ARROW_ASSIGN_OR_RAISE(auto rb_it, Execute());
+ return Future<RecordBatchVector>::MakeFinished(rb_it.ToVector());
+}
+
+Future<> ScanTask::SafeVisit(
+ Executor* executor, std::function<Status(std::shared_ptr<RecordBatch>)> visitor) {
+ // If the ScanTask can't possibly be async then just execute it
+ ARROW_ASSIGN_OR_RAISE(auto rb_it, Execute());
+ return Future<>::MakeFinished(rb_it.Visit(visitor));
+}
+
+Result<ScanTaskIterator> Scanner::Scan() {
+ // TODO(ARROW-12289) This is overridden in SyncScanner and will never be implemented in
+ // AsyncScanner. It is deprecated and will eventually go away.
+ return Status::NotImplemented("This scanner does not support the legacy Scan() method");
+}
+
+Result<EnumeratedRecordBatchIterator> Scanner::ScanBatchesUnordered() {
+ // If a scanner doesn't support unordered scanning (i.e. SyncScanner) then we just
+ // fall back to an ordered scan and assign the appropriate tagging
+ ARROW_ASSIGN_OR_RAISE(auto ordered_scan, ScanBatches());
+ return AddPositioningToInOrderScan(std::move(ordered_scan));
+}
+
+Result<EnumeratedRecordBatchIterator> Scanner::AddPositioningToInOrderScan(
+ TaggedRecordBatchIterator scan) {
+ ARROW_ASSIGN_OR_RAISE(auto first, scan.Next());
+ if (IsIterationEnd(first)) {
+ return MakeEmptyIterator<EnumeratedRecordBatch>();
+ }
+ struct State {
+ State(TaggedRecordBatchIterator source, TaggedRecordBatch first)
+ : source(std::move(source)),
+ batch_index(0),
+ fragment_index(0),
+ finished(false),
+ prev_batch(std::move(first)) {}
+ TaggedRecordBatchIterator source;
+ int batch_index;
+ int fragment_index;
+ bool finished;
+ TaggedRecordBatch prev_batch;
+ };
+ struct EnumeratingIterator {
+ Result<EnumeratedRecordBatch> Next() {
+ if (state->finished) {
+ return IterationEnd<EnumeratedRecordBatch>();
+ }
+ ARROW_ASSIGN_OR_RAISE(auto next, state->source.Next());
+ if (IsIterationEnd<TaggedRecordBatch>(next)) {
+ state->finished = true;
+ return EnumeratedRecordBatch{
+ {std::move(state->prev_batch.record_batch), state->batch_index, true},
+ {std::move(state->prev_batch.fragment), state->fragment_index, true}};
+ }
+ auto prev = std::move(state->prev_batch);
+ bool prev_is_last_batch = false;
+ auto prev_batch_index = state->batch_index;
+ auto prev_fragment_index = state->fragment_index;
+ // Reference equality here seems risky but a dataset should have a constant set of
+ // fragments which should be consistent for the lifetime of a scan
+ if (prev.fragment.get() != next.fragment.get()) {
+ state->batch_index = 0;
+ state->fragment_index++;
+ prev_is_last_batch = true;
+ } else {
+ state->batch_index++;
+ }
+ state->prev_batch = std::move(next);
+ return EnumeratedRecordBatch{
+ {std::move(prev.record_batch), prev_batch_index, prev_is_last_batch},
+ {std::move(prev.fragment), prev_fragment_index, false}};
+ }
+ std::shared_ptr<State> state;
+ };
+ return EnumeratedRecordBatchIterator(
+ EnumeratingIterator{std::make_shared<State>(std::move(scan), std::move(first))});
+}
+
+Result<int64_t> Scanner::CountRows() {
+ // Naive base implementation
+ ARROW_ASSIGN_OR_RAISE(auto batch_it, ScanBatchesUnordered());
+ int64_t count = 0;
+ RETURN_NOT_OK(batch_it.Visit([&](EnumeratedRecordBatch batch) {
+ count += batch.record_batch.value->num_rows();
+ return Status::OK();
+ }));
+ return count;
+}
+
+namespace {
+class ScannerRecordBatchReader : public RecordBatchReader {
+ public:
+ explicit ScannerRecordBatchReader(std::shared_ptr<Schema> schema,
+ TaggedRecordBatchIterator delegate)
+ : schema_(std::move(schema)), delegate_(std::move(delegate)) {}
+
+ std::shared_ptr<Schema> schema() const override { return schema_; }
+ Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
+ ARROW_ASSIGN_OR_RAISE(auto next, delegate_.Next());
+ if (IsIterationEnd(next)) {
+ *batch = nullptr;
+ } else {
+ *batch = std::move(next.record_batch);
+ }
+ return Status::OK();
+ }
+
+ private:
+ std::shared_ptr<Schema> schema_;
+ TaggedRecordBatchIterator delegate_;
+};
+} // namespace
+
+Result<std::shared_ptr<RecordBatchReader>> Scanner::ToRecordBatchReader() {
+ ARROW_ASSIGN_OR_RAISE(auto it, ScanBatches());
+ return std::make_shared<ScannerRecordBatchReader>(options()->projected_schema,
+ std::move(it));
+}
+
+namespace {
+
+struct ScanBatchesState : public std::enable_shared_from_this<ScanBatchesState> {
+ explicit ScanBatchesState(ScanTaskIterator scan_task_it,
+ std::shared_ptr<TaskGroup> task_group_)
+ : scan_tasks(std::move(scan_task_it)), task_group(std::move(task_group_)) {}
+
+ void ResizeBatches(size_t task_index) {
+ if (task_batches.size() <= task_index) {
+ task_batches.resize(task_index + 1);
+ task_drained.resize(task_index + 1);
+ }
+ }
+
+ void Push(TaggedRecordBatch batch, size_t task_index) {
+ {
+ std::lock_guard<std::mutex> lock(mutex);
+ ResizeBatches(task_index);
+ task_batches[task_index].push_back(std::move(batch));
+ }
+ ready.notify_one();
+ }
+
+ template <typename T>
+ Result<T> PushError(Result<T>&& result, size_t task_index) {
+ if (!result.ok()) {
+ {
+ std::lock_guard<std::mutex> lock(mutex);
+ task_drained[task_index] = true;
+ iteration_error = result.status();
+ }
+ ready.notify_one();
+ }
+ return std::move(result);
+ }
+
+ Status Finish(size_t task_index) {
+ {
+ std::lock_guard<std::mutex> lock(mutex);
+ ResizeBatches(task_index);
+ task_drained[task_index] = true;
+ }
+ ready.notify_one();
+ return Status::OK();
+ }
+
+ void PushScanTask() {
+ if (no_more_tasks) {
+ return;
+ }
+ std::unique_lock<std::mutex> lock(mutex);
+ auto maybe_task = scan_tasks.Next();
+ if (!maybe_task.ok()) {
+ no_more_tasks = true;
+ iteration_error = maybe_task.status();
+ return;
+ }
+ auto scan_task = maybe_task.ValueOrDie();
+ if (IsIterationEnd(scan_task)) {
+ no_more_tasks = true;
+ return;
+ }
+ auto state = shared_from_this();
+ auto id = next_scan_task_id++;
+ ResizeBatches(id);
+
+ lock.unlock();
+ task_group->Append([state, id, scan_task]() {
+ // If we were to return an error to the task group, subsequent tasks
+ // may never be executed, which would produce a deadlock in Pop()
+ // (ARROW-13480).
+ auto status_unused = [&]() {
+ ARROW_ASSIGN_OR_RAISE(auto batch_it, state->PushError(scan_task->Execute(), id));
+ for (auto maybe_batch : batch_it) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, state->PushError(std::move(maybe_batch), id));
+ state->Push(TaggedRecordBatch{std::move(batch), scan_task->fragment()}, id);
+ }
+ return state->Finish(id);
+ }();
+ return Status::OK();
+ });
+ }
+
+ Result<TaggedRecordBatch> Pop() {
+ std::unique_lock<std::mutex> lock(mutex);
+ ready.wait(lock, [this, &lock] {
+ while (pop_cursor < task_batches.size()) {
+ // queue for current scan task contains at least one batch, pop that
+ if (!task_batches[pop_cursor].empty()) return true;
+ // queue is empty but will be appended to eventually, wait for that
+ if (!task_drained[pop_cursor]) return false;
+
+ // Finished draining current scan task, enqueue a new one
+ ++pop_cursor;
+ // Must unlock since serial task group will execute synchronously
+ lock.unlock();
+ PushScanTask();
+ lock.lock();
+ }
+ DCHECK(no_more_tasks);
+ // all scan tasks drained (or getting next task failed), terminate
+ return true;
+ });
+
+ // We're not bubbling any task errors into the task group
+ DCHECK(task_group->ok());
+
+ if (pop_cursor == task_batches.size()) {
+ // Don't report an error until we yield up everything we can first
+ RETURN_NOT_OK(iteration_error);
+ return IterationEnd<TaggedRecordBatch>();
+ }
+
+ auto batch = std::move(task_batches[pop_cursor].front());
+ task_batches[pop_cursor].pop_front();
+ return batch;
+ }
+
+ /// Protecting mutating accesses to batches
+ std::mutex mutex;
+ std::condition_variable ready;
+ ScanTaskIterator scan_tasks;
+ std::shared_ptr<TaskGroup> task_group;
+ int next_scan_task_id = 0;
+ bool no_more_tasks = false;
+ Status iteration_error;
+ std::vector<std::deque<TaggedRecordBatch>> task_batches;
+ std::vector<bool> task_drained;
+ size_t pop_cursor = 0;
+};
+
+class SyncScanner : public Scanner {
+ public:
+ SyncScanner(std::shared_ptr<Dataset> dataset, std::shared_ptr<ScanOptions> scan_options)
+ : Scanner(std::move(scan_options)), dataset_(std::move(dataset)) {}
+
+ Result<TaggedRecordBatchIterator> ScanBatches() override;
+ Result<ScanTaskIterator> Scan() override;
+ Status Scan(std::function<Status(TaggedRecordBatch)> visitor) override;
+ Result<std::shared_ptr<Table>> ToTable() override;
+ Result<TaggedRecordBatchGenerator> ScanBatchesAsync() override;
+ Result<EnumeratedRecordBatchGenerator> ScanBatchesUnorderedAsync() override;
+ Result<int64_t> CountRows() override;
+ const std::shared_ptr<Dataset>& dataset() const override;
+
+ protected:
+ /// \brief GetFragments returns an iterator over all Fragments in this scan.
+ Result<FragmentIterator> GetFragments();
+ Result<TaggedRecordBatchIterator> ScanBatches(ScanTaskIterator scan_task_it);
+ Future<std::shared_ptr<Table>> ToTableInternal(Executor* cpu_executor);
+ Result<ScanTaskIterator> ScanInternal();
+
+ std::shared_ptr<Dataset> dataset_;
+};
+
+Result<TaggedRecordBatchIterator> SyncScanner::ScanBatches() {
+ ARROW_ASSIGN_OR_RAISE(auto scan_task_it, ScanInternal());
+ return ScanBatches(std::move(scan_task_it));
+}
+
+Result<TaggedRecordBatchIterator> SyncScanner::ScanBatches(
+ ScanTaskIterator scan_task_it) {
+ auto task_group = scan_options_->TaskGroup();
+ auto state = std::make_shared<ScanBatchesState>(std::move(scan_task_it), task_group);
+ for (int i = 0; i < scan_options_->fragment_readahead; i++) {
+ state->PushScanTask();
+ }
+ return MakeFunctionIterator([task_group, state]() -> Result<TaggedRecordBatch> {
+ ARROW_ASSIGN_OR_RAISE(auto batch, state->Pop());
+ if (!IsIterationEnd(batch)) return batch;
+ RETURN_NOT_OK(task_group->Finish());
+ return IterationEnd<TaggedRecordBatch>();
+ });
+}
+
+Result<TaggedRecordBatchGenerator> SyncScanner::ScanBatchesAsync() {
+ return Status::NotImplemented("Asynchronous scanning is not supported by SyncScanner");
+}
+
+Result<EnumeratedRecordBatchGenerator> SyncScanner::ScanBatchesUnorderedAsync() {
+ return Status::NotImplemented("Asynchronous scanning is not supported by SyncScanner");
+}
+
+Result<FragmentIterator> SyncScanner::GetFragments() {
+ // Transform Datasets in a flat Iterator<Fragment>. This
+ // iterator is lazily constructed, i.e. Dataset::GetFragments is
+ // not invoked until a Fragment is requested.
+ return GetFragmentsFromDatasets({dataset_}, scan_options_->filter);
+}
+
+Result<ScanTaskIterator> SyncScanner::Scan() { return ScanInternal(); }
+
+Status SyncScanner::Scan(std::function<Status(TaggedRecordBatch)> visitor) {
+ ARROW_ASSIGN_OR_RAISE(auto scan_task_it, ScanInternal());
+
+ auto task_group = scan_options_->TaskGroup();
+
+ for (auto maybe_scan_task : scan_task_it) {
+ ARROW_ASSIGN_OR_RAISE(auto scan_task, maybe_scan_task);
+ task_group->Append([scan_task, visitor] {
+ ARROW_ASSIGN_OR_RAISE(auto batch_it, scan_task->Execute());
+ for (auto maybe_batch : batch_it) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, maybe_batch);
+ RETURN_NOT_OK(
+ visitor(TaggedRecordBatch{std::move(batch), scan_task->fragment()}));
+ }
+ return Status::OK();
+ });
+ }
+
+ return task_group->Finish();
+}
+
+Result<ScanTaskIterator> SyncScanner::ScanInternal() {
+ // Transforms Iterator<Fragment> into a unified
+ // Iterator<ScanTask>. The first Iterator::Next invocation is going to do
+ // all the work of unwinding the chained iterators.
+ ARROW_ASSIGN_OR_RAISE(auto fragment_it, GetFragments());
+ return GetScanTaskIterator(std::move(fragment_it), scan_options_);
+}
+
+const std::shared_ptr<Dataset>& SyncScanner::dataset() const { return dataset_; }
+
+class AsyncScanner : public Scanner, public std::enable_shared_from_this<AsyncScanner> {
+ public:
+ AsyncScanner(std::shared_ptr<Dataset> dataset,
+ std::shared_ptr<ScanOptions> scan_options)
+ : Scanner(std::move(scan_options)), dataset_(std::move(dataset)) {
+ internal::Initialize();
+ }
+
+ Status Scan(std::function<Status(TaggedRecordBatch)> visitor) override;
+ Result<TaggedRecordBatchIterator> ScanBatches() override;
+ Result<TaggedRecordBatchGenerator> ScanBatchesAsync() override;
+ Result<EnumeratedRecordBatchIterator> ScanBatchesUnordered() override;
+ Result<EnumeratedRecordBatchGenerator> ScanBatchesUnorderedAsync() override;
+ Result<std::shared_ptr<Table>> ToTable() override;
+ Result<int64_t> CountRows() override;
+ const std::shared_ptr<Dataset>& dataset() const override;
+
+ private:
+ Result<TaggedRecordBatchGenerator> ScanBatchesAsync(Executor* executor);
+ Future<> VisitBatchesAsync(std::function<Status(TaggedRecordBatch)> visitor,
+ Executor* executor);
+ Result<EnumeratedRecordBatchGenerator> ScanBatchesUnorderedAsync(
+ Executor* executor, bool sequence_fragments = false);
+ Future<std::shared_ptr<Table>> ToTableAsync(Executor* executor);
+
+ Result<FragmentGenerator> GetFragments() const;
+
+ std::shared_ptr<Dataset> dataset_;
+};
+
+Result<EnumeratedRecordBatchGenerator> FragmentToBatches(
+ const Enumerated<std::shared_ptr<Fragment>>& fragment,
+ const std::shared_ptr<ScanOptions>& options) {
+ ARROW_ASSIGN_OR_RAISE(auto batch_gen, fragment.value->ScanBatchesAsync(options));
+ ArrayVector columns;
+ for (const auto& field : options->dataset_schema->fields()) {
+ // TODO(ARROW-7051): use helper to make empty batch
+ ARROW_ASSIGN_OR_RAISE(auto array,
+ MakeArrayOfNull(field->type(), /*length=*/0, options->pool));
+ columns.push_back(std::move(array));
+ }
+ batch_gen = MakeDefaultIfEmptyGenerator(
+ std::move(batch_gen),
+ RecordBatch::Make(options->dataset_schema, /*num_rows=*/0, std::move(columns)));
+ auto enumerated_batch_gen = MakeEnumeratedGenerator(std::move(batch_gen));
+
+ auto combine_fn =
+ [fragment](const Enumerated<std::shared_ptr<RecordBatch>>& record_batch) {
+ return EnumeratedRecordBatch{record_batch, fragment};
+ };
+
+ return MakeMappedGenerator(enumerated_batch_gen, std::move(combine_fn));
+}
+
+Result<AsyncGenerator<EnumeratedRecordBatchGenerator>> FragmentsToBatches(
+ FragmentGenerator fragment_gen, const std::shared_ptr<ScanOptions>& options) {
+ auto enumerated_fragment_gen = MakeEnumeratedGenerator(std::move(fragment_gen));
+ return MakeMappedGenerator(std::move(enumerated_fragment_gen),
+ [=](const Enumerated<std::shared_ptr<Fragment>>& fragment) {
+ return FragmentToBatches(fragment, options);
+ });
+}
+
+const FieldVector kAugmentedFields{
+ field("__fragment_index", int32()),
+ field("__batch_index", int32()),
+ field("__last_in_fragment", boolean()),
+};
+
+class OneShotScanTask : public ScanTask {
+ public:
+ OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr<ScanOptions> options,
+ std::shared_ptr<Fragment> fragment)
+ : ScanTask(std::move(options), std::move(fragment)),
+ batch_it_(std::move(batch_it)) {}
+ Result<RecordBatchIterator> Execute() override {
+ if (!batch_it_) return Status::Invalid("OneShotScanTask was already scanned");
+ return std::move(batch_it_);
+ }
+
+ private:
+ RecordBatchIterator batch_it_;
+};
+
+class OneShotFragment : public Fragment {
+ public:
+ OneShotFragment(std::shared_ptr<Schema> schema, RecordBatchIterator batch_it)
+ : Fragment(compute::literal(true), std::move(schema)),
+ batch_it_(std::move(batch_it)) {
+ DCHECK_NE(physical_schema_, nullptr);
+ }
+ Status CheckConsumed() {
+ if (!batch_it_) return Status::Invalid("OneShotFragment was already scanned");
+ return Status::OK();
+ }
+ Result<ScanTaskIterator> Scan(std::shared_ptr<ScanOptions> options) override {
+ RETURN_NOT_OK(CheckConsumed());
+ ScanTaskVector tasks{std::make_shared<OneShotScanTask>(
+ std::move(batch_it_), std::move(options), shared_from_this())};
+ return MakeVectorIterator(std::move(tasks));
+ }
+ Result<RecordBatchGenerator> ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& options) override {
+ RETURN_NOT_OK(CheckConsumed());
+ ARROW_ASSIGN_OR_RAISE(
+ auto background_gen,
+ MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor()));
+ return MakeTransferredGenerator(std::move(background_gen),
+ ::arrow::internal::GetCpuThreadPool());
+ }
+ std::string type_name() const override { return "one-shot"; }
+
+ protected:
+ Result<std::shared_ptr<Schema>> ReadPhysicalSchemaImpl() override {
+ return physical_schema_;
+ }
+
+ RecordBatchIterator batch_it_;
+};
+
+Result<FragmentGenerator> AsyncScanner::GetFragments() const {
+ // TODO(ARROW-8163): Async fragment scanning will return AsyncGenerator<Fragment>
+ // here. Current iterator based versions are all fast & sync so we will just ToVector
+ // it
+ ARROW_ASSIGN_OR_RAISE(auto fragments_it, dataset_->GetFragments(scan_options_->filter));
+ ARROW_ASSIGN_OR_RAISE(auto fragments_vec, fragments_it.ToVector());
+ return MakeVectorGenerator(std::move(fragments_vec));
+}
+
+Result<TaggedRecordBatchIterator> AsyncScanner::ScanBatches() {
+ ARROW_ASSIGN_OR_RAISE(auto batches_gen,
+ ScanBatchesAsync(::arrow::internal::GetCpuThreadPool()));
+ return MakeGeneratorIterator(std::move(batches_gen));
+}
+
+Result<EnumeratedRecordBatchIterator> AsyncScanner::ScanBatchesUnordered() {
+ ARROW_ASSIGN_OR_RAISE(auto batches_gen,
+ ScanBatchesUnorderedAsync(::arrow::internal::GetCpuThreadPool()));
+ return MakeGeneratorIterator(std::move(batches_gen));
+}
+
+Result<std::shared_ptr<Table>> AsyncScanner::ToTable() {
+ auto table_fut = ToTableAsync(::arrow::internal::GetCpuThreadPool());
+ return table_fut.result();
+}
+
+Result<EnumeratedRecordBatchGenerator> AsyncScanner::ScanBatchesUnorderedAsync() {
+ return ScanBatchesUnorderedAsync(::arrow::internal::GetCpuThreadPool());
+}
+
+Result<EnumeratedRecordBatch> ToEnumeratedRecordBatch(
+ const util::optional<compute::ExecBatch>& batch, const ScanOptions& options,
+ const FragmentVector& fragments) {
+ int num_fields = options.projected_schema->num_fields();
+
+ EnumeratedRecordBatch out;
+ out.fragment.index = batch->values[num_fields].scalar_as<Int32Scalar>().value;
+ out.fragment.last = false; // ignored during reordering
+ out.fragment.value = fragments[out.fragment.index];
+
+ out.record_batch.index = batch->values[num_fields + 1].scalar_as<Int32Scalar>().value;
+ out.record_batch.last = batch->values[num_fields + 2].scalar_as<BooleanScalar>().value;
+ ARROW_ASSIGN_OR_RAISE(out.record_batch.value,
+ batch->ToRecordBatch(options.projected_schema, options.pool));
+ return out;
+}
+
+Result<EnumeratedRecordBatchGenerator> AsyncScanner::ScanBatchesUnorderedAsync(
+ Executor* cpu_executor, bool sequence_fragments) {
+ if (!scan_options_->use_threads) {
+ cpu_executor = nullptr;
+ }
+
+ auto exec_context =
+ std::make_shared<compute::ExecContext>(scan_options_->pool, cpu_executor);
+
+ ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(exec_context.get()));
+ AsyncGenerator<util::optional<compute::ExecBatch>> sink_gen;
+
+ util::BackpressureOptions backpressure =
+ util::BackpressureOptions::Make(kDefaultBackpressureLow, kDefaultBackpressureHigh);
+ auto exprs = scan_options_->projection.call()->arguments;
+ auto names = checked_cast<const compute::MakeStructOptions*>(
+ scan_options_->projection.call()->options.get())
+ ->field_names;
+
+ RETURN_NOT_OK(
+ compute::Declaration::Sequence(
+ {
+ {"scan", ScanNodeOptions{dataset_, scan_options_, backpressure.toggle,
+ sequence_fragments}},
+ {"filter", compute::FilterNodeOptions{scan_options_->filter}},
+ {"augmented_project",
+ compute::ProjectNodeOptions{std::move(exprs), std::move(names)}},
+ {"sink", compute::SinkNodeOptions{&sink_gen, std::move(backpressure)}},
+ })
+ .AddToPlan(plan.get()));
+
+ RETURN_NOT_OK(plan->StartProducing());
+
+ auto options = scan_options_;
+ ARROW_ASSIGN_OR_RAISE(auto fragments_it, dataset_->GetFragments(scan_options_->filter));
+ ARROW_ASSIGN_OR_RAISE(auto fragments, fragments_it.ToVector());
+ auto shared_fragments = std::make_shared<FragmentVector>(std::move(fragments));
+
+ // If the generator is destroyed before being completely drained, inform plan
+ std::shared_ptr<void> stop_producing{
+ nullptr, [plan, exec_context](...) {
+ bool not_finished_yet = plan->finished().TryAddCallback(
+ [&plan, &exec_context] { return [plan, exec_context](const Status&) {}; });
+
+ if (not_finished_yet) {
+ plan->StopProducing();
+ }
+ }};
+
+ return MakeMappedGenerator(
+ std::move(sink_gen),
+ [sink_gen, options, stop_producing,
+ shared_fragments](const util::optional<compute::ExecBatch>& batch)
+ -> Future<EnumeratedRecordBatch> {
+ return ToEnumeratedRecordBatch(batch, *options, *shared_fragments);
+ });
+}
+
+Result<TaggedRecordBatchGenerator> AsyncScanner::ScanBatchesAsync() {
+ return ScanBatchesAsync(::arrow::internal::GetCpuThreadPool());
+}
+
+Result<TaggedRecordBatchGenerator> AsyncScanner::ScanBatchesAsync(
+ Executor* cpu_executor) {
+ ARROW_ASSIGN_OR_RAISE(auto unordered, ScanBatchesUnorderedAsync(
+ cpu_executor, /*sequence_fragments=*/true));
+ // We need an initial value sentinel, so we use one with fragment.index < 0
+ auto is_before_any = [](const EnumeratedRecordBatch& batch) {
+ return batch.fragment.index < 0;
+ };
+ auto left_after_right = [&is_before_any](const EnumeratedRecordBatch& left,
+ const EnumeratedRecordBatch& right) {
+ // Before any comes first
+ if (is_before_any(left)) {
+ return false;
+ }
+ if (is_before_any(right)) {
+ return true;
+ }
+ // Compare batches if fragment is the same
+ if (left.fragment.index == right.fragment.index) {
+ return left.record_batch.index > right.record_batch.index;
+ }
+ // Otherwise compare fragment
+ return left.fragment.index > right.fragment.index;
+ };
+ auto is_next = [is_before_any](const EnumeratedRecordBatch& prev,
+ const EnumeratedRecordBatch& next) {
+ // Only true if next is the first batch
+ if (is_before_any(prev)) {
+ return next.fragment.index == 0 && next.record_batch.index == 0;
+ }
+ // If same fragment, compare batch index
+ if (prev.fragment.index == next.fragment.index) {
+ return next.record_batch.index == prev.record_batch.index + 1;
+ }
+ // Else only if next first batch of next fragment and prev is last batch of previous
+ return next.fragment.index == prev.fragment.index + 1 && prev.record_batch.last &&
+ next.record_batch.index == 0;
+ };
+ auto before_any = EnumeratedRecordBatch{{nullptr, -1, false}, {nullptr, -1, false}};
+ auto sequenced = MakeSequencingGenerator(std::move(unordered), left_after_right,
+ is_next, before_any);
+
+ auto unenumerate_fn = [](const EnumeratedRecordBatch& enumerated_batch) {
+ return TaggedRecordBatch{enumerated_batch.record_batch.value,
+ enumerated_batch.fragment.value};
+ };
+ return MakeMappedGenerator(std::move(sequenced), unenumerate_fn);
+}
+
+struct AsyncTableAssemblyState {
+ /// Protecting mutating accesses to batches
+ std::mutex mutex{};
+ std::vector<RecordBatchVector> batches{};
+
+ void Emplace(const EnumeratedRecordBatch& batch) {
+ std::lock_guard<std::mutex> lock(mutex);
+ auto fragment_index = batch.fragment.index;
+ auto batch_index = batch.record_batch.index;
+ if (static_cast<int>(batches.size()) <= fragment_index) {
+ batches.resize(fragment_index + 1);
+ }
+ if (static_cast<int>(batches[fragment_index].size()) <= batch_index) {
+ batches[fragment_index].resize(batch_index + 1);
+ }
+ batches[fragment_index][batch_index] = batch.record_batch.value;
+ }
+
+ RecordBatchVector Finish() {
+ RecordBatchVector all_batches;
+ for (auto& fragment_batches : batches) {
+ auto end = std::make_move_iterator(fragment_batches.end());
+ for (auto it = std::make_move_iterator(fragment_batches.begin()); it != end; it++) {
+ all_batches.push_back(*it);
+ }
+ }
+ return all_batches;
+ }
+};
+
+Status AsyncScanner::Scan(std::function<Status(TaggedRecordBatch)> visitor) {
+ auto top_level_task = [this, &visitor](Executor* executor) {
+ return VisitBatchesAsync(visitor, executor);
+ };
+ return ::arrow::internal::RunSynchronously<Future<>>(top_level_task,
+ scan_options_->use_threads);
+}
+
+Future<> AsyncScanner::VisitBatchesAsync(std::function<Status(TaggedRecordBatch)> visitor,
+ Executor* executor) {
+ ARROW_ASSIGN_OR_RAISE(auto batches_gen, ScanBatchesAsync(executor));
+ return VisitAsyncGenerator(std::move(batches_gen), visitor);
+}
+
+Future<std::shared_ptr<Table>> AsyncScanner::ToTableAsync(Executor* cpu_executor) {
+ auto scan_options = scan_options_;
+ ARROW_ASSIGN_OR_RAISE(auto positioned_batch_gen,
+ ScanBatchesUnorderedAsync(cpu_executor));
+ /// Wraps the state in a shared_ptr to ensure that failing ScanTasks don't
+ /// invalidate concurrently running tasks when Finish() early returns
+ /// and the mutex/batches fail out of scope.
+ auto state = std::make_shared<AsyncTableAssemblyState>();
+
+ auto table_building_task = [state](const EnumeratedRecordBatch& batch) {
+ state->Emplace(batch);
+ return batch;
+ };
+
+ auto table_building_gen =
+ MakeMappedGenerator(positioned_batch_gen, table_building_task);
+
+ return DiscardAllFromAsyncGenerator(table_building_gen).Then([state, scan_options]() {
+ return Table::FromRecordBatches(scan_options->projected_schema, state->Finish());
+ });
+}
+
+Result<int64_t> AsyncScanner::CountRows() {
+ ARROW_ASSIGN_OR_RAISE(auto fragment_gen, GetFragments());
+
+ auto cpu_executor =
+ scan_options_->use_threads ? ::arrow::internal::GetCpuThreadPool() : nullptr;
+ compute::ExecContext exec_context(scan_options_->pool, cpu_executor);
+
+ ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(&exec_context));
+ // Drop projection since we only need to count rows
+ const auto options = std::make_shared<ScanOptions>(*scan_options_);
+ RETURN_NOT_OK(SetProjection(options.get(), std::vector<std::string>()));
+
+ std::atomic<int64_t> total{0};
+
+ fragment_gen = MakeMappedGenerator(
+ std::move(fragment_gen), [&](const std::shared_ptr<Fragment>& fragment) {
+ return fragment->CountRows(options->filter, options)
+ .Then([&, fragment](util::optional<int64_t> fast_count) mutable
+ -> std::shared_ptr<Fragment> {
+ if (fast_count) {
+ // fast path: got row count directly; skip scanning this fragment
+ total += *fast_count;
+ return std::make_shared<InMemoryFragment>(options->dataset_schema,
+ RecordBatchVector{});
+ }
+
+ // slow path: actually filter this fragment's batches
+ return std::move(fragment);
+ });
+ });
+
+ AsyncGenerator<util::optional<compute::ExecBatch>> sink_gen;
+
+ RETURN_NOT_OK(
+ compute::Declaration::Sequence(
+ {
+ {"scan", ScanNodeOptions{std::make_shared<FragmentDataset>(
+ scan_options_->dataset_schema,
+ std::move(fragment_gen)),
+ options}},
+ {"project", compute::ProjectNodeOptions{{options->filter}, {"mask"}}},
+ {"aggregate", compute::AggregateNodeOptions{{compute::internal::Aggregate{
+ "sum", nullptr}},
+ /*targets=*/{"mask"},
+ /*names=*/{"selected_count"}}},
+ {"sink", compute::SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ RETURN_NOT_OK(plan->StartProducing());
+ auto maybe_slow_count = sink_gen().result();
+ plan->finished().Wait();
+
+ ARROW_ASSIGN_OR_RAISE(auto slow_count, maybe_slow_count);
+ total += slow_count->values[0].scalar_as<UInt64Scalar>().value;
+
+ return total.load();
+}
+
+const std::shared_ptr<Dataset>& AsyncScanner::dataset() const { return dataset_; }
+
+} // namespace
+
+ScannerBuilder::ScannerBuilder(std::shared_ptr<Dataset> dataset)
+ : ScannerBuilder(std::move(dataset), std::make_shared<ScanOptions>()) {}
+
+ScannerBuilder::ScannerBuilder(std::shared_ptr<Dataset> dataset,
+ std::shared_ptr<ScanOptions> scan_options)
+ : dataset_(std::move(dataset)), scan_options_(std::move(scan_options)) {
+ scan_options_->dataset_schema = dataset_->schema();
+ DCHECK_OK(Filter(scan_options_->filter));
+}
+
+ScannerBuilder::ScannerBuilder(std::shared_ptr<Schema> schema,
+ std::shared_ptr<Fragment> fragment,
+ std::shared_ptr<ScanOptions> scan_options)
+ : ScannerBuilder(std::make_shared<FragmentDataset>(
+ std::move(schema), FragmentVector{std::move(fragment)}),
+ std::move(scan_options)) {}
+
+std::shared_ptr<ScannerBuilder> ScannerBuilder::FromRecordBatchReader(
+ std::shared_ptr<RecordBatchReader> reader) {
+ auto batch_it = MakeIteratorFromReader(reader);
+ auto fragment =
+ std::make_shared<OneShotFragment>(reader->schema(), std::move(batch_it));
+ return std::make_shared<ScannerBuilder>(reader->schema(), std::move(fragment),
+ std::make_shared<ScanOptions>());
+}
+
+const std::shared_ptr<Schema>& ScannerBuilder::schema() const {
+ return scan_options_->dataset_schema;
+}
+
+const std::shared_ptr<Schema>& ScannerBuilder::projected_schema() const {
+ return scan_options_->projected_schema;
+}
+
+Status ScannerBuilder::Project(std::vector<std::string> columns) {
+ return SetProjection(scan_options_.get(), std::move(columns));
+}
+
+Status ScannerBuilder::Project(std::vector<compute::Expression> exprs,
+ std::vector<std::string> names) {
+ return SetProjection(scan_options_.get(), std::move(exprs), std::move(names));
+}
+
+Status ScannerBuilder::Filter(const compute::Expression& filter) {
+ return SetFilter(scan_options_.get(), filter);
+}
+
+Status ScannerBuilder::UseThreads(bool use_threads) {
+ scan_options_->use_threads = use_threads;
+ return Status::OK();
+}
+
+Status ScannerBuilder::FragmentReadahead(int fragment_readahead) {
+ if (fragment_readahead <= 0) {
+ return Status::Invalid("FragmentReadahead must be greater than 0, got ",
+ fragment_readahead);
+ }
+ scan_options_->fragment_readahead = fragment_readahead;
+ return Status::OK();
+}
+
+Status ScannerBuilder::UseAsync(bool use_async) {
+ scan_options_->use_async = use_async;
+ return Status::OK();
+}
+
+Status ScannerBuilder::BatchSize(int64_t batch_size) {
+ if (batch_size <= 0) {
+ return Status::Invalid("BatchSize must be greater than 0, got ", batch_size);
+ }
+ scan_options_->batch_size = batch_size;
+ return Status::OK();
+}
+
+Status ScannerBuilder::Pool(MemoryPool* pool) {
+ scan_options_->pool = pool;
+ return Status::OK();
+}
+
+Status ScannerBuilder::FragmentScanOptions(
+ std::shared_ptr<dataset::FragmentScanOptions> fragment_scan_options) {
+ scan_options_->fragment_scan_options = std::move(fragment_scan_options);
+ return Status::OK();
+}
+
+Result<std::shared_ptr<Scanner>> ScannerBuilder::Finish() {
+ if (!scan_options_->projection.IsBound()) {
+ RETURN_NOT_OK(Project(scan_options_->dataset_schema->field_names()));
+ }
+
+ if (scan_options_->use_async) {
+ return std::make_shared<AsyncScanner>(dataset_, scan_options_);
+ } else {
+ return std::make_shared<SyncScanner>(dataset_, scan_options_);
+ }
+}
+
+namespace {
+
+inline RecordBatchVector FlattenRecordBatchVector(
+ std::vector<RecordBatchVector> nested_batches) {
+ RecordBatchVector flattened;
+
+ for (auto& task_batches : nested_batches) {
+ for (auto& batch : task_batches) {
+ flattened.emplace_back(std::move(batch));
+ }
+ }
+
+ return flattened;
+}
+
+struct TableAssemblyState {
+ /// Protecting mutating accesses to batches
+ std::mutex mutex{};
+ std::vector<RecordBatchVector> batches{};
+
+ void Emplace(RecordBatchVector b, size_t position) {
+ std::lock_guard<std::mutex> lock(mutex);
+ if (batches.size() <= position) {
+ batches.resize(position + 1);
+ }
+ batches[position] = std::move(b);
+ }
+};
+
+Result<std::shared_ptr<Table>> SyncScanner::ToTable() {
+ ARROW_ASSIGN_OR_RAISE(auto scan_task_it, ScanInternal());
+ auto task_group = scan_options_->TaskGroup();
+
+ /// Wraps the state in a shared_ptr to ensure that failing ScanTasks don't
+ /// invalidate concurrently running tasks when Finish() early returns
+ /// and the mutex/batches fail out of scope.
+ auto state = std::make_shared<TableAssemblyState>();
+
+ // TODO (ARROW-11797) Migrate to using ScanBatches()
+ size_t scan_task_id = 0;
+ for (auto maybe_scan_task : scan_task_it) {
+ ARROW_ASSIGN_OR_RAISE(auto scan_task, maybe_scan_task);
+
+ auto id = scan_task_id++;
+ task_group->Append([state, id, scan_task] {
+ ARROW_ASSIGN_OR_RAISE(
+ auto local,
+ ::arrow::internal::SerialExecutor::RunInSerialExecutor<RecordBatchVector>(
+ [&](Executor* executor) { return scan_task->SafeExecute(executor); }));
+ state->Emplace(std::move(local), id);
+ return Status::OK();
+ });
+ }
+ auto scan_options = scan_options_;
+ // Wait for all tasks to complete, or the first error
+ RETURN_NOT_OK(task_group->Finish());
+ return Table::FromRecordBatches(scan_options->projected_schema,
+ FlattenRecordBatchVector(std::move(state->batches)));
+}
+
+Result<int64_t> SyncScanner::CountRows() {
+ // While readers could implement an optimization where they just fabricate empty
+ // batches based on metadata when no columns are selected, skipping I/O (and
+ // indeed, the Parquet reader does this), counting rows using that optimization is
+ // still slower than just hitting metadata directly where possible.
+ ARROW_ASSIGN_OR_RAISE(auto fragment_it, GetFragments());
+ // Fragment is non-null iff fast path could not be taken.
+ std::vector<Future<std::pair<int64_t, std::shared_ptr<Fragment>>>> futures;
+ for (auto maybe_fragment : fragment_it) {
+ ARROW_ASSIGN_OR_RAISE(auto fragment, maybe_fragment);
+ auto count_fut = fragment->CountRows(scan_options_->filter, scan_options_);
+ futures.push_back(
+ count_fut.Then([fragment](const util::optional<int64_t>& count)
+ -> std::pair<int64_t, std::shared_ptr<Fragment>> {
+ if (count.has_value()) {
+ return std::make_pair(*count, nullptr);
+ }
+ return std::make_pair(0, std::move(fragment));
+ }));
+ }
+
+ int64_t count = 0;
+ FragmentVector fragments;
+ for (auto& future : futures) {
+ ARROW_ASSIGN_OR_RAISE(auto count_result, future.result());
+ count += count_result.first;
+ if (count_result.second) {
+ fragments.push_back(std::move(count_result.second));
+ }
+ }
+ // Now check for any fragments where we couldn't take the fast path
+ if (!fragments.empty()) {
+ auto options = std::make_shared<ScanOptions>(*scan_options_);
+ RETURN_NOT_OK(SetProjection(options.get(), std::vector<std::string>()));
+ ARROW_ASSIGN_OR_RAISE(
+ auto scan_task_it,
+ GetScanTaskIterator(MakeVectorIterator(std::move(fragments)), options));
+ ARROW_ASSIGN_OR_RAISE(auto batch_it, ScanBatches(std::move(scan_task_it)));
+ RETURN_NOT_OK(batch_it.Visit([&](TaggedRecordBatch batch) {
+ count += batch.record_batch->num_rows();
+ return Status::OK();
+ }));
+ }
+ return count;
+}
+
+} // namespace
+
+Result<std::shared_ptr<Table>> Scanner::TakeRows(const Array& indices) {
+ if (indices.null_count() != 0) {
+ return Status::NotImplemented("null take indices");
+ }
+
+ compute::ExecContext ctx(scan_options_->pool);
+
+ const Array* original_indices;
+ // If we have to cast, this is the backing reference
+ std::shared_ptr<Array> original_indices_ptr;
+ if (indices.type_id() != Type::INT64) {
+ ARROW_ASSIGN_OR_RAISE(
+ original_indices_ptr,
+ compute::Cast(indices, int64(), compute::CastOptions::Safe(), &ctx));
+ original_indices = original_indices_ptr.get();
+ } else {
+ original_indices = &indices;
+ }
+
+ std::shared_ptr<Array> unsort_indices;
+ {
+ ARROW_ASSIGN_OR_RAISE(
+ auto sort_indices,
+ compute::SortIndices(*original_indices, compute::SortOrder::Ascending, &ctx));
+ ARROW_ASSIGN_OR_RAISE(original_indices_ptr,
+ compute::Take(*original_indices, *sort_indices,
+ compute::TakeOptions::Defaults(), &ctx));
+ original_indices = original_indices_ptr.get();
+ ARROW_ASSIGN_OR_RAISE(
+ unsort_indices,
+ compute::SortIndices(*sort_indices, compute::SortOrder::Ascending, &ctx));
+ }
+
+ RecordBatchVector out_batches;
+
+ auto raw_indices = static_cast<const Int64Array&>(*original_indices).raw_values();
+ int64_t offset = 0, row_begin = 0;
+
+ ARROW_ASSIGN_OR_RAISE(auto batch_it, ScanBatches());
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, batch_it.Next());
+ if (IsIterationEnd(batch)) break;
+ if (offset == original_indices->length()) break;
+ DCHECK_LT(offset, original_indices->length());
+
+ int64_t length = 0;
+ while (offset + length < original_indices->length()) {
+ auto rel_index = raw_indices[offset + length] - row_begin;
+ if (rel_index >= batch.record_batch->num_rows()) break;
+ ++length;
+ }
+ DCHECK_LE(offset + length, original_indices->length());
+ if (length == 0) {
+ row_begin += batch.record_batch->num_rows();
+ continue;
+ }
+
+ Datum rel_indices = original_indices->Slice(offset, length);
+ ARROW_ASSIGN_OR_RAISE(rel_indices,
+ compute::Subtract(rel_indices, Datum(row_begin),
+ compute::ArithmeticOptions(), &ctx));
+
+ ARROW_ASSIGN_OR_RAISE(Datum out_batch,
+ compute::Take(batch.record_batch, rel_indices,
+ compute::TakeOptions::Defaults(), &ctx));
+ out_batches.push_back(out_batch.record_batch());
+
+ offset += length;
+ row_begin += batch.record_batch->num_rows();
+ }
+
+ if (offset < original_indices->length()) {
+ std::stringstream error;
+ const int64_t max_values_shown = 3;
+ const int64_t num_remaining = original_indices->length() - offset;
+ for (int64_t i = 0; i < std::min<int64_t>(max_values_shown, num_remaining); i++) {
+ if (i > 0) error << ", ";
+ error << static_cast<const Int64Array*>(original_indices)->Value(offset + i);
+ }
+ if (num_remaining > max_values_shown) error << ", ...";
+ return Status::IndexError("Some indices were out of bounds: ", error.str());
+ }
+ ARROW_ASSIGN_OR_RAISE(Datum out, Table::FromRecordBatches(options()->projected_schema,
+ std::move(out_batches)));
+ ARROW_ASSIGN_OR_RAISE(
+ out, compute::Take(out, unsort_indices, compute::TakeOptions::Defaults(), &ctx));
+ return out.table();
+}
+
+Result<std::shared_ptr<Table>> Scanner::Head(int64_t num_rows) {
+ if (num_rows == 0) {
+ return Table::FromRecordBatches(options()->projected_schema, {});
+ }
+ ARROW_ASSIGN_OR_RAISE(auto batch_iterator, ScanBatches());
+ RecordBatchVector batches;
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, batch_iterator.Next());
+ if (IsIterationEnd(batch)) break;
+ batches.push_back(batch.record_batch->Slice(0, num_rows));
+ num_rows -= batch.record_batch->num_rows();
+ if (num_rows <= 0) break;
+ }
+ return Table::FromRecordBatches(options()->projected_schema, batches);
+}
+
+namespace {
+
+Result<compute::ExecNode*> MakeScanNode(compute::ExecPlan* plan,
+ std::vector<compute::ExecNode*> inputs,
+ const compute::ExecNodeOptions& options) {
+ const auto& scan_node_options = checked_cast<const ScanNodeOptions&>(options);
+ auto scan_options = scan_node_options.scan_options;
+ auto dataset = scan_node_options.dataset;
+ const auto& backpressure_toggle = scan_node_options.backpressure_toggle;
+ bool require_sequenced_output = scan_node_options.require_sequenced_output;
+
+ if (!scan_options->use_async) {
+ return Status::NotImplemented("ScanNodes without asynchrony");
+ }
+
+ if (scan_options->dataset_schema == nullptr) {
+ scan_options->dataset_schema = dataset->schema();
+ }
+
+ if (!scan_options->filter.IsBound()) {
+ ARROW_ASSIGN_OR_RAISE(scan_options->filter,
+ scan_options->filter.Bind(*dataset->schema()));
+ }
+
+ if (!scan_options->projection.IsBound()) {
+ auto fields = dataset->schema()->fields();
+ for (const auto& aug_field : kAugmentedFields) {
+ fields.push_back(aug_field);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(scan_options->projection,
+ scan_options->projection.Bind(Schema(std::move(fields))));
+ }
+
+ // using a generator for speculative forward compatibility with async fragment discovery
+ ARROW_ASSIGN_OR_RAISE(auto fragments_it, dataset->GetFragments(scan_options->filter));
+ ARROW_ASSIGN_OR_RAISE(auto fragments_vec, fragments_it.ToVector());
+ auto fragment_gen = MakeVectorGenerator(std::move(fragments_vec));
+
+ ARROW_ASSIGN_OR_RAISE(auto batch_gen_gen,
+ FragmentsToBatches(std::move(fragment_gen), scan_options));
+
+ AsyncGenerator<EnumeratedRecordBatch> merged_batch_gen;
+ if (require_sequenced_output) {
+ ARROW_ASSIGN_OR_RAISE(merged_batch_gen,
+ MakeSequencedMergedGenerator(std::move(batch_gen_gen),
+ scan_options->fragment_readahead));
+ } else {
+ merged_batch_gen =
+ MakeMergedGenerator(std::move(batch_gen_gen), scan_options->fragment_readahead);
+ }
+
+ auto batch_gen = MakeReadaheadGenerator(std::move(merged_batch_gen),
+ scan_options->fragment_readahead);
+
+ auto gen = MakeMappedGenerator(
+ std::move(batch_gen),
+ [scan_options](const EnumeratedRecordBatch& partial)
+ -> Result<util::optional<compute::ExecBatch>> {
+ ARROW_ASSIGN_OR_RAISE(util::optional<compute::ExecBatch> batch,
+ compute::MakeExecBatch(*scan_options->dataset_schema,
+ partial.record_batch.value));
+ // TODO(ARROW-13263) fragments may be able to attach more guarantees to batches
+ // than this, for example parquet's row group stats. Failing to do this leaves
+ // perf on the table because row group stats could be used to skip kernel execs in
+ // FilterNode.
+ //
+ // Additionally, if a fragment failed to perform projection pushdown there may be
+ // unnecessarily materialized columns in batch. We could drop them now instead of
+ // letting them coast through the rest of the plan.
+ batch->guarantee = partial.fragment.value->partition_expression();
+
+ // tag rows with fragment- and batch-of-origin
+ batch->values.emplace_back(partial.fragment.index);
+ batch->values.emplace_back(partial.record_batch.index);
+ batch->values.emplace_back(partial.record_batch.last);
+ return batch;
+ });
+
+ if (backpressure_toggle) {
+ gen = MakePauseable(gen, backpressure_toggle);
+ }
+
+ auto fields = scan_options->dataset_schema->fields();
+ for (const auto& aug_field : kAugmentedFields) {
+ fields.push_back(aug_field);
+ }
+
+ return compute::MakeExecNode(
+ "source", plan, {},
+ compute::SourceNodeOptions{schema(std::move(fields)), std::move(gen)});
+}
+
+Result<compute::ExecNode*> MakeAugmentedProjectNode(
+ compute::ExecPlan* plan, std::vector<compute::ExecNode*> inputs,
+ const compute::ExecNodeOptions& options) {
+ const auto& project_options = checked_cast<const compute::ProjectNodeOptions&>(options);
+ auto exprs = project_options.expressions;
+ auto names = project_options.names;
+
+ if (names.size() == 0) {
+ names.resize(exprs.size());
+ for (size_t i = 0; i < exprs.size(); ++i) {
+ names[i] = exprs[i].ToString();
+ }
+ }
+
+ for (const auto& aug_field : kAugmentedFields) {
+ exprs.push_back(compute::field_ref(aug_field->name()));
+ names.push_back(aug_field->name());
+ }
+ return compute::MakeExecNode(
+ "project", plan, std::move(inputs),
+ compute::ProjectNodeOptions{std::move(exprs), std::move(names)});
+}
+
+Result<compute::ExecNode*> MakeOrderedSinkNode(compute::ExecPlan* plan,
+ std::vector<compute::ExecNode*> inputs,
+ const compute::ExecNodeOptions& options) {
+ if (inputs.size() != 1) {
+ return Status::Invalid("Ordered SinkNode requires exactly 1 input, got ",
+ inputs.size());
+ }
+ auto input = inputs[0];
+
+ AsyncGenerator<util::optional<compute::ExecBatch>> unordered;
+ ARROW_ASSIGN_OR_RAISE(auto node,
+ compute::MakeExecNode("sink", plan, std::move(inputs),
+ compute::SinkNodeOptions{&unordered}));
+
+ const Schema& schema = *input->output_schema();
+ ARROW_ASSIGN_OR_RAISE(FieldPath match, FieldRef("__fragment_index").FindOne(schema));
+ int i = match[0];
+ auto fragment_index = [i](const compute::ExecBatch& batch) {
+ return batch.values[i].scalar_as<Int32Scalar>().value;
+ };
+ compute::ExecBatch before_any{{}, 0};
+ before_any.values.resize(i + 1);
+ before_any.values.back() = Datum(-1);
+
+ ARROW_ASSIGN_OR_RAISE(match, FieldRef("__batch_index").FindOne(schema));
+ i = match[0];
+ auto batch_index = [i](const compute::ExecBatch& batch) {
+ return batch.values[i].scalar_as<Int32Scalar>().value;
+ };
+
+ ARROW_ASSIGN_OR_RAISE(match, FieldRef("__last_in_fragment").FindOne(schema));
+ i = match[0];
+ auto last_in_fragment = [i](const compute::ExecBatch& batch) {
+ return batch.values[i].scalar_as<BooleanScalar>().value;
+ };
+
+ auto is_before_any = [=](const compute::ExecBatch& batch) {
+ return fragment_index(batch) < 0;
+ };
+
+ auto left_after_right = [=](const util::optional<compute::ExecBatch>& left,
+ const util::optional<compute::ExecBatch>& right) {
+ // Before any comes first
+ if (is_before_any(*left)) {
+ return false;
+ }
+ if (is_before_any(*right)) {
+ return true;
+ }
+ // Compare batches if fragment is the same
+ if (fragment_index(*left) == fragment_index(*right)) {
+ return batch_index(*left) > batch_index(*right);
+ }
+ // Otherwise compare fragment
+ return fragment_index(*left) > fragment_index(*right);
+ };
+
+ auto is_next = [=](const util::optional<compute::ExecBatch>& prev,
+ const util::optional<compute::ExecBatch>& next) {
+ // Only true if next is the first batch
+ if (is_before_any(*prev)) {
+ return fragment_index(*next) == 0 && batch_index(*next) == 0;
+ }
+ // If same fragment, compare batch index
+ if (fragment_index(*next) == fragment_index(*prev)) {
+ return batch_index(*next) == batch_index(*prev) + 1;
+ }
+ // Else only if next first batch of next fragment and prev is last batch of previous
+ return fragment_index(*next) == fragment_index(*prev) + 1 &&
+ last_in_fragment(*prev) && batch_index(*next) == 0;
+ };
+
+ const auto& sink_options = checked_cast<const compute::SinkNodeOptions&>(options);
+ *sink_options.generator =
+ MakeSequencingGenerator(std::move(unordered), left_after_right, is_next,
+ util::make_optional(std::move(before_any)));
+
+ return node;
+}
+
+} // namespace
+
+namespace internal {
+void InitializeScanner(arrow::compute::ExecFactoryRegistry* registry) {
+ DCHECK_OK(registry->AddFactory("scan", MakeScanNode));
+ DCHECK_OK(registry->AddFactory("ordered_sink", MakeOrderedSinkNode));
+ DCHECK_OK(registry->AddFactory("augmented_project", MakeAugmentedProjectNode));
+}
+} // namespace internal
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/scanner.h b/src/arrow/cpp/src/arrow/dataset/scanner.h
new file mode 100644
index 000000000..75e9806fb
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/scanner.h
@@ -0,0 +1,458 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/compute/exec/expression.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/type_fwd.h"
+#include "arrow/dataset/dataset.h"
+#include "arrow/dataset/projector.h"
+#include "arrow/dataset/type_fwd.h"
+#include "arrow/dataset/visibility.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/memory_pool.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/thread_pool.h"
+#include "arrow/util/type_fwd.h"
+
+namespace arrow {
+
+using RecordBatchGenerator = std::function<Future<std::shared_ptr<RecordBatch>>()>;
+
+namespace dataset {
+
+/// \defgroup dataset-scanning Scanning API
+///
+/// @{
+
+constexpr int64_t kDefaultBatchSize = 1 << 20;
+constexpr int32_t kDefaultBatchReadahead = 32;
+constexpr int32_t kDefaultFragmentReadahead = 8;
+constexpr int32_t kDefaultBackpressureHigh = 64;
+constexpr int32_t kDefaultBackpressureLow = 32;
+
+/// Scan-specific options, which can be changed between scans of the same dataset.
+struct ARROW_DS_EXPORT ScanOptions {
+ /// A row filter (which will be pushed down to partitioning/reading if supported).
+ compute::Expression filter = compute::literal(true);
+ /// A projection expression (which can add/remove/rename columns).
+ compute::Expression projection;
+
+ /// Schema with which batches will be read from fragments. This is also known as the
+ /// "reader schema" it will be used (for example) in constructing CSV file readers to
+ /// identify column types for parsing. Usually only a subset of its fields (see
+ /// MaterializedFields) will be materialized during a scan.
+ std::shared_ptr<Schema> dataset_schema;
+
+ /// Schema of projected record batches. This is independent of dataset_schema as its
+ /// fields are derived from the projection. For example, let
+ ///
+ /// dataset_schema = {"a": int32, "b": int32, "id": utf8}
+ /// projection = project({equal(field_ref("a"), field_ref("b"))}, {"a_plus_b"})
+ ///
+ /// (no filter specified). In this case, the projected_schema would be
+ ///
+ /// {"a_plus_b": int32}
+ std::shared_ptr<Schema> projected_schema;
+
+ /// Maximum row count for scanned batches.
+ int64_t batch_size = kDefaultBatchSize;
+
+ /// How many batches to read ahead within a file
+ ///
+ /// Set to 0 to disable batch readahead
+ ///
+ /// Note: May not be supported by all formats
+ /// Note: May not be supported by all scanners
+ /// Note: Will be ignored if use_threads is set to false
+ int32_t batch_readahead = kDefaultBatchReadahead;
+
+ /// How many files to read ahead
+ ///
+ /// Set to 0 to disable fragment readahead
+ ///
+ /// Note: May not be enforced by all scanners
+ /// Note: Will be ignored if use_threads is set to false
+ int32_t fragment_readahead = kDefaultFragmentReadahead;
+
+ /// A pool from which materialized and scanned arrays will be allocated.
+ MemoryPool* pool = arrow::default_memory_pool();
+
+ /// IOContext for any IO tasks
+ ///
+ /// Note: The IOContext executor will be ignored if use_threads is set to false
+ io::IOContext io_context;
+
+ /// If true the scanner will scan in parallel
+ ///
+ /// Note: If true, this will use threads from both the cpu_executor and the
+ /// io_context.executor
+ /// Note: This must be true in order for any readahead to happen
+ bool use_threads = false;
+
+ /// If true then an asycnhronous implementation of the scanner will be used.
+ /// This implementation is newer and generally performs better. However, it
+ /// makes extensive use of threading and is still considered experimental
+ bool use_async = false;
+
+ /// Fragment-specific scan options.
+ std::shared_ptr<FragmentScanOptions> fragment_scan_options;
+
+ // Return a vector of fields that requires materialization.
+ //
+ // This is usually the union of the fields referenced in the projection and the
+ // filter expression. Examples:
+ //
+ // - `SELECT a, b WHERE a < 2 && c > 1` => ["a", "b", "a", "c"]
+ // - `SELECT a + b < 3 WHERE a > 1` => ["a", "b"]
+ //
+ // This is needed for expression where a field may not be directly
+ // used in the final projection but is still required to evaluate the
+ // expression.
+ //
+ // This is used by Fragment implementations to apply the column
+ // sub-selection optimization.
+ std::vector<std::string> MaterializedFields() const;
+
+ // Return a threaded or serial TaskGroup according to use_threads.
+ std::shared_ptr<::arrow::internal::TaskGroup> TaskGroup() const;
+};
+
+/// \brief Read record batches from a range of a single data fragment. A
+/// ScanTask is meant to be a unit of work to be dispatched. The implementation
+/// must be thread and concurrent safe.
+class ARROW_DS_EXPORT ScanTask {
+ public:
+ /// \brief Iterate through sequence of materialized record batches
+ /// resulting from the Scan. Execution semantics are encapsulated in the
+ /// particular ScanTask implementation
+ virtual Result<RecordBatchIterator> Execute() = 0;
+ virtual Future<RecordBatchVector> SafeExecute(::arrow::internal::Executor* executor);
+ virtual Future<> SafeVisit(::arrow::internal::Executor* executor,
+ std::function<Status(std::shared_ptr<RecordBatch>)> visitor);
+
+ virtual ~ScanTask() = default;
+
+ const std::shared_ptr<ScanOptions>& options() const { return options_; }
+ const std::shared_ptr<Fragment>& fragment() const { return fragment_; }
+
+ protected:
+ ScanTask(std::shared_ptr<ScanOptions> options, std::shared_ptr<Fragment> fragment)
+ : options_(std::move(options)), fragment_(std::move(fragment)) {}
+
+ std::shared_ptr<ScanOptions> options_;
+ std::shared_ptr<Fragment> fragment_;
+};
+
+/// \brief Combines a record batch with the fragment that the record batch originated
+/// from
+///
+/// Knowing the source fragment can be useful for debugging & understanding loaded data
+struct TaggedRecordBatch {
+ std::shared_ptr<RecordBatch> record_batch;
+ std::shared_ptr<Fragment> fragment;
+};
+using TaggedRecordBatchGenerator = std::function<Future<TaggedRecordBatch>()>;
+using TaggedRecordBatchIterator = Iterator<TaggedRecordBatch>;
+
+/// \brief Combines a tagged batch with positional information
+///
+/// This is returned when scanning batches in an unordered fashion. This information is
+/// needed if you ever want to reassemble the batches in order
+struct EnumeratedRecordBatch {
+ Enumerated<std::shared_ptr<RecordBatch>> record_batch;
+ Enumerated<std::shared_ptr<Fragment>> fragment;
+};
+using EnumeratedRecordBatchGenerator = std::function<Future<EnumeratedRecordBatch>()>;
+using EnumeratedRecordBatchIterator = Iterator<EnumeratedRecordBatch>;
+
+/// @}
+
+} // namespace dataset
+
+template <>
+struct IterationTraits<dataset::TaggedRecordBatch> {
+ static dataset::TaggedRecordBatch End() {
+ return dataset::TaggedRecordBatch{NULLPTR, NULLPTR};
+ }
+ static bool IsEnd(const dataset::TaggedRecordBatch& val) {
+ return val.record_batch == NULLPTR;
+ }
+};
+
+template <>
+struct IterationTraits<dataset::EnumeratedRecordBatch> {
+ static dataset::EnumeratedRecordBatch End() {
+ return dataset::EnumeratedRecordBatch{
+ IterationEnd<Enumerated<std::shared_ptr<RecordBatch>>>(),
+ IterationEnd<Enumerated<std::shared_ptr<dataset::Fragment>>>()};
+ }
+ static bool IsEnd(const dataset::EnumeratedRecordBatch& val) {
+ return IsIterationEnd(val.fragment);
+ }
+};
+
+namespace dataset {
+
+/// \defgroup dataset-scanning Scanning API
+///
+/// @{
+
+/// \brief A scanner glues together several dataset classes to load in data.
+/// The dataset contains a collection of fragments and partitioning rules.
+///
+/// The fragments identify independently loadable units of data (i.e. each fragment has
+/// a potentially unique schema and possibly even format. It should be possible to read
+/// fragments in parallel if desired).
+///
+/// The fragment's format contains the logic necessary to actually create a task to load
+/// the fragment into memory. That task may or may not support parallel execution of
+/// its own.
+///
+/// The scanner is then responsible for creating scan tasks from every fragment in the
+/// dataset and (potentially) sequencing the loaded record batches together.
+///
+/// The scanner should not buffer the entire dataset in memory (unless asked) instead
+/// yielding record batches as soon as they are ready to scan. Various readahead
+/// properties control how much data is allowed to be scanned before pausing to let a
+/// slow consumer catchup.
+///
+/// Today the scanner also handles projection & filtering although that may change in
+/// the future.
+class ARROW_DS_EXPORT Scanner {
+ public:
+ virtual ~Scanner() = default;
+
+ /// \brief The Scan operator returns a stream of ScanTask. The caller is
+ /// responsible to dispatch/schedule said tasks. Tasks should be safe to run
+ /// in a concurrent fashion and outlive the iterator.
+ ///
+ /// Note: Not supported by the async scanner
+ /// Planned for removal from the public API in ARROW-11782.
+ ARROW_DEPRECATED("Deprecated in 4.0.0 for removal in 5.0.0. Use ScanBatches().")
+ virtual Result<ScanTaskIterator> Scan();
+
+ /// \brief Apply a visitor to each RecordBatch as it is scanned. If multiple threads
+ /// are used (via use_threads), the visitor will be invoked from those threads and is
+ /// responsible for any synchronization.
+ virtual Status Scan(std::function<Status(TaggedRecordBatch)> visitor) = 0;
+ /// \brief Convert a Scanner into a Table.
+ ///
+ /// Use this convenience utility with care. This will serially materialize the
+ /// Scan result in memory before creating the Table.
+ virtual Result<std::shared_ptr<Table>> ToTable() = 0;
+ /// \brief Scan the dataset into a stream of record batches. Each batch is tagged
+ /// with the fragment it originated from. The batches will arrive in order. The
+ /// order of fragments is determined by the dataset.
+ ///
+ /// Note: The scanner will perform some readahead but will avoid materializing too
+ /// much in memory (this is goverended by the readahead options and use_threads option).
+ /// If the readahead queue fills up then I/O will pause until the calling thread catches
+ /// up.
+ virtual Result<TaggedRecordBatchIterator> ScanBatches() = 0;
+ virtual Result<TaggedRecordBatchGenerator> ScanBatchesAsync() = 0;
+ /// \brief Scan the dataset into a stream of record batches. Unlike ScanBatches this
+ /// method may allow record batches to be returned out of order. This allows for more
+ /// efficient scanning: some fragments may be accessed more quickly than others (e.g.
+ /// may be cached in RAM or just happen to get scheduled earlier by the I/O)
+ ///
+ /// To make up for the out-of-order iteration each batch is further tagged with
+ /// positional information.
+ virtual Result<EnumeratedRecordBatchIterator> ScanBatchesUnordered();
+ virtual Result<EnumeratedRecordBatchGenerator> ScanBatchesUnorderedAsync() = 0;
+ /// \brief A convenience to synchronously load the given rows by index.
+ ///
+ /// Will only consume as many batches as needed from ScanBatches().
+ virtual Result<std::shared_ptr<Table>> TakeRows(const Array& indices);
+ /// \brief Get the first N rows.
+ virtual Result<std::shared_ptr<Table>> Head(int64_t num_rows);
+ /// \brief Count rows matching a predicate.
+ ///
+ /// This method will push down the predicate and compute the result based on fragment
+ /// metadata if possible.
+ virtual Result<int64_t> CountRows();
+ /// \brief Convert the Scanner to a RecordBatchReader so it can be
+ /// easily used with APIs that expect a reader.
+ Result<std::shared_ptr<RecordBatchReader>> ToRecordBatchReader();
+
+ /// \brief Get the options for this scan.
+ const std::shared_ptr<ScanOptions>& options() const { return scan_options_; }
+ /// \brief Get the dataset that this scanner will scan
+ virtual const std::shared_ptr<Dataset>& dataset() const = 0;
+
+ protected:
+ explicit Scanner(std::shared_ptr<ScanOptions> scan_options)
+ : scan_options_(std::move(scan_options)) {}
+
+ Result<EnumeratedRecordBatchIterator> AddPositioningToInOrderScan(
+ TaggedRecordBatchIterator scan);
+
+ const std::shared_ptr<ScanOptions> scan_options_;
+};
+
+/// \brief ScannerBuilder is a factory class to construct a Scanner. It is used
+/// to pass information, notably a potential filter expression and a subset of
+/// columns to materialize.
+class ARROW_DS_EXPORT ScannerBuilder {
+ public:
+ explicit ScannerBuilder(std::shared_ptr<Dataset> dataset);
+
+ ScannerBuilder(std::shared_ptr<Dataset> dataset,
+ std::shared_ptr<ScanOptions> scan_options);
+
+ ScannerBuilder(std::shared_ptr<Schema> schema, std::shared_ptr<Fragment> fragment,
+ std::shared_ptr<ScanOptions> scan_options);
+
+ /// \brief Make a scanner from a record batch reader.
+ ///
+ /// The resulting scanner can be scanned only once. This is intended
+ /// to support writing data from streaming sources or other sources
+ /// that can be iterated only once.
+ static std::shared_ptr<ScannerBuilder> FromRecordBatchReader(
+ std::shared_ptr<RecordBatchReader> reader);
+
+ /// \brief Set the subset of columns to materialize.
+ ///
+ /// Columns which are not referenced may not be read from fragments.
+ ///
+ /// \param[in] columns list of columns to project. Order and duplicates will
+ /// be preserved.
+ ///
+ /// \return Failure if any column name does not exists in the dataset's
+ /// Schema.
+ Status Project(std::vector<std::string> columns);
+
+ /// \brief Set expressions which will be evaluated to produce the materialized
+ /// columns.
+ ///
+ /// Columns which are not referenced may not be read from fragments.
+ ///
+ /// \param[in] exprs expressions to evaluate to produce columns.
+ /// \param[in] names list of names for the resulting columns.
+ ///
+ /// \return Failure if any referenced column does not exists in the dataset's
+ /// Schema.
+ Status Project(std::vector<compute::Expression> exprs, std::vector<std::string> names);
+
+ /// \brief Set the filter expression to return only rows matching the filter.
+ ///
+ /// The predicate will be passed down to Sources and corresponding
+ /// Fragments to exploit predicate pushdown if possible using
+ /// partition information or Fragment internal metadata, e.g. Parquet statistics.
+ /// Columns which are not referenced may not be read from fragments.
+ ///
+ /// \param[in] filter expression to filter rows with.
+ ///
+ /// \return Failure if any referenced columns does not exist in the dataset's
+ /// Schema.
+ Status Filter(const compute::Expression& filter);
+
+ /// \brief Indicate if the Scanner should make use of the available
+ /// ThreadPool found in ScanOptions;
+ Status UseThreads(bool use_threads = true);
+
+ /// \brief Limit how many fragments the scanner will read at once
+ ///
+ /// Note: This is only enforced in "async" mode
+ Status FragmentReadahead(int fragment_readahead);
+
+ /// \brief Indicate if the Scanner should run in experimental "async" mode
+ ///
+ /// This mode should have considerably better performance on high-latency or parallel
+ /// filesystems but is still experimental
+ Status UseAsync(bool use_async = true);
+
+ /// \brief Set the maximum number of rows per RecordBatch.
+ ///
+ /// \param[in] batch_size the maximum number of rows.
+ /// \returns An error if the number for batch is not greater than 0.
+ ///
+ /// This option provides a control limiting the memory owned by any RecordBatch.
+ Status BatchSize(int64_t batch_size);
+
+ /// \brief Set the pool from which materialized and scanned arrays will be allocated.
+ Status Pool(MemoryPool* pool);
+
+ /// \brief Set fragment-specific scan options.
+ Status FragmentScanOptions(std::shared_ptr<FragmentScanOptions> fragment_scan_options);
+
+ /// \brief Return the constructed now-immutable Scanner object
+ Result<std::shared_ptr<Scanner>> Finish();
+
+ const std::shared_ptr<Schema>& schema() const;
+ const std::shared_ptr<Schema>& projected_schema() const;
+
+ private:
+ std::shared_ptr<Dataset> dataset_;
+ std::shared_ptr<ScanOptions> scan_options_ = std::make_shared<ScanOptions>();
+};
+
+/// \brief Construct a source ExecNode which yields batches from a dataset scan.
+///
+/// Does not construct associated filter or project nodes.
+/// Yielded batches will be augmented with fragment/batch indices to enable stable
+/// ordering for simple ExecPlans.
+class ARROW_DS_EXPORT ScanNodeOptions : public compute::ExecNodeOptions {
+ public:
+ explicit ScanNodeOptions(
+ std::shared_ptr<Dataset> dataset, std::shared_ptr<ScanOptions> scan_options,
+ std::shared_ptr<util::AsyncToggle> backpressure_toggle = NULLPTR,
+ bool require_sequenced_output = false)
+ : dataset(std::move(dataset)),
+ scan_options(std::move(scan_options)),
+ backpressure_toggle(std::move(backpressure_toggle)),
+ require_sequenced_output(require_sequenced_output) {}
+
+ std::shared_ptr<Dataset> dataset;
+ std::shared_ptr<ScanOptions> scan_options;
+ std::shared_ptr<util::AsyncToggle> backpressure_toggle;
+ bool require_sequenced_output;
+};
+
+/// @}
+
+/// \brief A trivial ScanTask that yields the RecordBatch of an array.
+class ARROW_DS_EXPORT InMemoryScanTask : public ScanTask {
+ public:
+ InMemoryScanTask(std::vector<std::shared_ptr<RecordBatch>> record_batches,
+ std::shared_ptr<ScanOptions> options,
+ std::shared_ptr<Fragment> fragment)
+ : ScanTask(std::move(options), std::move(fragment)),
+ record_batches_(std::move(record_batches)) {}
+
+ Result<RecordBatchIterator> Execute() override;
+
+ protected:
+ std::vector<std::shared_ptr<RecordBatch>> record_batches_;
+};
+
+namespace internal {
+ARROW_DS_EXPORT void InitializeScanner(arrow::compute::ExecFactoryRegistry* registry);
+} // namespace internal
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/scanner_benchmark.cc b/src/arrow/cpp/src/arrow/dataset/scanner_benchmark.cc
new file mode 100644
index 000000000..e3021794c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/scanner_benchmark.cc
@@ -0,0 +1,210 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/api.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/test_util.h"
+#include "arrow/dataset/dataset.h"
+#include "arrow/dataset/plan.h"
+#include "arrow/dataset/scanner.h"
+#include "arrow/dataset/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+#include "arrow/testing/random.h"
+
+namespace arrow {
+namespace compute {
+
+constexpr auto kSeed = 0x0ff1ce;
+
+void GenerateBatchesFromSchema(const std::shared_ptr<Schema>& schema, size_t num_batches,
+ BatchesWithSchema* out_batches, int multiplicity = 1,
+ int64_t batch_size = 4) {
+ ::arrow::random::RandomArrayGenerator rng_(kSeed);
+ if (num_batches == 0) {
+ auto empty_record_batch = ExecBatch(*rng_.BatchOf(schema->fields(), 0));
+ out_batches->batches.push_back(empty_record_batch);
+ } else {
+ for (size_t j = 0; j < num_batches; j++) {
+ out_batches->batches.push_back(
+ ExecBatch(*rng_.BatchOf(schema->fields(), batch_size)));
+ }
+ }
+
+ size_t batch_count = out_batches->batches.size();
+ for (int repeat = 1; repeat < multiplicity; ++repeat) {
+ for (size_t i = 0; i < batch_count; ++i) {
+ out_batches->batches.push_back(out_batches->batches[i]);
+ }
+ }
+ out_batches->schema = schema;
+}
+
+RecordBatchVector GenerateBatches(const std::shared_ptr<Schema>& schema,
+ size_t num_batches, size_t batch_size) {
+ BatchesWithSchema input_batches;
+
+ RecordBatchVector batches;
+ GenerateBatchesFromSchema(schema, num_batches, &input_batches, 1, batch_size);
+
+ for (const auto& batch : input_batches.batches) {
+ batches.push_back(batch.ToRecordBatch(schema).MoveValueUnsafe());
+ }
+ return batches;
+}
+
+} // namespace compute
+
+namespace dataset {
+
+static std::map<std::pair<size_t, size_t>, RecordBatchVector> datasets;
+
+void StoreBatches(size_t num_batches, size_t batch_size,
+ const RecordBatchVector& batches) {
+ datasets[std::make_pair(num_batches, batch_size)] = batches;
+}
+
+RecordBatchVector GetBatches(size_t num_batches, size_t batch_size) {
+ auto iter = datasets.find(std::make_pair(num_batches, batch_size));
+ if (iter == datasets.end()) {
+ return RecordBatchVector{};
+ }
+ return iter->second;
+}
+
+std::shared_ptr<Schema> GetSchema() {
+ static std::shared_ptr<Schema> s = schema({field("a", int32()), field("b", boolean())});
+ return s;
+}
+
+size_t GetBytesForSchema() { return sizeof(int32_t) + sizeof(bool); }
+
+void MinimalEndToEndScan(size_t num_batches, size_t batch_size, bool async_mode) {
+ // NB: This test is here for didactic purposes
+
+ // Specify a MemoryPool and ThreadPool for the ExecPlan
+ compute::ExecContext exec_context(default_memory_pool(),
+ ::arrow::internal::GetCpuThreadPool());
+
+ // ensure arrow::dataset node factories are in the registry
+ ::arrow::dataset::internal::Initialize();
+
+ // A ScanNode is constructed from an ExecPlan (into which it is inserted),
+ // a Dataset (whose batches will be scanned), and ScanOptions (to specify a filter for
+ // predicate pushdown, a projection to skip materialization of unnecessary columns,
+ // ...)
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<compute::ExecPlan> plan,
+ compute::ExecPlan::Make(&exec_context));
+
+ RecordBatchVector batches = GetBatches(num_batches, batch_size);
+
+ std::shared_ptr<Dataset> dataset =
+ std::make_shared<InMemoryDataset>(GetSchema(), batches);
+
+ auto options = std::make_shared<ScanOptions>();
+ // sync scanning is not supported by ScanNode
+ options->use_async = true;
+ // specify the filter
+ compute::Expression b_is_true = field_ref("b");
+ options->filter = b_is_true;
+ // for now, specify the projection as the full project expression (eventually this can
+ // just be a list of materialized field names)
+ compute::Expression a_times_2 = call("multiply", {field_ref("a"), literal(2)});
+ options->projection =
+ call("make_struct", {a_times_2}, compute::MakeStructOptions{{"a * 2"}});
+
+ // construct the scan node
+ ASSERT_OK_AND_ASSIGN(
+ compute::ExecNode * scan,
+ compute::MakeExecNode("scan", plan.get(), {}, ScanNodeOptions{dataset, options}));
+
+ // pipe the scan node into a filter node
+ ASSERT_OK_AND_ASSIGN(
+ compute::ExecNode * filter,
+ compute::MakeExecNode("filter", plan.get(), {scan},
+ compute::FilterNodeOptions{b_is_true, async_mode}));
+
+ // pipe the filter node into a project node
+ // NB: we're using the project node factory which preserves fragment/batch index
+ // tagging, so we *can* reorder later if we choose. The tags will not appear in
+ // our output.
+ ASSERT_OK_AND_ASSIGN(
+ compute::ExecNode * project,
+ compute::MakeExecNode("augmented_project", plan.get(), {filter},
+ compute::ProjectNodeOptions{{a_times_2}, {}, async_mode}));
+
+ // finally, pipe the project node into a sink node
+ AsyncGenerator<util::optional<compute::ExecBatch>> sink_gen;
+ ASSERT_OK_AND_ASSIGN(compute::ExecNode * sink,
+ compute::MakeExecNode("sink", plan.get(), {project},
+ compute::SinkNodeOptions{&sink_gen}));
+
+ ASSERT_NE(sink, nullptr);
+
+ // translate sink_gen (async) to sink_reader (sync)
+ std::shared_ptr<RecordBatchReader> sink_reader = compute::MakeGeneratorReader(
+ schema({field("a * 2", int32())}), std::move(sink_gen), exec_context.memory_pool());
+
+ // start the ExecPlan
+ ASSERT_OK(plan->StartProducing());
+
+ // collect sink_reader into a Table
+ ASSERT_OK_AND_ASSIGN(auto collected, Table::FromRecordBatchReader(sink_reader.get()));
+
+ ASSERT_GT(collected->num_rows(), 0);
+
+ // wait 1s for completion
+ ASSERT_TRUE(plan->finished().Wait(/*seconds=*/1)) << "ExecPlan didn't finish within 1s";
+}
+
+static void MinimalEndToEndBench(benchmark::State& state) {
+ size_t num_batches = state.range(0);
+ size_t batch_size = state.range(1);
+ bool async_mode = state.range(2);
+
+ for (auto _ : state) {
+ MinimalEndToEndScan(num_batches, batch_size, async_mode);
+ }
+ state.SetItemsProcessed(state.iterations() * num_batches);
+ state.SetBytesProcessed(state.iterations() * num_batches * batch_size *
+ GetBytesForSchema());
+}
+
+static const std::vector<int32_t> kWorkload = {100, 1000, 10000, 100000};
+
+static void MinimalEndToEnd_Customize(benchmark::internal::Benchmark* b) {
+ for (const int32_t num_batches : kWorkload) {
+ for (const int batch_size : {10, 100, 1000}) {
+ for (const bool async_mode : {true, false}) {
+ b->Args({num_batches, batch_size, async_mode});
+ RecordBatchVector batches =
+ ::arrow::compute::GenerateBatches(GetSchema(), num_batches, batch_size);
+ StoreBatches(num_batches, batch_size, batches);
+ }
+ }
+ }
+ b->ArgNames({"num_batches", "batch_size", "async_mode"});
+ b->UseRealTime();
+}
+
+BENCHMARK(MinimalEndToEndBench)->Apply(MinimalEndToEnd_Customize);
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/scanner_internal.h b/src/arrow/cpp/src/arrow/dataset/scanner_internal.h
new file mode 100644
index 000000000..7a43feb61
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/scanner_internal.h
@@ -0,0 +1,264 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <utility>
+
+#include "arrow/array/array_nested.h"
+#include "arrow/array/util.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/exec.h"
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/partition.h"
+#include "arrow/dataset/scanner.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::Executor;
+
+namespace dataset {
+
+inline Result<std::shared_ptr<RecordBatch>> FilterSingleBatch(
+ const std::shared_ptr<RecordBatch>& in, const compute::Expression& filter,
+ const std::shared_ptr<ScanOptions>& options) {
+ compute::ExecContext exec_context{options->pool};
+ ARROW_ASSIGN_OR_RAISE(
+ Datum mask,
+ ExecuteScalarExpression(filter, *options->dataset_schema, in, &exec_context));
+
+ if (mask.is_scalar()) {
+ const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
+ if (mask_scalar.is_valid && mask_scalar.value) {
+ return in;
+ }
+ return in->Slice(0, 0);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ Datum filtered,
+ compute::Filter(in, mask, compute::FilterOptions::Defaults(), &exec_context));
+ return filtered.record_batch();
+}
+
+inline RecordBatchIterator FilterRecordBatch(
+ RecordBatchIterator it, compute::Expression filter,
+ const std::shared_ptr<ScanOptions>& options) {
+ return MakeMaybeMapIterator(
+ [=](std::shared_ptr<RecordBatch> in) -> Result<std::shared_ptr<RecordBatch>> {
+ return FilterSingleBatch(in, filter, options);
+ },
+ std::move(it));
+}
+
+inline Result<std::shared_ptr<RecordBatch>> ProjectSingleBatch(
+ const std::shared_ptr<RecordBatch>& in, const compute::Expression& projection,
+ const std::shared_ptr<ScanOptions>& options) {
+ compute::ExecContext exec_context{options->pool};
+ ARROW_ASSIGN_OR_RAISE(
+ Datum projected,
+ ExecuteScalarExpression(projection, *options->dataset_schema, in, &exec_context));
+
+ DCHECK_EQ(projected.type()->id(), Type::STRUCT);
+ if (projected.shape() == ValueDescr::SCALAR) {
+ // Only virtual columns are projected. Broadcast to an array
+ ARROW_ASSIGN_OR_RAISE(projected, MakeArrayFromScalar(*projected.scalar(),
+ in->num_rows(), options->pool));
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto out,
+ RecordBatch::FromStructArray(projected.array_as<StructArray>()));
+
+ return out->ReplaceSchemaMetadata(in->schema()->metadata());
+}
+
+inline RecordBatchIterator ProjectRecordBatch(
+ RecordBatchIterator it, compute::Expression projection,
+ const std::shared_ptr<ScanOptions>& options) {
+ return MakeMaybeMapIterator(
+ [=](std::shared_ptr<RecordBatch> in) -> Result<std::shared_ptr<RecordBatch>> {
+ return ProjectSingleBatch(in, projection, options);
+ },
+ std::move(it));
+}
+
+class FilterAndProjectScanTask : public ScanTask {
+ public:
+ explicit FilterAndProjectScanTask(std::shared_ptr<ScanTask> task,
+ compute::Expression partition)
+ : ScanTask(task->options(), task->fragment()),
+ task_(std::move(task)),
+ partition_(std::move(partition)) {}
+
+ Result<RecordBatchIterator> Execute() override {
+ ARROW_ASSIGN_OR_RAISE(auto it, task_->Execute());
+
+ ARROW_ASSIGN_OR_RAISE(compute::Expression simplified_filter,
+ SimplifyWithGuarantee(options()->filter, partition_));
+
+ ARROW_ASSIGN_OR_RAISE(compute::Expression simplified_projection,
+ SimplifyWithGuarantee(options()->projection, partition_));
+
+ RecordBatchIterator filter_it =
+ FilterRecordBatch(std::move(it), simplified_filter, options_);
+
+ return ProjectRecordBatch(std::move(filter_it), simplified_projection, options_);
+ }
+
+ Result<RecordBatchIterator> ToFilteredAndProjectedIterator(
+ const RecordBatchVector& rbs) {
+ auto it = MakeVectorIterator(rbs);
+ ARROW_ASSIGN_OR_RAISE(compute::Expression simplified_filter,
+ SimplifyWithGuarantee(options()->filter, partition_));
+
+ ARROW_ASSIGN_OR_RAISE(compute::Expression simplified_projection,
+ SimplifyWithGuarantee(options()->projection, partition_));
+
+ RecordBatchIterator filter_it =
+ FilterRecordBatch(std::move(it), simplified_filter, options_);
+
+ return ProjectRecordBatch(std::move(filter_it), simplified_projection, options_);
+ }
+
+ Result<std::shared_ptr<RecordBatch>> FilterAndProjectBatch(
+ const std::shared_ptr<RecordBatch>& batch) {
+ ARROW_ASSIGN_OR_RAISE(compute::Expression simplified_filter,
+ SimplifyWithGuarantee(options()->filter, partition_));
+
+ ARROW_ASSIGN_OR_RAISE(compute::Expression simplified_projection,
+ SimplifyWithGuarantee(options()->projection, partition_));
+ ARROW_ASSIGN_OR_RAISE(auto filtered,
+ FilterSingleBatch(batch, simplified_filter, options_));
+ return ProjectSingleBatch(filtered, simplified_projection, options_);
+ }
+
+ inline Future<RecordBatchVector> SafeExecute(Executor* executor) override {
+ return task_->SafeExecute(executor).Then(
+ // This should only be run via SerialExecutor so it should be safe to capture
+ // `this`
+ [this](const RecordBatchVector& rbs) -> Result<RecordBatchVector> {
+ ARROW_ASSIGN_OR_RAISE(auto projected_it, ToFilteredAndProjectedIterator(rbs));
+ return projected_it.ToVector();
+ });
+ }
+
+ inline Future<> SafeVisit(
+ Executor* executor,
+ std::function<Status(std::shared_ptr<RecordBatch>)> visitor) override {
+ auto filter_and_project_visitor =
+ [this, visitor](const std::shared_ptr<RecordBatch>& batch) {
+ ARROW_ASSIGN_OR_RAISE(auto projected, FilterAndProjectBatch(batch));
+ return visitor(projected);
+ };
+ return task_->SafeVisit(executor, filter_and_project_visitor);
+ }
+
+ private:
+ std::shared_ptr<ScanTask> task_;
+ compute::Expression partition_;
+};
+
+/// \brief GetScanTaskIterator transforms an Iterator<Fragment> in a
+/// flattened Iterator<ScanTask>.
+inline Result<ScanTaskIterator> GetScanTaskIterator(
+ FragmentIterator fragments, std::shared_ptr<ScanOptions> options) {
+ // Fragment -> ScanTaskIterator
+ auto fn = [options](std::shared_ptr<Fragment> fragment) -> Result<ScanTaskIterator> {
+ ARROW_ASSIGN_OR_RAISE(auto scan_task_it, fragment->Scan(options));
+
+ auto partition = fragment->partition_expression();
+ // Apply the filter and/or projection to incoming RecordBatches by
+ // wrapping the ScanTask with a FilterAndProjectScanTask
+ auto wrap_scan_task =
+ [partition](std::shared_ptr<ScanTask> task) -> std::shared_ptr<ScanTask> {
+ return std::make_shared<FilterAndProjectScanTask>(std::move(task), partition);
+ };
+
+ return MakeMapIterator(wrap_scan_task, std::move(scan_task_it));
+ };
+
+ // Iterator<Iterator<ScanTask>>
+ auto maybe_scantask_it = MakeMaybeMapIterator(fn, std::move(fragments));
+
+ // Iterator<ScanTask>
+ return MakeFlattenIterator(std::move(maybe_scantask_it));
+}
+
+inline Status NestedFieldRefsNotImplemented() {
+ // TODO(ARROW-11259) Several functions (for example, IpcScanTask::Make) assume that
+ // only top level fields will be materialized.
+ return Status::NotImplemented("Nested field references in scans.");
+}
+
+inline Status SetProjection(ScanOptions* options, const compute::Expression& projection) {
+ ARROW_ASSIGN_OR_RAISE(options->projection, projection.Bind(*options->dataset_schema));
+
+ if (options->projection.type()->id() != Type::STRUCT) {
+ return Status::Invalid("Projection ", projection.ToString(),
+ " cannot yield record batches");
+ }
+ options->projected_schema = ::arrow::schema(
+ checked_cast<const StructType&>(*options->projection.type()).fields(),
+ options->dataset_schema->metadata());
+
+ return Status::OK();
+}
+
+inline Status SetProjection(ScanOptions* options, std::vector<compute::Expression> exprs,
+ std::vector<std::string> names) {
+ compute::MakeStructOptions project_options{std::move(names)};
+
+ for (size_t i = 0; i < exprs.size(); ++i) {
+ if (auto ref = exprs[i].field_ref()) {
+ if (!ref->name()) return NestedFieldRefsNotImplemented();
+
+ // set metadata and nullability for plain field references
+ ARROW_ASSIGN_OR_RAISE(auto field, ref->GetOne(*options->dataset_schema));
+ project_options.field_nullability[i] = field->nullable();
+ project_options.field_metadata[i] = field->metadata();
+ }
+ }
+
+ return SetProjection(options,
+ call("make_struct", std::move(exprs), std::move(project_options)));
+}
+
+inline Status SetProjection(ScanOptions* options, std::vector<std::string> names) {
+ std::vector<compute::Expression> exprs(names.size());
+ for (size_t i = 0; i < exprs.size(); ++i) {
+ exprs[i] = compute::field_ref(names[i]);
+ }
+ return SetProjection(options, std::move(exprs), std::move(names));
+}
+
+inline Status SetFilter(ScanOptions* options, const compute::Expression& filter) {
+ for (const auto& ref : FieldsInExpression(filter)) {
+ if (!ref.name()) return NestedFieldRefsNotImplemented();
+
+ RETURN_NOT_OK(ref.FindOne(*options->dataset_schema));
+ }
+ ARROW_ASSIGN_OR_RAISE(options->filter, filter.Bind(*options->dataset_schema));
+ return Status::OK();
+}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/scanner_test.cc b/src/arrow/cpp/src/arrow/dataset/scanner_test.cc
new file mode 100644
index 000000000..83151de2c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/scanner_test.cc
@@ -0,0 +1,1814 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dataset/scanner.h"
+
+#include <memory>
+#include <utility>
+
+#include <gmock/gmock.h>
+
+#include "arrow/compute/api.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/dataset/plan.h"
+#include "arrow/dataset/scanner_internal.h"
+#include "arrow/dataset/test_util.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+#include "arrow/testing/async_test_util.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/generator.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/range.h"
+#include "arrow/util/thread_pool.h"
+#include "arrow/util/vector.h"
+
+using testing::ElementsAre;
+using testing::IsEmpty;
+using testing::UnorderedElementsAreArray;
+
+namespace arrow {
+
+using internal::GetCpuThreadPool;
+using internal::Iota;
+
+namespace dataset {
+
+struct TestScannerParams {
+ bool use_async;
+ bool use_threads;
+ int num_child_datasets;
+ int num_batches;
+ int items_per_batch;
+
+ std::string ToString() const {
+ // GTest requires this to be alphanumeric
+ std::stringstream ss;
+ ss << (use_async ? "Async" : "Sync") << (use_threads ? "Threaded" : "Serial")
+ << num_child_datasets << "d" << num_batches << "b" << items_per_batch << "r";
+ return ss.str();
+ }
+
+ static std::string ToTestNameString(
+ const ::testing::TestParamInfo<TestScannerParams>& info) {
+ return std::to_string(info.index) + info.param.ToString();
+ }
+
+ static std::vector<TestScannerParams> Values() {
+ std::vector<TestScannerParams> values;
+ for (int sync = 0; sync < 2; sync++) {
+ for (int use_threads = 0; use_threads < 2; use_threads++) {
+ values.push_back(
+ {static_cast<bool>(sync), static_cast<bool>(use_threads), 1, 1, 1024});
+ values.push_back(
+ {static_cast<bool>(sync), static_cast<bool>(use_threads), 2, 16, 1024});
+ }
+ }
+ return values;
+ }
+};
+
+std::ostream& operator<<(std::ostream& out, const TestScannerParams& params) {
+ out << (params.use_async ? "async-" : "sync-")
+ << (params.use_threads ? "threaded-" : "serial-") << params.num_child_datasets
+ << "d-" << params.num_batches << "b-" << params.items_per_batch << "i";
+ return out;
+}
+
+class TestScanner : public DatasetFixtureMixinWithParam<TestScannerParams> {
+ protected:
+ std::shared_ptr<Scanner> MakeScanner(std::shared_ptr<Dataset> dataset) {
+ ScannerBuilder builder(std::move(dataset), options_);
+ ARROW_EXPECT_OK(builder.UseThreads(GetParam().use_threads));
+ ARROW_EXPECT_OK(builder.UseAsync(GetParam().use_async));
+ EXPECT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+ return scanner;
+ }
+
+ std::shared_ptr<Scanner> MakeScanner(std::shared_ptr<RecordBatch> batch) {
+ std::vector<std::shared_ptr<RecordBatch>> batches{
+ static_cast<size_t>(GetParam().num_batches), batch};
+
+ DatasetVector children{static_cast<size_t>(GetParam().num_child_datasets),
+ std::make_shared<InMemoryDataset>(batch->schema(), batches)};
+
+ EXPECT_OK_AND_ASSIGN(auto dataset, UnionDataset::Make(batch->schema(), children));
+ return MakeScanner(std::move(dataset));
+ }
+
+ void AssertScannerEqualsRepetitionsOf(
+ std::shared_ptr<Scanner> scanner, std::shared_ptr<RecordBatch> batch,
+ const int64_t total_batches = GetParam().num_child_datasets *
+ GetParam().num_batches) {
+ auto expected = ConstantArrayGenerator::Repeat(total_batches, batch);
+
+ // Verifies that the unified BatchReader is equivalent to flattening all the
+ // structures of the scanner, i.e. Scanner[Dataset[ScanTask[RecordBatch]]]
+ AssertScannerEquals(expected.get(), scanner.get());
+ }
+
+ void AssertScanBatchesEqualRepetitionsOf(
+ std::shared_ptr<Scanner> scanner, std::shared_ptr<RecordBatch> batch,
+ const int64_t total_batches = GetParam().num_child_datasets *
+ GetParam().num_batches) {
+ auto expected = ConstantArrayGenerator::Repeat(total_batches, batch);
+
+ AssertScanBatchesEquals(expected.get(), scanner.get());
+ }
+
+ void AssertScanBatchesUnorderedEqualRepetitionsOf(
+ std::shared_ptr<Scanner> scanner, std::shared_ptr<RecordBatch> batch,
+ const int64_t total_batches = GetParam().num_child_datasets *
+ GetParam().num_batches) {
+ auto expected = ConstantArrayGenerator::Repeat(total_batches, batch);
+
+ AssertScanBatchesUnorderedEquals(expected.get(), scanner.get(), 1);
+ }
+};
+
+TEST_P(TestScanner, Scan) {
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_);
+ AssertScanBatchesUnorderedEqualRepetitionsOf(MakeScanner(batch), batch);
+}
+
+TEST_P(TestScanner, ScanBatches) {
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_);
+ AssertScanBatchesEqualRepetitionsOf(MakeScanner(batch), batch);
+}
+
+TEST_P(TestScanner, ScanBatchesUnordered) {
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_);
+ AssertScanBatchesUnorderedEqualRepetitionsOf(MakeScanner(batch), batch);
+}
+
+TEST_P(TestScanner, ScanWithCappedBatchSize) {
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_);
+ options_->batch_size = GetParam().items_per_batch / 2;
+ auto expected = batch->Slice(GetParam().items_per_batch / 2);
+ AssertScanBatchesEqualRepetitionsOf(
+ MakeScanner(batch), expected,
+ GetParam().num_child_datasets * GetParam().num_batches * 2);
+}
+
+TEST_P(TestScanner, FilteredScan) {
+ SetSchema({field("f64", float64())});
+
+ double value = 0.5;
+ ASSERT_OK_AND_ASSIGN(auto f64,
+ ArrayFromBuilderVisitor(float64(), GetParam().items_per_batch,
+ GetParam().items_per_batch / 2,
+ [&](DoubleBuilder* builder) {
+ builder->UnsafeAppend(value);
+ builder->UnsafeAppend(-value);
+ value += 1.0;
+ }));
+
+ SetFilter(greater(field_ref("f64"), literal(0.0)));
+
+ auto batch = RecordBatch::Make(schema_, f64->length(), {f64});
+
+ value = 0.5;
+ ASSERT_OK_AND_ASSIGN(auto f64_filtered,
+ ArrayFromBuilderVisitor(float64(), GetParam().items_per_batch / 2,
+ [&](DoubleBuilder* builder) {
+ builder->UnsafeAppend(value);
+ value += 1.0;
+ }));
+
+ auto filtered_batch =
+ RecordBatch::Make(schema_, f64_filtered->length(), {f64_filtered});
+
+ AssertScanBatchesEqualRepetitionsOf(MakeScanner(batch), filtered_batch);
+}
+
+TEST_P(TestScanner, ProjectedScan) {
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ SetProjectedColumns({"i32"});
+ auto batch_in = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_);
+ auto batch_out = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch,
+ schema({field("i32", int32())}));
+ AssertScanBatchesUnorderedEqualRepetitionsOf(MakeScanner(batch_in), batch_out);
+}
+
+TEST_P(TestScanner, MaterializeMissingColumn) {
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch_missing_f64 = ConstantArrayGenerator::Zeroes(
+ GetParam().items_per_batch, schema({field("i32", int32())}));
+
+ auto fragment_missing_f64 = std::make_shared<InMemoryFragment>(
+ RecordBatchVector{
+ static_cast<size_t>(GetParam().num_child_datasets * GetParam().num_batches),
+ batch_missing_f64},
+ equal(field_ref("f64"), literal(2.5)));
+
+ ASSERT_OK_AND_ASSIGN(auto f64,
+ ArrayFromBuilderVisitor(
+ float64(), GetParam().items_per_batch,
+ [&](DoubleBuilder* builder) { builder->UnsafeAppend(2.5); }));
+ auto batch_with_f64 =
+ RecordBatch::Make(schema_, f64->length(), {batch_missing_f64->column(0), f64});
+
+ FragmentVector fragments{fragment_missing_f64};
+ auto dataset = std::make_shared<FragmentDataset>(schema_, fragments);
+ auto scanner = MakeScanner(std::move(dataset));
+ AssertScanBatchesEqualRepetitionsOf(scanner, batch_with_f64);
+}
+
+TEST_P(TestScanner, ToTable) {
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_);
+ std::vector<std::shared_ptr<RecordBatch>> batches{
+ static_cast<std::size_t>(GetParam().num_batches * GetParam().num_child_datasets),
+ batch};
+
+ ASSERT_OK_AND_ASSIGN(auto expected, Table::FromRecordBatches(batches));
+
+ auto scanner = MakeScanner(batch);
+ std::shared_ptr<Table> actual;
+
+ // There is no guarantee on the ordering when using multiple threads, but
+ // since the RecordBatch is always the same it will pass.
+ ASSERT_OK_AND_ASSIGN(actual, scanner->ToTable());
+ AssertTablesEqual(*expected, *actual, /*same_chunk_layout=*/false);
+}
+
+TEST_P(TestScanner, ScanWithVisitor) {
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_);
+ auto scanner = MakeScanner(batch);
+ ASSERT_OK(scanner->Scan([batch](TaggedRecordBatch scanned_batch) {
+ AssertBatchesEqual(*batch, *scanned_batch.record_batch, /*same_chunk_layout=*/false);
+ return Status::OK();
+ }));
+}
+
+TEST_P(TestScanner, TakeIndices) {
+ auto batch_size = GetParam().items_per_batch;
+ auto num_batches = GetParam().num_batches;
+ auto num_datasets = GetParam().num_child_datasets;
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ ArrayVector arrays(2);
+ ArrayFromVector<Int32Type>(Iota<int32_t>(batch_size), &arrays[0]);
+ ArrayFromVector<DoubleType>(Iota<double>(static_cast<double>(batch_size)), &arrays[1]);
+ auto batch = RecordBatch::Make(schema_, batch_size, arrays);
+
+ auto scanner = MakeScanner(batch);
+
+ std::shared_ptr<Array> indices;
+ {
+ ArrayFromVector<Int64Type>(Iota(batch_size), &indices);
+ ASSERT_OK_AND_ASSIGN(auto taken, scanner->TakeRows(*indices));
+ ASSERT_OK_AND_ASSIGN(auto expected, Table::FromRecordBatches({batch}));
+ ASSERT_EQ(expected->num_rows(), batch_size);
+ AssertTablesEqual(*expected, *taken, /*same_chunk_layout=*/false);
+ }
+ {
+ ArrayFromVector<Int64Type>({7, 5, 3, 1}, &indices);
+ ASSERT_OK_AND_ASSIGN(auto taken, scanner->TakeRows(*indices));
+ ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable());
+ ASSERT_OK_AND_ASSIGN(auto expected, compute::Take(table, *indices));
+ ASSERT_EQ(expected.table()->num_rows(), 4);
+ AssertTablesEqual(*expected.table(), *taken, /*same_chunk_layout=*/false);
+ }
+ if (num_batches > 1) {
+ ArrayFromVector<Int64Type>({batch_size + 2, batch_size + 1}, &indices);
+ ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable());
+ ASSERT_OK_AND_ASSIGN(auto taken, scanner->TakeRows(*indices));
+ ASSERT_OK_AND_ASSIGN(auto expected, compute::Take(table, *indices));
+ ASSERT_EQ(expected.table()->num_rows(), 2);
+ AssertTablesEqual(*expected.table(), *taken, /*same_chunk_layout=*/false);
+ }
+ if (num_batches > 1) {
+ ArrayFromVector<Int64Type>({1, 3, 5, 7, batch_size + 1, 2 * batch_size + 2},
+ &indices);
+ ASSERT_OK_AND_ASSIGN(auto taken, scanner->TakeRows(*indices));
+ ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable());
+ ASSERT_OK_AND_ASSIGN(auto expected, compute::Take(table, *indices));
+ ASSERT_EQ(expected.table()->num_rows(), 6);
+ AssertTablesEqual(*expected.table(), *taken, /*same_chunk_layout=*/false);
+ }
+ {
+ auto base = num_datasets * num_batches * batch_size;
+ ArrayFromVector<Int64Type>({base + 1}, &indices);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ IndexError,
+ ::testing::HasSubstr("Some indices were out of bounds: " +
+ std::to_string(base + 1)),
+ scanner->TakeRows(*indices));
+ }
+ {
+ auto base = num_datasets * num_batches * batch_size;
+ ArrayFromVector<Int64Type>(
+ {1, 2, base + 1, base + 2, base + 3, base + 4, base + 5, base + 6}, &indices);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ IndexError,
+ ::testing::HasSubstr(
+ "Some indices were out of bounds: " + std::to_string(base + 1) + ", " +
+ std::to_string(base + 2) + ", " + std::to_string(base + 3) + ", ..."),
+ scanner->TakeRows(*indices));
+ }
+}
+
+TEST_P(TestScanner, CountRows) {
+ const auto items_per_batch = GetParam().items_per_batch;
+ const auto num_batches = GetParam().num_batches;
+ const auto num_datasets = GetParam().num_child_datasets;
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ ArrayVector arrays(2);
+ ArrayFromVector<Int32Type>(Iota<int32_t>(static_cast<int32_t>(items_per_batch)),
+ &arrays[0]);
+ ArrayFromVector<DoubleType>(Iota<double>(static_cast<double>(items_per_batch)),
+ &arrays[1]);
+ auto batch = RecordBatch::Make(schema_, items_per_batch, arrays);
+ auto scanner = MakeScanner(batch);
+
+ ASSERT_OK_AND_ASSIGN(auto rows, scanner->CountRows());
+ ASSERT_EQ(rows, num_datasets * num_batches * items_per_batch);
+
+ ASSERT_OK_AND_ASSIGN(options_->filter,
+ greater_equal(field_ref("i32"), literal(64)).Bind(*schema_));
+ ASSERT_OK_AND_ASSIGN(rows, scanner->CountRows());
+ ASSERT_EQ(rows, num_datasets * num_batches * (items_per_batch - 64));
+}
+
+TEST_P(TestScanner, EmptyFragment) {
+ // Regression test for ARROW-13982
+ if (!GetParam().use_async) GTEST_SKIP() << "Test only applies to async scanner";
+
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_);
+ auto empty_batch = ConstantArrayGenerator::Zeroes(0, schema_);
+ std::vector<std::shared_ptr<RecordBatch>> batches{
+ static_cast<std::size_t>(GetParam().num_batches * GetParam().num_child_datasets),
+ batch};
+
+ FragmentVector fragments{
+ std::make_shared<InMemoryFragment>(RecordBatchVector{empty_batch}),
+ std::make_shared<InMemoryFragment>(batches)};
+ auto dataset = std::make_shared<FragmentDataset>(schema_, fragments);
+ auto scanner = MakeScanner(dataset);
+
+ // There is no guarantee on the ordering when using multiple threads, but
+ // since the RecordBatch is always the same (or empty) it will pass.
+ ASSERT_OK_AND_ASSIGN(auto gen, scanner->ScanBatchesAsync());
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto tagged, CollectAsyncGenerator(gen));
+ RecordBatchVector actual_batches;
+ for (const auto& batch : tagged) {
+ actual_batches.push_back(batch.record_batch);
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto expected, Table::FromRecordBatches(batches));
+ ASSERT_OK_AND_ASSIGN(auto actual, Table::FromRecordBatches(std::move(actual_batches)));
+ AssertTablesEqual(*expected, *actual, /*same_chunk_layout=*/false);
+}
+
+class CountRowsOnlyFragment : public InMemoryFragment {
+ public:
+ using InMemoryFragment::InMemoryFragment;
+
+ Future<util::optional<int64_t>> CountRows(
+ compute::Expression predicate, const std::shared_ptr<ScanOptions>&) override {
+ if (compute::FieldsInExpression(predicate).size() > 0) {
+ return Future<util::optional<int64_t>>::MakeFinished(util::nullopt);
+ }
+ int64_t sum = 0;
+ for (const auto& batch : record_batches_) {
+ sum += batch->num_rows();
+ }
+ return Future<util::optional<int64_t>>::MakeFinished(sum);
+ }
+ Result<ScanTaskIterator> Scan(std::shared_ptr<ScanOptions>) override {
+ return Status::Invalid("Don't scan me!");
+ }
+ Result<RecordBatchGenerator> ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>&) override {
+ return Status::Invalid("Don't scan me!");
+ }
+};
+
+class ScanOnlyFragment : public InMemoryFragment {
+ public:
+ using InMemoryFragment::InMemoryFragment;
+
+ Future<util::optional<int64_t>> CountRows(
+ compute::Expression predicate, const std::shared_ptr<ScanOptions>&) override {
+ return Future<util::optional<int64_t>>::MakeFinished(util::nullopt);
+ }
+ Result<ScanTaskIterator> Scan(std::shared_ptr<ScanOptions> options) override {
+ auto self = shared_from_this();
+ ScanTaskVector tasks{
+ std::make_shared<InMemoryScanTask>(record_batches_, options, self)};
+ return MakeVectorIterator(std::move(tasks));
+ }
+ Result<RecordBatchGenerator> ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>&) override {
+ return MakeVectorGenerator(record_batches_);
+ }
+};
+
+// Ensure the pipeline does not break on an empty batch
+TEST_P(TestScanner, CountRowsEmpty) {
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto empty_batch = ConstantArrayGenerator::Zeroes(0, schema_);
+ auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_);
+ RecordBatchVector batches = {empty_batch, batch};
+ ScannerBuilder builder(
+ std::make_shared<FragmentDataset>(
+ schema_, FragmentVector{std::make_shared<ScanOnlyFragment>(batches)}),
+ options_);
+ ASSERT_OK(builder.UseAsync(GetParam().use_async));
+ ASSERT_OK(builder.UseThreads(GetParam().use_threads));
+ ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+ ASSERT_OK_AND_EQ(batch->num_rows(), scanner->CountRows());
+}
+
+// Regression test for ARROW-12668: ensure failures are properly handled
+class CountFailFragment : public InMemoryFragment {
+ public:
+ explicit CountFailFragment(RecordBatchVector record_batches)
+ : InMemoryFragment(std::move(record_batches)),
+ count(Future<util::optional<int64_t>>::Make()) {}
+
+ Future<util::optional<int64_t>> CountRows(
+ compute::Expression, const std::shared_ptr<ScanOptions>&) override {
+ return count;
+ }
+
+ Future<util::optional<int64_t>> count;
+};
+TEST_P(TestScanner, CountRowsFailure) {
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_);
+ RecordBatchVector batches = {batch};
+ auto fragment1 = std::make_shared<CountFailFragment>(batches);
+ auto fragment2 = std::make_shared<CountFailFragment>(batches);
+ ScannerBuilder builder(
+ std::make_shared<FragmentDataset>(schema_, FragmentVector{fragment1, fragment2}),
+ options_);
+ ASSERT_OK(builder.UseAsync(GetParam().use_async));
+ ASSERT_OK(builder.UseThreads(GetParam().use_threads));
+ ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+ fragment1->count.MarkFinished(Status::Invalid(""));
+ // Should immediately stop the count
+ ASSERT_RAISES(Invalid, scanner->CountRows());
+ // Fragment 2 doesn't complete until after the count stops - should not break anything
+ // under ASan, etc.
+ fragment2->count.MarkFinished(util::nullopt);
+}
+
+TEST_P(TestScanner, CountRowsWithMetadata) {
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_);
+ RecordBatchVector batches = {batch, batch, batch, batch};
+ ScannerBuilder builder(
+ std::make_shared<FragmentDataset>(
+ schema_, FragmentVector{std::make_shared<CountRowsOnlyFragment>(batches)}),
+ options_);
+ ASSERT_OK(builder.UseAsync(GetParam().use_async));
+ ASSERT_OK(builder.UseThreads(GetParam().use_threads));
+ ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+ ASSERT_OK_AND_EQ(4 * batch->num_rows(), scanner->CountRows());
+
+ ASSERT_OK(builder.Filter(equal(field_ref("i32"), literal(5))));
+ ASSERT_OK_AND_ASSIGN(scanner, builder.Finish());
+ // Scanner should fall back on reading data and hit the error
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Don't scan me!"),
+ scanner->CountRows());
+}
+
+TEST_P(TestScanner, ToRecordBatchReader) {
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_);
+ std::vector<std::shared_ptr<RecordBatch>> batches{
+ static_cast<std::size_t>(GetParam().num_batches * GetParam().num_child_datasets),
+ batch};
+
+ ASSERT_OK_AND_ASSIGN(auto expected, Table::FromRecordBatches(batches));
+
+ std::shared_ptr<Table> actual;
+ auto scanner = MakeScanner(batch);
+ ASSERT_OK_AND_ASSIGN(auto reader, scanner->ToRecordBatchReader());
+ scanner.reset();
+ ASSERT_OK(reader->ReadAll(&actual));
+ AssertTablesEqual(*expected, *actual, /*same_chunk_layout=*/false);
+}
+
+class FailingFragment : public InMemoryFragment {
+ public:
+ using InMemoryFragment::InMemoryFragment;
+ Result<ScanTaskIterator> Scan(std::shared_ptr<ScanOptions> options) override {
+ int index = 0;
+ auto self = shared_from_this();
+ return MakeFunctionIterator([=]() mutable -> Result<std::shared_ptr<ScanTask>> {
+ if (index > 16) {
+ return Status::Invalid("Oh no, we failed!");
+ }
+ RecordBatchVector batches = {record_batches_[index++ % record_batches_.size()]};
+ return std::make_shared<InMemoryScanTask>(batches, options, self);
+ });
+ }
+
+ Result<RecordBatchGenerator> ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& options) override {
+ struct {
+ Future<std::shared_ptr<RecordBatch>> operator()() {
+ if (index > 16) {
+ return Status::Invalid("Oh no, we failed!");
+ }
+ auto batch = batches[index++ % batches.size()];
+ return Future<std::shared_ptr<RecordBatch>>::MakeFinished(batch);
+ }
+ RecordBatchVector batches;
+ int index = 0;
+ } Generator;
+ Generator.batches = record_batches_;
+ return Generator;
+ }
+};
+
+class FailingExecuteScanTask : public InMemoryScanTask {
+ public:
+ using InMemoryScanTask::InMemoryScanTask;
+
+ Result<RecordBatchIterator> Execute() override {
+ return Status::Invalid("Oh no, we failed!");
+ }
+};
+
+class FailingIterationScanTask : public InMemoryScanTask {
+ public:
+ using InMemoryScanTask::InMemoryScanTask;
+
+ Result<RecordBatchIterator> Execute() override {
+ int index = 0;
+ auto batches = record_batches_;
+ return MakeFunctionIterator(
+ [index, batches]() mutable -> Result<std::shared_ptr<RecordBatch>> {
+ if (index < 1) {
+ return batches[index++];
+ }
+ return Status::Invalid("Oh no, we failed!");
+ });
+ }
+};
+
+template <typename T>
+class FailingScanTaskFragment : public InMemoryFragment {
+ public:
+ using InMemoryFragment::InMemoryFragment;
+ Result<ScanTaskIterator> Scan(std::shared_ptr<ScanOptions> options) override {
+ auto self = shared_from_this();
+ ScanTaskVector scan_tasks;
+ for (int i = 0; i < 4; i++) {
+ scan_tasks.push_back(std::make_shared<T>(record_batches_, options, self));
+ }
+ return MakeVectorIterator(std::move(scan_tasks));
+ }
+
+ // Unlike the sync case, there's only two places to fail - during
+ // iteration (covered by FailingFragment) or at the initial scan
+ // (covered here)
+ Result<RecordBatchGenerator> ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& options) override {
+ return Status::Invalid("Oh no, we failed!");
+ }
+};
+
+template <typename It, typename GetBatch>
+bool CheckIteratorRaises(const RecordBatch& batch, It batch_it, GetBatch get_batch) {
+ while (true) {
+ auto maybe_batch = batch_it.Next();
+ if (maybe_batch.ok()) {
+ EXPECT_OK_AND_ASSIGN(auto scanned_batch, maybe_batch);
+ if (IsIterationEnd(scanned_batch)) break;
+ AssertBatchesEqual(batch, *get_batch(scanned_batch));
+ } else {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Oh no, we failed!"),
+ maybe_batch);
+ return true;
+ }
+ }
+ return false;
+}
+
+TEST_P(TestScanner, ScanBatchesFailure) {
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_);
+ RecordBatchVector batches = {batch, batch, batch, batch};
+
+ auto check_scanner = [](const RecordBatch& batch, Scanner* scanner) {
+ auto maybe_batch_it = scanner->ScanBatchesUnordered();
+ if (!maybe_batch_it.ok()) {
+ // SyncScanner can fail here as it eagerly consumes the first value
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Oh no, we failed!"),
+ std::move(maybe_batch_it));
+ } else {
+ ASSERT_OK_AND_ASSIGN(auto batch_it, std::move(maybe_batch_it));
+ EXPECT_TRUE(CheckIteratorRaises(
+ batch, std::move(batch_it),
+ [](const EnumeratedRecordBatch& batch) { return batch.record_batch.value; }))
+ << "ScanBatchesUnordered() did not raise an error";
+ }
+ ASSERT_OK_AND_ASSIGN(auto tagged_batch_it, scanner->ScanBatches());
+ EXPECT_TRUE(CheckIteratorRaises(
+ batch, std::move(tagged_batch_it),
+ [](const TaggedRecordBatch& batch) { return batch.record_batch; }))
+ << "ScanBatches() did not raise an error";
+ };
+
+ // Case 1: failure when getting next scan task
+ {
+ FragmentVector fragments{std::make_shared<FailingFragment>(batches)};
+ auto dataset = std::make_shared<FragmentDataset>(schema_, fragments);
+ auto scanner = MakeScanner(std::move(dataset));
+ check_scanner(*batch, scanner.get());
+ }
+
+ // Case 2: failure when calling ScanTask::Execute
+ {
+ FragmentVector fragments{
+ std::make_shared<FailingScanTaskFragment<FailingExecuteScanTask>>(batches)};
+ auto dataset = std::make_shared<FragmentDataset>(schema_, fragments);
+ auto scanner = MakeScanner(std::move(dataset));
+ check_scanner(*batch, scanner.get());
+ }
+
+ // Case 3: failure when calling RecordBatchIterator::Next
+ {
+ FragmentVector fragments{
+ std::make_shared<FailingScanTaskFragment<FailingIterationScanTask>>(batches)};
+ auto dataset = std::make_shared<FragmentDataset>(schema_, fragments);
+ auto scanner = MakeScanner(std::move(dataset));
+ check_scanner(*batch, scanner.get());
+ }
+}
+
+TEST_P(TestScanner, Head) {
+ auto batch_size = GetParam().items_per_batch;
+ auto num_batches = GetParam().num_batches;
+ auto num_datasets = GetParam().num_child_datasets;
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(batch_size, schema_);
+
+ auto scanner = MakeScanner(batch);
+ std::shared_ptr<Table> expected, actual;
+
+ ASSERT_OK_AND_ASSIGN(expected, Table::FromRecordBatches(schema_, {}));
+ ASSERT_OK_AND_ASSIGN(actual, scanner->Head(0));
+ AssertTablesEqual(*expected, *actual, /*same_chunk_layout=*/false);
+
+ ASSERT_OK_AND_ASSIGN(expected, Table::FromRecordBatches(schema_, {batch}));
+ ASSERT_OK_AND_ASSIGN(actual, scanner->Head(batch_size));
+ AssertTablesEqual(*expected, *actual, /*same_chunk_layout=*/false);
+
+ ASSERT_OK_AND_ASSIGN(expected, Table::FromRecordBatches(schema_, {batch->Slice(0, 1)}));
+ ASSERT_OK_AND_ASSIGN(actual, scanner->Head(1));
+ AssertTablesEqual(*expected, *actual, /*same_chunk_layout=*/false);
+
+ if (num_batches > 1) {
+ ASSERT_OK_AND_ASSIGN(expected,
+ Table::FromRecordBatches(schema_, {batch, batch->Slice(0, 1)}));
+ ASSERT_OK_AND_ASSIGN(actual, scanner->Head(batch_size + 1));
+ AssertTablesEqual(*expected, *actual, /*same_chunk_layout=*/false);
+ }
+
+ ASSERT_OK_AND_ASSIGN(expected, scanner->ToTable());
+ ASSERT_OK_AND_ASSIGN(actual, scanner->Head(batch_size * num_batches * num_datasets));
+ AssertTablesEqual(*expected, *actual, /*same_chunk_layout=*/false);
+
+ ASSERT_OK_AND_ASSIGN(expected, scanner->ToTable());
+ ASSERT_OK_AND_ASSIGN(actual,
+ scanner->Head(batch_size * num_batches * num_datasets + 100));
+ AssertTablesEqual(*expected, *actual, /*same_chunk_layout=*/false);
+}
+
+TEST_P(TestScanner, FromReader) {
+ if (GetParam().use_async) {
+ GTEST_SKIP() << "Async scanner does not support construction from reader";
+ }
+ auto batch_size = GetParam().items_per_batch;
+ auto num_batches = GetParam().num_batches;
+
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(batch_size, schema_);
+ auto source_reader = ConstantArrayGenerator::Repeat(num_batches, batch);
+ auto target_reader = ConstantArrayGenerator::Repeat(num_batches, batch);
+
+ auto builder = ScannerBuilder::FromRecordBatchReader(source_reader);
+ ARROW_EXPECT_OK(builder->UseThreads(GetParam().use_threads));
+ ASSERT_OK_AND_ASSIGN(auto scanner, builder->Finish());
+ AssertScannerEquals(target_reader.get(), scanner.get());
+
+ // Such datasets can only be scanned once (but you can get fragments multiple times)
+ ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"),
+ batch_it.Next());
+}
+
+INSTANTIATE_TEST_SUITE_P(TestScannerThreading, TestScanner,
+ ::testing::ValuesIn(TestScannerParams::Values()),
+ [](const ::testing::TestParamInfo<TestScannerParams>& info) {
+ return std::to_string(info.index) + info.param.ToString();
+ });
+
+/// These ControlledXyz classes allow for controlling the order in which things are
+/// delivered so that we can test out of order resequencing. The dataset allows
+/// batches to be delivered on any fragment. When delivering batches a num_rows
+/// parameter is taken which can be used to differentiate batches.
+class ControlledFragment : public Fragment {
+ public:
+ explicit ControlledFragment(std::shared_ptr<Schema> schema)
+ : Fragment(literal(true), std::move(schema)),
+ record_batch_generator_(),
+ tracking_generator_(record_batch_generator_) {}
+
+ Result<ScanTaskIterator> Scan(std::shared_ptr<ScanOptions> options) override {
+ return Status::NotImplemented(
+ "Not needed for testing. Sync can only return things in-order.");
+ }
+ Result<std::shared_ptr<Schema>> ReadPhysicalSchemaImpl() override {
+ return physical_schema_;
+ }
+ std::string type_name() const override { return "scanner_test.cc::ControlledFragment"; }
+
+ Result<RecordBatchGenerator> ScanBatchesAsync(
+ const std::shared_ptr<ScanOptions>& options) override {
+ return tracking_generator_;
+ };
+
+ int NumBatchesRead() { return tracking_generator_.num_read(); }
+
+ void Finish() { ARROW_UNUSED(record_batch_generator_.producer().Close()); }
+ void DeliverBatch(uint32_t num_rows) {
+ auto batch = ConstantArrayGenerator::Zeroes(num_rows, physical_schema_);
+ record_batch_generator_.producer().Push(std::move(batch));
+ }
+
+ private:
+ PushGenerator<std::shared_ptr<RecordBatch>> record_batch_generator_;
+ util::TrackingGenerator<std::shared_ptr<RecordBatch>> tracking_generator_;
+};
+
+// TODO(ARROW-8163) Add testing for fragments arriving out of order
+class ControlledDataset : public Dataset {
+ public:
+ explicit ControlledDataset(int num_fragments)
+ : Dataset(arrow::schema({field("i32", int32())})), fragments_() {
+ for (int i = 0; i < num_fragments; i++) {
+ fragments_.push_back(std::make_shared<ControlledFragment>(schema_));
+ }
+ }
+
+ std::string type_name() const override { return "scanner_test.cc::ControlledDataset"; }
+ Result<std::shared_ptr<Dataset>> ReplaceSchema(
+ std::shared_ptr<Schema> schema) const override {
+ return Status::NotImplemented("Should not be called by unit test");
+ }
+
+ void DeliverBatch(int fragment_index, int num_rows) {
+ fragments_[fragment_index]->DeliverBatch(num_rows);
+ }
+
+ void FinishFragment(int fragment_index) { fragments_[fragment_index]->Finish(); }
+
+ protected:
+ Result<FragmentIterator> GetFragmentsImpl(compute::Expression predicate) override {
+ std::vector<std::shared_ptr<Fragment>> casted_fragments(fragments_.begin(),
+ fragments_.end());
+ return MakeVectorIterator(std::move(casted_fragments));
+ }
+
+ private:
+ std::vector<std::shared_ptr<ControlledFragment>> fragments_;
+};
+
+constexpr int kNumFragments = 2;
+
+class TestReordering : public ::testing::Test {
+ public:
+ void SetUp() override { dataset_ = std::make_shared<ControlledDataset>(kNumFragments); }
+
+ // Given a vector of fragment indices (one per batch) return a vector
+ // (one per fragment) mapping fragment index to the last occurrence of that
+ // index in order
+ //
+ // This allows us to know when to mark a fragment as finished
+ std::vector<int> GetLastIndices(const std::vector<int>& order) {
+ std::vector<int> last_indices(kNumFragments);
+ for (std::size_t i = 0; i < kNumFragments; i++) {
+ auto last_p = std::find(order.rbegin(), order.rend(), static_cast<int>(i));
+ EXPECT_NE(last_p, order.rend());
+ last_indices[i] = static_cast<int>(std::distance(last_p, order.rend())) - 1;
+ }
+ return last_indices;
+ }
+
+ /// We buffer one item in order to enumerate it (technically this could be avoided if
+ /// delivering in order but easier to have a single code path). We also can't deliver
+ /// items that don't come next. These two facts make for some pretty complex logic
+ /// to determine when items are ready to be collected.
+ std::vector<TaggedRecordBatch> DeliverAndCollect(std::vector<int> order,
+ TaggedRecordBatchGenerator gen) {
+ std::vector<TaggedRecordBatch> collected;
+ auto last_indices = GetLastIndices(order);
+ int num_fragments = static_cast<int>(last_indices.size());
+ std::vector<int> batches_seen_for_fragment(num_fragments);
+ auto current_fragment_index = 0;
+ auto seen_fragment = false;
+ for (std::size_t i = 0; i < order.size(); i++) {
+ auto fragment_index = order[i];
+ dataset_->DeliverBatch(fragment_index, static_cast<int>(i));
+ batches_seen_for_fragment[fragment_index]++;
+ if (static_cast<int>(i) == last_indices[fragment_index]) {
+ dataset_->FinishFragment(fragment_index);
+ }
+ if (current_fragment_index == fragment_index) {
+ if (seen_fragment) {
+ EXPECT_FINISHES_OK_AND_ASSIGN(auto next, gen());
+ collected.push_back(std::move(next));
+ } else {
+ seen_fragment = true;
+ }
+ if (static_cast<int>(i) == last_indices[fragment_index]) {
+ // Immediately collect your bonus fragment
+ EXPECT_FINISHES_OK_AND_ASSIGN(auto next, gen());
+ collected.push_back(std::move(next));
+ // Now collect any batches freed up that couldn't be delivered because they came
+ // from the wrong fragment
+ auto last_fragment_index = fragment_index;
+ fragment_index++;
+ seen_fragment = batches_seen_for_fragment[fragment_index] > 0;
+ while (fragment_index < num_fragments &&
+ fragment_index != last_fragment_index) {
+ last_fragment_index = fragment_index;
+ for (int j = 0; j < batches_seen_for_fragment[fragment_index] - 1; j++) {
+ EXPECT_FINISHES_OK_AND_ASSIGN(auto next, gen());
+ collected.push_back(std::move(next));
+ }
+ if (static_cast<int>(i) >= last_indices[fragment_index]) {
+ EXPECT_FINISHES_OK_AND_ASSIGN(auto next, gen());
+ collected.push_back(std::move(next));
+ fragment_index++;
+ if (fragment_index < num_fragments) {
+ seen_fragment = batches_seen_for_fragment[fragment_index] > 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ return collected;
+ }
+
+ struct FragmentStats {
+ int last_index;
+ bool seen;
+ };
+
+ std::vector<FragmentStats> GetFragmentStats(const std::vector<int>& order) {
+ auto last_indices = GetLastIndices(order);
+ std::vector<FragmentStats> fragment_stats;
+ for (std::size_t i = 0; i < last_indices.size(); i++) {
+ fragment_stats.push_back({last_indices[i], false});
+ }
+ return fragment_stats;
+ }
+
+ /// When data arrives out of order then we first have to buffer up 1 item in order to
+ /// know when the last item has arrived (so we can mark it as the last). This means
+ /// sometimes we deliver an item and don't get one (first in a fragment) and sometimes
+ /// we deliver an item and we end up getting two (last in a fragment)
+ std::vector<EnumeratedRecordBatch> DeliverAndCollect(
+ std::vector<int> order, EnumeratedRecordBatchGenerator gen) {
+ std::vector<EnumeratedRecordBatch> collected;
+ auto fragment_stats = GetFragmentStats(order);
+ for (std::size_t i = 0; i < order.size(); i++) {
+ auto fragment_index = order[i];
+ dataset_->DeliverBatch(fragment_index, static_cast<int>(i));
+ if (static_cast<int>(i) == fragment_stats[fragment_index].last_index) {
+ dataset_->FinishFragment(fragment_index);
+ EXPECT_FINISHES_OK_AND_ASSIGN(auto next, gen());
+ collected.push_back(std::move(next));
+ }
+ if (!fragment_stats[fragment_index].seen) {
+ fragment_stats[fragment_index].seen = true;
+ } else {
+ EXPECT_FINISHES_OK_AND_ASSIGN(auto next, gen());
+ collected.push_back(std::move(next));
+ }
+ }
+ return collected;
+ }
+
+ std::shared_ptr<Scanner> MakeScanner(int fragment_readahead = 0) {
+ ScannerBuilder builder(dataset_);
+ // Reordering tests only make sense for async
+ ARROW_EXPECT_OK(builder.UseAsync(true));
+ if (fragment_readahead != 0) {
+ ARROW_EXPECT_OK(builder.FragmentReadahead(fragment_readahead));
+ }
+ EXPECT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+ return scanner;
+ }
+
+ void AssertBatchesInOrder(const std::vector<TaggedRecordBatch>& batches,
+ std::vector<int> expected_order) {
+ ASSERT_EQ(expected_order.size(), batches.size());
+ for (std::size_t i = 0; i < batches.size(); i++) {
+ ASSERT_EQ(expected_order[i], batches[i].record_batch->num_rows());
+ }
+ }
+
+ void AssertBatchesInOrder(const std::vector<EnumeratedRecordBatch>& batches,
+ std::vector<int> expected_batch_indices,
+ std::vector<int> expected_row_sizes) {
+ ASSERT_EQ(expected_batch_indices.size(), batches.size());
+ for (std::size_t i = 0; i < batches.size(); i++) {
+ ASSERT_EQ(expected_row_sizes[i], batches[i].record_batch.value->num_rows());
+ ASSERT_EQ(expected_batch_indices[i], batches[i].record_batch.index);
+ }
+ }
+
+ std::shared_ptr<ControlledDataset> dataset_;
+};
+
+TEST_F(TestReordering, ScanBatches) {
+ auto scanner = MakeScanner();
+ ASSERT_OK_AND_ASSIGN(auto batch_gen, scanner->ScanBatchesAsync());
+ auto collected = DeliverAndCollect({0, 0, 1, 1, 0}, std::move(batch_gen));
+ AssertBatchesInOrder(collected, {0, 1, 4, 2, 3});
+}
+
+TEST_F(TestReordering, ScanBatchesUnordered) {
+ auto scanner = MakeScanner();
+ ASSERT_OK_AND_ASSIGN(auto batch_gen, scanner->ScanBatchesUnorderedAsync());
+ auto collected = DeliverAndCollect({0, 0, 1, 1, 0}, std::move(batch_gen));
+ AssertBatchesInOrder(collected, {0, 0, 1, 1, 2}, {0, 2, 3, 1, 4});
+}
+
+class TestBackpressure : public ::testing::Test {
+ protected:
+ static constexpr int NFRAGMENTS = 10;
+ static constexpr int NBATCHES = 50;
+ static constexpr int NROWS = 10;
+
+ FragmentVector MakeFragmentsAndDeliverInitialBatches() {
+ FragmentVector fragments;
+ for (int i = 0; i < NFRAGMENTS; i++) {
+ controlled_fragments_.emplace_back(std::make_shared<ControlledFragment>(schema_));
+ fragments.push_back(controlled_fragments_[i]);
+ // We only emit one batch on the first fragment. This triggers the sequencing
+ // generator to dig really deep to try and find the second batch
+ int num_to_emit = NBATCHES;
+ if (i == 0) {
+ num_to_emit = 1;
+ }
+ for (int j = 0; j < num_to_emit; j++) {
+ controlled_fragments_[i]->DeliverBatch(NROWS);
+ }
+ }
+ return fragments;
+ }
+
+ void DeliverAdditionalBatches() {
+ // Deliver a bunch of batches that should not be read in
+ for (int i = 1; i < NFRAGMENTS; i++) {
+ for (int j = 0; j < NBATCHES; j++) {
+ controlled_fragments_[i]->DeliverBatch(NROWS);
+ }
+ }
+ }
+
+ std::shared_ptr<Dataset> MakeDataset() {
+ FragmentVector fragments = MakeFragmentsAndDeliverInitialBatches();
+ return std::make_shared<FragmentDataset>(schema_, std::move(fragments));
+ }
+
+ std::shared_ptr<Scanner> MakeScanner() {
+ std::shared_ptr<Dataset> dataset = MakeDataset();
+ std::shared_ptr<ScanOptions> options = std::make_shared<ScanOptions>();
+ ScannerBuilder builder(std::move(dataset), options);
+ ARROW_EXPECT_OK(builder.UseThreads(true));
+ ARROW_EXPECT_OK(builder.UseAsync(true));
+ ARROW_EXPECT_OK(builder.FragmentReadahead(4));
+ EXPECT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+ return scanner;
+ }
+
+ int TotalBatchesRead() {
+ int sum = 0;
+ for (const auto& controlled_fragment : controlled_fragments_) {
+ sum += controlled_fragment->NumBatchesRead();
+ }
+ return sum;
+ }
+
+ template <typename T>
+ void Finish(AsyncGenerator<T> gen) {
+ for (const auto& controlled_fragment : controlled_fragments_) {
+ controlled_fragment->Finish();
+ }
+ ASSERT_FINISHES_OK(VisitAsyncGenerator(gen, [](T batch) { return Status::OK(); }));
+ }
+
+ std::shared_ptr<Schema> schema_ = schema({field("values", int32())});
+ std::vector<std::shared_ptr<ControlledFragment>> controlled_fragments_;
+};
+
+TEST_F(TestBackpressure, ScanBatchesUnordered) {
+ std::shared_ptr<Scanner> scanner = MakeScanner();
+ EXPECT_OK_AND_ASSIGN(AsyncGenerator<EnumeratedRecordBatch> gen,
+ scanner->ScanBatchesUnorderedAsync());
+ ASSERT_FINISHES_OK(gen());
+ // The exact numbers may be imprecise due to threading but we should pretty quickly read
+ // up to our backpressure limit and a little above. We should not be able to go too far
+ // above.
+ BusyWait(30, [&] { return TotalBatchesRead() >= kDefaultBackpressureHigh; });
+ ASSERT_GE(TotalBatchesRead(), kDefaultBackpressureHigh);
+ // Wait for the thread pool to idle. By this point the scanner should have paused
+ // itself This helps with timing on slower CI systems where there is only one core and
+ // the scanner might keep that core until it has scanned all the batches which never
+ // gives the sink a chance to report it is falling behind.
+ GetCpuThreadPool()->WaitForIdle();
+ DeliverAdditionalBatches();
+
+ SleepABit();
+ // Worst case we read in the entire set of initial batches
+ ASSERT_LE(TotalBatchesRead(), NBATCHES * (NFRAGMENTS - 1) + 1);
+
+ Finish(std::move(gen));
+}
+
+TEST_F(TestBackpressure, ScanBatchesOrdered) {
+ std::shared_ptr<Scanner> scanner = MakeScanner();
+ EXPECT_OK_AND_ASSIGN(AsyncGenerator<TaggedRecordBatch> gen,
+ scanner->ScanBatchesAsync());
+ // This future never actually finishes because we only emit the first batch so far and
+ // the scanner delays by one batch. It is enough to start the system pumping though so
+ // we don't need it to finish.
+ Future<TaggedRecordBatch> fut = gen();
+
+ // See note on other test
+ GetCpuThreadPool()->WaitForIdle();
+ // Worst case we read in the entire set of initial batches
+ ASSERT_LE(TotalBatchesRead(), NBATCHES * (NFRAGMENTS - 1) + 1);
+
+ DeliverAdditionalBatches();
+ Finish(std::move(gen));
+}
+
+struct BatchConsumer {
+ explicit BatchConsumer(EnumeratedRecordBatchGenerator generator)
+ : generator(std::move(generator)), next() {}
+
+ void AssertCanConsume() {
+ if (!next.is_valid()) {
+ next = generator();
+ }
+ ASSERT_FINISHES_OK(next);
+ next = Future<EnumeratedRecordBatch>();
+ }
+
+ void AssertCannotConsume() {
+ if (!next.is_valid()) {
+ next = generator();
+ }
+ SleepABit();
+ ASSERT_FALSE(next.is_finished());
+ }
+
+ void AssertFinished() {
+ if (!next.is_valid()) {
+ next = generator();
+ }
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto last, next);
+ ASSERT_TRUE(IsIterationEnd(last));
+ }
+
+ EnumeratedRecordBatchGenerator generator;
+ Future<EnumeratedRecordBatch> next;
+};
+
+TEST_F(TestReordering, FileReadahead) {
+ auto scanner = MakeScanner(/*fragment_readahead=*/1);
+ ASSERT_OK_AND_ASSIGN(auto batch_gen, scanner->ScanBatchesUnorderedAsync());
+ BatchConsumer consumer(std::move(batch_gen));
+ dataset_->DeliverBatch(0, 0);
+ dataset_->DeliverBatch(0, 1);
+ consumer.AssertCanConsume();
+ consumer.AssertCannotConsume();
+ dataset_->DeliverBatch(1, 0);
+ consumer.AssertCannotConsume();
+ dataset_->FinishFragment(1);
+ // Even though fragment 1 is finished we cannot read it because fragment_readahead
+ // is 1 so we should only be reading fragment 0
+ consumer.AssertCannotConsume();
+ dataset_->FinishFragment(0);
+ consumer.AssertCanConsume();
+ consumer.AssertCanConsume();
+ consumer.AssertFinished();
+}
+
+class TestScannerBuilder : public ::testing::Test {
+ void SetUp() override {
+ DatasetVector sources;
+
+ schema_ = schema({
+ field("b", boolean()),
+ field("i8", int8()),
+ field("i16", int16()),
+ field("i32", int32()),
+ field("i64", int64()),
+ });
+
+ ASSERT_OK_AND_ASSIGN(dataset_, UnionDataset::Make(schema_, sources));
+ }
+
+ protected:
+ std::shared_ptr<ScanOptions> options_ = std::make_shared<ScanOptions>();
+ std::shared_ptr<Schema> schema_;
+ std::shared_ptr<Dataset> dataset_;
+};
+
+TEST_F(TestScannerBuilder, TestProject) {
+ ScannerBuilder builder(dataset_, options_);
+
+ // It is valid to request no columns, e.g. `SELECT 1 FROM t WHERE t.a > 0`.
+ // still needs to touch the `a` column.
+ ASSERT_OK(builder.Project({}));
+ ASSERT_OK(builder.Project({"i64", "b", "i8"}));
+ ASSERT_OK(builder.Project({"i16", "i16"}));
+ ASSERT_OK(builder.Project(
+ {field_ref("i16"), call("multiply", {field_ref("i16"), literal(2)})},
+ {"i16 renamed", "i16 * 2"}));
+
+ ASSERT_RAISES(Invalid, builder.Project({"not_found_column"}));
+ ASSERT_RAISES(Invalid, builder.Project({"i8", "not_found_column"}));
+ ASSERT_RAISES(Invalid,
+ builder.Project({field_ref("not_found_column"),
+ call("multiply", {field_ref("i16"), literal(2)})},
+ {"i16 renamed", "i16 * 2"}));
+
+ ASSERT_RAISES(NotImplemented, builder.Project({field_ref(FieldRef("nested", "column"))},
+ {"nested column"}));
+
+ // provided more field names than column exprs or vice versa
+ ASSERT_RAISES(Invalid, builder.Project({}, {"i16 renamed", "i16 * 2"}));
+ ASSERT_RAISES(Invalid, builder.Project({literal(2), field_ref("a")}, {"a"}));
+}
+
+TEST_F(TestScannerBuilder, TestFilter) {
+ ScannerBuilder builder(dataset_, options_);
+
+ ASSERT_OK(builder.Filter(literal(true)));
+ ASSERT_OK(builder.Filter(equal(field_ref("i64"), literal<int64_t>(10))));
+ ASSERT_OK(builder.Filter(or_(equal(field_ref("i64"), literal<int64_t>(10)),
+ equal(field_ref("b"), literal(true)))));
+
+ ASSERT_OK(builder.Filter(equal(field_ref("i64"), literal<double>(10))));
+
+ ASSERT_RAISES(Invalid, builder.Filter(equal(field_ref("not_a_column"), literal(true))));
+
+ ASSERT_RAISES(
+ NotImplemented,
+ builder.Filter(equal(field_ref(FieldRef("nested", "column")), literal(true))));
+
+ ASSERT_RAISES(Invalid,
+ builder.Filter(or_(equal(field_ref("i64"), literal<int64_t>(10)),
+ equal(field_ref("not_a_column"), literal(true)))));
+}
+
+TEST(ScanOptions, TestMaterializedFields) {
+ auto i32 = field("i32", int32());
+ auto i64 = field("i64", int64());
+ auto opts = std::make_shared<ScanOptions>();
+
+ // empty dataset, project nothing = nothing materialized
+ opts->dataset_schema = schema({});
+ ASSERT_OK(SetProjection(opts.get(), {}, {}));
+ EXPECT_THAT(opts->MaterializedFields(), IsEmpty());
+
+ // non-empty dataset, project nothing = nothing materialized
+ opts->dataset_schema = schema({i32, i64});
+ EXPECT_THAT(opts->MaterializedFields(), IsEmpty());
+
+ // project nothing, filter on i32 = materialize i32
+ opts->filter = equal(field_ref("i32"), literal(10));
+ EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32"));
+
+ // project i32 & i64, filter nothing = materialize i32 & i64
+ opts->filter = literal(true);
+ ASSERT_OK(SetProjection(opts.get(), {"i32", "i64"}));
+ EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64"));
+
+ // project i32 + i64, filter nothing = materialize i32 & i64
+ opts->filter = literal(true);
+ ASSERT_OK(SetProjection(opts.get(), {call("add", {field_ref("i32"), field_ref("i64")})},
+ {"i32 + i64"}));
+ EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64"));
+
+ // project i32, filter nothing = materialize i32
+ ASSERT_OK(SetProjection(opts.get(), {"i32"}));
+ EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32"));
+
+ // project i32, filter on i32 = materialize i32 (reported twice)
+ opts->filter = equal(field_ref("i32"), literal(10));
+ EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i32"));
+
+ // project i32, filter on i32 & i64 = materialize i64, i32 (reported twice)
+ opts->filter = less(field_ref("i32"), field_ref("i64"));
+ EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64", "i32"));
+
+ // project i32, filter on i64 = materialize i32 & i64
+ opts->filter = equal(field_ref("i64"), literal(10));
+ EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i64", "i32"));
+}
+
+namespace {
+
+struct TestPlan {
+ explicit TestPlan(compute::ExecContext* ctx = compute::default_exec_context())
+ : plan(compute::ExecPlan::Make(ctx).ValueOrDie()) {
+ internal::Initialize();
+ }
+
+ Future<std::vector<compute::ExecBatch>> Run() {
+ RETURN_NOT_OK(plan->Validate());
+ RETURN_NOT_OK(plan->StartProducing());
+
+ auto collected_fut = CollectAsyncGenerator(sink_gen);
+
+ return AllComplete({plan->finished(), Future<>(collected_fut)})
+ .Then([collected_fut]() -> Result<std::vector<compute::ExecBatch>> {
+ ARROW_ASSIGN_OR_RAISE(auto collected, collected_fut.result());
+ return ::arrow::internal::MapVector(
+ [](util::optional<compute::ExecBatch> batch) { return std::move(*batch); },
+ std::move(collected));
+ });
+ }
+
+ compute::ExecPlan* get() { return plan.get(); }
+
+ std::shared_ptr<compute::ExecPlan> plan;
+ AsyncGenerator<util::optional<compute::ExecBatch>> sink_gen;
+};
+
+struct DatasetAndBatches {
+ std::shared_ptr<Dataset> dataset;
+ std::vector<compute::ExecBatch> batches;
+};
+
+DatasetAndBatches MakeBasicDataset() {
+ const auto dataset_schema = ::arrow::schema({
+ field("a", int32()),
+ field("b", boolean()),
+ field("c", int32()),
+ });
+
+ const auto physical_schema = SchemaFromColumnNames(dataset_schema, {"a", "b"});
+
+ RecordBatchVector record_batches{
+ RecordBatchFromJSON(physical_schema, R"([{"a": 1, "b": null},
+ {"a": 2, "b": true}])"),
+ RecordBatchFromJSON(physical_schema, R"([{"a": null, "b": true},
+ {"a": 3, "b": false}])"),
+ RecordBatchFromJSON(physical_schema, R"([{"a": null, "b": true},
+ {"a": 4, "b": false}])"),
+ RecordBatchFromJSON(physical_schema, R"([{"a": 5, "b": null},
+ {"a": 6, "b": false},
+ {"a": 7, "b": false}])"),
+ };
+
+ auto dataset = std::make_shared<FragmentDataset>(
+ dataset_schema,
+ FragmentVector{
+ std::make_shared<InMemoryFragment>(
+ physical_schema, RecordBatchVector{record_batches[0], record_batches[1]},
+ equal(field_ref("c"), literal(23))),
+ std::make_shared<InMemoryFragment>(
+ physical_schema, RecordBatchVector{record_batches[2], record_batches[3]},
+ equal(field_ref("c"), literal(47))),
+ });
+
+ std::vector<compute::ExecBatch> batches;
+
+ auto batch_it = record_batches.begin();
+ for (int fragment_index = 0; fragment_index < 2; ++fragment_index) {
+ for (int batch_index = 0; batch_index < 2; ++batch_index) {
+ const auto& batch = *batch_it++;
+
+ // the scanned ExecBatches will begin with physical columns
+ batches.emplace_back(*batch);
+
+ // a placeholder will be inserted for partition field "c"
+ batches.back().values.emplace_back(std::make_shared<Int32Scalar>());
+
+ // scanned batches will be augmented with fragment and batch indices
+ batches.back().values.emplace_back(fragment_index);
+ batches.back().values.emplace_back(batch_index);
+
+ // ... and with the last-in-fragment flag
+ batches.back().values.emplace_back(batch_index == 1);
+
+ // each batch carries a guarantee inherited from its Fragment's partition expression
+ batches.back().guarantee =
+ equal(field_ref("c"), literal(fragment_index == 0 ? 23 : 47));
+ }
+ }
+
+ return {dataset, batches};
+}
+
+compute::Expression Materialize(std::vector<std::string> names,
+ bool include_aug_fields = false) {
+ if (include_aug_fields) {
+ for (auto aug_name : {"__fragment_index", "__batch_index", "__last_in_fragment"}) {
+ names.emplace_back(aug_name);
+ }
+ }
+
+ std::vector<compute::Expression> exprs;
+ for (const auto& name : names) {
+ exprs.push_back(field_ref(name));
+ }
+
+ return project(exprs, names);
+}
+} // namespace
+
+TEST(ScanNode, Schema) {
+ TestPlan plan;
+
+ auto basic = MakeBasicDataset();
+
+ auto options = std::make_shared<ScanOptions>();
+ options->use_async = true;
+ options->projection = Materialize({}); // set an empty projection
+
+ ASSERT_OK_AND_ASSIGN(auto scan,
+ compute::MakeExecNode("scan", plan.get(), {},
+ ScanNodeOptions{basic.dataset, options}));
+
+ auto fields = basic.dataset->schema()->fields();
+ fields.push_back(field("__fragment_index", int32()));
+ fields.push_back(field("__batch_index", int32()));
+ fields.push_back(field("__last_in_fragment", boolean()));
+ // output_schema is *always* the full augmented dataset schema, regardless of projection
+ // (but some columns *may* be placeholder null Scalars if not projected)
+ AssertSchemaEqual(Schema(fields), *scan->output_schema());
+}
+
+TEST(ScanNode, Trivial) {
+ TestPlan plan;
+
+ auto basic = MakeBasicDataset();
+
+ auto options = std::make_shared<ScanOptions>();
+ options->use_async = true;
+ // ensure all fields are materialized
+ options->projection = Materialize({"a", "b", "c"}, /*include_aug_fields=*/true);
+
+ ASSERT_OK(compute::Declaration::Sequence(
+ {
+ {"scan", ScanNodeOptions{basic.dataset, options}},
+ {"sink", compute::SinkNodeOptions{&plan.sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ // trivial scan: the batches are returned unmodified
+ auto expected = basic.batches;
+ ASSERT_THAT(plan.Run(), Finishes(ResultWith(UnorderedElementsAreArray(expected))));
+}
+
+TEST(ScanNode, FilteredOnVirtualColumn) {
+ TestPlan plan;
+
+ auto basic = MakeBasicDataset();
+
+ auto options = std::make_shared<ScanOptions>();
+ options->use_async = true;
+ options->filter = less(field_ref("c"), literal(30));
+ // ensure all fields are materialized
+ options->projection = Materialize({"a", "b", "c"}, /*include_aug_fields=*/true);
+
+ ASSERT_OK(compute::Declaration::Sequence(
+ {
+ {"scan", ScanNodeOptions{basic.dataset, options}},
+ {"sink", compute::SinkNodeOptions{&plan.sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ auto expected = basic.batches;
+
+ // only the first fragment will make it past the filter
+ expected.pop_back();
+ expected.pop_back();
+
+ ASSERT_THAT(plan.Run(), Finishes(ResultWith(UnorderedElementsAreArray(expected))));
+}
+
+TEST(ScanNode, DeferredFilterOnPhysicalColumn) {
+ TestPlan plan;
+
+ auto basic = MakeBasicDataset();
+
+ auto options = std::make_shared<ScanOptions>();
+ options->use_async = true;
+ options->filter = greater(field_ref("a"), literal(4));
+ // ensure all fields are materialized
+ options->projection = Materialize({"a", "b", "c"}, /*include_aug_fields=*/true);
+
+ ASSERT_OK(compute::Declaration::Sequence(
+ {
+ {"scan", ScanNodeOptions{basic.dataset, options}},
+ {"sink", compute::SinkNodeOptions{&plan.sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ // No post filtering is performed by ScanNode: all batches will be yielded whole.
+ // To filter out rows from individual batches, construct a FilterNode.
+ auto expected = basic.batches;
+
+ ASSERT_THAT(plan.Run(), Finishes(ResultWith(UnorderedElementsAreArray(expected))));
+}
+
+TEST(ScanNode, DISABLED_ProjectionPushdown) {
+ // ARROW-13263
+ TestPlan plan;
+
+ auto basic = MakeBasicDataset();
+
+ auto options = std::make_shared<ScanOptions>();
+ options->use_async = true;
+ options->projection = Materialize({"b"}, /*include_aug_fields=*/true);
+
+ ASSERT_OK(compute::Declaration::Sequence(
+ {
+ {"scan", ScanNodeOptions{basic.dataset, options}},
+ {"sink", compute::SinkNodeOptions{&plan.sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ auto expected = basic.batches;
+
+ int a_index = basic.dataset->schema()->GetFieldIndex("a");
+ int c_index = basic.dataset->schema()->GetFieldIndex("c");
+ for (auto& batch : expected) {
+ // "a", "c" were not projected or filtered so they are dropped eagerly
+ batch.values[a_index] = MakeNullScalar(batch.values[a_index].type());
+ batch.values[c_index] = MakeNullScalar(batch.values[c_index].type());
+ }
+
+ ASSERT_THAT(plan.Run(), Finishes(ResultWith(UnorderedElementsAreArray(expected))));
+}
+
+TEST(ScanNode, MaterializationOfVirtualColumn) {
+ TestPlan plan;
+
+ auto basic = MakeBasicDataset();
+
+ auto options = std::make_shared<ScanOptions>();
+ options->use_async = true;
+ options->projection = Materialize({"a", "b", "c"}, /*include_aug_fields=*/true);
+
+ ASSERT_OK(compute::Declaration::Sequence(
+ {
+ {"scan", ScanNodeOptions{basic.dataset, options}},
+ {"augmented_project",
+ compute::ProjectNodeOptions{
+ {field_ref("a"), field_ref("b"), field_ref("c")}}},
+ {"sink", compute::SinkNodeOptions{&plan.sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ auto expected = basic.batches;
+
+ for (auto& batch : expected) {
+ // ProjectNode overwrites "c" placeholder with non-null drawn from guarantee
+ const auto& value = *batch.guarantee.call()->arguments[1].literal();
+ batch.values[2] = value;
+ }
+
+ ASSERT_THAT(plan.Run(), Finishes(ResultWith(UnorderedElementsAreArray(expected))));
+}
+
+TEST(ScanNode, MinimalEndToEnd) {
+ // NB: This test is here for didactic purposes
+
+ // Specify a MemoryPool and ThreadPool for the ExecPlan
+ compute::ExecContext exec_context(default_memory_pool(),
+ ::arrow::internal::GetCpuThreadPool());
+
+ // ensure arrow::dataset node factories are in the registry
+ arrow::dataset::internal::Initialize();
+
+ // A ScanNode is constructed from an ExecPlan (into which it is inserted),
+ // a Dataset (whose batches will be scanned), and ScanOptions (to specify a filter for
+ // predicate pushdown, a projection to skip materialization of unnecessary columns, ...)
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<compute::ExecPlan> plan,
+ compute::ExecPlan::Make(&exec_context));
+
+ std::shared_ptr<Dataset> dataset = std::make_shared<InMemoryDataset>(
+ TableFromJSON(schema({field("a", int32()), field("b", boolean())}),
+ {
+ R"([{"a": 1, "b": null},
+ {"a": 2, "b": true}])",
+ R"([{"a": null, "b": true},
+ {"a": 3, "b": false}])",
+ R"([{"a": null, "b": true},
+ {"a": 4, "b": false}])",
+ R"([{"a": 5, "b": null},
+ {"a": 6, "b": false},
+ {"a": 7, "b": false}])",
+ }));
+
+ auto options = std::make_shared<ScanOptions>();
+ // sync scanning is not supported by ScanNode
+ options->use_async = true;
+ // specify the filter
+ compute::Expression b_is_true = field_ref("b");
+ options->filter = b_is_true;
+ // for now, specify the projection as the full project expression (eventually this can
+ // just be a list of materialized field names)
+ compute::Expression a_times_2 = call("multiply", {field_ref("a"), literal(2)});
+ options->projection =
+ call("make_struct", {a_times_2}, compute::MakeStructOptions{{"a * 2"}});
+
+ // construct the scan node
+ ASSERT_OK_AND_ASSIGN(
+ compute::ExecNode * scan,
+ compute::MakeExecNode("scan", plan.get(), {}, ScanNodeOptions{dataset, options}));
+
+ // pipe the scan node into a filter node
+ ASSERT_OK_AND_ASSIGN(compute::ExecNode * filter,
+ compute::MakeExecNode("filter", plan.get(), {scan},
+ compute::FilterNodeOptions{b_is_true}));
+
+ // pipe the filter node into a project node
+ // NB: we're using the project node factory which preserves fragment/batch index
+ // tagging, so we *can* reorder later if we choose. The tags will not appear in
+ // our output.
+ ASSERT_OK_AND_ASSIGN(compute::ExecNode * project,
+ compute::MakeExecNode("augmented_project", plan.get(), {filter},
+ compute::ProjectNodeOptions{{a_times_2}}));
+
+ // finally, pipe the project node into a sink node
+ AsyncGenerator<util::optional<compute::ExecBatch>> sink_gen;
+ ASSERT_OK_AND_ASSIGN(compute::ExecNode * sink,
+ compute::MakeExecNode("ordered_sink", plan.get(), {project},
+ compute::SinkNodeOptions{&sink_gen}));
+
+ ASSERT_THAT(plan->sinks(), ElementsAre(sink));
+
+ // translate sink_gen (async) to sink_reader (sync)
+ std::shared_ptr<RecordBatchReader> sink_reader = compute::MakeGeneratorReader(
+ schema({field("a * 2", int32())}), std::move(sink_gen), exec_context.memory_pool());
+
+ // start the ExecPlan
+ ASSERT_OK(plan->StartProducing());
+
+ // collect sink_reader into a Table
+ ASSERT_OK_AND_ASSIGN(auto collected, Table::FromRecordBatchReader(sink_reader.get()));
+
+ // Sort table
+ ASSERT_OK_AND_ASSIGN(
+ auto indices,
+ compute::SortIndices(collected, compute::SortOptions({compute::SortKey(
+ "a * 2", compute::SortOrder::Ascending)})));
+ ASSERT_OK_AND_ASSIGN(auto sorted, compute::Take(collected, indices));
+
+ // wait 1s for completion
+ ASSERT_TRUE(plan->finished().Wait(/*seconds=*/1)) << "ExecPlan didn't finish within 1s";
+
+ auto expected = TableFromJSON(schema({field("a * 2", int32())}), {
+ R"([
+ {"a * 2": 4},
+ {"a * 2": null},
+ {"a * 2": null}
+ ])"});
+ AssertTablesEqual(*expected, *sorted.table(), /*same_chunk_layout=*/false);
+}
+
+TEST(ScanNode, MinimalScalarAggEndToEnd) {
+ // NB: This test is here for didactic purposes
+
+ // Specify a MemoryPool and ThreadPool for the ExecPlan
+ compute::ExecContext exec_context(default_memory_pool(), GetCpuThreadPool());
+
+ // ensure arrow::dataset node factories are in the registry
+ arrow::dataset::internal::Initialize();
+
+ // A ScanNode is constructed from an ExecPlan (into which it is inserted),
+ // a Dataset (whose batches will be scanned), and ScanOptions (to specify a filter for
+ // predicate pushdown, a projection to skip materialization of unnecessary columns, ...)
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<compute::ExecPlan> plan,
+ compute::ExecPlan::Make(&exec_context));
+
+ std::shared_ptr<Dataset> dataset = std::make_shared<InMemoryDataset>(
+ TableFromJSON(schema({field("a", int32()), field("b", boolean())}),
+ {
+ R"([{"a": 1, "b": null},
+ {"a": 2, "b": true}])",
+ R"([{"a": null, "b": true},
+ {"a": 3, "b": false}])",
+ R"([{"a": null, "b": true},
+ {"a": 4, "b": false}])",
+ R"([{"a": 5, "b": null},
+ {"a": 6, "b": false},
+ {"a": 7, "b": false}])",
+ }));
+
+ auto options = std::make_shared<ScanOptions>();
+ // sync scanning is not supported by ScanNode
+ options->use_async = true;
+ // specify the filter
+ compute::Expression b_is_true = field_ref("b");
+ options->filter = b_is_true;
+ // for now, specify the projection as the full project expression (eventually this can
+ // just be a list of materialized field names)
+ compute::Expression a_times_2 = call("multiply", {field_ref("a"), literal(2)});
+ options->projection =
+ call("make_struct", {a_times_2}, compute::MakeStructOptions{{"a * 2"}});
+
+ // construct the scan node
+ ASSERT_OK_AND_ASSIGN(
+ compute::ExecNode * scan,
+ compute::MakeExecNode("scan", plan.get(), {}, ScanNodeOptions{dataset, options}));
+
+ // pipe the scan node into a filter node
+ ASSERT_OK_AND_ASSIGN(compute::ExecNode * filter,
+ compute::MakeExecNode("filter", plan.get(), {scan},
+ compute::FilterNodeOptions{b_is_true}));
+
+ // pipe the filter node into a project node
+ ASSERT_OK_AND_ASSIGN(
+ compute::ExecNode * project,
+ compute::MakeExecNode("project", plan.get(), {filter},
+ compute::ProjectNodeOptions{{a_times_2}, {"a * 2"}}));
+
+ // pipe the projection into a scalar aggregate node
+ ASSERT_OK_AND_ASSIGN(
+ compute::ExecNode * aggregate,
+ compute::MakeExecNode(
+ "aggregate", plan.get(), {project},
+ compute::AggregateNodeOptions{{compute::internal::Aggregate{"sum", nullptr}},
+ /*targets=*/{"a * 2"},
+ /*names=*/{"sum(a * 2)"}}));
+
+ // finally, pipe the aggregate node into a sink node
+ AsyncGenerator<util::optional<compute::ExecBatch>> sink_gen;
+ ASSERT_OK_AND_ASSIGN(compute::ExecNode * sink,
+ compute::MakeExecNode("sink", plan.get(), {aggregate},
+ compute::SinkNodeOptions{&sink_gen}));
+
+ ASSERT_THAT(plan->sinks(), ElementsAre(sink));
+
+ // translate sink_gen (async) to sink_reader (sync)
+ std::shared_ptr<RecordBatchReader> sink_reader =
+ compute::MakeGeneratorReader(schema({field("a*2 sum", int64())}),
+ std::move(sink_gen), exec_context.memory_pool());
+
+ // start the ExecPlan
+ ASSERT_OK(plan->StartProducing());
+
+ // collect sink_reader into a Table
+ ASSERT_OK_AND_ASSIGN(auto collected, Table::FromRecordBatchReader(sink_reader.get()));
+
+ // wait 1s for completion
+ ASSERT_TRUE(plan->finished().Wait(/*seconds=*/1)) << "ExecPlan didn't finish within 1s";
+
+ auto expected = TableFromJSON(schema({field("a*2 sum", int64())}), {
+ R"([
+ {"a*2 sum": 4}
+ ])"});
+ AssertTablesEqual(*expected, *collected, /*same_chunk_layout=*/false);
+}
+
+TEST(ScanNode, MinimalGroupedAggEndToEnd) {
+ // NB: This test is here for didactic purposes
+
+ // Specify a MemoryPool and ThreadPool for the ExecPlan
+ compute::ExecContext exec_context(default_memory_pool(), GetCpuThreadPool());
+
+ // ensure arrow::dataset node factories are in the registry
+ arrow::dataset::internal::Initialize();
+
+ // A ScanNode is constructed from an ExecPlan (into which it is inserted),
+ // a Dataset (whose batches will be scanned), and ScanOptions (to specify a filter for
+ // predicate pushdown, a projection to skip materialization of unnecessary columns, ...)
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<compute::ExecPlan> plan,
+ compute::ExecPlan::Make(&exec_context));
+
+ std::shared_ptr<Dataset> dataset = std::make_shared<InMemoryDataset>(
+ TableFromJSON(schema({field("a", int32()), field("b", boolean())}),
+ {
+ R"([{"a": 1, "b": null},
+ {"a": 2, "b": true}])",
+ R"([{"a": null, "b": true},
+ {"a": 3, "b": false}])",
+ R"([{"a": null, "b": true},
+ {"a": 4, "b": false}])",
+ R"([{"a": 5, "b": null},
+ {"a": 6, "b": false},
+ {"a": 7, "b": false}])",
+ }));
+
+ auto options = std::make_shared<ScanOptions>();
+ // sync scanning is not supported by ScanNode
+ options->use_async = true;
+ // specify the filter
+ compute::Expression b_is_true = field_ref("b");
+ options->filter = b_is_true;
+ // for now, specify the projection as the full project expression (eventually this can
+ // just be a list of materialized field names)
+ compute::Expression a_times_2 = call("multiply", {field_ref("a"), literal(2)});
+ compute::Expression b = field_ref("b");
+ options->projection =
+ call("make_struct", {a_times_2, b}, compute::MakeStructOptions{{"a * 2", "b"}});
+
+ // construct the scan node
+ ASSERT_OK_AND_ASSIGN(
+ compute::ExecNode * scan,
+ compute::MakeExecNode("scan", plan.get(), {}, ScanNodeOptions{dataset, options}));
+
+ // pipe the scan node into a project node
+ ASSERT_OK_AND_ASSIGN(
+ compute::ExecNode * project,
+ compute::MakeExecNode("project", plan.get(), {scan},
+ compute::ProjectNodeOptions{{a_times_2, b}, {"a * 2", "b"}}));
+
+ // pipe the projection into a grouped aggregate node
+ ASSERT_OK_AND_ASSIGN(
+ compute::ExecNode * aggregate,
+ compute::MakeExecNode("aggregate", plan.get(), {project},
+ compute::AggregateNodeOptions{
+ {compute::internal::Aggregate{"hash_sum", nullptr}},
+ /*targets=*/{"a * 2"},
+ /*names=*/{"sum(a * 2)"},
+ /*keys=*/{"b"}}));
+
+ // finally, pipe the aggregate node into a sink node
+ AsyncGenerator<util::optional<compute::ExecBatch>> sink_gen;
+ ASSERT_OK_AND_ASSIGN(compute::ExecNode * sink,
+ compute::MakeExecNode("sink", plan.get(), {aggregate},
+ compute::SinkNodeOptions{&sink_gen}));
+
+ ASSERT_THAT(plan->sinks(), ElementsAre(sink));
+
+ // translate sink_gen (async) to sink_reader (sync)
+ std::shared_ptr<RecordBatchReader> sink_reader = compute::MakeGeneratorReader(
+ schema({field("sum(a * 2)", int64()), field("b", boolean())}), std::move(sink_gen),
+ exec_context.memory_pool());
+
+ // start the ExecPlan
+ ASSERT_OK(plan->StartProducing());
+
+ // collect sink_reader into a Table
+ ASSERT_OK_AND_ASSIGN(auto collected, Table::FromRecordBatchReader(sink_reader.get()));
+
+ // Sort table
+ ASSERT_OK_AND_ASSIGN(
+ auto indices, compute::SortIndices(
+ collected, compute::SortOptions({compute::SortKey(
+ "sum(a * 2)", compute::SortOrder::Ascending)})));
+ ASSERT_OK_AND_ASSIGN(auto sorted, compute::Take(collected, indices));
+
+ // wait 1s for completion
+ ASSERT_TRUE(plan->finished().Wait(/*seconds=*/1)) << "ExecPlan didn't finish within 1s";
+
+ auto expected = TableFromJSON(
+ schema({field("sum(a * 2)", int64()), field("b", boolean())}), {
+ R"JSON([
+ {"sum(a * 2)": 4, "b": true},
+ {"sum(a * 2)": 12, "b": null},
+ {"sum(a * 2)": 40, "b": false}
+ ])JSON"});
+ AssertTablesEqual(*expected, *sorted.table(), /*same_chunk_layout=*/false);
+}
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/test_util.h b/src/arrow/cpp/src/arrow/dataset/test_util.h
new file mode 100644
index 000000000..722046e5e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/test_util.h
@@ -0,0 +1,1300 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <ciso646>
+#include <functional>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/compute/exec/expression.h"
+#include "arrow/dataset/dataset_internal.h"
+#include "arrow/dataset/discovery.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/scanner_internal.h"
+#include "arrow/filesystem/localfs.h"
+#include "arrow/filesystem/mockfs.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/filesystem/test_util.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/generator.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+using internal::TemporaryDir;
+
+namespace dataset {
+
+using compute::call;
+using compute::field_ref;
+using compute::literal;
+
+using compute::and_;
+using compute::equal;
+using compute::greater;
+using compute::greater_equal;
+using compute::is_null;
+using compute::is_valid;
+using compute::less;
+using compute::less_equal;
+using compute::not_;
+using compute::not_equal;
+using compute::or_;
+using compute::project;
+
+using fs::internal::GetAbstractPathExtension;
+
+class FileSourceFixtureMixin : public ::testing::Test {
+ public:
+ std::unique_ptr<FileSource> GetSource(std::shared_ptr<Buffer> buffer) {
+ return ::arrow::internal::make_unique<FileSource>(std::move(buffer));
+ }
+};
+
+template <typename Gen>
+class GeneratedRecordBatch : public RecordBatchReader {
+ public:
+ GeneratedRecordBatch(std::shared_ptr<Schema> schema, Gen gen)
+ : schema_(std::move(schema)), gen_(gen) {}
+
+ std::shared_ptr<Schema> schema() const override { return schema_; }
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* batch) override { return gen_(batch); }
+
+ private:
+ std::shared_ptr<Schema> schema_;
+ Gen gen_;
+};
+
+template <typename Gen>
+std::unique_ptr<GeneratedRecordBatch<Gen>> MakeGeneratedRecordBatch(
+ std::shared_ptr<Schema> schema, Gen&& gen) {
+ return ::arrow::internal::make_unique<GeneratedRecordBatch<Gen>>(
+ schema, std::forward<Gen>(gen));
+}
+
+std::unique_ptr<RecordBatchReader> MakeGeneratedRecordBatch(
+ std::shared_ptr<Schema> schema, int64_t batch_size, int64_t batch_repetitions) {
+ auto batch = random::GenerateBatch(schema->fields(), batch_size, /*seed=*/0);
+ int64_t i = 0;
+ return MakeGeneratedRecordBatch(
+ schema, [batch, i, batch_repetitions](std::shared_ptr<RecordBatch>* out) mutable {
+ *out = i++ < batch_repetitions ? batch : nullptr;
+ return Status::OK();
+ });
+}
+
+void EnsureRecordBatchReaderDrained(RecordBatchReader* reader) {
+ ASSERT_OK_AND_ASSIGN(auto batch, reader->Next());
+ EXPECT_EQ(batch, nullptr);
+}
+
+class DatasetFixtureMixin : public ::testing::Test {
+ public:
+ /// \brief Ensure that record batches found in reader are equals to the
+ /// record batches yielded by the data fragment.
+ void AssertScanTaskEquals(RecordBatchReader* expected, ScanTask* task,
+ bool ensure_drained = true) {
+ ASSERT_OK_AND_ASSIGN(auto it, task->Execute());
+ ARROW_EXPECT_OK(it.Visit([expected](std::shared_ptr<RecordBatch> rhs) -> Status {
+ std::shared_ptr<RecordBatch> lhs;
+ RETURN_NOT_OK(expected->ReadNext(&lhs));
+ EXPECT_NE(lhs, nullptr);
+ AssertBatchesEqual(*lhs, *rhs);
+ return Status::OK();
+ }));
+
+ if (ensure_drained) {
+ EnsureRecordBatchReaderDrained(expected);
+ }
+ }
+
+ /// \brief Assert the value of the next batch yielded by the reader
+ void AssertBatchEquals(RecordBatchReader* expected, const RecordBatch& batch) {
+ std::shared_ptr<RecordBatch> lhs;
+ ASSERT_OK(expected->ReadNext(&lhs));
+ EXPECT_NE(lhs, nullptr);
+ AssertBatchesEqual(*lhs, batch);
+ }
+
+ /// \brief Ensure that record batches found in reader are equals to the
+ /// record batches yielded by the data fragment.
+ void AssertFragmentEquals(RecordBatchReader* expected, Fragment* fragment,
+ bool ensure_drained = true) {
+ ASSERT_OK_AND_ASSIGN(auto it, fragment->Scan(options_));
+
+ ARROW_EXPECT_OK(it.Visit([&](std::shared_ptr<ScanTask> task) -> Status {
+ AssertScanTaskEquals(expected, task.get(), false);
+ return Status::OK();
+ }));
+
+ if (ensure_drained) {
+ EnsureRecordBatchReaderDrained(expected);
+ }
+ }
+
+ /// \brief Ensure that record batches found in reader are equals to the
+ /// record batches yielded by the data fragments of a dataset.
+ void AssertDatasetFragmentsEqual(RecordBatchReader* expected, Dataset* dataset,
+ bool ensure_drained = true) {
+ ASSERT_OK_AND_ASSIGN(auto predicate, options_->filter.Bind(*dataset->schema()));
+ ASSERT_OK_AND_ASSIGN(auto it, dataset->GetFragments(predicate));
+
+ ARROW_EXPECT_OK(it.Visit([&](std::shared_ptr<Fragment> fragment) -> Status {
+ AssertFragmentEquals(expected, fragment.get(), false);
+ return Status::OK();
+ }));
+
+ if (ensure_drained) {
+ EnsureRecordBatchReaderDrained(expected);
+ }
+ }
+
+ /// \brief Ensure that record batches found in reader are equals to the
+ /// record batches yielded by a scanner.
+ void AssertScannerEquals(RecordBatchReader* expected, Scanner* scanner,
+ bool ensure_drained = true) {
+ ASSERT_OK_AND_ASSIGN(auto it, scanner->ScanBatches());
+
+ ARROW_EXPECT_OK(it.Visit([&](TaggedRecordBatch batch) -> Status {
+ std::shared_ptr<RecordBatch> lhs;
+ RETURN_NOT_OK(expected->ReadNext(&lhs));
+ EXPECT_NE(lhs, nullptr);
+ AssertBatchesEqual(*lhs, *batch.record_batch);
+ return Status::OK();
+ }));
+
+ if (ensure_drained) {
+ EnsureRecordBatchReaderDrained(expected);
+ }
+ }
+
+ /// \brief Ensure that record batches found in reader are equals to the
+ /// record batches yielded by a scanner.
+ void AssertScanBatchesEquals(RecordBatchReader* expected, Scanner* scanner,
+ bool ensure_drained = true) {
+ ASSERT_OK_AND_ASSIGN(auto it, scanner->ScanBatches());
+
+ ARROW_EXPECT_OK(it.Visit([&](TaggedRecordBatch batch) -> Status {
+ AssertBatchEquals(expected, *batch.record_batch);
+ return Status::OK();
+ }));
+
+ if (ensure_drained) {
+ EnsureRecordBatchReaderDrained(expected);
+ }
+ }
+
+ /// \brief Ensure that record batches found in reader are equals to the
+ /// record batches yielded by a scanner.
+ void AssertScanBatchesUnorderedEquals(RecordBatchReader* expected, Scanner* scanner,
+ int expected_batches_per_fragment,
+ bool ensure_drained = true) {
+ ASSERT_OK_AND_ASSIGN(auto it, scanner->ScanBatchesUnordered());
+
+ // ToVector does not work since EnumeratedRecordBatch is not comparable
+ std::vector<EnumeratedRecordBatch> batches;
+ for (;;) {
+ ASSERT_OK_AND_ASSIGN(auto batch, it.Next());
+ if (IsIterationEnd(batch)) break;
+ batches.push_back(std::move(batch));
+ }
+ std::sort(batches.begin(), batches.end(),
+ [](const EnumeratedRecordBatch& left,
+ const EnumeratedRecordBatch& right) -> bool {
+ if (left.fragment.index < right.fragment.index) {
+ return true;
+ }
+ if (left.fragment.index > right.fragment.index) {
+ return false;
+ }
+ return left.record_batch.index < right.record_batch.index;
+ });
+
+ int fragment_counter = 0;
+ bool saw_last_fragment = false;
+ int batch_counter = 0;
+
+ for (const auto& batch : batches) {
+ if (batch_counter == 0) {
+ EXPECT_FALSE(saw_last_fragment);
+ }
+ EXPECT_EQ(batch_counter++, batch.record_batch.index);
+ auto last_batch = batch_counter == expected_batches_per_fragment;
+ EXPECT_EQ(last_batch, batch.record_batch.last);
+ EXPECT_EQ(fragment_counter, batch.fragment.index);
+ if (last_batch) {
+ fragment_counter++;
+ batch_counter = 0;
+ }
+ saw_last_fragment = batch.fragment.last;
+ AssertBatchEquals(expected, *batch.record_batch.value);
+ }
+
+ if (ensure_drained) {
+ EnsureRecordBatchReaderDrained(expected);
+ }
+ }
+
+ /// \brief Ensure that record batches found in reader are equals to the
+ /// record batches yielded by a dataset.
+ void AssertDatasetEquals(RecordBatchReader* expected, Dataset* dataset,
+ bool ensure_drained = true) {
+ ASSERT_OK_AND_ASSIGN(auto builder, dataset->NewScan());
+ ASSERT_OK_AND_ASSIGN(auto scanner, builder->Finish());
+ AssertScannerEquals(expected, scanner.get());
+
+ if (ensure_drained) {
+ EnsureRecordBatchReaderDrained(expected);
+ }
+ }
+
+ protected:
+ void SetSchema(std::vector<std::shared_ptr<Field>> fields) {
+ schema_ = schema(std::move(fields));
+ options_ = std::make_shared<ScanOptions>();
+ options_->dataset_schema = schema_;
+ ASSERT_OK(SetProjection(options_.get(), schema_->field_names()));
+ SetFilter(literal(true));
+ }
+
+ void SetFilter(compute::Expression filter) {
+ ASSERT_OK_AND_ASSIGN(options_->filter, filter.Bind(*schema_));
+ }
+
+ void SetProjectedColumns(std::vector<std::string> column_names) {
+ ASSERT_OK(SetProjection(options_.get(), std::move(column_names)));
+ }
+
+ std::shared_ptr<Schema> schema_;
+ std::shared_ptr<ScanOptions> options_;
+};
+
+template <typename P>
+class DatasetFixtureMixinWithParam : public DatasetFixtureMixin,
+ public ::testing::WithParamInterface<P> {};
+
+struct TestFormatParams {
+ bool use_async;
+ bool use_threads;
+ int num_batches;
+ int items_per_batch;
+
+ int64_t expected_rows() const { return num_batches * items_per_batch; }
+
+ std::string ToString() const {
+ // GTest requires this to be alphanumeric
+ std::stringstream ss;
+ ss << (use_async ? "Async" : "Sync") << (use_threads ? "Threaded" : "Serial")
+ << num_batches << "b" << items_per_batch << "r";
+ return ss.str();
+ }
+
+ static std::string ToTestNameString(
+ const ::testing::TestParamInfo<TestFormatParams>& info) {
+ return std::to_string(info.index) + info.param.ToString();
+ }
+
+ static std::vector<TestFormatParams> Values() {
+ std::vector<TestFormatParams> values;
+ for (const bool async : std::vector<bool>{true, false}) {
+ for (const bool use_threads : std::vector<bool>{true, false}) {
+ values.push_back(TestFormatParams{async, use_threads, 16, 1024});
+ }
+ }
+ return values;
+ }
+};
+
+std::ostream& operator<<(std::ostream& out, const TestFormatParams& params) {
+ out << params.ToString();
+ return out;
+}
+
+class FileFormatWriterMixin {
+ virtual std::shared_ptr<Buffer> Write(RecordBatchReader* reader) = 0;
+ virtual std::shared_ptr<Buffer> Write(const Table& table) = 0;
+};
+
+/// FormatHelper should be a class with these static methods:
+/// std::shared_ptr<Buffer> Write(RecordBatchReader* reader);
+/// std::shared_ptr<FileFormat> MakeFormat();
+template <typename FormatHelper>
+class FileFormatFixtureMixin : public ::testing::Test {
+ public:
+ constexpr static int64_t kBatchSize = 1UL << 12;
+ constexpr static int64_t kBatchRepetitions = 1 << 5;
+
+ FileFormatFixtureMixin()
+ : format_(FormatHelper::MakeFormat()), opts_(std::make_shared<ScanOptions>()) {}
+
+ int64_t expected_batches() const { return kBatchRepetitions; }
+ int64_t expected_rows() const { return kBatchSize * kBatchRepetitions; }
+
+ std::shared_ptr<FileFragment> MakeFragment(const FileSource& source) {
+ EXPECT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(source));
+ return fragment;
+ }
+
+ std::shared_ptr<FileFragment> MakeFragment(const FileSource& source,
+ compute::Expression partition_expression) {
+ EXPECT_OK_AND_ASSIGN(auto fragment,
+ format_->MakeFragment(source, partition_expression));
+ return fragment;
+ }
+
+ std::shared_ptr<FileSource> GetFileSource(RecordBatchReader* reader) {
+ EXPECT_OK_AND_ASSIGN(auto buffer, FormatHelper::Write(reader));
+ return std::make_shared<FileSource>(std::move(buffer));
+ }
+
+ virtual std::shared_ptr<RecordBatchReader> GetRecordBatchReader(
+ std::shared_ptr<Schema> schema) {
+ return MakeGeneratedRecordBatch(schema, kBatchSize, kBatchRepetitions);
+ }
+
+ Result<std::shared_ptr<io::BufferOutputStream>> GetFileSink() {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ResizableBuffer> buffer,
+ AllocateResizableBuffer(0));
+ return std::make_shared<io::BufferOutputStream>(buffer);
+ }
+
+ void SetSchema(std::vector<std::shared_ptr<Field>> fields) {
+ opts_->dataset_schema = schema(std::move(fields));
+ ASSERT_OK(SetProjection(opts_.get(), opts_->dataset_schema->field_names()));
+ }
+
+ void SetFilter(compute::Expression filter) {
+ ASSERT_OK_AND_ASSIGN(opts_->filter, filter.Bind(*opts_->dataset_schema));
+ }
+
+ void Project(std::vector<std::string> names) {
+ ASSERT_OK(SetProjection(opts_.get(), std::move(names)));
+ }
+
+ // Shared test cases
+ void AssertInspectFailure(const std::string& contents, StatusCode code,
+ const std::string& format_name) {
+ SCOPED_TRACE("Format: " + format_name + " File contents: " + contents);
+ constexpr auto file_name = "herp/derp";
+ auto make_error_message = [&](const std::string& filename) {
+ return "Could not open " + format_name + " input source '" + filename + "':";
+ };
+ const auto buf = std::make_shared<Buffer>(contents);
+ Status status;
+
+ status = format_->Inspect(FileSource(buf)).status();
+ EXPECT_EQ(code, status.code());
+ EXPECT_THAT(status.ToString(), ::testing::HasSubstr(make_error_message("<Buffer>")));
+
+ ASSERT_OK_AND_EQ(false, format_->IsSupported(FileSource(buf)));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto fs, fs::internal::MockFileSystem::Make(fs::kNoTime, {fs::File(file_name)}));
+ status = format_->Inspect({file_name, fs}).status();
+ EXPECT_EQ(code, status.code());
+ EXPECT_THAT(status.ToString(), testing::HasSubstr(make_error_message("herp/derp")));
+
+ fs::FileSelector s;
+ s.base_dir = "/";
+ s.recursive = true;
+ FileSystemFactoryOptions options;
+ ASSERT_OK_AND_ASSIGN(auto factory,
+ FileSystemDatasetFactory::Make(fs, s, format_, options));
+ status = factory->Finish().status();
+ EXPECT_EQ(code, status.code());
+ EXPECT_THAT(
+ status.ToString(),
+ ::testing::AllOf(
+ ::testing::HasSubstr(make_error_message("/herp/derp")),
+ ::testing::HasSubstr(
+ "Error creating dataset. Could not read schema from '/herp/derp':"),
+ ::testing::HasSubstr("Is this a '" + format_->type_name() + "' file?")));
+ }
+
+ void TestInspectFailureWithRelevantError(StatusCode code,
+ const std::string& format_name) {
+ const std::vector<std::string> file_contents{"", "PAR0", "ASDFPAR1", "ARROW1"};
+ for (const auto& contents : file_contents) {
+ AssertInspectFailure(contents, code, format_name);
+ }
+ }
+
+ void TestInspect() {
+ auto reader = GetRecordBatchReader(schema({field("f64", float64())}));
+ auto source = GetFileSource(reader.get());
+
+ ASSERT_OK_AND_ASSIGN(auto actual, format_->Inspect(*source.get()));
+ AssertSchemaEqual(*actual, *reader->schema(), /*check_metadata=*/false);
+ }
+ void TestIsSupported() {
+ auto reader = GetRecordBatchReader(schema({field("f64", float64())}));
+ auto source = GetFileSource(reader.get());
+
+ bool supported = false;
+
+ std::shared_ptr<Buffer> buf = std::make_shared<Buffer>(util::string_view(""));
+ ASSERT_OK_AND_ASSIGN(supported, format_->IsSupported(FileSource(buf)));
+ ASSERT_EQ(supported, false);
+
+ buf = std::make_shared<Buffer>(util::string_view("corrupted"));
+ ASSERT_OK_AND_ASSIGN(supported, format_->IsSupported(FileSource(buf)));
+ ASSERT_EQ(supported, false);
+
+ ASSERT_OK_AND_ASSIGN(supported, format_->IsSupported(*source));
+ EXPECT_EQ(supported, true);
+ }
+ std::shared_ptr<Buffer> WriteToBuffer(
+ std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options = nullptr) {
+ auto format = format_;
+ SetSchema(schema->fields());
+ EXPECT_OK_AND_ASSIGN(auto sink, GetFileSink());
+
+ if (!options) options = format->DefaultWriteOptions();
+ EXPECT_OK_AND_ASSIGN(auto writer, format->MakeWriter(sink, schema, options, {}));
+ ARROW_EXPECT_OK(writer->Write(GetRecordBatchReader(schema).get()));
+ ARROW_EXPECT_OK(writer->Finish());
+ EXPECT_OK_AND_ASSIGN(auto written, sink->Finish());
+ return written;
+ }
+ void TestWrite() {
+ auto reader = this->GetRecordBatchReader(schema({field("f64", float64())}));
+ auto source = this->GetFileSource(reader.get());
+ auto written = this->WriteToBuffer(reader->schema());
+ AssertBufferEqual(*written, *source->buffer());
+ }
+ void TestCountRows() {
+ auto options = std::make_shared<ScanOptions>();
+ auto reader = this->GetRecordBatchReader(schema({field("f64", float64())}));
+ auto full_schema = schema({field("f64", float64()), field("part", int64())});
+ auto source = this->GetFileSource(reader.get());
+
+ auto fragment = this->MakeFragment(*source);
+ ASSERT_FINISHES_OK_AND_EQ(util::make_optional<int64_t>(expected_rows()),
+ fragment->CountRows(literal(true), options));
+
+ fragment = this->MakeFragment(*source, equal(field_ref("part"), literal(2)));
+ ASSERT_FINISHES_OK_AND_EQ(util::make_optional<int64_t>(expected_rows()),
+ fragment->CountRows(literal(true), options));
+
+ auto predicate = equal(field_ref("part"), literal(1));
+ ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*full_schema));
+ ASSERT_FINISHES_OK_AND_EQ(util::make_optional<int64_t>(0),
+ fragment->CountRows(predicate, options));
+
+ predicate = equal(field_ref("part"), literal(2));
+ ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*full_schema));
+ ASSERT_FINISHES_OK_AND_EQ(util::make_optional<int64_t>(expected_rows()),
+ fragment->CountRows(predicate, options));
+
+ predicate = equal(call("add", {field_ref("f64"), literal(3)}), literal(2));
+ ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*full_schema));
+ ASSERT_FINISHES_OK_AND_EQ(util::nullopt, fragment->CountRows(predicate, options));
+ }
+
+ protected:
+ std::shared_ptr<typename FormatHelper::FormatType> format_;
+ std::shared_ptr<ScanOptions> opts_;
+};
+
+template <typename FormatHelper>
+class FileFormatScanMixin : public FileFormatFixtureMixin<FormatHelper>,
+ public ::testing::WithParamInterface<TestFormatParams> {
+ public:
+ int64_t expected_batches() const { return GetParam().num_batches; }
+ int64_t expected_rows() const { return GetParam().expected_rows(); }
+
+ std::shared_ptr<RecordBatchReader> GetRecordBatchReader(
+ std::shared_ptr<Schema> schema) override {
+ return MakeGeneratedRecordBatch(schema, GetParam().items_per_batch,
+ GetParam().num_batches);
+ }
+
+ // Scan the fragment through the scanner.
+ RecordBatchIterator Batches(std::shared_ptr<Fragment> fragment) {
+ auto dataset = std::make_shared<FragmentDataset>(opts_->dataset_schema,
+ FragmentVector{fragment});
+ ScannerBuilder builder(dataset, opts_);
+ ARROW_EXPECT_OK(builder.UseAsync(GetParam().use_async));
+ ARROW_EXPECT_OK(builder.UseThreads(GetParam().use_threads));
+ EXPECT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+ EXPECT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
+ return MakeMapIterator([](TaggedRecordBatch tagged) { return tagged.record_batch; },
+ std::move(batch_it));
+ }
+
+ // Scan the fragment directly, without using the scanner.
+ RecordBatchIterator PhysicalBatches(std::shared_ptr<Fragment> fragment) {
+ opts_->use_threads = GetParam().use_threads;
+ if (GetParam().use_async) {
+ EXPECT_OK_AND_ASSIGN(auto batch_gen, fragment->ScanBatchesAsync(opts_));
+ auto batch_it = MakeGeneratorIterator(std::move(batch_gen));
+ return batch_it;
+ }
+ EXPECT_OK_AND_ASSIGN(auto scan_task_it, fragment->Scan(opts_));
+ return MakeFlattenIterator(MakeMaybeMapIterator(
+ [](std::shared_ptr<ScanTask> scan_task) { return scan_task->Execute(); },
+ std::move(scan_task_it)));
+ }
+
+ // Shared test cases
+ void TestScan() {
+ auto reader = GetRecordBatchReader(schema({field("f64", float64())}));
+ auto source = this->GetFileSource(reader.get());
+
+ this->SetSchema(reader->schema()->fields());
+ auto fragment = this->MakeFragment(*source);
+
+ int64_t row_count = 0;
+ for (auto maybe_batch : Batches(fragment)) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ row_count += batch->num_rows();
+ }
+ ASSERT_EQ(row_count, GetParam().expected_rows());
+ }
+ // Ensure batch_size is respected
+ void TestScanBatchSize() {
+ constexpr int kBatchSize = 17;
+ auto reader = GetRecordBatchReader(schema({field("f64", float64())}));
+ auto source = this->GetFileSource(reader.get());
+
+ this->SetSchema(reader->schema()->fields());
+ auto fragment = this->MakeFragment(*source);
+
+ int64_t row_count = 0;
+ opts_->batch_size = kBatchSize;
+ for (auto maybe_batch : Batches(fragment)) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ ASSERT_LE(batch->num_rows(), kBatchSize);
+ row_count += batch->num_rows();
+ }
+ ASSERT_EQ(row_count, GetParam().expected_rows());
+ }
+ // Ensure file formats only return columns needed to fulfill filter/projection
+ void TestScanProjected() {
+ auto f32 = field("f32", float32());
+ auto f64 = field("f64", float64());
+ auto i32 = field("i32", int32());
+ auto i64 = field("i64", int64());
+ this->SetSchema({f64, i64, f32, i32});
+ this->Project({"f64"});
+ this->SetFilter(equal(field_ref("i32"), literal(0)));
+
+ // NB: projection is applied by the scanner; FileFragment does not evaluate it so
+ // we will not drop "i32" even though it is not projected since we need it for
+ // filtering
+ auto expected_schema = schema({f64, i32});
+
+ auto reader = this->GetRecordBatchReader(opts_->dataset_schema);
+ auto source = this->GetFileSource(reader.get());
+ auto fragment = this->MakeFragment(*source);
+
+ int64_t row_count = 0;
+
+ for (auto maybe_batch : PhysicalBatches(fragment)) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ row_count += batch->num_rows();
+ AssertSchemaEqual(*batch->schema(), *expected_schema,
+ /*check_metadata=*/false);
+ }
+
+ ASSERT_EQ(row_count, expected_rows());
+ }
+ void TestScanProjectedMissingCols() {
+ auto f32 = field("f32", float32());
+ auto f64 = field("f64", float64());
+ auto i32 = field("i32", int32());
+ auto i64 = field("i64", int64());
+ this->SetSchema({f64, i64, f32, i32});
+ this->Project({"f64"});
+ this->SetFilter(equal(field_ref("i32"), literal(0)));
+
+ auto reader_without_i32 = this->GetRecordBatchReader(schema({f64, i64, f32}));
+ auto reader_without_f64 = this->GetRecordBatchReader(schema({i64, f32, i32}));
+ auto reader = this->GetRecordBatchReader(schema({f64, i64, f32, i32}));
+
+ auto readers = {reader.get(), reader_without_i32.get(), reader_without_f64.get()};
+ for (auto reader : readers) {
+ SCOPED_TRACE(reader->schema()->ToString());
+ auto source = this->GetFileSource(reader);
+ auto fragment = this->MakeFragment(*source);
+
+ // NB: projection is applied by the scanner; FileFragment does not evaluate it so
+ // we will not drop "i32" even though it is not projected since we need it for
+ // filtering
+ //
+ // in the case where a file doesn't contain a referenced field, we won't
+ // materialize it as nulls later
+ std::shared_ptr<Schema> expected_schema;
+ if (reader == reader_without_i32.get()) {
+ expected_schema = schema({f64});
+ } else if (reader == reader_without_f64.get()) {
+ expected_schema = schema({i32});
+ } else {
+ expected_schema = schema({f64, i32});
+ }
+
+ int64_t row_count = 0;
+ for (auto maybe_batch : PhysicalBatches(fragment)) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ row_count += batch->num_rows();
+ AssertSchemaEqual(*batch->schema(), *expected_schema,
+ /*check_metadata=*/false);
+ }
+ ASSERT_EQ(row_count, expected_rows());
+ }
+ }
+ void TestScanWithVirtualColumn() {
+ auto reader = this->GetRecordBatchReader(schema({field("f64", float64())}));
+ auto source = this->GetFileSource(reader.get());
+ // NB: dataset_schema includes a column not present in the file
+ this->SetSchema({reader->schema()->field(0), field("virtual", int32())});
+ auto fragment = this->MakeFragment(*source);
+
+ ASSERT_OK_AND_ASSIGN(auto physical_schema, fragment->ReadPhysicalSchema());
+ AssertSchemaEqual(Schema({field("f64", float64())}), *physical_schema);
+ {
+ int64_t row_count = 0;
+ for (auto maybe_batch : Batches(fragment)) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ AssertSchemaEqual(*batch->schema(), *opts_->projected_schema);
+ row_count += batch->num_rows();
+ }
+ ASSERT_EQ(row_count, expected_rows());
+ }
+ {
+ int64_t row_count = 0;
+ for (auto maybe_batch : PhysicalBatches(fragment)) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ AssertSchemaEqual(*batch->schema(), *physical_schema);
+ row_count += batch->num_rows();
+ }
+ ASSERT_EQ(row_count, expected_rows());
+ }
+ }
+
+ protected:
+ using FileFormatFixtureMixin<FormatHelper>::opts_;
+};
+
+/// \brief A dummy FileFormat implementation
+class DummyFileFormat : public FileFormat {
+ public:
+ explicit DummyFileFormat(std::shared_ptr<Schema> schema = NULLPTR)
+ : schema_(std::move(schema)) {}
+
+ std::string type_name() const override { return "dummy"; }
+
+ bool Equals(const FileFormat& other) const override {
+ return type_name() == other.type_name() &&
+ schema_->Equals(checked_cast<const DummyFileFormat&>(other).schema_);
+ }
+
+ Result<bool> IsSupported(const FileSource& source) const override { return true; }
+
+ Result<std::shared_ptr<Schema>> Inspect(const FileSource& source) const override {
+ return schema_;
+ }
+
+ /// \brief Open a file for scanning (always returns an empty iterator)
+ Result<ScanTaskIterator> ScanFile(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& fragment) const override {
+ return MakeEmptyIterator<std::shared_ptr<ScanTask>>();
+ }
+
+ Result<std::shared_ptr<FileWriter>> MakeWriter(
+ std::shared_ptr<io::OutputStream> destination, std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options,
+ fs::FileLocator destination_locator) const override {
+ return Status::NotImplemented("writing fragment of DummyFileFormat");
+ }
+
+ std::shared_ptr<FileWriteOptions> DefaultWriteOptions() override { return nullptr; }
+
+ protected:
+ std::shared_ptr<Schema> schema_;
+};
+
+class JSONRecordBatchFileFormat : public FileFormat {
+ public:
+ using SchemaResolver = std::function<std::shared_ptr<Schema>(const FileSource&)>;
+
+ explicit JSONRecordBatchFileFormat(std::shared_ptr<Schema> schema)
+ : resolver_([schema](const FileSource&) { return schema; }) {}
+
+ explicit JSONRecordBatchFileFormat(SchemaResolver resolver)
+ : resolver_(std::move(resolver)) {}
+
+ bool Equals(const FileFormat& other) const override { return this == &other; }
+
+ std::string type_name() const override { return "json_record_batch"; }
+
+ /// \brief Return true if the given file extension
+ Result<bool> IsSupported(const FileSource& source) const override { return true; }
+
+ Result<std::shared_ptr<Schema>> Inspect(const FileSource& source) const override {
+ return resolver_(source);
+ }
+
+ /// \brief Open a file for scanning
+ Result<ScanTaskIterator> ScanFile(
+ const std::shared_ptr<ScanOptions>& options,
+ const std::shared_ptr<FileFragment>& fragment) const override {
+ ARROW_ASSIGN_OR_RAISE(auto file, fragment->source().Open());
+ ARROW_ASSIGN_OR_RAISE(int64_t size, file->GetSize());
+ ARROW_ASSIGN_OR_RAISE(auto buffer, file->Read(size));
+ ARROW_ASSIGN_OR_RAISE(auto schema, Inspect(fragment->source()));
+
+ RecordBatchVector batches{RecordBatchFromJSON(schema, util::string_view{*buffer})};
+ return std::make_shared<InMemoryFragment>(std::move(schema), std::move(batches))
+ ->Scan(std::move(options));
+ }
+
+ Result<std::shared_ptr<FileWriter>> MakeWriter(
+ std::shared_ptr<io::OutputStream> destination, std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options,
+ fs::FileLocator destination_locator) const override {
+ return Status::NotImplemented("writing fragment of JSONRecordBatchFileFormat");
+ }
+
+ std::shared_ptr<FileWriteOptions> DefaultWriteOptions() override { return nullptr; }
+
+ protected:
+ SchemaResolver resolver_;
+};
+
+struct MakeFileSystemDatasetMixin {
+ std::vector<fs::FileInfo> ParsePathList(const std::string& pathlist) {
+ std::vector<fs::FileInfo> infos;
+
+ std::stringstream ss(pathlist);
+ std::string line;
+ while (std::getline(ss, line)) {
+ auto start = line.find_first_not_of(" \n\r\t");
+ if (start == std::string::npos) {
+ continue;
+ }
+ line.erase(0, start);
+
+ if (line.front() == '#') {
+ continue;
+ }
+
+ if (line.back() == '/') {
+ infos.push_back(fs::Dir(line));
+ continue;
+ }
+
+ infos.push_back(fs::File(line));
+ }
+
+ return infos;
+ }
+
+ void MakeFileSystem(const std::vector<fs::FileInfo>& infos) {
+ ASSERT_OK_AND_ASSIGN(fs_, fs::internal::MockFileSystem::Make(fs::kNoTime, infos));
+ }
+
+ void MakeFileSystem(const std::vector<std::string>& paths) {
+ std::vector<fs::FileInfo> infos{paths.size()};
+ std::transform(paths.cbegin(), paths.cend(), infos.begin(),
+ [](const std::string& p) { return fs::File(p); });
+
+ ASSERT_OK_AND_ASSIGN(fs_, fs::internal::MockFileSystem::Make(fs::kNoTime, infos));
+ }
+
+ void MakeDataset(const std::vector<fs::FileInfo>& infos,
+ compute::Expression root_partition = literal(true),
+ std::vector<compute::Expression> partitions = {},
+ std::shared_ptr<Schema> s = schema({})) {
+ auto n_fragments = infos.size();
+ if (partitions.empty()) {
+ partitions.resize(n_fragments, literal(true));
+ }
+
+ MakeFileSystem(infos);
+ auto format = std::make_shared<DummyFileFormat>(s);
+
+ std::vector<std::shared_ptr<FileFragment>> fragments;
+ for (size_t i = 0; i < n_fragments; i++) {
+ const auto& info = infos[i];
+ if (!info.IsFile()) {
+ continue;
+ }
+
+ ASSERT_OK_AND_ASSIGN(partitions[i], partitions[i].Bind(*s));
+ ASSERT_OK_AND_ASSIGN(auto fragment,
+ format->MakeFragment({info, fs_}, partitions[i]));
+ fragments.push_back(std::move(fragment));
+ }
+
+ ASSERT_OK_AND_ASSIGN(root_partition, root_partition.Bind(*s));
+ ASSERT_OK_AND_ASSIGN(dataset_, FileSystemDataset::Make(s, root_partition, format, fs_,
+ std::move(fragments)));
+ }
+
+ std::shared_ptr<fs::FileSystem> fs_;
+ std::shared_ptr<Dataset> dataset_;
+ std::shared_ptr<ScanOptions> options_;
+};
+
+static const std::string& PathOf(const std::shared_ptr<Fragment>& fragment) {
+ EXPECT_NE(fragment, nullptr);
+ EXPECT_THAT(fragment->type_name(), "dummy");
+ return checked_cast<const FileFragment&>(*fragment).source().path();
+}
+
+class TestFileSystemDataset : public ::testing::Test,
+ public MakeFileSystemDatasetMixin {};
+
+static std::vector<std::string> PathsOf(const FragmentVector& fragments) {
+ std::vector<std::string> paths(fragments.size());
+ std::transform(fragments.begin(), fragments.end(), paths.begin(), PathOf);
+ return paths;
+}
+
+void AssertFilesAre(const std::shared_ptr<Dataset>& dataset,
+ std::vector<std::string> expected) {
+ auto fs_dataset = checked_cast<FileSystemDataset*>(dataset.get());
+ EXPECT_THAT(fs_dataset->files(), testing::UnorderedElementsAreArray(expected));
+}
+
+void AssertFragmentsAreFromPath(FragmentIterator it, std::vector<std::string> expected) {
+ // Ordering is not guaranteed.
+ EXPECT_THAT(PathsOf(IteratorToVector(std::move(it))),
+ testing::UnorderedElementsAreArray(expected));
+}
+
+static std::vector<compute::Expression> PartitionExpressionsOf(
+ const FragmentVector& fragments) {
+ std::vector<compute::Expression> partition_expressions;
+ std::transform(fragments.begin(), fragments.end(),
+ std::back_inserter(partition_expressions),
+ [](const std::shared_ptr<Fragment>& fragment) {
+ return fragment->partition_expression();
+ });
+ return partition_expressions;
+}
+
+void AssertFragmentsHavePartitionExpressions(std::shared_ptr<Dataset> dataset,
+ std::vector<compute::Expression> expected) {
+ ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments());
+ for (auto& expr : expected) {
+ ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*dataset->schema()));
+ }
+ // Ordering is not guaranteed.
+ EXPECT_THAT(PartitionExpressionsOf(IteratorToVector(std::move(fragment_it))),
+ testing::UnorderedElementsAreArray(expected));
+}
+
+struct ArithmeticDatasetFixture {
+ static std::shared_ptr<Schema> schema() {
+ return ::arrow::schema({
+ field("i64", int64()),
+ field("struct", struct_({
+ field("i32", int32()),
+ field("str", utf8()),
+ })),
+ field("u8", uint8()),
+ field("list", list(int32())),
+ field("bool", boolean()),
+ });
+ }
+
+ /// \brief Creates a single JSON record templated with n as follow.
+ ///
+ /// {"i64": n, "struct": {"i32": n, "str": "n"}, "u8": n "list": [n,n], "bool": n %
+ /// 2},
+ static std::string JSONRecordFor(int64_t n) {
+ std::stringstream ss;
+ auto n_i32 = static_cast<int32_t>(n);
+
+ ss << "{";
+ ss << "\"i64\": " << n << ", ";
+ ss << "\"struct\": {";
+ {
+ ss << "\"i32\": " << n_i32 << ", ";
+ ss << R"("str": ")" << std::to_string(n) << "\"";
+ }
+ ss << "}, ";
+ ss << "\"u8\": " << static_cast<int32_t>(n) << ", ";
+ ss << "\"list\": [" << n_i32 << ", " << n_i32 << "], ";
+ ss << "\"bool\": " << (static_cast<bool>(n % 2) ? "true" : "false");
+ ss << "}";
+
+ return ss.str();
+ }
+
+ /// \brief Creates a JSON RecordBatch
+ static std::string JSONRecordBatch(int64_t n) {
+ DCHECK_GT(n, 0);
+
+ auto record = JSONRecordFor(n);
+
+ std::stringstream ss;
+ ss << "[\n";
+ for (int64_t i = 1; i <= n; i++) {
+ if (i != 1) {
+ ss << "\n,";
+ }
+ ss << record;
+ }
+ ss << "]\n";
+ return ss.str();
+ }
+
+ static std::shared_ptr<RecordBatch> GetRecordBatch(int64_t n) {
+ return RecordBatchFromJSON(ArithmeticDatasetFixture::schema(), JSONRecordBatch(n));
+ }
+
+ static std::unique_ptr<RecordBatchReader> GetRecordBatchReader(int64_t n) {
+ DCHECK_GT(n, 0);
+
+ // Functor which generates `n` RecordBatch
+ struct {
+ Status operator()(std::shared_ptr<RecordBatch>* out) {
+ *out = i++ < count ? GetRecordBatch(i) : nullptr;
+ return Status::OK();
+ }
+ int64_t i;
+ int64_t count;
+ } generator{0, n};
+
+ return MakeGeneratedRecordBatch(schema(), std::move(generator));
+ }
+};
+
+class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin {
+ public:
+ using PathAndContent = std::unordered_map<std::string, std::string>;
+
+ void MakeSourceDataset() {
+ PathAndContent source_files;
+
+ source_files["/dataset/year=2018/month=01/dat0.json"] = R"([
+ {"region": "NY", "model": "3", "sales": 742.0, "country": "US"},
+ {"region": "NY", "model": "S", "sales": 304.125, "country": "US"},
+ {"region": "NY", "model": "Y", "sales": 27.5, "country": "US"}
+ ])";
+ source_files["/dataset/year=2018/month=01/dat1.json"] = R"([
+ {"region": "QC", "model": "3", "sales": 512, "country": "CA"},
+ {"region": "QC", "model": "S", "sales": 978, "country": "CA"},
+ {"region": "NY", "model": "X", "sales": 136.25, "country": "US"},
+ {"region": "QC", "model": "X", "sales": 1.0, "country": "CA"},
+ {"region": "QC", "model": "Y", "sales": 69, "country": "CA"}
+ ])";
+ source_files["/dataset/year=2019/month=01/dat0.json"] = R"([
+ {"region": "CA", "model": "3", "sales": 273.5, "country": "US"},
+ {"region": "CA", "model": "S", "sales": 13, "country": "US"},
+ {"region": "CA", "model": "X", "sales": 54, "country": "US"},
+ {"region": "QC", "model": "S", "sales": 10, "country": "CA"},
+ {"region": "CA", "model": "Y", "sales": 21, "country": "US"}
+ ])";
+ source_files["/dataset/year=2019/month=01/dat1.json"] = R"([
+ {"region": "QC", "model": "3", "sales": 152.25, "country": "CA"},
+ {"region": "QC", "model": "X", "sales": 42, "country": "CA"},
+ {"region": "QC", "model": "Y", "sales": 37, "country": "CA"}
+ ])";
+ source_files["/dataset/.pesky"] = "garbage content";
+
+ auto mock_fs = std::make_shared<fs::internal::MockFileSystem>(fs::kNoTime);
+ for (const auto& f : source_files) {
+ ARROW_EXPECT_OK(mock_fs->CreateFile(f.first, f.second, /* recursive */ true));
+ }
+ fs_ = mock_fs;
+
+ /// schema for the whole dataset (both source and destination)
+ source_schema_ = schema({
+ field("region", utf8()),
+ field("model", utf8()),
+ field("sales", float64()),
+ field("year", int32()),
+ field("month", int32()),
+ field("country", utf8()),
+ });
+
+ /// Dummy file format for source dataset. Note that it isn't partitioned on country
+ auto source_format = std::make_shared<JSONRecordBatchFileFormat>(
+ SchemaFromColumnNames(source_schema_, {"region", "model", "sales", "country"}));
+
+ fs::FileSelector s;
+ s.base_dir = "/dataset";
+ s.recursive = true;
+
+ FileSystemFactoryOptions options;
+ options.selector_ignore_prefixes = {"."};
+ options.partitioning = std::make_shared<HivePartitioning>(
+ SchemaFromColumnNames(source_schema_, {"year", "month"}));
+ ASSERT_OK_AND_ASSIGN(auto factory,
+ FileSystemDatasetFactory::Make(fs_, s, source_format, options));
+ ASSERT_OK_AND_ASSIGN(dataset_, factory->Finish());
+
+ scan_options_ = std::make_shared<ScanOptions>();
+ scan_options_->dataset_schema = dataset_->schema();
+ ASSERT_OK(SetProjection(scan_options_.get(), source_schema_->field_names()));
+ }
+
+ void SetWriteOptions(std::shared_ptr<FileWriteOptions> file_write_options) {
+ write_options_.file_write_options = file_write_options;
+ write_options_.filesystem = fs_;
+ write_options_.base_dir = "/new_root/";
+ write_options_.basename_template = "dat_{i}";
+ write_options_.writer_pre_finish = [this](FileWriter* writer) {
+ visited_paths_.push_back(writer->destination().path);
+ return Status::OK();
+ };
+ }
+
+ void DoWrite(std::shared_ptr<Partitioning> desired_partitioning) {
+ write_options_.partitioning = desired_partitioning;
+ auto scanner_builder = ScannerBuilder(dataset_, scan_options_);
+ ASSERT_OK(scanner_builder.UseAsync(true));
+ ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder.Finish());
+ ASSERT_OK(FileSystemDataset::Write(write_options_, scanner));
+
+ // re-discover the written dataset
+ fs::FileSelector s;
+ s.recursive = true;
+ s.base_dir = "/new_root";
+
+ FileSystemFactoryOptions factory_options;
+ factory_options.partitioning = desired_partitioning;
+ ASSERT_OK_AND_ASSIGN(
+ auto factory, FileSystemDatasetFactory::Make(fs_, s, format_, factory_options));
+ ASSERT_OK_AND_ASSIGN(written_, factory->Finish());
+ }
+
+ void TestWriteWithIdenticalPartitioningSchema() {
+ DoWrite(std::make_shared<DirectoryPartitioning>(
+ SchemaFromColumnNames(source_schema_, {"year", "month"})));
+
+ expected_files_["/new_root/2018/1/dat_0"] = R"([
+ {"region": "QC", "model": "X", "sales": 1.0, "country": "CA"},
+ {"region": "NY", "model": "Y", "sales": 27.5, "country": "US"},
+ {"region": "QC", "model": "Y", "sales": 69, "country": "CA"},
+ {"region": "NY", "model": "X", "sales": 136.25, "country": "US"},
+ {"region": "NY", "model": "S", "sales": 304.125, "country": "US"},
+ {"region": "QC", "model": "3", "sales": 512, "country": "CA"},
+ {"region": "NY", "model": "3", "sales": 742.0, "country": "US"},
+ {"region": "QC", "model": "S", "sales": 978, "country": "CA"}
+ ])";
+ expected_files_["/new_root/2019/1/dat_0"] = R"([
+ {"region": "QC", "model": "S", "sales": 10, "country": "CA"},
+ {"region": "CA", "model": "S", "sales": 13, "country": "US"},
+ {"region": "CA", "model": "Y", "sales": 21, "country": "US"},
+ {"region": "QC", "model": "Y", "sales": 37, "country": "CA"},
+ {"region": "QC", "model": "X", "sales": 42, "country": "CA"},
+ {"region": "CA", "model": "X", "sales": 54, "country": "US"},
+ {"region": "QC", "model": "3", "sales": 152.25, "country": "CA"},
+ {"region": "CA", "model": "3", "sales": 273.5, "country": "US"}
+ ])";
+ expected_physical_schema_ =
+ SchemaFromColumnNames(source_schema_, {"region", "model", "sales", "country"});
+
+ AssertWrittenAsExpected();
+ }
+
+ void TestWriteWithUnrelatedPartitioningSchema() {
+ DoWrite(std::make_shared<DirectoryPartitioning>(
+ SchemaFromColumnNames(source_schema_, {"country", "region"})));
+
+ // XXX first thing a user will be annoyed by: we don't support left
+ // padding the month field with 0.
+ expected_files_["/new_root/US/NY/dat_0"] = R"([
+ {"year": 2018, "month": 1, "model": "Y", "sales": 27.5},
+ {"year": 2018, "month": 1, "model": "X", "sales": 136.25},
+ {"year": 2018, "month": 1, "model": "S", "sales": 304.125},
+ {"year": 2018, "month": 1, "model": "3", "sales": 742.0}
+ ])";
+ expected_files_["/new_root/CA/QC/dat_0"] = R"([
+ {"year": 2018, "month": 1, "model": "X", "sales": 1.0},
+ {"year": 2019, "month": 1, "model": "S", "sales": 10},
+ {"year": 2019, "month": 1, "model": "Y", "sales": 37},
+ {"year": 2019, "month": 1, "model": "X", "sales": 42},
+ {"year": 2018, "month": 1, "model": "Y", "sales": 69},
+ {"year": 2019, "month": 1, "model": "3", "sales": 152.25},
+ {"year": 2018, "month": 1, "model": "3", "sales": 512},
+ {"year": 2018, "month": 1, "model": "S", "sales": 978}
+ ])";
+ expected_files_["/new_root/US/CA/dat_0"] = R"([
+ {"year": 2019, "month": 1, "model": "S", "sales": 13},
+ {"year": 2019, "month": 1, "model": "Y", "sales": 21},
+ {"year": 2019, "month": 1, "model": "X", "sales": 54},
+ {"year": 2019, "month": 1, "model": "3", "sales": 273.5}
+ ])";
+ expected_physical_schema_ =
+ SchemaFromColumnNames(source_schema_, {"model", "sales", "year", "month"});
+
+ AssertWrittenAsExpected();
+ }
+
+ void TestWriteWithSupersetPartitioningSchema() {
+ DoWrite(std::make_shared<DirectoryPartitioning>(
+ SchemaFromColumnNames(source_schema_, {"year", "month", "country", "region"})));
+
+ // XXX first thing a user will be annoyed by: we don't support left
+ // padding the month field with 0.
+ expected_files_["/new_root/2018/1/US/NY/dat_0"] = R"([
+ {"model": "Y", "sales": 27.5},
+ {"model": "X", "sales": 136.25},
+ {"model": "S", "sales": 304.125},
+ {"model": "3", "sales": 742.0}
+ ])";
+ expected_files_["/new_root/2018/1/CA/QC/dat_0"] = R"([
+ {"model": "X", "sales": 1.0},
+ {"model": "Y", "sales": 69},
+ {"model": "3", "sales": 512},
+ {"model": "S", "sales": 978}
+ ])";
+ expected_files_["/new_root/2019/1/US/CA/dat_0"] = R"([
+ {"model": "S", "sales": 13},
+ {"model": "Y", "sales": 21},
+ {"model": "X", "sales": 54},
+ {"model": "3", "sales": 273.5}
+ ])";
+ expected_files_["/new_root/2019/1/CA/QC/dat_0"] = R"([
+ {"model": "S", "sales": 10},
+ {"model": "Y", "sales": 37},
+ {"model": "X", "sales": 42},
+ {"model": "3", "sales": 152.25}
+ ])";
+ expected_physical_schema_ = SchemaFromColumnNames(source_schema_, {"model", "sales"});
+
+ AssertWrittenAsExpected();
+ }
+
+ void TestWriteWithEmptyPartitioningSchema() {
+ DoWrite(std::make_shared<DirectoryPartitioning>(
+ SchemaFromColumnNames(source_schema_, {})));
+
+ expected_files_["/new_root/dat_0"] = R"([
+ {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "X", "sales": 1.0},
+ {"country": "CA", "region": "QC", "year": 2019, "month": 1, "model": "S", "sales": 10},
+ {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "S", "sales": 13},
+ {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "Y", "sales": 21},
+ {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "Y", "sales": 27.5},
+ {"country": "CA", "region": "QC", "year": 2019, "month": 1, "model": "Y", "sales": 37},
+ {"country": "CA", "region": "QC", "year": 2019, "month": 1, "model": "X", "sales": 42},
+ {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "X", "sales": 54},
+ {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "Y", "sales": 69},
+ {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "X", "sales": 136.25},
+ {"country": "CA", "region": "QC", "year": 2019, "month": 1, "model": "3", "sales": 152.25},
+ {"country": "US", "region": "CA", "year": 2019, "month": 1, "model": "3", "sales": 273.5},
+ {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "S", "sales": 304.125},
+ {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "3", "sales": 512},
+ {"country": "US", "region": "NY", "year": 2018, "month": 1, "model": "3", "sales": 742.0},
+ {"country": "CA", "region": "QC", "year": 2018, "month": 1, "model": "S", "sales": 978}
+ ])";
+ expected_physical_schema_ = source_schema_;
+
+ AssertWrittenAsExpected();
+ }
+
+ void AssertWrittenAsExpected() {
+ std::unordered_set<std::string> expected_paths, actual_paths;
+ for (const auto& file_contents : expected_files_) {
+ expected_paths.insert(file_contents.first);
+ }
+
+ // expect the written filesystem to contain precisely the paths we expected
+ for (auto path : checked_pointer_cast<FileSystemDataset>(written_)->files()) {
+ actual_paths.insert(std::move(path));
+ }
+ EXPECT_THAT(actual_paths, testing::UnorderedElementsAreArray(expected_paths));
+
+ // Additionally, the writer producing each written file was visited and its path
+ // collected. That should match the expected paths as well
+ EXPECT_THAT(visited_paths_, testing::UnorderedElementsAreArray(expected_paths));
+
+ ASSERT_OK_AND_ASSIGN(auto written_fragments_it, written_->GetFragments());
+ for (auto maybe_fragment : written_fragments_it) {
+ ASSERT_OK_AND_ASSIGN(auto fragment, maybe_fragment);
+
+ ASSERT_OK_AND_ASSIGN(auto actual_physical_schema, fragment->ReadPhysicalSchema());
+ AssertSchemaEqual(*expected_physical_schema_, *actual_physical_schema,
+ check_metadata_);
+
+ const auto& path = checked_pointer_cast<FileFragment>(fragment)->source().path();
+
+ auto file_contents = expected_files_.find(path);
+ if (file_contents == expected_files_.end()) {
+ // file wasn't expected to be written at all; nothing to compare with
+ continue;
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto scanner, ScannerBuilder(actual_physical_schema, fragment,
+ std::make_shared<ScanOptions>())
+ .Finish());
+ ASSERT_OK_AND_ASSIGN(auto actual_table, scanner->ToTable());
+ ASSERT_OK_AND_ASSIGN(actual_table, actual_table->CombineChunks());
+ std::shared_ptr<Array> actual_struct;
+
+ for (auto maybe_batch :
+ MakeIteratorFromReader(std::make_shared<TableBatchReader>(*actual_table))) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ ASSERT_OK_AND_ASSIGN(
+ auto sort_indices,
+ compute::SortIndices(batch->GetColumnByName("sales"),
+ compute::SortOptions({compute::SortKey{"sales"}})));
+ ASSERT_OK_AND_ASSIGN(Datum sorted_batch, compute::Take(batch, sort_indices));
+ ASSERT_OK_AND_ASSIGN(actual_struct, sorted_batch.record_batch()->ToStructArray());
+ }
+
+ auto expected_struct = ArrayFromJSON(struct_(expected_physical_schema_->fields()),
+ {file_contents->second});
+
+ AssertArraysEqual(*expected_struct, *actual_struct, /*verbose=*/true);
+ }
+ }
+
+ bool check_metadata_ = true;
+ std::shared_ptr<Schema> source_schema_;
+ std::shared_ptr<FileFormat> format_;
+ PathAndContent expected_files_;
+ std::shared_ptr<Schema> expected_physical_schema_;
+ std::shared_ptr<Dataset> written_;
+ std::vector<std::string> visited_paths_;
+ FileSystemDatasetWriteOptions write_options_;
+ std::shared_ptr<ScanOptions> scan_options_;
+};
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/type_fwd.h b/src/arrow/cpp/src/arrow/dataset/type_fwd.h
new file mode 100644
index 000000000..78748a314
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/type_fwd.h
@@ -0,0 +1,104 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <memory>
+#include <vector>
+
+#include "arrow/compute/type_fwd.h" // IWYU pragma: export
+#include "arrow/dataset/visibility.h"
+#include "arrow/filesystem/type_fwd.h" // IWYU pragma: export
+#include "arrow/type_fwd.h" // IWYU pragma: export
+
+namespace arrow {
+namespace dataset {
+
+class Dataset;
+class DatasetFactory;
+using DatasetVector = std::vector<std::shared_ptr<Dataset>>;
+
+class UnionDataset;
+class UnionDatasetFactory;
+
+class Fragment;
+using FragmentIterator = Iterator<std::shared_ptr<Fragment>>;
+using FragmentVector = std::vector<std::shared_ptr<Fragment>>;
+
+class FragmentScanOptions;
+
+class FileSource;
+class FileFormat;
+class FileFragment;
+class FileWriter;
+class FileWriteOptions;
+class FileSystemDataset;
+class FileSystemDatasetFactory;
+struct FileSystemDatasetWriteOptions;
+
+/// \brief Controls what happens if files exist in an output directory during a dataset
+/// write
+enum class ExistingDataBehavior : int8_t {
+ /// Deletes all files in a directory the first time that directory is encountered
+ kDeleteMatchingPartitions,
+ /// Ignores existing files, overwriting any that happen to have the same name as an
+ /// output file
+ kOverwriteOrIgnore,
+ /// Returns an error if there are any files or subdirectories in the output directory
+ kError,
+};
+
+class InMemoryDataset;
+
+class CsvFileFormat;
+class CsvFileWriter;
+class CsvFileWriteOptions;
+struct CsvFragmentScanOptions;
+
+class IpcFileFormat;
+class IpcFileWriter;
+class IpcFileWriteOptions;
+class IpcFragmentScanOptions;
+
+class ParquetFileFormat;
+class ParquetFileFragment;
+class ParquetFragmentScanOptions;
+class ParquetFileWriter;
+class ParquetFileWriteOptions;
+
+class Partitioning;
+class PartitioningFactory;
+class PartitioningOrFactory;
+struct KeyValuePartitioningOptions;
+class DirectoryPartitioning;
+class HivePartitioning;
+struct HivePartitioningOptions;
+
+struct ScanOptions;
+
+class Scanner;
+
+class ScannerBuilder;
+
+class ScanTask;
+using ScanTaskVector = std::vector<std::shared_ptr<ScanTask>>;
+using ScanTaskIterator = Iterator<std::shared_ptr<ScanTask>>;
+
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dataset/visibility.h b/src/arrow/cpp/src/arrow/dataset/visibility.h
new file mode 100644
index 000000000..b43a25305
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dataset/visibility.h
@@ -0,0 +1,50 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#if defined(_WIN32) || defined(__CYGWIN__)
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4251)
+#else
+#pragma GCC diagnostic ignored "-Wattributes"
+#endif
+
+#ifdef ARROW_DS_STATIC
+#define ARROW_DS_EXPORT
+#elif defined(ARROW_DS_EXPORTING)
+#define ARROW_DS_EXPORT __declspec(dllexport)
+#else
+#define ARROW_DS_EXPORT __declspec(dllimport)
+#endif
+
+#define ARROW_DS_NO_EXPORT
+#else // Not Windows
+#ifndef ARROW_DS_EXPORT
+#define ARROW_DS_EXPORT __attribute__((visibility("default")))
+#endif
+#ifndef ARROW_DS_NO_EXPORT
+#define ARROW_DS_NO_EXPORT __attribute__((visibility("hidden")))
+#endif
+#endif // Non-Windows
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
diff --git a/src/arrow/cpp/src/arrow/datum.cc b/src/arrow/cpp/src/arrow/datum.cc
new file mode 100644
index 000000000..397e91de5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/datum.cc
@@ -0,0 +1,292 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/datum.h"
+
+#include <cstddef>
+#include <memory>
+#include <sstream>
+#include <vector>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/util.h"
+#include "arrow/chunked_array.h"
+#include "arrow/record_batch.h"
+#include "arrow/scalar.h"
+#include "arrow/table.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/memory.h"
+
+namespace arrow {
+
+static bool CollectionEquals(const std::vector<Datum>& left,
+ const std::vector<Datum>& right) {
+ if (left.size() != right.size()) {
+ return false;
+ }
+
+ for (size_t i = 0; i < left.size(); i++) {
+ if (!left[i].Equals(right[i])) {
+ return false;
+ }
+ }
+ return true;
+}
+
+Datum::Datum(const Array& value) : Datum(value.data()) {}
+
+Datum::Datum(const std::shared_ptr<Array>& value)
+ : Datum(value ? value->data() : NULLPTR) {}
+
+Datum::Datum(std::shared_ptr<ChunkedArray> value) : value(std::move(value)) {}
+Datum::Datum(std::shared_ptr<RecordBatch> value) : value(std::move(value)) {}
+Datum::Datum(std::shared_ptr<Table> value) : value(std::move(value)) {}
+Datum::Datum(std::vector<Datum> value) : value(std::move(value)) {}
+
+Datum::Datum(bool value) : value(std::make_shared<BooleanScalar>(value)) {}
+Datum::Datum(int8_t value) : value(std::make_shared<Int8Scalar>(value)) {}
+Datum::Datum(uint8_t value) : value(std::make_shared<UInt8Scalar>(value)) {}
+Datum::Datum(int16_t value) : value(std::make_shared<Int16Scalar>(value)) {}
+Datum::Datum(uint16_t value) : value(std::make_shared<UInt16Scalar>(value)) {}
+Datum::Datum(int32_t value) : value(std::make_shared<Int32Scalar>(value)) {}
+Datum::Datum(uint32_t value) : value(std::make_shared<UInt32Scalar>(value)) {}
+Datum::Datum(int64_t value) : value(std::make_shared<Int64Scalar>(value)) {}
+Datum::Datum(uint64_t value) : value(std::make_shared<UInt64Scalar>(value)) {}
+Datum::Datum(float value) : value(std::make_shared<FloatScalar>(value)) {}
+Datum::Datum(double value) : value(std::make_shared<DoubleScalar>(value)) {}
+Datum::Datum(std::string value)
+ : value(std::make_shared<StringScalar>(std::move(value))) {}
+Datum::Datum(const char* value) : value(std::make_shared<StringScalar>(value)) {}
+
+Datum::Datum(const ChunkedArray& value)
+ : value(std::make_shared<ChunkedArray>(value.chunks(), value.type())) {}
+
+Datum::Datum(const Table& value)
+ : value(Table::Make(value.schema(), value.columns(), value.num_rows())) {}
+
+Datum::Datum(const RecordBatch& value)
+ : value(RecordBatch::Make(value.schema(), value.num_rows(), value.columns())) {}
+
+std::shared_ptr<Array> Datum::make_array() const {
+ DCHECK_EQ(Datum::ARRAY, this->kind());
+ return MakeArray(util::get<std::shared_ptr<ArrayData>>(this->value));
+}
+
+const std::shared_ptr<DataType>& Datum::type() const {
+ if (this->kind() == Datum::ARRAY) {
+ return util::get<std::shared_ptr<ArrayData>>(this->value)->type;
+ }
+ if (this->kind() == Datum::CHUNKED_ARRAY) {
+ return util::get<std::shared_ptr<ChunkedArray>>(this->value)->type();
+ }
+ if (this->kind() == Datum::SCALAR) {
+ return util::get<std::shared_ptr<Scalar>>(this->value)->type;
+ }
+ static std::shared_ptr<DataType> no_type;
+ return no_type;
+}
+
+const std::shared_ptr<Schema>& Datum::schema() const {
+ if (this->kind() == Datum::RECORD_BATCH) {
+ return util::get<std::shared_ptr<RecordBatch>>(this->value)->schema();
+ }
+ if (this->kind() == Datum::TABLE) {
+ return util::get<std::shared_ptr<Table>>(this->value)->schema();
+ }
+ static std::shared_ptr<Schema> no_schema;
+ return no_schema;
+}
+
+int64_t Datum::length() const {
+ switch (this->kind()) {
+ case Datum::ARRAY:
+ return util::get<std::shared_ptr<ArrayData>>(this->value)->length;
+ case Datum::CHUNKED_ARRAY:
+ return util::get<std::shared_ptr<ChunkedArray>>(this->value)->length();
+ case Datum::RECORD_BATCH:
+ return util::get<std::shared_ptr<RecordBatch>>(this->value)->num_rows();
+ case Datum::TABLE:
+ return util::get<std::shared_ptr<Table>>(this->value)->num_rows();
+ case Datum::SCALAR:
+ return 1;
+ default:
+ return kUnknownLength;
+ }
+}
+
+int64_t Datum::null_count() const {
+ if (this->kind() == Datum::ARRAY) {
+ return util::get<std::shared_ptr<ArrayData>>(this->value)->GetNullCount();
+ } else if (this->kind() == Datum::CHUNKED_ARRAY) {
+ return util::get<std::shared_ptr<ChunkedArray>>(this->value)->null_count();
+ } else if (this->kind() == Datum::SCALAR) {
+ const auto& val = *util::get<std::shared_ptr<Scalar>>(this->value);
+ return val.is_valid ? 0 : 1;
+ } else {
+ DCHECK(false) << "This function only valid for array-like values";
+ return 0;
+ }
+}
+
+ArrayVector Datum::chunks() const {
+ if (!this->is_arraylike()) {
+ return {};
+ }
+ if (this->is_array()) {
+ return {this->make_array()};
+ }
+ return this->chunked_array()->chunks();
+}
+
+bool Datum::Equals(const Datum& other) const {
+ if (this->kind() != other.kind()) return false;
+
+ switch (this->kind()) {
+ case Datum::NONE:
+ return true;
+ case Datum::SCALAR:
+ return internal::SharedPtrEquals(this->scalar(), other.scalar());
+ case Datum::ARRAY:
+ return internal::SharedPtrEquals(this->make_array(), other.make_array());
+ case Datum::CHUNKED_ARRAY:
+ return internal::SharedPtrEquals(this->chunked_array(), other.chunked_array());
+ case Datum::RECORD_BATCH:
+ return internal::SharedPtrEquals(this->record_batch(), other.record_batch());
+ case Datum::TABLE:
+ return internal::SharedPtrEquals(this->table(), other.table());
+ case Datum::COLLECTION:
+ return CollectionEquals(this->collection(), other.collection());
+ default:
+ return false;
+ }
+}
+
+ValueDescr Datum::descr() const {
+ if (this->is_arraylike()) {
+ return ValueDescr(this->type(), ValueDescr::ARRAY);
+ } else if (this->is_scalar()) {
+ return ValueDescr(this->type(), ValueDescr::SCALAR);
+ } else {
+ DCHECK(false) << "Datum is not value-like, this method should not be called";
+ return ValueDescr();
+ }
+}
+
+ValueDescr::Shape Datum::shape() const {
+ if (this->is_arraylike()) {
+ return ValueDescr::ARRAY;
+ } else if (this->is_scalar()) {
+ return ValueDescr::SCALAR;
+ } else {
+ DCHECK(false) << "Datum is not value-like, this method should not be called";
+ return ValueDescr::ANY;
+ }
+}
+
+static std::string FormatValueDescr(const ValueDescr& descr) {
+ std::stringstream ss;
+ switch (descr.shape) {
+ case ValueDescr::ANY:
+ ss << "any";
+ break;
+ case ValueDescr::ARRAY:
+ ss << "array";
+ break;
+ case ValueDescr::SCALAR:
+ ss << "scalar";
+ break;
+ default:
+ DCHECK(false);
+ break;
+ }
+ ss << "[" << descr.type->ToString() << "]";
+ return ss.str();
+}
+
+std::string ValueDescr::ToString() const { return FormatValueDescr(*this); }
+
+std::string ValueDescr::ToString(const std::vector<ValueDescr>& descrs) {
+ std::stringstream ss;
+ ss << "(";
+ for (size_t i = 0; i < descrs.size(); ++i) {
+ if (i > 0) {
+ ss << ", ";
+ }
+ ss << descrs[i].ToString();
+ }
+ ss << ")";
+ return ss.str();
+}
+
+void PrintTo(const ValueDescr& descr, std::ostream* os) { *os << descr.ToString(); }
+
+std::string Datum::ToString() const {
+ switch (this->kind()) {
+ case Datum::NONE:
+ return "nullptr";
+ case Datum::SCALAR:
+ return "Scalar";
+ case Datum::ARRAY:
+ return "Array";
+ case Datum::CHUNKED_ARRAY:
+ return "ChunkedArray";
+ case Datum::RECORD_BATCH:
+ return "RecordBatch";
+ case Datum::TABLE:
+ return "Table";
+ case Datum::COLLECTION: {
+ std::stringstream ss;
+ ss << "Collection(";
+ const auto& values = this->collection();
+ for (size_t i = 0; i < values.size(); ++i) {
+ if (i > 0) {
+ ss << ", ";
+ }
+ ss << values[i].ToString();
+ }
+ ss << ')';
+ return ss.str();
+ }
+ default:
+ DCHECK(false);
+ return "";
+ }
+}
+
+ValueDescr::Shape GetBroadcastShape(const std::vector<ValueDescr>& args) {
+ for (const auto& descr : args) {
+ if (descr.shape == ValueDescr::ARRAY) {
+ return ValueDescr::ARRAY;
+ }
+ }
+ return ValueDescr::SCALAR;
+}
+
+void PrintTo(const Datum& datum, std::ostream* os) {
+ switch (datum.kind()) {
+ case Datum::SCALAR:
+ *os << datum.scalar()->ToString();
+ break;
+ case Datum::ARRAY:
+ *os << datum.make_array()->ToString();
+ break;
+ default:
+ *os << datum.ToString();
+ }
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/datum.h b/src/arrow/cpp/src/arrow/datum.h
new file mode 100644
index 000000000..da851d917
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/datum.h
@@ -0,0 +1,281 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/data.h"
+#include "arrow/scalar.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/variant.h" // IWYU pragma: export
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Array;
+class ChunkedArray;
+class RecordBatch;
+class Table;
+
+/// \brief A descriptor type that gives the shape (array or scalar) and
+/// DataType of a Value, but without the data
+struct ARROW_EXPORT ValueDescr {
+ std::shared_ptr<DataType> type;
+ enum Shape {
+ /// \brief Either Array or Scalar
+ ANY,
+
+ /// \brief Array type
+ ARRAY,
+
+ /// \brief Only Scalar arguments supported
+ SCALAR
+ };
+
+ Shape shape;
+
+ ValueDescr() : shape(ANY) {}
+
+ ValueDescr(std::shared_ptr<DataType> type, ValueDescr::Shape shape)
+ : type(std::move(type)), shape(shape) {}
+
+ ValueDescr(std::shared_ptr<DataType> type) // NOLINT implicit conversion
+ : type(std::move(type)), shape(ValueDescr::ANY) {}
+
+ /// \brief Convenience constructor for ANY descr
+ static ValueDescr Any(std::shared_ptr<DataType> type) {
+ return ValueDescr(std::move(type), ANY);
+ }
+
+ /// \brief Convenience constructor for Value::ARRAY descr
+ static ValueDescr Array(std::shared_ptr<DataType> type) {
+ return ValueDescr(std::move(type), ARRAY);
+ }
+
+ /// \brief Convenience constructor for Value::SCALAR descr
+ static ValueDescr Scalar(std::shared_ptr<DataType> type) {
+ return ValueDescr(std::move(type), SCALAR);
+ }
+
+ bool operator==(const ValueDescr& other) const {
+ if (shape != other.shape) return false;
+ if (type == other.type) return true;
+ return type && type->Equals(other.type);
+ }
+
+ bool operator!=(const ValueDescr& other) const { return !(*this == other); }
+
+ std::string ToString() const;
+ static std::string ToString(const std::vector<ValueDescr>&);
+
+ ARROW_EXPORT friend void PrintTo(const ValueDescr&, std::ostream*);
+};
+
+/// \brief For use with scalar functions, returns the broadcasted Value::Shape
+/// given a vector of value descriptors. Return SCALAR unless any value is
+/// ARRAY
+ARROW_EXPORT
+ValueDescr::Shape GetBroadcastShape(const std::vector<ValueDescr>& args);
+
+/// \class Datum
+/// \brief Variant type for various Arrow C++ data structures
+struct ARROW_EXPORT Datum {
+ enum Kind { NONE, SCALAR, ARRAY, CHUNKED_ARRAY, RECORD_BATCH, TABLE, COLLECTION };
+
+ struct Empty {};
+
+ // Datums variants may have a length. This special value indicate that the
+ // current variant does not have a length.
+ static constexpr int64_t kUnknownLength = -1;
+
+ util::Variant<Empty, std::shared_ptr<Scalar>, std::shared_ptr<ArrayData>,
+ std::shared_ptr<ChunkedArray>, std::shared_ptr<RecordBatch>,
+ std::shared_ptr<Table>, std::vector<Datum>>
+ value;
+
+ /// \brief Empty datum, to be populated elsewhere
+ Datum() = default;
+
+ Datum(const Datum& other) = default;
+ Datum& operator=(const Datum& other) = default;
+ Datum(Datum&& other) = default;
+ Datum& operator=(Datum&& other) = default;
+
+ Datum(std::shared_ptr<Scalar> value) // NOLINT implicit conversion
+ : value(std::move(value)) {}
+
+ Datum(std::shared_ptr<ArrayData> value) // NOLINT implicit conversion
+ : value(std::move(value)) {}
+
+ Datum(ArrayData arg) // NOLINT implicit conversion
+ : value(std::make_shared<ArrayData>(std::move(arg))) {}
+
+ Datum(const Array& value); // NOLINT implicit conversion
+ Datum(const std::shared_ptr<Array>& value); // NOLINT implicit conversion
+ Datum(std::shared_ptr<ChunkedArray> value); // NOLINT implicit conversion
+ Datum(std::shared_ptr<RecordBatch> value); // NOLINT implicit conversion
+ Datum(std::shared_ptr<Table> value); // NOLINT implicit conversion
+ Datum(std::vector<Datum> value); // NOLINT implicit conversion
+
+ // Explicit constructors from const-refs. Can be expensive, prefer the
+ // shared_ptr constructors
+ explicit Datum(const ChunkedArray& value);
+ explicit Datum(const RecordBatch& value);
+ explicit Datum(const Table& value);
+
+ // Cast from subtypes of Array to Datum
+ template <typename T, typename = enable_if_t<std::is_base_of<Array, T>::value>>
+ Datum(const std::shared_ptr<T>& value) // NOLINT implicit conversion
+ : Datum(std::shared_ptr<Array>(value)) {}
+
+ // Convenience constructors
+ explicit Datum(bool value);
+ explicit Datum(int8_t value);
+ explicit Datum(uint8_t value);
+ explicit Datum(int16_t value);
+ explicit Datum(uint16_t value);
+ explicit Datum(int32_t value);
+ explicit Datum(uint32_t value);
+ explicit Datum(int64_t value);
+ explicit Datum(uint64_t value);
+ explicit Datum(float value);
+ explicit Datum(double value);
+ explicit Datum(std::string value);
+ explicit Datum(const char* value);
+
+ Datum::Kind kind() const {
+ switch (this->value.index()) {
+ case 0:
+ return Datum::NONE;
+ case 1:
+ return Datum::SCALAR;
+ case 2:
+ return Datum::ARRAY;
+ case 3:
+ return Datum::CHUNKED_ARRAY;
+ case 4:
+ return Datum::RECORD_BATCH;
+ case 5:
+ return Datum::TABLE;
+ case 6:
+ return Datum::COLLECTION;
+ default:
+ return Datum::NONE;
+ }
+ }
+
+ const std::shared_ptr<ArrayData>& array() const {
+ return util::get<std::shared_ptr<ArrayData>>(this->value);
+ }
+
+ ArrayData* mutable_array() const { return this->array().get(); }
+
+ std::shared_ptr<Array> make_array() const;
+
+ const std::shared_ptr<ChunkedArray>& chunked_array() const {
+ return util::get<std::shared_ptr<ChunkedArray>>(this->value);
+ }
+
+ const std::shared_ptr<RecordBatch>& record_batch() const {
+ return util::get<std::shared_ptr<RecordBatch>>(this->value);
+ }
+
+ const std::shared_ptr<Table>& table() const {
+ return util::get<std::shared_ptr<Table>>(this->value);
+ }
+
+ const std::vector<Datum>& collection() const {
+ return util::get<std::vector<Datum>>(this->value);
+ }
+
+ const std::shared_ptr<Scalar>& scalar() const {
+ return util::get<std::shared_ptr<Scalar>>(this->value);
+ }
+
+ template <typename ExactType>
+ std::shared_ptr<ExactType> array_as() const {
+ return internal::checked_pointer_cast<ExactType>(this->make_array());
+ }
+
+ template <typename ExactType>
+ const ExactType& scalar_as() const {
+ return internal::checked_cast<const ExactType&>(*this->scalar());
+ }
+
+ bool is_array() const { return this->kind() == Datum::ARRAY; }
+
+ bool is_arraylike() const {
+ return this->kind() == Datum::ARRAY || this->kind() == Datum::CHUNKED_ARRAY;
+ }
+
+ bool is_scalar() const { return this->kind() == Datum::SCALAR; }
+
+ /// \brief True if Datum contains a scalar or array-like data
+ bool is_value() const { return this->is_arraylike() || this->is_scalar(); }
+
+ bool is_collection() const { return this->kind() == Datum::COLLECTION; }
+
+ int64_t null_count() const;
+
+ /// \brief Return the shape (array or scalar) and type for supported kinds
+ /// (ARRAY, CHUNKED_ARRAY, and SCALAR). Debug asserts otherwise
+ ValueDescr descr() const;
+
+ /// \brief Return the shape (array or scalar) for supported kinds (ARRAY,
+ /// CHUNKED_ARRAY, and SCALAR). Debug asserts otherwise
+ ValueDescr::Shape shape() const;
+
+ /// \brief The value type of the variant, if any
+ ///
+ /// \return nullptr if no type
+ const std::shared_ptr<DataType>& type() const;
+
+ /// \brief The schema of the variant, if any
+ ///
+ /// \return nullptr if no schema
+ const std::shared_ptr<Schema>& schema() const;
+
+ /// \brief The value length of the variant, if any
+ ///
+ /// \return kUnknownLength if no type
+ int64_t length() const;
+
+ /// \brief The array chunks of the variant, if any
+ ///
+ /// \return empty if not arraylike
+ ArrayVector chunks() const;
+
+ bool Equals(const Datum& other) const;
+
+ bool operator==(const Datum& other) const { return Equals(other); }
+ bool operator!=(const Datum& other) const { return !Equals(other); }
+
+ std::string ToString() const;
+
+ ARROW_EXPORT friend void PrintTo(const Datum&, std::ostream*);
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/datum_test.cc b/src/arrow/cpp/src/arrow/datum_test.cc
new file mode 100644
index 000000000..cf65d515d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/datum_test.cc
@@ -0,0 +1,172 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <string>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array/array_base.h"
+#include "arrow/chunked_array.h"
+#include "arrow/datum.h"
+#include "arrow/scalar.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+
+class BinaryArray;
+class RecordBatch;
+
+using internal::checked_cast;
+
+// ----------------------------------------------------------------------
+// Datum
+
+template <typename T>
+void CheckImplicitConstructor(Datum::Kind expected_kind) {
+ std::shared_ptr<T> value;
+ Datum datum = value;
+ ASSERT_EQ(expected_kind, datum.kind());
+}
+
+TEST(Datum, ImplicitConstructors) {
+ CheckImplicitConstructor<Scalar>(Datum::SCALAR);
+
+ CheckImplicitConstructor<Array>(Datum::ARRAY);
+
+ // Instantiate from array subclass
+ CheckImplicitConstructor<BinaryArray>(Datum::ARRAY);
+
+ CheckImplicitConstructor<ChunkedArray>(Datum::CHUNKED_ARRAY);
+ CheckImplicitConstructor<RecordBatch>(Datum::RECORD_BATCH);
+
+ CheckImplicitConstructor<Table>(Datum::TABLE);
+}
+
+TEST(Datum, Constructors) {
+ Datum val(std::make_shared<Int64Scalar>(1));
+ ASSERT_EQ(ValueDescr::SCALAR, val.shape());
+ AssertTypeEqual(*int64(), *val.type());
+ ASSERT_TRUE(val.is_scalar());
+ ASSERT_FALSE(val.is_array());
+ ASSERT_EQ(1, val.length());
+
+ const Int64Scalar& val_as_i64 = checked_cast<const Int64Scalar&>(*val.scalar());
+ const Int64Scalar& val_as_i64_2 = val.scalar_as<Int64Scalar>();
+ ASSERT_EQ(1, val_as_i64.value);
+ ASSERT_EQ(1, val_as_i64_2.value);
+
+ auto arr = ArrayFromJSON(int64(), "[1, 2, 3, 4]");
+ auto sel_indices = ArrayFromJSON(int32(), "[0, 3]");
+
+ Datum val2(arr);
+ ASSERT_EQ(Datum::ARRAY, val2.kind());
+ ASSERT_EQ(ValueDescr::ARRAY, val2.shape());
+ AssertTypeEqual(*int64(), *val2.type());
+ AssertArraysEqual(*arr, *val2.make_array());
+ ASSERT_TRUE(val2.is_array());
+ ASSERT_FALSE(val2.is_scalar());
+ ASSERT_EQ(arr->length(), val2.length());
+
+ auto Check = [&](const Datum& v) { AssertArraysEqual(*arr, *v.make_array()); };
+
+ // Copy constructor
+ Datum val3 = val2;
+ Check(val3);
+
+ // Copy assignment
+ Datum val4;
+ val4 = val2;
+ Check(val4);
+
+ // Move constructor
+ Datum val5 = std::move(val2);
+ Check(val5);
+
+ // Move assignment
+ Datum val6;
+ val6 = std::move(val4);
+ Check(val6);
+}
+
+TEST(Datum, NullCount) {
+ Datum val1(std::make_shared<Int8Scalar>(1));
+ ASSERT_EQ(0, val1.null_count());
+
+ Datum val2(MakeNullScalar(int8()));
+ ASSERT_EQ(1, val2.null_count());
+
+ Datum val3(ArrayFromJSON(int8(), "[1, null, null, null]"));
+ ASSERT_EQ(3, val3.null_count());
+}
+
+TEST(Datum, MutableArray) {
+ auto arr = ArrayFromJSON(int8(), "[1, 2, 3, 4]");
+
+ Datum val(arr);
+
+ val.mutable_array()->length = 0;
+ ASSERT_EQ(0, val.array()->length);
+}
+
+TEST(Datum, ToString) {
+ auto arr = ArrayFromJSON(int8(), "[1, 2, 3, 4]");
+
+ Datum v1(arr);
+ Datum v2(std::make_shared<Int8Scalar>(1));
+
+ std::vector<Datum> vec1 = {v1};
+ Datum v3(vec1);
+
+ std::vector<Datum> vec2 = {v1, v2};
+ Datum v4(vec2);
+
+ ASSERT_EQ("Array", v1.ToString());
+ ASSERT_EQ("Scalar", v2.ToString());
+ ASSERT_EQ("Collection(Array)", v3.ToString());
+ ASSERT_EQ("Collection(Array, Scalar)", v4.ToString());
+}
+
+TEST(ValueDescr, Basics) {
+ ValueDescr d1(utf8(), ValueDescr::SCALAR);
+ ValueDescr d2 = ValueDescr::Any(utf8());
+ ValueDescr d3 = ValueDescr::Scalar(utf8());
+ ValueDescr d4 = ValueDescr::Array(utf8());
+
+ ASSERT_EQ(ValueDescr::SCALAR, d1.shape);
+ AssertTypeEqual(*utf8(), *d1.type);
+ ASSERT_EQ(ValueDescr::Scalar(utf8()), d1);
+
+ ASSERT_EQ(ValueDescr::ANY, d2.shape);
+ AssertTypeEqual(*utf8(), *d2.type);
+ ASSERT_EQ(ValueDescr::Any(utf8()), d2);
+ ASSERT_NE(ValueDescr::Any(int32()), d2);
+
+ ASSERT_EQ(ValueDescr::SCALAR, d3.shape);
+ ASSERT_EQ(ValueDescr::ARRAY, d4.shape);
+
+ ASSERT_EQ("scalar[string]", d1.ToString());
+ ASSERT_EQ("any[string]", d2.ToString());
+ ASSERT_EQ("scalar[string]", d3.ToString());
+ ASSERT_EQ("array[string]", d4.ToString());
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dbi/README.md b/src/arrow/cpp/src/arrow/dbi/README.md
new file mode 100644
index 000000000..d73666c37
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/README.md
@@ -0,0 +1,24 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Arrow Database Interfaces
+
+## HiveServer2
+
+For Apache Hive and Apache Impala. See `hiveserver2/` directory
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/CMakeLists.txt b/src/arrow/cpp/src/arrow/dbi/hiveserver2/CMakeLists.txt
new file mode 100644
index 000000000..2638456c6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/CMakeLists.txt
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_custom_target(arrow_hiveserver2)
+add_custom_target(arrow_hiveserver2-tests)
+
+# Headers: top level
+arrow_install_all_headers("arrow/dbi/hiveserver2")
+
+set(ARROW_HIVESERVER2_SRCS
+ columnar_row_set.cc
+ service.cc
+ session.cc
+ operation.cc
+ sample_usage.cc
+ thrift_internal.cc
+ types.cc
+ util.cc)
+
+add_subdirectory(thrift)
+
+set(HIVESERVER2_THRIFT_SRC
+ ErrorCodes_constants.cpp
+ ErrorCodes_types.cpp
+ ImpalaService.cpp
+ ImpalaService_constants.cpp
+ ImpalaService_types.cpp
+ ImpalaHiveServer2Service.cpp
+ beeswax_constants.cpp
+ beeswax_types.cpp
+ BeeswaxService.cpp
+ TCLIService.cpp
+ TCLIService_constants.cpp
+ TCLIService_types.cpp
+ ExecStats_constants.cpp
+ ExecStats_types.cpp
+ hive_metastore_constants.cpp
+ hive_metastore_types.cpp
+ Status_constants.cpp
+ Status_types.cpp
+ Types_constants.cpp
+ Types_types.cpp)
+
+set_source_files_properties(${HIVESERVER2_THRIFT_SRC}
+ PROPERTIES COMPILE_FLAGS
+ "-Wno-unused-variable -Wno-shadow-field" GENERATED
+ TRUE)
+
+# keep everything in one library, the object files reference
+# each other
+add_library(arrow_hiveserver2_thrift STATIC ${HIVESERVER2_THRIFT_SRC})
+
+# Setting these files as code-generated lets make clean and incremental builds work
+# correctly
+
+# TODO(wesm): Something is broken with the dependency chain with
+# ImpalaService.cpp and others. Couldn't figure out what is different between
+# this setup and Impala.
+
+add_dependencies(arrow_hiveserver2_thrift hs2-thrift-cpp)
+
+set_target_properties(arrow_hiveserver2_thrift
+ PROPERTIES LIBRARY_OUTPUT_DIRECTORY
+ "${BUILD_OUTPUT_ROOT_DIRECTORY}")
+
+add_arrow_lib(arrow_hiveserver2
+ SOURCES
+ ${ARROW_HIVESERVER2_SRCS}
+ OUTPUTS
+ ARROW_HIVESERVER2_LIBRARIES
+ DEPENDENCIES
+ arrow_hiveserver2_thrift
+ SHARED_LINK_FLAGS
+ ""
+ SHARED_LINK_LIBS
+ ${ARROW_PYTHON_SHARED_LINK_LIBS})
+
+add_dependencies(arrow_hiveserver2 ${ARROW_HIVESERVER2_LIBRARIES})
+
+foreach(LIB_TARGET ${ARROW_HIVESERVER2_LIBRARIES})
+ target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_EXPORTING)
+endforeach()
+
+set_property(SOURCE ${ARROW_HIVESERVER2_SRCS}
+ APPEND_STRING
+ PROPERTY COMPILE_FLAGS " -Wno-shadow-field")
+
+set(ARROW_HIVESERVER2_TEST_LINK_LIBS arrow_hiveserver2_static arrow_hiveserver2_thrift
+ ${ARROW_TEST_LINK_LIBS} thrift::thrift)
+
+if(ARROW_BUILD_TESTS)
+ add_test_case(hiveserver2_test
+ STATIC_LINK_LIBS
+ "${ARROW_HIVESERVER2_TEST_LINK_LIBS}"
+ LABELS
+ "arrow_hiveserver2-tests")
+ if(TARGET arrow-hiveserver2-test)
+ set_property(TARGET arrow-hiveserver2-test
+ APPEND_STRING
+ PROPERTY COMPILE_FLAGS " -Wno-shadow-field")
+ endif()
+endif(ARROW_BUILD_TESTS)
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/api.h b/src/arrow/cpp/src/arrow/dbi/hiveserver2/api.h
new file mode 100644
index 000000000..da860e6d8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/api.h
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/dbi/hiveserver2/columnar_row_set.h"
+#include "arrow/dbi/hiveserver2/operation.h"
+#include "arrow/dbi/hiveserver2/service.h"
+#include "arrow/dbi/hiveserver2/session.h"
+#include "arrow/dbi/hiveserver2/types.h"
+#include "arrow/dbi/hiveserver2/util.h"
+
+#include "arrow/status.h"
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/columnar_row_set.cc b/src/arrow/cpp/src/arrow/dbi/hiveserver2/columnar_row_set.cc
new file mode 100644
index 000000000..bef894014
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/columnar_row_set.cc
@@ -0,0 +1,100 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dbi/hiveserver2/columnar_row_set.h"
+
+#include <string>
+#include <vector>
+
+#include "arrow/dbi/hiveserver2/TCLIService.h"
+#include "arrow/dbi/hiveserver2/thrift_internal.h"
+
+#include "arrow/util/logging.h"
+
+namespace hs2 = apache::hive::service::cli::thrift;
+
+namespace arrow {
+namespace hiveserver2 {
+
+Column::Column(const std::string* nulls) {
+ DCHECK(nulls);
+ nulls_ = reinterpret_cast<const uint8_t*>(nulls->c_str());
+ nulls_size_ = static_cast<int64_t>(nulls->size());
+}
+
+ColumnarRowSet::ColumnarRowSet(ColumnarRowSetImpl* impl) : impl_(impl) {}
+
+ColumnarRowSet::~ColumnarRowSet() = default;
+
+template <typename T>
+struct type_helpers {};
+
+#define VALUE_GETTER(COLUMN_TYPE, VALUE_TYPE, ATTR_NAME) \
+ template <> \
+ struct type_helpers<COLUMN_TYPE> { \
+ static const std::vector<VALUE_TYPE>* GetValues(const hs2::TColumn& col) { \
+ return &col.ATTR_NAME.values; \
+ } \
+ \
+ static const std::string* GetNulls(const hs2::TColumn& col) { \
+ return &col.ATTR_NAME.nulls; \
+ } \
+ };
+
+VALUE_GETTER(BoolColumn, bool, boolVal);
+VALUE_GETTER(ByteColumn, int8_t, byteVal);
+VALUE_GETTER(Int16Column, int16_t, i16Val);
+VALUE_GETTER(Int32Column, int32_t, i32Val);
+VALUE_GETTER(Int64Column, int64_t, i64Val);
+VALUE_GETTER(DoubleColumn, double, doubleVal);
+VALUE_GETTER(StringColumn, std::string, stringVal);
+
+#undef VALUE_GETTER
+
+template <typename T>
+std::unique_ptr<T> ColumnarRowSet::GetCol(int i) const {
+ using helper = type_helpers<T>;
+
+ DCHECK_LT(i, static_cast<int>(impl_->resp.results.columns.size()));
+
+ const hs2::TColumn& col = impl_->resp.results.columns[i];
+ return std::unique_ptr<T>(new T(helper::GetNulls(col), helper::GetValues(col)));
+}
+
+#define TYPED_GETTER(FUNC_NAME, TYPE) \
+ std::unique_ptr<TYPE> ColumnarRowSet::FUNC_NAME(int i) const { \
+ return GetCol<TYPE>(i); \
+ } \
+ template std::unique_ptr<TYPE> ColumnarRowSet::GetCol<TYPE>(int i) const;
+
+TYPED_GETTER(GetBoolCol, BoolColumn);
+TYPED_GETTER(GetByteCol, ByteColumn);
+TYPED_GETTER(GetInt16Col, Int16Column);
+TYPED_GETTER(GetInt32Col, Int32Column);
+TYPED_GETTER(GetInt64Col, Int64Column);
+TYPED_GETTER(GetDoubleCol, DoubleColumn);
+TYPED_GETTER(GetStringCol, StringColumn);
+
+#undef TYPED_GETTER
+
+// BinaryColumn is an alias for StringColumn
+std::unique_ptr<BinaryColumn> ColumnarRowSet::GetBinaryCol(int i) const {
+ return GetCol<BinaryColumn>(i);
+}
+
+} // namespace hiveserver2
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/columnar_row_set.h b/src/arrow/cpp/src/arrow/dbi/hiveserver2/columnar_row_set.h
new file mode 100644
index 000000000..a62c73802
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/columnar_row_set.h
@@ -0,0 +1,155 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace hiveserver2 {
+
+// The Column class is used to access data that was fetched in columnar format.
+// The contents of the data can be accessed through the data() fn, which returns
+// a ptr to a vector containing the contents of this column in the fetched
+// results, avoiding copies. This vector will be of size length().
+//
+// If any of the values are null, they will be represented in the data vector as
+// default values, i.e. 0 for numeric types. The nulls() fn returns a ptr to a
+// bit array representing which values are null, and the IsNull() fn is provided
+// for convenience when working with this bit array. The user should check
+// IsNull() to distinguish between actual instances of the default values and nulls.
+//
+// A Column object is returned from a ColumnarRowSet and is only valid as long
+// as that ColumnarRowSet still exists.
+//
+// Example:
+// unique_ptr<Int32Column> col = columnar_row_set->GetInt32Col();
+// for (int i = 0; i < col->length(); i++) {
+// if (col->IsNull(i)) {
+// cout << "NULL\n";
+// } else {
+// cout << col->data()[i] << "\n";
+// }
+// }
+class ARROW_EXPORT Column {
+ public:
+ virtual ~Column() {}
+
+ virtual int64_t length() const = 0;
+
+ const uint8_t* nulls() const { return nulls_; }
+ int64_t nulls_size() const { return nulls_size_; }
+
+ // Returns true iff the value for the i-th row within this set of data for this
+ // column is null.
+ bool IsNull(int64_t i) const { return (nulls_[i / 8] & (1 << (i % 8))) != 0; }
+
+ protected:
+ explicit Column(const std::string* nulls);
+
+ // The memory for these ptrs is owned by the ColumnarRowSet that
+ // created this Column.
+ //
+ // Due to the issue described in HUE-2722, the null bitmap may have fewer
+ // bytes than expected for some versions of Hive, so we retain the ability to
+ // check the buffer size in case this happens.
+ const uint8_t* nulls_;
+ int64_t nulls_size_;
+};
+
+template <class T>
+class ARROW_EXPORT TypedColumn : public Column {
+ public:
+ const std::vector<T>& data() const { return *data_; }
+ int64_t length() const { return data().size(); }
+
+ // Returns the value for the i-th row within this set of data for this column.
+ const T& GetData(int64_t i) const { return data()[i]; }
+
+ private:
+ // For access to the c'tor.
+ friend class ColumnarRowSet;
+
+ TypedColumn(const std::string* nulls, const std::vector<T>* data)
+ : Column(nulls), data_(data) {}
+
+ const std::vector<T>* data_;
+};
+
+typedef TypedColumn<bool> BoolColumn;
+typedef TypedColumn<int8_t> ByteColumn;
+typedef TypedColumn<int16_t> Int16Column;
+typedef TypedColumn<int32_t> Int32Column;
+typedef TypedColumn<int64_t> Int64Column;
+typedef TypedColumn<double> DoubleColumn;
+typedef TypedColumn<std::string> StringColumn;
+typedef TypedColumn<std::string> BinaryColumn;
+
+// A ColumnarRowSet represents the full results returned by a call to
+// Operation::Fetch() when a columnar format is being used.
+//
+// ColumnarRowSet provides access to specific columns by their type and index in
+// the results. All Column objects returned from a given ColumnarRowSet will have
+// the same length(). A Column object returned by a ColumnarRowSet is only valid
+// as long as the ColumnarRowSet still exists.
+//
+// Example:
+// unique_ptr<Operation> op;
+// session->ExecuteStatement("select int_col, string_col from tbl", &op);
+// unique_ptr<ColumnarRowSet> columnar_row_set;
+// if (op->Fetch(&columnar_row_set).ok()) {
+// unique_ptr<Int32Column> int32_col = columnar_row_set->GetInt32Col(0);
+// unique_ptr<StringColumn> string_col = columnar_row_set->GetStringCol(1);
+// }
+class ARROW_EXPORT ColumnarRowSet {
+ public:
+ ~ColumnarRowSet();
+
+ std::unique_ptr<BoolColumn> GetBoolCol(int i) const;
+ std::unique_ptr<ByteColumn> GetByteCol(int i) const;
+ std::unique_ptr<Int16Column> GetInt16Col(int i) const;
+ std::unique_ptr<Int32Column> GetInt32Col(int i) const;
+ std::unique_ptr<Int64Column> GetInt64Col(int i) const;
+ std::unique_ptr<DoubleColumn> GetDoubleCol(int i) const;
+ std::unique_ptr<StringColumn> GetStringCol(int i) const;
+ std::unique_ptr<BinaryColumn> GetBinaryCol(int i) const;
+
+ template <typename T>
+ std::unique_ptr<T> GetCol(int i) const;
+
+ private:
+ // Hides Thrift objects from the header.
+ struct ColumnarRowSetImpl;
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ColumnarRowSet);
+
+ // For access to the c'tor.
+ friend class Operation;
+
+ explicit ColumnarRowSet(ColumnarRowSetImpl* impl);
+
+ std::unique_ptr<ColumnarRowSetImpl> impl_;
+};
+
+} // namespace hiveserver2
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/hiveserver2_test.cc b/src/arrow/cpp/src/arrow/dbi/hiveserver2/hiveserver2_test.cc
new file mode 100644
index 000000000..c19ec1c80
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/hiveserver2_test.cc
@@ -0,0 +1,458 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dbi/hiveserver2/operation.h"
+
+#include <memory>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/dbi/hiveserver2/service.h"
+#include "arrow/dbi/hiveserver2/session.h"
+#include "arrow/dbi/hiveserver2/thrift_internal.h"
+
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+namespace hiveserver2 {
+
+static std::string GetTestHost() {
+ const char* host = std::getenv("ARROW_HIVESERVER2_TEST_HOST");
+ return host == nullptr ? "localhost" : std::string(host);
+}
+
+// Convenience functions for finding a row of values given several columns.
+template <typename VType, typename CType>
+bool FindRow(VType value, CType* column) {
+ for (int i = 0; i < column->length(); ++i) {
+ if (column->data()[i] == value) {
+ return true;
+ }
+ }
+ return false;
+}
+
+template <typename V1Type, typename V2Type, typename C1Type, typename C2Type>
+bool FindRow(V1Type value1, V2Type value2, C1Type* column1, C2Type* column2) {
+ EXPECT_EQ(column1->length(), column2->length());
+ for (int i = 0; i < column1->length(); ++i) {
+ if (column1->data()[i] == value1 && column2->data()[i] == value2) {
+ return true;
+ }
+ }
+ return false;
+}
+
+template <typename V1Type, typename V2Type, typename V3Type, typename C1Type,
+ typename C2Type, typename C3Type>
+bool FindRow(V1Type value1, V2Type value2, V3Type value3, C1Type* column1,
+ C2Type* column2, C3Type column3) {
+ EXPECT_EQ(column1->length(), column2->length());
+ EXPECT_EQ(column1->length(), column3->length());
+ for (int i = 0; i < column1->length(); ++i) {
+ if (column1->data()[i] == value1 && column2->data()[i] == value2 &&
+ column3->data()[i] == value3) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Waits for this operation to reach the given state, sleeping for sleep microseconds
+// between checks, and failing after max_retries checks.
+Status Wait(const std::unique_ptr<Operation>& op,
+ Operation::State state = Operation::State::FINISHED, int sleep_us = 10000,
+ int max_retries = 100) {
+ int retries = 0;
+ Operation::State op_state;
+ RETURN_NOT_OK(op->GetState(&op_state));
+ while (op_state != state && retries < max_retries) {
+ usleep(sleep_us);
+ RETURN_NOT_OK(op->GetState(&op_state));
+ ++retries;
+ }
+
+ if (op_state == state) {
+ return Status::OK();
+ } else {
+ return Status::IOError("Failed to reach state '", OperationStateToString(state),
+ "' after ", retries, " retries");
+ }
+}
+
+// Creates a service, session, and database for use in tests.
+class HS2ClientTest : public ::testing::Test {
+ protected:
+ virtual void SetUp() {
+ hostname_ = GetTestHost();
+
+ int conn_timeout = 0;
+ ProtocolVersion protocol_version = ProtocolVersion::PROTOCOL_V7;
+ ASSERT_OK(
+ Service::Connect(hostname_, port, conn_timeout, protocol_version, &service_));
+
+ std::string user = "user";
+ HS2ClientConfig config;
+ ASSERT_OK(service_->OpenSession(user, config, &session_));
+
+ std::unique_ptr<Operation> drop_db_op;
+ ASSERT_OK(session_->ExecuteStatement(
+ "drop database if exists " + TEST_DB + " cascade", &drop_db_op));
+ ASSERT_OK(drop_db_op->Close());
+
+ std::unique_ptr<Operation> create_db_op;
+ ASSERT_OK(session_->ExecuteStatement("create database " + TEST_DB, &create_db_op));
+ ASSERT_OK(create_db_op->Close());
+
+ std::unique_ptr<Operation> use_db_op;
+ ASSERT_OK(session_->ExecuteStatement("use " + TEST_DB, &use_db_op));
+ ASSERT_OK(use_db_op->Close());
+ }
+
+ virtual void TearDown() {
+ std::unique_ptr<Operation> use_db_op;
+ if (session_) {
+ // We were able to create a session and service
+ ASSERT_OK(session_->ExecuteStatement("use default", &use_db_op));
+ ASSERT_OK(use_db_op->Close());
+
+ std::unique_ptr<Operation> drop_db_op;
+ ASSERT_OK(session_->ExecuteStatement("drop database " + TEST_DB + " cascade",
+ &drop_db_op));
+ ASSERT_OK(drop_db_op->Close());
+
+ ASSERT_OK(session_->Close());
+ ASSERT_OK(service_->Close());
+ }
+ }
+
+ void CreateTestTable() {
+ std::unique_ptr<Operation> create_table_op;
+ ASSERT_OK(session_->ExecuteStatement(
+ "create table " + TEST_TBL + " (" + TEST_COL1 + " int, " + TEST_COL2 + " string)",
+ &create_table_op));
+ ASSERT_OK(create_table_op->Close());
+ }
+
+ void InsertIntoTestTable(std::vector<int> int_col_data,
+ std::vector<std::string> string_col_data) {
+ ASSERT_EQ(int_col_data.size(), string_col_data.size());
+
+ std::stringstream query;
+ query << "insert into " << TEST_TBL << " VALUES ";
+ for (size_t i = 0; i < int_col_data.size(); i++) {
+ if (int_col_data[i] == NULL_INT_VALUE) {
+ query << " (NULL, ";
+ } else {
+ query << " (" << int_col_data[i] << ", ";
+ }
+
+ if (string_col_data[i] == "NULL") {
+ query << "NULL)";
+ } else {
+ query << "'" << string_col_data[i] << "')";
+ }
+
+ if (i != int_col_data.size() - 1) {
+ query << ", ";
+ }
+ }
+
+ std::unique_ptr<Operation> insert_op;
+ ASSERT_OK(session_->ExecuteStatement(query.str(), &insert_op));
+ ASSERT_OK(Wait(insert_op));
+ Operation::State insert_op_state;
+ ASSERT_OK(insert_op->GetState(&insert_op_state));
+ ASSERT_EQ(insert_op_state, Operation::State::FINISHED);
+ ASSERT_OK(insert_op->Close());
+ }
+ std::string hostname_;
+
+ int port = 21050;
+
+ const std::string TEST_DB = "hs2client_test_db";
+ const std::string TEST_TBL = "hs2client_test_table";
+ const std::string TEST_COL1 = "int_col";
+ const std::string TEST_COL2 = "string_col";
+
+ const int NULL_INT_VALUE = -1;
+
+ std::unique_ptr<Service> service_;
+ std::unique_ptr<Session> session_;
+};
+
+class OperationTest : public HS2ClientTest {};
+
+TEST_F(OperationTest, TestFetch) {
+ CreateTestTable();
+ InsertIntoTestTable(std::vector<int>({1, 2, 3, 4}),
+ std::vector<std::string>({"a", "b", "c", "d"}));
+
+ std::unique_ptr<Operation> select_op;
+ ASSERT_OK(session_->ExecuteStatement("select * from " + TEST_TBL + " order by int_col",
+ &select_op));
+
+ std::unique_ptr<ColumnarRowSet> results;
+ bool has_more_rows = false;
+ // Impala only supports NEXT and FIRST.
+ ASSERT_RAISES(IOError,
+ select_op->Fetch(2, FetchOrientation::LAST, &results, &has_more_rows));
+
+ // Fetch the results in two batches by passing max_rows to Fetch.
+ ASSERT_OK(select_op->Fetch(2, FetchOrientation::NEXT, &results, &has_more_rows));
+ ASSERT_OK(Wait(select_op));
+ ASSERT_TRUE(select_op->HasResultSet());
+ std::unique_ptr<Int32Column> int_col = results->GetInt32Col(0);
+ std::unique_ptr<StringColumn> string_col = results->GetStringCol(1);
+ ASSERT_EQ(int_col->data(), std::vector<int>({1, 2}));
+ ASSERT_EQ(string_col->data(), std::vector<std::string>({"a", "b"}));
+ ASSERT_TRUE(has_more_rows);
+
+ ASSERT_OK(select_op->Fetch(2, FetchOrientation::NEXT, &results, &has_more_rows));
+ int_col = results->GetInt32Col(0);
+ string_col = results->GetStringCol(1);
+ ASSERT_EQ(int_col->data(), std::vector<int>({3, 4}));
+ ASSERT_EQ(string_col->data(), std::vector<std::string>({"c", "d"}));
+
+ ASSERT_OK(select_op->Fetch(2, FetchOrientation::NEXT, &results, &has_more_rows));
+ int_col = results->GetInt32Col(0);
+ string_col = results->GetStringCol(1);
+ ASSERT_EQ(int_col->length(), 0);
+ ASSERT_EQ(string_col->length(), 0);
+ ASSERT_FALSE(has_more_rows);
+
+ ASSERT_OK(select_op->Fetch(2, FetchOrientation::NEXT, &results, &has_more_rows));
+ int_col = results->GetInt32Col(0);
+ string_col = results->GetStringCol(1);
+ ASSERT_EQ(int_col->length(), 0);
+ ASSERT_EQ(string_col->length(), 0);
+ ASSERT_FALSE(has_more_rows);
+
+ ASSERT_OK(select_op->Close());
+}
+
+TEST_F(OperationTest, TestIsNull) {
+ CreateTestTable();
+ // Insert some NULLs and ensure Column::IsNull() is correct.
+ InsertIntoTestTable(std::vector<int>({1, 2, 3, 4, 5, NULL_INT_VALUE}),
+ std::vector<std::string>({"a", "b", "NULL", "d", "NULL", "f"}));
+
+ std::unique_ptr<Operation> select_nulls_op;
+ ASSERT_OK(session_->ExecuteStatement("select * from " + TEST_TBL + " order by int_col",
+ &select_nulls_op));
+
+ std::unique_ptr<ColumnarRowSet> nulls_results;
+ bool has_more_rows = false;
+ ASSERT_OK(select_nulls_op->Fetch(&nulls_results, &has_more_rows));
+ std::unique_ptr<Int32Column> int_col = nulls_results->GetInt32Col(0);
+ std::unique_ptr<StringColumn> string_col = nulls_results->GetStringCol(1);
+ ASSERT_EQ(int_col->length(), 6);
+ ASSERT_EQ(int_col->length(), string_col->length());
+
+ bool int_nulls[] = {false, false, false, false, false, true};
+ for (int i = 0; i < int_col->length(); i++) {
+ ASSERT_EQ(int_col->IsNull(i), int_nulls[i]);
+ }
+ bool string_nulls[] = {false, false, true, false, true, false};
+ for (int i = 0; i < string_col->length(); i++) {
+ ASSERT_EQ(string_col->IsNull(i), string_nulls[i]);
+ }
+
+ ASSERT_OK(select_nulls_op->Close());
+}
+
+TEST_F(OperationTest, TestCancel) {
+ CreateTestTable();
+ InsertIntoTestTable(std::vector<int>({1, 2, 3, 4}),
+ std::vector<std::string>({"a", "b", "c", "d"}));
+
+ std::unique_ptr<Operation> op;
+ ASSERT_OK(session_->ExecuteStatement("select count(*) from " + TEST_TBL, &op));
+ ASSERT_OK(op->Cancel());
+ // Impala currently returns ERROR and not CANCELED for canceled queries
+ // due to the use of beeswax states, which don't support a canceled state.
+ ASSERT_OK(Wait(op, Operation::State::ERROR));
+
+ std::string profile;
+ ASSERT_OK(op->GetProfile(&profile));
+ ASSERT_TRUE(profile.find("Cancelled") != std::string::npos);
+
+ ASSERT_OK(op->Close());
+}
+
+TEST_F(OperationTest, TestGetLog) {
+ CreateTestTable();
+
+ std::unique_ptr<Operation> op;
+ ASSERT_OK(session_->ExecuteStatement("select count(*) from " + TEST_TBL, &op));
+ std::string log;
+ ASSERT_OK(op->GetLog(&log));
+ ASSERT_NE(log, "");
+
+ ASSERT_OK(op->Close());
+}
+
+TEST_F(OperationTest, TestGetResultSetMetadata) {
+ const std::string TEST_COL1 = "int_col";
+ const std::string TEST_COL2 = "varchar_col";
+ const int MAX_LENGTH = 10;
+ const std::string TEST_COL3 = "decimal_cal";
+ const int PRECISION = 5;
+ const int SCALE = 3;
+ std::stringstream create_query;
+ create_query << "create table " << TEST_TBL << " (" << TEST_COL1 << " int, "
+ << TEST_COL2 << " varchar(" << MAX_LENGTH << "), " << TEST_COL3
+ << " decimal(" << PRECISION << ", " << SCALE << "))";
+ std::unique_ptr<Operation> create_table_op;
+ ASSERT_OK(session_->ExecuteStatement(create_query.str(), &create_table_op));
+ ASSERT_OK(create_table_op->Close());
+
+ // Perform a select, and check that we get the right metadata back.
+ std::unique_ptr<Operation> select_op;
+ ASSERT_OK(session_->ExecuteStatement("select * from " + TEST_TBL, &select_op));
+ std::vector<ColumnDesc> column_descs;
+ ASSERT_OK(select_op->GetResultSetMetadata(&column_descs));
+ ASSERT_EQ(column_descs.size(), 3);
+
+ ASSERT_EQ(column_descs[0].column_name(), TEST_COL1);
+ ASSERT_EQ(column_descs[0].type()->ToString(), "INT");
+ ASSERT_EQ(column_descs[0].type()->type_id(), ColumnType::TypeId::INT);
+ ASSERT_EQ(column_descs[0].position(), 0);
+
+ ASSERT_EQ(column_descs[1].column_name(), TEST_COL2);
+ ASSERT_EQ(column_descs[1].type()->ToString(), "VARCHAR");
+ ASSERT_EQ(column_descs[1].type()->type_id(), ColumnType::TypeId::VARCHAR);
+ ASSERT_EQ(column_descs[1].position(), 1);
+ ASSERT_EQ(column_descs[1].GetCharacterType()->max_length(), MAX_LENGTH);
+
+ ASSERT_EQ(column_descs[2].column_name(), TEST_COL3);
+ ASSERT_EQ(column_descs[2].type()->ToString(), "DECIMAL");
+ ASSERT_EQ(column_descs[2].type()->type_id(), ColumnType::TypeId::DECIMAL);
+ ASSERT_EQ(column_descs[2].position(), 2);
+ ASSERT_EQ(column_descs[2].GetDecimalType()->precision(), PRECISION);
+ ASSERT_EQ(column_descs[2].GetDecimalType()->scale(), SCALE);
+
+ ASSERT_OK(select_op->Close());
+
+ // Insert ops don't have result sets.
+ std::stringstream insert_query;
+ insert_query << "insert into " << TEST_TBL << " VALUES (1, cast('a' as varchar("
+ << MAX_LENGTH << ")), cast(1 as decimal(" << PRECISION << ", " << SCALE
+ << ")))";
+ std::unique_ptr<Operation> insert_op;
+ ASSERT_OK(session_->ExecuteStatement(insert_query.str(), &insert_op));
+ std::vector<ColumnDesc> insert_column_descs;
+ ASSERT_OK(insert_op->GetResultSetMetadata(&insert_column_descs));
+ ASSERT_EQ(insert_column_descs.size(), 0);
+ ASSERT_OK(insert_op->Close());
+}
+
+class SessionTest : public HS2ClientTest {};
+
+TEST_F(SessionTest, TestSessionConfig) {
+ // Create a table in TEST_DB.
+ const std::string& TEST_TBL = "hs2client_test_table";
+ std::unique_ptr<Operation> create_table_op;
+ ASSERT_OK(session_->ExecuteStatement(
+ "create table " + TEST_TBL + " (int_col int, string_col string)",
+ &create_table_op));
+ ASSERT_OK(create_table_op->Close());
+
+ // Start a new session with the use:database session option.
+ std::string user = "user";
+ HS2ClientConfig config_use;
+ config_use.SetOption("use:database", TEST_DB);
+ std::unique_ptr<Session> session_ok;
+ ASSERT_OK(service_->OpenSession(user, config_use, &session_ok));
+
+ // Ensure the use:database worked and we can access the table.
+ std::unique_ptr<Operation> select_op;
+ ASSERT_OK(session_ok->ExecuteStatement("select * from " + TEST_TBL, &select_op));
+ ASSERT_OK(select_op->Close());
+ ASSERT_OK(session_ok->Close());
+
+ // Start another session without use:database.
+ HS2ClientConfig config_no_use;
+ std::unique_ptr<Session> session_error;
+ ASSERT_OK(service_->OpenSession(user, config_no_use, &session_error));
+
+ // Ensure the we can't access the table.
+ std::unique_ptr<Operation> select_op_error;
+ ASSERT_RAISES(IOError, session_error->ExecuteStatement("select * from " + TEST_TBL,
+ &select_op_error));
+ ASSERT_OK(session_error->Close());
+}
+
+TEST(ServiceTest, TestConnect) {
+ // Open a connection.
+ std::string host = GetTestHost();
+ int port = 21050;
+ int conn_timeout = 0;
+ ProtocolVersion protocol_version = ProtocolVersion::PROTOCOL_V7;
+ std::unique_ptr<Service> service;
+ ASSERT_OK(Service::Connect(host, port, conn_timeout, protocol_version, &service));
+ ASSERT_TRUE(service->IsConnected());
+
+ // Check that we can start a session.
+ std::string user = "user";
+ HS2ClientConfig config;
+ std::unique_ptr<Session> session1;
+ ASSERT_OK(service->OpenSession(user, config, &session1));
+ ASSERT_OK(session1->Close());
+
+ // Close the service. We should not be able to open a session.
+ ASSERT_OK(service->Close());
+ ASSERT_FALSE(service->IsConnected());
+ ASSERT_OK(service->Close());
+ std::unique_ptr<Session> session3;
+ ASSERT_RAISES(IOError, service->OpenSession(user, config, &session3));
+ ASSERT_OK(session3->Close());
+
+ // We should be able to call Close again without errors.
+ ASSERT_OK(service->Close());
+ ASSERT_FALSE(service->IsConnected());
+}
+
+TEST(ServiceTest, TestFailedConnect) {
+ std::string host = GetTestHost();
+ int port = 21050;
+
+ // Set 100ms timeout so these return quickly
+ int conn_timeout = 100;
+
+ ProtocolVersion protocol_version = ProtocolVersion::PROTOCOL_V7;
+ std::unique_ptr<Service> service;
+
+ std::string invalid_host = "does_not_exist";
+ ASSERT_RAISES(IOError, Service::Connect(invalid_host, port, conn_timeout,
+ protocol_version, &service));
+
+ int invalid_port = -1;
+ ASSERT_RAISES(IOError, Service::Connect(host, invalid_port, conn_timeout,
+ protocol_version, &service));
+
+ ProtocolVersion invalid_protocol_version = ProtocolVersion::PROTOCOL_V2;
+ ASSERT_RAISES(NotImplemented, Service::Connect(host, port, conn_timeout,
+ invalid_protocol_version, &service));
+}
+
+} // namespace hiveserver2
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/operation.cc b/src/arrow/cpp/src/arrow/dbi/hiveserver2/operation.cc
new file mode 100644
index 000000000..3a5cceb25
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/operation.cc
@@ -0,0 +1,150 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dbi/hiveserver2/operation.h"
+
+#include "arrow/dbi/hiveserver2/thrift_internal.h"
+
+#include "arrow/dbi/hiveserver2/ImpalaService_types.h"
+#include "arrow/dbi/hiveserver2/TCLIService.h"
+
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+namespace hs2 = apache::hive::service::cli::thrift;
+using std::unique_ptr;
+
+namespace arrow {
+namespace hiveserver2 {
+
+// Max rows to fetch, if not specified.
+constexpr int kDefaultMaxRows = 1024;
+
+Operation::Operation(const std::shared_ptr<ThriftRPC>& rpc)
+ : impl_(new OperationImpl()), rpc_(rpc), open_(false) {}
+
+Operation::~Operation() { DCHECK(!open_); }
+
+Status Operation::GetState(Operation::State* out) const {
+ hs2::TGetOperationStatusReq req;
+ req.__set_operationHandle(impl_->handle);
+ hs2::TGetOperationStatusResp resp;
+ TRY_RPC_OR_RETURN(rpc_->client->GetOperationStatus(resp, req));
+ THRIFT_RETURN_NOT_OK(resp.status);
+ *out = TOperationStateToOperationState(resp.operationState);
+ return TStatusToStatus(resp.status);
+}
+
+Status Operation::GetLog(std::string* out) const {
+ hs2::TGetLogReq req;
+ req.__set_operationHandle(impl_->handle);
+ hs2::TGetLogResp resp;
+ TRY_RPC_OR_RETURN(rpc_->client->GetLog(resp, req));
+ THRIFT_RETURN_NOT_OK(resp.status);
+ *out = resp.log;
+ return TStatusToStatus(resp.status);
+}
+
+Status Operation::GetProfile(std::string* out) const {
+ impala::TGetRuntimeProfileReq req;
+ req.__set_operationHandle(impl_->handle);
+ req.__set_sessionHandle(impl_->session_handle);
+ impala::TGetRuntimeProfileResp resp;
+ TRY_RPC_OR_RETURN(rpc_->client->GetRuntimeProfile(resp, req));
+ THRIFT_RETURN_NOT_OK(resp.status);
+ *out = resp.profile;
+ return TStatusToStatus(resp.status);
+}
+
+Status Operation::GetResultSetMetadata(std::vector<ColumnDesc>* column_descs) const {
+ hs2::TGetResultSetMetadataReq req;
+ req.__set_operationHandle(impl_->handle);
+ hs2::TGetResultSetMetadataResp resp;
+ TRY_RPC_OR_RETURN(rpc_->client->GetResultSetMetadata(resp, req));
+ THRIFT_RETURN_NOT_OK(resp.status);
+
+ column_descs->clear();
+ column_descs->reserve(resp.schema.columns.size());
+ for (const hs2::TColumnDesc& tcolumn_desc : resp.schema.columns) {
+ column_descs->emplace_back(tcolumn_desc.columnName,
+ TTypeDescToColumnType(tcolumn_desc.typeDesc),
+ tcolumn_desc.position, tcolumn_desc.comment);
+ }
+
+ return TStatusToStatus(resp.status);
+}
+
+Status Operation::Fetch(unique_ptr<ColumnarRowSet>* results, bool* has_more_rows) const {
+ return Fetch(kDefaultMaxRows, FetchOrientation::NEXT, results, has_more_rows);
+}
+
+Status Operation::Fetch(int max_rows, FetchOrientation orientation,
+ unique_ptr<ColumnarRowSet>* results, bool* has_more_rows) const {
+ hs2::TFetchResultsReq req;
+ req.__set_operationHandle(impl_->handle);
+ req.__set_orientation(FetchOrientationToTFetchOrientation(orientation));
+ req.__set_maxRows(max_rows);
+ std::unique_ptr<ColumnarRowSet::ColumnarRowSetImpl> row_set_impl(
+ new ColumnarRowSet::ColumnarRowSetImpl());
+ TRY_RPC_OR_RETURN(rpc_->client->FetchResults(row_set_impl->resp, req));
+ THRIFT_RETURN_NOT_OK(row_set_impl->resp.status);
+
+ if (has_more_rows != NULL) {
+ *has_more_rows = row_set_impl->resp.hasMoreRows;
+ }
+ Status status = TStatusToStatus(row_set_impl->resp.status);
+ RETURN_NOT_OK(status);
+ results->reset(new ColumnarRowSet(row_set_impl.release()));
+ return status;
+}
+
+Status Operation::Cancel() const {
+ hs2::TCancelOperationReq req;
+ req.__set_operationHandle(impl_->handle);
+ hs2::TCancelOperationResp resp;
+ TRY_RPC_OR_RETURN(rpc_->client->CancelOperation(resp, req));
+ return TStatusToStatus(resp.status);
+}
+
+Status Operation::Close() {
+ if (!open_) return Status::OK();
+
+ hs2::TCloseOperationReq req;
+ req.__set_operationHandle(impl_->handle);
+ hs2::TCloseOperationResp resp;
+ TRY_RPC_OR_RETURN(rpc_->client->CloseOperation(resp, req));
+ THRIFT_RETURN_NOT_OK(resp.status);
+
+ open_ = false;
+ return TStatusToStatus(resp.status);
+}
+
+bool Operation::HasResultSet() const {
+ State op_state;
+ Status s = GetState(&op_state);
+ if (!s.ok()) return false;
+ return op_state == State::FINISHED;
+}
+
+bool Operation::IsColumnar() const {
+ // We currently only support the columnar hs2 protocols.
+ return true;
+}
+
+} // namespace hiveserver2
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/operation.h b/src/arrow/cpp/src/arrow/dbi/hiveserver2/operation.h
new file mode 100644
index 000000000..3efe66659
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/operation.h
@@ -0,0 +1,127 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/dbi/hiveserver2/columnar_row_set.h"
+#include "arrow/dbi/hiveserver2/types.h"
+
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Status;
+
+namespace hiveserver2 {
+
+struct ThriftRPC;
+
+// Maps directly to TFetchOrientation in the HiveServer2 interface.
+enum class FetchOrientation {
+ NEXT, // supported
+ PRIOR, // not supported
+ RELATIVE, // not supported
+ ABSOLUTE, // not supported
+ FIRST, // supported if query result caching is enabled in Impala
+ LAST // not supported
+};
+
+// Represents a single HiveServer2 operation. Used to monitor the status of an operation
+// and to retrieve its results. The only Operation function that will block is Fetch,
+// which blocks if there aren't any results ready yet.
+//
+// Operations are created using Session functions, eg. ExecuteStatement. They must
+// have Close called on them before they can be deleted.
+//
+// This class is not thread-safe.
+class ARROW_EXPORT Operation {
+ public:
+ // Maps directly to TOperationState in the HiveServer2 interface.
+ enum class State {
+ INITIALIZED,
+ RUNNING,
+ FINISHED,
+ CANCELED,
+ CLOSED,
+ ERROR,
+ UNKNOWN,
+ PENDING,
+ };
+
+ ~Operation();
+
+ // Fetches the current state of this operation. If successful, sets the operation state
+ // in 'out' and returns an OK status, otherwise an error status is returned. May be
+ // called after successfully creating the operation and before calling Close.
+ Status GetState(Operation::State* out) const;
+
+ // May be called after successfully creating the operation and before calling Close.
+ Status GetLog(std::string* out) const;
+
+ // May be called after successfully creating the operation and before calling Close.
+ Status GetProfile(std::string* out) const;
+
+ // Fetches metadata for the columns in the output of this operation, such as the
+ // names and types of the columns, and returns it as a list of column descriptions.
+ // May be called after successfully creating the operation and before calling Close.
+ Status GetResultSetMetadata(std::vector<ColumnDesc>* column_descs) const;
+
+ // Fetches a batch of results, stores them in 'results', and sets has_more_rows.
+ // Fetch will block if there aren't any results that are ready.
+ Status Fetch(std::unique_ptr<ColumnarRowSet>* results, bool* has_more_rows) const;
+ Status Fetch(int max_rows, FetchOrientation orientation,
+ std::unique_ptr<ColumnarRowSet>* results, bool* has_more_rows) const;
+
+ // May be called after successfully creating the operation and before calling Close.
+ Status Cancel() const;
+
+ // Closes the operation. Must be called before the operation is deleted. May be safely
+ // called on an invalid or already closed operation - will only return an error if the
+ // operation is open but the close rpc fails.
+ Status Close();
+
+ // May be called after successfully creating the operation and before calling Close.
+ bool HasResultSet() const;
+
+ // Returns true iff this operation's results will be returned in a columnar format.
+ // May be called at any time.
+ bool IsColumnar() const;
+
+ protected:
+ // Hides Thrift objects from the header.
+ struct OperationImpl;
+
+ explicit Operation(const std::shared_ptr<ThriftRPC>& rpc);
+
+ std::unique_ptr<OperationImpl> impl_;
+ std::shared_ptr<ThriftRPC> rpc_;
+
+ // True iff this operation has been successfully created and has not been closed yet,
+ // corresponding to when the operation has a valid operation handle.
+ bool open_;
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Operation);
+};
+
+} // namespace hiveserver2
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/public_api_test.cc b/src/arrow/cpp/src/arrow/dbi/hiveserver2/public_api_test.cc
new file mode 100644
index 000000000..833ad02ea
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/public_api_test.cc
@@ -0,0 +1,26 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include "arrow/adapters/hiveserver2/api.h"
+
+TEST(TestPublicAPI, DoesNotIncludeThrift) {
+#ifdef _THRIFT_THRIFT_H_
+ FAIL() << "Thrift headers should not be in the public API";
+#endif
+}
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/sample_usage.cc b/src/arrow/cpp/src/arrow/dbi/hiveserver2/sample_usage.cc
new file mode 100644
index 000000000..14c91adf8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/sample_usage.cc
@@ -0,0 +1,137 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cassert>
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "arrow/dbi/hiveserver2/api.h"
+
+namespace hs2 = arrow::hiveserver2;
+
+using arrow::Status;
+
+using hs2::Operation;
+using hs2::Service;
+using hs2::Session;
+
+#define ABORT_NOT_OK(s) \
+ do { \
+ ::arrow::Status _s = (s); \
+ if (ARROW_PREDICT_FALSE(!_s.ok())) { \
+ std::cerr << s.ToString() << "\n"; \
+ std::abort(); \
+ } \
+ } while (false);
+
+int main(int argc, char** argv) {
+ // Connect to the server.
+ std::string host = "localhost";
+ int port = 21050;
+ int conn_timeout = 0;
+ hs2::ProtocolVersion protocol = hs2::ProtocolVersion::PROTOCOL_V7;
+ std::unique_ptr<Service> service;
+ Status status = Service::Connect(host, port, conn_timeout, protocol, &service);
+ if (!status.ok()) {
+ std::cout << "Failed to connect to service: " << status.ToString();
+ ABORT_NOT_OK(service->Close());
+ return 1;
+ }
+
+ // Open a session.
+ std::string user = "user";
+ hs2::HS2ClientConfig config;
+ std::unique_ptr<Session> session;
+ status = service->OpenSession(user, config, &session);
+ if (!status.ok()) {
+ std::cout << "Failed to open session: " << status.ToString();
+ ABORT_NOT_OK(session->Close());
+ ABORT_NOT_OK(service->Close());
+ return 1;
+ }
+
+ // Execute a statement.
+ std::string statement = "SELECT int_col, string_col FROM test order by int_col";
+ std::unique_ptr<hs2::Operation> execute_op;
+ status = session->ExecuteStatement(statement, &execute_op);
+ if (!status.ok()) {
+ std::cout << "Failed to execute select: " << status.ToString();
+ ABORT_NOT_OK(execute_op->Close());
+ ABORT_NOT_OK(session->Close());
+ ABORT_NOT_OK(service->Close());
+ return 1;
+ }
+
+ std::unique_ptr<hs2::ColumnarRowSet> execute_results;
+ bool has_more_rows = true;
+ int64_t total_retrieved = 0;
+ std::cout << "Contents of test:\n";
+ while (has_more_rows) {
+ status = execute_op->Fetch(&execute_results, &has_more_rows);
+ if (!status.ok()) {
+ std::cout << "Failed to fetch results: " << status.ToString();
+ ABORT_NOT_OK(execute_op->Close());
+ ABORT_NOT_OK(session->Close());
+ ABORT_NOT_OK(service->Close());
+ return 1;
+ }
+
+ std::unique_ptr<hs2::Int32Column> int_col = execute_results->GetInt32Col(0);
+ std::unique_ptr<hs2::StringColumn> string_col = execute_results->GetStringCol(1);
+ assert(int_col->length() == string_col->length());
+ total_retrieved += int_col->length();
+ for (int64_t i = 0; i < int_col->length(); ++i) {
+ if (int_col->IsNull(i)) {
+ std::cout << "NULL";
+ } else {
+ std::cout << int_col->GetData(i);
+ }
+ std::cout << ":";
+
+ if (string_col->IsNull(i)) {
+ std::cout << "NULL";
+ } else {
+ std::cout << "'" << string_col->GetData(i) << "'";
+ }
+ std::cout << "\n";
+ }
+ }
+ std::cout << "retrieved " << total_retrieved << " rows\n";
+ std::cout << "\n";
+ ABORT_NOT_OK(execute_op->Close());
+
+ std::unique_ptr<Operation> show_tables_op;
+ status = session->ExecuteStatement("show tables", &show_tables_op);
+ if (!status.ok()) {
+ std::cout << "Failed to execute GetTables: " << status.ToString();
+ ABORT_NOT_OK(show_tables_op->Close());
+ ABORT_NOT_OK(session->Close());
+ ABORT_NOT_OK(service->Close());
+ return 1;
+ }
+
+ std::cout << "Show tables:\n";
+ hs2::Util::PrintResults(show_tables_op.get(), std::cout);
+ ABORT_NOT_OK(show_tables_op->Close());
+
+ // Shut down.
+ ABORT_NOT_OK(session->Close());
+ ABORT_NOT_OK(service->Close());
+
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/service.cc b/src/arrow/cpp/src/arrow/dbi/hiveserver2/service.cc
new file mode 100644
index 000000000..8ac19a8d3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/service.cc
@@ -0,0 +1,110 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dbi/hiveserver2/service.h"
+
+#include <thrift/protocol/TBinaryProtocol.h>
+#include <thrift/transport/TSocket.h>
+#include <thrift/transport/TTransportUtils.h>
+#include <sstream>
+
+#include "arrow/dbi/hiveserver2/session.h"
+#include "arrow/dbi/hiveserver2/thrift_internal.h"
+
+#include "arrow/dbi/hiveserver2/ImpalaHiveServer2Service.h"
+#include "arrow/dbi/hiveserver2/TCLIService.h"
+
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+
+namespace hs2 = apache::hive::service::cli::thrift;
+
+using apache::thrift::TException;
+using apache::thrift::protocol::TBinaryProtocol;
+using apache::thrift::protocol::TProtocol;
+using apache::thrift::transport::TBufferedTransport;
+using apache::thrift::transport::TSocket;
+using apache::thrift::transport::TTransport;
+
+namespace arrow {
+namespace hiveserver2 {
+
+struct Service::ServiceImpl {
+ hs2::TProtocolVersion::type protocol_version;
+ std::shared_ptr<TSocket> socket;
+ std::shared_ptr<TTransport> transport;
+ std::shared_ptr<TProtocol> protocol;
+};
+
+Status Service::Connect(const std::string& host, int port, int conn_timeout,
+ ProtocolVersion protocol_version,
+ std::unique_ptr<Service>* service) {
+ service->reset(new Service(host, port, conn_timeout, protocol_version));
+ return (*service)->Open();
+}
+
+Service::~Service() { DCHECK(!IsConnected()); }
+
+Status Service::Close() {
+ if (!IsConnected()) return Status::OK();
+ TRY_RPC_OR_RETURN(impl_->transport->close());
+ return Status::OK();
+}
+
+bool Service::IsConnected() const {
+ return impl_->transport && impl_->transport->isOpen();
+}
+
+void Service::SetRecvTimeout(int timeout) { impl_->socket->setRecvTimeout(timeout); }
+
+void Service::SetSendTimeout(int timeout) { impl_->socket->setSendTimeout(timeout); }
+
+Status Service::OpenSession(const std::string& user, const HS2ClientConfig& config,
+ std::unique_ptr<Session>* session) const {
+ session->reset(new Session(rpc_));
+ return (*session)->Open(config, user);
+}
+
+Service::Service(const std::string& host, int port, int conn_timeout,
+ ProtocolVersion protocol_version)
+ : host_(host),
+ port_(port),
+ conn_timeout_(conn_timeout),
+ impl_(new ServiceImpl()),
+ rpc_(new ThriftRPC()) {
+ impl_->protocol_version = ProtocolVersionToTProtocolVersion(protocol_version);
+}
+
+Status Service::Open() {
+ if (impl_->protocol_version < hs2::TProtocolVersion::HIVE_CLI_SERVICE_PROTOCOL_V6) {
+ return Status::NotImplemented("Unsupported protocol: ", impl_->protocol_version);
+ }
+
+ impl_->socket.reset(new TSocket(host_, port_));
+ impl_->socket->setConnTimeout(conn_timeout_);
+ impl_->transport.reset(new TBufferedTransport(impl_->socket));
+ impl_->protocol.reset(new TBinaryProtocol(impl_->transport));
+
+ rpc_->client.reset(new impala::ImpalaHiveServer2ServiceClient(impl_->protocol));
+
+ TRY_RPC_OR_RETURN(impl_->transport->open());
+
+ return Status::OK();
+}
+
+} // namespace hiveserver2
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/service.h b/src/arrow/cpp/src/arrow/dbi/hiveserver2/service.h
new file mode 100644
index 000000000..8b9a094f5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/service.h
@@ -0,0 +1,140 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <map>
+#include <memory>
+#include <string>
+
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Status;
+
+namespace hiveserver2 {
+
+class Session;
+struct ThriftRPC;
+
+// Stores per-session or per-operation configuration parameters.
+class HS2ClientConfig {
+ public:
+ void SetOption(const std::string& key, const std::string& value) {
+ config_[key] = value;
+ }
+
+ bool GetOption(const std::string& key, std::string* value_out) {
+ if (config_.find(key) != config_.end() && value_out) {
+ *value_out = config_[key];
+ return true;
+ }
+ return false;
+ }
+
+ const std::map<std::string, std::string>& GetConfig() const { return config_; }
+
+ private:
+ std::map<std::string, std::string> config_;
+};
+
+// Maps directly to TProtocolVersion in the HiveServer2 interface.
+enum class ProtocolVersion {
+ PROTOCOL_V1, // not supported
+ PROTOCOL_V2, // not supported
+ PROTOCOL_V3, // not supported
+ PROTOCOL_V4, // not supported
+ PROTOCOL_V5, // not supported
+ PROTOCOL_V6, // supported
+ PROTOCOL_V7, // supported
+};
+
+// Manages a connection to a HiveServer2 server. Primarily used to create
+// new sessions via OpenSession.
+//
+// Service objects are created using Service::Connect(). They must
+// have Close called on them before they can be deleted.
+//
+// This class is not thread-safe.
+//
+// Example:
+// unique_ptr<Service> service;
+// if (Service::Connect(host, port, protocol_version, &service).ok()) {
+// // do some work
+// service->Close();
+// }
+class ARROW_EXPORT Service {
+ public:
+ // Creates a new connection to a HS2 service at the given host and port. If
+ // conn_timeout > 0, connection attempts will timeout after conn_timeout ms, otherwise
+ // no timeout is used. protocol_version is the HiveServer2 protocol to use, and
+ // determines whether the results returned by operations from this service are row or
+ // column oriented. Only column oriented protocols are currently supported.
+ //
+ // The client calling Connect has ownership of the new Service that is created.
+ // Executing RPCs with a Session or Operation corresponding to a particular
+ // Service after that Service has been closed or deleted in undefined.
+ static Status Connect(const std::string& host, int port, int conn_timeout,
+ ProtocolVersion protocol_version,
+ std::unique_ptr<Service>* service);
+
+ ~Service();
+
+ // Closes the connection. Must be called before the service is deleted. May be
+ // safely called on an invalid or already closed service - will only return an
+ // error if the service is open but the close rpc fails.
+ Status Close();
+
+ // Returns true iff this service has an active connection to the HiveServer2 server.
+ bool IsConnected() const;
+
+ // Set the send and receive timeout for Thrift RPCs in ms. 0 indicates no timeout,
+ // negative values are ignored.
+ void SetRecvTimeout(int timeout);
+ void SetSendTimeout(int timeout);
+
+ // Opens a new HS2 session using this service.
+ // The client calling OpenSession has ownership of the Session that is created.
+ // Operations on the Session are undefined once it is closed.
+ Status OpenSession(const std::string& user, const HS2ClientConfig& config,
+ std::unique_ptr<Session>* session) const;
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Service);
+
+ // Hides Thrift objects from the header.
+ struct ServiceImpl;
+
+ Service(const std::string& host, int port, int conn_timeout,
+ ProtocolVersion protocol_version);
+
+ // Opens the connection to the server. Called by Connect before new service is returned
+ // to the user. Must be called before OpenSession.
+ Status Open();
+
+ std::string host_;
+ int port_;
+ int conn_timeout_;
+
+ std::unique_ptr<ServiceImpl> impl_;
+ std::shared_ptr<ThriftRPC> rpc_;
+};
+
+} // namespace hiveserver2
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/session.cc b/src/arrow/cpp/src/arrow/dbi/hiveserver2/session.cc
new file mode 100644
index 000000000..069f07275
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/session.cc
@@ -0,0 +1,103 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dbi/hiveserver2/session.h"
+
+#include "arrow/dbi/hiveserver2/TCLIService.h"
+#include "arrow/dbi/hiveserver2/thrift_internal.h"
+
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+
+namespace hs2 = apache::hive::service::cli::thrift;
+using apache::thrift::TException;
+
+namespace arrow {
+namespace hiveserver2 {
+
+struct Session::SessionImpl {
+ hs2::TSessionHandle handle;
+};
+
+Session::Session(const std::shared_ptr<ThriftRPC>& rpc)
+ : impl_(new SessionImpl()), rpc_(rpc), open_(false) {}
+
+Session::~Session() { DCHECK(!open_); }
+
+Status Session::Close() {
+ if (!open_) return Status::OK();
+
+ hs2::TCloseSessionReq req;
+ req.__set_sessionHandle(impl_->handle);
+ hs2::TCloseSessionResp resp;
+ TRY_RPC_OR_RETURN(rpc_->client->CloseSession(resp, req));
+ THRIFT_RETURN_NOT_OK(resp.status);
+
+ open_ = false;
+ return TStatusToStatus(resp.status);
+}
+
+Status Session::Open(const HS2ClientConfig& config, const std::string& user) {
+ hs2::TOpenSessionReq req;
+ req.__set_configuration(config.GetConfig());
+ req.__set_username(user);
+ hs2::TOpenSessionResp resp;
+ TRY_RPC_OR_RETURN(rpc_->client->OpenSession(resp, req));
+ THRIFT_RETURN_NOT_OK(resp.status);
+
+ impl_->handle = resp.sessionHandle;
+ open_ = true;
+ return TStatusToStatus(resp.status);
+}
+
+class ExecuteStatementOperation : public Operation {
+ public:
+ explicit ExecuteStatementOperation(const std::shared_ptr<ThriftRPC>& rpc)
+ : Operation(rpc) {}
+
+ Status Open(hs2::TSessionHandle session_handle, const std::string& statement,
+ const HS2ClientConfig& config) {
+ hs2::TExecuteStatementReq req;
+ req.__set_sessionHandle(session_handle);
+ req.__set_statement(statement);
+ req.__set_confOverlay(config.GetConfig());
+ hs2::TExecuteStatementResp resp;
+ TRY_RPC_OR_RETURN(rpc_->client->ExecuteStatement(resp, req));
+ THRIFT_RETURN_NOT_OK(resp.status);
+
+ impl_->handle = resp.operationHandle;
+ impl_->session_handle = session_handle;
+ open_ = true;
+ return TStatusToStatus(resp.status);
+ }
+};
+
+Status Session::ExecuteStatement(const std::string& statement,
+ std::unique_ptr<Operation>* operation) const {
+ return ExecuteStatement(statement, HS2ClientConfig(), operation);
+}
+
+Status Session::ExecuteStatement(const std::string& statement,
+ const HS2ClientConfig& conf_overlay,
+ std::unique_ptr<Operation>* operation) const {
+ ExecuteStatementOperation* op = new ExecuteStatementOperation(rpc_);
+ operation->reset(op);
+ return op->Open(impl_->handle, statement, conf_overlay);
+}
+
+} // namespace hiveserver2
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/session.h b/src/arrow/cpp/src/arrow/dbi/hiveserver2/session.h
new file mode 100644
index 000000000..4e223de6c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/session.h
@@ -0,0 +1,84 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/dbi/hiveserver2/operation.h"
+#include "arrow/dbi/hiveserver2/service.h"
+
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Status;
+
+namespace hiveserver2 {
+
+struct ThriftRPC;
+
+// Manages a single HiveServer2 session - stores the session handle returned by
+// the OpenSession RPC and uses it to create and return operations.
+//
+// Sessions are created with Service::OpenSession(). They must have Close
+// called on them before they can be deleted.
+//
+// Executing RPCs with an Operation corresponding to a particular Session after
+// that Session has been closed or deleted is undefined.
+//
+// This class is not thread-safe.
+class ARROW_EXPORT Session {
+ public:
+ ~Session();
+
+ // Closes the session. Must be called before the session is deleted. May be safely
+ // called on an invalid or already closed session - will only return an error if the
+ // session is open but the close rpc fails.
+ Status Close();
+
+ Status ExecuteStatement(const std::string& statement,
+ std::unique_ptr<Operation>* operation) const;
+ Status ExecuteStatement(const std::string& statement,
+ const HS2ClientConfig& conf_overlay,
+ std::unique_ptr<Operation>* operation) const;
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Session);
+
+ // Hides Thrift objects from the header.
+ struct SessionImpl;
+
+ // For access to the c'tor.
+ friend class Service;
+
+ explicit Session(const std::shared_ptr<ThriftRPC>& rpc);
+
+ // Performs the RPC that initiates the session and stores the returned handle.
+ // Must be called before operations can be executed.
+ Status Open(const HS2ClientConfig& config, const std::string& user);
+
+ std::unique_ptr<SessionImpl> impl_;
+ std::shared_ptr<ThriftRPC> rpc_;
+
+ // True if Open has been called and Close has not.
+ bool open_;
+};
+
+} // namespace hiveserver2
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/.gitignore b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/.gitignore
new file mode 100644
index 000000000..f510e7c95
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/.gitignore
@@ -0,0 +1 @@
+ErrorCodes.thrift
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/CMakeLists.txt b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/CMakeLists.txt
new file mode 100644
index 000000000..237a92a82
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/CMakeLists.txt
@@ -0,0 +1,120 @@
+# Copyright 2012 Cloudera Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Helper function to generate build rules. For each input thrift file, this function will
+# generate a rule that maps the input file to the output c++ file.
+# Thrift will generate multiple output files for each input (including java files) and
+# ideally, we'd specify all of the outputs for dependency tracking.
+# Unfortunately, it's not easy to figure out all the output files without parsing the
+# thrift input. (TODO: can thrift tells us what the java output files will be?)
+# The list of output files is used for build dependency tracking so it's not necessary to
+# capture all the output files.
+#
+# To call this function, pass it the output file list followed by the input thrift files:
+# i.e. HS2_THRIFT_GEN(OUTPUT_FILES, ${THRIFT_FILES})
+#
+# cmake seems to be case sensitive for some keywords. Changing the first IF check to lower
+# case makes it not work. TODO: investigate this
+function(HS2_THRIFT_GEN VAR)
+ if(NOT ARGN)
+ message(SEND_ERROR "Error: THRIFT_GEN called without any src files")
+ return()
+ endif(NOT ARGN)
+
+ set(${VAR})
+ foreach(FIL ${ARGN})
+ # Get full path
+ get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
+ # Get basename
+ get_filename_component(FIL_WE ${FIL} NAME_WE)
+
+ set(GEN_DIR "${OUTPUT_DIR}/arrow/dbi/hiveserver2")
+
+ # All the output files we can determine based on filename.
+ # - Does not include .skeleton.cpp files
+ # - Does not include java output files
+ set(OUTPUT_BE_FILE
+ "${GEN_DIR}/${FIL_WE}_types.cpp" "${GEN_DIR}/${FIL_WE}_types.h"
+ "${GEN_DIR}/${FIL_WE}_constants.cpp" "${GEN_DIR}/${FIL_WE}_constants.h")
+ list(APPEND ${VAR} ${OUTPUT_BE_FILE})
+
+ # BeeswaxService thrift generation
+ # It depends on hive_meta_store, which in turn depends on fb303.
+ # The java dependency is handled by maven.
+ # We need to generate C++ src file for the parent dependencies using the "-r" option.
+ set(CPP_ARGS
+ -nowarn
+ --gen
+ cpp
+ -out
+ ${GEN_DIR})
+ if(FIL STREQUAL "beeswax.thrift")
+ set(CPP_ARGS
+ -r
+ -nowarn
+ --gen
+ cpp
+ -out
+ ${GEN_DIR})
+ endif(FIL STREQUAL "beeswax.thrift")
+
+ # Be able to include generated ErrorCodes.thrift file
+ set(CPP_ARGS ${CPP_ARGS} -I ${CMAKE_CURRENT_BINARY_DIR})
+
+ add_custom_command(OUTPUT ${OUTPUT_BE_FILE}
+ COMMAND ${THRIFT_COMPILER} ${CPP_ARGS} ${FIL}
+ DEPENDS ${ABS_FIL}
+ COMMENT "Running thrift compiler on ${FIL}"
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+ VERBATIM)
+ endforeach(FIL)
+
+ set(${VAR}
+ ${${VAR}}
+ PARENT_SCOPE)
+endfunction(HS2_THRIFT_GEN)
+
+message("Using Thrift compiler: ${THRIFT_COMPILER}")
+
+set(OUTPUT_DIR ${ARROW_BINARY_DIR}/src)
+file(MAKE_DIRECTORY ${OUTPUT_DIR})
+
+add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ErrorCodes.thrift
+ COMMAND python generate_error_codes.py ${CMAKE_CURRENT_BINARY_DIR}
+ DEPENDS generate_error_codes.py
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
+
+set(SRC_FILES
+ ${CMAKE_CURRENT_BINARY_DIR}/ErrorCodes.thrift
+ beeswax.thrift
+ TCLIService.thrift
+ ExecStats.thrift
+ ImpalaService.thrift
+ Status.thrift
+ Types.thrift)
+
+set_source_files_properties(Status.thrift
+ PROPERTIES OBJECT_DEPENDS
+ ${CMAKE_CURRENT_BINARY_DIR}/ErrorCodes.thrift)
+
+# Create a build command for each of the thrift src files and generate
+# a list of files they produce
+hs2_thrift_gen(THRIFT_ALL_FILES ${SRC_FILES})
+
+# Add a custom target that generates all the thrift files
+add_custom_target(hs2-thrift-cpp ALL DEPENDS ${THRIFT_ALL_FILES})
+
+add_custom_target(hs2-thrift-generated-files-error
+ DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/ErrorCodes.thrift)
+add_dependencies(hs2-thrift-cpp hs2-thrift-generated-files-error)
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/ExecStats.thrift b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/ExecStats.thrift
new file mode 100644
index 000000000..bcf5c4c6a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/ExecStats.thrift
@@ -0,0 +1,103 @@
+// Copyright 2012 Cloudera Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+namespace cpp impala
+namespace java com.cloudera.impala.thrift
+
+include "Status.thrift"
+include "Types.thrift"
+
+enum TExecState {
+ REGISTERED = 0,
+ PLANNING = 1,
+ QUEUED = 2,
+ RUNNING = 3,
+ FINISHED = 4,
+
+ CANCELLED = 5,
+ FAILED = 6,
+}
+
+// Execution stats for a single plan node.
+struct TExecStats {
+ // The wall clock time spent on the "main" thread. This is the user perceived
+ // latency. This value indicates the current bottleneck.
+ // Note: anywhere we have a queue between operators, this time can fluctuate
+ // significantly without the overall query time changing much (i.e. the bottleneck
+ // moved to another operator). This is unavoidable though.
+ 1: optional i64 latency_ns
+
+ // Total CPU time spent across all threads. For operators that have an async
+ // component (e.g. multi-threaded) this will be >= latency_ns.
+ 2: optional i64 cpu_time_ns
+
+ // Number of rows returned.
+ 3: optional i64 cardinality
+
+ // Peak memory used (in bytes).
+ 4: optional i64 memory_used
+}
+
+// Summary for a single plan node. This includes labels for how to display the
+// node as well as per instance stats.
+struct TPlanNodeExecSummary {
+ 1: required Types.TPlanNodeId node_id
+ 2: required i32 fragment_id
+ 3: required string label
+ 4: optional string label_detail
+ 5: required i32 num_children
+
+ // Estimated stats generated by the planner
+ 6: optional TExecStats estimated_stats
+
+ // One entry for each BE executing this plan node.
+ 7: optional list<TExecStats> exec_stats
+
+ // One entry for each BE executing this plan node. True if this plan node is still
+ // running.
+ 8: optional list<bool> is_active
+
+ // If true, this plan node is an exchange node that is the receiver of a broadcast.
+ 9: optional bool is_broadcast
+}
+
+// Progress counters for an in-flight query.
+struct TExecProgress {
+ 1: optional i64 total_scan_ranges
+ 2: optional i64 num_completed_scan_ranges
+}
+
+// Execution summary of an entire query.
+struct TExecSummary {
+ // State of the query.
+ 1: required TExecState state
+
+ // Contains the error if state is FAILED.
+ 2: optional Status.TStatus status
+
+ // Flattened execution summary of the plan tree.
+ 3: optional list<TPlanNodeExecSummary> nodes
+
+ // For each exch node in 'nodes', contains the index to the root node of the sending
+ // fragment for this exch. Both the key and value are indices into 'nodes'.
+ 4: optional map<i32, i32> exch_to_sender_map
+
+ // List of errors that were encountered during execution. This can be non-empty
+ // even if status is okay, in which case it contains errors that impala skipped
+ // over.
+ 5: optional list<string> error_logs
+
+ // Optional record indicating the query progress
+ 6: optional TExecProgress progress
+}
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/ImpalaService.thrift b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/ImpalaService.thrift
new file mode 100644
index 000000000..76a839604
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/ImpalaService.thrift
@@ -0,0 +1,300 @@
+// Copyright 2012 Cloudera Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+namespace cpp impala
+namespace java com.cloudera.impala.thrift
+
+include "ExecStats.thrift"
+include "Status.thrift"
+include "Types.thrift"
+include "beeswax.thrift"
+include "TCLIService.thrift"
+
+// ImpalaService accepts query execution options through beeswax.Query.configuration in
+// key:value form. For example, the list of strings could be:
+// "num_nodes:1", "abort_on_error:false"
+// The valid keys are listed in this enum. They map to TQueryOptions.
+// Note: If you add an option or change the default, you also need to update:
+// - ImpalaInternalService.thrift: TQueryOptions
+// - SetQueryOption(), SetQueryOptions()
+// - TQueryOptionsToMap()
+enum TImpalaQueryOptions {
+ // if true, abort execution on the first error
+ ABORT_ON_ERROR,
+
+ // maximum # of errors to be reported; Unspecified or 0 indicates backend default
+ MAX_ERRORS,
+
+ // if true, disable llvm codegen
+ DISABLE_CODEGEN,
+
+ // batch size to be used by backend; Unspecified or a size of 0 indicates backend
+ // default
+ BATCH_SIZE,
+
+ // a per-machine approximate limit on the memory consumption of this query;
+ // unspecified or a limit of 0 means no limit;
+ // otherwise specified either as:
+ // a) an int (= number of bytes);
+ // b) a float followed by "M" (MB) or "G" (GB)
+ MEM_LIMIT,
+
+ // specifies the degree of parallelism with which to execute the query;
+ // 1: single-node execution
+ // NUM_NODES_ALL: executes on all nodes that contain relevant data
+ // NUM_NODES_ALL_RACKS: executes on one node per rack that holds relevant data
+ // > 1: executes on at most that many nodes at any point in time (ie, there can be
+ // more nodes than numNodes with plan fragments for this query, but at most
+ // numNodes would be active at any point in time)
+ // Constants (NUM_NODES_ALL, NUM_NODES_ALL_RACKS) are defined in JavaConstants.thrift.
+ NUM_NODES,
+
+ // maximum length of the scan range; only applicable to HDFS scan range; Unspecified or
+ // a length of 0 indicates backend default;
+ MAX_SCAN_RANGE_LENGTH,
+
+ // Maximum number of io buffers (per disk)
+ MAX_IO_BUFFERS,
+
+ // Number of scanner threads.
+ NUM_SCANNER_THREADS,
+
+ // If true, Impala will try to execute on file formats that are not fully supported yet
+ ALLOW_UNSUPPORTED_FORMATS,
+
+ // if set and > -1, specifies the default limit applied to a top-level SELECT statement
+ // with an ORDER BY but without a LIMIT clause (ie, if the SELECT statement also has
+ // a LIMIT clause, this default is ignored)
+ DEFAULT_ORDER_BY_LIMIT,
+
+ // DEBUG ONLY:
+ // If set to
+ // "[<backend number>:]<node id>:<TExecNodePhase>:<TDebugAction>",
+ // the exec node with the given id will perform the specified action in the given
+ // phase. If the optional backend number (starting from 0) is specified, only that
+ // backend instance will perform the debug action, otherwise all backends will behave
+ // in that way.
+ // If the string doesn't have the required format or if any of its components is
+ // invalid, the option is ignored.
+ DEBUG_ACTION,
+
+ // If true, raise an error when the DEFAULT_ORDER_BY_LIMIT has been reached.
+ ABORT_ON_DEFAULT_LIMIT_EXCEEDED,
+
+ // Compression codec when inserting into tables.
+ // Valid values are "snappy", "gzip", "bzip2" and "none"
+ // Leave blank to use default.
+ COMPRESSION_CODEC,
+
+ // Mode for compressing sequence files; either BLOCK, RECORD, or DEFAULT
+ SEQ_COMPRESSION_MODE,
+
+ // HBase scan query option. If set and > 0, HBASE_CACHING is the value for
+ // "hbase.client.Scan.setCaching()" when querying HBase table. Otherwise, use backend
+ // default.
+ // If the value is too high, then the hbase region server will have a hard time (GC
+ // pressure and long response times). If the value is too small, then there will be
+ // extra trips to the hbase region server.
+ HBASE_CACHING,
+
+ // HBase scan query option. If set, HBase scan will always set
+ // "hbase.client.setCacheBlocks" to CACHE_BLOCKS. Default is false.
+ // If the table is large and the query is doing big scan, set it to false to
+ // avoid polluting the cache in the hbase region server.
+ // If the table is small and the table is used several time, set it to true to improve
+ // performance.
+ HBASE_CACHE_BLOCKS,
+
+ // Target file size for inserts into parquet tables. 0 uses the default.
+ PARQUET_FILE_SIZE,
+
+ // Level of detail for explain output (NORMAL, VERBOSE).
+ EXPLAIN_LEVEL,
+
+ // If true, waits for the result of all catalog operations to be processed by all
+ // active impalad in the cluster before completing.
+ SYNC_DDL,
+
+ // Request pool this request should be submitted to. If not set
+ // the pool is determined based on the user.
+ REQUEST_POOL,
+
+ // Per-host virtual CPU cores required for query (only relevant with RM).
+ V_CPU_CORES,
+
+ // Max time in milliseconds the resource broker should wait for
+ // a resource request to be granted by Llama/Yarn (only relevant with RM).
+ RESERVATION_REQUEST_TIMEOUT,
+
+ // if true, disables cached reads. This option has no effect if REPLICA_PREFERENCE is
+ // configured.
+ // TODO: Retire in C6
+ DISABLE_CACHED_READS,
+
+ // Temporary testing flag
+ DISABLE_OUTERMOST_TOPN,
+
+ // Size of initial memory reservation when RM is enabled
+ RM_INITIAL_MEM,
+
+ // Time, in s, before a query will be timed out if it is inactive. May not exceed
+ // --idle_query_timeout if that flag > 0.
+ QUERY_TIMEOUT_S,
+
+ // Test hook for spill to disk operators
+ MAX_BLOCK_MGR_MEMORY,
+
+ // Transforms all count(distinct) aggregations into NDV()
+ APPX_COUNT_DISTINCT,
+
+ // If true, allows Impala to internally disable spilling for potentially
+ // disastrous query plans. Impala will exercise this option if a query
+ // has no plan hints, and at least one table is missing relevant stats.
+ DISABLE_UNSAFE_SPILLS,
+
+ // If the number of rows that are processed for a single query is below the
+ // threshold, it will be executed on the coordinator only with codegen disabled
+ EXEC_SINGLE_NODE_ROWS_THRESHOLD,
+
+ // If true, use the table's metadata to produce the partition columns instead of table
+ // scans whenever possible. This option is opt-in by default as this optimization may
+ // produce different results than the scan based approach in some edge cases.
+ OPTIMIZE_PARTITION_KEY_SCANS,
+
+ // Preferred memory distance of replicas. This parameter determines the pool of replicas
+ // among which scans will be scheduled in terms of the distance of the replica storage
+ // from the impalad.
+ REPLICA_PREFERENCE,
+
+ // Determines tie breaking policy when picking locations.
+ RANDOM_REPLICA,
+
+ // For scan nodes with any conjuncts, use codegen to evaluate the conjuncts if
+ // the number of rows * number of operators in the conjuncts exceeds this threshold.
+ SCAN_NODE_CODEGEN_THRESHOLD,
+
+ // If true, the planner will not generate plans with streaming preaggregations.
+ DISABLE_STREAMING_PREAGGREGATIONS,
+
+ RUNTIME_FILTER_MODE,
+
+ // Size (in bytes) of a runtime Bloom Filter. Will be rounded up to nearest power of
+ // two.
+ RUNTIME_BLOOM_FILTER_SIZE,
+
+ // Time (in ms) to wait in scans for partition filters to arrive.
+ RUNTIME_FILTER_WAIT_TIME_MS,
+
+ // If true, disable application of runtime filters to individual rows.
+ DISABLE_ROW_RUNTIME_FILTERING,
+
+ // Maximum number of runtime filters allowed per query.
+ MAX_NUM_RUNTIME_FILTERS
+}
+
+// The summary of an insert.
+struct TInsertResult {
+ // Number of appended rows per modified partition. Only applies to HDFS tables.
+ // The keys represent partitions to create, coded as k1=v1/k2=v2/k3=v3..., with the
+ // root in an unpartitioned table being the empty string.
+ 1: required map<string, i64> rows_appended
+}
+
+// Response from a call to PingImpalaService
+struct TPingImpalaServiceResp {
+ // The Impala service's version string.
+ 1: string version
+}
+
+// Parameters for a ResetTable request which will invalidate a table's metadata.
+// DEPRECATED.
+struct TResetTableReq {
+ // Name of the table's parent database.
+ 1: required string db_name
+
+ // Name of the table.
+ 2: required string table_name
+}
+
+// For all rpc that return a TStatus as part of their result type,
+// if the status_code field is set to anything other than OK, the contents
+// of the remainder of the result type is undefined (typically not set)
+service ImpalaService extends beeswax.BeeswaxService {
+ // Cancel execution of query. Returns RUNTIME_ERROR if query_id
+ // unknown.
+ // This terminates all threads running on behalf of this query at
+ // all nodes that were involved in the execution.
+ // Throws BeeswaxException if the query handle is invalid (this doesn't
+ // necessarily indicate an error: the query might have finished).
+ Status.TStatus Cancel(1:beeswax.QueryHandle query_id)
+ throws(1:beeswax.BeeswaxException error);
+
+ // Invalidates all catalog metadata, forcing a reload
+ // DEPRECATED; execute query "invalidate metadata" to refresh metadata
+ Status.TStatus ResetCatalog();
+
+ // Invalidates a specific table's catalog metadata, forcing a reload on the next access
+ // DEPRECATED; execute query "refresh <table>" to refresh metadata
+ Status.TStatus ResetTable(1:TResetTableReq request)
+
+ // Returns the runtime profile string for the given query handle.
+ string GetRuntimeProfile(1:beeswax.QueryHandle query_id)
+ throws(1:beeswax.BeeswaxException error);
+
+ // Closes the query handle and return the result summary of the insert.
+ TInsertResult CloseInsert(1:beeswax.QueryHandle handle)
+ throws(1:beeswax.QueryNotFoundException error, 2:beeswax.BeeswaxException error2);
+
+ // Client calls this RPC to verify that the server is an ImpalaService. Returns the
+ // server version.
+ TPingImpalaServiceResp PingImpalaService();
+
+ // Returns the summary of the current execution.
+ ExecStats.TExecSummary GetExecSummary(1:beeswax.QueryHandle handle)
+ throws(1:beeswax.QueryNotFoundException error, 2:beeswax.BeeswaxException error2);
+}
+
+// Impala HiveServer2 service
+
+struct TGetExecSummaryReq {
+ 1: optional TCLIService.TOperationHandle operationHandle
+
+ 2: optional TCLIService.TSessionHandle sessionHandle
+}
+
+struct TGetExecSummaryResp {
+ 1: required TCLIService.TStatus status
+
+ 2: optional ExecStats.TExecSummary summary
+}
+
+struct TGetRuntimeProfileReq {
+ 1: optional TCLIService.TOperationHandle operationHandle
+
+ 2: optional TCLIService.TSessionHandle sessionHandle
+}
+
+struct TGetRuntimeProfileResp {
+ 1: required TCLIService.TStatus status
+
+ 2: optional string profile
+}
+
+service ImpalaHiveServer2Service extends TCLIService.TCLIService {
+ // Returns the exec summary for the given query
+ TGetExecSummaryResp GetExecSummary(1:TGetExecSummaryReq req);
+
+ // Returns the runtime profile string for the given query
+ TGetRuntimeProfileResp GetRuntimeProfile(1:TGetRuntimeProfileReq req);
+}
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/Status.thrift b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/Status.thrift
new file mode 100644
index 000000000..db9518e02
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/Status.thrift
@@ -0,0 +1,23 @@
+// Copyright 2012 Cloudera Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+include "ErrorCodes.thrift"
+
+namespace cpp impala
+namespace java com.cloudera.impala.thrift
+
+struct TStatus {
+ 1: required ErrorCodes.TErrorCode status_code
+ 2: list<string> error_msgs
+} \ No newline at end of file
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/TCLIService.thrift b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/TCLIService.thrift
new file mode 100644
index 000000000..e0d74c53a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/TCLIService.thrift
@@ -0,0 +1,1180 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Coding Conventions for this file:
+//
+// Structs/Enums/Unions
+// * Struct, Enum, and Union names begin with a "T",
+// and use a capital letter for each new word, with no underscores.
+// * All fields should be declared as either optional or required.
+//
+// Functions
+// * Function names start with a capital letter and have a capital letter for
+// each new word, with no underscores.
+// * Each function should take exactly one parameter, named TFunctionNameReq,
+// and should return either void or TFunctionNameResp. This convention allows
+// incremental updates.
+//
+// Services
+// * Service names begin with the letter "T", use a capital letter for each
+// new word (with no underscores), and end with the word "Service".
+
+namespace java org.apache.hive.service.cli.thrift
+namespace cpp apache.hive.service.cli.thrift
+
+// List of protocol versions. A new token should be
+// added to the end of this list every time a change is made.
+enum TProtocolVersion {
+ HIVE_CLI_SERVICE_PROTOCOL_V1,
+
+ // V2 adds support for asynchronous execution
+ HIVE_CLI_SERVICE_PROTOCOL_V2
+
+ // V3 add varchar type, primitive type qualifiers
+ HIVE_CLI_SERVICE_PROTOCOL_V3
+
+ // V4 add decimal precision/scale, char type
+ HIVE_CLI_SERVICE_PROTOCOL_V4
+
+ // V5 adds error details when GetOperationStatus returns in error state
+ HIVE_CLI_SERVICE_PROTOCOL_V5
+
+ // V6 uses binary type for binary payload (was string) and uses columnar result set
+ HIVE_CLI_SERVICE_PROTOCOL_V6
+
+ // V7 adds support for delegation token based connection
+ HIVE_CLI_SERVICE_PROTOCOL_V7
+}
+
+enum TTypeId {
+ BOOLEAN_TYPE,
+ TINYINT_TYPE,
+ SMALLINT_TYPE,
+ INT_TYPE,
+ BIGINT_TYPE,
+ FLOAT_TYPE,
+ DOUBLE_TYPE,
+ STRING_TYPE,
+ TIMESTAMP_TYPE,
+ BINARY_TYPE,
+ ARRAY_TYPE,
+ MAP_TYPE,
+ STRUCT_TYPE,
+ UNION_TYPE,
+ USER_DEFINED_TYPE,
+ DECIMAL_TYPE,
+ NULL_TYPE,
+ DATE_TYPE,
+ VARCHAR_TYPE,
+ CHAR_TYPE
+}
+
+const set<TTypeId> PRIMITIVE_TYPES = [
+ TTypeId.BOOLEAN_TYPE,
+ TTypeId.TINYINT_TYPE,
+ TTypeId.SMALLINT_TYPE,
+ TTypeId.INT_TYPE,
+ TTypeId.BIGINT_TYPE,
+ TTypeId.FLOAT_TYPE,
+ TTypeId.DOUBLE_TYPE,
+ TTypeId.STRING_TYPE,
+ TTypeId.TIMESTAMP_TYPE,
+ TTypeId.BINARY_TYPE,
+ TTypeId.DECIMAL_TYPE,
+ TTypeId.NULL_TYPE,
+ TTypeId.DATE_TYPE,
+ TTypeId.VARCHAR_TYPE,
+ TTypeId.CHAR_TYPE
+]
+
+const set<TTypeId> COMPLEX_TYPES = [
+ TTypeId.ARRAY_TYPE
+ TTypeId.MAP_TYPE
+ TTypeId.STRUCT_TYPE
+ TTypeId.UNION_TYPE
+ TTypeId.USER_DEFINED_TYPE
+]
+
+const set<TTypeId> COLLECTION_TYPES = [
+ TTypeId.ARRAY_TYPE
+ TTypeId.MAP_TYPE
+]
+
+const map<TTypeId,string> TYPE_NAMES = {
+ TTypeId.BOOLEAN_TYPE: "BOOLEAN",
+ TTypeId.TINYINT_TYPE: "TINYINT",
+ TTypeId.SMALLINT_TYPE: "SMALLINT",
+ TTypeId.INT_TYPE: "INT",
+ TTypeId.BIGINT_TYPE: "BIGINT",
+ TTypeId.FLOAT_TYPE: "FLOAT",
+ TTypeId.DOUBLE_TYPE: "DOUBLE",
+ TTypeId.STRING_TYPE: "STRING",
+ TTypeId.TIMESTAMP_TYPE: "TIMESTAMP",
+ TTypeId.BINARY_TYPE: "BINARY",
+ TTypeId.ARRAY_TYPE: "ARRAY",
+ TTypeId.MAP_TYPE: "MAP",
+ TTypeId.STRUCT_TYPE: "STRUCT",
+ TTypeId.UNION_TYPE: "UNIONTYPE",
+ TTypeId.DECIMAL_TYPE: "DECIMAL",
+ TTypeId.NULL_TYPE: "NULL"
+ TTypeId.DATE_TYPE: "DATE"
+ TTypeId.VARCHAR_TYPE: "VARCHAR"
+ TTypeId.CHAR_TYPE: "CHAR"
+}
+
+// Thrift does not support recursively defined types or forward declarations,
+// which makes it difficult to represent Hive's nested types.
+// To get around these limitations TTypeDesc employs a type list that maps
+// integer "pointers" to TTypeEntry objects. The following examples show
+// how different types are represented using this scheme:
+//
+// "INT":
+// TTypeDesc {
+// types = [
+// TTypeEntry.primitive_entry {
+// type = INT_TYPE
+// }
+// ]
+// }
+//
+// "ARRAY<INT>":
+// TTypeDesc {
+// types = [
+// TTypeEntry.array_entry {
+// object_type_ptr = 1
+// },
+// TTypeEntry.primitive_entry {
+// type = INT_TYPE
+// }
+// ]
+// }
+//
+// "MAP<INT,STRING>":
+// TTypeDesc {
+// types = [
+// TTypeEntry.map_entry {
+// key_type_ptr = 1
+// value_type_ptr = 2
+// },
+// TTypeEntry.primitive_entry {
+// type = INT_TYPE
+// },
+// TTypeEntry.primitive_entry {
+// type = STRING_TYPE
+// }
+// ]
+// }
+
+typedef i32 TTypeEntryPtr
+
+// Valid TTypeQualifiers key names
+const string CHARACTER_MAXIMUM_LENGTH = "characterMaximumLength"
+
+// Type qualifier key name for decimal
+const string PRECISION = "precision"
+const string SCALE = "scale"
+
+union TTypeQualifierValue {
+ 1: optional i32 i32Value
+ 2: optional string stringValue
+}
+
+// Type qualifiers for primitive type.
+struct TTypeQualifiers {
+ 1: required map <string, TTypeQualifierValue> qualifiers
+}
+
+// Type entry for a primitive type.
+struct TPrimitiveTypeEntry {
+ // The primitive type token. This must satisfy the condition
+ // that type is in the PRIMITIVE_TYPES set.
+ 1: required TTypeId type
+ 2: optional TTypeQualifiers typeQualifiers
+}
+
+// Type entry for an ARRAY type.
+struct TArrayTypeEntry {
+ 1: required TTypeEntryPtr objectTypePtr
+}
+
+// Type entry for a MAP type.
+struct TMapTypeEntry {
+ 1: required TTypeEntryPtr keyTypePtr
+ 2: required TTypeEntryPtr valueTypePtr
+}
+
+// Type entry for a STRUCT type.
+struct TStructTypeEntry {
+ 1: required map<string, TTypeEntryPtr> nameToTypePtr
+}
+
+// Type entry for a UNIONTYPE type.
+struct TUnionTypeEntry {
+ 1: required map<string, TTypeEntryPtr> nameToTypePtr
+}
+
+struct TUserDefinedTypeEntry {
+ // The fully qualified name of the class implementing this type.
+ 1: required string typeClassName
+}
+
+// We use a union here since Thrift does not support inheritance.
+union TTypeEntry {
+ 1: TPrimitiveTypeEntry primitiveEntry
+ 2: TArrayTypeEntry arrayEntry
+ 3: TMapTypeEntry mapEntry
+ 4: TStructTypeEntry structEntry
+ 5: TUnionTypeEntry unionEntry
+ 6: TUserDefinedTypeEntry userDefinedTypeEntry
+}
+
+// Type descriptor for columns.
+struct TTypeDesc {
+ // The "top" type is always the first element of the list.
+ // If the top type is an ARRAY, MAP, STRUCT, or UNIONTYPE
+ // type, then subsequent elements represent nested types.
+ 1: required list<TTypeEntry> types
+}
+
+// A result set column descriptor.
+struct TColumnDesc {
+ // The name of the column
+ 1: required string columnName
+
+ // The type descriptor for this column
+ 2: required TTypeDesc typeDesc
+
+ // The ordinal position of this column in the schema
+ 3: required i32 position
+
+ 4: optional string comment
+}
+
+// Metadata used to describe the schema (column names, types, comments)
+// of result sets.
+struct TTableSchema {
+ 1: required list<TColumnDesc> columns
+}
+
+// A Boolean column value.
+struct TBoolValue {
+ // NULL if value is unset.
+ 1: optional bool value
+}
+
+// A Byte column value.
+struct TByteValue {
+ // NULL if value is unset.
+ 1: optional byte value
+}
+
+// A signed, 16 bit column value.
+struct TI16Value {
+ // NULL if value is unset
+ 1: optional i16 value
+}
+
+// A signed, 32 bit column value
+struct TI32Value {
+ // NULL if value is unset
+ 1: optional i32 value
+}
+
+// A signed 64 bit column value
+struct TI64Value {
+ // NULL if value is unset
+ 1: optional i64 value
+}
+
+// A floating point 64 bit column value
+struct TDoubleValue {
+ // NULL if value is unset
+ 1: optional double value
+}
+
+struct TStringValue {
+ // NULL if value is unset
+ 1: optional string value
+}
+
+// A single column value in a result set.
+// Note that Hive's type system is richer than Thrift's,
+// so in some cases we have to map multiple Hive types
+// to the same Thrift type. On the client-side this is
+// disambiguated by looking at the Schema of the
+// result set.
+union TColumnValue {
+ 1: TBoolValue boolVal // BOOLEAN
+ 2: TByteValue byteVal // TINYINT
+ 3: TI16Value i16Val // SMALLINT
+ 4: TI32Value i32Val // INT
+ 5: TI64Value i64Val // BIGINT, TIMESTAMP
+ 6: TDoubleValue doubleVal // FLOAT, DOUBLE
+ 7: TStringValue stringVal // STRING, LIST, MAP, STRUCT, UNIONTYPE, BINARY, DECIMAL, NULL
+}
+
+// Represents a row in a rowset.
+struct TRow {
+ 1: required list<TColumnValue> colVals
+}
+
+struct TBoolColumn {
+ 1: required list<bool> values
+ 2: required binary nulls
+}
+
+struct TByteColumn {
+ 1: required list<byte> values
+ 2: required binary nulls
+}
+
+struct TI16Column {
+ 1: required list<i16> values
+ 2: required binary nulls
+}
+
+struct TI32Column {
+ 1: required list<i32> values
+ 2: required binary nulls
+}
+
+struct TI64Column {
+ 1: required list<i64> values
+ 2: required binary nulls
+}
+
+struct TDoubleColumn {
+ 1: required list<double> values
+ 2: required binary nulls
+}
+
+struct TStringColumn {
+ 1: required list<string> values
+ 2: required binary nulls
+}
+
+struct TBinaryColumn {
+ 1: required list<binary> values
+ 2: required binary nulls
+}
+
+// Note that Hive's type system is richer than Thrift's,
+// so in some cases we have to map multiple Hive types
+// to the same Thrift type. On the client-side this is
+// disambiguated by looking at the Schema of the
+// result set.
+union TColumn {
+ 1: TBoolColumn boolVal // BOOLEAN
+ 2: TByteColumn byteVal // TINYINT
+ 3: TI16Column i16Val // SMALLINT
+ 4: TI32Column i32Val // INT
+ 5: TI64Column i64Val // BIGINT, TIMESTAMP
+ 6: TDoubleColumn doubleVal // FLOAT, DOUBLE
+ 7: TStringColumn stringVal // STRING, LIST, MAP, STRUCT, UNIONTYPE, DECIMAL, NULL
+ 8: TBinaryColumn binaryVal // BINARY
+}
+
+// Represents a rowset
+struct TRowSet {
+ // The starting row offset of this rowset.
+ 1: required i64 startRowOffset
+ 2: required list<TRow> rows
+ 3: optional list<TColumn> columns
+}
+
+// The return status code contained in each response.
+enum TStatusCode {
+ SUCCESS_STATUS,
+ SUCCESS_WITH_INFO_STATUS,
+ STILL_EXECUTING_STATUS,
+ ERROR_STATUS,
+ INVALID_HANDLE_STATUS
+}
+
+// The return status of a remote request
+struct TStatus {
+ 1: required TStatusCode statusCode
+
+ // If status is SUCCESS_WITH_INFO, info_msgs may be populated with
+ // additional diagnostic information.
+ 2: optional list<string> infoMessages
+
+ // If status is ERROR, then the following fields may be set
+ 3: optional string sqlState // as defined in the ISO/IEF CLI specification
+ 4: optional i32 errorCode // internal error code
+ 5: optional string errorMessage
+}
+
+// The state of an operation (i.e. a query or other
+// asynchronous operation that generates a result set)
+// on the server.
+enum TOperationState {
+ // The operation has been initialized
+ INITIALIZED_STATE,
+
+ // The operation is running. In this state the result
+ // set is not available.
+ RUNNING_STATE,
+
+ // The operation has completed. When an operation is in
+ // this state its result set may be fetched.
+ FINISHED_STATE,
+
+ // The operation was canceled by a client
+ CANCELED_STATE,
+
+ // The operation was closed by a client
+ CLOSED_STATE,
+
+ // The operation failed due to an error
+ ERROR_STATE,
+
+ // The operation is in an unrecognized state
+ UKNOWN_STATE,
+
+ // The operation is in a pending state
+ PENDING_STATE,
+}
+
+// A string identifier. This is interpreted literally.
+typedef string TIdentifier
+
+// A search pattern.
+//
+// Valid search pattern characters:
+// '_': Any single character.
+// '%': Any sequence of zero or more characters.
+// '\': Escape character used to include special characters,
+// e.g. '_', '%', '\'. If a '\' precedes a non-special
+// character it has no special meaning and is interpreted
+// literally.
+typedef string TPattern
+
+
+// A search pattern or identifier. Used as input
+// parameter for many of the catalog functions.
+typedef string TPatternOrIdentifier
+
+struct THandleIdentifier {
+ // 16 byte globally unique identifier
+ // This is the public ID of the handle and
+ // can be used for reporting.
+ 1: required binary guid,
+
+ // 16 byte secret generated by the server
+ // and used to verify that the handle is not
+ // being hijacked by another user.
+ 2: required binary secret,
+}
+
+// Client-side handle to persistent
+// session information on the server-side.
+struct TSessionHandle {
+ 1: required THandleIdentifier sessionId
+}
+
+// The subtype of an OperationHandle.
+enum TOperationType {
+ EXECUTE_STATEMENT,
+ GET_TYPE_INFO,
+ GET_CATALOGS,
+ GET_SCHEMAS,
+ GET_TABLES,
+ GET_TABLE_TYPES,
+ GET_COLUMNS,
+ GET_FUNCTIONS,
+ UNKNOWN,
+}
+
+// Client-side reference to a task running
+// asynchronously on the server.
+struct TOperationHandle {
+ 1: required THandleIdentifier operationId
+ 2: required TOperationType operationType
+
+ // If hasResultSet = TRUE, then this operation
+ // generates a result set that can be fetched.
+ // Note that the result set may be empty.
+ //
+ // If hasResultSet = FALSE, then this operation
+ // does not generate a result set, and calling
+ // GetResultSetMetadata or FetchResults against
+ // this OperationHandle will generate an error.
+ 3: required bool hasResultSet
+
+ // For operations that don't generate result sets,
+ // modifiedRowCount is either:
+ //
+ // 1) The number of rows that were modified by
+ // the DML operation (e.g. number of rows inserted,
+ // number of rows deleted, etc).
+ //
+ // 2) 0 for operations that don't modify or add rows.
+ //
+ // 3) < 0 if the operation is capable of modifying rows,
+ // but Hive is unable to determine how many rows were
+ // modified. For example, Hive's LOAD DATA command
+ // doesn't generate row count information because
+ // Hive doesn't inspect the data as it is loaded.
+ //
+ // modifiedRowCount is unset if the operation generates
+ // a result set.
+ 4: optional double modifiedRowCount
+}
+
+
+// OpenSession()
+//
+// Open a session (connection) on the server against
+// which operations may be executed.
+struct TOpenSessionReq {
+ // The version of the HiveServer2 protocol that the client is using.
+ 1: required TProtocolVersion client_protocol = TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6
+
+ // Username and password for authentication.
+ // Depending on the authentication scheme being used,
+ // this information may instead be provided by a lower
+ // protocol layer, in which case these fields may be
+ // left unset.
+ 2: optional string username
+ 3: optional string password
+
+ // Configuration overlay which is applied when the session is
+ // first created.
+ 4: optional map<string, string> configuration
+}
+
+struct TOpenSessionResp {
+ 1: required TStatus status
+
+ // The protocol version that the server is using.
+ 2: required TProtocolVersion serverProtocolVersion = TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6
+
+ // Session Handle
+ 3: optional TSessionHandle sessionHandle
+
+ // The configuration settings for this session.
+ 4: optional map<string, string> configuration
+}
+
+
+// CloseSession()
+//
+// Closes the specified session and frees any resources
+// currently allocated to that session. Any open
+// operations in that session will be canceled.
+struct TCloseSessionReq {
+ 1: required TSessionHandle sessionHandle
+}
+
+struct TCloseSessionResp {
+ 1: required TStatus status
+}
+
+
+
+enum TGetInfoType {
+ CLI_MAX_DRIVER_CONNECTIONS = 0,
+ CLI_MAX_CONCURRENT_ACTIVITIES = 1,
+ CLI_DATA_SOURCE_NAME = 2,
+ CLI_FETCH_DIRECTION = 8,
+ CLI_SERVER_NAME = 13,
+ CLI_SEARCH_PATTERN_ESCAPE = 14,
+ CLI_DBMS_NAME = 17,
+ CLI_DBMS_VER = 18,
+ CLI_ACCESSIBLE_TABLES = 19,
+ CLI_ACCESSIBLE_PROCEDURES = 20,
+ CLI_CURSOR_COMMIT_BEHAVIOR = 23,
+ CLI_DATA_SOURCE_READ_ONLY = 25,
+ CLI_DEFAULT_TXN_ISOLATION = 26,
+ CLI_IDENTIFIER_CASE = 28,
+ CLI_IDENTIFIER_QUOTE_CHAR = 29,
+ CLI_MAX_COLUMN_NAME_LEN = 30,
+ CLI_MAX_CURSOR_NAME_LEN = 31,
+ CLI_MAX_SCHEMA_NAME_LEN = 32,
+ CLI_MAX_CATALOG_NAME_LEN = 34,
+ CLI_MAX_TABLE_NAME_LEN = 35,
+ CLI_SCROLL_CONCURRENCY = 43,
+ CLI_TXN_CAPABLE = 46,
+ CLI_USER_NAME = 47,
+ CLI_TXN_ISOLATION_OPTION = 72,
+ CLI_INTEGRITY = 73,
+ CLI_GETDATA_EXTENSIONS = 81,
+ CLI_NULL_COLLATION = 85,
+ CLI_ALTER_TABLE = 86,
+ CLI_ORDER_BY_COLUMNS_IN_SELECT = 90,
+ CLI_SPECIAL_CHARACTERS = 94,
+ CLI_MAX_COLUMNS_IN_GROUP_BY = 97,
+ CLI_MAX_COLUMNS_IN_INDEX = 98,
+ CLI_MAX_COLUMNS_IN_ORDER_BY = 99,
+ CLI_MAX_COLUMNS_IN_SELECT = 100,
+ CLI_MAX_COLUMNS_IN_TABLE = 101,
+ CLI_MAX_INDEX_SIZE = 102,
+ CLI_MAX_ROW_SIZE = 104,
+ CLI_MAX_STATEMENT_LEN = 105,
+ CLI_MAX_TABLES_IN_SELECT = 106,
+ CLI_MAX_USER_NAME_LEN = 107,
+ CLI_OJ_CAPABILITIES = 115,
+
+ CLI_XOPEN_CLI_YEAR = 10000,
+ CLI_CURSOR_SENSITIVITY = 10001,
+ CLI_DESCRIBE_PARAMETER = 10002,
+ CLI_CATALOG_NAME = 10003,
+ CLI_COLLATION_SEQ = 10004,
+ CLI_MAX_IDENTIFIER_LEN = 10005,
+}
+
+union TGetInfoValue {
+ 1: string stringValue
+ 2: i16 smallIntValue
+ 3: i32 integerBitmask
+ 4: i32 integerFlag
+ 5: i32 binaryValue
+ 6: i64 lenValue
+}
+
+// GetInfo()
+//
+// This function is based on ODBC's CLIGetInfo() function.
+// The function returns general information about the data source
+// using the same keys as ODBC.
+struct TGetInfoReq {
+ // The session to run this request against
+ 1: required TSessionHandle sessionHandle
+
+ 2: required TGetInfoType infoType
+}
+
+struct TGetInfoResp {
+ 1: required TStatus status
+
+ 2: required TGetInfoValue infoValue
+}
+
+
+// ExecuteStatement()
+//
+// Execute a statement.
+// The returned OperationHandle can be used to check on the
+// status of the statement, and to fetch results once the
+// statement has finished executing.
+struct TExecuteStatementReq {
+ // The session to execute the statement against
+ 1: required TSessionHandle sessionHandle
+
+ // The statement to be executed (DML, DDL, SET, etc)
+ 2: required string statement
+
+ // Configuration properties that are overlaid on top of the
+ // the existing session configuration before this statement
+ // is executed. These properties apply to this statement
+ // only and will not affect the subsequent state of the Session.
+ 3: optional map<string, string> confOverlay
+
+ // Execute asynchronously when runAsync is true
+ 4: optional bool runAsync = false
+}
+
+struct TExecuteStatementResp {
+ 1: required TStatus status
+ 2: optional TOperationHandle operationHandle
+}
+
+// GetTypeInfo()
+//
+// Get information about types supported by the HiveServer instance.
+// The information is returned as a result set which can be fetched
+// using the OperationHandle provided in the response.
+//
+// Refer to the documentation for ODBC's CLIGetTypeInfo function for
+// the format of the result set.
+struct TGetTypeInfoReq {
+ // The session to run this request against.
+ 1: required TSessionHandle sessionHandle
+}
+
+struct TGetTypeInfoResp {
+ 1: required TStatus status
+ 2: optional TOperationHandle operationHandle
+}
+
+
+// GetCatalogs()
+//
+// Returns the list of catalogs (databases)
+// Results are ordered by TABLE_CATALOG
+//
+// Resultset columns :
+// col1
+// name: TABLE_CAT
+// type: STRING
+// desc: Catalog name. NULL if not applicable.
+//
+struct TGetCatalogsReq {
+ // Session to run this request against
+ 1: required TSessionHandle sessionHandle
+}
+
+struct TGetCatalogsResp {
+ 1: required TStatus status
+ 2: optional TOperationHandle operationHandle
+}
+
+
+// GetSchemas()
+//
+// Retrieves the schema names available in this database.
+// The results are ordered by TABLE_CATALOG and TABLE_SCHEM.
+// col1
+// name: TABLE_SCHEM
+// type: STRING
+// desc: schema name
+// col2
+// name: TABLE_CATALOG
+// type: STRING
+// desc: catalog name
+struct TGetSchemasReq {
+ // Session to run this request against
+ 1: required TSessionHandle sessionHandle
+
+ // Name of the catalog. Must not contain a search pattern.
+ 2: optional TIdentifier catalogName
+
+ // schema name or pattern
+ 3: optional TPatternOrIdentifier schemaName
+}
+
+struct TGetSchemasResp {
+ 1: required TStatus status
+ 2: optional TOperationHandle operationHandle
+}
+
+
+// GetTables()
+//
+// Returns a list of tables with catalog, schema, and table
+// type information. The information is returned as a result
+// set which can be fetched using the OperationHandle
+// provided in the response.
+// Results are ordered by TABLE_TYPE, TABLE_CAT, TABLE_SCHEM, and TABLE_NAME
+//
+// Result Set Columns:
+//
+// col1
+// name: TABLE_CAT
+// type: STRING
+// desc: Catalog name. NULL if not applicable.
+//
+// col2
+// name: TABLE_SCHEM
+// type: STRING
+// desc: Schema name.
+//
+// col3
+// name: TABLE_NAME
+// type: STRING
+// desc: Table name.
+//
+// col4
+// name: TABLE_TYPE
+// type: STRING
+// desc: The table type, e.g. "TABLE", "VIEW", etc.
+//
+// col5
+// name: REMARKS
+// type: STRING
+// desc: Comments about the table
+//
+struct TGetTablesReq {
+ // Session to run this request against
+ 1: required TSessionHandle sessionHandle
+
+ // Name of the catalog or a search pattern.
+ 2: optional TPatternOrIdentifier catalogName
+
+ // Name of the schema or a search pattern.
+ 3: optional TPatternOrIdentifier schemaName
+
+ // Name of the table or a search pattern.
+ 4: optional TPatternOrIdentifier tableName
+
+ // List of table types to match
+ // e.g. "TABLE", "VIEW", "SYSTEM TABLE", "GLOBAL TEMPORARY",
+ // "LOCAL TEMPORARY", "ALIAS", "SYNONYM", etc.
+ 5: optional list<string> tableTypes
+}
+
+struct TGetTablesResp {
+ 1: required TStatus status
+ 2: optional TOperationHandle operationHandle
+}
+
+
+// GetTableTypes()
+//
+// Returns the table types available in this database.
+// The results are ordered by table type.
+//
+// col1
+// name: TABLE_TYPE
+// type: STRING
+// desc: Table type name.
+struct TGetTableTypesReq {
+ // Session to run this request against
+ 1: required TSessionHandle sessionHandle
+}
+
+struct TGetTableTypesResp {
+ 1: required TStatus status
+ 2: optional TOperationHandle operationHandle
+}
+
+
+// GetColumns()
+//
+// Returns a list of columns in the specified tables.
+// The information is returned as a result set which can be fetched
+// using the OperationHandle provided in the response.
+// Results are ordered by TABLE_CAT, TABLE_SCHEM, TABLE_NAME,
+// and ORDINAL_POSITION.
+//
+// Result Set Columns are the same as those for the ODBC CLIColumns
+// function.
+//
+struct TGetColumnsReq {
+ // Session to run this request against
+ 1: required TSessionHandle sessionHandle
+
+ // Name of the catalog. Must not contain a search pattern.
+ 2: optional TIdentifier catalogName
+
+ // Schema name or search pattern
+ 3: optional TPatternOrIdentifier schemaName
+
+ // Table name or search pattern
+ 4: optional TPatternOrIdentifier tableName
+
+ // Column name or search pattern
+ 5: optional TPatternOrIdentifier columnName
+}
+
+struct TGetColumnsResp {
+ 1: required TStatus status
+ 2: optional TOperationHandle operationHandle
+}
+
+
+// GetFunctions()
+//
+// Returns a list of functions supported by the data source. The
+// behavior of this function matches
+// java.sql.DatabaseMetaData.getFunctions() both in terms of
+// inputs and outputs.
+//
+// Result Set Columns:
+//
+// col1
+// name: FUNCTION_CAT
+// type: STRING
+// desc: Function catalog (may be null)
+//
+// col2
+// name: FUNCTION_SCHEM
+// type: STRING
+// desc: Function schema (may be null)
+//
+// col3
+// name: FUNCTION_NAME
+// type: STRING
+// desc: Function name. This is the name used to invoke the function.
+//
+// col4
+// name: REMARKS
+// type: STRING
+// desc: Explanatory comment on the function.
+//
+// col5
+// name: FUNCTION_TYPE
+// type: SMALLINT
+// desc: Kind of function. One of:
+// * functionResultUnknown - Cannot determine if a return value or a table
+// will be returned.
+// * functionNoTable - Does not a return a table.
+// * functionReturnsTable - Returns a table.
+//
+// col6
+// name: SPECIFIC_NAME
+// type: STRING
+// desc: The name which uniquely identifies this function within its schema.
+// In this case this is the fully qualified class name of the class
+// that implements this function.
+//
+struct TGetFunctionsReq {
+ // Session to run this request against
+ 1: required TSessionHandle sessionHandle
+
+ // A catalog name; must match the catalog name as it is stored in the
+ // database; "" retrieves those without a catalog; null means
+ // that the catalog name should not be used to narrow the search.
+ 2: optional TIdentifier catalogName
+
+ // A schema name pattern; must match the schema name as it is stored
+ // in the database; "" retrieves those without a schema; null means
+ // that the schema name should not be used to narrow the search.
+ 3: optional TPatternOrIdentifier schemaName
+
+ // A function name pattern; must match the function name as it is stored
+ // in the database.
+ 4: required TPatternOrIdentifier functionName
+}
+
+struct TGetFunctionsResp {
+ 1: required TStatus status
+ 2: optional TOperationHandle operationHandle
+}
+
+
+// GetOperationStatus()
+//
+// Get the status of an operation running on the server.
+struct TGetOperationStatusReq {
+ // Session to run this request against
+ 1: required TOperationHandle operationHandle
+}
+
+struct TGetOperationStatusResp {
+ 1: required TStatus status
+ 2: optional TOperationState operationState
+
+ // If operationState is ERROR_STATE, then the following fields may be set
+ // sqlState as defined in the ISO/IEF CLI specification
+ 3: optional string sqlState
+
+ // Internal error code
+ 4: optional i32 errorCode
+
+ // Error message
+ 5: optional string errorMessage
+}
+
+
+// CancelOperation()
+//
+// Cancels processing on the specified operation handle and
+// frees any resources which were allocated.
+struct TCancelOperationReq {
+ // Operation to cancel
+ 1: required TOperationHandle operationHandle
+}
+
+struct TCancelOperationResp {
+ 1: required TStatus status
+}
+
+
+// CloseOperation()
+//
+// Given an operation in the FINISHED, CANCELED,
+// or ERROR states, CloseOperation() will free
+// all of the resources which were allocated on
+// the server to service the operation.
+struct TCloseOperationReq {
+ 1: required TOperationHandle operationHandle
+}
+
+struct TCloseOperationResp {
+ 1: required TStatus status
+}
+
+
+// GetResultSetMetadata()
+//
+// Retrieves schema information for the specified operation
+struct TGetResultSetMetadataReq {
+ // Operation for which to fetch result set schema information
+ 1: required TOperationHandle operationHandle
+}
+
+struct TGetResultSetMetadataResp {
+ 1: required TStatus status
+ 2: optional TTableSchema schema
+}
+
+
+enum TFetchOrientation {
+ // Get the next rowset. The fetch offset is ignored.
+ FETCH_NEXT,
+
+ // Get the previous rowset. The fetch offset is ignored.
+ // NOT SUPPORTED
+ FETCH_PRIOR,
+
+ // Return the rowset at the given fetch offset relative
+ // to the curren rowset.
+ // NOT SUPPORTED
+ FETCH_RELATIVE,
+
+ // Return the rowset at the specified fetch offset.
+ // NOT SUPPORTED
+ FETCH_ABSOLUTE,
+
+ // Get the first rowset in the result set.
+ FETCH_FIRST,
+
+ // Get the last rowset in the result set.
+ // NOT SUPPORTED
+ FETCH_LAST
+}
+
+// FetchResults()
+//
+// Fetch rows from the server corresponding to
+// a particular OperationHandle.
+struct TFetchResultsReq {
+ // Operation from which to fetch results.
+ 1: required TOperationHandle operationHandle
+
+ // The fetch orientation. For V1 this must be either
+ // FETCH_NEXT or FETCH_FIRST. Defaults to FETCH_NEXT.
+ 2: required TFetchOrientation orientation = TFetchOrientation.FETCH_NEXT
+
+ // Max number of rows that should be returned in
+ // the rowset.
+ 3: required i64 maxRows
+}
+
+struct TFetchResultsResp {
+ 1: required TStatus status
+
+ // TRUE if there are more rows left to fetch from the server.
+ 2: optional bool hasMoreRows
+
+ // The rowset. This is optional so that we have the
+ // option in the future of adding alternate formats for
+ // representing result set data, e.g. delimited strings,
+ // binary encoded, etc.
+ 3: optional TRowSet results
+}
+
+// GetDelegationToken()
+// Retrieve delegation token for the current user
+struct TGetDelegationTokenReq {
+ // session handle
+ 1: required TSessionHandle sessionHandle
+
+ // userid for the proxy user
+ 2: required string owner
+
+ // designated renewer userid
+ 3: required string renewer
+}
+
+struct TGetDelegationTokenResp {
+ // status of the request
+ 1: required TStatus status
+
+ // delegation token string
+ 2: optional string delegationToken
+}
+
+// CancelDelegationToken()
+// Cancel the given delegation token
+struct TCancelDelegationTokenReq {
+ // session handle
+ 1: required TSessionHandle sessionHandle
+
+ // delegation token to cancel
+ 2: required string delegationToken
+}
+
+struct TCancelDelegationTokenResp {
+ // status of the request
+ 1: required TStatus status
+}
+
+// RenewDelegationToken()
+// Renew the given delegation token
+struct TRenewDelegationTokenReq {
+ // session handle
+ 1: required TSessionHandle sessionHandle
+
+ // delegation token to renew
+ 2: required string delegationToken
+}
+
+struct TRenewDelegationTokenResp {
+ // status of the request
+ 1: required TStatus status
+}
+
+// GetLog()
+// Not present in Hive 0.13, re-added for backwards compatibility.
+//
+// Fetch operation log from the server corresponding to
+// a particular OperationHandle.
+struct TGetLogReq {
+ // Operation whose log is requested
+ 1: required TOperationHandle operationHandle
+}
+
+struct TGetLogResp {
+ 1: required TStatus status
+ 2: required string log
+}
+
+service TCLIService {
+
+ TOpenSessionResp OpenSession(1:TOpenSessionReq req);
+
+ TCloseSessionResp CloseSession(1:TCloseSessionReq req);
+
+ TGetInfoResp GetInfo(1:TGetInfoReq req);
+
+ TExecuteStatementResp ExecuteStatement(1:TExecuteStatementReq req);
+
+ TGetTypeInfoResp GetTypeInfo(1:TGetTypeInfoReq req);
+
+ TGetCatalogsResp GetCatalogs(1:TGetCatalogsReq req);
+
+ TGetSchemasResp GetSchemas(1:TGetSchemasReq req);
+
+ TGetTablesResp GetTables(1:TGetTablesReq req);
+
+ TGetTableTypesResp GetTableTypes(1:TGetTableTypesReq req);
+
+ TGetColumnsResp GetColumns(1:TGetColumnsReq req);
+
+ TGetFunctionsResp GetFunctions(1:TGetFunctionsReq req);
+
+ TGetOperationStatusResp GetOperationStatus(1:TGetOperationStatusReq req);
+
+ TCancelOperationResp CancelOperation(1:TCancelOperationReq req);
+
+ TCloseOperationResp CloseOperation(1:TCloseOperationReq req);
+
+ TGetResultSetMetadataResp GetResultSetMetadata(1:TGetResultSetMetadataReq req);
+
+ TFetchResultsResp FetchResults(1:TFetchResultsReq req);
+
+ TGetDelegationTokenResp GetDelegationToken(1:TGetDelegationTokenReq req);
+
+ TCancelDelegationTokenResp CancelDelegationToken(1:TCancelDelegationTokenReq req);
+
+ TRenewDelegationTokenResp RenewDelegationToken(1:TRenewDelegationTokenReq req);
+
+ // Not present in Hive 0.13, re-added for backwards compatibility.
+ TGetLogResp GetLog(1:TGetLogReq req);
+}
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/Types.thrift b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/Types.thrift
new file mode 100644
index 000000000..39ae6d0ba
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/Types.thrift
@@ -0,0 +1,218 @@
+// Copyright 2012 Cloudera Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+namespace cpp impala
+namespace java com.cloudera.impala.thrift
+
+typedef i64 TTimestamp
+typedef i32 TPlanNodeId
+typedef i32 TTupleId
+typedef i32 TSlotId
+typedef i32 TTableId
+
+// TODO: Consider moving unrelated enums to better locations.
+
+enum TPrimitiveType {
+ INVALID_TYPE,
+ NULL_TYPE,
+ BOOLEAN,
+ TINYINT,
+ SMALLINT,
+ INT,
+ BIGINT,
+ FLOAT,
+ DOUBLE,
+ DATE,
+ DATETIME,
+ TIMESTAMP,
+ STRING,
+ // Unsupported types
+ BINARY,
+ DECIMAL,
+ // CHAR(n). Currently only supported in UDAs
+ CHAR,
+ VARCHAR
+}
+
+enum TTypeNodeType {
+ SCALAR,
+ ARRAY,
+ MAP,
+ STRUCT
+}
+
+struct TScalarType {
+ 1: required TPrimitiveType type
+
+ // Only set if type == CHAR or type == VARCHAR
+ 2: optional i32 len
+
+ // Only set for DECIMAL
+ 3: optional i32 precision
+ 4: optional i32 scale
+}
+
+// Represents a field in a STRUCT type.
+// TODO: Model column stats for struct fields.
+struct TStructField {
+ 1: required string name
+ 2: optional string comment
+}
+
+struct TTypeNode {
+ 1: required TTypeNodeType type
+
+ // only set for scalar types
+ 2: optional TScalarType scalar_type
+
+ // only used for structs; has struct_fields.size() corresponding child types
+ 3: optional list<TStructField> struct_fields
+}
+
+// A flattened representation of a tree of column types obtained by depth-first
+// traversal. Complex types such as map, array and struct have child types corresponding
+// to the map key/value, array item type, and struct fields, respectively.
+// For scalar types the list contains only a single node.
+// Note: We cannot rename this to TType because it conflicts with Thrift's internal TType
+// and the generated Python thrift files will not work.
+struct TColumnType {
+ 1: list<TTypeNode> types
+}
+
+enum TStmtType {
+ QUERY,
+ DDL, // Data definition, e.g. CREATE TABLE (includes read-only functions e.g. SHOW)
+ DML, // Data modification e.g. INSERT
+ EXPLAIN,
+ LOAD, // Statement type for LOAD commands
+ SET
+}
+
+// Level of verboseness for "explain" output.
+enum TExplainLevel {
+ MINIMAL,
+ STANDARD,
+ EXTENDED,
+ VERBOSE
+}
+
+enum TRuntimeFilterMode {
+ // No filters are computed in the FE or the BE.
+ OFF,
+
+ // Only broadcast filters are computed in the BE, and are only published to the local
+ // fragment.
+ LOCAL,
+
+ // All filters are computed in the BE, and are published globally.
+ GLOBAL
+}
+
+// A TNetworkAddress is the standard host, port representation of a
+// network address. The hostname field must be resolvable to an IPv4
+// address.
+struct TNetworkAddress {
+ 1: required string hostname
+ 2: required i32 port
+}
+
+// Wire format for UniqueId
+struct TUniqueId {
+ 1: required i64 hi
+ 2: required i64 lo
+}
+
+enum TFunctionCategory {
+ SCALAR,
+ AGGREGATE,
+ ANALYTIC
+}
+
+enum TFunctionBinaryType {
+ // Impala builtin. We can either run this interpreted or via codegen
+ // depending on the query option.
+ BUILTIN,
+
+ // Java UDFs, loaded from *.jar
+ JAVA,
+
+ // Native-interface, precompiled UDFs loaded from *.so
+ NATIVE,
+
+ // Native-interface, precompiled to IR; loaded from *.ll
+ IR,
+}
+
+// Represents a fully qualified function name.
+struct TFunctionName {
+ // Name of the function's parent database. Not set if in global
+ // namespace (e.g. builtins)
+ 1: optional string db_name
+
+ // Name of the function
+ 2: required string function_name
+}
+
+struct TScalarFunction {
+ 1: required string symbol;
+ 2: optional string prepare_fn_symbol
+ 3: optional string close_fn_symbol
+}
+
+struct TAggregateFunction {
+ 1: required TColumnType intermediate_type
+ 2: required string update_fn_symbol
+ 3: required string init_fn_symbol
+ 4: optional string serialize_fn_symbol
+ 5: optional string merge_fn_symbol
+ 6: optional string finalize_fn_symbol
+ 8: optional string get_value_fn_symbol
+ 9: optional string remove_fn_symbol
+
+ 7: optional bool ignores_distinct
+}
+
+// Represents a function in the Catalog.
+struct TFunction {
+ // Fully qualified function name.
+ 1: required TFunctionName name
+
+ // Type of the udf. e.g. hive, native, ir
+ 2: required TFunctionBinaryType binary_type
+
+ // The types of the arguments to the function
+ 3: required list<TColumnType> arg_types
+
+ // Return type for the function.
+ 4: required TColumnType ret_type
+
+ // If true, this function takes var args.
+ 5: required bool has_var_args
+
+ // Optional comment to attach to the function
+ 6: optional string comment
+
+ 7: optional string signature
+
+ // HDFS path for the function binary. This binary must exist at the time the
+ // function is created.
+ 8: optional string hdfs_location
+
+ // One of these should be set.
+ 9: optional TScalarFunction scalar_fn
+ 10: optional TAggregateFunction aggregate_fn
+
+ // True for builtins or user-defined functions persisted by the catalog
+ 11: required bool is_persistent
+}
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/beeswax.thrift b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/beeswax.thrift
new file mode 100644
index 000000000..a0ca5a746
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/beeswax.thrift
@@ -0,0 +1,174 @@
+/*
+ * Licensed to Cloudera, Inc. under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. Cloudera, Inc. licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Interface for interacting with Beeswax Server
+ */
+
+namespace java com.cloudera.beeswax.api
+namespace py beeswaxd
+namespace cpp beeswax
+
+include "hive_metastore.thrift"
+
+// A Query
+struct Query {
+ 1: string query;
+ // A list of HQL commands to execute before the query.
+ // This is typically defining UDFs, setting settings, and loading resources.
+ 3: list<string> configuration;
+
+ // User and groups to "act as" for purposes of Hadoop.
+ 4: string hadoop_user;
+}
+
+typedef string LogContextId
+
+enum QueryState {
+ CREATED,
+ INITIALIZED,
+ COMPILED,
+ RUNNING,
+ FINISHED,
+ EXCEPTION
+}
+
+struct QueryHandle {
+ 1: string id;
+ 2: LogContextId log_context;
+}
+
+struct QueryExplanation {
+ 1: string textual
+}
+
+struct Results {
+ // If set, data is valid. Otherwise, results aren't ready yet.
+ 1: bool ready,
+ // Columns for the results
+ 2: list<string> columns,
+ // A set of results
+ 3: list<string> data,
+ // The starting row of the results
+ 4: i64 start_row,
+ // Whether there are more results to fetch
+ 5: bool has_more
+}
+
+/**
+ * Metadata information about the results.
+ * Applicable only for SELECT.
+ */
+struct ResultsMetadata {
+ /** The schema of the results */
+ 1: hive_metastore.Schema schema,
+ /** The directory containing the results. Not applicable for partition table. */
+ 2: string table_dir,
+ /** If the results are straight from an existing table, the table name. */
+ 3: string in_tablename,
+ /** Field delimiter */
+ 4: string delim,
+}
+
+exception BeeswaxException {
+ 1: string message,
+ // Use get_log(log_context) to retrieve any log related to this exception
+ 2: LogContextId log_context,
+ // (Optional) The QueryHandle that caused this exception
+ 3: QueryHandle handle,
+ 4: optional i32 errorCode = 0,
+ 5: optional string SQLState = " "
+}
+
+exception QueryNotFoundException {
+}
+
+/** Represents a Hadoop-style configuration variable. */
+struct ConfigVariable {
+ 1: string key,
+ 2: string value,
+ 3: string description
+}
+
+service BeeswaxService {
+ /**
+ * Submit a query and return a handle (QueryHandle). The query runs asynchronously.
+ */
+ QueryHandle query(1:Query query) throws(1:BeeswaxException error),
+
+ /**
+ * run a query synchronously and return a handle (QueryHandle).
+ */
+ QueryHandle executeAndWait(1:Query query, 2:LogContextId clientCtx)
+ throws(1:BeeswaxException error),
+
+ /**
+ * Get the query plan for a query.
+ */
+ QueryExplanation explain(1:Query query)
+ throws(1:BeeswaxException error),
+
+ /**
+ * Get the results of a query. This is non-blocking. Caller should check
+ * Results.ready to determine if the results are in yet. The call requests
+ * the batch size of fetch.
+ */
+ Results fetch(1:QueryHandle query_id, 2:bool start_over, 3:i32 fetch_size=-1)
+ throws(1:QueryNotFoundException error, 2:BeeswaxException error2),
+
+ /**
+ * Get the state of the query
+ */
+ QueryState get_state(1:QueryHandle handle) throws(1:QueryNotFoundException error),
+
+ /**
+ * Get the result metadata
+ */
+ ResultsMetadata get_results_metadata(1:QueryHandle handle)
+ throws(1:QueryNotFoundException error),
+
+ /**
+ * Used to test connection to server. A "noop" command.
+ */
+ string echo(1:string s)
+
+ /**
+ * Returns a string representation of the configuration object being used.
+ * Handy for debugging.
+ */
+ string dump_config()
+
+ /**
+ * Get the log messages related to the given context.
+ */
+ string get_log(1:LogContextId context) throws(1:QueryNotFoundException error)
+
+ /*
+ * Returns "default" configuration.
+ */
+ list<ConfigVariable> get_default_configuration(1:bool include_hadoop)
+
+ /*
+ * closes the query with given handle
+ */
+ void close(1:QueryHandle handle) throws(1:QueryNotFoundException error,
+ 2:BeeswaxException error2)
+
+ /*
+ * clean the log context for given id
+ */
+ void clean(1:LogContextId log_context)
+}
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/fb303.thrift b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/fb303.thrift
new file mode 100644
index 000000000..66c831527
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/fb303.thrift
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * fb303.thrift
+ */
+
+namespace java com.facebook.fb303
+namespace cpp facebook.fb303
+namespace perl Facebook.FB303
+
+/**
+ * Common status reporting mechanism across all services
+ */
+enum fb_status {
+ DEAD = 0,
+ STARTING = 1,
+ ALIVE = 2,
+ STOPPING = 3,
+ STOPPED = 4,
+ WARNING = 5,
+}
+
+/**
+ * Standard base service
+ */
+service FacebookService {
+
+ /**
+ * Returns a descriptive name of the service
+ */
+ string getName(),
+
+ /**
+ * Returns the version of the service
+ */
+ string getVersion(),
+
+ /**
+ * Gets the status of this service
+ */
+ fb_status getStatus(),
+
+ /**
+ * User friendly description of status, such as why the service is in
+ * the dead or warning state, or what is being started or stopped.
+ */
+ string getStatusDetails(),
+
+ /**
+ * Gets the counters for this service
+ */
+ map<string, i64> getCounters(),
+
+ /**
+ * Gets the value of a single counter
+ */
+ i64 getCounter(1: string key),
+
+ /**
+ * Sets an option
+ */
+ void setOption(1: string key, 2: string value),
+
+ /**
+ * Gets an option
+ */
+ string getOption(1: string key),
+
+ /**
+ * Gets all options
+ */
+ map<string, string> getOptions(),
+
+ /**
+ * Returns a CPU profile over the given time interval (client and server
+ * must agree on the profile format).
+ */
+ string getCpuProfile(1: i32 profileDurationInSec),
+
+ /**
+ * Returns the unix time that the server has been running since
+ */
+ i64 aliveSince(),
+
+ /**
+ * Tell the server to reload its configuration, reopen log files, etc
+ */
+ oneway void reinitialize(),
+
+ /**
+ * Suggest a shutdown to the server
+ */
+ oneway void shutdown(),
+
+}
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/generate_error_codes.py b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/generate_error_codes.py
new file mode 100644
index 000000000..3790057d2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/generate_error_codes.py
@@ -0,0 +1,293 @@
+#!/usr/bin/env python
+# Copyright 2015 Cloudera Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os
+
+
+# For readability purposes we define the error codes and messages at the top of the
+# file. New codes and messages must be added here. Old error messages MUST NEVER BE
+# DELETED, but can be renamed. The tuple layout for a new entry is: error code enum name,
+# numeric error code, format string of the message.
+#
+# TODO Add support for SQL Error Codes
+# https://msdn.microsoft.com/en-us/library/ms714687%28v=vs.85%29.aspx
+error_codes = (
+ ("OK", 0, ""),
+
+ ("UNUSED", 1, "<UNUSED>"),
+
+ ("GENERAL", 2, "$0"),
+
+ ("CANCELLED", 3, "$0"),
+
+ ("ANALYSIS_ERROR", 4, "$0"),
+
+ ("NOT_IMPLEMENTED_ERROR", 5, "$0"),
+
+ ("RUNTIME_ERROR", 6, "$0"),
+
+ ("MEM_LIMIT_EXCEEDED", 7, "$0"),
+
+ ("INTERNAL_ERROR", 8, "$0"),
+
+ ("RECOVERABLE_ERROR", 9, "$0"),
+
+ ("PARQUET_MULTIPLE_BLOCKS", 10,
+ "Parquet files should not be split into multiple hdfs-blocks. file=$0"),
+
+ ("PARQUET_COLUMN_METADATA_INVALID", 11,
+ "Column metadata states there are $0 values, but read $1 values from column $2. "
+ "file=$3"),
+
+ ("PARQUET_HEADER_PAGE_SIZE_EXCEEDED", 12, "(unused)"),
+
+ ("PARQUET_HEADER_EOF", 13,
+ "ParquetScanner: reached EOF while deserializing data page header. file=$0"),
+
+ ("PARQUET_GROUP_ROW_COUNT_ERROR", 14,
+ "Metadata states that in group $0($1) there are $2 rows, but $3 rows were read."),
+
+ ("PARQUET_GROUP_ROW_COUNT_OVERFLOW", 15, "(unused)"),
+
+ ("PARQUET_MISSING_PRECISION", 16,
+ "File '$0' column '$1' does not have the decimal precision set."),
+
+ ("PARQUET_WRONG_PRECISION", 17,
+ "File '$0' column '$1' has a precision that does not match the table metadata "
+ " precision. File metadata precision: $2, table metadata precision: $3."),
+
+ ("PARQUET_BAD_CONVERTED_TYPE", 18,
+ "File '$0' column '$1' does not have converted type set to DECIMAL"),
+
+ ("PARQUET_INCOMPATIBLE_DECIMAL", 19,
+ "File '$0' column '$1' contains decimal data but the table metadata has type $2"),
+
+ ("SEQUENCE_SCANNER_PARSE_ERROR", 20,
+ "Problem parsing file $0 at $1$2"),
+
+ ("SNAPPY_DECOMPRESS_INVALID_BLOCK_SIZE", 21,
+ "Decompressor: block size is too big. Data is likely corrupt. Size: $0"),
+
+ ("SNAPPY_DECOMPRESS_INVALID_COMPRESSED_LENGTH", 22,
+ "Decompressor: invalid compressed length. Data is likely corrupt."),
+
+ ("SNAPPY_DECOMPRESS_UNCOMPRESSED_LENGTH_FAILED", 23,
+ "Snappy: GetUncompressedLength failed"),
+
+ ("SNAPPY_DECOMPRESS_RAW_UNCOMPRESS_FAILED", 24,
+ "SnappyBlock: RawUncompress failed"),
+
+ ("SNAPPY_DECOMPRESS_DECOMPRESS_SIZE_INCORRECT", 25,
+ "Snappy: Decompressed size is not correct."),
+
+ ("HDFS_SCAN_NODE_UNKNOWN_DISK", 26, "Unknown disk id. "
+ "This will negatively affect performance. "
+ "Check your hdfs settings to enable block location metadata."),
+
+ ("FRAGMENT_EXECUTOR", 27, "Reserved resource size ($0) is larger than "
+ "query mem limit ($1), and will be restricted to $1. Configure the reservation "
+ "size by setting RM_INITIAL_MEM."),
+
+ ("PARTITIONED_HASH_JOIN_MAX_PARTITION_DEPTH", 28,
+ "Cannot perform join at hash join node with id $0."
+ " The input data was partitioned the maximum number of $1 times."
+ " This could mean there is significant skew in the data or the memory limit is"
+ " set too low."),
+
+ ("PARTITIONED_AGG_MAX_PARTITION_DEPTH", 29,
+ "Cannot perform aggregation at hash aggregation node with id $0."
+ " The input data was partitioned the maximum number of $1 times."
+ " This could mean there is significant skew in the data or the memory limit is"
+ " set too low."),
+
+ ("MISSING_BUILTIN", 30, "Builtin '$0' with symbol '$1' does not exist. "
+ "Verify that all your impalads are the same version."),
+
+ ("RPC_GENERAL_ERROR", 31, "RPC Error: $0"),
+ ("RPC_TIMEOUT", 32, "RPC timed out"),
+
+ ("UDF_VERIFY_FAILED", 33,
+ "Failed to verify function $0 from LLVM module $1, see log for more details."),
+
+ ("PARQUET_CORRUPT_VALUE", 34, "File $0 corrupt. RLE level data bytes = $1"),
+
+ ("AVRO_DECIMAL_RESOLUTION_ERROR", 35, "Column '$0' has conflicting Avro decimal types. "
+ "Table schema $1: $2, file schema $1: $3"),
+
+ ("AVRO_DECIMAL_METADATA_MISMATCH", 36, "Column '$0' has conflicting Avro decimal types. "
+ "Declared $1: $2, $1 in table's Avro schema: $3"),
+
+ ("AVRO_SCHEMA_RESOLUTION_ERROR", 37, "Unresolvable types for column '$0': "
+ "table type: $1, file type: $2"),
+
+ ("AVRO_SCHEMA_METADATA_MISMATCH", 38, "Unresolvable types for column '$0': "
+ "declared column type: $1, table's Avro schema type: $2"),
+
+ ("AVRO_UNSUPPORTED_DEFAULT_VALUE", 39, "Field $0 is missing from file and default "
+ "values of type $1 are not yet supported."),
+
+ ("AVRO_MISSING_FIELD", 40, "Inconsistent table metadata. Mismatch between column "
+ "definition and Avro schema: cannot read field $0 because there are only $1 fields."),
+
+ ("AVRO_MISSING_DEFAULT", 41,
+ "Field $0 is missing from file and does not have a default value."),
+
+ ("AVRO_NULLABILITY_MISMATCH", 42,
+ "Field $0 is nullable in the file schema but not the table schema."),
+
+ ("AVRO_NOT_A_RECORD", 43,
+ "Inconsistent table metadata. Field $0 is not a record in the Avro schema."),
+
+ ("PARQUET_DEF_LEVEL_ERROR", 44, "Could not read definition level, even though metadata"
+ " states there are $0 values remaining in data page. file=$1"),
+
+ ("PARQUET_NUM_COL_VALS_ERROR", 45, "Mismatched number of values in column index $0 "
+ "($1 vs. $2). file=$3"),
+
+ ("PARQUET_DICT_DECODE_FAILURE", 46, "Failed to decode dictionary-encoded value. "
+ "file=$0"),
+
+ ("SSL_PASSWORD_CMD_FAILED", 47,
+ "SSL private-key password command ('$0') failed with error: $1"),
+
+ ("SSL_CERTIFICATE_PATH_BLANK", 48, "The SSL certificate path is blank"),
+ ("SSL_PRIVATE_KEY_PATH_BLANK", 49, "The SSL private key path is blank"),
+
+ ("SSL_CERTIFICATE_NOT_FOUND", 50, "The SSL certificate file does not exist at path $0"),
+ ("SSL_PRIVATE_KEY_NOT_FOUND", 51, "The SSL private key file does not exist at path $0"),
+
+ ("SSL_SOCKET_CREATION_FAILED", 52, "SSL socket creation failed: $0"),
+
+ ("MEM_ALLOC_FAILED", 53, "Memory allocation of $0 bytes failed"),
+
+ ("PARQUET_REP_LEVEL_ERROR", 54, "Could not read repetition level, even though metadata"
+ " states there are $0 values remaining in data page. file=$1"),
+
+ ("PARQUET_UNRECOGNIZED_SCHEMA", 55, "File '$0' has an incompatible Parquet schema for "
+ "column '$1'. Column type: $2, Parquet schema:\\n$3"),
+
+ ("COLLECTION_ALLOC_FAILED", 56, "Failed to allocate buffer for collection '$0'."),
+
+ ("TMP_DEVICE_BLACKLISTED", 57,
+ "Temporary device for directory $0 is blacklisted from a previous error and cannot "
+ "be used."),
+
+ ("TMP_FILE_BLACKLISTED", 58,
+ "Temporary file $0 is blacklisted from a previous error and cannot be expanded."),
+
+ ("RPC_CLIENT_CONNECT_FAILURE", 59,
+ "RPC client failed to connect: $0"),
+
+ ("STALE_METADATA_FILE_TOO_SHORT", 60, "Metadata for file '$0' appears stale. "
+ "Try running \\\"refresh $1\\\" to reload the file metadata."),
+
+ ("PARQUET_BAD_VERSION_NUMBER", 61, "File '$0' has an invalid version number: $1\\n"
+ "This could be due to stale metadata. Try running \\\"refresh $2\\\"."),
+
+ ("SCANNER_INCOMPLETE_READ", 62, "Tried to read $0 bytes but could only read $1 bytes. "
+ "This may indicate data file corruption. (file $2, byte offset: $3)"),
+
+ ("SCANNER_INVALID_READ", 63, "Invalid read of $0 bytes. This may indicate data file "
+ "corruption. (file $1, byte offset: $2)"),
+
+ ("AVRO_BAD_VERSION_HEADER", 64, "File '$0' has an invalid version header: $1\\n"
+ "Make sure the file is an Avro data file."),
+
+ ("UDF_MEM_LIMIT_EXCEEDED", 65, "$0's allocations exceeded memory limits."),
+
+ ("BTS_BLOCK_OVERFLOW", 66, "Cannot process row that is bigger than the IO size "
+ "(row_size=$0, null_indicators_size=$1). To run this query, increase the IO size "
+ "(--read_size option)."),
+
+ ("COMPRESSED_FILE_MULTIPLE_BLOCKS", 67,
+ "For better performance, snappy-, gzip-, and bzip-compressed files "
+ "should not be split into multiple HDFS blocks. file=$0 offset $1"),
+
+ ("COMPRESSED_FILE_BLOCK_CORRUPTED", 68,
+ "$0 Data error, likely data corrupted in this block."),
+
+ ("COMPRESSED_FILE_DECOMPRESSOR_ERROR", 69, "$0 Decompressor error at $1, code=$2"),
+
+ ("COMPRESSED_FILE_DECOMPRESSOR_NO_PROGRESS", 70,
+ "Decompression failed to make progress, but end of input is not reached. "
+ "File appears corrupted. file=$0"),
+
+ ("COMPRESSED_FILE_TRUNCATED", 71,
+ "Unexpected end of compressed file. File may be truncated. file=$0")
+)
+
+# Verifies the uniqueness of the error constants and numeric error codes.
+# Numeric codes must start from 0, be in order and have no gaps
+def check_duplicates(codes):
+ constants = {}
+ next_num_code = 0
+ for row in codes:
+ if row[0] in constants:
+ print("Constant %s already used, please check definition of '%s'!" % \
+ (row[0], constants[row[0]]))
+ exit(1)
+ if row[1] != next_num_code:
+ print("Numeric error codes must start from 0, be in order, and not have any gaps: "
+ "got %d, expected %d" % (row[1], next_num_code))
+ exit(1)
+ next_num_code += 1
+ constants[row[0]] = row[2]
+
+preamble = """
+// Copyright 2015 Cloudera Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+//
+// THIS FILE IS AUTO GENERATED BY generated_error_codes.py DO NOT MODIFY
+// IT BY HAND.
+//
+
+namespace cpp impala
+namespace java com.cloudera.impala.thrift
+
+"""
+# The script will always generate the file, CMake will take care of running it only if
+# necessary.
+target_file = os.path.join(sys.argv[1], "ErrorCodes.thrift")
+
+# Check uniqueness of error constants and numeric codes
+check_duplicates(error_codes)
+
+fid = open(target_file, "w+")
+try:
+ fid.write(preamble)
+ fid.write("""\nenum TErrorCode {\n""")
+ fid.write(",\n".join(map(lambda x: " %s = %d" % (x[0], x[1]), error_codes)))
+ fid.write("\n}")
+ fid.write("\n")
+ fid.write("const list<string> TErrorMessage = [\n")
+ fid.write(",\n".join(map(lambda x: " // %s\n \"%s\"" %(x[0], x[2]), error_codes)))
+ fid.write("\n]")
+finally:
+ fid.close()
+
+print("%s created." % target_file)
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/hive_metastore.thrift b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/hive_metastore.thrift
new file mode 100644
index 000000000..30dae14fc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift/hive_metastore.thrift
@@ -0,0 +1,1214 @@
+#!/usr/local/bin/thrift -java
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#
+# Thrift Service that the MetaStore is built on
+#
+
+include "fb303.thrift"
+
+namespace java org.apache.hadoop.hive.metastore.api
+namespace php metastore
+namespace cpp Apache.Hadoop.Hive
+
+const string DDL_TIME = "transient_lastDdlTime"
+
+struct Version {
+ 1: string version,
+ 2: string comments
+}
+
+struct FieldSchema {
+ 1: string name, // name of the field
+ 2: string type, // type of the field. primitive types defined above, specify list<TYPE_NAME>, map<TYPE_NAME, TYPE_NAME> for lists & maps
+ 3: string comment
+}
+
+struct Type {
+ 1: string name, // one of the types in PrimitiveTypes or CollectionTypes or User defined types
+ 2: optional string type1, // object type if the name is 'list' (LIST_TYPE), key type if the name is 'map' (MAP_TYPE)
+ 3: optional string type2, // val type if the name is 'map' (MAP_TYPE)
+ 4: optional list<FieldSchema> fields // if the name is one of the user defined types
+}
+
+enum HiveObjectType {
+ GLOBAL = 1,
+ DATABASE = 2,
+ TABLE = 3,
+ PARTITION = 4,
+ COLUMN = 5,
+}
+
+enum PrincipalType {
+ USER = 1,
+ ROLE = 2,
+ GROUP = 3,
+}
+
+const string HIVE_FILTER_FIELD_OWNER = "hive_filter_field_owner__"
+const string HIVE_FILTER_FIELD_PARAMS = "hive_filter_field_params__"
+const string HIVE_FILTER_FIELD_LAST_ACCESS = "hive_filter_field_last_access__"
+
+enum PartitionEventType {
+ LOAD_DONE = 1,
+}
+
+// Enums for transaction and lock management
+enum TxnState {
+ COMMITTED = 1,
+ ABORTED = 2,
+ OPEN = 3,
+}
+
+enum LockLevel {
+ DB = 1,
+ TABLE = 2,
+ PARTITION = 3,
+}
+
+enum LockState {
+ ACQUIRED = 1, // requester has the lock
+ WAITING = 2, // requester is waiting for the lock and should call checklock at a later point to see if the lock has been obtained.
+ ABORT = 3, // the lock has been aborted, most likely due to timeout
+ NOT_ACQUIRED = 4, // returned only with lockNoWait, indicates the lock was not available and was not acquired
+}
+
+enum LockType {
+ SHARED_READ = 1,
+ SHARED_WRITE = 2,
+ EXCLUSIVE = 3,
+}
+
+enum CompactionType {
+ MINOR = 1,
+ MAJOR = 2,
+}
+
+enum GrantRevokeType {
+ GRANT = 1,
+ REVOKE = 2,
+}
+
+struct HiveObjectRef{
+ 1: HiveObjectType objectType,
+ 2: string dbName,
+ 3: string objectName,
+ 4: list<string> partValues,
+ 5: string columnName,
+}
+
+struct PrivilegeGrantInfo {
+ 1: string privilege,
+ 2: i32 createTime,
+ 3: string grantor,
+ 4: PrincipalType grantorType,
+ 5: bool grantOption,
+}
+
+struct HiveObjectPrivilege {
+ 1: HiveObjectRef hiveObject,
+ 2: string principalName,
+ 3: PrincipalType principalType,
+ 4: PrivilegeGrantInfo grantInfo,
+}
+
+struct PrivilegeBag {
+ 1: list<HiveObjectPrivilege> privileges,
+}
+
+struct PrincipalPrivilegeSet {
+ 1: map<string, list<PrivilegeGrantInfo>> userPrivileges, // user name -> privilege grant info
+ 2: map<string, list<PrivilegeGrantInfo>> groupPrivileges, // group name -> privilege grant info
+ 3: map<string, list<PrivilegeGrantInfo>> rolePrivileges, //role name -> privilege grant info
+}
+
+struct GrantRevokePrivilegeRequest {
+ 1: GrantRevokeType requestType;
+ 2: PrivilegeBag privileges;
+ 3: optional bool revokeGrantOption; // Only for revoke request
+}
+
+struct GrantRevokePrivilegeResponse {
+ 1: optional bool success;
+}
+
+struct Role {
+ 1: string roleName,
+ 2: i32 createTime,
+ 3: string ownerName,
+}
+
+// Representation of a grant for a principal to a role
+struct RolePrincipalGrant {
+ 1: string roleName,
+ 2: string principalName,
+ 3: PrincipalType principalType,
+ 4: bool grantOption,
+ 5: i32 grantTime,
+ 6: string grantorName,
+ 7: PrincipalType grantorPrincipalType
+}
+
+struct GetRoleGrantsForPrincipalRequest {
+ 1: required string principal_name,
+ 2: required PrincipalType principal_type
+}
+
+struct GetRoleGrantsForPrincipalResponse {
+ 1: required list<RolePrincipalGrant> principalGrants;
+}
+
+struct GetPrincipalsInRoleRequest {
+ 1: required string roleName;
+}
+
+struct GetPrincipalsInRoleResponse {
+ 1: required list<RolePrincipalGrant> principalGrants;
+}
+
+struct GrantRevokeRoleRequest {
+ 1: GrantRevokeType requestType;
+ 2: string roleName;
+ 3: string principalName;
+ 4: PrincipalType principalType;
+ 5: optional string grantor; // Needed for grant
+ 6: optional PrincipalType grantorType; // Needed for grant
+ 7: optional bool grantOption;
+}
+
+struct GrantRevokeRoleResponse {
+ 1: optional bool success;
+}
+
+// namespace for tables
+struct Database {
+ 1: string name,
+ 2: string description,
+ 3: string locationUri,
+ 4: map<string, string> parameters, // properties associated with the database
+ 5: optional PrincipalPrivilegeSet privileges,
+ 6: optional string ownerName,
+ 7: optional PrincipalType ownerType
+}
+
+// This object holds the information needed by SerDes
+struct SerDeInfo {
+ 1: string name, // name of the serde, table name by default
+ 2: string serializationLib, // usually the class that implements the extractor & loader
+ 3: map<string, string> parameters // initialization parameters
+}
+
+// sort order of a column (column name along with asc(1)/desc(0))
+struct Order {
+ 1: string col, // sort column name
+ 2: i32 order // asc(1) or desc(0)
+}
+
+// this object holds all the information about skewed table
+struct SkewedInfo {
+ 1: list<string> skewedColNames, // skewed column names
+ 2: list<list<string>> skewedColValues, //skewed values
+ 3: map<list<string>, string> skewedColValueLocationMaps, //skewed value to location mappings
+}
+
+// this object holds all the information about physical storage of the data belonging to a table
+struct StorageDescriptor {
+ 1: list<FieldSchema> cols, // required (refer to types defined above)
+ 2: string location, // defaults to <warehouse loc>/<db loc>/tablename
+ 3: string inputFormat, // SequenceFileInputFormat (binary) or TextInputFormat` or custom format
+ 4: string outputFormat, // SequenceFileOutputFormat (binary) or IgnoreKeyTextOutputFormat or custom format
+ 5: bool compressed, // compressed or not
+ 6: i32 numBuckets, // this must be specified if there are any dimension columns
+ 7: SerDeInfo serdeInfo, // serialization and deserialization information
+ 8: list<string> bucketCols, // reducer grouping columns and clustering columns and bucketing columns`
+ 9: list<Order> sortCols, // sort order of the data in each bucket
+ 10: map<string, string> parameters, // any user supplied key value hash
+ 11: optional SkewedInfo skewedInfo, // skewed information
+ 12: optional bool storedAsSubDirectories // stored as subdirectories or not
+}
+
+// table information
+struct Table {
+ 1: string tableName, // name of the table
+ 2: string dbName, // database name ('default')
+ 3: string owner, // owner of this table
+ 4: i32 createTime, // creation time of the table
+ 5: i32 lastAccessTime, // last access time (usually this will be filled from HDFS and shouldn't be relied on)
+ 6: i32 retention, // retention time
+ 7: StorageDescriptor sd, // storage descriptor of the table
+ 8: list<FieldSchema> partitionKeys, // partition keys of the table. only primitive types are supported
+ 9: map<string, string> parameters, // to store comments or any other user level parameters
+ 10: string viewOriginalText, // original view text, null for non-view
+ 11: string viewExpandedText, // expanded view text, null for non-view
+ 12: string tableType, // table type enum, e.g. EXTERNAL_TABLE
+ 13: optional PrincipalPrivilegeSet privileges,
+ 14: optional bool temporary=false
+}
+
+struct Partition {
+ 1: list<string> values // string value is converted to appropriate partition key type
+ 2: string dbName,
+ 3: string tableName,
+ 4: i32 createTime,
+ 5: i32 lastAccessTime,
+ 6: StorageDescriptor sd,
+ 7: map<string, string> parameters,
+ 8: optional PrincipalPrivilegeSet privileges
+}
+
+struct PartitionWithoutSD {
+ 1: list<string> values // string value is converted to appropriate partition key type
+ 2: i32 createTime,
+ 3: i32 lastAccessTime,
+ 4: string relativePath,
+ 5: map<string, string> parameters,
+ 6: optional PrincipalPrivilegeSet privileges
+}
+
+struct PartitionSpecWithSharedSD {
+ 1: list<PartitionWithoutSD> partitions,
+ 2: StorageDescriptor sd,
+}
+
+struct PartitionListComposingSpec {
+ 1: list<Partition> partitions
+}
+
+struct PartitionSpec {
+ 1: string dbName,
+ 2: string tableName,
+ 3: string rootPath,
+ 4: optional PartitionSpecWithSharedSD sharedSDPartitionSpec,
+ 5: optional PartitionListComposingSpec partitionList
+}
+
+struct Index {
+ 1: string indexName, // unique with in the whole database namespace
+ 2: string indexHandlerClass, // reserved
+ 3: string dbName,
+ 4: string origTableName,
+ 5: i32 createTime,
+ 6: i32 lastAccessTime,
+ 7: string indexTableName,
+ 8: StorageDescriptor sd,
+ 9: map<string, string> parameters,
+ 10: bool deferredRebuild
+}
+
+// column statistics
+struct BooleanColumnStatsData {
+1: required i64 numTrues,
+2: required i64 numFalses,
+3: required i64 numNulls
+}
+
+struct DoubleColumnStatsData {
+1: optional double lowValue,
+2: optional double highValue,
+3: required i64 numNulls,
+4: required i64 numDVs
+}
+
+struct LongColumnStatsData {
+1: optional i64 lowValue,
+2: optional i64 highValue,
+3: required i64 numNulls,
+4: required i64 numDVs
+}
+
+struct StringColumnStatsData {
+1: required i64 maxColLen,
+2: required double avgColLen,
+3: required i64 numNulls,
+4: required i64 numDVs
+}
+
+struct BinaryColumnStatsData {
+1: required i64 maxColLen,
+2: required double avgColLen,
+3: required i64 numNulls
+}
+
+
+struct Decimal {
+1: required binary unscaled,
+3: required i16 scale
+}
+
+struct DecimalColumnStatsData {
+1: optional Decimal lowValue,
+2: optional Decimal highValue,
+3: required i64 numNulls,
+4: required i64 numDVs
+}
+
+union ColumnStatisticsData {
+1: BooleanColumnStatsData booleanStats,
+2: LongColumnStatsData longStats,
+3: DoubleColumnStatsData doubleStats,
+4: StringColumnStatsData stringStats,
+5: BinaryColumnStatsData binaryStats,
+6: DecimalColumnStatsData decimalStats
+}
+
+struct ColumnStatisticsObj {
+1: required string colName,
+2: required string colType,
+3: required ColumnStatisticsData statsData
+}
+
+struct ColumnStatisticsDesc {
+1: required bool isTblLevel,
+2: required string dbName,
+3: required string tableName,
+4: optional string partName,
+5: optional i64 lastAnalyzed
+}
+
+struct ColumnStatistics {
+1: required ColumnStatisticsDesc statsDesc,
+2: required list<ColumnStatisticsObj> statsObj;
+}
+
+struct AggrStats {
+1: required list<ColumnStatisticsObj> colStats,
+2: required i64 partsFound // number of partitions for which stats were found
+}
+
+struct SetPartitionsStatsRequest {
+1: required list<ColumnStatistics> colStats
+}
+
+// schema of the table/query results etc.
+struct Schema {
+ // column names, types, comments
+ 1: list<FieldSchema> fieldSchemas, // delimiters etc
+ 2: map<string, string> properties
+}
+
+// Key-value store to be used with selected
+// Metastore APIs (create, alter methods).
+// The client can pass environment properties / configs that can be
+// accessed in hooks.
+struct EnvironmentContext {
+ 1: map<string, string> properties
+}
+
+// Return type for get_partitions_by_expr
+struct PartitionsByExprResult {
+ 1: required list<Partition> partitions,
+ // Whether the results has any (currently, all) partitions which may or may not match
+ 2: required bool hasUnknownPartitions
+}
+
+struct PartitionsByExprRequest {
+ 1: required string dbName,
+ 2: required string tblName,
+ 3: required binary expr,
+ 4: optional string defaultPartitionName,
+ 5: optional i16 maxParts=-1
+}
+
+struct TableStatsResult {
+ 1: required list<ColumnStatisticsObj> tableStats
+}
+
+struct PartitionsStatsResult {
+ 1: required map<string, list<ColumnStatisticsObj>> partStats
+}
+
+struct TableStatsRequest {
+ 1: required string dbName,
+ 2: required string tblName,
+ 3: required list<string> colNames
+}
+
+struct PartitionsStatsRequest {
+ 1: required string dbName,
+ 2: required string tblName,
+ 3: required list<string> colNames,
+ 4: required list<string> partNames
+}
+
+// Return type for add_partitions_req
+struct AddPartitionsResult {
+ 1: optional list<Partition> partitions,
+}
+
+// Request type for add_partitions_req
+struct AddPartitionsRequest {
+ 1: required string dbName,
+ 2: required string tblName,
+ 3: required list<Partition> parts,
+ 4: required bool ifNotExists,
+ 5: optional bool needResult=true
+}
+
+// Return type for drop_partitions_req
+struct DropPartitionsResult {
+ 1: optional list<Partition> partitions,
+}
+
+struct DropPartitionsExpr {
+ 1: required binary expr;
+ 2: optional i32 partArchiveLevel;
+}
+
+union RequestPartsSpec {
+ 1: list<string> names;
+ 2: list<DropPartitionsExpr> exprs;
+}
+
+// Request type for drop_partitions_req
+// TODO: we might want to add "bestEffort" flag; where a subset can fail
+struct DropPartitionsRequest {
+ 1: required string dbName,
+ 2: required string tblName,
+ 3: required RequestPartsSpec parts,
+ 4: optional bool deleteData,
+ 5: optional bool ifExists=true, // currently verified on client
+ 6: optional bool ignoreProtection,
+ 7: optional EnvironmentContext environmentContext,
+ 8: optional bool needResult=true
+}
+
+enum FunctionType {
+ JAVA = 1,
+}
+
+enum ResourceType {
+ JAR = 1,
+ FILE = 2,
+ ARCHIVE = 3,
+}
+
+struct ResourceUri {
+ 1: ResourceType resourceType,
+ 2: string uri,
+}
+
+// User-defined function
+struct Function {
+ 1: string functionName,
+ 2: string dbName,
+ 3: string className,
+ 4: string ownerName,
+ 5: PrincipalType ownerType,
+ 6: i32 createTime,
+ 7: FunctionType functionType,
+ 8: list<ResourceUri> resourceUris,
+}
+
+// Structs for transaction and locks
+struct TxnInfo {
+ 1: required i64 id,
+ 2: required TxnState state,
+ 3: required string user, // used in 'show transactions' to help admins find who has open transactions
+ 4: required string hostname, // used in 'show transactions' to help admins find who has open transactions
+}
+
+struct GetOpenTxnsInfoResponse {
+ 1: required i64 txn_high_water_mark,
+ 2: required list<TxnInfo> open_txns,
+}
+
+struct GetOpenTxnsResponse {
+ 1: required i64 txn_high_water_mark,
+ 2: required set<i64> open_txns,
+}
+
+struct OpenTxnRequest {
+ 1: required i32 num_txns,
+ 2: required string user,
+ 3: required string hostname,
+}
+
+struct OpenTxnsResponse {
+ 1: required list<i64> txn_ids,
+}
+
+struct AbortTxnRequest {
+ 1: required i64 txnid,
+}
+
+struct CommitTxnRequest {
+ 1: required i64 txnid,
+}
+
+struct LockComponent {
+ 1: required LockType type,
+ 2: required LockLevel level,
+ 3: required string dbname,
+ 4: optional string tablename,
+ 5: optional string partitionname,
+}
+
+struct LockRequest {
+ 1: required list<LockComponent> component,
+ 2: optional i64 txnid,
+ 3: required string user, // used in 'show locks' to help admins find who has open locks
+ 4: required string hostname, // used in 'show locks' to help admins find who has open locks
+}
+
+struct LockResponse {
+ 1: required i64 lockid,
+ 2: required LockState state,
+}
+
+struct CheckLockRequest {
+ 1: required i64 lockid,
+}
+
+struct UnlockRequest {
+ 1: required i64 lockid,
+}
+
+struct ShowLocksRequest {
+}
+
+struct ShowLocksResponseElement {
+ 1: required i64 lockid,
+ 2: required string dbname,
+ 3: optional string tablename,
+ 4: optional string partname,
+ 5: required LockState state,
+ 6: required LockType type,
+ 7: optional i64 txnid,
+ 8: required i64 lastheartbeat,
+ 9: optional i64 acquiredat,
+ 10: required string user,
+ 11: required string hostname,
+}
+
+struct ShowLocksResponse {
+ 1: list<ShowLocksResponseElement> locks,
+}
+
+struct HeartbeatRequest {
+ 1: optional i64 lockid,
+ 2: optional i64 txnid
+}
+
+struct HeartbeatTxnRangeRequest {
+ 1: required i64 min,
+ 2: required i64 max
+}
+
+struct HeartbeatTxnRangeResponse {
+ 1: required set<i64> aborted,
+ 2: required set<i64> nosuch
+}
+
+struct CompactionRequest {
+ 1: required string dbname,
+ 2: required string tablename,
+ 3: optional string partitionname,
+ 4: required CompactionType type,
+ 5: optional string runas,
+}
+
+struct ShowCompactRequest {
+}
+
+struct ShowCompactResponseElement {
+ 1: required string dbname,
+ 2: required string tablename,
+ 3: optional string partitionname,
+ 4: required CompactionType type,
+ 5: required string state,
+ 6: optional string workerid,
+ 7: optional i64 start,
+ 8: optional string runAs,
+}
+
+struct ShowCompactResponse {
+ 1: required list<ShowCompactResponseElement> compacts,
+}
+
+struct NotificationEventRequest {
+ 1: required i64 lastEvent,
+ 2: optional i32 maxEvents,
+}
+
+struct NotificationEvent {
+ 1: required i64 eventId,
+ 2: required i32 eventTime,
+ 3: required string eventType,
+ 4: optional string dbName,
+ 5: optional string tableName,
+ 6: required string message,
+}
+
+struct NotificationEventResponse {
+ 1: required list<NotificationEvent> events,
+}
+
+struct CurrentNotificationEventId {
+ 1: required i64 eventId,
+}
+
+struct InsertEventRequestData {
+ 1: required list<string> filesAdded
+}
+
+union FireEventRequestData {
+ 1: InsertEventRequestData insertData
+}
+
+struct FireEventRequest {
+ 1: required bool successful,
+ 2: required FireEventRequestData data
+ // dbname, tablename, and partition vals are included as optional in the top level event rather than placed in each type of
+ // subevent as I assume they'll be used across most event types.
+ 3: optional string dbName,
+ 4: optional string tableName,
+ 5: optional list<string> partitionVals,
+}
+
+struct FireEventResponse {
+ // NOP for now, this is just a place holder for future responses
+}
+
+
+struct GetAllFunctionsResponse {
+ 1: optional list<Function> functions
+}
+
+struct TableMeta {
+ 1: required string dbName;
+ 2: required string tableName;
+ 3: required string tableType;
+ 4: optional string comments;
+}
+
+exception MetaException {
+ 1: string message
+}
+
+exception UnknownTableException {
+ 1: string message
+}
+
+exception UnknownDBException {
+ 1: string message
+}
+
+exception AlreadyExistsException {
+ 1: string message
+}
+
+exception InvalidPartitionException {
+ 1: string message
+}
+
+exception UnknownPartitionException {
+ 1: string message
+}
+
+exception InvalidObjectException {
+ 1: string message
+}
+
+exception NoSuchObjectException {
+ 1: string message
+}
+
+exception IndexAlreadyExistsException {
+ 1: string message
+}
+
+exception InvalidOperationException {
+ 1: string message
+}
+
+exception ConfigValSecurityException {
+ 1: string message
+}
+
+exception InvalidInputException {
+ 1: string message
+}
+
+// Transaction and lock exceptions
+exception NoSuchTxnException {
+ 1: string message
+}
+
+exception TxnAbortedException {
+ 1: string message
+}
+
+exception TxnOpenException {
+ 1: string message
+}
+
+exception NoSuchLockException {
+ 1: string message
+}
+
+/**
+* This interface is live.
+*/
+service ThriftHiveMetastore extends fb303.FacebookService
+{
+ string getMetaConf(1:string key) throws(1:MetaException o1)
+ void setMetaConf(1:string key, 2:string value) throws(1:MetaException o1)
+
+ void create_database(1:Database database) throws(1:AlreadyExistsException o1, 2:InvalidObjectException o2, 3:MetaException o3)
+ Database get_database(1:string name) throws(1:NoSuchObjectException o1, 2:MetaException o2)
+ void drop_database(1:string name, 2:bool deleteData, 3:bool cascade) throws(1:NoSuchObjectException o1, 2:InvalidOperationException o2, 3:MetaException o3)
+ list<string> get_databases(1:string pattern) throws(1:MetaException o1)
+ list<string> get_all_databases() throws(1:MetaException o1)
+ void alter_database(1:string dbname, 2:Database db) throws(1:MetaException o1, 2:NoSuchObjectException o2)
+
+ // returns the type with given name (make separate calls for the dependent types if needed)
+ Type get_type(1:string name) throws(1:MetaException o1, 2:NoSuchObjectException o2)
+ bool create_type(1:Type type) throws(1:AlreadyExistsException o1, 2:InvalidObjectException o2, 3:MetaException o3)
+ bool drop_type(1:string type) throws(1:MetaException o1, 2:NoSuchObjectException o2)
+ map<string, Type> get_type_all(1:string name)
+ throws(1:MetaException o2)
+
+ // Gets a list of FieldSchemas describing the columns of a particular table
+ list<FieldSchema> get_fields(1: string db_name, 2: string table_name) throws (1: MetaException o1, 2: UnknownTableException o2, 3: UnknownDBException o3),
+ list<FieldSchema> get_fields_with_environment_context(1: string db_name, 2: string table_name, 3:EnvironmentContext environment_context) throws (1: MetaException o1, 2: UnknownTableException o2, 3: UnknownDBException o3)
+
+ // Gets a list of FieldSchemas describing both the columns and the partition keys of a particular table
+ list<FieldSchema> get_schema(1: string db_name, 2: string table_name) throws (1: MetaException o1, 2: UnknownTableException o2, 3: UnknownDBException o3)
+ list<FieldSchema> get_schema_with_environment_context(1: string db_name, 2: string table_name, 3:EnvironmentContext environment_context) throws (1: MetaException o1, 2: UnknownTableException o2, 3: UnknownDBException o3)
+
+ // create a Hive table. Following fields must be set
+ // tableName
+ // database (only 'default' for now until Hive QL supports databases)
+ // owner (not needed, but good to have for tracking purposes)
+ // sd.cols (list of field schemas)
+ // sd.inputFormat (SequenceFileInputFormat (binary like falcon tables or u_full) or TextInputFormat)
+ // sd.outputFormat (SequenceFileInputFormat (binary) or TextInputFormat)
+ // sd.serdeInfo.serializationLib (SerDe class name eg org.apache.hadoop.hive.serde.simple_meta.MetadataTypedColumnsetSerDe
+ // * See notes on DDL_TIME
+ void create_table(1:Table tbl) throws(1:AlreadyExistsException o1, 2:InvalidObjectException o2, 3:MetaException o3, 4:NoSuchObjectException o4)
+ void create_table_with_environment_context(1:Table tbl,
+ 2:EnvironmentContext environment_context)
+ throws (1:AlreadyExistsException o1,
+ 2:InvalidObjectException o2, 3:MetaException o3,
+ 4:NoSuchObjectException o4)
+ // drops the table and all the partitions associated with it if the table has partitions
+ // delete data (including partitions) if deleteData is set to true
+ void drop_table(1:string dbname, 2:string name, 3:bool deleteData)
+ throws(1:NoSuchObjectException o1, 2:MetaException o3)
+ void drop_table_with_environment_context(1:string dbname, 2:string name, 3:bool deleteData,
+ 4:EnvironmentContext environment_context)
+ throws(1:NoSuchObjectException o1, 2:MetaException o3)
+ list<string> get_tables(1: string db_name, 2: string pattern) throws (1: MetaException o1)
+ list<TableMeta> get_table_meta(1: string db_patterns, 2: string tbl_patterns, 3: list<string> tbl_types)
+ throws (1: MetaException o1)
+ list<string> get_all_tables(1: string db_name) throws (1: MetaException o1)
+
+ Table get_table(1:string dbname, 2:string tbl_name)
+ throws (1:MetaException o1, 2:NoSuchObjectException o2)
+ list<Table> get_table_objects_by_name(1:string dbname, 2:list<string> tbl_names)
+ throws (1:MetaException o1, 2:InvalidOperationException o2, 3:UnknownDBException o3)
+
+ // Get a list of table names that match a filter.
+ // The filter operators are LIKE, <, <=, >, >=, =, <>
+ //
+ // In the filter statement, values interpreted as strings must be enclosed in quotes,
+ // while values interpreted as integers should not be. Strings and integers are the only
+ // supported value types.
+ //
+ // The currently supported key names in the filter are:
+ // Constants.HIVE_FILTER_FIELD_OWNER, which filters on the tables' owner's name
+ // and supports all filter operators
+ // Constants.HIVE_FILTER_FIELD_LAST_ACCESS, which filters on the last access times
+ // and supports all filter operators except LIKE
+ // Constants.HIVE_FILTER_FIELD_PARAMS, which filters on the tables' parameter keys and values
+ // and only supports the filter operators = and <>.
+ // Append the parameter key name to HIVE_FILTER_FIELD_PARAMS in the filter statement.
+ // For example, to filter on parameter keys called "retention", the key name in the filter
+ // statement should be Constants.HIVE_FILTER_FIELD_PARAMS + "retention"
+ // Also, = and <> only work for keys that exist
+ // in the tables. E.g., if you are looking for tables where key1 <> value, it will only
+ // look at tables that have a value for the parameter key1.
+ // Some example filter statements include:
+ // filter = Constants.HIVE_FILTER_FIELD_OWNER + " like \".*test.*\" and " +
+ // Constants.HIVE_FILTER_FIELD_LAST_ACCESS + " = 0";
+ // filter = Constants.HIVE_FILTER_FIELD_PARAMS + "retention = \"30\" or " +
+ // Constants.HIVE_FILTER_FIELD_PARAMS + "retention = \"90\""
+ // @param dbName
+ // The name of the database from which you will retrieve the table names
+ // @param filterType
+ // The type of filter
+ // @param filter
+ // The filter string
+ // @param max_tables
+ // The maximum number of tables returned
+ // @return A list of table names that match the desired filter
+ list<string> get_table_names_by_filter(1:string dbname, 2:string filter, 3:i16 max_tables=-1)
+ throws (1:MetaException o1, 2:InvalidOperationException o2, 3:UnknownDBException o3)
+
+ // alter table applies to only future partitions not for existing partitions
+ // * See notes on DDL_TIME
+ void alter_table(1:string dbname, 2:string tbl_name, 3:Table new_tbl)
+ throws (1:InvalidOperationException o1, 2:MetaException o2)
+ void alter_table_with_environment_context(1:string dbname, 2:string tbl_name,
+ 3:Table new_tbl, 4:EnvironmentContext environment_context)
+ throws (1:InvalidOperationException o1, 2:MetaException o2)
+ // alter table not only applies to future partitions but also cascade to existing partitions
+ void alter_table_with_cascade(1:string dbname, 2:string tbl_name, 3:Table new_tbl, 4:bool cascade)
+ throws (1:InvalidOperationException o1, 2:MetaException o2)
+ // the following applies to only tables that have partitions
+ // * See notes on DDL_TIME
+ Partition add_partition(1:Partition new_part)
+ throws(1:InvalidObjectException o1, 2:AlreadyExistsException o2, 3:MetaException o3)
+ Partition add_partition_with_environment_context(1:Partition new_part,
+ 2:EnvironmentContext environment_context)
+ throws (1:InvalidObjectException o1, 2:AlreadyExistsException o2,
+ 3:MetaException o3)
+ i32 add_partitions(1:list<Partition> new_parts)
+ throws(1:InvalidObjectException o1, 2:AlreadyExistsException o2, 3:MetaException o3)
+ i32 add_partitions_pspec(1:list<PartitionSpec> new_parts)
+ throws(1:InvalidObjectException o1, 2:AlreadyExistsException o2, 3:MetaException o3)
+ Partition append_partition(1:string db_name, 2:string tbl_name, 3:list<string> part_vals)
+ throws (1:InvalidObjectException o1, 2:AlreadyExistsException o2, 3:MetaException o3)
+ AddPartitionsResult add_partitions_req(1:AddPartitionsRequest request)
+ throws(1:InvalidObjectException o1, 2:AlreadyExistsException o2, 3:MetaException o3)
+ Partition append_partition_with_environment_context(1:string db_name, 2:string tbl_name,
+ 3:list<string> part_vals, 4:EnvironmentContext environment_context)
+ throws (1:InvalidObjectException o1, 2:AlreadyExistsException o2, 3:MetaException o3)
+ Partition append_partition_by_name(1:string db_name, 2:string tbl_name, 3:string part_name)
+ throws (1:InvalidObjectException o1, 2:AlreadyExistsException o2, 3:MetaException o3)
+ Partition append_partition_by_name_with_environment_context(1:string db_name, 2:string tbl_name,
+ 3:string part_name, 4:EnvironmentContext environment_context)
+ throws (1:InvalidObjectException o1, 2:AlreadyExistsException o2, 3:MetaException o3)
+ bool drop_partition(1:string db_name, 2:string tbl_name, 3:list<string> part_vals, 4:bool deleteData)
+ throws(1:NoSuchObjectException o1, 2:MetaException o2)
+ bool drop_partition_with_environment_context(1:string db_name, 2:string tbl_name,
+ 3:list<string> part_vals, 4:bool deleteData, 5:EnvironmentContext environment_context)
+ throws(1:NoSuchObjectException o1, 2:MetaException o2)
+ bool drop_partition_by_name(1:string db_name, 2:string tbl_name, 3:string part_name, 4:bool deleteData)
+ throws(1:NoSuchObjectException o1, 2:MetaException o2)
+ bool drop_partition_by_name_with_environment_context(1:string db_name, 2:string tbl_name,
+ 3:string part_name, 4:bool deleteData, 5:EnvironmentContext environment_context)
+ throws(1:NoSuchObjectException o1, 2:MetaException o2)
+ DropPartitionsResult drop_partitions_req(1: DropPartitionsRequest req)
+ throws(1:NoSuchObjectException o1, 2:MetaException o2)
+
+ Partition get_partition(1:string db_name, 2:string tbl_name, 3:list<string> part_vals)
+ throws(1:MetaException o1, 2:NoSuchObjectException o2)
+ Partition exchange_partition(1:map<string, string> partitionSpecs, 2:string source_db,
+ 3:string source_table_name, 4:string dest_db, 5:string dest_table_name)
+ throws(1:MetaException o1, 2:NoSuchObjectException o2, 3:InvalidObjectException o3,
+ 4:InvalidInputException o4)
+
+ Partition get_partition_with_auth(1:string db_name, 2:string tbl_name, 3:list<string> part_vals,
+ 4: string user_name, 5: list<string> group_names) throws(1:MetaException o1, 2:NoSuchObjectException o2)
+
+ Partition get_partition_by_name(1:string db_name 2:string tbl_name, 3:string part_name)
+ throws(1:MetaException o1, 2:NoSuchObjectException o2)
+
+ // returns all the partitions for this table in reverse chronological order.
+ // If max parts is given then it will return only that many.
+ list<Partition> get_partitions(1:string db_name, 2:string tbl_name, 3:i16 max_parts=-1)
+ throws(1:NoSuchObjectException o1, 2:MetaException o2)
+ list<Partition> get_partitions_with_auth(1:string db_name, 2:string tbl_name, 3:i16 max_parts=-1,
+ 4: string user_name, 5: list<string> group_names) throws(1:NoSuchObjectException o1, 2:MetaException o2)
+
+ list<PartitionSpec> get_partitions_pspec(1:string db_name, 2:string tbl_name, 3:i32 max_parts=-1)
+ throws(1:NoSuchObjectException o1, 2:MetaException o2)
+
+ list<string> get_partition_names(1:string db_name, 2:string tbl_name, 3:i16 max_parts=-1)
+ throws(1:MetaException o2)
+
+ // get_partition*_ps methods allow filtering by a partial partition specification,
+ // as needed for dynamic partitions. The values that are not restricted should
+ // be empty strings. Nulls were considered (instead of "") but caused errors in
+ // generated Python code. The size of part_vals may be smaller than the
+ // number of partition columns - the unspecified values are considered the same
+ // as "".
+ list<Partition> get_partitions_ps(1:string db_name 2:string tbl_name
+ 3:list<string> part_vals, 4:i16 max_parts=-1)
+ throws(1:MetaException o1, 2:NoSuchObjectException o2)
+ list<Partition> get_partitions_ps_with_auth(1:string db_name, 2:string tbl_name, 3:list<string> part_vals, 4:i16 max_parts=-1,
+ 5: string user_name, 6: list<string> group_names) throws(1:NoSuchObjectException o1, 2:MetaException o2)
+
+ list<string> get_partition_names_ps(1:string db_name,
+ 2:string tbl_name, 3:list<string> part_vals, 4:i16 max_parts=-1)
+ throws(1:MetaException o1, 2:NoSuchObjectException o2)
+
+ // get the partitions matching the given partition filter
+ list<Partition> get_partitions_by_filter(1:string db_name 2:string tbl_name
+ 3:string filter, 4:i16 max_parts=-1)
+ throws(1:MetaException o1, 2:NoSuchObjectException o2)
+
+ // List partitions as PartitionSpec instances.
+ list<PartitionSpec> get_part_specs_by_filter(1:string db_name 2:string tbl_name
+ 3:string filter, 4:i32 max_parts=-1)
+ throws(1:MetaException o1, 2:NoSuchObjectException o2)
+
+ // get the partitions matching the given partition filter
+ // unlike get_partitions_by_filter, takes serialized hive expression, and with that can work
+ // with any filter (get_partitions_by_filter only works if the filter can be pushed down to JDOQL.
+ PartitionsByExprResult get_partitions_by_expr(1:PartitionsByExprRequest req)
+ throws(1:MetaException o1, 2:NoSuchObjectException o2)
+
+ // get partitions give a list of partition names
+ list<Partition> get_partitions_by_names(1:string db_name 2:string tbl_name 3:list<string> names)
+ throws(1:MetaException o1, 2:NoSuchObjectException o2)
+
+ // changes the partition to the new partition object. partition is identified from the part values
+ // in the new_part
+ // * See notes on DDL_TIME
+ void alter_partition(1:string db_name, 2:string tbl_name, 3:Partition new_part)
+ throws (1:InvalidOperationException o1, 2:MetaException o2)
+
+ // change a list of partitions. All partitions are altered atomically and all
+ // prehooks are fired together followed by all post hooks
+ void alter_partitions(1:string db_name, 2:string tbl_name, 3:list<Partition> new_parts)
+ throws (1:InvalidOperationException o1, 2:MetaException o2)
+
+ void alter_partition_with_environment_context(1:string db_name,
+ 2:string tbl_name, 3:Partition new_part,
+ 4:EnvironmentContext environment_context)
+ throws (1:InvalidOperationException o1, 2:MetaException o2)
+
+ // rename the old partition to the new partition object by changing old part values to the part values
+ // in the new_part. old partition is identified from part_vals.
+ // partition keys in new_part should be the same as those in old partition.
+ void rename_partition(1:string db_name, 2:string tbl_name, 3:list<string> part_vals, 4:Partition new_part)
+ throws (1:InvalidOperationException o1, 2:MetaException o2)
+
+ // returns whether or not the partition name is valid based on the value of the config
+ // hive.metastore.partition.name.whitelist.pattern
+ bool partition_name_has_valid_characters(1:list<string> part_vals, 2:bool throw_exception)
+ throws(1: MetaException o1)
+
+ // gets the value of the configuration key in the metastore server. returns
+ // defaultValue if the key does not exist. if the configuration key does not
+ // begin with "hive", "mapred", or "hdfs", a ConfigValSecurityException is
+ // thrown.
+ string get_config_value(1:string name, 2:string defaultValue)
+ throws(1:ConfigValSecurityException o1)
+
+ // converts a partition name into a partition values array
+ list<string> partition_name_to_vals(1: string part_name)
+ throws(1: MetaException o1)
+ // converts a partition name into a partition specification (a mapping from
+ // the partition cols to the values)
+ map<string, string> partition_name_to_spec(1: string part_name)
+ throws(1: MetaException o1)
+
+ void markPartitionForEvent(1:string db_name, 2:string tbl_name, 3:map<string,string> part_vals,
+ 4:PartitionEventType eventType) throws (1: MetaException o1, 2: NoSuchObjectException o2,
+ 3: UnknownDBException o3, 4: UnknownTableException o4, 5: UnknownPartitionException o5,
+ 6: InvalidPartitionException o6)
+ bool isPartitionMarkedForEvent(1:string db_name, 2:string tbl_name, 3:map<string,string> part_vals,
+ 4: PartitionEventType eventType) throws (1: MetaException o1, 2:NoSuchObjectException o2,
+ 3: UnknownDBException o3, 4: UnknownTableException o4, 5: UnknownPartitionException o5,
+ 6: InvalidPartitionException o6)
+
+ //index
+ Index add_index(1:Index new_index, 2: Table index_table)
+ throws(1:InvalidObjectException o1, 2:AlreadyExistsException o2, 3:MetaException o3)
+ void alter_index(1:string dbname, 2:string base_tbl_name, 3:string idx_name, 4:Index new_idx)
+ throws (1:InvalidOperationException o1, 2:MetaException o2)
+ bool drop_index_by_name(1:string db_name, 2:string tbl_name, 3:string index_name, 4:bool deleteData)
+ throws(1:NoSuchObjectException o1, 2:MetaException o2)
+ Index get_index_by_name(1:string db_name 2:string tbl_name, 3:string index_name)
+ throws(1:MetaException o1, 2:NoSuchObjectException o2)
+
+ list<Index> get_indexes(1:string db_name, 2:string tbl_name, 3:i16 max_indexes=-1)
+ throws(1:NoSuchObjectException o1, 2:MetaException o2)
+ list<string> get_index_names(1:string db_name, 2:string tbl_name, 3:i16 max_indexes=-1)
+ throws(1:MetaException o2)
+
+ // column statistics interfaces
+
+ // update APIs persist the column statistics object(s) that are passed in. If statistics already
+ // exists for one or more columns, the existing statistics will be overwritten. The update APIs
+ // validate that the dbName, tableName, partName, colName[] passed in as part of the ColumnStatistics
+ // struct are valid, throws InvalidInputException/NoSuchObjectException if found to be invalid
+ bool update_table_column_statistics(1:ColumnStatistics stats_obj) throws (1:NoSuchObjectException o1,
+ 2:InvalidObjectException o2, 3:MetaException o3, 4:InvalidInputException o4)
+ bool update_partition_column_statistics(1:ColumnStatistics stats_obj) throws (1:NoSuchObjectException o1,
+ 2:InvalidObjectException o2, 3:MetaException o3, 4:InvalidInputException o4)
+
+ // get APIs return the column statistics corresponding to db_name, tbl_name, [part_name], col_name if
+ // such statistics exists. If the required statistics doesn't exist, get APIs throw NoSuchObjectException
+ // For instance, if get_table_column_statistics is called on a partitioned table for which only
+ // partition level column stats exist, get_table_column_statistics will throw NoSuchObjectException
+ ColumnStatistics get_table_column_statistics(1:string db_name, 2:string tbl_name, 3:string col_name) throws
+ (1:NoSuchObjectException o1, 2:MetaException o2, 3:InvalidInputException o3, 4:InvalidObjectException o4)
+ ColumnStatistics get_partition_column_statistics(1:string db_name, 2:string tbl_name, 3:string part_name,
+ 4:string col_name) throws (1:NoSuchObjectException o1, 2:MetaException o2,
+ 3:InvalidInputException o3, 4:InvalidObjectException o4)
+ TableStatsResult get_table_statistics_req(1:TableStatsRequest request) throws
+ (1:NoSuchObjectException o1, 2:MetaException o2)
+ PartitionsStatsResult get_partitions_statistics_req(1:PartitionsStatsRequest request) throws
+ (1:NoSuchObjectException o1, 2:MetaException o2)
+ AggrStats get_aggr_stats_for(1:PartitionsStatsRequest request) throws
+ (1:NoSuchObjectException o1, 2:MetaException o2)
+ bool set_aggr_stats_for(1:SetPartitionsStatsRequest request) throws
+ (1:NoSuchObjectException o1, 2:InvalidObjectException o2, 3:MetaException o3, 4:InvalidInputException o4)
+
+
+ // delete APIs attempt to delete column statistics, if found, associated with a given db_name, tbl_name, [part_name]
+ // and col_name. If the delete API doesn't find the statistics record in the metastore, throws NoSuchObjectException
+ // Delete API validates the input and if the input is invalid throws InvalidInputException/InvalidObjectException.
+ bool delete_partition_column_statistics(1:string db_name, 2:string tbl_name, 3:string part_name, 4:string col_name) throws
+ (1:NoSuchObjectException o1, 2:MetaException o2, 3:InvalidObjectException o3,
+ 4:InvalidInputException o4)
+ bool delete_table_column_statistics(1:string db_name, 2:string tbl_name, 3:string col_name) throws
+ (1:NoSuchObjectException o1, 2:MetaException o2, 3:InvalidObjectException o3,
+ 4:InvalidInputException o4)
+
+ //
+ // user-defined functions
+ //
+
+ void create_function(1:Function func)
+ throws (1:AlreadyExistsException o1,
+ 2:InvalidObjectException o2,
+ 3:MetaException o3,
+ 4:NoSuchObjectException o4)
+
+ void drop_function(1:string dbName, 2:string funcName)
+ throws (1:NoSuchObjectException o1, 2:MetaException o3)
+
+ void alter_function(1:string dbName, 2:string funcName, 3:Function newFunc)
+ throws (1:InvalidOperationException o1, 2:MetaException o2)
+
+ list<string> get_functions(1:string dbName, 2:string pattern)
+ throws (1:MetaException o1)
+ Function get_function(1:string dbName, 2:string funcName)
+ throws (1:MetaException o1, 2:NoSuchObjectException o2)
+
+ GetAllFunctionsResponse get_all_functions() throws (1:MetaException o1)
+
+ //authorization privileges
+
+ bool create_role(1:Role role) throws(1:MetaException o1)
+ bool drop_role(1:string role_name) throws(1:MetaException o1)
+ list<string> get_role_names() throws(1:MetaException o1)
+ // Deprecated, use grant_revoke_role()
+ bool grant_role(1:string role_name, 2:string principal_name, 3:PrincipalType principal_type,
+ 4:string grantor, 5:PrincipalType grantorType, 6:bool grant_option) throws(1:MetaException o1)
+ // Deprecated, use grant_revoke_role()
+ bool revoke_role(1:string role_name, 2:string principal_name, 3:PrincipalType principal_type)
+ throws(1:MetaException o1)
+ list<Role> list_roles(1:string principal_name, 2:PrincipalType principal_type) throws(1:MetaException o1)
+ GrantRevokeRoleResponse grant_revoke_role(1:GrantRevokeRoleRequest request) throws(1:MetaException o1)
+
+ // get all role-grants for users/roles that have been granted the given role
+ // Note that in the returned list of RolePrincipalGrants, the roleName is
+ // redundant as it would match the role_name argument of this function
+ GetPrincipalsInRoleResponse get_principals_in_role(1: GetPrincipalsInRoleRequest request) throws(1:MetaException o1)
+
+ // get grant information of all roles granted to the given principal
+ // Note that in the returned list of RolePrincipalGrants, the principal name,type is
+ // redundant as it would match the principal name,type arguments of this function
+ GetRoleGrantsForPrincipalResponse get_role_grants_for_principal(1: GetRoleGrantsForPrincipalRequest request) throws(1:MetaException o1)
+
+ PrincipalPrivilegeSet get_privilege_set(1:HiveObjectRef hiveObject, 2:string user_name,
+ 3: list<string> group_names) throws(1:MetaException o1)
+ list<HiveObjectPrivilege> list_privileges(1:string principal_name, 2:PrincipalType principal_type,
+ 3: HiveObjectRef hiveObject) throws(1:MetaException o1)
+
+ // Deprecated, use grant_revoke_privileges()
+ bool grant_privileges(1:PrivilegeBag privileges) throws(1:MetaException o1)
+ // Deprecated, use grant_revoke_privileges()
+ bool revoke_privileges(1:PrivilegeBag privileges) throws(1:MetaException o1)
+ GrantRevokePrivilegeResponse grant_revoke_privileges(1:GrantRevokePrivilegeRequest request) throws(1:MetaException o1);
+
+ // this is used by metastore client to send UGI information to metastore server immediately
+ // after setting up a connection.
+ list<string> set_ugi(1:string user_name, 2:list<string> group_names) throws (1:MetaException o1)
+
+ //Authentication (delegation token) interfaces
+
+ // get metastore server delegation token for use from the map/reduce tasks to authenticate
+ // to metastore server
+ string get_delegation_token(1:string token_owner, 2:string renewer_kerberos_principal_name)
+ throws (1:MetaException o1)
+
+ // method to renew delegation token obtained from metastore server
+ i64 renew_delegation_token(1:string token_str_form) throws (1:MetaException o1)
+
+ // method to cancel delegation token obtained from metastore server
+ void cancel_delegation_token(1:string token_str_form) throws (1:MetaException o1)
+
+ // Transaction and lock management calls
+ // Get just list of open transactions
+ GetOpenTxnsResponse get_open_txns()
+ // Get list of open transactions with state (open, aborted)
+ GetOpenTxnsInfoResponse get_open_txns_info()
+ OpenTxnsResponse open_txns(1:OpenTxnRequest rqst)
+ void abort_txn(1:AbortTxnRequest rqst) throws (1:NoSuchTxnException o1)
+ void commit_txn(1:CommitTxnRequest rqst) throws (1:NoSuchTxnException o1, 2:TxnAbortedException o2)
+ LockResponse lock(1:LockRequest rqst) throws (1:NoSuchTxnException o1, 2:TxnAbortedException o2)
+ LockResponse check_lock(1:CheckLockRequest rqst)
+ throws (1:NoSuchTxnException o1, 2:TxnAbortedException o2, 3:NoSuchLockException o3)
+ void unlock(1:UnlockRequest rqst) throws (1:NoSuchLockException o1, 2:TxnOpenException o2)
+ ShowLocksResponse show_locks(1:ShowLocksRequest rqst)
+ void heartbeat(1:HeartbeatRequest ids) throws (1:NoSuchLockException o1, 2:NoSuchTxnException o2, 3:TxnAbortedException o3)
+ HeartbeatTxnRangeResponse heartbeat_txn_range(1:HeartbeatTxnRangeRequest txns)
+ void compact(1:CompactionRequest rqst)
+ ShowCompactResponse show_compact(1:ShowCompactRequest rqst)
+
+ // Notification logging calls
+ NotificationEventResponse get_next_notification(1:NotificationEventRequest rqst)
+ CurrentNotificationEventId get_current_notificationEventId()
+}
+
+// * Note about the DDL_TIME: When creating or altering a table or a partition,
+// if the DDL_TIME is not set, the current time will be used.
+
+// For storing info about archived partitions in parameters
+
+// Whether the partition is archived
+const string IS_ARCHIVED = "is_archived",
+// The original location of the partition, before archiving. After archiving,
+// this directory will contain the archive. When the partition
+// is dropped, this directory will be deleted
+const string ORIGINAL_LOCATION = "original_location",
+
+// Whether or not the table is considered immutable - immutable tables can only be
+// overwritten or created if unpartitioned, or if partitioned, partitions inside them
+// can only be overwritten or created. Immutability supports write-once and replace
+// semantics, but not append.
+const string IS_IMMUTABLE = "immutable",
+
+// these should be needed only for backward compatibility with filestore
+const string META_TABLE_COLUMNS = "columns",
+const string META_TABLE_COLUMN_TYPES = "columns.types",
+const string BUCKET_FIELD_NAME = "bucket_field_name",
+const string BUCKET_COUNT = "bucket_count",
+const string FIELD_TO_DIMENSION = "field_to_dimension",
+const string META_TABLE_NAME = "name",
+const string META_TABLE_DB = "db",
+const string META_TABLE_LOCATION = "location",
+const string META_TABLE_SERDE = "serde",
+const string META_TABLE_PARTITION_COLUMNS = "partition_columns",
+const string META_TABLE_PARTITION_COLUMN_TYPES = "partition_columns.types",
+const string FILE_INPUT_FORMAT = "file.inputformat",
+const string FILE_OUTPUT_FORMAT = "file.outputformat",
+const string META_TABLE_STORAGE = "storage_handler",
+const string TABLE_IS_TRANSACTIONAL = "transactional",
+const string TABLE_NO_AUTO_COMPACT = "no_auto_compaction",
+
+
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift_internal.cc b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift_internal.cc
new file mode 100644
index 000000000..91cd51e58
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift_internal.cc
@@ -0,0 +1,301 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dbi/hiveserver2/thrift_internal.h"
+
+#include <map>
+#include <sstream>
+
+#include "arrow/dbi/hiveserver2/TCLIService_constants.h"
+#include "arrow/dbi/hiveserver2/service.h"
+
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+
+namespace hs2 = apache::hive::service::cli::thrift;
+
+namespace arrow {
+namespace hiveserver2 {
+
+namespace {
+
+// Convert an "enum class" value to an integer equivalent, for outputting.
+template <typename ENUM>
+typename std::underlying_type<ENUM>::type EnumToInt(const ENUM& value) {
+ return static_cast<typename std::underlying_type<ENUM>::type>(value);
+}
+
+} // namespace
+
+const std::string OperationStateToString(const Operation::State& state) {
+ switch (state) {
+ case Operation::State::INITIALIZED:
+ return "INITIALIZED";
+ case Operation::State::RUNNING:
+ return "RUNNING";
+ case Operation::State::FINISHED:
+ return "FINISHED";
+ case Operation::State::CANCELED:
+ return "CANCELED";
+ case Operation::State::CLOSED:
+ return "CLOSED";
+ case Operation::State::ERROR:
+ return "ERROR";
+ case Operation::State::UNKNOWN:
+ return "UNKNOWN";
+ case Operation::State::PENDING:
+ return "PENDING";
+ default:
+ std::stringstream ss;
+ ss << "Unknown Operation::State " << EnumToInt(state);
+ return ss.str();
+ }
+}
+
+const std::string TypeIdToString(const ColumnType::TypeId& type_id) {
+ switch (type_id) {
+ case ColumnType::TypeId::BOOLEAN:
+ return "BOOLEAN";
+ case ColumnType::TypeId::TINYINT:
+ return "TINYINT";
+ case ColumnType::TypeId::SMALLINT:
+ return "SMALLINT";
+ case ColumnType::TypeId::INT:
+ return "INT";
+ case ColumnType::TypeId::BIGINT:
+ return "BIGINT";
+ case ColumnType::TypeId::FLOAT:
+ return "FLOAT";
+ case ColumnType::TypeId::DOUBLE:
+ return "DOUBLE";
+ case ColumnType::TypeId::STRING:
+ return "STRING";
+ case ColumnType::TypeId::TIMESTAMP:
+ return "TIMESTAMP";
+ case ColumnType::TypeId::BINARY:
+ return "BINARY";
+ case ColumnType::TypeId::ARRAY:
+ return "ARRAY";
+ case ColumnType::TypeId::MAP:
+ return "MAP";
+ case ColumnType::TypeId::STRUCT:
+ return "STRUCT";
+ case ColumnType::TypeId::UNION:
+ return "UNION";
+ case ColumnType::TypeId::USER_DEFINED:
+ return "USER_DEFINED";
+ case ColumnType::TypeId::DECIMAL:
+ return "DECIMAL";
+ case ColumnType::TypeId::NULL_TYPE:
+ return "NULL_TYPE";
+ case ColumnType::TypeId::DATE:
+ return "DATE";
+ case ColumnType::TypeId::VARCHAR:
+ return "VARCHAR";
+ case ColumnType::TypeId::CHAR:
+ return "CHAR";
+ case ColumnType::TypeId::INVALID:
+ return "INVALID";
+ default: {
+ std::stringstream ss;
+ ss << "Unknown ColumnType::TypeId " << EnumToInt(type_id);
+ return ss.str();
+ }
+ }
+}
+
+hs2::TFetchOrientation::type FetchOrientationToTFetchOrientation(
+ FetchOrientation orientation) {
+ switch (orientation) {
+ case FetchOrientation::NEXT:
+ return hs2::TFetchOrientation::FETCH_NEXT;
+ case FetchOrientation::PRIOR:
+ return hs2::TFetchOrientation::FETCH_PRIOR;
+ case FetchOrientation::RELATIVE:
+ return hs2::TFetchOrientation::FETCH_RELATIVE;
+ case FetchOrientation::ABSOLUTE:
+ return hs2::TFetchOrientation::FETCH_ABSOLUTE;
+ case FetchOrientation::FIRST:
+ return hs2::TFetchOrientation::FETCH_FIRST;
+ case FetchOrientation::LAST:
+ return hs2::TFetchOrientation::FETCH_LAST;
+ default:
+ DCHECK(false) << "Unknown FetchOrientation " << EnumToInt(orientation);
+ return hs2::TFetchOrientation::FETCH_NEXT;
+ }
+}
+
+hs2::TProtocolVersion::type ProtocolVersionToTProtocolVersion(ProtocolVersion protocol) {
+ switch (protocol) {
+ case ProtocolVersion::PROTOCOL_V1:
+ return hs2::TProtocolVersion::HIVE_CLI_SERVICE_PROTOCOL_V1;
+ case ProtocolVersion::PROTOCOL_V2:
+ return hs2::TProtocolVersion::HIVE_CLI_SERVICE_PROTOCOL_V2;
+ case ProtocolVersion::PROTOCOL_V3:
+ return hs2::TProtocolVersion::HIVE_CLI_SERVICE_PROTOCOL_V3;
+ case ProtocolVersion::PROTOCOL_V4:
+ return hs2::TProtocolVersion::HIVE_CLI_SERVICE_PROTOCOL_V4;
+ case ProtocolVersion::PROTOCOL_V5:
+ return hs2::TProtocolVersion::HIVE_CLI_SERVICE_PROTOCOL_V5;
+ case ProtocolVersion::PROTOCOL_V6:
+ return hs2::TProtocolVersion::HIVE_CLI_SERVICE_PROTOCOL_V6;
+ case ProtocolVersion::PROTOCOL_V7:
+ return hs2::TProtocolVersion::HIVE_CLI_SERVICE_PROTOCOL_V7;
+ default:
+ DCHECK(false) << "Unknown ProtocolVersion " << EnumToInt(protocol);
+ return hs2::TProtocolVersion::HIVE_CLI_SERVICE_PROTOCOL_V7;
+ }
+}
+
+Operation::State TOperationStateToOperationState(
+ const hs2::TOperationState::type& tstate) {
+ switch (tstate) {
+ case hs2::TOperationState::INITIALIZED_STATE:
+ return Operation::State::INITIALIZED;
+ case hs2::TOperationState::RUNNING_STATE:
+ return Operation::State::RUNNING;
+ case hs2::TOperationState::FINISHED_STATE:
+ return Operation::State::FINISHED;
+ case hs2::TOperationState::CANCELED_STATE:
+ return Operation::State::CANCELED;
+ case hs2::TOperationState::CLOSED_STATE:
+ return Operation::State::CLOSED;
+ case hs2::TOperationState::ERROR_STATE:
+ return Operation::State::ERROR;
+ case hs2::TOperationState::UKNOWN_STATE:
+ return Operation::State::UNKNOWN;
+ case hs2::TOperationState::PENDING_STATE:
+ return Operation::State::PENDING;
+ default:
+ ARROW_LOG(WARNING) << "Unknown TOperationState " << tstate;
+ return Operation::State::UNKNOWN;
+ }
+}
+
+Status TStatusToStatus(const hs2::TStatus& tstatus) {
+ switch (tstatus.statusCode) {
+ case hs2::TStatusCode::SUCCESS_STATUS:
+ return Status::OK();
+ case hs2::TStatusCode::SUCCESS_WITH_INFO_STATUS: {
+ std::stringstream ss;
+ for (size_t i = 0; i < tstatus.infoMessages.size(); i++) {
+ if (i != 0) ss << ",";
+ ss << tstatus.infoMessages[i];
+ }
+ return Status::OK(ss.str());
+ }
+ case hs2::TStatusCode::STILL_EXECUTING_STATUS:
+ return Status::ExecutionError("Still executing");
+ case hs2::TStatusCode::ERROR_STATUS:
+ return Status::IOError(tstatus.errorMessage);
+ case hs2::TStatusCode::INVALID_HANDLE_STATUS:
+ return Status::Invalid("Invalid handle");
+ default: {
+ return Status::UnknownError("Unknown TStatusCode ", tstatus.statusCode);
+ }
+ }
+}
+
+std::unique_ptr<ColumnType> TTypeDescToColumnType(const hs2::TTypeDesc& ttype_desc) {
+ if (ttype_desc.types.size() != 1 || !ttype_desc.types[0].__isset.primitiveEntry) {
+ ARROW_LOG(WARNING) << "TTypeDescToColumnType only supports primitive types.";
+ return std::unique_ptr<ColumnType>(new PrimitiveType(ColumnType::TypeId::INVALID));
+ }
+
+ ColumnType::TypeId type_id = TTypeIdToTypeId(ttype_desc.types[0].primitiveEntry.type);
+ if (type_id == ColumnType::TypeId::CHAR || type_id == ColumnType::TypeId::VARCHAR) {
+ const std::map<std::string, hs2::TTypeQualifierValue>& qualifiers =
+ ttype_desc.types[0].primitiveEntry.typeQualifiers.qualifiers;
+ DCHECK_EQ(qualifiers.count(hs2::g_TCLIService_constants.CHARACTER_MAXIMUM_LENGTH), 1);
+
+ try {
+ return std::unique_ptr<ColumnType>(new CharacterType(
+ type_id,
+ qualifiers.at(hs2::g_TCLIService_constants.CHARACTER_MAXIMUM_LENGTH).i32Value));
+ } catch (std::out_of_range e) {
+ ARROW_LOG(ERROR) << "Character type qualifiers invalid: " << e.what();
+ return std::unique_ptr<ColumnType>(new PrimitiveType(ColumnType::TypeId::INVALID));
+ }
+ } else if (type_id == ColumnType::TypeId::DECIMAL) {
+ const std::map<std::string, hs2::TTypeQualifierValue>& qualifiers =
+ ttype_desc.types[0].primitiveEntry.typeQualifiers.qualifiers;
+ DCHECK_EQ(qualifiers.count(hs2::g_TCLIService_constants.PRECISION), 1);
+ DCHECK_EQ(qualifiers.count(hs2::g_TCLIService_constants.SCALE), 1);
+
+ try {
+ return std::unique_ptr<ColumnType>(new DecimalType(
+ type_id, qualifiers.at(hs2::g_TCLIService_constants.PRECISION).i32Value,
+ qualifiers.at(hs2::g_TCLIService_constants.SCALE).i32Value));
+ } catch (std::out_of_range e) {
+ ARROW_LOG(ERROR) << "Decimal type qualifiers invalid: " << e.what();
+ return std::unique_ptr<ColumnType>(new PrimitiveType(ColumnType::TypeId::INVALID));
+ }
+ } else {
+ return std::unique_ptr<ColumnType>(new PrimitiveType(type_id));
+ }
+}
+
+ColumnType::TypeId TTypeIdToTypeId(const hs2::TTypeId::type& type_id) {
+ switch (type_id) {
+ case hs2::TTypeId::BOOLEAN_TYPE:
+ return ColumnType::TypeId::BOOLEAN;
+ case hs2::TTypeId::TINYINT_TYPE:
+ return ColumnType::TypeId::TINYINT;
+ case hs2::TTypeId::SMALLINT_TYPE:
+ return ColumnType::TypeId::SMALLINT;
+ case hs2::TTypeId::INT_TYPE:
+ return ColumnType::TypeId::INT;
+ case hs2::TTypeId::BIGINT_TYPE:
+ return ColumnType::TypeId::BIGINT;
+ case hs2::TTypeId::FLOAT_TYPE:
+ return ColumnType::TypeId::FLOAT;
+ case hs2::TTypeId::DOUBLE_TYPE:
+ return ColumnType::TypeId::DOUBLE;
+ case hs2::TTypeId::STRING_TYPE:
+ return ColumnType::TypeId::STRING;
+ case hs2::TTypeId::TIMESTAMP_TYPE:
+ return ColumnType::TypeId::TIMESTAMP;
+ case hs2::TTypeId::BINARY_TYPE:
+ return ColumnType::TypeId::BINARY;
+ case hs2::TTypeId::ARRAY_TYPE:
+ return ColumnType::TypeId::ARRAY;
+ case hs2::TTypeId::MAP_TYPE:
+ return ColumnType::TypeId::MAP;
+ case hs2::TTypeId::STRUCT_TYPE:
+ return ColumnType::TypeId::STRUCT;
+ case hs2::TTypeId::UNION_TYPE:
+ return ColumnType::TypeId::UNION;
+ case hs2::TTypeId::USER_DEFINED_TYPE:
+ return ColumnType::TypeId::USER_DEFINED;
+ case hs2::TTypeId::DECIMAL_TYPE:
+ return ColumnType::TypeId::DECIMAL;
+ case hs2::TTypeId::NULL_TYPE:
+ return ColumnType::TypeId::NULL_TYPE;
+ case hs2::TTypeId::DATE_TYPE:
+ return ColumnType::TypeId::DATE;
+ case hs2::TTypeId::VARCHAR_TYPE:
+ return ColumnType::TypeId::VARCHAR;
+ case hs2::TTypeId::CHAR_TYPE:
+ return ColumnType::TypeId::CHAR;
+ default:
+ ARROW_LOG(WARNING) << "Unknown TTypeId " << type_id;
+ return ColumnType::TypeId::INVALID;
+ }
+}
+
+} // namespace hiveserver2
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift_internal.h b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift_internal.h
new file mode 100644
index 000000000..44b6f3642
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/thrift_internal.h
@@ -0,0 +1,91 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/dbi/hiveserver2/columnar_row_set.h"
+#include "arrow/dbi/hiveserver2/operation.h"
+#include "arrow/dbi/hiveserver2/service.h"
+#include "arrow/dbi/hiveserver2/types.h"
+
+#include "arrow/dbi/hiveserver2/ImpalaHiveServer2Service.h"
+#include "arrow/dbi/hiveserver2/TCLIService.h"
+
+namespace arrow {
+namespace hiveserver2 {
+
+// PIMPL structs.
+struct ColumnarRowSet::ColumnarRowSetImpl {
+ apache::hive::service::cli::thrift::TFetchResultsResp resp;
+};
+
+struct Operation::OperationImpl {
+ apache::hive::service::cli::thrift::TOperationHandle handle;
+ apache::hive::service::cli::thrift::TSessionHandle session_handle;
+};
+
+struct ThriftRPC {
+ std::unique_ptr<impala::ImpalaHiveServer2ServiceClient> client;
+};
+
+const std::string OperationStateToString(const Operation::State& state);
+
+const std::string TypeIdToString(const ColumnType::TypeId& type_id);
+
+// Functions for converting Thrift object to hs2client objects and vice-versa.
+apache::hive::service::cli::thrift::TFetchOrientation::type
+FetchOrientationToTFetchOrientation(FetchOrientation orientation);
+
+apache::hive::service::cli::thrift::TProtocolVersion::type
+ProtocolVersionToTProtocolVersion(ProtocolVersion protocol);
+
+Operation::State TOperationStateToOperationState(
+ const apache::hive::service::cli::thrift::TOperationState::type& tstate);
+
+Status TStatusToStatus(const apache::hive::service::cli::thrift::TStatus& tstatus);
+
+// Converts a TTypeDesc to a ColumnType. Currently only primitive types are supported.
+// The converted type is returned as a pointer to allow for polymorphism with ColumnType
+// and its subclasses.
+std::unique_ptr<ColumnType> TTypeDescToColumnType(
+ const apache::hive::service::cli::thrift::TTypeDesc& ttype_desc);
+
+ColumnType::TypeId TTypeIdToTypeId(
+ const apache::hive::service::cli::thrift::TTypeId::type& type_id);
+
+} // namespace hiveserver2
+} // namespace arrow
+
+#define TRY_RPC_OR_RETURN(rpc) \
+ do { \
+ try { \
+ (rpc); \
+ } catch (apache::thrift::TException & tx) { \
+ return Status::IOError(tx.what()); \
+ } \
+ } while (0)
+
+#define THRIFT_RETURN_NOT_OK(tstatus) \
+ do { \
+ if (tstatus.statusCode != hs2::TStatusCode::SUCCESS_STATUS && \
+ tstatus.statusCode != hs2::TStatusCode::SUCCESS_WITH_INFO_STATUS) { \
+ return TStatusToStatus(tstatus); \
+ } \
+ } while (0)
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/types.cc b/src/arrow/cpp/src/arrow/dbi/hiveserver2/types.cc
new file mode 100644
index 000000000..ef2a02ecb
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/types.cc
@@ -0,0 +1,45 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dbi/hiveserver2/types.h"
+
+#include "arrow/dbi/hiveserver2/thrift_internal.h"
+
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace hiveserver2 {
+
+const PrimitiveType* ColumnDesc::GetPrimitiveType() const {
+ return static_cast<PrimitiveType*>(type_.get());
+}
+
+const CharacterType* ColumnDesc::GetCharacterType() const {
+ DCHECK(type_->type_id() == ColumnType::TypeId::CHAR ||
+ type_->type_id() == ColumnType::TypeId::VARCHAR);
+ return static_cast<CharacterType*>(type_.get());
+}
+
+const DecimalType* ColumnDesc::GetDecimalType() const {
+ DCHECK(type_->type_id() == ColumnType::TypeId::DECIMAL);
+ return static_cast<DecimalType*>(type_.get());
+}
+
+std::string PrimitiveType::ToString() const { return TypeIdToString(type_id_); }
+
+} // namespace hiveserver2
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/types.h b/src/arrow/cpp/src/arrow/dbi/hiveserver2/types.h
new file mode 100644
index 000000000..38cebcc2e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/types.h
@@ -0,0 +1,131 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <utility>
+
+namespace arrow {
+namespace hiveserver2 {
+
+// Represents a column's type.
+//
+// For now only PrimitiveType is implemented, as thase are the only types Impala will
+// currently return. In the future, nested types will be represented as other subclasses
+// of ColumnType containing ptrs to other ColumnTypes - for example, an ArrayType subclass
+// would contain a single ptr to another ColumnType representing the type of objects
+// stored in the array.
+class ColumnType {
+ public:
+ virtual ~ColumnType() = default;
+
+ // Maps directly to TTypeId in the HiveServer2 interface.
+ enum class TypeId {
+ BOOLEAN,
+ TINYINT,
+ SMALLINT,
+ INT,
+ BIGINT,
+ FLOAT,
+ DOUBLE,
+ STRING,
+ TIMESTAMP,
+ BINARY,
+ ARRAY,
+ MAP,
+ STRUCT,
+ UNION,
+ USER_DEFINED,
+ DECIMAL,
+ NULL_TYPE,
+ DATE,
+ VARCHAR,
+ CHAR,
+ INVALID,
+ };
+
+ virtual TypeId type_id() const = 0;
+ virtual std::string ToString() const = 0;
+};
+
+class PrimitiveType : public ColumnType {
+ public:
+ explicit PrimitiveType(const TypeId& type_id) : type_id_(type_id) {}
+
+ TypeId type_id() const override { return type_id_; }
+ std::string ToString() const override;
+
+ private:
+ const TypeId type_id_;
+};
+
+// Represents CHAR and VARCHAR types.
+class CharacterType : public PrimitiveType {
+ public:
+ CharacterType(const TypeId& type_id, int max_length)
+ : PrimitiveType(type_id), max_length_(max_length) {}
+
+ int max_length() const { return max_length_; }
+
+ private:
+ const int max_length_;
+};
+
+// Represents DECIMAL types.
+class DecimalType : public PrimitiveType {
+ public:
+ DecimalType(const TypeId& type_id, int precision, int scale)
+ : PrimitiveType(type_id), precision_(precision), scale_(scale) {}
+
+ int precision() const { return precision_; }
+ int scale() const { return scale_; }
+
+ private:
+ const int precision_;
+ const int scale_;
+};
+
+// Represents the metadata for a single column.
+class ColumnDesc {
+ public:
+ ColumnDesc(const std::string& column_name, std::unique_ptr<ColumnType> type,
+ int position, const std::string& comment)
+ : column_name_(column_name),
+ type_(move(type)),
+ position_(position),
+ comment_(comment) {}
+
+ const std::string& column_name() const { return column_name_; }
+ const ColumnType* type() const { return type_.get(); }
+ int position() const { return position_; }
+ const std::string& comment() const { return comment_; }
+
+ const PrimitiveType* GetPrimitiveType() const;
+ const CharacterType* GetCharacterType() const;
+ const DecimalType* GetDecimalType() const;
+
+ private:
+ const std::string column_name_;
+ std::unique_ptr<ColumnType> type_;
+ const int position_;
+ const std::string comment_;
+};
+
+} // namespace hiveserver2
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/util.cc b/src/arrow/cpp/src/arrow/dbi/hiveserver2/util.cc
new file mode 100644
index 000000000..772be4e38
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/util.cc
@@ -0,0 +1,250 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/dbi/hiveserver2/util.h"
+
+#include <algorithm>
+#include <memory>
+#include <sstream>
+#include <vector>
+
+#include "arrow/dbi/hiveserver2/columnar_row_set.h"
+#include "arrow/dbi/hiveserver2/thrift_internal.h"
+
+#include "arrow/dbi/hiveserver2/TCLIService.h"
+#include "arrow/dbi/hiveserver2/TCLIService_types.h"
+
+#include "arrow/status.h"
+
+namespace hs2 = apache::hive::service::cli::thrift;
+using std::unique_ptr;
+
+namespace arrow {
+namespace hiveserver2 {
+
+// PrintResults
+namespace {
+
+const char kNullSymbol[] = "NULL";
+const char kTrueSymbol[] = "true";
+const char kFalseSymbol[] = "false";
+
+struct PrintInfo {
+ // The PrintInfo takes ownership of the Column ptr.
+ PrintInfo(Column* c, size_t m) : column(c), max_size(m) {}
+
+ unique_ptr<Column> column;
+ size_t max_size;
+};
+
+// Adds a horizontal line of '-'s, with '+'s at the column breaks.
+static void AddTableBreak(std::ostream& out, std::vector<PrintInfo>* columns) {
+ for (size_t i = 0; i < columns->size(); ++i) {
+ out << "+";
+ for (size_t j = 0; j < (*columns)[i].max_size + 2; ++j) {
+ out << "-";
+ }
+ }
+ out << "+\n";
+}
+
+// Returns the number of spaces needed to display n, i.e. the number of digits n has,
+// plus 1 if n is negative.
+static size_t NumSpaces(int64_t n) {
+ if (n < 0) {
+ return 1 + NumSpaces(-n);
+ } else if (n < 10) {
+ return 1;
+ } else {
+ return 1 + NumSpaces(n / 10);
+ }
+}
+
+// Returns the max size needed to display a column of integer type.
+template <typename T>
+static size_t GetIntMaxSize(T* column, const std::string& column_name) {
+ size_t max_size = column_name.size();
+ for (int i = 0; i < column->length(); ++i) {
+ if (!column->IsNull(i)) {
+ max_size = std::max(max_size, NumSpaces(column->data()[i]));
+ } else {
+ max_size = std::max(max_size, sizeof(kNullSymbol));
+ }
+ }
+ return max_size;
+}
+
+} // namespace
+
+void Util::PrintResults(const Operation* op, std::ostream& out) {
+ unique_ptr<ColumnarRowSet> results;
+ bool has_more_rows = true;
+ while (has_more_rows) {
+ Status s = op->Fetch(&results, &has_more_rows);
+ if (!s.ok()) {
+ out << s.ToString();
+ return;
+ }
+
+ std::vector<ColumnDesc> column_descs;
+ s = op->GetResultSetMetadata(&column_descs);
+
+ if (!s.ok()) {
+ out << s.ToString();
+ return;
+ } else if (column_descs.size() == 0) {
+ out << "No result set to print.\n";
+ return;
+ }
+
+ std::vector<PrintInfo> columns;
+ for (int i = 0; i < static_cast<int>(column_descs.size()); i++) {
+ const std::string column_name = column_descs[i].column_name();
+ switch (column_descs[i].type()->type_id()) {
+ case ColumnType::TypeId::BOOLEAN: {
+ BoolColumn* bool_col = results->GetBoolCol(i).release();
+
+ // The largest symbol is length 4 unless there is a FALSE, then is it
+ // kFalseSymbol.size() = 5.
+ size_t max_size = std::max(column_name.size(), sizeof(kTrueSymbol));
+ for (int j = 0; j < bool_col->length(); ++j) {
+ if (!bool_col->IsNull(j) && !bool_col->data()[j]) {
+ max_size = std::max(max_size, sizeof(kFalseSymbol));
+ break;
+ }
+ }
+
+ columns.emplace_back(bool_col, max_size);
+ break;
+ }
+ case ColumnType::TypeId::TINYINT: {
+ ByteColumn* byte_col = results->GetByteCol(i).release();
+ columns.emplace_back(byte_col, GetIntMaxSize(byte_col, column_name));
+ break;
+ }
+ case ColumnType::TypeId::SMALLINT: {
+ Int16Column* int16_col = results->GetInt16Col(i).release();
+ columns.emplace_back(int16_col, GetIntMaxSize(int16_col, column_name));
+ break;
+ }
+ case ColumnType::TypeId::INT: {
+ Int32Column* int32_col = results->GetInt32Col(i).release();
+ columns.emplace_back(int32_col, GetIntMaxSize(int32_col, column_name));
+ break;
+ }
+ case ColumnType::TypeId::BIGINT: {
+ Int64Column* int64_col = results->GetInt64Col(i).release();
+ columns.emplace_back(int64_col, GetIntMaxSize(int64_col, column_name));
+ break;
+ }
+ case ColumnType::TypeId::STRING: {
+ unique_ptr<StringColumn> string_col = results->GetStringCol(i);
+
+ size_t max_size = column_name.size();
+ for (int j = 0; j < string_col->length(); ++j) {
+ if (!string_col->IsNull(j)) {
+ max_size = std::max(max_size, string_col->data()[j].size());
+ } else {
+ max_size = std::max(max_size, sizeof(kNullSymbol));
+ }
+ }
+
+ columns.emplace_back(string_col.release(), max_size);
+ break;
+ }
+ case ColumnType::TypeId::BINARY:
+ columns.emplace_back(results->GetBinaryCol(i).release(), column_name.size());
+ break;
+ default: {
+ out << "Unrecognized ColumnType = " << column_descs[i].type()->ToString();
+ }
+ }
+ }
+
+ AddTableBreak(out, &columns);
+ for (size_t i = 0; i < columns.size(); ++i) {
+ out << "| " << column_descs[i].column_name() << " ";
+
+ int padding =
+ static_cast<int>(columns[i].max_size - column_descs[i].column_name().size());
+ while (padding > 0) {
+ out << " ";
+ --padding;
+ }
+ }
+ out << "|\n";
+ AddTableBreak(out, &columns);
+
+ for (int i = 0; i < columns[0].column->length(); ++i) {
+ for (size_t j = 0; j < columns.size(); ++j) {
+ std::stringstream value;
+
+ if (columns[j].column->IsNull(i)) {
+ value << kNullSymbol;
+ } else {
+ switch (column_descs[j].type()->type_id()) {
+ case ColumnType::TypeId::BOOLEAN:
+ if (reinterpret_cast<BoolColumn*>(columns[j].column.get())->data()[i]) {
+ value << kTrueSymbol;
+ } else {
+ value << kFalseSymbol;
+ }
+ break;
+ case ColumnType::TypeId::TINYINT:
+ // The cast prevents us from printing this as a char.
+ value << static_cast<int16_t>(
+ reinterpret_cast<ByteColumn*>(columns[j].column.get())->data()[i]);
+ break;
+ case ColumnType::TypeId::SMALLINT:
+ value << reinterpret_cast<Int16Column*>(columns[j].column.get())->data()[i];
+ break;
+ case ColumnType::TypeId::INT:
+ value << reinterpret_cast<Int32Column*>(columns[j].column.get())->data()[i];
+ break;
+ case ColumnType::TypeId::BIGINT:
+ value << reinterpret_cast<Int64Column*>(columns[j].column.get())->data()[i];
+ break;
+ case ColumnType::TypeId::STRING:
+ value
+ << reinterpret_cast<StringColumn*>(columns[j].column.get())->data()[i];
+ break;
+ case ColumnType::TypeId::BINARY:
+ value
+ << reinterpret_cast<BinaryColumn*>(columns[j].column.get())->data()[i];
+ break;
+ default:
+ value << "unrecognized type";
+ break;
+ }
+ }
+
+ std::string value_str = value.str();
+ out << "| " << value_str << " ";
+ int padding = static_cast<int>(columns[j].max_size - value_str.size());
+ while (padding > 0) {
+ out << " ";
+ --padding;
+ }
+ }
+ out << "|\n";
+ }
+ AddTableBreak(out, &columns);
+ }
+}
+
+} // namespace hiveserver2
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/dbi/hiveserver2/util.h b/src/arrow/cpp/src/arrow/dbi/hiveserver2/util.h
new file mode 100644
index 000000000..a17e7b228
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/dbi/hiveserver2/util.h
@@ -0,0 +1,36 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+
+#include "arrow/dbi/hiveserver2/operation.h"
+
+namespace arrow {
+namespace hiveserver2 {
+
+// Utility functions. Intended primary for testing purposes - clients should not
+// rely on stability of the behavior or API of these functions.
+class Util {
+ public:
+ // Fetches the operation's results and returns them in a nicely formatted string.
+ static void PrintResults(const Operation* op, std::ostream& out);
+};
+
+} // namespace hiveserver2
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/device.cc b/src/arrow/cpp/src/arrow/device.cc
new file mode 100644
index 000000000..1aead49bf
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/device.cc
@@ -0,0 +1,209 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/device.h"
+
+#include <cstring>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/io/memory.h"
+#include "arrow/result.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+MemoryManager::~MemoryManager() {}
+
+Device::~Device() {}
+
+#define COPY_BUFFER_SUCCESS(maybe_buffer) \
+ ((maybe_buffer).ok() && *(maybe_buffer) != nullptr)
+
+#define COPY_BUFFER_RETURN(maybe_buffer, to) \
+ if (!maybe_buffer.ok()) { \
+ return maybe_buffer; \
+ } \
+ if (COPY_BUFFER_SUCCESS(maybe_buffer)) { \
+ DCHECK_EQ(*(**maybe_buffer).device(), *to->device()); \
+ return maybe_buffer; \
+ }
+
+Result<std::shared_ptr<Buffer>> MemoryManager::CopyBuffer(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& to) {
+ const auto& from = buf->memory_manager();
+ auto maybe_buffer = to->CopyBufferFrom(buf, from);
+ COPY_BUFFER_RETURN(maybe_buffer, to);
+ // `to` doesn't support copying from `from`, try the other way
+ maybe_buffer = from->CopyBufferTo(buf, to);
+ COPY_BUFFER_RETURN(maybe_buffer, to);
+ if (!from->is_cpu() && !to->is_cpu()) {
+ // Try an intermediate view on the CPU
+ auto cpu_mm = default_cpu_memory_manager();
+ maybe_buffer = from->ViewBufferTo(buf, cpu_mm);
+ if (!COPY_BUFFER_SUCCESS(maybe_buffer)) {
+ // View failed, try a copy instead
+ // XXX should we have a MemoryManager::IsCopySupportedTo(MemoryManager)
+ // to avoid copying to CPU if copy from CPU to dest is unsupported?
+ maybe_buffer = from->CopyBufferTo(buf, cpu_mm);
+ }
+ if (COPY_BUFFER_SUCCESS(maybe_buffer)) {
+ // Copy from source to CPU succeeded, now try to copy from CPU into dest
+ maybe_buffer = to->CopyBufferFrom(*maybe_buffer, cpu_mm);
+ if (COPY_BUFFER_SUCCESS(maybe_buffer)) {
+ return maybe_buffer;
+ }
+ }
+ }
+
+ return Status::NotImplemented("Copying buffer from ", from->device()->ToString(),
+ " to ", to->device()->ToString(), " not supported");
+}
+
+Result<std::shared_ptr<Buffer>> MemoryManager::ViewBuffer(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& to) {
+ if (buf->memory_manager() == to) {
+ return buf;
+ }
+ const auto& from = buf->memory_manager();
+ auto maybe_buffer = to->ViewBufferFrom(buf, from);
+ COPY_BUFFER_RETURN(maybe_buffer, to);
+ // `to` doesn't support viewing from `from`, try the other way
+ maybe_buffer = from->ViewBufferTo(buf, to);
+ COPY_BUFFER_RETURN(maybe_buffer, to);
+
+ return Status::NotImplemented("Viewing buffer from ", from->device()->ToString(),
+ " on ", to->device()->ToString(), " not supported");
+}
+
+#undef COPY_BUFFER_RETURN
+#undef COPY_BUFFER_SUCCESS
+
+Result<std::shared_ptr<Buffer>> MemoryManager::CopyBufferFrom(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& from) {
+ return nullptr;
+}
+
+Result<std::shared_ptr<Buffer>> MemoryManager::CopyBufferTo(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& to) {
+ return nullptr;
+}
+
+Result<std::shared_ptr<Buffer>> MemoryManager::ViewBufferFrom(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& from) {
+ return nullptr;
+}
+
+Result<std::shared_ptr<Buffer>> MemoryManager::ViewBufferTo(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& to) {
+ return nullptr;
+}
+
+// ----------------------------------------------------------------------
+// CPU backend implementation
+
+namespace {
+const char kCPUDeviceTypeName[] = "arrow::CPUDevice";
+}
+
+std::shared_ptr<MemoryManager> CPUMemoryManager::Make(
+ const std::shared_ptr<Device>& device, MemoryPool* pool) {
+ return std::shared_ptr<MemoryManager>(new CPUMemoryManager(device, pool));
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> CPUMemoryManager::GetBufferReader(
+ std::shared_ptr<Buffer> buf) {
+ return std::make_shared<io::BufferReader>(std::move(buf));
+}
+
+Result<std::shared_ptr<io::OutputStream>> CPUMemoryManager::GetBufferWriter(
+ std::shared_ptr<Buffer> buf) {
+ return std::make_shared<io::FixedSizeBufferWriter>(std::move(buf));
+}
+
+Result<std::shared_ptr<Buffer>> CPUMemoryManager::AllocateBuffer(int64_t size) {
+ return ::arrow::AllocateBuffer(size, pool_);
+}
+
+Result<std::shared_ptr<Buffer>> CPUMemoryManager::CopyBufferFrom(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& from) {
+ if (!from->is_cpu()) {
+ return nullptr;
+ }
+ ARROW_ASSIGN_OR_RAISE(auto dest, ::arrow::AllocateBuffer(buf->size(), pool_));
+ if (buf->size() > 0) {
+ memcpy(dest->mutable_data(), buf->data(), static_cast<size_t>(buf->size()));
+ }
+ return std::move(dest);
+}
+
+Result<std::shared_ptr<Buffer>> CPUMemoryManager::ViewBufferFrom(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& from) {
+ if (!from->is_cpu()) {
+ return nullptr;
+ }
+ return buf;
+}
+
+Result<std::shared_ptr<Buffer>> CPUMemoryManager::CopyBufferTo(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& to) {
+ if (!to->is_cpu()) {
+ return nullptr;
+ }
+ ARROW_ASSIGN_OR_RAISE(auto dest, ::arrow::AllocateBuffer(buf->size(), pool_));
+ if (buf->size() > 0) {
+ memcpy(dest->mutable_data(), buf->data(), static_cast<size_t>(buf->size()));
+ }
+ return std::move(dest);
+}
+
+Result<std::shared_ptr<Buffer>> CPUMemoryManager::ViewBufferTo(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& to) {
+ if (!to->is_cpu()) {
+ return nullptr;
+ }
+ return buf;
+}
+
+std::shared_ptr<MemoryManager> default_cpu_memory_manager() {
+ static auto instance =
+ CPUMemoryManager::Make(CPUDevice::Instance(), default_memory_pool());
+ return instance;
+}
+
+std::shared_ptr<Device> CPUDevice::Instance() {
+ static auto instance = std::shared_ptr<Device>(new CPUDevice());
+ return instance;
+}
+
+const char* CPUDevice::type_name() const { return kCPUDeviceTypeName; }
+
+std::string CPUDevice::ToString() const { return "CPUDevice()"; }
+
+bool CPUDevice::Equals(const Device& other) const {
+ return other.type_name() == kCPUDeviceTypeName;
+}
+
+std::shared_ptr<MemoryManager> CPUDevice::memory_manager(MemoryPool* pool) {
+ return CPUMemoryManager::Make(Instance(), pool);
+}
+
+std::shared_ptr<MemoryManager> CPUDevice::default_memory_manager() {
+ return default_cpu_memory_manager();
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/device.h b/src/arrow/cpp/src/arrow/device.h
new file mode 100644
index 000000000..068be483e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/device.h
@@ -0,0 +1,226 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+
+#include "arrow/io/type_fwd.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/compare.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class MemoryManager;
+
+/// \brief EXPERIMENTAL: Abstract interface for hardware devices
+///
+/// This object represents a device with access to some memory spaces.
+/// When handling a Buffer or raw memory address, it allows deciding in which
+/// context the raw memory address should be interpreted
+/// (e.g. CPU-accessible memory, or embedded memory on some particular GPU).
+class ARROW_EXPORT Device : public std::enable_shared_from_this<Device>,
+ public util::EqualityComparable<Device> {
+ public:
+ virtual ~Device();
+
+ /// \brief A shorthand for this device's type.
+ ///
+ /// The returned value is different for each device class, but is the
+ /// same for all instances of a given class. It can be used as a replacement
+ /// for RTTI.
+ virtual const char* type_name() const = 0;
+
+ /// \brief A human-readable description of the device.
+ ///
+ /// The returned value should be detailed enough to distinguish between
+ /// different instances, where necessary.
+ virtual std::string ToString() const = 0;
+
+ /// \brief Whether this instance points to the same device as another one.
+ virtual bool Equals(const Device&) const = 0;
+
+ /// \brief Whether this device is the main CPU device.
+ ///
+ /// This shorthand method is very useful when deciding whether a memory address
+ /// is CPU-accessible.
+ bool is_cpu() const { return is_cpu_; }
+
+ /// \brief Return a MemoryManager instance tied to this device
+ ///
+ /// The returned instance uses default parameters for this device type's
+ /// MemoryManager implementation. Some devices also allow constructing
+ /// MemoryManager instances with non-default parameters.
+ virtual std::shared_ptr<MemoryManager> default_memory_manager() = 0;
+
+ protected:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Device);
+ explicit Device(bool is_cpu = false) : is_cpu_(is_cpu) {}
+
+ bool is_cpu_;
+};
+
+/// \brief EXPERIMENTAL: An object that provides memory management primitives
+///
+/// A MemoryManager is always tied to a particular Device instance.
+/// It can also have additional parameters (such as a MemoryPool to
+/// allocate CPU memory).
+class ARROW_EXPORT MemoryManager : public std::enable_shared_from_this<MemoryManager> {
+ public:
+ virtual ~MemoryManager();
+
+ /// \brief The device this MemoryManager is tied to
+ const std::shared_ptr<Device>& device() const { return device_; }
+
+ /// \brief Whether this MemoryManager is tied to the main CPU device.
+ ///
+ /// This shorthand method is very useful when deciding whether a memory address
+ /// is CPU-accessible.
+ bool is_cpu() const { return device_->is_cpu(); }
+
+ /// \brief Create a RandomAccessFile to read a particular buffer.
+ ///
+ /// The given buffer must be tied to this MemoryManager.
+ ///
+ /// See also the Buffer::GetReader shorthand.
+ virtual Result<std::shared_ptr<io::RandomAccessFile>> GetBufferReader(
+ std::shared_ptr<Buffer> buf) = 0;
+
+ /// \brief Create a OutputStream to write to a particular buffer.
+ ///
+ /// The given buffer must be mutable and tied to this MemoryManager.
+ /// The returned stream object writes into the buffer's underlying memory
+ /// (but it won't resize it).
+ ///
+ /// See also the Buffer::GetWriter shorthand.
+ virtual Result<std::shared_ptr<io::OutputStream>> GetBufferWriter(
+ std::shared_ptr<Buffer> buf) = 0;
+
+ /// \brief Allocate a (mutable) Buffer
+ ///
+ /// The buffer will be allocated in the device's memory.
+ virtual Result<std::shared_ptr<Buffer>> AllocateBuffer(int64_t size) = 0;
+
+ // XXX Should this take a `const Buffer&` instead
+ /// \brief Copy a Buffer to a destination MemoryManager
+ ///
+ /// See also the Buffer::Copy shorthand.
+ static Result<std::shared_ptr<Buffer>> CopyBuffer(
+ const std::shared_ptr<Buffer>& source, const std::shared_ptr<MemoryManager>& to);
+
+ /// \brief Make a no-copy Buffer view in a destination MemoryManager
+ ///
+ /// See also the Buffer::View shorthand.
+ static Result<std::shared_ptr<Buffer>> ViewBuffer(
+ const std::shared_ptr<Buffer>& source, const std::shared_ptr<MemoryManager>& to);
+
+ protected:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(MemoryManager);
+
+ explicit MemoryManager(const std::shared_ptr<Device>& device) : device_(device) {}
+
+ // Default implementations always return nullptr, should be overridden
+ // by subclasses that support data transfer.
+ // (returning nullptr means unsupported copy / view)
+ // In CopyBufferFrom and ViewBufferFrom, the `from` parameter is guaranteed to
+ // be equal to `buf->memory_manager()`.
+ virtual Result<std::shared_ptr<Buffer>> CopyBufferFrom(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& from);
+ virtual Result<std::shared_ptr<Buffer>> CopyBufferTo(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& to);
+ virtual Result<std::shared_ptr<Buffer>> ViewBufferFrom(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& from);
+ virtual Result<std::shared_ptr<Buffer>> ViewBufferTo(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& to);
+
+ std::shared_ptr<Device> device_;
+};
+
+// ----------------------------------------------------------------------
+// CPU backend implementation
+
+class ARROW_EXPORT CPUDevice : public Device {
+ public:
+ const char* type_name() const override;
+ std::string ToString() const override;
+ bool Equals(const Device&) const override;
+
+ std::shared_ptr<MemoryManager> default_memory_manager() override;
+
+ /// \brief Return the global CPUDevice instance
+ static std::shared_ptr<Device> Instance();
+
+ /// \brief Create a MemoryManager
+ ///
+ /// The returned MemoryManager will use the given MemoryPool for allocations.
+ static std::shared_ptr<MemoryManager> memory_manager(MemoryPool* pool);
+
+ protected:
+ CPUDevice() : Device(true) {}
+};
+
+class ARROW_EXPORT CPUMemoryManager : public MemoryManager {
+ public:
+ Result<std::shared_ptr<io::RandomAccessFile>> GetBufferReader(
+ std::shared_ptr<Buffer> buf) override;
+ Result<std::shared_ptr<io::OutputStream>> GetBufferWriter(
+ std::shared_ptr<Buffer> buf) override;
+
+ Result<std::shared_ptr<Buffer>> AllocateBuffer(int64_t size) override;
+
+ /// \brief Return the MemoryPool associated with this MemoryManager.
+ MemoryPool* pool() const { return pool_; }
+
+ protected:
+ CPUMemoryManager(const std::shared_ptr<Device>& device, MemoryPool* pool)
+ : MemoryManager(device), pool_(pool) {}
+
+ static std::shared_ptr<MemoryManager> Make(const std::shared_ptr<Device>& device,
+ MemoryPool* pool = default_memory_pool());
+
+ Result<std::shared_ptr<Buffer>> CopyBufferFrom(
+ const std::shared_ptr<Buffer>& buf,
+ const std::shared_ptr<MemoryManager>& from) override;
+ Result<std::shared_ptr<Buffer>> CopyBufferTo(
+ const std::shared_ptr<Buffer>& buf,
+ const std::shared_ptr<MemoryManager>& to) override;
+ Result<std::shared_ptr<Buffer>> ViewBufferFrom(
+ const std::shared_ptr<Buffer>& buf,
+ const std::shared_ptr<MemoryManager>& from) override;
+ Result<std::shared_ptr<Buffer>> ViewBufferTo(
+ const std::shared_ptr<Buffer>& buf,
+ const std::shared_ptr<MemoryManager>& to) override;
+
+ MemoryPool* pool_;
+
+ friend std::shared_ptr<MemoryManager> CPUDevice::memory_manager(MemoryPool* pool);
+ friend ARROW_EXPORT std::shared_ptr<MemoryManager> default_cpu_memory_manager();
+};
+
+/// \brief Return the default CPU MemoryManager instance
+///
+/// The returned singleton instance uses the default MemoryPool.
+/// This function is a faster spelling of
+/// `CPUDevice::Instance()->default_memory_manager()`.
+ARROW_EXPORT
+std::shared_ptr<MemoryManager> default_cpu_memory_manager();
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/extension_type.cc b/src/arrow/cpp/src/arrow/extension_type.cc
new file mode 100644
index 000000000..e579b6910
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/extension_type.cc
@@ -0,0 +1,169 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/extension_type.h"
+
+#include <memory>
+#include <mutex>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <utility>
+
+#include "arrow/array/util.h"
+#include "arrow/chunked_array.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+DataTypeLayout ExtensionType::layout() const { return storage_type_->layout(); }
+
+std::string ExtensionType::ToString() const {
+ std::stringstream ss;
+ ss << "extension<" << this->extension_name() << ">";
+ return ss.str();
+}
+
+std::shared_ptr<Array> ExtensionType::WrapArray(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Array>& storage) {
+ DCHECK_EQ(type->id(), Type::EXTENSION);
+ const auto& ext_type = checked_cast<const ExtensionType&>(*type);
+ DCHECK_EQ(storage->type_id(), ext_type.storage_type()->id());
+ auto data = storage->data()->Copy();
+ data->type = type;
+ return ext_type.MakeArray(std::move(data));
+}
+
+std::shared_ptr<ChunkedArray> ExtensionType::WrapArray(
+ const std::shared_ptr<DataType>& type, const std::shared_ptr<ChunkedArray>& storage) {
+ DCHECK_EQ(type->id(), Type::EXTENSION);
+ const auto& ext_type = checked_cast<const ExtensionType&>(*type);
+ DCHECK_EQ(storage->type()->id(), ext_type.storage_type()->id());
+
+ ArrayVector out_chunks(storage->num_chunks());
+ for (int i = 0; i < storage->num_chunks(); i++) {
+ auto data = storage->chunk(i)->data()->Copy();
+ data->type = type;
+ out_chunks[i] = ext_type.MakeArray(std::move(data));
+ }
+ return std::make_shared<ChunkedArray>(std::move(out_chunks));
+}
+
+ExtensionArray::ExtensionArray(const std::shared_ptr<ArrayData>& data) { SetData(data); }
+
+ExtensionArray::ExtensionArray(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Array>& storage) {
+ ARROW_CHECK_EQ(type->id(), Type::EXTENSION);
+ ARROW_CHECK(
+ storage->type()->Equals(*checked_cast<const ExtensionType&>(*type).storage_type()));
+ auto data = storage->data()->Copy();
+ // XXX This pointer is reverted below in SetData()...
+ data->type = type;
+ SetData(data);
+}
+
+void ExtensionArray::SetData(const std::shared_ptr<ArrayData>& data) {
+ ARROW_CHECK_EQ(data->type->id(), Type::EXTENSION);
+ this->Array::SetData(data);
+
+ auto storage_data = data->Copy();
+ storage_data->type = (static_cast<const ExtensionType&>(*data->type).storage_type());
+ storage_ = MakeArray(storage_data);
+}
+
+class ExtensionTypeRegistryImpl : public ExtensionTypeRegistry {
+ public:
+ ExtensionTypeRegistryImpl() {}
+
+ Status RegisterType(std::shared_ptr<ExtensionType> type) override {
+ std::lock_guard<std::mutex> lock(lock_);
+ std::string type_name = type->extension_name();
+ auto it = name_to_type_.find(type_name);
+ if (it != name_to_type_.end()) {
+ return Status::KeyError("A type extension with name ", type_name,
+ " already defined");
+ }
+ name_to_type_[type_name] = std::move(type);
+ return Status::OK();
+ }
+
+ Status UnregisterType(const std::string& type_name) override {
+ std::lock_guard<std::mutex> lock(lock_);
+ auto it = name_to_type_.find(type_name);
+ if (it == name_to_type_.end()) {
+ return Status::KeyError("No type extension with name ", type_name, " found");
+ }
+ name_to_type_.erase(it);
+ return Status::OK();
+ }
+
+ std::shared_ptr<ExtensionType> GetType(const std::string& type_name) override {
+ std::lock_guard<std::mutex> lock(lock_);
+ auto it = name_to_type_.find(type_name);
+ if (it == name_to_type_.end()) {
+ return nullptr;
+ } else {
+ return it->second;
+ }
+ return nullptr;
+ }
+
+ private:
+ std::mutex lock_;
+ std::unordered_map<std::string, std::shared_ptr<ExtensionType>> name_to_type_;
+};
+
+static std::shared_ptr<ExtensionTypeRegistry> g_registry;
+static std::once_flag registry_initialized;
+
+namespace internal {
+
+static void CreateGlobalRegistry() {
+ g_registry = std::make_shared<ExtensionTypeRegistryImpl>();
+}
+
+} // namespace internal
+
+std::shared_ptr<ExtensionTypeRegistry> ExtensionTypeRegistry::GetGlobalRegistry() {
+ std::call_once(registry_initialized, internal::CreateGlobalRegistry);
+ return g_registry;
+}
+
+Status RegisterExtensionType(std::shared_ptr<ExtensionType> type) {
+ auto registry = ExtensionTypeRegistry::GetGlobalRegistry();
+ return registry->RegisterType(type);
+}
+
+Status UnregisterExtensionType(const std::string& type_name) {
+ auto registry = ExtensionTypeRegistry::GetGlobalRegistry();
+ return registry->UnregisterType(type_name);
+}
+
+std::shared_ptr<ExtensionType> GetExtensionType(const std::string& type_name) {
+ auto registry = ExtensionTypeRegistry::GetGlobalRegistry();
+ return registry->GetType(type_name);
+}
+
+extern const char kExtensionTypeKeyName[] = "ARROW:extension:name";
+extern const char kExtensionMetadataKeyName[] = "ARROW:extension:metadata";
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/extension_type.h b/src/arrow/cpp/src/arrow/extension_type.h
new file mode 100644
index 000000000..39cbc805a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/extension_type.h
@@ -0,0 +1,161 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// User-defined extension types.
+/// \since 0.13.0
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/data.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+/// \brief The base class for custom / user-defined types.
+class ARROW_EXPORT ExtensionType : public DataType {
+ public:
+ static constexpr Type::type type_id = Type::EXTENSION;
+
+ static constexpr const char* type_name() { return "extension"; }
+
+ /// \brief The type of array used to represent this extension type's data
+ const std::shared_ptr<DataType>& storage_type() const { return storage_type_; }
+
+ DataTypeLayout layout() const override;
+
+ std::string ToString() const override;
+
+ std::string name() const override { return "extension"; }
+
+ /// \brief Unique name of extension type used to identify type for
+ /// serialization
+ /// \return the string name of the extension
+ virtual std::string extension_name() const = 0;
+
+ /// \brief Determine if two instances of the same extension types are
+ /// equal. Invoked from ExtensionType::Equals
+ /// \param[in] other the type to compare this type with
+ /// \return bool true if type instances are equal
+ virtual bool ExtensionEquals(const ExtensionType& other) const = 0;
+
+ /// \brief Wrap built-in Array type in a user-defined ExtensionArray instance
+ /// \param[in] data the physical storage for the extension type
+ virtual std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const = 0;
+
+ /// \brief Create an instance of the ExtensionType given the actual storage
+ /// type and the serialized representation
+ /// \param[in] storage_type the physical storage type of the extension
+ /// \param[in] serialized_data the serialized representation produced by
+ /// Serialize
+ virtual Result<std::shared_ptr<DataType>> Deserialize(
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized_data) const = 0;
+
+ /// \brief Create a serialized representation of the extension type's
+ /// metadata. The storage type will be handled automatically in IPC code
+ /// paths
+ /// \return the serialized representation
+ virtual std::string Serialize() const = 0;
+
+ /// \brief Wrap the given storage array as an extension array
+ static std::shared_ptr<Array> WrapArray(const std::shared_ptr<DataType>& ext_type,
+ const std::shared_ptr<Array>& storage);
+
+ /// \brief Wrap the given chunked storage array as a chunked extension array
+ static std::shared_ptr<ChunkedArray> WrapArray(
+ const std::shared_ptr<DataType>& ext_type,
+ const std::shared_ptr<ChunkedArray>& storage);
+
+ protected:
+ explicit ExtensionType(std::shared_ptr<DataType> storage_type)
+ : DataType(Type::EXTENSION), storage_type_(storage_type) {}
+
+ std::shared_ptr<DataType> storage_type_;
+};
+
+/// \brief Base array class for user-defined extension types
+class ARROW_EXPORT ExtensionArray : public Array {
+ public:
+ /// \brief Construct an ExtensionArray from an ArrayData.
+ ///
+ /// The ArrayData must have the right ExtensionType.
+ explicit ExtensionArray(const std::shared_ptr<ArrayData>& data);
+
+ /// \brief Construct an ExtensionArray from a type and the underlying storage.
+ ExtensionArray(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Array>& storage);
+
+ const ExtensionType* extension_type() const {
+ return internal::checked_cast<const ExtensionType*>(data_->type.get());
+ }
+
+ /// \brief The physical storage for the extension array
+ const std::shared_ptr<Array>& storage() const { return storage_; }
+
+ protected:
+ void SetData(const std::shared_ptr<ArrayData>& data);
+ std::shared_ptr<Array> storage_;
+};
+
+class ARROW_EXPORT ExtensionTypeRegistry {
+ public:
+ /// \brief Provide access to the global registry to allow code to control for
+ /// race conditions in registry teardown when some types need to be
+ /// unregistered and destroyed first
+ static std::shared_ptr<ExtensionTypeRegistry> GetGlobalRegistry();
+
+ virtual ~ExtensionTypeRegistry() = default;
+
+ virtual Status RegisterType(std::shared_ptr<ExtensionType> type) = 0;
+ virtual Status UnregisterType(const std::string& type_name) = 0;
+ virtual std::shared_ptr<ExtensionType> GetType(const std::string& type_name) = 0;
+};
+
+/// \brief Register an extension type globally. The name returned by the type's
+/// extension_name() method should be unique. This method is thread-safe
+/// \param[in] type an instance of the extension type
+/// \return Status
+ARROW_EXPORT
+Status RegisterExtensionType(std::shared_ptr<ExtensionType> type);
+
+/// \brief Delete an extension type from the global registry. This method is
+/// thread-safe
+/// \param[in] type_name the unique name of a registered extension type
+/// \return Status error if the type name is unknown
+ARROW_EXPORT
+Status UnregisterExtensionType(const std::string& type_name);
+
+/// \brief Retrieve an extension type from the global registry. Returns nullptr
+/// if not found. This method is thread-safe
+/// \return the globally-registered extension type
+ARROW_EXPORT
+std::shared_ptr<ExtensionType> GetExtensionType(const std::string& type_name);
+
+ARROW_EXPORT extern const char kExtensionTypeKeyName[];
+ARROW_EXPORT extern const char kExtensionMetadataKeyName[];
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/extension_type_test.cc b/src/arrow/cpp/src/arrow/extension_type_test.cc
new file mode 100644
index 000000000..31222d748
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/extension_type_test.cc
@@ -0,0 +1,336 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array/array_nested.h"
+#include "arrow/array/util.h"
+#include "arrow/extension_type.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/options.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/testing/extension_type.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+class Parametric1Array : public ExtensionArray {
+ public:
+ using ExtensionArray::ExtensionArray;
+};
+
+class Parametric2Array : public ExtensionArray {
+ public:
+ using ExtensionArray::ExtensionArray;
+};
+
+// A parametric type where the extension_name() is always the same
+class Parametric1Type : public ExtensionType {
+ public:
+ explicit Parametric1Type(int32_t parameter)
+ : ExtensionType(int32()), parameter_(parameter) {}
+
+ int32_t parameter() const { return parameter_; }
+
+ std::string extension_name() const override { return "parametric-type-1"; }
+
+ bool ExtensionEquals(const ExtensionType& other) const override {
+ const auto& other_ext = static_cast<const ExtensionType&>(other);
+ if (other_ext.extension_name() != this->extension_name()) {
+ return false;
+ }
+ return this->parameter() == static_cast<const Parametric1Type&>(other).parameter();
+ }
+
+ std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
+ return std::make_shared<Parametric1Array>(data);
+ }
+
+ Result<std::shared_ptr<DataType>> Deserialize(
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized) const override {
+ DCHECK_EQ(4, serialized.size());
+ const int32_t parameter = *reinterpret_cast<const int32_t*>(serialized.data());
+ DCHECK(storage_type->Equals(int32()));
+ return std::make_shared<Parametric1Type>(parameter);
+ }
+
+ std::string Serialize() const override {
+ std::string result(" ");
+ memcpy(&result[0], &parameter_, sizeof(int32_t));
+ return result;
+ }
+
+ private:
+ int32_t parameter_;
+};
+
+// A parametric type where the extension_name() is different for each
+// parameter, and must be separately registered
+class Parametric2Type : public ExtensionType {
+ public:
+ explicit Parametric2Type(int32_t parameter)
+ : ExtensionType(int32()), parameter_(parameter) {}
+
+ int32_t parameter() const { return parameter_; }
+
+ std::string extension_name() const override {
+ std::stringstream ss;
+ ss << "parametric-type-2<param=" << parameter_ << ">";
+ return ss.str();
+ }
+
+ bool ExtensionEquals(const ExtensionType& other) const override {
+ const auto& other_ext = static_cast<const ExtensionType&>(other);
+ if (other_ext.extension_name() != this->extension_name()) {
+ return false;
+ }
+ return this->parameter() == static_cast<const Parametric2Type&>(other).parameter();
+ }
+
+ std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
+ return std::make_shared<Parametric2Array>(data);
+ }
+
+ Result<std::shared_ptr<DataType>> Deserialize(
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized) const override {
+ DCHECK_EQ(4, serialized.size());
+ const int32_t parameter = *reinterpret_cast<const int32_t*>(serialized.data());
+ DCHECK(storage_type->Equals(int32()));
+ return std::make_shared<Parametric2Type>(parameter);
+ }
+
+ std::string Serialize() const override {
+ std::string result(" ");
+ memcpy(&result[0], &parameter_, sizeof(int32_t));
+ return result;
+ }
+
+ private:
+ int32_t parameter_;
+};
+
+// An extension type with a non-primitive storage type
+class ExtStructArray : public ExtensionArray {
+ public:
+ using ExtensionArray::ExtensionArray;
+};
+
+class ExtStructType : public ExtensionType {
+ public:
+ ExtStructType()
+ : ExtensionType(
+ struct_({::arrow::field("a", int64()), ::arrow::field("b", float64())})) {}
+
+ std::string extension_name() const override { return "ext-struct-type"; }
+
+ bool ExtensionEquals(const ExtensionType& other) const override {
+ const auto& other_ext = static_cast<const ExtensionType&>(other);
+ if (other_ext.extension_name() != this->extension_name()) {
+ return false;
+ }
+ return true;
+ }
+
+ std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
+ return std::make_shared<ExtStructArray>(data);
+ }
+
+ Result<std::shared_ptr<DataType>> Deserialize(
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized) const override {
+ if (serialized != "ext-struct-type-unique-code") {
+ return Status::Invalid("Type identifier did not match");
+ }
+ return std::make_shared<ExtStructType>();
+ }
+
+ std::string Serialize() const override { return "ext-struct-type-unique-code"; }
+};
+
+class TestExtensionType : public ::testing::Test {
+ public:
+ void SetUp() { ASSERT_OK(RegisterExtensionType(std::make_shared<UuidType>())); }
+
+ void TearDown() {
+ if (GetExtensionType("uuid")) {
+ ASSERT_OK(UnregisterExtensionType("uuid"));
+ }
+ }
+};
+
+TEST_F(TestExtensionType, ExtensionTypeTest) {
+ auto type_not_exist = GetExtensionType("uuid-unknown");
+ ASSERT_EQ(type_not_exist, nullptr);
+
+ auto registered_type = GetExtensionType("uuid");
+ ASSERT_NE(registered_type, nullptr);
+
+ auto type = uuid();
+ ASSERT_EQ(type->id(), Type::EXTENSION);
+
+ const auto& ext_type = static_cast<const ExtensionType&>(*type);
+ std::string serialized = ext_type.Serialize();
+
+ ASSERT_OK_AND_ASSIGN(auto deserialized,
+ ext_type.Deserialize(fixed_size_binary(16), serialized));
+ ASSERT_TRUE(deserialized->Equals(*type));
+ ASSERT_FALSE(deserialized->Equals(*fixed_size_binary(16)));
+}
+
+auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch,
+ std::shared_ptr<RecordBatch>* out) {
+ ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
+ ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
+ out_stream.get()));
+
+ ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());
+
+ io::BufferReader reader(complete_ipc_stream);
+ std::shared_ptr<RecordBatchReader> batch_reader;
+ ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
+ ASSERT_OK(batch_reader->ReadNext(out));
+};
+
+TEST_F(TestExtensionType, IpcRoundtrip) {
+ auto ext_arr = ExampleUuid();
+ auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr});
+
+ std::shared_ptr<RecordBatch> read_batch;
+ RoundtripBatch(batch, &read_batch);
+ CompareBatch(*batch, *read_batch, false /* compare_metadata */);
+
+ // Wrap type in a ListArray and ensure it also makes it
+ auto offsets_arr = ArrayFromJSON(int32(), "[0, 0, 2, 4]");
+ ASSERT_OK_AND_ASSIGN(auto list_arr, ListArray::FromArrays(*offsets_arr, *ext_arr));
+ batch = RecordBatch::Make(schema({field("f0", list(uuid()))}), 3, {list_arr});
+ RoundtripBatch(batch, &read_batch);
+ CompareBatch(*batch, *read_batch, false /* compare_metadata */);
+}
+
+TEST_F(TestExtensionType, UnrecognizedExtension) {
+ auto ext_arr = ExampleUuid();
+ auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr});
+
+ auto storage_arr = static_cast<const ExtensionArray&>(*ext_arr).storage();
+
+ // Write full IPC stream including schema, then unregister type, then read
+ // and ensure that a plain instance of the storage type is created
+ ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
+ ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
+ out_stream.get()));
+
+ ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());
+
+ ASSERT_OK(UnregisterExtensionType("uuid"));
+ auto ext_metadata =
+ key_value_metadata({{"ARROW:extension:name", "uuid"},
+ {"ARROW:extension:metadata", "uuid-serialized"}});
+ auto ext_field = field("f0", fixed_size_binary(16), true, ext_metadata);
+ auto batch_no_ext = RecordBatch::Make(schema({ext_field}), 4, {storage_arr});
+
+ io::BufferReader reader(complete_ipc_stream);
+ std::shared_ptr<RecordBatchReader> batch_reader;
+ ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
+ std::shared_ptr<RecordBatch> read_batch;
+ ASSERT_OK(batch_reader->ReadNext(&read_batch));
+ CompareBatch(*batch_no_ext, *read_batch);
+}
+
+std::shared_ptr<Array> ExampleParametric(std::shared_ptr<DataType> type,
+ const std::string& json_data) {
+ auto arr = ArrayFromJSON(int32(), json_data);
+ auto ext_data = arr->data()->Copy();
+ ext_data->type = type;
+ return MakeArray(ext_data);
+}
+
+TEST_F(TestExtensionType, ParametricTypes) {
+ auto p1_type = std::make_shared<Parametric1Type>(6);
+ auto p1 = ExampleParametric(p1_type, "[null, 1, 2, 3]");
+
+ auto p2_type = std::make_shared<Parametric1Type>(12);
+ auto p2 = ExampleParametric(p2_type, "[2, null, 3, 4]");
+
+ auto p3_type = std::make_shared<Parametric2Type>(2);
+ auto p3 = ExampleParametric(p3_type, "[5, 6, 7, 8]");
+
+ auto p4_type = std::make_shared<Parametric2Type>(3);
+ auto p4 = ExampleParametric(p4_type, "[5, 6, 7, 9]");
+
+ ASSERT_OK(RegisterExtensionType(std::make_shared<Parametric1Type>(-1)));
+ ASSERT_OK(RegisterExtensionType(p3_type));
+ ASSERT_OK(RegisterExtensionType(p4_type));
+
+ auto batch = RecordBatch::Make(schema({field("f0", p1_type), field("f1", p2_type),
+ field("f2", p3_type), field("f3", p4_type)}),
+ 4, {p1, p2, p3, p4});
+
+ std::shared_ptr<RecordBatch> read_batch;
+ RoundtripBatch(batch, &read_batch);
+ CompareBatch(*batch, *read_batch, false /* compare_metadata */);
+}
+
+TEST_F(TestExtensionType, ParametricEquals) {
+ auto p1_type = std::make_shared<Parametric1Type>(6);
+ auto p2_type = std::make_shared<Parametric1Type>(6);
+ auto p3_type = std::make_shared<Parametric1Type>(3);
+
+ ASSERT_TRUE(p1_type->Equals(p2_type));
+ ASSERT_FALSE(p1_type->Equals(p3_type));
+
+ ASSERT_EQ(p1_type->fingerprint(), "");
+}
+
+std::shared_ptr<Array> ExampleStruct() {
+ auto ext_type = std::make_shared<ExtStructType>();
+ auto storage_type = ext_type->storage_type();
+ auto arr = ArrayFromJSON(storage_type, "[[1, 0.1], [2, 0.2]]");
+
+ auto ext_data = arr->data()->Copy();
+ ext_data->type = ext_type;
+ return MakeArray(ext_data);
+}
+
+TEST_F(TestExtensionType, ValidateExtensionArray) {
+ auto ext_arr1 = ExampleUuid();
+ auto p1_type = std::make_shared<Parametric1Type>(6);
+ auto ext_arr2 = ExampleParametric(p1_type, "[null, 1, 2, 3]");
+ auto ext_arr3 = ExampleStruct();
+ auto ext_arr4 = ExampleComplex128();
+
+ ASSERT_OK(ext_arr1->ValidateFull());
+ ASSERT_OK(ext_arr2->ValidateFull());
+ ASSERT_OK(ext_arr3->ValidateFull());
+ ASSERT_OK(ext_arr4->ValidateFull());
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/CMakeLists.txt b/src/arrow/cpp/src/arrow/filesystem/CMakeLists.txt
new file mode 100644
index 000000000..67ebe5489
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/CMakeLists.txt
@@ -0,0 +1,79 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Headers: top level
+arrow_install_all_headers("arrow/filesystem")
+
+# pkg-config support
+arrow_add_pkg_config("arrow-filesystem")
+
+add_arrow_test(filesystem-test
+ SOURCES
+ filesystem_test.cc
+ localfs_test.cc
+ EXTRA_LABELS
+ filesystem)
+
+if(ARROW_GCS)
+ add_arrow_test(gcsfs_test EXTRA_LABELS filesystem)
+endif()
+
+if(ARROW_S3)
+ add_arrow_test(s3fs_test EXTRA_LABELS filesystem)
+ if(TARGET arrow-s3fs-test)
+ set(ARROW_S3FS_TEST_COMPILE_DEFINITIONS ${ARROW_BOOST_PROCESS_COMPILE_DEFINITIONS})
+ get_target_property(AWS_CPP_SDK_S3_TYPE aws-cpp-sdk-s3 TYPE)
+ # We need to initialize AWS C++ SDK for direct use (not via
+ # arrow::fs::S3FileSystem) in arrow-s3fs-test if we use static AWS
+ # C++ SDK. Because AWS C++ SDK has internal static variables that
+ # aren't shared in libarrow and arrow-s3fs-test. It means that
+ # arrow::fs::InitializeS3() doesn't initialize AWS C++ SDK that is
+ # directly used in arrow-s3fs-test.
+ #
+ # But it seems that internal static variables in AWS C++ SDK are
+ # shared on macOS even if we link static AWS C++ SDK to both
+ # libarrow and arrow-s3fs-test. So we don't need to initialize AWS
+ # C++ SDK in arrow-s3fs-test on macOS.
+ if(AWS_CPP_SDK_S3_TYPE STREQUAL "STATIC_LIBRARY" AND NOT APPLE)
+ list(APPEND ARROW_S3FS_TEST_COMPILE_DEFINITIONS "AWS_CPP_SDK_S3_NOT_SHARED")
+ endif()
+ target_compile_definitions(arrow-s3fs-test
+ PRIVATE ${ARROW_S3FS_TEST_COMPILE_DEFINITIONS})
+ endif()
+
+ if(ARROW_BUILD_TESTS)
+ add_executable(arrow-s3fs-narrative-test s3fs_narrative_test.cc)
+ target_link_libraries(arrow-s3fs-narrative-test ${ARROW_TEST_LINK_LIBS}
+ ${GFLAGS_LIBRARIES} GTest::gtest)
+ add_dependencies(arrow-tests arrow-s3fs-narrative-test)
+ endif()
+
+ if(ARROW_BUILD_BENCHMARKS AND ARROW_PARQUET)
+ add_arrow_benchmark(s3fs_benchmark PREFIX "arrow-filesystem")
+ target_compile_definitions(arrow-filesystem-s3fs-benchmark
+ PRIVATE ${ARROW_BOOST_PROCESS_COMPILE_DEFINITIONS})
+ if(ARROW_TEST_LINKAGE STREQUAL "static")
+ target_link_libraries(arrow-filesystem-s3fs-benchmark PRIVATE parquet_static)
+ else()
+ target_link_libraries(arrow-filesystem-s3fs-benchmark PRIVATE parquet_shared)
+ endif()
+ endif()
+endif()
+
+if(ARROW_HDFS)
+ add_arrow_test(hdfs_test EXTRA_LABELS filesystem)
+endif()
diff --git a/src/arrow/cpp/src/arrow/filesystem/api.h b/src/arrow/cpp/src/arrow/filesystem/api.h
new file mode 100644
index 000000000..5b0c97d15
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/api.h
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/util/config.h" // IWYU pragma: export
+
+#include "arrow/filesystem/filesystem.h" // IWYU pragma: export
+#include "arrow/filesystem/hdfs.h" // IWYU pragma: export
+#include "arrow/filesystem/localfs.h" // IWYU pragma: export
+#include "arrow/filesystem/mockfs.h" // IWYU pragma: export
+#ifdef ARROW_S3
+#include "arrow/filesystem/s3fs.h" // IWYU pragma: export
+#endif
diff --git a/src/arrow/cpp/src/arrow/filesystem/arrow-filesystem.pc.in b/src/arrow/cpp/src/arrow/filesystem/arrow-filesystem.pc.in
new file mode 100644
index 000000000..4fcc6244f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/arrow-filesystem.pc.in
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+
+Name: Apache Arrow Filesystem
+Description: Filesystem API for accessing local and remote filesystems
+Version: @ARROW_VERSION@
+Requires: arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/filesystem.cc b/src/arrow/cpp/src/arrow/filesystem/filesystem.cc
new file mode 100644
index 000000000..fbe8b1f17
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/filesystem.cc
@@ -0,0 +1,767 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+#include <utility>
+
+#include "arrow/util/config.h"
+
+#include "arrow/filesystem/filesystem.h"
+#ifdef ARROW_HDFS
+#include "arrow/filesystem/hdfs.h"
+#endif
+#ifdef ARROW_S3
+#include "arrow/filesystem/s3fs.h"
+#endif
+#include "arrow/filesystem/localfs.h"
+#include "arrow/filesystem/mockfs.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/filesystem/util_internal.h"
+#include "arrow/io/slow.h"
+#include "arrow/io/util_internal.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/parallel.h"
+#include "arrow/util/uri.h"
+#include "arrow/util/vector.h"
+#include "arrow/util/windows_fixup.h"
+
+namespace arrow {
+
+using internal::checked_pointer_cast;
+using internal::TaskHints;
+using internal::Uri;
+using io::internal::SubmitIO;
+
+namespace fs {
+
+using internal::ConcatAbstractPath;
+using internal::EnsureTrailingSlash;
+using internal::GetAbstractPathParent;
+using internal::kSep;
+using internal::RemoveLeadingSlash;
+using internal::RemoveTrailingSlash;
+using internal::ToSlashes;
+
+std::string ToString(FileType ftype) {
+ switch (ftype) {
+ case FileType::NotFound:
+ return "not-found";
+ case FileType::Unknown:
+ return "unknown";
+ case FileType::File:
+ return "file";
+ case FileType::Directory:
+ return "directory";
+ default:
+ ARROW_LOG(FATAL) << "Invalid FileType value: " << static_cast<int>(ftype);
+ return "???";
+ }
+}
+
+// For googletest
+ARROW_EXPORT std::ostream& operator<<(std::ostream& os, FileType ftype) {
+#define FILE_TYPE_CASE(value_name) \
+ case FileType::value_name: \
+ os << "FileType::" ARROW_STRINGIFY(value_name); \
+ break;
+
+ switch (ftype) {
+ FILE_TYPE_CASE(NotFound)
+ FILE_TYPE_CASE(Unknown)
+ FILE_TYPE_CASE(File)
+ FILE_TYPE_CASE(Directory)
+ default:
+ ARROW_LOG(FATAL) << "Invalid FileType value: " << static_cast<int>(ftype);
+ }
+
+#undef FILE_TYPE_CASE
+ return os;
+}
+
+std::string FileInfo::base_name() const {
+ return internal::GetAbstractPathParent(path_).second;
+}
+
+std::string FileInfo::dir_name() const {
+ return internal::GetAbstractPathParent(path_).first;
+}
+
+// Debug helper
+std::string FileInfo::ToString() const {
+ std::stringstream os;
+ os << *this;
+ return os.str();
+}
+
+std::ostream& operator<<(std::ostream& os, const FileInfo& info) {
+ return os << "FileInfo(" << info.type() << ", " << info.path() << ")";
+}
+
+std::string FileInfo::extension() const {
+ return internal::GetAbstractPathExtension(path_);
+}
+
+//////////////////////////////////////////////////////////////////////////
+// FileSystem default method implementations
+
+FileSystem::~FileSystem() {}
+
+Result<std::string> FileSystem::NormalizePath(std::string path) { return path; }
+
+Result<std::vector<FileInfo>> FileSystem::GetFileInfo(
+ const std::vector<std::string>& paths) {
+ std::vector<FileInfo> res;
+ res.reserve(paths.size());
+ for (const auto& path : paths) {
+ ARROW_ASSIGN_OR_RAISE(FileInfo info, GetFileInfo(path));
+ res.push_back(std::move(info));
+ }
+ return res;
+}
+
+namespace {
+
+template <typename DeferredFunc>
+auto FileSystemDefer(FileSystem* fs, bool synchronous, DeferredFunc&& func)
+ -> decltype(DeferNotOk(
+ fs->io_context().executor()->Submit(func, std::shared_ptr<FileSystem>{}))) {
+ auto self = fs->shared_from_this();
+ if (synchronous) {
+ return std::forward<DeferredFunc>(func)(std::move(self));
+ }
+ return DeferNotOk(io::internal::SubmitIO(
+ fs->io_context(), std::forward<DeferredFunc>(func), std::move(self)));
+}
+
+} // namespace
+
+Future<std::vector<FileInfo>> FileSystem::GetFileInfoAsync(
+ const std::vector<std::string>& paths) {
+ return FileSystemDefer(
+ this, default_async_is_sync_,
+ [paths](std::shared_ptr<FileSystem> self) { return self->GetFileInfo(paths); });
+}
+
+FileInfoGenerator FileSystem::GetFileInfoGenerator(const FileSelector& select) {
+ auto fut = FileSystemDefer(
+ this, default_async_is_sync_,
+ [select](std::shared_ptr<FileSystem> self) { return self->GetFileInfo(select); });
+ return MakeSingleFutureGenerator(std::move(fut));
+}
+
+Status FileSystem::DeleteFiles(const std::vector<std::string>& paths) {
+ Status st = Status::OK();
+ for (const auto& path : paths) {
+ st &= DeleteFile(path);
+ }
+ return st;
+}
+
+namespace {
+
+Status ValidateInputFileInfo(const FileInfo& info) {
+ if (info.type() == FileType::NotFound) {
+ return internal::PathNotFound(info.path());
+ }
+ if (info.type() != FileType::File && info.type() != FileType::Unknown) {
+ return internal::NotAFile(info.path());
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Result<std::shared_ptr<io::InputStream>> FileSystem::OpenInputStream(
+ const FileInfo& info) {
+ RETURN_NOT_OK(ValidateInputFileInfo(info));
+ return OpenInputStream(info.path());
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> FileSystem::OpenInputFile(
+ const FileInfo& info) {
+ RETURN_NOT_OK(ValidateInputFileInfo(info));
+ return OpenInputFile(info.path());
+}
+
+Future<std::shared_ptr<io::InputStream>> FileSystem::OpenInputStreamAsync(
+ const std::string& path) {
+ return FileSystemDefer(
+ this, default_async_is_sync_,
+ [path](std::shared_ptr<FileSystem> self) { return self->OpenInputStream(path); });
+}
+
+Future<std::shared_ptr<io::InputStream>> FileSystem::OpenInputStreamAsync(
+ const FileInfo& info) {
+ RETURN_NOT_OK(ValidateInputFileInfo(info));
+ return FileSystemDefer(
+ this, default_async_is_sync_,
+ [info](std::shared_ptr<FileSystem> self) { return self->OpenInputStream(info); });
+}
+
+Future<std::shared_ptr<io::RandomAccessFile>> FileSystem::OpenInputFileAsync(
+ const std::string& path) {
+ return FileSystemDefer(
+ this, default_async_is_sync_,
+ [path](std::shared_ptr<FileSystem> self) { return self->OpenInputFile(path); });
+}
+
+Future<std::shared_ptr<io::RandomAccessFile>> FileSystem::OpenInputFileAsync(
+ const FileInfo& info) {
+ RETURN_NOT_OK(ValidateInputFileInfo(info));
+ return FileSystemDefer(
+ this, default_async_is_sync_,
+ [info](std::shared_ptr<FileSystem> self) { return self->OpenInputFile(info); });
+}
+
+Result<std::shared_ptr<io::OutputStream>> FileSystem::OpenOutputStream(
+ const std::string& path) {
+ return OpenOutputStream(path, std::shared_ptr<const KeyValueMetadata>{});
+}
+
+Result<std::shared_ptr<io::OutputStream>> FileSystem::OpenAppendStream(
+ const std::string& path) {
+ ARROW_SUPPRESS_DEPRECATION_WARNING
+ return OpenAppendStream(path, std::shared_ptr<const KeyValueMetadata>{});
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
+}
+
+//////////////////////////////////////////////////////////////////////////
+// SubTreeFileSystem implementation
+
+SubTreeFileSystem::SubTreeFileSystem(const std::string& base_path,
+ std::shared_ptr<FileSystem> base_fs)
+ : FileSystem(base_fs->io_context()),
+ base_path_(NormalizeBasePath(base_path, base_fs).ValueOrDie()),
+ base_fs_(base_fs) {}
+
+SubTreeFileSystem::~SubTreeFileSystem() {}
+
+Result<std::string> SubTreeFileSystem::NormalizeBasePath(
+ std::string base_path, const std::shared_ptr<FileSystem>& base_fs) {
+ ARROW_ASSIGN_OR_RAISE(base_path, base_fs->NormalizePath(std::move(base_path)));
+ return EnsureTrailingSlash(std::move(base_path));
+}
+
+bool SubTreeFileSystem::Equals(const FileSystem& other) const {
+ if (this == &other) {
+ return true;
+ }
+ if (other.type_name() != type_name()) {
+ return false;
+ }
+ const auto& subfs = ::arrow::internal::checked_cast<const SubTreeFileSystem&>(other);
+ return base_path_ == subfs.base_path_ && base_fs_->Equals(subfs.base_fs_);
+}
+
+std::string SubTreeFileSystem::PrependBase(const std::string& s) const {
+ if (s.empty()) {
+ return base_path_;
+ } else {
+ return ConcatAbstractPath(base_path_, s);
+ }
+}
+
+Status SubTreeFileSystem::PrependBaseNonEmpty(std::string* s) const {
+ if (s->empty()) {
+ return Status::IOError("Empty path");
+ } else {
+ *s = ConcatAbstractPath(base_path_, *s);
+ return Status::OK();
+ }
+}
+
+Result<std::string> SubTreeFileSystem::StripBase(const std::string& s) const {
+ auto len = base_path_.length();
+ // Note base_path_ ends with a slash (if not empty)
+ if (s.length() >= len && s.substr(0, len) == base_path_) {
+ return s.substr(len);
+ } else {
+ return Status::UnknownError("Underlying filesystem returned path '", s,
+ "', which is not a subpath of '", base_path_, "'");
+ }
+}
+
+Status SubTreeFileSystem::FixInfo(FileInfo* info) const {
+ ARROW_ASSIGN_OR_RAISE(auto fixed_path, StripBase(info->path()));
+ info->set_path(std::move(fixed_path));
+ return Status::OK();
+}
+
+Result<std::string> SubTreeFileSystem::NormalizePath(std::string path) {
+ ARROW_ASSIGN_OR_RAISE(auto normalized, base_fs_->NormalizePath(PrependBase(path)));
+ return StripBase(std::move(normalized));
+}
+
+Result<FileInfo> SubTreeFileSystem::GetFileInfo(const std::string& path) {
+ ARROW_ASSIGN_OR_RAISE(FileInfo info, base_fs_->GetFileInfo(PrependBase(path)));
+ RETURN_NOT_OK(FixInfo(&info));
+ return info;
+}
+
+Result<std::vector<FileInfo>> SubTreeFileSystem::GetFileInfo(const FileSelector& select) {
+ auto selector = select;
+ selector.base_dir = PrependBase(selector.base_dir);
+ ARROW_ASSIGN_OR_RAISE(auto infos, base_fs_->GetFileInfo(selector));
+ for (auto& info : infos) {
+ RETURN_NOT_OK(FixInfo(&info));
+ }
+ return infos;
+}
+
+FileInfoGenerator SubTreeFileSystem::GetFileInfoGenerator(const FileSelector& select) {
+ auto selector = select;
+ selector.base_dir = PrependBase(selector.base_dir);
+ auto gen = base_fs_->GetFileInfoGenerator(selector);
+
+ auto self = checked_pointer_cast<SubTreeFileSystem>(shared_from_this());
+
+ std::function<Result<std::vector<FileInfo>>(const std::vector<FileInfo>& infos)>
+ fix_infos = [self](std::vector<FileInfo> infos) -> Result<std::vector<FileInfo>> {
+ for (auto& info : infos) {
+ RETURN_NOT_OK(self->FixInfo(&info));
+ }
+ return infos;
+ };
+ return MakeMappedGenerator(gen, fix_infos);
+}
+
+Status SubTreeFileSystem::CreateDir(const std::string& path, bool recursive) {
+ auto s = path;
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ return base_fs_->CreateDir(s, recursive);
+}
+
+Status SubTreeFileSystem::DeleteDir(const std::string& path) {
+ auto s = path;
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ return base_fs_->DeleteDir(s);
+}
+
+Status SubTreeFileSystem::DeleteDirContents(const std::string& path) {
+ if (internal::IsEmptyPath(path)) {
+ return internal::InvalidDeleteDirContents(path);
+ }
+ auto s = PrependBase(path);
+ return base_fs_->DeleteDirContents(s);
+}
+
+Status SubTreeFileSystem::DeleteRootDirContents() {
+ if (base_path_.empty()) {
+ return base_fs_->DeleteRootDirContents();
+ } else {
+ return base_fs_->DeleteDirContents(base_path_);
+ }
+}
+
+Status SubTreeFileSystem::DeleteFile(const std::string& path) {
+ auto s = path;
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ return base_fs_->DeleteFile(s);
+}
+
+Status SubTreeFileSystem::Move(const std::string& src, const std::string& dest) {
+ auto s = src;
+ auto d = dest;
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ RETURN_NOT_OK(PrependBaseNonEmpty(&d));
+ return base_fs_->Move(s, d);
+}
+
+Status SubTreeFileSystem::CopyFile(const std::string& src, const std::string& dest) {
+ auto s = src;
+ auto d = dest;
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ RETURN_NOT_OK(PrependBaseNonEmpty(&d));
+ return base_fs_->CopyFile(s, d);
+}
+
+Result<std::shared_ptr<io::InputStream>> SubTreeFileSystem::OpenInputStream(
+ const std::string& path) {
+ auto s = path;
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ return base_fs_->OpenInputStream(s);
+}
+
+Result<std::shared_ptr<io::InputStream>> SubTreeFileSystem::OpenInputStream(
+ const FileInfo& info) {
+ auto s = info.path();
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ FileInfo new_info(info);
+ new_info.set_path(std::move(s));
+ return base_fs_->OpenInputStream(new_info);
+}
+
+Future<std::shared_ptr<io::InputStream>> SubTreeFileSystem::OpenInputStreamAsync(
+ const std::string& path) {
+ auto s = path;
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ return base_fs_->OpenInputStreamAsync(s);
+}
+
+Future<std::shared_ptr<io::InputStream>> SubTreeFileSystem::OpenInputStreamAsync(
+ const FileInfo& info) {
+ auto s = info.path();
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ FileInfo new_info(info);
+ new_info.set_path(std::move(s));
+ return base_fs_->OpenInputStreamAsync(new_info);
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> SubTreeFileSystem::OpenInputFile(
+ const std::string& path) {
+ auto s = path;
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ return base_fs_->OpenInputFile(s);
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> SubTreeFileSystem::OpenInputFile(
+ const FileInfo& info) {
+ auto s = info.path();
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ FileInfo new_info(info);
+ new_info.set_path(std::move(s));
+ return base_fs_->OpenInputFile(new_info);
+}
+
+Future<std::shared_ptr<io::RandomAccessFile>> SubTreeFileSystem::OpenInputFileAsync(
+ const std::string& path) {
+ auto s = path;
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ return base_fs_->OpenInputFileAsync(s);
+}
+
+Future<std::shared_ptr<io::RandomAccessFile>> SubTreeFileSystem::OpenInputFileAsync(
+ const FileInfo& info) {
+ auto s = info.path();
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ FileInfo new_info(info);
+ new_info.set_path(std::move(s));
+ return base_fs_->OpenInputFileAsync(new_info);
+}
+
+Result<std::shared_ptr<io::OutputStream>> SubTreeFileSystem::OpenOutputStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ auto s = path;
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ return base_fs_->OpenOutputStream(s, metadata);
+}
+
+Result<std::shared_ptr<io::OutputStream>> SubTreeFileSystem::OpenAppendStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ auto s = path;
+ RETURN_NOT_OK(PrependBaseNonEmpty(&s));
+ ARROW_SUPPRESS_DEPRECATION_WARNING
+ return base_fs_->OpenAppendStream(s, metadata);
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
+}
+
+//////////////////////////////////////////////////////////////////////////
+// SlowFileSystem implementation
+
+SlowFileSystem::SlowFileSystem(std::shared_ptr<FileSystem> base_fs,
+ std::shared_ptr<io::LatencyGenerator> latencies)
+ : FileSystem(base_fs->io_context()), base_fs_(base_fs), latencies_(latencies) {}
+
+SlowFileSystem::SlowFileSystem(std::shared_ptr<FileSystem> base_fs,
+ double average_latency)
+ : FileSystem(base_fs->io_context()),
+ base_fs_(base_fs),
+ latencies_(io::LatencyGenerator::Make(average_latency)) {}
+
+SlowFileSystem::SlowFileSystem(std::shared_ptr<FileSystem> base_fs,
+ double average_latency, int32_t seed)
+ : FileSystem(base_fs->io_context()),
+ base_fs_(base_fs),
+ latencies_(io::LatencyGenerator::Make(average_latency, seed)) {}
+
+bool SlowFileSystem::Equals(const FileSystem& other) const { return this == &other; }
+
+Result<FileInfo> SlowFileSystem::GetFileInfo(const std::string& path) {
+ latencies_->Sleep();
+ return base_fs_->GetFileInfo(path);
+}
+
+Result<std::vector<FileInfo>> SlowFileSystem::GetFileInfo(const FileSelector& selector) {
+ latencies_->Sleep();
+ return base_fs_->GetFileInfo(selector);
+}
+
+Status SlowFileSystem::CreateDir(const std::string& path, bool recursive) {
+ latencies_->Sleep();
+ return base_fs_->CreateDir(path, recursive);
+}
+
+Status SlowFileSystem::DeleteDir(const std::string& path) {
+ latencies_->Sleep();
+ return base_fs_->DeleteDir(path);
+}
+
+Status SlowFileSystem::DeleteDirContents(const std::string& path) {
+ latencies_->Sleep();
+ return base_fs_->DeleteDirContents(path);
+}
+
+Status SlowFileSystem::DeleteRootDirContents() {
+ latencies_->Sleep();
+ return base_fs_->DeleteRootDirContents();
+}
+
+Status SlowFileSystem::DeleteFile(const std::string& path) {
+ latencies_->Sleep();
+ return base_fs_->DeleteFile(path);
+}
+
+Status SlowFileSystem::Move(const std::string& src, const std::string& dest) {
+ latencies_->Sleep();
+ return base_fs_->Move(src, dest);
+}
+
+Status SlowFileSystem::CopyFile(const std::string& src, const std::string& dest) {
+ latencies_->Sleep();
+ return base_fs_->CopyFile(src, dest);
+}
+
+Result<std::shared_ptr<io::InputStream>> SlowFileSystem::OpenInputStream(
+ const std::string& path) {
+ latencies_->Sleep();
+ ARROW_ASSIGN_OR_RAISE(auto stream, base_fs_->OpenInputStream(path));
+ return std::make_shared<io::SlowInputStream>(stream, latencies_);
+}
+
+Result<std::shared_ptr<io::InputStream>> SlowFileSystem::OpenInputStream(
+ const FileInfo& info) {
+ latencies_->Sleep();
+ ARROW_ASSIGN_OR_RAISE(auto stream, base_fs_->OpenInputStream(info));
+ return std::make_shared<io::SlowInputStream>(stream, latencies_);
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> SlowFileSystem::OpenInputFile(
+ const std::string& path) {
+ latencies_->Sleep();
+ ARROW_ASSIGN_OR_RAISE(auto file, base_fs_->OpenInputFile(path));
+ return std::make_shared<io::SlowRandomAccessFile>(file, latencies_);
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> SlowFileSystem::OpenInputFile(
+ const FileInfo& info) {
+ latencies_->Sleep();
+ ARROW_ASSIGN_OR_RAISE(auto file, base_fs_->OpenInputFile(info));
+ return std::make_shared<io::SlowRandomAccessFile>(file, latencies_);
+}
+
+Result<std::shared_ptr<io::OutputStream>> SlowFileSystem::OpenOutputStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ latencies_->Sleep();
+ // XXX Should we have a SlowOutputStream that waits on Flush() and Close()?
+ return base_fs_->OpenOutputStream(path, metadata);
+}
+
+Result<std::shared_ptr<io::OutputStream>> SlowFileSystem::OpenAppendStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ latencies_->Sleep();
+ ARROW_SUPPRESS_DEPRECATION_WARNING
+ return base_fs_->OpenAppendStream(path, metadata);
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
+}
+
+Status CopyFiles(const std::vector<FileLocator>& sources,
+ const std::vector<FileLocator>& destinations,
+ const io::IOContext& io_context, int64_t chunk_size, bool use_threads) {
+ if (sources.size() != destinations.size()) {
+ return Status::Invalid("Trying to copy ", sources.size(), " files into ",
+ destinations.size(), " paths.");
+ }
+
+ auto copy_one_file = [&](int i) {
+ if (sources[i].filesystem->Equals(destinations[i].filesystem)) {
+ return sources[i].filesystem->CopyFile(sources[i].path, destinations[i].path);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto source,
+ sources[i].filesystem->OpenInputStream(sources[i].path));
+ ARROW_ASSIGN_OR_RAISE(const auto metadata, source->ReadMetadata());
+
+ ARROW_ASSIGN_OR_RAISE(auto destination, destinations[i].filesystem->OpenOutputStream(
+ destinations[i].path, metadata));
+ RETURN_NOT_OK(internal::CopyStream(source, destination, chunk_size, io_context));
+ return destination->Close();
+ };
+
+ return ::arrow::internal::OptionalParallelFor(
+ use_threads, static_cast<int>(sources.size()), std::move(copy_one_file),
+ io_context.executor());
+}
+
+Status CopyFiles(const std::shared_ptr<FileSystem>& source_fs,
+ const FileSelector& source_sel,
+ const std::shared_ptr<FileSystem>& destination_fs,
+ const std::string& destination_base_dir, const io::IOContext& io_context,
+ int64_t chunk_size, bool use_threads) {
+ ARROW_ASSIGN_OR_RAISE(auto source_infos, source_fs->GetFileInfo(source_sel));
+ if (source_infos.empty()) {
+ return Status::OK();
+ }
+
+ std::vector<FileLocator> sources, destinations;
+ std::vector<std::string> dirs;
+
+ for (const FileInfo& source_info : source_infos) {
+ auto relative = internal::RemoveAncestor(source_sel.base_dir, source_info.path());
+ if (!relative.has_value()) {
+ return Status::Invalid("GetFileInfo() yielded path '", source_info.path(),
+ "', which is outside base dir '", source_sel.base_dir, "'");
+ }
+
+ auto destination_path =
+ internal::ConcatAbstractPath(destination_base_dir, relative->to_string());
+
+ if (source_info.IsDirectory()) {
+ dirs.push_back(destination_path);
+ } else if (source_info.IsFile()) {
+ sources.push_back({source_fs, source_info.path()});
+ destinations.push_back({destination_fs, destination_path});
+ }
+ }
+
+ auto create_one_dir = [&](int i) { return destination_fs->CreateDir(dirs[i]); };
+
+ dirs = internal::MinimalCreateDirSet(std::move(dirs));
+ RETURN_NOT_OK(::arrow::internal::OptionalParallelFor(
+ use_threads, static_cast<int>(dirs.size()), std::move(create_one_dir),
+ io_context.executor()));
+
+ return CopyFiles(sources, destinations, io_context, chunk_size, use_threads);
+}
+
+namespace {
+
+Result<Uri> ParseFileSystemUri(const std::string& uri_string) {
+ Uri uri;
+ auto status = uri.Parse(uri_string);
+ if (!status.ok()) {
+#ifdef _WIN32
+ // Could be a "file:..." URI with backslashes instead of regular slashes.
+ RETURN_NOT_OK(uri.Parse(ToSlashes(uri_string)));
+ if (uri.scheme() != "file") {
+ return status;
+ }
+#else
+ return status;
+#endif
+ }
+ return std::move(uri);
+}
+
+Result<std::shared_ptr<FileSystem>> FileSystemFromUriReal(const Uri& uri,
+ const std::string& uri_string,
+ const io::IOContext& io_context,
+ std::string* out_path) {
+ const auto scheme = uri.scheme();
+
+ if (scheme == "file") {
+ std::string path;
+ ARROW_ASSIGN_OR_RAISE(auto options, LocalFileSystemOptions::FromUri(uri, &path));
+ if (out_path != nullptr) {
+ *out_path = path;
+ }
+ return std::make_shared<LocalFileSystem>(options, io_context);
+ }
+ if (scheme == "hdfs" || scheme == "viewfs") {
+#ifdef ARROW_HDFS
+ ARROW_ASSIGN_OR_RAISE(auto options, HdfsOptions::FromUri(uri));
+ if (out_path != nullptr) {
+ *out_path = uri.path();
+ }
+ ARROW_ASSIGN_OR_RAISE(auto hdfs, HadoopFileSystem::Make(options, io_context));
+ return hdfs;
+#else
+ return Status::NotImplemented("Got HDFS URI but Arrow compiled without HDFS support");
+#endif
+ }
+ if (scheme == "s3") {
+#ifdef ARROW_S3
+ RETURN_NOT_OK(EnsureS3Initialized());
+ ARROW_ASSIGN_OR_RAISE(auto options, S3Options::FromUri(uri, out_path));
+ ARROW_ASSIGN_OR_RAISE(auto s3fs, S3FileSystem::Make(options, io_context));
+ return s3fs;
+#else
+ return Status::NotImplemented("Got S3 URI but Arrow compiled without S3 support");
+#endif
+ }
+
+ if (scheme == "mock") {
+ // MockFileSystem does not have an absolute / relative path distinction,
+ // normalize path by removing leading slash.
+ if (out_path != nullptr) {
+ *out_path = std::string(RemoveLeadingSlash(uri.path()));
+ }
+ return std::make_shared<internal::MockFileSystem>(internal::CurrentTimePoint(),
+ io_context);
+ }
+
+ return Status::Invalid("Unrecognized filesystem type in URI: ", uri_string);
+}
+
+} // namespace
+
+Result<std::shared_ptr<FileSystem>> FileSystemFromUri(const std::string& uri_string,
+ std::string* out_path) {
+ return FileSystemFromUri(uri_string, io::default_io_context(), out_path);
+}
+
+Result<std::shared_ptr<FileSystem>> FileSystemFromUri(const std::string& uri_string,
+ const io::IOContext& io_context,
+ std::string* out_path) {
+ ARROW_ASSIGN_OR_RAISE(auto fsuri, ParseFileSystemUri(uri_string));
+ return FileSystemFromUriReal(fsuri, uri_string, io_context, out_path);
+}
+
+Result<std::shared_ptr<FileSystem>> FileSystemFromUriOrPath(const std::string& uri_string,
+ std::string* out_path) {
+ return FileSystemFromUriOrPath(uri_string, io::default_io_context(), out_path);
+}
+
+Result<std::shared_ptr<FileSystem>> FileSystemFromUriOrPath(
+ const std::string& uri_string, const io::IOContext& io_context,
+ std::string* out_path) {
+ if (internal::DetectAbsolutePath(uri_string)) {
+ // Normalize path separators
+ if (out_path != nullptr) {
+ *out_path = ToSlashes(uri_string);
+ }
+ return std::make_shared<LocalFileSystem>();
+ }
+ return FileSystemFromUri(uri_string, io_context, out_path);
+}
+
+Status FileSystemFromUri(const std::string& uri, std::shared_ptr<FileSystem>* out_fs,
+ std::string* out_path) {
+ return FileSystemFromUri(uri, out_path).Value(out_fs);
+}
+
+Status Initialize(const FileSystemGlobalOptions& options) {
+ internal::global_options = options;
+ return Status::OK();
+}
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/filesystem.h b/src/arrow/cpp/src/arrow/filesystem/filesystem.h
new file mode 100644
index 000000000..6a36d51e9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/filesystem.h
@@ -0,0 +1,535 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <chrono>
+#include <cstdint>
+#include <functional>
+#include <iosfwd>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/filesystem/type_fwd.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/compare.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/type_fwd.h"
+#include "arrow/util/visibility.h"
+#include "arrow/util/windows_fixup.h"
+
+namespace arrow {
+namespace fs {
+
+// A system clock time point expressed as a 64-bit (or more) number of
+// nanoseconds since the epoch.
+using TimePoint =
+ std::chrono::time_point<std::chrono::system_clock, std::chrono::nanoseconds>;
+
+ARROW_EXPORT std::string ToString(FileType);
+
+ARROW_EXPORT std::ostream& operator<<(std::ostream& os, FileType);
+
+static const int64_t kNoSize = -1;
+static const TimePoint kNoTime = TimePoint(TimePoint::duration(-1));
+
+/// \brief FileSystem entry info
+struct ARROW_EXPORT FileInfo : public util::EqualityComparable<FileInfo> {
+ FileInfo() = default;
+ FileInfo(FileInfo&&) = default;
+ FileInfo& operator=(FileInfo&&) = default;
+ FileInfo(const FileInfo&) = default;
+ FileInfo& operator=(const FileInfo&) = default;
+
+ explicit FileInfo(std::string path, FileType type = FileType::Unknown)
+ : path_(std::move(path)), type_(type) {}
+
+ /// The file type
+ FileType type() const { return type_; }
+ void set_type(FileType type) { type_ = type; }
+
+ /// The full file path in the filesystem
+ const std::string& path() const { return path_; }
+ void set_path(std::string path) { path_ = std::move(path); }
+
+ /// The file base name (component after the last directory separator)
+ std::string base_name() const;
+
+ // The directory base name (component before the file base name).
+ std::string dir_name() const;
+
+ /// The size in bytes, if available
+ ///
+ /// Only regular files are guaranteed to have a size.
+ int64_t size() const { return size_; }
+ void set_size(int64_t size) { size_ = size; }
+
+ /// The file extension (excluding the dot)
+ std::string extension() const;
+
+ /// The time of last modification, if available
+ TimePoint mtime() const { return mtime_; }
+ void set_mtime(TimePoint mtime) { mtime_ = mtime; }
+
+ bool IsFile() const { return type_ == FileType::File; }
+ bool IsDirectory() const { return type_ == FileType::Directory; }
+
+ bool Equals(const FileInfo& other) const {
+ return type() == other.type() && path() == other.path() && size() == other.size() &&
+ mtime() == other.mtime();
+ }
+
+ std::string ToString() const;
+
+ /// Function object implementing less-than comparison and hashing by
+ /// path, to support sorting infos, using them as keys, and other
+ /// interactions with the STL.
+ struct ByPath {
+ bool operator()(const FileInfo& l, const FileInfo& r) const {
+ return l.path() < r.path();
+ }
+
+ size_t operator()(const FileInfo& i) const {
+ return std::hash<std::string>{}(i.path());
+ }
+ };
+
+ protected:
+ std::string path_;
+ FileType type_ = FileType::Unknown;
+ int64_t size_ = kNoSize;
+ TimePoint mtime_ = kNoTime;
+};
+
+ARROW_EXPORT std::ostream& operator<<(std::ostream& os, const FileInfo&);
+
+/// \brief File selector for filesystem APIs
+struct ARROW_EXPORT FileSelector {
+ /// The directory in which to select files.
+ /// If the path exists but doesn't point to a directory, this should be an error.
+ std::string base_dir;
+ /// The behavior if `base_dir` isn't found in the filesystem. If false,
+ /// an error is returned. If true, an empty selection is returned.
+ bool allow_not_found;
+ /// Whether to recurse into subdirectories.
+ bool recursive;
+ /// The maximum number of subdirectories to recurse into.
+ int32_t max_recursion;
+
+ FileSelector() : allow_not_found(false), recursive(false), max_recursion(INT32_MAX) {}
+};
+
+/// \brief FileSystem, path pair
+struct ARROW_EXPORT FileLocator {
+ std::shared_ptr<FileSystem> filesystem;
+ std::string path;
+};
+
+using FileInfoVector = std::vector<FileInfo>;
+using FileInfoGenerator = std::function<Future<FileInfoVector>()>;
+
+} // namespace fs
+
+template <>
+struct IterationTraits<fs::FileInfoVector> {
+ static fs::FileInfoVector End() { return {}; }
+ static bool IsEnd(const fs::FileInfoVector& val) { return val.empty(); }
+};
+
+namespace fs {
+
+/// \brief Abstract file system API
+class ARROW_EXPORT FileSystem : public std::enable_shared_from_this<FileSystem> {
+ public:
+ virtual ~FileSystem();
+
+ virtual std::string type_name() const = 0;
+
+ /// EXPERIMENTAL: The IOContext associated with this filesystem.
+ const io::IOContext& io_context() const { return io_context_; }
+
+ /// Normalize path for the given filesystem
+ ///
+ /// The default implementation of this method is a no-op, but subclasses
+ /// may allow normalizing irregular path forms (such as Windows local paths).
+ virtual Result<std::string> NormalizePath(std::string path);
+
+ virtual bool Equals(const FileSystem& other) const = 0;
+
+ virtual bool Equals(const std::shared_ptr<FileSystem>& other) const {
+ return Equals(*other);
+ }
+
+ /// Get info for the given target.
+ ///
+ /// Any symlink is automatically dereferenced, recursively.
+ /// A nonexistent or unreachable file returns an Ok status and
+ /// has a FileType of value NotFound. An error status indicates
+ /// a truly exceptional condition (low-level I/O error, etc.).
+ virtual Result<FileInfo> GetFileInfo(const std::string& path) = 0;
+ /// Same, for many targets at once.
+ virtual Result<FileInfoVector> GetFileInfo(const std::vector<std::string>& paths);
+ /// Same, according to a selector.
+ ///
+ /// The selector's base directory will not be part of the results, even if
+ /// it exists.
+ /// If it doesn't exist, see `FileSelector::allow_not_found`.
+ virtual Result<FileInfoVector> GetFileInfo(const FileSelector& select) = 0;
+
+ /// Async version of GetFileInfo
+ virtual Future<FileInfoVector> GetFileInfoAsync(const std::vector<std::string>& paths);
+
+ /// Streaming async version of GetFileInfo
+ ///
+ /// The returned generator is not async-reentrant, i.e. you need to wait for
+ /// the returned future to complete before calling the generator again.
+ virtual FileInfoGenerator GetFileInfoGenerator(const FileSelector& select);
+
+ /// Create a directory and subdirectories.
+ ///
+ /// This function succeeds if the directory already exists.
+ virtual Status CreateDir(const std::string& path, bool recursive = true) = 0;
+
+ /// Delete a directory and its contents, recursively.
+ virtual Status DeleteDir(const std::string& path) = 0;
+
+ /// Delete a directory's contents, recursively.
+ ///
+ /// Like DeleteDir, but doesn't delete the directory itself.
+ /// Passing an empty path ("" or "/") is disallowed, see DeleteRootDirContents.
+ virtual Status DeleteDirContents(const std::string& path) = 0;
+
+ /// EXPERIMENTAL: Delete the root directory's contents, recursively.
+ ///
+ /// Implementations may decide to raise an error if this operation is
+ /// too dangerous.
+ // NOTE: may decide to remove this if it's deemed not useful
+ virtual Status DeleteRootDirContents() = 0;
+
+ /// Delete a file.
+ virtual Status DeleteFile(const std::string& path) = 0;
+ /// Delete many files.
+ ///
+ /// The default implementation issues individual delete operations in sequence.
+ virtual Status DeleteFiles(const std::vector<std::string>& paths);
+
+ /// Move / rename a file or directory.
+ ///
+ /// If the destination exists:
+ /// - if it is a non-empty directory, an error is returned
+ /// - otherwise, if it has the same type as the source, it is replaced
+ /// - otherwise, behavior is unspecified (implementation-dependent).
+ virtual Status Move(const std::string& src, const std::string& dest) = 0;
+
+ /// Copy a file.
+ ///
+ /// If the destination exists and is a directory, an error is returned.
+ /// Otherwise, it is replaced.
+ virtual Status CopyFile(const std::string& src, const std::string& dest) = 0;
+
+ /// Open an input stream for sequential reading.
+ virtual Result<std::shared_ptr<io::InputStream>> OpenInputStream(
+ const std::string& path) = 0;
+ /// Open an input stream for sequential reading.
+ ///
+ /// This override assumes the given FileInfo validly represents the file's
+ /// characteristics, and may optimize access depending on them (for example
+ /// avoid querying the file size or its existence).
+ virtual Result<std::shared_ptr<io::InputStream>> OpenInputStream(const FileInfo& info);
+
+ /// Open an input file for random access reading.
+ virtual Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(
+ const std::string& path) = 0;
+ /// Open an input file for random access reading.
+ ///
+ /// This override assumes the given FileInfo validly represents the file's
+ /// characteristics, and may optimize access depending on them (for example
+ /// avoid querying the file size or its existence).
+ virtual Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(
+ const FileInfo& info);
+
+ /// Async version of OpenInputStream
+ virtual Future<std::shared_ptr<io::InputStream>> OpenInputStreamAsync(
+ const std::string& path);
+ /// Async version of OpenInputStream
+ virtual Future<std::shared_ptr<io::InputStream>> OpenInputStreamAsync(
+ const FileInfo& info);
+
+ /// Async version of OpenInputFile
+ virtual Future<std::shared_ptr<io::RandomAccessFile>> OpenInputFileAsync(
+ const std::string& path);
+ /// Async version of OpenInputFile
+ virtual Future<std::shared_ptr<io::RandomAccessFile>> OpenInputFileAsync(
+ const FileInfo& info);
+
+ /// Open an output stream for sequential writing.
+ ///
+ /// If the target already exists, existing data is truncated.
+ virtual Result<std::shared_ptr<io::OutputStream>> OpenOutputStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata) = 0;
+ Result<std::shared_ptr<io::OutputStream>> OpenOutputStream(const std::string& path);
+
+ /// Open an output stream for appending.
+ ///
+ /// If the target doesn't exist, a new empty file is created.
+ ARROW_DEPRECATED(
+ "Deprecated in 6.0.0. "
+ "OpenAppendStream is unsupported on several filesystems and will be later removed.")
+ virtual Result<std::shared_ptr<io::OutputStream>> OpenAppendStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata) = 0;
+ Result<std::shared_ptr<io::OutputStream>> OpenAppendStream(const std::string& path);
+
+ protected:
+ explicit FileSystem(const io::IOContext& io_context = io::default_io_context())
+ : io_context_(io_context) {}
+
+ io::IOContext io_context_;
+ // Whether metadata operations (such as GetFileInfo or OpenInputStream)
+ // are cheap enough that the default async variants don't bother with
+ // a thread pool.
+ bool default_async_is_sync_ = true;
+};
+
+/// \brief A FileSystem implementation that delegates to another
+/// implementation after prepending a fixed base path.
+///
+/// This is useful to expose a logical view of a subtree of a filesystem,
+/// for example a directory in a LocalFileSystem.
+/// This works on abstract paths, i.e. paths using forward slashes and
+/// and a single root "/". Windows paths are not guaranteed to work.
+/// This makes no security guarantee. For example, symlinks may allow to
+/// "escape" the subtree and access other parts of the underlying filesystem.
+class ARROW_EXPORT SubTreeFileSystem : public FileSystem {
+ public:
+ // This constructor may abort if base_path is invalid.
+ explicit SubTreeFileSystem(const std::string& base_path,
+ std::shared_ptr<FileSystem> base_fs);
+ ~SubTreeFileSystem() override;
+
+ std::string type_name() const override { return "subtree"; }
+ std::string base_path() const { return base_path_; }
+ std::shared_ptr<FileSystem> base_fs() const { return base_fs_; }
+
+ Result<std::string> NormalizePath(std::string path) override;
+
+ bool Equals(const FileSystem& other) const override;
+
+ /// \cond FALSE
+ using FileSystem::GetFileInfo;
+ /// \endcond
+ Result<FileInfo> GetFileInfo(const std::string& path) override;
+ Result<FileInfoVector> GetFileInfo(const FileSelector& select) override;
+
+ FileInfoGenerator GetFileInfoGenerator(const FileSelector& select) override;
+
+ Status CreateDir(const std::string& path, bool recursive = true) override;
+
+ Status DeleteDir(const std::string& path) override;
+ Status DeleteDirContents(const std::string& path) override;
+ Status DeleteRootDirContents() override;
+
+ Status DeleteFile(const std::string& path) override;
+
+ Status Move(const std::string& src, const std::string& dest) override;
+
+ Status CopyFile(const std::string& src, const std::string& dest) override;
+
+ Result<std::shared_ptr<io::InputStream>> OpenInputStream(
+ const std::string& path) override;
+ Result<std::shared_ptr<io::InputStream>> OpenInputStream(const FileInfo& info) override;
+ Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(
+ const std::string& path) override;
+ Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(
+ const FileInfo& info) override;
+
+ Future<std::shared_ptr<io::InputStream>> OpenInputStreamAsync(
+ const std::string& path) override;
+ Future<std::shared_ptr<io::InputStream>> OpenInputStreamAsync(
+ const FileInfo& info) override;
+ Future<std::shared_ptr<io::RandomAccessFile>> OpenInputFileAsync(
+ const std::string& path) override;
+ Future<std::shared_ptr<io::RandomAccessFile>> OpenInputFileAsync(
+ const FileInfo& info) override;
+
+ Result<std::shared_ptr<io::OutputStream>> OpenOutputStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+ Result<std::shared_ptr<io::OutputStream>> OpenAppendStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+
+ protected:
+ SubTreeFileSystem() {}
+
+ const std::string base_path_;
+ std::shared_ptr<FileSystem> base_fs_;
+
+ std::string PrependBase(const std::string& s) const;
+ Status PrependBaseNonEmpty(std::string* s) const;
+ Result<std::string> StripBase(const std::string& s) const;
+ Status FixInfo(FileInfo* info) const;
+
+ static Result<std::string> NormalizeBasePath(
+ std::string base_path, const std::shared_ptr<FileSystem>& base_fs);
+};
+
+/// \brief A FileSystem implementation that delegates to another
+/// implementation but inserts latencies at various points.
+class ARROW_EXPORT SlowFileSystem : public FileSystem {
+ public:
+ SlowFileSystem(std::shared_ptr<FileSystem> base_fs,
+ std::shared_ptr<io::LatencyGenerator> latencies);
+ SlowFileSystem(std::shared_ptr<FileSystem> base_fs, double average_latency);
+ SlowFileSystem(std::shared_ptr<FileSystem> base_fs, double average_latency,
+ int32_t seed);
+
+ std::string type_name() const override { return "slow"; }
+ bool Equals(const FileSystem& other) const override;
+
+ using FileSystem::GetFileInfo;
+ Result<FileInfo> GetFileInfo(const std::string& path) override;
+ Result<FileInfoVector> GetFileInfo(const FileSelector& select) override;
+
+ Status CreateDir(const std::string& path, bool recursive = true) override;
+
+ Status DeleteDir(const std::string& path) override;
+ Status DeleteDirContents(const std::string& path) override;
+ Status DeleteRootDirContents() override;
+
+ Status DeleteFile(const std::string& path) override;
+
+ Status Move(const std::string& src, const std::string& dest) override;
+
+ Status CopyFile(const std::string& src, const std::string& dest) override;
+
+ Result<std::shared_ptr<io::InputStream>> OpenInputStream(
+ const std::string& path) override;
+ Result<std::shared_ptr<io::InputStream>> OpenInputStream(const FileInfo& info) override;
+ Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(
+ const std::string& path) override;
+ Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(
+ const FileInfo& info) override;
+ Result<std::shared_ptr<io::OutputStream>> OpenOutputStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+ Result<std::shared_ptr<io::OutputStream>> OpenAppendStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+
+ protected:
+ std::shared_ptr<FileSystem> base_fs_;
+ std::shared_ptr<io::LatencyGenerator> latencies_;
+};
+
+/// \defgroup filesystem-factories Functions for creating FileSystem instances
+///
+/// @{
+
+/// \brief Create a new FileSystem by URI
+///
+/// Recognized schemes are "file", "mock", "hdfs" and "s3fs".
+///
+/// \param[in] uri a URI-based path, ex: file:///some/local/path
+/// \param[out] out_path (optional) Path inside the filesystem.
+/// \return out_fs FileSystem instance.
+ARROW_EXPORT
+Result<std::shared_ptr<FileSystem>> FileSystemFromUri(const std::string& uri,
+ std::string* out_path = NULLPTR);
+
+/// \brief Create a new FileSystem by URI with a custom IO context
+///
+/// Recognized schemes are "file", "mock", "hdfs" and "s3fs".
+///
+/// \param[in] uri a URI-based path, ex: file:///some/local/path
+/// \param[in] io_context an IOContext which will be associated with the filesystem
+/// \param[out] out_path (optional) Path inside the filesystem.
+/// \return out_fs FileSystem instance.
+ARROW_EXPORT
+Result<std::shared_ptr<FileSystem>> FileSystemFromUri(const std::string& uri,
+ const io::IOContext& io_context,
+ std::string* out_path = NULLPTR);
+
+/// \brief Create a new FileSystem by URI
+///
+/// Same as FileSystemFromUri, but in addition also recognize non-URIs
+/// and treat them as local filesystem paths. Only absolute local filesystem
+/// paths are allowed.
+ARROW_EXPORT
+Result<std::shared_ptr<FileSystem>> FileSystemFromUriOrPath(
+ const std::string& uri, std::string* out_path = NULLPTR);
+
+/// \brief Create a new FileSystem by URI with a custom IO context
+///
+/// Same as FileSystemFromUri, but in addition also recognize non-URIs
+/// and treat them as local filesystem paths. Only absolute local filesystem
+/// paths are allowed.
+ARROW_EXPORT
+Result<std::shared_ptr<FileSystem>> FileSystemFromUriOrPath(
+ const std::string& uri, const io::IOContext& io_context,
+ std::string* out_path = NULLPTR);
+
+/// @}
+
+/// \brief Copy files, including from one FileSystem to another
+///
+/// If a source and destination are resident in the same FileSystem FileSystem::CopyFile
+/// will be used, otherwise the file will be opened as a stream in both FileSystems and
+/// chunks copied from the source to the destination. No directories will be created.
+ARROW_EXPORT
+Status CopyFiles(const std::vector<FileLocator>& sources,
+ const std::vector<FileLocator>& destinations,
+ const io::IOContext& io_context = io::default_io_context(),
+ int64_t chunk_size = 1024 * 1024, bool use_threads = true);
+
+/// \brief Copy selected files, including from one FileSystem to another
+///
+/// Directories will be created under the destination base directory as needed.
+ARROW_EXPORT
+Status CopyFiles(const std::shared_ptr<FileSystem>& source_fs,
+ const FileSelector& source_sel,
+ const std::shared_ptr<FileSystem>& destination_fs,
+ const std::string& destination_base_dir,
+ const io::IOContext& io_context = io::default_io_context(),
+ int64_t chunk_size = 1024 * 1024, bool use_threads = true);
+
+struct FileSystemGlobalOptions {
+ /// Path to a single PEM file holding all TLS CA certificates
+ ///
+ /// If empty, the underlying TLS library's defaults will be used.
+ std::string tls_ca_file_path;
+
+ /// Path to a directory holding TLS CA certificates in individual PEM files
+ /// named along the OpenSSL "hashed" format.
+ ///
+ /// If empty, the underlying TLS library's defaults will be used.
+ std::string tls_ca_dir_path;
+};
+
+/// EXPERIMENTAL: optional global initialization routine
+///
+/// This is for environments (such as manylinux) where the path
+/// to TLS CA certificates needs to be configured at runtime.
+ARROW_EXPORT
+Status Initialize(const FileSystemGlobalOptions& options);
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/filesystem_test.cc b/src/arrow/cpp/src/arrow/filesystem/filesystem_test.cc
new file mode 100644
index 000000000..44889356b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/filesystem_test.cc
@@ -0,0 +1,825 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/filesystem/filesystem.h"
+#include "arrow/filesystem/mockfs.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/filesystem/test_util.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+namespace fs {
+namespace internal {
+
+void AssertPartsEqual(const std::vector<std::string>& parts,
+ const std::vector<std::string>& expected) {
+ ASSERT_EQ(parts, expected);
+}
+
+void AssertPairEqual(const std::pair<std::string, std::string>& pair,
+ const std::pair<std::string, std::string>& expected) {
+ ASSERT_EQ(pair, expected);
+}
+
+TEST(FileInfo, BaseName) {
+ auto info = FileInfo();
+ ASSERT_EQ(info.base_name(), "");
+ info.set_path("foo");
+ ASSERT_EQ(info.base_name(), "foo");
+ info.set_path("foo/bar/baz.qux");
+ ASSERT_EQ(info.base_name(), "baz.qux");
+}
+
+TEST(PathUtil, SplitAbstractPath) {
+ std::vector<std::string> parts;
+
+ parts = SplitAbstractPath("");
+ AssertPartsEqual(parts, {});
+ parts = SplitAbstractPath("abc");
+ AssertPartsEqual(parts, {"abc"});
+ parts = SplitAbstractPath("abc/def.ghi");
+ AssertPartsEqual(parts, {"abc", "def.ghi"});
+ parts = SplitAbstractPath("abc/def/ghi");
+ AssertPartsEqual(parts, {"abc", "def", "ghi"});
+ parts = SplitAbstractPath("abc\\def\\ghi");
+ AssertPartsEqual(parts, {"abc\\def\\ghi"});
+
+ // Trailing slash
+ parts = SplitAbstractPath("abc/");
+ AssertPartsEqual(parts, {"abc"});
+ parts = SplitAbstractPath("abc/def.ghi/");
+ AssertPartsEqual(parts, {"abc", "def.ghi"});
+ parts = SplitAbstractPath("abc/def.ghi\\");
+ AssertPartsEqual(parts, {"abc", "def.ghi\\"});
+
+ // Leading slash
+ parts = SplitAbstractPath("/");
+ AssertPartsEqual(parts, {});
+ parts = SplitAbstractPath("/abc");
+ AssertPartsEqual(parts, {"abc"});
+ parts = SplitAbstractPath("/abc/def.ghi");
+ AssertPartsEqual(parts, {"abc", "def.ghi"});
+ parts = SplitAbstractPath("/abc/def.ghi/");
+ AssertPartsEqual(parts, {"abc", "def.ghi"});
+}
+
+TEST(PathUtil, GetAbstractPathExtension) {
+ ASSERT_EQ(GetAbstractPathExtension("abc.txt"), "txt");
+ ASSERT_EQ(GetAbstractPathExtension("dir/abc.txt"), "txt");
+ ASSERT_EQ(GetAbstractPathExtension("/dir/abc.txt"), "txt");
+ ASSERT_EQ(GetAbstractPathExtension("dir/abc.txt.gz"), "gz");
+ ASSERT_EQ(GetAbstractPathExtension("/run.d/abc.txt"), "txt");
+ ASSERT_EQ(GetAbstractPathExtension("abc"), "");
+ ASSERT_EQ(GetAbstractPathExtension("/dir/abc"), "");
+ ASSERT_EQ(GetAbstractPathExtension("/run.d/abc"), "");
+}
+
+TEST(PathUtil, GetAbstractPathParent) {
+ std::pair<std::string, std::string> pair;
+
+ pair = GetAbstractPathParent("");
+ AssertPairEqual(pair, {"", ""});
+ pair = GetAbstractPathParent("abc");
+ AssertPairEqual(pair, {"", "abc"});
+ pair = GetAbstractPathParent("abc/def/ghi");
+ AssertPairEqual(pair, {"abc/def", "ghi"});
+ pair = GetAbstractPathParent("abc/def\\ghi");
+ AssertPairEqual(pair, {"abc", "def\\ghi"});
+}
+
+TEST(PathUtil, ValidateAbstractPathParts) {
+ ASSERT_OK(ValidateAbstractPathParts({}));
+ ASSERT_OK(ValidateAbstractPathParts({"abc"}));
+ ASSERT_OK(ValidateAbstractPathParts({"abc", "def"}));
+ ASSERT_OK(ValidateAbstractPathParts({"abc", "def.ghi"}));
+ ASSERT_OK(ValidateAbstractPathParts({"abc", "def\\ghi"}));
+
+ // Empty path component
+ ASSERT_RAISES(Invalid, ValidateAbstractPathParts({""}));
+ ASSERT_RAISES(Invalid, ValidateAbstractPathParts({"abc", "", "def"}));
+
+ // Separator in component
+ ASSERT_RAISES(Invalid, ValidateAbstractPathParts({"/"}));
+ ASSERT_RAISES(Invalid, ValidateAbstractPathParts({"abc/def"}));
+}
+
+TEST(PathUtil, ConcatAbstractPath) {
+ ASSERT_EQ("abc", ConcatAbstractPath("", "abc"));
+ ASSERT_EQ("abc/def", ConcatAbstractPath("abc", "def"));
+ ASSERT_EQ("abc/def/ghi", ConcatAbstractPath("abc/def", "ghi"));
+
+ ASSERT_EQ("abc/def", ConcatAbstractPath("abc/", "def"));
+ ASSERT_EQ("abc/def/ghi", ConcatAbstractPath("abc/def/", "ghi"));
+
+ ASSERT_EQ("/abc", ConcatAbstractPath("/", "abc"));
+ ASSERT_EQ("/abc/def", ConcatAbstractPath("/abc", "def"));
+ ASSERT_EQ("/abc/def", ConcatAbstractPath("/abc/", "def"));
+}
+
+TEST(PathUtil, JoinAbstractPath) {
+ std::vector<std::string> parts = {"abc", "def", "ghi", "", "jkl"};
+
+ ASSERT_EQ("abc/def/ghi/jkl", JoinAbstractPath(parts.begin(), parts.end()));
+ ASSERT_EQ("def/ghi", JoinAbstractPath(parts.begin() + 1, parts.begin() + 3));
+ ASSERT_EQ("", JoinAbstractPath(parts.begin(), parts.begin()));
+}
+
+TEST(PathUtil, EnsureTrailingSlash) {
+ ASSERT_EQ("", EnsureTrailingSlash(""));
+ ASSERT_EQ("/", EnsureTrailingSlash("/"));
+ ASSERT_EQ("abc/", EnsureTrailingSlash("abc"));
+ ASSERT_EQ("abc/", EnsureTrailingSlash("abc/"));
+ ASSERT_EQ("/abc/", EnsureTrailingSlash("/abc"));
+ ASSERT_EQ("/abc/", EnsureTrailingSlash("/abc/"));
+}
+
+TEST(PathUtil, RemoveTrailingSlash) {
+ ASSERT_EQ("", std::string(RemoveTrailingSlash("")));
+ ASSERT_EQ("", std::string(RemoveTrailingSlash("/")));
+ ASSERT_EQ("", std::string(RemoveTrailingSlash("//")));
+ ASSERT_EQ("abc/def", std::string(RemoveTrailingSlash("abc/def")));
+ ASSERT_EQ("abc/def", std::string(RemoveTrailingSlash("abc/def/")));
+ ASSERT_EQ("abc/def", std::string(RemoveTrailingSlash("abc/def//")));
+ ASSERT_EQ("/abc/def", std::string(RemoveTrailingSlash("/abc/def")));
+ ASSERT_EQ("/abc/def", std::string(RemoveTrailingSlash("/abc/def/")));
+ ASSERT_EQ("/abc/def", std::string(RemoveTrailingSlash("/abc/def//")));
+}
+
+TEST(PathUtil, EnsureLeadingSlash) {
+ ASSERT_EQ("/", EnsureLeadingSlash(""));
+ ASSERT_EQ("/", EnsureLeadingSlash("/"));
+ ASSERT_EQ("/abc", EnsureLeadingSlash("abc"));
+ ASSERT_EQ("/abc/", EnsureLeadingSlash("abc/"));
+ ASSERT_EQ("/abc", EnsureLeadingSlash("/abc"));
+ ASSERT_EQ("/abc/", EnsureLeadingSlash("/abc/"));
+}
+
+TEST(PathUtil, RemoveLeadingSlash) {
+ ASSERT_EQ("", std::string(RemoveLeadingSlash("")));
+ ASSERT_EQ("", std::string(RemoveLeadingSlash("/")));
+ ASSERT_EQ("", std::string(RemoveLeadingSlash("//")));
+ ASSERT_EQ("abc/def", std::string(RemoveLeadingSlash("abc/def")));
+ ASSERT_EQ("abc/def", std::string(RemoveLeadingSlash("/abc/def")));
+ ASSERT_EQ("abc/def", std::string(RemoveLeadingSlash("//abc/def")));
+ ASSERT_EQ("abc/def/", std::string(RemoveLeadingSlash("abc/def/")));
+ ASSERT_EQ("abc/def/", std::string(RemoveLeadingSlash("/abc/def/")));
+ ASSERT_EQ("abc/def/", std::string(RemoveLeadingSlash("//abc/def/")));
+}
+
+TEST(PathUtil, IsAncestorOf) {
+ ASSERT_TRUE(IsAncestorOf("", ""));
+ ASSERT_TRUE(IsAncestorOf("", "/hello"));
+ ASSERT_TRUE(IsAncestorOf("/hello", "/hello"));
+ ASSERT_FALSE(IsAncestorOf("/hello", "/world"));
+ ASSERT_TRUE(IsAncestorOf("/hello", "/hello/world"));
+ ASSERT_TRUE(IsAncestorOf("/hello", "/hello/world/how/are/you"));
+ ASSERT_FALSE(IsAncestorOf("/hello/w", "/hello/world"));
+}
+
+TEST(PathUtil, MakeAbstractPathRelative) {
+ ASSERT_OK_AND_EQ("", MakeAbstractPathRelative("/", "/"));
+ ASSERT_OK_AND_EQ("foo/bar", MakeAbstractPathRelative("/", "/foo/bar"));
+
+ ASSERT_OK_AND_EQ("", MakeAbstractPathRelative("/foo", "/foo"));
+ ASSERT_OK_AND_EQ("", MakeAbstractPathRelative("/foo/", "/foo"));
+ ASSERT_OK_AND_EQ("", MakeAbstractPathRelative("/foo", "/foo/"));
+ ASSERT_OK_AND_EQ("", MakeAbstractPathRelative("/foo/", "/foo/"));
+
+ ASSERT_OK_AND_EQ("bar", MakeAbstractPathRelative("/foo", "/foo/bar"));
+ ASSERT_OK_AND_EQ("bar", MakeAbstractPathRelative("/foo/", "/foo/bar"));
+ ASSERT_OK_AND_EQ("bar/", MakeAbstractPathRelative("/foo/", "/foo/bar/"));
+
+ // Not relative to base
+ ASSERT_RAISES(Invalid, MakeAbstractPathRelative("/xxx", "/foo/bar"));
+ ASSERT_RAISES(Invalid, MakeAbstractPathRelative("/xxx", "/xxxx"));
+
+ // Base is not absolute
+ ASSERT_RAISES(Invalid, MakeAbstractPathRelative("foo/bar", "foo/bar/baz"));
+ ASSERT_RAISES(Invalid, MakeAbstractPathRelative("", "foo/bar/baz"));
+}
+
+TEST(PathUtil, AncestorsFromBasePath) {
+ using V = std::vector<std::string>;
+
+ // Not relative to base
+ ASSERT_EQ(AncestorsFromBasePath("xxx", "foo/bar"), V{});
+ ASSERT_EQ(AncestorsFromBasePath("xxx", "xxxx"), V{});
+
+ ASSERT_EQ(AncestorsFromBasePath("foo", "foo/bar"), V{});
+ ASSERT_EQ(AncestorsFromBasePath("foo", "foo/bar/baz"), V({"foo/bar"}));
+ ASSERT_EQ(AncestorsFromBasePath("foo", "foo/bar/baz/quux"),
+ V({"foo/bar", "foo/bar/baz"}));
+}
+
+TEST(PathUtil, MinimalCreateDirSet) {
+ using V = std::vector<std::string>;
+
+ ASSERT_EQ(MinimalCreateDirSet({}), V{});
+ ASSERT_EQ(MinimalCreateDirSet({"foo"}), V{"foo"});
+ ASSERT_EQ(MinimalCreateDirSet({"foo", "foo/bar"}), V{"foo/bar"});
+ ASSERT_EQ(MinimalCreateDirSet({"foo", "foo/bar/baz"}), V{"foo/bar/baz"});
+ ASSERT_EQ(MinimalCreateDirSet({"foo", "foo/bar", "foo/bar"}), V{"foo/bar"});
+ ASSERT_EQ(MinimalCreateDirSet({"foo", "foo/bar", "foo", "foo/baz", "foo/baz/quux"}),
+ V({"foo/bar", "foo/baz/quux"}));
+
+ ASSERT_EQ(MinimalCreateDirSet({""}), V{});
+ ASSERT_EQ(MinimalCreateDirSet({"", "/foo"}), V{"/foo"});
+}
+
+TEST(PathUtil, ToBackslashes) {
+ ASSERT_EQ(ToBackslashes("foo/bar"), "foo\\bar");
+ ASSERT_EQ(ToBackslashes("//foo/bar/"), "\\\\foo\\bar\\");
+ ASSERT_EQ(ToBackslashes("foo\\bar"), "foo\\bar");
+}
+
+TEST(PathUtil, ToSlashes) {
+#ifdef _WIN32
+ ASSERT_EQ(ToSlashes("foo\\bar"), "foo/bar");
+ ASSERT_EQ(ToSlashes("\\\\foo\\bar\\"), "//foo/bar/");
+#else
+ ASSERT_EQ(ToSlashes("foo\\bar"), "foo\\bar");
+ ASSERT_EQ(ToSlashes("\\\\foo\\bar\\"), "\\\\foo\\bar\\");
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Generic MockFileSystem tests
+
+template <typename MockFileSystemType>
+class TestMockFSGeneric : public ::testing::Test, public GenericFileSystemTest {
+ public:
+ void SetUp() override {
+ time_ = TimePoint(TimePoint::duration(42));
+ fs_ = std::make_shared<MockFileSystemType>(time_);
+ }
+
+ protected:
+ std::shared_ptr<FileSystem> GetEmptyFileSystem() override { return fs_; }
+
+ bool have_file_metadata() const override { return true; }
+
+ TimePoint time_;
+ std::shared_ptr<FileSystem> fs_;
+};
+
+using MockFileSystemTypes = ::testing::Types<MockFileSystem, MockAsyncFileSystem>;
+
+TYPED_TEST_SUITE(TestMockFSGeneric, MockFileSystemTypes);
+
+GENERIC_FS_TYPED_TEST_FUNCTIONS(TestMockFSGeneric);
+
+////////////////////////////////////////////////////////////////////////////
+// Concrete MockFileSystem tests
+
+class TestMockFS : public ::testing::Test {
+ public:
+ void SetUp() override {
+ time_ = TimePoint(TimePoint::duration(42));
+ fs_ = std::make_shared<MockFileSystem>(time_);
+ }
+
+ Status WriteString(io::OutputStream* stream, const std::string& s) {
+ return stream->Write(s.data(), static_cast<int64_t>(s.length()));
+ }
+
+ std::vector<MockDirInfo> AllDirs() {
+ return arrow::internal::checked_pointer_cast<MockFileSystem>(fs_)->AllDirs();
+ }
+
+ std::vector<MockFileInfo> AllFiles() {
+ return arrow::internal::checked_pointer_cast<MockFileSystem>(fs_)->AllFiles();
+ }
+
+ void CheckDirs(const std::vector<MockDirInfo>& expected) {
+ ASSERT_EQ(AllDirs(), expected);
+ }
+
+ void CheckDirPaths(const std::vector<std::string>& expected) {
+ std::vector<MockDirInfo> infos;
+ infos.reserve(expected.size());
+ for (const auto& s : expected) {
+ infos.push_back({s, time_});
+ }
+ ASSERT_EQ(AllDirs(), infos);
+ }
+
+ void CheckFiles(const std::vector<MockFileInfo>& expected) {
+ ASSERT_EQ(AllFiles(), expected);
+ }
+
+ void CreateFile(const std::string& path, const std::string& data) {
+ ::arrow::fs::CreateFile(fs_.get(), path, data);
+ }
+
+ protected:
+ TimePoint time_;
+ std::shared_ptr<FileSystem> fs_;
+};
+
+TEST_F(TestMockFS, Empty) {
+ CheckDirs({});
+ CheckFiles({});
+}
+
+TEST_F(TestMockFS, CreateDir) {
+ ASSERT_OK(fs_->CreateDir("AB"));
+ ASSERT_OK(fs_->CreateDir("AB/CD/EF")); // Recursive
+ // Non-recursive, parent doesn't exist
+ ASSERT_RAISES(IOError, fs_->CreateDir("AB/GH/IJ", false /* recursive */));
+ ASSERT_OK(fs_->CreateDir("AB/GH", false /* recursive */));
+ ASSERT_OK(fs_->CreateDir("AB/GH/IJ", false /* recursive */));
+ // Idempotency
+ ASSERT_OK(fs_->CreateDir("AB/GH/IJ", false /* recursive */));
+ ASSERT_OK(fs_->CreateDir("XY"));
+ CheckDirs({{"AB", time_},
+ {"AB/CD", time_},
+ {"AB/CD/EF", time_},
+ {"AB/GH", time_},
+ {"AB/GH/IJ", time_},
+ {"XY", time_}});
+ CheckFiles({});
+}
+
+TEST_F(TestMockFS, DeleteDir) {
+ ASSERT_OK(fs_->CreateDir("AB/CD/EF"));
+ ASSERT_OK(fs_->CreateDir("AB/GH/IJ"));
+ ASSERT_OK(fs_->DeleteDir("AB/CD"));
+ ASSERT_OK(fs_->DeleteDir("AB/GH/IJ"));
+ CheckDirs({{"AB", time_}, {"AB/GH", time_}});
+ CheckFiles({});
+ ASSERT_RAISES(IOError, fs_->DeleteDir("AB/CD"));
+ ASSERT_OK(fs_->DeleteDir("AB"));
+ CheckDirs({});
+ CheckFiles({});
+}
+
+TEST_F(TestMockFS, DeleteFile) {
+ ASSERT_OK(fs_->CreateDir("AB"));
+ CreateFile("AB/cd", "data");
+ CheckDirs({{"AB", time_}});
+ CheckFiles({{"AB/cd", time_, "data"}});
+
+ ASSERT_OK(fs_->DeleteFile("AB/cd"));
+ CheckDirs({{"AB", time_}});
+ CheckFiles({});
+
+ CreateFile("ab", "data");
+ CheckDirs({{"AB", time_}});
+ CheckFiles({{"ab", time_, "data"}});
+
+ ASSERT_OK(fs_->DeleteFile("ab"));
+ CheckDirs({{"AB", time_}});
+ CheckFiles({});
+}
+
+TEST_F(TestMockFS, GetFileInfo) {
+ ASSERT_OK(fs_->CreateDir("AB/CD"));
+ CreateFile("AB/CD/ef", "some data");
+
+ FileInfo info;
+ ASSERT_OK_AND_ASSIGN(info, fs_->GetFileInfo("AB"));
+ AssertFileInfo(info, "AB", FileType::Directory, time_);
+ ASSERT_EQ(info.base_name(), "AB");
+ ASSERT_OK_AND_ASSIGN(info, fs_->GetFileInfo("AB/CD/ef"));
+ AssertFileInfo(info, "AB/CD/ef", FileType::File, time_, 9);
+ ASSERT_EQ(info.base_name(), "ef");
+
+ // Invalid path
+ ASSERT_RAISES(Invalid, fs_->GetFileInfo("//foo//bar//baz//"));
+}
+
+TEST_F(TestMockFS, GetFileInfoVector) {
+ ASSERT_OK(fs_->CreateDir("AB/CD"));
+ CreateFile("AB/CD/ef", "some data");
+
+ std::vector<FileInfo> infos;
+ ASSERT_OK_AND_ASSIGN(
+ infos, fs_->GetFileInfo({"AB", "AB/CD", "AB/zz", "zz", "XX/zz", "AB/CD/ef"}));
+ ASSERT_EQ(infos.size(), 6);
+ AssertFileInfo(infos[0], "AB", FileType::Directory, time_);
+ AssertFileInfo(infos[1], "AB/CD", FileType::Directory, time_);
+ AssertFileInfo(infos[2], "AB/zz", FileType::NotFound);
+ AssertFileInfo(infos[3], "zz", FileType::NotFound);
+ AssertFileInfo(infos[4], "XX/zz", FileType::NotFound);
+ AssertFileInfo(infos[5], "AB/CD/ef", FileType::File, time_, 9);
+
+ // Invalid path
+ ASSERT_RAISES(Invalid, fs_->GetFileInfo({"AB", "AB/CD", "//foo//bar//baz//"}));
+}
+
+TEST_F(TestMockFS, GetFileInfoSelector) {
+ ASSERT_OK(fs_->CreateDir("AB/CD"));
+ CreateFile("ab", "data");
+
+ FileSelector s;
+ s.base_dir = "";
+ std::vector<FileInfo> infos;
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(s));
+ ASSERT_EQ(infos.size(), 2);
+ AssertFileInfo(infos[0], "AB", FileType::Directory, time_);
+ AssertFileInfo(infos[1], "ab", FileType::File, time_, 4);
+
+ s.recursive = true;
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(s));
+ ASSERT_EQ(infos.size(), 3);
+ AssertFileInfo(infos[0], "AB", FileType::Directory, time_);
+ AssertFileInfo(infos[1], "AB/CD", FileType::Directory, time_);
+ AssertFileInfo(infos[2], "ab", FileType::File, time_, 4);
+
+ // Invalid path
+ s.base_dir = "//foo//bar//baz//";
+ ASSERT_RAISES(Invalid, fs_->GetFileInfo(s));
+}
+
+TEST_F(TestMockFS, OpenOutputStream) {
+ ASSERT_OK_AND_ASSIGN(auto stream, fs_->OpenOutputStream("ab"));
+ ASSERT_OK(stream->Close());
+ CheckDirs({});
+ CheckFiles({{"ab", time_, ""}});
+
+ // With metadata
+ auto metadata = KeyValueMetadata::Make({"some key"}, {"some value"});
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("cd", metadata));
+ ASSERT_OK(WriteString(stream.get(), "data"));
+ ASSERT_OK(stream->Close());
+ CheckFiles({{"ab", time_, ""}, {"cd", time_, "data"}});
+
+ ASSERT_OK_AND_ASSIGN(auto input, fs_->OpenInputStream("cd"));
+ ASSERT_OK_AND_ASSIGN(auto got_metadata, input->ReadMetadata());
+ ASSERT_NE(got_metadata, nullptr);
+ ASSERT_TRUE(got_metadata->Equals(*metadata));
+}
+
+TEST_F(TestMockFS, OpenAppendStream) {
+ ASSERT_OK_AND_ASSIGN(auto stream, fs_->OpenAppendStream("ab"));
+ ASSERT_OK(WriteString(stream.get(), "some "));
+ ASSERT_OK(stream->Close());
+
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenAppendStream("ab"));
+ ASSERT_OK(WriteString(stream.get(), "data"));
+ ASSERT_OK(stream->Close());
+ CheckDirs({});
+ CheckFiles({{"ab", time_, "some data"}});
+}
+
+TEST_F(TestMockFS, Make) {
+ ASSERT_OK_AND_ASSIGN(fs_, MockFileSystem::Make(time_, {}));
+ CheckDirs({});
+ CheckFiles({});
+
+ ASSERT_OK_AND_ASSIGN(fs_, MockFileSystem::Make(time_, {Dir("A/B/C"), File("A/a")}));
+ CheckDirs({{"A", time_}, {"A/B", time_}, {"A/B/C", time_}});
+ CheckFiles({{"A/a", time_, ""}});
+}
+
+TEST_F(TestMockFS, FileSystemFromUri) {
+ std::string path;
+ ASSERT_OK_AND_ASSIGN(fs_, FileSystemFromUri("mock:", &path));
+ ASSERT_EQ(path, "");
+ CheckDirs({}); // Ensures it's a MockFileSystem
+ ASSERT_OK_AND_ASSIGN(fs_, FileSystemFromUri("mock:foo/bar", &path));
+ ASSERT_EQ(path, "foo/bar");
+ CheckDirs({});
+ ASSERT_OK_AND_ASSIGN(fs_, FileSystemFromUri("mock:/foo/bar", &path));
+ ASSERT_EQ(path, "foo/bar");
+ CheckDirs({});
+ ASSERT_OK_AND_ASSIGN(fs_, FileSystemFromUri("mock:/foo/bar/?q=xxx", &path));
+ ASSERT_EQ(path, "foo/bar/");
+ CheckDirs({});
+ ASSERT_OK_AND_ASSIGN(fs_, FileSystemFromUri("mock:///foo/bar", &path));
+ ASSERT_EQ(path, "foo/bar");
+ CheckDirs({});
+ ASSERT_OK_AND_ASSIGN(fs_, FileSystemFromUri("mock:///foo/bar?q=zzz", &path));
+ ASSERT_EQ(path, "foo/bar");
+ CheckDirs({});
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Concrete SubTreeFileSystem tests
+
+class TestSubTreeFileSystem : public TestMockFS {
+ public:
+ void SetUp() override {
+ TestMockFS::SetUp();
+ ASSERT_OK(fs_->CreateDir("sub/tree"));
+ subfs_ = std::make_shared<SubTreeFileSystem>("sub/tree", fs_);
+ }
+
+ void CreateFile(const std::string& path, const std::string& data) {
+ ::arrow::fs::CreateFile(subfs_.get(), path, data);
+ }
+
+ protected:
+ std::shared_ptr<SubTreeFileSystem> subfs_;
+};
+
+TEST_F(TestSubTreeFileSystem, CreateDir) {
+ ASSERT_OK(subfs_->CreateDir("AB"));
+ ASSERT_OK(subfs_->CreateDir("AB/CD/EF")); // Recursive
+ // Non-recursive, parent doesn't exist
+ ASSERT_RAISES(IOError, subfs_->CreateDir("AB/GH/IJ", false /* recursive */));
+ ASSERT_OK(subfs_->CreateDir("AB/GH", false /* recursive */));
+ ASSERT_OK(subfs_->CreateDir("AB/GH/IJ", false /* recursive */));
+ // Can't create root dir
+ ASSERT_RAISES(IOError, subfs_->CreateDir(""));
+ CheckDirs({{"sub", time_},
+ {"sub/tree", time_},
+ {"sub/tree/AB", time_},
+ {"sub/tree/AB/CD", time_},
+ {"sub/tree/AB/CD/EF", time_},
+ {"sub/tree/AB/GH", time_},
+ {"sub/tree/AB/GH/IJ", time_}});
+ CheckFiles({});
+}
+
+TEST_F(TestSubTreeFileSystem, DeleteDir) {
+ ASSERT_OK(subfs_->CreateDir("AB/CD/EF"));
+ ASSERT_OK(subfs_->CreateDir("AB/GH/IJ"));
+ ASSERT_OK(subfs_->DeleteDir("AB/CD"));
+ ASSERT_OK(subfs_->DeleteDir("AB/GH/IJ"));
+ CheckDirs({{"sub", time_},
+ {"sub/tree", time_},
+ {"sub/tree/AB", time_},
+ {"sub/tree/AB/GH", time_}});
+ CheckFiles({});
+ ASSERT_RAISES(IOError, subfs_->DeleteDir("AB/CD"));
+ ASSERT_OK(subfs_->DeleteDir("AB"));
+ CheckDirs({{"sub", time_}, {"sub/tree", time_}});
+ CheckFiles({});
+
+ // Can't delete root dir
+ ASSERT_RAISES(IOError, subfs_->DeleteDir(""));
+ CheckDirs({{"sub", time_}, {"sub/tree", time_}});
+ CheckFiles({});
+}
+
+TEST_F(TestSubTreeFileSystem, DeleteFile) {
+ ASSERT_OK(subfs_->CreateDir("AB"));
+
+ CreateFile("ab", "");
+ CheckFiles({{"sub/tree/ab", time_, ""}});
+ ASSERT_OK(subfs_->DeleteFile("ab"));
+ CheckFiles({});
+
+ CreateFile("AB/cd", "");
+ CheckFiles({{"sub/tree/AB/cd", time_, ""}});
+ ASSERT_OK(subfs_->DeleteFile("AB/cd"));
+ CheckFiles({});
+
+ ASSERT_RAISES(IOError, subfs_->DeleteFile("nonexistent"));
+ ASSERT_RAISES(IOError, subfs_->DeleteFile(""));
+}
+
+TEST_F(TestSubTreeFileSystem, MoveFile) {
+ CreateFile("ab", "");
+ CheckFiles({{"sub/tree/ab", time_, ""}});
+ ASSERT_OK(subfs_->Move("ab", "cd"));
+ CheckFiles({{"sub/tree/cd", time_, ""}});
+
+ ASSERT_OK(subfs_->CreateDir("AB"));
+ ASSERT_OK(subfs_->Move("cd", "AB/ef"));
+ CheckFiles({{"sub/tree/AB/ef", time_, ""}});
+
+ ASSERT_RAISES(IOError, subfs_->Move("AB/ef", ""));
+ ASSERT_RAISES(IOError, subfs_->Move("", "xxx"));
+ CheckFiles({{"sub/tree/AB/ef", time_, ""}});
+ CheckDirs({{"sub", time_}, {"sub/tree", time_}, {"sub/tree/AB", time_}});
+}
+
+TEST_F(TestSubTreeFileSystem, MoveDir) {
+ ASSERT_OK(subfs_->CreateDir("AB/CD/EF"));
+ ASSERT_OK(subfs_->Move("AB/CD", "GH"));
+ CheckDirs({{"sub", time_},
+ {"sub/tree", time_},
+ {"sub/tree/AB", time_},
+ {"sub/tree/GH", time_},
+ {"sub/tree/GH/EF", time_}});
+
+ ASSERT_RAISES(IOError, subfs_->Move("AB", ""));
+}
+
+TEST_F(TestSubTreeFileSystem, CopyFile) {
+ CreateFile("ab", "data");
+ CheckFiles({{"sub/tree/ab", time_, "data"}});
+ ASSERT_OK(subfs_->CopyFile("ab", "cd"));
+ CheckFiles({{"sub/tree/ab", time_, "data"}, {"sub/tree/cd", time_, "data"}});
+
+ ASSERT_OK(subfs_->CreateDir("AB"));
+ ASSERT_OK(subfs_->CopyFile("cd", "AB/ef"));
+ CheckFiles({{"sub/tree/AB/ef", time_, "data"},
+ {"sub/tree/ab", time_, "data"},
+ {"sub/tree/cd", time_, "data"}});
+
+ ASSERT_RAISES(IOError, subfs_->CopyFile("ab", ""));
+ ASSERT_RAISES(IOError, subfs_->CopyFile("", "xxx"));
+ CheckFiles({{"sub/tree/AB/ef", time_, "data"},
+ {"sub/tree/ab", time_, "data"},
+ {"sub/tree/cd", time_, "data"}});
+}
+
+TEST_F(TestSubTreeFileSystem, CopyFiles) {
+ ASSERT_OK(subfs_->CreateDir("AB"));
+ ASSERT_OK(subfs_->CreateDir("CD/CD"));
+ ASSERT_OK(subfs_->CreateDir("EF/EF/EF"));
+
+ CreateFile("AB/ab", "ab");
+ CreateFile("CD/CD/cd", "cd");
+ CreateFile("EF/EF/EF/ef", "ef");
+
+ ASSERT_OK(fs_->CreateDir("sub/copy"));
+ auto dest_fs = std::make_shared<SubTreeFileSystem>("sub/copy", fs_);
+
+ FileSelector sel;
+ sel.recursive = true;
+ ASSERT_OK(CopyFiles(subfs_, sel, dest_fs, ""));
+
+ CheckFiles({
+ {"sub/copy/AB/ab", time_, "ab"},
+ {"sub/copy/CD/CD/cd", time_, "cd"},
+ {"sub/copy/EF/EF/EF/ef", time_, "ef"},
+ {"sub/tree/AB/ab", time_, "ab"},
+ {"sub/tree/CD/CD/cd", time_, "cd"},
+ {"sub/tree/EF/EF/EF/ef", time_, "ef"},
+ });
+}
+
+TEST_F(TestSubTreeFileSystem, OpenInputStream) {
+ std::shared_ptr<io::InputStream> stream;
+ CreateFile("ab", "data");
+
+ ASSERT_OK_AND_ASSIGN(stream, subfs_->OpenInputStream("ab"));
+ ASSERT_OK_AND_ASSIGN(auto buffer, stream->Read(4));
+ AssertBufferEqual(*buffer, "data");
+ ASSERT_OK(stream->Close());
+
+ ASSERT_RAISES(IOError, subfs_->OpenInputStream("nonexistent"));
+ ASSERT_RAISES(IOError, subfs_->OpenInputStream(""));
+}
+
+TEST_F(TestSubTreeFileSystem, OpenInputFile) {
+ std::shared_ptr<io::RandomAccessFile> stream;
+ CreateFile("ab", "some data");
+
+ ASSERT_OK_AND_ASSIGN(stream, subfs_->OpenInputFile("ab"));
+ ASSERT_OK_AND_ASSIGN(auto buffer, stream->ReadAt(5, 4));
+ AssertBufferEqual(*buffer, "data");
+ ASSERT_OK(stream->Close());
+
+ ASSERT_RAISES(IOError, subfs_->OpenInputFile("nonexistent"));
+ ASSERT_RAISES(IOError, subfs_->OpenInputFile(""));
+}
+
+TEST_F(TestSubTreeFileSystem, OpenOutputStream) {
+ std::shared_ptr<io::OutputStream> stream;
+
+ ASSERT_OK_AND_ASSIGN(stream, subfs_->OpenOutputStream("ab"));
+ ASSERT_OK(stream->Write("data"));
+ ASSERT_OK(stream->Close());
+ CheckFiles({{"sub/tree/ab", time_, "data"}});
+
+ ASSERT_OK(subfs_->CreateDir("AB"));
+ ASSERT_OK_AND_ASSIGN(stream, subfs_->OpenOutputStream("AB/cd"));
+ ASSERT_OK(stream->Write("other"));
+ ASSERT_OK(stream->Close());
+ CheckFiles({{"sub/tree/AB/cd", time_, "other"}, {"sub/tree/ab", time_, "data"}});
+
+ ASSERT_RAISES(IOError, subfs_->OpenOutputStream("nonexistent/xxx"));
+ ASSERT_RAISES(IOError, subfs_->OpenOutputStream("AB"));
+ ASSERT_RAISES(IOError, subfs_->OpenOutputStream(""));
+ CheckFiles({{"sub/tree/AB/cd", time_, "other"}, {"sub/tree/ab", time_, "data"}});
+}
+
+TEST_F(TestSubTreeFileSystem, OpenAppendStream) {
+ std::shared_ptr<io::OutputStream> stream;
+
+ ASSERT_OK_AND_ASSIGN(stream, subfs_->OpenAppendStream("ab"));
+ ASSERT_OK(stream->Write("some"));
+ ASSERT_OK(stream->Close());
+ CheckFiles({{"sub/tree/ab", time_, "some"}});
+
+ ASSERT_OK_AND_ASSIGN(stream, subfs_->OpenAppendStream("ab"));
+ ASSERT_OK(stream->Write(" data"));
+ ASSERT_OK(stream->Close());
+ CheckFiles({{"sub/tree/ab", time_, "some data"}});
+}
+
+TEST_F(TestSubTreeFileSystem, GetFileInfo) {
+ ASSERT_OK(subfs_->CreateDir("AB/CD"));
+
+ AssertFileInfo(subfs_.get(), "AB", FileType::Directory, time_);
+ AssertFileInfo(subfs_.get(), "AB/CD", FileType::Directory, time_);
+
+ CreateFile("ab", "data");
+ AssertFileInfo(subfs_.get(), "ab", FileType::File, time_, 4);
+
+ AssertFileInfo(subfs_.get(), "nonexistent", FileType::NotFound);
+}
+
+TEST_F(TestSubTreeFileSystem, GetFileInfoVector) {
+ std::vector<FileInfo> infos;
+
+ ASSERT_OK(subfs_->CreateDir("AB/CD"));
+ CreateFile("ab", "data");
+ CreateFile("AB/cd", "other data");
+
+ ASSERT_OK_AND_ASSIGN(infos, subfs_->GetFileInfo({"ab", "AB", "AB/cd", "nonexistent"}));
+ ASSERT_EQ(infos.size(), 4);
+ AssertFileInfo(infos[0], "ab", FileType::File, time_, 4);
+ AssertFileInfo(infos[1], "AB", FileType::Directory, time_);
+ AssertFileInfo(infos[2], "AB/cd", FileType::File, time_, 10);
+ AssertFileInfo(infos[3], "nonexistent", FileType::NotFound);
+}
+
+TEST_F(TestSubTreeFileSystem, GetFileInfoSelector) {
+ std::vector<FileInfo> infos;
+ FileSelector selector;
+
+ ASSERT_OK(subfs_->CreateDir("AB/CD"));
+ CreateFile("ab", "data");
+ CreateFile("AB/cd", "data2");
+ CreateFile("AB/CD/ef", "data34");
+
+ selector.base_dir = "AB";
+ selector.recursive = false;
+ ASSERT_OK_AND_ASSIGN(infos, subfs_->GetFileInfo(selector));
+ ASSERT_EQ(infos.size(), 2);
+ AssertFileInfo(infos[0], "AB/CD", FileType::Directory, time_);
+ AssertFileInfo(infos[1], "AB/cd", FileType::File, time_, 5);
+
+ selector.recursive = true;
+ ASSERT_OK_AND_ASSIGN(infos, subfs_->GetFileInfo(selector));
+ ASSERT_EQ(infos.size(), 3);
+ AssertFileInfo(infos[0], "AB/CD", FileType::Directory, time_);
+ AssertFileInfo(infos[1], "AB/CD/ef", FileType::File, time_, 6);
+ AssertFileInfo(infos[2], "AB/cd", FileType::File, time_, 5);
+
+ selector.base_dir = "";
+ selector.recursive = false;
+ ASSERT_OK_AND_ASSIGN(infos, subfs_->GetFileInfo(selector));
+ ASSERT_EQ(infos.size(), 2);
+ AssertFileInfo(infos[0], "AB", FileType::Directory, time_);
+ AssertFileInfo(infos[1], "ab", FileType::File, time_, 4);
+
+ selector.recursive = true;
+ ASSERT_OK_AND_ASSIGN(infos, subfs_->GetFileInfo(selector));
+ ASSERT_EQ(infos.size(), 5);
+ AssertFileInfo(infos[0], "AB", FileType::Directory, time_);
+ AssertFileInfo(infos[1], "AB/CD", FileType::Directory, time_);
+ AssertFileInfo(infos[2], "AB/CD/ef", FileType::File, time_, 6);
+ AssertFileInfo(infos[3], "AB/cd", FileType::File, time_, 5);
+ AssertFileInfo(infos[4], "ab", FileType::File, time_, 4);
+
+ selector.base_dir = "nonexistent";
+ ASSERT_RAISES(IOError, subfs_->GetFileInfo(selector));
+ selector.allow_not_found = true;
+ ASSERT_OK_AND_ASSIGN(infos, subfs_->GetFileInfo(selector));
+ ASSERT_EQ(infos.size(), 0);
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Generic SlowFileSystem tests
+
+class TestSlowFSGeneric : public ::testing::Test, public GenericFileSystemTest {
+ public:
+ void SetUp() override {
+ time_ = TimePoint(TimePoint::duration(42));
+ fs_ = std::make_shared<MockFileSystem>(time_);
+ slow_fs_ = std::make_shared<SlowFileSystem>(fs_, 0.001);
+ }
+
+ protected:
+ std::shared_ptr<FileSystem> GetEmptyFileSystem() override { return slow_fs_; }
+
+ TimePoint time_;
+ std::shared_ptr<MockFileSystem> fs_;
+ std::shared_ptr<SlowFileSystem> slow_fs_;
+};
+
+GENERIC_FS_TEST_FUNCTIONS(TestSlowFSGeneric);
+
+} // namespace internal
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/gcsfs.cc b/src/arrow/cpp/src/arrow/filesystem/gcsfs.cc
new file mode 100644
index 000000000..898e54cf5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/gcsfs.cc
@@ -0,0 +1,269 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/filesystem/gcsfs.h"
+
+#include <google/cloud/storage/client.h>
+
+#include "arrow/buffer.h"
+#include "arrow/filesystem/gcsfs_internal.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/result.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+namespace fs {
+namespace {
+
+namespace gcs = google::cloud::storage;
+
+auto constexpr kSep = '/';
+
+struct GcsPath {
+ std::string full_path;
+ std::string bucket;
+ std::string object;
+
+ static Result<GcsPath> FromString(const std::string& s) {
+ const auto src = internal::RemoveTrailingSlash(s);
+ auto const first_sep = src.find_first_of(kSep);
+ if (first_sep == 0) {
+ return Status::Invalid("Path cannot start with a separator ('", s, "')");
+ }
+ if (first_sep == std::string::npos) {
+ return GcsPath{std::string(src), std::string(src), ""};
+ }
+ GcsPath path;
+ path.full_path = std::string(src);
+ path.bucket = std::string(src.substr(0, first_sep));
+ path.object = std::string(src.substr(first_sep + 1));
+ return path;
+ }
+
+ bool empty() const { return bucket.empty() && object.empty(); }
+
+ bool operator==(const GcsPath& other) const {
+ return bucket == other.bucket && object == other.object;
+ }
+};
+
+class GcsInputStream : public arrow::io::InputStream {
+ public:
+ explicit GcsInputStream(gcs::ObjectReadStream stream) : stream_(std::move(stream)) {}
+
+ ~GcsInputStream() override = default;
+
+ Status Close() override {
+ stream_.Close();
+ return Status::OK();
+ }
+
+ Result<int64_t> Tell() const override {
+ if (!stream_) {
+ return Status::IOError("invalid stream");
+ }
+ return stream_.tellg();
+ }
+
+ bool closed() const override { return !stream_.IsOpen(); }
+
+ Result<int64_t> Read(int64_t nbytes, void* out) override {
+ stream_.read(static_cast<char*>(out), nbytes);
+ if (!stream_.status().ok()) {
+ return internal::ToArrowStatus(stream_.status());
+ }
+ return stream_.gcount();
+ }
+
+ Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) override {
+ ARROW_ASSIGN_OR_RAISE(auto buffer, arrow::AllocateResizableBuffer(nbytes));
+ stream_.read(reinterpret_cast<char*>(buffer->mutable_data()), nbytes);
+ if (!stream_.status().ok()) {
+ return internal::ToArrowStatus(stream_.status());
+ }
+ return arrow::SliceMutableBufferSafe(std::move(buffer), 0, stream_.gcount());
+ }
+
+ private:
+ mutable gcs::ObjectReadStream stream_;
+};
+
+} // namespace
+
+google::cloud::Options AsGoogleCloudOptions(const GcsOptions& o) {
+ auto options = google::cloud::Options{};
+ if (!o.endpoint_override.empty()) {
+ std::string scheme = o.scheme;
+ if (scheme.empty()) scheme = "https";
+ if (scheme == "https") {
+ options.set<google::cloud::UnifiedCredentialsOption>(
+ google::cloud::MakeGoogleDefaultCredentials());
+ } else {
+ options.set<google::cloud::UnifiedCredentialsOption>(
+ google::cloud::MakeInsecureCredentials());
+ }
+ options.set<gcs::RestEndpointOption>(scheme + "://" + o.endpoint_override);
+ }
+ return options;
+}
+
+class GcsFileSystem::Impl {
+ public:
+ explicit Impl(GcsOptions o)
+ : options_(std::move(o)), client_(AsGoogleCloudOptions(options_)) {}
+
+ const GcsOptions& options() const { return options_; }
+
+ Result<FileInfo> GetFileInfo(const GcsPath& path) {
+ if (!path.object.empty()) {
+ auto meta = client_.GetObjectMetadata(path.bucket, path.object);
+ return GetFileInfoImpl(path, std::move(meta).status(), FileType::File);
+ }
+ auto meta = client_.GetBucketMetadata(path.bucket);
+ return GetFileInfoImpl(path, std::move(meta).status(), FileType::Directory);
+ }
+
+ Result<std::shared_ptr<io::InputStream>> OpenInputStream(const GcsPath& path) {
+ auto stream = client_.ReadObject(path.bucket, path.object);
+ if (!stream.status().ok()) {
+ return internal::ToArrowStatus(stream.status());
+ }
+ return std::make_shared<GcsInputStream>(std::move(stream));
+ }
+
+ private:
+ static Result<FileInfo> GetFileInfoImpl(const GcsPath& path,
+ const google::cloud::Status& status,
+ FileType type) {
+ if (status.ok()) {
+ return FileInfo(path.full_path, type);
+ }
+ using ::google::cloud::StatusCode;
+ if (status.code() == StatusCode::kNotFound) {
+ return FileInfo(path.full_path, FileType::NotFound);
+ }
+ return internal::ToArrowStatus(status);
+ }
+
+ GcsOptions options_;
+ gcs::Client client_;
+};
+
+bool GcsOptions::Equals(const GcsOptions& other) const {
+ return endpoint_override == other.endpoint_override && scheme == other.scheme;
+}
+
+std::string GcsFileSystem::type_name() const { return "gcs"; }
+
+bool GcsFileSystem::Equals(const FileSystem& other) const {
+ if (this == &other) {
+ return true;
+ }
+ if (other.type_name() != type_name()) {
+ return false;
+ }
+ const auto& fs = ::arrow::internal::checked_cast<const GcsFileSystem&>(other);
+ return impl_->options().Equals(fs.impl_->options());
+}
+
+Result<FileInfo> GcsFileSystem::GetFileInfo(const std::string& path) {
+ ARROW_ASSIGN_OR_RAISE(auto p, GcsPath::FromString(path));
+ return impl_->GetFileInfo(p);
+}
+
+Result<FileInfoVector> GcsFileSystem::GetFileInfo(const FileSelector& select) {
+ return Status::NotImplemented("The GCS FileSystem is not fully implemented");
+}
+
+Status GcsFileSystem::CreateDir(const std::string& path, bool recursive) {
+ return Status::NotImplemented("The GCS FileSystem is not fully implemented");
+}
+
+Status GcsFileSystem::DeleteDir(const std::string& path) {
+ return Status::NotImplemented("The GCS FileSystem is not fully implemented");
+}
+
+Status GcsFileSystem::DeleteDirContents(const std::string& path) {
+ return Status::NotImplemented("The GCS FileSystem is not fully implemented");
+}
+
+Status GcsFileSystem::DeleteRootDirContents() {
+ return Status::NotImplemented("The GCS FileSystem is not fully implemented");
+}
+
+Status GcsFileSystem::DeleteFile(const std::string& path) {
+ return Status::NotImplemented("The GCS FileSystem is not fully implemented");
+}
+
+Status GcsFileSystem::Move(const std::string& src, const std::string& dest) {
+ return Status::NotImplemented("The GCS FileSystem is not fully implemented");
+}
+
+Status GcsFileSystem::CopyFile(const std::string& src, const std::string& dest) {
+ return Status::NotImplemented("The GCS FileSystem is not fully implemented");
+}
+
+Result<std::shared_ptr<io::InputStream>> GcsFileSystem::OpenInputStream(
+ const std::string& path) {
+ ARROW_ASSIGN_OR_RAISE(auto p, GcsPath::FromString(path));
+ return impl_->OpenInputStream(p);
+}
+
+Result<std::shared_ptr<io::InputStream>> GcsFileSystem::OpenInputStream(
+ const FileInfo& info) {
+ if (!info.IsFile()) {
+ return Status::IOError("Only files can be opened as input streams");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto p, GcsPath::FromString(info.path()));
+ return impl_->OpenInputStream(p);
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> GcsFileSystem::OpenInputFile(
+ const std::string& path) {
+ return Status::NotImplemented("The GCS FileSystem is not fully implemented");
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> GcsFileSystem::OpenInputFile(
+ const FileInfo& info) {
+ return Status::NotImplemented("The GCS FileSystem is not fully implemented");
+}
+
+Result<std::shared_ptr<io::OutputStream>> GcsFileSystem::OpenOutputStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ return Status::NotImplemented("The GCS FileSystem is not fully implemented");
+}
+
+Result<std::shared_ptr<io::OutputStream>> GcsFileSystem::OpenAppendStream(
+ const std::string&, const std::shared_ptr<const KeyValueMetadata>&) {
+ return Status::NotImplemented("Append is not supported in GCS");
+}
+
+GcsFileSystem::GcsFileSystem(const GcsOptions& options, const io::IOContext& context)
+ : FileSystem(context), impl_(std::make_shared<Impl>(options)) {}
+
+namespace internal {
+
+std::shared_ptr<GcsFileSystem> MakeGcsFileSystemForTest(const GcsOptions& options) {
+ // Cannot use `std::make_shared<>` as the constructor is private.
+ return std::shared_ptr<GcsFileSystem>(
+ new GcsFileSystem(options, io::default_io_context()));
+}
+
+} // namespace internal
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/gcsfs.h b/src/arrow/cpp/src/arrow/filesystem/gcsfs.h
new file mode 100644
index 000000000..2583bdee8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/gcsfs.h
@@ -0,0 +1,118 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/filesystem/filesystem.h"
+
+namespace arrow {
+namespace fs {
+class GcsFileSystem;
+struct GcsOptions;
+namespace internal {
+// TODO(ARROW-1231) - remove, and provide a public API (static GcsFileSystem::Make()).
+std::shared_ptr<GcsFileSystem> MakeGcsFileSystemForTest(const GcsOptions& options);
+} // namespace internal
+
+/// Options for the GcsFileSystem implementation.
+struct ARROW_EXPORT GcsOptions {
+ std::string endpoint_override;
+ std::string scheme;
+
+ bool Equals(const GcsOptions& other) const;
+};
+
+/// \brief GCS-backed FileSystem implementation.
+///
+/// Some implementation notes:
+/// - TODO(ARROW-1231) - review all the notes once completed.
+/// - buckets are treated as top-level directories on a "root".
+/// - GCS buckets are in a global namespace, only one bucket
+/// named `foo` exists in Google Cloud.
+/// - Creating new top-level directories is implemented by creating
+/// a bucket, this may be a slower operation than usual.
+/// - A principal (service account, user, etc) can only list the
+/// buckets for a single project, but can access the buckets
+/// for many projects. It is possible that listing "all"
+/// the buckets returns fewer buckets than you have access to.
+/// - GCS does not have directories, they are emulated in this
+/// library by listing objects with a common prefix.
+/// - In general, GCS has much higher latency than local filesystems.
+/// The throughput of GCS is comparable to the throughput of
+/// a local file system.
+class ARROW_EXPORT GcsFileSystem : public FileSystem {
+ public:
+ ~GcsFileSystem() override = default;
+
+ std::string type_name() const override;
+
+ bool Equals(const FileSystem& other) const override;
+
+ Result<FileInfo> GetFileInfo(const std::string& path) override;
+ Result<FileInfoVector> GetFileInfo(const FileSelector& select) override;
+
+ Status CreateDir(const std::string& path, bool recursive) override;
+
+ Status DeleteDir(const std::string& path) override;
+
+ Status DeleteDirContents(const std::string& path) override;
+
+ Status DeleteRootDirContents() override;
+
+ Status DeleteFile(const std::string& path) override;
+
+ Status Move(const std::string& src, const std::string& dest) override;
+
+ Status CopyFile(const std::string& src, const std::string& dest) override;
+
+ Result<std::shared_ptr<io::InputStream>> OpenInputStream(
+ const std::string& path) override;
+ Result<std::shared_ptr<io::InputStream>> OpenInputStream(const FileInfo& info) override;
+
+ Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(
+ const std::string& path) override;
+ Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(
+ const FileInfo& info) override;
+
+ Result<std::shared_ptr<io::OutputStream>> OpenOutputStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata) override;
+
+ ARROW_DEPRECATED(
+ "Deprecated. "
+ "OpenAppendStream is unsupported on the GCS FileSystem.")
+ Result<std::shared_ptr<io::OutputStream>> OpenAppendStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata) override;
+
+ private:
+ /// Create a GcsFileSystem instance from the given options.
+ friend std::shared_ptr<GcsFileSystem> internal::MakeGcsFileSystemForTest(
+ const GcsOptions& options);
+
+ explicit GcsFileSystem(const GcsOptions& options, const io::IOContext& io_context);
+
+ class Impl;
+ std::shared_ptr<Impl> impl_;
+};
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/gcsfs_internal.cc b/src/arrow/cpp/src/arrow/filesystem/gcsfs_internal.cc
new file mode 100644
index 000000000..898015859
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/gcsfs_internal.cc
@@ -0,0 +1,67 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/filesystem/gcsfs_internal.h"
+
+#include <google/cloud/storage/client.h>
+
+#include <sstream>
+
+namespace arrow {
+namespace fs {
+namespace internal {
+
+Status ToArrowStatus(const google::cloud::Status& s) {
+ std::ostringstream os;
+ os << "google::cloud::Status(" << s << ")";
+ switch (s.code()) {
+ case google::cloud::StatusCode::kOk:
+ break;
+ case google::cloud::StatusCode::kCancelled:
+ return Status::Cancelled(os.str());
+ case google::cloud::StatusCode::kUnknown:
+ return Status::UnknownError(os.str());
+ case google::cloud::StatusCode::kInvalidArgument:
+ return Status::Invalid(os.str());
+ case google::cloud::StatusCode::kDeadlineExceeded:
+ case google::cloud::StatusCode::kNotFound:
+ return Status::IOError(os.str());
+ case google::cloud::StatusCode::kAlreadyExists:
+ return Status::AlreadyExists(os.str());
+ case google::cloud::StatusCode::kPermissionDenied:
+ case google::cloud::StatusCode::kUnauthenticated:
+ return Status::IOError(os.str());
+ case google::cloud::StatusCode::kResourceExhausted:
+ return Status::CapacityError(os.str());
+ case google::cloud::StatusCode::kFailedPrecondition:
+ case google::cloud::StatusCode::kAborted:
+ return Status::IOError(os.str());
+ case google::cloud::StatusCode::kOutOfRange:
+ return Status::Invalid(os.str());
+ case google::cloud::StatusCode::kUnimplemented:
+ return Status::NotImplemented(os.str());
+ case google::cloud::StatusCode::kInternal:
+ case google::cloud::StatusCode::kUnavailable:
+ case google::cloud::StatusCode::kDataLoss:
+ return Status::IOError(os.str());
+ }
+ return Status::OK();
+}
+
+} // namespace internal
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/gcsfs_internal.h b/src/arrow/cpp/src/arrow/filesystem/gcsfs_internal.h
new file mode 100644
index 000000000..8d568701e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/gcsfs_internal.h
@@ -0,0 +1,36 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <google/cloud/status.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/filesystem/filesystem.h"
+
+namespace arrow {
+namespace fs {
+namespace internal {
+
+Status ToArrowStatus(const google::cloud::Status& s);
+
+} // namespace internal
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/gcsfs_test.cc b/src/arrow/cpp/src/arrow/filesystem/gcsfs_test.cc
new file mode 100644
index 000000000..369317fbb
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/gcsfs_test.cc
@@ -0,0 +1,264 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/filesystem/gcsfs.h"
+
+#include <gmock/gmock-matchers.h>
+#include <gmock/gmock-more-matchers.h>
+#include <google/cloud/credentials.h>
+#include <google/cloud/storage/client.h>
+#include <google/cloud/storage/options.h>
+#include <gtest/gtest.h>
+
+#include <array>
+#include <boost/process.hpp>
+#include <string>
+
+#include "arrow/filesystem/gcsfs_internal.h"
+#include "arrow/filesystem/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+
+namespace arrow {
+namespace fs {
+namespace {
+
+namespace bp = boost::process;
+namespace gc = google::cloud;
+namespace gcs = google::cloud::storage;
+
+using ::testing::HasSubstr;
+using ::testing::IsEmpty;
+using ::testing::Not;
+using ::testing::NotNull;
+
+auto const* kPreexistingBucket = "test-bucket-name";
+auto const* kPreexistingObject = "test-object-name";
+auto const* kLoremIpsum = R"""(
+Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor
+incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis
+nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu
+fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in
+culpa qui officia deserunt mollit anim id est laborum.
+)""";
+
+class GcsIntegrationTest : public ::testing::Test {
+ public:
+ ~GcsIntegrationTest() override {
+ if (server_process_.valid()) {
+ // Brutal shutdown
+ server_process_.terminate();
+ server_process_.wait();
+ }
+ }
+
+ protected:
+ void SetUp() override {
+ port_ = std::to_string(GetListenPort());
+ auto exe_path = bp::search_path("python3");
+ ASSERT_THAT(exe_path, Not(IsEmpty()));
+
+ server_process_ = bp::child(boost::this_process::environment(), exe_path, "-m",
+ "testbench", "--port", port_);
+
+ // Create a bucket and a small file in the testbench. This makes it easier to
+ // bootstrap GcsFileSystem and its tests.
+ auto client = gcs::Client(
+ google::cloud::Options{}
+ .set<gcs::RestEndpointOption>("http://127.0.0.1:" + port_)
+ .set<gc::UnifiedCredentialsOption>(gc::MakeInsecureCredentials()));
+ google::cloud::StatusOr<gcs::BucketMetadata> bucket = client.CreateBucketForProject(
+ kPreexistingBucket, "ignored-by-testbench", gcs::BucketMetadata{});
+ ASSERT_TRUE(bucket.ok()) << "Failed to create bucket <" << kPreexistingBucket
+ << ">, status=" << bucket.status();
+
+ google::cloud::StatusOr<gcs::ObjectMetadata> object =
+ client.InsertObject(kPreexistingBucket, kPreexistingObject, kLoremIpsum);
+ ASSERT_TRUE(object.ok()) << "Failed to create object <" << kPreexistingObject
+ << ">, status=" << object.status();
+ }
+
+ static std::string PreexistingObjectPath() {
+ return std::string(kPreexistingBucket) + "/" + kPreexistingObject;
+ }
+
+ static std::string NotFoundObjectPath() {
+ return std::string(kPreexistingBucket) + "/not-found";
+ }
+
+ GcsOptions TestGcsOptions() {
+ GcsOptions options;
+ options.endpoint_override = "127.0.0.1:" + port_;
+ options.scheme = "http";
+ return options;
+ }
+
+ private:
+ std::string port_;
+ bp::child server_process_;
+};
+
+TEST(GcsFileSystem, OptionsCompare) {
+ GcsOptions a;
+ GcsOptions b;
+ b.endpoint_override = "localhost:1234";
+ EXPECT_TRUE(a.Equals(a));
+ EXPECT_TRUE(b.Equals(b));
+ auto c = b;
+ c.scheme = "http";
+ EXPECT_FALSE(b.Equals(c));
+}
+
+TEST(GcsFileSystem, ToArrowStatusOK) {
+ Status actual = internal::ToArrowStatus(google::cloud::Status());
+ EXPECT_TRUE(actual.ok());
+}
+
+TEST(GcsFileSystem, ToArrowStatus) {
+ struct {
+ google::cloud::StatusCode input;
+ arrow::StatusCode expected;
+ } cases[] = {
+ {google::cloud::StatusCode::kCancelled, StatusCode::Cancelled},
+ {google::cloud::StatusCode::kUnknown, StatusCode::UnknownError},
+ {google::cloud::StatusCode::kInvalidArgument, StatusCode::Invalid},
+ {google::cloud::StatusCode::kDeadlineExceeded, StatusCode::IOError},
+ {google::cloud::StatusCode::kNotFound, StatusCode::IOError},
+ {google::cloud::StatusCode::kAlreadyExists, StatusCode::AlreadyExists},
+ {google::cloud::StatusCode::kPermissionDenied, StatusCode::IOError},
+ {google::cloud::StatusCode::kUnauthenticated, StatusCode::IOError},
+ {google::cloud::StatusCode::kResourceExhausted, StatusCode::CapacityError},
+ {google::cloud::StatusCode::kFailedPrecondition, StatusCode::IOError},
+ {google::cloud::StatusCode::kAborted, StatusCode::IOError},
+ {google::cloud::StatusCode::kOutOfRange, StatusCode::Invalid},
+ {google::cloud::StatusCode::kUnimplemented, StatusCode::NotImplemented},
+ {google::cloud::StatusCode::kInternal, StatusCode::IOError},
+ {google::cloud::StatusCode::kUnavailable, StatusCode::IOError},
+ {google::cloud::StatusCode::kDataLoss, StatusCode::IOError},
+ };
+
+ for (const auto& test : cases) {
+ auto status = google::cloud::Status(test.input, "test-message");
+ auto message = [&] {
+ std::ostringstream os;
+ os << status;
+ return os.str();
+ }();
+ SCOPED_TRACE("Testing with status=" + message);
+ const auto actual = arrow::fs::internal::ToArrowStatus(status);
+ EXPECT_EQ(actual.code(), test.expected);
+ EXPECT_THAT(actual.message(), HasSubstr(message));
+ }
+}
+
+TEST(GcsFileSystem, FileSystemCompare) {
+ GcsOptions a_options;
+ a_options.scheme = "http";
+ auto a = internal::MakeGcsFileSystemForTest(a_options);
+ EXPECT_THAT(a, NotNull());
+ EXPECT_TRUE(a->Equals(*a));
+
+ GcsOptions b_options;
+ b_options.scheme = "http";
+ b_options.endpoint_override = "localhost:1234";
+ auto b = internal::MakeGcsFileSystemForTest(b_options);
+ EXPECT_THAT(b, NotNull());
+ EXPECT_TRUE(b->Equals(*b));
+
+ EXPECT_FALSE(a->Equals(*b));
+}
+
+TEST_F(GcsIntegrationTest, GetFileInfoBucket) {
+ auto fs = internal::MakeGcsFileSystemForTest(TestGcsOptions());
+ arrow::fs::AssertFileInfo(fs.get(), kPreexistingBucket, FileType::Directory);
+}
+
+TEST_F(GcsIntegrationTest, GetFileInfoObject) {
+ auto fs = internal::MakeGcsFileSystemForTest(TestGcsOptions());
+ arrow::fs::AssertFileInfo(fs.get(), PreexistingObjectPath(), FileType::File);
+}
+
+TEST_F(GcsIntegrationTest, ReadObjectString) {
+ auto fs = internal::MakeGcsFileSystemForTest(TestGcsOptions());
+
+ std::shared_ptr<io::InputStream> stream;
+ ASSERT_OK_AND_ASSIGN(stream, fs->OpenInputStream(PreexistingObjectPath()));
+
+ std::array<char, 1024> buffer{};
+ std::int64_t size;
+ ASSERT_OK_AND_ASSIGN(size, stream->Read(buffer.size(), buffer.data()));
+
+ EXPECT_EQ(std::string(buffer.data(), size), kLoremIpsum);
+}
+
+TEST_F(GcsIntegrationTest, ReadObjectStringBuffers) {
+ auto fs = internal::MakeGcsFileSystemForTest(TestGcsOptions());
+
+ std::shared_ptr<io::InputStream> stream;
+ ASSERT_OK_AND_ASSIGN(stream, fs->OpenInputStream(PreexistingObjectPath()));
+
+ std::string contents;
+ std::shared_ptr<Buffer> buffer;
+ do {
+ ASSERT_OK_AND_ASSIGN(buffer, stream->Read(16));
+ contents.append(buffer->ToString());
+ } while (buffer && buffer->size() != 0);
+
+ EXPECT_EQ(contents, kLoremIpsum);
+}
+
+TEST_F(GcsIntegrationTest, ReadObjectInfo) {
+ auto fs = internal::MakeGcsFileSystemForTest(TestGcsOptions());
+
+ arrow::fs::FileInfo info;
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo(PreexistingObjectPath()));
+
+ std::shared_ptr<io::InputStream> stream;
+ ASSERT_OK_AND_ASSIGN(stream, fs->OpenInputStream(info));
+
+ std::array<char, 1024> buffer{};
+ std::int64_t size;
+ ASSERT_OK_AND_ASSIGN(size, stream->Read(buffer.size(), buffer.data()));
+
+ EXPECT_EQ(std::string(buffer.data(), size), kLoremIpsum);
+}
+
+TEST_F(GcsIntegrationTest, ReadObjectNotFound) {
+ auto fs = internal::MakeGcsFileSystemForTest(TestGcsOptions());
+
+ auto result = fs->OpenInputStream(NotFoundObjectPath());
+ EXPECT_EQ(result.status().code(), StatusCode::IOError);
+}
+
+TEST_F(GcsIntegrationTest, ReadObjectInfoInvalid) {
+ auto fs = internal::MakeGcsFileSystemForTest(TestGcsOptions());
+
+ arrow::fs::FileInfo info;
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo(kPreexistingBucket));
+
+ auto result = fs->OpenInputStream(NotFoundObjectPath());
+ EXPECT_EQ(result.status().code(), StatusCode::IOError);
+
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo(NotFoundObjectPath()));
+ result = fs->OpenInputStream(NotFoundObjectPath());
+ EXPECT_EQ(result.status().code(), StatusCode::IOError);
+}
+
+} // namespace
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/hdfs.cc b/src/arrow/cpp/src/arrow/filesystem/hdfs.cc
new file mode 100644
index 000000000..c6396deac
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/hdfs.cc
@@ -0,0 +1,518 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <chrono>
+#include <cstring>
+#include <unordered_map>
+#include <utility>
+
+#include "arrow/filesystem/hdfs.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/filesystem/util_internal.h"
+#include "arrow/io/hdfs.h"
+#include "arrow/io/hdfs_internal.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/value_parsing.h"
+#include "arrow/util/windows_fixup.h"
+
+namespace arrow {
+
+using internal::ParseValue;
+using internal::Uri;
+
+namespace fs {
+
+using internal::GetAbstractPathParent;
+using internal::MakeAbstractPathRelative;
+using internal::RemoveLeadingSlash;
+
+class HadoopFileSystem::Impl {
+ public:
+ Impl(HdfsOptions options, const io::IOContext& io_context)
+ : options_(std::move(options)), io_context_(io_context) {}
+
+ ~Impl() {
+ Status st = Close();
+ if (!st.ok()) {
+ ARROW_LOG(WARNING) << "Failed to disconnect hdfs client: " << st.ToString();
+ }
+ }
+
+ Status Init() {
+ io::internal::LibHdfsShim* driver_shim;
+ RETURN_NOT_OK(ConnectLibHdfs(&driver_shim));
+ RETURN_NOT_OK(io::HadoopFileSystem::Connect(&options_.connection_config, &client_));
+ return Status::OK();
+ }
+
+ Status Close() {
+ if (client_) {
+ RETURN_NOT_OK(client_->Disconnect());
+ }
+ return Status::OK();
+ }
+
+ HdfsOptions options() const { return options_; }
+
+ Result<FileInfo> GetFileInfo(const std::string& path) {
+ // It has unfortunately been a frequent logic error to pass URIs down
+ // to GetFileInfo (e.g. ARROW-10264). Unlike other filesystems, HDFS
+ // silently accepts URIs but returns different results than if given the
+ // equivalent in-filesystem paths. Instead of raising cryptic errors
+ // later, notify the underlying problem immediately.
+ if (path.substr(0, 5) == "hdfs:") {
+ return Status::Invalid("GetFileInfo must not be passed a URI, got: ", path);
+ }
+ FileInfo info;
+ io::HdfsPathInfo path_info;
+ auto status = client_->GetPathInfo(path, &path_info);
+ info.set_path(path);
+ if (status.IsIOError()) {
+ info.set_type(FileType::NotFound);
+ return info;
+ }
+
+ PathInfoToFileInfo(path_info, &info);
+ return info;
+ }
+
+ Status StatSelector(const std::string& wd, const std::string& path,
+ const FileSelector& select, int nesting_depth,
+ std::vector<FileInfo>* out) {
+ std::vector<io::HdfsPathInfo> children;
+ Status st = client_->ListDirectory(path, &children);
+ if (!st.ok()) {
+ if (select.allow_not_found) {
+ ARROW_ASSIGN_OR_RAISE(auto info, GetFileInfo(path));
+ if (info.type() == FileType::NotFound) {
+ return Status::OK();
+ }
+ }
+ return st;
+ }
+ for (const auto& child_path_info : children) {
+ // HDFS returns an absolute "URI" here, need to extract path relative to wd
+ // XXX: unfortunately, this is not a real URI as special characters
+ // are not %-escaped... hence parsing it as URI would fail.
+ std::string child_path;
+ if (!wd.empty()) {
+ if (child_path_info.name.substr(0, wd.length()) != wd) {
+ return Status::IOError("HDFS returned path '", child_path_info.name,
+ "' that is not a child of '", wd, "'");
+ }
+ child_path = child_path_info.name.substr(wd.length());
+ } else {
+ child_path = child_path_info.name;
+ }
+
+ FileInfo info;
+ info.set_path(child_path);
+ PathInfoToFileInfo(child_path_info, &info);
+ const bool is_dir = info.type() == FileType::Directory;
+ out->push_back(std::move(info));
+ if (is_dir && select.recursive && nesting_depth < select.max_recursion) {
+ RETURN_NOT_OK(StatSelector(wd, child_path, select, nesting_depth + 1, out));
+ }
+ }
+ return Status::OK();
+ }
+
+ Result<std::vector<FileInfo>> GetFileInfo(const FileSelector& select) {
+ // See GetFileInfo(const std::string&) above.
+ if (select.base_dir.substr(0, 5) == "hdfs:") {
+ return Status::Invalid("FileSelector.base_dir must not be a URI, got: ",
+ select.base_dir);
+ }
+ std::vector<FileInfo> results;
+
+ // Fetch working directory.
+ // If select.base_dir is relative, we need to trim it from the start
+ // of paths returned by ListDirectory.
+ // If select.base_dir is absolute, we need to trim the "URI authority"
+ // portion of the working directory.
+ std::string wd;
+ RETURN_NOT_OK(client_->GetWorkingDirectory(&wd));
+
+ if (!select.base_dir.empty() && select.base_dir.front() == '/') {
+ // base_dir is absolute, only keep the URI authority portion.
+ // As mentioned in StatSelector() above, the URI may contain unescaped
+ // special chars and therefore may not be a valid URI, so we parse by hand.
+ auto pos = wd.find("://"); // start of host:port portion
+ if (pos == std::string::npos) {
+ return Status::IOError("Unexpected HDFS working directory URI: ", wd);
+ }
+ pos = wd.find("/", pos + 3); // end of host:port portion
+ if (pos == std::string::npos) {
+ return Status::IOError("Unexpected HDFS working directory URI: ", wd);
+ }
+ wd = wd.substr(0, pos); // keep up until host:port (included)
+ } else if (!wd.empty() && wd.back() != '/') {
+ // For a relative lookup, trim leading slashes
+ wd += '/';
+ }
+
+ if (!select.base_dir.empty()) {
+ ARROW_ASSIGN_OR_RAISE(auto info, GetFileInfo(select.base_dir));
+ if (info.type() == FileType::File) {
+ return Status::IOError(
+ "GetFileInfo expects base_dir of selector to be a directory, but '",
+ select.base_dir, "' is a file");
+ }
+ }
+ RETURN_NOT_OK(StatSelector(wd, select.base_dir, select, 0, &results));
+ return results;
+ }
+
+ Status CreateDir(const std::string& path, bool recursive) {
+ if (IsDirectory(path)) {
+ return Status::OK();
+ }
+ if (!recursive) {
+ const auto parent = GetAbstractPathParent(path).first;
+ if (!parent.empty() && !IsDirectory(parent)) {
+ return Status::IOError("Cannot create directory '", path,
+ "': parent is not a directory");
+ }
+ }
+ RETURN_NOT_OK(client_->MakeDirectory(path));
+ return Status::OK();
+ }
+
+ Status DeleteDir(const std::string& path) {
+ if (!IsDirectory(path)) {
+ return Status::IOError("Cannot delete directory '", path, "': not a directory");
+ }
+ RETURN_NOT_OK(client_->DeleteDirectory(path));
+ return Status::OK();
+ }
+
+ Status DeleteDirContents(const std::string& path) {
+ if (!IsDirectory(path)) {
+ return Status::IOError("Cannot delete contents of directory '", path,
+ "': not a directory");
+ }
+ std::vector<std::string> file_list;
+ RETURN_NOT_OK(client_->GetChildren(path, &file_list));
+ for (auto file : file_list) {
+ RETURN_NOT_OK(client_->Delete(file, /*recursive=*/true));
+ }
+ return Status::OK();
+ }
+
+ Status DeleteFile(const std::string& path) {
+ if (IsDirectory(path)) {
+ return Status::IOError("path is a directory");
+ }
+ RETURN_NOT_OK(client_->Delete(path));
+ return Status::OK();
+ }
+
+ Status Move(const std::string& src, const std::string& dest) {
+ auto st = client_->Rename(src, dest);
+ if (st.IsIOError() && IsFile(src) && IsFile(dest)) {
+ // Allow file -> file clobber
+ RETURN_NOT_OK(client_->Delete(dest));
+ st = client_->Rename(src, dest);
+ }
+ return st;
+ }
+
+ Status CopyFile(const std::string& src, const std::string& dest) {
+ return client_->Copy(src, dest);
+ }
+
+ Result<std::shared_ptr<io::InputStream>> OpenInputStream(const std::string& path) {
+ std::shared_ptr<io::HdfsReadableFile> file;
+ RETURN_NOT_OK(client_->OpenReadable(path, io_context_, &file));
+ return file;
+ }
+
+ Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(const std::string& path) {
+ std::shared_ptr<io::HdfsReadableFile> file;
+ RETURN_NOT_OK(client_->OpenReadable(path, io_context_, &file));
+ return file;
+ }
+
+ Result<std::shared_ptr<io::OutputStream>> OpenOutputStream(const std::string& path) {
+ bool append = false;
+ return OpenOutputStreamGeneric(path, append);
+ }
+
+ Result<std::shared_ptr<io::OutputStream>> OpenAppendStream(const std::string& path) {
+ bool append = true;
+ return OpenOutputStreamGeneric(path, append);
+ }
+
+ protected:
+ const HdfsOptions options_;
+ const io::IOContext io_context_;
+ std::shared_ptr<::arrow::io::HadoopFileSystem> client_;
+
+ void PathInfoToFileInfo(const io::HdfsPathInfo& info, FileInfo* out) {
+ if (info.kind == io::ObjectType::DIRECTORY) {
+ out->set_type(FileType::Directory);
+ out->set_size(kNoSize);
+ } else if (info.kind == io::ObjectType::FILE) {
+ out->set_type(FileType::File);
+ out->set_size(info.size);
+ }
+ out->set_mtime(ToTimePoint(info.last_modified_time));
+ }
+
+ Result<std::shared_ptr<io::OutputStream>> OpenOutputStreamGeneric(
+ const std::string& path, bool append) {
+ std::shared_ptr<io::HdfsOutputStream> stream;
+ RETURN_NOT_OK(client_->OpenWritable(path, append, options_.buffer_size,
+ options_.replication, options_.default_block_size,
+ &stream));
+ return stream;
+ }
+
+ bool IsDirectory(const std::string& path) {
+ io::HdfsPathInfo info;
+ return GetPathInfo(path, &info) && info.kind == io::ObjectType::DIRECTORY;
+ }
+
+ bool IsFile(const std::string& path) {
+ io::HdfsPathInfo info;
+ return GetPathInfo(path, &info) && info.kind == io::ObjectType::FILE;
+ }
+
+ bool GetPathInfo(const std::string& path, io::HdfsPathInfo* info) {
+ return client_->GetPathInfo(path, info).ok();
+ }
+
+ TimePoint ToTimePoint(int secs) {
+ std::chrono::nanoseconds ns_count(static_cast<int64_t>(secs) * 1000000000);
+ return TimePoint(std::chrono::duration_cast<TimePoint::duration>(ns_count));
+ }
+};
+
+void HdfsOptions::ConfigureEndPoint(std::string host, int port) {
+ connection_config.host = std::move(host);
+ connection_config.port = port;
+}
+
+void HdfsOptions::ConfigureUser(std::string user_name) {
+ connection_config.user = std::move(user_name);
+}
+
+void HdfsOptions::ConfigureKerberosTicketCachePath(std::string path) {
+ connection_config.kerb_ticket = std::move(path);
+}
+
+void HdfsOptions::ConfigureReplication(int16_t replication) {
+ this->replication = replication;
+}
+
+void HdfsOptions::ConfigureBufferSize(int32_t buffer_size) {
+ this->buffer_size = buffer_size;
+}
+
+void HdfsOptions::ConfigureBlockSize(int64_t default_block_size) {
+ this->default_block_size = default_block_size;
+}
+
+void HdfsOptions::ConfigureExtraConf(std::string key, std::string val) {
+ connection_config.extra_conf.emplace(std::move(key), std::move(val));
+}
+
+bool HdfsOptions::Equals(const HdfsOptions& other) const {
+ return (buffer_size == other.buffer_size && replication == other.replication &&
+ default_block_size == other.default_block_size &&
+ connection_config.host == other.connection_config.host &&
+ connection_config.port == other.connection_config.port &&
+ connection_config.user == other.connection_config.user &&
+ connection_config.kerb_ticket == other.connection_config.kerb_ticket &&
+ connection_config.extra_conf == other.connection_config.extra_conf);
+}
+
+Result<HdfsOptions> HdfsOptions::FromUri(const Uri& uri) {
+ HdfsOptions options;
+
+ std::unordered_map<std::string, std::string> options_map;
+ ARROW_ASSIGN_OR_RAISE(const auto options_items, uri.query_items());
+ for (const auto& kv : options_items) {
+ options_map.emplace(kv.first, kv.second);
+ }
+
+ std::string host;
+ host = uri.scheme() + "://" + uri.host();
+
+ // configure endpoint
+ const auto port = uri.port();
+ if (port == -1) {
+ // default port will be determined by hdfs FileSystem impl
+ options.ConfigureEndPoint(host, 0);
+ } else {
+ options.ConfigureEndPoint(host, port);
+ }
+
+ // configure replication
+ auto it = options_map.find("replication");
+ if (it != options_map.end()) {
+ const auto& v = it->second;
+ int16_t replication;
+ if (!ParseValue<Int16Type>(v.data(), v.size(), &replication)) {
+ return Status::Invalid("Invalid value for option 'replication': '", v, "'");
+ }
+ options.ConfigureReplication(replication);
+ options_map.erase(it);
+ }
+
+ // configure buffer_size
+ it = options_map.find("buffer_size");
+ if (it != options_map.end()) {
+ const auto& v = it->second;
+ int32_t buffer_size;
+ if (!ParseValue<Int32Type>(v.data(), v.size(), &buffer_size)) {
+ return Status::Invalid("Invalid value for option 'buffer_size': '", v, "'");
+ }
+ options.ConfigureBufferSize(buffer_size);
+ options_map.erase(it);
+ }
+
+ // configure default_block_size
+ it = options_map.find("default_block_size");
+ if (it != options_map.end()) {
+ const auto& v = it->second;
+ int64_t default_block_size;
+ if (!ParseValue<Int64Type>(v.data(), v.size(), &default_block_size)) {
+ return Status::Invalid("Invalid value for option 'default_block_size': '", v, "'");
+ }
+ options.ConfigureBlockSize(default_block_size);
+ options_map.erase(it);
+ }
+
+ // configure user
+ it = options_map.find("user");
+ if (it != options_map.end()) {
+ const auto& user = it->second;
+ options.ConfigureUser(user);
+ options_map.erase(it);
+ }
+
+ // configure kerberos
+ it = options_map.find("kerb_ticket");
+ if (it != options_map.end()) {
+ const auto& ticket = it->second;
+ options.ConfigureKerberosTicketCachePath(ticket);
+ options_map.erase(it);
+ }
+
+ // configure other options
+ for (const auto& it : options_map) {
+ options.ConfigureExtraConf(it.first, it.second);
+ }
+
+ return options;
+}
+
+Result<HdfsOptions> HdfsOptions::FromUri(const std::string& uri_string) {
+ Uri uri;
+ RETURN_NOT_OK(uri.Parse(uri_string));
+ return FromUri(uri);
+}
+
+HadoopFileSystem::HadoopFileSystem(const HdfsOptions& options,
+ const io::IOContext& io_context)
+ : FileSystem(io_context), impl_(new Impl{options, io_context_}) {
+ default_async_is_sync_ = false;
+}
+
+HadoopFileSystem::~HadoopFileSystem() {}
+
+Result<std::shared_ptr<HadoopFileSystem>> HadoopFileSystem::Make(
+ const HdfsOptions& options, const io::IOContext& io_context) {
+ std::shared_ptr<HadoopFileSystem> ptr(new HadoopFileSystem(options, io_context));
+ RETURN_NOT_OK(ptr->impl_->Init());
+ return ptr;
+}
+
+Result<FileInfo> HadoopFileSystem::GetFileInfo(const std::string& path) {
+ return impl_->GetFileInfo(path);
+}
+
+HdfsOptions HadoopFileSystem::options() const { return impl_->options(); }
+
+bool HadoopFileSystem::Equals(const FileSystem& other) const {
+ if (this == &other) {
+ return true;
+ }
+ if (other.type_name() != type_name()) {
+ return false;
+ }
+ const auto& hdfs = ::arrow::internal::checked_cast<const HadoopFileSystem&>(other);
+ return options().Equals(hdfs.options());
+}
+
+Result<std::vector<FileInfo>> HadoopFileSystem::GetFileInfo(const FileSelector& select) {
+ return impl_->GetFileInfo(select);
+}
+
+Status HadoopFileSystem::CreateDir(const std::string& path, bool recursive) {
+ return impl_->CreateDir(path, recursive);
+}
+
+Status HadoopFileSystem::DeleteDir(const std::string& path) {
+ return impl_->DeleteDir(path);
+}
+
+Status HadoopFileSystem::DeleteDirContents(const std::string& path) {
+ if (internal::IsEmptyPath(path)) {
+ return internal::InvalidDeleteDirContents(path);
+ }
+ return impl_->DeleteDirContents(path);
+}
+
+Status HadoopFileSystem::DeleteRootDirContents() { return impl_->DeleteDirContents(""); }
+
+Status HadoopFileSystem::DeleteFile(const std::string& path) {
+ return impl_->DeleteFile(path);
+}
+
+Status HadoopFileSystem::Move(const std::string& src, const std::string& dest) {
+ return impl_->Move(src, dest);
+}
+
+Status HadoopFileSystem::CopyFile(const std::string& src, const std::string& dest) {
+ return impl_->CopyFile(src, dest);
+}
+
+Result<std::shared_ptr<io::InputStream>> HadoopFileSystem::OpenInputStream(
+ const std::string& path) {
+ return impl_->OpenInputStream(path);
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> HadoopFileSystem::OpenInputFile(
+ const std::string& path) {
+ return impl_->OpenInputFile(path);
+}
+
+Result<std::shared_ptr<io::OutputStream>> HadoopFileSystem::OpenOutputStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ return impl_->OpenOutputStream(path);
+}
+
+Result<std::shared_ptr<io::OutputStream>> HadoopFileSystem::OpenAppendStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ return impl_->OpenAppendStream(path);
+}
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/hdfs.h b/src/arrow/cpp/src/arrow/filesystem/hdfs.h
new file mode 100644
index 000000000..bc72e1cdc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/hdfs.h
@@ -0,0 +1,113 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/filesystem/filesystem.h"
+#include "arrow/io/hdfs.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace fs {
+
+/// Options for the HDFS implementation.
+struct ARROW_EXPORT HdfsOptions {
+ HdfsOptions() = default;
+ ~HdfsOptions() = default;
+
+ /// Hdfs configuration options, contains host, port, driver
+ io::HdfsConnectionConfig connection_config;
+
+ /// Used by Hdfs OpenWritable Interface.
+ int32_t buffer_size = 0;
+ int16_t replication = 3;
+ int64_t default_block_size = 0;
+
+ void ConfigureEndPoint(std::string host, int port);
+ void ConfigureReplication(int16_t replication);
+ void ConfigureUser(std::string user_name);
+ void ConfigureBufferSize(int32_t buffer_size);
+ void ConfigureBlockSize(int64_t default_block_size);
+ void ConfigureKerberosTicketCachePath(std::string path);
+ void ConfigureExtraConf(std::string key, std::string val);
+
+ bool Equals(const HdfsOptions& other) const;
+
+ static Result<HdfsOptions> FromUri(const ::arrow::internal::Uri& uri);
+ static Result<HdfsOptions> FromUri(const std::string& uri);
+};
+
+/// HDFS-backed FileSystem implementation.
+///
+/// implementation notes:
+/// - This is a wrapper of arrow/io/hdfs, so we can use FileSystem API to handle hdfs.
+class ARROW_EXPORT HadoopFileSystem : public FileSystem {
+ public:
+ ~HadoopFileSystem() override;
+
+ std::string type_name() const override { return "hdfs"; }
+ HdfsOptions options() const;
+ bool Equals(const FileSystem& other) const override;
+
+ /// \cond FALSE
+ using FileSystem::GetFileInfo;
+ /// \endcond
+ Result<FileInfo> GetFileInfo(const std::string& path) override;
+ Result<std::vector<FileInfo>> GetFileInfo(const FileSelector& select) override;
+
+ Status CreateDir(const std::string& path, bool recursive = true) override;
+
+ Status DeleteDir(const std::string& path) override;
+
+ Status DeleteDirContents(const std::string& path) override;
+
+ Status DeleteRootDirContents() override;
+
+ Status DeleteFile(const std::string& path) override;
+
+ Status Move(const std::string& src, const std::string& dest) override;
+
+ Status CopyFile(const std::string& src, const std::string& dest) override;
+
+ Result<std::shared_ptr<io::InputStream>> OpenInputStream(
+ const std::string& path) override;
+ Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(
+ const std::string& path) override;
+ Result<std::shared_ptr<io::OutputStream>> OpenOutputStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+ Result<std::shared_ptr<io::OutputStream>> OpenAppendStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+
+ /// Create a HdfsFileSystem instance from the given options.
+ static Result<std::shared_ptr<HadoopFileSystem>> Make(
+ const HdfsOptions& options, const io::IOContext& = io::default_io_context());
+
+ protected:
+ HadoopFileSystem(const HdfsOptions& options, const io::IOContext&);
+
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/hdfs_test.cc b/src/arrow/cpp/src/arrow/filesystem/hdfs_test.cc
new file mode 100644
index 000000000..bd670c4b9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/hdfs_test.cc
@@ -0,0 +1,356 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/filesystem/hdfs.h"
+
+#include <gtest/gtest.h>
+
+#include <chrono>
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include "arrow/filesystem/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+
+using internal::Uri;
+
+namespace fs {
+
+TEST(TestHdfsOptions, FromUri) {
+ HdfsOptions options;
+ Uri uri;
+
+ ASSERT_OK(uri.Parse("hdfs://localhost"));
+ ASSERT_OK_AND_ASSIGN(options, HdfsOptions::FromUri(uri));
+ ASSERT_EQ(options.replication, 3);
+ ASSERT_EQ(options.connection_config.host, "hdfs://localhost");
+ ASSERT_EQ(options.connection_config.port, 0);
+ ASSERT_EQ(options.connection_config.user, "");
+
+ ASSERT_OK(uri.Parse("hdfs://otherhost:9999/?replication=2&kerb_ticket=kerb.ticket"));
+ ASSERT_OK_AND_ASSIGN(options, HdfsOptions::FromUri(uri));
+ ASSERT_EQ(options.replication, 2);
+ ASSERT_EQ(options.connection_config.kerb_ticket, "kerb.ticket");
+ ASSERT_EQ(options.connection_config.host, "hdfs://otherhost");
+ ASSERT_EQ(options.connection_config.port, 9999);
+ ASSERT_EQ(options.connection_config.user, "");
+
+ ASSERT_OK(uri.Parse("hdfs://otherhost:9999/?hdfs_token=hdfs_token_ticket"));
+ ASSERT_OK_AND_ASSIGN(options, HdfsOptions::FromUri(uri));
+ ASSERT_EQ(options.connection_config.host, "hdfs://otherhost");
+ ASSERT_EQ(options.connection_config.port, 9999);
+ ASSERT_EQ(options.connection_config.extra_conf["hdfs_token"], "hdfs_token_ticket");
+
+ ASSERT_OK(uri.Parse("viewfs://other-nn/mypath/myfile"));
+ ASSERT_OK_AND_ASSIGN(options, HdfsOptions::FromUri(uri));
+ ASSERT_EQ(options.connection_config.host, "viewfs://other-nn");
+ ASSERT_EQ(options.connection_config.port, 0);
+ ASSERT_EQ(options.connection_config.user, "");
+}
+
+class HadoopFileSystemTestMixin {
+ public:
+ void MakeFileSystem() {
+ const char* host = std::getenv("ARROW_HDFS_TEST_HOST");
+ const char* port = std::getenv("ARROW_HDFS_TEST_PORT");
+ const char* user = std::getenv("ARROW_HDFS_TEST_USER");
+
+ std::string hdfs_host = host == nullptr ? "localhost" : std::string(host);
+ int hdfs_port = port == nullptr ? 20500 : atoi(port);
+ std::string hdfs_user = user == nullptr ? "root" : std::string(user);
+
+ options_.ConfigureEndPoint(hdfs_host, hdfs_port);
+ options_.ConfigureUser(hdfs_user);
+ options_.ConfigureReplication(0);
+
+ auto result = HadoopFileSystem::Make(options_);
+ if (!result.ok()) {
+ ARROW_LOG(INFO)
+ << "HadoopFileSystem::Make failed, it is possible when we don't have "
+ "proper driver on this node, err msg is "
+ << result.status().ToString();
+ loaded_driver_ = false;
+ return;
+ }
+ loaded_driver_ = true;
+ fs_ = *result;
+ }
+
+ protected:
+ HdfsOptions options_;
+ bool loaded_driver_ = false;
+ std::shared_ptr<FileSystem> fs_;
+};
+
+class TestHadoopFileSystem : public ::testing::Test, public HadoopFileSystemTestMixin {
+ public:
+ void SetUp() override { MakeFileSystem(); }
+
+ void TestFileSystemFromUri() {
+ std::stringstream ss;
+ ss << "hdfs://" << options_.connection_config.host << ":"
+ << options_.connection_config.port << "/"
+ << "?replication=0&user=" << options_.connection_config.user;
+
+ std::shared_ptr<FileSystem> uri_fs;
+ std::string path;
+ ARROW_LOG(INFO) << "!!! uri = " << ss.str();
+ ASSERT_OK_AND_ASSIGN(uri_fs, FileSystemFromUri(ss.str(), &path));
+ ASSERT_EQ(path, "/");
+
+ // Sanity check
+ ASSERT_OK(uri_fs->CreateDir("AB"));
+ AssertFileInfo(uri_fs.get(), "AB", FileType::Directory);
+ ASSERT_OK(uri_fs->DeleteDir("AB"));
+ AssertFileInfo(uri_fs.get(), "AB", FileType::NotFound);
+ }
+
+ void TestGetFileInfo(const std::string& base_dir) {
+ std::vector<FileInfo> infos;
+
+ ASSERT_OK(fs_->CreateDir(base_dir + "AB"));
+ ASSERT_OK(fs_->CreateDir(base_dir + "AB/CD"));
+ ASSERT_OK(fs_->CreateDir(base_dir + "AB/EF"));
+ ASSERT_OK(fs_->CreateDir(base_dir + "AB/EF/GH"));
+ ASSERT_OK(fs_->CreateDir(base_dir + "AB/EF/GH/IJ"));
+ CreateFile(fs_.get(), base_dir + "AB/data", "some data");
+
+ // With single path
+ FileInfo info;
+ ASSERT_OK_AND_ASSIGN(info, fs_->GetFileInfo(base_dir + "AB"));
+ AssertFileInfo(info, base_dir + "AB", FileType::Directory);
+ ASSERT_OK_AND_ASSIGN(info, fs_->GetFileInfo(base_dir + "AB/data"));
+ AssertFileInfo(info, base_dir + "AB/data", FileType::File, 9);
+
+ // With selector
+ FileSelector selector;
+ selector.base_dir = base_dir + "AB";
+ selector.recursive = false;
+
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(selector));
+ ASSERT_EQ(infos.size(), 3);
+ AssertFileInfo(infos[0], base_dir + "AB/CD", FileType::Directory);
+ AssertFileInfo(infos[1], base_dir + "AB/EF", FileType::Directory);
+ AssertFileInfo(infos[2], base_dir + "AB/data", FileType::File);
+
+ selector.recursive = true;
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(selector));
+ ASSERT_EQ(infos.size(), 5);
+ AssertFileInfo(infos[0], base_dir + "AB/CD", FileType::Directory);
+ AssertFileInfo(infos[1], base_dir + "AB/EF", FileType::Directory);
+ AssertFileInfo(infos[2], base_dir + "AB/EF/GH", FileType::Directory);
+ AssertFileInfo(infos[3], base_dir + "AB/EF/GH/IJ", FileType::Directory);
+ AssertFileInfo(infos[4], base_dir + "AB/data", FileType::File, 9);
+
+ selector.max_recursion = 0;
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(selector));
+ ASSERT_EQ(infos.size(), 3);
+ AssertFileInfo(infos[0], base_dir + "AB/CD", FileType::Directory);
+ AssertFileInfo(infos[1], base_dir + "AB/EF", FileType::Directory);
+ AssertFileInfo(infos[2], base_dir + "AB/data", FileType::File);
+
+ selector.max_recursion = 1;
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(selector));
+ ASSERT_EQ(infos.size(), 4);
+ AssertFileInfo(infos[0], base_dir + "AB/CD", FileType::Directory);
+ AssertFileInfo(infos[1], base_dir + "AB/EF", FileType::Directory);
+ AssertFileInfo(infos[2], base_dir + "AB/EF/GH", FileType::Directory);
+ AssertFileInfo(infos[3], base_dir + "AB/data", FileType::File);
+
+ selector.base_dir = base_dir + "XYZ";
+ selector.allow_not_found = true;
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(selector));
+ ASSERT_EQ(infos.size(), 0);
+
+ selector.allow_not_found = false;
+ ASSERT_RAISES(IOError, fs_->GetFileInfo(selector));
+
+ ASSERT_OK(fs_->DeleteDir(base_dir + "AB"));
+ AssertFileInfo(fs_.get(), base_dir + "AB", FileType::NotFound);
+ }
+};
+
+#define SKIP_IF_NO_DRIVER() \
+ if (!this->loaded_driver_) { \
+ GTEST_SKIP() << "Driver not loaded, skipping"; \
+ }
+
+TEST_F(TestHadoopFileSystem, CreateDirDeleteDir) {
+ SKIP_IF_NO_DRIVER();
+
+ // recursive = true
+ ASSERT_OK(this->fs_->CreateDir("AB/CD"));
+ CreateFile(this->fs_.get(), "AB/CD/data", "some data");
+ AssertFileInfo(this->fs_.get(), "AB", FileType::Directory);
+ AssertFileInfo(this->fs_.get(), "AB/CD", FileType::Directory);
+ AssertFileInfo(this->fs_.get(), "AB/CD/data", FileType::File, 9);
+
+ ASSERT_OK(this->fs_->DeleteDir("AB"));
+ AssertFileInfo(this->fs_.get(), "AB", FileType::NotFound);
+
+ // recursive = false
+ ASSERT_RAISES(IOError, this->fs_->CreateDir("AB/CD", /*recursive=*/false));
+ ASSERT_OK(this->fs_->CreateDir("AB", /*recursive=*/false));
+ ASSERT_OK(this->fs_->CreateDir("AB/CD", /*recursive=*/false));
+
+ ASSERT_OK(this->fs_->DeleteDir("AB"));
+ AssertFileInfo(this->fs_.get(), "AB", FileType::NotFound);
+ ASSERT_RAISES(IOError, this->fs_->DeleteDir("AB"));
+}
+
+TEST_F(TestHadoopFileSystem, DeleteDirContents) {
+ SKIP_IF_NO_DRIVER();
+
+ ASSERT_OK(this->fs_->CreateDir("AB/CD"));
+ CreateFile(this->fs_.get(), "AB/CD/data", "some data");
+ AssertFileInfo(this->fs_.get(), "AB", FileType::Directory);
+ AssertFileInfo(this->fs_.get(), "AB/CD", FileType::Directory);
+ AssertFileInfo(this->fs_.get(), "AB/CD/data", FileType::File, 9);
+
+ ASSERT_OK(this->fs_->DeleteDirContents("AB"));
+ AssertFileInfo(this->fs_.get(), "AB", FileType::Directory);
+ AssertFileInfo(this->fs_.get(), "AB/CD", FileType::NotFound);
+ AssertFileInfo(this->fs_.get(), "AB/CD/data", FileType::NotFound);
+
+ ASSERT_OK(this->fs_->DeleteDirContents("AB"));
+ AssertFileInfo(this->fs_.get(), "AB", FileType::Directory);
+ ASSERT_OK(this->fs_->DeleteDir("AB"));
+}
+
+TEST_F(TestHadoopFileSystem, WriteReadFile) {
+ SKIP_IF_NO_DRIVER();
+
+ ASSERT_OK(this->fs_->CreateDir("CD"));
+ constexpr int kDataSize = 9;
+ std::string file_name = "CD/abc";
+ std::string data = "some data";
+ std::shared_ptr<io::OutputStream> stream;
+ ASSERT_OK_AND_ASSIGN(stream, this->fs_->OpenOutputStream(file_name));
+ auto data_size = static_cast<int64_t>(data.size());
+ ASSERT_OK(stream->Write(data.data(), data_size));
+ ASSERT_OK(stream->Close());
+
+ std::shared_ptr<io::RandomAccessFile> file;
+ ASSERT_OK_AND_ASSIGN(file, this->fs_->OpenInputFile(file_name));
+ ASSERT_OK_AND_EQ(kDataSize, file->GetSize());
+ uint8_t buffer[kDataSize];
+ ASSERT_OK_AND_EQ(kDataSize, file->Read(kDataSize, buffer));
+ ASSERT_EQ(0, std::memcmp(buffer, data.c_str(), kDataSize));
+
+ ASSERT_OK(this->fs_->DeleteDir("CD"));
+}
+
+TEST_F(TestHadoopFileSystem, GetFileInfoRelative) {
+ // Test GetFileInfo() with relative paths
+ SKIP_IF_NO_DRIVER();
+
+ this->TestGetFileInfo("");
+}
+
+TEST_F(TestHadoopFileSystem, GetFileInfoAbsolute) {
+ // Test GetFileInfo() with absolute paths
+ SKIP_IF_NO_DRIVER();
+
+ this->TestGetFileInfo("/");
+}
+
+TEST_F(TestHadoopFileSystem, RelativeVsAbsolutePaths) {
+ SKIP_IF_NO_DRIVER();
+
+ // XXX This test assumes the current working directory is not "/"
+
+ ASSERT_OK(this->fs_->CreateDir("AB"));
+ AssertFileInfo(this->fs_.get(), "AB", FileType::Directory);
+ AssertFileInfo(this->fs_.get(), "/AB", FileType::NotFound);
+
+ ASSERT_OK(this->fs_->CreateDir("/CD"));
+ AssertFileInfo(this->fs_.get(), "/CD", FileType::Directory);
+ AssertFileInfo(this->fs_.get(), "CD", FileType::NotFound);
+}
+
+TEST_F(TestHadoopFileSystem, MoveDir) {
+ SKIP_IF_NO_DRIVER();
+
+ FileInfo info;
+ std::string directory_name_src = "AB";
+ std::string directory_name_dest = "CD";
+ ASSERT_OK(this->fs_->CreateDir(directory_name_src));
+ ASSERT_OK_AND_ASSIGN(info, this->fs_->GetFileInfo(directory_name_src));
+ AssertFileInfo(info, directory_name_src, FileType::Directory);
+
+ // move file
+ ASSERT_OK(this->fs_->Move(directory_name_src, directory_name_dest));
+ ASSERT_OK_AND_ASSIGN(info, this->fs_->GetFileInfo(directory_name_src));
+ ASSERT_TRUE(info.type() == FileType::NotFound);
+
+ ASSERT_OK_AND_ASSIGN(info, this->fs_->GetFileInfo(directory_name_dest));
+ AssertFileInfo(info, directory_name_dest, FileType::Directory);
+ ASSERT_OK(this->fs_->DeleteDir(directory_name_dest));
+}
+
+TEST_F(TestHadoopFileSystem, FileSystemFromUri) {
+ SKIP_IF_NO_DRIVER();
+
+ this->TestFileSystemFromUri();
+}
+
+class TestHadoopFileSystemGeneric : public ::testing::Test,
+ public HadoopFileSystemTestMixin,
+ public GenericFileSystemTest {
+ public:
+ void SetUp() override {
+ MakeFileSystem();
+ SKIP_IF_NO_DRIVER();
+ timestamp_ =
+ static_cast<int64_t>(std::chrono::time_point_cast<std::chrono::nanoseconds>(
+ std::chrono::steady_clock::now())
+ .time_since_epoch()
+ .count());
+ }
+
+ protected:
+ bool allow_write_file_over_dir() const override { return true; }
+ bool allow_move_dir_over_non_empty_dir() const override { return true; }
+ bool have_implicit_directories() const override { return true; }
+ bool allow_append_to_new_file() const override { return false; }
+
+ std::shared_ptr<FileSystem> GetEmptyFileSystem() override {
+ // Since the HDFS contents are kept persistently between test runs,
+ // make sure each test gets a pristine fresh directory.
+ std::stringstream ss;
+ ss << "GenericTest" << timestamp_ << "-" << test_num_++;
+ const auto subdir = ss.str();
+ ARROW_EXPECT_OK(fs_->CreateDir(subdir));
+ return std::make_shared<SubTreeFileSystem>(subdir, fs_);
+ }
+
+ static int test_num_;
+ int64_t timestamp_;
+};
+
+int TestHadoopFileSystemGeneric::test_num_ = 1;
+
+GENERIC_FS_TEST_FUNCTIONS(TestHadoopFileSystemGeneric);
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/localfs.cc b/src/arrow/cpp/src/arrow/filesystem/localfs.cc
new file mode 100644
index 000000000..775fd746a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/localfs.cc
@@ -0,0 +1,448 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <chrono>
+#include <cstring>
+#include <sstream>
+#include <utility>
+
+#ifdef _WIN32
+#include "arrow/util/windows_compatibility.h"
+#else
+#include <errno.h>
+#include <fcntl.h>
+#include <stdio.h>
+#include <sys/stat.h>
+#endif
+
+#include "arrow/filesystem/localfs.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/filesystem/util_internal.h"
+#include "arrow/io/file.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/uri.h"
+#include "arrow/util/windows_fixup.h"
+
+namespace arrow {
+namespace fs {
+
+using ::arrow::internal::IOErrorFromErrno;
+#ifdef _WIN32
+using ::arrow::internal::IOErrorFromWinError;
+#endif
+using ::arrow::internal::NativePathString;
+using ::arrow::internal::PlatformFilename;
+
+namespace internal {
+
+#ifdef _WIN32
+static bool IsDriveLetter(char c) {
+ // Can't use locale-dependent functions from the C/C++ stdlib
+ return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
+}
+#endif
+
+bool DetectAbsolutePath(const std::string& s) {
+ // Is it a /-prefixed local path?
+ if (s.length() >= 1 && s[0] == '/') {
+ return true;
+ }
+#ifdef _WIN32
+ // Is it a \-prefixed local path?
+ if (s.length() >= 1 && s[0] == '\\') {
+ return true;
+ }
+ // Does it start with a drive letter in addition to being /- or \-prefixed,
+ // e.g. "C:\..."?
+ if (s.length() >= 3 && s[1] == ':' && (s[2] == '/' || s[2] == '\\') &&
+ IsDriveLetter(s[0])) {
+ return true;
+ }
+#endif
+ return false;
+}
+
+} // namespace internal
+
+namespace {
+
+#ifdef _WIN32
+
+std::string NativeToString(const NativePathString& ns) {
+ PlatformFilename fn(ns);
+ return fn.ToString();
+}
+
+TimePoint ToTimePoint(FILETIME ft) {
+ // Hundreds of nanoseconds between January 1, 1601 (UTC) and the Unix epoch.
+ static constexpr int64_t kFileTimeEpoch = 11644473600LL * 10000000;
+
+ int64_t hundreds = (static_cast<int64_t>(ft.dwHighDateTime) << 32) + ft.dwLowDateTime -
+ kFileTimeEpoch; // hundreds of ns since Unix epoch
+ std::chrono::nanoseconds ns_count(100 * hundreds);
+ return TimePoint(std::chrono::duration_cast<TimePoint::duration>(ns_count));
+}
+
+FileInfo FileInformationToFileInfo(const BY_HANDLE_FILE_INFORMATION& information) {
+ FileInfo info;
+ if (information.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) {
+ info.set_type(FileType::Directory);
+ info.set_size(kNoSize);
+ } else {
+ // Regular file
+ info.set_type(FileType::File);
+ info.set_size((static_cast<int64_t>(information.nFileSizeHigh) << 32) +
+ information.nFileSizeLow);
+ }
+ info.set_mtime(ToTimePoint(information.ftLastWriteTime));
+ return info;
+}
+
+Result<FileInfo> StatFile(const std::wstring& path) {
+ HANDLE h;
+ std::string bytes_path = NativeToString(path);
+ FileInfo info;
+
+ /* Inspired by CPython, see Modules/posixmodule.c */
+ h = CreateFileW(path.c_str(), FILE_READ_ATTRIBUTES, /* desired access */
+ 0, /* share mode */
+ NULL, /* security attributes */
+ OPEN_EXISTING,
+ /* FILE_FLAG_BACKUP_SEMANTICS is required to open a directory */
+ FILE_ATTRIBUTE_NORMAL | FILE_FLAG_BACKUP_SEMANTICS, NULL);
+
+ if (h == INVALID_HANDLE_VALUE) {
+ DWORD err = GetLastError();
+ if (err == ERROR_FILE_NOT_FOUND || err == ERROR_PATH_NOT_FOUND) {
+ info.set_path(bytes_path);
+ info.set_type(FileType::NotFound);
+ info.set_mtime(kNoTime);
+ info.set_size(kNoSize);
+ return info;
+ } else {
+ return IOErrorFromWinError(GetLastError(), "Failed querying information for path '",
+ bytes_path, "'");
+ }
+ }
+ BY_HANDLE_FILE_INFORMATION information;
+ if (!GetFileInformationByHandle(h, &information)) {
+ CloseHandle(h);
+ return IOErrorFromWinError(GetLastError(), "Failed querying information for path '",
+ bytes_path, "'");
+ }
+ CloseHandle(h);
+ info = FileInformationToFileInfo(information);
+ info.set_path(bytes_path);
+ return info;
+}
+
+#else // POSIX systems
+
+TimePoint ToTimePoint(const struct timespec& s) {
+ std::chrono::nanoseconds ns_count(static_cast<int64_t>(s.tv_sec) * 1000000000 +
+ static_cast<int64_t>(s.tv_nsec));
+ return TimePoint(std::chrono::duration_cast<TimePoint::duration>(ns_count));
+}
+
+FileInfo StatToFileInfo(const struct stat& s) {
+ FileInfo info;
+ if (S_ISREG(s.st_mode)) {
+ info.set_type(FileType::File);
+ info.set_size(static_cast<int64_t>(s.st_size));
+ } else if (S_ISDIR(s.st_mode)) {
+ info.set_type(FileType::Directory);
+ info.set_size(kNoSize);
+ } else {
+ info.set_type(FileType::Unknown);
+ info.set_size(kNoSize);
+ }
+#ifdef __APPLE__
+ // macOS doesn't use the POSIX-compliant spelling
+ info.set_mtime(ToTimePoint(s.st_mtimespec));
+#else
+ info.set_mtime(ToTimePoint(s.st_mtim));
+#endif
+ return info;
+}
+
+Result<FileInfo> StatFile(const std::string& path) {
+ FileInfo info;
+ struct stat s;
+ int r = stat(path.c_str(), &s);
+ if (r == -1) {
+ if (errno == ENOENT || errno == ENOTDIR || errno == ELOOP) {
+ info.set_type(FileType::NotFound);
+ info.set_mtime(kNoTime);
+ info.set_size(kNoSize);
+ } else {
+ return IOErrorFromErrno(errno, "Failed stat()ing path '", path, "'");
+ }
+ } else {
+ info = StatToFileInfo(s);
+ }
+ info.set_path(path);
+ return info;
+}
+
+#endif
+
+Status StatSelector(const PlatformFilename& dir_fn, const FileSelector& select,
+ int32_t nesting_depth, std::vector<FileInfo>* out) {
+ auto result = ListDir(dir_fn);
+ if (!result.ok()) {
+ auto status = result.status();
+ if (select.allow_not_found && status.IsIOError()) {
+ ARROW_ASSIGN_OR_RAISE(bool exists, FileExists(dir_fn));
+ if (!exists) {
+ return Status::OK();
+ }
+ }
+ return status;
+ }
+
+ for (const auto& child_fn : *result) {
+ PlatformFilename full_fn = dir_fn.Join(child_fn);
+ ARROW_ASSIGN_OR_RAISE(FileInfo info, StatFile(full_fn.ToNative()));
+ if (info.type() != FileType::NotFound) {
+ out->push_back(std::move(info));
+ }
+ if (nesting_depth < select.max_recursion && select.recursive &&
+ info.type() == FileType::Directory) {
+ RETURN_NOT_OK(StatSelector(full_fn, select, nesting_depth + 1, out));
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+LocalFileSystemOptions LocalFileSystemOptions::Defaults() {
+ return LocalFileSystemOptions();
+}
+
+bool LocalFileSystemOptions::Equals(const LocalFileSystemOptions& other) const {
+ return use_mmap == other.use_mmap;
+}
+
+Result<LocalFileSystemOptions> LocalFileSystemOptions::FromUri(
+ const ::arrow::internal::Uri& uri, std::string* out_path) {
+ if (!uri.username().empty() || !uri.password().empty()) {
+ return Status::Invalid("Unsupported username or password in local URI: '",
+ uri.ToString(), "'");
+ }
+ std::string path;
+ const auto host = uri.host();
+ if (!host.empty()) {
+#ifdef _WIN32
+ std::stringstream ss;
+ ss << "//" << host << "/" << internal::RemoveLeadingSlash(uri.path());
+ *out_path = ss.str();
+#else
+ return Status::Invalid("Unsupported hostname in non-Windows local URI: '",
+ uri.ToString(), "'");
+#endif
+ } else {
+ *out_path = uri.path();
+ }
+
+ // TODO handle use_mmap option
+ return LocalFileSystemOptions();
+}
+
+LocalFileSystem::LocalFileSystem(const io::IOContext& io_context)
+ : FileSystem(io_context), options_(LocalFileSystemOptions::Defaults()) {}
+
+LocalFileSystem::LocalFileSystem(const LocalFileSystemOptions& options,
+ const io::IOContext& io_context)
+ : FileSystem(io_context), options_(options) {}
+
+LocalFileSystem::~LocalFileSystem() {}
+
+Result<std::string> LocalFileSystem::NormalizePath(std::string path) {
+ ARROW_ASSIGN_OR_RAISE(auto fn, PlatformFilename::FromString(path));
+ return fn.ToString();
+}
+
+bool LocalFileSystem::Equals(const FileSystem& other) const {
+ if (other.type_name() != type_name()) {
+ return false;
+ } else {
+ const auto& localfs = ::arrow::internal::checked_cast<const LocalFileSystem&>(other);
+ return options_.Equals(localfs.options());
+ }
+}
+
+Result<FileInfo> LocalFileSystem::GetFileInfo(const std::string& path) {
+ ARROW_ASSIGN_OR_RAISE(auto fn, PlatformFilename::FromString(path));
+ return StatFile(fn.ToNative());
+}
+
+Result<std::vector<FileInfo>> LocalFileSystem::GetFileInfo(const FileSelector& select) {
+ ARROW_ASSIGN_OR_RAISE(auto fn, PlatformFilename::FromString(select.base_dir));
+ std::vector<FileInfo> results;
+ RETURN_NOT_OK(StatSelector(fn, select, 0, &results));
+ return results;
+}
+
+Status LocalFileSystem::CreateDir(const std::string& path, bool recursive) {
+ ARROW_ASSIGN_OR_RAISE(auto fn, PlatformFilename::FromString(path));
+ if (recursive) {
+ return ::arrow::internal::CreateDirTree(fn).status();
+ } else {
+ return ::arrow::internal::CreateDir(fn).status();
+ }
+}
+
+Status LocalFileSystem::DeleteDir(const std::string& path) {
+ ARROW_ASSIGN_OR_RAISE(auto fn, PlatformFilename::FromString(path));
+ auto st = ::arrow::internal::DeleteDirTree(fn, /*allow_not_found=*/false).status();
+ if (!st.ok()) {
+ // TODO Status::WithPrefix()?
+ std::stringstream ss;
+ ss << "Cannot delete directory '" << path << "': " << st.message();
+ return st.WithMessage(ss.str());
+ }
+ return Status::OK();
+}
+
+Status LocalFileSystem::DeleteDirContents(const std::string& path) {
+ if (internal::IsEmptyPath(path)) {
+ return internal::InvalidDeleteDirContents(path);
+ }
+ ARROW_ASSIGN_OR_RAISE(auto fn, PlatformFilename::FromString(path));
+ auto st = ::arrow::internal::DeleteDirContents(fn, /*allow_not_found=*/false).status();
+ if (!st.ok()) {
+ std::stringstream ss;
+ ss << "Cannot delete directory contents in '" << path << "': " << st.message();
+ return st.WithMessage(ss.str());
+ }
+ return Status::OK();
+}
+
+Status LocalFileSystem::DeleteRootDirContents() {
+ return Status::Invalid("LocalFileSystem::DeleteRootDirContents is strictly forbidden");
+}
+
+Status LocalFileSystem::DeleteFile(const std::string& path) {
+ ARROW_ASSIGN_OR_RAISE(auto fn, PlatformFilename::FromString(path));
+ return ::arrow::internal::DeleteFile(fn, /*allow_not_found=*/false).status();
+}
+
+Status LocalFileSystem::Move(const std::string& src, const std::string& dest) {
+ ARROW_ASSIGN_OR_RAISE(auto sfn, PlatformFilename::FromString(src));
+ ARROW_ASSIGN_OR_RAISE(auto dfn, PlatformFilename::FromString(dest));
+
+#ifdef _WIN32
+ if (!MoveFileExW(sfn.ToNative().c_str(), dfn.ToNative().c_str(),
+ MOVEFILE_REPLACE_EXISTING)) {
+ return IOErrorFromWinError(GetLastError(), "Failed renaming '", sfn.ToString(),
+ "' to '", dfn.ToString(), "'");
+ }
+#else
+ if (rename(sfn.ToNative().c_str(), dfn.ToNative().c_str()) == -1) {
+ return IOErrorFromErrno(errno, "Failed renaming '", sfn.ToString(), "' to '",
+ dfn.ToString(), "'");
+ }
+#endif
+ return Status::OK();
+}
+
+Status LocalFileSystem::CopyFile(const std::string& src, const std::string& dest) {
+ ARROW_ASSIGN_OR_RAISE(auto sfn, PlatformFilename::FromString(src));
+ ARROW_ASSIGN_OR_RAISE(auto dfn, PlatformFilename::FromString(dest));
+ // XXX should we use fstat() to compare inodes?
+ if (sfn.ToNative() == dfn.ToNative()) {
+ return Status::OK();
+ }
+
+#ifdef _WIN32
+ if (!CopyFileW(sfn.ToNative().c_str(), dfn.ToNative().c_str(),
+ FALSE /* bFailIfExists */)) {
+ return IOErrorFromWinError(GetLastError(), "Failed copying '", sfn.ToString(),
+ "' to '", dfn.ToString(), "'");
+ }
+ return Status::OK();
+#else
+ ARROW_ASSIGN_OR_RAISE(auto is, OpenInputStream(src));
+ ARROW_ASSIGN_OR_RAISE(auto os, OpenOutputStream(dest));
+ RETURN_NOT_OK(internal::CopyStream(is, os, 1024 * 1024 /* chunk_size */, io_context()));
+ RETURN_NOT_OK(os->Close());
+ return is->Close();
+#endif
+}
+
+namespace {
+
+template <typename InputStreamType>
+Result<std::shared_ptr<InputStreamType>> OpenInputStreamGeneric(
+ const std::string& path, const LocalFileSystemOptions& options,
+ const io::IOContext& io_context) {
+ if (options.use_mmap) {
+ return io::MemoryMappedFile::Open(path, io::FileMode::READ);
+ } else {
+ return io::ReadableFile::Open(path, io_context.pool());
+ }
+}
+
+} // namespace
+
+Result<std::shared_ptr<io::InputStream>> LocalFileSystem::OpenInputStream(
+ const std::string& path) {
+ return OpenInputStreamGeneric<io::InputStream>(path, options_, io_context());
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> LocalFileSystem::OpenInputFile(
+ const std::string& path) {
+ return OpenInputStreamGeneric<io::RandomAccessFile>(path, options_, io_context());
+}
+
+namespace {
+
+Result<std::shared_ptr<io::OutputStream>> OpenOutputStreamGeneric(const std::string& path,
+ bool truncate,
+ bool append) {
+ int fd;
+ bool write_only = true;
+ ARROW_ASSIGN_OR_RAISE(auto fn, PlatformFilename::FromString(path));
+ ARROW_ASSIGN_OR_RAISE(
+ fd, ::arrow::internal::FileOpenWritable(fn, write_only, truncate, append));
+ auto maybe_stream = io::FileOutputStream::Open(fd);
+ if (!maybe_stream.ok()) {
+ ARROW_UNUSED(::arrow::internal::FileClose(fd));
+ }
+ return maybe_stream;
+}
+
+} // namespace
+
+Result<std::shared_ptr<io::OutputStream>> LocalFileSystem::OpenOutputStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ bool truncate = true;
+ bool append = false;
+ return OpenOutputStreamGeneric(path, truncate, append);
+}
+
+Result<std::shared_ptr<io::OutputStream>> LocalFileSystem::OpenAppendStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ bool truncate = false;
+ bool append = true;
+ return OpenOutputStreamGeneric(path, truncate, append);
+}
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/localfs.h b/src/arrow/cpp/src/arrow/filesystem/localfs.h
new file mode 100644
index 000000000..f8e77aee5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/localfs.h
@@ -0,0 +1,113 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/filesystem/filesystem.h"
+
+namespace arrow {
+namespace internal {
+
+class Uri;
+
+}
+
+namespace fs {
+
+/// Options for the LocalFileSystem implementation.
+struct ARROW_EXPORT LocalFileSystemOptions {
+ /// Whether OpenInputStream and OpenInputFile return a mmap'ed file,
+ /// or a regular one.
+ bool use_mmap = false;
+
+ /// \brief Initialize with defaults
+ static LocalFileSystemOptions Defaults();
+
+ bool Equals(const LocalFileSystemOptions& other) const;
+
+ static Result<LocalFileSystemOptions> FromUri(const ::arrow::internal::Uri& uri,
+ std::string* out_path);
+};
+
+/// \brief A FileSystem implementation accessing files on the local machine.
+///
+/// This class handles only `/`-separated paths. If desired, conversion
+/// from Windows backslash-separated paths should be done by the caller.
+/// Details such as symlinks are abstracted away (symlinks are always
+/// followed, except when deleting an entry).
+class ARROW_EXPORT LocalFileSystem : public FileSystem {
+ public:
+ explicit LocalFileSystem(const io::IOContext& = io::default_io_context());
+ explicit LocalFileSystem(const LocalFileSystemOptions&,
+ const io::IOContext& = io::default_io_context());
+ ~LocalFileSystem() override;
+
+ std::string type_name() const override { return "local"; }
+
+ Result<std::string> NormalizePath(std::string path) override;
+
+ bool Equals(const FileSystem& other) const override;
+
+ LocalFileSystemOptions options() const { return options_; }
+
+ /// \cond FALSE
+ using FileSystem::GetFileInfo;
+ /// \endcond
+ Result<FileInfo> GetFileInfo(const std::string& path) override;
+ Result<std::vector<FileInfo>> GetFileInfo(const FileSelector& select) override;
+
+ Status CreateDir(const std::string& path, bool recursive = true) override;
+
+ Status DeleteDir(const std::string& path) override;
+ Status DeleteDirContents(const std::string& path) override;
+ Status DeleteRootDirContents() override;
+
+ Status DeleteFile(const std::string& path) override;
+
+ Status Move(const std::string& src, const std::string& dest) override;
+
+ Status CopyFile(const std::string& src, const std::string& dest) override;
+
+ Result<std::shared_ptr<io::InputStream>> OpenInputStream(
+ const std::string& path) override;
+ Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(
+ const std::string& path) override;
+ Result<std::shared_ptr<io::OutputStream>> OpenOutputStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+ Result<std::shared_ptr<io::OutputStream>> OpenAppendStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+
+ protected:
+ LocalFileSystemOptions options_;
+};
+
+namespace internal {
+
+// Return whether the string is detected as a local absolute path.
+ARROW_EXPORT
+bool DetectAbsolutePath(const std::string& s);
+
+} // namespace internal
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/localfs_test.cc b/src/arrow/cpp/src/arrow/filesystem/localfs_test.cc
new file mode 100644
index 000000000..e33816095
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/localfs_test.cc
@@ -0,0 +1,396 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cerrno>
+#include <chrono>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/filesystem/filesystem.h"
+#include "arrow/filesystem/localfs.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/filesystem/test_util.h"
+#include "arrow/filesystem/util_internal.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/io_util.h"
+
+namespace arrow {
+namespace fs {
+namespace internal {
+
+using ::arrow::internal::PlatformFilename;
+using ::arrow::internal::TemporaryDir;
+
+class LocalFSTestMixin : public ::testing::Test {
+ public:
+ void SetUp() override {
+ ASSERT_OK_AND_ASSIGN(temp_dir_, TemporaryDir::Make("test-localfs-"));
+ }
+
+ protected:
+ std::unique_ptr<TemporaryDir> temp_dir_;
+};
+
+struct CommonPathFormatter {
+ std::string operator()(std::string fn) { return fn; }
+ bool supports_uri() { return true; }
+};
+
+#ifdef _WIN32
+struct ExtendedLengthPathFormatter {
+ std::string operator()(std::string fn) { return "//?/" + fn; }
+ // The path prefix conflicts with URI syntax
+ bool supports_uri() { return false; }
+};
+
+using PathFormatters = ::testing::Types<CommonPathFormatter, ExtendedLengthPathFormatter>;
+#else
+using PathFormatters = ::testing::Types<CommonPathFormatter>;
+#endif
+
+// Non-overloaded version of FileSystemFromUri, for template resolution
+// in CheckFileSystemFromUriFunc.
+Result<std::shared_ptr<FileSystem>> FSFromUri(const std::string& uri,
+ std::string* out_path = NULLPTR) {
+ return FileSystemFromUri(uri, out_path);
+}
+
+Result<std::shared_ptr<FileSystem>> FSFromUriOrPath(const std::string& uri,
+ std::string* out_path = NULLPTR) {
+ return FileSystemFromUriOrPath(uri, out_path);
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Misc tests
+
+TEST(DetectAbsolutePath, Basics) {
+ ASSERT_TRUE(DetectAbsolutePath("/"));
+ ASSERT_TRUE(DetectAbsolutePath("/foo"));
+ ASSERT_TRUE(DetectAbsolutePath("/foo/bar.txt"));
+ ASSERT_TRUE(DetectAbsolutePath("//foo/bar/baz"));
+
+#ifdef _WIN32
+ constexpr bool is_win32 = true;
+#else
+ constexpr bool is_win32 = false;
+#endif
+ ASSERT_EQ(is_win32, DetectAbsolutePath("A:/"));
+ ASSERT_EQ(is_win32, DetectAbsolutePath("z:/foo"));
+
+ ASSERT_EQ(is_win32, DetectAbsolutePath("\\"));
+ ASSERT_EQ(is_win32, DetectAbsolutePath("\\foo"));
+ ASSERT_EQ(is_win32, DetectAbsolutePath("\\foo\\bar"));
+ ASSERT_EQ(is_win32, DetectAbsolutePath("\\\\foo\\bar\\baz"));
+ ASSERT_EQ(is_win32, DetectAbsolutePath("Z:\\"));
+ ASSERT_EQ(is_win32, DetectAbsolutePath("z:\\foo"));
+
+ ASSERT_FALSE(DetectAbsolutePath("A:"));
+ ASSERT_FALSE(DetectAbsolutePath("z:foo"));
+ ASSERT_FALSE(DetectAbsolutePath(""));
+ ASSERT_FALSE(DetectAbsolutePath("AB:"));
+ ASSERT_FALSE(DetectAbsolutePath(":"));
+ ASSERT_FALSE(DetectAbsolutePath(""));
+ ASSERT_FALSE(DetectAbsolutePath("@:"));
+ ASSERT_FALSE(DetectAbsolutePath("à:"));
+ ASSERT_FALSE(DetectAbsolutePath("0:"));
+ ASSERT_FALSE(DetectAbsolutePath("A"));
+ ASSERT_FALSE(DetectAbsolutePath("foo/bar"));
+ ASSERT_FALSE(DetectAbsolutePath("foo\\bar"));
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Generic LocalFileSystem tests
+
+template <typename PathFormatter>
+class TestLocalFSGeneric : public LocalFSTestMixin, public GenericFileSystemTest {
+ public:
+ void SetUp() override {
+ LocalFSTestMixin::SetUp();
+ local_fs_ = std::make_shared<LocalFileSystem>(options());
+ auto path = PathFormatter()(temp_dir_->path().ToString());
+ fs_ = std::make_shared<SubTreeFileSystem>(path, local_fs_);
+ }
+
+ protected:
+ virtual LocalFileSystemOptions options() { return LocalFileSystemOptions::Defaults(); }
+
+ std::shared_ptr<FileSystem> GetEmptyFileSystem() override { return fs_; }
+
+ std::shared_ptr<LocalFileSystem> local_fs_;
+ std::shared_ptr<FileSystem> fs_;
+};
+
+TYPED_TEST_SUITE(TestLocalFSGeneric, PathFormatters);
+
+GENERIC_FS_TYPED_TEST_FUNCTIONS(TestLocalFSGeneric);
+
+class TestLocalFSGenericMMap : public TestLocalFSGeneric<CommonPathFormatter> {
+ protected:
+ LocalFileSystemOptions options() override {
+ auto options = LocalFileSystemOptions::Defaults();
+ options.use_mmap = true;
+ return options;
+ }
+};
+
+GENERIC_FS_TEST_FUNCTIONS(TestLocalFSGenericMMap);
+
+////////////////////////////////////////////////////////////////////////////
+// Concrete LocalFileSystem tests
+
+template <typename PathFormatter>
+class TestLocalFS : public LocalFSTestMixin {
+ public:
+ void SetUp() {
+ LocalFSTestMixin::SetUp();
+ path_formatter_ = PathFormatter();
+ local_fs_ = std::make_shared<LocalFileSystem>();
+ local_path_ = EnsureTrailingSlash(path_formatter_(temp_dir_->path().ToString()));
+ fs_ = std::make_shared<SubTreeFileSystem>(local_path_, local_fs_);
+ }
+
+ std::string UriFromAbsolutePath(const std::string& path) {
+#ifdef _WIN32
+ // Path is supposed to start with "X:/..."
+ return "file:///" + path;
+#else
+ // Path is supposed to start with "/..."
+ return "file://" + path;
+#endif
+ }
+
+ template <typename FileSystemFromUriFunc>
+ void CheckFileSystemFromUriFunc(const std::string& uri,
+ FileSystemFromUriFunc&& fs_from_uri) {
+ if (!path_formatter_.supports_uri()) {
+ return; // skip
+ }
+ std::string path;
+ ASSERT_OK_AND_ASSIGN(fs_, fs_from_uri(uri, &path));
+ ASSERT_EQ(path, local_path_);
+
+ // Test that the right location on disk is accessed
+ CreateFile(fs_.get(), local_path_ + "abc", "some data");
+ CheckConcreteFile(this->temp_dir_->path().ToString() + "abc", 9);
+ }
+
+ void TestFileSystemFromUri(const std::string& uri) {
+ CheckFileSystemFromUriFunc(uri, FSFromUri);
+ }
+
+ void TestFileSystemFromUriOrPath(const std::string& uri) {
+ CheckFileSystemFromUriFunc(uri, FSFromUriOrPath);
+ }
+
+ template <typename FileSystemFromUriFunc>
+ void CheckLocalUri(const std::string& uri, const std::string& expected_path,
+ FileSystemFromUriFunc&& fs_from_uri) {
+ if (!path_formatter_.supports_uri()) {
+ return; // skip
+ }
+ std::string path;
+ ASSERT_OK_AND_ASSIGN(fs_, fs_from_uri(uri, &path));
+ ASSERT_EQ(fs_->type_name(), "local");
+ ASSERT_EQ(path, expected_path);
+ }
+
+ // Like TestFileSystemFromUri, but with an arbitrary non-existing path
+ void TestLocalUri(const std::string& uri, const std::string& expected_path) {
+ CheckLocalUri(uri, expected_path, FSFromUri);
+ }
+
+ void TestLocalUriOrPath(const std::string& uri, const std::string& expected_path) {
+ CheckLocalUri(uri, expected_path, FSFromUriOrPath);
+ }
+
+ void TestInvalidUri(const std::string& uri) {
+ if (!path_formatter_.supports_uri()) {
+ return; // skip
+ }
+ ASSERT_RAISES(Invalid, FileSystemFromUri(uri));
+ }
+
+ void TestInvalidUriOrPath(const std::string& uri) {
+ if (!path_formatter_.supports_uri()) {
+ return; // skip
+ }
+ ASSERT_RAISES(Invalid, FileSystemFromUriOrPath(uri));
+ }
+
+ void CheckConcreteFile(const std::string& path, int64_t expected_size) {
+ ASSERT_OK_AND_ASSIGN(auto fn, PlatformFilename::FromString(path));
+ ASSERT_OK_AND_ASSIGN(int fd, ::arrow::internal::FileOpenReadable(fn));
+ auto result = ::arrow::internal::FileGetSize(fd);
+ ASSERT_OK(::arrow::internal::FileClose(fd));
+ ASSERT_OK_AND_ASSIGN(int64_t size, result);
+ ASSERT_EQ(size, expected_size);
+ }
+
+ static void CheckNormalizePath(const std::shared_ptr<FileSystem>& fs) {}
+
+ protected:
+ PathFormatter path_formatter_;
+ std::shared_ptr<LocalFileSystem> local_fs_;
+ std::shared_ptr<FileSystem> fs_;
+ std::string local_path_;
+};
+
+TYPED_TEST_SUITE(TestLocalFS, PathFormatters);
+
+TYPED_TEST(TestLocalFS, CorrectPathExists) {
+ // Test that the right location on disk is accessed
+ std::shared_ptr<io::OutputStream> stream;
+ ASSERT_OK_AND_ASSIGN(stream, this->fs_->OpenOutputStream("abc"));
+ std::string data = "some data";
+ auto data_size = static_cast<int64_t>(data.size());
+ ASSERT_OK(stream->Write(data.data(), data_size));
+ ASSERT_OK(stream->Close());
+
+ // Now check the file's existence directly, bypassing the FileSystem abstraction
+ this->CheckConcreteFile(this->temp_dir_->path().ToString() + "abc", data_size);
+}
+
+TYPED_TEST(TestLocalFS, NormalizePath) {
+#ifdef _WIN32
+ ASSERT_OK_AND_EQ("AB/CD", this->local_fs_->NormalizePath("AB\\CD"));
+ ASSERT_OK_AND_EQ("/AB/CD", this->local_fs_->NormalizePath("\\AB\\CD"));
+ ASSERT_OK_AND_EQ("C:DE/fgh", this->local_fs_->NormalizePath("C:DE\\fgh"));
+ ASSERT_OK_AND_EQ("C:/DE/fgh", this->local_fs_->NormalizePath("C:\\DE\\fgh"));
+ ASSERT_OK_AND_EQ("//some/share/AB",
+ this->local_fs_->NormalizePath("\\\\some\\share\\AB"));
+#else
+ ASSERT_OK_AND_EQ("AB\\CD", this->local_fs_->NormalizePath("AB\\CD"));
+#endif
+}
+
+TYPED_TEST(TestLocalFS, NormalizePathThroughSubtreeFS) {
+#ifdef _WIN32
+ ASSERT_OK_AND_EQ("AB/CD", this->fs_->NormalizePath("AB\\CD"));
+#else
+ ASSERT_OK_AND_EQ("AB\\CD", this->fs_->NormalizePath("AB\\CD"));
+#endif
+}
+
+TYPED_TEST(TestLocalFS, FileSystemFromUriFile) {
+ // Concrete test with actual file
+ const auto uri_string = this->UriFromAbsolutePath(this->local_path_);
+ this->TestFileSystemFromUri(uri_string);
+ this->TestFileSystemFromUriOrPath(uri_string);
+
+ // Variations
+ this->TestLocalUri("file:/foo/bar", "/foo/bar");
+ this->TestLocalUri("file:///foo/bar", "/foo/bar");
+#ifdef _WIN32
+ this->TestLocalUri("file:/C:/foo/bar", "C:/foo/bar");
+ this->TestLocalUri("file:///C:/foo/bar", "C:/foo/bar");
+#endif
+
+ // Non-empty authority
+#ifdef _WIN32
+ this->TestLocalUri("file://server/share/foo/bar", "//server/share/foo/bar");
+#else
+ this->TestInvalidUri("file://server/share/foo/bar");
+#endif
+
+ // Relative paths
+ this->TestInvalidUri("file:");
+ this->TestInvalidUri("file:foo/bar");
+}
+
+TYPED_TEST(TestLocalFS, FileSystemFromUriNoScheme) {
+ // Concrete test with actual file
+ this->TestFileSystemFromUriOrPath(this->local_path_);
+ this->TestInvalidUri(this->local_path_); // Not actually an URI
+
+ // Variations
+ this->TestLocalUriOrPath(this->path_formatter_("/foo/bar"), "/foo/bar");
+
+#ifdef _WIN32
+ this->TestLocalUriOrPath(this->path_formatter_("C:/foo/bar/"), "C:/foo/bar/");
+#endif
+
+ // Relative paths
+ this->TestInvalidUriOrPath("C:foo/bar");
+ this->TestInvalidUriOrPath("foo/bar");
+}
+
+TYPED_TEST(TestLocalFS, FileSystemFromUriNoSchemeBackslashes) {
+ const auto uri_string = ToBackslashes(this->local_path_);
+#ifdef _WIN32
+ this->TestFileSystemFromUriOrPath(uri_string);
+
+ // Variations
+ this->TestLocalUriOrPath(this->path_formatter_("C:\\foo\\bar"), "C:/foo/bar");
+#else
+ this->TestInvalidUri(uri_string);
+#endif
+
+ // Relative paths
+ this->TestInvalidUriOrPath("C:foo\\bar");
+ this->TestInvalidUriOrPath("foo\\bar");
+}
+
+TYPED_TEST(TestLocalFS, DirectoryMTime) {
+ TimePoint t1 = CurrentTimePoint();
+ ASSERT_OK(this->fs_->CreateDir("AB/CD/EF"));
+ TimePoint t2 = CurrentTimePoint();
+
+ std::vector<FileInfo> infos;
+ ASSERT_OK_AND_ASSIGN(infos, this->fs_->GetFileInfo({"AB", "AB/CD/EF", "xxx"}));
+ ASSERT_EQ(infos.size(), 3);
+ AssertFileInfo(infos[0], "AB", FileType::Directory);
+ AssertFileInfo(infos[1], "AB/CD/EF", FileType::Directory);
+ AssertFileInfo(infos[2], "xxx", FileType::NotFound);
+
+ // NOTE: creating AB/CD updates AB's modification time, but creating
+ // AB/CD/EF doesn't. So AB/CD/EF's modification time should always be
+ // the same as or after AB's modification time.
+ AssertDurationBetween(infos[1].mtime() - infos[0].mtime(), 0, kTimeSlack);
+ // Depending on filesystem time granularity, the recorded time could be
+ // before the system time when doing the modification.
+ AssertDurationBetween(infos[0].mtime() - t1, -kTimeSlack, kTimeSlack);
+ AssertDurationBetween(t2 - infos[1].mtime(), -kTimeSlack, kTimeSlack);
+}
+
+TYPED_TEST(TestLocalFS, FileMTime) {
+ TimePoint t1 = CurrentTimePoint();
+ ASSERT_OK(this->fs_->CreateDir("AB/CD"));
+ CreateFile(this->fs_.get(), "AB/CD/ab", "data");
+ TimePoint t2 = CurrentTimePoint();
+
+ std::vector<FileInfo> infos;
+ ASSERT_OK_AND_ASSIGN(infos, this->fs_->GetFileInfo({"AB", "AB/CD/ab", "xxx"}));
+ ASSERT_EQ(infos.size(), 3);
+ AssertFileInfo(infos[0], "AB", FileType::Directory);
+ AssertFileInfo(infos[1], "AB/CD/ab", FileType::File, 4);
+ AssertFileInfo(infos[2], "xxx", FileType::NotFound);
+
+ AssertDurationBetween(infos[1].mtime() - infos[0].mtime(), 0, kTimeSlack);
+ AssertDurationBetween(infos[0].mtime() - t1, -kTimeSlack, kTimeSlack);
+ AssertDurationBetween(t2 - infos[1].mtime(), -kTimeSlack, kTimeSlack);
+}
+
+// TODO Should we test backslash paths on Windows?
+// SubTreeFileSystem isn't compatible with them.
+
+} // namespace internal
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/mockfs.cc b/src/arrow/cpp/src/arrow/filesystem/mockfs.cc
new file mode 100644
index 000000000..f2d2f8726
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/mockfs.cc
@@ -0,0 +1,778 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <iterator>
+#include <map>
+#include <mutex>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/filesystem/mockfs.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/filesystem/util_internal.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/memory.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/variant.h"
+#include "arrow/util/windows_fixup.h"
+
+namespace arrow {
+namespace fs {
+namespace internal {
+
+namespace {
+
+////////////////////////////////////////////////////////////////////////////
+// Filesystem structure
+
+class Entry;
+
+struct File {
+ TimePoint mtime;
+ std::string name;
+ std::shared_ptr<Buffer> data;
+ std::shared_ptr<const KeyValueMetadata> metadata;
+
+ File(TimePoint mtime, std::string name) : mtime(mtime), name(std::move(name)) {}
+
+ int64_t size() const { return data ? data->size() : 0; }
+
+ explicit operator util::string_view() const {
+ if (data) {
+ return util::string_view(*data);
+ } else {
+ return "";
+ }
+ }
+};
+
+struct Directory {
+ std::string name;
+ TimePoint mtime;
+ std::map<std::string, std::unique_ptr<Entry>> entries;
+
+ Directory(std::string name, TimePoint mtime) : name(std::move(name)), mtime(mtime) {}
+ Directory(Directory&& other) noexcept
+ : name(std::move(other.name)),
+ mtime(other.mtime),
+ entries(std::move(other.entries)) {}
+
+ Directory& operator=(Directory&& other) noexcept {
+ name = std::move(other.name);
+ mtime = other.mtime;
+ entries = std::move(other.entries);
+ return *this;
+ }
+
+ Entry* Find(const std::string& s) {
+ auto it = entries.find(s);
+ if (it != entries.end()) {
+ return it->second.get();
+ } else {
+ return nullptr;
+ }
+ }
+
+ bool CreateEntry(const std::string& s, std::unique_ptr<Entry> entry) {
+ DCHECK(!s.empty());
+ auto p = entries.emplace(s, std::move(entry));
+ return p.second;
+ }
+
+ void AssignEntry(const std::string& s, std::unique_ptr<Entry> entry) {
+ DCHECK(!s.empty());
+ entries[s] = std::move(entry);
+ }
+
+ bool DeleteEntry(const std::string& s) { return entries.erase(s) > 0; }
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Directory);
+};
+
+// A filesystem entry
+using EntryBase = util::Variant<std::nullptr_t, File, Directory>;
+
+class Entry : public EntryBase {
+ public:
+ Entry(Entry&&) = default;
+ Entry& operator=(Entry&&) = default;
+ explicit Entry(Directory&& v) : EntryBase(std::move(v)) {}
+ explicit Entry(File&& v) : EntryBase(std::move(v)) {}
+
+ bool is_dir() const { return util::holds_alternative<Directory>(*this); }
+
+ bool is_file() const { return util::holds_alternative<File>(*this); }
+
+ Directory& as_dir() { return util::get<Directory>(*this); }
+
+ File& as_file() { return util::get<File>(*this); }
+
+ // Get info for this entry. Note the path() property isn't set.
+ FileInfo GetInfo() {
+ FileInfo info;
+ if (is_dir()) {
+ Directory& dir = as_dir();
+ info.set_type(FileType::Directory);
+ info.set_mtime(dir.mtime);
+ } else {
+ DCHECK(is_file());
+ File& file = as_file();
+ info.set_type(FileType::File);
+ info.set_mtime(file.mtime);
+ info.set_size(file.size());
+ }
+ return info;
+ }
+
+ // Get info for this entry, knowing the parent path.
+ FileInfo GetInfo(const std::string& base_path) {
+ FileInfo info;
+ if (is_dir()) {
+ Directory& dir = as_dir();
+ info.set_type(FileType::Directory);
+ info.set_mtime(dir.mtime);
+ info.set_path(ConcatAbstractPath(base_path, dir.name));
+ } else {
+ DCHECK(is_file());
+ File& file = as_file();
+ info.set_type(FileType::File);
+ info.set_mtime(file.mtime);
+ info.set_size(file.size());
+ info.set_path(ConcatAbstractPath(base_path, file.name));
+ }
+ return info;
+ }
+
+ // Set the entry name
+ void SetName(const std::string& name) {
+ if (is_dir()) {
+ as_dir().name = name;
+ } else {
+ DCHECK(is_file());
+ as_file().name = name;
+ }
+ }
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Entry);
+};
+
+////////////////////////////////////////////////////////////////////////////
+// Streams
+
+class MockFSOutputStream : public io::OutputStream {
+ public:
+ MockFSOutputStream(File* file, MemoryPool* pool)
+ : file_(file), builder_(pool), closed_(false) {}
+
+ ~MockFSOutputStream() override = default;
+
+ // Implement the OutputStream interface
+ Status Close() override {
+ if (!closed_) {
+ RETURN_NOT_OK(builder_.Finish(&file_->data));
+ closed_ = true;
+ }
+ return Status::OK();
+ }
+
+ Status Abort() override {
+ if (!closed_) {
+ // MockFSOutputStream is mainly used for debugging and testing, so
+ // mark an aborted file's contents explicitly.
+ std::stringstream ss;
+ ss << "MockFSOutputStream aborted after " << file_->size() << " bytes written";
+ file_->data = Buffer::FromString(ss.str());
+ closed_ = true;
+ }
+ return Status::OK();
+ }
+
+ bool closed() const override { return closed_; }
+
+ Result<int64_t> Tell() const override {
+ if (closed_) {
+ return Status::Invalid("Invalid operation on closed stream");
+ }
+ return builder_.length();
+ }
+
+ Status Write(const void* data, int64_t nbytes) override {
+ if (closed_) {
+ return Status::Invalid("Invalid operation on closed stream");
+ }
+ return builder_.Append(data, nbytes);
+ }
+
+ protected:
+ File* file_;
+ BufferBuilder builder_;
+ bool closed_;
+};
+
+class MockFSInputStream : public io::BufferReader {
+ public:
+ explicit MockFSInputStream(const File& file)
+ : io::BufferReader(file.data), metadata_(file.metadata) {}
+
+ Result<std::shared_ptr<const KeyValueMetadata>> ReadMetadata() override {
+ return metadata_;
+ }
+
+ protected:
+ std::shared_ptr<const KeyValueMetadata> metadata_;
+};
+
+} // namespace
+
+std::ostream& operator<<(std::ostream& os, const MockDirInfo& di) {
+ return os << "'" << di.full_path << "' [mtime=" << di.mtime.time_since_epoch().count()
+ << "]";
+}
+
+std::ostream& operator<<(std::ostream& os, const MockFileInfo& di) {
+ return os << "'" << di.full_path << "' [mtime=" << di.mtime.time_since_epoch().count()
+ << ", size=" << di.data.length() << "]";
+}
+
+////////////////////////////////////////////////////////////////////////////
+// MockFileSystem implementation
+
+class MockFileSystem::Impl {
+ public:
+ TimePoint current_time;
+ MemoryPool* pool;
+
+ // The root directory
+ Entry root;
+ std::mutex mutex;
+
+ Impl(TimePoint current_time, MemoryPool* pool)
+ : current_time(current_time), pool(pool), root(Directory("", current_time)) {}
+
+ std::unique_lock<std::mutex> lock_guard() {
+ return std::unique_lock<std::mutex>(mutex);
+ }
+
+ Directory& RootDir() { return root.as_dir(); }
+
+ template <typename It>
+ Entry* FindEntry(It it, It end, size_t* nconsumed) {
+ size_t consumed = 0;
+ Entry* entry = &root;
+
+ for (; it != end; ++it) {
+ const std::string& part = *it;
+ DCHECK(entry->is_dir());
+ Entry* child = entry->as_dir().Find(part);
+ if (child == nullptr) {
+ // Partial find only
+ break;
+ }
+ ++consumed;
+ entry = child;
+ if (entry->is_file()) {
+ // Cannot go any further
+ break;
+ }
+ // Recurse
+ }
+ *nconsumed = consumed;
+ return entry;
+ }
+
+ // Find an entry, allowing partial matching
+ Entry* FindEntry(const std::vector<std::string>& parts, size_t* nconsumed) {
+ return FindEntry(parts.begin(), parts.end(), nconsumed);
+ }
+
+ // Find an entry, only full matching allowed
+ Entry* FindEntry(const std::vector<std::string>& parts) {
+ size_t consumed;
+ auto entry = FindEntry(parts, &consumed);
+ return (consumed == parts.size()) ? entry : nullptr;
+ }
+
+ // Find the parent entry, only full matching allowed
+ Entry* FindParent(const std::vector<std::string>& parts) {
+ if (parts.size() == 0) {
+ return nullptr;
+ }
+ size_t consumed;
+ auto entry = FindEntry(parts.begin(), --parts.end(), &consumed);
+ return (consumed == parts.size() - 1) ? entry : nullptr;
+ }
+
+ void GatherInfos(const FileSelector& select, const std::string& base_path,
+ const Directory& base_dir, int32_t nesting_depth,
+ std::vector<FileInfo>* infos) {
+ for (const auto& pair : base_dir.entries) {
+ Entry* child = pair.second.get();
+ infos->push_back(child->GetInfo(base_path));
+ if (select.recursive && nesting_depth < select.max_recursion && child->is_dir()) {
+ Directory& child_dir = child->as_dir();
+ std::string child_path = infos->back().path();
+ GatherInfos(select, std::move(child_path), child_dir, nesting_depth + 1, infos);
+ }
+ }
+ }
+
+ void DumpDirs(const std::string& prefix, const Directory& dir,
+ std::vector<MockDirInfo>* out) {
+ std::string path = prefix + dir.name;
+ if (!path.empty()) {
+ out->push_back({path, dir.mtime});
+ path += "/";
+ }
+ for (const auto& pair : dir.entries) {
+ Entry* child = pair.second.get();
+ if (child->is_dir()) {
+ DumpDirs(path, child->as_dir(), out);
+ }
+ }
+ }
+
+ void DumpFiles(const std::string& prefix, const Directory& dir,
+ std::vector<MockFileInfo>* out) {
+ std::string path = prefix + dir.name;
+ if (!path.empty()) {
+ path += "/";
+ }
+ for (const auto& pair : dir.entries) {
+ Entry* child = pair.second.get();
+ if (child->is_file()) {
+ auto& file = child->as_file();
+ out->push_back({path + file.name, file.mtime, util::string_view(file)});
+ } else if (child->is_dir()) {
+ DumpFiles(path, child->as_dir(), out);
+ }
+ }
+ }
+
+ Result<std::shared_ptr<io::OutputStream>> OpenOutputStream(
+ const std::string& path, bool append,
+ const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ auto parts = SplitAbstractPath(path);
+ RETURN_NOT_OK(ValidateAbstractPathParts(parts));
+
+ Entry* parent = FindParent(parts);
+ if (parent == nullptr || !parent->is_dir()) {
+ return PathNotFound(path);
+ }
+ // Find the file in the parent dir, or create it
+ const auto& name = parts.back();
+ Entry* child = parent->as_dir().Find(name);
+ File* file;
+ if (child == nullptr) {
+ child = new Entry(File(current_time, name));
+ parent->as_dir().AssignEntry(name, std::unique_ptr<Entry>(child));
+ file = &child->as_file();
+ } else if (child->is_file()) {
+ file = &child->as_file();
+ file->mtime = current_time;
+ } else {
+ return NotAFile(path);
+ }
+ file->metadata = metadata;
+ auto ptr = std::make_shared<MockFSOutputStream>(file, pool);
+ if (append && file->data) {
+ RETURN_NOT_OK(ptr->Write(file->data->data(), file->data->size()));
+ }
+ return ptr;
+ }
+
+ Result<std::shared_ptr<io::BufferReader>> OpenInputReader(const std::string& path) {
+ auto parts = SplitAbstractPath(path);
+ RETURN_NOT_OK(ValidateAbstractPathParts(parts));
+
+ Entry* entry = FindEntry(parts);
+ if (entry == nullptr) {
+ return PathNotFound(path);
+ }
+ if (!entry->is_file()) {
+ return NotAFile(path);
+ }
+ return std::make_shared<MockFSInputStream>(entry->as_file());
+ }
+};
+
+MockFileSystem::~MockFileSystem() = default;
+
+MockFileSystem::MockFileSystem(TimePoint current_time, const io::IOContext& io_context) {
+ impl_ = std::unique_ptr<Impl>(new Impl(current_time, io_context.pool()));
+}
+
+bool MockFileSystem::Equals(const FileSystem& other) const { return this == &other; }
+
+Status MockFileSystem::CreateDir(const std::string& path, bool recursive) {
+ auto parts = SplitAbstractPath(path);
+ RETURN_NOT_OK(ValidateAbstractPathParts(parts));
+
+ auto guard = impl_->lock_guard();
+
+ size_t consumed;
+ Entry* entry = impl_->FindEntry(parts, &consumed);
+ if (!entry->is_dir()) {
+ auto file_path = JoinAbstractPath(parts.begin(), parts.begin() + consumed);
+ return Status::IOError("Cannot create directory '", path, "': ", "ancestor '",
+ file_path, "' is not a directory");
+ }
+ if (!recursive && (parts.size() - consumed) > 1) {
+ return Status::IOError("Cannot create directory '", path,
+ "': ", "parent does not exist");
+ }
+ for (size_t i = consumed; i < parts.size(); ++i) {
+ const auto& name = parts[i];
+ std::unique_ptr<Entry> child(new Entry(Directory(name, impl_->current_time)));
+ Entry* child_ptr = child.get();
+ bool inserted = entry->as_dir().CreateEntry(name, std::move(child));
+ // No race condition on insertion is possible, as all operations are locked
+ DCHECK(inserted);
+ entry = child_ptr;
+ }
+ return Status::OK();
+}
+
+Status MockFileSystem::DeleteDir(const std::string& path) {
+ auto parts = SplitAbstractPath(path);
+ RETURN_NOT_OK(ValidateAbstractPathParts(parts));
+
+ auto guard = impl_->lock_guard();
+
+ Entry* parent = impl_->FindParent(parts);
+ if (parent == nullptr || !parent->is_dir()) {
+ return PathNotFound(path);
+ }
+ Directory& parent_dir = parent->as_dir();
+ auto child = parent_dir.Find(parts.back());
+ if (child == nullptr) {
+ return PathNotFound(path);
+ }
+ if (!child->is_dir()) {
+ return NotADir(path);
+ }
+
+ bool deleted = parent_dir.DeleteEntry(parts.back());
+ DCHECK(deleted);
+ return Status::OK();
+}
+
+Status MockFileSystem::DeleteDirContents(const std::string& path) {
+ auto parts = SplitAbstractPath(path);
+ RETURN_NOT_OK(ValidateAbstractPathParts(parts));
+
+ auto guard = impl_->lock_guard();
+
+ if (parts.empty()) {
+ // Wipe filesystem
+ return internal::InvalidDeleteDirContents(path);
+ }
+
+ Entry* entry = impl_->FindEntry(parts);
+ if (entry == nullptr) {
+ return PathNotFound(path);
+ }
+ if (!entry->is_dir()) {
+ return NotADir(path);
+ }
+ entry->as_dir().entries.clear();
+ return Status::OK();
+}
+
+Status MockFileSystem::DeleteRootDirContents() {
+ auto guard = impl_->lock_guard();
+
+ impl_->RootDir().entries.clear();
+ return Status::OK();
+}
+
+Status MockFileSystem::DeleteFile(const std::string& path) {
+ auto parts = SplitAbstractPath(path);
+ RETURN_NOT_OK(ValidateAbstractPathParts(parts));
+
+ auto guard = impl_->lock_guard();
+
+ Entry* parent = impl_->FindParent(parts);
+ if (parent == nullptr || !parent->is_dir()) {
+ return PathNotFound(path);
+ }
+ Directory& parent_dir = parent->as_dir();
+ auto child = parent_dir.Find(parts.back());
+ if (child == nullptr) {
+ return PathNotFound(path);
+ }
+ if (!child->is_file()) {
+ return NotAFile(path);
+ }
+ bool deleted = parent_dir.DeleteEntry(parts.back());
+ DCHECK(deleted);
+ return Status::OK();
+}
+
+Result<FileInfo> MockFileSystem::GetFileInfo(const std::string& path) {
+ auto parts = SplitAbstractPath(path);
+ RETURN_NOT_OK(ValidateAbstractPathParts(parts));
+
+ auto guard = impl_->lock_guard();
+
+ FileInfo info;
+ Entry* entry = impl_->FindEntry(parts);
+ if (entry == nullptr) {
+ info.set_type(FileType::NotFound);
+ } else {
+ info = entry->GetInfo();
+ }
+ info.set_path(path);
+ return info;
+}
+
+Result<FileInfoVector> MockFileSystem::GetFileInfo(const FileSelector& selector) {
+ auto parts = SplitAbstractPath(selector.base_dir);
+ RETURN_NOT_OK(ValidateAbstractPathParts(parts));
+
+ auto guard = impl_->lock_guard();
+
+ FileInfoVector results;
+
+ Entry* base_dir = impl_->FindEntry(parts);
+ if (base_dir == nullptr) {
+ // Base directory does not exist
+ if (selector.allow_not_found) {
+ return results;
+ } else {
+ return PathNotFound(selector.base_dir);
+ }
+ }
+ if (!base_dir->is_dir()) {
+ return NotADir(selector.base_dir);
+ }
+
+ impl_->GatherInfos(selector, selector.base_dir, base_dir->as_dir(), 0, &results);
+ return results;
+}
+
+namespace {
+
+// Helper for binary operations (move, copy)
+struct BinaryOp {
+ std::vector<std::string> src_parts;
+ std::vector<std::string> dest_parts;
+ Directory& src_dir;
+ Directory& dest_dir;
+ std::string src_name;
+ std::string dest_name;
+ Entry* src_entry;
+ Entry* dest_entry;
+
+ template <typename OpFunc>
+ static Status Run(MockFileSystem::Impl* impl, const std::string& src,
+ const std::string& dest, OpFunc&& op_func) {
+ auto src_parts = SplitAbstractPath(src);
+ auto dest_parts = SplitAbstractPath(dest);
+ RETURN_NOT_OK(ValidateAbstractPathParts(src_parts));
+ RETURN_NOT_OK(ValidateAbstractPathParts(dest_parts));
+
+ auto guard = impl->lock_guard();
+
+ // Both source and destination must have valid parents
+ Entry* src_parent = impl->FindParent(src_parts);
+ if (src_parent == nullptr || !src_parent->is_dir()) {
+ return PathNotFound(src);
+ }
+ Entry* dest_parent = impl->FindParent(dest_parts);
+ if (dest_parent == nullptr || !dest_parent->is_dir()) {
+ return PathNotFound(dest);
+ }
+ Directory& src_dir = src_parent->as_dir();
+ Directory& dest_dir = dest_parent->as_dir();
+ DCHECK_GE(src_parts.size(), 1);
+ DCHECK_GE(dest_parts.size(), 1);
+ const auto& src_name = src_parts.back();
+ const auto& dest_name = dest_parts.back();
+
+ BinaryOp op{std::move(src_parts),
+ std::move(dest_parts),
+ src_dir,
+ dest_dir,
+ src_name,
+ dest_name,
+ src_dir.Find(src_name),
+ dest_dir.Find(dest_name)};
+
+ return op_func(std::move(op));
+ }
+};
+
+} // namespace
+
+Status MockFileSystem::Move(const std::string& src, const std::string& dest) {
+ return BinaryOp::Run(impl_.get(), src, dest, [&](const BinaryOp& op) -> Status {
+ if (op.src_entry == nullptr) {
+ return PathNotFound(src);
+ }
+ if (op.dest_entry != nullptr) {
+ if (op.dest_entry->is_dir()) {
+ return Status::IOError("Cannot replace destination '", dest,
+ "', which is a directory");
+ }
+ if (op.dest_entry->is_file() && op.src_entry->is_dir()) {
+ return Status::IOError("Cannot replace destination '", dest,
+ "', which is a file, with directory '", src, "'");
+ }
+ }
+ if (op.src_parts.size() < op.dest_parts.size()) {
+ // Check if dest is a child of src
+ auto p =
+ std::mismatch(op.src_parts.begin(), op.src_parts.end(), op.dest_parts.begin());
+ if (p.first == op.src_parts.end()) {
+ return Status::IOError("Cannot move '", src, "' into child path '", dest, "'");
+ }
+ }
+
+ // Move original entry, fix its name
+ std::unique_ptr<Entry> new_entry(new Entry(std::move(*op.src_entry)));
+ new_entry->SetName(op.dest_name);
+ bool deleted = op.src_dir.DeleteEntry(op.src_name);
+ DCHECK(deleted);
+ op.dest_dir.AssignEntry(op.dest_name, std::move(new_entry));
+ return Status::OK();
+ });
+}
+
+Status MockFileSystem::CopyFile(const std::string& src, const std::string& dest) {
+ return BinaryOp::Run(impl_.get(), src, dest, [&](const BinaryOp& op) -> Status {
+ if (op.src_entry == nullptr) {
+ return PathNotFound(src);
+ }
+ if (!op.src_entry->is_file()) {
+ return NotAFile(src);
+ }
+ if (op.dest_entry != nullptr && op.dest_entry->is_dir()) {
+ return Status::IOError("Cannot replace destination '", dest,
+ "', which is a directory");
+ }
+
+ // Copy original entry, fix its name
+ std::unique_ptr<Entry> new_entry(new Entry(File(op.src_entry->as_file())));
+ new_entry->SetName(op.dest_name);
+ op.dest_dir.AssignEntry(op.dest_name, std::move(new_entry));
+ return Status::OK();
+ });
+}
+
+Result<std::shared_ptr<io::InputStream>> MockFileSystem::OpenInputStream(
+ const std::string& path) {
+ auto guard = impl_->lock_guard();
+
+ return impl_->OpenInputReader(path);
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> MockFileSystem::OpenInputFile(
+ const std::string& path) {
+ auto guard = impl_->lock_guard();
+
+ return impl_->OpenInputReader(path);
+}
+
+Result<std::shared_ptr<io::OutputStream>> MockFileSystem::OpenOutputStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ auto guard = impl_->lock_guard();
+
+ return impl_->OpenOutputStream(path, /*append=*/false, metadata);
+}
+
+Result<std::shared_ptr<io::OutputStream>> MockFileSystem::OpenAppendStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ auto guard = impl_->lock_guard();
+
+ return impl_->OpenOutputStream(path, /*append=*/true, metadata);
+}
+
+std::vector<MockDirInfo> MockFileSystem::AllDirs() {
+ auto guard = impl_->lock_guard();
+
+ std::vector<MockDirInfo> result;
+ impl_->DumpDirs("", impl_->RootDir(), &result);
+ return result;
+}
+
+std::vector<MockFileInfo> MockFileSystem::AllFiles() {
+ auto guard = impl_->lock_guard();
+
+ std::vector<MockFileInfo> result;
+ impl_->DumpFiles("", impl_->RootDir(), &result);
+ return result;
+}
+
+Status MockFileSystem::CreateFile(const std::string& path, util::string_view contents,
+ bool recursive) {
+ auto parent = fs::internal::GetAbstractPathParent(path).first;
+
+ if (parent != "") {
+ RETURN_NOT_OK(CreateDir(parent, recursive));
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto file, OpenOutputStream(path));
+ RETURN_NOT_OK(file->Write(contents));
+ return file->Close();
+}
+
+Result<std::shared_ptr<FileSystem>> MockFileSystem::Make(
+ TimePoint current_time, const std::vector<FileInfo>& infos) {
+ auto fs = std::make_shared<MockFileSystem>(current_time);
+ for (const auto& info : infos) {
+ switch (info.type()) {
+ case FileType::Directory:
+ RETURN_NOT_OK(fs->CreateDir(info.path(), /*recursive*/ true));
+ break;
+ case FileType::File:
+ RETURN_NOT_OK(fs->CreateFile(info.path(), "", /*recursive*/ true));
+ break;
+ default:
+ break;
+ }
+ }
+
+ return fs;
+}
+
+FileInfoGenerator MockAsyncFileSystem::GetFileInfoGenerator(const FileSelector& select) {
+ auto maybe_infos = GetFileInfo(select);
+ if (maybe_infos.ok()) {
+ // Return the FileInfo entries one by one
+ const auto& infos = *maybe_infos;
+ std::vector<FileInfoVector> chunks(infos.size());
+ std::transform(infos.begin(), infos.end(), chunks.begin(),
+ [](const FileInfo& info) { return FileInfoVector{info}; });
+ return MakeVectorGenerator(std::move(chunks));
+ } else {
+ return MakeFailingGenerator(maybe_infos);
+ }
+}
+
+} // namespace internal
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/mockfs.h b/src/arrow/cpp/src/arrow/filesystem/mockfs.h
new file mode 100644
index 000000000..378f30d29
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/mockfs.h
@@ -0,0 +1,132 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <iosfwd>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/filesystem/filesystem.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/windows_fixup.h"
+
+namespace arrow {
+namespace fs {
+namespace internal {
+
+struct MockDirInfo {
+ std::string full_path;
+ TimePoint mtime;
+
+ bool operator==(const MockDirInfo& other) const {
+ return mtime == other.mtime && full_path == other.full_path;
+ }
+
+ friend ARROW_EXPORT std::ostream& operator<<(std::ostream&, const MockDirInfo&);
+};
+
+struct MockFileInfo {
+ std::string full_path;
+ TimePoint mtime;
+ util::string_view data;
+
+ bool operator==(const MockFileInfo& other) const {
+ return mtime == other.mtime && full_path == other.full_path && data == other.data;
+ }
+
+ friend ARROW_EXPORT std::ostream& operator<<(std::ostream&, const MockFileInfo&);
+};
+
+/// A mock FileSystem implementation that holds its contents in memory.
+///
+/// Useful for validating the FileSystem API, writing conformance suite,
+/// and bootstrapping FileSystem-based APIs.
+class ARROW_EXPORT MockFileSystem : public FileSystem {
+ public:
+ explicit MockFileSystem(TimePoint current_time,
+ const io::IOContext& = io::default_io_context());
+ ~MockFileSystem() override;
+
+ std::string type_name() const override { return "mock"; }
+
+ bool Equals(const FileSystem& other) const override;
+
+ // XXX It's not very practical to have to explicitly declare inheritance
+ // of default overrides.
+ using FileSystem::GetFileInfo;
+ Result<FileInfo> GetFileInfo(const std::string& path) override;
+ Result<std::vector<FileInfo>> GetFileInfo(const FileSelector& select) override;
+
+ Status CreateDir(const std::string& path, bool recursive = true) override;
+
+ Status DeleteDir(const std::string& path) override;
+ Status DeleteDirContents(const std::string& path) override;
+ Status DeleteRootDirContents() override;
+
+ Status DeleteFile(const std::string& path) override;
+
+ Status Move(const std::string& src, const std::string& dest) override;
+
+ Status CopyFile(const std::string& src, const std::string& dest) override;
+
+ Result<std::shared_ptr<io::InputStream>> OpenInputStream(
+ const std::string& path) override;
+ Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(
+ const std::string& path) override;
+ Result<std::shared_ptr<io::OutputStream>> OpenOutputStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+ Result<std::shared_ptr<io::OutputStream>> OpenAppendStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+
+ // Contents-dumping helpers to ease testing.
+ // Output is lexicographically-ordered by full path.
+ std::vector<MockDirInfo> AllDirs();
+ std::vector<MockFileInfo> AllFiles();
+
+ // Create a File with a content from a string.
+ Status CreateFile(const std::string& path, util::string_view content,
+ bool recursive = true);
+
+ // Create a MockFileSystem out of (empty) FileInfo. The content of every
+ // file is empty and of size 0. All directories will be created recursively.
+ static Result<std::shared_ptr<FileSystem>> Make(TimePoint current_time,
+ const std::vector<FileInfo>& infos);
+
+ class Impl;
+
+ protected:
+ std::unique_ptr<Impl> impl_;
+};
+
+class ARROW_EXPORT MockAsyncFileSystem : public MockFileSystem {
+ public:
+ explicit MockAsyncFileSystem(TimePoint current_time,
+ const io::IOContext& io_context = io::default_io_context())
+ : MockFileSystem(current_time, io_context) {
+ default_async_is_sync_ = false;
+ }
+
+ FileInfoGenerator GetFileInfoGenerator(const FileSelector& select) override;
+};
+
+} // namespace internal
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/path_util.cc b/src/arrow/cpp/src/arrow/filesystem/path_util.cc
new file mode 100644
index 000000000..f1bd5c087
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/path_util.cc
@@ -0,0 +1,271 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+
+#include "arrow/filesystem/path_util.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace fs {
+namespace internal {
+
+// XXX How does this encode Windows UNC paths?
+
+std::vector<std::string> SplitAbstractPath(const std::string& path) {
+ std::vector<std::string> parts;
+ auto v = util::string_view(path);
+ // Strip trailing slash
+ if (v.length() > 0 && v.back() == kSep) {
+ v = v.substr(0, v.length() - 1);
+ }
+ // Strip leading slash
+ if (v.length() > 0 && v.front() == kSep) {
+ v = v.substr(1);
+ }
+ if (v.length() == 0) {
+ return parts;
+ }
+
+ auto append_part = [&parts, &v](size_t start, size_t end) {
+ parts.push_back(std::string(v.substr(start, end - start)));
+ };
+
+ size_t start = 0;
+ while (true) {
+ size_t end = v.find_first_of(kSep, start);
+ append_part(start, end);
+ if (end == std::string::npos) {
+ break;
+ }
+ start = end + 1;
+ }
+ return parts;
+}
+
+std::pair<std::string, std::string> GetAbstractPathParent(const std::string& s) {
+ // XXX should strip trailing slash?
+
+ auto pos = s.find_last_of(kSep);
+ if (pos == std::string::npos) {
+ // Empty parent
+ return {{}, s};
+ }
+ return {s.substr(0, pos), s.substr(pos + 1)};
+}
+
+std::string GetAbstractPathExtension(const std::string& s) {
+ util::string_view basename(s);
+ auto offset = basename.find_last_of(kSep);
+ if (offset != std::string::npos) {
+ basename = basename.substr(offset);
+ }
+ auto dot = basename.find_last_of('.');
+ if (dot == util::string_view::npos) {
+ // Empty extension
+ return "";
+ }
+ return std::string(basename.substr(dot + 1));
+}
+
+Status ValidateAbstractPathParts(const std::vector<std::string>& parts) {
+ for (const auto& part : parts) {
+ if (part.length() == 0) {
+ return Status::Invalid("Empty path component");
+ }
+ if (part.find_first_of(kSep) != std::string::npos) {
+ return Status::Invalid("Separator in component '", part, "'");
+ }
+ }
+ return Status::OK();
+}
+
+std::string ConcatAbstractPath(const std::string& base, const std::string& stem) {
+ DCHECK(!stem.empty());
+ if (base.empty()) {
+ return stem;
+ }
+ return EnsureTrailingSlash(base) + std::string(RemoveLeadingSlash(stem));
+}
+
+std::string EnsureTrailingSlash(util::string_view v) {
+ if (v.length() > 0 && v.back() != kSep) {
+ // XXX How about "C:" on Windows? We probably don't want to turn it into "C:/"...
+ // Unless the local filesystem always uses absolute paths
+ return std::string(v) + kSep;
+ } else {
+ return std::string(v);
+ }
+}
+
+std::string EnsureLeadingSlash(util::string_view v) {
+ if (v.length() == 0 || v.front() != kSep) {
+ // XXX How about "C:" on Windows? We probably don't want to turn it into "/C:"...
+ return kSep + std::string(v);
+ } else {
+ return std::string(v);
+ }
+}
+util::string_view RemoveTrailingSlash(util::string_view key) {
+ while (!key.empty() && key.back() == kSep) {
+ key.remove_suffix(1);
+ }
+ return key;
+}
+
+util::string_view RemoveLeadingSlash(util::string_view key) {
+ while (!key.empty() && key.front() == kSep) {
+ key.remove_prefix(1);
+ }
+ return key;
+}
+
+Result<std::string> MakeAbstractPathRelative(const std::string& base,
+ const std::string& path) {
+ if (base.empty() || base.front() != kSep) {
+ return Status::Invalid("MakeAbstractPathRelative called with non-absolute base '",
+ base, "'");
+ }
+ auto b = EnsureLeadingSlash(RemoveTrailingSlash(base));
+ auto p = util::string_view(path);
+ if (p.substr(0, b.size()) != util::string_view(b)) {
+ return Status::Invalid("Path '", path, "' is not relative to '", base, "'");
+ }
+ p = p.substr(b.size());
+ if (!p.empty() && p.front() != kSep && b.back() != kSep) {
+ return Status::Invalid("Path '", path, "' is not relative to '", base, "'");
+ }
+ return std::string(RemoveLeadingSlash(p));
+}
+
+bool IsAncestorOf(util::string_view ancestor, util::string_view descendant) {
+ ancestor = RemoveTrailingSlash(ancestor);
+ if (ancestor == "") {
+ // everything is a descendant of the root directory
+ return true;
+ }
+
+ descendant = RemoveTrailingSlash(descendant);
+ if (!descendant.starts_with(ancestor)) {
+ // an ancestor path is a prefix of descendant paths
+ return false;
+ }
+
+ descendant.remove_prefix(ancestor.size());
+
+ if (descendant.empty()) {
+ // "/hello" is an ancestor of "/hello"
+ return true;
+ }
+
+ // "/hello/w" is not an ancestor of "/hello/world"
+ return descendant.starts_with(std::string{kSep});
+}
+
+util::optional<util::string_view> RemoveAncestor(util::string_view ancestor,
+ util::string_view descendant) {
+ if (!IsAncestorOf(ancestor, descendant)) {
+ return util::nullopt;
+ }
+
+ auto relative_to_ancestor = descendant.substr(ancestor.size());
+ return RemoveLeadingSlash(relative_to_ancestor);
+}
+
+std::vector<std::string> AncestorsFromBasePath(util::string_view base_path,
+ util::string_view descendant) {
+ std::vector<std::string> ancestry;
+ if (auto relative = RemoveAncestor(base_path, descendant)) {
+ auto relative_segments = fs::internal::SplitAbstractPath(std::string(*relative));
+
+ // the last segment indicates descendant
+ relative_segments.pop_back();
+
+ if (relative_segments.empty()) {
+ // no missing parent
+ return {};
+ }
+
+ for (auto&& relative_segment : relative_segments) {
+ ancestry.push_back(JoinAbstractPath(
+ std::vector<std::string>{std::string(base_path), std::move(relative_segment)}));
+ base_path = ancestry.back();
+ }
+ }
+ return ancestry;
+}
+
+std::vector<std::string> MinimalCreateDirSet(std::vector<std::string> dirs) {
+ std::sort(dirs.begin(), dirs.end());
+
+ for (auto ancestor = dirs.begin(); ancestor != dirs.end(); ++ancestor) {
+ auto descendant = ancestor;
+ auto descendants_end = descendant + 1;
+
+ while (descendants_end != dirs.end() && IsAncestorOf(*descendant, *descendants_end)) {
+ ++descendant;
+ ++descendants_end;
+ }
+
+ ancestor = dirs.erase(ancestor, descendants_end - 1);
+ }
+
+ // the root directory need not be created
+ if (dirs.size() == 1 && IsAncestorOf(dirs[0], "")) {
+ return {};
+ }
+
+ return dirs;
+}
+
+std::string ToBackslashes(util::string_view v) {
+ std::string s(v);
+ for (auto& c : s) {
+ if (c == '/') {
+ c = '\\';
+ }
+ }
+ return s;
+}
+
+std::string ToSlashes(util::string_view v) {
+ std::string s(v);
+#ifdef _WIN32
+ for (auto& c : s) {
+ if (c == '\\') {
+ c = '/';
+ }
+ }
+#endif
+ return s;
+}
+
+bool IsEmptyPath(util::string_view v) {
+ for (const auto c : v) {
+ if (c != '/') {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace internal
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/path_util.h b/src/arrow/cpp/src/arrow/filesystem/path_util.h
new file mode 100644
index 000000000..5701c11b5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/path_util.h
@@ -0,0 +1,130 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/type_fwd.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace fs {
+namespace internal {
+
+constexpr char kSep = '/';
+
+// Computations on abstract paths (not local paths with system-dependent behaviour).
+// Abstract paths are typically used in URIs.
+
+// Split an abstract path into its individual components.
+ARROW_EXPORT
+std::vector<std::string> SplitAbstractPath(const std::string& s);
+
+// Return the extension of the file
+ARROW_EXPORT
+std::string GetAbstractPathExtension(const std::string& s);
+
+// Return the parent directory and basename of an abstract path. Both values may be
+// empty.
+ARROW_EXPORT
+std::pair<std::string, std::string> GetAbstractPathParent(const std::string& s);
+
+// Validate the components of an abstract path.
+ARROW_EXPORT
+Status ValidateAbstractPathParts(const std::vector<std::string>& parts);
+
+// Append a non-empty stem to an abstract path.
+ARROW_EXPORT
+std::string ConcatAbstractPath(const std::string& base, const std::string& stem);
+
+// Make path relative to base, if it starts with base. Otherwise error out.
+ARROW_EXPORT
+Result<std::string> MakeAbstractPathRelative(const std::string& base,
+ const std::string& path);
+
+ARROW_EXPORT
+std::string EnsureLeadingSlash(util::string_view s);
+
+ARROW_EXPORT
+util::string_view RemoveLeadingSlash(util::string_view s);
+
+ARROW_EXPORT
+std::string EnsureTrailingSlash(util::string_view s);
+
+ARROW_EXPORT
+util::string_view RemoveTrailingSlash(util::string_view s);
+
+ARROW_EXPORT
+bool IsAncestorOf(util::string_view ancestor, util::string_view descendant);
+
+ARROW_EXPORT
+util::optional<util::string_view> RemoveAncestor(util::string_view ancestor,
+ util::string_view descendant);
+
+/// Return a vector of ancestors between a base path and a descendant.
+/// For example,
+///
+/// AncestorsFromBasePath("a/b", "a/b/c/d/e") -> ["a/b/c", "a/b/c/d"]
+ARROW_EXPORT
+std::vector<std::string> AncestorsFromBasePath(util::string_view base_path,
+ util::string_view descendant);
+
+/// Given a vector of paths of directories which must be created, produce a the minimal
+/// subset for passing to CreateDir(recursive=true) by removing redundant parent
+/// directories
+ARROW_EXPORT
+std::vector<std::string> MinimalCreateDirSet(std::vector<std::string> dirs);
+
+// Join the components of an abstract path.
+template <class StringIt>
+std::string JoinAbstractPath(StringIt it, StringIt end) {
+ std::string path;
+ for (; it != end; ++it) {
+ if (it->empty()) continue;
+
+ if (!path.empty()) {
+ path += kSep;
+ }
+ path += *it;
+ }
+ return path;
+}
+
+template <class StringRange>
+std::string JoinAbstractPath(const StringRange& range) {
+ return JoinAbstractPath(range.begin(), range.end());
+}
+
+/// Convert slashes to backslashes, on all platforms. Mostly useful for testing.
+ARROW_EXPORT
+std::string ToBackslashes(util::string_view s);
+
+/// Ensure a local path is abstract, by converting backslashes to regular slashes
+/// on Windows. Return the path unchanged on other systems.
+ARROW_EXPORT
+std::string ToSlashes(util::string_view s);
+
+ARROW_EXPORT
+bool IsEmptyPath(util::string_view s);
+
+} // namespace internal
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/s3_internal.h b/src/arrow/cpp/src/arrow/filesystem/s3_internal.h
new file mode 100644
index 000000000..ceb92b554
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/s3_internal.h
@@ -0,0 +1,215 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <sstream>
+#include <string>
+#include <tuple>
+#include <utility>
+
+#include <aws/core/Aws.h>
+#include <aws/core/client/RetryStrategy.h>
+#include <aws/core/http/HttpTypes.h>
+#include <aws/core/utils/DateTime.h>
+#include <aws/core/utils/StringUtils.h>
+
+#include "arrow/filesystem/filesystem.h"
+#include "arrow/filesystem/s3fs.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/print.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace fs {
+namespace internal {
+
+#define ARROW_AWS_ASSIGN_OR_RAISE_IMPL(outcome_name, lhs, rexpr) \
+ auto outcome_name = (rexpr); \
+ if (!outcome_name.IsSuccess()) { \
+ return ErrorToStatus(outcome_name.GetError()); \
+ } \
+ lhs = std::move(outcome_name).GetResultWithOwnership();
+
+#define ARROW_AWS_ASSIGN_OR_RAISE_NAME(x, y) ARROW_CONCAT(x, y)
+
+#define ARROW_AWS_ASSIGN_OR_RAISE(lhs, rexpr) \
+ ARROW_AWS_ASSIGN_OR_RAISE_IMPL( \
+ ARROW_AWS_ASSIGN_OR_RAISE_NAME(_aws_error_or_value, __COUNTER__), lhs, rexpr);
+
+// XXX Should we expose this at some point?
+enum class S3Backend { Amazon, Minio, Other };
+
+// Detect the S3 backend type from the S3 server's response headers
+S3Backend DetectS3Backend(const Aws::Http::HeaderValueCollection& headers) {
+ const auto it = headers.find("server");
+ if (it != headers.end()) {
+ const auto& value = util::string_view(it->second);
+ if (value.find("AmazonS3") != std::string::npos) {
+ return S3Backend::Amazon;
+ }
+ if (value.find("MinIO") != std::string::npos) {
+ return S3Backend::Minio;
+ }
+ }
+ return S3Backend::Other;
+}
+
+template <typename Error>
+S3Backend DetectS3Backend(const Aws::Client::AWSError<Error>& error) {
+ return DetectS3Backend(error.GetResponseHeaders());
+}
+
+template <typename Error>
+inline bool IsConnectError(const Aws::Client::AWSError<Error>& error) {
+ if (error.ShouldRetry()) {
+ return true;
+ }
+ // Sometimes Minio may fail with a 503 error
+ // (exception name: XMinioServerNotInitialized,
+ // message: "Server not initialized, please try again")
+ if (error.GetExceptionName() == "XMinioServerNotInitialized") {
+ return true;
+ }
+ return false;
+}
+
+inline bool IsNotFound(const Aws::Client::AWSError<Aws::S3::S3Errors>& error) {
+ const auto error_type = error.GetErrorType();
+ return (error_type == Aws::S3::S3Errors::NO_SUCH_BUCKET ||
+ error_type == Aws::S3::S3Errors::RESOURCE_NOT_FOUND);
+}
+
+inline bool IsAlreadyExists(const Aws::Client::AWSError<Aws::S3::S3Errors>& error) {
+ const auto error_type = error.GetErrorType();
+ return (error_type == Aws::S3::S3Errors::BUCKET_ALREADY_EXISTS ||
+ error_type == Aws::S3::S3Errors::BUCKET_ALREADY_OWNED_BY_YOU);
+}
+
+// TODO qualify error messages with a prefix indicating context
+// (e.g. "When completing multipart upload to bucket 'xxx', key 'xxx': ...")
+template <typename ErrorType>
+Status ErrorToStatus(const std::string& prefix,
+ const Aws::Client::AWSError<ErrorType>& error) {
+ // XXX Handle fine-grained error types
+ // See
+ // https://sdk.amazonaws.com/cpp/api/LATEST/namespace_aws_1_1_s3.html#ae3f82f8132b619b6e91c88a9f1bde371
+ return Status::IOError(prefix, "AWS Error [code ",
+ static_cast<int>(error.GetErrorType()),
+ "]: ", error.GetMessage());
+}
+
+template <typename ErrorType, typename... Args>
+Status ErrorToStatus(const std::tuple<Args&...>& prefix,
+ const Aws::Client::AWSError<ErrorType>& error) {
+ std::stringstream ss;
+ ::arrow::internal::PrintTuple(&ss, prefix);
+ return ErrorToStatus(ss.str(), error);
+}
+
+template <typename ErrorType>
+Status ErrorToStatus(const Aws::Client::AWSError<ErrorType>& error) {
+ return ErrorToStatus(std::string(), error);
+}
+
+template <typename AwsResult, typename Error>
+Status OutcomeToStatus(const std::string& prefix,
+ const Aws::Utils::Outcome<AwsResult, Error>& outcome) {
+ if (outcome.IsSuccess()) {
+ return Status::OK();
+ } else {
+ return ErrorToStatus(prefix, outcome.GetError());
+ }
+}
+
+template <typename AwsResult, typename Error, typename... Args>
+Status OutcomeToStatus(const std::tuple<Args&...>& prefix,
+ const Aws::Utils::Outcome<AwsResult, Error>& outcome) {
+ if (outcome.IsSuccess()) {
+ return Status::OK();
+ } else {
+ return ErrorToStatus(prefix, outcome.GetError());
+ }
+}
+
+template <typename AwsResult, typename Error>
+Status OutcomeToStatus(const Aws::Utils::Outcome<AwsResult, Error>& outcome) {
+ return OutcomeToStatus(std::string(), outcome);
+}
+
+template <typename AwsResult, typename Error>
+Result<AwsResult> OutcomeToResult(Aws::Utils::Outcome<AwsResult, Error> outcome) {
+ if (outcome.IsSuccess()) {
+ return std::move(outcome).GetResultWithOwnership();
+ } else {
+ return ErrorToStatus(outcome.GetError());
+ }
+}
+
+inline Aws::String ToAwsString(const std::string& s) {
+ // Direct construction of Aws::String from std::string doesn't work because
+ // it uses a specific Allocator class.
+ return Aws::String(s.begin(), s.end());
+}
+
+inline util::string_view FromAwsString(const Aws::String& s) {
+ return {s.data(), s.length()};
+}
+
+inline Aws::String ToURLEncodedAwsString(const std::string& s) {
+ return Aws::Utils::StringUtils::URLEncode(s.data());
+}
+
+inline TimePoint FromAwsDatetime(const Aws::Utils::DateTime& dt) {
+ return std::chrono::time_point_cast<std::chrono::nanoseconds>(dt.UnderlyingTimestamp());
+}
+
+// A connect retry strategy with a controlled max duration.
+
+class ConnectRetryStrategy : public Aws::Client::RetryStrategy {
+ public:
+ static const int32_t kDefaultRetryInterval = 200; /* milliseconds */
+ static const int32_t kDefaultMaxRetryDuration = 6000; /* milliseconds */
+
+ explicit ConnectRetryStrategy(int32_t retry_interval = kDefaultRetryInterval,
+ int32_t max_retry_duration = kDefaultMaxRetryDuration)
+ : retry_interval_(retry_interval), max_retry_duration_(max_retry_duration) {}
+
+ bool ShouldRetry(const Aws::Client::AWSError<Aws::Client::CoreErrors>& error,
+ long attempted_retries) const override { // NOLINT runtime/int
+ if (!IsConnectError(error)) {
+ // Not a connect error, don't retry
+ return false;
+ }
+ return attempted_retries * retry_interval_ < max_retry_duration_;
+ }
+
+ long CalculateDelayBeforeNextRetry( // NOLINT runtime/int
+ const Aws::Client::AWSError<Aws::Client::CoreErrors>& error,
+ long attempted_retries) const override { // NOLINT runtime/int
+ return retry_interval_;
+ }
+
+ protected:
+ int32_t retry_interval_;
+ int32_t max_retry_duration_;
+};
+
+} // namespace internal
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/s3_test_util.h b/src/arrow/cpp/src/arrow/filesystem/s3_test_util.h
new file mode 100644
index 000000000..432ff1d22
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/s3_test_util.h
@@ -0,0 +1,153 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <sstream>
+#include <string>
+#include <utility>
+
+// We need BOOST_USE_WINDOWS_H definition with MinGW when we use
+// boost/process.hpp. See ARROW_BOOST_PROCESS_COMPILE_DEFINITIONS in
+// cpp/cmake_modules/BuildUtils.cmake for details.
+#include <aws/core/Aws.h>
+#include <gtest/gtest.h>
+
+#include <boost/process.hpp>
+
+#include "arrow/filesystem/s3fs.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/io_util.h"
+
+namespace arrow {
+namespace fs {
+
+using ::arrow::internal::TemporaryDir;
+
+namespace bp = boost::process;
+
+// TODO: allocate an ephemeral port
+static const char* kMinioExecutableName = "minio";
+static const char* kMinioAccessKey = "minio";
+static const char* kMinioSecretKey = "miniopass";
+
+// Environment variables to configure another S3-compatible service
+static const char* kEnvConnectString = "ARROW_TEST_S3_CONNECT_STRING";
+static const char* kEnvAccessKey = "ARROW_TEST_S3_ACCESS_KEY";
+static const char* kEnvSecretKey = "ARROW_TEST_S3_SECRET_KEY";
+
+static std::string GenerateConnectString() {
+ std::stringstream ss;
+ ss << "127.0.0.1:" << GetListenPort();
+ return ss.str();
+}
+
+// A minio test server, managed as a child process
+
+class MinioTestServer {
+ public:
+ Status Start();
+
+ Status Stop();
+
+ std::string connect_string() const { return connect_string_; }
+
+ std::string access_key() const { return access_key_; }
+
+ std::string secret_key() const { return secret_key_; }
+
+ private:
+ std::unique_ptr<TemporaryDir> temp_dir_;
+ std::string connect_string_;
+ std::string access_key_ = kMinioAccessKey;
+ std::string secret_key_ = kMinioSecretKey;
+ std::shared_ptr<::boost::process::child> server_process_;
+};
+
+Status MinioTestServer::Start() {
+ const char* connect_str = std::getenv(kEnvConnectString);
+ const char* access_key = std::getenv(kEnvAccessKey);
+ const char* secret_key = std::getenv(kEnvSecretKey);
+ if (connect_str && access_key && secret_key) {
+ // Use external instance
+ connect_string_ = connect_str;
+ access_key_ = access_key;
+ secret_key_ = secret_key;
+ return Status::OK();
+ }
+
+ ARROW_ASSIGN_OR_RAISE(temp_dir_, TemporaryDir::Make("s3fs-test-"));
+
+ // Get a copy of the current environment.
+ // (NOTE: using "auto" would return a native_environment that mutates
+ // the current environment)
+ bp::environment env = boost::this_process::environment();
+ env["MINIO_ACCESS_KEY"] = kMinioAccessKey;
+ env["MINIO_SECRET_KEY"] = kMinioSecretKey;
+
+ connect_string_ = GenerateConnectString();
+
+ auto exe_path = bp::search_path(kMinioExecutableName);
+ if (exe_path.empty()) {
+ return Status::IOError("Failed to find minio executable ('", kMinioExecutableName,
+ "') in PATH");
+ }
+
+ try {
+ // NOTE: --quiet makes startup faster by suppressing remote version check
+ server_process_ = std::make_shared<bp::child>(
+ env, exe_path, "server", "--quiet", "--compat", "--address", connect_string_,
+ temp_dir_->path().ToString());
+ } catch (const std::exception& e) {
+ return Status::IOError("Failed to launch Minio server: ", e.what());
+ }
+ return Status::OK();
+}
+
+Status MinioTestServer::Stop() {
+ if (server_process_ && server_process_->valid()) {
+ // Brutal shutdown
+ server_process_->terminate();
+ server_process_->wait();
+ }
+ return Status::OK();
+}
+
+// A global test "environment", to ensure that the S3 API is initialized before
+// running unit tests.
+
+class S3Environment : public ::testing::Environment {
+ public:
+ void SetUp() override {
+ // Change this to increase logging during tests
+ S3GlobalOptions options;
+ options.log_level = S3LogLevel::Fatal;
+ ASSERT_OK(InitializeS3(options));
+ }
+
+ void TearDown() override { ASSERT_OK(FinalizeS3()); }
+
+ protected:
+ Aws::SDKOptions options_;
+};
+
+::testing::Environment* s3_env = ::testing::AddGlobalTestEnvironment(new S3Environment);
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/s3fs.cc b/src/arrow/cpp/src/arrow/filesystem/s3fs.cc
new file mode 100644
index 000000000..49766d175
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/s3fs.cc
@@ -0,0 +1,2453 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/filesystem/s3fs.h"
+
+#include <algorithm>
+#include <atomic>
+#include <chrono>
+#include <condition_variable>
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <sstream>
+#include <thread>
+#include <unordered_map>
+#include <utility>
+
+#ifdef _WIN32
+// Undefine preprocessor macros that interfere with AWS function / method names
+#ifdef GetMessage
+#undef GetMessage
+#endif
+#ifdef GetObject
+#undef GetObject
+#endif
+#endif
+
+#include <aws/core/Aws.h>
+#include <aws/core/Region.h>
+#include <aws/core/auth/AWSCredentials.h>
+#include <aws/core/auth/AWSCredentialsProviderChain.h>
+#include <aws/core/auth/STSCredentialsProvider.h>
+#include <aws/core/client/DefaultRetryStrategy.h>
+#include <aws/core/client/RetryStrategy.h>
+#include <aws/core/http/HttpResponse.h>
+#include <aws/core/utils/logging/ConsoleLogSystem.h>
+#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
+#include <aws/core/utils/xml/XmlSerializer.h>
+#include <aws/identity-management/auth/STSAssumeRoleCredentialsProvider.h>
+#include <aws/s3/S3Client.h>
+#include <aws/s3/model/AbortMultipartUploadRequest.h>
+#include <aws/s3/model/CompleteMultipartUploadRequest.h>
+#include <aws/s3/model/CompletedMultipartUpload.h>
+#include <aws/s3/model/CompletedPart.h>
+#include <aws/s3/model/CopyObjectRequest.h>
+#include <aws/s3/model/CreateBucketRequest.h>
+#include <aws/s3/model/CreateMultipartUploadRequest.h>
+#include <aws/s3/model/DeleteBucketRequest.h>
+#include <aws/s3/model/DeleteObjectRequest.h>
+#include <aws/s3/model/DeleteObjectsRequest.h>
+#include <aws/s3/model/GetObjectRequest.h>
+#include <aws/s3/model/HeadBucketRequest.h>
+#include <aws/s3/model/HeadObjectRequest.h>
+#include <aws/s3/model/ListBucketsResult.h>
+#include <aws/s3/model/ListObjectsV2Request.h>
+#include <aws/s3/model/ObjectCannedACL.h>
+#include <aws/s3/model/PutObjectRequest.h>
+#include <aws/s3/model/UploadPartRequest.h>
+
+#include "arrow/util/windows_fixup.h"
+
+#include "arrow/buffer.h"
+#include "arrow/filesystem/filesystem.h"
+#include "arrow/filesystem/path_util.h"
+#include "arrow/filesystem/s3_internal.h"
+#include "arrow/filesystem/util_internal.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/memory.h"
+#include "arrow/io/util_internal.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/atomic_shared_ptr.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/task_group.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::TaskGroup;
+using internal::Uri;
+using io::internal::SubmitIO;
+
+namespace fs {
+
+using ::Aws::Client::AWSError;
+using ::Aws::S3::S3Errors;
+namespace S3Model = Aws::S3::Model;
+
+using internal::ConnectRetryStrategy;
+using internal::DetectS3Backend;
+using internal::ErrorToStatus;
+using internal::FromAwsDatetime;
+using internal::FromAwsString;
+using internal::IsAlreadyExists;
+using internal::IsNotFound;
+using internal::OutcomeToResult;
+using internal::OutcomeToStatus;
+using internal::S3Backend;
+using internal::ToAwsString;
+using internal::ToURLEncodedAwsString;
+
+static const char kSep = '/';
+
+namespace {
+
+std::mutex aws_init_lock;
+Aws::SDKOptions aws_options;
+std::atomic<bool> aws_initialized(false);
+
+Status DoInitializeS3(const S3GlobalOptions& options) {
+ Aws::Utils::Logging::LogLevel aws_log_level;
+
+#define LOG_LEVEL_CASE(level_name) \
+ case S3LogLevel::level_name: \
+ aws_log_level = Aws::Utils::Logging::LogLevel::level_name; \
+ break;
+
+ switch (options.log_level) {
+ LOG_LEVEL_CASE(Fatal)
+ LOG_LEVEL_CASE(Error)
+ LOG_LEVEL_CASE(Warn)
+ LOG_LEVEL_CASE(Info)
+ LOG_LEVEL_CASE(Debug)
+ LOG_LEVEL_CASE(Trace)
+ default:
+ aws_log_level = Aws::Utils::Logging::LogLevel::Off;
+ }
+
+#undef LOG_LEVEL_CASE
+
+ aws_options.loggingOptions.logLevel = aws_log_level;
+ // By default the AWS SDK logs to files, log to console instead
+ aws_options.loggingOptions.logger_create_fn = [] {
+ return std::make_shared<Aws::Utils::Logging::ConsoleLogSystem>(
+ aws_options.loggingOptions.logLevel);
+ };
+ Aws::InitAPI(aws_options);
+ aws_initialized.store(true);
+ return Status::OK();
+}
+
+} // namespace
+
+Status InitializeS3(const S3GlobalOptions& options) {
+ std::lock_guard<std::mutex> lock(aws_init_lock);
+ return DoInitializeS3(options);
+}
+
+Status FinalizeS3() {
+ std::lock_guard<std::mutex> lock(aws_init_lock);
+ Aws::ShutdownAPI(aws_options);
+ aws_initialized.store(false);
+ return Status::OK();
+}
+
+Status EnsureS3Initialized() {
+ std::lock_guard<std::mutex> lock(aws_init_lock);
+ if (!aws_initialized.load()) {
+ S3GlobalOptions options{S3LogLevel::Fatal};
+ return DoInitializeS3(options);
+ }
+ return Status::OK();
+}
+
+// -----------------------------------------------------------------------
+// S3ProxyOptions implementation
+
+Result<S3ProxyOptions> S3ProxyOptions::FromUri(const Uri& uri) {
+ S3ProxyOptions options;
+
+ options.scheme = uri.scheme();
+ options.host = uri.host();
+ options.port = uri.port();
+ options.username = uri.username();
+ options.password = uri.password();
+
+ return options;
+}
+
+Result<S3ProxyOptions> S3ProxyOptions::FromUri(const std::string& uri_string) {
+ Uri uri;
+ RETURN_NOT_OK(uri.Parse(uri_string));
+ return FromUri(uri);
+}
+
+bool S3ProxyOptions::Equals(const S3ProxyOptions& other) const {
+ return (scheme == other.scheme && host == other.host && port == other.port &&
+ username == other.username && password == other.password);
+}
+
+// -----------------------------------------------------------------------
+// S3Options implementation
+
+void S3Options::ConfigureDefaultCredentials() {
+ credentials_provider =
+ std::make_shared<Aws::Auth::DefaultAWSCredentialsProviderChain>();
+ credentials_kind = S3CredentialsKind::Default;
+}
+
+void S3Options::ConfigureAnonymousCredentials() {
+ credentials_provider = std::make_shared<Aws::Auth::AnonymousAWSCredentialsProvider>();
+ credentials_kind = S3CredentialsKind::Anonymous;
+}
+
+void S3Options::ConfigureAccessKey(const std::string& access_key,
+ const std::string& secret_key,
+ const std::string& session_token) {
+ credentials_provider = std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(
+ ToAwsString(access_key), ToAwsString(secret_key), ToAwsString(session_token));
+ credentials_kind = S3CredentialsKind::Explicit;
+}
+
+void S3Options::ConfigureAssumeRoleCredentials(
+ const std::string& role_arn, const std::string& session_name,
+ const std::string& external_id, int load_frequency,
+ const std::shared_ptr<Aws::STS::STSClient>& stsClient) {
+ credentials_provider = std::make_shared<Aws::Auth::STSAssumeRoleCredentialsProvider>(
+ ToAwsString(role_arn), ToAwsString(session_name), ToAwsString(external_id),
+ load_frequency, stsClient);
+ credentials_kind = S3CredentialsKind::Role;
+}
+
+void S3Options::ConfigureAssumeRoleWithWebIdentityCredentials() {
+ // The AWS SDK uses environment variables AWS_DEFAULT_REGION,
+ // AWS_ROLE_ARN, AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_SESSION_NAME
+ // to configure the required credentials
+ credentials_provider =
+ std::make_shared<Aws::Auth::STSAssumeRoleWebIdentityCredentialsProvider>();
+ credentials_kind = S3CredentialsKind::WebIdentity;
+}
+
+std::string S3Options::GetAccessKey() const {
+ auto credentials = credentials_provider->GetAWSCredentials();
+ return std::string(FromAwsString(credentials.GetAWSAccessKeyId()));
+}
+
+std::string S3Options::GetSecretKey() const {
+ auto credentials = credentials_provider->GetAWSCredentials();
+ return std::string(FromAwsString(credentials.GetAWSSecretKey()));
+}
+
+std::string S3Options::GetSessionToken() const {
+ auto credentials = credentials_provider->GetAWSCredentials();
+ return std::string(FromAwsString(credentials.GetSessionToken()));
+}
+
+S3Options S3Options::Defaults() {
+ S3Options options;
+ options.ConfigureDefaultCredentials();
+ return options;
+}
+
+S3Options S3Options::Anonymous() {
+ S3Options options;
+ options.ConfigureAnonymousCredentials();
+ return options;
+}
+
+S3Options S3Options::FromAccessKey(const std::string& access_key,
+ const std::string& secret_key,
+ const std::string& session_token) {
+ S3Options options;
+ options.ConfigureAccessKey(access_key, secret_key, session_token);
+ return options;
+}
+
+S3Options S3Options::FromAssumeRole(
+ const std::string& role_arn, const std::string& session_name,
+ const std::string& external_id, int load_frequency,
+ const std::shared_ptr<Aws::STS::STSClient>& stsClient) {
+ S3Options options;
+ options.role_arn = role_arn;
+ options.session_name = session_name;
+ options.external_id = external_id;
+ options.load_frequency = load_frequency;
+ options.ConfigureAssumeRoleCredentials(role_arn, session_name, external_id,
+ load_frequency, stsClient);
+ return options;
+}
+
+S3Options S3Options::FromAssumeRoleWithWebIdentity() {
+ S3Options options;
+ options.ConfigureAssumeRoleWithWebIdentityCredentials();
+ return options;
+}
+
+Result<S3Options> S3Options::FromUri(const Uri& uri, std::string* out_path) {
+ S3Options options;
+
+ const auto bucket = uri.host();
+ auto path = uri.path();
+ if (bucket.empty()) {
+ if (!path.empty()) {
+ return Status::Invalid("Missing bucket name in S3 URI");
+ }
+ } else {
+ if (path.empty()) {
+ path = bucket;
+ } else {
+ if (path[0] != '/') {
+ return Status::Invalid("S3 URI should absolute, not relative");
+ }
+ path = bucket + path;
+ }
+ }
+ if (out_path != nullptr) {
+ *out_path = std::string(internal::RemoveTrailingSlash(path));
+ }
+
+ std::unordered_map<std::string, std::string> options_map;
+ ARROW_ASSIGN_OR_RAISE(const auto options_items, uri.query_items());
+ for (const auto& kv : options_items) {
+ options_map.emplace(kv.first, kv.second);
+ }
+
+ const auto username = uri.username();
+ if (!username.empty()) {
+ options.ConfigureAccessKey(username, uri.password());
+ } else {
+ options.ConfigureDefaultCredentials();
+ }
+
+ bool region_set = false;
+ for (const auto& kv : options_map) {
+ if (kv.first == "region") {
+ options.region = kv.second;
+ region_set = true;
+ } else if (kv.first == "scheme") {
+ options.scheme = kv.second;
+ } else if (kv.first == "endpoint_override") {
+ options.endpoint_override = kv.second;
+ } else {
+ return Status::Invalid("Unexpected query parameter in S3 URI: '", kv.first, "'");
+ }
+ }
+
+ if (!region_set && !bucket.empty() && options.endpoint_override.empty()) {
+ // XXX Should we use a dedicated resolver with the given credentials?
+ ARROW_ASSIGN_OR_RAISE(options.region, ResolveBucketRegion(bucket));
+ }
+
+ return options;
+}
+
+Result<S3Options> S3Options::FromUri(const std::string& uri_string,
+ std::string* out_path) {
+ Uri uri;
+ RETURN_NOT_OK(uri.Parse(uri_string));
+ return FromUri(uri, out_path);
+}
+
+bool S3Options::Equals(const S3Options& other) const {
+ return (region == other.region && endpoint_override == other.endpoint_override &&
+ scheme == other.scheme && background_writes == other.background_writes &&
+ credentials_kind == other.credentials_kind &&
+ proxy_options.Equals(other.proxy_options) &&
+ GetAccessKey() == other.GetAccessKey() &&
+ GetSecretKey() == other.GetSecretKey() &&
+ GetSessionToken() == other.GetSessionToken());
+}
+
+namespace {
+
+Status CheckS3Initialized() {
+ if (!aws_initialized.load()) {
+ return Status::Invalid(
+ "S3 subsystem not initialized; please call InitializeS3() "
+ "before carrying out any S3-related operation");
+ }
+ return Status::OK();
+}
+
+// XXX Sanitize paths by removing leading slash?
+
+struct S3Path {
+ std::string full_path;
+ std::string bucket;
+ std::string key;
+ std::vector<std::string> key_parts;
+
+ static Result<S3Path> FromString(const std::string& s) {
+ const auto src = internal::RemoveTrailingSlash(s);
+ auto first_sep = src.find_first_of(kSep);
+ if (first_sep == 0) {
+ return Status::Invalid("Path cannot start with a separator ('", s, "')");
+ }
+ if (first_sep == std::string::npos) {
+ return S3Path{std::string(src), std::string(src), "", {}};
+ }
+ S3Path path;
+ path.full_path = std::string(src);
+ path.bucket = std::string(src.substr(0, first_sep));
+ path.key = std::string(src.substr(first_sep + 1));
+ path.key_parts = internal::SplitAbstractPath(path.key);
+ RETURN_NOT_OK(Validate(&path));
+ return path;
+ }
+
+ static Status Validate(const S3Path* path) {
+ auto result = internal::ValidateAbstractPathParts(path->key_parts);
+ if (!result.ok()) {
+ return Status::Invalid(result.message(), " in path ", path->full_path);
+ } else {
+ return result;
+ }
+ }
+
+ Aws::String ToAwsString() const {
+ Aws::String res(bucket.begin(), bucket.end());
+ res.reserve(bucket.size() + key.size() + 1);
+ res += kSep;
+ res.append(key.begin(), key.end());
+ return res;
+ }
+
+ Aws::String ToURLEncodedAwsString() const {
+ // URL-encode individual parts, not the '/' separator
+ Aws::String res;
+ res += internal::ToURLEncodedAwsString(bucket);
+ for (const auto& part : key_parts) {
+ res += kSep;
+ res += internal::ToURLEncodedAwsString(part);
+ }
+ return res;
+ }
+
+ S3Path parent() const {
+ DCHECK(!key_parts.empty());
+ auto parent = S3Path{"", bucket, "", key_parts};
+ parent.key_parts.pop_back();
+ parent.key = internal::JoinAbstractPath(parent.key_parts);
+ parent.full_path = parent.bucket + kSep + parent.key;
+ return parent;
+ }
+
+ bool has_parent() const { return !key.empty(); }
+
+ bool empty() const { return bucket.empty() && key.empty(); }
+
+ bool operator==(const S3Path& other) const {
+ return bucket == other.bucket && key == other.key;
+ }
+};
+
+// XXX return in OutcomeToStatus instead?
+Status PathNotFound(const S3Path& path) {
+ return ::arrow::fs::internal::PathNotFound(path.full_path);
+}
+
+Status PathNotFound(const std::string& bucket, const std::string& key) {
+ return ::arrow::fs::internal::PathNotFound(bucket + kSep + key);
+}
+
+Status NotAFile(const S3Path& path) {
+ return ::arrow::fs::internal::NotAFile(path.full_path);
+}
+
+Status ValidateFilePath(const S3Path& path) {
+ if (path.bucket.empty() || path.key.empty()) {
+ return NotAFile(path);
+ }
+ return Status::OK();
+}
+
+std::string FormatRange(int64_t start, int64_t length) {
+ // Format a HTTP range header value
+ std::stringstream ss;
+ ss << "bytes=" << start << "-" << start + length - 1;
+ return ss.str();
+}
+
+// An AWS RetryStrategy that wraps a provided arrow::fs::S3RetryStrategy
+class WrappedRetryStrategy : public Aws::Client::RetryStrategy {
+ public:
+ explicit WrappedRetryStrategy(const std::shared_ptr<S3RetryStrategy>& s3_retry_strategy)
+ : s3_retry_strategy_(s3_retry_strategy) {}
+
+ bool ShouldRetry(const Aws::Client::AWSError<Aws::Client::CoreErrors>& error,
+ long attempted_retries) const override { // NOLINT runtime/int
+ S3RetryStrategy::AWSErrorDetail detail = ErrorToDetail(error);
+ return s3_retry_strategy_->ShouldRetry(detail,
+ static_cast<int64_t>(attempted_retries));
+ }
+
+ long CalculateDelayBeforeNextRetry( // NOLINT runtime/int
+ const Aws::Client::AWSError<Aws::Client::CoreErrors>& error,
+ long attempted_retries) const override { // NOLINT runtime/int
+ S3RetryStrategy::AWSErrorDetail detail = ErrorToDetail(error);
+ return static_cast<long>( // NOLINT runtime/int
+ s3_retry_strategy_->CalculateDelayBeforeNextRetry(
+ detail, static_cast<int64_t>(attempted_retries)));
+ }
+
+ private:
+ template <typename ErrorType>
+ static S3RetryStrategy::AWSErrorDetail ErrorToDetail(
+ const Aws::Client::AWSError<ErrorType>& error) {
+ S3RetryStrategy::AWSErrorDetail detail;
+ detail.error_type = static_cast<int>(error.GetErrorType());
+ detail.message = std::string(FromAwsString(error.GetMessage()));
+ detail.exception_name = std::string(FromAwsString(error.GetExceptionName()));
+ detail.should_retry = error.ShouldRetry();
+ return detail;
+ }
+
+ std::shared_ptr<S3RetryStrategy> s3_retry_strategy_;
+};
+
+class S3Client : public Aws::S3::S3Client {
+ public:
+ using Aws::S3::S3Client::S3Client;
+
+ // To get a bucket's region, we must extract the "x-amz-bucket-region" header
+ // from the response to a HEAD bucket request.
+ // Unfortunately, the S3Client APIs don't let us access the headers of successful
+ // responses. So we have to cook a AWS request and issue it ourselves.
+
+ Result<std::string> GetBucketRegion(const S3Model::HeadBucketRequest& request) {
+ auto uri = GeneratePresignedUrl(request.GetBucket(),
+ /*key=*/"", Aws::Http::HttpMethod::HTTP_HEAD);
+ // NOTE: The signer region argument isn't passed here, as there's no easy
+ // way of computing it (the relevant method is private).
+ auto outcome = MakeRequest(uri, request, Aws::Http::HttpMethod::HTTP_HEAD,
+ Aws::Auth::SIGV4_SIGNER);
+ const auto code = outcome.IsSuccess() ? outcome.GetResult().GetResponseCode()
+ : outcome.GetError().GetResponseCode();
+ const auto& headers = outcome.IsSuccess()
+ ? outcome.GetResult().GetHeaderValueCollection()
+ : outcome.GetError().GetResponseHeaders();
+
+ const auto it = headers.find(ToAwsString("x-amz-bucket-region"));
+ if (it == headers.end()) {
+ if (code == Aws::Http::HttpResponseCode::NOT_FOUND) {
+ return Status::IOError("Bucket '", request.GetBucket(), "' not found");
+ } else if (!outcome.IsSuccess()) {
+ return ErrorToStatus(std::forward_as_tuple("When resolving region for bucket '",
+ request.GetBucket(), "': "),
+ outcome.GetError());
+ } else {
+ return Status::IOError("When resolving region for bucket '", request.GetBucket(),
+ "': missing 'x-amz-bucket-region' header in response");
+ }
+ }
+ return std::string(FromAwsString(it->second));
+ }
+
+ Result<std::string> GetBucketRegion(const std::string& bucket) {
+ S3Model::HeadBucketRequest req;
+ req.SetBucket(ToAwsString(bucket));
+ return GetBucketRegion(req);
+ }
+
+ S3Model::CompleteMultipartUploadOutcome CompleteMultipartUploadWithErrorFixup(
+ S3Model::CompleteMultipartUploadRequest&& request) const {
+ // CompletedMultipartUpload can return a 200 OK response with an error
+ // encoded in the response body, in which case we should either retry
+ // or propagate the error to the user (see
+ // https://docs.aws.amazon.com/AmazonS3/latest/API/API_CompleteMultipartUpload.html).
+ //
+ // Unfortunately the AWS SDK doesn't detect such situations but lets them
+ // return successfully (see https://github.com/aws/aws-sdk-cpp/issues/658).
+ //
+ // We work around the issue by registering a DataReceivedEventHandler
+ // which parses the XML response for embedded errors.
+
+ util::optional<AWSError<Aws::Client::CoreErrors>> aws_error;
+
+ auto handler = [&](const Aws::Http::HttpRequest* http_req,
+ Aws::Http::HttpResponse* http_resp,
+ long long) { // NOLINT runtime/int
+ auto& stream = http_resp->GetResponseBody();
+ const auto pos = stream.tellg();
+ const auto doc = Aws::Utils::Xml::XmlDocument::CreateFromXmlStream(stream);
+ // Rewind stream for later
+ stream.clear();
+ stream.seekg(pos);
+
+ if (doc.WasParseSuccessful()) {
+ auto root = doc.GetRootElement();
+ if (!root.IsNull()) {
+ // Detect something that looks like an abnormal CompletedMultipartUpload
+ // response.
+ if (root.GetName() != "CompleteMultipartUploadResult" ||
+ !root.FirstChild("Error").IsNull() || !root.FirstChild("Errors").IsNull()) {
+ // Make sure the error marshaller doesn't see a 200 OK
+ http_resp->SetResponseCode(
+ Aws::Http::HttpResponseCode::INTERNAL_SERVER_ERROR);
+ aws_error = GetErrorMarshaller()->Marshall(*http_resp);
+ // Rewind stream for later
+ stream.clear();
+ stream.seekg(pos);
+ }
+ }
+ }
+ };
+
+ request.SetDataReceivedEventHandler(std::move(handler));
+
+ // We don't have access to the configured AWS retry strategy
+ // (m_retryStrategy is a private member of AwsClient), so don't use that.
+ std::unique_ptr<Aws::Client::RetryStrategy> retry_strategy;
+ if (s3_retry_strategy_) {
+ retry_strategy.reset(new WrappedRetryStrategy(s3_retry_strategy_));
+ } else {
+ // Note that DefaultRetryStrategy, unlike StandardRetryStrategy,
+ // has empty definitions for RequestBookkeeping() and GetSendToken(),
+ // which simplifies the code below.
+ retry_strategy.reset(new Aws::Client::DefaultRetryStrategy());
+ }
+
+ for (int32_t retries = 0;; retries++) {
+ aws_error.reset();
+ auto outcome = Aws::S3::S3Client::S3Client::CompleteMultipartUpload(request);
+ if (!outcome.IsSuccess()) {
+ // Error returned in HTTP headers (or client failure)
+ return outcome;
+ }
+ if (!aws_error.has_value()) {
+ // Genuinely successful outcome
+ return outcome;
+ }
+
+ const bool should_retry = retry_strategy->ShouldRetry(*aws_error, retries);
+
+ ARROW_LOG(WARNING)
+ << "CompletedMultipartUpload got error embedded in a 200 OK response: "
+ << aws_error->GetExceptionName() << " (\"" << aws_error->GetMessage()
+ << "\"), retry = " << should_retry;
+
+ if (!should_retry) {
+ break;
+ }
+ const auto delay = std::chrono::milliseconds(
+ retry_strategy->CalculateDelayBeforeNextRetry(*aws_error, retries));
+ std::this_thread::sleep_for(delay);
+ }
+
+ DCHECK(aws_error.has_value());
+ auto s3_error = AWSError<S3Errors>(std::move(aws_error).value());
+ return S3Model::CompleteMultipartUploadOutcome(std::move(s3_error));
+ }
+
+ std::shared_ptr<S3RetryStrategy> s3_retry_strategy_;
+};
+
+// In AWS SDK < 1.8, Aws::Client::ClientConfiguration::followRedirects is a bool.
+template <bool Never = false>
+void DisableRedirectsImpl(bool* followRedirects) {
+ *followRedirects = false;
+}
+
+// In AWS SDK >= 1.8, it's a Aws::Client::FollowRedirectsPolicy scoped enum.
+template <typename PolicyEnum, PolicyEnum Never = PolicyEnum::NEVER>
+void DisableRedirectsImpl(PolicyEnum* followRedirects) {
+ *followRedirects = Never;
+}
+
+void DisableRedirects(Aws::Client::ClientConfiguration* c) {
+ DisableRedirectsImpl(&c->followRedirects);
+}
+
+class ClientBuilder {
+ public:
+ explicit ClientBuilder(S3Options options) : options_(std::move(options)) {}
+
+ const Aws::Client::ClientConfiguration& config() const { return client_config_; }
+
+ Aws::Client::ClientConfiguration* mutable_config() { return &client_config_; }
+
+ Result<std::shared_ptr<S3Client>> BuildClient() {
+ credentials_provider_ = options_.credentials_provider;
+ if (!options_.region.empty()) {
+ client_config_.region = ToAwsString(options_.region);
+ }
+ client_config_.endpointOverride = ToAwsString(options_.endpoint_override);
+ if (options_.scheme == "http") {
+ client_config_.scheme = Aws::Http::Scheme::HTTP;
+ } else if (options_.scheme == "https") {
+ client_config_.scheme = Aws::Http::Scheme::HTTPS;
+ } else {
+ return Status::Invalid("Invalid S3 connection scheme '", options_.scheme, "'");
+ }
+ if (options_.retry_strategy) {
+ client_config_.retryStrategy =
+ std::make_shared<WrappedRetryStrategy>(options_.retry_strategy);
+ } else {
+ client_config_.retryStrategy = std::make_shared<ConnectRetryStrategy>();
+ }
+ if (!internal::global_options.tls_ca_file_path.empty()) {
+ client_config_.caFile = ToAwsString(internal::global_options.tls_ca_file_path);
+ }
+ if (!internal::global_options.tls_ca_dir_path.empty()) {
+ client_config_.caPath = ToAwsString(internal::global_options.tls_ca_dir_path);
+ }
+
+ const bool use_virtual_addressing = options_.endpoint_override.empty();
+
+ // Set proxy options if provided
+ if (!options_.proxy_options.scheme.empty()) {
+ if (options_.proxy_options.scheme == "http") {
+ client_config_.proxyScheme = Aws::Http::Scheme::HTTP;
+ } else if (options_.proxy_options.scheme == "https") {
+ client_config_.proxyScheme = Aws::Http::Scheme::HTTPS;
+ } else {
+ return Status::Invalid("Invalid proxy connection scheme '",
+ options_.proxy_options.scheme, "'");
+ }
+ }
+ if (!options_.proxy_options.host.empty()) {
+ client_config_.proxyHost = ToAwsString(options_.proxy_options.host);
+ }
+ if (options_.proxy_options.port != -1) {
+ client_config_.proxyPort = options_.proxy_options.port;
+ }
+ if (!options_.proxy_options.username.empty()) {
+ client_config_.proxyUserName = ToAwsString(options_.proxy_options.username);
+ }
+ if (!options_.proxy_options.password.empty()) {
+ client_config_.proxyPassword = ToAwsString(options_.proxy_options.password);
+ }
+
+ auto client = std::make_shared<S3Client>(
+ credentials_provider_, client_config_,
+ Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
+ use_virtual_addressing);
+ client->s3_retry_strategy_ = options_.retry_strategy;
+ return client;
+ }
+
+ const S3Options& options() const { return options_; }
+
+ protected:
+ S3Options options_;
+ Aws::Client::ClientConfiguration client_config_;
+ std::shared_ptr<Aws::Auth::AWSCredentialsProvider> credentials_provider_;
+};
+
+// -----------------------------------------------------------------------
+// S3 region resolver
+
+class RegionResolver {
+ public:
+ static Result<std::shared_ptr<RegionResolver>> Make(S3Options options) {
+ std::shared_ptr<RegionResolver> resolver(new RegionResolver(std::move(options)));
+ RETURN_NOT_OK(resolver->Init());
+ return resolver;
+ }
+
+ static Result<std::shared_ptr<RegionResolver>> DefaultInstance() {
+ static std::shared_ptr<RegionResolver> instance;
+ auto resolver = arrow::internal::atomic_load(&instance);
+ if (resolver) {
+ return resolver;
+ }
+ auto maybe_resolver = Make(S3Options::Anonymous());
+ if (!maybe_resolver.ok()) {
+ return maybe_resolver;
+ }
+ // Make sure to always return the same instance even if several threads
+ // call DefaultInstance at once.
+ std::shared_ptr<RegionResolver> existing;
+ if (arrow::internal::atomic_compare_exchange_strong(&instance, &existing,
+ *maybe_resolver)) {
+ return *maybe_resolver;
+ } else {
+ return existing;
+ }
+ }
+
+ Result<std::string> ResolveRegion(const std::string& bucket) {
+ std::unique_lock<std::mutex> lock(cache_mutex_);
+ auto it = cache_.find(bucket);
+ if (it != cache_.end()) {
+ return it->second;
+ }
+ lock.unlock();
+ ARROW_ASSIGN_OR_RAISE(auto region, ResolveRegionUncached(bucket));
+ lock.lock();
+ // Note we don't cache a non-existent bucket, as the bucket could be created later
+ cache_[bucket] = region;
+ return region;
+ }
+
+ Result<std::string> ResolveRegionUncached(const std::string& bucket) {
+ return client_->GetBucketRegion(bucket);
+ }
+
+ protected:
+ explicit RegionResolver(S3Options options) : builder_(std::move(options)) {}
+
+ Status Init() {
+ DCHECK(builder_.options().endpoint_override.empty());
+ // On Windows with AWS SDK >= 1.8, it is necessary to disable redirects (ARROW-10085).
+ DisableRedirects(builder_.mutable_config());
+ return builder_.BuildClient().Value(&client_);
+ }
+
+ ClientBuilder builder_;
+ std::shared_ptr<S3Client> client_;
+
+ std::mutex cache_mutex_;
+ // XXX Should cache size be bounded? It must be quite unusual to query millions
+ // of different buckets in a single program invocation...
+ std::unordered_map<std::string, std::string> cache_;
+};
+
+// -----------------------------------------------------------------------
+// S3 file stream implementations
+
+// A non-copying iostream.
+// See https://stackoverflow.com/questions/35322033/aws-c-sdk-uploadpart-times-out
+// https://stackoverflow.com/questions/13059091/creating-an-input-stream-from-constant-memory
+class StringViewStream : Aws::Utils::Stream::PreallocatedStreamBuf, public std::iostream {
+ public:
+ StringViewStream(const void* data, int64_t nbytes)
+ : Aws::Utils::Stream::PreallocatedStreamBuf(
+ reinterpret_cast<unsigned char*>(const_cast<void*>(data)),
+ static_cast<size_t>(nbytes)),
+ std::iostream(this) {}
+};
+
+// By default, the AWS SDK reads object data into an auto-growing StringStream.
+// To avoid copies, read directly into our preallocated buffer instead.
+// See https://github.com/aws/aws-sdk-cpp/issues/64 for an alternative but
+// functionally similar recipe.
+Aws::IOStreamFactory AwsWriteableStreamFactory(void* data, int64_t nbytes) {
+ return [=]() { return Aws::New<StringViewStream>("", data, nbytes); };
+}
+
+Result<S3Model::GetObjectResult> GetObjectRange(Aws::S3::S3Client* client,
+ const S3Path& path, int64_t start,
+ int64_t length, void* out) {
+ S3Model::GetObjectRequest req;
+ req.SetBucket(ToAwsString(path.bucket));
+ req.SetKey(ToAwsString(path.key));
+ req.SetRange(ToAwsString(FormatRange(start, length)));
+ req.SetResponseStreamFactory(AwsWriteableStreamFactory(out, length));
+ return OutcomeToResult(client->GetObject(req));
+}
+
+template <typename ObjectResult>
+std::shared_ptr<const KeyValueMetadata> GetObjectMetadata(const ObjectResult& result) {
+ auto md = std::make_shared<KeyValueMetadata>();
+
+ auto push = [&](std::string k, const Aws::String& v) {
+ if (!v.empty()) {
+ md->Append(std::move(k), FromAwsString(v).to_string());
+ }
+ };
+ auto push_datetime = [&](std::string k, const Aws::Utils::DateTime& v) {
+ if (v != Aws::Utils::DateTime(0.0)) {
+ push(std::move(k), v.ToGmtString(Aws::Utils::DateFormat::ISO_8601));
+ }
+ };
+
+ md->Append("Content-Length", std::to_string(result.GetContentLength()));
+ push("Cache-Control", result.GetCacheControl());
+ push("Content-Type", result.GetContentType());
+ push("Content-Language", result.GetContentLanguage());
+ push("ETag", result.GetETag());
+ push("VersionId", result.GetVersionId());
+ push_datetime("Last-Modified", result.GetLastModified());
+ push_datetime("Expires", result.GetExpires());
+ // NOTE the "canned ACL" isn't available for reading (one can get an expanded
+ // ACL using a separate GetObjectAcl request)
+ return md;
+}
+
+template <typename ObjectRequest>
+struct ObjectMetadataSetter {
+ using Setter = std::function<Status(const std::string& value, ObjectRequest* req)>;
+
+ static std::unordered_map<std::string, Setter> GetSetters() {
+ return {{"ACL", CannedACLSetter()},
+ {"Cache-Control", StringSetter(&ObjectRequest::SetCacheControl)},
+ {"Content-Type", StringSetter(&ObjectRequest::SetContentType)},
+ {"Content-Language", StringSetter(&ObjectRequest::SetContentLanguage)},
+ {"Expires", DateTimeSetter(&ObjectRequest::SetExpires)}};
+ }
+
+ private:
+ static Setter StringSetter(void (ObjectRequest::*req_method)(Aws::String&&)) {
+ return [req_method](const std::string& v, ObjectRequest* req) {
+ (req->*req_method)(ToAwsString(v));
+ return Status::OK();
+ };
+ }
+
+ static Setter DateTimeSetter(
+ void (ObjectRequest::*req_method)(Aws::Utils::DateTime&&)) {
+ return [req_method](const std::string& v, ObjectRequest* req) {
+ (req->*req_method)(
+ Aws::Utils::DateTime(v.data(), Aws::Utils::DateFormat::ISO_8601));
+ return Status::OK();
+ };
+ }
+
+ static Setter CannedACLSetter() {
+ return [](const std::string& v, ObjectRequest* req) {
+ ARROW_ASSIGN_OR_RAISE(auto acl, ParseACL(v));
+ req->SetACL(acl);
+ return Status::OK();
+ };
+ }
+
+ static Result<S3Model::ObjectCannedACL> ParseACL(const std::string& v) {
+ if (v.empty()) {
+ return S3Model::ObjectCannedACL::NOT_SET;
+ }
+ auto acl = S3Model::ObjectCannedACLMapper::GetObjectCannedACLForName(ToAwsString(v));
+ if (acl == S3Model::ObjectCannedACL::NOT_SET) {
+ // XXX This actually never happens, as the AWS SDK dynamically
+ // expands the enum range using Aws::GetEnumOverflowContainer()
+ return Status::Invalid("Invalid S3 canned ACL: '", v, "'");
+ }
+ return acl;
+ }
+};
+
+template <typename ObjectRequest>
+Status SetObjectMetadata(const std::shared_ptr<const KeyValueMetadata>& metadata,
+ ObjectRequest* req) {
+ static auto setters = ObjectMetadataSetter<ObjectRequest>::GetSetters();
+
+ DCHECK_NE(metadata, nullptr);
+ const auto& keys = metadata->keys();
+ const auto& values = metadata->values();
+
+ for (size_t i = 0; i < keys.size(); ++i) {
+ auto it = setters.find(keys[i]);
+ if (it != setters.end()) {
+ RETURN_NOT_OK(it->second(values[i], req));
+ }
+ }
+ return Status::OK();
+}
+
+// A RandomAccessFile that reads from a S3 object
+class ObjectInputFile final : public io::RandomAccessFile {
+ public:
+ ObjectInputFile(std::shared_ptr<Aws::S3::S3Client> client,
+ const io::IOContext& io_context, const S3Path& path,
+ int64_t size = kNoSize)
+ : client_(std::move(client)),
+ io_context_(io_context),
+ path_(path),
+ content_length_(size) {}
+
+ Status Init() {
+ // Issue a HEAD Object to get the content-length and ensure any
+ // errors (e.g. file not found) don't wait until the first Read() call.
+ if (content_length_ != kNoSize) {
+ DCHECK_GE(content_length_, 0);
+ return Status::OK();
+ }
+
+ S3Model::HeadObjectRequest req;
+ req.SetBucket(ToAwsString(path_.bucket));
+ req.SetKey(ToAwsString(path_.key));
+
+ auto outcome = client_->HeadObject(req);
+ if (!outcome.IsSuccess()) {
+ if (IsNotFound(outcome.GetError())) {
+ return PathNotFound(path_);
+ } else {
+ return ErrorToStatus(
+ std::forward_as_tuple("When reading information for key '", path_.key,
+ "' in bucket '", path_.bucket, "': "),
+ outcome.GetError());
+ }
+ }
+ content_length_ = outcome.GetResult().GetContentLength();
+ DCHECK_GE(content_length_, 0);
+ metadata_ = GetObjectMetadata(outcome.GetResult());
+ return Status::OK();
+ }
+
+ Status CheckClosed() const {
+ if (closed_) {
+ return Status::Invalid("Operation on closed stream");
+ }
+ return Status::OK();
+ }
+
+ Status CheckPosition(int64_t position, const char* action) const {
+ if (position < 0) {
+ return Status::Invalid("Cannot ", action, " from negative position");
+ }
+ if (position > content_length_) {
+ return Status::IOError("Cannot ", action, " past end of file");
+ }
+ return Status::OK();
+ }
+
+ // RandomAccessFile APIs
+
+ Result<std::shared_ptr<const KeyValueMetadata>> ReadMetadata() override {
+ return metadata_;
+ }
+
+ Future<std::shared_ptr<const KeyValueMetadata>> ReadMetadataAsync(
+ const io::IOContext& io_context) override {
+ return metadata_;
+ }
+
+ Status Close() override {
+ client_ = nullptr;
+ closed_ = true;
+ return Status::OK();
+ }
+
+ bool closed() const override { return closed_; }
+
+ Result<int64_t> Tell() const override {
+ RETURN_NOT_OK(CheckClosed());
+ return pos_;
+ }
+
+ Result<int64_t> GetSize() override {
+ RETURN_NOT_OK(CheckClosed());
+ return content_length_;
+ }
+
+ Status Seek(int64_t position) override {
+ RETURN_NOT_OK(CheckClosed());
+ RETURN_NOT_OK(CheckPosition(position, "seek"));
+
+ pos_ = position;
+ return Status::OK();
+ }
+
+ Result<int64_t> ReadAt(int64_t position, int64_t nbytes, void* out) override {
+ RETURN_NOT_OK(CheckClosed());
+ RETURN_NOT_OK(CheckPosition(position, "read"));
+
+ nbytes = std::min(nbytes, content_length_ - position);
+ if (nbytes == 0) {
+ return 0;
+ }
+
+ // Read the desired range of bytes
+ ARROW_ASSIGN_OR_RAISE(S3Model::GetObjectResult result,
+ GetObjectRange(client_.get(), path_, position, nbytes, out));
+
+ auto& stream = result.GetBody();
+ stream.ignore(nbytes);
+ // NOTE: the stream is a stringstream by default, there is no actual error
+ // to check for. However, stream.fail() may return true if EOF is reached.
+ return stream.gcount();
+ }
+
+ Result<std::shared_ptr<Buffer>> ReadAt(int64_t position, int64_t nbytes) override {
+ RETURN_NOT_OK(CheckClosed());
+ RETURN_NOT_OK(CheckPosition(position, "read"));
+
+ // No need to allocate more than the remaining number of bytes
+ nbytes = std::min(nbytes, content_length_ - position);
+
+ ARROW_ASSIGN_OR_RAISE(auto buf, AllocateResizableBuffer(nbytes, io_context_.pool()));
+ if (nbytes > 0) {
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read,
+ ReadAt(position, nbytes, buf->mutable_data()));
+ DCHECK_LE(bytes_read, nbytes);
+ RETURN_NOT_OK(buf->Resize(bytes_read));
+ }
+ return std::move(buf);
+ }
+
+ Result<int64_t> Read(int64_t nbytes, void* out) override {
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, ReadAt(pos_, nbytes, out));
+ pos_ += bytes_read;
+ return bytes_read;
+ }
+
+ Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) override {
+ ARROW_ASSIGN_OR_RAISE(auto buffer, ReadAt(pos_, nbytes));
+ pos_ += buffer->size();
+ return std::move(buffer);
+ }
+
+ protected:
+ std::shared_ptr<Aws::S3::S3Client> client_;
+ const io::IOContext io_context_;
+ S3Path path_;
+
+ bool closed_ = false;
+ int64_t pos_ = 0;
+ int64_t content_length_ = kNoSize;
+ std::shared_ptr<const KeyValueMetadata> metadata_;
+};
+
+// Minimum size for each part of a multipart upload, except for the last part.
+// AWS doc says "5 MB" but it's not clear whether those are MB or MiB,
+// so I chose the safer value.
+// (see https://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadUploadPart.html)
+static constexpr int64_t kMinimumPartUpload = 5 * 1024 * 1024;
+
+// An OutputStream that writes to a S3 object
+class ObjectOutputStream final : public io::OutputStream {
+ protected:
+ struct UploadState;
+
+ public:
+ ObjectOutputStream(std::shared_ptr<S3Client> client, const io::IOContext& io_context,
+ const S3Path& path, const S3Options& options,
+ const std::shared_ptr<const KeyValueMetadata>& metadata)
+ : client_(std::move(client)),
+ io_context_(io_context),
+ path_(path),
+ metadata_(metadata),
+ default_metadata_(options.default_metadata),
+ background_writes_(options.background_writes) {}
+
+ ~ObjectOutputStream() override {
+ // For compliance with the rest of the IO stack, Close rather than Abort,
+ // even though it may be more expensive.
+ io::internal::CloseFromDestructor(this);
+ }
+
+ Status Init() {
+ // Initiate the multi-part upload
+ S3Model::CreateMultipartUploadRequest req;
+ req.SetBucket(ToAwsString(path_.bucket));
+ req.SetKey(ToAwsString(path_.key));
+ if (metadata_ && metadata_->size() != 0) {
+ RETURN_NOT_OK(SetObjectMetadata(metadata_, &req));
+ } else if (default_metadata_ && default_metadata_->size() != 0) {
+ RETURN_NOT_OK(SetObjectMetadata(default_metadata_, &req));
+ }
+
+ auto outcome = client_->CreateMultipartUpload(req);
+ if (!outcome.IsSuccess()) {
+ return ErrorToStatus(
+ std::forward_as_tuple("When initiating multiple part upload for key '",
+ path_.key, "' in bucket '", path_.bucket, "': "),
+ outcome.GetError());
+ }
+ upload_id_ = outcome.GetResult().GetUploadId();
+ upload_state_ = std::make_shared<UploadState>();
+ closed_ = false;
+ return Status::OK();
+ }
+
+ Status Abort() override {
+ if (closed_) {
+ return Status::OK();
+ }
+
+ S3Model::AbortMultipartUploadRequest req;
+ req.SetBucket(ToAwsString(path_.bucket));
+ req.SetKey(ToAwsString(path_.key));
+ req.SetUploadId(upload_id_);
+
+ auto outcome = client_->AbortMultipartUpload(req);
+ if (!outcome.IsSuccess()) {
+ return ErrorToStatus(
+ std::forward_as_tuple("When aborting multiple part upload for key '", path_.key,
+ "' in bucket '", path_.bucket, "': "),
+ outcome.GetError());
+ }
+ current_part_.reset();
+ client_ = nullptr;
+ closed_ = true;
+ return Status::OK();
+ }
+
+ // OutputStream interface
+
+ Status Close() override {
+ if (closed_) {
+ return Status::OK();
+ }
+
+ if (current_part_) {
+ // Upload last part
+ RETURN_NOT_OK(CommitCurrentPart());
+ }
+
+ // S3 mandates at least one part, upload an empty one if necessary
+ if (part_number_ == 1) {
+ RETURN_NOT_OK(UploadPart("", 0));
+ }
+
+ // Wait for in-progress uploads to finish (if async writes are enabled)
+ RETURN_NOT_OK(Flush());
+
+ // At this point, all part uploads have finished successfully
+ DCHECK_GT(part_number_, 1);
+ DCHECK_EQ(upload_state_->completed_parts.size(),
+ static_cast<size_t>(part_number_ - 1));
+
+ S3Model::CompletedMultipartUpload completed_upload;
+ completed_upload.SetParts(upload_state_->completed_parts);
+ S3Model::CompleteMultipartUploadRequest req;
+ req.SetBucket(ToAwsString(path_.bucket));
+ req.SetKey(ToAwsString(path_.key));
+ req.SetUploadId(upload_id_);
+ req.SetMultipartUpload(std::move(completed_upload));
+
+ auto outcome = client_->CompleteMultipartUploadWithErrorFixup(std::move(req));
+ if (!outcome.IsSuccess()) {
+ return ErrorToStatus(
+ std::forward_as_tuple("When completing multiple part upload for key '",
+ path_.key, "' in bucket '", path_.bucket, "': "),
+ outcome.GetError());
+ }
+
+ client_ = nullptr;
+ closed_ = true;
+ return Status::OK();
+ }
+
+ bool closed() const override { return closed_; }
+
+ Result<int64_t> Tell() const override {
+ if (closed_) {
+ return Status::Invalid("Operation on closed stream");
+ }
+ return pos_;
+ }
+
+ Status Write(const std::shared_ptr<Buffer>& buffer) override {
+ return DoWrite(buffer->data(), buffer->size(), buffer);
+ }
+
+ Status Write(const void* data, int64_t nbytes) override {
+ return DoWrite(data, nbytes);
+ }
+
+ Status DoWrite(const void* data, int64_t nbytes,
+ std::shared_ptr<Buffer> owned_buffer = nullptr) {
+ if (closed_) {
+ return Status::Invalid("Operation on closed stream");
+ }
+
+ if (!current_part_ && nbytes >= part_upload_threshold_) {
+ // No current part and data large enough, upload it directly
+ // (without copying if the buffer is owned)
+ RETURN_NOT_OK(UploadPart(data, nbytes, owned_buffer));
+ pos_ += nbytes;
+ return Status::OK();
+ }
+ // Can't upload data on its own, need to buffer it
+ if (!current_part_) {
+ ARROW_ASSIGN_OR_RAISE(
+ current_part_,
+ io::BufferOutputStream::Create(part_upload_threshold_, io_context_.pool()));
+ current_part_size_ = 0;
+ }
+ RETURN_NOT_OK(current_part_->Write(data, nbytes));
+ pos_ += nbytes;
+ current_part_size_ += nbytes;
+
+ if (current_part_size_ >= part_upload_threshold_) {
+ // Current part large enough, upload it
+ RETURN_NOT_OK(CommitCurrentPart());
+ }
+
+ return Status::OK();
+ }
+
+ Status Flush() override {
+ if (closed_) {
+ return Status::Invalid("Operation on closed stream");
+ }
+ // Wait for background writes to finish
+ std::unique_lock<std::mutex> lock(upload_state_->mutex);
+ upload_state_->cv.wait(lock,
+ [this]() { return upload_state_->parts_in_progress == 0; });
+ return upload_state_->status;
+ }
+
+ // Upload-related helpers
+
+ Status CommitCurrentPart() {
+ ARROW_ASSIGN_OR_RAISE(auto buf, current_part_->Finish());
+ current_part_.reset();
+ current_part_size_ = 0;
+ return UploadPart(buf);
+ }
+
+ Status UploadPart(std::shared_ptr<Buffer> buffer) {
+ return UploadPart(buffer->data(), buffer->size(), buffer);
+ }
+
+ Status UploadPart(const void* data, int64_t nbytes,
+ std::shared_ptr<Buffer> owned_buffer = nullptr) {
+ S3Model::UploadPartRequest req;
+ req.SetBucket(ToAwsString(path_.bucket));
+ req.SetKey(ToAwsString(path_.key));
+ req.SetUploadId(upload_id_);
+ req.SetPartNumber(part_number_);
+ req.SetContentLength(nbytes);
+
+ if (!background_writes_) {
+ req.SetBody(std::make_shared<StringViewStream>(data, nbytes));
+ auto outcome = client_->UploadPart(req);
+ if (!outcome.IsSuccess()) {
+ return UploadPartError(req, outcome);
+ } else {
+ AddCompletedPart(upload_state_, part_number_, outcome.GetResult());
+ }
+ } else {
+ // If the data isn't owned, make an immutable copy for the lifetime of the closure
+ if (owned_buffer == nullptr) {
+ ARROW_ASSIGN_OR_RAISE(owned_buffer, AllocateBuffer(nbytes, io_context_.pool()));
+ memcpy(owned_buffer->mutable_data(), data, nbytes);
+ } else {
+ DCHECK_EQ(data, owned_buffer->data());
+ DCHECK_EQ(nbytes, owned_buffer->size());
+ }
+ req.SetBody(
+ std::make_shared<StringViewStream>(owned_buffer->data(), owned_buffer->size()));
+
+ {
+ std::unique_lock<std::mutex> lock(upload_state_->mutex);
+ ++upload_state_->parts_in_progress;
+ }
+ auto client = client_;
+ ARROW_ASSIGN_OR_RAISE(auto fut, SubmitIO(io_context_, [client, req]() {
+ return client->UploadPart(req);
+ }));
+ // The closure keeps the buffer and the upload state alive
+ auto state = upload_state_;
+ auto part_number = part_number_;
+ auto handler = [owned_buffer, state, part_number,
+ req](const Result<S3Model::UploadPartOutcome>& result) -> void {
+ HandleUploadOutcome(state, part_number, req, result);
+ };
+ fut.AddCallback(std::move(handler));
+ }
+
+ ++part_number_;
+ // With up to 10000 parts in an upload (S3 limit), a stream writing chunks
+ // of exactly 5MB would be limited to 50GB total. To avoid that, we bump
+ // the upload threshold every 100 parts. So the pattern is:
+ // - part 1 to 99: 5MB threshold
+ // - part 100 to 199: 10MB threshold
+ // - part 200 to 299: 15MB threshold
+ // ...
+ // - part 9900 to 9999: 500MB threshold
+ // So the total size limit is 2475000MB or ~2.4TB, while keeping manageable
+ // chunk sizes and avoiding too much buffering in the common case of a small-ish
+ // stream. If the limit's not enough, we can revisit.
+ if (part_number_ % 100 == 0) {
+ part_upload_threshold_ += kMinimumPartUpload;
+ }
+
+ return Status::OK();
+ }
+
+ static void HandleUploadOutcome(const std::shared_ptr<UploadState>& state,
+ int part_number, const S3Model::UploadPartRequest& req,
+ const Result<S3Model::UploadPartOutcome>& result) {
+ std::unique_lock<std::mutex> lock(state->mutex);
+ if (!result.ok()) {
+ state->status &= result.status();
+ } else {
+ const auto& outcome = *result;
+ if (!outcome.IsSuccess()) {
+ state->status &= UploadPartError(req, outcome);
+ } else {
+ AddCompletedPart(state, part_number, outcome.GetResult());
+ }
+ }
+ // Notify completion
+ if (--state->parts_in_progress == 0) {
+ state->cv.notify_all();
+ }
+ }
+
+ static void AddCompletedPart(const std::shared_ptr<UploadState>& state, int part_number,
+ const S3Model::UploadPartResult& result) {
+ S3Model::CompletedPart part;
+ // Append ETag and part number for this uploaded part
+ // (will be needed for upload completion in Close())
+ part.SetPartNumber(part_number);
+ part.SetETag(result.GetETag());
+ int slot = part_number - 1;
+ if (state->completed_parts.size() <= static_cast<size_t>(slot)) {
+ state->completed_parts.resize(slot + 1);
+ }
+ DCHECK(!state->completed_parts[slot].PartNumberHasBeenSet());
+ state->completed_parts[slot] = std::move(part);
+ }
+
+ static Status UploadPartError(const S3Model::UploadPartRequest& req,
+ const S3Model::UploadPartOutcome& outcome) {
+ return ErrorToStatus(
+ std::forward_as_tuple("When uploading part for key '", req.GetKey(),
+ "' in bucket '", req.GetBucket(), "': "),
+ outcome.GetError());
+ }
+
+ protected:
+ std::shared_ptr<S3Client> client_;
+ const io::IOContext io_context_;
+ const S3Path path_;
+ const std::shared_ptr<const KeyValueMetadata> metadata_;
+ const std::shared_ptr<const KeyValueMetadata> default_metadata_;
+ const bool background_writes_;
+
+ Aws::String upload_id_;
+ bool closed_ = true;
+ int64_t pos_ = 0;
+ int32_t part_number_ = 1;
+ std::shared_ptr<io::BufferOutputStream> current_part_;
+ int64_t current_part_size_ = 0;
+ int64_t part_upload_threshold_ = kMinimumPartUpload;
+
+ // This struct is kept alive through background writes to avoid problems
+ // in the completion handler.
+ struct UploadState {
+ std::mutex mutex;
+ std::condition_variable cv;
+ Aws::Vector<S3Model::CompletedPart> completed_parts;
+ int64_t parts_in_progress = 0;
+ Status status;
+ };
+ std::shared_ptr<UploadState> upload_state_;
+};
+
+// This function assumes info->path() is already set
+void FileObjectToInfo(const S3Model::HeadObjectResult& obj, FileInfo* info) {
+ info->set_type(FileType::File);
+ info->set_size(static_cast<int64_t>(obj.GetContentLength()));
+ info->set_mtime(FromAwsDatetime(obj.GetLastModified()));
+}
+
+void FileObjectToInfo(const S3Model::Object& obj, FileInfo* info) {
+ info->set_type(FileType::File);
+ info->set_size(static_cast<int64_t>(obj.GetSize()));
+ info->set_mtime(FromAwsDatetime(obj.GetLastModified()));
+}
+
+struct TreeWalker : public std::enable_shared_from_this<TreeWalker> {
+ using ResultHandler = std::function<Status(const std::string& prefix,
+ const S3Model::ListObjectsV2Result&)>;
+ using ErrorHandler = std::function<Status(const AWSError<S3Errors>& error)>;
+ using RecursionHandler = std::function<Result<bool>(int32_t nesting_depth)>;
+
+ std::shared_ptr<Aws::S3::S3Client> client_;
+ io::IOContext io_context_;
+ const std::string bucket_;
+ const std::string base_dir_;
+ const int32_t max_keys_;
+ const ResultHandler result_handler_;
+ const ErrorHandler error_handler_;
+ const RecursionHandler recursion_handler_;
+
+ template <typename... Args>
+ static Status Walk(Args&&... args) {
+ return WalkAsync(std::forward<Args>(args)...).status();
+ }
+
+ template <typename... Args>
+ static Future<> WalkAsync(Args&&... args) {
+ auto self = std::make_shared<TreeWalker>(std::forward<Args>(args)...);
+ return self->DoWalk();
+ }
+
+ TreeWalker(std::shared_ptr<Aws::S3::S3Client> client, io::IOContext io_context,
+ std::string bucket, std::string base_dir, int32_t max_keys,
+ ResultHandler result_handler, ErrorHandler error_handler,
+ RecursionHandler recursion_handler)
+ : client_(std::move(client)),
+ io_context_(io_context),
+ bucket_(std::move(bucket)),
+ base_dir_(std::move(base_dir)),
+ max_keys_(max_keys),
+ result_handler_(std::move(result_handler)),
+ error_handler_(std::move(error_handler)),
+ recursion_handler_(std::move(recursion_handler)) {}
+
+ private:
+ std::shared_ptr<TaskGroup> task_group_;
+ std::mutex mutex_;
+
+ Future<> DoWalk() {
+ task_group_ =
+ TaskGroup::MakeThreaded(io_context_.executor(), io_context_.stop_token());
+ WalkChild(base_dir_, /*nesting_depth=*/0);
+ // When this returns, ListObjectsV2 tasks either have finished or will exit early
+ return task_group_->FinishAsync();
+ }
+
+ bool ok() const { return task_group_->ok(); }
+
+ struct ListObjectsV2Handler {
+ std::shared_ptr<TreeWalker> walker;
+ std::string prefix;
+ int32_t nesting_depth;
+ S3Model::ListObjectsV2Request req;
+
+ Status operator()(const Result<S3Model::ListObjectsV2Outcome>& result) {
+ // Serialize calls to operation-specific handlers
+ if (!walker->ok()) {
+ // Early exit: avoid executing handlers if DoWalk() returned
+ return Status::OK();
+ }
+ if (!result.ok()) {
+ return result.status();
+ }
+ const auto& outcome = *result;
+ if (!outcome.IsSuccess()) {
+ {
+ std::lock_guard<std::mutex> guard(walker->mutex_);
+ return walker->error_handler_(outcome.GetError());
+ }
+ }
+ return HandleResult(outcome.GetResult());
+ }
+
+ void SpawnListObjectsV2() {
+ auto cb = *this;
+ walker->task_group_->Append([cb]() mutable {
+ Result<S3Model::ListObjectsV2Outcome> result =
+ cb.walker->client_->ListObjectsV2(cb.req);
+ return cb(result);
+ });
+ }
+
+ Status HandleResult(const S3Model::ListObjectsV2Result& result) {
+ bool recurse;
+ {
+ // Only one thread should be running result_handler_/recursion_handler_ at a time
+ std::lock_guard<std::mutex> guard(walker->mutex_);
+ recurse = result.GetCommonPrefixes().size() > 0;
+ if (recurse) {
+ ARROW_ASSIGN_OR_RAISE(auto maybe_recurse,
+ walker->recursion_handler_(nesting_depth + 1));
+ recurse &= maybe_recurse;
+ }
+ RETURN_NOT_OK(walker->result_handler_(prefix, result));
+ }
+ if (recurse) {
+ walker->WalkChildren(result, nesting_depth + 1);
+ }
+ // If the result was truncated, issue a continuation request to get
+ // further directory entries.
+ if (result.GetIsTruncated()) {
+ DCHECK(!result.GetNextContinuationToken().empty());
+ req.SetContinuationToken(result.GetNextContinuationToken());
+ SpawnListObjectsV2();
+ }
+ return Status::OK();
+ }
+
+ void Start() {
+ req.SetBucket(ToAwsString(walker->bucket_));
+ if (!prefix.empty()) {
+ req.SetPrefix(ToAwsString(prefix) + kSep);
+ }
+ req.SetDelimiter(Aws::String() + kSep);
+ req.SetMaxKeys(walker->max_keys_);
+ SpawnListObjectsV2();
+ }
+ };
+
+ void WalkChild(std::string key, int32_t nesting_depth) {
+ ListObjectsV2Handler handler{shared_from_this(), std::move(key), nesting_depth, {}};
+ handler.Start();
+ }
+
+ void WalkChildren(const S3Model::ListObjectsV2Result& result, int32_t nesting_depth) {
+ for (const auto& prefix : result.GetCommonPrefixes()) {
+ const auto child_key =
+ internal::RemoveTrailingSlash(FromAwsString(prefix.GetPrefix()));
+ WalkChild(std::string{child_key}, nesting_depth);
+ }
+ }
+
+ friend struct ListObjectsV2Handler;
+};
+
+} // namespace
+
+// -----------------------------------------------------------------------
+// S3 filesystem implementation
+
+class S3FileSystem::Impl : public std::enable_shared_from_this<S3FileSystem::Impl> {
+ public:
+ ClientBuilder builder_;
+ io::IOContext io_context_;
+ std::shared_ptr<S3Client> client_;
+ util::optional<S3Backend> backend_;
+
+ const int32_t kListObjectsMaxKeys = 1000;
+ // At most 1000 keys per multiple-delete request
+ const int32_t kMultipleDeleteMaxKeys = 1000;
+ // Limit recursing depth, since a recursion bomb can be created
+ const int32_t kMaxNestingDepth = 100;
+
+ explicit Impl(S3Options options, io::IOContext io_context)
+ : builder_(std::move(options)), io_context_(io_context) {}
+
+ Status Init() { return builder_.BuildClient().Value(&client_); }
+
+ const S3Options& options() const { return builder_.options(); }
+
+ std::string region() const {
+ return std::string(FromAwsString(builder_.config().region));
+ }
+
+ template <typename Error>
+ void SaveBackend(const Aws::Client::AWSError<Error>& error) {
+ if (!backend_ || *backend_ == S3Backend::Other) {
+ backend_ = DetectS3Backend(error);
+ }
+ }
+
+ // Tests to see if a bucket exists
+ Result<bool> BucketExists(const std::string& bucket) {
+ S3Model::HeadBucketRequest req;
+ req.SetBucket(ToAwsString(bucket));
+
+ auto outcome = client_->HeadBucket(req);
+ if (!outcome.IsSuccess()) {
+ if (!IsNotFound(outcome.GetError())) {
+ return ErrorToStatus(std::forward_as_tuple(
+ "When testing for existence of bucket '", bucket, "': "),
+ outcome.GetError());
+ }
+ return false;
+ }
+ return true;
+ }
+
+ // Create a bucket. Successful if bucket already exists.
+ Status CreateBucket(const std::string& bucket) {
+ S3Model::CreateBucketConfiguration config;
+ S3Model::CreateBucketRequest req;
+ auto _region = region();
+ // AWS S3 treats the us-east-1 differently than other regions
+ // https://docs.aws.amazon.com/cli/latest/reference/s3api/create-bucket.html
+ if (_region != "us-east-1") {
+ config.SetLocationConstraint(
+ S3Model::BucketLocationConstraintMapper::GetBucketLocationConstraintForName(
+ ToAwsString(_region)));
+ }
+ req.SetBucket(ToAwsString(bucket));
+ req.SetCreateBucketConfiguration(config);
+
+ auto outcome = client_->CreateBucket(req);
+ if (!outcome.IsSuccess() && !IsAlreadyExists(outcome.GetError())) {
+ return ErrorToStatus(std::forward_as_tuple("When creating bucket '", bucket, "': "),
+ outcome.GetError());
+ }
+ return Status::OK();
+ }
+
+ // Create an object with empty contents. Successful if object already exists.
+ Status CreateEmptyObject(const std::string& bucket, const std::string& key) {
+ S3Model::PutObjectRequest req;
+ req.SetBucket(ToAwsString(bucket));
+ req.SetKey(ToAwsString(key));
+ return OutcomeToStatus(
+ std::forward_as_tuple("When creating key '", key, "' in bucket '", bucket, "': "),
+ client_->PutObject(req));
+ }
+
+ Status CreateEmptyDir(const std::string& bucket, const std::string& key) {
+ DCHECK(!key.empty());
+ return CreateEmptyObject(bucket, key + kSep);
+ }
+
+ Status DeleteObject(const std::string& bucket, const std::string& key) {
+ S3Model::DeleteObjectRequest req;
+ req.SetBucket(ToAwsString(bucket));
+ req.SetKey(ToAwsString(key));
+ return OutcomeToStatus(
+ std::forward_as_tuple("When delete key '", key, "' in bucket '", bucket, "': "),
+ client_->DeleteObject(req));
+ }
+
+ Status CopyObject(const S3Path& src_path, const S3Path& dest_path) {
+ S3Model::CopyObjectRequest req;
+ req.SetBucket(ToAwsString(dest_path.bucket));
+ req.SetKey(ToAwsString(dest_path.key));
+ // ARROW-13048: Copy source "Must be URL-encoded" according to AWS SDK docs.
+ // However at least in 1.8 and 1.9 the SDK URL-encodes the path for you
+ req.SetCopySource(src_path.ToAwsString());
+ return OutcomeToStatus(
+ std::forward_as_tuple("When copying key '", src_path.key, "' in bucket '",
+ src_path.bucket, "' to key '", dest_path.key,
+ "' in bucket '", dest_path.bucket, "': "),
+ client_->CopyObject(req));
+ }
+
+ // On Minio, an empty "directory" doesn't satisfy the same API requests as
+ // a non-empty "directory". This is a Minio-specific quirk, but we need
+ // to handle it for unit testing.
+
+ Status IsEmptyDirectory(const std::string& bucket, const std::string& key, bool* out) {
+ S3Model::HeadObjectRequest req;
+ req.SetBucket(ToAwsString(bucket));
+ if (backend_ && *backend_ == S3Backend::Minio) {
+ // Minio wants a slash at the end, Amazon doesn't
+ req.SetKey(ToAwsString(key) + kSep);
+ } else {
+ req.SetKey(ToAwsString(key));
+ }
+
+ auto outcome = client_->HeadObject(req);
+ if (outcome.IsSuccess()) {
+ *out = true;
+ return Status::OK();
+ }
+ if (!backend_) {
+ SaveBackend(outcome.GetError());
+ DCHECK(backend_);
+ if (*backend_ == S3Backend::Minio) {
+ // Try again with separator-terminated key (see above)
+ return IsEmptyDirectory(bucket, key, out);
+ }
+ }
+ if (IsNotFound(outcome.GetError())) {
+ *out = false;
+ return Status::OK();
+ }
+ return ErrorToStatus(std::forward_as_tuple("When reading information for key '", key,
+ "' in bucket '", bucket, "': "),
+ outcome.GetError());
+ }
+
+ Status IsEmptyDirectory(const S3Path& path, bool* out) {
+ return IsEmptyDirectory(path.bucket, path.key, out);
+ }
+
+ Status IsNonEmptyDirectory(const S3Path& path, bool* out) {
+ S3Model::ListObjectsV2Request req;
+ req.SetBucket(ToAwsString(path.bucket));
+ req.SetPrefix(ToAwsString(path.key) + kSep);
+ req.SetDelimiter(Aws::String() + kSep);
+ req.SetMaxKeys(1);
+ auto outcome = client_->ListObjectsV2(req);
+ if (outcome.IsSuccess()) {
+ *out = outcome.GetResult().GetKeyCount() > 0;
+ return Status::OK();
+ }
+ if (IsNotFound(outcome.GetError())) {
+ *out = false;
+ return Status::OK();
+ }
+ return ErrorToStatus(
+ std::forward_as_tuple("When listing objects under key '", path.key,
+ "' in bucket '", path.bucket, "': "),
+ outcome.GetError());
+ }
+
+ Status CheckNestingDepth(int32_t nesting_depth) {
+ if (nesting_depth >= kMaxNestingDepth) {
+ return Status::IOError("S3 filesystem tree exceeds maximum nesting depth (",
+ kMaxNestingDepth, ")");
+ }
+ return Status::OK();
+ }
+
+ // A helper class for Walk and WalkAsync
+ struct FileInfoCollector {
+ FileInfoCollector(std::string bucket, std::string key, const FileSelector& select)
+ : bucket(std::move(bucket)),
+ key(std::move(key)),
+ allow_not_found(select.allow_not_found) {}
+
+ Status Collect(const std::string& prefix, const S3Model::ListObjectsV2Result& result,
+ std::vector<FileInfo>* out) {
+ // Walk "directories"
+ for (const auto& child_prefix : result.GetCommonPrefixes()) {
+ is_empty = false;
+ const auto child_key =
+ internal::RemoveTrailingSlash(FromAwsString(child_prefix.GetPrefix()));
+ std::stringstream child_path;
+ child_path << bucket << kSep << child_key;
+ FileInfo info;
+ info.set_path(child_path.str());
+ info.set_type(FileType::Directory);
+ out->push_back(std::move(info));
+ }
+ // Walk "files"
+ for (const auto& obj : result.GetContents()) {
+ is_empty = false;
+ FileInfo info;
+ const auto child_key = internal::RemoveTrailingSlash(FromAwsString(obj.GetKey()));
+ if (child_key == util::string_view(prefix)) {
+ // Amazon can return the "directory" key itself as part of the results, skip
+ continue;
+ }
+ std::stringstream child_path;
+ child_path << bucket << kSep << child_key;
+ info.set_path(child_path.str());
+ FileObjectToInfo(obj, &info);
+ out->push_back(std::move(info));
+ }
+ return Status::OK();
+ }
+
+ Status Finish(Impl* impl) {
+ // If no contents were found, perhaps it's an empty "directory",
+ // or perhaps it's a nonexistent entry. Check.
+ if (is_empty && !allow_not_found) {
+ bool is_actually_empty;
+ RETURN_NOT_OK(impl->IsEmptyDirectory(bucket, key, &is_actually_empty));
+ if (!is_actually_empty) {
+ return PathNotFound(bucket, key);
+ }
+ }
+ return Status::OK();
+ }
+
+ std::string bucket;
+ std::string key;
+ bool allow_not_found;
+ bool is_empty = true;
+ };
+
+ // Workhorse for GetFileInfo(FileSelector...)
+ Status Walk(const FileSelector& select, const std::string& bucket,
+ const std::string& key, std::vector<FileInfo>* out) {
+ FileInfoCollector collector(bucket, key, select);
+
+ auto handle_error = [&](const AWSError<S3Errors>& error) -> Status {
+ if (select.allow_not_found && IsNotFound(error)) {
+ return Status::OK();
+ }
+ return ErrorToStatus(std::forward_as_tuple("When listing objects under key '", key,
+ "' in bucket '", bucket, "': "),
+ error);
+ };
+
+ auto handle_recursion = [&](int32_t nesting_depth) -> Result<bool> {
+ RETURN_NOT_OK(CheckNestingDepth(nesting_depth));
+ return select.recursive && nesting_depth <= select.max_recursion;
+ };
+
+ auto handle_results = [&](const std::string& prefix,
+ const S3Model::ListObjectsV2Result& result) -> Status {
+ return collector.Collect(prefix, result, out);
+ };
+
+ RETURN_NOT_OK(TreeWalker::Walk(client_, io_context_, bucket, key, kListObjectsMaxKeys,
+ handle_results, handle_error, handle_recursion));
+
+ // If no contents were found, perhaps it's an empty "directory",
+ // or perhaps it's a nonexistent entry. Check.
+ RETURN_NOT_OK(collector.Finish(this));
+ // Sort results for convenience, since they can come massively out of order
+ std::sort(out->begin(), out->end(), FileInfo::ByPath{});
+ return Status::OK();
+ }
+
+ // Workhorse for GetFileInfoGenerator(FileSelector...)
+ FileInfoGenerator WalkAsync(const FileSelector& select, const std::string& bucket,
+ const std::string& key) {
+ PushGenerator<std::vector<FileInfo>> gen;
+ auto producer = gen.producer();
+ auto collector = std::make_shared<FileInfoCollector>(bucket, key, select);
+ auto self = shared_from_this();
+
+ auto handle_error = [select, bucket, key](const AWSError<S3Errors>& error) -> Status {
+ if (select.allow_not_found && IsNotFound(error)) {
+ return Status::OK();
+ }
+ return ErrorToStatus(std::forward_as_tuple("When listing objects under key '", key,
+ "' in bucket '", bucket, "': "),
+ error);
+ };
+
+ auto handle_recursion = [producer, select,
+ self](int32_t nesting_depth) -> Result<bool> {
+ if (producer.is_closed()) {
+ return false;
+ }
+ RETURN_NOT_OK(self->CheckNestingDepth(nesting_depth));
+ return select.recursive && nesting_depth <= select.max_recursion;
+ };
+
+ auto handle_results =
+ [collector, producer](
+ const std::string& prefix,
+ const S3Model::ListObjectsV2Result& result) mutable -> Status {
+ std::vector<FileInfo> out;
+ RETURN_NOT_OK(collector->Collect(prefix, result, &out));
+ if (!out.empty()) {
+ producer.Push(std::move(out));
+ }
+ return Status::OK();
+ };
+
+ TreeWalker::WalkAsync(client_, io_context_, bucket, key, kListObjectsMaxKeys,
+ handle_results, handle_error, handle_recursion)
+ .AddCallback([collector, producer, self](const Status& status) mutable {
+ auto st = collector->Finish(self.get());
+ if (!st.ok()) {
+ producer.Push(st);
+ }
+ producer.Close();
+ });
+ return gen;
+ }
+
+ Status WalkForDeleteDir(const std::string& bucket, const std::string& key,
+ std::vector<std::string>* file_keys,
+ std::vector<std::string>* dir_keys) {
+ auto handle_results = [&](const std::string& prefix,
+ const S3Model::ListObjectsV2Result& result) -> Status {
+ // Walk "files"
+ file_keys->reserve(file_keys->size() + result.GetContents().size());
+ for (const auto& obj : result.GetContents()) {
+ file_keys->emplace_back(FromAwsString(obj.GetKey()));
+ }
+ // Walk "directories"
+ dir_keys->reserve(dir_keys->size() + result.GetCommonPrefixes().size());
+ for (const auto& prefix : result.GetCommonPrefixes()) {
+ dir_keys->emplace_back(FromAwsString(prefix.GetPrefix()));
+ }
+ return Status::OK();
+ };
+
+ auto handle_error = [&](const AWSError<S3Errors>& error) -> Status {
+ return ErrorToStatus(std::forward_as_tuple("When listing objects under key '", key,
+ "' in bucket '", bucket, "': "),
+ error);
+ };
+
+ auto handle_recursion = [&](int32_t nesting_depth) -> Result<bool> {
+ RETURN_NOT_OK(CheckNestingDepth(nesting_depth));
+ return true; // Recurse
+ };
+
+ return TreeWalker::Walk(client_, io_context_, bucket, key, kListObjectsMaxKeys,
+ handle_results, handle_error, handle_recursion);
+ }
+
+ // Delete multiple objects at once
+ Future<> DeleteObjectsAsync(const std::string& bucket,
+ const std::vector<std::string>& keys) {
+ struct DeleteCallback {
+ const std::string bucket;
+
+ Status operator()(const S3Model::DeleteObjectsOutcome& outcome) {
+ if (!outcome.IsSuccess()) {
+ return ErrorToStatus(outcome.GetError());
+ }
+ // Also need to check per-key errors, even on successful outcome
+ // See
+ // https://docs.aws.amazon.com/fr_fr/AmazonS3/latest/API/multiobjectdeleteapi.html
+ const auto& errors = outcome.GetResult().GetErrors();
+ if (!errors.empty()) {
+ std::stringstream ss;
+ ss << "Got the following " << errors.size()
+ << " errors when deleting objects in S3 bucket '" << bucket << "':\n";
+ for (const auto& error : errors) {
+ ss << "- key '" << error.GetKey() << "': " << error.GetMessage() << "\n";
+ }
+ return Status::IOError(ss.str());
+ }
+ return Status::OK();
+ }
+ };
+
+ const auto chunk_size = static_cast<size_t>(kMultipleDeleteMaxKeys);
+ DeleteCallback delete_cb{bucket};
+ auto client = client_;
+
+ std::vector<Future<>> futures;
+ futures.reserve(keys.size() / chunk_size + 1);
+
+ for (size_t start = 0; start < keys.size(); start += chunk_size) {
+ S3Model::DeleteObjectsRequest req;
+ S3Model::Delete del;
+ for (size_t i = start; i < std::min(keys.size(), chunk_size); ++i) {
+ del.AddObjects(S3Model::ObjectIdentifier().WithKey(ToAwsString(keys[i])));
+ }
+ req.SetBucket(ToAwsString(bucket));
+ req.SetDelete(std::move(del));
+ ARROW_ASSIGN_OR_RAISE(auto fut, SubmitIO(io_context_, [client, req]() {
+ return client->DeleteObjects(req);
+ }));
+ futures.push_back(std::move(fut).Then(delete_cb));
+ }
+
+ return AllComplete(futures);
+ }
+
+ Status DeleteObjects(const std::string& bucket, const std::vector<std::string>& keys) {
+ return DeleteObjectsAsync(bucket, keys).status();
+ }
+
+ Status DeleteDirContents(const std::string& bucket, const std::string& key) {
+ std::vector<std::string> file_keys;
+ std::vector<std::string> dir_keys;
+ RETURN_NOT_OK(WalkForDeleteDir(bucket, key, &file_keys, &dir_keys));
+ if (file_keys.empty() && dir_keys.empty() && !key.empty()) {
+ // No contents found, is it an empty directory?
+ bool exists = false;
+ RETURN_NOT_OK(IsEmptyDirectory(bucket, key, &exists));
+ if (!exists) {
+ return PathNotFound(bucket, key);
+ }
+ }
+ // First delete all "files", then delete all child "directories"
+ RETURN_NOT_OK(DeleteObjects(bucket, file_keys));
+ // Delete directories in reverse lexicographic order, to ensure children
+ // are deleted before their parents (Minio).
+ std::sort(dir_keys.rbegin(), dir_keys.rend());
+ return DeleteObjects(bucket, dir_keys);
+ }
+
+ Status EnsureDirectoryExists(const S3Path& path) {
+ if (!path.key.empty()) {
+ return CreateEmptyDir(path.bucket, path.key);
+ }
+ return Status::OK();
+ }
+
+ Status EnsureParentExists(const S3Path& path) {
+ if (path.has_parent()) {
+ return EnsureDirectoryExists(path.parent());
+ }
+ return Status::OK();
+ }
+
+ static Result<std::vector<std::string>> ProcessListBuckets(
+ const Aws::S3::Model::ListBucketsOutcome& outcome) {
+ if (!outcome.IsSuccess()) {
+ return ErrorToStatus(std::forward_as_tuple("When listing buckets: "),
+ outcome.GetError());
+ }
+ std::vector<std::string> buckets;
+ buckets.reserve(outcome.GetResult().GetBuckets().size());
+ for (const auto& bucket : outcome.GetResult().GetBuckets()) {
+ buckets.emplace_back(FromAwsString(bucket.GetName()));
+ }
+ return buckets;
+ }
+
+ Result<std::vector<std::string>> ListBuckets() {
+ auto outcome = client_->ListBuckets();
+ return ProcessListBuckets(outcome);
+ }
+
+ Future<std::vector<std::string>> ListBucketsAsync(io::IOContext ctx) {
+ auto self = shared_from_this();
+ return DeferNotOk(SubmitIO(ctx, [self]() { return self->client_->ListBuckets(); }))
+ // TODO(ARROW-12655) Change to Then(Impl::ProcessListBuckets)
+ .Then([](const Aws::S3::Model::ListBucketsOutcome& outcome) {
+ return Impl::ProcessListBuckets(outcome);
+ });
+ }
+
+ Result<std::shared_ptr<ObjectInputFile>> OpenInputFile(const std::string& s,
+ S3FileSystem* fs) {
+ ARROW_ASSIGN_OR_RAISE(auto path, S3Path::FromString(s));
+ RETURN_NOT_OK(ValidateFilePath(path));
+
+ auto ptr = std::make_shared<ObjectInputFile>(client_, fs->io_context(), path);
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+ }
+
+ Result<std::shared_ptr<ObjectInputFile>> OpenInputFile(const FileInfo& info,
+ S3FileSystem* fs) {
+ if (info.type() == FileType::NotFound) {
+ return ::arrow::fs::internal::PathNotFound(info.path());
+ }
+ if (info.type() != FileType::File && info.type() != FileType::Unknown) {
+ return ::arrow::fs::internal::NotAFile(info.path());
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto path, S3Path::FromString(info.path()));
+ RETURN_NOT_OK(ValidateFilePath(path));
+
+ auto ptr =
+ std::make_shared<ObjectInputFile>(client_, fs->io_context(), path, info.size());
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+ }
+};
+
+S3FileSystem::S3FileSystem(const S3Options& options, const io::IOContext& io_context)
+ : FileSystem(io_context), impl_(std::make_shared<Impl>(options, io_context)) {
+ default_async_is_sync_ = false;
+}
+
+S3FileSystem::~S3FileSystem() {}
+
+Result<std::shared_ptr<S3FileSystem>> S3FileSystem::Make(
+ const S3Options& options, const io::IOContext& io_context) {
+ RETURN_NOT_OK(CheckS3Initialized());
+
+ std::shared_ptr<S3FileSystem> ptr(new S3FileSystem(options, io_context));
+ RETURN_NOT_OK(ptr->impl_->Init());
+ return ptr;
+}
+
+bool S3FileSystem::Equals(const FileSystem& other) const {
+ if (this == &other) {
+ return true;
+ }
+ if (other.type_name() != type_name()) {
+ return false;
+ }
+ const auto& s3fs = ::arrow::internal::checked_cast<const S3FileSystem&>(other);
+ return options().Equals(s3fs.options());
+}
+
+S3Options S3FileSystem::options() const { return impl_->options(); }
+
+std::string S3FileSystem::region() const { return impl_->region(); }
+
+Result<FileInfo> S3FileSystem::GetFileInfo(const std::string& s) {
+ ARROW_ASSIGN_OR_RAISE(auto path, S3Path::FromString(s));
+ FileInfo info;
+ info.set_path(s);
+
+ if (path.empty()) {
+ // It's the root path ""
+ info.set_type(FileType::Directory);
+ return info;
+ } else if (path.key.empty()) {
+ // It's a bucket
+ S3Model::HeadBucketRequest req;
+ req.SetBucket(ToAwsString(path.bucket));
+
+ auto outcome = impl_->client_->HeadBucket(req);
+ if (!outcome.IsSuccess()) {
+ if (!IsNotFound(outcome.GetError())) {
+ return ErrorToStatus(
+ std::forward_as_tuple("When getting information for bucket '", path.bucket,
+ "': "),
+ outcome.GetError());
+ }
+ info.set_type(FileType::NotFound);
+ return info;
+ }
+ // NOTE: S3 doesn't have a bucket modification time. Only a creation
+ // time is available, and you have to list all buckets to get it.
+ info.set_type(FileType::Directory);
+ return info;
+ } else {
+ // It's an object
+ S3Model::HeadObjectRequest req;
+ req.SetBucket(ToAwsString(path.bucket));
+ req.SetKey(ToAwsString(path.key));
+
+ auto outcome = impl_->client_->HeadObject(req);
+ if (outcome.IsSuccess()) {
+ // "File" object found
+ FileObjectToInfo(outcome.GetResult(), &info);
+ return info;
+ }
+ if (!IsNotFound(outcome.GetError())) {
+ return ErrorToStatus(
+ std::forward_as_tuple("When getting information for key '", path.key,
+ "' in bucket '", path.bucket, "': "),
+ outcome.GetError());
+ }
+ // Not found => perhaps it's an empty "directory"
+ bool is_dir = false;
+ RETURN_NOT_OK(impl_->IsEmptyDirectory(path, &is_dir));
+ if (is_dir) {
+ info.set_type(FileType::Directory);
+ return info;
+ }
+ // Not found => perhaps it's a non-empty "directory"
+ RETURN_NOT_OK(impl_->IsNonEmptyDirectory(path, &is_dir));
+ if (is_dir) {
+ info.set_type(FileType::Directory);
+ } else {
+ info.set_type(FileType::NotFound);
+ }
+ return info;
+ }
+}
+
+Result<FileInfoVector> S3FileSystem::GetFileInfo(const FileSelector& select) {
+ ARROW_ASSIGN_OR_RAISE(auto base_path, S3Path::FromString(select.base_dir));
+
+ FileInfoVector results;
+
+ if (base_path.empty()) {
+ // List all buckets
+ ARROW_ASSIGN_OR_RAISE(auto buckets, impl_->ListBuckets());
+ for (const auto& bucket : buckets) {
+ FileInfo info;
+ info.set_path(bucket);
+ info.set_type(FileType::Directory);
+ results.push_back(std::move(info));
+ if (select.recursive) {
+ RETURN_NOT_OK(impl_->Walk(select, bucket, "", &results));
+ }
+ }
+ return results;
+ }
+
+ // Nominal case -> walk a single bucket
+ RETURN_NOT_OK(impl_->Walk(select, base_path.bucket, base_path.key, &results));
+ return results;
+}
+
+FileInfoGenerator S3FileSystem::GetFileInfoGenerator(const FileSelector& select) {
+ auto maybe_base_path = S3Path::FromString(select.base_dir);
+ if (!maybe_base_path.ok()) {
+ return MakeFailingGenerator<FileInfoVector>(maybe_base_path.status());
+ }
+ auto base_path = *std::move(maybe_base_path);
+
+ if (base_path.empty()) {
+ // List all buckets, then possibly recurse
+ PushGenerator<AsyncGenerator<FileInfoVector>> gen;
+ auto producer = gen.producer();
+
+ auto fut = impl_->ListBucketsAsync(io_context());
+ auto impl = impl_->shared_from_this();
+ fut.AddCallback(
+ [producer, select, impl](const Result<std::vector<std::string>>& res) mutable {
+ if (!res.ok()) {
+ producer.Push(res.status());
+ producer.Close();
+ return;
+ }
+ FileInfoVector buckets;
+ for (const auto& bucket : *res) {
+ buckets.push_back(FileInfo{bucket, FileType::Directory});
+ }
+ // Generate all bucket infos
+ auto buckets_fut = Future<FileInfoVector>::MakeFinished(std::move(buckets));
+ producer.Push(MakeSingleFutureGenerator(buckets_fut));
+ if (select.recursive) {
+ // Generate recursive walk for each bucket in turn
+ for (const auto& bucket : *buckets_fut.result()) {
+ producer.Push(impl->WalkAsync(select, bucket.path(), ""));
+ }
+ }
+ producer.Close();
+ });
+
+ return MakeConcatenatedGenerator(
+ AsyncGenerator<AsyncGenerator<FileInfoVector>>{std::move(gen)});
+ }
+
+ // Nominal case -> walk a single bucket
+ return impl_->WalkAsync(select, base_path.bucket, base_path.key);
+}
+
+Status S3FileSystem::CreateDir(const std::string& s, bool recursive) {
+ ARROW_ASSIGN_OR_RAISE(auto path, S3Path::FromString(s));
+
+ if (path.key.empty()) {
+ // Create bucket
+ return impl_->CreateBucket(path.bucket);
+ }
+
+ // Create object
+ if (recursive) {
+ // Ensure bucket exists
+ ARROW_ASSIGN_OR_RAISE(bool bucket_exists, impl_->BucketExists(path.bucket));
+ if (!bucket_exists) {
+ RETURN_NOT_OK(impl_->CreateBucket(path.bucket));
+ }
+ // Ensure that all parents exist, then the directory itself
+ std::string parent_key;
+ for (const auto& part : path.key_parts) {
+ parent_key += part;
+ parent_key += kSep;
+ RETURN_NOT_OK(impl_->CreateEmptyObject(path.bucket, parent_key));
+ }
+ return Status::OK();
+ } else {
+ // Check parent dir exists
+ if (path.has_parent()) {
+ S3Path parent_path = path.parent();
+ bool exists;
+ RETURN_NOT_OK(impl_->IsNonEmptyDirectory(parent_path, &exists));
+ if (!exists) {
+ RETURN_NOT_OK(impl_->IsEmptyDirectory(parent_path, &exists));
+ }
+ if (!exists) {
+ return Status::IOError("Cannot create directory '", path.full_path,
+ "': parent directory does not exist");
+ }
+ }
+
+ // XXX Should we check that no non-directory entry exists?
+ // Minio does it for us, not sure about other S3 implementations.
+ return impl_->CreateEmptyDir(path.bucket, path.key);
+ }
+}
+
+Status S3FileSystem::DeleteDir(const std::string& s) {
+ ARROW_ASSIGN_OR_RAISE(auto path, S3Path::FromString(s));
+
+ if (path.empty()) {
+ return Status::NotImplemented("Cannot delete all S3 buckets");
+ }
+ RETURN_NOT_OK(impl_->DeleteDirContents(path.bucket, path.key));
+ if (path.key.empty()) {
+ // Delete bucket
+ S3Model::DeleteBucketRequest req;
+ req.SetBucket(ToAwsString(path.bucket));
+ return OutcomeToStatus(
+ std::forward_as_tuple("When deleting bucket '", path.bucket, "': "),
+ impl_->client_->DeleteBucket(req));
+ } else {
+ // Delete "directory"
+ RETURN_NOT_OK(impl_->DeleteObject(path.bucket, path.key + kSep));
+ // Parent may be implicitly deleted if it became empty, recreate it
+ return impl_->EnsureParentExists(path);
+ }
+}
+
+Status S3FileSystem::DeleteDirContents(const std::string& s) {
+ ARROW_ASSIGN_OR_RAISE(auto path, S3Path::FromString(s));
+
+ if (path.empty()) {
+ return Status::NotImplemented("Cannot delete all S3 buckets");
+ }
+ RETURN_NOT_OK(impl_->DeleteDirContents(path.bucket, path.key));
+ // Directory may be implicitly deleted, recreate it
+ return impl_->EnsureDirectoryExists(path);
+}
+
+Status S3FileSystem::DeleteRootDirContents() {
+ return Status::NotImplemented("Cannot delete all S3 buckets");
+}
+
+Status S3FileSystem::DeleteFile(const std::string& s) {
+ ARROW_ASSIGN_OR_RAISE(auto path, S3Path::FromString(s));
+ RETURN_NOT_OK(ValidateFilePath(path));
+
+ // Check the object exists
+ S3Model::HeadObjectRequest req;
+ req.SetBucket(ToAwsString(path.bucket));
+ req.SetKey(ToAwsString(path.key));
+
+ auto outcome = impl_->client_->HeadObject(req);
+ if (!outcome.IsSuccess()) {
+ if (IsNotFound(outcome.GetError())) {
+ return PathNotFound(path);
+ } else {
+ return ErrorToStatus(
+ std::forward_as_tuple("When getting information for key '", path.key,
+ "' in bucket '", path.bucket, "': "),
+ outcome.GetError());
+ }
+ }
+ // Object found, delete it
+ RETURN_NOT_OK(impl_->DeleteObject(path.bucket, path.key));
+ // Parent may be implicitly deleted if it became empty, recreate it
+ return impl_->EnsureParentExists(path);
+}
+
+Status S3FileSystem::Move(const std::string& src, const std::string& dest) {
+ // XXX We don't implement moving directories as it would be too expensive:
+ // one must copy all directory contents one by one (including object data),
+ // then delete the original contents.
+
+ ARROW_ASSIGN_OR_RAISE(auto src_path, S3Path::FromString(src));
+ RETURN_NOT_OK(ValidateFilePath(src_path));
+ ARROW_ASSIGN_OR_RAISE(auto dest_path, S3Path::FromString(dest));
+ RETURN_NOT_OK(ValidateFilePath(dest_path));
+
+ if (src_path == dest_path) {
+ return Status::OK();
+ }
+ RETURN_NOT_OK(impl_->CopyObject(src_path, dest_path));
+ RETURN_NOT_OK(impl_->DeleteObject(src_path.bucket, src_path.key));
+ // Source parent may be implicitly deleted if it became empty, recreate it
+ return impl_->EnsureParentExists(src_path);
+}
+
+Status S3FileSystem::CopyFile(const std::string& src, const std::string& dest) {
+ ARROW_ASSIGN_OR_RAISE(auto src_path, S3Path::FromString(src));
+ RETURN_NOT_OK(ValidateFilePath(src_path));
+ ARROW_ASSIGN_OR_RAISE(auto dest_path, S3Path::FromString(dest));
+ RETURN_NOT_OK(ValidateFilePath(dest_path));
+
+ if (src_path == dest_path) {
+ return Status::OK();
+ }
+ return impl_->CopyObject(src_path, dest_path);
+}
+
+Result<std::shared_ptr<io::InputStream>> S3FileSystem::OpenInputStream(
+ const std::string& s) {
+ return impl_->OpenInputFile(s, this);
+}
+
+Result<std::shared_ptr<io::InputStream>> S3FileSystem::OpenInputStream(
+ const FileInfo& info) {
+ return impl_->OpenInputFile(info, this);
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> S3FileSystem::OpenInputFile(
+ const std::string& s) {
+ return impl_->OpenInputFile(s, this);
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> S3FileSystem::OpenInputFile(
+ const FileInfo& info) {
+ return impl_->OpenInputFile(info, this);
+}
+
+Result<std::shared_ptr<io::OutputStream>> S3FileSystem::OpenOutputStream(
+ const std::string& s, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ ARROW_ASSIGN_OR_RAISE(auto path, S3Path::FromString(s));
+ RETURN_NOT_OK(ValidateFilePath(path));
+
+ auto ptr = std::make_shared<ObjectOutputStream>(impl_->client_, io_context(), path,
+ impl_->options(), metadata);
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+}
+
+Result<std::shared_ptr<io::OutputStream>> S3FileSystem::OpenAppendStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ // XXX Investigate UploadPartCopy? Does it work with source == destination?
+ // https://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadUploadPartCopy.html
+ // (but would need to fall back to GET if the current data is < 5 MB)
+ return Status::NotImplemented("It is not possible to append efficiently to S3 objects");
+}
+
+//
+// Top-level utility functions
+//
+
+Result<std::string> ResolveBucketRegion(const std::string& bucket) {
+ ARROW_ASSIGN_OR_RAISE(auto resolver, RegionResolver::DefaultInstance());
+ return resolver->ResolveRegion(bucket);
+}
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/s3fs.h b/src/arrow/cpp/src/arrow/filesystem/s3fs.h
new file mode 100644
index 000000000..abb9c852a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/s3fs.h
@@ -0,0 +1,315 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/filesystem/filesystem.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/uri.h"
+
+namespace Aws {
+namespace Auth {
+
+class AWSCredentialsProvider;
+class STSAssumeRoleCredentialsProvider;
+
+} // namespace Auth
+namespace STS {
+class STSClient;
+}
+} // namespace Aws
+
+namespace arrow {
+namespace fs {
+
+/// Options for using a proxy for S3
+struct ARROW_EXPORT S3ProxyOptions {
+ std::string scheme;
+ std::string host;
+ int port = -1;
+ std::string username;
+ std::string password;
+
+ /// Initialize from URI such as http://username:password@host:port
+ /// or http://host:port
+ static Result<S3ProxyOptions> FromUri(const std::string& uri);
+ static Result<S3ProxyOptions> FromUri(const ::arrow::internal::Uri& uri);
+
+ bool Equals(const S3ProxyOptions& other) const;
+};
+
+enum class S3CredentialsKind : int8_t {
+ /// Anonymous access (no credentials used)
+ Anonymous,
+ /// Use default AWS credentials, configured through environment variables
+ Default,
+ /// Use explicitly-provided access key pair
+ Explicit,
+ /// Assume role through a role ARN
+ Role,
+ /// Use web identity token to assume role, configured through environment variables
+ WebIdentity
+};
+
+/// Pure virtual class for describing custom S3 retry strategies
+class S3RetryStrategy {
+ public:
+ virtual ~S3RetryStrategy() = default;
+
+ /// Simple struct where each field corresponds to a field in Aws::Client::AWSError
+ struct AWSErrorDetail {
+ /// Corresponds to AWSError::GetErrorType()
+ int error_type;
+ /// Corresponds to AWSError::GetMessage()
+ std::string message;
+ /// Corresponds to AWSError::GetExceptionName()
+ std::string exception_name;
+ /// Corresponds to AWSError::ShouldRetry()
+ bool should_retry;
+ };
+ /// Returns true if the S3 request resulting in the provided error should be retried.
+ virtual bool ShouldRetry(const AWSErrorDetail& error, int64_t attempted_retries) = 0;
+ /// Returns the time in milliseconds the S3 client should sleep for until retrying.
+ virtual int64_t CalculateDelayBeforeNextRetry(const AWSErrorDetail& error,
+ int64_t attempted_retries) = 0;
+};
+
+/// Options for the S3FileSystem implementation.
+struct ARROW_EXPORT S3Options {
+ /// \brief AWS region to connect to.
+ ///
+ /// If unset, the AWS SDK will choose a default value. The exact algorithm
+ /// depends on the SDK version. Before 1.8, the default is hardcoded
+ /// to "us-east-1". Since 1.8, several heuristics are used to determine
+ /// the region (environment variables, configuration profile, EC2 metadata
+ /// server).
+ std::string region;
+
+ /// If non-empty, override region with a connect string such as "localhost:9000"
+ // XXX perhaps instead take a URL like "http://localhost:9000"?
+ std::string endpoint_override;
+ /// S3 connection transport, default "https"
+ std::string scheme = "https";
+
+ /// ARN of role to assume
+ std::string role_arn;
+ /// Optional identifier for an assumed role session.
+ std::string session_name;
+ /// Optional external idenitifer to pass to STS when assuming a role
+ std::string external_id;
+ /// Frequency (in seconds) to refresh temporary credentials from assumed role
+ int load_frequency;
+
+ /// If connection is through a proxy, set options here
+ S3ProxyOptions proxy_options;
+
+ /// AWS credentials provider
+ std::shared_ptr<Aws::Auth::AWSCredentialsProvider> credentials_provider;
+
+ /// Type of credentials being used. Set along with credentials_provider.
+ S3CredentialsKind credentials_kind = S3CredentialsKind::Default;
+
+ /// Whether OutputStream writes will be issued in the background, without blocking.
+ bool background_writes = true;
+
+ /// \brief Default metadata for OpenOutputStream.
+ ///
+ /// This will be ignored if non-empty metadata is passed to OpenOutputStream.
+ std::shared_ptr<const KeyValueMetadata> default_metadata;
+
+ /// Optional retry strategy to determine which error types should be retried, and the
+ /// delay between retries.
+ std::shared_ptr<S3RetryStrategy> retry_strategy;
+
+ /// Configure with the default AWS credentials provider chain.
+ void ConfigureDefaultCredentials();
+
+ /// Configure with anonymous credentials. This will only let you access public buckets.
+ void ConfigureAnonymousCredentials();
+
+ /// Configure with explicit access and secret key.
+ void ConfigureAccessKey(const std::string& access_key, const std::string& secret_key,
+ const std::string& session_token = "");
+
+ /// Configure with credentials from an assumed role.
+ void ConfigureAssumeRoleCredentials(
+ const std::string& role_arn, const std::string& session_name = "",
+ const std::string& external_id = "", int load_frequency = 900,
+ const std::shared_ptr<Aws::STS::STSClient>& stsClient = NULLPTR);
+
+ /// Configure with credentials from role assumed using a web identitiy token
+ void ConfigureAssumeRoleWithWebIdentityCredentials();
+
+ std::string GetAccessKey() const;
+ std::string GetSecretKey() const;
+ std::string GetSessionToken() const;
+
+ bool Equals(const S3Options& other) const;
+
+ /// \brief Initialize with default credentials provider chain
+ ///
+ /// This is recommended if you use the standard AWS environment variables
+ /// and/or configuration file.
+ static S3Options Defaults();
+
+ /// \brief Initialize with anonymous credentials.
+ ///
+ /// This will only let you access public buckets.
+ static S3Options Anonymous();
+
+ /// \brief Initialize with explicit access and secret key.
+ ///
+ /// Optionally, a session token may also be provided for temporary credentials
+ /// (from STS).
+ static S3Options FromAccessKey(const std::string& access_key,
+ const std::string& secret_key,
+ const std::string& session_token = "");
+
+ /// \brief Initialize from an assumed role.
+ static S3Options FromAssumeRole(
+ const std::string& role_arn, const std::string& session_name = "",
+ const std::string& external_id = "", int load_frequency = 900,
+ const std::shared_ptr<Aws::STS::STSClient>& stsClient = NULLPTR);
+
+ /// \brief Initialize from an assumed role with web-identity.
+ /// Uses the AWS SDK which uses environment variables to
+ /// generate temporary credentials.
+ static S3Options FromAssumeRoleWithWebIdentity();
+
+ static Result<S3Options> FromUri(const ::arrow::internal::Uri& uri,
+ std::string* out_path = NULLPTR);
+ static Result<S3Options> FromUri(const std::string& uri,
+ std::string* out_path = NULLPTR);
+};
+
+/// S3-backed FileSystem implementation.
+///
+/// Some implementation notes:
+/// - buckets are special and the operations available on them may be limited
+/// or more expensive than desired.
+class ARROW_EXPORT S3FileSystem : public FileSystem {
+ public:
+ ~S3FileSystem() override;
+
+ std::string type_name() const override { return "s3"; }
+
+ /// Return the original S3 options when constructing the filesystem
+ S3Options options() const;
+ /// Return the actual region this filesystem connects to
+ std::string region() const;
+
+ bool Equals(const FileSystem& other) const override;
+
+ /// \cond FALSE
+ using FileSystem::GetFileInfo;
+ /// \endcond
+ Result<FileInfo> GetFileInfo(const std::string& path) override;
+ Result<std::vector<FileInfo>> GetFileInfo(const FileSelector& select) override;
+
+ FileInfoGenerator GetFileInfoGenerator(const FileSelector& select) override;
+
+ Status CreateDir(const std::string& path, bool recursive = true) override;
+
+ Status DeleteDir(const std::string& path) override;
+ Status DeleteDirContents(const std::string& path) override;
+ Status DeleteRootDirContents() override;
+
+ Status DeleteFile(const std::string& path) override;
+
+ Status Move(const std::string& src, const std::string& dest) override;
+
+ Status CopyFile(const std::string& src, const std::string& dest) override;
+
+ /// Create a sequential input stream for reading from a S3 object.
+ ///
+ /// NOTE: Reads from the stream will be synchronous and unbuffered.
+ /// You way want to wrap the stream in a BufferedInputStream or use
+ /// a custom readahead strategy to avoid idle waits.
+ Result<std::shared_ptr<io::InputStream>> OpenInputStream(
+ const std::string& path) override;
+ /// Create a sequential input stream for reading from a S3 object.
+ ///
+ /// This override avoids a HEAD request by assuming the FileInfo
+ /// contains correct information.
+ Result<std::shared_ptr<io::InputStream>> OpenInputStream(const FileInfo& info) override;
+
+ /// Create a random access file for reading from a S3 object.
+ ///
+ /// See OpenInputStream for performance notes.
+ Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(
+ const std::string& path) override;
+ /// Create a random access file for reading from a S3 object.
+ ///
+ /// This override avoids a HEAD request by assuming the FileInfo
+ /// contains correct information.
+ Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(
+ const FileInfo& info) override;
+
+ /// Create a sequential output stream for writing to a S3 object.
+ ///
+ /// NOTE: Writes to the stream will be buffered. Depending on
+ /// S3Options.background_writes, they can be synchronous or not.
+ /// It is recommended to enable background_writes unless you prefer
+ /// implementing your own background execution strategy.
+ Result<std::shared_ptr<io::OutputStream>> OpenOutputStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+
+ Result<std::shared_ptr<io::OutputStream>> OpenAppendStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+
+ /// Create a S3FileSystem instance from the given options.
+ static Result<std::shared_ptr<S3FileSystem>> Make(
+ const S3Options& options, const io::IOContext& = io::default_io_context());
+
+ protected:
+ explicit S3FileSystem(const S3Options& options, const io::IOContext&);
+
+ class Impl;
+ std::shared_ptr<Impl> impl_;
+};
+
+enum class S3LogLevel : int8_t { Off, Fatal, Error, Warn, Info, Debug, Trace };
+
+struct ARROW_EXPORT S3GlobalOptions {
+ S3LogLevel log_level;
+};
+
+/// Initialize the S3 APIs. It is required to call this function at least once
+/// before using S3FileSystem.
+ARROW_EXPORT
+Status InitializeS3(const S3GlobalOptions& options);
+
+/// Ensure the S3 APIs are initialized, but only if not already done.
+/// If necessary, this will call InitializeS3() with some default options.
+ARROW_EXPORT
+Status EnsureS3Initialized();
+
+/// Shutdown the S3 APIs.
+ARROW_EXPORT
+Status FinalizeS3();
+
+ARROW_EXPORT
+Result<std::string> ResolveBucketRegion(const std::string& bucket);
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/s3fs_benchmark.cc b/src/arrow/cpp/src/arrow/filesystem/s3fs_benchmark.cc
new file mode 100644
index 000000000..869601b84
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/s3fs_benchmark.cc
@@ -0,0 +1,432 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <sstream>
+#include <utility>
+
+#include "benchmark/benchmark.h"
+
+#include <aws/core/auth/AWSCredentials.h>
+#include <aws/s3/S3Client.h>
+#include <aws/s3/model/CreateBucketRequest.h>
+#include <aws/s3/model/HeadBucketRequest.h>
+#include <aws/s3/model/PutObjectRequest.h>
+
+#include "arrow/filesystem/s3_internal.h"
+#include "arrow/filesystem/s3_test_util.h"
+#include "arrow/filesystem/s3fs.h"
+#include "arrow/io/caching.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/range.h"
+
+#include "parquet/arrow/reader.h"
+#include "parquet/arrow/writer.h"
+#include "parquet/properties.h"
+
+namespace arrow {
+namespace fs {
+
+using ::arrow::fs::internal::ConnectRetryStrategy;
+using ::arrow::fs::internal::OutcomeToStatus;
+using ::arrow::fs::internal::ToAwsString;
+
+// Environment variables to configure the S3 test environment
+static const char* kEnvBucketName = "ARROW_TEST_S3_BUCKET";
+static const char* kEnvSkipSetup = "ARROW_TEST_S3_SKIP_SETUP";
+static const char* kEnvAwsRegion = "ARROW_TEST_S3_REGION";
+
+// Set up Minio and create the test bucket and files.
+class MinioFixture : public benchmark::Fixture {
+ public:
+ void SetUp(const ::benchmark::State& state) override {
+ minio_.reset(new MinioTestServer());
+ ASSERT_OK(minio_->Start());
+
+ const char* region_str = std::getenv(kEnvAwsRegion);
+ if (region_str) {
+ region_ = region_str;
+ std::cerr << "Using region from environment: " << region_ << std::endl;
+ } else {
+ std::cerr << "Using default region" << std::endl;
+ }
+
+ const char* bucket_str = std::getenv(kEnvBucketName);
+ if (bucket_str) {
+ bucket_ = bucket_str;
+ std::cerr << "Using bucket from environment: " << bucket_ << std::endl;
+ } else {
+ bucket_ = "bucket";
+ std::cerr << "Using default bucket: " << bucket_ << std::endl;
+ }
+
+ client_config_.endpointOverride = ToAwsString(minio_->connect_string());
+ client_config_.scheme = Aws::Http::Scheme::HTTP;
+ if (!region_.empty()) {
+ client_config_.region = ToAwsString(region_);
+ }
+ client_config_.retryStrategy = std::make_shared<ConnectRetryStrategy>();
+ credentials_ = {ToAwsString(minio_->access_key()), ToAwsString(minio_->secret_key())};
+ bool use_virtual_addressing = false;
+ client_.reset(
+ new Aws::S3::S3Client(credentials_, client_config_,
+ Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
+ use_virtual_addressing));
+
+ MakeFileSystem();
+
+ const char* skip_str = std::getenv(kEnvSkipSetup);
+ const std::string skip = skip_str ? std::string(skip_str) : "";
+ if (!skip.empty()) {
+ std::cerr << "Skipping creation of bucket/objects as requested" << std::endl;
+ } else {
+ ASSERT_OK(MakeBucket());
+ ASSERT_OK(MakeObject("bytes_1mib", 1024 * 1024));
+ ASSERT_OK(MakeObject("bytes_100mib", 100 * 1024 * 1024));
+ ASSERT_OK(MakeObject("bytes_500mib", 500 * 1024 * 1024));
+ ASSERT_OK(MakeParquetObject(bucket_ + "/pq_c402_r250k", 400, 250000));
+ }
+ }
+
+ void MakeFileSystem() {
+ options_.ConfigureAccessKey(minio_->access_key(), minio_->secret_key());
+ options_.scheme = "http";
+ if (!region_.empty()) {
+ options_.region = region_;
+ }
+ options_.endpoint_override = minio_->connect_string();
+ ASSERT_OK_AND_ASSIGN(fs_, S3FileSystem::Make(options_));
+ }
+
+ /// Set up bucket if it doesn't exist.
+ ///
+ /// When using Minio we'll have a fresh setup each time, but
+ /// otherwise we may have a leftover bucket.
+ Status MakeBucket() {
+ Aws::S3::Model::HeadBucketRequest head;
+ head.SetBucket(ToAwsString(bucket_));
+ const Status st = OutcomeToStatus(client_->HeadBucket(head));
+ if (st.ok()) {
+ // Bucket exists already
+ return st;
+ }
+ Aws::S3::Model::CreateBucketRequest req;
+ req.SetBucket(ToAwsString(bucket_));
+ return OutcomeToStatus(client_->CreateBucket(req));
+ }
+
+ /// Make an object with dummy data.
+ Status MakeObject(const std::string& name, int size) {
+ Aws::S3::Model::PutObjectRequest req;
+ req.SetBucket(ToAwsString(bucket_));
+ req.SetKey(ToAwsString(name));
+ req.SetBody(std::make_shared<std::stringstream>(std::string(size, 'a')));
+ return OutcomeToStatus(client_->PutObject(req));
+ }
+
+ /// Make an object with Parquet data.
+ /// Appends integer columns to the beginning (to act as indices).
+ Status MakeParquetObject(const std::string& path, int num_columns, int num_rows) {
+ std::vector<std::shared_ptr<ChunkedArray>> columns;
+ FieldVector fields{
+ field("timestamp", int64(), /*nullable=*/true,
+ key_value_metadata(
+ {{"min", "0"}, {"max", "10000000000"}, {"null_probability", "0"}})),
+ field("val", int32(), /*nullable=*/true,
+ key_value_metadata(
+ {{"min", "0"}, {"max", "1000000000"}, {"null_probability", "0"}}))};
+ for (int i = 0; i < num_columns; i++) {
+ std::stringstream ss;
+ ss << "col" << i;
+ fields.push_back(
+ field(ss.str(), float64(), /*nullable=*/true,
+ key_value_metadata(
+ {{"min", "-1.e10"}, {"max", "1e10"}, {"null_probability", "0"}})));
+ }
+ auto batch = random::GenerateBatch(fields, num_rows, 0);
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Table> table,
+ Table::FromRecordBatches({batch}));
+
+ std::shared_ptr<io::OutputStream> sink;
+ ARROW_ASSIGN_OR_RAISE(sink, fs_->OpenOutputStream(path));
+ RETURN_NOT_OK(
+ parquet::arrow::WriteTable(*table, arrow::default_memory_pool(), sink, num_rows));
+
+ return Status::OK();
+ }
+
+ void TearDown(const ::benchmark::State& state) override {
+ ASSERT_OK(minio_->Stop());
+ // Delete temporary directory, freeing up disk space
+ minio_.reset();
+ }
+
+ protected:
+ std::unique_ptr<MinioTestServer> minio_;
+ std::string region_;
+ std::string bucket_;
+ Aws::Client::ClientConfiguration client_config_;
+ Aws::Auth::AWSCredentials credentials_;
+ std::unique_ptr<Aws::S3::S3Client> client_;
+ S3Options options_;
+ std::shared_ptr<S3FileSystem> fs_;
+};
+
+/// Set up/tear down the AWS SDK globally.
+/// (GBenchmark doesn't run GTest environments.)
+class S3BenchmarkEnvironment {
+ public:
+ S3BenchmarkEnvironment() { s3_env->SetUp(); }
+ ~S3BenchmarkEnvironment() { s3_env->TearDown(); }
+};
+
+S3BenchmarkEnvironment env{};
+
+/// Read the entire file into memory in one go to measure bandwidth.
+static void NaiveRead(benchmark::State& st, S3FileSystem* fs, const std::string& path) {
+ int64_t total_bytes = 0;
+ int total_items = 0;
+ for (auto _ : st) {
+ std::shared_ptr<io::RandomAccessFile> file;
+ std::shared_ptr<Buffer> buf;
+ int64_t size;
+ ASSERT_OK_AND_ASSIGN(file, fs->OpenInputFile(path));
+ ASSERT_OK_AND_ASSIGN(size, file->GetSize());
+ ASSERT_OK_AND_ASSIGN(buf, file->ReadAt(0, size));
+ total_bytes += buf->size();
+ total_items += 1;
+ }
+ st.SetBytesProcessed(total_bytes);
+ st.SetItemsProcessed(total_items);
+ std::cerr << "Read the file " << total_items << " times" << std::endl;
+}
+
+constexpr int64_t kChunkSize = 5 * 1024 * 1024;
+
+/// Mimic the Parquet reader, reading the file in small chunks.
+static void ChunkedRead(benchmark::State& st, S3FileSystem* fs, const std::string& path) {
+ int64_t total_bytes = 0;
+ int total_items = 0;
+ for (auto _ : st) {
+ std::shared_ptr<io::RandomAccessFile> file;
+ std::shared_ptr<Buffer> buf;
+ int64_t size = 0;
+ ASSERT_OK_AND_ASSIGN(file, fs->OpenInputFile(path));
+ ASSERT_OK_AND_ASSIGN(size, file->GetSize());
+ total_items += 1;
+
+ int64_t offset = 0;
+ while (offset < size) {
+ const int64_t read = std::min(size, kChunkSize);
+ ASSERT_OK_AND_ASSIGN(buf, file->ReadAt(offset, read));
+ total_bytes += buf->size();
+ offset += buf->size();
+ }
+ }
+ st.SetBytesProcessed(total_bytes);
+ st.SetItemsProcessed(total_items);
+ std::cerr << "Read the file " << total_items << " times" << std::endl;
+}
+
+/// Read the file in small chunks, but using read coalescing.
+static void CoalescedRead(benchmark::State& st, S3FileSystem* fs,
+ const std::string& path) {
+ int64_t total_bytes = 0;
+ int total_items = 0;
+ for (auto _ : st) {
+ std::shared_ptr<io::RandomAccessFile> file;
+ std::shared_ptr<Buffer> buf;
+ int64_t size = 0;
+ ASSERT_OK_AND_ASSIGN(file, fs->OpenInputFile(path));
+ ASSERT_OK_AND_ASSIGN(size, file->GetSize());
+ total_items += 1;
+
+ io::internal::ReadRangeCache cache(
+ file, {},
+ io::CacheOptions{/*hole_size_limit=*/8192, /*range_size_limit=*/64 * 1024 * 1024,
+ /*lazy=*/false});
+ std::vector<io::ReadRange> ranges;
+
+ int64_t offset = 0;
+ while (offset < size) {
+ const int64_t read = std::min(size, kChunkSize);
+ ranges.push_back(io::ReadRange{offset, read});
+ offset += read;
+ }
+ ASSERT_OK(cache.Cache(ranges));
+
+ offset = 0;
+ while (offset < size) {
+ const int64_t read = std::min(size, kChunkSize);
+ ASSERT_OK_AND_ASSIGN(buf, cache.Read({offset, read}));
+ total_bytes += buf->size();
+ offset += read;
+ }
+ }
+ st.SetBytesProcessed(total_bytes);
+ st.SetItemsProcessed(total_items);
+ std::cerr << "Read the file " << total_items << " times" << std::endl;
+}
+
+/// Read a Parquet file from S3.
+static void ParquetRead(benchmark::State& st, S3FileSystem* fs, const std::string& path,
+ std::vector<int> column_indices, bool pre_buffer,
+ std::string read_strategy) {
+ int64_t total_bytes = 0;
+ int total_items = 0;
+
+ parquet::ArrowReaderProperties properties;
+ properties.set_use_threads(true);
+ properties.set_pre_buffer(pre_buffer);
+ parquet::ReaderProperties parquet_properties = parquet::default_reader_properties();
+
+ for (auto _ : st) {
+ std::shared_ptr<io::RandomAccessFile> file;
+ int64_t size = 0;
+ ASSERT_OK_AND_ASSIGN(file, fs->OpenInputFile(path));
+ ASSERT_OK_AND_ASSIGN(size, file->GetSize());
+
+ std::unique_ptr<parquet::arrow::FileReader> reader;
+ parquet::arrow::FileReaderBuilder builder;
+ ASSERT_OK(builder.Open(file, parquet_properties));
+ ASSERT_OK(builder.properties(properties)->Build(&reader));
+
+ std::shared_ptr<Table> table;
+
+ if (read_strategy == "ReadTable") {
+ ASSERT_OK(reader->ReadTable(column_indices, &table));
+ } else {
+ std::shared_ptr<RecordBatchReader> rb_reader;
+ ASSERT_OK(reader->GetRecordBatchReader({0}, column_indices, &rb_reader));
+ ASSERT_OK(rb_reader->ReadAll(&table));
+ }
+
+ // TODO: actually measure table memory usage
+ total_bytes += size;
+ total_items += 1;
+ }
+ st.SetBytesProcessed(total_bytes);
+ st.SetItemsProcessed(total_items);
+}
+
+/// Helper function used in the macros below to template benchmarks.
+static void ParquetReadAll(benchmark::State& st, S3FileSystem* fs,
+ const std::string& bucket, int64_t file_rows,
+ int64_t file_cols, bool pre_buffer,
+ std::string read_strategy) {
+ std::vector<int> column_indices(file_cols);
+ std::iota(column_indices.begin(), column_indices.end(), 0);
+ std::stringstream ss;
+ ss << bucket << "/pq_c" << file_cols << "_r" << file_rows << "k";
+ ParquetRead(st, fs, ss.str(), column_indices, false, read_strategy);
+}
+
+/// Helper function used in the macros below to template benchmarks.
+static void ParquetReadSome(benchmark::State& st, S3FileSystem* fs,
+ const std::string& bucket, int64_t file_rows,
+ int64_t file_cols, std::vector<int> cols_to_read,
+ bool pre_buffer, std::string read_strategy) {
+ std::stringstream ss;
+ ss << bucket << "/pq_c" << file_cols << "_r" << file_rows << "k";
+ ParquetRead(st, fs, ss.str(), cols_to_read, false, read_strategy);
+}
+
+BENCHMARK_DEFINE_F(MinioFixture, ReadAll1Mib)(benchmark::State& st) {
+ NaiveRead(st, fs_.get(), bucket_ + "/bytes_1mib");
+}
+BENCHMARK_REGISTER_F(MinioFixture, ReadAll1Mib)->UseRealTime();
+BENCHMARK_DEFINE_F(MinioFixture, ReadAll100Mib)(benchmark::State& st) {
+ NaiveRead(st, fs_.get(), bucket_ + "/bytes_100mib");
+}
+BENCHMARK_REGISTER_F(MinioFixture, ReadAll100Mib)->UseRealTime();
+BENCHMARK_DEFINE_F(MinioFixture, ReadAll500Mib)(benchmark::State& st) {
+ NaiveRead(st, fs_.get(), bucket_ + "/bytes_500mib");
+}
+BENCHMARK_REGISTER_F(MinioFixture, ReadAll500Mib)->UseRealTime();
+
+BENCHMARK_DEFINE_F(MinioFixture, ReadChunked100Mib)(benchmark::State& st) {
+ ChunkedRead(st, fs_.get(), bucket_ + "/bytes_100mib");
+}
+BENCHMARK_REGISTER_F(MinioFixture, ReadChunked100Mib)->UseRealTime();
+BENCHMARK_DEFINE_F(MinioFixture, ReadChunked500Mib)(benchmark::State& st) {
+ ChunkedRead(st, fs_.get(), bucket_ + "/bytes_500mib");
+}
+BENCHMARK_REGISTER_F(MinioFixture, ReadChunked500Mib)->UseRealTime();
+
+BENCHMARK_DEFINE_F(MinioFixture, ReadCoalesced100Mib)(benchmark::State& st) {
+ CoalescedRead(st, fs_.get(), bucket_ + "/bytes_100mib");
+}
+BENCHMARK_REGISTER_F(MinioFixture, ReadCoalesced100Mib)->UseRealTime();
+BENCHMARK_DEFINE_F(MinioFixture, ReadCoalesced500Mib)(benchmark::State& st) {
+ CoalescedRead(st, fs_.get(), bucket_ + "/bytes_500mib");
+}
+BENCHMARK_REGISTER_F(MinioFixture, ReadCoalesced500Mib)->UseRealTime();
+
+// Helpers to generate various multiple benchmarks for a given Parquet file.
+
+// NAME: the base name of the benchmark.
+// ROWS: the number of rows in the Parquet file.
+// COLS: the number of columns in the Parquet file.
+// STRATEGY: how to read the file (ReadTable or GetRecordBatchReader)
+#define PQ_BENCHMARK_IMPL(NAME, ROWS, COLS, STRATEGY) \
+ BENCHMARK_DEFINE_F(MinioFixture, NAME##STRATEGY##AllNaive)(benchmark::State & st) { \
+ ParquetReadAll(st, fs_.get(), bucket_, ROWS, COLS, false, #STRATEGY); \
+ } \
+ BENCHMARK_REGISTER_F(MinioFixture, NAME##STRATEGY##AllNaive)->UseRealTime(); \
+ BENCHMARK_DEFINE_F(MinioFixture, NAME##STRATEGY##AllCoalesced) \
+ (benchmark::State & st) { \
+ ParquetReadAll(st, fs_.get(), bucket_, ROWS, COLS, true, #STRATEGY); \
+ } \
+ BENCHMARK_REGISTER_F(MinioFixture, NAME##STRATEGY##AllCoalesced)->UseRealTime();
+
+// COL_INDICES: a vector specifying a subset of column indices to read.
+#define PQ_BENCHMARK_PICK_IMPL(NAME, ROWS, COLS, COL_INDICES, STRATEGY) \
+ BENCHMARK_DEFINE_F(MinioFixture, NAME##STRATEGY##PickNaive)(benchmark::State & st) { \
+ ParquetReadSome(st, fs_.get(), bucket_, ROWS, COLS, COL_INDICES, false, #STRATEGY); \
+ } \
+ BENCHMARK_REGISTER_F(MinioFixture, NAME##STRATEGY##PickNaive)->UseRealTime(); \
+ BENCHMARK_DEFINE_F(MinioFixture, NAME##STRATEGY##PickCoalesced) \
+ (benchmark::State & st) { \
+ ParquetReadSome(st, fs_.get(), bucket_, ROWS, COLS, COL_INDICES, true, #STRATEGY); \
+ } \
+ BENCHMARK_REGISTER_F(MinioFixture, NAME##STRATEGY##PickCoalesced)->UseRealTime();
+
+#define PQ_BENCHMARK(ROWS, COLS) \
+ PQ_BENCHMARK_IMPL(ReadParquet_c##COLS##_r##ROWS##K_, ROWS, COLS, \
+ GetRecordBatchReader); \
+ PQ_BENCHMARK_IMPL(ReadParquet_c##COLS##_r##ROWS##K_, ROWS, COLS, ReadTable);
+
+#define PQ_BENCHMARK_PICK(NAME, ROWS, COLS, COL_INDICES) \
+ PQ_BENCHMARK_PICK_IMPL(ReadParquet_c##COLS##_r##ROWS##K_##NAME##_, ROWS, COLS, \
+ COL_INDICES, GetRecordBatchReader); \
+ PQ_BENCHMARK_PICK_IMPL(ReadParquet_c##COLS##_r##ROWS##K_##NAME##_, ROWS, COLS, \
+ COL_INDICES, ReadTable);
+
+// Test a Parquet file with 250k rows, 402 columns.
+PQ_BENCHMARK(250, 402);
+// Scenario A: test selecting a small set of contiguous columns, and a "far" column.
+PQ_BENCHMARK_PICK(A, 250, 402, (std::vector<int>{0, 1, 2, 3, 4, 90}));
+// Scenario B: test selecting a large set of contiguous columns.
+PQ_BENCHMARK_PICK(B, 250, 402, (::arrow::internal::Iota(41)));
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/s3fs_narrative_test.cc b/src/arrow/cpp/src/arrow/filesystem/s3fs_narrative_test.cc
new file mode 100644
index 000000000..f65ccba02
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/s3fs_narrative_test.cc
@@ -0,0 +1,245 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// "Narrative" test for S3. This must be run manually against a S3 endpoint.
+// The test bucket must exist and be empty (you can use --clear to delete its
+// contents).
+
+#include <iostream>
+#include <sstream>
+#include <string>
+
+#include <gflags/gflags.h>
+
+#include "arrow/filesystem/s3fs.h"
+#include "arrow/filesystem/test_util.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/logging.h"
+
+DEFINE_bool(clear, false, "delete all bucket contents");
+DEFINE_bool(test, false, "run narrative test against bucket");
+
+DEFINE_bool(verbose, false, "be more verbose (includes AWS warnings)");
+DEFINE_bool(debug, false, "be very verbose (includes AWS debug logs)");
+
+DEFINE_string(access_key, "", "S3 access key");
+DEFINE_string(secret_key, "", "S3 secret key");
+
+DEFINE_string(bucket, "", "bucket name");
+DEFINE_string(region, "", "AWS region");
+DEFINE_string(endpoint, "", "Endpoint override (e.g. '127.0.0.1:9000')");
+DEFINE_string(scheme, "https", "Connection scheme");
+
+namespace arrow {
+namespace fs {
+
+#define ASSERT_RAISES_PRINT(context_msg, error_type, expr) \
+ do { \
+ auto _status_or_result = (expr); \
+ ASSERT_RAISES(error_type, _status_or_result); \
+ PrintError(context_msg, _status_or_result); \
+ } while (0)
+
+std::shared_ptr<FileSystem> MakeFileSystem() {
+ std::shared_ptr<S3FileSystem> s3fs;
+ S3Options options;
+ if (!FLAGS_access_key.empty()) {
+ options = S3Options::FromAccessKey(FLAGS_access_key, FLAGS_secret_key);
+ } else {
+ options = S3Options::Defaults();
+ }
+ options.endpoint_override = FLAGS_endpoint;
+ options.scheme = FLAGS_scheme;
+ options.region = FLAGS_region;
+ s3fs = S3FileSystem::Make(options).ValueOrDie();
+ return std::make_shared<SubTreeFileSystem>(FLAGS_bucket, s3fs);
+}
+
+void PrintError(const std::string& context_msg, const Status& st) {
+ if (FLAGS_verbose) {
+ std::cout << "-- Error printout (" << context_msg << ") --\n"
+ << st.ToString() << std::endl;
+ }
+}
+
+template <typename T>
+void PrintError(const std::string& context_msg, const Result<T>& result) {
+ PrintError(context_msg, result.status());
+}
+
+void ClearBucket(int argc, char** argv) {
+ auto fs = MakeFileSystem();
+
+ ASSERT_OK(fs->DeleteRootDirContents());
+}
+
+void TestBucket(int argc, char** argv) {
+ auto fs = MakeFileSystem();
+ std::vector<FileInfo> infos;
+ FileSelector select;
+ std::shared_ptr<io::InputStream> is;
+ std::shared_ptr<io::RandomAccessFile> file;
+ std::shared_ptr<Buffer> buf;
+ Status status;
+
+ // Check bucket exists and is empty
+ select.base_dir = "";
+ select.allow_not_found = false;
+ select.recursive = false;
+ ASSERT_OK_AND_ASSIGN(infos, fs->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 0) << "Bucket should be empty, perhaps use --clear?";
+
+ // Create directory structure
+ ASSERT_OK(fs->CreateDir("EmptyDir", /*recursive=*/false));
+ ASSERT_OK(fs->CreateDir("Dir1", /*recursive=*/false));
+ ASSERT_OK(fs->CreateDir("Dir1/Subdir", /*recursive=*/false));
+ ASSERT_RAISES_PRINT("CreateDir in nonexistent parent", IOError,
+ fs->CreateDir("Dir2/Subdir", /*recursive=*/false));
+ ASSERT_OK(fs->CreateDir("Dir2/Subdir", /*recursive=*/true));
+ CreateFile(fs.get(), "File1", "first data");
+ CreateFile(fs.get(), "Dir1/File2", "second data");
+ CreateFile(fs.get(), "Dir2/Subdir/File3", "third data");
+
+ // GetFileInfo(Selector)
+ select.base_dir = "";
+ ASSERT_OK_AND_ASSIGN(infos, fs->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 4);
+ SortInfos(&infos);
+ AssertFileInfo(infos[0], "Dir1", FileType::Directory);
+ AssertFileInfo(infos[1], "Dir2", FileType::Directory);
+ AssertFileInfo(infos[2], "EmptyDir", FileType::Directory);
+ AssertFileInfo(infos[3], "File1", FileType::File, 10);
+
+ select.base_dir = "zzzz";
+ ASSERT_RAISES_PRINT("GetFileInfo(Selector) with nonexisting base_dir", IOError,
+ fs->GetFileInfo(select));
+ select.allow_not_found = true;
+ ASSERT_OK_AND_ASSIGN(infos, fs->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 0);
+
+ select.base_dir = "Dir1";
+ select.allow_not_found = false;
+ ASSERT_OK_AND_ASSIGN(infos, fs->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 2);
+ SortInfos(&infos);
+ AssertFileInfo(infos[0], "Dir1/File2", FileType::File, 11);
+ AssertFileInfo(infos[1], "Dir1/Subdir", FileType::Directory);
+
+ select.base_dir = "Dir2";
+ select.recursive = true;
+ ASSERT_OK_AND_ASSIGN(infos, fs->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 2);
+ SortInfos(&infos);
+ AssertFileInfo(infos[0], "Dir2/Subdir", FileType::Directory);
+ AssertFileInfo(infos[1], "Dir2/Subdir/File3", FileType::File, 10);
+
+ // Read a file
+ ASSERT_RAISES_PRINT("OpenInputStream with nonexistent file", IOError,
+ fs->OpenInputStream("zzz"));
+ ASSERT_OK_AND_ASSIGN(is, fs->OpenInputStream("File1"));
+ ASSERT_OK_AND_ASSIGN(buf, is->Read(5));
+ AssertBufferEqual(*buf, "first");
+ ASSERT_OK_AND_ASSIGN(buf, is->Read(10));
+ AssertBufferEqual(*buf, " data");
+ ASSERT_OK_AND_ASSIGN(buf, is->Read(10));
+ AssertBufferEqual(*buf, "");
+ ASSERT_OK(is->Close());
+
+ ASSERT_OK_AND_ASSIGN(file, fs->OpenInputFile("Dir1/File2"));
+ ASSERT_OK_AND_EQ(0, file->Tell());
+ ASSERT_OK(file->Seek(7));
+ ASSERT_OK_AND_EQ(7, file->Tell());
+ ASSERT_OK_AND_ASSIGN(buf, file->Read(2));
+ AssertBufferEqual(*buf, "da");
+ ASSERT_OK_AND_EQ(9, file->Tell());
+ ASSERT_OK_AND_ASSIGN(buf, file->ReadAt(2, 4));
+ AssertBufferEqual(*buf, "cond");
+ ASSERT_OK(file->Close());
+
+ // Copy a file
+ ASSERT_OK(fs->CopyFile("File1", "Dir2/File4"));
+ AssertFileInfo(fs.get(), "File1", FileType::File, 10);
+ AssertFileInfo(fs.get(), "Dir2/File4", FileType::File, 10);
+ AssertFileContents(fs.get(), "Dir2/File4", "first data");
+
+ // Copy a file over itself
+ ASSERT_OK(fs->CopyFile("File1", "File1"));
+ AssertFileInfo(fs.get(), "File1", FileType::File, 10);
+ AssertFileContents(fs.get(), "File1", "first data");
+
+ // Move a file
+ ASSERT_OK(fs->Move("Dir2/File4", "File5"));
+ AssertFileInfo(fs.get(), "Dir2/File4", FileType::NotFound);
+ AssertFileInfo(fs.get(), "File5", FileType::File, 10);
+ AssertFileContents(fs.get(), "File5", "first data");
+
+ // Move a file over itself
+ ASSERT_OK(fs->Move("File5", "File5"));
+ AssertFileInfo(fs.get(), "File5", FileType::File, 10);
+ AssertFileContents(fs.get(), "File5", "first data");
+}
+
+void TestMain(int argc, char** argv) {
+ S3GlobalOptions options;
+ options.log_level = FLAGS_debug
+ ? S3LogLevel::Debug
+ : (FLAGS_verbose ? S3LogLevel::Warn : S3LogLevel::Fatal);
+ ASSERT_OK(InitializeS3(options));
+
+ if (FLAGS_region.empty()) {
+ ASSERT_OK_AND_ASSIGN(FLAGS_region, ResolveBucketRegion(FLAGS_bucket));
+ }
+
+ if (FLAGS_clear) {
+ ClearBucket(argc, argv);
+ } else if (FLAGS_test) {
+ TestBucket(argc, argv);
+ }
+
+ ASSERT_OK(FinalizeS3());
+}
+
+} // namespace fs
+} // namespace arrow
+
+int main(int argc, char** argv) {
+ std::stringstream ss;
+ ss << "Narrative test for S3. Needs an initialized empty bucket.\n";
+ ss << "Usage: " << argv[0];
+ gflags::SetUsageMessage(ss.str());
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+ if (FLAGS_clear + FLAGS_test != 1) {
+ ARROW_LOG(ERROR) << "Need exactly one of --test and --clear";
+ return 2;
+ }
+ if (FLAGS_bucket.empty()) {
+ ARROW_LOG(ERROR) << "--bucket is mandatory";
+ return 2;
+ }
+
+ arrow::fs::TestMain(argc, argv);
+ if (::testing::Test::HasFatalFailure() || ::testing::Test::HasNonfatalFailure()) {
+ return 1;
+ } else {
+ std::cout << "Ok" << std::endl;
+ return 0;
+ }
+}
diff --git a/src/arrow/cpp/src/arrow/filesystem/s3fs_test.cc b/src/arrow/cpp/src/arrow/filesystem/s3fs_test.cc
new file mode 100644
index 000000000..d7618730f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/s3fs_test.cc
@@ -0,0 +1,1084 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <exception>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+// This boost/asio/io_context.hpp include is needless for no MinGW
+// build.
+//
+// This is for including boost/asio/detail/socket_types.hpp before any
+// "#include <windows.h>". boost/asio/detail/socket_types.hpp doesn't
+// work if windows.h is already included. boost/process.h ->
+// boost/process/args.hpp -> boost/process/detail/basic_cmd.hpp
+// includes windows.h. boost/process/args.hpp is included before
+// boost/process/async.h that includes
+// boost/asio/detail/socket_types.hpp implicitly is included.
+#include <boost/asio/io_context.hpp>
+// We need BOOST_USE_WINDOWS_H definition with MinGW when we use
+// boost/process.hpp. See ARROW_BOOST_PROCESS_COMPILE_DEFINITIONS in
+// cpp/cmake_modules/BuildUtils.cmake for details.
+#include <boost/process.hpp>
+
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
+
+#ifdef _WIN32
+// Undefine preprocessor macros that interfere with AWS function / method names
+#ifdef GetMessage
+#undef GetMessage
+#endif
+#ifdef GetObject
+#undef GetObject
+#endif
+#endif
+
+#include <aws/core/Aws.h>
+#include <aws/core/Version.h>
+#include <aws/core/auth/AWSCredentials.h>
+#include <aws/core/auth/AWSCredentialsProvider.h>
+#include <aws/core/client/CoreErrors.h>
+#include <aws/core/client/RetryStrategy.h>
+#include <aws/core/utils/logging/ConsoleLogSystem.h>
+#include <aws/s3/S3Client.h>
+#include <aws/s3/model/CreateBucketRequest.h>
+#include <aws/s3/model/GetObjectRequest.h>
+#include <aws/s3/model/PutObjectRequest.h>
+#include <aws/sts/STSClient.h>
+
+#include "arrow/filesystem/filesystem.h"
+#include "arrow/filesystem/s3_internal.h"
+#include "arrow/filesystem/s3_test_util.h"
+#include "arrow/filesystem/s3fs.h"
+#include "arrow/filesystem/test_util.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace fs {
+
+using ::arrow::internal::checked_pointer_cast;
+using ::arrow::internal::PlatformFilename;
+using ::arrow::internal::UriEscape;
+
+using ::arrow::fs::internal::ConnectRetryStrategy;
+using ::arrow::fs::internal::ErrorToStatus;
+using ::arrow::fs::internal::OutcomeToStatus;
+using ::arrow::fs::internal::ToAwsString;
+
+namespace bp = boost::process;
+
+// NOTE: Connecting in Python:
+// >>> fs = s3fs.S3FileSystem(key='minio', secret='miniopass',
+// client_kwargs=dict(endpoint_url='http://127.0.0.1:9000'))
+// >>> fs.ls('')
+// ['bucket']
+// or:
+// >>> from fs_s3fs import S3FS
+// >>> fs = S3FS('bucket', endpoint_url='http://127.0.0.1:9000',
+// aws_access_key_id='minio', aws_secret_access_key='miniopass')
+
+#define ARROW_AWS_ASSIGN_OR_FAIL_IMPL(outcome_name, lhs, rexpr) \
+ auto outcome_name = (rexpr); \
+ if (!outcome_name.IsSuccess()) { \
+ FAIL() << "'" ARROW_STRINGIFY(rexpr) "' failed with " \
+ << outcome_name.GetError().GetMessage(); \
+ } \
+ lhs = std::move(outcome_name).GetResultWithOwnership();
+
+#define ARROW_AWS_ASSIGN_OR_FAIL_NAME(x, y) ARROW_CONCAT(x, y)
+
+#define ARROW_AWS_ASSIGN_OR_FAIL(lhs, rexpr) \
+ ARROW_AWS_ASSIGN_OR_FAIL_IMPL( \
+ ARROW_AWS_ASSIGN_OR_FAIL_NAME(_aws_error_or_value, __COUNTER__), lhs, rexpr);
+
+class AwsTestMixin : public ::testing::Test {
+ public:
+ // We set this environment variable to speed up tests by ensuring
+ // DefaultAWSCredentialsProviderChain does not query (inaccessible)
+ // EC2 metadata endpoint
+ AwsTestMixin() : ec2_metadata_disabled_guard_("AWS_EC2_METADATA_DISABLED", "true") {}
+
+ void SetUp() override {
+#ifdef AWS_CPP_SDK_S3_NOT_SHARED
+ auto aws_log_level = Aws::Utils::Logging::LogLevel::Fatal;
+ aws_options_.loggingOptions.logLevel = aws_log_level;
+ aws_options_.loggingOptions.logger_create_fn = [&aws_log_level] {
+ return std::make_shared<Aws::Utils::Logging::ConsoleLogSystem>(aws_log_level);
+ };
+ Aws::InitAPI(aws_options_);
+#endif
+ }
+
+ void TearDown() override {
+#ifdef AWS_CPP_SDK_S3_NOT_SHARED
+ Aws::ShutdownAPI(aws_options_);
+#endif
+ }
+
+ private:
+ EnvVarGuard ec2_metadata_disabled_guard_;
+#ifdef AWS_CPP_SDK_S3_NOT_SHARED
+ Aws::SDKOptions aws_options_;
+#endif
+};
+
+class S3TestMixin : public AwsTestMixin {
+ public:
+ void SetUp() override {
+ AwsTestMixin::SetUp();
+
+ ASSERT_OK(minio_.Start());
+
+ client_config_.reset(new Aws::Client::ClientConfiguration());
+ client_config_->endpointOverride = ToAwsString(minio_.connect_string());
+ client_config_->scheme = Aws::Http::Scheme::HTTP;
+ client_config_->retryStrategy = std::make_shared<ConnectRetryStrategy>();
+ credentials_ = {ToAwsString(minio_.access_key()), ToAwsString(minio_.secret_key())};
+ bool use_virtual_addressing = false;
+ client_.reset(
+ new Aws::S3::S3Client(credentials_, *client_config_,
+ Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
+ use_virtual_addressing));
+ }
+
+ void TearDown() override {
+ ASSERT_OK(minio_.Stop());
+
+ AwsTestMixin::TearDown();
+ }
+
+ protected:
+ MinioTestServer minio_;
+ std::unique_ptr<Aws::Client::ClientConfiguration> client_config_;
+ Aws::Auth::AWSCredentials credentials_;
+ std::unique_ptr<Aws::S3::S3Client> client_;
+};
+
+void AssertGetObject(Aws::S3::Model::GetObjectResult& result,
+ const std::string& expected) {
+ auto length = static_cast<int64_t>(expected.length());
+ ASSERT_EQ(result.GetContentLength(), length);
+ auto& stream = result.GetBody();
+ std::string actual;
+ actual.resize(length + 1);
+ stream.read(&actual[0], length + 1);
+ ASSERT_EQ(stream.gcount(), length); // EOF was reached before length + 1
+ actual.resize(length);
+ ASSERT_EQ(actual.size(), expected.size());
+ ASSERT_TRUE(actual == expected); // Avoid ASSERT_EQ on large data
+}
+
+void AssertObjectContents(Aws::S3::S3Client* client, const std::string& bucket,
+ const std::string& key, const std::string& expected) {
+ Aws::S3::Model::GetObjectRequest req;
+ req.SetBucket(ToAwsString(bucket));
+ req.SetKey(ToAwsString(key));
+ ARROW_AWS_ASSIGN_OR_FAIL(auto result, client->GetObject(req));
+ AssertGetObject(result, expected);
+}
+
+////////////////////////////////////////////////////////////////////////////
+// S3Options tests
+
+class S3OptionsTest : public AwsTestMixin {};
+
+TEST_F(S3OptionsTest, FromUri) {
+ std::string path;
+ S3Options options;
+
+ ASSERT_OK_AND_ASSIGN(options, S3Options::FromUri("s3://", &path));
+ ASSERT_EQ(options.region, "");
+ ASSERT_EQ(options.scheme, "https");
+ ASSERT_EQ(options.endpoint_override, "");
+ ASSERT_EQ(path, "");
+
+ ASSERT_OK_AND_ASSIGN(options, S3Options::FromUri("s3:", &path));
+ ASSERT_EQ(path, "");
+
+ ASSERT_OK_AND_ASSIGN(options, S3Options::FromUri("s3://access:secret@mybucket", &path));
+ ASSERT_EQ(path, "mybucket");
+ const auto creds = options.credentials_provider->GetAWSCredentials();
+ ASSERT_EQ(creds.GetAWSAccessKeyId(), "access");
+ ASSERT_EQ(creds.GetAWSSecretKey(), "secret");
+
+ ASSERT_OK_AND_ASSIGN(options, S3Options::FromUri("s3://mybucket/", &path));
+ ASSERT_NE(options.region, ""); // Some region was chosen
+ ASSERT_EQ(options.scheme, "https");
+ ASSERT_EQ(options.endpoint_override, "");
+ ASSERT_EQ(path, "mybucket");
+
+ ASSERT_OK_AND_ASSIGN(options, S3Options::FromUri("s3://mybucket/foo/bar/", &path));
+ ASSERT_NE(options.region, "");
+ ASSERT_EQ(options.scheme, "https");
+ ASSERT_EQ(options.endpoint_override, "");
+ ASSERT_EQ(path, "mybucket/foo/bar");
+
+ // Region resolution with a well-known bucket
+ ASSERT_OK_AND_ASSIGN(
+ options, S3Options::FromUri("s3://aws-earth-mo-atmospheric-ukv-prd/", &path));
+ ASSERT_EQ(options.region, "eu-west-2");
+
+ // Explicit region override
+ ASSERT_OK_AND_ASSIGN(
+ options,
+ S3Options::FromUri(
+ "s3://mybucket/foo/bar/?region=utopia&endpoint_override=localhost&scheme=http",
+ &path));
+ ASSERT_EQ(options.region, "utopia");
+ ASSERT_EQ(options.scheme, "http");
+ ASSERT_EQ(options.endpoint_override, "localhost");
+ ASSERT_EQ(path, "mybucket/foo/bar");
+
+ // Missing bucket name
+ ASSERT_RAISES(Invalid, S3Options::FromUri("s3:///foo/bar/", &path));
+
+ // Invalid option
+ ASSERT_RAISES(Invalid, S3Options::FromUri("s3://mybucket/?xxx=zzz", &path));
+}
+
+TEST_F(S3OptionsTest, FromAccessKey) {
+ S3Options options;
+
+ // session token is optional and should default to empty string
+ options = S3Options::FromAccessKey("access", "secret");
+ ASSERT_EQ(options.GetAccessKey(), "access");
+ ASSERT_EQ(options.GetSecretKey(), "secret");
+ ASSERT_EQ(options.GetSessionToken(), "");
+
+ options = S3Options::FromAccessKey("access", "secret", "token");
+ ASSERT_EQ(options.GetAccessKey(), "access");
+ ASSERT_EQ(options.GetSecretKey(), "secret");
+ ASSERT_EQ(options.GetSessionToken(), "token");
+}
+
+TEST_F(S3OptionsTest, FromAssumeRole) {
+ S3Options options;
+
+ // arn should be only required argument
+ options = S3Options::FromAssumeRole("my_role_arn");
+ options = S3Options::FromAssumeRole("my_role_arn", "session");
+ options = S3Options::FromAssumeRole("my_role_arn", "session", "id");
+ options = S3Options::FromAssumeRole("my_role_arn", "session", "id", 42);
+
+ // test w/ custom STSClient (will not use DefaultAWSCredentialsProviderChain)
+ Aws::Auth::AWSCredentials test_creds = Aws::Auth::AWSCredentials("access", "secret");
+ std::shared_ptr<Aws::STS::STSClient> sts_client =
+ std::make_shared<Aws::STS::STSClient>(Aws::STS::STSClient(test_creds));
+ options = S3Options::FromAssumeRole("my_role_arn", "session", "id", 42, sts_client);
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Region resolution test
+
+class S3RegionResolutionTest : public AwsTestMixin {};
+
+TEST_F(S3RegionResolutionTest, PublicBucket) {
+ ASSERT_OK_AND_EQ("us-east-2", ResolveBucketRegion("ursa-labs-taxi-data"));
+
+ // Taken from a registry of open S3-hosted datasets
+ // at https://github.com/awslabs/open-data-registry
+ ASSERT_OK_AND_EQ("eu-west-2", ResolveBucketRegion("aws-earth-mo-atmospheric-ukv-prd"));
+ // Same again, cached
+ ASSERT_OK_AND_EQ("eu-west-2", ResolveBucketRegion("aws-earth-mo-atmospheric-ukv-prd"));
+}
+
+TEST_F(S3RegionResolutionTest, RestrictedBucket) {
+ ASSERT_OK_AND_EQ("us-west-2", ResolveBucketRegion("ursa-labs-r-test"));
+ // Same again, cached
+ ASSERT_OK_AND_EQ("us-west-2", ResolveBucketRegion("ursa-labs-r-test"));
+}
+
+TEST_F(S3RegionResolutionTest, NonExistentBucket) {
+ auto maybe_region = ResolveBucketRegion("ursa-labs-non-existent-bucket");
+ ASSERT_RAISES(IOError, maybe_region);
+ ASSERT_THAT(maybe_region.status().message(),
+ ::testing::HasSubstr("Bucket 'ursa-labs-non-existent-bucket' not found"));
+}
+
+////////////////////////////////////////////////////////////////////////////
+// S3FileSystem region test
+
+class S3FileSystemRegionTest : public AwsTestMixin {};
+
+TEST_F(S3FileSystemRegionTest, Default) {
+ ASSERT_OK_AND_ASSIGN(auto fs, FileSystemFromUri("s3://"));
+ auto s3fs = checked_pointer_cast<S3FileSystem>(fs);
+ ASSERT_EQ(s3fs->region(), "us-east-1");
+}
+
+// Skipped on Windows, as the AWS SDK ignores runtime environment changes:
+// https://github.com/aws/aws-sdk-cpp/issues/1476
+
+#ifndef _WIN32
+TEST_F(S3FileSystemRegionTest, EnvironmentVariable) {
+ // Region override with environment variable (AWS SDK >= 1.8)
+ EnvVarGuard region_guard("AWS_DEFAULT_REGION", "eu-north-1");
+
+ ASSERT_OK_AND_ASSIGN(auto fs, FileSystemFromUri("s3://"));
+ auto s3fs = checked_pointer_cast<S3FileSystem>(fs);
+
+ if (Aws::Version::GetVersionMajor() > 1 || Aws::Version::GetVersionMinor() >= 8) {
+ ASSERT_EQ(s3fs->region(), "eu-north-1");
+ } else {
+ ASSERT_EQ(s3fs->region(), "us-east-1");
+ }
+}
+#endif
+
+////////////////////////////////////////////////////////////////////////////
+// Basic test for the Minio test server.
+
+class TestMinioServer : public S3TestMixin {
+ public:
+ void SetUp() override { S3TestMixin::SetUp(); }
+
+ protected:
+};
+
+TEST_F(TestMinioServer, Connect) {
+ // Just a dummy connection test. Check that we can list buckets,
+ // and that there are none (the server is launched in an empty temp dir).
+ ARROW_AWS_ASSIGN_OR_FAIL(auto bucket_list, client_->ListBuckets());
+ ASSERT_EQ(bucket_list.GetBuckets().size(), 0);
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Concrete S3 tests
+
+class TestS3FS : public S3TestMixin {
+ public:
+ void SetUp() override {
+ S3TestMixin::SetUp();
+ MakeFileSystem();
+ // Set up test bucket
+ {
+ Aws::S3::Model::CreateBucketRequest req;
+ req.SetBucket(ToAwsString("bucket"));
+ ASSERT_OK(OutcomeToStatus(client_->CreateBucket(req)));
+ req.SetBucket(ToAwsString("empty-bucket"));
+ ASSERT_OK(OutcomeToStatus(client_->CreateBucket(req)));
+ }
+ {
+ Aws::S3::Model::PutObjectRequest req;
+ req.SetBucket(ToAwsString("bucket"));
+ req.SetKey(ToAwsString("emptydir/"));
+ ASSERT_OK(OutcomeToStatus(client_->PutObject(req)));
+ // NOTE: no need to create intermediate "directories" somedir/ and
+ // somedir/subdir/
+ req.SetKey(ToAwsString("somedir/subdir/subfile"));
+ req.SetBody(std::make_shared<std::stringstream>("sub data"));
+ ASSERT_OK(OutcomeToStatus(client_->PutObject(req)));
+ req.SetKey(ToAwsString("somefile"));
+ req.SetBody(std::make_shared<std::stringstream>("some data"));
+ req.SetContentType("x-arrow/test");
+ ASSERT_OK(OutcomeToStatus(client_->PutObject(req)));
+ }
+ }
+
+ void MakeFileSystem() {
+ options_.ConfigureAccessKey(minio_.access_key(), minio_.secret_key());
+ options_.scheme = "http";
+ options_.endpoint_override = minio_.connect_string();
+ ASSERT_OK_AND_ASSIGN(fs_, S3FileSystem::Make(options_));
+ }
+
+ template <typename Matcher>
+ void AssertMetadataRoundtrip(const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata,
+ Matcher&& matcher) {
+ ASSERT_OK_AND_ASSIGN(auto output, fs_->OpenOutputStream(path, metadata));
+ ASSERT_OK(output->Close());
+ ASSERT_OK_AND_ASSIGN(auto input, fs_->OpenInputStream(path));
+ ASSERT_OK_AND_ASSIGN(auto got_metadata, input->ReadMetadata());
+ ASSERT_NE(got_metadata, nullptr);
+ ASSERT_THAT(got_metadata->sorted_pairs(), matcher);
+ }
+
+ void TestOpenOutputStream() {
+ std::shared_ptr<io::OutputStream> stream;
+
+ // Nonexistent
+ ASSERT_RAISES(IOError, fs_->OpenOutputStream("nonexistent-bucket/somefile"));
+
+ // Create new empty file
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/newfile1"));
+ ASSERT_OK(stream->Close());
+ AssertObjectContents(client_.get(), "bucket", "newfile1", "");
+
+ // Create new file with 1 small write
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/newfile2"));
+ ASSERT_OK(stream->Write("some data"));
+ ASSERT_OK(stream->Close());
+ AssertObjectContents(client_.get(), "bucket", "newfile2", "some data");
+
+ // Create new file with 3 small writes
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/newfile3"));
+ ASSERT_OK(stream->Write("some "));
+ ASSERT_OK(stream->Write(""));
+ ASSERT_OK(stream->Write("new data"));
+ ASSERT_OK(stream->Close());
+ AssertObjectContents(client_.get(), "bucket", "newfile3", "some new data");
+
+ // Create new file with some large writes
+ std::string s1, s2, s3, s4, s5, expected;
+ s1 = random_string(6000000, /*seed =*/42); // More than the 5 MB minimum part upload
+ s2 = "xxx";
+ s3 = random_string(6000000, 43);
+ s4 = "zzz";
+ s5 = random_string(600000, 44);
+ expected = s1 + s2 + s3 + s4 + s5;
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/newfile4"));
+ for (auto input : {s1, s2, s3, s4, s5}) {
+ ASSERT_OK(stream->Write(input));
+ // Clobber source contents. This shouldn't reflect in the data written.
+ input.front() = 'x';
+ input.back() = 'x';
+ }
+ ASSERT_OK(stream->Close());
+ AssertObjectContents(client_.get(), "bucket", "newfile4", expected);
+
+ // Overwrite
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/newfile1"));
+ ASSERT_OK(stream->Write("overwritten data"));
+ ASSERT_OK(stream->Close());
+ AssertObjectContents(client_.get(), "bucket", "newfile1", "overwritten data");
+
+ // Overwrite and make empty
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/newfile1"));
+ ASSERT_OK(stream->Close());
+ AssertObjectContents(client_.get(), "bucket", "newfile1", "");
+
+ // Open file and then lose filesystem reference
+ ASSERT_EQ(fs_.use_count(), 1); // needed for test to work
+ std::weak_ptr<S3FileSystem> weak_fs(fs_);
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/newfile99"));
+ fs_.reset();
+ ASSERT_OK(stream->Write("some other data"));
+ ASSERT_OK(stream->Close());
+ ASSERT_TRUE(weak_fs.expired());
+ AssertObjectContents(client_.get(), "bucket", "newfile99", "some other data");
+ }
+
+ void TestOpenOutputStreamAbort() {
+ std::shared_ptr<io::OutputStream> stream;
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/somefile"));
+ ASSERT_OK(stream->Write("new data"));
+ // Abort() cancels the multipart upload.
+ ASSERT_OK(stream->Abort());
+ ASSERT_EQ(stream->closed(), true);
+ AssertObjectContents(client_.get(), "bucket", "somefile", "some data");
+ }
+
+ void TestOpenOutputStreamDestructor() {
+ std::shared_ptr<io::OutputStream> stream;
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/somefile"));
+ ASSERT_OK(stream->Write("new data"));
+ // Destructor implicitly closes stream and completes the multipart upload.
+ stream.reset();
+ AssertObjectContents(client_.get(), "bucket", "somefile", "new data");
+ }
+
+ protected:
+ S3Options options_;
+ std::shared_ptr<S3FileSystem> fs_;
+};
+
+TEST_F(TestS3FS, GetFileInfoRoot) { AssertFileInfo(fs_.get(), "", FileType::Directory); }
+
+TEST_F(TestS3FS, GetFileInfoBucket) {
+ AssertFileInfo(fs_.get(), "bucket", FileType::Directory);
+ AssertFileInfo(fs_.get(), "empty-bucket", FileType::Directory);
+ AssertFileInfo(fs_.get(), "nonexistent-bucket", FileType::NotFound);
+ // Trailing slashes
+ AssertFileInfo(fs_.get(), "bucket/", FileType::Directory);
+ AssertFileInfo(fs_.get(), "empty-bucket/", FileType::Directory);
+ AssertFileInfo(fs_.get(), "nonexistent-bucket/", FileType::NotFound);
+}
+
+TEST_F(TestS3FS, GetFileInfoObject) {
+ // "Directories"
+ AssertFileInfo(fs_.get(), "bucket/emptydir", FileType::Directory, kNoSize);
+ AssertFileInfo(fs_.get(), "bucket/somedir", FileType::Directory, kNoSize);
+ AssertFileInfo(fs_.get(), "bucket/somedir/subdir", FileType::Directory, kNoSize);
+
+ // "Files"
+ AssertFileInfo(fs_.get(), "bucket/somefile", FileType::File, 9);
+ AssertFileInfo(fs_.get(), "bucket/somedir/subdir/subfile", FileType::File, 8);
+
+ // Nonexistent
+ AssertFileInfo(fs_.get(), "bucket/emptyd", FileType::NotFound);
+ AssertFileInfo(fs_.get(), "bucket/somed", FileType::NotFound);
+ AssertFileInfo(fs_.get(), "non-existent-bucket/somed", FileType::NotFound);
+
+ // Trailing slashes
+ AssertFileInfo(fs_.get(), "bucket/emptydir/", FileType::Directory, kNoSize);
+ AssertFileInfo(fs_.get(), "bucket/somefile/", FileType::File, 9);
+ AssertFileInfo(fs_.get(), "bucket/emptyd/", FileType::NotFound);
+ AssertFileInfo(fs_.get(), "non-existent-bucket/somed/", FileType::NotFound);
+}
+
+TEST_F(TestS3FS, GetFileInfoSelector) {
+ FileSelector select;
+ std::vector<FileInfo> infos;
+
+ // Root dir
+ select.base_dir = "";
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 2);
+ SortInfos(&infos);
+ AssertFileInfo(infos[0], "bucket", FileType::Directory);
+ AssertFileInfo(infos[1], "empty-bucket", FileType::Directory);
+
+ // Empty bucket
+ select.base_dir = "empty-bucket";
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 0);
+ // Nonexistent bucket
+ select.base_dir = "nonexistent-bucket";
+ ASSERT_RAISES(IOError, fs_->GetFileInfo(select));
+ select.allow_not_found = true;
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 0);
+ select.allow_not_found = false;
+ // Non-empty bucket
+ select.base_dir = "bucket";
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ SortInfos(&infos);
+ ASSERT_EQ(infos.size(), 3);
+ AssertFileInfo(infos[0], "bucket/emptydir", FileType::Directory);
+ AssertFileInfo(infos[1], "bucket/somedir", FileType::Directory);
+ AssertFileInfo(infos[2], "bucket/somefile", FileType::File, 9);
+
+ // Empty "directory"
+ select.base_dir = "bucket/emptydir";
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 0);
+ // Non-empty "directories"
+ select.base_dir = "bucket/somedir";
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 1);
+ AssertFileInfo(infos[0], "bucket/somedir/subdir", FileType::Directory);
+ select.base_dir = "bucket/somedir/subdir";
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 1);
+ AssertFileInfo(infos[0], "bucket/somedir/subdir/subfile", FileType::File, 8);
+ // Nonexistent
+ select.base_dir = "bucket/nonexistent";
+ ASSERT_RAISES(IOError, fs_->GetFileInfo(select));
+ select.allow_not_found = true;
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 0);
+ select.allow_not_found = false;
+
+ // Trailing slashes
+ select.base_dir = "empty-bucket/";
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 0);
+ select.base_dir = "nonexistent-bucket/";
+ ASSERT_RAISES(IOError, fs_->GetFileInfo(select));
+ select.base_dir = "bucket/";
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ SortInfos(&infos);
+ ASSERT_EQ(infos.size(), 3);
+}
+
+TEST_F(TestS3FS, GetFileInfoSelectorRecursive) {
+ FileSelector select;
+ std::vector<FileInfo> infos;
+ select.recursive = true;
+
+ // Root dir
+ select.base_dir = "";
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 7);
+ SortInfos(&infos);
+ AssertFileInfo(infos[0], "bucket", FileType::Directory);
+ AssertFileInfo(infos[1], "bucket/emptydir", FileType::Directory);
+ AssertFileInfo(infos[2], "bucket/somedir", FileType::Directory);
+ AssertFileInfo(infos[3], "bucket/somedir/subdir", FileType::Directory);
+ AssertFileInfo(infos[4], "bucket/somedir/subdir/subfile", FileType::File, 8);
+ AssertFileInfo(infos[5], "bucket/somefile", FileType::File, 9);
+ AssertFileInfo(infos[6], "empty-bucket", FileType::Directory);
+
+ // Empty bucket
+ select.base_dir = "empty-bucket";
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 0);
+
+ // Non-empty bucket
+ select.base_dir = "bucket";
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ SortInfos(&infos);
+ ASSERT_EQ(infos.size(), 5);
+ AssertFileInfo(infos[0], "bucket/emptydir", FileType::Directory);
+ AssertFileInfo(infos[1], "bucket/somedir", FileType::Directory);
+ AssertFileInfo(infos[2], "bucket/somedir/subdir", FileType::Directory);
+ AssertFileInfo(infos[3], "bucket/somedir/subdir/subfile", FileType::File, 8);
+ AssertFileInfo(infos[4], "bucket/somefile", FileType::File, 9);
+
+ // Empty "directory"
+ select.base_dir = "bucket/emptydir";
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 0);
+
+ // Non-empty "directories"
+ select.base_dir = "bucket/somedir";
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ SortInfos(&infos);
+ ASSERT_EQ(infos.size(), 2);
+ AssertFileInfo(infos[0], "bucket/somedir/subdir", FileType::Directory);
+ AssertFileInfo(infos[1], "bucket/somedir/subdir/subfile", FileType::File, 8);
+}
+
+TEST_F(TestS3FS, GetFileInfoGenerator) {
+ FileSelector select;
+ FileInfoVector infos;
+
+ // Root dir
+ select.base_dir = "";
+ CollectFileInfoGenerator(fs_->GetFileInfoGenerator(select), &infos);
+ ASSERT_EQ(infos.size(), 2);
+ SortInfos(&infos);
+ AssertFileInfo(infos[0], "bucket", FileType::Directory);
+ AssertFileInfo(infos[1], "empty-bucket", FileType::Directory);
+
+ // Root dir, recursive
+ select.recursive = true;
+ CollectFileInfoGenerator(fs_->GetFileInfoGenerator(select), &infos);
+ ASSERT_EQ(infos.size(), 7);
+ SortInfos(&infos);
+ AssertFileInfo(infos[0], "bucket", FileType::Directory);
+ AssertFileInfo(infos[1], "bucket/emptydir", FileType::Directory);
+ AssertFileInfo(infos[2], "bucket/somedir", FileType::Directory);
+ AssertFileInfo(infos[3], "bucket/somedir/subdir", FileType::Directory);
+ AssertFileInfo(infos[4], "bucket/somedir/subdir/subfile", FileType::File, 8);
+ AssertFileInfo(infos[5], "bucket/somefile", FileType::File, 9);
+ AssertFileInfo(infos[6], "empty-bucket", FileType::Directory);
+
+ // Non-root dir case is tested by generic tests
+}
+
+TEST_F(TestS3FS, CreateDir) {
+ FileInfo st;
+
+ // Existing bucket
+ ASSERT_OK(fs_->CreateDir("bucket"));
+ AssertFileInfo(fs_.get(), "bucket", FileType::Directory);
+
+ // New bucket
+ AssertFileInfo(fs_.get(), "new-bucket", FileType::NotFound);
+ ASSERT_OK(fs_->CreateDir("new-bucket"));
+ AssertFileInfo(fs_.get(), "new-bucket", FileType::Directory);
+
+ // Existing "directory"
+ AssertFileInfo(fs_.get(), "bucket/somedir", FileType::Directory);
+ ASSERT_OK(fs_->CreateDir("bucket/somedir"));
+ AssertFileInfo(fs_.get(), "bucket/somedir", FileType::Directory);
+
+ AssertFileInfo(fs_.get(), "bucket/emptydir", FileType::Directory);
+ ASSERT_OK(fs_->CreateDir("bucket/emptydir"));
+ AssertFileInfo(fs_.get(), "bucket/emptydir", FileType::Directory);
+
+ // New "directory"
+ AssertFileInfo(fs_.get(), "bucket/newdir", FileType::NotFound);
+ ASSERT_OK(fs_->CreateDir("bucket/newdir"));
+ AssertFileInfo(fs_.get(), "bucket/newdir", FileType::Directory);
+
+ // New "directory", recursive
+ ASSERT_OK(fs_->CreateDir("bucket/newdir/newsub/newsubsub", /*recursive=*/true));
+ AssertFileInfo(fs_.get(), "bucket/newdir/newsub", FileType::Directory);
+ AssertFileInfo(fs_.get(), "bucket/newdir/newsub/newsubsub", FileType::Directory);
+
+ // Existing "file", should fail
+ ASSERT_RAISES(IOError, fs_->CreateDir("bucket/somefile"));
+}
+
+TEST_F(TestS3FS, DeleteFile) {
+ // Bucket
+ ASSERT_RAISES(IOError, fs_->DeleteFile("bucket"));
+ ASSERT_RAISES(IOError, fs_->DeleteFile("empty-bucket"));
+ ASSERT_RAISES(IOError, fs_->DeleteFile("nonexistent-bucket"));
+
+ // "File"
+ ASSERT_OK(fs_->DeleteFile("bucket/somefile"));
+ AssertFileInfo(fs_.get(), "bucket/somefile", FileType::NotFound);
+ ASSERT_RAISES(IOError, fs_->DeleteFile("bucket/somefile"));
+ ASSERT_RAISES(IOError, fs_->DeleteFile("bucket/nonexistent"));
+
+ // "Directory"
+ ASSERT_RAISES(IOError, fs_->DeleteFile("bucket/somedir"));
+ AssertFileInfo(fs_.get(), "bucket/somedir", FileType::Directory);
+}
+
+TEST_F(TestS3FS, DeleteDir) {
+ FileSelector select;
+ select.base_dir = "bucket";
+ std::vector<FileInfo> infos;
+
+ // Empty "directory"
+ ASSERT_OK(fs_->DeleteDir("bucket/emptydir"));
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 2);
+ SortInfos(&infos);
+ AssertFileInfo(infos[0], "bucket/somedir", FileType::Directory);
+ AssertFileInfo(infos[1], "bucket/somefile", FileType::File);
+
+ // Non-empty "directory"
+ ASSERT_OK(fs_->DeleteDir("bucket/somedir"));
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 1);
+ AssertFileInfo(infos[0], "bucket/somefile", FileType::File);
+
+ // Leaving parent "directory" empty
+ ASSERT_OK(fs_->CreateDir("bucket/newdir/newsub/newsubsub"));
+ ASSERT_OK(fs_->DeleteDir("bucket/newdir/newsub"));
+ ASSERT_OK_AND_ASSIGN(infos, fs_->GetFileInfo(select));
+ ASSERT_EQ(infos.size(), 2);
+ SortInfos(&infos);
+ AssertFileInfo(infos[0], "bucket/newdir", FileType::Directory); // still exists
+ AssertFileInfo(infos[1], "bucket/somefile", FileType::File);
+
+ // Bucket
+ ASSERT_OK(fs_->DeleteDir("bucket"));
+ AssertFileInfo(fs_.get(), "bucket", FileType::NotFound);
+}
+
+TEST_F(TestS3FS, CopyFile) {
+ // "File"
+ ASSERT_OK(fs_->CopyFile("bucket/somefile", "bucket/newfile"));
+ AssertFileInfo(fs_.get(), "bucket/newfile", FileType::File, 9);
+ AssertObjectContents(client_.get(), "bucket", "newfile", "some data");
+ AssertFileInfo(fs_.get(), "bucket/somefile", FileType::File, 9); // still exists
+ // Overwrite
+ ASSERT_OK(fs_->CopyFile("bucket/somedir/subdir/subfile", "bucket/newfile"));
+ AssertFileInfo(fs_.get(), "bucket/newfile", FileType::File, 8);
+ AssertObjectContents(client_.get(), "bucket", "newfile", "sub data");
+ // ARROW-13048: URL-encoded paths
+ ASSERT_OK(fs_->CopyFile("bucket/somefile", "bucket/a=2/newfile"));
+ ASSERT_OK(fs_->CopyFile("bucket/a=2/newfile", "bucket/a=3/newfile"));
+ // Nonexistent
+ ASSERT_RAISES(IOError, fs_->CopyFile("bucket/nonexistent", "bucket/newfile2"));
+ ASSERT_RAISES(IOError, fs_->CopyFile("nonexistent-bucket/somefile", "bucket/newfile2"));
+ ASSERT_RAISES(IOError, fs_->CopyFile("bucket/somefile", "nonexistent-bucket/newfile2"));
+ AssertFileInfo(fs_.get(), "bucket/newfile2", FileType::NotFound);
+}
+
+TEST_F(TestS3FS, Move) {
+ // "File"
+ ASSERT_OK(fs_->Move("bucket/somefile", "bucket/newfile"));
+ AssertFileInfo(fs_.get(), "bucket/newfile", FileType::File, 9);
+ AssertObjectContents(client_.get(), "bucket", "newfile", "some data");
+ // Source was deleted
+ AssertFileInfo(fs_.get(), "bucket/somefile", FileType::NotFound);
+
+ // Overwrite
+ ASSERT_OK(fs_->Move("bucket/somedir/subdir/subfile", "bucket/newfile"));
+ AssertFileInfo(fs_.get(), "bucket/newfile", FileType::File, 8);
+ AssertObjectContents(client_.get(), "bucket", "newfile", "sub data");
+ // Source was deleted
+ AssertFileInfo(fs_.get(), "bucket/somedir/subdir/subfile", FileType::NotFound);
+
+ // ARROW-13048: URL-encoded paths
+ ASSERT_OK(fs_->Move("bucket/newfile", "bucket/a=2/newfile"));
+ ASSERT_OK(fs_->Move("bucket/a=2/newfile", "bucket/a=3/newfile"));
+
+ // Nonexistent
+ ASSERT_RAISES(IOError, fs_->Move("bucket/non-existent", "bucket/newfile2"));
+ ASSERT_RAISES(IOError, fs_->Move("nonexistent-bucket/somefile", "bucket/newfile2"));
+ ASSERT_RAISES(IOError, fs_->Move("bucket/somefile", "nonexistent-bucket/newfile2"));
+ AssertFileInfo(fs_.get(), "bucket/newfile2", FileType::NotFound);
+}
+
+TEST_F(TestS3FS, OpenInputStream) {
+ std::shared_ptr<io::InputStream> stream;
+ std::shared_ptr<Buffer> buf;
+
+ // Nonexistent
+ ASSERT_RAISES(IOError, fs_->OpenInputStream("nonexistent-bucket/somefile"));
+ ASSERT_RAISES(IOError, fs_->OpenInputStream("bucket/zzzt"));
+
+ // "Files"
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenInputStream("bucket/somefile"));
+ ASSERT_OK_AND_ASSIGN(buf, stream->Read(2));
+ AssertBufferEqual(*buf, "so");
+ ASSERT_OK_AND_ASSIGN(buf, stream->Read(5));
+ AssertBufferEqual(*buf, "me da");
+ ASSERT_OK_AND_ASSIGN(buf, stream->Read(5));
+ AssertBufferEqual(*buf, "ta");
+ ASSERT_OK_AND_ASSIGN(buf, stream->Read(5));
+ AssertBufferEqual(*buf, "");
+
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenInputStream("bucket/somedir/subdir/subfile"));
+ ASSERT_OK_AND_ASSIGN(buf, stream->Read(100));
+ AssertBufferEqual(*buf, "sub data");
+ ASSERT_OK_AND_ASSIGN(buf, stream->Read(100));
+ AssertBufferEqual(*buf, "");
+ ASSERT_OK(stream->Close());
+
+ // "Directories"
+ ASSERT_RAISES(IOError, fs_->OpenInputStream("bucket/emptydir"));
+ ASSERT_RAISES(IOError, fs_->OpenInputStream("bucket/somedir"));
+ ASSERT_RAISES(IOError, fs_->OpenInputStream("bucket"));
+
+ // Open file and then lose filesystem reference
+ ASSERT_EQ(fs_.use_count(), 1); // needed for test to work
+ std::weak_ptr<S3FileSystem> weak_fs(fs_);
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenInputStream("bucket/somefile"));
+ fs_.reset();
+ ASSERT_OK_AND_ASSIGN(buf, stream->Read(10));
+ AssertBufferEqual(*buf, "some data");
+ ASSERT_OK(stream->Close());
+ ASSERT_TRUE(weak_fs.expired());
+}
+
+TEST_F(TestS3FS, OpenInputStreamMetadata) {
+ std::shared_ptr<io::InputStream> stream;
+ std::shared_ptr<const KeyValueMetadata> metadata;
+
+ ASSERT_OK_AND_ASSIGN(stream, fs_->OpenInputStream("bucket/somefile"));
+ ASSERT_FINISHES_OK_AND_ASSIGN(metadata, stream->ReadMetadataAsync());
+
+ std::vector<std::pair<std::string, std::string>> expected_kv{
+ {"Content-Length", "9"}, {"Content-Type", "x-arrow/test"}};
+ ASSERT_NE(metadata, nullptr);
+ ASSERT_THAT(metadata->sorted_pairs(), testing::IsSupersetOf(expected_kv));
+}
+
+TEST_F(TestS3FS, OpenInputFile) {
+ std::shared_ptr<io::RandomAccessFile> file;
+ std::shared_ptr<Buffer> buf;
+
+ // Nonexistent
+ ASSERT_RAISES(IOError, fs_->OpenInputFile("nonexistent-bucket/somefile"));
+ ASSERT_RAISES(IOError, fs_->OpenInputFile("bucket/zzzt"));
+
+ // "Files"
+ ASSERT_OK_AND_ASSIGN(file, fs_->OpenInputFile("bucket/somefile"));
+ ASSERT_OK_AND_EQ(9, file->GetSize());
+ ASSERT_OK_AND_ASSIGN(buf, file->Read(4));
+ AssertBufferEqual(*buf, "some");
+ ASSERT_OK_AND_EQ(9, file->GetSize());
+ ASSERT_OK_AND_EQ(4, file->Tell());
+
+ ASSERT_OK_AND_ASSIGN(buf, file->ReadAt(2, 5));
+ AssertBufferEqual(*buf, "me da");
+ ASSERT_OK_AND_EQ(4, file->Tell());
+ ASSERT_OK_AND_ASSIGN(buf, file->ReadAt(5, 20));
+ AssertBufferEqual(*buf, "data");
+ ASSERT_OK_AND_ASSIGN(buf, file->ReadAt(9, 20));
+ AssertBufferEqual(*buf, "");
+
+ char result[10];
+ ASSERT_OK_AND_EQ(5, file->ReadAt(2, 5, &result));
+ ASSERT_OK_AND_EQ(4, file->ReadAt(5, 20, &result));
+ ASSERT_OK_AND_EQ(0, file->ReadAt(9, 0, &result));
+
+ // Reading past end of file
+ ASSERT_RAISES(IOError, file->ReadAt(10, 20));
+
+ ASSERT_OK(file->Seek(5));
+ ASSERT_OK_AND_ASSIGN(buf, file->Read(2));
+ AssertBufferEqual(*buf, "da");
+ ASSERT_OK(file->Seek(9));
+ ASSERT_OK_AND_ASSIGN(buf, file->Read(2));
+ AssertBufferEqual(*buf, "");
+ // Seeking past end of file
+ ASSERT_RAISES(IOError, file->Seek(10));
+}
+
+TEST_F(TestS3FS, OpenOutputStreamBackgroundWrites) { TestOpenOutputStream(); }
+
+TEST_F(TestS3FS, OpenOutputStreamSyncWrites) {
+ options_.background_writes = false;
+ MakeFileSystem();
+ TestOpenOutputStream();
+}
+
+TEST_F(TestS3FS, OpenOutputStreamAbortBackgroundWrites) { TestOpenOutputStreamAbort(); }
+
+TEST_F(TestS3FS, OpenOutputStreamAbortSyncWrites) {
+ options_.background_writes = false;
+ MakeFileSystem();
+ TestOpenOutputStreamAbort();
+}
+
+TEST_F(TestS3FS, OpenOutputStreamDestructorBackgroundWrites) {
+ TestOpenOutputStreamDestructor();
+}
+
+TEST_F(TestS3FS, OpenOutputStreamDestructorSyncWrite) {
+ options_.background_writes = false;
+ MakeFileSystem();
+ TestOpenOutputStreamDestructor();
+}
+
+TEST_F(TestS3FS, OpenOutputStreamMetadata) {
+ std::shared_ptr<io::OutputStream> stream;
+
+ // Create new file with explicit metadata
+ auto metadata = KeyValueMetadata::Make({"Content-Type", "Expires"},
+ {"x-arrow/test6", "2016-02-05T20:08:35Z"});
+ AssertMetadataRoundtrip("bucket/mdfile1", metadata,
+ testing::IsSupersetOf(metadata->sorted_pairs()));
+
+ // Create new file with valid canned ACL
+ // XXX: no easy way of testing the ACL actually gets set
+ metadata = KeyValueMetadata::Make({"ACL"}, {"authenticated-read"});
+ AssertMetadataRoundtrip("bucket/mdfile2", metadata, testing::_);
+
+ // Create new file with default metadata
+ auto default_metadata = KeyValueMetadata::Make({"Content-Type", "Content-Language"},
+ {"image/png", "fr_FR"});
+ options_.default_metadata = default_metadata;
+ MakeFileSystem();
+ // (null, then empty metadata argument)
+ AssertMetadataRoundtrip("bucket/mdfile3", nullptr,
+ testing::IsSupersetOf(default_metadata->sorted_pairs()));
+ AssertMetadataRoundtrip("bucket/mdfile4", KeyValueMetadata::Make({}, {}),
+ testing::IsSupersetOf(default_metadata->sorted_pairs()));
+
+ // Create new file with explicit metadata replacing default metadata
+ metadata = KeyValueMetadata::Make({"Content-Type"}, {"x-arrow/test6"});
+ AssertMetadataRoundtrip("bucket/mdfile5", metadata,
+ testing::IsSupersetOf(metadata->sorted_pairs()));
+}
+
+TEST_F(TestS3FS, FileSystemFromUri) {
+ std::stringstream ss;
+ ss << "s3://" << minio_.access_key() << ":" << minio_.secret_key()
+ << "@bucket/somedir/subdir/subfile"
+ << "?scheme=http&endpoint_override=" << UriEscape(minio_.connect_string());
+
+ std::string path;
+ ASSERT_OK_AND_ASSIGN(auto fs, FileSystemFromUri(ss.str(), &path));
+ ASSERT_EQ(path, "bucket/somedir/subdir/subfile");
+
+ // Check the filesystem has the right connection parameters
+ AssertFileInfo(fs.get(), path, FileType::File, 8);
+}
+
+// Simple retry strategy that records errors encountered and its emitted retry delays
+class TestRetryStrategy : public S3RetryStrategy {
+ public:
+ bool ShouldRetry(const S3RetryStrategy::AWSErrorDetail& error,
+ int64_t attempted_retries) final {
+ errors_encountered_.emplace_back(error);
+ constexpr int64_t MAX_RETRIES = 2;
+ return attempted_retries < MAX_RETRIES;
+ }
+
+ int64_t CalculateDelayBeforeNextRetry(const S3RetryStrategy::AWSErrorDetail& error,
+ int64_t attempted_retries) final {
+ int64_t delay = attempted_retries;
+ retry_delays_.emplace_back(delay);
+ return delay;
+ }
+
+ std::vector<S3RetryStrategy::AWSErrorDetail> GetErrorsEncountered() {
+ return errors_encountered_;
+ }
+ std::vector<int64_t> GetRetryDelays() { return retry_delays_; }
+
+ private:
+ std::vector<S3RetryStrategy::AWSErrorDetail> errors_encountered_;
+ std::vector<int64_t> retry_delays_;
+};
+
+TEST_F(TestS3FS, CustomRetryStrategy) {
+ auto retry_strategy = std::make_shared<TestRetryStrategy>();
+ options_.retry_strategy = retry_strategy;
+ MakeFileSystem();
+ // Attempt to open file that doesn't exist. Should hit TestRetryStrategy::ShouldRetry()
+ // 3 times before bubbling back up here.
+ ASSERT_RAISES(IOError, fs_->OpenInputStream("nonexistent-bucket/somefile"));
+ ASSERT_EQ(retry_strategy->GetErrorsEncountered().size(), 3);
+ for (const auto& error : retry_strategy->GetErrorsEncountered()) {
+ ASSERT_EQ(static_cast<Aws::Client::CoreErrors>(error.error_type),
+ Aws::Client::CoreErrors::RESOURCE_NOT_FOUND);
+ ASSERT_EQ(error.message, "No response body.");
+ ASSERT_EQ(error.exception_name, "");
+ ASSERT_EQ(error.should_retry, false);
+ }
+ std::vector<int64_t> expected_retry_delays = {0, 1, 2};
+ ASSERT_EQ(retry_strategy->GetRetryDelays(), expected_retry_delays);
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Generic S3 tests
+
+class TestS3FSGeneric : public S3TestMixin, public GenericFileSystemTest {
+ public:
+ void SetUp() override {
+ S3TestMixin::SetUp();
+ // Set up test bucket
+ {
+ Aws::S3::Model::CreateBucketRequest req;
+ req.SetBucket(ToAwsString("s3fs-test-bucket"));
+ ASSERT_OK(OutcomeToStatus(client_->CreateBucket(req)));
+ }
+
+ options_.ConfigureAccessKey(minio_.access_key(), minio_.secret_key());
+ options_.scheme = "http";
+ options_.endpoint_override = minio_.connect_string();
+ ASSERT_OK_AND_ASSIGN(s3fs_, S3FileSystem::Make(options_));
+ fs_ = std::make_shared<SubTreeFileSystem>("s3fs-test-bucket", s3fs_);
+ }
+
+ protected:
+ std::shared_ptr<FileSystem> GetEmptyFileSystem() override { return fs_; }
+
+ bool have_implicit_directories() const override { return true; }
+ bool allow_write_file_over_dir() const override { return true; }
+ bool allow_move_dir() const override { return false; }
+ bool allow_append_to_file() const override { return false; }
+ bool have_directory_mtimes() const override { return false; }
+ bool have_flaky_directory_tree_deletion() const override {
+#ifdef _WIN32
+ // Recent Minio versions on Windows may not register deletion of all
+ // directories in a tree when doing a bulk delete.
+ return true;
+#else
+ return false;
+#endif
+ }
+ bool have_file_metadata() const override { return true; }
+
+ S3Options options_;
+ std::shared_ptr<S3FileSystem> s3fs_;
+ std::shared_ptr<FileSystem> fs_;
+};
+
+GENERIC_FS_TEST_FUNCTIONS(TestS3FSGeneric);
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/test_util.cc b/src/arrow/cpp/src/arrow/filesystem/test_util.cc
new file mode 100644
index 000000000..0e2833781
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/test_util.cc
@@ -0,0 +1,1135 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <chrono>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/filesystem/mockfs.h"
+#include "arrow/filesystem/test_util.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/status.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/vector.h"
+
+using ::testing::ElementsAre;
+
+namespace arrow {
+namespace fs {
+
+namespace {
+
+std::vector<FileInfo> GetAllWithType(FileSystem* fs, FileType type) {
+ FileSelector selector;
+ selector.base_dir = "";
+ selector.recursive = true;
+ std::vector<FileInfo> infos = std::move(fs->GetFileInfo(selector)).ValueOrDie();
+ std::vector<FileInfo> result;
+ for (const auto& info : infos) {
+ if (info.type() == type) {
+ result.push_back(info);
+ }
+ }
+ return result;
+}
+
+std::vector<FileInfo> GetAllDirs(FileSystem* fs) {
+ return GetAllWithType(fs, FileType::Directory);
+}
+
+std::vector<FileInfo> GetAllFiles(FileSystem* fs) {
+ return GetAllWithType(fs, FileType::File);
+}
+
+void AssertPaths(const std::vector<FileInfo>& infos,
+ const std::vector<std::string>& expected_paths) {
+ auto sorted_infos = infos;
+ SortInfos(&sorted_infos);
+ std::vector<std::string> paths(sorted_infos.size());
+ std::transform(sorted_infos.begin(), sorted_infos.end(), paths.begin(),
+ [&](const FileInfo& info) { return info.path(); });
+
+ ASSERT_EQ(paths, expected_paths);
+}
+
+void AssertAllDirs(FileSystem* fs, const std::vector<std::string>& expected_paths) {
+ AssertPaths(GetAllDirs(fs), expected_paths);
+}
+
+void AssertAllFiles(FileSystem* fs, const std::vector<std::string>& expected_paths) {
+ AssertPaths(GetAllFiles(fs), expected_paths);
+}
+
+void ValidateTimePoint(TimePoint tp) { ASSERT_GE(tp.time_since_epoch().count(), 0); }
+
+}; // namespace
+
+void AssertFileContents(FileSystem* fs, const std::string& path,
+ const std::string& expected_data) {
+ ASSERT_OK_AND_ASSIGN(FileInfo info, fs->GetFileInfo(path));
+ ASSERT_EQ(info.type(), FileType::File) << "For path '" << path << "'";
+ ASSERT_EQ(info.size(), static_cast<int64_t>(expected_data.length()))
+ << "For path '" << path << "'";
+
+ ASSERT_OK_AND_ASSIGN(auto stream, fs->OpenInputStream(path));
+ ASSERT_OK_AND_ASSIGN(auto buffer, stream->Read(info.size()));
+ AssertBufferEqual(*buffer, expected_data);
+ // No data left in stream
+ ASSERT_OK_AND_ASSIGN(auto leftover, stream->Read(1));
+ ASSERT_EQ(leftover->size(), 0);
+
+ ASSERT_OK(stream->Close());
+}
+
+void CreateFile(FileSystem* fs, const std::string& path, const std::string& data) {
+ ASSERT_OK_AND_ASSIGN(auto stream, fs->OpenOutputStream(path));
+ ASSERT_OK(stream->Write(data));
+ ASSERT_OK(stream->Close());
+}
+
+void SortInfos(std::vector<FileInfo>* infos) {
+ std::sort(infos->begin(), infos->end(), FileInfo::ByPath{});
+}
+
+void CollectFileInfoGenerator(FileInfoGenerator gen, FileInfoVector* out_infos) {
+ auto fut = CollectAsyncGenerator(gen);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto nested_infos, fut);
+ *out_infos = ::arrow::internal::FlattenVectors(nested_infos);
+}
+
+void AssertFileInfo(const FileInfo& info, const std::string& path, FileType type) {
+ ASSERT_EQ(info.path(), path);
+ ASSERT_EQ(info.type(), type) << "For path '" << info.path() << "'";
+}
+
+void AssertFileInfo(const FileInfo& info, const std::string& path, FileType type,
+ TimePoint mtime) {
+ AssertFileInfo(info, path, type);
+ ASSERT_EQ(info.mtime(), mtime) << "For path '" << info.path() << "'";
+}
+
+void AssertFileInfo(const FileInfo& info, const std::string& path, FileType type,
+ TimePoint mtime, int64_t size) {
+ AssertFileInfo(info, path, type, mtime);
+ ASSERT_EQ(info.size(), size) << "For path '" << info.path() << "'";
+}
+
+void AssertFileInfo(const FileInfo& info, const std::string& path, FileType type,
+ int64_t size) {
+ AssertFileInfo(info, path, type);
+ ASSERT_EQ(info.size(), size) << "For path '" << info.path() << "'";
+}
+
+void AssertFileInfo(FileSystem* fs, const std::string& path, FileType type) {
+ ASSERT_OK_AND_ASSIGN(FileInfo info, fs->GetFileInfo(path));
+ AssertFileInfo(info, path, type);
+}
+
+void AssertFileInfo(FileSystem* fs, const std::string& path, FileType type,
+ TimePoint mtime) {
+ ASSERT_OK_AND_ASSIGN(FileInfo info, fs->GetFileInfo(path));
+ AssertFileInfo(info, path, type, mtime);
+}
+
+void AssertFileInfo(FileSystem* fs, const std::string& path, FileType type,
+ TimePoint mtime, int64_t size) {
+ ASSERT_OK_AND_ASSIGN(FileInfo info, fs->GetFileInfo(path));
+ AssertFileInfo(info, path, type, mtime, size);
+}
+
+void AssertFileInfo(FileSystem* fs, const std::string& path, FileType type,
+ int64_t size) {
+ ASSERT_OK_AND_ASSIGN(FileInfo info, fs->GetFileInfo(path));
+ AssertFileInfo(info, path, type, size);
+}
+
+GatedMockFilesystem::GatedMockFilesystem(TimePoint current_time,
+ const io::IOContext& io_context)
+ : internal::MockFileSystem(current_time, io_context) {}
+GatedMockFilesystem::~GatedMockFilesystem() = default;
+
+Result<std::shared_ptr<io::OutputStream>> GatedMockFilesystem::OpenOutputStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ RETURN_NOT_OK(open_output_sem_.Acquire(1));
+ return MockFileSystem::OpenOutputStream(path, metadata);
+}
+
+Status GatedMockFilesystem::WaitForOpenOutputStream(uint32_t num_waiters) {
+ return open_output_sem_.WaitForWaiters(num_waiters);
+}
+
+Status GatedMockFilesystem::UnlockOpenOutputStream(uint32_t num_waiters) {
+ return open_output_sem_.Release(num_waiters);
+}
+
+////////////////////////////////////////////////////////////////////////////
+// GenericFileSystemTest implementation
+
+// XXX is there a way we can test mtimes reliably and precisely?
+
+GenericFileSystemTest::~GenericFileSystemTest() {}
+
+void GenericFileSystemTest::TestEmpty(FileSystem* fs) {
+ auto dirs = GetAllDirs(fs);
+ ASSERT_EQ(dirs.size(), 0);
+ auto files = GetAllFiles(fs);
+ ASSERT_EQ(files.size(), 0);
+}
+
+void GenericFileSystemTest::TestNormalizePath(FileSystem* fs) {
+ // Canonical abstract paths should go through unchanged
+ ASSERT_OK_AND_EQ("AB", fs->NormalizePath("AB"));
+ ASSERT_OK_AND_EQ("AB/CD/efg", fs->NormalizePath("AB/CD/efg"));
+}
+
+void GenericFileSystemTest::TestCreateDir(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB"));
+ ASSERT_OK(fs->CreateDir("AB/CD/EF")); // Recursive
+ if (!have_implicit_directories()) {
+ // Non-recursive, parent doesn't exist
+ ASSERT_RAISES(IOError, fs->CreateDir("AB/GH/IJ", false /* recursive */));
+ }
+ ASSERT_OK(fs->CreateDir("AB/GH", false /* recursive */));
+ ASSERT_OK(fs->CreateDir("AB/GH/IJ", false /* recursive */));
+ // Idempotency
+ ASSERT_OK(fs->CreateDir("AB/GH/IJ", false /* recursive */));
+ ASSERT_OK(fs->CreateDir("XY"));
+
+ AssertAllDirs(fs, {"AB", "AB/CD", "AB/CD/EF", "AB/GH", "AB/GH/IJ", "XY"});
+ AssertAllFiles(fs, {});
+
+ // Cannot create a directory as child of a file
+ CreateFile(fs, "AB/def", "");
+ ASSERT_RAISES(IOError, fs->CreateDir("AB/def/EF/GH", true /* recursive */));
+ ASSERT_RAISES(IOError, fs->CreateDir("AB/def/EF", false /* recursive */));
+
+ // Cannot create a directory when there is already a file with the same name
+ ASSERT_RAISES(IOError, fs->CreateDir("AB/def"));
+
+ AssertAllDirs(fs, {"AB", "AB/CD", "AB/CD/EF", "AB/GH", "AB/GH/IJ", "XY"});
+ AssertAllFiles(fs, {"AB/def"});
+}
+
+void GenericFileSystemTest::TestDeleteDir(FileSystem* fs) {
+ if (have_flaky_directory_tree_deletion())
+ GTEST_SKIP() << "Flaky directory deletion on Windows";
+
+ ASSERT_OK(fs->CreateDir("AB/CD/EF"));
+ ASSERT_OK(fs->CreateDir("AB/GH/IJ"));
+ CreateFile(fs, "AB/abc", "");
+ CreateFile(fs, "AB/CD/def", "");
+ CreateFile(fs, "AB/CD/EF/ghi", "");
+ ASSERT_OK(fs->DeleteDir("AB/CD"));
+ ASSERT_OK(fs->DeleteDir("AB/GH/IJ"));
+
+ AssertAllDirs(fs, {"AB", "AB/GH"});
+ AssertAllFiles(fs, {"AB/abc"});
+
+ // File doesn't exist
+ ASSERT_RAISES(IOError, fs->DeleteDir("AB/GH/IJ"));
+ ASSERT_RAISES(IOError, fs->DeleteDir(""));
+
+ AssertAllDirs(fs, {"AB", "AB/GH"});
+
+ // Not a directory
+ CreateFile(fs, "AB/def", "");
+ ASSERT_RAISES(IOError, fs->DeleteDir("AB/def"));
+
+ AssertAllDirs(fs, {"AB", "AB/GH"});
+ AssertAllFiles(fs, {"AB/abc", "AB/def"});
+}
+
+void GenericFileSystemTest::TestDeleteDirContents(FileSystem* fs) {
+ if (have_flaky_directory_tree_deletion())
+ GTEST_SKIP() << "Flaky directory deletion on Windows";
+
+ ASSERT_OK(fs->CreateDir("AB/CD/EF"));
+ ASSERT_OK(fs->CreateDir("AB/GH/IJ"));
+ CreateFile(fs, "AB/abc", "");
+ CreateFile(fs, "AB/CD/def", "");
+ CreateFile(fs, "AB/CD/EF/ghi", "");
+ ASSERT_OK(fs->DeleteDirContents("AB/CD"));
+ ASSERT_OK(fs->DeleteDirContents("AB/GH/IJ"));
+
+ AssertAllDirs(fs, {"AB", "AB/CD", "AB/GH", "AB/GH/IJ"});
+ AssertAllFiles(fs, {"AB/abc"});
+
+ // Calling DeleteDirContents on root directory is disallowed
+ ASSERT_RAISES(Invalid, fs->DeleteDirContents(""));
+ ASSERT_RAISES(Invalid, fs->DeleteDirContents("/"));
+ AssertAllDirs(fs, {"AB", "AB/CD", "AB/GH", "AB/GH/IJ"});
+ AssertAllFiles(fs, {"AB/abc"});
+
+ // Not a directory
+ CreateFile(fs, "abc", "");
+ ASSERT_RAISES(IOError, fs->DeleteDirContents("abc"));
+ AssertAllFiles(fs, {"AB/abc", "abc"});
+}
+
+void GenericFileSystemTest::TestDeleteRootDirContents(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB/CD"));
+ CreateFile(fs, "AB/abc", "");
+
+ auto st = fs->DeleteRootDirContents();
+ if (!st.ok()) {
+ // Not all filesystems support deleting root directory contents
+ ASSERT_TRUE(st.IsInvalid() || st.IsNotImplemented());
+ AssertAllDirs(fs, {"AB", "AB/CD"});
+ AssertAllFiles(fs, {"AB/abc"});
+ } else {
+ if (!have_flaky_directory_tree_deletion()) {
+ AssertAllDirs(fs, {});
+ }
+ AssertAllFiles(fs, {});
+ }
+}
+
+void GenericFileSystemTest::TestDeleteFile(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB"));
+ CreateFile(fs, "AB/def", "");
+ AssertAllDirs(fs, {"AB"});
+ AssertAllFiles(fs, {"AB/def"});
+
+ ASSERT_OK(fs->DeleteFile("AB/def"));
+ AssertAllDirs(fs, {"AB"});
+ AssertAllFiles(fs, {});
+
+ CreateFile(fs, "abc", "data");
+ AssertAllDirs(fs, {"AB"});
+ AssertAllFiles(fs, {"abc"});
+
+ ASSERT_OK(fs->DeleteFile("abc"));
+ AssertAllDirs(fs, {"AB"});
+ AssertAllFiles(fs, {});
+
+ // File doesn't exist
+ ASSERT_RAISES(IOError, fs->DeleteFile("abc"));
+ ASSERT_RAISES(IOError, fs->DeleteFile("AB/def"));
+
+ // Not a file
+ ASSERT_RAISES(IOError, fs->DeleteFile("AB"));
+ AssertAllDirs(fs, {"AB"});
+ AssertAllFiles(fs, {});
+}
+
+void GenericFileSystemTest::TestDeleteFiles(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB"));
+ CreateFile(fs, "abc", "");
+ CreateFile(fs, "AB/def", "123");
+ CreateFile(fs, "AB/ghi", "456");
+ CreateFile(fs, "AB/jkl", "789");
+ CreateFile(fs, "AB/mno", "789");
+ AssertAllDirs(fs, {"AB"});
+ AssertAllFiles(fs, {"AB/def", "AB/ghi", "AB/jkl", "AB/mno", "abc"});
+
+ // All successful
+ ASSERT_OK(fs->DeleteFiles({"abc", "AB/def"}));
+ AssertAllDirs(fs, {"AB"});
+ AssertAllFiles(fs, {"AB/ghi", "AB/jkl", "AB/mno"});
+
+ // One error: file doesn't exist
+ ASSERT_RAISES(IOError, fs->DeleteFiles({"xx", "AB/jkl"}));
+ AssertAllDirs(fs, {"AB"});
+ AssertAllFiles(fs, {"AB/ghi", "AB/mno"});
+
+ // One error: not a file
+ ASSERT_RAISES(IOError, fs->DeleteFiles({"AB", "AB/mno"}));
+ AssertAllDirs(fs, {"AB"});
+ AssertAllFiles(fs, {"AB/ghi"});
+}
+
+void GenericFileSystemTest::TestMoveFile(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB/CD"));
+ ASSERT_OK(fs->CreateDir("EF"));
+ CreateFile(fs, "abc", "data");
+ std::vector<std::string> all_dirs{"AB", "AB/CD", "EF"};
+ AssertAllDirs(fs, all_dirs);
+ AssertAllFiles(fs, {"abc"});
+
+ // Move inside root dir
+ ASSERT_OK(fs->Move("abc", "def"));
+ AssertAllDirs(fs, all_dirs);
+ AssertAllFiles(fs, {"def"});
+ AssertFileInfo(fs, "def", FileType::File, 4);
+ AssertFileContents(fs, "def", "data");
+
+ // Move out of root dir
+ ASSERT_OK(fs->Move("def", "AB/CD/ghi"));
+ AssertAllDirs(fs, all_dirs);
+ AssertAllFiles(fs, {"AB/CD/ghi"});
+ AssertFileInfo(fs, "AB/CD/ghi", FileType::File, 4);
+ AssertFileContents(fs, "AB/CD/ghi", "data");
+
+ ASSERT_OK(fs->Move("AB/CD/ghi", "EF/jkl"));
+ AssertAllDirs(fs, all_dirs);
+ AssertAllFiles(fs, {"EF/jkl"});
+ AssertFileInfo(fs, "EF/jkl", FileType::File, 4);
+ AssertFileContents(fs, "EF/jkl", "data");
+
+ // Move back into root dir
+ ASSERT_OK(fs->Move("EF/jkl", "mno"));
+ AssertAllDirs(fs, all_dirs);
+ AssertAllFiles(fs, {"mno"});
+ AssertFileInfo(fs, "mno", FileType::File, 4);
+ AssertFileContents(fs, "mno", "data");
+
+ // Destination is a file => clobber
+ CreateFile(fs, "AB/pqr", "other data");
+ AssertAllFiles(fs, {"AB/pqr", "mno"});
+ ASSERT_OK(fs->Move("mno", "AB/pqr"));
+ AssertAllFiles(fs, {"AB/pqr"});
+ AssertFileInfo(fs, "AB/pqr", FileType::File, 4);
+ AssertFileContents(fs, "AB/pqr", "data");
+
+ // Identical source and destination: allowed to succeed or raise IOError,
+ // but should not lose data.
+ auto err = fs->Move("AB/pqr", "AB/pqr");
+ if (!err.ok()) {
+ ASSERT_RAISES(IOError, err);
+ }
+ AssertAllFiles(fs, {"AB/pqr"});
+ AssertFileInfo(fs, "AB/pqr", FileType::File, 4);
+ AssertFileContents(fs, "AB/pqr", "data");
+
+ // Source doesn't exist
+ ASSERT_RAISES(IOError, fs->Move("abc", "def"));
+ if (!have_implicit_directories()) {
+ // Parent destination doesn't exist
+ ASSERT_RAISES(IOError, fs->Move("AB/pqr", "XX/mno"));
+ }
+ // Parent destination is not a directory
+ CreateFile(fs, "xxx", "");
+ ASSERT_RAISES(IOError, fs->Move("AB/pqr", "xxx/mno"));
+ if (!allow_write_file_over_dir()) {
+ // Destination is a directory
+ ASSERT_RAISES(IOError, fs->Move("AB/pqr", "EF"));
+ }
+ AssertAllDirs(fs, all_dirs);
+ AssertAllFiles(fs, {"AB/pqr", "xxx"});
+}
+
+void GenericFileSystemTest::TestMoveDir(FileSystem* fs) {
+ if (!allow_move_dir()) {
+ GTEST_SKIP() << "Filesystem doesn't allow moving directories";
+ }
+
+ ASSERT_OK(fs->CreateDir("AB/CD"));
+ ASSERT_OK(fs->CreateDir("EF"));
+ CreateFile(fs, "AB/abc", "abc data");
+ CreateFile(fs, "AB/CD/def", "def data");
+ CreateFile(fs, "EF/ghi", "ghi data");
+ AssertAllDirs(fs, {"AB", "AB/CD", "EF"});
+ AssertAllFiles(fs, {"AB/CD/def", "AB/abc", "EF/ghi"});
+
+ // Move inside root dir
+ ASSERT_OK(fs->Move("AB", "GH"));
+ AssertAllDirs(fs, {"EF", "GH", "GH/CD"});
+ AssertAllFiles(fs, {"EF/ghi", "GH/CD/def", "GH/abc"});
+
+ // Move out of root dir
+ ASSERT_OK(fs->Move("GH", "EF/IJ"));
+ AssertAllDirs(fs, {"EF", "EF/IJ", "EF/IJ/CD"});
+ AssertAllFiles(fs, {"EF/IJ/CD/def", "EF/IJ/abc", "EF/ghi"});
+
+ // Move back into root dir
+ ASSERT_OK(fs->Move("EF/IJ", "KL"));
+ AssertAllDirs(fs, {"EF", "KL", "KL/CD"});
+ AssertAllFiles(fs, {"EF/ghi", "KL/CD/def", "KL/abc"});
+
+ // Overwrite file with directory => untested (implementation-dependent)
+
+ // Identical source and destination: allowed to succeed or raise IOError,
+ // but should not lose data.
+ Status st = fs->Move("KL", "KL");
+ if (!st.ok()) {
+ ASSERT_RAISES(IOError, st);
+ }
+ AssertAllDirs(fs, {"EF", "KL", "KL/CD"});
+ AssertAllFiles(fs, {"EF/ghi", "KL/CD/def", "KL/abc"});
+
+ // Cannot move directory inside itself
+ ASSERT_RAISES(IOError, fs->Move("KL", "KL/ZZ"));
+
+ // Contents didn't change
+ AssertAllDirs(fs, {"EF", "KL", "KL/CD"});
+ AssertFileContents(fs, "KL/abc", "abc data");
+ AssertFileContents(fs, "KL/CD/def", "def data");
+
+ // Destination is a non-empty directory
+ if (!allow_move_dir_over_non_empty_dir()) {
+ ASSERT_RAISES(IOError, fs->Move("KL", "EF"));
+ AssertAllDirs(fs, {"EF", "KL", "KL/CD"});
+ AssertAllFiles(fs, {"EF/ghi", "KL/CD/def", "KL/abc"});
+ } else {
+ // In some filesystems such as HDFS, this operation is interpreted
+ // as with the Unix `mv` command, i.e. move KL *inside* EF.
+ ASSERT_OK(fs->Move("KL", "EF"));
+ AssertAllDirs(fs, {"EF", "EF/KL", "EF/KL/CD"});
+ AssertAllFiles(fs, {"EF/KL/CD/def", "EF/KL/abc", "EF/ghi"});
+ }
+
+ // (other errors tested in TestMoveFile)
+}
+
+void GenericFileSystemTest::TestCopyFile(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB/CD"));
+ ASSERT_OK(fs->CreateDir("EF"));
+ CreateFile(fs, "AB/abc", "data");
+ std::vector<std::string> all_dirs{"AB", "AB/CD", "EF"};
+
+ // Copy into root dir
+ ASSERT_OK(fs->CopyFile("AB/abc", "def"));
+ AssertAllDirs(fs, all_dirs);
+ AssertAllFiles(fs, {"AB/abc", "def"});
+
+ // Copy out of root dir
+ ASSERT_OK(fs->CopyFile("def", "EF/ghi"));
+ AssertAllDirs(fs, all_dirs);
+ AssertAllFiles(fs, {"AB/abc", "EF/ghi", "def"});
+
+ // Overwrite contents for one file => other data shouldn't change
+ CreateFile(fs, "def", "other data");
+ AssertFileContents(fs, "AB/abc", "data");
+ AssertFileContents(fs, "def", "other data");
+ AssertFileContents(fs, "EF/ghi", "data");
+
+ // Destination is a file => clobber
+ ASSERT_OK(fs->CopyFile("def", "AB/abc"));
+ AssertAllDirs(fs, all_dirs);
+ AssertAllFiles(fs, {"AB/abc", "EF/ghi", "def"});
+ AssertFileContents(fs, "AB/abc", "other data");
+ AssertFileContents(fs, "def", "other data");
+ AssertFileContents(fs, "EF/ghi", "data");
+
+ // Identical source and destination: allowed to succeed or raise IOError,
+ // but should not lose data.
+ Status st = fs->CopyFile("def", "def");
+ if (!st.ok()) {
+ ASSERT_RAISES(IOError, st);
+ }
+ AssertAllFiles(fs, {"AB/abc", "EF/ghi", "def"});
+ AssertFileContents(fs, "def", "other data");
+
+ // Source doesn't exist
+ ASSERT_RAISES(IOError, fs->CopyFile("abc", "xxx"));
+ if (!allow_write_file_over_dir()) {
+ // Destination is a non-empty directory
+ ASSERT_RAISES(IOError, fs->CopyFile("def", "AB"));
+ }
+ if (!have_implicit_directories()) {
+ // Parent destination doesn't exist
+ ASSERT_RAISES(IOError, fs->CopyFile("AB/abc", "XX/mno"));
+ }
+ // Parent destination is not a directory
+ ASSERT_RAISES(IOError, fs->CopyFile("AB/abc", "def/mno"));
+ AssertAllDirs(fs, all_dirs);
+ AssertAllFiles(fs, {"AB/abc", "EF/ghi", "def"});
+}
+
+void GenericFileSystemTest::TestGetFileInfo(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB/CD/EF"));
+ CreateFile(fs, "AB/CD/ghi", "some data");
+ CreateFile(fs, "AB/CD/jkl", "some other data");
+
+ FileInfo info;
+ TimePoint first_dir_time, first_file_time;
+
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo("AB"));
+ AssertFileInfo(info, "AB", FileType::Directory);
+ ASSERT_EQ(info.base_name(), "AB");
+ ASSERT_EQ(info.size(), kNoSize);
+ first_dir_time = info.mtime();
+ if (have_directory_mtimes()) {
+ ValidateTimePoint(first_dir_time);
+ }
+
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo("AB/CD/EF"));
+ AssertFileInfo(info, "AB/CD/EF", FileType::Directory);
+ ASSERT_EQ(info.base_name(), "EF");
+ ASSERT_EQ(info.size(), kNoSize);
+ // AB/CD's creation can impact AB's modification time, however, AB/CD/EF's
+ // creation doesn't, so AB/CD/EF's mtime should be after AB's.
+ if (have_directory_mtimes()) {
+ AssertDurationBetween(info.mtime() - first_dir_time, 0.0, kTimeSlack);
+ }
+
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo("AB/CD/ghi"));
+ AssertFileInfo(info, "AB/CD/ghi", FileType::File, 9);
+ ASSERT_EQ(info.base_name(), "ghi");
+ first_file_time = info.mtime();
+ // AB/CD/ghi's creation doesn't impact AB's modification time,
+ // so AB/CD/ghi's mtime should be after AB's.
+ if (have_directory_mtimes()) {
+ AssertDurationBetween(first_file_time - first_dir_time, 0.0, kTimeSlack);
+ }
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo("AB/CD/jkl"));
+ AssertFileInfo(info, "AB/CD/jkl", FileType::File, 15);
+ // This file was created after the one above
+ AssertDurationBetween(info.mtime() - first_file_time, 0.0, kTimeSlack);
+
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo("zz"));
+ AssertFileInfo(info, "zz", FileType::NotFound);
+ ASSERT_EQ(info.base_name(), "zz");
+ ASSERT_EQ(info.size(), kNoSize);
+ ASSERT_EQ(info.mtime(), kNoTime);
+}
+
+void GenericFileSystemTest::TestGetFileInfoAsync(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB/CD"));
+ CreateFile(fs, "AB/CD/ghi", "some data");
+
+ std::vector<FileInfo> infos;
+ auto fut = fs->GetFileInfoAsync({"AB", "AB/CD", "AB/zz", "zz", "XX/zz", "AB/CD/ghi"});
+ ASSERT_FINISHES_OK_AND_ASSIGN(infos, fut);
+
+ ASSERT_EQ(infos.size(), 6);
+ AssertFileInfo(infos[0], "AB", FileType::Directory);
+ AssertFileInfo(infos[1], "AB/CD", FileType::Directory);
+ AssertFileInfo(infos[2], "AB/zz", FileType::NotFound);
+ AssertFileInfo(infos[3], "zz", FileType::NotFound);
+ AssertFileInfo(infos[4], "XX/zz", FileType::NotFound);
+ AssertFileInfo(infos[5], "AB/CD/ghi", FileType::File, 9);
+}
+
+void GenericFileSystemTest::TestGetFileInfoVector(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB/CD"));
+ CreateFile(fs, "AB/CD/ghi", "some data");
+
+ std::vector<FileInfo> infos;
+ TimePoint dir_time, file_time;
+ ASSERT_OK_AND_ASSIGN(
+ infos, fs->GetFileInfo({"AB", "AB/CD", "AB/zz", "zz", "XX/zz", "AB/CD/ghi"}));
+ ASSERT_EQ(infos.size(), 6);
+ AssertFileInfo(infos[0], "AB", FileType::Directory);
+ dir_time = infos[0].mtime();
+ if (have_directory_mtimes()) {
+ ValidateTimePoint(dir_time);
+ }
+ AssertFileInfo(infos[1], "AB/CD", FileType::Directory);
+ AssertFileInfo(infos[2], "AB/zz", FileType::NotFound);
+ AssertFileInfo(infos[3], "zz", FileType::NotFound);
+ AssertFileInfo(infos[4], "XX/zz", FileType::NotFound);
+ ASSERT_EQ(infos[4].size(), kNoSize);
+ ASSERT_EQ(infos[4].mtime(), kNoTime);
+ AssertFileInfo(infos[5], "AB/CD/ghi", FileType::File, 9);
+ file_time = infos[5].mtime();
+ if (have_directory_mtimes()) {
+ AssertDurationBetween(file_time - dir_time, 0.0, kTimeSlack);
+ } else {
+ ValidateTimePoint(file_time);
+ }
+
+ // Check the mtime is the same from one call to the other
+ FileInfo info;
+ if (have_directory_mtimes()) {
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo("AB"));
+ AssertFileInfo(info, "AB", FileType::Directory);
+ ASSERT_EQ(info.mtime(), dir_time);
+ }
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo("AB/CD/ghi"));
+ AssertFileInfo(info, "AB/CD/ghi", FileType::File, 9);
+ ASSERT_EQ(info.mtime(), file_time);
+}
+
+void GenericFileSystemTest::TestGetFileInfoSelector(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB/CD"));
+ CreateFile(fs, "abc", "data");
+ CreateFile(fs, "AB/def", "some data");
+ CreateFile(fs, "AB/CD/ghi", "some other data");
+ CreateFile(fs, "AB/CD/jkl", "yet other data");
+
+ TimePoint first_dir_time, first_file_time;
+ FileSelector s;
+ s.base_dir = "";
+ std::vector<FileInfo> infos;
+ ASSERT_OK_AND_ASSIGN(infos, fs->GetFileInfo(s));
+ // Need to sort results to make testing deterministic
+ SortInfos(&infos);
+ ASSERT_EQ(infos.size(), 2);
+ AssertFileInfo(infos[0], "AB", FileType::Directory);
+ first_dir_time = infos[0].mtime();
+ if (have_directory_mtimes()) {
+ ValidateTimePoint(first_dir_time);
+ }
+ AssertFileInfo(infos[1], "abc", FileType::File, 4);
+
+ s.base_dir = "AB";
+ ASSERT_OK_AND_ASSIGN(infos, fs->GetFileInfo(s));
+ SortInfos(&infos);
+ ASSERT_EQ(infos.size(), 2);
+ AssertFileInfo(infos[0], "AB/CD", FileType::Directory);
+ AssertFileInfo(infos[1], "AB/def", FileType::File, 9);
+
+ s.base_dir = "AB/CD";
+ ASSERT_OK_AND_ASSIGN(infos, fs->GetFileInfo(s));
+ SortInfos(&infos);
+ ASSERT_EQ(infos.size(), 2);
+ AssertFileInfo(infos[0], "AB/CD/ghi", FileType::File, 15);
+ AssertFileInfo(infos[1], "AB/CD/jkl", FileType::File, 14);
+ first_file_time = infos[0].mtime();
+ if (have_directory_mtimes()) {
+ AssertDurationBetween(first_file_time - first_dir_time, 0.0, kTimeSlack);
+ }
+ AssertDurationBetween(infos[1].mtime() - first_file_time, 0.0, kTimeSlack);
+
+ // Recursive
+ s.base_dir = "AB";
+ s.recursive = true;
+ ASSERT_OK_AND_ASSIGN(infos, fs->GetFileInfo(s));
+ SortInfos(&infos);
+ ASSERT_EQ(infos.size(), 4);
+ AssertFileInfo(infos[0], "AB/CD", FileType::Directory);
+ AssertFileInfo(infos[1], "AB/CD/ghi", FileType::File, first_file_time, 15);
+ AssertFileInfo(infos[2], "AB/CD/jkl", FileType::File, 14);
+ AssertFileInfo(infos[3], "AB/def", FileType::File, 9);
+
+ // Check the mtime is the same from one call to the other
+ FileInfo info;
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo("AB"));
+ AssertFileInfo(info, "AB", FileType::Directory, first_dir_time);
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo("AB/CD/ghi"));
+ AssertFileInfo(info, "AB/CD/ghi", FileType::File, first_file_time, 15);
+
+ // Doesn't exist
+ s.base_dir = "XX";
+ ASSERT_RAISES(IOError, fs->GetFileInfo(s));
+ s.allow_not_found = true;
+ ASSERT_OK_AND_ASSIGN(infos, fs->GetFileInfo(s));
+ ASSERT_EQ(infos.size(), 0);
+ s.allow_not_found = false;
+
+ // Not a dir
+ s.base_dir = "abc";
+ ASSERT_RAISES(IOError, fs->GetFileInfo(s));
+}
+
+void GenericFileSystemTest::TestGetFileInfoGenerator(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB/CD"));
+ CreateFile(fs, "abc", "data");
+ CreateFile(fs, "AB/def", "some data");
+ CreateFile(fs, "AB/CD/ghi", "some other data");
+ CreateFile(fs, "AB/CD/jkl", "yet other data");
+
+ FileSelector s;
+ s.base_dir = "";
+ std::vector<FileInfo> infos;
+ std::vector<std::vector<FileInfo>> nested_infos;
+
+ // Non-recursive
+ auto gen = fs->GetFileInfoGenerator(s);
+ CollectFileInfoGenerator(std::move(gen), &infos);
+ SortInfos(&infos);
+ ASSERT_EQ(infos.size(), 2);
+ AssertFileInfo(infos[0], "AB", FileType::Directory);
+ AssertFileInfo(infos[1], "abc", FileType::File, 4);
+
+ // Recursive
+ s.base_dir = "AB";
+ s.recursive = true;
+ CollectFileInfoGenerator(fs->GetFileInfoGenerator(s), &infos);
+ SortInfos(&infos);
+ ASSERT_EQ(infos.size(), 4);
+ AssertFileInfo(infos[0], "AB/CD", FileType::Directory);
+ AssertFileInfo(infos[1], "AB/CD/ghi", FileType::File, 15);
+ AssertFileInfo(infos[2], "AB/CD/jkl", FileType::File, 14);
+ AssertFileInfo(infos[3], "AB/def", FileType::File, 9);
+
+ // Doesn't exist
+ s.base_dir = "XX";
+ auto fut = CollectAsyncGenerator(fs->GetFileInfoGenerator(s));
+ ASSERT_FINISHES_AND_RAISES(IOError, fut);
+ s.allow_not_found = true;
+ CollectFileInfoGenerator(fs->GetFileInfoGenerator(s), &infos);
+ ASSERT_EQ(infos.size(), 0);
+}
+
+void GetSortedInfos(FileSystem* fs, FileSelector s, std::vector<FileInfo>& infos) {
+ ASSERT_OK_AND_ASSIGN(infos, fs->GetFileInfo(s));
+ // Clear mtime & size for easier testing.
+ for_each(infos.begin(), infos.end(), [](FileInfo& info) {
+ info.set_mtime(kNoTime);
+ info.set_size(kNoSize);
+ });
+ SortInfos(&infos);
+}
+
+void GenericFileSystemTest::TestGetFileInfoSelectorWithRecursion(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("01/02/03/04"));
+ ASSERT_OK(fs->CreateDir("AA"));
+ CreateFile(fs, "00.file", "00");
+ CreateFile(fs, "01/01.file", "01");
+ CreateFile(fs, "AA/AA.file", "aa");
+ CreateFile(fs, "01/02/02.file", "02");
+ CreateFile(fs, "01/02/03/03.file", "03");
+ CreateFile(fs, "01/02/03/04/04.file", "04");
+
+ std::vector<FileInfo> infos;
+ FileSelector s;
+
+ s.base_dir = "";
+ s.recursive = false;
+ GetSortedInfos(fs, s, infos);
+ EXPECT_THAT(infos, ElementsAre(File("00.file"), Dir("01"), Dir("AA")));
+
+ // recursive should prevail on max_recursion
+ s.max_recursion = 9000;
+ GetSortedInfos(fs, s, infos);
+ EXPECT_THAT(infos, ElementsAre(File("00.file"), Dir("01"), Dir("AA")));
+
+ // recursive but no traversal
+ s.recursive = true;
+ s.max_recursion = 0;
+ GetSortedInfos(fs, s, infos);
+ EXPECT_THAT(infos, ElementsAre(File("00.file"), Dir("01"), Dir("AA")));
+
+ s.recursive = true;
+ s.max_recursion = 1;
+ GetSortedInfos(fs, s, infos);
+ EXPECT_THAT(infos, ElementsAre(File("00.file"), Dir("01"), File("01/01.file"),
+ Dir("01/02"), Dir("AA"), File("AA/AA.file")));
+
+ s.recursive = true;
+ s.max_recursion = 2;
+ GetSortedInfos(fs, s, infos);
+ EXPECT_THAT(infos, ElementsAre(File("00.file"), Dir("01"), File("01/01.file"),
+ Dir("01/02"), File("01/02/02.file"), Dir("01/02/03"),
+ Dir("AA"), File("AA/AA.file")));
+
+ s.base_dir = "01";
+ s.recursive = false;
+ GetSortedInfos(fs, s, infos);
+ EXPECT_THAT(infos, ElementsAre(File("01/01.file"), Dir("01/02")));
+
+ s.base_dir = "01";
+ s.recursive = true;
+ s.max_recursion = 1;
+ GetSortedInfos(fs, s, infos);
+ EXPECT_THAT(infos, ElementsAre(File("01/01.file"), Dir("01/02"), File("01/02/02.file"),
+ Dir("01/02/03")));
+
+ // All-in
+ s.base_dir = "";
+ s.recursive = true;
+ s.max_recursion = INT32_MAX;
+ GetSortedInfos(fs, s, infos);
+ EXPECT_THAT(
+ infos, ElementsAre(File("00.file"), Dir("01"), File("01/01.file"), Dir("01/02"),
+ File("01/02/02.file"), Dir("01/02/03"), File("01/02/03/03.file"),
+ Dir("01/02/03/04"), File("01/02/03/04/04.file"), Dir("AA"),
+ File("AA/AA.file")));
+}
+
+void GenericFileSystemTest::TestOpenOutputStream(FileSystem* fs) {
+ std::shared_ptr<io::OutputStream> stream;
+
+ ASSERT_OK_AND_ASSIGN(stream, fs->OpenOutputStream("abc"));
+ ASSERT_OK_AND_EQ(0, stream->Tell());
+ ASSERT_FALSE(stream->closed());
+ ASSERT_OK(stream->Close());
+ ASSERT_TRUE(stream->closed());
+ AssertAllDirs(fs, {});
+ AssertAllFiles(fs, {"abc"});
+ AssertFileContents(fs, "abc", "");
+
+ // Parent does not exist
+ if (!have_implicit_directories()) {
+ ASSERT_RAISES(IOError, fs->OpenOutputStream("AB/def"));
+ }
+ AssertAllDirs(fs, {});
+ AssertAllFiles(fs, {"abc"});
+
+ // Several writes
+ ASSERT_OK(fs->CreateDir("CD"));
+ ASSERT_OK_AND_ASSIGN(stream, fs->OpenOutputStream("CD/ghi"));
+ ASSERT_OK(stream->Write("some "));
+ ASSERT_OK(stream->Write(Buffer::FromString("data")));
+ ASSERT_OK_AND_EQ(9, stream->Tell());
+ ASSERT_OK(stream->Close());
+ AssertAllDirs(fs, {"CD"});
+ AssertAllFiles(fs, {"CD/ghi", "abc"});
+ AssertFileContents(fs, "CD/ghi", "some data");
+
+ // Overwrite
+ ASSERT_OK_AND_ASSIGN(stream, fs->OpenOutputStream("CD/ghi"));
+ ASSERT_OK(stream->Write("overwritten"));
+ ASSERT_OK(stream->Close());
+ AssertAllDirs(fs, {"CD"});
+ AssertAllFiles(fs, {"CD/ghi", "abc"});
+ AssertFileContents(fs, "CD/ghi", "overwritten");
+
+ ASSERT_RAISES(Invalid, stream->Write("x")); // Stream is closed
+
+ // Storing metadata along file
+ auto metadata = KeyValueMetadata::Make({"Content-Type", "Content-Language"},
+ {"x-arrow/filesystem-test", "fr_FR"});
+ ASSERT_OK_AND_ASSIGN(stream, fs->OpenOutputStream("jkl", metadata));
+ ASSERT_OK(stream->Write("data"));
+ ASSERT_OK(stream->Close());
+ ASSERT_OK_AND_ASSIGN(auto input, fs->OpenInputStream("jkl"));
+ ASSERT_OK_AND_ASSIGN(auto got_metadata, input->ReadMetadata());
+ if (have_file_metadata()) {
+ ASSERT_NE(got_metadata, nullptr);
+ ASSERT_GE(got_metadata->size(), 2);
+ ASSERT_OK_AND_EQ("x-arrow/filesystem-test", got_metadata->Get("Content-Type"));
+ } else {
+ if (got_metadata) {
+ ASSERT_EQ(got_metadata->size(), 0);
+ }
+ }
+
+ if (!allow_write_file_over_dir()) {
+ // Cannot turn dir into file
+ ASSERT_RAISES(IOError, fs->OpenOutputStream("CD"));
+ AssertAllDirs(fs, {"CD"});
+ }
+}
+
+void GenericFileSystemTest::TestOpenAppendStream(FileSystem* fs) {
+ if (!allow_append_to_file()) {
+ GTEST_SKIP() << "Filesystem doesn't allow file appends";
+ }
+
+ std::shared_ptr<io::OutputStream> stream;
+
+ if (allow_append_to_new_file()) {
+ ASSERT_OK_AND_ASSIGN(stream, fs->OpenAppendStream("abc"));
+ } else {
+ ASSERT_OK_AND_ASSIGN(stream, fs->OpenOutputStream("abc"));
+ }
+ ASSERT_OK_AND_EQ(0, stream->Tell());
+ ASSERT_OK(stream->Write("some "));
+ ASSERT_OK(stream->Write(Buffer::FromString("data")));
+ ASSERT_OK_AND_EQ(9, stream->Tell());
+ ASSERT_OK(stream->Close());
+ AssertAllDirs(fs, {});
+ AssertAllFiles(fs, {"abc"});
+ AssertFileContents(fs, "abc", "some data");
+
+ ASSERT_OK_AND_ASSIGN(stream, fs->OpenAppendStream("abc"));
+ ASSERT_OK_AND_EQ(9, stream->Tell());
+ ASSERT_OK(stream->Write(" appended"));
+ ASSERT_OK(stream->Close());
+ AssertAllDirs(fs, {});
+ AssertAllFiles(fs, {"abc"});
+ AssertFileContents(fs, "abc", "some data appended");
+
+ ASSERT_RAISES(Invalid, stream->Write("x")); // Stream is closed
+}
+
+void GenericFileSystemTest::TestOpenInputStream(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB"));
+ CreateFile(fs, "AB/abc", "some data");
+
+ std::shared_ptr<io::InputStream> stream;
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_OK_AND_ASSIGN(stream, fs->OpenInputStream("AB/abc"));
+ ASSERT_OK_AND_ASSIGN(auto metadata, stream->ReadMetadata());
+ // XXX we cannot really test anything more about metadata...
+ ASSERT_OK_AND_ASSIGN(buffer, stream->Read(4));
+ AssertBufferEqual(*buffer, "some");
+ ASSERT_OK_AND_ASSIGN(buffer, stream->Read(6));
+ AssertBufferEqual(*buffer, " data");
+ ASSERT_OK_AND_ASSIGN(buffer, stream->Read(1));
+ AssertBufferEqual(*buffer, "");
+ ASSERT_OK(stream->Close());
+ ASSERT_RAISES(Invalid, stream->Read(1)); // Stream is closed
+
+ // File does not exist
+ ASSERT_RAISES(IOError, fs->OpenInputStream("AB/def"));
+ ASSERT_RAISES(IOError, fs->OpenInputStream("def"));
+
+ // Cannot open directory
+ ASSERT_RAISES(IOError, fs->OpenInputStream("AB"));
+}
+
+void GenericFileSystemTest::TestOpenInputStreamWithFileInfo(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB"));
+ CreateFile(fs, "AB/abc", "some data");
+
+ ASSERT_OK_AND_ASSIGN(auto info, fs->GetFileInfo("AB/abc"));
+
+ ASSERT_OK_AND_ASSIGN(auto stream, fs->OpenInputStream(info));
+ ASSERT_OK_AND_ASSIGN(auto buffer, stream->Read(9));
+ AssertBufferEqual(*buffer, "some data");
+
+ // Passing an incomplete FileInfo should also work
+ info.set_type(FileType::Unknown);
+ info.set_size(kNoSize);
+ info.set_mtime(kNoTime);
+ ASSERT_OK_AND_ASSIGN(stream, fs->OpenInputStream(info));
+ ASSERT_OK_AND_ASSIGN(buffer, stream->Read(4));
+ AssertBufferEqual(*buffer, "some");
+
+ // File does not exist
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo("zzzzt"));
+ ASSERT_RAISES(IOError, fs->OpenInputStream(info));
+ // (same, with incomplete FileInfo)
+ info.set_type(FileType::Unknown);
+ ASSERT_RAISES(IOError, fs->OpenInputStream(info));
+
+ // Cannot open directory
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo("AB"));
+ ASSERT_RAISES(IOError, fs->OpenInputStream(info));
+}
+
+void GenericFileSystemTest::TestOpenInputStreamAsync(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB"));
+ CreateFile(fs, "AB/abc", "some data");
+
+ std::shared_ptr<io::InputStream> stream;
+ std::shared_ptr<Buffer> buffer;
+ std::shared_ptr<const KeyValueMetadata> metadata;
+ ASSERT_FINISHES_OK_AND_ASSIGN(stream, fs->OpenInputStreamAsync("AB/abc"));
+ ASSERT_FINISHES_OK_AND_ASSIGN(metadata, stream->ReadMetadataAsync());
+ ASSERT_OK_AND_ASSIGN(buffer, stream->Read(4));
+ AssertBufferEqual(*buffer, "some");
+ ASSERT_OK(stream->Close());
+
+ // File does not exist
+ ASSERT_RAISES(IOError, fs->OpenInputStreamAsync("AB/def").result());
+}
+
+void GenericFileSystemTest::TestOpenInputFile(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB"));
+ CreateFile(fs, "AB/abc", "some other data");
+
+ std::shared_ptr<io::RandomAccessFile> file;
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_OK_AND_ASSIGN(file, fs->OpenInputFile("AB/abc"));
+ ASSERT_OK_AND_ASSIGN(buffer, file->ReadAt(5, 6));
+ AssertBufferEqual(*buffer, "other ");
+ ASSERT_OK_AND_EQ(15, file->GetSize());
+ ASSERT_OK(file->Close());
+ ASSERT_RAISES(Invalid, file->ReadAt(1, 1)); // Stream is closed
+
+ // File does not exist
+ ASSERT_RAISES(IOError, fs->OpenInputFile("AB/def"));
+ ASSERT_RAISES(IOError, fs->OpenInputFile("def"));
+
+ // Cannot open directory
+ ASSERT_RAISES(IOError, fs->OpenInputFile("AB"));
+}
+
+void GenericFileSystemTest::TestOpenInputFileAsync(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB"));
+ CreateFile(fs, "AB/abc", "some other data");
+
+ std::shared_ptr<io::RandomAccessFile> file;
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_FINISHES_OK_AND_ASSIGN(file, fs->OpenInputFileAsync("AB/abc"));
+ ASSERT_OK_AND_ASSIGN(buffer, file->ReadAt(5, 6));
+ AssertBufferEqual(*buffer, "other ");
+ ASSERT_OK(file->Close());
+
+ // File does not exist
+ ASSERT_RAISES(IOError, fs->OpenInputFileAsync("AB/def").result());
+}
+
+void GenericFileSystemTest::TestOpenInputFileWithFileInfo(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("AB"));
+ CreateFile(fs, "AB/abc", "some data");
+
+ ASSERT_OK_AND_ASSIGN(auto info, fs->GetFileInfo("AB/abc"));
+
+ ASSERT_OK_AND_ASSIGN(auto file, fs->OpenInputFile(info));
+ ASSERT_OK_AND_EQ(9, file->GetSize());
+ ASSERT_OK_AND_ASSIGN(auto buffer, file->Read(9));
+ AssertBufferEqual(*buffer, "some data");
+
+ // Passing an incomplete FileInfo should also work
+ info.set_type(FileType::Unknown);
+ info.set_size(kNoSize);
+ info.set_mtime(kNoTime);
+ ASSERT_OK_AND_ASSIGN(file, fs->OpenInputFile(info));
+ ASSERT_OK_AND_EQ(9, file->GetSize());
+ ASSERT_OK_AND_ASSIGN(buffer, file->Read(4));
+ AssertBufferEqual(*buffer, "some");
+
+ // File does not exist
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo("zzzzt"));
+ ASSERT_RAISES(IOError, fs->OpenInputFile(info));
+ // (same, with incomplete FileInfo)
+ info.set_type(FileType::Unknown);
+ ASSERT_RAISES(IOError, fs->OpenInputFile(info));
+
+ // Cannot open directory
+ ASSERT_OK_AND_ASSIGN(info, fs->GetFileInfo("AB"));
+ ASSERT_RAISES(IOError, fs->OpenInputFile(info));
+}
+
+void GenericFileSystemTest::TestSpecialChars(FileSystem* fs) {
+ ASSERT_OK(fs->CreateDir("Blank Char"));
+ CreateFile(fs, "Blank Char/Special%Char.txt", "data");
+ std::vector<std::string> all_dirs{"Blank Char"};
+
+ AssertAllDirs(fs, all_dirs);
+ AssertAllFiles(fs, {"Blank Char/Special%Char.txt"});
+ AssertFileContents(fs, "Blank Char/Special%Char.txt", "data");
+
+ ASSERT_OK(fs->CopyFile("Blank Char/Special%Char.txt", "Special and%different.txt"));
+ AssertAllDirs(fs, all_dirs);
+ AssertAllFiles(fs, {"Blank Char/Special%Char.txt", "Special and%different.txt"});
+ AssertFileContents(fs, "Special and%different.txt", "data");
+
+ ASSERT_OK(fs->DeleteFile("Special and%different.txt"));
+ ASSERT_OK(fs->DeleteDir("Blank Char"));
+ AssertAllDirs(fs, {});
+ AssertAllFiles(fs, {});
+}
+
+#define GENERIC_FS_TEST_DEFINE(FUNC_NAME) \
+ void GenericFileSystemTest::FUNC_NAME() { FUNC_NAME(GetEmptyFileSystem().get()); }
+
+GENERIC_FS_TEST_DEFINE(TestEmpty)
+GENERIC_FS_TEST_DEFINE(TestNormalizePath)
+GENERIC_FS_TEST_DEFINE(TestCreateDir)
+GENERIC_FS_TEST_DEFINE(TestDeleteDir)
+GENERIC_FS_TEST_DEFINE(TestDeleteDirContents)
+GENERIC_FS_TEST_DEFINE(TestDeleteRootDirContents)
+GENERIC_FS_TEST_DEFINE(TestDeleteFile)
+GENERIC_FS_TEST_DEFINE(TestDeleteFiles)
+GENERIC_FS_TEST_DEFINE(TestMoveFile)
+GENERIC_FS_TEST_DEFINE(TestMoveDir)
+GENERIC_FS_TEST_DEFINE(TestCopyFile)
+GENERIC_FS_TEST_DEFINE(TestGetFileInfo)
+GENERIC_FS_TEST_DEFINE(TestGetFileInfoVector)
+GENERIC_FS_TEST_DEFINE(TestGetFileInfoSelector)
+GENERIC_FS_TEST_DEFINE(TestGetFileInfoSelectorWithRecursion)
+GENERIC_FS_TEST_DEFINE(TestGetFileInfoAsync)
+GENERIC_FS_TEST_DEFINE(TestGetFileInfoGenerator)
+GENERIC_FS_TEST_DEFINE(TestOpenOutputStream)
+GENERIC_FS_TEST_DEFINE(TestOpenAppendStream)
+GENERIC_FS_TEST_DEFINE(TestOpenInputStream)
+GENERIC_FS_TEST_DEFINE(TestOpenInputStreamWithFileInfo)
+GENERIC_FS_TEST_DEFINE(TestOpenInputStreamAsync)
+GENERIC_FS_TEST_DEFINE(TestOpenInputFile)
+GENERIC_FS_TEST_DEFINE(TestOpenInputFileWithFileInfo)
+GENERIC_FS_TEST_DEFINE(TestOpenInputFileAsync)
+GENERIC_FS_TEST_DEFINE(TestSpecialChars)
+
+#undef GENERIC_FS_TEST_DEFINE
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/test_util.h b/src/arrow/cpp/src/arrow/filesystem/test_util.h
new file mode 100644
index 000000000..8d80b10ba
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/test_util.h
@@ -0,0 +1,246 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <chrono>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/filesystem/filesystem.h"
+#include "arrow/filesystem/mockfs.h"
+#include "arrow/testing/visibility.h"
+#include "arrow/util/counting_semaphore.h"
+
+namespace arrow {
+namespace fs {
+
+static constexpr double kTimeSlack = 2.0; // In seconds
+
+static inline FileInfo File(std::string path) {
+ return FileInfo(std::move(path), FileType::File);
+}
+
+static inline FileInfo Dir(std::string path) {
+ return FileInfo(std::move(path), FileType::Directory);
+}
+
+// A subclass of MockFileSystem that blocks operations until an unlock method is
+// called.
+//
+// This is intended for testing fine-grained ordering of filesystem operations.
+//
+// N.B. Only OpenOutputStream supports gating at the moment but this is simply because
+// it is all that has been needed so far. Feel free to add support for more methods
+// as required.
+class ARROW_TESTING_EXPORT GatedMockFilesystem : public internal::MockFileSystem {
+ public:
+ GatedMockFilesystem(TimePoint current_time,
+ const io::IOContext& = io::default_io_context());
+ ~GatedMockFilesystem() override;
+
+ Result<std::shared_ptr<io::OutputStream>> OpenOutputStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+
+ // Wait until at least num_waiters are waiting on OpenOutputStream
+ Status WaitForOpenOutputStream(uint32_t num_waiters);
+ // Unlock `num_waiters` individual calls to OpenOutputStream
+ Status UnlockOpenOutputStream(uint32_t num_waiters);
+
+ private:
+ util::CountingSemaphore open_output_sem_;
+};
+
+ARROW_TESTING_EXPORT
+void CreateFile(FileSystem* fs, const std::string& path, const std::string& data);
+
+// Sort a vector of FileInfo by lexicographic path order
+ARROW_TESTING_EXPORT
+void SortInfos(FileInfoVector* infos);
+
+ARROW_TESTING_EXPORT
+void CollectFileInfoGenerator(FileInfoGenerator gen, FileInfoVector* out_infos);
+
+ARROW_TESTING_EXPORT
+void AssertFileInfo(const FileInfo& info, const std::string& path, FileType type);
+
+ARROW_TESTING_EXPORT
+void AssertFileInfo(const FileInfo& info, const std::string& path, FileType type,
+ TimePoint mtime);
+
+ARROW_TESTING_EXPORT
+void AssertFileInfo(const FileInfo& info, const std::string& path, FileType type,
+ TimePoint mtime, int64_t size);
+
+ARROW_TESTING_EXPORT
+void AssertFileInfo(const FileInfo& info, const std::string& path, FileType type,
+ int64_t size);
+
+ARROW_TESTING_EXPORT
+void AssertFileInfo(FileSystem* fs, const std::string& path, FileType type);
+
+ARROW_TESTING_EXPORT
+void AssertFileInfo(FileSystem* fs, const std::string& path, FileType type,
+ TimePoint mtime);
+
+ARROW_TESTING_EXPORT
+void AssertFileInfo(FileSystem* fs, const std::string& path, FileType type,
+ TimePoint mtime, int64_t size);
+
+ARROW_TESTING_EXPORT
+void AssertFileInfo(FileSystem* fs, const std::string& path, FileType type, int64_t size);
+
+ARROW_TESTING_EXPORT
+void AssertFileContents(FileSystem* fs, const std::string& path,
+ const std::string& expected_data);
+
+template <typename Duration>
+void AssertDurationBetween(Duration d, double min_secs, double max_secs) {
+ auto seconds = std::chrono::duration_cast<std::chrono::duration<double>>(d);
+ ASSERT_GE(seconds.count(), min_secs);
+ ASSERT_LE(seconds.count(), max_secs);
+}
+
+// Generic tests for FileSystem implementations.
+// To use this class, subclass both from it and ::testing::Test,
+// implement GetEmptyFileSystem(), and use GENERIC_FS_TEST_FUNCTIONS()
+// to define the various tests.
+class ARROW_TESTING_EXPORT GenericFileSystemTest {
+ public:
+ virtual ~GenericFileSystemTest();
+
+ void TestEmpty();
+ void TestNormalizePath();
+ void TestCreateDir();
+ void TestDeleteDir();
+ void TestDeleteDirContents();
+ void TestDeleteRootDirContents();
+ void TestDeleteFile();
+ void TestDeleteFiles();
+ void TestMoveFile();
+ void TestMoveDir();
+ void TestCopyFile();
+ void TestGetFileInfo();
+ void TestGetFileInfoVector();
+ void TestGetFileInfoSelector();
+ void TestGetFileInfoSelectorWithRecursion();
+ void TestGetFileInfoAsync();
+ void TestGetFileInfoGenerator();
+ void TestOpenOutputStream();
+ void TestOpenAppendStream();
+ void TestOpenInputStream();
+ void TestOpenInputStreamWithFileInfo();
+ void TestOpenInputStreamAsync();
+ void TestOpenInputFile();
+ void TestOpenInputFileWithFileInfo();
+ void TestOpenInputFileAsync();
+ void TestSpecialChars();
+
+ protected:
+ // This function should return the filesystem under test.
+ virtual std::shared_ptr<FileSystem> GetEmptyFileSystem() = 0;
+
+ // Override the following functions to specify deviations from expected
+ // filesystem semantics.
+ // - Whether the filesystem may "implicitly" create intermediate directories
+ virtual bool have_implicit_directories() const { return false; }
+ // - Whether the filesystem may allow writing a file "over" a directory
+ virtual bool allow_write_file_over_dir() const { return false; }
+ // - Whether the filesystem allows moving a directory
+ virtual bool allow_move_dir() const { return true; }
+ // - Whether the filesystem allows moving a directory "over" a non-empty destination
+ virtual bool allow_move_dir_over_non_empty_dir() const { return false; }
+ // - Whether the filesystem allows appending to a file
+ virtual bool allow_append_to_file() const { return true; }
+ // - Whether the filesystem allows appending to a new (not existent yet) file
+ virtual bool allow_append_to_new_file() const { return true; }
+ // - Whether the filesystem supports directory modification times
+ virtual bool have_directory_mtimes() const { return true; }
+ // - Whether some directory tree deletion tests may fail randomly
+ virtual bool have_flaky_directory_tree_deletion() const { return false; }
+ // - Whether the filesystem stores some metadata alongside files
+ virtual bool have_file_metadata() const { return false; }
+
+ void TestEmpty(FileSystem* fs);
+ void TestNormalizePath(FileSystem* fs);
+ void TestCreateDir(FileSystem* fs);
+ void TestDeleteDir(FileSystem* fs);
+ void TestDeleteDirContents(FileSystem* fs);
+ void TestDeleteRootDirContents(FileSystem* fs);
+ void TestDeleteFile(FileSystem* fs);
+ void TestDeleteFiles(FileSystem* fs);
+ void TestMoveFile(FileSystem* fs);
+ void TestMoveDir(FileSystem* fs);
+ void TestCopyFile(FileSystem* fs);
+ void TestGetFileInfo(FileSystem* fs);
+ void TestGetFileInfoVector(FileSystem* fs);
+ void TestGetFileInfoSelector(FileSystem* fs);
+ void TestGetFileInfoSelectorWithRecursion(FileSystem* fs);
+ void TestGetFileInfoAsync(FileSystem* fs);
+ void TestGetFileInfoGenerator(FileSystem* fs);
+ void TestOpenOutputStream(FileSystem* fs);
+ void TestOpenAppendStream(FileSystem* fs);
+ void TestOpenInputStream(FileSystem* fs);
+ void TestOpenInputStreamWithFileInfo(FileSystem* fs);
+ void TestOpenInputStreamAsync(FileSystem* fs);
+ void TestOpenInputFile(FileSystem* fs);
+ void TestOpenInputFileWithFileInfo(FileSystem* fs);
+ void TestOpenInputFileAsync(FileSystem* fs);
+ void TestSpecialChars(FileSystem* fs);
+};
+
+#define GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, NAME) \
+ TEST_MACRO(TEST_CLASS, NAME) { this->Test##NAME(); }
+
+#define GENERIC_FS_TEST_FUNCTIONS_MACROS(TEST_MACRO, TEST_CLASS) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, Empty) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, NormalizePath) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, CreateDir) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, DeleteDir) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, DeleteDirContents) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, DeleteRootDirContents) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, DeleteFile) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, DeleteFiles) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, MoveFile) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, MoveDir) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, CopyFile) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, GetFileInfo) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, GetFileInfoVector) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, GetFileInfoSelector) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, GetFileInfoSelectorWithRecursion) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, GetFileInfoAsync) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, GetFileInfoGenerator) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, OpenOutputStream) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, OpenAppendStream) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, OpenInputStream) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, OpenInputStreamWithFileInfo) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, OpenInputStreamAsync) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, OpenInputFile) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, OpenInputFileWithFileInfo) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, OpenInputFileAsync) \
+ GENERIC_FS_TEST_FUNCTION(TEST_MACRO, TEST_CLASS, SpecialChars)
+
+#define GENERIC_FS_TEST_FUNCTIONS(TEST_CLASS) \
+ GENERIC_FS_TEST_FUNCTIONS_MACROS(TEST_F, TEST_CLASS)
+
+#define GENERIC_FS_TYPED_TEST_FUNCTIONS(TEST_CLASS) \
+ GENERIC_FS_TEST_FUNCTIONS_MACROS(TYPED_TEST, TEST_CLASS)
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/type_fwd.h b/src/arrow/cpp/src/arrow/filesystem/type_fwd.h
new file mode 100644
index 000000000..112563577
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/type_fwd.h
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+namespace arrow {
+namespace fs {
+
+/// \brief FileSystem entry type
+enum class FileType : int8_t {
+ /// Entry is not found
+ NotFound,
+ /// Entry exists but its type is unknown
+ ///
+ /// This can designate a special file such as a Unix socket or character
+ /// device, or Windows NUL / CON / ...
+ Unknown,
+ /// Entry is a regular file
+ File,
+ /// Entry is a directory
+ Directory
+};
+
+struct FileInfo;
+
+struct FileSelector;
+
+class FileSystem;
+class SubTreeFileSystem;
+class SlowFileSystem;
+class LocalFileSystem;
+class S3FileSystem;
+
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/util_internal.cc b/src/arrow/cpp/src/arrow/filesystem/util_internal.cc
new file mode 100644
index 000000000..8f8670737
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/util_internal.cc
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/filesystem/util_internal.h"
+#include "arrow/buffer.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+
+namespace arrow {
+namespace fs {
+namespace internal {
+
+TimePoint CurrentTimePoint() {
+ auto now = std::chrono::system_clock::now();
+ return TimePoint(
+ std::chrono::duration_cast<TimePoint::duration>(now.time_since_epoch()));
+}
+
+Status CopyStream(const std::shared_ptr<io::InputStream>& src,
+ const std::shared_ptr<io::OutputStream>& dest, int64_t chunk_size,
+ const io::IOContext& io_context) {
+ ARROW_ASSIGN_OR_RAISE(auto chunk, AllocateBuffer(chunk_size, io_context.pool()));
+
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read,
+ src->Read(chunk_size, chunk->mutable_data()));
+ if (bytes_read == 0) {
+ // EOF
+ break;
+ }
+ RETURN_NOT_OK(dest->Write(chunk->data(), bytes_read));
+ }
+
+ return Status::OK();
+}
+
+Status PathNotFound(const std::string& path) {
+ return Status::IOError("Path does not exist '", path, "'");
+}
+
+Status NotADir(const std::string& path) {
+ return Status::IOError("Not a directory: '", path, "'");
+}
+
+Status NotAFile(const std::string& path) {
+ return Status::IOError("Not a regular file: '", path, "'");
+}
+
+Status InvalidDeleteDirContents(const std::string& path) {
+ return Status::Invalid(
+ "DeleteDirContents called on invalid path '", path, "'. ",
+ "If you wish to delete the root directory's contents, call DeleteRootDirContents.");
+}
+
+FileSystemGlobalOptions global_options;
+
+} // namespace internal
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/filesystem/util_internal.h b/src/arrow/cpp/src/arrow/filesystem/util_internal.h
new file mode 100644
index 000000000..915c8d03d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/filesystem/util_internal.h
@@ -0,0 +1,56 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/filesystem/filesystem.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/status.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace fs {
+namespace internal {
+
+ARROW_EXPORT
+TimePoint CurrentTimePoint();
+
+ARROW_EXPORT
+Status CopyStream(const std::shared_ptr<io::InputStream>& src,
+ const std::shared_ptr<io::OutputStream>& dest, int64_t chunk_size,
+ const io::IOContext& io_context);
+
+ARROW_EXPORT
+Status PathNotFound(const std::string& path);
+
+ARROW_EXPORT
+Status NotADir(const std::string& path);
+
+ARROW_EXPORT
+Status NotAFile(const std::string& path);
+
+ARROW_EXPORT
+Status InvalidDeleteDirContents(const std::string& path);
+
+extern FileSystemGlobalOptions global_options;
+
+} // namespace internal
+} // namespace fs
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/ArrowFlightConfig.cmake.in b/src/arrow/cpp/src/arrow/flight/ArrowFlightConfig.cmake.in
new file mode 100644
index 000000000..11be45794
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/ArrowFlightConfig.cmake.in
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This config sets the following variables in your project::
+#
+# ArrowFlight_FOUND - true if Arrow Flight found on the system
+#
+# This config sets the following targets in your project::
+#
+# arrow_flight_shared - for linked as shared library if shared library is built
+# arrow_flight_static - for linked as static library if static library is built
+
+@PACKAGE_INIT@
+
+include(CMakeFindDependencyMacro)
+find_dependency(Arrow)
+
+# Load targets only once. If we load targets multiple times, CMake reports
+# already existent target error.
+if(NOT (TARGET arrow_flight_shared OR TARGET arrow_flight_static))
+ include("${CMAKE_CURRENT_LIST_DIR}/ArrowFlightTargets.cmake")
+endif()
diff --git a/src/arrow/cpp/src/arrow/flight/ArrowFlightTestingConfig.cmake.in b/src/arrow/cpp/src/arrow/flight/ArrowFlightTestingConfig.cmake.in
new file mode 100644
index 000000000..f3e1a63d6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/ArrowFlightTestingConfig.cmake.in
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This config sets the following variables in your project::
+#
+# ArrowFlightTesting_FOUND - true if Arrow Flight testing library found on the system
+#
+# This config sets the following targets in your project::
+#
+# arrow_flight_testing_shared - for linked as shared library if shared library is built
+# arrow_flight_testing_static - for linked as static library if static library is built
+
+@PACKAGE_INIT@
+
+include(CMakeFindDependencyMacro)
+find_dependency(ArrowFlight)
+find_dependency(ArrowTesting)
+
+# Load targets only once. If we load targets multiple times, CMake reports
+# already existent target error.
+if(NOT (TARGET arrow_flight_testing_shared OR TARGET arrow_flight_testing_static))
+ include("${CMAKE_CURRENT_LIST_DIR}/ArrowFlightTestingTargets.cmake")
+endif()
diff --git a/src/arrow/cpp/src/arrow/flight/CMakeLists.txt b/src/arrow/cpp/src/arrow/flight/CMakeLists.txt
new file mode 100644
index 000000000..309e5a968
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/CMakeLists.txt
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_custom_target(arrow_flight)
+
+arrow_install_all_headers("arrow/flight")
+
+set(ARROW_FLIGHT_LINK_LIBS gRPC::grpc++ ${ARROW_PROTOBUF_LIBPROTOBUF})
+
+if(WIN32)
+ list(APPEND ARROW_FLIGHT_LINK_LIBS ws2_32.lib)
+endif()
+
+if(ARROW_TEST_LINKAGE STREQUAL "static")
+ set(ARROW_FLIGHT_TEST_LINK_LIBS
+ arrow_flight_static arrow_flight_testing_static ${ARROW_FLIGHT_STATIC_LINK_LIBS}
+ ${ARROW_TEST_LINK_LIBS})
+else()
+ set(ARROW_FLIGHT_TEST_LINK_LIBS arrow_flight_shared arrow_flight_testing_shared
+ ${ARROW_TEST_LINK_LIBS})
+endif()
+
+# TODO(wesm): Protobuf shared vs static linking
+
+set(FLIGHT_PROTO_PATH "${ARROW_SOURCE_DIR}/../format")
+set(FLIGHT_PROTO ${ARROW_SOURCE_DIR}/../format/Flight.proto)
+
+set(FLIGHT_GENERATED_PROTO_FILES
+ "${CMAKE_CURRENT_BINARY_DIR}/Flight.pb.cc" "${CMAKE_CURRENT_BINARY_DIR}/Flight.pb.h"
+ "${CMAKE_CURRENT_BINARY_DIR}/Flight.grpc.pb.cc"
+ "${CMAKE_CURRENT_BINARY_DIR}/Flight.grpc.pb.h")
+
+set(PROTO_DEPENDS ${FLIGHT_PROTO} ${ARROW_PROTOBUF_LIBPROTOBUF} gRPC::grpc_cpp_plugin)
+
+add_custom_command(OUTPUT ${FLIGHT_GENERATED_PROTO_FILES}
+ COMMAND ${ARROW_PROTOBUF_PROTOC} "-I${FLIGHT_PROTO_PATH}"
+ "--cpp_out=${CMAKE_CURRENT_BINARY_DIR}" "${FLIGHT_PROTO}"
+ DEPENDS ${PROTO_DEPENDS} ARGS
+ COMMAND ${ARROW_PROTOBUF_PROTOC} "-I${FLIGHT_PROTO_PATH}"
+ "--grpc_out=${CMAKE_CURRENT_BINARY_DIR}"
+ "--plugin=protoc-gen-grpc=$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
+ "${FLIGHT_PROTO}")
+
+set_source_files_properties(${FLIGHT_GENERATED_PROTO_FILES} PROPERTIES GENERATED TRUE)
+
+add_custom_target(flight_grpc_gen ALL DEPENDS ${FLIGHT_GENERATED_PROTO_FILES})
+
+# <KLUDGE> -Werror / /WX cause try_compile to fail because there seems to be no
+# way to pass -isystem $GRPC_INCLUDE_DIR instead of -I$GRPC_INCLUDE_DIR
+set(CMAKE_CXX_FLAGS_BACKUP "${CMAKE_CXX_FLAGS}")
+string(REPLACE "/WX" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
+string(REPLACE "-Werror " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
+
+# Probe the version of gRPC being used to see if it supports disabling server
+# verification when using TLS.
+function(test_grpc_version DST_VAR DETECT_VERSION TEST_FILE)
+ if(NOT DEFINED ${DST_VAR})
+ message(STATUS "Checking support for TlsCredentialsOptions (gRPC >= ${DETECT_VERSION})..."
+ )
+ get_property(CURRENT_INCLUDE_DIRECTORIES
+ DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+ PROPERTY INCLUDE_DIRECTORIES)
+ # ARROW-13881: when detecting support, avoid mismatch between
+ # debug flags of gRPC and our probe (which results in LNK2038)
+ set(CMAKE_TRY_COMPILE_CONFIGURATION ${CMAKE_BUILD_TYPE})
+ try_compile(HAS_GRPC_VERSION ${CMAKE_CURRENT_BINARY_DIR}/try_compile
+ SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/try_compile/${TEST_FILE}"
+ CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CURRENT_INCLUDE_DIRECTORIES}"
+ LINK_LIBRARIES gRPC::grpc++
+ OUTPUT_VARIABLE TLS_CREDENTIALS_OPTIONS_CHECK_OUTPUT CXX_STANDARD 11)
+ if(HAS_GRPC_VERSION)
+ set(${DST_VAR}
+ "${DETECT_VERSION}"
+ CACHE INTERNAL "The detected (approximate) gRPC version.")
+ else()
+ message(STATUS "TlsCredentialsOptions (for gRPC ${DETECT_VERSION}) not found in grpc::experimental."
+ )
+ message(DEBUG "Build output:")
+ list(APPEND CMAKE_MESSAGE_INDENT "${TEST_FILE}: ")
+ message(DEBUG ${TLS_CREDENTIALS_OPTIONS_CHECK_OUTPUT})
+ list(REMOVE_AT CMAKE_MESSAGE_INDENT -1)
+ endif()
+ endif()
+endfunction()
+
+if(GRPC_VENDORED)
+ # v1.35.0 -> 1.35
+ string(REGEX MATCH "[0-9]+\\.[0-9]+" GRPC_VERSION "${ARROW_GRPC_BUILD_VERSION}")
+else()
+ test_grpc_version(GRPC_VERSION "1.36" "check_tls_opts_136.cc")
+ test_grpc_version(GRPC_VERSION "1.34" "check_tls_opts_134.cc")
+ test_grpc_version(GRPC_VERSION "1.32" "check_tls_opts_132.cc")
+ test_grpc_version(GRPC_VERSION "1.27" "check_tls_opts_127.cc")
+ message(STATUS "Found approximate gRPC version: ${GRPC_VERSION} (ARROW_FLIGHT_REQUIRE_TLSCREDENTIALSOPTIONS=${ARROW_FLIGHT_REQUIRE_TLSCREDENTIALSOPTIONS})"
+ )
+endif()
+if(GRPC_VERSION EQUAL "1.27")
+ add_definitions(-DGRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS=grpc_impl::experimental)
+elseif(GRPC_VERSION EQUAL "1.32")
+ add_definitions(-DGRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS=grpc::experimental)
+elseif(GRPC_VERSION EQUAL "1.34" OR GRPC_VERSION EQUAL "1.35")
+ add_definitions(-DGRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS
+ -DGRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS_ROOT_CERTS
+ -DGRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS=grpc::experimental)
+elseif(GRPC_VERSION EQUAL "1.36")
+ add_definitions(-DGRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS
+ -DGRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS=grpc::experimental)
+else()
+ message(STATUS "A proper version of gRPC could not be found to support TlsCredentialsOptions in Arrow Flight."
+ )
+ message(STATUS "You may need a newer version of gRPC (>= 1.27), or the gRPC API has changed and Flight must be updated to match."
+ )
+ if(ARROW_FLIGHT_REQUIRE_TLSCREDENTIALSOPTIONS)
+ message(FATAL_ERROR "Halting build since ARROW_FLIGHT_REQUIRE_TLSCREDENTIALSOPTIONS is set."
+ )
+ endif()
+endif()
+
+# </KLUDGE> Restore the CXXFLAGS that were modified above
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}")
+
+# Note, we do not compile the generated Protobuf sources directly, instead
+# compiling then via protocol_internal.cc which contains some gRPC template
+# overrides to enable Flight-specific optimizations. See comments in
+# protobuf-internal.cc
+set(ARROW_FLIGHT_SRCS
+ client.cc
+ client_cookie_middleware.cc
+ client_header_internal.cc
+ internal.cc
+ protocol_internal.cc
+ serialization_internal.cc
+ server.cc
+ server_auth.cc
+ types.cc)
+
+add_arrow_lib(arrow_flight
+ CMAKE_PACKAGE_NAME
+ ArrowFlight
+ PKG_CONFIG_NAME
+ arrow-flight
+ OUTPUTS
+ ARROW_FLIGHT_LIBRARIES
+ SOURCES
+ ${ARROW_FLIGHT_SRCS}
+ PRECOMPILED_HEADERS
+ "$<$<COMPILE_LANGUAGE:CXX>:arrow/flight/pch.h>"
+ DEPENDENCIES
+ flight_grpc_gen
+ SHARED_LINK_FLAGS
+ ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt
+ SHARED_LINK_LIBS
+ arrow_shared
+ ${ARROW_FLIGHT_LINK_LIBS}
+ STATIC_LINK_LIBS
+ arrow_static
+ ${ARROW_FLIGHT_LINK_LIBS})
+
+foreach(LIB_TARGET ${ARROW_FLIGHT_LIBRARIES})
+ target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_FLIGHT_EXPORTING)
+endforeach()
+
+# Define arrow_flight_testing library
+if(ARROW_TESTING)
+ add_arrow_lib(arrow_flight_testing
+ CMAKE_PACKAGE_NAME
+ ArrowFlightTesting
+ PKG_CONFIG_NAME
+ arrow-flight-testing
+ OUTPUTS
+ ARROW_FLIGHT_TESTING_LIBRARIES
+ SOURCES
+ test_integration.cc
+ test_util.cc
+ DEPENDENCIES
+ GTest::gtest
+ flight_grpc_gen
+ arrow_dependencies
+ SHARED_LINK_LIBS
+ arrow_shared
+ arrow_flight_shared
+ arrow_testing_shared
+ ${BOOST_FILESYSTEM_LIBRARY}
+ ${BOOST_SYSTEM_LIBRARY}
+ GTest::gtest
+ STATIC_LINK_LIBS
+ arrow_static
+ arrow_flight_static
+ arrow_testing_static)
+endif()
+
+foreach(LIB_TARGET ${ARROW_FLIGHT_TESTING_LIBRARIES})
+ target_compile_definitions(${LIB_TARGET}
+ PRIVATE ARROW_FLIGHT_EXPORTING
+ ${ARROW_BOOST_PROCESS_COMPILE_DEFINITIONS})
+endforeach()
+
+add_arrow_test(flight_test
+ STATIC_LINK_LIBS
+ ${ARROW_FLIGHT_TEST_LINK_LIBS}
+ LABELS
+ "arrow_flight")
+
+# Build test server for unit tests or benchmarks
+if(ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS)
+ add_executable(flight-test-server test_server.cc)
+ target_link_libraries(flight-test-server ${ARROW_FLIGHT_TEST_LINK_LIBS}
+ ${GFLAGS_LIBRARIES} GTest::gtest)
+
+ if(ARROW_BUILD_TESTS)
+ add_dependencies(arrow-flight-test flight-test-server)
+ endif()
+
+ add_dependencies(arrow_flight flight-test-server)
+endif()
+
+if(ARROW_BUILD_INTEGRATION)
+ add_executable(flight-test-integration-server test_integration_server.cc)
+ target_link_libraries(flight-test-integration-server ${ARROW_FLIGHT_TEST_LINK_LIBS}
+ ${GFLAGS_LIBRARIES} GTest::gtest)
+
+ add_executable(flight-test-integration-client test_integration_client.cc)
+ target_link_libraries(flight-test-integration-client ${ARROW_FLIGHT_TEST_LINK_LIBS}
+ ${GFLAGS_LIBRARIES} GTest::gtest)
+
+ add_dependencies(arrow_flight flight-test-integration-client
+ flight-test-integration-server)
+ add_dependencies(arrow-integration flight-test-integration-client
+ flight-test-integration-server)
+endif()
+
+if(ARROW_BUILD_BENCHMARKS)
+ # Perf server for benchmarks
+ set(PERF_PROTO_GENERATED_FILES "${CMAKE_CURRENT_BINARY_DIR}/perf.pb.cc"
+ "${CMAKE_CURRENT_BINARY_DIR}/perf.pb.h")
+
+ add_custom_command(OUTPUT ${PERF_PROTO_GENERATED_FILES}
+ COMMAND ${ARROW_PROTOBUF_PROTOC} "-I${CMAKE_CURRENT_SOURCE_DIR}"
+ "--cpp_out=${CMAKE_CURRENT_BINARY_DIR}" "perf.proto"
+ DEPENDS ${PROTO_DEPENDS})
+
+ add_executable(arrow-flight-perf-server perf_server.cc perf.pb.cc)
+ target_link_libraries(arrow-flight-perf-server ${ARROW_FLIGHT_TEST_LINK_LIBS}
+ ${GFLAGS_LIBRARIES} GTest::gtest)
+
+ add_executable(arrow-flight-benchmark flight_benchmark.cc perf.pb.cc)
+ target_link_libraries(arrow-flight-benchmark ${ARROW_FLIGHT_TEST_LINK_LIBS}
+ ${GFLAGS_LIBRARIES} GTest::gtest)
+
+ add_dependencies(arrow-flight-benchmark arrow-flight-perf-server)
+
+ add_dependencies(arrow_flight arrow-flight-benchmark)
+endif(ARROW_BUILD_BENCHMARKS)
diff --git a/src/arrow/cpp/src/arrow/flight/README.md b/src/arrow/cpp/src/arrow/flight/README.md
new file mode 100644
index 000000000..5156973ac
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/README.md
@@ -0,0 +1,36 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Arrow Flight RPC System for C++
+
+## Development notes
+
+The gRPC protobuf plugin requires that libprotoc is in your
+`LD_LIBRARY_PATH`. Until we figure out a general solution, you may need to do:
+
+```
+export LD_LIBRARY_PATH=$PROTOBUF_HOME/lib:$LD_LIBRARY_PATH
+```
+
+Currently, to run the unit tests, the directory of executables must either be
+your current working directory or you need to add it to your path, e.g.
+
+```
+PATH=debug:$PATH debug/flight-test
+``` \ No newline at end of file
diff --git a/src/arrow/cpp/src/arrow/flight/api.h b/src/arrow/cpp/src/arrow/flight/api.h
new file mode 100644
index 000000000..c58a9d48a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/api.h
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/flight/client.h"
+#include "arrow/flight/client_auth.h"
+#include "arrow/flight/client_middleware.h"
+#include "arrow/flight/middleware.h"
+#include "arrow/flight/server.h"
+#include "arrow/flight/server_auth.h"
+#include "arrow/flight/server_middleware.h"
+#include "arrow/flight/types.h"
diff --git a/src/arrow/cpp/src/arrow/flight/arrow-flight-testing.pc.in b/src/arrow/cpp/src/arrow/flight/arrow-flight-testing.pc.in
new file mode 100644
index 000000000..6946b84f7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/arrow-flight-testing.pc.in
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+
+Name: Apache Arrow Flight testing
+Description: Library for testing Apache Arrow Flight related programs.
+Version: @ARROW_VERSION@
+Requires: arrow-flight arrow-testing
+Libs: -L${libdir} -larrow_flight_testing
diff --git a/src/arrow/cpp/src/arrow/flight/arrow-flight.pc.in b/src/arrow/cpp/src/arrow/flight/arrow-flight.pc.in
new file mode 100644
index 000000000..3d66f9937
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/arrow-flight.pc.in
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+
+Name: Apache Arrow Flight
+Description: Apache Arrow's RPC system built on gRPC
+Version: @ARROW_VERSION@
+Requires: arrow
+Libs: -L${libdir} -larrow_flight
diff --git a/src/arrow/cpp/src/arrow/flight/client.cc b/src/arrow/cpp/src/arrow/flight/client.cc
new file mode 100644
index 000000000..f9728f849
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/client.cc
@@ -0,0 +1,1355 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/client.h"
+
+// Platform-specific defines
+#include "arrow/flight/platform.h"
+
+#include <map>
+#include <memory>
+#include <mutex>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <utility>
+
+#ifdef GRPCPP_PP_INCLUDE
+#include <grpcpp/grpcpp.h>
+#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS)
+#include <grpcpp/security/tls_credentials_options.h>
+#endif
+#else
+#include <grpc++/grpc++.h>
+#endif
+
+#include <grpc/grpc_security_constants.h>
+
+#include "arrow/buffer.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/uri.h"
+
+#include "arrow/flight/client_auth.h"
+#include "arrow/flight/client_header_internal.h"
+#include "arrow/flight/client_middleware.h"
+#include "arrow/flight/internal.h"
+#include "arrow/flight/middleware.h"
+#include "arrow/flight/middleware_internal.h"
+#include "arrow/flight/serialization_internal.h"
+#include "arrow/flight/types.h"
+
+namespace arrow {
+
+namespace flight {
+
+namespace pb = arrow::flight::protocol;
+
+const char* kWriteSizeDetailTypeId = "flight::FlightWriteSizeStatusDetail";
+
+FlightCallOptions::FlightCallOptions()
+ : timeout(-1),
+ read_options(ipc::IpcReadOptions::Defaults()),
+ write_options(ipc::IpcWriteOptions::Defaults()) {}
+
+const char* FlightWriteSizeStatusDetail::type_id() const {
+ return kWriteSizeDetailTypeId;
+}
+
+std::string FlightWriteSizeStatusDetail::ToString() const {
+ std::stringstream ss;
+ ss << "IPC payload size (" << actual_ << " bytes) exceeded soft limit (" << limit_
+ << " bytes)";
+ return ss.str();
+}
+
+std::shared_ptr<FlightWriteSizeStatusDetail> FlightWriteSizeStatusDetail::UnwrapStatus(
+ const arrow::Status& status) {
+ if (!status.detail() || status.detail()->type_id() != kWriteSizeDetailTypeId) {
+ return nullptr;
+ }
+ return std::dynamic_pointer_cast<FlightWriteSizeStatusDetail>(status.detail());
+}
+
+FlightClientOptions FlightClientOptions::Defaults() { return FlightClientOptions(); }
+
+Status FlightStreamReader::ReadAll(std::shared_ptr<Table>* table,
+ const StopToken& stop_token) {
+ std::vector<std::shared_ptr<RecordBatch>> batches;
+ RETURN_NOT_OK(ReadAll(&batches, stop_token));
+ ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema());
+ return Table::FromRecordBatches(schema, std::move(batches)).Value(table);
+}
+
+struct ClientRpc {
+ grpc::ClientContext context;
+
+ explicit ClientRpc(const FlightCallOptions& options) {
+ if (options.timeout.count() >= 0) {
+ std::chrono::system_clock::time_point deadline =
+ std::chrono::time_point_cast<std::chrono::system_clock::time_point::duration>(
+ std::chrono::system_clock::now() + options.timeout);
+ context.set_deadline(deadline);
+ }
+ for (auto header : options.headers) {
+ context.AddMetadata(header.first, header.second);
+ }
+ }
+
+ /// \brief Add an auth token via an auth handler
+ Status SetToken(ClientAuthHandler* auth_handler) {
+ if (auth_handler) {
+ std::string token;
+ RETURN_NOT_OK(auth_handler->GetToken(&token));
+ context.AddMetadata(internal::kGrpcAuthHeader, token);
+ }
+ return Status::OK();
+ }
+};
+
+/// Helper that manages Finish() of a gRPC stream.
+///
+/// When we encounter an error (e.g. could not decode an IPC message),
+/// we want to provide both the client-side error context and any
+/// available server-side context. This helper helps wrap up that
+/// logic.
+///
+/// This class protects the stream with a flag (so that Finish is
+/// idempotent), and drains the read side (so that Finish won't hang).
+///
+/// The template lets us abstract between DoGet/DoExchange and DoPut,
+/// which respectively read internal::FlightData and pb::PutResult.
+template <typename Stream, typename ReadT>
+class FinishableStream {
+ public:
+ FinishableStream(std::shared_ptr<ClientRpc> rpc, std::shared_ptr<Stream> stream)
+ : rpc_(rpc), stream_(stream), finished_(false), server_status_() {}
+ virtual ~FinishableStream() = default;
+
+ /// \brief Get the underlying stream.
+ std::shared_ptr<Stream> stream() const { return stream_; }
+
+ /// \brief Finish the call, adding server context to the given status.
+ virtual Status Finish(Status st) {
+ if (finished_) {
+ return MergeStatus(std::move(st));
+ }
+
+ // Drain the read side, as otherwise gRPC Finish() will hang. We
+ // only call Finish() when the client closes the writer or the
+ // reader finishes, so it's OK to assume the client no longer
+ // wants to read and drain the read side. (If the client wants to
+ // indicate that it is done writing, but not done reading, it
+ // should use DoneWriting.
+ ReadT message;
+ while (internal::ReadPayload(stream_.get(), &message)) {
+ // Drain the read side to avoid gRPC hanging in Finish()
+ }
+
+ server_status_ = internal::FromGrpcStatus(stream_->Finish(), &rpc_->context);
+ finished_ = true;
+
+ return MergeStatus(std::move(st));
+ }
+
+ private:
+ Status MergeStatus(Status&& st) {
+ if (server_status_.ok()) {
+ return std::move(st);
+ }
+ return Status::FromDetailAndArgs(
+ server_status_.code(), server_status_.detail(), server_status_.message(),
+ ". Client context: ", st.ToString(),
+ ". gRPC client debug context: ", rpc_->context.debug_error_string());
+ }
+
+ std::shared_ptr<ClientRpc> rpc_;
+ std::shared_ptr<Stream> stream_;
+ bool finished_;
+ Status server_status_;
+};
+
+/// Helper that manages \a Finish() of a read-write gRPC stream.
+///
+/// This also calls \a WritesDone() and protects itself with a mutex
+/// to enable sharing between the reader and writer.
+template <typename Stream, typename ReadT>
+class FinishableWritableStream : public FinishableStream<Stream, ReadT> {
+ public:
+ FinishableWritableStream(std::shared_ptr<ClientRpc> rpc,
+ std::shared_ptr<std::mutex> read_mutex,
+ std::shared_ptr<Stream> stream)
+ : FinishableStream<Stream, ReadT>(rpc, stream),
+ finish_mutex_(),
+ read_mutex_(read_mutex),
+ done_writing_(false) {}
+ virtual ~FinishableWritableStream() = default;
+
+ /// \brief Indicate to gRPC that the write half of the stream is done.
+ Status DoneWriting() {
+ // This is only used by the writer side of a stream, so it need
+ // not be protected with a lock.
+ if (done_writing_) {
+ return Status::OK();
+ }
+ done_writing_ = true;
+ if (!this->stream()->WritesDone()) {
+ // Error happened, try to close the stream to get more detailed info
+ return Finish(MakeFlightError(FlightStatusCode::Internal,
+ "Could not flush pending record batches"));
+ }
+ return Status::OK();
+ }
+
+ Status Finish(Status st) override {
+ // This may be used concurrently by reader/writer side of a
+ // stream, so it needs to be protected.
+ std::lock_guard<std::mutex> guard(finish_mutex_);
+
+ // Now that we're shared between a reader and writer, we need to
+ // protect ourselves from being called while there's an
+ // outstanding read.
+ std::unique_lock<std::mutex> read_guard(*read_mutex_, std::try_to_lock);
+ if (!read_guard.owns_lock()) {
+ return MakeFlightError(
+ FlightStatusCode::Internal,
+ "Cannot close stream with pending read operation. Client context: " +
+ st.ToString());
+ }
+
+ // Try to flush pending writes. Don't use our WritesDone() to
+ // avoid recursion.
+ bool finished_writes = done_writing_ || this->stream()->WritesDone();
+ done_writing_ = true;
+
+ st = FinishableStream<Stream, ReadT>::Finish(std::move(st));
+
+ if (!finished_writes) {
+ return Status::FromDetailAndArgs(
+ st.code(), st.detail(), st.message(),
+ ". Additionally, could not finish writing record batches before closing");
+ }
+ return st;
+ }
+
+ private:
+ std::mutex finish_mutex_;
+ std::shared_ptr<std::mutex> read_mutex_;
+ bool done_writing_;
+};
+
+class GrpcAddCallHeaders : public AddCallHeaders {
+ public:
+ explicit GrpcAddCallHeaders(std::multimap<grpc::string, grpc::string>* metadata)
+ : metadata_(metadata) {}
+ ~GrpcAddCallHeaders() override = default;
+
+ void AddHeader(const std::string& key, const std::string& value) override {
+ metadata_->insert(std::make_pair(key, value));
+ }
+
+ private:
+ std::multimap<grpc::string, grpc::string>* metadata_;
+};
+
+class GrpcClientInterceptorAdapter : public grpc::experimental::Interceptor {
+ public:
+ explicit GrpcClientInterceptorAdapter(
+ std::vector<std::unique_ptr<ClientMiddleware>> middleware)
+ : middleware_(std::move(middleware)), received_headers_(false) {}
+
+ void Intercept(grpc::experimental::InterceptorBatchMethods* methods) {
+ using InterceptionHookPoints = grpc::experimental::InterceptionHookPoints;
+ if (methods->QueryInterceptionHookPoint(
+ InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ GrpcAddCallHeaders add_headers(methods->GetSendInitialMetadata());
+ for (const auto& middleware : middleware_) {
+ middleware->SendingHeaders(&add_headers);
+ }
+ }
+
+ if (methods->QueryInterceptionHookPoint(
+ InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
+ if (!methods->GetRecvInitialMetadata()->empty()) {
+ ReceivedHeaders(*methods->GetRecvInitialMetadata());
+ }
+ }
+
+ if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::POST_RECV_STATUS)) {
+ DCHECK_NE(nullptr, methods->GetRecvStatus());
+ DCHECK_NE(nullptr, methods->GetRecvTrailingMetadata());
+ ReceivedHeaders(*methods->GetRecvTrailingMetadata());
+ const Status status = internal::FromGrpcStatus(*methods->GetRecvStatus());
+ for (const auto& middleware : middleware_) {
+ middleware->CallCompleted(status);
+ }
+ }
+
+ methods->Proceed();
+ }
+
+ private:
+ void ReceivedHeaders(
+ const std::multimap<grpc::string_ref, grpc::string_ref>& metadata) {
+ if (received_headers_) {
+ return;
+ }
+ received_headers_ = true;
+ CallHeaders headers;
+ for (const auto& entry : metadata) {
+ headers.insert({util::string_view(entry.first.data(), entry.first.length()),
+ util::string_view(entry.second.data(), entry.second.length())});
+ }
+ for (const auto& middleware : middleware_) {
+ middleware->ReceivedHeaders(headers);
+ }
+ }
+
+ std::vector<std::unique_ptr<ClientMiddleware>> middleware_;
+ // When communicating with a gRPC-Java server, the server may not
+ // send back headers if the call fails right away. Instead, the
+ // headers will be consolidated into the trailers. We don't want to
+ // call the client middleware callback twice, so instead track
+ // whether we saw headers - if not, then we need to check trailers.
+ bool received_headers_;
+};
+
+class GrpcClientInterceptorAdapterFactory
+ : public grpc::experimental::ClientInterceptorFactoryInterface {
+ public:
+ GrpcClientInterceptorAdapterFactory(
+ std::vector<std::shared_ptr<ClientMiddlewareFactory>> middleware)
+ : middleware_(middleware) {}
+
+ grpc::experimental::Interceptor* CreateClientInterceptor(
+ grpc::experimental::ClientRpcInfo* info) override {
+ std::vector<std::unique_ptr<ClientMiddleware>> middleware;
+
+ FlightMethod flight_method = FlightMethod::Invalid;
+ util::string_view method(info->method());
+ if (method.ends_with("/Handshake")) {
+ flight_method = FlightMethod::Handshake;
+ } else if (method.ends_with("/ListFlights")) {
+ flight_method = FlightMethod::ListFlights;
+ } else if (method.ends_with("/GetFlightInfo")) {
+ flight_method = FlightMethod::GetFlightInfo;
+ } else if (method.ends_with("/GetSchema")) {
+ flight_method = FlightMethod::GetSchema;
+ } else if (method.ends_with("/DoGet")) {
+ flight_method = FlightMethod::DoGet;
+ } else if (method.ends_with("/DoPut")) {
+ flight_method = FlightMethod::DoPut;
+ } else if (method.ends_with("/DoExchange")) {
+ flight_method = FlightMethod::DoExchange;
+ } else if (method.ends_with("/DoAction")) {
+ flight_method = FlightMethod::DoAction;
+ } else if (method.ends_with("/ListActions")) {
+ flight_method = FlightMethod::ListActions;
+ } else {
+ DCHECK(false) << "Unknown Flight method: " << info->method();
+ }
+
+ const CallInfo flight_info{flight_method};
+ for (auto& factory : middleware_) {
+ std::unique_ptr<ClientMiddleware> instance;
+ factory->StartCall(flight_info, &instance);
+ if (instance) {
+ middleware.push_back(std::move(instance));
+ }
+ }
+ return new GrpcClientInterceptorAdapter(std::move(middleware));
+ }
+
+ private:
+ std::vector<std::shared_ptr<ClientMiddlewareFactory>> middleware_;
+};
+
+class GrpcClientAuthSender : public ClientAuthSender {
+ public:
+ explicit GrpcClientAuthSender(
+ std::shared_ptr<
+ grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>>
+ stream)
+ : stream_(stream) {}
+
+ Status Write(const std::string& token) override {
+ pb::HandshakeRequest response;
+ response.set_payload(token);
+ if (stream_->Write(response)) {
+ return Status::OK();
+ }
+ return internal::FromGrpcStatus(stream_->Finish());
+ }
+
+ private:
+ std::shared_ptr<grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>>
+ stream_;
+};
+
+class GrpcClientAuthReader : public ClientAuthReader {
+ public:
+ explicit GrpcClientAuthReader(
+ std::shared_ptr<
+ grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>>
+ stream)
+ : stream_(stream) {}
+
+ Status Read(std::string* token) override {
+ pb::HandshakeResponse request;
+ if (stream_->Read(&request)) {
+ *token = std::move(*request.mutable_payload());
+ return Status::OK();
+ }
+ return internal::FromGrpcStatus(stream_->Finish());
+ }
+
+ private:
+ std::shared_ptr<grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>>
+ stream_;
+};
+
+// An ipc::MessageReader that adapts any readable gRPC stream
+// returning FlightData.
+template <typename Reader>
+class GrpcIpcMessageReader : public ipc::MessageReader {
+ public:
+ GrpcIpcMessageReader(
+ std::shared_ptr<ClientRpc> rpc, std::shared_ptr<std::mutex> read_mutex,
+ std::shared_ptr<FinishableStream<Reader, internal::FlightData>> stream,
+ std::shared_ptr<internal::PeekableFlightDataReader<std::shared_ptr<Reader>>>
+ peekable_reader,
+ std::shared_ptr<Buffer>* app_metadata)
+ : rpc_(rpc),
+ read_mutex_(read_mutex),
+ stream_(std::move(stream)),
+ peekable_reader_(peekable_reader),
+ app_metadata_(app_metadata),
+ stream_finished_(false) {}
+
+ ::arrow::Result<std::unique_ptr<ipc::Message>> ReadNextMessage() override {
+ if (stream_finished_) {
+ return nullptr;
+ }
+ internal::FlightData* data;
+ {
+ auto guard = read_mutex_ ? std::unique_lock<std::mutex>(*read_mutex_)
+ : std::unique_lock<std::mutex>();
+ peekable_reader_->Next(&data);
+ }
+ if (!data) {
+ stream_finished_ = true;
+ return stream_->Finish(Status::OK());
+ }
+ // Validate IPC message
+ auto result = data->OpenMessage();
+ if (!result.ok()) {
+ return stream_->Finish(std::move(result).status());
+ }
+ *app_metadata_ = std::move(data->app_metadata);
+ return result;
+ }
+
+ private:
+ // The RPC context lifetime must be coupled to the ClientReader
+ std::shared_ptr<ClientRpc> rpc_;
+ // Guard reads with a mutex to prevent concurrent reads if the write
+ // side calls Finish(). Nullable as DoGet doesn't need this.
+ std::shared_ptr<std::mutex> read_mutex_;
+ std::shared_ptr<FinishableStream<Reader, internal::FlightData>> stream_;
+ std::shared_ptr<internal::PeekableFlightDataReader<std::shared_ptr<Reader>>>
+ peekable_reader_;
+ // A reference to GrpcStreamReader.app_metadata_. That class
+ // can't access the app metadata because when it Peek()s the stream,
+ // it may be looking at a dictionary batch, not the record
+ // batch. Updating it here ensures the reader is always updated with
+ // the last metadata message read.
+ std::shared_ptr<Buffer>* app_metadata_;
+ bool stream_finished_;
+};
+
+/// The implementation of the public-facing API for reading from a
+/// FlightData stream
+template <typename Reader>
+class GrpcStreamReader : public FlightStreamReader {
+ public:
+ GrpcStreamReader(std::shared_ptr<ClientRpc> rpc, std::shared_ptr<std::mutex> read_mutex,
+ const ipc::IpcReadOptions& options, StopToken stop_token,
+ std::shared_ptr<FinishableStream<Reader, internal::FlightData>> stream)
+ : rpc_(rpc),
+ read_mutex_(read_mutex),
+ options_(options),
+ stop_token_(std::move(stop_token)),
+ stream_(stream),
+ peekable_reader_(new internal::PeekableFlightDataReader<std::shared_ptr<Reader>>(
+ stream->stream())),
+ app_metadata_(nullptr) {}
+
+ Status EnsureDataStarted() {
+ if (!batch_reader_) {
+ bool skipped_to_data = false;
+ {
+ auto guard = TakeGuard();
+ skipped_to_data = peekable_reader_->SkipToData();
+ }
+ // peek() until we find the first data message; discard metadata
+ if (!skipped_to_data) {
+ return OverrideWithServerError(MakeFlightError(
+ FlightStatusCode::Internal, "Server never sent a data message"));
+ }
+
+ auto message_reader =
+ std::unique_ptr<ipc::MessageReader>(new GrpcIpcMessageReader<Reader>(
+ rpc_, read_mutex_, stream_, peekable_reader_, &app_metadata_));
+ auto result =
+ ipc::RecordBatchStreamReader::Open(std::move(message_reader), options_);
+ RETURN_NOT_OK(OverrideWithServerError(std::move(result).Value(&batch_reader_)));
+ }
+ return Status::OK();
+ }
+ arrow::Result<std::shared_ptr<Schema>> GetSchema() override {
+ RETURN_NOT_OK(EnsureDataStarted());
+ return batch_reader_->schema();
+ }
+ Status Next(FlightStreamChunk* out) override {
+ internal::FlightData* data;
+ {
+ auto guard = TakeGuard();
+ peekable_reader_->Peek(&data);
+ }
+ if (!data) {
+ out->app_metadata = nullptr;
+ out->data = nullptr;
+ return stream_->Finish(Status::OK());
+ }
+
+ if (!data->metadata) {
+ // Metadata-only (data->metadata is the IPC header)
+ out->app_metadata = data->app_metadata;
+ out->data = nullptr;
+ {
+ auto guard = TakeGuard();
+ peekable_reader_->Next(&data);
+ }
+ return Status::OK();
+ }
+
+ if (!batch_reader_) {
+ RETURN_NOT_OK(EnsureDataStarted());
+ // Re-peek here since EnsureDataStarted() advances the stream
+ return Next(out);
+ }
+ RETURN_NOT_OK(batch_reader_->ReadNext(&out->data));
+ out->app_metadata = std::move(app_metadata_);
+ return Status::OK();
+ }
+ Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches) override {
+ return ReadAll(batches, stop_token_);
+ }
+ Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches,
+ const StopToken& stop_token) override {
+ FlightStreamChunk chunk;
+
+ while (true) {
+ if (stop_token.IsStopRequested()) {
+ Cancel();
+ return stop_token.Poll();
+ }
+ RETURN_NOT_OK(Next(&chunk));
+ if (!chunk.data) break;
+ batches->emplace_back(std::move(chunk.data));
+ }
+ return Status::OK();
+ }
+ Status ReadAll(std::shared_ptr<Table>* table) override {
+ return ReadAll(table, stop_token_);
+ }
+ using FlightStreamReader::ReadAll;
+ void Cancel() override { rpc_->context.TryCancel(); }
+
+ private:
+ std::unique_lock<std::mutex> TakeGuard() {
+ return read_mutex_ ? std::unique_lock<std::mutex>(*read_mutex_)
+ : std::unique_lock<std::mutex>();
+ }
+
+ Status OverrideWithServerError(Status&& st) {
+ if (st.ok()) {
+ return std::move(st);
+ }
+ return stream_->Finish(std::move(st));
+ }
+
+ friend class GrpcIpcMessageReader<Reader>;
+ std::shared_ptr<ClientRpc> rpc_;
+ // Guard reads with a lock to prevent Finish()/Close() from being
+ // called on the writer while the reader has a pending
+ // read. Nullable, as DoGet() doesn't need this.
+ std::shared_ptr<std::mutex> read_mutex_;
+ ipc::IpcReadOptions options_;
+ StopToken stop_token_;
+ std::shared_ptr<FinishableStream<Reader, internal::FlightData>> stream_;
+ std::shared_ptr<internal::PeekableFlightDataReader<std::shared_ptr<Reader>>>
+ peekable_reader_;
+ std::shared_ptr<ipc::RecordBatchReader> batch_reader_;
+ std::shared_ptr<Buffer> app_metadata_;
+};
+
+// The next two classes implement writing to a FlightData stream.
+// Similarly to the read side, we want to reuse the implementation of
+// RecordBatchWriter. As a result, these two classes are intertwined
+// in order to pass application metadata "through" RecordBatchWriter.
+// In order to get application-specific metadata to the
+// IpcPayloadWriter, DoPutPayloadWriter takes a pointer to
+// GrpcStreamWriter. GrpcStreamWriter updates a metadata field on
+// write; DoPutPayloadWriter reads that metadata field to determine
+// what to write.
+
+template <typename ProtoReadT, typename FlightReadT>
+class DoPutPayloadWriter;
+
+template <typename ProtoReadT, typename FlightReadT>
+class GrpcStreamWriter : public FlightStreamWriter {
+ public:
+ ~GrpcStreamWriter() override = default;
+
+ using GrpcStream = grpc::ClientReaderWriter<pb::FlightData, ProtoReadT>;
+
+ explicit GrpcStreamWriter(
+ const FlightDescriptor& descriptor, std::shared_ptr<ClientRpc> rpc,
+ int64_t write_size_limit_bytes, const ipc::IpcWriteOptions& options,
+ std::shared_ptr<FinishableWritableStream<GrpcStream, FlightReadT>> writer)
+ : app_metadata_(nullptr),
+ batch_writer_(nullptr),
+ writer_(std::move(writer)),
+ rpc_(std::move(rpc)),
+ write_size_limit_bytes_(write_size_limit_bytes),
+ options_(options),
+ descriptor_(descriptor),
+ writer_closed_(false) {}
+
+ static Status Open(
+ const FlightDescriptor& descriptor, std::shared_ptr<Schema> schema,
+ const ipc::IpcWriteOptions& options, std::shared_ptr<ClientRpc> rpc,
+ int64_t write_size_limit_bytes,
+ std::shared_ptr<FinishableWritableStream<GrpcStream, FlightReadT>> writer,
+ std::unique_ptr<FlightStreamWriter>* out);
+
+ Status CheckStarted() {
+ if (!batch_writer_) {
+ return Status::Invalid("Writer not initialized. Call Begin() with a schema.");
+ }
+ return Status::OK();
+ }
+
+ Status Begin(const std::shared_ptr<Schema>& schema,
+ const ipc::IpcWriteOptions& options) override {
+ if (batch_writer_) {
+ return Status::Invalid("This writer has already been started.");
+ }
+ std::unique_ptr<ipc::internal::IpcPayloadWriter> payload_writer(
+ new DoPutPayloadWriter<ProtoReadT, FlightReadT>(
+ descriptor_, std::move(rpc_), write_size_limit_bytes_, writer_, this));
+ // XXX: this does not actually write the message to the stream.
+ // See Close().
+ ARROW_ASSIGN_OR_RAISE(batch_writer_, ipc::internal::OpenRecordBatchWriter(
+ std::move(payload_writer), schema, options));
+ return Status::OK();
+ }
+
+ Status Begin(const std::shared_ptr<Schema>& schema) override {
+ return Begin(schema, options_);
+ }
+
+ Status WriteRecordBatch(const RecordBatch& batch) override {
+ RETURN_NOT_OK(CheckStarted());
+ return WriteWithMetadata(batch, nullptr);
+ }
+
+ Status WriteMetadata(std::shared_ptr<Buffer> app_metadata) override {
+ FlightPayload payload{};
+ payload.app_metadata = app_metadata;
+ auto status = internal::WritePayload(payload, writer_->stream().get());
+ if (status.IsIOError()) {
+ return writer_->Finish(MakeFlightError(FlightStatusCode::Internal,
+ "Could not write metadata to stream"));
+ }
+ return status;
+ }
+
+ Status WriteWithMetadata(const RecordBatch& batch,
+ std::shared_ptr<Buffer> app_metadata) override {
+ RETURN_NOT_OK(CheckStarted());
+ app_metadata_ = app_metadata;
+ return batch_writer_->WriteRecordBatch(batch);
+ }
+
+ Status DoneWriting() override {
+ // Do not CheckStarted - DoneWriting applies to data and metadata
+ if (batch_writer_) {
+ // Close the writer if we have one; this will force it to flush any
+ // remaining data, before we close the write side of the stream.
+ writer_closed_ = true;
+ Status st = batch_writer_->Close();
+ if (!st.ok()) {
+ return writer_->Finish(std::move(st));
+ }
+ }
+ return writer_->DoneWriting();
+ }
+
+ Status Close() override {
+ // Do not CheckStarted - Close applies to data and metadata
+ if (batch_writer_ && !writer_closed_) {
+ // This is important! Close() calls
+ // IpcPayloadWriter::CheckStarted() which will force the initial
+ // schema message to be written to the stream. This is required
+ // to unstick the server, else the client and the server end up
+ // waiting for each other. This happens if the client never
+ // wrote anything before calling Close().
+ writer_closed_ = true;
+ return writer_->Finish(batch_writer_->Close());
+ }
+ return writer_->Finish(Status::OK());
+ }
+
+ ipc::WriteStats stats() const override {
+ ARROW_CHECK_NE(batch_writer_, nullptr);
+ return batch_writer_->stats();
+ }
+
+ private:
+ friend class DoPutPayloadWriter<ProtoReadT, FlightReadT>;
+ std::shared_ptr<Buffer> app_metadata_;
+ std::unique_ptr<ipc::RecordBatchWriter> batch_writer_;
+ std::shared_ptr<FinishableWritableStream<GrpcStream, FlightReadT>> writer_;
+
+ // Fields used to lazy-initialize the IpcPayloadWriter. They're
+ // invalid once Begin() is called.
+ std::shared_ptr<ClientRpc> rpc_;
+ int64_t write_size_limit_bytes_;
+ ipc::IpcWriteOptions options_;
+ FlightDescriptor descriptor_;
+ bool writer_closed_;
+};
+
+/// A IpcPayloadWriter implementation that writes to a gRPC stream of
+/// FlightData messages.
+template <typename ProtoReadT, typename FlightReadT>
+class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
+ public:
+ using GrpcStream = grpc::ClientReaderWriter<pb::FlightData, ProtoReadT>;
+
+ DoPutPayloadWriter(
+ const FlightDescriptor& descriptor, std::shared_ptr<ClientRpc> rpc,
+ int64_t write_size_limit_bytes,
+ std::shared_ptr<FinishableWritableStream<GrpcStream, FlightReadT>> writer,
+ GrpcStreamWriter<ProtoReadT, FlightReadT>* stream_writer)
+ : descriptor_(descriptor),
+ rpc_(rpc),
+ write_size_limit_bytes_(write_size_limit_bytes),
+ writer_(std::move(writer)),
+ first_payload_(true),
+ stream_writer_(stream_writer) {}
+
+ ~DoPutPayloadWriter() override = default;
+
+ Status Start() override { return Status::OK(); }
+
+ Status WritePayload(const ipc::IpcPayload& ipc_payload) override {
+ FlightPayload payload;
+ payload.ipc_message = ipc_payload;
+
+ if (first_payload_) {
+ // First Flight message needs to encore the Flight descriptor
+ if (ipc_payload.type != ipc::MessageType::SCHEMA) {
+ return Status::Invalid("First IPC message should be schema");
+ }
+ // Write the descriptor to begin with
+ RETURN_NOT_OK(internal::ToPayload(descriptor_, &payload.descriptor));
+ first_payload_ = false;
+ } else if (ipc_payload.type == ipc::MessageType::RECORD_BATCH &&
+ stream_writer_->app_metadata_) {
+ payload.app_metadata = std::move(stream_writer_->app_metadata_);
+ }
+
+ if (write_size_limit_bytes_ > 0) {
+ // Check if the total size is greater than the user-configured
+ // soft-limit.
+ int64_t size = ipc_payload.body_length + ipc_payload.metadata->size();
+ if (payload.descriptor) {
+ size += payload.descriptor->size();
+ }
+ if (payload.app_metadata) {
+ size += payload.app_metadata->size();
+ }
+ if (size > write_size_limit_bytes_) {
+ return arrow::Status(
+ arrow::StatusCode::Invalid, "IPC payload size exceeded soft limit",
+ std::make_shared<FlightWriteSizeStatusDetail>(write_size_limit_bytes_, size));
+ }
+ }
+
+ auto status = internal::WritePayload(payload, writer_->stream().get());
+ if (status.IsIOError()) {
+ return writer_->Finish(MakeFlightError(FlightStatusCode::Internal,
+ "Could not write record batch to stream"));
+ }
+ return status;
+ }
+
+ Status Close() override {
+ // Closing is handled one layer up in GrpcStreamWriter::Close
+ return Status::OK();
+ }
+
+ protected:
+ const FlightDescriptor descriptor_;
+ std::shared_ptr<ClientRpc> rpc_;
+ int64_t write_size_limit_bytes_;
+ std::shared_ptr<FinishableWritableStream<GrpcStream, FlightReadT>> writer_;
+ bool first_payload_;
+ GrpcStreamWriter<ProtoReadT, FlightReadT>* stream_writer_;
+};
+
+template <typename ProtoReadT, typename FlightReadT>
+Status GrpcStreamWriter<ProtoReadT, FlightReadT>::Open(
+ const FlightDescriptor& descriptor,
+ std::shared_ptr<Schema> schema, // this schema is nullable
+ const ipc::IpcWriteOptions& options, std::shared_ptr<ClientRpc> rpc,
+ int64_t write_size_limit_bytes,
+ std::shared_ptr<FinishableWritableStream<GrpcStream, FlightReadT>> writer,
+ std::unique_ptr<FlightStreamWriter>* out) {
+ std::unique_ptr<GrpcStreamWriter<ProtoReadT, FlightReadT>> instance(
+ new GrpcStreamWriter<ProtoReadT, FlightReadT>(
+ descriptor, std::move(rpc), write_size_limit_bytes, options, writer));
+ if (schema) {
+ // The schema was provided (DoPut). Eagerly write the schema and
+ // descriptor together as the first message.
+ RETURN_NOT_OK(instance->Begin(schema, options));
+ } else {
+ // The schema was not provided (DoExchange). Eagerly write just
+ // the descriptor as the first message. Note that if the client
+ // calls Begin() to send data, we'll send a redundant descriptor.
+ FlightPayload payload{};
+ RETURN_NOT_OK(internal::ToPayload(descriptor, &payload.descriptor));
+ auto status = internal::WritePayload(payload, instance->writer_->stream().get());
+ if (status.IsIOError()) {
+ return writer->Finish(MakeFlightError(FlightStatusCode::Internal,
+ "Could not write descriptor to stream"));
+ }
+ RETURN_NOT_OK(status);
+ }
+ *out = std::move(instance);
+ return Status::OK();
+}
+
+FlightMetadataReader::~FlightMetadataReader() = default;
+
+class GrpcMetadataReader : public FlightMetadataReader {
+ public:
+ explicit GrpcMetadataReader(
+ std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> reader,
+ std::shared_ptr<std::mutex> read_mutex)
+ : reader_(reader), read_mutex_(read_mutex) {}
+
+ Status ReadMetadata(std::shared_ptr<Buffer>* out) override {
+ std::lock_guard<std::mutex> guard(*read_mutex_);
+ pb::PutResult message;
+ if (reader_->Read(&message)) {
+ *out = Buffer::FromString(std::move(*message.mutable_app_metadata()));
+ } else {
+ // Stream finished
+ *out = nullptr;
+ }
+ return Status::OK();
+ }
+
+ private:
+ std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> reader_;
+ std::shared_ptr<std::mutex> read_mutex_;
+};
+
+namespace {
+// Dummy self-signed certificate to be used because TlsCredentials
+// requires root CA certs, even if you are skipping server
+// verification.
+#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS)
+constexpr char kDummyRootCert[] =
+ "-----BEGIN CERTIFICATE-----\n"
+ "MIICwzCCAaugAwIBAgIJAM12DOkcaqrhMA0GCSqGSIb3DQEBBQUAMBQxEjAQBgNV\n"
+ "BAMTCWxvY2FsaG9zdDAeFw0yMDEwMDcwODIyNDFaFw0zMDEwMDUwODIyNDFaMBQx\n"
+ "EjAQBgNVBAMTCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC\n"
+ "ggEBALjJ8KPEpF0P4GjMPrJhjIBHUL0AX9E4oWdgJRCSFkPKKEWzQabTQBikMOhI\n"
+ "W4VvBMaHEBuECE5OEyrzDRiAO354I4F4JbBfxMOY8NIW0uWD6THWm2KkCzoZRIPW\n"
+ "yZL6dN+mK6cEH+YvbNuy5ZQGNjGG43tyiXdOCAc4AI9POeTtjdMpbbpR2VY4Ad/E\n"
+ "oTEiS3gNnN7WIAdgMhCJxjzvPwKszV3f7pwuTHzFMsuHLKr6JeaVUYfbi4DxxC8Z\n"
+ "k6PF6dLlLf3ngTSLBJyaXP1BhKMvz0TaMK3F0y2OGwHM9J8np2zWjTlNVEzffQZx\n"
+ "SWMOQManlJGs60xYx9KCPJMZZsMCAwEAAaMYMBYwFAYDVR0RBA0wC4IJbG9jYWxo\n"
+ "b3N0MA0GCSqGSIb3DQEBBQUAA4IBAQC0LrmbcNKgO+D50d/wOc+vhi9K04EZh8bg\n"
+ "WYAK1kLOT4eShbzqWGV/1EggY4muQ6ypSELCLuSsg88kVtFQIeRilA6bHFqQSj6t\n"
+ "sqgh2cWsMwyllCtmX6Maf3CLb2ZdoJlqUwdiBdrbIbuyeAZj3QweCtLKGSQzGDyI\n"
+ "KH7G8nC5d0IoRPiCMB6RnMMKsrhviuCdWbAFHop7Ff36JaOJ8iRa2sSf2OXE8j/5\n"
+ "obCXCUvYHf4Zw27JcM2AnnQI9VJLnYxis83TysC5s2Z7t0OYNS9kFmtXQbUNlmpS\n"
+ "doQ/Eu47vWX7S0TXeGziGtbAOKxbHE0BGGPDOAB/jGW/JVbeTiXY\n"
+ "-----END CERTIFICATE-----\n";
+#endif
+} // namespace
+class FlightClient::FlightClientImpl {
+ public:
+ Status Connect(const Location& location, const FlightClientOptions& options) {
+ const std::string& scheme = location.scheme();
+
+ std::stringstream grpc_uri;
+ std::shared_ptr<grpc::ChannelCredentials> creds;
+ if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) {
+ grpc_uri << arrow::internal::UriEncodeHost(location.uri_->host()) << ':'
+ << location.uri_->port_text();
+
+ if (scheme == kSchemeGrpcTls) {
+ if (options.disable_server_verification) {
+#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS)
+ namespace ge = GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS;
+
+ // A callback to supply to TlsCredentialsOptions that accepts any server
+ // arguments.
+ struct NoOpTlsAuthorizationCheck
+ : public ge::TlsServerAuthorizationCheckInterface {
+ int Schedule(ge::TlsServerAuthorizationCheckArg* arg) override {
+ arg->set_success(1);
+ arg->set_status(GRPC_STATUS_OK);
+ return 0;
+ }
+ };
+ auto server_authorization_check = std::make_shared<NoOpTlsAuthorizationCheck>();
+ noop_auth_check_ = std::make_shared<ge::TlsServerAuthorizationCheckConfig>(
+ server_authorization_check);
+#if defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS)
+ auto certificate_provider =
+ std::make_shared<grpc::experimental::StaticDataCertificateProvider>(
+ kDummyRootCert);
+#if defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS_ROOT_CERTS)
+ grpc::experimental::TlsChannelCredentialsOptions tls_options(
+ certificate_provider);
+#else
+ // While gRPC >= 1.36 does not require a root cert (it has a default)
+ // in practice the path it hardcodes is broken. See grpc/grpc#21655.
+ grpc::experimental::TlsChannelCredentialsOptions tls_options;
+ tls_options.set_certificate_provider(certificate_provider);
+#endif
+ tls_options.watch_root_certs();
+ tls_options.set_root_cert_name("dummy");
+ tls_options.set_server_verification_option(
+ grpc_tls_server_verification_option::GRPC_TLS_SKIP_ALL_SERVER_VERIFICATION);
+ tls_options.set_server_authorization_check_config(noop_auth_check_);
+#elif defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS)
+ auto materials_config = std::make_shared<ge::TlsKeyMaterialsConfig>();
+ materials_config->set_pem_root_certs(kDummyRootCert);
+ ge::TlsCredentialsOptions tls_options(
+ GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE,
+ GRPC_TLS_SKIP_ALL_SERVER_VERIFICATION, materials_config,
+ std::shared_ptr<ge::TlsCredentialReloadConfig>(), noop_auth_check_);
+#endif
+ creds = ge::TlsCredentials(tls_options);
+#else
+ return Status::NotImplemented(
+ "Using encryption with server verification disabled is unsupported. "
+ "Please use a release of Arrow Flight built with gRPC 1.27 or higher.");
+#endif
+ } else {
+ grpc::SslCredentialsOptions ssl_options;
+ if (!options.tls_root_certs.empty()) {
+ ssl_options.pem_root_certs = options.tls_root_certs;
+ }
+ if (!options.cert_chain.empty()) {
+ ssl_options.pem_cert_chain = options.cert_chain;
+ }
+ if (!options.private_key.empty()) {
+ ssl_options.pem_private_key = options.private_key;
+ }
+ creds = grpc::SslCredentials(ssl_options);
+ }
+ } else {
+ creds = grpc::InsecureChannelCredentials();
+ }
+ } else if (scheme == kSchemeGrpcUnix) {
+ grpc_uri << "unix://" << location.uri_->path();
+ creds = grpc::InsecureChannelCredentials();
+ } else {
+ return Status::NotImplemented("Flight scheme " + scheme + " is not supported.");
+ }
+
+ grpc::ChannelArguments args;
+ // We can't set the same config value twice, so for values where
+ // we want to set defaults, keep them in a map and update them;
+ // then update them all at once
+ std::unordered_map<std::string, int> default_args;
+ // Try to reconnect quickly at first, in case the server is still starting up
+ default_args[GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS] = 100;
+ // Receive messages of any size
+ default_args[GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH] = -1;
+ // Setting this arg enables each client to open it's own TCP connection to server,
+ // not sharing one single connection, which becomes bottleneck under high load.
+ default_args[GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL] = 1;
+
+ if (options.override_hostname != "") {
+ args.SetSslTargetNameOverride(options.override_hostname);
+ }
+
+ // Allow setting generic gRPC options.
+ for (const auto& arg : options.generic_options) {
+ if (util::holds_alternative<int>(arg.second)) {
+ default_args[arg.first] = util::get<int>(arg.second);
+ } else if (util::holds_alternative<std::string>(arg.second)) {
+ args.SetString(arg.first, util::get<std::string>(arg.second));
+ }
+ // Otherwise unimplemented
+ }
+ for (const auto& pair : default_args) {
+ args.SetInt(pair.first, pair.second);
+ }
+
+ std::vector<std::unique_ptr<grpc::experimental::ClientInterceptorFactoryInterface>>
+ interceptors;
+ interceptors.emplace_back(
+ new GrpcClientInterceptorAdapterFactory(std::move(options.middleware)));
+
+ stub_ = pb::FlightService::NewStub(
+ grpc::experimental::CreateCustomChannelWithInterceptors(
+ grpc_uri.str(), creds, args, std::move(interceptors)));
+
+ write_size_limit_bytes_ = options.write_size_limit_bytes;
+ return Status::OK();
+ }
+
+ Status Authenticate(const FlightCallOptions& options,
+ std::unique_ptr<ClientAuthHandler> auth_handler) {
+ auth_handler_ = std::move(auth_handler);
+ ClientRpc rpc(options);
+ std::shared_ptr<grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>>
+ stream = stub_->Handshake(&rpc.context);
+ GrpcClientAuthSender outgoing{stream};
+ GrpcClientAuthReader incoming{stream};
+ RETURN_NOT_OK(auth_handler_->Authenticate(&outgoing, &incoming));
+ // Explicitly close our side of the connection
+ bool finished_writes = stream->WritesDone();
+ RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context));
+ if (!finished_writes) {
+ return MakeFlightError(FlightStatusCode::Internal,
+ "Could not finish writing before closing");
+ }
+ return Status::OK();
+ }
+
+ arrow::Result<std::pair<std::string, std::string>> AuthenticateBasicToken(
+ const FlightCallOptions& options, const std::string& username,
+ const std::string& password) {
+ // Add basic auth headers to outgoing headers.
+ ClientRpc rpc(options);
+ internal::AddBasicAuthHeaders(&rpc.context, username, password);
+
+ std::shared_ptr<grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>>
+ stream = stub_->Handshake(&rpc.context);
+ GrpcClientAuthSender outgoing{stream};
+ GrpcClientAuthReader incoming{stream};
+
+ // Explicitly close our side of the connection.
+ bool finished_writes = stream->WritesDone();
+ RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context));
+ if (!finished_writes) {
+ return MakeFlightError(FlightStatusCode::Internal,
+ "Could not finish writing before closing");
+ }
+
+ // Grab bearer token from incoming headers.
+ return internal::GetBearerTokenHeader(rpc.context);
+ }
+
+ Status ListFlights(const FlightCallOptions& options, const Criteria& criteria,
+ std::unique_ptr<FlightListing>* listing) {
+ pb::Criteria pb_criteria;
+ RETURN_NOT_OK(internal::ToProto(criteria, &pb_criteria));
+
+ ClientRpc rpc(options);
+ RETURN_NOT_OK(rpc.SetToken(auth_handler_.get()));
+ std::unique_ptr<grpc::ClientReader<pb::FlightInfo>> stream(
+ stub_->ListFlights(&rpc.context, pb_criteria));
+
+ std::vector<FlightInfo> flights;
+
+ pb::FlightInfo pb_info;
+ while (!options.stop_token.IsStopRequested() && stream->Read(&pb_info)) {
+ FlightInfo::Data info_data;
+ RETURN_NOT_OK(internal::FromProto(pb_info, &info_data));
+ flights.emplace_back(std::move(info_data));
+ }
+ if (options.stop_token.IsStopRequested()) rpc.context.TryCancel();
+ RETURN_NOT_OK(options.stop_token.Poll());
+ listing->reset(new SimpleFlightListing(std::move(flights)));
+ return internal::FromGrpcStatus(stream->Finish(), &rpc.context);
+ }
+
+ Status DoAction(const FlightCallOptions& options, const Action& action,
+ std::unique_ptr<ResultStream>* results) {
+ pb::Action pb_action;
+ RETURN_NOT_OK(internal::ToProto(action, &pb_action));
+
+ ClientRpc rpc(options);
+ RETURN_NOT_OK(rpc.SetToken(auth_handler_.get()));
+ std::unique_ptr<grpc::ClientReader<pb::Result>> stream(
+ stub_->DoAction(&rpc.context, pb_action));
+
+ pb::Result pb_result;
+
+ std::vector<Result> materialized_results;
+ while (!options.stop_token.IsStopRequested() && stream->Read(&pb_result)) {
+ Result result;
+ RETURN_NOT_OK(internal::FromProto(pb_result, &result));
+ materialized_results.emplace_back(std::move(result));
+ }
+ if (options.stop_token.IsStopRequested()) rpc.context.TryCancel();
+ RETURN_NOT_OK(options.stop_token.Poll());
+
+ *results = std::unique_ptr<ResultStream>(
+ new SimpleResultStream(std::move(materialized_results)));
+ return internal::FromGrpcStatus(stream->Finish(), &rpc.context);
+ }
+
+ Status ListActions(const FlightCallOptions& options, std::vector<ActionType>* types) {
+ pb::Empty empty;
+
+ ClientRpc rpc(options);
+ RETURN_NOT_OK(rpc.SetToken(auth_handler_.get()));
+ std::unique_ptr<grpc::ClientReader<pb::ActionType>> stream(
+ stub_->ListActions(&rpc.context, empty));
+
+ pb::ActionType pb_type;
+ ActionType type;
+ while (!options.stop_token.IsStopRequested() && stream->Read(&pb_type)) {
+ RETURN_NOT_OK(internal::FromProto(pb_type, &type));
+ types->emplace_back(std::move(type));
+ }
+ if (options.stop_token.IsStopRequested()) rpc.context.TryCancel();
+ RETURN_NOT_OK(options.stop_token.Poll());
+ return internal::FromGrpcStatus(stream->Finish(), &rpc.context);
+ }
+
+ Status GetFlightInfo(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ std::unique_ptr<FlightInfo>* info) {
+ pb::FlightDescriptor pb_descriptor;
+ pb::FlightInfo pb_response;
+
+ RETURN_NOT_OK(internal::ToProto(descriptor, &pb_descriptor));
+
+ ClientRpc rpc(options);
+ RETURN_NOT_OK(rpc.SetToken(auth_handler_.get()));
+ Status s = internal::FromGrpcStatus(
+ stub_->GetFlightInfo(&rpc.context, pb_descriptor, &pb_response), &rpc.context);
+ RETURN_NOT_OK(s);
+
+ FlightInfo::Data info_data;
+ RETURN_NOT_OK(internal::FromProto(pb_response, &info_data));
+ info->reset(new FlightInfo(std::move(info_data)));
+ return Status::OK();
+ }
+
+ Status GetSchema(const FlightCallOptions& options, const FlightDescriptor& descriptor,
+ std::unique_ptr<SchemaResult>* schema_result) {
+ pb::FlightDescriptor pb_descriptor;
+ pb::SchemaResult pb_response;
+
+ RETURN_NOT_OK(internal::ToProto(descriptor, &pb_descriptor));
+
+ ClientRpc rpc(options);
+ RETURN_NOT_OK(rpc.SetToken(auth_handler_.get()));
+ Status s = internal::FromGrpcStatus(
+ stub_->GetSchema(&rpc.context, pb_descriptor, &pb_response), &rpc.context);
+ RETURN_NOT_OK(s);
+
+ std::string str;
+ RETURN_NOT_OK(internal::FromProto(pb_response, &str));
+ schema_result->reset(new SchemaResult(str));
+ return Status::OK();
+ }
+
+ Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
+ std::unique_ptr<FlightStreamReader>* out) {
+ using StreamReader = GrpcStreamReader<grpc::ClientReader<pb::FlightData>>;
+ pb::Ticket pb_ticket;
+ internal::ToProto(ticket, &pb_ticket);
+
+ auto rpc = std::make_shared<ClientRpc>(options);
+ RETURN_NOT_OK(rpc->SetToken(auth_handler_.get()));
+ std::shared_ptr<grpc::ClientReader<pb::FlightData>> stream =
+ stub_->DoGet(&rpc->context, pb_ticket);
+ auto finishable_stream = std::make_shared<
+ FinishableStream<grpc::ClientReader<pb::FlightData>, internal::FlightData>>(
+ rpc, stream);
+ *out = std::unique_ptr<StreamReader>(new StreamReader(
+ rpc, nullptr, options.read_options, options.stop_token, finishable_stream));
+ // Eagerly read the schema
+ return static_cast<StreamReader*>(out->get())->EnsureDataStarted();
+ }
+
+ Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor,
+ const std::shared_ptr<Schema>& schema,
+ std::unique_ptr<FlightStreamWriter>* out,
+ std::unique_ptr<FlightMetadataReader>* reader) {
+ using GrpcStream = grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>;
+ using StreamWriter = GrpcStreamWriter<pb::PutResult, pb::PutResult>;
+
+ auto rpc = std::make_shared<ClientRpc>(options);
+ RETURN_NOT_OK(rpc->SetToken(auth_handler_.get()));
+ std::shared_ptr<GrpcStream> stream = stub_->DoPut(&rpc->context);
+ // The writer drains the reader on close to avoid hanging inside
+ // gRPC. Concurrent reads are unsafe, so a mutex protects this operation.
+ std::shared_ptr<std::mutex> read_mutex = std::make_shared<std::mutex>();
+ auto finishable_stream =
+ std::make_shared<FinishableWritableStream<GrpcStream, pb::PutResult>>(
+ rpc, read_mutex, stream);
+ *reader =
+ std::unique_ptr<FlightMetadataReader>(new GrpcMetadataReader(stream, read_mutex));
+ return StreamWriter::Open(descriptor, schema, options.write_options, rpc,
+ write_size_limit_bytes_, finishable_stream, out);
+ }
+
+ Status DoExchange(const FlightCallOptions& options, const FlightDescriptor& descriptor,
+ std::unique_ptr<FlightStreamWriter>* writer,
+ std::unique_ptr<FlightStreamReader>* reader) {
+ using GrpcStream = grpc::ClientReaderWriter<pb::FlightData, pb::FlightData>;
+ using StreamReader = GrpcStreamReader<GrpcStream>;
+ using StreamWriter = GrpcStreamWriter<pb::FlightData, internal::FlightData>;
+
+ auto rpc = std::make_shared<ClientRpc>(options);
+ RETURN_NOT_OK(rpc->SetToken(auth_handler_.get()));
+ std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::FlightData>> stream =
+ stub_->DoExchange(&rpc->context);
+ // The writer drains the reader on close to avoid hanging inside
+ // gRPC. Concurrent reads are unsafe, so a mutex protects this operation.
+ std::shared_ptr<std::mutex> read_mutex = std::make_shared<std::mutex>();
+ auto finishable_stream =
+ std::make_shared<FinishableWritableStream<GrpcStream, internal::FlightData>>(
+ rpc, read_mutex, stream);
+ *reader = std::unique_ptr<StreamReader>(new StreamReader(
+ rpc, read_mutex, options.read_options, options.stop_token, finishable_stream));
+ // Do not eagerly read the schema. There may be metadata messages
+ // before any data is sent, or data may not be sent at all.
+ return StreamWriter::Open(descriptor, nullptr, options.write_options, rpc,
+ write_size_limit_bytes_, finishable_stream, writer);
+ }
+
+ private:
+ std::unique_ptr<pb::FlightService::Stub> stub_;
+ std::shared_ptr<ClientAuthHandler> auth_handler_;
+#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS)
+ // Scope the TlsServerAuthorizationCheckConfig to be at the class instance level, since
+ // it gets created during Connect() and needs to persist to DoAction() calls. gRPC does
+ // not correctly increase the reference count of this object:
+ // https://github.com/grpc/grpc/issues/22287
+ std::shared_ptr<
+ GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS::TlsServerAuthorizationCheckConfig>
+ noop_auth_check_;
+#endif
+ int64_t write_size_limit_bytes_;
+};
+
+FlightClient::FlightClient() { impl_.reset(new FlightClientImpl); }
+
+FlightClient::~FlightClient() {}
+
+Status FlightClient::Connect(const Location& location,
+ std::unique_ptr<FlightClient>* client) {
+ return Connect(location, FlightClientOptions::Defaults(), client);
+}
+
+Status FlightClient::Connect(const Location& location, const FlightClientOptions& options,
+ std::unique_ptr<FlightClient>* client) {
+ client->reset(new FlightClient);
+ return (*client)->impl_->Connect(location, options);
+}
+
+Status FlightClient::Authenticate(const FlightCallOptions& options,
+ std::unique_ptr<ClientAuthHandler> auth_handler) {
+ return impl_->Authenticate(options, std::move(auth_handler));
+}
+
+arrow::Result<std::pair<std::string, std::string>> FlightClient::AuthenticateBasicToken(
+ const FlightCallOptions& options, const std::string& username,
+ const std::string& password) {
+ return impl_->AuthenticateBasicToken(options, username, password);
+}
+
+Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action,
+ std::unique_ptr<ResultStream>* results) {
+ return impl_->DoAction(options, action, results);
+}
+
+Status FlightClient::ListActions(const FlightCallOptions& options,
+ std::vector<ActionType>* actions) {
+ return impl_->ListActions(options, actions);
+}
+
+Status FlightClient::GetFlightInfo(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ std::unique_ptr<FlightInfo>* info) {
+ return impl_->GetFlightInfo(options, descriptor, info);
+}
+
+Status FlightClient::GetSchema(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ std::unique_ptr<SchemaResult>* schema_result) {
+ return impl_->GetSchema(options, descriptor, schema_result);
+}
+
+Status FlightClient::ListFlights(std::unique_ptr<FlightListing>* listing) {
+ return ListFlights({}, {}, listing);
+}
+
+Status FlightClient::ListFlights(const FlightCallOptions& options,
+ const Criteria& criteria,
+ std::unique_ptr<FlightListing>* listing) {
+ return impl_->ListFlights(options, criteria, listing);
+}
+
+Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticket,
+ std::unique_ptr<FlightStreamReader>* stream) {
+ return impl_->DoGet(options, ticket, stream);
+}
+
+Status FlightClient::DoPut(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ const std::shared_ptr<Schema>& schema,
+ std::unique_ptr<FlightStreamWriter>* stream,
+ std::unique_ptr<FlightMetadataReader>* reader) {
+ return impl_->DoPut(options, descriptor, schema, stream, reader);
+}
+
+Status FlightClient::DoExchange(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ std::unique_ptr<FlightStreamWriter>* writer,
+ std::unique_ptr<FlightStreamReader>* reader) {
+ return impl_->DoExchange(options, descriptor, writer, reader);
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/client.h b/src/arrow/cpp/src/arrow/flight/client.h
new file mode 100644
index 000000000..0a35b6d10
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/client.h
@@ -0,0 +1,330 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// \brief Implementation of Flight RPC client using gRPC. API should be
+// considered experimental for now
+
+#pragma once
+
+#include <chrono>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/ipc/options.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/cancel.h"
+#include "arrow/util/variant.h"
+
+#include "arrow/flight/types.h" // IWYU pragma: keep
+#include "arrow/flight/visibility.h"
+
+namespace arrow {
+
+class RecordBatch;
+class Schema;
+
+namespace flight {
+
+class ClientAuthHandler;
+class ClientMiddleware;
+class ClientMiddlewareFactory;
+
+/// \brief A duration type for Flight call timeouts.
+typedef std::chrono::duration<double, std::chrono::seconds::period> TimeoutDuration;
+
+/// \brief Hints to the underlying RPC layer for Arrow Flight calls.
+class ARROW_FLIGHT_EXPORT FlightCallOptions {
+ public:
+ /// Create a default set of call options.
+ FlightCallOptions();
+
+ /// \brief An optional timeout for this call. Negative durations
+ /// mean an implementation-defined default behavior will be used
+ /// instead. This is the default value.
+ TimeoutDuration timeout;
+
+ /// \brief IPC reader options, if applicable for the call.
+ ipc::IpcReadOptions read_options;
+
+ /// \brief IPC writer options, if applicable for the call.
+ ipc::IpcWriteOptions write_options;
+
+ /// \brief Headers for client to add to context.
+ std::vector<std::pair<std::string, std::string>> headers;
+
+ /// \brief A token to enable interactive user cancellation of long-running requests.
+ StopToken stop_token;
+};
+
+/// \brief Indicate that the client attempted to write a message
+/// larger than the soft limit set via write_size_limit_bytes.
+class ARROW_FLIGHT_EXPORT FlightWriteSizeStatusDetail : public arrow::StatusDetail {
+ public:
+ explicit FlightWriteSizeStatusDetail(int64_t limit, int64_t actual)
+ : limit_(limit), actual_(actual) {}
+ const char* type_id() const override;
+ std::string ToString() const override;
+ int64_t limit() const { return limit_; }
+ int64_t actual() const { return actual_; }
+
+ /// \brief Extract this status detail from a status, or return
+ /// nullptr if the status doesn't contain this status detail.
+ static std::shared_ptr<FlightWriteSizeStatusDetail> UnwrapStatus(
+ const arrow::Status& status);
+
+ private:
+ int64_t limit_;
+ int64_t actual_;
+};
+
+struct ARROW_FLIGHT_EXPORT FlightClientOptions {
+ /// \brief Root certificates to use for validating server
+ /// certificates.
+ std::string tls_root_certs;
+ /// \brief Override the hostname checked by TLS. Use with caution.
+ std::string override_hostname;
+ /// \brief The client certificate to use if using Mutual TLS
+ std::string cert_chain;
+ /// \brief The private key associated with the client certificate for Mutual TLS
+ std::string private_key;
+ /// \brief A list of client middleware to apply.
+ std::vector<std::shared_ptr<ClientMiddlewareFactory>> middleware;
+ /// \brief A soft limit on the number of bytes to write in a single
+ /// batch when sending Arrow data to a server.
+ ///
+ /// Used to help limit server memory consumption. Only enabled if
+ /// positive. When enabled, FlightStreamWriter.Write* may yield a
+ /// IOError with error detail FlightWriteSizeStatusDetail.
+ int64_t write_size_limit_bytes = 0;
+
+ /// \brief Generic connection options, passed to the underlying
+ /// transport; interpretation is implementation-dependent.
+ std::vector<std::pair<std::string, util::Variant<int, std::string>>> generic_options;
+
+ /// \brief Use TLS without validating the server certificate. Use with caution.
+ bool disable_server_verification = false;
+
+ /// \brief Get default options.
+ static FlightClientOptions Defaults();
+};
+
+/// \brief A RecordBatchReader exposing Flight metadata and cancel
+/// operations.
+class ARROW_FLIGHT_EXPORT FlightStreamReader : public MetadataRecordBatchReader {
+ public:
+ /// \brief Try to cancel the call.
+ virtual void Cancel() = 0;
+ using MetadataRecordBatchReader::ReadAll;
+ /// \brief Consume entire stream as a vector of record batches
+ virtual Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches,
+ const StopToken& stop_token) = 0;
+ /// \brief Consume entire stream as a Table
+ Status ReadAll(std::shared_ptr<Table>* table, const StopToken& stop_token);
+};
+
+// Silence warning
+// "non dll-interface class RecordBatchReader used as base for dll-interface class"
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4275)
+#endif
+
+/// \brief A RecordBatchWriter that also allows sending
+/// application-defined metadata via the Flight protocol.
+class ARROW_FLIGHT_EXPORT FlightStreamWriter : public MetadataRecordBatchWriter {
+ public:
+ /// \brief Indicate that the application is done writing to this stream.
+ ///
+ /// The application may not write to this stream after calling
+ /// this. This differs from closing the stream because this writer
+ /// may represent only one half of a readable and writable stream.
+ virtual Status DoneWriting() = 0;
+};
+
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+
+/// \brief A reader for application-specific metadata sent back to the
+/// client during an upload.
+class ARROW_FLIGHT_EXPORT FlightMetadataReader {
+ public:
+ virtual ~FlightMetadataReader();
+ /// \brief Read a message from the server.
+ virtual Status ReadMetadata(std::shared_ptr<Buffer>* out) = 0;
+};
+
+/// \brief Client class for Arrow Flight RPC services (gRPC-based).
+/// API experimental for now
+class ARROW_FLIGHT_EXPORT FlightClient {
+ public:
+ ~FlightClient();
+
+ /// \brief Connect to an unauthenticated flight service
+ /// \param[in] location the URI
+ /// \param[out] client the created FlightClient
+ /// \return Status OK status may not indicate that the connection was
+ /// successful
+ static Status Connect(const Location& location, std::unique_ptr<FlightClient>* client);
+
+ /// \brief Connect to an unauthenticated flight service
+ /// \param[in] location the URI
+ /// \param[in] options Other options for setting up the client
+ /// \param[out] client the created FlightClient
+ /// \return Status OK status may not indicate that the connection was
+ /// successful
+ static Status Connect(const Location& location, const FlightClientOptions& options,
+ std::unique_ptr<FlightClient>* client);
+
+ /// \brief Authenticate to the server using the given handler.
+ /// \param[in] options Per-RPC options
+ /// \param[in] auth_handler The authentication mechanism to use
+ /// \return Status OK if the client authenticated successfully
+ Status Authenticate(const FlightCallOptions& options,
+ std::unique_ptr<ClientAuthHandler> auth_handler);
+
+ /// \brief Authenticate to the server using basic HTTP style authentication.
+ /// \param[in] options Per-RPC options
+ /// \param[in] username Username to use
+ /// \param[in] password Password to use
+ /// \return Arrow result with bearer token and status OK if client authenticated
+ /// sucessfully
+ arrow::Result<std::pair<std::string, std::string>> AuthenticateBasicToken(
+ const FlightCallOptions& options, const std::string& username,
+ const std::string& password);
+
+ /// \brief Perform the indicated action, returning an iterator to the stream
+ /// of results, if any
+ /// \param[in] options Per-RPC options
+ /// \param[in] action the action to be performed
+ /// \param[out] results an iterator object for reading the returned results
+ /// \return Status
+ Status DoAction(const FlightCallOptions& options, const Action& action,
+ std::unique_ptr<ResultStream>* results);
+ Status DoAction(const Action& action, std::unique_ptr<ResultStream>* results) {
+ return DoAction({}, action, results);
+ }
+
+ /// \brief Retrieve a list of available Action types
+ /// \param[in] options Per-RPC options
+ /// \param[out] actions the available actions
+ /// \return Status
+ Status ListActions(const FlightCallOptions& options, std::vector<ActionType>* actions);
+ Status ListActions(std::vector<ActionType>* actions) {
+ return ListActions({}, actions);
+ }
+
+ /// \brief Request access plan for a single flight, which may be an existing
+ /// dataset or a command to be executed
+ /// \param[in] options Per-RPC options
+ /// \param[in] descriptor the dataset request, whether a named dataset or
+ /// command
+ /// \param[out] info the FlightInfo describing where to access the dataset
+ /// \return Status
+ Status GetFlightInfo(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ std::unique_ptr<FlightInfo>* info);
+ Status GetFlightInfo(const FlightDescriptor& descriptor,
+ std::unique_ptr<FlightInfo>* info) {
+ return GetFlightInfo({}, descriptor, info);
+ }
+
+ /// \brief Request schema for a single flight, which may be an existing
+ /// dataset or a command to be executed
+ /// \param[in] options Per-RPC options
+ /// \param[in] descriptor the dataset request, whether a named dataset or
+ /// command
+ /// \param[out] schema_result the SchemaResult describing the dataset schema
+ /// \return Status
+ Status GetSchema(const FlightCallOptions& options, const FlightDescriptor& descriptor,
+ std::unique_ptr<SchemaResult>* schema_result);
+ Status GetSchema(const FlightDescriptor& descriptor,
+ std::unique_ptr<SchemaResult>* schema_result) {
+ return GetSchema({}, descriptor, schema_result);
+ }
+
+ /// \brief List all available flights known to the server
+ /// \param[out] listing an iterator that returns a FlightInfo for each flight
+ /// \return Status
+ Status ListFlights(std::unique_ptr<FlightListing>* listing);
+
+ /// \brief List available flights given indicated filter criteria
+ /// \param[in] options Per-RPC options
+ /// \param[in] criteria the filter criteria (opaque)
+ /// \param[out] listing an iterator that returns a FlightInfo for each flight
+ /// \return Status
+ Status ListFlights(const FlightCallOptions& options, const Criteria& criteria,
+ std::unique_ptr<FlightListing>* listing);
+
+ /// \brief Given a flight ticket and schema, request to be sent the
+ /// stream. Returns record batch stream reader
+ /// \param[in] options Per-RPC options
+ /// \param[in] ticket The flight ticket to use
+ /// \param[out] stream the returned RecordBatchReader
+ /// \return Status
+ Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
+ std::unique_ptr<FlightStreamReader>* stream);
+ Status DoGet(const Ticket& ticket, std::unique_ptr<FlightStreamReader>* stream) {
+ return DoGet({}, ticket, stream);
+ }
+
+ /// \brief Upload data to a Flight described by the given
+ /// descriptor. The caller must call Close() on the returned stream
+ /// once they are done writing.
+ ///
+ /// The reader and writer are linked; closing the writer will also
+ /// close the reader. Use \a DoneWriting to only close the write
+ /// side of the channel.
+ ///
+ /// \param[in] options Per-RPC options
+ /// \param[in] descriptor the descriptor of the stream
+ /// \param[in] schema the schema for the data to upload
+ /// \param[out] stream a writer to write record batches to
+ /// \param[out] reader a reader for application metadata from the server
+ /// \return Status
+ Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor,
+ const std::shared_ptr<Schema>& schema,
+ std::unique_ptr<FlightStreamWriter>* stream,
+ std::unique_ptr<FlightMetadataReader>* reader);
+ Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr<Schema>& schema,
+ std::unique_ptr<FlightStreamWriter>* stream,
+ std::unique_ptr<FlightMetadataReader>* reader) {
+ return DoPut({}, descriptor, schema, stream, reader);
+ }
+
+ Status DoExchange(const FlightCallOptions& options, const FlightDescriptor& descriptor,
+ std::unique_ptr<FlightStreamWriter>* writer,
+ std::unique_ptr<FlightStreamReader>* reader);
+ Status DoExchange(const FlightDescriptor& descriptor,
+ std::unique_ptr<FlightStreamWriter>* writer,
+ std::unique_ptr<FlightStreamReader>* reader) {
+ return DoExchange({}, descriptor, writer, reader);
+ }
+
+ private:
+ FlightClient();
+ class FlightClientImpl;
+ std::unique_ptr<FlightClientImpl> impl_;
+};
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/client_auth.h b/src/arrow/cpp/src/arrow/flight/client_auth.h
new file mode 100644
index 000000000..9dad36aa0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/client_auth.h
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+
+#include "arrow/flight/visibility.h"
+#include "arrow/status.h"
+
+namespace arrow {
+
+namespace flight {
+
+/// \brief A reader for messages from the server during an
+/// authentication handshake.
+class ARROW_FLIGHT_EXPORT ClientAuthReader {
+ public:
+ virtual ~ClientAuthReader() = default;
+ virtual Status Read(std::string* response) = 0;
+};
+
+/// \brief A writer for messages to the server during an
+/// authentication handshake.
+class ARROW_FLIGHT_EXPORT ClientAuthSender {
+ public:
+ virtual ~ClientAuthSender() = default;
+ virtual Status Write(const std::string& token) = 0;
+};
+
+/// \brief An authentication implementation for a Flight service.
+/// Authentication includes both an initial negotiation and a per-call
+/// token validation. Implementations may choose to use either or both
+/// mechanisms.
+class ARROW_FLIGHT_EXPORT ClientAuthHandler {
+ public:
+ virtual ~ClientAuthHandler() = default;
+ /// \brief Authenticate the client on initial connection. The client
+ /// can send messages to/read responses from the server at any time.
+ /// \return Status OK if authenticated successfully
+ virtual Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) = 0;
+ /// \brief Get a per-call token.
+ /// \param[out] token The token to send to the server.
+ virtual Status GetToken(std::string* token) = 0;
+};
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/client_cookie_middleware.cc b/src/arrow/cpp/src/arrow/flight/client_cookie_middleware.cc
new file mode 100644
index 000000000..145705e97
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/client_cookie_middleware.cc
@@ -0,0 +1,65 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/client_cookie_middleware.h"
+#include "arrow/flight/client_header_internal.h"
+#include "arrow/util/value_parsing.h"
+
+namespace arrow {
+namespace flight {
+
+/// \brief Client-side middleware for sending/receiving HTTP style cookies.
+class ClientCookieMiddlewareFactory : public ClientMiddlewareFactory {
+ public:
+ void StartCall(const CallInfo& info, std::unique_ptr<ClientMiddleware>* middleware) {
+ ARROW_UNUSED(info);
+ *middleware = std::unique_ptr<ClientMiddleware>(new ClientCookieMiddleware(*this));
+ }
+
+ private:
+ class ClientCookieMiddleware : public ClientMiddleware {
+ public:
+ explicit ClientCookieMiddleware(ClientCookieMiddlewareFactory& factory)
+ : factory_(factory) {}
+
+ void SendingHeaders(AddCallHeaders* outgoing_headers) override {
+ const std::string& cookie_string = factory_.cookie_cache_.GetValidCookiesAsString();
+ if (!cookie_string.empty()) {
+ outgoing_headers->AddHeader("cookie", cookie_string);
+ }
+ }
+
+ void ReceivedHeaders(const CallHeaders& incoming_headers) override {
+ factory_.cookie_cache_.UpdateCachedCookies(incoming_headers);
+ }
+
+ void CallCompleted(const Status& status) override {}
+
+ private:
+ ClientCookieMiddlewareFactory& factory_;
+ };
+
+ // Cookie cache has mutex to protect itself.
+ internal::CookieCache cookie_cache_;
+};
+
+std::shared_ptr<ClientMiddlewareFactory> GetCookieFactory() {
+ return std::make_shared<ClientCookieMiddlewareFactory>();
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/client_cookie_middleware.h b/src/arrow/cpp/src/arrow/flight/client_cookie_middleware.h
new file mode 100644
index 000000000..6a56a632d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/client_cookie_middleware.h
@@ -0,0 +1,33 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Middleware implementation for sending and receiving HTTP cookies.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/flight/client_middleware.h"
+
+namespace arrow {
+namespace flight {
+
+/// \brief Returns a ClientMiddlewareFactory that handles sending and receiving cookies.
+ARROW_FLIGHT_EXPORT std::shared_ptr<ClientMiddlewareFactory> GetCookieFactory();
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/client_header_internal.cc b/src/arrow/cpp/src/arrow/flight/client_header_internal.cc
new file mode 100644
index 000000000..689fe0fc1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/client_header_internal.cc
@@ -0,0 +1,337 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Interfaces for defining middleware for Flight clients. Currently
+// experimental.
+
+#include "arrow/flight/client_header_internal.h"
+#include "arrow/flight/client.h"
+#include "arrow/flight/client_auth.h"
+#include "arrow/flight/platform.h"
+#include "arrow/util/base64.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/string.h"
+#include "arrow/util/uri.h"
+#include "arrow/util/value_parsing.h"
+
+// Mingw-w64 defines strcasecmp in string.h
+#if defined(_WIN32) && !defined(strcasecmp)
+#define strcasecmp stricmp
+#endif
+
+#include <algorithm>
+#include <cctype>
+#include <chrono>
+#include <map>
+#include <memory>
+#include <mutex>
+#include <string>
+
+const char kAuthHeader[] = "authorization";
+const char kBearerPrefix[] = "Bearer ";
+const char kBasicPrefix[] = "Basic ";
+const char kCookieExpiresFormat[] = "%d %m %Y %H:%M:%S";
+
+namespace arrow {
+namespace flight {
+namespace internal {
+
+using CookiePair = arrow::util::optional<std::pair<std::string, std::string>>;
+using CookieHeaderPair =
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>&;
+
+bool CaseInsensitiveComparator::operator()(const std::string& lhs,
+ const std::string& rhs) const {
+ return strcasecmp(lhs.c_str(), rhs.c_str()) < 0;
+}
+
+size_t CaseInsensitiveHash::operator()(const std::string& key) const {
+ std::string upper_string = key;
+ std::transform(upper_string.begin(), upper_string.end(), upper_string.begin(),
+ ::toupper);
+ return std::hash<std::string>{}(upper_string);
+}
+
+Cookie Cookie::parse(const arrow::util::string_view& cookie_header_value) {
+ // Parse the cookie string. If the cookie has an expiration, record it.
+ // If the cookie has a max-age, calculate the current time + max_age and set that as
+ // the expiration.
+ Cookie cookie;
+ cookie.has_expiry_ = false;
+ std::string cookie_value_str(cookie_header_value);
+
+ // There should always be a first match which should be the name and value of the
+ // cookie.
+ std::string::size_type pos = 0;
+ CookiePair cookie_pair = ParseCookieAttribute(cookie_value_str, &pos);
+ if (!cookie_pair.has_value()) {
+ // No cookie found. Mark the output cookie as expired.
+ cookie.has_expiry_ = true;
+ cookie.expiration_time_ = std::chrono::system_clock::now();
+ } else {
+ cookie.cookie_name_ = cookie_pair.value().first;
+ cookie.cookie_value_ = cookie_pair.value().second;
+ }
+
+ while (pos < cookie_value_str.size()) {
+ cookie_pair = ParseCookieAttribute(cookie_value_str, &pos);
+ if (!cookie_pair.has_value()) {
+ break;
+ }
+
+ std::string cookie_attr_value_str = cookie_pair.value().second;
+ if (arrow::internal::AsciiEqualsCaseInsensitive(cookie_pair.value().first,
+ "max-age")) {
+ // Note: max-age takes precedence over expires. We don't really care about other
+ // attributes and will arbitrarily take the first max-age. We can stop the loop
+ // here.
+ cookie.has_expiry_ = true;
+ int max_age = -1;
+ try {
+ max_age = std::stoi(cookie_attr_value_str);
+ } catch (...) {
+ // stoi throws an exception when it fails, just ignore and leave max_age as -1.
+ }
+
+ if (max_age <= 0) {
+ // Force expiration.
+ cookie.expiration_time_ = std::chrono::system_clock::now();
+ } else {
+ // Max-age is in seconds.
+ cookie.expiration_time_ =
+ std::chrono::system_clock::now() + std::chrono::seconds(max_age);
+ }
+ break;
+ } else if (arrow::internal::AsciiEqualsCaseInsensitive(cookie_pair.value().first,
+ "expires")) {
+ cookie.has_expiry_ = true;
+ int64_t seconds = 0;
+ ConvertCookieDate(&cookie_attr_value_str);
+ if (arrow::internal::ParseTimestampStrptime(
+ cookie_attr_value_str.c_str(), cookie_attr_value_str.size(),
+ kCookieExpiresFormat, false, true, arrow::TimeUnit::SECOND, &seconds)) {
+ cookie.expiration_time_ = std::chrono::time_point<std::chrono::system_clock>(
+ std::chrono::seconds(seconds));
+ } else {
+ // Force expiration.
+ cookie.expiration_time_ = std::chrono::system_clock::now();
+ }
+ }
+ }
+
+ return cookie;
+}
+
+CookiePair Cookie::ParseCookieAttribute(const std::string& cookie_header_value,
+ std::string::size_type* start_pos) {
+ std::string::size_type equals_pos = cookie_header_value.find('=', *start_pos);
+ if (std::string::npos == equals_pos) {
+ // No cookie attribute.
+ *start_pos = std::string::npos;
+ return arrow::util::nullopt;
+ }
+
+ std::string::size_type semi_col_pos = cookie_header_value.find(';', equals_pos);
+ std::string out_key = arrow::internal::TrimString(
+ cookie_header_value.substr(*start_pos, equals_pos - *start_pos));
+ std::string out_value;
+ if (std::string::npos == semi_col_pos) {
+ // Last item - set start pos to end
+ out_value = arrow::internal::TrimString(cookie_header_value.substr(equals_pos + 1));
+ *start_pos = std::string::npos;
+ } else {
+ out_value = arrow::internal::TrimString(
+ cookie_header_value.substr(equals_pos + 1, semi_col_pos - equals_pos - 1));
+ *start_pos = semi_col_pos + 1;
+ }
+
+ // Key/Value may be URI-encoded.
+ out_key = arrow::internal::UriUnescape(out_key);
+ out_value = arrow::internal::UriUnescape(out_value);
+
+ // Strip outer quotes on the value.
+ if (out_value.size() >= 2 && out_value[0] == '"' &&
+ out_value[out_value.size() - 1] == '"') {
+ out_value = out_value.substr(1, out_value.size() - 2);
+ }
+
+ // Update the start position for subsequent calls to this function.
+ return std::make_pair(out_key, out_value);
+}
+
+void Cookie::ConvertCookieDate(std::string* date) {
+ // Abbreviated months in order.
+ static const std::vector<std::string> months = {
+ "JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC"};
+
+ // The date comes in with the following format: Wed, 01 Jan 3000 22:15:36 GMT
+ // Symbolics are not supported by Windows parsing, so we need to convert to
+ // the following format: 01 01 3000 22:15:36
+
+ // String is currently in regular format: 'Wed, 01 Jan 3000 22:15:36 GMT'
+ // Start by removing comma and everything before it, then trimming space.
+ auto comma_loc = date->find(",");
+ if (comma_loc == std::string::npos) {
+ return;
+ }
+ *date = arrow::internal::TrimString(date->substr(comma_loc + 1));
+
+ // String is now in trimmed format: '01 Jan 3000 22:15:36 GMT'
+ // Now swap month to proper month format for Windows.
+ // Start by removing case sensitivity.
+ std::transform(date->begin(), date->end(), date->begin(), ::toupper);
+
+ // Loop through months.
+ for (size_t i = 0; i < months.size(); i++) {
+ // Search the date for the month.
+ auto it = date->find(months[i]);
+ if (it != std::string::npos) {
+ // Create month integer, pad with leading zeros if required.
+ std::string padded_month;
+ if ((i + 1) < 10) {
+ padded_month = "0";
+ }
+ padded_month += std::to_string(i + 1);
+
+ // Replace symbolic month with numeric month.
+ date->replace(it, months[i].length(), padded_month);
+
+ // String is now in format: '01 01 3000 22:15:36 GMT'.
+ break;
+ }
+ }
+
+ // String is now in format '01 01 3000 22:15:36'.
+ auto gmt = date->find(" GMT");
+ if (gmt == std::string::npos) {
+ return;
+ }
+ date->erase(gmt, 4);
+
+ // Sometimes a semicolon is added at the end, if this is the case, remove it.
+ if (date->back() == ';') {
+ date->pop_back();
+ }
+}
+
+bool Cookie::IsExpired() const {
+ // Check if current-time is less than creation time.
+ return (has_expiry_ && (expiration_time_ <= std::chrono::system_clock::now()));
+}
+
+std::string Cookie::AsCookieString() const {
+ // Return the string for the cookie as it would appear in a Cookie header.
+ // Keys must be wrapped in quotes depending on if this is a v1 or v2 cookie.
+ return cookie_name_ + "=\"" + cookie_value_ + "\"";
+}
+
+std::string Cookie::GetName() const { return cookie_name_; }
+
+void CookieCache::DiscardExpiredCookies() {
+ for (auto it = cookies.begin(); it != cookies.end();) {
+ if (it->second.IsExpired()) {
+ it = cookies.erase(it);
+ } else {
+ ++it;
+ }
+ }
+}
+
+void CookieCache::UpdateCachedCookies(const CallHeaders& incoming_headers) {
+ CookieHeaderPair header_values = incoming_headers.equal_range("set-cookie");
+ const std::lock_guard<std::mutex> guard(mutex_);
+
+ for (auto it = header_values.first; it != header_values.second; ++it) {
+ const util::string_view& value = it->second;
+ Cookie cookie = Cookie::parse(value);
+
+ // Cache cookies regardless of whether or not they are expired. The server may have
+ // explicitly sent a Set-Cookie to expire a cached cookie.
+ auto insertable = cookies.insert({cookie.GetName(), cookie});
+
+ // Force overwrite on insert collision.
+ if (!insertable.second) {
+ insertable.first->second = cookie;
+ }
+ }
+}
+
+std::string CookieCache::GetValidCookiesAsString() {
+ const std::lock_guard<std::mutex> guard(mutex_);
+
+ DiscardExpiredCookies();
+ if (cookies.empty()) {
+ return "";
+ }
+
+ std::string cookie_string = cookies.begin()->second.AsCookieString();
+ for (auto it = (++cookies.begin()); cookies.end() != it; ++it) {
+ cookie_string += "; " + it->second.AsCookieString();
+ }
+ return cookie_string;
+}
+
+/// \brief Add base64 encoded credentials to the outbound headers.
+///
+/// \param context Context object to add the headers to.
+/// \param username Username to format and encode.
+/// \param password Password to format and encode.
+void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username,
+ const std::string& password) {
+ const std::string credentials = username + ":" + password;
+ context->AddMetadata(kAuthHeader,
+ kBasicPrefix + arrow::util::base64_encode(credentials));
+}
+
+/// \brief Get bearer token from inbound headers.
+///
+/// \param context Incoming ClientContext that contains headers.
+/// \return Arrow result with bearer token (empty if no bearer token found).
+arrow::Result<std::pair<std::string, std::string>> GetBearerTokenHeader(
+ grpc::ClientContext& context) {
+ // Lambda function to compare characters without case sensitivity.
+ auto char_compare = [](const char& char1, const char& char2) {
+ return (::toupper(char1) == ::toupper(char2));
+ };
+
+ // Get the auth token if it exists, this can be in the initial or the trailing metadata.
+ auto trailing_headers = context.GetServerTrailingMetadata();
+ auto initial_headers = context.GetServerInitialMetadata();
+ auto bearer_iter = trailing_headers.find(kAuthHeader);
+ if (bearer_iter == trailing_headers.end()) {
+ bearer_iter = initial_headers.find(kAuthHeader);
+ if (bearer_iter == initial_headers.end()) {
+ return std::make_pair("", "");
+ }
+ }
+
+ // Check if the value of the auth token starts with the bearer prefix and latch it.
+ std::string bearer_val(bearer_iter->second.data(), bearer_iter->second.size());
+ if (bearer_val.size() > strlen(kBearerPrefix)) {
+ if (std::equal(bearer_val.begin(), bearer_val.begin() + strlen(kBearerPrefix),
+ kBearerPrefix, char_compare)) {
+ return std::make_pair(kAuthHeader, bearer_val);
+ }
+ }
+
+ // The server is not required to provide a bearer token.
+ return std::make_pair("", "");
+}
+
+} // namespace internal
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/client_header_internal.h b/src/arrow/cpp/src/arrow/flight/client_header_internal.h
new file mode 100644
index 000000000..dd4498e03
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/client_header_internal.h
@@ -0,0 +1,151 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Interfaces for defining middleware for Flight clients. Currently
+// experimental.
+
+#pragma once
+
+#include "arrow/flight/client_middleware.h"
+#include "arrow/result.h"
+#include "arrow/util/optional.h"
+
+#ifdef GRPCPP_PP_INCLUDE
+#include <grpcpp/grpcpp.h>
+#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS)
+#include <grpcpp/security/tls_credentials_options.h>
+#endif
+#else
+#include <grpc++/grpc++.h>
+#endif
+
+#include <chrono>
+#include <functional>
+#include <mutex>
+#include <string>
+#include <unordered_map>
+
+namespace arrow {
+namespace flight {
+namespace internal {
+
+/// \brief Case insensitive comparator for use by cookie caching map. Cookies are not
+/// case sensitive.
+class ARROW_FLIGHT_EXPORT CaseInsensitiveComparator {
+ public:
+ bool operator()(const std::string& t1, const std::string& t2) const;
+};
+
+/// \brief Case insensitive hasher for use by cookie caching map. Cookies are not
+/// case sensitive.
+class ARROW_FLIGHT_EXPORT CaseInsensitiveHash {
+ public:
+ size_t operator()(const std::string& key) const;
+};
+
+/// \brief Class to represent a cookie.
+class ARROW_FLIGHT_EXPORT Cookie {
+ public:
+ /// \brief Parse function to parse a cookie header value and return a Cookie object.
+ ///
+ /// \return Cookie object based on cookie header value.
+ static Cookie parse(const arrow::util::string_view& cookie_header_value);
+
+ /// \brief Parse a cookie header string beginning at the given start_pos and identify
+ /// the name and value of an attribute.
+ ///
+ /// \param cookie_header_value The value of the Set-Cookie header.
+ /// \param[out] start_pos An input/output parameter indicating the starting position
+ /// of the attribute. It will store the position of the next attribute when the
+ /// function returns.
+ ///
+ /// \return Optional cookie key value pair.
+ static arrow::util::optional<std::pair<std::string, std::string>> ParseCookieAttribute(
+ const std::string& cookie_header_value, std::string::size_type* start_pos);
+
+ /// \brief Function to fix cookie format date string so it is accepted by Windows
+ ///
+ /// parsers.
+ /// \param date Date to fix.
+ static void ConvertCookieDate(std::string* date);
+
+ /// \brief Function to check if the cookie has expired.
+ ///
+ /// \return Returns true if the cookie has expired.
+ bool IsExpired() const;
+
+ /// \brief Function to get cookie as a string.
+ ///
+ /// \return Cookie as a string.
+ std::string AsCookieString() const;
+
+ /// \brief Function to get name of the cookie as a string.
+ ///
+ /// \return Name of the cookie as a string.
+ std::string GetName() const;
+
+ private:
+ std::string cookie_name_;
+ std::string cookie_value_;
+ std::chrono::time_point<std::chrono::system_clock> expiration_time_;
+ bool has_expiry_;
+};
+
+/// \brief Class to handle updating a cookie cache.
+class ARROW_FLIGHT_EXPORT CookieCache {
+ public:
+ /// \brief Updates the cache of cookies with new Set-Cookie header values.
+ ///
+ /// \param incoming_headers The range representing header values.
+ void UpdateCachedCookies(const CallHeaders& incoming_headers);
+
+ /// \brief Retrieve the cached cookie values as a string. This function discards
+ /// cookies that have expired.
+ ///
+ /// \return a string that can be used in a Cookie header representing the cookies that
+ /// have been cached.
+ std::string GetValidCookiesAsString();
+
+ private:
+ /// \brief Removes cookies that are marked as expired from the cache.
+ void DiscardExpiredCookies();
+
+ // Mutex must be used to protect cookie cache.
+ std::mutex mutex_;
+ std::unordered_map<std::string, Cookie, CaseInsensitiveHash, CaseInsensitiveComparator>
+ cookies;
+};
+
+/// \brief Add basic authentication header key value pair to context.
+///
+/// \param context grpc context variable to add header to.
+/// \param username username to encode into header.
+/// \param password password to to encode into header.
+void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context,
+ const std::string& username,
+ const std::string& password);
+
+/// \brief Get bearer token from incoming headers.
+///
+/// \param context context that contains headers which hold the bearer token.
+/// \return Bearer token, parsed from headers, empty if one is not present.
+arrow::Result<std::pair<std::string, std::string>> ARROW_FLIGHT_EXPORT
+GetBearerTokenHeader(grpc::ClientContext& context);
+
+} // namespace internal
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/client_middleware.h b/src/arrow/cpp/src/arrow/flight/client_middleware.h
new file mode 100644
index 000000000..5b67e784b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/client_middleware.h
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Interfaces for defining middleware for Flight clients. Currently
+// experimental.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/flight/middleware.h"
+#include "arrow/flight/visibility.h" // IWYU pragma: keep
+#include "arrow/status.h"
+
+namespace arrow {
+namespace flight {
+
+/// \brief Client-side middleware for a call, instantiated per RPC.
+///
+/// Middleware should be fast and must be infallible: there is no way
+/// to reject the call or report errors from the middleware instance.
+class ARROW_FLIGHT_EXPORT ClientMiddleware {
+ public:
+ virtual ~ClientMiddleware() = default;
+
+ /// \brief A callback before headers are sent. Extra headers can be
+ /// added, but existing ones cannot be read.
+ virtual void SendingHeaders(AddCallHeaders* outgoing_headers) = 0;
+
+ /// \brief A callback when headers are received from the server.
+ virtual void ReceivedHeaders(const CallHeaders& incoming_headers) = 0;
+
+ /// \brief A callback after the call has completed.
+ virtual void CallCompleted(const Status& status) = 0;
+};
+
+/// \brief A factory for new middleware instances.
+///
+/// If added to a client, this will be called for each RPC (including
+/// Handshake) to give the opportunity to intercept the call.
+///
+/// It is guaranteed that all client middleware methods are called
+/// from the same thread that calls the RPC method implementation.
+class ARROW_FLIGHT_EXPORT ClientMiddlewareFactory {
+ public:
+ virtual ~ClientMiddlewareFactory() = default;
+
+ /// \brief A callback for the start of a new call.
+ ///
+ /// \param info Information about the call.
+ /// \param[out] middleware The middleware instance for this call. If
+ /// unset, will not add middleware to this call instance from
+ /// this factory.
+ virtual void StartCall(const CallInfo& info,
+ std::unique_ptr<ClientMiddleware>* middleware) = 0;
+};
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/customize_protobuf.h b/src/arrow/cpp/src/arrow/flight/customize_protobuf.h
new file mode 100644
index 000000000..1508af254
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/customize_protobuf.h
@@ -0,0 +1,108 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <limits>
+#include <memory>
+
+#include "arrow/flight/platform.h"
+#include "arrow/util/config.h"
+
+// Silence protobuf warnings
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4244)
+#pragma warning(disable : 4267)
+#endif
+
+#ifdef GRPCPP_PP_INCLUDE
+#include <grpcpp/impl/codegen/config_protobuf.h>
+#else
+#include <grpc++/impl/codegen/config_protobuf.h>
+#endif
+
+#ifdef GRPCPP_PP_INCLUDE
+#include <grpcpp/impl/codegen/proto_utils.h>
+#else
+#include <grpc++/impl/codegen/proto_utils.h>
+#endif
+
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+
+namespace grpc {
+
+class ByteBuffer;
+
+} // namespace grpc
+
+namespace arrow {
+namespace flight {
+
+struct FlightPayload;
+
+namespace internal {
+
+struct FlightData;
+
+// Those two functions are defined in serialization-internal.cc
+
+// Write FlightData to a grpc::ByteBuffer without extra copying
+grpc::Status FlightDataSerialize(const FlightPayload& msg, grpc::ByteBuffer* out,
+ bool* own_buffer);
+
+// Read internal::FlightData from grpc::ByteBuffer containing FlightData
+// protobuf without copying
+grpc::Status FlightDataDeserialize(grpc::ByteBuffer* buffer, FlightData* out);
+
+} // namespace internal
+
+namespace protocol {
+
+class FlightData;
+
+} // namespace protocol
+} // namespace flight
+} // namespace arrow
+
+namespace grpc {
+
+template <>
+class SerializationTraits<arrow::flight::protocol::FlightData> {
+#ifdef GRPC_CUSTOM_MESSAGELITE
+ using MessageType = grpc::protobuf::MessageLite;
+#else
+ using MessageType = grpc::protobuf::Message;
+#endif
+
+ public:
+ // In the functions below, we cast back the Message argument to its real
+ // type (see ReadPayload() and WritePayload() for the initial cast).
+ static Status Serialize(const MessageType& msg, ByteBuffer* bb, bool* own_buffer) {
+ return arrow::flight::internal::FlightDataSerialize(
+ *reinterpret_cast<const arrow::flight::FlightPayload*>(&msg), bb, own_buffer);
+ }
+
+ static Status Deserialize(ByteBuffer* buffer, MessageType* msg) {
+ return arrow::flight::internal::FlightDataDeserialize(
+ buffer, reinterpret_cast<arrow::flight::internal::FlightData*>(msg));
+ }
+};
+
+} // namespace grpc
diff --git a/src/arrow/cpp/src/arrow/flight/flight_benchmark.cc b/src/arrow/cpp/src/arrow/flight/flight_benchmark.cc
new file mode 100644
index 000000000..1b5f27d31
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/flight_benchmark.cc
@@ -0,0 +1,493 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <mutex>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include <gflags/gflags.h>
+
+#include "arrow/io/file.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/api.h"
+#include "arrow/record_batch.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/compression.h"
+#include "arrow/util/stopwatch.h"
+#include "arrow/util/tdigest.h"
+#include "arrow/util/thread_pool.h"
+
+#include "arrow/flight/api.h"
+#include "arrow/flight/perf.pb.h"
+#include "arrow/flight/test_util.h"
+
+DEFINE_string(server_host, "",
+ "An existing performance server to benchmark against (leave blank to spawn "
+ "one automatically)");
+DEFINE_int32(server_port, 31337, "The port to connect to");
+DEFINE_string(server_unix, "",
+ "An existing performance server listening on Unix socket (leave blank to "
+ "spawn one automatically)");
+DEFINE_bool(test_unix, false, "Test Unix socket instead of TCP");
+DEFINE_int32(num_perf_runs, 1,
+ "Number of times to run the perf test to "
+ "increase precision");
+DEFINE_int32(num_servers, 1, "Number of performance servers to run");
+DEFINE_int32(num_streams, 4, "Number of streams for each server");
+DEFINE_int32(num_threads, 4, "Number of concurrent gets");
+DEFINE_int64(records_per_stream, 10000000, "Total records per stream");
+DEFINE_int32(records_per_batch, 4096, "Total records per batch within stream");
+DEFINE_bool(test_put, false, "Test DoPut instead of DoGet");
+DEFINE_string(compression, "",
+ "Select compression method (\"zstd\", \"lz4\"). "
+ "Leave blank to disable compression.\n"
+ "E.g., \"zstd\": zstd with default compression level.\n"
+ " \"zstd:7\": zstd with compression leve = 7.\n");
+DEFINE_string(
+ data_file, "",
+ "Instead of random data, use data from the given IPC file. Only affects -test_put.");
+DEFINE_string(cert_file, "", "Path to TLS certificate");
+DEFINE_string(key_file, "", "Path to TLS private key (used when spawning a server)");
+
+namespace perf = arrow::flight::perf;
+
+namespace arrow {
+
+using internal::StopWatch;
+using internal::ThreadPool;
+
+namespace flight {
+
+struct PerformanceResult {
+ int64_t num_batches;
+ int64_t num_records;
+ int64_t num_bytes;
+};
+
+struct PerformanceStats {
+ std::mutex mutex;
+ int64_t total_batches = 0;
+ int64_t total_records = 0;
+ int64_t total_bytes = 0;
+ const std::array<double, 3> quantiles = {0.5, 0.95, 0.99};
+ mutable arrow::internal::TDigest latencies;
+
+ void Update(int64_t total_batches, int64_t total_records, int64_t total_bytes) {
+ std::lock_guard<std::mutex> lock(this->mutex);
+ this->total_batches += total_batches;
+ this->total_records += total_records;
+ this->total_bytes += total_bytes;
+ }
+
+ // Invoked per batch in the test loop. Holding a lock looks not scalable.
+ // Tested with 1 ~ 8 threads, no noticeable overhead is observed.
+ // A better approach may be calculate per-thread quantiles and merge.
+ void AddLatency(uint64_t elapsed_nanos) {
+ std::lock_guard<std::mutex> lock(this->mutex);
+ latencies.Add(static_cast<double>(elapsed_nanos));
+ }
+
+ // ns -> us
+ uint64_t max_latency() const { return latencies.Max() / 1000; }
+
+ uint64_t mean_latency() const { return latencies.Mean() / 1000; }
+
+ uint64_t quantile_latency(double q) const { return latencies.Quantile(q) / 1000; }
+};
+
+Status WaitForReady(FlightClient* client, const FlightCallOptions& call_options) {
+ Action action{"ping", nullptr};
+ for (int attempt = 0; attempt < 10; attempt++) {
+ std::unique_ptr<ResultStream> stream;
+ if (client->DoAction(call_options, action, &stream).ok()) {
+ return Status::OK();
+ }
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ }
+ return Status::IOError("Server was not available after 10 attempts");
+}
+
+arrow::Result<PerformanceResult> RunDoGetTest(FlightClient* client,
+ const FlightCallOptions& call_options,
+ const perf::Token& token,
+ const FlightEndpoint& endpoint,
+ PerformanceStats* stats) {
+ std::unique_ptr<FlightStreamReader> reader;
+ RETURN_NOT_OK(client->DoGet(call_options, endpoint.ticket, &reader));
+
+ FlightStreamChunk batch;
+
+ // This is hard-coded for right now, 4 columns each with int64
+ const int bytes_per_record = 32;
+
+ // This must also be set in perf_server.cc
+ const bool verify = false;
+
+ int64_t num_bytes = 0;
+ int64_t num_records = 0;
+ int64_t num_batches = 0;
+ StopWatch timer;
+ while (true) {
+ timer.Start();
+ RETURN_NOT_OK(reader->Next(&batch));
+ stats->AddLatency(timer.Stop());
+ if (!batch.data) {
+ break;
+ }
+
+ if (verify) {
+ auto values = batch.data->column_data(0)->GetValues<int64_t>(1);
+ const int64_t start = token.start() + num_records;
+ for (int64_t i = 0; i < batch.data->num_rows(); ++i) {
+ if (values[i] != start + i) {
+ return Status::Invalid("verification failure");
+ }
+ }
+ }
+
+ ++num_batches;
+ num_records += batch.data->num_rows();
+
+ // Hard-coded
+ num_bytes += batch.data->num_rows() * bytes_per_record;
+ }
+ return PerformanceResult{num_batches, num_records, num_bytes};
+}
+
+struct SizedBatch {
+ std::shared_ptr<arrow::RecordBatch> batch;
+ int64_t bytes;
+};
+
+arrow::Result<std::vector<SizedBatch>> GetPutData(const perf::Token& token) {
+ if (!FLAGS_data_file.empty()) {
+ ARROW_ASSIGN_OR_RAISE(auto file, arrow::io::ReadableFile::Open(FLAGS_data_file));
+ ARROW_ASSIGN_OR_RAISE(auto reader,
+ arrow::ipc::RecordBatchFileReader::Open(std::move(file)));
+ std::vector<SizedBatch> batches(reader->num_record_batches());
+ for (int i = 0; i < reader->num_record_batches(); i++) {
+ ARROW_ASSIGN_OR_RAISE(batches[i].batch, reader->ReadRecordBatch(i));
+ RETURN_NOT_OK(arrow::ipc::GetRecordBatchSize(*batches[i].batch, &batches[i].bytes));
+ }
+ return batches;
+ }
+
+ std::shared_ptr<Schema> schema =
+ arrow::schema({field("a", int64()), field("b", int64()), field("c", int64()),
+ field("d", int64())});
+
+ // This is hard-coded for right now, 4 columns each with int64
+ const int bytes_per_record = 32;
+
+ std::shared_ptr<ResizableBuffer> buffer;
+ std::vector<std::shared_ptr<Array>> arrays;
+
+ const int64_t total_records = token.definition().records_per_stream();
+ const int32_t length = token.definition().records_per_batch();
+ const int32_t ncolumns = 4;
+ for (int i = 0; i < ncolumns; ++i) {
+ RETURN_NOT_OK(MakeRandomByteBuffer(length * sizeof(int64_t), default_memory_pool(),
+ &buffer, static_cast<int32_t>(i) /* seed */));
+ arrays.push_back(std::make_shared<Int64Array>(length, buffer));
+ RETURN_NOT_OK(arrays.back()->Validate());
+ }
+
+ std::shared_ptr<RecordBatch> batch = RecordBatch::Make(schema, length, arrays);
+ std::vector<SizedBatch> batches;
+
+ int64_t records_sent = 0;
+ while (records_sent < total_records) {
+ if (records_sent + length > total_records) {
+ const int last_length = total_records - records_sent;
+ // Hard-coded
+ batches.push_back(SizedBatch{batch->Slice(0, last_length),
+ /*bytes=*/last_length * bytes_per_record});
+ records_sent += last_length;
+ } else {
+ // Hard-coded
+ batches.push_back(SizedBatch{batch, /*bytes=*/length * bytes_per_record});
+ records_sent += length;
+ }
+ }
+ return batches;
+}
+
+arrow::Result<PerformanceResult> RunDoPutTest(FlightClient* client,
+ const FlightCallOptions& call_options,
+ const perf::Token& token,
+ const FlightEndpoint& endpoint,
+ PerformanceStats* stats) {
+ ARROW_ASSIGN_OR_RAISE(const auto batches, GetPutData(token));
+ StopWatch timer;
+ int64_t num_records = 0;
+ int64_t num_bytes = 0;
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightMetadataReader> reader;
+ RETURN_NOT_OK(client->DoPut(call_options, FlightDescriptor{},
+ batches[0].batch->schema(), &writer, &reader));
+ for (size_t i = 0; i < batches.size(); i++) {
+ auto batch = batches[i];
+ auto is_last = i == (batches.size() - 1);
+ if (is_last) {
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch.batch));
+ num_records += batch.batch->num_rows();
+ num_bytes += batch.bytes;
+ } else {
+ timer.Start();
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch.batch));
+ stats->AddLatency(timer.Stop());
+ num_records += batch.batch->num_rows();
+ num_bytes += batch.bytes;
+ }
+ }
+ RETURN_NOT_OK(writer->Close());
+ return PerformanceResult{static_cast<int64_t>(batches.size()), num_records, num_bytes};
+}
+
+Status DoSinglePerfRun(FlightClient* client, const FlightClientOptions client_options,
+ const FlightCallOptions& call_options, bool test_put,
+ PerformanceStats* stats) {
+ // schema not needed
+ perf::Perf perf;
+ perf.set_stream_count(FLAGS_num_streams);
+ perf.set_records_per_stream(FLAGS_records_per_stream);
+ perf.set_records_per_batch(FLAGS_records_per_batch);
+
+ // Plan the query
+ FlightDescriptor descriptor;
+ descriptor.type = FlightDescriptor::CMD;
+ perf.SerializeToString(&descriptor.cmd);
+
+ std::unique_ptr<FlightInfo> plan;
+ RETURN_NOT_OK(client->GetFlightInfo(call_options, descriptor, &plan));
+
+ // Read the streams in parallel
+ std::shared_ptr<Schema> schema;
+ ipc::DictionaryMemo dict_memo;
+ RETURN_NOT_OK(plan->GetSchema(&dict_memo, &schema));
+
+ int64_t start_total_records = stats->total_records;
+
+ auto test_loop = test_put ? &RunDoPutTest : &RunDoGetTest;
+ auto ConsumeStream = [&stats, &test_loop, &client_options,
+ &call_options](const FlightEndpoint& endpoint) {
+ std::unique_ptr<FlightClient> client;
+ RETURN_NOT_OK(
+ FlightClient::Connect(endpoint.locations.front(), client_options, &client));
+
+ perf::Token token;
+ token.ParseFromString(endpoint.ticket.ticket);
+
+ const auto& result = test_loop(client.get(), call_options, token, endpoint, stats);
+ if (result.ok()) {
+ const PerformanceResult& perf = result.ValueOrDie();
+ stats->Update(perf.num_batches, perf.num_records, perf.num_bytes);
+ }
+ return result.status();
+ };
+
+ // XXX(wesm): Serial version for debugging
+ // for (const auto& endpoint : plan->endpoints()) {
+ // RETURN_NOT_OK(ConsumeStream(endpoint));
+ // }
+
+ ARROW_ASSIGN_OR_RAISE(auto pool, ThreadPool::Make(FLAGS_num_threads));
+ std::vector<Future<>> tasks;
+ for (const auto& endpoint : plan->endpoints()) {
+ ARROW_ASSIGN_OR_RAISE(auto task, pool->Submit(ConsumeStream, endpoint));
+ tasks.push_back(std::move(task));
+ }
+
+ // Wait for tasks to finish
+ for (auto&& task : tasks) {
+ RETURN_NOT_OK(task.status());
+ }
+
+ if (FLAGS_data_file.empty()) {
+ // Check that number of rows read / written is as expected
+ int64_t records_for_run = stats->total_records - start_total_records;
+ if (records_for_run != static_cast<int64_t>(plan->total_records())) {
+ return Status::Invalid("Did not consume expected number of records");
+ }
+ }
+ return Status::OK();
+}
+
+Status RunPerformanceTest(FlightClient* client, const FlightClientOptions& client_options,
+ const FlightCallOptions& call_options, bool test_put) {
+ StopWatch timer;
+ timer.Start();
+
+ PerformanceStats stats;
+ for (int i = 0; i < FLAGS_num_perf_runs; ++i) {
+ RETURN_NOT_OK(
+ DoSinglePerfRun(client, client_options, call_options, test_put, &stats));
+ }
+
+ // Elapsed time in seconds
+ uint64_t elapsed_nanos = timer.Stop();
+ double time_elapsed =
+ static_cast<double>(elapsed_nanos) / static_cast<double>(1000000000);
+
+ constexpr double kMegabyte = static_cast<double>(1 << 20);
+
+ std::cout << "Number of perf runs: " << FLAGS_num_perf_runs << std::endl;
+ std::cout << "Number of concurrent gets/puts: " << FLAGS_num_threads << std::endl;
+ std::cout << "Batch size: " << stats.total_bytes / stats.total_batches << std::endl;
+ if (FLAGS_test_put) {
+ std::cout << "Batches written: " << stats.total_batches << std::endl;
+ std::cout << "Bytes written: " << stats.total_bytes << std::endl;
+ } else {
+ std::cout << "Batches read: " << stats.total_batches << std::endl;
+ std::cout << "Bytes read: " << stats.total_bytes << std::endl;
+ }
+
+ std::cout << "Nanos: " << elapsed_nanos << std::endl;
+ std::cout << "Speed: "
+ << (static_cast<double>(stats.total_bytes) / kMegabyte / time_elapsed)
+ << " MB/s" << std::endl;
+
+ // Calculate throughput(IOPS) and latency vs batch size
+ std::cout << "Throughput: " << (static_cast<double>(stats.total_batches) / time_elapsed)
+ << " batches/s" << std::endl;
+ std::cout << "Latency mean: " << stats.mean_latency() << " us" << std::endl;
+ for (auto q : stats.quantiles) {
+ std::cout << "Latency quantile=" << q << ": " << stats.quantile_latency(q) << " us"
+ << std::endl;
+ }
+ std::cout << "Latency max: " << stats.max_latency() << " us" << std::endl;
+
+ return Status::OK();
+}
+
+} // namespace flight
+} // namespace arrow
+
+int main(int argc, char** argv) {
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+ std::cout << "Testing method: ";
+ if (FLAGS_test_put) {
+ std::cout << "DoPut";
+ } else {
+ std::cout << "DoGet";
+ }
+ std::cout << std::endl;
+
+ arrow::flight::FlightCallOptions call_options;
+ if (!FLAGS_compression.empty()) {
+ if (!FLAGS_test_put) {
+ std::cerr << "Compression is only useful for Put test now, "
+ "please append \"-test_put\" to command line"
+ << std::endl;
+ std::abort();
+ }
+
+ // "zstd" -> name = "zstd", level = default
+ // "zstd:7" -> name = "zstd", level = 7
+ const size_t delim = FLAGS_compression.find(":");
+ const std::string name = FLAGS_compression.substr(0, delim);
+ const std::string level_str =
+ delim == std::string::npos
+ ? ""
+ : FLAGS_compression.substr(delim + 1, FLAGS_compression.length() - delim - 1);
+ const int level = level_str.empty() ? arrow::util::kUseDefaultCompressionLevel
+ : std::stoi(level_str);
+ const auto type = arrow::util::Codec::GetCompressionType(name).ValueOrDie();
+ auto codec = arrow::util::Codec::Create(type, level).ValueOrDie();
+ std::cout << "Compression method: " << name;
+ if (!level_str.empty()) {
+ std::cout << ", level " << level;
+ }
+ std::cout << std::endl;
+
+ call_options.write_options.codec = std::move(codec);
+ }
+ if (!FLAGS_data_file.empty() && !FLAGS_test_put) {
+ std::cerr << "A data file can only be specified with \"-test_put\"" << std::endl;
+ return 1;
+ }
+
+ std::unique_ptr<arrow::flight::TestServer> server;
+ arrow::flight::Location location;
+ auto options = arrow::flight::FlightClientOptions::Defaults();
+ if (FLAGS_test_unix || !FLAGS_server_unix.empty()) {
+ if (FLAGS_server_unix == "") {
+ FLAGS_server_unix = "/tmp/flight-bench-spawn.sock";
+ std::cout << "Using spawned Unix server" << std::endl;
+ server.reset(
+ new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_unix));
+ server->Start();
+ } else {
+ std::cout << "Using standalone Unix server" << std::endl;
+ }
+ std::cout << "Server unix socket: " << FLAGS_server_unix << std::endl;
+ ABORT_NOT_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &location));
+ } else {
+ if (FLAGS_server_host == "") {
+ FLAGS_server_host = "localhost";
+ std::cout << "Using spawned TCP server" << std::endl;
+ server.reset(
+ new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_port));
+ std::vector<std::string> args;
+ if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) {
+ if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) {
+ std::cout << "Enabling TLS for spawned server" << std::endl;
+ args.push_back("-cert_file");
+ args.push_back(FLAGS_cert_file);
+ args.push_back("-key_file");
+ args.push_back(FLAGS_key_file);
+ } else {
+ std::cerr << "If providing TLS cert/key, must provide both" << std::endl;
+ return 1;
+ }
+ }
+ server->Start(args);
+ } else {
+ std::cout << "Using standalone TCP server" << std::endl;
+ }
+ std::cout << "Server host: " << FLAGS_server_host << std::endl
+ << "Server port: " << FLAGS_server_port << std::endl;
+ if (FLAGS_cert_file.empty()) {
+ ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host,
+ FLAGS_server_port, &location));
+ } else {
+ ABORT_NOT_OK(arrow::flight::Location::ForGrpcTls(FLAGS_server_host,
+ FLAGS_server_port, &location));
+ options.disable_server_verification = true;
+ }
+ }
+
+ std::unique_ptr<arrow::flight::FlightClient> client;
+ ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, options, &client));
+ ABORT_NOT_OK(arrow::flight::WaitForReady(client.get(), call_options));
+
+ arrow::Status s = arrow::flight::RunPerformanceTest(client.get(), options, call_options,
+ FLAGS_test_put);
+
+ if (server) {
+ server->Stop();
+ }
+
+ if (!s.ok()) {
+ std::cerr << "Failed with error: << " << s.ToString() << std::endl;
+ }
+
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/flight/flight_test.cc b/src/arrow/cpp/src/arrow/flight/flight_test.cc
new file mode 100644
index 000000000..56ca468a0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/flight_test.cc
@@ -0,0 +1,2872 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <atomic>
+#include <chrono>
+#include <cstdint>
+#include <cstdio>
+#include <cstring>
+#include <iostream>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include "arrow/flight/api.h"
+#include "arrow/ipc/test_common.h"
+#include "arrow/status.h"
+#include "arrow/testing/generator.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/base64.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/string.h"
+
+#ifdef GRPCPP_GRPCPP_H
+#error "gRPC headers should not be in public API"
+#endif
+
+#include "arrow/flight/client_cookie_middleware.h"
+#include "arrow/flight/client_header_internal.h"
+#include "arrow/flight/internal.h"
+#include "arrow/flight/middleware_internal.h"
+#include "arrow/flight/test_util.h"
+
+namespace arrow {
+namespace flight {
+
+namespace pb = arrow::flight::protocol;
+
+const char kValidUsername[] = "flight_username";
+const char kValidPassword[] = "flight_password";
+const char kInvalidUsername[] = "invalid_flight_username";
+const char kInvalidPassword[] = "invalid_flight_password";
+const char kBearerToken[] = "bearertoken";
+const char kBasicPrefix[] = "Basic ";
+const char kBearerPrefix[] = "Bearer ";
+const char kAuthHeader[] = "authorization";
+
+void AssertEqual(const ActionType& expected, const ActionType& actual) {
+ ASSERT_EQ(expected.type, actual.type);
+ ASSERT_EQ(expected.description, actual.description);
+}
+
+void AssertEqual(const FlightDescriptor& expected, const FlightDescriptor& actual) {
+ ASSERT_TRUE(expected.Equals(actual));
+}
+
+void AssertEqual(const Ticket& expected, const Ticket& actual) {
+ ASSERT_EQ(expected.ticket, actual.ticket);
+}
+
+void AssertEqual(const Location& expected, const Location& actual) {
+ ASSERT_EQ(expected, actual);
+}
+
+void AssertEqual(const std::vector<FlightEndpoint>& expected,
+ const std::vector<FlightEndpoint>& actual) {
+ ASSERT_EQ(expected.size(), actual.size());
+ for (size_t i = 0; i < expected.size(); ++i) {
+ AssertEqual(expected[i].ticket, actual[i].ticket);
+
+ ASSERT_EQ(expected[i].locations.size(), actual[i].locations.size());
+ for (size_t j = 0; j < expected[i].locations.size(); ++j) {
+ AssertEqual(expected[i].locations[j], actual[i].locations[j]);
+ }
+ }
+}
+
+template <typename T>
+void AssertEqual(const std::vector<T>& expected, const std::vector<T>& actual) {
+ ASSERT_EQ(expected.size(), actual.size());
+ for (size_t i = 0; i < expected.size(); ++i) {
+ AssertEqual(expected[i], actual[i]);
+ }
+}
+
+void AssertEqual(const FlightInfo& expected, const FlightInfo& actual) {
+ std::shared_ptr<Schema> ex_schema, actual_schema;
+ ipc::DictionaryMemo expected_memo;
+ ipc::DictionaryMemo actual_memo;
+ ASSERT_OK(expected.GetSchema(&expected_memo, &ex_schema));
+ ASSERT_OK(actual.GetSchema(&actual_memo, &actual_schema));
+
+ AssertSchemaEqual(*ex_schema, *actual_schema);
+ ASSERT_EQ(expected.total_records(), actual.total_records());
+ ASSERT_EQ(expected.total_bytes(), actual.total_bytes());
+
+ AssertEqual(expected.descriptor(), actual.descriptor());
+ AssertEqual(expected.endpoints(), actual.endpoints());
+}
+
+TEST(TestFlightDescriptor, Basics) {
+ auto a = FlightDescriptor::Command("select * from table");
+ auto b = FlightDescriptor::Command("select * from table");
+ auto c = FlightDescriptor::Command("select foo from table");
+ auto d = FlightDescriptor::Path({"foo", "bar"});
+ auto e = FlightDescriptor::Path({"foo", "baz"});
+ auto f = FlightDescriptor::Path({"foo", "baz"});
+
+ ASSERT_EQ(a.ToString(), "FlightDescriptor<cmd = 'select * from table'>");
+ ASSERT_EQ(d.ToString(), "FlightDescriptor<path = 'foo/bar'>");
+ ASSERT_TRUE(a.Equals(b));
+ ASSERT_FALSE(a.Equals(c));
+ ASSERT_FALSE(a.Equals(d));
+ ASSERT_FALSE(d.Equals(e));
+ ASSERT_TRUE(e.Equals(f));
+}
+
+// This tests the internal protobuf types which don't get exported in the Flight DLL.
+#ifndef _WIN32
+TEST(TestFlightDescriptor, ToFromProto) {
+ FlightDescriptor descr_test;
+ pb::FlightDescriptor pb_descr;
+
+ FlightDescriptor descr1{FlightDescriptor::PATH, "", {"foo", "bar"}};
+ ASSERT_OK(internal::ToProto(descr1, &pb_descr));
+ ASSERT_OK(internal::FromProto(pb_descr, &descr_test));
+ AssertEqual(descr1, descr_test);
+
+ FlightDescriptor descr2{FlightDescriptor::CMD, "command", {}};
+ ASSERT_OK(internal::ToProto(descr2, &pb_descr));
+ ASSERT_OK(internal::FromProto(pb_descr, &descr_test));
+ AssertEqual(descr2, descr_test);
+}
+#endif
+
+TEST(TestFlight, DISABLED_StartStopTestServer) {
+ TestServer server("flight-test-server");
+ server.Start();
+ ASSERT_TRUE(server.IsRunning());
+
+ std::this_thread::sleep_for(std::chrono::duration<double>(0.2));
+
+ ASSERT_TRUE(server.IsRunning());
+ int exit_code = server.Stop();
+#ifdef _WIN32
+ // We do a hard kill on Windows
+ ASSERT_EQ(259, exit_code);
+#else
+ ASSERT_EQ(0, exit_code);
+#endif
+ ASSERT_FALSE(server.IsRunning());
+}
+
+// ARROW-6017: we should be able to construct locations for unknown
+// schemes
+TEST(TestFlight, UnknownLocationScheme) {
+ Location location;
+ ASSERT_OK(Location::Parse("s3://test", &location));
+ ASSERT_OK(Location::Parse("https://example.com/foo", &location));
+}
+
+TEST(TestFlight, ConnectUri) {
+ TestServer server("flight-test-server");
+ server.Start();
+ ASSERT_TRUE(server.IsRunning());
+
+ std::stringstream ss;
+ ss << "grpc://localhost:" << server.port();
+ std::string uri = ss.str();
+
+ std::unique_ptr<FlightClient> client;
+ Location location1;
+ Location location2;
+ ASSERT_OK(Location::Parse(uri, &location1));
+ ASSERT_OK(Location::Parse(uri, &location2));
+ ASSERT_OK(FlightClient::Connect(location1, &client));
+ ASSERT_OK(FlightClient::Connect(location2, &client));
+}
+
+#ifndef _WIN32
+TEST(TestFlight, ConnectUriUnix) {
+ TestServer server("flight-test-server", "/tmp/flight-test.sock");
+ server.Start();
+ ASSERT_TRUE(server.IsRunning());
+
+ std::stringstream ss;
+ ss << "grpc+unix://" << server.unix_sock();
+ std::string uri = ss.str();
+
+ std::unique_ptr<FlightClient> client;
+ Location location1;
+ Location location2;
+ ASSERT_OK(Location::Parse(uri, &location1));
+ ASSERT_OK(Location::Parse(uri, &location2));
+ ASSERT_OK(FlightClient::Connect(location1, &client));
+ ASSERT_OK(FlightClient::Connect(location2, &client));
+}
+#endif
+
+TEST(TestFlight, RoundTripTypes) {
+ Ticket ticket{"foo"};
+ std::string ticket_serialized;
+ Ticket ticket_deserialized;
+ ASSERT_OK(ticket.SerializeToString(&ticket_serialized));
+ ASSERT_OK(Ticket::Deserialize(ticket_serialized, &ticket_deserialized));
+ ASSERT_EQ(ticket.ticket, ticket_deserialized.ticket);
+
+ FlightDescriptor desc = FlightDescriptor::Command("select * from foo;");
+ std::string desc_serialized;
+ FlightDescriptor desc_deserialized;
+ ASSERT_OK(desc.SerializeToString(&desc_serialized));
+ ASSERT_OK(FlightDescriptor::Deserialize(desc_serialized, &desc_deserialized));
+ ASSERT_TRUE(desc.Equals(desc_deserialized));
+
+ desc = FlightDescriptor::Path({"a", "b", "test.arrow"});
+ ASSERT_OK(desc.SerializeToString(&desc_serialized));
+ ASSERT_OK(FlightDescriptor::Deserialize(desc_serialized, &desc_deserialized));
+ ASSERT_TRUE(desc.Equals(desc_deserialized));
+
+ FlightInfo::Data data;
+ std::shared_ptr<Schema> schema =
+ arrow::schema({field("a", int64()), field("b", int64()), field("c", int64()),
+ field("d", int64())});
+ Location location1, location2, location3;
+ ASSERT_OK(Location::ForGrpcTcp("localhost", 10010, &location1));
+ ASSERT_OK(Location::ForGrpcTls("localhost", 10010, &location2));
+ ASSERT_OK(Location::ForGrpcUnix("/tmp/test.sock", &location3));
+ std::vector<FlightEndpoint> endpoints{FlightEndpoint{ticket, {location1, location2}},
+ FlightEndpoint{ticket, {location3}}};
+ ASSERT_OK(MakeFlightInfo(*schema, desc, endpoints, -1, -1, &data));
+ std::unique_ptr<FlightInfo> info = std::unique_ptr<FlightInfo>(new FlightInfo(data));
+ std::string info_serialized;
+ std::unique_ptr<FlightInfo> info_deserialized;
+ ASSERT_OK(info->SerializeToString(&info_serialized));
+ ASSERT_OK(FlightInfo::Deserialize(info_serialized, &info_deserialized));
+ ASSERT_TRUE(info->descriptor().Equals(info_deserialized->descriptor()));
+ ASSERT_EQ(info->endpoints(), info_deserialized->endpoints());
+ ASSERT_EQ(info->total_records(), info_deserialized->total_records());
+ ASSERT_EQ(info->total_bytes(), info_deserialized->total_bytes());
+}
+
+TEST(TestFlight, RoundtripStatus) {
+ // Make sure status codes round trip through our conversions
+
+ std::shared_ptr<FlightStatusDetail> detail;
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Internal, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Internal, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::TimedOut, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::TimedOut, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Cancelled, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Cancelled, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Unauthenticated, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Unauthenticated, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Unauthorized, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Unauthorized, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Unavailable, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Unavailable, detail->code());
+
+ Status status = internal::FromGrpcStatus(
+ internal::ToGrpcStatus(Status::NotImplemented("Sentinel")));
+ ASSERT_TRUE(status.IsNotImplemented());
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
+
+ status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::Invalid("Sentinel")));
+ ASSERT_TRUE(status.IsInvalid());
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
+
+ status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::KeyError("Sentinel")));
+ ASSERT_TRUE(status.IsKeyError());
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
+
+ status =
+ internal::FromGrpcStatus(internal::ToGrpcStatus(Status::AlreadyExists("Sentinel")));
+ ASSERT_TRUE(status.IsAlreadyExists());
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
+}
+
+TEST(TestFlight, GetPort) {
+ Location location;
+ std::unique_ptr<FlightServerBase> server = ExampleTestServer();
+
+ ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location));
+ FlightServerOptions options(location);
+ ASSERT_OK(server->Init(options));
+ ASSERT_GT(server->port(), 0);
+}
+
+// CI environments don't have an IPv6 interface configured
+TEST(TestFlight, DISABLED_IpV6Port) {
+ Location location, location2;
+ std::unique_ptr<FlightServerBase> server = ExampleTestServer();
+
+ ASSERT_OK(Location::ForGrpcTcp("[::1]", 0, &location));
+ FlightServerOptions options(location);
+ ASSERT_OK(server->Init(options));
+ ASSERT_GT(server->port(), 0);
+
+ ASSERT_OK(Location::ForGrpcTcp("[::1]", server->port(), &location2));
+ std::unique_ptr<FlightClient> client;
+ ASSERT_OK(FlightClient::Connect(location2, &client));
+ std::unique_ptr<FlightListing> listing;
+ ASSERT_OK(client->ListFlights(&listing));
+}
+
+TEST(TestFlight, BuilderHook) {
+ Location location;
+ std::unique_ptr<FlightServerBase> server = ExampleTestServer();
+
+ ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location));
+ FlightServerOptions options(location);
+ bool builder_hook_run = false;
+ options.builder_hook = [&builder_hook_run](void* builder) {
+ ASSERT_NE(nullptr, builder);
+ builder_hook_run = true;
+ };
+ ASSERT_OK(server->Init(options));
+ ASSERT_TRUE(builder_hook_run);
+ ASSERT_GT(server->port(), 0);
+ ASSERT_OK(server->Shutdown());
+}
+
+// ----------------------------------------------------------------------
+// Client tests
+
+// Helper to initialize a server and matching client with callbacks to
+// populate options.
+template <typename T, typename... Args>
+Status MakeServer(std::unique_ptr<FlightServerBase>* server,
+ std::unique_ptr<FlightClient>* client,
+ std::function<Status(FlightServerOptions*)> make_server_options,
+ std::function<Status(FlightClientOptions*)> make_client_options,
+ Args&&... server_args) {
+ Location location;
+ RETURN_NOT_OK(Location::ForGrpcTcp("localhost", 0, &location));
+ *server = arrow::internal::make_unique<T>(std::forward<Args>(server_args)...);
+ FlightServerOptions server_options(location);
+ RETURN_NOT_OK(make_server_options(&server_options));
+ RETURN_NOT_OK((*server)->Init(server_options));
+ Location real_location;
+ RETURN_NOT_OK(Location::ForGrpcTcp("localhost", (*server)->port(), &real_location));
+ FlightClientOptions client_options = FlightClientOptions::Defaults();
+ RETURN_NOT_OK(make_client_options(&client_options));
+ return FlightClient::Connect(real_location, client_options, client);
+}
+
+class TestFlightClient : public ::testing::Test {
+ public:
+ void SetUp() {
+ server_ = ExampleTestServer();
+
+ Location location;
+ ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location));
+ FlightServerOptions options(location);
+ ASSERT_OK(server_->Init(options));
+
+ ASSERT_OK(ConnectClient());
+ }
+
+ void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ Status ConnectClient() {
+ Location location;
+ RETURN_NOT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location));
+ return FlightClient::Connect(location, &client_);
+ }
+
+ template <typename EndpointCheckFunc>
+ void CheckDoGet(const FlightDescriptor& descr, const BatchVector& expected_batches,
+ EndpointCheckFunc&& check_endpoints) {
+ auto expected_schema = expected_batches[0]->schema();
+
+ std::unique_ptr<FlightInfo> info;
+ ASSERT_OK(client_->GetFlightInfo(descr, &info));
+ check_endpoints(info->endpoints());
+
+ std::shared_ptr<Schema> schema;
+ ipc::DictionaryMemo dict_memo;
+ ASSERT_OK(info->GetSchema(&dict_memo, &schema));
+ AssertSchemaEqual(*expected_schema, *schema);
+
+ // By convention, fetch the first endpoint
+ Ticket ticket = info->endpoints()[0].ticket;
+ CheckDoGet(ticket, expected_batches);
+ }
+
+ void CheckDoGet(const Ticket& ticket, const BatchVector& expected_batches) {
+ auto num_batches = static_cast<int>(expected_batches.size());
+ ASSERT_GE(num_batches, 2);
+
+ std::unique_ptr<FlightStreamReader> stream;
+ ASSERT_OK(client_->DoGet(ticket, &stream));
+
+ std::unique_ptr<FlightStreamReader> stream2;
+ ASSERT_OK(client_->DoGet(ticket, &stream2));
+ ASSERT_OK_AND_ASSIGN(auto reader, MakeRecordBatchReader(std::move(stream2)));
+
+ FlightStreamChunk chunk;
+ std::shared_ptr<RecordBatch> batch;
+ for (int i = 0; i < num_batches; ++i) {
+ ASSERT_OK(stream->Next(&chunk));
+ ASSERT_OK(reader->ReadNext(&batch));
+ ASSERT_NE(nullptr, chunk.data);
+ ASSERT_NE(nullptr, batch);
+#if !defined(__MINGW32__)
+ ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
+ ASSERT_BATCHES_EQUAL(*expected_batches[i], *batch);
+#else
+ // In MINGW32, the following code does not have the reproducibility at the LSB
+ // even when this is called twice with the same seed.
+ // As a workaround, use approxEqual
+ // /* from GenerateTypedData in random.cc */
+ // std::default_random_engine rng(seed); // seed = 282475250
+ // std::uniform_real_distribution<double> dist;
+ // std::generate(data, data + n, // n = 10
+ // [&dist, &rng] { return static_cast<ValueType>(dist(rng)); });
+ // /* data[1] = 0x40852cdfe23d3976 or 0x40852cdfe23d3975 */
+ ASSERT_BATCHES_APPROX_EQUAL(*expected_batches[i], *chunk.data);
+ ASSERT_BATCHES_APPROX_EQUAL(*expected_batches[i], *batch);
+#endif
+ }
+
+ // Stream exhausted
+ ASSERT_OK(stream->Next(&chunk));
+ ASSERT_OK(reader->ReadNext(&batch));
+ ASSERT_EQ(nullptr, chunk.data);
+ ASSERT_EQ(nullptr, batch);
+ }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+};
+
+class AuthTestServer : public FlightServerBase {
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* result) override {
+ auto buf = Buffer::FromString(context.peer_identity());
+ auto peer = Buffer::FromString(context.peer());
+ *result = std::unique_ptr<ResultStream>(
+ new SimpleResultStream({Result{buf}, Result{peer}}));
+ return Status::OK();
+ }
+};
+
+class TlsTestServer : public FlightServerBase {
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* result) override {
+ auto buf = Buffer::FromString("Hello, world!");
+ *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
+ return Status::OK();
+ }
+};
+
+class DoPutTestServer : public FlightServerBase {
+ public:
+ Status DoPut(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer) override {
+ descriptor_ = reader->descriptor();
+ return reader->ReadAll(&batches_);
+ }
+
+ protected:
+ FlightDescriptor descriptor_;
+ BatchVector batches_;
+
+ friend class TestDoPut;
+};
+
+class MetadataTestServer : public FlightServerBase {
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* data_stream) override {
+ BatchVector batches;
+ if (request.ticket == "dicts") {
+ RETURN_NOT_OK(ExampleDictBatches(&batches));
+ } else if (request.ticket == "floats") {
+ RETURN_NOT_OK(ExampleFloatBatches(&batches));
+ } else {
+ RETURN_NOT_OK(ExampleIntBatches(&batches));
+ }
+ std::shared_ptr<RecordBatchReader> batch_reader =
+ std::make_shared<BatchIterator>(batches[0]->schema(), batches);
+
+ *data_stream = std::unique_ptr<FlightDataStream>(new NumberingStream(
+ std::unique_ptr<FlightDataStream>(new RecordBatchStream(batch_reader))));
+ return Status::OK();
+ }
+
+ Status DoPut(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer) override {
+ FlightStreamChunk chunk;
+ int counter = 0;
+ while (true) {
+ RETURN_NOT_OK(reader->Next(&chunk));
+ if (chunk.data == nullptr) break;
+ if (chunk.app_metadata == nullptr) {
+ return Status::Invalid("Expected application metadata to be provided");
+ }
+ if (std::to_string(counter) != chunk.app_metadata->ToString()) {
+ return Status::Invalid("Expected metadata value: " + std::to_string(counter) +
+ " but got: " + chunk.app_metadata->ToString());
+ }
+ auto metadata = Buffer::FromString(std::to_string(counter));
+ RETURN_NOT_OK(writer->WriteMetadata(*metadata));
+ counter++;
+ }
+ return Status::OK();
+ }
+};
+
+// Server for testing custom IPC options support
+class OptionsTestServer : public FlightServerBase {
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* data_stream) override {
+ BatchVector batches;
+ RETURN_NOT_OK(ExampleNestedBatches(&batches));
+ auto reader = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
+ *data_stream = std::unique_ptr<FlightDataStream>(new RecordBatchStream(reader));
+ return Status::OK();
+ }
+
+ // Just echo the number of batches written. The client will try to
+ // call this method with different write options set.
+ Status DoPut(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer) override {
+ FlightStreamChunk chunk;
+ int counter = 0;
+ while (true) {
+ RETURN_NOT_OK(reader->Next(&chunk));
+ if (chunk.data == nullptr) break;
+ counter++;
+ }
+ auto metadata = Buffer::FromString(std::to_string(counter));
+ return writer->WriteMetadata(*metadata);
+ }
+
+ // Echo client data, but with write options set to limit the nesting
+ // level.
+ Status DoExchange(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer) override {
+ FlightStreamChunk chunk;
+ auto options = ipc::IpcWriteOptions::Defaults();
+ options.max_recursion_depth = 1;
+ bool begun = false;
+ while (true) {
+ RETURN_NOT_OK(reader->Next(&chunk));
+ if (!chunk.data && !chunk.app_metadata) {
+ break;
+ }
+ if (!begun && chunk.data) {
+ begun = true;
+ RETURN_NOT_OK(writer->Begin(chunk.data->schema(), options));
+ }
+ if (chunk.data && chunk.app_metadata) {
+ RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata));
+ } else if (chunk.data) {
+ RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data));
+ } else if (chunk.app_metadata) {
+ RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata));
+ }
+ }
+ return Status::OK();
+ }
+};
+
+class HeaderAuthTestServer : public FlightServerBase {
+ public:
+ Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
+ std::unique_ptr<FlightListing>* listings) override {
+ return Status::OK();
+ }
+};
+
+class TestMetadata : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK(MakeServer<MetadataTestServer>(
+ &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+ }
+
+ void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+};
+
+class TestOptions : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK(MakeServer<OptionsTestServer>(
+ &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+ }
+
+ void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+};
+
+class TestAuthHandler : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK(MakeServer<AuthTestServer>(
+ &server_, &client_,
+ [](FlightServerOptions* options) {
+ options->auth_handler = std::unique_ptr<ServerAuthHandler>(
+ new TestServerAuthHandler("user", "p4ssw0rd"));
+ return Status::OK();
+ },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+ }
+
+ void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+};
+
+class TestBasicAuthHandler : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK(MakeServer<AuthTestServer>(
+ &server_, &client_,
+ [](FlightServerOptions* options) {
+ options->auth_handler = std::unique_ptr<ServerAuthHandler>(
+ new TestServerBasicAuthHandler("user", "p4ssw0rd"));
+ return Status::OK();
+ },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+ }
+
+ void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+};
+
+class TestDoPut : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK(MakeServer<DoPutTestServer>(
+ &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+ do_put_server_ = (DoPutTestServer*)server_.get();
+ }
+
+ void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ void CheckBatches(FlightDescriptor expected_descriptor,
+ const BatchVector& expected_batches) {
+ ASSERT_TRUE(do_put_server_->descriptor_.Equals(expected_descriptor));
+ ASSERT_EQ(do_put_server_->batches_.size(), expected_batches.size());
+ for (size_t i = 0; i < expected_batches.size(); ++i) {
+ ASSERT_BATCHES_EQUAL(*do_put_server_->batches_[i], *expected_batches[i]);
+ }
+ }
+
+ void CheckDoPut(FlightDescriptor descr, const std::shared_ptr<Schema>& schema,
+ const BatchVector& batches) {
+ std::unique_ptr<FlightStreamWriter> stream;
+ std::unique_ptr<FlightMetadataReader> reader;
+ ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
+ for (const auto& batch : batches) {
+ ASSERT_OK(stream->WriteRecordBatch(*batch));
+ }
+ ASSERT_OK(stream->DoneWriting());
+ ASSERT_OK(stream->Close());
+
+ CheckBatches(descr, batches);
+ }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+ DoPutTestServer* do_put_server_;
+};
+
+class TestTls : public ::testing::Test {
+ public:
+ void SetUp() {
+ // Manually initialize gRPC to try to ensure some thread-locals
+ // get initialized.
+ // https://github.com/grpc/grpc/issues/13856
+ // https://github.com/grpc/grpc/issues/20311
+ // In general, gRPC on MacOS struggles with TLS (both in the sense
+ // of thread-locals and encryption)
+ grpc_init();
+
+ server_.reset(new TlsTestServer);
+
+ Location location;
+ ASSERT_OK(Location::ForGrpcTls("localhost", 0, &location));
+ FlightServerOptions options(location);
+ ASSERT_RAISES(UnknownError, server_->Init(options));
+ ASSERT_OK(ExampleTlsCertificates(&options.tls_certificates));
+ ASSERT_OK(server_->Init(options));
+
+ ASSERT_OK(Location::ForGrpcTls("localhost", server_->port(), &location_));
+ ASSERT_OK(ConnectClient());
+ }
+
+ void TearDown() {
+ ASSERT_OK(server_->Shutdown());
+ grpc_shutdown();
+ }
+
+ Status ConnectClient() {
+ auto options = FlightClientOptions::Defaults();
+ CertKeyPair root_cert;
+ RETURN_NOT_OK(ExampleTlsCertificateRoot(&root_cert));
+ options.tls_root_certs = root_cert.pem_cert;
+ return FlightClient::Connect(location_, options, &client_);
+ }
+
+ protected:
+ Location location_;
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+};
+
+// A server middleware that rejects all calls.
+class RejectServerMiddlewareFactory : public ServerMiddlewareFactory {
+ Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) override {
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "All calls are rejected");
+ }
+};
+
+// A server middleware that counts the number of successful and failed
+// calls.
+class CountingServerMiddleware : public ServerMiddleware {
+ public:
+ CountingServerMiddleware(std::atomic<int>* successful, std::atomic<int>* failed)
+ : successful_(successful), failed_(failed) {}
+ void SendingHeaders(AddCallHeaders* outgoing_headers) override {}
+ void CallCompleted(const Status& status) override {
+ if (status.ok()) {
+ ARROW_IGNORE_EXPR((*successful_)++);
+ } else {
+ ARROW_IGNORE_EXPR((*failed_)++);
+ }
+ }
+
+ std::string name() const override { return "CountingServerMiddleware"; }
+
+ private:
+ std::atomic<int>* successful_;
+ std::atomic<int>* failed_;
+};
+
+class CountingServerMiddlewareFactory : public ServerMiddlewareFactory {
+ public:
+ CountingServerMiddlewareFactory() : successful_(0), failed_(0) {}
+
+ Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) override {
+ *middleware = std::make_shared<CountingServerMiddleware>(&successful_, &failed_);
+ return Status::OK();
+ }
+
+ std::atomic<int> successful_;
+ std::atomic<int> failed_;
+};
+
+// The current span ID, used to emulate OpenTracing style distributed
+// tracing. Only used for communication between application code and
+// client middleware.
+static thread_local std::string current_span_id = "";
+
+// A server middleware that stores the current span ID, in an
+// emulation of OpenTracing style distributed tracing.
+class TracingServerMiddleware : public ServerMiddleware {
+ public:
+ explicit TracingServerMiddleware(const std::string& current_span_id)
+ : span_id(current_span_id) {}
+ void SendingHeaders(AddCallHeaders* outgoing_headers) override {}
+ void CallCompleted(const Status& status) override {}
+
+ std::string name() const override { return "TracingServerMiddleware"; }
+
+ std::string span_id;
+};
+
+class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
+ public:
+ TracingServerMiddlewareFactory() {}
+
+ Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) override {
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& iter_pair =
+ incoming_headers.equal_range("x-tracing-span-id");
+ if (iter_pair.first != iter_pair.second) {
+ const util::string_view& value = (*iter_pair.first).second;
+ *middleware = std::make_shared<TracingServerMiddleware>(std::string(value));
+ }
+ return Status::OK();
+ }
+};
+
+// Function to look in CallHeaders for a key that has a value starting with prefix and
+// return the rest of the value after the prefix.
+std::string FindKeyValPrefixInCallHeaders(const CallHeaders& incoming_headers,
+ const std::string& key,
+ const std::string& prefix) {
+ // Lambda function to compare characters without case sensitivity.
+ auto char_compare = [](const char& char1, const char& char2) {
+ return (::toupper(char1) == ::toupper(char2));
+ };
+
+ auto iter = incoming_headers.find(key);
+ if (iter == incoming_headers.end()) {
+ return "";
+ }
+ const std::string val = iter->second.to_string();
+ if (val.size() > prefix.length()) {
+ if (std::equal(val.begin(), val.begin() + prefix.length(), prefix.begin(),
+ char_compare)) {
+ return val.substr(prefix.length());
+ }
+ }
+ return "";
+}
+
+class HeaderAuthServerMiddleware : public ServerMiddleware {
+ public:
+ void SendingHeaders(AddCallHeaders* outgoing_headers) override {
+ outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerToken);
+ }
+
+ void CallCompleted(const Status& status) override {}
+
+ std::string name() const override { return "HeaderAuthServerMiddleware"; }
+};
+
+void ParseBasicHeader(const CallHeaders& incoming_headers, std::string& username,
+ std::string& password) {
+ std::string encoded_credentials =
+ FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix);
+ std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials));
+ std::getline(decoded_stream, username, ':');
+ std::getline(decoded_stream, password, ':');
+}
+
+// Factory for base64 header authentication testing.
+class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
+ public:
+ HeaderAuthServerMiddlewareFactory() {}
+
+ Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) override {
+ std::string username, password;
+ ParseBasicHeader(incoming_headers, username, password);
+ if ((username == kValidUsername) && (password == kValidPassword)) {
+ *middleware = std::make_shared<HeaderAuthServerMiddleware>();
+ } else if ((username == kInvalidUsername) && (password == kInvalidPassword)) {
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid credentials");
+ }
+ return Status::OK();
+ }
+};
+
+// A server middleware for validating incoming bearer header authentication.
+class BearerAuthServerMiddleware : public ServerMiddleware {
+ public:
+ explicit BearerAuthServerMiddleware(const CallHeaders& incoming_headers, bool* isValid)
+ : isValid_(isValid) {
+ incoming_headers_ = incoming_headers;
+ }
+
+ void SendingHeaders(AddCallHeaders* outgoing_headers) override {
+ std::string bearer_token =
+ FindKeyValPrefixInCallHeaders(incoming_headers_, kAuthHeader, kBearerPrefix);
+ *isValid_ = (bearer_token == std::string(kBearerToken));
+ }
+
+ void CallCompleted(const Status& status) override {}
+
+ std::string name() const override { return "BearerAuthServerMiddleware"; }
+
+ private:
+ CallHeaders incoming_headers_;
+ bool* isValid_;
+};
+
+// Factory for base64 header authentication testing.
+class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
+ public:
+ BearerAuthServerMiddlewareFactory() : isValid_(false) {}
+
+ Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) override {
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& iter_pair =
+ incoming_headers.equal_range(kAuthHeader);
+ if (iter_pair.first != iter_pair.second) {
+ *middleware =
+ std::make_shared<BearerAuthServerMiddleware>(incoming_headers, &isValid_);
+ }
+ return Status::OK();
+ }
+
+ bool GetIsValid() { return isValid_; }
+
+ private:
+ bool isValid_;
+};
+
+// A client middleware that adds a thread-local "request ID" to
+// outgoing calls as a header, and keeps track of the status of
+// completed calls. NOT thread-safe.
+class PropagatingClientMiddleware : public ClientMiddleware {
+ public:
+ explicit PropagatingClientMiddleware(std::atomic<int>* received_headers,
+ std::vector<Status>* recorded_status)
+ : received_headers_(received_headers), recorded_status_(recorded_status) {}
+
+ void SendingHeaders(AddCallHeaders* outgoing_headers) {
+ // Pick up the span ID from thread locals. We have to use a
+ // thread-local for communication, since we aren't even
+ // instantiated until after the application code has already
+ // started the call (and so there's no chance for application code
+ // to pass us parameters directly).
+ outgoing_headers->AddHeader("x-tracing-span-id", current_span_id);
+ }
+
+ void ReceivedHeaders(const CallHeaders& incoming_headers) { (*received_headers_)++; }
+
+ void CallCompleted(const Status& status) { recorded_status_->push_back(status); }
+
+ private:
+ std::atomic<int>* received_headers_;
+ std::vector<Status>* recorded_status_;
+};
+
+class PropagatingClientMiddlewareFactory : public ClientMiddlewareFactory {
+ public:
+ void StartCall(const CallInfo& info, std::unique_ptr<ClientMiddleware>* middleware) {
+ recorded_calls_.push_back(info.method);
+ *middleware = arrow::internal::make_unique<PropagatingClientMiddleware>(
+ &received_headers_, &recorded_status_);
+ }
+
+ void Reset() {
+ recorded_calls_.clear();
+ recorded_status_.clear();
+ received_headers_.fetch_and(0);
+ }
+
+ std::vector<FlightMethod> recorded_calls_;
+ std::vector<Status> recorded_status_;
+ std::atomic<int> received_headers_;
+};
+
+class ReportContextTestServer : public FlightServerBase {
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* result) override {
+ std::shared_ptr<Buffer> buf;
+ const ServerMiddleware* middleware = context.GetMiddleware("tracing");
+ if (middleware == nullptr || middleware->name() != "TracingServerMiddleware") {
+ buf = Buffer::FromString("");
+ } else {
+ buf = Buffer::FromString(((const TracingServerMiddleware*)middleware)->span_id);
+ }
+ *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
+ return Status::OK();
+ }
+};
+
+class ErrorMiddlewareServer : public FlightServerBase {
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* result) override {
+ std::string msg = "error_message";
+ auto buf = Buffer::FromString("");
+
+ std::shared_ptr<FlightStatusDetail> flightStatusDetail(
+ new FlightStatusDetail(FlightStatusCode::Failed, msg));
+ *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
+ return Status(StatusCode::ExecutionError, "test failed", flightStatusDetail);
+ }
+};
+
+class PropagatingTestServer : public FlightServerBase {
+ public:
+ explicit PropagatingTestServer(std::unique_ptr<FlightClient> client)
+ : client_(std::move(client)) {}
+
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* result) override {
+ const ServerMiddleware* middleware = context.GetMiddleware("tracing");
+ if (middleware == nullptr || middleware->name() != "TracingServerMiddleware") {
+ current_span_id = "";
+ } else {
+ current_span_id = ((const TracingServerMiddleware*)middleware)->span_id;
+ }
+
+ return client_->DoAction(action, result);
+ }
+
+ private:
+ std::unique_ptr<FlightClient> client_;
+};
+
+class TestRejectServerMiddleware : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK(MakeServer<MetadataTestServer>(
+ &server_, &client_,
+ [](FlightServerOptions* options) {
+ options->middleware.push_back(
+ {"reject", std::make_shared<RejectServerMiddlewareFactory>()});
+ return Status::OK();
+ },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+ }
+
+ void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+};
+
+class TestCountingServerMiddleware : public ::testing::Test {
+ public:
+ void SetUp() {
+ request_counter_ = std::make_shared<CountingServerMiddlewareFactory>();
+ ASSERT_OK(MakeServer<MetadataTestServer>(
+ &server_, &client_,
+ [&](FlightServerOptions* options) {
+ options->middleware.push_back({"request_counter", request_counter_});
+ return Status::OK();
+ },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+ }
+
+ void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ protected:
+ std::shared_ptr<CountingServerMiddlewareFactory> request_counter_;
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+};
+
+// Setup for this test is 2 servers
+// 1. Client makes request to server A with a request ID set
+// 2. server A extracts the request ID and makes a request to server B
+// with the same request ID set
+// 3. server B extracts the request ID and sends it back
+// 4. server A returns the response of server B
+// 5. Client validates the response
+class TestPropagatingMiddleware : public ::testing::Test {
+ public:
+ void SetUp() {
+ server_middleware_ = std::make_shared<TracingServerMiddlewareFactory>();
+ second_client_middleware_ = std::make_shared<PropagatingClientMiddlewareFactory>();
+ client_middleware_ = std::make_shared<PropagatingClientMiddlewareFactory>();
+
+ std::unique_ptr<FlightClient> server_client;
+ ASSERT_OK(MakeServer<ReportContextTestServer>(
+ &second_server_, &server_client,
+ [&](FlightServerOptions* options) {
+ options->middleware.push_back({"tracing", server_middleware_});
+ return Status::OK();
+ },
+ [&](FlightClientOptions* options) {
+ options->middleware.push_back(second_client_middleware_);
+ return Status::OK();
+ }));
+
+ ASSERT_OK(MakeServer<PropagatingTestServer>(
+ &first_server_, &client_,
+ [&](FlightServerOptions* options) {
+ options->middleware.push_back({"tracing", server_middleware_});
+ return Status::OK();
+ },
+ [&](FlightClientOptions* options) {
+ options->middleware.push_back(client_middleware_);
+ return Status::OK();
+ },
+ std::move(server_client)));
+ }
+
+ void ValidateStatus(const Status& status, const FlightMethod& method) {
+ ASSERT_EQ(1, client_middleware_->received_headers_);
+ ASSERT_EQ(method, client_middleware_->recorded_calls_.at(0));
+ ASSERT_EQ(status.code(), client_middleware_->recorded_status_.at(0).code());
+ }
+
+ void TearDown() {
+ ASSERT_OK(first_server_->Shutdown());
+ ASSERT_OK(second_server_->Shutdown());
+ }
+
+ void CheckHeader(const std::string& header, const std::string& value,
+ const CallHeaders::const_iterator& it) {
+ // Construct a string_view before comparison to satisfy MSVC
+ util::string_view header_view(header.data(), header.length());
+ util::string_view value_view(value.data(), value.length());
+ ASSERT_EQ(header_view, (*it).first);
+ ASSERT_EQ(value_view, (*it).second);
+ }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> first_server_;
+ std::unique_ptr<FlightServerBase> second_server_;
+ std::shared_ptr<TracingServerMiddlewareFactory> server_middleware_;
+ std::shared_ptr<PropagatingClientMiddlewareFactory> second_client_middleware_;
+ std::shared_ptr<PropagatingClientMiddlewareFactory> client_middleware_;
+};
+
+class TestErrorMiddleware : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK(MakeServer<ErrorMiddlewareServer>(
+ &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+ }
+
+ void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+};
+
+class TestBasicHeaderAuthMiddleware : public ::testing::Test {
+ public:
+ void SetUp() {
+ header_middleware_ = std::make_shared<HeaderAuthServerMiddlewareFactory>();
+ bearer_middleware_ = std::make_shared<BearerAuthServerMiddlewareFactory>();
+ std::pair<std::string, std::string> bearer = make_pair(
+ kAuthHeader, std::string(kBearerPrefix) + " " + std::string(kBearerToken));
+ ASSERT_OK(MakeServer<HeaderAuthTestServer>(
+ &server_, &client_,
+ [&](FlightServerOptions* options) {
+ options->auth_handler =
+ std::unique_ptr<ServerAuthHandler>(new NoOpAuthHandler());
+ options->middleware.push_back({"header-auth-server", header_middleware_});
+ options->middleware.push_back({"bearer-auth-server", bearer_middleware_});
+ return Status::OK();
+ },
+ [&](FlightClientOptions* options) { return Status::OK(); }));
+ }
+
+ void RunValidClientAuth() {
+ arrow::Result<std::pair<std::string, std::string>> bearer_result =
+ client_->AuthenticateBasicToken({}, kValidUsername, kValidPassword);
+ ASSERT_OK(bearer_result.status());
+ ASSERT_EQ(bearer_result.ValueOrDie().first, kAuthHeader);
+ ASSERT_EQ(bearer_result.ValueOrDie().second,
+ (std::string(kBearerPrefix) + kBearerToken));
+ std::unique_ptr<FlightListing> listing;
+ FlightCallOptions call_options;
+ call_options.headers.push_back(bearer_result.ValueOrDie());
+ ASSERT_OK(client_->ListFlights(call_options, {}, &listing));
+ ASSERT_TRUE(bearer_middleware_->GetIsValid());
+ }
+
+ void RunInvalidClientAuth() {
+ arrow::Result<std::pair<std::string, std::string>> bearer_result =
+ client_->AuthenticateBasicToken({}, kInvalidUsername, kInvalidPassword);
+ ASSERT_RAISES(IOError, bearer_result.status());
+ ASSERT_THAT(bearer_result.status().message(),
+ ::testing::HasSubstr("Invalid credentials"));
+ }
+
+ void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+ std::shared_ptr<HeaderAuthServerMiddlewareFactory> header_middleware_;
+ std::shared_ptr<BearerAuthServerMiddlewareFactory> bearer_middleware_;
+};
+
+// This test keeps an internal cookie cache and compares that with the middleware.
+class TestCookieMiddleware : public ::testing::Test {
+ public:
+ // Setup function creates middleware factory and starts it up.
+ void SetUp() {
+ factory_ = GetCookieFactory();
+ CallInfo callInfo;
+ factory_->StartCall(callInfo, &middleware_);
+ }
+
+ // Function to add incoming cookies to middleware and validate them.
+ void AddAndValidate(const std::string& incoming_cookie) {
+ // Add cookie
+ CallHeaders call_headers;
+ call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"),
+ arrow::util::string_view(incoming_cookie)));
+ middleware_->ReceivedHeaders(call_headers);
+ expected_cookie_cache_.UpdateCachedCookies(call_headers);
+
+ // Get cookie from middleware.
+ TestCallHeaders add_call_headers;
+ middleware_->SendingHeaders(&add_call_headers);
+ const std::string actual_cookies = add_call_headers.GetCookies();
+
+ // Validate cookie
+ const std::string expected_cookies = expected_cookie_cache_.GetValidCookiesAsString();
+ const std::vector<std::string> split_expected_cookies =
+ SplitCookies(expected_cookies);
+ const std::vector<std::string> split_actual_cookies = SplitCookies(actual_cookies);
+ EXPECT_EQ(split_expected_cookies, split_actual_cookies);
+ }
+
+ // Function to take a list of cookies and split them into a vector of individual
+ // cookies. This is done because the cookie cache is a map so ordering is not
+ // necessarily consistent.
+ static std::vector<std::string> SplitCookies(const std::string& cookies) {
+ std::vector<std::string> split_cookies;
+ std::string::size_type pos1 = 0;
+ std::string::size_type pos2 = 0;
+ while ((pos2 = cookies.find(';', pos1)) != std::string::npos) {
+ split_cookies.push_back(
+ arrow::internal::TrimString(cookies.substr(pos1, pos2 - pos1)));
+ pos1 = pos2 + 1;
+ }
+ if (pos1 < cookies.size()) {
+ split_cookies.push_back(arrow::internal::TrimString(cookies.substr(pos1)));
+ }
+ std::sort(split_cookies.begin(), split_cookies.end());
+ return split_cookies;
+ }
+
+ protected:
+ // Class to allow testing of the call headers.
+ class TestCallHeaders : public AddCallHeaders {
+ public:
+ TestCallHeaders() {}
+ ~TestCallHeaders() {}
+
+ // Function to add cookie header.
+ void AddHeader(const std::string& key, const std::string& value) {
+ ASSERT_EQ(key, "cookie");
+ outbound_cookie_ = value;
+ }
+
+ // Function to get outgoing cookie.
+ std::string GetCookies() { return outbound_cookie_; }
+
+ private:
+ std::string outbound_cookie_;
+ };
+
+ internal::CookieCache expected_cookie_cache_;
+ std::unique_ptr<ClientMiddleware> middleware_;
+ std::shared_ptr<ClientMiddlewareFactory> factory_;
+};
+
+// This test is used to test the parsing capabilities of the cookie framework.
+class TestCookieParsing : public ::testing::Test {
+ public:
+ void VerifyParseCookie(const std::string& cookie_str, bool expired) {
+ internal::Cookie cookie = internal::Cookie::parse(cookie_str);
+ EXPECT_EQ(expired, cookie.IsExpired());
+ }
+
+ void VerifyCookieName(const std::string& cookie_str, const std::string& name) {
+ internal::Cookie cookie = internal::Cookie::parse(cookie_str);
+ EXPECT_EQ(name, cookie.GetName());
+ }
+
+ void VerifyCookieString(const std::string& cookie_str,
+ const std::string& cookie_as_string) {
+ internal::Cookie cookie = internal::Cookie::parse(cookie_str);
+ EXPECT_EQ(cookie_as_string, cookie.AsCookieString());
+ }
+
+ void VerifyCookieDateConverson(std::string date, const std::string& converted_date) {
+ internal::Cookie::ConvertCookieDate(&date);
+ EXPECT_EQ(converted_date, date);
+ }
+
+ void VerifyCookieAttributeParsing(
+ const std::string cookie_str, std::string::size_type start_pos,
+ const util::optional<std::pair<std::string, std::string>> cookie_attribute,
+ const std::string::size_type start_pos_after) {
+ util::optional<std::pair<std::string, std::string>> attr =
+ internal::Cookie::ParseCookieAttribute(cookie_str, &start_pos);
+
+ if (cookie_attribute == util::nullopt) {
+ EXPECT_EQ(cookie_attribute, attr);
+ } else {
+ EXPECT_EQ(cookie_attribute.value(), attr.value());
+ }
+ EXPECT_EQ(start_pos_after, start_pos);
+ }
+
+ void AddCookieVerifyCache(const std::vector<std::string>& cookies,
+ const std::string& expected_cookies) {
+ internal::CookieCache cookie_cache;
+ for (auto& cookie : cookies) {
+ // Add cookie
+ CallHeaders call_headers;
+ call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"),
+ arrow::util::string_view(cookie)));
+ cookie_cache.UpdateCachedCookies(call_headers);
+ }
+ const std::string actual_cookies = cookie_cache.GetValidCookiesAsString();
+ const std::vector<std::string> actual_split_cookies =
+ TestCookieMiddleware::SplitCookies(actual_cookies);
+ const std::vector<std::string> expected_split_cookies =
+ TestCookieMiddleware::SplitCookies(expected_cookies);
+ }
+};
+
+TEST_F(TestErrorMiddleware, TestMetadata) {
+ Action action;
+ std::unique_ptr<ResultStream> stream;
+
+ // Run action1
+ action.type = "action1";
+
+ action.body = Buffer::FromString("action1-content");
+ Status s = client_->DoAction(action, &stream);
+ ASSERT_FALSE(s.ok());
+ std::shared_ptr<FlightStatusDetail> flightStatusDetail =
+ FlightStatusDetail::UnwrapStatus(s);
+ ASSERT_TRUE(flightStatusDetail);
+ ASSERT_EQ(flightStatusDetail->extra_info(), "error_message");
+}
+
+TEST_F(TestFlightClient, ListFlights) {
+ std::unique_ptr<FlightListing> listing;
+ ASSERT_OK(client_->ListFlights(&listing));
+ ASSERT_TRUE(listing != nullptr);
+
+ std::vector<FlightInfo> flights = ExampleFlightInfo();
+
+ std::unique_ptr<FlightInfo> info;
+ for (const FlightInfo& flight : flights) {
+ ASSERT_OK(listing->Next(&info));
+ AssertEqual(flight, *info);
+ }
+ ASSERT_OK(listing->Next(&info));
+ ASSERT_TRUE(info == nullptr);
+
+ ASSERT_OK(listing->Next(&info));
+ ASSERT_TRUE(info == nullptr);
+}
+
+TEST_F(TestFlightClient, ListFlightsWithCriteria) {
+ std::unique_ptr<FlightListing> listing;
+ ASSERT_OK(client_->ListFlights(FlightCallOptions(), {"foo"}, &listing));
+ std::unique_ptr<FlightInfo> info;
+ ASSERT_OK(listing->Next(&info));
+ ASSERT_TRUE(info == nullptr);
+}
+
+TEST_F(TestFlightClient, GetFlightInfo) {
+ auto descr = FlightDescriptor::Path({"examples", "ints"});
+ std::unique_ptr<FlightInfo> info;
+
+ ASSERT_OK(client_->GetFlightInfo(descr, &info));
+ ASSERT_NE(info, nullptr);
+
+ std::vector<FlightInfo> flights = ExampleFlightInfo();
+ AssertEqual(flights[0], *info);
+}
+
+TEST_F(TestFlightClient, GetSchema) {
+ auto descr = FlightDescriptor::Path({"examples", "ints"});
+ std::unique_ptr<SchemaResult> schema_result;
+ std::shared_ptr<Schema> schema;
+ ipc::DictionaryMemo dict_memo;
+
+ ASSERT_OK(client_->GetSchema(descr, &schema_result));
+ ASSERT_NE(schema_result, nullptr);
+ ASSERT_OK(schema_result->GetSchema(&dict_memo, &schema));
+}
+
+TEST_F(TestFlightClient, GetFlightInfoNotFound) {
+ auto descr = FlightDescriptor::Path({"examples", "things"});
+ std::unique_ptr<FlightInfo> info;
+ // XXX Ideally should be Invalid (or KeyError), but gRPC doesn't support
+ // multiple error codes.
+ auto st = client_->GetFlightInfo(descr, &info);
+ ASSERT_RAISES(Invalid, st);
+ ASSERT_NE(st.message().find("Flight not found"), std::string::npos);
+}
+
+TEST_F(TestFlightClient, DoGetInts) {
+ auto descr = FlightDescriptor::Path({"examples", "ints"});
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleIntBatches(&expected_batches));
+
+ auto check_endpoints = [](const std::vector<FlightEndpoint>& endpoints) {
+ // Two endpoints in the example FlightInfo
+ ASSERT_EQ(2, endpoints.size());
+ AssertEqual(Ticket{"ticket-ints-1"}, endpoints[0].ticket);
+ };
+
+ CheckDoGet(descr, expected_batches, check_endpoints);
+}
+
+TEST_F(TestFlightClient, DoGetFloats) {
+ auto descr = FlightDescriptor::Path({"examples", "floats"});
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleFloatBatches(&expected_batches));
+
+ auto check_endpoints = [](const std::vector<FlightEndpoint>& endpoints) {
+ // One endpoint in the example FlightInfo
+ ASSERT_EQ(1, endpoints.size());
+ AssertEqual(Ticket{"ticket-floats-1"}, endpoints[0].ticket);
+ };
+
+ CheckDoGet(descr, expected_batches, check_endpoints);
+}
+
+TEST_F(TestFlightClient, DoGetDicts) {
+ auto descr = FlightDescriptor::Path({"examples", "dicts"});
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleDictBatches(&expected_batches));
+
+ auto check_endpoints = [](const std::vector<FlightEndpoint>& endpoints) {
+ // One endpoint in the example FlightInfo
+ ASSERT_EQ(1, endpoints.size());
+ AssertEqual(Ticket{"ticket-dicts-1"}, endpoints[0].ticket);
+ };
+
+ CheckDoGet(descr, expected_batches, check_endpoints);
+}
+
+// Ensure the gRPC client is configured to allow large messages
+// Tests a 32 MiB batch
+TEST_F(TestFlightClient, DoGetLargeBatch) {
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleLargeBatches(&expected_batches));
+ Ticket ticket{"ticket-large-batch-1"};
+ CheckDoGet(ticket, expected_batches);
+}
+
+TEST_F(TestFlightClient, FlightDataOverflowServerBatch) {
+ // Regression test for ARROW-13253
+ // N.B. this is rather a slow and memory-hungry test
+ {
+ // DoGet: check for overflow on large batch
+ Ticket ticket{"ARROW-13253-DoGet-Batch"};
+ std::unique_ptr<FlightStreamReader> stream;
+ ASSERT_OK(client_->DoGet(ticket, &stream));
+ FlightStreamChunk chunk;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
+ stream->Next(&chunk));
+ }
+ {
+ // DoExchange: check for overflow on large batch from server
+ auto descr = FlightDescriptor::Command("large_batch");
+ std::unique_ptr<FlightStreamReader> reader;
+ std::unique_ptr<FlightStreamWriter> writer;
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ BatchVector batches;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
+ reader->ReadAll(&batches));
+ }
+}
+
+TEST_F(TestFlightClient, FlightDataOverflowClientBatch) {
+ ASSERT_OK_AND_ASSIGN(auto batch, VeryLargeBatch());
+ {
+ // DoPut: check for overflow on large batch
+ std::unique_ptr<FlightStreamWriter> stream;
+ std::unique_ptr<FlightMetadataReader> reader;
+ auto descr = FlightDescriptor::Path({""});
+ ASSERT_OK(client_->DoPut(descr, batch->schema(), &stream, &reader));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
+ stream->WriteRecordBatch(*batch));
+ ASSERT_OK(stream->Close());
+ }
+ {
+ // DoExchange: check for overflow on large batch from client
+ auto descr = FlightDescriptor::Command("counter");
+ std::unique_ptr<FlightStreamReader> reader;
+ std::unique_ptr<FlightStreamWriter> writer;
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK(writer->Begin(batch->schema()));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
+ writer->WriteRecordBatch(*batch));
+ ASSERT_OK(writer->Close());
+ }
+}
+
+TEST_F(TestFlightClient, DoExchange) {
+ auto descr = FlightDescriptor::Command("counter");
+ BatchVector batches;
+ auto a1 = ArrayFromJSON(int32(), "[4, 5, 6, null]");
+ auto schema = arrow::schema({field("f1", a1->type())});
+ batches.push_back(RecordBatch::Make(schema, a1->length(), {a1}));
+ std::unique_ptr<FlightStreamReader> reader;
+ std::unique_ptr<FlightStreamWriter> writer;
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK(writer->Begin(schema));
+ for (const auto& batch : batches) {
+ ASSERT_OK(writer->WriteRecordBatch(*batch));
+ }
+ ASSERT_OK(writer->DoneWriting());
+ FlightStreamChunk chunk;
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_NE(nullptr, chunk.app_metadata);
+ ASSERT_EQ(nullptr, chunk.data);
+ ASSERT_EQ("1", chunk.app_metadata->ToString());
+ ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema());
+ AssertSchemaEqual(schema, server_schema);
+ for (const auto& batch : batches) {
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_BATCHES_EQUAL(*batch, *chunk.data);
+ }
+ ASSERT_OK(writer->Close());
+}
+
+// Test pure-metadata DoExchange to ensure nothing blocks waiting for
+// schema messages
+TEST_F(TestFlightClient, DoExchangeNoData) {
+ auto descr = FlightDescriptor::Command("counter");
+ std::unique_ptr<FlightStreamReader> reader;
+ std::unique_ptr<FlightStreamWriter> writer;
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK(writer->DoneWriting());
+ FlightStreamChunk chunk;
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_EQ(nullptr, chunk.data);
+ ASSERT_NE(nullptr, chunk.app_metadata);
+ ASSERT_EQ("0", chunk.app_metadata->ToString());
+ ASSERT_OK(writer->Close());
+}
+
+// Test sending a schema without any data, as this hits an edge case
+// in the client-side writer.
+TEST_F(TestFlightClient, DoExchangeWriteOnlySchema) {
+ auto descr = FlightDescriptor::Command("counter");
+ std::unique_ptr<FlightStreamReader> reader;
+ std::unique_ptr<FlightStreamWriter> writer;
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ auto schema = arrow::schema({field("f1", arrow::int32())});
+ ASSERT_OK(writer->Begin(schema));
+ ASSERT_OK(writer->WriteMetadata(Buffer::FromString("foo")));
+ ASSERT_OK(writer->DoneWriting());
+ FlightStreamChunk chunk;
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_EQ(nullptr, chunk.data);
+ ASSERT_NE(nullptr, chunk.app_metadata);
+ ASSERT_EQ("0", chunk.app_metadata->ToString());
+ ASSERT_OK(writer->Close());
+}
+
+// Emulate DoGet
+TEST_F(TestFlightClient, DoExchangeGet) {
+ auto descr = FlightDescriptor::Command("get");
+ std::unique_ptr<FlightStreamReader> reader;
+ std::unique_ptr<FlightStreamWriter> writer;
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema());
+ AssertSchemaEqual(*ExampleIntSchema(), *server_schema);
+ BatchVector batches;
+ ASSERT_OK(ExampleIntBatches(&batches));
+ FlightStreamChunk chunk;
+ for (const auto& batch : batches) {
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_NE(nullptr, chunk.data);
+ AssertBatchesEqual(*batch, *chunk.data);
+ }
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_EQ(nullptr, chunk.data);
+ ASSERT_EQ(nullptr, chunk.app_metadata);
+ ASSERT_OK(writer->Close());
+}
+
+// Emulate DoPut
+TEST_F(TestFlightClient, DoExchangePut) {
+ auto descr = FlightDescriptor::Command("put");
+ std::unique_ptr<FlightStreamReader> reader;
+ std::unique_ptr<FlightStreamWriter> writer;
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK(writer->Begin(ExampleIntSchema()));
+ BatchVector batches;
+ ASSERT_OK(ExampleIntBatches(&batches));
+ for (const auto& batch : batches) {
+ ASSERT_OK(writer->WriteRecordBatch(*batch));
+ }
+ ASSERT_OK(writer->DoneWriting());
+ FlightStreamChunk chunk;
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_NE(nullptr, chunk.app_metadata);
+ AssertBufferEqual(*chunk.app_metadata, "done");
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_EQ(nullptr, chunk.data);
+ ASSERT_EQ(nullptr, chunk.app_metadata);
+ ASSERT_OK(writer->Close());
+}
+
+// Test the echo server
+TEST_F(TestFlightClient, DoExchangeEcho) {
+ auto descr = FlightDescriptor::Command("echo");
+ std::unique_ptr<FlightStreamReader> reader;
+ std::unique_ptr<FlightStreamWriter> writer;
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK(writer->Begin(ExampleIntSchema()));
+ BatchVector batches;
+ FlightStreamChunk chunk;
+ ASSERT_OK(ExampleIntBatches(&batches));
+ for (const auto& batch : batches) {
+ ASSERT_OK(writer->WriteRecordBatch(*batch));
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_NE(nullptr, chunk.data);
+ ASSERT_EQ(nullptr, chunk.app_metadata);
+ AssertBatchesEqual(*batch, *chunk.data);
+ }
+ for (int i = 0; i < 10; i++) {
+ const auto buf = Buffer::FromString(std::to_string(i));
+ ASSERT_OK(writer->WriteMetadata(buf));
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_EQ(nullptr, chunk.data);
+ ASSERT_NE(nullptr, chunk.app_metadata);
+ AssertBufferEqual(*buf, *chunk.app_metadata);
+ }
+ int index = 0;
+ for (const auto& batch : batches) {
+ const auto buf = Buffer::FromString(std::to_string(index));
+ ASSERT_OK(writer->WriteWithMetadata(*batch, buf));
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_NE(nullptr, chunk.data);
+ ASSERT_NE(nullptr, chunk.app_metadata);
+ AssertBatchesEqual(*batch, *chunk.data);
+ AssertBufferEqual(*buf, *chunk.app_metadata);
+ index++;
+ }
+ ASSERT_OK(writer->DoneWriting());
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_EQ(nullptr, chunk.data);
+ ASSERT_EQ(nullptr, chunk.app_metadata);
+ ASSERT_OK(writer->Close());
+}
+
+// Test interleaved reading/writing
+TEST_F(TestFlightClient, DoExchangeTotal) {
+ auto descr = FlightDescriptor::Command("total");
+ std::unique_ptr<FlightStreamReader> reader;
+ std::unique_ptr<FlightStreamWriter> writer;
+ {
+ auto a1 = ArrayFromJSON(arrow::int32(), "[4, 5, 6, null]");
+ auto schema = arrow::schema({field("f1", a1->type())});
+ // XXX: as noted in flight/client.cc, Begin() is lazy and the
+ // schema message won't be written until some data is also
+ // written. There's also timing issues; hence we check each status
+ // here.
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Field is not INT64: f1"), ([&]() {
+ RETURN_NOT_OK(client_->DoExchange(descr, &writer, &reader));
+ RETURN_NOT_OK(writer->Begin(schema));
+ auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1});
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+ return writer->Close();
+ })());
+ }
+ {
+ auto a1 = ArrayFromJSON(arrow::int64(), "[1, 2, null, 3]");
+ auto a2 = ArrayFromJSON(arrow::int64(), "[null, 4, 5, 6]");
+ auto schema = arrow::schema({field("f1", a1->type()), field("f2", a2->type())});
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK(writer->Begin(schema));
+ auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1, a2});
+ FlightStreamChunk chunk;
+ ASSERT_OK(writer->WriteRecordBatch(*batch));
+ ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema());
+ AssertSchemaEqual(*schema, *server_schema);
+
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_NE(nullptr, chunk.data);
+ auto expected1 = RecordBatch::Make(
+ schema, /* num_rows */ 1,
+ {ArrayFromJSON(arrow::int64(), "[6]"), ArrayFromJSON(arrow::int64(), "[15]")});
+ AssertBatchesEqual(*expected1, *chunk.data);
+
+ ASSERT_OK(writer->WriteRecordBatch(*batch));
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_NE(nullptr, chunk.data);
+ auto expected2 = RecordBatch::Make(
+ schema, /* num_rows */ 1,
+ {ArrayFromJSON(arrow::int64(), "[12]"), ArrayFromJSON(arrow::int64(), "[30]")});
+ AssertBatchesEqual(*expected2, *chunk.data);
+
+ ASSERT_OK(writer->Close());
+ }
+}
+
+// Ensure server errors get propagated no matter what we try
+TEST_F(TestFlightClient, DoExchangeError) {
+ auto descr = FlightDescriptor::Command("error");
+ std::unique_ptr<FlightStreamReader> reader;
+ std::unique_ptr<FlightStreamWriter> writer;
+ {
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ auto status = writer->Close();
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ NotImplemented, ::testing::HasSubstr("Expected error"), writer->Close());
+ }
+ {
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ FlightStreamChunk chunk;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ NotImplemented, ::testing::HasSubstr("Expected error"), reader->Next(&chunk));
+ }
+ {
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ NotImplemented, ::testing::HasSubstr("Expected error"), reader->GetSchema());
+ }
+ // writer->Begin isn't tested here because, as noted in client.cc,
+ // OpenRecordBatchWriter lazily writes the initial message - hence
+ // Begin() won't fail. Additionally, it appears gRPC may buffer
+ // writes - a write won't immediately fail even when the server
+ // immediately fails.
+}
+
+TEST_F(TestFlightClient, ListActions) {
+ std::vector<ActionType> actions;
+ ASSERT_OK(client_->ListActions(&actions));
+
+ std::vector<ActionType> expected = ExampleActionTypes();
+ AssertEqual(expected, actions);
+}
+
+TEST_F(TestFlightClient, DoAction) {
+ Action action;
+ std::unique_ptr<ResultStream> stream;
+ std::unique_ptr<Result> result;
+
+ // Run action1
+ action.type = "action1";
+
+ const std::string action1_value = "action1-content";
+ action.body = Buffer::FromString(action1_value);
+ ASSERT_OK(client_->DoAction(action, &stream));
+
+ for (int i = 0; i < 3; ++i) {
+ ASSERT_OK(stream->Next(&result));
+ std::string expected = action1_value + "-part" + std::to_string(i);
+ ASSERT_EQ(expected, result->body->ToString());
+ }
+
+ // stream consumed
+ ASSERT_OK(stream->Next(&result));
+ ASSERT_EQ(nullptr, result);
+
+ // Run action2, no results
+ action.type = "action2";
+ ASSERT_OK(client_->DoAction(action, &stream));
+
+ ASSERT_OK(stream->Next(&result));
+ ASSERT_EQ(nullptr, result);
+}
+
+TEST_F(TestFlightClient, RoundTripStatus) {
+ const auto descr = FlightDescriptor::Command("status-outofmemory");
+ std::unique_ptr<FlightInfo> info;
+ const auto status = client_->GetFlightInfo(descr, &info);
+ ASSERT_RAISES(OutOfMemory, status);
+}
+
+TEST_F(TestFlightClient, Issue5095) {
+ // Make sure the server-side error message is reflected to the
+ // client
+ Ticket ticket1{"ARROW-5095-fail"};
+ std::unique_ptr<FlightStreamReader> stream;
+ Status status = client_->DoGet(ticket1, &stream);
+ ASSERT_RAISES(UnknownError, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Server-side error"));
+
+ Ticket ticket2{"ARROW-5095-success"};
+ status = client_->DoGet(ticket2, &stream);
+ ASSERT_RAISES(KeyError, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("No data"));
+}
+
+// Test setting generic transport options by configuring gRPC to fail
+// all calls.
+TEST_F(TestFlightClient, GenericOptions) {
+ std::unique_ptr<FlightClient> client;
+ auto options = FlightClientOptions::Defaults();
+ // Set a very low limit at the gRPC layer to fail all calls
+ options.generic_options.emplace_back(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, 4);
+ Location location;
+ ASSERT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location));
+ ASSERT_OK(FlightClient::Connect(location, options, &client));
+ auto descr = FlightDescriptor::Path({"examples", "ints"});
+ std::unique_ptr<SchemaResult> schema_result;
+ std::shared_ptr<Schema> schema;
+ ipc::DictionaryMemo dict_memo;
+ auto status = client->GetSchema(descr, &schema_result);
+ ASSERT_RAISES(Invalid, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("resource exhausted"));
+}
+
+TEST_F(TestFlightClient, TimeoutFires) {
+ // Server does not exist on this port, so call should fail
+ std::unique_ptr<FlightClient> client;
+ Location location;
+ ASSERT_OK(Location::ForGrpcTcp("localhost", 30001, &location));
+ ASSERT_OK(FlightClient::Connect(location, &client));
+ FlightCallOptions options;
+ options.timeout = TimeoutDuration{0.2};
+ std::unique_ptr<FlightInfo> info;
+ auto start = std::chrono::system_clock::now();
+ Status status = client->GetFlightInfo(options, FlightDescriptor{}, &info);
+ auto end = std::chrono::system_clock::now();
+#ifdef ARROW_WITH_TIMING_TESTS
+ EXPECT_LE(end - start, std::chrono::milliseconds{400});
+#else
+ ARROW_UNUSED(end - start);
+#endif
+ ASSERT_RAISES(IOError, status);
+}
+
+TEST_F(TestFlightClient, NoTimeout) {
+ // Call should complete quickly, so timeout should not fire
+ FlightCallOptions options;
+ options.timeout = TimeoutDuration{5.0}; // account for slow server process startup
+ std::unique_ptr<FlightInfo> info;
+ auto start = std::chrono::system_clock::now();
+ auto descriptor = FlightDescriptor::Path({"examples", "ints"});
+ Status status = client_->GetFlightInfo(options, descriptor, &info);
+ auto end = std::chrono::system_clock::now();
+#ifdef ARROW_WITH_TIMING_TESTS
+ EXPECT_LE(end - start, std::chrono::milliseconds{600});
+#else
+ ARROW_UNUSED(end - start);
+#endif
+ ASSERT_OK(status);
+ ASSERT_NE(nullptr, info);
+}
+
+TEST_F(TestDoPut, DoPutInts) {
+ auto descr = FlightDescriptor::Path({"ints"});
+ BatchVector batches;
+ auto a0 = ArrayFromJSON(int8(), "[0, 1, 127, -128, null]");
+ auto a1 = ArrayFromJSON(uint8(), "[0, 1, 127, 255, null]");
+ auto a2 = ArrayFromJSON(int16(), "[0, 258, 32767, -32768, null]");
+ auto a3 = ArrayFromJSON(uint16(), "[0, 258, 32767, 65535, null]");
+ auto a4 = ArrayFromJSON(int32(), "[0, 65538, 2147483647, -2147483648, null]");
+ auto a5 = ArrayFromJSON(uint32(), "[0, 65538, 2147483647, 4294967295, null]");
+ auto a6 = ArrayFromJSON(
+ int64(), "[0, 4294967298, 9223372036854775807, -9223372036854775808, null]");
+ auto a7 = ArrayFromJSON(
+ uint64(), "[0, 4294967298, 9223372036854775807, 18446744073709551615, null]");
+ auto schema = arrow::schema({field("f0", a0->type()), field("f1", a1->type()),
+ field("f2", a2->type()), field("f3", a3->type()),
+ field("f4", a4->type()), field("f5", a5->type()),
+ field("f6", a6->type()), field("f7", a7->type())});
+ batches.push_back(
+ RecordBatch::Make(schema, a0->length(), {a0, a1, a2, a3, a4, a5, a6, a7}));
+
+ CheckDoPut(descr, schema, batches);
+}
+
+TEST_F(TestDoPut, DoPutFloats) {
+ auto descr = FlightDescriptor::Path({"floats"});
+ BatchVector batches;
+ auto a0 = ArrayFromJSON(float32(), "[0, 1.2, -3.4, 5.6, null]");
+ auto a1 = ArrayFromJSON(float64(), "[0, 1.2, -3.4, 5.6, null]");
+ auto schema = arrow::schema({field("f0", a0->type()), field("f1", a1->type())});
+ batches.push_back(RecordBatch::Make(schema, a0->length(), {a0, a1}));
+
+ CheckDoPut(descr, schema, batches);
+}
+
+TEST_F(TestDoPut, DoPutEmptyBatch) {
+ // Sending and receiving a 0-sized batch shouldn't fail
+ auto descr = FlightDescriptor::Path({"ints"});
+ BatchVector batches;
+ auto a1 = ArrayFromJSON(int32(), "[]");
+ auto schema = arrow::schema({field("f1", a1->type())});
+ batches.push_back(RecordBatch::Make(schema, a1->length(), {a1}));
+
+ CheckDoPut(descr, schema, batches);
+}
+
+TEST_F(TestDoPut, DoPutDicts) {
+ auto descr = FlightDescriptor::Path({"dicts"});
+ BatchVector batches;
+ auto dict_values = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"quux\"]");
+ auto ty = dictionary(int8(), dict_values->type());
+ auto schema = arrow::schema({field("f1", ty)});
+ // Make several batches
+ for (const char* json : {"[1, 0, 1]", "[null]", "[null, 1]"}) {
+ auto indices = ArrayFromJSON(int8(), json);
+ auto dict_array = std::make_shared<DictionaryArray>(ty, indices, dict_values);
+ batches.push_back(RecordBatch::Make(schema, dict_array->length(), {dict_array}));
+ }
+
+ CheckDoPut(descr, schema, batches);
+}
+
+// Ensure the gRPC server is configured to allow large messages
+// Tests a 32 MiB batch
+TEST_F(TestDoPut, DoPutLargeBatch) {
+ auto descr = FlightDescriptor::Path({"large-batches"});
+ auto schema = ExampleLargeSchema();
+ BatchVector batches;
+ ASSERT_OK(ExampleLargeBatches(&batches));
+ CheckDoPut(descr, schema, batches);
+}
+
+TEST_F(TestDoPut, DoPutSizeLimit) {
+ const int64_t size_limit = 4096;
+ Location location;
+ ASSERT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location));
+ auto client_options = FlightClientOptions::Defaults();
+ client_options.write_size_limit_bytes = size_limit;
+ std::unique_ptr<FlightClient> client;
+ ASSERT_OK(FlightClient::Connect(location, client_options, &client));
+
+ auto descr = FlightDescriptor::Path({"ints"});
+ // Batch is too large to fit in one message
+ auto schema = arrow::schema({field("f1", arrow::int64())});
+ auto batch = arrow::ConstantArrayGenerator::Zeroes(768, schema);
+ BatchVector batches;
+ batches.push_back(batch->Slice(0, 384));
+ batches.push_back(batch->Slice(384));
+
+ std::unique_ptr<FlightStreamWriter> stream;
+ std::unique_ptr<FlightMetadataReader> reader;
+ ASSERT_OK(client->DoPut(descr, schema, &stream, &reader));
+
+ // Large batch will exceed the limit
+ const auto status = stream->WriteRecordBatch(*batch);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("exceeded soft limit"),
+ status);
+ auto detail = FlightWriteSizeStatusDetail::UnwrapStatus(status);
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(size_limit, detail->limit());
+ ASSERT_GT(detail->actual(), size_limit);
+
+ // But we can retry with a smaller batch
+ for (const auto& batch : batches) {
+ ASSERT_OK(stream->WriteRecordBatch(*batch));
+ }
+
+ ASSERT_OK(stream->DoneWriting());
+ ASSERT_OK(stream->Close());
+ CheckBatches(descr, batches);
+}
+
+TEST_F(TestAuthHandler, PassAuthenticatedCalls) {
+ ASSERT_OK(client_->Authenticate(
+ {},
+ std::unique_ptr<ClientAuthHandler>(new TestClientAuthHandler("user", "p4ssw0rd"))));
+
+ Status status;
+ std::unique_ptr<FlightListing> listing;
+ status = client_->ListFlights(&listing);
+ ASSERT_RAISES(NotImplemented, status);
+
+ std::unique_ptr<ResultStream> results;
+ Action action;
+ action.type = "";
+ action.body = Buffer::FromString("");
+ status = client_->DoAction(action, &results);
+ ASSERT_OK(status);
+
+ std::vector<ActionType> actions;
+ status = client_->ListActions(&actions);
+ ASSERT_RAISES(NotImplemented, status);
+
+ std::unique_ptr<FlightInfo> info;
+ status = client_->GetFlightInfo(FlightDescriptor{}, &info);
+ ASSERT_RAISES(NotImplemented, status);
+
+ std::unique_ptr<FlightStreamReader> stream;
+ status = client_->DoGet(Ticket{}, &stream);
+ ASSERT_RAISES(NotImplemented, status);
+
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightMetadataReader> reader;
+ std::shared_ptr<Schema> schema = arrow::schema({});
+ status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
+ ASSERT_OK(status);
+ status = writer->Close();
+ ASSERT_RAISES(NotImplemented, status);
+}
+
+TEST_F(TestAuthHandler, FailUnauthenticatedCalls) {
+ Status status;
+ std::unique_ptr<FlightListing> listing;
+ status = client_->ListFlights(&listing);
+ ASSERT_RAISES(IOError, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
+
+ std::unique_ptr<ResultStream> results;
+ Action action;
+ action.type = "";
+ action.body = Buffer::FromString("");
+ status = client_->DoAction(action, &results);
+ ASSERT_RAISES(IOError, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
+
+ std::vector<ActionType> actions;
+ status = client_->ListActions(&actions);
+ ASSERT_RAISES(IOError, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
+
+ std::unique_ptr<FlightInfo> info;
+ status = client_->GetFlightInfo(FlightDescriptor{}, &info);
+ ASSERT_RAISES(IOError, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
+
+ std::unique_ptr<FlightStreamReader> stream;
+ status = client_->DoGet(Ticket{}, &stream);
+ ASSERT_RAISES(IOError, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
+
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightMetadataReader> reader;
+ std::shared_ptr<Schema> schema(
+ (new arrow::Schema(std::vector<std::shared_ptr<Field>>())));
+ status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
+ ASSERT_OK(status);
+ status = writer->Close();
+ ASSERT_RAISES(IOError, status);
+ // ARROW-7583: don't check the error message here.
+ // Because gRPC reports errors in some paths with booleans, instead
+ // of statuses, we can fail the call without knowing why it fails,
+ // instead reporting a generic error message. This is
+ // nondeterministic, so don't assert any particular message here.
+}
+
+TEST_F(TestAuthHandler, CheckPeerIdentity) {
+ ASSERT_OK(client_->Authenticate(
+ {},
+ std::unique_ptr<ClientAuthHandler>(new TestClientAuthHandler("user", "p4ssw0rd"))));
+
+ Action action;
+ action.type = "who-am-i";
+ action.body = Buffer::FromString("");
+ std::unique_ptr<ResultStream> results;
+ ASSERT_OK(client_->DoAction(action, &results));
+ ASSERT_NE(results, nullptr);
+
+ std::unique_ptr<Result> result;
+ ASSERT_OK(results->Next(&result));
+ ASSERT_NE(result, nullptr);
+ // Action returns the peer identity as the result.
+ ASSERT_EQ(result->body->ToString(), "user");
+
+ ASSERT_OK(results->Next(&result));
+ ASSERT_NE(result, nullptr);
+ // Action returns the peer address as the result.
+#ifndef _WIN32
+ // On Windows gRPC sometimes returns a blank peer address, so don't
+ // bother checking for it.
+ ASSERT_NE(result->body->ToString(), "");
+#endif
+}
+
+TEST_F(TestBasicAuthHandler, PassAuthenticatedCalls) {
+ ASSERT_OK(
+ client_->Authenticate({}, std::unique_ptr<ClientAuthHandler>(
+ new TestClientBasicAuthHandler("user", "p4ssw0rd"))));
+
+ Status status;
+ std::unique_ptr<FlightListing> listing;
+ status = client_->ListFlights(&listing);
+ ASSERT_RAISES(NotImplemented, status);
+
+ std::unique_ptr<ResultStream> results;
+ Action action;
+ action.type = "";
+ action.body = Buffer::FromString("");
+ status = client_->DoAction(action, &results);
+ ASSERT_OK(status);
+
+ std::vector<ActionType> actions;
+ status = client_->ListActions(&actions);
+ ASSERT_RAISES(NotImplemented, status);
+
+ std::unique_ptr<FlightInfo> info;
+ status = client_->GetFlightInfo(FlightDescriptor{}, &info);
+ ASSERT_RAISES(NotImplemented, status);
+
+ std::unique_ptr<FlightStreamReader> stream;
+ status = client_->DoGet(Ticket{}, &stream);
+ ASSERT_RAISES(NotImplemented, status);
+
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightMetadataReader> reader;
+ std::shared_ptr<Schema> schema = arrow::schema({});
+ status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
+ ASSERT_OK(status);
+ status = writer->Close();
+ ASSERT_RAISES(NotImplemented, status);
+}
+
+TEST_F(TestBasicAuthHandler, FailUnauthenticatedCalls) {
+ Status status;
+ std::unique_ptr<FlightListing> listing;
+ status = client_->ListFlights(&listing);
+ ASSERT_RAISES(IOError, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
+
+ std::unique_ptr<ResultStream> results;
+ Action action;
+ action.type = "";
+ action.body = Buffer::FromString("");
+ status = client_->DoAction(action, &results);
+ ASSERT_RAISES(IOError, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
+
+ std::vector<ActionType> actions;
+ status = client_->ListActions(&actions);
+ ASSERT_RAISES(IOError, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
+
+ std::unique_ptr<FlightInfo> info;
+ status = client_->GetFlightInfo(FlightDescriptor{}, &info);
+ ASSERT_RAISES(IOError, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
+
+ std::unique_ptr<FlightStreamReader> stream;
+ status = client_->DoGet(Ticket{}, &stream);
+ ASSERT_RAISES(IOError, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
+
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightMetadataReader> reader;
+ std::shared_ptr<Schema> schema(
+ (new arrow::Schema(std::vector<std::shared_ptr<Field>>())));
+ status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
+ ASSERT_OK(status);
+ status = writer->Close();
+ ASSERT_RAISES(IOError, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
+}
+
+TEST_F(TestBasicAuthHandler, CheckPeerIdentity) {
+ ASSERT_OK(
+ client_->Authenticate({}, std::unique_ptr<ClientAuthHandler>(
+ new TestClientBasicAuthHandler("user", "p4ssw0rd"))));
+
+ Action action;
+ action.type = "who-am-i";
+ action.body = Buffer::FromString("");
+ std::unique_ptr<ResultStream> results;
+ ASSERT_OK(client_->DoAction(action, &results));
+ ASSERT_NE(results, nullptr);
+
+ std::unique_ptr<Result> result;
+ ASSERT_OK(results->Next(&result));
+ ASSERT_NE(result, nullptr);
+ // Action returns the peer identity as the result.
+ ASSERT_EQ(result->body->ToString(), "user");
+}
+
+TEST_F(TestTls, DoAction) {
+ FlightCallOptions options;
+ options.timeout = TimeoutDuration{5.0};
+ Action action;
+ action.type = "test";
+ action.body = Buffer::FromString("");
+ std::unique_ptr<ResultStream> results;
+ ASSERT_OK(client_->DoAction(options, action, &results));
+ ASSERT_NE(results, nullptr);
+
+ std::unique_ptr<Result> result;
+ ASSERT_OK(results->Next(&result));
+ ASSERT_NE(result, nullptr);
+ ASSERT_EQ(result->body->ToString(), "Hello, world!");
+}
+
+#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS)
+TEST_F(TestTls, DisableServerVerification) {
+ std::unique_ptr<FlightClient> client;
+ auto client_options = FlightClientOptions::Defaults();
+ // For security reasons, if encryption is being used,
+ // the client should be configured to verify the server by default.
+ ASSERT_EQ(client_options.disable_server_verification, false);
+ client_options.disable_server_verification = true;
+ ASSERT_OK(FlightClient::Connect(location_, client_options, &client));
+
+ FlightCallOptions options;
+ options.timeout = TimeoutDuration{5.0};
+ Action action;
+ action.type = "test";
+ action.body = Buffer::FromString("");
+ std::unique_ptr<ResultStream> results;
+ ASSERT_OK(client->DoAction(options, action, &results));
+ ASSERT_NE(results, nullptr);
+
+ std::unique_ptr<Result> result;
+ ASSERT_OK(results->Next(&result));
+ ASSERT_NE(result, nullptr);
+ ASSERT_EQ(result->body->ToString(), "Hello, world!");
+}
+#endif
+
+TEST_F(TestTls, OverrideHostname) {
+ std::unique_ptr<FlightClient> client;
+ auto client_options = FlightClientOptions::Defaults();
+ client_options.override_hostname = "fakehostname";
+ CertKeyPair root_cert;
+ ASSERT_OK(ExampleTlsCertificateRoot(&root_cert));
+ client_options.tls_root_certs = root_cert.pem_cert;
+ ASSERT_OK(FlightClient::Connect(location_, client_options, &client));
+
+ FlightCallOptions options;
+ options.timeout = TimeoutDuration{5.0};
+ Action action;
+ action.type = "test";
+ action.body = Buffer::FromString("");
+ std::unique_ptr<ResultStream> results;
+ ASSERT_RAISES(IOError, client->DoAction(options, action, &results));
+}
+
+// Test the facility for setting generic transport options.
+TEST_F(TestTls, OverrideHostnameGeneric) {
+ std::unique_ptr<FlightClient> client;
+ auto client_options = FlightClientOptions::Defaults();
+ client_options.generic_options.emplace_back(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG,
+ "fakehostname");
+ CertKeyPair root_cert;
+ ASSERT_OK(ExampleTlsCertificateRoot(&root_cert));
+ client_options.tls_root_certs = root_cert.pem_cert;
+ ASSERT_OK(FlightClient::Connect(location_, client_options, &client));
+
+ FlightCallOptions options;
+ options.timeout = TimeoutDuration{5.0};
+ Action action;
+ action.type = "test";
+ action.body = Buffer::FromString("");
+ std::unique_ptr<ResultStream> results;
+ ASSERT_RAISES(IOError, client->DoAction(options, action, &results));
+ // Could check error message for the gRPC error message but it isn't
+ // necessarily stable
+}
+
+TEST_F(TestMetadata, DoGet) {
+ Ticket ticket{""};
+ std::unique_ptr<FlightStreamReader> stream;
+ ASSERT_OK(client_->DoGet(ticket, &stream));
+
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleIntBatches(&expected_batches));
+
+ FlightStreamChunk chunk;
+ auto num_batches = static_cast<int>(expected_batches.size());
+ for (int i = 0; i < num_batches; ++i) {
+ ASSERT_OK(stream->Next(&chunk));
+ ASSERT_NE(nullptr, chunk.data);
+ ASSERT_NE(nullptr, chunk.app_metadata);
+ ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
+ ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString());
+ }
+ ASSERT_OK(stream->Next(&chunk));
+ ASSERT_EQ(nullptr, chunk.data);
+}
+
+// Test dictionaries. This tests a corner case in the reader:
+// dictionary batches come in between the schema and the first record
+// batch, so the server must take care to read application metadata
+// from the record batch, and not one of the dictionary batches.
+TEST_F(TestMetadata, DoGetDictionaries) {
+ Ticket ticket{"dicts"};
+ std::unique_ptr<FlightStreamReader> stream;
+ ASSERT_OK(client_->DoGet(ticket, &stream));
+
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleDictBatches(&expected_batches));
+
+ FlightStreamChunk chunk;
+ auto num_batches = static_cast<int>(expected_batches.size());
+ for (int i = 0; i < num_batches; ++i) {
+ ASSERT_OK(stream->Next(&chunk));
+ ASSERT_NE(nullptr, chunk.data);
+ ASSERT_NE(nullptr, chunk.app_metadata);
+ ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
+ ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString());
+ }
+ ASSERT_OK(stream->Next(&chunk));
+ ASSERT_EQ(nullptr, chunk.data);
+}
+
+TEST_F(TestMetadata, DoPut) {
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightMetadataReader> reader;
+ std::shared_ptr<Schema> schema = ExampleIntSchema();
+ ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
+
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleIntBatches(&expected_batches));
+
+ std::shared_ptr<RecordBatch> chunk;
+ std::shared_ptr<Buffer> metadata;
+ auto num_batches = static_cast<int>(expected_batches.size());
+ for (int i = 0; i < num_batches; ++i) {
+ ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
+ Buffer::FromString(std::to_string(i))));
+ }
+ // This eventually calls grpc::ClientReaderWriter::Finish which can
+ // hang if there are unread messages. So make sure our wrapper
+ // around this doesn't hang (because it drains any unread messages)
+ ASSERT_OK(writer->Close());
+}
+
+// Test DoPut() with dictionaries. This tests a corner case in the
+// server-side reader; see DoGetDictionaries above.
+TEST_F(TestMetadata, DoPutDictionaries) {
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightMetadataReader> reader;
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleDictBatches(&expected_batches));
+ // ARROW-8749: don't get the schema via ExampleDictSchema because
+ // DictionaryMemo uses field addresses to determine whether it's
+ // seen a field before. Hence, if we use a schema that is different
+ // (identity-wise) than the schema of the first batch we write,
+ // we'll end up generating a duplicate set of dictionaries that
+ // confuses the reader.
+ ASSERT_OK(client_->DoPut(FlightDescriptor{}, expected_batches[0]->schema(), &writer,
+ &reader));
+ std::shared_ptr<RecordBatch> chunk;
+ std::shared_ptr<Buffer> metadata;
+ auto num_batches = static_cast<int>(expected_batches.size());
+ for (int i = 0; i < num_batches; ++i) {
+ ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
+ Buffer::FromString(std::to_string(i))));
+ }
+ ASSERT_OK(writer->Close());
+}
+
+TEST_F(TestMetadata, DoPutReadMetadata) {
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightMetadataReader> reader;
+ std::shared_ptr<Schema> schema = ExampleIntSchema();
+ ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
+
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleIntBatches(&expected_batches));
+
+ std::shared_ptr<RecordBatch> chunk;
+ std::shared_ptr<Buffer> metadata;
+ auto num_batches = static_cast<int>(expected_batches.size());
+ for (int i = 0; i < num_batches; ++i) {
+ ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
+ Buffer::FromString(std::to_string(i))));
+ ASSERT_OK(reader->ReadMetadata(&metadata));
+ ASSERT_NE(nullptr, metadata);
+ ASSERT_EQ(std::to_string(i), metadata->ToString());
+ }
+ // As opposed to DoPutDrainMetadata, now we've read the messages, so
+ // make sure this still closes as expected.
+ ASSERT_OK(writer->Close());
+}
+
+TEST_F(TestOptions, DoGetReadOptions) {
+ // Call DoGet, but with a very low read nesting depth set to fail the call.
+ Ticket ticket{""};
+ auto options = FlightCallOptions();
+ options.read_options.max_recursion_depth = 1;
+ std::unique_ptr<FlightStreamReader> stream;
+ ASSERT_OK(client_->DoGet(options, ticket, &stream));
+ FlightStreamChunk chunk;
+ ASSERT_RAISES(Invalid, stream->Next(&chunk));
+}
+
+TEST_F(TestOptions, DoPutWriteOptions) {
+ // Call DoPut, but with a very low write nesting depth set to fail the call.
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightMetadataReader> reader;
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleNestedBatches(&expected_batches));
+
+ auto options = FlightCallOptions();
+ options.write_options.max_recursion_depth = 1;
+ ASSERT_OK(client_->DoPut(options, FlightDescriptor{}, expected_batches[0]->schema(),
+ &writer, &reader));
+ for (const auto& batch : expected_batches) {
+ ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch));
+ }
+}
+
+TEST_F(TestOptions, DoExchangeClientWriteOptions) {
+ // Call DoExchange and write nested data, but with a very low nesting depth set to
+ // fail the call.
+ auto options = FlightCallOptions();
+ options.write_options.max_recursion_depth = 1;
+ auto descr = FlightDescriptor::Command("");
+ std::unique_ptr<FlightStreamReader> reader;
+ std::unique_ptr<FlightStreamWriter> writer;
+ ASSERT_OK(client_->DoExchange(options, descr, &writer, &reader));
+ BatchVector batches;
+ ASSERT_OK(ExampleNestedBatches(&batches));
+ ASSERT_OK(writer->Begin(batches[0]->schema()));
+ for (const auto& batch : batches) {
+ ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch));
+ }
+ ASSERT_OK(writer->DoneWriting());
+ ASSERT_OK(writer->Close());
+}
+
+TEST_F(TestOptions, DoExchangeClientWriteOptionsBegin) {
+ // Call DoExchange and write nested data, but with a very low nesting depth set to
+ // fail the call. Here the options are set explicitly when we write data and not in the
+ // call options.
+ auto descr = FlightDescriptor::Command("");
+ std::unique_ptr<FlightStreamReader> reader;
+ std::unique_ptr<FlightStreamWriter> writer;
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ BatchVector batches;
+ ASSERT_OK(ExampleNestedBatches(&batches));
+ auto options = ipc::IpcWriteOptions::Defaults();
+ options.max_recursion_depth = 1;
+ ASSERT_OK(writer->Begin(batches[0]->schema(), options));
+ for (const auto& batch : batches) {
+ ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch));
+ }
+ ASSERT_OK(writer->DoneWriting());
+ ASSERT_OK(writer->Close());
+}
+
+TEST_F(TestOptions, DoExchangeServerWriteOptions) {
+ // Call DoExchange and write nested data, but with a very low nesting depth set to fail
+ // the call. (The low nesting depth is set on the server side.)
+ auto descr = FlightDescriptor::Command("");
+ std::unique_ptr<FlightStreamReader> reader;
+ std::unique_ptr<FlightStreamWriter> writer;
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ BatchVector batches;
+ ASSERT_OK(ExampleNestedBatches(&batches));
+ ASSERT_OK(writer->Begin(batches[0]->schema()));
+ FlightStreamChunk chunk;
+ ASSERT_OK(writer->WriteRecordBatch(*batches[0]));
+ ASSERT_OK(writer->DoneWriting());
+ ASSERT_RAISES(Invalid, writer->Close());
+}
+
+TEST_F(TestRejectServerMiddleware, Rejected) {
+ std::unique_ptr<FlightInfo> info;
+ const auto& status = client_->GetFlightInfo(FlightDescriptor{}, &info);
+ ASSERT_RAISES(IOError, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("All calls are rejected"));
+}
+
+TEST_F(TestCountingServerMiddleware, Count) {
+ std::unique_ptr<FlightInfo> info;
+ const auto& status = client_->GetFlightInfo(FlightDescriptor{}, &info);
+ ASSERT_RAISES(NotImplemented, status);
+
+ Ticket ticket{""};
+ std::unique_ptr<FlightStreamReader> stream;
+ ASSERT_OK(client_->DoGet(ticket, &stream));
+
+ ASSERT_EQ(1, request_counter_->failed_);
+
+ while (true) {
+ FlightStreamChunk chunk;
+ ASSERT_OK(stream->Next(&chunk));
+ if (chunk.data == nullptr) {
+ break;
+ }
+ }
+
+ ASSERT_EQ(1, request_counter_->successful_);
+ ASSERT_EQ(1, request_counter_->failed_);
+}
+
+TEST_F(TestPropagatingMiddleware, Propagate) {
+ Action action;
+ std::unique_ptr<ResultStream> stream;
+ std::unique_ptr<Result> result;
+
+ current_span_id = "trace-id";
+ client_middleware_->Reset();
+
+ action.type = "action1";
+ action.body = Buffer::FromString("action1-content");
+ ASSERT_OK(client_->DoAction(action, &stream));
+
+ ASSERT_OK(stream->Next(&result));
+ ASSERT_EQ("trace-id", result->body->ToString());
+ ValidateStatus(Status::OK(), FlightMethod::DoAction);
+}
+
+// For each method, make sure that the client middleware received
+// headers from the server and that the proper method enum value was
+// passed to the interceptor
+TEST_F(TestPropagatingMiddleware, ListFlights) {
+ client_middleware_->Reset();
+ std::unique_ptr<FlightListing> listing;
+ const Status status = client_->ListFlights(&listing);
+ ASSERT_RAISES(NotImplemented, status);
+ ValidateStatus(status, FlightMethod::ListFlights);
+}
+
+TEST_F(TestPropagatingMiddleware, GetFlightInfo) {
+ client_middleware_->Reset();
+ auto descr = FlightDescriptor::Path({"examples", "ints"});
+ std::unique_ptr<FlightInfo> info;
+ const Status status = client_->GetFlightInfo(descr, &info);
+ ASSERT_RAISES(NotImplemented, status);
+ ValidateStatus(status, FlightMethod::GetFlightInfo);
+}
+
+TEST_F(TestPropagatingMiddleware, GetSchema) {
+ client_middleware_->Reset();
+ auto descr = FlightDescriptor::Path({"examples", "ints"});
+ std::unique_ptr<SchemaResult> result;
+ const Status status = client_->GetSchema(descr, &result);
+ ASSERT_RAISES(NotImplemented, status);
+ ValidateStatus(status, FlightMethod::GetSchema);
+}
+
+TEST_F(TestPropagatingMiddleware, ListActions) {
+ client_middleware_->Reset();
+ std::vector<ActionType> actions;
+ const Status status = client_->ListActions(&actions);
+ ASSERT_RAISES(NotImplemented, status);
+ ValidateStatus(status, FlightMethod::ListActions);
+}
+
+TEST_F(TestPropagatingMiddleware, DoGet) {
+ client_middleware_->Reset();
+ Ticket ticket1{"ARROW-5095-fail"};
+ std::unique_ptr<FlightStreamReader> stream;
+ Status status = client_->DoGet(ticket1, &stream);
+ ASSERT_RAISES(NotImplemented, status);
+ ValidateStatus(status, FlightMethod::DoGet);
+}
+
+TEST_F(TestPropagatingMiddleware, DoPut) {
+ client_middleware_->Reset();
+ auto descr = FlightDescriptor::Path({"ints"});
+ auto a1 = ArrayFromJSON(int32(), "[4, 5, 6, null]");
+ auto schema = arrow::schema({field("f1", a1->type())});
+
+ std::unique_ptr<FlightStreamWriter> stream;
+ std::unique_ptr<FlightMetadataReader> reader;
+ ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
+ const Status status = stream->Close();
+ ASSERT_RAISES(NotImplemented, status);
+ ValidateStatus(status, FlightMethod::DoPut);
+}
+
+TEST_F(TestBasicHeaderAuthMiddleware, ValidCredentials) { RunValidClientAuth(); }
+
+TEST_F(TestBasicHeaderAuthMiddleware, InvalidCredentials) { RunInvalidClientAuth(); }
+
+TEST_F(TestCookieMiddleware, BasicParsing) {
+ AddAndValidate("id1=1; foo=bar;");
+ AddAndValidate("id1=1; foo=bar");
+ AddAndValidate("id2=2;");
+ AddAndValidate("id4=\"4\"");
+ AddAndValidate("id5=5; foo=bar; baz=buz;");
+}
+
+TEST_F(TestCookieMiddleware, Overwrite) {
+ AddAndValidate("id0=0");
+ AddAndValidate("id0=1");
+ AddAndValidate("id1=0");
+ AddAndValidate("id1=1");
+ AddAndValidate("id1=1");
+ AddAndValidate("id1=10");
+ AddAndValidate("id=3");
+ AddAndValidate("id=0");
+ AddAndValidate("id=0");
+}
+
+TEST_F(TestCookieMiddleware, MaxAge) {
+ AddAndValidate("id0=0; max-age=0;");
+ AddAndValidate("id1=0; max-age=-1;");
+ AddAndValidate("id2=0; max-age=0");
+ AddAndValidate("id3=0; max-age=-1");
+ AddAndValidate("id4=0; max-age=1");
+ AddAndValidate("id5=0; max-age=1");
+ AddAndValidate("id4=0; max-age=0");
+ AddAndValidate("id5=0; max-age=0");
+}
+
+TEST_F(TestCookieMiddleware, Expires) {
+ AddAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT;");
+ AddAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT");
+ AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;");
+ AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT");
+ AddAndValidate("id0=0; expires=Fri, 01 Jan 2038 22:15:36 GMT;");
+ AddAndValidate("id1=0; expires=Fri, 01 Jan 2038 22:15:36 GMT");
+ AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;");
+ AddAndValidate("id1=0; expires=Fri, 22 Dec 2017 22:15:36 GMT");
+}
+
+TEST_F(TestCookieParsing, Expired) {
+ VerifyParseCookie("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;", true);
+ VerifyParseCookie("id1=0; max-age=-1;", true);
+ VerifyParseCookie("id0=0; max-age=0;", true);
+}
+
+TEST_F(TestCookieParsing, Invalid) {
+ VerifyParseCookie("id1=0; expires=0, 0 0 0 0:0:0 GMT;", true);
+ VerifyParseCookie("id1=0; expires=Fri, 01 FOO 2038 22:15:36 GMT", true);
+ VerifyParseCookie("id1=0; expires=foo", true);
+ VerifyParseCookie("id1=0; expires=", true);
+ VerifyParseCookie("id1=0; max-age=FOO", true);
+ VerifyParseCookie("id1=0; max-age=", true);
+}
+
+TEST_F(TestCookieParsing, NoExpiry) {
+ VerifyParseCookie("id1=0;", false);
+ VerifyParseCookie("id1=0; noexpiry=Fri, 01 Jan 2038 22:15:36 GMT", false);
+ VerifyParseCookie("id1=0; noexpiry=\"Fri, 01 Jan 2038 22:15:36 GMT\"", false);
+ VerifyParseCookie("id1=0; nomax-age=-1", false);
+ VerifyParseCookie("id1=0; nomax-age=\"-1\"", false);
+ VerifyParseCookie("id1=0; randomattr=foo", false);
+}
+
+TEST_F(TestCookieParsing, NotExpired) {
+ VerifyParseCookie("id5=0; max-age=1", false);
+ VerifyParseCookie("id0=0; expires=Fri, 01 Jan 2038 22:15:36 GMT;", false);
+}
+
+TEST_F(TestCookieParsing, GetName) {
+ VerifyCookieName("id1=1; foo=bar;", "id1");
+ VerifyCookieName("id1=1; foo=bar", "id1");
+ VerifyCookieName("id2=2;", "id2");
+ VerifyCookieName("id4=\"4\"", "id4");
+ VerifyCookieName("id5=5; foo=bar; baz=buz;", "id5");
+}
+
+TEST_F(TestCookieParsing, ToString) {
+ VerifyCookieString("id1=1; foo=bar;", "id1=\"1\"");
+ VerifyCookieString("id1=1; foo=bar", "id1=\"1\"");
+ VerifyCookieString("id2=2;", "id2=\"2\"");
+ VerifyCookieString("id4=\"4\"", "id4=\"4\"");
+ VerifyCookieString("id5=5; foo=bar; baz=buz;", "id5=\"5\"");
+}
+
+TEST_F(TestCookieParsing, DateConversion) {
+ VerifyCookieDateConverson("Mon, 01 jan 2038 22:15:36 GMT;", "01 01 2038 22:15:36");
+ VerifyCookieDateConverson("TUE, 10 Feb 2038 22:15:36 GMT", "10 02 2038 22:15:36");
+ VerifyCookieDateConverson("WED, 20 MAr 2038 22:15:36 GMT;", "20 03 2038 22:15:36");
+ VerifyCookieDateConverson("thu, 15 APR 2038 22:15:36 GMT", "15 04 2038 22:15:36");
+ VerifyCookieDateConverson("Fri, 30 mAY 2038 22:15:36 GMT;", "30 05 2038 22:15:36");
+ VerifyCookieDateConverson("Sat, 03 juN 2038 22:15:36 GMT", "03 06 2038 22:15:36");
+ VerifyCookieDateConverson("Sun, 01 JuL 2038 22:15:36 GMT;", "01 07 2038 22:15:36");
+ VerifyCookieDateConverson("Fri, 06 aUg 2038 22:15:36 GMT", "06 08 2038 22:15:36");
+ VerifyCookieDateConverson("Fri, 01 SEP 2038 22:15:36 GMT;", "01 09 2038 22:15:36");
+ VerifyCookieDateConverson("Fri, 01 OCT 2038 22:15:36 GMT", "01 10 2038 22:15:36");
+ VerifyCookieDateConverson("Fri, 01 Nov 2038 22:15:36 GMT;", "01 11 2038 22:15:36");
+ VerifyCookieDateConverson("Fri, 01 deC 2038 22:15:36 GMT", "01 12 2038 22:15:36");
+ VerifyCookieDateConverson("", "");
+ VerifyCookieDateConverson("Fri, 01 INVALID 2038 22:15:36 GMT;",
+ "01 INVALID 2038 22:15:36");
+}
+
+TEST_F(TestCookieParsing, ParseCookieAttribute) {
+ VerifyCookieAttributeParsing("", 0, util::nullopt, std::string::npos);
+
+ std::string cookie_string = "attr0=0; attr1=1; attr2=2; attr3=3";
+ auto attr_length = std::string("attr0=0;").length();
+ std::string::size_type start_pos = 0;
+ VerifyCookieAttributeParsing(cookie_string, start_pos, std::make_pair("attr0", "0"),
+ cookie_string.find("attr0=0;") + attr_length);
+ VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
+ std::make_pair("attr1", "1"),
+ cookie_string.find("attr1=1;") + attr_length);
+ VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
+ std::make_pair("attr2", "2"),
+ cookie_string.find("attr2=2;") + attr_length);
+ VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
+ std::make_pair("attr3", "3"), std::string::npos);
+ VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length - 1)),
+ util::nullopt, std::string::npos);
+ VerifyCookieAttributeParsing(cookie_string, std::string::npos, util::nullopt,
+ std::string::npos);
+}
+
+TEST_F(TestCookieParsing, CookieCache) {
+ AddCookieVerifyCache({"id0=0;"}, "");
+ AddCookieVerifyCache({"id0=0;", "id0=1;"}, "id0=\"1\"");
+ AddCookieVerifyCache({"id0=0;", "id1=1;"}, "id0=\"0\"; id1=\"1\"");
+ AddCookieVerifyCache({"id0=0;", "id1=1;", "id2=2"}, "id0=\"0\"; id1=\"1\"; id2=\"2\"");
+}
+
+class ForeverFlightListing : public FlightListing {
+ Status Next(std::unique_ptr<FlightInfo>* info) override {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ *info = arrow::internal::make_unique<FlightInfo>(ExampleFlightInfo()[0]);
+ return Status::OK();
+ }
+};
+
+class ForeverResultStream : public ResultStream {
+ Status Next(std::unique_ptr<Result>* result) override {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ *result = arrow::internal::make_unique<Result>();
+ (*result)->body = Buffer::FromString("foo");
+ return Status::OK();
+ }
+};
+
+class ForeverDataStream : public FlightDataStream {
+ public:
+ ForeverDataStream() : schema_(arrow::schema({})), mapper_(*schema_) {}
+ std::shared_ptr<Schema> schema() override { return schema_; }
+
+ Status GetSchemaPayload(FlightPayload* payload) override {
+ return ipc::GetSchemaPayload(*schema_, ipc::IpcWriteOptions::Defaults(), mapper_,
+ &payload->ipc_message);
+ }
+
+ Status Next(FlightPayload* payload) override {
+ auto batch = RecordBatch::Make(schema_, 0, ArrayVector{});
+ return ipc::GetRecordBatchPayload(*batch, ipc::IpcWriteOptions::Defaults(),
+ &payload->ipc_message);
+ }
+
+ private:
+ std::shared_ptr<Schema> schema_;
+ ipc::DictionaryFieldMapper mapper_;
+};
+
+class CancelTestServer : public FlightServerBase {
+ public:
+ Status ListFlights(const ServerCallContext&, const Criteria*,
+ std::unique_ptr<FlightListing>* listings) override {
+ *listings = arrow::internal::make_unique<ForeverFlightListing>();
+ return Status::OK();
+ }
+ Status DoAction(const ServerCallContext&, const Action&,
+ std::unique_ptr<ResultStream>* result) override {
+ *result = arrow::internal::make_unique<ForeverResultStream>();
+ return Status::OK();
+ }
+ Status ListActions(const ServerCallContext&,
+ std::vector<ActionType>* actions) override {
+ *actions = {};
+ return Status::OK();
+ }
+ Status DoGet(const ServerCallContext&, const Ticket&,
+ std::unique_ptr<FlightDataStream>* data_stream) override {
+ *data_stream = arrow::internal::make_unique<ForeverDataStream>();
+ return Status::OK();
+ }
+};
+
+class TestCancel : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK(MakeServer<CancelTestServer>(
+ &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+ }
+ void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+};
+
+TEST_F(TestCancel, ListFlights) {
+ StopSource stop_source;
+ FlightCallOptions options;
+ options.stop_token = stop_source.token();
+ std::unique_ptr<FlightListing> listing;
+ stop_source.RequestStop(Status::Cancelled("StopSource"));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
+ client_->ListFlights(options, {}, &listing));
+}
+
+TEST_F(TestCancel, DoAction) {
+ StopSource stop_source;
+ FlightCallOptions options;
+ options.stop_token = stop_source.token();
+ std::unique_ptr<ResultStream> results;
+ stop_source.RequestStop(Status::Cancelled("StopSource"));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
+ client_->DoAction(options, {}, &results));
+}
+
+TEST_F(TestCancel, ListActions) {
+ StopSource stop_source;
+ FlightCallOptions options;
+ options.stop_token = stop_source.token();
+ std::vector<ActionType> results;
+ stop_source.RequestStop(Status::Cancelled("StopSource"));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
+ client_->ListActions(options, &results));
+}
+
+TEST_F(TestCancel, DoGet) {
+ StopSource stop_source;
+ FlightCallOptions options;
+ options.stop_token = stop_source.token();
+ std::unique_ptr<ResultStream> results;
+ stop_source.RequestStop(Status::Cancelled("StopSource"));
+ std::unique_ptr<FlightStreamReader> stream;
+ ASSERT_OK(client_->DoGet(options, {}, &stream));
+ std::shared_ptr<Table> table;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
+ stream->ReadAll(&table));
+
+ ASSERT_OK(client_->DoGet({}, &stream));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
+ stream->ReadAll(&table, options.stop_token));
+}
+
+TEST_F(TestCancel, DoExchange) {
+ StopSource stop_source;
+ FlightCallOptions options;
+ options.stop_token = stop_source.token();
+ std::unique_ptr<ResultStream> results;
+ stop_source.RequestStop(Status::Cancelled("StopSource"));
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightStreamReader> stream;
+ ASSERT_OK(
+ client_->DoExchange(options, FlightDescriptor::Command(""), &writer, &stream));
+ std::shared_ptr<Table> table;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
+ stream->ReadAll(&table));
+
+ ASSERT_OK(client_->DoExchange(FlightDescriptor::Command(""), &writer, &stream));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
+ stream->ReadAll(&table, options.stop_token));
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/internal.cc b/src/arrow/cpp/src/arrow/flight/internal.cc
new file mode 100644
index 000000000..f27de208a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/internal.cc
@@ -0,0 +1,514 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/internal.h"
+
+#include <cstddef>
+#include <map>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <utility>
+
+#include "arrow/flight/platform.h"
+#include "arrow/flight/protocol_internal.h"
+#include "arrow/flight/types.h"
+
+#ifdef GRPCPP_PP_INCLUDE
+#include <grpcpp/grpcpp.h>
+#else
+#include <grpc++/grpc++.h>
+#endif
+
+#include "arrow/buffer.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string_builder.h"
+
+namespace arrow {
+namespace flight {
+namespace internal {
+
+const char* kGrpcAuthHeader = "auth-token-bin";
+const char* kGrpcStatusCodeHeader = "x-arrow-status";
+const char* kGrpcStatusMessageHeader = "x-arrow-status-message-bin";
+const char* kGrpcStatusDetailHeader = "x-arrow-status-detail-bin";
+const char* kBinaryErrorDetailsKey = "grpc-status-details-bin";
+
+static Status StatusCodeFromString(const grpc::string_ref& code_ref, StatusCode* code) {
+ // Bounce through std::string to get a proper null-terminated C string
+ const auto code_int = std::atoi(std::string(code_ref.data(), code_ref.size()).c_str());
+ switch (code_int) {
+ case static_cast<int>(StatusCode::OutOfMemory):
+ case static_cast<int>(StatusCode::KeyError):
+ case static_cast<int>(StatusCode::TypeError):
+ case static_cast<int>(StatusCode::Invalid):
+ case static_cast<int>(StatusCode::IOError):
+ case static_cast<int>(StatusCode::CapacityError):
+ case static_cast<int>(StatusCode::IndexError):
+ case static_cast<int>(StatusCode::UnknownError):
+ case static_cast<int>(StatusCode::NotImplemented):
+ case static_cast<int>(StatusCode::SerializationError):
+ case static_cast<int>(StatusCode::RError):
+ case static_cast<int>(StatusCode::CodeGenError):
+ case static_cast<int>(StatusCode::ExpressionValidationError):
+ case static_cast<int>(StatusCode::ExecutionError):
+ case static_cast<int>(StatusCode::AlreadyExists): {
+ *code = static_cast<StatusCode>(code_int);
+ return Status::OK();
+ }
+ default:
+ // Code is invalid
+ return Status::UnknownError("Unknown Arrow status code", code_ref);
+ }
+}
+
+/// Try to extract a status from gRPC trailers.
+/// Return Status::OK if found, an error otherwise.
+static Status FromGrpcContext(const grpc::ClientContext& ctx, Status* status,
+ std::shared_ptr<FlightStatusDetail> flightStatusDetail) {
+ const std::multimap<grpc::string_ref, grpc::string_ref>& trailers =
+ ctx.GetServerTrailingMetadata();
+ const auto code_val = trailers.find(kGrpcStatusCodeHeader);
+ if (code_val == trailers.end()) {
+ return Status::IOError("Status code header not found");
+ }
+
+ const grpc::string_ref code_ref = code_val->second;
+ StatusCode code = {};
+ RETURN_NOT_OK(StatusCodeFromString(code_ref, &code));
+
+ const auto message_val = trailers.find(kGrpcStatusMessageHeader);
+ if (message_val == trailers.end()) {
+ return Status::IOError("Status message header not found");
+ }
+
+ const grpc::string_ref message_ref = message_val->second;
+ std::string message = std::string(message_ref.data(), message_ref.size());
+ const auto detail_val = trailers.find(kGrpcStatusDetailHeader);
+ if (detail_val != trailers.end()) {
+ const grpc::string_ref detail_ref = detail_val->second;
+ message += ". Detail: ";
+ message += std::string(detail_ref.data(), detail_ref.size());
+ }
+ const auto grpc_detail_val = trailers.find(kBinaryErrorDetailsKey);
+ if (grpc_detail_val != trailers.end()) {
+ const grpc::string_ref detail_ref = grpc_detail_val->second;
+ std::string bin_detail = std::string(detail_ref.data(), detail_ref.size());
+ if (!flightStatusDetail) {
+ flightStatusDetail =
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal);
+ }
+ flightStatusDetail->set_extra_info(bin_detail);
+ }
+ *status = Status(code, message, flightStatusDetail);
+ return Status::OK();
+}
+
+/// Convert a gRPC status to an Arrow status, ignoring any
+/// implementation-defined headers that encode further detail.
+static Status FromGrpcCode(const grpc::Status& grpc_status) {
+ switch (grpc_status.error_code()) {
+ case grpc::StatusCode::OK:
+ return Status::OK();
+ case grpc::StatusCode::CANCELLED:
+ return Status::IOError("gRPC cancelled call, with message: ",
+ grpc_status.error_message())
+ .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::Cancelled));
+ case grpc::StatusCode::UNKNOWN: {
+ std::stringstream ss;
+ ss << "Flight RPC failed with message: " << grpc_status.error_message();
+ return Status::UnknownError(ss.str()).WithDetail(
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Failed));
+ }
+ case grpc::StatusCode::INVALID_ARGUMENT:
+ return Status::Invalid("gRPC returned invalid argument error, with message: ",
+ grpc_status.error_message());
+ case grpc::StatusCode::DEADLINE_EXCEEDED:
+ return Status::IOError("gRPC returned deadline exceeded error, with message: ",
+ grpc_status.error_message())
+ .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::TimedOut));
+ case grpc::StatusCode::NOT_FOUND:
+ return Status::KeyError("gRPC returned not found error, with message: ",
+ grpc_status.error_message());
+ case grpc::StatusCode::ALREADY_EXISTS:
+ return Status::AlreadyExists("gRPC returned already exists error, with message: ",
+ grpc_status.error_message());
+ case grpc::StatusCode::PERMISSION_DENIED:
+ return Status::IOError("gRPC returned permission denied error, with message: ",
+ grpc_status.error_message())
+ .WithDetail(
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Unauthorized));
+ case grpc::StatusCode::RESOURCE_EXHAUSTED:
+ return Status::Invalid("gRPC returned resource exhausted error, with message: ",
+ grpc_status.error_message());
+ case grpc::StatusCode::FAILED_PRECONDITION:
+ return Status::Invalid("gRPC returned precondition failed error, with message: ",
+ grpc_status.error_message());
+ case grpc::StatusCode::ABORTED:
+ return Status::IOError("gRPC returned aborted error, with message: ",
+ grpc_status.error_message())
+ .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal));
+ case grpc::StatusCode::OUT_OF_RANGE:
+ return Status::Invalid("gRPC returned out-of-range error, with message: ",
+ grpc_status.error_message());
+ case grpc::StatusCode::UNIMPLEMENTED:
+ return Status::NotImplemented("gRPC returned unimplemented error, with message: ",
+ grpc_status.error_message());
+ case grpc::StatusCode::INTERNAL:
+ return Status::IOError("gRPC returned internal error, with message: ",
+ grpc_status.error_message())
+ .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal));
+ case grpc::StatusCode::UNAVAILABLE:
+ return Status::IOError("gRPC returned unavailable error, with message: ",
+ grpc_status.error_message())
+ .WithDetail(
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Unavailable));
+ case grpc::StatusCode::DATA_LOSS:
+ return Status::IOError("gRPC returned data loss error, with message: ",
+ grpc_status.error_message())
+ .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal));
+ case grpc::StatusCode::UNAUTHENTICATED:
+ return Status::IOError("gRPC returned unauthenticated error, with message: ",
+ grpc_status.error_message())
+ .WithDetail(
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Unauthenticated));
+ default:
+ return Status::UnknownError("gRPC failed with error code ",
+ grpc_status.error_code(),
+ " and message: ", grpc_status.error_message());
+ }
+}
+
+Status FromGrpcStatus(const grpc::Status& grpc_status, grpc::ClientContext* ctx) {
+ const Status status = FromGrpcCode(grpc_status);
+
+ if (!status.ok() && ctx) {
+ Status arrow_status;
+
+ if (!FromGrpcContext(*ctx, &arrow_status, FlightStatusDetail::UnwrapStatus(status))
+ .ok()) {
+ // If we fail to decode a more detailed status from the headers,
+ // proceed normally
+ return status;
+ }
+
+ return arrow_status;
+ }
+ return status;
+}
+
+/// Convert an Arrow status to a gRPC status.
+static grpc::Status ToRawGrpcStatus(const Status& arrow_status) {
+ if (arrow_status.ok()) {
+ return grpc::Status::OK;
+ }
+
+ grpc::StatusCode grpc_code = grpc::StatusCode::UNKNOWN;
+ std::string message = arrow_status.message();
+ if (arrow_status.detail()) {
+ message += ". Detail: ";
+ message += arrow_status.detail()->ToString();
+ }
+
+ std::shared_ptr<FlightStatusDetail> flight_status =
+ FlightStatusDetail::UnwrapStatus(arrow_status);
+ if (flight_status) {
+ switch (flight_status->code()) {
+ case FlightStatusCode::Internal:
+ grpc_code = grpc::StatusCode::INTERNAL;
+ break;
+ case FlightStatusCode::TimedOut:
+ grpc_code = grpc::StatusCode::DEADLINE_EXCEEDED;
+ break;
+ case FlightStatusCode::Cancelled:
+ grpc_code = grpc::StatusCode::CANCELLED;
+ break;
+ case FlightStatusCode::Unauthenticated:
+ grpc_code = grpc::StatusCode::UNAUTHENTICATED;
+ break;
+ case FlightStatusCode::Unauthorized:
+ grpc_code = grpc::StatusCode::PERMISSION_DENIED;
+ break;
+ case FlightStatusCode::Unavailable:
+ grpc_code = grpc::StatusCode::UNAVAILABLE;
+ break;
+ default:
+ break;
+ }
+ } else if (arrow_status.IsNotImplemented()) {
+ grpc_code = grpc::StatusCode::UNIMPLEMENTED;
+ } else if (arrow_status.IsInvalid()) {
+ grpc_code = grpc::StatusCode::INVALID_ARGUMENT;
+ } else if (arrow_status.IsKeyError()) {
+ grpc_code = grpc::StatusCode::NOT_FOUND;
+ } else if (arrow_status.IsAlreadyExists()) {
+ grpc_code = grpc::StatusCode::ALREADY_EXISTS;
+ }
+ return grpc::Status(grpc_code, message);
+}
+
+/// Convert an Arrow status to a gRPC status, and add extra headers to
+/// the response to encode the original Arrow status.
+grpc::Status ToGrpcStatus(const Status& arrow_status, grpc::ServerContext* ctx) {
+ grpc::Status status = ToRawGrpcStatus(arrow_status);
+ if (!status.ok() && ctx) {
+ const std::string code = std::to_string(static_cast<int>(arrow_status.code()));
+ ctx->AddTrailingMetadata(internal::kGrpcStatusCodeHeader, code);
+ ctx->AddTrailingMetadata(internal::kGrpcStatusMessageHeader, arrow_status.message());
+ if (arrow_status.detail()) {
+ const std::string detail_string = arrow_status.detail()->ToString();
+ ctx->AddTrailingMetadata(internal::kGrpcStatusDetailHeader, detail_string);
+ }
+ auto fsd = FlightStatusDetail::UnwrapStatus(arrow_status);
+ if (fsd && !fsd->extra_info().empty()) {
+ ctx->AddTrailingMetadata(internal::kBinaryErrorDetailsKey, fsd->extra_info());
+ }
+ }
+
+ return status;
+}
+
+// ActionType
+
+Status FromProto(const pb::ActionType& pb_type, ActionType* type) {
+ type->type = pb_type.type();
+ type->description = pb_type.description();
+ return Status::OK();
+}
+
+Status ToProto(const ActionType& type, pb::ActionType* pb_type) {
+ pb_type->set_type(type.type);
+ pb_type->set_description(type.description);
+ return Status::OK();
+}
+
+// Action
+
+Status FromProto(const pb::Action& pb_action, Action* action) {
+ action->type = pb_action.type();
+ action->body = Buffer::FromString(pb_action.body());
+ return Status::OK();
+}
+
+Status ToProto(const Action& action, pb::Action* pb_action) {
+ pb_action->set_type(action.type);
+ if (action.body) {
+ pb_action->set_body(action.body->ToString());
+ }
+ return Status::OK();
+}
+
+// Result (of an Action)
+
+Status FromProto(const pb::Result& pb_result, Result* result) {
+ // ARROW-3250; can avoid copy. Can also write custom deserializer if it
+ // becomes an issue
+ result->body = Buffer::FromString(pb_result.body());
+ return Status::OK();
+}
+
+Status ToProto(const Result& result, pb::Result* pb_result) {
+ pb_result->set_body(result.body->ToString());
+ return Status::OK();
+}
+
+// Criteria
+
+Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria) {
+ criteria->expression = pb_criteria.expression();
+ return Status::OK();
+}
+Status ToProto(const Criteria& criteria, pb::Criteria* pb_criteria) {
+ pb_criteria->set_expression(criteria.expression);
+ return Status::OK();
+}
+
+// Location
+
+Status FromProto(const pb::Location& pb_location, Location* location) {
+ return Location::Parse(pb_location.uri(), location);
+}
+
+void ToProto(const Location& location, pb::Location* pb_location) {
+ pb_location->set_uri(location.ToString());
+}
+
+Status ToProto(const BasicAuth& basic_auth, pb::BasicAuth* pb_basic_auth) {
+ pb_basic_auth->set_username(basic_auth.username);
+ pb_basic_auth->set_password(basic_auth.password);
+ return Status::OK();
+}
+
+// Ticket
+
+Status FromProto(const pb::Ticket& pb_ticket, Ticket* ticket) {
+ ticket->ticket = pb_ticket.ticket();
+ return Status::OK();
+}
+
+void ToProto(const Ticket& ticket, pb::Ticket* pb_ticket) {
+ pb_ticket->set_ticket(ticket.ticket);
+}
+
+// FlightData
+
+Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor,
+ std::unique_ptr<ipc::Message>* message) {
+ RETURN_NOT_OK(internal::FromProto(pb_data.flight_descriptor(), descriptor));
+ const std::string& header = pb_data.data_header();
+ const std::string& body = pb_data.data_body();
+ std::shared_ptr<Buffer> header_buf = Buffer::Wrap(header.data(), header.size());
+ std::shared_ptr<Buffer> body_buf = Buffer::Wrap(body.data(), body.size());
+ if (header_buf == nullptr || body_buf == nullptr) {
+ return Status::UnknownError("Could not create buffers from protobuf");
+ }
+ return ipc::Message::Open(header_buf, body_buf).Value(message);
+}
+
+// FlightEndpoint
+
+Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint) {
+ RETURN_NOT_OK(FromProto(pb_endpoint.ticket(), &endpoint->ticket));
+ endpoint->locations.resize(pb_endpoint.location_size());
+ for (int i = 0; i < pb_endpoint.location_size(); ++i) {
+ RETURN_NOT_OK(FromProto(pb_endpoint.location(i), &endpoint->locations[i]));
+ }
+ return Status::OK();
+}
+
+void ToProto(const FlightEndpoint& endpoint, pb::FlightEndpoint* pb_endpoint) {
+ ToProto(endpoint.ticket, pb_endpoint->mutable_ticket());
+ pb_endpoint->clear_location();
+ for (const Location& location : endpoint.locations) {
+ ToProto(location, pb_endpoint->add_location());
+ }
+}
+
+// FlightDescriptor
+
+Status FromProto(const pb::FlightDescriptor& pb_descriptor,
+ FlightDescriptor* descriptor) {
+ if (pb_descriptor.type() == pb::FlightDescriptor::PATH) {
+ descriptor->type = FlightDescriptor::PATH;
+ descriptor->path.reserve(pb_descriptor.path_size());
+ for (int i = 0; i < pb_descriptor.path_size(); ++i) {
+ descriptor->path.emplace_back(pb_descriptor.path(i));
+ }
+ } else if (pb_descriptor.type() == pb::FlightDescriptor::CMD) {
+ descriptor->type = FlightDescriptor::CMD;
+ descriptor->cmd = pb_descriptor.cmd();
+ } else {
+ return Status::Invalid("Client sent UNKNOWN descriptor type");
+ }
+ return Status::OK();
+}
+
+Status ToProto(const FlightDescriptor& descriptor, pb::FlightDescriptor* pb_descriptor) {
+ if (descriptor.type == FlightDescriptor::PATH) {
+ pb_descriptor->set_type(pb::FlightDescriptor::PATH);
+ for (const std::string& path : descriptor.path) {
+ pb_descriptor->add_path(path);
+ }
+ } else {
+ pb_descriptor->set_type(pb::FlightDescriptor::CMD);
+ pb_descriptor->set_cmd(descriptor.cmd);
+ }
+ return Status::OK();
+}
+
+// FlightInfo
+
+Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info) {
+ RETURN_NOT_OK(FromProto(pb_info.flight_descriptor(), &info->descriptor));
+
+ info->schema = pb_info.schema();
+
+ info->endpoints.resize(pb_info.endpoint_size());
+ for (int i = 0; i < pb_info.endpoint_size(); ++i) {
+ RETURN_NOT_OK(FromProto(pb_info.endpoint(i), &info->endpoints[i]));
+ }
+
+ info->total_records = pb_info.total_records();
+ info->total_bytes = pb_info.total_bytes();
+ return Status::OK();
+}
+
+Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* basic_auth) {
+ basic_auth->password = pb_basic_auth.password();
+ basic_auth->username = pb_basic_auth.username();
+
+ return Status::OK();
+}
+
+Status FromProto(const pb::SchemaResult& pb_result, std::string* result) {
+ *result = pb_result.schema();
+ return Status::OK();
+}
+
+Status SchemaToString(const Schema& schema, std::string* out) {
+ ipc::DictionaryMemo unused_dict_memo;
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> serialized_schema,
+ ipc::SerializeSchema(schema));
+ *out = std::string(reinterpret_cast<const char*>(serialized_schema->data()),
+ static_cast<size_t>(serialized_schema->size()));
+ return Status::OK();
+}
+
+Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info) {
+ // clear any repeated fields
+ pb_info->clear_endpoint();
+
+ pb_info->set_schema(info.serialized_schema());
+
+ // descriptor
+ RETURN_NOT_OK(ToProto(info.descriptor(), pb_info->mutable_flight_descriptor()));
+
+ // endpoints
+ for (const FlightEndpoint& endpoint : info.endpoints()) {
+ ToProto(endpoint, pb_info->add_endpoint());
+ }
+
+ pb_info->set_total_records(info.total_records());
+ pb_info->set_total_bytes(info.total_bytes());
+ return Status::OK();
+}
+
+Status ToProto(const SchemaResult& result, pb::SchemaResult* pb_result) {
+ pb_result->set_schema(result.serialized_schema());
+ return Status::OK();
+}
+
+Status ToPayload(const FlightDescriptor& descr, std::shared_ptr<Buffer>* out) {
+ // TODO(lidavidm): make these use Result<T>
+ std::string str_descr;
+ pb::FlightDescriptor pb_descr;
+ RETURN_NOT_OK(ToProto(descr, &pb_descr));
+ if (!pb_descr.SerializeToString(&str_descr)) {
+ return Status::UnknownError("Failed to serialize Flight descriptor");
+ }
+ *out = Buffer::FromString(std::move(str_descr));
+ return Status::OK();
+}
+
+} // namespace internal
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/internal.h b/src/arrow/cpp/src/arrow/flight/internal.h
new file mode 100644
index 000000000..c0964c68f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/internal.h
@@ -0,0 +1,128 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/flight/protocol_internal.h" // IWYU pragma: keep
+#include "arrow/flight/types.h"
+#include "arrow/util/macros.h"
+
+namespace grpc {
+
+class Status;
+
+} // namespace grpc
+
+namespace arrow {
+
+class Schema;
+class Status;
+
+namespace pb = arrow::flight::protocol;
+
+namespace ipc {
+
+class Message;
+
+} // namespace ipc
+
+namespace flight {
+
+#define GRPC_RETURN_NOT_OK(expr) \
+ do { \
+ ::arrow::Status _s = (expr); \
+ if (ARROW_PREDICT_FALSE(!_s.ok())) { \
+ return ::arrow::flight::internal::ToGrpcStatus(_s); \
+ } \
+ } while (0)
+
+#define GRPC_RETURN_NOT_GRPC_OK(expr) \
+ do { \
+ ::grpc::Status _s = (expr); \
+ if (ARROW_PREDICT_FALSE(!_s.ok())) { \
+ return _s; \
+ } \
+ } while (0)
+
+namespace internal {
+
+/// The name of the header used to pass authentication tokens.
+ARROW_FLIGHT_EXPORT
+extern const char* kGrpcAuthHeader;
+
+/// The name of the header used to pass the exact Arrow status code.
+ARROW_FLIGHT_EXPORT
+extern const char* kGrpcStatusCodeHeader;
+
+/// The name of the header used to pass the exact Arrow status message.
+ARROW_FLIGHT_EXPORT
+extern const char* kGrpcStatusMessageHeader;
+
+/// The name of the header used to pass the exact Arrow status detail.
+ARROW_FLIGHT_EXPORT
+extern const char* kGrpcStatusDetailHeader;
+
+ARROW_FLIGHT_EXPORT
+extern const char* kBinaryErrorDetailsKey;
+
+ARROW_FLIGHT_EXPORT
+Status SchemaToString(const Schema& schema, std::string* out);
+
+/// Convert a gRPC status to an Arrow status. Optionally, provide a
+/// ClientContext to recover the exact Arrow status if it was passed
+/// over the wire.
+ARROW_FLIGHT_EXPORT
+Status FromGrpcStatus(const grpc::Status& grpc_status,
+ grpc::ClientContext* ctx = nullptr);
+
+ARROW_FLIGHT_EXPORT
+grpc::Status ToGrpcStatus(const Status& arrow_status, grpc::ServerContext* ctx = nullptr);
+
+// These functions depend on protobuf types which are not exported in the Flight DLL.
+
+Status FromProto(const pb::ActionType& pb_type, ActionType* type);
+Status FromProto(const pb::Action& pb_action, Action* action);
+Status FromProto(const pb::Result& pb_result, Result* result);
+Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria);
+Status FromProto(const pb::Location& pb_location, Location* location);
+Status FromProto(const pb::Ticket& pb_ticket, Ticket* ticket);
+Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor,
+ std::unique_ptr<ipc::Message>* message);
+Status FromProto(const pb::FlightDescriptor& pb_descr, FlightDescriptor* descr);
+Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint);
+Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info);
+Status FromProto(const pb::SchemaResult& pb_result, std::string* result);
+Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* info);
+
+Status ToProto(const FlightDescriptor& descr, pb::FlightDescriptor* pb_descr);
+Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info);
+Status ToProto(const ActionType& type, pb::ActionType* pb_type);
+Status ToProto(const Action& action, pb::Action* pb_action);
+Status ToProto(const Result& result, pb::Result* pb_result);
+Status ToProto(const Criteria& criteria, pb::Criteria* pb_criteria);
+Status ToProto(const SchemaResult& result, pb::SchemaResult* pb_result);
+void ToProto(const Ticket& ticket, pb::Ticket* pb_ticket);
+Status ToProto(const BasicAuth& basic_auth, pb::BasicAuth* pb_basic_auth);
+
+Status ToPayload(const FlightDescriptor& descr, std::shared_ptr<Buffer>* out);
+
+} // namespace internal
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/middleware.h b/src/arrow/cpp/src/arrow/flight/middleware.h
new file mode 100644
index 000000000..d4f5e9320
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/middleware.h
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Interfaces for defining middleware for Flight clients and
+// servers. Currently experimental.
+
+#pragma once
+
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "arrow/flight/visibility.h" // IWYU pragma: keep
+#include "arrow/status.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+
+namespace flight {
+
+/// \brief Headers sent from the client or server.
+///
+/// Header values are ordered.
+using CallHeaders = std::multimap<util::string_view, util::string_view>;
+
+/// \brief A write-only wrapper around headers for an RPC call.
+class ARROW_FLIGHT_EXPORT AddCallHeaders {
+ public:
+ virtual ~AddCallHeaders() = default;
+
+ /// \brief Add a header to be sent to the client.
+ virtual void AddHeader(const std::string& key, const std::string& value) = 0;
+};
+
+/// \brief An enumeration of the RPC methods Flight implements.
+enum class FlightMethod : char {
+ Invalid = 0,
+ Handshake = 1,
+ ListFlights = 2,
+ GetFlightInfo = 3,
+ GetSchema = 4,
+ DoGet = 5,
+ DoPut = 6,
+ DoAction = 7,
+ ListActions = 8,
+ DoExchange = 9,
+};
+
+/// \brief Information about an instance of a Flight RPC.
+struct ARROW_FLIGHT_EXPORT CallInfo {
+ public:
+ /// \brief The RPC method of this call.
+ FlightMethod method;
+};
+
+} // namespace flight
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/middleware_internal.h b/src/arrow/cpp/src/arrow/flight/middleware_internal.h
new file mode 100644
index 000000000..8ee76476a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/middleware_internal.h
@@ -0,0 +1,46 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Interfaces for defining middleware for Flight clients and
+// servers. Currently experimental.
+
+#pragma once
+
+#include "arrow/flight/platform.h"
+#include "arrow/flight/visibility.h" // IWYU pragma: keep
+
+#include <map>
+#include <string>
+#include <utility>
+
+#ifdef GRPCPP_PP_INCLUDE
+#include <grpcpp/grpcpp.h>
+#else
+#include <grpc++/grpc++.h>
+#endif
+
+#include "arrow/flight/middleware.h"
+
+namespace arrow {
+
+namespace flight {
+
+namespace internal {} // namespace internal
+
+} // namespace flight
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/pch.h b/src/arrow/cpp/src/arrow/flight/pch.h
new file mode 100644
index 000000000..fff107fa8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/pch.h
@@ -0,0 +1,26 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Often-used headers, for precompiling.
+// If updating this header, please make sure you check compilation speed
+// before checking in. Adding headers which are not used extremely often
+// may incur a slowdown, since it makes the precompiled header heavier to load.
+
+#include "arrow/flight/client.h"
+#include "arrow/flight/server.h"
+#include "arrow/flight/types.h"
+#include "arrow/pch.h"
diff --git a/src/arrow/cpp/src/arrow/flight/perf.proto b/src/arrow/cpp/src/arrow/flight/perf.proto
new file mode 100644
index 000000000..9123bafba
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/perf.proto
@@ -0,0 +1,44 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package arrow.flight.perf;
+
+message Perf {
+ bytes schema = 1;
+ int32 stream_count = 2;
+ int64 records_per_stream = 3;
+ int32 records_per_batch = 4;
+}
+
+/*
+ * Payload of ticket
+ */
+message Token {
+
+ // definition of entire flight.
+ Perf definition = 1;
+
+ // inclusive start
+ int64 start = 2;
+
+ // exclusive end
+ int64 end = 3;
+
+} \ No newline at end of file
diff --git a/src/arrow/cpp/src/arrow/flight/perf_server.cc b/src/arrow/cpp/src/arrow/flight/perf_server.cc
new file mode 100644
index 000000000..7efd034ad
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/perf_server.cc
@@ -0,0 +1,285 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Performance server for benchmarking purposes
+
+#include <signal.h>
+#include <cstdint>
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include <gflags/gflags.h>
+
+#include "arrow/array.h"
+#include "arrow/io/test_common.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/logging.h"
+
+#include "arrow/flight/api.h"
+#include "arrow/flight/internal.h"
+#include "arrow/flight/perf.pb.h"
+#include "arrow/flight/test_util.h"
+
+DEFINE_string(server_host, "localhost", "Host where the server is running on");
+DEFINE_int32(port, 31337, "Server port to listen on");
+DEFINE_string(server_unix, "", "Unix socket path where the server is running on");
+DEFINE_string(cert_file, "", "Path to TLS certificate");
+DEFINE_string(key_file, "", "Path to TLS private key");
+
+namespace perf = arrow::flight::perf;
+namespace proto = arrow::flight::protocol;
+
+namespace arrow {
+namespace flight {
+
+#define CHECK_PARSE(EXPR) \
+ do { \
+ if (!EXPR) { \
+ return Status::Invalid("cannot parse protobuf"); \
+ } \
+ } while (0)
+
+// Create record batches with a unique "a" column so we can verify on the
+// client side that the results are correct
+class PerfDataStream : public FlightDataStream {
+ public:
+ PerfDataStream(bool verify, const int64_t start, const int64_t total_records,
+ const std::shared_ptr<Schema>& schema, const ArrayVector& arrays)
+ : start_(start),
+ verify_(verify),
+ batch_length_(arrays[0]->length()),
+ total_records_(total_records),
+ records_sent_(0),
+ schema_(schema),
+ mapper_(*schema),
+ arrays_(arrays) {
+ batch_ = RecordBatch::Make(schema, batch_length_, arrays_);
+ }
+
+ std::shared_ptr<Schema> schema() override { return schema_; }
+
+ Status GetSchemaPayload(FlightPayload* payload) override {
+ return ipc::GetSchemaPayload(*schema_, ipc_options_, mapper_, &payload->ipc_message);
+ }
+
+ Status Next(FlightPayload* payload) override {
+ if (records_sent_ >= total_records_) {
+ // Signal that iteration is over
+ payload->ipc_message.metadata = nullptr;
+ return Status::OK();
+ }
+
+ if (verify_) {
+ // mutate first array
+ auto data =
+ reinterpret_cast<int64_t*>(arrays_[0]->data()->buffers[1]->mutable_data());
+ for (int64_t i = 0; i < batch_length_; ++i) {
+ data[i] = start_ + records_sent_ + i;
+ }
+ }
+
+ auto batch = batch_;
+
+ // Last partial batch
+ if (records_sent_ + batch_length_ > total_records_) {
+ batch = batch_->Slice(0, total_records_ - records_sent_);
+ records_sent_ += total_records_ - records_sent_;
+ } else {
+ records_sent_ += batch_length_;
+ }
+ return ipc::GetRecordBatchPayload(*batch, ipc_options_, &payload->ipc_message);
+ }
+
+ private:
+ const int64_t start_;
+ bool verify_;
+ const int64_t batch_length_;
+ const int64_t total_records_;
+ int64_t records_sent_;
+ std::shared_ptr<Schema> schema_;
+ ipc::DictionaryFieldMapper mapper_;
+ ipc::IpcWriteOptions ipc_options_;
+ std::shared_ptr<RecordBatch> batch_;
+ ArrayVector arrays_;
+};
+
+Status GetPerfBatches(const perf::Token& token, const std::shared_ptr<Schema>& schema,
+ bool use_verifier, std::unique_ptr<FlightDataStream>* data_stream) {
+ std::shared_ptr<ResizableBuffer> buffer;
+ std::vector<std::shared_ptr<Array>> arrays;
+
+ const int32_t length = token.definition().records_per_batch();
+ const int32_t ncolumns = 4;
+ for (int i = 0; i < ncolumns; ++i) {
+ RETURN_NOT_OK(MakeRandomByteBuffer(length * sizeof(int64_t), default_memory_pool(),
+ &buffer, static_cast<int32_t>(i) /* seed */));
+ arrays.push_back(std::make_shared<Int64Array>(length, buffer));
+ RETURN_NOT_OK(arrays.back()->Validate());
+ }
+
+ *data_stream = std::unique_ptr<FlightDataStream>(
+ new PerfDataStream(use_verifier, token.start(),
+ token.definition().records_per_stream(), schema, arrays));
+ return Status::OK();
+}
+
+class FlightPerfServer : public FlightServerBase {
+ public:
+ FlightPerfServer() : location_() {
+ perf_schema_ = schema({field("a", int64()), field("b", int64()), field("c", int64()),
+ field("d", int64())});
+ }
+
+ void SetLocation(Location location) { location_ = location; }
+
+ Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
+ std::unique_ptr<FlightInfo>* info) override {
+ perf::Perf perf_request;
+ CHECK_PARSE(perf_request.ParseFromString(request.cmd));
+
+ perf::Token token;
+ token.mutable_definition()->CopyFrom(perf_request);
+
+ std::vector<FlightEndpoint> endpoints;
+ Ticket tmp_ticket;
+ for (int64_t i = 0; i < perf_request.stream_count(); ++i) {
+ token.set_start(i * perf_request.records_per_stream());
+ token.set_end((i + 1) * perf_request.records_per_stream());
+
+ (void)token.SerializeToString(&tmp_ticket.ticket);
+
+ // All endpoints same location for now
+ endpoints.push_back(FlightEndpoint{tmp_ticket, {location_}});
+ }
+
+ uint64_t total_records =
+ perf_request.stream_count() * perf_request.records_per_stream();
+
+ FlightInfo::Data data;
+ RETURN_NOT_OK(
+ MakeFlightInfo(*perf_schema_, request, endpoints, total_records, -1, &data));
+ *info = std::unique_ptr<FlightInfo>(new FlightInfo(data));
+ return Status::OK();
+ }
+
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* data_stream) override {
+ perf::Token token;
+ CHECK_PARSE(token.ParseFromString(request.ticket));
+ return GetPerfBatches(token, perf_schema_, false, data_stream);
+ }
+
+ Status DoPut(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer) override {
+ FlightStreamChunk chunk;
+ while (true) {
+ RETURN_NOT_OK(reader->Next(&chunk));
+ if (!chunk.data) break;
+ if (chunk.app_metadata) {
+ RETURN_NOT_OK(writer->WriteMetadata(*chunk.app_metadata));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* result) override {
+ if (action.type == "ping") {
+ std::shared_ptr<Buffer> buf = Buffer::FromString("ok");
+ *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
+ return Status::OK();
+ }
+ return Status::NotImplemented(action.type);
+ }
+
+ private:
+ Location location_;
+ std::shared_ptr<Schema> perf_schema_;
+};
+
+} // namespace flight
+} // namespace arrow
+
+std::unique_ptr<arrow::flight::FlightPerfServer> g_server;
+
+void Shutdown(int signal) {
+ if (g_server != nullptr) {
+ ARROW_CHECK_OK(g_server->Shutdown());
+ }
+}
+
+int main(int argc, char** argv) {
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+ g_server.reset(new arrow::flight::FlightPerfServer);
+
+ arrow::flight::Location bind_location;
+ arrow::flight::Location connect_location;
+ if (FLAGS_server_unix.empty()) {
+ if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) {
+ if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) {
+ ARROW_CHECK_OK(
+ arrow::flight::Location::ForGrpcTls("0.0.0.0", FLAGS_port, &bind_location));
+ ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTls(FLAGS_server_host, FLAGS_port,
+ &connect_location));
+ } else {
+ std::cerr << "If providing TLS cert/key, must provide both" << std::endl;
+ return 1;
+ }
+ } else {
+ ARROW_CHECK_OK(
+ arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &bind_location));
+ ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, FLAGS_port,
+ &connect_location));
+ }
+ } else {
+ ARROW_CHECK_OK(
+ arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &bind_location));
+ ARROW_CHECK_OK(
+ arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &connect_location));
+ }
+ arrow::flight::FlightServerOptions options(bind_location);
+ if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) {
+ std::cout << "Enabling TLS" << std::endl;
+ std::ifstream cert_file(FLAGS_cert_file);
+ std::string cert((std::istreambuf_iterator<char>(cert_file)),
+ (std::istreambuf_iterator<char>()));
+ std::ifstream key_file(FLAGS_key_file);
+ std::string key((std::istreambuf_iterator<char>(key_file)),
+ (std::istreambuf_iterator<char>()));
+ options.tls_certificates.push_back(arrow::flight::CertKeyPair{cert, key});
+ }
+
+ ARROW_CHECK_OK(g_server->Init(options));
+ // Exit with a clean error code (0) on SIGTERM
+ ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
+ if (FLAGS_server_unix.empty()) {
+ std::cout << "Server host: " << FLAGS_server_host << std::endl;
+ std::cout << "Server port: " << FLAGS_port << std::endl;
+ } else {
+ std::cout << "Server unix socket: " << FLAGS_server_unix << std::endl;
+ }
+ g_server->SetLocation(connect_location);
+ ARROW_CHECK_OK(g_server->Serve());
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/flight/platform.h b/src/arrow/cpp/src/arrow/flight/platform.h
new file mode 100644
index 000000000..7f1b0954d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/platform.h
@@ -0,0 +1,32 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Internal header. Platform-specific definitions for gRPC.
+
+#pragma once
+
+#ifdef _MSC_VER
+
+// The protobuf documentation says that C4251 warnings when using the
+// library are spurious and suppressed when the build the library and
+// compiler, but must be also suppressed in downstream projects
+#pragma warning(disable : 4251)
+
+#endif // _MSC_VER
+
+#include "arrow/util/config.h" // IWYU pragma: keep
+#include "arrow/util/windows_compatibility.h" // IWYU pragma: keep
diff --git a/src/arrow/cpp/src/arrow/flight/protocol_internal.cc b/src/arrow/cpp/src/arrow/flight/protocol_internal.cc
new file mode 100644
index 000000000..9f815398e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/protocol_internal.cc
@@ -0,0 +1,26 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+
+#include "arrow/flight/protocol_internal.h"
+
+// NOTE(wesm): Including .cc files in another .cc file would ordinarily be a
+// no-no. We have customized the serialization path for FlightData, which is
+// currently only possible through some pre-processor commands that need to be
+// included before either of these files is compiled. Because we don't want to
+// edit the generated C++ files, we include them here and do our gRPC
+// customizations in protocol-internal.h
+#include "arrow/flight/Flight.grpc.pb.cc" // NOLINT
+#include "arrow/flight/Flight.pb.cc" // NOLINT
diff --git a/src/arrow/cpp/src/arrow/flight/protocol_internal.h b/src/arrow/cpp/src/arrow/flight/protocol_internal.h
new file mode 100644
index 000000000..98bf92388
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/protocol_internal.h
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+
+#pragma once
+
+// This addresses platform-specific defines, e.g. on Windows
+#include "arrow/flight/platform.h" // IWYU pragma: keep
+
+// This header holds the Flight protobuf definitions.
+
+// Need to include this first to get our gRPC customizations
+#include "arrow/flight/customize_protobuf.h" // IWYU pragma: export
+
+#include "arrow/flight/Flight.grpc.pb.h" // IWYU pragma: export
+#include "arrow/flight/Flight.pb.h" // IWYU pragma: export
diff --git a/src/arrow/cpp/src/arrow/flight/serialization_internal.cc b/src/arrow/cpp/src/arrow/flight/serialization_internal.cc
new file mode 100644
index 000000000..36c6cc9e6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/serialization_internal.cc
@@ -0,0 +1,474 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/serialization_internal.h"
+
+#include <cstdint>
+#include <limits>
+#include <string>
+#include <vector>
+
+#include "arrow/flight/platform.h"
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4267)
+#endif
+
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
+#include <google/protobuf/wire_format_lite.h>
+
+#include <grpc/byte_buffer_reader.h>
+#ifdef GRPCPP_PP_INCLUDE
+#include <grpcpp/grpcpp.h>
+#include <grpcpp/impl/codegen/proto_utils.h>
+#else
+#include <grpc++/grpc++.h>
+#include <grpc++/impl/codegen/proto_utils.h>
+#endif
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+#include "arrow/buffer.h"
+#include "arrow/flight/server.h"
+#include "arrow/ipc/message.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/logging.h"
+
+static constexpr int64_t kInt32Max = std::numeric_limits<int32_t>::max();
+
+namespace arrow {
+namespace flight {
+namespace internal {
+
+namespace pb = arrow::flight::protocol;
+
+using arrow::ipc::IpcPayload;
+
+using google::protobuf::internal::WireFormatLite;
+using google::protobuf::io::ArrayOutputStream;
+using google::protobuf::io::CodedInputStream;
+using google::protobuf::io::CodedOutputStream;
+
+using grpc::ByteBuffer;
+
+bool ReadBytesZeroCopy(const std::shared_ptr<Buffer>& source_data,
+ CodedInputStream* input, std::shared_ptr<Buffer>* out) {
+ uint32_t length;
+ if (!input->ReadVarint32(&length)) {
+ return false;
+ }
+ auto buf =
+ SliceBuffer(source_data, input->CurrentPosition(), static_cast<int64_t>(length));
+ *out = buf;
+ return input->Skip(static_cast<int>(length));
+}
+
+// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow
+// consumers with zero-copy
+class GrpcBuffer : public MutableBuffer {
+ public:
+ GrpcBuffer(grpc_slice slice, bool incref)
+ : MutableBuffer(GRPC_SLICE_START_PTR(slice),
+ static_cast<int64_t>(GRPC_SLICE_LENGTH(slice))),
+ slice_(incref ? grpc_slice_ref(slice) : slice) {}
+
+ ~GrpcBuffer() override {
+ // Decref slice
+ grpc_slice_unref(slice_);
+ }
+
+ static Status Wrap(ByteBuffer* cpp_buf, std::shared_ptr<Buffer>* out) {
+ // These types are guaranteed by static assertions in gRPC to have the same
+ // in-memory representation
+
+ auto buffer = *reinterpret_cast<grpc_byte_buffer**>(cpp_buf);
+
+ // This part below is based on the Flatbuffers gRPC SerializationTraits in
+ // flatbuffers/grpc.h
+
+ // Check if this is a single uncompressed slice.
+ if ((buffer->type == GRPC_BB_RAW) &&
+ (buffer->data.raw.compression == GRPC_COMPRESS_NONE) &&
+ (buffer->data.raw.slice_buffer.count == 1)) {
+ // If it is, then we can reference the `grpc_slice` directly.
+ grpc_slice slice = buffer->data.raw.slice_buffer.slices[0];
+
+ if (slice.refcount) {
+ // Increment reference count so this memory remains valid
+ *out = std::make_shared<GrpcBuffer>(slice, true);
+ } else {
+ // Small slices (less than GRPC_SLICE_INLINED_SIZE bytes) are
+ // inlined into the structure and must be copied.
+ const uint8_t length = slice.data.inlined.length;
+ ARROW_ASSIGN_OR_RAISE(*out, arrow::AllocateBuffer(length));
+ std::memcpy((*out)->mutable_data(), slice.data.inlined.bytes, length);
+ }
+ } else {
+ // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read
+ // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives
+ // us back a new slice with the refcount already incremented.
+ grpc_byte_buffer_reader reader;
+ if (!grpc_byte_buffer_reader_init(&reader, buffer)) {
+ return Status::IOError("Internal gRPC error reading from ByteBuffer");
+ }
+ grpc_slice slice = grpc_byte_buffer_reader_readall(&reader);
+ grpc_byte_buffer_reader_destroy(&reader);
+
+ // Steal the slice reference
+ *out = std::make_shared<GrpcBuffer>(slice, false);
+ }
+
+ return Status::OK();
+ }
+
+ private:
+ grpc_slice slice_;
+};
+
+// Destructor callback for grpc::Slice
+static void ReleaseBuffer(void* buf_ptr) {
+ delete reinterpret_cast<std::shared_ptr<Buffer>*>(buf_ptr);
+}
+
+// Initialize gRPC Slice from arrow Buffer
+grpc::Slice SliceFromBuffer(const std::shared_ptr<Buffer>& buf) {
+ // Allocate persistent shared_ptr to control Buffer lifetime
+ auto ptr = new std::shared_ptr<Buffer>(buf);
+ grpc::Slice slice(const_cast<uint8_t*>(buf->data()), static_cast<size_t>(buf->size()),
+ &ReleaseBuffer, ptr);
+ // Make sure no copy was done (some grpc::Slice() constructors do an implicit memcpy)
+ DCHECK_EQ(slice.begin(), buf->data());
+ return slice;
+}
+
+static const uint8_t kPaddingBytes[8] = {0, 0, 0, 0, 0, 0, 0, 0};
+
+// Update the sizes of our Protobuf fields based on the given IPC payload.
+grpc::Status IpcMessageHeaderSize(const arrow::ipc::IpcPayload& ipc_msg, bool has_body,
+ size_t* header_size, int32_t* metadata_size) {
+ DCHECK_LE(ipc_msg.metadata->size(), kInt32Max);
+ *metadata_size = static_cast<int32_t>(ipc_msg.metadata->size());
+
+ // 1 byte for metadata tag
+ *header_size += 1 + WireFormatLite::LengthDelimitedSize(*metadata_size);
+
+ // 2 bytes for body tag
+ if (has_body) {
+ // We write the body tag in the header but not the actual body data
+ *header_size += 2 + WireFormatLite::LengthDelimitedSize(ipc_msg.body_length) -
+ ipc_msg.body_length;
+ }
+
+ return grpc::Status::OK;
+}
+
+grpc::Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out,
+ bool* own_buffer) {
+ // Size of the IPC body (protobuf: data_body)
+ size_t body_size = 0;
+ // Size of the Protobuf "header" (everything except for the body)
+ size_t header_size = 0;
+ // Size of IPC header metadata (protobuf: data_header)
+ int32_t metadata_size = 0;
+
+ // Write the descriptor if present
+ int32_t descriptor_size = 0;
+ if (msg.descriptor != nullptr) {
+ DCHECK_LE(msg.descriptor->size(), kInt32Max);
+ descriptor_size = static_cast<int32_t>(msg.descriptor->size());
+ header_size += 1 + WireFormatLite::LengthDelimitedSize(descriptor_size);
+ }
+
+ // App metadata tag if appropriate
+ int32_t app_metadata_size = 0;
+ if (msg.app_metadata && msg.app_metadata->size() > 0) {
+ DCHECK_LE(msg.app_metadata->size(), kInt32Max);
+ app_metadata_size = static_cast<int32_t>(msg.app_metadata->size());
+ header_size += 1 + WireFormatLite::LengthDelimitedSize(app_metadata_size);
+ }
+
+ const arrow::ipc::IpcPayload& ipc_msg = msg.ipc_message;
+ // No data in this payload (metadata-only).
+ bool has_ipc = ipc_msg.type != ipc::MessageType::NONE;
+ bool has_body = has_ipc ? ipc::Message::HasBody(ipc_msg.type) : false;
+
+ if (has_ipc) {
+ DCHECK(has_body || ipc_msg.body_length == 0);
+ GRPC_RETURN_NOT_GRPC_OK(
+ IpcMessageHeaderSize(ipc_msg, has_body, &header_size, &metadata_size));
+ body_size = static_cast<size_t>(ipc_msg.body_length);
+ }
+
+ // TODO(wesm): messages over 2GB unlikely to be yet supported
+ // Validated in WritePayload since returning error here causes gRPC to fail an assertion
+ DCHECK_LE(body_size, kInt32Max);
+
+ // Allocate and initialize slices
+ std::vector<grpc::Slice> slices;
+ slices.emplace_back(header_size);
+
+ // Force the header_stream to be destructed, which actually flushes
+ // the data into the slice.
+ {
+ ArrayOutputStream header_writer(const_cast<uint8_t*>(slices[0].begin()),
+ static_cast<int>(slices[0].size()));
+ CodedOutputStream header_stream(&header_writer);
+
+ // Write descriptor
+ if (msg.descriptor != nullptr) {
+ WireFormatLite::WriteTag(pb::FlightData::kFlightDescriptorFieldNumber,
+ WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream);
+ header_stream.WriteVarint32(descriptor_size);
+ header_stream.WriteRawMaybeAliased(msg.descriptor->data(),
+ static_cast<int>(msg.descriptor->size()));
+ }
+
+ // Write header
+ if (has_ipc) {
+ WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber,
+ WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream);
+ header_stream.WriteVarint32(metadata_size);
+ header_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(),
+ static_cast<int>(ipc_msg.metadata->size()));
+ }
+
+ // Write app metadata
+ if (app_metadata_size > 0) {
+ WireFormatLite::WriteTag(pb::FlightData::kAppMetadataFieldNumber,
+ WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream);
+ header_stream.WriteVarint32(app_metadata_size);
+ header_stream.WriteRawMaybeAliased(msg.app_metadata->data(),
+ static_cast<int>(msg.app_metadata->size()));
+ }
+
+ if (has_body) {
+ // Write body tag
+ WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber,
+ WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream);
+ header_stream.WriteVarint32(static_cast<uint32_t>(body_size));
+
+ // Enqueue body buffers for writing, without copying
+ for (const auto& buffer : ipc_msg.body_buffers) {
+ // Buffer may be null when the row length is zero, or when all
+ // entries are invalid.
+ if (!buffer) continue;
+
+ slices.push_back(SliceFromBuffer(buffer));
+
+ // Write padding if not multiple of 8
+ const auto remainder = static_cast<int>(
+ BitUtil::RoundUpToMultipleOf8(buffer->size()) - buffer->size());
+ if (remainder) {
+ slices.push_back(grpc::Slice(kPaddingBytes, remainder));
+ }
+ }
+ }
+
+ DCHECK_EQ(static_cast<int>(header_size), header_stream.ByteCount());
+ }
+
+ // Hand off the slices to the returned ByteBuffer
+ *out = grpc::ByteBuffer(slices.data(), slices.size());
+ *own_buffer = true;
+ return grpc::Status::OK;
+}
+
+// Read internal::FlightData from grpc::ByteBuffer containing FlightData
+// protobuf without copying
+grpc::Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out) {
+ if (!buffer) {
+ return grpc::Status(grpc::StatusCode::INTERNAL, "No payload");
+ }
+
+ // Reset fields in case the caller reuses a single allocation
+ out->descriptor = nullptr;
+ out->app_metadata = nullptr;
+ out->metadata = nullptr;
+ out->body = nullptr;
+
+ std::shared_ptr<arrow::Buffer> wrapped_buffer;
+ GRPC_RETURN_NOT_OK(GrpcBuffer::Wrap(buffer, &wrapped_buffer));
+
+ auto buffer_length = static_cast<int>(wrapped_buffer->size());
+ CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length);
+
+ pb_stream.SetTotalBytesLimit(buffer_length);
+
+ // This is the bytes remaining when using CodedInputStream like this
+ while (pb_stream.BytesUntilTotalBytesLimit()) {
+ const uint32_t tag = pb_stream.ReadTag();
+ const int field_number = WireFormatLite::GetTagFieldNumber(tag);
+ switch (field_number) {
+ case pb::FlightData::kFlightDescriptorFieldNumber: {
+ pb::FlightDescriptor pb_descriptor;
+ uint32_t length;
+ if (!pb_stream.ReadVarint32(&length)) {
+ return grpc::Status(grpc::StatusCode::INTERNAL,
+ "Unable to parse length of FlightDescriptor");
+ }
+ // Can't use ParseFromCodedStream as this reads the entire
+ // rest of the stream into the descriptor command field.
+ std::string buffer;
+ pb_stream.ReadString(&buffer, length);
+ if (!pb_descriptor.ParseFromString(buffer)) {
+ return grpc::Status(grpc::StatusCode::INTERNAL,
+ "Unable to parse FlightDescriptor");
+ }
+ arrow::flight::FlightDescriptor descriptor;
+ GRPC_RETURN_NOT_OK(
+ arrow::flight::internal::FromProto(pb_descriptor, &descriptor));
+ out->descriptor.reset(new arrow::flight::FlightDescriptor(descriptor));
+ } break;
+ case pb::FlightData::kDataHeaderFieldNumber: {
+ if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) {
+ return grpc::Status(grpc::StatusCode::INTERNAL,
+ "Unable to read FlightData metadata");
+ }
+ } break;
+ case pb::FlightData::kAppMetadataFieldNumber: {
+ if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->app_metadata)) {
+ return grpc::Status(grpc::StatusCode::INTERNAL,
+ "Unable to read FlightData application metadata");
+ }
+ } break;
+ case pb::FlightData::kDataBodyFieldNumber: {
+ if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) {
+ return grpc::Status(grpc::StatusCode::INTERNAL,
+ "Unable to read FlightData body");
+ }
+ } break;
+ default:
+ DCHECK(false) << "cannot happen";
+ }
+ }
+ buffer->Clear();
+
+ // TODO(wesm): Where and when should we verify that the FlightData is not
+ // malformed?
+
+ // Set the default value for an unspecified FlightData body. The other
+ // fields can be null if they're unspecified.
+ if (out->body == nullptr) {
+ out->body = std::make_shared<Buffer>(nullptr, 0);
+ }
+
+ return grpc::Status::OK;
+}
+
+::arrow::Result<std::unique_ptr<ipc::Message>> FlightData::OpenMessage() {
+ return ipc::Message::Open(metadata, body);
+}
+
+// The pointer bitcast hack below causes legitimate warnings, silence them.
+#ifndef _WIN32
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wstrict-aliasing"
+#endif
+
+// Pointer bitcast explanation: grpc::*Writer<T>::Write() and grpc::*Reader<T>::Read()
+// both take a T* argument (here pb::FlightData*). But they don't do anything
+// with that argument except pass it to SerializationTraits<T>::Serialize() and
+// SerializationTraits<T>::Deserialize().
+//
+// Since we control SerializationTraits<pb::FlightData>, we can interpret the
+// pointer argument whichever way we want, including cast it back to the original type.
+// (see customize_protobuf.h).
+
+Status WritePayload(const FlightPayload& payload,
+ grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>* writer) {
+ RETURN_NOT_OK(payload.Validate());
+ // Pretend to be pb::FlightData and intercept in SerializationTraits
+ if (!writer->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
+ grpc::WriteOptions())) {
+ return Status::IOError("Could not write payload to stream");
+ }
+ return Status::OK();
+}
+
+Status WritePayload(const FlightPayload& payload,
+ grpc::ClientReaderWriter<pb::FlightData, pb::FlightData>* writer) {
+ RETURN_NOT_OK(payload.Validate());
+ // Pretend to be pb::FlightData and intercept in SerializationTraits
+ if (!writer->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
+ grpc::WriteOptions())) {
+ return Status::IOError("Could not write payload to stream");
+ }
+ return Status::OK();
+}
+
+Status WritePayload(const FlightPayload& payload,
+ grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* writer) {
+ RETURN_NOT_OK(payload.Validate());
+ // Pretend to be pb::FlightData and intercept in SerializationTraits
+ if (!writer->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
+ grpc::WriteOptions())) {
+ return Status::IOError("Could not write payload to stream");
+ }
+ return Status::OK();
+}
+
+Status WritePayload(const FlightPayload& payload,
+ grpc::ServerWriter<pb::FlightData>* writer) {
+ RETURN_NOT_OK(payload.Validate());
+ // Pretend to be pb::FlightData and intercept in SerializationTraits
+ if (!writer->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
+ grpc::WriteOptions())) {
+ return Status::IOError("Could not write payload to stream");
+ }
+ return Status::OK();
+}
+
+bool ReadPayload(grpc::ClientReader<pb::FlightData>* reader, FlightData* data) {
+ // Pretend to be pb::FlightData and intercept in SerializationTraits
+ return reader->Read(reinterpret_cast<pb::FlightData*>(data));
+}
+
+bool ReadPayload(grpc::ClientReaderWriter<pb::FlightData, pb::FlightData>* reader,
+ FlightData* data) {
+ // Pretend to be pb::FlightData and intercept in SerializationTraits
+ return reader->Read(reinterpret_cast<pb::FlightData*>(data));
+}
+
+bool ReadPayload(grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader,
+ FlightData* data) {
+ // Pretend to be pb::FlightData and intercept in SerializationTraits
+ return reader->Read(reinterpret_cast<pb::FlightData*>(data));
+}
+
+bool ReadPayload(grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* reader,
+ FlightData* data) {
+ // Pretend to be pb::FlightData and intercept in SerializationTraits
+ return reader->Read(reinterpret_cast<pb::FlightData*>(data));
+}
+
+bool ReadPayload(grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>* reader,
+ pb::PutResult* data) {
+ return reader->Read(data);
+}
+
+#ifndef _WIN32
+#pragma GCC diagnostic pop
+#endif
+
+} // namespace internal
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/serialization_internal.h b/src/arrow/cpp/src/arrow/flight/serialization_internal.h
new file mode 100644
index 000000000..5f7d0cc48
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/serialization_internal.h
@@ -0,0 +1,152 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// (De)serialization utilities that hook into gRPC, efficiently
+// handling Arrow-encoded data in a gRPC call.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/flight/internal.h"
+#include "arrow/flight/types.h"
+#include "arrow/ipc/message.h"
+#include "arrow/result.h"
+
+namespace arrow {
+
+class Buffer;
+
+namespace flight {
+namespace internal {
+
+/// Internal, not user-visible type used for memory-efficient reads from gRPC
+/// stream
+struct FlightData {
+ /// Used only for puts, may be null
+ std::unique_ptr<FlightDescriptor> descriptor;
+
+ /// Non-length-prefixed Message header as described in format/Message.fbs
+ std::shared_ptr<Buffer> metadata;
+
+ /// Application-defined metadata
+ std::shared_ptr<Buffer> app_metadata;
+
+ /// Message body
+ std::shared_ptr<Buffer> body;
+
+ /// Open IPC message from the metadata and body
+ ::arrow::Result<std::unique_ptr<ipc::Message>> OpenMessage();
+};
+
+/// Write Flight message on gRPC stream with zero-copy optimizations.
+// Returns Invalid if the payload is ill-formed
+// Returns IOError if gRPC did not write the message (note this is not
+// necessarily an error - the client may simply have gone away)
+Status WritePayload(const FlightPayload& payload,
+ grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>* writer);
+Status WritePayload(const FlightPayload& payload,
+ grpc::ClientReaderWriter<pb::FlightData, pb::FlightData>* writer);
+Status WritePayload(const FlightPayload& payload,
+ grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* writer);
+Status WritePayload(const FlightPayload& payload,
+ grpc::ServerWriter<pb::FlightData>* writer);
+
+/// Read Flight message from gRPC stream with zero-copy optimizations.
+/// True is returned on success, false if stream ended.
+bool ReadPayload(grpc::ClientReader<pb::FlightData>* reader, FlightData* data);
+bool ReadPayload(grpc::ClientReaderWriter<pb::FlightData, pb::FlightData>* reader,
+ FlightData* data);
+bool ReadPayload(grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader,
+ FlightData* data);
+bool ReadPayload(grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* reader,
+ FlightData* data);
+// Overload to make genericity easier in DoPutPayloadWriter
+bool ReadPayload(grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>* reader,
+ pb::PutResult* data);
+
+// We want to reuse RecordBatchStreamReader's implementation while
+// (1) Adapting it to the Flight message format
+// (2) Allowing pure-metadata messages before data is sent
+// (3) Reusing the reader implementation between DoGet and DoExchange.
+// To do this, we wrap the gRPC reader in a peekable iterator.
+// The Flight reader can then peek at the message to determine whether
+// it has application metadata or not, and pass the message to
+// RecordBatchStreamReader as appropriate.
+template <typename ReaderPtr>
+class PeekableFlightDataReader {
+ public:
+ explicit PeekableFlightDataReader(ReaderPtr stream)
+ : stream_(stream), peek_(), finished_(false), valid_(false) {}
+
+ void Peek(internal::FlightData** out) {
+ *out = nullptr;
+ if (finished_) {
+ return;
+ }
+ if (EnsurePeek()) {
+ *out = &peek_;
+ }
+ }
+
+ void Next(internal::FlightData** out) {
+ Peek(out);
+ valid_ = false;
+ }
+
+ /// \brief Peek() until the first data message.
+ ///
+ /// After this is called, either this will return \a false, or the
+ /// next result of \a Peek and \a Next will contain Arrow data.
+ bool SkipToData() {
+ FlightData* data;
+ while (true) {
+ Peek(&data);
+ if (!data) {
+ return false;
+ }
+ if (data->metadata) {
+ return true;
+ }
+ Next(&data);
+ }
+ }
+
+ private:
+ bool EnsurePeek() {
+ if (finished_ || valid_) {
+ return valid_;
+ }
+
+ if (!internal::ReadPayload(&*stream_, &peek_)) {
+ finished_ = true;
+ valid_ = false;
+ } else {
+ valid_ = true;
+ }
+ return valid_;
+ }
+
+ ReaderPtr stream_;
+ internal::FlightData peek_;
+ bool finished_;
+ bool valid_;
+};
+
+} // namespace internal
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/server.cc b/src/arrow/cpp/src/arrow/flight/server.cc
new file mode 100644
index 000000000..b52c16246
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/server.cc
@@ -0,0 +1,1165 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Platform-specific defines
+#include "arrow/flight/platform.h"
+
+#include "arrow/flight/server.h"
+
+#ifdef _WIN32
+#include <io.h>
+#else
+#include <fcntl.h>
+#include <unistd.h>
+#endif
+#include <atomic>
+#include <cerrno>
+#include <cstdint>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <thread>
+#include <unordered_map>
+#include <utility>
+
+#ifdef GRPCPP_PP_INCLUDE
+#include <grpcpp/grpcpp.h>
+#else
+#include <grpc++/grpc++.h>
+#endif
+
+#include "arrow/buffer.h"
+#include "arrow/ipc/dictionary.h"
+#include "arrow/ipc/options.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/memory_pool.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/uri.h"
+
+#include "arrow/flight/internal.h"
+#include "arrow/flight/middleware.h"
+#include "arrow/flight/middleware_internal.h"
+#include "arrow/flight/serialization_internal.h"
+#include "arrow/flight/server_auth.h"
+#include "arrow/flight/server_middleware.h"
+#include "arrow/flight/types.h"
+
+using FlightService = arrow::flight::protocol::FlightService;
+using ServerContext = grpc::ServerContext;
+
+template <typename T>
+using ServerWriter = grpc::ServerWriter<T>;
+
+namespace arrow {
+namespace flight {
+
+namespace pb = arrow::flight::protocol;
+
+// Macro that runs interceptors before returning the given status
+#define RETURN_WITH_MIDDLEWARE(CONTEXT, STATUS) \
+ do { \
+ const auto& __s = (STATUS); \
+ return CONTEXT.FinishRequest(__s); \
+ } while (false)
+
+#define CHECK_ARG_NOT_NULL(CONTEXT, VAL, MESSAGE) \
+ if (VAL == nullptr) { \
+ RETURN_WITH_MIDDLEWARE(CONTEXT, \
+ grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE)); \
+ }
+
+// Same as RETURN_NOT_OK, but accepts either Arrow or gRPC status, and
+// will run interceptors
+#define SERVICE_RETURN_NOT_OK(CONTEXT, expr) \
+ do { \
+ const auto& _s = (expr); \
+ if (ARROW_PREDICT_FALSE(!_s.ok())) { \
+ return CONTEXT.FinishRequest(_s); \
+ } \
+ } while (false)
+
+namespace {
+
+// A MessageReader implementation that reads from a gRPC ServerReader.
+// Templated to be generic over DoPut/DoExchange.
+template <typename Reader>
+class FlightIpcMessageReader : public ipc::MessageReader {
+ public:
+ explicit FlightIpcMessageReader(
+ std::shared_ptr<internal::PeekableFlightDataReader<Reader*>> peekable_reader,
+ std::shared_ptr<Buffer>* app_metadata)
+ : peekable_reader_(peekable_reader), app_metadata_(app_metadata) {}
+
+ ::arrow::Result<std::unique_ptr<ipc::Message>> ReadNextMessage() override {
+ if (stream_finished_) {
+ return nullptr;
+ }
+ internal::FlightData* data;
+ peekable_reader_->Next(&data);
+ if (!data) {
+ stream_finished_ = true;
+ if (first_message_) {
+ return Status::Invalid(
+ "Client provided malformed message or did not provide message");
+ }
+ return nullptr;
+ }
+ *app_metadata_ = std::move(data->app_metadata);
+ return data->OpenMessage();
+ }
+
+ protected:
+ std::shared_ptr<internal::PeekableFlightDataReader<Reader*>> peekable_reader_;
+ // A reference to FlightMessageReaderImpl.app_metadata_. That class
+ // can't access the app metadata because when it Peek()s the stream,
+ // it may be looking at a dictionary batch, not the record
+ // batch. Updating it here ensures the reader is always updated with
+ // the last metadata message read.
+ std::shared_ptr<Buffer>* app_metadata_;
+ bool first_message_ = true;
+ bool stream_finished_ = false;
+};
+
+template <typename WritePayload>
+class FlightMessageReaderImpl : public FlightMessageReader {
+ public:
+ using GrpcStream = grpc::ServerReaderWriter<WritePayload, pb::FlightData>;
+
+ explicit FlightMessageReaderImpl(GrpcStream* reader)
+ : reader_(reader),
+ peekable_reader_(new internal::PeekableFlightDataReader<GrpcStream*>(reader)) {}
+
+ Status Init() {
+ // Peek the first message to get the descriptor.
+ internal::FlightData* data;
+ peekable_reader_->Peek(&data);
+ if (!data) {
+ return Status::IOError("Stream finished before first message sent");
+ }
+ if (!data->descriptor) {
+ return Status::IOError("Descriptor missing on first message");
+ }
+ descriptor_ = *data->descriptor.get(); // Copy
+ // If there's a schema (=DoPut), also Open().
+ if (data->metadata) {
+ return EnsureDataStarted();
+ }
+ peekable_reader_->Next(&data);
+ return Status::OK();
+ }
+
+ const FlightDescriptor& descriptor() const override { return descriptor_; }
+
+ arrow::Result<std::shared_ptr<Schema>> GetSchema() override {
+ RETURN_NOT_OK(EnsureDataStarted());
+ return batch_reader_->schema();
+ }
+
+ Status Next(FlightStreamChunk* out) override {
+ internal::FlightData* data;
+ peekable_reader_->Peek(&data);
+ if (!data) {
+ out->app_metadata = nullptr;
+ out->data = nullptr;
+ return Status::OK();
+ }
+
+ if (!data->metadata) {
+ // Metadata-only (data->metadata is the IPC header)
+ out->app_metadata = data->app_metadata;
+ out->data = nullptr;
+ peekable_reader_->Next(&data);
+ return Status::OK();
+ }
+
+ if (!batch_reader_) {
+ RETURN_NOT_OK(EnsureDataStarted());
+ // re-peek here since EnsureDataStarted() advances the stream
+ return Next(out);
+ }
+ RETURN_NOT_OK(batch_reader_->ReadNext(&out->data));
+ out->app_metadata = std::move(app_metadata_);
+ return Status::OK();
+ }
+
+ private:
+ /// Ensure we are set up to read data.
+ Status EnsureDataStarted() {
+ if (!batch_reader_) {
+ // peek() until we find the first data message; discard metadata
+ if (!peekable_reader_->SkipToData()) {
+ return Status::IOError("Client never sent a data message");
+ }
+ auto message_reader = std::unique_ptr<ipc::MessageReader>(
+ new FlightIpcMessageReader<GrpcStream>(peekable_reader_, &app_metadata_));
+ ARROW_ASSIGN_OR_RAISE(
+ batch_reader_, ipc::RecordBatchStreamReader::Open(std::move(message_reader)));
+ }
+ return Status::OK();
+ }
+
+ FlightDescriptor descriptor_;
+ GrpcStream* reader_;
+ std::shared_ptr<internal::PeekableFlightDataReader<GrpcStream*>> peekable_reader_;
+ std::shared_ptr<RecordBatchReader> batch_reader_;
+ std::shared_ptr<Buffer> app_metadata_;
+};
+
+class GrpcMetadataWriter : public FlightMetadataWriter {
+ public:
+ explicit GrpcMetadataWriter(
+ grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* writer)
+ : writer_(writer) {}
+
+ Status WriteMetadata(const Buffer& buffer) override {
+ pb::PutResult message{};
+ message.set_app_metadata(buffer.data(), buffer.size());
+ if (writer_->Write(message)) {
+ return Status::OK();
+ }
+ return Status::IOError("Unknown error writing metadata.");
+ }
+
+ private:
+ grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* writer_;
+};
+
+class GrpcServerAuthReader : public ServerAuthReader {
+ public:
+ explicit GrpcServerAuthReader(
+ grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream)
+ : stream_(stream) {}
+
+ Status Read(std::string* token) override {
+ pb::HandshakeRequest request;
+ if (stream_->Read(&request)) {
+ *token = std::move(*request.mutable_payload());
+ return Status::OK();
+ }
+ return Status::IOError("Stream is closed.");
+ }
+
+ private:
+ grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream_;
+};
+
+class GrpcServerAuthSender : public ServerAuthSender {
+ public:
+ explicit GrpcServerAuthSender(
+ grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream)
+ : stream_(stream) {}
+
+ Status Write(const std::string& token) override {
+ pb::HandshakeResponse response;
+ response.set_payload(token);
+ if (stream_->Write(response)) {
+ return Status::OK();
+ }
+ return Status::IOError("Stream was closed.");
+ }
+
+ private:
+ grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream_;
+};
+
+/// The implementation of the write side of a bidirectional FlightData
+/// stream for DoExchange.
+class DoExchangeMessageWriter : public FlightMessageWriter {
+ public:
+ explicit DoExchangeMessageWriter(
+ grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* stream)
+ : stream_(stream), ipc_options_(::arrow::ipc::IpcWriteOptions::Defaults()) {}
+
+ Status Begin(const std::shared_ptr<Schema>& schema,
+ const ipc::IpcWriteOptions& options) override {
+ if (started_) {
+ return Status::Invalid("This writer has already been started.");
+ }
+ started_ = true;
+ ipc_options_ = options;
+
+ RETURN_NOT_OK(mapper_.AddSchemaFields(*schema));
+ FlightPayload schema_payload;
+ RETURN_NOT_OK(ipc::GetSchemaPayload(*schema, ipc_options_, mapper_,
+ &schema_payload.ipc_message));
+ return WritePayload(schema_payload);
+ }
+
+ Status WriteRecordBatch(const RecordBatch& batch) override {
+ return WriteWithMetadata(batch, nullptr);
+ }
+
+ Status WriteMetadata(std::shared_ptr<Buffer> app_metadata) override {
+ FlightPayload payload{};
+ payload.app_metadata = app_metadata;
+ return WritePayload(payload);
+ }
+
+ Status WriteWithMetadata(const RecordBatch& batch,
+ std::shared_ptr<Buffer> app_metadata) override {
+ RETURN_NOT_OK(CheckStarted());
+ RETURN_NOT_OK(EnsureDictionariesWritten(batch));
+ FlightPayload payload{};
+ if (app_metadata) {
+ payload.app_metadata = app_metadata;
+ }
+ RETURN_NOT_OK(ipc::GetRecordBatchPayload(batch, ipc_options_, &payload.ipc_message));
+ RETURN_NOT_OK(WritePayload(payload));
+ ++stats_.num_record_batches;
+ return Status::OK();
+ }
+
+ Status Close() override {
+ // It's fine to Close() without writing data
+ return Status::OK();
+ }
+
+ ipc::WriteStats stats() const override { return stats_; }
+
+ private:
+ Status WritePayload(const FlightPayload& payload) {
+ RETURN_NOT_OK(internal::WritePayload(payload, stream_));
+ ++stats_.num_messages;
+ return Status::OK();
+ }
+
+ Status CheckStarted() {
+ if (!started_) {
+ return Status::Invalid("This writer is not started. Call Begin() with a schema");
+ }
+ return Status::OK();
+ }
+
+ Status EnsureDictionariesWritten(const RecordBatch& batch) {
+ if (dictionaries_written_) {
+ return Status::OK();
+ }
+ dictionaries_written_ = true;
+ ARROW_ASSIGN_OR_RAISE(const auto dictionaries,
+ ipc::CollectDictionaries(batch, mapper_));
+ for (const auto& pair : dictionaries) {
+ FlightPayload payload{};
+ RETURN_NOT_OK(ipc::GetDictionaryPayload(pair.first, pair.second, ipc_options_,
+ &payload.ipc_message));
+ RETURN_NOT_OK(WritePayload(payload));
+ ++stats_.num_dictionary_batches;
+ }
+ return Status::OK();
+ }
+
+ grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* stream_;
+ ::arrow::ipc::IpcWriteOptions ipc_options_;
+ ipc::DictionaryFieldMapper mapper_;
+ ipc::WriteStats stats_;
+ bool started_ = false;
+ bool dictionaries_written_ = false;
+};
+
+class FlightServiceImpl;
+class GrpcServerCallContext : public ServerCallContext {
+ explicit GrpcServerCallContext(grpc::ServerContext* context)
+ : context_(context), peer_(context_->peer()) {}
+
+ const std::string& peer_identity() const override { return peer_identity_; }
+ const std::string& peer() const override { return peer_; }
+ bool is_cancelled() const override { return context_->IsCancelled(); }
+
+ // Helper method that runs interceptors given the result of an RPC,
+ // then returns the final gRPC status to send to the client
+ grpc::Status FinishRequest(const grpc::Status& status) {
+ // Don't double-convert status - return the original one here
+ FinishRequest(internal::FromGrpcStatus(status));
+ return status;
+ }
+
+ grpc::Status FinishRequest(const arrow::Status& status) {
+ for (const auto& instance : middleware_) {
+ instance->CallCompleted(status);
+ }
+
+ // Set custom headers to map the exact Arrow status for clients
+ // who want it.
+ return internal::ToGrpcStatus(status, context_);
+ }
+
+ ServerMiddleware* GetMiddleware(const std::string& key) const override {
+ const auto& instance = middleware_map_.find(key);
+ if (instance == middleware_map_.end()) {
+ return nullptr;
+ }
+ return instance->second.get();
+ }
+
+ private:
+ friend class FlightServiceImpl;
+ ServerContext* context_;
+ std::string peer_;
+ std::string peer_identity_;
+ std::vector<std::shared_ptr<ServerMiddleware>> middleware_;
+ std::unordered_map<std::string, std::shared_ptr<ServerMiddleware>> middleware_map_;
+};
+
+class GrpcAddCallHeaders : public AddCallHeaders {
+ public:
+ explicit GrpcAddCallHeaders(grpc::ServerContext* context) : context_(context) {}
+ ~GrpcAddCallHeaders() override = default;
+
+ void AddHeader(const std::string& key, const std::string& value) override {
+ context_->AddInitialMetadata(key, value);
+ }
+
+ private:
+ grpc::ServerContext* context_;
+};
+
+// This class glues an implementation of FlightServerBase together with the
+// gRPC service definition, so the latter is not exposed in the public API
+class FlightServiceImpl : public FlightService::Service {
+ public:
+ explicit FlightServiceImpl(
+ std::shared_ptr<ServerAuthHandler> auth_handler,
+ std::vector<std::pair<std::string, std::shared_ptr<ServerMiddlewareFactory>>>
+ middleware,
+ FlightServerBase* server)
+ : auth_handler_(auth_handler), middleware_(middleware), server_(server) {}
+
+ template <typename UserType, typename Iterator, typename ProtoType>
+ grpc::Status WriteStream(Iterator* iterator, ServerWriter<ProtoType>* writer) {
+ if (!iterator) {
+ return grpc::Status(grpc::StatusCode::INTERNAL, "No items to iterate");
+ }
+ // Write flight info to stream until listing is exhausted
+ while (true) {
+ ProtoType pb_value;
+ std::unique_ptr<UserType> value;
+ GRPC_RETURN_NOT_OK(iterator->Next(&value));
+ if (!value) {
+ break;
+ }
+ GRPC_RETURN_NOT_OK(internal::ToProto(*value, &pb_value));
+
+ // Blocking write
+ if (!writer->Write(pb_value)) {
+ // Write returns false if the stream is closed
+ break;
+ }
+ }
+ return grpc::Status::OK;
+ }
+
+ template <typename UserType, typename ProtoType>
+ grpc::Status WriteStream(const std::vector<UserType>& values,
+ ServerWriter<ProtoType>* writer) {
+ // Write flight info to stream until listing is exhausted
+ for (const UserType& value : values) {
+ ProtoType pb_value;
+ GRPC_RETURN_NOT_OK(internal::ToProto(value, &pb_value));
+ // Blocking write
+ if (!writer->Write(pb_value)) {
+ // Write returns false if the stream is closed
+ break;
+ }
+ }
+ return grpc::Status::OK;
+ }
+
+ // Authenticate the client (if applicable) and construct the call context
+ grpc::Status CheckAuth(const FlightMethod& method, ServerContext* context,
+ GrpcServerCallContext& flight_context) {
+ if (!auth_handler_) {
+ const auto auth_context = context->auth_context();
+ if (auth_context && auth_context->IsPeerAuthenticated()) {
+ auto peer_identity = auth_context->GetPeerIdentity();
+ flight_context.peer_identity_ =
+ peer_identity.empty()
+ ? ""
+ : std::string(peer_identity.front().begin(), peer_identity.front().end());
+ } else {
+ flight_context.peer_identity_ = "";
+ }
+ } else {
+ const auto client_metadata = context->client_metadata();
+ const auto auth_header = client_metadata.find(internal::kGrpcAuthHeader);
+ std::string token;
+ if (auth_header == client_metadata.end()) {
+ token = "";
+ } else {
+ token = std::string(auth_header->second.data(), auth_header->second.length());
+ }
+ GRPC_RETURN_NOT_OK(auth_handler_->IsValid(token, &flight_context.peer_identity_));
+ }
+
+ return MakeCallContext(method, context, flight_context);
+ }
+
+ // Authenticate the client (if applicable) and construct the call context
+ grpc::Status MakeCallContext(const FlightMethod& method, ServerContext* context,
+ GrpcServerCallContext& flight_context) {
+ // Run server middleware
+ const CallInfo info{method};
+ CallHeaders incoming_headers;
+ for (const auto& entry : context->client_metadata()) {
+ incoming_headers.insert(
+ {util::string_view(entry.first.data(), entry.first.length()),
+ util::string_view(entry.second.data(), entry.second.length())});
+ }
+
+ GrpcAddCallHeaders outgoing_headers(context);
+ for (const auto& factory : middleware_) {
+ std::shared_ptr<ServerMiddleware> instance;
+ Status result = factory.second->StartCall(info, incoming_headers, &instance);
+ if (!result.ok()) {
+ // Interceptor rejected call, end the request on all existing
+ // interceptors
+ return flight_context.FinishRequest(result);
+ }
+ if (instance != nullptr) {
+ flight_context.middleware_.push_back(instance);
+ flight_context.middleware_map_.insert({factory.first, instance});
+ instance->SendingHeaders(&outgoing_headers);
+ }
+ }
+
+ return grpc::Status::OK;
+ }
+
+ grpc::Status Handshake(
+ ServerContext* context,
+ grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream) {
+ GrpcServerCallContext flight_context(context);
+ GRPC_RETURN_NOT_GRPC_OK(
+ MakeCallContext(FlightMethod::Handshake, context, flight_context));
+
+ if (!auth_handler_) {
+ RETURN_WITH_MIDDLEWARE(
+ flight_context,
+ grpc::Status(
+ grpc::StatusCode::UNIMPLEMENTED,
+ "This service does not have an authentication mechanism enabled."));
+ }
+ GrpcServerAuthSender outgoing{stream};
+ GrpcServerAuthReader incoming{stream};
+ RETURN_WITH_MIDDLEWARE(flight_context,
+ auth_handler_->Authenticate(&outgoing, &incoming));
+ }
+
+ grpc::Status ListFlights(ServerContext* context, const pb::Criteria* request,
+ ServerWriter<pb::FlightInfo>* writer) {
+ GrpcServerCallContext flight_context(context);
+ GRPC_RETURN_NOT_GRPC_OK(
+ CheckAuth(FlightMethod::ListFlights, context, flight_context));
+
+ // Retrieve the listing from the implementation
+ std::unique_ptr<FlightListing> listing;
+
+ Criteria criteria;
+ if (request) {
+ SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &criteria));
+ }
+ SERVICE_RETURN_NOT_OK(flight_context,
+ server_->ListFlights(flight_context, &criteria, &listing));
+ if (!listing) {
+ // Treat null listing as no flights available
+ RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK);
+ }
+ RETURN_WITH_MIDDLEWARE(flight_context,
+ WriteStream<FlightInfo>(listing.get(), writer));
+ }
+
+ grpc::Status GetFlightInfo(ServerContext* context, const pb::FlightDescriptor* request,
+ pb::FlightInfo* response) {
+ GrpcServerCallContext flight_context(context);
+ GRPC_RETURN_NOT_GRPC_OK(
+ CheckAuth(FlightMethod::GetFlightInfo, context, flight_context));
+
+ CHECK_ARG_NOT_NULL(flight_context, request, "FlightDescriptor cannot be null");
+
+ FlightDescriptor descr;
+ SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &descr));
+
+ std::unique_ptr<FlightInfo> info;
+ SERVICE_RETURN_NOT_OK(flight_context,
+ server_->GetFlightInfo(flight_context, descr, &info));
+
+ if (!info) {
+ // Treat null listing as no flights available
+ RETURN_WITH_MIDDLEWARE(
+ flight_context, grpc::Status(grpc::StatusCode::NOT_FOUND, "Flight not found"));
+ }
+
+ SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*info, response));
+ RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK);
+ }
+
+ grpc::Status GetSchema(ServerContext* context, const pb::FlightDescriptor* request,
+ pb::SchemaResult* response) {
+ GrpcServerCallContext flight_context(context);
+ GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::GetSchema, context, flight_context));
+
+ CHECK_ARG_NOT_NULL(flight_context, request, "FlightDescriptor cannot be null");
+
+ FlightDescriptor descr;
+ SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &descr));
+
+ std::unique_ptr<SchemaResult> result;
+ SERVICE_RETURN_NOT_OK(flight_context,
+ server_->GetSchema(flight_context, descr, &result));
+
+ if (!result) {
+ // Treat null listing as no flights available
+ RETURN_WITH_MIDDLEWARE(
+ flight_context, grpc::Status(grpc::StatusCode::NOT_FOUND, "Flight not found"));
+ }
+
+ SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*result, response));
+ RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK);
+ }
+
+ grpc::Status DoGet(ServerContext* context, const pb::Ticket* request,
+ ServerWriter<pb::FlightData>* writer) {
+ GrpcServerCallContext flight_context(context);
+ GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoGet, context, flight_context));
+
+ CHECK_ARG_NOT_NULL(flight_context, request, "ticket cannot be null");
+
+ Ticket ticket;
+ SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &ticket));
+
+ std::unique_ptr<FlightDataStream> data_stream;
+ SERVICE_RETURN_NOT_OK(flight_context,
+ server_->DoGet(flight_context, ticket, &data_stream));
+
+ if (!data_stream) {
+ RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status(grpc::StatusCode::NOT_FOUND,
+ "No data in this flight"));
+ }
+
+ // Write the schema as the first message in the stream
+ FlightPayload schema_payload;
+ SERVICE_RETURN_NOT_OK(flight_context, data_stream->GetSchemaPayload(&schema_payload));
+ auto status = internal::WritePayload(schema_payload, writer);
+ if (status.IsIOError()) {
+ // gRPC doesn't give any way for us to know why the message
+ // could not be written.
+ RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK);
+ }
+ SERVICE_RETURN_NOT_OK(flight_context, status);
+
+ // Consume data stream and write out payloads
+ while (true) {
+ FlightPayload payload;
+ SERVICE_RETURN_NOT_OK(flight_context, data_stream->Next(&payload));
+ // End of stream
+ if (payload.ipc_message.metadata == nullptr) break;
+ auto status = internal::WritePayload(payload, writer);
+ // Connection terminated
+ if (status.IsIOError()) break;
+ SERVICE_RETURN_NOT_OK(flight_context, status);
+ }
+ RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK);
+ }
+
+ grpc::Status DoPut(ServerContext* context,
+ grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader) {
+ GrpcServerCallContext flight_context(context);
+ GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoPut, context, flight_context));
+
+ auto message_reader = std::unique_ptr<FlightMessageReaderImpl<pb::PutResult>>(
+ new FlightMessageReaderImpl<pb::PutResult>(reader));
+ SERVICE_RETURN_NOT_OK(flight_context, message_reader->Init());
+ auto metadata_writer =
+ std::unique_ptr<FlightMetadataWriter>(new GrpcMetadataWriter(reader));
+ RETURN_WITH_MIDDLEWARE(flight_context,
+ server_->DoPut(flight_context, std::move(message_reader),
+ std::move(metadata_writer)));
+ }
+
+ grpc::Status DoExchange(
+ ServerContext* context,
+ grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* stream) {
+ GrpcServerCallContext flight_context(context);
+ GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoExchange, context, flight_context));
+ auto message_reader = std::unique_ptr<FlightMessageReaderImpl<pb::FlightData>>(
+ new FlightMessageReaderImpl<pb::FlightData>(stream));
+ SERVICE_RETURN_NOT_OK(flight_context, message_reader->Init());
+ auto writer =
+ std::unique_ptr<DoExchangeMessageWriter>(new DoExchangeMessageWriter(stream));
+ RETURN_WITH_MIDDLEWARE(flight_context,
+ server_->DoExchange(flight_context, std::move(message_reader),
+ std::move(writer)));
+ }
+
+ grpc::Status ListActions(ServerContext* context, const pb::Empty* request,
+ ServerWriter<pb::ActionType>* writer) {
+ GrpcServerCallContext flight_context(context);
+ GRPC_RETURN_NOT_GRPC_OK(
+ CheckAuth(FlightMethod::ListActions, context, flight_context));
+ // Retrieve the listing from the implementation
+ std::vector<ActionType> types;
+ SERVICE_RETURN_NOT_OK(flight_context, server_->ListActions(flight_context, &types));
+ RETURN_WITH_MIDDLEWARE(flight_context, WriteStream<ActionType>(types, writer));
+ }
+
+ grpc::Status DoAction(ServerContext* context, const pb::Action* request,
+ ServerWriter<pb::Result>* writer) {
+ GrpcServerCallContext flight_context(context);
+ GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoAction, context, flight_context));
+ CHECK_ARG_NOT_NULL(flight_context, request, "Action cannot be null");
+ Action action;
+ SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &action));
+
+ std::unique_ptr<ResultStream> results;
+ SERVICE_RETURN_NOT_OK(flight_context,
+ server_->DoAction(flight_context, action, &results));
+
+ if (!results) {
+ RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::CANCELLED);
+ }
+
+ while (true) {
+ std::unique_ptr<Result> result;
+ SERVICE_RETURN_NOT_OK(flight_context, results->Next(&result));
+ if (!result) {
+ // No more results
+ break;
+ }
+ pb::Result pb_result;
+ SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*result, &pb_result));
+ if (!writer->Write(pb_result)) {
+ // Stream may be closed
+ break;
+ }
+ }
+ RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK);
+ }
+
+ private:
+ std::shared_ptr<ServerAuthHandler> auth_handler_;
+ std::vector<std::pair<std::string, std::shared_ptr<ServerMiddlewareFactory>>>
+ middleware_;
+ FlightServerBase* server_;
+};
+
+} // namespace
+
+FlightMetadataWriter::~FlightMetadataWriter() = default;
+
+//
+// gRPC server lifecycle
+//
+
+#if (ATOMIC_INT_LOCK_FREE != 2 || ATOMIC_POINTER_LOCK_FREE != 2)
+#error "atomic ints and atomic pointers not always lock-free!"
+#endif
+
+using ::arrow::internal::SetSignalHandler;
+using ::arrow::internal::SignalHandler;
+
+#ifdef WIN32
+#define PIPE_WRITE _write
+#define PIPE_READ _read
+#else
+#define PIPE_WRITE write
+#define PIPE_READ read
+#endif
+
+/// RAII guard that manages a self-pipe and a thread that listens on
+/// the self-pipe, shutting down the gRPC server when a signal handler
+/// writes to the pipe.
+class ServerSignalHandler {
+ public:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ServerSignalHandler);
+ ServerSignalHandler() = default;
+
+ /// Create the pipe and handler thread.
+ ///
+ /// \return the fd of the write side of the pipe.
+ template <typename Fn>
+ arrow::Result<int> Init(Fn handler) {
+ ARROW_ASSIGN_OR_RAISE(auto pipe, arrow::internal::CreatePipe());
+#ifndef WIN32
+ // Make write end nonblocking
+ int flags = fcntl(pipe.wfd, F_GETFL);
+ if (flags == -1) {
+ RETURN_NOT_OK(arrow::internal::FileClose(pipe.rfd));
+ RETURN_NOT_OK(arrow::internal::FileClose(pipe.wfd));
+ return arrow::internal::IOErrorFromErrno(
+ errno, "Could not initialize self-pipe to wait for signals");
+ }
+ flags |= O_NONBLOCK;
+ if (fcntl(pipe.wfd, F_SETFL, flags) == -1) {
+ RETURN_NOT_OK(arrow::internal::FileClose(pipe.rfd));
+ RETURN_NOT_OK(arrow::internal::FileClose(pipe.wfd));
+ return arrow::internal::IOErrorFromErrno(
+ errno, "Could not initialize self-pipe to wait for signals");
+ }
+#endif
+ self_pipe_ = pipe;
+ handle_signals_ = std::thread(handler, self_pipe_.rfd);
+ return self_pipe_.wfd;
+ }
+
+ Status Shutdown() {
+ if (self_pipe_.rfd == 0) {
+ // Already closed
+ return Status::OK();
+ }
+ if (PIPE_WRITE(self_pipe_.wfd, "0", 1) < 0 && errno != EAGAIN &&
+ errno != EWOULDBLOCK && errno != EINTR) {
+ return arrow::internal::IOErrorFromErrno(errno, "Could not unblock signal thread");
+ }
+ RETURN_NOT_OK(arrow::internal::FileClose(self_pipe_.rfd));
+ RETURN_NOT_OK(arrow::internal::FileClose(self_pipe_.wfd));
+ handle_signals_.join();
+ self_pipe_.rfd = 0;
+ self_pipe_.wfd = 0;
+ return Status::OK();
+ }
+
+ ~ServerSignalHandler() { ARROW_CHECK_OK(Shutdown()); }
+
+ private:
+ arrow::internal::Pipe self_pipe_;
+ std::thread handle_signals_;
+};
+
+struct FlightServerBase::Impl {
+ std::unique_ptr<FlightServiceImpl> service_;
+ std::unique_ptr<grpc::Server> server_;
+ int port_;
+
+ // Signal handlers (on Windows) and the shutdown handler (other platforms)
+ // are executed in a separate thread, so getting the current thread instance
+ // wouldn't make sense. This means only a single instance can receive signals.
+ static std::atomic<Impl*> running_instance_;
+ // We'll use the self-pipe trick to notify a thread from the signal
+ // handler. The thread will then shut down the gRPC server.
+ int self_pipe_wfd_;
+
+ // Signal handling
+ std::vector<int> signals_;
+ std::vector<SignalHandler> old_signal_handlers_;
+ std::atomic<int> got_signal_;
+
+ static void HandleSignal(int signum) {
+ auto instance = running_instance_.load();
+ if (instance != nullptr) {
+ instance->DoHandleSignal(signum);
+ }
+ }
+
+ void DoHandleSignal(int signum) {
+ got_signal_ = signum;
+ int saved_errno = errno;
+ // Ignore errors - pipe is nonblocking
+ PIPE_WRITE(self_pipe_wfd_, "0", 1);
+ errno = saved_errno;
+ }
+
+ static void WaitForSignals(int fd) {
+ // Wait for a signal handler to write to the pipe
+ int8_t buf[1];
+ while (PIPE_READ(fd, /*buf=*/buf, /*count=*/1) == -1) {
+ if (errno == EINTR) {
+ continue;
+ }
+ ARROW_CHECK_OK(arrow::internal::IOErrorFromErrno(
+ errno, "Error while waiting for shutdown signal"));
+ }
+ auto instance = running_instance_.load();
+ if (instance != nullptr) {
+ instance->server_->Shutdown();
+ }
+ }
+};
+
+std::atomic<FlightServerBase::Impl*> FlightServerBase::Impl::running_instance_;
+
+FlightServerOptions::FlightServerOptions(const Location& location_)
+ : location(location_),
+ auth_handler(nullptr),
+ tls_certificates(),
+ verify_client(false),
+ root_certificates(),
+ middleware(),
+ builder_hook(nullptr) {}
+
+FlightServerOptions::~FlightServerOptions() = default;
+
+FlightServerBase::FlightServerBase() { impl_.reset(new Impl); }
+
+FlightServerBase::~FlightServerBase() {}
+
+Status FlightServerBase::Init(const FlightServerOptions& options) {
+ impl_->service_.reset(
+ new FlightServiceImpl(options.auth_handler, options.middleware, this));
+
+ grpc::ServerBuilder builder;
+ // Allow uploading messages of any length
+ builder.SetMaxReceiveMessageSize(-1);
+
+ const Location& location = options.location;
+ const std::string scheme = location.scheme();
+ if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) {
+ std::stringstream address;
+ address << arrow::internal::UriEncodeHost(location.uri_->host()) << ':'
+ << location.uri_->port_text();
+
+ std::shared_ptr<grpc::ServerCredentials> creds;
+ if (scheme == kSchemeGrpcTls) {
+ grpc::SslServerCredentialsOptions ssl_options;
+ for (const auto& pair : options.tls_certificates) {
+ ssl_options.pem_key_cert_pairs.push_back({pair.pem_key, pair.pem_cert});
+ }
+ if (options.verify_client) {
+ ssl_options.client_certificate_request =
+ GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY;
+ }
+ if (!options.root_certificates.empty()) {
+ ssl_options.pem_root_certs = options.root_certificates;
+ }
+ creds = grpc::SslServerCredentials(ssl_options);
+ } else {
+ creds = grpc::InsecureServerCredentials();
+ }
+
+ builder.AddListeningPort(address.str(), creds, &impl_->port_);
+ } else if (scheme == kSchemeGrpcUnix) {
+ std::stringstream address;
+ address << "unix:" << location.uri_->path();
+ builder.AddListeningPort(address.str(), grpc::InsecureServerCredentials());
+ } else {
+ return Status::NotImplemented("Scheme is not supported: " + scheme);
+ }
+
+ builder.RegisterService(impl_->service_.get());
+
+ // Disable SO_REUSEPORT - it makes debugging/testing a pain as
+ // leftover processes can handle requests on accident
+ builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0);
+
+ if (options.builder_hook) {
+ options.builder_hook(&builder);
+ }
+
+ impl_->server_ = builder.BuildAndStart();
+ if (!impl_->server_) {
+ return Status::UnknownError("Server did not start properly");
+ }
+ return Status::OK();
+}
+
+int FlightServerBase::port() const { return impl_->port_; }
+
+Status FlightServerBase::SetShutdownOnSignals(const std::vector<int> sigs) {
+ impl_->signals_ = sigs;
+ impl_->old_signal_handlers_.clear();
+ return Status::OK();
+}
+
+Status FlightServerBase::Serve() {
+ if (!impl_->server_) {
+ return Status::UnknownError("Server did not start properly");
+ }
+ impl_->got_signal_ = 0;
+ impl_->old_signal_handlers_.clear();
+ impl_->running_instance_ = impl_.get();
+
+ ServerSignalHandler signal_handler;
+ ARROW_ASSIGN_OR_RAISE(impl_->self_pipe_wfd_,
+ signal_handler.Init(&Impl::WaitForSignals));
+ // Override existing signal handlers with our own handler so as to stop the server.
+ for (size_t i = 0; i < impl_->signals_.size(); ++i) {
+ int signum = impl_->signals_[i];
+ SignalHandler new_handler(&Impl::HandleSignal), old_handler;
+ ARROW_ASSIGN_OR_RAISE(old_handler, SetSignalHandler(signum, new_handler));
+ impl_->old_signal_handlers_.push_back(std::move(old_handler));
+ }
+
+ impl_->server_->Wait();
+ impl_->running_instance_ = nullptr;
+
+ // Restore signal handlers
+ for (size_t i = 0; i < impl_->signals_.size(); ++i) {
+ RETURN_NOT_OK(
+ SetSignalHandler(impl_->signals_[i], impl_->old_signal_handlers_[i]).status());
+ }
+ return Status::OK();
+}
+
+int FlightServerBase::GotSignal() const { return impl_->got_signal_; }
+
+Status FlightServerBase::Shutdown() {
+ auto server = impl_->server_.get();
+ if (!server) {
+ return Status::Invalid("Shutdown() on uninitialized FlightServerBase");
+ }
+ impl_->server_->Shutdown();
+ return Status::OK();
+}
+
+Status FlightServerBase::Wait() {
+ impl_->server_->Wait();
+ impl_->running_instance_ = nullptr;
+ return Status::OK();
+}
+
+Status FlightServerBase::ListFlights(const ServerCallContext& context,
+ const Criteria* criteria,
+ std::unique_ptr<FlightListing>* listings) {
+ return Status::NotImplemented("NYI");
+}
+
+Status FlightServerBase::GetFlightInfo(const ServerCallContext& context,
+ const FlightDescriptor& request,
+ std::unique_ptr<FlightInfo>* info) {
+ return Status::NotImplemented("NYI");
+}
+
+Status FlightServerBase::DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* data_stream) {
+ return Status::NotImplemented("NYI");
+}
+
+Status FlightServerBase::DoPut(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer) {
+ return Status::NotImplemented("NYI");
+}
+
+Status FlightServerBase::DoExchange(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer) {
+ return Status::NotImplemented("NYI");
+}
+
+Status FlightServerBase::DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* result) {
+ return Status::NotImplemented("NYI");
+}
+
+Status FlightServerBase::ListActions(const ServerCallContext& context,
+ std::vector<ActionType>* actions) {
+ return Status::NotImplemented("NYI");
+}
+
+Status FlightServerBase::GetSchema(const ServerCallContext& context,
+ const FlightDescriptor& request,
+ std::unique_ptr<SchemaResult>* schema) {
+ return Status::NotImplemented("NYI");
+}
+
+// ----------------------------------------------------------------------
+// Implement RecordBatchStream
+
+class RecordBatchStream::RecordBatchStreamImpl {
+ public:
+ // Stages of the stream when producing payloads
+ enum class Stage {
+ NEW, // The stream has been created, but Next has not been called yet
+ DICTIONARY, // Dictionaries have been collected, and are being sent
+ RECORD_BATCH // Initial have been sent
+ };
+
+ RecordBatchStreamImpl(const std::shared_ptr<RecordBatchReader>& reader,
+ const ipc::IpcWriteOptions& options)
+ : reader_(reader), mapper_(*reader_->schema()), ipc_options_(options) {}
+
+ std::shared_ptr<Schema> schema() { return reader_->schema(); }
+
+ Status GetSchemaPayload(FlightPayload* payload) {
+ return ipc::GetSchemaPayload(*reader_->schema(), ipc_options_, mapper_,
+ &payload->ipc_message);
+ }
+
+ Status Next(FlightPayload* payload) {
+ if (stage_ == Stage::NEW) {
+ RETURN_NOT_OK(reader_->ReadNext(&current_batch_));
+ if (!current_batch_) {
+ // Signal that iteration is over
+ payload->ipc_message.metadata = nullptr;
+ return Status::OK();
+ }
+ ARROW_ASSIGN_OR_RAISE(dictionaries_,
+ ipc::CollectDictionaries(*current_batch_, mapper_));
+ stage_ = Stage::DICTIONARY;
+ }
+
+ if (stage_ == Stage::DICTIONARY) {
+ if (dictionary_index_ == static_cast<int>(dictionaries_.size())) {
+ stage_ = Stage::RECORD_BATCH;
+ return ipc::GetRecordBatchPayload(*current_batch_, ipc_options_,
+ &payload->ipc_message);
+ } else {
+ return GetNextDictionary(payload);
+ }
+ }
+
+ RETURN_NOT_OK(reader_->ReadNext(&current_batch_));
+
+ // TODO(wesm): Delta dictionaries
+ if (!current_batch_) {
+ // Signal that iteration is over
+ payload->ipc_message.metadata = nullptr;
+ return Status::OK();
+ } else {
+ return ipc::GetRecordBatchPayload(*current_batch_, ipc_options_,
+ &payload->ipc_message);
+ }
+ }
+
+ private:
+ Status GetNextDictionary(FlightPayload* payload) {
+ const auto& it = dictionaries_[dictionary_index_++];
+ return ipc::GetDictionaryPayload(it.first, it.second, ipc_options_,
+ &payload->ipc_message);
+ }
+
+ Stage stage_ = Stage::NEW;
+ std::shared_ptr<RecordBatchReader> reader_;
+ ipc::DictionaryFieldMapper mapper_;
+ ipc::IpcWriteOptions ipc_options_;
+ std::shared_ptr<RecordBatch> current_batch_;
+ std::vector<std::pair<int64_t, std::shared_ptr<Array>>> dictionaries_;
+
+ // Index of next dictionary to send
+ int dictionary_index_ = 0;
+};
+
+FlightDataStream::~FlightDataStream() {}
+
+RecordBatchStream::RecordBatchStream(const std::shared_ptr<RecordBatchReader>& reader,
+ const ipc::IpcWriteOptions& options) {
+ impl_.reset(new RecordBatchStreamImpl(reader, options));
+}
+
+RecordBatchStream::~RecordBatchStream() {}
+
+std::shared_ptr<Schema> RecordBatchStream::schema() { return impl_->schema(); }
+
+Status RecordBatchStream::GetSchemaPayload(FlightPayload* payload) {
+ return impl_->GetSchemaPayload(payload);
+}
+
+Status RecordBatchStream::Next(FlightPayload* payload) { return impl_->Next(payload); }
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/server.h b/src/arrow/cpp/src/arrow/flight/server.h
new file mode 100644
index 000000000..96b2da488
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/server.h
@@ -0,0 +1,285 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Interfaces to use for defining Flight RPC servers. API should be considered
+// experimental for now
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/flight/server_auth.h"
+#include "arrow/flight/types.h" // IWYU pragma: keep
+#include "arrow/flight/visibility.h" // IWYU pragma: keep
+#include "arrow/ipc/dictionary.h"
+#include "arrow/ipc/options.h"
+#include "arrow/record_batch.h"
+
+namespace arrow {
+
+class Schema;
+class Status;
+
+namespace flight {
+
+class ServerMiddleware;
+class ServerMiddlewareFactory;
+
+/// \brief Interface that produces a sequence of IPC payloads to be sent in
+/// FlightData protobuf messages
+class ARROW_FLIGHT_EXPORT FlightDataStream {
+ public:
+ virtual ~FlightDataStream();
+
+ virtual std::shared_ptr<Schema> schema() = 0;
+
+ /// \brief Compute FlightPayload containing serialized RecordBatch schema
+ virtual Status GetSchemaPayload(FlightPayload* payload) = 0;
+
+ // When the stream is completed, the last payload written will have null
+ // metadata
+ virtual Status Next(FlightPayload* payload) = 0;
+};
+
+/// \brief A basic implementation of FlightDataStream that will provide
+/// a sequence of FlightData messages to be written to a gRPC stream
+class ARROW_FLIGHT_EXPORT RecordBatchStream : public FlightDataStream {
+ public:
+ /// \param[in] reader produces a sequence of record batches
+ /// \param[in] options IPC options for writing
+ explicit RecordBatchStream(
+ const std::shared_ptr<RecordBatchReader>& reader,
+ const ipc::IpcWriteOptions& options = ipc::IpcWriteOptions::Defaults());
+ ~RecordBatchStream() override;
+
+ std::shared_ptr<Schema> schema() override;
+ Status GetSchemaPayload(FlightPayload* payload) override;
+ Status Next(FlightPayload* payload) override;
+
+ private:
+ class RecordBatchStreamImpl;
+ std::unique_ptr<RecordBatchStreamImpl> impl_;
+};
+
+/// \brief A reader for IPC payloads uploaded by a client. Also allows
+/// reading application-defined metadata via the Flight protocol.
+class ARROW_FLIGHT_EXPORT FlightMessageReader : public MetadataRecordBatchReader {
+ public:
+ /// \brief Get the descriptor for this upload.
+ virtual const FlightDescriptor& descriptor() const = 0;
+};
+
+/// \brief A writer for application-specific metadata sent back to the
+/// client during an upload.
+class ARROW_FLIGHT_EXPORT FlightMetadataWriter {
+ public:
+ virtual ~FlightMetadataWriter();
+ /// \brief Send a message to the client.
+ virtual Status WriteMetadata(const Buffer& app_metadata) = 0;
+};
+
+/// \brief A writer for IPC payloads to a client. Also allows sending
+/// application-defined metadata via the Flight protocol.
+///
+/// This class offers more control compared to FlightDataStream,
+/// including the option to write metadata without data and the
+/// ability to interleave reading and writing.
+class ARROW_FLIGHT_EXPORT FlightMessageWriter : public MetadataRecordBatchWriter {
+ public:
+ virtual ~FlightMessageWriter() = default;
+};
+
+/// \brief Call state/contextual data.
+class ARROW_FLIGHT_EXPORT ServerCallContext {
+ public:
+ virtual ~ServerCallContext() = default;
+ /// \brief The name of the authenticated peer (may be the empty string)
+ virtual const std::string& peer_identity() const = 0;
+ /// \brief The peer address (not validated)
+ virtual const std::string& peer() const = 0;
+ /// \brief Look up a middleware by key. Do not maintain a reference
+ /// to the object beyond the request body.
+ /// \return The middleware, or nullptr if not found.
+ virtual ServerMiddleware* GetMiddleware(const std::string& key) const = 0;
+ /// \brief Check if the current RPC has been cancelled (by the client, by
+ /// a network error, etc.).
+ virtual bool is_cancelled() const = 0;
+};
+
+class ARROW_FLIGHT_EXPORT FlightServerOptions {
+ public:
+ explicit FlightServerOptions(const Location& location_);
+
+ ~FlightServerOptions();
+
+ /// \brief The host & port (or domain socket path) to listen on.
+ /// Use port 0 to bind to an available port.
+ Location location;
+ /// \brief The authentication handler to use.
+ std::shared_ptr<ServerAuthHandler> auth_handler;
+ /// \brief A list of TLS certificate+key pairs to use.
+ std::vector<CertKeyPair> tls_certificates;
+ /// \brief Enable mTLS and require that the client present a certificate.
+ bool verify_client;
+ /// \brief If using mTLS, the PEM-encoded root certificate to use.
+ std::string root_certificates;
+ /// \brief A list of server middleware to apply, along with a key to
+ /// identify them by.
+ ///
+ /// Middleware are always applied in the order provided. Duplicate
+ /// keys are an error.
+ std::vector<std::pair<std::string, std::shared_ptr<ServerMiddlewareFactory>>>
+ middleware;
+
+ /// \brief A Flight implementation-specific callback to customize
+ /// transport-specific options.
+ ///
+ /// Not guaranteed to be called. The type of the parameter is
+ /// specific to the Flight implementation. Users should take care to
+ /// link to the same transport implementation as Flight to avoid
+ /// runtime problems.
+ std::function<void(void*)> builder_hook;
+};
+
+/// \brief Skeleton RPC server implementation which can be used to create
+/// custom servers by implementing its abstract methods
+class ARROW_FLIGHT_EXPORT FlightServerBase {
+ public:
+ FlightServerBase();
+ virtual ~FlightServerBase();
+
+ // Lifecycle methods.
+
+ /// \brief Initialize a Flight server listening at the given location.
+ /// This method must be called before any other method.
+ /// \param[in] options The configuration for this server.
+ Status Init(const FlightServerOptions& options);
+
+ /// \brief Get the port that the Flight server is listening on.
+ /// This method must only be called after Init(). Will return a
+ /// non-positive value if no port exists (e.g. when listening on a
+ /// domain socket).
+ int port() const;
+
+ /// \brief Set the server to stop when receiving any of the given signal
+ /// numbers.
+ /// This method must be called before Serve().
+ Status SetShutdownOnSignals(const std::vector<int> sigs);
+
+ /// \brief Start serving.
+ /// This method blocks until either Shutdown() is called or one of the signals
+ /// registered in SetShutdownOnSignals() is received.
+ Status Serve();
+
+ /// \brief Query whether Serve() was interrupted by a signal.
+ /// This method must be called after Serve() has returned.
+ ///
+ /// \return int the signal number that interrupted Serve(), if any, otherwise 0
+ int GotSignal() const;
+
+ /// \brief Shut down the server. Can be called from signal handler or another
+ /// thread while Serve() blocks.
+ ///
+ /// TODO(wesm): Shutdown with deadline
+ Status Shutdown();
+
+ /// \brief Block until server is terminated with Shutdown.
+ Status Wait();
+
+ // Implement these methods to create your own server. The default
+ // implementations will return a not-implemented result to the client
+
+ /// \brief Retrieve a list of available fields given an optional opaque
+ /// criteria
+ /// \param[in] context The call context.
+ /// \param[in] criteria may be null
+ /// \param[out] listings the returned listings iterator
+ /// \return Status
+ virtual Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
+ std::unique_ptr<FlightListing>* listings);
+
+ /// \brief Retrieve the schema and an access plan for the indicated
+ /// descriptor
+ /// \param[in] context The call context.
+ /// \param[in] request may be null
+ /// \param[out] info the returned flight info provider
+ /// \return Status
+ virtual Status GetFlightInfo(const ServerCallContext& context,
+ const FlightDescriptor& request,
+ std::unique_ptr<FlightInfo>* info);
+
+ /// \brief Retrieve the schema for the indicated descriptor
+ /// \param[in] context The call context.
+ /// \param[in] request may be null
+ /// \param[out] schema the returned flight schema provider
+ /// \return Status
+ virtual Status GetSchema(const ServerCallContext& context,
+ const FlightDescriptor& request,
+ std::unique_ptr<SchemaResult>* schema);
+
+ /// \brief Get a stream of IPC payloads to put on the wire
+ /// \param[in] context The call context.
+ /// \param[in] request an opaque ticket
+ /// \param[out] stream the returned stream provider
+ /// \return Status
+ virtual Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* stream);
+
+ /// \brief Process a stream of IPC payloads sent from a client
+ /// \param[in] context The call context.
+ /// \param[in] reader a sequence of uploaded record batches
+ /// \param[in] writer send metadata back to the client
+ /// \return Status
+ virtual Status DoPut(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer);
+
+ /// \brief Process a bidirectional stream of IPC payloads
+ /// \param[in] context The call context.
+ /// \param[in] reader a sequence of uploaded record batches
+ /// \param[in] writer send data back to the client
+ /// \return Status
+ virtual Status DoExchange(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer);
+
+ /// \brief Execute an action, return stream of zero or more results
+ /// \param[in] context The call context.
+ /// \param[in] action the action to execute, with type and body
+ /// \param[out] result the result iterator
+ /// \return Status
+ virtual Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* result);
+
+ /// \brief Retrieve the list of available actions
+ /// \param[in] context The call context.
+ /// \param[out] actions a vector of available action types
+ /// \return Status
+ virtual Status ListActions(const ServerCallContext& context,
+ std::vector<ActionType>* actions);
+
+ private:
+ struct Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/server_auth.cc b/src/arrow/cpp/src/arrow/flight/server_auth.cc
new file mode 100644
index 000000000..ae03c3d45
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/server_auth.cc
@@ -0,0 +1,37 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/server_auth.h"
+
+namespace arrow {
+namespace flight {
+
+ServerAuthHandler::~ServerAuthHandler() {}
+
+NoOpAuthHandler::~NoOpAuthHandler() {}
+Status NoOpAuthHandler::Authenticate(ServerAuthSender* outgoing,
+ ServerAuthReader* incoming) {
+ return Status::OK();
+}
+
+Status NoOpAuthHandler::IsValid(const std::string& token, std::string* peer_identity) {
+ *peer_identity = "";
+ return Status::OK();
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/server_auth.h b/src/arrow/cpp/src/arrow/flight/server_auth.h
new file mode 100644
index 000000000..b1ccb096d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/server_auth.h
@@ -0,0 +1,78 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// \brief Server-side APIs to implement authentication for Flight.
+
+#pragma once
+
+#include <string>
+
+#include "arrow/flight/visibility.h"
+#include "arrow/status.h"
+
+namespace arrow {
+
+namespace flight {
+
+/// \brief A reader for messages from the client during an
+/// authentication handshake.
+class ARROW_FLIGHT_EXPORT ServerAuthReader {
+ public:
+ virtual ~ServerAuthReader() = default;
+ virtual Status Read(std::string* token) = 0;
+};
+
+/// \brief A writer for messages to the client during an
+/// authentication handshake.
+class ARROW_FLIGHT_EXPORT ServerAuthSender {
+ public:
+ virtual ~ServerAuthSender() = default;
+ virtual Status Write(const std::string& message) = 0;
+};
+
+/// \brief An authentication implementation for a Flight service.
+/// Authentication includes both an initial negotiation and a per-call
+/// token validation. Implementations may choose to use either or both
+/// mechanisms.
+/// An implementation may need to track some state, e.g. a mapping of
+/// client tokens to authenticated identities.
+class ARROW_FLIGHT_EXPORT ServerAuthHandler {
+ public:
+ virtual ~ServerAuthHandler();
+ /// \brief Authenticate the client on initial connection. The server
+ /// can send and read responses from the client at any time.
+ virtual Status Authenticate(ServerAuthSender* outgoing, ServerAuthReader* incoming) = 0;
+ /// \brief Validate a per-call client token.
+ /// \param[in] token The client token. May be the empty string if
+ /// the client does not provide a token.
+ /// \param[out] peer_identity The identity of the peer, if this
+ /// authentication method supports it.
+ /// \return Status OK if the token is valid, any other status if
+ /// validation failed
+ virtual Status IsValid(const std::string& token, std::string* peer_identity) = 0;
+};
+
+/// \brief An authentication mechanism that does nothing.
+class ARROW_FLIGHT_EXPORT NoOpAuthHandler : public ServerAuthHandler {
+ public:
+ ~NoOpAuthHandler() override;
+ Status Authenticate(ServerAuthSender* outgoing, ServerAuthReader* incoming) override;
+ Status IsValid(const std::string& token, std::string* peer_identity) override;
+};
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/server_middleware.h b/src/arrow/cpp/src/arrow/flight/server_middleware.h
new file mode 100644
index 000000000..26431aff0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/server_middleware.h
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Interfaces for defining middleware for Flight servers. Currently
+// experimental.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/flight/middleware.h"
+#include "arrow/flight/visibility.h" // IWYU pragma: keep
+#include "arrow/status.h"
+
+namespace arrow {
+namespace flight {
+
+/// \brief Server-side middleware for a call, instantiated per RPC.
+///
+/// Middleware should be fast and must be infallible: there is no way
+/// to reject the call or report errors from the middleware instance.
+class ARROW_FLIGHT_EXPORT ServerMiddleware {
+ public:
+ virtual ~ServerMiddleware() = default;
+
+ /// \brief Unique name of middleware, used as alternative to RTTI
+ /// \return the string name of the middleware
+ virtual std::string name() const = 0;
+
+ /// \brief A callback before headers are sent. Extra headers can be
+ /// added, but existing ones cannot be read.
+ virtual void SendingHeaders(AddCallHeaders* outgoing_headers) = 0;
+
+ /// \brief A callback after the call has completed.
+ virtual void CallCompleted(const Status& status) = 0;
+};
+
+/// \brief A factory for new middleware instances.
+///
+/// If added to a server, this will be called for each RPC (including
+/// Handshake) to give the opportunity to intercept the call.
+///
+/// It is guaranteed that all server middleware methods are called
+/// from the same thread that calls the RPC method implementation.
+class ARROW_FLIGHT_EXPORT ServerMiddlewareFactory {
+ public:
+ virtual ~ServerMiddlewareFactory() = default;
+
+ /// \brief A callback for the start of a new call.
+ ///
+ /// Return a non-OK status to reject the call with the given status.
+ ///
+ /// \param info Information about the call.
+ /// \param incoming_headers Headers sent by the client for this call.
+ /// Do not retain a reference to this object.
+ /// \param[out] middleware The middleware instance for this call. If
+ /// null, no middleware will be added to this call instance from
+ /// this factory.
+ /// \return Status A non-OK status will reject the call with the
+ /// given status. Middleware previously in the chain will have
+ /// their CallCompleted callback called. Other middleware
+ /// factories will not be called.
+ virtual Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) = 0;
+};
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/test_integration.cc b/src/arrow/cpp/src/arrow/flight/test_integration.cc
new file mode 100644
index 000000000..29ce5601f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/test_integration.cc
@@ -0,0 +1,270 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/test_integration.h"
+#include "arrow/flight/client_middleware.h"
+#include "arrow/flight/server_middleware.h"
+#include "arrow/flight/test_util.h"
+#include "arrow/flight/types.h"
+#include "arrow/ipc/dictionary.h"
+
+#include <iostream>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace arrow {
+namespace flight {
+
+/// \brief The server for the basic auth integration test.
+class AuthBasicProtoServer : public FlightServerBase {
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* result) override {
+ // Respond with the authenticated username.
+ auto buf = Buffer::FromString(context.peer_identity());
+ *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
+ return Status::OK();
+ }
+};
+
+/// Validate the result of a DoAction.
+Status CheckActionResults(FlightClient* client, const Action& action,
+ std::vector<std::string> results) {
+ std::unique_ptr<ResultStream> stream;
+ RETURN_NOT_OK(client->DoAction(action, &stream));
+ std::unique_ptr<Result> result;
+ for (const std::string& expected : results) {
+ RETURN_NOT_OK(stream->Next(&result));
+ if (!result) {
+ return Status::Invalid("Action result stream ended early");
+ }
+ const auto actual = result->body->ToString();
+ if (expected != actual) {
+ return Status::Invalid("Got wrong result; expected", expected, "but got", actual);
+ }
+ }
+ RETURN_NOT_OK(stream->Next(&result));
+ if (result) {
+ return Status::Invalid("Action result stream had too many entries");
+ }
+ return Status::OK();
+}
+
+// The expected username for the basic auth integration test.
+constexpr auto kAuthUsername = "arrow";
+// The expected password for the basic auth integration test.
+constexpr auto kAuthPassword = "flight";
+
+/// \brief A scenario testing the basic auth protobuf.
+class AuthBasicProtoScenario : public Scenario {
+ Status MakeServer(std::unique_ptr<FlightServerBase>* server,
+ FlightServerOptions* options) override {
+ server->reset(new AuthBasicProtoServer());
+ options->auth_handler =
+ std::make_shared<TestServerBasicAuthHandler>(kAuthUsername, kAuthPassword);
+ return Status::OK();
+ }
+
+ Status MakeClient(FlightClientOptions* options) override { return Status::OK(); }
+
+ Status RunClient(std::unique_ptr<FlightClient> client) override {
+ Action action;
+ std::unique_ptr<ResultStream> stream;
+ std::shared_ptr<FlightStatusDetail> detail;
+ const auto& status = client->DoAction(action, &stream);
+ detail = FlightStatusDetail::UnwrapStatus(status);
+ // This client is unauthenticated and should fail.
+ if (detail == nullptr) {
+ return Status::Invalid("Expected UNAUTHENTICATED but got ", status.ToString());
+ }
+ if (detail->code() != FlightStatusCode::Unauthenticated) {
+ return Status::Invalid("Expected UNAUTHENTICATED but got ", detail->ToString());
+ }
+
+ auto client_handler = std::unique_ptr<ClientAuthHandler>(
+ new TestClientBasicAuthHandler(kAuthUsername, kAuthPassword));
+ RETURN_NOT_OK(client->Authenticate({}, std::move(client_handler)));
+ return CheckActionResults(client.get(), action, {kAuthUsername});
+ }
+};
+
+/// \brief Test middleware that echoes back the value of a particular
+/// incoming header.
+///
+/// In Java, gRPC may consolidate this header with HTTP/2 trailers if
+/// the call fails, but C++ generally doesn't do this. The integration
+/// test confirms the presence of this header to ensure we can read it
+/// regardless of what gRPC does.
+class TestServerMiddleware : public ServerMiddleware {
+ public:
+ explicit TestServerMiddleware(std::string received) : received_(received) {}
+ void SendingHeaders(AddCallHeaders* outgoing_headers) override {
+ outgoing_headers->AddHeader("x-middleware", received_);
+ }
+ void CallCompleted(const Status& status) override {}
+
+ std::string name() const override { return "GrpcTrailersMiddleware"; }
+
+ private:
+ std::string received_;
+};
+
+class TestServerMiddlewareFactory : public ServerMiddlewareFactory {
+ public:
+ Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) override {
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& iter_pair =
+ incoming_headers.equal_range("x-middleware");
+ std::string received = "";
+ if (iter_pair.first != iter_pair.second) {
+ const util::string_view& value = (*iter_pair.first).second;
+ received = std::string(value);
+ }
+ *middleware = std::make_shared<TestServerMiddleware>(received);
+ return Status::OK();
+ }
+};
+
+/// \brief Test middleware that adds a header on every outgoing call,
+/// and gets the value of the expected header sent by the server.
+class TestClientMiddleware : public ClientMiddleware {
+ public:
+ explicit TestClientMiddleware(std::string* received_header)
+ : received_header_(received_header) {}
+
+ void SendingHeaders(AddCallHeaders* outgoing_headers) {
+ outgoing_headers->AddHeader("x-middleware", "expected value");
+ }
+
+ void ReceivedHeaders(const CallHeaders& incoming_headers) {
+ // We expect the server to always send this header. gRPC/Java may
+ // send it in trailers instead of headers, so we expect Flight to
+ // account for this.
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& iter_pair =
+ incoming_headers.equal_range("x-middleware");
+ if (iter_pair.first != iter_pair.second) {
+ const util::string_view& value = (*iter_pair.first).second;
+ *received_header_ = std::string(value);
+ }
+ }
+
+ void CallCompleted(const Status& status) {}
+
+ private:
+ std::string* received_header_;
+};
+
+class TestClientMiddlewareFactory : public ClientMiddlewareFactory {
+ public:
+ void StartCall(const CallInfo& info, std::unique_ptr<ClientMiddleware>* middleware) {
+ *middleware =
+ std::unique_ptr<ClientMiddleware>(new TestClientMiddleware(&received_header_));
+ }
+
+ std::string received_header_;
+};
+
+/// \brief The server used for testing middleware. Implements only one
+/// endpoint, GetFlightInfo, in such a way that it either succeeds or
+/// returns an error based on the input, in order to test both paths.
+class MiddlewareServer : public FlightServerBase {
+ Status GetFlightInfo(const ServerCallContext& context,
+ const FlightDescriptor& descriptor,
+ std::unique_ptr<FlightInfo>* result) override {
+ if (descriptor.type == FlightDescriptor::DescriptorType::CMD &&
+ descriptor.cmd == "success") {
+ // Don't fail
+ std::shared_ptr<Schema> schema = arrow::schema({});
+ Location location;
+ // Return a fake location - the test doesn't read it
+ RETURN_NOT_OK(Location::ForGrpcTcp("localhost", 10010, &location));
+ std::vector<FlightEndpoint> endpoints{FlightEndpoint{{"foo"}, {location}}};
+ ARROW_ASSIGN_OR_RAISE(auto info,
+ FlightInfo::Make(*schema, descriptor, endpoints, -1, -1));
+ *result = std::unique_ptr<FlightInfo>(new FlightInfo(info));
+ return Status::OK();
+ }
+ // Fail the call immediately. In some gRPC implementations, this
+ // means that gRPC sends only HTTP/2 trailers and not headers. We want
+ // Flight middleware to be agnostic to this difference.
+ return Status::UnknownError("Unknown");
+ }
+};
+
+/// \brief The middleware scenario.
+///
+/// This tests that the server and client get expected header values.
+class MiddlewareScenario : public Scenario {
+ Status MakeServer(std::unique_ptr<FlightServerBase>* server,
+ FlightServerOptions* options) override {
+ options->middleware.push_back(
+ {"grpc_trailers", std::make_shared<TestServerMiddlewareFactory>()});
+ server->reset(new MiddlewareServer());
+ return Status::OK();
+ }
+
+ Status MakeClient(FlightClientOptions* options) override {
+ client_middleware_ = std::make_shared<TestClientMiddlewareFactory>();
+ options->middleware.push_back(client_middleware_);
+ return Status::OK();
+ }
+
+ Status RunClient(std::unique_ptr<FlightClient> client) override {
+ std::unique_ptr<FlightInfo> info;
+ // This call is expected to fail. In gRPC/Java, this causes the
+ // server to combine headers and HTTP/2 trailers, so to read the
+ // expected header, Flight must check for both headers and
+ // trailers.
+ if (client->GetFlightInfo(FlightDescriptor::Command(""), &info).ok()) {
+ return Status::Invalid("Expected call to fail");
+ }
+ if (client_middleware_->received_header_ != "expected value") {
+ return Status::Invalid(
+ "Expected to receive header 'x-middleware: expected value', but instead got: '",
+ client_middleware_->received_header_, "'");
+ }
+ std::cerr << "Headers received successfully on failing call." << std::endl;
+
+ // This call should succeed
+ client_middleware_->received_header_ = "";
+ RETURN_NOT_OK(client->GetFlightInfo(FlightDescriptor::Command("success"), &info));
+ if (client_middleware_->received_header_ != "expected value") {
+ return Status::Invalid(
+ "Expected to receive header 'x-middleware: expected value', but instead got '",
+ client_middleware_->received_header_, "'");
+ }
+ std::cerr << "Headers received successfully on passing call." << std::endl;
+ return Status::OK();
+ }
+
+ std::shared_ptr<TestClientMiddlewareFactory> client_middleware_;
+};
+
+Status GetScenario(const std::string& scenario_name, std::shared_ptr<Scenario>* out) {
+ if (scenario_name == "auth:basic_proto") {
+ *out = std::make_shared<AuthBasicProtoScenario>();
+ return Status::OK();
+ } else if (scenario_name == "middleware") {
+ *out = std::make_shared<MiddlewareScenario>();
+ return Status::OK();
+ }
+ return Status::KeyError("Scenario not found: ", scenario_name);
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/test_integration.h b/src/arrow/cpp/src/arrow/flight/test_integration.h
new file mode 100644
index 000000000..5d9bd7fd7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/test_integration.h
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Integration test scenarios for Arrow Flight.
+
+#include "arrow/flight/visibility.h"
+
+#include <memory>
+#include <string>
+
+#include "arrow/flight/client.h"
+#include "arrow/flight/server.h"
+#include "arrow/status.h"
+
+namespace arrow {
+namespace flight {
+
+/// \brief An integration test for Arrow Flight.
+class ARROW_FLIGHT_EXPORT Scenario {
+ public:
+ virtual ~Scenario() = default;
+ /// \brief Set up the server.
+ virtual Status MakeServer(std::unique_ptr<FlightServerBase>* server,
+ FlightServerOptions* options) = 0;
+ /// \brief Set up the client.
+ virtual Status MakeClient(FlightClientOptions* options) = 0;
+ /// \brief Run the scenario as the client.
+ virtual Status RunClient(std::unique_ptr<FlightClient> client) = 0;
+};
+
+/// \brief Get the implementation of an integration test scenario by name.
+Status GetScenario(const std::string& scenario_name, std::shared_ptr<Scenario>* out);
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/test_integration_client.cc b/src/arrow/cpp/src/arrow/flight/test_integration_client.cc
new file mode 100644
index 000000000..6c1d69046
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/test_integration_client.cc
@@ -0,0 +1,244 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Client implementation for Flight integration testing. Loads
+// RecordBatches from the given JSON file and uploads them to the
+// Flight server, which stores the data and schema in memory. The
+// client then requests the data from the server and compares it to
+// the data originally uploaded.
+
+#include <chrono>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <thread>
+
+#include <gflags/gflags.h>
+
+#include "arrow/io/file.h"
+#include "arrow/io/test_common.h"
+#include "arrow/ipc/dictionary.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+#include "arrow/testing/extension_type.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/json_integration.h"
+#include "arrow/util/logging.h"
+
+#include "arrow/flight/api.h"
+#include "arrow/flight/test_integration.h"
+#include "arrow/flight/test_util.h"
+
+DEFINE_string(host, "localhost", "Server port to connect to");
+DEFINE_int32(port, 31337, "Server port to connect to");
+DEFINE_string(path, "", "Resource path to request");
+DEFINE_string(scenario, "", "Integration test scenario to run");
+
+namespace arrow {
+namespace flight {
+
+/// \brief Helper to read all batches from a JsonReader
+Status ReadBatches(std::unique_ptr<testing::IntegrationJsonReader>& reader,
+ std::vector<std::shared_ptr<RecordBatch>>* chunks) {
+ std::shared_ptr<RecordBatch> chunk;
+ for (int i = 0; i < reader->num_record_batches(); i++) {
+ RETURN_NOT_OK(reader->ReadRecordBatch(i, &chunk));
+ RETURN_NOT_OK(chunk->ValidateFull());
+ chunks->push_back(chunk);
+ }
+ return Status::OK();
+}
+
+/// \brief Upload the a list of batches to a Flight server, validating
+/// the application metadata on the side.
+Status UploadBatchesToFlight(const std::vector<std::shared_ptr<RecordBatch>>& chunks,
+ FlightStreamWriter& writer,
+ FlightMetadataReader& metadata_reader) {
+ int counter = 0;
+ for (const auto& chunk : chunks) {
+ std::shared_ptr<Buffer> metadata = Buffer::FromString(std::to_string(counter));
+ RETURN_NOT_OK(writer.WriteWithMetadata(*chunk, metadata));
+ // Wait for the server to ack the result
+ std::shared_ptr<Buffer> ack_metadata;
+ RETURN_NOT_OK(metadata_reader.ReadMetadata(&ack_metadata));
+ if (!ack_metadata) {
+ return Status::Invalid("Expected metadata value: ", metadata->ToString(),
+ " but got nothing.");
+ } else if (!ack_metadata->Equals(*metadata)) {
+ return Status::Invalid("Expected metadata value: ", metadata->ToString(),
+ " but got: ", ack_metadata->ToString());
+ }
+ counter++;
+ }
+ return writer.Close();
+}
+
+/// \brief Retrieve the given Flight and compare to the original expected batches.
+Status ConsumeFlightLocation(
+ const Location& location, const Ticket& ticket,
+ const std::vector<std::shared_ptr<RecordBatch>>& retrieved_data) {
+ std::unique_ptr<FlightClient> read_client;
+ RETURN_NOT_OK(FlightClient::Connect(location, &read_client));
+
+ std::unique_ptr<FlightStreamReader> stream;
+ RETURN_NOT_OK(read_client->DoGet(ticket, &stream));
+
+ int counter = 0;
+ const int expected = static_cast<int>(retrieved_data.size());
+ for (const auto& original_batch : retrieved_data) {
+ FlightStreamChunk chunk;
+ RETURN_NOT_OK(stream->Next(&chunk));
+ if (chunk.data == nullptr) {
+ return Status::Invalid("Got fewer batches than expected, received so far: ",
+ counter, " expected ", expected);
+ }
+
+ if (!original_batch->Equals(*chunk.data)) {
+ return Status::Invalid("Batch ", counter, " does not match");
+ }
+
+ const auto st = chunk.data->ValidateFull();
+ if (!st.ok()) {
+ return Status::Invalid("Batch ", counter, " is not valid: ", st.ToString());
+ }
+
+ if (std::to_string(counter) != chunk.app_metadata->ToString()) {
+ return Status::Invalid("Expected metadata value: " + std::to_string(counter) +
+ " but got: " + chunk.app_metadata->ToString());
+ }
+ counter++;
+ }
+
+ FlightStreamChunk chunk;
+ RETURN_NOT_OK(stream->Next(&chunk));
+ if (chunk.data != nullptr) {
+ return Status::Invalid("Got more batches than the expected ", expected);
+ }
+
+ return Status::OK();
+}
+
+class IntegrationTestScenario : public flight::Scenario {
+ public:
+ Status MakeServer(std::unique_ptr<FlightServerBase>* server,
+ FlightServerOptions* options) override {
+ ARROW_UNUSED(server);
+ ARROW_UNUSED(options);
+ return Status::NotImplemented("Not implemented, see test_integration_server.cc");
+ }
+
+ Status MakeClient(FlightClientOptions* options) override {
+ ARROW_UNUSED(options);
+ return Status::OK();
+ }
+
+ Status RunClient(std::unique_ptr<FlightClient> client) override {
+ // Make sure the required extension types are registered.
+ ExtensionTypeGuard uuid_ext_guard(uuid());
+ ExtensionTypeGuard dict_ext_guard(dict_extension_type());
+
+ FlightDescriptor descr{FlightDescriptor::PATH, "", {FLAGS_path}};
+
+ // 1. Put the data to the server.
+ std::unique_ptr<testing::IntegrationJsonReader> reader;
+ std::cout << "Opening JSON file '" << FLAGS_path << "'" << std::endl;
+ auto in_file = *io::ReadableFile::Open(FLAGS_path);
+ ABORT_NOT_OK(
+ testing::IntegrationJsonReader::Open(default_memory_pool(), in_file, &reader));
+
+ std::shared_ptr<Schema> original_schema = reader->schema();
+ std::vector<std::shared_ptr<RecordBatch>> original_data;
+ ABORT_NOT_OK(ReadBatches(reader, &original_data));
+
+ std::unique_ptr<FlightStreamWriter> write_stream;
+ std::unique_ptr<FlightMetadataReader> metadata_reader;
+ ABORT_NOT_OK(client->DoPut(descr, original_schema, &write_stream, &metadata_reader));
+ ABORT_NOT_OK(UploadBatchesToFlight(original_data, *write_stream, *metadata_reader));
+
+ // 2. Get the ticket for the data.
+ std::unique_ptr<FlightInfo> info;
+ ABORT_NOT_OK(client->GetFlightInfo(descr, &info));
+
+ std::shared_ptr<Schema> schema;
+ ipc::DictionaryMemo dict_memo;
+ ABORT_NOT_OK(info->GetSchema(&dict_memo, &schema));
+
+ if (info->endpoints().size() == 0) {
+ std::cerr << "No endpoints returned from Flight server." << std::endl;
+ return Status::IOError("No endpoints returned from Flight server.");
+ }
+
+ for (const FlightEndpoint& endpoint : info->endpoints()) {
+ const auto& ticket = endpoint.ticket;
+
+ auto locations = endpoint.locations;
+ if (locations.size() == 0) {
+ return Status::IOError("No locations returned from Flight server.");
+ }
+
+ for (const auto& location : locations) {
+ std::cout << "Verifying location " << location.ToString() << std::endl;
+ // 3. Stream data from the server, comparing individual batches.
+ ABORT_NOT_OK(ConsumeFlightLocation(location, ticket, original_data));
+ }
+ }
+ return Status::OK();
+ }
+};
+
+} // namespace flight
+} // namespace arrow
+
+constexpr int kRetries = 3;
+
+arrow::Status RunScenario(arrow::flight::Scenario* scenario) {
+ auto options = arrow::flight::FlightClientOptions::Defaults();
+ std::unique_ptr<arrow::flight::FlightClient> client;
+
+ RETURN_NOT_OK(scenario->MakeClient(&options));
+ arrow::flight::Location location;
+ RETURN_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location));
+ RETURN_NOT_OK(arrow::flight::FlightClient::Connect(location, options, &client));
+ return scenario->RunClient(std::move(client));
+}
+
+int main(int argc, char** argv) {
+ arrow::util::ArrowLog::InstallFailureSignalHandler();
+
+ gflags::SetUsageMessage("Integration testing client for Flight.");
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+ std::shared_ptr<arrow::flight::Scenario> scenario;
+ if (!FLAGS_scenario.empty()) {
+ ARROW_CHECK_OK(arrow::flight::GetScenario(FLAGS_scenario, &scenario));
+ } else {
+ scenario = std::make_shared<arrow::flight::IntegrationTestScenario>();
+ }
+
+ // ARROW-11908: retry a few times in case a client is slow to bring up the server
+ auto status = arrow::Status::OK();
+ for (int i = 0; i < kRetries; i++) {
+ status = RunScenario(scenario.get());
+ if (status.ok()) break;
+ // Failed, wait a bit and try again
+ std::this_thread::sleep_for(std::chrono::milliseconds((i + 1) * 500));
+ }
+ ABORT_NOT_OK(status);
+
+ arrow::util::ArrowLog::UninstallSignalAction();
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/flight/test_integration_server.cc b/src/arrow/cpp/src/arrow/flight/test_integration_server.cc
new file mode 100644
index 000000000..4b904b0eb
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/test_integration_server.cc
@@ -0,0 +1,207 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Server for integration testing.
+
+// Integration testing covers files and scenarios. The former
+// validates that Arrow data survives a round-trip through a Flight
+// service. The latter tests specific features of Arrow Flight.
+
+#include <signal.h>
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include <gflags/gflags.h>
+
+#include "arrow/io/test_common.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+#include "arrow/testing/json_integration.h"
+#include "arrow/util/logging.h"
+
+#include "arrow/flight/internal.h"
+#include "arrow/flight/server.h"
+#include "arrow/flight/server_auth.h"
+#include "arrow/flight/test_integration.h"
+#include "arrow/flight/test_util.h"
+
+DEFINE_int32(port, 31337, "Server port to listen on");
+DEFINE_string(scenario, "", "Integration test senario to run");
+
+namespace arrow {
+namespace flight {
+
+struct IntegrationDataset {
+ std::shared_ptr<Schema> schema;
+ std::vector<std::shared_ptr<RecordBatch>> chunks;
+};
+
+class RecordBatchListReader : public RecordBatchReader {
+ public:
+ explicit RecordBatchListReader(IntegrationDataset dataset)
+ : dataset_(dataset), current_(0) {}
+
+ std::shared_ptr<Schema> schema() const override { return dataset_.schema; }
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
+ if (current_ >= dataset_.chunks.size()) {
+ *batch = nullptr;
+ return Status::OK();
+ }
+ *batch = dataset_.chunks[current_];
+ current_++;
+ return Status::OK();
+ }
+
+ private:
+ IntegrationDataset dataset_;
+ uint64_t current_;
+};
+
+class FlightIntegrationTestServer : public FlightServerBase {
+ Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
+ std::unique_ptr<FlightInfo>* info) override {
+ if (request.type == FlightDescriptor::PATH) {
+ if (request.path.size() == 0) {
+ return Status::Invalid("Invalid path");
+ }
+
+ auto data = uploaded_chunks.find(request.path[0]);
+ if (data == uploaded_chunks.end()) {
+ return Status::KeyError("Could not find flight.", request.path[0]);
+ }
+ auto flight = data->second;
+
+ Location server_location;
+ RETURN_NOT_OK(Location::ForGrpcTcp("127.0.0.1", port(), &server_location));
+ FlightEndpoint endpoint1({{request.path[0]}, {server_location}});
+
+ FlightInfo::Data flight_data;
+ RETURN_NOT_OK(internal::SchemaToString(*flight.schema, &flight_data.schema));
+ flight_data.descriptor = request;
+ flight_data.endpoints = {endpoint1};
+ flight_data.total_records = 0;
+ for (const auto& chunk : flight.chunks) {
+ flight_data.total_records += chunk->num_rows();
+ }
+ flight_data.total_bytes = -1;
+ FlightInfo value(flight_data);
+
+ *info = std::unique_ptr<FlightInfo>(new FlightInfo(value));
+ return Status::OK();
+ } else {
+ return Status::NotImplemented(request.type);
+ }
+ }
+
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* data_stream) override {
+ auto data = uploaded_chunks.find(request.ticket);
+ if (data == uploaded_chunks.end()) {
+ return Status::KeyError("Could not find flight.", request.ticket);
+ }
+ auto flight = data->second;
+
+ *data_stream = std::unique_ptr<FlightDataStream>(
+ new NumberingStream(std::unique_ptr<FlightDataStream>(new RecordBatchStream(
+ std::shared_ptr<RecordBatchReader>(new RecordBatchListReader(flight))))));
+
+ return Status::OK();
+ }
+
+ Status DoPut(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer) override {
+ const FlightDescriptor& descriptor = reader->descriptor();
+
+ if (descriptor.type != FlightDescriptor::DescriptorType::PATH) {
+ return Status::Invalid("Must specify a path");
+ } else if (descriptor.path.size() < 1) {
+ return Status::Invalid("Must specify a path");
+ }
+
+ std::string key = descriptor.path[0];
+
+ IntegrationDataset dataset;
+ ARROW_ASSIGN_OR_RAISE(dataset.schema, reader->GetSchema());
+ arrow::flight::FlightStreamChunk chunk;
+ while (true) {
+ RETURN_NOT_OK(reader->Next(&chunk));
+ if (chunk.data == nullptr) break;
+ RETURN_NOT_OK(chunk.data->ValidateFull());
+ dataset.chunks.push_back(chunk.data);
+ if (chunk.app_metadata) {
+ RETURN_NOT_OK(writer->WriteMetadata(*chunk.app_metadata));
+ }
+ }
+ uploaded_chunks[key] = dataset;
+ return Status::OK();
+ }
+
+ std::unordered_map<std::string, IntegrationDataset> uploaded_chunks;
+};
+
+class IntegrationTestScenario : public Scenario {
+ public:
+ Status MakeServer(std::unique_ptr<FlightServerBase>* server,
+ FlightServerOptions* options) override {
+ server->reset(new FlightIntegrationTestServer());
+ return Status::OK();
+ }
+
+ Status MakeClient(FlightClientOptions* options) override {
+ ARROW_UNUSED(options);
+ return Status::NotImplemented("Not implemented, see test_integration_client.cc");
+ }
+
+ Status RunClient(std::unique_ptr<FlightClient> client) override {
+ ARROW_UNUSED(client);
+ return Status::NotImplemented("Not implemented, see test_integration_client.cc");
+ }
+};
+
+} // namespace flight
+} // namespace arrow
+
+std::unique_ptr<arrow::flight::FlightServerBase> g_server;
+
+int main(int argc, char** argv) {
+ gflags::SetUsageMessage("Integration testing server for Flight.");
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+ std::shared_ptr<arrow::flight::Scenario> scenario;
+
+ if (!FLAGS_scenario.empty()) {
+ ARROW_CHECK_OK(arrow::flight::GetScenario(FLAGS_scenario, &scenario));
+ } else {
+ scenario = std::make_shared<arrow::flight::IntegrationTestScenario>();
+ }
+ arrow::flight::Location location;
+ ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location));
+ arrow::flight::FlightServerOptions options(location);
+
+ ARROW_CHECK_OK(scenario->MakeServer(&g_server, &options));
+
+ ARROW_CHECK_OK(g_server->Init(options));
+ // Exit with a clean error code (0) on SIGTERM
+ ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
+
+ std::cout << "Server listening on localhost:" << g_server->port() << std::endl;
+ ARROW_CHECK_OK(g_server->Serve());
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/flight/test_server.cc b/src/arrow/cpp/src/arrow/flight/test_server.cc
new file mode 100644
index 000000000..2e5b10f84
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/test_server.cc
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Example server implementation to use for unit testing and benchmarking
+// purposes
+
+#include <signal.h>
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include <gflags/gflags.h>
+
+#include "arrow/flight/server.h"
+#include "arrow/flight/test_util.h"
+#include "arrow/flight/types.h"
+#include "arrow/util/logging.h"
+
+DEFINE_int32(port, 31337, "Server port to listen on");
+DEFINE_string(unix, "", "Unix socket path to listen on");
+
+std::unique_ptr<arrow::flight::FlightServerBase> g_server;
+
+int main(int argc, char** argv) {
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+ g_server = arrow::flight::ExampleTestServer();
+
+ arrow::flight::Location location;
+ if (FLAGS_unix.empty()) {
+ ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location));
+ } else {
+ ARROW_CHECK_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_unix, &location));
+ }
+ arrow::flight::FlightServerOptions options(location);
+
+ ARROW_CHECK_OK(g_server->Init(options));
+ // Exit with a clean error code (0) on SIGTERM
+ ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
+
+ if (FLAGS_unix.empty()) {
+ std::cout << "Server listening on localhost:" << FLAGS_port << std::endl;
+ } else {
+ std::cout << "Server listening on unix://" << FLAGS_unix << std::endl;
+ }
+ ARROW_CHECK_OK(g_server->Serve());
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/flight/test_util.cc b/src/arrow/cpp/src/arrow/flight/test_util.cc
new file mode 100644
index 000000000..d10b82807
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/test_util.cc
@@ -0,0 +1,822 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/platform.h"
+
+#ifdef __APPLE__
+#include <limits.h>
+#include <mach-o/dyld.h>
+#endif
+
+#include <algorithm>
+#include <cstdlib>
+#include <sstream>
+
+#include <boost/filesystem.hpp>
+// We need BOOST_USE_WINDOWS_H definition with MinGW when we use
+// boost/process.hpp. See ARROW_BOOST_PROCESS_COMPILE_DEFINITIONS in
+// cpp/cmake_modules/BuildUtils.cmake for details.
+#include <boost/process.hpp>
+
+#include <gtest/gtest.h>
+
+#include "arrow/ipc/test_common.h"
+#include "arrow/testing/generator.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/logging.h"
+
+#include "arrow/flight/api.h"
+#include "arrow/flight/internal.h"
+#include "arrow/flight/test_util.h"
+
+namespace arrow {
+namespace flight {
+
+namespace bp = boost::process;
+namespace fs = boost::filesystem;
+
+namespace {
+
+Status ResolveCurrentExecutable(fs::path* out) {
+ // See https://stackoverflow.com/a/1024937/10194 for various
+ // platform-specific recipes.
+
+ boost::system::error_code ec;
+
+#if defined(__linux__)
+ *out = fs::canonical("/proc/self/exe", ec);
+#elif defined(__APPLE__)
+ char buf[PATH_MAX + 1];
+ uint32_t bufsize = sizeof(buf);
+ if (_NSGetExecutablePath(buf, &bufsize) < 0) {
+ return Status::Invalid("Can't resolve current exe: path too large");
+ }
+ *out = fs::canonical(buf, ec);
+#elif defined(_WIN32)
+ char buf[MAX_PATH + 1];
+ if (!GetModuleFileNameA(NULL, buf, sizeof(buf))) {
+ return Status::Invalid("Can't get executable file path");
+ }
+ *out = fs::canonical(buf, ec);
+#else
+ ARROW_UNUSED(ec);
+ return Status::NotImplemented("Not available on this system");
+#endif
+ if (ec) {
+ // XXX fold this into the Status class?
+ return Status::IOError("Can't resolve current exe: ", ec.message());
+ } else {
+ return Status::OK();
+ }
+}
+
+} // namespace
+
+void TestServer::Start(const std::vector<std::string>& extra_args) {
+ namespace fs = boost::filesystem;
+
+ std::string str_port = std::to_string(port_);
+ std::vector<fs::path> search_path = ::boost::this_process::path();
+ // If possible, prepend current executable directory to search path,
+ // since it's likely that the test server executable is located in
+ // the same directory as the running unit test.
+ fs::path current_exe;
+ Status st = ResolveCurrentExecutable(&current_exe);
+ if (st.ok()) {
+ search_path.insert(search_path.begin(), current_exe.parent_path());
+ } else if (st.IsNotImplemented()) {
+ ARROW_CHECK(st.IsNotImplemented()) << st.ToString();
+ }
+
+ try {
+ if (unix_sock_.empty()) {
+ server_process_ =
+ std::make_shared<bp::child>(bp::search_path(executable_name_, search_path),
+ "-port", str_port, bp::args(extra_args));
+ } else {
+ server_process_ =
+ std::make_shared<bp::child>(bp::search_path(executable_name_, search_path),
+ "-server_unix", unix_sock_, bp::args(extra_args));
+ }
+ } catch (...) {
+ std::stringstream ss;
+ ss << "Failed to launch test server '" << executable_name_ << "', looked in ";
+ for (const auto& path : search_path) {
+ ss << path << " : ";
+ }
+ ARROW_LOG(FATAL) << ss.str();
+ throw;
+ }
+ std::cout << "Server running with pid " << server_process_->id() << std::endl;
+}
+
+int TestServer::Stop() {
+ if (server_process_ && server_process_->valid()) {
+#ifndef _WIN32
+ kill(server_process_->id(), SIGTERM);
+#else
+ // This would use SIGKILL on POSIX, which is more brutal than SIGTERM
+ server_process_->terminate();
+#endif
+ server_process_->wait();
+ return server_process_->exit_code();
+ } else {
+ // Presumably the server wasn't able to start
+ return -1;
+ }
+}
+
+bool TestServer::IsRunning() { return server_process_->running(); }
+
+int TestServer::port() const { return port_; }
+
+const std::string& TestServer::unix_sock() const { return unix_sock_; }
+
+Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr<RecordBatchReader>* out) {
+ if (ticket.ticket == "ticket-ints-1") {
+ BatchVector batches;
+ RETURN_NOT_OK(ExampleIntBatches(&batches));
+ *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
+ return Status::OK();
+ } else if (ticket.ticket == "ticket-floats-1") {
+ BatchVector batches;
+ RETURN_NOT_OK(ExampleFloatBatches(&batches));
+ *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
+ return Status::OK();
+ } else if (ticket.ticket == "ticket-dicts-1") {
+ BatchVector batches;
+ RETURN_NOT_OK(ExampleDictBatches(&batches));
+ *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
+ return Status::OK();
+ } else if (ticket.ticket == "ticket-large-batch-1") {
+ BatchVector batches;
+ RETURN_NOT_OK(ExampleLargeBatches(&batches));
+ *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
+ return Status::OK();
+ } else {
+ return Status::NotImplemented("no stream implemented for ticket: " + ticket.ticket);
+ }
+}
+
+class FlightTestServer : public FlightServerBase {
+ Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
+ std::unique_ptr<FlightListing>* listings) override {
+ std::vector<FlightInfo> flights = ExampleFlightInfo();
+ if (criteria && criteria->expression != "") {
+ // For test purposes, if we get criteria, return no results
+ flights.clear();
+ }
+ *listings = std::unique_ptr<FlightListing>(new SimpleFlightListing(flights));
+ return Status::OK();
+ }
+
+ Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
+ std::unique_ptr<FlightInfo>* out) override {
+ // Test that Arrow-C++ status codes can make it through gRPC
+ if (request.type == FlightDescriptor::DescriptorType::CMD &&
+ request.cmd == "status-outofmemory") {
+ return Status::OutOfMemory("Sentinel");
+ }
+
+ std::vector<FlightInfo> flights = ExampleFlightInfo();
+
+ for (const auto& info : flights) {
+ if (info.descriptor().Equals(request)) {
+ *out = std::unique_ptr<FlightInfo>(new FlightInfo(info));
+ return Status::OK();
+ }
+ }
+ return Status::Invalid("Flight not found: ", request.ToString());
+ }
+
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* data_stream) override {
+ // Test for ARROW-5095
+ if (request.ticket == "ARROW-5095-fail") {
+ return Status::UnknownError("Server-side error");
+ }
+ if (request.ticket == "ARROW-5095-success") {
+ return Status::OK();
+ }
+ if (request.ticket == "ARROW-13253-DoGet-Batch") {
+ // Make batch > 2GiB in size
+ ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch());
+ ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch}));
+ *data_stream =
+ std::unique_ptr<FlightDataStream>(new RecordBatchStream(std::move(reader)));
+ return Status::OK();
+ }
+
+ std::shared_ptr<RecordBatchReader> batch_reader;
+ RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader));
+
+ *data_stream = std::unique_ptr<FlightDataStream>(new RecordBatchStream(batch_reader));
+ return Status::OK();
+ }
+
+ Status DoPut(const ServerCallContext&, std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer) override {
+ BatchVector batches;
+ return reader->ReadAll(&batches);
+ }
+
+ Status DoExchange(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer) override {
+ // Test various scenarios for a DoExchange
+ if (reader->descriptor().type != FlightDescriptor::DescriptorType::CMD) {
+ return Status::Invalid("Must provide a command descriptor");
+ }
+
+ const std::string& cmd = reader->descriptor().cmd;
+ if (cmd == "error") {
+ // Immediately return an error to the client.
+ return Status::NotImplemented("Expected error");
+ } else if (cmd == "get") {
+ return RunExchangeGet(std::move(reader), std::move(writer));
+ } else if (cmd == "put") {
+ return RunExchangePut(std::move(reader), std::move(writer));
+ } else if (cmd == "counter") {
+ return RunExchangeCounter(std::move(reader), std::move(writer));
+ } else if (cmd == "total") {
+ return RunExchangeTotal(std::move(reader), std::move(writer));
+ } else if (cmd == "echo") {
+ return RunExchangeEcho(std::move(reader), std::move(writer));
+ } else if (cmd == "large_batch") {
+ return RunExchangeLargeBatch(std::move(reader), std::move(writer));
+ } else {
+ return Status::NotImplemented("Scenario not implemented: ", cmd);
+ }
+ }
+
+ // A simple example - act like DoGet.
+ Status RunExchangeGet(std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer) {
+ RETURN_NOT_OK(writer->Begin(ExampleIntSchema()));
+ BatchVector batches;
+ RETURN_NOT_OK(ExampleIntBatches(&batches));
+ for (const auto& batch : batches) {
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+ }
+ return Status::OK();
+ }
+
+ // A simple example - act like DoPut
+ Status RunExchangePut(std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer) {
+ ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
+ if (!schema->Equals(ExampleIntSchema(), false)) {
+ return Status::Invalid("Schema is not as expected");
+ }
+ BatchVector batches;
+ RETURN_NOT_OK(ExampleIntBatches(&batches));
+ FlightStreamChunk chunk;
+ for (const auto& batch : batches) {
+ RETURN_NOT_OK(reader->Next(&chunk));
+ if (!chunk.data) {
+ return Status::Invalid("Expected another batch");
+ }
+ if (!batch->Equals(*chunk.data)) {
+ return Status::Invalid("Batch does not match");
+ }
+ }
+ RETURN_NOT_OK(reader->Next(&chunk));
+ if (chunk.data || chunk.app_metadata) {
+ return Status::Invalid("Too many batches");
+ }
+
+ RETURN_NOT_OK(writer->WriteMetadata(Buffer::FromString("done")));
+ return Status::OK();
+ }
+
+ // Read some number of record batches from the client, send a
+ // metadata message back with the count, then echo the batches back.
+ Status RunExchangeCounter(std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer) {
+ std::vector<std::shared_ptr<RecordBatch>> batches;
+ FlightStreamChunk chunk;
+ int chunks = 0;
+ while (true) {
+ RETURN_NOT_OK(reader->Next(&chunk));
+ if (!chunk.data && !chunk.app_metadata) {
+ break;
+ }
+ if (chunk.data) {
+ batches.push_back(chunk.data);
+ chunks++;
+ }
+ }
+
+ // Echo back the number of record batches read.
+ std::shared_ptr<Buffer> buf = Buffer::FromString(std::to_string(chunks));
+ RETURN_NOT_OK(writer->WriteMetadata(buf));
+ // Echo the record batches themselves.
+ if (chunks > 0) {
+ ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
+ RETURN_NOT_OK(writer->Begin(schema));
+
+ for (const auto& batch : batches) {
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+ }
+ }
+
+ return Status::OK();
+ }
+
+ // Read int64 batches from the client, each time sending back a
+ // batch with a running sum of columns.
+ Status RunExchangeTotal(std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer) {
+ FlightStreamChunk chunk{};
+ ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
+ // Ensure the schema contains only int64 columns
+ for (const auto& field : schema->fields()) {
+ if (field->type()->id() != Type::type::INT64) {
+ return Status::Invalid("Field is not INT64: ", field->name());
+ }
+ }
+ std::vector<int64_t> sums(schema->num_fields());
+ std::vector<std::shared_ptr<Array>> columns(schema->num_fields());
+ RETURN_NOT_OK(writer->Begin(schema));
+ while (true) {
+ RETURN_NOT_OK(reader->Next(&chunk));
+ if (!chunk.data && !chunk.app_metadata) {
+ break;
+ }
+ if (chunk.data) {
+ if (!chunk.data->schema()->Equals(schema, false)) {
+ // A compliant client implementation would make this impossible
+ return Status::Invalid("Schemas are incompatible");
+ }
+
+ // Update the running totals
+ auto builder = std::make_shared<Int64Builder>();
+ int col_index = 0;
+ for (const auto& column : chunk.data->columns()) {
+ auto arr = std::dynamic_pointer_cast<Int64Array>(column);
+ if (!arr) {
+ return MakeFlightError(FlightStatusCode::Internal, "Could not cast array");
+ }
+ for (int row = 0; row < column->length(); row++) {
+ if (!arr->IsNull(row)) {
+ sums[col_index] += arr->Value(row);
+ }
+ }
+
+ builder->Reset();
+ RETURN_NOT_OK(builder->Append(sums[col_index]));
+ RETURN_NOT_OK(builder->Finish(&columns[col_index]));
+
+ col_index++;
+ }
+
+ // Echo the totals to the client
+ auto response = RecordBatch::Make(schema, /* num_rows */ 1, columns);
+ RETURN_NOT_OK(writer->WriteRecordBatch(*response));
+ }
+ }
+ return Status::OK();
+ }
+
+ // Echo the client's messages back.
+ Status RunExchangeEcho(std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMessageWriter> writer) {
+ FlightStreamChunk chunk;
+ bool begun = false;
+ while (true) {
+ RETURN_NOT_OK(reader->Next(&chunk));
+ if (!chunk.data && !chunk.app_metadata) {
+ break;
+ }
+ if (!begun && chunk.data) {
+ begun = true;
+ RETURN_NOT_OK(writer->Begin(chunk.data->schema()));
+ }
+ if (chunk.data && chunk.app_metadata) {
+ RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata));
+ } else if (chunk.data) {
+ RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data));
+ } else if (chunk.app_metadata) {
+ RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata));
+ }
+ }
+ return Status::OK();
+ }
+
+ // Regression test for ARROW-13253
+ Status RunExchangeLargeBatch(std::unique_ptr<FlightMessageReader>,
+ std::unique_ptr<FlightMessageWriter> writer) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch());
+ RETURN_NOT_OK(writer->Begin(batch->schema()));
+ return writer->WriteRecordBatch(*batch);
+ }
+
+ Status RunAction1(const Action& action, std::unique_ptr<ResultStream>* out) {
+ std::vector<Result> results;
+ for (int i = 0; i < 3; ++i) {
+ Result result;
+ std::string value = action.body->ToString() + "-part" + std::to_string(i);
+ result.body = Buffer::FromString(std::move(value));
+ results.push_back(result);
+ }
+ *out = std::unique_ptr<ResultStream>(new SimpleResultStream(std::move(results)));
+ return Status::OK();
+ }
+
+ Status RunAction2(std::unique_ptr<ResultStream>* out) {
+ // Empty
+ *out = std::unique_ptr<ResultStream>(new SimpleResultStream({}));
+ return Status::OK();
+ }
+
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* out) override {
+ if (action.type == "action1") {
+ return RunAction1(action, out);
+ } else if (action.type == "action2") {
+ return RunAction2(out);
+ } else {
+ return Status::NotImplemented(action.type);
+ }
+ }
+
+ Status ListActions(const ServerCallContext& context,
+ std::vector<ActionType>* out) override {
+ std::vector<ActionType> actions = ExampleActionTypes();
+ *out = std::move(actions);
+ return Status::OK();
+ }
+
+ Status GetSchema(const ServerCallContext& context, const FlightDescriptor& request,
+ std::unique_ptr<SchemaResult>* schema) override {
+ std::vector<FlightInfo> flights = ExampleFlightInfo();
+
+ for (const auto& info : flights) {
+ if (info.descriptor().Equals(request)) {
+ *schema =
+ std::unique_ptr<SchemaResult>(new SchemaResult(info.serialized_schema()));
+ return Status::OK();
+ }
+ }
+ return Status::Invalid("Flight not found: ", request.ToString());
+ }
+};
+
+std::unique_ptr<FlightServerBase> ExampleTestServer() {
+ return std::unique_ptr<FlightServerBase>(new FlightTestServer);
+}
+
+Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor,
+ const std::vector<FlightEndpoint>& endpoints, int64_t total_records,
+ int64_t total_bytes, FlightInfo::Data* out) {
+ out->descriptor = descriptor;
+ out->endpoints = endpoints;
+ out->total_records = total_records;
+ out->total_bytes = total_bytes;
+ return internal::SchemaToString(schema, &out->schema);
+}
+
+NumberingStream::NumberingStream(std::unique_ptr<FlightDataStream> stream)
+ : counter_(0), stream_(std::move(stream)) {}
+
+std::shared_ptr<Schema> NumberingStream::schema() { return stream_->schema(); }
+
+Status NumberingStream::GetSchemaPayload(FlightPayload* payload) {
+ return stream_->GetSchemaPayload(payload);
+}
+
+Status NumberingStream::Next(FlightPayload* payload) {
+ RETURN_NOT_OK(stream_->Next(payload));
+ if (payload && payload->ipc_message.type == ipc::MessageType::RECORD_BATCH) {
+ payload->app_metadata = Buffer::FromString(std::to_string(counter_));
+ counter_++;
+ }
+ return Status::OK();
+}
+
+std::shared_ptr<Schema> ExampleIntSchema() {
+ auto f0 = field("f0", int8());
+ auto f1 = field("f1", uint8());
+ auto f2 = field("f2", int16());
+ auto f3 = field("f3", uint16());
+ auto f4 = field("f4", int32());
+ auto f5 = field("f5", uint32());
+ auto f6 = field("f6", int64());
+ auto f7 = field("f7", uint64());
+ return ::arrow::schema({f0, f1, f2, f3, f4, f5, f6, f7});
+}
+
+std::shared_ptr<Schema> ExampleFloatSchema() {
+ auto f0 = field("f0", float16());
+ auto f1 = field("f1", float32());
+ auto f2 = field("f2", float64());
+ return ::arrow::schema({f0, f1, f2});
+}
+
+std::shared_ptr<Schema> ExampleStringSchema() {
+ auto f0 = field("f0", utf8());
+ auto f1 = field("f1", binary());
+ return ::arrow::schema({f0, f1});
+}
+
+std::shared_ptr<Schema> ExampleDictSchema() {
+ std::shared_ptr<RecordBatch> batch;
+ ABORT_NOT_OK(ipc::test::MakeDictionary(&batch));
+ return batch->schema();
+}
+
+std::shared_ptr<Schema> ExampleLargeSchema() {
+ std::vector<std::shared_ptr<arrow::Field>> fields;
+ for (int i = 0; i < 128; i++) {
+ const auto field_name = "f" + std::to_string(i);
+ fields.push_back(arrow::field(field_name, arrow::float64()));
+ }
+ return arrow::schema(fields);
+}
+
+std::vector<FlightInfo> ExampleFlightInfo() {
+ Location location1;
+ Location location2;
+ Location location3;
+ Location location4;
+ Location location5;
+ ARROW_EXPECT_OK(Location::ForGrpcTcp("foo1.bar.com", 12345, &location1));
+ ARROW_EXPECT_OK(Location::ForGrpcTcp("foo2.bar.com", 12345, &location2));
+ ARROW_EXPECT_OK(Location::ForGrpcTcp("foo3.bar.com", 12345, &location3));
+ ARROW_EXPECT_OK(Location::ForGrpcTcp("foo4.bar.com", 12345, &location4));
+ ARROW_EXPECT_OK(Location::ForGrpcTcp("foo5.bar.com", 12345, &location5));
+
+ FlightInfo::Data flight1, flight2, flight3, flight4;
+
+ FlightEndpoint endpoint1({{"ticket-ints-1"}, {location1}});
+ FlightEndpoint endpoint2({{"ticket-ints-2"}, {location2}});
+ FlightEndpoint endpoint3({{"ticket-cmd"}, {location3}});
+ FlightEndpoint endpoint4({{"ticket-dicts-1"}, {location4}});
+ FlightEndpoint endpoint5({{"ticket-floats-1"}, {location5}});
+
+ FlightDescriptor descr1{FlightDescriptor::PATH, "", {"examples", "ints"}};
+ FlightDescriptor descr2{FlightDescriptor::CMD, "my_command", {}};
+ FlightDescriptor descr3{FlightDescriptor::PATH, "", {"examples", "dicts"}};
+ FlightDescriptor descr4{FlightDescriptor::PATH, "", {"examples", "floats"}};
+
+ auto schema1 = ExampleIntSchema();
+ auto schema2 = ExampleStringSchema();
+ auto schema3 = ExampleDictSchema();
+ auto schema4 = ExampleFloatSchema();
+
+ ARROW_EXPECT_OK(
+ MakeFlightInfo(*schema1, descr1, {endpoint1, endpoint2}, 1000, 100000, &flight1));
+ ARROW_EXPECT_OK(MakeFlightInfo(*schema2, descr2, {endpoint3}, 1000, 100000, &flight2));
+ ARROW_EXPECT_OK(MakeFlightInfo(*schema3, descr3, {endpoint4}, -1, -1, &flight3));
+ ARROW_EXPECT_OK(MakeFlightInfo(*schema4, descr4, {endpoint5}, 1000, 100000, &flight4));
+ return {FlightInfo(flight1), FlightInfo(flight2), FlightInfo(flight3),
+ FlightInfo(flight4)};
+}
+
+Status ExampleIntBatches(BatchVector* out) {
+ std::shared_ptr<RecordBatch> batch;
+ for (int i = 0; i < 5; ++i) {
+ // Make all different sizes, use different random seed
+ RETURN_NOT_OK(ipc::test::MakeIntBatchSized(10 + i, &batch, i));
+ out->push_back(batch);
+ }
+ return Status::OK();
+}
+
+Status ExampleFloatBatches(BatchVector* out) {
+ std::shared_ptr<RecordBatch> batch;
+ for (int i = 0; i < 5; ++i) {
+ // Make all different sizes, use different random seed
+ RETURN_NOT_OK(ipc::test::MakeFloatBatchSized(10 + i, &batch, i));
+ out->push_back(batch);
+ }
+ return Status::OK();
+}
+
+Status ExampleDictBatches(BatchVector* out) {
+ // Just the same batch, repeated a few times
+ std::shared_ptr<RecordBatch> batch;
+ for (int i = 0; i < 3; ++i) {
+ RETURN_NOT_OK(ipc::test::MakeDictionary(&batch));
+ out->push_back(batch);
+ }
+ return Status::OK();
+}
+
+Status ExampleNestedBatches(BatchVector* out) {
+ std::shared_ptr<RecordBatch> batch;
+ for (int i = 0; i < 3; ++i) {
+ RETURN_NOT_OK(ipc::test::MakeListRecordBatch(&batch));
+ out->push_back(batch);
+ }
+ return Status::OK();
+}
+
+Status ExampleLargeBatches(BatchVector* out) {
+ const auto array_length = 32768;
+ std::shared_ptr<RecordBatch> batch;
+ std::vector<std::shared_ptr<arrow::Array>> arrays;
+ const auto arr = arrow::ConstantArrayGenerator::Float64(array_length, 1.0);
+ for (int i = 0; i < 128; i++) {
+ arrays.push_back(arr);
+ }
+ auto schema = ExampleLargeSchema();
+ out->push_back(RecordBatch::Make(schema, array_length, arrays));
+ out->push_back(RecordBatch::Make(schema, array_length, arrays));
+ return Status::OK();
+}
+
+arrow::Result<std::shared_ptr<RecordBatch>> VeryLargeBatch() {
+ // In CI, some platforms don't let us allocate one very large
+ // buffer, so allocate a smaller buffer and repeat it a few times
+ constexpr int64_t nbytes = (1ul << 27ul) + 8ul;
+ constexpr int64_t nrows = nbytes / 8;
+ constexpr int64_t ncols = 16;
+ ARROW_ASSIGN_OR_RAISE(auto values, AllocateBuffer(nbytes));
+ std::memset(values->mutable_data(), 0x00, values->capacity());
+ std::vector<std::shared_ptr<Buffer>> buffers = {nullptr, std::move(values)};
+ auto array = std::make_shared<ArrayData>(int64(), nrows, buffers,
+ /*null_count=*/0);
+ std::vector<std::shared_ptr<ArrayData>> arrays(ncols, array);
+ std::vector<std::shared_ptr<Field>> fields(ncols, field("a", int64()));
+ return RecordBatch::Make(schema(std::move(fields)), nrows, std::move(arrays));
+}
+
+std::vector<ActionType> ExampleActionTypes() {
+ return {{"drop", "drop a dataset"}, {"cache", "cache a dataset"}};
+}
+
+TestServerAuthHandler::TestServerAuthHandler(const std::string& username,
+ const std::string& password)
+ : username_(username), password_(password) {}
+
+TestServerAuthHandler::~TestServerAuthHandler() {}
+
+Status TestServerAuthHandler::Authenticate(ServerAuthSender* outgoing,
+ ServerAuthReader* incoming) {
+ std::string token;
+ RETURN_NOT_OK(incoming->Read(&token));
+ if (token != password_) {
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
+ }
+ RETURN_NOT_OK(outgoing->Write(username_));
+ return Status::OK();
+}
+
+Status TestServerAuthHandler::IsValid(const std::string& token,
+ std::string* peer_identity) {
+ if (token != password_) {
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
+ }
+ *peer_identity = username_;
+ return Status::OK();
+}
+
+TestServerBasicAuthHandler::TestServerBasicAuthHandler(const std::string& username,
+ const std::string& password) {
+ basic_auth_.username = username;
+ basic_auth_.password = password;
+}
+
+TestServerBasicAuthHandler::~TestServerBasicAuthHandler() {}
+
+Status TestServerBasicAuthHandler::Authenticate(ServerAuthSender* outgoing,
+ ServerAuthReader* incoming) {
+ std::string token;
+ RETURN_NOT_OK(incoming->Read(&token));
+ BasicAuth incoming_auth;
+ RETURN_NOT_OK(BasicAuth::Deserialize(token, &incoming_auth));
+ if (incoming_auth.username != basic_auth_.username ||
+ incoming_auth.password != basic_auth_.password) {
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
+ }
+ RETURN_NOT_OK(outgoing->Write(basic_auth_.username));
+ return Status::OK();
+}
+
+Status TestServerBasicAuthHandler::IsValid(const std::string& token,
+ std::string* peer_identity) {
+ if (token != basic_auth_.username) {
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
+ }
+ *peer_identity = basic_auth_.username;
+ return Status::OK();
+}
+
+TestClientAuthHandler::TestClientAuthHandler(const std::string& username,
+ const std::string& password)
+ : username_(username), password_(password) {}
+
+TestClientAuthHandler::~TestClientAuthHandler() {}
+
+Status TestClientAuthHandler::Authenticate(ClientAuthSender* outgoing,
+ ClientAuthReader* incoming) {
+ RETURN_NOT_OK(outgoing->Write(password_));
+ std::string username;
+ RETURN_NOT_OK(incoming->Read(&username));
+ if (username != username_) {
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
+ }
+ return Status::OK();
+}
+
+Status TestClientAuthHandler::GetToken(std::string* token) {
+ *token = password_;
+ return Status::OK();
+}
+
+TestClientBasicAuthHandler::TestClientBasicAuthHandler(const std::string& username,
+ const std::string& password) {
+ basic_auth_.username = username;
+ basic_auth_.password = password;
+}
+
+TestClientBasicAuthHandler::~TestClientBasicAuthHandler() {}
+
+Status TestClientBasicAuthHandler::Authenticate(ClientAuthSender* outgoing,
+ ClientAuthReader* incoming) {
+ std::string pb_result;
+ RETURN_NOT_OK(BasicAuth::Serialize(basic_auth_, &pb_result));
+ RETURN_NOT_OK(outgoing->Write(pb_result));
+ RETURN_NOT_OK(incoming->Read(&token_));
+ return Status::OK();
+}
+
+Status TestClientBasicAuthHandler::GetToken(std::string* token) {
+ *token = token_;
+ return Status::OK();
+}
+
+Status ExampleTlsCertificates(std::vector<CertKeyPair>* out) {
+ std::string root;
+ RETURN_NOT_OK(GetTestResourceRoot(&root));
+
+ *out = std::vector<CertKeyPair>();
+ for (int i = 0; i < 2; i++) {
+ try {
+ std::stringstream cert_path;
+ cert_path << root << "/flight/cert" << i << ".pem";
+ std::stringstream key_path;
+ key_path << root << "/flight/cert" << i << ".key";
+
+ std::ifstream cert_file(cert_path.str());
+ if (!cert_file) {
+ return Status::IOError("Could not open certificate: " + cert_path.str());
+ }
+ std::stringstream cert;
+ cert << cert_file.rdbuf();
+
+ std::ifstream key_file(key_path.str());
+ if (!key_file) {
+ return Status::IOError("Could not open key: " + key_path.str());
+ }
+ std::stringstream key;
+ key << key_file.rdbuf();
+
+ out->push_back(CertKeyPair{cert.str(), key.str()});
+ } catch (const std::ifstream::failure& e) {
+ return Status::IOError(e.what());
+ }
+ }
+ return Status::OK();
+}
+
+Status ExampleTlsCertificateRoot(CertKeyPair* out) {
+ std::string root;
+ RETURN_NOT_OK(GetTestResourceRoot(&root));
+
+ std::stringstream path;
+ path << root << "/flight/root-ca.pem";
+
+ try {
+ std::ifstream cert_file(path.str());
+ if (!cert_file) {
+ return Status::IOError("Could not open certificate: " + path.str());
+ }
+ std::stringstream cert;
+ cert << cert_file.rdbuf();
+ out->pem_cert = cert.str();
+ out->pem_key = "";
+ return Status::OK();
+ } catch (const std::ifstream::failure& e) {
+ return Status::IOError(e.what());
+ }
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/test_util.h b/src/arrow/cpp/src/arrow/flight/test_util.h
new file mode 100644
index 000000000..c912c342a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/test_util.h
@@ -0,0 +1,242 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <thread>
+#include <utility>
+#include <vector>
+
+#include "arrow/status.h"
+#include "arrow/testing/util.h"
+
+#include "arrow/flight/client_auth.h"
+#include "arrow/flight/server.h"
+#include "arrow/flight/server_auth.h"
+#include "arrow/flight/types.h"
+#include "arrow/flight/visibility.h"
+
+namespace boost {
+namespace process {
+
+class child;
+
+} // namespace process
+} // namespace boost
+
+namespace arrow {
+namespace flight {
+
+// ----------------------------------------------------------------------
+// Fixture to use for running test servers
+
+class ARROW_FLIGHT_EXPORT TestServer {
+ public:
+ explicit TestServer(const std::string& executable_name)
+ : executable_name_(executable_name), port_(::arrow::GetListenPort()) {}
+ TestServer(const std::string& executable_name, int port)
+ : executable_name_(executable_name), port_(port) {}
+ TestServer(const std::string& executable_name, const std::string& unix_sock)
+ : executable_name_(executable_name), unix_sock_(unix_sock) {}
+
+ void Start(const std::vector<std::string>& extra_args);
+ void Start() { Start({}); }
+
+ int Stop();
+
+ bool IsRunning();
+
+ int port() const;
+ const std::string& unix_sock() const;
+
+ private:
+ std::string executable_name_;
+ int port_;
+ std::string unix_sock_;
+ std::shared_ptr<::boost::process::child> server_process_;
+};
+
+/// \brief Create a simple Flight server for testing
+ARROW_FLIGHT_EXPORT
+std::unique_ptr<FlightServerBase> ExampleTestServer();
+
+// ----------------------------------------------------------------------
+// A RecordBatchReader for serving a sequence of in-memory record batches
+
+// Silence warning
+// "non dll-interface class RecordBatchReader used as base for dll-interface class"
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4275)
+#endif
+
+class ARROW_FLIGHT_EXPORT BatchIterator : public RecordBatchReader {
+ public:
+ BatchIterator(const std::shared_ptr<Schema>& schema,
+ const std::vector<std::shared_ptr<RecordBatch>>& batches)
+ : schema_(schema), batches_(batches), position_(0) {}
+
+ std::shared_ptr<Schema> schema() const override { return schema_; }
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* out) override {
+ if (position_ >= batches_.size()) {
+ *out = nullptr;
+ } else {
+ *out = batches_[position_++];
+ }
+ return Status::OK();
+ }
+
+ private:
+ std::shared_ptr<Schema> schema_;
+ std::vector<std::shared_ptr<RecordBatch>> batches_;
+ size_t position_;
+};
+
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+
+// ----------------------------------------------------------------------
+// A FlightDataStream that numbers the record batches
+/// \brief A basic implementation of FlightDataStream that will provide
+/// a sequence of FlightData messages to be written to a gRPC stream
+class ARROW_FLIGHT_EXPORT NumberingStream : public FlightDataStream {
+ public:
+ explicit NumberingStream(std::unique_ptr<FlightDataStream> stream);
+
+ std::shared_ptr<Schema> schema() override;
+ Status GetSchemaPayload(FlightPayload* payload) override;
+ Status Next(FlightPayload* payload) override;
+
+ private:
+ int counter_;
+ std::shared_ptr<FlightDataStream> stream_;
+};
+
+// ----------------------------------------------------------------------
+// Example data for test-server and unit tests
+
+using BatchVector = std::vector<std::shared_ptr<RecordBatch>>;
+
+ARROW_FLIGHT_EXPORT
+std::shared_ptr<Schema> ExampleIntSchema();
+
+ARROW_FLIGHT_EXPORT
+std::shared_ptr<Schema> ExampleStringSchema();
+
+ARROW_FLIGHT_EXPORT
+std::shared_ptr<Schema> ExampleDictSchema();
+
+ARROW_FLIGHT_EXPORT
+std::shared_ptr<Schema> ExampleLargeSchema();
+
+ARROW_FLIGHT_EXPORT
+Status ExampleIntBatches(BatchVector* out);
+
+ARROW_FLIGHT_EXPORT
+Status ExampleFloatBatches(BatchVector* out);
+
+ARROW_FLIGHT_EXPORT
+Status ExampleDictBatches(BatchVector* out);
+
+ARROW_FLIGHT_EXPORT
+Status ExampleNestedBatches(BatchVector* out);
+
+ARROW_FLIGHT_EXPORT
+Status ExampleLargeBatches(BatchVector* out);
+
+ARROW_FLIGHT_EXPORT
+arrow::Result<std::shared_ptr<RecordBatch>> VeryLargeBatch();
+
+ARROW_FLIGHT_EXPORT
+std::vector<FlightInfo> ExampleFlightInfo();
+
+ARROW_FLIGHT_EXPORT
+std::vector<ActionType> ExampleActionTypes();
+
+ARROW_FLIGHT_EXPORT
+Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor,
+ const std::vector<FlightEndpoint>& endpoints, int64_t total_records,
+ int64_t total_bytes, FlightInfo::Data* out);
+
+// ----------------------------------------------------------------------
+// A pair of authentication handlers that check for a predefined password
+// and set the peer identity to a predefined username.
+
+class ARROW_FLIGHT_EXPORT TestServerAuthHandler : public ServerAuthHandler {
+ public:
+ explicit TestServerAuthHandler(const std::string& username,
+ const std::string& password);
+ ~TestServerAuthHandler() override;
+ Status Authenticate(ServerAuthSender* outgoing, ServerAuthReader* incoming) override;
+ Status IsValid(const std::string& token, std::string* peer_identity) override;
+
+ private:
+ std::string username_;
+ std::string password_;
+};
+
+class ARROW_FLIGHT_EXPORT TestServerBasicAuthHandler : public ServerAuthHandler {
+ public:
+ explicit TestServerBasicAuthHandler(const std::string& username,
+ const std::string& password);
+ ~TestServerBasicAuthHandler() override;
+ Status Authenticate(ServerAuthSender* outgoing, ServerAuthReader* incoming) override;
+ Status IsValid(const std::string& token, std::string* peer_identity) override;
+
+ private:
+ BasicAuth basic_auth_;
+};
+
+class ARROW_FLIGHT_EXPORT TestClientAuthHandler : public ClientAuthHandler {
+ public:
+ explicit TestClientAuthHandler(const std::string& username,
+ const std::string& password);
+ ~TestClientAuthHandler() override;
+ Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override;
+ Status GetToken(std::string* token) override;
+
+ private:
+ std::string username_;
+ std::string password_;
+};
+
+class ARROW_FLIGHT_EXPORT TestClientBasicAuthHandler : public ClientAuthHandler {
+ public:
+ explicit TestClientBasicAuthHandler(const std::string& username,
+ const std::string& password);
+ ~TestClientBasicAuthHandler() override;
+ Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override;
+ Status GetToken(std::string* token) override;
+
+ private:
+ BasicAuth basic_auth_;
+ std::string token_;
+};
+
+ARROW_FLIGHT_EXPORT
+Status ExampleTlsCertificates(std::vector<CertKeyPair>* out);
+
+ARROW_FLIGHT_EXPORT
+Status ExampleTlsCertificateRoot(CertKeyPair* out);
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_127.cc b/src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_127.cc
new file mode 100644
index 000000000..3815d13c5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_127.cc
@@ -0,0 +1,36 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Dummy file for checking if TlsCredentialsOptions exists in
+// the grpc_impl::experimental namespace. gRPC versions 1.27-1.31
+// put it here. This is for supporting disabling server
+// validation when using TLS.
+
+#include <grpc/grpc_security_constants.h>
+#include <grpcpp/grpcpp.h>
+#include <grpcpp/security/tls_credentials_options.h>
+
+static grpc_tls_server_verification_option check(
+ const grpc_impl::experimental::TlsCredentialsOptions* options) {
+ grpc_tls_server_verification_option server_opt = options->server_verification_option();
+ return server_opt;
+}
+
+int main(int argc, const char** argv) {
+ grpc_tls_server_verification_option opt = check(nullptr);
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_132.cc b/src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_132.cc
new file mode 100644
index 000000000..d580aba6e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_132.cc
@@ -0,0 +1,36 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Dummy file for checking if TlsCredentialsOptions exists in
+// the grpc::experimental namespace. gRPC versions 1.32 and higher
+// put it here. This is for supporting disabling server
+// validation when using TLS.
+
+#include <grpc/grpc_security_constants.h>
+#include <grpcpp/grpcpp.h>
+#include <grpcpp/security/tls_credentials_options.h>
+
+static grpc_tls_server_verification_option check(
+ const grpc::experimental::TlsCredentialsOptions* options) {
+ grpc_tls_server_verification_option server_opt = options->server_verification_option();
+ return server_opt;
+}
+
+int main(int argc, const char** argv) {
+ grpc_tls_server_verification_option opt = check(nullptr);
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_134.cc b/src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_134.cc
new file mode 100644
index 000000000..4ee2122ef
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_134.cc
@@ -0,0 +1,44 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Dummy file for checking if TlsCredentialsOptions exists in
+// the grpc::experimental namespace. gRPC starting from 1.34
+// put it here. This is for supporting disabling server
+// validation when using TLS.
+
+#include <grpc/grpc_security_constants.h>
+#include <grpcpp/grpcpp.h>
+#include <grpcpp/security/tls_credentials_options.h>
+
+// Dummy file for checking if TlsCredentialsOptions exists in
+// the grpc::experimental namespace. gRPC starting from 1.34
+// puts it here. This is for supporting disabling server
+// validation when using TLS.
+
+static void check() {
+ // In 1.34, there's no parameterless constructor; in 1.36, there's
+ // only a parameterless constructor
+ auto options =
+ std::make_shared<grpc::experimental::TlsChannelCredentialsOptions>(nullptr);
+ options->set_server_verification_option(
+ grpc_tls_server_verification_option::GRPC_TLS_SERVER_VERIFICATION);
+}
+
+int main(int argc, const char** argv) {
+ check();
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_136.cc b/src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_136.cc
new file mode 100644
index 000000000..638eec67b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/try_compile/check_tls_opts_136.cc
@@ -0,0 +1,38 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Dummy file for checking if TlsCredentialsOptions exists in
+// the grpc::experimental namespace. gRPC starting from 1.36
+// puts it here. This is for supporting disabling server
+// validation when using TLS.
+
+#include <grpc/grpc_security_constants.h>
+#include <grpcpp/grpcpp.h>
+#include <grpcpp/security/tls_credentials_options.h>
+
+static void check() {
+ // In 1.34, there's no parameterless constructor; in 1.36, there's
+ // only a parameterless constructor
+ auto options = std::make_shared<grpc::experimental::TlsChannelCredentialsOptions>();
+ options->set_server_verification_option(
+ grpc_tls_server_verification_option::GRPC_TLS_SERVER_VERIFICATION);
+}
+
+int main(int argc, const char** argv) {
+ check();
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/flight/types.cc b/src/arrow/cpp/src/arrow/flight/types.cc
new file mode 100644
index 000000000..313be1229
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/types.cc
@@ -0,0 +1,378 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/types.h"
+
+#include <memory>
+#include <sstream>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/flight/serialization_internal.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/dictionary.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace flight {
+
+const char* kSchemeGrpc = "grpc";
+const char* kSchemeGrpcTcp = "grpc+tcp";
+const char* kSchemeGrpcUnix = "grpc+unix";
+const char* kSchemeGrpcTls = "grpc+tls";
+
+const char* kErrorDetailTypeId = "flight::FlightStatusDetail";
+
+const char* FlightStatusDetail::type_id() const { return kErrorDetailTypeId; }
+
+std::string FlightStatusDetail::ToString() const { return CodeAsString(); }
+
+FlightStatusCode FlightStatusDetail::code() const { return code_; }
+
+std::string FlightStatusDetail::extra_info() const { return extra_info_; }
+
+void FlightStatusDetail::set_extra_info(std::string extra_info) {
+ extra_info_ = std::move(extra_info);
+}
+
+std::string FlightStatusDetail::CodeAsString() const {
+ switch (code()) {
+ case FlightStatusCode::Internal:
+ return "Internal";
+ case FlightStatusCode::TimedOut:
+ return "TimedOut";
+ case FlightStatusCode::Cancelled:
+ return "Cancelled";
+ case FlightStatusCode::Unauthenticated:
+ return "Unauthenticated";
+ case FlightStatusCode::Unauthorized:
+ return "Unauthorized";
+ case FlightStatusCode::Unavailable:
+ return "Unavailable";
+ default:
+ return "Unknown";
+ }
+}
+
+std::shared_ptr<FlightStatusDetail> FlightStatusDetail::UnwrapStatus(
+ const arrow::Status& status) {
+ if (!status.detail() || status.detail()->type_id() != kErrorDetailTypeId) {
+ return nullptr;
+ }
+ return std::dynamic_pointer_cast<FlightStatusDetail>(status.detail());
+}
+
+Status MakeFlightError(FlightStatusCode code, std::string message,
+ std::string extra_info) {
+ StatusCode arrow_code = arrow::StatusCode::IOError;
+ return arrow::Status(arrow_code, std::move(message),
+ std::make_shared<FlightStatusDetail>(code, std::move(extra_info)));
+}
+
+bool FlightDescriptor::Equals(const FlightDescriptor& other) const {
+ if (type != other.type) {
+ return false;
+ }
+ switch (type) {
+ case PATH:
+ return path == other.path;
+ case CMD:
+ return cmd == other.cmd;
+ default:
+ return false;
+ }
+}
+
+std::string FlightDescriptor::ToString() const {
+ std::stringstream ss;
+ ss << "FlightDescriptor<";
+ switch (type) {
+ case PATH: {
+ bool first = true;
+ ss << "path = '";
+ for (const auto& p : path) {
+ if (!first) {
+ ss << "/";
+ }
+ first = false;
+ ss << p;
+ }
+ ss << "'";
+ break;
+ }
+ case CMD:
+ ss << "cmd = '" << cmd << "'";
+ break;
+ default:
+ break;
+ }
+ ss << ">";
+ return ss.str();
+}
+
+Status FlightPayload::Validate() const {
+ static constexpr int64_t kInt32Max = std::numeric_limits<int32_t>::max();
+ if (descriptor && descriptor->size() > kInt32Max) {
+ return Status::CapacityError("Descriptor size overflow (>= 2**31)");
+ }
+ if (app_metadata && app_metadata->size() > kInt32Max) {
+ return Status::CapacityError("app_metadata size overflow (>= 2**31)");
+ }
+ if (ipc_message.body_length > kInt32Max) {
+ return Status::Invalid("Cannot send record batches exceeding 2GiB yet");
+ }
+ return Status::OK();
+}
+
+Status SchemaResult::GetSchema(ipc::DictionaryMemo* dictionary_memo,
+ std::shared_ptr<Schema>* out) const {
+ io::BufferReader schema_reader(raw_schema_);
+ return ipc::ReadSchema(&schema_reader, dictionary_memo).Value(out);
+}
+
+Status FlightDescriptor::SerializeToString(std::string* out) const {
+ pb::FlightDescriptor pb_descriptor;
+ RETURN_NOT_OK(internal::ToProto(*this, &pb_descriptor));
+
+ if (!pb_descriptor.SerializeToString(out)) {
+ return Status::IOError("Serialized descriptor exceeded 2 GiB limit");
+ }
+ return Status::OK();
+}
+
+Status FlightDescriptor::Deserialize(const std::string& serialized,
+ FlightDescriptor* out) {
+ pb::FlightDescriptor pb_descriptor;
+ if (!pb_descriptor.ParseFromString(serialized)) {
+ return Status::Invalid("Not a valid descriptor");
+ }
+ return internal::FromProto(pb_descriptor, out);
+}
+
+bool Ticket::Equals(const Ticket& other) const { return ticket == other.ticket; }
+
+Status Ticket::SerializeToString(std::string* out) const {
+ pb::Ticket pb_ticket;
+ internal::ToProto(*this, &pb_ticket);
+
+ if (!pb_ticket.SerializeToString(out)) {
+ return Status::IOError("Serialized ticket exceeded 2 GiB limit");
+ }
+ return Status::OK();
+}
+
+Status Ticket::Deserialize(const std::string& serialized, Ticket* out) {
+ pb::Ticket pb_ticket;
+ if (!pb_ticket.ParseFromString(serialized)) {
+ return Status::Invalid("Not a valid ticket");
+ }
+ return internal::FromProto(pb_ticket, out);
+}
+
+arrow::Result<FlightInfo> FlightInfo::Make(const Schema& schema,
+ const FlightDescriptor& descriptor,
+ const std::vector<FlightEndpoint>& endpoints,
+ int64_t total_records, int64_t total_bytes) {
+ FlightInfo::Data data;
+ data.descriptor = descriptor;
+ data.endpoints = endpoints;
+ data.total_records = total_records;
+ data.total_bytes = total_bytes;
+ RETURN_NOT_OK(internal::SchemaToString(schema, &data.schema));
+ return FlightInfo(data);
+}
+
+Status FlightInfo::GetSchema(ipc::DictionaryMemo* dictionary_memo,
+ std::shared_ptr<Schema>* out) const {
+ if (reconstructed_schema_) {
+ *out = schema_;
+ return Status::OK();
+ }
+ io::BufferReader schema_reader(data_.schema);
+ RETURN_NOT_OK(ipc::ReadSchema(&schema_reader, dictionary_memo).Value(&schema_));
+ reconstructed_schema_ = true;
+ *out = schema_;
+ return Status::OK();
+}
+
+Status FlightInfo::SerializeToString(std::string* out) const {
+ pb::FlightInfo pb_info;
+ RETURN_NOT_OK(internal::ToProto(*this, &pb_info));
+
+ if (!pb_info.SerializeToString(out)) {
+ return Status::IOError("Serialized FlightInfo exceeded 2 GiB limit");
+ }
+ return Status::OK();
+}
+
+Status FlightInfo::Deserialize(const std::string& serialized,
+ std::unique_ptr<FlightInfo>* out) {
+ pb::FlightInfo pb_info;
+ if (!pb_info.ParseFromString(serialized)) {
+ return Status::Invalid("Not a valid FlightInfo");
+ }
+ FlightInfo::Data data;
+ RETURN_NOT_OK(internal::FromProto(pb_info, &data));
+ out->reset(new FlightInfo(data));
+ return Status::OK();
+}
+
+Location::Location() { uri_ = std::make_shared<arrow::internal::Uri>(); }
+
+Status Location::Parse(const std::string& uri_string, Location* location) {
+ return location->uri_->Parse(uri_string);
+}
+
+Status Location::ForGrpcTcp(const std::string& host, const int port, Location* location) {
+ std::stringstream uri_string;
+ uri_string << "grpc+tcp://" << host << ':' << port;
+ return Location::Parse(uri_string.str(), location);
+}
+
+Status Location::ForGrpcTls(const std::string& host, const int port, Location* location) {
+ std::stringstream uri_string;
+ uri_string << "grpc+tls://" << host << ':' << port;
+ return Location::Parse(uri_string.str(), location);
+}
+
+Status Location::ForGrpcUnix(const std::string& path, Location* location) {
+ std::stringstream uri_string;
+ uri_string << "grpc+unix://" << path;
+ return Location::Parse(uri_string.str(), location);
+}
+
+std::string Location::ToString() const { return uri_->ToString(); }
+std::string Location::scheme() const {
+ std::string scheme = uri_->scheme();
+ if (scheme.empty()) {
+ // Default to grpc+tcp
+ return "grpc+tcp";
+ }
+ return scheme;
+}
+
+bool Location::Equals(const Location& other) const {
+ return ToString() == other.ToString();
+}
+
+bool FlightEndpoint::Equals(const FlightEndpoint& other) const {
+ return ticket == other.ticket && locations == other.locations;
+}
+
+Status MetadataRecordBatchReader::ReadAll(
+ std::vector<std::shared_ptr<RecordBatch>>* batches) {
+ FlightStreamChunk chunk;
+
+ while (true) {
+ RETURN_NOT_OK(Next(&chunk));
+ if (!chunk.data) break;
+ batches->emplace_back(std::move(chunk.data));
+ }
+ return Status::OK();
+}
+
+Status MetadataRecordBatchReader::ReadAll(std::shared_ptr<Table>* table) {
+ std::vector<std::shared_ptr<RecordBatch>> batches;
+ RETURN_NOT_OK(ReadAll(&batches));
+ ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema());
+ return Table::FromRecordBatches(schema, std::move(batches)).Value(table);
+}
+
+Status MetadataRecordBatchWriter::Begin(const std::shared_ptr<Schema>& schema) {
+ return Begin(schema, ipc::IpcWriteOptions::Defaults());
+}
+
+namespace {
+class MetadataRecordBatchReaderAdapter : public RecordBatchReader {
+ public:
+ explicit MetadataRecordBatchReaderAdapter(
+ std::shared_ptr<Schema> schema, std::shared_ptr<MetadataRecordBatchReader> delegate)
+ : schema_(std::move(schema)), delegate_(std::move(delegate)) {}
+ std::shared_ptr<Schema> schema() const override { return schema_; }
+ Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
+ FlightStreamChunk next;
+ while (true) {
+ RETURN_NOT_OK(delegate_->Next(&next));
+ if (!next.data && !next.app_metadata) {
+ // EOS
+ *batch = nullptr;
+ return Status::OK();
+ } else if (next.data) {
+ *batch = std::move(next.data);
+ return Status::OK();
+ }
+ // Got metadata, but no data (which is valid) - read the next message
+ }
+ }
+
+ private:
+ std::shared_ptr<Schema> schema_;
+ std::shared_ptr<MetadataRecordBatchReader> delegate_;
+};
+}; // namespace
+
+arrow::Result<std::shared_ptr<RecordBatchReader>> MakeRecordBatchReader(
+ std::shared_ptr<MetadataRecordBatchReader> reader) {
+ ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
+ return std::make_shared<MetadataRecordBatchReaderAdapter>(std::move(schema),
+ std::move(reader));
+}
+
+SimpleFlightListing::SimpleFlightListing(const std::vector<FlightInfo>& flights)
+ : position_(0), flights_(flights) {}
+
+SimpleFlightListing::SimpleFlightListing(std::vector<FlightInfo>&& flights)
+ : position_(0), flights_(std::move(flights)) {}
+
+Status SimpleFlightListing::Next(std::unique_ptr<FlightInfo>* info) {
+ if (position_ >= static_cast<int>(flights_.size())) {
+ *info = nullptr;
+ return Status::OK();
+ }
+ *info = std::unique_ptr<FlightInfo>(new FlightInfo(std::move(flights_[position_++])));
+ return Status::OK();
+}
+
+SimpleResultStream::SimpleResultStream(std::vector<Result>&& results)
+ : results_(std::move(results)), position_(0) {}
+
+Status SimpleResultStream::Next(std::unique_ptr<Result>* result) {
+ if (position_ >= results_.size()) {
+ *result = nullptr;
+ return Status::OK();
+ }
+ *result = std::unique_ptr<Result>(new Result(std::move(results_[position_++])));
+ return Status::OK();
+}
+
+Status BasicAuth::Deserialize(const std::string& serialized, BasicAuth* out) {
+ pb::BasicAuth pb_result;
+ pb_result.ParseFromString(serialized);
+ return internal::FromProto(pb_result, out);
+}
+
+Status BasicAuth::Serialize(const BasicAuth& basic_auth, std::string* out) {
+ pb::BasicAuth pb_result;
+ RETURN_NOT_OK(internal::ToProto(basic_auth, &pb_result));
+ *out = pb_result.SerializeAsString();
+ return Status::OK();
+}
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/types.h b/src/arrow/cpp/src/arrow/flight/types.h
new file mode 100644
index 000000000..1e3051d5c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/types.h
@@ -0,0 +1,529 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Data structure for Flight RPC. API should be considered experimental for now
+
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/flight/visibility.h"
+#include "arrow/ipc/options.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/result.h"
+
+namespace arrow {
+
+class Buffer;
+class RecordBatch;
+class Schema;
+class Status;
+class Table;
+
+namespace ipc {
+
+class DictionaryMemo;
+
+} // namespace ipc
+
+namespace internal {
+
+class Uri;
+
+} // namespace internal
+
+namespace flight {
+
+/// \brief A Flight-specific status code.
+enum class FlightStatusCode : int8_t {
+ /// An implementation error has occurred.
+ Internal,
+ /// A request timed out.
+ TimedOut,
+ /// A request was cancelled.
+ Cancelled,
+ /// We are not authenticated to the remote service.
+ Unauthenticated,
+ /// We do not have permission to make this request.
+ Unauthorized,
+ /// The remote service cannot handle this request at the moment.
+ Unavailable,
+ /// A request failed for some other reason
+ Failed
+};
+
+// Silence warning
+// "non dll-interface class RecordBatchReader used as base for dll-interface class"
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4275)
+#endif
+
+/// \brief Flight-specific error information in a Status.
+class ARROW_FLIGHT_EXPORT FlightStatusDetail : public arrow::StatusDetail {
+ public:
+ explicit FlightStatusDetail(FlightStatusCode code) : code_{code} {}
+ explicit FlightStatusDetail(FlightStatusCode code, std::string extra_info)
+ : code_{code}, extra_info_(std::move(extra_info)) {}
+ const char* type_id() const override;
+ std::string ToString() const override;
+
+ /// \brief Get the Flight status code.
+ FlightStatusCode code() const;
+ /// \brief Get the extra error info
+ std::string extra_info() const;
+ /// \brief Get the human-readable name of the status code.
+ std::string CodeAsString() const;
+ /// \brief Set the extra error info
+ void set_extra_info(std::string extra_info);
+
+ /// \brief Try to extract a \a FlightStatusDetail from any Arrow
+ /// status.
+ ///
+ /// \return a \a FlightStatusDetail if it could be unwrapped, \a
+ /// nullptr otherwise
+ static std::shared_ptr<FlightStatusDetail> UnwrapStatus(const arrow::Status& status);
+
+ private:
+ FlightStatusCode code_;
+ std::string extra_info_;
+};
+
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+
+/// \brief Make an appropriate Arrow status for the given
+/// Flight-specific status.
+///
+/// \param code The Flight status code.
+/// \param message The message for the error.
+/// \param extra_info Optional extra binary info for the error (eg protobuf)
+ARROW_FLIGHT_EXPORT
+Status MakeFlightError(FlightStatusCode code, std::string message,
+ std::string extra_info = {});
+
+/// \brief A TLS certificate plus key.
+struct ARROW_FLIGHT_EXPORT CertKeyPair {
+ /// \brief The certificate in PEM format.
+ std::string pem_cert;
+
+ /// \brief The key in PEM format.
+ std::string pem_key;
+};
+
+/// \brief A type of action that can be performed with the DoAction RPC.
+struct ARROW_FLIGHT_EXPORT ActionType {
+ /// \brief The name of the action.
+ std::string type;
+
+ /// \brief A human-readable description of the action.
+ std::string description;
+};
+
+/// \brief Opaque selection criteria for ListFlights RPC
+struct ARROW_FLIGHT_EXPORT Criteria {
+ /// Opaque criteria expression, dependent on server implementation
+ std::string expression;
+};
+
+/// \brief An action to perform with the DoAction RPC
+struct ARROW_FLIGHT_EXPORT Action {
+ /// The action type
+ std::string type;
+
+ /// The action content as a Buffer
+ std::shared_ptr<Buffer> body;
+};
+
+/// \brief Opaque result returned after executing an action
+struct ARROW_FLIGHT_EXPORT Result {
+ std::shared_ptr<Buffer> body;
+};
+
+/// \brief message for simple auth
+struct ARROW_FLIGHT_EXPORT BasicAuth {
+ std::string username;
+ std::string password;
+
+ static Status Deserialize(const std::string& serialized, BasicAuth* out);
+
+ static Status Serialize(const BasicAuth& basic_auth, std::string* out);
+};
+
+/// \brief A request to retrieve or generate a dataset
+struct ARROW_FLIGHT_EXPORT FlightDescriptor {
+ enum DescriptorType {
+ UNKNOWN = 0, /// Unused
+ PATH = 1, /// Named path identifying a dataset
+ CMD = 2 /// Opaque command to generate a dataset
+ };
+
+ /// The descriptor type
+ DescriptorType type;
+
+ /// Opaque value used to express a command. Should only be defined when type
+ /// is CMD
+ std::string cmd;
+
+ /// List of strings identifying a particular dataset. Should only be defined
+ /// when type is PATH
+ std::vector<std::string> path;
+
+ bool Equals(const FlightDescriptor& other) const;
+
+ /// \brief Get a human-readable form of this descriptor.
+ std::string ToString() const;
+
+ /// \brief Get the wire-format representation of this type.
+ ///
+ /// Useful when interoperating with non-Flight systems (e.g. REST
+ /// services) that may want to return Flight types.
+ Status SerializeToString(std::string* out) const;
+
+ /// \brief Parse the wire-format representation of this type.
+ ///
+ /// Useful when interoperating with non-Flight systems (e.g. REST
+ /// services) that may want to return Flight types.
+ static Status Deserialize(const std::string& serialized, FlightDescriptor* out);
+
+ // Convenience factory functions
+
+ static FlightDescriptor Command(const std::string& c) {
+ return FlightDescriptor{CMD, c, {}};
+ }
+
+ static FlightDescriptor Path(const std::vector<std::string>& p) {
+ return FlightDescriptor{PATH, "", p};
+ }
+
+ friend bool operator==(const FlightDescriptor& left, const FlightDescriptor& right) {
+ return left.Equals(right);
+ }
+ friend bool operator!=(const FlightDescriptor& left, const FlightDescriptor& right) {
+ return !(left == right);
+ }
+};
+
+/// \brief Data structure providing an opaque identifier or credential to use
+/// when requesting a data stream with the DoGet RPC
+struct ARROW_FLIGHT_EXPORT Ticket {
+ std::string ticket;
+
+ bool Equals(const Ticket& other) const;
+
+ friend bool operator==(const Ticket& left, const Ticket& right) {
+ return left.Equals(right);
+ }
+ friend bool operator!=(const Ticket& left, const Ticket& right) {
+ return !(left == right);
+ }
+
+ /// \brief Get the wire-format representation of this type.
+ ///
+ /// Useful when interoperating with non-Flight systems (e.g. REST
+ /// services) that may want to return Flight types.
+ Status SerializeToString(std::string* out) const;
+
+ /// \brief Parse the wire-format representation of this type.
+ ///
+ /// Useful when interoperating with non-Flight systems (e.g. REST
+ /// services) that may want to return Flight types.
+ static Status Deserialize(const std::string& serialized, Ticket* out);
+};
+
+class FlightClient;
+class FlightServerBase;
+
+ARROW_FLIGHT_EXPORT
+extern const char* kSchemeGrpc;
+ARROW_FLIGHT_EXPORT
+extern const char* kSchemeGrpcTcp;
+ARROW_FLIGHT_EXPORT
+extern const char* kSchemeGrpcUnix;
+ARROW_FLIGHT_EXPORT
+extern const char* kSchemeGrpcTls;
+
+/// \brief A host location (a URI)
+struct ARROW_FLIGHT_EXPORT Location {
+ public:
+ /// \brief Initialize a blank location.
+ Location();
+
+ /// \brief Initialize a location by parsing a URI string
+ static Status Parse(const std::string& uri_string, Location* location);
+
+ /// \brief Initialize a location for a non-TLS, gRPC-based Flight
+ /// service from a host and port
+ /// \param[in] host The hostname to connect to
+ /// \param[in] port The port
+ /// \param[out] location The resulting location
+ static Status ForGrpcTcp(const std::string& host, const int port, Location* location);
+
+ /// \brief Initialize a location for a TLS-enabled, gRPC-based Flight
+ /// service from a host and port
+ /// \param[in] host The hostname to connect to
+ /// \param[in] port The port
+ /// \param[out] location The resulting location
+ static Status ForGrpcTls(const std::string& host, const int port, Location* location);
+
+ /// \brief Initialize a location for a domain socket-based Flight
+ /// service
+ /// \param[in] path The path to the domain socket
+ /// \param[out] location The resulting location
+ static Status ForGrpcUnix(const std::string& path, Location* location);
+
+ /// \brief Get a representation of this URI as a string.
+ std::string ToString() const;
+
+ /// \brief Get the scheme of this URI.
+ std::string scheme() const;
+
+ bool Equals(const Location& other) const;
+
+ friend bool operator==(const Location& left, const Location& right) {
+ return left.Equals(right);
+ }
+ friend bool operator!=(const Location& left, const Location& right) {
+ return !(left == right);
+ }
+
+ private:
+ friend class FlightClient;
+ friend class FlightServerBase;
+ std::shared_ptr<arrow::internal::Uri> uri_;
+};
+
+/// \brief A flight ticket and list of locations where the ticket can be
+/// redeemed
+struct ARROW_FLIGHT_EXPORT FlightEndpoint {
+ /// Opaque ticket identify; use with DoGet RPC
+ Ticket ticket;
+
+ /// List of locations where ticket can be redeemed. If the list is empty, the
+ /// ticket can only be redeemed on the current service where the ticket was
+ /// generated
+ std::vector<Location> locations;
+
+ bool Equals(const FlightEndpoint& other) const;
+
+ friend bool operator==(const FlightEndpoint& left, const FlightEndpoint& right) {
+ return left.Equals(right);
+ }
+ friend bool operator!=(const FlightEndpoint& left, const FlightEndpoint& right) {
+ return !(left == right);
+ }
+};
+
+/// \brief Staging data structure for messages about to be put on the wire
+///
+/// This structure corresponds to FlightData in the protocol.
+struct ARROW_FLIGHT_EXPORT FlightPayload {
+ std::shared_ptr<Buffer> descriptor;
+ std::shared_ptr<Buffer> app_metadata;
+ ipc::IpcPayload ipc_message;
+
+ /// \brief Check that the payload can be written to the wire.
+ Status Validate() const;
+};
+
+/// \brief Schema result returned after a schema request RPC
+struct ARROW_FLIGHT_EXPORT SchemaResult {
+ public:
+ explicit SchemaResult(std::string schema) : raw_schema_(std::move(schema)) {}
+
+ /// \brief return schema
+ /// \param[in,out] dictionary_memo for dictionary bookkeeping, will
+ /// be modified
+ /// \param[out] out the reconstructed Schema
+ Status GetSchema(ipc::DictionaryMemo* dictionary_memo,
+ std::shared_ptr<Schema>* out) const;
+
+ const std::string& serialized_schema() const { return raw_schema_; }
+
+ private:
+ std::string raw_schema_;
+};
+
+/// \brief The access coordinates for retireval of a dataset, returned by
+/// GetFlightInfo
+class ARROW_FLIGHT_EXPORT FlightInfo {
+ public:
+ struct Data {
+ std::string schema;
+ FlightDescriptor descriptor;
+ std::vector<FlightEndpoint> endpoints;
+ int64_t total_records;
+ int64_t total_bytes;
+ };
+
+ explicit FlightInfo(const Data& data) : data_(data), reconstructed_schema_(false) {}
+ explicit FlightInfo(Data&& data)
+ : data_(std::move(data)), reconstructed_schema_(false) {}
+
+ /// \brief Factory method to construct a FlightInfo.
+ static arrow::Result<FlightInfo> Make(const Schema& schema,
+ const FlightDescriptor& descriptor,
+ const std::vector<FlightEndpoint>& endpoints,
+ int64_t total_records, int64_t total_bytes);
+
+ /// \brief Deserialize the Arrow schema of the dataset, to be passed
+ /// to each call to DoGet. Populate any dictionary encoded fields
+ /// into a DictionaryMemo for bookkeeping
+ /// \param[in,out] dictionary_memo for dictionary bookkeeping, will
+ /// be modified
+ /// \param[out] out the reconstructed Schema
+ Status GetSchema(ipc::DictionaryMemo* dictionary_memo,
+ std::shared_ptr<Schema>* out) const;
+
+ const std::string& serialized_schema() const { return data_.schema; }
+
+ /// The descriptor associated with this flight, may not be set
+ const FlightDescriptor& descriptor() const { return data_.descriptor; }
+
+ /// A list of endpoints associated with the flight (dataset). To consume the
+ /// whole flight, all endpoints must be consumed
+ const std::vector<FlightEndpoint>& endpoints() const { return data_.endpoints; }
+
+ /// The total number of records (rows) in the dataset. If unknown, set to -1
+ int64_t total_records() const { return data_.total_records; }
+
+ /// The total number of bytes in the dataset. If unknown, set to -1
+ int64_t total_bytes() const { return data_.total_bytes; }
+
+ /// \brief Get the wire-format representation of this type.
+ ///
+ /// Useful when interoperating with non-Flight systems (e.g. REST
+ /// services) that may want to return Flight types.
+ Status SerializeToString(std::string* out) const;
+
+ /// \brief Parse the wire-format representation of this type.
+ ///
+ /// Useful when interoperating with non-Flight systems (e.g. REST
+ /// services) that may want to return Flight types.
+ static Status Deserialize(const std::string& serialized,
+ std::unique_ptr<FlightInfo>* out);
+
+ private:
+ Data data_;
+ mutable std::shared_ptr<Schema> schema_;
+ mutable bool reconstructed_schema_;
+};
+
+/// \brief An iterator to FlightInfo instances returned by ListFlights.
+class ARROW_FLIGHT_EXPORT FlightListing {
+ public:
+ virtual ~FlightListing() = default;
+
+ /// \brief Retrieve the next FlightInfo from the iterator.
+ /// \param[out] info A single FlightInfo. Set to \a nullptr if there
+ /// are none left.
+ /// \return Status
+ virtual Status Next(std::unique_ptr<FlightInfo>* info) = 0;
+};
+
+/// \brief An iterator to Result instances returned by DoAction.
+class ARROW_FLIGHT_EXPORT ResultStream {
+ public:
+ virtual ~ResultStream() = default;
+
+ /// \brief Retrieve the next Result from the iterator.
+ /// \param[out] info A single result. Set to \a nullptr if there
+ /// are none left.
+ /// \return Status
+ virtual Status Next(std::unique_ptr<Result>* info) = 0;
+};
+
+/// \brief A holder for a RecordBatch with associated Flight metadata.
+struct ARROW_FLIGHT_EXPORT FlightStreamChunk {
+ public:
+ std::shared_ptr<RecordBatch> data;
+ std::shared_ptr<Buffer> app_metadata;
+};
+
+/// \brief An interface to read Flight data with metadata.
+class ARROW_FLIGHT_EXPORT MetadataRecordBatchReader {
+ public:
+ virtual ~MetadataRecordBatchReader() = default;
+
+ /// \brief Get the schema for this stream.
+ virtual arrow::Result<std::shared_ptr<Schema>> GetSchema() = 0;
+ /// \brief Get the next message from Flight. If the stream is
+ /// finished, then the members of \a FlightStreamChunk will be
+ /// nullptr.
+ virtual Status Next(FlightStreamChunk* next) = 0;
+ /// \brief Consume entire stream as a vector of record batches
+ virtual Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches);
+ /// \brief Consume entire stream as a Table
+ virtual Status ReadAll(std::shared_ptr<Table>* table);
+};
+
+/// \brief Convert a MetadataRecordBatchReader to a regular RecordBatchReader.
+ARROW_FLIGHT_EXPORT
+arrow::Result<std::shared_ptr<RecordBatchReader>> MakeRecordBatchReader(
+ std::shared_ptr<MetadataRecordBatchReader> reader);
+
+/// \brief An interface to write IPC payloads with metadata.
+class ARROW_FLIGHT_EXPORT MetadataRecordBatchWriter : public ipc::RecordBatchWriter {
+ public:
+ virtual ~MetadataRecordBatchWriter() = default;
+ /// \brief Begin writing data with the given schema. Only used with \a DoExchange.
+ virtual Status Begin(const std::shared_ptr<Schema>& schema,
+ const ipc::IpcWriteOptions& options) = 0;
+ virtual Status Begin(const std::shared_ptr<Schema>& schema);
+ virtual Status WriteMetadata(std::shared_ptr<Buffer> app_metadata) = 0;
+ virtual Status WriteWithMetadata(const RecordBatch& batch,
+ std::shared_ptr<Buffer> app_metadata) = 0;
+};
+
+/// \brief A FlightListing implementation based on a vector of
+/// FlightInfo objects.
+///
+/// This can be iterated once, then it is consumed.
+class ARROW_FLIGHT_EXPORT SimpleFlightListing : public FlightListing {
+ public:
+ explicit SimpleFlightListing(const std::vector<FlightInfo>& flights);
+ explicit SimpleFlightListing(std::vector<FlightInfo>&& flights);
+
+ Status Next(std::unique_ptr<FlightInfo>* info) override;
+
+ private:
+ int position_;
+ std::vector<FlightInfo> flights_;
+};
+
+/// \brief A ResultStream implementation based on a vector of
+/// Result objects.
+///
+/// This can be iterated once, then it is consumed.
+class ARROW_FLIGHT_EXPORT SimpleResultStream : public ResultStream {
+ public:
+ explicit SimpleResultStream(std::vector<Result>&& results);
+ Status Next(std::unique_ptr<Result>* result) override;
+
+ private:
+ std::vector<Result> results_;
+ size_t position_;
+};
+
+} // namespace flight
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/flight/visibility.h b/src/arrow/cpp/src/arrow/flight/visibility.h
new file mode 100644
index 000000000..bdee8b751
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/flight/visibility.h
@@ -0,0 +1,48 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#if defined(_WIN32) || defined(__CYGWIN__)
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4251)
+#else
+#pragma GCC diagnostic ignored "-Wattributes"
+#endif
+
+#ifdef ARROW_FLIGHT_STATIC
+#define ARROW_FLIGHT_EXPORT
+#elif defined(ARROW_FLIGHT_EXPORTING)
+#define ARROW_FLIGHT_EXPORT __declspec(dllexport)
+#else
+#define ARROW_FLIGHT_EXPORT __declspec(dllimport)
+#endif
+
+#define ARROW_FLIGHT_NO_EXPORT
+#else // Not Windows
+#ifndef ARROW_FLIGHT_EXPORT
+#define ARROW_FLIGHT_EXPORT __attribute__((visibility("default")))
+#endif
+#ifndef ARROW_FLIGHT_NO_EXPORT
+#define ARROW_FLIGHT_NO_EXPORT __attribute__((visibility("hidden")))
+#endif
+#endif // Non-Windows
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
diff --git a/src/arrow/cpp/src/arrow/gpu/.gitignore b/src/arrow/cpp/src/arrow/gpu/.gitignore
new file mode 100644
index 000000000..0ef3f98c5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/.gitignore
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+cuda_version.h
diff --git a/src/arrow/cpp/src/arrow/gpu/ArrowCUDAConfig.cmake.in b/src/arrow/cpp/src/arrow/gpu/ArrowCUDAConfig.cmake.in
new file mode 100644
index 000000000..67bb58093
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/ArrowCUDAConfig.cmake.in
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This config sets the following variables in your project::
+#
+# ArrowCUDA_FOUND - true if Arrow CUDA found on the system
+#
+# This config sets the following targets in your project::
+#
+# arrow_cuda_shared - for linked as shared library if shared library is built
+# arrow_cuda_static - for linked as static library if static library is built
+
+@PACKAGE_INIT@
+
+include(CMakeFindDependencyMacro)
+find_dependency(Arrow)
+
+# Load targets only once. If we load targets multiple times, CMake reports
+# already existent target error.
+if(NOT (TARGET arrow_cuda_shared OR TARGET arrow_cuda_static))
+ include("${CMAKE_CURRENT_LIST_DIR}/ArrowCUDATargets.cmake")
+endif()
diff --git a/src/arrow/cpp/src/arrow/gpu/CMakeLists.txt b/src/arrow/cpp/src/arrow/gpu/CMakeLists.txt
new file mode 100644
index 000000000..a1c182a58
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/CMakeLists.txt
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#
+# arrow_cuda
+#
+
+add_custom_target(arrow_cuda-all)
+add_custom_target(arrow_cuda)
+add_custom_target(arrow_cuda-benchmarks)
+add_custom_target(arrow_cuda-tests)
+add_dependencies(arrow_cuda-all arrow_cuda arrow_cuda-tests arrow_cuda-benchmarks)
+
+if(DEFINED ENV{CUDA_HOME})
+ set(CUDA_TOOLKIT_ROOT_DIR "$ENV{CUDA_HOME}")
+endif()
+
+find_package(CUDA REQUIRED)
+include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
+
+set(ARROW_CUDA_SRCS cuda_arrow_ipc.cc cuda_context.cc cuda_internal.cc cuda_memory.cc)
+
+set(ARROW_CUDA_SHARED_LINK_LIBS ${CUDA_CUDA_LIBRARY})
+
+add_arrow_lib(arrow_cuda
+ CMAKE_PACKAGE_NAME
+ ArrowCUDA
+ PKG_CONFIG_NAME
+ arrow-cuda
+ SOURCES
+ ${ARROW_CUDA_SRCS}
+ OUTPUTS
+ ARROW_CUDA_LIBRARIES
+ SHARED_LINK_FLAGS
+ ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt
+ SHARED_LINK_LIBS
+ arrow_shared
+ ${ARROW_CUDA_SHARED_LINK_LIBS}
+ # Static arrow_cuda must also link against CUDA shared libs
+ STATIC_LINK_LIBS
+ ${ARROW_CUDA_SHARED_LINK_LIBS})
+
+add_dependencies(arrow_cuda ${ARROW_CUDA_LIBRARIES})
+
+foreach(LIB_TARGET ${ARROW_CUDA_LIBRARIES})
+ target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_EXPORTING)
+endforeach()
+
+# CUDA build version
+configure_file(cuda_version.h.in "${CMAKE_CURRENT_BINARY_DIR}/cuda_version.h" @ONLY)
+
+install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cuda_version.h"
+ DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/arrow/gpu")
+
+arrow_install_all_headers("arrow/gpu")
+
+if(ARROW_BUILD_SHARED)
+ set(ARROW_CUDA_LIBRARY arrow_cuda_shared)
+else()
+ set(ARROW_CUDA_LIBRARY arrow_cuda_static)
+endif()
+
+set(ARROW_CUDA_TEST_LINK_LIBS ${ARROW_CUDA_LIBRARY} ${ARROW_TEST_LINK_LIBS})
+
+if(ARROW_BUILD_TESTS)
+ add_arrow_test(cuda_test STATIC_LINK_LIBS ${ARROW_CUDA_TEST_LINK_LIBS} NO_VALGRIND)
+endif()
+
+if(ARROW_BUILD_BENCHMARKS)
+ cuda_add_executable(arrow-cuda-benchmark cuda_benchmark.cc)
+ target_link_libraries(arrow-cuda-benchmark ${ARROW_CUDA_LIBRARY} GTest::gtest
+ ${ARROW_BENCHMARK_LINK_LIBS})
+ add_dependencies(arrow_cuda-benchmarks arrow-cuda-benchmark)
+endif()
diff --git a/src/arrow/cpp/src/arrow/gpu/arrow-cuda.pc.in b/src/arrow/cpp/src/arrow/gpu/arrow-cuda.pc.in
new file mode 100644
index 000000000..858096f89
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/arrow-cuda.pc.in
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+
+Name: Apache Arrow CUDA
+Description: CUDA integration library for Apache Arrow
+Version: @ARROW_VERSION@
+Requires: arrow
+Libs: -L${libdir} -larrow_cuda
+Cflags: -I${includedir}
diff --git a/src/arrow/cpp/src/arrow/gpu/cuda_api.h b/src/arrow/cpp/src/arrow/gpu/cuda_api.h
new file mode 100644
index 000000000..33fdaf6b1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/cuda_api.h
@@ -0,0 +1,23 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/gpu/cuda_arrow_ipc.h"
+#include "arrow/gpu/cuda_context.h"
+#include "arrow/gpu/cuda_memory.h"
+#include "arrow/gpu/cuda_version.h"
diff --git a/src/arrow/cpp/src/arrow/gpu/cuda_arrow_ipc.cc b/src/arrow/cpp/src/arrow/gpu/cuda_arrow_ipc.cc
new file mode 100644
index 000000000..a928df013
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/cuda_arrow_ipc.cc
@@ -0,0 +1,69 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/gpu/cuda_arrow_ipc.h"
+
+#include <cstdint>
+#include <memory>
+#include <sstream>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/ipc/dictionary.h"
+#include "arrow/ipc/message.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/util/visibility.h"
+
+#include "generated/Message_generated.h"
+
+#include "arrow/gpu/cuda_context.h"
+#include "arrow/gpu/cuda_memory.h"
+
+namespace arrow {
+
+namespace flatbuf = org::apache::arrow::flatbuf;
+
+namespace cuda {
+
+Result<std::shared_ptr<CudaBuffer>> SerializeRecordBatch(const RecordBatch& batch,
+ CudaContext* ctx) {
+ ARROW_ASSIGN_OR_RAISE(auto buf,
+ ipc::SerializeRecordBatch(batch, ctx->memory_manager()));
+ return CudaBuffer::FromBuffer(buf);
+}
+
+Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
+ const std::shared_ptr<Schema>& schema, const ipc::DictionaryMemo* dictionary_memo,
+ const std::shared_ptr<CudaBuffer>& buffer, MemoryPool* pool) {
+ CudaBufferReader cuda_reader(buffer);
+
+ // The pool is only used for metadata allocation
+ ARROW_ASSIGN_OR_RAISE(auto message, ipc::ReadMessage(&cuda_reader, pool));
+ if (!message) {
+ return Status::Invalid("End of stream (message has length 0)");
+ }
+
+ // Zero-copy read on device memory
+ return ipc::ReadRecordBatch(*message, schema, dictionary_memo,
+ ipc::IpcReadOptions::Defaults());
+}
+
+} // namespace cuda
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/gpu/cuda_arrow_ipc.h b/src/arrow/cpp/src/arrow/gpu/cuda_arrow_ipc.h
new file mode 100644
index 000000000..b7200a94b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/cuda_arrow_ipc.h
@@ -0,0 +1,72 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/buffer.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+#include "arrow/gpu/cuda_memory.h"
+
+namespace arrow {
+
+class MemoryPool;
+class RecordBatch;
+class Schema;
+
+namespace ipc {
+
+class Message;
+class DictionaryMemo;
+
+} // namespace ipc
+
+namespace cuda {
+
+/// \defgroup cuda-ipc-functions Functions for CUDA IPC
+///
+/// @{
+
+/// \brief Write record batch message to GPU device memory
+/// \param[in] batch record batch to write
+/// \param[in] ctx CudaContext to allocate device memory from
+/// \return CudaBuffer or Status
+ARROW_EXPORT
+Result<std::shared_ptr<CudaBuffer>> SerializeRecordBatch(const RecordBatch& batch,
+ CudaContext* ctx);
+
+/// \brief ReadRecordBatch specialized to handle metadata on CUDA device
+/// \param[in] schema the Schema for the record batch
+/// \param[in] dictionary_memo DictionaryMemo which has any
+/// dictionaries. Can be nullptr if you are sure there are no
+/// dictionary-encoded fields
+/// \param[in] buffer a CudaBuffer containing the complete IPC message
+/// \param[in] pool a MemoryPool to use for allocating space for the metadata
+/// \return RecordBatch or Status
+ARROW_EXPORT
+Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
+ const std::shared_ptr<Schema>& schema, const ipc::DictionaryMemo* dictionary_memo,
+ const std::shared_ptr<CudaBuffer>& buffer, MemoryPool* pool = default_memory_pool());
+
+/// @}
+
+} // namespace cuda
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/gpu/cuda_benchmark.cc b/src/arrow/cpp/src/arrow/gpu/cuda_benchmark.cc
new file mode 100644
index 000000000..2787d103c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/cuda_benchmark.cc
@@ -0,0 +1,94 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/memory_pool.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+
+#include "arrow/gpu/cuda_api.h"
+
+namespace arrow {
+namespace cuda {
+
+constexpr int64_t kGpuNumber = 0;
+
+static void CudaBufferWriterBenchmark(benchmark::State& state, const int64_t total_bytes,
+ const int64_t chunksize,
+ const int64_t buffer_size) {
+ CudaDeviceManager* manager;
+ ABORT_NOT_OK(CudaDeviceManager::Instance().Value(&manager));
+ std::shared_ptr<CudaContext> context;
+ ABORT_NOT_OK(manager->GetContext(kGpuNumber).Value(&context));
+
+ std::shared_ptr<CudaBuffer> device_buffer;
+ ABORT_NOT_OK(context->Allocate(total_bytes).Value(&device_buffer));
+ CudaBufferWriter writer(device_buffer);
+
+ if (buffer_size > 0) {
+ ABORT_NOT_OK(writer.SetBufferSize(buffer_size));
+ }
+
+ std::shared_ptr<ResizableBuffer> buffer;
+ ASSERT_OK(MakeRandomByteBuffer(total_bytes, default_memory_pool(), &buffer));
+
+ const uint8_t* host_data = buffer->data();
+ while (state.KeepRunning()) {
+ int64_t bytes_written = 0;
+ ABORT_NOT_OK(writer.Seek(0));
+ while (bytes_written < total_bytes) {
+ int64_t bytes_to_write = std::min(chunksize, total_bytes - bytes_written);
+ ABORT_NOT_OK(writer.Write(host_data + bytes_written, bytes_to_write));
+ bytes_written += bytes_to_write;
+ }
+ }
+ state.SetBytesProcessed(int64_t(state.iterations()) * total_bytes);
+}
+
+static void Writer_Buffered(benchmark::State& state) {
+ // 128MB
+ const int64_t kTotalBytes = 1 << 27;
+
+ // 8MB
+ const int64_t kBufferSize = 1 << 23;
+
+ CudaBufferWriterBenchmark(state, kTotalBytes, state.range(0), kBufferSize);
+}
+
+static void Writer_Unbuffered(benchmark::State& state) {
+ // 128MB
+ const int64_t kTotalBytes = 1 << 27;
+ CudaBufferWriterBenchmark(state, kTotalBytes, state.range(0), 0);
+}
+
+// Vary chunk write size from 256 bytes to 64K
+BENCHMARK(Writer_Buffered)->RangeMultiplier(16)->Range(1 << 8, 1 << 16)->UseRealTime();
+
+BENCHMARK(Writer_Unbuffered)
+ ->RangeMultiplier(4)
+ ->RangeMultiplier(16)
+ ->Range(1 << 8, 1 << 16)
+ ->UseRealTime();
+
+} // namespace cuda
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/gpu/cuda_context.cc b/src/arrow/cpp/src/arrow/gpu/cuda_context.cc
new file mode 100644
index 000000000..8cb7e65fa
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/cuda_context.cc
@@ -0,0 +1,646 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/gpu/cuda_context.h"
+
+#include <atomic>
+#include <cstdint>
+#include <memory>
+#include <mutex>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <cuda.h>
+
+#include "arrow/gpu/cuda_internal.h"
+#include "arrow/gpu/cuda_memory.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace cuda {
+
+using internal::ContextSaver;
+
+namespace {
+
+struct DeviceProperties {
+ int device_number_;
+ CUdevice handle_;
+ int64_t total_memory_;
+ std::string name_;
+
+ Status Init(int device_number) {
+ device_number_ = device_number;
+ CU_RETURN_NOT_OK("cuDeviceGet", cuDeviceGet(&handle_, device_number));
+ size_t total_memory = 0;
+ CU_RETURN_NOT_OK("cuDeviceTotalMem", cuDeviceTotalMem(&total_memory, handle_));
+ total_memory_ = total_memory;
+
+ char buf[200];
+ CU_RETURN_NOT_OK("cuDeviceGetName", cuDeviceGetName(buf, sizeof(buf), device_number));
+ name_.assign(buf);
+ return Status::OK();
+ }
+};
+
+const char kCudaDeviceTypeName[] = "arrow::cuda::CudaDevice";
+
+} // namespace
+
+struct CudaDevice::Impl {
+ DeviceProperties props;
+};
+
+// ----------------------------------------------------------------------
+// CudaContext implementation
+
+class CudaContext::Impl {
+ public:
+ Impl() : bytes_allocated_(0) {}
+
+ Status Init(const std::shared_ptr<CudaDevice>& device) {
+ mm_ = checked_pointer_cast<CudaMemoryManager>(device->default_memory_manager());
+ props_ = &device->impl_->props;
+ own_context_ = true;
+ CU_RETURN_NOT_OK("cuDevicePrimaryCtxRetain",
+ cuDevicePrimaryCtxRetain(&context_, props_->handle_));
+ is_open_ = true;
+ return Status::OK();
+ }
+
+ Status InitShared(const std::shared_ptr<CudaDevice>& device, CUcontext ctx) {
+ mm_ = checked_pointer_cast<CudaMemoryManager>(device->default_memory_manager());
+ props_ = &device->impl_->props;
+ own_context_ = false;
+ context_ = ctx;
+ is_open_ = true;
+ return Status::OK();
+ }
+
+ Status Close() {
+ if (is_open_ && own_context_) {
+ CU_RETURN_NOT_OK("cuDevicePrimaryCtxRelease",
+ cuDevicePrimaryCtxRelease(props_->handle_));
+ }
+ is_open_ = false;
+ return Status::OK();
+ }
+
+ int64_t bytes_allocated() const { return bytes_allocated_.load(); }
+
+ Status Allocate(int64_t nbytes, uint8_t** out) {
+ if (nbytes > 0) {
+ ContextSaver set_temporary(context_);
+ CUdeviceptr data;
+ CU_RETURN_NOT_OK("cuMemAlloc", cuMemAlloc(&data, static_cast<size_t>(nbytes)));
+ bytes_allocated_ += nbytes;
+ *out = reinterpret_cast<uint8_t*>(data);
+ } else {
+ *out = nullptr;
+ }
+ return Status::OK();
+ }
+
+ Status CopyHostToDevice(uintptr_t dst, const void* src, int64_t nbytes) {
+ ContextSaver set_temporary(context_);
+ CU_RETURN_NOT_OK("cuMemcpyHtoD", cuMemcpyHtoD(dst, src, static_cast<size_t>(nbytes)));
+ return Status::OK();
+ }
+
+ Status CopyDeviceToHost(void* dst, uintptr_t src, int64_t nbytes) {
+ ContextSaver set_temporary(context_);
+ CU_RETURN_NOT_OK("cuMemcpyDtoH", cuMemcpyDtoH(dst, src, static_cast<size_t>(nbytes)));
+ return Status::OK();
+ }
+
+ Status CopyDeviceToDevice(uintptr_t dst, uintptr_t src, int64_t nbytes) {
+ ContextSaver set_temporary(context_);
+ CU_RETURN_NOT_OK("cuMemcpyDtoD", cuMemcpyDtoD(dst, src, static_cast<size_t>(nbytes)));
+ return Status::OK();
+ }
+
+ Status CopyDeviceToAnotherDevice(const std::shared_ptr<CudaContext>& dst_ctx,
+ uintptr_t dst, uintptr_t src, int64_t nbytes) {
+ ContextSaver set_temporary(context_);
+ CU_RETURN_NOT_OK("cuMemcpyPeer",
+ cuMemcpyPeer(dst, reinterpret_cast<CUcontext>(dst_ctx->handle()),
+ src, context_, static_cast<size_t>(nbytes)));
+ return Status::OK();
+ }
+
+ Status Synchronize(void) {
+ ContextSaver set_temporary(context_);
+ CU_RETURN_NOT_OK("cuCtxSynchronize", cuCtxSynchronize());
+ return Status::OK();
+ }
+
+ Status Free(void* device_ptr, int64_t nbytes) {
+ CU_RETURN_NOT_OK("cuMemFree", cuMemFree(reinterpret_cast<CUdeviceptr>(device_ptr)));
+ bytes_allocated_ -= nbytes;
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<CudaIpcMemHandle>> ExportIpcBuffer(const void* data,
+ int64_t size) {
+ CUipcMemHandle cu_handle;
+ if (size > 0) {
+ ContextSaver set_temporary(context_);
+ CU_RETURN_NOT_OK(
+ "cuIpcGetMemHandle",
+ cuIpcGetMemHandle(&cu_handle, reinterpret_cast<CUdeviceptr>(data)));
+ }
+ return std::shared_ptr<CudaIpcMemHandle>(new CudaIpcMemHandle(size, &cu_handle));
+ }
+
+ Status OpenIpcBuffer(const CudaIpcMemHandle& ipc_handle, uint8_t** out) {
+ int64_t size = ipc_handle.memory_size();
+ if (size > 0) {
+ auto handle = reinterpret_cast<const CUipcMemHandle*>(ipc_handle.handle());
+ CUdeviceptr data;
+ CU_RETURN_NOT_OK(
+ "cuIpcOpenMemHandle",
+ cuIpcOpenMemHandle(&data, *handle, CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS));
+ *out = reinterpret_cast<uint8_t*>(data);
+ } else {
+ *out = nullptr;
+ }
+ return Status::OK();
+ }
+
+ std::shared_ptr<CudaDevice> device() const {
+ return checked_pointer_cast<CudaDevice>(mm_->device());
+ }
+
+ const std::shared_ptr<CudaMemoryManager>& memory_manager() const { return mm_; }
+
+ void* context_handle() const { return reinterpret_cast<void*>(context_); }
+
+ private:
+ std::shared_ptr<CudaMemoryManager> mm_;
+ const DeviceProperties* props_;
+ CUcontext context_;
+ bool is_open_;
+
+ // So that we can utilize a CUcontext that was created outside this library
+ bool own_context_;
+
+ std::atomic<int64_t> bytes_allocated_;
+};
+
+// ----------------------------------------------------------------------
+// CudaDevice implementation
+
+CudaDevice::CudaDevice(Impl impl) : impl_(new Impl(std::move(impl))) {}
+
+const char* CudaDevice::type_name() const { return kCudaDeviceTypeName; }
+
+std::string CudaDevice::ToString() const {
+ std::stringstream ss;
+ ss << "CudaDevice(device_number=" << device_number() << ", name=\"" << device_name()
+ << "\")";
+ return ss.str();
+}
+
+bool CudaDevice::Equals(const Device& other) const {
+ if (!IsCudaDevice(other)) {
+ return false;
+ }
+ return checked_cast<const CudaDevice&>(other).device_number() == device_number();
+}
+
+int CudaDevice::device_number() const { return impl_->props.device_number_; }
+
+std::string CudaDevice::device_name() const { return impl_->props.name_; }
+
+int64_t CudaDevice::total_memory() const { return impl_->props.total_memory_; }
+
+int CudaDevice::handle() const { return impl_->props.handle_; }
+
+Result<std::shared_ptr<CudaDevice>> CudaDevice::Make(int device_number) {
+ ARROW_ASSIGN_OR_RAISE(auto manager, CudaDeviceManager::Instance());
+ return manager->GetDevice(device_number);
+}
+
+std::shared_ptr<MemoryManager> CudaDevice::default_memory_manager() {
+ return CudaMemoryManager::Make(shared_from_this());
+}
+
+Result<std::shared_ptr<CudaContext>> CudaDevice::GetContext() {
+ // XXX should we cache a default context in CudaDevice instance?
+ auto context = std::shared_ptr<CudaContext>(new CudaContext());
+ auto self = checked_pointer_cast<CudaDevice>(shared_from_this());
+ RETURN_NOT_OK(context->impl_->Init(self));
+ return context;
+}
+
+Result<std::shared_ptr<CudaContext>> CudaDevice::GetSharedContext(void* handle) {
+ auto context = std::shared_ptr<CudaContext>(new CudaContext());
+ auto self = checked_pointer_cast<CudaDevice>(shared_from_this());
+ RETURN_NOT_OK(context->impl_->InitShared(self, reinterpret_cast<CUcontext>(handle)));
+ return context;
+}
+
+Result<std::shared_ptr<CudaHostBuffer>> CudaDevice::AllocateHostBuffer(int64_t size) {
+ ARROW_ASSIGN_OR_RAISE(auto context, GetContext());
+ ContextSaver set_temporary(*context);
+ void* ptr;
+ CU_RETURN_NOT_OK("cuMemHostAlloc", cuMemHostAlloc(&ptr, static_cast<size_t>(size),
+ CU_MEMHOSTALLOC_PORTABLE));
+ return std::make_shared<CudaHostBuffer>(reinterpret_cast<uint8_t*>(ptr), size);
+}
+
+bool IsCudaDevice(const Device& device) {
+ return device.type_name() == kCudaDeviceTypeName;
+}
+
+Result<std::shared_ptr<CudaDevice>> AsCudaDevice(const std::shared_ptr<Device>& device) {
+ if (IsCudaDevice(*device)) {
+ return checked_pointer_cast<CudaDevice>(device);
+ } else {
+ return Status::TypeError("Device is not a Cuda device: ", device->ToString());
+ }
+}
+
+// ----------------------------------------------------------------------
+// CudaMemoryManager implementation
+
+std::shared_ptr<CudaMemoryManager> CudaMemoryManager::Make(
+ const std::shared_ptr<Device>& device) {
+ return std::shared_ptr<CudaMemoryManager>(new CudaMemoryManager(device));
+}
+
+std::shared_ptr<CudaDevice> CudaMemoryManager::cuda_device() const {
+ return checked_pointer_cast<CudaDevice>(device_);
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> CudaMemoryManager::GetBufferReader(
+ std::shared_ptr<Buffer> buf) {
+ if (*buf->device() != *device_) {
+ return Status::Invalid(
+ "CudaMemoryManager::GetBufferReader called on foreign buffer "
+ "for device ",
+ buf->device()->ToString());
+ }
+ return std::make_shared<CudaBufferReader>(checked_pointer_cast<CudaBuffer>(buf));
+}
+
+Result<std::shared_ptr<io::OutputStream>> CudaMemoryManager::GetBufferWriter(
+ std::shared_ptr<Buffer> buf) {
+ if (*buf->device() != *device_) {
+ return Status::Invalid(
+ "CudaMemoryManager::GetBufferReader called on foreign buffer "
+ "for device ",
+ buf->device()->ToString());
+ }
+ ARROW_ASSIGN_OR_RAISE(auto cuda_buf, CudaBuffer::FromBuffer(buf));
+ auto writer = std::make_shared<CudaBufferWriter>(cuda_buf);
+ // Use 8MB buffering, which yields generally good performance
+ RETURN_NOT_OK(writer->SetBufferSize(1 << 23));
+ return writer;
+}
+
+Result<std::shared_ptr<Buffer>> CudaMemoryManager::AllocateBuffer(int64_t size) {
+ ARROW_ASSIGN_OR_RAISE(auto context, cuda_device()->GetContext());
+ std::shared_ptr<CudaBuffer> dest;
+ return context->Allocate(size);
+}
+
+Result<std::shared_ptr<Buffer>> CudaMemoryManager::CopyBufferTo(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& to) {
+ if (to->is_cpu()) {
+ // Device-to-CPU copy
+ std::shared_ptr<Buffer> dest;
+ ARROW_ASSIGN_OR_RAISE(auto from_context, cuda_device()->GetContext());
+ ARROW_ASSIGN_OR_RAISE(dest, to->AllocateBuffer(buf->size()));
+ RETURN_NOT_OK(from_context->CopyDeviceToHost(dest->mutable_data(), buf->address(),
+ buf->size()));
+ return dest;
+ }
+ return nullptr;
+}
+
+Result<std::shared_ptr<Buffer>> CudaMemoryManager::CopyBufferFrom(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& from) {
+ if (from->is_cpu()) {
+ // CPU-to-device copy
+ ARROW_ASSIGN_OR_RAISE(auto to_context, cuda_device()->GetContext());
+ ARROW_ASSIGN_OR_RAISE(auto dest, to_context->Allocate(buf->size()));
+ RETURN_NOT_OK(
+ to_context->CopyHostToDevice(dest->address(), buf->data(), buf->size()));
+ return dest;
+ }
+ if (IsCudaMemoryManager(*from)) {
+ // Device-to-device copy
+ ARROW_ASSIGN_OR_RAISE(auto to_context, cuda_device()->GetContext());
+ ARROW_ASSIGN_OR_RAISE(
+ auto from_context,
+ checked_cast<const CudaMemoryManager&>(*from).cuda_device()->GetContext());
+ ARROW_ASSIGN_OR_RAISE(auto dest, to_context->Allocate(buf->size()));
+ if (to_context->handle() == from_context->handle()) {
+ // Same context
+ RETURN_NOT_OK(
+ to_context->CopyDeviceToDevice(dest->address(), buf->address(), buf->size()));
+ } else {
+ // Other context
+ RETURN_NOT_OK(from_context->CopyDeviceToAnotherDevice(to_context, dest->address(),
+ buf->address(), buf->size()));
+ }
+ return dest;
+ }
+ return nullptr;
+}
+
+Result<std::shared_ptr<Buffer>> CudaMemoryManager::ViewBufferTo(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& to) {
+ if (to->is_cpu()) {
+ // Device-on-CPU view
+ ARROW_ASSIGN_OR_RAISE(auto address, GetHostAddress(buf->address()));
+ return std::make_shared<Buffer>(address, buf->size(), to, buf);
+ }
+ return nullptr;
+}
+
+Result<std::shared_ptr<Buffer>> CudaMemoryManager::ViewBufferFrom(
+ const std::shared_ptr<Buffer>& buf, const std::shared_ptr<MemoryManager>& from) {
+ if (from->is_cpu()) {
+ // CPU-on-device view
+ ARROW_ASSIGN_OR_RAISE(auto to_context, cuda_device()->GetContext());
+ ARROW_ASSIGN_OR_RAISE(auto address, GetDeviceAddress(buf->data(), to_context));
+ return std::make_shared<Buffer>(address, buf->size(), shared_from_this(), buf);
+ }
+ return nullptr;
+}
+
+bool IsCudaMemoryManager(const MemoryManager& mm) { return IsCudaDevice(*mm.device()); }
+
+Result<std::shared_ptr<CudaMemoryManager>> AsCudaMemoryManager(
+ const std::shared_ptr<MemoryManager>& mm) {
+ if (IsCudaMemoryManager(*mm)) {
+ return checked_pointer_cast<CudaMemoryManager>(mm);
+ } else {
+ return Status::TypeError("Device is not a Cuda device: ", mm->device()->ToString());
+ }
+}
+
+// ----------------------------------------------------------------------
+// CudaDeviceManager implementation
+
+class CudaDeviceManager::Impl {
+ public:
+ Impl() : host_bytes_allocated_(0) {}
+
+ Status Init() {
+ CU_RETURN_NOT_OK("cuInit", cuInit(0));
+ CU_RETURN_NOT_OK("cuDeviceGetCount", cuDeviceGetCount(&num_devices_));
+
+ devices_.resize(num_devices_);
+ for (int i = 0; i < num_devices_; ++i) {
+ ARROW_ASSIGN_OR_RAISE(devices_[i], MakeDevice(i));
+ }
+ return Status::OK();
+ }
+
+ Status AllocateHost(int device_number, int64_t nbytes, uint8_t** out) {
+ RETURN_NOT_OK(CheckDeviceNum(device_number));
+ ARROW_ASSIGN_OR_RAISE(auto ctx, GetContext(device_number));
+ ContextSaver set_temporary((CUcontext)(ctx.get()->handle()));
+ CU_RETURN_NOT_OK("cuMemHostAlloc", cuMemHostAlloc(reinterpret_cast<void**>(out),
+ static_cast<size_t>(nbytes),
+ CU_MEMHOSTALLOC_PORTABLE));
+ host_bytes_allocated_ += nbytes;
+ return Status::OK();
+ }
+
+ Status FreeHost(void* data, int64_t nbytes) {
+ CU_RETURN_NOT_OK("cuMemFreeHost", cuMemFreeHost(data));
+ host_bytes_allocated_ -= nbytes;
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<CudaContext>> GetContext(int device_number) {
+ RETURN_NOT_OK(CheckDeviceNum(device_number));
+ return devices_[device_number]->GetContext();
+ }
+
+ Result<std::shared_ptr<CudaContext>> GetSharedContext(int device_number, void* handle) {
+ RETURN_NOT_OK(CheckDeviceNum(device_number));
+ return devices_[device_number]->GetSharedContext(handle);
+ }
+
+ Result<std::shared_ptr<CudaDevice>> GetDevice(int device_number) {
+ RETURN_NOT_OK(CheckDeviceNum(device_number));
+ return devices_[device_number];
+ }
+
+ int num_devices() const { return num_devices_; }
+
+ Status CheckDeviceNum(int device_number) const {
+ if (device_number < 0 || device_number >= num_devices_) {
+ return Status::Invalid("Invalid Cuda device number ", device_number,
+ " (should be between 0 and ", num_devices_ - 1,
+ ", inclusive)");
+ }
+ return Status::OK();
+ }
+
+ protected:
+ Result<std::shared_ptr<CudaDevice>> MakeDevice(int device_number) {
+ DeviceProperties props;
+ RETURN_NOT_OK(props.Init(device_number));
+ return std::shared_ptr<CudaDevice>(new CudaDevice({std::move(props)}));
+ }
+
+ private:
+ int num_devices_;
+ std::vector<std::shared_ptr<CudaDevice>> devices_;
+
+ int64_t host_bytes_allocated_;
+};
+
+CudaDeviceManager::CudaDeviceManager() { impl_.reset(new Impl()); }
+
+std::unique_ptr<CudaDeviceManager> CudaDeviceManager::instance_ = nullptr;
+
+Result<CudaDeviceManager*> CudaDeviceManager::Instance() {
+ static std::mutex mutex;
+ static std::atomic<bool> init_end(false);
+
+ if (!init_end) {
+ std::lock_guard<std::mutex> lock(mutex);
+ if (!init_end) {
+ instance_.reset(new CudaDeviceManager());
+ RETURN_NOT_OK(instance_->impl_->Init());
+ init_end = true;
+ }
+ }
+ return instance_.get();
+}
+
+Result<std::shared_ptr<CudaDevice>> CudaDeviceManager::GetDevice(int device_number) {
+ return impl_->GetDevice(device_number);
+}
+
+Result<std::shared_ptr<CudaContext>> CudaDeviceManager::GetContext(int device_number) {
+ return impl_->GetContext(device_number);
+}
+
+Result<std::shared_ptr<CudaContext>> CudaDeviceManager::GetSharedContext(
+ int device_number, void* ctx) {
+ return impl_->GetSharedContext(device_number, ctx);
+}
+
+Result<std::shared_ptr<CudaHostBuffer>> CudaDeviceManager::AllocateHost(int device_number,
+ int64_t nbytes) {
+ uint8_t* data = nullptr;
+ RETURN_NOT_OK(impl_->AllocateHost(device_number, nbytes, &data));
+ return std::make_shared<CudaHostBuffer>(data, nbytes);
+}
+
+Status CudaDeviceManager::FreeHost(void* data, int64_t nbytes) {
+ return impl_->FreeHost(data, nbytes);
+}
+
+int CudaDeviceManager::num_devices() const { return impl_->num_devices(); }
+
+// ----------------------------------------------------------------------
+// CudaContext public API
+
+CudaContext::CudaContext() { impl_.reset(new Impl()); }
+
+CudaContext::~CudaContext() {}
+
+Result<std::shared_ptr<CudaBuffer>> CudaContext::Allocate(int64_t nbytes) {
+ uint8_t* data = nullptr;
+ RETURN_NOT_OK(impl_->Allocate(nbytes, &data));
+ return std::make_shared<CudaBuffer>(data, nbytes, this->shared_from_this(), true);
+}
+
+Result<std::shared_ptr<CudaBuffer>> CudaContext::View(uint8_t* data, int64_t nbytes) {
+ return std::make_shared<CudaBuffer>(data, nbytes, this->shared_from_this(), false);
+}
+
+Result<std::shared_ptr<CudaIpcMemHandle>> CudaContext::ExportIpcBuffer(const void* data,
+ int64_t size) {
+ return impl_->ExportIpcBuffer(data, size);
+}
+
+Status CudaContext::CopyHostToDevice(uintptr_t dst, const void* src, int64_t nbytes) {
+ return impl_->CopyHostToDevice(dst, src, nbytes);
+}
+
+Status CudaContext::CopyHostToDevice(void* dst, const void* src, int64_t nbytes) {
+ return impl_->CopyHostToDevice(reinterpret_cast<uintptr_t>(dst), src, nbytes);
+}
+
+Status CudaContext::CopyDeviceToHost(void* dst, uintptr_t src, int64_t nbytes) {
+ return impl_->CopyDeviceToHost(dst, src, nbytes);
+}
+
+Status CudaContext::CopyDeviceToHost(void* dst, const void* src, int64_t nbytes) {
+ return impl_->CopyDeviceToHost(dst, reinterpret_cast<uintptr_t>(src), nbytes);
+}
+
+Status CudaContext::CopyDeviceToDevice(uintptr_t dst, uintptr_t src, int64_t nbytes) {
+ return impl_->CopyDeviceToDevice(dst, src, nbytes);
+}
+
+Status CudaContext::CopyDeviceToDevice(void* dst, const void* src, int64_t nbytes) {
+ return impl_->CopyDeviceToDevice(reinterpret_cast<uintptr_t>(dst),
+ reinterpret_cast<uintptr_t>(src), nbytes);
+}
+
+Status CudaContext::CopyDeviceToAnotherDevice(const std::shared_ptr<CudaContext>& dst_ctx,
+ uintptr_t dst, uintptr_t src,
+ int64_t nbytes) {
+ return impl_->CopyDeviceToAnotherDevice(dst_ctx, dst, src, nbytes);
+}
+
+Status CudaContext::CopyDeviceToAnotherDevice(const std::shared_ptr<CudaContext>& dst_ctx,
+ void* dst, const void* src,
+ int64_t nbytes) {
+ return impl_->CopyDeviceToAnotherDevice(dst_ctx, reinterpret_cast<uintptr_t>(dst),
+ reinterpret_cast<uintptr_t>(src), nbytes);
+}
+
+Status CudaContext::Synchronize(void) { return impl_->Synchronize(); }
+
+Status CudaContext::Close() { return impl_->Close(); }
+
+Status CudaContext::Free(void* device_ptr, int64_t nbytes) {
+ return impl_->Free(device_ptr, nbytes);
+}
+
+Result<std::shared_ptr<CudaBuffer>> CudaContext::OpenIpcBuffer(
+ const CudaIpcMemHandle& ipc_handle) {
+ if (ipc_handle.memory_size() > 0) {
+ ContextSaver set_temporary(*this);
+ uint8_t* data = nullptr;
+ RETURN_NOT_OK(impl_->OpenIpcBuffer(ipc_handle, &data));
+ // Need to ask the device how big the buffer is
+ size_t allocation_size = 0;
+ CU_RETURN_NOT_OK("cuMemGetAddressRange",
+ cuMemGetAddressRange(nullptr, &allocation_size,
+ reinterpret_cast<CUdeviceptr>(data)));
+ return std::make_shared<CudaBuffer>(data, allocation_size, this->shared_from_this(),
+ true, true);
+ } else {
+ // zero-sized buffer does not own data (which is nullptr), hence
+ // CloseIpcBuffer will not be called (see CudaBuffer::Close).
+ return std::make_shared<CudaBuffer>(nullptr, 0, this->shared_from_this(), false,
+ true);
+ }
+}
+
+Status CudaContext::CloseIpcBuffer(CudaBuffer* buf) {
+ ContextSaver set_temporary(*this);
+ CU_RETURN_NOT_OK("cuIpcCloseMemHandle", cuIpcCloseMemHandle(buf->address()));
+ return Status::OK();
+}
+
+int64_t CudaContext::bytes_allocated() const { return impl_->bytes_allocated(); }
+
+void* CudaContext::handle() const { return impl_->context_handle(); }
+
+std::shared_ptr<CudaDevice> CudaContext::device() const { return impl_->device(); }
+
+std::shared_ptr<CudaMemoryManager> CudaContext::memory_manager() const {
+ return impl_->memory_manager();
+}
+
+int CudaContext::device_number() const { return impl_->device()->device_number(); }
+
+Result<uintptr_t> CudaContext::GetDeviceAddress(uintptr_t addr) {
+ ContextSaver set_temporary(*this);
+ CUdeviceptr ptr;
+ CU_RETURN_NOT_OK("cuPointerGetAttribute",
+ cuPointerGetAttribute(&ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER,
+ static_cast<CUdeviceptr>(addr)));
+ return static_cast<uintptr_t>(ptr);
+}
+
+Result<uintptr_t> CudaContext::GetDeviceAddress(uint8_t* addr) {
+ return GetDeviceAddress(reinterpret_cast<uintptr_t>(addr));
+}
+
+} // namespace cuda
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/gpu/cuda_context.h b/src/arrow/cpp/src/arrow/gpu/cuda_context.h
new file mode 100644
index 000000000..2cff4f57a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/cuda_context.h
@@ -0,0 +1,310 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+
+#include "arrow/device.h"
+#include "arrow/result.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace cuda {
+
+// Forward declaration
+class CudaContext;
+class CudaDevice;
+class CudaDeviceManager;
+class CudaBuffer;
+class CudaHostBuffer;
+class CudaIpcMemHandle;
+class CudaMemoryManager;
+
+// XXX Should CudaContext be merged into CudaMemoryManager?
+
+class ARROW_EXPORT CudaDeviceManager {
+ public:
+ static Result<CudaDeviceManager*> Instance();
+
+ /// \brief Get a CudaDevice instance for a particular device
+ /// \param[in] device_number the CUDA device number
+ Result<std::shared_ptr<CudaDevice>> GetDevice(int device_number);
+
+ /// \brief Get the CUDA driver context for a particular device
+ /// \param[in] device_number the CUDA device number
+ /// \return cached context
+ Result<std::shared_ptr<CudaContext>> GetContext(int device_number);
+
+ /// \brief Get the shared CUDA driver context for a particular device
+ /// \param[in] device_number the CUDA device number
+ /// \param[in] handle CUDA context handle created by another library
+ /// \return shared context
+ Result<std::shared_ptr<CudaContext>> GetSharedContext(int device_number, void* handle);
+
+ /// \brief Allocate host memory with fast access to given GPU device
+ /// \param[in] device_number the CUDA device number
+ /// \param[in] nbytes number of bytes
+ /// \return Host buffer or Status
+ Result<std::shared_ptr<CudaHostBuffer>> AllocateHost(int device_number, int64_t nbytes);
+
+ /// \brief Free host memory
+ ///
+ /// The given memory pointer must have been allocated with AllocateHost.
+ Status FreeHost(void* data, int64_t nbytes);
+
+ int num_devices() const;
+
+ private:
+ CudaDeviceManager();
+ static std::unique_ptr<CudaDeviceManager> instance_;
+
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+
+ friend class CudaContext;
+ friend class CudaDevice;
+};
+
+/// \brief Device implementation for CUDA
+///
+/// Each CudaDevice instance is tied to a particular CUDA device
+/// (identified by its logical device number).
+class ARROW_EXPORT CudaDevice : public Device {
+ public:
+ const char* type_name() const override;
+ std::string ToString() const override;
+ bool Equals(const Device&) const override;
+ std::shared_ptr<MemoryManager> default_memory_manager() override;
+
+ /// \brief Return a CudaDevice instance for a particular device
+ /// \param[in] device_number the CUDA device number
+ static Result<std::shared_ptr<CudaDevice>> Make(int device_number);
+
+ /// \brief Return the device logical number
+ int device_number() const;
+
+ /// \brief Return the GPU model name
+ std::string device_name() const;
+
+ /// \brief Return total memory on this device
+ int64_t total_memory() const;
+
+ /// \brief Return a raw CUDA device handle
+ ///
+ /// The returned value can be used to expose this device to other libraries.
+ /// It should be interpreted as `CUdevice`.
+ int handle() const;
+
+ /// \brief Get a CUDA driver context for this device
+ ///
+ /// The returned context is associated with the primary CUDA context for the
+ /// device. This is the recommended way of getting a context for a device,
+ /// as it allows interoperating transparently with any library using the
+ /// primary CUDA context API.
+ Result<std::shared_ptr<CudaContext>> GetContext();
+
+ /// \brief Get a CUDA driver context for this device, using an existing handle
+ ///
+ /// The handle is not owned: it will not be released when the CudaContext
+ /// is destroyed. This function should only be used if you need interoperation
+ /// with a library that uses a non-primary context.
+ ///
+ /// \param[in] handle CUDA context handle created by another library
+ Result<std::shared_ptr<CudaContext>> GetSharedContext(void* handle);
+
+ /// \brief Allocate a host-residing, GPU-accessible buffer
+ ///
+ /// The buffer is allocated using this device's primary context.
+ ///
+ /// \param[in] size The buffer size in bytes
+ Result<std::shared_ptr<CudaHostBuffer>> AllocateHostBuffer(int64_t size);
+
+ protected:
+ struct Impl;
+
+ friend class CudaContext;
+ /// \cond FALSE
+ // (note: emits warning on Doxygen < 1.8.15)
+ friend class CudaDeviceManager::Impl;
+ /// \endcond
+
+ explicit CudaDevice(Impl);
+ std::unique_ptr<Impl> impl_;
+};
+
+/// \brief Return whether a device instance is a CudaDevice
+ARROW_EXPORT
+bool IsCudaDevice(const Device& device);
+
+/// \brief Cast a device instance to a CudaDevice
+///
+/// An error is returned if the device is not a CudaDevice.
+ARROW_EXPORT
+Result<std::shared_ptr<CudaDevice>> AsCudaDevice(const std::shared_ptr<Device>& device);
+
+/// \brief MemoryManager implementation for CUDA
+class ARROW_EXPORT CudaMemoryManager : public MemoryManager {
+ public:
+ Result<std::shared_ptr<io::RandomAccessFile>> GetBufferReader(
+ std::shared_ptr<Buffer> buf) override;
+ Result<std::shared_ptr<io::OutputStream>> GetBufferWriter(
+ std::shared_ptr<Buffer> buf) override;
+
+ Result<std::shared_ptr<Buffer>> AllocateBuffer(int64_t size) override;
+
+ /// \brief The CudaDevice instance tied to this MemoryManager
+ ///
+ /// This is a useful shorthand returning a concrete-typed pointer, avoiding
+ /// having to cast the `device()` result.
+ std::shared_ptr<CudaDevice> cuda_device() const;
+
+ protected:
+ using MemoryManager::MemoryManager;
+ static std::shared_ptr<CudaMemoryManager> Make(const std::shared_ptr<Device>& device);
+
+ Result<std::shared_ptr<Buffer>> CopyBufferFrom(
+ const std::shared_ptr<Buffer>& buf,
+ const std::shared_ptr<MemoryManager>& from) override;
+ Result<std::shared_ptr<Buffer>> CopyBufferTo(
+ const std::shared_ptr<Buffer>& buf,
+ const std::shared_ptr<MemoryManager>& to) override;
+ Result<std::shared_ptr<Buffer>> ViewBufferFrom(
+ const std::shared_ptr<Buffer>& buf,
+ const std::shared_ptr<MemoryManager>& from) override;
+ Result<std::shared_ptr<Buffer>> ViewBufferTo(
+ const std::shared_ptr<Buffer>& buf,
+ const std::shared_ptr<MemoryManager>& to) override;
+
+ friend class CudaDevice;
+};
+
+/// \brief Return whether a MemoryManager instance is a CudaMemoryManager
+ARROW_EXPORT
+bool IsCudaMemoryManager(const MemoryManager& mm);
+
+/// \brief Cast a MemoryManager instance to a CudaMemoryManager
+///
+/// An error is returned if the MemoryManager is not a CudaMemoryManager.
+ARROW_EXPORT
+Result<std::shared_ptr<CudaMemoryManager>> AsCudaMemoryManager(
+ const std::shared_ptr<MemoryManager>& mm);
+
+/// \class CudaContext
+/// \brief Object-oriented interface to the low-level CUDA driver API
+class ARROW_EXPORT CudaContext : public std::enable_shared_from_this<CudaContext> {
+ public:
+ ~CudaContext();
+
+ Status Close();
+
+ /// \brief Allocate CUDA memory on GPU device for this context
+ /// \param[in] nbytes number of bytes
+ /// \return the allocated buffer
+ Result<std::shared_ptr<CudaBuffer>> Allocate(int64_t nbytes);
+
+ /// \brief Release CUDA memory on GPU device for this context
+ /// \param[in] device_ptr the buffer address
+ /// \param[in] nbytes number of bytes
+ /// \return Status
+ Status Free(void* device_ptr, int64_t nbytes);
+
+ /// \brief Create a view of CUDA memory on GPU device of this context
+ /// \param[in] data the starting device address
+ /// \param[in] nbytes number of bytes
+ /// \return the view buffer
+ ///
+ /// \note The caller is responsible for allocating and freeing the
+ /// memory as well as ensuring that the memory belongs to the CUDA
+ /// context that this CudaContext instance holds.
+ Result<std::shared_ptr<CudaBuffer>> View(uint8_t* data, int64_t nbytes);
+
+ /// \brief Open existing CUDA IPC memory handle
+ /// \param[in] ipc_handle opaque pointer to CUipcMemHandle (driver API)
+ /// \return a CudaBuffer referencing the IPC segment
+ Result<std::shared_ptr<CudaBuffer>> OpenIpcBuffer(const CudaIpcMemHandle& ipc_handle);
+
+ /// \brief Close memory mapped with IPC buffer
+ /// \param[in] buffer a CudaBuffer referencing
+ /// \return Status
+ Status CloseIpcBuffer(CudaBuffer* buffer);
+
+ /// \brief Block until the all device tasks are completed.
+ Status Synchronize(void);
+
+ int64_t bytes_allocated() const;
+
+ /// \brief Expose CUDA context handle to other libraries
+ void* handle() const;
+
+ /// \brief Return the default memory manager tied to this context's device
+ std::shared_ptr<CudaMemoryManager> memory_manager() const;
+
+ /// \brief Return the device instance associated with this context
+ std::shared_ptr<CudaDevice> device() const;
+
+ /// \brief Return the logical device number
+ int device_number() const;
+
+ /// \brief Return the device address that is reachable from kernels
+ /// running in the context
+ /// \param[in] addr device or host memory address
+ /// \return the device address
+ ///
+ /// The device address is defined as a memory address accessible by
+ /// device. While it is often a device memory address, it can be
+ /// also a host memory address, for instance, when the memory is
+ /// allocated as host memory (using cudaMallocHost or cudaHostAlloc)
+ /// or as managed memory (using cudaMallocManaged) or the host
+ /// memory is page-locked (using cudaHostRegister).
+ Result<uintptr_t> GetDeviceAddress(uint8_t* addr);
+ Result<uintptr_t> GetDeviceAddress(uintptr_t addr);
+
+ private:
+ CudaContext();
+
+ Result<std::shared_ptr<CudaIpcMemHandle>> ExportIpcBuffer(const void* data,
+ int64_t size);
+ Status CopyHostToDevice(void* dst, const void* src, int64_t nbytes);
+ Status CopyHostToDevice(uintptr_t dst, const void* src, int64_t nbytes);
+ Status CopyDeviceToHost(void* dst, const void* src, int64_t nbytes);
+ Status CopyDeviceToHost(void* dst, uintptr_t src, int64_t nbytes);
+ Status CopyDeviceToDevice(void* dst, const void* src, int64_t nbytes);
+ Status CopyDeviceToDevice(uintptr_t dst, uintptr_t src, int64_t nbytes);
+ Status CopyDeviceToAnotherDevice(const std::shared_ptr<CudaContext>& dst_ctx, void* dst,
+ const void* src, int64_t nbytes);
+ Status CopyDeviceToAnotherDevice(const std::shared_ptr<CudaContext>& dst_ctx,
+ uintptr_t dst, uintptr_t src, int64_t nbytes);
+
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+
+ friend class CudaBuffer;
+ friend class CudaBufferReader;
+ friend class CudaBufferWriter;
+ friend class CudaDevice;
+ friend class CudaMemoryManager;
+ /// \cond FALSE
+ // (note: emits warning on Doxygen < 1.8.15)
+ friend class CudaDeviceManager::Impl;
+ /// \endcond
+};
+
+} // namespace cuda
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/gpu/cuda_internal.cc b/src/arrow/cpp/src/arrow/gpu/cuda_internal.cc
new file mode 100644
index 000000000..1e941415f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/cuda_internal.cc
@@ -0,0 +1,66 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/gpu/cuda_internal.h"
+
+#include <sstream>
+#include <string>
+
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace cuda {
+namespace internal {
+
+std::string CudaErrorDescription(CUresult err) {
+ DCHECK_NE(err, CUDA_SUCCESS);
+ std::stringstream ss;
+
+ const char* name = nullptr;
+ auto err_result = cuGetErrorName(err, &name);
+ if (err_result == CUDA_SUCCESS) {
+ DCHECK_NE(name, nullptr);
+ ss << "[" << name << "] ";
+ }
+
+ const char* str = nullptr;
+ err_result = cuGetErrorString(err, &str);
+ if (err_result == CUDA_SUCCESS) {
+ DCHECK_NE(str, nullptr);
+ ss << str;
+ } else {
+ ss << "unknown error";
+ }
+ return ss.str();
+}
+
+Status StatusFromCuda(CUresult res, const char* function_name) {
+ if (res == CUDA_SUCCESS) {
+ return Status::OK();
+ }
+ std::stringstream ss;
+ ss << "Cuda error " << res;
+ if (function_name != nullptr) {
+ ss << " in function '" << function_name << "'";
+ }
+ ss << ": " << CudaErrorDescription(res);
+ return Status::IOError(ss.str());
+}
+
+} // namespace internal
+} // namespace cuda
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/gpu/cuda_internal.h b/src/arrow/cpp/src/arrow/gpu/cuda_internal.h
new file mode 100644
index 000000000..25eb6e06c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/cuda_internal.h
@@ -0,0 +1,60 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Non-public header
+
+#pragma once
+
+#include <cassert>
+#include <string>
+
+#include <cuda.h>
+
+#include "arrow/gpu/cuda_context.h"
+#include "arrow/status.h"
+
+namespace arrow {
+namespace cuda {
+namespace internal {
+
+std::string CudaErrorDescription(CUresult err);
+
+Status StatusFromCuda(CUresult res, const char* function_name = nullptr);
+
+#define CU_RETURN_NOT_OK(FUNC_NAME, STMT) \
+ do { \
+ CUresult __res = (STMT); \
+ if (__res != CUDA_SUCCESS) { \
+ return ::arrow::cuda::internal::StatusFromCuda(__res, FUNC_NAME); \
+ } \
+ } while (0)
+
+class ContextSaver {
+ public:
+ explicit ContextSaver(CUcontext new_context) { cuCtxPushCurrent(new_context); }
+ explicit ContextSaver(const CudaContext& context)
+ : ContextSaver(reinterpret_cast<CUcontext>(context.handle())) {}
+
+ ~ContextSaver() {
+ CUcontext unused;
+ cuCtxPopCurrent(&unused);
+ }
+};
+
+} // namespace internal
+} // namespace cuda
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/gpu/cuda_memory.cc b/src/arrow/cpp/src/arrow/gpu/cuda_memory.cc
new file mode 100644
index 000000000..297e4dcf7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/cuda_memory.cc
@@ -0,0 +1,484 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/gpu/cuda_memory.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
+#include <mutex>
+#include <utility>
+
+#include <cuda.h>
+
+#include "arrow/buffer.h"
+#include "arrow/io/memory.h"
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+
+#include "arrow/gpu/cuda_context.h"
+#include "arrow/gpu/cuda_internal.h"
+
+namespace arrow {
+namespace cuda {
+
+using internal::ContextSaver;
+
+// ----------------------------------------------------------------------
+// CUDA IPC memory handle
+
+struct CudaIpcMemHandle::CudaIpcMemHandleImpl {
+ explicit CudaIpcMemHandleImpl(const uint8_t* handle) {
+ memcpy(&memory_size, handle, sizeof(memory_size));
+ if (memory_size != 0)
+ memcpy(&ipc_handle, handle + sizeof(memory_size), sizeof(CUipcMemHandle));
+ }
+
+ explicit CudaIpcMemHandleImpl(int64_t memory_size, const void* cu_handle)
+ : memory_size(memory_size) {
+ if (memory_size != 0) {
+ memcpy(&ipc_handle, cu_handle, sizeof(CUipcMemHandle));
+ }
+ }
+
+ CUipcMemHandle ipc_handle; /// initialized only when memory_size != 0
+ int64_t memory_size; /// size of the memory that ipc_handle refers to
+};
+
+CudaIpcMemHandle::CudaIpcMemHandle(const void* handle) {
+ impl_.reset(new CudaIpcMemHandleImpl(reinterpret_cast<const uint8_t*>(handle)));
+}
+
+CudaIpcMemHandle::CudaIpcMemHandle(int64_t memory_size, const void* cu_handle) {
+ impl_.reset(new CudaIpcMemHandleImpl(memory_size, cu_handle));
+}
+
+CudaIpcMemHandle::~CudaIpcMemHandle() {}
+
+Result<std::shared_ptr<CudaIpcMemHandle>> CudaIpcMemHandle::FromBuffer(
+ const void* opaque_handle) {
+ return std::shared_ptr<CudaIpcMemHandle>(new CudaIpcMemHandle(opaque_handle));
+}
+
+Result<std::shared_ptr<Buffer>> CudaIpcMemHandle::Serialize(MemoryPool* pool) const {
+ int64_t size = impl_->memory_size;
+ const size_t handle_size =
+ (size > 0 ? sizeof(int64_t) + sizeof(CUipcMemHandle) : sizeof(int64_t));
+
+ ARROW_ASSIGN_OR_RAISE(auto buffer,
+ AllocateBuffer(static_cast<int64_t>(handle_size), pool));
+ memcpy(buffer->mutable_data(), &impl_->memory_size, sizeof(impl_->memory_size));
+ if (size > 0) {
+ memcpy(buffer->mutable_data() + sizeof(impl_->memory_size), &impl_->ipc_handle,
+ sizeof(impl_->ipc_handle));
+ }
+ return std::move(buffer);
+}
+
+const void* CudaIpcMemHandle::handle() const { return &impl_->ipc_handle; }
+
+int64_t CudaIpcMemHandle::memory_size() const { return impl_->memory_size; }
+
+// ----------------------------------------------------------------------
+
+CudaBuffer::CudaBuffer(uint8_t* data, int64_t size,
+ const std::shared_ptr<CudaContext>& context, bool own_data,
+ bool is_ipc)
+ : Buffer(data, size), context_(context), own_data_(own_data), is_ipc_(is_ipc) {
+ is_mutable_ = true;
+ SetMemoryManager(context_->memory_manager());
+}
+
+CudaBuffer::CudaBuffer(uintptr_t address, int64_t size,
+ const std::shared_ptr<CudaContext>& context, bool own_data,
+ bool is_ipc)
+ : CudaBuffer(reinterpret_cast<uint8_t*>(address), size, context, own_data, is_ipc) {}
+
+CudaBuffer::~CudaBuffer() { ARROW_CHECK_OK(Close()); }
+
+Status CudaBuffer::Close() {
+ if (own_data_) {
+ if (is_ipc_) {
+ return context_->CloseIpcBuffer(this);
+ } else {
+ return context_->Free(const_cast<uint8_t*>(data_), size_);
+ }
+ }
+ return Status::OK();
+}
+
+CudaBuffer::CudaBuffer(const std::shared_ptr<CudaBuffer>& parent, const int64_t offset,
+ const int64_t size)
+ : Buffer(parent, offset, size),
+ context_(parent->context()),
+ own_data_(false),
+ is_ipc_(false) {
+ is_mutable_ = parent->is_mutable();
+}
+
+Result<std::shared_ptr<CudaBuffer>> CudaBuffer::FromBuffer(
+ std::shared_ptr<Buffer> buffer) {
+ int64_t offset = 0, size = buffer->size();
+ bool is_mutable = buffer->is_mutable();
+ std::shared_ptr<CudaBuffer> cuda_buffer;
+
+ // The original CudaBuffer may have been wrapped in another Buffer
+ // (for example through slicing).
+ // TODO check device instead
+ while (!(cuda_buffer = std::dynamic_pointer_cast<CudaBuffer>(buffer))) {
+ const std::shared_ptr<Buffer> parent = buffer->parent();
+ if (!parent) {
+ return Status::TypeError("buffer is not backed by a CudaBuffer");
+ }
+ offset += buffer->address() - parent->address();
+ buffer = parent;
+ }
+ // Re-slice to represent the same memory area
+ if (offset != 0 || cuda_buffer->size() != size || !is_mutable) {
+ cuda_buffer = std::make_shared<CudaBuffer>(std::move(cuda_buffer), offset, size);
+ cuda_buffer->is_mutable_ = is_mutable;
+ }
+ return cuda_buffer;
+}
+
+Status CudaBuffer::CopyToHost(const int64_t position, const int64_t nbytes,
+ void* out) const {
+ return context_->CopyDeviceToHost(out, data_ + position, nbytes);
+}
+
+Status CudaBuffer::CopyFromHost(const int64_t position, const void* data,
+ int64_t nbytes) {
+ if (nbytes > size_ - position) {
+ return Status::Invalid("Copy would overflow buffer");
+ }
+ return context_->CopyHostToDevice(const_cast<uint8_t*>(data_) + position, data, nbytes);
+}
+
+Status CudaBuffer::CopyFromDevice(const int64_t position, const void* data,
+ int64_t nbytes) {
+ if (nbytes > size_ - position) {
+ return Status::Invalid("Copy would overflow buffer");
+ }
+ return context_->CopyDeviceToDevice(const_cast<uint8_t*>(data_) + position, data,
+ nbytes);
+}
+
+Status CudaBuffer::CopyFromAnotherDevice(const std::shared_ptr<CudaContext>& src_ctx,
+ const int64_t position, const void* data,
+ int64_t nbytes) {
+ if (nbytes > size_ - position) {
+ return Status::Invalid("Copy would overflow buffer");
+ }
+ return src_ctx->CopyDeviceToAnotherDevice(
+ context_, const_cast<uint8_t*>(data_) + position, data, nbytes);
+}
+
+Result<std::shared_ptr<CudaIpcMemHandle>> CudaBuffer::ExportForIpc() {
+ if (is_ipc_) {
+ return Status::Invalid("Buffer has already been exported for IPC");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto handle, context_->ExportIpcBuffer(data_, size_));
+ own_data_ = false;
+ return handle;
+}
+
+CudaHostBuffer::~CudaHostBuffer() {
+ auto maybe_manager = CudaDeviceManager::Instance();
+ ARROW_CHECK_OK(maybe_manager.status());
+ ARROW_CHECK_OK((*maybe_manager)->FreeHost(const_cast<uint8_t*>(data_), size_));
+}
+
+Result<uintptr_t> CudaHostBuffer::GetDeviceAddress(
+ const std::shared_ptr<CudaContext>& ctx) {
+ return ::arrow::cuda::GetDeviceAddress(data(), ctx);
+}
+
+// ----------------------------------------------------------------------
+// CudaBufferReader
+
+CudaBufferReader::CudaBufferReader(const std::shared_ptr<Buffer>& buffer)
+ : address_(buffer->address()), size_(buffer->size()), position_(0), is_open_(true) {
+ auto maybe_buffer = CudaBuffer::FromBuffer(buffer);
+ if (ARROW_PREDICT_FALSE(!maybe_buffer.ok())) {
+ throw std::bad_cast();
+ }
+ buffer_ = *std::move(maybe_buffer);
+ context_ = buffer_->context();
+}
+
+Status CudaBufferReader::DoClose() {
+ is_open_ = false;
+ return Status::OK();
+}
+
+bool CudaBufferReader::closed() const { return !is_open_; }
+
+// XXX Only in a certain sense (not on the CPU)...
+bool CudaBufferReader::supports_zero_copy() const { return true; }
+
+Result<int64_t> CudaBufferReader::DoTell() const {
+ RETURN_NOT_OK(CheckClosed());
+ return position_;
+}
+
+Result<int64_t> CudaBufferReader::DoGetSize() {
+ RETURN_NOT_OK(CheckClosed());
+ return size_;
+}
+
+Status CudaBufferReader::DoSeek(int64_t position) {
+ RETURN_NOT_OK(CheckClosed());
+
+ if (position < 0 || position > size_) {
+ return Status::IOError("Seek out of bounds");
+ }
+
+ position_ = position;
+ return Status::OK();
+}
+
+Result<int64_t> CudaBufferReader::DoReadAt(int64_t position, int64_t nbytes,
+ void* buffer) {
+ RETURN_NOT_OK(CheckClosed());
+
+ nbytes = std::min(nbytes, size_ - position);
+ RETURN_NOT_OK(context_->CopyDeviceToHost(buffer, address_ + position, nbytes));
+ return nbytes;
+}
+
+Result<int64_t> CudaBufferReader::DoRead(int64_t nbytes, void* buffer) {
+ RETURN_NOT_OK(CheckClosed());
+
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, DoReadAt(position_, nbytes, buffer));
+ position_ += bytes_read;
+ return bytes_read;
+}
+
+Result<std::shared_ptr<Buffer>> CudaBufferReader::DoReadAt(int64_t position,
+ int64_t nbytes) {
+ RETURN_NOT_OK(CheckClosed());
+
+ int64_t size = std::min(nbytes, size_ - position);
+ return std::make_shared<CudaBuffer>(buffer_, position, size);
+}
+
+Result<std::shared_ptr<Buffer>> CudaBufferReader::DoRead(int64_t nbytes) {
+ RETURN_NOT_OK(CheckClosed());
+
+ int64_t size = std::min(nbytes, size_ - position_);
+ auto buffer = std::make_shared<CudaBuffer>(buffer_, position_, size);
+ position_ += size;
+ return buffer;
+}
+
+// ----------------------------------------------------------------------
+// CudaBufferWriter
+
+class CudaBufferWriter::CudaBufferWriterImpl {
+ public:
+ explicit CudaBufferWriterImpl(const std::shared_ptr<CudaBuffer>& buffer)
+ : context_(buffer->context()),
+ buffer_(buffer),
+ buffer_size_(0),
+ buffer_position_(0) {
+ buffer_ = buffer;
+ ARROW_CHECK(buffer->is_mutable()) << "Must pass mutable buffer";
+ address_ = buffer->mutable_address();
+ size_ = buffer->size();
+ position_ = 0;
+ closed_ = false;
+ }
+
+#define CHECK_CLOSED() \
+ if (closed_) { \
+ return Status::Invalid("Operation on closed CudaBufferWriter"); \
+ }
+
+ Status Seek(int64_t position) {
+ CHECK_CLOSED();
+ if (position < 0 || position >= size_) {
+ return Status::IOError("position out of bounds");
+ }
+ position_ = position;
+ return Status::OK();
+ }
+
+ Status Close() {
+ if (!closed_) {
+ closed_ = true;
+ RETURN_NOT_OK(FlushInternal());
+ }
+ return Status::OK();
+ }
+
+ Status Flush() {
+ CHECK_CLOSED();
+ return FlushInternal();
+ }
+
+ Status FlushInternal() {
+ if (buffer_size_ > 0 && buffer_position_ > 0) {
+ // Only need to flush when the write has been buffered
+ RETURN_NOT_OK(context_->CopyHostToDevice(address_ + position_ - buffer_position_,
+ host_buffer_data_, buffer_position_));
+ buffer_position_ = 0;
+ }
+ return Status::OK();
+ }
+
+ bool closed() const { return closed_; }
+
+ Result<int64_t> Tell() const {
+ CHECK_CLOSED();
+ return position_;
+ }
+
+ Status Write(const void* data, int64_t nbytes) {
+ CHECK_CLOSED();
+ if (nbytes == 0) {
+ return Status::OK();
+ }
+
+ if (buffer_size_ > 0) {
+ if (nbytes + buffer_position_ >= buffer_size_) {
+ // Reach end of buffer, write everything
+ RETURN_NOT_OK(Flush());
+ RETURN_NOT_OK(context_->CopyHostToDevice(address_ + position_, data, nbytes));
+ } else {
+ // Write bytes to buffer
+ std::memcpy(host_buffer_data_ + buffer_position_, data, nbytes);
+ buffer_position_ += nbytes;
+ }
+ } else {
+ // Unbuffered write
+ RETURN_NOT_OK(context_->CopyHostToDevice(address_ + position_, data, nbytes));
+ }
+ position_ += nbytes;
+ return Status::OK();
+ }
+
+ Status WriteAt(int64_t position, const void* data, int64_t nbytes) {
+ std::lock_guard<std::mutex> guard(lock_);
+ CHECK_CLOSED();
+ RETURN_NOT_OK(Seek(position));
+ return Write(data, nbytes);
+ }
+
+ Status SetBufferSize(const int64_t buffer_size) {
+ CHECK_CLOSED();
+ if (buffer_position_ > 0) {
+ // Flush any buffered data
+ RETURN_NOT_OK(Flush());
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ host_buffer_,
+ AllocateCudaHostBuffer(context_.get()->device_number(), buffer_size));
+ host_buffer_data_ = host_buffer_->mutable_data();
+ buffer_size_ = buffer_size;
+ return Status::OK();
+ }
+
+ int64_t buffer_size() const { return buffer_size_; }
+
+ int64_t buffer_position() const { return buffer_position_; }
+
+#undef CHECK_CLOSED
+
+ private:
+ std::shared_ptr<CudaContext> context_;
+ std::shared_ptr<CudaBuffer> buffer_;
+ std::mutex lock_;
+ uintptr_t address_;
+ int64_t size_;
+ int64_t position_;
+ bool closed_;
+
+ // Pinned host buffer for buffering writes on CPU before calling cudaMalloc
+ int64_t buffer_size_;
+ int64_t buffer_position_;
+ std::shared_ptr<CudaHostBuffer> host_buffer_;
+ uint8_t* host_buffer_data_;
+};
+
+CudaBufferWriter::CudaBufferWriter(const std::shared_ptr<CudaBuffer>& buffer) {
+ impl_.reset(new CudaBufferWriterImpl(buffer));
+}
+
+CudaBufferWriter::~CudaBufferWriter() {}
+
+Status CudaBufferWriter::Close() { return impl_->Close(); }
+
+bool CudaBufferWriter::closed() const { return impl_->closed(); }
+
+Status CudaBufferWriter::Flush() { return impl_->Flush(); }
+
+Status CudaBufferWriter::Seek(int64_t position) {
+ if (impl_->buffer_position() > 0) {
+ RETURN_NOT_OK(Flush());
+ }
+ return impl_->Seek(position);
+}
+
+Result<int64_t> CudaBufferWriter::Tell() const { return impl_->Tell(); }
+
+Status CudaBufferWriter::Write(const void* data, int64_t nbytes) {
+ return impl_->Write(data, nbytes);
+}
+
+Status CudaBufferWriter::WriteAt(int64_t position, const void* data, int64_t nbytes) {
+ return impl_->WriteAt(position, data, nbytes);
+}
+
+Status CudaBufferWriter::SetBufferSize(const int64_t buffer_size) {
+ return impl_->SetBufferSize(buffer_size);
+}
+
+int64_t CudaBufferWriter::buffer_size() const { return impl_->buffer_size(); }
+
+int64_t CudaBufferWriter::num_bytes_buffered() const { return impl_->buffer_position(); }
+
+// ----------------------------------------------------------------------
+
+Result<std::shared_ptr<CudaHostBuffer>> AllocateCudaHostBuffer(int device_number,
+ const int64_t size) {
+ ARROW_ASSIGN_OR_RAISE(auto manager, CudaDeviceManager::Instance());
+ return manager->AllocateHost(device_number, size);
+}
+
+Result<uintptr_t> GetDeviceAddress(const uint8_t* cpu_data,
+ const std::shared_ptr<CudaContext>& ctx) {
+ ContextSaver context_saver(*ctx);
+ CUdeviceptr ptr;
+ // XXX should we use cuPointerGetAttribute(CU_POINTER_ATTRIBUTE_DEVICE_POINTER)
+ // instead?
+ CU_RETURN_NOT_OK("cuMemHostGetDevicePointer",
+ cuMemHostGetDevicePointer(&ptr, const_cast<uint8_t*>(cpu_data), 0));
+ return static_cast<uintptr_t>(ptr);
+}
+
+Result<uint8_t*> GetHostAddress(uintptr_t device_ptr) {
+ void* ptr;
+ CU_RETURN_NOT_OK(
+ "cuPointerGetAttribute",
+ cuPointerGetAttribute(&ptr, CU_POINTER_ATTRIBUTE_HOST_POINTER, device_ptr));
+ return static_cast<uint8_t*>(ptr);
+}
+
+} // namespace cuda
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/gpu/cuda_memory.h b/src/arrow/cpp/src/arrow/gpu/cuda_memory.h
new file mode 100644
index 000000000..4efd38894
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/cuda_memory.h
@@ -0,0 +1,260 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/buffer.h"
+#include "arrow/io/concurrency.h"
+#include "arrow/type_fwd.h"
+
+namespace arrow {
+namespace cuda {
+
+class CudaContext;
+class CudaIpcMemHandle;
+
+/// \class CudaBuffer
+/// \brief An Arrow buffer located on a GPU device
+///
+/// Be careful using this in any Arrow code which may not be GPU-aware
+class ARROW_EXPORT CudaBuffer : public Buffer {
+ public:
+ // XXX deprecate?
+ CudaBuffer(uint8_t* data, int64_t size, const std::shared_ptr<CudaContext>& context,
+ bool own_data = false, bool is_ipc = false);
+
+ CudaBuffer(uintptr_t address, int64_t size, const std::shared_ptr<CudaContext>& context,
+ bool own_data = false, bool is_ipc = false);
+
+ CudaBuffer(const std::shared_ptr<CudaBuffer>& parent, const int64_t offset,
+ const int64_t size);
+
+ ~CudaBuffer();
+
+ /// \brief Convert back generic buffer into CudaBuffer
+ /// \param[in] buffer buffer to convert
+ /// \return CudaBuffer or Status
+ ///
+ /// \note This function returns an error if the buffer isn't backed
+ /// by GPU memory
+ static Result<std::shared_ptr<CudaBuffer>> FromBuffer(std::shared_ptr<Buffer> buffer);
+
+ /// \brief Copy memory from GPU device to CPU host
+ /// \param[in] position start position inside buffer to copy bytes from
+ /// \param[in] nbytes number of bytes to copy
+ /// \param[out] out start address of the host memory area to copy to
+ /// \return Status
+ Status CopyToHost(const int64_t position, const int64_t nbytes, void* out) const;
+
+ /// \brief Copy memory to device at position
+ /// \param[in] position start position to copy bytes to
+ /// \param[in] data the host data to copy
+ /// \param[in] nbytes number of bytes to copy
+ /// \return Status
+ Status CopyFromHost(const int64_t position, const void* data, int64_t nbytes);
+
+ /// \brief Copy memory from device to device at position
+ /// \param[in] position start position inside buffer to copy bytes to
+ /// \param[in] data start address of the device memory area to copy from
+ /// \param[in] nbytes number of bytes to copy
+ /// \return Status
+ ///
+ /// \note It is assumed that both source and destination device
+ /// memories have been allocated within the same context.
+ Status CopyFromDevice(const int64_t position, const void* data, int64_t nbytes);
+
+ /// \brief Copy memory from another device to device at position
+ /// \param[in] src_ctx context of the source device memory
+ /// \param[in] position start position inside buffer to copy bytes to
+ /// \param[in] data start address of the another device memory area to copy from
+ /// \param[in] nbytes number of bytes to copy
+ /// \return Status
+ Status CopyFromAnotherDevice(const std::shared_ptr<CudaContext>& src_ctx,
+ const int64_t position, const void* data, int64_t nbytes);
+
+ /// \brief Expose this device buffer as IPC memory which can be used in other processes
+ /// \return Handle or Status
+ ///
+ /// \note After calling this function, this device memory will not be freed
+ /// when the CudaBuffer is destructed
+ virtual Result<std::shared_ptr<CudaIpcMemHandle>> ExportForIpc();
+
+ const std::shared_ptr<CudaContext>& context() const { return context_; }
+
+ protected:
+ std::shared_ptr<CudaContext> context_;
+ bool own_data_;
+ bool is_ipc_;
+
+ virtual Status Close();
+};
+
+/// \class CudaHostBuffer
+/// \brief Device-accessible CPU memory created using cudaHostAlloc
+class ARROW_EXPORT CudaHostBuffer : public MutableBuffer {
+ public:
+ using MutableBuffer::MutableBuffer;
+ ~CudaHostBuffer();
+
+ /// \brief Return a device address the GPU can read this memory from.
+ Result<uintptr_t> GetDeviceAddress(const std::shared_ptr<CudaContext>& ctx);
+};
+
+/// \class CudaIpcHandle
+/// \brief A container for a CUDA IPC handle
+class ARROW_EXPORT CudaIpcMemHandle {
+ public:
+ ~CudaIpcMemHandle();
+
+ /// \brief Create CudaIpcMemHandle from opaque buffer (e.g. from another process)
+ /// \param[in] opaque_handle a CUipcMemHandle as a const void*
+ /// \return Handle or Status
+ static Result<std::shared_ptr<CudaIpcMemHandle>> FromBuffer(const void* opaque_handle);
+
+ /// \brief Write CudaIpcMemHandle to a Buffer
+ /// \param[in] pool a MemoryPool to allocate memory from
+ /// \return Buffer or Status
+ Result<std::shared_ptr<Buffer>> Serialize(
+ MemoryPool* pool = default_memory_pool()) const;
+
+ private:
+ explicit CudaIpcMemHandle(const void* handle);
+ CudaIpcMemHandle(int64_t memory_size, const void* cu_handle);
+
+ struct CudaIpcMemHandleImpl;
+ std::unique_ptr<CudaIpcMemHandleImpl> impl_;
+
+ const void* handle() const;
+ int64_t memory_size() const;
+
+ friend CudaBuffer;
+ friend CudaContext;
+};
+
+/// \class CudaBufferReader
+/// \brief File interface for zero-copy read from CUDA buffers
+///
+/// CAUTION: reading to a Buffer returns a Buffer pointing to device memory.
+/// It will generally not be compatible with Arrow code expecting a buffer
+/// pointing to CPU memory.
+/// Reading to a raw pointer, though, copies device memory into the host
+/// memory pointed to.
+class ARROW_EXPORT CudaBufferReader
+ : public ::arrow::io::internal::RandomAccessFileConcurrencyWrapper<CudaBufferReader> {
+ public:
+ explicit CudaBufferReader(const std::shared_ptr<Buffer>& buffer);
+
+ bool closed() const override;
+
+ bool supports_zero_copy() const override;
+
+ std::shared_ptr<CudaBuffer> buffer() const { return buffer_; }
+
+ protected:
+ friend ::arrow::io::internal::RandomAccessFileConcurrencyWrapper<CudaBufferReader>;
+
+ Status DoClose();
+
+ Result<int64_t> DoRead(int64_t nbytes, void* buffer);
+ Result<std::shared_ptr<Buffer>> DoRead(int64_t nbytes);
+ Result<int64_t> DoReadAt(int64_t position, int64_t nbytes, void* out);
+ Result<std::shared_ptr<Buffer>> DoReadAt(int64_t position, int64_t nbytes);
+
+ Result<int64_t> DoTell() const;
+ Status DoSeek(int64_t position);
+ Result<int64_t> DoGetSize();
+
+ Status CheckClosed() const {
+ if (!is_open_) {
+ return Status::Invalid("Operation forbidden on closed CudaBufferReader");
+ }
+ return Status::OK();
+ }
+
+ std::shared_ptr<CudaBuffer> buffer_;
+ std::shared_ptr<CudaContext> context_;
+ const uintptr_t address_;
+ int64_t size_;
+ int64_t position_;
+ bool is_open_;
+};
+
+/// \class CudaBufferWriter
+/// \brief File interface for writing to CUDA buffers, with optional buffering
+class ARROW_EXPORT CudaBufferWriter : public io::WritableFile {
+ public:
+ explicit CudaBufferWriter(const std::shared_ptr<CudaBuffer>& buffer);
+ ~CudaBufferWriter() override;
+
+ /// \brief Close writer and flush buffered bytes to GPU
+ Status Close() override;
+
+ bool closed() const override;
+
+ /// \brief Flush buffered bytes to GPU
+ Status Flush() override;
+
+ Status Seek(int64_t position) override;
+
+ Status Write(const void* data, int64_t nbytes) override;
+
+ Status WriteAt(int64_t position, const void* data, int64_t nbytes) override;
+
+ Result<int64_t> Tell() const override;
+
+ /// \brief Set CPU buffer size to limit calls to cudaMemcpy
+ /// \param[in] buffer_size the size of CPU buffer to allocate
+ /// \return Status
+ ///
+ /// By default writes are unbuffered
+ Status SetBufferSize(const int64_t buffer_size);
+
+ /// \brief Returns size of host (CPU) buffer, 0 for unbuffered
+ int64_t buffer_size() const;
+
+ /// \brief Returns number of bytes buffered on host
+ int64_t num_bytes_buffered() const;
+
+ private:
+ class CudaBufferWriterImpl;
+ std::unique_ptr<CudaBufferWriterImpl> impl_;
+};
+
+/// \brief Allocate CUDA-accessible memory on CPU host
+///
+/// The GPU will benefit from fast access to this CPU-located buffer,
+/// including fast memory copy.
+///
+/// \param[in] device_number device to expose host memory
+/// \param[in] size number of bytes
+/// \return Host buffer or Status
+ARROW_EXPORT
+Result<std::shared_ptr<CudaHostBuffer>> AllocateCudaHostBuffer(int device_number,
+ const int64_t size);
+
+/// Low-level: get a device address through which the CPU data be accessed.
+Result<uintptr_t> GetDeviceAddress(const uint8_t* cpu_data,
+ const std::shared_ptr<CudaContext>& ctx);
+
+/// Low-level: get a CPU address through which the device data be accessed.
+Result<uint8_t*> GetHostAddress(uintptr_t device_ptr);
+
+} // namespace cuda
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/gpu/cuda_test.cc b/src/arrow/cpp/src/arrow/gpu/cuda_test.cc
new file mode 100644
index 000000000..08cf44284
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/cuda_test.cc
@@ -0,0 +1,626 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <limits>
+#include <string>
+
+#include <cuda.h>
+
+#include "gtest/gtest.h"
+
+#include "arrow/io/memory.h"
+#include "arrow/ipc/api.h"
+#include "arrow/ipc/dictionary.h"
+#include "arrow/ipc/test_common.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+
+#include "arrow/gpu/cuda_api.h"
+#include "arrow/gpu/cuda_internal.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace cuda {
+
+using internal::StatusFromCuda;
+
+#define ASSERT_CUDA_OK(expr) ASSERT_OK(::arrow::cuda::internal::StatusFromCuda((expr)))
+
+constexpr int kGpuNumber = 0;
+// Needs a second GPU installed
+constexpr int kOtherGpuNumber = 1;
+
+template <typename Expected>
+void AssertCudaBufferEquals(const CudaBuffer& buffer, Expected&& expected) {
+ ASSERT_OK_AND_ASSIGN(auto result, AllocateBuffer(buffer.size()));
+ ASSERT_OK(buffer.CopyToHost(0, buffer.size(), result->mutable_data()));
+ AssertBufferEqual(*result, expected);
+}
+
+template <typename Expected>
+void AssertCudaBufferEquals(const Buffer& buffer, Expected&& expected) {
+ ASSERT_TRUE(IsCudaDevice(*buffer.device()));
+ AssertCudaBufferEquals(checked_cast<const CudaBuffer&>(buffer),
+ std::forward<Expected>(expected));
+}
+
+class TestCudaBase : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK_AND_ASSIGN(manager_, CudaDeviceManager::Instance());
+ ASSERT_OK_AND_ASSIGN(device_, manager_->GetDevice(kGpuNumber));
+ // ASSERT_OK(device_->GetContext(kGpuNumber, &context_));
+ ASSERT_OK_AND_ASSIGN(context_, device_->GetContext());
+ ASSERT_OK_AND_ASSIGN(mm_, AsCudaMemoryManager(device_->default_memory_manager()));
+ cpu_device_ = CPUDevice::Instance();
+ cpu_mm_ = cpu_device_->default_memory_manager();
+ }
+
+ void TearDown() {
+ for (auto cu_context : non_primary_contexts_) {
+ ASSERT_CUDA_OK(cuCtxDestroy(cu_context));
+ }
+ }
+
+ Result<CUcontext> NonPrimaryRawContext() {
+ CUcontext ctx;
+ RETURN_NOT_OK(StatusFromCuda(cuCtxCreate(&ctx, /*flags=*/0, device_->handle())));
+ non_primary_contexts_.push_back(ctx);
+ return ctx;
+ }
+
+ Result<std::shared_ptr<CudaContext>> NonPrimaryContext() {
+ ARROW_ASSIGN_OR_RAISE(auto cuctx, NonPrimaryRawContext());
+ return device_->GetSharedContext(cuctx);
+ }
+
+ // Returns nullptr if kOtherGpuNumber does not correspond to an installed GPU
+ Result<std::shared_ptr<CudaDevice>> OtherGpuDevice() {
+ auto maybe_device = CudaDevice::Make(kOtherGpuNumber);
+ if (maybe_device.status().IsInvalid()) {
+ return nullptr;
+ }
+ return maybe_device;
+ }
+
+ protected:
+ CudaDeviceManager* manager_;
+ std::shared_ptr<CudaDevice> device_;
+ std::shared_ptr<CudaMemoryManager> mm_;
+ std::shared_ptr<CudaContext> context_;
+ std::shared_ptr<Device> cpu_device_;
+ std::shared_ptr<MemoryManager> cpu_mm_;
+ std::vector<CUcontext> non_primary_contexts_;
+};
+
+// ------------------------------------------------------------------------
+// Test CudaDevice
+
+class TestCudaDevice : public TestCudaBase {
+ public:
+ void SetUp() { TestCudaBase::SetUp(); }
+};
+
+TEST_F(TestCudaDevice, Basics) {
+ ASSERT_FALSE(device_->is_cpu());
+ ASSERT_TRUE(IsCudaDevice(*device_));
+ ASSERT_EQ(device_->device_number(), kGpuNumber);
+ ASSERT_GE(device_->total_memory(), 1 << 20);
+ ASSERT_NE(device_->device_name(), "");
+ ASSERT_NE(device_->ToString(), "");
+
+ ASSERT_OK_AND_ASSIGN(auto other_device, CudaDevice::Make(kGpuNumber));
+ ASSERT_FALSE(other_device->is_cpu());
+ ASSERT_TRUE(IsCudaDevice(*other_device));
+ ASSERT_EQ(other_device->device_number(), kGpuNumber);
+ ASSERT_EQ(other_device->total_memory(), device_->total_memory());
+ ASSERT_EQ(other_device->handle(), device_->handle());
+ ASSERT_EQ(other_device->device_name(), device_->device_name());
+ ASSERT_EQ(*other_device, *device_);
+
+ ASSERT_FALSE(IsCudaDevice(*cpu_device_));
+
+ // Try another device if possible
+ ASSERT_OK_AND_ASSIGN(other_device, OtherGpuDevice());
+ if (other_device != nullptr) {
+ ASSERT_FALSE(other_device->is_cpu());
+ ASSERT_EQ(other_device->device_number(), kOtherGpuNumber);
+ ASSERT_NE(*other_device, *device_);
+ ASSERT_NE(other_device->handle(), device_->handle());
+ ASSERT_NE(other_device->ToString(), device_->ToString());
+ }
+
+ ASSERT_RAISES(Invalid, CudaDevice::Make(-1));
+ ASSERT_RAISES(Invalid, CudaDevice::Make(99));
+}
+
+TEST_F(TestCudaDevice, Copy) {
+ auto cpu_buffer = Buffer::FromString("some data");
+
+ // CPU -> device
+ ASSERT_OK_AND_ASSIGN(auto other_buffer, Buffer::Copy(cpu_buffer, mm_));
+ ASSERT_EQ(other_buffer->device(), device_);
+ AssertCudaBufferEquals(*other_buffer, "some data");
+
+ // device -> CPU
+ ASSERT_OK_AND_ASSIGN(cpu_buffer, Buffer::Copy(other_buffer, cpu_mm_));
+ ASSERT_TRUE(cpu_buffer->device()->is_cpu());
+ AssertBufferEqual(*cpu_buffer, "some data");
+
+ // device -> device
+ const auto old_address = other_buffer->address();
+ ASSERT_OK_AND_ASSIGN(other_buffer, Buffer::Copy(other_buffer, mm_));
+ ASSERT_EQ(other_buffer->device(), device_);
+ ASSERT_NE(other_buffer->address(), old_address);
+ AssertCudaBufferEquals(*other_buffer, "some data");
+
+ // device (other context) -> device
+ ASSERT_OK_AND_ASSIGN(auto other_context, NonPrimaryContext());
+ ASSERT_OK_AND_ASSIGN(auto cuda_buffer, other_context->Allocate(9));
+ ASSERT_OK(cuda_buffer->CopyFromHost(0, "some data", 9));
+ ASSERT_OK_AND_ASSIGN(other_buffer, Buffer::Copy(cuda_buffer, mm_));
+ ASSERT_EQ(other_buffer->device(), device_);
+ AssertCudaBufferEquals(*other_buffer, "some data");
+ auto other_handle = cuda_buffer->context()->handle();
+ ASSERT_OK_AND_ASSIGN(cuda_buffer, CudaBuffer::FromBuffer(other_buffer));
+ ASSERT_NE(cuda_buffer->context()->handle(), other_handle);
+
+ // device -> other device
+ ASSERT_OK_AND_ASSIGN(auto other_device, OtherGpuDevice());
+ if (other_device != nullptr) {
+ ASSERT_OK_AND_ASSIGN(
+ other_buffer, Buffer::Copy(cuda_buffer, other_device->default_memory_manager()));
+ ASSERT_EQ(other_buffer->device(), other_device);
+ AssertCudaBufferEquals(*other_buffer, "some data");
+ }
+}
+
+// ------------------------------------------------------------------------
+// Test CudaContext
+
+class TestCudaContext : public TestCudaBase {
+ public:
+ void SetUp() { TestCudaBase::SetUp(); }
+};
+
+TEST_F(TestCudaContext, Basics) { ASSERT_EQ(*context_->device(), *device_); }
+
+TEST_F(TestCudaContext, NonPrimaryContext) {
+ ASSERT_OK_AND_ASSIGN(auto other_context, NonPrimaryContext());
+ ASSERT_EQ(*other_context->device(), *device_);
+ ASSERT_NE(other_context->handle(), context_->handle());
+}
+
+TEST_F(TestCudaContext, GetDeviceAddress) {
+ const int64_t kSize = 100;
+ ASSERT_OK_AND_ASSIGN(auto buffer, context_->Allocate(kSize));
+ // GetDeviceAddress() is idempotent on device addresses
+ ASSERT_OK_AND_ASSIGN(auto devptr, context_->GetDeviceAddress(buffer->address()));
+ ASSERT_EQ(devptr, buffer->address());
+}
+
+// ------------------------------------------------------------------------
+// Test CudaBuffer
+
+class TestCudaBuffer : public TestCudaBase {
+ public:
+ void SetUp() { TestCudaBase::SetUp(); }
+};
+
+TEST_F(TestCudaBuffer, Allocate) {
+ const int64_t kSize = 100;
+ std::shared_ptr<CudaBuffer> buffer;
+ ASSERT_OK_AND_ASSIGN(buffer, context_->Allocate(kSize));
+ ASSERT_EQ(buffer->device(), context_->device());
+ ASSERT_EQ(kSize, buffer->size());
+ ASSERT_EQ(kSize, context_->bytes_allocated());
+ ASSERT_FALSE(buffer->is_cpu());
+}
+
+TEST_F(TestCudaBuffer, CopyFromHost) {
+ const int64_t kSize = 1000;
+ std::shared_ptr<CudaBuffer> device_buffer;
+ ASSERT_OK_AND_ASSIGN(device_buffer, context_->Allocate(kSize));
+
+ std::shared_ptr<ResizableBuffer> host_buffer;
+ ASSERT_OK(MakeRandomByteBuffer(kSize, default_memory_pool(), &host_buffer));
+
+ ASSERT_OK(device_buffer->CopyFromHost(0, host_buffer->data(), 500));
+ ASSERT_OK(device_buffer->CopyFromHost(500, host_buffer->data() + 500, kSize - 500));
+
+ AssertCudaBufferEquals(*device_buffer, *host_buffer);
+}
+
+TEST_F(TestCudaBuffer, FromBuffer) {
+ const int64_t kSize = 1000;
+ // Initialize device buffer with random data
+ std::shared_ptr<ResizableBuffer> host_buffer;
+ std::shared_ptr<CudaBuffer> device_buffer;
+ ASSERT_OK_AND_ASSIGN(device_buffer, context_->Allocate(kSize));
+ ASSERT_OK(MakeRandomByteBuffer(kSize, default_memory_pool(), &host_buffer));
+ ASSERT_OK(device_buffer->CopyFromHost(0, host_buffer->data(), 1000));
+ // Sanity check
+ AssertCudaBufferEquals(*device_buffer, *host_buffer);
+
+ // Get generic Buffer from device buffer
+ std::shared_ptr<Buffer> buffer;
+ std::shared_ptr<CudaBuffer> result;
+ buffer = std::static_pointer_cast<Buffer>(device_buffer);
+ ASSERT_OK_AND_ASSIGN(result, CudaBuffer::FromBuffer(buffer));
+ ASSERT_EQ(result->size(), kSize);
+ ASSERT_EQ(result->is_mutable(), true);
+ ASSERT_EQ(result->address(), buffer->address());
+ AssertCudaBufferEquals(*result, *host_buffer);
+
+ buffer = SliceBuffer(device_buffer, 0, kSize);
+ ASSERT_OK_AND_ASSIGN(result, CudaBuffer::FromBuffer(buffer));
+ ASSERT_EQ(result->size(), kSize);
+ ASSERT_EQ(result->is_mutable(), false);
+ ASSERT_EQ(result->address(), buffer->address());
+ AssertCudaBufferEquals(*result, *host_buffer);
+
+ buffer = SliceMutableBuffer(device_buffer, 0, kSize);
+ ASSERT_OK_AND_ASSIGN(result, CudaBuffer::FromBuffer(buffer));
+ ASSERT_EQ(result->size(), kSize);
+ ASSERT_EQ(result->is_mutable(), true);
+ ASSERT_EQ(result->address(), buffer->address());
+ AssertCudaBufferEquals(*result, *host_buffer);
+
+ buffer = SliceMutableBuffer(device_buffer, 3, kSize - 10);
+ buffer = SliceMutableBuffer(buffer, 8, kSize - 20);
+ ASSERT_OK_AND_ASSIGN(result, CudaBuffer::FromBuffer(buffer));
+ ASSERT_EQ(result->size(), kSize - 20);
+ ASSERT_EQ(result->is_mutable(), true);
+ ASSERT_EQ(result->address(), buffer->address());
+ AssertCudaBufferEquals(*result, *SliceBuffer(host_buffer, 11, kSize - 20));
+}
+
+// IPC only supported on Linux
+#if defined(__linux)
+
+TEST_F(TestCudaBuffer, DISABLED_ExportForIpc) {
+ // For this test to work, a second process needs to be spawned
+ const int64_t kSize = 1000;
+ std::shared_ptr<CudaBuffer> device_buffer;
+ ASSERT_OK_AND_ASSIGN(device_buffer, context_->Allocate(kSize));
+
+ std::shared_ptr<ResizableBuffer> host_buffer;
+ ASSERT_OK(MakeRandomByteBuffer(kSize, default_memory_pool(), &host_buffer));
+ ASSERT_OK(device_buffer->CopyFromHost(0, host_buffer->data(), kSize));
+
+ // Export for IPC and serialize
+ std::shared_ptr<CudaIpcMemHandle> ipc_handle;
+ ASSERT_OK_AND_ASSIGN(ipc_handle, device_buffer->ExportForIpc());
+
+ std::shared_ptr<Buffer> serialized_handle;
+ ASSERT_OK_AND_ASSIGN(serialized_handle, ipc_handle->Serialize());
+
+ // Deserialize IPC handle and open
+ std::shared_ptr<CudaIpcMemHandle> ipc_handle2;
+ ASSERT_OK_AND_ASSIGN(ipc_handle2,
+ CudaIpcMemHandle::FromBuffer(serialized_handle->data()));
+
+ std::shared_ptr<CudaBuffer> ipc_buffer;
+ ASSERT_OK_AND_ASSIGN(ipc_buffer, context_->OpenIpcBuffer(*ipc_handle2));
+
+ ASSERT_EQ(kSize, ipc_buffer->size());
+
+ ASSERT_OK_AND_ASSIGN(auto ipc_data, AllocateBuffer(kSize));
+ ASSERT_OK(ipc_buffer->CopyToHost(0, kSize, ipc_data->mutable_data()));
+ ASSERT_EQ(0, std::memcmp(ipc_buffer->data(), host_buffer->data(), kSize));
+}
+
+#endif
+
+// ------------------------------------------------------------------------
+// Test CudaHostBuffer
+
+class TestCudaHostBuffer : public TestCudaBase {
+ public:
+};
+
+TEST_F(TestCudaHostBuffer, AllocateGlobal) {
+ // Allocation using the global AllocateCudaHostBuffer() function
+ std::shared_ptr<CudaHostBuffer> host_buffer;
+ ASSERT_OK_AND_ASSIGN(host_buffer, AllocateCudaHostBuffer(kGpuNumber, 1024));
+
+ ASSERT_TRUE(host_buffer->is_cpu());
+ ASSERT_EQ(host_buffer->memory_manager(), cpu_mm_);
+
+ ASSERT_OK_AND_ASSIGN(auto device_address, host_buffer->GetDeviceAddress(context_));
+ ASSERT_NE(device_address, 0);
+ ASSERT_OK_AND_ASSIGN(auto host_address, GetHostAddress(device_address));
+ ASSERT_EQ(host_address, host_buffer->data());
+}
+
+TEST_F(TestCudaHostBuffer, ViewOnDevice) {
+ ASSERT_OK_AND_ASSIGN(auto host_buffer, device_->AllocateHostBuffer(1024));
+
+ ASSERT_TRUE(host_buffer->is_cpu());
+ ASSERT_EQ(host_buffer->memory_manager(), cpu_mm_);
+
+ // Try to view the host buffer on the device. This should correspond to
+ // GetDeviceAddress() in the previous test.
+ ASSERT_OK_AND_ASSIGN(auto device_buffer, Buffer::View(host_buffer, mm_));
+ ASSERT_FALSE(device_buffer->is_cpu());
+ ASSERT_EQ(device_buffer->memory_manager(), mm_);
+ ASSERT_NE(device_buffer->address(), 0);
+ ASSERT_EQ(device_buffer->size(), host_buffer->size());
+ ASSERT_EQ(device_buffer->parent(), host_buffer);
+
+ // View back the device buffer on the CPU. This should roundtrip.
+ ASSERT_OK_AND_ASSIGN(auto buffer, Buffer::View(device_buffer, cpu_mm_));
+ ASSERT_TRUE(buffer->is_cpu());
+ ASSERT_EQ(buffer->memory_manager(), cpu_mm_);
+ ASSERT_EQ(buffer->address(), host_buffer->address());
+ ASSERT_EQ(buffer->size(), host_buffer->size());
+ ASSERT_EQ(buffer->parent(), device_buffer);
+}
+
+// ------------------------------------------------------------------------
+// Test CudaBufferWriter
+
+class TestCudaBufferWriter : public TestCudaBase {
+ public:
+ void SetUp() { TestCudaBase::SetUp(); }
+
+ void Allocate(const int64_t size) {
+ ASSERT_OK_AND_ASSIGN(device_buffer_, context_->Allocate(size));
+ writer_.reset(new CudaBufferWriter(device_buffer_));
+ }
+
+ void TestWrites(const int64_t total_bytes, const int64_t chunksize,
+ const int64_t buffer_size = 0) {
+ std::shared_ptr<ResizableBuffer> buffer;
+ ASSERT_OK(MakeRandomByteBuffer(total_bytes, default_memory_pool(), &buffer));
+
+ if (buffer_size > 0) {
+ ASSERT_OK(writer_->SetBufferSize(buffer_size));
+ }
+
+ ASSERT_OK_AND_EQ(0, writer_->Tell());
+
+ const uint8_t* host_data = buffer->data();
+ ASSERT_OK(writer_->Write(host_data, chunksize));
+ ASSERT_OK_AND_EQ(chunksize, writer_->Tell());
+
+ ASSERT_OK(writer_->Seek(0));
+ ASSERT_OK_AND_EQ(0, writer_->Tell());
+
+ int64_t position = 0;
+ while (position < total_bytes) {
+ int64_t bytes_to_write = std::min(chunksize, total_bytes - position);
+ ASSERT_OK(writer_->Write(host_data + position, bytes_to_write));
+ position += bytes_to_write;
+ }
+
+ ASSERT_OK(writer_->Flush());
+
+ AssertCudaBufferEquals(*device_buffer_, *buffer);
+ }
+
+ protected:
+ std::shared_ptr<CudaBuffer> device_buffer_;
+ std::unique_ptr<CudaBufferWriter> writer_;
+};
+
+TEST_F(TestCudaBufferWriter, UnbufferedWrites) {
+ const int64_t kTotalSize = 1 << 16;
+ Allocate(kTotalSize);
+ TestWrites(kTotalSize, 1000);
+}
+
+TEST_F(TestCudaBufferWriter, BufferedWrites) {
+ const int64_t kTotalSize = 1 << 16;
+ Allocate(kTotalSize);
+ TestWrites(kTotalSize, 1000, 1 << 12);
+}
+
+TEST_F(TestCudaBufferWriter, EdgeCases) {
+ Allocate(1000);
+
+ std::shared_ptr<ResizableBuffer> buffer;
+ ASSERT_OK(MakeRandomByteBuffer(1000, default_memory_pool(), &buffer));
+ const uint8_t* host_data = buffer->data();
+
+ ASSERT_EQ(0, writer_->buffer_size());
+ ASSERT_OK(writer_->SetBufferSize(100));
+ ASSERT_EQ(100, writer_->buffer_size());
+
+ // Write 0 bytes
+ ASSERT_OK(writer_->Write(host_data, 0));
+ ASSERT_OK_AND_EQ(0, writer_->Tell());
+
+ // Write some data, then change buffer size
+ ASSERT_OK(writer_->Write(host_data, 10));
+ ASSERT_OK(writer_->SetBufferSize(200));
+ ASSERT_EQ(200, writer_->buffer_size());
+
+ ASSERT_EQ(0, writer_->num_bytes_buffered());
+
+ // Write more than buffer size
+ ASSERT_OK(writer_->Write(host_data + 10, 300));
+ ASSERT_EQ(0, writer_->num_bytes_buffered());
+
+ // Write exactly buffer size
+ ASSERT_OK(writer_->Write(host_data + 310, 200));
+ ASSERT_EQ(0, writer_->num_bytes_buffered());
+
+ // Write rest of bytes
+ ASSERT_OK(writer_->Write(host_data + 510, 390));
+ ASSERT_OK(writer_->Write(host_data + 900, 100));
+
+ // Close flushes
+ ASSERT_OK(writer_->Close());
+
+ // Check that everything was written
+ AssertCudaBufferEquals(*device_buffer_, Buffer(host_data, 1000));
+}
+
+// ------------------------------------------------------------------------
+// Test CudaBufferReader
+
+class TestCudaBufferReader : public TestCudaBase {
+ public:
+ void SetUp() { TestCudaBase::SetUp(); }
+};
+
+TEST_F(TestCudaBufferReader, Basics) {
+ std::shared_ptr<CudaBuffer> device_buffer;
+
+ const int64_t size = 1000;
+ ASSERT_OK_AND_ASSIGN(device_buffer, context_->Allocate(size));
+
+ std::shared_ptr<ResizableBuffer> buffer;
+ ASSERT_OK(MakeRandomByteBuffer(1000, default_memory_pool(), &buffer));
+ const uint8_t* host_data = buffer->data();
+
+ ASSERT_OK(device_buffer->CopyFromHost(0, host_data, 1000));
+
+ CudaBufferReader reader(device_buffer);
+
+ uint8_t stack_buffer[100] = {0};
+ ASSERT_OK(reader.Seek(950));
+
+ ASSERT_OK_AND_EQ(950, reader.Tell());
+
+ // Read() to host memory
+ ASSERT_OK_AND_EQ(50, reader.Read(100, stack_buffer));
+ ASSERT_EQ(0, std::memcmp(stack_buffer, host_data + 950, 50));
+ ASSERT_OK_AND_EQ(1000, reader.Tell());
+
+ // ReadAt() to host memory
+ ASSERT_OK_AND_EQ(45, reader.ReadAt(123, 45, stack_buffer));
+ ASSERT_EQ(0, std::memcmp(stack_buffer, host_data + 123, 45));
+ ASSERT_OK_AND_EQ(1000, reader.Tell());
+
+ // Read() to device buffer
+ ASSERT_OK(reader.Seek(925));
+ ASSERT_OK_AND_ASSIGN(auto tmp, reader.Read(100));
+ ASSERT_EQ(75, tmp->size());
+ ASSERT_FALSE(tmp->is_cpu());
+ ASSERT_EQ(*tmp->device(), *device_);
+ ASSERT_OK_AND_EQ(1000, reader.Tell());
+
+ ASSERT_OK(std::dynamic_pointer_cast<CudaBuffer>(tmp)->CopyToHost(0, tmp->size(),
+ stack_buffer));
+ ASSERT_EQ(0, std::memcmp(stack_buffer, host_data + 925, tmp->size()));
+
+ // ReadAt() to device buffer
+ ASSERT_OK(reader.Seek(42));
+ ASSERT_OK_AND_ASSIGN(tmp, reader.ReadAt(980, 30));
+ ASSERT_EQ(20, tmp->size());
+ ASSERT_FALSE(tmp->is_cpu());
+ ASSERT_EQ(*tmp->device(), *device_);
+ ASSERT_OK_AND_EQ(42, reader.Tell());
+
+ ASSERT_OK(std::dynamic_pointer_cast<CudaBuffer>(tmp)->CopyToHost(0, tmp->size(),
+ stack_buffer));
+ ASSERT_EQ(0, std::memcmp(stack_buffer, host_data + 980, tmp->size()));
+}
+
+TEST_F(TestCudaBufferReader, WillNeed) {
+ std::shared_ptr<CudaBuffer> device_buffer;
+
+ const int64_t size = 1000;
+ ASSERT_OK_AND_ASSIGN(device_buffer, context_->Allocate(size));
+
+ CudaBufferReader reader(device_buffer);
+
+ ASSERT_OK(reader.WillNeed({{0, size}}));
+}
+
+// ------------------------------------------------------------------------
+// Test Cuda IPC
+
+class TestCudaArrowIpc : public TestCudaBase {
+ public:
+ void SetUp() {
+ TestCudaBase::SetUp();
+ pool_ = default_memory_pool();
+ }
+
+ protected:
+ MemoryPool* pool_;
+};
+
+TEST_F(TestCudaArrowIpc, BasicWriteRead) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch));
+
+ std::shared_ptr<CudaBuffer> device_serialized;
+ ASSERT_OK_AND_ASSIGN(device_serialized, SerializeRecordBatch(*batch, context_.get()));
+
+ // Test that ReadRecordBatch works properly
+ ipc::DictionaryMemo unused_memo;
+ std::shared_ptr<RecordBatch> device_batch;
+ ASSERT_OK_AND_ASSIGN(device_batch,
+ ReadRecordBatch(batch->schema(), &unused_memo, device_serialized));
+
+ ASSERT_OK(device_batch->Validate());
+
+ // Copy data from device, read batch, and compare
+ int64_t size = device_serialized->size();
+ ASSERT_OK_AND_ASSIGN(auto host_buffer, AllocateBuffer(size, pool_));
+ ASSERT_OK(device_serialized->CopyToHost(0, size, host_buffer->mutable_data()));
+
+ std::shared_ptr<RecordBatch> cpu_batch;
+ io::BufferReader cpu_reader(std::move(host_buffer));
+ ASSERT_OK_AND_ASSIGN(
+ cpu_batch, ipc::ReadRecordBatch(batch->schema(), &unused_memo,
+ ipc::IpcReadOptions::Defaults(), &cpu_reader));
+
+ CompareBatch(*batch, *cpu_batch);
+}
+
+TEST_F(TestCudaArrowIpc, DictionaryWriteRead) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(ipc::test::MakeDictionary(&batch));
+
+ ipc::DictionaryMemo dictionary_memo;
+ ASSERT_OK(ipc::internal::CollectDictionaries(*batch, &dictionary_memo));
+
+ std::shared_ptr<CudaBuffer> device_serialized;
+ ASSERT_OK_AND_ASSIGN(device_serialized, SerializeRecordBatch(*batch, context_.get()));
+
+ // Test that ReadRecordBatch works properly
+ std::shared_ptr<RecordBatch> device_batch;
+ ASSERT_OK_AND_ASSIGN(device_batch, ReadRecordBatch(batch->schema(), &dictionary_memo,
+ device_serialized));
+
+ // Copy data from device, read batch, and compare
+ int64_t size = device_serialized->size();
+ ASSERT_OK_AND_ASSIGN(auto host_buffer, AllocateBuffer(size, pool_));
+ ASSERT_OK(device_serialized->CopyToHost(0, size, host_buffer->mutable_data()));
+
+ std::shared_ptr<RecordBatch> cpu_batch;
+ io::BufferReader cpu_reader(std::move(host_buffer));
+ ASSERT_OK_AND_ASSIGN(
+ cpu_batch, ipc::ReadRecordBatch(batch->schema(), &dictionary_memo,
+ ipc::IpcReadOptions::Defaults(), &cpu_reader));
+
+ CompareBatch(*batch, *cpu_batch);
+}
+
+} // namespace cuda
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/gpu/cuda_version.h.in b/src/arrow/cpp/src/arrow/gpu/cuda_version.h.in
new file mode 100644
index 000000000..bc687685d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/gpu/cuda_version.h.in
@@ -0,0 +1,25 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef ARROW_GPU_CUDA_VERSION_H
+#define ARROW_GPU_CUDA_VERSION_H
+
+// Set the CUDA version used to build the library
+#define ARROW_CUDA_ABI_VERSION_MAJOR @CUDA_VERSION_MAJOR@
+#define ARROW_CUDA_ABI_VERSION_MINOR @CUDA_VERSION_MINOR@
+
+#endif // ARROW_GPU_CUDA_VERSION_H
diff --git a/src/arrow/cpp/src/arrow/io/CMakeLists.txt b/src/arrow/cpp/src/arrow/io/CMakeLists.txt
new file mode 100644
index 000000000..1669a9ba6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/CMakeLists.txt
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# ----------------------------------------------------------------------
+# arrow_io : Arrow IO interfaces
+
+add_arrow_test(buffered_test PREFIX "arrow-io")
+add_arrow_test(compressed_test PREFIX "arrow-io")
+add_arrow_test(file_test PREFIX "arrow-io")
+
+if(ARROW_HDFS)
+ add_arrow_test(hdfs_test NO_VALGRIND PREFIX "arrow-io")
+endif()
+
+add_arrow_test(memory_test PREFIX "arrow-io")
+
+add_arrow_benchmark(file_benchmark PREFIX "arrow-io")
+
+if(NOT (${ARROW_SIMD_LEVEL} STREQUAL "NONE"))
+ # This benchmark either requires SSE4.2 or ARMV8 SIMD to be enabled
+ add_arrow_benchmark(memory_benchmark PREFIX "arrow-io")
+endif()
+
+# Headers: top level
+arrow_install_all_headers("arrow/io")
diff --git a/src/arrow/cpp/src/arrow/io/api.h b/src/arrow/cpp/src/arrow/io/api.h
new file mode 100644
index 000000000..d55b2c2d5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/api.h
@@ -0,0 +1,25 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/io/buffered.h"
+#include "arrow/io/compressed.h"
+#include "arrow/io/file.h"
+#include "arrow/io/hdfs.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/memory.h"
diff --git a/src/arrow/cpp/src/arrow/io/buffered.cc b/src/arrow/cpp/src/arrow/io/buffered.cc
new file mode 100644
index 000000000..7804c130c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/buffered.cc
@@ -0,0 +1,489 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/io/buffered.h"
+
+#include <algorithm>
+#include <cstring>
+#include <memory>
+#include <mutex>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/io/util_internal.h"
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace io {
+
+// ----------------------------------------------------------------------
+// BufferedOutputStream implementation
+
+class BufferedBase {
+ public:
+ explicit BufferedBase(MemoryPool* pool)
+ : pool_(pool),
+ is_open_(true),
+ buffer_data_(nullptr),
+ buffer_pos_(0),
+ buffer_size_(0),
+ raw_pos_(-1) {}
+
+ bool closed() const {
+ std::lock_guard<std::mutex> guard(lock_);
+ return !is_open_;
+ }
+
+ Status ResetBuffer() {
+ if (!buffer_) {
+ // On first invocation, or if the buffer has been released, we allocate a
+ // new buffer
+ ARROW_ASSIGN_OR_RAISE(buffer_, AllocateResizableBuffer(buffer_size_, pool_));
+ } else if (buffer_->size() != buffer_size_) {
+ RETURN_NOT_OK(buffer_->Resize(buffer_size_));
+ }
+ buffer_data_ = buffer_->mutable_data();
+ return Status::OK();
+ }
+
+ Status ResizeBuffer(int64_t new_buffer_size) {
+ buffer_size_ = new_buffer_size;
+ return ResetBuffer();
+ }
+
+ void AppendToBuffer(const void* data, int64_t nbytes) {
+ DCHECK_LE(buffer_pos_ + nbytes, buffer_size_);
+ std::memcpy(buffer_data_ + buffer_pos_, data, nbytes);
+ buffer_pos_ += nbytes;
+ }
+
+ int64_t buffer_size() const { return buffer_size_; }
+
+ int64_t buffer_pos() const { return buffer_pos_; }
+
+ protected:
+ MemoryPool* pool_;
+ bool is_open_;
+
+ std::shared_ptr<ResizableBuffer> buffer_;
+ uint8_t* buffer_data_;
+ int64_t buffer_pos_;
+ int64_t buffer_size_;
+
+ mutable int64_t raw_pos_;
+ mutable std::mutex lock_;
+};
+
+class BufferedOutputStream::Impl : public BufferedBase {
+ public:
+ explicit Impl(std::shared_ptr<OutputStream> raw, MemoryPool* pool)
+ : BufferedBase(pool), raw_(std::move(raw)) {}
+
+ Status Close() {
+ std::lock_guard<std::mutex> guard(lock_);
+ if (is_open_) {
+ Status st = FlushUnlocked();
+ is_open_ = false;
+ RETURN_NOT_OK(raw_->Close());
+ return st;
+ }
+ return Status::OK();
+ }
+
+ Status Abort() {
+ std::lock_guard<std::mutex> guard(lock_);
+ if (is_open_) {
+ is_open_ = false;
+ return raw_->Abort();
+ }
+ return Status::OK();
+ }
+
+ Result<int64_t> Tell() const {
+ std::lock_guard<std::mutex> guard(lock_);
+ if (raw_pos_ == -1) {
+ ARROW_ASSIGN_OR_RAISE(raw_pos_, raw_->Tell());
+ DCHECK_GE(raw_pos_, 0);
+ }
+ return raw_pos_ + buffer_pos_;
+ }
+
+ Status Write(const void* data, int64_t nbytes) { return DoWrite(data, nbytes); }
+
+ Status Write(const std::shared_ptr<Buffer>& buffer) {
+ return DoWrite(buffer->data(), buffer->size(), buffer);
+ }
+
+ Status DoWrite(const void* data, int64_t nbytes,
+ const std::shared_ptr<Buffer>& buffer = nullptr) {
+ std::lock_guard<std::mutex> guard(lock_);
+ if (nbytes < 0) {
+ return Status::Invalid("write count should be >= 0");
+ }
+ if (nbytes == 0) {
+ return Status::OK();
+ }
+ if (nbytes + buffer_pos_ >= buffer_size_) {
+ RETURN_NOT_OK(FlushUnlocked());
+ DCHECK_EQ(buffer_pos_, 0);
+ if (nbytes >= buffer_size_) {
+ // Direct write
+ if (buffer) {
+ return raw_->Write(buffer);
+ } else {
+ return raw_->Write(data, nbytes);
+ }
+ }
+ }
+ AppendToBuffer(data, nbytes);
+ return Status::OK();
+ }
+
+ Status FlushUnlocked() {
+ if (buffer_pos_ > 0) {
+ // Invalidate cached raw pos
+ raw_pos_ = -1;
+ RETURN_NOT_OK(raw_->Write(buffer_data_, buffer_pos_));
+ buffer_pos_ = 0;
+ }
+ return Status::OK();
+ }
+
+ Status Flush() {
+ std::lock_guard<std::mutex> guard(lock_);
+ return FlushUnlocked();
+ }
+
+ Result<std::shared_ptr<OutputStream>> Detach() {
+ std::lock_guard<std::mutex> guard(lock_);
+ RETURN_NOT_OK(FlushUnlocked());
+ is_open_ = false;
+ return std::move(raw_);
+ }
+
+ Status SetBufferSize(int64_t new_buffer_size) {
+ std::lock_guard<std::mutex> guard(lock_);
+ if (new_buffer_size <= 0) {
+ return Status::Invalid("Buffer size should be positive");
+ }
+ if (buffer_pos_ >= new_buffer_size) {
+ // If the buffer is shrinking, first flush to the raw OutputStream
+ RETURN_NOT_OK(FlushUnlocked());
+ }
+ return ResizeBuffer(new_buffer_size);
+ }
+
+ std::shared_ptr<OutputStream> raw() const { return raw_; }
+
+ private:
+ std::shared_ptr<OutputStream> raw_;
+};
+
+BufferedOutputStream::BufferedOutputStream(std::shared_ptr<OutputStream> raw,
+ MemoryPool* pool) {
+ impl_.reset(new Impl(std::move(raw), pool));
+}
+
+Result<std::shared_ptr<BufferedOutputStream>> BufferedOutputStream::Create(
+ int64_t buffer_size, MemoryPool* pool, std::shared_ptr<OutputStream> raw) {
+ auto result = std::shared_ptr<BufferedOutputStream>(
+ new BufferedOutputStream(std::move(raw), pool));
+ RETURN_NOT_OK(result->SetBufferSize(buffer_size));
+ return result;
+}
+
+BufferedOutputStream::~BufferedOutputStream() { internal::CloseFromDestructor(this); }
+
+Status BufferedOutputStream::SetBufferSize(int64_t new_buffer_size) {
+ return impl_->SetBufferSize(new_buffer_size);
+}
+
+int64_t BufferedOutputStream::buffer_size() const { return impl_->buffer_size(); }
+
+int64_t BufferedOutputStream::bytes_buffered() const { return impl_->buffer_pos(); }
+
+Result<std::shared_ptr<OutputStream>> BufferedOutputStream::Detach() {
+ return impl_->Detach();
+}
+
+Status BufferedOutputStream::Close() { return impl_->Close(); }
+
+Status BufferedOutputStream::Abort() { return impl_->Abort(); }
+
+bool BufferedOutputStream::closed() const { return impl_->closed(); }
+
+Result<int64_t> BufferedOutputStream::Tell() const { return impl_->Tell(); }
+
+Status BufferedOutputStream::Write(const void* data, int64_t nbytes) {
+ return impl_->Write(data, nbytes);
+}
+
+Status BufferedOutputStream::Write(const std::shared_ptr<Buffer>& data) {
+ return impl_->Write(data);
+}
+
+Status BufferedOutputStream::Flush() { return impl_->Flush(); }
+
+std::shared_ptr<OutputStream> BufferedOutputStream::raw() const { return impl_->raw(); }
+
+// ----------------------------------------------------------------------
+// BufferedInputStream implementation
+
+class BufferedInputStream::Impl : public BufferedBase {
+ public:
+ Impl(std::shared_ptr<InputStream> raw, MemoryPool* pool, int64_t raw_total_bytes_bound)
+ : BufferedBase(pool),
+ raw_(std::move(raw)),
+ raw_read_total_(0),
+ raw_read_bound_(raw_total_bytes_bound),
+ bytes_buffered_(0) {}
+
+ Status Close() {
+ if (is_open_) {
+ is_open_ = false;
+ return raw_->Close();
+ }
+ return Status::OK();
+ }
+
+ Status Abort() {
+ if (is_open_) {
+ is_open_ = false;
+ return raw_->Abort();
+ }
+ return Status::OK();
+ }
+
+ Result<int64_t> Tell() const {
+ if (raw_pos_ == -1) {
+ ARROW_ASSIGN_OR_RAISE(raw_pos_, raw_->Tell());
+ DCHECK_GE(raw_pos_, 0);
+ }
+ // Shift by bytes_buffered to return semantic stream position
+ return raw_pos_ - bytes_buffered_;
+ }
+
+ Status SetBufferSize(int64_t new_buffer_size) {
+ if (new_buffer_size <= 0) {
+ return Status::Invalid("Buffer size should be positive");
+ }
+ if ((buffer_pos_ + bytes_buffered_) >= new_buffer_size) {
+ return Status::Invalid("Cannot shrink read buffer if buffered data remains");
+ }
+ return ResizeBuffer(new_buffer_size);
+ }
+
+ Result<util::string_view> Peek(int64_t nbytes) {
+ if (raw_read_bound_ >= 0) {
+ // Do not try to peek more than the total remaining number of bytes.
+ nbytes = std::min(nbytes, bytes_buffered_ + (raw_read_bound_ - raw_read_total_));
+ }
+
+ if (bytes_buffered_ == 0 && nbytes < buffer_size_) {
+ // Pre-buffer for small reads
+ RETURN_NOT_OK(BufferIfNeeded());
+ }
+
+ // Increase the buffer size if needed.
+ if (nbytes > buffer_->size() - buffer_pos_) {
+ RETURN_NOT_OK(SetBufferSize(nbytes + buffer_pos_));
+ DCHECK(buffer_->size() - buffer_pos_ >= nbytes);
+ }
+ // Read more data when buffer has insufficient left
+ if (nbytes > bytes_buffered_) {
+ int64_t additional_bytes_to_read = nbytes - bytes_buffered_;
+ if (raw_read_bound_ >= 0) {
+ additional_bytes_to_read =
+ std::min(additional_bytes_to_read, raw_read_bound_ - raw_read_total_);
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ int64_t bytes_read,
+ raw_->Read(additional_bytes_to_read,
+ buffer_->mutable_data() + buffer_pos_ + bytes_buffered_));
+ bytes_buffered_ += bytes_read;
+ raw_read_total_ += bytes_read;
+ nbytes = bytes_buffered_;
+ }
+ DCHECK(nbytes <= bytes_buffered_); // Enough bytes available
+ return util::string_view(reinterpret_cast<const char*>(buffer_data_ + buffer_pos_),
+ static_cast<size_t>(nbytes));
+ }
+
+ int64_t bytes_buffered() const { return bytes_buffered_; }
+
+ int64_t buffer_size() const { return buffer_size_; }
+
+ std::shared_ptr<InputStream> Detach() {
+ is_open_ = false;
+ return std::move(raw_);
+ }
+
+ void RewindBuffer() {
+ // Invalidate buffered data, as with a Seek or large Read
+ buffer_pos_ = bytes_buffered_ = 0;
+ }
+
+ Status BufferIfNeeded() {
+ if (bytes_buffered_ == 0) {
+ // Fill buffer
+ if (!buffer_) {
+ RETURN_NOT_OK(ResetBuffer());
+ }
+
+ int64_t bytes_to_buffer = buffer_size_;
+ if (raw_read_bound_ >= 0) {
+ bytes_to_buffer = std::min(buffer_size_, raw_read_bound_ - raw_read_total_);
+ }
+ ARROW_ASSIGN_OR_RAISE(bytes_buffered_, raw_->Read(bytes_to_buffer, buffer_data_));
+ buffer_pos_ = 0;
+ raw_read_total_ += bytes_buffered_;
+
+ // Do not make assumptions about the raw stream position
+ raw_pos_ = -1;
+ }
+ return Status::OK();
+ }
+
+ void ConsumeBuffer(int64_t nbytes) {
+ buffer_pos_ += nbytes;
+ bytes_buffered_ -= nbytes;
+ }
+
+ Result<int64_t> Read(int64_t nbytes, void* out) {
+ if (ARROW_PREDICT_FALSE(nbytes < 0)) {
+ return Status::Invalid("Bytes to read must be positive. Received:", nbytes);
+ }
+
+ if (nbytes < buffer_size_) {
+ // Pre-buffer for small reads
+ RETURN_NOT_OK(BufferIfNeeded());
+ }
+
+ if (nbytes > bytes_buffered_) {
+ // Copy buffered bytes into out, then read rest
+ memcpy(out, buffer_data_ + buffer_pos_, bytes_buffered_);
+
+ int64_t bytes_to_read = nbytes - bytes_buffered_;
+ if (raw_read_bound_ >= 0) {
+ bytes_to_read = std::min(bytes_to_read, raw_read_bound_ - raw_read_total_);
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ int64_t bytes_read,
+ raw_->Read(bytes_to_read, reinterpret_cast<uint8_t*>(out) + bytes_buffered_));
+ raw_read_total_ += bytes_read;
+
+ // Do not make assumptions about the raw stream position
+ raw_pos_ = -1;
+ bytes_read += bytes_buffered_;
+ RewindBuffer();
+ return bytes_read;
+ } else {
+ memcpy(out, buffer_data_ + buffer_pos_, nbytes);
+ ConsumeBuffer(nbytes);
+ return nbytes;
+ }
+ }
+
+ Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) {
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateResizableBuffer(nbytes, pool_));
+
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, Read(nbytes, buffer->mutable_data()));
+
+ if (bytes_read < nbytes) {
+ // Change size but do not reallocate internal capacity
+ RETURN_NOT_OK(buffer->Resize(bytes_read, false /* shrink_to_fit */));
+ buffer->ZeroPadding();
+ }
+ return std::move(buffer);
+ }
+
+ // For providing access to the raw file handles
+ std::shared_ptr<InputStream> raw() const { return raw_; }
+
+ private:
+ std::shared_ptr<InputStream> raw_;
+ int64_t raw_read_total_;
+ int64_t raw_read_bound_;
+
+ // Number of remaining bytes in the buffer, to be reduced on each read from
+ // the buffer
+ int64_t bytes_buffered_;
+};
+
+BufferedInputStream::BufferedInputStream(std::shared_ptr<InputStream> raw,
+ MemoryPool* pool,
+ int64_t raw_total_bytes_bound) {
+ impl_.reset(new Impl(std::move(raw), pool, raw_total_bytes_bound));
+}
+
+BufferedInputStream::~BufferedInputStream() { internal::CloseFromDestructor(this); }
+
+Result<std::shared_ptr<BufferedInputStream>> BufferedInputStream::Create(
+ int64_t buffer_size, MemoryPool* pool, std::shared_ptr<InputStream> raw,
+ int64_t raw_total_bytes_bound) {
+ auto result = std::shared_ptr<BufferedInputStream>(
+ new BufferedInputStream(std::move(raw), pool, raw_total_bytes_bound));
+ RETURN_NOT_OK(result->SetBufferSize(buffer_size));
+ return result;
+}
+
+Status BufferedInputStream::DoClose() { return impl_->Close(); }
+
+Status BufferedInputStream::DoAbort() { return impl_->Abort(); }
+
+bool BufferedInputStream::closed() const { return impl_->closed(); }
+
+std::shared_ptr<InputStream> BufferedInputStream::Detach() { return impl_->Detach(); }
+
+std::shared_ptr<InputStream> BufferedInputStream::raw() const { return impl_->raw(); }
+
+Result<int64_t> BufferedInputStream::DoTell() const { return impl_->Tell(); }
+
+Result<util::string_view> BufferedInputStream::DoPeek(int64_t nbytes) {
+ return impl_->Peek(nbytes);
+}
+
+Status BufferedInputStream::SetBufferSize(int64_t new_buffer_size) {
+ return impl_->SetBufferSize(new_buffer_size);
+}
+
+int64_t BufferedInputStream::bytes_buffered() const { return impl_->bytes_buffered(); }
+
+int64_t BufferedInputStream::buffer_size() const { return impl_->buffer_size(); }
+
+Result<int64_t> BufferedInputStream::DoRead(int64_t nbytes, void* out) {
+ return impl_->Read(nbytes, out);
+}
+
+Result<std::shared_ptr<Buffer>> BufferedInputStream::DoRead(int64_t nbytes) {
+ return impl_->Read(nbytes);
+}
+
+Result<std::shared_ptr<const KeyValueMetadata>> BufferedInputStream::ReadMetadata() {
+ return impl_->raw()->ReadMetadata();
+}
+
+Future<std::shared_ptr<const KeyValueMetadata>> BufferedInputStream::ReadMetadataAsync(
+ const IOContext& io_context) {
+ return impl_->raw()->ReadMetadataAsync(io_context);
+}
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/buffered.h b/src/arrow/cpp/src/arrow/io/buffered.h
new file mode 100644
index 000000000..8116613fa
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/buffered.h
@@ -0,0 +1,167 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Buffered stream implementations
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/io/concurrency.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Buffer;
+class MemoryPool;
+class Status;
+
+namespace io {
+
+class ARROW_EXPORT BufferedOutputStream : public OutputStream {
+ public:
+ ~BufferedOutputStream() override;
+
+ /// \brief Create a buffered output stream wrapping the given output stream.
+ /// \param[in] buffer_size the size of the temporary write buffer
+ /// \param[in] pool a MemoryPool to use for allocations
+ /// \param[in] raw another OutputStream
+ /// \return the created BufferedOutputStream
+ static Result<std::shared_ptr<BufferedOutputStream>> Create(
+ int64_t buffer_size, MemoryPool* pool, std::shared_ptr<OutputStream> raw);
+
+ /// \brief Resize internal buffer
+ /// \param[in] new_buffer_size the new buffer size
+ /// \return Status
+ Status SetBufferSize(int64_t new_buffer_size);
+
+ /// \brief Return the current size of the internal buffer
+ int64_t buffer_size() const;
+
+ /// \brief Return the number of remaining bytes that have not been flushed to
+ /// the raw OutputStream
+ int64_t bytes_buffered() const;
+
+ /// \brief Flush any buffered writes and release the raw
+ /// OutputStream. Further operations on this object are invalid
+ /// \return the underlying OutputStream
+ Result<std::shared_ptr<OutputStream>> Detach();
+
+ // OutputStream interface
+
+ /// \brief Close the buffered output stream. This implicitly closes the
+ /// underlying raw output stream.
+ Status Close() override;
+ Status Abort() override;
+ bool closed() const override;
+
+ Result<int64_t> Tell() const override;
+ // Write bytes to the stream. Thread-safe
+ Status Write(const void* data, int64_t nbytes) override;
+ Status Write(const std::shared_ptr<Buffer>& data) override;
+
+ Status Flush() override;
+
+ /// \brief Return the underlying raw output stream.
+ std::shared_ptr<OutputStream> raw() const;
+
+ private:
+ explicit BufferedOutputStream(std::shared_ptr<OutputStream> raw, MemoryPool* pool);
+
+ class ARROW_NO_EXPORT Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+/// \class BufferedInputStream
+/// \brief An InputStream that performs buffered reads from an unbuffered
+/// InputStream, which can mitigate the overhead of many small reads in some
+/// cases
+class ARROW_EXPORT BufferedInputStream
+ : public internal::InputStreamConcurrencyWrapper<BufferedInputStream> {
+ public:
+ ~BufferedInputStream() override;
+
+ /// \brief Create a BufferedInputStream from a raw InputStream
+ /// \param[in] buffer_size the size of the temporary read buffer
+ /// \param[in] pool a MemoryPool to use for allocations
+ /// \param[in] raw a raw InputStream
+ /// \param[in] raw_read_bound a bound on the maximum number of bytes
+ /// to read from the raw input stream. The default -1 indicates that
+ /// it is unbounded
+ /// \return the created BufferedInputStream
+ static Result<std::shared_ptr<BufferedInputStream>> Create(
+ int64_t buffer_size, MemoryPool* pool, std::shared_ptr<InputStream> raw,
+ int64_t raw_read_bound = -1);
+
+ /// \brief Resize internal read buffer; calls to Read(...) will read at least
+ /// \param[in] new_buffer_size the new read buffer size
+ /// \return Status
+ Status SetBufferSize(int64_t new_buffer_size);
+
+ /// \brief Return the number of remaining bytes in the read buffer
+ int64_t bytes_buffered() const;
+
+ /// \brief Return the current size of the internal buffer
+ int64_t buffer_size() const;
+
+ /// \brief Release the raw InputStream. Any data buffered will be
+ /// discarded. Further operations on this object are invalid
+ /// \return raw the underlying InputStream
+ std::shared_ptr<InputStream> Detach();
+
+ /// \brief Return the unbuffered InputStream
+ std::shared_ptr<InputStream> raw() const;
+
+ // InputStream APIs
+
+ bool closed() const override;
+ Result<std::shared_ptr<const KeyValueMetadata>> ReadMetadata() override;
+ Future<std::shared_ptr<const KeyValueMetadata>> ReadMetadataAsync(
+ const IOContext& io_context) override;
+
+ private:
+ friend InputStreamConcurrencyWrapper<BufferedInputStream>;
+
+ explicit BufferedInputStream(std::shared_ptr<InputStream> raw, MemoryPool* pool,
+ int64_t raw_total_bytes_bound);
+
+ Status DoClose();
+ Status DoAbort() override;
+
+ /// \brief Returns the position of the buffered stream, though the position
+ /// of the unbuffered stream may be further advanced.
+ Result<int64_t> DoTell() const;
+
+ Result<int64_t> DoRead(int64_t nbytes, void* out);
+
+ /// \brief Read into buffer.
+ Result<std::shared_ptr<Buffer>> DoRead(int64_t nbytes);
+
+ /// \brief Return a zero-copy string view referencing buffered data,
+ /// but do not advance the position of the stream. Buffers data and
+ /// expands the buffer size if necessary
+ Result<util::string_view> DoPeek(int64_t nbytes) override;
+
+ class ARROW_NO_EXPORT Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/buffered_test.cc b/src/arrow/cpp/src/arrow/io/buffered_test.cc
new file mode 100644
index 000000000..1fefc261b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/buffered_test.cc
@@ -0,0 +1,667 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef _WIN32
+#include <fcntl.h> // IWYU pragma: keep
+#include <unistd.h>
+#endif
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdio>
+#include <functional>
+#include <iterator>
+#include <memory>
+#include <random>
+#include <string>
+#include <utility>
+#include <valarray>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/io/buffered.h"
+#include "arrow/io/file.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/memory.h"
+#include "arrow/io/test_common.h"
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace io {
+
+using ::arrow::internal::TemporaryDir;
+
+static std::string GenerateRandomData(size_t nbytes) {
+ // MSVC doesn't accept uint8_t for std::independent_bits_engine<>
+ typedef unsigned long UInt; // NOLINT
+ std::independent_bits_engine<std::default_random_engine, 8 * sizeof(UInt), UInt> engine;
+
+ std::vector<UInt> data(nbytes / sizeof(UInt) + 1);
+ std::generate(begin(data), end(data), std::ref(engine));
+ return std::string(reinterpret_cast<char*>(data.data()), nbytes);
+}
+
+template <typename FileType>
+class FileTestFixture : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK_AND_ASSIGN(temp_dir_, TemporaryDir::Make("buffered-test-"));
+ path_ = temp_dir_->path()
+ .Join("arrow-test-io-buffered-stream.txt")
+ .ValueOrDie()
+ .ToString();
+ EnsureFileDeleted();
+ }
+
+ void TearDown() { EnsureFileDeleted(); }
+
+ void EnsureFileDeleted() {
+ if (FileExists(path_)) {
+ ARROW_UNUSED(std::remove(path_.c_str()));
+ }
+ }
+
+ void AssertTell(int64_t expected) { ASSERT_OK_AND_EQ(expected, buffered_->Tell()); }
+
+ protected:
+ int fd_;
+ std::shared_ptr<FileType> buffered_;
+ std::string path_;
+ std::unique_ptr<TemporaryDir> temp_dir_;
+};
+
+// ----------------------------------------------------------------------
+// Buffered output tests
+
+constexpr int64_t kDefaultBufferSize = 4096;
+
+class TestBufferedOutputStream : public FileTestFixture<BufferedOutputStream> {
+ public:
+ void OpenBuffered(int64_t buffer_size = kDefaultBufferSize, bool append = false) {
+ // So that any open file is closed
+ buffered_.reset();
+
+ ASSERT_OK_AND_ASSIGN(auto file, FileOutputStream::Open(path_, append));
+ fd_ = file->file_descriptor();
+ if (append) {
+ // Workaround for ARROW-2466 ("append" flag doesn't set file pos)
+#if defined(_MSC_VER)
+ _lseeki64(fd_, 0, SEEK_END);
+#else
+ lseek(fd_, 0, SEEK_END);
+#endif
+ }
+ ASSERT_OK_AND_ASSIGN(buffered_, BufferedOutputStream::Create(
+ buffer_size, default_memory_pool(), file));
+ }
+
+ void WriteChunkwise(const std::string& datastr, const std::valarray<int64_t>& sizes) {
+ const char* data = datastr.data();
+ const int64_t data_size = static_cast<int64_t>(datastr.size());
+ int64_t data_pos = 0;
+ auto size_it = std::begin(sizes);
+
+ // Write datastr, chunk by chunk, until exhausted
+ while (true) {
+ int64_t size = *size_it++;
+ if (size_it == std::end(sizes)) {
+ size_it = std::begin(sizes);
+ }
+ if (data_pos + size > data_size) {
+ break;
+ }
+ ASSERT_OK(buffered_->Write(data + data_pos, size));
+ data_pos += size;
+ }
+ ASSERT_OK(buffered_->Write(data + data_pos, data_size - data_pos));
+ }
+};
+
+TEST_F(TestBufferedOutputStream, DestructorClosesFile) {
+ OpenBuffered();
+ ASSERT_FALSE(FileIsClosed(fd_));
+ buffered_.reset();
+ ASSERT_TRUE(FileIsClosed(fd_));
+}
+
+TEST_F(TestBufferedOutputStream, Detach) {
+ OpenBuffered();
+ const std::string datastr = "1234568790";
+
+ ASSERT_OK(buffered_->Write(datastr.data(), 10));
+
+ ASSERT_OK_AND_ASSIGN(auto detached_stream, buffered_->Detach());
+
+ // Destroying the stream does not close the file because we have detached
+ buffered_.reset();
+ ASSERT_FALSE(FileIsClosed(fd_));
+
+ ASSERT_OK(detached_stream->Close());
+ ASSERT_TRUE(FileIsClosed(fd_));
+
+ AssertFileContents(path_, datastr);
+}
+
+TEST_F(TestBufferedOutputStream, ExplicitCloseClosesFile) {
+ OpenBuffered();
+ ASSERT_FALSE(buffered_->closed());
+ ASSERT_FALSE(FileIsClosed(fd_));
+ ASSERT_OK(buffered_->Close());
+ ASSERT_TRUE(buffered_->closed());
+ ASSERT_TRUE(FileIsClosed(fd_));
+ // Idempotency
+ ASSERT_OK(buffered_->Close());
+ ASSERT_TRUE(buffered_->closed());
+ ASSERT_TRUE(FileIsClosed(fd_));
+}
+
+TEST_F(TestBufferedOutputStream, InvalidWrites) {
+ OpenBuffered();
+
+ const char* data = "";
+ ASSERT_RAISES(Invalid, buffered_->Write(data, -1));
+}
+
+TEST_F(TestBufferedOutputStream, TinyWrites) {
+ OpenBuffered();
+
+ const std::string datastr = "1234568790";
+ const char* data = datastr.data();
+
+ ASSERT_OK(buffered_->Write(data, 2));
+ ASSERT_OK(buffered_->Write(data + 2, 6));
+ ASSERT_OK(buffered_->Close());
+
+ AssertFileContents(path_, datastr.substr(0, 8));
+}
+
+TEST_F(TestBufferedOutputStream, SmallWrites) {
+ OpenBuffered();
+
+ // Data here should be larger than BufferedOutputStream's buffer size
+ const std::string data = GenerateRandomData(200000);
+ const std::valarray<int64_t> sizes = {1, 1, 2, 3, 5, 8, 13};
+
+ WriteChunkwise(data, sizes);
+ ASSERT_OK(buffered_->Close());
+
+ AssertFileContents(path_, data);
+}
+
+TEST_F(TestBufferedOutputStream, MixedWrites) {
+ OpenBuffered();
+
+ const std::string data = GenerateRandomData(300000);
+ const std::valarray<int64_t> sizes = {1, 1, 2, 3, 70000};
+
+ WriteChunkwise(data, sizes);
+ ASSERT_OK(buffered_->Close());
+
+ AssertFileContents(path_, data);
+}
+
+TEST_F(TestBufferedOutputStream, LargeWrites) {
+ OpenBuffered();
+
+ const std::string data = GenerateRandomData(800000);
+ const std::valarray<int64_t> sizes = {10000, 60000, 70000};
+
+ WriteChunkwise(data, sizes);
+ ASSERT_OK(buffered_->Close());
+
+ AssertFileContents(path_, data);
+}
+
+TEST_F(TestBufferedOutputStream, Flush) {
+ OpenBuffered();
+
+ const std::string datastr = "1234568790";
+ const char* data = datastr.data();
+
+ ASSERT_OK(buffered_->Write(data, datastr.size()));
+ ASSERT_OK(buffered_->Flush());
+
+ AssertFileContents(path_, datastr);
+
+ ASSERT_OK(buffered_->Close());
+}
+
+TEST_F(TestBufferedOutputStream, SetBufferSize) {
+ OpenBuffered(20);
+
+ ASSERT_EQ(20, buffered_->buffer_size());
+
+ const std::string datastr = "1234568790abcdefghij";
+ const char* data = datastr.data();
+
+ // Write part of the data, then shrink buffer size to make sure it gets
+ // flushed
+ ASSERT_OK(buffered_->Write(data, 10));
+ ASSERT_OK(buffered_->SetBufferSize(10));
+
+ ASSERT_EQ(10, buffered_->buffer_size());
+
+ // Shrink buffer, write some buffered bytes, then expand buffer
+ ASSERT_OK(buffered_->SetBufferSize(5));
+ ASSERT_OK(buffered_->Write(data + 10, 3));
+ ASSERT_OK(buffered_->SetBufferSize(10));
+ ASSERT_EQ(3, buffered_->bytes_buffered());
+
+ ASSERT_OK(buffered_->Write(data + 13, 7));
+ ASSERT_OK(buffered_->Flush());
+
+ AssertFileContents(path_, datastr);
+ ASSERT_OK(buffered_->Close());
+}
+
+TEST_F(TestBufferedOutputStream, Tell) {
+ OpenBuffered();
+
+ AssertTell(0);
+ AssertTell(0);
+ WriteChunkwise(std::string(100, 'x'), {1, 1, 2, 3, 5, 8});
+ AssertTell(100);
+ WriteChunkwise(std::string(100000, 'x'), {60000});
+ AssertTell(100100);
+
+ ASSERT_OK(buffered_->Close());
+
+ OpenBuffered(kDefaultBufferSize, true /* append */);
+ AssertTell(100100);
+ WriteChunkwise(std::string(90, 'x'), {1, 1, 2, 3, 5, 8});
+ AssertTell(100190);
+
+ ASSERT_OK(buffered_->Close());
+
+ OpenBuffered();
+ AssertTell(0);
+}
+
+TEST_F(TestBufferedOutputStream, TruncatesFile) {
+ OpenBuffered();
+
+ const std::string datastr = "1234568790";
+ ASSERT_OK(buffered_->Write(datastr.data(), datastr.size()));
+ ASSERT_OK(buffered_->Close());
+
+ AssertFileContents(path_, datastr);
+
+ OpenBuffered();
+ AssertFileContents(path_, "");
+}
+
+// ----------------------------------------------------------------------
+// BufferedInputStream tests
+
+const char kExample1[] = "informaticacrobaticsimmolation";
+
+class TestBufferedInputStream : public FileTestFixture<BufferedInputStream> {
+ public:
+ void SetUp() {
+ FileTestFixture<BufferedInputStream>::SetUp();
+ local_pool_ = MemoryPool::CreateDefault();
+ }
+
+ void MakeExample1(int64_t buffer_size, MemoryPool* pool = default_memory_pool()) {
+ test_data_ = kExample1;
+
+ ASSERT_OK_AND_ASSIGN(auto file_out, FileOutputStream::Open(path_));
+ ASSERT_OK(file_out->Write(test_data_));
+ ASSERT_OK(file_out->Close());
+
+ ASSERT_OK_AND_ASSIGN(auto file_in, ReadableFile::Open(path_));
+ raw_ = file_in;
+ ASSERT_OK_AND_ASSIGN(buffered_, BufferedInputStream::Create(buffer_size, pool, raw_));
+ }
+
+ protected:
+ std::unique_ptr<MemoryPool> local_pool_;
+ std::string test_data_;
+ std::shared_ptr<InputStream> raw_;
+};
+
+TEST_F(TestBufferedInputStream, InvalidReads) {
+ const int64_t kBufferSize = 10;
+ MakeExample1(kBufferSize);
+ ASSERT_EQ(kBufferSize, buffered_->buffer_size());
+ std::vector<char> buf(test_data_.size());
+ ASSERT_RAISES(Invalid, buffered_->Read(-1, buf.data()));
+}
+
+TEST_F(TestBufferedInputStream, BasicOperation) {
+ const int64_t kBufferSize = 10;
+ MakeExample1(kBufferSize);
+ ASSERT_EQ(kBufferSize, buffered_->buffer_size());
+
+ ASSERT_OK_AND_EQ(0, buffered_->Tell());
+
+ // Nothing in the buffer
+ ASSERT_EQ(0, buffered_->bytes_buffered());
+
+ std::vector<char> buf(test_data_.size());
+ ASSERT_OK_AND_EQ(0, buffered_->Read(0, buf.data()));
+ ASSERT_OK_AND_EQ(4, buffered_->Read(4, buf.data()));
+ ASSERT_EQ(0, memcmp(buf.data(), test_data_.data(), 4));
+
+ // 6 bytes remaining in buffer
+ ASSERT_EQ(6, buffered_->bytes_buffered());
+
+ // This make sure Peek() works well when buffered bytes are not enough
+ ASSERT_OK_AND_ASSIGN(auto peek, buffered_->Peek(8));
+ ASSERT_EQ(8, peek.size());
+ ASSERT_EQ('r', peek.data()[0]);
+ ASSERT_EQ('m', peek.data()[1]);
+ ASSERT_EQ('a', peek.data()[2]);
+ ASSERT_EQ('t', peek.data()[3]);
+ ASSERT_EQ('i', peek.data()[4]);
+ ASSERT_EQ('c', peek.data()[5]);
+ ASSERT_EQ('a', peek.data()[6]);
+ ASSERT_EQ('c', peek.data()[7]);
+
+ // Buffered position is 4
+ ASSERT_OK_AND_EQ(4, buffered_->Tell());
+
+ // Raw position actually 12
+ ASSERT_OK_AND_EQ(12, raw_->Tell());
+
+ // Reading to end of buffered bytes does not cause any more data to be
+ // buffered
+ ASSERT_OK_AND_EQ(8, buffered_->Read(8, buf.data()));
+ ASSERT_EQ(0, memcmp(buf.data(), test_data_.data() + 4, 8));
+
+ ASSERT_EQ(0, buffered_->bytes_buffered());
+
+ // Read to EOF, exceeding buffer size
+ ASSERT_OK_AND_EQ(18, buffered_->Read(18, buf.data()));
+ ASSERT_EQ(0, memcmp(buf.data(), test_data_.data() + 12, 18));
+ ASSERT_EQ(0, buffered_->bytes_buffered());
+
+ // Read to EOF
+ ASSERT_OK_AND_EQ(0, buffered_->Read(1, buf.data()));
+ ASSERT_OK_AND_EQ(test_data_.size(), buffered_->Tell());
+
+ // Peek at EOF
+ ASSERT_OK_AND_ASSIGN(peek, buffered_->Peek(10));
+ ASSERT_EQ(0, peek.size());
+
+ // Calling Close closes raw_
+ ASSERT_OK(buffered_->Close());
+ ASSERT_TRUE(buffered_->raw()->closed());
+}
+
+TEST_F(TestBufferedInputStream, Detach) {
+ MakeExample1(10);
+ auto raw = buffered_->Detach();
+ ASSERT_OK(buffered_->Close());
+ ASSERT_FALSE(raw->closed());
+}
+
+TEST_F(TestBufferedInputStream, ReadBuffer) {
+ const int64_t kBufferSize = 10;
+ MakeExample1(kBufferSize);
+
+ std::shared_ptr<Buffer> buf;
+
+ // Read exceeding buffer size
+ ASSERT_OK_AND_ASSIGN(buf, buffered_->Read(15));
+ ASSERT_EQ(0, memcmp(buf->data(), test_data_.data(), 15));
+ ASSERT_EQ(0, buffered_->bytes_buffered());
+
+ // Buffered reads
+ ASSERT_OK_AND_ASSIGN(buf, buffered_->Read(6));
+ ASSERT_EQ(6, buf->size());
+ ASSERT_EQ(0, memcmp(buf->data(), test_data_.data() + 15, 6));
+ ASSERT_EQ(4, buffered_->bytes_buffered());
+
+ ASSERT_OK_AND_ASSIGN(buf, buffered_->Read(4));
+ ASSERT_EQ(4, buf->size());
+ ASSERT_EQ(0, memcmp(buf->data(), test_data_.data() + 21, 4));
+ ASSERT_EQ(0, buffered_->bytes_buffered());
+}
+
+TEST_F(TestBufferedInputStream, SetBufferSize) {
+ MakeExample1(5);
+
+ std::shared_ptr<Buffer> buf;
+ ASSERT_OK_AND_ASSIGN(buf, buffered_->Read(5));
+ ASSERT_EQ(5, buf->size());
+
+ // Increase buffer size
+ ASSERT_OK(buffered_->SetBufferSize(10));
+ ASSERT_EQ(10, buffered_->buffer_size());
+ ASSERT_OK_AND_ASSIGN(buf, buffered_->Read(6));
+ ASSERT_EQ(4, buffered_->bytes_buffered());
+
+ // Consume until 5 byte left
+ ASSERT_OK(buffered_->Read(15));
+
+ // Read at EOF so there will be only 5 bytes in the buffer
+ ASSERT_OK(buffered_->Read(2));
+
+ // Cannot shrink buffer if it would destroy data
+ ASSERT_RAISES(Invalid, buffered_->SetBufferSize(4));
+
+ // Shrinking to exactly number of buffered bytes is ok
+ ASSERT_OK(buffered_->SetBufferSize(5));
+}
+
+class TestBufferedInputStreamBound : public ::testing::Test {
+ public:
+ void SetUp() { CreateExample(/*bounded=*/true); }
+
+ void CreateExample(bool bounded = true) {
+ // Create a buffer larger than source size, to check that the
+ // stream end is respected
+ ASSERT_OK_AND_ASSIGN(auto buf, AllocateResizableBuffer(source_size_ + 10));
+ ASSERT_LT(source_size_, buf->size());
+ for (int i = 0; i < source_size_; i++) {
+ buf->mutable_data()[i] = static_cast<uint8_t>(i);
+ }
+ source_ = std::make_shared<BufferReader>(std::move(buf));
+ ASSERT_OK(source_->Advance(stream_offset_));
+ ASSERT_OK_AND_ASSIGN(
+ stream_, BufferedInputStream::Create(chunk_size_, default_memory_pool(), source_,
+ bounded ? stream_size_ : -1));
+ }
+
+ protected:
+ int64_t source_size_ = 256;
+ int64_t stream_offset_ = 10;
+ int64_t stream_size_ = source_size_ - stream_offset_;
+ int64_t chunk_size_ = 50;
+ std::shared_ptr<InputStream> source_;
+ std::shared_ptr<BufferedInputStream> stream_;
+};
+
+TEST_F(TestBufferedInputStreamBound, Basics) {
+ std::shared_ptr<Buffer> buffer;
+ util::string_view view;
+
+ // source is at offset 10
+ ASSERT_OK_AND_ASSIGN(view, stream_->Peek(10));
+ ASSERT_EQ(10, view.size());
+ for (int i = 0; i < 10; i++) {
+ ASSERT_EQ(10 + i, view[i]) << i;
+ }
+
+ ASSERT_OK_AND_ASSIGN(buffer, stream_->Read(10));
+ ASSERT_EQ(10, buffer->size());
+ for (int i = 0; i < 10; i++) {
+ ASSERT_EQ(10 + i, (*buffer)[i]) << i;
+ }
+
+ ASSERT_OK_AND_ASSIGN(buffer, stream_->Read(10));
+ ASSERT_EQ(10, buffer->size());
+ for (int i = 0; i < 10; i++) {
+ ASSERT_EQ(20 + i, (*buffer)[i]) << i;
+ }
+ ASSERT_OK(stream_->Advance(5));
+ ASSERT_OK(stream_->Advance(5));
+
+ // source is at offset 40
+ // read across buffer boundary. buffer size is 50
+ ASSERT_OK_AND_ASSIGN(buffer, stream_->Read(20));
+ ASSERT_EQ(20, buffer->size());
+ for (int i = 0; i < 20; i++) {
+ ASSERT_EQ(40 + i, (*buffer)[i]) << i;
+ }
+
+ // read more than original chunk size
+ ASSERT_OK_AND_ASSIGN(buffer, stream_->Read(60));
+ ASSERT_EQ(60, buffer->size());
+ for (int i = 0; i < 60; i++) {
+ ASSERT_EQ(60 + i, (*buffer)[i]) << i;
+ }
+
+ ASSERT_OK(stream_->Advance(120));
+
+ // source is at offset 240
+ // read outside of source boundary. source size is 256
+ ASSERT_OK_AND_ASSIGN(buffer, stream_->Read(30));
+
+ ASSERT_EQ(16, buffer->size());
+ for (int i = 0; i < 16; i++) {
+ ASSERT_EQ(240 + i, (*buffer)[i]) << i;
+ }
+ // Stream exhausted
+ ASSERT_OK_AND_ASSIGN(buffer, stream_->Read(1));
+ ASSERT_EQ(0, buffer->size());
+}
+
+TEST_F(TestBufferedInputStreamBound, LargeFirstPeek) {
+ // Test a first peek larger than chunk size
+ std::shared_ptr<Buffer> buffer;
+ util::string_view view;
+ int64_t n = 70;
+ ASSERT_GT(n, chunk_size_);
+
+ // source is at offset 10
+ ASSERT_OK_AND_ASSIGN(view, stream_->Peek(n));
+ ASSERT_EQ(n, static_cast<int>(view.size()));
+ for (int i = 0; i < n; i++) {
+ ASSERT_EQ(10 + i, view[i]) << i;
+ }
+
+ ASSERT_OK_AND_ASSIGN(view, stream_->Peek(n));
+ ASSERT_EQ(n, static_cast<int>(view.size()));
+ for (int i = 0; i < n; i++) {
+ ASSERT_EQ(10 + i, view[i]) << i;
+ }
+
+ ASSERT_OK_AND_ASSIGN(buffer, stream_->Read(n));
+ ASSERT_EQ(n, buffer->size());
+ for (int i = 0; i < n; i++) {
+ ASSERT_EQ(10 + i, (*buffer)[i]) << i;
+ }
+ // source is at offset 10 + n
+ ASSERT_OK_AND_ASSIGN(buffer, stream_->Read(20));
+ ASSERT_EQ(20, buffer->size());
+ for (int i = 0; i < 20; i++) {
+ ASSERT_EQ(10 + n + i, (*buffer)[i]) << i;
+ }
+}
+
+TEST_F(TestBufferedInputStreamBound, UnboundedPeek) {
+ CreateExample(/*bounded=*/false);
+
+ util::string_view view;
+ ASSERT_OK_AND_ASSIGN(view, stream_->Peek(10));
+ ASSERT_EQ(10, view.size());
+ ASSERT_EQ(50, stream_->bytes_buffered());
+
+ ASSERT_OK(stream_->Read(10));
+
+ // Peek into buffered bytes
+ ASSERT_OK_AND_ASSIGN(view, stream_->Peek(40));
+ ASSERT_EQ(40, view.size());
+ ASSERT_EQ(40, stream_->bytes_buffered());
+ ASSERT_EQ(50, stream_->buffer_size());
+
+ // Peek past buffered bytes
+ ASSERT_OK_AND_ASSIGN(view, stream_->Peek(41));
+ ASSERT_EQ(41, view.size());
+ ASSERT_EQ(41, stream_->bytes_buffered());
+ ASSERT_EQ(51, stream_->buffer_size());
+
+ // Peek to the end of the buffer
+ ASSERT_OK_AND_ASSIGN(view, stream_->Peek(246));
+ ASSERT_EQ(246, view.size());
+ ASSERT_EQ(246, stream_->bytes_buffered());
+ ASSERT_EQ(256, stream_->buffer_size());
+
+ // Larger peek returns the same, expands the buffer, but there is no
+ // more data to buffer
+ ASSERT_OK_AND_ASSIGN(view, stream_->Peek(300));
+ ASSERT_EQ(246, view.size());
+ ASSERT_EQ(246, stream_->bytes_buffered());
+ ASSERT_EQ(310, stream_->buffer_size());
+}
+
+TEST_F(TestBufferedInputStreamBound, OneByteReads) {
+ for (int i = 0; i < stream_size_; ++i) {
+ ASSERT_OK_AND_ASSIGN(auto buffer, stream_->Read(1));
+ ASSERT_EQ(1, buffer->size());
+ ASSERT_EQ(10 + i, (*buffer)[0]) << i;
+ }
+ // Stream exhausted
+ ASSERT_OK_AND_ASSIGN(auto buffer, stream_->Read(1));
+ ASSERT_EQ(0, buffer->size());
+}
+
+TEST_F(TestBufferedInputStreamBound, BufferExactlyExhausted) {
+ // Test exhausting the buffer exactly then issuing further reads (PARQUET-1571).
+ std::shared_ptr<Buffer> buffer;
+
+ // source is at offset 10
+ int64_t n = 10;
+ ASSERT_OK_AND_ASSIGN(buffer, stream_->Read(n));
+ ASSERT_EQ(n, buffer->size());
+ for (int i = 0; i < n; i++) {
+ ASSERT_EQ(10 + i, (*buffer)[i]) << i;
+ }
+ // source is at offset 20
+ // Exhaust buffer exactly
+ n = stream_->bytes_buffered();
+ ASSERT_OK_AND_ASSIGN(buffer, stream_->Read(n));
+ ASSERT_EQ(n, buffer->size());
+ for (int i = 0; i < n; i++) {
+ ASSERT_EQ(20 + i, (*buffer)[i]) << i;
+ }
+
+ // source is at offset 20 + n
+ // Read new buffer
+ ASSERT_OK_AND_ASSIGN(buffer, stream_->Read(10));
+ ASSERT_EQ(10, buffer->size());
+ for (int i = 0; i < 10; i++) {
+ ASSERT_EQ(20 + n + i, (*buffer)[i]) << i;
+ }
+
+ // source is at offset 30 + n
+ ASSERT_OK_AND_ASSIGN(buffer, stream_->Read(10));
+ ASSERT_EQ(10, buffer->size());
+ for (int i = 0; i < 10; i++) {
+ ASSERT_EQ(30 + n + i, (*buffer)[i]) << i;
+ }
+}
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/caching.cc b/src/arrow/cpp/src/arrow/io/caching.cc
new file mode 100644
index 000000000..722026ccd
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/caching.cc
@@ -0,0 +1,318 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <atomic>
+#include <cmath>
+#include <mutex>
+#include <utility>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/io/caching.h"
+#include "arrow/io/util_internal.h"
+#include "arrow/result.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace io {
+
+CacheOptions CacheOptions::Defaults() {
+ return CacheOptions{internal::ReadRangeCache::kDefaultHoleSizeLimit,
+ internal::ReadRangeCache::kDefaultRangeSizeLimit,
+ /*lazy=*/false};
+}
+
+CacheOptions CacheOptions::LazyDefaults() {
+ return CacheOptions{internal::ReadRangeCache::kDefaultHoleSizeLimit,
+ internal::ReadRangeCache::kDefaultRangeSizeLimit,
+ /*lazy=*/true};
+}
+
+CacheOptions CacheOptions::MakeFromNetworkMetrics(int64_t time_to_first_byte_millis,
+ int64_t transfer_bandwidth_mib_per_sec,
+ double ideal_bandwidth_utilization_frac,
+ int64_t max_ideal_request_size_mib) {
+ //
+ // The I/O coalescing algorithm uses two parameters:
+ // 1. hole_size_limit (a.k.a max_io_gap): Max I/O gap/hole size in bytes
+ // 2. range_size_limit (a.k.a ideal_request_size): Ideal I/O Request size in bytes
+ //
+ // These parameters can be derived from network metrics (e.g. S3) as described below:
+ //
+ // In an S3 compatible storage, there are two main metrics:
+ // 1. Seek-time or Time-To-First-Byte (TTFB) in seconds: call setup latency of a new
+ // S3 request
+ // 2. Transfer Bandwidth (BW) for data in bytes/sec
+ //
+ // 1. Computing hole_size_limit:
+ //
+ // hole_size_limit = TTFB * BW
+ //
+ // This is also called Bandwidth-Delay-Product (BDP).
+ // Two byte ranges that have a gap can still be mapped to the same read
+ // if the gap is less than the bandwidth-delay product [TTFB * TransferBandwidth],
+ // i.e. if the Time-To-First-Byte (or call setup latency of a new S3 request) is
+ // expected to be greater than just reading and discarding the extra bytes on an
+ // existing HTTP request.
+ //
+ // 2. Computing range_size_limit:
+ //
+ // We want to have high bandwidth utilization per S3 connections,
+ // i.e. transfer large amounts of data to amortize the seek overhead.
+ // But, we also want to leverage parallelism by slicing very large IO chunks.
+ // We define two more config parameters with suggested default values to control
+ // the slice size and seek to balance the two effects with the goal of maximizing
+ // net data load performance.
+ //
+ // BW_util_frac (ideal bandwidth utilization): Transfer bandwidth utilization fraction
+ // (per connection) to maximize the net data load. 90% is a good default number for
+ // an effective transfer bandwidth.
+ //
+ // MAX_IDEAL_REQUEST_SIZE: The maximum single data request size (in MiB) to maximize
+ // the net data load. 64 MiB is a good default number for the ideal request size.
+ //
+ // The amount of data that needs to be transferred in a single S3 get_object
+ // request to achieve effective bandwidth eff_BW = BW_util_frac * BW is as follows:
+ // eff_BW = range_size_limit / (TTFB + range_size_limit / BW)
+ //
+ // Substituting TTFB = hole_size_limit / BW and eff_BW = BW_util_frac * BW, we get the
+ // following result:
+ // range_size_limit = hole_size_limit * BW_util_frac / (1 - BW_util_frac)
+ //
+ // Applying the MAX_IDEAL_REQUEST_SIZE, we get the following:
+ // range_size_limit = min(MAX_IDEAL_REQUEST_SIZE,
+ // hole_size_limit * BW_util_frac / (1 - BW_util_frac))
+ //
+ DCHECK_GT(time_to_first_byte_millis, 0) << "TTFB must be > 0";
+ DCHECK_GT(transfer_bandwidth_mib_per_sec, 0) << "Transfer bandwidth must be > 0";
+ DCHECK_GT(ideal_bandwidth_utilization_frac, 0)
+ << "Ideal bandwidth utilization fraction must be > 0";
+ DCHECK_LT(ideal_bandwidth_utilization_frac, 1.0)
+ << "Ideal bandwidth utilization fraction must be < 1";
+ DCHECK_GT(max_ideal_request_size_mib, 0) << "Max Ideal request size must be > 0";
+
+ const double time_to_first_byte_sec = time_to_first_byte_millis / 1000.0;
+ const int64_t transfer_bandwidth_bytes_per_sec =
+ transfer_bandwidth_mib_per_sec * 1024 * 1024;
+ const int64_t max_ideal_request_size_bytes = max_ideal_request_size_mib * 1024 * 1024;
+
+ // hole_size_limit = TTFB * BW
+ const auto hole_size_limit = static_cast<int64_t>(
+ std::round(time_to_first_byte_sec * transfer_bandwidth_bytes_per_sec));
+ DCHECK_GT(hole_size_limit, 0) << "Computed hole_size_limit must be > 0";
+
+ // range_size_limit = min(MAX_IDEAL_REQUEST_SIZE,
+ // hole_size_limit * BW_util_frac / (1 - BW_util_frac))
+ const int64_t range_size_limit = std::min(
+ max_ideal_request_size_bytes,
+ static_cast<int64_t>(std::round(hole_size_limit * ideal_bandwidth_utilization_frac /
+ (1 - ideal_bandwidth_utilization_frac))));
+ DCHECK_GT(range_size_limit, 0) << "Computed range_size_limit must be > 0";
+
+ return {hole_size_limit, range_size_limit, false};
+}
+
+namespace internal {
+
+struct RangeCacheEntry {
+ ReadRange range;
+ Future<std::shared_ptr<Buffer>> future;
+
+ RangeCacheEntry() = default;
+ RangeCacheEntry(const ReadRange& range_, Future<std::shared_ptr<Buffer>> future_)
+ : range(range_), future(std::move(future_)) {}
+
+ friend bool operator<(const RangeCacheEntry& left, const RangeCacheEntry& right) {
+ return left.range.offset < right.range.offset;
+ }
+};
+
+struct ReadRangeCache::Impl {
+ std::shared_ptr<RandomAccessFile> file;
+ IOContext ctx;
+ CacheOptions options;
+
+ // Ordered by offset (so as to find a matching region by binary search)
+ std::vector<RangeCacheEntry> entries;
+
+ virtual ~Impl() = default;
+
+ // Get the future corresponding to a range
+ virtual Future<std::shared_ptr<Buffer>> MaybeRead(RangeCacheEntry* entry) {
+ return entry->future;
+ }
+
+ // Make cache entries for ranges
+ virtual std::vector<RangeCacheEntry> MakeCacheEntries(
+ const std::vector<ReadRange>& ranges) {
+ std::vector<RangeCacheEntry> new_entries;
+ new_entries.reserve(ranges.size());
+ for (const auto& range : ranges) {
+ new_entries.emplace_back(range, file->ReadAsync(ctx, range.offset, range.length));
+ }
+ return new_entries;
+ }
+
+ // Add the given ranges to the cache, coalescing them where possible
+ virtual Status Cache(std::vector<ReadRange> ranges) {
+ ranges = internal::CoalesceReadRanges(std::move(ranges), options.hole_size_limit,
+ options.range_size_limit);
+ std::vector<RangeCacheEntry> new_entries = MakeCacheEntries(ranges);
+ // Add new entries, themselves ordered by offset
+ if (entries.size() > 0) {
+ std::vector<RangeCacheEntry> merged(entries.size() + new_entries.size());
+ std::merge(entries.begin(), entries.end(), new_entries.begin(), new_entries.end(),
+ merged.begin());
+ entries = std::move(merged);
+ } else {
+ entries = std::move(new_entries);
+ }
+ // Prefetch immediately, regardless of executor availability, if possible
+ return file->WillNeed(ranges);
+ }
+
+ // Read the given range from the cache, blocking if needed. Cannot read a range
+ // that spans cache entries.
+ virtual Result<std::shared_ptr<Buffer>> Read(ReadRange range) {
+ if (range.length == 0) {
+ static const uint8_t byte = 0;
+ return std::make_shared<Buffer>(&byte, 0);
+ }
+
+ const auto it = std::lower_bound(
+ entries.begin(), entries.end(), range,
+ [](const RangeCacheEntry& entry, const ReadRange& range) {
+ return entry.range.offset + entry.range.length < range.offset + range.length;
+ });
+ if (it != entries.end() && it->range.Contains(range)) {
+ auto fut = MaybeRead(&*it);
+ ARROW_ASSIGN_OR_RAISE(auto buf, fut.result());
+ return SliceBuffer(std::move(buf), range.offset - it->range.offset, range.length);
+ }
+ return Status::Invalid("ReadRangeCache did not find matching cache entry");
+ }
+
+ virtual Future<> Wait() {
+ std::vector<Future<>> futures;
+ for (auto& entry : entries) {
+ futures.emplace_back(MaybeRead(&entry));
+ }
+ return AllComplete(futures);
+ }
+
+ // Return a Future that completes when the given ranges have been read.
+ virtual Future<> WaitFor(std::vector<ReadRange> ranges) {
+ auto end = std::remove_if(ranges.begin(), ranges.end(),
+ [](const ReadRange& range) { return range.length == 0; });
+ ranges.resize(end - ranges.begin());
+ std::vector<Future<>> futures;
+ futures.reserve(ranges.size());
+ for (auto& range : ranges) {
+ const auto it = std::lower_bound(
+ entries.begin(), entries.end(), range,
+ [](const RangeCacheEntry& entry, const ReadRange& range) {
+ return entry.range.offset + entry.range.length < range.offset + range.length;
+ });
+ if (it != entries.end() && it->range.Contains(range)) {
+ futures.push_back(Future<>(MaybeRead(&*it)));
+ } else {
+ return Status::Invalid("Range was not requested for caching: offset=",
+ range.offset, " length=", range.length);
+ }
+ }
+ return AllComplete(futures);
+ }
+};
+
+// Don't read ranges when they're first added. Instead, wait until they're requested
+// (either through Read or WaitFor).
+struct ReadRangeCache::LazyImpl : public ReadRangeCache::Impl {
+ // Protect against concurrent modification of entries[i]->future
+ std::mutex entry_mutex;
+
+ virtual ~LazyImpl() = default;
+
+ Future<std::shared_ptr<Buffer>> MaybeRead(RangeCacheEntry* entry) override {
+ // Called by superclass Read()/WaitFor() so we have the lock
+ if (!entry->future.is_valid()) {
+ entry->future = file->ReadAsync(ctx, entry->range.offset, entry->range.length);
+ }
+ return entry->future;
+ }
+
+ std::vector<RangeCacheEntry> MakeCacheEntries(
+ const std::vector<ReadRange>& ranges) override {
+ std::vector<RangeCacheEntry> new_entries;
+ new_entries.reserve(ranges.size());
+ for (const auto& range : ranges) {
+ // In the lazy variant, don't read data here - later, a call to Read or WaitFor
+ // will call back to MaybeRead (under the lock) which will fill the future.
+ new_entries.emplace_back(range, Future<std::shared_ptr<Buffer>>());
+ }
+ return new_entries;
+ }
+
+ Status Cache(std::vector<ReadRange> ranges) override {
+ std::unique_lock<std::mutex> guard(entry_mutex);
+ return ReadRangeCache::Impl::Cache(std::move(ranges));
+ }
+
+ Result<std::shared_ptr<Buffer>> Read(ReadRange range) override {
+ std::unique_lock<std::mutex> guard(entry_mutex);
+ return ReadRangeCache::Impl::Read(range);
+ }
+
+ Future<> Wait() override {
+ std::unique_lock<std::mutex> guard(entry_mutex);
+ return ReadRangeCache::Impl::Wait();
+ }
+
+ Future<> WaitFor(std::vector<ReadRange> ranges) override {
+ std::unique_lock<std::mutex> guard(entry_mutex);
+ return ReadRangeCache::Impl::WaitFor(std::move(ranges));
+ }
+};
+
+ReadRangeCache::ReadRangeCache(std::shared_ptr<RandomAccessFile> file, IOContext ctx,
+ CacheOptions options)
+ : impl_(options.lazy ? new LazyImpl() : new Impl()) {
+ impl_->file = std::move(file);
+ impl_->ctx = std::move(ctx);
+ impl_->options = options;
+}
+
+ReadRangeCache::~ReadRangeCache() = default;
+
+Status ReadRangeCache::Cache(std::vector<ReadRange> ranges) {
+ return impl_->Cache(std::move(ranges));
+}
+
+Result<std::shared_ptr<Buffer>> ReadRangeCache::Read(ReadRange range) {
+ return impl_->Read(range);
+}
+
+Future<> ReadRangeCache::Wait() { return impl_->Wait(); }
+
+Future<> ReadRangeCache::WaitFor(std::vector<ReadRange> ranges) {
+ return impl_->WaitFor(std::move(ranges));
+}
+
+} // namespace internal
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/caching.h b/src/arrow/cpp/src/arrow/io/caching.h
new file mode 100644
index 000000000..59a9b60e8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/caching.h
@@ -0,0 +1,138 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/io/interfaces.h"
+#include "arrow/util/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace io {
+
+struct ARROW_EXPORT CacheOptions {
+ static constexpr double kDefaultIdealBandwidthUtilizationFrac = 0.9;
+ static constexpr int64_t kDefaultMaxIdealRequestSizeMib = 64;
+
+ /// \brief The maximum distance in bytes between two consecutive
+ /// ranges; beyond this value, ranges are not combined
+ int64_t hole_size_limit;
+ /// \brief The maximum size in bytes of a combined range; if
+ /// combining two consecutive ranges would produce a range of a
+ /// size greater than this, they are not combined
+ int64_t range_size_limit;
+ /// \brief A lazy cache does not perform any I/O until requested.
+ bool lazy;
+
+ bool operator==(const CacheOptions& other) const {
+ return hole_size_limit == other.hole_size_limit &&
+ range_size_limit == other.range_size_limit && lazy == other.lazy;
+ }
+
+ /// \brief Construct CacheOptions from network storage metrics (e.g. S3).
+ ///
+ /// \param[in] time_to_first_byte_millis Seek-time or Time-To-First-Byte (TTFB) in
+ /// milliseconds, also called call setup latency of a new S3 request.
+ /// The value is a positive integer.
+ /// \param[in] transfer_bandwidth_mib_per_sec Data transfer Bandwidth (BW) in MiB/sec.
+ /// The value is a positive integer.
+ /// \param[in] ideal_bandwidth_utilization_frac Transfer bandwidth utilization fraction
+ /// (per connection) to maximize the net data load.
+ /// The value is a positive double precision number less than 1.
+ /// \param[in] max_ideal_request_size_mib The maximum single data request size (in MiB)
+ /// to maximize the net data load.
+ /// The value is a positive integer.
+ /// \return A new instance of CacheOptions.
+ static CacheOptions MakeFromNetworkMetrics(
+ int64_t time_to_first_byte_millis, int64_t transfer_bandwidth_mib_per_sec,
+ double ideal_bandwidth_utilization_frac = kDefaultIdealBandwidthUtilizationFrac,
+ int64_t max_ideal_request_size_mib = kDefaultMaxIdealRequestSizeMib);
+
+ static CacheOptions Defaults();
+ static CacheOptions LazyDefaults();
+};
+
+namespace internal {
+
+/// \brief A read cache designed to hide IO latencies when reading.
+///
+/// This class takes multiple byte ranges that an application expects to read, and
+/// coalesces them into fewer, larger read requests, which benefits performance on some
+/// filesystems, particularly remote ones like Amazon S3. By default, it also issues
+/// these read requests in parallel up front.
+///
+/// To use:
+/// 1. Cache() the ranges you expect to read in the future. Ideally, these ranges have
+/// the exact offset and length that will later be read. The cache will combine those
+/// ranges according to parameters (see constructor).
+///
+/// By default, the cache will also start fetching the combined ranges in parallel in
+/// the background, unless CacheOptions.lazy is set.
+///
+/// 2. Call WaitFor() to be notified when the given ranges have been read. If
+/// CacheOptions.lazy is set, I/O will be triggered in the background here instead.
+/// This can be done in parallel (e.g. if parsing a file, call WaitFor() for each
+/// chunk of the file that can be parsed in parallel).
+///
+/// 3. Call Read() to retrieve the actual data for the given ranges.
+/// A synchronous application may skip WaitFor() and just call Read() - it will still
+/// benefit from coalescing and parallel fetching.
+class ARROW_EXPORT ReadRangeCache {
+ public:
+ static constexpr int64_t kDefaultHoleSizeLimit = 8192;
+ static constexpr int64_t kDefaultRangeSizeLimit = 32 * 1024 * 1024;
+
+ /// Construct a read cache with default
+ explicit ReadRangeCache(std::shared_ptr<RandomAccessFile> file, IOContext ctx)
+ : ReadRangeCache(file, std::move(ctx), CacheOptions::Defaults()) {}
+
+ /// Construct a read cache with given options
+ explicit ReadRangeCache(std::shared_ptr<RandomAccessFile> file, IOContext ctx,
+ CacheOptions options);
+ ~ReadRangeCache();
+
+ /// \brief Cache the given ranges in the background.
+ ///
+ /// The caller must ensure that the ranges do not overlap with each other,
+ /// nor with previously cached ranges. Otherwise, behaviour will be undefined.
+ Status Cache(std::vector<ReadRange> ranges);
+
+ /// \brief Read a range previously given to Cache().
+ Result<std::shared_ptr<Buffer>> Read(ReadRange range);
+
+ /// \brief Wait until all ranges added so far have been cached.
+ Future<> Wait();
+
+ /// \brief Wait until all given ranges have been cached.
+ Future<> WaitFor(std::vector<ReadRange> ranges);
+
+ protected:
+ struct Impl;
+ struct LazyImpl;
+
+ std::unique_ptr<Impl> impl_;
+};
+
+} // namespace internal
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/compressed.cc b/src/arrow/cpp/src/arrow/io/compressed.cc
new file mode 100644
index 000000000..72977f0f2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/compressed.cc
@@ -0,0 +1,450 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/io/compressed.h"
+
+#include <algorithm>
+#include <cstring>
+#include <memory>
+#include <mutex>
+#include <string>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/io/util_internal.h"
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/util/compression.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using util::Codec;
+using util::Compressor;
+using util::Decompressor;
+
+namespace io {
+
+// ----------------------------------------------------------------------
+// CompressedOutputStream implementation
+
+class CompressedOutputStream::Impl {
+ public:
+ Impl(MemoryPool* pool, const std::shared_ptr<OutputStream>& raw)
+ : pool_(pool), raw_(raw), is_open_(false), compressed_pos_(0), total_pos_(0) {}
+
+ Status Init(Codec* codec) {
+ ARROW_ASSIGN_OR_RAISE(compressor_, codec->MakeCompressor());
+ ARROW_ASSIGN_OR_RAISE(compressed_, AllocateResizableBuffer(kChunkSize, pool_));
+ compressed_pos_ = 0;
+ is_open_ = true;
+ return Status::OK();
+ }
+
+ Result<int64_t> Tell() const {
+ std::lock_guard<std::mutex> guard(lock_);
+ return total_pos_;
+ }
+
+ std::shared_ptr<OutputStream> raw() const { return raw_; }
+
+ Status FlushCompressed() {
+ if (compressed_pos_ > 0) {
+ RETURN_NOT_OK(raw_->Write(compressed_->data(), compressed_pos_));
+ compressed_pos_ = 0;
+ }
+ return Status::OK();
+ }
+
+ Status Write(const void* data, int64_t nbytes) {
+ std::lock_guard<std::mutex> guard(lock_);
+
+ auto input = reinterpret_cast<const uint8_t*>(data);
+ while (nbytes > 0) {
+ int64_t input_len = nbytes;
+ int64_t output_len = compressed_->size() - compressed_pos_;
+ uint8_t* output = compressed_->mutable_data() + compressed_pos_;
+ ARROW_ASSIGN_OR_RAISE(auto result,
+ compressor_->Compress(input_len, input, output_len, output));
+ compressed_pos_ += result.bytes_written;
+
+ if (result.bytes_read == 0) {
+ // Not enough output, try to flush it and retry
+ if (compressed_pos_ > 0) {
+ RETURN_NOT_OK(FlushCompressed());
+ output_len = compressed_->size() - compressed_pos_;
+ output = compressed_->mutable_data() + compressed_pos_;
+ ARROW_ASSIGN_OR_RAISE(
+ result, compressor_->Compress(input_len, input, output_len, output));
+ compressed_pos_ += result.bytes_written;
+ }
+ }
+ input += result.bytes_read;
+ nbytes -= result.bytes_read;
+ total_pos_ += result.bytes_read;
+ if (compressed_pos_ == compressed_->size()) {
+ // Output buffer full, flush it
+ RETURN_NOT_OK(FlushCompressed());
+ }
+ if (result.bytes_read == 0) {
+ // Need to enlarge output buffer
+ RETURN_NOT_OK(compressed_->Resize(compressed_->size() * 2));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Flush() {
+ std::lock_guard<std::mutex> guard(lock_);
+
+ while (true) {
+ // Flush compressor
+ int64_t output_len = compressed_->size() - compressed_pos_;
+ uint8_t* output = compressed_->mutable_data() + compressed_pos_;
+ ARROW_ASSIGN_OR_RAISE(auto result, compressor_->Flush(output_len, output));
+ compressed_pos_ += result.bytes_written;
+
+ // Flush compressed output
+ RETURN_NOT_OK(FlushCompressed());
+
+ if (result.should_retry) {
+ // Need to enlarge output buffer
+ RETURN_NOT_OK(compressed_->Resize(compressed_->size() * 2));
+ } else {
+ break;
+ }
+ }
+ return Status::OK();
+ }
+
+ Status FinalizeCompression() {
+ while (true) {
+ // Try to end compressor
+ int64_t output_len = compressed_->size() - compressed_pos_;
+ uint8_t* output = compressed_->mutable_data() + compressed_pos_;
+ ARROW_ASSIGN_OR_RAISE(auto result, compressor_->End(output_len, output));
+ compressed_pos_ += result.bytes_written;
+
+ // Flush compressed output
+ RETURN_NOT_OK(FlushCompressed());
+
+ if (result.should_retry) {
+ // Need to enlarge output buffer
+ RETURN_NOT_OK(compressed_->Resize(compressed_->size() * 2));
+ } else {
+ // Done
+ break;
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Close() {
+ std::lock_guard<std::mutex> guard(lock_);
+
+ if (is_open_) {
+ is_open_ = false;
+ RETURN_NOT_OK(FinalizeCompression());
+ return raw_->Close();
+ } else {
+ return Status::OK();
+ }
+ }
+
+ Status Abort() {
+ std::lock_guard<std::mutex> guard(lock_);
+
+ if (is_open_) {
+ is_open_ = false;
+ return raw_->Abort();
+ } else {
+ return Status::OK();
+ }
+ }
+
+ bool closed() {
+ std::lock_guard<std::mutex> guard(lock_);
+ return !is_open_;
+ }
+
+ private:
+ // Write 64 KB compressed data at a time
+ static const int64_t kChunkSize = 64 * 1024;
+
+ MemoryPool* pool_;
+ std::shared_ptr<OutputStream> raw_;
+ bool is_open_;
+ std::shared_ptr<Compressor> compressor_;
+ std::shared_ptr<ResizableBuffer> compressed_;
+ int64_t compressed_pos_;
+ // Total number of bytes compressed
+ int64_t total_pos_;
+
+ mutable std::mutex lock_;
+};
+
+Result<std::shared_ptr<CompressedOutputStream>> CompressedOutputStream::Make(
+ util::Codec* codec, const std::shared_ptr<OutputStream>& raw, MemoryPool* pool) {
+ // CAUTION: codec is not owned
+ std::shared_ptr<CompressedOutputStream> res(new CompressedOutputStream);
+ res->impl_.reset(new Impl(pool, std::move(raw)));
+ RETURN_NOT_OK(res->impl_->Init(codec));
+ return res;
+}
+
+CompressedOutputStream::~CompressedOutputStream() { internal::CloseFromDestructor(this); }
+
+Status CompressedOutputStream::Close() { return impl_->Close(); }
+
+Status CompressedOutputStream::Abort() { return impl_->Abort(); }
+
+bool CompressedOutputStream::closed() const { return impl_->closed(); }
+
+Result<int64_t> CompressedOutputStream::Tell() const { return impl_->Tell(); }
+
+Status CompressedOutputStream::Write(const void* data, int64_t nbytes) {
+ return impl_->Write(data, nbytes);
+}
+
+Status CompressedOutputStream::Flush() { return impl_->Flush(); }
+
+std::shared_ptr<OutputStream> CompressedOutputStream::raw() const { return impl_->raw(); }
+
+// ----------------------------------------------------------------------
+// CompressedInputStream implementation
+
+class CompressedInputStream::Impl {
+ public:
+ Impl(MemoryPool* pool, const std::shared_ptr<InputStream>& raw)
+ : pool_(pool),
+ raw_(raw),
+ is_open_(true),
+ compressed_pos_(0),
+ decompressed_pos_(0),
+ total_pos_(0) {}
+
+ Status Init(Codec* codec) {
+ ARROW_ASSIGN_OR_RAISE(decompressor_, codec->MakeDecompressor());
+ fresh_decompressor_ = true;
+ return Status::OK();
+ }
+
+ Status Close() {
+ if (is_open_) {
+ is_open_ = false;
+ return raw_->Close();
+ } else {
+ return Status::OK();
+ }
+ }
+
+ Status Abort() {
+ if (is_open_) {
+ is_open_ = false;
+ return raw_->Abort();
+ } else {
+ return Status::OK();
+ }
+ }
+
+ bool closed() { return !is_open_; }
+
+ Result<int64_t> Tell() const { return total_pos_; }
+
+ // Read compressed data if necessary
+ Status EnsureCompressedData() {
+ int64_t compressed_avail = compressed_ ? compressed_->size() - compressed_pos_ : 0;
+ if (compressed_avail == 0) {
+ // No compressed data available, read a full chunk
+ ARROW_ASSIGN_OR_RAISE(compressed_, raw_->Read(kChunkSize));
+ compressed_pos_ = 0;
+ }
+ return Status::OK();
+ }
+
+ // Decompress some data from the compressed_ buffer.
+ // Call this function only if the decompressed_ buffer is empty.
+ Status DecompressData() {
+ int64_t decompress_size = kDecompressSize;
+
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(decompressed_,
+ AllocateResizableBuffer(decompress_size, pool_));
+ decompressed_pos_ = 0;
+
+ int64_t input_len = compressed_->size() - compressed_pos_;
+ const uint8_t* input = compressed_->data() + compressed_pos_;
+ int64_t output_len = decompressed_->size();
+ uint8_t* output = decompressed_->mutable_data();
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto result, decompressor_->Decompress(input_len, input, output_len, output));
+ compressed_pos_ += result.bytes_read;
+ if (result.bytes_read > 0) {
+ fresh_decompressor_ = false;
+ }
+ if (result.bytes_written > 0 || !result.need_more_output || input_len == 0) {
+ RETURN_NOT_OK(decompressed_->Resize(result.bytes_written));
+ break;
+ }
+ DCHECK_EQ(result.bytes_written, 0);
+ // Need to enlarge output buffer
+ decompress_size *= 2;
+ }
+ return Status::OK();
+ }
+
+ // Read a given number of bytes from the decompressed_ buffer.
+ int64_t ReadFromDecompressed(int64_t nbytes, uint8_t* out) {
+ int64_t readable = decompressed_ ? (decompressed_->size() - decompressed_pos_) : 0;
+ int64_t read_bytes = std::min(readable, nbytes);
+
+ if (read_bytes > 0) {
+ memcpy(out, decompressed_->data() + decompressed_pos_, read_bytes);
+ decompressed_pos_ += read_bytes;
+
+ if (decompressed_pos_ == decompressed_->size()) {
+ // Decompressed data is exhausted, release buffer
+ decompressed_.reset();
+ }
+ }
+
+ return read_bytes;
+ }
+
+ // Try to feed more data into the decompressed_ buffer.
+ Status RefillDecompressed(bool* has_data) {
+ // First try to read data from the decompressor
+ if (compressed_) {
+ if (decompressor_->IsFinished()) {
+ // We just went over the end of a previous compressed stream.
+ RETURN_NOT_OK(decompressor_->Reset());
+ fresh_decompressor_ = true;
+ }
+ RETURN_NOT_OK(DecompressData());
+ }
+ if (!decompressed_ || decompressed_->size() == 0) {
+ // Got nothing, need to read more compressed data
+ RETURN_NOT_OK(EnsureCompressedData());
+ if (compressed_pos_ == compressed_->size()) {
+ // No more data to decompress
+ if (!fresh_decompressor_ && !decompressor_->IsFinished()) {
+ return Status::IOError("Truncated compressed stream");
+ }
+ *has_data = false;
+ return Status::OK();
+ }
+ RETURN_NOT_OK(DecompressData());
+ }
+ *has_data = true;
+ return Status::OK();
+ }
+
+ Result<int64_t> Read(int64_t nbytes, void* out) {
+ auto out_data = reinterpret_cast<uint8_t*>(out);
+
+ int64_t total_read = 0;
+ bool decompressor_has_data = true;
+
+ while (nbytes - total_read > 0 && decompressor_has_data) {
+ total_read += ReadFromDecompressed(nbytes - total_read, out_data + total_read);
+
+ if (nbytes == total_read) {
+ break;
+ }
+
+ // At this point, no more decompressed data remains, so we need to
+ // decompress more
+ RETURN_NOT_OK(RefillDecompressed(&decompressor_has_data));
+ }
+
+ total_pos_ += total_read;
+ return total_read;
+ }
+
+ Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) {
+ ARROW_ASSIGN_OR_RAISE(auto buf, AllocateResizableBuffer(nbytes, pool_));
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, Read(nbytes, buf->mutable_data()));
+ RETURN_NOT_OK(buf->Resize(bytes_read));
+ return std::move(buf);
+ }
+
+ std::shared_ptr<InputStream> raw() const { return raw_; }
+
+ private:
+ // Read 64 KB compressed data at a time
+ static const int64_t kChunkSize = 64 * 1024;
+ // Decompress 1 MB at a time
+ static const int64_t kDecompressSize = 1024 * 1024;
+
+ MemoryPool* pool_;
+ std::shared_ptr<InputStream> raw_;
+ bool is_open_;
+ std::shared_ptr<Decompressor> decompressor_;
+ std::shared_ptr<Buffer> compressed_;
+ // Position in compressed buffer
+ int64_t compressed_pos_;
+ std::shared_ptr<ResizableBuffer> decompressed_;
+ // Position in decompressed buffer
+ int64_t decompressed_pos_;
+ // True if the decompressor hasn't read any data yet.
+ bool fresh_decompressor_;
+ // Total number of bytes decompressed
+ int64_t total_pos_;
+};
+
+Result<std::shared_ptr<CompressedInputStream>> CompressedInputStream::Make(
+ Codec* codec, const std::shared_ptr<InputStream>& raw, MemoryPool* pool) {
+ // CAUTION: codec is not owned
+ std::shared_ptr<CompressedInputStream> res(new CompressedInputStream);
+ res->impl_.reset(new Impl(pool, std::move(raw)));
+ RETURN_NOT_OK(res->impl_->Init(codec));
+ return res;
+ return Status::OK();
+}
+
+CompressedInputStream::~CompressedInputStream() { internal::CloseFromDestructor(this); }
+
+Status CompressedInputStream::DoClose() { return impl_->Close(); }
+
+Status CompressedInputStream::DoAbort() { return impl_->Abort(); }
+
+bool CompressedInputStream::closed() const { return impl_->closed(); }
+
+Result<int64_t> CompressedInputStream::DoTell() const { return impl_->Tell(); }
+
+Result<int64_t> CompressedInputStream::DoRead(int64_t nbytes, void* out) {
+ return impl_->Read(nbytes, out);
+}
+
+Result<std::shared_ptr<Buffer>> CompressedInputStream::DoRead(int64_t nbytes) {
+ return impl_->Read(nbytes);
+}
+
+std::shared_ptr<InputStream> CompressedInputStream::raw() const { return impl_->raw(); }
+
+Result<std::shared_ptr<const KeyValueMetadata>> CompressedInputStream::ReadMetadata() {
+ return impl_->raw()->ReadMetadata();
+}
+
+Future<std::shared_ptr<const KeyValueMetadata>> CompressedInputStream::ReadMetadataAsync(
+ const IOContext& io_context) {
+ return impl_->raw()->ReadMetadataAsync(io_context);
+}
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/compressed.h b/src/arrow/cpp/src/arrow/io/compressed.h
new file mode 100644
index 000000000..cd1a7f673
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/compressed.h
@@ -0,0 +1,118 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Compressed stream implementations
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/io/concurrency.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class MemoryPool;
+class Status;
+
+namespace util {
+
+class Codec;
+
+} // namespace util
+
+namespace io {
+
+class ARROW_EXPORT CompressedOutputStream : public OutputStream {
+ public:
+ ~CompressedOutputStream() override;
+
+ /// \brief Create a compressed output stream wrapping the given output stream.
+ static Result<std::shared_ptr<CompressedOutputStream>> Make(
+ util::Codec* codec, const std::shared_ptr<OutputStream>& raw,
+ MemoryPool* pool = default_memory_pool());
+
+ // OutputStream interface
+
+ /// \brief Close the compressed output stream. This implicitly closes the
+ /// underlying raw output stream.
+ Status Close() override;
+ Status Abort() override;
+ bool closed() const override;
+
+ Result<int64_t> Tell() const override;
+
+ Status Write(const void* data, int64_t nbytes) override;
+ /// \cond FALSE
+ using Writable::Write;
+ /// \endcond
+ Status Flush() override;
+
+ /// \brief Return the underlying raw output stream.
+ std::shared_ptr<OutputStream> raw() const;
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(CompressedOutputStream);
+
+ CompressedOutputStream() = default;
+
+ class ARROW_NO_EXPORT Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+class ARROW_EXPORT CompressedInputStream
+ : public internal::InputStreamConcurrencyWrapper<CompressedInputStream> {
+ public:
+ ~CompressedInputStream() override;
+
+ /// \brief Create a compressed input stream wrapping the given input stream.
+ static Result<std::shared_ptr<CompressedInputStream>> Make(
+ util::Codec* codec, const std::shared_ptr<InputStream>& raw,
+ MemoryPool* pool = default_memory_pool());
+
+ // InputStream interface
+
+ bool closed() const override;
+ Result<std::shared_ptr<const KeyValueMetadata>> ReadMetadata() override;
+ Future<std::shared_ptr<const KeyValueMetadata>> ReadMetadataAsync(
+ const IOContext& io_context) override;
+
+ /// \brief Return the underlying raw input stream.
+ std::shared_ptr<InputStream> raw() const;
+
+ private:
+ friend InputStreamConcurrencyWrapper<CompressedInputStream>;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(CompressedInputStream);
+
+ CompressedInputStream() = default;
+
+ /// \brief Close the compressed input stream. This implicitly closes the
+ /// underlying raw input stream.
+ Status DoClose();
+ Status DoAbort() override;
+ Result<int64_t> DoTell() const;
+ Result<int64_t> DoRead(int64_t nbytes, void* out);
+ Result<std::shared_ptr<Buffer>> DoRead(int64_t nbytes);
+
+ class ARROW_NO_EXPORT Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/compressed_test.cc b/src/arrow/cpp/src/arrow/io/compressed_test.cc
new file mode 100644
index 000000000..c4ed60667
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/compressed_test.cc
@@ -0,0 +1,311 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <random>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/buffer.h"
+#include "arrow/io/compressed.h"
+#include "arrow/io/memory.h"
+#include "arrow/io/test_common.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/compression.h"
+
+namespace arrow {
+namespace io {
+
+using ::arrow::util::Codec;
+
+#ifdef ARROW_VALGRIND
+// Avoid slowing down tests too much with Valgrind
+static constexpr int64_t RANDOM_DATA_SIZE = 50 * 1024;
+static constexpr int64_t COMPRESSIBLE_DATA_SIZE = 120 * 1024;
+#else
+// The data should be large enough to exercise internal buffers
+static constexpr int64_t RANDOM_DATA_SIZE = 3 * 1024 * 1024;
+static constexpr int64_t COMPRESSIBLE_DATA_SIZE = 8 * 1024 * 1024;
+#endif
+
+std::vector<uint8_t> MakeRandomData(int data_size) {
+ std::vector<uint8_t> data(data_size);
+ random_bytes(data_size, 1234, data.data());
+ return data;
+}
+
+std::vector<uint8_t> MakeCompressibleData(int data_size) {
+ std::string base_data =
+ "Apache Arrow is a cross-language development platform for in-memory data";
+ int nrepeats = static_cast<int>(1 + data_size / base_data.size());
+
+ std::vector<uint8_t> data(base_data.size() * nrepeats);
+ for (int i = 0; i < nrepeats; ++i) {
+ std::memcpy(data.data() + i * base_data.size(), base_data.data(), base_data.size());
+ }
+ data.resize(data_size);
+ return data;
+}
+
+std::shared_ptr<Buffer> CompressDataOneShot(Codec* codec,
+ const std::vector<uint8_t>& data) {
+ int64_t max_compressed_len, compressed_len;
+ max_compressed_len = codec->MaxCompressedLen(data.size(), data.data());
+ auto compressed = *AllocateResizableBuffer(max_compressed_len);
+ compressed_len = *codec->Compress(data.size(), data.data(), max_compressed_len,
+ compressed->mutable_data());
+ ABORT_NOT_OK(compressed->Resize(compressed_len));
+ return std::move(compressed);
+}
+
+Status RunCompressedInputStream(Codec* codec, std::shared_ptr<Buffer> compressed,
+ int64_t* stream_pos, std::vector<uint8_t>* out) {
+ // Create compressed input stream
+ auto buffer_reader = std::make_shared<BufferReader>(compressed);
+ ARROW_ASSIGN_OR_RAISE(auto stream, CompressedInputStream::Make(codec, buffer_reader));
+
+ std::vector<uint8_t> decompressed;
+ int64_t decompressed_size = 0;
+ const int64_t chunk_size = 1111;
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(auto buf, stream->Read(chunk_size));
+ if (buf->size() == 0) {
+ // EOF
+ break;
+ }
+ decompressed.resize(decompressed_size + buf->size());
+ memcpy(decompressed.data() + decompressed_size, buf->data(), buf->size());
+ decompressed_size += buf->size();
+ }
+ if (stream_pos != nullptr) {
+ RETURN_NOT_OK(stream->Tell().Value(stream_pos));
+ }
+ *out = std::move(decompressed);
+ return Status::OK();
+}
+
+Status RunCompressedInputStream(Codec* codec, std::shared_ptr<Buffer> compressed,
+ std::vector<uint8_t>* out) {
+ return RunCompressedInputStream(codec, compressed, nullptr, out);
+}
+
+void CheckCompressedInputStream(Codec* codec, const std::vector<uint8_t>& data) {
+ // Create compressed data
+ auto compressed = CompressDataOneShot(codec, data);
+
+ std::vector<uint8_t> decompressed;
+ int64_t stream_pos = -1;
+ ASSERT_OK(RunCompressedInputStream(codec, compressed, &stream_pos, &decompressed));
+
+ ASSERT_EQ(decompressed.size(), data.size());
+ ASSERT_EQ(decompressed, data);
+ ASSERT_EQ(stream_pos, static_cast<int64_t>(decompressed.size()));
+}
+
+void CheckCompressedOutputStream(Codec* codec, const std::vector<uint8_t>& data,
+ bool do_flush) {
+ // Create compressed output stream
+ ASSERT_OK_AND_ASSIGN(auto buffer_writer, BufferOutputStream::Create());
+ ASSERT_OK_AND_ASSIGN(auto stream, CompressedOutputStream::Make(codec, buffer_writer));
+ ASSERT_OK_AND_EQ(0, stream->Tell());
+
+ const uint8_t* input = data.data();
+ int64_t input_len = data.size();
+ const int64_t chunk_size = 1111;
+ while (input_len > 0) {
+ int64_t nbytes = std::min(chunk_size, input_len);
+ ASSERT_OK(stream->Write(input, nbytes));
+ input += nbytes;
+ input_len -= nbytes;
+ if (do_flush) {
+ ASSERT_OK(stream->Flush());
+ }
+ }
+ ASSERT_OK_AND_EQ(static_cast<int64_t>(data.size()), stream->Tell());
+ ASSERT_OK(stream->Close());
+
+ // Get compressed data and decompress it
+ ASSERT_OK_AND_ASSIGN(auto compressed, buffer_writer->Finish());
+ std::vector<uint8_t> decompressed(data.size());
+ ASSERT_OK(codec->Decompress(compressed->size(), compressed->data(), decompressed.size(),
+ decompressed.data()));
+ ASSERT_EQ(decompressed, data);
+}
+
+class CompressedInputStreamTest : public ::testing::TestWithParam<Compression::type> {
+ protected:
+ Compression::type GetCompression() { return GetParam(); }
+
+ std::unique_ptr<Codec> MakeCodec() { return *Codec::Create(GetCompression()); }
+};
+
+class CompressedOutputStreamTest : public ::testing::TestWithParam<Compression::type> {
+ protected:
+ Compression::type GetCompression() { return GetParam(); }
+
+ std::unique_ptr<Codec> MakeCodec() { return *Codec::Create(GetCompression()); }
+};
+
+TEST_P(CompressedInputStreamTest, CompressibleData) {
+ auto codec = MakeCodec();
+ auto data = MakeCompressibleData(COMPRESSIBLE_DATA_SIZE);
+
+ CheckCompressedInputStream(codec.get(), data);
+}
+
+TEST_P(CompressedInputStreamTest, RandomData) {
+ auto codec = MakeCodec();
+ auto data = MakeRandomData(RANDOM_DATA_SIZE);
+
+ CheckCompressedInputStream(codec.get(), data);
+}
+
+TEST_P(CompressedInputStreamTest, TruncatedData) {
+ auto codec = MakeCodec();
+ auto data = MakeRandomData(10000);
+ auto compressed = CompressDataOneShot(codec.get(), data);
+ auto truncated = SliceBuffer(compressed, 0, compressed->size() - 3);
+
+ std::vector<uint8_t> decompressed;
+ ASSERT_RAISES(IOError, RunCompressedInputStream(codec.get(), truncated, &decompressed));
+}
+
+TEST_P(CompressedInputStreamTest, InvalidData) {
+ auto codec = MakeCodec();
+ auto compressed_data = MakeRandomData(100);
+
+ auto buffer_reader = std::make_shared<BufferReader>(Buffer::Wrap(compressed_data));
+ ASSERT_OK_AND_ASSIGN(auto stream,
+ CompressedInputStream::Make(codec.get(), buffer_reader));
+ ASSERT_RAISES(IOError, stream->Read(1024));
+}
+
+TEST_P(CompressedInputStreamTest, ConcatenatedStreams) {
+ // ARROW-5974: just like the "gunzip", "bzip2" and "xz" commands,
+ // decompressing concatenated compressed streams should yield the entire
+ // original data.
+ auto codec = MakeCodec();
+ auto data1 = MakeCompressibleData(100);
+ auto data2 = MakeCompressibleData(200);
+ auto compressed1 = CompressDataOneShot(codec.get(), data1);
+ auto compressed2 = CompressDataOneShot(codec.get(), data2);
+ std::vector<uint8_t> expected;
+ std::copy(data1.begin(), data1.end(), std::back_inserter(expected));
+ std::copy(data2.begin(), data2.end(), std::back_inserter(expected));
+
+ ASSERT_OK_AND_ASSIGN(auto concatenated, ConcatenateBuffers({compressed1, compressed2}));
+ std::vector<uint8_t> decompressed;
+ ASSERT_OK(RunCompressedInputStream(codec.get(), concatenated, &decompressed));
+ ASSERT_EQ(decompressed.size(), expected.size());
+ ASSERT_EQ(decompressed, expected);
+
+ // Same, but with an empty decompressed stream in the middle
+ auto compressed_empty = CompressDataOneShot(codec.get(), {});
+ ASSERT_OK_AND_ASSIGN(concatenated,
+ ConcatenateBuffers({compressed1, compressed_empty, compressed2}));
+ ASSERT_OK(RunCompressedInputStream(codec.get(), concatenated, &decompressed));
+ ASSERT_EQ(decompressed.size(), expected.size());
+ ASSERT_EQ(decompressed, expected);
+
+ // Same, but with an empty decompressed stream at the end
+ ASSERT_OK_AND_ASSIGN(concatenated,
+ ConcatenateBuffers({compressed1, compressed2, compressed_empty}));
+ ASSERT_OK(RunCompressedInputStream(codec.get(), concatenated, &decompressed));
+ ASSERT_EQ(decompressed.size(), expected.size());
+ ASSERT_EQ(decompressed, expected);
+}
+
+TEST_P(CompressedOutputStreamTest, CompressibleData) {
+ auto codec = MakeCodec();
+ auto data = MakeCompressibleData(COMPRESSIBLE_DATA_SIZE);
+
+ CheckCompressedOutputStream(codec.get(), data, false /* do_flush */);
+ CheckCompressedOutputStream(codec.get(), data, true /* do_flush */);
+}
+
+TEST_P(CompressedOutputStreamTest, RandomData) {
+ auto codec = MakeCodec();
+ auto data = MakeRandomData(RANDOM_DATA_SIZE);
+
+ CheckCompressedOutputStream(codec.get(), data, false /* do_flush */);
+ CheckCompressedOutputStream(codec.get(), data, true /* do_flush */);
+}
+
+// NOTES:
+// - Snappy doesn't support streaming decompression
+// - BZ2 doesn't support one-shot compression
+// - LZ4 raw format doesn't support streaming decompression
+
+#ifdef ARROW_WITH_SNAPPY
+TEST(TestSnappyInputStream, NotImplemented) {
+ std::unique_ptr<Codec> codec;
+ ASSERT_OK_AND_ASSIGN(codec, Codec::Create(Compression::SNAPPY));
+ std::shared_ptr<InputStream> stream = std::make_shared<BufferReader>("");
+ ASSERT_RAISES(NotImplemented, CompressedInputStream::Make(codec.get(), stream));
+}
+
+TEST(TestSnappyOutputStream, NotImplemented) {
+ std::unique_ptr<Codec> codec;
+ ASSERT_OK_AND_ASSIGN(codec, Codec::Create(Compression::SNAPPY));
+ std::shared_ptr<OutputStream> stream = std::make_shared<MockOutputStream>();
+ ASSERT_RAISES(NotImplemented, CompressedOutputStream::Make(codec.get(), stream));
+}
+#endif
+
+#if !defined ARROW_WITH_ZLIB && !defined ARROW_WITH_BROTLI && !defined ARROW_WITH_LZ4 && \
+ !defined ARROW_WITH_ZSTD
+GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CompressedInputStreamTest);
+GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CompressedOutputStreamTest);
+#endif
+
+#ifdef ARROW_WITH_ZLIB
+INSTANTIATE_TEST_SUITE_P(TestGZipInputStream, CompressedInputStreamTest,
+ ::testing::Values(Compression::GZIP));
+INSTANTIATE_TEST_SUITE_P(TestGZipOutputStream, CompressedOutputStreamTest,
+ ::testing::Values(Compression::GZIP));
+#endif
+
+#ifdef ARROW_WITH_BROTLI
+INSTANTIATE_TEST_SUITE_P(TestBrotliInputStream, CompressedInputStreamTest,
+ ::testing::Values(Compression::BROTLI));
+INSTANTIATE_TEST_SUITE_P(TestBrotliOutputStream, CompressedOutputStreamTest,
+ ::testing::Values(Compression::BROTLI));
+#endif
+
+#ifdef ARROW_WITH_LZ4
+INSTANTIATE_TEST_SUITE_P(TestLZ4InputStream, CompressedInputStreamTest,
+ ::testing::Values(Compression::LZ4_FRAME));
+INSTANTIATE_TEST_SUITE_P(TestLZ4OutputStream, CompressedOutputStreamTest,
+ ::testing::Values(Compression::LZ4_FRAME));
+#endif
+
+#ifdef ARROW_WITH_ZSTD
+INSTANTIATE_TEST_SUITE_P(TestZSTDInputStream, CompressedInputStreamTest,
+ ::testing::Values(Compression::ZSTD));
+INSTANTIATE_TEST_SUITE_P(TestZSTDOutputStream, CompressedOutputStreamTest,
+ ::testing::Values(Compression::ZSTD));
+#endif
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/concurrency.h b/src/arrow/cpp/src/arrow/io/concurrency.h
new file mode 100644
index 000000000..b41ad2c13
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/concurrency.h
@@ -0,0 +1,263 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/io/interfaces.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace io {
+namespace internal {
+
+template <class LockType>
+class SharedLockGuard {
+ public:
+ explicit SharedLockGuard(LockType* lock) : lock_(lock) { lock_->LockShared(); }
+
+ ~SharedLockGuard() { lock_->UnlockShared(); }
+
+ protected:
+ LockType* lock_;
+};
+
+template <class LockType>
+class ExclusiveLockGuard {
+ public:
+ explicit ExclusiveLockGuard(LockType* lock) : lock_(lock) { lock_->LockExclusive(); }
+
+ ~ExclusiveLockGuard() { lock_->UnlockExclusive(); }
+
+ protected:
+ LockType* lock_;
+};
+
+// Debug concurrency checker that marks "shared" and "exclusive" code sections,
+// aborting if the concurrency rules get violated. Does nothing in release mode.
+// Note that we intentionally use the same class declaration in debug and
+// release builds in order to avoid runtime failures when e.g. loading a
+// release-built DLL with a debug-built application, or the reverse.
+
+class ARROW_EXPORT SharedExclusiveChecker {
+ public:
+ SharedExclusiveChecker();
+ void LockShared();
+ void UnlockShared();
+ void LockExclusive();
+ void UnlockExclusive();
+
+ SharedLockGuard<SharedExclusiveChecker> shared_guard() {
+ return SharedLockGuard<SharedExclusiveChecker>(this);
+ }
+
+ ExclusiveLockGuard<SharedExclusiveChecker> exclusive_guard() {
+ return ExclusiveLockGuard<SharedExclusiveChecker>(this);
+ }
+
+ protected:
+ struct Impl;
+ std::shared_ptr<Impl> impl_;
+};
+
+// Concurrency wrappers for IO classes that check the correctness of
+// concurrent calls to various methods. It is not necessary to wrap all
+// IO classes with these, only a few core classes that get used in tests.
+//
+// We're not using virtual inheritance here as virtual bases have poorly
+// understood semantic overhead which we'd be passing on to implementers
+// and users of these interfaces. Instead, we just duplicate the method
+// wrappers between those two classes.
+
+template <class Derived>
+class ARROW_EXPORT InputStreamConcurrencyWrapper : public InputStream {
+ public:
+ Status Close() final {
+ auto guard = lock_.exclusive_guard();
+ return derived()->DoClose();
+ }
+
+ Status Abort() final {
+ auto guard = lock_.exclusive_guard();
+ return derived()->DoAbort();
+ }
+
+ Result<int64_t> Tell() const final {
+ auto guard = lock_.exclusive_guard();
+ return derived()->DoTell();
+ }
+
+ Result<int64_t> Read(int64_t nbytes, void* out) final {
+ auto guard = lock_.exclusive_guard();
+ return derived()->DoRead(nbytes, out);
+ }
+
+ Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) final {
+ auto guard = lock_.exclusive_guard();
+ return derived()->DoRead(nbytes);
+ }
+
+ Result<util::string_view> Peek(int64_t nbytes) final {
+ auto guard = lock_.exclusive_guard();
+ return derived()->DoPeek(nbytes);
+ }
+
+ /*
+ Methods to implement in derived class:
+
+ Status DoClose();
+ Result<int64_t> DoTell() const;
+ Result<int64_t> DoRead(int64_t nbytes, void* out);
+ Result<std::shared_ptr<Buffer>> DoRead(int64_t nbytes);
+
+ And optionally:
+
+ Status DoAbort() override;
+ Result<util::string_view> DoPeek(int64_t nbytes) override;
+
+ These methods should be protected in the derived class and
+ InputStreamConcurrencyWrapper declared as a friend with
+
+ friend InputStreamConcurrencyWrapper<derived>;
+ */
+
+ protected:
+ // Default implementations. They are virtual because the derived class may
+ // have derived classes itself.
+ virtual Status DoAbort() { return derived()->DoClose(); }
+
+ virtual Result<util::string_view> DoPeek(int64_t ARROW_ARG_UNUSED(nbytes)) {
+ return Status::NotImplemented("Peek not implemented");
+ }
+
+ Derived* derived() { return ::arrow::internal::checked_cast<Derived*>(this); }
+
+ const Derived* derived() const {
+ return ::arrow::internal::checked_cast<const Derived*>(this);
+ }
+
+ mutable SharedExclusiveChecker lock_;
+};
+
+template <class Derived>
+class ARROW_EXPORT RandomAccessFileConcurrencyWrapper : public RandomAccessFile {
+ public:
+ Status Close() final {
+ auto guard = lock_.exclusive_guard();
+ return derived()->DoClose();
+ }
+
+ Status Abort() final {
+ auto guard = lock_.exclusive_guard();
+ return derived()->DoAbort();
+ }
+
+ Result<int64_t> Tell() const final {
+ auto guard = lock_.exclusive_guard();
+ return derived()->DoTell();
+ }
+
+ Result<int64_t> Read(int64_t nbytes, void* out) final {
+ auto guard = lock_.exclusive_guard();
+ return derived()->DoRead(nbytes, out);
+ }
+
+ Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) final {
+ auto guard = lock_.exclusive_guard();
+ return derived()->DoRead(nbytes);
+ }
+
+ Result<util::string_view> Peek(int64_t nbytes) final {
+ auto guard = lock_.exclusive_guard();
+ return derived()->DoPeek(nbytes);
+ }
+
+ Status Seek(int64_t position) final {
+ auto guard = lock_.exclusive_guard();
+ return derived()->DoSeek(position);
+ }
+
+ Result<int64_t> GetSize() final {
+ auto guard = lock_.shared_guard();
+ return derived()->DoGetSize();
+ }
+
+ // NOTE: ReadAt doesn't use stream pointer, but it is allowed to update it
+ // (it's the case on Windows when using ReadFileEx).
+ // So any method that relies on the current position (even if it doesn't
+ // update it, such as Peek) cannot run in parallel with ReadAt and has
+ // to use the exclusive_guard.
+
+ Result<int64_t> ReadAt(int64_t position, int64_t nbytes, void* out) final {
+ auto guard = lock_.shared_guard();
+ return derived()->DoReadAt(position, nbytes, out);
+ }
+
+ Result<std::shared_ptr<Buffer>> ReadAt(int64_t position, int64_t nbytes) final {
+ auto guard = lock_.shared_guard();
+ return derived()->DoReadAt(position, nbytes);
+ }
+
+ /*
+ Methods to implement in derived class:
+
+ Status DoClose();
+ Result<int64_t> DoTell() const;
+ Result<int64_t> DoRead(int64_t nbytes, void* out);
+ Result<std::shared_ptr<Buffer>> DoRead(int64_t nbytes);
+ Status DoSeek(int64_t position);
+ Result<int64_t> DoGetSize()
+ Result<int64_t> DoReadAt(int64_t position, int64_t nbytes, void* out);
+ Result<std::shared_ptr<Buffer>> DoReadAt(int64_t position, int64_t nbytes);
+
+ And optionally:
+
+ Status DoAbort() override;
+ Result<util::string_view> DoPeek(int64_t nbytes) override;
+
+ These methods should be protected in the derived class and
+ RandomAccessFileConcurrencyWrapper declared as a friend with
+
+ friend RandomAccessFileConcurrencyWrapper<derived>;
+ */
+
+ protected:
+ // Default implementations. They are virtual because the derived class may
+ // have derived classes itself.
+ virtual Status DoAbort() { return derived()->DoClose(); }
+
+ virtual Result<util::string_view> DoPeek(int64_t ARROW_ARG_UNUSED(nbytes)) {
+ return Status::NotImplemented("Peek not implemented");
+ }
+
+ Derived* derived() { return ::arrow::internal::checked_cast<Derived*>(this); }
+
+ const Derived* derived() const {
+ return ::arrow::internal::checked_cast<const Derived*>(this);
+ }
+
+ mutable SharedExclusiveChecker lock_;
+};
+
+} // namespace internal
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/file.cc b/src/arrow/cpp/src/arrow/io/file.cc
new file mode 100644
index 000000000..effbfd30b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/file.cc
@@ -0,0 +1,789 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/windows_compatibility.h" // IWYU pragma: keep
+
+// sys/mman.h not present in Visual Studio or Cygwin
+#ifdef _WIN32
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+#include "arrow/io/mman.h"
+#undef Realloc
+#undef Free
+#else
+#include <fcntl.h>
+#include <sys/mman.h>
+#include <unistd.h> // IWYU pragma: keep
+#endif
+
+#include <algorithm>
+#include <atomic>
+#include <cerrno>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <mutex>
+#include <sstream>
+#include <string>
+#include <utility>
+
+// ----------------------------------------------------------------------
+// Other Arrow includes
+
+#include "arrow/io/file.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/util_internal.h"
+
+#include "arrow/buffer.h"
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/util/future.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::IOErrorFromErrno;
+
+namespace io {
+
+class OSFile {
+ public:
+ OSFile() : fd_(-1), is_open_(false), size_(-1), need_seeking_(false) {}
+
+ ~OSFile() {}
+
+ // Note: only one of the Open* methods below may be called on a given instance
+
+ Status OpenWritable(const std::string& path, bool truncate, bool append,
+ bool write_only) {
+ RETURN_NOT_OK(SetFileName(path));
+
+ ARROW_ASSIGN_OR_RAISE(fd_, ::arrow::internal::FileOpenWritable(file_name_, write_only,
+ truncate, append));
+ is_open_ = true;
+ mode_ = write_only ? FileMode::WRITE : FileMode::READWRITE;
+
+ if (!truncate) {
+ ARROW_ASSIGN_OR_RAISE(size_, ::arrow::internal::FileGetSize(fd_));
+ } else {
+ size_ = 0;
+ }
+ return Status::OK();
+ }
+
+ // This is different from OpenWritable(string, ...) in that it doesn't
+ // truncate nor mandate a seekable file
+ Status OpenWritable(int fd) {
+ auto result = ::arrow::internal::FileGetSize(fd);
+ if (result.ok()) {
+ size_ = *result;
+ } else {
+ // Non-seekable file
+ size_ = -1;
+ }
+ RETURN_NOT_OK(SetFileName(fd));
+ is_open_ = true;
+ mode_ = FileMode::WRITE;
+ fd_ = fd;
+ return Status::OK();
+ }
+
+ Status OpenReadable(const std::string& path) {
+ RETURN_NOT_OK(SetFileName(path));
+
+ ARROW_ASSIGN_OR_RAISE(fd_, ::arrow::internal::FileOpenReadable(file_name_));
+ ARROW_ASSIGN_OR_RAISE(size_, ::arrow::internal::FileGetSize(fd_));
+
+ is_open_ = true;
+ mode_ = FileMode::READ;
+ return Status::OK();
+ }
+
+ Status OpenReadable(int fd) {
+ ARROW_ASSIGN_OR_RAISE(size_, ::arrow::internal::FileGetSize(fd));
+ RETURN_NOT_OK(SetFileName(fd));
+ is_open_ = true;
+ mode_ = FileMode::READ;
+ fd_ = fd;
+ return Status::OK();
+ }
+
+ Status CheckClosed() const {
+ if (!is_open_) {
+ return Status::Invalid("Invalid operation on closed file");
+ }
+ return Status::OK();
+ }
+
+ Status Close() {
+ if (is_open_) {
+ // Even if closing fails, the fd will likely be closed (perhaps it's
+ // already closed).
+ is_open_ = false;
+ int fd = fd_;
+ fd_ = -1;
+ RETURN_NOT_OK(::arrow::internal::FileClose(fd));
+ }
+ return Status::OK();
+ }
+
+ Result<int64_t> Read(int64_t nbytes, void* out) {
+ RETURN_NOT_OK(CheckClosed());
+ RETURN_NOT_OK(CheckPositioned());
+ return ::arrow::internal::FileRead(fd_, reinterpret_cast<uint8_t*>(out), nbytes);
+ }
+
+ Result<int64_t> ReadAt(int64_t position, int64_t nbytes, void* out) {
+ RETURN_NOT_OK(CheckClosed());
+ RETURN_NOT_OK(internal::ValidateRange(position, nbytes));
+ // ReadAt() leaves the file position undefined, so require that we seek
+ // before calling Read() or Write().
+ need_seeking_.store(true);
+ return ::arrow::internal::FileReadAt(fd_, reinterpret_cast<uint8_t*>(out), position,
+ nbytes);
+ }
+
+ Status Seek(int64_t pos) {
+ RETURN_NOT_OK(CheckClosed());
+ if (pos < 0) {
+ return Status::Invalid("Invalid position");
+ }
+ Status st = ::arrow::internal::FileSeek(fd_, pos);
+ if (st.ok()) {
+ need_seeking_.store(false);
+ }
+ return st;
+ }
+
+ Result<int64_t> Tell() const {
+ RETURN_NOT_OK(CheckClosed());
+ return ::arrow::internal::FileTell(fd_);
+ }
+
+ Status Write(const void* data, int64_t length) {
+ RETURN_NOT_OK(CheckClosed());
+
+ std::lock_guard<std::mutex> guard(lock_);
+ RETURN_NOT_OK(CheckPositioned());
+ if (length < 0) {
+ return Status::IOError("Length must be non-negative");
+ }
+ return ::arrow::internal::FileWrite(fd_, reinterpret_cast<const uint8_t*>(data),
+ length);
+ }
+
+ int fd() const { return fd_; }
+
+ bool is_open() const { return is_open_; }
+
+ int64_t size() const { return size_; }
+
+ FileMode::type mode() const { return mode_; }
+
+ std::mutex& lock() { return lock_; }
+
+ protected:
+ Status SetFileName(const std::string& file_name) {
+ return ::arrow::internal::PlatformFilename::FromString(file_name).Value(&file_name_);
+ }
+
+ Status SetFileName(int fd) {
+ std::stringstream ss;
+ ss << "<fd " << fd << ">";
+ return SetFileName(ss.str());
+ }
+
+ Status CheckPositioned() {
+ if (need_seeking_.load()) {
+ return Status::Invalid(
+ "Need seeking after ReadAt() before "
+ "calling implicitly-positioned operation");
+ }
+ return Status::OK();
+ }
+
+ ::arrow::internal::PlatformFilename file_name_;
+
+ std::mutex lock_;
+
+ // File descriptor
+ int fd_;
+
+ FileMode::type mode_;
+
+ bool is_open_;
+ int64_t size_;
+ // Whether ReadAt made the file position non-deterministic.
+ std::atomic<bool> need_seeking_;
+};
+
+// ----------------------------------------------------------------------
+// ReadableFile implementation
+
+class ReadableFile::ReadableFileImpl : public OSFile {
+ public:
+ explicit ReadableFileImpl(MemoryPool* pool) : OSFile(), pool_(pool) {}
+
+ Status Open(const std::string& path) { return OpenReadable(path); }
+ Status Open(int fd) { return OpenReadable(fd); }
+
+ Result<std::shared_ptr<Buffer>> ReadBuffer(int64_t nbytes) {
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateResizableBuffer(nbytes, pool_));
+
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, Read(nbytes, buffer->mutable_data()));
+ if (bytes_read < nbytes) {
+ RETURN_NOT_OK(buffer->Resize(bytes_read));
+ buffer->ZeroPadding();
+ }
+ return std::move(buffer);
+ }
+
+ Result<std::shared_ptr<Buffer>> ReadBufferAt(int64_t position, int64_t nbytes) {
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateResizableBuffer(nbytes, pool_));
+
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read,
+ ReadAt(position, nbytes, buffer->mutable_data()));
+ if (bytes_read < nbytes) {
+ RETURN_NOT_OK(buffer->Resize(bytes_read));
+ buffer->ZeroPadding();
+ }
+ return std::move(buffer);
+ }
+
+ Status WillNeed(const std::vector<ReadRange>& ranges) {
+ auto report_error = [](int errnum, const char* msg) -> Status {
+ if (errnum == EBADF || errnum == EINVAL) {
+ // These are logic errors, so raise them
+ return IOErrorFromErrno(errnum, msg);
+ }
+#ifndef NDEBUG
+ // Other errors may be encountered if the target device or filesystem
+ // does not support fadvise advisory (for example, macOS can return
+ // ENOTTY on macOS: ARROW-13983). Log the error for diagnosis
+ // on debug builds, but avoid bothering the user otherwise.
+ ARROW_LOG(WARNING) << IOErrorFromErrno(errnum, msg).ToString();
+#else
+ ARROW_UNUSED(msg);
+#endif
+ return Status::OK();
+ };
+ RETURN_NOT_OK(CheckClosed());
+ for (const auto& range : ranges) {
+ RETURN_NOT_OK(internal::ValidateRange(range.offset, range.length));
+#if defined(POSIX_FADV_WILLNEED)
+ int ret = posix_fadvise(fd_, range.offset, range.length, POSIX_FADV_WILLNEED);
+ if (ret) {
+ RETURN_NOT_OK(report_error(ret, "posix_fadvise failed"));
+ }
+#elif defined(F_RDADVISE) // macOS, BSD?
+ struct {
+ off_t ra_offset;
+ int ra_count;
+ } radvisory{range.offset, static_cast<int>(range.length)};
+ if (radvisory.ra_count > 0 && fcntl(fd_, F_RDADVISE, &radvisory) == -1) {
+ RETURN_NOT_OK(report_error(errno, "fcntl(fd, F_RDADVISE, ...) failed"));
+ }
+#endif
+ }
+ return Status::OK();
+ }
+
+ private:
+ MemoryPool* pool_;
+};
+
+ReadableFile::ReadableFile(MemoryPool* pool) { impl_.reset(new ReadableFileImpl(pool)); }
+
+ReadableFile::~ReadableFile() { internal::CloseFromDestructor(this); }
+
+Result<std::shared_ptr<ReadableFile>> ReadableFile::Open(const std::string& path,
+ MemoryPool* pool) {
+ auto file = std::shared_ptr<ReadableFile>(new ReadableFile(pool));
+ RETURN_NOT_OK(file->impl_->Open(path));
+ return file;
+}
+
+Result<std::shared_ptr<ReadableFile>> ReadableFile::Open(int fd, MemoryPool* pool) {
+ auto file = std::shared_ptr<ReadableFile>(new ReadableFile(pool));
+ RETURN_NOT_OK(file->impl_->Open(fd));
+ return file;
+}
+
+Status ReadableFile::DoClose() { return impl_->Close(); }
+
+bool ReadableFile::closed() const { return !impl_->is_open(); }
+
+Status ReadableFile::WillNeed(const std::vector<ReadRange>& ranges) {
+ return impl_->WillNeed(ranges);
+}
+
+Result<int64_t> ReadableFile::DoTell() const { return impl_->Tell(); }
+
+Result<int64_t> ReadableFile::DoRead(int64_t nbytes, void* out) {
+ return impl_->Read(nbytes, out);
+}
+
+Result<int64_t> ReadableFile::DoReadAt(int64_t position, int64_t nbytes, void* out) {
+ return impl_->ReadAt(position, nbytes, out);
+}
+
+Result<std::shared_ptr<Buffer>> ReadableFile::DoReadAt(int64_t position, int64_t nbytes) {
+ return impl_->ReadBufferAt(position, nbytes);
+}
+
+Result<std::shared_ptr<Buffer>> ReadableFile::DoRead(int64_t nbytes) {
+ return impl_->ReadBuffer(nbytes);
+}
+
+Result<int64_t> ReadableFile::DoGetSize() { return impl_->size(); }
+
+Status ReadableFile::DoSeek(int64_t pos) { return impl_->Seek(pos); }
+
+int ReadableFile::file_descriptor() const { return impl_->fd(); }
+
+// ----------------------------------------------------------------------
+// FileOutputStream
+
+class FileOutputStream::FileOutputStreamImpl : public OSFile {
+ public:
+ Status Open(const std::string& path, bool append) {
+ const bool truncate = !append;
+ return OpenWritable(path, truncate, append, true /* write_only */);
+ }
+ Status Open(int fd) { return OpenWritable(fd); }
+};
+
+FileOutputStream::FileOutputStream() { impl_.reset(new FileOutputStreamImpl()); }
+
+FileOutputStream::~FileOutputStream() { internal::CloseFromDestructor(this); }
+
+Result<std::shared_ptr<FileOutputStream>> FileOutputStream::Open(const std::string& path,
+ bool append) {
+ auto stream = std::shared_ptr<FileOutputStream>(new FileOutputStream());
+ RETURN_NOT_OK(stream->impl_->Open(path, append));
+ return stream;
+}
+
+Result<std::shared_ptr<FileOutputStream>> FileOutputStream::Open(int fd) {
+ auto stream = std::shared_ptr<FileOutputStream>(new FileOutputStream());
+ RETURN_NOT_OK(stream->impl_->Open(fd));
+ return stream;
+}
+
+Status FileOutputStream::Close() { return impl_->Close(); }
+
+bool FileOutputStream::closed() const { return !impl_->is_open(); }
+
+Result<int64_t> FileOutputStream::Tell() const { return impl_->Tell(); }
+
+Status FileOutputStream::Write(const void* data, int64_t length) {
+ return impl_->Write(data, length);
+}
+
+int FileOutputStream::file_descriptor() const { return impl_->fd(); }
+
+// ----------------------------------------------------------------------
+// Implement MemoryMappedFile
+
+class MemoryMappedFile::MemoryMap
+ : public std::enable_shared_from_this<MemoryMappedFile::MemoryMap> {
+ public:
+ // An object representing the entire memory-mapped region.
+ // It can be sliced in order to return individual subregions, which
+ // will then keep the original region alive as long as necessary.
+ class Region : public Buffer {
+ public:
+ Region(std::shared_ptr<MemoryMappedFile::MemoryMap> memory_map, uint8_t* data,
+ int64_t size)
+ : Buffer(data, size) {
+ is_mutable_ = memory_map->writable();
+ }
+
+ ~Region() {
+ if (data_ != nullptr) {
+ int result = munmap(data(), static_cast<size_t>(size_));
+ ARROW_CHECK_EQ(result, 0) << "munmap failed";
+ }
+ }
+
+ // For convenience
+ uint8_t* data() { return const_cast<uint8_t*>(data_); }
+
+ void Detach() { data_ = nullptr; }
+ };
+
+ MemoryMap() : file_size_(0), map_len_(0) {}
+
+ ~MemoryMap() { ARROW_CHECK_OK(Close()); }
+
+ Status Close() {
+ if (file_->is_open()) {
+ // Lose our reference to the MemoryMappedRegion, so that munmap()
+ // is called as soon as all buffer exports are released.
+ region_.reset();
+ return file_->Close();
+ } else {
+ return Status::OK();
+ }
+ }
+
+ bool closed() const { return !file_->is_open(); }
+
+ Status CheckClosed() const {
+ if (closed()) {
+ return Status::Invalid("Invalid operation on closed file");
+ }
+ return Status::OK();
+ }
+
+ Status Open(const std::string& path, FileMode::type mode, const int64_t offset = 0,
+ const int64_t length = -1) {
+ file_.reset(new OSFile());
+
+ if (mode != FileMode::READ) {
+ // Memory mapping has permission failures if PROT_READ not set
+ prot_flags_ = PROT_READ | PROT_WRITE;
+ map_mode_ = MAP_SHARED;
+ constexpr bool append = false;
+ constexpr bool truncate = false;
+ constexpr bool write_only = false;
+ RETURN_NOT_OK(file_->OpenWritable(path, truncate, append, write_only));
+ } else {
+ prot_flags_ = PROT_READ;
+ map_mode_ = MAP_PRIVATE; // Changes are not to be committed back to the file
+ RETURN_NOT_OK(file_->OpenReadable(path));
+ }
+ map_len_ = offset_ = 0;
+
+ // Memory mapping fails when file size is 0
+ // delay it until the first resize
+ if (file_->size() > 0) {
+ RETURN_NOT_OK(InitMMap(file_->size(), false, offset, length));
+ }
+
+ position_ = 0;
+
+ return Status::OK();
+ }
+
+ // Resize the mmap and file to the specified size.
+ // Resize on memory mapped file region is not supported.
+ Status Resize(const int64_t new_size) {
+ if (!writable()) {
+ return Status::IOError("Cannot resize a readonly memory map");
+ }
+ if (map_len_ != file_size_) {
+ return Status::IOError("Cannot resize a partial memory map");
+ }
+ if (region_.use_count() > 1) {
+ // There are buffer exports currently, the MemoryMapRemap() call
+ // would make the buffers invalid
+ return Status::IOError("Cannot resize memory map while there are active readers");
+ }
+
+ if (new_size == 0) {
+ if (map_len_ > 0) {
+ // Just unmap the mmap and truncate the file to 0 size
+ region_.reset();
+ RETURN_NOT_OK(::arrow::internal::FileTruncate(file_->fd(), 0));
+ map_len_ = offset_ = file_size_ = 0;
+ }
+ position_ = 0;
+ return Status::OK();
+ }
+
+ if (map_len_ > 0) {
+ void* result;
+ auto data = region_->data();
+ RETURN_NOT_OK(::arrow::internal::MemoryMapRemap(data, map_len_, new_size,
+ file_->fd(), &result));
+ region_->Detach(); // avoid munmap() on destruction
+ region_ = std::make_shared<Region>(shared_from_this(),
+ static_cast<uint8_t*>(result), new_size);
+ map_len_ = file_size_ = new_size;
+ offset_ = 0;
+ if (position_ > map_len_) {
+ position_ = map_len_;
+ }
+ } else {
+ DCHECK_EQ(position_, 0);
+ // the mmap is not yet initialized, resize the underlying
+ // file, since it might have been 0-sized
+ RETURN_NOT_OK(InitMMap(new_size, /*resize_file*/ true));
+ }
+ return Status::OK();
+ }
+
+ Status Seek(int64_t position) {
+ if (position < 0) {
+ return Status::Invalid("position is out of bounds");
+ }
+ position_ = position;
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<Buffer>> Slice(int64_t offset, int64_t length) {
+ length = std::max<int64_t>(0, std::min(length, map_len_ - offset));
+
+ if (length > 0) {
+ DCHECK_NE(region_, nullptr);
+ return SliceBuffer(region_, offset, length);
+ } else {
+ return std::make_shared<Buffer>(nullptr, 0);
+ }
+ }
+
+ // map_len_ == file_size_ if memory mapping on the whole file
+ int64_t size() const { return map_len_; }
+
+ int64_t position() { return position_; }
+
+ void advance(int64_t nbytes) { position_ = position_ + nbytes; }
+
+ uint8_t* data() { return region_ ? region_->data() : nullptr; }
+
+ uint8_t* head() { return data() + position_; }
+
+ bool writable() { return file_->mode() != FileMode::READ; }
+
+ bool opened() { return file_->is_open(); }
+
+ int fd() const { return file_->fd(); }
+
+ std::mutex& write_lock() { return file_->lock(); }
+
+ std::mutex& resize_lock() { return resize_lock_; }
+
+ private:
+ // Initialize the mmap and set size, capacity and the data pointers
+ Status InitMMap(int64_t initial_size, bool resize_file = false,
+ const int64_t offset = 0, const int64_t length = -1) {
+ DCHECK(!region_);
+
+ if (resize_file) {
+ RETURN_NOT_OK(::arrow::internal::FileTruncate(file_->fd(), initial_size));
+ }
+
+ size_t mmap_length = static_cast<size_t>(initial_size);
+ if (length > initial_size) {
+ return Status::Invalid("mapping length is beyond file size");
+ }
+ if (length >= 0 && length < initial_size) {
+ // memory mapping a file region
+ mmap_length = static_cast<size_t>(length);
+ }
+
+ void* result = mmap(nullptr, mmap_length, prot_flags_, map_mode_, file_->fd(),
+ static_cast<off_t>(offset));
+ if (result == MAP_FAILED) {
+ return Status::IOError("Memory mapping file failed: ",
+ ::arrow::internal::ErrnoMessage(errno));
+ }
+ map_len_ = mmap_length;
+ offset_ = offset;
+ region_ = std::make_shared<Region>(shared_from_this(), static_cast<uint8_t*>(result),
+ map_len_);
+ file_size_ = initial_size;
+
+ return Status::OK();
+ }
+
+ std::unique_ptr<OSFile> file_;
+ int prot_flags_;
+ int map_mode_;
+
+ std::shared_ptr<Region> region_;
+ int64_t file_size_;
+ int64_t position_;
+ int64_t offset_;
+ int64_t map_len_;
+ std::mutex resize_lock_;
+};
+
+MemoryMappedFile::MemoryMappedFile() {}
+
+MemoryMappedFile::~MemoryMappedFile() { internal::CloseFromDestructor(this); }
+
+Result<std::shared_ptr<MemoryMappedFile>> MemoryMappedFile::Create(
+ const std::string& path, int64_t size) {
+ ARROW_ASSIGN_OR_RAISE(auto file, FileOutputStream::Open(path));
+ RETURN_NOT_OK(::arrow::internal::FileTruncate(file->file_descriptor(), size));
+ RETURN_NOT_OK(file->Close());
+ return MemoryMappedFile::Open(path, FileMode::READWRITE);
+}
+
+Result<std::shared_ptr<MemoryMappedFile>> MemoryMappedFile::Open(const std::string& path,
+ FileMode::type mode) {
+ std::shared_ptr<MemoryMappedFile> result(new MemoryMappedFile());
+
+ result->memory_map_.reset(new MemoryMap());
+ RETURN_NOT_OK(result->memory_map_->Open(path, mode));
+ return result;
+}
+
+Result<std::shared_ptr<MemoryMappedFile>> MemoryMappedFile::Open(const std::string& path,
+ FileMode::type mode,
+ const int64_t offset,
+ const int64_t length) {
+ std::shared_ptr<MemoryMappedFile> result(new MemoryMappedFile());
+
+ result->memory_map_.reset(new MemoryMap());
+ RETURN_NOT_OK(result->memory_map_->Open(path, mode, offset, length));
+ return result;
+}
+
+Result<int64_t> MemoryMappedFile::GetSize() {
+ RETURN_NOT_OK(memory_map_->CheckClosed());
+ return memory_map_->size();
+}
+
+Result<int64_t> MemoryMappedFile::Tell() const {
+ RETURN_NOT_OK(memory_map_->CheckClosed());
+ return memory_map_->position();
+}
+
+Status MemoryMappedFile::Seek(int64_t position) {
+ RETURN_NOT_OK(memory_map_->CheckClosed());
+ return memory_map_->Seek(position);
+}
+
+Status MemoryMappedFile::Close() { return memory_map_->Close(); }
+
+bool MemoryMappedFile::closed() const { return memory_map_->closed(); }
+
+Result<std::shared_ptr<Buffer>> MemoryMappedFile::ReadAt(int64_t position,
+ int64_t nbytes) {
+ RETURN_NOT_OK(memory_map_->CheckClosed());
+ // if the file is writable, we acquire the lock before creating any slices
+ // in case a resize is triggered concurrently, otherwise we wouldn't detect
+ // a change in the use count
+ auto guard_resize = memory_map_->writable()
+ ? std::unique_lock<std::mutex>(memory_map_->resize_lock())
+ : std::unique_lock<std::mutex>();
+
+ ARROW_ASSIGN_OR_RAISE(
+ nbytes, internal::ValidateReadRange(position, nbytes, memory_map_->size()));
+ // Arrange to page data in
+ RETURN_NOT_OK(::arrow::internal::MemoryAdviseWillNeed(
+ {{memory_map_->data() + position, static_cast<size_t>(nbytes)}}));
+ return memory_map_->Slice(position, nbytes);
+}
+
+Result<int64_t> MemoryMappedFile::ReadAt(int64_t position, int64_t nbytes, void* out) {
+ RETURN_NOT_OK(memory_map_->CheckClosed());
+ auto guard_resize = memory_map_->writable()
+ ? std::unique_lock<std::mutex>(memory_map_->resize_lock())
+ : std::unique_lock<std::mutex>();
+
+ ARROW_ASSIGN_OR_RAISE(
+ nbytes, internal::ValidateReadRange(position, nbytes, memory_map_->size()));
+ if (nbytes > 0) {
+ memcpy(out, memory_map_->data() + position, static_cast<size_t>(nbytes));
+ }
+ return nbytes;
+}
+
+Result<int64_t> MemoryMappedFile::Read(int64_t nbytes, void* out) {
+ RETURN_NOT_OK(memory_map_->CheckClosed());
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, ReadAt(memory_map_->position(), nbytes, out));
+ memory_map_->advance(bytes_read);
+ return bytes_read;
+}
+
+Result<std::shared_ptr<Buffer>> MemoryMappedFile::Read(int64_t nbytes) {
+ RETURN_NOT_OK(memory_map_->CheckClosed());
+ ARROW_ASSIGN_OR_RAISE(auto buffer, ReadAt(memory_map_->position(), nbytes));
+ memory_map_->advance(buffer->size());
+ return buffer;
+}
+
+Future<std::shared_ptr<Buffer>> MemoryMappedFile::ReadAsync(const IOContext&,
+ int64_t position,
+ int64_t nbytes) {
+ return Future<std::shared_ptr<Buffer>>::MakeFinished(ReadAt(position, nbytes));
+}
+
+Status MemoryMappedFile::WillNeed(const std::vector<ReadRange>& ranges) {
+ using ::arrow::internal::MemoryRegion;
+
+ RETURN_NOT_OK(memory_map_->CheckClosed());
+ auto guard_resize = memory_map_->writable()
+ ? std::unique_lock<std::mutex>(memory_map_->resize_lock())
+ : std::unique_lock<std::mutex>();
+
+ std::vector<MemoryRegion> regions(ranges.size());
+ for (size_t i = 0; i < ranges.size(); ++i) {
+ const auto& range = ranges[i];
+ ARROW_ASSIGN_OR_RAISE(
+ auto size,
+ internal::ValidateReadRange(range.offset, range.length, memory_map_->size()));
+ DCHECK_NE(memory_map_->data(), nullptr);
+ regions[i] = {const_cast<uint8_t*>(memory_map_->data() + range.offset),
+ static_cast<size_t>(size)};
+ }
+ return ::arrow::internal::MemoryAdviseWillNeed(regions);
+}
+
+bool MemoryMappedFile::supports_zero_copy() const { return true; }
+
+Status MemoryMappedFile::WriteAt(int64_t position, const void* data, int64_t nbytes) {
+ RETURN_NOT_OK(memory_map_->CheckClosed());
+ std::lock_guard<std::mutex> guard(memory_map_->write_lock());
+
+ if (!memory_map_->opened() || !memory_map_->writable()) {
+ return Status::IOError("Unable to write");
+ }
+ RETURN_NOT_OK(internal::ValidateWriteRange(position, nbytes, memory_map_->size()));
+
+ RETURN_NOT_OK(memory_map_->Seek(position));
+ return WriteInternal(data, nbytes);
+}
+
+Status MemoryMappedFile::Write(const void* data, int64_t nbytes) {
+ RETURN_NOT_OK(memory_map_->CheckClosed());
+ std::lock_guard<std::mutex> guard(memory_map_->write_lock());
+
+ if (!memory_map_->opened() || !memory_map_->writable()) {
+ return Status::IOError("Unable to write");
+ }
+ RETURN_NOT_OK(
+ internal::ValidateWriteRange(memory_map_->position(), nbytes, memory_map_->size()));
+
+ return WriteInternal(data, nbytes);
+}
+
+Status MemoryMappedFile::WriteInternal(const void* data, int64_t nbytes) {
+ memcpy(memory_map_->head(), data, static_cast<size_t>(nbytes));
+ memory_map_->advance(nbytes);
+ return Status::OK();
+}
+
+Status MemoryMappedFile::Resize(int64_t new_size) {
+ RETURN_NOT_OK(memory_map_->CheckClosed());
+ std::unique_lock<std::mutex> write_guard(memory_map_->write_lock(), std::defer_lock);
+ std::unique_lock<std::mutex> resize_guard(memory_map_->resize_lock(), std::defer_lock);
+ std::lock(write_guard, resize_guard);
+ RETURN_NOT_OK(memory_map_->Resize(new_size));
+ return Status::OK();
+}
+
+int MemoryMappedFile::file_descriptor() const { return memory_map_->fd(); }
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/file.h b/src/arrow/cpp/src/arrow/io/file.h
new file mode 100644
index 000000000..50d4f2c4d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/file.h
@@ -0,0 +1,221 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// IO interface implementations for OS files
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/io/concurrency.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Buffer;
+class MemoryPool;
+class Status;
+
+namespace io {
+
+/// \brief An operating system file open in write-only mode.
+class ARROW_EXPORT FileOutputStream : public OutputStream {
+ public:
+ ~FileOutputStream() override;
+
+ /// \brief Open a local file for writing, truncating any existing file
+ /// \param[in] path with UTF8 encoding
+ /// \param[in] append append to existing file, otherwise truncate to 0 bytes
+ /// \return an open FileOutputStream
+ ///
+ /// When opening a new file, any existing file with the indicated path is
+ /// truncated to 0 bytes, deleting any existing data
+ static Result<std::shared_ptr<FileOutputStream>> Open(const std::string& path,
+ bool append = false);
+
+ /// \brief Open a file descriptor for writing. The underlying file isn't
+ /// truncated.
+ /// \param[in] fd file descriptor
+ /// \return an open FileOutputStream
+ ///
+ /// The file descriptor becomes owned by the OutputStream, and will be closed
+ /// on Close() or destruction.
+ static Result<std::shared_ptr<FileOutputStream>> Open(int fd);
+
+ // OutputStream interface
+ Status Close() override;
+ bool closed() const override;
+ Result<int64_t> Tell() const override;
+
+ // Write bytes to the stream. Thread-safe
+ Status Write(const void* data, int64_t nbytes) override;
+ /// \cond FALSE
+ using Writable::Write;
+ /// \endcond
+
+ int file_descriptor() const;
+
+ private:
+ FileOutputStream();
+
+ class ARROW_NO_EXPORT FileOutputStreamImpl;
+ std::unique_ptr<FileOutputStreamImpl> impl_;
+};
+
+/// \brief An operating system file open in read-only mode.
+///
+/// Reads through this implementation are unbuffered. If many small reads
+/// need to be issued, it is recommended to use a buffering layer for good
+/// performance.
+class ARROW_EXPORT ReadableFile
+ : public internal::RandomAccessFileConcurrencyWrapper<ReadableFile> {
+ public:
+ ~ReadableFile() override;
+
+ /// \brief Open a local file for reading
+ /// \param[in] path with UTF8 encoding
+ /// \param[in] pool a MemoryPool for memory allocations
+ /// \return ReadableFile instance
+ static Result<std::shared_ptr<ReadableFile>> Open(
+ const std::string& path, MemoryPool* pool = default_memory_pool());
+
+ /// \brief Open a local file for reading
+ /// \param[in] fd file descriptor
+ /// \param[in] pool a MemoryPool for memory allocations
+ /// \return ReadableFile instance
+ ///
+ /// The file descriptor becomes owned by the ReadableFile, and will be closed
+ /// on Close() or destruction.
+ static Result<std::shared_ptr<ReadableFile>> Open(
+ int fd, MemoryPool* pool = default_memory_pool());
+
+ bool closed() const override;
+
+ int file_descriptor() const;
+
+ Status WillNeed(const std::vector<ReadRange>& ranges) override;
+
+ private:
+ friend RandomAccessFileConcurrencyWrapper<ReadableFile>;
+
+ explicit ReadableFile(MemoryPool* pool);
+
+ Status DoClose();
+ Result<int64_t> DoTell() const;
+ Result<int64_t> DoRead(int64_t nbytes, void* buffer);
+ Result<std::shared_ptr<Buffer>> DoRead(int64_t nbytes);
+
+ /// \brief Thread-safe implementation of ReadAt
+ Result<int64_t> DoReadAt(int64_t position, int64_t nbytes, void* out);
+
+ /// \brief Thread-safe implementation of ReadAt
+ Result<std::shared_ptr<Buffer>> DoReadAt(int64_t position, int64_t nbytes);
+
+ Result<int64_t> DoGetSize();
+ Status DoSeek(int64_t position);
+
+ class ARROW_NO_EXPORT ReadableFileImpl;
+ std::unique_ptr<ReadableFileImpl> impl_;
+};
+
+/// \brief A file interface that uses memory-mapped files for memory interactions
+///
+/// This implementation supports zero-copy reads. The same class is used
+/// for both reading and writing.
+///
+/// If opening a file in a writable mode, it is not truncated first as with
+/// FileOutputStream.
+class ARROW_EXPORT MemoryMappedFile : public ReadWriteFileInterface {
+ public:
+ ~MemoryMappedFile() override;
+
+ /// Create new file with indicated size, return in read/write mode
+ static Result<std::shared_ptr<MemoryMappedFile>> Create(const std::string& path,
+ int64_t size);
+
+ // mmap() with whole file
+ static Result<std::shared_ptr<MemoryMappedFile>> Open(const std::string& path,
+ FileMode::type mode);
+
+ // mmap() with a region of file, the offset must be a multiple of the page size
+ static Result<std::shared_ptr<MemoryMappedFile>> Open(const std::string& path,
+ FileMode::type mode,
+ const int64_t offset,
+ const int64_t length);
+
+ Status Close() override;
+
+ bool closed() const override;
+
+ Result<int64_t> Tell() const override;
+
+ Status Seek(int64_t position) override;
+
+ // Required by RandomAccessFile, copies memory into out. Not thread-safe
+ Result<int64_t> Read(int64_t nbytes, void* out) override;
+
+ // Zero copy read, moves position pointer. Not thread-safe
+ Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) override;
+
+ // Zero-copy read, leaves position unchanged. Acquires a reader lock
+ // for the duration of slice creation (typically very short). Is thread-safe.
+ Result<std::shared_ptr<Buffer>> ReadAt(int64_t position, int64_t nbytes) override;
+
+ // Raw copy of the memory at specified position. Thread-safe, but
+ // locks out other readers for the duration of memcpy. Prefer the
+ // zero copy method
+ Result<int64_t> ReadAt(int64_t position, int64_t nbytes, void* out) override;
+
+ // Synchronous ReadAsync override
+ Future<std::shared_ptr<Buffer>> ReadAsync(const IOContext&, int64_t position,
+ int64_t nbytes) override;
+
+ Status WillNeed(const std::vector<ReadRange>& ranges) override;
+
+ bool supports_zero_copy() const override;
+
+ /// Write data at the current position in the file. Thread-safe
+ Status Write(const void* data, int64_t nbytes) override;
+ /// \cond FALSE
+ using Writable::Write;
+ /// \endcond
+
+ /// Set the size of the map to new_size.
+ Status Resize(int64_t new_size);
+
+ /// Write data at a particular position in the file. Thread-safe
+ Status WriteAt(int64_t position, const void* data, int64_t nbytes) override;
+
+ Result<int64_t> GetSize() override;
+
+ int file_descriptor() const;
+
+ private:
+ MemoryMappedFile();
+
+ Status WriteInternal(const void* data, int64_t nbytes);
+
+ class ARROW_NO_EXPORT MemoryMap;
+ std::shared_ptr<MemoryMap> memory_map_;
+};
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/file_benchmark.cc b/src/arrow/cpp/src/arrow/io/file_benchmark.cc
new file mode 100644
index 000000000..5e7e55725
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/file_benchmark.cc
@@ -0,0 +1,301 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/io/buffered.h"
+#include "arrow/io/file.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/windows_compatibility.h"
+
+#include "benchmark/benchmark.h"
+
+#include <algorithm>
+#include <atomic>
+#include <cstdlib>
+#include <iostream>
+#include <thread>
+#include <valarray>
+
+#ifdef _WIN32
+
+#include <io.h>
+
+#else
+
+#include <fcntl.h>
+#include <poll.h>
+#include <unistd.h>
+
+#endif
+
+namespace arrow {
+
+std::string GetNullFile() {
+#ifdef _WIN32
+ return "NUL";
+#else
+ return "/dev/null";
+#endif
+}
+
+const std::valarray<int64_t> small_sizes = {8, 24, 33, 1, 32, 192, 16, 40};
+const std::valarray<int64_t> large_sizes = {8192, 100000};
+
+constexpr int64_t kBufferSize = 4096;
+
+#ifdef _WIN32
+
+class BackgroundReader {
+ // A class that reads data in the background from a file descriptor
+ // (Windows implementation)
+
+ public:
+ static std::shared_ptr<BackgroundReader> StartReader(int fd) {
+ std::shared_ptr<BackgroundReader> reader(new BackgroundReader(fd));
+ reader->worker_.reset(new std::thread([=] { reader->LoopReading(); }));
+ return reader;
+ }
+ void Stop() { ARROW_CHECK(SetEvent(event_)); }
+ void Join() { worker_->join(); }
+
+ ~BackgroundReader() {
+ ABORT_NOT_OK(internal::FileClose(fd_));
+ ARROW_CHECK(CloseHandle(event_));
+ }
+
+ protected:
+ explicit BackgroundReader(int fd) : fd_(fd), total_bytes_(0) {
+ file_handle_ = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
+ ARROW_CHECK_NE(file_handle_, INVALID_HANDLE_VALUE);
+ event_ =
+ CreateEvent(nullptr, /* bManualReset=*/TRUE, /* bInitialState=*/FALSE, nullptr);
+ ARROW_CHECK_NE(event_, INVALID_HANDLE_VALUE);
+ }
+
+ void LoopReading() {
+ const HANDLE handles[] = {file_handle_, event_};
+ while (true) {
+ DWORD ret = WaitForMultipleObjects(2, handles, /* bWaitAll=*/FALSE, INFINITE);
+ ARROW_CHECK_NE(ret, WAIT_FAILED);
+ if (ret == WAIT_OBJECT_0 + 1) {
+ // Got stop request
+ break;
+ } else if (ret == WAIT_OBJECT_0) {
+ // File ready for reading
+ total_bytes_ += *internal::FileRead(fd_, buffer_, buffer_size_);
+ } else {
+ ARROW_LOG(FATAL) << "Unexpected WaitForMultipleObjects return value " << ret;
+ }
+ }
+ }
+
+ int fd_;
+ HANDLE file_handle_, event_;
+ int64_t total_bytes_;
+
+ static const int64_t buffer_size_ = 16384;
+ uint8_t buffer_[buffer_size_];
+
+ std::unique_ptr<std::thread> worker_;
+};
+
+#else
+
+class BackgroundReader {
+ // A class that reads data in the background from a file descriptor
+ // (Unix implementation)
+
+ public:
+ static std::shared_ptr<BackgroundReader> StartReader(int fd) {
+ std::shared_ptr<BackgroundReader> reader(new BackgroundReader(fd));
+ reader->worker_.reset(new std::thread([=] { reader->LoopReading(); }));
+ return reader;
+ }
+ void Stop() {
+ const uint8_t data[] = "x";
+ ABORT_NOT_OK(internal::FileWrite(wakeup_w_, data, 1));
+ }
+ void Join() { worker_->join(); }
+
+ ~BackgroundReader() {
+ for (int fd : {fd_, wakeup_r_, wakeup_w_}) {
+ ABORT_NOT_OK(internal::FileClose(fd));
+ }
+ }
+
+ protected:
+ explicit BackgroundReader(int fd) : fd_(fd), total_bytes_(0) {
+ // Prepare self-pipe trick
+ auto pipe = *internal::CreatePipe();
+ wakeup_r_ = pipe.rfd;
+ wakeup_w_ = pipe.wfd;
+ // Put fd in non-blocking mode
+ fcntl(fd, F_SETFL, O_NONBLOCK);
+ }
+
+ void LoopReading() {
+ struct pollfd pollfds[2];
+ pollfds[0].fd = fd_;
+ pollfds[0].events = POLLIN;
+ pollfds[1].fd = wakeup_r_;
+ pollfds[1].events = POLLIN;
+ while (true) {
+ int ret = poll(pollfds, 2, -1 /* timeout */);
+ if (ret < 1) {
+ std::cerr << "poll() failed with code " << ret << "\n";
+ abort();
+ }
+ if (pollfds[1].revents & POLLIN) {
+ // We're done
+ break;
+ }
+ if (!(pollfds[0].revents & POLLIN)) {
+ continue;
+ }
+ auto result = internal::FileRead(fd_, buffer_, buffer_size_);
+ // There could be a spurious wakeup followed by EAGAIN
+ if (result.ok()) {
+ total_bytes_ += *result;
+ }
+ }
+ }
+
+ int fd_, wakeup_r_, wakeup_w_;
+ int64_t total_bytes_;
+
+ static const int64_t buffer_size_ = 16384;
+ uint8_t buffer_[buffer_size_];
+
+ std::unique_ptr<std::thread> worker_;
+};
+
+#endif
+
+// Set up a pipe with an OutputStream at one end and a BackgroundReader at
+// the other end.
+static void SetupPipeWriter(std::shared_ptr<io::OutputStream>* stream,
+ std::shared_ptr<BackgroundReader>* reader) {
+ auto pipe = *internal::CreatePipe();
+ *stream = *io::FileOutputStream::Open(pipe.wfd);
+ *reader = BackgroundReader::StartReader(pipe.rfd);
+}
+
+static void BenchmarkStreamingWrites(benchmark::State& state,
+ std::valarray<int64_t> sizes,
+ io::OutputStream* stream,
+ BackgroundReader* reader = nullptr) {
+ const std::string datastr(*std::max_element(std::begin(sizes), std::end(sizes)), 'x');
+ const void* data = datastr.data();
+ const int64_t sum_sizes = sizes.sum();
+
+ while (state.KeepRunning()) {
+ for (const int64_t size : sizes) {
+ ABORT_NOT_OK(stream->Write(data, size));
+ }
+ }
+ // For Windows: need to close writer before joining reader thread.
+ ABORT_NOT_OK(stream->Close());
+
+ const int64_t total_bytes = static_cast<int64_t>(state.iterations()) * sum_sizes;
+ state.SetBytesProcessed(total_bytes);
+
+ if (reader != nullptr) {
+ // Wake up and stop
+ reader->Stop();
+ reader->Join();
+ }
+}
+
+// Benchmark writing to /dev/null
+//
+// This situation is irrealistic as the kernel likely doesn't
+// copy the data at all, so we only measure small writes.
+
+static void FileOutputStreamSmallWritesToNull(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto stream = *io::FileOutputStream::Open(GetNullFile());
+
+ BenchmarkStreamingWrites(state, small_sizes, stream.get());
+}
+
+static void BufferedOutputStreamSmallWritesToNull(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto file = *io::FileOutputStream::Open(GetNullFile());
+
+ auto buffered_file =
+ *io::BufferedOutputStream::Create(kBufferSize, default_memory_pool(), file);
+ BenchmarkStreamingWrites(state, small_sizes, buffered_file.get());
+}
+
+// Benchmark writing a pipe
+//
+// This is slightly more realistic than the above
+
+static void FileOutputStreamSmallWritesToPipe(
+ benchmark::State& state) { // NOLINT non-const reference
+ std::shared_ptr<io::OutputStream> stream;
+ std::shared_ptr<BackgroundReader> reader;
+ SetupPipeWriter(&stream, &reader);
+
+ BenchmarkStreamingWrites(state, small_sizes, stream.get(), reader.get());
+}
+
+static void FileOutputStreamLargeWritesToPipe(
+ benchmark::State& state) { // NOLINT non-const reference
+ std::shared_ptr<io::OutputStream> stream;
+ std::shared_ptr<BackgroundReader> reader;
+ SetupPipeWriter(&stream, &reader);
+
+ BenchmarkStreamingWrites(state, large_sizes, stream.get(), reader.get());
+}
+
+static void BufferedOutputStreamSmallWritesToPipe(
+ benchmark::State& state) { // NOLINT non-const reference
+ std::shared_ptr<io::OutputStream> stream;
+ std::shared_ptr<BackgroundReader> reader;
+ SetupPipeWriter(&stream, &reader);
+
+ auto buffered_stream =
+ *io::BufferedOutputStream::Create(kBufferSize, default_memory_pool(), stream);
+ BenchmarkStreamingWrites(state, small_sizes, buffered_stream.get(), reader.get());
+}
+
+static void BufferedOutputStreamLargeWritesToPipe(
+ benchmark::State& state) { // NOLINT non-const reference
+ std::shared_ptr<io::OutputStream> stream;
+ std::shared_ptr<BackgroundReader> reader;
+ SetupPipeWriter(&stream, &reader);
+
+ auto buffered_stream =
+ *io::BufferedOutputStream::Create(kBufferSize, default_memory_pool(), stream);
+
+ BenchmarkStreamingWrites(state, large_sizes, buffered_stream.get(), reader.get());
+}
+
+// We use real time as we don't want to count CPU time spent in the
+// BackgroundReader thread
+
+BENCHMARK(FileOutputStreamSmallWritesToNull)->UseRealTime();
+BENCHMARK(FileOutputStreamSmallWritesToPipe)->UseRealTime();
+BENCHMARK(FileOutputStreamLargeWritesToPipe)->UseRealTime();
+
+BENCHMARK(BufferedOutputStreamSmallWritesToNull)->UseRealTime();
+BENCHMARK(BufferedOutputStreamSmallWritesToPipe)->UseRealTime();
+BENCHMARK(BufferedOutputStreamLargeWritesToPipe)->UseRealTime();
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/file_test.cc b/src/arrow/cpp/src/arrow/io/file_test.cc
new file mode 100644
index 000000000..7d3d1c621
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/file_test.cc
@@ -0,0 +1,1064 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef _WIN32
+#include <fcntl.h> // IWYU pragma: keep
+#include <unistd.h>
+#endif
+
+#include <atomic>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <fstream>
+#include <memory>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/buffer.h"
+#include "arrow/io/file.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/test_common.h"
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/future.h"
+#include "arrow/util/io_util.h"
+
+namespace arrow {
+
+using internal::CreatePipe;
+using internal::FileClose;
+using internal::FileGetSize;
+using internal::FileOpenReadable;
+using internal::FileOpenWritable;
+using internal::FileRead;
+using internal::FileSeek;
+using internal::PlatformFilename;
+using internal::TemporaryDir;
+
+namespace io {
+
+class FileTestFixture : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK_AND_ASSIGN(temp_dir_, TemporaryDir::Make("file-test-"));
+ path_ = TempFile("arrow-test-io-file.txt");
+ EnsureFileDeleted();
+ }
+
+ std::string TempFile(arrow::util::string_view path) {
+ return temp_dir_->path().Join(std::string(path)).ValueOrDie().ToString();
+ }
+
+ void TearDown() { EnsureFileDeleted(); }
+
+ void EnsureFileDeleted() {
+ if (FileExists(path_)) {
+ ARROW_UNUSED(std::remove(path_.c_str()));
+ }
+ }
+
+ protected:
+ std::unique_ptr<TemporaryDir> temp_dir_;
+ std::string path_;
+};
+
+// ----------------------------------------------------------------------
+// File output tests
+
+class TestFileOutputStream : public FileTestFixture {
+ public:
+ void OpenFile(bool append = false) {
+ ASSERT_OK_AND_ASSIGN(file_, FileOutputStream::Open(path_, append));
+ }
+
+ void OpenFileDescriptor() {
+ int fd_file;
+ ASSERT_OK_AND_ASSIGN(auto file_name, PlatformFilename::FromString(path_));
+ ASSERT_OK_AND_ASSIGN(fd_file, FileOpenWritable(file_name, true /* write_only */,
+ false /* truncate */));
+ ASSERT_OK_AND_ASSIGN(file_, FileOutputStream::Open(fd_file));
+ }
+
+ protected:
+ std::shared_ptr<FileOutputStream> file_;
+};
+
+#if defined(_MSC_VER)
+TEST_F(TestFileOutputStream, FileNameWideCharConversionRangeException) {
+ // Invalid utf-8 filename
+ std::string file_name = "\x80";
+ ASSERT_RAISES(Invalid, FileOutputStream::Open(file_name));
+ ASSERT_RAISES(Invalid, ReadableFile::Open(file_name));
+}
+
+// TODO add a test with a valid utf-8 filename
+#endif
+
+TEST_F(TestFileOutputStream, DestructorClosesFile) {
+ int fd_file;
+
+ OpenFile();
+ fd_file = file_->file_descriptor();
+ ASSERT_FALSE(FileIsClosed(fd_file));
+ file_.reset();
+ ASSERT_TRUE(FileIsClosed(fd_file));
+
+ OpenFileDescriptor();
+ fd_file = file_->file_descriptor();
+ ASSERT_FALSE(FileIsClosed(fd_file));
+ file_.reset();
+ ASSERT_TRUE(FileIsClosed(fd_file));
+}
+
+TEST_F(TestFileOutputStream, Close) {
+ OpenFile();
+
+ const char* data = "testdata";
+ ASSERT_OK(file_->Write(data, strlen(data)));
+
+ int fd = file_->file_descriptor();
+ ASSERT_FALSE(file_->closed());
+ ASSERT_OK(file_->Close());
+ ASSERT_TRUE(file_->closed());
+ ASSERT_TRUE(FileIsClosed(fd));
+ ASSERT_RAISES(Invalid, file_->Write(data, strlen(data)));
+
+ // Idempotent
+ ASSERT_OK(file_->Close());
+
+ AssertFileContents(path_, data);
+}
+
+TEST_F(TestFileOutputStream, FromFileDescriptor) {
+ OpenFileDescriptor();
+
+ std::string data1 = "test";
+ ASSERT_OK(file_->Write(data1.data(), data1.size()));
+ int fd = file_->file_descriptor();
+ ASSERT_OK(file_->Close());
+ ASSERT_TRUE(FileIsClosed(fd));
+
+ AssertFileContents(path_, data1);
+
+ // Re-open at end of file
+ ASSERT_OK_AND_ASSIGN(auto file_name, PlatformFilename::FromString(path_));
+ ASSERT_OK_AND_ASSIGN(
+ fd, FileOpenWritable(file_name, true /* write_only */, false /* truncate */));
+ ASSERT_OK(FileSeek(fd, 0, SEEK_END));
+ ASSERT_OK_AND_ASSIGN(file_, FileOutputStream::Open(fd));
+
+ std::string data2 = "data";
+ ASSERT_OK(file_->Write(data2.data(), data2.size()));
+ ASSERT_OK(file_->Close());
+
+ AssertFileContents(path_, data1 + data2);
+}
+
+TEST_F(TestFileOutputStream, InvalidWrites) {
+ OpenFile();
+
+ const char* data = "";
+
+ ASSERT_RAISES(IOError, file_->Write(data, -1));
+}
+
+TEST_F(TestFileOutputStream, Tell) {
+ OpenFile();
+
+ ASSERT_OK_AND_EQ(0, file_->Tell());
+
+ const char* data = "testdata";
+ ASSERT_OK(file_->Write(data, 8));
+ ASSERT_OK_AND_EQ(8, file_->Tell());
+}
+
+TEST_F(TestFileOutputStream, TruncatesNewFile) {
+ ASSERT_OK_AND_ASSIGN(file_, FileOutputStream::Open(path_));
+
+ const char* data = "testdata";
+ ASSERT_OK(file_->Write(data, strlen(data)));
+ ASSERT_OK(file_->Close());
+
+ ASSERT_OK_AND_ASSIGN(file_, FileOutputStream::Open(path_));
+ ASSERT_OK(file_->Close());
+
+ AssertFileContents(path_, "");
+}
+
+TEST_F(TestFileOutputStream, Append) {
+ ASSERT_OK_AND_ASSIGN(file_, FileOutputStream::Open(path_));
+ {
+ const char* data = "test";
+ ASSERT_OK(file_->Write(data, strlen(data)));
+ }
+ ASSERT_OK(file_->Close());
+ ASSERT_OK_AND_ASSIGN(file_, FileOutputStream::Open(path_, true /* append */));
+ {
+ const char* data = "data";
+ ASSERT_OK(file_->Write(data, strlen(data)));
+ }
+ ASSERT_OK(file_->Close());
+ AssertFileContents(path_, "testdata");
+}
+
+// ----------------------------------------------------------------------
+// File input tests
+
+class TestReadableFile : public FileTestFixture {
+ public:
+ void OpenFile() { ASSERT_OK_AND_ASSIGN(file_, ReadableFile::Open(path_)); }
+
+ void MakeTestFile() {
+ std::string data = "testdata";
+ std::ofstream stream;
+ stream.open(path_.c_str());
+ stream << data;
+ }
+
+ protected:
+ std::shared_ptr<ReadableFile> file_;
+};
+
+TEST_F(TestReadableFile, DestructorClosesFile) {
+ MakeTestFile();
+
+ int fd;
+ {
+ ASSERT_OK_AND_ASSIGN(auto file, ReadableFile::Open(path_));
+ fd = file->file_descriptor();
+ }
+ ASSERT_TRUE(FileIsClosed(fd));
+}
+
+TEST_F(TestReadableFile, Close) {
+ MakeTestFile();
+ OpenFile();
+
+ int fd = file_->file_descriptor();
+ ASSERT_FALSE(file_->closed());
+ ASSERT_OK(file_->Close());
+ ASSERT_TRUE(file_->closed());
+
+ ASSERT_TRUE(FileIsClosed(fd));
+
+ // Idempotent
+ ASSERT_OK(file_->Close());
+ ASSERT_TRUE(FileIsClosed(fd));
+}
+
+TEST_F(TestReadableFile, FromFileDescriptor) {
+ MakeTestFile();
+
+ int fd = -2;
+ ASSERT_OK_AND_ASSIGN(auto file_name, PlatformFilename::FromString(path_));
+ ASSERT_OK_AND_ASSIGN(fd, FileOpenReadable(file_name));
+ ASSERT_GE(fd, 0);
+ ASSERT_OK(FileSeek(fd, 4));
+
+ ASSERT_OK_AND_ASSIGN(file_, ReadableFile::Open(fd));
+ ASSERT_EQ(file_->file_descriptor(), fd);
+ ASSERT_OK_AND_ASSIGN(auto buf, file_->Read(5));
+ ASSERT_EQ(buf->size(), 4);
+ ASSERT_TRUE(buf->Equals(Buffer("data")));
+
+ ASSERT_FALSE(FileIsClosed(fd));
+ ASSERT_OK(file_->Close());
+ ASSERT_TRUE(FileIsClosed(fd));
+ // Idempotent
+ ASSERT_OK(file_->Close());
+ ASSERT_TRUE(FileIsClosed(fd));
+}
+
+TEST_F(TestReadableFile, Peek) {
+ MakeTestFile();
+ OpenFile();
+
+ // Cannot peek
+ ASSERT_RAISES(NotImplemented, file_->Peek(4));
+}
+
+TEST_F(TestReadableFile, SeekTellSize) {
+ MakeTestFile();
+ OpenFile();
+
+ ASSERT_OK_AND_EQ(0, file_->Tell());
+
+ ASSERT_OK(file_->Seek(4));
+ ASSERT_OK_AND_EQ(4, file_->Tell());
+
+ // Can seek past end of file
+ ASSERT_OK(file_->Seek(100));
+ ASSERT_OK_AND_EQ(100, file_->Tell());
+
+ ASSERT_OK_AND_EQ(8, file_->GetSize());
+
+ ASSERT_OK_AND_EQ(100, file_->Tell());
+
+ // does not support zero copy
+ ASSERT_FALSE(file_->supports_zero_copy());
+}
+
+TEST_F(TestReadableFile, Read) {
+ uint8_t buffer[50];
+
+ MakeTestFile();
+ OpenFile();
+
+ ASSERT_OK_AND_EQ(4, file_->Read(4, buffer));
+ ASSERT_EQ(0, std::memcmp(buffer, "test", 4));
+
+ ASSERT_OK_AND_EQ(4, file_->Read(10, buffer));
+ ASSERT_EQ(0, std::memcmp(buffer, "data", 4));
+
+ // Test incomplete read, ARROW-1094
+ ASSERT_OK_AND_ASSIGN(int64_t size, file_->GetSize());
+
+ ASSERT_OK(file_->Seek(1));
+ ASSERT_OK_AND_ASSIGN(auto buf, file_->Read(size));
+ ASSERT_EQ(size - 1, buf->size());
+
+ ASSERT_OK(file_->Close());
+ ASSERT_RAISES(Invalid, file_->Read(1));
+}
+
+TEST_F(TestReadableFile, ReadAt) {
+ uint8_t buffer[50];
+ const char* test_data = "testdata";
+
+ MakeTestFile();
+ OpenFile();
+
+ ASSERT_OK_AND_EQ(4, file_->ReadAt(0, 4, buffer));
+ ASSERT_EQ(0, std::memcmp(buffer, "test", 4));
+
+ ASSERT_OK_AND_EQ(7, file_->ReadAt(1, 10, buffer));
+ ASSERT_EQ(0, std::memcmp(buffer, "estdata", 7));
+
+ // Check buffer API
+ ASSERT_OK_AND_ASSIGN(auto buffer2, file_->ReadAt(2, 5));
+ ASSERT_EQ(5, buffer2->size());
+
+ Buffer expected(reinterpret_cast<const uint8_t*>(test_data + 2), 5);
+ ASSERT_TRUE(buffer2->Equals(expected));
+
+ // Invalid reads
+ ASSERT_RAISES(Invalid, file_->ReadAt(-1, 1));
+ ASSERT_RAISES(Invalid, file_->ReadAt(1, -1));
+ ASSERT_RAISES(Invalid, file_->ReadAt(-1, 1, buffer));
+ ASSERT_RAISES(Invalid, file_->ReadAt(1, -1, buffer));
+
+ ASSERT_OK(file_->Close());
+ ASSERT_RAISES(Invalid, file_->ReadAt(0, 1));
+}
+
+TEST_F(TestReadableFile, ReadAsync) {
+ MakeTestFile();
+ OpenFile();
+
+ auto fut1 = file_->ReadAsync({}, 1, 10);
+ auto fut2 = file_->ReadAsync({}, 0, 4);
+ ASSERT_OK_AND_ASSIGN(auto buf1, fut1.result());
+ ASSERT_OK_AND_ASSIGN(auto buf2, fut2.result());
+ AssertBufferEqual(*buf1, "estdata");
+ AssertBufferEqual(*buf2, "test");
+}
+
+TEST_F(TestReadableFile, SeekingRequired) {
+ MakeTestFile();
+ OpenFile();
+
+ ASSERT_OK_AND_ASSIGN(auto buffer, file_->ReadAt(0, 4));
+ AssertBufferEqual(*buffer, "test");
+
+ ASSERT_RAISES(Invalid, file_->Read(4));
+ ASSERT_OK(file_->Seek(0));
+ ASSERT_OK_AND_ASSIGN(buffer, file_->Read(4));
+ AssertBufferEqual(*buffer, "test");
+}
+
+TEST_F(TestReadableFile, WillNeed) {
+ MakeTestFile();
+ OpenFile();
+
+ ASSERT_OK(file_->WillNeed({}));
+ ASSERT_OK(file_->WillNeed({{0, 3}, {4, 6}}));
+ ASSERT_OK(file_->WillNeed({{10, 0}}));
+
+ ASSERT_RAISES(Invalid, file_->WillNeed({{-1, -1}}));
+}
+
+TEST_F(TestReadableFile, NonexistentFile) {
+ std::string path = "0xDEADBEEF.txt";
+ auto maybe_file = ReadableFile::Open(path);
+ ASSERT_RAISES(IOError, maybe_file);
+ std::string message = maybe_file.status().message();
+ ASSERT_NE(std::string::npos, message.find(path));
+}
+
+class MyMemoryPool : public MemoryPool {
+ public:
+ MyMemoryPool() : num_allocations_(0) {}
+
+ Status Allocate(int64_t size, uint8_t** out) override {
+ *out = reinterpret_cast<uint8_t*>(std::malloc(size));
+ ++num_allocations_;
+ return Status::OK();
+ }
+
+ void Free(uint8_t* buffer, int64_t size) override { std::free(buffer); }
+
+ Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) override {
+ *ptr = reinterpret_cast<uint8_t*>(std::realloc(*ptr, new_size));
+
+ if (*ptr == NULL) {
+ return Status::OutOfMemory("realloc of size ", new_size, " failed");
+ }
+
+ return Status::OK();
+ }
+
+ int64_t bytes_allocated() const override { return -1; }
+
+ std::string backend_name() const override { return "my"; }
+
+ int64_t num_allocations() const { return num_allocations_.load(); }
+
+ private:
+ std::atomic<int64_t> num_allocations_;
+};
+
+TEST_F(TestReadableFile, CustomMemoryPool) {
+ MakeTestFile();
+
+ MyMemoryPool pool;
+ ASSERT_OK_AND_ASSIGN(file_, ReadableFile::Open(path_, &pool));
+
+ ASSERT_OK_AND_ASSIGN(auto buffer, file_->ReadAt(0, 4));
+ ASSERT_OK_AND_ASSIGN(buffer, file_->ReadAt(4, 8));
+
+ ASSERT_EQ(2, pool.num_allocations());
+}
+
+TEST_F(TestReadableFile, ThreadSafety) {
+ std::string data = "foobar";
+ {
+ std::ofstream stream;
+ stream.open(path_.c_str());
+ stream << data;
+ }
+
+ MyMemoryPool pool;
+ ASSERT_OK_AND_ASSIGN(file_, ReadableFile::Open(path_, &pool));
+
+ std::atomic<int> correct_count(0);
+ int niter = 30000;
+
+ auto ReadData = [&correct_count, &data, &niter, this]() {
+ for (int i = 0; i < niter; ++i) {
+ const int offset = i % 3;
+ ASSERT_OK_AND_ASSIGN(auto buffer, file_->ReadAt(offset, 3));
+ if (0 == memcmp(data.c_str() + offset, buffer->data(), 3)) {
+ correct_count += 1;
+ }
+ }
+ };
+
+ std::thread thread1(ReadData);
+ std::thread thread2(ReadData);
+
+ thread1.join();
+ thread2.join();
+
+ ASSERT_EQ(niter * 2, correct_count);
+}
+
+// ----------------------------------------------------------------------
+// Pipe I/O tests using FileOutputStream
+// (cannot test using ReadableFile as it currently requires seeking)
+
+class TestPipeIO : public ::testing::Test {
+ public:
+ void MakePipe() {
+ ASSERT_OK_AND_ASSIGN(auto pipe, CreatePipe());
+ r_ = pipe.rfd;
+ w_ = pipe.wfd;
+ ASSERT_GE(r_, 0);
+ ASSERT_GE(w_, 0);
+ }
+ void ClosePipe() {
+ if (r_ != -1) {
+ ASSERT_OK(FileClose(r_));
+ r_ = -1;
+ }
+ if (w_ != -1) {
+ ASSERT_OK(FileClose(w_));
+ w_ = -1;
+ }
+ }
+ void TearDown() { ClosePipe(); }
+
+ protected:
+ int r_ = -1, w_ = -1;
+};
+
+TEST_F(TestPipeIO, TestWrite) {
+ std::string data1 = "test", data2 = "data!";
+ std::shared_ptr<FileOutputStream> file;
+ uint8_t buffer[10];
+ int64_t bytes_read;
+
+ MakePipe();
+ ASSERT_OK_AND_ASSIGN(file, FileOutputStream::Open(w_));
+ w_ = -1; // now owned by FileOutputStream
+
+ ASSERT_OK(file->Write(data1.data(), data1.size()));
+ ASSERT_OK_AND_ASSIGN(bytes_read, FileRead(r_, buffer, 4));
+ ASSERT_EQ(bytes_read, 4);
+ ASSERT_EQ(0, std::memcmp(buffer, "test", 4));
+
+ ASSERT_OK(file->Write(Buffer::FromString(std::string(data2))));
+ ASSERT_OK_AND_ASSIGN(bytes_read, FileRead(r_, buffer, 4));
+ ASSERT_EQ(bytes_read, 4);
+ ASSERT_EQ(0, std::memcmp(buffer, "data", 4));
+
+ ASSERT_FALSE(file->closed());
+ ASSERT_OK(file->Close());
+ ASSERT_TRUE(file->closed());
+ ASSERT_OK_AND_ASSIGN(bytes_read, FileRead(r_, buffer, 2));
+ ASSERT_EQ(bytes_read, 1);
+ ASSERT_EQ(0, std::memcmp(buffer, "!", 1));
+ // EOF reached
+ ASSERT_OK_AND_ASSIGN(bytes_read, FileRead(r_, buffer, 2));
+ ASSERT_EQ(bytes_read, 0);
+}
+
+TEST_F(TestPipeIO, ReadableFileFails) {
+ // ReadableFile fails on non-seekable fd
+ ASSERT_RAISES(IOError, ReadableFile::Open(r_));
+}
+
+// ----------------------------------------------------------------------
+// Memory map tests
+
+class TestMemoryMappedFile : public ::testing::Test, public MemoryMapFixture {
+ public:
+ void SetUp() override {
+ ASSERT_OK_AND_ASSIGN(temp_dir_, TemporaryDir::Make("memory-map-test-"));
+ }
+
+ void TearDown() override { MemoryMapFixture::TearDown(); }
+
+ std::string TempFile(arrow::util::string_view path) {
+ return temp_dir_->path().Join(std::string(path)).ValueOrDie().ToString();
+ }
+
+ protected:
+ std::unique_ptr<TemporaryDir> temp_dir_;
+};
+
+TEST_F(TestMemoryMappedFile, InvalidUsages) {}
+
+TEST_F(TestMemoryMappedFile, ZeroSizeFile) {
+ std::string path = TempFile("io-memory-map-zero-size");
+ ASSERT_OK_AND_ASSIGN(auto result, InitMemoryMap(0, path));
+
+ ASSERT_OK_AND_EQ(0, result->Tell());
+}
+
+TEST_F(TestMemoryMappedFile, MapPartFile) {
+ const int64_t buffer_size = 1024;
+ const int64_t unalign_offset = 1024;
+ const int64_t offset = 65536; // make WIN32 happy
+ std::vector<uint8_t> buffer(buffer_size);
+
+ random_bytes(1024, 0, buffer.data());
+
+ const int reps = 128;
+
+ std::string path = TempFile("io-memory-map-offset");
+
+ // file size = 128k
+ CreateFile(path, reps * buffer_size);
+
+ // map failed with unaligned offset
+ ASSERT_RAISES(IOError,
+ MemoryMappedFile::Open(path, FileMode::READWRITE, unalign_offset, 4096));
+
+ // map failed if length is greater than file size
+ ASSERT_RAISES(Invalid,
+ MemoryMappedFile::Open(path, FileMode::READWRITE, offset, 409600));
+
+ // map succeeded with valid file region <64k-68k>
+ ASSERT_OK_AND_ASSIGN(auto result,
+ MemoryMappedFile::Open(path, FileMode::READWRITE, offset, 4096));
+
+ ASSERT_OK_AND_EQ(4096, result->GetSize());
+
+ ASSERT_OK_AND_EQ(0, result->Tell());
+
+ ASSERT_OK(result->Write(buffer.data(), buffer_size));
+ ASSERT_OK_AND_ASSIGN(auto out_buffer, result->ReadAt(0, buffer_size));
+ ASSERT_EQ(0, memcmp(out_buffer->data(), buffer.data(), buffer_size));
+
+ ASSERT_OK_AND_EQ(buffer_size, result->Tell());
+
+ ASSERT_OK(result->Seek(4096));
+ ASSERT_OK_AND_EQ(4096, result->Tell());
+
+ // Resize is not supported
+ ASSERT_RAISES(IOError, result->Resize(4096));
+
+ // Write beyond memory mapped length
+ ASSERT_RAISES(IOError, result->WriteAt(4096, buffer.data(), buffer_size));
+}
+
+TEST_F(TestMemoryMappedFile, WriteRead) {
+ const int64_t buffer_size = 1024;
+ std::vector<uint8_t> buffer(buffer_size);
+ random_bytes(1024, 0, buffer.data());
+
+ const int reps = 5;
+
+ std::string path = TempFile("io-memory-map-write-read-test");
+ ASSERT_OK_AND_ASSIGN(auto result, InitMemoryMap(reps * buffer_size, path));
+
+ int64_t position = 0;
+ for (int i = 0; i < reps; ++i) {
+ ASSERT_OK(result->Write(buffer.data(), buffer_size));
+ ASSERT_OK_AND_ASSIGN(auto out_buffer, result->ReadAt(position, buffer_size));
+
+ ASSERT_EQ(0, memcmp(out_buffer->data(), buffer.data(), buffer_size));
+
+ position += buffer_size;
+ }
+}
+
+TEST_F(TestMemoryMappedFile, ReadAsync) {
+ const int64_t buffer_size = 1024;
+ std::vector<uint8_t> buffer(buffer_size);
+ random_bytes(1024, 0, buffer.data());
+
+ std::string path = TempFile("io-memory-map-read-async-test");
+ ASSERT_OK_AND_ASSIGN(auto mmap, InitMemoryMap(buffer_size, path));
+ ASSERT_OK(mmap->Write(buffer.data(), buffer_size));
+
+ auto fut1 = mmap->ReadAsync({}, 1, 1000);
+ auto fut2 = mmap->ReadAsync({}, 3, 4);
+ ASSERT_EQ(fut1.state(), FutureState::SUCCESS);
+ ASSERT_EQ(fut2.state(), FutureState::SUCCESS);
+ ASSERT_OK_AND_ASSIGN(auto buf1, fut1.result());
+ ASSERT_OK_AND_ASSIGN(auto buf2, fut2.result());
+
+ AssertBufferEqual(*buf1, Buffer(buffer.data() + 1, 1000));
+ AssertBufferEqual(*buf2, Buffer(buffer.data() + 3, 4));
+}
+
+TEST_F(TestMemoryMappedFile, WillNeed) {
+ const int64_t buffer_size = 1024;
+ std::vector<uint8_t> buffer(buffer_size);
+ random_bytes(1024, 0, buffer.data());
+
+ std::string path = TempFile("io-memory-map-will-need-test");
+ ASSERT_OK_AND_ASSIGN(auto mmap, InitMemoryMap(buffer_size, path));
+ ASSERT_OK(mmap->Write(buffer.data(), buffer_size));
+
+ ASSERT_OK(mmap->WillNeed({}));
+ ASSERT_OK(mmap->WillNeed({{0, 4}, {100, 924}}));
+ ASSERT_OK(mmap->WillNeed({{1024, 0}}));
+ ASSERT_RAISES(IOError, mmap->WillNeed({{1025, 1}})); // Out of bounds
+}
+
+TEST_F(TestMemoryMappedFile, InvalidReads) {
+ std::string path = TempFile("io-memory-map-invalid-reads-test");
+ ASSERT_OK_AND_ASSIGN(auto result, InitMemoryMap(4096, path));
+
+ uint8_t buffer[10];
+
+ ASSERT_RAISES(Invalid, result->ReadAt(-1, 1));
+ ASSERT_RAISES(Invalid, result->ReadAt(1, -1));
+ ASSERT_RAISES(Invalid, result->ReadAt(-1, 1, buffer));
+ ASSERT_RAISES(Invalid, result->ReadAt(1, -1, buffer));
+}
+
+TEST_F(TestMemoryMappedFile, WriteResizeRead) {
+ const int64_t buffer_size = 1024;
+ const int reps = 5;
+ std::vector<std::vector<uint8_t>> buffers(reps);
+ for (auto& b : buffers) {
+ b.resize(buffer_size);
+ random_bytes(buffer_size, 0, b.data());
+ }
+
+ std::string path = TempFile("io-memory-map-write-read-test");
+ ASSERT_OK_AND_ASSIGN(auto result, InitMemoryMap(buffer_size, path));
+
+ int64_t position = 0;
+ for (int i = 0; i < reps; ++i) {
+ if (i != 0) {
+ ASSERT_OK(result->Resize(buffer_size * (i + 1)));
+ }
+ ASSERT_OK(result->Write(buffers[i].data(), buffer_size));
+ ASSERT_OK_AND_ASSIGN(auto out_buffer, result->ReadAt(position, buffer_size));
+
+ ASSERT_EQ(out_buffer->size(), buffer_size);
+ ASSERT_EQ(0, memcmp(out_buffer->data(), buffers[i].data(), buffer_size));
+ out_buffer.reset();
+
+ position += buffer_size;
+ }
+}
+
+TEST_F(TestMemoryMappedFile, ResizeRaisesOnExported) {
+ const int64_t buffer_size = 1024;
+ std::vector<uint8_t> buffer(buffer_size);
+ random_bytes(buffer_size, 0, buffer.data());
+
+ std::string path = TempFile("io-memory-map-write-read-test");
+ ASSERT_OK_AND_ASSIGN(auto result, InitMemoryMap(buffer_size, path));
+
+ ASSERT_OK(result->Write(buffer.data(), buffer_size));
+ ASSERT_OK_AND_ASSIGN(auto out_buffer1, result->ReadAt(0, buffer_size));
+ ASSERT_OK_AND_ASSIGN(auto out_buffer2, result->ReadAt(0, buffer_size));
+ ASSERT_EQ(0, memcmp(out_buffer1->data(), buffer.data(), buffer_size));
+ ASSERT_EQ(0, memcmp(out_buffer2->data(), buffer.data(), buffer_size));
+
+ // attempt resize
+ ASSERT_RAISES(IOError, result->Resize(2 * buffer_size));
+
+ out_buffer1.reset();
+
+ ASSERT_RAISES(IOError, result->Resize(2 * buffer_size));
+
+ out_buffer2.reset();
+
+ ASSERT_OK(result->Resize(2 * buffer_size));
+
+ ASSERT_OK_AND_EQ(buffer_size * 2, result->GetSize());
+ ASSERT_OK_AND_EQ(buffer_size * 2, FileGetSize(result->file_descriptor()));
+}
+
+TEST_F(TestMemoryMappedFile, WriteReadZeroInitSize) {
+ const int64_t buffer_size = 1024;
+ std::vector<uint8_t> buffer(buffer_size);
+ random_bytes(buffer_size, 0, buffer.data());
+
+ std::string path = TempFile("io-memory-map-write-read-test");
+ ASSERT_OK_AND_ASSIGN(auto result, InitMemoryMap(0, path));
+
+ ASSERT_OK(result->Resize(buffer_size));
+ ASSERT_OK(result->Write(buffer.data(), buffer_size));
+ ASSERT_OK_AND_ASSIGN(auto out_buffer, result->ReadAt(0, buffer_size));
+ ASSERT_EQ(0, memcmp(out_buffer->data(), buffer.data(), buffer_size));
+
+ ASSERT_OK_AND_EQ(buffer_size, result->GetSize());
+}
+
+TEST_F(TestMemoryMappedFile, WriteThenShrink) {
+ const int64_t buffer_size = 1024;
+ std::vector<uint8_t> buffer(buffer_size);
+ random_bytes(buffer_size, 0, buffer.data());
+
+ std::string path = TempFile("io-memory-map-write-read-test");
+ ASSERT_OK_AND_ASSIGN(auto result, InitMemoryMap(buffer_size * 2, path));
+
+ ASSERT_OK(result->Resize(buffer_size));
+ ASSERT_OK(result->Write(buffer.data(), buffer_size));
+ ASSERT_OK(result->Resize(buffer_size / 2));
+
+ ASSERT_OK_AND_ASSIGN(auto out_buffer, result->ReadAt(0, buffer_size / 2));
+ ASSERT_EQ(0, memcmp(out_buffer->data(), buffer.data(), buffer_size / 2));
+
+ ASSERT_OK_AND_EQ(buffer_size / 2, result->GetSize());
+ ASSERT_OK_AND_EQ(buffer_size / 2, FileGetSize(result->file_descriptor()));
+}
+
+TEST_F(TestMemoryMappedFile, WriteThenShrinkToHalfThenWrite) {
+ const int64_t buffer_size = 1024;
+ std::vector<uint8_t> buffer(buffer_size);
+ random_bytes(buffer_size, 0, buffer.data());
+
+ std::string path = TempFile("io-memory-map-write-read-test");
+ ASSERT_OK_AND_ASSIGN(auto result, InitMemoryMap(buffer_size, path));
+
+ ASSERT_OK(result->Write(buffer.data(), buffer_size));
+ ASSERT_OK(result->Resize(buffer_size / 2));
+
+ ASSERT_OK_AND_EQ(buffer_size / 2, result->Tell());
+
+ ASSERT_OK_AND_ASSIGN(auto out_buffer, result->ReadAt(0, buffer_size / 2));
+ ASSERT_EQ(0, memcmp(out_buffer->data(), buffer.data(), buffer_size / 2));
+ out_buffer.reset();
+
+ // should resume writing directly at the seam
+ ASSERT_OK(result->Resize(buffer_size));
+ ASSERT_OK(result->Write(buffer.data() + buffer_size / 2, buffer_size / 2));
+
+ ASSERT_OK_AND_ASSIGN(out_buffer, result->ReadAt(0, buffer_size));
+ ASSERT_EQ(0, memcmp(out_buffer->data(), buffer.data(), buffer_size));
+
+ ASSERT_OK_AND_EQ(buffer_size, result->GetSize());
+ ASSERT_OK_AND_EQ(buffer_size, FileGetSize(result->file_descriptor()));
+}
+
+TEST_F(TestMemoryMappedFile, ResizeToZeroThanWrite) {
+ const int64_t buffer_size = 1024;
+ std::vector<uint8_t> buffer(buffer_size);
+ random_bytes(buffer_size, 0, buffer.data());
+
+ std::string path = TempFile("io-memory-map-write-read-test");
+ ASSERT_OK_AND_ASSIGN(auto result, InitMemoryMap(buffer_size, path));
+
+ // just a sanity check that writing works ook
+ ASSERT_OK(result->Write(buffer.data(), buffer_size));
+ ASSERT_OK_AND_ASSIGN(auto out_buffer, result->ReadAt(0, buffer_size));
+ ASSERT_EQ(0, memcmp(out_buffer->data(), buffer.data(), buffer_size));
+ out_buffer.reset();
+
+ ASSERT_OK(result->Resize(0));
+ ASSERT_OK_AND_EQ(0, result->GetSize());
+
+ ASSERT_OK_AND_EQ(0, result->Tell());
+
+ ASSERT_OK_AND_EQ(0, FileGetSize(result->file_descriptor()));
+
+ // provision a vector to the buffer size in case ReadAt decides
+ // to read even though it shouldn't
+ std::vector<uint8_t> should_remain_empty(buffer_size);
+ ASSERT_OK_AND_EQ(0, result->ReadAt(0, 1, should_remain_empty.data()));
+
+ // just a sanity check that writing works ook
+ ASSERT_OK(result->Resize(buffer_size));
+ ASSERT_OK(result->Write(buffer.data(), buffer_size));
+ ASSERT_OK_AND_ASSIGN(out_buffer, result->ReadAt(0, buffer_size));
+ ASSERT_EQ(0, memcmp(out_buffer->data(), buffer.data(), buffer_size));
+}
+
+TEST_F(TestMemoryMappedFile, WriteAt) {
+ const int64_t buffer_size = 1024;
+ std::vector<uint8_t> buffer(buffer_size);
+ random_bytes(buffer_size, 0, buffer.data());
+
+ std::string path = TempFile("io-memory-map-write-read-test");
+ ASSERT_OK_AND_ASSIGN(auto result, InitMemoryMap(buffer_size, path));
+
+ ASSERT_OK(result->WriteAt(0, buffer.data(), buffer_size / 2));
+
+ ASSERT_OK(
+ result->WriteAt(buffer_size / 2, buffer.data() + buffer_size / 2, buffer_size / 2));
+
+ ASSERT_OK_AND_ASSIGN(auto out_buffer, result->ReadAt(0, buffer_size));
+
+ ASSERT_EQ(memcmp(out_buffer->data(), buffer.data(), buffer_size), 0);
+}
+
+TEST_F(TestMemoryMappedFile, WriteBeyondEnd) {
+ const int64_t buffer_size = 1024;
+ std::vector<uint8_t> buffer(buffer_size);
+ random_bytes(buffer_size, 0, buffer.data());
+
+ std::string path = TempFile("io-memory-map-write-read-test");
+ ASSERT_OK_AND_ASSIGN(auto result, InitMemoryMap(buffer_size, path));
+
+ ASSERT_OK(result->Seek(1));
+ // Attempt to write beyond end of memory map
+ ASSERT_RAISES(IOError, result->Write(buffer.data(), buffer_size));
+
+ // The position should remain unchanged afterwards
+ ASSERT_OK_AND_EQ(1, result->Tell());
+}
+
+TEST_F(TestMemoryMappedFile, WriteAtBeyondEnd) {
+ const int64_t buffer_size = 1024;
+ std::vector<uint8_t> buffer(buffer_size);
+ random_bytes(buffer_size, 0, buffer.data());
+
+ std::string path = TempFile("io-memory-map-write-read-test");
+ ASSERT_OK_AND_ASSIGN(auto result, InitMemoryMap(buffer_size, path));
+
+ // Attempt to write beyond end of memory map
+ ASSERT_RAISES(IOError, result->WriteAt(1, buffer.data(), buffer_size));
+
+ // The position should remain unchanged afterwards
+ ASSERT_OK_AND_EQ(0, result->Tell());
+}
+
+TEST_F(TestMemoryMappedFile, GetSize) {
+ std::string path = TempFile("io-memory-map-get-size");
+ ASSERT_OK_AND_ASSIGN(auto result, InitMemoryMap(16384, path));
+
+ ASSERT_OK_AND_EQ(16384, result->GetSize());
+
+ ASSERT_OK_AND_EQ(0, result->Tell());
+}
+
+TEST_F(TestMemoryMappedFile, ReadOnly) {
+ const int64_t buffer_size = 1024;
+ std::vector<uint8_t> buffer(buffer_size);
+
+ random_bytes(1024, 0, buffer.data());
+
+ const int reps = 5;
+
+ std::string path = TempFile("ipc-read-only-test");
+ ASSERT_OK_AND_ASSIGN(auto rwmmap, InitMemoryMap(reps * buffer_size, path));
+
+ int64_t position = 0;
+ for (int i = 0; i < reps; ++i) {
+ ASSERT_OK(rwmmap->Write(buffer.data(), buffer_size));
+ position += buffer_size;
+ }
+ ASSERT_OK(rwmmap->Close());
+
+ ASSERT_OK_AND_ASSIGN(auto rommap, MemoryMappedFile::Open(path, FileMode::READ));
+
+ position = 0;
+ for (int i = 0; i < reps; ++i) {
+ ASSERT_OK_AND_ASSIGN(auto out_buffer, rommap->ReadAt(position, buffer_size));
+
+ ASSERT_EQ(0, memcmp(out_buffer->data(), buffer.data(), buffer_size));
+ position += buffer_size;
+ }
+ ASSERT_OK(rommap->Close());
+}
+
+TEST_F(TestMemoryMappedFile, LARGE_MEMORY_TEST(ReadWriteOver4GbFile)) {
+ // ARROW-1096
+ const int64_t buffer_size = 1000 * 1000;
+ std::vector<uint8_t> buffer(buffer_size);
+
+ random_bytes(buffer_size, 0, buffer.data());
+
+ const int64_t reps = 5000;
+
+ std::string path = TempFile("ipc-read-over-4gb-file-test");
+ ASSERT_OK_AND_ASSIGN(auto rwmmap, InitMemoryMap(reps * buffer_size, path));
+ AppendFile(path);
+
+ int64_t position = 0;
+ for (int i = 0; i < reps; ++i) {
+ ASSERT_OK(rwmmap->Write(buffer.data(), buffer_size));
+ position += buffer_size;
+ }
+ ASSERT_OK(rwmmap->Close());
+
+ ASSERT_OK_AND_ASSIGN(auto rommap, MemoryMappedFile::Open(path, FileMode::READ));
+
+ position = 0;
+ for (int i = 0; i < reps; ++i) {
+ ASSERT_OK_AND_ASSIGN(auto out_buffer, rommap->ReadAt(position, buffer_size));
+
+ ASSERT_EQ(0, memcmp(out_buffer->data(), buffer.data(), buffer_size));
+ position += buffer_size;
+ }
+ ASSERT_OK(rommap->Close());
+}
+
+TEST_F(TestMemoryMappedFile, RetainMemoryMapReference) {
+ // ARROW-494
+
+ const int64_t buffer_size = 1024;
+ std::vector<uint8_t> buffer(buffer_size);
+
+ random_bytes(1024, 0, buffer.data());
+
+ std::string path = TempFile("ipc-read-only-test");
+ CreateFile(path, buffer_size);
+
+ {
+ ASSERT_OK_AND_ASSIGN(auto rwmmap, MemoryMappedFile::Open(path, FileMode::READWRITE));
+ ASSERT_OK(rwmmap->Write(buffer.data(), buffer_size));
+ ASSERT_FALSE(rwmmap->closed());
+ ASSERT_OK(rwmmap->Close());
+ ASSERT_TRUE(rwmmap->closed());
+ }
+
+ std::shared_ptr<Buffer> out_buffer;
+
+ {
+ ASSERT_OK_AND_ASSIGN(auto rommap, MemoryMappedFile::Open(path, FileMode::READ));
+ ASSERT_OK_AND_ASSIGN(out_buffer, rommap->Read(buffer_size));
+ ASSERT_FALSE(rommap->closed());
+ ASSERT_OK(rommap->Close());
+ ASSERT_TRUE(rommap->closed());
+ }
+
+ // valgrind will catch if memory is unmapped
+ ASSERT_EQ(0, memcmp(out_buffer->data(), buffer.data(), buffer_size));
+}
+
+TEST_F(TestMemoryMappedFile, InvalidMode) {
+ const int64_t buffer_size = 1024;
+ std::vector<uint8_t> buffer(buffer_size);
+
+ random_bytes(1024, 0, buffer.data());
+
+ std::string path = TempFile("ipc-invalid-mode-test");
+ CreateFile(path, buffer_size);
+
+ ASSERT_OK_AND_ASSIGN(auto rommap, MemoryMappedFile::Open(path, FileMode::READ));
+ ASSERT_RAISES(IOError, rommap->Write(buffer.data(), buffer_size));
+}
+
+TEST_F(TestMemoryMappedFile, InvalidFile) {
+ std::string nonexistent_path = "invalid-file-name-asfd";
+
+ ASSERT_RAISES(IOError, MemoryMappedFile::Open(nonexistent_path, FileMode::READ));
+}
+
+TEST_F(TestMemoryMappedFile, CastableToFileInterface) {
+ std::shared_ptr<MemoryMappedFile> memory_mapped_file;
+ std::shared_ptr<FileInterface> file = memory_mapped_file;
+}
+
+TEST_F(TestMemoryMappedFile, ThreadSafety) {
+ std::string data = "foobar";
+ std::string path = TempFile("ipc-multithreading-test");
+ CreateFile(path, static_cast<int>(data.size()));
+
+ ASSERT_OK_AND_ASSIGN(auto file, MemoryMappedFile::Open(path, FileMode::READWRITE));
+ ASSERT_OK(file->Write(data.c_str(), static_cast<int64_t>(data.size())));
+
+ std::atomic<int> correct_count(0);
+ int niter = 10000;
+
+ auto ReadData = [&correct_count, &data, &file, &niter]() {
+ for (int i = 0; i < niter; ++i) {
+ ASSERT_OK_AND_ASSIGN(auto buffer, file->ReadAt(0, 3));
+ if (0 == memcmp(data.c_str(), buffer->data(), 3)) {
+ correct_count += 1;
+ }
+ }
+ };
+
+ std::thread thread1(ReadData);
+ std::thread thread2(ReadData);
+
+ thread1.join();
+ thread2.join();
+
+ ASSERT_EQ(niter * 2, correct_count);
+}
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/hdfs.cc b/src/arrow/cpp/src/arrow/io/hdfs.cc
new file mode 100644
index 000000000..cd9e91205
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/hdfs.cc
@@ -0,0 +1,738 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cerrno>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <mutex>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/io/hdfs.h"
+#include "arrow/io/hdfs_internal.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/memory_pool.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+
+using std::size_t;
+
+namespace arrow {
+
+using internal::IOErrorFromErrno;
+
+namespace io {
+
+namespace {
+
+std::string TranslateErrno(int error_code) {
+ std::stringstream ss;
+ ss << error_code << " (" << strerror(error_code) << ")";
+ if (error_code == 255) {
+ // Unknown error can occur if the host is correct but the port is not
+ ss << " Please check that you are connecting to the correct HDFS RPC port";
+ }
+ return ss.str();
+}
+
+} // namespace
+
+#define CHECK_FAILURE(RETURN_VALUE, WHAT) \
+ do { \
+ if (RETURN_VALUE == -1) { \
+ return Status::IOError("HDFS ", WHAT, " failed, errno: ", TranslateErrno(errno)); \
+ } \
+ } while (0)
+
+static constexpr int kDefaultHdfsBufferSize = 1 << 16;
+
+// ----------------------------------------------------------------------
+// File reading
+
+class HdfsAnyFileImpl {
+ public:
+ void set_members(const std::string& path, internal::LibHdfsShim* driver, hdfsFS fs,
+ hdfsFile handle) {
+ path_ = path;
+ driver_ = driver;
+ fs_ = fs;
+ file_ = handle;
+ is_open_ = true;
+ }
+
+ Status Seek(int64_t position) {
+ RETURN_NOT_OK(CheckClosed());
+ int ret = driver_->Seek(fs_, file_, position);
+ CHECK_FAILURE(ret, "seek");
+ return Status::OK();
+ }
+
+ Result<int64_t> Tell() {
+ RETURN_NOT_OK(CheckClosed());
+ int64_t ret = driver_->Tell(fs_, file_);
+ CHECK_FAILURE(ret, "tell");
+ return ret;
+ }
+
+ bool is_open() const { return is_open_; }
+
+ protected:
+ Status CheckClosed() {
+ if (!is_open_) {
+ return Status::Invalid("Operation on closed HDFS file");
+ }
+ return Status::OK();
+ }
+
+ std::string path_;
+
+ internal::LibHdfsShim* driver_;
+
+ // For threadsafety
+ std::mutex lock_;
+
+ // These are pointers in libhdfs, so OK to copy
+ hdfsFS fs_;
+ hdfsFile file_;
+
+ bool is_open_;
+};
+
+namespace {
+
+Status GetPathInfoFailed(const std::string& path) {
+ std::stringstream ss;
+ ss << "Calling GetPathInfo for " << path << " failed. errno: " << TranslateErrno(errno);
+ return Status::IOError(ss.str());
+}
+
+} // namespace
+
+// Private implementation for read-only files
+class HdfsReadableFile::HdfsReadableFileImpl : public HdfsAnyFileImpl {
+ public:
+ explicit HdfsReadableFileImpl(MemoryPool* pool) : pool_(pool) {}
+
+ Status Close() {
+ if (is_open_) {
+ // is_open_ must be set to false in the beginning, because the destructor
+ // attempts to close the stream again, and if the first close fails, then
+ // the error doesn't get propagated properly and the second close
+ // initiated by the destructor raises a segfault
+ is_open_ = false;
+ int ret = driver_->CloseFile(fs_, file_);
+ CHECK_FAILURE(ret, "CloseFile");
+ }
+ return Status::OK();
+ }
+
+ bool closed() const { return !is_open_; }
+
+ Result<int64_t> ReadAt(int64_t position, int64_t nbytes, uint8_t* buffer) {
+ RETURN_NOT_OK(CheckClosed());
+ if (!driver_->HasPread()) {
+ std::lock_guard<std::mutex> guard(lock_);
+ RETURN_NOT_OK(Seek(position));
+ return Read(nbytes, buffer);
+ }
+
+ constexpr int64_t kMaxBlockSize = std::numeric_limits<int32_t>::max();
+ int64_t total_bytes = 0;
+ while (nbytes > 0) {
+ const auto block_size = static_cast<tSize>(std::min(kMaxBlockSize, nbytes));
+ tSize ret =
+ driver_->Pread(fs_, file_, static_cast<tOffset>(position), buffer, block_size);
+ CHECK_FAILURE(ret, "read");
+ DCHECK_LE(ret, block_size);
+ if (ret == 0) {
+ break; // EOF
+ }
+ buffer += ret;
+ total_bytes += ret;
+ position += ret;
+ nbytes -= ret;
+ }
+ return total_bytes;
+ }
+
+ Result<std::shared_ptr<Buffer>> ReadAt(int64_t position, int64_t nbytes) {
+ RETURN_NOT_OK(CheckClosed());
+
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateResizableBuffer(nbytes, pool_));
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read,
+ ReadAt(position, nbytes, buffer->mutable_data()));
+ if (bytes_read < nbytes) {
+ RETURN_NOT_OK(buffer->Resize(bytes_read));
+ buffer->ZeroPadding();
+ }
+ return std::move(buffer);
+ }
+
+ Result<int64_t> Read(int64_t nbytes, void* buffer) {
+ RETURN_NOT_OK(CheckClosed());
+
+ int64_t total_bytes = 0;
+ while (total_bytes < nbytes) {
+ tSize ret = driver_->Read(
+ fs_, file_, reinterpret_cast<uint8_t*>(buffer) + total_bytes,
+ static_cast<tSize>(std::min<int64_t>(buffer_size_, nbytes - total_bytes)));
+ CHECK_FAILURE(ret, "read");
+ total_bytes += ret;
+ if (ret == 0) {
+ break;
+ }
+ }
+ return total_bytes;
+ }
+
+ Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) {
+ RETURN_NOT_OK(CheckClosed());
+
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateResizableBuffer(nbytes, pool_));
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, Read(nbytes, buffer->mutable_data()));
+ if (bytes_read < nbytes) {
+ RETURN_NOT_OK(buffer->Resize(bytes_read));
+ }
+ return std::move(buffer);
+ }
+
+ Result<int64_t> GetSize() {
+ RETURN_NOT_OK(CheckClosed());
+
+ hdfsFileInfo* entry = driver_->GetPathInfo(fs_, path_.c_str());
+ if (entry == nullptr) {
+ return GetPathInfoFailed(path_);
+ }
+ int64_t size = entry->mSize;
+ driver_->FreeFileInfo(entry, 1);
+ return size;
+ }
+
+ void set_memory_pool(MemoryPool* pool) { pool_ = pool; }
+
+ void set_buffer_size(int32_t buffer_size) { buffer_size_ = buffer_size; }
+
+ private:
+ MemoryPool* pool_;
+ int32_t buffer_size_;
+};
+
+HdfsReadableFile::HdfsReadableFile(const io::IOContext& io_context) {
+ impl_.reset(new HdfsReadableFileImpl(io_context.pool()));
+}
+
+HdfsReadableFile::~HdfsReadableFile() { DCHECK_OK(impl_->Close()); }
+
+Status HdfsReadableFile::Close() { return impl_->Close(); }
+
+bool HdfsReadableFile::closed() const { return impl_->closed(); }
+
+Result<int64_t> HdfsReadableFile::ReadAt(int64_t position, int64_t nbytes, void* buffer) {
+ return impl_->ReadAt(position, nbytes, reinterpret_cast<uint8_t*>(buffer));
+}
+
+Result<std::shared_ptr<Buffer>> HdfsReadableFile::ReadAt(int64_t position,
+ int64_t nbytes) {
+ return impl_->ReadAt(position, nbytes);
+}
+
+Result<int64_t> HdfsReadableFile::Read(int64_t nbytes, void* buffer) {
+ return impl_->Read(nbytes, buffer);
+}
+
+Result<std::shared_ptr<Buffer>> HdfsReadableFile::Read(int64_t nbytes) {
+ return impl_->Read(nbytes);
+}
+
+Result<int64_t> HdfsReadableFile::GetSize() { return impl_->GetSize(); }
+
+Status HdfsReadableFile::Seek(int64_t position) { return impl_->Seek(position); }
+
+Result<int64_t> HdfsReadableFile::Tell() const { return impl_->Tell(); }
+
+// ----------------------------------------------------------------------
+// File writing
+
+// Private implementation for writable-only files
+class HdfsOutputStream::HdfsOutputStreamImpl : public HdfsAnyFileImpl {
+ public:
+ HdfsOutputStreamImpl() {}
+
+ Status Close() {
+ if (is_open_) {
+ // is_open_ must be set to false in the beginning, because the destructor
+ // attempts to close the stream again, and if the first close fails, then
+ // the error doesn't get propagated properly and the second close
+ // initiated by the destructor raises a segfault
+ is_open_ = false;
+ RETURN_NOT_OK(FlushInternal());
+ int ret = driver_->CloseFile(fs_, file_);
+ CHECK_FAILURE(ret, "CloseFile");
+ }
+ return Status::OK();
+ }
+
+ bool closed() const { return !is_open_; }
+
+ Status Flush() {
+ RETURN_NOT_OK(CheckClosed());
+
+ return FlushInternal();
+ }
+
+ Status Write(const uint8_t* buffer, int64_t nbytes) {
+ RETURN_NOT_OK(CheckClosed());
+
+ constexpr int64_t kMaxBlockSize = std::numeric_limits<int32_t>::max();
+
+ std::lock_guard<std::mutex> guard(lock_);
+ while (nbytes > 0) {
+ const auto block_size = static_cast<tSize>(std::min(kMaxBlockSize, nbytes));
+ tSize ret = driver_->Write(fs_, file_, buffer, block_size);
+ CHECK_FAILURE(ret, "Write");
+ DCHECK_LE(ret, block_size);
+ buffer += ret;
+ nbytes -= ret;
+ }
+ return Status::OK();
+ }
+
+ protected:
+ Status FlushInternal() {
+ int ret = driver_->Flush(fs_, file_);
+ CHECK_FAILURE(ret, "Flush");
+ return Status::OK();
+ }
+};
+
+HdfsOutputStream::HdfsOutputStream() { impl_.reset(new HdfsOutputStreamImpl()); }
+
+HdfsOutputStream::~HdfsOutputStream() { DCHECK_OK(impl_->Close()); }
+
+Status HdfsOutputStream::Close() { return impl_->Close(); }
+
+bool HdfsOutputStream::closed() const { return impl_->closed(); }
+
+Status HdfsOutputStream::Write(const void* buffer, int64_t nbytes) {
+ return impl_->Write(reinterpret_cast<const uint8_t*>(buffer), nbytes);
+}
+
+Status HdfsOutputStream::Flush() { return impl_->Flush(); }
+
+Result<int64_t> HdfsOutputStream::Tell() const { return impl_->Tell(); }
+
+// ----------------------------------------------------------------------
+// HDFS client
+
+// TODO(wesm): this could throw std::bad_alloc in the course of copying strings
+// into the path info object
+static void SetPathInfo(const hdfsFileInfo* input, HdfsPathInfo* out) {
+ out->kind = input->mKind == kObjectKindFile ? ObjectType::FILE : ObjectType::DIRECTORY;
+ out->name = std::string(input->mName);
+ out->owner = std::string(input->mOwner);
+ out->group = std::string(input->mGroup);
+
+ out->last_access_time = static_cast<int32_t>(input->mLastAccess);
+ out->last_modified_time = static_cast<int32_t>(input->mLastMod);
+ out->size = static_cast<int64_t>(input->mSize);
+
+ out->replication = input->mReplication;
+ out->block_size = input->mBlockSize;
+
+ out->permissions = input->mPermissions;
+}
+
+// Private implementation
+class HadoopFileSystem::HadoopFileSystemImpl {
+ public:
+ HadoopFileSystemImpl() : driver_(NULLPTR), port_(0), fs_(NULLPTR) {}
+
+ Status Connect(const HdfsConnectionConfig* config) {
+ RETURN_NOT_OK(ConnectLibHdfs(&driver_));
+
+ // connect to HDFS with the builder object
+ hdfsBuilder* builder = driver_->NewBuilder();
+ if (!config->host.empty()) {
+ driver_->BuilderSetNameNode(builder, config->host.c_str());
+ }
+ driver_->BuilderSetNameNodePort(builder, static_cast<tPort>(config->port));
+ if (!config->user.empty()) {
+ driver_->BuilderSetUserName(builder, config->user.c_str());
+ }
+ if (!config->kerb_ticket.empty()) {
+ driver_->BuilderSetKerbTicketCachePath(builder, config->kerb_ticket.c_str());
+ }
+
+ for (const auto& kv : config->extra_conf) {
+ int ret = driver_->BuilderConfSetStr(builder, kv.first.c_str(), kv.second.c_str());
+ CHECK_FAILURE(ret, "confsetstr");
+ }
+
+ driver_->BuilderSetForceNewInstance(builder);
+ fs_ = driver_->BuilderConnect(builder);
+
+ if (fs_ == nullptr) {
+ return Status::IOError("HDFS connection failed");
+ }
+ namenode_host_ = config->host;
+ port_ = config->port;
+ user_ = config->user;
+ kerb_ticket_ = config->kerb_ticket;
+
+ return Status::OK();
+ }
+
+ Status MakeDirectory(const std::string& path) {
+ int ret = driver_->MakeDirectory(fs_, path.c_str());
+ CHECK_FAILURE(ret, "create directory");
+ return Status::OK();
+ }
+
+ Status Delete(const std::string& path, bool recursive) {
+ int ret = driver_->Delete(fs_, path.c_str(), static_cast<int>(recursive));
+ CHECK_FAILURE(ret, "delete");
+ return Status::OK();
+ }
+
+ Status Disconnect() {
+ int ret = driver_->Disconnect(fs_);
+ CHECK_FAILURE(ret, "hdfsFS::Disconnect");
+ return Status::OK();
+ }
+
+ bool Exists(const std::string& path) {
+ // hdfsExists does not distinguish between RPC failure and the file not
+ // existing
+ int ret = driver_->Exists(fs_, path.c_str());
+ return ret == 0;
+ }
+
+ Status GetCapacity(int64_t* nbytes) {
+ tOffset ret = driver_->GetCapacity(fs_);
+ CHECK_FAILURE(ret, "GetCapacity");
+ *nbytes = ret;
+ return Status::OK();
+ }
+
+ Status GetUsed(int64_t* nbytes) {
+ tOffset ret = driver_->GetUsed(fs_);
+ CHECK_FAILURE(ret, "GetUsed");
+ *nbytes = ret;
+ return Status::OK();
+ }
+
+ Status GetWorkingDirectory(std::string* out) {
+ char buffer[2048];
+ if (driver_->GetWorkingDirectory(fs_, buffer, sizeof(buffer) - 1) == nullptr) {
+ return Status::IOError("HDFS GetWorkingDirectory failed, errno: ",
+ TranslateErrno(errno));
+ }
+ *out = buffer;
+ return Status::OK();
+ }
+
+ Status GetPathInfo(const std::string& path, HdfsPathInfo* info) {
+ hdfsFileInfo* entry = driver_->GetPathInfo(fs_, path.c_str());
+
+ if (entry == nullptr) {
+ return GetPathInfoFailed(path);
+ }
+
+ SetPathInfo(entry, info);
+ driver_->FreeFileInfo(entry, 1);
+
+ return Status::OK();
+ }
+
+ Status Stat(const std::string& path, FileStatistics* stat) {
+ HdfsPathInfo info;
+ RETURN_NOT_OK(GetPathInfo(path, &info));
+
+ stat->size = info.size;
+ stat->kind = info.kind;
+ return Status::OK();
+ }
+
+ Status GetChildren(const std::string& path, std::vector<std::string>* listing) {
+ std::vector<HdfsPathInfo> detailed_listing;
+ RETURN_NOT_OK(ListDirectory(path, &detailed_listing));
+ for (const auto& info : detailed_listing) {
+ listing->push_back(info.name);
+ }
+ return Status::OK();
+ }
+
+ Status ListDirectory(const std::string& path, std::vector<HdfsPathInfo>* listing) {
+ int num_entries = 0;
+ errno = 0;
+ hdfsFileInfo* entries = driver_->ListDirectory(fs_, path.c_str(), &num_entries);
+
+ if (entries == nullptr) {
+ // If the directory is empty, entries is NULL but errno is 0. Non-zero
+ // errno indicates error
+ //
+ // Note: errno is thread-local
+ //
+ // XXX(wesm): ARROW-2300; we found with Hadoop 2.6 that libhdfs would set
+ // errno 2/ENOENT for empty directories. To be more robust to this we
+ // double check this case
+ if ((errno == 0) || (errno == ENOENT && Exists(path))) {
+ num_entries = 0;
+ } else {
+ return Status::IOError("HDFS list directory failed, errno: ",
+ TranslateErrno(errno));
+ }
+ }
+
+ // Allocate additional space for elements
+ int vec_offset = static_cast<int>(listing->size());
+ listing->resize(vec_offset + num_entries);
+
+ for (int i = 0; i < num_entries; ++i) {
+ SetPathInfo(entries + i, &(*listing)[vec_offset + i]);
+ }
+
+ // Free libhdfs file info
+ driver_->FreeFileInfo(entries, num_entries);
+
+ return Status::OK();
+ }
+
+ Status OpenReadable(const std::string& path, int32_t buffer_size,
+ const io::IOContext& io_context,
+ std::shared_ptr<HdfsReadableFile>* file) {
+ errno = 0;
+ hdfsFile handle = driver_->OpenFile(fs_, path.c_str(), O_RDONLY, buffer_size, 0, 0);
+
+ if (handle == nullptr) {
+ if (errno) {
+ return IOErrorFromErrno(errno, "Opening HDFS file '", path, "' failed");
+ } else {
+ return Status::IOError("Opening HDFS file '", path, "' failed");
+ }
+ }
+
+ // std::make_shared does not work with private ctors
+ *file = std::shared_ptr<HdfsReadableFile>(new HdfsReadableFile(io_context));
+ (*file)->impl_->set_members(path, driver_, fs_, handle);
+ (*file)->impl_->set_buffer_size(buffer_size);
+
+ return Status::OK();
+ }
+
+ Status OpenWritable(const std::string& path, bool append, int32_t buffer_size,
+ int16_t replication, int64_t default_block_size,
+ std::shared_ptr<HdfsOutputStream>* file) {
+ int flags = O_WRONLY;
+ if (append) flags |= O_APPEND;
+
+ errno = 0;
+ hdfsFile handle =
+ driver_->OpenFile(fs_, path.c_str(), flags, buffer_size, replication,
+ static_cast<tSize>(default_block_size));
+
+ if (handle == nullptr) {
+ if (errno) {
+ return IOErrorFromErrno(errno, "Opening HDFS file '", path, "' failed");
+ } else {
+ return Status::IOError("Opening HDFS file '", path, "' failed");
+ }
+ }
+
+ // std::make_shared does not work with private ctors
+ *file = std::shared_ptr<HdfsOutputStream>(new HdfsOutputStream());
+ (*file)->impl_->set_members(path, driver_, fs_, handle);
+
+ return Status::OK();
+ }
+
+ Status Rename(const std::string& src, const std::string& dst) {
+ int ret = driver_->Rename(fs_, src.c_str(), dst.c_str());
+ CHECK_FAILURE(ret, "Rename");
+ return Status::OK();
+ }
+
+ Status Copy(const std::string& src, const std::string& dst) {
+ int ret = driver_->Copy(fs_, src.c_str(), fs_, dst.c_str());
+ CHECK_FAILURE(ret, "Rename");
+ return Status::OK();
+ }
+
+ Status Move(const std::string& src, const std::string& dst) {
+ int ret = driver_->Move(fs_, src.c_str(), fs_, dst.c_str());
+ CHECK_FAILURE(ret, "Rename");
+ return Status::OK();
+ }
+
+ Status Chmod(const std::string& path, int mode) {
+ int ret = driver_->Chmod(fs_, path.c_str(), static_cast<short>(mode)); // NOLINT
+ CHECK_FAILURE(ret, "Chmod");
+ return Status::OK();
+ }
+
+ Status Chown(const std::string& path, const char* owner, const char* group) {
+ int ret = driver_->Chown(fs_, path.c_str(), owner, group);
+ CHECK_FAILURE(ret, "Chown");
+ return Status::OK();
+ }
+
+ private:
+ internal::LibHdfsShim* driver_;
+
+ std::string namenode_host_;
+ std::string user_;
+ int port_;
+ std::string kerb_ticket_;
+
+ hdfsFS fs_;
+};
+
+// ----------------------------------------------------------------------
+// Public API for HDFSClient
+
+HadoopFileSystem::HadoopFileSystem() { impl_.reset(new HadoopFileSystemImpl()); }
+
+HadoopFileSystem::~HadoopFileSystem() {}
+
+Status HadoopFileSystem::Connect(const HdfsConnectionConfig* config,
+ std::shared_ptr<HadoopFileSystem>* fs) {
+ // ctor is private, make_shared will not work
+ *fs = std::shared_ptr<HadoopFileSystem>(new HadoopFileSystem());
+
+ RETURN_NOT_OK((*fs)->impl_->Connect(config));
+ return Status::OK();
+}
+
+Status HadoopFileSystem::MakeDirectory(const std::string& path) {
+ return impl_->MakeDirectory(path);
+}
+
+Status HadoopFileSystem::Delete(const std::string& path, bool recursive) {
+ return impl_->Delete(path, recursive);
+}
+
+Status HadoopFileSystem::DeleteDirectory(const std::string& path) {
+ return Delete(path, true);
+}
+
+Status HadoopFileSystem::Disconnect() { return impl_->Disconnect(); }
+
+bool HadoopFileSystem::Exists(const std::string& path) { return impl_->Exists(path); }
+
+Status HadoopFileSystem::GetPathInfo(const std::string& path, HdfsPathInfo* info) {
+ return impl_->GetPathInfo(path, info);
+}
+
+Status HadoopFileSystem::Stat(const std::string& path, FileStatistics* stat) {
+ return impl_->Stat(path, stat);
+}
+
+Status HadoopFileSystem::GetCapacity(int64_t* nbytes) {
+ return impl_->GetCapacity(nbytes);
+}
+
+Status HadoopFileSystem::GetUsed(int64_t* nbytes) { return impl_->GetUsed(nbytes); }
+
+Status HadoopFileSystem::GetWorkingDirectory(std::string* out) {
+ return impl_->GetWorkingDirectory(out);
+}
+
+Status HadoopFileSystem::GetChildren(const std::string& path,
+ std::vector<std::string>* listing) {
+ return impl_->GetChildren(path, listing);
+}
+
+Status HadoopFileSystem::ListDirectory(const std::string& path,
+ std::vector<HdfsPathInfo>* listing) {
+ return impl_->ListDirectory(path, listing);
+}
+
+Status HadoopFileSystem::OpenReadable(const std::string& path, int32_t buffer_size,
+ std::shared_ptr<HdfsReadableFile>* file) {
+ return impl_->OpenReadable(path, buffer_size, io::default_io_context(), file);
+}
+
+Status HadoopFileSystem::OpenReadable(const std::string& path,
+ std::shared_ptr<HdfsReadableFile>* file) {
+ return OpenReadable(path, kDefaultHdfsBufferSize, io::default_io_context(), file);
+}
+
+Status HadoopFileSystem::OpenReadable(const std::string& path, int32_t buffer_size,
+ const io::IOContext& io_context,
+ std::shared_ptr<HdfsReadableFile>* file) {
+ return impl_->OpenReadable(path, buffer_size, io_context, file);
+}
+
+Status HadoopFileSystem::OpenReadable(const std::string& path,
+ const io::IOContext& io_context,
+ std::shared_ptr<HdfsReadableFile>* file) {
+ return OpenReadable(path, kDefaultHdfsBufferSize, io_context, file);
+}
+
+Status HadoopFileSystem::OpenWritable(const std::string& path, bool append,
+ int32_t buffer_size, int16_t replication,
+ int64_t default_block_size,
+ std::shared_ptr<HdfsOutputStream>* file) {
+ return impl_->OpenWritable(path, append, buffer_size, replication, default_block_size,
+ file);
+}
+
+Status HadoopFileSystem::OpenWritable(const std::string& path, bool append,
+ std::shared_ptr<HdfsOutputStream>* file) {
+ return OpenWritable(path, append, 0, 0, 0, file);
+}
+
+Status HadoopFileSystem::Chmod(const std::string& path, int mode) {
+ return impl_->Chmod(path, mode);
+}
+
+Status HadoopFileSystem::Chown(const std::string& path, const char* owner,
+ const char* group) {
+ return impl_->Chown(path, owner, group);
+}
+
+Status HadoopFileSystem::Rename(const std::string& src, const std::string& dst) {
+ return impl_->Rename(src, dst);
+}
+
+Status HadoopFileSystem::Copy(const std::string& src, const std::string& dst) {
+ return impl_->Copy(src, dst);
+}
+
+Status HadoopFileSystem::Move(const std::string& src, const std::string& dst) {
+ return impl_->Move(src, dst);
+}
+
+// ----------------------------------------------------------------------
+// Allow public API users to check whether we are set up correctly
+
+Status HaveLibHdfs() {
+ internal::LibHdfsShim* driver;
+ return internal::ConnectLibHdfs(&driver);
+}
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/hdfs.h b/src/arrow/cpp/src/arrow/io/hdfs.h
new file mode 100644
index 000000000..5244eb052
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/hdfs.h
@@ -0,0 +1,284 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "arrow/io/interfaces.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Buffer;
+class MemoryPool;
+class Status;
+
+namespace io {
+
+class HdfsReadableFile;
+class HdfsOutputStream;
+
+/// DEPRECATED. Use the FileSystem API in arrow::fs instead.
+struct ObjectType {
+ enum type { FILE, DIRECTORY };
+};
+
+/// DEPRECATED. Use the FileSystem API in arrow::fs instead.
+struct ARROW_EXPORT FileStatistics {
+ /// Size of file, -1 if finding length is unsupported
+ int64_t size;
+ ObjectType::type kind;
+};
+
+class ARROW_EXPORT FileSystem {
+ public:
+ virtual ~FileSystem() = default;
+
+ virtual Status MakeDirectory(const std::string& path) = 0;
+
+ virtual Status DeleteDirectory(const std::string& path) = 0;
+
+ virtual Status GetChildren(const std::string& path,
+ std::vector<std::string>* listing) = 0;
+
+ virtual Status Rename(const std::string& src, const std::string& dst) = 0;
+
+ virtual Status Stat(const std::string& path, FileStatistics* stat) = 0;
+};
+
+struct HdfsPathInfo {
+ ObjectType::type kind;
+
+ std::string name;
+ std::string owner;
+ std::string group;
+
+ // Access times in UNIX timestamps (seconds)
+ int64_t size;
+ int64_t block_size;
+
+ int32_t last_modified_time;
+ int32_t last_access_time;
+
+ int16_t replication;
+ int16_t permissions;
+};
+
+struct HdfsConnectionConfig {
+ std::string host;
+ int port;
+ std::string user;
+ std::string kerb_ticket;
+ std::unordered_map<std::string, std::string> extra_conf;
+};
+
+class ARROW_EXPORT HadoopFileSystem : public FileSystem {
+ public:
+ ~HadoopFileSystem() override;
+
+ // Connect to an HDFS cluster given a configuration
+ //
+ // @param config (in): configuration for connecting
+ // @param fs (out): the created client
+ // @returns Status
+ static Status Connect(const HdfsConnectionConfig* config,
+ std::shared_ptr<HadoopFileSystem>* fs);
+
+ // Create directory and all parents
+ //
+ // @param path (in): absolute HDFS path
+ // @returns Status
+ Status MakeDirectory(const std::string& path) override;
+
+ // Delete file or directory
+ // @param path absolute path to data
+ // @param recursive if path is a directory, delete contents as well
+ // @returns error status on failure
+ Status Delete(const std::string& path, bool recursive = false);
+
+ Status DeleteDirectory(const std::string& path) override;
+
+ // Disconnect from cluster
+ //
+ // @returns Status
+ Status Disconnect();
+
+ // @param path (in): absolute HDFS path
+ // @returns bool, true if the path exists, false if not (or on error)
+ bool Exists(const std::string& path);
+
+ // @param path (in): absolute HDFS path
+ // @param info (out)
+ // @returns Status
+ Status GetPathInfo(const std::string& path, HdfsPathInfo* info);
+
+ // @param nbytes (out): total capacity of the filesystem
+ // @returns Status
+ Status GetCapacity(int64_t* nbytes);
+
+ // @param nbytes (out): total bytes used of the filesystem
+ // @returns Status
+ Status GetUsed(int64_t* nbytes);
+
+ Status GetChildren(const std::string& path, std::vector<std::string>* listing) override;
+
+ /// List directory contents
+ ///
+ /// If path is a relative path, returned values will be absolute paths or URIs
+ /// starting from the current working directory.
+ Status ListDirectory(const std::string& path, std::vector<HdfsPathInfo>* listing);
+
+ /// Return the filesystem's current working directory.
+ ///
+ /// The working directory is the base path for all relative paths given to
+ /// other APIs.
+ /// NOTE: this actually returns a URI.
+ Status GetWorkingDirectory(std::string* out);
+
+ /// Change
+ ///
+ /// @param path file path to change
+ /// @param owner pass null for no change
+ /// @param group pass null for no change
+ Status Chown(const std::string& path, const char* owner, const char* group);
+
+ /// Change path permissions
+ ///
+ /// \param path Absolute path in file system
+ /// \param mode Mode bitset
+ /// \return Status
+ Status Chmod(const std::string& path, int mode);
+
+ // Move file or directory from source path to destination path within the
+ // current filesystem
+ Status Rename(const std::string& src, const std::string& dst) override;
+
+ Status Copy(const std::string& src, const std::string& dst);
+
+ Status Move(const std::string& src, const std::string& dst);
+
+ Status Stat(const std::string& path, FileStatistics* stat) override;
+
+ // TODO(wesm): GetWorkingDirectory, SetWorkingDirectory
+
+ // Open an HDFS file in READ mode. Returns error
+ // status if the file is not found.
+ //
+ // @param path complete file path
+ Status OpenReadable(const std::string& path, int32_t buffer_size,
+ std::shared_ptr<HdfsReadableFile>* file);
+
+ Status OpenReadable(const std::string& path, int32_t buffer_size,
+ const io::IOContext& io_context,
+ std::shared_ptr<HdfsReadableFile>* file);
+
+ Status OpenReadable(const std::string& path, std::shared_ptr<HdfsReadableFile>* file);
+
+ Status OpenReadable(const std::string& path, const io::IOContext& io_context,
+ std::shared_ptr<HdfsReadableFile>* file);
+
+ // FileMode::WRITE options
+ // @param path complete file path
+ // @param buffer_size 0 by default
+ // @param replication 0 by default
+ // @param default_block_size 0 by default
+ Status OpenWritable(const std::string& path, bool append, int32_t buffer_size,
+ int16_t replication, int64_t default_block_size,
+ std::shared_ptr<HdfsOutputStream>* file);
+
+ Status OpenWritable(const std::string& path, bool append,
+ std::shared_ptr<HdfsOutputStream>* file);
+
+ private:
+ friend class HdfsReadableFile;
+ friend class HdfsOutputStream;
+
+ class ARROW_NO_EXPORT HadoopFileSystemImpl;
+ std::unique_ptr<HadoopFileSystemImpl> impl_;
+
+ HadoopFileSystem();
+ ARROW_DISALLOW_COPY_AND_ASSIGN(HadoopFileSystem);
+};
+
+class ARROW_EXPORT HdfsReadableFile : public RandomAccessFile {
+ public:
+ ~HdfsReadableFile() override;
+
+ Status Close() override;
+
+ bool closed() const override;
+
+ // NOTE: If you wish to read a particular range of a file in a multithreaded
+ // context, you may prefer to use ReadAt to avoid locking issues
+ Result<int64_t> Read(int64_t nbytes, void* out) override;
+ Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) override;
+ Result<int64_t> ReadAt(int64_t position, int64_t nbytes, void* out) override;
+ Result<std::shared_ptr<Buffer>> ReadAt(int64_t position, int64_t nbytes) override;
+
+ Status Seek(int64_t position) override;
+ Result<int64_t> Tell() const override;
+ Result<int64_t> GetSize() override;
+
+ private:
+ explicit HdfsReadableFile(const io::IOContext&);
+
+ class ARROW_NO_EXPORT HdfsReadableFileImpl;
+ std::unique_ptr<HdfsReadableFileImpl> impl_;
+
+ friend class HadoopFileSystem::HadoopFileSystemImpl;
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(HdfsReadableFile);
+};
+
+// Naming this file OutputStream because it does not support seeking (like the
+// WritableFile interface)
+class ARROW_EXPORT HdfsOutputStream : public OutputStream {
+ public:
+ ~HdfsOutputStream() override;
+
+ Status Close() override;
+
+ bool closed() const override;
+
+ using OutputStream::Write;
+ Status Write(const void* buffer, int64_t nbytes) override;
+
+ Status Flush() override;
+
+ Result<int64_t> Tell() const override;
+
+ private:
+ class ARROW_NO_EXPORT HdfsOutputStreamImpl;
+ std::unique_ptr<HdfsOutputStreamImpl> impl_;
+
+ friend class HadoopFileSystem::HadoopFileSystemImpl;
+
+ HdfsOutputStream();
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(HdfsOutputStream);
+};
+
+Status ARROW_EXPORT HaveLibHdfs();
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/hdfs_internal.cc b/src/arrow/cpp/src/arrow/io/hdfs_internal.cc
new file mode 100644
index 000000000..4592392b8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/hdfs_internal.cc
@@ -0,0 +1,556 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This shim interface to libhdfs (for runtime shared library loading) has been
+// adapted from the SFrame project, released under the ASF-compatible 3-clause
+// BSD license
+//
+// Using this required having the $JAVA_HOME and $HADOOP_HOME environment
+// variables set, so that libjvm and libhdfs can be located easily
+
+// Copyright (C) 2015 Dato, Inc.
+// All rights reserved.
+//
+// This software may be modified and distributed under the terms
+// of the BSD license. See the LICENSE file for details.
+
+#include "arrow/io/hdfs_internal.h"
+
+#include <cstdint>
+#include <cstdlib>
+#include <mutex>
+#include <sstream> // IWYU pragma: keep
+#include <string>
+#include <utility>
+#include <vector>
+
+#ifndef _WIN32
+#include <dlfcn.h>
+#endif
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::GetEnvVarNative;
+using internal::PlatformFilename;
+#ifdef _WIN32
+using internal::WinErrorMessage;
+#endif
+
+namespace io {
+namespace internal {
+
+namespace {
+
+void* GetLibrarySymbol(LibraryHandle handle, const char* symbol) {
+ if (handle == NULL) return NULL;
+#ifndef _WIN32
+ return dlsym(handle, symbol);
+#else
+
+ void* ret = reinterpret_cast<void*>(GetProcAddress(handle, symbol));
+ if (ret == NULL) {
+ // logstream(LOG_INFO) << "GetProcAddress error: "
+ // << get_last_err_str(GetLastError()) << std::endl;
+ }
+ return ret;
+#endif
+}
+
+#define GET_SYMBOL_REQUIRED(SHIM, SYMBOL_NAME) \
+ do { \
+ if (!SHIM->SYMBOL_NAME) { \
+ *reinterpret_cast<void**>(&SHIM->SYMBOL_NAME) = \
+ GetLibrarySymbol(SHIM->handle, "" #SYMBOL_NAME); \
+ } \
+ if (!SHIM->SYMBOL_NAME) \
+ return Status::IOError("Getting symbol " #SYMBOL_NAME "failed"); \
+ } while (0)
+
+#define GET_SYMBOL(SHIM, SYMBOL_NAME) \
+ if (!SHIM->SYMBOL_NAME) { \
+ *reinterpret_cast<void**>(&SHIM->SYMBOL_NAME) = \
+ GetLibrarySymbol(SHIM->handle, "" #SYMBOL_NAME); \
+ }
+
+LibraryHandle libjvm_handle = nullptr;
+
+// Helper functions for dlopens
+Result<std::vector<PlatformFilename>> get_potential_libjvm_paths();
+Result<std::vector<PlatformFilename>> get_potential_libhdfs_paths();
+Result<LibraryHandle> try_dlopen(const std::vector<PlatformFilename>& potential_paths,
+ const char* name);
+
+Result<std::vector<PlatformFilename>> MakeFilenameVector(
+ const std::vector<std::string>& names) {
+ std::vector<PlatformFilename> filenames(names.size());
+ for (size_t i = 0; i < names.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(filenames[i], PlatformFilename::FromString(names[i]));
+ }
+ return filenames;
+}
+
+void AppendEnvVarFilename(const char* var_name,
+ std::vector<PlatformFilename>* filenames) {
+ auto maybe_env_var = GetEnvVarNative(var_name);
+ if (maybe_env_var.ok()) {
+ filenames->emplace_back(std::move(*maybe_env_var));
+ }
+}
+
+void AppendEnvVarFilename(const char* var_name, const char* suffix,
+ std::vector<PlatformFilename>* filenames) {
+ auto maybe_env_var = GetEnvVarNative(var_name);
+ if (maybe_env_var.ok()) {
+ auto maybe_env_var_with_suffix =
+ PlatformFilename(std::move(*maybe_env_var)).Join(suffix);
+ if (maybe_env_var_with_suffix.ok()) {
+ filenames->emplace_back(std::move(*maybe_env_var_with_suffix));
+ }
+ }
+}
+
+void InsertEnvVarFilename(const char* var_name,
+ std::vector<PlatformFilename>* filenames) {
+ auto maybe_env_var = GetEnvVarNative(var_name);
+ if (maybe_env_var.ok()) {
+ filenames->emplace(filenames->begin(), PlatformFilename(std::move(*maybe_env_var)));
+ }
+}
+
+Result<std::vector<PlatformFilename>> get_potential_libhdfs_paths() {
+ std::vector<PlatformFilename> potential_paths;
+ std::string file_name;
+
+// OS-specific file name
+#ifdef _WIN32
+ file_name = "hdfs.dll";
+#elif __APPLE__
+ file_name = "libhdfs.dylib";
+#else
+ file_name = "libhdfs.so";
+#endif
+
+ // Common paths
+ ARROW_ASSIGN_OR_RAISE(auto search_paths, MakeFilenameVector({"", "."}));
+
+ // Path from environment variable
+ AppendEnvVarFilename("HADOOP_HOME", "lib/native", &search_paths);
+ AppendEnvVarFilename("ARROW_LIBHDFS_DIR", &search_paths);
+
+ // All paths with file name
+ for (const auto& path : search_paths) {
+ ARROW_ASSIGN_OR_RAISE(auto full_path, path.Join(file_name));
+ potential_paths.push_back(std::move(full_path));
+ }
+
+ return potential_paths;
+}
+
+Result<std::vector<PlatformFilename>> get_potential_libjvm_paths() {
+ std::vector<PlatformFilename> potential_paths;
+
+ std::vector<PlatformFilename> search_prefixes;
+ std::vector<PlatformFilename> search_suffixes;
+ std::string file_name;
+
+// From heuristics
+#ifdef _WIN32
+ ARROW_ASSIGN_OR_RAISE(search_prefixes, MakeFilenameVector({""}));
+ ARROW_ASSIGN_OR_RAISE(search_suffixes,
+ MakeFilenameVector({"/jre/bin/server", "/bin/server"}));
+ file_name = "jvm.dll";
+#elif __APPLE__
+ ARROW_ASSIGN_OR_RAISE(search_prefixes, MakeFilenameVector({""}));
+ ARROW_ASSIGN_OR_RAISE(search_suffixes,
+ MakeFilenameVector({"/jre/lib/server", "/lib/server"}));
+ file_name = "libjvm.dylib";
+
+// SFrame uses /usr/libexec/java_home to find JAVA_HOME; for now we are
+// expecting users to set an environment variable
+#else
+#if defined(__aarch64__)
+ const std::string prefix_arch{"arm64"};
+ const std::string suffix_arch{"aarch64"};
+#else
+ const std::string prefix_arch{"amd64"};
+ const std::string suffix_arch{"amd64"};
+#endif
+ ARROW_ASSIGN_OR_RAISE(
+ search_prefixes,
+ MakeFilenameVector({
+ "/usr/lib/jvm/default-java", // ubuntu / debian distros
+ "/usr/lib/jvm/java", // rhel6
+ "/usr/lib/jvm", // centos6
+ "/usr/lib64/jvm", // opensuse 13
+ "/usr/local/lib/jvm/default-java", // alt ubuntu / debian distros
+ "/usr/local/lib/jvm/java", // alt rhel6
+ "/usr/local/lib/jvm", // alt centos6
+ "/usr/local/lib64/jvm", // alt opensuse 13
+ "/usr/local/lib/jvm/java-8-openjdk-" +
+ prefix_arch, // alt ubuntu / debian distros
+ "/usr/lib/jvm/java-8-openjdk-" + prefix_arch, // alt ubuntu / debian distros
+ "/usr/local/lib/jvm/java-7-openjdk-" +
+ prefix_arch, // alt ubuntu / debian distros
+ "/usr/lib/jvm/java-7-openjdk-" + prefix_arch, // alt ubuntu / debian distros
+ "/usr/local/lib/jvm/java-6-openjdk-" +
+ prefix_arch, // alt ubuntu / debian distros
+ "/usr/lib/jvm/java-6-openjdk-" + prefix_arch, // alt ubuntu / debian distros
+ "/usr/lib/jvm/java-7-oracle", // alt ubuntu
+ "/usr/lib/jvm/java-8-oracle", // alt ubuntu
+ "/usr/lib/jvm/java-6-oracle", // alt ubuntu
+ "/usr/local/lib/jvm/java-7-oracle", // alt ubuntu
+ "/usr/local/lib/jvm/java-8-oracle", // alt ubuntu
+ "/usr/local/lib/jvm/java-6-oracle", // alt ubuntu
+ "/usr/lib/jvm/default", // alt centos
+ "/usr/java/latest" // alt centos
+ }));
+ ARROW_ASSIGN_OR_RAISE(
+ search_suffixes,
+ MakeFilenameVector({"", "/lib/server", "/jre/lib/" + suffix_arch + "/server",
+ "/lib/" + suffix_arch + "/server"}));
+ file_name = "libjvm.so";
+#endif
+
+ // From direct environment variable
+ InsertEnvVarFilename("JAVA_HOME", &search_prefixes);
+
+ // Generate cross product between search_prefixes, search_suffixes, and file_name
+ for (auto& prefix : search_prefixes) {
+ for (auto& suffix : search_suffixes) {
+ ARROW_ASSIGN_OR_RAISE(auto path, prefix.Join(suffix).Join(file_name));
+ potential_paths.push_back(std::move(path));
+ }
+ }
+
+ return potential_paths;
+}
+
+#ifndef _WIN32
+Result<LibraryHandle> try_dlopen(const std::vector<PlatformFilename>& potential_paths,
+ const char* name) {
+ std::string error_message = "unknown error";
+ LibraryHandle handle;
+
+ for (const auto& p : potential_paths) {
+ handle = dlopen(p.ToNative().c_str(), RTLD_NOW | RTLD_LOCAL);
+
+ if (handle != NULL) {
+ return handle;
+ } else {
+ const char* err_msg = dlerror();
+ if (err_msg != NULL) {
+ error_message = err_msg;
+ }
+ }
+ }
+
+ return Status::IOError("Unable to load ", name, ": ", error_message);
+}
+
+#else
+Result<LibraryHandle> try_dlopen(const std::vector<PlatformFilename>& potential_paths,
+ const char* name) {
+ std::string error_message;
+ LibraryHandle handle;
+
+ for (const auto& p : potential_paths) {
+ handle = LoadLibraryW(p.ToNative().c_str());
+ if (handle != NULL) {
+ return handle;
+ } else {
+ error_message = WinErrorMessage(GetLastError());
+ }
+ }
+
+ return Status::IOError("Unable to load ", name, ": ", error_message);
+}
+#endif // _WIN32
+
+LibHdfsShim libhdfs_shim;
+
+} // namespace
+
+Status LibHdfsShim::GetRequiredSymbols() {
+ GET_SYMBOL_REQUIRED(this, hdfsNewBuilder);
+ GET_SYMBOL_REQUIRED(this, hdfsBuilderSetNameNode);
+ GET_SYMBOL_REQUIRED(this, hdfsBuilderSetNameNodePort);
+ GET_SYMBOL_REQUIRED(this, hdfsBuilderSetUserName);
+ GET_SYMBOL_REQUIRED(this, hdfsBuilderSetKerbTicketCachePath);
+ GET_SYMBOL_REQUIRED(this, hdfsBuilderSetForceNewInstance);
+ GET_SYMBOL_REQUIRED(this, hdfsBuilderConfSetStr);
+ GET_SYMBOL_REQUIRED(this, hdfsBuilderConnect);
+ GET_SYMBOL_REQUIRED(this, hdfsCreateDirectory);
+ GET_SYMBOL_REQUIRED(this, hdfsDelete);
+ GET_SYMBOL_REQUIRED(this, hdfsDisconnect);
+ GET_SYMBOL_REQUIRED(this, hdfsExists);
+ GET_SYMBOL_REQUIRED(this, hdfsFreeFileInfo);
+ GET_SYMBOL_REQUIRED(this, hdfsGetCapacity);
+ GET_SYMBOL_REQUIRED(this, hdfsGetUsed);
+ GET_SYMBOL_REQUIRED(this, hdfsGetPathInfo);
+ GET_SYMBOL_REQUIRED(this, hdfsListDirectory);
+ GET_SYMBOL_REQUIRED(this, hdfsChown);
+ GET_SYMBOL_REQUIRED(this, hdfsChmod);
+
+ // File methods
+ GET_SYMBOL_REQUIRED(this, hdfsCloseFile);
+ GET_SYMBOL_REQUIRED(this, hdfsFlush);
+ GET_SYMBOL_REQUIRED(this, hdfsOpenFile);
+ GET_SYMBOL_REQUIRED(this, hdfsRead);
+ GET_SYMBOL_REQUIRED(this, hdfsSeek);
+ GET_SYMBOL_REQUIRED(this, hdfsTell);
+ GET_SYMBOL_REQUIRED(this, hdfsWrite);
+
+ return Status::OK();
+}
+
+Status ConnectLibHdfs(LibHdfsShim** driver) {
+ static std::mutex lock;
+ std::lock_guard<std::mutex> guard(lock);
+
+ LibHdfsShim* shim = &libhdfs_shim;
+
+ static bool shim_attempted = false;
+ if (!shim_attempted) {
+ shim_attempted = true;
+
+ shim->Initialize();
+
+ ARROW_ASSIGN_OR_RAISE(auto libjvm_potential_paths, get_potential_libjvm_paths());
+ ARROW_ASSIGN_OR_RAISE(libjvm_handle, try_dlopen(libjvm_potential_paths, "libjvm"));
+
+ ARROW_ASSIGN_OR_RAISE(auto libhdfs_potential_paths, get_potential_libhdfs_paths());
+ ARROW_ASSIGN_OR_RAISE(shim->handle, try_dlopen(libhdfs_potential_paths, "libhdfs"));
+ } else if (shim->handle == nullptr) {
+ return Status::IOError("Prior attempt to load libhdfs failed");
+ }
+
+ *driver = shim;
+ return shim->GetRequiredSymbols();
+}
+
+///////////////////////////////////////////////////////////////////////////
+// HDFS thin wrapper methods
+
+hdfsBuilder* LibHdfsShim::NewBuilder(void) { return this->hdfsNewBuilder(); }
+
+void LibHdfsShim::BuilderSetNameNode(hdfsBuilder* bld, const char* nn) {
+ this->hdfsBuilderSetNameNode(bld, nn);
+}
+
+void LibHdfsShim::BuilderSetNameNodePort(hdfsBuilder* bld, tPort port) {
+ this->hdfsBuilderSetNameNodePort(bld, port);
+}
+
+void LibHdfsShim::BuilderSetUserName(hdfsBuilder* bld, const char* userName) {
+ this->hdfsBuilderSetUserName(bld, userName);
+}
+
+void LibHdfsShim::BuilderSetKerbTicketCachePath(hdfsBuilder* bld,
+ const char* kerbTicketCachePath) {
+ this->hdfsBuilderSetKerbTicketCachePath(bld, kerbTicketCachePath);
+}
+
+void LibHdfsShim::BuilderSetForceNewInstance(hdfsBuilder* bld) {
+ this->hdfsBuilderSetForceNewInstance(bld);
+}
+
+hdfsFS LibHdfsShim::BuilderConnect(hdfsBuilder* bld) {
+ return this->hdfsBuilderConnect(bld);
+}
+
+int LibHdfsShim::BuilderConfSetStr(hdfsBuilder* bld, const char* key, const char* val) {
+ return this->hdfsBuilderConfSetStr(bld, key, val);
+}
+
+int LibHdfsShim::Disconnect(hdfsFS fs) { return this->hdfsDisconnect(fs); }
+
+hdfsFile LibHdfsShim::OpenFile(hdfsFS fs, const char* path, int flags, int bufferSize,
+ short replication, tSize blocksize) { // NOLINT
+ return this->hdfsOpenFile(fs, path, flags, bufferSize, replication, blocksize);
+}
+
+int LibHdfsShim::CloseFile(hdfsFS fs, hdfsFile file) {
+ return this->hdfsCloseFile(fs, file);
+}
+
+int LibHdfsShim::Exists(hdfsFS fs, const char* path) {
+ return this->hdfsExists(fs, path);
+}
+
+int LibHdfsShim::Seek(hdfsFS fs, hdfsFile file, tOffset desiredPos) {
+ return this->hdfsSeek(fs, file, desiredPos);
+}
+
+tOffset LibHdfsShim::Tell(hdfsFS fs, hdfsFile file) { return this->hdfsTell(fs, file); }
+
+tSize LibHdfsShim::Read(hdfsFS fs, hdfsFile file, void* buffer, tSize length) {
+ return this->hdfsRead(fs, file, buffer, length);
+}
+
+bool LibHdfsShim::HasPread() {
+ GET_SYMBOL(this, hdfsPread);
+ return this->hdfsPread != nullptr;
+}
+
+tSize LibHdfsShim::Pread(hdfsFS fs, hdfsFile file, tOffset position, void* buffer,
+ tSize length) {
+ GET_SYMBOL(this, hdfsPread);
+ DCHECK(this->hdfsPread);
+ return this->hdfsPread(fs, file, position, buffer, length);
+}
+
+tSize LibHdfsShim::Write(hdfsFS fs, hdfsFile file, const void* buffer, tSize length) {
+ return this->hdfsWrite(fs, file, buffer, length);
+}
+
+int LibHdfsShim::Flush(hdfsFS fs, hdfsFile file) { return this->hdfsFlush(fs, file); }
+
+int LibHdfsShim::Available(hdfsFS fs, hdfsFile file) {
+ GET_SYMBOL(this, hdfsAvailable);
+ if (this->hdfsAvailable)
+ return this->hdfsAvailable(fs, file);
+ else
+ return 0;
+}
+
+int LibHdfsShim::Copy(hdfsFS srcFS, const char* src, hdfsFS dstFS, const char* dst) {
+ GET_SYMBOL(this, hdfsCopy);
+ if (this->hdfsCopy)
+ return this->hdfsCopy(srcFS, src, dstFS, dst);
+ else
+ return 0;
+}
+
+int LibHdfsShim::Move(hdfsFS srcFS, const char* src, hdfsFS dstFS, const char* dst) {
+ GET_SYMBOL(this, hdfsMove);
+ if (this->hdfsMove)
+ return this->hdfsMove(srcFS, src, dstFS, dst);
+ else
+ return 0;
+}
+
+int LibHdfsShim::Delete(hdfsFS fs, const char* path, int recursive) {
+ return this->hdfsDelete(fs, path, recursive);
+}
+
+int LibHdfsShim::Rename(hdfsFS fs, const char* oldPath, const char* newPath) {
+ GET_SYMBOL(this, hdfsRename);
+ if (this->hdfsRename)
+ return this->hdfsRename(fs, oldPath, newPath);
+ else
+ return 0;
+}
+
+char* LibHdfsShim::GetWorkingDirectory(hdfsFS fs, char* buffer, size_t bufferSize) {
+ GET_SYMBOL(this, hdfsGetWorkingDirectory);
+ if (this->hdfsGetWorkingDirectory) {
+ return this->hdfsGetWorkingDirectory(fs, buffer, bufferSize);
+ } else {
+ return NULL;
+ }
+}
+
+int LibHdfsShim::SetWorkingDirectory(hdfsFS fs, const char* path) {
+ GET_SYMBOL(this, hdfsSetWorkingDirectory);
+ if (this->hdfsSetWorkingDirectory) {
+ return this->hdfsSetWorkingDirectory(fs, path);
+ } else {
+ return 0;
+ }
+}
+
+int LibHdfsShim::MakeDirectory(hdfsFS fs, const char* path) {
+ return this->hdfsCreateDirectory(fs, path);
+}
+
+int LibHdfsShim::SetReplication(hdfsFS fs, const char* path, int16_t replication) {
+ GET_SYMBOL(this, hdfsSetReplication);
+ if (this->hdfsSetReplication) {
+ return this->hdfsSetReplication(fs, path, replication);
+ } else {
+ return 0;
+ }
+}
+
+hdfsFileInfo* LibHdfsShim::ListDirectory(hdfsFS fs, const char* path, int* numEntries) {
+ return this->hdfsListDirectory(fs, path, numEntries);
+}
+
+hdfsFileInfo* LibHdfsShim::GetPathInfo(hdfsFS fs, const char* path) {
+ return this->hdfsGetPathInfo(fs, path);
+}
+
+void LibHdfsShim::FreeFileInfo(hdfsFileInfo* hdfsFileInfo, int numEntries) {
+ this->hdfsFreeFileInfo(hdfsFileInfo, numEntries);
+}
+
+char*** LibHdfsShim::GetHosts(hdfsFS fs, const char* path, tOffset start,
+ tOffset length) {
+ GET_SYMBOL(this, hdfsGetHosts);
+ if (this->hdfsGetHosts) {
+ return this->hdfsGetHosts(fs, path, start, length);
+ } else {
+ return NULL;
+ }
+}
+
+void LibHdfsShim::FreeHosts(char*** blockHosts) {
+ GET_SYMBOL(this, hdfsFreeHosts);
+ if (this->hdfsFreeHosts) {
+ this->hdfsFreeHosts(blockHosts);
+ }
+}
+
+tOffset LibHdfsShim::GetDefaultBlockSize(hdfsFS fs) {
+ GET_SYMBOL(this, hdfsGetDefaultBlockSize);
+ if (this->hdfsGetDefaultBlockSize) {
+ return this->hdfsGetDefaultBlockSize(fs);
+ } else {
+ return 0;
+ }
+}
+
+tOffset LibHdfsShim::GetCapacity(hdfsFS fs) { return this->hdfsGetCapacity(fs); }
+
+tOffset LibHdfsShim::GetUsed(hdfsFS fs) { return this->hdfsGetUsed(fs); }
+
+int LibHdfsShim::Chown(hdfsFS fs, const char* path, const char* owner,
+ const char* group) {
+ return this->hdfsChown(fs, path, owner, group);
+}
+
+int LibHdfsShim::Chmod(hdfsFS fs, const char* path, short mode) { // NOLINT
+ return this->hdfsChmod(fs, path, mode);
+}
+
+int LibHdfsShim::Utime(hdfsFS fs, const char* path, tTime mtime, tTime atime) {
+ GET_SYMBOL(this, hdfsUtime);
+ if (this->hdfsUtime) {
+ return this->hdfsUtime(fs, path, mtime, atime);
+ } else {
+ return 0;
+ }
+}
+
+} // namespace internal
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/hdfs_internal.h b/src/arrow/cpp/src/arrow/io/hdfs_internal.h
new file mode 100644
index 000000000..624938231
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/hdfs_internal.h
@@ -0,0 +1,222 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+
+#include <hdfs.h>
+
+#include "arrow/util/visibility.h"
+#include "arrow/util/windows_compatibility.h" // IWYU pragma: keep
+
+using std::size_t;
+
+struct hdfsBuilder;
+
+namespace arrow {
+
+class Status;
+
+namespace io {
+namespace internal {
+
+#ifndef _WIN32
+typedef void* LibraryHandle;
+#else
+typedef HINSTANCE LibraryHandle;
+#endif
+
+// NOTE(wesm): cpplint does not like use of short and other imprecise C types
+struct LibHdfsShim {
+ LibraryHandle handle;
+
+ hdfsBuilder* (*hdfsNewBuilder)(void);
+ void (*hdfsBuilderSetNameNode)(hdfsBuilder* bld, const char* nn);
+ void (*hdfsBuilderSetNameNodePort)(hdfsBuilder* bld, tPort port);
+ void (*hdfsBuilderSetUserName)(hdfsBuilder* bld, const char* userName);
+ void (*hdfsBuilderSetKerbTicketCachePath)(hdfsBuilder* bld,
+ const char* kerbTicketCachePath);
+ void (*hdfsBuilderSetForceNewInstance)(hdfsBuilder* bld);
+ hdfsFS (*hdfsBuilderConnect)(hdfsBuilder* bld);
+ int (*hdfsBuilderConfSetStr)(hdfsBuilder* bld, const char* key, const char* val);
+
+ int (*hdfsDisconnect)(hdfsFS fs);
+
+ hdfsFile (*hdfsOpenFile)(hdfsFS fs, const char* path, int flags, int bufferSize,
+ short replication, tSize blocksize); // NOLINT
+
+ int (*hdfsCloseFile)(hdfsFS fs, hdfsFile file);
+ int (*hdfsExists)(hdfsFS fs, const char* path);
+ int (*hdfsSeek)(hdfsFS fs, hdfsFile file, tOffset desiredPos);
+ tOffset (*hdfsTell)(hdfsFS fs, hdfsFile file);
+ tSize (*hdfsRead)(hdfsFS fs, hdfsFile file, void* buffer, tSize length);
+ tSize (*hdfsPread)(hdfsFS fs, hdfsFile file, tOffset position, void* buffer,
+ tSize length);
+ tSize (*hdfsWrite)(hdfsFS fs, hdfsFile file, const void* buffer, tSize length);
+ int (*hdfsFlush)(hdfsFS fs, hdfsFile file);
+ int (*hdfsAvailable)(hdfsFS fs, hdfsFile file);
+ int (*hdfsCopy)(hdfsFS srcFS, const char* src, hdfsFS dstFS, const char* dst);
+ int (*hdfsMove)(hdfsFS srcFS, const char* src, hdfsFS dstFS, const char* dst);
+ int (*hdfsDelete)(hdfsFS fs, const char* path, int recursive);
+ int (*hdfsRename)(hdfsFS fs, const char* oldPath, const char* newPath);
+ char* (*hdfsGetWorkingDirectory)(hdfsFS fs, char* buffer, size_t bufferSize);
+ int (*hdfsSetWorkingDirectory)(hdfsFS fs, const char* path);
+ int (*hdfsCreateDirectory)(hdfsFS fs, const char* path);
+ int (*hdfsSetReplication)(hdfsFS fs, const char* path, int16_t replication);
+ hdfsFileInfo* (*hdfsListDirectory)(hdfsFS fs, const char* path, int* numEntries);
+ hdfsFileInfo* (*hdfsGetPathInfo)(hdfsFS fs, const char* path);
+ void (*hdfsFreeFileInfo)(hdfsFileInfo* hdfsFileInfo, int numEntries);
+ char*** (*hdfsGetHosts)(hdfsFS fs, const char* path, tOffset start, tOffset length);
+ void (*hdfsFreeHosts)(char*** blockHosts);
+ tOffset (*hdfsGetDefaultBlockSize)(hdfsFS fs);
+ tOffset (*hdfsGetCapacity)(hdfsFS fs);
+ tOffset (*hdfsGetUsed)(hdfsFS fs);
+ int (*hdfsChown)(hdfsFS fs, const char* path, const char* owner, const char* group);
+ int (*hdfsChmod)(hdfsFS fs, const char* path, short mode); // NOLINT
+ int (*hdfsUtime)(hdfsFS fs, const char* path, tTime mtime, tTime atime);
+
+ void Initialize() {
+ this->handle = nullptr;
+ this->hdfsNewBuilder = nullptr;
+ this->hdfsBuilderSetNameNode = nullptr;
+ this->hdfsBuilderSetNameNodePort = nullptr;
+ this->hdfsBuilderSetUserName = nullptr;
+ this->hdfsBuilderSetKerbTicketCachePath = nullptr;
+ this->hdfsBuilderSetForceNewInstance = nullptr;
+ this->hdfsBuilderConfSetStr = nullptr;
+ this->hdfsBuilderConnect = nullptr;
+ this->hdfsDisconnect = nullptr;
+ this->hdfsOpenFile = nullptr;
+ this->hdfsCloseFile = nullptr;
+ this->hdfsExists = nullptr;
+ this->hdfsSeek = nullptr;
+ this->hdfsTell = nullptr;
+ this->hdfsRead = nullptr;
+ this->hdfsPread = nullptr;
+ this->hdfsWrite = nullptr;
+ this->hdfsFlush = nullptr;
+ this->hdfsAvailable = nullptr;
+ this->hdfsCopy = nullptr;
+ this->hdfsMove = nullptr;
+ this->hdfsDelete = nullptr;
+ this->hdfsRename = nullptr;
+ this->hdfsGetWorkingDirectory = nullptr;
+ this->hdfsSetWorkingDirectory = nullptr;
+ this->hdfsCreateDirectory = nullptr;
+ this->hdfsSetReplication = nullptr;
+ this->hdfsListDirectory = nullptr;
+ this->hdfsGetPathInfo = nullptr;
+ this->hdfsFreeFileInfo = nullptr;
+ this->hdfsGetHosts = nullptr;
+ this->hdfsFreeHosts = nullptr;
+ this->hdfsGetDefaultBlockSize = nullptr;
+ this->hdfsGetCapacity = nullptr;
+ this->hdfsGetUsed = nullptr;
+ this->hdfsChown = nullptr;
+ this->hdfsChmod = nullptr;
+ this->hdfsUtime = nullptr;
+ }
+
+ hdfsBuilder* NewBuilder(void);
+
+ void BuilderSetNameNode(hdfsBuilder* bld, const char* nn);
+
+ void BuilderSetNameNodePort(hdfsBuilder* bld, tPort port);
+
+ void BuilderSetUserName(hdfsBuilder* bld, const char* userName);
+
+ void BuilderSetKerbTicketCachePath(hdfsBuilder* bld, const char* kerbTicketCachePath);
+
+ void BuilderSetForceNewInstance(hdfsBuilder* bld);
+
+ int BuilderConfSetStr(hdfsBuilder* bld, const char* key, const char* val);
+
+ hdfsFS BuilderConnect(hdfsBuilder* bld);
+
+ int Disconnect(hdfsFS fs);
+
+ hdfsFile OpenFile(hdfsFS fs, const char* path, int flags, int bufferSize,
+ short replication, tSize blocksize); // NOLINT
+
+ int CloseFile(hdfsFS fs, hdfsFile file);
+
+ int Exists(hdfsFS fs, const char* path);
+
+ int Seek(hdfsFS fs, hdfsFile file, tOffset desiredPos);
+
+ tOffset Tell(hdfsFS fs, hdfsFile file);
+
+ tSize Read(hdfsFS fs, hdfsFile file, void* buffer, tSize length);
+
+ bool HasPread();
+
+ tSize Pread(hdfsFS fs, hdfsFile file, tOffset position, void* buffer, tSize length);
+
+ tSize Write(hdfsFS fs, hdfsFile file, const void* buffer, tSize length);
+
+ int Flush(hdfsFS fs, hdfsFile file);
+
+ int Available(hdfsFS fs, hdfsFile file);
+
+ int Copy(hdfsFS srcFS, const char* src, hdfsFS dstFS, const char* dst);
+
+ int Move(hdfsFS srcFS, const char* src, hdfsFS dstFS, const char* dst);
+
+ int Delete(hdfsFS fs, const char* path, int recursive);
+
+ int Rename(hdfsFS fs, const char* oldPath, const char* newPath);
+
+ char* GetWorkingDirectory(hdfsFS fs, char* buffer, size_t bufferSize);
+
+ int SetWorkingDirectory(hdfsFS fs, const char* path);
+
+ int MakeDirectory(hdfsFS fs, const char* path);
+
+ int SetReplication(hdfsFS fs, const char* path, int16_t replication);
+
+ hdfsFileInfo* ListDirectory(hdfsFS fs, const char* path, int* numEntries);
+
+ hdfsFileInfo* GetPathInfo(hdfsFS fs, const char* path);
+
+ void FreeFileInfo(hdfsFileInfo* hdfsFileInfo, int numEntries);
+
+ char*** GetHosts(hdfsFS fs, const char* path, tOffset start, tOffset length);
+
+ void FreeHosts(char*** blockHosts);
+
+ tOffset GetDefaultBlockSize(hdfsFS fs);
+ tOffset GetCapacity(hdfsFS fs);
+
+ tOffset GetUsed(hdfsFS fs);
+
+ int Chown(hdfsFS fs, const char* path, const char* owner, const char* group);
+
+ int Chmod(hdfsFS fs, const char* path, short mode); // NOLINT
+
+ int Utime(hdfsFS fs, const char* path, tTime mtime, tTime atime);
+
+ Status GetRequiredSymbols();
+};
+
+// TODO(wesm): Remove these exports when we are linking statically
+Status ARROW_EXPORT ConnectLibHdfs(LibHdfsShim** driver);
+
+} // namespace internal
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/hdfs_test.cc b/src/arrow/cpp/src/arrow/io/hdfs_test.cc
new file mode 100644
index 000000000..2ebf95080
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/hdfs_test.cc
@@ -0,0 +1,464 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <atomic>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+#include <iostream>
+#include <memory>
+#include <sstream> // IWYU pragma: keep
+#include <string>
+#include <thread>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include <boost/filesystem.hpp> // NOLINT
+
+#include "arrow/buffer.h"
+#include "arrow/io/hdfs.h"
+#include "arrow/io/hdfs_internal.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+
+namespace arrow {
+namespace io {
+
+std::vector<uint8_t> RandomData(int64_t size) {
+ std::vector<uint8_t> buffer(size);
+ random_bytes(size, 0, buffer.data());
+ return buffer;
+}
+
+class TestHadoopFileSystem : public ::testing::Test {
+ public:
+ Status MakeScratchDir() {
+ if (client_->Exists(scratch_dir_)) {
+ RETURN_NOT_OK((client_->Delete(scratch_dir_, true)));
+ }
+ return client_->MakeDirectory(scratch_dir_);
+ }
+
+ Status WriteDummyFile(const std::string& path, const uint8_t* buffer, int64_t size,
+ bool append = false, int buffer_size = 0, int16_t replication = 0,
+ int default_block_size = 0) {
+ std::shared_ptr<HdfsOutputStream> file;
+ RETURN_NOT_OK(client_->OpenWritable(path, append, buffer_size, replication,
+ default_block_size, &file));
+
+ RETURN_NOT_OK(file->Write(buffer, size));
+ RETURN_NOT_OK(file->Close());
+
+ return Status::OK();
+ }
+
+ std::string ScratchPath(const std::string& name) {
+ std::stringstream ss;
+ ss << scratch_dir_ << "/" << name;
+ return ss.str();
+ }
+
+ std::string HdfsAbsPath(const std::string& relpath) {
+ std::stringstream ss;
+ ss << "hdfs://" << conf_.host << ":" << conf_.port << relpath;
+ return ss.str();
+ }
+
+ // Set up shared state between unit tests
+ void SetUp() {
+ internal::LibHdfsShim* driver_shim;
+
+ client_ = nullptr;
+ scratch_dir_ =
+ boost::filesystem::unique_path(boost::filesystem::temp_directory_path() /
+ "arrow-hdfs/scratch-%%%%")
+ .string();
+
+ loaded_driver_ = false;
+
+ Status msg = ConnectLibHdfs(&driver_shim);
+ if (!msg.ok()) {
+ if (std::getenv("ARROW_HDFS_TEST_LIBHDFS_REQUIRE")) {
+ FAIL() << "Loading libhdfs failed: " << msg.ToString();
+ } else {
+ std::cout << "Loading libhdfs failed, skipping tests gracefully: "
+ << msg.ToString() << std::endl;
+ }
+ return;
+ }
+
+ loaded_driver_ = true;
+
+ const char* host = std::getenv("ARROW_HDFS_TEST_HOST");
+ const char* port = std::getenv("ARROW_HDFS_TEST_PORT");
+ const char* user = std::getenv("ARROW_HDFS_TEST_USER");
+
+ ASSERT_TRUE(user != nullptr) << "Set ARROW_HDFS_TEST_USER";
+
+ conf_.host = host == nullptr ? "localhost" : host;
+ conf_.user = user;
+ conf_.port = port == nullptr ? 20500 : atoi(port);
+
+ ASSERT_OK(HadoopFileSystem::Connect(&conf_, &client_));
+ }
+
+ void TearDown() {
+ if (client_) {
+ if (client_->Exists(scratch_dir_)) {
+ ARROW_EXPECT_OK(client_->Delete(scratch_dir_, true));
+ }
+ ARROW_EXPECT_OK(client_->Disconnect());
+ }
+ }
+
+ HdfsConnectionConfig conf_;
+ bool loaded_driver_;
+
+ // Resources shared amongst unit tests
+ std::string scratch_dir_;
+ std::shared_ptr<HadoopFileSystem> client_;
+};
+
+#define SKIP_IF_NO_DRIVER() \
+ if (!this->loaded_driver_) { \
+ GTEST_SKIP() << "Driver not loaded, skipping"; \
+ }
+
+TEST_F(TestHadoopFileSystem, ConnectsAgain) {
+ SKIP_IF_NO_DRIVER();
+
+ std::shared_ptr<HadoopFileSystem> client;
+ ASSERT_OK(HadoopFileSystem::Connect(&this->conf_, &client));
+ ASSERT_OK(client->Disconnect());
+}
+
+TEST_F(TestHadoopFileSystem, MultipleClients) {
+ SKIP_IF_NO_DRIVER();
+
+ ASSERT_OK(this->MakeScratchDir());
+
+ std::shared_ptr<HadoopFileSystem> client1;
+ std::shared_ptr<HadoopFileSystem> client2;
+ ASSERT_OK(HadoopFileSystem::Connect(&this->conf_, &client1));
+ ASSERT_OK(HadoopFileSystem::Connect(&this->conf_, &client2));
+ ASSERT_OK(client1->Disconnect());
+
+ // client2 continues to function after equivalent client1 has shutdown
+ std::vector<HdfsPathInfo> listing;
+ ASSERT_OK(client2->ListDirectory(this->scratch_dir_, &listing));
+ ASSERT_OK(client2->Disconnect());
+}
+
+TEST_F(TestHadoopFileSystem, MakeDirectory) {
+ SKIP_IF_NO_DRIVER();
+
+ std::string path = this->ScratchPath("create-directory");
+
+ if (this->client_->Exists(path)) {
+ ASSERT_OK(this->client_->Delete(path, true));
+ }
+
+ ASSERT_OK(this->client_->MakeDirectory(path));
+ ASSERT_TRUE(this->client_->Exists(path));
+ std::vector<HdfsPathInfo> listing;
+ ARROW_EXPECT_OK(this->client_->ListDirectory(path, &listing));
+ ASSERT_EQ(0, listing.size());
+ ARROW_EXPECT_OK(this->client_->Delete(path, true));
+ ASSERT_FALSE(this->client_->Exists(path));
+ ASSERT_RAISES(IOError, this->client_->ListDirectory(path, &listing));
+}
+
+TEST_F(TestHadoopFileSystem, GetCapacityUsed) {
+ SKIP_IF_NO_DRIVER();
+
+ // Who knows what is actually in your DFS cluster, but expect it to have
+ // positive used bytes and capacity
+ int64_t nbytes = 0;
+ ASSERT_OK(this->client_->GetCapacity(&nbytes));
+ ASSERT_LT(0, nbytes);
+
+ ASSERT_OK(this->client_->GetUsed(&nbytes));
+ ASSERT_LT(0, nbytes);
+}
+
+TEST_F(TestHadoopFileSystem, GetPathInfo) {
+ SKIP_IF_NO_DRIVER();
+
+ HdfsPathInfo info;
+
+ ASSERT_OK(this->MakeScratchDir());
+
+ // Directory info
+ ASSERT_OK(this->client_->GetPathInfo(this->scratch_dir_, &info));
+ ASSERT_EQ(ObjectType::DIRECTORY, info.kind);
+ ASSERT_EQ(this->HdfsAbsPath(this->scratch_dir_), info.name);
+ ASSERT_EQ(this->conf_.user, info.owner);
+
+ // TODO(wesm): test group, other attrs
+
+ auto path = this->ScratchPath("test-file");
+
+ const int size = 100;
+
+ std::vector<uint8_t> buffer = RandomData(size);
+
+ ASSERT_OK(this->WriteDummyFile(path, buffer.data(), size));
+ ASSERT_OK(this->client_->GetPathInfo(path, &info));
+
+ ASSERT_EQ(ObjectType::FILE, info.kind);
+ ASSERT_EQ(this->HdfsAbsPath(path), info.name);
+ ASSERT_EQ(this->conf_.user, info.owner);
+ ASSERT_EQ(size, info.size);
+}
+
+TEST_F(TestHadoopFileSystem, GetPathInfoNotExist) {
+ // ARROW-2919: Test that the error message is reasonable
+ SKIP_IF_NO_DRIVER();
+
+ ASSERT_OK(this->MakeScratchDir());
+ auto path = this->ScratchPath("path-does-not-exist");
+
+ HdfsPathInfo info;
+ Status s = this->client_->GetPathInfo(path, &info);
+ ASSERT_TRUE(s.IsIOError());
+
+ const std::string error_message = s.ToString();
+
+ // Check that the file path is found in the error message
+ ASSERT_LT(error_message.find(path), std::string::npos);
+}
+
+TEST_F(TestHadoopFileSystem, AppendToFile) {
+ SKIP_IF_NO_DRIVER();
+
+ ASSERT_OK(this->MakeScratchDir());
+
+ auto path = this->ScratchPath("test-file");
+ const int size = 100;
+
+ std::vector<uint8_t> buffer = RandomData(size);
+ ASSERT_OK(this->WriteDummyFile(path, buffer.data(), size));
+
+ // now append
+ ASSERT_OK(this->WriteDummyFile(path, buffer.data(), size, true));
+
+ HdfsPathInfo info;
+ ASSERT_OK(this->client_->GetPathInfo(path, &info));
+ ASSERT_EQ(size * 2, info.size);
+}
+
+TEST_F(TestHadoopFileSystem, ListDirectory) {
+ SKIP_IF_NO_DRIVER();
+
+ const int size = 100;
+ std::vector<uint8_t> data = RandomData(size);
+
+ auto p1 = this->ScratchPath("test-file-1");
+ auto p2 = this->ScratchPath("test-file-2");
+ auto d1 = this->ScratchPath("test-dir-1");
+
+ ASSERT_OK(this->MakeScratchDir());
+ ASSERT_OK(this->WriteDummyFile(p1, data.data(), size));
+ ASSERT_OK(this->WriteDummyFile(p2, data.data(), size / 2));
+ ASSERT_OK(this->client_->MakeDirectory(d1));
+
+ std::vector<HdfsPathInfo> listing;
+ ASSERT_OK(this->client_->ListDirectory(this->scratch_dir_, &listing));
+
+ // Do it again, appends!
+ ASSERT_OK(this->client_->ListDirectory(this->scratch_dir_, &listing));
+
+ ASSERT_EQ(6, static_cast<int>(listing.size()));
+
+ // Argh, well, shouldn't expect the listing to be in any particular order
+ for (size_t i = 0; i < listing.size(); ++i) {
+ const HdfsPathInfo& info = listing[i];
+ if (info.name == this->HdfsAbsPath(p1)) {
+ ASSERT_EQ(ObjectType::FILE, info.kind);
+ ASSERT_EQ(size, info.size);
+ } else if (info.name == this->HdfsAbsPath(p2)) {
+ ASSERT_EQ(ObjectType::FILE, info.kind);
+ ASSERT_EQ(size / 2, info.size);
+ } else if (info.name == this->HdfsAbsPath(d1)) {
+ ASSERT_EQ(ObjectType::DIRECTORY, info.kind);
+ } else {
+ FAIL() << "Unexpected path: " << info.name;
+ }
+ }
+}
+
+TEST_F(TestHadoopFileSystem, ReadableMethods) {
+ SKIP_IF_NO_DRIVER();
+
+ ASSERT_OK(this->MakeScratchDir());
+
+ auto path = this->ScratchPath("test-file");
+ const int size = 100;
+
+ std::vector<uint8_t> data = RandomData(size);
+ ASSERT_OK(this->WriteDummyFile(path, data.data(), size));
+
+ std::shared_ptr<HdfsReadableFile> file;
+ ASSERT_OK(this->client_->OpenReadable(path, &file));
+
+ // Test GetSize -- move this into its own unit test if ever needed
+ ASSERT_OK_AND_EQ(size, file->GetSize());
+
+ uint8_t buffer[50];
+
+ ASSERT_OK_AND_EQ(50, file->Read(50, buffer));
+ ASSERT_EQ(0, std::memcmp(buffer, data.data(), 50));
+
+ ASSERT_OK_AND_EQ(50, file->Read(50, buffer));
+ ASSERT_EQ(0, std::memcmp(buffer, data.data() + 50, 50));
+
+ // EOF
+ ASSERT_OK_AND_EQ(0, file->Read(1, buffer));
+
+ // ReadAt to EOF
+ ASSERT_OK_AND_EQ(40, file->ReadAt(60, 100, buffer));
+ ASSERT_EQ(0, std::memcmp(buffer, data.data() + 60, 40));
+
+ // Seek, Tell
+ ASSERT_OK(file->Seek(60));
+ ASSERT_OK_AND_EQ(60, file->Tell());
+}
+
+TEST_F(TestHadoopFileSystem, LargeFile) {
+ SKIP_IF_NO_DRIVER();
+
+ ASSERT_OK(this->MakeScratchDir());
+
+ auto path = this->ScratchPath("test-large-file");
+ const int size = 1000000;
+
+ std::vector<uint8_t> data = RandomData(size);
+ ASSERT_OK(this->WriteDummyFile(path, data.data(), size));
+
+ std::shared_ptr<HdfsReadableFile> file;
+ ASSERT_OK(this->client_->OpenReadable(path, &file));
+
+ ASSERT_FALSE(file->closed());
+
+ ASSERT_OK_AND_ASSIGN(auto buffer, AllocateBuffer(size));
+
+ ASSERT_OK_AND_EQ(size, file->Read(size, buffer->mutable_data()));
+ ASSERT_EQ(0, std::memcmp(buffer->data(), data.data(), size));
+
+ // explicit buffer size
+ std::shared_ptr<HdfsReadableFile> file2;
+ ASSERT_OK(this->client_->OpenReadable(path, 1 << 18, &file2));
+
+ ASSERT_OK_AND_ASSIGN(auto buffer2, AllocateBuffer(size));
+
+ ASSERT_OK_AND_EQ(size, file2->Read(size, buffer2->mutable_data()));
+ ASSERT_EQ(0, std::memcmp(buffer2->data(), data.data(), size));
+}
+
+TEST_F(TestHadoopFileSystem, RenameFile) {
+ SKIP_IF_NO_DRIVER();
+ ASSERT_OK(this->MakeScratchDir());
+
+ auto src_path = this->ScratchPath("src-file");
+ auto dst_path = this->ScratchPath("dst-file");
+ const int size = 100;
+
+ std::vector<uint8_t> data = RandomData(size);
+ ASSERT_OK(this->WriteDummyFile(src_path, data.data(), size));
+
+ ASSERT_OK(this->client_->Rename(src_path, dst_path));
+
+ ASSERT_FALSE(this->client_->Exists(src_path));
+ ASSERT_TRUE(this->client_->Exists(dst_path));
+}
+
+TEST_F(TestHadoopFileSystem, ChmodChown) {
+ SKIP_IF_NO_DRIVER();
+ ASSERT_OK(this->MakeScratchDir());
+
+ auto path = this->ScratchPath("path-to-chmod");
+
+ int16_t mode = 0755;
+ const int size = 100;
+
+ std::vector<uint8_t> data = RandomData(size);
+ ASSERT_OK(this->WriteDummyFile(path, data.data(), size));
+
+ HdfsPathInfo info;
+ ASSERT_OK(this->client_->Chmod(path, mode));
+ ASSERT_OK(this->client_->GetPathInfo(path, &info));
+ ASSERT_EQ(mode, info.permissions);
+
+ std::string owner = "hadoop";
+ std::string group = "hadoop";
+ ASSERT_OK(this->client_->Chown(path, owner.c_str(), group.c_str()));
+ ASSERT_OK(this->client_->GetPathInfo(path, &info));
+ ASSERT_EQ("hadoop", info.owner);
+ ASSERT_EQ("hadoop", info.group);
+}
+
+TEST_F(TestHadoopFileSystem, ThreadSafety) {
+ SKIP_IF_NO_DRIVER();
+ ASSERT_OK(this->MakeScratchDir());
+
+ auto src_path = this->ScratchPath("threadsafety");
+
+ std::string data = "foobar";
+ ASSERT_OK(this->WriteDummyFile(src_path, reinterpret_cast<const uint8_t*>(data.c_str()),
+ static_cast<int64_t>(data.size())));
+
+ std::shared_ptr<HdfsReadableFile> file;
+ ASSERT_OK(this->client_->OpenReadable(src_path, &file));
+
+ std::atomic<int> correct_count(0);
+ int niter = 1000;
+
+ auto ReadData = [&file, &correct_count, &data, &niter]() {
+ for (int i = 0; i < niter; ++i) {
+ std::shared_ptr<Buffer> buffer;
+ if (i % 2 == 0) {
+ ASSERT_OK_AND_ASSIGN(buffer, file->ReadAt(3, 3));
+ if (0 == memcmp(data.c_str() + 3, buffer->data(), 3)) {
+ correct_count += 1;
+ }
+ } else {
+ ASSERT_OK_AND_ASSIGN(buffer, file->ReadAt(0, 4));
+ if (0 == memcmp(data.c_str() + 0, buffer->data(), 4)) {
+ correct_count += 1;
+ }
+ }
+ }
+ };
+
+ std::thread thread1(ReadData);
+ std::thread thread2(ReadData);
+ std::thread thread3(ReadData);
+ std::thread thread4(ReadData);
+
+ thread1.join();
+ thread2.join();
+ thread3.join();
+ thread4.join();
+
+ ASSERT_EQ(niter * 4, correct_count);
+}
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/interfaces.cc b/src/arrow/cpp/src/arrow/io/interfaces.cc
new file mode 100644
index 000000000..954c0f37b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/interfaces.cc
@@ -0,0 +1,469 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/io/interfaces.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <iterator>
+#include <list>
+#include <memory>
+#include <mutex>
+#include <sstream>
+#include <typeinfo>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/io/concurrency.h"
+#include "arrow/io/type_fwd.h"
+#include "arrow/io/util_internal.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_pointer_cast;
+using internal::Executor;
+using internal::TaskHints;
+using internal::ThreadPool;
+
+namespace io {
+
+static IOContext g_default_io_context{};
+
+IOContext::IOContext(MemoryPool* pool, StopToken stop_token)
+ : IOContext(pool, internal::GetIOThreadPool(), std::move(stop_token)) {}
+
+const IOContext& default_io_context() { return g_default_io_context; }
+
+int GetIOThreadPoolCapacity() { return internal::GetIOThreadPool()->GetCapacity(); }
+
+Status SetIOThreadPoolCapacity(int threads) {
+ return internal::GetIOThreadPool()->SetCapacity(threads);
+}
+
+FileInterface::~FileInterface() = default;
+
+Status FileInterface::Abort() { return Close(); }
+
+namespace {
+
+class InputStreamBlockIterator {
+ public:
+ InputStreamBlockIterator(std::shared_ptr<InputStream> stream, int64_t block_size)
+ : stream_(std::move(stream)), block_size_(block_size) {}
+
+ Result<std::shared_ptr<Buffer>> Next() {
+ if (done_) {
+ return nullptr;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto out, stream_->Read(block_size_));
+
+ if (out->size() == 0) {
+ done_ = true;
+ stream_.reset();
+ out.reset();
+ }
+
+ return out;
+ }
+
+ protected:
+ std::shared_ptr<InputStream> stream_;
+ int64_t block_size_;
+ bool done_ = false;
+};
+
+} // namespace
+
+const IOContext& Readable::io_context() const { return g_default_io_context; }
+
+Status InputStream::Advance(int64_t nbytes) { return Read(nbytes).status(); }
+
+Result<util::string_view> InputStream::Peek(int64_t ARROW_ARG_UNUSED(nbytes)) {
+ return Status::NotImplemented("Peek not implemented");
+}
+
+bool InputStream::supports_zero_copy() const { return false; }
+
+Result<std::shared_ptr<const KeyValueMetadata>> InputStream::ReadMetadata() {
+ return std::shared_ptr<const KeyValueMetadata>{};
+}
+
+// Default ReadMetadataAsync() implementation: simply issue the read on the context's
+// executor
+Future<std::shared_ptr<const KeyValueMetadata>> InputStream::ReadMetadataAsync(
+ const IOContext& ctx) {
+ auto self = shared_from_this();
+ return DeferNotOk(internal::SubmitIO(ctx, [self] { return self->ReadMetadata(); }));
+}
+
+Future<std::shared_ptr<const KeyValueMetadata>> InputStream::ReadMetadataAsync() {
+ return ReadMetadataAsync(io_context());
+}
+
+Result<Iterator<std::shared_ptr<Buffer>>> MakeInputStreamIterator(
+ std::shared_ptr<InputStream> stream, int64_t block_size) {
+ if (stream->closed()) {
+ return Status::Invalid("Cannot take iterator on closed stream");
+ }
+ DCHECK_GT(block_size, 0);
+ return Iterator<std::shared_ptr<Buffer>>(InputStreamBlockIterator(stream, block_size));
+}
+
+struct RandomAccessFile::Impl {
+ std::mutex lock_;
+};
+
+RandomAccessFile::~RandomAccessFile() = default;
+
+RandomAccessFile::RandomAccessFile() : interface_impl_(new Impl()) {}
+
+Result<int64_t> RandomAccessFile::ReadAt(int64_t position, int64_t nbytes, void* out) {
+ std::lock_guard<std::mutex> lock(interface_impl_->lock_);
+ RETURN_NOT_OK(Seek(position));
+ return Read(nbytes, out);
+}
+
+Result<std::shared_ptr<Buffer>> RandomAccessFile::ReadAt(int64_t position,
+ int64_t nbytes) {
+ std::lock_guard<std::mutex> lock(interface_impl_->lock_);
+ RETURN_NOT_OK(Seek(position));
+ return Read(nbytes);
+}
+
+// Default ReadAsync() implementation: simply issue the read on the context's executor
+Future<std::shared_ptr<Buffer>> RandomAccessFile::ReadAsync(const IOContext& ctx,
+ int64_t position,
+ int64_t nbytes) {
+ auto self = checked_pointer_cast<RandomAccessFile>(shared_from_this());
+ return DeferNotOk(internal::SubmitIO(
+ ctx, [self, position, nbytes] { return self->ReadAt(position, nbytes); }));
+}
+
+Future<std::shared_ptr<Buffer>> RandomAccessFile::ReadAsync(int64_t position,
+ int64_t nbytes) {
+ return ReadAsync(io_context(), position, nbytes);
+}
+
+// Default WillNeed() implementation: no-op
+Status RandomAccessFile::WillNeed(const std::vector<ReadRange>& ranges) {
+ return Status::OK();
+}
+
+Status Writable::Write(util::string_view data) {
+ return Write(data.data(), static_cast<int64_t>(data.size()));
+}
+
+Status Writable::Write(const std::shared_ptr<Buffer>& data) {
+ return Write(data->data(), data->size());
+}
+
+Status Writable::Flush() { return Status::OK(); }
+
+// An InputStream that reads from a delimited range of a RandomAccessFile
+class FileSegmentReader
+ : public internal::InputStreamConcurrencyWrapper<FileSegmentReader> {
+ public:
+ FileSegmentReader(std::shared_ptr<RandomAccessFile> file, int64_t file_offset,
+ int64_t nbytes)
+ : file_(std::move(file)),
+ closed_(false),
+ position_(0),
+ file_offset_(file_offset),
+ nbytes_(nbytes) {
+ FileInterface::set_mode(FileMode::READ);
+ }
+
+ Status CheckOpen() const {
+ if (closed_) {
+ return Status::IOError("Stream is closed");
+ }
+ return Status::OK();
+ }
+
+ Status DoClose() {
+ closed_ = true;
+ return Status::OK();
+ }
+
+ Result<int64_t> DoTell() const {
+ RETURN_NOT_OK(CheckOpen());
+ return position_;
+ }
+
+ bool closed() const override { return closed_; }
+
+ Result<int64_t> DoRead(int64_t nbytes, void* out) {
+ RETURN_NOT_OK(CheckOpen());
+ int64_t bytes_to_read = std::min(nbytes, nbytes_ - position_);
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read,
+ file_->ReadAt(file_offset_ + position_, bytes_to_read, out));
+ position_ += bytes_read;
+ return bytes_read;
+ }
+
+ Result<std::shared_ptr<Buffer>> DoRead(int64_t nbytes) {
+ RETURN_NOT_OK(CheckOpen());
+ int64_t bytes_to_read = std::min(nbytes, nbytes_ - position_);
+ ARROW_ASSIGN_OR_RAISE(auto buffer,
+ file_->ReadAt(file_offset_ + position_, bytes_to_read));
+ position_ += buffer->size();
+ return buffer;
+ }
+
+ private:
+ std::shared_ptr<RandomAccessFile> file_;
+ bool closed_;
+ int64_t position_;
+ int64_t file_offset_;
+ int64_t nbytes_;
+};
+
+std::shared_ptr<InputStream> RandomAccessFile::GetStream(
+ std::shared_ptr<RandomAccessFile> file, int64_t file_offset, int64_t nbytes) {
+ return std::make_shared<FileSegmentReader>(std::move(file), file_offset, nbytes);
+}
+
+// -----------------------------------------------------------------------
+// Implement utilities exported from concurrency.h and util_internal.h
+
+namespace internal {
+
+void CloseFromDestructor(FileInterface* file) {
+ Status st = file->Close();
+ if (!st.ok()) {
+ auto file_type = typeid(*file).name();
+#ifdef NDEBUG
+ ARROW_LOG(ERROR) << "Error ignored when destroying file of type " << file_type << ": "
+ << st;
+#else
+ std::stringstream ss;
+ ss << "When destroying file of type " << file_type << ": " << st.message();
+ ARROW_LOG(FATAL) << st.WithMessage(ss.str());
+#endif
+ }
+}
+
+Result<int64_t> ValidateReadRange(int64_t offset, int64_t size, int64_t file_size) {
+ if (offset < 0 || size < 0) {
+ return Status::Invalid("Invalid read (offset = ", offset, ", size = ", size, ")");
+ }
+ if (offset > file_size) {
+ return Status::IOError("Read out of bounds (offset = ", offset, ", size = ", size,
+ ") in file of size ", file_size);
+ }
+ return std::min(size, file_size - offset);
+}
+
+Status ValidateWriteRange(int64_t offset, int64_t size, int64_t file_size) {
+ if (offset < 0 || size < 0) {
+ return Status::Invalid("Invalid write (offset = ", offset, ", size = ", size, ")");
+ }
+ if (offset + size > file_size) {
+ return Status::IOError("Write out of bounds (offset = ", offset, ", size = ", size,
+ ") in file of size ", file_size);
+ }
+ return Status::OK();
+}
+
+Status ValidateRange(int64_t offset, int64_t size) {
+ if (offset < 0 || size < 0) {
+ return Status::Invalid("Invalid IO range (offset = ", offset, ", size = ", size, ")");
+ }
+ return Status::OK();
+}
+
+#ifndef NDEBUG
+
+// Debug mode concurrency checking
+
+struct SharedExclusiveChecker::Impl {
+ std::mutex mutex;
+ int64_t n_shared = 0;
+ int64_t n_exclusive = 0;
+};
+
+SharedExclusiveChecker::SharedExclusiveChecker() : impl_(new Impl) {}
+
+void SharedExclusiveChecker::LockShared() {
+ std::lock_guard<std::mutex> lock(impl_->mutex);
+ // XXX The error message doesn't really describe the actual situation
+ // (e.g. ReadAt() called while Read() call in progress)
+ ARROW_CHECK_EQ(impl_->n_exclusive, 0)
+ << "Attempted to take shared lock while locked exclusive";
+ ++impl_->n_shared;
+}
+
+void SharedExclusiveChecker::UnlockShared() {
+ std::lock_guard<std::mutex> lock(impl_->mutex);
+ ARROW_CHECK_GT(impl_->n_shared, 0);
+ --impl_->n_shared;
+}
+
+void SharedExclusiveChecker::LockExclusive() {
+ std::lock_guard<std::mutex> lock(impl_->mutex);
+ ARROW_CHECK_EQ(impl_->n_shared, 0)
+ << "Attempted to take exclusive lock while locked shared";
+ ARROW_CHECK_EQ(impl_->n_exclusive, 0)
+ << "Attempted to take exclusive lock while already locked exclusive";
+ ++impl_->n_exclusive;
+}
+
+void SharedExclusiveChecker::UnlockExclusive() {
+ std::lock_guard<std::mutex> lock(impl_->mutex);
+ ARROW_CHECK_EQ(impl_->n_exclusive, 1);
+ --impl_->n_exclusive;
+}
+
+#else
+
+// Release mode no-op concurrency checking
+
+struct SharedExclusiveChecker::Impl {};
+
+SharedExclusiveChecker::SharedExclusiveChecker() {}
+
+void SharedExclusiveChecker::LockShared() {}
+void SharedExclusiveChecker::UnlockShared() {}
+void SharedExclusiveChecker::LockExclusive() {}
+void SharedExclusiveChecker::UnlockExclusive() {}
+
+#endif
+
+static std::shared_ptr<ThreadPool> MakeIOThreadPool() {
+ auto maybe_pool = ThreadPool::MakeEternal(/*threads=*/8);
+ if (!maybe_pool.ok()) {
+ maybe_pool.status().Abort("Failed to create global IO thread pool");
+ }
+ return *std::move(maybe_pool);
+}
+
+ThreadPool* GetIOThreadPool() {
+ static std::shared_ptr<ThreadPool> pool = MakeIOThreadPool();
+ return pool.get();
+}
+
+// -----------------------------------------------------------------------
+// CoalesceReadRanges
+
+namespace {
+
+struct ReadRangeCombiner {
+ std::vector<ReadRange> Coalesce(std::vector<ReadRange> ranges) {
+ if (ranges.empty()) {
+ return ranges;
+ }
+
+ // Remove zero-sized ranges
+ auto end = std::remove_if(ranges.begin(), ranges.end(),
+ [](const ReadRange& range) { return range.length == 0; });
+ // Sort in position order
+ std::sort(ranges.begin(), end,
+ [](const ReadRange& a, const ReadRange& b) { return a.offset < b.offset; });
+ // Remove ranges that overlap 100%
+ end = std::unique(ranges.begin(), end,
+ [](const ReadRange& left, const ReadRange& right) {
+ return right.offset >= left.offset &&
+ right.offset + right.length <= left.offset + left.length;
+ });
+ ranges.resize(end - ranges.begin());
+
+ // Skip further processing if ranges is empty after removing zero-sized ranges.
+ if (ranges.empty()) {
+ return ranges;
+ }
+
+#ifndef NDEBUG
+ for (size_t i = 0; i < ranges.size() - 1; ++i) {
+ const auto& left = ranges[i];
+ const auto& right = ranges[i + 1];
+ DCHECK_LE(left.offset, right.offset);
+ DCHECK_LE(left.offset + left.length, right.offset) << "Some read ranges overlap";
+ }
+#endif
+
+ std::vector<ReadRange> coalesced;
+
+ auto itr = ranges.begin();
+ // Ensure ranges is not empty.
+ DCHECK_LE(itr, ranges.end());
+ // Start of the current coalesced range and end (exclusive) of previous range.
+ // Both are initialized with the start of first range which is a placeholder value.
+ int64_t coalesced_start = itr->offset;
+ int64_t prev_range_end = coalesced_start;
+
+ for (; itr < ranges.end(); ++itr) {
+ const int64_t current_range_start = itr->offset;
+ const int64_t current_range_end = current_range_start + itr->length;
+ // We don't expect to have 0 sized ranges.
+ DCHECK_LT(current_range_start, current_range_end);
+
+ // At this point, the coalesced range is [coalesced_start, prev_range_end).
+ // Stop coalescing if:
+ // - coalesced range is too large, or
+ // - distance (hole/gap) between consecutive ranges is too large.
+ if (current_range_end - coalesced_start > range_size_limit_ ||
+ current_range_start - prev_range_end > hole_size_limit_) {
+ DCHECK_LE(coalesced_start, prev_range_end);
+ // Append the coalesced range only if coalesced range size > 0.
+ if (prev_range_end > coalesced_start) {
+ coalesced.push_back({coalesced_start, prev_range_end - coalesced_start});
+ }
+ // Start a new coalesced range.
+ coalesced_start = current_range_start;
+ }
+
+ // Update the prev_range_end with the current range.
+ prev_range_end = current_range_end;
+ }
+ // Append the coalesced range only if coalesced range size > 0.
+ if (prev_range_end > coalesced_start) {
+ coalesced.push_back({coalesced_start, prev_range_end - coalesced_start});
+ }
+
+ DCHECK_EQ(coalesced.front().offset, ranges.front().offset);
+ DCHECK_EQ(coalesced.back().offset + coalesced.back().length,
+ ranges.back().offset + ranges.back().length);
+ return coalesced;
+ }
+
+ const int64_t hole_size_limit_;
+ const int64_t range_size_limit_;
+};
+
+}; // namespace
+
+std::vector<ReadRange> CoalesceReadRanges(std::vector<ReadRange> ranges,
+ int64_t hole_size_limit,
+ int64_t range_size_limit) {
+ DCHECK_GT(range_size_limit, hole_size_limit);
+
+ ReadRangeCombiner combiner{hole_size_limit, range_size_limit};
+ return combiner.Coalesce(std::move(ranges));
+}
+
+} // namespace internal
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/interfaces.h b/src/arrow/cpp/src/arrow/io/interfaces.h
new file mode 100644
index 000000000..e524afa99
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/interfaces.h
@@ -0,0 +1,340 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/io/type_fwd.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/cancel.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace io {
+
+struct ReadRange {
+ int64_t offset;
+ int64_t length;
+
+ friend bool operator==(const ReadRange& left, const ReadRange& right) {
+ return (left.offset == right.offset && left.length == right.length);
+ }
+ friend bool operator!=(const ReadRange& left, const ReadRange& right) {
+ return !(left == right);
+ }
+
+ bool Contains(const ReadRange& other) const {
+ return (offset <= other.offset && offset + length >= other.offset + other.length);
+ }
+};
+
+/// EXPERIMENTAL: options provider for IO tasks
+///
+/// Includes an Executor (which will be used to execute asynchronous reads),
+/// a MemoryPool (which will be used to allocate buffers when zero copy reads
+/// are not possible), and an external id (in case the executor receives tasks from
+/// multiple sources and must distinguish tasks associated with this IOContext).
+struct ARROW_EXPORT IOContext {
+ // No specified executor: will use a global IO thread pool
+ IOContext() : IOContext(default_memory_pool(), StopToken::Unstoppable()) {}
+
+ explicit IOContext(StopToken stop_token)
+ : IOContext(default_memory_pool(), std::move(stop_token)) {}
+
+ explicit IOContext(MemoryPool* pool, StopToken stop_token = StopToken::Unstoppable());
+
+ explicit IOContext(MemoryPool* pool, ::arrow::internal::Executor* executor,
+ StopToken stop_token = StopToken::Unstoppable(),
+ int64_t external_id = -1)
+ : pool_(pool),
+ executor_(executor),
+ external_id_(external_id),
+ stop_token_(std::move(stop_token)) {}
+
+ explicit IOContext(::arrow::internal::Executor* executor,
+ StopToken stop_token = StopToken::Unstoppable(),
+ int64_t external_id = -1)
+ : pool_(default_memory_pool()),
+ executor_(executor),
+ external_id_(external_id),
+ stop_token_(std::move(stop_token)) {}
+
+ MemoryPool* pool() const { return pool_; }
+
+ ::arrow::internal::Executor* executor() const { return executor_; }
+
+ // An application-specific ID, forwarded to executor task submissions
+ int64_t external_id() const { return external_id_; }
+
+ StopToken stop_token() const { return stop_token_; }
+
+ private:
+ MemoryPool* pool_;
+ ::arrow::internal::Executor* executor_;
+ int64_t external_id_;
+ StopToken stop_token_;
+};
+
+struct ARROW_DEPRECATED("renamed to IOContext in 4.0.0") AsyncContext : public IOContext {
+ using IOContext::IOContext;
+};
+
+class ARROW_EXPORT FileInterface {
+ public:
+ virtual ~FileInterface() = 0;
+
+ /// \brief Close the stream cleanly
+ ///
+ /// For writable streams, this will attempt to flush any pending data
+ /// before releasing the underlying resource.
+ ///
+ /// After Close() is called, closed() returns true and the stream is not
+ /// available for further operations.
+ virtual Status Close() = 0;
+
+ /// \brief Close the stream abruptly
+ ///
+ /// This method does not guarantee that any pending data is flushed.
+ /// It merely releases any underlying resource used by the stream for
+ /// its operation.
+ ///
+ /// After Abort() is called, closed() returns true and the stream is not
+ /// available for further operations.
+ virtual Status Abort();
+
+ /// \brief Return the position in this stream
+ virtual Result<int64_t> Tell() const = 0;
+
+ /// \brief Return whether the stream is closed
+ virtual bool closed() const = 0;
+
+ FileMode::type mode() const { return mode_; }
+
+ protected:
+ FileInterface() : mode_(FileMode::READ) {}
+ FileMode::type mode_;
+ void set_mode(FileMode::type mode) { mode_ = mode; }
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(FileInterface);
+};
+
+class ARROW_EXPORT Seekable {
+ public:
+ virtual ~Seekable() = default;
+ virtual Status Seek(int64_t position) = 0;
+};
+
+class ARROW_EXPORT Writable {
+ public:
+ virtual ~Writable() = default;
+
+ /// \brief Write the given data to the stream
+ ///
+ /// This method always processes the bytes in full. Depending on the
+ /// semantics of the stream, the data may be written out immediately,
+ /// held in a buffer, or written asynchronously. In the case where
+ /// the stream buffers the data, it will be copied. To avoid potentially
+ /// large copies, use the Write variant that takes an owned Buffer.
+ virtual Status Write(const void* data, int64_t nbytes) = 0;
+
+ /// \brief Write the given data to the stream
+ ///
+ /// Since the Buffer owns its memory, this method can avoid a copy if
+ /// buffering is required. See Write(const void*, int64_t) for details.
+ virtual Status Write(const std::shared_ptr<Buffer>& data);
+
+ /// \brief Flush buffered bytes, if any
+ virtual Status Flush();
+
+ Status Write(util::string_view data);
+};
+
+class ARROW_EXPORT Readable {
+ public:
+ virtual ~Readable() = default;
+
+ /// \brief Read data from current file position.
+ ///
+ /// Read at most `nbytes` from the current file position into `out`.
+ /// The number of bytes read is returned.
+ virtual Result<int64_t> Read(int64_t nbytes, void* out) = 0;
+
+ /// \brief Read data from current file position.
+ ///
+ /// Read at most `nbytes` from the current file position. Less bytes may
+ /// be read if EOF is reached. This method updates the current file position.
+ ///
+ /// In some cases (e.g. a memory-mapped file), this method may avoid a
+ /// memory copy.
+ virtual Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) = 0;
+
+ /// EXPERIMENTAL: The IOContext associated with this file.
+ ///
+ /// By default, this is the same as default_io_context(), but it may be
+ /// overriden by subclasses.
+ virtual const IOContext& io_context() const;
+};
+
+class ARROW_EXPORT OutputStream : virtual public FileInterface, public Writable {
+ protected:
+ OutputStream() = default;
+};
+
+class ARROW_EXPORT InputStream : virtual public FileInterface,
+ virtual public Readable,
+ public std::enable_shared_from_this<InputStream> {
+ public:
+ /// \brief Advance or skip stream indicated number of bytes
+ /// \param[in] nbytes the number to move forward
+ /// \return Status
+ Status Advance(int64_t nbytes);
+
+ /// \brief Return zero-copy string_view to upcoming bytes.
+ ///
+ /// Do not modify the stream position. The view becomes invalid after
+ /// any operation on the stream. May trigger buffering if the requested
+ /// size is larger than the number of buffered bytes.
+ ///
+ /// May return NotImplemented on streams that don't support it.
+ ///
+ /// \param[in] nbytes the maximum number of bytes to see
+ virtual Result<util::string_view> Peek(int64_t nbytes);
+
+ /// \brief Return true if InputStream is capable of zero copy Buffer reads
+ ///
+ /// Zero copy reads imply the use of Buffer-returning Read() overloads.
+ virtual bool supports_zero_copy() const;
+
+ /// \brief Read and return stream metadata
+ ///
+ /// If the stream implementation doesn't support metadata, empty metadata
+ /// is returned. Note that it is allowed to return a null pointer rather
+ /// than an allocated empty metadata.
+ virtual Result<std::shared_ptr<const KeyValueMetadata>> ReadMetadata();
+
+ /// \brief Read stream metadata asynchronously
+ virtual Future<std::shared_ptr<const KeyValueMetadata>> ReadMetadataAsync(
+ const IOContext& io_context);
+ Future<std::shared_ptr<const KeyValueMetadata>> ReadMetadataAsync();
+
+ protected:
+ InputStream() = default;
+};
+
+class ARROW_EXPORT RandomAccessFile : public InputStream, public Seekable {
+ public:
+ /// Necessary because we hold a std::unique_ptr
+ ~RandomAccessFile() override;
+
+ /// \brief Create an isolated InputStream that reads a segment of a
+ /// RandomAccessFile. Multiple such stream can be created and used
+ /// independently without interference
+ /// \param[in] file a file instance
+ /// \param[in] file_offset the starting position in the file
+ /// \param[in] nbytes the extent of bytes to read. The file should have
+ /// sufficient bytes available
+ static std::shared_ptr<InputStream> GetStream(std::shared_ptr<RandomAccessFile> file,
+ int64_t file_offset, int64_t nbytes);
+
+ /// \brief Return the total file size in bytes.
+ ///
+ /// This method does not read or move the current file position, so is safe
+ /// to call concurrently with e.g. ReadAt().
+ virtual Result<int64_t> GetSize() = 0;
+
+ /// \brief Read data from given file position.
+ ///
+ /// At most `nbytes` bytes are read. The number of bytes read is returned
+ /// (it can be less than `nbytes` if EOF is reached).
+ ///
+ /// This method can be safely called from multiple threads concurrently.
+ /// It is unspecified whether this method updates the file position or not.
+ ///
+ /// The default RandomAccessFile-provided implementation uses Seek() and Read(),
+ /// but subclasses may override it with a more efficient implementation
+ /// that doesn't depend on implicit file positioning.
+ ///
+ /// \param[in] position Where to read bytes from
+ /// \param[in] nbytes The number of bytes to read
+ /// \param[out] out The buffer to read bytes into
+ /// \return The number of bytes read, or an error
+ virtual Result<int64_t> ReadAt(int64_t position, int64_t nbytes, void* out);
+
+ /// \brief Read data from given file position.
+ ///
+ /// At most `nbytes` bytes are read, but it can be less if EOF is reached.
+ ///
+ /// \param[in] position Where to read bytes from
+ /// \param[in] nbytes The number of bytes to read
+ /// \return A buffer containing the bytes read, or an error
+ virtual Result<std::shared_ptr<Buffer>> ReadAt(int64_t position, int64_t nbytes);
+
+ /// EXPERIMENTAL: Read data asynchronously.
+ virtual Future<std::shared_ptr<Buffer>> ReadAsync(const IOContext&, int64_t position,
+ int64_t nbytes);
+
+ /// EXPERIMENTAL: Read data asynchronously, using the file's IOContext.
+ Future<std::shared_ptr<Buffer>> ReadAsync(int64_t position, int64_t nbytes);
+
+ /// EXPERIMENTAL: Inform that the given ranges may be read soon.
+ ///
+ /// Some implementations might arrange to prefetch some of the data.
+ /// However, no guarantee is made and the default implementation does nothing.
+ /// For robust prefetching, use ReadAt() or ReadAsync().
+ virtual Status WillNeed(const std::vector<ReadRange>& ranges);
+
+ protected:
+ RandomAccessFile();
+
+ private:
+ struct ARROW_NO_EXPORT Impl;
+ std::unique_ptr<Impl> interface_impl_;
+};
+
+class ARROW_EXPORT WritableFile : public OutputStream, public Seekable {
+ public:
+ virtual Status WriteAt(int64_t position, const void* data, int64_t nbytes) = 0;
+
+ protected:
+ WritableFile() = default;
+};
+
+class ARROW_EXPORT ReadWriteFileInterface : public RandomAccessFile, public WritableFile {
+ protected:
+ ReadWriteFileInterface() { RandomAccessFile::set_mode(FileMode::READWRITE); }
+};
+
+/// \brief Return an iterator on an input stream
+///
+/// The iterator yields a fixed-size block on each Next() call, except the
+/// last block in the stream which may be smaller.
+/// Once the end of stream is reached, Next() returns nullptr
+/// (unlike InputStream::Read() which returns an empty buffer).
+ARROW_EXPORT
+Result<Iterator<std::shared_ptr<Buffer>>> MakeInputStreamIterator(
+ std::shared_ptr<InputStream> stream, int64_t block_size);
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/memory.cc b/src/arrow/cpp/src/arrow/io/memory.cc
new file mode 100644
index 000000000..6495242e6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/memory.cc
@@ -0,0 +1,388 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/io/memory.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <mutex>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/io/util_internal.h"
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/util/future.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/memory.h"
+
+namespace arrow {
+namespace io {
+
+// ----------------------------------------------------------------------
+// OutputStream that writes to resizable buffer
+
+static constexpr int64_t kBufferMinimumSize = 256;
+
+BufferOutputStream::BufferOutputStream()
+ : is_open_(false), capacity_(0), position_(0), mutable_data_(nullptr) {}
+
+BufferOutputStream::BufferOutputStream(const std::shared_ptr<ResizableBuffer>& buffer)
+ : buffer_(buffer),
+ is_open_(true),
+ capacity_(buffer->size()),
+ position_(0),
+ mutable_data_(buffer->mutable_data()) {}
+
+Result<std::shared_ptr<BufferOutputStream>> BufferOutputStream::Create(
+ int64_t initial_capacity, MemoryPool* pool) {
+ // ctor is private, so cannot use make_shared
+ auto ptr = std::shared_ptr<BufferOutputStream>(new BufferOutputStream);
+ RETURN_NOT_OK(ptr->Reset(initial_capacity, pool));
+ return ptr;
+}
+
+Status BufferOutputStream::Reset(int64_t initial_capacity, MemoryPool* pool) {
+ ARROW_ASSIGN_OR_RAISE(buffer_, AllocateResizableBuffer(initial_capacity, pool));
+ is_open_ = true;
+ capacity_ = initial_capacity;
+ position_ = 0;
+ mutable_data_ = buffer_->mutable_data();
+ return Status::OK();
+}
+
+BufferOutputStream::~BufferOutputStream() {
+ if (buffer_) {
+ internal::CloseFromDestructor(this);
+ }
+}
+
+Status BufferOutputStream::Close() {
+ if (is_open_) {
+ is_open_ = false;
+ if (position_ < capacity_) {
+ RETURN_NOT_OK(buffer_->Resize(position_, false));
+ }
+ }
+ return Status::OK();
+}
+
+bool BufferOutputStream::closed() const { return !is_open_; }
+
+Result<std::shared_ptr<Buffer>> BufferOutputStream::Finish() {
+ RETURN_NOT_OK(Close());
+ buffer_->ZeroPadding();
+ is_open_ = false;
+ return std::move(buffer_);
+}
+
+Result<int64_t> BufferOutputStream::Tell() const { return position_; }
+
+Status BufferOutputStream::Write(const void* data, int64_t nbytes) {
+ if (ARROW_PREDICT_FALSE(!is_open_)) {
+ return Status::IOError("OutputStream is closed");
+ }
+ DCHECK(buffer_);
+ if (ARROW_PREDICT_TRUE(nbytes > 0)) {
+ if (ARROW_PREDICT_FALSE(position_ + nbytes >= capacity_)) {
+ RETURN_NOT_OK(Reserve(nbytes));
+ }
+ memcpy(mutable_data_ + position_, data, nbytes);
+ position_ += nbytes;
+ }
+ return Status::OK();
+}
+
+Status BufferOutputStream::Reserve(int64_t nbytes) {
+ // Always overallocate by doubling. It seems that it is a better growth
+ // strategy, at least for memory_benchmark.cc.
+ // This may be because it helps match the allocator's allocation buckets
+ // more exactly. Or perhaps it hits a sweet spot in jemalloc.
+ int64_t new_capacity = std::max(kBufferMinimumSize, capacity_);
+ while (new_capacity < position_ + nbytes) {
+ new_capacity = new_capacity * 2;
+ }
+ if (new_capacity > capacity_) {
+ RETURN_NOT_OK(buffer_->Resize(new_capacity));
+ capacity_ = new_capacity;
+ mutable_data_ = buffer_->mutable_data();
+ }
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// OutputStream that doesn't write anything
+
+Status MockOutputStream::Close() {
+ is_open_ = false;
+ return Status::OK();
+}
+
+bool MockOutputStream::closed() const { return !is_open_; }
+
+Result<int64_t> MockOutputStream::Tell() const { return extent_bytes_written_; }
+
+Status MockOutputStream::Write(const void* data, int64_t nbytes) {
+ extent_bytes_written_ += nbytes;
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// In-memory buffer writer
+
+static constexpr int kMemcopyDefaultNumThreads = 1;
+static constexpr int64_t kMemcopyDefaultBlocksize = 64;
+static constexpr int64_t kMemcopyDefaultThreshold = 1024 * 1024;
+
+class FixedSizeBufferWriter::FixedSizeBufferWriterImpl {
+ public:
+ /// Input buffer must be mutable, will abort if not
+
+ /// Input buffer must be mutable, will abort if not
+ explicit FixedSizeBufferWriterImpl(const std::shared_ptr<Buffer>& buffer)
+ : is_open_(true),
+ memcopy_num_threads_(kMemcopyDefaultNumThreads),
+ memcopy_blocksize_(kMemcopyDefaultBlocksize),
+ memcopy_threshold_(kMemcopyDefaultThreshold) {
+ buffer_ = buffer;
+ ARROW_CHECK(buffer->is_mutable()) << "Must pass mutable buffer";
+ mutable_data_ = buffer->mutable_data();
+ size_ = buffer->size();
+ position_ = 0;
+ }
+
+ Status Close() {
+ is_open_ = false;
+ return Status::OK();
+ }
+
+ bool closed() const { return !is_open_; }
+
+ Status Seek(int64_t position) {
+ if (position < 0 || position > size_) {
+ return Status::IOError("Seek out of bounds");
+ }
+ position_ = position;
+ return Status::OK();
+ }
+
+ Result<int64_t> Tell() { return position_; }
+
+ Status Write(const void* data, int64_t nbytes) {
+ RETURN_NOT_OK(internal::ValidateWriteRange(position_, nbytes, size_));
+ if (nbytes > memcopy_threshold_ && memcopy_num_threads_ > 1) {
+ ::arrow::internal::parallel_memcopy(mutable_data_ + position_,
+ reinterpret_cast<const uint8_t*>(data), nbytes,
+ memcopy_blocksize_, memcopy_num_threads_);
+ } else {
+ memcpy(mutable_data_ + position_, data, nbytes);
+ }
+ position_ += nbytes;
+ return Status::OK();
+ }
+
+ Status WriteAt(int64_t position, const void* data, int64_t nbytes) {
+ std::lock_guard<std::mutex> guard(lock_);
+ RETURN_NOT_OK(internal::ValidateWriteRange(position, nbytes, size_));
+ RETURN_NOT_OK(Seek(position));
+ return Write(data, nbytes);
+ }
+
+ void set_memcopy_threads(int num_threads) { memcopy_num_threads_ = num_threads; }
+
+ void set_memcopy_blocksize(int64_t blocksize) { memcopy_blocksize_ = blocksize; }
+
+ void set_memcopy_threshold(int64_t threshold) { memcopy_threshold_ = threshold; }
+
+ private:
+ std::mutex lock_;
+ std::shared_ptr<Buffer> buffer_;
+ uint8_t* mutable_data_;
+ int64_t size_;
+ int64_t position_;
+ bool is_open_;
+
+ int memcopy_num_threads_;
+ int64_t memcopy_blocksize_;
+ int64_t memcopy_threshold_;
+};
+
+FixedSizeBufferWriter::FixedSizeBufferWriter(const std::shared_ptr<Buffer>& buffer)
+ : impl_(new FixedSizeBufferWriterImpl(buffer)) {}
+
+FixedSizeBufferWriter::~FixedSizeBufferWriter() = default;
+
+Status FixedSizeBufferWriter::Close() { return impl_->Close(); }
+
+bool FixedSizeBufferWriter::closed() const { return impl_->closed(); }
+
+Status FixedSizeBufferWriter::Seek(int64_t position) { return impl_->Seek(position); }
+
+Result<int64_t> FixedSizeBufferWriter::Tell() const { return impl_->Tell(); }
+
+Status FixedSizeBufferWriter::Write(const void* data, int64_t nbytes) {
+ return impl_->Write(data, nbytes);
+}
+
+Status FixedSizeBufferWriter::WriteAt(int64_t position, const void* data,
+ int64_t nbytes) {
+ return impl_->WriteAt(position, data, nbytes);
+}
+
+void FixedSizeBufferWriter::set_memcopy_threads(int num_threads) {
+ impl_->set_memcopy_threads(num_threads);
+}
+
+void FixedSizeBufferWriter::set_memcopy_blocksize(int64_t blocksize) {
+ impl_->set_memcopy_blocksize(blocksize);
+}
+
+void FixedSizeBufferWriter::set_memcopy_threshold(int64_t threshold) {
+ impl_->set_memcopy_threshold(threshold);
+}
+
+// ----------------------------------------------------------------------
+// In-memory buffer reader
+
+BufferReader::BufferReader(std::shared_ptr<Buffer> buffer)
+ : buffer_(std::move(buffer)),
+ data_(buffer_ ? buffer_->data() : reinterpret_cast<const uint8_t*>("")),
+ size_(buffer_ ? buffer_->size() : 0),
+ position_(0),
+ is_open_(true) {}
+
+BufferReader::BufferReader(const uint8_t* data, int64_t size)
+ : buffer_(nullptr), data_(data), size_(size), position_(0), is_open_(true) {}
+
+BufferReader::BufferReader(const Buffer& buffer)
+ : BufferReader(buffer.data(), buffer.size()) {}
+
+BufferReader::BufferReader(const util::string_view& data)
+ : BufferReader(reinterpret_cast<const uint8_t*>(data.data()),
+ static_cast<int64_t>(data.size())) {}
+
+Status BufferReader::DoClose() {
+ is_open_ = false;
+ return Status::OK();
+}
+
+bool BufferReader::closed() const { return !is_open_; }
+
+Result<int64_t> BufferReader::DoTell() const {
+ RETURN_NOT_OK(CheckClosed());
+ return position_;
+}
+
+Result<util::string_view> BufferReader::DoPeek(int64_t nbytes) {
+ RETURN_NOT_OK(CheckClosed());
+
+ const int64_t bytes_available = std::min(nbytes, size_ - position_);
+ return util::string_view(reinterpret_cast<const char*>(data_) + position_,
+ static_cast<size_t>(bytes_available));
+}
+
+bool BufferReader::supports_zero_copy() const { return true; }
+
+Status BufferReader::WillNeed(const std::vector<ReadRange>& ranges) {
+ using ::arrow::internal::MemoryRegion;
+
+ RETURN_NOT_OK(CheckClosed());
+
+ std::vector<MemoryRegion> regions(ranges.size());
+ for (size_t i = 0; i < ranges.size(); ++i) {
+ const auto& range = ranges[i];
+ ARROW_ASSIGN_OR_RAISE(auto size,
+ internal::ValidateReadRange(range.offset, range.length, size_));
+ regions[i] = {const_cast<uint8_t*>(data_ + range.offset), static_cast<size_t>(size)};
+ }
+ const auto st = ::arrow::internal::MemoryAdviseWillNeed(regions);
+ if (st.IsIOError()) {
+ // Ignore any system-level errors, in case the memory area isn't madvise()-able
+ return Status::OK();
+ }
+ return st;
+}
+
+Future<std::shared_ptr<Buffer>> BufferReader::ReadAsync(const IOContext&,
+ int64_t position,
+ int64_t nbytes) {
+ return Future<std::shared_ptr<Buffer>>::MakeFinished(DoReadAt(position, nbytes));
+}
+
+Result<int64_t> BufferReader::DoReadAt(int64_t position, int64_t nbytes, void* buffer) {
+ RETURN_NOT_OK(CheckClosed());
+
+ ARROW_ASSIGN_OR_RAISE(nbytes, internal::ValidateReadRange(position, nbytes, size_));
+ DCHECK_GE(nbytes, 0);
+ if (nbytes) {
+ memcpy(buffer, data_ + position, nbytes);
+ }
+ return nbytes;
+}
+
+Result<std::shared_ptr<Buffer>> BufferReader::DoReadAt(int64_t position, int64_t nbytes) {
+ RETURN_NOT_OK(CheckClosed());
+
+ ARROW_ASSIGN_OR_RAISE(nbytes, internal::ValidateReadRange(position, nbytes, size_));
+ DCHECK_GE(nbytes, 0);
+
+ // Arrange for data to be paged in
+ // RETURN_NOT_OK(::arrow::internal::MemoryAdviseWillNeed(
+ // {{const_cast<uint8_t*>(data_ + position), static_cast<size_t>(nbytes)}}));
+
+ if (nbytes > 0 && buffer_ != nullptr) {
+ return SliceBuffer(buffer_, position, nbytes);
+ } else {
+ return std::make_shared<Buffer>(data_ + position, nbytes);
+ }
+}
+
+Result<int64_t> BufferReader::DoRead(int64_t nbytes, void* out) {
+ RETURN_NOT_OK(CheckClosed());
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, DoReadAt(position_, nbytes, out));
+ position_ += bytes_read;
+ return bytes_read;
+}
+
+Result<std::shared_ptr<Buffer>> BufferReader::DoRead(int64_t nbytes) {
+ RETURN_NOT_OK(CheckClosed());
+ ARROW_ASSIGN_OR_RAISE(auto buffer, DoReadAt(position_, nbytes));
+ position_ += buffer->size();
+ return buffer;
+}
+
+Result<int64_t> BufferReader::DoGetSize() {
+ RETURN_NOT_OK(CheckClosed());
+ return size_;
+}
+
+Status BufferReader::DoSeek(int64_t position) {
+ RETURN_NOT_OK(CheckClosed());
+
+ if (position < 0 || position > size_) {
+ return Status::IOError("Seek out of bounds");
+ }
+
+ position_ = position;
+ return Status::OK();
+}
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/memory.h b/src/arrow/cpp/src/arrow/io/memory.h
new file mode 100644
index 000000000..8213439ef
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/memory.h
@@ -0,0 +1,197 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Public API for different memory sharing / IO mechanisms
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "arrow/io/concurrency.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Status;
+
+namespace io {
+
+/// \brief An output stream that writes to a resizable buffer
+class ARROW_EXPORT BufferOutputStream : public OutputStream {
+ public:
+ explicit BufferOutputStream(const std::shared_ptr<ResizableBuffer>& buffer);
+
+ /// \brief Create in-memory output stream with indicated capacity using a
+ /// memory pool
+ /// \param[in] initial_capacity the initial allocated internal capacity of
+ /// the OutputStream
+ /// \param[in,out] pool a MemoryPool to use for allocations
+ /// \return the created stream
+ static Result<std::shared_ptr<BufferOutputStream>> Create(
+ int64_t initial_capacity = 4096, MemoryPool* pool = default_memory_pool());
+
+ ~BufferOutputStream() override;
+
+ // Implement the OutputStream interface
+
+ /// Close the stream, preserving the buffer (retrieve it with Finish()).
+ Status Close() override;
+ bool closed() const override;
+ Result<int64_t> Tell() const override;
+ Status Write(const void* data, int64_t nbytes) override;
+
+ /// \cond FALSE
+ using OutputStream::Write;
+ /// \endcond
+
+ /// Close the stream and return the buffer
+ Result<std::shared_ptr<Buffer>> Finish();
+
+ /// \brief Initialize state of OutputStream with newly allocated memory and
+ /// set position to 0
+ /// \param[in] initial_capacity the starting allocated capacity
+ /// \param[in,out] pool the memory pool to use for allocations
+ /// \return Status
+ Status Reset(int64_t initial_capacity = 1024, MemoryPool* pool = default_memory_pool());
+
+ int64_t capacity() const { return capacity_; }
+
+ private:
+ BufferOutputStream();
+
+ // Ensures there is sufficient space available to write nbytes
+ Status Reserve(int64_t nbytes);
+
+ std::shared_ptr<ResizableBuffer> buffer_;
+ bool is_open_;
+ int64_t capacity_;
+ int64_t position_;
+ uint8_t* mutable_data_;
+};
+
+/// \brief A helper class to track the size of allocations
+///
+/// Writes to this stream do not copy or retain any data, they just bump
+/// a size counter that can be later used to know exactly which data size
+/// needs to be allocated for actual writing.
+class ARROW_EXPORT MockOutputStream : public OutputStream {
+ public:
+ MockOutputStream() : extent_bytes_written_(0), is_open_(true) {}
+
+ // Implement the OutputStream interface
+ Status Close() override;
+ bool closed() const override;
+ Result<int64_t> Tell() const override;
+ Status Write(const void* data, int64_t nbytes) override;
+ /// \cond FALSE
+ using Writable::Write;
+ /// \endcond
+
+ int64_t GetExtentBytesWritten() const { return extent_bytes_written_; }
+
+ private:
+ int64_t extent_bytes_written_;
+ bool is_open_;
+};
+
+/// \brief An output stream that writes into a fixed-size mutable buffer
+class ARROW_EXPORT FixedSizeBufferWriter : public WritableFile {
+ public:
+ /// Input buffer must be mutable, will abort if not
+ explicit FixedSizeBufferWriter(const std::shared_ptr<Buffer>& buffer);
+ ~FixedSizeBufferWriter() override;
+
+ Status Close() override;
+ bool closed() const override;
+ Status Seek(int64_t position) override;
+ Result<int64_t> Tell() const override;
+ Status Write(const void* data, int64_t nbytes) override;
+ /// \cond FALSE
+ using Writable::Write;
+ /// \endcond
+
+ Status WriteAt(int64_t position, const void* data, int64_t nbytes) override;
+
+ void set_memcopy_threads(int num_threads);
+ void set_memcopy_blocksize(int64_t blocksize);
+ void set_memcopy_threshold(int64_t threshold);
+
+ protected:
+ class FixedSizeBufferWriterImpl;
+ std::unique_ptr<FixedSizeBufferWriterImpl> impl_;
+};
+
+/// \class BufferReader
+/// \brief Random access zero-copy reads on an arrow::Buffer
+class ARROW_EXPORT BufferReader
+ : public internal::RandomAccessFileConcurrencyWrapper<BufferReader> {
+ public:
+ explicit BufferReader(std::shared_ptr<Buffer> buffer);
+ explicit BufferReader(const Buffer& buffer);
+ BufferReader(const uint8_t* data, int64_t size);
+
+ /// \brief Instantiate from std::string or arrow::util::string_view. Does not
+ /// own data
+ explicit BufferReader(const util::string_view& data);
+
+ bool closed() const override;
+
+ bool supports_zero_copy() const override;
+
+ std::shared_ptr<Buffer> buffer() const { return buffer_; }
+
+ // Synchronous ReadAsync override
+ Future<std::shared_ptr<Buffer>> ReadAsync(const IOContext&, int64_t position,
+ int64_t nbytes) override;
+ Status WillNeed(const std::vector<ReadRange>& ranges) override;
+
+ protected:
+ friend RandomAccessFileConcurrencyWrapper<BufferReader>;
+
+ Status DoClose();
+
+ Result<int64_t> DoRead(int64_t nbytes, void* buffer);
+ Result<std::shared_ptr<Buffer>> DoRead(int64_t nbytes);
+ Result<int64_t> DoReadAt(int64_t position, int64_t nbytes, void* out);
+ Result<std::shared_ptr<Buffer>> DoReadAt(int64_t position, int64_t nbytes);
+ Result<util::string_view> DoPeek(int64_t nbytes) override;
+
+ Result<int64_t> DoTell() const;
+ Status DoSeek(int64_t position);
+ Result<int64_t> DoGetSize();
+
+ Status CheckClosed() const {
+ if (!is_open_) {
+ return Status::Invalid("Operation forbidden on closed BufferReader");
+ }
+ return Status::OK();
+ }
+
+ std::shared_ptr<Buffer> buffer_;
+ const uint8_t* data_;
+ int64_t size_;
+ int64_t position_;
+ bool is_open_;
+};
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/memory_benchmark.cc b/src/arrow/cpp/src/arrow/io/memory_benchmark.cc
new file mode 100644
index 000000000..fbb34f386
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/memory_benchmark.cc
@@ -0,0 +1,359 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <iostream>
+
+#include "arrow/io/memory.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/cpu_info.h"
+#include "arrow/util/simd.h"
+
+#include "benchmark/benchmark.h"
+
+namespace arrow {
+
+using internal::CpuInfo;
+static CpuInfo* cpu_info = CpuInfo::GetInstance();
+
+static const int kNumCores = cpu_info->num_cores();
+static const int64_t kL1Size = cpu_info->CacheSize(CpuInfo::L1_CACHE);
+static const int64_t kL2Size = cpu_info->CacheSize(CpuInfo::L2_CACHE);
+static const int64_t kL3Size = cpu_info->CacheSize(CpuInfo::L3_CACHE);
+
+constexpr size_t kMemoryPerCore = 32 * 1024 * 1024;
+using BufferPtr = std::shared_ptr<Buffer>;
+
+#ifdef ARROW_WITH_BENCHMARKS_REFERENCE
+#ifndef _MSC_VER
+
+#ifdef ARROW_HAVE_SSE4_2
+
+#ifdef ARROW_HAVE_AVX512
+
+using VectorType = __m512i;
+#define VectorSet _mm512_set1_epi32
+#define VectorLoad _mm512_stream_load_si512
+#define VectorLoadAsm(SRC, DST) \
+ asm volatile("vmovaps %[src], %[dst]" : [ dst ] "=v"(DST) : [ src ] "m"(SRC) :)
+#define VectorStreamLoad _mm512_stream_load_si512
+#define VectorStreamLoadAsm(SRC, DST) \
+ asm volatile("vmovntdqa %[src], %[dst]" : [ dst ] "=v"(DST) : [ src ] "m"(SRC) :)
+#define VectorStreamWrite _mm512_stream_si512
+
+#else
+
+#ifdef ARROW_HAVE_AVX2
+
+using VectorType = __m256i;
+#define VectorSet _mm256_set1_epi32
+#define VectorLoad _mm256_stream_load_si256
+#define VectorLoadAsm(SRC, DST) \
+ asm volatile("vmovaps %[src], %[dst]" : [ dst ] "=v"(DST) : [ src ] "m"(SRC) :)
+#define VectorStreamLoad _mm256_stream_load_si256
+#define VectorStreamLoadAsm(SRC, DST) \
+ asm volatile("vmovntdqa %[src], %[dst]" : [ dst ] "=v"(DST) : [ src ] "m"(SRC) :)
+#define VectorStreamWrite _mm256_stream_si256
+
+#else // ARROW_HAVE_AVX2 not set
+
+using VectorType = __m128i;
+#define VectorSet _mm_set1_epi32
+#define VectorLoad _mm_stream_load_si128
+#define VectorLoadAsm(SRC, DST) \
+ asm volatile("movaps %[src], %[dst]" : [ dst ] "=x"(DST) : [ src ] "m"(SRC) :)
+#define VectorStreamLoad _mm_stream_load_si128
+#define VectorStreamLoadAsm(SRC, DST) \
+ asm volatile("movntdqa %[src], %[dst]" : [ dst ] "=x"(DST) : [ src ] "m"(SRC) :)
+#define VectorStreamWrite _mm_stream_si128
+
+#endif // ARROW_HAVE_AVX2
+#endif // ARROW_HAVE_AVX512
+
+static void Read(void* src, void* dst, size_t size) {
+ const auto simd = static_cast<VectorType*>(src);
+ VectorType a, b, c, d;
+ (void)dst;
+
+ for (size_t i = 0; i < size / sizeof(VectorType); i += 4) {
+ VectorLoadAsm(simd[i], a);
+ VectorLoadAsm(simd[i + 1], b);
+ VectorLoadAsm(simd[i + 2], c);
+ VectorLoadAsm(simd[i + 3], d);
+ }
+
+ memset(&a, 0, sizeof(a));
+ memset(&b, 0, sizeof(b));
+ memset(&c, 0, sizeof(c));
+ memset(&d, 0, sizeof(d));
+
+ benchmark::DoNotOptimize(a + b + c + d);
+}
+
+// See http://codearcana.com/posts/2013/05/18/achieving-maximum-memory-bandwidth.html
+// for the usage of stream loads/writes. Or section 6.1, page 47 of
+// https://akkadia.org/drepper/cpumemory.pdf .
+static void StreamRead(void* src, void* dst, size_t size) {
+ auto simd = static_cast<VectorType*>(src);
+ VectorType a, b, c, d;
+ (void)dst;
+
+ memset(&a, 0, sizeof(a));
+ memset(&b, 0, sizeof(b));
+ memset(&c, 0, sizeof(c));
+ memset(&d, 0, sizeof(d));
+
+ for (size_t i = 0; i < size / sizeof(VectorType); i += 4) {
+ VectorStreamLoadAsm(simd[i], a);
+ VectorStreamLoadAsm(simd[i + 1], b);
+ VectorStreamLoadAsm(simd[i + 2], c);
+ VectorStreamLoadAsm(simd[i + 3], d);
+ }
+
+ benchmark::DoNotOptimize(a + b + c + d);
+}
+
+static void StreamWrite(void* src, void* dst, size_t size) {
+ auto simd = static_cast<VectorType*>(dst);
+ const VectorType ones = VectorSet(1);
+ (void)src;
+
+ for (size_t i = 0; i < size / sizeof(VectorType); i += 4) {
+ VectorStreamWrite(&simd[i], ones);
+ VectorStreamWrite(&simd[i + 1], ones);
+ VectorStreamWrite(&simd[i + 2], ones);
+ VectorStreamWrite(&simd[i + 3], ones);
+ }
+}
+
+static void StreamReadWrite(void* src, void* dst, size_t size) {
+ auto src_simd = static_cast<VectorType*>(src);
+ auto dst_simd = static_cast<VectorType*>(dst);
+
+ for (size_t i = 0; i < size / sizeof(VectorType); i += 4) {
+ VectorStreamWrite(&dst_simd[i], VectorStreamLoad(&src_simd[i]));
+ VectorStreamWrite(&dst_simd[i + 1], VectorStreamLoad(&src_simd[i + 1]));
+ VectorStreamWrite(&dst_simd[i + 2], VectorStreamLoad(&src_simd[i + 2]));
+ VectorStreamWrite(&dst_simd[i + 3], VectorStreamLoad(&src_simd[i + 3]));
+ }
+}
+
+#endif // ARROW_HAVE_SSE4_2
+
+#ifdef ARROW_HAVE_ARMV8_CRYPTO
+
+using VectorType = uint8x16_t;
+using VectorTypeDual = uint8x16x2_t;
+
+#define VectorSet vdupq_n_u8
+#define VectorLoadAsm vld1q_u8
+
+static void armv8_stream_load_pair(VectorType* src, VectorType* dst) {
+ asm volatile("LDNP %[reg1], %[reg2], [%[from]]\n\t"
+ : [ reg1 ] "+r"(*dst), [ reg2 ] "+r"(*(dst + 1))
+ : [ from ] "r"(src));
+}
+
+static void armv8_stream_store_pair(VectorType* src, VectorType* dst) {
+ asm volatile("STNP %[reg1], %[reg2], [%[to]]\n\t"
+ : [ to ] "+r"(dst)
+ : [ reg1 ] "r"(*src), [ reg2 ] "r"(*(src + 1))
+ : "memory");
+}
+
+static void armv8_stream_ldst_pair(VectorType* src, VectorType* dst) {
+ asm volatile(
+ "LDNP q1, q2, [%[from]]\n\t"
+ "STNP q1, q2, [%[to]]\n\t"
+ : [ from ] "+r"(src), [ to ] "+r"(dst)
+ :
+ : "memory", "v0", "v1", "v2", "v3");
+}
+
+static void Read(void* src, void* dst, size_t size) {
+ const auto simd = static_cast<uint8_t*>(src);
+ VectorType a;
+ (void)dst;
+
+ memset(&a, 0, sizeof(a));
+
+ for (size_t i = 0; i < size; i += sizeof(VectorType)) {
+ a = VectorLoadAsm(simd + i);
+ }
+
+ benchmark::DoNotOptimize(a);
+}
+
+// See http://codearcana.com/posts/2013/05/18/achieving-maximum-memory-bandwidth.html
+// for the usage of stream loads/writes. Or section 6.1, page 47 of
+// https://akkadia.org/drepper/cpumemory.pdf .
+static void StreamRead(void* src, void* dst, size_t size) {
+ auto simd = static_cast<VectorType*>(src);
+ VectorType a[2];
+ (void)dst;
+
+ memset(&a, 0, sizeof(VectorTypeDual));
+
+ for (size_t i = 0; i < size / sizeof(VectorType); i += 2) {
+ armv8_stream_load_pair(simd + i, a);
+ }
+
+ benchmark::DoNotOptimize(a);
+}
+
+static void StreamWrite(void* src, void* dst, size_t size) {
+ auto simd = static_cast<VectorType*>(dst);
+ VectorType ones[2];
+ (void)src;
+
+ ones[0] = VectorSet(1);
+ ones[1] = VectorSet(1);
+
+ for (size_t i = 0; i < size / sizeof(VectorType); i += 2) {
+ armv8_stream_store_pair(static_cast<VectorType*>(ones), simd + i);
+ }
+}
+
+static void StreamReadWrite(void* src, void* dst, size_t size) {
+ auto src_simd = static_cast<VectorType*>(src);
+ auto dst_simd = static_cast<VectorType*>(dst);
+
+ for (size_t i = 0; i < size / sizeof(VectorType); i += 2) {
+ armv8_stream_ldst_pair(src_simd + i, dst_simd + i);
+ }
+}
+
+#endif // ARROW_HAVE_ARMV8_CRYPTO
+
+static void PlatformMemcpy(void* src, void* dst, size_t size) { memcpy(src, dst, size); }
+
+using ApplyFn = decltype(Read);
+
+template <ApplyFn Apply>
+static void MemoryBandwidth(benchmark::State& state) { // NOLINT non-const reference
+ const size_t buffer_size = state.range(0);
+ BufferPtr src, dst;
+
+ dst = *AllocateBuffer(buffer_size);
+ src = *AllocateBuffer(buffer_size);
+ random_bytes(buffer_size, 0, src->mutable_data());
+
+ while (state.KeepRunning()) {
+ Apply(src->mutable_data(), dst->mutable_data(), buffer_size);
+ }
+
+ state.SetBytesProcessed(state.iterations() * buffer_size);
+}
+
+#ifdef ARROW_HAVE_SSE4_2
+static void SetCacheBandwidthArgs(benchmark::internal::Benchmark* bench) {
+ auto cache_sizes = {kL1Size, kL2Size, kL3Size};
+ for (auto size : cache_sizes) {
+ bench->Arg(size / 2);
+ bench->Arg(size);
+ bench->Arg(size * 2);
+ }
+
+ bench->ArgName("size");
+}
+
+BENCHMARK_TEMPLATE(MemoryBandwidth, Read)->Apply(SetCacheBandwidthArgs);
+#endif // ARROW_HAVE_SSE4_2
+
+static void SetMemoryBandwidthArgs(benchmark::internal::Benchmark* bench) {
+ // `UseRealTime` is required due to threads, otherwise the cumulative CPU time
+ // is used which will skew the results by the number of threads.
+ bench->Arg(kMemoryPerCore)->ThreadRange(1, kNumCores)->UseRealTime();
+}
+
+BENCHMARK_TEMPLATE(MemoryBandwidth, StreamRead)->Apply(SetMemoryBandwidthArgs);
+BENCHMARK_TEMPLATE(MemoryBandwidth, StreamWrite)->Apply(SetMemoryBandwidthArgs);
+BENCHMARK_TEMPLATE(MemoryBandwidth, StreamReadWrite)->Apply(SetMemoryBandwidthArgs);
+BENCHMARK_TEMPLATE(MemoryBandwidth, PlatformMemcpy)->Apply(SetMemoryBandwidthArgs);
+
+#endif // _MSC_VER
+#endif // ARROW_WITH_BENCHMARKS_REFERENCE
+
+static void ParallelMemoryCopy(benchmark::State& state) { // NOLINT non-const reference
+ const int64_t n_threads = state.range(0);
+ const int64_t buffer_size = kMemoryPerCore;
+
+ auto src = *AllocateBuffer(buffer_size);
+ std::shared_ptr<Buffer> dst = *AllocateBuffer(buffer_size);
+
+ random_bytes(buffer_size, 0, src->mutable_data());
+
+ while (state.KeepRunning()) {
+ io::FixedSizeBufferWriter writer(dst);
+ writer.set_memcopy_threads(static_cast<int>(n_threads));
+ ABORT_NOT_OK(writer.Write(src->data(), src->size()));
+ }
+
+ state.SetBytesProcessed(int64_t(state.iterations()) * buffer_size);
+}
+
+BENCHMARK(ParallelMemoryCopy)
+ ->RangeMultiplier(2)
+ ->Range(1, kNumCores)
+ ->ArgName("threads")
+ ->UseRealTime();
+
+static void BenchmarkBufferOutputStream(
+ const std::string& datum,
+ benchmark::State& state) { // NOLINT non-const reference
+ const void* raw_data = datum.data();
+ int64_t raw_nbytes = static_cast<int64_t>(datum.size());
+ // Write approx. 32 MB to each BufferOutputStream
+ int64_t num_raw_values = (1 << 25) / raw_nbytes;
+ for (auto _ : state) {
+ auto stream = *io::BufferOutputStream::Create(1024);
+ for (int64_t i = 0; i < num_raw_values; ++i) {
+ ABORT_NOT_OK(stream->Write(raw_data, raw_nbytes));
+ }
+ ABORT_NOT_OK(stream->Finish());
+ }
+ state.SetBytesProcessed(int64_t(state.iterations()) * num_raw_values * raw_nbytes);
+}
+
+static void BufferOutputStreamTinyWrites(
+ benchmark::State& state) { // NOLINT non-const reference
+ // A 8-byte datum
+ return BenchmarkBufferOutputStream("abdefghi", state);
+}
+
+static void BufferOutputStreamSmallWrites(
+ benchmark::State& state) { // NOLINT non-const reference
+ // A 700-byte datum
+ std::string datum;
+ for (int i = 0; i < 100; ++i) {
+ datum += "abcdefg";
+ }
+ return BenchmarkBufferOutputStream(datum, state);
+}
+
+static void BufferOutputStreamLargeWrites(
+ benchmark::State& state) { // NOLINT non-const reference
+ // A 1.5MB datum
+ std::string datum(1500000, 'x');
+ return BenchmarkBufferOutputStream(datum, state);
+}
+
+BENCHMARK(BufferOutputStreamTinyWrites)->UseRealTime();
+BENCHMARK(BufferOutputStreamSmallWrites)->UseRealTime();
+BENCHMARK(BufferOutputStreamLargeWrites)->UseRealTime();
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/memory_test.cc b/src/arrow/cpp/src/arrow/io/memory_test.cc
new file mode 100644
index 000000000..bd62761c7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/memory_test.cc
@@ -0,0 +1,883 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <chrono>
+#include <cmath>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+#include <functional>
+#include <memory>
+#include <ostream>
+#include <random>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/buffer.h"
+#include "arrow/io/caching.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/memory.h"
+#include "arrow/io/slow.h"
+#include "arrow/io/transform.h"
+#include "arrow/io/util_internal.h"
+#include "arrow/status.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/parallel.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace io {
+
+std::ostream& operator<<(std::ostream& os, const ReadRange& range) {
+ return os << "<offset=" << range.offset << ", length=" << range.length << ">";
+}
+
+class TestBufferOutputStream : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK_AND_ASSIGN(buffer_, AllocateResizableBuffer(0));
+ stream_.reset(new BufferOutputStream(buffer_));
+ }
+
+ protected:
+ std::shared_ptr<ResizableBuffer> buffer_;
+ std::unique_ptr<OutputStream> stream_;
+};
+
+TEST_F(TestBufferOutputStream, DtorCloses) {
+ std::string data = "data123456";
+
+ const int K = 100;
+ for (int i = 0; i < K; ++i) {
+ ARROW_EXPECT_OK(stream_->Write(data));
+ }
+
+ stream_ = nullptr;
+ ASSERT_EQ(static_cast<int64_t>(K * data.size()), buffer_->size());
+}
+
+TEST_F(TestBufferOutputStream, CloseResizes) {
+ std::string data = "data123456";
+
+ const int K = 100;
+ for (int i = 0; i < K; ++i) {
+ ARROW_EXPECT_OK(stream_->Write(data));
+ }
+
+ ASSERT_OK(stream_->Close());
+ ASSERT_EQ(static_cast<int64_t>(K * data.size()), buffer_->size());
+}
+
+TEST_F(TestBufferOutputStream, WriteAfterFinish) {
+ std::string data = "data123456";
+ ASSERT_OK(stream_->Write(data));
+
+ auto buffer_stream = checked_cast<BufferOutputStream*>(stream_.get());
+
+ ASSERT_OK(buffer_stream->Finish());
+
+ ASSERT_RAISES(IOError, stream_->Write(data));
+}
+
+TEST_F(TestBufferOutputStream, Reset) {
+ std::string data = "data123456";
+
+ auto stream = checked_cast<BufferOutputStream*>(stream_.get());
+
+ ASSERT_OK(stream->Write(data));
+
+ ASSERT_OK_AND_ASSIGN(auto buffer, stream->Finish());
+ ASSERT_EQ(buffer->size(), static_cast<int64_t>(data.size()));
+
+ ASSERT_OK(stream->Reset(2048));
+ ASSERT_OK(stream->Write(data));
+ ASSERT_OK(stream->Write(data));
+ ASSERT_OK_AND_ASSIGN(auto buffer2, stream->Finish());
+
+ ASSERT_EQ(buffer2->size(), static_cast<int64_t>(data.size() * 2));
+}
+
+TEST(TestFixedSizeBufferWriter, Basics) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> buffer, AllocateBuffer(1024));
+
+ FixedSizeBufferWriter writer(buffer);
+
+ ASSERT_OK_AND_EQ(0, writer.Tell());
+
+ std::string data = "data123456";
+ auto nbytes = static_cast<int64_t>(data.size());
+ ASSERT_OK(writer.Write(data.c_str(), nbytes));
+
+ ASSERT_OK_AND_EQ(nbytes, writer.Tell());
+
+ ASSERT_OK(writer.Seek(4));
+ ASSERT_OK_AND_EQ(4, writer.Tell());
+
+ ASSERT_OK(writer.Seek(1024));
+ ASSERT_OK_AND_EQ(1024, writer.Tell());
+
+ // Write out of bounds
+ ASSERT_RAISES(IOError, writer.Write(data.c_str(), 1));
+
+ // Seek out of bounds
+ ASSERT_RAISES(IOError, writer.Seek(-1));
+ ASSERT_RAISES(IOError, writer.Seek(1025));
+
+ ASSERT_OK(writer.Close());
+}
+
+TEST(TestFixedSizeBufferWriter, InvalidWrites) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> buffer, AllocateBuffer(1024));
+
+ FixedSizeBufferWriter writer(buffer);
+ const uint8_t data[10]{};
+ ASSERT_RAISES(Invalid, writer.WriteAt(-1, data, 1));
+ ASSERT_RAISES(Invalid, writer.WriteAt(1, data, -1));
+}
+
+TEST(TestBufferReader, FromStrings) {
+ // ARROW-3291: construct BufferReader from std::string or
+ // arrow::util::string_view
+
+ std::string data = "data123456";
+ auto view = util::string_view(data);
+
+ BufferReader reader1(data);
+ BufferReader reader2(view);
+
+ std::shared_ptr<Buffer> piece;
+ ASSERT_OK_AND_ASSIGN(piece, reader1.Read(4));
+ ASSERT_EQ(0, memcmp(piece->data(), data.data(), 4));
+
+ ASSERT_OK(reader2.Seek(2));
+ ASSERT_OK_AND_ASSIGN(piece, reader2.Read(4));
+ ASSERT_EQ(0, memcmp(piece->data(), data.data() + 2, 4));
+}
+
+TEST(TestBufferReader, FromNullBuffer) {
+ std::shared_ptr<Buffer> buf;
+ BufferReader reader(buf);
+ ASSERT_OK_AND_EQ(0, reader.GetSize());
+ ASSERT_OK_AND_ASSIGN(auto piece, reader.Read(10));
+ ASSERT_EQ(0, piece->size());
+}
+
+TEST(TestBufferReader, Seeking) {
+ std::string data = "data123456";
+
+ BufferReader reader(data);
+ ASSERT_OK_AND_EQ(0, reader.Tell());
+
+ ASSERT_OK(reader.Seek(9));
+ ASSERT_OK_AND_EQ(9, reader.Tell());
+
+ ASSERT_OK(reader.Seek(10));
+ ASSERT_OK_AND_EQ(10, reader.Tell());
+
+ ASSERT_RAISES(IOError, reader.Seek(11));
+ ASSERT_OK_AND_EQ(10, reader.Tell());
+}
+
+TEST(TestBufferReader, Peek) {
+ std::string data = "data123456";
+
+ BufferReader reader(std::make_shared<Buffer>(data));
+
+ util::string_view view;
+
+ ASSERT_OK_AND_ASSIGN(view, reader.Peek(4));
+
+ ASSERT_EQ(4, view.size());
+ ASSERT_EQ(data.substr(0, 4), std::string(view));
+
+ ASSERT_OK_AND_ASSIGN(view, reader.Peek(20));
+ ASSERT_EQ(data.size(), view.size());
+ ASSERT_EQ(data, std::string(view));
+}
+
+TEST(TestBufferReader, ReadAsync) {
+ std::string data = "data123456";
+
+ BufferReader reader(std::make_shared<Buffer>(data));
+
+ auto fut1 = reader.ReadAsync({}, 2, 6);
+ auto fut2 = reader.ReadAsync({}, 1, 4);
+ ASSERT_EQ(fut1.state(), FutureState::SUCCESS);
+ ASSERT_EQ(fut2.state(), FutureState::SUCCESS);
+ ASSERT_OK_AND_ASSIGN(auto buf, fut1.result());
+ AssertBufferEqual(*buf, "ta1234");
+ ASSERT_OK_AND_ASSIGN(buf, fut2.result());
+ AssertBufferEqual(*buf, "ata1");
+}
+
+TEST(TestBufferReader, InvalidReads) {
+ std::string data = "data123456";
+ BufferReader reader(std::make_shared<Buffer>(data));
+ uint8_t buffer[10];
+
+ ASSERT_RAISES(Invalid, reader.ReadAt(-1, 1));
+ ASSERT_RAISES(Invalid, reader.ReadAt(1, -1));
+ ASSERT_RAISES(Invalid, reader.ReadAt(-1, 1, buffer));
+ ASSERT_RAISES(Invalid, reader.ReadAt(1, -1, buffer));
+
+ ASSERT_RAISES(Invalid, reader.ReadAsync({}, -1, 1).result());
+ ASSERT_RAISES(Invalid, reader.ReadAsync({}, 1, -1).result());
+}
+
+TEST(TestBufferReader, RetainParentReference) {
+ // ARROW-387
+ std::string data = "data123456";
+
+ std::shared_ptr<Buffer> slice1;
+ std::shared_ptr<Buffer> slice2;
+ {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> buffer,
+ AllocateBuffer(static_cast<int64_t>(data.size())));
+ std::memcpy(buffer->mutable_data(), data.c_str(), data.size());
+ BufferReader reader(buffer);
+ ASSERT_OK_AND_ASSIGN(slice1, reader.Read(4));
+ ASSERT_OK_AND_ASSIGN(slice2, reader.Read(6));
+ }
+
+ ASSERT_TRUE(slice1->parent() != nullptr);
+
+ ASSERT_EQ(0, std::memcmp(slice1->data(), data.c_str(), 4));
+ ASSERT_EQ(0, std::memcmp(slice2->data(), data.c_str() + 4, 6));
+}
+
+TEST(TestBufferReader, WillNeed) {
+ {
+ std::string data = "data123456";
+ BufferReader reader(std::make_shared<Buffer>(data));
+
+ ASSERT_OK(reader.WillNeed({}));
+ ASSERT_OK(reader.WillNeed({{0, 4}, {4, 6}}));
+ ASSERT_OK(reader.WillNeed({{10, 0}}));
+ ASSERT_RAISES(IOError, reader.WillNeed({{11, 1}})); // Out of bounds
+ }
+ {
+ std::string data = "data123456";
+ BufferReader reader(reinterpret_cast<const uint8_t*>(data.data()),
+ static_cast<int64_t>(data.size()));
+
+ ASSERT_OK(reader.WillNeed({{0, 4}, {4, 6}}));
+ ASSERT_RAISES(IOError, reader.WillNeed({{11, 1}})); // Out of bounds
+ }
+}
+
+TEST(TestRandomAccessFile, GetStream) {
+ std::string data = "data1data2data3data4data5";
+
+ auto buf = std::make_shared<Buffer>(data);
+ auto file = std::make_shared<BufferReader>(buf);
+
+ std::shared_ptr<InputStream> stream1, stream2;
+
+ stream1 = RandomAccessFile::GetStream(file, 0, 10);
+ stream2 = RandomAccessFile::GetStream(file, 9, 16);
+
+ ASSERT_OK_AND_EQ(0, stream1->Tell());
+
+ std::shared_ptr<Buffer> buf2;
+ uint8_t buf3[20];
+
+ ASSERT_OK_AND_EQ(4, stream2->Read(4, buf3));
+ ASSERT_EQ(0, std::memcmp(buf3, "2dat", 4));
+ ASSERT_OK_AND_EQ(4, stream2->Tell());
+
+ ASSERT_OK_AND_EQ(6, stream1->Read(6, buf3));
+ ASSERT_EQ(0, std::memcmp(buf3, "data1d", 6));
+ ASSERT_OK_AND_EQ(6, stream1->Tell());
+
+ ASSERT_OK_AND_ASSIGN(buf2, stream1->Read(2));
+ ASSERT_TRUE(SliceBuffer(buf, 6, 2)->Equals(*buf2));
+
+ // Read to end of each stream
+ ASSERT_OK_AND_EQ(2, stream1->Read(4, buf3));
+ ASSERT_EQ(0, std::memcmp(buf3, "a2", 2));
+ ASSERT_OK_AND_EQ(10, stream1->Tell());
+
+ ASSERT_OK_AND_EQ(0, stream1->Read(1, buf3));
+ ASSERT_OK_AND_EQ(10, stream1->Tell());
+
+ // stream2 had its extent limited
+ ASSERT_OK_AND_ASSIGN(buf2, stream2->Read(20));
+ ASSERT_TRUE(SliceBuffer(buf, 13, 12)->Equals(*buf2));
+
+ ASSERT_OK_AND_ASSIGN(buf2, stream2->Read(1));
+ ASSERT_EQ(0, buf2->size());
+ ASSERT_OK_AND_EQ(16, stream2->Tell());
+
+ ASSERT_OK(stream1->Close());
+
+ // idempotent
+ ASSERT_OK(stream1->Close());
+ ASSERT_TRUE(stream1->closed());
+
+ // Check whether closed
+ ASSERT_RAISES(IOError, stream1->Tell());
+ ASSERT_RAISES(IOError, stream1->Read(1));
+ ASSERT_RAISES(IOError, stream1->Read(1, buf3));
+}
+
+TEST(TestMemcopy, ParallelMemcopy) {
+#if defined(ARROW_VALGRIND)
+ // Compensate for Valgrind's slowness
+ constexpr int64_t THRESHOLD = 32 * 1024;
+#else
+ constexpr int64_t THRESHOLD = 1024 * 1024;
+#endif
+
+ for (int i = 0; i < 5; ++i) {
+ // randomize size so the memcopy alignment is tested
+ int64_t total_size = 3 * THRESHOLD + std::rand() % 100;
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> buffer1, AllocateBuffer(total_size));
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> buffer2, AllocateBuffer(total_size));
+
+ random_bytes(total_size, 0, buffer2->mutable_data());
+
+ io::FixedSizeBufferWriter writer(buffer1);
+ writer.set_memcopy_threads(4);
+ writer.set_memcopy_threshold(THRESHOLD);
+ ASSERT_OK(writer.Write(buffer2->data(), buffer2->size()));
+
+ ASSERT_EQ(0, memcmp(buffer1->data(), buffer2->data(), buffer1->size()));
+ }
+}
+
+// -----------------------------------------------------------------------
+// Test slow streams
+
+template <typename SlowStreamType>
+void TestSlowInputStream() {
+ using clock = std::chrono::high_resolution_clock;
+
+ auto stream = std::make_shared<BufferReader>(util::string_view("abcdefghijkl"));
+ const double latency = 0.6;
+ auto slow = std::make_shared<SlowStreamType>(stream, latency);
+
+ ASSERT_FALSE(slow->closed());
+ auto t1 = clock::now();
+ ASSERT_OK_AND_ASSIGN(auto buf, slow->Read(6));
+ auto t2 = clock::now();
+ AssertBufferEqual(*buf, "abcdef");
+ auto dt = std::chrono::duration_cast<std::chrono::duration<double>>(t2 - t1).count();
+#ifdef ARROW_WITH_TIMING_TESTS
+ ASSERT_LT(dt, latency * 3); // likely
+ ASSERT_GT(dt, latency / 3); // likely
+#else
+ ARROW_UNUSED(dt);
+#endif
+
+ ASSERT_OK_AND_ASSIGN(util::string_view view, slow->Peek(4));
+ ASSERT_EQ(view, util::string_view("ghij"));
+
+ ASSERT_OK(slow->Close());
+ ASSERT_TRUE(slow->closed());
+ ASSERT_TRUE(stream->closed());
+ ASSERT_OK(slow->Close());
+ ASSERT_TRUE(slow->closed());
+ ASSERT_TRUE(stream->closed());
+}
+
+TEST(TestSlowInputStream, Basics) { TestSlowInputStream<SlowInputStream>(); }
+
+TEST(TestSlowRandomAccessFile, Basics) { TestSlowInputStream<SlowRandomAccessFile>(); }
+
+// -----------------------------------------------------------------------
+// Test transform streams
+
+struct DoublingTransform {
+ // A transform that duplicates every byte
+ Result<std::shared_ptr<Buffer>> operator()(const std::shared_ptr<Buffer>& buf) {
+ ARROW_ASSIGN_OR_RAISE(auto dest, AllocateBuffer(buf->size() * 2));
+ const uint8_t* data = buf->data();
+ uint8_t* out_data = dest->mutable_data();
+ for (int64_t i = 0; i < buf->size(); ++i) {
+ out_data[i * 2] = data[i];
+ out_data[i * 2 + 1] = data[i];
+ }
+ return std::shared_ptr<Buffer>(std::move(dest));
+ }
+};
+
+struct SwappingTransform {
+ // A transform that swaps every pair of bytes
+ Result<std::shared_ptr<Buffer>> operator()(const std::shared_ptr<Buffer>& buf) {
+ int64_t dest_size = BitUtil::RoundDown(buf->size() + has_pending_, 2);
+ ARROW_ASSIGN_OR_RAISE(auto dest, AllocateBuffer(dest_size));
+ const uint8_t* data = buf->data();
+ uint8_t* out_data = dest->mutable_data();
+ if (has_pending_ && dest_size > 0) {
+ *out_data++ = *data++;
+ *out_data++ = pending_byte_;
+ dest_size -= 2;
+ }
+ for (int64_t i = 0; i < dest_size; i += 2) {
+ out_data[i] = data[i + 1];
+ out_data[i + 1] = data[i];
+ }
+ has_pending_ = has_pending_ ^ (buf->size() & 1);
+ if (has_pending_) {
+ pending_byte_ = buf->data()[buf->size() - 1];
+ }
+ return std::shared_ptr<Buffer>(std::move(dest));
+ }
+
+ protected:
+ bool has_pending_ = 0;
+ uint8_t pending_byte_ = 0;
+};
+
+struct BaseShrinkingTransform {
+ // A transform that keeps one byte every N bytes
+ explicit BaseShrinkingTransform(int64_t keep_every) : keep_every_(keep_every) {}
+
+ Result<std::shared_ptr<Buffer>> operator()(const std::shared_ptr<Buffer>& buf) {
+ int64_t dest_size = (buf->size() - skip_bytes_ + keep_every_ - 1) / keep_every_;
+ ARROW_ASSIGN_OR_RAISE(auto dest, AllocateBuffer(dest_size));
+ const uint8_t* data = buf->data() + skip_bytes_;
+ uint8_t* out_data = dest->mutable_data();
+ for (int64_t i = 0; i < dest_size; ++i) {
+ out_data[i] = data[i * keep_every_];
+ }
+ if (dest_size > 0) {
+ skip_bytes_ = skip_bytes_ + dest_size * keep_every_ - buf->size();
+ } else {
+ skip_bytes_ = skip_bytes_ - buf->size();
+ }
+ DCHECK_GE(skip_bytes_, 0);
+ DCHECK_LT(skip_bytes_, keep_every_);
+ return std::shared_ptr<Buffer>(std::move(dest));
+ }
+
+ protected:
+ int64_t skip_bytes_ = 0;
+ const int64_t keep_every_;
+};
+
+template <int N>
+struct ShrinkingTransform : public BaseShrinkingTransform {
+ ShrinkingTransform() : BaseShrinkingTransform(N) {}
+};
+
+template <typename T>
+class TestTransformInputStream : public ::testing::Test {
+ public:
+ TransformInputStream::TransformFunc transform() const { return T(); }
+
+ void TestEmptyStream() {
+ auto wrapped = std::make_shared<BufferReader>(util::string_view());
+ auto stream = std::make_shared<TransformInputStream>(wrapped, transform());
+
+ ASSERT_OK_AND_EQ(0, stream->Tell());
+ ASSERT_OK_AND_ASSIGN(auto buf, stream->Read(123));
+ ASSERT_EQ(buf->size(), 0);
+ ASSERT_OK_AND_ASSIGN(buf, stream->Read(0));
+ ASSERT_EQ(buf->size(), 0);
+ ASSERT_OK_AND_EQ(0, stream->Read(5, out_data_));
+ ASSERT_OK_AND_EQ(0, stream->Tell());
+ }
+
+ void TestBasics() {
+ auto src = Buffer::FromString("1234567890abcdefghi");
+ ASSERT_OK_AND_ASSIGN(auto expected, this->transform()(src));
+
+ auto stream = std::make_shared<TransformInputStream>(
+ std::make_shared<BufferReader>(src), this->transform());
+ std::shared_ptr<Buffer> actual;
+ AccumulateReads(stream, 200, &actual);
+ AssertBufferEqual(*actual, *expected);
+ }
+
+ void TestClose() {
+ auto src = Buffer::FromString("1234567890abcdefghi");
+ auto stream = std::make_shared<TransformInputStream>(
+ std::make_shared<BufferReader>(src), this->transform());
+ ASSERT_FALSE(stream->closed());
+ ASSERT_OK(stream->Close());
+ ASSERT_TRUE(stream->closed());
+ ASSERT_RAISES(Invalid, stream->Read(1));
+ ASSERT_RAISES(Invalid, stream->Read(1, out_data_));
+ ASSERT_RAISES(Invalid, stream->Tell());
+ ASSERT_OK(stream->Close());
+ ASSERT_TRUE(stream->closed());
+ }
+
+ void TestChunked() {
+ auto src = Buffer::FromString("1234567890abcdefghi");
+ ASSERT_OK_AND_ASSIGN(auto expected, this->transform()(src));
+
+ auto stream = std::make_shared<TransformInputStream>(
+ std::make_shared<BufferReader>(src), this->transform());
+ std::shared_ptr<Buffer> actual;
+ AccumulateReads(stream, 5, &actual);
+ AssertBufferEqual(*actual, *expected);
+ }
+
+ void TestStressChunked() {
+ ASSERT_OK_AND_ASSIGN(auto unique_src, AllocateBuffer(1000));
+ auto src = std::shared_ptr<Buffer>(std::move(unique_src));
+ random_bytes(src->size(), /*seed=*/42, src->mutable_data());
+
+ ASSERT_OK_AND_ASSIGN(auto expected, this->transform()(src));
+
+ std::default_random_engine gen(42);
+ std::uniform_int_distribution<int> chunk_sizes(0, 20);
+
+ auto stream = std::make_shared<TransformInputStream>(
+ std::make_shared<BufferReader>(src), this->transform());
+ std::shared_ptr<Buffer> actual;
+ AccumulateReads(
+ stream, [&]() -> int64_t { return chunk_sizes(gen); }, &actual);
+ AssertBufferEqual(*actual, *expected);
+ }
+
+ void AccumulateReads(const std::shared_ptr<InputStream>& stream,
+ std::function<int64_t()> gen_chunk_sizes,
+ std::shared_ptr<Buffer>* out) {
+ std::vector<std::shared_ptr<Buffer>> buffers;
+ int64_t total_size = 0;
+ while (true) {
+ const int64_t chunk_size = gen_chunk_sizes();
+ ASSERT_OK_AND_ASSIGN(auto buf, stream->Read(chunk_size));
+ const int64_t buf_size = buf->size();
+ total_size += buf_size;
+ ASSERT_OK_AND_EQ(total_size, stream->Tell());
+ if (chunk_size > 0 && buf_size == 0) {
+ // EOF
+ break;
+ }
+ buffers.push_back(std::move(buf));
+ if (buf_size < chunk_size) {
+ // Short read should imply EOF on next read
+ ASSERT_OK_AND_ASSIGN(auto buf, stream->Read(100));
+ ASSERT_EQ(buf->size(), 0);
+ break;
+ }
+ }
+ ASSERT_OK_AND_ASSIGN(*out, ConcatenateBuffers(buffers));
+ }
+
+ void AccumulateReads(const std::shared_ptr<InputStream>& stream, int64_t chunk_size,
+ std::shared_ptr<Buffer>* out) {
+ return AccumulateReads(
+ stream, [=]() { return chunk_size; }, out);
+ }
+
+ protected:
+ uint8_t* out_data_[10];
+};
+
+using TransformTypes =
+ ::testing::Types<DoublingTransform, SwappingTransform, ShrinkingTransform<2>,
+ ShrinkingTransform<3>, ShrinkingTransform<7>>;
+
+TYPED_TEST_SUITE(TestTransformInputStream, TransformTypes);
+
+TYPED_TEST(TestTransformInputStream, EmptyStream) { this->TestEmptyStream(); }
+
+TYPED_TEST(TestTransformInputStream, Basics) { this->TestBasics(); }
+
+TYPED_TEST(TestTransformInputStream, Close) { this->TestClose(); }
+
+TYPED_TEST(TestTransformInputStream, Chunked) { this->TestChunked(); }
+
+TYPED_TEST(TestTransformInputStream, StressChunked) { this->TestStressChunked(); }
+
+static Result<std::shared_ptr<Buffer>> FailingTransform(
+ const std::shared_ptr<Buffer>& buf) {
+ return Status::UnknownError("Failed transform");
+}
+
+TEST(TestTransformInputStream, FailingTransform) {
+ auto src = Buffer::FromString("1234567890abcdefghi");
+ auto stream = std::make_shared<TransformInputStream>(
+ std::make_shared<BufferReader>(src), FailingTransform);
+ ASSERT_RAISES(UnknownError, stream->Read(5));
+}
+
+// -----------------------------------------------------------------------
+// Test various utilities
+
+TEST(TestInputStreamIterator, Basics) {
+ auto reader = std::make_shared<BufferReader>(Buffer::FromString("data123456"));
+ ASSERT_OK_AND_ASSIGN(auto it, MakeInputStreamIterator(reader, /*block_size=*/3));
+ std::shared_ptr<Buffer> buf;
+ ASSERT_OK_AND_ASSIGN(buf, it.Next());
+ AssertBufferEqual(*buf, "dat");
+ ASSERT_OK_AND_ASSIGN(buf, it.Next());
+ AssertBufferEqual(*buf, "a12");
+ ASSERT_OK_AND_ASSIGN(buf, it.Next());
+ AssertBufferEqual(*buf, "345");
+ ASSERT_OK_AND_ASSIGN(buf, it.Next());
+ AssertBufferEqual(*buf, "6");
+ ASSERT_OK_AND_ASSIGN(buf, it.Next());
+ ASSERT_EQ(buf, nullptr);
+ ASSERT_OK_AND_ASSIGN(buf, it.Next());
+ ASSERT_EQ(buf, nullptr);
+}
+
+TEST(TestInputStreamIterator, Closed) {
+ auto reader = std::make_shared<BufferReader>(Buffer::FromString("data123456"));
+ ASSERT_OK(reader->Close());
+ ASSERT_RAISES(Invalid, MakeInputStreamIterator(reader, 3));
+
+ reader = std::make_shared<BufferReader>(Buffer::FromString("data123456"));
+ ASSERT_OK_AND_ASSIGN(auto it, MakeInputStreamIterator(reader, /*block_size=*/3));
+ ASSERT_OK_AND_ASSIGN(auto buf, it.Next());
+ AssertBufferEqual(*buf, "dat");
+ // Close stream and read from iterator
+ ASSERT_OK(reader->Close());
+ ASSERT_RAISES(Invalid, it.Next().status());
+}
+
+TEST(CoalesceReadRanges, Basics) {
+ auto check = [](std::vector<ReadRange> ranges,
+ std::vector<ReadRange> expected) -> void {
+ const int64_t hole_size_limit = 9;
+ const int64_t range_size_limit = 99;
+ auto coalesced =
+ internal::CoalesceReadRanges(ranges, hole_size_limit, range_size_limit);
+ ASSERT_EQ(coalesced, expected);
+ };
+
+ check({}, {});
+ // Zero sized range that ends up in empty list
+ check({{110, 0}}, {});
+ // Combination on 1 zero sized range and 1 non-zero sized range
+ check({{110, 10}, {120, 0}}, {{110, 10}});
+ // 1 non-zero sized range
+ check({{110, 10}}, {{110, 10}});
+ // No holes + unordered ranges
+ check({{130, 10}, {110, 10}, {120, 10}}, {{110, 30}});
+ // No holes
+ check({{110, 10}, {120, 10}, {130, 10}}, {{110, 30}});
+ // Small holes only
+ check({{110, 11}, {130, 11}, {150, 11}}, {{110, 51}});
+ // Large holes
+ check({{110, 10}, {130, 10}}, {{110, 10}, {130, 10}});
+ check({{110, 11}, {130, 11}, {150, 10}, {170, 11}, {190, 11}}, {{110, 50}, {170, 31}});
+
+ // With zero-sized ranges
+ check({{110, 11}, {130, 0}, {130, 11}, {145, 0}, {150, 11}, {200, 0}}, {{110, 51}});
+
+ // No holes but large ranges
+ check({{110, 100}, {210, 100}}, {{110, 100}, {210, 100}});
+ // Small holes and large range in the middle (*)
+ check({{110, 10}, {120, 11}, {140, 100}, {240, 11}, {260, 11}},
+ {{110, 21}, {140, 100}, {240, 31}});
+ // Mid-size ranges that would turn large after coalescing
+ check({{100, 50}, {150, 50}}, {{100, 50}, {150, 50}});
+ check({{100, 30}, {130, 30}, {160, 30}, {190, 30}, {220, 30}}, {{100, 90}, {190, 60}});
+
+ // Same as (*) but unsorted
+ check({{140, 100}, {120, 11}, {240, 11}, {110, 10}, {260, 11}},
+ {{110, 21}, {140, 100}, {240, 31}});
+
+ // Completely overlapping ranges should be eliminated
+ check({{20, 5}, {20, 5}, {21, 2}}, {{20, 5}});
+}
+
+class CountingBufferReader : public BufferReader {
+ public:
+ using BufferReader::BufferReader;
+ Future<std::shared_ptr<Buffer>> ReadAsync(const IOContext& context, int64_t position,
+ int64_t nbytes) override {
+ read_count_++;
+ return BufferReader::ReadAsync(context, position, nbytes);
+ }
+ int64_t read_count() const { return read_count_; }
+
+ private:
+ int64_t read_count_ = 0;
+};
+
+TEST(RangeReadCache, Basics) {
+ std::string data = "abcdefghijklmnopqrstuvwxyz";
+
+ CacheOptions options = CacheOptions::Defaults();
+ options.hole_size_limit = 2;
+ options.range_size_limit = 10;
+
+ for (auto lazy : std::vector<bool>{false, true}) {
+ SCOPED_TRACE(lazy);
+ options.lazy = lazy;
+ auto file = std::make_shared<CountingBufferReader>(Buffer(data));
+ internal::ReadRangeCache cache(file, {}, options);
+
+ ASSERT_OK(cache.Cache({{1, 2}, {3, 2}, {8, 2}, {20, 2}, {25, 0}}));
+ ASSERT_OK(cache.Cache({{10, 4}, {14, 0}, {15, 4}}));
+
+ ASSERT_OK_AND_ASSIGN(auto buf, cache.Read({20, 2}));
+ AssertBufferEqual(*buf, "uv");
+ ASSERT_OK_AND_ASSIGN(buf, cache.Read({1, 2}));
+ AssertBufferEqual(*buf, "bc");
+ ASSERT_OK_AND_ASSIGN(buf, cache.Read({3, 2}));
+ AssertBufferEqual(*buf, "de");
+ ASSERT_OK_AND_ASSIGN(buf, cache.Read({8, 2}));
+ AssertBufferEqual(*buf, "ij");
+ ASSERT_OK_AND_ASSIGN(buf, cache.Read({10, 4}));
+ AssertBufferEqual(*buf, "klmn");
+ ASSERT_OK_AND_ASSIGN(buf, cache.Read({15, 4}));
+ AssertBufferEqual(*buf, "pqrs");
+ ASSERT_FINISHES_OK(cache.WaitFor({{15, 1}, {16, 3}, {25, 0}, {1, 2}}));
+ // Zero-sized
+ ASSERT_OK_AND_ASSIGN(buf, cache.Read({14, 0}));
+ AssertBufferEqual(*buf, "");
+ ASSERT_OK_AND_ASSIGN(buf, cache.Read({25, 0}));
+ AssertBufferEqual(*buf, "");
+
+ // Non-cached ranges
+ ASSERT_RAISES(Invalid, cache.Read({20, 3}));
+ ASSERT_RAISES(Invalid, cache.Read({19, 3}));
+ ASSERT_RAISES(Invalid, cache.Read({0, 3}));
+ ASSERT_RAISES(Invalid, cache.Read({25, 2}));
+ ASSERT_FINISHES_AND_RAISES(Invalid, cache.WaitFor({{25, 2}}));
+ ASSERT_FINISHES_AND_RAISES(Invalid, cache.WaitFor({{1, 2}, {25, 2}}));
+
+ ASSERT_FINISHES_OK(cache.Wait());
+ // 8 ranges should lead to less than 8 reads
+ ASSERT_LT(file->read_count(), 8);
+ }
+}
+
+TEST(RangeReadCache, Concurrency) {
+ std::string data = "abcdefghijklmnopqrstuvwxyz";
+
+ auto file = std::make_shared<BufferReader>(Buffer(data));
+ std::vector<ReadRange> ranges{{1, 2}, {3, 2}, {8, 2}, {20, 2},
+ {25, 0}, {10, 4}, {14, 0}, {15, 4}};
+
+ for (auto lazy : std::vector<bool>{false, true}) {
+ SCOPED_TRACE(lazy);
+ CacheOptions options = CacheOptions::Defaults();
+ options.hole_size_limit = 2;
+ options.range_size_limit = 10;
+ options.lazy = lazy;
+
+ {
+ internal::ReadRangeCache cache(file, {}, options);
+ ASSERT_OK(cache.Cache(ranges));
+ std::vector<Future<std::shared_ptr<Buffer>>> futures;
+ for (const auto& range : ranges) {
+ futures.push_back(
+ cache.WaitFor({range}).Then([&cache, range]() { return cache.Read(range); }));
+ }
+ for (auto fut : futures) {
+ ASSERT_FINISHES_OK(fut);
+ }
+ }
+ {
+ internal::ReadRangeCache cache(file, {}, options);
+ ASSERT_OK(cache.Cache(ranges));
+ ASSERT_OK(arrow::internal::ParallelFor(
+ static_cast<int>(ranges.size()),
+ [&](int index) { return cache.Read(ranges[index]).status(); }));
+ }
+ }
+}
+
+TEST(RangeReadCache, Lazy) {
+ std::string data = "abcdefghijklmnopqrstuvwxyz";
+
+ auto file = std::make_shared<CountingBufferReader>(Buffer(data));
+ CacheOptions options = CacheOptions::LazyDefaults();
+ options.hole_size_limit = 2;
+ options.range_size_limit = 10;
+ internal::ReadRangeCache cache(file, {}, options);
+
+ ASSERT_OK(cache.Cache({{1, 2}, {3, 2}, {8, 2}, {20, 2}, {25, 0}}));
+ ASSERT_OK(cache.Cache({{10, 4}, {14, 0}, {15, 4}}));
+
+ // Lazy cache doesn't fetch ranges until requested
+ ASSERT_EQ(0, file->read_count());
+
+ ASSERT_OK_AND_ASSIGN(auto buf, cache.Read({20, 2}));
+ AssertBufferEqual(*buf, "uv");
+ ASSERT_EQ(1, file->read_count());
+
+ ASSERT_OK_AND_ASSIGN(buf, cache.Read({1, 4}));
+ AssertBufferEqual(*buf, "bcde");
+ ASSERT_EQ(2, file->read_count());
+
+ // Requested ranges are still cached
+ ASSERT_OK_AND_ASSIGN(buf, cache.Read({1, 4}));
+ ASSERT_EQ(2, file->read_count());
+
+ // Non-cached ranges
+ ASSERT_RAISES(Invalid, cache.Read({20, 3}));
+ ASSERT_RAISES(Invalid, cache.Read({19, 3}));
+ ASSERT_RAISES(Invalid, cache.Read({0, 3}));
+ ASSERT_RAISES(Invalid, cache.Read({25, 2}));
+
+ // Can asynchronously kick off a read (though BufferReader::ReadAsync is synchronous so
+ // it will increment the read count here)
+ ASSERT_FINISHES_OK(cache.WaitFor({{10, 2}, {15, 4}}));
+ ASSERT_EQ(3, file->read_count());
+ ASSERT_OK_AND_ASSIGN(buf, cache.Read({10, 2}));
+ ASSERT_EQ(3, file->read_count());
+}
+
+TEST(CacheOptions, Basics) {
+ auto check = [](const CacheOptions actual, const double expected_hole_size_limit_MiB,
+ const double expected_range_size_limit_MiB) -> void {
+ const CacheOptions expected = {
+ static_cast<int64_t>(std::round(expected_hole_size_limit_MiB * 1024 * 1024)),
+ static_cast<int64_t>(std::round(expected_range_size_limit_MiB * 1024 * 1024)),
+ /*lazy=*/false};
+ ASSERT_EQ(actual, expected);
+ };
+
+ // Test: normal usage.
+ // TTFB = 5 ms, BW = 500 MiB/s,
+ // we expect hole_size_limit = 2.5 MiB, and range_size_limit = 22.5 MiB
+ check(CacheOptions::MakeFromNetworkMetrics(5, 500), 2.5, 22.5);
+ // Test: custom bandwidth utilization.
+ // TTFB = 5 ms, BW = 500 MiB/s, BW_utilization = 75%,
+ // we expect a change in range_size_limit = 7.5 MiB.
+ check(CacheOptions::MakeFromNetworkMetrics(5, 500, .75), 2.5, 7.5);
+ // Test: custom max_ideal_request_size, range_size_limit gets capped.
+ // TTFB = 5 ms, BW = 500 MiB/s, BW_utilization = 75%, max_ideal_request_size = 5 MiB,
+ // we expect the range_size_limit to be capped at 5 MiB.
+ check(CacheOptions::MakeFromNetworkMetrics(5, 500, .75, 5), 2.5, 5);
+}
+
+TEST(IOThreadPool, Capacity) {
+ // Simple sanity check
+ auto pool = internal::GetIOThreadPool();
+ int capacity = pool->GetCapacity();
+ ASSERT_GT(capacity, 0);
+ ASSERT_EQ(GetIOThreadPoolCapacity(), capacity);
+ ASSERT_OK(SetIOThreadPoolCapacity(capacity + 1));
+ ASSERT_EQ(GetIOThreadPoolCapacity(), capacity + 1);
+}
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/mman.h b/src/arrow/cpp/src/arrow/io/mman.h
new file mode 100644
index 000000000..9b06ac8e7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/mman.h
@@ -0,0 +1,169 @@
+// Copyright https://code.google.com/p/mman-win32/
+//
+// Licensed under the MIT License;
+// You may obtain a copy of the License at
+//
+// https://opensource.org/licenses/MIT
+
+#pragma once
+
+#include "arrow/util/windows_compatibility.h"
+
+#include <errno.h>
+#include <io.h>
+#include <sys/types.h>
+
+#include <cstdint>
+
+#define PROT_NONE 0
+#define PROT_READ 1
+#define PROT_WRITE 2
+#define PROT_EXEC 4
+
+#define MAP_FILE 0
+#define MAP_SHARED 1
+#define MAP_PRIVATE 2
+#define MAP_TYPE 0xf
+#define MAP_FIXED 0x10
+#define MAP_ANONYMOUS 0x20
+#define MAP_ANON MAP_ANONYMOUS
+
+#define MAP_FAILED ((void*)-1)
+
+/* Flags for msync. */
+#define MS_ASYNC 1
+#define MS_SYNC 2
+#define MS_INVALIDATE 4
+
+#ifndef FILE_MAP_EXECUTE
+#define FILE_MAP_EXECUTE 0x0020
+#endif
+
+static inline int __map_mman_error(const DWORD err, const int deferr) {
+ if (err == 0) return 0;
+ // TODO: implement
+ return err;
+}
+
+static inline DWORD __map_mmap_prot_page(const int prot) {
+ DWORD protect = 0;
+
+ if (prot == PROT_NONE) return protect;
+
+ if ((prot & PROT_EXEC) != 0) {
+ protect = ((prot & PROT_WRITE) != 0) ? PAGE_EXECUTE_READWRITE : PAGE_EXECUTE_READ;
+ } else {
+ protect = ((prot & PROT_WRITE) != 0) ? PAGE_READWRITE : PAGE_READONLY;
+ }
+
+ return protect;
+}
+
+static inline DWORD __map_mmap_prot_file(const int prot) {
+ DWORD desiredAccess = 0;
+
+ if (prot == PROT_NONE) return desiredAccess;
+
+ if ((prot & PROT_READ) != 0) desiredAccess |= FILE_MAP_READ;
+ if ((prot & PROT_WRITE) != 0) desiredAccess |= FILE_MAP_WRITE;
+ if ((prot & PROT_EXEC) != 0) desiredAccess |= FILE_MAP_EXECUTE;
+
+ return desiredAccess;
+}
+
+static inline void* mmap(void* addr, size_t len, int prot, int flags, int fildes,
+ off_t off) {
+ HANDLE fm, h;
+
+ void* map = MAP_FAILED;
+ const uint64_t off64 = static_cast<uint64_t>(off);
+ const uint64_t maxSize = off64 + len;
+
+ const DWORD dwFileOffsetLow = static_cast<DWORD>(off64 & 0xFFFFFFFFUL);
+ const DWORD dwFileOffsetHigh = static_cast<DWORD>((off64 >> 32) & 0xFFFFFFFFUL);
+ const DWORD dwMaxSizeLow = static_cast<DWORD>(maxSize & 0xFFFFFFFFUL);
+ const DWORD dwMaxSizeHigh = static_cast<DWORD>((maxSize >> 32) & 0xFFFFFFFFUL);
+
+ const DWORD protect = __map_mmap_prot_page(prot);
+ const DWORD desiredAccess = __map_mmap_prot_file(prot);
+
+ errno = 0;
+
+ if (len == 0
+ /* Unsupported flag combinations */
+ || (flags & MAP_FIXED) != 0
+ /* Unsupported protection combinations */
+ || prot == PROT_EXEC) {
+ errno = EINVAL;
+ return MAP_FAILED;
+ }
+
+ h = ((flags & MAP_ANONYMOUS) == 0) ? (HANDLE)_get_osfhandle(fildes)
+ : INVALID_HANDLE_VALUE;
+
+ if ((flags & MAP_ANONYMOUS) == 0 && h == INVALID_HANDLE_VALUE) {
+ errno = EBADF;
+ return MAP_FAILED;
+ }
+
+ fm = CreateFileMapping(h, NULL, protect, dwMaxSizeHigh, dwMaxSizeLow, NULL);
+
+ if (fm == NULL) {
+ errno = __map_mman_error(GetLastError(), EPERM);
+ return MAP_FAILED;
+ }
+
+ map = MapViewOfFile(fm, desiredAccess, dwFileOffsetHigh, dwFileOffsetLow, len);
+
+ CloseHandle(fm);
+
+ if (map == NULL) {
+ errno = __map_mman_error(GetLastError(), EPERM);
+ return MAP_FAILED;
+ }
+
+ return map;
+}
+
+static inline int munmap(void* addr, size_t len) {
+ if (UnmapViewOfFile(addr)) return 0;
+
+ errno = __map_mman_error(GetLastError(), EPERM);
+
+ return -1;
+}
+
+static inline int mprotect(void* addr, size_t len, int prot) {
+ DWORD newProtect = __map_mmap_prot_page(prot);
+ DWORD oldProtect = 0;
+
+ if (VirtualProtect(addr, len, newProtect, &oldProtect)) return 0;
+
+ errno = __map_mman_error(GetLastError(), EPERM);
+
+ return -1;
+}
+
+static inline int msync(void* addr, size_t len, int flags) {
+ if (FlushViewOfFile(addr, len)) return 0;
+
+ errno = __map_mman_error(GetLastError(), EPERM);
+
+ return -1;
+}
+
+static inline int mlock(const void* addr, size_t len) {
+ if (VirtualLock((LPVOID)addr, len)) return 0;
+
+ errno = __map_mman_error(GetLastError(), EPERM);
+
+ return -1;
+}
+
+static inline int munlock(const void* addr, size_t len) {
+ if (VirtualUnlock((LPVOID)addr, len)) return 0;
+
+ errno = __map_mman_error(GetLastError(), EPERM);
+
+ return -1;
+}
diff --git a/src/arrow/cpp/src/arrow/io/slow.cc b/src/arrow/cpp/src/arrow/io/slow.cc
new file mode 100644
index 000000000..1042691fa
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/slow.cc
@@ -0,0 +1,148 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/io/slow.h"
+
+#include <algorithm>
+#include <cstring>
+#include <mutex>
+#include <random>
+#include <thread>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/io/util_internal.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace io {
+
+// Multiply the average by this ratio to get the intended standard deviation
+static constexpr double kStandardDeviationRatio = 0.1;
+
+class LatencyGeneratorImpl : public LatencyGenerator {
+ public:
+ ~LatencyGeneratorImpl() override = default;
+
+ LatencyGeneratorImpl(double average_latency, int32_t seed)
+ : gen_(static_cast<decltype(gen_)::result_type>(seed)),
+ latency_dist_(average_latency, average_latency * kStandardDeviationRatio) {}
+
+ double NextLatency() override {
+ // std::random distributions are unlikely to be thread-safe, and
+ // a RandomAccessFile may be called from multiple threads
+ std::lock_guard<std::mutex> lock(mutex_);
+ return std::max<double>(0.0, latency_dist_(gen_));
+ }
+
+ private:
+ std::default_random_engine gen_;
+ std::normal_distribution<double> latency_dist_;
+ std::mutex mutex_;
+};
+
+LatencyGenerator::~LatencyGenerator() {}
+
+void LatencyGenerator::Sleep() {
+ std::this_thread::sleep_for(std::chrono::duration<double>(NextLatency()));
+}
+
+std::shared_ptr<LatencyGenerator> LatencyGenerator::Make(double average_latency) {
+ return std::make_shared<LatencyGeneratorImpl>(
+ average_latency, static_cast<int32_t>(::arrow::internal::GetRandomSeed()));
+}
+
+std::shared_ptr<LatencyGenerator> LatencyGenerator::Make(double average_latency,
+ int32_t seed) {
+ return std::make_shared<LatencyGeneratorImpl>(average_latency, seed);
+}
+
+//////////////////////////////////////////////////////////////////////////
+// SlowInputStream implementation
+
+SlowInputStream::~SlowInputStream() { internal::CloseFromDestructor(this); }
+
+Status SlowInputStream::Close() { return stream_->Close(); }
+
+Status SlowInputStream::Abort() { return stream_->Abort(); }
+
+bool SlowInputStream::closed() const { return stream_->closed(); }
+
+Result<int64_t> SlowInputStream::Tell() const { return stream_->Tell(); }
+
+Result<int64_t> SlowInputStream::Read(int64_t nbytes, void* out) {
+ latencies_->Sleep();
+ return stream_->Read(nbytes, out);
+}
+
+Result<std::shared_ptr<Buffer>> SlowInputStream::Read(int64_t nbytes) {
+ latencies_->Sleep();
+ return stream_->Read(nbytes);
+}
+
+Result<util::string_view> SlowInputStream::Peek(int64_t nbytes) {
+ return stream_->Peek(nbytes);
+}
+
+//////////////////////////////////////////////////////////////////////////
+// SlowRandomAccessFile implementation
+
+SlowRandomAccessFile::~SlowRandomAccessFile() { internal::CloseFromDestructor(this); }
+
+Status SlowRandomAccessFile::Close() { return stream_->Close(); }
+
+Status SlowRandomAccessFile::Abort() { return stream_->Abort(); }
+
+bool SlowRandomAccessFile::closed() const { return stream_->closed(); }
+
+Result<int64_t> SlowRandomAccessFile::GetSize() { return stream_->GetSize(); }
+
+Status SlowRandomAccessFile::Seek(int64_t position) { return stream_->Seek(position); }
+
+Result<int64_t> SlowRandomAccessFile::Tell() const { return stream_->Tell(); }
+
+Result<int64_t> SlowRandomAccessFile::Read(int64_t nbytes, void* out) {
+ latencies_->Sleep();
+ return stream_->Read(nbytes, out);
+}
+
+Result<std::shared_ptr<Buffer>> SlowRandomAccessFile::Read(int64_t nbytes) {
+ latencies_->Sleep();
+ return stream_->Read(nbytes);
+}
+
+Result<int64_t> SlowRandomAccessFile::ReadAt(int64_t position, int64_t nbytes,
+ void* out) {
+ latencies_->Sleep();
+ return stream_->ReadAt(position, nbytes, out);
+}
+
+Result<std::shared_ptr<Buffer>> SlowRandomAccessFile::ReadAt(int64_t position,
+ int64_t nbytes) {
+ latencies_->Sleep();
+ return stream_->ReadAt(position, nbytes);
+}
+
+Result<util::string_view> SlowRandomAccessFile::Peek(int64_t nbytes) {
+ return stream_->Peek(nbytes);
+}
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/slow.h b/src/arrow/cpp/src/arrow/io/slow.h
new file mode 100644
index 000000000..b0c02a85a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/slow.h
@@ -0,0 +1,118 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Slow stream implementations, mainly for testing and benchmarking
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "arrow/io/interfaces.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Buffer;
+class Status;
+
+namespace io {
+
+class ARROW_EXPORT LatencyGenerator {
+ public:
+ virtual ~LatencyGenerator();
+
+ void Sleep();
+
+ virtual double NextLatency() = 0;
+
+ static std::shared_ptr<LatencyGenerator> Make(double average_latency);
+ static std::shared_ptr<LatencyGenerator> Make(double average_latency, int32_t seed);
+};
+
+// XXX use ConcurrencyWrapper? It could increase chances of finding a race.
+
+template <class StreamType>
+class ARROW_EXPORT SlowInputStreamBase : public StreamType {
+ public:
+ SlowInputStreamBase(std::shared_ptr<StreamType> stream,
+ std::shared_ptr<LatencyGenerator> latencies)
+ : stream_(std::move(stream)), latencies_(std::move(latencies)) {}
+
+ SlowInputStreamBase(std::shared_ptr<StreamType> stream, double average_latency)
+ : stream_(std::move(stream)), latencies_(LatencyGenerator::Make(average_latency)) {}
+
+ SlowInputStreamBase(std::shared_ptr<StreamType> stream, double average_latency,
+ int32_t seed)
+ : stream_(std::move(stream)),
+ latencies_(LatencyGenerator::Make(average_latency, seed)) {}
+
+ protected:
+ std::shared_ptr<StreamType> stream_;
+ std::shared_ptr<LatencyGenerator> latencies_;
+};
+
+/// \brief An InputStream wrapper that makes reads slower.
+///
+/// Read() calls are made slower by an average latency (in seconds).
+/// Actual latencies form a normal distribution closely centered
+/// on the average latency.
+/// Other calls are forwarded directly.
+class ARROW_EXPORT SlowInputStream : public SlowInputStreamBase<InputStream> {
+ public:
+ ~SlowInputStream() override;
+
+ using SlowInputStreamBase<InputStream>::SlowInputStreamBase;
+
+ Status Close() override;
+ Status Abort() override;
+ bool closed() const override;
+
+ Result<int64_t> Read(int64_t nbytes, void* out) override;
+ Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) override;
+ Result<util::string_view> Peek(int64_t nbytes) override;
+
+ Result<int64_t> Tell() const override;
+};
+
+/// \brief A RandomAccessFile wrapper that makes reads slower.
+///
+/// Similar to SlowInputStream, but allows random access and seeking.
+class ARROW_EXPORT SlowRandomAccessFile : public SlowInputStreamBase<RandomAccessFile> {
+ public:
+ ~SlowRandomAccessFile() override;
+
+ using SlowInputStreamBase<RandomAccessFile>::SlowInputStreamBase;
+
+ Status Close() override;
+ Status Abort() override;
+ bool closed() const override;
+
+ Result<int64_t> Read(int64_t nbytes, void* out) override;
+ Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) override;
+ Result<int64_t> ReadAt(int64_t position, int64_t nbytes, void* out) override;
+ Result<std::shared_ptr<Buffer>> ReadAt(int64_t position, int64_t nbytes) override;
+ Result<util::string_view> Peek(int64_t nbytes) override;
+
+ Result<int64_t> GetSize() override;
+ Status Seek(int64_t position) override;
+ Result<int64_t> Tell() const override;
+};
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/stdio.cc b/src/arrow/cpp/src/arrow/io/stdio.cc
new file mode 100644
index 000000000..7ef4843a2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/stdio.cc
@@ -0,0 +1,95 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/io/stdio.h"
+
+#include <iostream>
+
+#include "arrow/buffer.h"
+#include "arrow/result.h"
+
+namespace arrow {
+namespace io {
+
+//
+// StdoutStream implementation
+//
+
+StdoutStream::StdoutStream() : pos_(0) { set_mode(FileMode::WRITE); }
+
+Status StdoutStream::Close() { return Status::OK(); }
+
+bool StdoutStream::closed() const { return false; }
+
+Result<int64_t> StdoutStream::Tell() const { return pos_; }
+
+Status StdoutStream::Write(const void* data, int64_t nbytes) {
+ pos_ += nbytes;
+ std::cout.write(reinterpret_cast<const char*>(data), nbytes);
+ return Status::OK();
+}
+
+//
+// StderrStream implementation
+//
+
+StderrStream::StderrStream() : pos_(0) { set_mode(FileMode::WRITE); }
+
+Status StderrStream::Close() { return Status::OK(); }
+
+bool StderrStream::closed() const { return false; }
+
+Result<int64_t> StderrStream::Tell() const { return pos_; }
+
+Status StderrStream::Write(const void* data, int64_t nbytes) {
+ pos_ += nbytes;
+ std::cerr.write(reinterpret_cast<const char*>(data), nbytes);
+ return Status::OK();
+}
+
+//
+// StdinStream implementation
+//
+
+StdinStream::StdinStream() : pos_(0) { set_mode(FileMode::READ); }
+
+Status StdinStream::Close() { return Status::OK(); }
+
+bool StdinStream::closed() const { return false; }
+
+Result<int64_t> StdinStream::Tell() const { return pos_; }
+
+Result<int64_t> StdinStream::Read(int64_t nbytes, void* out) {
+ std::cin.read(reinterpret_cast<char*>(out), nbytes);
+ if (std::cin) {
+ pos_ += nbytes;
+ return nbytes;
+ } else {
+ return 0;
+ }
+}
+
+Result<std::shared_ptr<Buffer>> StdinStream::Read(int64_t nbytes) {
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateResizableBuffer(nbytes));
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, Read(nbytes, buffer->mutable_data()));
+ ARROW_RETURN_NOT_OK(buffer->Resize(bytes_read, false));
+ buffer->ZeroPadding();
+ return std::move(buffer);
+}
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/stdio.h b/src/arrow/cpp/src/arrow/io/stdio.h
new file mode 100644
index 000000000..9484ac771
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/stdio.h
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+
+#include "arrow/io/interfaces.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace io {
+
+// Output stream that just writes to stdout.
+class ARROW_EXPORT StdoutStream : public OutputStream {
+ public:
+ StdoutStream();
+ ~StdoutStream() override {}
+
+ Status Close() override;
+ bool closed() const override;
+
+ Result<int64_t> Tell() const override;
+
+ Status Write(const void* data, int64_t nbytes) override;
+
+ private:
+ int64_t pos_;
+};
+
+// Output stream that just writes to stderr.
+class ARROW_EXPORT StderrStream : public OutputStream {
+ public:
+ StderrStream();
+ ~StderrStream() override {}
+
+ Status Close() override;
+ bool closed() const override;
+
+ Result<int64_t> Tell() const override;
+
+ Status Write(const void* data, int64_t nbytes) override;
+
+ private:
+ int64_t pos_;
+};
+
+// Input stream that just reads from stdin.
+class ARROW_EXPORT StdinStream : public InputStream {
+ public:
+ StdinStream();
+ ~StdinStream() override {}
+
+ Status Close() override;
+ bool closed() const override;
+
+ Result<int64_t> Tell() const override;
+
+ Result<int64_t> Read(int64_t nbytes, void* out) override;
+
+ Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) override;
+
+ private:
+ int64_t pos_;
+};
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/test_common.cc b/src/arrow/cpp/src/arrow/io/test_common.cc
new file mode 100644
index 000000000..0a9686a28
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/test_common.cc
@@ -0,0 +1,121 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/io/test_common.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <fstream> // IWYU pragma: keep
+
+#ifdef _WIN32
+#include <crtdbg.h>
+#include <io.h>
+#else
+#include <fcntl.h>
+#endif
+
+#include "arrow/buffer.h"
+#include "arrow/io/file.h"
+#include "arrow/io/memory.h"
+#include "arrow/memory_pool.h"
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+namespace io {
+
+void AssertFileContents(const std::string& path, const std::string& contents) {
+ ASSERT_OK_AND_ASSIGN(auto rf, ReadableFile::Open(path));
+ ASSERT_OK_AND_ASSIGN(int64_t size, rf->GetSize());
+ ASSERT_EQ(size, contents.size());
+
+ ASSERT_OK_AND_ASSIGN(auto actual_data, rf->Read(size));
+ ASSERT_TRUE(actual_data->Equals(Buffer(contents)));
+}
+
+bool FileExists(const std::string& path) { return std::ifstream(path.c_str()).good(); }
+
+#if defined(_WIN32)
+static void InvalidParamHandler(const wchar_t* expr, const wchar_t* func,
+ const wchar_t* source_file, unsigned int source_line,
+ uintptr_t reserved) {
+ wprintf(L"Invalid parameter in function '%s'. Source: '%s' line %d expression '%s'\n",
+ func, source_file, source_line, expr);
+}
+#endif
+
+bool FileIsClosed(int fd) {
+#if defined(_WIN32)
+ // Disables default behavior on wrong params which causes the application to crash
+ // https://msdn.microsoft.com/en-us/library/ksazx244.aspx
+ _set_invalid_parameter_handler(InvalidParamHandler);
+
+ // Disables possible assertion alert box on invalid input arguments
+ _CrtSetReportMode(_CRT_ASSERT, 0);
+
+ int new_fd = _dup(fd);
+ if (new_fd == -1) {
+ return errno == EBADF;
+ }
+ _close(new_fd);
+ return false;
+#else
+ if (-1 != fcntl(fd, F_GETFD)) {
+ return false;
+ }
+ return errno == EBADF;
+#endif
+}
+
+Status ZeroMemoryMap(MemoryMappedFile* file) {
+ constexpr int64_t kBufferSize = 512;
+ static constexpr uint8_t kZeroBytes[kBufferSize] = {0};
+
+ RETURN_NOT_OK(file->Seek(0));
+ int64_t position = 0;
+ ARROW_ASSIGN_OR_RAISE(int64_t file_size, file->GetSize());
+
+ int64_t chunksize;
+ while (position < file_size) {
+ chunksize = std::min(kBufferSize, file_size - position);
+ RETURN_NOT_OK(file->Write(kZeroBytes, chunksize));
+ position += chunksize;
+ }
+ return Status::OK();
+}
+
+void MemoryMapFixture::TearDown() {
+ for (auto path : tmp_files_) {
+ ARROW_UNUSED(std::remove(path.c_str()));
+ }
+}
+
+void MemoryMapFixture::CreateFile(const std::string& path, int64_t size) {
+ ASSERT_OK(MemoryMappedFile::Create(path, size));
+ tmp_files_.push_back(path);
+}
+
+Result<std::shared_ptr<MemoryMappedFile>> MemoryMapFixture::InitMemoryMap(
+ int64_t size, const std::string& path) {
+ ARROW_ASSIGN_OR_RAISE(auto mmap, MemoryMappedFile::Create(path, size));
+ tmp_files_.push_back(path);
+ return mmap;
+}
+
+void MemoryMapFixture::AppendFile(const std::string& path) { tmp_files_.push_back(path); }
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/test_common.h b/src/arrow/cpp/src/arrow/io/test_common.h
new file mode 100644
index 000000000..ba263a3ad
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/test_common.h
@@ -0,0 +1,58 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/testing/visibility.h"
+#include "arrow/type_fwd.h"
+
+namespace arrow {
+namespace io {
+
+class MemoryMappedFile;
+
+ARROW_TESTING_EXPORT
+void AssertFileContents(const std::string& path, const std::string& contents);
+
+ARROW_TESTING_EXPORT bool FileExists(const std::string& path);
+
+ARROW_TESTING_EXPORT bool FileIsClosed(int fd);
+
+ARROW_TESTING_EXPORT
+Status ZeroMemoryMap(MemoryMappedFile* file);
+
+class ARROW_TESTING_EXPORT MemoryMapFixture {
+ public:
+ void TearDown();
+
+ void CreateFile(const std::string& path, int64_t size);
+
+ Result<std::shared_ptr<MemoryMappedFile>> InitMemoryMap(int64_t size,
+ const std::string& path);
+
+ void AppendFile(const std::string& path);
+
+ private:
+ std::vector<std::string> tmp_files_;
+};
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/transform.cc b/src/arrow/cpp/src/arrow/io/transform.cc
new file mode 100644
index 000000000..3fdf5a7a9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/transform.cc
@@ -0,0 +1,162 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/io/transform.h"
+
+#include <algorithm>
+#include <cstring>
+#include <mutex>
+#include <random>
+#include <thread>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/io/util_internal.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace io {
+
+struct TransformInputStream::Impl {
+ std::shared_ptr<InputStream> wrapped_;
+ TransformInputStream::TransformFunc transform_;
+ std::shared_ptr<Buffer> pending_;
+ int64_t pos_ = 0;
+ bool closed_ = false;
+
+ Impl(std::shared_ptr<InputStream> wrapped,
+ TransformInputStream::TransformFunc transform)
+ : wrapped_(std::move(wrapped)), transform_(std::move(transform)) {}
+
+ void Close() {
+ closed_ = true;
+ pending_.reset();
+ }
+
+ Status CheckClosed() const {
+ if (closed_) {
+ return Status::Invalid("Operation on closed file");
+ }
+ return Status::OK();
+ }
+};
+
+TransformInputStream::TransformInputStream(std::shared_ptr<InputStream> wrapped,
+ TransformInputStream::TransformFunc transform)
+ : impl_(new Impl{std::move(wrapped), std::move(transform)}) {}
+
+TransformInputStream::~TransformInputStream() {}
+
+Status TransformInputStream::Close() {
+ impl_->Close();
+ return impl_->wrapped_->Close();
+}
+
+Status TransformInputStream::Abort() { return impl_->wrapped_->Abort(); }
+
+bool TransformInputStream::closed() const { return impl_->closed_; }
+
+Result<std::shared_ptr<Buffer>> TransformInputStream::Read(int64_t nbytes) {
+ RETURN_NOT_OK(impl_->CheckClosed());
+
+ ARROW_ASSIGN_OR_RAISE(auto buf, AllocateResizableBuffer(nbytes));
+ ARROW_ASSIGN_OR_RAISE(auto bytes_read, this->Read(nbytes, buf->mutable_data()));
+ if (bytes_read < nbytes) {
+ RETURN_NOT_OK(buf->Resize(bytes_read, /*shrink_to_fit=*/true));
+ }
+ return std::shared_ptr<Buffer>(std::move(buf));
+}
+
+Result<int64_t> TransformInputStream::Read(int64_t nbytes, void* out) {
+ RETURN_NOT_OK(impl_->CheckClosed());
+
+ if (nbytes == 0) {
+ return 0;
+ }
+
+ int64_t avail_size = 0;
+ std::vector<std::shared_ptr<Buffer>> avail;
+ if (impl_->pending_) {
+ avail.push_back(impl_->pending_);
+ avail_size += impl_->pending_->size();
+ }
+ // Accumulate enough transformed data to satisfy read
+ while (avail_size < nbytes) {
+ ARROW_ASSIGN_OR_RAISE(auto buf, impl_->wrapped_->Read(nbytes));
+ const bool have_eof = (buf->size() == 0);
+ // Even if EOF is met, let the transform function run a last time
+ // (for example to flush internal buffers)
+ ARROW_ASSIGN_OR_RAISE(buf, impl_->transform_(std::move(buf)));
+ avail_size += buf->size();
+ avail.push_back(std::move(buf));
+ if (have_eof) {
+ break;
+ }
+ }
+ DCHECK(!avail.empty());
+
+ // Coalesce buffer data
+ uint8_t* out_data = reinterpret_cast<uint8_t*>(out);
+ int64_t copied_bytes = 0;
+ for (size_t i = 0; i < avail.size() - 1; ++i) {
+ // All buffers except the last fit fully into `nbytes`
+ const auto buf = std::move(avail[i]);
+ DCHECK_LE(buf->size(), nbytes);
+ memcpy(out_data, buf->data(), static_cast<size_t>(buf->size()));
+ out_data += buf->size();
+ nbytes -= buf->size();
+ copied_bytes += buf->size();
+ }
+ {
+ // Last buffer: splice into `out` and `pending_`
+ const auto buf = std::move(avail.back());
+ const int64_t to_copy = std::min(buf->size(), nbytes);
+ memcpy(out_data, buf->data(), static_cast<size_t>(to_copy));
+ copied_bytes += to_copy;
+ if (buf->size() > to_copy) {
+ impl_->pending_ = SliceBuffer(buf, to_copy);
+ } else {
+ impl_->pending_.reset();
+ }
+ }
+ impl_->pos_ += copied_bytes;
+ return copied_bytes;
+}
+
+Result<int64_t> TransformInputStream::Tell() const {
+ RETURN_NOT_OK(impl_->CheckClosed());
+
+ return impl_->pos_;
+}
+
+Result<std::shared_ptr<const KeyValueMetadata>> TransformInputStream::ReadMetadata() {
+ RETURN_NOT_OK(impl_->CheckClosed());
+
+ return impl_->wrapped_->ReadMetadata();
+}
+
+Future<std::shared_ptr<const KeyValueMetadata>> TransformInputStream::ReadMetadataAsync(
+ const IOContext& io_context) {
+ RETURN_NOT_OK(impl_->CheckClosed());
+
+ return impl_->wrapped_->ReadMetadataAsync(io_context);
+}
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/transform.h b/src/arrow/cpp/src/arrow/io/transform.h
new file mode 100644
index 000000000..c117f2759
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/transform.h
@@ -0,0 +1,60 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Slow stream implementations, mainly for testing and benchmarking
+
+#pragma once
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <utility>
+
+#include "arrow/io/interfaces.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace io {
+
+class ARROW_EXPORT TransformInputStream : public InputStream {
+ public:
+ using TransformFunc =
+ std::function<Result<std::shared_ptr<Buffer>>(const std::shared_ptr<Buffer>&)>;
+
+ TransformInputStream(std::shared_ptr<InputStream> wrapped, TransformFunc transform);
+ ~TransformInputStream() override;
+
+ Status Close() override;
+ Status Abort() override;
+ bool closed() const override;
+
+ Result<int64_t> Read(int64_t nbytes, void* out) override;
+ Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) override;
+
+ Result<std::shared_ptr<const KeyValueMetadata>> ReadMetadata() override;
+ Future<std::shared_ptr<const KeyValueMetadata>> ReadMetadataAsync(
+ const IOContext& io_context) override;
+
+ Result<int64_t> Tell() const override;
+
+ protected:
+ struct Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/type_fwd.h b/src/arrow/cpp/src/arrow/io/type_fwd.h
new file mode 100644
index 000000000..a2fd33bf3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/type_fwd.h
@@ -0,0 +1,79 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace io {
+
+struct FileMode {
+ enum type { READ, WRITE, READWRITE };
+};
+
+struct IOContext;
+struct CacheOptions;
+
+/// EXPERIMENTAL: convenience global singleton for default IOContext settings
+ARROW_EXPORT
+const IOContext& default_io_context();
+
+/// \brief Get the capacity of the global I/O thread pool
+///
+/// Return the number of worker threads in the thread pool to which
+/// Arrow dispatches various I/O-bound tasks. This is an ideal number,
+/// not necessarily the exact number of threads at a given point in time.
+///
+/// You can change this number using SetIOThreadPoolCapacity().
+ARROW_EXPORT int GetIOThreadPoolCapacity();
+
+/// \brief Set the capacity of the global I/O thread pool
+///
+/// Set the number of worker threads in the thread pool to which
+/// Arrow dispatches various I/O-bound tasks.
+///
+/// The current number is returned by GetIOThreadPoolCapacity().
+ARROW_EXPORT Status SetIOThreadPoolCapacity(int threads);
+
+class FileInterface;
+class Seekable;
+class Writable;
+class Readable;
+class OutputStream;
+class FileOutputStream;
+class InputStream;
+class ReadableFile;
+class RandomAccessFile;
+class MemoryMappedFile;
+class WritableFile;
+class ReadWriteFileInterface;
+
+class LatencyGenerator;
+
+class BufferReader;
+
+class BufferInputStream;
+class BufferOutputStream;
+class CompressedInputStream;
+class CompressedOutputStream;
+class BufferedInputStream;
+class BufferedOutputStream;
+
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/io/util_internal.h b/src/arrow/cpp/src/arrow/io/util_internal.h
new file mode 100644
index 000000000..b1d75d1d0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/io/util_internal.h
@@ -0,0 +1,66 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/io/interfaces.h"
+#include "arrow/util/thread_pool.h"
+#include "arrow/util/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace io {
+namespace internal {
+
+ARROW_EXPORT void CloseFromDestructor(FileInterface* file);
+
+// Validate a (offset, size) region (as given to ReadAt) against
+// the file size. Return the actual read size.
+ARROW_EXPORT Result<int64_t> ValidateReadRange(int64_t offset, int64_t size,
+ int64_t file_size);
+// Validate a (offset, size) region (as given to WriteAt) against
+// the file size. Short writes are not allowed.
+ARROW_EXPORT Status ValidateWriteRange(int64_t offset, int64_t size, int64_t file_size);
+
+// Validate a (offset, size) region (as given to ReadAt or WriteAt), without
+// knowing the file size.
+ARROW_EXPORT Status ValidateRange(int64_t offset, int64_t size);
+
+ARROW_EXPORT
+std::vector<ReadRange> CoalesceReadRanges(std::vector<ReadRange> ranges,
+ int64_t hole_size_limit,
+ int64_t range_size_limit);
+
+ARROW_EXPORT
+::arrow::internal::ThreadPool* GetIOThreadPool();
+
+template <typename... SubmitArgs>
+auto SubmitIO(IOContext io_context, SubmitArgs&&... submit_args)
+ -> decltype(std::declval<::arrow::internal::Executor*>()->Submit(submit_args...)) {
+ ::arrow::internal::TaskHints hints;
+ hints.external_id = io_context.external_id();
+ return io_context.executor()->Submit(hints, io_context.stop_token(),
+ std::forward<SubmitArgs>(submit_args)...);
+}
+
+} // namespace internal
+} // namespace io
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/CMakeLists.txt b/src/arrow/cpp/src/arrow/ipc/CMakeLists.txt
new file mode 100644
index 000000000..495018ec0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/CMakeLists.txt
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#
+# Messaging and interprocess communication
+
+add_custom_target(arrow_ipc)
+
+function(ADD_ARROW_IPC_TEST REL_TEST_NAME)
+ set(options)
+ set(one_value_args PREFIX)
+ set(multi_value_args LABELS)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+
+ if(ARG_PREFIX)
+ set(PREFIX ${ARG_PREFIX})
+ else()
+ set(PREFIX "arrow-ipc")
+ endif()
+
+ add_arrow_test(${REL_TEST_NAME}
+ EXTRA_LINK_LIBS
+ ${ARROW_DATASET_TEST_LINK_LIBS}
+ PREFIX
+ ${PREFIX}
+ ${ARG_UNPARSED_ARGUMENTS})
+endfunction()
+
+add_arrow_test(feather_test)
+add_arrow_ipc_test(json_simple_test)
+add_arrow_ipc_test(read_write_test)
+add_arrow_ipc_test(tensor_test)
+
+# Headers: top level
+arrow_install_all_headers("arrow/ipc")
+
+if(ARROW_BUILD_STATIC)
+ set(ARROW_UTIL_LIB arrow_static)
+else()
+ set(ARROW_UTIL_LIB arrow_shared)
+endif()
+
+if(ARROW_BUILD_UTILITIES OR ARROW_BUILD_INTEGRATION)
+ add_executable(arrow-file-to-stream file_to_stream.cc)
+ target_link_libraries(arrow-file-to-stream ${ARROW_UTIL_LIB})
+ add_executable(arrow-stream-to-file stream_to_file.cc)
+ target_link_libraries(arrow-stream-to-file ${ARROW_UTIL_LIB})
+
+ if(ARROW_BUILD_INTEGRATION)
+ add_dependencies(arrow-integration arrow-file-to-stream)
+ add_dependencies(arrow-integration arrow-stream-to-file)
+ endif()
+endif()
+
+add_arrow_benchmark(read_write_benchmark PREFIX "arrow-ipc")
+
+if(ARROW_FUZZING)
+ add_executable(arrow-ipc-generate-fuzz-corpus generate_fuzz_corpus.cc)
+ target_link_libraries(arrow-ipc-generate-fuzz-corpus ${ARROW_UTIL_LIB}
+ ${ARROW_TEST_LINK_LIBS})
+
+ add_executable(arrow-ipc-generate-tensor-fuzz-corpus generate_tensor_fuzz_corpus.cc)
+ target_link_libraries(arrow-ipc-generate-tensor-fuzz-corpus ${ARROW_UTIL_LIB}
+ ${ARROW_TEST_LINK_LIBS})
+endif()
+
+add_arrow_fuzz_target(file_fuzz PREFIX "arrow-ipc")
+add_arrow_fuzz_target(stream_fuzz PREFIX "arrow-ipc")
+add_arrow_fuzz_target(tensor_stream_fuzz PREFIX "arrow-ipc")
diff --git a/src/arrow/cpp/src/arrow/ipc/api.h b/src/arrow/cpp/src/arrow/ipc/api.h
new file mode 100644
index 000000000..b5690aed8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/api.h
@@ -0,0 +1,25 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/ipc/dictionary.h"
+#include "arrow/ipc/feather.h"
+#include "arrow/ipc/json_simple.h"
+#include "arrow/ipc/message.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
diff --git a/src/arrow/cpp/src/arrow/ipc/dictionary.cc b/src/arrow/cpp/src/arrow/ipc/dictionary.cc
new file mode 100644
index 000000000..3ab2c8b38
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/dictionary.cc
@@ -0,0 +1,412 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/ipc/dictionary.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <set>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/array/validate.h"
+#include "arrow/extension_type.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace std {
+template <>
+struct hash<arrow::FieldPath> {
+ size_t operator()(const arrow::FieldPath& path) const { return path.hash(); }
+};
+} // namespace std
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace ipc {
+
+using internal::FieldPosition;
+
+// ----------------------------------------------------------------------
+// DictionaryFieldMapper implementation
+
+struct DictionaryFieldMapper::Impl {
+ using FieldPathMap = std::unordered_map<FieldPath, int64_t>;
+
+ FieldPathMap field_path_to_id;
+
+ void ImportSchema(const Schema& schema) {
+ ImportFields(FieldPosition(), schema.fields());
+ }
+
+ Status AddSchemaFields(const Schema& schema) {
+ if (!field_path_to_id.empty()) {
+ return Status::Invalid("Non-empty DictionaryFieldMapper");
+ }
+ ImportSchema(schema);
+ return Status::OK();
+ }
+
+ Status AddField(int64_t id, std::vector<int> field_path) {
+ const auto pair = field_path_to_id.emplace(FieldPath(std::move(field_path)), id);
+ if (!pair.second) {
+ return Status::KeyError("Field already mapped to id");
+ }
+ return Status::OK();
+ }
+
+ Result<int64_t> GetFieldId(std::vector<int> field_path) const {
+ const auto it = field_path_to_id.find(FieldPath(std::move(field_path)));
+ if (it == field_path_to_id.end()) {
+ return Status::KeyError("Dictionary field not found");
+ }
+ return it->second;
+ }
+
+ int num_fields() const { return static_cast<int>(field_path_to_id.size()); }
+
+ int num_dicts() const {
+ std::set<int64_t> uniqueIds;
+
+ for (auto& kv : field_path_to_id) {
+ uniqueIds.insert(kv.second);
+ }
+
+ return static_cast<int>(uniqueIds.size());
+ }
+
+ private:
+ void ImportFields(const FieldPosition& pos,
+ const std::vector<std::shared_ptr<Field>>& fields) {
+ for (int i = 0; i < static_cast<int>(fields.size()); ++i) {
+ ImportField(pos.child(i), *fields[i]);
+ }
+ }
+
+ void ImportField(const FieldPosition& pos, const Field& field) {
+ const DataType* type = field.type().get();
+ if (type->id() == Type::EXTENSION) {
+ type = checked_cast<const ExtensionType&>(*type).storage_type().get();
+ }
+ if (type->id() == Type::DICTIONARY) {
+ InsertPath(pos);
+ // Import nested dictionaries
+ ImportFields(pos,
+ checked_cast<const DictionaryType&>(*type).value_type()->fields());
+ } else {
+ ImportFields(pos, type->fields());
+ }
+ }
+
+ void InsertPath(const FieldPosition& pos) {
+ const int64_t id = field_path_to_id.size();
+ const auto pair = field_path_to_id.emplace(FieldPath(pos.path()), id);
+ DCHECK(pair.second); // was inserted
+ ARROW_UNUSED(pair);
+ }
+};
+
+DictionaryFieldMapper::DictionaryFieldMapper() : impl_(new Impl) {}
+
+DictionaryFieldMapper::DictionaryFieldMapper(const Schema& schema) : impl_(new Impl) {
+ impl_->ImportSchema(schema);
+}
+
+DictionaryFieldMapper::~DictionaryFieldMapper() {}
+
+Status DictionaryFieldMapper::AddSchemaFields(const Schema& schema) {
+ return impl_->AddSchemaFields(schema);
+}
+
+Status DictionaryFieldMapper::AddField(int64_t id, std::vector<int> field_path) {
+ return impl_->AddField(id, std::move(field_path));
+}
+
+Result<int64_t> DictionaryFieldMapper::GetFieldId(std::vector<int> field_path) const {
+ return impl_->GetFieldId(std::move(field_path));
+}
+
+int DictionaryFieldMapper::num_fields() const { return impl_->num_fields(); }
+
+int DictionaryFieldMapper::num_dicts() const { return impl_->num_dicts(); }
+
+// ----------------------------------------------------------------------
+// DictionaryMemo implementation
+
+namespace {
+
+bool HasUnresolvedNestedDict(const ArrayData& data) {
+ if (data.type->id() == Type::DICTIONARY) {
+ if (data.dictionary == nullptr) {
+ return true;
+ }
+ if (HasUnresolvedNestedDict(*data.dictionary)) {
+ return true;
+ }
+ }
+ for (const auto& child : data.child_data) {
+ if (HasUnresolvedNestedDict(*child)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+struct DictionaryMemo::Impl {
+ // Map of dictionary id to dictionary array(s) (several in case of deltas)
+ std::unordered_map<int64_t, ArrayDataVector> id_to_dictionary_;
+ std::unordered_map<int64_t, std::shared_ptr<DataType>> id_to_type_;
+ DictionaryFieldMapper mapper_;
+
+ Result<decltype(id_to_dictionary_)::iterator> FindDictionary(int64_t id) {
+ auto it = id_to_dictionary_.find(id);
+ if (it == id_to_dictionary_.end()) {
+ return Status::KeyError("Dictionary with id ", id, " not found");
+ }
+ return it;
+ }
+
+ Result<std::shared_ptr<ArrayData>> ReifyDictionary(int64_t id, MemoryPool* pool) {
+ ARROW_ASSIGN_OR_RAISE(auto it, FindDictionary(id));
+ ArrayDataVector* data_vector = &it->second;
+
+ DCHECK(!data_vector->empty());
+ if (data_vector->size() > 1) {
+ // There are deltas, we need to concatenate them to the first dictionary.
+ ArrayVector to_combine;
+ to_combine.reserve(data_vector->size());
+ // IMPORTANT: At this point, the dictionary data may be untrusted.
+ // We need to validate it, as concatenation can crash on invalid or
+ // corrupted data. Full validation is necessary for certain types
+ // (for example nested dictionaries).
+ for (const auto& data : *data_vector) {
+ if (HasUnresolvedNestedDict(*data)) {
+ return Status::NotImplemented(
+ "Encountered delta dictionary with an unresolved nested dictionary");
+ }
+ RETURN_NOT_OK(::arrow::internal::ValidateArray(*data));
+ RETURN_NOT_OK(::arrow::internal::ValidateArrayFull(*data));
+ to_combine.push_back(MakeArray(data));
+ }
+ ARROW_ASSIGN_OR_RAISE(auto combined_dict, Concatenate(to_combine, pool));
+ *data_vector = {combined_dict->data()};
+ }
+
+ return data_vector->back();
+ }
+};
+
+DictionaryMemo::DictionaryMemo() : impl_(new Impl()) {}
+
+DictionaryMemo::~DictionaryMemo() {}
+
+DictionaryFieldMapper& DictionaryMemo::fields() { return impl_->mapper_; }
+
+const DictionaryFieldMapper& DictionaryMemo::fields() const { return impl_->mapper_; }
+
+Result<std::shared_ptr<DataType>> DictionaryMemo::GetDictionaryType(int64_t id) const {
+ const auto it = impl_->id_to_type_.find(id);
+ if (it == impl_->id_to_type_.end()) {
+ return Status::KeyError("No record of dictionary type with id ", id);
+ }
+ return it->second;
+}
+
+// Returns KeyError if dictionary not found
+Result<std::shared_ptr<ArrayData>> DictionaryMemo::GetDictionary(int64_t id,
+ MemoryPool* pool) const {
+ return impl_->ReifyDictionary(id, pool);
+}
+
+Status DictionaryMemo::AddDictionaryType(int64_t id,
+ const std::shared_ptr<DataType>& type) {
+ // AddDictionaryType expects the dict value type
+ DCHECK_NE(type->id(), Type::DICTIONARY);
+ const auto pair = impl_->id_to_type_.emplace(id, type);
+ if (!pair.second && !pair.first->second->Equals(*type)) {
+ return Status::KeyError("Conflicting dictionary types for id ", id);
+ }
+ return Status::OK();
+}
+
+bool DictionaryMemo::HasDictionary(int64_t id) const {
+ const auto it = impl_->id_to_dictionary_.find(id);
+ return it != impl_->id_to_dictionary_.end();
+}
+
+Status DictionaryMemo::AddDictionary(int64_t id,
+ const std::shared_ptr<ArrayData>& dictionary) {
+ const auto pair = impl_->id_to_dictionary_.emplace(id, ArrayDataVector{dictionary});
+ if (!pair.second) {
+ return Status::KeyError("Dictionary with id ", id, " already exists");
+ }
+ return Status::OK();
+}
+
+Status DictionaryMemo::AddDictionaryDelta(int64_t id,
+ const std::shared_ptr<ArrayData>& dictionary) {
+ ARROW_ASSIGN_OR_RAISE(auto it, impl_->FindDictionary(id));
+ it->second.push_back(dictionary);
+ return Status::OK();
+}
+
+Result<bool> DictionaryMemo::AddOrReplaceDictionary(
+ int64_t id, const std::shared_ptr<ArrayData>& dictionary) {
+ ArrayDataVector value{dictionary};
+
+ auto pair = impl_->id_to_dictionary_.emplace(id, value);
+ if (pair.second) {
+ // Inserted
+ return true;
+ } else {
+ // Update existing value
+ pair.first->second = std::move(value);
+ return false;
+ }
+}
+
+// ----------------------------------------------------------------------
+// CollectDictionaries implementation
+
+namespace {
+
+struct DictionaryCollector {
+ const DictionaryFieldMapper& mapper_;
+ DictionaryVector dictionaries_;
+
+ Status WalkChildren(const FieldPosition& position, const DataType& type,
+ const Array& array) {
+ for (int i = 0; i < type.num_fields(); ++i) {
+ auto boxed_child = MakeArray(array.data()->child_data[i]);
+ RETURN_NOT_OK(Visit(position.child(i), type.field(i), boxed_child.get()));
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const FieldPosition& position, const std::shared_ptr<Field>& field,
+ const Array* array) {
+ const DataType* type = array->type().get();
+
+ if (type->id() == Type::EXTENSION) {
+ type = checked_cast<const ExtensionType&>(*type).storage_type().get();
+ array = checked_cast<const ExtensionArray&>(*array).storage().get();
+ }
+ if (type->id() == Type::DICTIONARY) {
+ const auto& dict_array = checked_cast<const DictionaryArray&>(*array);
+ auto dictionary = dict_array.dictionary();
+
+ // Traverse the dictionary to first gather any nested dictionaries
+ // (so that they appear in the output before their parent)
+ const auto& dict_type = checked_cast<const DictionaryType&>(*type);
+ RETURN_NOT_OK(WalkChildren(position, *dict_type.value_type(), *dictionary));
+
+ // Then record the dictionary itself
+ ARROW_ASSIGN_OR_RAISE(int64_t id, mapper_.GetFieldId(position.path()));
+ dictionaries_.emplace_back(id, dictionary);
+ } else {
+ RETURN_NOT_OK(WalkChildren(position, *type, *array));
+ }
+ return Status::OK();
+ }
+
+ Status Collect(const RecordBatch& batch) {
+ FieldPosition position;
+ const Schema& schema = *batch.schema();
+ dictionaries_.reserve(mapper_.num_fields());
+
+ for (int i = 0; i < schema.num_fields(); ++i) {
+ RETURN_NOT_OK(Visit(position.child(i), schema.field(i), batch.column(i).get()));
+ }
+ return Status::OK();
+ }
+};
+
+struct DictionaryResolver {
+ const DictionaryMemo& memo_;
+ MemoryPool* pool_;
+
+ Status VisitChildren(const ArrayDataVector& data_vector, FieldPosition parent_pos) {
+ int i = 0;
+ for (const auto& data : data_vector) {
+ // Some data entries may be missing if reading only a subset of the schema
+ if (data != nullptr) {
+ RETURN_NOT_OK(VisitField(parent_pos.child(i), data.get()));
+ }
+ ++i;
+ }
+ return Status::OK();
+ }
+
+ Status VisitField(FieldPosition field_pos, ArrayData* data) {
+ const DataType* type = data->type.get();
+ if (type->id() == Type::EXTENSION) {
+ type = checked_cast<const ExtensionType&>(*type).storage_type().get();
+ }
+ if (type->id() == Type::DICTIONARY) {
+ ARROW_ASSIGN_OR_RAISE(const int64_t id,
+ memo_.fields().GetFieldId(field_pos.path()));
+ ARROW_ASSIGN_OR_RAISE(data->dictionary, memo_.GetDictionary(id, pool_));
+ // Resolve nested dictionary data
+ RETURN_NOT_OK(VisitField(field_pos, data->dictionary.get()));
+ }
+ // Resolve child data
+ return VisitChildren(data->child_data, field_pos);
+ }
+};
+
+} // namespace
+
+Result<DictionaryVector> CollectDictionaries(const RecordBatch& batch,
+ const DictionaryFieldMapper& mapper) {
+ DictionaryCollector collector{mapper, {}};
+ RETURN_NOT_OK(collector.Collect(batch));
+ return std::move(collector.dictionaries_);
+}
+
+namespace internal {
+
+Status CollectDictionaries(const RecordBatch& batch, DictionaryMemo* memo) {
+ RETURN_NOT_OK(memo->fields().AddSchemaFields(*batch.schema()));
+ ARROW_ASSIGN_OR_RAISE(const auto dictionaries,
+ CollectDictionaries(batch, memo->fields()));
+ for (const auto& pair : dictionaries) {
+ RETURN_NOT_OK(memo->AddDictionary(pair.first, pair.second->data()));
+ }
+ return Status::OK();
+}
+
+} // namespace internal
+
+Status ResolveDictionaries(const ArrayDataVector& columns, const DictionaryMemo& memo,
+ MemoryPool* pool) {
+ DictionaryResolver resolver{memo, pool};
+ return resolver.VisitChildren(columns, FieldPosition());
+}
+
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/dictionary.h b/src/arrow/cpp/src/arrow/ipc/dictionary.h
new file mode 100644
index 000000000..e4287cb19
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/dictionary.h
@@ -0,0 +1,177 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Tools for dictionaries in IPC context
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace ipc {
+
+namespace internal {
+
+class FieldPosition {
+ public:
+ FieldPosition() : parent_(NULLPTR), index_(-1), depth_(0) {}
+
+ FieldPosition child(int index) const { return {this, index}; }
+
+ std::vector<int> path() const {
+ std::vector<int> path(depth_);
+ const FieldPosition* cur = this;
+ for (int i = depth_ - 1; i >= 0; --i) {
+ path[i] = cur->index_;
+ cur = cur->parent_;
+ }
+ return path;
+ }
+
+ protected:
+ FieldPosition(const FieldPosition* parent, int index)
+ : parent_(parent), index_(index), depth_(parent->depth_ + 1) {}
+
+ const FieldPosition* parent_;
+ int index_;
+ int depth_;
+};
+
+} // namespace internal
+
+/// \brief Map fields in a schema to dictionary ids
+///
+/// The mapping is structural, i.e. the field path (as a vector of indices)
+/// is associated to the dictionary id. A dictionary id may be associated
+/// to multiple fields.
+class ARROW_EXPORT DictionaryFieldMapper {
+ public:
+ DictionaryFieldMapper();
+ explicit DictionaryFieldMapper(const Schema& schema);
+ ~DictionaryFieldMapper();
+
+ Status AddSchemaFields(const Schema& schema);
+ Status AddField(int64_t id, std::vector<int> field_path);
+
+ Result<int64_t> GetFieldId(std::vector<int> field_path) const;
+
+ int num_fields() const;
+
+ /// \brief Returns number of unique dictionaries, taking into
+ /// account that different fields can share the same dictionary.
+ int num_dicts() const;
+
+ private:
+ struct Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+using DictionaryVector = std::vector<std::pair<int64_t, std::shared_ptr<Array>>>;
+
+/// \brief Memoization data structure for reading dictionaries from IPC streams
+///
+/// This structure tracks the following associations:
+/// - field position (structural) -> dictionary id
+/// - dictionary id -> value type
+/// - dictionary id -> dictionary (value) data
+///
+/// Together, they allow resolving dictionary data when reading an IPC stream,
+/// using metadata recorded in the schema message and data recorded in the
+/// dictionary batch messages (see ResolveDictionaries).
+///
+/// This structure isn't useful for writing an IPC stream, where only
+/// DictionaryFieldMapper is necessary.
+class ARROW_EXPORT DictionaryMemo {
+ public:
+ DictionaryMemo();
+ ~DictionaryMemo();
+
+ DictionaryFieldMapper& fields();
+ const DictionaryFieldMapper& fields() const;
+
+ /// \brief Return current dictionary corresponding to a particular
+ /// id. Returns KeyError if id not found
+ Result<std::shared_ptr<ArrayData>> GetDictionary(int64_t id, MemoryPool* pool) const;
+
+ /// \brief Return dictionary value type corresponding to a
+ /// particular dictionary id.
+ Result<std::shared_ptr<DataType>> GetDictionaryType(int64_t id) const;
+
+ /// \brief Return true if we have a dictionary for the input id
+ bool HasDictionary(int64_t id) const;
+
+ /// \brief Add a dictionary value type to the memo with a particular id.
+ /// Returns KeyError if a different type is already registered with the same id.
+ Status AddDictionaryType(int64_t id, const std::shared_ptr<DataType>& type);
+
+ /// \brief Add a dictionary to the memo with a particular id. Returns
+ /// KeyError if that dictionary already exists
+ Status AddDictionary(int64_t id, const std::shared_ptr<ArrayData>& dictionary);
+
+ /// \brief Append a dictionary delta to the memo with a particular id. Returns
+ /// KeyError if that dictionary does not exists
+ Status AddDictionaryDelta(int64_t id, const std::shared_ptr<ArrayData>& dictionary);
+
+ /// \brief Add a dictionary to the memo if it does not have one with the id,
+ /// otherwise, replace the dictionary with the new one.
+ ///
+ /// Return true if the dictionary was added, false if replaced.
+ Result<bool> AddOrReplaceDictionary(int64_t id,
+ const std::shared_ptr<ArrayData>& dictionary);
+
+ private:
+ struct Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+// For writing: collect dictionary entries to write to the IPC stream, in order
+// (i.e. inner dictionaries before dependent outer dictionaries).
+ARROW_EXPORT
+Result<DictionaryVector> CollectDictionaries(const RecordBatch& batch,
+ const DictionaryFieldMapper& mapper);
+
+// For reading: resolve all dictionaries in columns, according to the field
+// mapping and dictionary arrays stored in memo.
+// Columns may be sparse, i.e. some entries may be left null
+// (e.g. if an inclusion mask was used).
+ARROW_EXPORT
+Status ResolveDictionaries(const ArrayDataVector& columns, const DictionaryMemo& memo,
+ MemoryPool* pool);
+
+namespace internal {
+
+// Like CollectDictionaries above, but uses the memo's DictionaryFieldMapper
+// and all collected dictionaries are added to the memo using AddDictionary.
+//
+// This is used as a shortcut in some roundtripping tests (to avoid emitting
+// any actual dictionary batches).
+ARROW_EXPORT
+Status CollectDictionaries(const RecordBatch& batch, DictionaryMemo* memo);
+
+} // namespace internal
+
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/feather.cc b/src/arrow/cpp/src/arrow/ipc/feather.cc
new file mode 100644
index 000000000..b1c30eec0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/feather.cc
@@ -0,0 +1,819 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/ipc/feather.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <sstream> // IWYU pragma: keep
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include <flatbuffers/flatbuffers.h>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/chunked_array.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/ipc/metadata_internal.h"
+#include "arrow/ipc/options.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/util.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/visitor_inline.h"
+
+#include "generated/feather_generated.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::make_unique;
+
+class ExtensionType;
+
+namespace ipc {
+namespace feather {
+
+namespace {
+
+using FBB = flatbuffers::FlatBufferBuilder;
+
+constexpr const char* kFeatherV1MagicBytes = "FEA1";
+constexpr const int kFeatherDefaultAlignment = 8;
+const uint8_t kPaddingBytes[kFeatherDefaultAlignment] = {0};
+
+inline int64_t PaddedLength(int64_t nbytes) {
+ static const int64_t alignment = kFeatherDefaultAlignment;
+ return ((nbytes + alignment - 1) / alignment) * alignment;
+}
+
+Status WritePaddedWithOffset(io::OutputStream* stream, const uint8_t* data,
+ int64_t bit_offset, const int64_t length,
+ int64_t* bytes_written) {
+ data = data + bit_offset / 8;
+ uint8_t bit_shift = static_cast<uint8_t>(bit_offset % 8);
+ if (bit_offset == 0) {
+ RETURN_NOT_OK(stream->Write(data, length));
+ } else {
+ constexpr int64_t buffersize = 256;
+ uint8_t buffer[buffersize];
+ const uint8_t lshift = static_cast<uint8_t>(8 - bit_shift);
+ const uint8_t* buffer_end = buffer + buffersize;
+ uint8_t* buffer_it = buffer;
+
+ for (const uint8_t* end = data + length; data != end;) {
+ uint8_t r = static_cast<uint8_t>(*data++ >> bit_shift);
+ uint8_t l = static_cast<uint8_t>(*data << lshift);
+ uint8_t value = l | r;
+ *buffer_it++ = value;
+ if (buffer_it == buffer_end) {
+ RETURN_NOT_OK(stream->Write(buffer, buffersize));
+ buffer_it = buffer;
+ }
+ }
+ if (buffer_it != buffer) {
+ RETURN_NOT_OK(stream->Write(buffer, buffer_it - buffer));
+ }
+ }
+
+ int64_t remainder = PaddedLength(length) - length;
+ if (remainder != 0) {
+ RETURN_NOT_OK(stream->Write(kPaddingBytes, remainder));
+ }
+ *bytes_written = length + remainder;
+ return Status::OK();
+}
+
+Status WritePadded(io::OutputStream* stream, const uint8_t* data, int64_t length,
+ int64_t* bytes_written) {
+ return WritePaddedWithOffset(stream, data, /*bit_offset=*/0, length, bytes_written);
+}
+
+struct ColumnType {
+ enum type { PRIMITIVE, CATEGORY, TIMESTAMP, DATE, TIME };
+};
+
+inline TimeUnit::type FromFlatbufferEnum(fbs::TimeUnit unit) {
+ return static_cast<TimeUnit::type>(static_cast<int>(unit));
+}
+
+/// For compatibility, we need to write any data sometimes just to keep producing
+/// files that can be read with an older reader.
+Status WritePaddedBlank(io::OutputStream* stream, int64_t length,
+ int64_t* bytes_written) {
+ const uint8_t null = 0;
+ for (int64_t i = 0; i < length; i++) {
+ RETURN_NOT_OK(stream->Write(&null, 1));
+ }
+ int64_t remainder = PaddedLength(length) - length;
+ if (remainder != 0) {
+ RETURN_NOT_OK(stream->Write(kPaddingBytes, remainder));
+ }
+ *bytes_written = length + remainder;
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// ReaderV1
+
+class ReaderV1 : public Reader {
+ public:
+ Status Open(const std::shared_ptr<io::RandomAccessFile>& source) {
+ source_ = source;
+
+ ARROW_ASSIGN_OR_RAISE(int64_t size, source->GetSize());
+ int magic_size = static_cast<int>(strlen(kFeatherV1MagicBytes));
+ int footer_size = magic_size + static_cast<int>(sizeof(uint32_t));
+
+ // Now get the footer and verify
+ ARROW_ASSIGN_OR_RAISE(auto buffer, source->ReadAt(size - footer_size, footer_size));
+
+ if (memcmp(buffer->data() + sizeof(uint32_t), kFeatherV1MagicBytes, magic_size)) {
+ return Status::Invalid("Feather file footer incomplete");
+ }
+
+ uint32_t metadata_length = *reinterpret_cast<const uint32_t*>(buffer->data());
+ if (size < magic_size + footer_size + metadata_length) {
+ return Status::Invalid("File is smaller than indicated metadata size");
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ metadata_buffer_,
+ source->ReadAt(size - footer_size - metadata_length, metadata_length));
+
+ metadata_ = fbs::GetCTable(metadata_buffer_->data());
+ return ReadSchema();
+ }
+
+ Status ReadSchema() {
+ std::vector<std::shared_ptr<Field>> fields;
+ for (int i = 0; i < static_cast<int>(metadata_->columns()->size()); ++i) {
+ const fbs::Column* col = metadata_->columns()->Get(i);
+ std::shared_ptr<DataType> type;
+ RETURN_NOT_OK(
+ GetDataType(col->values(), col->metadata_type(), col->metadata(), &type));
+ fields.push_back(::arrow::field(col->name()->str(), type));
+ }
+ schema_ = ::arrow::schema(std::move(fields));
+ return Status::OK();
+ }
+
+ Status GetDataType(const fbs::PrimitiveArray* values, fbs::TypeMetadata metadata_type,
+ const void* metadata, std::shared_ptr<DataType>* out) {
+#define PRIMITIVE_CASE(CAP_TYPE, FACTORY_FUNC) \
+ case fbs::Type::CAP_TYPE: \
+ *out = FACTORY_FUNC(); \
+ break;
+
+ switch (metadata_type) {
+ case fbs::TypeMetadata::CategoryMetadata: {
+ auto meta = static_cast<const fbs::CategoryMetadata*>(metadata);
+
+ std::shared_ptr<DataType> index_type, dict_type;
+ RETURN_NOT_OK(GetDataType(values, fbs::TypeMetadata::NONE, nullptr, &index_type));
+ RETURN_NOT_OK(
+ GetDataType(meta->levels(), fbs::TypeMetadata::NONE, nullptr, &dict_type));
+ *out = dictionary(index_type, dict_type, meta->ordered());
+ break;
+ }
+ case fbs::TypeMetadata::TimestampMetadata: {
+ auto meta = static_cast<const fbs::TimestampMetadata*>(metadata);
+ TimeUnit::type unit = FromFlatbufferEnum(meta->unit());
+ std::string tz;
+ // flatbuffer non-null
+ if (meta->timezone() != 0) {
+ tz = meta->timezone()->str();
+ } else {
+ tz = "";
+ }
+ *out = timestamp(unit, tz);
+ } break;
+ case fbs::TypeMetadata::DateMetadata:
+ *out = date32();
+ break;
+ case fbs::TypeMetadata::TimeMetadata: {
+ auto meta = static_cast<const fbs::TimeMetadata*>(metadata);
+ *out = time32(FromFlatbufferEnum(meta->unit()));
+ } break;
+ default:
+ switch (values->type()) {
+ PRIMITIVE_CASE(BOOL, boolean);
+ PRIMITIVE_CASE(INT8, int8);
+ PRIMITIVE_CASE(INT16, int16);
+ PRIMITIVE_CASE(INT32, int32);
+ PRIMITIVE_CASE(INT64, int64);
+ PRIMITIVE_CASE(UINT8, uint8);
+ PRIMITIVE_CASE(UINT16, uint16);
+ PRIMITIVE_CASE(UINT32, uint32);
+ PRIMITIVE_CASE(UINT64, uint64);
+ PRIMITIVE_CASE(FLOAT, float32);
+ PRIMITIVE_CASE(DOUBLE, float64);
+ PRIMITIVE_CASE(UTF8, utf8);
+ PRIMITIVE_CASE(BINARY, binary);
+ PRIMITIVE_CASE(LARGE_UTF8, large_utf8);
+ PRIMITIVE_CASE(LARGE_BINARY, large_binary);
+ default:
+ return Status::Invalid("Unrecognized type");
+ }
+ break;
+ }
+
+#undef PRIMITIVE_CASE
+
+ return Status::OK();
+ }
+
+ int64_t GetOutputLength(int64_t nbytes) {
+ // XXX: Hack for Feather 0.3.0 for backwards compatibility with old files
+ // Size in-file of written byte buffer
+ if (version() < 2) {
+ // Feather files < 0.3.0
+ return nbytes;
+ } else {
+ return PaddedLength(nbytes);
+ }
+ }
+
+ // Retrieve a primitive array from the data source
+ //
+ // @returns: a Buffer instance, the precise type will depend on the kind of
+ // input data source (which may or may not have memory-map like semantics)
+ Status LoadValues(std::shared_ptr<DataType> type, const fbs::PrimitiveArray* meta,
+ fbs::TypeMetadata metadata_type, const void* metadata,
+ std::shared_ptr<ArrayData>* out) {
+ std::vector<std::shared_ptr<Buffer>> buffers;
+
+ // Buffer data from the source (may or may not perform a copy depending on
+ // input source)
+ ARROW_ASSIGN_OR_RAISE(auto buffer,
+ source_->ReadAt(meta->offset(), meta->total_bytes()));
+
+ int64_t offset = 0;
+
+ if (type->id() == Type::DICTIONARY) {
+ // Load the index type values
+ type = checked_cast<const DictionaryType&>(*type).index_type();
+ }
+
+ // If there are nulls, the null bitmask is first
+ if (meta->null_count() > 0) {
+ int64_t null_bitmap_size = GetOutputLength(BitUtil::BytesForBits(meta->length()));
+ buffers.push_back(SliceBuffer(buffer, offset, null_bitmap_size));
+ offset += null_bitmap_size;
+ } else {
+ buffers.push_back(nullptr);
+ }
+
+ if (is_binary_like(type->id())) {
+ int64_t offsets_size = GetOutputLength((meta->length() + 1) * sizeof(int32_t));
+ buffers.push_back(SliceBuffer(buffer, offset, offsets_size));
+ offset += offsets_size;
+ } else if (is_large_binary_like(type->id())) {
+ int64_t offsets_size = GetOutputLength((meta->length() + 1) * sizeof(int64_t));
+ buffers.push_back(SliceBuffer(buffer, offset, offsets_size));
+ offset += offsets_size;
+ }
+
+ buffers.push_back(SliceBuffer(buffer, offset, buffer->size() - offset));
+
+ *out = ArrayData::Make(type, meta->length(), std::move(buffers), meta->null_count());
+ return Status::OK();
+ }
+
+ int version() const override { return metadata_->version(); }
+ int64_t num_rows() const { return metadata_->num_rows(); }
+
+ std::shared_ptr<Schema> schema() const override { return schema_; }
+
+ Status GetDictionary(int field_index, std::shared_ptr<ArrayData>* out) {
+ const fbs::Column* col_meta = metadata_->columns()->Get(field_index);
+ auto dict_meta = col_meta->metadata_as<fbs::CategoryMetadata>();
+ const auto& dict_type =
+ checked_cast<const DictionaryType&>(*schema_->field(field_index)->type());
+
+ return LoadValues(dict_type.value_type(), dict_meta->levels(),
+ fbs::TypeMetadata::NONE, nullptr, out);
+ }
+
+ Status GetColumn(int field_index, std::shared_ptr<ChunkedArray>* out) {
+ const fbs::Column* col_meta = metadata_->columns()->Get(field_index);
+ std::shared_ptr<ArrayData> data;
+
+ auto type = schema_->field(field_index)->type();
+ RETURN_NOT_OK(LoadValues(type, col_meta->values(), col_meta->metadata_type(),
+ col_meta->metadata(), &data));
+
+ if (type->id() == Type::DICTIONARY) {
+ RETURN_NOT_OK(GetDictionary(field_index, &data->dictionary));
+ data->type = type;
+ }
+ *out = std::make_shared<ChunkedArray>(MakeArray(data));
+ return Status::OK();
+ }
+
+ Status Read(std::shared_ptr<Table>* out) override {
+ std::vector<std::shared_ptr<ChunkedArray>> columns;
+ for (int i = 0; i < static_cast<int>(metadata_->columns()->size()); ++i) {
+ columns.emplace_back();
+ RETURN_NOT_OK(GetColumn(i, &columns.back()));
+ }
+ *out = Table::Make(this->schema(), std::move(columns), this->num_rows());
+ return Status::OK();
+ }
+
+ Status Read(const std::vector<int>& indices, std::shared_ptr<Table>* out) override {
+ std::vector<std::shared_ptr<Field>> fields;
+ std::vector<std::shared_ptr<ChunkedArray>> columns;
+
+ auto my_schema = this->schema();
+ for (auto field_index : indices) {
+ if (field_index < 0 || field_index >= my_schema->num_fields()) {
+ return Status::Invalid("Field index ", field_index, " is out of bounds");
+ }
+ columns.emplace_back();
+ RETURN_NOT_OK(GetColumn(field_index, &columns.back()));
+ fields.push_back(my_schema->field(field_index));
+ }
+ *out = Table::Make(::arrow::schema(std::move(fields)), std::move(columns),
+ this->num_rows());
+ return Status::OK();
+ }
+
+ Status Read(const std::vector<std::string>& names,
+ std::shared_ptr<Table>* out) override {
+ std::vector<std::shared_ptr<Field>> fields;
+ std::vector<std::shared_ptr<ChunkedArray>> columns;
+
+ std::shared_ptr<Schema> sch = this->schema();
+ for (auto name : names) {
+ int field_index = sch->GetFieldIndex(name);
+ if (field_index == -1) {
+ return Status::Invalid("Field named ", name, " is not found");
+ }
+ columns.emplace_back();
+ RETURN_NOT_OK(GetColumn(field_index, &columns.back()));
+ fields.push_back(sch->field(field_index));
+ }
+ *out = Table::Make(::arrow::schema(std::move(fields)), std::move(columns),
+ this->num_rows());
+ return Status::OK();
+ }
+
+ private:
+ std::shared_ptr<io::RandomAccessFile> source_;
+ std::shared_ptr<Buffer> metadata_buffer_;
+ const fbs::CTable* metadata_;
+ std::shared_ptr<Schema> schema_;
+};
+
+// ----------------------------------------------------------------------
+// WriterV1
+
+struct ArrayMetadata {
+ fbs::Type type;
+ int64_t offset;
+ int64_t length;
+ int64_t null_count;
+ int64_t total_bytes;
+};
+
+#define TO_FLATBUFFER_CASE(TYPE) \
+ case Type::TYPE: \
+ return fbs::Type::TYPE;
+
+Result<fbs::Type> ToFlatbufferType(const DataType& type) {
+ switch (type.id()) {
+ TO_FLATBUFFER_CASE(BOOL);
+ TO_FLATBUFFER_CASE(INT8);
+ TO_FLATBUFFER_CASE(INT16);
+ TO_FLATBUFFER_CASE(INT32);
+ TO_FLATBUFFER_CASE(INT64);
+ TO_FLATBUFFER_CASE(UINT8);
+ TO_FLATBUFFER_CASE(UINT16);
+ TO_FLATBUFFER_CASE(UINT32);
+ TO_FLATBUFFER_CASE(UINT64);
+ TO_FLATBUFFER_CASE(FLOAT);
+ TO_FLATBUFFER_CASE(DOUBLE);
+ TO_FLATBUFFER_CASE(LARGE_BINARY);
+ TO_FLATBUFFER_CASE(BINARY);
+ case Type::STRING:
+ return fbs::Type::UTF8;
+ case Type::LARGE_STRING:
+ return fbs::Type::LARGE_UTF8;
+ case Type::DATE32:
+ return fbs::Type::INT32;
+ case Type::TIMESTAMP:
+ return fbs::Type::INT64;
+ case Type::TIME32:
+ return fbs::Type::INT32;
+ case Type::TIME64:
+ return fbs::Type::INT64;
+ default:
+ return Status::TypeError("Unsupported Feather V1 type: ", type.ToString(),
+ ". Use V2 format to serialize all Arrow types.");
+ }
+}
+
+inline flatbuffers::Offset<fbs::PrimitiveArray> GetPrimitiveArray(
+ FBB& fbb, const ArrayMetadata& array) {
+ return fbs::CreatePrimitiveArray(fbb, array.type, fbs::Encoding::PLAIN, array.offset,
+ array.length, array.null_count, array.total_bytes);
+}
+
+// Convert Feather enums to Flatbuffer enums
+inline fbs::TimeUnit ToFlatbufferEnum(TimeUnit::type unit) {
+ return static_cast<fbs::TimeUnit>(static_cast<int>(unit));
+}
+
+const fbs::TypeMetadata COLUMN_TYPE_ENUM_MAPPING[] = {
+ fbs::TypeMetadata::NONE, // PRIMITIVE
+ fbs::TypeMetadata::CategoryMetadata, // CATEGORY
+ fbs::TypeMetadata::TimestampMetadata, // TIMESTAMP
+ fbs::TypeMetadata::DateMetadata, // DATE
+ fbs::TypeMetadata::TimeMetadata // TIME
+};
+
+inline fbs::TypeMetadata ToFlatbufferEnum(ColumnType::type column_type) {
+ return COLUMN_TYPE_ENUM_MAPPING[column_type];
+}
+
+struct ColumnMetadata {
+ flatbuffers::Offset<void> WriteMetadata(FBB& fbb) { // NOLINT
+ switch (this->meta_type) {
+ case ColumnType::PRIMITIVE:
+ // flatbuffer void
+ return 0;
+ case ColumnType::CATEGORY: {
+ auto cat_meta = fbs::CreateCategoryMetadata(
+ fbb, GetPrimitiveArray(fbb, this->category_levels), this->category_ordered);
+ return cat_meta.Union();
+ }
+ case ColumnType::TIMESTAMP: {
+ // flatbuffer void
+ flatbuffers::Offset<flatbuffers::String> tz = 0;
+ if (!this->timezone.empty()) {
+ tz = fbb.CreateString(this->timezone);
+ }
+
+ auto ts_meta =
+ fbs::CreateTimestampMetadata(fbb, ToFlatbufferEnum(this->temporal_unit), tz);
+ return ts_meta.Union();
+ }
+ case ColumnType::DATE: {
+ auto date_meta = fbs::CreateDateMetadata(fbb);
+ return date_meta.Union();
+ }
+ case ColumnType::TIME: {
+ auto time_meta =
+ fbs::CreateTimeMetadata(fbb, ToFlatbufferEnum(this->temporal_unit));
+ return time_meta.Union();
+ }
+ default:
+ // null
+ DCHECK(false);
+ return 0;
+ }
+ }
+
+ ArrayMetadata values;
+ ColumnType::type meta_type;
+
+ ArrayMetadata category_levels;
+ bool category_ordered;
+
+ TimeUnit::type temporal_unit;
+
+ // A timezone name known to the Olson timezone database. For display purposes
+ // because the actual data is all UTC
+ std::string timezone;
+};
+
+Status WriteArrayV1(const Array& values, io::OutputStream* dst, ArrayMetadata* meta);
+
+struct ArrayWriterV1 {
+ const Array& values;
+ io::OutputStream* dst;
+ ArrayMetadata* meta;
+
+ Status WriteBuffer(const uint8_t* buffer, int64_t length, int64_t bit_offset) {
+ int64_t bytes_written = 0;
+ if (buffer) {
+ RETURN_NOT_OK(
+ WritePaddedWithOffset(dst, buffer, bit_offset, length, &bytes_written));
+ } else {
+ RETURN_NOT_OK(WritePaddedBlank(dst, length, &bytes_written));
+ }
+ meta->total_bytes += bytes_written;
+ return Status::OK();
+ }
+
+ template <typename T>
+ typename std::enable_if<
+ is_nested_type<T>::value || is_null_type<T>::value || is_decimal_type<T>::value ||
+ std::is_same<DictionaryType, T>::value || is_duration_type<T>::value ||
+ is_interval_type<T>::value || is_fixed_size_binary_type<T>::value ||
+ std::is_same<Date64Type, T>::value || std::is_same<Time64Type, T>::value ||
+ std::is_same<ExtensionType, T>::value,
+ Status>::type
+ Visit(const T& type) {
+ return Status::NotImplemented(type.ToString());
+ }
+
+ template <typename T>
+ typename std::enable_if<is_number_type<T>::value ||
+ std::is_same<Date32Type, T>::value ||
+ std::is_same<Time32Type, T>::value ||
+ is_timestamp_type<T>::value || is_boolean_type<T>::value,
+ Status>::type
+ Visit(const T&) {
+ const auto& prim_values = checked_cast<const PrimitiveArray&>(values);
+ const auto& fw_type = checked_cast<const FixedWidthType&>(*values.type());
+
+ if (prim_values.values()) {
+ const uint8_t* buffer =
+ prim_values.values()->data() + (prim_values.offset() * fw_type.bit_width() / 8);
+ int64_t bit_offset = (prim_values.offset() * fw_type.bit_width()) % 8;
+ return WriteBuffer(buffer,
+ BitUtil::BytesForBits(values.length() * fw_type.bit_width()),
+ bit_offset);
+ } else {
+ return Status::OK();
+ }
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_base_binary<T, Status> Visit(const T&) {
+ using ArrayType = typename TypeTraits<T>::ArrayType;
+ const auto& ty_values = checked_cast<const ArrayType&>(values);
+
+ using offset_type = typename T::offset_type;
+ const offset_type* offsets_data = nullptr;
+ int64_t values_bytes = 0;
+ if (ty_values.value_offsets()) {
+ offsets_data = ty_values.raw_value_offsets();
+ // All of the data has to be written because we don't have offset
+ // shifting implemented here as with the IPC format
+ values_bytes = offsets_data[values.length()];
+ }
+ RETURN_NOT_OK(WriteBuffer(reinterpret_cast<const uint8_t*>(offsets_data),
+ sizeof(offset_type) * (values.length() + 1),
+ /*bit_offset=*/0));
+
+ const uint8_t* values_buffer = nullptr;
+ if (ty_values.value_data()) {
+ values_buffer = ty_values.value_data()->data();
+ }
+ return WriteBuffer(values_buffer, values_bytes, /*bit_offset=*/0);
+ }
+
+ Status Write() {
+ if (values.type_id() == Type::DICTIONARY) {
+ return WriteArrayV1(*(checked_cast<const DictionaryArray&>(values).indices()), dst,
+ meta);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(meta->type, ToFlatbufferType(*values.type()));
+ ARROW_ASSIGN_OR_RAISE(meta->offset, dst->Tell());
+ meta->length = values.length();
+ meta->null_count = values.null_count();
+ meta->total_bytes = 0;
+
+ // Write the null bitmask
+ if (values.null_count() > 0) {
+ RETURN_NOT_OK(WriteBuffer(values.null_bitmap_data(),
+ BitUtil::BytesForBits(values.length()), values.offset()));
+ }
+ // Write data buffer(s)
+ return VisitTypeInline(*values.type(), this);
+ }
+};
+
+Status WriteArrayV1(const Array& values, io::OutputStream* dst, ArrayMetadata* meta) {
+ std::shared_ptr<Array> sanitized;
+ if (values.type_id() == Type::NA) {
+ // As long as R doesn't support NA, we write this as a StringColumn
+ // to ensure stable roundtrips.
+ sanitized = std::make_shared<StringArray>(values.length(), nullptr, nullptr,
+ values.null_bitmap(), values.null_count());
+ } else {
+ sanitized = MakeArray(values.data());
+ }
+ ArrayWriterV1 visitor{*sanitized, dst, meta};
+ return visitor.Write();
+}
+
+Status WriteColumnV1(const ChunkedArray& values, io::OutputStream* dst,
+ ColumnMetadata* out) {
+ if (values.num_chunks() > 1) {
+ return Status::Invalid("Writing chunked arrays not supported in Feather V1");
+ }
+ const Array& chunk = *values.chunk(0);
+ RETURN_NOT_OK(WriteArrayV1(chunk, dst, &out->values));
+ switch (chunk.type_id()) {
+ case Type::DICTIONARY: {
+ out->meta_type = ColumnType::CATEGORY;
+ auto dictionary = checked_cast<const DictionaryArray&>(chunk).dictionary();
+ RETURN_NOT_OK(WriteArrayV1(*dictionary, dst, &out->category_levels));
+ out->category_ordered =
+ checked_cast<const DictionaryType&>(*chunk.type()).ordered();
+ } break;
+ case Type::DATE32:
+ out->meta_type = ColumnType::DATE;
+ break;
+ case Type::TIME32: {
+ out->meta_type = ColumnType::TIME;
+ out->temporal_unit = checked_cast<const Time32Type&>(*chunk.type()).unit();
+ } break;
+ case Type::TIMESTAMP: {
+ const auto& ts_type = checked_cast<const TimestampType&>(*chunk.type());
+ out->meta_type = ColumnType::TIMESTAMP;
+ out->temporal_unit = ts_type.unit();
+ out->timezone = ts_type.timezone();
+ } break;
+ default:
+ out->meta_type = ColumnType::PRIMITIVE;
+ break;
+ }
+ return Status::OK();
+}
+
+Status WriteFeatherV1(const Table& table, io::OutputStream* dst) {
+ // Preamble
+ int64_t bytes_written;
+ RETURN_NOT_OK(WritePadded(dst, reinterpret_cast<const uint8_t*>(kFeatherV1MagicBytes),
+ strlen(kFeatherV1MagicBytes), &bytes_written));
+
+ // Write columns
+ flatbuffers::FlatBufferBuilder fbb;
+ std::vector<flatbuffers::Offset<fbs::Column>> fb_columns;
+ for (int i = 0; i < table.num_columns(); ++i) {
+ ColumnMetadata col;
+ RETURN_NOT_OK(WriteColumnV1(*table.column(i), dst, &col));
+ auto fb_column = fbs::CreateColumn(
+ fbb, fbb.CreateString(table.field(i)->name()), GetPrimitiveArray(fbb, col.values),
+ ToFlatbufferEnum(col.meta_type), col.WriteMetadata(fbb),
+ /*user_metadata=*/0);
+ fb_columns.push_back(fb_column);
+ }
+
+ // Finalize file footer
+ auto root = fbs::CreateCTable(fbb, /*description=*/0, table.num_rows(),
+ fbb.CreateVector(fb_columns), kFeatherV1Version,
+ /*metadata=*/0);
+ fbb.Finish(root);
+ auto buffer = std::make_shared<Buffer>(fbb.GetBufferPointer(),
+ static_cast<int64_t>(fbb.GetSize()));
+
+ // Writer metadata
+ RETURN_NOT_OK(WritePadded(dst, buffer->data(), buffer->size(), &bytes_written));
+ uint32_t metadata_size = static_cast<uint32_t>(bytes_written);
+
+ // Footer: metadata length, magic bytes
+ RETURN_NOT_OK(dst->Write(&metadata_size, sizeof(uint32_t)));
+ return dst->Write(kFeatherV1MagicBytes, strlen(kFeatherV1MagicBytes));
+}
+
+// ----------------------------------------------------------------------
+// Reader V2
+
+class ReaderV2 : public Reader {
+ public:
+ Status Open(const std::shared_ptr<io::RandomAccessFile>& source) {
+ source_ = source;
+ ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchFileReader::Open(source_));
+ schema_ = reader->schema();
+ return Status::OK();
+ }
+
+ int version() const override { return kFeatherV2Version; }
+
+ std::shared_ptr<Schema> schema() const override { return schema_; }
+
+ Status Read(const IpcReadOptions& options, std::shared_ptr<Table>* out) {
+ ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchFileReader::Open(source_, options));
+ RecordBatchVector batches(reader->num_record_batches());
+ for (int i = 0; i < reader->num_record_batches(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(batches[i], reader->ReadRecordBatch(i));
+ }
+
+ return Table::FromRecordBatches(reader->schema(), batches).Value(out);
+ }
+
+ Status Read(std::shared_ptr<Table>* out) override {
+ return Read(IpcReadOptions::Defaults(), out);
+ }
+
+ Status Read(const std::vector<int>& indices, std::shared_ptr<Table>* out) override {
+ auto options = IpcReadOptions::Defaults();
+ options.included_fields = indices;
+ return Read(options, out);
+ }
+
+ Status Read(const std::vector<std::string>& names,
+ std::shared_ptr<Table>* out) override {
+ std::vector<int> indices;
+ std::shared_ptr<Schema> sch = this->schema();
+ for (auto name : names) {
+ int field_index = sch->GetFieldIndex(name);
+ if (field_index == -1) {
+ return Status::Invalid("Field named ", name, " is not found");
+ }
+ indices.push_back(field_index);
+ }
+ return Read(indices, out);
+ }
+
+ private:
+ std::shared_ptr<io::RandomAccessFile> source_;
+ std::shared_ptr<Schema> schema_;
+};
+
+} // namespace
+
+Result<std::shared_ptr<Reader>> Reader::Open(
+ const std::shared_ptr<io::RandomAccessFile>& source) {
+ // Pathological issue where the file is smaller than header and footer
+ // combined
+ ARROW_ASSIGN_OR_RAISE(int64_t size, source->GetSize());
+ if (size < /* 2 * 4 + 4 */ 12) {
+ return Status::Invalid("File is too small to be a well-formed file");
+ }
+
+ // Determine what kind of file we have. 6 is the max of len(FEA1) and
+ // len(ARROW1)
+ constexpr int magic_size = 6;
+ ARROW_ASSIGN_OR_RAISE(auto buffer, source->ReadAt(0, magic_size));
+
+ if (memcmp(buffer->data(), kFeatherV1MagicBytes, strlen(kFeatherV1MagicBytes)) == 0) {
+ std::shared_ptr<ReaderV1> result = std::make_shared<ReaderV1>();
+ RETURN_NOT_OK(result->Open(source));
+ return result;
+ } else if (memcmp(buffer->data(), internal::kArrowMagicBytes,
+ strlen(internal::kArrowMagicBytes)) == 0) {
+ std::shared_ptr<ReaderV2> result = std::make_shared<ReaderV2>();
+ RETURN_NOT_OK(result->Open(source));
+ return result;
+ } else {
+ return Status::Invalid("Not a Feather V1 or Arrow IPC file");
+ }
+}
+
+WriteProperties WriteProperties::Defaults() {
+ WriteProperties result;
+#ifdef ARROW_WITH_LZ4
+ result.compression = Compression::LZ4_FRAME;
+#else
+ result.compression = Compression::UNCOMPRESSED;
+#endif
+ return result;
+}
+
+Status WriteTable(const Table& table, io::OutputStream* dst,
+ const WriteProperties& properties) {
+ if (properties.version == kFeatherV1Version) {
+ return WriteFeatherV1(table, dst);
+ } else {
+ IpcWriteOptions ipc_options = IpcWriteOptions::Defaults();
+ ipc_options.unify_dictionaries = true;
+ ipc_options.allow_64bit = true;
+ ARROW_ASSIGN_OR_RAISE(
+ ipc_options.codec,
+ util::Codec::Create(properties.compression, properties.compression_level));
+
+ std::shared_ptr<RecordBatchWriter> writer;
+ ARROW_ASSIGN_OR_RAISE(writer, MakeFileWriter(dst, table.schema(), ipc_options));
+ RETURN_NOT_OK(writer->WriteTable(table, properties.chunksize));
+ return writer->Close();
+ }
+}
+
+} // namespace feather
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/feather.fbs b/src/arrow/cpp/src/arrow/ipc/feather.fbs
new file mode 100644
index 000000000..b4076be87
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/feather.fbs
@@ -0,0 +1,156 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// DEPRECATED: Feather V2 is available starting in version 0.17.0 and does not
+/// use this file at all.
+
+namespace arrow.ipc.feather.fbs;
+
+/// Feather is an experimental serialization format implemented using
+/// techniques from Apache Arrow. It was created as a proof-of-concept of an
+/// interoperable file format for storing data frames originating in Python or
+/// R. It enabled the developers to sidestep some of the open design questions
+/// in Arrow from early 2016 and instead create something simple and useful for
+/// the intended use cases.
+
+enum Type : byte {
+ BOOL = 0,
+
+ INT8 = 1,
+ INT16 = 2,
+ INT32 = 3,
+ INT64 = 4,
+
+ UINT8 = 5,
+ UINT16 = 6,
+ UINT32 = 7,
+ UINT64 = 8,
+
+ FLOAT = 9,
+ DOUBLE = 10,
+
+ UTF8 = 11,
+
+ BINARY = 12,
+
+ CATEGORY = 13,
+
+ TIMESTAMP = 14,
+ DATE = 15,
+ TIME = 16,
+
+ LARGE_UTF8 = 17,
+ LARGE_BINARY = 18
+}
+
+enum Encoding : byte {
+ PLAIN = 0,
+
+ /// Data is stored dictionary-encoded
+ /// dictionary size: <INT32 Dictionary size>
+ /// dictionary data: <TYPE primitive array>
+ /// dictionary index: <INT32 primitive array>
+ ///
+ /// TODO: do we care about storing the index values in a smaller typeclass
+ DICTIONARY = 1
+}
+
+enum TimeUnit : byte {
+ SECOND = 0,
+ MILLISECOND = 1,
+ MICROSECOND = 2,
+ NANOSECOND = 3
+}
+
+table PrimitiveArray {
+ type: Type;
+
+ encoding: Encoding = PLAIN;
+
+ /// Relative memory offset of the start of the array data excluding the size
+ /// of the metadata
+ offset: long;
+
+ /// The number of logical values in the array
+ length: long;
+
+ /// The number of observed nulls
+ null_count: long;
+
+ /// The total size of the actual data in the file
+ total_bytes: long;
+
+ /// TODO: Compression
+}
+
+table CategoryMetadata {
+ /// The category codes are presumed to be integers that are valid indexes into
+ /// the levels array
+
+ levels: PrimitiveArray;
+ ordered: bool = false;
+}
+
+table TimestampMetadata {
+ unit: TimeUnit;
+
+ /// Timestamp data is assumed to be UTC, but the time zone is stored here for
+ /// presentation as localized
+ timezone: string;
+}
+
+table DateMetadata {
+}
+
+table TimeMetadata {
+ unit: TimeUnit;
+}
+
+union TypeMetadata {
+ CategoryMetadata,
+ TimestampMetadata,
+ DateMetadata,
+ TimeMetadata,
+}
+
+table Column {
+ name: string;
+ values: PrimitiveArray;
+ metadata: TypeMetadata;
+
+ /// This should (probably) be JSON
+ user_metadata: string;
+}
+
+table CTable {
+ /// Some text (or a name) metadata about what the file is, optional
+ description: string;
+
+ num_rows: long;
+ columns: [Column];
+
+ /// Version number of the Feather format
+ ///
+ /// Internal versions 0, 1, and 2: Implemented in Apache Arrow <= 0.16.0 and
+ /// wesm/feather. Uses "custom" metadata defined in this file.
+ version: int;
+
+ /// Table metadata (likely JSON), not yet used
+ metadata: string;
+}
+
+root_type CTable;
diff --git a/src/arrow/cpp/src/arrow/ipc/feather.h b/src/arrow/cpp/src/arrow/ipc/feather.h
new file mode 100644
index 000000000..a32ff6d0a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/feather.h
@@ -0,0 +1,140 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Public API for the "Feather" file format, originally created at
+// http://github.com/wesm/feather
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/type_fwd.h"
+#include "arrow/util/compression.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Schema;
+class Status;
+class Table;
+
+namespace io {
+
+class OutputStream;
+class RandomAccessFile;
+
+} // namespace io
+
+namespace ipc {
+namespace feather {
+
+static constexpr const int kFeatherV1Version = 2;
+static constexpr const int kFeatherV2Version = 3;
+
+// ----------------------------------------------------------------------
+// Metadata accessor classes
+
+/// \class Reader
+/// \brief An interface for reading columns from Feather files
+class ARROW_EXPORT Reader {
+ public:
+ virtual ~Reader() = default;
+
+ /// \brief Open a Feather file from a RandomAccessFile interface
+ ///
+ /// \param[in] source a RandomAccessFile instance
+ /// \return the table reader
+ static Result<std::shared_ptr<Reader>> Open(
+ const std::shared_ptr<io::RandomAccessFile>& source);
+
+ /// \brief Return the version number of the Feather file
+ virtual int version() const = 0;
+
+ virtual std::shared_ptr<Schema> schema() const = 0;
+
+ /// \brief Read all columns from the file as an arrow::Table.
+ ///
+ /// \param[out] out the returned table
+ /// \return Status
+ ///
+ /// This function is zero-copy if the file source supports zero-copy reads
+ virtual Status Read(std::shared_ptr<Table>* out) = 0;
+
+ /// \brief Read only the specified columns from the file as an arrow::Table.
+ ///
+ /// \param[in] indices the column indices to read
+ /// \param[out] out the returned table
+ /// \return Status
+ ///
+ /// This function is zero-copy if the file source supports zero-copy reads
+ virtual Status Read(const std::vector<int>& indices, std::shared_ptr<Table>* out) = 0;
+
+ /// \brief Read only the specified columns from the file as an arrow::Table.
+ ///
+ /// \param[in] names the column names to read
+ /// \param[out] out the returned table
+ /// \return Status
+ ///
+ /// This function is zero-copy if the file source supports zero-copy reads
+ virtual Status Read(const std::vector<std::string>& names,
+ std::shared_ptr<Table>* out) = 0;
+};
+
+struct ARROW_EXPORT WriteProperties {
+ static WriteProperties Defaults();
+
+ static WriteProperties DefaultsV1() {
+ WriteProperties props = Defaults();
+ props.version = kFeatherV1Version;
+ return props;
+ }
+
+ /// Feather file version number
+ ///
+ /// version 2: "Feather V1" Apache Arrow <= 0.16.0
+ /// version 3: "Feather V2" Apache Arrow > 0.16.0
+ int version = kFeatherV2Version;
+
+ // Parameters for Feather V2 only
+
+ /// Number of rows per intra-file chunk. Use smaller chunksize when you need
+ /// faster random row access
+ int64_t chunksize = 1LL << 16;
+
+ /// Compression type to use. Only UNCOMPRESSED, LZ4_FRAME, and ZSTD are
+ /// supported. The default compression returned by Defaults() is LZ4 if the
+ /// project is built with support for it, otherwise
+ /// UNCOMPRESSED. UNCOMPRESSED is set as the object default here so that if
+ /// WriteProperties::Defaults() is not used, the default constructor for
+ /// WriteProperties will work regardless of the options used to build the C++
+ /// project.
+ Compression::type compression = Compression::UNCOMPRESSED;
+
+ /// Compressor-specific compression level
+ int compression_level = ::arrow::util::kUseDefaultCompressionLevel;
+};
+
+ARROW_EXPORT
+Status WriteTable(const Table& table, io::OutputStream* dst,
+ const WriteProperties& properties = WriteProperties::Defaults());
+
+} // namespace feather
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/feather_test.cc b/src/arrow/cpp/src/arrow/ipc/feather_test.cc
new file mode 100644
index 000000000..e9a3c72c6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/feather_test.cc
@@ -0,0 +1,373 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <tuple>
+#include <utility>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/feather.h"
+#include "arrow/ipc/test_common.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/compression.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace ipc {
+namespace feather {
+
+struct TestParam {
+ TestParam(int arg_version,
+ Compression::type arg_compression = Compression::UNCOMPRESSED)
+ : version(arg_version), compression(arg_compression) {}
+
+ int version;
+ Compression::type compression;
+};
+
+void PrintTo(const TestParam& p, std::ostream* os) {
+ *os << "{version = " << p.version
+ << ", compression = " << ::arrow::util::Codec::GetCodecAsString(p.compression)
+ << "}";
+}
+
+class TestFeatherBase {
+ public:
+ void SetUp() { Initialize(); }
+
+ void Initialize() { ASSERT_OK_AND_ASSIGN(stream_, io::BufferOutputStream::Create()); }
+
+ virtual WriteProperties GetProperties() = 0;
+
+ void DoWrite(const Table& table) {
+ Initialize();
+ ASSERT_OK(WriteTable(table, stream_.get(), GetProperties()));
+ ASSERT_OK_AND_ASSIGN(output_, stream_->Finish());
+ auto buffer = std::make_shared<io::BufferReader>(output_);
+ ASSERT_OK_AND_ASSIGN(reader_, Reader::Open(buffer));
+ }
+
+ void CheckSlice(std::shared_ptr<RecordBatch> batch, int start, int size) {
+ batch = batch->Slice(start, size);
+ ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch}));
+
+ DoWrite(*table);
+ std::shared_ptr<Table> result;
+ ASSERT_OK(reader_->Read(&result));
+ ASSERT_OK(result->ValidateFull());
+ if (table->num_rows() > 0) {
+ AssertTablesEqual(*table, *result);
+ } else {
+ ASSERT_EQ(0, result->num_rows());
+ ASSERT_TRUE(result->schema()->Equals(*table->schema()));
+ }
+ }
+
+ void CheckSlices(std::shared_ptr<RecordBatch> batch) {
+ std::vector<int> starts = {0, 1, 300, 301, 302, 303, 304, 305, 306, 307};
+ std::vector<int> sizes = {0, 1, 7, 8, 30, 32, 100};
+ for (auto start : starts) {
+ for (auto size : sizes) {
+ CheckSlice(batch, start, size);
+ }
+ }
+ }
+
+ void CheckRoundtrip(std::shared_ptr<RecordBatch> batch) {
+ std::vector<std::shared_ptr<RecordBatch>> batches = {batch};
+ ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches(batches));
+
+ DoWrite(*table);
+
+ std::shared_ptr<Table> read_table;
+ ASSERT_OK(reader_->Read(&read_table));
+ ASSERT_OK(read_table->ValidateFull());
+ AssertTablesEqual(*table, *read_table);
+ }
+
+ protected:
+ std::shared_ptr<io::BufferOutputStream> stream_;
+ std::shared_ptr<Reader> reader_;
+ std::shared_ptr<Buffer> output_;
+};
+
+class TestFeather : public ::testing::TestWithParam<TestParam>, public TestFeatherBase {
+ public:
+ void SetUp() { TestFeatherBase::SetUp(); }
+
+ WriteProperties GetProperties() {
+ auto param = GetParam();
+
+ auto props = WriteProperties::Defaults();
+ props.version = param.version;
+
+ // Don't fail if the build doesn't have LZ4_FRAME or ZSTD enabled
+ if (util::Codec::IsAvailable(param.compression)) {
+ props.compression = param.compression;
+ } else {
+ props.compression = Compression::UNCOMPRESSED;
+ }
+ return props;
+ }
+};
+
+class TestFeatherRoundTrip : public ::testing::TestWithParam<ipc::test::MakeRecordBatch*>,
+ public TestFeatherBase {
+ public:
+ void SetUp() { TestFeatherBase::SetUp(); }
+
+ WriteProperties GetProperties() {
+ auto props = WriteProperties::Defaults();
+ props.version = kFeatherV2Version;
+
+ // Don't fail if the build doesn't have LZ4_FRAME or ZSTD enabled
+ if (!util::Codec::IsAvailable(props.compression)) {
+ props.compression = Compression::UNCOMPRESSED;
+ }
+ return props;
+ }
+};
+
+TEST(TestFeatherWriteProperties, Defaults) {
+ auto props = WriteProperties::Defaults();
+
+#ifdef ARROW_WITH_LZ4
+ ASSERT_EQ(Compression::LZ4_FRAME, props.compression);
+#else
+ ASSERT_EQ(Compression::UNCOMPRESSED, props.compression);
+#endif
+}
+
+TEST_P(TestFeather, ReadIndicesOrNames) {
+ std::shared_ptr<RecordBatch> batch1;
+ ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch1));
+
+ ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch1}));
+
+ DoWrite(*table);
+
+ // int32 type is at the column f4 of the result of MakeIntRecordBatch
+ auto expected = Table::Make(schema({field("f4", int32())}), {batch1->column(4)});
+
+ std::shared_ptr<Table> result1, result2;
+
+ std::vector<int> indices = {4};
+ ASSERT_OK(reader_->Read(indices, &result1));
+ AssertTablesEqual(*expected, *result1);
+
+ std::vector<std::string> names = {"f4"};
+ ASSERT_OK(reader_->Read(names, &result2));
+ AssertTablesEqual(*expected, *result2);
+}
+
+TEST_P(TestFeather, EmptyTable) {
+ std::vector<std::shared_ptr<ChunkedArray>> columns;
+ auto table = Table::Make(schema({}), columns, 0);
+
+ DoWrite(*table);
+
+ std::shared_ptr<Table> result;
+ ASSERT_OK(reader_->Read(&result));
+ AssertTablesEqual(*table, *result);
+}
+
+TEST_P(TestFeather, SetNumRows) {
+ std::vector<std::shared_ptr<ChunkedArray>> columns;
+ auto table = Table::Make(schema({}), columns, 1000);
+ DoWrite(*table);
+ std::shared_ptr<Table> result;
+ ASSERT_OK(reader_->Read(&result));
+ ASSERT_EQ(1000, result->num_rows());
+}
+
+TEST_P(TestFeather, PrimitiveIntRoundTrip) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch));
+ CheckRoundtrip(batch);
+}
+
+TEST_P(TestFeather, PrimitiveFloatRoundTrip) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(ipc::test::MakeFloat3264Batch(&batch));
+ CheckRoundtrip(batch);
+}
+
+TEST_P(TestFeather, CategoryRoundtrip) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(ipc::test::MakeDictionaryFlat(&batch));
+ CheckRoundtrip(batch);
+}
+
+TEST_P(TestFeather, TimeTypes) {
+ std::vector<bool> is_valid = {true, true, true, false, true, true, true};
+ auto f0 = field("f0", date32());
+ auto f1 = field("f1", time32(TimeUnit::MILLI));
+ auto f2 = field("f2", timestamp(TimeUnit::NANO));
+ auto f3 = field("f3", timestamp(TimeUnit::SECOND, "US/Los_Angeles"));
+ auto schema = ::arrow::schema({f0, f1, f2, f3});
+
+ std::vector<int64_t> values64_vec = {0, 1, 2, 3, 4, 5, 6};
+ std::shared_ptr<Array> values64;
+ ArrayFromVector<Int64Type, int64_t>(is_valid, values64_vec, &values64);
+
+ std::vector<int32_t> values32_vec = {10, 11, 12, 13, 14, 15, 16};
+ std::shared_ptr<Array> values32;
+ ArrayFromVector<Int32Type, int32_t>(is_valid, values32_vec, &values32);
+
+ std::vector<int32_t> date_values_vec = {20, 21, 22, 23, 24, 25, 26};
+ std::shared_ptr<Array> date_array;
+ ArrayFromVector<Date32Type, int32_t>(is_valid, date_values_vec, &date_array);
+
+ const auto& prim_values64 = checked_cast<const PrimitiveArray&>(*values64);
+ BufferVector buffers64 = {prim_values64.null_bitmap(), prim_values64.values()};
+
+ const auto& prim_values32 = checked_cast<const PrimitiveArray&>(*values32);
+ BufferVector buffers32 = {prim_values32.null_bitmap(), prim_values32.values()};
+
+ // Push date32 ArrayData
+ std::vector<std::shared_ptr<ArrayData>> arrays;
+ arrays.push_back(date_array->data());
+
+ // Create time32 ArrayData
+ arrays.emplace_back(ArrayData::Make(schema->field(1)->type(), values32->length(),
+ BufferVector(buffers32), values32->null_count(),
+ 0));
+
+ // Create timestamp ArrayData
+ for (int i = 2; i < schema->num_fields(); ++i) {
+ arrays.emplace_back(ArrayData::Make(schema->field(i)->type(), values64->length(),
+ BufferVector(buffers64), values64->null_count(),
+ 0));
+ }
+
+ auto batch = RecordBatch::Make(schema, 7, std::move(arrays));
+ CheckRoundtrip(batch);
+}
+
+TEST_P(TestFeather, VLenPrimitiveRoundTrip) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(ipc::test::MakeStringTypesRecordBatch(&batch));
+ CheckRoundtrip(batch);
+}
+
+TEST_P(TestFeather, PrimitiveNullRoundTrip) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(ipc::test::MakeNullRecordBatch(&batch));
+
+ ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch}));
+
+ DoWrite(*table);
+
+ std::shared_ptr<Table> result;
+ ASSERT_OK(reader_->Read(&result));
+
+ if (GetParam().version == kFeatherV1Version) {
+ std::vector<std::shared_ptr<Array>> expected_fields;
+ for (int i = 0; i < batch->num_columns(); ++i) {
+ ASSERT_EQ(batch->column_name(i), reader_->schema()->field(i)->name());
+ ASSERT_OK_AND_ASSIGN(auto expected, MakeArrayOfNull(utf8(), batch->num_rows()));
+ AssertArraysEqual(*expected, *result->column(i)->chunk(0));
+ }
+ } else {
+ AssertTablesEqual(*table, *result);
+ }
+}
+
+TEST_P(TestFeather, SliceIntRoundTrip) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(ipc::test::MakeIntBatchSized(600, &batch));
+ CheckSlices(batch);
+}
+
+TEST_P(TestFeather, SliceFloatRoundTrip) {
+ std::shared_ptr<RecordBatch> batch;
+ // Float16 is not supported by FeatherV1
+ ASSERT_OK(ipc::test::MakeFloat3264BatchSized(600, &batch));
+ CheckSlices(batch);
+}
+
+TEST_P(TestFeather, SliceStringsRoundTrip) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(ipc::test::MakeStringTypesRecordBatch(&batch, /*with_nulls=*/true));
+ CheckSlices(batch);
+}
+
+TEST_P(TestFeather, SliceBooleanRoundTrip) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(ipc::test::MakeBooleanBatchSized(600, &batch));
+ CheckSlices(batch);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ FeatherTests, TestFeather,
+ ::testing::Values(TestParam(kFeatherV1Version), TestParam(kFeatherV2Version),
+ TestParam(kFeatherV2Version, Compression::LZ4_FRAME),
+ TestParam(kFeatherV2Version, Compression::ZSTD)));
+
+namespace {
+
+const std::vector<test::MakeRecordBatch*> kBatchCases = {
+ &ipc::test::MakeIntRecordBatch,
+ &ipc::test::MakeListRecordBatch,
+ &ipc::test::MakeFixedSizeListRecordBatch,
+ &ipc::test::MakeNonNullRecordBatch,
+ &ipc::test::MakeDeeplyNestedList,
+ &ipc::test::MakeStringTypesRecordBatchWithNulls,
+ &ipc::test::MakeStruct,
+ &ipc::test::MakeUnion,
+ &ipc::test::MakeDictionary,
+ &ipc::test::MakeNestedDictionary,
+ &ipc::test::MakeMap,
+ &ipc::test::MakeMapOfDictionary,
+ &ipc::test::MakeDates,
+ &ipc::test::MakeTimestamps,
+ &ipc::test::MakeTimes,
+ &ipc::test::MakeFWBinary,
+ &ipc::test::MakeNull,
+ &ipc::test::MakeDecimal,
+ &ipc::test::MakeBooleanBatch,
+ &ipc::test::MakeFloatBatch,
+ &ipc::test::MakeIntervals};
+
+} // namespace
+
+TEST_P(TestFeatherRoundTrip, RoundTrip) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue
+
+ CheckRoundtrip(batch);
+}
+
+INSTANTIATE_TEST_SUITE_P(FeatherRoundTripTests, TestFeatherRoundTrip,
+ ::testing::ValuesIn(kBatchCases));
+
+} // namespace feather
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/file_fuzz.cc b/src/arrow/cpp/src/arrow/ipc/file_fuzz.cc
new file mode 100644
index 000000000..840d19a4e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/file_fuzz.cc
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+
+#include "arrow/ipc/reader.h"
+#include "arrow/status.h"
+#include "arrow/util/macros.h"
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+ auto status = arrow::ipc::internal::FuzzIpcFile(data, static_cast<int64_t>(size));
+ ARROW_UNUSED(status);
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/ipc/file_to_stream.cc b/src/arrow/cpp/src/arrow/ipc/file_to_stream.cc
new file mode 100644
index 000000000..6ae6a4fa0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/file_to_stream.cc
@@ -0,0 +1,64 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "arrow/io/file.h"
+#include "arrow/io/stdio.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+
+namespace arrow {
+
+class RecordBatch;
+
+namespace ipc {
+
+// Reads a file on the file system and prints to stdout the stream version of it.
+Status ConvertToStream(const char* path) {
+ io::StdoutStream sink;
+
+ ARROW_ASSIGN_OR_RAISE(auto in_file, io::ReadableFile::Open(path));
+ ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(in_file.get()));
+ ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeStreamWriter(&sink, reader->schema(),
+ IpcWriteOptions::Defaults()));
+ for (int i = 0; i < reader->num_record_batches(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> chunk, reader->ReadRecordBatch(i));
+ RETURN_NOT_OK(writer->WriteRecordBatch(*chunk));
+ }
+ return writer->Close();
+}
+
+} // namespace ipc
+} // namespace arrow
+
+int main(int argc, char** argv) {
+ if (argc != 2) {
+ std::cerr << "Usage: file-to-stream <input arrow file>" << std::endl;
+ return 1;
+ }
+ arrow::Status status = arrow::ipc::ConvertToStream(argv[1]);
+ if (!status.ok()) {
+ std::cerr << "Could not convert to stream: " << status.ToString() << std::endl;
+ return 1;
+ }
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/ipc/generate_fuzz_corpus.cc b/src/arrow/cpp/src/arrow/ipc/generate_fuzz_corpus.cc
new file mode 100644
index 000000000..9e6400305
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/generate_fuzz_corpus.cc
@@ -0,0 +1,161 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// A command line executable that generates a bunch of valid IPC files
+// containing example record batches. Those are used as fuzzing seeds
+// to make fuzzing more efficient.
+
+#include <cstdlib>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/io/file.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/json_simple.h"
+#include "arrow/ipc/test_common.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/testing/extension_type.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+namespace ipc {
+
+using ::arrow::internal::CreateDir;
+using ::arrow::internal::PlatformFilename;
+using internal::json::ArrayFromJSON;
+
+Result<std::shared_ptr<RecordBatch>> MakeExtensionBatch() {
+ auto array = ExampleUuid();
+ auto md = key_value_metadata({"key1", "key2"}, {"value1", ""});
+ auto schema = ::arrow::schema({field("f0", array->type())}, md);
+ return RecordBatch::Make(schema, array->length(), {array});
+}
+
+Result<std::shared_ptr<RecordBatch>> MakeMapBatch() {
+ std::shared_ptr<Array> array;
+ const char* json_input = R"(
+[
+ [[0, 1], [1, 1], [2, 2], [3, 3], [4, 5], [5, 8]],
+ null,
+ [[0, null], [1, null], [2, 0], [3, 1], [4, null], [5, 2]],
+ []
+ ]
+)";
+ RETURN_NOT_OK(ArrayFromJSON(map(int16(), int32()), json_input, &array));
+ auto schema = ::arrow::schema({field("f0", array->type())});
+ return RecordBatch::Make(schema, array->length(), {array});
+}
+
+Result<std::vector<std::shared_ptr<RecordBatch>>> Batches() {
+ std::vector<std::shared_ptr<RecordBatch>> batches;
+ std::shared_ptr<RecordBatch> batch;
+ std::shared_ptr<Array> array;
+
+ RETURN_NOT_OK(test::MakeNullRecordBatch(&batch));
+ batches.push_back(batch);
+ RETURN_NOT_OK(test::MakeListRecordBatch(&batch));
+ batches.push_back(batch);
+ RETURN_NOT_OK(test::MakeDictionary(&batch));
+ batches.push_back(batch);
+ RETURN_NOT_OK(test::MakeTimestamps(&batch));
+ batches.push_back(batch);
+ RETURN_NOT_OK(test::MakeFWBinary(&batch));
+ batches.push_back(batch);
+ RETURN_NOT_OK(test::MakeStruct(&batch));
+ batches.push_back(batch);
+ RETURN_NOT_OK(test::MakeUnion(&batch));
+ batches.push_back(batch);
+ RETURN_NOT_OK(test::MakeFixedSizeListRecordBatch(&batch));
+ batches.push_back(batch);
+ ARROW_ASSIGN_OR_RAISE(batch, MakeExtensionBatch());
+ batches.push_back(batch);
+ ARROW_ASSIGN_OR_RAISE(batch, MakeMapBatch());
+ batches.push_back(batch);
+
+ return batches;
+}
+
+Result<std::shared_ptr<Buffer>> SerializeRecordBatch(
+ const std::shared_ptr<RecordBatch>& batch, bool is_stream_format) {
+ ARROW_ASSIGN_OR_RAISE(auto sink, io::BufferOutputStream::Create(1024));
+ std::shared_ptr<RecordBatchWriter> writer;
+ if (is_stream_format) {
+ ARROW_ASSIGN_OR_RAISE(writer, MakeStreamWriter(sink, batch->schema()));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(writer, MakeFileWriter(sink, batch->schema()));
+ }
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+ RETURN_NOT_OK(writer->Close());
+ return sink->Finish();
+}
+
+Status DoMain(bool is_stream_format, const std::string& out_dir) {
+ ARROW_ASSIGN_OR_RAISE(auto dir_fn, PlatformFilename::FromString(out_dir));
+ RETURN_NOT_OK(CreateDir(dir_fn));
+
+ int sample_num = 1;
+ auto sample_name = [&]() -> std::string {
+ return "batch-" + std::to_string(sample_num++);
+ };
+
+ ARROW_ASSIGN_OR_RAISE(auto batches, Batches());
+
+ for (const auto& batch : batches) {
+ RETURN_NOT_OK(batch->ValidateFull());
+ ARROW_ASSIGN_OR_RAISE(auto buf, SerializeRecordBatch(batch, is_stream_format));
+ ARROW_ASSIGN_OR_RAISE(auto sample_fn, dir_fn.Join(sample_name()));
+ std::cerr << sample_fn.ToString() << std::endl;
+ ARROW_ASSIGN_OR_RAISE(auto file, io::FileOutputStream::Open(sample_fn.ToString()));
+ RETURN_NOT_OK(file->Write(buf));
+ RETURN_NOT_OK(file->Close());
+ }
+ return Status::OK();
+}
+
+ARROW_NORETURN void Usage() {
+ std::cerr << "Usage: arrow-ipc-generate-fuzz-corpus "
+ << "[-stream|-file] <output directory>" << std::endl;
+ std::exit(2);
+}
+
+int Main(int argc, char** argv) {
+ if (argc != 3) {
+ Usage();
+ }
+ auto opt = std::string(argv[1]);
+ if (opt != "-stream" && opt != "-file") {
+ Usage();
+ }
+ auto out_dir = std::string(argv[2]);
+
+ Status st = DoMain(opt == "-stream", out_dir);
+ if (!st.ok()) {
+ std::cerr << st.ToString() << std::endl;
+ return 1;
+ }
+ return 0;
+}
+
+} // namespace ipc
+} // namespace arrow
+
+int main(int argc, char** argv) { return arrow::ipc::Main(argc, argv); }
diff --git a/src/arrow/cpp/src/arrow/ipc/generate_tensor_fuzz_corpus.cc b/src/arrow/cpp/src/arrow/ipc/generate_tensor_fuzz_corpus.cc
new file mode 100644
index 000000000..dd40ef0ab
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/generate_tensor_fuzz_corpus.cc
@@ -0,0 +1,134 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// A command line executable that generates a bunch of valid IPC files
+// containing example tensors. Those are used as fuzzing seeds to make
+// fuzzing more efficient.
+
+#include <cstdlib>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/io/file.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/test_common.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/result.h"
+#include "arrow/tensor.h"
+#include "arrow/util/io_util.h"
+
+namespace arrow {
+namespace ipc {
+
+using ::arrow::internal::PlatformFilename;
+
+Result<PlatformFilename> PrepareDirectory(const std::string& dir) {
+ ARROW_ASSIGN_OR_RAISE(auto dir_fn, PlatformFilename::FromString(dir));
+ RETURN_NOT_OK(::arrow::internal::CreateDir(dir_fn));
+ return std::move(dir_fn);
+}
+
+Result<std::shared_ptr<Buffer>> MakeSerializedBuffer(
+ std::function<Status(const std::shared_ptr<io::BufferOutputStream>&)> fn) {
+ ARROW_ASSIGN_OR_RAISE(auto sink, io::BufferOutputStream::Create(1024));
+ RETURN_NOT_OK(fn(sink));
+ return sink->Finish();
+}
+
+Result<std::shared_ptr<Buffer>> SerializeTensor(const std::shared_ptr<Tensor>& tensor) {
+ return MakeSerializedBuffer(
+ [&](const std::shared_ptr<io::BufferOutputStream>& sink) -> Status {
+ int32_t metadata_length;
+ int64_t body_length;
+ return ipc::WriteTensor(*tensor, sink.get(), &metadata_length, &body_length);
+ });
+}
+
+Result<std::vector<std::shared_ptr<Tensor>>> Tensors() {
+ std::vector<std::shared_ptr<Tensor>> tensors;
+ std::shared_ptr<Tensor> tensor;
+ std::vector<int64_t> shape = {5, 3, 7};
+ std::shared_ptr<DataType> types[] = {int8(), int16(), int32(), int64(),
+ uint8(), uint16(), uint32(), uint64()};
+ uint32_t seed = 0;
+ for (auto type : types) {
+ RETURN_NOT_OK(
+ test::MakeRandomTensor(type, shape, /*row_major_p=*/true, &tensor, seed++));
+ tensors.push_back(tensor);
+ RETURN_NOT_OK(
+ test::MakeRandomTensor(type, shape, /*row_major_p=*/false, &tensor, seed++));
+ tensors.push_back(tensor);
+ }
+ return tensors;
+}
+
+Status GenerateTensors(const PlatformFilename& dir_fn) {
+ int sample_num = 1;
+ auto sample_name = [&]() -> std::string {
+ return "tensor-" + std::to_string(sample_num++);
+ };
+
+ ARROW_ASSIGN_OR_RAISE(auto tensors, Tensors());
+
+ for (const auto& tensor : tensors) {
+ ARROW_ASSIGN_OR_RAISE(auto buf, SerializeTensor(tensor));
+ ARROW_ASSIGN_OR_RAISE(auto sample_fn, dir_fn.Join(sample_name()));
+ std::cerr << sample_fn.ToString() << std::endl;
+ ARROW_ASSIGN_OR_RAISE(auto file, io::FileOutputStream::Open(sample_fn.ToString()));
+ RETURN_NOT_OK(file->Write(buf));
+ RETURN_NOT_OK(file->Close());
+ }
+ return Status::OK();
+}
+
+Status DoMain(const std::string& out_dir) {
+ ARROW_ASSIGN_OR_RAISE(auto dir_fn, PrepareDirectory(out_dir));
+ return GenerateTensors(dir_fn);
+}
+
+ARROW_NORETURN void Usage() {
+ std::cerr << "Usage: arrow-ipc-generate-tensor-fuzz-corpus "
+ << "-stream <output directory>" << std::endl;
+ std::exit(2);
+}
+
+int Main(int argc, char** argv) {
+ if (argc != 3) {
+ Usage();
+ }
+
+ auto opt = std::string(argv[1]);
+ if (opt != "-stream") {
+ Usage();
+ }
+
+ auto out_dir = std::string(argv[2]);
+
+ Status st = DoMain(out_dir);
+ if (!st.ok()) {
+ std::cerr << st.ToString() << std::endl;
+ return 1;
+ }
+ return 0;
+}
+
+} // namespace ipc
+} // namespace arrow
+
+int main(int argc, char** argv) { return arrow::ipc::Main(argc, argv); }
diff --git a/src/arrow/cpp/src/arrow/ipc/json_simple.cc b/src/arrow/cpp/src/arrow/ipc/json_simple.cc
new file mode 100644
index 000000000..8347b871b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/json_simple.cc
@@ -0,0 +1,994 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <sstream>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/array_dict.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/array/builder_dict.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/builder_time.h"
+#include "arrow/array/builder_union.h"
+#include "arrow/ipc/json_simple.h"
+#include "arrow/scalar.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/value_parsing.h"
+
+#include "arrow/json/rapidjson_defs.h"
+
+#include <rapidjson/document.h>
+#include <rapidjson/error/en.h>
+#include <rapidjson/rapidjson.h>
+#include <rapidjson/reader.h>
+#include <rapidjson/writer.h>
+
+namespace rj = arrow::rapidjson;
+
+namespace arrow {
+
+using internal::ParseValue;
+
+namespace ipc {
+namespace internal {
+namespace json {
+
+using ::arrow::internal::checked_cast;
+using ::arrow::internal::checked_pointer_cast;
+
+namespace {
+
+constexpr auto kParseFlags = rj::kParseFullPrecisionFlag | rj::kParseNanAndInfFlag;
+
+Status JSONTypeError(const char* expected_type, rj::Type json_type) {
+ return Status::Invalid("Expected ", expected_type, " or null, got JSON type ",
+ json_type);
+}
+
+class Converter {
+ public:
+ virtual ~Converter() = default;
+
+ virtual Status Init() { return Status::OK(); }
+
+ virtual Status AppendValue(const rj::Value& json_obj) = 0;
+
+ Status AppendNull() { return this->builder()->AppendNull(); }
+
+ virtual Status AppendValues(const rj::Value& json_array) = 0;
+
+ virtual std::shared_ptr<ArrayBuilder> builder() = 0;
+
+ virtual Status Finish(std::shared_ptr<Array>* out) {
+ auto builder = this->builder();
+ if (builder->length() == 0) {
+ // Make sure the builder was initialized
+ RETURN_NOT_OK(builder->Resize(1));
+ }
+ return builder->Finish(out);
+ }
+
+ protected:
+ std::shared_ptr<DataType> type_;
+};
+
+Status GetConverter(const std::shared_ptr<DataType>&, std::shared_ptr<Converter>* out);
+
+// CRTP
+template <class Derived>
+class ConcreteConverter : public Converter {
+ public:
+ Status AppendValues(const rj::Value& json_array) override {
+ auto self = static_cast<Derived*>(this);
+ if (!json_array.IsArray()) {
+ return JSONTypeError("array", json_array.GetType());
+ }
+ auto size = json_array.Size();
+ for (uint32_t i = 0; i < size; ++i) {
+ RETURN_NOT_OK(self->AppendValue(json_array[i]));
+ }
+ return Status::OK();
+ }
+
+ const std::shared_ptr<DataType>& value_type() {
+ if (type_->id() != Type::DICTIONARY) {
+ return type_;
+ }
+ return checked_cast<const DictionaryType&>(*type_).value_type();
+ }
+
+ template <typename BuilderType>
+ Status MakeConcreteBuilder(std::shared_ptr<BuilderType>* out) {
+ std::unique_ptr<ArrayBuilder> builder;
+ RETURN_NOT_OK(MakeBuilder(default_memory_pool(), this->type_, &builder));
+ *out = checked_pointer_cast<BuilderType>(std::move(builder));
+ DCHECK(*out);
+ return Status::OK();
+ }
+};
+
+// ------------------------------------------------------------------------
+// Converter for null arrays
+
+class NullConverter final : public ConcreteConverter<NullConverter> {
+ public:
+ explicit NullConverter(const std::shared_ptr<DataType>& type) {
+ type_ = type;
+ builder_ = std::make_shared<NullBuilder>();
+ }
+
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return AppendNull();
+ }
+ return JSONTypeError("null", json_obj.GetType());
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ std::shared_ptr<NullBuilder> builder_;
+};
+
+// ------------------------------------------------------------------------
+// Converter for boolean arrays
+
+class BooleanConverter final : public ConcreteConverter<BooleanConverter> {
+ public:
+ explicit BooleanConverter(const std::shared_ptr<DataType>& type) {
+ type_ = type;
+ builder_ = std::make_shared<BooleanBuilder>();
+ }
+
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return AppendNull();
+ }
+ if (json_obj.IsBool()) {
+ return builder_->Append(json_obj.GetBool());
+ }
+ if (json_obj.IsInt()) {
+ return builder_->Append(json_obj.GetInt() != 0);
+ }
+ return JSONTypeError("boolean", json_obj.GetType());
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ std::shared_ptr<BooleanBuilder> builder_;
+};
+
+// ------------------------------------------------------------------------
+// Helpers for numeric converters
+
+// Convert single signed integer value (also {Date,Time}{32,64} and Timestamp)
+template <typename T>
+enable_if_physical_signed_integer<T, Status> ConvertNumber(const rj::Value& json_obj,
+ const DataType& type,
+ typename T::c_type* out) {
+ if (json_obj.IsInt64()) {
+ int64_t v64 = json_obj.GetInt64();
+ *out = static_cast<typename T::c_type>(v64);
+ if (*out == v64) {
+ return Status::OK();
+ } else {
+ return Status::Invalid("Value ", v64, " out of bounds for ", type);
+ }
+ } else {
+ *out = static_cast<typename T::c_type>(0);
+ return JSONTypeError("signed int", json_obj.GetType());
+ }
+}
+
+// Convert single unsigned integer value
+template <typename T>
+enable_if_physical_unsigned_integer<T, Status> ConvertNumber(const rj::Value& json_obj,
+ const DataType& type,
+ typename T::c_type* out) {
+ if (json_obj.IsUint64()) {
+ uint64_t v64 = json_obj.GetUint64();
+ *out = static_cast<typename T::c_type>(v64);
+ if (*out == v64) {
+ return Status::OK();
+ } else {
+ return Status::Invalid("Value ", v64, " out of bounds for ", type);
+ }
+ } else {
+ *out = static_cast<typename T::c_type>(0);
+ return JSONTypeError("unsigned int", json_obj.GetType());
+ }
+}
+
+// Convert single floating point value
+template <typename T>
+enable_if_physical_floating_point<T, Status> ConvertNumber(const rj::Value& json_obj,
+ const DataType& type,
+ typename T::c_type* out) {
+ if (json_obj.IsNumber()) {
+ *out = static_cast<typename T::c_type>(json_obj.GetDouble());
+ return Status::OK();
+ } else {
+ *out = static_cast<typename T::c_type>(0);
+ return JSONTypeError("number", json_obj.GetType());
+ }
+}
+
+// ------------------------------------------------------------------------
+// Converter for int arrays
+
+template <typename Type, typename BuilderType = typename TypeTraits<Type>::BuilderType>
+class IntegerConverter final
+ : public ConcreteConverter<IntegerConverter<Type, BuilderType>> {
+ using c_type = typename Type::c_type;
+
+ static constexpr auto is_signed = std::is_signed<c_type>::value;
+
+ public:
+ explicit IntegerConverter(const std::shared_ptr<DataType>& type) { this->type_ = type; }
+
+ Status Init() override { return this->MakeConcreteBuilder(&builder_); }
+
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return this->AppendNull();
+ }
+ c_type value;
+ RETURN_NOT_OK(ConvertNumber<Type>(json_obj, *this->type_, &value));
+ return builder_->Append(value);
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ std::shared_ptr<BuilderType> builder_;
+};
+
+// ------------------------------------------------------------------------
+// Converter for float arrays
+
+template <typename Type, typename BuilderType = typename TypeTraits<Type>::BuilderType>
+class FloatConverter final : public ConcreteConverter<FloatConverter<Type, BuilderType>> {
+ using c_type = typename Type::c_type;
+
+ public:
+ explicit FloatConverter(const std::shared_ptr<DataType>& type) { this->type_ = type; }
+
+ Status Init() override { return this->MakeConcreteBuilder(&builder_); }
+
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return this->AppendNull();
+ }
+ c_type value;
+ RETURN_NOT_OK(ConvertNumber<Type>(json_obj, *this->type_, &value));
+ return builder_->Append(value);
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ std::shared_ptr<BuilderType> builder_;
+};
+
+// ------------------------------------------------------------------------
+// Converter for decimal arrays
+
+template <typename DecimalSubtype, typename DecimalValue, typename BuilderType>
+class DecimalConverter final
+ : public ConcreteConverter<
+ DecimalConverter<DecimalSubtype, DecimalValue, BuilderType>> {
+ public:
+ explicit DecimalConverter(const std::shared_ptr<DataType>& type) {
+ this->type_ = type;
+ decimal_type_ = &checked_cast<const DecimalSubtype&>(*this->value_type());
+ }
+
+ Status Init() override { return this->MakeConcreteBuilder(&builder_); }
+
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return this->AppendNull();
+ }
+ if (json_obj.IsString()) {
+ int32_t precision, scale;
+ DecimalValue d;
+ auto view = util::string_view(json_obj.GetString(), json_obj.GetStringLength());
+ RETURN_NOT_OK(DecimalValue::FromString(view, &d, &precision, &scale));
+ if (scale != decimal_type_->scale()) {
+ return Status::Invalid("Invalid scale for decimal: expected ",
+ decimal_type_->scale(), ", got ", scale);
+ }
+ return builder_->Append(d);
+ }
+ return JSONTypeError("decimal string", json_obj.GetType());
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ std::shared_ptr<BuilderType> builder_;
+ const DecimalSubtype* decimal_type_;
+};
+
+template <typename BuilderType = typename TypeTraits<Decimal128Type>::BuilderType>
+using Decimal128Converter = DecimalConverter<Decimal128Type, Decimal128, BuilderType>;
+template <typename BuilderType = typename TypeTraits<Decimal256Type>::BuilderType>
+using Decimal256Converter = DecimalConverter<Decimal256Type, Decimal256, BuilderType>;
+
+// ------------------------------------------------------------------------
+// Converter for timestamp arrays
+
+class TimestampConverter final : public ConcreteConverter<TimestampConverter> {
+ public:
+ explicit TimestampConverter(const std::shared_ptr<DataType>& type)
+ : timestamp_type_{checked_cast<const TimestampType*>(type.get())} {
+ this->type_ = type;
+ builder_ = std::make_shared<TimestampBuilder>(type, default_memory_pool());
+ }
+
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return this->AppendNull();
+ }
+ int64_t value;
+ if (json_obj.IsNumber()) {
+ RETURN_NOT_OK(ConvertNumber<Int64Type>(json_obj, *this->type_, &value));
+ } else if (json_obj.IsString()) {
+ util::string_view view(json_obj.GetString(), json_obj.GetStringLength());
+ if (!ParseValue(*timestamp_type_, view.data(), view.size(), &value)) {
+ return Status::Invalid("couldn't parse timestamp from ", view);
+ }
+ } else {
+ return JSONTypeError("timestamp", json_obj.GetType());
+ }
+ return builder_->Append(value);
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ const TimestampType* timestamp_type_;
+ std::shared_ptr<TimestampBuilder> builder_;
+};
+
+// ------------------------------------------------------------------------
+// Converter for day-time interval arrays
+
+class DayTimeIntervalConverter final
+ : public ConcreteConverter<DayTimeIntervalConverter> {
+ public:
+ explicit DayTimeIntervalConverter(const std::shared_ptr<DataType>& type) {
+ this->type_ = type;
+ builder_ = std::make_shared<DayTimeIntervalBuilder>(default_memory_pool());
+ }
+
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return this->AppendNull();
+ }
+ DayTimeIntervalType::DayMilliseconds value;
+ if (!json_obj.IsArray()) {
+ return JSONTypeError("array", json_obj.GetType());
+ }
+ if (json_obj.Size() != 2) {
+ return Status::Invalid(
+ "day time interval pair must have exactly two elements, had ", json_obj.Size());
+ }
+ RETURN_NOT_OK(ConvertNumber<Int32Type>(json_obj[0], *this->type_, &value.days));
+ RETURN_NOT_OK(
+ ConvertNumber<Int32Type>(json_obj[1], *this->type_, &value.milliseconds));
+ return builder_->Append(value);
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ std::shared_ptr<DayTimeIntervalBuilder> builder_;
+};
+
+class MonthDayNanoIntervalConverter final
+ : public ConcreteConverter<MonthDayNanoIntervalConverter> {
+ public:
+ explicit MonthDayNanoIntervalConverter(const std::shared_ptr<DataType>& type) {
+ this->type_ = type;
+ builder_ = std::make_shared<MonthDayNanoIntervalBuilder>(default_memory_pool());
+ }
+
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return this->AppendNull();
+ }
+ MonthDayNanoIntervalType::MonthDayNanos value;
+ if (!json_obj.IsArray()) {
+ return JSONTypeError("array", json_obj.GetType());
+ }
+ if (json_obj.Size() != 3) {
+ return Status::Invalid(
+ "month_day_nano_interval must have exactly 3 elements, had ", json_obj.Size());
+ }
+ RETURN_NOT_OK(ConvertNumber<Int32Type>(json_obj[0], *this->type_, &value.months));
+ RETURN_NOT_OK(ConvertNumber<Int32Type>(json_obj[1], *this->type_, &value.days));
+ RETURN_NOT_OK(
+ ConvertNumber<Int64Type>(json_obj[2], *this->type_, &value.nanoseconds));
+
+ return builder_->Append(value);
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ std::shared_ptr<MonthDayNanoIntervalBuilder> builder_;
+};
+
+// ------------------------------------------------------------------------
+// Converter for binary and string arrays
+
+template <typename Type, typename BuilderType = typename TypeTraits<Type>::BuilderType>
+class StringConverter final
+ : public ConcreteConverter<StringConverter<Type, BuilderType>> {
+ public:
+ explicit StringConverter(const std::shared_ptr<DataType>& type) { this->type_ = type; }
+
+ Status Init() override { return this->MakeConcreteBuilder(&builder_); }
+
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return this->AppendNull();
+ }
+ if (json_obj.IsString()) {
+ auto view = util::string_view(json_obj.GetString(), json_obj.GetStringLength());
+ return builder_->Append(view);
+ } else {
+ return JSONTypeError("string", json_obj.GetType());
+ }
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ std::shared_ptr<BuilderType> builder_;
+};
+
+// ------------------------------------------------------------------------
+// Converter for fixed-size binary arrays
+
+template <typename BuilderType = typename TypeTraits<FixedSizeBinaryType>::BuilderType>
+class FixedSizeBinaryConverter final
+ : public ConcreteConverter<FixedSizeBinaryConverter<BuilderType>> {
+ public:
+ explicit FixedSizeBinaryConverter(const std::shared_ptr<DataType>& type) {
+ this->type_ = type;
+ }
+
+ Status Init() override { return this->MakeConcreteBuilder(&builder_); }
+
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return this->AppendNull();
+ }
+ if (json_obj.IsString()) {
+ auto view = util::string_view(json_obj.GetString(), json_obj.GetStringLength());
+ if (view.length() != static_cast<size_t>(builder_->byte_width())) {
+ std::stringstream ss;
+ ss << "Invalid string length " << view.length() << " in JSON input for "
+ << this->type_->ToString();
+ return Status::Invalid(ss.str());
+ }
+ return builder_->Append(view);
+ } else {
+ return JSONTypeError("string", json_obj.GetType());
+ }
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ std::shared_ptr<BuilderType> builder_;
+};
+
+// ------------------------------------------------------------------------
+// Converter for list arrays
+
+template <typename TYPE>
+class ListConverter final : public ConcreteConverter<ListConverter<TYPE>> {
+ public:
+ using BuilderType = typename TypeTraits<TYPE>::BuilderType;
+
+ explicit ListConverter(const std::shared_ptr<DataType>& type) { this->type_ = type; }
+
+ Status Init() override {
+ const auto& list_type = checked_cast<const TYPE&>(*this->type_);
+ RETURN_NOT_OK(GetConverter(list_type.value_type(), &child_converter_));
+ auto child_builder = child_converter_->builder();
+ builder_ =
+ std::make_shared<BuilderType>(default_memory_pool(), child_builder, this->type_);
+ return Status::OK();
+ }
+
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return this->AppendNull();
+ }
+ RETURN_NOT_OK(builder_->Append());
+ // Extend the child converter with this JSON array
+ return child_converter_->AppendValues(json_obj);
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ std::shared_ptr<BuilderType> builder_;
+ std::shared_ptr<Converter> child_converter_;
+};
+
+// ------------------------------------------------------------------------
+// Converter for map arrays
+
+class MapConverter final : public ConcreteConverter<MapConverter> {
+ public:
+ explicit MapConverter(const std::shared_ptr<DataType>& type) { type_ = type; }
+
+ Status Init() override {
+ const auto& map_type = checked_cast<const MapType&>(*type_);
+ RETURN_NOT_OK(GetConverter(map_type.key_type(), &key_converter_));
+ RETURN_NOT_OK(GetConverter(map_type.item_type(), &item_converter_));
+ auto key_builder = key_converter_->builder();
+ auto item_builder = item_converter_->builder();
+ builder_ = std::make_shared<MapBuilder>(default_memory_pool(), key_builder,
+ item_builder, type_);
+ return Status::OK();
+ }
+
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return this->AppendNull();
+ }
+ RETURN_NOT_OK(builder_->Append());
+ if (!json_obj.IsArray()) {
+ return JSONTypeError("array", json_obj.GetType());
+ }
+ auto size = json_obj.Size();
+ for (uint32_t i = 0; i < size; ++i) {
+ const auto& json_pair = json_obj[i];
+ if (!json_pair.IsArray()) {
+ return JSONTypeError("array", json_pair.GetType());
+ }
+ if (json_pair.Size() != 2) {
+ return Status::Invalid("key item pair must have exactly two elements, had ",
+ json_pair.Size());
+ }
+ if (json_pair[0].IsNull()) {
+ return Status::Invalid("null key is invalid");
+ }
+ RETURN_NOT_OK(key_converter_->AppendValue(json_pair[0]));
+ RETURN_NOT_OK(item_converter_->AppendValue(json_pair[1]));
+ }
+ return Status::OK();
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ std::shared_ptr<MapBuilder> builder_;
+ std::shared_ptr<Converter> key_converter_, item_converter_;
+};
+
+// ------------------------------------------------------------------------
+// Converter for fixed size list arrays
+
+class FixedSizeListConverter final : public ConcreteConverter<FixedSizeListConverter> {
+ public:
+ explicit FixedSizeListConverter(const std::shared_ptr<DataType>& type) { type_ = type; }
+
+ Status Init() override {
+ const auto& list_type = checked_cast<const FixedSizeListType&>(*type_);
+ list_size_ = list_type.list_size();
+ RETURN_NOT_OK(GetConverter(list_type.value_type(), &child_converter_));
+ auto child_builder = child_converter_->builder();
+ builder_ = std::make_shared<FixedSizeListBuilder>(default_memory_pool(),
+ child_builder, type_);
+ return Status::OK();
+ }
+
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return this->AppendNull();
+ }
+ RETURN_NOT_OK(builder_->Append());
+ // Extend the child converter with this JSON array
+ RETURN_NOT_OK(child_converter_->AppendValues(json_obj));
+ if (json_obj.GetArray().Size() != static_cast<rj::SizeType>(list_size_)) {
+ return Status::Invalid("incorrect list size ", json_obj.GetArray().Size());
+ }
+ return Status::OK();
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ int32_t list_size_;
+ std::shared_ptr<FixedSizeListBuilder> builder_;
+ std::shared_ptr<Converter> child_converter_;
+};
+
+// ------------------------------------------------------------------------
+// Converter for struct arrays
+
+class StructConverter final : public ConcreteConverter<StructConverter> {
+ public:
+ explicit StructConverter(const std::shared_ptr<DataType>& type) { type_ = type; }
+
+ Status Init() override {
+ std::vector<std::shared_ptr<ArrayBuilder>> child_builders;
+ for (const auto& field : type_->fields()) {
+ std::shared_ptr<Converter> child_converter;
+ RETURN_NOT_OK(GetConverter(field->type(), &child_converter));
+ child_converters_.push_back(child_converter);
+ child_builders.push_back(child_converter->builder());
+ }
+ builder_ = std::make_shared<StructBuilder>(type_, default_memory_pool(),
+ std::move(child_builders));
+ return Status::OK();
+ }
+
+ // Append a JSON value that is either an array of N elements in order
+ // or an object mapping struct names to values (omitted struct members
+ // are mapped to null).
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return this->AppendNull();
+ }
+ if (json_obj.IsArray()) {
+ auto size = json_obj.Size();
+ auto expected_size = static_cast<uint32_t>(type_->num_fields());
+ if (size != expected_size) {
+ return Status::Invalid("Expected array of size ", expected_size,
+ ", got array of size ", size);
+ }
+ for (uint32_t i = 0; i < size; ++i) {
+ RETURN_NOT_OK(child_converters_[i]->AppendValue(json_obj[i]));
+ }
+ return builder_->Append();
+ }
+ if (json_obj.IsObject()) {
+ auto remaining = json_obj.MemberCount();
+ auto num_children = type_->num_fields();
+ for (int32_t i = 0; i < num_children; ++i) {
+ const auto& field = type_->field(i);
+ auto it = json_obj.FindMember(field->name());
+ if (it != json_obj.MemberEnd()) {
+ --remaining;
+ RETURN_NOT_OK(child_converters_[i]->AppendValue(it->value));
+ } else {
+ RETURN_NOT_OK(child_converters_[i]->AppendNull());
+ }
+ }
+ if (remaining > 0) {
+ rj::StringBuffer sb;
+ rj::Writer<rj::StringBuffer> writer(sb);
+ json_obj.Accept(writer);
+ return Status::Invalid("Unexpected members in JSON object for type ",
+ type_->ToString(), " Object: ", sb.GetString());
+ }
+ return builder_->Append();
+ }
+ return JSONTypeError("array or object", json_obj.GetType());
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ std::shared_ptr<StructBuilder> builder_;
+ std::vector<std::shared_ptr<Converter>> child_converters_;
+};
+
+// ------------------------------------------------------------------------
+// Converter for union arrays
+
+class UnionConverter final : public ConcreteConverter<UnionConverter> {
+ public:
+ explicit UnionConverter(const std::shared_ptr<DataType>& type) { type_ = type; }
+
+ Status Init() override {
+ auto union_type = checked_cast<const UnionType*>(type_.get());
+ mode_ = union_type->mode();
+ type_id_to_child_num_.clear();
+ type_id_to_child_num_.resize(union_type->max_type_code() + 1, -1);
+ int child_i = 0;
+ for (auto type_id : union_type->type_codes()) {
+ type_id_to_child_num_[type_id] = child_i++;
+ }
+ std::vector<std::shared_ptr<ArrayBuilder>> child_builders;
+ for (const auto& field : type_->fields()) {
+ std::shared_ptr<Converter> child_converter;
+ RETURN_NOT_OK(GetConverter(field->type(), &child_converter));
+ child_converters_.push_back(child_converter);
+ child_builders.push_back(child_converter->builder());
+ }
+ if (mode_ == UnionMode::DENSE) {
+ builder_ = std::make_shared<DenseUnionBuilder>(default_memory_pool(),
+ std::move(child_builders), type_);
+ } else {
+ builder_ = std::make_shared<SparseUnionBuilder>(default_memory_pool(),
+ std::move(child_builders), type_);
+ }
+ return Status::OK();
+ }
+
+ // Append a JSON value that must be a 2-long array, containing the type_id
+ // and value of the UnionArray's slot.
+ Status AppendValue(const rj::Value& json_obj) override {
+ if (json_obj.IsNull()) {
+ return this->AppendNull();
+ }
+ if (!json_obj.IsArray()) {
+ return JSONTypeError("array", json_obj.GetType());
+ }
+ if (json_obj.Size() != 2) {
+ return Status::Invalid("Expected [type_id, value] pair, got array of size ",
+ json_obj.Size());
+ }
+ const auto& id_obj = json_obj[0];
+ if (!id_obj.IsInt()) {
+ return JSONTypeError("int", id_obj.GetType());
+ }
+
+ auto id = static_cast<int8_t>(id_obj.GetInt());
+ auto child_num = type_id_to_child_num_[id];
+ if (child_num == -1) {
+ return Status::Invalid("type_id ", id, " not found in ", *type_);
+ }
+
+ auto child_converter = child_converters_[child_num];
+ if (mode_ == UnionMode::SPARSE) {
+ RETURN_NOT_OK(checked_cast<SparseUnionBuilder&>(*builder_).Append(id));
+ for (auto&& other_converter : child_converters_) {
+ if (other_converter != child_converter) {
+ RETURN_NOT_OK(other_converter->AppendNull());
+ }
+ }
+ } else {
+ RETURN_NOT_OK(checked_cast<DenseUnionBuilder&>(*builder_).Append(id));
+ }
+ return child_converter->AppendValue(json_obj[1]);
+ }
+
+ std::shared_ptr<ArrayBuilder> builder() override { return builder_; }
+
+ private:
+ UnionMode::type mode_;
+ std::shared_ptr<ArrayBuilder> builder_;
+ std::vector<std::shared_ptr<Converter>> child_converters_;
+ std::vector<int8_t> type_id_to_child_num_;
+};
+
+// ------------------------------------------------------------------------
+// General conversion functions
+
+Status ConversionNotImplemented(const std::shared_ptr<DataType>& type) {
+ return Status::NotImplemented("JSON conversion to ", type->ToString(),
+ " not implemented");
+}
+
+Status GetDictConverter(const std::shared_ptr<DataType>& type,
+ std::shared_ptr<Converter>* out) {
+ std::shared_ptr<Converter> res;
+
+ const auto value_type = checked_cast<const DictionaryType&>(*type).value_type();
+
+#define SIMPLE_CONVERTER_CASE(ID, CLASS, TYPE) \
+ case ID: \
+ res = std::make_shared<CLASS<DictionaryBuilder<TYPE>>>(type); \
+ break;
+
+#define PARAM_CONVERTER_CASE(ID, CLASS, TYPE) \
+ case ID: \
+ res = std::make_shared<CLASS<TYPE, DictionaryBuilder<TYPE>>>(type); \
+ break;
+
+ switch (value_type->id()) {
+ PARAM_CONVERTER_CASE(Type::INT8, IntegerConverter, Int8Type)
+ PARAM_CONVERTER_CASE(Type::INT16, IntegerConverter, Int16Type)
+ PARAM_CONVERTER_CASE(Type::INT32, IntegerConverter, Int32Type)
+ PARAM_CONVERTER_CASE(Type::INT64, IntegerConverter, Int64Type)
+ PARAM_CONVERTER_CASE(Type::UINT8, IntegerConverter, UInt8Type)
+ PARAM_CONVERTER_CASE(Type::UINT16, IntegerConverter, UInt16Type)
+ PARAM_CONVERTER_CASE(Type::UINT32, IntegerConverter, UInt32Type)
+ PARAM_CONVERTER_CASE(Type::UINT64, IntegerConverter, UInt64Type)
+ PARAM_CONVERTER_CASE(Type::FLOAT, FloatConverter, FloatType)
+ PARAM_CONVERTER_CASE(Type::DOUBLE, FloatConverter, DoubleType)
+ PARAM_CONVERTER_CASE(Type::STRING, StringConverter, StringType)
+ PARAM_CONVERTER_CASE(Type::BINARY, StringConverter, BinaryType)
+ PARAM_CONVERTER_CASE(Type::LARGE_STRING, StringConverter, LargeStringType)
+ PARAM_CONVERTER_CASE(Type::LARGE_BINARY, StringConverter, LargeBinaryType)
+ SIMPLE_CONVERTER_CASE(Type::FIXED_SIZE_BINARY, FixedSizeBinaryConverter,
+ FixedSizeBinaryType)
+ SIMPLE_CONVERTER_CASE(Type::DECIMAL128, Decimal128Converter, Decimal128Type)
+ SIMPLE_CONVERTER_CASE(Type::DECIMAL256, Decimal256Converter, Decimal256Type)
+ default:
+ return ConversionNotImplemented(type);
+ }
+
+#undef SIMPLE_CONVERTER_CASE
+#undef PARAM_CONVERTER_CASE
+
+ RETURN_NOT_OK(res->Init());
+ *out = res;
+ return Status::OK();
+}
+
+Status GetConverter(const std::shared_ptr<DataType>& type,
+ std::shared_ptr<Converter>* out) {
+ if (type->id() == Type::DICTIONARY) {
+ return GetDictConverter(type, out);
+ }
+
+ std::shared_ptr<Converter> res;
+
+#define SIMPLE_CONVERTER_CASE(ID, CLASS) \
+ case ID: \
+ res = std::make_shared<CLASS>(type); \
+ break;
+
+ switch (type->id()) {
+ SIMPLE_CONVERTER_CASE(Type::INT8, IntegerConverter<Int8Type>)
+ SIMPLE_CONVERTER_CASE(Type::INT16, IntegerConverter<Int16Type>)
+ SIMPLE_CONVERTER_CASE(Type::INT32, IntegerConverter<Int32Type>)
+ SIMPLE_CONVERTER_CASE(Type::INT64, IntegerConverter<Int64Type>)
+ SIMPLE_CONVERTER_CASE(Type::UINT8, IntegerConverter<UInt8Type>)
+ SIMPLE_CONVERTER_CASE(Type::UINT16, IntegerConverter<UInt16Type>)
+ SIMPLE_CONVERTER_CASE(Type::UINT32, IntegerConverter<UInt32Type>)
+ SIMPLE_CONVERTER_CASE(Type::UINT64, IntegerConverter<UInt64Type>)
+ SIMPLE_CONVERTER_CASE(Type::TIMESTAMP, TimestampConverter)
+ SIMPLE_CONVERTER_CASE(Type::DATE32, IntegerConverter<Date32Type>)
+ SIMPLE_CONVERTER_CASE(Type::DATE64, IntegerConverter<Date64Type>)
+ SIMPLE_CONVERTER_CASE(Type::TIME32, IntegerConverter<Time32Type>)
+ SIMPLE_CONVERTER_CASE(Type::TIME64, IntegerConverter<Time64Type>)
+ SIMPLE_CONVERTER_CASE(Type::DURATION, IntegerConverter<DurationType>)
+ SIMPLE_CONVERTER_CASE(Type::NA, NullConverter)
+ SIMPLE_CONVERTER_CASE(Type::BOOL, BooleanConverter)
+ SIMPLE_CONVERTER_CASE(Type::HALF_FLOAT, IntegerConverter<HalfFloatType>)
+ SIMPLE_CONVERTER_CASE(Type::FLOAT, FloatConverter<FloatType>)
+ SIMPLE_CONVERTER_CASE(Type::DOUBLE, FloatConverter<DoubleType>)
+ SIMPLE_CONVERTER_CASE(Type::LIST, ListConverter<ListType>)
+ SIMPLE_CONVERTER_CASE(Type::LARGE_LIST, ListConverter<LargeListType>)
+ SIMPLE_CONVERTER_CASE(Type::MAP, MapConverter)
+ SIMPLE_CONVERTER_CASE(Type::FIXED_SIZE_LIST, FixedSizeListConverter)
+ SIMPLE_CONVERTER_CASE(Type::STRUCT, StructConverter)
+ SIMPLE_CONVERTER_CASE(Type::STRING, StringConverter<StringType>)
+ SIMPLE_CONVERTER_CASE(Type::BINARY, StringConverter<BinaryType>)
+ SIMPLE_CONVERTER_CASE(Type::LARGE_STRING, StringConverter<LargeStringType>)
+ SIMPLE_CONVERTER_CASE(Type::LARGE_BINARY, StringConverter<LargeBinaryType>)
+ SIMPLE_CONVERTER_CASE(Type::FIXED_SIZE_BINARY, FixedSizeBinaryConverter<>)
+ SIMPLE_CONVERTER_CASE(Type::DECIMAL128, Decimal128Converter<>)
+ SIMPLE_CONVERTER_CASE(Type::DECIMAL256, Decimal256Converter<>)
+ SIMPLE_CONVERTER_CASE(Type::SPARSE_UNION, UnionConverter)
+ SIMPLE_CONVERTER_CASE(Type::DENSE_UNION, UnionConverter)
+ SIMPLE_CONVERTER_CASE(Type::INTERVAL_MONTHS, IntegerConverter<MonthIntervalType>)
+ SIMPLE_CONVERTER_CASE(Type::INTERVAL_DAY_TIME, DayTimeIntervalConverter)
+ SIMPLE_CONVERTER_CASE(Type::INTERVAL_MONTH_DAY_NANO, MonthDayNanoIntervalConverter)
+ default:
+ return ConversionNotImplemented(type);
+ }
+
+#undef SIMPLE_CONVERTER_CASE
+
+ RETURN_NOT_OK(res->Init());
+ *out = res;
+ return Status::OK();
+}
+
+} // namespace
+
+Status ArrayFromJSON(const std::shared_ptr<DataType>& type, util::string_view json_string,
+ std::shared_ptr<Array>* out) {
+ std::shared_ptr<Converter> converter;
+ RETURN_NOT_OK(GetConverter(type, &converter));
+
+ rj::Document json_doc;
+ json_doc.Parse<kParseFlags>(json_string.data(), json_string.length());
+ if (json_doc.HasParseError()) {
+ return Status::Invalid("JSON parse error at offset ", json_doc.GetErrorOffset(), ": ",
+ GetParseError_En(json_doc.GetParseError()));
+ }
+
+ // The JSON document should be an array, append it
+ RETURN_NOT_OK(converter->AppendValues(json_doc));
+ return converter->Finish(out);
+}
+
+Status ArrayFromJSON(const std::shared_ptr<DataType>& type,
+ const std::string& json_string, std::shared_ptr<Array>* out) {
+ return ArrayFromJSON(type, util::string_view(json_string), out);
+}
+
+Status ArrayFromJSON(const std::shared_ptr<DataType>& type, const char* json_string,
+ std::shared_ptr<Array>* out) {
+ return ArrayFromJSON(type, util::string_view(json_string), out);
+}
+
+Status DictArrayFromJSON(const std::shared_ptr<DataType>& type,
+ util::string_view indices_json,
+ util::string_view dictionary_json, std::shared_ptr<Array>* out) {
+ if (type->id() != Type::DICTIONARY) {
+ return Status::TypeError("DictArrayFromJSON requires dictionary type, got ", *type);
+ }
+
+ const auto& dictionary_type = checked_cast<const DictionaryType&>(*type);
+
+ std::shared_ptr<Array> indices, dictionary;
+ RETURN_NOT_OK(ArrayFromJSON(dictionary_type.index_type(), indices_json, &indices));
+ RETURN_NOT_OK(
+ ArrayFromJSON(dictionary_type.value_type(), dictionary_json, &dictionary));
+
+ return DictionaryArray::FromArrays(type, std::move(indices), std::move(dictionary))
+ .Value(out);
+}
+
+Status ScalarFromJSON(const std::shared_ptr<DataType>& type,
+ util::string_view json_string, std::shared_ptr<Scalar>* out) {
+ std::shared_ptr<Converter> converter;
+ RETURN_NOT_OK(GetConverter(type, &converter));
+
+ rj::Document json_doc;
+ json_doc.Parse<kParseFlags>(json_string.data(), json_string.length());
+ if (json_doc.HasParseError()) {
+ return Status::Invalid("JSON parse error at offset ", json_doc.GetErrorOffset(), ": ",
+ GetParseError_En(json_doc.GetParseError()));
+ }
+
+ std::shared_ptr<Array> array;
+ RETURN_NOT_OK(converter->AppendValue(json_doc));
+ RETURN_NOT_OK(converter->Finish(&array));
+ DCHECK_EQ(array->length(), 1);
+ ARROW_ASSIGN_OR_RAISE(*out, array->GetScalar(0));
+ return Status::OK();
+}
+
+Status DictScalarFromJSON(const std::shared_ptr<DataType>& type,
+ util::string_view index_json, util::string_view dictionary_json,
+ std::shared_ptr<Scalar>* out) {
+ if (type->id() != Type::DICTIONARY) {
+ return Status::TypeError("DictScalarFromJSON requires dictionary type, got ", *type);
+ }
+
+ const auto& dictionary_type = checked_cast<const DictionaryType&>(*type);
+
+ std::shared_ptr<Scalar> index;
+ std::shared_ptr<Array> dictionary;
+ RETURN_NOT_OK(ScalarFromJSON(dictionary_type.index_type(), index_json, &index));
+ RETURN_NOT_OK(
+ ArrayFromJSON(dictionary_type.value_type(), dictionary_json, &dictionary));
+
+ *out = DictionaryScalar::Make(std::move(index), std::move(dictionary));
+ return Status::OK();
+}
+
+} // namespace json
+} // namespace internal
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/json_simple.h b/src/arrow/cpp/src/arrow/ipc/json_simple.h
new file mode 100644
index 000000000..8269bd653
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/json_simple.h
@@ -0,0 +1,66 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Implement a simple JSON representation format for arrays
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/status.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Array;
+class DataType;
+
+namespace ipc {
+namespace internal {
+namespace json {
+
+ARROW_EXPORT
+Status ArrayFromJSON(const std::shared_ptr<DataType>&, const std::string& json,
+ std::shared_ptr<Array>* out);
+
+ARROW_EXPORT
+Status ArrayFromJSON(const std::shared_ptr<DataType>&, util::string_view json,
+ std::shared_ptr<Array>* out);
+
+ARROW_EXPORT
+Status ArrayFromJSON(const std::shared_ptr<DataType>&, const char* json,
+ std::shared_ptr<Array>* out);
+
+ARROW_EXPORT
+Status DictArrayFromJSON(const std::shared_ptr<DataType>&, util::string_view indices_json,
+ util::string_view dictionary_json, std::shared_ptr<Array>* out);
+
+ARROW_EXPORT
+Status ScalarFromJSON(const std::shared_ptr<DataType>&, util::string_view json,
+ std::shared_ptr<Scalar>* out);
+
+ARROW_EXPORT
+Status DictScalarFromJSON(const std::shared_ptr<DataType>&, util::string_view index_json,
+ util::string_view dictionary_json,
+ std::shared_ptr<Scalar>* out);
+
+} // namespace json
+} // namespace internal
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/json_simple_test.cc b/src/arrow/cpp/src/arrow/ipc/json_simple_test.cc
new file mode 100644
index 000000000..34c300faa
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/json_simple_test.cc
@@ -0,0 +1,1415 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/builder_time.h"
+#include "arrow/ipc/json_simple.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bitmap_builders.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+
+#if defined(_MSC_VER)
+// "warning C4307: '+': integral constant overflow"
+#pragma warning(disable : 4307)
+#endif
+
+namespace arrow {
+namespace ipc {
+namespace internal {
+namespace json {
+
+using ::arrow::internal::BytesToBits;
+using ::arrow::internal::checked_cast;
+using ::arrow::internal::checked_pointer_cast;
+
+// Avoid undefined behaviour on signed overflow
+template <typename Signed>
+Signed SafeSignedAdd(Signed u, Signed v) {
+ using Unsigned = typename std::make_unsigned<Signed>::type;
+ return static_cast<Signed>(static_cast<Unsigned>(u) + static_cast<Unsigned>(v));
+}
+
+// Special case for 8-bit ints (must output their decimal value, not the
+// corresponding ASCII character)
+void JSONArrayInternal(std::ostream* ss, int8_t value) {
+ *ss << static_cast<int16_t>(value);
+}
+
+void JSONArrayInternal(std::ostream* ss, uint8_t value) {
+ *ss << static_cast<int16_t>(value);
+}
+
+template <typename Value>
+void JSONArrayInternal(std::ostream* ss, Value&& value) {
+ *ss << value;
+}
+
+template <typename Value, typename... Tail>
+void JSONArrayInternal(std::ostream* ss, Value&& value, Tail&&... tail) {
+ JSONArrayInternal(ss, std::forward<Value>(value));
+ *ss << ", ";
+ JSONArrayInternal(ss, std::forward<Tail>(tail)...);
+}
+
+template <typename... Args>
+std::string JSONArray(Args&&... args) {
+ std::stringstream ss;
+ ss << "[";
+ JSONArrayInternal(&ss, std::forward<Args>(args)...);
+ ss << "]";
+ return ss.str();
+}
+
+template <typename T, typename C_TYPE = typename T::c_type>
+void AssertJSONArray(const std::shared_ptr<DataType>& type, const std::string& json,
+ const std::vector<C_TYPE>& values) {
+ std::shared_ptr<Array> actual, expected;
+
+ ASSERT_OK(ArrayFromJSON(type, json, &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<T, C_TYPE>(type, values, &expected);
+ AssertArraysEqual(*expected, *actual);
+}
+
+template <typename T, typename C_TYPE = typename T::c_type>
+void AssertJSONArray(const std::shared_ptr<DataType>& type, const std::string& json,
+ const std::vector<bool>& is_valid,
+ const std::vector<C_TYPE>& values) {
+ std::shared_ptr<Array> actual, expected;
+
+ ASSERT_OK(ArrayFromJSON(type, json, &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<T, C_TYPE>(type, is_valid, values, &expected);
+ AssertArraysEqual(*expected, *actual);
+}
+
+void AssertJSONDictArray(const std::shared_ptr<DataType>& index_type,
+ const std::shared_ptr<DataType>& value_type,
+ const std::string& json,
+ const std::string& expected_indices_json,
+ const std::string& expected_values_json) {
+ auto type = dictionary(index_type, value_type);
+ std::shared_ptr<Array> actual, expected_indices, expected_values;
+
+ ASSERT_OK(ArrayFromJSON(index_type, expected_indices_json, &expected_indices));
+ ASSERT_OK(ArrayFromJSON(value_type, expected_values_json, &expected_values));
+
+ ASSERT_OK(ArrayFromJSON(type, json, &actual));
+ ASSERT_OK(actual->ValidateFull());
+
+ const auto& dict_array = checked_cast<const DictionaryArray&>(*actual);
+ AssertArraysEqual(*expected_indices, *dict_array.indices());
+ AssertArraysEqual(*expected_values, *dict_array.dictionary());
+}
+
+template <typename T, typename C_TYPE = typename T::c_type>
+void AssertJSONScalar(const std::shared_ptr<DataType>& type, const std::string& json,
+ const bool is_valid, const C_TYPE value) {
+ SCOPED_TRACE(json);
+ std::shared_ptr<Scalar> actual, expected;
+
+ ASSERT_OK(ScalarFromJSON(type, json, &actual));
+ if (is_valid) {
+ ASSERT_OK_AND_ASSIGN(expected, MakeScalar(type, value));
+ } else {
+ expected = MakeNullScalar(type);
+ }
+ AssertScalarsEqual(*expected, *actual, /*verbose=*/true);
+}
+
+TEST(TestHelper, JSONArray) {
+ // Test the JSONArray helper func
+ std::string s =
+ JSONArray(123, -4.5, static_cast<int8_t>(-12), static_cast<uint8_t>(34));
+ ASSERT_EQ(s, "[123, -4.5, -12, 34]");
+ s = JSONArray(9223372036854775807LL, 9223372036854775808ULL, -9223372036854775807LL - 1,
+ 18446744073709551615ULL);
+ ASSERT_EQ(s,
+ "[9223372036854775807, 9223372036854775808, -9223372036854775808, "
+ "18446744073709551615]");
+}
+
+TEST(TestHelper, SafeSignedAdd) {
+ ASSERT_EQ(0, SafeSignedAdd<int8_t>(-128, -128));
+ ASSERT_EQ(1, SafeSignedAdd<int8_t>(-128, -127));
+ ASSERT_EQ(-128, SafeSignedAdd<int8_t>(1, 127));
+ ASSERT_EQ(-2147483648LL, SafeSignedAdd<int32_t>(1, 2147483647));
+}
+
+template <typename T>
+class TestIntegers : public ::testing::Test {
+ public:
+ std::shared_ptr<DataType> type() { return TypeTraits<T>::type_singleton(); }
+};
+
+TYPED_TEST_SUITE_P(TestIntegers);
+
+TYPED_TEST_P(TestIntegers, Basics) {
+ using T = TypeParam;
+ using c_type = typename T::c_type;
+
+ std::shared_ptr<Array> expected, actual;
+ auto type = this->type();
+
+ AssertJSONArray<T>(type, "[]", {});
+ AssertJSONArray<T>(type, "[4, 0, 5]", {4, 0, 5});
+ AssertJSONArray<T>(type, "[4, null, 5]", {true, false, true}, {4, 0, 5});
+
+ // Test limits
+ const auto min_val = std::numeric_limits<c_type>::min();
+ const auto max_val = std::numeric_limits<c_type>::max();
+ std::string json_string = JSONArray(0, 1, min_val);
+ AssertJSONArray<T>(type, json_string, {0, 1, min_val});
+ json_string = JSONArray(0, 1, max_val);
+ AssertJSONArray<T>(type, json_string, {0, 1, max_val});
+}
+
+TYPED_TEST_P(TestIntegers, Errors) {
+ std::shared_ptr<Array> array;
+ auto type = this->type();
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "0", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "{}", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0.0]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"0\"]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[0]]", &array));
+}
+
+TYPED_TEST_P(TestIntegers, OutOfBounds) {
+ using T = TypeParam;
+ using c_type = typename T::c_type;
+
+ std::shared_ptr<Array> array;
+ auto type = this->type();
+
+ if (type->id() == Type::UINT64) {
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[18446744073709551616]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[-1]", &array));
+ } else if (type->id() == Type::INT64) {
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[9223372036854775808]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[-9223372036854775809]", &array));
+ } else if (std::is_signed<c_type>::value) {
+ const auto lower = SafeSignedAdd<int64_t>(std::numeric_limits<c_type>::min(), -1);
+ const auto upper = SafeSignedAdd<int64_t>(std::numeric_limits<c_type>::max(), +1);
+ auto json_string = JSONArray(lower);
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, json_string, &array));
+ json_string = JSONArray(upper);
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, json_string, &array));
+ } else {
+ const auto upper = static_cast<uint64_t>(std::numeric_limits<c_type>::max()) + 1;
+ auto json_string = JSONArray(upper);
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, json_string, &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[-1]", &array));
+ }
+}
+
+TYPED_TEST_P(TestIntegers, Dictionary) {
+ std::shared_ptr<Array> array;
+ std::shared_ptr<DataType> value_type = this->type();
+
+ if (value_type->id() == Type::HALF_FLOAT) {
+ // Unsupported, skip
+ return;
+ }
+
+ AssertJSONDictArray(int8(), value_type, "[1, 2, 3, null, 3, 1]",
+ /*indices=*/"[0, 1, 2, null, 2, 0]",
+ /*values=*/"[1, 2, 3]");
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TestIntegers, Basics, Errors, OutOfBounds, Dictionary);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt8, TestIntegers, Int8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt16, TestIntegers, Int16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt32, TestIntegers, Int32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt64, TestIntegers, Int64Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt8, TestIntegers, UInt8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt16, TestIntegers, UInt16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt32, TestIntegers, UInt32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt64, TestIntegers, UInt64Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestHalfFloat, TestIntegers, HalfFloatType);
+
+template <typename T>
+class TestStrings : public ::testing::Test {
+ public:
+ std::shared_ptr<DataType> type() { return TypeTraits<T>::type_singleton(); }
+};
+
+TYPED_TEST_SUITE_P(TestStrings);
+
+TYPED_TEST_P(TestStrings, Basics) {
+ using T = TypeParam;
+ auto type = this->type();
+
+ std::shared_ptr<Array> expected, actual;
+
+ AssertJSONArray<T, std::string>(type, "[]", {});
+ AssertJSONArray<T, std::string>(type, "[\"\", \"foo\"]", {"", "foo"});
+ AssertJSONArray<T, std::string>(type, "[\"\", null]", {true, false}, {"", ""});
+ // NUL character in string
+ std::string s = "some";
+ s += '\x00';
+ s += "char";
+ AssertJSONArray<T, std::string>(type, "[\"\", \"some\\u0000char\"]", {"", s});
+ // UTF8 sequence in string
+ AssertJSONArray<T, std::string>(type, "[\"\xc3\xa9\"]", {"\xc3\xa9"});
+
+ if (!T::is_utf8) {
+ // Arbitrary binary (non-UTF8) sequence in string
+ s = "\xff\x9f";
+ AssertJSONArray<T, std::string>(type, "[\"" + s + "\"]", {s});
+ }
+
+ // Bytes < 0x20 can be represented as JSON unicode escapes
+ s = '\x00';
+ s += "\x1f";
+ AssertJSONArray<T, std::string>(type, "[\"\\u0000\\u001f\"]", {s});
+}
+
+TYPED_TEST_P(TestStrings, Errors) {
+ auto type = this->type();
+ std::shared_ptr<Array> array;
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[]]", &array));
+}
+
+TYPED_TEST_P(TestStrings, Dictionary) {
+ auto value_type = this->type();
+
+ AssertJSONDictArray(int16(), value_type, R"(["foo", "bar", null, "bar", "foo"])",
+ /*indices=*/"[0, 1, null, 1, 0]",
+ /*values=*/R"(["foo", "bar"])");
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TestStrings, Basics, Errors, Dictionary);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(TestString, TestStrings, StringType);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestBinary, TestStrings, BinaryType);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestLargeString, TestStrings, LargeStringType);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestLargeBinary, TestStrings, LargeBinaryType);
+
+TEST(TestNull, Basics) {
+ std::shared_ptr<DataType> type = null();
+ std::shared_ptr<Array> expected, actual;
+
+ AssertJSONArray<NullType, std::nullptr_t>(type, "[]", {});
+ AssertJSONArray<NullType, std::nullptr_t>(type, "[null, null]", {nullptr, nullptr});
+}
+
+TEST(TestNull, Errors) {
+ std::shared_ptr<DataType> type = null();
+ std::shared_ptr<Array> array;
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[]]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[NaN]", &array));
+}
+
+TEST(TestBoolean, Basics) {
+ std::shared_ptr<DataType> type = boolean();
+
+ AssertJSONArray<BooleanType, bool>(type, "[]", {});
+ AssertJSONArray<BooleanType, bool>(type, "[false, true, false]", {false, true, false});
+ AssertJSONArray<BooleanType, bool>(type, "[false, true, null]", {true, true, false},
+ {false, true, false});
+ // Supports integer literal casting
+ AssertJSONArray<BooleanType, bool>(type, "[0, 1, 0]", {false, true, false});
+ AssertJSONArray<BooleanType, bool>(type, "[0, 1, null]", {true, true, false},
+ {false, true, false});
+}
+
+TEST(TestBoolean, Errors) {
+ std::shared_ptr<DataType> type = boolean();
+ std::shared_ptr<Array> array;
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0.0]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"true\"]", &array));
+}
+
+TEST(TestFloat, Basics) {
+ std::shared_ptr<DataType> type = float32();
+ std::shared_ptr<Array> expected, actual;
+
+ AssertJSONArray<FloatType>(type, "[]", {});
+ AssertJSONArray<FloatType>(type, "[1, 2.5, -3e4]", {1.0f, 2.5f, -3.0e4f});
+ AssertJSONArray<FloatType>(type, "[-0.0, Inf, -Inf, null]", {true, true, true, false},
+ {-0.0f, INFINITY, -INFINITY, 0.0f});
+
+ // Check NaN separately as AssertArraysEqual simply memcmp's array contents
+ // and NaNs can have many bit representations.
+ ASSERT_OK(ArrayFromJSON(type, "[NaN]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ float value = checked_cast<FloatArray&>(*actual).Value(0);
+ ASSERT_TRUE(std::isnan(value));
+}
+
+TEST(TestFloat, Errors) {
+ std::shared_ptr<DataType> type = float32();
+ std::shared_ptr<Array> array;
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[true]", &array));
+}
+
+TEST(TestDouble, Basics) {
+ std::shared_ptr<DataType> type = float64();
+ std::shared_ptr<Array> expected, actual;
+
+ AssertJSONArray<DoubleType>(type, "[]", {});
+ AssertJSONArray<DoubleType>(type, "[1, 2.5, -3e4]", {1.0, 2.5, -3.0e4});
+ AssertJSONArray<DoubleType>(type, "[-0.0, Inf, -Inf, null]", {true, true, true, false},
+ {-0.0, INFINITY, -INFINITY, 0.0});
+
+ ASSERT_OK(ArrayFromJSON(type, "[NaN]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ double value = checked_cast<DoubleArray&>(*actual).Value(0);
+ ASSERT_TRUE(std::isnan(value));
+}
+
+TEST(TestDouble, Errors) {
+ std::shared_ptr<DataType> type = float64();
+ std::shared_ptr<Array> array;
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[true]", &array));
+}
+
+TEST(TestTimestamp, Basics) {
+ // Timestamp type
+ auto type = timestamp(TimeUnit::SECOND);
+ AssertJSONArray<TimestampType, int64_t>(
+ type, R"(["1970-01-01","2000-02-29","3989-07-14","1900-02-28"])",
+ {0, 951782400, 63730281600LL, -2203977600LL});
+
+ type = timestamp(TimeUnit::NANO);
+ AssertJSONArray<TimestampType, int64_t>(
+ type, R"(["1970-01-01","2000-02-29","1900-02-28"])",
+ {0, 951782400000000000LL, -2203977600000000000LL});
+}
+
+TEST(TestDate, Basics) {
+ auto type = date32();
+ AssertJSONArray<Date32Type>(type, R"([5, null, 42])", {true, false, true}, {5, 0, 42});
+ type = date64();
+ AssertJSONArray<Date64Type>(type, R"([1, null, 9999999999999])", {true, false, true},
+ {1, 0, 9999999999999LL});
+}
+
+TEST(TestTime, Basics) {
+ auto type = time32(TimeUnit::SECOND);
+ AssertJSONArray<Time32Type>(type, R"([5, null, 42])", {true, false, true}, {5, 0, 42});
+ type = time32(TimeUnit::MILLI);
+ AssertJSONArray<Time32Type>(type, R"([5, null, 42])", {true, false, true}, {5, 0, 42});
+
+ type = time64(TimeUnit::MICRO);
+ AssertJSONArray<Time64Type>(type, R"([1, null, 9999999999999])", {true, false, true},
+ {1, 0, 9999999999999LL});
+ type = time64(TimeUnit::NANO);
+ AssertJSONArray<Time64Type>(type, R"([1, null, 9999999999999])", {true, false, true},
+ {1, 0, 9999999999999LL});
+}
+
+TEST(TestDuration, Basics) {
+ auto type = duration(TimeUnit::SECOND);
+ AssertJSONArray<DurationType>(type, R"([null, -7777777777777, 9999999999999])",
+ {false, true, true},
+ {0, -7777777777777LL, 9999999999999LL});
+ type = duration(TimeUnit::MILLI);
+ AssertJSONArray<DurationType>(type, R"([null, -7777777777777, 9999999999999])",
+ {false, true, true},
+ {0, -7777777777777LL, 9999999999999LL});
+ type = duration(TimeUnit::MICRO);
+ AssertJSONArray<DurationType>(type, R"([null, -7777777777777, 9999999999999])",
+ {false, true, true},
+ {0, -7777777777777LL, 9999999999999LL});
+ type = duration(TimeUnit::NANO);
+ AssertJSONArray<DurationType>(type, R"([null, -7777777777777, 9999999999999])",
+ {false, true, true},
+ {0, -7777777777777LL, 9999999999999LL});
+}
+
+TEST(TestMonthInterval, Basics) {
+ auto type = month_interval();
+ AssertJSONArray<MonthIntervalType>(type, R"([123, -456, null])", {true, true, false},
+ {123, -456, 0});
+}
+
+TEST(TestDayTimeInterval, Basics) {
+ auto type = day_time_interval();
+ AssertJSONArray<DayTimeIntervalType>(type, R"([[1, -600], null])", {true, false},
+ {{1, -600}, {}});
+}
+
+TEST(MonthDayNanoInterval, Basics) {
+ auto type = month_day_nano_interval();
+ AssertJSONArray<MonthDayNanoIntervalType>(type, R"([[1, -600, 5000], null])",
+ {true, false}, {{1, -600, 5000}, {}});
+}
+
+TEST(TestFixedSizeBinary, Basics) {
+ std::shared_ptr<DataType> type = fixed_size_binary(3);
+ std::shared_ptr<Array> expected, actual;
+
+ AssertJSONArray<FixedSizeBinaryType, std::string>(type, "[]", {});
+ AssertJSONArray<FixedSizeBinaryType, std::string>(type, "[\"foo\", \"bar\"]",
+ {"foo", "bar"});
+ AssertJSONArray<FixedSizeBinaryType, std::string>(type, "[null, \"foo\"]",
+ {false, true}, {"", "foo"});
+ // Arbitrary binary (non-UTF8) sequence in string
+ std::string s = "\xff\x9f\xcc";
+ AssertJSONArray<FixedSizeBinaryType, std::string>(type, "[\"" + s + "\"]", {s});
+}
+
+TEST(TestFixedSizeBinary, Errors) {
+ std::shared_ptr<DataType> type = fixed_size_binary(3);
+ std::shared_ptr<Array> array;
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[]]", &array));
+ // Invalid length
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"\"]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"abcd\"]", &array));
+}
+
+TEST(TestFixedSizeBinary, Dictionary) {
+ std::shared_ptr<DataType> type = fixed_size_binary(3);
+
+ AssertJSONDictArray(int8(), type, R"(["foo", "bar", "foo", null])",
+ /*indices=*/"[0, 1, 0, null]",
+ /*values=*/R"(["foo", "bar"])");
+
+ // Invalid length
+ std::shared_ptr<Array> array;
+ ASSERT_RAISES(Invalid, ArrayFromJSON(dictionary(int8(), type), R"(["x"])", &array));
+}
+
+template <typename DecimalValue, typename DecimalBuilder>
+void TestDecimalBasic(std::shared_ptr<DataType> type) {
+ std::shared_ptr<Array> expected, actual;
+
+ ASSERT_OK(ArrayFromJSON(type, "[]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ {
+ DecimalBuilder builder(type);
+ ASSERT_OK(builder.Finish(&expected));
+ }
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[\"123.4567\", \"-78.9000\"]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ {
+ DecimalBuilder builder(type);
+ ASSERT_OK(builder.Append(DecimalValue(1234567)));
+ ASSERT_OK(builder.Append(DecimalValue(-789000)));
+ ASSERT_OK(builder.Finish(&expected));
+ }
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[\"123.4567\", null]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ {
+ DecimalBuilder builder(type);
+ ASSERT_OK(builder.Append(DecimalValue(1234567)));
+ ASSERT_OK(builder.AppendNull());
+ ASSERT_OK(builder.Finish(&expected));
+ }
+ AssertArraysEqual(*expected, *actual);
+}
+
+TEST(TestDecimal128, Basics) {
+ TestDecimalBasic<Decimal128, Decimal128Builder>(decimal128(10, 4));
+}
+
+TEST(TestDecimal256, Basics) {
+ TestDecimalBasic<Decimal256, Decimal256Builder>(decimal256(10, 4));
+}
+
+TEST(TestDecimal, Errors) {
+ for (std::shared_ptr<DataType> type : {decimal128(10, 4), decimal256(10, 4)}) {
+ std::shared_ptr<Array> array;
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[12.3456]", &array));
+ // Bad scale
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"12.345\"]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"12.34560\"]", &array));
+ }
+}
+
+TEST(TestDecimal, Dictionary) {
+ for (std::shared_ptr<DataType> type : {decimal128(10, 2), decimal256(10, 2)}) {
+ AssertJSONDictArray(int32(), type,
+ R"(["123.45", "-78.90", "-78.90", null, "123.45"])",
+ /*indices=*/"[0, 1, 1, null, 0]",
+ /*values=*/R"(["123.45", "-78.90"])");
+ }
+}
+
+TEST(TestList, IntegerList) {
+ auto pool = default_memory_pool();
+ std::shared_ptr<DataType> type = list(int64());
+ std::shared_ptr<Array> offsets, values, expected, actual;
+
+ ASSERT_OK(ArrayFromJSON(type, "[]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int32Type>({0}, &offsets);
+ ArrayFromVector<Int64Type>({}, &values);
+ ASSERT_OK_AND_ASSIGN(expected, ListArray::FromArrays(*offsets, *values, pool));
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[[4, 5], [], [6]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int32Type>({0, 2, 2, 3}, &offsets);
+ ArrayFromVector<Int64Type>({4, 5, 6}, &values);
+ ASSERT_OK_AND_ASSIGN(expected, ListArray::FromArrays(*offsets, *values, pool));
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[[], [null], [6, null]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int32Type>({0, 0, 1, 3}, &offsets);
+ auto is_valid = std::vector<bool>{false, true, false};
+ ArrayFromVector<Int64Type>(is_valid, {0, 6, 0}, &values);
+ ASSERT_OK_AND_ASSIGN(expected, ListArray::FromArrays(*offsets, *values, pool));
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[null, [], null]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ {
+ std::unique_ptr<ArrayBuilder> builder;
+ ASSERT_OK(MakeBuilder(pool, type, &builder));
+ auto& list_builder = checked_cast<ListBuilder&>(*builder);
+ ASSERT_OK(list_builder.AppendNull());
+ ASSERT_OK(list_builder.Append());
+ ASSERT_OK(list_builder.AppendNull());
+ ASSERT_OK(list_builder.Finish(&expected));
+ }
+ AssertArraysEqual(*expected, *actual);
+}
+
+TEST(TestList, IntegerListErrors) {
+ std::shared_ptr<DataType> type = list(int64());
+ std::shared_ptr<Array> array;
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[0.0]]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[9223372036854775808]]", &array));
+}
+
+TEST(TestList, NullList) {
+ auto pool = default_memory_pool();
+ std::shared_ptr<DataType> type = list(null());
+ std::shared_ptr<Array> offsets, values, expected, actual;
+
+ ASSERT_OK(ArrayFromJSON(type, "[]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int32Type>({0}, &offsets);
+ values = std::make_shared<NullArray>(0);
+ ASSERT_OK_AND_ASSIGN(expected, ListArray::FromArrays(*offsets, *values, pool));
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[[], [null], [null, null]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int32Type>({0, 0, 1, 3}, &offsets);
+ values = std::make_shared<NullArray>(3);
+ ASSERT_OK_AND_ASSIGN(expected, ListArray::FromArrays(*offsets, *values, pool));
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[null, [], null]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ {
+ std::unique_ptr<ArrayBuilder> builder;
+ ASSERT_OK(MakeBuilder(pool, type, &builder));
+ auto& list_builder = checked_cast<ListBuilder&>(*builder);
+ ASSERT_OK(list_builder.AppendNull());
+ ASSERT_OK(list_builder.Append());
+ ASSERT_OK(list_builder.AppendNull());
+ ASSERT_OK(list_builder.Finish(&expected));
+ }
+ AssertArraysEqual(*expected, *actual);
+}
+
+TEST(TestList, IntegerListList) {
+ auto pool = default_memory_pool();
+ std::shared_ptr<DataType> type = list(list(uint8()));
+ std::shared_ptr<Array> offsets, values, nested, expected, actual;
+
+ ASSERT_OK(ArrayFromJSON(type, "[[[4], [5, 6]], [[7, 8, 9]]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int32Type>({0, 1, 3, 6}, &offsets);
+ ArrayFromVector<UInt8Type>({4, 5, 6, 7, 8, 9}, &values);
+ ASSERT_OK_AND_ASSIGN(nested, ListArray::FromArrays(*offsets, *values, pool));
+ ArrayFromVector<Int32Type>({0, 2, 3}, &offsets);
+ ASSERT_OK_AND_ASSIGN(expected, ListArray::FromArrays(*offsets, *nested, pool));
+ ASSERT_EQ(actual->length(), 2);
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[[], [[]], [[4], [], [5, 6]], [[7, 8, 9]]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int32Type>({0, 0, 1, 1, 3, 6}, &offsets);
+ ArrayFromVector<UInt8Type>({4, 5, 6, 7, 8, 9}, &values);
+ ASSERT_OK_AND_ASSIGN(nested, ListArray::FromArrays(*offsets, *values, pool));
+ ArrayFromVector<Int32Type>({0, 0, 1, 4, 5}, &offsets);
+ ASSERT_OK_AND_ASSIGN(expected, ListArray::FromArrays(*offsets, *nested, pool));
+ ASSERT_EQ(actual->length(), 4);
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[null, [null], [[null]]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ {
+ std::unique_ptr<ArrayBuilder> builder;
+ ASSERT_OK(MakeBuilder(pool, type, &builder));
+ auto& list_builder = checked_cast<ListBuilder&>(*builder);
+ auto& child_builder = checked_cast<ListBuilder&>(*list_builder.value_builder());
+ ASSERT_OK(list_builder.AppendNull());
+ ASSERT_OK(list_builder.Append());
+ ASSERT_OK(child_builder.AppendNull());
+ ASSERT_OK(list_builder.Append());
+ ASSERT_OK(child_builder.Append());
+ ASSERT_OK(list_builder.Finish(&expected));
+ }
+}
+
+TEST(TestLargeList, Basics) {
+ // Similar as TestList above, only testing the basics
+ auto pool = default_memory_pool();
+ std::shared_ptr<DataType> type = large_list(int16());
+ std::shared_ptr<Array> offsets, values, expected, actual;
+
+ ASSERT_OK(ArrayFromJSON(type, "[[], [null], [6, null]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int64Type>({0, 0, 1, 3}, &offsets);
+ auto is_valid = std::vector<bool>{false, true, false};
+ ArrayFromVector<Int16Type>(is_valid, {0, 6, 0}, &values);
+ ASSERT_OK_AND_ASSIGN(expected, LargeListArray::FromArrays(*offsets, *values, pool));
+ AssertArraysEqual(*expected, *actual);
+}
+
+TEST(TestMap, IntegerToInteger) {
+ auto type = map(int16(), int16());
+ std::shared_ptr<Array> expected, actual;
+
+ const char* input = R"(
+[
+ [[0, 1], [1, 1], [2, 2], [3, 3], [4, 5], [5, 8]],
+ null,
+ [[0, null], [1, null], [2, 0], [3, 1], [4, null], [5, 2]],
+ []
+ ]
+)";
+ ASSERT_OK(ArrayFromJSON(type, input, &actual));
+
+ std::unique_ptr<ArrayBuilder> builder;
+ ASSERT_OK(MakeBuilder(default_memory_pool(), type, &builder));
+ auto& map_builder = checked_cast<MapBuilder&>(*builder);
+ auto& key_builder = checked_cast<Int16Builder&>(*map_builder.key_builder());
+ auto& item_builder = checked_cast<Int16Builder&>(*map_builder.item_builder());
+
+ ASSERT_OK(map_builder.Append());
+ ASSERT_OK(key_builder.AppendValues({0, 1, 2, 3, 4, 5}));
+ ASSERT_OK(item_builder.AppendValues({1, 1, 2, 3, 5, 8}));
+ ASSERT_OK(map_builder.AppendNull());
+ ASSERT_OK(map_builder.Append());
+ ASSERT_OK(key_builder.AppendValues({0, 1, 2, 3, 4, 5}));
+ ASSERT_OK(item_builder.AppendValues({-1, -1, 0, 1, -1, 2}, {0, 0, 1, 1, 0, 1}));
+ ASSERT_OK(map_builder.Append());
+ ASSERT_OK(map_builder.Finish(&expected));
+
+ ASSERT_ARRAYS_EQUAL(*actual, *expected);
+}
+
+TEST(TestMap, StringToInteger) {
+ auto type = map(utf8(), int32());
+ const char* input = R"(
+[
+ [["joe", 0], ["mark", null]],
+ null,
+ [["cap", 8]],
+ []
+ ]
+)";
+ auto actual = ArrayFromJSON(type, input);
+ std::vector<int32_t> offsets = {0, 2, 2, 3, 3};
+ auto expected_keys = ArrayFromJSON(utf8(), R"(["joe", "mark", "cap"])");
+ auto expected_values = ArrayFromJSON(int32(), "[0, null, 8]");
+ ASSERT_OK_AND_ASSIGN(auto expected_null_bitmap, BytesToBits({1, 0, 1, 1}));
+ auto expected =
+ std::make_shared<MapArray>(type, 4, Buffer::Wrap(offsets), expected_keys,
+ expected_values, expected_null_bitmap, 1);
+ ASSERT_ARRAYS_EQUAL(*actual, *expected);
+}
+
+TEST(TestMap, Errors) {
+ auto type = map(int16(), int16());
+ std::shared_ptr<Array> array;
+
+ // list of pairs isn't an array
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0]", &array));
+ // pair isn't an array
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[0]]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[null]]", &array));
+ // pair with length != 2
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[[0]]]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[[0, 1, 2]]]", &array));
+ // null key
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[[null, 0]]]", &array));
+ // key or value fails to convert
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[[0.0, 0]]]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[[0, 0.0]]]", &array));
+}
+
+TEST(TestMap, IntegerMapToStringList) {
+ auto type = map(map(int16(), int16()), list(utf8()));
+ std::shared_ptr<Array> expected, actual;
+
+ const char* input = R"(
+[
+ [
+ [
+ [],
+ [null, "empty"]
+ ],
+ [
+ [[0, 1]],
+ null
+ ],
+ [
+ [[0, 0], [1, 1]],
+ ["bootstrapping tautology?", "lispy", null, "i can see eternity"]
+ ]
+ ],
+ null
+ ]
+)";
+ ASSERT_OK(ArrayFromJSON(type, input, &actual));
+
+ std::unique_ptr<ArrayBuilder> builder;
+ ASSERT_OK(MakeBuilder(default_memory_pool(), type, &builder));
+ auto& map_builder = checked_cast<MapBuilder&>(*builder);
+ auto& key_builder = checked_cast<MapBuilder&>(*map_builder.key_builder());
+ auto& key_key_builder = checked_cast<Int16Builder&>(*key_builder.key_builder());
+ auto& key_item_builder = checked_cast<Int16Builder&>(*key_builder.item_builder());
+ auto& item_builder = checked_cast<ListBuilder&>(*map_builder.item_builder());
+ auto& item_value_builder = checked_cast<StringBuilder&>(*item_builder.value_builder());
+
+ ASSERT_OK(map_builder.Append());
+ ASSERT_OK(key_builder.Append());
+ ASSERT_OK(item_builder.Append());
+ ASSERT_OK(item_value_builder.AppendNull());
+ ASSERT_OK(item_value_builder.Append("empty"));
+
+ ASSERT_OK(key_builder.Append());
+ ASSERT_OK(item_builder.AppendNull());
+ ASSERT_OK(key_key_builder.AppendValues({0}));
+ ASSERT_OK(key_item_builder.AppendValues({1}));
+
+ ASSERT_OK(key_builder.Append());
+ ASSERT_OK(item_builder.Append());
+ ASSERT_OK(key_key_builder.AppendValues({0, 1}));
+ ASSERT_OK(key_item_builder.AppendValues({0, 1}));
+ ASSERT_OK(item_value_builder.Append("bootstrapping tautology?"));
+ ASSERT_OK(item_value_builder.Append("lispy"));
+ ASSERT_OK(item_value_builder.AppendNull());
+ ASSERT_OK(item_value_builder.Append("i can see eternity"));
+
+ ASSERT_OK(map_builder.AppendNull());
+
+ ASSERT_OK(map_builder.Finish(&expected));
+ ASSERT_ARRAYS_EQUAL(*actual, *expected);
+}
+
+TEST(TestFixedSizeList, IntegerList) {
+ auto pool = default_memory_pool();
+ auto type = fixed_size_list(int64(), 2);
+ std::shared_ptr<Array> values, expected, actual;
+
+ ASSERT_OK(ArrayFromJSON(type, "[]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int64Type>({}, &values);
+ expected = std::make_shared<FixedSizeListArray>(type, 0, values);
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[[4, 5], [0, 0], [6, 7]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int64Type>({4, 5, 0, 0, 6, 7}, &values);
+ expected = std::make_shared<FixedSizeListArray>(type, 3, values);
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[[null, null], [0, null], [6, null]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ auto is_valid = std::vector<bool>{false, false, true, false, true, false};
+ ArrayFromVector<Int64Type>(is_valid, {0, 0, 0, 0, 6, 0}, &values);
+ expected = std::make_shared<FixedSizeListArray>(type, 3, values);
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[null, [null, null], null]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ {
+ std::unique_ptr<ArrayBuilder> builder;
+ ASSERT_OK(MakeBuilder(pool, type, &builder));
+ auto& list_builder = checked_cast<FixedSizeListBuilder&>(*builder);
+ auto value_builder = checked_cast<Int64Builder*>(list_builder.value_builder());
+ ASSERT_OK(list_builder.AppendNull());
+ ASSERT_OK(list_builder.Append());
+ ASSERT_OK(value_builder->AppendNull());
+ ASSERT_OK(value_builder->AppendNull());
+ ASSERT_OK(list_builder.AppendNull());
+ ASSERT_OK(list_builder.Finish(&expected));
+ }
+ AssertArraysEqual(*expected, *actual);
+}
+
+TEST(TestFixedSizeList, IntegerListErrors) {
+ std::shared_ptr<DataType> type = fixed_size_list(int64(), 2);
+ std::shared_ptr<Array> array;
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[0.0, 1.0]]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[0]]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[9223372036854775808, 0]]", &array));
+}
+
+TEST(TestFixedSizeList, NullList) {
+ auto pool = default_memory_pool();
+ std::shared_ptr<DataType> type = fixed_size_list(null(), 2);
+ std::shared_ptr<Array> values, expected, actual;
+
+ ASSERT_OK(ArrayFromJSON(type, "[]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ values = std::make_shared<NullArray>(0);
+ expected = std::make_shared<FixedSizeListArray>(type, 0, values);
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[[null, null], [null, null], [null, null]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ values = std::make_shared<NullArray>(6);
+ expected = std::make_shared<FixedSizeListArray>(type, 3, values);
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[null, [null, null], null]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ {
+ std::unique_ptr<ArrayBuilder> builder;
+ ASSERT_OK(MakeBuilder(pool, type, &builder));
+ auto& list_builder = checked_cast<FixedSizeListBuilder&>(*builder);
+ auto value_builder = checked_cast<NullBuilder*>(list_builder.value_builder());
+ ASSERT_OK(list_builder.AppendNull());
+ ASSERT_OK(list_builder.Append());
+ ASSERT_OK(value_builder->AppendNull());
+ ASSERT_OK(value_builder->AppendNull());
+ ASSERT_OK(list_builder.AppendNull());
+ ASSERT_OK(list_builder.Finish(&expected));
+ }
+ AssertArraysEqual(*expected, *actual);
+}
+
+TEST(TestFixedSizeList, IntegerListList) {
+ auto pool = default_memory_pool();
+ auto nested_type = fixed_size_list(uint8(), 2);
+ std::shared_ptr<DataType> type = fixed_size_list(nested_type, 1);
+ std::shared_ptr<Array> values, nested, expected, actual;
+
+ ASSERT_OK(ArrayFromJSON(type, "[[[1, 4]], [[2, 5]], [[3, 6]]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<UInt8Type>({1, 4, 2, 5, 3, 6}, &values);
+ nested = std::make_shared<FixedSizeListArray>(nested_type, 3, values);
+ expected = std::make_shared<FixedSizeListArray>(type, 3, nested);
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[[[1, null]], [null], null]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ {
+ std::unique_ptr<ArrayBuilder> builder;
+ ASSERT_OK(MakeBuilder(pool, type, &builder));
+ auto& list_builder = checked_cast<FixedSizeListBuilder&>(*builder);
+ auto nested_builder =
+ checked_cast<FixedSizeListBuilder*>(list_builder.value_builder());
+ auto value_builder = checked_cast<UInt8Builder*>(nested_builder->value_builder());
+
+ ASSERT_OK(list_builder.Append());
+ ASSERT_OK(nested_builder->Append());
+ ASSERT_OK(value_builder->Append(1));
+ ASSERT_OK(value_builder->AppendNull());
+
+ ASSERT_OK(list_builder.Append());
+ ASSERT_OK(nested_builder->AppendNull());
+
+ ASSERT_OK(list_builder.AppendNull());
+
+ ASSERT_OK(list_builder.Finish(&expected));
+ }
+ AssertArraysEqual(*expected, *actual);
+}
+
+TEST(TestStruct, SimpleStruct) {
+ auto field_a = field("a", int8());
+ auto field_b = field("b", boolean());
+ std::shared_ptr<DataType> type = struct_({field_a, field_b});
+ std::shared_ptr<Array> a, b, expected, actual;
+ std::shared_ptr<Buffer> null_bitmap;
+ std::vector<bool> is_valid;
+ std::vector<std::shared_ptr<Array>> children;
+
+ // Trivial
+ ASSERT_OK(ArrayFromJSON(type, "[]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int8Type>({}, &a);
+ ArrayFromVector<BooleanType, bool>({}, &b);
+ children.assign({a, b});
+ expected = std::make_shared<StructArray>(type, 0, children);
+ AssertArraysEqual(*expected, *actual);
+
+ // Non-empty
+ ArrayFromVector<Int8Type>({5, 6}, &a);
+ ArrayFromVector<BooleanType, bool>({true, false}, &b);
+ children.assign({a, b});
+ expected = std::make_shared<StructArray>(type, 2, children);
+
+ ASSERT_OK(ArrayFromJSON(type, "[[5, true], [6, false]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ AssertArraysEqual(*expected, *actual);
+ ASSERT_OK(ArrayFromJSON(type, "[{\"a\": 5, \"b\": true}, {\"b\": false, \"a\": 6}]",
+ &actual));
+ ASSERT_OK(actual->ValidateFull());
+ AssertArraysEqual(*expected, *actual);
+
+ // With nulls
+ is_valid = {false, true, false, false};
+ ArrayFromVector<Int8Type>(is_valid, {0, 5, 6, 0}, &a);
+ is_valid = {false, false, true, false};
+ ArrayFromVector<BooleanType, bool>(is_valid, {false, true, false, false}, &b);
+ children.assign({a, b});
+ BitmapFromVector<bool>({false, true, true, true}, &null_bitmap);
+ expected = std::make_shared<StructArray>(type, 4, children, null_bitmap, 1);
+
+ ASSERT_OK(
+ ArrayFromJSON(type, "[null, [5, null], [null, false], [null, null]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ AssertArraysEqual(*expected, *actual);
+ // When using object notation, null members can be omitted
+ ASSERT_OK(ArrayFromJSON(type, "[null, {\"a\": 5, \"b\": null}, {\"b\": false}, {}]",
+ &actual));
+ ASSERT_OK(actual->ValidateFull());
+ AssertArraysEqual(*expected, *actual);
+}
+
+TEST(TestStruct, NestedStruct) {
+ auto field_a = field("a", int8());
+ auto field_b = field("b", boolean());
+ auto field_c = field("c", float64());
+ std::shared_ptr<DataType> nested_type = struct_({field_a, field_b});
+ auto field_nested = field("nested", nested_type);
+ std::shared_ptr<DataType> type = struct_({field_nested, field_c});
+ std::shared_ptr<Array> expected, actual;
+ std::shared_ptr<Buffer> null_bitmap;
+ std::vector<bool> is_valid;
+ std::vector<std::shared_ptr<Array>> children(2);
+
+ ASSERT_OK(ArrayFromJSON(type, "[]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int8Type>({}, &children[0]);
+ ArrayFromVector<BooleanType, bool>({}, &children[1]);
+ children[0] = std::make_shared<StructArray>(nested_type, 0, children);
+ ArrayFromVector<DoubleType>({}, &children[1]);
+ expected = std::make_shared<StructArray>(type, 0, children);
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[[[5, true], 1.5], [[6, false], -3e2]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ArrayFromVector<Int8Type>({5, 6}, &children[0]);
+ ArrayFromVector<BooleanType, bool>({true, false}, &children[1]);
+ children[0] = std::make_shared<StructArray>(nested_type, 2, children);
+ ArrayFromVector<DoubleType>({1.5, -300.0}, &children[1]);
+ expected = std::make_shared<StructArray>(type, 2, children);
+ AssertArraysEqual(*expected, *actual);
+
+ ASSERT_OK(ArrayFromJSON(type, "[null, [[5, null], null], [null, -3e2]]", &actual));
+ ASSERT_OK(actual->ValidateFull());
+ is_valid = {false, true, false};
+ ArrayFromVector<Int8Type>(is_valid, {0, 5, 0}, &children[0]);
+ is_valid = {false, false, false};
+ ArrayFromVector<BooleanType, bool>(is_valid, {false, false, false}, &children[1]);
+ BitmapFromVector<bool>({false, true, false}, &null_bitmap);
+ children[0] = std::make_shared<StructArray>(nested_type, 3, children, null_bitmap, 2);
+ is_valid = {false, false, true};
+ ArrayFromVector<DoubleType>(is_valid, {0.0, 0.0, -300.0}, &children[1]);
+ BitmapFromVector<bool>({false, true, true}, &null_bitmap);
+ expected = std::make_shared<StructArray>(type, 3, children, null_bitmap, 1);
+ AssertArraysEqual(*expected, *actual);
+}
+
+TEST(TestStruct, Errors) {
+ auto field_a = field("a", int8());
+ auto field_b = field("b", boolean());
+ std::shared_ptr<DataType> type = struct_({field_a, field_b});
+ std::shared_ptr<Array> array;
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[0, true]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[0]]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[0, true, 1]]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[true, 0]]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[{\"b\": 0, \"a\": true}]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[{\"c\": 0}]", &array));
+}
+
+TEST(TestDenseUnion, Basics) {
+ auto field_a = field("a", int8());
+ auto field_b = field("b", boolean());
+
+ auto type = dense_union({field_a, field_b}, {4, 8});
+ auto array = checked_pointer_cast<DenseUnionArray>(
+ ArrayFromJSON(type, "[null, [4, 122], [8, true], [4, null], null, [8, false]]"));
+
+ auto expected_types = ArrayFromJSON(int8(), "[4, 4, 8, 4, 4, 8]");
+ auto expected_offsets = ArrayFromJSON(int32(), "[0, 1, 0, 2, 3, 1]");
+ auto expected_a = ArrayFromJSON(int8(), "[null, 122, null, null]");
+ auto expected_b = ArrayFromJSON(boolean(), "[true, false]");
+
+ ASSERT_OK_AND_ASSIGN(
+ auto expected, DenseUnionArray::Make(*expected_types, *expected_offsets,
+ {expected_a, expected_b}, {"a", "b"}, {4, 8}));
+
+ ASSERT_ARRAYS_EQUAL(*expected, *array);
+
+ // ensure that the array is as dense as we expect
+ ASSERT_TRUE(array->value_offsets()->Equals(*expected_offsets->data()->buffers[1]));
+ ASSERT_ARRAYS_EQUAL(*expected_a, *array->field(0));
+ ASSERT_ARRAYS_EQUAL(*expected_b, *array->field(1));
+}
+
+TEST(TestSparseUnion, Basics) {
+ auto field_a = field("a", int8());
+ auto field_b = field("b", boolean());
+
+ auto type = sparse_union({field_a, field_b}, {4, 8});
+ auto array = ArrayFromJSON(type, "[[4, 122], [8, true], [4, null], null, [8, false]]");
+
+ auto expected_types = ArrayFromJSON(int8(), "[4, 8, 4, 4, 8]");
+ auto expected_a = ArrayFromJSON(int8(), "[122, null, null, null, null]");
+ auto expected_b = ArrayFromJSON(boolean(), "[null, true, null, null, false]");
+
+ ASSERT_OK_AND_ASSIGN(auto expected,
+ SparseUnionArray::Make(*expected_types, {expected_a, expected_b},
+ {"a", "b"}, {4, 8}));
+
+ ASSERT_ARRAYS_EQUAL(*expected, *array);
+}
+
+TEST(TestDenseUnion, ListOfUnion) {
+ auto field_a = field("a", int8());
+ auto field_b = field("b", boolean());
+ auto union_type = dense_union({field_a, field_b}, {4, 8});
+ auto list_type = list(union_type);
+ auto array =
+ checked_pointer_cast<ListArray>(ArrayFromJSON(list_type,
+ "["
+ "[[4, 122], [8, true]],"
+ "[[4, null], null, [8, false]]"
+ "]"));
+
+ auto expected_types = ArrayFromJSON(int8(), "[4, 8, 4, 4, 8]");
+ auto expected_offsets = ArrayFromJSON(int32(), "[0, 0, 1, 2, 1]");
+ auto expected_a = ArrayFromJSON(int8(), "[122, null, null]");
+ auto expected_b = ArrayFromJSON(boolean(), "[true, false]");
+
+ ASSERT_OK_AND_ASSIGN(
+ auto expected_values,
+ DenseUnionArray::Make(*expected_types, *expected_offsets, {expected_a, expected_b},
+ {"a", "b"}, {4, 8}));
+ auto expected_list_offsets = ArrayFromJSON(int32(), "[0, 2, 5]");
+ ASSERT_OK_AND_ASSIGN(auto expected,
+ ListArray::FromArrays(*expected_list_offsets, *expected_values));
+
+ ASSERT_ARRAYS_EQUAL(*expected, *array);
+
+ // ensure that the array is as dense as we expect
+ auto array_values = checked_pointer_cast<DenseUnionArray>(array->values());
+ ASSERT_TRUE(array_values->value_offsets()->Equals(
+ *checked_pointer_cast<DenseUnionArray>(expected_values)->value_offsets()));
+ ASSERT_ARRAYS_EQUAL(*expected_a, *array_values->field(0));
+ ASSERT_ARRAYS_EQUAL(*expected_b, *array_values->field(1));
+}
+
+TEST(TestSparseUnion, ListOfUnion) {
+ auto field_a = field("a", int8());
+ auto field_b = field("b", boolean());
+ auto union_type = sparse_union({field_a, field_b}, {4, 8});
+ auto list_type = list(union_type);
+ auto array = ArrayFromJSON(list_type,
+ "["
+ "[[4, 122], [8, true]],"
+ "[[4, null], null, [8, false]]"
+ "]");
+
+ auto expected_types = ArrayFromJSON(int8(), "[4, 8, 4, 4, 8]");
+ auto expected_a = ArrayFromJSON(int8(), "[122, null, null, null, null]");
+ auto expected_b = ArrayFromJSON(boolean(), "[null, true, null, null, false]");
+
+ ASSERT_OK_AND_ASSIGN(auto expected_values,
+ SparseUnionArray::Make(*expected_types, {expected_a, expected_b},
+ {"a", "b"}, {4, 8}));
+ auto expected_list_offsets = ArrayFromJSON(int32(), "[0, 2, 5]");
+ ASSERT_OK_AND_ASSIGN(auto expected,
+ ListArray::FromArrays(*expected_list_offsets, *expected_values));
+
+ ASSERT_ARRAYS_EQUAL(*expected, *array);
+}
+
+TEST(TestDenseUnion, UnionOfStructs) {
+ std::vector<std::shared_ptr<Field>> fields = {
+ field("ab", struct_({field("alpha", float64()), field("bravo", utf8())})),
+ field("wtf", struct_({field("whiskey", int8()), field("tango", float64()),
+ field("foxtrot", list(int8()))})),
+ field("q", struct_({field("quebec", utf8())}))};
+ auto type = dense_union(fields, {0, 23, 47});
+ auto array = checked_pointer_cast<DenseUnionArray>(ArrayFromJSON(type, R"([
+ [0, {"alpha": 0.0, "bravo": "charlie"}],
+ [23, {"whiskey": 99}],
+ [0, {"bravo": "mike"}],
+ null,
+ [23, {"tango": 8.25, "foxtrot": [0, 2, 3]}]
+ ])"));
+
+ auto expected_types = ArrayFromJSON(int8(), "[0, 23, 0, 0, 23]");
+ auto expected_offsets = ArrayFromJSON(int32(), "[0, 0, 1, 2, 1]");
+ ArrayVector expected_fields = {ArrayFromJSON(fields[0]->type(), R"([
+ {"alpha": 0.0, "bravo": "charlie"},
+ {"bravo": "mike"},
+ null
+ ])"),
+ ArrayFromJSON(fields[1]->type(), R"([
+ {"whiskey": 99},
+ {"tango": 8.25, "foxtrot": [0, 2, 3]}
+ ])"),
+ ArrayFromJSON(fields[2]->type(), "[]")};
+
+ ASSERT_OK_AND_ASSIGN(
+ auto expected,
+ DenseUnionArray::Make(*expected_types, *expected_offsets, expected_fields,
+ {"ab", "wtf", "q"}, {0, 23, 47}));
+
+ ASSERT_ARRAYS_EQUAL(*expected, *array);
+
+ // ensure that the array is as dense as we expect
+ ASSERT_TRUE(array->value_offsets()->Equals(*expected_offsets->data()->buffers[1]));
+ for (int i = 0; i < type->num_fields(); ++i) {
+ ASSERT_ARRAYS_EQUAL(*checked_cast<const UnionArray&>(*expected).field(i),
+ *array->field(i));
+ }
+}
+
+TEST(TestSparseUnion, UnionOfStructs) {
+ std::vector<std::shared_ptr<Field>> fields = {
+ field("ab", struct_({field("alpha", float64()), field("bravo", utf8())})),
+ field("wtf", struct_({field("whiskey", int8()), field("tango", float64()),
+ field("foxtrot", list(int8()))})),
+ field("q", struct_({field("quebec", utf8())}))};
+ auto type = sparse_union(fields, {0, 23, 47});
+ auto array = ArrayFromJSON(type, R"([
+ [0, {"alpha": 0.0, "bravo": "charlie"}],
+ [23, {"whiskey": 99}],
+ [0, {"bravo": "mike"}],
+ null,
+ [23, {"tango": 8.25, "foxtrot": [0, 2, 3]}]
+ ])");
+
+ auto expected_types = ArrayFromJSON(int8(), "[0, 23, 0, 0, 23]");
+ ArrayVector expected_fields = {
+ ArrayFromJSON(fields[0]->type(), R"([
+ {"alpha": 0.0, "bravo": "charlie"},
+ null,
+ {"bravo": "mike"},
+ null,
+ null
+ ])"),
+ ArrayFromJSON(fields[1]->type(), R"([
+ null,
+ {"whiskey": 99},
+ null,
+ null,
+ {"tango": 8.25, "foxtrot": [0, 2, 3]}
+ ])"),
+ ArrayFromJSON(fields[2]->type(), "[null, null, null, null, null]")};
+
+ ASSERT_OK_AND_ASSIGN(auto expected,
+ SparseUnionArray::Make(*expected_types, expected_fields,
+ {"ab", "wtf", "q"}, {0, 23, 47}));
+
+ ASSERT_ARRAYS_EQUAL(*expected, *array);
+}
+
+TEST(TestDenseUnion, Errors) {
+ auto field_a = field("a", int8());
+ auto field_b = field("b", boolean());
+ std::shared_ptr<DataType> type = dense_union({field_a, field_b}, {4, 8});
+ std::shared_ptr<Array> array;
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"not a valid type_id\"]", &array));
+ ASSERT_RAISES(Invalid,
+ ArrayFromJSON(type, "[[0, 99]]", &array)); // 0 is not one of {4, 8}
+ ASSERT_RAISES(Invalid,
+ ArrayFromJSON(type, "[[4, \"\"]]", &array)); // "" is not a valid int8()
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"not a pair\"]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[0]]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[8, true, 1]]", &array));
+}
+
+TEST(TestSparseUnion, Errors) {
+ auto field_a = field("a", int8());
+ auto field_b = field("b", boolean());
+ std::shared_ptr<DataType> type = sparse_union({field_a, field_b}, {4, 8});
+ std::shared_ptr<Array> array;
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"not a valid type_id\"]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[0, 99]]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[4, \"\"]]", &array));
+
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[\"not a pair\"]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[0]]", &array));
+ ASSERT_RAISES(Invalid, ArrayFromJSON(type, "[[8, true, 1]]", &array));
+}
+
+TEST(TestNestedDictionary, ListOfDict) {
+ auto index_type = int8();
+ auto value_type = utf8();
+ auto dict_type = dictionary(index_type, value_type);
+ auto type = list(dict_type);
+
+ std::shared_ptr<Array> array, expected, indices, values, dicts, offsets;
+
+ ASSERT_OK(ArrayFromJSON(type, R"([["ab", "cd", null], null, ["cd", "cd"]])", &array));
+ ASSERT_OK(array->ValidateFull());
+
+ // Build expected array
+ ASSERT_OK(ArrayFromJSON(index_type, "[0, 1, null, 1, 1]", &indices));
+ ASSERT_OK(ArrayFromJSON(value_type, R"(["ab", "cd"])", &values));
+ ASSERT_OK_AND_ASSIGN(dicts, DictionaryArray::FromArrays(dict_type, indices, values));
+ ASSERT_OK(ArrayFromJSON(int32(), "[0, null, 3, 5]", &offsets));
+ ASSERT_OK_AND_ASSIGN(expected, ListArray::FromArrays(*offsets, *dicts));
+
+ AssertArraysEqual(*expected, *array, /*verbose=*/true);
+}
+
+TEST(TestDictArrayFromJSON, Basics) {
+ auto type = dictionary(int32(), utf8());
+ auto array =
+ DictArrayFromJSON(type, "[null, 2, 1, 0]", R"(["whiskey", "tango", "foxtrot"])");
+
+ auto expected_indices = ArrayFromJSON(int32(), "[null, 2, 1, 0]");
+ auto expected_dictionary = ArrayFromJSON(utf8(), R"(["whiskey", "tango", "foxtrot"])");
+
+ ASSERT_ARRAYS_EQUAL(DictionaryArray(type, expected_indices, expected_dictionary),
+ *array);
+}
+
+TEST(TestDictArrayFromJSON, Errors) {
+ auto type = dictionary(int32(), utf8());
+ std::shared_ptr<Array> array;
+
+ ASSERT_RAISES(Invalid,
+ DictArrayFromJSON(type, "[\"not a valid index\"]", "[\"\"]", &array));
+ ASSERT_RAISES(Invalid, DictArrayFromJSON(type, "[0, 1]", "[1]",
+ &array)); // dict value isn't string
+}
+
+TEST(TestScalarFromJSON, Basics) {
+ // Sanity check for common types (not exhaustive)
+ std::shared_ptr<Scalar> scalar;
+ AssertJSONScalar<Int64Type>(int64(), "4", true, 4);
+ AssertJSONScalar<Int64Type>(int64(), "null", false, 0);
+ AssertJSONScalar<StringType, std::shared_ptr<Buffer>>(utf8(), R"("")", true,
+ Buffer::FromString(""));
+ AssertJSONScalar<StringType, std::shared_ptr<Buffer>>(utf8(), R"("foo")", true,
+ Buffer::FromString("foo"));
+ AssertJSONScalar<StringType, std::shared_ptr<Buffer>>(utf8(), R"(null)", false,
+ Buffer::FromString(""));
+ AssertJSONScalar<NullType, std::nullptr_t>(null(), "null", false, nullptr);
+ AssertJSONScalar<BooleanType, bool>(boolean(), "true", true, true);
+ AssertJSONScalar<BooleanType, bool>(boolean(), "false", true, false);
+ AssertJSONScalar<BooleanType, bool>(boolean(), "null", false, false);
+ AssertJSONScalar<BooleanType, bool>(boolean(), "0", true, false);
+ AssertJSONScalar<BooleanType, bool>(boolean(), "1", true, true);
+ AssertJSONScalar<DoubleType>(float64(), "1.0", true, 1.0);
+ AssertJSONScalar<DoubleType>(float64(), "-0.0", true, -0.0);
+ ASSERT_OK(ScalarFromJSON(float64(), "NaN", &scalar));
+ ASSERT_TRUE(std::isnan(checked_cast<DoubleScalar&>(*scalar).value));
+ ASSERT_OK(ScalarFromJSON(float64(), "Inf", &scalar));
+ ASSERT_TRUE(std::isinf(checked_cast<DoubleScalar&>(*scalar).value));
+}
+
+TEST(TestScalarFromJSON, Errors) {
+ std::shared_ptr<Scalar> scalar;
+ ASSERT_RAISES(Invalid, ScalarFromJSON(int64(), "[0]", &scalar));
+ ASSERT_RAISES(Invalid, ScalarFromJSON(int64(), "[9223372036854775808]", &scalar));
+ ASSERT_RAISES(Invalid, ScalarFromJSON(int64(), "[-9223372036854775809]", &scalar));
+ ASSERT_RAISES(Invalid, ScalarFromJSON(uint64(), "[18446744073709551616]", &scalar));
+ ASSERT_RAISES(Invalid, ScalarFromJSON(uint64(), "[-1]", &scalar));
+ ASSERT_RAISES(Invalid, ScalarFromJSON(binary(), "0", &scalar));
+ ASSERT_RAISES(Invalid, ScalarFromJSON(binary(), "[]", &scalar));
+ ASSERT_RAISES(Invalid, ScalarFromJSON(boolean(), "0.0", &scalar));
+ ASSERT_RAISES(Invalid, ScalarFromJSON(boolean(), "\"true\"", &scalar));
+}
+
+TEST(TestDictScalarFromJSON, Basics) {
+ auto type = dictionary(int32(), utf8());
+ auto dict = R"(["whiskey", "tango", "foxtrot"])";
+ auto expected_dictionary = ArrayFromJSON(utf8(), dict);
+
+ for (auto index : {"null", "2", "1", "0"}) {
+ auto scalar = DictScalarFromJSON(type, index, dict);
+ auto expected_index = ScalarFromJSON(int32(), index);
+ AssertScalarsEqual(*DictionaryScalar::Make(expected_index, expected_dictionary),
+ *scalar, /*verbose=*/true);
+ ASSERT_OK(scalar->ValidateFull());
+ }
+}
+
+TEST(TestDictScalarFromJSON, Errors) {
+ auto type = dictionary(int32(), utf8());
+ std::shared_ptr<Scalar> scalar;
+
+ ASSERT_RAISES(Invalid,
+ DictScalarFromJSON(type, "\"not a valid index\"", "[\"\"]", &scalar));
+ ASSERT_RAISES(Invalid, DictScalarFromJSON(type, "0", "[1]",
+ &scalar)); // dict value isn't string
+}
+
+} // namespace json
+} // namespace internal
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/message.cc b/src/arrow/cpp/src/arrow/ipc/message.cc
new file mode 100644
index 000000000..197556efc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/message.cc
@@ -0,0 +1,931 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/ipc/message.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/device.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/ipc/metadata_internal.h"
+#include "arrow/ipc/options.h"
+#include "arrow/ipc/util.h"
+#include "arrow/status.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/ubsan.h"
+
+#include "generated/Message_generated.h"
+
+namespace arrow {
+
+class KeyValueMetadata;
+class MemoryPool;
+
+namespace ipc {
+
+class Message::MessageImpl {
+ public:
+ explicit MessageImpl(std::shared_ptr<Buffer> metadata, std::shared_ptr<Buffer> body)
+ : metadata_(std::move(metadata)), message_(nullptr), body_(std::move(body)) {}
+
+ Status Open() {
+ RETURN_NOT_OK(
+ internal::VerifyMessage(metadata_->data(), metadata_->size(), &message_));
+
+ // Check that the metadata version is supported
+ if (message_->version() < internal::kMinMetadataVersion) {
+ return Status::Invalid("Old metadata version not supported");
+ }
+
+ if (message_->version() > flatbuf::MetadataVersion::MAX) {
+ return Status::Invalid("Unsupported future MetadataVersion: ",
+ static_cast<int16_t>(message_->version()));
+ }
+
+ if (message_->custom_metadata() != nullptr) {
+ // Deserialize from Flatbuffers if first time called
+ std::shared_ptr<KeyValueMetadata> md;
+ RETURN_NOT_OK(internal::GetKeyValueMetadata(message_->custom_metadata(), &md));
+ custom_metadata_ = std::move(md); // const-ify
+ }
+
+ return Status::OK();
+ }
+
+ MessageType type() const {
+ switch (message_->header_type()) {
+ case flatbuf::MessageHeader::Schema:
+ return MessageType::SCHEMA;
+ case flatbuf::MessageHeader::DictionaryBatch:
+ return MessageType::DICTIONARY_BATCH;
+ case flatbuf::MessageHeader::RecordBatch:
+ return MessageType::RECORD_BATCH;
+ case flatbuf::MessageHeader::Tensor:
+ return MessageType::TENSOR;
+ case flatbuf::MessageHeader::SparseTensor:
+ return MessageType::SPARSE_TENSOR;
+ default:
+ return MessageType::NONE;
+ }
+ }
+
+ MetadataVersion version() const {
+ return internal::GetMetadataVersion(message_->version());
+ }
+
+ const void* header() const { return message_->header(); }
+
+ int64_t body_length() const { return message_->bodyLength(); }
+
+ std::shared_ptr<Buffer> body() const { return body_; }
+
+ std::shared_ptr<Buffer> metadata() const { return metadata_; }
+
+ const std::shared_ptr<const KeyValueMetadata>& custom_metadata() const {
+ return custom_metadata_;
+ }
+
+ private:
+ // The Flatbuffer metadata
+ std::shared_ptr<Buffer> metadata_;
+ const flatbuf::Message* message_;
+
+ // The reconstructed custom_metadata field from the Message Flatbuffer
+ std::shared_ptr<const KeyValueMetadata> custom_metadata_;
+
+ // The message body, if any
+ std::shared_ptr<Buffer> body_;
+};
+
+Message::Message(std::shared_ptr<Buffer> metadata, std::shared_ptr<Buffer> body) {
+ impl_.reset(new MessageImpl(std::move(metadata), std::move(body)));
+}
+
+Result<std::unique_ptr<Message>> Message::Open(std::shared_ptr<Buffer> metadata,
+ std::shared_ptr<Buffer> body) {
+ std::unique_ptr<Message> result(new Message(std::move(metadata), std::move(body)));
+ RETURN_NOT_OK(result->impl_->Open());
+ return std::move(result);
+}
+
+Message::~Message() {}
+
+std::shared_ptr<Buffer> Message::body() const { return impl_->body(); }
+
+int64_t Message::body_length() const { return impl_->body_length(); }
+
+std::shared_ptr<Buffer> Message::metadata() const { return impl_->metadata(); }
+
+MessageType Message::type() const { return impl_->type(); }
+
+MetadataVersion Message::metadata_version() const { return impl_->version(); }
+
+const void* Message::header() const { return impl_->header(); }
+
+const std::shared_ptr<const KeyValueMetadata>& Message::custom_metadata() const {
+ return impl_->custom_metadata();
+}
+
+bool Message::Equals(const Message& other) const {
+ int64_t metadata_bytes = std::min(metadata()->size(), other.metadata()->size());
+
+ if (!metadata()->Equals(*other.metadata(), metadata_bytes)) {
+ return false;
+ }
+
+ // Compare bodies, if they have them
+ auto this_body = body();
+ auto other_body = other.body();
+
+ const bool this_has_body = (this_body != nullptr) && (this_body->size() > 0);
+ const bool other_has_body = (other_body != nullptr) && (other_body->size() > 0);
+
+ if (this_has_body && other_has_body) {
+ return this_body->Equals(*other_body);
+ } else if (this_has_body ^ other_has_body) {
+ // One has a body but not the other
+ return false;
+ } else {
+ // Neither has a body
+ return true;
+ }
+}
+
+Status MaybeAlignMetadata(std::shared_ptr<Buffer>* metadata) {
+ if (reinterpret_cast<uintptr_t>((*metadata)->data()) % 8 != 0) {
+ // If the metadata memory is not aligned, we copy it here to avoid
+ // potential UBSAN issues from Flatbuffers
+ ARROW_ASSIGN_OR_RAISE(*metadata, (*metadata)->CopySlice(0, (*metadata)->size()));
+ }
+ return Status::OK();
+}
+
+Status CheckMetadataAndGetBodyLength(const Buffer& metadata, int64_t* body_length) {
+ const flatbuf::Message* fb_message = nullptr;
+ RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &fb_message));
+ *body_length = fb_message->bodyLength();
+ if (*body_length < 0) {
+ return Status::IOError("Invalid IPC message: negative bodyLength");
+ }
+ return Status::OK();
+}
+
+Result<std::unique_ptr<Message>> Message::ReadFrom(std::shared_ptr<Buffer> metadata,
+ io::InputStream* stream) {
+ std::unique_ptr<Message> result;
+ auto listener = std::make_shared<AssignMessageDecoderListener>(&result);
+ MessageDecoder decoder(listener, MessageDecoder::State::METADATA, metadata->size());
+ ARROW_RETURN_NOT_OK(decoder.Consume(metadata));
+
+ ARROW_ASSIGN_OR_RAISE(auto body, stream->Read(decoder.next_required_size()));
+ if (body->size() < decoder.next_required_size()) {
+ return Status::IOError("Expected to be able to read ", decoder.next_required_size(),
+ " bytes for message body, got ", body->size());
+ }
+ RETURN_NOT_OK(decoder.Consume(body));
+ return std::move(result);
+}
+
+Result<std::unique_ptr<Message>> Message::ReadFrom(const int64_t offset,
+ std::shared_ptr<Buffer> metadata,
+ io::RandomAccessFile* file) {
+ std::unique_ptr<Message> result;
+ auto listener = std::make_shared<AssignMessageDecoderListener>(&result);
+ MessageDecoder decoder(listener, MessageDecoder::State::METADATA, metadata->size());
+ ARROW_RETURN_NOT_OK(decoder.Consume(metadata));
+
+ ARROW_ASSIGN_OR_RAISE(auto body, file->ReadAt(offset, decoder.next_required_size()));
+ if (body->size() < decoder.next_required_size()) {
+ return Status::IOError("Expected to be able to read ", decoder.next_required_size(),
+ " bytes for message body, got ", body->size());
+ }
+ RETURN_NOT_OK(decoder.Consume(body));
+ return std::move(result);
+}
+
+Status WritePadding(io::OutputStream* stream, int64_t nbytes) {
+ while (nbytes > 0) {
+ const int64_t bytes_to_write = std::min<int64_t>(nbytes, kArrowAlignment);
+ RETURN_NOT_OK(stream->Write(kPaddingBytes, bytes_to_write));
+ nbytes -= bytes_to_write;
+ }
+ return Status::OK();
+}
+
+Status Message::SerializeTo(io::OutputStream* stream, const IpcWriteOptions& options,
+ int64_t* output_length) const {
+ int32_t metadata_length = 0;
+ RETURN_NOT_OK(WriteMessage(*metadata(), options, stream, &metadata_length));
+
+ *output_length = metadata_length;
+
+ auto body_buffer = body();
+ if (body_buffer) {
+ RETURN_NOT_OK(stream->Write(body_buffer));
+ *output_length += body_buffer->size();
+
+ DCHECK_GE(this->body_length(), body_buffer->size());
+
+ int64_t remainder = this->body_length() - body_buffer->size();
+ RETURN_NOT_OK(WritePadding(stream, remainder));
+ *output_length += remainder;
+ }
+ return Status::OK();
+}
+
+bool Message::Verify() const {
+ const flatbuf::Message* unused;
+ return internal::VerifyMessage(metadata()->data(), metadata()->size(), &unused).ok();
+}
+
+std::string FormatMessageType(MessageType type) {
+ switch (type) {
+ case MessageType::SCHEMA:
+ return "schema";
+ case MessageType::RECORD_BATCH:
+ return "record batch";
+ case MessageType::DICTIONARY_BATCH:
+ return "dictionary";
+ case MessageType::TENSOR:
+ return "tensor";
+ case MessageType::SPARSE_TENSOR:
+ return "sparse tensor";
+ default:
+ break;
+ }
+ return "unknown";
+}
+
+Result<std::unique_ptr<Message>> ReadMessage(int64_t offset, int32_t metadata_length,
+ io::RandomAccessFile* file) {
+ std::unique_ptr<Message> result;
+ auto listener = std::make_shared<AssignMessageDecoderListener>(&result);
+ MessageDecoder decoder(listener);
+
+ if (metadata_length < decoder.next_required_size()) {
+ return Status::Invalid("metadata_length should be at least ",
+ decoder.next_required_size());
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto metadata, file->ReadAt(offset, metadata_length));
+ if (metadata->size() < metadata_length) {
+ return Status::Invalid("Expected to read ", metadata_length,
+ " metadata bytes but got ", metadata->size());
+ }
+ ARROW_RETURN_NOT_OK(decoder.Consume(metadata));
+
+ switch (decoder.state()) {
+ case MessageDecoder::State::INITIAL:
+ return std::move(result);
+ case MessageDecoder::State::METADATA_LENGTH:
+ return Status::Invalid("metadata length is missing. File offset: ", offset,
+ ", metadata length: ", metadata_length);
+ case MessageDecoder::State::METADATA:
+ return Status::Invalid("flatbuffer size ", decoder.next_required_size(),
+ " invalid. File offset: ", offset,
+ ", metadata length: ", metadata_length);
+ case MessageDecoder::State::BODY: {
+ ARROW_ASSIGN_OR_RAISE(auto body, file->ReadAt(offset + metadata_length,
+ decoder.next_required_size()));
+ if (body->size() < decoder.next_required_size()) {
+ return Status::IOError("Expected to be able to read ",
+ decoder.next_required_size(),
+ " bytes for message body, got ", body->size());
+ }
+ RETURN_NOT_OK(decoder.Consume(body));
+ return std::move(result);
+ }
+ case MessageDecoder::State::EOS:
+ return Status::Invalid("Unexpected empty message in IPC file format");
+ default:
+ return Status::Invalid("Unexpected state: ", decoder.state());
+ }
+}
+
+Future<std::shared_ptr<Message>> ReadMessageAsync(int64_t offset, int32_t metadata_length,
+ int64_t body_length,
+ io::RandomAccessFile* file,
+ const io::IOContext& context) {
+ struct State {
+ std::unique_ptr<Message> result;
+ std::shared_ptr<MessageDecoderListener> listener;
+ std::shared_ptr<MessageDecoder> decoder;
+ };
+ auto state = std::make_shared<State>();
+ state->listener = std::make_shared<AssignMessageDecoderListener>(&state->result);
+ state->decoder = std::make_shared<MessageDecoder>(state->listener);
+
+ if (metadata_length < state->decoder->next_required_size()) {
+ return Status::Invalid("metadata_length should be at least ",
+ state->decoder->next_required_size());
+ }
+ return file->ReadAsync(context, offset, metadata_length + body_length)
+ .Then([=](std::shared_ptr<Buffer> metadata) -> Result<std::shared_ptr<Message>> {
+ if (metadata->size() < metadata_length) {
+ return Status::Invalid("Expected to read ", metadata_length,
+ " metadata bytes but got ", metadata->size());
+ }
+ ARROW_RETURN_NOT_OK(
+ state->decoder->Consume(SliceBuffer(metadata, 0, metadata_length)));
+ switch (state->decoder->state()) {
+ case MessageDecoder::State::INITIAL:
+ return std::move(state->result);
+ case MessageDecoder::State::METADATA_LENGTH:
+ return Status::Invalid("metadata length is missing. File offset: ", offset,
+ ", metadata length: ", metadata_length);
+ case MessageDecoder::State::METADATA:
+ return Status::Invalid("flatbuffer size ",
+ state->decoder->next_required_size(),
+ " invalid. File offset: ", offset,
+ ", metadata length: ", metadata_length);
+ case MessageDecoder::State::BODY: {
+ auto body = SliceBuffer(metadata, metadata_length, body_length);
+ if (body->size() < state->decoder->next_required_size()) {
+ return Status::IOError("Expected to be able to read ",
+ state->decoder->next_required_size(),
+ " bytes for message body, got ", body->size());
+ }
+ RETURN_NOT_OK(state->decoder->Consume(body));
+ return std::move(state->result);
+ }
+ case MessageDecoder::State::EOS:
+ return Status::Invalid("Unexpected empty message in IPC file format");
+ default:
+ return Status::Invalid("Unexpected state: ", state->decoder->state());
+ }
+ });
+}
+
+Status AlignStream(io::InputStream* stream, int32_t alignment) {
+ ARROW_ASSIGN_OR_RAISE(int64_t position, stream->Tell());
+ return stream->Advance(PaddedLength(position, alignment) - position);
+}
+
+Status AlignStream(io::OutputStream* stream, int32_t alignment) {
+ ARROW_ASSIGN_OR_RAISE(int64_t position, stream->Tell());
+ int64_t remainder = PaddedLength(position, alignment) - position;
+ if (remainder > 0) {
+ return stream->Write(kPaddingBytes, remainder);
+ }
+ return Status::OK();
+}
+
+Status CheckAligned(io::FileInterface* stream, int32_t alignment) {
+ ARROW_ASSIGN_OR_RAISE(int64_t position, stream->Tell());
+ if (position % alignment != 0) {
+ return Status::Invalid("Stream is not aligned pos: ", position,
+ " alignment: ", alignment);
+ } else {
+ return Status::OK();
+ }
+}
+
+Status DecodeMessage(MessageDecoder* decoder, io::InputStream* file) {
+ if (decoder->state() == MessageDecoder::State::INITIAL) {
+ uint8_t continuation[sizeof(int32_t)];
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, file->Read(sizeof(int32_t), &continuation));
+ if (bytes_read == 0) {
+ // EOS without indication
+ return Status::OK();
+ } else if (bytes_read != decoder->next_required_size()) {
+ return Status::Invalid("Corrupted message, only ", bytes_read, " bytes available");
+ }
+ ARROW_RETURN_NOT_OK(decoder->Consume(continuation, bytes_read));
+ }
+
+ if (decoder->state() == MessageDecoder::State::METADATA_LENGTH) {
+ // Valid IPC message, read the message length now
+ uint8_t metadata_length[sizeof(int32_t)];
+ ARROW_ASSIGN_OR_RAISE(int64_t bytes_read,
+ file->Read(sizeof(int32_t), &metadata_length));
+ if (bytes_read != decoder->next_required_size()) {
+ return Status::Invalid("Corrupted metadata length, only ", bytes_read,
+ " bytes available");
+ }
+ ARROW_RETURN_NOT_OK(decoder->Consume(metadata_length, bytes_read));
+ }
+
+ if (decoder->state() == MessageDecoder::State::EOS) {
+ return Status::OK();
+ }
+
+ auto metadata_length = decoder->next_required_size();
+ ARROW_ASSIGN_OR_RAISE(auto metadata, file->Read(metadata_length));
+ if (metadata->size() != metadata_length) {
+ return Status::Invalid("Expected to read ", metadata_length, " metadata bytes, but ",
+ "only read ", metadata->size());
+ }
+ ARROW_RETURN_NOT_OK(decoder->Consume(metadata));
+
+ if (decoder->state() == MessageDecoder::State::BODY) {
+ ARROW_ASSIGN_OR_RAISE(auto body, file->Read(decoder->next_required_size()));
+ if (body->size() < decoder->next_required_size()) {
+ return Status::IOError("Expected to be able to read ",
+ decoder->next_required_size(),
+ " bytes for message body, got ", body->size());
+ }
+ ARROW_RETURN_NOT_OK(decoder->Consume(body));
+ }
+
+ if (decoder->state() == MessageDecoder::State::INITIAL ||
+ decoder->state() == MessageDecoder::State::EOS) {
+ return Status::OK();
+ } else {
+ return Status::Invalid("Failed to decode message");
+ }
+}
+
+Result<std::unique_ptr<Message>> ReadMessage(io::InputStream* file, MemoryPool* pool) {
+ std::unique_ptr<Message> message;
+ auto listener = std::make_shared<AssignMessageDecoderListener>(&message);
+ MessageDecoder decoder(listener, pool);
+ ARROW_RETURN_NOT_OK(DecodeMessage(&decoder, file));
+ if (!message) {
+ return nullptr;
+ } else {
+ return std::move(message);
+ }
+}
+
+Status WriteMessage(const Buffer& message, const IpcWriteOptions& options,
+ io::OutputStream* file, int32_t* message_length) {
+ const int32_t prefix_size = options.write_legacy_ipc_format ? 4 : 8;
+ const int32_t flatbuffer_size = static_cast<int32_t>(message.size());
+
+ int32_t padded_message_length = static_cast<int32_t>(
+ PaddedLength(flatbuffer_size + prefix_size, options.alignment));
+
+ int32_t padding = padded_message_length - flatbuffer_size - prefix_size;
+
+ // The returned message size includes the length prefix, the flatbuffer,
+ // plus padding
+ *message_length = padded_message_length;
+
+ // ARROW-6314: Write continuation / padding token
+ if (!options.write_legacy_ipc_format) {
+ RETURN_NOT_OK(file->Write(&internal::kIpcContinuationToken, sizeof(int32_t)));
+ }
+
+ // Write the flatbuffer size prefix including padding in little endian
+ int32_t padded_flatbuffer_size =
+ BitUtil::ToLittleEndian(padded_message_length - prefix_size);
+ RETURN_NOT_OK(file->Write(&padded_flatbuffer_size, sizeof(int32_t)));
+
+ // Write the flatbuffer
+ RETURN_NOT_OK(file->Write(message.data(), flatbuffer_size));
+ if (padding > 0) {
+ RETURN_NOT_OK(file->Write(kPaddingBytes, padding));
+ }
+
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Implement MessageDecoder
+
+Status MessageDecoderListener::OnInitial() { return Status::OK(); }
+Status MessageDecoderListener::OnMetadataLength() { return Status::OK(); }
+Status MessageDecoderListener::OnMetadata() { return Status::OK(); }
+Status MessageDecoderListener::OnBody() { return Status::OK(); }
+Status MessageDecoderListener::OnEOS() { return Status::OK(); }
+
+static constexpr auto kMessageDecoderNextRequiredSizeInitial = sizeof(int32_t);
+static constexpr auto kMessageDecoderNextRequiredSizeMetadataLength = sizeof(int32_t);
+
+class MessageDecoder::MessageDecoderImpl {
+ public:
+ explicit MessageDecoderImpl(std::shared_ptr<MessageDecoderListener> listener,
+ State initial_state, int64_t initial_next_required_size,
+ MemoryPool* pool)
+ : listener_(std::move(listener)),
+ pool_(pool),
+ state_(initial_state),
+ next_required_size_(initial_next_required_size),
+ chunks_(),
+ buffered_size_(0),
+ metadata_(nullptr) {}
+
+ Status ConsumeData(const uint8_t* data, int64_t size) {
+ if (buffered_size_ == 0) {
+ while (size > 0 && size >= next_required_size_) {
+ auto used_size = next_required_size_;
+ switch (state_) {
+ case State::INITIAL:
+ RETURN_NOT_OK(ConsumeInitialData(data, next_required_size_));
+ break;
+ case State::METADATA_LENGTH:
+ RETURN_NOT_OK(ConsumeMetadataLengthData(data, next_required_size_));
+ break;
+ case State::METADATA: {
+ auto buffer = std::make_shared<Buffer>(data, next_required_size_);
+ RETURN_NOT_OK(ConsumeMetadataBuffer(buffer));
+ } break;
+ case State::BODY: {
+ auto buffer = std::make_shared<Buffer>(data, next_required_size_);
+ RETURN_NOT_OK(ConsumeBodyBuffer(buffer));
+ } break;
+ case State::EOS:
+ return Status::OK();
+ }
+ data += used_size;
+ size -= used_size;
+ }
+ }
+
+ if (size == 0) {
+ return Status::OK();
+ }
+
+ chunks_.push_back(std::make_shared<Buffer>(data, size));
+ buffered_size_ += size;
+ return ConsumeChunks();
+ }
+
+ Status ConsumeBuffer(std::shared_ptr<Buffer> buffer) {
+ if (buffered_size_ == 0) {
+ while (buffer->size() >= next_required_size_) {
+ auto used_size = next_required_size_;
+ switch (state_) {
+ case State::INITIAL:
+ RETURN_NOT_OK(ConsumeInitialBuffer(buffer));
+ break;
+ case State::METADATA_LENGTH:
+ RETURN_NOT_OK(ConsumeMetadataLengthBuffer(buffer));
+ break;
+ case State::METADATA:
+ if (buffer->size() == next_required_size_) {
+ return ConsumeMetadataBuffer(buffer);
+ } else {
+ auto sliced_buffer = SliceBuffer(buffer, 0, next_required_size_);
+ RETURN_NOT_OK(ConsumeMetadataBuffer(sliced_buffer));
+ }
+ break;
+ case State::BODY:
+ if (buffer->size() == next_required_size_) {
+ return ConsumeBodyBuffer(buffer);
+ } else {
+ auto sliced_buffer = SliceBuffer(buffer, 0, next_required_size_);
+ RETURN_NOT_OK(ConsumeBodyBuffer(sliced_buffer));
+ }
+ break;
+ case State::EOS:
+ return Status::OK();
+ }
+ if (buffer->size() == used_size) {
+ return Status::OK();
+ }
+ buffer = SliceBuffer(buffer, used_size);
+ }
+ }
+
+ if (buffer->size() == 0) {
+ return Status::OK();
+ }
+
+ buffered_size_ += buffer->size();
+ chunks_.push_back(std::move(buffer));
+ return ConsumeChunks();
+ }
+
+ int64_t next_required_size() const { return next_required_size_ - buffered_size_; }
+
+ MessageDecoder::State state() const { return state_; }
+
+ private:
+ Status ConsumeChunks() {
+ while (state_ != State::EOS) {
+ if (buffered_size_ < next_required_size_) {
+ return Status::OK();
+ }
+
+ switch (state_) {
+ case State::INITIAL:
+ RETURN_NOT_OK(ConsumeInitialChunks());
+ break;
+ case State::METADATA_LENGTH:
+ RETURN_NOT_OK(ConsumeMetadataLengthChunks());
+ break;
+ case State::METADATA:
+ RETURN_NOT_OK(ConsumeMetadataChunks());
+ break;
+ case State::BODY:
+ RETURN_NOT_OK(ConsumeBodyChunks());
+ break;
+ case State::EOS:
+ return Status::OK();
+ }
+ }
+
+ return Status::OK();
+ }
+
+ Status ConsumeInitialData(const uint8_t* data, int64_t size) {
+ return ConsumeInitial(BitUtil::FromLittleEndian(util::SafeLoadAs<int32_t>(data)));
+ }
+
+ Status ConsumeInitialBuffer(const std::shared_ptr<Buffer>& buffer) {
+ ARROW_ASSIGN_OR_RAISE(auto continuation, ConsumeDataBufferInt32(buffer));
+ return ConsumeInitial(BitUtil::FromLittleEndian(continuation));
+ }
+
+ Status ConsumeInitialChunks() {
+ int32_t continuation = 0;
+ RETURN_NOT_OK(ConsumeDataChunks(sizeof(int32_t), &continuation));
+ return ConsumeInitial(BitUtil::FromLittleEndian(continuation));
+ }
+
+ Status ConsumeInitial(int32_t continuation) {
+ if (continuation == internal::kIpcContinuationToken) {
+ state_ = State::METADATA_LENGTH;
+ next_required_size_ = kMessageDecoderNextRequiredSizeMetadataLength;
+ RETURN_NOT_OK(listener_->OnMetadataLength());
+ // Valid IPC message, read the message length now
+ return Status::OK();
+ } else if (continuation == 0) {
+ state_ = State::EOS;
+ next_required_size_ = 0;
+ RETURN_NOT_OK(listener_->OnEOS());
+ return Status::OK();
+ } else if (continuation > 0) {
+ state_ = State::METADATA;
+ // ARROW-6314: Backwards compatibility for reading old IPC
+ // messages produced prior to version 0.15.0
+ next_required_size_ = continuation;
+ RETURN_NOT_OK(listener_->OnMetadata());
+ return Status::OK();
+ } else {
+ return Status::IOError("Invalid IPC stream: negative continuation token");
+ }
+ }
+
+ Status ConsumeMetadataLengthData(const uint8_t* data, int64_t size) {
+ return ConsumeMetadataLength(
+ BitUtil::FromLittleEndian(util::SafeLoadAs<int32_t>(data)));
+ }
+
+ Status ConsumeMetadataLengthBuffer(const std::shared_ptr<Buffer>& buffer) {
+ ARROW_ASSIGN_OR_RAISE(auto metadata_length, ConsumeDataBufferInt32(buffer));
+ return ConsumeMetadataLength(BitUtil::FromLittleEndian(metadata_length));
+ }
+
+ Status ConsumeMetadataLengthChunks() {
+ int32_t metadata_length = 0;
+ RETURN_NOT_OK(ConsumeDataChunks(sizeof(int32_t), &metadata_length));
+ return ConsumeMetadataLength(BitUtil::FromLittleEndian(metadata_length));
+ }
+
+ Status ConsumeMetadataLength(int32_t metadata_length) {
+ if (metadata_length == 0) {
+ state_ = State::EOS;
+ next_required_size_ = 0;
+ RETURN_NOT_OK(listener_->OnEOS());
+ return Status::OK();
+ } else if (metadata_length > 0) {
+ state_ = State::METADATA;
+ next_required_size_ = metadata_length;
+ RETURN_NOT_OK(listener_->OnMetadata());
+ return Status::OK();
+ } else {
+ return Status::IOError("Invalid IPC message: negative metadata length");
+ }
+ }
+
+ Status ConsumeMetadataBuffer(const std::shared_ptr<Buffer>& buffer) {
+ if (buffer->is_cpu()) {
+ metadata_ = buffer;
+ } else {
+ ARROW_ASSIGN_OR_RAISE(metadata_,
+ Buffer::ViewOrCopy(buffer, CPUDevice::memory_manager(pool_)));
+ }
+ return ConsumeMetadata();
+ }
+
+ Status ConsumeMetadataChunks() {
+ if (chunks_[0]->size() >= next_required_size_) {
+ if (chunks_[0]->size() == next_required_size_) {
+ if (chunks_[0]->is_cpu()) {
+ metadata_ = std::move(chunks_[0]);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ metadata_,
+ Buffer::ViewOrCopy(chunks_[0], CPUDevice::memory_manager(pool_)));
+ }
+ chunks_.erase(chunks_.begin());
+ } else {
+ metadata_ = SliceBuffer(chunks_[0], 0, next_required_size_);
+ if (!chunks_[0]->is_cpu()) {
+ ARROW_ASSIGN_OR_RAISE(
+ metadata_, Buffer::ViewOrCopy(metadata_, CPUDevice::memory_manager(pool_)));
+ }
+ chunks_[0] = SliceBuffer(chunks_[0], next_required_size_);
+ }
+ buffered_size_ -= next_required_size_;
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto metadata, AllocateBuffer(next_required_size_, pool_));
+ metadata_ = std::shared_ptr<Buffer>(metadata.release());
+ RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, metadata_->mutable_data()));
+ }
+ return ConsumeMetadata();
+ }
+
+ Status ConsumeMetadata() {
+ RETURN_NOT_OK(MaybeAlignMetadata(&metadata_));
+ int64_t body_length = -1;
+ RETURN_NOT_OK(CheckMetadataAndGetBodyLength(*metadata_, &body_length));
+
+ state_ = State::BODY;
+ next_required_size_ = body_length;
+ RETURN_NOT_OK(listener_->OnBody());
+ if (next_required_size_ == 0) {
+ ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(0, pool_));
+ std::shared_ptr<Buffer> shared_body(body.release());
+ return ConsumeBody(&shared_body);
+ } else {
+ return Status::OK();
+ }
+ }
+
+ Status ConsumeBodyBuffer(std::shared_ptr<Buffer> buffer) {
+ return ConsumeBody(&buffer);
+ }
+
+ Status ConsumeBodyChunks() {
+ if (chunks_[0]->size() >= next_required_size_) {
+ auto used_size = next_required_size_;
+ if (chunks_[0]->size() == next_required_size_) {
+ RETURN_NOT_OK(ConsumeBody(&chunks_[0]));
+ chunks_.erase(chunks_.begin());
+ } else {
+ auto body = SliceBuffer(chunks_[0], 0, next_required_size_);
+ RETURN_NOT_OK(ConsumeBody(&body));
+ chunks_[0] = SliceBuffer(chunks_[0], used_size);
+ }
+ buffered_size_ -= used_size;
+ return Status::OK();
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(next_required_size_, pool_));
+ RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, body->mutable_data()));
+ std::shared_ptr<Buffer> shared_body(body.release());
+ return ConsumeBody(&shared_body);
+ }
+ }
+
+ Status ConsumeBody(std::shared_ptr<Buffer>* buffer) {
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message,
+ Message::Open(metadata_, *buffer));
+
+ RETURN_NOT_OK(listener_->OnMessageDecoded(std::move(message)));
+ state_ = State::INITIAL;
+ next_required_size_ = kMessageDecoderNextRequiredSizeInitial;
+ RETURN_NOT_OK(listener_->OnInitial());
+ return Status::OK();
+ }
+
+ Result<int32_t> ConsumeDataBufferInt32(const std::shared_ptr<Buffer>& buffer) {
+ if (buffer->is_cpu()) {
+ return util::SafeLoadAs<int32_t>(buffer->data());
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto cpu_buffer,
+ Buffer::ViewOrCopy(buffer, CPUDevice::memory_manager(pool_)));
+ return util::SafeLoadAs<int32_t>(cpu_buffer->data());
+ }
+ }
+
+ Status ConsumeDataChunks(int64_t nbytes, void* out) {
+ size_t offset = 0;
+ size_t n_used_chunks = 0;
+ auto required_size = nbytes;
+ std::shared_ptr<Buffer> last_chunk;
+ for (auto& chunk : chunks_) {
+ if (!chunk->is_cpu()) {
+ ARROW_ASSIGN_OR_RAISE(
+ chunk, Buffer::ViewOrCopy(chunk, CPUDevice::memory_manager(pool_)));
+ }
+ auto data = chunk->data();
+ auto data_size = chunk->size();
+ auto copy_size = std::min(required_size, data_size);
+ memcpy(static_cast<uint8_t*>(out) + offset, data, copy_size);
+ n_used_chunks++;
+ offset += copy_size;
+ required_size -= copy_size;
+ if (required_size == 0) {
+ if (data_size != copy_size) {
+ last_chunk = SliceBuffer(chunk, copy_size);
+ }
+ break;
+ }
+ }
+ chunks_.erase(chunks_.begin(), chunks_.begin() + n_used_chunks);
+ if (last_chunk.get() != nullptr) {
+ chunks_.insert(chunks_.begin(), std::move(last_chunk));
+ }
+ buffered_size_ -= offset;
+ return Status::OK();
+ }
+
+ std::shared_ptr<MessageDecoderListener> listener_;
+ MemoryPool* pool_;
+ State state_;
+ int64_t next_required_size_;
+ std::vector<std::shared_ptr<Buffer>> chunks_;
+ int64_t buffered_size_;
+ std::shared_ptr<Buffer> metadata_; // Must be CPU buffer
+};
+
+MessageDecoder::MessageDecoder(std::shared_ptr<MessageDecoderListener> listener,
+ MemoryPool* pool) {
+ impl_.reset(new MessageDecoderImpl(std::move(listener), State::INITIAL,
+ kMessageDecoderNextRequiredSizeInitial, pool));
+}
+
+MessageDecoder::MessageDecoder(std::shared_ptr<MessageDecoderListener> listener,
+ State initial_state, int64_t initial_next_required_size,
+ MemoryPool* pool) {
+ impl_.reset(new MessageDecoderImpl(std::move(listener), initial_state,
+ initial_next_required_size, pool));
+}
+
+MessageDecoder::~MessageDecoder() {}
+
+Status MessageDecoder::Consume(const uint8_t* data, int64_t size) {
+ return impl_->ConsumeData(data, size);
+}
+
+Status MessageDecoder::Consume(std::shared_ptr<Buffer> buffer) {
+ return impl_->ConsumeBuffer(buffer);
+}
+
+int64_t MessageDecoder::next_required_size() const { return impl_->next_required_size(); }
+
+MessageDecoder::State MessageDecoder::state() const { return impl_->state(); }
+
+// ----------------------------------------------------------------------
+// Implement InputStream message reader
+
+/// \brief Implementation of MessageReader that reads from InputStream
+class InputStreamMessageReader : public MessageReader, public MessageDecoderListener {
+ public:
+ explicit InputStreamMessageReader(io::InputStream* stream)
+ : stream_(stream),
+ owned_stream_(),
+ message_(),
+ decoder_(std::shared_ptr<InputStreamMessageReader>(this, [](void*) {})) {}
+
+ explicit InputStreamMessageReader(const std::shared_ptr<io::InputStream>& owned_stream)
+ : InputStreamMessageReader(owned_stream.get()) {
+ owned_stream_ = owned_stream;
+ }
+
+ ~InputStreamMessageReader() {}
+
+ Status OnMessageDecoded(std::unique_ptr<Message> message) override {
+ message_ = std::move(message);
+ return Status::OK();
+ }
+
+ Result<std::unique_ptr<Message>> ReadNextMessage() override {
+ ARROW_RETURN_NOT_OK(DecodeMessage(&decoder_, stream_));
+ return std::move(message_);
+ }
+
+ private:
+ io::InputStream* stream_;
+ std::shared_ptr<io::InputStream> owned_stream_;
+ std::unique_ptr<Message> message_;
+ MessageDecoder decoder_;
+};
+
+std::unique_ptr<MessageReader> MessageReader::Open(io::InputStream* stream) {
+ return std::unique_ptr<MessageReader>(new InputStreamMessageReader(stream));
+}
+
+std::unique_ptr<MessageReader> MessageReader::Open(
+ const std::shared_ptr<io::InputStream>& owned_stream) {
+ return std::unique_ptr<MessageReader>(new InputStreamMessageReader(owned_stream));
+}
+
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/message.h b/src/arrow/cpp/src/arrow/ipc/message.h
new file mode 100644
index 000000000..b2683259c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/message.h
@@ -0,0 +1,536 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// C++ object model and user API for interprocess schema messaging
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "arrow/io/type_fwd.h"
+#include "arrow/ipc/type_fwd.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace ipc {
+
+struct IpcWriteOptions;
+
+// Read interface classes. We do not fully deserialize the flatbuffers so that
+// individual fields metadata can be retrieved from very large schema without
+//
+
+/// \class Message
+/// \brief An IPC message including metadata and body
+class ARROW_EXPORT Message {
+ public:
+ /// \brief Construct message, but do not validate
+ ///
+ /// Use at your own risk; Message::Open has more metadata validation
+ Message(std::shared_ptr<Buffer> metadata, std::shared_ptr<Buffer> body);
+
+ ~Message();
+
+ /// \brief Create and validate a Message instance from two buffers
+ ///
+ /// \param[in] metadata a buffer containing the Flatbuffer metadata
+ /// \param[in] body a buffer containing the message body, which may be null
+ /// \return the created message
+ static Result<std::unique_ptr<Message>> Open(std::shared_ptr<Buffer> metadata,
+ std::shared_ptr<Buffer> body);
+
+ /// \brief Read message body and create Message given Flatbuffer metadata
+ /// \param[in] metadata containing a serialized Message flatbuffer
+ /// \param[in] stream an InputStream
+ /// \return the created Message
+ ///
+ /// \note If stream supports zero-copy, this is zero-copy
+ static Result<std::unique_ptr<Message>> ReadFrom(std::shared_ptr<Buffer> metadata,
+ io::InputStream* stream);
+
+ /// \brief Read message body from position in file, and create Message given
+ /// the Flatbuffer metadata
+ /// \param[in] offset the position in the file where the message body starts.
+ /// \param[in] metadata containing a serialized Message flatbuffer
+ /// \param[in] file the seekable file interface to read from
+ /// \return the created Message
+ ///
+ /// \note If file supports zero-copy, this is zero-copy
+ static Result<std::unique_ptr<Message>> ReadFrom(const int64_t offset,
+ std::shared_ptr<Buffer> metadata,
+ io::RandomAccessFile* file);
+
+ /// \brief Return true if message type and contents are equal
+ ///
+ /// \param other another message
+ /// \return true if contents equal
+ bool Equals(const Message& other) const;
+
+ /// \brief the Message metadata
+ ///
+ /// \return buffer
+ std::shared_ptr<Buffer> metadata() const;
+
+ /// \brief Custom metadata serialized in metadata Flatbuffer. Returns nullptr
+ /// when none set
+ const std::shared_ptr<const KeyValueMetadata>& custom_metadata() const;
+
+ /// \brief the Message body, if any
+ ///
+ /// \return buffer is null if no body
+ std::shared_ptr<Buffer> body() const;
+
+ /// \brief The expected body length according to the metadata, for
+ /// verification purposes
+ int64_t body_length() const;
+
+ /// \brief The Message type
+ MessageType type() const;
+
+ /// \brief The Message metadata version
+ MetadataVersion metadata_version() const;
+
+ const void* header() const;
+
+ /// \brief Write length-prefixed metadata and body to output stream
+ ///
+ /// \param[in] file output stream to write to
+ /// \param[in] options IPC writing options including alignment
+ /// \param[out] output_length the number of bytes written
+ /// \return Status
+ Status SerializeTo(io::OutputStream* file, const IpcWriteOptions& options,
+ int64_t* output_length) const;
+
+ /// \brief Return true if the Message metadata passes Flatbuffer validation
+ bool Verify() const;
+
+ /// \brief Whether a given message type needs a body.
+ static bool HasBody(MessageType type) {
+ return type != MessageType::NONE && type != MessageType::SCHEMA;
+ }
+
+ private:
+ // Hide serialization details from user API
+ class MessageImpl;
+ std::unique_ptr<MessageImpl> impl_;
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Message);
+};
+
+ARROW_EXPORT std::string FormatMessageType(MessageType type);
+
+/// \class MessageDecoderListener
+/// \brief An abstract class to listen events from MessageDecoder.
+///
+/// This API is EXPERIMENTAL.
+///
+/// \since 0.17.0
+class ARROW_EXPORT MessageDecoderListener {
+ public:
+ virtual ~MessageDecoderListener() = default;
+
+ /// \brief Called when a message is decoded.
+ ///
+ /// MessageDecoder calls this method when it decodes a message. This
+ /// method is called multiple times when the target stream has
+ /// multiple messages.
+ ///
+ /// \param[in] message a decoded message
+ /// \return Status
+ virtual Status OnMessageDecoded(std::unique_ptr<Message> message) = 0;
+
+ /// \brief Called when the decoder state is changed to
+ /// MessageDecoder::State::INITIAL.
+ ///
+ /// The default implementation just returns arrow::Status::OK().
+ ///
+ /// \return Status
+ virtual Status OnInitial();
+
+ /// \brief Called when the decoder state is changed to
+ /// MessageDecoder::State::METADATA_LENGTH.
+ ///
+ /// The default implementation just returns arrow::Status::OK().
+ ///
+ /// \return Status
+ virtual Status OnMetadataLength();
+
+ /// \brief Called when the decoder state is changed to
+ /// MessageDecoder::State::METADATA.
+ ///
+ /// The default implementation just returns arrow::Status::OK().
+ ///
+ /// \return Status
+ virtual Status OnMetadata();
+
+ /// \brief Called when the decoder state is changed to
+ /// MessageDecoder::State::BODY.
+ ///
+ /// The default implementation just returns arrow::Status::OK().
+ ///
+ /// \return Status
+ virtual Status OnBody();
+
+ /// \brief Called when the decoder state is changed to
+ /// MessageDecoder::State::EOS.
+ ///
+ /// The default implementation just returns arrow::Status::OK().
+ ///
+ /// \return Status
+ virtual Status OnEOS();
+};
+
+/// \class AssignMessageDecoderListener
+/// \brief Assign a message decoded by MessageDecoder.
+///
+/// This API is EXPERIMENTAL.
+///
+/// \since 0.17.0
+class ARROW_EXPORT AssignMessageDecoderListener : public MessageDecoderListener {
+ public:
+ /// \brief Construct a listener that assigns a decoded message to the
+ /// specified location.
+ ///
+ /// \param[in] message a location to store the received message
+ explicit AssignMessageDecoderListener(std::unique_ptr<Message>* message)
+ : message_(message) {}
+
+ virtual ~AssignMessageDecoderListener() = default;
+
+ Status OnMessageDecoded(std::unique_ptr<Message> message) override {
+ *message_ = std::move(message);
+ return Status::OK();
+ }
+
+ private:
+ std::unique_ptr<Message>* message_;
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(AssignMessageDecoderListener);
+};
+
+/// \class MessageDecoder
+/// \brief Push style message decoder that receives data from user.
+///
+/// This API is EXPERIMENTAL.
+///
+/// \since 0.17.0
+class ARROW_EXPORT MessageDecoder {
+ public:
+ /// \brief State for reading a message
+ enum State {
+ /// The initial state. It requires one of the followings as the next data:
+ ///
+ /// * int32_t continuation token
+ /// * int32_t end-of-stream mark (== 0)
+ /// * int32_t metadata length (backward compatibility for
+ /// reading old IPC messages produced prior to version 0.15.0
+ INITIAL,
+
+ /// It requires int32_t metadata length.
+ METADATA_LENGTH,
+
+ /// It requires metadata.
+ METADATA,
+
+ /// It requires message body.
+ BODY,
+
+ /// The end-of-stream state. No more data is processed.
+ EOS,
+ };
+
+ /// \brief Construct a message decoder.
+ ///
+ /// \param[in] listener a MessageDecoderListener that responds events from
+ /// the decoder
+ /// \param[in] pool an optional MemoryPool to copy metadata on the
+ /// CPU, if required
+ explicit MessageDecoder(std::shared_ptr<MessageDecoderListener> listener,
+ MemoryPool* pool = default_memory_pool());
+
+ /// \brief Construct a message decoder with the specified state.
+ ///
+ /// This is a construct for advanced users that know how to decode
+ /// Message.
+ ///
+ /// \param[in] listener a MessageDecoderListener that responds events from
+ /// the decoder
+ /// \param[in] initial_state an initial state of the decode
+ /// \param[in] initial_next_required_size the number of bytes needed
+ /// to run the next action
+ /// \param[in] pool an optional MemoryPool to copy metadata on the
+ /// CPU, if required
+ MessageDecoder(std::shared_ptr<MessageDecoderListener> listener, State initial_state,
+ int64_t initial_next_required_size,
+ MemoryPool* pool = default_memory_pool());
+
+ virtual ~MessageDecoder();
+
+ /// \brief Feed data to the decoder as a raw data.
+ ///
+ /// If the decoder can decode one or more messages by the data, the
+ /// decoder calls listener->OnMessageDecoded() with a decoded
+ /// message multiple times.
+ ///
+ /// If the state of the decoder is changed, corresponding callbacks
+ /// on listener is called:
+ ///
+ /// * MessageDecoder::State::INITIAL: listener->OnInitial()
+ /// * MessageDecoder::State::METADATA_LENGTH: listener->OnMetadataLength()
+ /// * MessageDecoder::State::METADATA: listener->OnMetadata()
+ /// * MessageDecoder::State::BODY: listener->OnBody()
+ /// * MessageDecoder::State::EOS: listener->OnEOS()
+ ///
+ /// \param[in] data a raw data to be processed. This data isn't
+ /// copied. The passed memory must be kept alive through message
+ /// processing.
+ /// \param[in] size raw data size.
+ /// \return Status
+ Status Consume(const uint8_t* data, int64_t size);
+
+ /// \brief Feed data to the decoder as a Buffer.
+ ///
+ /// If the decoder can decode one or more messages by the Buffer,
+ /// the decoder calls listener->OnMessageDecoded() with a decoded
+ /// message multiple times.
+ ///
+ /// \param[in] buffer a Buffer to be processed.
+ /// \return Status
+ Status Consume(std::shared_ptr<Buffer> buffer);
+
+ /// \brief Return the number of bytes needed to advance the state of
+ /// the decoder.
+ ///
+ /// This method is provided for users who want to optimize performance.
+ /// Normal users don't need to use this method.
+ ///
+ /// Here is an example usage for normal users:
+ ///
+ /// ~~~{.cpp}
+ /// decoder.Consume(buffer1);
+ /// decoder.Consume(buffer2);
+ /// decoder.Consume(buffer3);
+ /// ~~~
+ ///
+ /// Decoder has internal buffer. If consumed data isn't enough to
+ /// advance the state of the decoder, consumed data is buffered to
+ /// the internal buffer. It causes performance overhead.
+ ///
+ /// If you pass next_required_size() size data to each Consume()
+ /// call, the decoder doesn't use its internal buffer. It improves
+ /// performance.
+ ///
+ /// Here is an example usage to avoid using internal buffer:
+ ///
+ /// ~~~{.cpp}
+ /// buffer1 = get_data(decoder.next_required_size());
+ /// decoder.Consume(buffer1);
+ /// buffer2 = get_data(decoder.next_required_size());
+ /// decoder.Consume(buffer2);
+ /// ~~~
+ ///
+ /// Users can use this method to avoid creating small
+ /// chunks. Message body must be contiguous data. If users pass
+ /// small chunks to the decoder, the decoder needs concatenate small
+ /// chunks internally. It causes performance overhead.
+ ///
+ /// Here is an example usage to reduce small chunks:
+ ///
+ /// ~~~{.cpp}
+ /// buffer = AllocateResizableBuffer();
+ /// while ((small_chunk = get_data(&small_chunk_size))) {
+ /// auto current_buffer_size = buffer->size();
+ /// buffer->Resize(current_buffer_size + small_chunk_size);
+ /// memcpy(buffer->mutable_data() + current_buffer_size,
+ /// small_chunk,
+ /// small_chunk_size);
+ /// if (buffer->size() < decoder.next_required_size()) {
+ /// continue;
+ /// }
+ /// std::shared_ptr<arrow::Buffer> chunk(buffer.release());
+ /// decoder.Consume(chunk);
+ /// buffer = AllocateResizableBuffer();
+ /// }
+ /// if (buffer->size() > 0) {
+ /// std::shared_ptr<arrow::Buffer> chunk(buffer.release());
+ /// decoder.Consume(chunk);
+ /// }
+ /// ~~~
+ ///
+ /// \return the number of bytes needed to advance the state of the
+ /// decoder
+ int64_t next_required_size() const;
+
+ /// \brief Return the current state of the decoder.
+ ///
+ /// This method is provided for users who want to optimize performance.
+ /// Normal users don't need to use this method.
+ ///
+ /// Decoder doesn't need Buffer to process data on the
+ /// MessageDecoder::State::INITIAL state and the
+ /// MessageDecoder::State::METADATA_LENGTH. Creating Buffer has
+ /// performance overhead. Advanced users can avoid creating Buffer
+ /// by checking the current state of the decoder:
+ ///
+ /// ~~~{.cpp}
+ /// switch (decoder.state()) {
+ /// MessageDecoder::State::INITIAL:
+ /// MessageDecoder::State::METADATA_LENGTH:
+ /// {
+ /// uint8_t data[sizeof(int32_t)];
+ /// auto data_size = input->Read(decoder.next_required_size(), data);
+ /// decoder.Consume(data, data_size);
+ /// }
+ /// break;
+ /// default:
+ /// {
+ /// auto buffer = input->Read(decoder.next_required_size());
+ /// decoder.Consume(buffer);
+ /// }
+ /// break;
+ /// }
+ /// ~~~
+ ///
+ /// \return the current state
+ State state() const;
+
+ private:
+ class MessageDecoderImpl;
+ std::unique_ptr<MessageDecoderImpl> impl_;
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(MessageDecoder);
+};
+
+/// \brief Abstract interface for a sequence of messages
+/// \since 0.5.0
+class ARROW_EXPORT MessageReader {
+ public:
+ virtual ~MessageReader() = default;
+
+ /// \brief Create MessageReader that reads from InputStream
+ static std::unique_ptr<MessageReader> Open(io::InputStream* stream);
+
+ /// \brief Create MessageReader that reads from owned InputStream
+ static std::unique_ptr<MessageReader> Open(
+ const std::shared_ptr<io::InputStream>& owned_stream);
+
+ /// \brief Read next Message from the interface
+ ///
+ /// \return an arrow::ipc::Message instance
+ virtual Result<std::unique_ptr<Message>> ReadNextMessage() = 0;
+};
+
+/// \brief Read encapsulated RPC message from position in file
+///
+/// Read a length-prefixed message flatbuffer starting at the indicated file
+/// offset. If the message has a body with non-zero length, it will also be
+/// read
+///
+/// The metadata_length includes at least the length prefix and the flatbuffer
+///
+/// \param[in] offset the position in the file where the message starts. The
+/// first 4 bytes after the offset are the message length
+/// \param[in] metadata_length the total number of bytes to read from file
+/// \param[in] file the seekable file interface to read from
+/// \return the message read
+ARROW_EXPORT
+Result<std::unique_ptr<Message>> ReadMessage(const int64_t offset,
+ const int32_t metadata_length,
+ io::RandomAccessFile* file);
+
+ARROW_EXPORT
+Future<std::shared_ptr<Message>> ReadMessageAsync(
+ const int64_t offset, const int32_t metadata_length, const int64_t body_length,
+ io::RandomAccessFile* file, const io::IOContext& context = io::default_io_context());
+
+/// \brief Advance stream to an 8-byte offset if its position is not a multiple
+/// of 8 already
+/// \param[in] stream an input stream
+/// \param[in] alignment the byte multiple for the metadata prefix, usually 8
+/// or 64, to ensure the body starts on a multiple of that alignment
+/// \return Status
+ARROW_EXPORT
+Status AlignStream(io::InputStream* stream, int32_t alignment = 8);
+
+/// \brief Advance stream to an 8-byte offset if its position is not a multiple
+/// of 8 already
+/// \param[in] stream an output stream
+/// \param[in] alignment the byte multiple for the metadata prefix, usually 8
+/// or 64, to ensure the body starts on a multiple of that alignment
+/// \return Status
+ARROW_EXPORT
+Status AlignStream(io::OutputStream* stream, int32_t alignment = 8);
+
+/// \brief Return error Status if file position is not a multiple of the
+/// indicated alignment
+ARROW_EXPORT
+Status CheckAligned(io::FileInterface* stream, int32_t alignment = 8);
+
+/// \brief Read encapsulated IPC message (metadata and body) from InputStream
+///
+/// Returns null if there are not enough bytes available or the
+/// message length is 0 (e.g. EOS in a stream)
+///
+/// \param[in] stream an input stream
+/// \param[in] pool an optional MemoryPool to copy metadata on the CPU, if required
+/// \return Message
+ARROW_EXPORT
+Result<std::unique_ptr<Message>> ReadMessage(io::InputStream* stream,
+ MemoryPool* pool = default_memory_pool());
+
+/// \brief Feed data from InputStream to MessageDecoder to decode an
+/// encapsulated IPC message (metadata and body)
+///
+/// This API is EXPERIMENTAL.
+///
+/// \param[in] decoder a decoder
+/// \param[in] stream an input stream
+/// \return Status
+///
+/// \since 0.17.0
+ARROW_EXPORT
+Status DecodeMessage(MessageDecoder* decoder, io::InputStream* stream);
+
+/// Write encapsulated IPC message Does not make assumptions about
+/// whether the stream is aligned already. Can write legacy (pre
+/// version 0.15.0) IPC message if option set
+///
+/// continuation: 0xFFFFFFFF
+/// message_size: int32
+/// message: const void*
+/// padding
+///
+///
+/// \param[in] message a buffer containing the metadata to write
+/// \param[in] options IPC writing options, including alignment and
+/// legacy message support
+/// \param[in,out] file the OutputStream to write to
+/// \param[out] message_length the total size of the payload written including
+/// padding
+/// \return Status
+Status WriteMessage(const Buffer& message, const IpcWriteOptions& options,
+ io::OutputStream* file, int32_t* message_length);
+
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/metadata_internal.cc b/src/arrow/cpp/src/arrow/ipc/metadata_internal.cc
new file mode 100644
index 000000000..f7fd46ee8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/metadata_internal.cc
@@ -0,0 +1,1497 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/ipc/metadata_internal.h"
+
+#include <cstdint>
+#include <memory>
+#include <sstream>
+#include <unordered_map>
+#include <utility>
+
+#include <flatbuffers/flatbuffers.h>
+
+#include "arrow/extension_type.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/ipc/dictionary.h"
+#include "arrow/ipc/message.h"
+#include "arrow/ipc/options.h"
+#include "arrow/ipc/util.h"
+#include "arrow/sparse_tensor.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/ubsan.h"
+#include "arrow/visitor_inline.h"
+
+#include "generated/File_generated.h"
+#include "generated/Message_generated.h"
+#include "generated/Schema_generated.h"
+#include "generated/SparseTensor_generated.h"
+#include "generated/Tensor_generated.h"
+
+namespace arrow {
+
+namespace flatbuf = org::apache::arrow::flatbuf;
+using internal::checked_cast;
+using internal::GetByteWidth;
+
+namespace ipc {
+namespace internal {
+
+using FBB = flatbuffers::FlatBufferBuilder;
+using DictionaryOffset = flatbuffers::Offset<flatbuf::DictionaryEncoding>;
+using FieldOffset = flatbuffers::Offset<flatbuf::Field>;
+using RecordBatchOffset = flatbuffers::Offset<flatbuf::RecordBatch>;
+using SparseTensorOffset = flatbuffers::Offset<flatbuf::SparseTensor>;
+using Offset = flatbuffers::Offset<void>;
+using FBString = flatbuffers::Offset<flatbuffers::String>;
+
+MetadataVersion GetMetadataVersion(flatbuf::MetadataVersion version) {
+ switch (version) {
+ case flatbuf::MetadataVersion::V1:
+ // Arrow 0.1
+ return MetadataVersion::V1;
+ case flatbuf::MetadataVersion::V2:
+ // Arrow 0.2
+ return MetadataVersion::V2;
+ case flatbuf::MetadataVersion::V3:
+ // Arrow 0.3 to 0.7.1
+ return MetadataVersion::V3;
+ case flatbuf::MetadataVersion::V4:
+ // Arrow 0.8 to 0.17
+ return MetadataVersion::V4;
+ case flatbuf::MetadataVersion::V5:
+ // Arrow >= 1.0
+ return MetadataVersion::V5;
+ // Add cases as other versions become available
+ default:
+ return MetadataVersion::V5;
+ }
+}
+
+flatbuf::MetadataVersion MetadataVersionToFlatbuffer(MetadataVersion version) {
+ switch (version) {
+ case MetadataVersion::V1:
+ return flatbuf::MetadataVersion::V1;
+ case MetadataVersion::V2:
+ return flatbuf::MetadataVersion::V2;
+ case MetadataVersion::V3:
+ return flatbuf::MetadataVersion::V3;
+ case MetadataVersion::V4:
+ return flatbuf::MetadataVersion::V4;
+ case MetadataVersion::V5:
+ return flatbuf::MetadataVersion::V5;
+ // Add cases as other versions become available
+ default:
+ return flatbuf::MetadataVersion::V5;
+ }
+}
+
+bool HasValidityBitmap(Type::type type_id, MetadataVersion version) {
+ // In V4, null types have no validity bitmap
+ // In V5 and later, null and union types have no validity bitmap
+ return (version < MetadataVersion::V5) ? (type_id != Type::NA)
+ : ::arrow::internal::HasValidityBitmap(type_id);
+}
+
+namespace {
+
+Status IntFromFlatbuffer(const flatbuf::Int* int_data, std::shared_ptr<DataType>* out) {
+ if (int_data->bitWidth() > 64) {
+ return Status::NotImplemented("Integers with more than 64 bits not implemented");
+ }
+ if (int_data->bitWidth() < 8) {
+ return Status::NotImplemented("Integers with less than 8 bits not implemented");
+ }
+
+ switch (int_data->bitWidth()) {
+ case 8:
+ *out = int_data->is_signed() ? int8() : uint8();
+ break;
+ case 16:
+ *out = int_data->is_signed() ? int16() : uint16();
+ break;
+ case 32:
+ *out = int_data->is_signed() ? int32() : uint32();
+ break;
+ case 64:
+ *out = int_data->is_signed() ? int64() : uint64();
+ break;
+ default:
+ return Status::NotImplemented("Integers not in cstdint are not implemented");
+ }
+ return Status::OK();
+}
+
+Status FloatFromFlatbuffer(const flatbuf::FloatingPoint* float_data,
+ std::shared_ptr<DataType>* out) {
+ if (float_data->precision() == flatbuf::Precision::HALF) {
+ *out = float16();
+ } else if (float_data->precision() == flatbuf::Precision::SINGLE) {
+ *out = float32();
+ } else {
+ *out = float64();
+ }
+ return Status::OK();
+}
+
+Offset IntToFlatbuffer(FBB& fbb, int bitWidth, bool is_signed) {
+ return flatbuf::CreateInt(fbb, bitWidth, is_signed).Union();
+}
+
+Offset FloatToFlatbuffer(FBB& fbb, flatbuf::Precision precision) {
+ return flatbuf::CreateFloatingPoint(fbb, precision).Union();
+}
+
+// ----------------------------------------------------------------------
+// Union implementation
+
+Status UnionFromFlatbuffer(const flatbuf::Union* union_data,
+ const std::vector<std::shared_ptr<Field>>& children,
+ std::shared_ptr<DataType>* out) {
+ UnionMode::type mode =
+ (union_data->mode() == flatbuf::UnionMode::Sparse ? UnionMode::SPARSE
+ : UnionMode::DENSE);
+
+ std::vector<int8_t> type_codes;
+
+ const flatbuffers::Vector<int32_t>* fb_type_ids = union_data->typeIds();
+ if (fb_type_ids == nullptr) {
+ for (int8_t i = 0; i < static_cast<int8_t>(children.size()); ++i) {
+ type_codes.push_back(i);
+ }
+ } else {
+ for (int32_t id : (*fb_type_ids)) {
+ const auto type_code = static_cast<int8_t>(id);
+ if (id != type_code) {
+ return Status::Invalid("union type id out of bounds");
+ }
+ type_codes.push_back(type_code);
+ }
+ }
+
+ if (mode == UnionMode::SPARSE) {
+ ARROW_ASSIGN_OR_RAISE(
+ *out, SparseUnionType::Make(std::move(children), std::move(type_codes)));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ *out, DenseUnionType::Make(std::move(children), std::move(type_codes)));
+ }
+ return Status::OK();
+}
+
+#define INT_TO_FB_CASE(BIT_WIDTH, IS_SIGNED) \
+ *out_type = flatbuf::Type::Int; \
+ *offset = IntToFlatbuffer(fbb, BIT_WIDTH, IS_SIGNED); \
+ break;
+
+static inline flatbuf::TimeUnit ToFlatbufferUnit(TimeUnit::type unit) {
+ switch (unit) {
+ case TimeUnit::SECOND:
+ return flatbuf::TimeUnit::SECOND;
+ case TimeUnit::MILLI:
+ return flatbuf::TimeUnit::MILLISECOND;
+ case TimeUnit::MICRO:
+ return flatbuf::TimeUnit::MICROSECOND;
+ case TimeUnit::NANO:
+ return flatbuf::TimeUnit::NANOSECOND;
+ default:
+ break;
+ }
+ return flatbuf::TimeUnit::MIN;
+}
+
+static inline TimeUnit::type FromFlatbufferUnit(flatbuf::TimeUnit unit) {
+ switch (unit) {
+ case flatbuf::TimeUnit::SECOND:
+ return TimeUnit::SECOND;
+ case flatbuf::TimeUnit::MILLISECOND:
+ return TimeUnit::MILLI;
+ case flatbuf::TimeUnit::MICROSECOND:
+ return TimeUnit::MICRO;
+ case flatbuf::TimeUnit::NANOSECOND:
+ return TimeUnit::NANO;
+ default:
+ break;
+ }
+ // cannot reach
+ return TimeUnit::SECOND;
+}
+
+Status ConcreteTypeFromFlatbuffer(flatbuf::Type type, const void* type_data,
+ const std::vector<std::shared_ptr<Field>>& children,
+ std::shared_ptr<DataType>* out) {
+ switch (type) {
+ case flatbuf::Type::NONE:
+ return Status::Invalid("Type metadata cannot be none");
+ case flatbuf::Type::Null:
+ *out = null();
+ return Status::OK();
+ case flatbuf::Type::Int:
+ return IntFromFlatbuffer(static_cast<const flatbuf::Int*>(type_data), out);
+ case flatbuf::Type::FloatingPoint:
+ return FloatFromFlatbuffer(static_cast<const flatbuf::FloatingPoint*>(type_data),
+ out);
+ case flatbuf::Type::Binary:
+ *out = binary();
+ return Status::OK();
+ case flatbuf::Type::LargeBinary:
+ *out = large_binary();
+ return Status::OK();
+ case flatbuf::Type::FixedSizeBinary: {
+ auto fw_binary = static_cast<const flatbuf::FixedSizeBinary*>(type_data);
+ return FixedSizeBinaryType::Make(fw_binary->byteWidth()).Value(out);
+ }
+ case flatbuf::Type::Utf8:
+ *out = utf8();
+ return Status::OK();
+ case flatbuf::Type::LargeUtf8:
+ *out = large_utf8();
+ return Status::OK();
+ case flatbuf::Type::Bool:
+ *out = boolean();
+ return Status::OK();
+ case flatbuf::Type::Decimal: {
+ auto dec_type = static_cast<const flatbuf::Decimal*>(type_data);
+ if (dec_type->bitWidth() == 128) {
+ return Decimal128Type::Make(dec_type->precision(), dec_type->scale()).Value(out);
+ } else if (dec_type->bitWidth() == 256) {
+ return Decimal256Type::Make(dec_type->precision(), dec_type->scale()).Value(out);
+ } else {
+ return Status::Invalid("Library only supports 128-bit or 256-bit decimal values");
+ }
+ }
+ case flatbuf::Type::Date: {
+ auto date_type = static_cast<const flatbuf::Date*>(type_data);
+ if (date_type->unit() == flatbuf::DateUnit::DAY) {
+ *out = date32();
+ } else {
+ *out = date64();
+ }
+ return Status::OK();
+ }
+ case flatbuf::Type::Time: {
+ auto time_type = static_cast<const flatbuf::Time*>(type_data);
+ TimeUnit::type unit = FromFlatbufferUnit(time_type->unit());
+ int32_t bit_width = time_type->bitWidth();
+ switch (unit) {
+ case TimeUnit::SECOND:
+ case TimeUnit::MILLI:
+ if (bit_width != 32) {
+ return Status::Invalid("Time is 32 bits for second/milli unit");
+ }
+ *out = time32(unit);
+ break;
+ default:
+ if (bit_width != 64) {
+ return Status::Invalid("Time is 64 bits for micro/nano unit");
+ }
+ *out = time64(unit);
+ break;
+ }
+ return Status::OK();
+ }
+ case flatbuf::Type::Timestamp: {
+ auto ts_type = static_cast<const flatbuf::Timestamp*>(type_data);
+ TimeUnit::type unit = FromFlatbufferUnit(ts_type->unit());
+ *out = timestamp(unit, StringFromFlatbuffers(ts_type->timezone()));
+ return Status::OK();
+ }
+ case flatbuf::Type::Duration: {
+ auto duration = static_cast<const flatbuf::Duration*>(type_data);
+ TimeUnit::type unit = FromFlatbufferUnit(duration->unit());
+ *out = arrow::duration(unit);
+ return Status::OK();
+ }
+
+ case flatbuf::Type::Interval: {
+ auto i_type = static_cast<const flatbuf::Interval*>(type_data);
+ switch (i_type->unit()) {
+ case flatbuf::IntervalUnit::YEAR_MONTH: {
+ *out = month_interval();
+ return Status::OK();
+ }
+ case flatbuf::IntervalUnit::DAY_TIME: {
+ *out = day_time_interval();
+ return Status::OK();
+ }
+ case flatbuf::IntervalUnit::MONTH_DAY_NANO: {
+ *out = month_day_nano_interval();
+ return Status::OK();
+ }
+ }
+ return Status::NotImplemented("Unrecognized interval type.");
+ }
+
+ case flatbuf::Type::List:
+ if (children.size() != 1) {
+ return Status::Invalid("List must have exactly 1 child field");
+ }
+ *out = std::make_shared<ListType>(children[0]);
+ return Status::OK();
+ case flatbuf::Type::LargeList:
+ if (children.size() != 1) {
+ return Status::Invalid("LargeList must have exactly 1 child field");
+ }
+ *out = std::make_shared<LargeListType>(children[0]);
+ return Status::OK();
+ case flatbuf::Type::Map:
+ if (children.size() != 1) {
+ return Status::Invalid("Map must have exactly 1 child field");
+ }
+ if (children[0]->nullable() || children[0]->type()->id() != Type::STRUCT ||
+ children[0]->type()->num_fields() != 2) {
+ return Status::Invalid("Map's key-item pairs must be non-nullable structs");
+ }
+ if (children[0]->type()->field(0)->nullable()) {
+ return Status::Invalid("Map's keys must be non-nullable");
+ } else {
+ auto map = static_cast<const flatbuf::Map*>(type_data);
+ *out = std::make_shared<MapType>(children[0]->type()->field(0)->type(),
+ children[0]->type()->field(1)->type(),
+ map->keysSorted());
+ }
+ return Status::OK();
+ case flatbuf::Type::FixedSizeList:
+ if (children.size() != 1) {
+ return Status::Invalid("FixedSizeList must have exactly 1 child field");
+ } else {
+ auto fs_list = static_cast<const flatbuf::FixedSizeList*>(type_data);
+ *out = std::make_shared<FixedSizeListType>(children[0], fs_list->listSize());
+ }
+ return Status::OK();
+ case flatbuf::Type::Struct_:
+ *out = std::make_shared<StructType>(children);
+ return Status::OK();
+ case flatbuf::Type::Union:
+ return UnionFromFlatbuffer(static_cast<const flatbuf::Union*>(type_data), children,
+ out);
+ default:
+ return Status::Invalid("Unrecognized type:" +
+ std::to_string(static_cast<int>(type)));
+ }
+}
+
+Status TensorTypeToFlatbuffer(FBB& fbb, const DataType& type, flatbuf::Type* out_type,
+ Offset* offset) {
+ switch (type.id()) {
+ case Type::UINT8:
+ INT_TO_FB_CASE(8, false);
+ case Type::INT8:
+ INT_TO_FB_CASE(8, true);
+ case Type::UINT16:
+ INT_TO_FB_CASE(16, false);
+ case Type::INT16:
+ INT_TO_FB_CASE(16, true);
+ case Type::UINT32:
+ INT_TO_FB_CASE(32, false);
+ case Type::INT32:
+ INT_TO_FB_CASE(32, true);
+ case Type::UINT64:
+ INT_TO_FB_CASE(64, false);
+ case Type::INT64:
+ INT_TO_FB_CASE(64, true);
+ case Type::HALF_FLOAT:
+ *out_type = flatbuf::Type::FloatingPoint;
+ *offset = FloatToFlatbuffer(fbb, flatbuf::Precision::HALF);
+ break;
+ case Type::FLOAT:
+ *out_type = flatbuf::Type::FloatingPoint;
+ *offset = FloatToFlatbuffer(fbb, flatbuf::Precision::SINGLE);
+ break;
+ case Type::DOUBLE:
+ *out_type = flatbuf::Type::FloatingPoint;
+ *offset = FloatToFlatbuffer(fbb, flatbuf::Precision::DOUBLE);
+ break;
+ default:
+ *out_type = flatbuf::Type::NONE; // Make clang-tidy happy
+ return Status::NotImplemented("Unable to convert type: ", type.ToString());
+ }
+ return Status::OK();
+}
+
+static Status GetDictionaryEncoding(FBB& fbb, const std::shared_ptr<Field>& field,
+ const DictionaryType& type, int64_t dictionary_id,
+ DictionaryOffset* out) {
+ // We assume that the dictionary index type (as an integer) has already been
+ // validated elsewhere, and can safely assume we are dealing with integers
+ const auto& index_type = checked_cast<const IntegerType&>(*type.index_type());
+
+ auto index_type_offset =
+ flatbuf::CreateInt(fbb, index_type.bit_width(), index_type.is_signed());
+
+ *out = flatbuf::CreateDictionaryEncoding(fbb, dictionary_id, index_type_offset,
+ type.ordered());
+ return Status::OK();
+}
+
+static KeyValueOffset AppendKeyValue(FBB& fbb, const std::string& key,
+ const std::string& value) {
+ return flatbuf::CreateKeyValue(fbb, fbb.CreateString(key), fbb.CreateString(value));
+}
+
+static void AppendKeyValueMetadata(FBB& fbb, const KeyValueMetadata& metadata,
+ std::vector<KeyValueOffset>* key_values) {
+ key_values->reserve(metadata.size());
+ for (int i = 0; i < metadata.size(); ++i) {
+ key_values->push_back(AppendKeyValue(fbb, metadata.key(i), metadata.value(i)));
+ }
+}
+
+class FieldToFlatbufferVisitor {
+ public:
+ FieldToFlatbufferVisitor(FBB& fbb, const DictionaryFieldMapper& mapper,
+ const FieldPosition& field_pos)
+ : fbb_(fbb), mapper_(mapper), field_pos_(field_pos) {}
+
+ Status VisitType(const DataType& type) { return VisitTypeInline(type, this); }
+
+ Status Visit(const NullType& type) {
+ fb_type_ = flatbuf::Type::Null;
+ type_offset_ = flatbuf::CreateNull(fbb_).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const BooleanType& type) {
+ fb_type_ = flatbuf::Type::Bool;
+ type_offset_ = flatbuf::CreateBool(fbb_).Union();
+ return Status::OK();
+ }
+
+ template <int BitWidth, bool IsSigned, typename T>
+ Status Visit(const T& type) {
+ fb_type_ = flatbuf::Type::Int;
+ type_offset_ = IntToFlatbuffer(fbb_, BitWidth, IsSigned);
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_integer<T, Status> Visit(const T& type) {
+ constexpr bool is_signed = is_signed_integer_type<T>::value;
+ return Visit<sizeof(typename T::c_type) * 8, is_signed>(type);
+ }
+
+ Status Visit(const HalfFloatType& type) {
+ fb_type_ = flatbuf::Type::FloatingPoint;
+ type_offset_ = FloatToFlatbuffer(fbb_, flatbuf::Precision::HALF);
+ return Status::OK();
+ }
+
+ Status Visit(const FloatType& type) {
+ fb_type_ = flatbuf::Type::FloatingPoint;
+ type_offset_ = FloatToFlatbuffer(fbb_, flatbuf::Precision::SINGLE);
+ return Status::OK();
+ }
+
+ Status Visit(const DoubleType& type) {
+ fb_type_ = flatbuf::Type::FloatingPoint;
+ type_offset_ = FloatToFlatbuffer(fbb_, flatbuf::Precision::DOUBLE);
+ return Status::OK();
+ }
+
+ Status Visit(const FixedSizeBinaryType& type) {
+ const auto& fw_type = checked_cast<const FixedSizeBinaryType&>(type);
+ fb_type_ = flatbuf::Type::FixedSizeBinary;
+ type_offset_ = flatbuf::CreateFixedSizeBinary(fbb_, fw_type.byte_width()).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const BinaryType& type) {
+ fb_type_ = flatbuf::Type::Binary;
+ type_offset_ = flatbuf::CreateBinary(fbb_).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const LargeBinaryType& type) {
+ fb_type_ = flatbuf::Type::LargeBinary;
+ type_offset_ = flatbuf::CreateLargeBinary(fbb_).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const StringType& type) {
+ fb_type_ = flatbuf::Type::Utf8;
+ type_offset_ = flatbuf::CreateUtf8(fbb_).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const LargeStringType& type) {
+ fb_type_ = flatbuf::Type::LargeUtf8;
+ type_offset_ = flatbuf::CreateLargeUtf8(fbb_).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const Date32Type& type) {
+ fb_type_ = flatbuf::Type::Date;
+ type_offset_ = flatbuf::CreateDate(fbb_, flatbuf::DateUnit::DAY).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const Date64Type& type) {
+ fb_type_ = flatbuf::Type::Date;
+ type_offset_ = flatbuf::CreateDate(fbb_, flatbuf::DateUnit::MILLISECOND).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const Time32Type& type) {
+ const auto& time_type = checked_cast<const Time32Type&>(type);
+ fb_type_ = flatbuf::Type::Time;
+ type_offset_ =
+ flatbuf::CreateTime(fbb_, ToFlatbufferUnit(time_type.unit()), 32).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const Time64Type& type) {
+ const auto& time_type = checked_cast<const Time64Type&>(type);
+ fb_type_ = flatbuf::Type::Time;
+ type_offset_ =
+ flatbuf::CreateTime(fbb_, ToFlatbufferUnit(time_type.unit()), 64).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const TimestampType& type) {
+ const auto& ts_type = checked_cast<const TimestampType&>(type);
+ fb_type_ = flatbuf::Type::Timestamp;
+ flatbuf::TimeUnit fb_unit = ToFlatbufferUnit(ts_type.unit());
+ FBString fb_timezone = 0;
+ if (ts_type.timezone().size() > 0) {
+ fb_timezone = fbb_.CreateString(ts_type.timezone());
+ }
+ type_offset_ = flatbuf::CreateTimestamp(fbb_, fb_unit, fb_timezone).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const DurationType& type) {
+ fb_type_ = flatbuf::Type::Duration;
+ flatbuf::TimeUnit fb_unit = ToFlatbufferUnit(type.unit());
+ type_offset_ = flatbuf::CreateDuration(fbb_, fb_unit).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const DayTimeIntervalType& type) {
+ fb_type_ = flatbuf::Type::Interval;
+ type_offset_ = flatbuf::CreateInterval(fbb_, flatbuf::IntervalUnit::DAY_TIME).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const MonthDayNanoIntervalType& type) {
+ fb_type_ = flatbuf::Type::Interval;
+ type_offset_ =
+ flatbuf::CreateInterval(fbb_, flatbuf::IntervalUnit::MONTH_DAY_NANO).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const MonthIntervalType& type) {
+ fb_type_ = flatbuf::Type::Interval;
+ type_offset_ =
+ flatbuf::CreateInterval(fbb_, flatbuf::IntervalUnit::YEAR_MONTH).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal128Type& type) {
+ const auto& dec_type = checked_cast<const Decimal128Type&>(type);
+ fb_type_ = flatbuf::Type::Decimal;
+ type_offset_ = flatbuf::CreateDecimal(fbb_, dec_type.precision(), dec_type.scale(),
+ /*bitWidth=*/128)
+ .Union();
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal256Type& type) {
+ const auto& dec_type = checked_cast<const Decimal256Type&>(type);
+ fb_type_ = flatbuf::Type::Decimal;
+ type_offset_ = flatbuf::CreateDecimal(fbb_, dec_type.precision(), dec_type.scale(),
+ /*bitWith=*/256)
+ .Union();
+ return Status::OK();
+ }
+
+ Status Visit(const ListType& type) {
+ fb_type_ = flatbuf::Type::List;
+ RETURN_NOT_OK(VisitChildFields(type));
+ type_offset_ = flatbuf::CreateList(fbb_).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const LargeListType& type) {
+ fb_type_ = flatbuf::Type::LargeList;
+ RETURN_NOT_OK(VisitChildFields(type));
+ type_offset_ = flatbuf::CreateLargeList(fbb_).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const MapType& type) {
+ fb_type_ = flatbuf::Type::Map;
+ RETURN_NOT_OK(VisitChildFields(type));
+ type_offset_ = flatbuf::CreateMap(fbb_, type.keys_sorted()).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const FixedSizeListType& type) {
+ fb_type_ = flatbuf::Type::FixedSizeList;
+ RETURN_NOT_OK(VisitChildFields(type));
+ type_offset_ = flatbuf::CreateFixedSizeList(fbb_, type.list_size()).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const StructType& type) {
+ fb_type_ = flatbuf::Type::Struct_;
+ RETURN_NOT_OK(VisitChildFields(type));
+ type_offset_ = flatbuf::CreateStruct_(fbb_).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const UnionType& type) {
+ fb_type_ = flatbuf::Type::Union;
+ RETURN_NOT_OK(VisitChildFields(type));
+
+ const auto& union_type = checked_cast<const UnionType&>(type);
+
+ flatbuf::UnionMode mode = union_type.mode() == UnionMode::SPARSE
+ ? flatbuf::UnionMode::Sparse
+ : flatbuf::UnionMode::Dense;
+
+ std::vector<int32_t> type_ids;
+ type_ids.reserve(union_type.type_codes().size());
+ for (uint8_t code : union_type.type_codes()) {
+ type_ids.push_back(code);
+ }
+
+ auto fb_type_ids = fbb_.CreateVector(type_ids.data(), type_ids.size());
+
+ type_offset_ = flatbuf::CreateUnion(fbb_, mode, fb_type_ids).Union();
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryType& type) {
+ // In this library, the dictionary "type" is a logical construct. Here we
+ // pass through to the value type, as we've already captured the index
+ // type in the DictionaryEncoding metadata in the parent field
+ return VisitType(*checked_cast<const DictionaryType&>(type).value_type());
+ }
+
+ Status Visit(const ExtensionType& type) {
+ RETURN_NOT_OK(VisitType(*type.storage_type()));
+ extra_type_metadata_[kExtensionTypeKeyName] = type.extension_name();
+ extra_type_metadata_[kExtensionMetadataKeyName] = type.Serialize();
+ return Status::OK();
+ }
+
+ Status VisitChildFields(const DataType& type) {
+ for (int i = 0; i < type.num_fields(); ++i) {
+ FieldOffset child_offset;
+ FieldToFlatbufferVisitor child_visitor(fbb_, mapper_, field_pos_.child(i));
+ RETURN_NOT_OK(child_visitor.GetResult(type.field(i), &child_offset));
+ children_.push_back(child_offset);
+ }
+ return Status::OK();
+ }
+
+ Status GetResult(const std::shared_ptr<Field>& field, FieldOffset* offset) {
+ RETURN_NOT_OK(VisitType(*field->type()));
+
+ DictionaryOffset dictionary = 0;
+ const DataType* storage_type = field->type().get();
+ if (storage_type->id() == Type::EXTENSION) {
+ storage_type =
+ checked_cast<const ExtensionType&>(*storage_type).storage_type().get();
+ }
+ if (storage_type->id() == Type::DICTIONARY) {
+ ARROW_ASSIGN_OR_RAISE(const auto dictionary_id,
+ mapper_.GetFieldId(field_pos_.path()));
+ RETURN_NOT_OK(GetDictionaryEncoding(
+ fbb_, field, checked_cast<const DictionaryType&>(*storage_type), dictionary_id,
+ &dictionary));
+ }
+
+ auto metadata = field->metadata();
+
+ flatbuffers::Offset<KVVector> fb_custom_metadata;
+ std::vector<KeyValueOffset> key_values;
+ if (metadata != nullptr) {
+ AppendKeyValueMetadata(fbb_, *metadata, &key_values);
+ }
+
+ for (const auto& pair : extra_type_metadata_) {
+ key_values.push_back(AppendKeyValue(fbb_, pair.first, pair.second));
+ }
+
+ if (key_values.size() > 0) {
+ fb_custom_metadata = fbb_.CreateVector(key_values);
+ }
+
+ auto fb_name = fbb_.CreateString(field->name());
+ auto fb_children = fbb_.CreateVector(children_.data(), children_.size());
+ *offset =
+ flatbuf::CreateField(fbb_, fb_name, field->nullable(), fb_type_, type_offset_,
+ dictionary, fb_children, fb_custom_metadata);
+ return Status::OK();
+ }
+
+ private:
+ FBB& fbb_;
+ const DictionaryFieldMapper& mapper_;
+ FieldPosition field_pos_;
+ flatbuf::Type fb_type_;
+ Offset type_offset_;
+ std::vector<FieldOffset> children_;
+ std::unordered_map<std::string, std::string> extra_type_metadata_;
+};
+
+Status FieldFromFlatbuffer(const flatbuf::Field* field, FieldPosition field_pos,
+ DictionaryMemo* dictionary_memo, std::shared_ptr<Field>* out) {
+ std::shared_ptr<DataType> type;
+
+ std::shared_ptr<KeyValueMetadata> metadata;
+ RETURN_NOT_OK(internal::GetKeyValueMetadata(field->custom_metadata(), &metadata));
+
+ // Reconstruct the data type
+ // 1. Data type children
+ FieldVector child_fields;
+ const auto& children = field->children();
+ // As a tolerance, allow for a null children field meaning "no children" (ARROW-12100)
+ if (children != nullptr) {
+ child_fields.resize(children->size());
+ for (int i = 0; i < static_cast<int>(children->size()); ++i) {
+ RETURN_NOT_OK(FieldFromFlatbuffer(children->Get(i), field_pos.child(i),
+ dictionary_memo, &child_fields[i]));
+ }
+ }
+
+ // 2. Top-level concrete data type
+ auto type_data = field->type();
+ CHECK_FLATBUFFERS_NOT_NULL(type_data, "Field.type");
+ RETURN_NOT_OK(
+ ConcreteTypeFromFlatbuffer(field->type_type(), type_data, child_fields, &type));
+
+ // 3. Is it a dictionary type?
+ int64_t dictionary_id = -1;
+ std::shared_ptr<DataType> dict_value_type;
+ const flatbuf::DictionaryEncoding* encoding = field->dictionary();
+ if (encoding != nullptr) {
+ // The field is dictionary-encoded. Construct the DictionaryType
+ // based on the DictionaryEncoding metadata and record in the
+ // dictionary_memo
+ std::shared_ptr<DataType> index_type;
+ auto int_data = encoding->indexType();
+ CHECK_FLATBUFFERS_NOT_NULL(int_data, "DictionaryEncoding.indexType");
+ RETURN_NOT_OK(IntFromFlatbuffer(int_data, &index_type));
+ dict_value_type = type;
+ ARROW_ASSIGN_OR_RAISE(type,
+ DictionaryType::Make(index_type, type, encoding->isOrdered()));
+ dictionary_id = encoding->id();
+ }
+
+ // 4. Is it an extension type?
+ if (metadata != nullptr) {
+ // Look for extension metadata in custom_metadata field
+ int name_index = metadata->FindKey(kExtensionTypeKeyName);
+ if (name_index != -1) {
+ std::shared_ptr<ExtensionType> ext_type =
+ GetExtensionType(metadata->value(name_index));
+ if (ext_type != nullptr) {
+ int data_index = metadata->FindKey(kExtensionMetadataKeyName);
+ std::string type_data = data_index == -1 ? "" : metadata->value(data_index);
+
+ ARROW_ASSIGN_OR_RAISE(type, ext_type->Deserialize(type, type_data));
+ // Remove the metadata, for faithful roundtripping
+ if (data_index != -1) {
+ RETURN_NOT_OK(metadata->DeleteMany({name_index, data_index}));
+ } else {
+ RETURN_NOT_OK(metadata->Delete(name_index));
+ }
+ }
+ // NOTE: if extension type is unknown, we do not raise here and
+ // simply return the storage type.
+ }
+ }
+
+ // Reconstruct field
+ auto field_name = StringFromFlatbuffers(field->name());
+ *out =
+ ::arrow::field(std::move(field_name), type, field->nullable(), std::move(metadata));
+ if (dictionary_id != -1) {
+ // We need both the id -> type mapping (to find the value type when
+ // reading a dictionary batch)
+ // and the field path -> id mapping (to find the dictionary when
+ // reading a record batch)
+ RETURN_NOT_OK(dictionary_memo->fields().AddField(dictionary_id, field_pos.path()));
+ RETURN_NOT_OK(dictionary_memo->AddDictionaryType(dictionary_id, dict_value_type));
+ }
+ return Status::OK();
+}
+
+// will return the endianness of the system we are running on
+// based the NUMPY_API function. See NOTICE.txt
+flatbuf::Endianness endianness() {
+ union {
+ uint32_t i;
+ char c[4];
+ } bint = {0x01020304};
+
+ return bint.c[0] == 1 ? flatbuf::Endianness::Big : flatbuf::Endianness::Little;
+}
+
+flatbuffers::Offset<KVVector> SerializeCustomMetadata(
+ FBB& fbb, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ std::vector<KeyValueOffset> key_values;
+ if (metadata != nullptr) {
+ AppendKeyValueMetadata(fbb, *metadata, &key_values);
+ return fbb.CreateVector(key_values);
+ } else {
+ // null
+ return 0;
+ }
+}
+
+Status SchemaToFlatbuffer(FBB& fbb, const Schema& schema,
+ const DictionaryFieldMapper& mapper,
+ flatbuffers::Offset<flatbuf::Schema>* out) {
+ std::vector<FieldOffset> field_offsets;
+ FieldPosition pos;
+ for (int i = 0; i < schema.num_fields(); ++i) {
+ FieldOffset offset;
+ FieldToFlatbufferVisitor field_visitor(fbb, mapper, pos.child(i));
+ RETURN_NOT_OK(field_visitor.GetResult(schema.field(i), &offset));
+ field_offsets.push_back(offset);
+ }
+
+ auto fb_offsets = fbb.CreateVector(field_offsets);
+ *out = flatbuf::CreateSchema(fbb, endianness(), fb_offsets,
+ SerializeCustomMetadata(fbb, schema.metadata()));
+ return Status::OK();
+}
+
+Result<std::shared_ptr<Buffer>> WriteFBMessage(
+ FBB& fbb, flatbuf::MessageHeader header_type, flatbuffers::Offset<void> header,
+ int64_t body_length, MetadataVersion version,
+ const std::shared_ptr<const KeyValueMetadata>& custom_metadata, MemoryPool* pool) {
+ auto message = flatbuf::CreateMessage(fbb, MetadataVersionToFlatbuffer(version),
+ header_type, header, body_length,
+ SerializeCustomMetadata(fbb, custom_metadata));
+ fbb.Finish(message);
+ return WriteFlatbufferBuilder(fbb, pool);
+}
+
+using FieldNodeVector =
+ flatbuffers::Offset<flatbuffers::Vector<const flatbuf::FieldNode*>>;
+using BufferVector = flatbuffers::Offset<flatbuffers::Vector<const flatbuf::Buffer*>>;
+using BodyCompressionOffset = flatbuffers::Offset<flatbuf::BodyCompression>;
+
+static Status WriteFieldNodes(FBB& fbb, const std::vector<FieldMetadata>& nodes,
+ FieldNodeVector* out) {
+ std::vector<flatbuf::FieldNode> fb_nodes;
+ fb_nodes.reserve(nodes.size());
+
+ for (size_t i = 0; i < nodes.size(); ++i) {
+ const FieldMetadata& node = nodes[i];
+ if (node.offset != 0) {
+ return Status::Invalid("Field metadata for IPC must have offset 0");
+ }
+ fb_nodes.emplace_back(node.length, node.null_count);
+ }
+ *out = fbb.CreateVectorOfStructs(fb_nodes.data(), fb_nodes.size());
+ return Status::OK();
+}
+
+static Status WriteBuffers(FBB& fbb, const std::vector<BufferMetadata>& buffers,
+ BufferVector* out) {
+ std::vector<flatbuf::Buffer> fb_buffers;
+ fb_buffers.reserve(buffers.size());
+
+ for (size_t i = 0; i < buffers.size(); ++i) {
+ const BufferMetadata& buffer = buffers[i];
+ fb_buffers.emplace_back(buffer.offset, buffer.length);
+ }
+ *out = fbb.CreateVectorOfStructs(fb_buffers.data(), fb_buffers.size());
+
+ return Status::OK();
+}
+
+static Status GetBodyCompression(FBB& fbb, const IpcWriteOptions& options,
+ BodyCompressionOffset* out) {
+ if (options.codec != nullptr) {
+ flatbuf::CompressionType codec;
+ if (options.codec->compression_type() == Compression::LZ4_FRAME) {
+ codec = flatbuf::CompressionType::LZ4_FRAME;
+ } else if (options.codec->compression_type() == Compression::ZSTD) {
+ codec = flatbuf::CompressionType::ZSTD;
+ } else {
+ return Status::Invalid("Unsupported IPC compression codec: ",
+ options.codec->name());
+ }
+ *out = flatbuf::CreateBodyCompression(fbb, codec,
+ flatbuf::BodyCompressionMethod::BUFFER);
+ }
+ return Status::OK();
+}
+
+static Status MakeRecordBatch(FBB& fbb, int64_t length, int64_t body_length,
+ const std::vector<FieldMetadata>& nodes,
+ const std::vector<BufferMetadata>& buffers,
+ const IpcWriteOptions& options, RecordBatchOffset* offset) {
+ FieldNodeVector fb_nodes;
+ RETURN_NOT_OK(WriteFieldNodes(fbb, nodes, &fb_nodes));
+
+ BufferVector fb_buffers;
+ RETURN_NOT_OK(WriteBuffers(fbb, buffers, &fb_buffers));
+
+ BodyCompressionOffset fb_compression;
+ RETURN_NOT_OK(GetBodyCompression(fbb, options, &fb_compression));
+
+ *offset = flatbuf::CreateRecordBatch(fbb, length, fb_nodes, fb_buffers, fb_compression);
+ return Status::OK();
+}
+
+Status MakeSparseTensorIndexCOO(FBB& fbb, const SparseCOOIndex& sparse_index,
+ const std::vector<BufferMetadata>& buffers,
+ flatbuf::SparseTensorIndex* fb_sparse_index_type,
+ Offset* fb_sparse_index, size_t* num_buffers) {
+ *fb_sparse_index_type = flatbuf::SparseTensorIndex::SparseTensorIndexCOO;
+
+ // We assume that the value type of indices tensor is an integer.
+ const auto& index_value_type =
+ checked_cast<const IntegerType&>(*sparse_index.indices()->type());
+ auto indices_type_offset =
+ flatbuf::CreateInt(fbb, index_value_type.bit_width(), index_value_type.is_signed());
+
+ auto fb_strides = fbb.CreateVector(sparse_index.indices()->strides().data(),
+ sparse_index.indices()->strides().size());
+
+ const BufferMetadata& indices_metadata = buffers[0];
+ flatbuf::Buffer indices(indices_metadata.offset, indices_metadata.length);
+
+ *fb_sparse_index =
+ flatbuf::CreateSparseTensorIndexCOO(fbb, indices_type_offset, fb_strides, &indices,
+ sparse_index.is_canonical())
+ .Union();
+ *num_buffers = 1;
+ return Status::OK();
+}
+
+template <typename SparseIndexType>
+struct SparseMatrixCompressedAxis {};
+
+template <>
+struct SparseMatrixCompressedAxis<SparseCSRIndex> {
+ constexpr static const auto value = flatbuf::SparseMatrixCompressedAxis::Row;
+};
+
+template <>
+struct SparseMatrixCompressedAxis<SparseCSCIndex> {
+ constexpr static const auto value = flatbuf::SparseMatrixCompressedAxis::Column;
+};
+
+template <typename SparseIndexType>
+Status MakeSparseMatrixIndexCSX(FBB& fbb, const SparseIndexType& sparse_index,
+ const std::vector<BufferMetadata>& buffers,
+ flatbuf::SparseTensorIndex* fb_sparse_index_type,
+ Offset* fb_sparse_index, size_t* num_buffers) {
+ *fb_sparse_index_type = flatbuf::SparseTensorIndex::SparseMatrixIndexCSX;
+
+ // We assume that the value type of indptr tensor is an integer.
+ const auto& indptr_value_type =
+ checked_cast<const IntegerType&>(*sparse_index.indptr()->type());
+ auto indptr_type_offset = flatbuf::CreateInt(fbb, indptr_value_type.bit_width(),
+ indptr_value_type.is_signed());
+
+ const BufferMetadata& indptr_metadata = buffers[0];
+ flatbuf::Buffer indptr(indptr_metadata.offset, indptr_metadata.length);
+
+ // We assume that the value type of indices tensor is an integer.
+ const auto& indices_value_type =
+ checked_cast<const IntegerType&>(*sparse_index.indices()->type());
+ auto indices_type_offset = flatbuf::CreateInt(fbb, indices_value_type.bit_width(),
+ indices_value_type.is_signed());
+
+ const BufferMetadata& indices_metadata = buffers[1];
+ flatbuf::Buffer indices(indices_metadata.offset, indices_metadata.length);
+
+ auto compressedAxis = SparseMatrixCompressedAxis<SparseIndexType>::value;
+ *fb_sparse_index =
+ flatbuf::CreateSparseMatrixIndexCSX(fbb, compressedAxis, indptr_type_offset,
+ &indptr, indices_type_offset, &indices)
+ .Union();
+ *num_buffers = 2;
+ return Status::OK();
+}
+
+Status MakeSparseTensorIndexCSF(FBB& fbb, const SparseCSFIndex& sparse_index,
+ const std::vector<BufferMetadata>& buffers,
+ flatbuf::SparseTensorIndex* fb_sparse_index_type,
+ Offset* fb_sparse_index, size_t* num_buffers) {
+ *fb_sparse_index_type = flatbuf::SparseTensorIndex::SparseTensorIndexCSF;
+ const int ndim = static_cast<int>(sparse_index.axis_order().size());
+
+ // We assume that the value type of indptr tensor is an integer.
+ const auto& indptr_value_type =
+ checked_cast<const IntegerType&>(*sparse_index.indptr()[0]->type());
+ auto indptr_type_offset = flatbuf::CreateInt(fbb, indptr_value_type.bit_width(),
+ indptr_value_type.is_signed());
+
+ // We assume that the value type of indices tensor is an integer.
+ const auto& indices_value_type =
+ checked_cast<const IntegerType&>(*sparse_index.indices()[0]->type());
+ auto indices_type_offset = flatbuf::CreateInt(fbb, indices_value_type.bit_width(),
+ indices_value_type.is_signed());
+
+ const int64_t indptr_elem_size = GetByteWidth(indptr_value_type);
+ const int64_t indices_elem_size = GetByteWidth(indices_value_type);
+
+ int64_t offset = 0;
+ std::vector<flatbuf::Buffer> indptr, indices;
+
+ for (const std::shared_ptr<arrow::Tensor>& tensor : sparse_index.indptr()) {
+ const int64_t size = tensor->data()->size() / indptr_elem_size;
+ const int64_t padded_size = PaddedLength(tensor->data()->size(), kArrowIpcAlignment);
+
+ indptr.push_back({offset, size});
+ offset += padded_size;
+ }
+ for (const std::shared_ptr<arrow::Tensor>& tensor : sparse_index.indices()) {
+ const int64_t size = tensor->data()->size() / indices_elem_size;
+ const int64_t padded_size = PaddedLength(tensor->data()->size(), kArrowIpcAlignment);
+
+ indices.push_back({offset, size});
+ offset += padded_size;
+ }
+
+ auto fb_indices = fbb.CreateVectorOfStructs(indices);
+ auto fb_indptr = fbb.CreateVectorOfStructs(indptr);
+
+ std::vector<int> axis_order;
+ for (int i = 0; i < ndim; ++i) {
+ axis_order.emplace_back(static_cast<int>(sparse_index.axis_order()[i]));
+ }
+ auto fb_axis_order =
+ fbb.CreateVector(arrow::util::MakeNonNull(axis_order.data()), axis_order.size());
+
+ *fb_sparse_index =
+ flatbuf::CreateSparseTensorIndexCSF(fbb, indptr_type_offset, fb_indptr,
+ indices_type_offset, fb_indices, fb_axis_order)
+ .Union();
+ *num_buffers = 2 * ndim - 1;
+ return Status::OK();
+}
+
+Status MakeSparseTensorIndex(FBB& fbb, const SparseIndex& sparse_index,
+ const std::vector<BufferMetadata>& buffers,
+ flatbuf::SparseTensorIndex* fb_sparse_index_type,
+ Offset* fb_sparse_index, size_t* num_buffers) {
+ switch (sparse_index.format_id()) {
+ case SparseTensorFormat::COO:
+ RETURN_NOT_OK(MakeSparseTensorIndexCOO(
+ fbb, checked_cast<const SparseCOOIndex&>(sparse_index), buffers,
+ fb_sparse_index_type, fb_sparse_index, num_buffers));
+ break;
+
+ case SparseTensorFormat::CSR:
+ RETURN_NOT_OK(MakeSparseMatrixIndexCSX(
+ fbb, checked_cast<const SparseCSRIndex&>(sparse_index), buffers,
+ fb_sparse_index_type, fb_sparse_index, num_buffers));
+ break;
+
+ case SparseTensorFormat::CSC:
+ RETURN_NOT_OK(MakeSparseMatrixIndexCSX(
+ fbb, checked_cast<const SparseCSCIndex&>(sparse_index), buffers,
+ fb_sparse_index_type, fb_sparse_index, num_buffers));
+ break;
+
+ case SparseTensorFormat::CSF:
+ RETURN_NOT_OK(MakeSparseTensorIndexCSF(
+ fbb, checked_cast<const SparseCSFIndex&>(sparse_index), buffers,
+ fb_sparse_index_type, fb_sparse_index, num_buffers));
+ break;
+
+ default:
+ *fb_sparse_index_type = flatbuf::SparseTensorIndex::NONE; // Silence warnings
+ std::stringstream ss;
+ ss << "Unsupported sparse tensor format:: " << sparse_index.ToString() << std::endl;
+ return Status::NotImplemented(ss.str());
+ }
+
+ return Status::OK();
+}
+
+Status MakeSparseTensor(FBB& fbb, const SparseTensor& sparse_tensor, int64_t body_length,
+ const std::vector<BufferMetadata>& buffers,
+ SparseTensorOffset* offset) {
+ flatbuf::Type fb_type_type;
+ Offset fb_type;
+ RETURN_NOT_OK(
+ TensorTypeToFlatbuffer(fbb, *sparse_tensor.type(), &fb_type_type, &fb_type));
+
+ using TensorDimOffset = flatbuffers::Offset<flatbuf::TensorDim>;
+ std::vector<TensorDimOffset> dims;
+ for (int i = 0; i < sparse_tensor.ndim(); ++i) {
+ FBString name = fbb.CreateString(sparse_tensor.dim_name(i));
+ dims.push_back(flatbuf::CreateTensorDim(fbb, sparse_tensor.shape()[i], name));
+ }
+
+ auto fb_shape = fbb.CreateVector(dims);
+
+ flatbuf::SparseTensorIndex fb_sparse_index_type;
+ Offset fb_sparse_index;
+ size_t num_index_buffers = 0;
+ RETURN_NOT_OK(MakeSparseTensorIndex(fbb, *sparse_tensor.sparse_index(), buffers,
+ &fb_sparse_index_type, &fb_sparse_index,
+ &num_index_buffers));
+
+ const BufferMetadata& data_metadata = buffers[num_index_buffers];
+ flatbuf::Buffer data(data_metadata.offset, data_metadata.length);
+
+ const int64_t non_zero_length = sparse_tensor.non_zero_length();
+
+ *offset =
+ flatbuf::CreateSparseTensor(fbb, fb_type_type, fb_type, fb_shape, non_zero_length,
+ fb_sparse_index_type, fb_sparse_index, &data);
+
+ return Status::OK();
+}
+
+} // namespace
+
+Status GetKeyValueMetadata(const KVVector* fb_metadata,
+ std::shared_ptr<KeyValueMetadata>* out) {
+ if (fb_metadata == nullptr) {
+ *out = nullptr;
+ return Status::OK();
+ }
+
+ auto metadata = std::make_shared<KeyValueMetadata>();
+
+ metadata->reserve(fb_metadata->size());
+ for (const auto pair : *fb_metadata) {
+ CHECK_FLATBUFFERS_NOT_NULL(pair->key(), "custom_metadata.key");
+ CHECK_FLATBUFFERS_NOT_NULL(pair->value(), "custom_metadata.value");
+ metadata->Append(pair->key()->str(), pair->value()->str());
+ }
+
+ *out = std::move(metadata);
+ return Status::OK();
+}
+
+Status WriteSchemaMessage(const Schema& schema, const DictionaryFieldMapper& mapper,
+ const IpcWriteOptions& options, std::shared_ptr<Buffer>* out) {
+ FBB fbb;
+ flatbuffers::Offset<flatbuf::Schema> fb_schema;
+ RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, mapper, &fb_schema));
+ return WriteFBMessage(fbb, flatbuf::MessageHeader::Schema, fb_schema.Union(),
+ /*body_length=*/0, options.metadata_version,
+ /*custom_metadata=*/nullptr, options.memory_pool)
+ .Value(out);
+}
+
+Status WriteRecordBatchMessage(
+ int64_t length, int64_t body_length,
+ const std::shared_ptr<const KeyValueMetadata>& custom_metadata,
+ const std::vector<FieldMetadata>& nodes, const std::vector<BufferMetadata>& buffers,
+ const IpcWriteOptions& options, std::shared_ptr<Buffer>* out) {
+ FBB fbb;
+ RecordBatchOffset record_batch;
+ RETURN_NOT_OK(
+ MakeRecordBatch(fbb, length, body_length, nodes, buffers, options, &record_batch));
+ return WriteFBMessage(fbb, flatbuf::MessageHeader::RecordBatch, record_batch.Union(),
+ body_length, options.metadata_version, custom_metadata,
+ options.memory_pool)
+ .Value(out);
+}
+
+Result<std::shared_ptr<Buffer>> WriteTensorMessage(const Tensor& tensor,
+ int64_t buffer_start_offset,
+ const IpcWriteOptions& options) {
+ using TensorDimOffset = flatbuffers::Offset<flatbuf::TensorDim>;
+ using TensorOffset = flatbuffers::Offset<flatbuf::Tensor>;
+
+ FBB fbb;
+ const int elem_size = GetByteWidth(*tensor.type());
+
+ flatbuf::Type fb_type_type;
+ Offset fb_type;
+ RETURN_NOT_OK(TensorTypeToFlatbuffer(fbb, *tensor.type(), &fb_type_type, &fb_type));
+
+ std::vector<TensorDimOffset> dims;
+ for (int i = 0; i < tensor.ndim(); ++i) {
+ FBString name = fbb.CreateString(tensor.dim_name(i));
+ dims.push_back(flatbuf::CreateTensorDim(fbb, tensor.shape()[i], name));
+ }
+
+ auto fb_shape = fbb.CreateVector(dims.data(), dims.size());
+
+ flatbuffers::Offset<flatbuffers::Vector<int64_t>> fb_strides;
+ fb_strides = fbb.CreateVector(tensor.strides().data(), tensor.strides().size());
+ int64_t body_length = tensor.size() * elem_size;
+ flatbuf::Buffer buffer(buffer_start_offset, body_length);
+
+ TensorOffset fb_tensor =
+ flatbuf::CreateTensor(fbb, fb_type_type, fb_type, fb_shape, fb_strides, &buffer);
+
+ return WriteFBMessage(fbb, flatbuf::MessageHeader::Tensor, fb_tensor.Union(),
+ body_length, options.metadata_version,
+ /*custom_metadata=*/nullptr, options.memory_pool);
+}
+
+Result<std::shared_ptr<Buffer>> WriteSparseTensorMessage(
+ const SparseTensor& sparse_tensor, int64_t body_length,
+ const std::vector<BufferMetadata>& buffers, const IpcWriteOptions& options) {
+ FBB fbb;
+ SparseTensorOffset fb_sparse_tensor;
+ RETURN_NOT_OK(
+ MakeSparseTensor(fbb, sparse_tensor, body_length, buffers, &fb_sparse_tensor));
+ return WriteFBMessage(fbb, flatbuf::MessageHeader::SparseTensor,
+ fb_sparse_tensor.Union(), body_length, options.metadata_version,
+ /*custom_metadata=*/nullptr, options.memory_pool);
+}
+
+Status WriteDictionaryMessage(
+ int64_t id, bool is_delta, int64_t length, int64_t body_length,
+ const std::shared_ptr<const KeyValueMetadata>& custom_metadata,
+ const std::vector<FieldMetadata>& nodes, const std::vector<BufferMetadata>& buffers,
+ const IpcWriteOptions& options, std::shared_ptr<Buffer>* out) {
+ FBB fbb;
+ RecordBatchOffset record_batch;
+ RETURN_NOT_OK(
+ MakeRecordBatch(fbb, length, body_length, nodes, buffers, options, &record_batch));
+ auto dictionary_batch =
+ flatbuf::CreateDictionaryBatch(fbb, id, record_batch, is_delta).Union();
+ return WriteFBMessage(fbb, flatbuf::MessageHeader::DictionaryBatch, dictionary_batch,
+ body_length, options.metadata_version, custom_metadata,
+ options.memory_pool)
+ .Value(out);
+}
+
+static flatbuffers::Offset<flatbuffers::Vector<const flatbuf::Block*>>
+FileBlocksToFlatbuffer(FBB& fbb, const std::vector<FileBlock>& blocks) {
+ std::vector<flatbuf::Block> fb_blocks;
+
+ for (const FileBlock& block : blocks) {
+ fb_blocks.emplace_back(block.offset, block.metadata_length, block.body_length);
+ }
+
+ return fbb.CreateVectorOfStructs(fb_blocks.data(), fb_blocks.size());
+}
+
+Status WriteFileFooter(const Schema& schema, const std::vector<FileBlock>& dictionaries,
+ const std::vector<FileBlock>& record_batches,
+ const std::shared_ptr<const KeyValueMetadata>& metadata,
+ io::OutputStream* out) {
+ FBB fbb;
+
+ flatbuffers::Offset<flatbuf::Schema> fb_schema;
+ DictionaryFieldMapper mapper(schema);
+ RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, mapper, &fb_schema));
+
+#ifndef NDEBUG
+ for (size_t i = 0; i < dictionaries.size(); ++i) {
+ DCHECK(BitUtil::IsMultipleOf8(dictionaries[i].offset)) << i;
+ DCHECK(BitUtil::IsMultipleOf8(dictionaries[i].metadata_length)) << i;
+ DCHECK(BitUtil::IsMultipleOf8(dictionaries[i].body_length)) << i;
+ }
+
+ for (size_t i = 0; i < record_batches.size(); ++i) {
+ DCHECK(BitUtil::IsMultipleOf8(record_batches[i].offset)) << i;
+ DCHECK(BitUtil::IsMultipleOf8(record_batches[i].metadata_length)) << i;
+ DCHECK(BitUtil::IsMultipleOf8(record_batches[i].body_length)) << i;
+ }
+#endif
+
+ auto fb_dictionaries = FileBlocksToFlatbuffer(fbb, dictionaries);
+ auto fb_record_batches = FileBlocksToFlatbuffer(fbb, record_batches);
+
+ auto fb_custom_metadata = SerializeCustomMetadata(fbb, metadata);
+
+ auto footer =
+ flatbuf::CreateFooter(fbb, kCurrentMetadataVersion, fb_schema, fb_dictionaries,
+ fb_record_batches, fb_custom_metadata);
+ fbb.Finish(footer);
+
+ int32_t size = fbb.GetSize();
+
+ return out->Write(fbb.GetBufferPointer(), size);
+}
+
+// ----------------------------------------------------------------------
+
+Status GetSchema(const void* opaque_schema, DictionaryMemo* dictionary_memo,
+ std::shared_ptr<Schema>* out) {
+ auto schema = static_cast<const flatbuf::Schema*>(opaque_schema);
+ CHECK_FLATBUFFERS_NOT_NULL(schema, "schema");
+ CHECK_FLATBUFFERS_NOT_NULL(schema->fields(), "Schema.fields");
+ int num_fields = static_cast<int>(schema->fields()->size());
+
+ FieldPosition field_pos;
+
+ std::vector<std::shared_ptr<Field>> fields(num_fields);
+ for (int i = 0; i < num_fields; ++i) {
+ const flatbuf::Field* field = schema->fields()->Get(i);
+ // XXX I don't think this check is necessary (AP)
+ CHECK_FLATBUFFERS_NOT_NULL(field, "DictionaryEncoding.indexType");
+ RETURN_NOT_OK(
+ FieldFromFlatbuffer(field, field_pos.child(i), dictionary_memo, &fields[i]));
+ }
+
+ std::shared_ptr<KeyValueMetadata> metadata;
+ RETURN_NOT_OK(internal::GetKeyValueMetadata(schema->custom_metadata(), &metadata));
+ // set endianess using the value in flatbuf schema
+ auto endianness = schema->endianness() == flatbuf::Endianness::Little
+ ? Endianness::Little
+ : Endianness::Big;
+ *out = ::arrow::schema(std::move(fields), endianness, metadata);
+ return Status::OK();
+}
+
+Status GetTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>* type,
+ std::vector<int64_t>* shape, std::vector<int64_t>* strides,
+ std::vector<std::string>* dim_names) {
+ const flatbuf::Message* message = nullptr;
+ RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
+ auto tensor = message->header_as_Tensor();
+ if (tensor == nullptr) {
+ return Status::IOError("Header-type of flatbuffer-encoded Message is not Tensor.");
+ }
+
+ flatbuffers::uoffset_t ndim = tensor->shape()->size();
+
+ for (flatbuffers::uoffset_t i = 0; i < ndim; ++i) {
+ auto dim = tensor->shape()->Get(i);
+
+ shape->push_back(dim->size());
+ dim_names->push_back(StringFromFlatbuffers(dim->name()));
+ }
+
+ if (tensor->strides() && tensor->strides()->size() > 0) {
+ if (tensor->strides()->size() != ndim) {
+ return Status::IOError(
+ "The sizes of shape and strides in a tensor are mismatched.");
+ }
+
+ for (decltype(ndim) i = 0; i < ndim; ++i) {
+ strides->push_back(tensor->strides()->Get(i));
+ }
+ }
+
+ auto type_data = tensor->type(); // Required
+ return ConcreteTypeFromFlatbuffer(tensor->type_type(), type_data, {}, type);
+}
+
+Status GetSparseCOOIndexMetadata(const flatbuf::SparseTensorIndexCOO* sparse_index,
+ std::shared_ptr<DataType>* indices_type) {
+ return IntFromFlatbuffer(sparse_index->indicesType(), indices_type);
+}
+
+Status GetSparseCSXIndexMetadata(const flatbuf::SparseMatrixIndexCSX* sparse_index,
+ std::shared_ptr<DataType>* indptr_type,
+ std::shared_ptr<DataType>* indices_type) {
+ RETURN_NOT_OK(IntFromFlatbuffer(sparse_index->indptrType(), indptr_type));
+ RETURN_NOT_OK(IntFromFlatbuffer(sparse_index->indicesType(), indices_type));
+ return Status::OK();
+}
+
+Status GetSparseCSFIndexMetadata(const flatbuf::SparseTensorIndexCSF* sparse_index,
+ std::vector<int64_t>* axis_order,
+ std::vector<int64_t>* indices_size,
+ std::shared_ptr<DataType>* indptr_type,
+ std::shared_ptr<DataType>* indices_type) {
+ RETURN_NOT_OK(IntFromFlatbuffer(sparse_index->indptrType(), indptr_type));
+ RETURN_NOT_OK(IntFromFlatbuffer(sparse_index->indicesType(), indices_type));
+
+ const int ndim = static_cast<int>(sparse_index->axisOrder()->size());
+ for (int i = 0; i < ndim; ++i) {
+ axis_order->push_back(sparse_index->axisOrder()->Get(i));
+ indices_size->push_back(sparse_index->indicesBuffers()->Get(i)->length());
+ }
+
+ return Status::OK();
+}
+
+Status GetSparseTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>* type,
+ std::vector<int64_t>* shape,
+ std::vector<std::string>* dim_names,
+ int64_t* non_zero_length,
+ SparseTensorFormat::type* sparse_tensor_format_id) {
+ const flatbuf::Message* message = nullptr;
+ RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
+ auto sparse_tensor = message->header_as_SparseTensor();
+ if (sparse_tensor == nullptr) {
+ return Status::IOError(
+ "Header-type of flatbuffer-encoded Message is not SparseTensor.");
+ }
+ int ndim = static_cast<int>(sparse_tensor->shape()->size());
+
+ if (shape || dim_names) {
+ for (int i = 0; i < ndim; ++i) {
+ auto dim = sparse_tensor->shape()->Get(i);
+
+ if (shape) {
+ shape->push_back(dim->size());
+ }
+
+ if (dim_names) {
+ dim_names->push_back(StringFromFlatbuffers(dim->name()));
+ }
+ }
+ }
+
+ if (non_zero_length) {
+ *non_zero_length = sparse_tensor->non_zero_length();
+ }
+
+ if (sparse_tensor_format_id) {
+ switch (sparse_tensor->sparseIndex_type()) {
+ case flatbuf::SparseTensorIndex::SparseTensorIndexCOO:
+ *sparse_tensor_format_id = SparseTensorFormat::COO;
+ break;
+
+ case flatbuf::SparseTensorIndex::SparseMatrixIndexCSX: {
+ auto cs = sparse_tensor->sparseIndex_as_SparseMatrixIndexCSX();
+ switch (cs->compressedAxis()) {
+ case flatbuf::SparseMatrixCompressedAxis::Row:
+ *sparse_tensor_format_id = SparseTensorFormat::CSR;
+ break;
+
+ case flatbuf::SparseMatrixCompressedAxis::Column:
+ *sparse_tensor_format_id = SparseTensorFormat::CSC;
+ break;
+
+ default:
+ return Status::Invalid("Invalid value of SparseMatrixCompressedAxis");
+ }
+ } break;
+
+ case flatbuf::SparseTensorIndex::SparseTensorIndexCSF:
+ *sparse_tensor_format_id = SparseTensorFormat::CSF;
+ break;
+
+ default:
+ return Status::Invalid("Unrecognized sparse index type");
+ }
+ }
+
+ auto type_data = sparse_tensor->type(); // Required
+ if (type) {
+ return ConcreteTypeFromFlatbuffer(sparse_tensor->type_type(), type_data, {}, type);
+ } else {
+ return Status::OK();
+ }
+}
+
+} // namespace internal
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/metadata_internal.h b/src/arrow/cpp/src/arrow/ipc/metadata_internal.h
new file mode 100644
index 000000000..2afa95f6f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/metadata_internal.h
@@ -0,0 +1,228 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Internal metadata serialization matters
+
+#pragma once
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <flatbuffers/flatbuffers.h>
+
+#include "arrow/buffer.h"
+#include "arrow/io/type_fwd.h"
+#include "arrow/ipc/message.h"
+#include "arrow/result.h"
+#include "arrow/sparse_tensor.h"
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+#include "generated/Message_generated.h"
+#include "generated/Schema_generated.h"
+#include "generated/SparseTensor_generated.h" // IWYU pragma: keep
+
+namespace arrow {
+
+namespace flatbuf = org::apache::arrow::flatbuf;
+
+namespace ipc {
+
+class DictionaryFieldMapper;
+class DictionaryMemo;
+
+namespace internal {
+
+using KeyValueOffset = flatbuffers::Offset<flatbuf::KeyValue>;
+using KVVector = flatbuffers::Vector<KeyValueOffset>;
+
+// This 0xFFFFFFFF value is the first 4 bytes of a valid IPC message
+constexpr int32_t kIpcContinuationToken = -1;
+
+static constexpr flatbuf::MetadataVersion kCurrentMetadataVersion =
+ flatbuf::MetadataVersion::V5;
+
+static constexpr flatbuf::MetadataVersion kLatestMetadataVersion =
+ flatbuf::MetadataVersion::V5;
+
+static constexpr flatbuf::MetadataVersion kMinMetadataVersion =
+ flatbuf::MetadataVersion::V4;
+
+// These functions are used in unit tests
+ARROW_EXPORT
+MetadataVersion GetMetadataVersion(flatbuf::MetadataVersion version);
+
+ARROW_EXPORT
+flatbuf::MetadataVersion MetadataVersionToFlatbuffer(MetadataVersion version);
+
+// Whether the type has a validity bitmap in the given IPC version
+bool HasValidityBitmap(Type::type type_id, MetadataVersion version);
+
+static constexpr const char* kArrowMagicBytes = "ARROW1";
+
+struct FieldMetadata {
+ int64_t length;
+ int64_t null_count;
+ int64_t offset;
+};
+
+struct BufferMetadata {
+ /// The relative offset into the memory page to the starting byte of the buffer
+ int64_t offset;
+
+ /// Absolute length in bytes of the buffer
+ int64_t length;
+};
+
+struct FileBlock {
+ int64_t offset;
+ int32_t metadata_length;
+ int64_t body_length;
+};
+
+// Low-level utilities to help with reading Flatbuffers data.
+
+#define CHECK_FLATBUFFERS_NOT_NULL(fb_value, name) \
+ if ((fb_value) == NULLPTR) { \
+ return Status::IOError("Unexpected null field ", name, \
+ " in flatbuffer-encoded metadata"); \
+ }
+
+template <typename T>
+inline uint32_t FlatBuffersVectorSize(const flatbuffers::Vector<T>* vec) {
+ return (vec == NULLPTR) ? 0 : vec->size();
+}
+
+inline std::string StringFromFlatbuffers(const flatbuffers::String* s) {
+ return (s == NULLPTR) ? "" : s->str();
+}
+
+// Read interface classes. We do not fully deserialize the flatbuffers so that
+// individual fields metadata can be retrieved from very large schema without
+//
+
+// Construct a complete Schema from the message and add
+// dictionary-encoded fields to a DictionaryMemo instance. May be
+// expensive for very large schemas if you are only interested in a
+// few fields
+Status GetSchema(const void* opaque_schema, DictionaryMemo* dictionary_memo,
+ std::shared_ptr<Schema>* out);
+
+Status GetTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>* type,
+ std::vector<int64_t>* shape, std::vector<int64_t>* strides,
+ std::vector<std::string>* dim_names);
+
+// EXPERIMENTAL: Extracting metadata of a SparseCOOIndex from the message
+Status GetSparseCOOIndexMetadata(const flatbuf::SparseTensorIndexCOO* sparse_index,
+ std::shared_ptr<DataType>* indices_type);
+
+// EXPERIMENTAL: Extracting metadata of a SparseCSXIndex from the message
+Status GetSparseCSXIndexMetadata(const flatbuf::SparseMatrixIndexCSX* sparse_index,
+ std::shared_ptr<DataType>* indptr_type,
+ std::shared_ptr<DataType>* indices_type);
+
+// EXPERIMENTAL: Extracting metadata of a SparseCSFIndex from the message
+Status GetSparseCSFIndexMetadata(const flatbuf::SparseTensorIndexCSF* sparse_index,
+ std::vector<int64_t>* axis_order,
+ std::vector<int64_t>* indices_size,
+ std::shared_ptr<DataType>* indptr_type,
+ std::shared_ptr<DataType>* indices_type);
+
+// EXPERIMENTAL: Extracting metadata of a sparse tensor from the message
+Status GetSparseTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>* type,
+ std::vector<int64_t>* shape,
+ std::vector<std::string>* dim_names, int64_t* length,
+ SparseTensorFormat::type* sparse_tensor_format_id);
+
+Status GetKeyValueMetadata(const KVVector* fb_metadata,
+ std::shared_ptr<KeyValueMetadata>* out);
+
+template <typename RootType>
+bool VerifyFlatbuffers(const uint8_t* data, int64_t size) {
+ // Heuristic: tables in a Arrow flatbuffers buffer must take at least 1 bit
+ // each in average (ARROW-11559).
+ // Especially, the only recursive table (the `Field` table in Schema.fbs)
+ // must have a non-empty `type` member.
+ flatbuffers::Verifier verifier(
+ data, static_cast<size_t>(size),
+ /*max_depth=*/128,
+ /*max_tables=*/static_cast<flatbuffers::uoffset_t>(8 * size));
+ return verifier.VerifyBuffer<RootType>(nullptr);
+}
+
+static inline Status VerifyMessage(const uint8_t* data, int64_t size,
+ const flatbuf::Message** out) {
+ if (!VerifyFlatbuffers<flatbuf::Message>(data, size)) {
+ return Status::IOError("Invalid flatbuffers message.");
+ }
+ *out = flatbuf::GetMessage(data);
+ return Status::OK();
+}
+
+// Serialize arrow::Schema as a Flatbuffer
+Status WriteSchemaMessage(const Schema& schema, const DictionaryFieldMapper& mapper,
+ const IpcWriteOptions& options, std::shared_ptr<Buffer>* out);
+
+// This function is used in a unit test
+ARROW_EXPORT
+Status WriteRecordBatchMessage(
+ const int64_t length, const int64_t body_length,
+ const std::shared_ptr<const KeyValueMetadata>& custom_metadata,
+ const std::vector<FieldMetadata>& nodes, const std::vector<BufferMetadata>& buffers,
+ const IpcWriteOptions& options, std::shared_ptr<Buffer>* out);
+
+Result<std::shared_ptr<Buffer>> WriteTensorMessage(const Tensor& tensor,
+ const int64_t buffer_start_offset,
+ const IpcWriteOptions& options);
+
+Result<std::shared_ptr<Buffer>> WriteSparseTensorMessage(
+ const SparseTensor& sparse_tensor, int64_t body_length,
+ const std::vector<BufferMetadata>& buffers, const IpcWriteOptions& options);
+
+Status WriteFileFooter(const Schema& schema, const std::vector<FileBlock>& dictionaries,
+ const std::vector<FileBlock>& record_batches,
+ const std::shared_ptr<const KeyValueMetadata>& metadata,
+ io::OutputStream* out);
+
+Status WriteDictionaryMessage(
+ const int64_t id, const bool is_delta, const int64_t length,
+ const int64_t body_length,
+ const std::shared_ptr<const KeyValueMetadata>& custom_metadata,
+ const std::vector<FieldMetadata>& nodes, const std::vector<BufferMetadata>& buffers,
+ const IpcWriteOptions& options, std::shared_ptr<Buffer>* out);
+
+static inline Result<std::shared_ptr<Buffer>> WriteFlatbufferBuilder(
+ flatbuffers::FlatBufferBuilder& fbb, // NOLINT non-const reference
+ MemoryPool* pool = default_memory_pool()) {
+ int32_t size = fbb.GetSize();
+
+ ARROW_ASSIGN_OR_RAISE(auto result, AllocateBuffer(size, pool));
+
+ uint8_t* dst = result->mutable_data();
+ memcpy(dst, fbb.GetBufferPointer(), size);
+ return std::move(result);
+}
+
+} // namespace internal
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/options.cc b/src/arrow/cpp/src/arrow/ipc/options.cc
new file mode 100644
index 000000000..e5b14a47f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/options.cc
@@ -0,0 +1,41 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/ipc/options.h"
+
+#include "arrow/status.h"
+
+namespace arrow {
+namespace ipc {
+
+IpcWriteOptions IpcWriteOptions::Defaults() { return IpcWriteOptions(); }
+
+IpcReadOptions IpcReadOptions::Defaults() { return IpcReadOptions(); }
+
+namespace internal {
+
+Status CheckCompressionSupported(Compression::type codec) {
+ if (!(codec == Compression::LZ4_FRAME || codec == Compression::ZSTD)) {
+ return Status::Invalid("Only LZ4_FRAME and ZSTD compression allowed");
+ }
+ return Status::OK();
+}
+
+} // namespace internal
+
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/options.h b/src/arrow/cpp/src/arrow/ipc/options.h
new file mode 100644
index 000000000..2d0c2548f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/options.h
@@ -0,0 +1,160 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <vector>
+
+#include "arrow/ipc/type_fwd.h"
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/compression.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class MemoryPool;
+
+namespace ipc {
+
+// ARROW-109: We set this number arbitrarily to help catch user mistakes. For
+// deeply nested schemas, it is expected the user will indicate explicitly the
+// maximum allowed recursion depth
+constexpr int kMaxNestingDepth = 64;
+
+/// \brief Options for writing Arrow IPC messages
+struct ARROW_EXPORT IpcWriteOptions {
+ /// \brief If true, allow field lengths that don't fit in a signed 32-bit int.
+ ///
+ /// Some implementations may not be able to parse streams created with this option.
+ bool allow_64bit = false;
+
+ /// \brief The maximum permitted schema nesting depth.
+ int max_recursion_depth = kMaxNestingDepth;
+
+ /// \brief Write padding after memory buffers up to this multiple of bytes.
+ int32_t alignment = 8;
+
+ /// \brief Write the pre-0.15.0 IPC message format
+ ///
+ /// This legacy format consists of a 4-byte prefix instead of 8-byte.
+ bool write_legacy_ipc_format = false;
+
+ /// \brief The memory pool to use for allocations made during IPC writing
+ ///
+ /// While Arrow IPC is predominantly zero-copy, it may have to allocate
+ /// memory in some cases (for example if compression is enabled).
+ MemoryPool* memory_pool = default_memory_pool();
+
+ /// \brief Compression codec to use for record batch body buffers
+ ///
+ /// May only be UNCOMPRESSED, LZ4_FRAME and ZSTD.
+ std::shared_ptr<util::Codec> codec;
+
+ /// \brief Use global CPU thread pool to parallelize any computational tasks
+ /// like compression
+ bool use_threads = true;
+
+ /// \brief Whether to emit dictionary deltas
+ ///
+ /// If false, a changed dictionary for a given field will emit a full
+ /// dictionary replacement.
+ /// If true, a changed dictionary will be compared against the previous
+ /// version. If possible, a dictionary delta will be omitted, otherwise
+ /// a full dictionary replacement.
+ ///
+ /// Default is false to maximize stream compatibility.
+ ///
+ /// Also, note that if a changed dictionary is a nested dictionary,
+ /// then a delta is never emitted, for compatibility with the read path.
+ bool emit_dictionary_deltas = false;
+
+ /// \brief Whether to unify dictionaries for the IPC file format
+ ///
+ /// The IPC file format doesn't support dictionary replacements or deltas.
+ /// Therefore, chunks of a column with a dictionary type must have the same
+ /// dictionary in each record batch.
+ ///
+ /// If this option is true, RecordBatchWriter::WriteTable will attempt
+ /// to unify dictionaries across each table column. If this option is
+ /// false, unequal dictionaries across a table column will simply raise
+ /// an error.
+ ///
+ /// Note that enabling this option has a runtime cost. Also, not all types
+ /// currently support dictionary unification.
+ ///
+ /// This option is ignored for IPC streams, which support dictionary replacement
+ /// and deltas.
+ bool unify_dictionaries = false;
+
+ /// \brief Format version to use for IPC messages and their metadata.
+ ///
+ /// Presently using V5 version (readable by 1.0.0 and later).
+ /// V4 is also available (readable by 0.8.0 and later).
+ MetadataVersion metadata_version = MetadataVersion::V5;
+
+ static IpcWriteOptions Defaults();
+};
+
+#ifndef ARROW_NO_DEPRECATED_API
+using IpcOptions = IpcWriteOptions;
+#endif
+
+/// \brief Options for reading Arrow IPC messages
+struct ARROW_EXPORT IpcReadOptions {
+ /// \brief The maximum permitted schema nesting depth.
+ int max_recursion_depth = kMaxNestingDepth;
+
+ /// \brief The memory pool to use for allocations made during IPC reading
+ ///
+ /// While Arrow IPC is predominantly zero-copy, it may have to allocate
+ /// memory in some cases (for example if compression is enabled).
+ MemoryPool* memory_pool = default_memory_pool();
+
+ /// \brief Top-level schema fields to include when deserializing RecordBatch.
+ ///
+ /// If empty (the default), return all deserialized fields.
+ /// If non-empty, the values are the indices of fields in the top-level schema.
+ std::vector<int> included_fields;
+
+ /// \brief Use global CPU thread pool to parallelize any computational tasks
+ /// like decompression
+ bool use_threads = true;
+
+ /// \brief Whether to convert incoming data to platform-native endianness
+ ///
+ /// If the endianness of the received schema is not equal to platform-native
+ /// endianness, then all buffers with endian-sensitive data will be byte-swapped.
+ /// This includes the value buffers of numeric types, temporal types, decimal
+ /// types, as well as the offset buffers of variable-sized binary and list-like
+ /// types.
+ ///
+ /// Endianness conversion is achieved by the RecordBatchFileReader,
+ /// RecordBatchStreamReader and StreamDecoder classes.
+ bool ensure_native_endian = true;
+
+ static IpcReadOptions Defaults();
+};
+
+namespace internal {
+
+Status CheckCompressionSupported(Compression::type codec);
+
+} // namespace internal
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/read_write_benchmark.cc b/src/arrow/cpp/src/arrow/ipc/read_write_benchmark.cc
new file mode 100644
index 000000000..f5cc857ac
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/read_write_benchmark.cc
@@ -0,0 +1,262 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <cstdint>
+#include <sstream>
+#include <string>
+
+#include "arrow/io/file.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/api.h"
+#include "arrow/record_batch.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+
+namespace arrow {
+
+std::shared_ptr<RecordBatch> MakeRecordBatch(int64_t total_size, int64_t num_fields) {
+ int64_t length = total_size / num_fields / sizeof(int64_t);
+ random::RandomArrayGenerator rand(0x4f32a908);
+ auto type = arrow::int64();
+
+ ArrayVector arrays;
+ std::vector<std::shared_ptr<Field>> fields;
+ for (int64_t i = 0; i < num_fields; ++i) {
+ std::stringstream ss;
+ ss << "f" << i;
+ fields.push_back(field(ss.str(), type));
+ arrays.push_back(rand.Int64(length, 0, 100, 0.1));
+ }
+
+ auto schema = std::make_shared<Schema>(fields);
+ return RecordBatch::Make(schema, length, arrays);
+}
+
+static void WriteRecordBatch(benchmark::State& state) { // NOLINT non-const reference
+ // 1MB
+ constexpr int64_t kTotalSize = 1 << 20;
+ auto options = ipc::IpcWriteOptions::Defaults();
+
+ std::shared_ptr<ResizableBuffer> buffer = *AllocateResizableBuffer(1024);
+ auto record_batch = MakeRecordBatch(kTotalSize, state.range(0));
+
+ while (state.KeepRunning()) {
+ io::BufferOutputStream stream(buffer);
+ int32_t metadata_length;
+ int64_t body_length;
+ ABORT_NOT_OK(ipc::WriteRecordBatch(*record_batch, 0, &stream, &metadata_length,
+ &body_length, options));
+ }
+ state.SetBytesProcessed(int64_t(state.iterations()) * kTotalSize);
+}
+
+static void ReadRecordBatch(benchmark::State& state) { // NOLINT non-const reference
+ // 1MB
+ constexpr int64_t kTotalSize = 1 << 20;
+ auto options = ipc::IpcWriteOptions::Defaults();
+
+ std::shared_ptr<ResizableBuffer> buffer = *AllocateResizableBuffer(1024);
+ auto record_batch = MakeRecordBatch(kTotalSize, state.range(0));
+
+ io::BufferOutputStream stream(buffer);
+
+ int32_t metadata_length;
+ int64_t body_length;
+ ABORT_NOT_OK(ipc::WriteRecordBatch(*record_batch, 0, &stream, &metadata_length,
+ &body_length, options));
+
+ ipc::DictionaryMemo empty_memo;
+ while (state.KeepRunning()) {
+ io::BufferReader reader(buffer);
+ ABORT_NOT_OK(ipc::ReadRecordBatch(record_batch->schema(), &empty_memo,
+ ipc::IpcReadOptions::Defaults(), &reader));
+ }
+ state.SetBytesProcessed(int64_t(state.iterations()) * kTotalSize);
+}
+
+static void ReadStream(benchmark::State& state) { // NOLINT non-const reference
+ // 1MB
+ constexpr int64_t kTotalSize = 1 << 20;
+ auto options = ipc::IpcWriteOptions::Defaults();
+
+ std::shared_ptr<ResizableBuffer> buffer = *AllocateResizableBuffer(1024);
+ {
+ // Make Arrow IPC stream
+ auto record_batch = MakeRecordBatch(kTotalSize, state.range(0));
+
+ io::BufferOutputStream stream(buffer);
+
+ auto writer_result = ipc::MakeStreamWriter(&stream, record_batch->schema(), options);
+ ABORT_NOT_OK(writer_result);
+ auto writer = *writer_result;
+ ABORT_NOT_OK(writer->WriteRecordBatch(*record_batch));
+ ABORT_NOT_OK(writer->Close());
+ ABORT_NOT_OK(stream.Close());
+ }
+
+ ipc::DictionaryMemo empty_memo;
+ while (state.KeepRunning()) {
+ io::BufferReader input(buffer);
+ auto reader_result =
+ ipc::RecordBatchStreamReader::Open(&input, ipc::IpcReadOptions::Defaults());
+ ABORT_NOT_OK(reader_result);
+ auto reader = *reader_result;
+ while (true) {
+ std::shared_ptr<RecordBatch> batch;
+ ABORT_NOT_OK(reader->ReadNext(&batch));
+ if (batch.get() == nullptr) {
+ break;
+ }
+ }
+ }
+ state.SetBytesProcessed(int64_t(state.iterations()) * kTotalSize);
+}
+
+static void DecodeStream(benchmark::State& state) { // NOLINT non-const reference
+ // 1MB
+ constexpr int64_t kTotalSize = 1 << 20;
+ auto options = ipc::IpcWriteOptions::Defaults();
+
+ std::shared_ptr<ResizableBuffer> buffer = *AllocateResizableBuffer(1024);
+ auto record_batch = MakeRecordBatch(kTotalSize, state.range(0));
+
+ io::BufferOutputStream stream(buffer);
+
+ auto writer_result = ipc::MakeStreamWriter(&stream, record_batch->schema(), options);
+ ABORT_NOT_OK(writer_result);
+ auto writer = *writer_result;
+ ABORT_NOT_OK(writer->WriteRecordBatch(*record_batch));
+ ABORT_NOT_OK(writer->Close());
+
+ ipc::DictionaryMemo empty_memo;
+ while (state.KeepRunning()) {
+ class NullListener : public ipc::Listener {
+ Status OnRecordBatchDecoded(std::shared_ptr<RecordBatch> batch) override {
+ return Status::OK();
+ }
+ } listener;
+ ipc::StreamDecoder decoder(std::shared_ptr<NullListener>(&listener, [](void*) {}),
+ ipc::IpcReadOptions::Defaults());
+ ABORT_NOT_OK(decoder.Consume(buffer));
+ }
+ state.SetBytesProcessed(int64_t(state.iterations()) * kTotalSize);
+}
+
+#define GENERATE_COMPRESSED_DATA_IN_MEMORY() \
+ constexpr int64_t kBatchSize = 1 << 20; /* 1 MB */ \
+ constexpr int64_t kBatches = 16; \
+ auto options = ipc::IpcWriteOptions::Defaults(); \
+ ASSIGN_OR_ABORT(options.codec, \
+ arrow::util::Codec::Create(arrow::Compression::type::ZSTD)); \
+ std::shared_ptr<ResizableBuffer> buffer = *AllocateResizableBuffer(1024); \
+ { \
+ auto record_batch = MakeRecordBatch(kBatchSize, state.range(0)); \
+ io::BufferOutputStream stream(buffer); \
+ auto writer = *ipc::MakeFileWriter(&stream, record_batch->schema(), options); \
+ for (int i = 0; i < kBatches; i++) { \
+ ABORT_NOT_OK(writer->WriteRecordBatch(*record_batch)); \
+ } \
+ ABORT_NOT_OK(writer->Close()); \
+ ABORT_NOT_OK(stream.Close()); \
+ }
+
+#define GENERATE_DATA_IN_MEMORY() \
+ constexpr int64_t kBatchSize = 1 << 20; /* 1 MB */ \
+ constexpr int64_t kBatches = 1; \
+ auto options = ipc::IpcWriteOptions::Defaults(); \
+ std::shared_ptr<ResizableBuffer> buffer = *AllocateResizableBuffer(1024); \
+ { \
+ auto record_batch = MakeRecordBatch(kBatchSize, state.range(0)); \
+ io::BufferOutputStream stream(buffer); \
+ auto writer = *ipc::MakeFileWriter(&stream, record_batch->schema(), options); \
+ ABORT_NOT_OK(writer->WriteRecordBatch(*record_batch)); \
+ ABORT_NOT_OK(writer->Close()); \
+ ABORT_NOT_OK(stream.Close()); \
+ }
+
+#define GENERATE_DATA_TEMP_FILE() \
+ constexpr int64_t kBatchSize = 1 << 20; /* 1 MB */ \
+ constexpr int64_t kBatches = 16; \
+ auto options = ipc::IpcWriteOptions::Defaults(); \
+ ASSIGN_OR_ABORT(auto sink, io::FileOutputStream::Open("/tmp/benchmark.arrow")); \
+ { \
+ auto record_batch = MakeRecordBatch(kBatchSize, state.range(0)); \
+ auto writer = *ipc::MakeFileWriter(sink, record_batch->schema(), options); \
+ ABORT_NOT_OK(writer->WriteRecordBatch(*record_batch)); \
+ ABORT_NOT_OK(writer->Close()); \
+ ABORT_NOT_OK(sink->Close()); \
+ }
+
+#define READ_DATA_IN_MEMORY() auto input = std::make_shared<io::BufferReader>(buffer);
+#define READ_DATA_TEMP_FILE() \
+ ASSIGN_OR_ABORT(auto input, io::ReadableFile::Open("/tmp/benchmark.arrow"));
+#define READ_DATA_MMAP_FILE() \
+ ASSIGN_OR_ABORT(auto input, io::MemoryMappedFile::Open("/tmp/benchmark.arrow", \
+ io::FileMode::type::READ));
+
+#define READ_SYNC(NAME, GENERATE, READ) \
+ static void NAME(benchmark::State& state) { \
+ GENERATE(); \
+ for (auto _ : state) { \
+ READ(); \
+ auto reader = *ipc::RecordBatchFileReader::Open(input.get(), \
+ ipc::IpcReadOptions::Defaults()); \
+ const int num_batches = reader->num_record_batches(); \
+ for (int i = 0; i < num_batches; ++i) { \
+ auto batch = *reader->ReadRecordBatch(i); \
+ } \
+ } \
+ state.SetBytesProcessed(int64_t(state.iterations()) * kBatchSize * kBatches); \
+ } \
+ BENCHMARK(NAME)->RangeMultiplier(4)->Range(1, 1 << 13)->UseRealTime();
+
+#define READ_ASYNC(NAME, GENERATE, READ) \
+ static void NAME##Async(benchmark::State& state) { \
+ GENERATE(); \
+ for (auto _ : state) { \
+ READ(); \
+ auto reader = *ipc::RecordBatchFileReader::Open(input.get(), \
+ ipc::IpcReadOptions::Defaults()); \
+ ASSIGN_OR_ABORT(auto generator, reader->GetRecordBatchGenerator()); \
+ const int num_batches = reader->num_record_batches(); \
+ for (int i = 0; i < num_batches; ++i) { \
+ auto batch = *generator().result(); \
+ } \
+ } \
+ state.SetBytesProcessed(int64_t(state.iterations()) * kBatchSize * kBatches); \
+ } \
+ BENCHMARK(NAME##Async)->RangeMultiplier(4)->Range(1, 1 << 13)->UseRealTime();
+
+#define READ_BENCHMARK(NAME, GENERATE, READ) \
+ READ_SYNC(NAME, GENERATE, READ); \
+ READ_ASYNC(NAME, GENERATE, READ);
+
+READ_BENCHMARK(ReadFile, GENERATE_DATA_IN_MEMORY, READ_DATA_IN_MEMORY);
+READ_BENCHMARK(ReadTempFile, GENERATE_DATA_TEMP_FILE, READ_DATA_TEMP_FILE);
+READ_BENCHMARK(ReadMmapFile, GENERATE_DATA_TEMP_FILE, READ_DATA_MMAP_FILE);
+READ_BENCHMARK(ReadCompressedFile, GENERATE_COMPRESSED_DATA_IN_MEMORY,
+ READ_DATA_IN_MEMORY);
+
+BENCHMARK(WriteRecordBatch)->RangeMultiplier(4)->Range(1, 1 << 13)->UseRealTime();
+BENCHMARK(ReadRecordBatch)->RangeMultiplier(4)->Range(1, 1 << 13)->UseRealTime();
+BENCHMARK(ReadStream)->RangeMultiplier(4)->Range(1, 1 << 13)->UseRealTime();
+BENCHMARK(DecodeStream)->RangeMultiplier(4)->Range(1, 1 << 13)->UseRealTime();
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/read_write_test.cc b/src/arrow/cpp/src/arrow/ipc/read_write_test.cc
new file mode 100644
index 000000000..70edab1f6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/read_write_test.cc
@@ -0,0 +1,2415 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <unordered_set>
+
+#include <flatbuffers/flatbuffers.h>
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/io/file.h"
+#include "arrow/io/memory.h"
+#include "arrow/io/test_common.h"
+#include "arrow/ipc/message.h"
+#include "arrow/ipc/metadata_internal.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/test_common.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/testing/extension_type.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/key_value_metadata.h"
+
+#include "generated/Message_generated.h" // IWYU pragma: keep
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::GetByteWidth;
+using internal::TemporaryDir;
+
+namespace ipc {
+
+using internal::FieldPosition;
+
+namespace test {
+
+const std::vector<MetadataVersion> kMetadataVersions = {MetadataVersion::V4,
+ MetadataVersion::V5};
+
+class TestMessage : public ::testing::TestWithParam<MetadataVersion> {
+ public:
+ void SetUp() {
+ version_ = GetParam();
+ fb_version_ = internal::MetadataVersionToFlatbuffer(version_);
+ options_ = IpcWriteOptions::Defaults();
+ options_.metadata_version = version_;
+ }
+
+ protected:
+ MetadataVersion version_;
+ flatbuf::MetadataVersion fb_version_;
+ IpcWriteOptions options_;
+};
+
+TEST(TestMessage, Equals) {
+ std::string metadata = "foo";
+ std::string body = "bar";
+
+ auto b1 = std::make_shared<Buffer>(metadata);
+ auto b2 = std::make_shared<Buffer>(metadata);
+ auto b3 = std::make_shared<Buffer>(body);
+ auto b4 = std::make_shared<Buffer>(body);
+
+ Message msg1(b1, b3);
+ Message msg2(b2, b4);
+ Message msg3(b1, nullptr);
+ Message msg4(b2, nullptr);
+
+ ASSERT_TRUE(msg1.Equals(msg2));
+ ASSERT_TRUE(msg3.Equals(msg4));
+
+ ASSERT_FALSE(msg1.Equals(msg3));
+ ASSERT_FALSE(msg3.Equals(msg1));
+
+ // same metadata as msg1, different body
+ Message msg5(b2, b1);
+ ASSERT_FALSE(msg1.Equals(msg5));
+ ASSERT_FALSE(msg5.Equals(msg1));
+}
+
+TEST_P(TestMessage, SerializeTo) {
+ const int64_t body_length = 64;
+
+ flatbuffers::FlatBufferBuilder fbb;
+ fbb.Finish(flatbuf::CreateMessage(fbb, fb_version_, flatbuf::MessageHeader::RecordBatch,
+ 0 /* header */, body_length));
+
+ std::shared_ptr<Buffer> metadata;
+ ASSERT_OK_AND_ASSIGN(metadata, internal::WriteFlatbufferBuilder(fbb));
+
+ std::string body = "abcdef";
+
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr<Message> message,
+ Message::Open(metadata, std::make_shared<Buffer>(body)));
+
+ auto CheckWithAlignment = [&](int32_t alignment) {
+ options_.alignment = alignment;
+ const int32_t prefix_size = 8;
+ int64_t output_length = 0;
+ ASSERT_OK_AND_ASSIGN(auto stream, io::BufferOutputStream::Create(1 << 10));
+ ASSERT_OK(message->SerializeTo(stream.get(), options_, &output_length));
+ ASSERT_EQ(BitUtil::RoundUp(metadata->size() + prefix_size, alignment) + body_length,
+ output_length);
+ ASSERT_OK_AND_EQ(output_length, stream->Tell());
+ ASSERT_OK_AND_ASSIGN(auto buffer, stream->Finish());
+ // chech whether length is written in little endian
+ auto buffer_ptr = buffer.get()->data();
+ ASSERT_EQ(output_length - body_length - prefix_size,
+ BitUtil::FromLittleEndian(*(uint32_t*)(buffer_ptr + 4)));
+ };
+
+ CheckWithAlignment(8);
+ CheckWithAlignment(64);
+}
+
+TEST_P(TestMessage, SerializeCustomMetadata) {
+ std::vector<std::shared_ptr<KeyValueMetadata>> cases = {
+ nullptr, key_value_metadata({}, {}),
+ key_value_metadata({"foo", "bar"}, {"fizz", "buzz"})};
+ for (auto metadata : cases) {
+ std::shared_ptr<Buffer> serialized;
+ ASSERT_OK(internal::WriteRecordBatchMessage(
+ /*length=*/0, /*body_length=*/0, metadata,
+ /*nodes=*/{},
+ /*buffers=*/{}, options_, &serialized));
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr<Message> message,
+ Message::Open(serialized, /*body=*/nullptr));
+
+ if (metadata) {
+ ASSERT_TRUE(message->custom_metadata()->Equals(*metadata));
+ } else {
+ ASSERT_EQ(nullptr, message->custom_metadata());
+ }
+ }
+}
+
+void BuffersOverlapEquals(const Buffer& left, const Buffer& right) {
+ ASSERT_GT(left.size(), 0);
+ ASSERT_GT(right.size(), 0);
+ ASSERT_TRUE(left.Equals(right, std::min(left.size(), right.size())));
+}
+
+TEST_P(TestMessage, LegacyIpcBackwardsCompatibility) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(MakeIntBatchSized(36, &batch));
+
+ auto RoundtripWithOptions = [&](std::shared_ptr<Buffer>* out_serialized,
+ std::unique_ptr<Message>* out) {
+ IpcPayload payload;
+ ASSERT_OK(GetRecordBatchPayload(*batch, options_, &payload));
+
+ ASSERT_OK_AND_ASSIGN(auto stream, io::BufferOutputStream::Create(1 << 20));
+
+ int32_t metadata_length = -1;
+ ASSERT_OK(WriteIpcPayload(payload, options_, stream.get(), &metadata_length));
+
+ ASSERT_OK_AND_ASSIGN(*out_serialized, stream->Finish());
+ io::BufferReader io_reader(*out_serialized);
+ ASSERT_OK(ReadMessage(&io_reader).Value(out));
+ };
+
+ std::shared_ptr<Buffer> serialized, legacy_serialized;
+ std::unique_ptr<Message> message, legacy_message;
+
+ RoundtripWithOptions(&serialized, &message);
+
+ // First 4 bytes 0xFFFFFFFF Continuation marker
+ ASSERT_EQ(-1, util::SafeLoadAs<int32_t>(serialized->data()));
+
+ options_.write_legacy_ipc_format = true;
+ RoundtripWithOptions(&legacy_serialized, &legacy_message);
+
+ // Check that the continuation marker is not written
+ ASSERT_NE(-1, util::SafeLoadAs<int32_t>(legacy_serialized->data()));
+
+ // Have to use the smaller size to exclude padding
+ BuffersOverlapEquals(*legacy_message->metadata(), *message->metadata());
+ ASSERT_TRUE(legacy_message->body()->Equals(*message->body()));
+}
+
+TEST(TestMessage, Verify) {
+ std::string metadata = "invalid";
+ std::string body = "abcdef";
+
+ Message message(std::make_shared<Buffer>(metadata), std::make_shared<Buffer>(body));
+ ASSERT_FALSE(message.Verify());
+}
+
+INSTANTIATE_TEST_SUITE_P(TestMessage, TestMessage,
+ ::testing::ValuesIn(kMetadataVersions));
+
+class TestSchemaMetadata : public ::testing::Test {
+ public:
+ void SetUp() {}
+
+ void CheckSchemaRoundtrip(const Schema& schema) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> buffer, SerializeSchema(schema));
+
+ io::BufferReader reader(buffer);
+ DictionaryMemo in_memo;
+ ASSERT_OK_AND_ASSIGN(auto actual_schema, ReadSchema(&reader, &in_memo));
+ AssertSchemaEqual(schema, *actual_schema);
+ }
+};
+
+const std::shared_ptr<DataType> INT32 = std::make_shared<Int32Type>();
+
+TEST_F(TestSchemaMetadata, PrimitiveFields) {
+ auto f0 = field("f0", std::make_shared<Int8Type>());
+ auto f1 = field("f1", std::make_shared<Int16Type>(), false);
+ auto f2 = field("f2", std::make_shared<Int32Type>());
+ auto f3 = field("f3", std::make_shared<Int64Type>());
+ auto f4 = field("f4", std::make_shared<UInt8Type>());
+ auto f5 = field("f5", std::make_shared<UInt16Type>());
+ auto f6 = field("f6", std::make_shared<UInt32Type>());
+ auto f7 = field("f7", std::make_shared<UInt64Type>());
+ auto f8 = field("f8", std::make_shared<FloatType>());
+ auto f9 = field("f9", std::make_shared<DoubleType>(), false);
+ auto f10 = field("f10", std::make_shared<BooleanType>());
+
+ Schema schema({f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10});
+ CheckSchemaRoundtrip(schema);
+}
+
+TEST_F(TestSchemaMetadata, NestedFields) {
+ auto type = list(int32());
+ auto f0 = field("f0", type);
+
+ std::shared_ptr<StructType> type2(
+ new StructType({field("k1", INT32), field("k2", INT32), field("k3", INT32)}));
+ auto f1 = field("f1", type2);
+
+ Schema schema({f0, f1});
+ CheckSchemaRoundtrip(schema);
+}
+
+TEST_F(TestSchemaMetadata, DictionaryFields) {
+ {
+ auto dict_type = dictionary(int8(), int32(), true /* ordered */);
+ auto f0 = field("f0", dict_type);
+ auto f1 = field("f1", list(dict_type));
+
+ Schema schema({f0, f1});
+ CheckSchemaRoundtrip(schema);
+ }
+ {
+ auto dict_type = dictionary(int8(), list(int32()));
+ auto f0 = field("f0", dict_type);
+
+ Schema schema({f0});
+ CheckSchemaRoundtrip(schema);
+ }
+}
+
+TEST_F(TestSchemaMetadata, NestedDictionaryFields) {
+ {
+ auto inner_dict_type = dictionary(int8(), int32(), /*ordered=*/true);
+ auto dict_type = dictionary(int16(), list(inner_dict_type));
+
+ Schema schema({field("f0", dict_type)});
+ CheckSchemaRoundtrip(schema);
+ }
+ {
+ auto dict_type1 = dictionary(int8(), utf8(), /*ordered=*/true);
+ auto dict_type2 = dictionary(int32(), fixed_size_binary(24));
+ auto dict_type3 = dictionary(int32(), binary());
+ auto dict_type4 = dictionary(int8(), decimal(19, 7));
+
+ auto struct_type1 = struct_({field("s1", dict_type1), field("s2", dict_type2)});
+ auto struct_type2 = struct_({field("s3", dict_type3), field("s4", dict_type4)});
+
+ Schema schema({field("f1", dictionary(int32(), struct_type1)),
+ field("f2", dictionary(int32(), struct_type2))});
+ CheckSchemaRoundtrip(schema);
+ }
+}
+
+TEST_F(TestSchemaMetadata, KeyValueMetadata) {
+ auto field_metadata = key_value_metadata({{"key", "value"}});
+ auto schema_metadata = key_value_metadata({{"foo", "bar"}, {"bizz", "buzz"}});
+
+ auto f0 = field("f0", std::make_shared<Int8Type>());
+ auto f1 = field("f1", std::make_shared<Int16Type>(), false, field_metadata);
+
+ Schema schema({f0, f1}, schema_metadata);
+ CheckSchemaRoundtrip(schema);
+}
+
+TEST_F(TestSchemaMetadata, MetadataVersionForwardCompatibility) {
+ // ARROW-9399
+ std::string root;
+ ASSERT_OK(GetTestResourceRoot(&root));
+
+ // schema_v6.arrow with currently non-existent MetadataVersion::V6
+ std::stringstream schema_v6_path;
+ schema_v6_path << root << "/forward-compatibility/schema_v6.arrow";
+
+ ASSERT_OK_AND_ASSIGN(auto schema_v6_file, io::ReadableFile::Open(schema_v6_path.str()));
+
+ DictionaryMemo placeholder_memo;
+ ASSERT_RAISES(Invalid, ReadSchema(schema_v6_file.get(), &placeholder_memo));
+}
+
+const std::vector<test::MakeRecordBatch*> kBatchCases = {
+ &MakeIntRecordBatch,
+ &MakeListRecordBatch,
+ &MakeFixedSizeListRecordBatch,
+ &MakeNonNullRecordBatch,
+ &MakeZeroLengthRecordBatch,
+ &MakeDeeplyNestedList,
+ &MakeStringTypesRecordBatchWithNulls,
+ &MakeStruct,
+ &MakeUnion,
+ &MakeDictionary,
+ &MakeNestedDictionary,
+ &MakeMap,
+ &MakeMapOfDictionary,
+ &MakeDates,
+ &MakeTimestamps,
+ &MakeTimes,
+ &MakeFWBinary,
+ &MakeNull,
+ &MakeDecimal,
+ &MakeBooleanBatch,
+ &MakeFloatBatch,
+ &MakeIntervals,
+ &MakeUuid,
+ &MakeComplex128,
+ &MakeDictExtension};
+
+static int g_file_number = 0;
+
+class ExtensionTypesMixin {
+ public:
+ // Register the extension types required to ensure roundtripping
+ ExtensionTypesMixin() : ext_guard_({uuid(), dict_extension_type(), complex128()}) {}
+
+ protected:
+ ExtensionTypeGuard ext_guard_;
+};
+
+class IpcTestFixture : public io::MemoryMapFixture, public ExtensionTypesMixin {
+ public:
+ void SetUp() {
+ options_ = IpcWriteOptions::Defaults();
+ ASSERT_OK_AND_ASSIGN(temp_dir_, TemporaryDir::Make("ipc-test-"));
+ }
+
+ std::string TempFile(util::string_view file) {
+ return temp_dir_->path().Join(std::string(file)).ValueOrDie().ToString();
+ }
+
+ void DoSchemaRoundTrip(const Schema& schema, std::shared_ptr<Schema>* result) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> serialized_schema,
+ SerializeSchema(schema, options_.memory_pool));
+
+ DictionaryMemo in_memo;
+ io::BufferReader buf_reader(serialized_schema);
+ ASSERT_OK_AND_ASSIGN(*result, ReadSchema(&buf_reader, &in_memo));
+ }
+
+ Result<std::shared_ptr<RecordBatch>> DoStandardRoundTrip(
+ const RecordBatch& batch, const IpcWriteOptions& options,
+ DictionaryMemo* dictionary_memo,
+ const IpcReadOptions& read_options = IpcReadOptions::Defaults()) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> serialized_batch,
+ SerializeRecordBatch(batch, options));
+
+ io::BufferReader buf_reader(serialized_batch);
+ return ReadRecordBatch(batch.schema(), dictionary_memo, read_options, &buf_reader);
+ }
+
+ Result<std::shared_ptr<RecordBatch>> DoLargeRoundTrip(const RecordBatch& batch,
+ bool zero_data) {
+ if (zero_data) {
+ RETURN_NOT_OK(ZeroMemoryMap(mmap_.get()));
+ }
+ RETURN_NOT_OK(mmap_->Seek(0));
+
+ auto options = options_;
+ options.allow_64bit = true;
+
+ ARROW_ASSIGN_OR_RAISE(auto file_writer,
+ MakeFileWriter(mmap_, batch.schema(), options));
+ RETURN_NOT_OK(file_writer->WriteRecordBatch(batch));
+ RETURN_NOT_OK(file_writer->Close());
+
+ ARROW_ASSIGN_OR_RAISE(int64_t offset, mmap_->Tell());
+
+ std::shared_ptr<RecordBatchFileReader> file_reader;
+ ARROW_ASSIGN_OR_RAISE(file_reader, RecordBatchFileReader::Open(mmap_.get(), offset));
+
+ return file_reader->ReadRecordBatch(0);
+ }
+
+ void CheckReadResult(const RecordBatch& result, const RecordBatch& expected) {
+ ASSERT_OK(result.ValidateFull());
+ EXPECT_EQ(expected.num_rows(), result.num_rows());
+
+ ASSERT_TRUE(expected.schema()->Equals(*result.schema()));
+ ASSERT_EQ(expected.num_columns(), result.num_columns())
+ << expected.schema()->ToString() << " result: " << result.schema()->ToString();
+
+ CompareBatchColumnsDetailed(result, expected);
+ }
+
+ void CheckRoundtrip(const RecordBatch& batch,
+ IpcWriteOptions options = IpcWriteOptions::Defaults(),
+ IpcReadOptions read_options = IpcReadOptions::Defaults(),
+ int64_t buffer_size = 1 << 20) {
+ std::stringstream ss;
+ ss << "test-write-row-batch-" << g_file_number++;
+ ASSERT_OK_AND_ASSIGN(
+ mmap_, io::MemoryMapFixture::InitMemoryMap(buffer_size, TempFile(ss.str())));
+
+ std::shared_ptr<Schema> schema_result;
+ DoSchemaRoundTrip(*batch.schema(), &schema_result);
+ ASSERT_TRUE(batch.schema()->Equals(*schema_result));
+
+ DictionaryMemo dictionary_memo;
+ ASSERT_OK(::arrow::ipc::internal::CollectDictionaries(batch, &dictionary_memo));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto result, DoStandardRoundTrip(batch, options, &dictionary_memo, read_options));
+ CheckReadResult(*result, batch);
+
+ ASSERT_OK_AND_ASSIGN(result, DoLargeRoundTrip(batch, /*zero_data=*/true));
+ CheckReadResult(*result, batch);
+ }
+
+ void CheckRoundtrip(const std::shared_ptr<Array>& array,
+ IpcWriteOptions options = IpcWriteOptions::Defaults(),
+ int64_t buffer_size = 1 << 20) {
+ auto f0 = arrow::field("f0", array->type());
+ std::vector<std::shared_ptr<Field>> fields = {f0};
+ auto schema = std::make_shared<Schema>(fields);
+
+ auto batch = RecordBatch::Make(schema, 0, {array});
+ CheckRoundtrip(*batch, options, IpcReadOptions::Defaults(), buffer_size);
+ }
+
+ protected:
+ std::shared_ptr<io::MemoryMappedFile> mmap_;
+ IpcWriteOptions options_;
+ std::unique_ptr<TemporaryDir> temp_dir_;
+};
+
+TEST(MetadataVersion, ForwardsCompatCheck) {
+ // Verify UBSAN is ok with casting out of range metdata version.
+ EXPECT_LT(flatbuf::MetadataVersion::MAX, static_cast<flatbuf::MetadataVersion>(72));
+}
+
+class TestWriteRecordBatch : public ::testing::Test, public IpcTestFixture {
+ public:
+ void SetUp() { IpcTestFixture::SetUp(); }
+ void TearDown() { IpcTestFixture::TearDown(); }
+};
+
+class TestIpcRoundTrip : public ::testing::TestWithParam<MakeRecordBatch*>,
+ public IpcTestFixture {
+ public:
+ void SetUp() { IpcTestFixture::SetUp(); }
+ void TearDown() { IpcTestFixture::TearDown(); }
+
+ void TestMetadataVersion(MetadataVersion expected_version) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(MakeIntRecordBatch(&batch));
+
+ mmap_.reset(); // Ditch previous mmap view, to avoid errors on Windows
+ ASSERT_OK_AND_ASSIGN(mmap_,
+ io::MemoryMapFixture::InitMemoryMap(1 << 16, "test-metadata"));
+
+ int32_t metadata_length;
+ int64_t body_length;
+ const int64_t buffer_offset = 0;
+ ASSERT_OK(WriteRecordBatch(*batch, buffer_offset, mmap_.get(), &metadata_length,
+ &body_length, options_));
+
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr<Message> message,
+ ReadMessage(0, metadata_length, mmap_.get()));
+ ASSERT_EQ(expected_version, message->metadata_version());
+ }
+};
+
+TEST_P(TestIpcRoundTrip, RoundTrip) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue
+
+ for (const auto version : kMetadataVersions) {
+ options_.metadata_version = version;
+ CheckRoundtrip(*batch);
+ }
+}
+
+TEST_F(TestIpcRoundTrip, DefaultMetadataVersion) {
+ TestMetadataVersion(MetadataVersion::V5);
+}
+
+TEST_F(TestIpcRoundTrip, SpecificMetadataVersion) {
+ options_.metadata_version = MetadataVersion::V4;
+ TestMetadataVersion(MetadataVersion::V4);
+ options_.metadata_version = MetadataVersion::V5;
+ TestMetadataVersion(MetadataVersion::V5);
+}
+
+TEST(TestReadMessage, CorruptedSmallInput) {
+ std::string data = "abc";
+ io::BufferReader reader(data);
+ ASSERT_RAISES(Invalid, ReadMessage(&reader));
+
+ // But no error on unsignaled EOS
+ io::BufferReader reader2("");
+ ASSERT_OK_AND_ASSIGN(auto message, ReadMessage(&reader2));
+ ASSERT_EQ(nullptr, message);
+}
+
+TEST(TestMetadata, GetMetadataVersion) {
+ ASSERT_EQ(MetadataVersion::V1,
+ ipc::internal::GetMetadataVersion(flatbuf::MetadataVersion::V1));
+ ASSERT_EQ(MetadataVersion::V2,
+ ipc::internal::GetMetadataVersion(flatbuf::MetadataVersion::V2));
+ ASSERT_EQ(MetadataVersion::V3,
+ ipc::internal::GetMetadataVersion(flatbuf::MetadataVersion::V3));
+ ASSERT_EQ(MetadataVersion::V4,
+ ipc::internal::GetMetadataVersion(flatbuf::MetadataVersion::V4));
+ ASSERT_EQ(MetadataVersion::V5,
+ ipc::internal::GetMetadataVersion(flatbuf::MetadataVersion::V5));
+ ASSERT_EQ(MetadataVersion::V1,
+ ipc::internal::GetMetadataVersion(flatbuf::MetadataVersion::MIN));
+ ASSERT_EQ(MetadataVersion::V5,
+ ipc::internal::GetMetadataVersion(flatbuf::MetadataVersion::MAX));
+}
+
+TEST_P(TestIpcRoundTrip, SliceRoundTrip) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue
+
+ // Skip the zero-length case
+ if (batch->num_rows() < 2) {
+ return;
+ }
+
+ auto sliced_batch = batch->Slice(2, 10);
+ CheckRoundtrip(*sliced_batch);
+}
+
+TEST_P(TestIpcRoundTrip, ZeroLengthArrays) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue
+
+ std::shared_ptr<RecordBatch> zero_length_batch;
+ if (batch->num_rows() > 2) {
+ zero_length_batch = batch->Slice(2, 0);
+ } else {
+ zero_length_batch = batch->Slice(0, 0);
+ }
+
+ CheckRoundtrip(*zero_length_batch);
+
+ // ARROW-544: check binary array
+ ASSERT_OK_AND_ASSIGN(auto value_offsets,
+ AllocateBuffer(sizeof(int32_t), options_.memory_pool));
+ *reinterpret_cast<int32_t*>(value_offsets->mutable_data()) = 0;
+
+ std::shared_ptr<Array> bin_array = std::make_shared<BinaryArray>(
+ 0, std::move(value_offsets), std::make_shared<Buffer>(nullptr, 0),
+ std::make_shared<Buffer>(nullptr, 0));
+
+ // null value_offsets
+ std::shared_ptr<Array> bin_array2 = std::make_shared<BinaryArray>(0, nullptr, nullptr);
+
+ CheckRoundtrip(bin_array);
+ CheckRoundtrip(bin_array2);
+}
+
+TEST_F(TestWriteRecordBatch, WriteWithCompression) {
+ random::RandomArrayGenerator rg(/*seed=*/0);
+
+ // Generate both regular and dictionary encoded because the dictionary batch
+ // gets compressed also
+
+ int64_t length = 500;
+
+ int dict_size = 50;
+ std::shared_ptr<Array> dict = rg.String(dict_size, /*min_length=*/5, /*max_length=*/5,
+ /*null_probability=*/0);
+ std::shared_ptr<Array> indices = rg.Int32(length, /*min=*/0, /*max=*/dict_size - 1,
+ /*null_probability=*/0.1);
+
+ auto dict_type = dictionary(int32(), utf8());
+ auto dict_field = field("f1", dict_type);
+ ASSERT_OK_AND_ASSIGN(auto dict_array,
+ DictionaryArray::FromArrays(dict_type, indices, dict));
+
+ auto schema = ::arrow::schema({field("f0", utf8()), dict_field});
+ auto batch =
+ RecordBatch::Make(schema, length, {rg.String(500, 0, 10, 0.1), dict_array});
+
+ std::vector<Compression::type> codecs = {Compression::LZ4_FRAME, Compression::ZSTD};
+ for (auto codec : codecs) {
+ if (!util::Codec::IsAvailable(codec)) {
+ continue;
+ }
+ IpcWriteOptions write_options = IpcWriteOptions::Defaults();
+ ASSERT_OK_AND_ASSIGN(write_options.codec, util::Codec::Create(codec));
+ CheckRoundtrip(*batch, write_options);
+
+ // Check non-parallel read and write
+ IpcReadOptions read_options = IpcReadOptions::Defaults();
+ write_options.use_threads = false;
+ read_options.use_threads = false;
+ CheckRoundtrip(*batch, write_options, read_options);
+ }
+
+ std::vector<Compression::type> disallowed_codecs = {
+ Compression::BROTLI, Compression::BZ2, Compression::LZ4, Compression::GZIP,
+ Compression::SNAPPY};
+ for (auto codec : disallowed_codecs) {
+ if (!util::Codec::IsAvailable(codec)) {
+ continue;
+ }
+ IpcWriteOptions write_options = IpcWriteOptions::Defaults();
+ ASSERT_OK_AND_ASSIGN(write_options.codec, util::Codec::Create(codec));
+ ASSERT_RAISES(Invalid, SerializeRecordBatch(*batch, write_options));
+ }
+}
+
+TEST_F(TestWriteRecordBatch, SliceTruncatesBinaryOffsets) {
+ // ARROW-6046
+ std::shared_ptr<Array> array;
+ ASSERT_OK(MakeRandomStringArray(500, false, default_memory_pool(), &array));
+
+ auto f0 = field("f0", array->type());
+ auto schema = ::arrow::schema({f0});
+ auto batch = RecordBatch::Make(schema, array->length(), {array});
+ auto sliced_batch = batch->Slice(0, 5);
+
+ ASSERT_OK_AND_ASSIGN(
+ mmap_, io::MemoryMapFixture::InitMemoryMap(/*buffer_size=*/1 << 20,
+ TempFile("test-truncate-offsets")));
+ DictionaryMemo dictionary_memo;
+ ASSERT_OK_AND_ASSIGN(
+ auto result,
+ DoStandardRoundTrip(*sliced_batch, IpcWriteOptions::Defaults(), &dictionary_memo));
+ ASSERT_EQ(6 * sizeof(int32_t), result->column(0)->data()->buffers[1]->size());
+}
+
+TEST_F(TestWriteRecordBatch, SliceTruncatesBuffers) {
+ auto CheckArray = [this](const std::shared_ptr<Array>& array) {
+ auto f0 = field("f0", array->type());
+ auto schema = ::arrow::schema({f0});
+ auto batch = RecordBatch::Make(schema, array->length(), {array});
+ auto sliced_batch = batch->Slice(0, 5);
+
+ int64_t full_size;
+ int64_t sliced_size;
+
+ ASSERT_OK(GetRecordBatchSize(*batch, &full_size));
+ ASSERT_OK(GetRecordBatchSize(*sliced_batch, &sliced_size));
+ ASSERT_TRUE(sliced_size < full_size) << sliced_size << " " << full_size;
+
+ // make sure we can write and read it
+ this->CheckRoundtrip(*sliced_batch);
+ };
+
+ std::shared_ptr<Array> a0, a1;
+ auto pool = default_memory_pool();
+
+ // Integer
+ ASSERT_OK(MakeRandomInt32Array(500, false, pool, &a0));
+ CheckArray(a0);
+
+ // String / Binary
+ {
+ auto s = MakeRandomStringArray(500, false, pool, &a0);
+ ASSERT_TRUE(s.ok());
+ }
+ CheckArray(a0);
+
+ // Boolean
+ ASSERT_OK(MakeRandomBooleanArray(10000, false, &a0));
+ CheckArray(a0);
+
+ // List
+ ASSERT_OK(MakeRandomInt32Array(500, false, pool, &a0));
+ ASSERT_OK(MakeRandomListArray(a0, 200, false, pool, &a1));
+ CheckArray(a1);
+
+ // Struct
+ auto struct_type = struct_({field("f0", a0->type())});
+ std::vector<std::shared_ptr<Array>> struct_children = {a0};
+ a1 = std::make_shared<StructArray>(struct_type, a0->length(), struct_children);
+ CheckArray(a1);
+
+ // Sparse Union
+ auto union_type = sparse_union({field("f0", a0->type())}, {0});
+ std::vector<int32_t> type_ids(a0->length());
+ std::shared_ptr<Buffer> ids_buffer;
+ ASSERT_OK(CopyBufferFromVector(type_ids, default_memory_pool(), &ids_buffer));
+ a1 = std::make_shared<SparseUnionArray>(union_type, a0->length(), struct_children,
+ ids_buffer);
+ CheckArray(a1);
+
+ // Dense union
+ auto dense_union_type = dense_union({field("f0", a0->type())}, {0});
+ std::vector<int32_t> type_offsets;
+ for (int32_t i = 0; i < a0->length(); ++i) {
+ type_offsets.push_back(i);
+ }
+ std::shared_ptr<Buffer> offsets_buffer;
+ ASSERT_OK(CopyBufferFromVector(type_offsets, default_memory_pool(), &offsets_buffer));
+ a1 = std::make_shared<DenseUnionArray>(dense_union_type, a0->length(), struct_children,
+ ids_buffer, offsets_buffer);
+ CheckArray(a1);
+}
+
+TEST_F(TestWriteRecordBatch, RoundtripPreservesBufferSizes) {
+ // ARROW-7975
+ random::RandomArrayGenerator rg(/*seed=*/0);
+
+ int64_t length = 15;
+ auto arr = rg.String(length, 0, 10, 0.1);
+ auto batch = RecordBatch::Make(::arrow::schema({field("f0", utf8())}), length, {arr});
+
+ ASSERT_OK_AND_ASSIGN(
+ mmap_, io::MemoryMapFixture::InitMemoryMap(
+ /*buffer_size=*/1 << 20, TempFile("test-roundtrip-buffer-sizes")));
+ DictionaryMemo dictionary_memo;
+ ASSERT_OK_AND_ASSIGN(
+ auto result,
+ DoStandardRoundTrip(*batch, IpcWriteOptions::Defaults(), &dictionary_memo));
+
+ // Make sure that the validity bitmap is size 2 as expected
+ ASSERT_EQ(2, arr->data()->buffers[0]->size());
+
+ for (size_t i = 0; i < arr->data()->buffers.size(); ++i) {
+ ASSERT_EQ(arr->data()->buffers[i]->size(),
+ result->column(0)->data()->buffers[i]->size());
+ }
+}
+
+void TestGetRecordBatchSize(const IpcWriteOptions& options,
+ std::shared_ptr<RecordBatch> batch) {
+ io::MockOutputStream mock;
+ ipc::IpcPayload payload;
+ int32_t mock_metadata_length = -1;
+ int64_t mock_body_length = -1;
+ int64_t size = -1;
+ ASSERT_OK(WriteRecordBatch(*batch, 0, &mock, &mock_metadata_length, &mock_body_length,
+ options));
+ ASSERT_OK(GetRecordBatchPayload(*batch, options, &payload));
+ int64_t payload_size = GetPayloadSize(payload, options);
+ ASSERT_OK(GetRecordBatchSize(*batch, options, &size));
+ ASSERT_EQ(mock.GetExtentBytesWritten(), size);
+ ASSERT_EQ(mock.GetExtentBytesWritten(), payload_size);
+}
+
+TEST_F(TestWriteRecordBatch, IntegerGetRecordBatchSize) {
+ std::shared_ptr<RecordBatch> batch;
+
+ ASSERT_OK(MakeIntRecordBatch(&batch));
+ TestGetRecordBatchSize(options_, batch);
+
+ ASSERT_OK(MakeListRecordBatch(&batch));
+ TestGetRecordBatchSize(options_, batch);
+
+ ASSERT_OK(MakeZeroLengthRecordBatch(&batch));
+ TestGetRecordBatchSize(options_, batch);
+
+ ASSERT_OK(MakeNonNullRecordBatch(&batch));
+ TestGetRecordBatchSize(options_, batch);
+
+ ASSERT_OK(MakeDeeplyNestedList(&batch));
+ TestGetRecordBatchSize(options_, batch);
+}
+
+class RecursionLimits : public ::testing::Test, public io::MemoryMapFixture {
+ public:
+ void SetUp() {
+ pool_ = default_memory_pool();
+ ASSERT_OK_AND_ASSIGN(temp_dir_, TemporaryDir::Make("ipc-recursion-limits-test-"));
+ }
+
+ std::string TempFile(util::string_view file) {
+ return temp_dir_->path().Join(std::string(file)).ValueOrDie().ToString();
+ }
+
+ void TearDown() { io::MemoryMapFixture::TearDown(); }
+
+ Status WriteToMmap(int recursion_level, bool override_level, int32_t* metadata_length,
+ int64_t* body_length, std::shared_ptr<RecordBatch>* batch,
+ std::shared_ptr<Schema>* schema) {
+ const int batch_length = 5;
+ auto type = int32();
+ std::shared_ptr<Array> array;
+ const bool include_nulls = true;
+ RETURN_NOT_OK(MakeRandomInt32Array(1000, include_nulls, pool_, &array));
+ for (int i = 0; i < recursion_level; ++i) {
+ type = list(type);
+ RETURN_NOT_OK(
+ MakeRandomListArray(array, batch_length, include_nulls, pool_, &array));
+ }
+
+ auto f0 = field("f0", type);
+
+ *schema = ::arrow::schema({f0});
+
+ *batch = RecordBatch::Make(*schema, batch_length, {array});
+
+ std::stringstream ss;
+ ss << "test-write-past-max-recursion-" << g_file_number++;
+ const int memory_map_size = 1 << 20;
+ ARROW_ASSIGN_OR_RAISE(
+ mmap_, io::MemoryMapFixture::InitMemoryMap(memory_map_size, TempFile(ss.str())));
+
+ auto options = IpcWriteOptions::Defaults();
+ if (override_level) {
+ options.max_recursion_depth = recursion_level + 1;
+ }
+ return WriteRecordBatch(**batch, 0, mmap_.get(), metadata_length, body_length,
+ options);
+ }
+
+ protected:
+ std::shared_ptr<io::MemoryMappedFile> mmap_;
+ std::unique_ptr<TemporaryDir> temp_dir_;
+ MemoryPool* pool_;
+};
+
+TEST_F(RecursionLimits, WriteLimit) {
+ int32_t metadata_length = -1;
+ int64_t body_length = -1;
+ std::shared_ptr<Schema> schema;
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_RAISES(Invalid, WriteToMmap((1 << 8) + 1, false, &metadata_length, &body_length,
+ &batch, &schema));
+}
+
+TEST_F(RecursionLimits, ReadLimit) {
+ int32_t metadata_length = -1;
+ int64_t body_length = -1;
+ std::shared_ptr<Schema> schema;
+
+ const int recursion_depth = 64;
+
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(WriteToMmap(recursion_depth, true, &metadata_length, &body_length, &batch,
+ &schema));
+
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr<Message> message,
+ ReadMessage(0, metadata_length, mmap_.get()));
+
+ io::BufferReader reader(message->body());
+
+ DictionaryMemo empty_memo;
+ ASSERT_RAISES(Invalid, ReadRecordBatch(*message->metadata(), schema, &empty_memo,
+ IpcReadOptions::Defaults(), &reader));
+}
+
+// Test fails with a structured exception on Windows + Debug
+#if !defined(_WIN32) || defined(NDEBUG)
+TEST_F(RecursionLimits, StressLimit) {
+ auto CheckDepth = [this](int recursion_depth, bool* it_works) {
+ int32_t metadata_length = -1;
+ int64_t body_length = -1;
+ std::shared_ptr<Schema> schema;
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(WriteToMmap(recursion_depth, true, &metadata_length, &body_length, &batch,
+ &schema));
+
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr<Message> message,
+ ReadMessage(0, metadata_length, mmap_.get()));
+
+ DictionaryMemo empty_memo;
+
+ auto options = IpcReadOptions::Defaults();
+ options.max_recursion_depth = recursion_depth + 1;
+ io::BufferReader reader(message->body());
+ std::shared_ptr<RecordBatch> result;
+ ASSERT_OK_AND_ASSIGN(result, ReadRecordBatch(*message->metadata(), schema,
+ &empty_memo, options, &reader));
+ *it_works = result->Equals(*batch);
+ };
+
+ bool it_works = false;
+ CheckDepth(100, &it_works);
+ ASSERT_TRUE(it_works);
+
+// Mitigate Valgrind's slowness
+#if !defined(ARROW_VALGRIND)
+ CheckDepth(500, &it_works);
+ ASSERT_TRUE(it_works);
+#endif
+}
+#endif // !defined(_WIN32) || defined(NDEBUG)
+
+struct FileWriterHelper {
+ static constexpr bool kIsFileFormat = true;
+
+ Status Init(const std::shared_ptr<Schema>& schema, const IpcWriteOptions& options,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = nullptr) {
+ num_batches_written_ = 0;
+
+ ARROW_ASSIGN_OR_RAISE(buffer_, AllocateResizableBuffer(0));
+ sink_.reset(new io::BufferOutputStream(buffer_));
+ ARROW_ASSIGN_OR_RAISE(writer_,
+ MakeFileWriter(sink_.get(), schema, options, metadata));
+ return Status::OK();
+ }
+
+ Status WriteBatch(const std::shared_ptr<RecordBatch>& batch) {
+ RETURN_NOT_OK(writer_->WriteRecordBatch(*batch));
+ num_batches_written_++;
+ return Status::OK();
+ }
+
+ Status WriteTable(const RecordBatchVector& batches) {
+ num_batches_written_ += static_cast<int>(batches.size());
+ ARROW_ASSIGN_OR_RAISE(auto table, Table::FromRecordBatches(batches));
+ return writer_->WriteTable(*table);
+ }
+
+ Status Finish(WriteStats* out_stats = nullptr) {
+ RETURN_NOT_OK(writer_->Close());
+ if (out_stats) {
+ *out_stats = writer_->stats();
+ }
+ RETURN_NOT_OK(sink_->Close());
+ // Current offset into stream is the end of the file
+ return sink_->Tell().Value(&footer_offset_);
+ }
+
+ virtual Status ReadBatches(const IpcReadOptions& options,
+ RecordBatchVector* out_batches,
+ ReadStats* out_stats = nullptr) {
+ auto buf_reader = std::make_shared<io::BufferReader>(buffer_);
+ ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchFileReader::Open(
+ buf_reader.get(), footer_offset_, options));
+
+ EXPECT_EQ(num_batches_written_, reader->num_record_batches());
+ for (int i = 0; i < num_batches_written_; ++i) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> chunk,
+ reader->ReadRecordBatch(i));
+ out_batches->push_back(chunk);
+ }
+ if (out_stats) {
+ *out_stats = reader->stats();
+ }
+ return Status::OK();
+ }
+
+ Status ReadSchema(std::shared_ptr<Schema>* out) {
+ return ReadSchema(ipc::IpcReadOptions::Defaults(), out);
+ }
+
+ Status ReadSchema(const IpcReadOptions& read_options, std::shared_ptr<Schema>* out) {
+ auto buf_reader = std::make_shared<io::BufferReader>(buffer_);
+ ARROW_ASSIGN_OR_RAISE(
+ auto reader,
+ RecordBatchFileReader::Open(buf_reader.get(), footer_offset_, read_options));
+
+ *out = reader->schema();
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<const KeyValueMetadata>> ReadFooterMetadata() {
+ auto buf_reader = std::make_shared<io::BufferReader>(buffer_);
+ ARROW_ASSIGN_OR_RAISE(auto reader,
+ RecordBatchFileReader::Open(buf_reader.get(), footer_offset_));
+ return reader->metadata();
+ }
+
+ std::shared_ptr<ResizableBuffer> buffer_;
+ std::unique_ptr<io::BufferOutputStream> sink_;
+ std::shared_ptr<RecordBatchWriter> writer_;
+ int num_batches_written_;
+ int64_t footer_offset_;
+};
+
+struct FileGeneratorWriterHelper : public FileWriterHelper {
+ Status ReadBatches(const IpcReadOptions& options, RecordBatchVector* out_batches,
+ ReadStats* out_stats = nullptr) override {
+ auto buf_reader = std::make_shared<io::BufferReader>(buffer_);
+ AsyncGenerator<std::shared_ptr<RecordBatch>> generator;
+
+ {
+ auto fut =
+ RecordBatchFileReader::OpenAsync(buf_reader.get(), footer_offset_, options);
+ // Do NOT assert OK since some tests check whether this fails properly
+ EXPECT_FINISHES(fut);
+ ARROW_ASSIGN_OR_RAISE(auto reader, fut.result());
+ EXPECT_EQ(num_batches_written_, reader->num_record_batches());
+ // Generator will keep reader alive internally
+ ARROW_ASSIGN_OR_RAISE(generator, reader->GetRecordBatchGenerator());
+ }
+
+ // Generator is async-reentrant
+ std::vector<Future<std::shared_ptr<RecordBatch>>> futures;
+ for (int i = 0; i < num_batches_written_; ++i) {
+ futures.push_back(generator());
+ }
+ auto fut = generator();
+ EXPECT_FINISHES_OK_AND_EQ(nullptr, fut);
+ for (auto& future : futures) {
+ EXPECT_FINISHES_OK_AND_ASSIGN(auto batch, future);
+ out_batches->push_back(batch);
+ }
+
+ // The generator doesn't track stats.
+ EXPECT_EQ(nullptr, out_stats);
+
+ return Status::OK();
+ }
+};
+
+struct StreamWriterHelper {
+ static constexpr bool kIsFileFormat = false;
+
+ Status Init(const std::shared_ptr<Schema>& schema, const IpcWriteOptions& options) {
+ ARROW_ASSIGN_OR_RAISE(buffer_, AllocateResizableBuffer(0));
+ sink_.reset(new io::BufferOutputStream(buffer_));
+ ARROW_ASSIGN_OR_RAISE(writer_, MakeStreamWriter(sink_.get(), schema, options));
+ return Status::OK();
+ }
+
+ Status WriteBatch(const std::shared_ptr<RecordBatch>& batch) {
+ RETURN_NOT_OK(writer_->WriteRecordBatch(*batch));
+ return Status::OK();
+ }
+
+ Status WriteTable(const RecordBatchVector& batches) {
+ ARROW_ASSIGN_OR_RAISE(auto table, Table::FromRecordBatches(batches));
+ return writer_->WriteTable(*table);
+ }
+
+ Status Finish(WriteStats* out_stats = nullptr) {
+ RETURN_NOT_OK(writer_->Close());
+ if (out_stats) {
+ *out_stats = writer_->stats();
+ }
+ return sink_->Close();
+ }
+
+ virtual Status ReadBatches(const IpcReadOptions& options,
+ RecordBatchVector* out_batches,
+ ReadStats* out_stats = nullptr) {
+ auto buf_reader = std::make_shared<io::BufferReader>(buffer_);
+ ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchStreamReader::Open(buf_reader, options))
+ RETURN_NOT_OK(reader->ReadAll(out_batches));
+ if (out_stats) {
+ *out_stats = reader->stats();
+ }
+ return Status::OK();
+ }
+
+ Status ReadSchema(std::shared_ptr<Schema>* out) {
+ return ReadSchema(ipc::IpcReadOptions::Defaults(), out);
+ }
+
+ virtual Status ReadSchema(const IpcReadOptions& read_options,
+ std::shared_ptr<Schema>* out) {
+ auto buf_reader = std::make_shared<io::BufferReader>(buffer_);
+ ARROW_ASSIGN_OR_RAISE(auto reader,
+ RecordBatchStreamReader::Open(buf_reader.get(), read_options));
+ *out = reader->schema();
+ return Status::OK();
+ }
+
+ std::shared_ptr<ResizableBuffer> buffer_;
+ std::unique_ptr<io::BufferOutputStream> sink_;
+ std::shared_ptr<RecordBatchWriter> writer_;
+};
+
+struct StreamDecoderWriterHelper : public StreamWriterHelper {
+ Status ReadBatches(const IpcReadOptions& options, RecordBatchVector* out_batches,
+ ReadStats* out_stats = nullptr) override {
+ auto listener = std::make_shared<CollectListener>();
+ StreamDecoder decoder(listener, options);
+ RETURN_NOT_OK(DoConsume(&decoder));
+ *out_batches = listener->record_batches();
+ if (out_stats) {
+ *out_stats = decoder.stats();
+ }
+ return Status::OK();
+ }
+
+ Status ReadSchema(const IpcReadOptions& read_options,
+ std::shared_ptr<Schema>* out) override {
+ auto listener = std::make_shared<CollectListener>();
+ StreamDecoder decoder(listener, read_options);
+ RETURN_NOT_OK(DoConsume(&decoder));
+ *out = listener->schema();
+ return Status::OK();
+ }
+
+ virtual Status DoConsume(StreamDecoder* decoder) = 0;
+};
+
+struct StreamDecoderDataWriterHelper : public StreamDecoderWriterHelper {
+ Status DoConsume(StreamDecoder* decoder) override {
+ return decoder->Consume(buffer_->data(), buffer_->size());
+ }
+};
+
+struct StreamDecoderBufferWriterHelper : public StreamDecoderWriterHelper {
+ Status DoConsume(StreamDecoder* decoder) override { return decoder->Consume(buffer_); }
+};
+
+struct StreamDecoderSmallChunksWriterHelper : public StreamDecoderWriterHelper {
+ Status DoConsume(StreamDecoder* decoder) override {
+ for (int64_t offset = 0; offset < buffer_->size() - 1; ++offset) {
+ RETURN_NOT_OK(decoder->Consume(buffer_->data() + offset, 1));
+ }
+ return Status::OK();
+ }
+};
+
+struct StreamDecoderLargeChunksWriterHelper : public StreamDecoderWriterHelper {
+ Status DoConsume(StreamDecoder* decoder) override {
+ RETURN_NOT_OK(decoder->Consume(SliceBuffer(buffer_, 0, 1)));
+ RETURN_NOT_OK(decoder->Consume(SliceBuffer(buffer_, 1)));
+ return Status::OK();
+ }
+};
+
+// Parameterized mixin with tests for stream / file writer
+
+template <class WriterHelperType>
+class ReaderWriterMixin : public ExtensionTypesMixin {
+ public:
+ using WriterHelper = WriterHelperType;
+
+ // Check simple RecordBatch roundtripping
+ template <typename Param>
+ void TestRoundTrip(Param&& param, const IpcWriteOptions& options) {
+ std::shared_ptr<RecordBatch> batch1;
+ std::shared_ptr<RecordBatch> batch2;
+ ASSERT_OK(param(&batch1)); // NOLINT clang-tidy gtest issue
+ ASSERT_OK(param(&batch2)); // NOLINT clang-tidy gtest issue
+
+ RecordBatchVector in_batches = {batch1, batch2};
+ RecordBatchVector out_batches;
+
+ WriterHelper writer_helper;
+ ASSERT_OK(RoundTripHelper(writer_helper, in_batches, options,
+ IpcReadOptions::Defaults(), &out_batches));
+ ASSERT_EQ(out_batches.size(), in_batches.size());
+
+ // Compare batches
+ for (size_t i = 0; i < in_batches.size(); ++i) {
+ CompareBatch(*in_batches[i], *out_batches[i]);
+ }
+ }
+
+ template <typename Param>
+ void TestZeroLengthRoundTrip(Param&& param, const IpcWriteOptions& options) {
+ std::shared_ptr<RecordBatch> batch1;
+ std::shared_ptr<RecordBatch> batch2;
+ ASSERT_OK(param(&batch1)); // NOLINT clang-tidy gtest issue
+ ASSERT_OK(param(&batch2)); // NOLINT clang-tidy gtest issue
+ batch1 = batch1->Slice(0, 0);
+ batch2 = batch2->Slice(0, 0);
+
+ RecordBatchVector in_batches = {batch1, batch2};
+ RecordBatchVector out_batches;
+
+ WriterHelper writer_helper;
+ ASSERT_OK(RoundTripHelper(writer_helper, in_batches, options,
+ IpcReadOptions::Defaults(), &out_batches));
+ ASSERT_EQ(out_batches.size(), in_batches.size());
+
+ // Compare batches
+ for (size_t i = 0; i < in_batches.size(); ++i) {
+ CompareBatch(*in_batches[i], *out_batches[i]);
+ }
+ }
+
+ void TestDictionaryRoundtrip() {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(MakeDictionary(&batch));
+
+ WriterHelper writer_helper;
+ RecordBatchVector out_batches;
+ ASSERT_OK(RoundTripHelper(writer_helper, {batch}, IpcWriteOptions::Defaults(),
+ IpcReadOptions::Defaults(), &out_batches));
+ ASSERT_EQ(out_batches.size(), 1);
+
+ // TODO(wesm): This was broken in ARROW-3144. I'm not sure how to
+ // restore the deduplication logic yet because dictionaries are
+ // corresponded to the Schema using Field pointers rather than
+ // DataType as before
+
+ // CheckDictionariesDeduplicated(*out_batches[0]);
+ }
+
+ void TestReadSubsetOfFields() {
+ // Part of ARROW-7979
+ auto a0 = ArrayFromJSON(utf8(), "[\"a0\", null]");
+ auto a1 = ArrayFromJSON(utf8(), "[\"a1\", null]");
+ auto a2 = ArrayFromJSON(utf8(), "[\"a2\", null]");
+ auto a3 = ArrayFromJSON(utf8(), "[\"a3\", null]");
+
+ auto my_schema = schema({field("a0", utf8()), field("a1", utf8()),
+ field("a2", utf8()), field("a3", utf8())},
+ key_value_metadata({"key1"}, {"value1"}));
+ auto batch = RecordBatch::Make(my_schema, a0->length(), {a0, a1, a2, a3});
+
+ IpcReadOptions options = IpcReadOptions::Defaults();
+
+ options.included_fields = {1, 3};
+
+ {
+ WriterHelper writer_helper;
+ RecordBatchVector out_batches;
+ std::shared_ptr<Schema> out_schema;
+ ASSERT_OK(RoundTripHelper(writer_helper, {batch}, IpcWriteOptions::Defaults(),
+ options, &out_batches, &out_schema));
+
+ auto ex_schema = schema({field("a1", utf8()), field("a3", utf8())},
+ key_value_metadata({"key1"}, {"value1"}));
+ AssertSchemaEqual(*ex_schema, *out_schema);
+
+ auto ex_batch = RecordBatch::Make(ex_schema, a0->length(), {a1, a3});
+ AssertBatchesEqual(*ex_batch, *out_batches[0], /*check_metadata=*/true);
+ }
+
+ // Duplicated or unordered indices are normalized when reading
+ options.included_fields = {3, 1, 1};
+
+ {
+ WriterHelper writer_helper;
+ RecordBatchVector out_batches;
+ std::shared_ptr<Schema> out_schema;
+ ASSERT_OK(RoundTripHelper(writer_helper, {batch}, IpcWriteOptions::Defaults(),
+ options, &out_batches, &out_schema));
+
+ auto ex_schema = schema({field("a1", utf8()), field("a3", utf8())},
+ key_value_metadata({"key1"}, {"value1"}));
+ AssertSchemaEqual(*ex_schema, *out_schema);
+
+ auto ex_batch = RecordBatch::Make(ex_schema, a0->length(), {a1, a3});
+ AssertBatchesEqual(*ex_batch, *out_batches[0], /*check_metadata=*/true);
+ }
+
+ // Out of bounds cases
+ options.included_fields = {1, 3, 5};
+ {
+ WriterHelper writer_helper;
+ RecordBatchVector out_batches;
+ ASSERT_RAISES(Invalid,
+ RoundTripHelper(writer_helper, {batch}, IpcWriteOptions::Defaults(),
+ options, &out_batches));
+ }
+ options.included_fields = {1, 3, -1};
+ {
+ WriterHelper writer_helper;
+ RecordBatchVector out_batches;
+ ASSERT_RAISES(Invalid,
+ RoundTripHelper(writer_helper, {batch}, IpcWriteOptions::Defaults(),
+ options, &out_batches));
+ }
+ }
+
+ void TestWriteDifferentSchema() {
+ // Test writing batches with a different schema than the RecordBatchWriter
+ // was initialized with.
+ std::shared_ptr<RecordBatch> batch_ints, batch_bools;
+ ASSERT_OK(MakeIntRecordBatch(&batch_ints));
+ ASSERT_OK(MakeBooleanBatch(&batch_bools));
+
+ std::shared_ptr<Schema> schema = batch_bools->schema();
+ ASSERT_FALSE(schema->HasMetadata());
+ schema = schema->WithMetadata(key_value_metadata({"some_key"}, {"some_value"}));
+
+ WriterHelper writer_helper;
+ ASSERT_OK(writer_helper.Init(schema, IpcWriteOptions::Defaults()));
+ // Writing a record batch with a different schema
+ ASSERT_RAISES(Invalid, writer_helper.WriteBatch(batch_ints));
+ // Writing a record batch with the same schema (except metadata)
+ ASSERT_OK(writer_helper.WriteBatch(batch_bools));
+ ASSERT_OK(writer_helper.Finish());
+
+ // The single successful batch can be read again
+ RecordBatchVector out_batches;
+ ASSERT_OK(writer_helper.ReadBatches(IpcReadOptions::Defaults(), &out_batches));
+ ASSERT_EQ(out_batches.size(), 1);
+ CompareBatch(*out_batches[0], *batch_bools, false /* compare_metadata */);
+ // Metadata from the RecordBatchWriter initialization schema was kept
+ ASSERT_TRUE(out_batches[0]->schema()->Equals(*schema));
+ }
+
+ void TestWriteNoRecordBatches() {
+ // Test writing no batches.
+ auto schema = arrow::schema({field("a", int32())});
+
+ WriterHelper writer_helper;
+ ASSERT_OK(writer_helper.Init(schema, IpcWriteOptions::Defaults()));
+ ASSERT_OK(writer_helper.Finish());
+
+ RecordBatchVector out_batches;
+ ASSERT_OK(writer_helper.ReadBatches(IpcReadOptions::Defaults(), &out_batches));
+ ASSERT_EQ(out_batches.size(), 0);
+
+ std::shared_ptr<Schema> actual_schema;
+ ASSERT_OK(writer_helper.ReadSchema(&actual_schema));
+ AssertSchemaEqual(*actual_schema, *schema);
+ }
+
+ private:
+ Status RoundTripHelper(WriterHelper& writer_helper, const RecordBatchVector& in_batches,
+ const IpcWriteOptions& write_options,
+ const IpcReadOptions& read_options,
+ RecordBatchVector* out_batches,
+ std::shared_ptr<Schema>* out_schema = nullptr) {
+ RETURN_NOT_OK(writer_helper.Init(in_batches[0]->schema(), write_options));
+ for (const auto& batch : in_batches) {
+ RETURN_NOT_OK(writer_helper.WriteBatch(batch));
+ }
+ RETURN_NOT_OK(writer_helper.Finish());
+ RETURN_NOT_OK(writer_helper.ReadBatches(read_options, out_batches));
+ if (out_schema) {
+ RETURN_NOT_OK(writer_helper.ReadSchema(read_options, out_schema));
+ }
+ for (const auto& batch : *out_batches) {
+ RETURN_NOT_OK(batch->ValidateFull());
+ }
+ return Status::OK();
+ }
+
+ void CheckBatchDictionaries(const RecordBatch& batch) {
+ // Check that dictionaries that should be the same are the same
+ auto schema = batch.schema();
+
+ const auto& b0 = checked_cast<const DictionaryArray&>(*batch.column(0));
+ const auto& b1 = checked_cast<const DictionaryArray&>(*batch.column(1));
+
+ ASSERT_EQ(b0.dictionary().get(), b1.dictionary().get());
+
+ // Same dictionary used for list values
+ const auto& b3 = checked_cast<const ListArray&>(*batch.column(3));
+ const auto& b3_value = checked_cast<const DictionaryArray&>(*b3.values());
+ ASSERT_EQ(b0.dictionary().get(), b3_value.dictionary().get());
+ }
+}; // namespace test
+
+class TestFileFormat : public ReaderWriterMixin<FileWriterHelper>,
+ public ::testing::TestWithParam<MakeRecordBatch*> {};
+
+class TestFileFormatGenerator : public ReaderWriterMixin<FileGeneratorWriterHelper>,
+ public ::testing::TestWithParam<MakeRecordBatch*> {};
+
+class TestStreamFormat : public ReaderWriterMixin<StreamWriterHelper>,
+ public ::testing::TestWithParam<MakeRecordBatch*> {};
+
+class TestStreamDecoderData : public ReaderWriterMixin<StreamDecoderDataWriterHelper>,
+ public ::testing::TestWithParam<MakeRecordBatch*> {};
+class TestStreamDecoderBuffer : public ReaderWriterMixin<StreamDecoderBufferWriterHelper>,
+ public ::testing::TestWithParam<MakeRecordBatch*> {};
+class TestStreamDecoderSmallChunks
+ : public ReaderWriterMixin<StreamDecoderSmallChunksWriterHelper>,
+ public ::testing::TestWithParam<MakeRecordBatch*> {};
+class TestStreamDecoderLargeChunks
+ : public ReaderWriterMixin<StreamDecoderLargeChunksWriterHelper>,
+ public ::testing::TestWithParam<MakeRecordBatch*> {};
+
+TEST_P(TestFileFormat, RoundTrip) {
+ TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
+ TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
+
+ IpcWriteOptions options;
+ options.write_legacy_ipc_format = true;
+ TestRoundTrip(*GetParam(), options);
+ TestZeroLengthRoundTrip(*GetParam(), options);
+}
+
+TEST_P(TestFileFormatGenerator, RoundTrip) {
+ TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
+ TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
+
+ IpcWriteOptions options;
+ options.write_legacy_ipc_format = true;
+ TestRoundTrip(*GetParam(), options);
+ TestZeroLengthRoundTrip(*GetParam(), options);
+}
+
+Status MakeDictionaryBatch(std::shared_ptr<RecordBatch>* out) {
+ auto f0_type = arrow::dictionary(int32(), utf8());
+ auto f1_type = arrow::dictionary(int8(), utf8());
+
+ auto dict = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]");
+
+ auto indices0 = ArrayFromJSON(int32(), "[1, 2, null, 0, 2, 0]");
+ auto indices1 = ArrayFromJSON(int8(), "[0, 0, 2, 2, 1, 1]");
+
+ auto a0 = std::make_shared<DictionaryArray>(f0_type, indices0, dict);
+ auto a1 = std::make_shared<DictionaryArray>(f1_type, indices1, dict);
+
+ // construct batch
+ auto schema = ::arrow::schema({field("dict1", f0_type), field("dict2", f1_type)});
+
+ *out = RecordBatch::Make(schema, 6, {a0, a1});
+ return Status::OK();
+}
+
+// A utility that supports reading/writing record batches,
+// and manually specifying dictionaries.
+class DictionaryBatchHelper {
+ public:
+ explicit DictionaryBatchHelper(const Schema& schema) : schema_(schema) {
+ buffer_ = *AllocateResizableBuffer(0);
+ sink_.reset(new io::BufferOutputStream(buffer_));
+ payload_writer_ = *internal::MakePayloadStreamWriter(sink_.get());
+ }
+
+ Status Start() {
+ RETURN_NOT_OK(payload_writer_->Start());
+
+ // write schema
+ IpcPayload payload;
+ DictionaryFieldMapper mapper(schema_);
+ RETURN_NOT_OK(
+ GetSchemaPayload(schema_, IpcWriteOptions::Defaults(), mapper, &payload));
+ return payload_writer_->WritePayload(payload);
+ }
+
+ Status WriteDictionary(int64_t dictionary_id, const std::shared_ptr<Array>& dictionary,
+ bool is_delta) {
+ IpcPayload payload;
+ RETURN_NOT_OK(GetDictionaryPayload(dictionary_id, is_delta, dictionary,
+ IpcWriteOptions::Defaults(), &payload));
+ RETURN_NOT_OK(payload_writer_->WritePayload(payload));
+ return Status::OK();
+ }
+
+ Status WriteBatchPayload(const RecordBatch& batch) {
+ // write record batch payload only
+ IpcPayload payload;
+ RETURN_NOT_OK(GetRecordBatchPayload(batch, IpcWriteOptions::Defaults(), &payload));
+ return payload_writer_->WritePayload(payload);
+ }
+
+ Status Close() {
+ RETURN_NOT_OK(payload_writer_->Close());
+ return sink_->Close();
+ }
+
+ Status ReadBatch(std::shared_ptr<RecordBatch>* out_batch) {
+ auto buf_reader = std::make_shared<io::BufferReader>(buffer_);
+ std::shared_ptr<RecordBatchReader> reader;
+ ARROW_ASSIGN_OR_RAISE(
+ reader, RecordBatchStreamReader::Open(buf_reader, IpcReadOptions::Defaults()))
+ return reader->ReadNext(out_batch);
+ }
+
+ std::unique_ptr<internal::IpcPayloadWriter> payload_writer_;
+ const Schema& schema_;
+ std::shared_ptr<ResizableBuffer> buffer_;
+ std::unique_ptr<io::BufferOutputStream> sink_;
+};
+
+TEST(TestDictionaryBatch, DictionaryDelta) {
+ std::shared_ptr<RecordBatch> in_batch;
+ std::shared_ptr<RecordBatch> out_batch;
+ ASSERT_OK(MakeDictionaryBatch(&in_batch));
+
+ auto dict1 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\"]");
+ auto dict2 = ArrayFromJSON(utf8(), "[\"baz\"]");
+
+ DictionaryBatchHelper helper(*in_batch->schema());
+ ASSERT_OK(helper.Start());
+
+ ASSERT_OK(helper.WriteDictionary(0L, dict1, /*is_delta=*/false));
+ ASSERT_OK(helper.WriteDictionary(0L, dict2, /*is_delta=*/true));
+
+ ASSERT_OK(helper.WriteDictionary(1L, dict1, /*is_delta=*/false));
+ ASSERT_OK(helper.WriteDictionary(1L, dict2, /*is_delta=*/true));
+
+ ASSERT_OK(helper.WriteBatchPayload(*in_batch));
+ ASSERT_OK(helper.Close());
+
+ ASSERT_OK(helper.ReadBatch(&out_batch));
+
+ ASSERT_BATCHES_EQUAL(*in_batch, *out_batch);
+}
+
+TEST(TestDictionaryBatch, DictionaryDeltaWithUnknownId) {
+ std::shared_ptr<RecordBatch> in_batch;
+ std::shared_ptr<RecordBatch> out_batch;
+ ASSERT_OK(MakeDictionaryBatch(&in_batch));
+
+ auto dict1 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\"]");
+ auto dict2 = ArrayFromJSON(utf8(), "[\"baz\"]");
+
+ DictionaryBatchHelper helper(*in_batch->schema());
+ ASSERT_OK(helper.Start());
+
+ ASSERT_OK(helper.WriteDictionary(0L, dict1, /*is_delta=*/false));
+ ASSERT_OK(helper.WriteDictionary(0L, dict2, /*is_delta=*/true));
+
+ /* This delta dictionary does not have a base dictionary previously in stream */
+ ASSERT_OK(helper.WriteDictionary(1L, dict2, /*is_delta=*/true));
+
+ ASSERT_OK(helper.WriteBatchPayload(*in_batch));
+ ASSERT_OK(helper.Close());
+
+ ASSERT_RAISES(KeyError, helper.ReadBatch(&out_batch));
+}
+
+TEST(TestDictionaryBatch, DictionaryReplacement) {
+ std::shared_ptr<RecordBatch> in_batch;
+ std::shared_ptr<RecordBatch> out_batch;
+ ASSERT_OK(MakeDictionaryBatch(&in_batch));
+
+ auto dict = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]");
+ auto dict1 = ArrayFromJSON(utf8(), "[\"foo1\", \"bar1\", \"baz1\"]");
+ auto dict2 = ArrayFromJSON(utf8(), "[\"foo2\", \"bar2\", \"baz2\"]");
+
+ DictionaryBatchHelper helper(*in_batch->schema());
+ ASSERT_OK(helper.Start());
+
+ // the old dictionaries will be overwritten by
+ // the new dictionaries with the same ids.
+ ASSERT_OK(helper.WriteDictionary(0L, dict1, /*is_delta=*/false));
+ ASSERT_OK(helper.WriteDictionary(0L, dict, /*is_delta=*/false));
+
+ ASSERT_OK(helper.WriteDictionary(1L, dict2, /*is_delta=*/false));
+ ASSERT_OK(helper.WriteDictionary(1L, dict, /*is_delta=*/false));
+
+ ASSERT_OK(helper.WriteBatchPayload(*in_batch));
+ ASSERT_OK(helper.Close());
+
+ ASSERT_OK(helper.ReadBatch(&out_batch));
+
+ ASSERT_BATCHES_EQUAL(*in_batch, *out_batch);
+}
+
+TEST_P(TestStreamFormat, RoundTrip) {
+ TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
+ TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
+
+ IpcWriteOptions options;
+ options.write_legacy_ipc_format = true;
+ TestRoundTrip(*GetParam(), options);
+ TestZeroLengthRoundTrip(*GetParam(), options);
+}
+
+TEST_P(TestStreamDecoderData, RoundTrip) {
+ TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
+ TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
+
+ IpcWriteOptions options;
+ options.write_legacy_ipc_format = true;
+ TestRoundTrip(*GetParam(), options);
+ TestZeroLengthRoundTrip(*GetParam(), options);
+}
+
+TEST_P(TestStreamDecoderBuffer, RoundTrip) {
+ TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
+ TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
+
+ IpcWriteOptions options;
+ options.write_legacy_ipc_format = true;
+ TestRoundTrip(*GetParam(), options);
+ TestZeroLengthRoundTrip(*GetParam(), options);
+}
+
+TEST_P(TestStreamDecoderSmallChunks, RoundTrip) {
+ TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
+ TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
+
+ IpcWriteOptions options;
+ options.write_legacy_ipc_format = true;
+ TestRoundTrip(*GetParam(), options);
+ TestZeroLengthRoundTrip(*GetParam(), options);
+}
+
+TEST_P(TestStreamDecoderLargeChunks, RoundTrip) {
+ TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
+ TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults());
+
+ IpcWriteOptions options;
+ options.write_legacy_ipc_format = true;
+ TestRoundTrip(*GetParam(), options);
+ TestZeroLengthRoundTrip(*GetParam(), options);
+}
+
+INSTANTIATE_TEST_SUITE_P(GenericIpcRoundTripTests, TestIpcRoundTrip,
+ ::testing::ValuesIn(kBatchCases));
+INSTANTIATE_TEST_SUITE_P(FileRoundTripTests, TestFileFormat,
+ ::testing::ValuesIn(kBatchCases));
+INSTANTIATE_TEST_SUITE_P(FileRoundTripTests, TestFileFormatGenerator,
+ ::testing::ValuesIn(kBatchCases));
+INSTANTIATE_TEST_SUITE_P(StreamRoundTripTests, TestStreamFormat,
+ ::testing::ValuesIn(kBatchCases));
+INSTANTIATE_TEST_SUITE_P(StreamDecoderDataRoundTripTests, TestStreamDecoderData,
+ ::testing::ValuesIn(kBatchCases));
+INSTANTIATE_TEST_SUITE_P(StreamDecoderBufferRoundTripTests, TestStreamDecoderBuffer,
+ ::testing::ValuesIn(kBatchCases));
+INSTANTIATE_TEST_SUITE_P(StreamDecoderSmallChunksRoundTripTests,
+ TestStreamDecoderSmallChunks, ::testing::ValuesIn(kBatchCases));
+INSTANTIATE_TEST_SUITE_P(StreamDecoderLargeChunksRoundTripTests,
+ TestStreamDecoderLargeChunks, ::testing::ValuesIn(kBatchCases));
+
+TEST(TestIpcFileFormat, FooterMetaData) {
+ // ARROW-6837
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(MakeIntRecordBatch(&batch));
+
+ auto metadata = key_value_metadata({"ARROW:example", "ARROW:example2"},
+ {"something something", "something something2"});
+
+ FileWriterHelper helper;
+ ASSERT_OK(helper.Init(batch->schema(), IpcWriteOptions::Defaults(), metadata));
+ ASSERT_OK(helper.WriteBatch(batch));
+ ASSERT_OK(helper.Finish());
+
+ ASSERT_OK_AND_ASSIGN(auto out_metadata, helper.ReadFooterMetadata());
+ ASSERT_TRUE(out_metadata->Equals(*metadata));
+}
+
+// This test uses uninitialized memory
+
+#if !(defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER))
+TEST_F(TestIpcRoundTrip, LargeRecordBatch) {
+ const int64_t length = static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
+
+ TypedBufferBuilder<bool> data_builder;
+ ASSERT_OK(data_builder.Reserve(length));
+ ASSERT_OK(data_builder.Advance(length));
+ ASSERT_EQ(data_builder.length(), length);
+ ASSERT_OK_AND_ASSIGN(auto data, data_builder.Finish());
+
+ auto array = std::make_shared<BooleanArray>(length, data, nullptr, /*null_count=*/0);
+
+ auto f0 = arrow::field("f0", array->type());
+ std::vector<std::shared_ptr<Field>> fields = {f0};
+ auto schema = std::make_shared<Schema>(fields);
+
+ auto batch = RecordBatch::Make(schema, length, {array});
+
+ std::string path = "test-write-large-record_batch";
+
+ // 512 MB
+ constexpr int64_t kBufferSize = 1 << 29;
+ ASSERT_OK_AND_ASSIGN(mmap_, io::MemoryMapFixture::InitMemoryMap(kBufferSize, path));
+
+ ASSERT_OK_AND_ASSIGN(auto result, DoLargeRoundTrip(*batch, false));
+ CheckReadResult(*result, *batch);
+
+ ASSERT_EQ(length, result->num_rows());
+}
+#endif
+
+TEST_F(TestStreamFormat, DictionaryRoundTrip) { TestDictionaryRoundtrip(); }
+
+TEST_F(TestFileFormat, DictionaryRoundTrip) { TestDictionaryRoundtrip(); }
+
+TEST_F(TestFileFormatGenerator, DictionaryRoundTrip) { TestDictionaryRoundtrip(); }
+
+TEST_F(TestStreamFormat, DifferentSchema) { TestWriteDifferentSchema(); }
+
+TEST_F(TestFileFormat, DifferentSchema) { TestWriteDifferentSchema(); }
+
+TEST_F(TestFileFormatGenerator, DifferentSchema) { TestWriteDifferentSchema(); }
+
+TEST_F(TestStreamFormat, NoRecordBatches) { TestWriteNoRecordBatches(); }
+
+TEST_F(TestFileFormat, NoRecordBatches) { TestWriteNoRecordBatches(); }
+
+TEST_F(TestFileFormatGenerator, NoRecordBatches) { TestWriteNoRecordBatches(); }
+
+TEST_F(TestStreamFormat, ReadFieldSubset) { TestReadSubsetOfFields(); }
+
+TEST_F(TestFileFormat, ReadFieldSubset) { TestReadSubsetOfFields(); }
+
+TEST_F(TestFileFormatGenerator, ReadFieldSubset) { TestReadSubsetOfFields(); }
+
+TEST(TestRecordBatchStreamReader, EmptyStreamWithDictionaries) {
+ // ARROW-6006
+ auto f0 = arrow::field("f0", arrow::dictionary(arrow::int8(), arrow::utf8()));
+ auto schema = arrow::schema({f0});
+
+ ASSERT_OK_AND_ASSIGN(auto stream, io::BufferOutputStream::Create(0));
+
+ ASSERT_OK_AND_ASSIGN(auto writer, MakeStreamWriter(stream, schema));
+ ASSERT_OK(writer->Close());
+
+ ASSERT_OK_AND_ASSIGN(auto buffer, stream->Finish());
+ io::BufferReader buffer_reader(buffer);
+ std::shared_ptr<RecordBatchReader> reader;
+ ASSERT_OK_AND_ASSIGN(reader, RecordBatchStreamReader::Open(&buffer_reader));
+
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(reader->ReadNext(&batch));
+ ASSERT_EQ(nullptr, batch);
+}
+
+// Delimit IPC stream messages and reassemble with the indicated messages
+// included. This way we can remove messages from an IPC stream to test
+// different failure modes or other difficult-to-test behaviors
+void SpliceMessages(std::shared_ptr<Buffer> stream,
+ const std::vector<int>& included_indices,
+ std::shared_ptr<Buffer>* spliced_stream) {
+ ASSERT_OK_AND_ASSIGN(auto out, io::BufferOutputStream::Create(0));
+
+ io::BufferReader buffer_reader(stream);
+ std::unique_ptr<MessageReader> message_reader = MessageReader::Open(&buffer_reader);
+ std::unique_ptr<Message> msg;
+
+ // Parse and reassemble first two messages in stream
+ int message_index = 0;
+ while (true) {
+ ASSERT_OK_AND_ASSIGN(msg, message_reader->ReadNextMessage());
+ if (!msg) {
+ break;
+ }
+
+ if (std::find(included_indices.begin(), included_indices.end(), message_index++) ==
+ included_indices.end()) {
+ // Message being dropped, continue
+ continue;
+ }
+
+ IpcWriteOptions options;
+ IpcPayload payload;
+ payload.type = msg->type();
+ payload.metadata = msg->metadata();
+ payload.body_buffers.push_back(msg->body());
+ payload.body_length = msg->body()->size();
+ int32_t unused_metadata_length = -1;
+ ASSERT_OK(ipc::WriteIpcPayload(payload, options, out.get(), &unused_metadata_length));
+ }
+ ASSERT_OK_AND_ASSIGN(*spliced_stream, out->Finish());
+}
+
+TEST(TestRecordBatchStreamReader, NotEnoughDictionaries) {
+ // ARROW-6126
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(MakeDictionaryFlat(&batch));
+
+ ASSERT_OK_AND_ASSIGN(auto out, io::BufferOutputStream::Create(0));
+ ASSERT_OK_AND_ASSIGN(auto writer, MakeStreamWriter(out, batch->schema()));
+ ASSERT_OK(writer->WriteRecordBatch(*batch));
+ ASSERT_OK(writer->Close());
+
+ // Now let's mangle the stream a little bit and make sure we return the right
+ // error
+ ASSERT_OK_AND_ASSIGN(auto buffer, out->Finish());
+
+ auto AssertFailsWith = [](std::shared_ptr<Buffer> stream, const std::string& ex_error) {
+ io::BufferReader reader(stream);
+ ASSERT_OK_AND_ASSIGN(auto ipc_reader, RecordBatchStreamReader::Open(&reader));
+ std::shared_ptr<RecordBatch> batch;
+ Status s = ipc_reader->ReadNext(&batch);
+ ASSERT_TRUE(s.IsInvalid());
+ ASSERT_EQ(ex_error, s.message().substr(0, ex_error.size()));
+ };
+
+ // Stream terminates before reading all dictionaries
+ std::shared_ptr<Buffer> truncated_stream;
+ SpliceMessages(buffer, {0, 1}, &truncated_stream);
+ std::string ex_message =
+ ("IPC stream ended without reading the expected number (3)"
+ " of dictionaries");
+ AssertFailsWith(truncated_stream, ex_message);
+
+ // One of the dictionaries is missing, then we see a record batch
+ SpliceMessages(buffer, {0, 1, 2, 4}, &truncated_stream);
+ ex_message =
+ ("IPC stream did not have the expected number (3) of dictionaries "
+ "at the start of the stream");
+ AssertFailsWith(truncated_stream, ex_message);
+}
+
+TEST(TestRecordBatchStreamReader, MalformedInput) {
+ const std::string empty_str = "";
+ const std::string garbage_str = "12345678";
+
+ auto empty = std::make_shared<Buffer>(empty_str);
+ auto garbage = std::make_shared<Buffer>(garbage_str);
+
+ io::BufferReader empty_reader(empty);
+ ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&empty_reader));
+
+ io::BufferReader garbage_reader(garbage);
+ ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader));
+}
+
+TEST(TestStreamDecoder, NextRequiredSize) {
+ auto listener = std::make_shared<CollectListener>();
+ StreamDecoder decoder(listener);
+ auto next_required_size = decoder.next_required_size();
+ const uint8_t data[1] = {0};
+ ASSERT_OK(decoder.Consume(data, 1));
+ ASSERT_EQ(next_required_size - 1, decoder.next_required_size());
+}
+
+template <typename WriterHelperType>
+class TestDictionaryReplacement : public ::testing::Test {
+ public:
+ using WriterHelper = WriterHelperType;
+
+ void TestSameDictPointer() {
+ auto type = dictionary(int8(), utf8());
+ auto values = ArrayFromJSON(utf8(), R"(["foo", "bar", "quux"])");
+ auto batch1 = MakeBatch(type, ArrayFromJSON(int8(), "[0, 2, null, 1]"), values);
+ auto batch2 = MakeBatch(type, ArrayFromJSON(int8(), "[1, 0, 0]"), values);
+ CheckRoundtrip({batch1, batch2});
+
+ EXPECT_EQ(read_stats_.num_messages, 4); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 2);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 1);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 0);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+ }
+
+ void TestSameDictValues() {
+ auto type = dictionary(int8(), utf8());
+ // Create two separate dictionaries, but with the same contents
+ auto batch1 = MakeBatch(ArrayFromJSON(type, R"(["foo", "foo", "bar", null])"));
+ auto batch2 = MakeBatch(ArrayFromJSON(type, R"(["foo", "bar", "foo"])"));
+ CheckRoundtrip({batch1, batch2});
+
+ EXPECT_EQ(read_stats_.num_messages, 4); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 2);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 1);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 0);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+ }
+
+ void TestDeltaDict() {
+ auto type = dictionary(int8(), utf8());
+ auto batch1 = MakeBatch(ArrayFromJSON(type, R"(["foo", "foo", "bar", null])"));
+ // Potential delta
+ auto batch2 = MakeBatch(ArrayFromJSON(type, R"(["foo", "bar", "quux", "foo"])"));
+ // Potential delta
+ auto batch3 =
+ MakeBatch(ArrayFromJSON(type, R"(["foo", "bar", "quux", "zzz", "foo"])"));
+ auto batch4 = MakeBatch(ArrayFromJSON(type, R"(["bar", null, "quux", "foo"])"));
+ RecordBatchVector batches{batch1, batch2, batch3, batch4};
+
+ // Emit replacements
+ if (WriterHelper::kIsFileFormat) {
+ CheckWritingFails(batches, 1);
+ } else {
+ CheckRoundtrip(batches);
+ EXPECT_EQ(read_stats_.num_messages, 9); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 4);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 4);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 3);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+ }
+
+ // Emit deltas
+ write_options_.emit_dictionary_deltas = true;
+ if (WriterHelper::kIsFileFormat) {
+ CheckWritingFails(batches, 1);
+ } else {
+ CheckRoundtrip(batches);
+ EXPECT_EQ(read_stats_.num_messages, 9); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 4);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 4);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 1);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 2);
+ }
+
+ // IPC file format: WriteTable should unify dicts
+ RecordBatchVector actual;
+ write_options_.unify_dictionaries = true;
+ ASSERT_OK(RoundTripTable(batches, &actual));
+ if (WriterHelper::kIsFileFormat) {
+ EXPECT_EQ(read_stats_.num_messages, 6); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 4);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 1);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 0);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+ CheckBatchesLogical(batches, actual);
+ } else {
+ EXPECT_EQ(read_stats_.num_messages, 9); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 4);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 4);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 1);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 2);
+ CheckBatches(batches, actual);
+ }
+ }
+
+ void TestSameDictValuesNested() {
+ auto batches = SameValuesNestedDictBatches();
+ CheckRoundtrip(batches);
+
+ EXPECT_EQ(read_stats_.num_messages, 5); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 2);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 2);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 0);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+
+ write_options_.unify_dictionaries = true;
+ CheckRoundtrip(batches);
+ if (WriterHelper::kIsFileFormat) {
+ // This fails because unification of nested dictionaries is not supported.
+ // However, perhaps this should work because the dictionaries are simply equal.
+ CheckWritingTableFails(batches, StatusCode::NotImplemented);
+ } else {
+ CheckRoundtripTable(batches);
+ }
+ }
+
+ void TestDifferentDictValues() {
+ if (WriterHelper::kIsFileFormat) {
+ CheckWritingFails(DifferentOrderDictBatches(), 1);
+ CheckWritingFails(DifferentValuesDictBatches(), 1);
+ } else {
+ CheckRoundtrip(DifferentOrderDictBatches());
+
+ EXPECT_EQ(read_stats_.num_messages, 5); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 2);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 2);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 1);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+
+ CheckRoundtrip(DifferentValuesDictBatches());
+
+ EXPECT_EQ(read_stats_.num_messages, 5); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 2);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 2);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 1);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+ }
+
+ // Same, but single-shot table write
+ if (WriterHelper::kIsFileFormat) {
+ CheckWritingTableFails(DifferentOrderDictBatches());
+ CheckWritingTableFails(DifferentValuesDictBatches());
+
+ write_options_.unify_dictionaries = true;
+ // Will unify dictionaries
+ CheckRoundtripTable(DifferentOrderDictBatches());
+
+ EXPECT_EQ(read_stats_.num_messages, 4); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 2);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 1);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 0);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+
+ CheckRoundtripTable(DifferentValuesDictBatches());
+
+ EXPECT_EQ(read_stats_.num_messages, 4); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 2);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 1);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 0);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+ } else {
+ CheckRoundtripTable(DifferentOrderDictBatches());
+
+ EXPECT_EQ(read_stats_.num_messages, 5); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 2);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 2);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 1);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+
+ CheckRoundtripTable(DifferentValuesDictBatches());
+
+ EXPECT_EQ(read_stats_.num_messages, 5); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 2);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 2);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 1);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+ }
+ }
+
+ void TestDifferentDictValuesNested() {
+ if (WriterHelper::kIsFileFormat) {
+ CheckWritingFails(DifferentValuesNestedDictBatches1(), 1);
+ CheckWritingFails(DifferentValuesNestedDictBatches2(), 1);
+ CheckWritingTableFails(DifferentValuesNestedDictBatches1());
+ CheckWritingTableFails(DifferentValuesNestedDictBatches2());
+
+ write_options_.unify_dictionaries = true;
+ CheckWritingFails(DifferentValuesNestedDictBatches1(), 1);
+ CheckWritingFails(DifferentValuesNestedDictBatches2(), 1);
+ CheckWritingTableFails(DifferentValuesNestedDictBatches1(),
+ StatusCode::NotImplemented);
+ CheckWritingTableFails(DifferentValuesNestedDictBatches2(),
+ StatusCode::NotImplemented);
+ return;
+ }
+ CheckRoundtrip(DifferentValuesNestedDictBatches1());
+
+ EXPECT_EQ(read_stats_.num_messages, 7); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 2);
+ // Both inner and outer dict were replaced
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 4);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 2);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+
+ CheckRoundtrip(DifferentValuesNestedDictBatches2());
+
+ EXPECT_EQ(read_stats_.num_messages, 6); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 2);
+ // Only inner dict was replaced
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 3);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 1);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+ }
+
+ void TestDeltaDictNestedOuter() {
+ // Outer dict changes, inner dict remains the same
+ auto value_type = list(dictionary(int8(), utf8()));
+ auto type = dictionary(int8(), value_type);
+ // Inner dict: ["a", "b"]
+ auto batch1_values = ArrayFromJSON(value_type, R"([["a"], ["b"]])");
+ // Potential delta
+ auto batch2_values = ArrayFromJSON(value_type, R"([["a"], ["b"], ["a", "a"]])");
+ auto batch1 = MakeBatch(type, ArrayFromJSON(int8(), "[1, 0, 1]"), batch1_values);
+ auto batch2 =
+ MakeBatch(type, ArrayFromJSON(int8(), "[2, null, 0, 0]"), batch2_values);
+ RecordBatchVector batches{batch1, batch2};
+
+ if (WriterHelper::kIsFileFormat) {
+ CheckWritingFails(batches, 1);
+ } else {
+ CheckRoundtrip(batches);
+ EXPECT_EQ(read_stats_.num_messages, 6); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 2);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 3);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 1);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+ }
+
+ write_options_.emit_dictionary_deltas = true;
+ if (WriterHelper::kIsFileFormat) {
+ CheckWritingFails(batches, 1);
+ } else {
+ // Outer dict deltas are not emitted as the read path doesn't support them
+ CheckRoundtrip(batches);
+ EXPECT_EQ(read_stats_.num_messages, 6); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 2);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 3);
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 1);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+ }
+ }
+
+ void TestDeltaDictNestedInner() {
+ // Inner dict changes
+ auto value_type = list(dictionary(int8(), utf8()));
+ auto type = dictionary(int8(), value_type);
+ // Inner dict: ["a"]
+ auto batch1_values = ArrayFromJSON(value_type, R"([["a"]])");
+ // Inner dict: ["a", "b"] => potential delta
+ auto batch2_values = ArrayFromJSON(value_type, R"([["a"], ["b"], ["a", "a"]])");
+ // Inner dict: ["a", "b", "c"] => potential delta
+ auto batch3_values = ArrayFromJSON(value_type, R"([["a"], ["b"], ["c"]])");
+ // Inner dict: ["a", "b", "c"]
+ auto batch4_values = ArrayFromJSON(value_type, R"([["a"], ["b", "c"]])");
+ // Inner dict: ["a", "c", "b"] => replacement
+ auto batch5_values = ArrayFromJSON(value_type, R"([["a"], ["c"], ["b"]])");
+ auto batch1 = MakeBatch(type, ArrayFromJSON(int8(), "[0, null, 0]"), batch1_values);
+ auto batch2 = MakeBatch(type, ArrayFromJSON(int8(), "[1, 0, 2]"), batch2_values);
+ auto batch3 = MakeBatch(type, ArrayFromJSON(int8(), "[1, 0, 2]"), batch3_values);
+ auto batch4 = MakeBatch(type, ArrayFromJSON(int8(), "[1, 0, null]"), batch4_values);
+ auto batch5 = MakeBatch(type, ArrayFromJSON(int8(), "[1, 0, 2]"), batch5_values);
+ RecordBatchVector batches{batch1, batch2, batch3, batch4, batch5};
+
+ if (WriterHelper::kIsFileFormat) {
+ CheckWritingFails(batches, 1);
+ } else {
+ CheckRoundtrip(batches);
+ EXPECT_EQ(read_stats_.num_messages, 15); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 5);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 9); // 4 inner + 5 outer
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 7);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 0);
+ }
+
+ write_options_.emit_dictionary_deltas = true;
+ if (WriterHelper::kIsFileFormat) {
+ CheckWritingFails(batches, 1);
+ } else {
+ CheckRoundtrip(batches);
+ EXPECT_EQ(read_stats_.num_messages, 15); // including schema message
+ EXPECT_EQ(read_stats_.num_record_batches, 5);
+ EXPECT_EQ(read_stats_.num_dictionary_batches, 9); // 4 inner + 5 outer
+ EXPECT_EQ(read_stats_.num_replaced_dictionaries, 5);
+ EXPECT_EQ(read_stats_.num_dictionary_deltas, 2);
+ }
+ }
+
+ Status RoundTrip(const RecordBatchVector& in_batches, RecordBatchVector* out_batches) {
+ WriterHelper writer_helper;
+ RETURN_NOT_OK(writer_helper.Init(in_batches[0]->schema(), write_options_));
+ for (const auto& batch : in_batches) {
+ RETURN_NOT_OK(writer_helper.WriteBatch(batch));
+ }
+ RETURN_NOT_OK(writer_helper.Finish(&write_stats_));
+ RETURN_NOT_OK(writer_helper.ReadBatches(read_options_, out_batches, &read_stats_));
+ for (const auto& batch : *out_batches) {
+ RETURN_NOT_OK(batch->ValidateFull());
+ }
+ return Status::OK();
+ }
+
+ Status RoundTripTable(const RecordBatchVector& in_batches,
+ RecordBatchVector* out_batches) {
+ WriterHelper writer_helper;
+ RETURN_NOT_OK(writer_helper.Init(in_batches[0]->schema(), write_options_));
+ // WriteTable is different from a series of WriteBatch for RecordBatchFileWriter
+ RETURN_NOT_OK(writer_helper.WriteTable(in_batches));
+ RETURN_NOT_OK(writer_helper.Finish(&write_stats_));
+ RETURN_NOT_OK(writer_helper.ReadBatches(read_options_, out_batches, &read_stats_));
+ for (const auto& batch : *out_batches) {
+ RETURN_NOT_OK(batch->ValidateFull());
+ }
+ return Status::OK();
+ }
+
+ void CheckBatches(const RecordBatchVector& expected, const RecordBatchVector& actual) {
+ ASSERT_EQ(expected.size(), actual.size());
+ for (size_t i = 0; i < expected.size(); ++i) {
+ AssertBatchesEqual(*expected[i], *actual[i]);
+ }
+ }
+
+ // Check that batches are logically equal, even if e.g. dictionaries
+ // are different.
+ void CheckBatchesLogical(const RecordBatchVector& expected,
+ const RecordBatchVector& actual) {
+ ASSERT_OK_AND_ASSIGN(auto expected_table, Table::FromRecordBatches(expected));
+ ASSERT_OK_AND_ASSIGN(auto actual_table, Table::FromRecordBatches(actual));
+ ASSERT_OK_AND_ASSIGN(expected_table, expected_table->CombineChunks());
+ ASSERT_OK_AND_ASSIGN(actual_table, actual_table->CombineChunks());
+ AssertTablesEqual(*expected_table, *actual_table);
+ }
+
+ void CheckRoundtrip(const RecordBatchVector& in_batches) {
+ RecordBatchVector out_batches;
+ ASSERT_OK(RoundTrip(in_batches, &out_batches));
+ CheckStatsConsistent();
+ CheckBatches(in_batches, out_batches);
+ }
+
+ void CheckRoundtripTable(const RecordBatchVector& in_batches) {
+ RecordBatchVector out_batches;
+ ASSERT_OK(RoundTripTable(in_batches, &out_batches));
+ CheckStatsConsistent();
+ CheckBatchesLogical(in_batches, out_batches);
+ }
+
+ void CheckWritingFails(const RecordBatchVector& in_batches, size_t fails_at_batch_num) {
+ WriterHelper writer_helper;
+ ASSERT_OK(writer_helper.Init(in_batches[0]->schema(), write_options_));
+ for (size_t i = 0; i < fails_at_batch_num; ++i) {
+ ASSERT_OK(writer_helper.WriteBatch(in_batches[i]));
+ }
+ ASSERT_RAISES(Invalid, writer_helper.WriteBatch(in_batches[fails_at_batch_num]));
+ }
+
+ void CheckWritingTableFails(const RecordBatchVector& in_batches,
+ StatusCode expected_error = StatusCode::Invalid) {
+ WriterHelper writer_helper;
+ ASSERT_OK(writer_helper.Init(in_batches[0]->schema(), write_options_));
+ auto st = writer_helper.WriteTable(in_batches);
+ ASSERT_FALSE(st.ok());
+ ASSERT_EQ(st.code(), expected_error);
+ }
+
+ void CheckStatsConsistent() {
+ ASSERT_EQ(read_stats_.num_messages, write_stats_.num_messages);
+ ASSERT_EQ(read_stats_.num_record_batches, write_stats_.num_record_batches);
+ ASSERT_EQ(read_stats_.num_dictionary_batches, write_stats_.num_dictionary_batches);
+ ASSERT_EQ(read_stats_.num_replaced_dictionaries,
+ write_stats_.num_replaced_dictionaries);
+ ASSERT_EQ(read_stats_.num_dictionary_deltas, write_stats_.num_dictionary_deltas);
+ }
+
+ RecordBatchVector DifferentOrderDictBatches() {
+ // Create two separate dictionaries with different order
+ auto type = dictionary(int8(), utf8());
+ auto batch1 = MakeBatch(ArrayFromJSON(type, R"(["foo", "foo", "bar", null])"));
+ auto batch2 = MakeBatch(ArrayFromJSON(type, R"(["bar", "bar", "foo"])"));
+ return {batch1, batch2};
+ }
+
+ RecordBatchVector DifferentValuesDictBatches() {
+ // Create two separate dictionaries with different values
+ auto type = dictionary(int8(), utf8());
+ auto batch1 = MakeBatch(ArrayFromJSON(type, R"(["foo", "foo", "bar", null])"));
+ auto batch2 = MakeBatch(ArrayFromJSON(type, R"(["bar", "quux", "quux"])"));
+ return {batch1, batch2};
+ }
+
+ RecordBatchVector SameValuesNestedDictBatches() {
+ auto value_type = list(dictionary(int8(), utf8()));
+ auto type = dictionary(int8(), value_type);
+ auto batch1_values = ArrayFromJSON(value_type, R"([[], ["a"], ["b"], ["a", "a"]])");
+ auto batch2_values = ArrayFromJSON(value_type, R"([[], ["a"], ["b"], ["a", "a"]])");
+ auto batch1 = MakeBatch(type, ArrayFromJSON(int8(), "[1, 3, 0, 3]"), batch1_values);
+ auto batch2 = MakeBatch(type, ArrayFromJSON(int8(), "[2, null, 2]"), batch2_values);
+ return {batch1, batch2};
+ }
+
+ RecordBatchVector DifferentValuesNestedDictBatches1() {
+ // Inner dictionary values differ
+ auto value_type = list(dictionary(int8(), utf8()));
+ auto type = dictionary(int8(), value_type);
+ auto batch1_values = ArrayFromJSON(value_type, R"([[], ["a"], ["b"], ["a", "a"]])");
+ auto batch2_values = ArrayFromJSON(value_type, R"([[], ["a"], ["c"], ["a", "a"]])");
+ auto batch1 = MakeBatch(type, ArrayFromJSON(int8(), "[1, 3, 0, 3]"), batch1_values);
+ auto batch2 = MakeBatch(type, ArrayFromJSON(int8(), "[2, null, 2]"), batch2_values);
+ return {batch1, batch2};
+ }
+
+ RecordBatchVector DifferentValuesNestedDictBatches2() {
+ // Outer dictionary values differ
+ auto value_type = list(dictionary(int8(), utf8()));
+ auto type = dictionary(int8(), value_type);
+ auto batch1_values = ArrayFromJSON(value_type, R"([[], ["a"], ["b"], ["a", "a"]])");
+ auto batch2_values = ArrayFromJSON(value_type, R"([["a"], ["b"], ["a", "a"]])");
+ auto batch1 = MakeBatch(type, ArrayFromJSON(int8(), "[1, 3, 0, 3]"), batch1_values);
+ auto batch2 = MakeBatch(type, ArrayFromJSON(int8(), "[2, null, 2]"), batch2_values);
+ return {batch1, batch2};
+ }
+
+ // Make one-column batch
+ std::shared_ptr<RecordBatch> MakeBatch(std::shared_ptr<Array> column) {
+ return RecordBatch::Make(schema({field("f", column->type())}), column->length(),
+ {column});
+ }
+
+ // Make one-column batch with a dictionary array
+ std::shared_ptr<RecordBatch> MakeBatch(std::shared_ptr<DataType> type,
+ std::shared_ptr<Array> indices,
+ std::shared_ptr<Array> dictionary) {
+ auto array = *DictionaryArray::FromArrays(std::move(type), std::move(indices),
+ std::move(dictionary));
+ return MakeBatch(std::move(array));
+ }
+
+ protected:
+ IpcWriteOptions write_options_ = IpcWriteOptions::Defaults();
+ IpcReadOptions read_options_ = IpcReadOptions::Defaults();
+ WriteStats write_stats_;
+ ReadStats read_stats_;
+};
+
+using DictionaryReplacementTestTypes =
+ ::testing::Types<StreamWriterHelper, StreamDecoderBufferWriterHelper,
+ FileWriterHelper>;
+
+TYPED_TEST_SUITE(TestDictionaryReplacement, DictionaryReplacementTestTypes);
+
+TYPED_TEST(TestDictionaryReplacement, SameDictPointer) { this->TestSameDictPointer(); }
+
+TYPED_TEST(TestDictionaryReplacement, SameDictValues) { this->TestSameDictValues(); }
+
+TYPED_TEST(TestDictionaryReplacement, DeltaDict) { this->TestDeltaDict(); }
+
+TYPED_TEST(TestDictionaryReplacement, SameDictValuesNested) {
+ this->TestSameDictValuesNested();
+}
+
+TYPED_TEST(TestDictionaryReplacement, DifferentDictValues) {
+ this->TestDifferentDictValues();
+}
+
+TYPED_TEST(TestDictionaryReplacement, DifferentDictValuesNested) {
+ this->TestDifferentDictValuesNested();
+}
+
+TYPED_TEST(TestDictionaryReplacement, DeltaDictNestedOuter) {
+ this->TestDeltaDictNestedOuter();
+}
+
+TYPED_TEST(TestDictionaryReplacement, DeltaDictNestedInner) {
+ this->TestDeltaDictNestedInner();
+}
+
+// ----------------------------------------------------------------------
+// Miscellanea
+
+TEST(FieldPosition, Basics) {
+ FieldPosition pos;
+ ASSERT_EQ(pos.path(), std::vector<int>{});
+ {
+ auto child = pos.child(6);
+ ASSERT_EQ(child.path(), std::vector<int>{6});
+ auto grand_child = child.child(42);
+ ASSERT_EQ(grand_child.path(), (std::vector<int>{6, 42}));
+ }
+ {
+ auto child = pos.child(12);
+ ASSERT_EQ(child.path(), std::vector<int>{12});
+ }
+}
+
+TEST(DictionaryFieldMapper, Basics) {
+ DictionaryFieldMapper mapper;
+
+ ASSERT_EQ(mapper.num_fields(), 0);
+
+ ASSERT_OK(mapper.AddField(42, {0, 1}));
+ ASSERT_OK(mapper.AddField(43, {0, 2}));
+ ASSERT_OK(mapper.AddField(44, {0, 1, 3}));
+ ASSERT_EQ(mapper.num_fields(), 3);
+
+ ASSERT_OK_AND_EQ(42, mapper.GetFieldId({0, 1}));
+ ASSERT_OK_AND_EQ(43, mapper.GetFieldId({0, 2}));
+ ASSERT_OK_AND_EQ(44, mapper.GetFieldId({0, 1, 3}));
+ ASSERT_RAISES(KeyError, mapper.GetFieldId({}));
+ ASSERT_RAISES(KeyError, mapper.GetFieldId({0}));
+ ASSERT_RAISES(KeyError, mapper.GetFieldId({0, 1, 2}));
+ ASSERT_RAISES(KeyError, mapper.GetFieldId({1}));
+
+ ASSERT_OK(mapper.AddField(41, {}));
+ ASSERT_EQ(mapper.num_fields(), 4);
+ ASSERT_OK_AND_EQ(41, mapper.GetFieldId({}));
+ ASSERT_OK_AND_EQ(42, mapper.GetFieldId({0, 1}));
+
+ // Duplicated dictionary ids are allowed
+ ASSERT_OK(mapper.AddField(42, {4, 5, 6}));
+ ASSERT_EQ(mapper.num_fields(), 5);
+ ASSERT_OK_AND_EQ(42, mapper.GetFieldId({0, 1}));
+ ASSERT_OK_AND_EQ(42, mapper.GetFieldId({4, 5, 6}));
+
+ // Duplicated fields paths are not
+ ASSERT_RAISES(KeyError, mapper.AddField(46, {0, 1}));
+}
+
+TEST(DictionaryFieldMapper, FromSchema) {
+ auto f0 = field("f0", int8());
+ auto f1 =
+ field("f1", struct_({field("a", null()), field("b", dictionary(int8(), utf8()))}));
+ auto f2 = field("f2", dictionary(int32(), list(dictionary(int8(), utf8()))));
+
+ Schema schema({f0, f1, f2});
+ DictionaryFieldMapper mapper(schema);
+
+ ASSERT_EQ(mapper.num_fields(), 3);
+ std::unordered_set<int64_t> ids;
+ for (const auto& path : std::vector<std::vector<int>>{{1, 1}, {2}, {2, 0}}) {
+ ASSERT_OK_AND_ASSIGN(const int64_t id, mapper.GetFieldId(path));
+ ids.insert(id);
+ }
+ ASSERT_EQ(ids.size(), 3); // All ids are distinct
+}
+
+static void AssertMemoDictionaryType(const DictionaryMemo& memo, int64_t id,
+ const std::shared_ptr<DataType>& expected) {
+ ASSERT_OK_AND_ASSIGN(const auto actual, memo.GetDictionaryType(id));
+ AssertTypeEqual(*expected, *actual);
+}
+
+TEST(DictionaryMemo, AddDictionaryType) {
+ DictionaryMemo memo;
+ std::shared_ptr<DataType> type;
+
+ ASSERT_RAISES(KeyError, memo.GetDictionaryType(42));
+
+ ASSERT_OK(memo.AddDictionaryType(42, utf8()));
+ ASSERT_OK(memo.AddDictionaryType(43, large_binary()));
+ AssertMemoDictionaryType(memo, 42, utf8());
+ AssertMemoDictionaryType(memo, 43, large_binary());
+
+ // Re-adding same type with different id
+ ASSERT_OK(memo.AddDictionaryType(44, utf8()));
+ AssertMemoDictionaryType(memo, 42, utf8());
+ AssertMemoDictionaryType(memo, 44, utf8());
+
+ // Re-adding same type with same id
+ ASSERT_OK(memo.AddDictionaryType(42, utf8()));
+ AssertMemoDictionaryType(memo, 42, utf8());
+ AssertMemoDictionaryType(memo, 44, utf8());
+
+ // Trying to add different type with same id
+ ASSERT_RAISES(KeyError, memo.AddDictionaryType(42, large_utf8()));
+ AssertMemoDictionaryType(memo, 42, utf8());
+ AssertMemoDictionaryType(memo, 43, large_binary());
+ AssertMemoDictionaryType(memo, 44, utf8());
+}
+
+} // namespace test
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/reader.cc b/src/arrow/cpp/src/arrow/ipc/reader.cc
new file mode 100644
index 000000000..a98f844c7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/reader.cc
@@ -0,0 +1,2095 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/ipc/reader.h"
+
+#include <algorithm>
+#include <climits>
+#include <cstdint>
+#include <cstring>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include <flatbuffers/flatbuffers.h> // IWYU pragma: export
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/extension_type.h"
+#include "arrow/io/caching.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/message.h"
+#include "arrow/ipc/metadata_internal.h"
+#include "arrow/ipc/util.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/sparse_tensor.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/compression.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/parallel.h"
+#include "arrow/util/string.h"
+#include "arrow/util/thread_pool.h"
+#include "arrow/util/ubsan.h"
+#include "arrow/util/vector.h"
+#include "arrow/visitor_inline.h"
+
+#include "generated/File_generated.h" // IWYU pragma: export
+#include "generated/Message_generated.h"
+#include "generated/Schema_generated.h"
+#include "generated/SparseTensor_generated.h"
+
+namespace arrow {
+
+namespace flatbuf = org::apache::arrow::flatbuf;
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+using internal::GetByteWidth;
+
+namespace ipc {
+
+using internal::FileBlock;
+using internal::kArrowMagicBytes;
+
+namespace {
+
+enum class DictionaryKind { New, Delta, Replacement };
+
+Status InvalidMessageType(MessageType expected, MessageType actual) {
+ return Status::IOError("Expected IPC message of type ", FormatMessageType(expected),
+ " but got ", FormatMessageType(actual));
+}
+
+#define CHECK_MESSAGE_TYPE(expected, actual) \
+ do { \
+ if ((actual) != (expected)) { \
+ return InvalidMessageType((expected), (actual)); \
+ } \
+ } while (0)
+
+#define CHECK_HAS_BODY(message) \
+ do { \
+ if ((message).body() == nullptr) { \
+ return Status::IOError("Expected body in IPC message of type ", \
+ FormatMessageType((message).type())); \
+ } \
+ } while (0)
+
+#define CHECK_HAS_NO_BODY(message) \
+ do { \
+ if ((message).body_length() != 0) { \
+ return Status::IOError("Unexpected body in IPC message of type ", \
+ FormatMessageType((message).type())); \
+ } \
+ } while (0)
+
+} // namespace
+
+// ----------------------------------------------------------------------
+// Record batch read path
+
+/// \brief Structure to keep common arguments to be passed
+struct IpcReadContext {
+ IpcReadContext(DictionaryMemo* memo, const IpcReadOptions& option, bool swap,
+ MetadataVersion version = MetadataVersion::V5,
+ Compression::type kind = Compression::UNCOMPRESSED)
+ : dictionary_memo(memo),
+ options(option),
+ metadata_version(version),
+ compression(kind),
+ swap_endian(swap) {}
+
+ DictionaryMemo* dictionary_memo;
+
+ const IpcReadOptions& options;
+
+ MetadataVersion metadata_version;
+
+ Compression::type compression;
+
+ /// \brief LoadRecordBatch() or LoadRecordBatchSubset() swaps endianness of elements
+ /// if this flag is true
+ const bool swap_endian;
+};
+
+/// The field_index and buffer_index are incremented based on how much of the
+/// batch is "consumed" (through nested data reconstruction, for example)
+class ArrayLoader {
+ public:
+ explicit ArrayLoader(const flatbuf::RecordBatch* metadata,
+ MetadataVersion metadata_version, const IpcReadOptions& options,
+ io::RandomAccessFile* file)
+ : metadata_(metadata),
+ metadata_version_(metadata_version),
+ file_(file),
+ max_recursion_depth_(options.max_recursion_depth) {}
+
+ Status ReadBuffer(int64_t offset, int64_t length, std::shared_ptr<Buffer>* out) {
+ if (skip_io_) {
+ return Status::OK();
+ }
+ if (offset < 0) {
+ return Status::Invalid("Negative offset for reading buffer ", buffer_index_);
+ }
+ if (length < 0) {
+ return Status::Invalid("Negative length for reading buffer ", buffer_index_);
+ }
+ // This construct permits overriding GetBuffer at compile time
+ if (!BitUtil::IsMultipleOf8(offset)) {
+ return Status::Invalid("Buffer ", buffer_index_,
+ " did not start on 8-byte aligned offset: ", offset);
+ }
+ return file_->ReadAt(offset, length).Value(out);
+ }
+
+ Status LoadType(const DataType& type) { return VisitTypeInline(type, this); }
+
+ Status Load(const Field* field, ArrayData* out) {
+ if (max_recursion_depth_ <= 0) {
+ return Status::Invalid("Max recursion depth reached");
+ }
+
+ field_ = field;
+ out_ = out;
+ out_->type = field_->type();
+ return LoadType(*field_->type());
+ }
+
+ Status SkipField(const Field* field) {
+ ArrayData dummy;
+ skip_io_ = true;
+ Status status = Load(field, &dummy);
+ skip_io_ = false;
+ return status;
+ }
+
+ Status GetBuffer(int buffer_index, std::shared_ptr<Buffer>* out) {
+ auto buffers = metadata_->buffers();
+ CHECK_FLATBUFFERS_NOT_NULL(buffers, "RecordBatch.buffers");
+ if (buffer_index >= static_cast<int>(buffers->size())) {
+ return Status::IOError("buffer_index out of range.");
+ }
+ const flatbuf::Buffer* buffer = buffers->Get(buffer_index);
+ if (buffer->length() == 0) {
+ // Should never return a null buffer here.
+ // (zero-sized buffer allocations are cheap)
+ return AllocateBuffer(0).Value(out);
+ } else {
+ return ReadBuffer(buffer->offset(), buffer->length(), out);
+ }
+ }
+
+ Status GetFieldMetadata(int field_index, ArrayData* out) {
+ auto nodes = metadata_->nodes();
+ CHECK_FLATBUFFERS_NOT_NULL(nodes, "Table.nodes");
+ // pop off a field
+ if (field_index >= static_cast<int>(nodes->size())) {
+ return Status::Invalid("Ran out of field metadata, likely malformed");
+ }
+ const flatbuf::FieldNode* node = nodes->Get(field_index);
+
+ out->length = node->length();
+ out->null_count = node->null_count();
+ out->offset = 0;
+ return Status::OK();
+ }
+
+ Status LoadCommon(Type::type type_id) {
+ // This only contains the length and null count, which we need to figure
+ // out what to do with the buffers. For example, if null_count == 0, then
+ // we can skip that buffer without reading from shared memory
+ RETURN_NOT_OK(GetFieldMetadata(field_index_++, out_));
+
+ if (internal::HasValidityBitmap(type_id, metadata_version_)) {
+ // Extract null_bitmap which is common to all arrays except for unions
+ // and nulls.
+ if (out_->null_count != 0) {
+ RETURN_NOT_OK(GetBuffer(buffer_index_, &out_->buffers[0]));
+ }
+ buffer_index_++;
+ }
+ return Status::OK();
+ }
+
+ template <typename TYPE>
+ Status LoadPrimitive(Type::type type_id) {
+ out_->buffers.resize(2);
+
+ RETURN_NOT_OK(LoadCommon(type_id));
+ if (out_->length > 0) {
+ RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
+ } else {
+ buffer_index_++;
+ out_->buffers[1].reset(new Buffer(nullptr, 0));
+ }
+ return Status::OK();
+ }
+
+ template <typename TYPE>
+ Status LoadBinary(Type::type type_id) {
+ out_->buffers.resize(3);
+
+ RETURN_NOT_OK(LoadCommon(type_id));
+ RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
+ return GetBuffer(buffer_index_++, &out_->buffers[2]);
+ }
+
+ template <typename TYPE>
+ Status LoadList(const TYPE& type) {
+ out_->buffers.resize(2);
+
+ RETURN_NOT_OK(LoadCommon(type.id()));
+ RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1]));
+
+ const int num_children = type.num_fields();
+ if (num_children != 1) {
+ return Status::Invalid("Wrong number of children: ", num_children);
+ }
+
+ return LoadChildren(type.fields());
+ }
+
+ Status LoadChildren(const std::vector<std::shared_ptr<Field>>& child_fields) {
+ ArrayData* parent = out_;
+
+ parent->child_data.resize(child_fields.size());
+ for (int i = 0; i < static_cast<int>(child_fields.size()); ++i) {
+ parent->child_data[i] = std::make_shared<ArrayData>();
+ --max_recursion_depth_;
+ RETURN_NOT_OK(Load(child_fields[i].get(), parent->child_data[i].get()));
+ ++max_recursion_depth_;
+ }
+ out_ = parent;
+ return Status::OK();
+ }
+
+ Status Visit(const NullType& type) {
+ out_->buffers.resize(1);
+
+ // ARROW-6379: NullType has no buffers in the IPC payload
+ return GetFieldMetadata(field_index_++, out_);
+ }
+
+ template <typename T>
+ enable_if_t<std::is_base_of<FixedWidthType, T>::value &&
+ !std::is_base_of<FixedSizeBinaryType, T>::value &&
+ !std::is_base_of<DictionaryType, T>::value,
+ Status>
+ Visit(const T& type) {
+ return LoadPrimitive<T>(type.id());
+ }
+
+ template <typename T>
+ enable_if_base_binary<T, Status> Visit(const T& type) {
+ return LoadBinary<T>(type.id());
+ }
+
+ Status Visit(const FixedSizeBinaryType& type) {
+ out_->buffers.resize(2);
+ RETURN_NOT_OK(LoadCommon(type.id()));
+ return GetBuffer(buffer_index_++, &out_->buffers[1]);
+ }
+
+ template <typename T>
+ enable_if_var_size_list<T, Status> Visit(const T& type) {
+ return LoadList(type);
+ }
+
+ Status Visit(const MapType& type) {
+ RETURN_NOT_OK(LoadList(type));
+ return MapArray::ValidateChildData(out_->child_data);
+ }
+
+ Status Visit(const FixedSizeListType& type) {
+ out_->buffers.resize(1);
+
+ RETURN_NOT_OK(LoadCommon(type.id()));
+
+ const int num_children = type.num_fields();
+ if (num_children != 1) {
+ return Status::Invalid("Wrong number of children: ", num_children);
+ }
+
+ return LoadChildren(type.fields());
+ }
+
+ Status Visit(const StructType& type) {
+ out_->buffers.resize(1);
+ RETURN_NOT_OK(LoadCommon(type.id()));
+ return LoadChildren(type.fields());
+ }
+
+ Status Visit(const UnionType& type) {
+ int n_buffers = type.mode() == UnionMode::SPARSE ? 2 : 3;
+ out_->buffers.resize(n_buffers);
+
+ RETURN_NOT_OK(LoadCommon(type.id()));
+
+ // With metadata V4, we can get a validity bitmap.
+ // Trying to fix up union data to do without the top-level validity bitmap
+ // is hairy:
+ // - type ids must be rewritten to all have valid values (even for former
+ // null slots)
+ // - sparse union children must have their validity bitmaps rewritten
+ // by ANDing the top-level validity bitmap
+ // - dense union children must be rewritten (at least one of them)
+ // to insert the required null slots that were formerly omitted
+ // So instead we bail out.
+ if (out_->null_count != 0 && out_->buffers[0] != nullptr) {
+ return Status::Invalid(
+ "Cannot read pre-1.0.0 Union array with top-level validity bitmap");
+ }
+ out_->buffers[0] = nullptr;
+ out_->null_count = 0;
+
+ if (out_->length > 0) {
+ RETURN_NOT_OK(GetBuffer(buffer_index_, &out_->buffers[1]));
+ if (type.mode() == UnionMode::DENSE) {
+ RETURN_NOT_OK(GetBuffer(buffer_index_ + 1, &out_->buffers[2]));
+ }
+ }
+ buffer_index_ += n_buffers - 1;
+ return LoadChildren(type.fields());
+ }
+
+ Status Visit(const DictionaryType& type) {
+ // out_->dictionary will be filled later in ResolveDictionaries()
+ return LoadType(*type.index_type());
+ }
+
+ Status Visit(const ExtensionType& type) { return LoadType(*type.storage_type()); }
+
+ private:
+ const flatbuf::RecordBatch* metadata_;
+ const MetadataVersion metadata_version_;
+ io::RandomAccessFile* file_;
+ int max_recursion_depth_;
+ int buffer_index_ = 0;
+ int field_index_ = 0;
+ bool skip_io_ = false;
+
+ const Field* field_;
+ ArrayData* out_;
+};
+
+Result<std::shared_ptr<Buffer>> DecompressBuffer(const std::shared_ptr<Buffer>& buf,
+ const IpcReadOptions& options,
+ util::Codec* codec) {
+ if (buf == nullptr || buf->size() == 0) {
+ return buf;
+ }
+
+ if (buf->size() < 8) {
+ return Status::Invalid(
+ "Likely corrupted message, compressed buffers "
+ "are larger than 8 bytes by construction");
+ }
+
+ const uint8_t* data = buf->data();
+ int64_t compressed_size = buf->size() - sizeof(int64_t);
+ int64_t uncompressed_size = BitUtil::FromLittleEndian(util::SafeLoadAs<int64_t>(data));
+
+ ARROW_ASSIGN_OR_RAISE(auto uncompressed,
+ AllocateBuffer(uncompressed_size, options.memory_pool));
+
+ ARROW_ASSIGN_OR_RAISE(
+ int64_t actual_decompressed,
+ codec->Decompress(compressed_size, data + sizeof(int64_t), uncompressed_size,
+ uncompressed->mutable_data()));
+ if (actual_decompressed != uncompressed_size) {
+ return Status::Invalid("Failed to fully decompress buffer, expected ",
+ uncompressed_size, " bytes but decompressed ",
+ actual_decompressed);
+ }
+
+ return std::move(uncompressed);
+}
+
+Status DecompressBuffers(Compression::type compression, const IpcReadOptions& options,
+ ArrayDataVector* fields) {
+ struct BufferAccumulator {
+ using BufferPtrVector = std::vector<std::shared_ptr<Buffer>*>;
+
+ void AppendFrom(const ArrayDataVector& fields) {
+ for (const auto& field : fields) {
+ for (auto& buffer : field->buffers) {
+ buffers_.push_back(&buffer);
+ }
+ AppendFrom(field->child_data);
+ }
+ }
+
+ BufferPtrVector Get(const ArrayDataVector& fields) && {
+ AppendFrom(fields);
+ return std::move(buffers_);
+ }
+
+ BufferPtrVector buffers_;
+ };
+
+ // Flatten all buffers
+ auto buffers = BufferAccumulator{}.Get(*fields);
+
+ std::unique_ptr<util::Codec> codec;
+ ARROW_ASSIGN_OR_RAISE(codec, util::Codec::Create(compression));
+
+ return ::arrow::internal::OptionalParallelFor(
+ options.use_threads, static_cast<int>(buffers.size()), [&](int i) {
+ ARROW_ASSIGN_OR_RAISE(*buffers[i],
+ DecompressBuffer(*buffers[i], options, codec.get()));
+ return Status::OK();
+ });
+}
+
+Result<std::shared_ptr<RecordBatch>> LoadRecordBatchSubset(
+ const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema,
+ const std::vector<bool>* inclusion_mask, const IpcReadContext& context,
+ io::RandomAccessFile* file) {
+ ArrayLoader loader(metadata, context.metadata_version, context.options, file);
+
+ ArrayDataVector columns(schema->num_fields());
+ ArrayDataVector filtered_columns;
+ FieldVector filtered_fields;
+ std::shared_ptr<Schema> filtered_schema;
+
+ for (int i = 0; i < schema->num_fields(); ++i) {
+ const Field& field = *schema->field(i);
+ if (!inclusion_mask || (*inclusion_mask)[i]) {
+ // Read field
+ auto column = std::make_shared<ArrayData>();
+ RETURN_NOT_OK(loader.Load(&field, column.get()));
+ if (metadata->length() != column->length) {
+ return Status::IOError("Array length did not match record batch length");
+ }
+ columns[i] = std::move(column);
+ if (inclusion_mask) {
+ filtered_columns.push_back(columns[i]);
+ filtered_fields.push_back(schema->field(i));
+ }
+ } else {
+ // Skip field. This logic must be executed to advance the state of the
+ // loader to the next field
+ RETURN_NOT_OK(loader.SkipField(&field));
+ }
+ }
+
+ // Dictionary resolution needs to happen on the unfiltered columns,
+ // because fields are mapped structurally (by path in the original schema).
+ RETURN_NOT_OK(ResolveDictionaries(columns, *context.dictionary_memo,
+ context.options.memory_pool));
+
+ if (inclusion_mask) {
+ filtered_schema = ::arrow::schema(std::move(filtered_fields), schema->metadata());
+ columns.clear();
+ } else {
+ filtered_schema = schema;
+ filtered_columns = std::move(columns);
+ }
+ if (context.compression != Compression::UNCOMPRESSED) {
+ RETURN_NOT_OK(
+ DecompressBuffers(context.compression, context.options, &filtered_columns));
+ }
+
+ // swap endian in a set of ArrayData if necessary (swap_endian == true)
+ if (context.swap_endian) {
+ for (int i = 0; i < static_cast<int>(filtered_columns.size()); ++i) {
+ ARROW_ASSIGN_OR_RAISE(filtered_columns[i],
+ arrow::internal::SwapEndianArrayData(filtered_columns[i]));
+ }
+ }
+ return RecordBatch::Make(std::move(filtered_schema), metadata->length(),
+ std::move(filtered_columns));
+}
+
+Result<std::shared_ptr<RecordBatch>> LoadRecordBatch(
+ const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema,
+ const std::vector<bool>& inclusion_mask, const IpcReadContext& context,
+ io::RandomAccessFile* file) {
+ if (inclusion_mask.size() > 0) {
+ return LoadRecordBatchSubset(metadata, schema, &inclusion_mask, context, file);
+ } else {
+ return LoadRecordBatchSubset(metadata, schema, /*param_name=*/nullptr, context, file);
+ }
+}
+
+// ----------------------------------------------------------------------
+// Array loading
+
+Status GetCompression(const flatbuf::RecordBatch* batch, Compression::type* out) {
+ *out = Compression::UNCOMPRESSED;
+ const flatbuf::BodyCompression* compression = batch->compression();
+ if (compression != nullptr) {
+ if (compression->method() != flatbuf::BodyCompressionMethod::BUFFER) {
+ // Forward compatibility
+ return Status::Invalid("This library only supports BUFFER compression method");
+ }
+
+ if (compression->codec() == flatbuf::CompressionType::LZ4_FRAME) {
+ *out = Compression::LZ4_FRAME;
+ } else if (compression->codec() == flatbuf::CompressionType::ZSTD) {
+ *out = Compression::ZSTD;
+ } else {
+ return Status::Invalid("Unsupported codec in RecordBatch::compression metadata");
+ }
+ return Status::OK();
+ }
+ return Status::OK();
+}
+
+Status GetCompressionExperimental(const flatbuf::Message* message,
+ Compression::type* out) {
+ *out = Compression::UNCOMPRESSED;
+ if (message->custom_metadata() != nullptr) {
+ // TODO: Ensure this deserialization only ever happens once
+ std::shared_ptr<KeyValueMetadata> metadata;
+ RETURN_NOT_OK(internal::GetKeyValueMetadata(message->custom_metadata(), &metadata));
+ int index = metadata->FindKey("ARROW:experimental_compression");
+ if (index != -1) {
+ // Arrow 0.17 stored string in upper case, internal utils now require lower case
+ auto name = arrow::internal::AsciiToLower(metadata->value(index));
+ ARROW_ASSIGN_OR_RAISE(*out, util::Codec::GetCompressionType(name));
+ }
+ return internal::CheckCompressionSupported(*out);
+ }
+ return Status::OK();
+}
+
+static Status ReadContiguousPayload(io::InputStream* file,
+ std::unique_ptr<Message>* message) {
+ ARROW_ASSIGN_OR_RAISE(*message, ReadMessage(file));
+ if (*message == nullptr) {
+ return Status::Invalid("Unable to read metadata at offset");
+ }
+ return Status::OK();
+}
+
+Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
+ const std::shared_ptr<Schema>& schema, const DictionaryMemo* dictionary_memo,
+ const IpcReadOptions& options, io::InputStream* file) {
+ std::unique_ptr<Message> message;
+ RETURN_NOT_OK(ReadContiguousPayload(file, &message));
+ CHECK_HAS_BODY(*message);
+ ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
+ return ReadRecordBatch(*message->metadata(), schema, dictionary_memo, options,
+ reader.get());
+}
+
+Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
+ const Message& message, const std::shared_ptr<Schema>& schema,
+ const DictionaryMemo* dictionary_memo, const IpcReadOptions& options) {
+ CHECK_MESSAGE_TYPE(MessageType::RECORD_BATCH, message.type());
+ CHECK_HAS_BODY(message);
+ ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
+ return ReadRecordBatch(*message.metadata(), schema, dictionary_memo, options,
+ reader.get());
+}
+
+Result<std::shared_ptr<RecordBatch>> ReadRecordBatchInternal(
+ const Buffer& metadata, const std::shared_ptr<Schema>& schema,
+ const std::vector<bool>& inclusion_mask, IpcReadContext& context,
+ io::RandomAccessFile* file) {
+ const flatbuf::Message* message = nullptr;
+ RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
+ auto batch = message->header_as_RecordBatch();
+ if (batch == nullptr) {
+ return Status::IOError(
+ "Header-type of flatbuffer-encoded Message is not RecordBatch.");
+ }
+
+ Compression::type compression;
+ RETURN_NOT_OK(GetCompression(batch, &compression));
+ if (context.compression == Compression::UNCOMPRESSED &&
+ message->version() == flatbuf::MetadataVersion::V4) {
+ // Possibly obtain codec information from experimental serialization format
+ // in 0.17.x
+ RETURN_NOT_OK(GetCompressionExperimental(message, &compression));
+ }
+ context.compression = compression;
+ context.metadata_version = internal::GetMetadataVersion(message->version());
+ return LoadRecordBatch(batch, schema, inclusion_mask, context, file);
+}
+
+// If we are selecting only certain fields, populate an inclusion mask for fast lookups.
+// Additionally, drop deselected fields from the reader's schema.
+Status GetInclusionMaskAndOutSchema(const std::shared_ptr<Schema>& full_schema,
+ const std::vector<int>& included_indices,
+ std::vector<bool>* inclusion_mask,
+ std::shared_ptr<Schema>* out_schema) {
+ inclusion_mask->clear();
+ if (included_indices.empty()) {
+ *out_schema = full_schema;
+ return Status::OK();
+ }
+
+ inclusion_mask->resize(full_schema->num_fields(), false);
+
+ auto included_indices_sorted = included_indices;
+ std::sort(included_indices_sorted.begin(), included_indices_sorted.end());
+
+ FieldVector included_fields;
+ for (int i : included_indices_sorted) {
+ // Ignore out of bounds indices
+ if (i < 0 || i >= full_schema->num_fields()) {
+ return Status::Invalid("Out of bounds field index: ", i);
+ }
+
+ if (inclusion_mask->at(i)) continue;
+
+ inclusion_mask->at(i) = true;
+ included_fields.push_back(full_schema->field(i));
+ }
+
+ *out_schema = schema(std::move(included_fields), full_schema->endianness(),
+ full_schema->metadata());
+ return Status::OK();
+}
+
+Status UnpackSchemaMessage(const void* opaque_schema, const IpcReadOptions& options,
+ DictionaryMemo* dictionary_memo,
+ std::shared_ptr<Schema>* schema,
+ std::shared_ptr<Schema>* out_schema,
+ std::vector<bool>* field_inclusion_mask, bool* swap_endian) {
+ RETURN_NOT_OK(internal::GetSchema(opaque_schema, dictionary_memo, schema));
+
+ // If we are selecting only certain fields, populate the inclusion mask now
+ // for fast lookups
+ RETURN_NOT_OK(GetInclusionMaskAndOutSchema(*schema, options.included_fields,
+ field_inclusion_mask, out_schema));
+ *swap_endian = options.ensure_native_endian && !out_schema->get()->is_native_endian();
+ if (*swap_endian) {
+ // create a new schema with native endianness before swapping endian in ArrayData
+ *schema = schema->get()->WithEndianness(Endianness::Native);
+ *out_schema = out_schema->get()->WithEndianness(Endianness::Native);
+ }
+ return Status::OK();
+}
+
+Status UnpackSchemaMessage(const Message& message, const IpcReadOptions& options,
+ DictionaryMemo* dictionary_memo,
+ std::shared_ptr<Schema>* schema,
+ std::shared_ptr<Schema>* out_schema,
+ std::vector<bool>* field_inclusion_mask, bool* swap_endian) {
+ CHECK_MESSAGE_TYPE(MessageType::SCHEMA, message.type());
+ CHECK_HAS_NO_BODY(message);
+
+ return UnpackSchemaMessage(message.header(), options, dictionary_memo, schema,
+ out_schema, field_inclusion_mask, swap_endian);
+}
+
+Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
+ const Buffer& metadata, const std::shared_ptr<Schema>& schema,
+ const DictionaryMemo* dictionary_memo, const IpcReadOptions& options,
+ io::RandomAccessFile* file) {
+ std::shared_ptr<Schema> out_schema;
+ // Empty means do not use
+ std::vector<bool> inclusion_mask;
+ IpcReadContext context(const_cast<DictionaryMemo*>(dictionary_memo), options, false);
+ RETURN_NOT_OK(GetInclusionMaskAndOutSchema(schema, context.options.included_fields,
+ &inclusion_mask, &out_schema));
+ return ReadRecordBatchInternal(metadata, schema, inclusion_mask, context, file);
+}
+
+Status ReadDictionary(const Buffer& metadata, const IpcReadContext& context,
+ DictionaryKind* kind, io::RandomAccessFile* file) {
+ const flatbuf::Message* message = nullptr;
+ RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
+ const auto dictionary_batch = message->header_as_DictionaryBatch();
+ if (dictionary_batch == nullptr) {
+ return Status::IOError(
+ "Header-type of flatbuffer-encoded Message is not DictionaryBatch.");
+ }
+
+ // The dictionary is embedded in a record batch with a single column
+ const auto batch_meta = dictionary_batch->data();
+
+ CHECK_FLATBUFFERS_NOT_NULL(batch_meta, "DictionaryBatch.data");
+
+ Compression::type compression;
+ RETURN_NOT_OK(GetCompression(batch_meta, &compression));
+ if (compression == Compression::UNCOMPRESSED &&
+ message->version() == flatbuf::MetadataVersion::V4) {
+ // Possibly obtain codec information from experimental serialization format
+ // in 0.17.x
+ RETURN_NOT_OK(GetCompressionExperimental(message, &compression));
+ }
+
+ const int64_t id = dictionary_batch->id();
+
+ // Look up the dictionary value type, which must have been added to the
+ // DictionaryMemo already prior to invoking this function
+ ARROW_ASSIGN_OR_RAISE(auto value_type, context.dictionary_memo->GetDictionaryType(id));
+
+ // Load the dictionary data from the dictionary batch
+ ArrayLoader loader(batch_meta, internal::GetMetadataVersion(message->version()),
+ context.options, file);
+ auto dict_data = std::make_shared<ArrayData>();
+ const Field dummy_field("", value_type);
+ RETURN_NOT_OK(loader.Load(&dummy_field, dict_data.get()));
+
+ if (compression != Compression::UNCOMPRESSED) {
+ ArrayDataVector dict_fields{dict_data};
+ RETURN_NOT_OK(DecompressBuffers(compression, context.options, &dict_fields));
+ }
+
+ // swap endian in dict_data if necessary (swap_endian == true)
+ if (context.swap_endian) {
+ ARROW_ASSIGN_OR_RAISE(dict_data, ::arrow::internal::SwapEndianArrayData(dict_data));
+ }
+
+ if (dictionary_batch->isDelta()) {
+ if (kind != nullptr) {
+ *kind = DictionaryKind::Delta;
+ }
+ return context.dictionary_memo->AddDictionaryDelta(id, dict_data);
+ }
+ ARROW_ASSIGN_OR_RAISE(bool inserted,
+ context.dictionary_memo->AddOrReplaceDictionary(id, dict_data));
+ if (kind != nullptr) {
+ *kind = inserted ? DictionaryKind::New : DictionaryKind::Replacement;
+ }
+ return Status::OK();
+}
+
+Status ReadDictionary(const Message& message, const IpcReadContext& context,
+ DictionaryKind* kind) {
+ // Only invoke this method if we already know we have a dictionary message
+ DCHECK_EQ(message.type(), MessageType::DICTIONARY_BATCH);
+ CHECK_HAS_BODY(message);
+ ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
+ return ReadDictionary(*message.metadata(), context, kind, reader.get());
+}
+
+// ----------------------------------------------------------------------
+// RecordBatchStreamReader implementation
+
+class RecordBatchStreamReaderImpl : public RecordBatchStreamReader {
+ public:
+ Status Open(std::unique_ptr<MessageReader> message_reader,
+ const IpcReadOptions& options) {
+ message_reader_ = std::move(message_reader);
+ options_ = options;
+
+ // Read schema
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message, ReadNextMessage());
+ if (!message) {
+ return Status::Invalid("Tried reading schema message, was null or length 0");
+ }
+
+ RETURN_NOT_OK(UnpackSchemaMessage(*message, options, &dictionary_memo_, &schema_,
+ &out_schema_, &field_inclusion_mask_,
+ &swap_endian_));
+ return Status::OK();
+ }
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
+ if (!have_read_initial_dictionaries_) {
+ RETURN_NOT_OK(ReadInitialDictionaries());
+ }
+
+ if (empty_stream_) {
+ // ARROW-6006: Degenerate case where stream contains no data, we do not
+ // bother trying to read a RecordBatch message from the stream
+ *batch = nullptr;
+ return Status::OK();
+ }
+
+ // Continue to read other dictionaries, if any
+ std::unique_ptr<Message> message;
+ ARROW_ASSIGN_OR_RAISE(message, ReadNextMessage());
+
+ while (message != nullptr && message->type() == MessageType::DICTIONARY_BATCH) {
+ RETURN_NOT_OK(ReadDictionary(*message));
+ ARROW_ASSIGN_OR_RAISE(message, ReadNextMessage());
+ }
+
+ if (message == nullptr) {
+ // End of stream
+ *batch = nullptr;
+ return Status::OK();
+ }
+
+ CHECK_HAS_BODY(*message);
+ ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
+ IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
+ return ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_,
+ context, reader.get())
+ .Value(batch);
+ }
+
+ std::shared_ptr<Schema> schema() const override { return out_schema_; }
+
+ ReadStats stats() const override { return stats_; }
+
+ private:
+ Result<std::unique_ptr<Message>> ReadNextMessage() {
+ ARROW_ASSIGN_OR_RAISE(auto message, message_reader_->ReadNextMessage());
+ if (message) {
+ ++stats_.num_messages;
+ switch (message->type()) {
+ case MessageType::RECORD_BATCH:
+ ++stats_.num_record_batches;
+ break;
+ case MessageType::DICTIONARY_BATCH:
+ ++stats_.num_dictionary_batches;
+ break;
+ default:
+ break;
+ }
+ }
+ return std::move(message);
+ }
+
+ // Read dictionary from dictionary batch
+ Status ReadDictionary(const Message& message) {
+ DictionaryKind kind;
+ IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
+ RETURN_NOT_OK(::arrow::ipc::ReadDictionary(message, context, &kind));
+ switch (kind) {
+ case DictionaryKind::New:
+ break;
+ case DictionaryKind::Delta:
+ ++stats_.num_dictionary_deltas;
+ break;
+ case DictionaryKind::Replacement:
+ ++stats_.num_replaced_dictionaries;
+ break;
+ }
+ return Status::OK();
+ }
+
+ Status ReadInitialDictionaries() {
+ // We must receive all dictionaries before reconstructing the
+ // first record batch. Subsequent dictionary deltas modify the memo
+ std::unique_ptr<Message> message;
+
+ // TODO(wesm): In future, we may want to reconcile the ids in the stream with
+ // those found in the schema
+ const auto num_dicts = dictionary_memo_.fields().num_dicts();
+ for (int i = 0; i < num_dicts; ++i) {
+ ARROW_ASSIGN_OR_RAISE(message, ReadNextMessage());
+ if (!message) {
+ if (i == 0) {
+ /// ARROW-6006: If we fail to find any dictionaries in the stream, then
+ /// it may be that the stream has a schema but no actual data. In such
+ /// case we communicate that we were unable to find the dictionaries
+ /// (but there was no failure otherwise), so the caller can decide what
+ /// to do
+ empty_stream_ = true;
+ break;
+ } else {
+ // ARROW-6126, the stream terminated before receiving the expected
+ // number of dictionaries
+ return Status::Invalid("IPC stream ended without reading the expected number (",
+ num_dicts, ") of dictionaries");
+ }
+ }
+
+ if (message->type() != MessageType::DICTIONARY_BATCH) {
+ return Status::Invalid("IPC stream did not have the expected number (", num_dicts,
+ ") of dictionaries at the start of the stream");
+ }
+ RETURN_NOT_OK(ReadDictionary(*message));
+ }
+
+ have_read_initial_dictionaries_ = true;
+ return Status::OK();
+ }
+
+ std::unique_ptr<MessageReader> message_reader_;
+ IpcReadOptions options_;
+ std::vector<bool> field_inclusion_mask_;
+
+ bool have_read_initial_dictionaries_ = false;
+
+ // Flag to set in case where we fail to observe all dictionaries in a stream,
+ // and so the reader should not attempt to parse any messages
+ bool empty_stream_ = false;
+
+ ReadStats stats_;
+
+ DictionaryMemo dictionary_memo_;
+ std::shared_ptr<Schema> schema_, out_schema_;
+
+ bool swap_endian_;
+};
+
+// ----------------------------------------------------------------------
+// Stream reader constructors
+
+Result<std::shared_ptr<RecordBatchStreamReader>> RecordBatchStreamReader::Open(
+ std::unique_ptr<MessageReader> message_reader, const IpcReadOptions& options) {
+ // Private ctor
+ auto result = std::make_shared<RecordBatchStreamReaderImpl>();
+ RETURN_NOT_OK(result->Open(std::move(message_reader), options));
+ return result;
+}
+
+Result<std::shared_ptr<RecordBatchStreamReader>> RecordBatchStreamReader::Open(
+ io::InputStream* stream, const IpcReadOptions& options) {
+ return Open(MessageReader::Open(stream), options);
+}
+
+Result<std::shared_ptr<RecordBatchStreamReader>> RecordBatchStreamReader::Open(
+ const std::shared_ptr<io::InputStream>& stream, const IpcReadOptions& options) {
+ return Open(MessageReader::Open(stream), options);
+}
+
+// ----------------------------------------------------------------------
+// Reader implementation
+
+// Common functions used in both the random-access file reader and the
+// asynchronous generator
+static inline FileBlock FileBlockFromFlatbuffer(const flatbuf::Block* block) {
+ return FileBlock{block->offset(), block->metaDataLength(), block->bodyLength()};
+}
+
+static Result<std::unique_ptr<Message>> ReadMessageFromBlock(const FileBlock& block,
+ io::RandomAccessFile* file) {
+ if (!BitUtil::IsMultipleOf8(block.offset) ||
+ !BitUtil::IsMultipleOf8(block.metadata_length) ||
+ !BitUtil::IsMultipleOf8(block.body_length)) {
+ return Status::Invalid("Unaligned block in IPC file");
+ }
+
+ // TODO(wesm): this breaks integration tests, see ARROW-3256
+ // DCHECK_EQ((*out)->body_length(), block.body_length);
+
+ ARROW_ASSIGN_OR_RAISE(auto message,
+ ReadMessage(block.offset, block.metadata_length, file));
+ return std::move(message);
+}
+
+static Future<std::shared_ptr<Message>> ReadMessageFromBlockAsync(
+ const FileBlock& block, io::RandomAccessFile* file, const io::IOContext& io_context) {
+ if (!BitUtil::IsMultipleOf8(block.offset) ||
+ !BitUtil::IsMultipleOf8(block.metadata_length) ||
+ !BitUtil::IsMultipleOf8(block.body_length)) {
+ return Status::Invalid("Unaligned block in IPC file");
+ }
+
+ // TODO(wesm): this breaks integration tests, see ARROW-3256
+ // DCHECK_EQ((*out)->body_length(), block.body_length);
+
+ return ReadMessageAsync(block.offset, block.metadata_length, block.body_length, file,
+ io_context);
+}
+
+static Status ReadOneDictionary(Message* message, const IpcReadContext& context) {
+ CHECK_HAS_BODY(*message);
+ ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
+ DictionaryKind kind;
+ RETURN_NOT_OK(ReadDictionary(*message->metadata(), context, &kind, reader.get()));
+ if (kind != DictionaryKind::New) {
+ return Status::Invalid(
+ "Unsupported dictionary replacement or "
+ "dictionary delta in IPC file");
+ }
+ return Status::OK();
+}
+
+class RecordBatchFileReaderImpl;
+
+/// A generator of record batches.
+///
+/// All batches are yielded in order.
+class ARROW_EXPORT IpcFileRecordBatchGenerator {
+ public:
+ using Item = std::shared_ptr<RecordBatch>;
+
+ explicit IpcFileRecordBatchGenerator(
+ std::shared_ptr<RecordBatchFileReaderImpl> state,
+ std::shared_ptr<io::internal::ReadRangeCache> cached_source,
+ const io::IOContext& io_context, arrow::internal::Executor* executor)
+ : state_(std::move(state)),
+ cached_source_(std::move(cached_source)),
+ io_context_(io_context),
+ executor_(executor),
+ index_(0) {}
+
+ Future<Item> operator()();
+ Future<std::shared_ptr<Message>> ReadBlock(const FileBlock& block);
+
+ static Status ReadDictionaries(
+ RecordBatchFileReaderImpl* state,
+ std::vector<std::shared_ptr<Message>> dictionary_messages);
+ static Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
+ RecordBatchFileReaderImpl* state, Message* message);
+
+ private:
+ std::shared_ptr<RecordBatchFileReaderImpl> state_;
+ std::shared_ptr<io::internal::ReadRangeCache> cached_source_;
+ io::IOContext io_context_;
+ arrow::internal::Executor* executor_;
+ int index_;
+ // Odd Future type, but this lets us use All() easily
+ Future<> read_dictionaries_;
+};
+
+class RecordBatchFileReaderImpl : public RecordBatchFileReader {
+ public:
+ RecordBatchFileReaderImpl() : file_(NULLPTR), footer_offset_(0), footer_(NULLPTR) {}
+
+ int num_record_batches() const override {
+ return static_cast<int>(internal::FlatBuffersVectorSize(footer_->recordBatches()));
+ }
+
+ MetadataVersion version() const override {
+ return internal::GetMetadataVersion(footer_->version());
+ }
+
+ Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(int i) override {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, num_record_batches());
+
+ if (!read_dictionaries_) {
+ RETURN_NOT_OK(ReadDictionaries());
+ read_dictionaries_ = true;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto message, ReadMessageFromBlock(GetRecordBatchBlock(i)));
+
+ CHECK_HAS_BODY(*message);
+ ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
+ IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
+ ARROW_ASSIGN_OR_RAISE(auto batch, ReadRecordBatchInternal(
+ *message->metadata(), schema_,
+ field_inclusion_mask_, context, reader.get()));
+ ++stats_.num_record_batches;
+ return batch;
+ }
+
+ Result<int64_t> CountRows() override {
+ int64_t total = 0;
+ for (int i = 0; i < num_record_batches(); i++) {
+ ARROW_ASSIGN_OR_RAISE(auto outer_message,
+ ReadMessageFromBlock(GetRecordBatchBlock(i)));
+ auto metadata = outer_message->metadata();
+ const flatbuf::Message* message = nullptr;
+ RETURN_NOT_OK(
+ internal::VerifyMessage(metadata->data(), metadata->size(), &message));
+ auto batch = message->header_as_RecordBatch();
+ if (batch == nullptr) {
+ return Status::IOError(
+ "Header-type of flatbuffer-encoded Message is not RecordBatch.");
+ }
+ total += batch->length();
+ }
+ return total;
+ }
+
+ Status Open(const std::shared_ptr<io::RandomAccessFile>& file, int64_t footer_offset,
+ const IpcReadOptions& options) {
+ owned_file_ = file;
+ return Open(file.get(), footer_offset, options);
+ }
+
+ Status Open(io::RandomAccessFile* file, int64_t footer_offset,
+ const IpcReadOptions& options) {
+ file_ = file;
+ options_ = options;
+ footer_offset_ = footer_offset;
+ RETURN_NOT_OK(ReadFooter());
+
+ // Get the schema and record any observed dictionaries
+ RETURN_NOT_OK(UnpackSchemaMessage(footer_->schema(), options, &dictionary_memo_,
+ &schema_, &out_schema_, &field_inclusion_mask_,
+ &swap_endian_));
+ ++stats_.num_messages;
+ return Status::OK();
+ }
+
+ Future<> OpenAsync(const std::shared_ptr<io::RandomAccessFile>& file,
+ int64_t footer_offset, const IpcReadOptions& options) {
+ owned_file_ = file;
+ return OpenAsync(file.get(), footer_offset, options);
+ }
+
+ Future<> OpenAsync(io::RandomAccessFile* file, int64_t footer_offset,
+ const IpcReadOptions& options) {
+ file_ = file;
+ options_ = options;
+ footer_offset_ = footer_offset;
+ auto cpu_executor = ::arrow::internal::GetCpuThreadPool();
+ auto self = std::dynamic_pointer_cast<RecordBatchFileReaderImpl>(shared_from_this());
+ return ReadFooterAsync(cpu_executor).Then([self, options]() -> Status {
+ // Get the schema and record any observed dictionaries
+ RETURN_NOT_OK(UnpackSchemaMessage(
+ self->footer_->schema(), options, &self->dictionary_memo_, &self->schema_,
+ &self->out_schema_, &self->field_inclusion_mask_, &self->swap_endian_));
+ ++self->stats_.num_messages;
+ return Status::OK();
+ });
+ }
+
+ std::shared_ptr<Schema> schema() const override { return out_schema_; }
+
+ std::shared_ptr<const KeyValueMetadata> metadata() const override { return metadata_; }
+
+ ReadStats stats() const override { return stats_; }
+
+ Result<AsyncGenerator<std::shared_ptr<RecordBatch>>> GetRecordBatchGenerator(
+ const bool coalesce, const io::IOContext& io_context,
+ const io::CacheOptions cache_options,
+ arrow::internal::Executor* executor) override {
+ auto state = std::dynamic_pointer_cast<RecordBatchFileReaderImpl>(shared_from_this());
+ std::shared_ptr<io::internal::ReadRangeCache> cached_source;
+ if (coalesce) {
+ if (!owned_file_) return Status::Invalid("Cannot coalesce without an owned file");
+ cached_source = std::make_shared<io::internal::ReadRangeCache>(
+ owned_file_, io_context, cache_options);
+ auto num_dictionaries = this->num_dictionaries();
+ auto num_record_batches = this->num_record_batches();
+ std::vector<io::ReadRange> ranges(num_dictionaries + num_record_batches);
+ for (int i = 0; i < num_dictionaries; i++) {
+ auto block = FileBlockFromFlatbuffer(footer_->dictionaries()->Get(i));
+ ranges[i].offset = block.offset;
+ ranges[i].length = block.metadata_length + block.body_length;
+ }
+ for (int i = 0; i < num_record_batches; i++) {
+ auto block = FileBlockFromFlatbuffer(footer_->recordBatches()->Get(i));
+ ranges[num_dictionaries + i].offset = block.offset;
+ ranges[num_dictionaries + i].length = block.metadata_length + block.body_length;
+ }
+ RETURN_NOT_OK(cached_source->Cache(std::move(ranges)));
+ }
+ return IpcFileRecordBatchGenerator(std::move(state), std::move(cached_source),
+ io_context, executor);
+ }
+
+ private:
+ friend AsyncGenerator<std::shared_ptr<Message>> MakeMessageGenerator(
+ std::shared_ptr<RecordBatchFileReaderImpl>, const io::IOContext&);
+ friend class IpcFileRecordBatchGenerator;
+
+ FileBlock GetRecordBatchBlock(int i) const {
+ return FileBlockFromFlatbuffer(footer_->recordBatches()->Get(i));
+ }
+
+ FileBlock GetDictionaryBlock(int i) const {
+ return FileBlockFromFlatbuffer(footer_->dictionaries()->Get(i));
+ }
+
+ Result<std::unique_ptr<Message>> ReadMessageFromBlock(const FileBlock& block) {
+ ARROW_ASSIGN_OR_RAISE(auto message, arrow::ipc::ReadMessageFromBlock(block, file_));
+ ++stats_.num_messages;
+ return std::move(message);
+ }
+
+ Status ReadDictionaries() {
+ // Read all the dictionaries
+ IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
+ for (int i = 0; i < num_dictionaries(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto message, ReadMessageFromBlock(GetDictionaryBlock(i)));
+ RETURN_NOT_OK(ReadOneDictionary(message.get(), context));
+ ++stats_.num_dictionary_batches;
+ }
+ return Status::OK();
+ }
+
+ Status ReadFooter() {
+ auto fut = ReadFooterAsync(/*executor=*/nullptr);
+ return fut.status();
+ }
+
+ Future<> ReadFooterAsync(arrow::internal::Executor* executor) {
+ const int32_t magic_size = static_cast<int>(strlen(kArrowMagicBytes));
+
+ if (footer_offset_ <= magic_size * 2 + 4) {
+ return Status::Invalid("File is too small: ", footer_offset_);
+ }
+
+ int file_end_size = static_cast<int>(magic_size + sizeof(int32_t));
+ auto self = std::dynamic_pointer_cast<RecordBatchFileReaderImpl>(shared_from_this());
+ auto read_magic = file_->ReadAsync(footer_offset_ - file_end_size, file_end_size);
+ if (executor) read_magic = executor->Transfer(std::move(read_magic));
+ return read_magic
+ .Then([=](const std::shared_ptr<Buffer>& buffer)
+ -> Future<std::shared_ptr<Buffer>> {
+ const int64_t expected_footer_size = magic_size + sizeof(int32_t);
+ if (buffer->size() < expected_footer_size) {
+ return Status::Invalid("Unable to read ", expected_footer_size,
+ "from end of file");
+ }
+
+ if (memcmp(buffer->data() + sizeof(int32_t), kArrowMagicBytes, magic_size)) {
+ return Status::Invalid("Not an Arrow file");
+ }
+
+ int32_t footer_length = BitUtil::FromLittleEndian(
+ *reinterpret_cast<const int32_t*>(buffer->data()));
+
+ if (footer_length <= 0 ||
+ footer_length > self->footer_offset_ - magic_size * 2 - 4) {
+ return Status::Invalid("File is smaller than indicated metadata size");
+ }
+
+ // Now read the footer
+ auto read_footer = self->file_->ReadAsync(
+ self->footer_offset_ - footer_length - file_end_size, footer_length);
+ if (executor) read_footer = executor->Transfer(std::move(read_footer));
+ return read_footer;
+ })
+ .Then([=](const std::shared_ptr<Buffer>& buffer) -> Status {
+ self->footer_buffer_ = buffer;
+ const auto data = self->footer_buffer_->data();
+ const auto size = self->footer_buffer_->size();
+ if (!internal::VerifyFlatbuffers<flatbuf::Footer>(data, size)) {
+ return Status::IOError("Verification of flatbuffer-encoded Footer failed.");
+ }
+ self->footer_ = flatbuf::GetFooter(data);
+
+ auto fb_metadata = self->footer_->custom_metadata();
+ if (fb_metadata != nullptr) {
+ std::shared_ptr<KeyValueMetadata> md;
+ RETURN_NOT_OK(internal::GetKeyValueMetadata(fb_metadata, &md));
+ self->metadata_ = std::move(md); // const-ify
+ }
+ return Status::OK();
+ });
+ }
+
+ int num_dictionaries() const {
+ return static_cast<int>(internal::FlatBuffersVectorSize(footer_->dictionaries()));
+ }
+
+ io::RandomAccessFile* file_;
+ IpcReadOptions options_;
+ std::vector<bool> field_inclusion_mask_;
+
+ std::shared_ptr<io::RandomAccessFile> owned_file_;
+
+ // The location where the Arrow file layout ends. May be the end of the file
+ // or some other location if embedded in a larger file.
+ int64_t footer_offset_;
+
+ // Footer metadata
+ std::shared_ptr<Buffer> footer_buffer_;
+ const flatbuf::Footer* footer_;
+ std::shared_ptr<const KeyValueMetadata> metadata_;
+
+ bool read_dictionaries_ = false;
+ DictionaryMemo dictionary_memo_;
+
+ // Reconstructed schema, including any read dictionaries
+ std::shared_ptr<Schema> schema_;
+ // Schema with deselected fields dropped
+ std::shared_ptr<Schema> out_schema_;
+
+ ReadStats stats_;
+
+ bool swap_endian_;
+};
+
+Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
+ io::RandomAccessFile* file, const IpcReadOptions& options) {
+ ARROW_ASSIGN_OR_RAISE(int64_t footer_offset, file->GetSize());
+ return Open(file, footer_offset, options);
+}
+
+Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
+ io::RandomAccessFile* file, int64_t footer_offset, const IpcReadOptions& options) {
+ auto result = std::make_shared<RecordBatchFileReaderImpl>();
+ RETURN_NOT_OK(result->Open(file, footer_offset, options));
+ return result;
+}
+
+Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
+ const std::shared_ptr<io::RandomAccessFile>& file, const IpcReadOptions& options) {
+ ARROW_ASSIGN_OR_RAISE(int64_t footer_offset, file->GetSize());
+ return Open(file, footer_offset, options);
+}
+
+Result<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::Open(
+ const std::shared_ptr<io::RandomAccessFile>& file, int64_t footer_offset,
+ const IpcReadOptions& options) {
+ auto result = std::make_shared<RecordBatchFileReaderImpl>();
+ RETURN_NOT_OK(result->Open(file, footer_offset, options));
+ return result;
+}
+
+Future<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::OpenAsync(
+ const std::shared_ptr<io::RandomAccessFile>& file, const IpcReadOptions& options) {
+ ARROW_ASSIGN_OR_RAISE(int64_t footer_offset, file->GetSize());
+ return OpenAsync(std::move(file), footer_offset, options);
+}
+
+Future<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::OpenAsync(
+ io::RandomAccessFile* file, const IpcReadOptions& options) {
+ ARROW_ASSIGN_OR_RAISE(int64_t footer_offset, file->GetSize());
+ return OpenAsync(file, footer_offset, options);
+}
+
+Future<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::OpenAsync(
+ const std::shared_ptr<io::RandomAccessFile>& file, int64_t footer_offset,
+ const IpcReadOptions& options) {
+ auto result = std::make_shared<RecordBatchFileReaderImpl>();
+ return result->OpenAsync(file, footer_offset, options)
+ .Then([=]() -> Result<std::shared_ptr<RecordBatchFileReader>> { return result; });
+}
+
+Future<std::shared_ptr<RecordBatchFileReader>> RecordBatchFileReader::OpenAsync(
+ io::RandomAccessFile* file, int64_t footer_offset, const IpcReadOptions& options) {
+ auto result = std::make_shared<RecordBatchFileReaderImpl>();
+ return result->OpenAsync(file, footer_offset, options)
+ .Then([=]() -> Result<std::shared_ptr<RecordBatchFileReader>> { return result; });
+}
+
+Future<IpcFileRecordBatchGenerator::Item> IpcFileRecordBatchGenerator::operator()() {
+ auto state = state_;
+ if (!read_dictionaries_.is_valid()) {
+ std::vector<Future<std::shared_ptr<Message>>> messages(state->num_dictionaries());
+ for (int i = 0; i < state->num_dictionaries(); i++) {
+ auto block = FileBlockFromFlatbuffer(state->footer_->dictionaries()->Get(i));
+ messages[i] = ReadBlock(block);
+ }
+ auto read_messages = All(std::move(messages));
+ if (executor_) read_messages = executor_->Transfer(read_messages);
+ read_dictionaries_ = read_messages.Then(
+ [=](const std::vector<Result<std::shared_ptr<Message>>>& maybe_messages)
+ -> Status {
+ ARROW_ASSIGN_OR_RAISE(auto messages,
+ arrow::internal::UnwrapOrRaise(maybe_messages));
+ return ReadDictionaries(state.get(), std::move(messages));
+ });
+ }
+ if (index_ >= state_->num_record_batches()) {
+ return Future<Item>::MakeFinished(IterationTraits<Item>::End());
+ }
+ auto block = FileBlockFromFlatbuffer(state->footer_->recordBatches()->Get(index_++));
+ auto read_message = ReadBlock(block);
+ auto read_messages = read_dictionaries_.Then([read_message]() { return read_message; });
+ // Force transfer. This may be wasteful in some cases, but ensures we get off the
+ // I/O threads as soon as possible, and ensures we don't decode record batches
+ // synchronously in the case that the message read has already finished.
+ if (executor_) {
+ auto executor = executor_;
+ return read_messages.Then(
+ [=](const std::shared_ptr<Message>& message) -> Future<Item> {
+ return DeferNotOk(executor->Submit(
+ [=]() { return ReadRecordBatch(state.get(), message.get()); }));
+ });
+ }
+ return read_messages.Then([=](const std::shared_ptr<Message>& message) -> Result<Item> {
+ return ReadRecordBatch(state.get(), message.get());
+ });
+}
+
+Future<std::shared_ptr<Message>> IpcFileRecordBatchGenerator::ReadBlock(
+ const FileBlock& block) {
+ if (cached_source_) {
+ auto cached_source = cached_source_;
+ io::ReadRange range{block.offset, block.metadata_length + block.body_length};
+ auto pool = state_->options_.memory_pool;
+ return cached_source->WaitFor({range}).Then(
+ [cached_source, pool, range]() -> Result<std::shared_ptr<Message>> {
+ ARROW_ASSIGN_OR_RAISE(auto buffer, cached_source->Read(range));
+ io::BufferReader stream(std::move(buffer));
+ return ReadMessage(&stream, pool);
+ });
+ } else {
+ return ReadMessageFromBlockAsync(block, state_->file_, io_context_);
+ }
+}
+
+Status IpcFileRecordBatchGenerator::ReadDictionaries(
+ RecordBatchFileReaderImpl* state,
+ std::vector<std::shared_ptr<Message>> dictionary_messages) {
+ IpcReadContext context(&state->dictionary_memo_, state->options_, state->swap_endian_);
+ for (const auto& message : dictionary_messages) {
+ RETURN_NOT_OK(ReadOneDictionary(message.get(), context));
+ }
+ return Status::OK();
+}
+
+Result<std::shared_ptr<RecordBatch>> IpcFileRecordBatchGenerator::ReadRecordBatch(
+ RecordBatchFileReaderImpl* state, Message* message) {
+ CHECK_HAS_BODY(*message);
+ ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
+ IpcReadContext context(&state->dictionary_memo_, state->options_, state->swap_endian_);
+ return ReadRecordBatchInternal(*message->metadata(), state->schema_,
+ state->field_inclusion_mask_, context, reader.get());
+}
+
+Status Listener::OnEOS() { return Status::OK(); }
+
+Status Listener::OnSchemaDecoded(std::shared_ptr<Schema> schema) { return Status::OK(); }
+
+Status Listener::OnRecordBatchDecoded(std::shared_ptr<RecordBatch> record_batch) {
+ return Status::NotImplemented("OnRecordBatchDecoded() callback isn't implemented");
+}
+
+class StreamDecoder::StreamDecoderImpl : public MessageDecoderListener {
+ private:
+ enum State {
+ SCHEMA,
+ INITIAL_DICTIONARIES,
+ RECORD_BATCHES,
+ EOS,
+ };
+
+ public:
+ explicit StreamDecoderImpl(std::shared_ptr<Listener> listener, IpcReadOptions options)
+ : listener_(std::move(listener)),
+ options_(std::move(options)),
+ state_(State::SCHEMA),
+ message_decoder_(std::shared_ptr<StreamDecoderImpl>(this, [](void*) {}),
+ options_.memory_pool),
+ n_required_dictionaries_(0) {}
+
+ Status OnMessageDecoded(std::unique_ptr<Message> message) override {
+ ++stats_.num_messages;
+ switch (state_) {
+ case State::SCHEMA:
+ ARROW_RETURN_NOT_OK(OnSchemaMessageDecoded(std::move(message)));
+ break;
+ case State::INITIAL_DICTIONARIES:
+ ARROW_RETURN_NOT_OK(OnInitialDictionaryMessageDecoded(std::move(message)));
+ break;
+ case State::RECORD_BATCHES:
+ ARROW_RETURN_NOT_OK(OnRecordBatchMessageDecoded(std::move(message)));
+ break;
+ case State::EOS:
+ break;
+ }
+ return Status::OK();
+ }
+
+ Status OnEOS() override {
+ state_ = State::EOS;
+ return listener_->OnEOS();
+ }
+
+ Status Consume(const uint8_t* data, int64_t size) {
+ return message_decoder_.Consume(data, size);
+ }
+
+ Status Consume(std::shared_ptr<Buffer> buffer) {
+ return message_decoder_.Consume(std::move(buffer));
+ }
+
+ std::shared_ptr<Schema> schema() const { return out_schema_; }
+
+ int64_t next_required_size() const { return message_decoder_.next_required_size(); }
+
+ ReadStats stats() const { return stats_; }
+
+ private:
+ Status OnSchemaMessageDecoded(std::unique_ptr<Message> message) {
+ RETURN_NOT_OK(UnpackSchemaMessage(*message, options_, &dictionary_memo_, &schema_,
+ &out_schema_, &field_inclusion_mask_,
+ &swap_endian_));
+
+ n_required_dictionaries_ = dictionary_memo_.fields().num_fields();
+ if (n_required_dictionaries_ == 0) {
+ state_ = State::RECORD_BATCHES;
+ RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_));
+ } else {
+ state_ = State::INITIAL_DICTIONARIES;
+ }
+ return Status::OK();
+ }
+
+ Status OnInitialDictionaryMessageDecoded(std::unique_ptr<Message> message) {
+ if (message->type() != MessageType::DICTIONARY_BATCH) {
+ return Status::Invalid("IPC stream did not have the expected number (",
+ dictionary_memo_.fields().num_fields(),
+ ") of dictionaries at the start of the stream");
+ }
+ RETURN_NOT_OK(ReadDictionary(*message));
+ n_required_dictionaries_--;
+ if (n_required_dictionaries_ == 0) {
+ state_ = State::RECORD_BATCHES;
+ ARROW_RETURN_NOT_OK(listener_->OnSchemaDecoded(schema_));
+ }
+ return Status::OK();
+ }
+
+ Status OnRecordBatchMessageDecoded(std::unique_ptr<Message> message) {
+ IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
+ if (message->type() == MessageType::DICTIONARY_BATCH) {
+ return ReadDictionary(*message);
+ } else {
+ CHECK_HAS_BODY(*message);
+ ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
+ IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
+ ARROW_ASSIGN_OR_RAISE(
+ auto batch,
+ ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_,
+ context, reader.get()));
+ ++stats_.num_record_batches;
+ return listener_->OnRecordBatchDecoded(std::move(batch));
+ }
+ }
+
+ // Read dictionary from dictionary batch
+ Status ReadDictionary(const Message& message) {
+ DictionaryKind kind;
+ IpcReadContext context(&dictionary_memo_, options_, swap_endian_);
+ RETURN_NOT_OK(::arrow::ipc::ReadDictionary(message, context, &kind));
+ ++stats_.num_dictionary_batches;
+ switch (kind) {
+ case DictionaryKind::New:
+ break;
+ case DictionaryKind::Delta:
+ ++stats_.num_dictionary_deltas;
+ break;
+ case DictionaryKind::Replacement:
+ ++stats_.num_replaced_dictionaries;
+ break;
+ }
+ return Status::OK();
+ }
+
+ std::shared_ptr<Listener> listener_;
+ const IpcReadOptions options_;
+ State state_;
+ MessageDecoder message_decoder_;
+ std::vector<bool> field_inclusion_mask_;
+ int n_required_dictionaries_;
+ DictionaryMemo dictionary_memo_;
+ std::shared_ptr<Schema> schema_, out_schema_;
+ ReadStats stats_;
+ bool swap_endian_;
+};
+
+StreamDecoder::StreamDecoder(std::shared_ptr<Listener> listener, IpcReadOptions options) {
+ impl_.reset(new StreamDecoderImpl(std::move(listener), options));
+}
+
+StreamDecoder::~StreamDecoder() {}
+
+Status StreamDecoder::Consume(const uint8_t* data, int64_t size) {
+ return impl_->Consume(data, size);
+}
+Status StreamDecoder::Consume(std::shared_ptr<Buffer> buffer) {
+ return impl_->Consume(std::move(buffer));
+}
+
+std::shared_ptr<Schema> StreamDecoder::schema() const { return impl_->schema(); }
+
+int64_t StreamDecoder::next_required_size() const { return impl_->next_required_size(); }
+
+ReadStats StreamDecoder::stats() const { return impl_->stats(); }
+
+Result<std::shared_ptr<Schema>> ReadSchema(io::InputStream* stream,
+ DictionaryMemo* dictionary_memo) {
+ std::unique_ptr<MessageReader> reader = MessageReader::Open(stream);
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Message> message, reader->ReadNextMessage());
+ if (!message) {
+ return Status::Invalid("Tried reading schema message, was null or length 0");
+ }
+ CHECK_MESSAGE_TYPE(MessageType::SCHEMA, message->type());
+ return ReadSchema(*message, dictionary_memo);
+}
+
+Result<std::shared_ptr<Schema>> ReadSchema(const Message& message,
+ DictionaryMemo* dictionary_memo) {
+ std::shared_ptr<Schema> result;
+ RETURN_NOT_OK(internal::GetSchema(message.header(), dictionary_memo, &result));
+ return result;
+}
+
+Result<std::shared_ptr<Tensor>> ReadTensor(io::InputStream* file) {
+ std::unique_ptr<Message> message;
+ RETURN_NOT_OK(ReadContiguousPayload(file, &message));
+ return ReadTensor(*message);
+}
+
+Result<std::shared_ptr<Tensor>> ReadTensor(const Message& message) {
+ std::shared_ptr<DataType> type;
+ std::vector<int64_t> shape;
+ std::vector<int64_t> strides;
+ std::vector<std::string> dim_names;
+ CHECK_HAS_BODY(message);
+ RETURN_NOT_OK(internal::GetTensorMetadata(*message.metadata(), &type, &shape, &strides,
+ &dim_names));
+ return Tensor::Make(type, message.body(), shape, strides, dim_names);
+}
+
+namespace {
+
+Result<std::shared_ptr<SparseIndex>> ReadSparseCOOIndex(
+ const flatbuf::SparseTensor* sparse_tensor, const std::vector<int64_t>& shape,
+ int64_t non_zero_length, io::RandomAccessFile* file) {
+ auto* sparse_index = sparse_tensor->sparseIndex_as_SparseTensorIndexCOO();
+ const auto ndim = static_cast<int64_t>(shape.size());
+
+ std::shared_ptr<DataType> indices_type;
+ RETURN_NOT_OK(internal::GetSparseCOOIndexMetadata(sparse_index, &indices_type));
+ const int64_t indices_elsize = GetByteWidth(*indices_type);
+
+ auto* indices_buffer = sparse_index->indicesBuffer();
+ ARROW_ASSIGN_OR_RAISE(auto indices_data,
+ file->ReadAt(indices_buffer->offset(), indices_buffer->length()));
+ std::vector<int64_t> indices_shape({non_zero_length, ndim});
+ auto* indices_strides = sparse_index->indicesStrides();
+ std::vector<int64_t> strides(2);
+ if (indices_strides && indices_strides->size() > 0) {
+ if (indices_strides->size() != 2) {
+ return Status::Invalid("Wrong size for indicesStrides in SparseCOOIndex");
+ }
+ strides[0] = indices_strides->Get(0);
+ strides[1] = indices_strides->Get(1);
+ } else {
+ // Row-major by default
+ strides[0] = indices_elsize * ndim;
+ strides[1] = indices_elsize;
+ }
+ return SparseCOOIndex::Make(
+ std::make_shared<Tensor>(indices_type, indices_data, indices_shape, strides),
+ sparse_index->isCanonical());
+}
+
+Result<std::shared_ptr<SparseIndex>> ReadSparseCSXIndex(
+ const flatbuf::SparseTensor* sparse_tensor, const std::vector<int64_t>& shape,
+ int64_t non_zero_length, io::RandomAccessFile* file) {
+ if (shape.size() != 2) {
+ return Status::Invalid("Invalid shape length for a sparse matrix");
+ }
+
+ auto* sparse_index = sparse_tensor->sparseIndex_as_SparseMatrixIndexCSX();
+
+ std::shared_ptr<DataType> indptr_type, indices_type;
+ RETURN_NOT_OK(
+ internal::GetSparseCSXIndexMetadata(sparse_index, &indptr_type, &indices_type));
+ const int indptr_byte_width = GetByteWidth(*indptr_type);
+
+ auto* indptr_buffer = sparse_index->indptrBuffer();
+ ARROW_ASSIGN_OR_RAISE(auto indptr_data,
+ file->ReadAt(indptr_buffer->offset(), indptr_buffer->length()));
+
+ auto* indices_buffer = sparse_index->indicesBuffer();
+ ARROW_ASSIGN_OR_RAISE(auto indices_data,
+ file->ReadAt(indices_buffer->offset(), indices_buffer->length()));
+
+ std::vector<int64_t> indices_shape({non_zero_length});
+ const auto indices_minimum_bytes = indices_shape[0] * GetByteWidth(*indices_type);
+ if (indices_minimum_bytes > indices_buffer->length()) {
+ return Status::Invalid("shape is inconsistent to the size of indices buffer");
+ }
+
+ switch (sparse_index->compressedAxis()) {
+ case flatbuf::SparseMatrixCompressedAxis::Row: {
+ std::vector<int64_t> indptr_shape({shape[0] + 1});
+ const int64_t indptr_minimum_bytes = indptr_shape[0] * indptr_byte_width;
+ if (indptr_minimum_bytes > indptr_buffer->length()) {
+ return Status::Invalid("shape is inconsistent to the size of indptr buffer");
+ }
+ return std::make_shared<SparseCSRIndex>(
+ std::make_shared<Tensor>(indptr_type, indptr_data, indptr_shape),
+ std::make_shared<Tensor>(indices_type, indices_data, indices_shape));
+ }
+ case flatbuf::SparseMatrixCompressedAxis::Column: {
+ std::vector<int64_t> indptr_shape({shape[1] + 1});
+ const int64_t indptr_minimum_bytes = indptr_shape[0] * indptr_byte_width;
+ if (indptr_minimum_bytes > indptr_buffer->length()) {
+ return Status::Invalid("shape is inconsistent to the size of indptr buffer");
+ }
+ return std::make_shared<SparseCSCIndex>(
+ std::make_shared<Tensor>(indptr_type, indptr_data, indptr_shape),
+ std::make_shared<Tensor>(indices_type, indices_data, indices_shape));
+ }
+ default:
+ return Status::Invalid("Invalid value of SparseMatrixCompressedAxis");
+ }
+}
+
+Result<std::shared_ptr<SparseIndex>> ReadSparseCSFIndex(
+ const flatbuf::SparseTensor* sparse_tensor, const std::vector<int64_t>& shape,
+ io::RandomAccessFile* file) {
+ auto* sparse_index = sparse_tensor->sparseIndex_as_SparseTensorIndexCSF();
+ const auto ndim = static_cast<int64_t>(shape.size());
+ auto* indptr_buffers = sparse_index->indptrBuffers();
+ auto* indices_buffers = sparse_index->indicesBuffers();
+ std::vector<std::shared_ptr<Buffer>> indptr_data(ndim - 1);
+ std::vector<std::shared_ptr<Buffer>> indices_data(ndim);
+
+ std::shared_ptr<DataType> indptr_type, indices_type;
+ std::vector<int64_t> axis_order, indices_size;
+
+ RETURN_NOT_OK(internal::GetSparseCSFIndexMetadata(
+ sparse_index, &axis_order, &indices_size, &indptr_type, &indices_type));
+ for (int i = 0; i < static_cast<int>(indptr_buffers->size()); ++i) {
+ ARROW_ASSIGN_OR_RAISE(indptr_data[i], file->ReadAt(indptr_buffers->Get(i)->offset(),
+ indptr_buffers->Get(i)->length()));
+ }
+ for (int i = 0; i < static_cast<int>(indices_buffers->size()); ++i) {
+ ARROW_ASSIGN_OR_RAISE(indices_data[i],
+ file->ReadAt(indices_buffers->Get(i)->offset(),
+ indices_buffers->Get(i)->length()));
+ }
+
+ return SparseCSFIndex::Make(indptr_type, indices_type, indices_size, axis_order,
+ indptr_data, indices_data);
+}
+
+Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCOOIndex(
+ const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
+ const std::vector<std::string>& dim_names,
+ const std::shared_ptr<SparseCOOIndex>& sparse_index, int64_t non_zero_length,
+ const std::shared_ptr<Buffer>& data) {
+ return SparseCOOTensor::Make(sparse_index, type, data, shape, dim_names);
+}
+
+Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCSRIndex(
+ const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
+ const std::vector<std::string>& dim_names,
+ const std::shared_ptr<SparseCSRIndex>& sparse_index, int64_t non_zero_length,
+ const std::shared_ptr<Buffer>& data) {
+ return SparseCSRMatrix::Make(sparse_index, type, data, shape, dim_names);
+}
+
+Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCSCIndex(
+ const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
+ const std::vector<std::string>& dim_names,
+ const std::shared_ptr<SparseCSCIndex>& sparse_index, int64_t non_zero_length,
+ const std::shared_ptr<Buffer>& data) {
+ return SparseCSCMatrix::Make(sparse_index, type, data, shape, dim_names);
+}
+
+Result<std::shared_ptr<SparseTensor>> MakeSparseTensorWithSparseCSFIndex(
+ const std::shared_ptr<DataType>& type, const std::vector<int64_t>& shape,
+ const std::vector<std::string>& dim_names,
+ const std::shared_ptr<SparseCSFIndex>& sparse_index,
+ const std::shared_ptr<Buffer>& data) {
+ return SparseCSFTensor::Make(sparse_index, type, data, shape, dim_names);
+}
+
+Status ReadSparseTensorMetadata(const Buffer& metadata,
+ std::shared_ptr<DataType>* out_type,
+ std::vector<int64_t>* out_shape,
+ std::vector<std::string>* out_dim_names,
+ int64_t* out_non_zero_length,
+ SparseTensorFormat::type* out_format_id,
+ const flatbuf::SparseTensor** out_fb_sparse_tensor,
+ const flatbuf::Buffer** out_buffer) {
+ RETURN_NOT_OK(internal::GetSparseTensorMetadata(
+ metadata, out_type, out_shape, out_dim_names, out_non_zero_length, out_format_id));
+
+ const flatbuf::Message* message = nullptr;
+ RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
+
+ auto sparse_tensor = message->header_as_SparseTensor();
+ if (sparse_tensor == nullptr) {
+ return Status::IOError(
+ "Header-type of flatbuffer-encoded Message is not SparseTensor.");
+ }
+ *out_fb_sparse_tensor = sparse_tensor;
+
+ auto buffer = sparse_tensor->data();
+ if (!BitUtil::IsMultipleOf8(buffer->offset())) {
+ return Status::Invalid(
+ "Buffer of sparse index data did not start on 8-byte aligned offset: ",
+ buffer->offset());
+ }
+ *out_buffer = buffer;
+
+ return Status::OK();
+}
+
+} // namespace
+
+namespace internal {
+
+namespace {
+
+Result<size_t> GetSparseTensorBodyBufferCount(SparseTensorFormat::type format_id,
+ const size_t ndim) {
+ switch (format_id) {
+ case SparseTensorFormat::COO:
+ return 2;
+
+ case SparseTensorFormat::CSR:
+ return 3;
+
+ case SparseTensorFormat::CSC:
+ return 3;
+
+ case SparseTensorFormat::CSF:
+ return 2 * ndim;
+
+ default:
+ return Status::Invalid("Unrecognized sparse tensor format");
+ }
+}
+
+Status CheckSparseTensorBodyBufferCount(const IpcPayload& payload,
+ SparseTensorFormat::type sparse_tensor_format_id,
+ const size_t ndim) {
+ size_t expected_body_buffer_count = 0;
+ ARROW_ASSIGN_OR_RAISE(expected_body_buffer_count,
+ GetSparseTensorBodyBufferCount(sparse_tensor_format_id, ndim));
+ if (payload.body_buffers.size() != expected_body_buffer_count) {
+ return Status::Invalid("Invalid body buffer count for a sparse tensor");
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+Result<size_t> ReadSparseTensorBodyBufferCount(const Buffer& metadata) {
+ SparseTensorFormat::type format_id;
+ std::vector<int64_t> shape;
+
+ RETURN_NOT_OK(internal::GetSparseTensorMetadata(metadata, nullptr, &shape, nullptr,
+ nullptr, &format_id));
+
+ return GetSparseTensorBodyBufferCount(format_id, static_cast<size_t>(shape.size()));
+}
+
+Result<std::shared_ptr<SparseTensor>> ReadSparseTensorPayload(const IpcPayload& payload) {
+ std::shared_ptr<DataType> type;
+ std::vector<int64_t> shape;
+ std::vector<std::string> dim_names;
+ int64_t non_zero_length;
+ SparseTensorFormat::type sparse_tensor_format_id;
+ const flatbuf::SparseTensor* sparse_tensor;
+ const flatbuf::Buffer* buffer;
+
+ RETURN_NOT_OK(ReadSparseTensorMetadata(*payload.metadata, &type, &shape, &dim_names,
+ &non_zero_length, &sparse_tensor_format_id,
+ &sparse_tensor, &buffer));
+
+ RETURN_NOT_OK(CheckSparseTensorBodyBufferCount(payload, sparse_tensor_format_id,
+ static_cast<size_t>(shape.size())));
+
+ switch (sparse_tensor_format_id) {
+ case SparseTensorFormat::COO: {
+ std::shared_ptr<SparseCOOIndex> sparse_index;
+ std::shared_ptr<DataType> indices_type;
+ RETURN_NOT_OK(internal::GetSparseCOOIndexMetadata(
+ sparse_tensor->sparseIndex_as_SparseTensorIndexCOO(), &indices_type));
+ ARROW_ASSIGN_OR_RAISE(sparse_index,
+ SparseCOOIndex::Make(indices_type, shape, non_zero_length,
+ payload.body_buffers[0]));
+ return MakeSparseTensorWithSparseCOOIndex(type, shape, dim_names, sparse_index,
+ non_zero_length, payload.body_buffers[1]);
+ }
+ case SparseTensorFormat::CSR: {
+ std::shared_ptr<SparseCSRIndex> sparse_index;
+ std::shared_ptr<DataType> indptr_type;
+ std::shared_ptr<DataType> indices_type;
+ RETURN_NOT_OK(internal::GetSparseCSXIndexMetadata(
+ sparse_tensor->sparseIndex_as_SparseMatrixIndexCSX(), &indptr_type,
+ &indices_type));
+ ARROW_CHECK_EQ(indptr_type, indices_type);
+ ARROW_ASSIGN_OR_RAISE(
+ sparse_index,
+ SparseCSRIndex::Make(indices_type, shape, non_zero_length,
+ payload.body_buffers[0], payload.body_buffers[1]));
+ return MakeSparseTensorWithSparseCSRIndex(type, shape, dim_names, sparse_index,
+ non_zero_length, payload.body_buffers[2]);
+ }
+ case SparseTensorFormat::CSC: {
+ std::shared_ptr<SparseCSCIndex> sparse_index;
+ std::shared_ptr<DataType> indptr_type;
+ std::shared_ptr<DataType> indices_type;
+ RETURN_NOT_OK(internal::GetSparseCSXIndexMetadata(
+ sparse_tensor->sparseIndex_as_SparseMatrixIndexCSX(), &indptr_type,
+ &indices_type));
+ ARROW_CHECK_EQ(indptr_type, indices_type);
+ ARROW_ASSIGN_OR_RAISE(
+ sparse_index,
+ SparseCSCIndex::Make(indices_type, shape, non_zero_length,
+ payload.body_buffers[0], payload.body_buffers[1]));
+ return MakeSparseTensorWithSparseCSCIndex(type, shape, dim_names, sparse_index,
+ non_zero_length, payload.body_buffers[2]);
+ }
+ case SparseTensorFormat::CSF: {
+ std::shared_ptr<SparseCSFIndex> sparse_index;
+ std::shared_ptr<DataType> indptr_type, indices_type;
+ std::vector<int64_t> axis_order, indices_size;
+
+ RETURN_NOT_OK(internal::GetSparseCSFIndexMetadata(
+ sparse_tensor->sparseIndex_as_SparseTensorIndexCSF(), &axis_order,
+ &indices_size, &indptr_type, &indices_type));
+ ARROW_CHECK_EQ(indptr_type, indices_type);
+
+ const int64_t ndim = shape.size();
+ std::vector<std::shared_ptr<Buffer>> indptr_data(ndim - 1);
+ std::vector<std::shared_ptr<Buffer>> indices_data(ndim);
+
+ for (int64_t i = 0; i < ndim - 1; ++i) {
+ indptr_data[i] = payload.body_buffers[i];
+ }
+ for (int64_t i = 0; i < ndim; ++i) {
+ indices_data[i] = payload.body_buffers[i + ndim - 1];
+ }
+
+ ARROW_ASSIGN_OR_RAISE(sparse_index,
+ SparseCSFIndex::Make(indptr_type, indices_type, indices_size,
+ axis_order, indptr_data, indices_data));
+ return MakeSparseTensorWithSparseCSFIndex(type, shape, dim_names, sparse_index,
+ payload.body_buffers[2 * ndim - 1]);
+ }
+ default:
+ return Status::Invalid("Unsupported sparse index format");
+ }
+}
+
+} // namespace internal
+
+Result<std::shared_ptr<SparseTensor>> ReadSparseTensor(const Buffer& metadata,
+ io::RandomAccessFile* file) {
+ std::shared_ptr<DataType> type;
+ std::vector<int64_t> shape;
+ std::vector<std::string> dim_names;
+ int64_t non_zero_length;
+ SparseTensorFormat::type sparse_tensor_format_id;
+ const flatbuf::SparseTensor* sparse_tensor;
+ const flatbuf::Buffer* buffer;
+
+ RETURN_NOT_OK(ReadSparseTensorMetadata(metadata, &type, &shape, &dim_names,
+ &non_zero_length, &sparse_tensor_format_id,
+ &sparse_tensor, &buffer));
+
+ ARROW_ASSIGN_OR_RAISE(auto data, file->ReadAt(buffer->offset(), buffer->length()));
+
+ std::shared_ptr<SparseIndex> sparse_index;
+ switch (sparse_tensor_format_id) {
+ case SparseTensorFormat::COO: {
+ ARROW_ASSIGN_OR_RAISE(
+ sparse_index, ReadSparseCOOIndex(sparse_tensor, shape, non_zero_length, file));
+ return MakeSparseTensorWithSparseCOOIndex(
+ type, shape, dim_names, checked_pointer_cast<SparseCOOIndex>(sparse_index),
+ non_zero_length, data);
+ }
+ case SparseTensorFormat::CSR: {
+ ARROW_ASSIGN_OR_RAISE(
+ sparse_index, ReadSparseCSXIndex(sparse_tensor, shape, non_zero_length, file));
+ return MakeSparseTensorWithSparseCSRIndex(
+ type, shape, dim_names, checked_pointer_cast<SparseCSRIndex>(sparse_index),
+ non_zero_length, data);
+ }
+ case SparseTensorFormat::CSC: {
+ ARROW_ASSIGN_OR_RAISE(
+ sparse_index, ReadSparseCSXIndex(sparse_tensor, shape, non_zero_length, file));
+ return MakeSparseTensorWithSparseCSCIndex(
+ type, shape, dim_names, checked_pointer_cast<SparseCSCIndex>(sparse_index),
+ non_zero_length, data);
+ }
+ case SparseTensorFormat::CSF: {
+ ARROW_ASSIGN_OR_RAISE(sparse_index, ReadSparseCSFIndex(sparse_tensor, shape, file));
+ return MakeSparseTensorWithSparseCSFIndex(
+ type, shape, dim_names, checked_pointer_cast<SparseCSFIndex>(sparse_index),
+ data);
+ }
+ default:
+ return Status::Invalid("Unsupported sparse index format");
+ }
+}
+
+Result<std::shared_ptr<SparseTensor>> ReadSparseTensor(const Message& message) {
+ CHECK_HAS_BODY(message);
+ ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body()));
+ return ReadSparseTensor(*message.metadata(), reader.get());
+}
+
+Result<std::shared_ptr<SparseTensor>> ReadSparseTensor(io::InputStream* file) {
+ std::unique_ptr<Message> message;
+ RETURN_NOT_OK(ReadContiguousPayload(file, &message));
+ CHECK_MESSAGE_TYPE(MessageType::SPARSE_TENSOR, message->type());
+ CHECK_HAS_BODY(*message);
+ ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body()));
+ return ReadSparseTensor(*message->metadata(), reader.get());
+}
+
+///////////////////////////////////////////////////////////////////////////
+// Helpers for fuzzing
+
+namespace internal {
+namespace {
+
+Status ValidateFuzzBatch(const RecordBatch& batch) {
+ auto st = batch.ValidateFull();
+ if (st.ok()) {
+ // If the batch is valid, printing should succeed
+ batch.ToString();
+ }
+ return st;
+}
+
+} // namespace
+
+Status FuzzIpcStream(const uint8_t* data, int64_t size) {
+ auto buffer = std::make_shared<Buffer>(data, size);
+ io::BufferReader buffer_reader(buffer);
+
+ std::shared_ptr<RecordBatchReader> batch_reader;
+ ARROW_ASSIGN_OR_RAISE(batch_reader, RecordBatchStreamReader::Open(&buffer_reader));
+ Status st;
+
+ while (true) {
+ std::shared_ptr<arrow::RecordBatch> batch;
+ RETURN_NOT_OK(batch_reader->ReadNext(&batch));
+ if (batch == nullptr) {
+ break;
+ }
+ st &= ValidateFuzzBatch(*batch);
+ }
+
+ return st;
+}
+
+Status FuzzIpcFile(const uint8_t* data, int64_t size) {
+ auto buffer = std::make_shared<Buffer>(data, size);
+ io::BufferReader buffer_reader(buffer);
+
+ std::shared_ptr<RecordBatchFileReader> batch_reader;
+ ARROW_ASSIGN_OR_RAISE(batch_reader, RecordBatchFileReader::Open(&buffer_reader));
+ Status st;
+
+ const int n_batches = batch_reader->num_record_batches();
+ for (int i = 0; i < n_batches; ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, batch_reader->ReadRecordBatch(i));
+ st &= ValidateFuzzBatch(*batch);
+ }
+
+ return st;
+}
+
+Status FuzzIpcTensorStream(const uint8_t* data, int64_t size) {
+ auto buffer = std::make_shared<Buffer>(data, size);
+ io::BufferReader buffer_reader(buffer);
+
+ std::shared_ptr<Tensor> tensor;
+
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(tensor, ReadTensor(&buffer_reader));
+ if (tensor == nullptr) {
+ break;
+ }
+ RETURN_NOT_OK(tensor->Validate());
+ }
+
+ return Status::OK();
+}
+
+} // namespace internal
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/reader.h b/src/arrow/cpp/src/arrow/ipc/reader.h
new file mode 100644
index 000000000..6f2157557
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/reader.h
@@ -0,0 +1,536 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Read Arrow files and streams
+
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/io/caching.h"
+#include "arrow/io/type_fwd.h"
+#include "arrow/ipc/message.h"
+#include "arrow/ipc/options.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace ipc {
+
+class DictionaryMemo;
+struct IpcPayload;
+
+using RecordBatchReader = ::arrow::RecordBatchReader;
+
+struct ReadStats {
+ /// Number of IPC messages read.
+ int64_t num_messages = 0;
+ /// Number of record batches read.
+ int64_t num_record_batches = 0;
+ /// Number of dictionary batches read.
+ ///
+ /// Note: num_dictionary_batches >= num_dictionary_deltas + num_replaced_dictionaries
+ int64_t num_dictionary_batches = 0;
+
+ /// Number of dictionary deltas read.
+ int64_t num_dictionary_deltas = 0;
+ /// Number of replaced dictionaries (i.e. where a dictionary batch replaces
+ /// an existing dictionary with an unrelated new dictionary).
+ int64_t num_replaced_dictionaries = 0;
+};
+
+/// \brief Synchronous batch stream reader that reads from io::InputStream
+///
+/// This class reads the schema (plus any dictionaries) as the first messages
+/// in the stream, followed by record batches. For more granular zero-copy
+/// reads see the ReadRecordBatch functions
+class ARROW_EXPORT RecordBatchStreamReader : public RecordBatchReader {
+ public:
+ /// Create batch reader from generic MessageReader.
+ /// This will take ownership of the given MessageReader.
+ ///
+ /// \param[in] message_reader a MessageReader implementation
+ /// \param[in] options any IPC reading options (optional)
+ /// \return the created batch reader
+ static Result<std::shared_ptr<RecordBatchStreamReader>> Open(
+ std::unique_ptr<MessageReader> message_reader,
+ const IpcReadOptions& options = IpcReadOptions::Defaults());
+
+ /// \brief Record batch stream reader from InputStream
+ ///
+ /// \param[in] stream an input stream instance. Must stay alive throughout
+ /// lifetime of stream reader
+ /// \param[in] options any IPC reading options (optional)
+ /// \return the created batch reader
+ static Result<std::shared_ptr<RecordBatchStreamReader>> Open(
+ io::InputStream* stream,
+ const IpcReadOptions& options = IpcReadOptions::Defaults());
+
+ /// \brief Open stream and retain ownership of stream object
+ /// \param[in] stream the input stream
+ /// \param[in] options any IPC reading options (optional)
+ /// \return the created batch reader
+ static Result<std::shared_ptr<RecordBatchStreamReader>> Open(
+ const std::shared_ptr<io::InputStream>& stream,
+ const IpcReadOptions& options = IpcReadOptions::Defaults());
+
+ /// \brief Return current read statistics
+ virtual ReadStats stats() const = 0;
+};
+
+/// \brief Reads the record batch file format
+class ARROW_EXPORT RecordBatchFileReader
+ : public std::enable_shared_from_this<RecordBatchFileReader> {
+ public:
+ virtual ~RecordBatchFileReader() = default;
+
+ /// \brief Open a RecordBatchFileReader
+ ///
+ /// Open a file-like object that is assumed to be self-contained; i.e., the
+ /// end of the file interface is the end of the Arrow file. Note that there
+ /// can be any amount of data preceding the Arrow-formatted data, because we
+ /// need only locate the end of the Arrow file stream to discover the metadata
+ /// and then proceed to read the data into memory.
+ static Result<std::shared_ptr<RecordBatchFileReader>> Open(
+ io::RandomAccessFile* file,
+ const IpcReadOptions& options = IpcReadOptions::Defaults());
+
+ /// \brief Open a RecordBatchFileReader
+ /// If the file is embedded within some larger file or memory region, you can
+ /// pass the absolute memory offset to the end of the file (which contains the
+ /// metadata footer). The metadata must have been written with memory offsets
+ /// relative to the start of the containing file
+ ///
+ /// \param[in] file the data source
+ /// \param[in] footer_offset the position of the end of the Arrow file
+ /// \param[in] options options for IPC reading
+ /// \return the returned reader
+ static Result<std::shared_ptr<RecordBatchFileReader>> Open(
+ io::RandomAccessFile* file, int64_t footer_offset,
+ const IpcReadOptions& options = IpcReadOptions::Defaults());
+
+ /// \brief Version of Open that retains ownership of file
+ ///
+ /// \param[in] file the data source
+ /// \param[in] options options for IPC reading
+ /// \return the returned reader
+ static Result<std::shared_ptr<RecordBatchFileReader>> Open(
+ const std::shared_ptr<io::RandomAccessFile>& file,
+ const IpcReadOptions& options = IpcReadOptions::Defaults());
+
+ /// \brief Version of Open that retains ownership of file
+ ///
+ /// \param[in] file the data source
+ /// \param[in] footer_offset the position of the end of the Arrow file
+ /// \param[in] options options for IPC reading
+ /// \return the returned reader
+ static Result<std::shared_ptr<RecordBatchFileReader>> Open(
+ const std::shared_ptr<io::RandomAccessFile>& file, int64_t footer_offset,
+ const IpcReadOptions& options = IpcReadOptions::Defaults());
+
+ /// \brief Open a file asynchronously (owns the file).
+ static Future<std::shared_ptr<RecordBatchFileReader>> OpenAsync(
+ const std::shared_ptr<io::RandomAccessFile>& file,
+ const IpcReadOptions& options = IpcReadOptions::Defaults());
+
+ /// \brief Open a file asynchronously (borrows the file).
+ static Future<std::shared_ptr<RecordBatchFileReader>> OpenAsync(
+ io::RandomAccessFile* file,
+ const IpcReadOptions& options = IpcReadOptions::Defaults());
+
+ /// \brief Open a file asynchronously (owns the file).
+ static Future<std::shared_ptr<RecordBatchFileReader>> OpenAsync(
+ const std::shared_ptr<io::RandomAccessFile>& file, int64_t footer_offset,
+ const IpcReadOptions& options = IpcReadOptions::Defaults());
+
+ /// \brief Open a file asynchronously (borrows the file).
+ static Future<std::shared_ptr<RecordBatchFileReader>> OpenAsync(
+ io::RandomAccessFile* file, int64_t footer_offset,
+ const IpcReadOptions& options = IpcReadOptions::Defaults());
+
+ /// \brief The schema read from the file
+ virtual std::shared_ptr<Schema> schema() const = 0;
+
+ /// \brief Returns the number of record batches in the file
+ virtual int num_record_batches() const = 0;
+
+ /// \brief Return the metadata version from the file metadata
+ virtual MetadataVersion version() const = 0;
+
+ /// \brief Return the contents of the custom_metadata field from the file's
+ /// Footer
+ virtual std::shared_ptr<const KeyValueMetadata> metadata() const = 0;
+
+ /// \brief Read a particular record batch from the file. Does not copy memory
+ /// if the input source supports zero-copy.
+ ///
+ /// \param[in] i the index of the record batch to return
+ /// \return the read batch
+ virtual Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(int i) = 0;
+
+ /// \brief Return current read statistics
+ virtual ReadStats stats() const = 0;
+
+ /// \brief Computes the total number of rows in the file.
+ virtual Result<int64_t> CountRows() = 0;
+
+ /// \brief Get a reentrant generator of record batches.
+ ///
+ /// \param[in] coalesce If true, enable I/O coalescing.
+ /// \param[in] io_context The IOContext to use (controls which thread pool
+ /// is used for I/O).
+ /// \param[in] cache_options Options for coalescing (if enabled).
+ /// \param[in] executor Optionally, an executor to use for decoding record
+ /// batches. This is generally only a benefit for very wide and/or
+ /// compressed batches.
+ virtual Result<AsyncGenerator<std::shared_ptr<RecordBatch>>> GetRecordBatchGenerator(
+ const bool coalesce = false,
+ const io::IOContext& io_context = io::default_io_context(),
+ const io::CacheOptions cache_options = io::CacheOptions::LazyDefaults(),
+ arrow::internal::Executor* executor = NULLPTR) = 0;
+};
+
+/// \brief A general listener class to receive events.
+///
+/// You must implement callback methods for interested events.
+///
+/// This API is EXPERIMENTAL.
+///
+/// \since 0.17.0
+class ARROW_EXPORT Listener {
+ public:
+ virtual ~Listener() = default;
+
+ /// \brief Called when end-of-stream is received.
+ ///
+ /// The default implementation just returns arrow::Status::OK().
+ ///
+ /// \return Status
+ ///
+ /// \see StreamDecoder
+ virtual Status OnEOS();
+
+ /// \brief Called when a record batch is decoded.
+ ///
+ /// The default implementation just returns
+ /// arrow::Status::NotImplemented().
+ ///
+ /// \param[in] record_batch a record batch decoded
+ /// \return Status
+ ///
+ /// \see StreamDecoder
+ virtual Status OnRecordBatchDecoded(std::shared_ptr<RecordBatch> record_batch);
+
+ /// \brief Called when a schema is decoded.
+ ///
+ /// The default implementation just returns arrow::Status::OK().
+ ///
+ /// \param[in] schema a schema decoded
+ /// \return Status
+ ///
+ /// \see StreamDecoder
+ virtual Status OnSchemaDecoded(std::shared_ptr<Schema> schema);
+};
+
+/// \brief Collect schema and record batches decoded by StreamDecoder.
+///
+/// This API is EXPERIMENTAL.
+///
+/// \since 0.17.0
+class ARROW_EXPORT CollectListener : public Listener {
+ public:
+ CollectListener() : schema_(), record_batches_() {}
+ virtual ~CollectListener() = default;
+
+ Status OnSchemaDecoded(std::shared_ptr<Schema> schema) override {
+ schema_ = std::move(schema);
+ return Status::OK();
+ }
+
+ Status OnRecordBatchDecoded(std::shared_ptr<RecordBatch> record_batch) override {
+ record_batches_.push_back(std::move(record_batch));
+ return Status::OK();
+ }
+
+ /// \return the decoded schema
+ std::shared_ptr<Schema> schema() const { return schema_; }
+
+ /// \return the all decoded record batches
+ std::vector<std::shared_ptr<RecordBatch>> record_batches() const {
+ return record_batches_;
+ }
+
+ private:
+ std::shared_ptr<Schema> schema_;
+ std::vector<std::shared_ptr<RecordBatch>> record_batches_;
+};
+
+/// \brief Push style stream decoder that receives data from user.
+///
+/// This class decodes the Apache Arrow IPC streaming format data.
+///
+/// This API is EXPERIMENTAL.
+///
+/// \see https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format
+///
+/// \since 0.17.0
+class ARROW_EXPORT StreamDecoder {
+ public:
+ /// \brief Construct a stream decoder.
+ ///
+ /// \param[in] listener a Listener that must implement
+ /// Listener::OnRecordBatchDecoded() to receive decoded record batches
+ /// \param[in] options any IPC reading options (optional)
+ StreamDecoder(std::shared_ptr<Listener> listener,
+ IpcReadOptions options = IpcReadOptions::Defaults());
+
+ virtual ~StreamDecoder();
+
+ /// \brief Feed data to the decoder as a raw data.
+ ///
+ /// If the decoder can read one or more record batches by the data,
+ /// the decoder calls listener->OnRecordBatchDecoded() with a
+ /// decoded record batch multiple times.
+ ///
+ /// \param[in] data a raw data to be processed. This data isn't
+ /// copied. The passed memory must be kept alive through record
+ /// batch processing.
+ /// \param[in] size raw data size.
+ /// \return Status
+ Status Consume(const uint8_t* data, int64_t size);
+
+ /// \brief Feed data to the decoder as a Buffer.
+ ///
+ /// If the decoder can read one or more record batches by the
+ /// Buffer, the decoder calls listener->RecordBatchReceived() with a
+ /// decoded record batch multiple times.
+ ///
+ /// \param[in] buffer a Buffer to be processed.
+ /// \return Status
+ Status Consume(std::shared_ptr<Buffer> buffer);
+
+ /// \return the shared schema of the record batches in the stream
+ std::shared_ptr<Schema> schema() const;
+
+ /// \brief Return the number of bytes needed to advance the state of
+ /// the decoder.
+ ///
+ /// This method is provided for users who want to optimize performance.
+ /// Normal users don't need to use this method.
+ ///
+ /// Here is an example usage for normal users:
+ ///
+ /// ~~~{.cpp}
+ /// decoder.Consume(buffer1);
+ /// decoder.Consume(buffer2);
+ /// decoder.Consume(buffer3);
+ /// ~~~
+ ///
+ /// Decoder has internal buffer. If consumed data isn't enough to
+ /// advance the state of the decoder, consumed data is buffered to
+ /// the internal buffer. It causes performance overhead.
+ ///
+ /// If you pass next_required_size() size data to each Consume()
+ /// call, the decoder doesn't use its internal buffer. It improves
+ /// performance.
+ ///
+ /// Here is an example usage to avoid using internal buffer:
+ ///
+ /// ~~~{.cpp}
+ /// buffer1 = get_data(decoder.next_required_size());
+ /// decoder.Consume(buffer1);
+ /// buffer2 = get_data(decoder.next_required_size());
+ /// decoder.Consume(buffer2);
+ /// ~~~
+ ///
+ /// Users can use this method to avoid creating small chunks. Record
+ /// batch data must be contiguous data. If users pass small chunks
+ /// to the decoder, the decoder needs concatenate small chunks
+ /// internally. It causes performance overhead.
+ ///
+ /// Here is an example usage to reduce small chunks:
+ ///
+ /// ~~~{.cpp}
+ /// buffer = AllocateResizableBuffer();
+ /// while ((small_chunk = get_data(&small_chunk_size))) {
+ /// auto current_buffer_size = buffer->size();
+ /// buffer->Resize(current_buffer_size + small_chunk_size);
+ /// memcpy(buffer->mutable_data() + current_buffer_size,
+ /// small_chunk,
+ /// small_chunk_size);
+ /// if (buffer->size() < decoder.next_required_size()) {
+ /// continue;
+ /// }
+ /// std::shared_ptr<arrow::Buffer> chunk(buffer.release());
+ /// decoder.Consume(chunk);
+ /// buffer = AllocateResizableBuffer();
+ /// }
+ /// if (buffer->size() > 0) {
+ /// std::shared_ptr<arrow::Buffer> chunk(buffer.release());
+ /// decoder.Consume(chunk);
+ /// }
+ /// ~~~
+ ///
+ /// \return the number of bytes needed to advance the state of the
+ /// decoder
+ int64_t next_required_size() const;
+
+ /// \brief Return current read statistics
+ ReadStats stats() const;
+
+ private:
+ class StreamDecoderImpl;
+ std::unique_ptr<StreamDecoderImpl> impl_;
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(StreamDecoder);
+};
+
+// Generic read functions; does not copy data if the input supports zero copy reads
+
+/// \brief Read Schema from stream serialized as a single IPC message
+/// and populate any dictionary-encoded fields into a DictionaryMemo
+///
+/// \param[in] stream an InputStream
+/// \param[in] dictionary_memo for recording dictionary-encoded fields
+/// \return the output Schema
+///
+/// If record batches follow the schema, it is better to use
+/// RecordBatchStreamReader
+ARROW_EXPORT
+Result<std::shared_ptr<Schema>> ReadSchema(io::InputStream* stream,
+ DictionaryMemo* dictionary_memo);
+
+/// \brief Read Schema from encapsulated Message
+///
+/// \param[in] message the message containing the Schema IPC metadata
+/// \param[in] dictionary_memo DictionaryMemo for recording dictionary-encoded
+/// fields. Can be nullptr if you are sure there are no
+/// dictionary-encoded fields
+/// \return the resulting Schema
+ARROW_EXPORT
+Result<std::shared_ptr<Schema>> ReadSchema(const Message& message,
+ DictionaryMemo* dictionary_memo);
+
+/// Read record batch as encapsulated IPC message with metadata size prefix and
+/// header
+///
+/// \param[in] schema the record batch schema
+/// \param[in] dictionary_memo DictionaryMemo which has any
+/// dictionaries. Can be nullptr if you are sure there are no
+/// dictionary-encoded fields
+/// \param[in] options IPC options for reading
+/// \param[in] stream the file where the batch is located
+/// \return the read record batch
+ARROW_EXPORT
+Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
+ const std::shared_ptr<Schema>& schema, const DictionaryMemo* dictionary_memo,
+ const IpcReadOptions& options, io::InputStream* stream);
+
+/// \brief Read record batch from message
+///
+/// \param[in] message a Message containing the record batch metadata
+/// \param[in] schema the record batch schema
+/// \param[in] dictionary_memo DictionaryMemo which has any
+/// dictionaries. Can be nullptr if you are sure there are no
+/// dictionary-encoded fields
+/// \param[in] options IPC options for reading
+/// \return the read record batch
+ARROW_EXPORT
+Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
+ const Message& message, const std::shared_ptr<Schema>& schema,
+ const DictionaryMemo* dictionary_memo, const IpcReadOptions& options);
+
+/// Read record batch from file given metadata and schema
+///
+/// \param[in] metadata a Message containing the record batch metadata
+/// \param[in] schema the record batch schema
+/// \param[in] dictionary_memo DictionaryMemo which has any
+/// dictionaries. Can be nullptr if you are sure there are no
+/// dictionary-encoded fields
+/// \param[in] file a random access file
+/// \param[in] options options for deserialization
+/// \return the read record batch
+ARROW_EXPORT
+Result<std::shared_ptr<RecordBatch>> ReadRecordBatch(
+ const Buffer& metadata, const std::shared_ptr<Schema>& schema,
+ const DictionaryMemo* dictionary_memo, const IpcReadOptions& options,
+ io::RandomAccessFile* file);
+
+/// \brief Read arrow::Tensor as encapsulated IPC message in file
+///
+/// \param[in] file an InputStream pointed at the start of the message
+/// \return the read tensor
+ARROW_EXPORT
+Result<std::shared_ptr<Tensor>> ReadTensor(io::InputStream* file);
+
+/// \brief EXPERIMENTAL: Read arrow::Tensor from IPC message
+///
+/// \param[in] message a Message containing the tensor metadata and body
+/// \return the read tensor
+ARROW_EXPORT
+Result<std::shared_ptr<Tensor>> ReadTensor(const Message& message);
+
+/// \brief EXPERIMENTAL: Read arrow::SparseTensor as encapsulated IPC message in file
+///
+/// \param[in] file an InputStream pointed at the start of the message
+/// \return the read sparse tensor
+ARROW_EXPORT
+Result<std::shared_ptr<SparseTensor>> ReadSparseTensor(io::InputStream* file);
+
+/// \brief EXPERIMENTAL: Read arrow::SparseTensor from IPC message
+///
+/// \param[in] message a Message containing the tensor metadata and body
+/// \return the read sparse tensor
+ARROW_EXPORT
+Result<std::shared_ptr<SparseTensor>> ReadSparseTensor(const Message& message);
+
+namespace internal {
+
+// These internal APIs may change without warning or deprecation
+
+/// \brief EXPERIMENTAL: Read arrow::SparseTensorFormat::type from a metadata
+/// \param[in] metadata a Buffer containing the sparse tensor metadata
+/// \return the count of the body buffers
+ARROW_EXPORT
+Result<size_t> ReadSparseTensorBodyBufferCount(const Buffer& metadata);
+
+/// \brief EXPERIMENTAL: Read arrow::SparseTensor from an IpcPayload
+/// \param[in] payload a IpcPayload contains a serialized SparseTensor
+/// \return the read sparse tensor
+ARROW_EXPORT
+Result<std::shared_ptr<SparseTensor>> ReadSparseTensorPayload(const IpcPayload& payload);
+
+// For fuzzing targets
+ARROW_EXPORT
+Status FuzzIpcStream(const uint8_t* data, int64_t size);
+ARROW_EXPORT
+Status FuzzIpcTensorStream(const uint8_t* data, int64_t size);
+ARROW_EXPORT
+Status FuzzIpcFile(const uint8_t* data, int64_t size);
+
+} // namespace internal
+
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/stream_fuzz.cc b/src/arrow/cpp/src/arrow/ipc/stream_fuzz.cc
new file mode 100644
index 000000000..e26f3d1f4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/stream_fuzz.cc
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+
+#include "arrow/ipc/reader.h"
+#include "arrow/status.h"
+#include "arrow/util/macros.h"
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+ auto status = arrow::ipc::internal::FuzzIpcStream(data, static_cast<int64_t>(size));
+ ARROW_UNUSED(status);
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/ipc/stream_to_file.cc b/src/arrow/cpp/src/arrow/ipc/stream_to_file.cc
new file mode 100644
index 000000000..40288b687
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/stream_to_file.cc
@@ -0,0 +1,60 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "arrow/io/stdio.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+
+namespace arrow {
+namespace ipc {
+
+// Converts a stream from stdin to a file written to standard out.
+// A typical usage would be:
+// $ <program that produces streaming output> | stream-to-file > file.arrow
+Status ConvertToFile() {
+ io::StdinStream input;
+ io::StdoutStream sink;
+
+ ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchStreamReader::Open(&input));
+ ARROW_ASSIGN_OR_RAISE(
+ auto writer, MakeFileWriter(&sink, reader->schema(), IpcWriteOptions::Defaults()));
+ std::shared_ptr<RecordBatch> batch;
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(batch, reader->Next());
+ if (batch == nullptr) break;
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+ }
+ return writer->Close();
+}
+
+} // namespace ipc
+} // namespace arrow
+
+int main(int argc, char** argv) {
+ arrow::Status status = arrow::ipc::ConvertToFile();
+ if (!status.ok()) {
+ std::cerr << "Could not convert to file: " << status.ToString() << std::endl;
+ return 1;
+ }
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/ipc/tensor_stream_fuzz.cc b/src/arrow/cpp/src/arrow/ipc/tensor_stream_fuzz.cc
new file mode 100644
index 000000000..7524940e1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/tensor_stream_fuzz.cc
@@ -0,0 +1,29 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+
+#include "arrow/ipc/reader.h"
+#include "arrow/status.h"
+#include "arrow/util/macros.h"
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+ auto status =
+ arrow::ipc::internal::FuzzIpcTensorStream(data, static_cast<int64_t>(size));
+ ARROW_UNUSED(status);
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/ipc/tensor_test.cc b/src/arrow/cpp/src/arrow/ipc/tensor_test.cc
new file mode 100644
index 000000000..7af1492f6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/tensor_test.cc
@@ -0,0 +1,506 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <string>
+#include <unordered_set>
+
+#include <gtest/gtest.h>
+
+#include "arrow/io/file.h"
+#include "arrow/io/memory.h"
+#include "arrow/io/test_common.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/test_common.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/sparse_tensor.h"
+#include "arrow/status.h"
+#include "arrow/tensor.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/io_util.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::GetByteWidth;
+using internal::TemporaryDir;
+
+namespace ipc {
+namespace test {
+
+class BaseTensorTest : public ::testing::Test, public io::MemoryMapFixture {
+ public:
+ void SetUp() { ASSERT_OK_AND_ASSIGN(temp_dir_, TemporaryDir::Make("ipc-test-")); }
+
+ void TearDown() { io::MemoryMapFixture::TearDown(); }
+
+ protected:
+ std::shared_ptr<io::MemoryMappedFile> mmap_;
+ std::unique_ptr<TemporaryDir> temp_dir_;
+};
+
+class TestTensorRoundTrip : public BaseTensorTest {
+ public:
+ void CheckTensorRoundTrip(const Tensor& tensor) {
+ int32_t metadata_length;
+ int64_t body_length;
+ const int elem_size = GetByteWidth(*tensor.type());
+
+ ASSERT_OK(mmap_->Seek(0));
+
+ ASSERT_OK(WriteTensor(tensor, mmap_.get(), &metadata_length, &body_length));
+
+ const int64_t expected_body_length = elem_size * tensor.size();
+ ASSERT_EQ(expected_body_length, body_length);
+
+ ASSERT_OK(mmap_->Seek(0));
+
+ std::shared_ptr<Tensor> result;
+ ASSERT_OK_AND_ASSIGN(result, ReadTensor(mmap_.get()));
+
+ ASSERT_EQ(result->data()->size(), expected_body_length);
+ ASSERT_TRUE(tensor.Equals(*result));
+ }
+
+ protected:
+ std::shared_ptr<io::MemoryMappedFile> mmap_;
+ std::unique_ptr<TemporaryDir> temp_dir_;
+};
+
+TEST_F(TestTensorRoundTrip, BasicRoundtrip) {
+ std::string path = "test-write-tensor";
+ constexpr int64_t kBufferSize = 1 << 20;
+ ASSERT_OK_AND_ASSIGN(mmap_, io::MemoryMapFixture::InitMemoryMap(kBufferSize, path));
+
+ std::vector<int64_t> shape = {4, 6};
+ std::vector<int64_t> strides = {48, 8};
+ std::vector<std::string> dim_names = {"foo", "bar"};
+ int64_t size = 24;
+
+ std::vector<int64_t> values;
+ randint(size, 0, 100, &values);
+
+ auto data = Buffer::Wrap(values);
+
+ Tensor t0(int64(), data, shape, strides, dim_names);
+ Tensor t_no_dims(int64(), data, {}, {}, {});
+ Tensor t_zero_length_dim(int64(), data, {0}, {8}, {"foo"});
+
+ CheckTensorRoundTrip(t0);
+ CheckTensorRoundTrip(t_no_dims);
+ CheckTensorRoundTrip(t_zero_length_dim);
+
+ int64_t serialized_size;
+ ASSERT_OK(GetTensorSize(t0, &serialized_size));
+ ASSERT_TRUE(serialized_size > static_cast<int64_t>(size * sizeof(int64_t)));
+
+ // ARROW-2840: Check that padding/alignment minded
+ std::vector<int64_t> shape_2 = {1, 1};
+ std::vector<int64_t> strides_2 = {8, 8};
+ Tensor t0_not_multiple_64(int64(), data, shape_2, strides_2, dim_names);
+ CheckTensorRoundTrip(t0_not_multiple_64);
+}
+
+TEST_F(TestTensorRoundTrip, NonContiguous) {
+ std::string path = "test-write-tensor-strided";
+ constexpr int64_t kBufferSize = 1 << 20;
+ ASSERT_OK_AND_ASSIGN(mmap_, io::MemoryMapFixture::InitMemoryMap(kBufferSize, path));
+
+ std::vector<int64_t> values;
+ randint(24, 0, 100, &values);
+
+ auto data = Buffer::Wrap(values);
+ Tensor tensor(int64(), data, {4, 3}, {48, 16});
+
+ CheckTensorRoundTrip(tensor);
+}
+
+template <typename IndexValueType>
+class TestSparseTensorRoundTrip : public BaseTensorTest {
+ public:
+ void CheckSparseCOOTensorRoundTrip(const SparseCOOTensor& sparse_tensor) {
+ const int elem_size = GetByteWidth(*sparse_tensor.type());
+ const int index_elem_size = sizeof(typename IndexValueType::c_type);
+
+ int32_t metadata_length;
+ int64_t body_length;
+
+ ASSERT_OK(mmap_->Seek(0));
+
+ ASSERT_OK(
+ WriteSparseTensor(sparse_tensor, mmap_.get(), &metadata_length, &body_length));
+
+ const auto& sparse_index =
+ checked_cast<const SparseCOOIndex&>(*sparse_tensor.sparse_index());
+ const int64_t indices_length =
+ BitUtil::RoundUpToMultipleOf8(index_elem_size * sparse_index.indices()->size());
+ const int64_t data_length =
+ BitUtil::RoundUpToMultipleOf8(elem_size * sparse_tensor.non_zero_length());
+ const int64_t expected_body_length = indices_length + data_length;
+ ASSERT_EQ(expected_body_length, body_length);
+
+ ASSERT_OK(mmap_->Seek(0));
+
+ std::shared_ptr<SparseTensor> result;
+ ASSERT_OK_AND_ASSIGN(result, ReadSparseTensor(mmap_.get()));
+ ASSERT_EQ(SparseTensorFormat::COO, result->format_id());
+
+ const auto& resulted_sparse_index =
+ checked_cast<const SparseCOOIndex&>(*result->sparse_index());
+ ASSERT_EQ(resulted_sparse_index.indices()->data()->size(), indices_length);
+ ASSERT_EQ(resulted_sparse_index.is_canonical(), sparse_index.is_canonical());
+ ASSERT_EQ(result->data()->size(), data_length);
+ ASSERT_TRUE(result->Equals(sparse_tensor));
+ }
+
+ template <typename SparseIndexType>
+ void CheckSparseCSXMatrixRoundTrip(
+ const SparseTensorImpl<SparseIndexType>& sparse_tensor) {
+ static_assert(std::is_same<SparseIndexType, SparseCSRIndex>::value ||
+ std::is_same<SparseIndexType, SparseCSCIndex>::value,
+ "SparseIndexType must be either SparseCSRIndex or SparseCSCIndex");
+
+ const int elem_size = GetByteWidth(*sparse_tensor.type());
+ const int index_elem_size = sizeof(typename IndexValueType::c_type);
+
+ int32_t metadata_length;
+ int64_t body_length;
+
+ ASSERT_OK(mmap_->Seek(0));
+
+ ASSERT_OK(
+ WriteSparseTensor(sparse_tensor, mmap_.get(), &metadata_length, &body_length));
+
+ const auto& sparse_index =
+ checked_cast<const SparseIndexType&>(*sparse_tensor.sparse_index());
+ const int64_t indptr_length =
+ BitUtil::RoundUpToMultipleOf8(index_elem_size * sparse_index.indptr()->size());
+ const int64_t indices_length =
+ BitUtil::RoundUpToMultipleOf8(index_elem_size * sparse_index.indices()->size());
+ const int64_t data_length =
+ BitUtil::RoundUpToMultipleOf8(elem_size * sparse_tensor.non_zero_length());
+ const int64_t expected_body_length = indptr_length + indices_length + data_length;
+ ASSERT_EQ(expected_body_length, body_length);
+
+ ASSERT_OK(mmap_->Seek(0));
+
+ std::shared_ptr<SparseTensor> result;
+ ASSERT_OK_AND_ASSIGN(result, ReadSparseTensor(mmap_.get()));
+
+ constexpr auto expected_format_id =
+ std::is_same<SparseIndexType, SparseCSRIndex>::value ? SparseTensorFormat::CSR
+ : SparseTensorFormat::CSC;
+ ASSERT_EQ(expected_format_id, result->format_id());
+
+ const auto& resulted_sparse_index =
+ checked_cast<const SparseIndexType&>(*result->sparse_index());
+ ASSERT_EQ(resulted_sparse_index.indptr()->data()->size(), indptr_length);
+ ASSERT_EQ(resulted_sparse_index.indices()->data()->size(), indices_length);
+ ASSERT_EQ(result->data()->size(), data_length);
+ ASSERT_TRUE(result->Equals(sparse_tensor));
+ }
+
+ void CheckSparseCSFTensorRoundTrip(const SparseCSFTensor& sparse_tensor) {
+ const int elem_size = GetByteWidth(*sparse_tensor.type());
+ const int index_elem_size = sizeof(typename IndexValueType::c_type);
+
+ int32_t metadata_length;
+ int64_t body_length;
+
+ ASSERT_OK(mmap_->Seek(0));
+
+ ASSERT_OK(
+ WriteSparseTensor(sparse_tensor, mmap_.get(), &metadata_length, &body_length));
+
+ const auto& sparse_index =
+ checked_cast<const SparseCSFIndex&>(*sparse_tensor.sparse_index());
+
+ const int64_t ndim = sparse_index.axis_order().size();
+ int64_t indptr_length = 0;
+ int64_t indices_length = 0;
+
+ for (int64_t i = 0; i < ndim - 1; ++i) {
+ indptr_length += BitUtil::RoundUpToMultipleOf8(index_elem_size *
+ sparse_index.indptr()[i]->size());
+ }
+ for (int64_t i = 0; i < ndim; ++i) {
+ indices_length += BitUtil::RoundUpToMultipleOf8(index_elem_size *
+ sparse_index.indices()[i]->size());
+ }
+ const int64_t data_length =
+ BitUtil::RoundUpToMultipleOf8(elem_size * sparse_tensor.non_zero_length());
+ const int64_t expected_body_length = indptr_length + indices_length + data_length;
+ ASSERT_EQ(expected_body_length, body_length);
+
+ ASSERT_OK(mmap_->Seek(0));
+
+ std::shared_ptr<SparseTensor> result;
+ ASSERT_OK_AND_ASSIGN(result, ReadSparseTensor(mmap_.get()));
+ ASSERT_EQ(SparseTensorFormat::CSF, result->format_id());
+
+ const auto& resulted_sparse_index =
+ checked_cast<const SparseCSFIndex&>(*result->sparse_index());
+
+ int64_t out_indptr_length = 0;
+ int64_t out_indices_length = 0;
+ for (int i = 0; i < ndim - 1; ++i) {
+ out_indptr_length += BitUtil::RoundUpToMultipleOf8(
+ index_elem_size * resulted_sparse_index.indptr()[i]->size());
+ }
+ for (int i = 0; i < ndim; ++i) {
+ out_indices_length += BitUtil::RoundUpToMultipleOf8(
+ index_elem_size * resulted_sparse_index.indices()[i]->size());
+ }
+
+ ASSERT_EQ(out_indptr_length, indptr_length);
+ ASSERT_EQ(out_indices_length, indices_length);
+ ASSERT_EQ(result->data()->size(), data_length);
+ ASSERT_TRUE(resulted_sparse_index.Equals(sparse_index));
+ ASSERT_TRUE(result->Equals(sparse_tensor));
+ }
+
+ protected:
+ std::shared_ptr<SparseCOOIndex> MakeSparseCOOIndex(
+ const std::vector<int64_t>& coords_shape,
+ const std::vector<int64_t>& coords_strides,
+ std::vector<typename IndexValueType::c_type>& coords_values) const {
+ auto coords_data = Buffer::Wrap(coords_values);
+ auto coords = std::make_shared<NumericTensor<IndexValueType>>(
+ coords_data, coords_shape, coords_strides);
+ return std::make_shared<SparseCOOIndex>(coords);
+ }
+
+ template <typename ValueType>
+ Result<std::shared_ptr<SparseCOOTensor>> MakeSparseCOOTensor(
+ const std::shared_ptr<SparseCOOIndex>& si, std::vector<ValueType>& sparse_values,
+ const std::vector<int64_t>& shape,
+ const std::vector<std::string>& dim_names = {}) const {
+ auto data = Buffer::Wrap(sparse_values);
+ return SparseCOOTensor::Make(si, CTypeTraits<ValueType>::type_singleton(), data,
+ shape, dim_names);
+ }
+};
+
+TYPED_TEST_SUITE_P(TestSparseTensorRoundTrip);
+
+TYPED_TEST_P(TestSparseTensorRoundTrip, WithSparseCOOIndexRowMajor) {
+ using IndexValueType = TypeParam;
+ using c_index_value_type = typename IndexValueType::c_type;
+
+ std::string path = "test-write-sparse-coo-tensor";
+ constexpr int64_t kBufferSize = 1 << 20;
+ ASSERT_OK_AND_ASSIGN(this->mmap_,
+ io::MemoryMapFixture::InitMemoryMap(kBufferSize, path));
+
+ // Dense representation:
+ // [
+ // [
+ // 1 0 2 0
+ // 0 3 0 4
+ // 5 0 6 0
+ // ],
+ // [
+ // 0 11 0 12
+ // 13 0 14 0
+ // 0 15 0 16
+ // ]
+ // ]
+ //
+ // Sparse representation:
+ // idx[0] = [0 0 0 0 0 0 1 1 1 1 1 1]
+ // idx[1] = [0 0 1 1 2 2 0 0 1 1 2 2]
+ // idx[2] = [0 2 1 3 0 2 1 3 0 2 1 3]
+ // data = [1 2 3 4 5 6 11 12 13 14 15 16]
+
+ // canonical
+ std::vector<c_index_value_type> coords_values = {0, 0, 0, 0, 0, 2, 0, 1, 1, 0, 1, 3,
+ 0, 2, 0, 0, 2, 2, 1, 0, 1, 1, 0, 3,
+ 1, 1, 0, 1, 1, 2, 1, 2, 1, 1, 2, 3};
+ const int sizeof_index_value = sizeof(c_index_value_type);
+ std::shared_ptr<SparseCOOIndex> si;
+ ASSERT_OK_AND_ASSIGN(
+ si, SparseCOOIndex::Make(TypeTraits<IndexValueType>::type_singleton(), {12, 3},
+ {sizeof_index_value * 3, sizeof_index_value},
+ Buffer::Wrap(coords_values)));
+ ASSERT_TRUE(si->is_canonical());
+
+ std::vector<int64_t> shape = {2, 3, 4};
+ std::vector<std::string> dim_names = {"foo", "bar", "baz"};
+ std::vector<int64_t> values = {1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16};
+ std::shared_ptr<SparseCOOTensor> st;
+ ASSERT_OK_AND_ASSIGN(st, this->MakeSparseCOOTensor(si, values, shape, dim_names));
+
+ this->CheckSparseCOOTensorRoundTrip(*st);
+
+ // non-canonical
+ ASSERT_OK_AND_ASSIGN(
+ si, SparseCOOIndex::Make(TypeTraits<IndexValueType>::type_singleton(), {12, 3},
+ {sizeof_index_value * 3, sizeof_index_value},
+ Buffer::Wrap(coords_values), false));
+ ASSERT_FALSE(si->is_canonical());
+ ASSERT_OK_AND_ASSIGN(st, this->MakeSparseCOOTensor(si, values, shape, dim_names));
+
+ this->CheckSparseCOOTensorRoundTrip(*st);
+}
+
+TYPED_TEST_P(TestSparseTensorRoundTrip, WithSparseCOOIndexColumnMajor) {
+ using IndexValueType = TypeParam;
+ using c_index_value_type = typename IndexValueType::c_type;
+
+ std::string path = "test-write-sparse-coo-tensor";
+ constexpr int64_t kBufferSize = 1 << 20;
+ ASSERT_OK_AND_ASSIGN(this->mmap_,
+ io::MemoryMapFixture::InitMemoryMap(kBufferSize, path));
+
+ // Dense representation:
+ // [
+ // [
+ // 1 0 2 0
+ // 0 3 0 4
+ // 5 0 6 0
+ // ],
+ // [
+ // 0 11 0 12
+ // 13 0 14 0
+ // 0 15 0 16
+ // ]
+ // ]
+ //
+ // Sparse representation:
+ // idx[0] = [0 0 0 0 0 0 1 1 1 1 1 1]
+ // idx[1] = [0 0 1 1 2 2 0 0 1 1 2 2]
+ // idx[2] = [0 2 1 3 0 2 1 3 0 2 1 3]
+ // data = [1 2 3 4 5 6 11 12 13 14 15 16]
+
+ // canonical
+ std::vector<c_index_value_type> coords_values = {0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
+ 0, 0, 1, 1, 2, 2, 0, 0, 1, 1, 2, 2,
+ 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3};
+ const int sizeof_index_value = sizeof(c_index_value_type);
+ std::shared_ptr<SparseCOOIndex> si;
+ ASSERT_OK_AND_ASSIGN(
+ si, SparseCOOIndex::Make(TypeTraits<IndexValueType>::type_singleton(), {12, 3},
+ {sizeof_index_value, sizeof_index_value * 12},
+ Buffer::Wrap(coords_values)));
+ ASSERT_TRUE(si->is_canonical());
+
+ std::vector<int64_t> shape = {2, 3, 4};
+ std::vector<std::string> dim_names = {"foo", "bar", "baz"};
+ std::vector<int64_t> values = {1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16};
+
+ std::shared_ptr<SparseCOOTensor> st;
+ ASSERT_OK_AND_ASSIGN(st, this->MakeSparseCOOTensor(si, values, shape, dim_names));
+
+ this->CheckSparseCOOTensorRoundTrip(*st);
+
+ // non-canonical
+ ASSERT_OK_AND_ASSIGN(
+ si, SparseCOOIndex::Make(TypeTraits<IndexValueType>::type_singleton(), {12, 3},
+ {sizeof_index_value, sizeof_index_value * 12},
+ Buffer::Wrap(coords_values), false));
+ ASSERT_FALSE(si->is_canonical());
+ ASSERT_OK_AND_ASSIGN(st, this->MakeSparseCOOTensor(si, values, shape, dim_names));
+
+ this->CheckSparseCOOTensorRoundTrip(*st);
+}
+
+TYPED_TEST_P(TestSparseTensorRoundTrip, WithSparseCSRIndex) {
+ using IndexValueType = TypeParam;
+
+ std::string path = "test-write-sparse-csr-matrix";
+ constexpr int64_t kBufferSize = 1 << 20;
+ ASSERT_OK_AND_ASSIGN(this->mmap_,
+ io::MemoryMapFixture::InitMemoryMap(kBufferSize, path));
+
+ std::vector<int64_t> shape = {4, 6};
+ std::vector<std::string> dim_names = {"foo", "bar", "baz"};
+ std::vector<int64_t> values = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+
+ auto data = Buffer::Wrap(values);
+ NumericTensor<Int64Type> t(data, shape, {}, dim_names);
+ std::shared_ptr<SparseCSRMatrix> st;
+ ASSERT_OK_AND_ASSIGN(
+ st, SparseCSRMatrix::Make(t, TypeTraits<IndexValueType>::type_singleton()));
+
+ this->CheckSparseCSXMatrixRoundTrip(*st);
+}
+
+TYPED_TEST_P(TestSparseTensorRoundTrip, WithSparseCSCIndex) {
+ using IndexValueType = TypeParam;
+
+ std::string path = "test-write-sparse-csc-matrix";
+ constexpr int64_t kBufferSize = 1 << 20;
+ ASSERT_OK_AND_ASSIGN(this->mmap_,
+ io::MemoryMapFixture::InitMemoryMap(kBufferSize, path));
+
+ std::vector<int64_t> shape = {4, 6};
+ std::vector<std::string> dim_names = {"foo", "bar", "baz"};
+ std::vector<int64_t> values = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+
+ auto data = Buffer::Wrap(values);
+ NumericTensor<Int64Type> t(data, shape, {}, dim_names);
+ std::shared_ptr<SparseCSCMatrix> st;
+ ASSERT_OK_AND_ASSIGN(
+ st, SparseCSCMatrix::Make(t, TypeTraits<IndexValueType>::type_singleton()));
+
+ this->CheckSparseCSXMatrixRoundTrip(*st);
+}
+
+TYPED_TEST_P(TestSparseTensorRoundTrip, WithSparseCSFIndex) {
+ using IndexValueType = TypeParam;
+
+ std::string path = "test-write-sparse-csf-tensor";
+ constexpr int64_t kBufferSize = 1 << 20;
+ ASSERT_OK_AND_ASSIGN(this->mmap_,
+ io::MemoryMapFixture::InitMemoryMap(kBufferSize, path));
+
+ std::vector<int64_t> shape = {4, 6};
+ std::vector<std::string> dim_names = {"foo", "bar", "baz"};
+ std::vector<int64_t> values = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+
+ auto data = Buffer::Wrap(values);
+ NumericTensor<Int64Type> t(data, shape, {}, dim_names);
+ std::shared_ptr<SparseCSFTensor> st;
+ ASSERT_OK_AND_ASSIGN(
+ st, SparseCSFTensor::Make(t, TypeTraits<IndexValueType>::type_singleton()));
+
+ this->CheckSparseCSFTensorRoundTrip(*st);
+}
+REGISTER_TYPED_TEST_SUITE_P(TestSparseTensorRoundTrip, WithSparseCOOIndexRowMajor,
+ WithSparseCOOIndexColumnMajor, WithSparseCSRIndex,
+ WithSparseCSCIndex, WithSparseCSFIndex);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt8, TestSparseTensorRoundTrip, Int8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt8, TestSparseTensorRoundTrip, UInt8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt16, TestSparseTensorRoundTrip, Int16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt16, TestSparseTensorRoundTrip, UInt16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt32, TestSparseTensorRoundTrip, Int32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt32, TestSparseTensorRoundTrip, UInt32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt64, TestSparseTensorRoundTrip, Int64Type);
+
+} // namespace test
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/test_common.cc b/src/arrow/cpp/src/arrow/ipc/test_common.cc
new file mode 100644
index 000000000..5068eca00
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/test_common.cc
@@ -0,0 +1,1125 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <numeric>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/builder_time.h"
+#include "arrow/ipc/test_common.h"
+#include "arrow/pretty_print.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/tensor.h"
+#include "arrow/testing/extension_type.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_builders.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace ipc {
+namespace test {
+
+void CompareArraysDetailed(int index, const Array& result, const Array& expected) {
+ if (!expected.Equals(result)) {
+ std::stringstream pp_result;
+ std::stringstream pp_expected;
+
+ ASSERT_OK(PrettyPrint(expected, 0, &pp_expected));
+ ASSERT_OK(PrettyPrint(result, 0, &pp_result));
+
+ FAIL() << "Index: " << index << " Expected: " << pp_expected.str()
+ << "\nGot: " << pp_result.str();
+ }
+}
+
+void CompareBatchColumnsDetailed(const RecordBatch& result, const RecordBatch& expected) {
+ for (int i = 0; i < expected.num_columns(); ++i) {
+ auto left = result.column(i);
+ auto right = expected.column(i);
+ CompareArraysDetailed(i, *left, *right);
+ }
+}
+
+Status MakeRandomInt32Array(int64_t length, bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out, uint32_t seed) {
+ random::RandomArrayGenerator rand(seed);
+ const double null_probability = include_nulls ? 0.5 : 0.0;
+
+ *out = rand.Int32(length, 0, 1000, null_probability);
+
+ return Status::OK();
+}
+
+namespace {
+
+template <typename ArrayType>
+Status MakeRandomArray(int64_t length, bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out, uint32_t seed) {
+ random::RandomArrayGenerator rand(seed);
+ const double null_probability = include_nulls ? 0.5 : 0.0;
+
+ *out = rand.Numeric<ArrayType>(length, 0, 1000, null_probability);
+
+ return Status::OK();
+}
+
+template <>
+Status MakeRandomArray<Int8Type>(int64_t length, bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out, uint32_t seed) {
+ random::RandomArrayGenerator rand(seed);
+ const double null_probability = include_nulls ? 0.5 : 0.0;
+
+ *out = rand.Numeric<Int8Type>(length, 0, 127, null_probability);
+
+ return Status::OK();
+}
+
+template <>
+Status MakeRandomArray<UInt8Type>(int64_t length, bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out, uint32_t seed) {
+ random::RandomArrayGenerator rand(seed);
+ const double null_probability = include_nulls ? 0.5 : 0.0;
+
+ *out = rand.Numeric<UInt8Type>(length, 0, 127, null_probability);
+
+ return Status::OK();
+}
+
+template <typename TypeClass>
+Status MakeListArray(const std::shared_ptr<Array>& child_array, int num_lists,
+ bool include_nulls, MemoryPool* pool, std::shared_ptr<Array>* out) {
+ using offset_type = typename TypeClass::offset_type;
+ using ArrayType = typename TypeTraits<TypeClass>::ArrayType;
+
+ // Create the null list values
+ std::vector<uint8_t> valid_lists(num_lists);
+ const double null_percent = include_nulls ? 0.1 : 0;
+ random_null_bytes(num_lists, null_percent, valid_lists.data());
+
+ // Create list offsets
+ const int max_list_size = 10;
+
+ std::vector<offset_type> list_sizes(num_lists, 0);
+ std::vector<offset_type> offsets(
+ num_lists + 1, 0); // +1 so we can shift for nulls. See partial sum below.
+ const auto seed = static_cast<uint32_t>(child_array->length());
+
+ if (num_lists > 0) {
+ rand_uniform_int(num_lists, seed, 0, max_list_size, list_sizes.data());
+ // make sure sizes are consistent with null
+ std::transform(list_sizes.begin(), list_sizes.end(), valid_lists.begin(),
+ list_sizes.begin(),
+ [](offset_type size, uint8_t valid) { return valid == 0 ? 0 : size; });
+ std::partial_sum(list_sizes.begin(), list_sizes.end(), ++offsets.begin());
+
+ // Force invariants
+ const auto child_length = static_cast<offset_type>(child_array->length());
+ offsets[0] = 0;
+ std::replace_if(
+ offsets.begin(), offsets.end(),
+ [child_length](offset_type offset) { return offset > child_length; },
+ child_length);
+ }
+
+ offsets[num_lists] = static_cast<offset_type>(child_array->length());
+
+ /// TODO(wesm): Implement support for nulls in ListArray::FromArrays
+ std::shared_ptr<Buffer> null_bitmap, offsets_buffer;
+ RETURN_NOT_OK(GetBitmapFromVector(valid_lists, &null_bitmap));
+ RETURN_NOT_OK(CopyBufferFromVector(offsets, pool, &offsets_buffer));
+
+ *out = std::make_shared<ArrayType>(std::make_shared<TypeClass>(child_array->type()),
+ num_lists, offsets_buffer, child_array, null_bitmap,
+ kUnknownNullCount);
+ return (**out).Validate();
+}
+
+} // namespace
+
+Status MakeRandomListArray(const std::shared_ptr<Array>& child_array, int num_lists,
+ bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out) {
+ return MakeListArray<ListType>(child_array, num_lists, include_nulls, pool, out);
+}
+
+Status MakeRandomLargeListArray(const std::shared_ptr<Array>& child_array, int num_lists,
+ bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out) {
+ return MakeListArray<LargeListType>(child_array, num_lists, include_nulls, pool, out);
+}
+
+Status MakeRandomMapArray(const std::shared_ptr<Array>& key_array,
+ const std::shared_ptr<Array>& item_array, int num_maps,
+ bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out) {
+ auto pair_type = struct_(
+ {field("key", key_array->type(), false), field("value", item_array->type())});
+
+ auto pair_array = std::make_shared<StructArray>(pair_type, key_array->length(),
+ ArrayVector{key_array, item_array});
+
+ RETURN_NOT_OK(MakeRandomListArray(pair_array, num_maps, include_nulls, pool, out));
+ auto map_data = (*out)->data();
+ map_data->type = map(key_array->type(), item_array->type());
+ out->reset(new MapArray(map_data));
+ return (**out).Validate();
+}
+
+Status MakeRandomBooleanArray(const int length, bool include_nulls,
+ std::shared_ptr<Array>* out) {
+ std::vector<uint8_t> values(length);
+ random_null_bytes(length, 0.5, values.data());
+ ARROW_ASSIGN_OR_RAISE(auto data, internal::BytesToBits(values));
+
+ if (include_nulls) {
+ std::vector<uint8_t> valid_bytes(length);
+ ARROW_ASSIGN_OR_RAISE(auto null_bitmap, internal::BytesToBits(valid_bytes));
+ random_null_bytes(length, 0.1, valid_bytes.data());
+ *out = std::make_shared<BooleanArray>(length, data, null_bitmap, -1);
+ } else {
+ *out = std::make_shared<BooleanArray>(length, data, NULLPTR, 0);
+ }
+ return Status::OK();
+}
+
+Status MakeBooleanBatchSized(const int length, std::shared_ptr<RecordBatch>* out) {
+ // Make the schema
+ auto f0 = field("f0", boolean());
+ auto f1 = field("f1", boolean());
+ auto schema = ::arrow::schema({f0, f1});
+
+ std::shared_ptr<Array> a0, a1;
+ RETURN_NOT_OK(MakeRandomBooleanArray(length, true, &a0));
+ RETURN_NOT_OK(MakeRandomBooleanArray(length, false, &a1));
+ *out = RecordBatch::Make(schema, length, {a0, a1});
+ return Status::OK();
+}
+
+Status MakeBooleanBatch(std::shared_ptr<RecordBatch>* out) {
+ return MakeBooleanBatchSized(1000, out);
+}
+
+Status MakeIntBatchSized(int length, std::shared_ptr<RecordBatch>* out, uint32_t seed) {
+ // Make the schema
+ auto f0 = field("f0", int8());
+ auto f1 = field("f1", uint8());
+ auto f2 = field("f2", int16());
+ auto f3 = field("f3", uint16());
+ auto f4 = field("f4", int32());
+ auto f5 = field("f5", uint32());
+ auto f6 = field("f6", int64());
+ auto f7 = field("f7", uint64());
+ auto schema = ::arrow::schema({f0, f1, f2, f3, f4, f5, f6, f7});
+
+ // Example data
+ std::shared_ptr<Array> a0, a1, a2, a3, a4, a5, a6, a7;
+ MemoryPool* pool = default_memory_pool();
+ RETURN_NOT_OK(MakeRandomArray<Int8Type>(length, false, pool, &a0, seed));
+ RETURN_NOT_OK(MakeRandomArray<UInt8Type>(length, true, pool, &a1, seed));
+ RETURN_NOT_OK(MakeRandomArray<Int16Type>(length, true, pool, &a2, seed));
+ RETURN_NOT_OK(MakeRandomArray<UInt16Type>(length, false, pool, &a3, seed));
+ RETURN_NOT_OK(MakeRandomArray<Int32Type>(length, false, pool, &a4, seed));
+ RETURN_NOT_OK(MakeRandomArray<UInt32Type>(length, true, pool, &a5, seed));
+ RETURN_NOT_OK(MakeRandomArray<Int64Type>(length, true, pool, &a6, seed));
+ RETURN_NOT_OK(MakeRandomArray<UInt64Type>(length, false, pool, &a7, seed));
+ *out = RecordBatch::Make(schema, length, {a0, a1, a2, a3, a4, a5, a6, a7});
+ return Status::OK();
+}
+
+Status MakeIntRecordBatch(std::shared_ptr<RecordBatch>* out) {
+ return MakeIntBatchSized(10, out);
+}
+
+Status MakeFloat3264BatchSized(int length, std::shared_ptr<RecordBatch>* out,
+ uint32_t seed) {
+ // Make the schema
+ auto f0 = field("f0", float32());
+ auto f1 = field("f1", float64());
+ auto schema = ::arrow::schema({f0, f1});
+
+ // Example data
+ std::shared_ptr<Array> a0, a1;
+ MemoryPool* pool = default_memory_pool();
+ RETURN_NOT_OK(MakeRandomArray<FloatType>(length, false, pool, &a0, seed));
+ RETURN_NOT_OK(MakeRandomArray<DoubleType>(length, true, pool, &a1, seed + 1));
+ *out = RecordBatch::Make(schema, length, {a0, a1});
+ return Status::OK();
+}
+
+Status MakeFloat3264Batch(std::shared_ptr<RecordBatch>* out) {
+ return MakeFloat3264BatchSized(10, out);
+}
+
+Status MakeFloatBatchSized(int length, std::shared_ptr<RecordBatch>* out, uint32_t seed) {
+ // Make the schema
+ auto f0 = field("f0", float16());
+ auto f1 = field("f1", float32());
+ auto f2 = field("f2", float64());
+ auto schema = ::arrow::schema({f0, f1, f2});
+
+ // Example data
+ std::shared_ptr<Array> a0, a1, a2;
+ MemoryPool* pool = default_memory_pool();
+ RETURN_NOT_OK(MakeRandomArray<HalfFloatType>(length, false, pool, &a0, seed));
+ RETURN_NOT_OK(MakeRandomArray<FloatType>(length, false, pool, &a1, seed + 1));
+ RETURN_NOT_OK(MakeRandomArray<DoubleType>(length, true, pool, &a2, seed + 2));
+ *out = RecordBatch::Make(schema, length, {a0, a1, a2});
+ return Status::OK();
+}
+
+Status MakeFloatBatch(std::shared_ptr<RecordBatch>* out) {
+ return MakeFloatBatchSized(10, out);
+}
+
+Status MakeRandomStringArray(int64_t length, bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out) {
+ const std::vector<std::string> values = {"", "", "abc", "123",
+ "efg", "456!@#!@#", "12312"};
+ StringBuilder builder(pool);
+ const size_t values_len = values.size();
+ for (int64_t i = 0; i < length; ++i) {
+ int64_t values_index = i % values_len;
+ if (include_nulls && values_index == 0) {
+ RETURN_NOT_OK(builder.AppendNull());
+ } else {
+ const auto& value = values[values_index];
+ RETURN_NOT_OK(builder.Append(value));
+ }
+ }
+ return builder.Finish(out);
+}
+
+template <class BuilderType>
+static Status MakeBinaryArrayWithUniqueValues(int64_t length, bool include_nulls,
+ MemoryPool* pool,
+ std::shared_ptr<Array>* out) {
+ BuilderType builder(pool);
+ for (int64_t i = 0; i < length; ++i) {
+ if (include_nulls && (i % 7 == 0)) {
+ RETURN_NOT_OK(builder.AppendNull());
+ } else {
+ RETURN_NOT_OK(builder.Append(std::to_string(i)));
+ }
+ }
+ return builder.Finish(out);
+}
+
+Status MakeStringTypesRecordBatch(std::shared_ptr<RecordBatch>* out, bool with_nulls) {
+ const int64_t length = 500;
+ auto f0 = field("strings", utf8());
+ auto f1 = field("binaries", binary());
+ auto f2 = field("large_strings", large_utf8());
+ auto f3 = field("large_binaries", large_binary());
+ auto schema = ::arrow::schema({f0, f1, f2, f3});
+
+ std::shared_ptr<Array> a0, a1, a2, a3;
+ MemoryPool* pool = default_memory_pool();
+
+ // Quirk with RETURN_NOT_OK macro and templated functions
+ {
+ auto s =
+ MakeBinaryArrayWithUniqueValues<StringBuilder>(length, with_nulls, pool, &a0);
+ RETURN_NOT_OK(s);
+ }
+ {
+ auto s =
+ MakeBinaryArrayWithUniqueValues<BinaryBuilder>(length, with_nulls, pool, &a1);
+ RETURN_NOT_OK(s);
+ }
+ {
+ auto s = MakeBinaryArrayWithUniqueValues<LargeStringBuilder>(length, with_nulls, pool,
+ &a2);
+ RETURN_NOT_OK(s);
+ }
+ {
+ auto s = MakeBinaryArrayWithUniqueValues<LargeBinaryBuilder>(length, with_nulls, pool,
+ &a3);
+ RETURN_NOT_OK(s);
+ }
+ *out = RecordBatch::Make(schema, length, {a0, a1, a2, a3});
+ return Status::OK();
+}
+
+Status MakeStringTypesRecordBatchWithNulls(std::shared_ptr<RecordBatch>* out) {
+ return MakeStringTypesRecordBatch(out, true);
+}
+
+Status MakeNullRecordBatch(std::shared_ptr<RecordBatch>* out) {
+ const int64_t length = 500;
+ auto f0 = field("f0", null());
+ auto schema = ::arrow::schema({f0});
+ std::shared_ptr<Array> a0 = std::make_shared<NullArray>(length);
+ *out = RecordBatch::Make(schema, length, {a0});
+ return Status::OK();
+}
+
+Status MakeListRecordBatch(std::shared_ptr<RecordBatch>* out) {
+ // Make the schema
+ auto f0 = field("f0", list(int32()));
+ auto f1 = field("f1", list(list(int32())));
+ auto f2 = field("f2", large_list(int32()));
+ auto schema = ::arrow::schema({f0, f1, f2});
+
+ // Example data
+
+ MemoryPool* pool = default_memory_pool();
+ const int length = 200;
+ std::shared_ptr<Array> leaf_values, list_array, list_list_array, large_list_array;
+ const bool include_nulls = true;
+ RETURN_NOT_OK(MakeRandomInt32Array(1000, include_nulls, pool, &leaf_values));
+ RETURN_NOT_OK(
+ MakeRandomListArray(leaf_values, length, include_nulls, pool, &list_array));
+ RETURN_NOT_OK(
+ MakeRandomListArray(list_array, length, include_nulls, pool, &list_list_array));
+ RETURN_NOT_OK(MakeRandomLargeListArray(leaf_values, length, include_nulls, pool,
+ &large_list_array));
+ *out =
+ RecordBatch::Make(schema, length, {list_array, list_list_array, large_list_array});
+ return Status::OK();
+}
+
+Status MakeFixedSizeListRecordBatch(std::shared_ptr<RecordBatch>* out) {
+ // Make the schema
+ auto f0 = field("f0", fixed_size_list(int32(), 1));
+ auto f1 = field("f1", fixed_size_list(list(int32()), 3));
+ auto f2 = field("f2", int32());
+ auto schema = ::arrow::schema({f0, f1, f2});
+
+ // Example data
+
+ MemoryPool* pool = default_memory_pool();
+ const int length = 200;
+ std::shared_ptr<Array> leaf_values, list_array, list_list_array, flat_array;
+ const bool include_nulls = true;
+ RETURN_NOT_OK(MakeRandomInt32Array(1000, include_nulls, pool, &leaf_values));
+ RETURN_NOT_OK(
+ MakeRandomListArray(leaf_values, length * 3, include_nulls, pool, &list_array));
+ list_list_array = std::make_shared<FixedSizeListArray>(f1->type(), length, list_array);
+ list_array = std::make_shared<FixedSizeListArray>(f0->type(), length,
+ leaf_values->Slice(0, length));
+ RETURN_NOT_OK(MakeRandomInt32Array(length, include_nulls, pool, &flat_array));
+ *out = RecordBatch::Make(schema, length, {list_array, list_list_array, flat_array});
+ return Status::OK();
+}
+
+Status MakeZeroLengthRecordBatch(std::shared_ptr<RecordBatch>* out) {
+ // Make the schema
+ auto f0 = field("f0", list(int32()));
+ auto f1 = field("f1", list(list(int32())));
+ auto f2 = field("f2", int32());
+ auto schema = ::arrow::schema({f0, f1, f2});
+
+ // Example data
+ MemoryPool* pool = default_memory_pool();
+ const bool include_nulls = true;
+ std::shared_ptr<Array> leaf_values, list_array, list_list_array, flat_array;
+ RETURN_NOT_OK(MakeRandomInt32Array(0, include_nulls, pool, &leaf_values));
+ RETURN_NOT_OK(MakeRandomListArray(leaf_values, 0, include_nulls, pool, &list_array));
+ RETURN_NOT_OK(
+ MakeRandomListArray(list_array, 0, include_nulls, pool, &list_list_array));
+ RETURN_NOT_OK(MakeRandomInt32Array(0, include_nulls, pool, &flat_array));
+ *out = RecordBatch::Make(schema, 0, {list_array, list_list_array, flat_array});
+ return Status::OK();
+}
+
+Status MakeNonNullRecordBatch(std::shared_ptr<RecordBatch>* out) {
+ // Make the schema
+ auto f0 = field("f0", list(int32()));
+ auto f1 = field("f1", list(list(int32())));
+ auto f2 = field("f2", int32());
+ auto schema = ::arrow::schema({f0, f1, f2});
+
+ // Example data
+ MemoryPool* pool = default_memory_pool();
+ const int length = 50;
+ std::shared_ptr<Array> leaf_values, list_array, list_list_array, flat_array;
+
+ RETURN_NOT_OK(MakeRandomInt32Array(1000, true, pool, &leaf_values));
+ bool include_nulls = false;
+ RETURN_NOT_OK(
+ MakeRandomListArray(leaf_values, length, include_nulls, pool, &list_array));
+ RETURN_NOT_OK(
+ MakeRandomListArray(list_array, length, include_nulls, pool, &list_list_array));
+ RETURN_NOT_OK(MakeRandomInt32Array(length, include_nulls, pool, &flat_array));
+ *out = RecordBatch::Make(schema, length, {list_array, list_list_array, flat_array});
+ return Status::OK();
+}
+
+Status MakeDeeplyNestedList(std::shared_ptr<RecordBatch>* out) {
+ const int batch_length = 5;
+ auto type = int32();
+
+ MemoryPool* pool = default_memory_pool();
+ std::shared_ptr<Array> array;
+ const bool include_nulls = true;
+ RETURN_NOT_OK(MakeRandomInt32Array(1000, include_nulls, pool, &array));
+ for (int i = 0; i < 63; ++i) {
+ type = std::static_pointer_cast<DataType>(list(type));
+ RETURN_NOT_OK(MakeRandomListArray(array, batch_length, include_nulls, pool, &array));
+ }
+
+ auto f0 = field("f0", type);
+ auto schema = ::arrow::schema({f0});
+ std::vector<std::shared_ptr<Array>> arrays = {array};
+ *out = RecordBatch::Make(schema, batch_length, arrays);
+ return Status::OK();
+}
+
+Status MakeStruct(std::shared_ptr<RecordBatch>* out) {
+ // reuse constructed list columns
+ std::shared_ptr<RecordBatch> list_batch;
+ RETURN_NOT_OK(MakeListRecordBatch(&list_batch));
+ std::vector<std::shared_ptr<Array>> columns = {
+ list_batch->column(0), list_batch->column(1), list_batch->column(2)};
+ auto list_schema = list_batch->schema();
+
+ // Define schema
+ std::shared_ptr<DataType> type(new StructType(
+ {list_schema->field(0), list_schema->field(1), list_schema->field(2)}));
+ auto f0 = field("non_null_struct", type);
+ auto f1 = field("null_struct", type);
+ auto schema = ::arrow::schema({f0, f1});
+
+ // construct individual nullable/non-nullable struct arrays
+ std::shared_ptr<Array> no_nulls(new StructArray(type, list_batch->num_rows(), columns));
+ std::vector<uint8_t> null_bytes(list_batch->num_rows(), 1);
+ null_bytes[0] = 0;
+ ARROW_ASSIGN_OR_RAISE(auto null_bitmap, internal::BytesToBits(null_bytes));
+ std::shared_ptr<Array> with_nulls(
+ new StructArray(type, list_batch->num_rows(), columns, null_bitmap, 1));
+
+ // construct batch
+ std::vector<std::shared_ptr<Array>> arrays = {no_nulls, with_nulls};
+ *out = RecordBatch::Make(schema, list_batch->num_rows(), arrays);
+ return Status::OK();
+}
+
+Status MakeUnion(std::shared_ptr<RecordBatch>* out) {
+ // Define schema
+ std::vector<std::shared_ptr<Field>> union_fields(
+ {field("u0", int32()), field("u1", uint8())});
+
+ std::vector<int8_t> type_codes = {5, 10};
+ auto sparse_type = sparse_union(union_fields, type_codes);
+ auto dense_type = dense_union(union_fields, type_codes);
+
+ auto f0 = field("sparse", sparse_type);
+ auto f1 = field("dense", dense_type);
+
+ auto schema = ::arrow::schema({f0, f1});
+
+ // Create data
+ std::vector<std::shared_ptr<Array>> sparse_children(2);
+ std::vector<std::shared_ptr<Array>> dense_children(2);
+
+ const int64_t length = 7;
+
+ std::shared_ptr<Buffer> type_ids_buffer;
+ std::vector<uint8_t> type_ids = {5, 10, 5, 5, 10, 10, 5};
+ RETURN_NOT_OK(CopyBufferFromVector(type_ids, default_memory_pool(), &type_ids_buffer));
+
+ std::vector<int32_t> u0_values = {0, 1, 2, 3, 4, 5, 6};
+ ArrayFromVector<Int32Type, int32_t>(u0_values, &sparse_children[0]);
+
+ std::vector<uint8_t> u1_values = {10, 11, 12, 13, 14, 15, 16};
+ ArrayFromVector<UInt8Type, uint8_t>(u1_values, &sparse_children[1]);
+
+ // dense children
+ u0_values = {0, 2, 3, 7};
+ ArrayFromVector<Int32Type, int32_t>(u0_values, &dense_children[0]);
+
+ u1_values = {11, 14, 15};
+ ArrayFromVector<UInt8Type, uint8_t>(u1_values, &dense_children[1]);
+
+ std::shared_ptr<Buffer> offsets_buffer;
+ std::vector<int32_t> offsets = {0, 0, 1, 2, 1, 2, 3};
+ RETURN_NOT_OK(CopyBufferFromVector(offsets, default_memory_pool(), &offsets_buffer));
+
+ auto sparse = std::make_shared<SparseUnionArray>(sparse_type, length, sparse_children,
+ type_ids_buffer);
+ auto dense = std::make_shared<DenseUnionArray>(dense_type, length, dense_children,
+ type_ids_buffer, offsets_buffer);
+
+ // construct batch
+ std::vector<std::shared_ptr<Array>> arrays = {sparse, dense};
+ *out = RecordBatch::Make(schema, length, arrays);
+ return Status::OK();
+}
+
+Status MakeDictionary(std::shared_ptr<RecordBatch>* out) {
+ const int64_t length = 6;
+
+ std::vector<bool> is_valid = {true, true, false, true, true, true};
+
+ auto dict_ty = utf8();
+
+ auto dict1 = ArrayFromJSON(dict_ty, "[\"foo\", \"bar\", \"baz\"]");
+ auto dict2 = ArrayFromJSON(dict_ty, "[\"fo\", \"bap\", \"bop\", \"qup\"]");
+
+ auto f0_type = arrow::dictionary(arrow::int32(), dict_ty);
+ auto f1_type = arrow::dictionary(arrow::int8(), dict_ty, true);
+ auto f2_type = arrow::dictionary(arrow::int32(), dict_ty);
+
+ std::shared_ptr<Array> indices0, indices1, indices2;
+ std::vector<int32_t> indices0_values = {1, 2, -1, 0, 2, 0};
+ std::vector<int8_t> indices1_values = {0, 0, 2, 2, 1, 1};
+ std::vector<int32_t> indices2_values = {3, 0, 2, 1, 0, 2};
+
+ ArrayFromVector<Int32Type, int32_t>(is_valid, indices0_values, &indices0);
+ ArrayFromVector<Int8Type, int8_t>(is_valid, indices1_values, &indices1);
+ ArrayFromVector<Int32Type, int32_t>(is_valid, indices2_values, &indices2);
+
+ auto a0 = std::make_shared<DictionaryArray>(f0_type, indices0, dict1);
+ auto a1 = std::make_shared<DictionaryArray>(f1_type, indices1, dict1);
+ auto a2 = std::make_shared<DictionaryArray>(f2_type, indices2, dict2);
+
+ // Lists of dictionary-encoded strings
+ auto f3_type = list(f1_type);
+
+ auto indices3 = ArrayFromJSON(int8(), "[0, 1, 2, 0, 1, 1, 2, 1, 0]");
+ auto offsets3 = ArrayFromJSON(int32(), "[0, 0, 2, 2, 5, 6, 9]");
+
+ std::shared_ptr<Buffer> null_bitmap;
+ RETURN_NOT_OK(GetBitmapFromVector(is_valid, &null_bitmap));
+
+ std::shared_ptr<Array> a3 = std::make_shared<ListArray>(
+ f3_type, length, std::static_pointer_cast<PrimitiveArray>(offsets3)->values(),
+ std::make_shared<DictionaryArray>(f1_type, indices3, dict1), null_bitmap, 1);
+
+ // Dictionary-encoded lists of integers
+ auto dict4_ty = list(int8());
+ auto f4_type = dictionary(int8(), dict4_ty);
+
+ auto indices4 = ArrayFromJSON(int8(), "[0, 1, 2, 0, 2, 2]");
+ auto dict4 = ArrayFromJSON(dict4_ty, "[[44, 55], [], [66]]");
+ auto a4 = std::make_shared<DictionaryArray>(f4_type, indices4, dict4);
+
+ std::vector<std::shared_ptr<Field>> fields = {
+ field("dict1", f0_type), field("dict2", f1_type), field("dict3", f2_type),
+ field("list<encoded utf8>", f3_type), field("encoded list<int8>", f4_type)};
+ std::vector<std::shared_ptr<Array>> arrays = {a0, a1, a2, a3, a4};
+
+ // Ensure all dictionary index types are represented
+ int field_index = 5;
+ for (auto index_ty : all_dictionary_index_types()) {
+ std::stringstream ss;
+ ss << "dict" << field_index++;
+ auto ty = arrow::dictionary(index_ty, dict_ty);
+ auto indices = ArrayFromJSON(index_ty, "[0, 1, 2, 0, 2, 2]");
+ fields.push_back(field(ss.str(), ty));
+ arrays.push_back(std::make_shared<DictionaryArray>(ty, indices, dict1));
+ }
+
+ // construct batch
+ *out = RecordBatch::Make(::arrow::schema(fields), length, arrays);
+ return Status::OK();
+}
+
+Status MakeDictionaryFlat(std::shared_ptr<RecordBatch>* out) {
+ const int64_t length = 6;
+
+ std::vector<bool> is_valid = {true, true, false, true, true, true};
+
+ auto dict_ty = utf8();
+ auto dict1 = ArrayFromJSON(dict_ty, "[\"foo\", \"bar\", \"baz\"]");
+ auto dict2 = ArrayFromJSON(dict_ty, "[\"foo\", \"bar\", \"baz\", \"qux\"]");
+
+ auto f0_type = arrow::dictionary(arrow::int32(), dict_ty);
+ auto f1_type = arrow::dictionary(arrow::int8(), dict_ty);
+ auto f2_type = arrow::dictionary(arrow::int32(), dict_ty);
+
+ std::shared_ptr<Array> indices0, indices1, indices2;
+ std::vector<int32_t> indices0_values = {1, 2, -1, 0, 2, 0};
+ std::vector<int8_t> indices1_values = {0, 0, 2, 2, 1, 1};
+ std::vector<int32_t> indices2_values = {3, 0, 2, 1, 0, 2};
+
+ ArrayFromVector<Int32Type, int32_t>(is_valid, indices0_values, &indices0);
+ ArrayFromVector<Int8Type, int8_t>(is_valid, indices1_values, &indices1);
+ ArrayFromVector<Int32Type, int32_t>(is_valid, indices2_values, &indices2);
+
+ auto a0 = std::make_shared<DictionaryArray>(f0_type, indices0, dict1);
+ auto a1 = std::make_shared<DictionaryArray>(f1_type, indices1, dict1);
+ auto a2 = std::make_shared<DictionaryArray>(f2_type, indices2, dict2);
+
+ // construct batch
+ auto schema = ::arrow::schema(
+ {field("dict1", f0_type), field("dict2", f1_type), field("dict3", f2_type)});
+
+ std::vector<std::shared_ptr<Array>> arrays = {a0, a1, a2};
+ *out = RecordBatch::Make(schema, length, arrays);
+ return Status::OK();
+}
+
+Status MakeNestedDictionary(std::shared_ptr<RecordBatch>* out) {
+ const int64_t length = 7;
+
+ auto values0 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]");
+ auto values1 = ArrayFromJSON(int64(), "[1234567890, 987654321]");
+
+ // NOTE: it is important to test several levels of nesting, with non-trivial
+ // numbers of child fields, to exercise structural mapping of fields to dict ids.
+
+ // Field 0: dict(int32, list(dict(int8, utf8)))
+ ARROW_ASSIGN_OR_RAISE(auto inner0,
+ DictionaryArray::FromArrays(
+ dictionary(int8(), values0->type()),
+ /*indices=*/ArrayFromJSON(int8(), "[0, 1, 2, null, 2, 1, 0]"),
+ /*dictionary=*/values0));
+
+ ARROW_ASSIGN_OR_RAISE(auto nested_values0,
+ ListArray::FromArrays(
+ /*offsets=*/*ArrayFromJSON(int32(), "[0, 3, 3, 6, 7]"),
+ /*values=*/*inner0));
+ ARROW_ASSIGN_OR_RAISE(
+ auto outer0, DictionaryArray::FromArrays(
+ dictionary(int32(), nested_values0->type()),
+ /*indices=*/ArrayFromJSON(int32(), "[0, 1, 3, 3, null, 3, 2]"),
+ /*dictionary=*/nested_values0));
+ DCHECK_EQ(outer0->length(), length);
+
+ // Field 1: struct(a: dict(int8, int64), b: dict(int16, utf8))
+ ARROW_ASSIGN_OR_RAISE(
+ auto inner1, DictionaryArray::FromArrays(
+ dictionary(int8(), values1->type()),
+ /*indices=*/ArrayFromJSON(int8(), "[0, 1, 1, null, null, 1, 0]"),
+ /*dictionary=*/values1));
+ ARROW_ASSIGN_OR_RAISE(
+ auto inner2, DictionaryArray::FromArrays(
+ dictionary(int16(), values0->type()),
+ /*indices=*/ArrayFromJSON(int16(), "[2, 1, null, null, 2, 1, 0]"),
+ /*dictionary=*/values0));
+ ARROW_ASSIGN_OR_RAISE(
+ auto outer1, StructArray::Make({inner1, inner2}, {field("a", inner1->type()),
+ field("b", inner2->type())}));
+ DCHECK_EQ(outer1->length(), length);
+
+ // Field 2: dict(int8, struct(c: dict(int8, int64), d: dict(int16, list(dict(int8,
+ // utf8)))))
+ ARROW_ASSIGN_OR_RAISE(auto nested_values2,
+ ListArray::FromArrays(
+ /*offsets=*/*ArrayFromJSON(int32(), "[0, 1, 5, 5, 7]"),
+ /*values=*/*inner0));
+ ARROW_ASSIGN_OR_RAISE(
+ auto inner3, DictionaryArray::FromArrays(
+ dictionary(int16(), nested_values2->type()),
+ /*indices=*/ArrayFromJSON(int16(), "[0, 1, 3, null, 3, 2, 1]"),
+ /*dictionary=*/nested_values2));
+ ARROW_ASSIGN_OR_RAISE(
+ auto inner4, StructArray::Make({inner1, inner3}, {field("c", inner1->type()),
+ field("d", inner3->type())}));
+ ARROW_ASSIGN_OR_RAISE(auto outer2,
+ DictionaryArray::FromArrays(
+ dictionary(int8(), inner4->type()),
+ /*indices=*/ArrayFromJSON(int8(), "[0, 2, 4, 6, 1, 3, 5]"),
+ /*dictionary=*/inner4));
+ DCHECK_EQ(outer2->length(), length);
+
+ auto schema = ::arrow::schema({
+ field("f0", outer0->type()),
+ field("f1", outer1->type()),
+ field("f2", outer2->type()),
+ });
+ *out = RecordBatch::Make(schema, length, {outer0, outer1, outer2});
+ return Status::OK();
+}
+
+Status MakeMap(std::shared_ptr<RecordBatch>* out) {
+ constexpr int64_t kNumRows = 3;
+ std::shared_ptr<Array> a0, a1;
+
+ auto key_array = ArrayFromJSON(utf8(), R"(["k1", "k2", "k1", "k3", "k1", "k4"])");
+ auto item_array = ArrayFromJSON(int16(), "[0, -1, 2, -3, 4, null]");
+ RETURN_NOT_OK(MakeRandomMapArray(key_array, item_array, kNumRows,
+ /*include_nulls=*/false, default_memory_pool(), &a0));
+ RETURN_NOT_OK(MakeRandomMapArray(key_array, item_array, kNumRows,
+ /*include_nulls=*/true, default_memory_pool(), &a1));
+ auto f0 = field("f0", a0->type());
+ auto f1 = field("f1", a1->type());
+ *out = RecordBatch::Make(::arrow::schema({f0, f1}), kNumRows, {a0, a1});
+ return Status::OK();
+}
+
+Status MakeMapOfDictionary(std::shared_ptr<RecordBatch>* out) {
+ // Exercises ARROW-9660
+ constexpr int64_t kNumRows = 3;
+ std::shared_ptr<Array> a0, a1;
+
+ auto key_array = DictArrayFromJSON(dictionary(int32(), utf8()), "[0, 1, 0, 2, 0, 3]",
+ R"(["k1", "k2", "k3", "k4"])");
+ auto item_array = ArrayFromJSON(int16(), "[0, -1, 2, -3, 4, null]");
+ RETURN_NOT_OK(MakeRandomMapArray(key_array, item_array, kNumRows,
+ /*include_nulls=*/false, default_memory_pool(), &a0));
+ RETURN_NOT_OK(MakeRandomMapArray(key_array, item_array, kNumRows,
+ /*include_nulls=*/true, default_memory_pool(), &a1));
+ auto f0 = field("f0", a0->type());
+ auto f1 = field("f1", a1->type());
+ *out = RecordBatch::Make(::arrow::schema({f0, f1}), kNumRows, {a0, a1});
+ return Status::OK();
+}
+
+Status MakeDates(std::shared_ptr<RecordBatch>* out) {
+ std::vector<bool> is_valid = {true, true, true, false, true, true, true};
+ auto f0 = field("f0", date32());
+ auto f1 = field("f1", date64());
+ auto schema = ::arrow::schema({f0, f1});
+
+ std::vector<int32_t> date32_values = {0, 1, 2, 3, 4, 5, 6};
+ std::shared_ptr<Array> date32_array;
+ ArrayFromVector<Date32Type, int32_t>(is_valid, date32_values, &date32_array);
+
+ std::vector<int64_t> date64_values = {1489269000000, 1489270000000, 1489271000000,
+ 1489272000000, 1489272000000, 1489273000000,
+ 1489274000000};
+ std::shared_ptr<Array> date64_array;
+ ArrayFromVector<Date64Type, int64_t>(is_valid, date64_values, &date64_array);
+
+ *out = RecordBatch::Make(schema, date32_array->length(), {date32_array, date64_array});
+ return Status::OK();
+}
+
+Status MakeTimestamps(std::shared_ptr<RecordBatch>* out) {
+ std::vector<bool> is_valid = {true, true, true, false, true, true, true};
+ auto f0 = field("f0", timestamp(TimeUnit::MILLI));
+ auto f1 = field("f1", timestamp(TimeUnit::NANO, "America/New_York"));
+ auto f2 = field("f2", timestamp(TimeUnit::SECOND));
+ auto schema = ::arrow::schema({f0, f1, f2});
+
+ std::vector<int64_t> ts_values = {1489269000000, 1489270000000, 1489271000000,
+ 1489272000000, 1489272000000, 1489273000000};
+
+ std::shared_ptr<Array> a0, a1, a2;
+ ArrayFromVector<TimestampType, int64_t>(f0->type(), is_valid, ts_values, &a0);
+ ArrayFromVector<TimestampType, int64_t>(f1->type(), is_valid, ts_values, &a1);
+ ArrayFromVector<TimestampType, int64_t>(f2->type(), is_valid, ts_values, &a2);
+
+ *out = RecordBatch::Make(schema, a0->length(), {a0, a1, a2});
+ return Status::OK();
+}
+
+Status MakeIntervals(std::shared_ptr<RecordBatch>* out) {
+ std::vector<bool> is_valid = {true, true, true, false, true, true, true};
+ auto f0 = field("f0", duration(TimeUnit::MILLI));
+ auto f1 = field("f1", duration(TimeUnit::NANO));
+ auto f2 = field("f2", duration(TimeUnit::SECOND));
+ auto f3 = field("f3", day_time_interval());
+ auto f4 = field("f4", month_interval());
+ auto f5 = field("f5", month_day_nano_interval());
+ auto schema = ::arrow::schema({f0, f1, f2, f3, f4, f5});
+
+ std::vector<int64_t> ts_values = {1489269000000, 1489270000000, 1489271000000,
+ 1489272000000, 1489272000000, 1489273000000};
+
+ std::shared_ptr<Array> a0, a1, a2, a3, a4, a5;
+ ArrayFromVector<DurationType, int64_t>(f0->type(), is_valid, ts_values, &a0);
+ ArrayFromVector<DurationType, int64_t>(f1->type(), is_valid, ts_values, &a1);
+ ArrayFromVector<DurationType, int64_t>(f2->type(), is_valid, ts_values, &a2);
+ ArrayFromVector<DayTimeIntervalType, DayTimeIntervalType::DayMilliseconds>(
+ f3->type(), is_valid, {{0, 0}, {0, 1}, {1, 1}, {2, 1}, {3, 4}, {-1, -1}}, &a3);
+ ArrayFromVector<MonthIntervalType, int32_t>(f4->type(), is_valid, {0, -1, 1, 2, -2, 24},
+ &a4);
+ ArrayFromVector<MonthDayNanoIntervalType, MonthDayNanoIntervalType::MonthDayNanos>(
+ f5->type(), is_valid,
+ {{0, 0, 0}, {0, 0, 1}, {-1, 0, 1}, {-1, -2, -3}, {2, 4, 6}, {-3, -4, -5}}, &a5);
+
+ *out = RecordBatch::Make(schema, a0->length(), {a0, a1, a2, a3, a4, a5});
+ return Status::OK();
+}
+
+Status MakeTimes(std::shared_ptr<RecordBatch>* out) {
+ std::vector<bool> is_valid = {true, true, true, false, true, true, true};
+ auto f0 = field("f0", time32(TimeUnit::MILLI));
+ auto f1 = field("f1", time64(TimeUnit::NANO));
+ auto f2 = field("f2", time32(TimeUnit::SECOND));
+ auto f3 = field("f3", time64(TimeUnit::NANO));
+ auto schema = ::arrow::schema({f0, f1, f2, f3});
+
+ std::vector<int32_t> t32_values = {1489269000, 1489270000, 1489271000,
+ 1489272000, 1489272000, 1489273000};
+ std::vector<int64_t> t64_values = {1489269000000, 1489270000000, 1489271000000,
+ 1489272000000, 1489272000000, 1489273000000};
+
+ std::shared_ptr<Array> a0, a1, a2, a3;
+ ArrayFromVector<Time32Type, int32_t>(f0->type(), is_valid, t32_values, &a0);
+ ArrayFromVector<Time64Type, int64_t>(f1->type(), is_valid, t64_values, &a1);
+ ArrayFromVector<Time32Type, int32_t>(f2->type(), is_valid, t32_values, &a2);
+ ArrayFromVector<Time64Type, int64_t>(f3->type(), is_valid, t64_values, &a3);
+
+ *out = RecordBatch::Make(schema, a0->length(), {a0, a1, a2, a3});
+ return Status::OK();
+}
+
+template <typename BuilderType, typename T>
+static void AppendValues(const std::vector<bool>& is_valid, const std::vector<T>& values,
+ BuilderType* builder) {
+ for (size_t i = 0; i < values.size(); ++i) {
+ if (is_valid[i]) {
+ ASSERT_OK(builder->Append(values[i]));
+ } else {
+ ASSERT_OK(builder->AppendNull());
+ }
+ }
+}
+
+Status MakeFWBinary(std::shared_ptr<RecordBatch>* out) {
+ std::vector<bool> is_valid = {true, true, true, false};
+ auto f0 = field("f0", fixed_size_binary(4));
+ auto f1 = field("f1", fixed_size_binary(0));
+ auto schema = ::arrow::schema({f0, f1});
+
+ std::shared_ptr<Array> a1, a2;
+
+ FixedSizeBinaryBuilder b1(f0->type());
+ FixedSizeBinaryBuilder b2(f1->type());
+
+ std::vector<std::string> values1 = {"foo1", "foo2", "foo3", "foo4"};
+ AppendValues(is_valid, values1, &b1);
+
+ std::vector<std::string> values2 = {"", "", "", ""};
+ AppendValues(is_valid, values2, &b2);
+
+ RETURN_NOT_OK(b1.Finish(&a1));
+ RETURN_NOT_OK(b2.Finish(&a2));
+
+ *out = RecordBatch::Make(schema, a1->length(), {a1, a2});
+ return Status::OK();
+}
+
+Status MakeDecimal(std::shared_ptr<RecordBatch>* out) {
+ constexpr int kDecimalPrecision = 38;
+ auto type = decimal(kDecimalPrecision, 4);
+ auto f0 = field("f0", type);
+ auto f1 = field("f1", type);
+ auto schema = ::arrow::schema({f0, f1});
+
+ constexpr int kDecimalSize = 16;
+ constexpr int length = 10;
+
+ std::vector<uint8_t> is_valid_bytes(length);
+
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> data,
+ AllocateBuffer(kDecimalSize * length));
+
+ random_decimals(length, 1, kDecimalPrecision, data->mutable_data());
+ random_null_bytes(length, 0.1, is_valid_bytes.data());
+
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> is_valid,
+ internal::BytesToBits(is_valid_bytes));
+
+ auto a1 = std::make_shared<Decimal128Array>(f0->type(), length, data, is_valid,
+ kUnknownNullCount);
+
+ auto a2 = std::make_shared<Decimal128Array>(f1->type(), length, data);
+
+ *out = RecordBatch::Make(schema, length, {a1, a2});
+ return Status::OK();
+}
+
+Status MakeNull(std::shared_ptr<RecordBatch>* out) {
+ auto f0 = field("f0", null());
+
+ // Also put a non-null field to make sure we handle the null array buffers properly
+ auto f1 = field("f1", int64());
+
+ auto schema = ::arrow::schema({f0, f1});
+
+ auto a1 = std::make_shared<NullArray>(10);
+
+ std::vector<int64_t> int_values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<bool> is_valid = {true, true, true, false, false,
+ true, true, true, true, true};
+ std::shared_ptr<Array> a2;
+ ArrayFromVector<Int64Type, int64_t>(f1->type(), is_valid, int_values, &a2);
+
+ *out = RecordBatch::Make(schema, a1->length(), {a1, a2});
+ return Status::OK();
+}
+
+Status MakeUuid(std::shared_ptr<RecordBatch>* out) {
+ auto uuid_type = uuid();
+ auto storage_type = checked_cast<const ExtensionType&>(*uuid_type).storage_type();
+
+ auto f0 = field("f0", uuid_type);
+ auto f1 = field("f1", uuid_type, /*nullable=*/false);
+ auto schema = ::arrow::schema({f0, f1});
+
+ auto a0 = std::make_shared<UuidArray>(
+ uuid_type, ArrayFromJSON(storage_type, R"(["0123456789abcdef", null])"));
+ auto a1 = std::make_shared<UuidArray>(
+ uuid_type,
+ ArrayFromJSON(storage_type, R"(["ZYXWVUTSRQPONMLK", "JIHGFEDBA9876543"])"));
+
+ *out = RecordBatch::Make(schema, a1->length(), {a0, a1});
+ return Status::OK();
+}
+
+Status MakeComplex128(std::shared_ptr<RecordBatch>* out) {
+ auto type = complex128();
+ auto storage_type = checked_cast<const ExtensionType&>(*type).storage_type();
+
+ auto f0 = field("f0", type);
+ auto f1 = field("f1", type, /*nullable=*/false);
+ auto schema = ::arrow::schema({f0, f1});
+
+ auto a0 = ExtensionType::WrapArray(complex128(),
+ ArrayFromJSON(storage_type, "[[1.0, -2.5], null]"));
+ auto a1 = ExtensionType::WrapArray(
+ complex128(), ArrayFromJSON(storage_type, "[[1.0, -2.5], [3.0, -4.0]]"));
+
+ *out = RecordBatch::Make(schema, a1->length(), {a0, a1});
+ return Status::OK();
+}
+
+Status MakeDictExtension(std::shared_ptr<RecordBatch>* out) {
+ auto type = dict_extension_type();
+ auto storage_type = checked_cast<const ExtensionType&>(*type).storage_type();
+
+ auto f0 = field("f0", type);
+ auto f1 = field("f1", type, /*nullable=*/false);
+ auto schema = ::arrow::schema({f0, f1});
+
+ auto storage0 = std::make_shared<DictionaryArray>(
+ storage_type, ArrayFromJSON(int8(), "[1, 0, null, 1, 1]"),
+ ArrayFromJSON(utf8(), R"(["foo", "bar"])"));
+ auto a0 = std::make_shared<ExtensionArray>(type, storage0);
+
+ auto storage1 = std::make_shared<DictionaryArray>(
+ storage_type, ArrayFromJSON(int8(), "[2, 0, 0, 1, 1]"),
+ ArrayFromJSON(utf8(), R"(["arrow", "parquet", "plasma"])"));
+ auto a1 = std::make_shared<ExtensionArray>(type, storage1);
+
+ *out = RecordBatch::Make(schema, a1->length(), {a0, a1});
+ return Status::OK();
+}
+
+namespace {
+
+template <typename CValueType, typename SeedType, typename DistributionType>
+void FillRandomData(CValueType* data, size_t n, CValueType min, CValueType max,
+ SeedType seed) {
+ std::default_random_engine rng(seed);
+ DistributionType dist(min, max);
+ std::generate(data, data + n,
+ [&dist, &rng] { return static_cast<CValueType>(dist(rng)); });
+}
+
+template <typename CValueType, typename SeedType>
+enable_if_t<std::is_integral<CValueType>::value && std::is_signed<CValueType>::value,
+ void>
+FillRandomData(CValueType* data, size_t n, SeedType seed) {
+ FillRandomData<CValueType, SeedType, std::uniform_int_distribution<CValueType>>(
+ data, n, -1000, 1000, seed);
+}
+
+template <typename CValueType, typename SeedType>
+enable_if_t<std::is_integral<CValueType>::value && std::is_unsigned<CValueType>::value,
+ void>
+FillRandomData(CValueType* data, size_t n, SeedType seed) {
+ FillRandomData<CValueType, SeedType, std::uniform_int_distribution<CValueType>>(
+ data, n, 0, 1000, seed);
+}
+
+template <typename CValueType, typename SeedType>
+enable_if_t<std::is_floating_point<CValueType>::value, void> FillRandomData(
+ CValueType* data, size_t n, SeedType seed) {
+ FillRandomData<CValueType, SeedType, std::uniform_real_distribution<CValueType>>(
+ data, n, -1000, 1000, seed);
+}
+
+} // namespace
+
+Status MakeRandomTensor(const std::shared_ptr<DataType>& type,
+ const std::vector<int64_t>& shape, bool row_major_p,
+ std::shared_ptr<Tensor>* out, uint32_t seed) {
+ const auto& element_type = internal::checked_cast<const FixedWidthType&>(*type);
+ std::vector<int64_t> strides;
+ if (row_major_p) {
+ RETURN_NOT_OK(internal::ComputeRowMajorStrides(element_type, shape, &strides));
+ } else {
+ RETURN_NOT_OK(internal::ComputeColumnMajorStrides(element_type, shape, &strides));
+ }
+
+ const int64_t element_size = element_type.bit_width() / CHAR_BIT;
+ const int64_t len =
+ std::accumulate(shape.begin(), shape.end(), int64_t(1), std::multiplies<int64_t>());
+
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> buf, AllocateBuffer(element_size * len));
+
+ switch (type->id()) {
+ case Type::INT8:
+ FillRandomData<int8_t, uint32_t, std::uniform_int_distribution<int16_t>>(
+ reinterpret_cast<int8_t*>(buf->mutable_data()), len, -128, 127, seed);
+ break;
+ case Type::UINT8:
+ FillRandomData<uint8_t, uint32_t, std::uniform_int_distribution<uint16_t>>(
+ reinterpret_cast<uint8_t*>(buf->mutable_data()), len, 0, 255, seed);
+ break;
+ case Type::INT16:
+ FillRandomData(reinterpret_cast<int16_t*>(buf->mutable_data()), len, seed);
+ break;
+ case Type::UINT16:
+ FillRandomData(reinterpret_cast<uint16_t*>(buf->mutable_data()), len, seed);
+ break;
+ case Type::INT32:
+ FillRandomData(reinterpret_cast<int32_t*>(buf->mutable_data()), len, seed);
+ break;
+ case Type::UINT32:
+ FillRandomData(reinterpret_cast<uint32_t*>(buf->mutable_data()), len, seed);
+ break;
+ case Type::INT64:
+ FillRandomData(reinterpret_cast<int64_t*>(buf->mutable_data()), len, seed);
+ break;
+ case Type::UINT64:
+ FillRandomData(reinterpret_cast<uint64_t*>(buf->mutable_data()), len, seed);
+ break;
+ case Type::HALF_FLOAT:
+ FillRandomData(reinterpret_cast<int16_t*>(buf->mutable_data()), len, seed);
+ break;
+ case Type::FLOAT:
+ FillRandomData(reinterpret_cast<float*>(buf->mutable_data()), len, seed);
+ break;
+ case Type::DOUBLE:
+ FillRandomData(reinterpret_cast<double*>(buf->mutable_data()), len, seed);
+ break;
+ default:
+ return Status::Invalid(type->ToString(), " is not valid data type for a tensor");
+ }
+
+ return Tensor::Make(type, buf, shape, strides).Value(out);
+}
+
+} // namespace test
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/test_common.h b/src/arrow/cpp/src/arrow/ipc/test_common.h
new file mode 100644
index 000000000..48df28b2d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/test_common.h
@@ -0,0 +1,175 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/testing/visibility.h"
+#include "arrow/type.h"
+
+namespace arrow {
+namespace ipc {
+namespace test {
+
+// A typedef used for test parameterization
+typedef Status MakeRecordBatch(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+void CompareArraysDetailed(int index, const Array& result, const Array& expected);
+
+ARROW_TESTING_EXPORT
+void CompareBatchColumnsDetailed(const RecordBatch& result, const RecordBatch& expected);
+
+ARROW_TESTING_EXPORT
+Status MakeRandomInt32Array(int64_t length, bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out, uint32_t seed = 0);
+
+ARROW_TESTING_EXPORT
+Status MakeRandomListArray(const std::shared_ptr<Array>& child_array, int num_lists,
+ bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeRandomLargeListArray(const std::shared_ptr<Array>& child_array, int num_lists,
+ bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeRandomBooleanArray(const int length, bool include_nulls,
+ std::shared_ptr<Array>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeBooleanBatchSized(const int length, std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeBooleanBatch(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeIntBatchSized(int length, std::shared_ptr<RecordBatch>* out,
+ uint32_t seed = 0);
+
+ARROW_TESTING_EXPORT
+Status MakeIntRecordBatch(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeFloat3264BatchSized(int length, std::shared_ptr<RecordBatch>* out,
+ uint32_t seed = 0);
+
+ARROW_TESTING_EXPORT
+Status MakeFloat3264Batch(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeFloatBatchSized(int length, std::shared_ptr<RecordBatch>* out,
+ uint32_t seed = 0);
+
+ARROW_TESTING_EXPORT
+Status MakeFloatBatch(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeRandomStringArray(int64_t length, bool include_nulls, MemoryPool* pool,
+ std::shared_ptr<Array>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeStringTypesRecordBatch(std::shared_ptr<RecordBatch>* out,
+ bool with_nulls = true);
+
+ARROW_TESTING_EXPORT
+Status MakeStringTypesRecordBatchWithNulls(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeNullRecordBatch(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeListRecordBatch(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeFixedSizeListRecordBatch(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeZeroLengthRecordBatch(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeNonNullRecordBatch(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeDeeplyNestedList(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeStruct(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeUnion(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeDictionary(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeDictionaryFlat(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeNestedDictionary(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeMap(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeMapOfDictionary(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeDates(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeTimestamps(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeIntervals(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeTimes(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeFWBinary(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeDecimal(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeNull(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeUuid(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeComplex128(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeDictExtension(std::shared_ptr<RecordBatch>* out);
+
+ARROW_TESTING_EXPORT
+Status MakeRandomTensor(const std::shared_ptr<DataType>& type,
+ const std::vector<int64_t>& shape, bool row_major_p,
+ std::shared_ptr<Tensor>* out, uint32_t seed = 0);
+
+} // namespace test
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/type_fwd.h b/src/arrow/cpp/src/arrow/ipc/type_fwd.h
new file mode 100644
index 000000000..3493c4f14
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/type_fwd.h
@@ -0,0 +1,65 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+namespace arrow {
+namespace ipc {
+
+enum class MetadataVersion : char {
+ /// 0.1.0
+ V1,
+
+ /// 0.2.0
+ V2,
+
+ /// 0.3.0 to 0.7.1
+ V3,
+
+ /// 0.8.0 to 0.17.0
+ V4,
+
+ /// >= 1.0.0
+ V5
+};
+
+class Message;
+enum class MessageType {
+ NONE,
+ SCHEMA,
+ DICTIONARY_BATCH,
+ RECORD_BATCH,
+ TENSOR,
+ SPARSE_TENSOR
+};
+
+struct IpcReadOptions;
+struct IpcWriteOptions;
+
+class MessageReader;
+
+class RecordBatchStreamReader;
+class RecordBatchFileReader;
+class RecordBatchWriter;
+
+namespace feather {
+
+class Reader;
+
+} // namespace feather
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/util.h b/src/arrow/cpp/src/arrow/ipc/util.h
new file mode 100644
index 000000000..709fedbf3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/util.h
@@ -0,0 +1,41 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+
+namespace arrow {
+namespace ipc {
+
+// Buffers are padded to 64-byte boundaries (for SIMD)
+static constexpr int32_t kArrowAlignment = 64;
+
+// Tensors are padded to 64-byte boundaries
+static constexpr int32_t kTensorAlignment = 64;
+
+// Align on 8-byte boundaries in IPC
+static constexpr int32_t kArrowIpcAlignment = 8;
+
+static constexpr uint8_t kPaddingBytes[kArrowAlignment] = {0};
+
+static inline int64_t PaddedLength(int64_t nbytes, int32_t alignment = kArrowAlignment) {
+ return ((nbytes + alignment - 1) / alignment) * alignment;
+}
+
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/writer.cc b/src/arrow/cpp/src/arrow/ipc/writer.cc
new file mode 100644
index 000000000..7b9254b7e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/writer.cc
@@ -0,0 +1,1429 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/ipc/writer.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <sstream>
+#include <string>
+#include <type_traits>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/device.h"
+#include "arrow/extension_type.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/dictionary.h"
+#include "arrow/ipc/message.h"
+#include "arrow/ipc/metadata_internal.h"
+#include "arrow/ipc/util.h"
+#include "arrow/record_batch.h"
+#include "arrow/result_internal.h"
+#include "arrow/sparse_tensor.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/compression.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/parallel.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+using internal::CopyBitmap;
+using internal::GetByteWidth;
+
+namespace ipc {
+
+using internal::FileBlock;
+using internal::kArrowMagicBytes;
+
+namespace {
+
+bool HasNestedDict(const ArrayData& data) {
+ if (data.type->id() == Type::DICTIONARY) {
+ return true;
+ }
+ for (const auto& child : data.child_data) {
+ if (HasNestedDict(*child)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+Status GetTruncatedBitmap(int64_t offset, int64_t length,
+ const std::shared_ptr<Buffer> input, MemoryPool* pool,
+ std::shared_ptr<Buffer>* buffer) {
+ if (!input) {
+ *buffer = input;
+ return Status::OK();
+ }
+ int64_t min_length = PaddedLength(BitUtil::BytesForBits(length));
+ if (offset != 0 || min_length < input->size()) {
+ // With a sliced array / non-zero offset, we must copy the bitmap
+ ARROW_ASSIGN_OR_RAISE(*buffer, CopyBitmap(pool, input->data(), offset, length));
+ } else {
+ *buffer = input;
+ }
+ return Status::OK();
+}
+
+Status GetTruncatedBuffer(int64_t offset, int64_t length, int32_t byte_width,
+ const std::shared_ptr<Buffer> input, MemoryPool* pool,
+ std::shared_ptr<Buffer>* buffer) {
+ if (!input) {
+ *buffer = input;
+ return Status::OK();
+ }
+ int64_t padded_length = PaddedLength(length * byte_width);
+ if (offset != 0 || padded_length < input->size()) {
+ *buffer =
+ SliceBuffer(input, offset * byte_width, std::min(padded_length, input->size()));
+ } else {
+ *buffer = input;
+ }
+ return Status::OK();
+}
+
+static inline bool NeedTruncate(int64_t offset, const Buffer* buffer,
+ int64_t min_length) {
+ // buffer can be NULL
+ if (buffer == nullptr) {
+ return false;
+ }
+ return offset != 0 || min_length < buffer->size();
+}
+
+class RecordBatchSerializer {
+ public:
+ RecordBatchSerializer(int64_t buffer_start_offset, const IpcWriteOptions& options,
+ IpcPayload* out)
+ : out_(out),
+ options_(options),
+ max_recursion_depth_(options.max_recursion_depth),
+ buffer_start_offset_(buffer_start_offset) {
+ DCHECK_GT(max_recursion_depth_, 0);
+ }
+
+ virtual ~RecordBatchSerializer() = default;
+
+ Status VisitArray(const Array& arr) {
+ static std::shared_ptr<Buffer> kNullBuffer = std::make_shared<Buffer>(nullptr, 0);
+
+ if (max_recursion_depth_ <= 0) {
+ return Status::Invalid("Max recursion depth reached");
+ }
+
+ if (!options_.allow_64bit && arr.length() > std::numeric_limits<int32_t>::max()) {
+ return Status::CapacityError("Cannot write arrays larger than 2^31 - 1 in length");
+ }
+
+ // push back all common elements
+ field_nodes_.push_back({arr.length(), arr.null_count(), 0});
+
+ // In V4, null types have no validity bitmap
+ // In V5 and later, null and union types have no validity bitmap
+ if (internal::HasValidityBitmap(arr.type_id(), options_.metadata_version)) {
+ if (arr.null_count() > 0) {
+ std::shared_ptr<Buffer> bitmap;
+ RETURN_NOT_OK(GetTruncatedBitmap(arr.offset(), arr.length(), arr.null_bitmap(),
+ options_.memory_pool, &bitmap));
+ out_->body_buffers.emplace_back(bitmap);
+ } else {
+ // Push a dummy zero-length buffer, not to be copied
+ out_->body_buffers.emplace_back(kNullBuffer);
+ }
+ }
+ return VisitType(arr);
+ }
+
+ // Override this for writing dictionary metadata
+ virtual Status SerializeMetadata(int64_t num_rows) {
+ return WriteRecordBatchMessage(num_rows, out_->body_length, custom_metadata_,
+ field_nodes_, buffer_meta_, options_, &out_->metadata);
+ }
+
+ void AppendCustomMetadata(const std::string& key, const std::string& value) {
+ if (!custom_metadata_) {
+ custom_metadata_ = std::make_shared<KeyValueMetadata>();
+ }
+ custom_metadata_->Append(key, value);
+ }
+
+ Status CompressBuffer(const Buffer& buffer, util::Codec* codec,
+ std::shared_ptr<Buffer>* out) {
+ // Convert buffer to uncompressed-length-prefixed compressed buffer
+ int64_t maximum_length = codec->MaxCompressedLen(buffer.size(), buffer.data());
+ ARROW_ASSIGN_OR_RAISE(auto result, AllocateBuffer(maximum_length + sizeof(int64_t)));
+
+ int64_t actual_length;
+ ARROW_ASSIGN_OR_RAISE(actual_length,
+ codec->Compress(buffer.size(), buffer.data(), maximum_length,
+ result->mutable_data() + sizeof(int64_t)));
+ *reinterpret_cast<int64_t*>(result->mutable_data()) =
+ BitUtil::ToLittleEndian(buffer.size());
+ *out = SliceBuffer(std::move(result), /*offset=*/0, actual_length + sizeof(int64_t));
+ return Status::OK();
+ }
+
+ Status CompressBodyBuffers() {
+ RETURN_NOT_OK(
+ internal::CheckCompressionSupported(options_.codec->compression_type()));
+
+ auto CompressOne = [&](size_t i) {
+ if (out_->body_buffers[i]->size() > 0) {
+ RETURN_NOT_OK(CompressBuffer(*out_->body_buffers[i], options_.codec.get(),
+ &out_->body_buffers[i]));
+ }
+ return Status::OK();
+ };
+
+ return ::arrow::internal::OptionalParallelFor(
+ options_.use_threads, static_cast<int>(out_->body_buffers.size()), CompressOne);
+ }
+
+ Status Assemble(const RecordBatch& batch) {
+ if (field_nodes_.size() > 0) {
+ field_nodes_.clear();
+ buffer_meta_.clear();
+ out_->body_buffers.clear();
+ }
+
+ // Perform depth-first traversal of the row-batch
+ for (int i = 0; i < batch.num_columns(); ++i) {
+ RETURN_NOT_OK(VisitArray(*batch.column(i)));
+ }
+
+ if (options_.codec != nullptr) {
+ RETURN_NOT_OK(CompressBodyBuffers());
+ }
+
+ // The position for the start of a buffer relative to the passed frame of
+ // reference. May be 0 or some other position in an address space
+ int64_t offset = buffer_start_offset_;
+
+ buffer_meta_.reserve(out_->body_buffers.size());
+
+ // Construct the buffer metadata for the record batch header
+ for (const auto& buffer : out_->body_buffers) {
+ int64_t size = 0;
+ int64_t padding = 0;
+
+ // The buffer might be null if we are handling zero row lengths.
+ if (buffer) {
+ size = buffer->size();
+ padding = BitUtil::RoundUpToMultipleOf8(size) - size;
+ }
+
+ buffer_meta_.push_back({offset, size});
+ offset += size + padding;
+ }
+
+ out_->body_length = offset - buffer_start_offset_;
+ DCHECK(BitUtil::IsMultipleOf8(out_->body_length));
+
+ // Now that we have computed the locations of all of the buffers in shared
+ // memory, the data header can be converted to a flatbuffer and written out
+ //
+ // Note: The memory written here is prefixed by the size of the flatbuffer
+ // itself as an int32_t.
+ return SerializeMetadata(batch.num_rows());
+ }
+
+ template <typename ArrayType>
+ Status GetZeroBasedValueOffsets(const ArrayType& array,
+ std::shared_ptr<Buffer>* value_offsets) {
+ // Share slicing logic between ListArray, BinaryArray and LargeBinaryArray
+ using offset_type = typename ArrayType::offset_type;
+
+ auto offsets = array.value_offsets();
+
+ int64_t required_bytes = sizeof(offset_type) * (array.length() + 1);
+ if (array.offset() != 0) {
+ // If we have a non-zero offset, then the value offsets do not start at
+ // zero. We must a) create a new offsets array with shifted offsets and
+ // b) slice the values array accordingly
+
+ ARROW_ASSIGN_OR_RAISE(auto shifted_offsets,
+ AllocateBuffer(required_bytes, options_.memory_pool));
+
+ offset_type* dest_offsets =
+ reinterpret_cast<offset_type*>(shifted_offsets->mutable_data());
+ const offset_type start_offset = array.value_offset(0);
+
+ for (int i = 0; i < array.length(); ++i) {
+ dest_offsets[i] = array.value_offset(i) - start_offset;
+ }
+ // Final offset
+ dest_offsets[array.length()] = array.value_offset(array.length()) - start_offset;
+ offsets = std::move(shifted_offsets);
+ } else {
+ // ARROW-6046: Slice offsets to used extent, in case we have a truncated
+ // slice
+ if (offsets != nullptr && offsets->size() > required_bytes) {
+ offsets = SliceBuffer(offsets, 0, required_bytes);
+ }
+ }
+ *value_offsets = std::move(offsets);
+ return Status::OK();
+ }
+
+ Status Visit(const BooleanArray& array) {
+ std::shared_ptr<Buffer> data;
+ RETURN_NOT_OK(GetTruncatedBitmap(array.offset(), array.length(), array.values(),
+ options_.memory_pool, &data));
+ out_->body_buffers.emplace_back(data);
+ return Status::OK();
+ }
+
+ Status Visit(const NullArray& array) { return Status::OK(); }
+
+ template <typename T>
+ typename std::enable_if<is_number_type<typename T::TypeClass>::value ||
+ is_temporal_type<typename T::TypeClass>::value ||
+ is_fixed_size_binary_type<typename T::TypeClass>::value,
+ Status>::type
+ Visit(const T& array) {
+ std::shared_ptr<Buffer> data = array.values();
+
+ const int64_t type_width = GetByteWidth(*array.type());
+ int64_t min_length = PaddedLength(array.length() * type_width);
+
+ if (NeedTruncate(array.offset(), data.get(), min_length)) {
+ // Non-zero offset, slice the buffer
+ const int64_t byte_offset = array.offset() * type_width;
+
+ // Send padding if it's available
+ const int64_t buffer_length =
+ std::min(BitUtil::RoundUpToMultipleOf8(array.length() * type_width),
+ data->size() - byte_offset);
+ data = SliceBuffer(data, byte_offset, buffer_length);
+ }
+ out_->body_buffers.emplace_back(data);
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_base_binary<typename T::TypeClass, Status> Visit(const T& array) {
+ std::shared_ptr<Buffer> value_offsets;
+ RETURN_NOT_OK(GetZeroBasedValueOffsets<T>(array, &value_offsets));
+ auto data = array.value_data();
+
+ int64_t total_data_bytes = 0;
+ if (value_offsets) {
+ total_data_bytes = array.value_offset(array.length()) - array.value_offset(0);
+ }
+ if (NeedTruncate(array.offset(), data.get(), total_data_bytes)) {
+ // Slice the data buffer to include only the range we need now
+ const int64_t start_offset = array.value_offset(0);
+ const int64_t slice_length =
+ std::min(PaddedLength(total_data_bytes), data->size() - start_offset);
+ data = SliceBuffer(data, start_offset, slice_length);
+ }
+
+ out_->body_buffers.emplace_back(value_offsets);
+ out_->body_buffers.emplace_back(data);
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_base_list<typename T::TypeClass, Status> Visit(const T& array) {
+ using offset_type = typename T::offset_type;
+
+ std::shared_ptr<Buffer> value_offsets;
+ RETURN_NOT_OK(GetZeroBasedValueOffsets<T>(array, &value_offsets));
+ out_->body_buffers.emplace_back(value_offsets);
+
+ --max_recursion_depth_;
+ std::shared_ptr<Array> values = array.values();
+
+ offset_type values_offset = 0;
+ offset_type values_length = 0;
+ if (value_offsets) {
+ values_offset = array.value_offset(0);
+ values_length = array.value_offset(array.length()) - values_offset;
+ }
+
+ if (array.offset() != 0 || values_length < values->length()) {
+ // Must also slice the values
+ values = values->Slice(values_offset, values_length);
+ }
+ RETURN_NOT_OK(VisitArray(*values));
+ ++max_recursion_depth_;
+ return Status::OK();
+ }
+
+ Status Visit(const FixedSizeListArray& array) {
+ --max_recursion_depth_;
+ auto size = array.list_type()->list_size();
+ auto values = array.values()->Slice(array.offset() * size, array.length() * size);
+
+ RETURN_NOT_OK(VisitArray(*values));
+ ++max_recursion_depth_;
+ return Status::OK();
+ }
+
+ Status Visit(const StructArray& array) {
+ --max_recursion_depth_;
+ for (int i = 0; i < array.num_fields(); ++i) {
+ std::shared_ptr<Array> field = array.field(i);
+ RETURN_NOT_OK(VisitArray(*field));
+ }
+ ++max_recursion_depth_;
+ return Status::OK();
+ }
+
+ Status Visit(const SparseUnionArray& array) {
+ const int64_t offset = array.offset();
+ const int64_t length = array.length();
+
+ std::shared_ptr<Buffer> type_codes;
+ RETURN_NOT_OK(GetTruncatedBuffer(
+ offset, length, static_cast<int32_t>(sizeof(UnionArray::type_code_t)),
+ array.type_codes(), options_.memory_pool, &type_codes));
+ out_->body_buffers.emplace_back(type_codes);
+
+ --max_recursion_depth_;
+ for (int i = 0; i < array.num_fields(); ++i) {
+ // Sparse union, slicing is done for us by field()
+ RETURN_NOT_OK(VisitArray(*array.field(i)));
+ }
+ ++max_recursion_depth_;
+ return Status::OK();
+ }
+
+ Status Visit(const DenseUnionArray& array) {
+ const int64_t offset = array.offset();
+ const int64_t length = array.length();
+
+ std::shared_ptr<Buffer> type_codes;
+ RETURN_NOT_OK(GetTruncatedBuffer(
+ offset, length, static_cast<int32_t>(sizeof(UnionArray::type_code_t)),
+ array.type_codes(), options_.memory_pool, &type_codes));
+ out_->body_buffers.emplace_back(type_codes);
+
+ --max_recursion_depth_;
+ const auto& type = checked_cast<const UnionType&>(*array.type());
+
+ std::shared_ptr<Buffer> value_offsets;
+ RETURN_NOT_OK(
+ GetTruncatedBuffer(offset, length, static_cast<int32_t>(sizeof(int32_t)),
+ array.value_offsets(), options_.memory_pool, &value_offsets));
+
+ // The Union type codes are not necessary 0-indexed
+ int8_t max_code = 0;
+ for (int8_t code : type.type_codes()) {
+ if (code > max_code) {
+ max_code = code;
+ }
+ }
+
+ // Allocate an array of child offsets. Set all to -1 to indicate that we
+ // haven't observed a first occurrence of a particular child yet
+ std::vector<int32_t> child_offsets(max_code + 1, -1);
+ std::vector<int32_t> child_lengths(max_code + 1, 0);
+
+ if (offset != 0) {
+ // This is an unpleasant case. Because the offsets are different for
+ // each child array, when we have a sliced array, we need to "rebase"
+ // the value_offsets for each array
+
+ const int32_t* unshifted_offsets = array.raw_value_offsets();
+ const int8_t* type_codes = array.raw_type_codes();
+
+ // Allocate the shifted offsets
+ ARROW_ASSIGN_OR_RAISE(
+ auto shifted_offsets_buffer,
+ AllocateBuffer(length * sizeof(int32_t), options_.memory_pool));
+ int32_t* shifted_offsets =
+ reinterpret_cast<int32_t*>(shifted_offsets_buffer->mutable_data());
+
+ // Offsets may not be ascending, so we need to find out the start offset
+ // for each child
+ for (int64_t i = 0; i < length; ++i) {
+ const uint8_t code = type_codes[i];
+ if (child_offsets[code] == -1) {
+ child_offsets[code] = unshifted_offsets[i];
+ } else {
+ child_offsets[code] = std::min(child_offsets[code], unshifted_offsets[i]);
+ }
+ }
+
+ // Now compute shifted offsets by subtracting child offset
+ for (int64_t i = 0; i < length; ++i) {
+ const int8_t code = type_codes[i];
+ shifted_offsets[i] = unshifted_offsets[i] - child_offsets[code];
+ // Update the child length to account for observed value
+ child_lengths[code] = std::max(child_lengths[code], shifted_offsets[i] + 1);
+ }
+
+ value_offsets = std::move(shifted_offsets_buffer);
+ }
+ out_->body_buffers.emplace_back(value_offsets);
+
+ // Visit children and slice accordingly
+ for (int i = 0; i < type.num_fields(); ++i) {
+ std::shared_ptr<Array> child = array.field(i);
+
+ // TODO: ARROW-809, for sliced unions, tricky to know how much to
+ // truncate the children. For now, we are truncating the children to be
+ // no longer than the parent union.
+ if (offset != 0) {
+ const int8_t code = type.type_codes()[i];
+ const int64_t child_offset = child_offsets[code];
+ const int64_t child_length = child_lengths[code];
+
+ if (child_offset > 0) {
+ child = child->Slice(child_offset, child_length);
+ } else if (child_length < child->length()) {
+ // This case includes when child is not encountered at all
+ child = child->Slice(0, child_length);
+ }
+ }
+ RETURN_NOT_OK(VisitArray(*child));
+ }
+ ++max_recursion_depth_;
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryArray& array) {
+ // Dictionary written out separately. Slice offset contained in the indices
+ return VisitType(*array.indices());
+ }
+
+ Status Visit(const ExtensionArray& array) { return VisitType(*array.storage()); }
+
+ Status VisitType(const Array& values) { return VisitArrayInline(values, this); }
+
+ protected:
+ // Destination for output buffers
+ IpcPayload* out_;
+
+ std::shared_ptr<KeyValueMetadata> custom_metadata_;
+
+ std::vector<internal::FieldMetadata> field_nodes_;
+ std::vector<internal::BufferMetadata> buffer_meta_;
+
+ const IpcWriteOptions& options_;
+ int64_t max_recursion_depth_;
+ int64_t buffer_start_offset_;
+};
+
+class DictionarySerializer : public RecordBatchSerializer {
+ public:
+ DictionarySerializer(int64_t dictionary_id, bool is_delta, int64_t buffer_start_offset,
+ const IpcWriteOptions& options, IpcPayload* out)
+ : RecordBatchSerializer(buffer_start_offset, options, out),
+ dictionary_id_(dictionary_id),
+ is_delta_(is_delta) {}
+
+ Status SerializeMetadata(int64_t num_rows) override {
+ return WriteDictionaryMessage(dictionary_id_, is_delta_, num_rows, out_->body_length,
+ custom_metadata_, field_nodes_, buffer_meta_, options_,
+ &out_->metadata);
+ }
+
+ Status Assemble(const std::shared_ptr<Array>& dictionary) {
+ // Make a dummy record batch. A bit tedious as we have to make a schema
+ auto schema = arrow::schema({arrow::field("dictionary", dictionary->type())});
+ auto batch = RecordBatch::Make(std::move(schema), dictionary->length(), {dictionary});
+ return RecordBatchSerializer::Assemble(*batch);
+ }
+
+ private:
+ int64_t dictionary_id_;
+ bool is_delta_;
+};
+
+} // namespace
+
+Status WriteIpcPayload(const IpcPayload& payload, const IpcWriteOptions& options,
+ io::OutputStream* dst, int32_t* metadata_length) {
+ RETURN_NOT_OK(WriteMessage(*payload.metadata, options, dst, metadata_length));
+
+#ifndef NDEBUG
+ RETURN_NOT_OK(CheckAligned(dst));
+#endif
+
+ // Now write the buffers
+ for (size_t i = 0; i < payload.body_buffers.size(); ++i) {
+ const std::shared_ptr<Buffer>& buffer = payload.body_buffers[i];
+ int64_t size = 0;
+ int64_t padding = 0;
+
+ // The buffer might be null if we are handling zero row lengths.
+ if (buffer) {
+ size = buffer->size();
+ padding = BitUtil::RoundUpToMultipleOf8(size) - size;
+ }
+
+ if (size > 0) {
+ RETURN_NOT_OK(dst->Write(buffer));
+ }
+
+ if (padding > 0) {
+ RETURN_NOT_OK(dst->Write(kPaddingBytes, padding));
+ }
+ }
+
+#ifndef NDEBUG
+ RETURN_NOT_OK(CheckAligned(dst));
+#endif
+
+ return Status::OK();
+}
+
+Status GetSchemaPayload(const Schema& schema, const IpcWriteOptions& options,
+ const DictionaryFieldMapper& mapper, IpcPayload* out) {
+ out->type = MessageType::SCHEMA;
+ return internal::WriteSchemaMessage(schema, mapper, options, &out->metadata);
+}
+
+Status GetDictionaryPayload(int64_t id, const std::shared_ptr<Array>& dictionary,
+ const IpcWriteOptions& options, IpcPayload* out) {
+ return GetDictionaryPayload(id, false, dictionary, options, out);
+}
+
+Status GetDictionaryPayload(int64_t id, bool is_delta,
+ const std::shared_ptr<Array>& dictionary,
+ const IpcWriteOptions& options, IpcPayload* out) {
+ out->type = MessageType::DICTIONARY_BATCH;
+ // Frame of reference is 0, see ARROW-384
+ DictionarySerializer assembler(id, is_delta, /*buffer_start_offset=*/0, options, out);
+ return assembler.Assemble(dictionary);
+}
+
+Status GetRecordBatchPayload(const RecordBatch& batch, const IpcWriteOptions& options,
+ IpcPayload* out) {
+ out->type = MessageType::RECORD_BATCH;
+ RecordBatchSerializer assembler(/*buffer_start_offset=*/0, options, out);
+ return assembler.Assemble(batch);
+}
+
+Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset,
+ io::OutputStream* dst, int32_t* metadata_length,
+ int64_t* body_length, const IpcWriteOptions& options) {
+ IpcPayload payload;
+ RecordBatchSerializer assembler(buffer_start_offset, options, &payload);
+ RETURN_NOT_OK(assembler.Assemble(batch));
+
+ // TODO: it's a rough edge that the metadata and body length here are
+ // computed separately
+
+ // The body size is computed in the payload
+ *body_length = payload.body_length;
+
+ return WriteIpcPayload(payload, options, dst, metadata_length);
+}
+
+Status WriteRecordBatchStream(const std::vector<std::shared_ptr<RecordBatch>>& batches,
+ const IpcWriteOptions& options, io::OutputStream* dst) {
+ ASSIGN_OR_RAISE(std::shared_ptr<RecordBatchWriter> writer,
+ MakeStreamWriter(dst, batches[0]->schema(), options));
+ for (const auto& batch : batches) {
+ DCHECK(batch->schema()->Equals(*batches[0]->schema())) << "Schemas unequal";
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+ }
+ RETURN_NOT_OK(writer->Close());
+ return Status::OK();
+}
+
+namespace {
+
+Status WriteTensorHeader(const Tensor& tensor, io::OutputStream* dst,
+ int32_t* metadata_length) {
+ IpcWriteOptions options;
+ options.alignment = kTensorAlignment;
+ std::shared_ptr<Buffer> metadata;
+ ARROW_ASSIGN_OR_RAISE(metadata, internal::WriteTensorMessage(tensor, 0, options));
+ return WriteMessage(*metadata, options, dst, metadata_length);
+}
+
+Status WriteStridedTensorData(int dim_index, int64_t offset, int elem_size,
+ const Tensor& tensor, uint8_t* scratch_space,
+ io::OutputStream* dst) {
+ if (dim_index == tensor.ndim() - 1) {
+ const uint8_t* data_ptr = tensor.raw_data() + offset;
+ const int64_t stride = tensor.strides()[dim_index];
+ for (int64_t i = 0; i < tensor.shape()[dim_index]; ++i) {
+ memcpy(scratch_space + i * elem_size, data_ptr, elem_size);
+ data_ptr += stride;
+ }
+ return dst->Write(scratch_space, elem_size * tensor.shape()[dim_index]);
+ }
+ for (int64_t i = 0; i < tensor.shape()[dim_index]; ++i) {
+ RETURN_NOT_OK(WriteStridedTensorData(dim_index + 1, offset, elem_size, tensor,
+ scratch_space, dst));
+ offset += tensor.strides()[dim_index];
+ }
+ return Status::OK();
+}
+
+Status GetContiguousTensor(const Tensor& tensor, MemoryPool* pool,
+ std::unique_ptr<Tensor>* out) {
+ const int elem_size = GetByteWidth(*tensor.type());
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto scratch_space,
+ AllocateBuffer(tensor.shape()[tensor.ndim() - 1] * elem_size, pool));
+
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ResizableBuffer> contiguous_data,
+ AllocateResizableBuffer(tensor.size() * elem_size, pool));
+
+ io::BufferOutputStream stream(contiguous_data);
+ RETURN_NOT_OK(WriteStridedTensorData(0, 0, elem_size, tensor,
+ scratch_space->mutable_data(), &stream));
+
+ out->reset(new Tensor(tensor.type(), contiguous_data, tensor.shape()));
+
+ return Status::OK();
+}
+
+} // namespace
+
+Status WriteTensor(const Tensor& tensor, io::OutputStream* dst, int32_t* metadata_length,
+ int64_t* body_length) {
+ const int elem_size = GetByteWidth(*tensor.type());
+
+ *body_length = tensor.size() * elem_size;
+
+ // Tensor metadata accounts for padding
+ if (tensor.is_contiguous()) {
+ RETURN_NOT_OK(WriteTensorHeader(tensor, dst, metadata_length));
+ auto data = tensor.data();
+ if (data && data->data()) {
+ RETURN_NOT_OK(dst->Write(data->data(), *body_length));
+ } else {
+ *body_length = 0;
+ }
+ } else {
+ // The tensor written is made contiguous
+ Tensor dummy(tensor.type(), nullptr, tensor.shape());
+ RETURN_NOT_OK(WriteTensorHeader(dummy, dst, metadata_length));
+
+ // TODO: Do we care enough about this temporary allocation to pass in a
+ // MemoryPool to this function?
+ ARROW_ASSIGN_OR_RAISE(auto scratch_space,
+ AllocateBuffer(tensor.shape()[tensor.ndim() - 1] * elem_size));
+
+ RETURN_NOT_OK(WriteStridedTensorData(0, 0, elem_size, tensor,
+ scratch_space->mutable_data(), dst));
+ }
+
+ return Status::OK();
+}
+
+Result<std::unique_ptr<Message>> GetTensorMessage(const Tensor& tensor,
+ MemoryPool* pool) {
+ const Tensor* tensor_to_write = &tensor;
+ std::unique_ptr<Tensor> temp_tensor;
+
+ if (!tensor.is_contiguous()) {
+ RETURN_NOT_OK(GetContiguousTensor(tensor, pool, &temp_tensor));
+ tensor_to_write = temp_tensor.get();
+ }
+
+ IpcWriteOptions options;
+ options.alignment = kTensorAlignment;
+ std::shared_ptr<Buffer> metadata;
+ ARROW_ASSIGN_OR_RAISE(metadata,
+ internal::WriteTensorMessage(*tensor_to_write, 0, options));
+ return std::unique_ptr<Message>(new Message(metadata, tensor_to_write->data()));
+}
+
+namespace internal {
+
+class SparseTensorSerializer {
+ public:
+ SparseTensorSerializer(int64_t buffer_start_offset, IpcPayload* out)
+ : out_(out),
+ buffer_start_offset_(buffer_start_offset),
+ options_(IpcWriteOptions::Defaults()) {}
+
+ ~SparseTensorSerializer() = default;
+
+ Status VisitSparseIndex(const SparseIndex& sparse_index) {
+ switch (sparse_index.format_id()) {
+ case SparseTensorFormat::COO:
+ RETURN_NOT_OK(
+ VisitSparseCOOIndex(checked_cast<const SparseCOOIndex&>(sparse_index)));
+ break;
+
+ case SparseTensorFormat::CSR:
+ RETURN_NOT_OK(
+ VisitSparseCSRIndex(checked_cast<const SparseCSRIndex&>(sparse_index)));
+ break;
+
+ case SparseTensorFormat::CSC:
+ RETURN_NOT_OK(
+ VisitSparseCSCIndex(checked_cast<const SparseCSCIndex&>(sparse_index)));
+ break;
+
+ case SparseTensorFormat::CSF:
+ RETURN_NOT_OK(
+ VisitSparseCSFIndex(checked_cast<const SparseCSFIndex&>(sparse_index)));
+ break;
+
+ default:
+ std::stringstream ss;
+ ss << "Unable to convert type: " << sparse_index.ToString() << std::endl;
+ return Status::NotImplemented(ss.str());
+ }
+
+ return Status::OK();
+ }
+
+ Status SerializeMetadata(const SparseTensor& sparse_tensor) {
+ return WriteSparseTensorMessage(sparse_tensor, out_->body_length, buffer_meta_,
+ options_)
+ .Value(&out_->metadata);
+ }
+
+ Status Assemble(const SparseTensor& sparse_tensor) {
+ if (buffer_meta_.size() > 0) {
+ buffer_meta_.clear();
+ out_->body_buffers.clear();
+ }
+
+ RETURN_NOT_OK(VisitSparseIndex(*sparse_tensor.sparse_index()));
+ out_->body_buffers.emplace_back(sparse_tensor.data());
+
+ int64_t offset = buffer_start_offset_;
+ buffer_meta_.reserve(out_->body_buffers.size());
+
+ for (size_t i = 0; i < out_->body_buffers.size(); ++i) {
+ const Buffer* buffer = out_->body_buffers[i].get();
+ int64_t size = buffer->size();
+ int64_t padding = BitUtil::RoundUpToMultipleOf8(size) - size;
+ buffer_meta_.push_back({offset, size + padding});
+ offset += size + padding;
+ }
+
+ out_->body_length = offset - buffer_start_offset_;
+ DCHECK(BitUtil::IsMultipleOf8(out_->body_length));
+
+ return SerializeMetadata(sparse_tensor);
+ }
+
+ private:
+ Status VisitSparseCOOIndex(const SparseCOOIndex& sparse_index) {
+ out_->body_buffers.emplace_back(sparse_index.indices()->data());
+ return Status::OK();
+ }
+
+ Status VisitSparseCSRIndex(const SparseCSRIndex& sparse_index) {
+ out_->body_buffers.emplace_back(sparse_index.indptr()->data());
+ out_->body_buffers.emplace_back(sparse_index.indices()->data());
+ return Status::OK();
+ }
+
+ Status VisitSparseCSCIndex(const SparseCSCIndex& sparse_index) {
+ out_->body_buffers.emplace_back(sparse_index.indptr()->data());
+ out_->body_buffers.emplace_back(sparse_index.indices()->data());
+ return Status::OK();
+ }
+
+ Status VisitSparseCSFIndex(const SparseCSFIndex& sparse_index) {
+ for (const std::shared_ptr<arrow::Tensor>& indptr : sparse_index.indptr()) {
+ out_->body_buffers.emplace_back(indptr->data());
+ }
+ for (const std::shared_ptr<arrow::Tensor>& indices : sparse_index.indices()) {
+ out_->body_buffers.emplace_back(indices->data());
+ }
+ return Status::OK();
+ }
+
+ IpcPayload* out_;
+
+ std::vector<internal::BufferMetadata> buffer_meta_;
+ int64_t buffer_start_offset_;
+ IpcWriteOptions options_;
+};
+
+} // namespace internal
+
+Status WriteSparseTensor(const SparseTensor& sparse_tensor, io::OutputStream* dst,
+ int32_t* metadata_length, int64_t* body_length) {
+ IpcPayload payload;
+ internal::SparseTensorSerializer writer(0, &payload);
+ RETURN_NOT_OK(writer.Assemble(sparse_tensor));
+
+ *body_length = payload.body_length;
+ return WriteIpcPayload(payload, IpcWriteOptions::Defaults(), dst, metadata_length);
+}
+
+Status GetSparseTensorPayload(const SparseTensor& sparse_tensor, MemoryPool* pool,
+ IpcPayload* out) {
+ internal::SparseTensorSerializer writer(0, out);
+ return writer.Assemble(sparse_tensor);
+}
+
+Result<std::unique_ptr<Message>> GetSparseTensorMessage(const SparseTensor& sparse_tensor,
+ MemoryPool* pool) {
+ IpcPayload payload;
+ RETURN_NOT_OK(GetSparseTensorPayload(sparse_tensor, pool, &payload));
+ return std::unique_ptr<Message>(
+ new Message(std::move(payload.metadata), std::move(payload.body_buffers[0])));
+}
+
+int64_t GetPayloadSize(const IpcPayload& payload, const IpcWriteOptions& options) {
+ const int32_t prefix_size = options.write_legacy_ipc_format ? 4 : 8;
+ const int32_t flatbuffer_size = static_cast<int32_t>(payload.metadata->size());
+ const int32_t padded_message_length = static_cast<int32_t>(
+ PaddedLength(flatbuffer_size + prefix_size, options.alignment));
+ // body_length already accounts for padding
+ return payload.body_length + padded_message_length;
+}
+
+Status GetRecordBatchSize(const RecordBatch& batch, int64_t* size) {
+ return GetRecordBatchSize(batch, IpcWriteOptions::Defaults(), size);
+}
+
+Status GetRecordBatchSize(const RecordBatch& batch, const IpcWriteOptions& options,
+ int64_t* size) {
+ // emulates the behavior of Write without actually writing
+ int32_t metadata_length = 0;
+ int64_t body_length = 0;
+ io::MockOutputStream dst;
+ RETURN_NOT_OK(
+ WriteRecordBatch(batch, 0, &dst, &metadata_length, &body_length, options));
+ *size = dst.GetExtentBytesWritten();
+ return Status::OK();
+}
+
+Status GetTensorSize(const Tensor& tensor, int64_t* size) {
+ // emulates the behavior of Write without actually writing
+ int32_t metadata_length = 0;
+ int64_t body_length = 0;
+ io::MockOutputStream dst;
+ RETURN_NOT_OK(WriteTensor(tensor, &dst, &metadata_length, &body_length));
+ *size = dst.GetExtentBytesWritten();
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+
+RecordBatchWriter::~RecordBatchWriter() {}
+
+Status RecordBatchWriter::WriteTable(const Table& table, int64_t max_chunksize) {
+ TableBatchReader reader(table);
+
+ if (max_chunksize > 0) {
+ reader.set_chunksize(max_chunksize);
+ }
+
+ std::shared_ptr<RecordBatch> batch;
+ while (true) {
+ RETURN_NOT_OK(reader.ReadNext(&batch));
+ if (batch == nullptr) {
+ break;
+ }
+ RETURN_NOT_OK(WriteRecordBatch(*batch));
+ }
+
+ return Status::OK();
+}
+
+Status RecordBatchWriter::WriteTable(const Table& table) { return WriteTable(table, -1); }
+
+// ----------------------------------------------------------------------
+// Payload writer implementation
+
+namespace internal {
+
+IpcPayloadWriter::~IpcPayloadWriter() {}
+
+Status IpcPayloadWriter::Start() { return Status::OK(); }
+
+class ARROW_EXPORT IpcFormatWriter : public RecordBatchWriter {
+ public:
+ // A RecordBatchWriter implementation that writes to a IpcPayloadWriter.
+ IpcFormatWriter(std::unique_ptr<internal::IpcPayloadWriter> payload_writer,
+ const Schema& schema, const IpcWriteOptions& options,
+ bool is_file_format)
+ : payload_writer_(std::move(payload_writer)),
+ schema_(schema),
+ mapper_(schema),
+ is_file_format_(is_file_format),
+ options_(options) {}
+
+ // A Schema-owning constructor variant
+ IpcFormatWriter(std::unique_ptr<internal::IpcPayloadWriter> payload_writer,
+ const std::shared_ptr<Schema>& schema, const IpcWriteOptions& options,
+ bool is_file_format)
+ : IpcFormatWriter(std::move(payload_writer), *schema, options, is_file_format) {
+ shared_schema_ = schema;
+ }
+
+ Status WriteRecordBatch(const RecordBatch& batch) override {
+ if (!batch.schema()->Equals(schema_, false /* check_metadata */)) {
+ return Status::Invalid("Tried to write record batch with different schema");
+ }
+
+ RETURN_NOT_OK(CheckStarted());
+
+ RETURN_NOT_OK(WriteDictionaries(batch));
+
+ IpcPayload payload;
+ RETURN_NOT_OK(GetRecordBatchPayload(batch, options_, &payload));
+ RETURN_NOT_OK(WritePayload(payload));
+ ++stats_.num_record_batches;
+ return Status::OK();
+ }
+
+ Status WriteTable(const Table& table, int64_t max_chunksize) override {
+ if (is_file_format_ && options_.unify_dictionaries) {
+ ARROW_ASSIGN_OR_RAISE(auto unified_table,
+ DictionaryUnifier::UnifyTable(table, options_.memory_pool));
+ return RecordBatchWriter::WriteTable(*unified_table, max_chunksize);
+ } else {
+ return RecordBatchWriter::WriteTable(table, max_chunksize);
+ }
+ }
+
+ Status Close() override {
+ RETURN_NOT_OK(CheckStarted());
+ return payload_writer_->Close();
+ }
+
+ Status Start() {
+ started_ = true;
+ RETURN_NOT_OK(payload_writer_->Start());
+
+ IpcPayload payload;
+ RETURN_NOT_OK(GetSchemaPayload(schema_, options_, mapper_, &payload));
+ return WritePayload(payload);
+ }
+
+ WriteStats stats() const override { return stats_; }
+
+ protected:
+ Status CheckStarted() {
+ if (!started_) {
+ return Start();
+ }
+ return Status::OK();
+ }
+
+ Status WriteDictionaries(const RecordBatch& batch) {
+ ARROW_ASSIGN_OR_RAISE(const auto dictionaries, CollectDictionaries(batch, mapper_));
+ const auto equal_options = EqualOptions().nans_equal(true);
+
+ for (const auto& pair : dictionaries) {
+ int64_t dictionary_id = pair.first;
+ const auto& dictionary = pair.second;
+
+ // If a dictionary with this id was already emitted, check if it was the same.
+ auto* last_dictionary = &last_dictionaries_[dictionary_id];
+ const bool dictionary_exists = (*last_dictionary != nullptr);
+ int64_t delta_start = 0;
+ if (dictionary_exists) {
+ if ((*last_dictionary)->data() == dictionary->data()) {
+ // Fast shortcut for a common case.
+ // Same dictionary data by pointer => no need to emit it again
+ continue;
+ }
+ const int64_t last_length = (*last_dictionary)->length();
+ const int64_t new_length = dictionary->length();
+ if (new_length == last_length &&
+ ((*last_dictionary)->Equals(dictionary, equal_options))) {
+ // Same dictionary by value => no need to emit it again
+ // (while this can have a CPU cost, this code path is required
+ // for the IPC file format)
+ continue;
+ }
+ if (is_file_format_) {
+ return Status::Invalid(
+ "Dictionary replacement detected when writing IPC file format. "
+ "Arrow IPC files only support a single dictionary for a given field "
+ "across all batches.");
+ }
+
+ // (the read path doesn't support outer dictionary deltas, don't emit them)
+ if (new_length > last_length && options_.emit_dictionary_deltas &&
+ !HasNestedDict(*dictionary->data()) &&
+ ((*last_dictionary)
+ ->RangeEquals(dictionary, 0, last_length, 0, equal_options))) {
+ // New dictionary starts with the current dictionary
+ delta_start = last_length;
+ }
+ }
+
+ IpcPayload payload;
+ if (delta_start) {
+ RETURN_NOT_OK(GetDictionaryPayload(dictionary_id, /*is_delta=*/true,
+ dictionary->Slice(delta_start), options_,
+ &payload));
+ } else {
+ RETURN_NOT_OK(
+ GetDictionaryPayload(dictionary_id, dictionary, options_, &payload));
+ }
+ RETURN_NOT_OK(WritePayload(payload));
+ ++stats_.num_dictionary_batches;
+ if (dictionary_exists) {
+ if (delta_start) {
+ ++stats_.num_dictionary_deltas;
+ } else {
+ ++stats_.num_replaced_dictionaries;
+ }
+ }
+
+ // Remember dictionary for next batches
+ *last_dictionary = dictionary;
+ }
+ return Status::OK();
+ }
+
+ Status WritePayload(const IpcPayload& payload) {
+ RETURN_NOT_OK(payload_writer_->WritePayload(payload));
+ ++stats_.num_messages;
+ return Status::OK();
+ }
+
+ std::unique_ptr<IpcPayloadWriter> payload_writer_;
+ std::shared_ptr<Schema> shared_schema_;
+ const Schema& schema_;
+ const DictionaryFieldMapper mapper_;
+ const bool is_file_format_;
+
+ // A map of last-written dictionaries by id.
+ // This is required to avoid the same dictionary again and again,
+ // and also for correctness when writing the IPC file format
+ // (where replacements and deltas are unsupported).
+ // The latter is also why we can't use weak_ptr.
+ std::unordered_map<int64_t, std::shared_ptr<Array>> last_dictionaries_;
+
+ bool started_ = false;
+ IpcWriteOptions options_;
+ WriteStats stats_;
+};
+
+class StreamBookKeeper {
+ public:
+ StreamBookKeeper(const IpcWriteOptions& options, io::OutputStream* sink)
+ : options_(options), sink_(sink), position_(-1) {}
+ StreamBookKeeper(const IpcWriteOptions& options, std::shared_ptr<io::OutputStream> sink)
+ : options_(options),
+ sink_(sink.get()),
+ owned_sink_(std::move(sink)),
+ position_(-1) {}
+
+ Status UpdatePosition() { return sink_->Tell().Value(&position_); }
+
+ Status UpdatePositionCheckAligned() {
+ RETURN_NOT_OK(UpdatePosition());
+ DCHECK_EQ(0, position_ % 8) << "Stream is not aligned";
+ return Status::OK();
+ }
+
+ Status Align(int32_t alignment = kArrowIpcAlignment) {
+ // Adds padding bytes if necessary to ensure all memory blocks are written on
+ // 8-byte (or other alignment) boundaries.
+ int64_t remainder = PaddedLength(position_, alignment) - position_;
+ if (remainder > 0) {
+ return Write(kPaddingBytes, remainder);
+ }
+ return Status::OK();
+ }
+
+ // Write data and update position
+ Status Write(const void* data, int64_t nbytes) {
+ RETURN_NOT_OK(sink_->Write(data, nbytes));
+ position_ += nbytes;
+ return Status::OK();
+ }
+
+ Status WriteEOS() {
+ // End of stream marker
+ constexpr int32_t kZeroLength = 0;
+ if (!options_.write_legacy_ipc_format) {
+ RETURN_NOT_OK(Write(&kIpcContinuationToken, sizeof(int32_t)));
+ }
+ return Write(&kZeroLength, sizeof(int32_t));
+ }
+
+ protected:
+ IpcWriteOptions options_;
+ io::OutputStream* sink_;
+ std::shared_ptr<io::OutputStream> owned_sink_;
+ int64_t position_;
+};
+
+/// A IpcPayloadWriter implementation that writes to an IPC stream
+/// (with an end-of-stream marker)
+class PayloadStreamWriter : public IpcPayloadWriter, protected StreamBookKeeper {
+ public:
+ PayloadStreamWriter(io::OutputStream* sink,
+ const IpcWriteOptions& options = IpcWriteOptions::Defaults())
+ : StreamBookKeeper(options, sink) {}
+ PayloadStreamWriter(std::shared_ptr<io::OutputStream> sink,
+ const IpcWriteOptions& options = IpcWriteOptions::Defaults())
+ : StreamBookKeeper(options, std::move(sink)) {}
+
+ ~PayloadStreamWriter() override = default;
+
+ Status WritePayload(const IpcPayload& payload) override {
+#ifndef NDEBUG
+ // Catch bug fixed in ARROW-3236
+ RETURN_NOT_OK(UpdatePositionCheckAligned());
+#endif
+
+ int32_t metadata_length = 0; // unused
+ RETURN_NOT_OK(WriteIpcPayload(payload, options_, sink_, &metadata_length));
+ RETURN_NOT_OK(UpdatePositionCheckAligned());
+ return Status::OK();
+ }
+
+ Status Close() override { return WriteEOS(); }
+};
+
+/// A IpcPayloadWriter implementation that writes to a IPC file
+/// (with a footer as defined in File.fbs)
+class PayloadFileWriter : public internal::IpcPayloadWriter, protected StreamBookKeeper {
+ public:
+ PayloadFileWriter(const IpcWriteOptions& options, const std::shared_ptr<Schema>& schema,
+ const std::shared_ptr<const KeyValueMetadata>& metadata,
+ io::OutputStream* sink)
+ : StreamBookKeeper(options, sink), schema_(schema), metadata_(metadata) {}
+ PayloadFileWriter(const IpcWriteOptions& options, const std::shared_ptr<Schema>& schema,
+ const std::shared_ptr<const KeyValueMetadata>& metadata,
+ std::shared_ptr<io::OutputStream> sink)
+ : StreamBookKeeper(options, std::move(sink)),
+ schema_(schema),
+ metadata_(metadata) {}
+
+ ~PayloadFileWriter() override = default;
+
+ Status WritePayload(const IpcPayload& payload) override {
+#ifndef NDEBUG
+ // Catch bug fixed in ARROW-3236
+ RETURN_NOT_OK(UpdatePositionCheckAligned());
+#endif
+
+ // Metadata length must include padding, it's computed by WriteIpcPayload()
+ FileBlock block = {position_, 0, payload.body_length};
+ RETURN_NOT_OK(WriteIpcPayload(payload, options_, sink_, &block.metadata_length));
+ RETURN_NOT_OK(UpdatePositionCheckAligned());
+
+ // Record position and size of some message types, to list them in the footer
+ switch (payload.type) {
+ case MessageType::DICTIONARY_BATCH:
+ dictionaries_.push_back(block);
+ break;
+ case MessageType::RECORD_BATCH:
+ record_batches_.push_back(block);
+ break;
+ default:
+ break;
+ }
+
+ return Status::OK();
+ }
+
+ Status Start() override {
+ // ARROW-3236: The initial position -1 needs to be updated to the stream's
+ // current position otherwise an incorrect amount of padding will be
+ // written to new files.
+ RETURN_NOT_OK(UpdatePosition());
+
+ // It is only necessary to align to 8-byte boundary at the start of the file
+ RETURN_NOT_OK(Write(kArrowMagicBytes, strlen(kArrowMagicBytes)));
+ RETURN_NOT_OK(Align());
+
+ return Status::OK();
+ }
+
+ Status Close() override {
+ // Write 0 EOS message for compatibility with sequential readers
+ RETURN_NOT_OK(WriteEOS());
+
+ // Write file footer
+ RETURN_NOT_OK(UpdatePosition());
+ int64_t initial_position = position_;
+ RETURN_NOT_OK(
+ WriteFileFooter(*schema_, dictionaries_, record_batches_, metadata_, sink_));
+
+ // Write footer length
+ RETURN_NOT_OK(UpdatePosition());
+ int32_t footer_length = static_cast<int32_t>(position_ - initial_position);
+ if (footer_length <= 0) {
+ return Status::Invalid("Invalid file footer");
+ }
+
+ // write footer length in little endian
+ footer_length = BitUtil::ToLittleEndian(footer_length);
+ RETURN_NOT_OK(Write(&footer_length, sizeof(int32_t)));
+
+ // Write magic bytes to end file
+ return Write(kArrowMagicBytes, strlen(kArrowMagicBytes));
+ }
+
+ protected:
+ std::shared_ptr<Schema> schema_;
+ std::shared_ptr<const KeyValueMetadata> metadata_;
+ std::vector<FileBlock> dictionaries_;
+ std::vector<FileBlock> record_batches_;
+};
+
+} // namespace internal
+
+Result<std::shared_ptr<RecordBatchWriter>> MakeStreamWriter(
+ io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options) {
+ return std::make_shared<internal::IpcFormatWriter>(
+ ::arrow::internal::make_unique<internal::PayloadStreamWriter>(sink, options),
+ schema, options, /*is_file_format=*/false);
+}
+
+Result<std::shared_ptr<RecordBatchWriter>> MakeStreamWriter(
+ std::shared_ptr<io::OutputStream> sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options) {
+ return std::make_shared<internal::IpcFormatWriter>(
+ ::arrow::internal::make_unique<internal::PayloadStreamWriter>(std::move(sink),
+ options),
+ schema, options, /*is_file_format=*/false);
+}
+
+Result<std::shared_ptr<RecordBatchWriter>> NewStreamWriter(
+ io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options) {
+ return MakeStreamWriter(sink, schema, options);
+}
+
+Result<std::shared_ptr<RecordBatchWriter>> MakeFileWriter(
+ io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options,
+ const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ return std::make_shared<internal::IpcFormatWriter>(
+ ::arrow::internal::make_unique<internal::PayloadFileWriter>(options, schema,
+ metadata, sink),
+ schema, options, /*is_file_format=*/true);
+}
+
+Result<std::shared_ptr<RecordBatchWriter>> MakeFileWriter(
+ std::shared_ptr<io::OutputStream> sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options,
+ const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ return std::make_shared<internal::IpcFormatWriter>(
+ ::arrow::internal::make_unique<internal::PayloadFileWriter>(
+ options, schema, metadata, std::move(sink)),
+ schema, options, /*is_file_format=*/true);
+}
+
+Result<std::shared_ptr<RecordBatchWriter>> NewFileWriter(
+ io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options,
+ const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ return MakeFileWriter(sink, schema, options, metadata);
+}
+
+namespace internal {
+
+Result<std::unique_ptr<RecordBatchWriter>> OpenRecordBatchWriter(
+ std::unique_ptr<IpcPayloadWriter> sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options) {
+ // XXX should we call Start()?
+ return ::arrow::internal::make_unique<internal::IpcFormatWriter>(
+ std::move(sink), schema, options, /*is_file_format=*/false);
+}
+
+Result<std::unique_ptr<IpcPayloadWriter>> MakePayloadStreamWriter(
+ io::OutputStream* sink, const IpcWriteOptions& options) {
+ return ::arrow::internal::make_unique<internal::PayloadStreamWriter>(sink, options);
+}
+
+Result<std::unique_ptr<IpcPayloadWriter>> MakePayloadFileWriter(
+ io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options,
+ const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ return ::arrow::internal::make_unique<internal::PayloadFileWriter>(options, schema,
+ metadata, sink);
+}
+
+} // namespace internal
+
+// ----------------------------------------------------------------------
+// Serialization public APIs
+
+Result<std::shared_ptr<Buffer>> SerializeRecordBatch(const RecordBatch& batch,
+ std::shared_ptr<MemoryManager> mm) {
+ auto options = IpcWriteOptions::Defaults();
+ int64_t size = 0;
+ RETURN_NOT_OK(GetRecordBatchSize(batch, options, &size));
+ ARROW_ASSIGN_OR_RAISE(auto buffer, mm->AllocateBuffer(size));
+ ARROW_ASSIGN_OR_RAISE(auto writer, Buffer::GetWriter(buffer));
+
+ // XXX Should we have a helper function for getting a MemoryPool
+ // for any MemoryManager (not only CPU)?
+ if (mm->is_cpu()) {
+ options.memory_pool = checked_pointer_cast<CPUMemoryManager>(mm)->pool();
+ }
+ RETURN_NOT_OK(SerializeRecordBatch(batch, options, writer.get()));
+ RETURN_NOT_OK(writer->Close());
+ return buffer;
+}
+
+Result<std::shared_ptr<Buffer>> SerializeRecordBatch(const RecordBatch& batch,
+ const IpcWriteOptions& options) {
+ int64_t size = 0;
+ RETURN_NOT_OK(GetRecordBatchSize(batch, options, &size));
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> buffer,
+ AllocateBuffer(size, options.memory_pool));
+
+ io::FixedSizeBufferWriter stream(buffer);
+ RETURN_NOT_OK(SerializeRecordBatch(batch, options, &stream));
+ return buffer;
+}
+
+Status SerializeRecordBatch(const RecordBatch& batch, const IpcWriteOptions& options,
+ io::OutputStream* out) {
+ int32_t metadata_length = 0;
+ int64_t body_length = 0;
+ return WriteRecordBatch(batch, 0, out, &metadata_length, &body_length, options);
+}
+
+Result<std::shared_ptr<Buffer>> SerializeSchema(const Schema& schema, MemoryPool* pool) {
+ ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create(1024, pool));
+
+ auto options = IpcWriteOptions::Defaults();
+ const bool is_file_format = false; // indifferent as we don't write dictionaries
+ internal::IpcFormatWriter writer(
+ ::arrow::internal::make_unique<internal::PayloadStreamWriter>(stream.get()), schema,
+ options, is_file_format);
+ RETURN_NOT_OK(writer.Start());
+ return stream->Finish();
+}
+
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/ipc/writer.h b/src/arrow/cpp/src/arrow/ipc/writer.h
new file mode 100644
index 000000000..e976b41a1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/ipc/writer.h
@@ -0,0 +1,459 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Implement Arrow streaming binary format
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "arrow/ipc/dictionary.h" // IWYU pragma: export
+#include "arrow/ipc/message.h"
+#include "arrow/ipc/options.h"
+#include "arrow/result.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Array;
+class Buffer;
+class MemoryManager;
+class MemoryPool;
+class RecordBatch;
+class Schema;
+class Status;
+class Table;
+class Tensor;
+class SparseTensor;
+
+namespace io {
+
+class OutputStream;
+
+} // namespace io
+
+namespace ipc {
+
+/// \brief Intermediate data structure with metadata header, and zero
+/// or more buffers for the message body.
+struct IpcPayload {
+ MessageType type = MessageType::NONE;
+ std::shared_ptr<Buffer> metadata;
+ std::vector<std::shared_ptr<Buffer>> body_buffers;
+ int64_t body_length = 0;
+};
+
+struct WriteStats {
+ /// Number of IPC messages written.
+ int64_t num_messages = 0;
+ /// Number of record batches written.
+ int64_t num_record_batches = 0;
+ /// Number of dictionary batches written.
+ ///
+ /// Note: num_dictionary_batches >= num_dictionary_deltas + num_replaced_dictionaries
+ int64_t num_dictionary_batches = 0;
+
+ /// Number of dictionary deltas written.
+ int64_t num_dictionary_deltas = 0;
+ /// Number of replaced dictionaries (i.e. where a dictionary batch replaces
+ /// an existing dictionary with an unrelated new dictionary).
+ int64_t num_replaced_dictionaries = 0;
+};
+
+/// \class RecordBatchWriter
+/// \brief Abstract interface for writing a stream of record batches
+class ARROW_EXPORT RecordBatchWriter {
+ public:
+ virtual ~RecordBatchWriter();
+
+ /// \brief Write a record batch to the stream
+ ///
+ /// \param[in] batch the record batch to write to the stream
+ /// \return Status
+ virtual Status WriteRecordBatch(const RecordBatch& batch) = 0;
+
+ /// \brief Write possibly-chunked table by creating sequence of record batches
+ /// \param[in] table table to write
+ /// \return Status
+ Status WriteTable(const Table& table);
+
+ /// \brief Write Table with a particular chunksize
+ /// \param[in] table table to write
+ /// \param[in] max_chunksize maximum length of table chunks. To indicate
+ /// that no maximum should be enforced, pass -1.
+ /// \return Status
+ virtual Status WriteTable(const Table& table, int64_t max_chunksize);
+
+ /// \brief Perform any logic necessary to finish the stream
+ ///
+ /// \return Status
+ virtual Status Close() = 0;
+
+ /// \brief Return current write statistics
+ virtual WriteStats stats() const = 0;
+};
+
+/// \defgroup record-batch-writer-factories Functions for creating RecordBatchWriter
+/// instances
+///
+/// @{
+
+/// Create a new IPC stream writer from stream sink and schema. User is
+/// responsible for closing the actual OutputStream.
+///
+/// \param[in] sink output stream to write to
+/// \param[in] schema the schema of the record batches to be written
+/// \param[in] options options for serialization
+/// \return Result<std::shared_ptr<RecordBatchWriter>>
+ARROW_EXPORT
+Result<std::shared_ptr<RecordBatchWriter>> MakeStreamWriter(
+ io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options = IpcWriteOptions::Defaults());
+
+/// Create a new IPC stream writer from stream sink and schema. User is
+/// responsible for closing the actual OutputStream.
+///
+/// \param[in] sink output stream to write to
+/// \param[in] schema the schema of the record batches to be written
+/// \param[in] options options for serialization
+/// \return Result<std::shared_ptr<RecordBatchWriter>>
+ARROW_EXPORT
+Result<std::shared_ptr<RecordBatchWriter>> MakeStreamWriter(
+ std::shared_ptr<io::OutputStream> sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options = IpcWriteOptions::Defaults());
+
+/// Create a new IPC file writer from stream sink and schema
+///
+/// \param[in] sink output stream to write to
+/// \param[in] schema the schema of the record batches to be written
+/// \param[in] options options for serialization, optional
+/// \param[in] metadata custom metadata for File Footer, optional
+/// \return Result<std::shared_ptr<RecordBatchWriter>>
+ARROW_EXPORT
+Result<std::shared_ptr<RecordBatchWriter>> MakeFileWriter(
+ io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options = IpcWriteOptions::Defaults(),
+ const std::shared_ptr<const KeyValueMetadata>& metadata = NULLPTR);
+
+/// Create a new IPC file writer from stream sink and schema
+///
+/// \param[in] sink output stream to write to
+/// \param[in] schema the schema of the record batches to be written
+/// \param[in] options options for serialization, optional
+/// \param[in] metadata custom metadata for File Footer, optional
+/// \return Result<std::shared_ptr<RecordBatchWriter>>
+ARROW_EXPORT
+Result<std::shared_ptr<RecordBatchWriter>> MakeFileWriter(
+ std::shared_ptr<io::OutputStream> sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options = IpcWriteOptions::Defaults(),
+ const std::shared_ptr<const KeyValueMetadata>& metadata = NULLPTR);
+
+/// @}
+
+ARROW_DEPRECATED("Deprecated in 3.0.0. Use MakeStreamWriter")
+ARROW_EXPORT
+Result<std::shared_ptr<RecordBatchWriter>> NewStreamWriter(
+ io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options = IpcWriteOptions::Defaults());
+
+ARROW_DEPRECATED("Deprecated in 2.0.0. Use MakeFileWriter")
+ARROW_EXPORT
+Result<std::shared_ptr<RecordBatchWriter>> NewFileWriter(
+ io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options = IpcWriteOptions::Defaults(),
+ const std::shared_ptr<const KeyValueMetadata>& metadata = NULLPTR);
+
+/// \brief Low-level API for writing a record batch (without schema)
+/// to an OutputStream as encapsulated IPC message. See Arrow format
+/// documentation for more detail.
+///
+/// \param[in] batch the record batch to write
+/// \param[in] buffer_start_offset the start offset to use in the buffer metadata,
+/// generally should be 0
+/// \param[in] dst an OutputStream
+/// \param[out] metadata_length the size of the length-prefixed flatbuffer
+/// including padding to a 64-byte boundary
+/// \param[out] body_length the size of the contiguous buffer block plus
+/// \param[in] options options for serialization
+/// \return Status
+ARROW_EXPORT
+Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset,
+ io::OutputStream* dst, int32_t* metadata_length,
+ int64_t* body_length, const IpcWriteOptions& options);
+
+/// \brief Serialize record batch as encapsulated IPC message in a new buffer
+///
+/// \param[in] batch the record batch
+/// \param[in] options the IpcWriteOptions to use for serialization
+/// \return the serialized message
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> SerializeRecordBatch(const RecordBatch& batch,
+ const IpcWriteOptions& options);
+
+/// \brief Serialize record batch as encapsulated IPC message in a new buffer
+///
+/// \param[in] batch the record batch
+/// \param[in] mm a MemoryManager to allocate memory from
+/// \return the serialized message
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> SerializeRecordBatch(const RecordBatch& batch,
+ std::shared_ptr<MemoryManager> mm);
+
+/// \brief Write record batch to OutputStream
+///
+/// \param[in] batch the record batch to write
+/// \param[in] options the IpcWriteOptions to use for serialization
+/// \param[in] out the OutputStream to write the output to
+/// \return Status
+///
+/// If writing to pre-allocated memory, you can use
+/// arrow::ipc::GetRecordBatchSize to compute how much space is required
+ARROW_EXPORT
+Status SerializeRecordBatch(const RecordBatch& batch, const IpcWriteOptions& options,
+ io::OutputStream* out);
+
+/// \brief Serialize schema as encapsulated IPC message
+///
+/// \param[in] schema the schema to write
+/// \param[in] pool a MemoryPool to allocate memory from
+/// \return the serialized schema
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> SerializeSchema(const Schema& schema,
+ MemoryPool* pool = default_memory_pool());
+
+/// \brief Write multiple record batches to OutputStream, including schema
+/// \param[in] batches a vector of batches. Must all have same schema
+/// \param[in] options options for serialization
+/// \param[out] dst an OutputStream
+/// \return Status
+ARROW_EXPORT
+Status WriteRecordBatchStream(const std::vector<std::shared_ptr<RecordBatch>>& batches,
+ const IpcWriteOptions& options, io::OutputStream* dst);
+
+/// \brief Compute the number of bytes needed to write an IPC payload
+/// including metadata
+///
+/// \param[in] payload the IPC payload to write
+/// \param[in] options write options
+/// \return the size of the complete encapsulated message
+ARROW_EXPORT
+int64_t GetPayloadSize(const IpcPayload& payload,
+ const IpcWriteOptions& options = IpcWriteOptions::Defaults());
+
+/// \brief Compute the number of bytes needed to write a record batch including metadata
+///
+/// \param[in] batch the record batch to write
+/// \param[out] size the size of the complete encapsulated message
+/// \return Status
+ARROW_EXPORT
+Status GetRecordBatchSize(const RecordBatch& batch, int64_t* size);
+
+/// \brief Compute the number of bytes needed to write a record batch including metadata
+///
+/// \param[in] batch the record batch to write
+/// \param[in] options options for serialization
+/// \param[out] size the size of the complete encapsulated message
+/// \return Status
+ARROW_EXPORT
+Status GetRecordBatchSize(const RecordBatch& batch, const IpcWriteOptions& options,
+ int64_t* size);
+
+/// \brief Compute the number of bytes needed to write a tensor including metadata
+///
+/// \param[in] tensor the tensor to write
+/// \param[out] size the size of the complete encapsulated message
+/// \return Status
+ARROW_EXPORT
+Status GetTensorSize(const Tensor& tensor, int64_t* size);
+
+/// \brief EXPERIMENTAL: Convert arrow::Tensor to a Message with minimal memory
+/// allocation
+///
+/// \param[in] tensor the Tensor to write
+/// \param[in] pool MemoryPool to allocate space for metadata
+/// \return the resulting Message
+ARROW_EXPORT
+Result<std::unique_ptr<Message>> GetTensorMessage(const Tensor& tensor, MemoryPool* pool);
+
+/// \brief Write arrow::Tensor as a contiguous message.
+///
+/// The metadata and body are written assuming 64-byte alignment. It is the
+/// user's responsibility to ensure that the OutputStream has been aligned
+/// to a 64-byte multiple before writing the message.
+///
+/// The message is written out as followed:
+/// \code
+/// <metadata size> <metadata> <tensor data>
+/// \endcode
+///
+/// \param[in] tensor the Tensor to write
+/// \param[in] dst the OutputStream to write to
+/// \param[out] metadata_length the actual metadata length, including padding
+/// \param[out] body_length the actual message body length
+/// \return Status
+ARROW_EXPORT
+Status WriteTensor(const Tensor& tensor, io::OutputStream* dst, int32_t* metadata_length,
+ int64_t* body_length);
+
+/// \brief EXPERIMENTAL: Convert arrow::SparseTensor to a Message with minimal memory
+/// allocation
+///
+/// The message is written out as followed:
+/// \code
+/// <metadata size> <metadata> <sparse index> <sparse tensor body>
+/// \endcode
+///
+/// \param[in] sparse_tensor the SparseTensor to write
+/// \param[in] pool MemoryPool to allocate space for metadata
+/// \return the resulting Message
+ARROW_EXPORT
+Result<std::unique_ptr<Message>> GetSparseTensorMessage(const SparseTensor& sparse_tensor,
+ MemoryPool* pool);
+
+/// \brief EXPERIMENTAL: Write arrow::SparseTensor as a contiguous message. The metadata,
+/// sparse index, and body are written assuming 64-byte alignment. It is the
+/// user's responsibility to ensure that the OutputStream has been aligned
+/// to a 64-byte multiple before writing the message.
+///
+/// \param[in] sparse_tensor the SparseTensor to write
+/// \param[in] dst the OutputStream to write to
+/// \param[out] metadata_length the actual metadata length, including padding
+/// \param[out] body_length the actual message body length
+/// \return Status
+ARROW_EXPORT
+Status WriteSparseTensor(const SparseTensor& sparse_tensor, io::OutputStream* dst,
+ int32_t* metadata_length, int64_t* body_length);
+
+/// \brief Compute IpcPayload for the given schema
+/// \param[in] schema the Schema that is being serialized
+/// \param[in] options options for serialization
+/// \param[in] mapper object mapping dictionary fields to dictionary ids
+/// \param[out] out the returned vector of IpcPayloads
+/// \return Status
+ARROW_EXPORT
+Status GetSchemaPayload(const Schema& schema, const IpcWriteOptions& options,
+ const DictionaryFieldMapper& mapper, IpcPayload* out);
+
+/// \brief Compute IpcPayload for a dictionary
+/// \param[in] id the dictionary id
+/// \param[in] dictionary the dictionary values
+/// \param[in] options options for serialization
+/// \param[out] payload the output IpcPayload
+/// \return Status
+ARROW_EXPORT
+Status GetDictionaryPayload(int64_t id, const std::shared_ptr<Array>& dictionary,
+ const IpcWriteOptions& options, IpcPayload* payload);
+
+/// \brief Compute IpcPayload for a dictionary
+/// \param[in] id the dictionary id
+/// \param[in] is_delta whether the dictionary is a delta dictionary
+/// \param[in] dictionary the dictionary values
+/// \param[in] options options for serialization
+/// \param[out] payload the output IpcPayload
+/// \return Status
+ARROW_EXPORT
+Status GetDictionaryPayload(int64_t id, bool is_delta,
+ const std::shared_ptr<Array>& dictionary,
+ const IpcWriteOptions& options, IpcPayload* payload);
+
+/// \brief Compute IpcPayload for the given record batch
+/// \param[in] batch the RecordBatch that is being serialized
+/// \param[in] options options for serialization
+/// \param[out] out the returned IpcPayload
+/// \return Status
+ARROW_EXPORT
+Status GetRecordBatchPayload(const RecordBatch& batch, const IpcWriteOptions& options,
+ IpcPayload* out);
+
+/// \brief Write an IPC payload to the given stream.
+/// \param[in] payload the payload to write
+/// \param[in] options options for serialization
+/// \param[in] dst The stream to write the payload to.
+/// \param[out] metadata_length the length of the serialized metadata
+/// \return Status
+ARROW_EXPORT
+Status WriteIpcPayload(const IpcPayload& payload, const IpcWriteOptions& options,
+ io::OutputStream* dst, int32_t* metadata_length);
+
+/// \brief Compute IpcPayload for the given sparse tensor
+/// \param[in] sparse_tensor the SparseTensor that is being serialized
+/// \param[in,out] pool for any required temporary memory allocations
+/// \param[out] out the returned IpcPayload
+/// \return Status
+ARROW_EXPORT
+Status GetSparseTensorPayload(const SparseTensor& sparse_tensor, MemoryPool* pool,
+ IpcPayload* out);
+
+namespace internal {
+
+// These internal APIs may change without warning or deprecation
+
+class ARROW_EXPORT IpcPayloadWriter {
+ public:
+ virtual ~IpcPayloadWriter();
+
+ // Default implementation is a no-op
+ virtual Status Start();
+
+ virtual Status WritePayload(const IpcPayload& payload) = 0;
+
+ virtual Status Close() = 0;
+};
+
+/// Create a new IPC payload stream writer from stream sink. User is
+/// responsible for closing the actual OutputStream.
+///
+/// \param[in] sink output stream to write to
+/// \param[in] options options for serialization
+/// \return Result<std::shared_ptr<IpcPayloadWriter>>
+ARROW_EXPORT
+Result<std::unique_ptr<IpcPayloadWriter>> MakePayloadStreamWriter(
+ io::OutputStream* sink, const IpcWriteOptions& options = IpcWriteOptions::Defaults());
+
+/// Create a new IPC payload file writer from stream sink.
+///
+/// \param[in] sink output stream to write to
+/// \param[in] schema the schema of the record batches to be written
+/// \param[in] options options for serialization, optional
+/// \param[in] metadata custom metadata for File Footer, optional
+/// \return Status
+ARROW_EXPORT
+Result<std::unique_ptr<IpcPayloadWriter>> MakePayloadFileWriter(
+ io::OutputStream* sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options = IpcWriteOptions::Defaults(),
+ const std::shared_ptr<const KeyValueMetadata>& metadata = NULLPTR);
+
+/// Create a new RecordBatchWriter from IpcPayloadWriter and schema.
+///
+/// The format is implicitly the IPC stream format (allowing dictionary
+/// replacement and deltas).
+///
+/// \param[in] sink the IpcPayloadWriter to write to
+/// \param[in] schema the schema of the record batches to be written
+/// \param[in] options options for serialization
+/// \return Result<std::unique_ptr<RecordBatchWriter>>
+ARROW_EXPORT
+Result<std::unique_ptr<RecordBatchWriter>> OpenRecordBatchWriter(
+ std::unique_ptr<IpcPayloadWriter> sink, const std::shared_ptr<Schema>& schema,
+ const IpcWriteOptions& options = IpcWriteOptions::Defaults());
+
+} // namespace internal
+} // namespace ipc
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/CMakeLists.txt b/src/arrow/cpp/src/arrow/json/CMakeLists.txt
new file mode 100644
index 000000000..f09b15ce5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/CMakeLists.txt
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_arrow_test(test
+ SOURCES
+ chunked_builder_test.cc
+ chunker_test.cc
+ converter_test.cc
+ parser_test.cc
+ reader_test.cc
+ PREFIX
+ "arrow-json")
+
+add_arrow_benchmark(parser_benchmark PREFIX "arrow-json")
+arrow_install_all_headers("arrow/json")
+
+# pkg-config support
+arrow_add_pkg_config("arrow-json")
diff --git a/src/arrow/cpp/src/arrow/json/api.h b/src/arrow/cpp/src/arrow/json/api.h
new file mode 100644
index 000000000..47b56684b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/api.h
@@ -0,0 +1,21 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/json/options.h"
+#include "arrow/json/reader.h"
diff --git a/src/arrow/cpp/src/arrow/json/arrow-json.pc.in b/src/arrow/cpp/src/arrow/json/arrow-json.pc.in
new file mode 100644
index 000000000..ace2a07a3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/arrow-json.pc.in
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+
+Name: Apache Arrow JSON
+Description: JSON reader module for Apache Arrow
+Version: @ARROW_VERSION@
+Requires: arrow
diff --git a/src/arrow/cpp/src/arrow/json/chunked_builder.cc b/src/arrow/cpp/src/arrow/json/chunked_builder.cc
new file mode 100644
index 000000000..e95041ea0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/chunked_builder.cc
@@ -0,0 +1,470 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/json/chunked_builder.h"
+
+#include <mutex>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/json/converter.h"
+#include "arrow/table.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/task_group.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::TaskGroup;
+
+namespace json {
+
+class NonNestedChunkedArrayBuilder : public ChunkedArrayBuilder {
+ public:
+ NonNestedChunkedArrayBuilder(const std::shared_ptr<TaskGroup>& task_group,
+ std::shared_ptr<Converter> converter)
+ : ChunkedArrayBuilder(task_group), converter_(std::move(converter)) {}
+
+ Status Finish(std::shared_ptr<ChunkedArray>* out) override {
+ RETURN_NOT_OK(task_group_->Finish());
+ *out = std::make_shared<ChunkedArray>(std::move(chunks_), converter_->out_type());
+ chunks_.clear();
+ return Status::OK();
+ }
+
+ Status ReplaceTaskGroup(const std::shared_ptr<TaskGroup>& task_group) override {
+ RETURN_NOT_OK(task_group_->Finish());
+ task_group_ = task_group;
+ return Status::OK();
+ }
+
+ protected:
+ ArrayVector chunks_;
+ std::mutex mutex_;
+ std::shared_ptr<Converter> converter_;
+};
+
+class TypedChunkedArrayBuilder
+ : public NonNestedChunkedArrayBuilder,
+ public std::enable_shared_from_this<TypedChunkedArrayBuilder> {
+ public:
+ using NonNestedChunkedArrayBuilder::NonNestedChunkedArrayBuilder;
+
+ void Insert(int64_t block_index, const std::shared_ptr<Field>&,
+ const std::shared_ptr<Array>& unconverted) override {
+ std::unique_lock<std::mutex> lock(mutex_);
+ if (chunks_.size() <= static_cast<size_t>(block_index)) {
+ chunks_.resize(static_cast<size_t>(block_index) + 1, nullptr);
+ }
+ lock.unlock();
+
+ auto self = shared_from_this();
+
+ task_group_->Append([self, block_index, unconverted] {
+ std::shared_ptr<Array> converted;
+ RETURN_NOT_OK(self->converter_->Convert(unconverted, &converted));
+ std::unique_lock<std::mutex> lock(self->mutex_);
+ self->chunks_[block_index] = std::move(converted);
+ return Status::OK();
+ });
+ }
+};
+
+class InferringChunkedArrayBuilder
+ : public NonNestedChunkedArrayBuilder,
+ public std::enable_shared_from_this<InferringChunkedArrayBuilder> {
+ public:
+ InferringChunkedArrayBuilder(const std::shared_ptr<TaskGroup>& task_group,
+ const PromotionGraph* promotion_graph,
+ std::shared_ptr<Converter> converter)
+ : NonNestedChunkedArrayBuilder(task_group, std::move(converter)),
+ promotion_graph_(promotion_graph) {}
+
+ void Insert(int64_t block_index, const std::shared_ptr<Field>& unconverted_field,
+ const std::shared_ptr<Array>& unconverted) override {
+ std::unique_lock<std::mutex> lock(mutex_);
+ if (chunks_.size() <= static_cast<size_t>(block_index)) {
+ chunks_.resize(static_cast<size_t>(block_index) + 1, nullptr);
+ unconverted_.resize(chunks_.size(), nullptr);
+ unconverted_fields_.resize(chunks_.size(), nullptr);
+ }
+ unconverted_[block_index] = unconverted;
+ unconverted_fields_[block_index] = unconverted_field;
+ lock.unlock();
+ ScheduleConvertChunk(block_index);
+ }
+
+ void ScheduleConvertChunk(int64_t block_index) {
+ auto self = shared_from_this();
+ task_group_->Append([self, block_index] {
+ return self->TryConvertChunk(static_cast<size_t>(block_index));
+ });
+ }
+
+ Status TryConvertChunk(size_t block_index) {
+ std::unique_lock<std::mutex> lock(mutex_);
+ auto converter = converter_;
+ auto unconverted = unconverted_[block_index];
+ auto unconverted_field = unconverted_fields_[block_index];
+ std::shared_ptr<Array> converted;
+
+ lock.unlock();
+ Status st = converter->Convert(unconverted, &converted);
+ lock.lock();
+
+ if (converter != converter_) {
+ // another task promoted converter; reconvert
+ lock.unlock();
+ ScheduleConvertChunk(block_index);
+ return Status::OK();
+ }
+
+ if (st.ok()) {
+ // conversion succeeded
+ chunks_[block_index] = std::move(converted);
+ return Status::OK();
+ }
+
+ auto promoted_type =
+ promotion_graph_->Promote(converter_->out_type(), unconverted_field);
+ if (promoted_type == nullptr) {
+ // converter failed, no promotion available
+ return st;
+ }
+ RETURN_NOT_OK(MakeConverter(promoted_type, converter_->pool(), &converter_));
+
+ size_t nchunks = chunks_.size();
+ for (size_t i = 0; i < nchunks; ++i) {
+ if (i != block_index && chunks_[i]) {
+ // We're assuming the chunk was converted using the wrong type
+ // (which should be true unless the executor reorders tasks)
+ chunks_[i].reset();
+ lock.unlock();
+ ScheduleConvertChunk(i);
+ lock.lock();
+ }
+ }
+ lock.unlock();
+ ScheduleConvertChunk(block_index);
+ return Status::OK();
+ }
+
+ Status Finish(std::shared_ptr<ChunkedArray>* out) override {
+ RETURN_NOT_OK(NonNestedChunkedArrayBuilder::Finish(out));
+ unconverted_.clear();
+ return Status::OK();
+ }
+
+ private:
+ ArrayVector unconverted_;
+ std::vector<std::shared_ptr<Field>> unconverted_fields_;
+ const PromotionGraph* promotion_graph_;
+};
+
+class ChunkedListArrayBuilder : public ChunkedArrayBuilder {
+ public:
+ ChunkedListArrayBuilder(const std::shared_ptr<TaskGroup>& task_group, MemoryPool* pool,
+ std::shared_ptr<ChunkedArrayBuilder> value_builder,
+ const std::shared_ptr<Field>& value_field)
+ : ChunkedArrayBuilder(task_group),
+ pool_(pool),
+ value_builder_(std::move(value_builder)),
+ value_field_(value_field) {}
+
+ Status ReplaceTaskGroup(const std::shared_ptr<TaskGroup>& task_group) override {
+ RETURN_NOT_OK(task_group_->Finish());
+ RETURN_NOT_OK(value_builder_->ReplaceTaskGroup(task_group));
+ task_group_ = task_group;
+ return Status::OK();
+ }
+
+ void Insert(int64_t block_index, const std::shared_ptr<Field>&,
+ const std::shared_ptr<Array>& unconverted) override {
+ std::unique_lock<std::mutex> lock(mutex_);
+
+ if (null_bitmap_chunks_.size() <= static_cast<size_t>(block_index)) {
+ null_bitmap_chunks_.resize(static_cast<size_t>(block_index) + 1, nullptr);
+ offset_chunks_.resize(null_bitmap_chunks_.size(), nullptr);
+ }
+
+ if (unconverted->type_id() == Type::NA) {
+ auto st = InsertNull(block_index, unconverted->length());
+ if (!st.ok()) {
+ task_group_->Append([st] { return st; });
+ }
+ return;
+ }
+
+ DCHECK_EQ(unconverted->type_id(), Type::LIST);
+ const auto& list_array = checked_cast<const ListArray&>(*unconverted);
+
+ null_bitmap_chunks_[block_index] = unconverted->null_bitmap();
+ offset_chunks_[block_index] = list_array.value_offsets();
+
+ value_builder_->Insert(block_index, list_array.list_type()->value_field(),
+ list_array.values());
+ }
+
+ Status Finish(std::shared_ptr<ChunkedArray>* out) override {
+ RETURN_NOT_OK(task_group_->Finish());
+
+ std::shared_ptr<ChunkedArray> value_array;
+ RETURN_NOT_OK(value_builder_->Finish(&value_array));
+
+ auto type = list(value_field_->WithType(value_array->type())->WithMetadata(nullptr));
+ ArrayVector chunks(null_bitmap_chunks_.size());
+ for (size_t i = 0; i < null_bitmap_chunks_.size(); ++i) {
+ auto value_chunk = value_array->chunk(static_cast<int>(i));
+ auto length = offset_chunks_[i]->size() / sizeof(int32_t) - 1;
+ chunks[i] = std::make_shared<ListArray>(type, length, offset_chunks_[i],
+ value_chunk, null_bitmap_chunks_[i]);
+ }
+
+ *out = std::make_shared<ChunkedArray>(std::move(chunks), type);
+ return Status::OK();
+ }
+
+ private:
+ // call from Insert() only, with mutex_ locked
+ Status InsertNull(int64_t block_index, int64_t length) {
+ value_builder_->Insert(block_index, value_field_, std::make_shared<NullArray>(0));
+
+ ARROW_ASSIGN_OR_RAISE(null_bitmap_chunks_[block_index],
+ AllocateEmptyBitmap(length, pool_));
+
+ int64_t offsets_length = (length + 1) * sizeof(int32_t);
+ ARROW_ASSIGN_OR_RAISE(offset_chunks_[block_index],
+ AllocateBuffer(offsets_length, pool_));
+ std::memset(offset_chunks_[block_index]->mutable_data(), 0, offsets_length);
+
+ return Status::OK();
+ }
+
+ std::mutex mutex_;
+ MemoryPool* pool_;
+ std::shared_ptr<ChunkedArrayBuilder> value_builder_;
+ BufferVector offset_chunks_, null_bitmap_chunks_;
+ std::shared_ptr<Field> value_field_;
+};
+
+class ChunkedStructArrayBuilder : public ChunkedArrayBuilder {
+ public:
+ ChunkedStructArrayBuilder(
+ const std::shared_ptr<TaskGroup>& task_group, MemoryPool* pool,
+ const PromotionGraph* promotion_graph,
+ std::vector<std::pair<std::string, std::shared_ptr<ChunkedArrayBuilder>>>
+ name_builders)
+ : ChunkedArrayBuilder(task_group), pool_(pool), promotion_graph_(promotion_graph) {
+ for (auto&& name_builder : name_builders) {
+ auto index = static_cast<int>(name_to_index_.size());
+ name_to_index_.emplace(std::move(name_builder.first), index);
+ child_builders_.emplace_back(std::move(name_builder.second));
+ }
+ }
+
+ void Insert(int64_t block_index, const std::shared_ptr<Field>&,
+ const std::shared_ptr<Array>& unconverted) override {
+ std::unique_lock<std::mutex> lock(mutex_);
+
+ if (null_bitmap_chunks_.size() <= static_cast<size_t>(block_index)) {
+ null_bitmap_chunks_.resize(static_cast<size_t>(block_index) + 1, nullptr);
+ chunk_lengths_.resize(null_bitmap_chunks_.size(), -1);
+ child_absent_.resize(null_bitmap_chunks_.size(), std::vector<bool>(0));
+ }
+ null_bitmap_chunks_[block_index] = unconverted->null_bitmap();
+ chunk_lengths_[block_index] = unconverted->length();
+
+ if (unconverted->type_id() == Type::NA) {
+ auto maybe_buffer = AllocateBitmap(unconverted->length(), pool_);
+ if (maybe_buffer.ok()) {
+ null_bitmap_chunks_[block_index] = *std::move(maybe_buffer);
+ std::memset(null_bitmap_chunks_[block_index]->mutable_data(), 0,
+ null_bitmap_chunks_[block_index]->size());
+ } else {
+ Status st = maybe_buffer.status();
+ task_group_->Append([st] { return st; });
+ }
+
+ // absent fields will be inserted at Finish
+ return;
+ }
+
+ const auto& struct_array = checked_cast<const StructArray&>(*unconverted);
+ if (promotion_graph_ == nullptr) {
+ // If unexpected fields are ignored or result in an error then all parsers will emit
+ // columns exclusively in the ordering specified in ParseOptions::explicit_schema,
+ // so child_builders_ is immutable and no associative lookup is necessary.
+ for (int i = 0; i < unconverted->num_fields(); ++i) {
+ child_builders_[i]->Insert(block_index, unconverted->type()->field(i),
+ struct_array.field(i));
+ }
+ } else {
+ auto st = InsertChildren(block_index, struct_array);
+ if (!st.ok()) {
+ return task_group_->Append([st] { return st; });
+ }
+ }
+ }
+
+ Status Finish(std::shared_ptr<ChunkedArray>* out) override {
+ RETURN_NOT_OK(task_group_->Finish());
+
+ if (promotion_graph_ != nullptr) {
+ // insert absent child chunks
+ for (auto&& name_index : name_to_index_) {
+ auto child_builder = child_builders_[name_index.second].get();
+
+ RETURN_NOT_OK(child_builder->ReplaceTaskGroup(TaskGroup::MakeSerial()));
+
+ for (size_t i = 0; i < chunk_lengths_.size(); ++i) {
+ if (child_absent_[i].size() > static_cast<size_t>(name_index.second) &&
+ !child_absent_[i][name_index.second]) {
+ continue;
+ }
+ auto empty = std::make_shared<NullArray>(chunk_lengths_[i]);
+ child_builder->Insert(i, promotion_graph_->Null(name_index.first), empty);
+ }
+ }
+ }
+
+ std::vector<std::shared_ptr<Field>> fields(name_to_index_.size());
+ std::vector<std::shared_ptr<ChunkedArray>> child_arrays(name_to_index_.size());
+ for (auto&& name_index : name_to_index_) {
+ auto child_builder = child_builders_[name_index.second].get();
+
+ std::shared_ptr<ChunkedArray> child_array;
+ RETURN_NOT_OK(child_builder->Finish(&child_array));
+
+ child_arrays[name_index.second] = child_array;
+ fields[name_index.second] = field(name_index.first, child_array->type());
+ }
+
+ auto type = struct_(std::move(fields));
+ ArrayVector chunks(null_bitmap_chunks_.size());
+ for (size_t i = 0; i < null_bitmap_chunks_.size(); ++i) {
+ ArrayVector child_chunks;
+ for (const auto& child_array : child_arrays) {
+ child_chunks.push_back(child_array->chunk(static_cast<int>(i)));
+ }
+ chunks[i] = std::make_shared<StructArray>(type, chunk_lengths_[i], child_chunks,
+ null_bitmap_chunks_[i]);
+ }
+
+ *out = std::make_shared<ChunkedArray>(std::move(chunks), type);
+ return Status::OK();
+ }
+
+ Status ReplaceTaskGroup(const std::shared_ptr<TaskGroup>& task_group) override {
+ RETURN_NOT_OK(task_group_->Finish());
+ for (auto&& child_builder : child_builders_) {
+ RETURN_NOT_OK(child_builder->ReplaceTaskGroup(task_group));
+ }
+ task_group_ = task_group;
+ return Status::OK();
+ }
+
+ private:
+ // Insert children associatively by name; the unconverted block may have unexpected or
+ // differently ordered fields
+ // call from Insert() only, with mutex_ locked
+ Status InsertChildren(int64_t block_index, const StructArray& unconverted) {
+ const auto& fields = unconverted.type()->fields();
+
+ for (int i = 0; i < unconverted.num_fields(); ++i) {
+ auto it = name_to_index_.find(fields[i]->name());
+
+ if (it == name_to_index_.end()) {
+ // add a new field to this builder
+ auto type = promotion_graph_->Infer(fields[i]);
+ DCHECK_NE(type, nullptr)
+ << "invalid unconverted_field encountered in conversion: "
+ << fields[i]->name() << ":" << *fields[i]->type();
+
+ auto new_index = static_cast<int>(name_to_index_.size());
+ it = name_to_index_.emplace(fields[i]->name(), new_index).first;
+
+ std::shared_ptr<ChunkedArrayBuilder> child_builder;
+ RETURN_NOT_OK(MakeChunkedArrayBuilder(task_group_, pool_, promotion_graph_, type,
+ &child_builder));
+ child_builders_.emplace_back(std::move(child_builder));
+ }
+
+ auto unconverted_field = unconverted.type()->field(i);
+ child_builders_[it->second]->Insert(block_index, unconverted_field,
+ unconverted.field(i));
+
+ child_absent_[block_index].resize(child_builders_.size(), true);
+ child_absent_[block_index][it->second] = false;
+ }
+
+ return Status::OK();
+ }
+
+ std::mutex mutex_;
+ MemoryPool* pool_;
+ const PromotionGraph* promotion_graph_;
+ std::unordered_map<std::string, int> name_to_index_;
+ std::vector<std::shared_ptr<ChunkedArrayBuilder>> child_builders_;
+ std::vector<std::vector<bool>> child_absent_;
+ BufferVector null_bitmap_chunks_;
+ std::vector<int64_t> chunk_lengths_;
+};
+
+Status MakeChunkedArrayBuilder(const std::shared_ptr<TaskGroup>& task_group,
+ MemoryPool* pool, const PromotionGraph* promotion_graph,
+ const std::shared_ptr<DataType>& type,
+ std::shared_ptr<ChunkedArrayBuilder>* out) {
+ if (type->id() == Type::STRUCT) {
+ std::vector<std::pair<std::string, std::shared_ptr<ChunkedArrayBuilder>>>
+ child_builders;
+ for (const auto& f : type->fields()) {
+ std::shared_ptr<ChunkedArrayBuilder> child_builder;
+ RETURN_NOT_OK(MakeChunkedArrayBuilder(task_group, pool, promotion_graph, f->type(),
+ &child_builder));
+ child_builders.emplace_back(f->name(), std::move(child_builder));
+ }
+ *out = std::make_shared<ChunkedStructArrayBuilder>(task_group, pool, promotion_graph,
+ std::move(child_builders));
+ return Status::OK();
+ }
+ if (type->id() == Type::LIST) {
+ const auto& list_type = checked_cast<const ListType&>(*type);
+ std::shared_ptr<ChunkedArrayBuilder> value_builder;
+ RETURN_NOT_OK(MakeChunkedArrayBuilder(task_group, pool, promotion_graph,
+ list_type.value_type(), &value_builder));
+ *out = std::make_shared<ChunkedListArrayBuilder>(
+ task_group, pool, std::move(value_builder), list_type.value_field());
+ return Status::OK();
+ }
+ std::shared_ptr<Converter> converter;
+ RETURN_NOT_OK(MakeConverter(type, pool, &converter));
+ if (promotion_graph) {
+ *out = std::make_shared<InferringChunkedArrayBuilder>(task_group, promotion_graph,
+ std::move(converter));
+ } else {
+ *out = std::make_shared<TypedChunkedArrayBuilder>(task_group, std::move(converter));
+ }
+ return Status::OK();
+}
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/chunked_builder.h b/src/arrow/cpp/src/arrow/json/chunked_builder.h
new file mode 100644
index 000000000..93b327bf3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/chunked_builder.h
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <vector>
+
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace json {
+
+class PromotionGraph;
+
+class ARROW_EXPORT ChunkedArrayBuilder {
+ public:
+ virtual ~ChunkedArrayBuilder() = default;
+
+ /// Spawn a task that will try to convert and insert the given JSON block
+ virtual void Insert(int64_t block_index,
+ const std::shared_ptr<Field>& unconverted_field,
+ const std::shared_ptr<Array>& unconverted) = 0;
+
+ /// Return the final chunked array.
+ /// Every chunk must be inserted before this is called!
+ virtual Status Finish(std::shared_ptr<ChunkedArray>* out) = 0;
+
+ /// Finish current task group and substitute a new one
+ virtual Status ReplaceTaskGroup(
+ const std::shared_ptr<arrow::internal::TaskGroup>& task_group) = 0;
+
+ protected:
+ explicit ChunkedArrayBuilder(
+ const std::shared_ptr<arrow::internal::TaskGroup>& task_group)
+ : task_group_(task_group) {}
+
+ std::shared_ptr<arrow::internal::TaskGroup> task_group_;
+};
+
+/// create a chunked builder
+///
+/// if unexpected fields and promotion need to be handled, promotion_graph must be
+/// non-null
+ARROW_EXPORT Status MakeChunkedArrayBuilder(
+ const std::shared_ptr<arrow::internal::TaskGroup>& task_group, MemoryPool* pool,
+ const PromotionGraph* promotion_graph, const std::shared_ptr<DataType>& type,
+ std::shared_ptr<ChunkedArrayBuilder>* out);
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/chunked_builder_test.cc b/src/arrow/cpp/src/arrow/json/chunked_builder_test.cc
new file mode 100644
index 000000000..d04f0d5c9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/chunked_builder_test.cc
@@ -0,0 +1,450 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/json/chunked_builder.h"
+#include "arrow/json/converter.h"
+#include "arrow/json/options.h"
+#include "arrow/json/test_common.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/task_group.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+namespace json {
+
+using util::string_view;
+
+using internal::checked_cast;
+using internal::GetCpuThreadPool;
+using internal::TaskGroup;
+
+void AssertBuilding(const std::shared_ptr<ChunkedArrayBuilder>& builder,
+ const std::vector<std::string>& chunks,
+ std::shared_ptr<ChunkedArray>* out) {
+ ArrayVector unconverted;
+
+ auto options = ParseOptions::Defaults();
+ for (const auto& chunk : chunks) {
+ std::shared_ptr<Array> parsed;
+ ASSERT_OK(ParseFromString(options, chunk, &parsed));
+ unconverted.push_back(parsed);
+ }
+
+ int64_t i = 0;
+ for (const auto& parsed : unconverted) {
+ builder->Insert(i, field("", parsed->type()), parsed);
+ ++i;
+ }
+ ASSERT_OK(builder->Finish(out));
+ ASSERT_OK((*out)->ValidateFull());
+}
+
+std::shared_ptr<ChunkedArray> ExtractField(const std::string& name,
+ const ChunkedArray& columns) {
+ auto chunks = columns.chunks();
+ for (auto& chunk : chunks) {
+ chunk = checked_cast<const StructArray&>(*chunk).GetFieldByName(name);
+ }
+ const auto& struct_type = checked_cast<const StructType&>(*columns.type());
+ return std::make_shared<ChunkedArray>(chunks, struct_type.GetFieldByName(name)->type());
+}
+
+void AssertFieldEqual(const std::vector<std::string>& path,
+ const std::shared_ptr<ChunkedArray>& columns,
+ const ChunkedArray& expected) {
+ ASSERT_EQ(expected.num_chunks(), columns->num_chunks()) << "# chunks unequal";
+ std::shared_ptr<ChunkedArray> actual = columns;
+ for (const auto& name : path) {
+ actual = ExtractField(name, *actual);
+ }
+ AssertChunkedEqual(expected, *actual);
+}
+
+TEST(ChunkedArrayBuilder, Empty) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), nullptr,
+ struct_({field("a", int32())}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder, {}, &actual);
+
+ ChunkedArray expected({}, int32());
+ AssertFieldEqual({"a"}, actual, expected);
+}
+
+TEST(ChunkedArrayBuilder, Basics) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), nullptr,
+ struct_({field("a", int32())}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder, {RowsOfOneColumn("a", {123, -456})}, &actual);
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<Int32Type>({{123, -456}}, &expected);
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+TEST(ChunkedArrayBuilder, Insert) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), nullptr,
+ struct_({field("a", int32())}), &builder));
+
+ auto options = ParseOptions::Defaults();
+ std::shared_ptr<ChunkedArray> actual, expected;
+
+ std::shared_ptr<Array> parsed;
+ ASSERT_OK(ParseFromString(options, RowsOfOneColumn("a", {-456}), &parsed));
+ builder->Insert(1, field("", parsed->type()), parsed);
+ ASSERT_OK(ParseFromString(options, RowsOfOneColumn("a", {123}), &parsed));
+ builder->Insert(0, field("", parsed->type()), parsed);
+
+ ASSERT_OK(builder->Finish(&actual));
+
+ ChunkedArrayFromVector<Int32Type>({{123}, {-456}}, &expected);
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+TEST(ChunkedArrayBuilder, MultipleChunks) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), nullptr,
+ struct_({field("a", int32())}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder,
+ {
+ RowsOfOneColumn("a", {1, 2, 3}),
+ RowsOfOneColumn("a", {4, 5}),
+ },
+ &actual);
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<Int32Type>({{1, 2, 3}, {4, 5}}, &expected);
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+TEST(ChunkedArrayBuilder, MultipleChunksParallel) {
+ auto tg = TaskGroup::MakeThreaded(GetCpuThreadPool());
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), nullptr,
+ struct_({field("a", int32())}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder,
+ {
+ RowsOfOneColumn("a", {1, 2}),
+ RowsOfOneColumn("a", {3}),
+ RowsOfOneColumn("a", {4, 5}),
+ RowsOfOneColumn("a", {6, 7}),
+ },
+ &actual);
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<Int32Type>({{1, 2}, {3}, {4, 5}, {6, 7}}, &expected);
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Tests for type-inferring chunked array builders
+
+TEST(InferringChunkedArrayBuilder, Empty) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), GetPromotionGraph(),
+ struct_({}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder, {}, &actual);
+
+ ASSERT_TRUE(actual->type()->Equals(*struct_({})));
+ ASSERT_EQ(actual->num_chunks(), 0);
+}
+
+TEST(InferringChunkedArrayBuilder, SingleChunkNull) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), GetPromotionGraph(),
+ struct_({}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder,
+ {
+ "{}\n" + RowsOfOneColumn("a", {"null", "null"}),
+ },
+ &actual);
+
+ ASSERT_TRUE(actual->type()->Equals(*struct_({field("a", null())})));
+ ASSERT_EQ(actual->length(), 3);
+}
+
+TEST(InferringChunkedArrayBuilder, MultipleChunkNull) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), GetPromotionGraph(),
+ struct_({}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder,
+ {
+ "{}\n{}\n",
+ "{}\n" + RowsOfOneColumn("a", {"null", "null"}),
+ RowsOfOneColumn("a", {"null"}),
+ RowsOfOneColumn("a", {"null", "null"}) + "{}\n",
+ },
+ &actual);
+
+ ASSERT_TRUE(actual->type()->Equals(*struct_({field("a", null())})));
+ ASSERT_EQ(actual->length(), 9);
+}
+
+TEST(InferringChunkedArrayBuilder, SingleChunkInteger) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), GetPromotionGraph(),
+ struct_({}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(
+ builder,
+ {
+ "{}\n" + RowsOfOneColumn("a", {123, 456}) + RowsOfOneColumn("a", {"null"}),
+ },
+ &actual);
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<Int64Type>({{false, true, true, false}}, {{0, 123, 456, 0}},
+ &expected);
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+TEST(InferringChunkedArrayBuilder, MultipleChunkInteger) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), GetPromotionGraph(),
+ struct_({}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder,
+ {
+ "{}\n{}\n",
+ RowsOfOneColumn("a", {"null"}),
+ "{}\n" + RowsOfOneColumn("a", {123, 456}),
+ },
+ &actual);
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<Int64Type>({{false, false}, {false}, {false, true, true}},
+ {{0, 0}, {0}, {0, 123, 456}}, &expected);
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+TEST(InferringChunkedArrayBuilder, SingleChunkDouble) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), GetPromotionGraph(),
+ struct_({}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(
+ builder,
+ {
+ "{}\n" + RowsOfOneColumn("a", {0.0, 12.5}) + RowsOfOneColumn("a", {"null"}),
+ },
+ &actual);
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<DoubleType>({{false, true, true, false}},
+ {{0.0, 0.0, 12.5, 0.0}}, &expected);
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+TEST(InferringChunkedArrayBuilder, MultipleChunkDouble) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), GetPromotionGraph(),
+ struct_({}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder,
+ {
+ "{}\n{}\n",
+ RowsOfOneColumn("a", {"null"}),
+ RowsOfOneColumn("a", {8}),
+ RowsOfOneColumn("a", {"null", "12.5"}),
+ },
+ &actual);
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<DoubleType>({{false, false}, {false}, {true}, {false, true}},
+ {{0.0, 0.0}, {0.0}, {8.0}, {0.0, 12.5}}, &expected);
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+TEST(InferringChunkedArrayBuilder, SingleChunkTimestamp) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), GetPromotionGraph(),
+ struct_({}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder,
+ {
+ "{}\n" + RowsOfOneColumn("a", {"null", "\"1970-01-01\"",
+ "\"2018-11-13 17:11:10\""}),
+ },
+ &actual);
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<TimestampType>(timestamp(TimeUnit::SECOND),
+ {{false, false, true, true}},
+ {{0, 0, 0, 1542129070}}, &expected);
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+TEST(InferringChunkedArrayBuilder, MultipleChunkTimestamp) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), GetPromotionGraph(),
+ struct_({}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder,
+ {
+ "{}\n{}\n",
+ RowsOfOneColumn("a", {"null"}),
+ RowsOfOneColumn("a", {"\"1970-01-01\""}),
+ RowsOfOneColumn("a", {"\"2018-11-13 17:11:10\""}),
+ },
+ &actual);
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<TimestampType>(timestamp(TimeUnit::SECOND),
+ {{false, false}, {false}, {true}, {true}},
+ {{0, 0}, {0}, {0}, {1542129070}}, &expected);
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+TEST(InferringChunkedArrayBuilder, SingleChunkString) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), GetPromotionGraph(),
+ struct_({}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(
+ builder,
+ {
+ "{}\n" + RowsOfOneColumn("a", {"null", "\"\"", "null", "\"foo\"", "\"baré\""}),
+ },
+ &actual);
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<StringType, std::string>(
+ {{false, false, true, false, true, true}}, {{"", "", "", "", "foo", "baré"}},
+ &expected);
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+TEST(InferringChunkedArrayBuilder, MultipleChunkString) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), GetPromotionGraph(),
+ struct_({}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder,
+ {
+ "{}\n{}\n",
+ RowsOfOneColumn("a", {"\"\"", "null"}),
+ RowsOfOneColumn("a", {"\"1970-01-01\""}),
+ RowsOfOneColumn("a", {"\"\"", "\"baré\""}),
+ },
+ &actual);
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<StringType, std::string>(
+ {{false, false}, {true, false}, {true}, {true, true}},
+ {{"", ""}, {"", ""}, {"1970-01-01"}, {"", "baré"}}, &expected);
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+TEST(InferringChunkedArrayBuilder, MultipleChunkIntegerParallel) {
+ auto tg = TaskGroup::MakeThreaded(GetCpuThreadPool());
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), GetPromotionGraph(),
+ struct_({}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ std::vector<std::string> chunks;
+ std::vector<std::vector<int>> expected_chunks;
+ for (int i = 0; i < 1 << 10; ++i) {
+ expected_chunks.push_back({i, i + 1, i + 2, i + 3});
+ chunks.push_back(RowsOfOneColumn("a", {i, i + 1, i + 2, i + 3}));
+ }
+ AssertBuilding(builder, chunks, &actual);
+
+ std::shared_ptr<ChunkedArray> expected;
+ ChunkedArrayFromVector<Int64Type>(expected_chunks, &expected);
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+TEST(InferringChunkedArrayBuilder, SingleChunkList) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), GetPromotionGraph(),
+ struct_({}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder,
+ {
+ std::string("{}\n") + "{\"a\": []}\n" + "{\"a\": [1, 2]}\n",
+ },
+ &actual);
+
+ auto expected = ChunkedArrayFromJSON(list(int64()), {"[null, [], [1, 2]]"});
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+TEST(InferringChunkedArrayBuilder, MultipleChunkList) {
+ auto tg = TaskGroup::MakeSerial();
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ ASSERT_OK(MakeChunkedArrayBuilder(tg, default_memory_pool(), GetPromotionGraph(),
+ struct_({}), &builder));
+
+ std::shared_ptr<ChunkedArray> actual;
+ AssertBuilding(builder, {"{}\n", "{\"a\": []}\n", "{\"a\": [1, 2]}\n", "{}\n"},
+ &actual);
+
+ auto expected =
+ ChunkedArrayFromJSON(list(int64()), {"[null]", "[[]]", "[[1, 2]]", "[null]"});
+ AssertFieldEqual({"a"}, actual, *expected);
+}
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/chunker.cc b/src/arrow/cpp/src/arrow/json/chunker.cc
new file mode 100644
index 000000000..b4b4d31eb
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/chunker.cc
@@ -0,0 +1,186 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/json/chunker.h"
+
+#include <algorithm>
+#include <utility>
+#include <vector>
+
+#include "arrow/json/rapidjson_defs.h"
+#include "rapidjson/reader.h"
+
+#include "arrow/buffer.h"
+#include "arrow/json/options.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+
+using internal::make_unique;
+using util::string_view;
+
+namespace json {
+
+namespace rj = arrow::rapidjson;
+
+static size_t ConsumeWhitespace(string_view view) {
+#ifdef RAPIDJSON_SIMD
+ auto data = view.data();
+ auto nonws_begin = rj::SkipWhitespace_SIMD(data, data + view.size());
+ return nonws_begin - data;
+#else
+ auto ws_count = view.find_first_not_of(" \t\r\n");
+ if (ws_count == string_view::npos) {
+ return view.size();
+ } else {
+ return ws_count;
+ }
+#endif
+}
+
+/// RapidJson custom stream for reading JSON stored in multiple buffers
+/// http://rapidjson.org/md_doc_stream.html#CustomStream
+class MultiStringStream {
+ public:
+ using Ch = char;
+ explicit MultiStringStream(std::vector<string_view> strings)
+ : strings_(std::move(strings)) {
+ std::reverse(strings_.begin(), strings_.end());
+ }
+ explicit MultiStringStream(const BufferVector& buffers) : strings_(buffers.size()) {
+ for (size_t i = 0; i < buffers.size(); ++i) {
+ strings_[i] = string_view(*buffers[i]);
+ }
+ std::reverse(strings_.begin(), strings_.end());
+ }
+ char Peek() const {
+ if (strings_.size() == 0) return '\0';
+ return strings_.back()[0];
+ }
+ char Take() {
+ if (strings_.size() == 0) return '\0';
+ char taken = strings_.back()[0];
+ if (strings_.back().size() == 1) {
+ strings_.pop_back();
+ } else {
+ strings_.back() = strings_.back().substr(1);
+ }
+ ++index_;
+ return taken;
+ }
+ size_t Tell() { return index_; }
+ void Put(char) { ARROW_LOG(FATAL) << "not implemented"; }
+ void Flush() { ARROW_LOG(FATAL) << "not implemented"; }
+ char* PutBegin() {
+ ARROW_LOG(FATAL) << "not implemented";
+ return nullptr;
+ }
+ size_t PutEnd(char*) {
+ ARROW_LOG(FATAL) << "not implemented";
+ return 0;
+ }
+
+ private:
+ size_t index_ = 0;
+ std::vector<string_view> strings_;
+};
+
+template <typename Stream>
+static size_t ConsumeWholeObject(Stream&& stream) {
+ static constexpr unsigned parse_flags = rj::kParseIterativeFlag |
+ rj::kParseStopWhenDoneFlag |
+ rj::kParseNumbersAsStringsFlag;
+ rj::BaseReaderHandler<rj::UTF8<>> handler;
+ rj::Reader reader;
+ // parse a single JSON object
+ switch (reader.Parse<parse_flags>(stream, handler).Code()) {
+ case rj::kParseErrorNone:
+ return stream.Tell();
+ case rj::kParseErrorDocumentEmpty:
+ return 0;
+ default:
+ // rapidjson emitted an error, the most recent object was partial
+ return string_view::npos;
+ }
+}
+
+namespace {
+
+// A BoundaryFinder implementation that assumes JSON objects can contain raw newlines,
+// and uses actual JSON parsing to delimit them.
+class ParsingBoundaryFinder : public BoundaryFinder {
+ public:
+ Status FindFirst(string_view partial, string_view block, int64_t* out_pos) override {
+ // NOTE: We could bubble up JSON parse errors here, but the actual parsing
+ // step will detect them later anyway.
+ auto length = ConsumeWholeObject(MultiStringStream({partial, block}));
+ if (length == string_view::npos) {
+ *out_pos = -1;
+ } else {
+ DCHECK_GE(length, partial.size());
+ DCHECK_LE(length, partial.size() + block.size());
+ *out_pos = static_cast<int64_t>(length - partial.size());
+ }
+ return Status::OK();
+ }
+
+ Status FindLast(util::string_view block, int64_t* out_pos) override {
+ const size_t block_length = block.size();
+ size_t consumed_length = 0;
+ while (consumed_length < block_length) {
+ rj::MemoryStream ms(reinterpret_cast<const char*>(block.data()), block.size());
+ using InputStream = rj::EncodedInputStream<rj::UTF8<>, rj::MemoryStream>;
+ auto length = ConsumeWholeObject(InputStream(ms));
+ if (length == string_view::npos || length == 0) {
+ // found incomplete object or block is empty
+ break;
+ }
+ consumed_length += length;
+ block = block.substr(length);
+ }
+ if (consumed_length == 0) {
+ *out_pos = -1;
+ } else {
+ consumed_length += ConsumeWhitespace(block);
+ DCHECK_LE(consumed_length, block_length);
+ *out_pos = static_cast<int64_t>(consumed_length);
+ }
+ return Status::OK();
+ }
+
+ Status FindNth(util::string_view partial, util::string_view block, int64_t count,
+ int64_t* out_pos, int64_t* num_found) override {
+ return Status::NotImplemented("ParsingBoundaryFinder::FindNth");
+ }
+};
+
+} // namespace
+
+std::unique_ptr<Chunker> MakeChunker(const ParseOptions& options) {
+ std::shared_ptr<BoundaryFinder> delimiter;
+ if (options.newlines_in_values) {
+ delimiter = std::make_shared<ParsingBoundaryFinder>();
+ } else {
+ delimiter = MakeNewlineBoundaryFinder();
+ }
+ return std::unique_ptr<Chunker>(new Chunker(std::move(delimiter)));
+}
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/chunker.h b/src/arrow/cpp/src/arrow/json/chunker.h
new file mode 100644
index 000000000..9ed85126d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/chunker.h
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/util/delimiting.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace json {
+
+struct ParseOptions;
+
+ARROW_EXPORT
+std::unique_ptr<Chunker> MakeChunker(const ParseOptions& options);
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/chunker_test.cc b/src/arrow/cpp/src/arrow/json/chunker_test.cc
new file mode 100644
index 000000000..1b4ea4d08
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/chunker_test.cc
@@ -0,0 +1,276 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <memory>
+#include <numeric>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/buffer.h"
+#include "arrow/json/chunker.h"
+#include "arrow/json/test_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace json {
+
+// Use no nested objects and no string literals containing braces in this test.
+// This way the positions of '{' and '}' can be used as simple proxies
+// for object begin/end.
+
+using util::string_view;
+
+template <typename Lines>
+static std::shared_ptr<Buffer> join(Lines&& lines, std::string delimiter,
+ bool delimiter_at_end = true) {
+ std::shared_ptr<Buffer> joined;
+ BufferVector line_buffers;
+ auto delimiter_buffer = std::make_shared<Buffer>(delimiter);
+ for (const auto& line : lines) {
+ line_buffers.push_back(std::make_shared<Buffer>(line));
+ line_buffers.push_back(delimiter_buffer);
+ }
+ if (!delimiter_at_end) {
+ line_buffers.pop_back();
+ }
+ return *ConcatenateBuffers(line_buffers);
+}
+
+static bool WhitespaceOnly(string_view s) {
+ return s.find_first_not_of(" \t\r\n") == string_view::npos;
+}
+
+static bool WhitespaceOnly(const std::shared_ptr<Buffer>& b) {
+ return WhitespaceOnly(string_view(*b));
+}
+
+static std::size_t ConsumeWholeObject(std::shared_ptr<Buffer>* buf) {
+ auto str = string_view(**buf);
+ auto fail = [buf] {
+ *buf = nullptr;
+ return string_view::npos;
+ };
+ if (WhitespaceOnly(str)) return fail();
+ auto open_brace = str.find_first_not_of(" \t\r\n");
+ if (str.at(open_brace) != '{') return fail();
+ auto close_brace = str.find_first_of("}");
+ if (close_brace == string_view::npos) return fail();
+ auto length = close_brace + 1;
+ *buf = SliceBuffer(*buf, length);
+ return length;
+}
+
+void AssertOnlyWholeObjects(Chunker& chunker, std::shared_ptr<Buffer> whole, int* count) {
+ *count = 0;
+ while (whole && !WhitespaceOnly(whole)) {
+ auto buf = whole;
+ if (ConsumeWholeObject(&whole) == string_view::npos) {
+ FAIL() << "Not a whole JSON object: '" << buf->ToString() << "'";
+ }
+ ++*count;
+ }
+}
+
+void AssertWholeObjects(Chunker& chunker, const std::shared_ptr<Buffer>& block,
+ int expected_count) {
+ std::shared_ptr<Buffer> whole, partial;
+ ASSERT_OK(chunker.Process(block, &whole, &partial));
+ int count;
+ AssertOnlyWholeObjects(chunker, whole, &count);
+ ASSERT_EQ(count, expected_count);
+}
+
+void AssertChunking(Chunker& chunker, std::shared_ptr<Buffer> buf, int total_count) {
+ // First chunkize whole JSON block
+ AssertWholeObjects(chunker, buf, total_count);
+
+ // Then chunkize incomplete substrings of the block
+ for (int i = 0; i < total_count; ++i) {
+ // ensure shearing the closing brace off the last object causes it to be chunked out
+ auto last_brace = string_view(*buf).find_last_of('}');
+ AssertWholeObjects(chunker, SliceBuffer(buf, 0, last_brace), total_count - i - 1);
+
+ // ensure skipping one object reduces the count by one
+ ASSERT_NE(ConsumeWholeObject(&buf), string_view::npos);
+ AssertWholeObjects(chunker, buf, total_count - i - 1);
+ }
+}
+
+void AssertChunkingBlockSize(Chunker& chunker, std::shared_ptr<Buffer> buf,
+ int64_t block_size, int expected_count) {
+ std::shared_ptr<Buffer> partial = Buffer::FromString({});
+ int64_t pos = 0;
+ int total_count = 0;
+ while (pos < buf->size()) {
+ int count;
+ auto block = SliceBuffer(buf, pos, std::min(block_size, buf->size() - pos));
+ pos += block->size();
+ std::shared_ptr<Buffer> completion, whole, next_partial;
+
+ if (pos == buf->size()) {
+ // Last block
+ ASSERT_OK(chunker.ProcessFinal(partial, block, &completion, &whole));
+ } else {
+ std::shared_ptr<Buffer> starts_with_whole;
+ ASSERT_OK(
+ chunker.ProcessWithPartial(partial, block, &completion, &starts_with_whole));
+ ASSERT_OK(chunker.Process(starts_with_whole, &whole, &next_partial));
+ }
+ // partial + completion should be a valid JSON block
+ ASSERT_OK_AND_ASSIGN(partial, ConcatenateBuffers({partial, completion}));
+ AssertOnlyWholeObjects(chunker, partial, &count);
+ total_count += count;
+ // whole should be a valid JSON block
+ AssertOnlyWholeObjects(chunker, whole, &count);
+ total_count += count;
+ partial = next_partial;
+ }
+ ASSERT_EQ(pos, buf->size());
+ ASSERT_EQ(total_count, expected_count);
+}
+
+void AssertStraddledChunking(Chunker& chunker, const std::shared_ptr<Buffer>& buf) {
+ auto first_half = SliceBuffer(buf, 0, buf->size() / 2);
+ auto second_half = SliceBuffer(buf, buf->size() / 2);
+ AssertChunking(chunker, first_half, 1);
+ std::shared_ptr<Buffer> first_whole, partial;
+ ASSERT_OK(chunker.Process(first_half, &first_whole, &partial));
+ ASSERT_TRUE(string_view(*first_half).starts_with(string_view(*first_whole)));
+ std::shared_ptr<Buffer> completion, rest;
+ ASSERT_OK(chunker.ProcessWithPartial(partial, second_half, &completion, &rest));
+ ASSERT_TRUE(string_view(*second_half).starts_with(string_view(*completion)));
+ std::shared_ptr<Buffer> straddling;
+ ASSERT_OK_AND_ASSIGN(straddling, ConcatenateBuffers({partial, completion}));
+ auto length = ConsumeWholeObject(&straddling);
+ ASSERT_NE(length, string_view::npos);
+ ASSERT_NE(length, 0);
+ auto final_whole = SliceBuffer(second_half, completion->size());
+ ASSERT_EQ(string_view(*final_whole), string_view(*rest));
+ length = ConsumeWholeObject(&final_whole);
+ ASSERT_NE(length, string_view::npos);
+ ASSERT_NE(length, 0);
+}
+
+std::unique_ptr<Chunker> MakeChunker(bool newlines_in_values) {
+ auto options = ParseOptions::Defaults();
+ options.newlines_in_values = newlines_in_values;
+ return MakeChunker(options);
+}
+
+class BaseChunkerTest : public ::testing::TestWithParam<bool> {
+ protected:
+ void SetUp() override { chunker_ = MakeChunker(GetParam()); }
+
+ std::unique_ptr<Chunker> chunker_;
+};
+
+INSTANTIATE_TEST_SUITE_P(NoNewlineChunkerTest, BaseChunkerTest, ::testing::Values(false));
+
+INSTANTIATE_TEST_SUITE_P(ChunkerTest, BaseChunkerTest, ::testing::Values(true));
+
+constexpr int object_count = 4;
+constexpr int min_block_size = 28;
+
+static const std::vector<std::string>& lines() {
+ // clang-format off
+ static const std::vector<std::string> l = {
+ R"({"0":"ab","1":"c","2":""})",
+ R"({"0":"def","1":"","2":"gh"})",
+ R"({"0":null})",
+ R"({"0":"","1":"ij","2":"kl"})"
+ };
+ // clang-format on
+ return l;
+}
+
+TEST_P(BaseChunkerTest, Basics) {
+ AssertChunking(*chunker_, join(lines(), "\n"), object_count);
+}
+
+TEST_P(BaseChunkerTest, BlockSizes) {
+ auto check_block_sizes = [&](std::shared_ptr<Buffer> data) {
+ for (int64_t block_size = min_block_size; block_size < min_block_size + 30;
+ ++block_size) {
+ AssertChunkingBlockSize(*chunker_, data, block_size, object_count);
+ }
+ };
+
+ check_block_sizes(join(lines(), "\n"));
+ check_block_sizes(join(lines(), "\r\n"));
+ // Without ending newline
+ check_block_sizes(join(lines(), "\n", false));
+ check_block_sizes(join(lines(), "\r\n", false));
+}
+
+TEST_P(BaseChunkerTest, Empty) {
+ auto empty = std::make_shared<Buffer>("\n");
+ AssertChunking(*chunker_, empty, 0);
+ empty = std::make_shared<Buffer>("\n\n");
+ AssertChunking(*chunker_, empty, 0);
+}
+
+TEST(ChunkerTest, PrettyPrinted) {
+ std::string pretty[object_count];
+ std::transform(std::begin(lines()), std::end(lines()), std::begin(pretty), PrettyPrint);
+ auto chunker = MakeChunker(true);
+ AssertChunking(*chunker, join(pretty, "\n"), object_count);
+}
+
+TEST(ChunkerTest, SingleLine) {
+ auto chunker = MakeChunker(true);
+ auto single_line = join(lines(), "");
+ AssertChunking(*chunker, single_line, object_count);
+}
+
+TEST_P(BaseChunkerTest, Straddling) {
+ AssertStraddledChunking(*chunker_, join(lines(), "\n"));
+}
+
+TEST(ChunkerTest, StraddlingPrettyPrinted) {
+ std::string pretty[object_count];
+ std::transform(std::begin(lines()), std::end(lines()), std::begin(pretty), PrettyPrint);
+ auto chunker = MakeChunker(true);
+ AssertStraddledChunking(*chunker, join(pretty, "\n"));
+}
+
+TEST(ChunkerTest, StraddlingSingleLine) {
+ auto chunker = MakeChunker(true);
+ AssertStraddledChunking(*chunker, join(lines(), ""));
+}
+
+TEST_P(BaseChunkerTest, StraddlingEmpty) {
+ auto all = join(lines(), "\n");
+
+ auto first = SliceBuffer(all, 0, lines()[0].size() + 1);
+ std::shared_ptr<Buffer> first_whole, partial;
+ ASSERT_OK(chunker_->Process(first, &first_whole, &partial));
+ ASSERT_TRUE(WhitespaceOnly(partial));
+
+ auto others = SliceBuffer(all, first->size());
+ std::shared_ptr<Buffer> completion, rest;
+ ASSERT_OK(chunker_->ProcessWithPartial(partial, others, &completion, &rest));
+ ASSERT_EQ(completion->size(), 0);
+ ASSERT_TRUE(rest->Equals(*others));
+}
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/converter.cc b/src/arrow/cpp/src/arrow/json/converter.cc
new file mode 100644
index 000000000..a2f584c0b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/converter.cc
@@ -0,0 +1,362 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/json/converter.h"
+
+#include <memory>
+#include <utility>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/builder_time.h"
+#include "arrow/json/parser.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/value_parsing.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using util::string_view;
+
+namespace json {
+
+template <typename... Args>
+Status GenericConversionError(const DataType& type, Args&&... args) {
+ return Status::Invalid("Failed of conversion of JSON to ", type,
+ std::forward<Args>(args)...);
+}
+
+namespace {
+
+const DictionaryArray& GetDictionaryArray(const std::shared_ptr<Array>& in) {
+ DCHECK_EQ(in->type_id(), Type::DICTIONARY);
+ auto dict_type = checked_cast<const DictionaryType*>(in->type().get());
+ DCHECK_EQ(dict_type->index_type()->id(), Type::INT32);
+ DCHECK_EQ(dict_type->value_type()->id(), Type::STRING);
+ return checked_cast<const DictionaryArray&>(*in);
+}
+
+template <typename ValidVisitor, typename NullVisitor>
+Status VisitDictionaryEntries(const DictionaryArray& dict_array,
+ ValidVisitor&& visit_valid, NullVisitor&& visit_null) {
+ const StringArray& dict = checked_cast<const StringArray&>(*dict_array.dictionary());
+ const Int32Array& indices = checked_cast<const Int32Array&>(*dict_array.indices());
+ for (int64_t i = 0; i < indices.length(); ++i) {
+ if (indices.IsValid(i)) {
+ RETURN_NOT_OK(visit_valid(dict.GetView(indices.GetView(i))));
+ } else {
+ RETURN_NOT_OK(visit_null());
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+// base class for types which accept and output non-nested types
+class PrimitiveConverter : public Converter {
+ public:
+ PrimitiveConverter(MemoryPool* pool, std::shared_ptr<DataType> out_type)
+ : Converter(pool, out_type) {}
+};
+
+class NullConverter : public PrimitiveConverter {
+ public:
+ using PrimitiveConverter::PrimitiveConverter;
+
+ Status Convert(const std::shared_ptr<Array>& in, std::shared_ptr<Array>* out) override {
+ if (in->type_id() != Type::NA) {
+ return GenericConversionError(*out_type_, " from ", *in->type());
+ }
+ *out = in;
+ return Status::OK();
+ }
+};
+
+class BooleanConverter : public PrimitiveConverter {
+ public:
+ using PrimitiveConverter::PrimitiveConverter;
+
+ Status Convert(const std::shared_ptr<Array>& in, std::shared_ptr<Array>* out) override {
+ if (in->type_id() == Type::NA) {
+ return MakeArrayOfNull(boolean(), in->length(), pool_).Value(out);
+ }
+ if (in->type_id() != Type::BOOL) {
+ return GenericConversionError(*out_type_, " from ", *in->type());
+ }
+ *out = in;
+ return Status::OK();
+ }
+};
+
+template <typename T>
+class NumericConverter : public PrimitiveConverter {
+ public:
+ using value_type = typename T::c_type;
+
+ NumericConverter(MemoryPool* pool, const std::shared_ptr<DataType>& type)
+ : PrimitiveConverter(pool, type), numeric_type_(checked_cast<const T&>(*type)) {}
+
+ Status Convert(const std::shared_ptr<Array>& in, std::shared_ptr<Array>* out) override {
+ if (in->type_id() == Type::NA) {
+ return MakeArrayOfNull(out_type_, in->length(), pool_).Value(out);
+ }
+ const auto& dict_array = GetDictionaryArray(in);
+
+ using Builder = typename TypeTraits<T>::BuilderType;
+ Builder builder(out_type_, pool_);
+ RETURN_NOT_OK(builder.Resize(dict_array.indices()->length()));
+
+ auto visit_valid = [&](string_view repr) {
+ value_type value;
+ if (!arrow::internal::ParseValue(numeric_type_, repr.data(), repr.size(), &value)) {
+ return GenericConversionError(*out_type_, ", couldn't parse:", repr);
+ }
+
+ builder.UnsafeAppend(value);
+ return Status::OK();
+ };
+
+ auto visit_null = [&]() {
+ builder.UnsafeAppendNull();
+ return Status::OK();
+ };
+
+ RETURN_NOT_OK(VisitDictionaryEntries(dict_array, visit_valid, visit_null));
+ return builder.Finish(out);
+ }
+
+ const T& numeric_type_;
+};
+
+template <typename T>
+class DecimalConverter : public PrimitiveConverter {
+ public:
+ using value_type = typename TypeTraits<T>::BuilderType::ValueType;
+
+ DecimalConverter(MemoryPool* pool, const std::shared_ptr<DataType>& type)
+ : PrimitiveConverter(pool, type) {}
+
+ Status Convert(const std::shared_ptr<Array>& in, std::shared_ptr<Array>* out) override {
+ if (in->type_id() == Type::NA) {
+ return MakeArrayOfNull(out_type_, in->length(), pool_).Value(out);
+ }
+ const auto& dict_array = GetDictionaryArray(in);
+
+ using Builder = typename TypeTraits<T>::BuilderType;
+ Builder builder(out_type_, pool_);
+ RETURN_NOT_OK(builder.Resize(dict_array.indices()->length()));
+
+ auto visit_valid = [&builder](string_view repr) {
+ ARROW_ASSIGN_OR_RAISE(value_type value,
+ TypeTraits<T>::BuilderType::ValueType::FromString(repr));
+ builder.UnsafeAppend(value);
+ return Status::OK();
+ };
+
+ auto visit_null = [&builder]() {
+ builder.UnsafeAppendNull();
+ return Status::OK();
+ };
+
+ RETURN_NOT_OK(VisitDictionaryEntries(dict_array, visit_valid, visit_null));
+ return builder.Finish(out);
+ }
+};
+
+template <typename DateTimeType>
+class DateTimeConverter : public PrimitiveConverter {
+ public:
+ DateTimeConverter(MemoryPool* pool, const std::shared_ptr<DataType>& type)
+ : PrimitiveConverter(pool, type), converter_(pool, repr_type()) {}
+
+ Status Convert(const std::shared_ptr<Array>& in, std::shared_ptr<Array>* out) override {
+ if (in->type_id() == Type::NA) {
+ return MakeArrayOfNull(out_type_, in->length(), pool_).Value(out);
+ }
+
+ std::shared_ptr<Array> repr;
+ RETURN_NOT_OK(converter_.Convert(in, &repr));
+
+ auto out_data = repr->data()->Copy();
+ out_data->type = out_type_;
+ *out = MakeArray(out_data);
+
+ return Status::OK();
+ }
+
+ private:
+ using ReprType = typename CTypeTraits<typename DateTimeType::c_type>::ArrowType;
+ static std::shared_ptr<DataType> repr_type() {
+ return TypeTraits<ReprType>::type_singleton();
+ }
+ NumericConverter<ReprType> converter_;
+};
+
+template <typename T>
+class BinaryConverter : public PrimitiveConverter {
+ public:
+ using PrimitiveConverter::PrimitiveConverter;
+
+ Status Convert(const std::shared_ptr<Array>& in, std::shared_ptr<Array>* out) override {
+ if (in->type_id() == Type::NA) {
+ return MakeArrayOfNull(out_type_, in->length(), pool_).Value(out);
+ }
+ const auto& dict_array = GetDictionaryArray(in);
+
+ using Builder = typename TypeTraits<T>::BuilderType;
+ Builder builder(out_type_, pool_);
+ RETURN_NOT_OK(builder.Resize(dict_array.indices()->length()));
+
+ // TODO(bkietz) this can be computed during parsing at low cost
+ int64_t data_length = 0;
+ auto visit_lengths_valid = [&](string_view value) {
+ data_length += value.size();
+ return Status::OK();
+ };
+
+ auto visit_lengths_null = [&]() {
+ // no-op
+ return Status::OK();
+ };
+
+ RETURN_NOT_OK(
+ VisitDictionaryEntries(dict_array, visit_lengths_valid, visit_lengths_null));
+ RETURN_NOT_OK(builder.ReserveData(data_length));
+
+ auto visit_valid = [&](string_view value) {
+ builder.UnsafeAppend(value);
+ return Status::OK();
+ };
+
+ auto visit_null = [&]() {
+ builder.UnsafeAppendNull();
+ return Status::OK();
+ };
+
+ RETURN_NOT_OK(VisitDictionaryEntries(dict_array, visit_valid, visit_null));
+ return builder.Finish(out);
+ }
+};
+
+Status MakeConverter(const std::shared_ptr<DataType>& out_type, MemoryPool* pool,
+ std::shared_ptr<Converter>* out) {
+ switch (out_type->id()) {
+#define CONVERTER_CASE(TYPE_ID, CONVERTER_TYPE) \
+ case TYPE_ID: \
+ *out = std::make_shared<CONVERTER_TYPE>(pool, out_type); \
+ break
+ CONVERTER_CASE(Type::NA, NullConverter);
+ CONVERTER_CASE(Type::BOOL, BooleanConverter);
+ CONVERTER_CASE(Type::INT8, NumericConverter<Int8Type>);
+ CONVERTER_CASE(Type::INT16, NumericConverter<Int16Type>);
+ CONVERTER_CASE(Type::INT32, NumericConverter<Int32Type>);
+ CONVERTER_CASE(Type::INT64, NumericConverter<Int64Type>);
+ CONVERTER_CASE(Type::UINT8, NumericConverter<UInt8Type>);
+ CONVERTER_CASE(Type::UINT16, NumericConverter<UInt16Type>);
+ CONVERTER_CASE(Type::UINT32, NumericConverter<UInt32Type>);
+ CONVERTER_CASE(Type::UINT64, NumericConverter<UInt64Type>);
+ CONVERTER_CASE(Type::FLOAT, NumericConverter<FloatType>);
+ CONVERTER_CASE(Type::DOUBLE, NumericConverter<DoubleType>);
+ CONVERTER_CASE(Type::TIMESTAMP, NumericConverter<TimestampType>);
+ CONVERTER_CASE(Type::TIME32, DateTimeConverter<Time32Type>);
+ CONVERTER_CASE(Type::TIME64, DateTimeConverter<Time64Type>);
+ CONVERTER_CASE(Type::DATE32, DateTimeConverter<Date32Type>);
+ CONVERTER_CASE(Type::DATE64, DateTimeConverter<Date64Type>);
+ CONVERTER_CASE(Type::BINARY, BinaryConverter<BinaryType>);
+ CONVERTER_CASE(Type::STRING, BinaryConverter<StringType>);
+ CONVERTER_CASE(Type::LARGE_BINARY, BinaryConverter<LargeBinaryType>);
+ CONVERTER_CASE(Type::LARGE_STRING, BinaryConverter<LargeStringType>);
+ CONVERTER_CASE(Type::DECIMAL128, DecimalConverter<Decimal128Type>);
+ CONVERTER_CASE(Type::DECIMAL256, DecimalConverter<Decimal256Type>);
+ default:
+ return Status::NotImplemented("JSON conversion to ", *out_type,
+ " is not supported");
+#undef CONVERTER_CASE
+ }
+ return Status::OK();
+}
+
+const PromotionGraph* GetPromotionGraph() {
+ static struct : PromotionGraph {
+ std::shared_ptr<Field> Null(const std::string& name) const override {
+ return field(name, null(), true, Kind::Tag(Kind::kNull));
+ }
+
+ std::shared_ptr<DataType> Infer(
+ const std::shared_ptr<Field>& unexpected_field) const override {
+ auto kind = Kind::FromTag(unexpected_field->metadata());
+ switch (kind) {
+ case Kind::kNull:
+ return null();
+
+ case Kind::kBoolean:
+ return boolean();
+
+ case Kind::kNumber:
+ return int64();
+
+ case Kind::kString:
+ return timestamp(TimeUnit::SECOND);
+
+ case Kind::kArray: {
+ const auto& type = checked_cast<const ListType&>(*unexpected_field->type());
+ auto value_field = type.value_field();
+ return list(value_field->WithType(Infer(value_field)));
+ }
+ case Kind::kObject: {
+ auto fields = unexpected_field->type()->fields();
+ for (auto& field : fields) {
+ field = field->WithType(Infer(field));
+ }
+ return struct_(std::move(fields));
+ }
+ default:
+ return nullptr;
+ }
+ }
+
+ std::shared_ptr<DataType> Promote(
+ const std::shared_ptr<DataType>& failed,
+ const std::shared_ptr<Field>& unexpected_field) const override {
+ switch (failed->id()) {
+ case Type::NA:
+ return Infer(unexpected_field);
+
+ case Type::TIMESTAMP:
+ return utf8();
+
+ case Type::INT64:
+ return float64();
+
+ default:
+ return nullptr;
+ }
+ }
+ } impl;
+
+ return &impl;
+}
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/converter.h b/src/arrow/cpp/src/arrow/json/converter.h
new file mode 100644
index 000000000..9a812dd3c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/converter.h
@@ -0,0 +1,94 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/status.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Array;
+class DataType;
+class Field;
+class MemoryPool;
+
+namespace json {
+
+/// \brief interface for conversion of Arrays
+///
+/// Converters are not required to be correct for arbitrary input- only
+/// for unconverted arrays emitted by a corresponding parser.
+class ARROW_EXPORT Converter {
+ public:
+ virtual ~Converter() = default;
+
+ /// convert an array
+ /// on failure, this converter may be promoted to another converter which
+ /// *can* convert the given input.
+ virtual Status Convert(const std::shared_ptr<Array>& in,
+ std::shared_ptr<Array>* out) = 0;
+
+ std::shared_ptr<DataType> out_type() const { return out_type_; }
+
+ MemoryPool* pool() { return pool_; }
+
+ protected:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Converter);
+
+ Converter(MemoryPool* pool, const std::shared_ptr<DataType>& out_type)
+ : pool_(pool), out_type_(out_type) {}
+
+ MemoryPool* pool_;
+ std::shared_ptr<DataType> out_type_;
+};
+
+/// \brief produce a single converter to the specified out_type
+ARROW_EXPORT Status MakeConverter(const std::shared_ptr<DataType>& out_type,
+ MemoryPool* pool, std::shared_ptr<Converter>* out);
+
+class ARROW_EXPORT PromotionGraph {
+ public:
+ virtual ~PromotionGraph() = default;
+
+ /// \brief produce a valid field which will be inferred as null
+ virtual std::shared_ptr<Field> Null(const std::string& name) const = 0;
+
+ /// \brief given an unexpected field encountered during parsing, return a type to which
+ /// it may be convertible (may return null if none is available)
+ virtual std::shared_ptr<DataType> Infer(
+ const std::shared_ptr<Field>& unexpected_field) const = 0;
+
+ /// \brief given a type to which conversion failed, return a promoted type to which
+ /// conversion may succeed (may return null if none is available)
+ virtual std::shared_ptr<DataType> Promote(
+ const std::shared_ptr<DataType>& failed,
+ const std::shared_ptr<Field>& unexpected_field) const = 0;
+
+ protected:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(PromotionGraph);
+ PromotionGraph() = default;
+};
+
+ARROW_EXPORT const PromotionGraph* GetPromotionGraph();
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/converter_test.cc b/src/arrow/cpp/src/arrow/json/converter_test.cc
new file mode 100644
index 000000000..030f2a7bc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/converter_test.cc
@@ -0,0 +1,214 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/json/converter.h"
+
+#include <gtest/gtest.h>
+
+#include <string>
+
+#include "arrow/json/options.h"
+#include "arrow/json/test_common.h"
+
+namespace arrow {
+namespace json {
+
+Result<std::shared_ptr<Array>> Convert(std::shared_ptr<DataType> type,
+ std::shared_ptr<Array> unconverted) {
+ std::shared_ptr<Array> converted;
+ // convert the array
+ std::shared_ptr<Converter> converter;
+ RETURN_NOT_OK(MakeConverter(type, default_memory_pool(), &converter));
+ RETURN_NOT_OK(converter->Convert(unconverted, &converted));
+ RETURN_NOT_OK(converted->ValidateFull());
+ return converted;
+}
+
+// bool, null are trivial pass throughs
+
+TEST(ConverterTest, Integers) {
+ for (auto int_type : {int8(), int16(), int32(), int64()}) {
+ ParseOptions options;
+ options.explicit_schema = schema({field("", int_type)});
+
+ std::string json_source = R"(
+ {"" : -0}
+ {"" : null}
+ {"" : -1}
+ {"" : 32}
+ {"" : -45}
+ {"" : 12}
+ {"" : -64}
+ {"" : 124}
+ )";
+
+ std::shared_ptr<StructArray> parse_array;
+ ASSERT_OK(ParseFromString(options, json_source, &parse_array));
+
+ // call to convert
+ ASSERT_OK_AND_ASSIGN(auto converted,
+ Convert(int_type, parse_array->GetFieldByName("")));
+
+ // assert equality
+ auto expected = ArrayFromJSON(int_type, R"([
+ -0, null, -1, 32, -45, 12, -64, 124])");
+
+ AssertArraysEqual(*expected, *converted);
+ }
+}
+
+TEST(ConverterTest, UnsignedIntegers) {
+ for (auto uint_type : {uint8(), uint16(), uint32(), uint64()}) {
+ ParseOptions options;
+ options.explicit_schema = schema({field("", uint_type)});
+
+ std::string json_source = R"(
+ {"" : 0}
+ {"" : null}
+ {"" : 1}
+ {"" : 32}
+ {"" : 45}
+ {"" : 12}
+ {"" : 64}
+ {"" : 124}
+ )";
+
+ std::shared_ptr<StructArray> parse_array;
+ ASSERT_OK(ParseFromString(options, json_source, &parse_array));
+
+ // call to convert
+ ASSERT_OK_AND_ASSIGN(auto converted,
+ Convert(uint_type, parse_array->GetFieldByName("")));
+
+ // assert equality
+ auto expected = ArrayFromJSON(uint_type, R"([
+ 0, null, 1, 32, 45, 12, 64, 124])");
+
+ AssertArraysEqual(*expected, *converted);
+ }
+}
+
+TEST(ConverterTest, Floats) {
+ for (auto float_type : {float32(), float64()}) {
+ ParseOptions options;
+ options.explicit_schema = schema({field("", float_type)});
+
+ std::string json_source = R"(
+ {"" : 0}
+ {"" : -0.0}
+ {"" : null}
+ {"" : 32.0}
+ {"" : 1e5}
+ )";
+
+ std::shared_ptr<StructArray> parse_array;
+ ASSERT_OK(ParseFromString(options, json_source, &parse_array));
+
+ // call to convert
+ ASSERT_OK_AND_ASSIGN(auto converted,
+ Convert(float_type, parse_array->GetFieldByName("")));
+
+ // assert equality
+ auto expected = ArrayFromJSON(float_type, R"([
+ 0, -0.0, null, 32.0, 1e5])");
+
+ AssertArraysEqual(*expected, *converted);
+ }
+}
+
+TEST(ConverterTest, StringAndLargeString) {
+ for (auto string_type : {utf8(), large_utf8()}) {
+ ParseOptions options;
+ options.explicit_schema = schema({field("", string_type)});
+
+ std::string json_source = R"(
+ {"" : "a"}
+ {"" : "b c"}
+ {"" : null}
+ {"" : "d e f"}
+ {"" : "g"}
+ )";
+
+ std::shared_ptr<StructArray> parse_array;
+ ASSERT_OK(ParseFromString(options, json_source, &parse_array));
+
+ // call to convert
+ ASSERT_OK_AND_ASSIGN(auto converted,
+ Convert(string_type, parse_array->GetFieldByName("")));
+
+ // assert equality
+ auto expected = ArrayFromJSON(string_type, R"([
+ "a", "b c", null, "d e f", "g"])");
+
+ AssertArraysEqual(*expected, *converted);
+ }
+}
+
+TEST(ConverterTest, Timestamp) {
+ auto timestamp_type = timestamp(TimeUnit::SECOND);
+
+ ParseOptions options;
+ options.explicit_schema = schema({field("", timestamp_type)});
+
+ std::string json_source = R"(
+ {"" : null}
+ {"" : "1970-01-01"}
+ {"" : "2018-11-13 17:11:10"}
+ )";
+
+ std::shared_ptr<StructArray> parse_array;
+ ASSERT_OK(ParseFromString(options, json_source, &parse_array));
+
+ // call to convert
+ ASSERT_OK_AND_ASSIGN(auto converted,
+ Convert(timestamp_type, parse_array->GetFieldByName("")));
+
+ // assert equality
+ auto expected = ArrayFromJSON(timestamp_type, R"([
+ null, "1970-01-01", "2018-11-13 17:11:10"])");
+
+ AssertArraysEqual(*expected, *converted);
+}
+
+TEST(ConverterTest, Decimal128And256) {
+ for (auto decimal_type : {decimal128(38, 10), decimal256(38, 10)}) {
+ ParseOptions options;
+ options.explicit_schema = schema({field("", decimal_type)});
+
+ std::string json_source = R"(
+ {"" : "02.0000000000"}
+ {"" : "30.0000000000"}
+ )";
+
+ std::shared_ptr<StructArray> parse_array;
+ ASSERT_OK(ParseFromString(options, json_source, &parse_array));
+
+ // call to convert
+ ASSERT_OK_AND_ASSIGN(auto converted,
+ Convert(decimal_type, parse_array->GetFieldByName("")));
+
+ // assert equality
+ auto expected = ArrayFromJSON(decimal_type, R"([
+ "02.0000000000",
+ "30.0000000000"])");
+
+ AssertArraysEqual(*expected, *converted);
+ }
+}
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/object_parser.cc b/src/arrow/cpp/src/arrow/json/object_parser.cc
new file mode 100644
index 000000000..c857cd537
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/object_parser.cc
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/json/object_parser.h"
+#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
+
+#include <rapidjson/document.h>
+
+namespace arrow {
+namespace json {
+namespace internal {
+
+namespace rj = arrow::rapidjson;
+
+class ObjectParser::Impl {
+ public:
+ Status Parse(arrow::util::string_view json) {
+ document_.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()),
+ static_cast<size_t>(json.size()));
+
+ if (document_.HasParseError()) {
+ return Status::Invalid("Json parse error (offset ", document_.GetErrorOffset(),
+ "): ", document_.GetParseError());
+ }
+ if (!document_.IsObject()) {
+ return Status::TypeError("Not a json object");
+ }
+ return Status::OK();
+ }
+
+ Result<std::string> GetString(const char* key) const {
+ if (!document_.HasMember(key)) {
+ return Status::KeyError("Key '", key, "' does not exist");
+ }
+ if (!document_[key].IsString()) {
+ return Status::TypeError("Key '", key, "' is not a string");
+ }
+ return document_[key].GetString();
+ }
+
+ Result<bool> GetBool(const char* key) const {
+ if (!document_.HasMember(key)) {
+ return Status::KeyError("Key '", key, "' does not exist");
+ }
+ if (!document_[key].IsBool()) {
+ return Status::TypeError("Key '", key, "' is not a boolean");
+ }
+ return document_[key].GetBool();
+ }
+
+ private:
+ rj::Document document_;
+};
+
+ObjectParser::ObjectParser() : impl_(new ObjectParser::Impl()) {}
+
+ObjectParser::~ObjectParser() = default;
+
+Status ObjectParser::Parse(arrow::util::string_view json) { return impl_->Parse(json); }
+
+Result<std::string> ObjectParser::GetString(const char* key) const {
+ return impl_->GetString(key);
+}
+
+Result<bool> ObjectParser::GetBool(const char* key) const { return impl_->GetBool(key); }
+
+} // namespace internal
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/object_parser.h b/src/arrow/cpp/src/arrow/json/object_parser.h
new file mode 100644
index 000000000..ef9320165
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/object_parser.h
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/result.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace json {
+namespace internal {
+
+/// This class is a helper to parse a json object from a string.
+/// It uses rapidjson::Document in implementation.
+class ARROW_EXPORT ObjectParser {
+ public:
+ ObjectParser();
+ ~ObjectParser();
+
+ Status Parse(arrow::util::string_view json);
+
+ Result<std::string> GetString(const char* key) const;
+ Result<bool> GetBool(const char* key) const;
+
+ private:
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+} // namespace internal
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/object_writer.cc b/src/arrow/cpp/src/arrow/json/object_writer.cc
new file mode 100644
index 000000000..06d09f81e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/object_writer.cc
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/json/object_writer.h"
+#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
+
+#include <rapidjson/document.h>
+#include <rapidjson/stringbuffer.h>
+#include <rapidjson/writer.h>
+
+namespace rj = arrow::rapidjson;
+
+namespace arrow {
+namespace json {
+namespace internal {
+
+class ObjectWriter::Impl {
+ public:
+ Impl() : root_(rj::kObjectType) {}
+
+ void SetString(arrow::util::string_view key, arrow::util::string_view value) {
+ rj::Document::AllocatorType& allocator = document_.GetAllocator();
+
+ rj::Value str_key(key.data(), allocator);
+ rj::Value str_value(value.data(), allocator);
+
+ root_.AddMember(str_key, str_value, allocator);
+ }
+
+ void SetBool(arrow::util::string_view key, bool value) {
+ rj::Document::AllocatorType& allocator = document_.GetAllocator();
+
+ rj::Value str_key(key.data(), allocator);
+
+ root_.AddMember(str_key, value, allocator);
+ }
+
+ std::string Serialize() {
+ rj::StringBuffer buffer;
+ rj::Writer<rj::StringBuffer> writer(buffer);
+ root_.Accept(writer);
+
+ return buffer.GetString();
+ }
+
+ private:
+ rj::Document document_;
+ rj::Value root_;
+};
+
+ObjectWriter::ObjectWriter() : impl_(new ObjectWriter::Impl()) {}
+
+ObjectWriter::~ObjectWriter() = default;
+
+void ObjectWriter::SetString(arrow::util::string_view key,
+ arrow::util::string_view value) {
+ impl_->SetString(key, value);
+}
+
+void ObjectWriter::SetBool(arrow::util::string_view key, bool value) {
+ impl_->SetBool(key, value);
+}
+
+std::string ObjectWriter::Serialize() { return impl_->Serialize(); }
+
+} // namespace internal
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/object_writer.h b/src/arrow/cpp/src/arrow/json/object_writer.h
new file mode 100644
index 000000000..55ff0ce52
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/object_writer.h
@@ -0,0 +1,48 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/util/string_view.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace json {
+namespace internal {
+
+/// This class is a helper to serialize a json object to a string.
+/// It uses rapidjson in implementation.
+class ARROW_EXPORT ObjectWriter {
+ public:
+ ObjectWriter();
+ ~ObjectWriter();
+
+ void SetString(arrow::util::string_view key, arrow::util::string_view value);
+ void SetBool(arrow::util::string_view key, bool value);
+
+ std::string Serialize();
+
+ private:
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+} // namespace internal
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/options.cc b/src/arrow/cpp/src/arrow/json/options.cc
new file mode 100644
index 000000000..dc5e628b1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/options.cc
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/json/options.h"
+
+namespace arrow {
+namespace json {
+
+ParseOptions ParseOptions::Defaults() { return ParseOptions(); }
+
+ReadOptions ReadOptions::Defaults() { return ReadOptions(); }
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/options.h b/src/arrow/cpp/src/arrow/json/options.h
new file mode 100644
index 000000000..d7edab9ce
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/options.h
@@ -0,0 +1,74 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/json/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class DataType;
+class Schema;
+
+namespace json {
+
+enum class UnexpectedFieldBehavior : char {
+ /// Unexpected JSON fields are ignored
+ Ignore,
+ /// Unexpected JSON fields error out
+ Error,
+ /// Unexpected JSON fields are type-inferred and included in the output
+ InferType
+};
+
+struct ARROW_EXPORT ParseOptions {
+ // Parsing options
+
+ /// Optional explicit schema (disables type inference on those fields)
+ std::shared_ptr<Schema> explicit_schema;
+
+ /// Whether objects may be printed across multiple lines (for example pretty-printed)
+ ///
+ /// If true, parsing may be slower.
+ bool newlines_in_values = false;
+
+ /// How JSON fields outside of explicit_schema (if given) are treated
+ UnexpectedFieldBehavior unexpected_field_behavior = UnexpectedFieldBehavior::InferType;
+
+ /// Create parsing options with default values
+ static ParseOptions Defaults();
+};
+
+struct ARROW_EXPORT ReadOptions {
+ // Reader options
+
+ /// Whether to use the global CPU thread pool
+ bool use_threads = true;
+ /// Block size we request from the IO layer; also determines the size of
+ /// chunks when use_threads is true
+ int32_t block_size = 1 << 20; // 1 MB
+
+ /// Create read options with default values
+ static ReadOptions Defaults();
+};
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/parser.cc b/src/arrow/cpp/src/arrow/json/parser.cc
new file mode 100644
index 000000000..16a2fa1ce
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/parser.cc
@@ -0,0 +1,1107 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/json/parser.h"
+
+#include <functional>
+#include <limits>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "arrow/json/rapidjson_defs.h"
+#include "rapidjson/error/en.h"
+#include "rapidjson/reader.h"
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/type.h"
+#include "arrow/util/bitset_stack.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/trie.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::BitsetStack;
+using internal::checked_cast;
+using internal::make_unique;
+using util::string_view;
+
+namespace json {
+
+namespace rj = arrow::rapidjson;
+
+template <typename... T>
+static Status ParseError(T&&... t) {
+ return Status::Invalid("JSON parse error: ", std::forward<T>(t)...);
+}
+
+const std::string& Kind::Name(Kind::type kind) {
+ static const std::string names[] = {"null", "boolean", "number",
+ "string", "array", "object"};
+
+ return names[kind];
+}
+
+const std::shared_ptr<const KeyValueMetadata>& Kind::Tag(Kind::type kind) {
+ static const std::shared_ptr<const KeyValueMetadata> tags[] = {
+ key_value_metadata({{"json_kind", Kind::Name(Kind::kNull)}}),
+ key_value_metadata({{"json_kind", Kind::Name(Kind::kBoolean)}}),
+ key_value_metadata({{"json_kind", Kind::Name(Kind::kNumber)}}),
+ key_value_metadata({{"json_kind", Kind::Name(Kind::kString)}}),
+ key_value_metadata({{"json_kind", Kind::Name(Kind::kArray)}}),
+ key_value_metadata({{"json_kind", Kind::Name(Kind::kObject)}}),
+ };
+ return tags[kind];
+}
+
+static arrow::internal::Trie MakeFromTagTrie() {
+ arrow::internal::TrieBuilder builder;
+ for (auto kind : {Kind::kNull, Kind::kBoolean, Kind::kNumber, Kind::kString,
+ Kind::kArray, Kind::kObject}) {
+ DCHECK_OK(builder.Append(Kind::Name(kind)));
+ }
+ auto name_to_kind = builder.Finish();
+ DCHECK_OK(name_to_kind.Validate());
+ return name_to_kind;
+}
+
+Kind::type Kind::FromTag(const std::shared_ptr<const KeyValueMetadata>& tag) {
+ static arrow::internal::Trie name_to_kind = MakeFromTagTrie();
+ DCHECK_NE(tag->FindKey("json_kind"), -1);
+ util::string_view name = tag->value(tag->FindKey("json_kind"));
+ DCHECK_NE(name_to_kind.Find(name), -1);
+ return static_cast<Kind::type>(name_to_kind.Find(name));
+}
+
+Status Kind::ForType(const DataType& type, Kind::type* kind) {
+ struct {
+ Status Visit(const NullType&) { return SetKind(Kind::kNull); }
+ Status Visit(const BooleanType&) { return SetKind(Kind::kBoolean); }
+ Status Visit(const NumberType&) { return SetKind(Kind::kNumber); }
+ Status Visit(const TimeType&) { return SetKind(Kind::kNumber); }
+ Status Visit(const DateType&) { return SetKind(Kind::kNumber); }
+ Status Visit(const BinaryType&) { return SetKind(Kind::kString); }
+ Status Visit(const LargeBinaryType&) { return SetKind(Kind::kString); }
+ Status Visit(const TimestampType&) { return SetKind(Kind::kString); }
+ Status Visit(const FixedSizeBinaryType&) { return SetKind(Kind::kString); }
+ Status Visit(const DictionaryType& dict_type) {
+ return Kind::ForType(*dict_type.value_type(), kind_);
+ }
+ Status Visit(const ListType&) { return SetKind(Kind::kArray); }
+ Status Visit(const StructType&) { return SetKind(Kind::kObject); }
+ Status Visit(const DataType& not_impl) {
+ return Status::NotImplemented("JSON parsing of ", not_impl);
+ }
+ Status SetKind(Kind::type kind) {
+ *kind_ = kind;
+ return Status::OK();
+ }
+ Kind::type* kind_;
+ } visitor = {kind};
+ return VisitTypeInline(type, &visitor);
+}
+
+/// \brief ArrayBuilder for parsed but unconverted arrays
+template <Kind::type>
+class RawArrayBuilder;
+
+/// \brief packed pointer to a RawArrayBuilder
+///
+/// RawArrayBuilders are stored in HandlerBase,
+/// which allows storage of their indices (uint32_t) instead of a full pointer.
+/// BuilderPtr is also tagged with the json kind and nullable properties
+/// so those can be accessed before dereferencing the builder.
+struct BuilderPtr {
+ BuilderPtr() : BuilderPtr(BuilderPtr::null) {}
+ BuilderPtr(Kind::type k, uint32_t i, bool n) : index(i), kind(k), nullable(n) {}
+
+ BuilderPtr(const BuilderPtr&) = default;
+ BuilderPtr& operator=(const BuilderPtr&) = default;
+ BuilderPtr(BuilderPtr&&) = default;
+ BuilderPtr& operator=(BuilderPtr&&) = default;
+
+ // index of builder in its arena
+ // OR the length of that builder if kind == Kind::kNull
+ // (we don't allocate an arena for nulls since they're trivial)
+ uint32_t index;
+ Kind::type kind;
+ bool nullable;
+
+ bool operator==(BuilderPtr other) const {
+ return kind == other.kind && index == other.index;
+ }
+
+ bool operator!=(BuilderPtr other) const { return !(other == *this); }
+
+ operator bool() const { return *this != null; }
+
+ bool operator!() const { return *this == null; }
+
+ // The static BuilderPtr for null type data
+ static const BuilderPtr null;
+};
+
+const BuilderPtr BuilderPtr::null(Kind::kNull, 0, true);
+
+template <>
+class RawArrayBuilder<Kind::kBoolean> {
+ public:
+ explicit RawArrayBuilder(MemoryPool* pool)
+ : data_builder_(pool), null_bitmap_builder_(pool) {}
+
+ Status Append(bool value) {
+ RETURN_NOT_OK(data_builder_.Append(value));
+ return null_bitmap_builder_.Append(true);
+ }
+
+ Status AppendNull() {
+ RETURN_NOT_OK(data_builder_.Append(false));
+ return null_bitmap_builder_.Append(false);
+ }
+
+ Status AppendNull(int64_t count) {
+ RETURN_NOT_OK(data_builder_.Append(count, false));
+ return null_bitmap_builder_.Append(count, false);
+ }
+
+ Status Finish(std::shared_ptr<Array>* out) {
+ auto size = length();
+ auto null_count = null_bitmap_builder_.false_count();
+ std::shared_ptr<Buffer> data, null_bitmap;
+ RETURN_NOT_OK(data_builder_.Finish(&data));
+ RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
+ *out = MakeArray(ArrayData::Make(boolean(), size, {null_bitmap, data}, null_count));
+ return Status::OK();
+ }
+
+ int64_t length() { return null_bitmap_builder_.length(); }
+
+ private:
+ TypedBufferBuilder<bool> data_builder_;
+ TypedBufferBuilder<bool> null_bitmap_builder_;
+};
+
+/// \brief builder for strings or unconverted numbers
+///
+/// Both of these are represented in the builder as an index only;
+/// the actual characters are stored in a single StringArray (into which
+/// an index refers). This means building is faster since we don't do
+/// allocation for string/number characters but accessing is strided.
+///
+/// On completion the indices and the character storage are combined
+/// into a dictionary-encoded array, which is a convenient container
+/// for indices referring into another array.
+class ScalarBuilder {
+ public:
+ explicit ScalarBuilder(MemoryPool* pool)
+ : values_length_(0), data_builder_(pool), null_bitmap_builder_(pool) {}
+
+ Status Append(int32_t index, int32_t value_length) {
+ RETURN_NOT_OK(data_builder_.Append(index));
+ values_length_ += value_length;
+ return null_bitmap_builder_.Append(true);
+ }
+
+ Status AppendNull() {
+ RETURN_NOT_OK(data_builder_.Append(0));
+ return null_bitmap_builder_.Append(false);
+ }
+
+ Status AppendNull(int64_t count) {
+ RETURN_NOT_OK(data_builder_.Append(count, 0));
+ return null_bitmap_builder_.Append(count, false);
+ }
+
+ Status Finish(std::shared_ptr<Array>* out) {
+ auto size = length();
+ auto null_count = null_bitmap_builder_.false_count();
+ std::shared_ptr<Buffer> data, null_bitmap;
+ RETURN_NOT_OK(data_builder_.Finish(&data));
+ RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
+ *out = MakeArray(ArrayData::Make(int32(), size, {null_bitmap, data}, null_count));
+ return Status::OK();
+ }
+
+ int64_t length() { return null_bitmap_builder_.length(); }
+
+ int32_t values_length() { return values_length_; }
+
+ private:
+ int32_t values_length_;
+ TypedBufferBuilder<int32_t> data_builder_;
+ TypedBufferBuilder<bool> null_bitmap_builder_;
+};
+
+template <>
+class RawArrayBuilder<Kind::kNumber> : public ScalarBuilder {
+ public:
+ using ScalarBuilder::ScalarBuilder;
+};
+
+template <>
+class RawArrayBuilder<Kind::kString> : public ScalarBuilder {
+ public:
+ using ScalarBuilder::ScalarBuilder;
+};
+
+template <>
+class RawArrayBuilder<Kind::kArray> {
+ public:
+ explicit RawArrayBuilder(MemoryPool* pool)
+ : offset_builder_(pool), null_bitmap_builder_(pool) {}
+
+ Status Append(int32_t child_length) {
+ RETURN_NOT_OK(offset_builder_.Append(offset_));
+ offset_ += child_length;
+ return null_bitmap_builder_.Append(true);
+ }
+
+ Status AppendNull() {
+ RETURN_NOT_OK(offset_builder_.Append(offset_));
+ return null_bitmap_builder_.Append(false);
+ }
+
+ Status AppendNull(int64_t count) {
+ RETURN_NOT_OK(offset_builder_.Append(count, offset_));
+ return null_bitmap_builder_.Append(count, false);
+ }
+
+ Status Finish(std::function<Status(BuilderPtr, std::shared_ptr<Array>*)> finish_child,
+ std::shared_ptr<Array>* out) {
+ RETURN_NOT_OK(offset_builder_.Append(offset_));
+ auto size = length();
+ auto null_count = null_bitmap_builder_.false_count();
+ std::shared_ptr<Buffer> offsets, null_bitmap;
+ RETURN_NOT_OK(offset_builder_.Finish(&offsets));
+ RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
+ std::shared_ptr<Array> values;
+ RETURN_NOT_OK(finish_child(value_builder_, &values));
+ auto type = list(field("item", values->type(), value_builder_.nullable,
+ Kind::Tag(value_builder_.kind)));
+ *out = MakeArray(ArrayData::Make(type, size, {null_bitmap, offsets}, {values->data()},
+ null_count));
+ return Status::OK();
+ }
+
+ BuilderPtr value_builder() const { return value_builder_; }
+
+ void value_builder(BuilderPtr builder) { value_builder_ = builder; }
+
+ int64_t length() { return null_bitmap_builder_.length(); }
+
+ private:
+ BuilderPtr value_builder_ = BuilderPtr::null;
+ int32_t offset_ = 0;
+ TypedBufferBuilder<int32_t> offset_builder_;
+ TypedBufferBuilder<bool> null_bitmap_builder_;
+};
+
+template <>
+class RawArrayBuilder<Kind::kObject> {
+ public:
+ explicit RawArrayBuilder(MemoryPool* pool) : null_bitmap_builder_(pool) {}
+
+ Status Append() { return null_bitmap_builder_.Append(true); }
+
+ Status AppendNull() { return null_bitmap_builder_.Append(false); }
+
+ Status AppendNull(int64_t count) { return null_bitmap_builder_.Append(count, false); }
+
+ std::string FieldName(int i) const {
+ for (const auto& name_index : name_to_index_) {
+ if (name_index.second == i) {
+ return name_index.first;
+ }
+ }
+ return "";
+ }
+
+ int GetFieldIndex(const std::string& name) const {
+ auto it = name_to_index_.find(name);
+ if (it == name_to_index_.end()) {
+ return -1;
+ }
+ return it->second;
+ }
+
+ int AddField(std::string name, BuilderPtr builder) {
+ auto index = num_fields();
+ field_builders_.push_back(builder);
+ name_to_index_.emplace(std::move(name), index);
+ return index;
+ }
+
+ int num_fields() const { return static_cast<int>(field_builders_.size()); }
+
+ BuilderPtr field_builder(int index) const { return field_builders_[index]; }
+
+ void field_builder(int index, BuilderPtr builder) { field_builders_[index] = builder; }
+
+ Status Finish(std::function<Status(BuilderPtr, std::shared_ptr<Array>*)> finish_child,
+ std::shared_ptr<Array>* out) {
+ auto size = length();
+ auto null_count = null_bitmap_builder_.false_count();
+ std::shared_ptr<Buffer> null_bitmap;
+ RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
+
+ std::vector<string_view> field_names(num_fields());
+ for (const auto& name_index : name_to_index_) {
+ field_names[name_index.second] = name_index.first;
+ }
+
+ std::vector<std::shared_ptr<Field>> fields(num_fields());
+ std::vector<std::shared_ptr<ArrayData>> child_data(num_fields());
+ for (int i = 0; i < num_fields(); ++i) {
+ std::shared_ptr<Array> field_values;
+ RETURN_NOT_OK(finish_child(field_builders_[i], &field_values));
+ child_data[i] = field_values->data();
+ fields[i] = field(std::string(field_names[i]), field_values->type(),
+ field_builders_[i].nullable, Kind::Tag(field_builders_[i].kind));
+ }
+
+ *out = MakeArray(ArrayData::Make(struct_(std::move(fields)), size, {null_bitmap},
+ std::move(child_data), null_count));
+ return Status::OK();
+ }
+
+ int64_t length() { return null_bitmap_builder_.length(); }
+
+ private:
+ std::vector<BuilderPtr> field_builders_;
+ std::unordered_map<std::string, int> name_to_index_;
+ TypedBufferBuilder<bool> null_bitmap_builder_;
+};
+
+class RawBuilderSet {
+ public:
+ explicit RawBuilderSet(MemoryPool* pool) : pool_(pool) {}
+
+ /// Retrieve a pointer to a builder from a BuilderPtr
+ template <Kind::type kind>
+ enable_if_t<kind != Kind::kNull, RawArrayBuilder<kind>*> Cast(BuilderPtr builder) {
+ DCHECK_EQ(builder.kind, kind);
+ return arena<kind>().data() + builder.index;
+ }
+
+ /// construct a builder of statically defined kind
+ template <Kind::type kind>
+ Status MakeBuilder(int64_t leading_nulls, BuilderPtr* builder) {
+ builder->index = static_cast<uint32_t>(arena<kind>().size());
+ builder->kind = kind;
+ builder->nullable = true;
+ arena<kind>().emplace_back(RawArrayBuilder<kind>(pool_));
+ return Cast<kind>(*builder)->AppendNull(leading_nulls);
+ }
+
+ /// construct a builder of whatever kind corresponds to a DataType
+ Status MakeBuilder(const DataType& t, int64_t leading_nulls, BuilderPtr* builder) {
+ Kind::type kind;
+ RETURN_NOT_OK(Kind::ForType(t, &kind));
+ switch (kind) {
+ case Kind::kNull:
+ *builder = BuilderPtr(Kind::kNull, static_cast<uint32_t>(leading_nulls), true);
+ return Status::OK();
+
+ case Kind::kBoolean:
+ return MakeBuilder<Kind::kBoolean>(leading_nulls, builder);
+
+ case Kind::kNumber:
+ return MakeBuilder<Kind::kNumber>(leading_nulls, builder);
+
+ case Kind::kString:
+ return MakeBuilder<Kind::kString>(leading_nulls, builder);
+
+ case Kind::kArray: {
+ RETURN_NOT_OK(MakeBuilder<Kind::kArray>(leading_nulls, builder));
+ const auto& list_type = checked_cast<const ListType&>(t);
+
+ BuilderPtr value_builder;
+ RETURN_NOT_OK(MakeBuilder(*list_type.value_type(), 0, &value_builder));
+ value_builder.nullable = list_type.value_field()->nullable();
+
+ Cast<Kind::kArray>(*builder)->value_builder(value_builder);
+ return Status::OK();
+ }
+ case Kind::kObject: {
+ RETURN_NOT_OK(MakeBuilder<Kind::kObject>(leading_nulls, builder));
+ const auto& struct_type = checked_cast<const StructType&>(t);
+
+ for (const auto& f : struct_type.fields()) {
+ BuilderPtr field_builder;
+ RETURN_NOT_OK(MakeBuilder(*f->type(), leading_nulls, &field_builder));
+ field_builder.nullable = f->nullable();
+
+ Cast<Kind::kObject>(*builder)->AddField(f->name(), field_builder);
+ }
+ return Status::OK();
+ }
+ default:
+ return Status::NotImplemented("invalid builder type");
+ }
+ }
+
+ /// Appending null is slightly tricky since null count is stored inline
+ /// for builders of Kind::kNull. Append nulls using this helper
+ Status AppendNull(BuilderPtr parent, int field_index, BuilderPtr builder) {
+ if (ARROW_PREDICT_FALSE(!builder.nullable)) {
+ return ParseError("a required field was null");
+ }
+ switch (builder.kind) {
+ case Kind::kNull: {
+ DCHECK_EQ(builder, parent.kind == Kind::kArray
+ ? Cast<Kind::kArray>(parent)->value_builder()
+ : Cast<Kind::kObject>(parent)->field_builder(field_index));
+
+ // increment null count stored inline
+ builder.index += 1;
+
+ // update the parent, since changing builder doesn't affect parent
+ if (parent.kind == Kind::kArray) {
+ Cast<Kind::kArray>(parent)->value_builder(builder);
+ } else {
+ Cast<Kind::kObject>(parent)->field_builder(field_index, builder);
+ }
+ return Status::OK();
+ }
+ case Kind::kBoolean:
+ return Cast<Kind::kBoolean>(builder)->AppendNull();
+
+ case Kind::kNumber:
+ return Cast<Kind::kNumber>(builder)->AppendNull();
+
+ case Kind::kString:
+ return Cast<Kind::kString>(builder)->AppendNull();
+
+ case Kind::kArray:
+ return Cast<Kind::kArray>(builder)->AppendNull();
+
+ case Kind::kObject: {
+ auto struct_builder = Cast<Kind::kObject>(builder);
+ RETURN_NOT_OK(struct_builder->AppendNull());
+
+ for (int i = 0; i < struct_builder->num_fields(); ++i) {
+ auto field_builder = struct_builder->field_builder(i);
+ RETURN_NOT_OK(AppendNull(builder, i, field_builder));
+ }
+ return Status::OK();
+ }
+ default:
+ return Status::NotImplemented("invalid builder Kind");
+ }
+ }
+
+ Status Finish(const std::shared_ptr<Array>& scalar_values, BuilderPtr builder,
+ std::shared_ptr<Array>* out) {
+ auto finish_children = [this, &scalar_values](BuilderPtr child,
+ std::shared_ptr<Array>* out) {
+ return Finish(scalar_values, child, out);
+ };
+ switch (builder.kind) {
+ case Kind::kNull: {
+ auto length = static_cast<int64_t>(builder.index);
+ *out = std::make_shared<NullArray>(length);
+ return Status::OK();
+ }
+ case Kind::kBoolean:
+ return Cast<Kind::kBoolean>(builder)->Finish(out);
+
+ case Kind::kNumber:
+ return FinishScalar(scalar_values, Cast<Kind::kNumber>(builder), out);
+
+ case Kind::kString:
+ return FinishScalar(scalar_values, Cast<Kind::kString>(builder), out);
+
+ case Kind::kArray:
+ return Cast<Kind::kArray>(builder)->Finish(std::move(finish_children), out);
+
+ case Kind::kObject:
+ return Cast<Kind::kObject>(builder)->Finish(std::move(finish_children), out);
+
+ default:
+ return Status::NotImplemented("invalid builder kind");
+ }
+ }
+
+ private:
+ /// finish a column of scalar values (string or number)
+ Status FinishScalar(const std::shared_ptr<Array>& scalar_values, ScalarBuilder* builder,
+ std::shared_ptr<Array>* out) {
+ std::shared_ptr<Array> indices;
+ // TODO(bkietz) embed builder->values_length() in this output somehow
+ RETURN_NOT_OK(builder->Finish(&indices));
+ auto ty = dictionary(int32(), scalar_values->type());
+ *out = std::make_shared<DictionaryArray>(ty, indices, scalar_values);
+ return Status::OK();
+ }
+
+ template <Kind::type kind>
+ std::vector<RawArrayBuilder<kind>>& arena() {
+ return std::get<static_cast<std::size_t>(kind)>(arenas_);
+ }
+
+ MemoryPool* pool_;
+ std::tuple<std::tuple<>, std::vector<RawArrayBuilder<Kind::kBoolean>>,
+ std::vector<RawArrayBuilder<Kind::kNumber>>,
+ std::vector<RawArrayBuilder<Kind::kString>>,
+ std::vector<RawArrayBuilder<Kind::kArray>>,
+ std::vector<RawArrayBuilder<Kind::kObject>>>
+ arenas_;
+};
+
+/// Three implementations are provided for BlockParser, one for each
+/// UnexpectedFieldBehavior. However most of the logic is identical in each
+/// case, so the majority of the implementation is in this base class
+class HandlerBase : public BlockParser,
+ public rj::BaseReaderHandler<rj::UTF8<>, HandlerBase> {
+ public:
+ explicit HandlerBase(MemoryPool* pool)
+ : BlockParser(pool),
+ builder_set_(pool),
+ field_index_(-1),
+ scalar_values_builder_(pool) {}
+
+ /// Retrieve a pointer to a builder from a BuilderPtr
+ template <Kind::type kind>
+ enable_if_t<kind != Kind::kNull, RawArrayBuilder<kind>*> Cast(BuilderPtr builder) {
+ return builder_set_.Cast<kind>(builder);
+ }
+
+ /// Accessor for a stored error Status
+ Status Error() { return status_; }
+
+ /// \defgroup rapidjson-handler-interface functions expected by rj::Reader
+ ///
+ /// bool Key(const char* data, rj::SizeType size, ...) is omitted since
+ /// the behavior varies greatly between UnexpectedFieldBehaviors
+ ///
+ /// @{
+ bool Null() {
+ status_ = builder_set_.AppendNull(builder_stack_.back(), field_index_, builder_);
+ return status_.ok();
+ }
+
+ bool Bool(bool value) {
+ constexpr auto kind = Kind::kBoolean;
+ if (ARROW_PREDICT_FALSE(builder_.kind != kind)) {
+ status_ = IllegallyChangedTo(kind);
+ return status_.ok();
+ }
+ status_ = Cast<kind>(builder_)->Append(value);
+ return status_.ok();
+ }
+
+ bool RawNumber(const char* data, rj::SizeType size, ...) {
+ status_ = AppendScalar<Kind::kNumber>(builder_, string_view(data, size));
+ return status_.ok();
+ }
+
+ bool String(const char* data, rj::SizeType size, ...) {
+ status_ = AppendScalar<Kind::kString>(builder_, string_view(data, size));
+ return status_.ok();
+ }
+
+ bool StartObject() {
+ status_ = StartObjectImpl();
+ return status_.ok();
+ }
+
+ bool EndObject(...) {
+ status_ = EndObjectImpl();
+ return status_.ok();
+ }
+
+ bool StartArray() {
+ status_ = StartArrayImpl();
+ return status_.ok();
+ }
+
+ bool EndArray(rj::SizeType size) {
+ status_ = EndArrayImpl(size);
+ return status_.ok();
+ }
+ /// @}
+
+ /// \brief Set up builders using an expected Schema
+ Status Initialize(const std::shared_ptr<Schema>& s) {
+ auto type = struct_({});
+ if (s) {
+ type = struct_(s->fields());
+ }
+ return builder_set_.MakeBuilder(*type, 0, &builder_);
+ }
+
+ Status Finish(std::shared_ptr<Array>* parsed) override {
+ std::shared_ptr<Array> scalar_values;
+ RETURN_NOT_OK(scalar_values_builder_.Finish(&scalar_values));
+ return builder_set_.Finish(scalar_values, builder_, parsed);
+ }
+
+ /// \brief Emit path of current field for debugging purposes
+ std::string Path() {
+ std::string path;
+ for (size_t i = 0; i < builder_stack_.size(); ++i) {
+ auto builder = builder_stack_[i];
+ if (builder.kind == Kind::kArray) {
+ path += "/[]";
+ } else {
+ auto struct_builder = Cast<Kind::kObject>(builder);
+ auto field_index = field_index_;
+ if (i + 1 < field_index_stack_.size()) {
+ field_index = field_index_stack_[i + 1];
+ }
+ path += "/" + struct_builder->FieldName(field_index);
+ }
+ }
+ return path;
+ }
+
+ protected:
+ template <typename Handler, typename Stream>
+ Status DoParse(Handler& handler, Stream&& json) {
+ constexpr auto parse_flags = rj::kParseIterativeFlag | rj::kParseNanAndInfFlag |
+ rj::kParseStopWhenDoneFlag |
+ rj::kParseNumbersAsStringsFlag;
+
+ rj::Reader reader;
+
+ for (; num_rows_ < kMaxParserNumRows; ++num_rows_) {
+ auto ok = reader.Parse<parse_flags>(json, handler);
+ switch (ok.Code()) {
+ case rj::kParseErrorNone:
+ // parse the next object
+ continue;
+ case rj::kParseErrorDocumentEmpty:
+ // parsed all objects, finish
+ return Status::OK();
+ case rj::kParseErrorTermination:
+ // handler emitted an error
+ return handler.Error();
+ default:
+ // rj emitted an error
+ return ParseError(rj::GetParseError_En(ok.Code()), " in row ", num_rows_);
+ }
+ }
+ return Status::Invalid("Exceeded maximum rows");
+ }
+
+ template <typename Handler>
+ Status DoParse(Handler& handler, const std::shared_ptr<Buffer>& json) {
+ RETURN_NOT_OK(ReserveScalarStorage(json->size()));
+ rj::MemoryStream ms(reinterpret_cast<const char*>(json->data()), json->size());
+ using InputStream = rj::EncodedInputStream<rj::UTF8<>, rj::MemoryStream>;
+ return DoParse(handler, InputStream(ms));
+ }
+
+ /// \defgroup handlerbase-append-methods append non-nested values
+ ///
+ /// @{
+
+ template <Kind::type kind>
+ Status AppendScalar(BuilderPtr builder, string_view scalar) {
+ if (ARROW_PREDICT_FALSE(builder.kind != kind)) {
+ return IllegallyChangedTo(kind);
+ }
+ auto index = static_cast<int32_t>(scalar_values_builder_.length());
+ auto value_length = static_cast<int32_t>(scalar.size());
+ RETURN_NOT_OK(Cast<kind>(builder)->Append(index, value_length));
+ RETURN_NOT_OK(scalar_values_builder_.Reserve(1));
+ scalar_values_builder_.UnsafeAppend(scalar);
+ return Status::OK();
+ }
+
+ /// @}
+
+ Status StartObjectImpl() {
+ constexpr auto kind = Kind::kObject;
+ if (ARROW_PREDICT_FALSE(builder_.kind != kind)) {
+ return IllegallyChangedTo(kind);
+ }
+ auto struct_builder = Cast<kind>(builder_);
+ absent_fields_stack_.Push(struct_builder->num_fields(), true);
+ StartNested();
+ return struct_builder->Append();
+ }
+
+ /// \brief helper for Key() functions
+ ///
+ /// sets the field builder with name key, or returns false if
+ /// there is no field with that name
+ bool SetFieldBuilder(string_view key, bool* duplicate_keys) {
+ auto parent = Cast<Kind::kObject>(builder_stack_.back());
+ field_index_ = parent->GetFieldIndex(std::string(key));
+ if (ARROW_PREDICT_FALSE(field_index_ == -1)) {
+ return false;
+ }
+ if (field_index_ < absent_fields_stack_.TopSize()) {
+ *duplicate_keys = !absent_fields_stack_[field_index_];
+ } else {
+ // When field_index is beyond the range of absent_fields_stack_ we have a duplicated
+ // field that wasn't declared in schema or previous records.
+ *duplicate_keys = true;
+ }
+ if (*duplicate_keys) {
+ status_ = ParseError("Column(", Path(), ") was specified twice in row ", num_rows_);
+ return false;
+ }
+ builder_ = parent->field_builder(field_index_);
+ absent_fields_stack_[field_index_] = false;
+ return true;
+ }
+
+ Status EndObjectImpl() {
+ auto parent = builder_stack_.back();
+
+ auto expected_count = absent_fields_stack_.TopSize();
+ for (int i = 0; i < expected_count; ++i) {
+ if (!absent_fields_stack_[i]) {
+ continue;
+ }
+ auto field_builder = Cast<Kind::kObject>(parent)->field_builder(i);
+ if (ARROW_PREDICT_FALSE(!field_builder.nullable)) {
+ return ParseError("a required field was absent");
+ }
+ RETURN_NOT_OK(builder_set_.AppendNull(parent, i, field_builder));
+ }
+ absent_fields_stack_.Pop();
+ EndNested();
+ return Status::OK();
+ }
+
+ Status StartArrayImpl() {
+ constexpr auto kind = Kind::kArray;
+ if (ARROW_PREDICT_FALSE(builder_.kind != kind)) {
+ return IllegallyChangedTo(kind);
+ }
+ StartNested();
+ // append to the list builder in EndArrayImpl
+ builder_ = Cast<kind>(builder_)->value_builder();
+ return Status::OK();
+ }
+
+ Status EndArrayImpl(rj::SizeType size) {
+ EndNested();
+ // append to list_builder here
+ auto list_builder = Cast<Kind::kArray>(builder_);
+ return list_builder->Append(size);
+ }
+
+ /// helper method for StartArray and StartObject
+ /// adds the current builder to a stack so its
+ /// children can be visited and parsed.
+ void StartNested() {
+ field_index_stack_.push_back(field_index_);
+ field_index_ = -1;
+ builder_stack_.push_back(builder_);
+ }
+
+ /// helper method for EndArray and EndObject
+ /// replaces the current builder with its parent
+ /// so parsing of the parent can continue
+ void EndNested() {
+ field_index_ = field_index_stack_.back();
+ field_index_stack_.pop_back();
+ builder_ = builder_stack_.back();
+ builder_stack_.pop_back();
+ }
+
+ Status IllegallyChangedTo(Kind::type illegally_changed_to) {
+ return ParseError("Column(", Path(), ") changed from ", Kind::Name(builder_.kind),
+ " to ", Kind::Name(illegally_changed_to), " in row ", num_rows_);
+ }
+
+ /// Reserve storage for scalars, these can occupy almost all of the JSON buffer
+ Status ReserveScalarStorage(int64_t size) override {
+ auto available_storage = scalar_values_builder_.value_data_capacity() -
+ scalar_values_builder_.value_data_length();
+ if (size <= available_storage) {
+ return Status::OK();
+ }
+ return scalar_values_builder_.ReserveData(size - available_storage);
+ }
+
+ Status status_;
+ RawBuilderSet builder_set_;
+ BuilderPtr builder_;
+ // top of this stack is the parent of builder_
+ std::vector<BuilderPtr> builder_stack_;
+ // top of this stack refers to the fields of the highest *StructBuilder*
+ // in builder_stack_ (list builders don't have absent fields)
+ BitsetStack absent_fields_stack_;
+ // index of builder_ within its parent
+ int field_index_;
+ // top of this stack == field_index_
+ std::vector<int> field_index_stack_;
+ StringBuilder scalar_values_builder_;
+};
+
+template <UnexpectedFieldBehavior>
+class Handler;
+
+template <>
+class Handler<UnexpectedFieldBehavior::Error> : public HandlerBase {
+ public:
+ using HandlerBase::HandlerBase;
+
+ Status Parse(const std::shared_ptr<Buffer>& json) override {
+ return DoParse(*this, json);
+ }
+
+ /// \ingroup rapidjson-handler-interface
+ ///
+ /// if an unexpected field is encountered, emit a parse error and bail
+ bool Key(const char* key, rj::SizeType len, ...) {
+ bool duplicate_keys = false;
+ if (ARROW_PREDICT_FALSE(SetFieldBuilder(string_view(key, len), &duplicate_keys))) {
+ return true;
+ }
+ if (!duplicate_keys) {
+ status_ = ParseError("unexpected field");
+ }
+ return false;
+ }
+};
+
+template <>
+class Handler<UnexpectedFieldBehavior::Ignore> : public HandlerBase {
+ public:
+ using HandlerBase::HandlerBase;
+
+ Status Parse(const std::shared_ptr<Buffer>& json) override {
+ return DoParse(*this, json);
+ }
+
+ bool Null() {
+ if (Skipping()) {
+ return true;
+ }
+ return HandlerBase::Null();
+ }
+
+ bool Bool(bool value) {
+ if (Skipping()) {
+ return true;
+ }
+ return HandlerBase::Bool(value);
+ }
+
+ bool RawNumber(const char* data, rj::SizeType size, ...) {
+ if (Skipping()) {
+ return true;
+ }
+ return HandlerBase::RawNumber(data, size);
+ }
+
+ bool String(const char* data, rj::SizeType size, ...) {
+ if (Skipping()) {
+ return true;
+ }
+ return HandlerBase::String(data, size);
+ }
+
+ bool StartObject() {
+ ++depth_;
+ if (Skipping()) {
+ return true;
+ }
+ return HandlerBase::StartObject();
+ }
+
+ /// \ingroup rapidjson-handler-interface
+ ///
+ /// if an unexpected field is encountered, skip until its value has been consumed
+ bool Key(const char* key, rj::SizeType len, ...) {
+ MaybeStopSkipping();
+ if (Skipping()) {
+ return true;
+ }
+ bool duplicate_keys = false;
+ if (ARROW_PREDICT_TRUE(SetFieldBuilder(string_view(key, len), &duplicate_keys))) {
+ return true;
+ }
+ if (ARROW_PREDICT_FALSE(duplicate_keys)) {
+ return false;
+ }
+ skip_depth_ = depth_;
+ return true;
+ }
+
+ bool EndObject(...) {
+ MaybeStopSkipping();
+ --depth_;
+ if (Skipping()) {
+ return true;
+ }
+ return HandlerBase::EndObject();
+ }
+
+ bool StartArray() {
+ if (Skipping()) {
+ return true;
+ }
+ return HandlerBase::StartArray();
+ }
+
+ bool EndArray(rj::SizeType size) {
+ if (Skipping()) {
+ return true;
+ }
+ return HandlerBase::EndArray(size);
+ }
+
+ private:
+ bool Skipping() { return depth_ >= skip_depth_; }
+
+ void MaybeStopSkipping() {
+ if (skip_depth_ == depth_) {
+ skip_depth_ = std::numeric_limits<int>::max();
+ }
+ }
+
+ int depth_ = 0;
+ int skip_depth_ = std::numeric_limits<int>::max();
+};
+
+template <>
+class Handler<UnexpectedFieldBehavior::InferType> : public HandlerBase {
+ public:
+ using HandlerBase::HandlerBase;
+
+ Status Parse(const std::shared_ptr<Buffer>& json) override {
+ return DoParse(*this, json);
+ }
+
+ bool Bool(bool value) {
+ if (ARROW_PREDICT_FALSE(MaybePromoteFromNull<Kind::kBoolean>())) {
+ return false;
+ }
+ return HandlerBase::Bool(value);
+ }
+
+ bool RawNumber(const char* data, rj::SizeType size, ...) {
+ if (ARROW_PREDICT_FALSE(MaybePromoteFromNull<Kind::kNumber>())) {
+ return false;
+ }
+ return HandlerBase::RawNumber(data, size);
+ }
+
+ bool String(const char* data, rj::SizeType size, ...) {
+ if (ARROW_PREDICT_FALSE(MaybePromoteFromNull<Kind::kString>())) {
+ return false;
+ }
+ return HandlerBase::String(data, size);
+ }
+
+ bool StartObject() {
+ if (ARROW_PREDICT_FALSE(MaybePromoteFromNull<Kind::kObject>())) {
+ return false;
+ }
+ return HandlerBase::StartObject();
+ }
+
+ /// \ingroup rapidjson-handler-interface
+ ///
+ /// If an unexpected field is encountered, add a new builder to
+ /// the current parent builder. It is added as a NullBuilder with
+ /// (parent.length - 1) leading nulls. The next value parsed
+ /// will probably trigger promotion of this field from null
+ bool Key(const char* key, rj::SizeType len, ...) {
+ bool duplicate_keys = false;
+ if (ARROW_PREDICT_TRUE(SetFieldBuilder(string_view(key, len), &duplicate_keys))) {
+ return true;
+ }
+ if (ARROW_PREDICT_FALSE(duplicate_keys)) {
+ return false;
+ }
+ auto struct_builder = Cast<Kind::kObject>(builder_stack_.back());
+ auto leading_nulls = static_cast<uint32_t>(struct_builder->length() - 1);
+ builder_ = BuilderPtr(Kind::kNull, leading_nulls, true);
+ field_index_ = struct_builder->AddField(std::string(key, len), builder_);
+ return true;
+ }
+
+ bool StartArray() {
+ if (ARROW_PREDICT_FALSE(MaybePromoteFromNull<Kind::kArray>())) {
+ return false;
+ }
+ return HandlerBase::StartArray();
+ }
+
+ private:
+ // return true if a terminal error was encountered
+ template <Kind::type kind>
+ bool MaybePromoteFromNull() {
+ if (ARROW_PREDICT_TRUE(builder_.kind != Kind::kNull)) {
+ return false;
+ }
+ auto parent = builder_stack_.back();
+ if (parent.kind == Kind::kArray) {
+ auto list_builder = Cast<Kind::kArray>(parent);
+ DCHECK_EQ(list_builder->value_builder(), builder_);
+ status_ = builder_set_.MakeBuilder<kind>(builder_.index, &builder_);
+ if (ARROW_PREDICT_FALSE(!status_.ok())) {
+ return true;
+ }
+ list_builder = Cast<Kind::kArray>(parent);
+ list_builder->value_builder(builder_);
+ } else {
+ auto struct_builder = Cast<Kind::kObject>(parent);
+ DCHECK_EQ(struct_builder->field_builder(field_index_), builder_);
+ status_ = builder_set_.MakeBuilder<kind>(builder_.index, &builder_);
+ if (ARROW_PREDICT_FALSE(!status_.ok())) {
+ return true;
+ }
+ struct_builder = Cast<Kind::kObject>(parent);
+ struct_builder->field_builder(field_index_, builder_);
+ }
+ return false;
+ }
+};
+
+Status BlockParser::Make(MemoryPool* pool, const ParseOptions& options,
+ std::unique_ptr<BlockParser>* out) {
+ DCHECK(options.unexpected_field_behavior == UnexpectedFieldBehavior::InferType ||
+ options.explicit_schema != nullptr);
+
+ switch (options.unexpected_field_behavior) {
+ case UnexpectedFieldBehavior::Ignore: {
+ *out = make_unique<Handler<UnexpectedFieldBehavior::Ignore>>(pool);
+ break;
+ }
+ case UnexpectedFieldBehavior::Error: {
+ *out = make_unique<Handler<UnexpectedFieldBehavior::Error>>(pool);
+ break;
+ }
+ case UnexpectedFieldBehavior::InferType:
+ *out = make_unique<Handler<UnexpectedFieldBehavior::InferType>>(pool);
+ break;
+ }
+ return static_cast<HandlerBase&>(**out).Initialize(options.explicit_schema);
+}
+
+Status BlockParser::Make(const ParseOptions& options, std::unique_ptr<BlockParser>* out) {
+ return BlockParser::Make(default_memory_pool(), options, out);
+}
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/parser.h b/src/arrow/cpp/src/arrow/json/parser.h
new file mode 100644
index 000000000..4dd14e4b8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/parser.h
@@ -0,0 +1,101 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/json/options.h"
+#include "arrow/status.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Array;
+class Buffer;
+class MemoryPool;
+class KeyValueMetadata;
+class ResizableBuffer;
+
+namespace json {
+
+struct Kind {
+ enum type : uint8_t { kNull, kBoolean, kNumber, kString, kArray, kObject };
+
+ static const std::string& Name(Kind::type);
+
+ static const std::shared_ptr<const KeyValueMetadata>& Tag(Kind::type);
+
+ static Kind::type FromTag(const std::shared_ptr<const KeyValueMetadata>& tag);
+
+ static Status ForType(const DataType& type, Kind::type* kind);
+};
+
+constexpr int32_t kMaxParserNumRows = 100000;
+
+/// \class BlockParser
+/// \brief A reusable block-based parser for JSON data
+///
+/// The parser takes a block of newline delimited JSON data and extracts Arrays
+/// of unconverted strings which can be fed to a Converter to obtain a usable Array.
+///
+/// Note that in addition to parse errors (such as malformed JSON) some conversion
+/// errors are caught at parse time:
+/// - A null value in non-nullable column
+/// - Change in the JSON kind of a column. For example, if an explicit schema is provided
+/// which stipulates that field "a" is integral, a row of {"a": "not a number"} will
+/// result in an error. This also applies to fields outside an explicit schema.
+class ARROW_EXPORT BlockParser {
+ public:
+ virtual ~BlockParser() = default;
+
+ /// \brief Reserve storage for scalars parsed from a block of json
+ virtual Status ReserveScalarStorage(int64_t nbytes) = 0;
+
+ /// \brief Parse a block of data
+ virtual Status Parse(const std::shared_ptr<Buffer>& json) = 0;
+
+ /// \brief Extract parsed data
+ virtual Status Finish(std::shared_ptr<Array>* parsed) = 0;
+
+ /// \brief Return the number of parsed rows
+ int32_t num_rows() const { return num_rows_; }
+
+ /// \brief Construct a BlockParser
+ ///
+ /// \param[in] pool MemoryPool to use when constructing parsed array
+ /// \param[in] options ParseOptions to use when parsing JSON
+ /// \param[out] out constructed BlockParser
+ static Status Make(MemoryPool* pool, const ParseOptions& options,
+ std::unique_ptr<BlockParser>* out);
+
+ static Status Make(const ParseOptions& options, std::unique_ptr<BlockParser>* out);
+
+ protected:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(BlockParser);
+
+ explicit BlockParser(MemoryPool* pool) : pool_(pool) {}
+
+ MemoryPool* pool_;
+ int32_t num_rows_ = 0;
+};
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/parser_benchmark.cc b/src/arrow/cpp/src/arrow/json/parser_benchmark.cc
new file mode 100644
index 000000000..9b7047d78
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/parser_benchmark.cc
@@ -0,0 +1,164 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <string>
+
+#include "arrow/json/chunker.h"
+#include "arrow/json/options.h"
+#include "arrow/json/parser.h"
+#include "arrow/json/reader.h"
+#include "arrow/json/test_common.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+namespace json {
+
+std::shared_ptr<Schema> TestSchema() {
+ return schema({field("int", int32()), field("str", utf8())});
+}
+
+constexpr int seed = 0x432432;
+
+std::string TestJsonData(int num_rows, bool pretty = false) {
+ std::default_random_engine engine(seed);
+ std::string json;
+ for (int i = 0; i < num_rows; ++i) {
+ StringBuffer sb;
+ Writer writer(sb);
+ ABORT_NOT_OK(Generate(TestSchema(), engine, &writer));
+ json += pretty ? PrettyPrint(sb.GetString()) : sb.GetString();
+ json += "\n";
+ }
+
+ return json;
+}
+
+static void BenchmarkJSONChunking(benchmark::State& state,
+ const std::shared_ptr<Buffer>& json,
+ ParseOptions options) { // NOLINT non-const reference
+ auto chunker = MakeChunker(options);
+
+ for (auto _ : state) {
+ std::shared_ptr<Buffer> chunked, partial;
+ ABORT_NOT_OK(chunker->Process(json, &chunked, &partial));
+ }
+
+ state.SetBytesProcessed(state.iterations() * json->size());
+}
+
+static void ChunkJSONPrettyPrinted(
+ benchmark::State& state) { // NOLINT non-const reference
+ const int32_t num_rows = 5000;
+
+ auto options = ParseOptions::Defaults();
+ options.newlines_in_values = true;
+ options.explicit_schema = TestSchema();
+
+ auto json = TestJsonData(num_rows, /* pretty */ true);
+ BenchmarkJSONChunking(state, std::make_shared<Buffer>(json), options);
+}
+
+static void ChunkJSONLineDelimited(
+ benchmark::State& state) { // NOLINT non-const reference
+ const int32_t num_rows = 5000;
+
+ auto options = ParseOptions::Defaults();
+ options.newlines_in_values = false;
+ options.explicit_schema = TestSchema();
+
+ auto json = TestJsonData(num_rows);
+ BenchmarkJSONChunking(state, std::make_shared<Buffer>(json), options);
+ state.SetBytesProcessed(0);
+}
+
+static void BenchmarkJSONParsing(benchmark::State& state, // NOLINT non-const reference
+ const std::shared_ptr<Buffer>& json, int32_t num_rows,
+ ParseOptions options) {
+ for (auto _ : state) {
+ std::unique_ptr<BlockParser> parser;
+ ABORT_NOT_OK(BlockParser::Make(options, &parser));
+ ABORT_NOT_OK(parser->Parse(json));
+
+ std::shared_ptr<Array> parsed;
+ ABORT_NOT_OK(parser->Finish(&parsed));
+ }
+ state.SetBytesProcessed(state.iterations() * json->size());
+}
+
+static void ParseJSONBlockWithSchema(
+ benchmark::State& state) { // NOLINT non-const reference
+ const int32_t num_rows = 5000;
+ auto options = ParseOptions::Defaults();
+ options.unexpected_field_behavior = UnexpectedFieldBehavior::Error;
+ options.explicit_schema = TestSchema();
+
+ auto json = TestJsonData(num_rows);
+ BenchmarkJSONParsing(state, std::make_shared<Buffer>(json), num_rows, options);
+}
+
+static void BenchmarkJSONReading(benchmark::State& state, // NOLINT non-const reference
+ const std::string& json, int32_t num_rows,
+ ReadOptions read_options, ParseOptions parse_options) {
+ for (auto _ : state) {
+ std::shared_ptr<io::InputStream> input;
+ ABORT_NOT_OK(MakeStream(json, &input));
+
+ ASSERT_OK_AND_ASSIGN(auto reader, TableReader::Make(default_memory_pool(), input,
+ read_options, parse_options));
+
+ std::shared_ptr<Table> table = *reader->Read();
+ }
+
+ state.SetBytesProcessed(state.iterations() * json.size());
+}
+
+static void BenchmarkReadJSONBlockWithSchema(
+ benchmark::State& state, bool use_threads) { // NOLINT non-const reference
+ const int32_t num_rows = 500000;
+ auto read_options = ReadOptions::Defaults();
+ read_options.use_threads = use_threads;
+
+ auto parse_options = ParseOptions::Defaults();
+ parse_options.unexpected_field_behavior = UnexpectedFieldBehavior::Error;
+ parse_options.explicit_schema = TestSchema();
+
+ auto json = TestJsonData(num_rows);
+ BenchmarkJSONReading(state, json, num_rows, read_options, parse_options);
+}
+
+static void ReadJSONBlockWithSchemaSingleThread(
+ benchmark::State& state) { // NOLINT non-const reference
+ BenchmarkReadJSONBlockWithSchema(state, false);
+}
+
+static void ReadJSONBlockWithSchemaMultiThread(
+ benchmark::State& state) { // NOLINT non-const reference
+ BenchmarkReadJSONBlockWithSchema(state, true);
+}
+
+BENCHMARK(ChunkJSONPrettyPrinted);
+BENCHMARK(ChunkJSONLineDelimited);
+BENCHMARK(ParseJSONBlockWithSchema);
+
+BENCHMARK(ReadJSONBlockWithSchemaSingleThread);
+BENCHMARK(ReadJSONBlockWithSchemaMultiThread)->UseRealTime();
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/parser_test.cc b/src/arrow/cpp/src/arrow/json/parser_test.cc
new file mode 100644
index 000000000..2a44ed837
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/parser_test.cc
@@ -0,0 +1,265 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/json/parser.h"
+
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/json/options.h"
+#include "arrow/json/test_common.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace json {
+
+using util::string_view;
+
+void AssertUnconvertedStructArraysEqual(const StructArray& expected,
+ const StructArray& actual);
+
+void AssertUnconvertedArraysEqual(const Array& expected, const Array& actual) {
+ switch (actual.type_id()) {
+ case Type::BOOL:
+ case Type::NA:
+ return AssertArraysEqual(expected, actual);
+ case Type::DICTIONARY: {
+ ASSERT_EQ(expected.type_id(), Type::STRING);
+ std::shared_ptr<Array> actual_decoded;
+ ASSERT_OK(DecodeStringDictionary(checked_cast<const DictionaryArray&>(actual),
+ &actual_decoded));
+ return AssertArraysEqual(expected, *actual_decoded);
+ }
+ case Type::LIST: {
+ ASSERT_EQ(expected.type_id(), Type::LIST);
+ ASSERT_EQ(expected.null_count(), actual.null_count());
+ if (expected.null_count() != 0) {
+ AssertBufferEqual(*expected.null_bitmap(), *actual.null_bitmap());
+ }
+ const auto& expected_offsets = expected.data()->buffers[1];
+ const auto& actual_offsets = actual.data()->buffers[1];
+ AssertBufferEqual(*expected_offsets, *actual_offsets);
+ auto expected_values = checked_cast<const ListArray&>(expected).values();
+ auto actual_values = checked_cast<const ListArray&>(actual).values();
+ return AssertUnconvertedArraysEqual(*expected_values, *actual_values);
+ }
+ case Type::STRUCT:
+ ASSERT_EQ(expected.type_id(), Type::STRUCT);
+ return AssertUnconvertedStructArraysEqual(
+ checked_cast<const StructArray&>(expected),
+ checked_cast<const StructArray&>(actual));
+ default:
+ FAIL();
+ }
+}
+
+void AssertUnconvertedStructArraysEqual(const StructArray& expected,
+ const StructArray& actual) {
+ ASSERT_EQ(expected.num_fields(), actual.num_fields());
+ for (int i = 0; i < expected.num_fields(); ++i) {
+ auto expected_name = expected.type()->field(i)->name();
+ auto actual_name = actual.type()->field(i)->name();
+ ASSERT_EQ(expected_name, actual_name);
+ AssertUnconvertedArraysEqual(*expected.field(i), *actual.field(i));
+ }
+}
+
+void AssertParseColumns(ParseOptions options, string_view src_str,
+ const std::vector<std::shared_ptr<Field>>& fields,
+ const std::vector<std::string>& columns_json) {
+ std::shared_ptr<Array> parsed;
+ ASSERT_OK(ParseFromString(options, src_str, &parsed));
+ auto struct_array = std::static_pointer_cast<StructArray>(parsed);
+ for (size_t i = 0; i < fields.size(); ++i) {
+ auto column_expected = ArrayFromJSON(fields[i]->type(), columns_json[i]);
+ auto column = struct_array->GetFieldByName(fields[i]->name());
+ AssertUnconvertedArraysEqual(*column_expected, *column);
+ }
+}
+
+// TODO(bkietz) parameterize (at least some of) these tests over UnexpectedFieldBehavior
+
+TEST(BlockParserWithSchema, Basics) {
+ auto options = ParseOptions::Defaults();
+ options.explicit_schema =
+ schema({field("hello", float64()), field("world", boolean()), field("yo", utf8())});
+ options.unexpected_field_behavior = UnexpectedFieldBehavior::Ignore;
+ AssertParseColumns(
+ options, scalars_only_src(),
+ {field("hello", utf8()), field("world", boolean()), field("yo", utf8())},
+ {"[\"3.5\", \"3.25\", \"3.125\", \"0.0\"]", "[false, null, null, true]",
+ "[\"thing\", null, \"\xe5\xbf\x8d\", null]"});
+}
+
+TEST(BlockParserWithSchema, Empty) {
+ auto options = ParseOptions::Defaults();
+ options.explicit_schema =
+ schema({field("hello", float64()), field("world", boolean()), field("yo", utf8())});
+ options.unexpected_field_behavior = UnexpectedFieldBehavior::Ignore;
+ AssertParseColumns(
+ options, "",
+ {field("hello", utf8()), field("world", boolean()), field("yo", utf8())},
+ {"[]", "[]", "[]"});
+}
+
+TEST(BlockParserWithSchema, SkipFieldsOutsideSchema) {
+ auto options = ParseOptions::Defaults();
+ options.explicit_schema = schema({field("hello", float64()), field("yo", utf8())});
+ options.unexpected_field_behavior = UnexpectedFieldBehavior::Ignore;
+ AssertParseColumns(options, scalars_only_src(),
+ {field("hello", utf8()), field("yo", utf8())},
+ {"[\"3.5\", \"3.25\", \"3.125\", \"0.0\"]",
+ "[\"thing\", null, \"\xe5\xbf\x8d\", null]"});
+}
+
+class BlockParserTypeError : public ::testing::TestWithParam<UnexpectedFieldBehavior> {
+ public:
+ ParseOptions Options(std::shared_ptr<Schema> explicit_schema) {
+ auto options = ParseOptions::Defaults();
+ options.explicit_schema = std::move(explicit_schema);
+ options.unexpected_field_behavior = GetParam();
+ return options;
+ }
+};
+
+TEST_P(BlockParserTypeError, FailOnInconvertible) {
+ auto options = Options(schema({field("a", int32())}));
+ std::shared_ptr<Array> parsed;
+ Status error = ParseFromString(options, "{\"a\":0}\n{\"a\":true}", &parsed);
+ ASSERT_RAISES(Invalid, error);
+ EXPECT_THAT(
+ error.message(),
+ testing::StartsWith(
+ "JSON parse error: Column(/a) changed from number to boolean in row 1"));
+}
+
+TEST_P(BlockParserTypeError, FailOnNestedInconvertible) {
+ auto options = Options(schema({field("a", list(struct_({field("b", int32())})))}));
+ std::shared_ptr<Array> parsed;
+ Status error =
+ ParseFromString(options, "{\"a\":[{\"b\":0}]}\n{\"a\":[{\"b\":true}]}", &parsed);
+ ASSERT_RAISES(Invalid, error);
+ EXPECT_THAT(
+ error.message(),
+ testing::StartsWith(
+ "JSON parse error: Column(/a/[]/b) changed from number to boolean in row 1"));
+}
+
+TEST_P(BlockParserTypeError, FailOnDuplicateKeys) {
+ std::shared_ptr<Array> parsed;
+ Status error = ParseFromString(Options(schema({field("a", int32())})),
+ "{\"a\":0, \"a\":1}\n", &parsed);
+ ASSERT_RAISES(Invalid, error);
+ EXPECT_THAT(
+ error.message(),
+ testing::StartsWith("JSON parse error: Column(/a) was specified twice in row 0"));
+}
+
+TEST_P(BlockParserTypeError, FailOnDuplicateKeysNoSchema) {
+ std::shared_ptr<Array> parsed;
+ Status error =
+ ParseFromString(ParseOptions::Defaults(), "{\"a\":0, \"a\":1}\n", &parsed);
+
+ ASSERT_RAISES(Invalid, error);
+ EXPECT_THAT(
+ error.message(),
+ testing::StartsWith("JSON parse error: Column(/a) was specified twice in row 0"));
+}
+
+INSTANTIATE_TEST_SUITE_P(BlockParserTypeError, BlockParserTypeError,
+ ::testing::Values(UnexpectedFieldBehavior::Ignore,
+ UnexpectedFieldBehavior::Error,
+ UnexpectedFieldBehavior::InferType));
+
+TEST(BlockParserWithSchema, Nested) {
+ auto options = ParseOptions::Defaults();
+ options.explicit_schema = schema({field("yo", utf8()), field("arr", list(int32())),
+ field("nuf", struct_({field("ps", int32())}))});
+ options.unexpected_field_behavior = UnexpectedFieldBehavior::Ignore;
+ AssertParseColumns(options, nested_src(),
+ {field("yo", utf8()), field("arr", list(utf8())),
+ field("nuf", struct_({field("ps", utf8())}))},
+ {"[\"thing\", null, \"\xe5\xbf\x8d\", null]",
+ R"([["1", "2", "3"], ["2"], [], null])",
+ R"([{"ps":null}, {}, {"ps":"78"}, {"ps":"90"}])"});
+}
+
+TEST(BlockParserWithSchema, FailOnIncompleteJson) {
+ auto options = ParseOptions::Defaults();
+ options.explicit_schema = schema({field("a", int32())});
+ options.unexpected_field_behavior = UnexpectedFieldBehavior::Ignore;
+ std::shared_ptr<Array> parsed;
+ ASSERT_RAISES(Invalid, ParseFromString(options, "{\"a\":0, \"b\"", &parsed));
+}
+
+TEST(BlockParser, Basics) {
+ auto options = ParseOptions::Defaults();
+ options.unexpected_field_behavior = UnexpectedFieldBehavior::InferType;
+ AssertParseColumns(
+ options, scalars_only_src(),
+ {field("hello", utf8()), field("world", boolean()), field("yo", utf8())},
+ {"[\"3.5\", \"3.25\", \"3.125\", \"0.0\"]", "[false, null, null, true]",
+ "[\"thing\", null, \"\xe5\xbf\x8d\", null]"});
+}
+
+TEST(BlockParser, Nested) {
+ auto options = ParseOptions::Defaults();
+ options.unexpected_field_behavior = UnexpectedFieldBehavior::InferType;
+ AssertParseColumns(options, nested_src(),
+ {field("yo", utf8()), field("arr", list(utf8())),
+ field("nuf", struct_({field("ps", utf8())}))},
+ {"[\"thing\", null, \"\xe5\xbf\x8d\", null]",
+ R"([["1", "2", "3"], ["2"], [], null])",
+ R"([{"ps":null}, {}, {"ps":"78"}, {"ps":"90"}])"});
+}
+
+TEST(BlockParser, Null) {
+ auto options = ParseOptions::Defaults();
+ options.unexpected_field_behavior = UnexpectedFieldBehavior::InferType;
+ AssertParseColumns(
+ options, null_src(),
+ {field("plain", null()), field("list1", list(null())), field("list2", list(null())),
+ field("struct", struct_({field("plain", null())}))},
+ {"[null, null]", "[[], []]", "[[], [null]]",
+ R"([{"plain": null}, {"plain": null}])"});
+}
+
+TEST(BlockParser, AdHoc) {
+ auto options = ParseOptions::Defaults();
+ options.unexpected_field_behavior = UnexpectedFieldBehavior::InferType;
+ AssertParseColumns(
+ options, R"({"a": [1], "b": {"c": true, "d": "1991-02-03"}}
+{"a": [], "b": {"c": false, "d": "2019-04-01"}}
+)",
+ {field("a", list(utf8())),
+ field("b", struct_({field("c", boolean()), field("d", utf8())}))},
+ {R"([["1"], []])",
+ R"([{"c":true, "d": "1991-02-03"}, {"c":false, "d":"2019-04-01"}])"});
+}
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/rapidjson_defs.h b/src/arrow/cpp/src/arrow/json/rapidjson_defs.h
new file mode 100644
index 000000000..9ed81d000
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/rapidjson_defs.h
@@ -0,0 +1,43 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Include this file before including any RapidJSON headers.
+
+#pragma once
+
+#define RAPIDJSON_HAS_STDSTRING 1
+#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 1
+#define RAPIDJSON_HAS_CXX11_RANGE_FOR 1
+
+// rapidjson will be defined in namespace arrow::rapidjson
+#define RAPIDJSON_NAMESPACE arrow::rapidjson
+#define RAPIDJSON_NAMESPACE_BEGIN \
+ namespace arrow { \
+ namespace rapidjson {
+#define RAPIDJSON_NAMESPACE_END \
+ } \
+ }
+
+// enable SIMD whitespace skipping, if available
+#if defined(ARROW_HAVE_SSE4_2)
+#define RAPIDJSON_SSE2 1
+#define RAPIDJSON_SSE42 1
+#endif
+
+#if defined(ARROW_HAVE_NEON)
+#define RAPIDJSON_NEON 1
+#endif
diff --git a/src/arrow/cpp/src/arrow/json/reader.cc b/src/arrow/cpp/src/arrow/json/reader.cc
new file mode 100644
index 000000000..18aed0235
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/reader.cc
@@ -0,0 +1,218 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/json/reader.h"
+
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/json/chunked_builder.h"
+#include "arrow/json/chunker.h"
+#include "arrow/json/converter.h"
+#include "arrow/json/parser.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/task_group.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using util::string_view;
+
+using internal::checked_cast;
+using internal::GetCpuThreadPool;
+using internal::TaskGroup;
+using internal::ThreadPool;
+
+namespace json {
+
+class TableReaderImpl : public TableReader,
+ public std::enable_shared_from_this<TableReaderImpl> {
+ public:
+ TableReaderImpl(MemoryPool* pool, const ReadOptions& read_options,
+ const ParseOptions& parse_options,
+ std::shared_ptr<TaskGroup> task_group)
+ : pool_(pool),
+ read_options_(read_options),
+ parse_options_(parse_options),
+ chunker_(MakeChunker(parse_options_)),
+ task_group_(std::move(task_group)) {}
+
+ Status Init(std::shared_ptr<io::InputStream> input) {
+ ARROW_ASSIGN_OR_RAISE(auto it,
+ io::MakeInputStreamIterator(input, read_options_.block_size));
+ return MakeReadaheadIterator(std::move(it), task_group_->parallelism())
+ .Value(&block_iterator_);
+ }
+
+ Result<std::shared_ptr<Table>> Read() override {
+ RETURN_NOT_OK(MakeBuilder());
+
+ ARROW_ASSIGN_OR_RAISE(auto block, block_iterator_.Next());
+ if (block == nullptr) {
+ return Status::Invalid("Empty JSON file");
+ }
+
+ auto self = shared_from_this();
+ auto empty = std::make_shared<Buffer>("");
+
+ int64_t block_index = 0;
+ std::shared_ptr<Buffer> partial = empty;
+
+ while (block != nullptr) {
+ std::shared_ptr<Buffer> next_block, whole, completion, next_partial;
+
+ ARROW_ASSIGN_OR_RAISE(next_block, block_iterator_.Next());
+
+ if (next_block == nullptr) {
+ // End of file reached => compute completion from penultimate block
+ RETURN_NOT_OK(chunker_->ProcessFinal(partial, block, &completion, &whole));
+ } else {
+ std::shared_ptr<Buffer> starts_with_whole;
+ // Get completion of partial from previous block.
+ RETURN_NOT_OK(chunker_->ProcessWithPartial(partial, block, &completion,
+ &starts_with_whole));
+
+ // Get all whole objects entirely inside the current buffer
+ RETURN_NOT_OK(chunker_->Process(starts_with_whole, &whole, &next_partial));
+ }
+
+ // Launch parse task
+ task_group_->Append([self, partial, completion, whole, block_index] {
+ return self->ParseAndInsert(partial, completion, whole, block_index);
+ });
+ block_index++;
+
+ partial = next_partial;
+ block = next_block;
+ }
+
+ std::shared_ptr<ChunkedArray> array;
+ RETURN_NOT_OK(builder_->Finish(&array));
+ return Table::FromChunkedStructArray(array);
+ }
+
+ private:
+ Status MakeBuilder() {
+ auto type = parse_options_.explicit_schema
+ ? struct_(parse_options_.explicit_schema->fields())
+ : struct_({});
+
+ auto promotion_graph =
+ parse_options_.unexpected_field_behavior == UnexpectedFieldBehavior::InferType
+ ? GetPromotionGraph()
+ : nullptr;
+
+ return MakeChunkedArrayBuilder(task_group_, pool_, promotion_graph, type, &builder_);
+ }
+
+ Status ParseAndInsert(const std::shared_ptr<Buffer>& partial,
+ const std::shared_ptr<Buffer>& completion,
+ const std::shared_ptr<Buffer>& whole, int64_t block_index) {
+ std::unique_ptr<BlockParser> parser;
+ RETURN_NOT_OK(BlockParser::Make(pool_, parse_options_, &parser));
+ RETURN_NOT_OK(parser->ReserveScalarStorage(partial->size() + completion->size() +
+ whole->size()));
+
+ if (partial->size() != 0 || completion->size() != 0) {
+ std::shared_ptr<Buffer> straddling;
+ if (partial->size() == 0) {
+ straddling = completion;
+ } else if (completion->size() == 0) {
+ straddling = partial;
+ } else {
+ ARROW_ASSIGN_OR_RAISE(straddling,
+ ConcatenateBuffers({partial, completion}, pool_));
+ }
+ RETURN_NOT_OK(parser->Parse(straddling));
+ }
+
+ if (whole->size() != 0) {
+ RETURN_NOT_OK(parser->Parse(whole));
+ }
+
+ std::shared_ptr<Array> parsed;
+ RETURN_NOT_OK(parser->Finish(&parsed));
+ builder_->Insert(block_index, field("", parsed->type()), parsed);
+ return Status::OK();
+ }
+
+ MemoryPool* pool_;
+ ReadOptions read_options_;
+ ParseOptions parse_options_;
+ std::unique_ptr<Chunker> chunker_;
+ std::shared_ptr<TaskGroup> task_group_;
+ Iterator<std::shared_ptr<Buffer>> block_iterator_;
+ std::shared_ptr<ChunkedArrayBuilder> builder_;
+};
+
+Result<std::shared_ptr<TableReader>> TableReader::Make(
+ MemoryPool* pool, std::shared_ptr<io::InputStream> input,
+ const ReadOptions& read_options, const ParseOptions& parse_options) {
+ std::shared_ptr<TableReaderImpl> ptr;
+ if (read_options.use_threads) {
+ ptr = std::make_shared<TableReaderImpl>(pool, read_options, parse_options,
+ TaskGroup::MakeThreaded(GetCpuThreadPool()));
+ } else {
+ ptr = std::make_shared<TableReaderImpl>(pool, read_options, parse_options,
+ TaskGroup::MakeSerial());
+ }
+ RETURN_NOT_OK(ptr->Init(input));
+ return ptr;
+}
+
+Result<std::shared_ptr<RecordBatch>> ParseOne(ParseOptions options,
+ std::shared_ptr<Buffer> json) {
+ std::unique_ptr<BlockParser> parser;
+ RETURN_NOT_OK(BlockParser::Make(options, &parser));
+ RETURN_NOT_OK(parser->Parse(json));
+ std::shared_ptr<Array> parsed;
+ RETURN_NOT_OK(parser->Finish(&parsed));
+
+ auto type =
+ options.explicit_schema ? struct_(options.explicit_schema->fields()) : struct_({});
+ auto promotion_graph =
+ options.unexpected_field_behavior == UnexpectedFieldBehavior::InferType
+ ? GetPromotionGraph()
+ : nullptr;
+ std::shared_ptr<ChunkedArrayBuilder> builder;
+ RETURN_NOT_OK(MakeChunkedArrayBuilder(TaskGroup::MakeSerial(), default_memory_pool(),
+ promotion_graph, type, &builder));
+
+ builder->Insert(0, field("", type), parsed);
+ std::shared_ptr<ChunkedArray> converted_chunked;
+ RETURN_NOT_OK(builder->Finish(&converted_chunked));
+ const auto& converted = checked_cast<const StructArray&>(*converted_chunked->chunk(0));
+
+ std::vector<std::shared_ptr<Array>> columns(converted.num_fields());
+ for (int i = 0; i < converted.num_fields(); ++i) {
+ columns[i] = converted.field(i);
+ }
+ return RecordBatch::Make(schema(converted.type()->fields()), converted.length(),
+ std::move(columns));
+}
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/reader.h b/src/arrow/cpp/src/arrow/json/reader.h
new file mode 100644
index 000000000..3374931a0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/reader.h
@@ -0,0 +1,64 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/json/options.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Buffer;
+class MemoryPool;
+class Table;
+class RecordBatch;
+class Array;
+class DataType;
+
+namespace io {
+class InputStream;
+} // namespace io
+
+namespace json {
+
+/// A class that reads an entire JSON file into a Arrow Table
+///
+/// The file is expected to consist of individual line-separated JSON objects
+class ARROW_EXPORT TableReader {
+ public:
+ virtual ~TableReader() = default;
+
+ /// Read the entire JSON file and convert it to a Arrow Table
+ virtual Result<std::shared_ptr<Table>> Read() = 0;
+
+ /// Create a TableReader instance
+ static Result<std::shared_ptr<TableReader>> Make(MemoryPool* pool,
+ std::shared_ptr<io::InputStream> input,
+ const ReadOptions&,
+ const ParseOptions&);
+};
+
+ARROW_EXPORT Result<std::shared_ptr<RecordBatch>> ParseOne(ParseOptions options,
+ std::shared_ptr<Buffer> json);
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/reader_test.cc b/src/arrow/cpp/src/arrow/json/reader_test.cc
new file mode 100644
index 000000000..976343b52
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/reader_test.cc
@@ -0,0 +1,278 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/io/interfaces.h"
+#include "arrow/json/options.h"
+#include "arrow/json/reader.h"
+#include "arrow/json/test_common.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+namespace json {
+
+using util::string_view;
+
+using internal::checked_cast;
+
+class ReaderTest : public ::testing::TestWithParam<bool> {
+ public:
+ void SetUpReader() {
+ read_options_.use_threads = GetParam();
+ ASSERT_OK_AND_ASSIGN(reader_, TableReader::Make(default_memory_pool(), input_,
+ read_options_, parse_options_));
+ }
+
+ void SetUpReader(util::string_view input) {
+ ASSERT_OK(MakeStream(input, &input_));
+ SetUpReader();
+ }
+
+ std::shared_ptr<ChunkedArray> ChunkedFromJSON(const std::shared_ptr<Field>& field,
+ const std::vector<std::string>& data) {
+ ArrayVector chunks(data.size());
+ for (size_t i = 0; i < chunks.size(); ++i) {
+ chunks[i] = ArrayFromJSON(field->type(), data[i]);
+ }
+ return std::make_shared<ChunkedArray>(std::move(chunks));
+ }
+
+ ParseOptions parse_options_ = ParseOptions::Defaults();
+ ReadOptions read_options_ = ReadOptions::Defaults();
+ std::shared_ptr<io::InputStream> input_;
+ std::shared_ptr<TableReader> reader_;
+ std::shared_ptr<Table> table_;
+};
+
+INSTANTIATE_TEST_SUITE_P(ReaderTest, ReaderTest, ::testing::Values(false, true));
+
+TEST_P(ReaderTest, Empty) {
+ SetUpReader("{}\n{}\n");
+ ASSERT_OK_AND_ASSIGN(table_, reader_->Read());
+
+ auto expected_table = Table::Make(schema({}), ArrayVector(), 2);
+ AssertTablesEqual(*expected_table, *table_);
+}
+
+TEST_P(ReaderTest, EmptyNoNewlineAtEnd) {
+ SetUpReader("{}\n{}");
+ ASSERT_OK_AND_ASSIGN(table_, reader_->Read());
+
+ auto expected_table = Table::Make(schema({}), ArrayVector(), 2);
+ AssertTablesEqual(*expected_table, *table_);
+}
+
+TEST_P(ReaderTest, EmptyManyNewlines) {
+ SetUpReader("{}\n\r\n{}\n\r\n");
+ ASSERT_OK_AND_ASSIGN(table_, reader_->Read());
+
+ auto expected_table = Table::Make(schema({}), ArrayVector(), 2);
+ AssertTablesEqual(*expected_table, *table_);
+}
+
+TEST_P(ReaderTest, Basics) {
+ parse_options_.unexpected_field_behavior = UnexpectedFieldBehavior::InferType;
+ auto src = scalars_only_src();
+ SetUpReader(src);
+ ASSERT_OK_AND_ASSIGN(table_, reader_->Read());
+
+ auto schema = ::arrow::schema(
+ {field("hello", float64()), field("world", boolean()), field("yo", utf8())});
+
+ auto expected_table = Table::Make(
+ schema, {
+ ArrayFromJSON(schema->field(0)->type(), "[3.5, 3.25, 3.125, 0.0]"),
+ ArrayFromJSON(schema->field(1)->type(), "[false, null, null, true]"),
+ ArrayFromJSON(schema->field(2)->type(),
+ "[\"thing\", null, \"\xe5\xbf\x8d\", null]"),
+ });
+ AssertTablesEqual(*expected_table, *table_);
+}
+
+TEST_P(ReaderTest, Nested) {
+ parse_options_.unexpected_field_behavior = UnexpectedFieldBehavior::InferType;
+ auto src = nested_src();
+ SetUpReader(src);
+ ASSERT_OK_AND_ASSIGN(table_, reader_->Read());
+
+ auto schema = ::arrow::schema({field("hello", float64()), field("world", boolean()),
+ field("yo", utf8()), field("arr", list(int64())),
+ field("nuf", struct_({field("ps", int64())}))});
+
+ auto a0 = ArrayFromJSON(schema->field(0)->type(), "[3.5, 3.25, 3.125, 0.0]");
+ auto a1 = ArrayFromJSON(schema->field(1)->type(), "[false, null, null, true]");
+ auto a2 = ArrayFromJSON(schema->field(2)->type(),
+ "[\"thing\", null, \"\xe5\xbf\x8d\", null]");
+ auto a3 = ArrayFromJSON(schema->field(3)->type(), "[[1, 2, 3], [2], [], null]");
+ auto a4 = ArrayFromJSON(schema->field(4)->type(),
+ R"([{"ps":null}, null, {"ps":78}, {"ps":90}])");
+ auto expected_table = Table::Make(schema, {a0, a1, a2, a3, a4});
+ AssertTablesEqual(*expected_table, *table_);
+}
+
+TEST_P(ReaderTest, PartialSchema) {
+ parse_options_.unexpected_field_behavior = UnexpectedFieldBehavior::InferType;
+ parse_options_.explicit_schema =
+ schema({field("nuf", struct_({field("absent", date32())})),
+ field("arr", list(float32()))});
+ auto src = nested_src();
+ SetUpReader(src);
+ ASSERT_OK_AND_ASSIGN(table_, reader_->Read());
+
+ auto schema = ::arrow::schema(
+ {field("nuf", struct_({field("absent", date32()), field("ps", int64())})),
+ field("arr", list(float32())), field("hello", float64()),
+ field("world", boolean()), field("yo", utf8())});
+
+ auto expected_table = Table::Make(
+ schema,
+ {
+ // NB: explicitly declared fields will appear first
+ ArrayFromJSON(
+ schema->field(0)->type(),
+ R"([{"absent":null,"ps":null}, null, {"absent":null,"ps":78}, {"absent":null,"ps":90}])"),
+ ArrayFromJSON(schema->field(1)->type(), R"([[1, 2, 3], [2], [], null])"),
+ // ...followed by undeclared fields
+ ArrayFromJSON(schema->field(2)->type(), "[3.5, 3.25, 3.125, 0.0]"),
+ ArrayFromJSON(schema->field(3)->type(), "[false, null, null, true]"),
+ ArrayFromJSON(schema->field(4)->type(),
+ "[\"thing\", null, \"\xe5\xbf\x8d\", null]"),
+ });
+ AssertTablesEqual(*expected_table, *table_);
+}
+
+TEST_P(ReaderTest, TypeInference) {
+ parse_options_.unexpected_field_behavior = UnexpectedFieldBehavior::InferType;
+ SetUpReader(R"(
+ {"ts":null, "f": null}
+ {"ts":"1970-01-01", "f": 3}
+ {"ts":"2018-11-13 17:11:10", "f":3.125}
+ )");
+ ASSERT_OK_AND_ASSIGN(table_, reader_->Read());
+
+ auto schema =
+ ::arrow::schema({field("ts", timestamp(TimeUnit::SECOND)), field("f", float64())});
+ auto expected_table = Table::Make(
+ schema, {ArrayFromJSON(schema->field(0)->type(),
+ R"([null, "1970-01-01", "2018-11-13 17:11:10"])"),
+ ArrayFromJSON(schema->field(1)->type(), R"([null, 3, 3.125])")});
+ AssertTablesEqual(*expected_table, *table_);
+}
+
+TEST_P(ReaderTest, MultipleChunks) {
+ parse_options_.unexpected_field_behavior = UnexpectedFieldBehavior::InferType;
+
+ auto src = scalars_only_src();
+ read_options_.block_size = static_cast<int>(src.length() / 3);
+
+ SetUpReader(src);
+ ASSERT_OK_AND_ASSIGN(table_, reader_->Read());
+
+ auto schema = ::arrow::schema(
+ {field("hello", float64()), field("world", boolean()), field("yo", utf8())});
+
+ // there is an empty chunk because the last block of the file is " "
+ auto expected_table = Table::Make(
+ schema,
+ {
+ ChunkedFromJSON(schema->field(0), {"[3.5]", "[3.25]", "[3.125, 0.0]", "[]"}),
+ ChunkedFromJSON(schema->field(1), {"[false]", "[null]", "[null, true]", "[]"}),
+ ChunkedFromJSON(schema->field(2),
+ {"[\"thing\"]", "[null]", "[\"\xe5\xbf\x8d\", null]", "[]"}),
+ });
+ AssertTablesEqual(*expected_table, *table_);
+}
+
+TEST(ReaderTest, MultipleChunksParallel) {
+ int64_t count = 1 << 10;
+
+ ParseOptions parse_options;
+ parse_options.unexpected_field_behavior = UnexpectedFieldBehavior::InferType;
+ ReadOptions read_options;
+ read_options.block_size =
+ static_cast<int>(count / 2); // there will be about two dozen blocks
+
+ std::string json;
+ for (int i = 0; i < count; ++i) {
+ json += "{\"a\":" + std::to_string(i) + "}\n";
+ }
+ std::shared_ptr<io::InputStream> input;
+ std::shared_ptr<TableReader> reader;
+
+ read_options.use_threads = true;
+ ASSERT_OK(MakeStream(json, &input));
+ ASSERT_OK_AND_ASSIGN(reader, TableReader::Make(default_memory_pool(), input,
+ read_options, parse_options));
+ ASSERT_OK_AND_ASSIGN(auto threaded, reader->Read());
+
+ read_options.use_threads = false;
+ ASSERT_OK(MakeStream(json, &input));
+ ASSERT_OK_AND_ASSIGN(reader, TableReader::Make(default_memory_pool(), input,
+ read_options, parse_options));
+ ASSERT_OK_AND_ASSIGN(auto serial, reader->Read());
+
+ ASSERT_EQ(serial->column(0)->type()->id(), Type::INT64);
+ int expected = 0;
+ for (auto chunk : serial->column(0)->chunks()) {
+ for (int64_t i = 0; i < chunk->length(); ++i) {
+ ASSERT_EQ(checked_cast<const Int64Array*>(chunk.get())->GetView(i), expected)
+ << " at index " << i;
+ ++expected;
+ }
+ }
+
+ AssertTablesEqual(*serial, *threaded);
+}
+
+TEST(ReaderTest, ListArrayWithFewValues) {
+ // ARROW-7647
+ ParseOptions parse_options;
+ parse_options.unexpected_field_behavior = UnexpectedFieldBehavior::InferType;
+ ReadOptions read_options;
+
+ auto expected_batch = RecordBatchFromJSON(
+ schema({field("a", list(int64())),
+ field("b", struct_({field("c", boolean()),
+ field("d", timestamp(TimeUnit::SECOND))}))}),
+ R"([
+ {"a": [1], "b": {"c": true, "d": "1991-02-03"}},
+ {"a": [], "b": {"c": false, "d": "2019-04-01"}}
+ ])");
+ ASSERT_OK_AND_ASSIGN(auto expected_table, Table::FromRecordBatches({expected_batch}));
+
+ std::string json = R"({"a": [1], "b": {"c": true, "d": "1991-02-03"}}
+{"a": [], "b": {"c": false, "d": "2019-04-01"}}
+)";
+ std::shared_ptr<io::InputStream> input;
+ ASSERT_OK(MakeStream(json, &input));
+
+ read_options.use_threads = false;
+ ASSERT_OK_AND_ASSIGN(auto reader, TableReader::Make(default_memory_pool(), input,
+ read_options, parse_options));
+
+ ASSERT_OK_AND_ASSIGN(auto actual_table, reader->Read());
+ AssertTablesEqual(*actual_table, *expected_table);
+}
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/test_common.h b/src/arrow/cpp/src/arrow/json/test_common.h
new file mode 100644
index 000000000..488da071d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/test_common.h
@@ -0,0 +1,260 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <random>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/io/memory.h"
+#include "arrow/json/converter.h"
+#include "arrow/json/options.h"
+#include "arrow/json/parser.h"
+#include "arrow/json/rapidjson_defs.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/string_view.h"
+#include "arrow/visitor_inline.h"
+#include "rapidjson/document.h"
+#include "rapidjson/prettywriter.h"
+#include "rapidjson/reader.h"
+#include "rapidjson/writer.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace json {
+
+namespace rj = arrow::rapidjson;
+
+using rj::StringBuffer;
+using util::string_view;
+using Writer = rj::Writer<StringBuffer>;
+
+inline static Status OK(bool ok) { return ok ? Status::OK() : Status::Invalid(""); }
+
+template <typename Engine>
+inline static Status Generate(const std::shared_ptr<DataType>& type, Engine& e,
+ Writer* writer);
+
+template <typename Engine>
+inline static Status Generate(const std::vector<std::shared_ptr<Field>>& fields,
+ Engine& e, Writer* writer);
+
+template <typename Engine>
+inline static Status Generate(const std::shared_ptr<Schema>& schm, Engine& e,
+ Writer* writer) {
+ return Generate(schm->fields(), e, writer);
+}
+
+template <typename Engine>
+struct GenerateImpl {
+ Status Visit(const NullType&) { return OK(writer.Null()); }
+
+ Status Visit(const BooleanType&) {
+ return OK(writer.Bool(std::uniform_int_distribution<uint16_t>{}(e)&1));
+ }
+
+ template <typename T>
+ enable_if_physical_unsigned_integer<T, Status> Visit(const T&) {
+ auto val = std::uniform_int_distribution<>{}(e);
+ return OK(writer.Uint64(static_cast<typename T::c_type>(val)));
+ }
+
+ template <typename T>
+ enable_if_physical_signed_integer<T, Status> Visit(const T&) {
+ auto val = std::uniform_int_distribution<>{}(e);
+ return OK(writer.Int64(static_cast<typename T::c_type>(val)));
+ }
+
+ template <typename T>
+ enable_if_physical_floating_point<T, Status> Visit(const T&) {
+ auto val = std::normal_distribution<typename T::c_type>{0, 1 << 10}(e);
+ return OK(writer.Double(val));
+ }
+
+ template <typename T>
+ enable_if_base_binary<T, Status> Visit(const T&) {
+ auto size = std::poisson_distribution<>{4}(e);
+ std::uniform_int_distribution<uint16_t> gen_char(32, 127); // FIXME generate UTF8
+ std::string s(size, '\0');
+ for (char& ch : s) ch = static_cast<char>(gen_char(e));
+ return OK(writer.String(s.c_str()));
+ }
+
+ template <typename T>
+ enable_if_list_like<T, Status> Visit(const T& t) {
+ auto size = std::poisson_distribution<>{4}(e);
+ writer.StartArray();
+ for (int i = 0; i < size; ++i) RETURN_NOT_OK(Generate(t.value_type(), e, &writer));
+ return OK(writer.EndArray(size));
+ }
+
+ Status Visit(const StructType& t) { return Generate(t.fields(), e, &writer); }
+
+ Status Visit(const DayTimeIntervalType& t) { return NotImplemented(t); }
+
+ Status Visit(const MonthDayNanoIntervalType& t) { return NotImplemented(t); }
+
+ Status Visit(const DictionaryType& t) { return NotImplemented(t); }
+
+ Status Visit(const ExtensionType& t) { return NotImplemented(t); }
+
+ Status Visit(const Decimal128Type& t) { return NotImplemented(t); }
+
+ Status Visit(const FixedSizeBinaryType& t) { return NotImplemented(t); }
+
+ Status Visit(const UnionType& t) { return NotImplemented(t); }
+
+ Status NotImplemented(const DataType& t) {
+ return Status::NotImplemented("random generation of arrays of type ", t);
+ }
+
+ Engine& e;
+ rj::Writer<rj::StringBuffer>& writer;
+};
+
+template <typename Engine>
+inline static Status Generate(const std::shared_ptr<DataType>& type, Engine& e,
+ Writer* writer) {
+ if (std::uniform_real_distribution<>{0, 1}(e) < .2) {
+ // one out of 5 chance of null, anywhere
+ writer->Null();
+ return Status::OK();
+ }
+ GenerateImpl<Engine> visitor = {e, *writer};
+ return VisitTypeInline(*type, &visitor);
+}
+
+template <typename Engine>
+inline static Status Generate(const std::vector<std::shared_ptr<Field>>& fields,
+ Engine& e, Writer* writer) {
+ RETURN_NOT_OK(OK(writer->StartObject()));
+ for (const auto& f : fields) {
+ writer->Key(f->name().c_str());
+ RETURN_NOT_OK(Generate(f->type(), e, writer));
+ }
+ return OK(writer->EndObject(static_cast<int>(fields.size())));
+}
+
+inline static Status MakeStream(string_view src_str,
+ std::shared_ptr<io::InputStream>* out) {
+ auto src = std::make_shared<Buffer>(src_str);
+ *out = std::make_shared<io::BufferReader>(src);
+ return Status::OK();
+}
+
+// scalar values (numbers and strings) are parsed into a
+// dictionary<index:int32, value:string>. This can be decoded for ease of comparison
+inline static Status DecodeStringDictionary(const DictionaryArray& dict_array,
+ std::shared_ptr<Array>* decoded) {
+ const StringArray& dict = checked_cast<const StringArray&>(*dict_array.dictionary());
+ const Int32Array& indices = checked_cast<const Int32Array&>(*dict_array.indices());
+ StringBuilder builder;
+ RETURN_NOT_OK(builder.Resize(indices.length()));
+ for (int64_t i = 0; i < indices.length(); ++i) {
+ if (indices.IsNull(i)) {
+ builder.UnsafeAppendNull();
+ continue;
+ }
+ auto value = dict.GetView(indices.GetView(i));
+ RETURN_NOT_OK(builder.ReserveData(value.size()));
+ builder.UnsafeAppend(value);
+ }
+ return builder.Finish(decoded);
+}
+
+inline static Status ParseFromString(ParseOptions options, string_view src_str,
+ std::shared_ptr<Array>* parsed) {
+ auto src = std::make_shared<Buffer>(src_str);
+ std::unique_ptr<BlockParser> parser;
+ RETURN_NOT_OK(BlockParser::Make(options, &parser));
+ RETURN_NOT_OK(parser->Parse(src));
+ return parser->Finish(parsed);
+}
+
+inline static Status ParseFromString(ParseOptions options, string_view src_str,
+ std::shared_ptr<StructArray>* parsed) {
+ std::shared_ptr<Array> parsed_non_struct;
+ RETURN_NOT_OK(ParseFromString(options, src_str, &parsed_non_struct));
+ *parsed = internal::checked_pointer_cast<StructArray>(parsed_non_struct);
+ return Status::OK();
+}
+
+static inline std::string PrettyPrint(string_view one_line) {
+ rj::Document document;
+
+ // Must pass size to avoid ASAN issues.
+ document.Parse(one_line.data(), one_line.size());
+ rj::StringBuffer sb;
+ rj::PrettyWriter<rj::StringBuffer> writer(sb);
+ document.Accept(writer);
+ return sb.GetString();
+}
+
+template <typename T>
+std::string RowsOfOneColumn(util::string_view name, std::initializer_list<T> values,
+ decltype(std::to_string(*values.begin()))* = nullptr) {
+ std::stringstream ss;
+ for (auto value : values) {
+ ss << R"({")" << name << R"(":)" << std::to_string(value) << "}\n";
+ }
+ return ss.str();
+}
+
+inline std::string RowsOfOneColumn(util::string_view name,
+ std::initializer_list<std::string> values) {
+ std::stringstream ss;
+ for (auto value : values) {
+ ss << R"({")" << name << R"(":)" << value << "}\n";
+ }
+ return ss.str();
+}
+
+inline static std::string scalars_only_src() {
+ return R"(
+ { "hello": 3.5, "world": false, "yo": "thing" }
+ { "hello": 3.25, "world": null }
+ { "hello": 3.125, "world": null, "yo": "\u5fcd" }
+ { "hello": 0.0, "world": true, "yo": null }
+ )";
+}
+
+inline static std::string nested_src() {
+ return R"(
+ { "hello": 3.5, "world": false, "yo": "thing", "arr": [1, 2, 3], "nuf": {} }
+ { "hello": 3.25, "world": null, "arr": [2], "nuf": null }
+ { "hello": 3.125, "world": null, "yo": "\u5fcd", "arr": [], "nuf": { "ps": 78 } }
+ { "hello": 0.0, "world": true, "yo": null, "arr": null, "nuf": { "ps": 90 } }
+ )";
+}
+
+inline static std::string null_src() {
+ return R"(
+ { "plain": null, "list1": [], "list2": [], "struct": { "plain": null } }
+ { "plain": null, "list1": [], "list2": [null], "struct": {} }
+ )";
+}
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/json/type_fwd.h b/src/arrow/cpp/src/arrow/json/type_fwd.h
new file mode 100644
index 000000000..67e2e1bb4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/json/type_fwd.h
@@ -0,0 +1,26 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+namespace arrow {
+namespace json {
+
+class TableReader;
+struct ReadOptions;
+struct ParseOptions;
+
+} // namespace json
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/memory_pool.cc b/src/arrow/cpp/src/arrow/memory_pool.cc
new file mode 100644
index 000000000..c80e8f6f6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/memory_pool.cc
@@ -0,0 +1,797 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/memory_pool.h"
+
+#include <algorithm> // IWYU pragma: keep
+#include <atomic>
+#include <cstdlib> // IWYU pragma: keep
+#include <cstring> // IWYU pragma: keep
+#include <iostream> // IWYU pragma: keep
+#include <limits>
+#include <memory>
+
+#if defined(sun) || defined(__sun)
+#include <stdlib.h>
+#endif
+
+#include "arrow/buffer.h"
+#include "arrow/io/util_internal.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h" // IWYU pragma: keep
+#include "arrow/util/optional.h"
+#include "arrow/util/string.h"
+#include "arrow/util/thread_pool.h"
+
+#ifdef __GLIBC__
+#include <malloc.h>
+#endif
+
+#ifdef ARROW_JEMALLOC
+// Needed to support jemalloc 3 and 4
+#define JEMALLOC_MANGLE
+// Explicitly link to our version of jemalloc
+#include "jemalloc_ep/dist/include/jemalloc/jemalloc.h"
+#endif
+
+#ifdef ARROW_MIMALLOC
+#include <mimalloc.h>
+#endif
+
+#ifdef ARROW_JEMALLOC
+
+// Compile-time configuration for jemalloc options.
+// Note the prefix ("je_arrow_") must match the symbol prefix given when
+// building jemalloc.
+// See discussion in https://github.com/jemalloc/jemalloc/issues/1621
+
+// ARROW-6910(wesm): we found that jemalloc's default behavior with respect to
+// dirty / muzzy pages (see definitions of these in the jemalloc documentation)
+// conflicted with user expectations, and would even cause memory use problems
+// in some cases. By enabling the background_thread option and reducing the
+// decay time from 10 seconds to 1 seconds, memory is released more
+// aggressively (and in the background) to the OS. This can be configured
+// further by using the arrow::jemalloc_set_decay_ms API
+
+#undef USE_JEMALLOC_BACKGROUND_THREAD
+#ifndef __APPLE__
+// ARROW-6977: jemalloc's background_thread isn't always enabled on macOS
+#define USE_JEMALLOC_BACKGROUND_THREAD
+#endif
+
+// In debug mode, add memory poisoning on alloc / free
+#ifdef NDEBUG
+#define JEMALLOC_DEBUG_OPTIONS ""
+#else
+#define JEMALLOC_DEBUG_OPTIONS ",junk:true"
+#endif
+
+const char* je_arrow_malloc_conf =
+ ("oversize_threshold:0"
+#ifdef USE_JEMALLOC_BACKGROUND_THREAD
+ ",dirty_decay_ms:1000"
+ ",muzzy_decay_ms:1000"
+ ",background_thread:true"
+#else
+ // ARROW-6994: return memory immediately to the OS if the
+ // background_thread option isn't available
+ ",dirty_decay_ms:0"
+ ",muzzy_decay_ms:0"
+#endif
+ JEMALLOC_DEBUG_OPTIONS); // NOLINT: whitespace/parens
+
+#endif // ARROW_JEMALLOC
+
+namespace arrow {
+
+namespace {
+
+constexpr size_t kAlignment = 64;
+
+constexpr char kDefaultBackendEnvVar[] = "ARROW_DEFAULT_MEMORY_POOL";
+
+enum class MemoryPoolBackend : uint8_t { System, Jemalloc, Mimalloc };
+
+struct SupportedBackend {
+ const char* name;
+ MemoryPoolBackend backend;
+};
+
+// See ARROW-12248 for why we use static in-function singletons rather than
+// global constants below (in SupportedBackends() and UserSelectedBackend()).
+// In some contexts (especially R bindings) `default_memory_pool()` may be
+// called before all globals are initialized, and then the ARROW_DEFAULT_MEMORY_POOL
+// environment variable would be ignored.
+
+const std::vector<SupportedBackend>& SupportedBackends() {
+ static std::vector<SupportedBackend> backends = {
+ // ARROW-12316: Apple => mimalloc first, then jemalloc
+ // non-Apple => jemalloc first, then mimalloc
+#if defined(ARROW_JEMALLOC) && !defined(__APPLE__)
+ {"jemalloc", MemoryPoolBackend::Jemalloc},
+#endif
+#ifdef ARROW_MIMALLOC
+ {"mimalloc", MemoryPoolBackend::Mimalloc},
+#endif
+#if defined(ARROW_JEMALLOC) && defined(__APPLE__)
+ {"jemalloc", MemoryPoolBackend::Jemalloc},
+#endif
+ {"system", MemoryPoolBackend::System}
+ };
+ return backends;
+}
+
+// Return the MemoryPoolBackend selected by the user through the
+// ARROW_DEFAULT_MEMORY_POOL environment variable, if any.
+util::optional<MemoryPoolBackend> UserSelectedBackend() {
+ static auto user_selected_backend = []() -> util::optional<MemoryPoolBackend> {
+ auto unsupported_backend = [](const std::string& name) {
+ std::vector<std::string> supported;
+ for (const auto backend : SupportedBackends()) {
+ supported.push_back(std::string("'") + backend.name + "'");
+ }
+ ARROW_LOG(WARNING) << "Unsupported backend '" << name << "' specified in "
+ << kDefaultBackendEnvVar << " (supported backends are "
+ << internal::JoinStrings(supported, ", ") << ")";
+ };
+
+ auto maybe_name = internal::GetEnvVar(kDefaultBackendEnvVar);
+ if (!maybe_name.ok()) {
+ return {};
+ }
+ const auto name = *std::move(maybe_name);
+ if (name.empty()) {
+ // An empty environment variable is considered missing
+ return {};
+ }
+ const auto found = std::find_if(
+ SupportedBackends().begin(), SupportedBackends().end(),
+ [&](const SupportedBackend& backend) { return name == backend.name; });
+ if (found != SupportedBackends().end()) {
+ return found->backend;
+ }
+ unsupported_backend(name);
+ return {};
+ }();
+
+ return user_selected_backend;
+}
+
+MemoryPoolBackend DefaultBackend() {
+ auto backend = UserSelectedBackend();
+ if (backend.has_value()) {
+ return backend.value();
+ }
+ struct SupportedBackend default_backend = SupportedBackends().front();
+ return default_backend.backend;
+}
+
+// A static piece of memory for 0-size allocations, so as to return
+// an aligned non-null pointer.
+alignas(kAlignment) static uint8_t zero_size_area[1];
+
+// Helper class directing allocations to the standard system allocator.
+class SystemAllocator {
+ public:
+ // Allocate memory according to the alignment requirements for Arrow
+ // (as of May 2016 64 bytes)
+ static Status AllocateAligned(int64_t size, uint8_t** out) {
+ if (size == 0) {
+ *out = zero_size_area;
+ return Status::OK();
+ }
+#ifdef _WIN32
+ // Special code path for Windows
+ *out = reinterpret_cast<uint8_t*>(
+ _aligned_malloc(static_cast<size_t>(size), kAlignment));
+ if (!*out) {
+ return Status::OutOfMemory("malloc of size ", size, " failed");
+ }
+#elif defined(sun) || defined(__sun)
+ *out = reinterpret_cast<uint8_t*>(memalign(kAlignment, static_cast<size_t>(size)));
+ if (!*out) {
+ return Status::OutOfMemory("malloc of size ", size, " failed");
+ }
+#else
+ const int result = posix_memalign(reinterpret_cast<void**>(out), kAlignment,
+ static_cast<size_t>(size));
+ if (result == ENOMEM) {
+ return Status::OutOfMemory("malloc of size ", size, " failed");
+ }
+
+ if (result == EINVAL) {
+ return Status::Invalid("invalid alignment parameter: ", kAlignment);
+ }
+#endif
+ return Status::OK();
+ }
+
+ static Status ReallocateAligned(int64_t old_size, int64_t new_size, uint8_t** ptr) {
+ uint8_t* previous_ptr = *ptr;
+ if (previous_ptr == zero_size_area) {
+ DCHECK_EQ(old_size, 0);
+ return AllocateAligned(new_size, ptr);
+ }
+ if (new_size == 0) {
+ DeallocateAligned(previous_ptr, old_size);
+ *ptr = zero_size_area;
+ return Status::OK();
+ }
+ // Note: We cannot use realloc() here as it doesn't guarantee alignment.
+
+ // Allocate new chunk
+ uint8_t* out = nullptr;
+ RETURN_NOT_OK(AllocateAligned(new_size, &out));
+ DCHECK(out);
+ // Copy contents and release old memory chunk
+ memcpy(out, *ptr, static_cast<size_t>(std::min(new_size, old_size)));
+#ifdef _WIN32
+ _aligned_free(*ptr);
+#else
+ free(*ptr);
+#endif // defined(_WIN32)
+ *ptr = out;
+ return Status::OK();
+ }
+
+ static void DeallocateAligned(uint8_t* ptr, int64_t size) {
+ if (ptr == zero_size_area) {
+ DCHECK_EQ(size, 0);
+ } else {
+#ifdef _WIN32
+ _aligned_free(ptr);
+#else
+ free(ptr);
+#endif
+ }
+ }
+
+ static void ReleaseUnused() {
+#ifdef __GLIBC__
+ // The return value of malloc_trim is not an error but to inform
+ // you if memory was actually released or not, which we do not care about here
+ ARROW_UNUSED(malloc_trim(0));
+#endif
+ }
+};
+
+#ifdef ARROW_JEMALLOC
+
+// Helper class directing allocations to the jemalloc allocator.
+class JemallocAllocator {
+ public:
+ static Status AllocateAligned(int64_t size, uint8_t** out) {
+ if (size == 0) {
+ *out = zero_size_area;
+ return Status::OK();
+ }
+ *out = reinterpret_cast<uint8_t*>(
+ mallocx(static_cast<size_t>(size), MALLOCX_ALIGN(kAlignment)));
+ if (*out == NULL) {
+ return Status::OutOfMemory("malloc of size ", size, " failed");
+ }
+ return Status::OK();
+ }
+
+ static Status ReallocateAligned(int64_t old_size, int64_t new_size, uint8_t** ptr) {
+ uint8_t* previous_ptr = *ptr;
+ if (previous_ptr == zero_size_area) {
+ DCHECK_EQ(old_size, 0);
+ return AllocateAligned(new_size, ptr);
+ }
+ if (new_size == 0) {
+ DeallocateAligned(previous_ptr, old_size);
+ *ptr = zero_size_area;
+ return Status::OK();
+ }
+ *ptr = reinterpret_cast<uint8_t*>(
+ rallocx(*ptr, static_cast<size_t>(new_size), MALLOCX_ALIGN(kAlignment)));
+ if (*ptr == NULL) {
+ *ptr = previous_ptr;
+ return Status::OutOfMemory("realloc of size ", new_size, " failed");
+ }
+ return Status::OK();
+ }
+
+ static void DeallocateAligned(uint8_t* ptr, int64_t size) {
+ if (ptr == zero_size_area) {
+ DCHECK_EQ(size, 0);
+ } else {
+ dallocx(ptr, MALLOCX_ALIGN(kAlignment));
+ }
+ }
+
+ static void ReleaseUnused() {
+ mallctl("arena." ARROW_STRINGIFY(MALLCTL_ARENAS_ALL) ".purge", NULL, NULL, NULL, 0);
+ }
+};
+
+#endif // defined(ARROW_JEMALLOC)
+
+#ifdef ARROW_MIMALLOC
+
+// Helper class directing allocations to the mimalloc allocator.
+class MimallocAllocator {
+ public:
+ static Status AllocateAligned(int64_t size, uint8_t** out) {
+ if (size == 0) {
+ *out = zero_size_area;
+ return Status::OK();
+ }
+ *out = reinterpret_cast<uint8_t*>(
+ mi_malloc_aligned(static_cast<size_t>(size), kAlignment));
+ if (*out == NULL) {
+ return Status::OutOfMemory("malloc of size ", size, " failed");
+ }
+ return Status::OK();
+ }
+
+ static void ReleaseUnused() { mi_collect(true); }
+
+ static Status ReallocateAligned(int64_t old_size, int64_t new_size, uint8_t** ptr) {
+ uint8_t* previous_ptr = *ptr;
+ if (previous_ptr == zero_size_area) {
+ DCHECK_EQ(old_size, 0);
+ return AllocateAligned(new_size, ptr);
+ }
+ if (new_size == 0) {
+ DeallocateAligned(previous_ptr, old_size);
+ *ptr = zero_size_area;
+ return Status::OK();
+ }
+ *ptr = reinterpret_cast<uint8_t*>(
+ mi_realloc_aligned(previous_ptr, static_cast<size_t>(new_size), kAlignment));
+ if (*ptr == NULL) {
+ *ptr = previous_ptr;
+ return Status::OutOfMemory("realloc of size ", new_size, " failed");
+ }
+ return Status::OK();
+ }
+
+ static void DeallocateAligned(uint8_t* ptr, int64_t size) {
+ if (ptr == zero_size_area) {
+ DCHECK_EQ(size, 0);
+ } else {
+ mi_free(ptr);
+ }
+ }
+};
+
+#endif // defined(ARROW_MIMALLOC)
+
+} // namespace
+
+int64_t MemoryPool::max_memory() const { return -1; }
+
+///////////////////////////////////////////////////////////////////////
+// MemoryPool implementation that delegates its core duty
+// to an Allocator class.
+
+#ifndef NDEBUG
+static constexpr uint8_t kAllocPoison = 0xBC;
+static constexpr uint8_t kReallocPoison = 0xBD;
+static constexpr uint8_t kDeallocPoison = 0xBE;
+#endif
+
+template <typename Allocator>
+class BaseMemoryPoolImpl : public MemoryPool {
+ public:
+ ~BaseMemoryPoolImpl() override {}
+
+ Status Allocate(int64_t size, uint8_t** out) override {
+ if (size < 0) {
+ return Status::Invalid("negative malloc size");
+ }
+ if (static_cast<uint64_t>(size) >= std::numeric_limits<size_t>::max()) {
+ return Status::CapacityError("malloc size overflows size_t");
+ }
+ RETURN_NOT_OK(Allocator::AllocateAligned(size, out));
+#ifndef NDEBUG
+ // Poison data
+ if (size > 0) {
+ DCHECK_NE(*out, nullptr);
+ (*out)[0] = kAllocPoison;
+ (*out)[size - 1] = kAllocPoison;
+ }
+#endif
+
+ stats_.UpdateAllocatedBytes(size);
+ return Status::OK();
+ }
+
+ Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) override {
+ if (new_size < 0) {
+ return Status::Invalid("negative realloc size");
+ }
+ if (static_cast<uint64_t>(new_size) >= std::numeric_limits<size_t>::max()) {
+ return Status::CapacityError("realloc overflows size_t");
+ }
+ RETURN_NOT_OK(Allocator::ReallocateAligned(old_size, new_size, ptr));
+#ifndef NDEBUG
+ // Poison data
+ if (new_size > old_size) {
+ DCHECK_NE(*ptr, nullptr);
+ (*ptr)[old_size] = kReallocPoison;
+ (*ptr)[new_size - 1] = kReallocPoison;
+ }
+#endif
+
+ stats_.UpdateAllocatedBytes(new_size - old_size);
+ return Status::OK();
+ }
+
+ void Free(uint8_t* buffer, int64_t size) override {
+#ifndef NDEBUG
+ // Poison data
+ if (size > 0) {
+ DCHECK_NE(buffer, nullptr);
+ buffer[0] = kDeallocPoison;
+ buffer[size - 1] = kDeallocPoison;
+ }
+#endif
+ Allocator::DeallocateAligned(buffer, size);
+
+ stats_.UpdateAllocatedBytes(-size);
+ }
+
+ void ReleaseUnused() override { Allocator::ReleaseUnused(); }
+
+ int64_t bytes_allocated() const override { return stats_.bytes_allocated(); }
+
+ int64_t max_memory() const override { return stats_.max_memory(); }
+
+ protected:
+ internal::MemoryPoolStats stats_;
+};
+
+class SystemMemoryPool : public BaseMemoryPoolImpl<SystemAllocator> {
+ public:
+ std::string backend_name() const override { return "system"; }
+};
+
+#ifdef ARROW_JEMALLOC
+class JemallocMemoryPool : public BaseMemoryPoolImpl<JemallocAllocator> {
+ public:
+ std::string backend_name() const override { return "jemalloc"; }
+};
+#endif
+
+#ifdef ARROW_MIMALLOC
+class MimallocMemoryPool : public BaseMemoryPoolImpl<MimallocAllocator> {
+ public:
+ std::string backend_name() const override { return "mimalloc"; }
+};
+#endif
+
+std::unique_ptr<MemoryPool> MemoryPool::CreateDefault() {
+ auto backend = DefaultBackend();
+ switch (backend) {
+ case MemoryPoolBackend::System:
+ return std::unique_ptr<MemoryPool>(new SystemMemoryPool);
+#ifdef ARROW_JEMALLOC
+ case MemoryPoolBackend::Jemalloc:
+ return std::unique_ptr<MemoryPool>(new JemallocMemoryPool);
+#endif
+#ifdef ARROW_MIMALLOC
+ case MemoryPoolBackend::Mimalloc:
+ return std::unique_ptr<MemoryPool>(new MimallocMemoryPool);
+#endif
+ default:
+ ARROW_LOG(FATAL) << "Internal error: cannot create default memory pool";
+ return nullptr;
+ }
+}
+
+static struct GlobalState {
+ ~GlobalState() { finalizing.store(true, std::memory_order_relaxed); }
+
+ bool is_finalizing() const { return finalizing.load(std::memory_order_relaxed); }
+
+ std::atomic<bool> finalizing{false}; // constructed first, destroyed last
+
+ SystemMemoryPool system_pool;
+#ifdef ARROW_JEMALLOC
+ JemallocMemoryPool jemalloc_pool;
+#endif
+#ifdef ARROW_MIMALLOC
+ MimallocMemoryPool mimalloc_pool;
+#endif
+} global_state;
+
+MemoryPool* system_memory_pool() { return &global_state.system_pool; }
+
+Status jemalloc_memory_pool(MemoryPool** out) {
+#ifdef ARROW_JEMALLOC
+ *out = &global_state.jemalloc_pool;
+ return Status::OK();
+#else
+ return Status::NotImplemented("This Arrow build does not enable jemalloc");
+#endif
+}
+
+Status mimalloc_memory_pool(MemoryPool** out) {
+#ifdef ARROW_MIMALLOC
+ *out = &global_state.mimalloc_pool;
+ return Status::OK();
+#else
+ return Status::NotImplemented("This Arrow build does not enable mimalloc");
+#endif
+}
+
+MemoryPool* default_memory_pool() {
+ auto backend = DefaultBackend();
+ switch (backend) {
+ case MemoryPoolBackend::System:
+ return &global_state.system_pool;
+#ifdef ARROW_JEMALLOC
+ case MemoryPoolBackend::Jemalloc:
+ return &global_state.jemalloc_pool;
+#endif
+#ifdef ARROW_MIMALLOC
+ case MemoryPoolBackend::Mimalloc:
+ return &global_state.mimalloc_pool;
+#endif
+ default:
+ ARROW_LOG(FATAL) << "Internal error: cannot create default memory pool";
+ return nullptr;
+ }
+}
+
+#define RETURN_IF_JEMALLOC_ERROR(ERR) \
+ do { \
+ if (err != 0) { \
+ return Status::UnknownError(std::strerror(ERR)); \
+ } \
+ } while (0)
+
+Status jemalloc_set_decay_ms(int ms) {
+#ifdef ARROW_JEMALLOC
+ ssize_t decay_time_ms = static_cast<ssize_t>(ms);
+
+ int err = mallctl("arenas.dirty_decay_ms", nullptr, nullptr, &decay_time_ms,
+ sizeof(decay_time_ms));
+ RETURN_IF_JEMALLOC_ERROR(err);
+ err = mallctl("arenas.muzzy_decay_ms", nullptr, nullptr, &decay_time_ms,
+ sizeof(decay_time_ms));
+ RETURN_IF_JEMALLOC_ERROR(err);
+
+ return Status::OK();
+#else
+ return Status::Invalid("jemalloc support is not built");
+#endif
+}
+
+///////////////////////////////////////////////////////////////////////
+// LoggingMemoryPool implementation
+
+LoggingMemoryPool::LoggingMemoryPool(MemoryPool* pool) : pool_(pool) {}
+
+Status LoggingMemoryPool::Allocate(int64_t size, uint8_t** out) {
+ Status s = pool_->Allocate(size, out);
+ std::cout << "Allocate: size = " << size << std::endl;
+ return s;
+}
+
+Status LoggingMemoryPool::Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) {
+ Status s = pool_->Reallocate(old_size, new_size, ptr);
+ std::cout << "Reallocate: old_size = " << old_size << " - new_size = " << new_size
+ << std::endl;
+ return s;
+}
+
+void LoggingMemoryPool::Free(uint8_t* buffer, int64_t size) {
+ pool_->Free(buffer, size);
+ std::cout << "Free: size = " << size << std::endl;
+}
+
+int64_t LoggingMemoryPool::bytes_allocated() const {
+ int64_t nb_bytes = pool_->bytes_allocated();
+ std::cout << "bytes_allocated: " << nb_bytes << std::endl;
+ return nb_bytes;
+}
+
+int64_t LoggingMemoryPool::max_memory() const {
+ int64_t mem = pool_->max_memory();
+ std::cout << "max_memory: " << mem << std::endl;
+ return mem;
+}
+
+std::string LoggingMemoryPool::backend_name() const { return pool_->backend_name(); }
+
+///////////////////////////////////////////////////////////////////////
+// ProxyMemoryPool implementation
+
+class ProxyMemoryPool::ProxyMemoryPoolImpl {
+ public:
+ explicit ProxyMemoryPoolImpl(MemoryPool* pool) : pool_(pool) {}
+
+ Status Allocate(int64_t size, uint8_t** out) {
+ RETURN_NOT_OK(pool_->Allocate(size, out));
+ stats_.UpdateAllocatedBytes(size);
+ return Status::OK();
+ }
+
+ Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) {
+ RETURN_NOT_OK(pool_->Reallocate(old_size, new_size, ptr));
+ stats_.UpdateAllocatedBytes(new_size - old_size);
+ return Status::OK();
+ }
+
+ void Free(uint8_t* buffer, int64_t size) {
+ pool_->Free(buffer, size);
+ stats_.UpdateAllocatedBytes(-size);
+ }
+
+ int64_t bytes_allocated() const { return stats_.bytes_allocated(); }
+
+ int64_t max_memory() const { return stats_.max_memory(); }
+
+ std::string backend_name() const { return pool_->backend_name(); }
+
+ private:
+ MemoryPool* pool_;
+ internal::MemoryPoolStats stats_;
+};
+
+ProxyMemoryPool::ProxyMemoryPool(MemoryPool* pool) {
+ impl_.reset(new ProxyMemoryPoolImpl(pool));
+}
+
+ProxyMemoryPool::~ProxyMemoryPool() {}
+
+Status ProxyMemoryPool::Allocate(int64_t size, uint8_t** out) {
+ return impl_->Allocate(size, out);
+}
+
+Status ProxyMemoryPool::Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) {
+ return impl_->Reallocate(old_size, new_size, ptr);
+}
+
+void ProxyMemoryPool::Free(uint8_t* buffer, int64_t size) {
+ return impl_->Free(buffer, size);
+}
+
+int64_t ProxyMemoryPool::bytes_allocated() const { return impl_->bytes_allocated(); }
+
+int64_t ProxyMemoryPool::max_memory() const { return impl_->max_memory(); }
+
+std::string ProxyMemoryPool::backend_name() const { return impl_->backend_name(); }
+
+std::vector<std::string> SupportedMemoryBackendNames() {
+ std::vector<std::string> supported;
+ for (const auto backend : SupportedBackends()) {
+ supported.push_back(backend.name);
+ }
+ return supported;
+}
+
+// -----------------------------------------------------------------------
+// Pool buffer and allocation
+
+/// A Buffer whose lifetime is tied to a particular MemoryPool
+class PoolBuffer final : public ResizableBuffer {
+ public:
+ explicit PoolBuffer(std::shared_ptr<MemoryManager> mm, MemoryPool* pool)
+ : ResizableBuffer(nullptr, 0, std::move(mm)), pool_(pool) {}
+
+ ~PoolBuffer() override {
+ // Avoid calling pool_->Free if the global pools are destroyed
+ // (XXX this will not work with user-defined pools)
+
+ // This can happen if a Future is destructing on one thread while or
+ // after memory pools are destructed on the main thread (as there is
+ // no guarantee of destructor order between thread/memory pools)
+ uint8_t* ptr = mutable_data();
+ if (ptr && !global_state.is_finalizing()) {
+ pool_->Free(ptr, capacity_);
+ }
+ }
+
+ Status Reserve(const int64_t capacity) override {
+ if (capacity < 0) {
+ return Status::Invalid("Negative buffer capacity: ", capacity);
+ }
+ uint8_t* ptr = mutable_data();
+ if (!ptr || capacity > capacity_) {
+ int64_t new_capacity = BitUtil::RoundUpToMultipleOf64(capacity);
+ if (ptr) {
+ RETURN_NOT_OK(pool_->Reallocate(capacity_, new_capacity, &ptr));
+ } else {
+ RETURN_NOT_OK(pool_->Allocate(new_capacity, &ptr));
+ }
+ data_ = ptr;
+ capacity_ = new_capacity;
+ }
+ return Status::OK();
+ }
+
+ Status Resize(const int64_t new_size, bool shrink_to_fit = true) override {
+ if (ARROW_PREDICT_FALSE(new_size < 0)) {
+ return Status::Invalid("Negative buffer resize: ", new_size);
+ }
+ uint8_t* ptr = mutable_data();
+ if (ptr && shrink_to_fit && new_size <= size_) {
+ // Buffer is non-null and is not growing, so shrink to the requested size without
+ // excess space.
+ int64_t new_capacity = BitUtil::RoundUpToMultipleOf64(new_size);
+ if (capacity_ != new_capacity) {
+ // Buffer hasn't got yet the requested size.
+ RETURN_NOT_OK(pool_->Reallocate(capacity_, new_capacity, &ptr));
+ data_ = ptr;
+ capacity_ = new_capacity;
+ }
+ } else {
+ RETURN_NOT_OK(Reserve(new_size));
+ }
+ size_ = new_size;
+
+ return Status::OK();
+ }
+
+ static std::shared_ptr<PoolBuffer> MakeShared(MemoryPool* pool) {
+ std::shared_ptr<MemoryManager> mm;
+ if (pool == nullptr) {
+ pool = default_memory_pool();
+ mm = default_cpu_memory_manager();
+ } else {
+ mm = CPUDevice::memory_manager(pool);
+ }
+ return std::make_shared<PoolBuffer>(std::move(mm), pool);
+ }
+
+ static std::unique_ptr<PoolBuffer> MakeUnique(MemoryPool* pool) {
+ std::shared_ptr<MemoryManager> mm;
+ if (pool == nullptr) {
+ pool = default_memory_pool();
+ mm = default_cpu_memory_manager();
+ } else {
+ mm = CPUDevice::memory_manager(pool);
+ }
+ return std::unique_ptr<PoolBuffer>(new PoolBuffer(std::move(mm), pool));
+ }
+
+ private:
+ MemoryPool* pool_;
+};
+
+namespace {
+// A utility that does most of the work of the `AllocateBuffer` and
+// `AllocateResizableBuffer` methods. The argument `buffer` should be a smart pointer to
+// a PoolBuffer.
+template <typename BufferPtr, typename PoolBufferPtr>
+inline Result<BufferPtr> ResizePoolBuffer(PoolBufferPtr&& buffer, const int64_t size) {
+ RETURN_NOT_OK(buffer->Resize(size));
+ buffer->ZeroPadding();
+ return std::move(buffer);
+}
+
+} // namespace
+
+Result<std::unique_ptr<Buffer>> AllocateBuffer(const int64_t size, MemoryPool* pool) {
+ return ResizePoolBuffer<std::unique_ptr<Buffer>>(PoolBuffer::MakeUnique(pool), size);
+}
+
+Result<std::unique_ptr<ResizableBuffer>> AllocateResizableBuffer(const int64_t size,
+ MemoryPool* pool) {
+ return ResizePoolBuffer<std::unique_ptr<ResizableBuffer>>(PoolBuffer::MakeUnique(pool),
+ size);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/memory_pool.h b/src/arrow/cpp/src/arrow/memory_pool.h
new file mode 100644
index 000000000..81b1b112d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/memory_pool.h
@@ -0,0 +1,185 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <atomic>
+#include <cstdint>
+#include <memory>
+#include <string>
+
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+namespace internal {
+
+///////////////////////////////////////////////////////////////////////
+// Helper tracking memory statistics
+
+class MemoryPoolStats {
+ public:
+ MemoryPoolStats() : bytes_allocated_(0), max_memory_(0) {}
+
+ int64_t max_memory() const { return max_memory_.load(); }
+
+ int64_t bytes_allocated() const { return bytes_allocated_.load(); }
+
+ inline void UpdateAllocatedBytes(int64_t diff) {
+ auto allocated = bytes_allocated_.fetch_add(diff) + diff;
+ // "maximum" allocated memory is ill-defined in multi-threaded code,
+ // so don't try to be too rigorous here
+ if (diff > 0 && allocated > max_memory_) {
+ max_memory_ = allocated;
+ }
+ }
+
+ protected:
+ std::atomic<int64_t> bytes_allocated_;
+ std::atomic<int64_t> max_memory_;
+};
+
+} // namespace internal
+
+/// Base class for memory allocation on the CPU.
+///
+/// Besides tracking the number of allocated bytes, the allocator also should
+/// take care of the required 64-byte alignment.
+class ARROW_EXPORT MemoryPool {
+ public:
+ virtual ~MemoryPool() = default;
+
+ /// \brief EXPERIMENTAL. Create a new instance of the default MemoryPool
+ static std::unique_ptr<MemoryPool> CreateDefault();
+
+ /// Allocate a new memory region of at least size bytes.
+ ///
+ /// The allocated region shall be 64-byte aligned.
+ virtual Status Allocate(int64_t size, uint8_t** out) = 0;
+
+ /// Resize an already allocated memory section.
+ ///
+ /// As by default most default allocators on a platform don't support aligned
+ /// reallocation, this function can involve a copy of the underlying data.
+ virtual Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) = 0;
+
+ /// Free an allocated region.
+ ///
+ /// @param buffer Pointer to the start of the allocated memory region
+ /// @param size Allocated size located at buffer. An allocator implementation
+ /// may use this for tracking the amount of allocated bytes as well as for
+ /// faster deallocation if supported by its backend.
+ virtual void Free(uint8_t* buffer, int64_t size) = 0;
+
+ /// Return unused memory to the OS
+ ///
+ /// Only applies to allocators that hold onto unused memory. This will be
+ /// best effort, a memory pool may not implement this feature or may be
+ /// unable to fulfill the request due to fragmentation.
+ virtual void ReleaseUnused() {}
+
+ /// The number of bytes that were allocated and not yet free'd through
+ /// this allocator.
+ virtual int64_t bytes_allocated() const = 0;
+
+ /// Return peak memory allocation in this memory pool
+ ///
+ /// \return Maximum bytes allocated. If not known (or not implemented),
+ /// returns -1
+ virtual int64_t max_memory() const;
+
+ /// The name of the backend used by this MemoryPool (e.g. "system" or "jemalloc").
+ virtual std::string backend_name() const = 0;
+
+ protected:
+ MemoryPool() = default;
+};
+
+class ARROW_EXPORT LoggingMemoryPool : public MemoryPool {
+ public:
+ explicit LoggingMemoryPool(MemoryPool* pool);
+ ~LoggingMemoryPool() override = default;
+
+ Status Allocate(int64_t size, uint8_t** out) override;
+ Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) override;
+
+ void Free(uint8_t* buffer, int64_t size) override;
+
+ int64_t bytes_allocated() const override;
+
+ int64_t max_memory() const override;
+
+ std::string backend_name() const override;
+
+ private:
+ MemoryPool* pool_;
+};
+
+/// Derived class for memory allocation.
+///
+/// Tracks the number of bytes and maximum memory allocated through its direct
+/// calls. Actual allocation is delegated to MemoryPool class.
+class ARROW_EXPORT ProxyMemoryPool : public MemoryPool {
+ public:
+ explicit ProxyMemoryPool(MemoryPool* pool);
+ ~ProxyMemoryPool() override;
+
+ Status Allocate(int64_t size, uint8_t** out) override;
+ Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) override;
+
+ void Free(uint8_t* buffer, int64_t size) override;
+
+ int64_t bytes_allocated() const override;
+
+ int64_t max_memory() const override;
+
+ std::string backend_name() const override;
+
+ private:
+ class ProxyMemoryPoolImpl;
+ std::unique_ptr<ProxyMemoryPoolImpl> impl_;
+};
+
+/// \brief Return a process-wide memory pool based on the system allocator.
+ARROW_EXPORT MemoryPool* system_memory_pool();
+
+/// \brief Return a process-wide memory pool based on jemalloc.
+///
+/// May return NotImplemented if jemalloc is not available.
+ARROW_EXPORT Status jemalloc_memory_pool(MemoryPool** out);
+
+/// \brief Set jemalloc memory page purging behavior for future-created arenas
+/// to the indicated number of milliseconds. See dirty_decay_ms and
+/// muzzy_decay_ms options in jemalloc for a description of what these do. The
+/// default is configured to 1000 (1 second) which releases memory more
+/// aggressively to the operating system than the jemalloc default of 10
+/// seconds. If you set the value to 0, dirty / muzzy pages will be released
+/// immediately rather than with a time decay, but this may reduce application
+/// performance.
+ARROW_EXPORT
+Status jemalloc_set_decay_ms(int ms);
+
+/// \brief Return a process-wide memory pool based on mimalloc.
+///
+/// May return NotImplemented if mimalloc is not available.
+ARROW_EXPORT Status mimalloc_memory_pool(MemoryPool** out);
+
+ARROW_EXPORT std::vector<std::string> SupportedMemoryBackendNames();
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/memory_pool_benchmark.cc b/src/arrow/cpp/src/arrow/memory_pool_benchmark.cc
new file mode 100644
index 000000000..ba39310a8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/memory_pool_benchmark.cc
@@ -0,0 +1,129 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/memory_pool.h"
+#include "arrow/result.h"
+#include "arrow/util/logging.h"
+
+#include "benchmark/benchmark.h"
+
+namespace arrow {
+
+struct SystemAlloc {
+ static Result<MemoryPool*> GetAllocator() { return system_memory_pool(); }
+};
+
+#ifdef ARROW_JEMALLOC
+struct Jemalloc {
+ static Result<MemoryPool*> GetAllocator() {
+ MemoryPool* pool;
+ RETURN_NOT_OK(jemalloc_memory_pool(&pool));
+ return pool;
+ }
+};
+#endif
+
+#ifdef ARROW_MIMALLOC
+struct Mimalloc {
+ static Result<MemoryPool*> GetAllocator() {
+ MemoryPool* pool;
+ RETURN_NOT_OK(mimalloc_memory_pool(&pool));
+ return pool;
+ }
+};
+#endif
+
+static void TouchCacheLines(uint8_t* data, int64_t nbytes) {
+ uint8_t total = 0;
+ while (nbytes > 0) {
+ total += *data;
+ data += 64;
+ nbytes -= 64;
+ }
+ benchmark::DoNotOptimize(total);
+}
+
+// Benchmark the cost of accessing always the same memory area.
+// This gives us a lower bound of the potential difference between
+// AllocateTouchDeallocate and AllocateDeallocate.
+static void TouchArea(benchmark::State& state) { // NOLINT non-const reference
+ const int64_t nbytes = state.range(0);
+ MemoryPool* pool = default_memory_pool();
+ uint8_t* data;
+ ARROW_CHECK_OK(pool->Allocate(nbytes, &data));
+
+ for (auto _ : state) {
+ TouchCacheLines(data, nbytes);
+ }
+
+ pool->Free(data, nbytes);
+}
+
+// Benchmark the raw cost of allocating memory.
+// Note this is a best case situation: we always allocate and deallocate exactly
+// the same size, without any other allocator traffic. However, it can be
+// representative of workloads where we routinely create and destroy
+// temporary buffers for intermediate computation results.
+template <typename Alloc>
+static void AllocateDeallocate(benchmark::State& state) { // NOLINT non-const reference
+ const int64_t nbytes = state.range(0);
+ MemoryPool* pool = *Alloc::GetAllocator();
+
+ for (auto _ : state) {
+ uint8_t* data;
+ ARROW_CHECK_OK(pool->Allocate(nbytes, &data));
+ pool->Free(data, nbytes);
+ }
+}
+
+// Benchmark the cost of allocating memory plus accessing it.
+template <typename Alloc>
+static void AllocateTouchDeallocate(
+ benchmark::State& state) { // NOLINT non-const reference
+ const int64_t nbytes = state.range(0);
+ MemoryPool* pool = *Alloc::GetAllocator();
+
+ for (auto _ : state) {
+ uint8_t* data;
+ ARROW_CHECK_OK(pool->Allocate(nbytes, &data));
+ TouchCacheLines(data, nbytes);
+ pool->Free(data, nbytes);
+ }
+}
+
+#define BENCHMARK_ALLOCATE_ARGS \
+ ->RangeMultiplier(16)->Range(4096, 16 * 1024 * 1024)->ArgName("size")->UseRealTime()
+
+#define BENCHMARK_ALLOCATE(benchmark_func, template_param) \
+ BENCHMARK_TEMPLATE(benchmark_func, template_param) BENCHMARK_ALLOCATE_ARGS
+
+BENCHMARK(TouchArea) BENCHMARK_ALLOCATE_ARGS;
+
+BENCHMARK_ALLOCATE(AllocateDeallocate, SystemAlloc);
+BENCHMARK_ALLOCATE(AllocateTouchDeallocate, SystemAlloc);
+
+#ifdef ARROW_JEMALLOC
+BENCHMARK_ALLOCATE(AllocateDeallocate, Jemalloc);
+BENCHMARK_ALLOCATE(AllocateTouchDeallocate, Jemalloc);
+#endif
+
+#ifdef ARROW_MIMALLOC
+BENCHMARK_ALLOCATE(AllocateDeallocate, Mimalloc);
+BENCHMARK_ALLOCATE(AllocateTouchDeallocate, Mimalloc);
+#endif
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/memory_pool_test.cc b/src/arrow/cpp/src/arrow/memory_pool_test.cc
new file mode 100644
index 000000000..3ea35165f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/memory_pool_test.cc
@@ -0,0 +1,174 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+
+#include <gtest/gtest.h>
+
+#include "arrow/memory_pool.h"
+#include "arrow/memory_pool_test.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+
+struct DefaultMemoryPoolFactory {
+ static MemoryPool* memory_pool() { return default_memory_pool(); }
+};
+
+struct SystemMemoryPoolFactory {
+ static MemoryPool* memory_pool() { return system_memory_pool(); }
+};
+
+#ifdef ARROW_JEMALLOC
+struct JemallocMemoryPoolFactory {
+ static MemoryPool* memory_pool() {
+ MemoryPool* pool;
+ ABORT_NOT_OK(jemalloc_memory_pool(&pool));
+ return pool;
+ }
+};
+#endif
+
+#ifdef ARROW_MIMALLOC
+struct MimallocMemoryPoolFactory {
+ static MemoryPool* memory_pool() {
+ MemoryPool* pool;
+ ABORT_NOT_OK(mimalloc_memory_pool(&pool));
+ return pool;
+ }
+};
+#endif
+
+template <typename Factory>
+class TestMemoryPool : public ::arrow::TestMemoryPoolBase {
+ public:
+ MemoryPool* memory_pool() override { return Factory::memory_pool(); }
+};
+
+TYPED_TEST_SUITE_P(TestMemoryPool);
+
+TYPED_TEST_P(TestMemoryPool, MemoryTracking) { this->TestMemoryTracking(); }
+
+TYPED_TEST_P(TestMemoryPool, OOM) {
+#ifndef ADDRESS_SANITIZER
+ this->TestOOM();
+#endif
+}
+
+TYPED_TEST_P(TestMemoryPool, Reallocate) { this->TestReallocate(); }
+
+REGISTER_TYPED_TEST_SUITE_P(TestMemoryPool, MemoryTracking, OOM, Reallocate);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(Default, TestMemoryPool, DefaultMemoryPoolFactory);
+INSTANTIATE_TYPED_TEST_SUITE_P(System, TestMemoryPool, SystemMemoryPoolFactory);
+
+#ifdef ARROW_JEMALLOC
+INSTANTIATE_TYPED_TEST_SUITE_P(Jemalloc, TestMemoryPool, JemallocMemoryPoolFactory);
+#endif
+
+#ifdef ARROW_MIMALLOC
+INSTANTIATE_TYPED_TEST_SUITE_P(Mimalloc, TestMemoryPool, MimallocMemoryPoolFactory);
+#endif
+
+TEST(DefaultMemoryPool, Identity) {
+ // The default memory pool is pointer-identical to one of the backend-specific pools.
+ MemoryPool* pool = default_memory_pool();
+ std::vector<MemoryPool*> specific_pools = {system_memory_pool()};
+#ifdef ARROW_JEMALLOC
+ specific_pools.push_back(nullptr);
+ ASSERT_OK(jemalloc_memory_pool(&specific_pools.back()));
+#endif
+#ifdef ARROW_MIMALLOC
+ specific_pools.push_back(nullptr);
+ ASSERT_OK(mimalloc_memory_pool(&specific_pools.back()));
+#endif
+ ASSERT_NE(std::find(specific_pools.begin(), specific_pools.end(), pool),
+ specific_pools.end());
+}
+
+// Death tests and valgrind are known to not play well 100% of the time. See
+// googletest documentation
+#if !(defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER))
+
+TEST(DefaultMemoryPoolDeathTest, MaxMemory) {
+ MemoryPool* pool = default_memory_pool();
+ uint8_t* data1;
+ uint8_t* data2;
+
+ ASSERT_OK(pool->Allocate(100, &data1));
+ ASSERT_OK(pool->Allocate(50, &data2));
+ pool->Free(data2, 50);
+ ASSERT_OK(pool->Allocate(100, &data2));
+ pool->Free(data1, 100);
+ pool->Free(data2, 100);
+
+ ASSERT_EQ(200, pool->max_memory());
+}
+
+#endif // ARROW_VALGRIND
+
+TEST(LoggingMemoryPool, Logging) {
+ auto pool = MemoryPool::CreateDefault();
+
+ LoggingMemoryPool lp(pool.get());
+
+ uint8_t* data;
+ ASSERT_OK(lp.Allocate(100, &data));
+
+ uint8_t* data2;
+ ASSERT_OK(lp.Allocate(100, &data2));
+
+ lp.Free(data, 100);
+ lp.Free(data2, 100);
+
+ ASSERT_EQ(200, lp.max_memory());
+ ASSERT_EQ(200, pool->max_memory());
+}
+
+TEST(ProxyMemoryPool, Logging) {
+ auto pool = MemoryPool::CreateDefault();
+
+ ProxyMemoryPool pp(pool.get());
+
+ uint8_t* data;
+ ASSERT_OK(pool->Allocate(100, &data));
+
+ uint8_t* data2;
+ ASSERT_OK(pp.Allocate(300, &data2));
+
+ ASSERT_EQ(400, pool->bytes_allocated());
+ ASSERT_EQ(300, pp.bytes_allocated());
+
+ pool->Free(data, 100);
+ pp.Free(data2, 300);
+
+ ASSERT_EQ(0, pool->bytes_allocated());
+ ASSERT_EQ(0, pp.bytes_allocated());
+}
+
+TEST(Jemalloc, SetDirtyPageDecayMillis) {
+ // ARROW-6910
+#ifdef ARROW_JEMALLOC
+ ASSERT_OK(jemalloc_set_decay_ms(0));
+#else
+ ASSERT_RAISES(Invalid, jemalloc_set_decay_ms(0));
+#endif
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/memory_pool_test.h b/src/arrow/cpp/src/arrow/memory_pool_test.h
new file mode 100644
index 000000000..f73f7a028
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/memory_pool_test.h
@@ -0,0 +1,92 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <limits>
+
+#include <gtest/gtest.h>
+
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+
+class TestMemoryPoolBase : public ::testing::Test {
+ public:
+ virtual ::arrow::MemoryPool* memory_pool() = 0;
+
+ void TestMemoryTracking() {
+ auto pool = memory_pool();
+
+ uint8_t* data;
+ ASSERT_OK(pool->Allocate(100, &data));
+ EXPECT_EQ(static_cast<uint64_t>(0), reinterpret_cast<uint64_t>(data) % 64);
+ ASSERT_EQ(100, pool->bytes_allocated());
+
+ uint8_t* data2;
+ ASSERT_OK(pool->Allocate(27, &data2));
+ EXPECT_EQ(static_cast<uint64_t>(0), reinterpret_cast<uint64_t>(data2) % 64);
+ ASSERT_EQ(127, pool->bytes_allocated());
+
+ pool->Free(data, 100);
+ ASSERT_EQ(27, pool->bytes_allocated());
+ pool->Free(data2, 27);
+ ASSERT_EQ(0, pool->bytes_allocated());
+ }
+
+ void TestOOM() {
+ auto pool = memory_pool();
+
+ uint8_t* data;
+ int64_t to_alloc = std::min<uint64_t>(std::numeric_limits<int64_t>::max(),
+ std::numeric_limits<size_t>::max());
+ // subtract 63 to prevent overflow after the size is aligned
+ to_alloc -= 63;
+ ASSERT_RAISES(OutOfMemory, pool->Allocate(to_alloc, &data));
+ }
+
+ void TestReallocate() {
+ auto pool = memory_pool();
+
+ uint8_t* data;
+ ASSERT_OK(pool->Allocate(10, &data));
+ ASSERT_EQ(10, pool->bytes_allocated());
+ data[0] = 35;
+ data[9] = 12;
+
+ // Expand
+ ASSERT_OK(pool->Reallocate(10, 20, &data));
+ ASSERT_EQ(data[9], 12);
+ ASSERT_EQ(20, pool->bytes_allocated());
+
+ // Shrink
+ ASSERT_OK(pool->Reallocate(20, 5, &data));
+ ASSERT_EQ(data[0], 35);
+ ASSERT_EQ(5, pool->bytes_allocated());
+
+ // Free
+ pool->Free(data, 5);
+ ASSERT_EQ(0, pool->bytes_allocated());
+ }
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/pch.h b/src/arrow/cpp/src/arrow/pch.h
new file mode 100644
index 000000000..31da37b82
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/pch.h
@@ -0,0 +1,30 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Often-used headers, for precompiling.
+// If updating this header, please make sure you check compilation speed
+// before checking in. Adding headers which are not used extremely often
+// may incur a slowdown, since it makes the precompiled header heavier to load.
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
diff --git a/src/arrow/cpp/src/arrow/pretty_print.cc b/src/arrow/cpp/src/arrow/pretty_print.cc
new file mode 100644
index 000000000..3ec2961fa
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/pretty_print.cc
@@ -0,0 +1,646 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/pretty_print.h"
+
+#include <algorithm>
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <sstream> // IWYU pragma: keep
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/chunked_array.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/formatting.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/string.h"
+#include "arrow/util/string_view.h"
+#include "arrow/vendored/datetime.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::StringFormatter;
+
+namespace {
+
+class PrettyPrinter {
+ public:
+ PrettyPrinter(const PrettyPrintOptions& options, std::ostream* sink)
+ : options_(options), indent_(options.indent), sink_(sink) {}
+
+ inline void Write(util::string_view data);
+ inline void WriteIndented(util::string_view data);
+ inline void Newline();
+ inline void Indent();
+ inline void IndentAfterNewline();
+ void OpenArray(const Array& array);
+ void CloseArray(const Array& array);
+ void Flush() { (*sink_) << std::flush; }
+
+ PrettyPrintOptions ChildOptions() const {
+ PrettyPrintOptions child_options = options_;
+ child_options.indent = indent_;
+ return child_options;
+ }
+
+ protected:
+ const PrettyPrintOptions& options_;
+ int indent_;
+ std::ostream* sink_;
+};
+
+void PrettyPrinter::OpenArray(const Array& array) {
+ if (!options_.skip_new_lines) {
+ Indent();
+ }
+ (*sink_) << "[";
+ if (array.length() > 0) {
+ Newline();
+ indent_ += options_.indent_size;
+ }
+}
+
+void PrettyPrinter::CloseArray(const Array& array) {
+ if (array.length() > 0) {
+ indent_ -= options_.indent_size;
+ if (!options_.skip_new_lines) {
+ Indent();
+ }
+ }
+ (*sink_) << "]";
+}
+
+void PrettyPrinter::Write(util::string_view data) { (*sink_) << data; }
+
+void PrettyPrinter::WriteIndented(util::string_view data) {
+ Indent();
+ Write(data);
+}
+
+void PrettyPrinter::Newline() {
+ if (options_.skip_new_lines) {
+ return;
+ }
+ (*sink_) << "\n";
+}
+
+void PrettyPrinter::Indent() {
+ for (int i = 0; i < indent_; ++i) {
+ (*sink_) << " ";
+ }
+}
+
+void PrettyPrinter::IndentAfterNewline() {
+ if (options_.skip_new_lines) {
+ return;
+ }
+ Indent();
+}
+
+class ArrayPrinter : public PrettyPrinter {
+ public:
+ ArrayPrinter(const PrettyPrintOptions& options, std::ostream* sink)
+ : PrettyPrinter(options, sink) {}
+
+ private:
+ template <typename FormatFunction>
+ Status WriteValues(const Array& array, FormatFunction&& func,
+ bool indent_non_null_values = true) {
+ // `indent_non_null_values` should be false if `FormatFunction` applies
+ // indentation itself.
+ for (int64_t i = 0; i < array.length(); ++i) {
+ const bool is_last = (i == array.length() - 1);
+ if ((i >= options_.window) && (i < (array.length() - options_.window))) {
+ IndentAfterNewline();
+ (*sink_) << "...";
+ if (!is_last && options_.skip_new_lines) {
+ (*sink_) << ",";
+ }
+ i = array.length() - options_.window - 1;
+ } else if (array.IsNull(i)) {
+ IndentAfterNewline();
+ (*sink_) << options_.null_rep;
+ if (!is_last) {
+ (*sink_) << ",";
+ }
+ } else {
+ if (indent_non_null_values) {
+ IndentAfterNewline();
+ }
+ RETURN_NOT_OK(func(i));
+ if (!is_last) {
+ (*sink_) << ",";
+ }
+ }
+ Newline();
+ }
+ return Status::OK();
+ }
+
+ template <typename ArrayType, typename Formatter>
+ Status WritePrimitiveValues(const ArrayType& array, Formatter* formatter) {
+ auto appender = [&](util::string_view v) { (*sink_) << v; };
+ auto format_func = [&](int64_t i) {
+ (*formatter)(array.GetView(i), appender);
+ return Status::OK();
+ };
+ return WriteValues(array, std::move(format_func));
+ }
+
+ template <typename ArrayType, typename T = typename ArrayType::TypeClass>
+ Status WritePrimitiveValues(const ArrayType& array) {
+ StringFormatter<T> formatter{array.type()};
+ return WritePrimitiveValues(array, &formatter);
+ }
+
+ Status WriteValidityBitmap(const Array& array);
+
+ Status PrintChildren(const std::vector<std::shared_ptr<Array>>& fields, int64_t offset,
+ int64_t length) {
+ for (size_t i = 0; i < fields.size(); ++i) {
+ Newline();
+ Indent();
+ std::stringstream ss;
+ ss << "-- child " << i << " type: " << fields[i]->type()->ToString() << "\n";
+ Write(ss.str());
+
+ std::shared_ptr<Array> field = fields[i];
+ if (offset != 0) {
+ field = field->Slice(offset, length);
+ }
+ RETURN_NOT_OK(PrettyPrint(*field, indent_ + options_.indent_size, sink_));
+ }
+ return Status::OK();
+ }
+
+ //
+ // WriteDataValues(): generic function to write values from an array
+ //
+
+ template <typename ArrayType, typename T = typename ArrayType::TypeClass>
+ enable_if_has_c_type<T, Status> WriteDataValues(const ArrayType& array) {
+ return WritePrimitiveValues(array);
+ }
+
+ Status WriteDataValues(const HalfFloatArray& array) {
+ // XXX do not know how to format half floats yet
+ StringFormatter<Int16Type> formatter{array.type()};
+ return WritePrimitiveValues(array, &formatter);
+ }
+
+ template <typename ArrayType, typename T = typename ArrayType::TypeClass>
+ enable_if_string_like<T, Status> WriteDataValues(const ArrayType& array) {
+ return WriteValues(array, [&](int64_t i) {
+ (*sink_) << "\"" << array.GetView(i) << "\"";
+ return Status::OK();
+ });
+ }
+
+ template <typename ArrayType, typename T = typename ArrayType::TypeClass>
+ enable_if_t<is_binary_like_type<T>::value && !is_decimal_type<T>::value, Status>
+ WriteDataValues(const ArrayType& array) {
+ return WriteValues(array, [&](int64_t i) {
+ (*sink_) << HexEncode(array.GetView(i));
+ return Status::OK();
+ });
+ }
+
+ template <typename ArrayType, typename T = typename ArrayType::TypeClass>
+ enable_if_decimal<T, Status> WriteDataValues(const ArrayType& array) {
+ return WriteValues(array, [&](int64_t i) {
+ (*sink_) << array.FormatValue(i);
+ return Status::OK();
+ });
+ }
+
+ template <typename ArrayType, typename T = typename ArrayType::TypeClass>
+ enable_if_list_like<T, Status> WriteDataValues(const ArrayType& array) {
+ const auto values = array.values();
+ const auto child_options = ChildOptions();
+ ArrayPrinter values_printer(child_options, sink_);
+
+ return WriteValues(
+ array,
+ [&](int64_t i) {
+ // XXX this could be much faster if ArrayPrinter allowed specifying start and
+ // stop endpoints.
+ return values_printer.Print(
+ *values->Slice(array.value_offset(i), array.value_length(i)));
+ },
+ /*indent_non_null_values=*/false);
+ }
+
+ Status WriteDataValues(const MapArray& array) {
+ const auto keys = array.keys();
+ const auto items = array.items();
+ const auto child_options = ChildOptions();
+ ArrayPrinter values_printer(child_options, sink_);
+
+ return WriteValues(
+ array,
+ [&](int64_t i) {
+ Indent();
+ (*sink_) << "keys:";
+ Newline();
+ RETURN_NOT_OK(values_printer.Print(
+ *keys->Slice(array.value_offset(i), array.value_length(i))));
+ Newline();
+ IndentAfterNewline();
+ (*sink_) << "values:";
+ Newline();
+ RETURN_NOT_OK(values_printer.Print(
+ *items->Slice(array.value_offset(i), array.value_length(i))));
+ return Status::OK();
+ },
+ /*indent_non_null_values=*/false);
+ }
+
+ public:
+ template <typename T>
+ enable_if_t<std::is_base_of<PrimitiveArray, T>::value ||
+ std::is_base_of<FixedSizeBinaryArray, T>::value ||
+ std::is_base_of<BinaryArray, T>::value ||
+ std::is_base_of<LargeBinaryArray, T>::value ||
+ std::is_base_of<ListArray, T>::value ||
+ std::is_base_of<LargeListArray, T>::value ||
+ std::is_base_of<MapArray, T>::value ||
+ std::is_base_of<FixedSizeListArray, T>::value,
+ Status>
+ Visit(const T& array) {
+ Status st = array.Validate();
+ if (!st.ok()) {
+ (*sink_) << "<Invalid array: " << st.message() << ">";
+ return Status::OK();
+ }
+
+ OpenArray(array);
+ if (array.length() > 0) {
+ RETURN_NOT_OK(WriteDataValues(array));
+ }
+ CloseArray(array);
+ return Status::OK();
+ }
+
+ Status Visit(const NullArray& array) {
+ (*sink_) << array.length() << " nulls";
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionArray& array) { return Print(*array.storage()); }
+
+ Status Visit(const StructArray& array) {
+ RETURN_NOT_OK(WriteValidityBitmap(array));
+ std::vector<std::shared_ptr<Array>> children;
+ children.reserve(array.num_fields());
+ for (int i = 0; i < array.num_fields(); ++i) {
+ children.emplace_back(array.field(i));
+ }
+ return PrintChildren(children, 0, array.length());
+ }
+
+ Status Visit(const UnionArray& array) {
+ RETURN_NOT_OK(WriteValidityBitmap(array));
+
+ Newline();
+ Indent();
+ Write("-- type_ids: ");
+ UInt8Array type_codes(array.length(), array.type_codes(), nullptr, 0, array.offset());
+ RETURN_NOT_OK(PrettyPrint(type_codes, indent_ + options_.indent_size, sink_));
+
+ if (array.mode() == UnionMode::DENSE) {
+ Newline();
+ Indent();
+ Write("-- value_offsets: ");
+ Int32Array value_offsets(
+ array.length(), checked_cast<const DenseUnionArray&>(array).value_offsets(),
+ nullptr, 0, array.offset());
+ RETURN_NOT_OK(PrettyPrint(value_offsets, indent_ + options_.indent_size, sink_));
+ }
+
+ // Print the children without any offset, because the type ids are absolute
+ std::vector<std::shared_ptr<Array>> children;
+ children.reserve(array.num_fields());
+ for (int i = 0; i < array.num_fields(); ++i) {
+ children.emplace_back(array.field(i));
+ }
+ return PrintChildren(children, 0, array.length() + array.offset());
+ }
+
+ Status Visit(const DictionaryArray& array) {
+ Newline();
+ Indent();
+ Write("-- dictionary:\n");
+ RETURN_NOT_OK(
+ PrettyPrint(*array.dictionary(), indent_ + options_.indent_size, sink_));
+
+ Newline();
+ Indent();
+ Write("-- indices:\n");
+ return PrettyPrint(*array.indices(), indent_ + options_.indent_size, sink_);
+ }
+
+ Status Print(const Array& array) {
+ RETURN_NOT_OK(VisitArrayInline(array, this));
+ Flush();
+ return Status::OK();
+ }
+};
+
+Status ArrayPrinter::WriteValidityBitmap(const Array& array) {
+ Indent();
+ Write("-- is_valid:");
+
+ if (array.null_count() > 0) {
+ Newline();
+ Indent();
+ BooleanArray is_valid(array.length(), array.null_bitmap(), nullptr, 0,
+ array.offset());
+ return PrettyPrint(is_valid, indent_ + options_.indent_size, sink_);
+ } else {
+ Write(" all not null");
+ return Status::OK();
+ }
+}
+
+} // namespace
+
+Status PrettyPrint(const Array& arr, int indent, std::ostream* sink) {
+ PrettyPrintOptions options;
+ options.indent = indent;
+ ArrayPrinter printer(options, sink);
+ return printer.Print(arr);
+}
+
+Status PrettyPrint(const Array& arr, const PrettyPrintOptions& options,
+ std::ostream* sink) {
+ ArrayPrinter printer(options, sink);
+ return printer.Print(arr);
+}
+
+Status PrettyPrint(const Array& arr, const PrettyPrintOptions& options,
+ std::string* result) {
+ std::ostringstream sink;
+ RETURN_NOT_OK(PrettyPrint(arr, options, &sink));
+ *result = sink.str();
+ return Status::OK();
+}
+
+Status PrettyPrint(const ChunkedArray& chunked_arr, const PrettyPrintOptions& options,
+ std::ostream* sink) {
+ int num_chunks = chunked_arr.num_chunks();
+ int indent = options.indent;
+ int window = options.window;
+
+ for (int i = 0; i < indent; ++i) {
+ (*sink) << " ";
+ }
+ (*sink) << "[";
+ if (!options.skip_new_lines) {
+ *sink << "\n";
+ }
+ bool skip_comma = true;
+ for (int i = 0; i < num_chunks; ++i) {
+ if (skip_comma) {
+ skip_comma = false;
+ } else {
+ (*sink) << ",";
+ if (!options.skip_new_lines) {
+ *sink << "\n";
+ }
+ }
+ if ((i >= window) && (i < (num_chunks - window))) {
+ for (int i = 0; i < indent; ++i) {
+ (*sink) << " ";
+ }
+ (*sink) << "...";
+ if (!options.skip_new_lines) {
+ *sink << "\n";
+ }
+ i = num_chunks - window - 1;
+ skip_comma = true;
+ } else {
+ PrettyPrintOptions chunk_options = options;
+ chunk_options.indent += options.indent_size;
+ ArrayPrinter printer(chunk_options, sink);
+ RETURN_NOT_OK(printer.Print(*chunked_arr.chunk(i)));
+ }
+ }
+ if (!options.skip_new_lines) {
+ *sink << "\n";
+ }
+
+ for (int i = 0; i < indent; ++i) {
+ (*sink) << " ";
+ }
+ (*sink) << "]";
+
+ return Status::OK();
+}
+
+Status PrettyPrint(const ChunkedArray& chunked_arr, const PrettyPrintOptions& options,
+ std::string* result) {
+ std::ostringstream sink;
+ RETURN_NOT_OK(PrettyPrint(chunked_arr, options, &sink));
+ *result = sink.str();
+ return Status::OK();
+}
+
+Status PrettyPrint(const RecordBatch& batch, int indent, std::ostream* sink) {
+ for (int i = 0; i < batch.num_columns(); ++i) {
+ const std::string& name = batch.column_name(i);
+ (*sink) << name << ": ";
+ RETURN_NOT_OK(PrettyPrint(*batch.column(i), indent + 2, sink));
+ (*sink) << "\n";
+ }
+ (*sink) << std::flush;
+ return Status::OK();
+}
+
+Status PrettyPrint(const RecordBatch& batch, const PrettyPrintOptions& options,
+ std::ostream* sink) {
+ for (int i = 0; i < batch.num_columns(); ++i) {
+ const std::string& name = batch.column_name(i);
+ PrettyPrintOptions column_options = options;
+ column_options.indent += 2;
+
+ (*sink) << name << ": ";
+ RETURN_NOT_OK(PrettyPrint(*batch.column(i), column_options, sink));
+ (*sink) << "\n";
+ }
+ (*sink) << std::flush;
+ return Status::OK();
+}
+
+Status PrettyPrint(const Table& table, const PrettyPrintOptions& options,
+ std::ostream* sink) {
+ RETURN_NOT_OK(PrettyPrint(*table.schema(), options, sink));
+ (*sink) << "\n";
+ (*sink) << "----\n";
+
+ PrettyPrintOptions column_options = options;
+ column_options.indent += 2;
+ for (int i = 0; i < table.num_columns(); ++i) {
+ for (int j = 0; j < options.indent; ++j) {
+ (*sink) << " ";
+ }
+ (*sink) << table.schema()->field(i)->name() << ":\n";
+ RETURN_NOT_OK(PrettyPrint(*table.column(i), column_options, sink));
+ (*sink) << "\n";
+ }
+ (*sink) << std::flush;
+ return Status::OK();
+}
+
+Status DebugPrint(const Array& arr, int indent) {
+ return PrettyPrint(arr, indent, &std::cerr);
+}
+
+namespace {
+
+class SchemaPrinter : public PrettyPrinter {
+ public:
+ SchemaPrinter(const Schema& schema, const PrettyPrintOptions& options,
+ std::ostream* sink)
+ : PrettyPrinter(options, sink), schema_(schema) {}
+
+ Status PrintType(const DataType& type, bool nullable);
+ Status PrintField(const Field& field);
+
+ void PrintVerboseMetadata(const KeyValueMetadata& metadata) {
+ for (int64_t i = 0; i < metadata.size(); ++i) {
+ Newline();
+ Indent();
+ Write(metadata.key(i) + ": '" + metadata.value(i) + "'");
+ }
+ }
+
+ void PrintTruncatedMetadata(const KeyValueMetadata& metadata) {
+ for (int64_t i = 0; i < metadata.size(); ++i) {
+ Newline();
+ Indent();
+ size_t size = metadata.value(i).size();
+ size_t truncated_size = std::max<size_t>(10, 70 - metadata.key(i).size() - indent_);
+ if (size <= truncated_size) {
+ Write(metadata.key(i) + ": '" + metadata.value(i) + "'");
+ continue;
+ }
+
+ Write(metadata.key(i) + ": '" + metadata.value(i).substr(0, truncated_size) +
+ "' + " + std::to_string(size - truncated_size));
+ }
+ }
+
+ void PrintMetadata(const std::string& metadata_type, const KeyValueMetadata& metadata) {
+ if (metadata.size() > 0) {
+ Newline();
+ Indent();
+ Write(metadata_type);
+ if (options_.truncate_metadata) {
+ PrintTruncatedMetadata(metadata);
+ } else {
+ PrintVerboseMetadata(metadata);
+ }
+ }
+ }
+
+ Status Print() {
+ for (int i = 0; i < schema_.num_fields(); ++i) {
+ if (i > 0) {
+ Newline();
+ Indent();
+ } else {
+ Indent();
+ }
+ RETURN_NOT_OK(PrintField(*schema_.field(i)));
+ }
+
+ if (options_.show_schema_metadata && schema_.metadata() != nullptr) {
+ PrintMetadata("-- schema metadata --", *schema_.metadata());
+ }
+ Flush();
+ return Status::OK();
+ }
+
+ private:
+ const Schema& schema_;
+};
+
+Status SchemaPrinter::PrintType(const DataType& type, bool nullable) {
+ Write(type.ToString());
+ if (!nullable) {
+ Write(" not null");
+ }
+ for (int i = 0; i < type.num_fields(); ++i) {
+ Newline();
+ Indent();
+
+ std::stringstream ss;
+ ss << "child " << i << ", ";
+
+ indent_ += options_.indent_size;
+ WriteIndented(ss.str());
+ RETURN_NOT_OK(PrintField(*type.field(i)));
+ indent_ -= options_.indent_size;
+ }
+ return Status::OK();
+}
+
+Status SchemaPrinter::PrintField(const Field& field) {
+ Write(field.name());
+ Write(": ");
+ RETURN_NOT_OK(PrintType(*field.type(), field.nullable()));
+
+ if (options_.show_field_metadata && field.metadata() != nullptr) {
+ indent_ += options_.indent_size;
+ PrintMetadata("-- field metadata --", *field.metadata());
+ indent_ -= options_.indent_size;
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status PrettyPrint(const Schema& schema, const PrettyPrintOptions& options,
+ std::ostream* sink) {
+ SchemaPrinter printer(schema, options, sink);
+ return printer.Print();
+}
+
+Status PrettyPrint(const Schema& schema, const PrettyPrintOptions& options,
+ std::string* result) {
+ std::ostringstream sink;
+ RETURN_NOT_OK(PrettyPrint(schema, options, &sink));
+ *result = sink.str();
+ return Status::OK();
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/pretty_print.h b/src/arrow/cpp/src/arrow/pretty_print.h
new file mode 100644
index 000000000..1bc086a68
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/pretty_print.h
@@ -0,0 +1,125 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <iosfwd>
+#include <string>
+#include <utility>
+
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Array;
+class ChunkedArray;
+class RecordBatch;
+class Schema;
+class Status;
+class Table;
+
+struct PrettyPrintOptions {
+ PrettyPrintOptions() = default;
+
+ PrettyPrintOptions(int indent_arg, // NOLINT runtime/explicit
+ int window_arg = 10, int indent_size_arg = 2,
+ std::string null_rep_arg = "null", bool skip_new_lines_arg = false,
+ bool truncate_metadata_arg = true)
+ : indent(indent_arg),
+ indent_size(indent_size_arg),
+ window(window_arg),
+ null_rep(std::move(null_rep_arg)),
+ skip_new_lines(skip_new_lines_arg),
+ truncate_metadata(truncate_metadata_arg) {}
+
+ static PrettyPrintOptions Defaults() { return PrettyPrintOptions(); }
+
+ /// Number of spaces to shift entire formatted object to the right
+ int indent = 0;
+
+ /// Size of internal indents
+ int indent_size = 2;
+
+ /// Maximum number of elements to show at the beginning and at the end.
+ int window = 10;
+
+ /// String to use for representing a null value, defaults to "null"
+ std::string null_rep = "null";
+
+ /// Skip new lines between elements, defaults to false
+ bool skip_new_lines = false;
+
+ /// Limit display of each KeyValueMetadata key/value pair to a single line at
+ /// 80 character width
+ bool truncate_metadata = true;
+
+ /// If true, display field metadata when pretty-printing a Schema
+ bool show_field_metadata = true;
+
+ /// If true, display schema metadata when pretty-printing a Schema
+ bool show_schema_metadata = true;
+};
+
+/// \brief Print human-readable representation of RecordBatch
+ARROW_EXPORT
+Status PrettyPrint(const RecordBatch& batch, int indent, std::ostream* sink);
+
+ARROW_EXPORT
+Status PrettyPrint(const RecordBatch& batch, const PrettyPrintOptions& options,
+ std::ostream* sink);
+
+/// \brief Print human-readable representation of Table
+ARROW_EXPORT
+Status PrettyPrint(const Table& table, const PrettyPrintOptions& options,
+ std::ostream* sink);
+
+/// \brief Print human-readable representation of Array
+ARROW_EXPORT
+Status PrettyPrint(const Array& arr, int indent, std::ostream* sink);
+
+/// \brief Print human-readable representation of Array
+ARROW_EXPORT
+Status PrettyPrint(const Array& arr, const PrettyPrintOptions& options,
+ std::ostream* sink);
+
+/// \brief Print human-readable representation of Array
+ARROW_EXPORT
+Status PrettyPrint(const Array& arr, const PrettyPrintOptions& options,
+ std::string* result);
+
+/// \brief Print human-readable representation of ChunkedArray
+ARROW_EXPORT
+Status PrettyPrint(const ChunkedArray& chunked_arr, const PrettyPrintOptions& options,
+ std::ostream* sink);
+
+/// \brief Print human-readable representation of ChunkedArray
+ARROW_EXPORT
+Status PrettyPrint(const ChunkedArray& chunked_arr, const PrettyPrintOptions& options,
+ std::string* result);
+
+ARROW_EXPORT
+Status PrettyPrint(const Schema& schema, const PrettyPrintOptions& options,
+ std::ostream* sink);
+
+ARROW_EXPORT
+Status PrettyPrint(const Schema& schema, const PrettyPrintOptions& options,
+ std::string* result);
+
+ARROW_EXPORT
+Status DebugPrint(const Array& arr, int indent);
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/pretty_print_test.cc b/src/arrow/cpp/src/arrow/pretty_print_test.cc
new file mode 100644
index 000000000..42995de32
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/pretty_print_test.cc
@@ -0,0 +1,1081 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/pretty_print.h"
+
+#include <gtest/gtest.h>
+
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+
+class TestPrettyPrint : public ::testing::Test {
+ public:
+ void SetUp() {}
+
+ void Print(const Array& array) {}
+
+ private:
+ std::ostringstream sink_;
+};
+
+template <typename T>
+void CheckStream(const T& obj, const PrettyPrintOptions& options, const char* expected) {
+ std::ostringstream sink;
+ ASSERT_OK(PrettyPrint(obj, options, &sink));
+ std::string result = sink.str();
+ ASSERT_EQ(std::string(expected, strlen(expected)), result);
+}
+
+void CheckArray(const Array& arr, const PrettyPrintOptions& options, const char* expected,
+ bool check_operator = true) {
+ ARROW_SCOPED_TRACE("For datatype: ", arr.type()->ToString());
+ CheckStream(arr, options, expected);
+
+ if (options.indent == 0 && check_operator) {
+ std::stringstream ss;
+ ss << arr;
+ std::string result = std::string(expected, strlen(expected));
+ ASSERT_EQ(result, ss.str());
+ }
+}
+
+template <typename T>
+void Check(const T& obj, const PrettyPrintOptions& options, const char* expected) {
+ std::string result;
+ ASSERT_OK(PrettyPrint(obj, options, &result));
+ ASSERT_EQ(std::string(expected, strlen(expected)), result);
+}
+
+template <typename TYPE, typename C_TYPE>
+void CheckPrimitive(const std::shared_ptr<DataType>& type,
+ const PrettyPrintOptions& options, const std::vector<bool>& is_valid,
+ const std::vector<C_TYPE>& values, const char* expected,
+ bool check_operator = true) {
+ std::shared_ptr<Array> array;
+ ArrayFromVector<TYPE, C_TYPE>(type, is_valid, values, &array);
+ CheckArray(*array, options, expected, check_operator);
+}
+
+template <typename TYPE, typename C_TYPE>
+void CheckPrimitive(const PrettyPrintOptions& options, const std::vector<bool>& is_valid,
+ const std::vector<C_TYPE>& values, const char* expected,
+ bool check_operator = true) {
+ CheckPrimitive<TYPE, C_TYPE>(TypeTraits<TYPE>::type_singleton(), options, is_valid,
+ values, expected, check_operator);
+}
+
+TEST_F(TestPrettyPrint, PrimitiveType) {
+ std::vector<bool> is_valid = {true, true, false, true, false};
+
+ std::vector<int32_t> values = {0, 1, 2, 3, 4};
+ static const char* expected = R"expected([
+ 0,
+ 1,
+ null,
+ 3,
+ null
+])expected";
+ CheckPrimitive<Int32Type, int32_t>({0, 10}, is_valid, values, expected);
+
+ static const char* expected_na = R"expected([
+ 0,
+ 1,
+ NA,
+ 3,
+ NA
+])expected";
+ CheckPrimitive<Int32Type, int32_t>({0, 10, 2, "NA"}, is_valid, values, expected_na,
+ false);
+
+ static const char* ex_in2 = R"expected( [
+ 0,
+ 1,
+ null,
+ 3,
+ null
+ ])expected";
+ CheckPrimitive<Int32Type, int32_t>({2, 10}, is_valid, values, ex_in2);
+ static const char* ex_in2_w2 = R"expected( [
+ 0,
+ 1,
+ ...
+ 3,
+ null
+ ])expected";
+ CheckPrimitive<Int32Type, int32_t>({2, 2}, is_valid, values, ex_in2_w2);
+
+ std::vector<double> values2 = {0., 1., 2., 3., 4.};
+ static const char* ex2 = R"expected([
+ 0,
+ 1,
+ null,
+ 3,
+ null
+])expected";
+ CheckPrimitive<DoubleType, double>({0, 10}, is_valid, values2, ex2);
+ static const char* ex2_in2 = R"expected( [
+ 0,
+ 1,
+ null,
+ 3,
+ null
+ ])expected";
+ CheckPrimitive<DoubleType, double>({2, 10}, is_valid, values2, ex2_in2);
+
+ std::vector<std::string> values3 = {"foo", "bar", "", "baz", ""};
+ static const char* ex3 = R"expected([
+ "foo",
+ "bar",
+ null,
+ "baz",
+ null
+])expected";
+ CheckPrimitive<StringType, std::string>({0, 10}, is_valid, values3, ex3);
+ CheckPrimitive<LargeStringType, std::string>({0, 10}, is_valid, values3, ex3);
+ static const char* ex3_in2 = R"expected( [
+ "foo",
+ "bar",
+ null,
+ "baz",
+ null
+ ])expected";
+ CheckPrimitive<StringType, std::string>({2, 10}, is_valid, values3, ex3_in2);
+ CheckPrimitive<LargeStringType, std::string>({2, 10}, is_valid, values3, ex3_in2);
+}
+
+TEST_F(TestPrettyPrint, PrimitiveTypeNoNewlines) {
+ std::vector<bool> is_valid = {true, true, false, true, false};
+ std::vector<int32_t> values = {0, 1, 2, 3, 4};
+
+ PrettyPrintOptions options{};
+ options.skip_new_lines = true;
+ options.window = 4;
+
+ const char* expected = "[0,1,null,3,null]";
+ CheckPrimitive<Int32Type, int32_t>(options, is_valid, values, expected, false);
+
+ // With ellipsis
+ is_valid.insert(is_valid.end(), 20, true);
+ is_valid.insert(is_valid.end(), {true, false, true});
+ values.insert(values.end(), 20, 99);
+ values.insert(values.end(), {44, 43, 42});
+
+ expected = "[0,1,null,3,...,99,44,null,42]";
+ CheckPrimitive<Int32Type, int32_t>(options, is_valid, values, expected, false);
+}
+
+TEST_F(TestPrettyPrint, Int8) {
+ static const char* expected = R"expected([
+ 0,
+ 127,
+ -128
+])expected";
+ CheckPrimitive<Int8Type, int8_t>({0, 10}, {true, true, true}, {0, 127, -128}, expected);
+}
+
+TEST_F(TestPrettyPrint, UInt8) {
+ static const char* expected = R"expected([
+ 0,
+ 255
+])expected";
+ CheckPrimitive<UInt8Type, uint8_t>({0, 10}, {true, true}, {0, 255}, expected);
+}
+
+TEST_F(TestPrettyPrint, Int64) {
+ static const char* expected = R"expected([
+ 0,
+ 9223372036854775807,
+ -9223372036854775808
+])expected";
+ CheckPrimitive<Int64Type, int64_t>(
+ {0, 10}, {true, true, true}, {0, 9223372036854775807LL, -9223372036854775807LL - 1},
+ expected);
+}
+
+TEST_F(TestPrettyPrint, UInt64) {
+ static const char* expected = R"expected([
+ 0,
+ 9223372036854775803,
+ 18446744073709551615
+])expected";
+ CheckPrimitive<UInt64Type, uint64_t>(
+ {0, 10}, {true, true, true}, {0, 9223372036854775803ULL, 18446744073709551615ULL},
+ expected);
+}
+
+TEST_F(TestPrettyPrint, DateTimeTypes) {
+ std::vector<bool> is_valid = {true, true, false, true, false};
+
+ {
+ std::vector<int32_t> values = {0, 1, 2, 31, 4};
+ static const char* expected = R"expected([
+ 1970-01-01,
+ 1970-01-02,
+ null,
+ 1970-02-01,
+ null
+])expected";
+ CheckPrimitive<Date32Type, int32_t>({0, 10}, is_valid, values, expected);
+ }
+
+ {
+ constexpr int64_t ms_per_day = 24 * 60 * 60 * 1000;
+ std::vector<int64_t> values = {0 * ms_per_day, 1 * ms_per_day, 2 * ms_per_day,
+ 31 * ms_per_day, 4 * ms_per_day};
+ static const char* expected = R"expected([
+ 1970-01-01,
+ 1970-01-02,
+ null,
+ 1970-02-01,
+ null
+])expected";
+ CheckPrimitive<Date64Type, int64_t>({0, 10}, is_valid, values, expected);
+ }
+
+ {
+ std::vector<int64_t> values = {
+ 0, 1, 2, 678 + 1000000 * (5 + 60 * (4 + 60 * (3 + 24 * int64_t(1)))), 4};
+ static const char* expected = R"expected([
+ 1970-01-01 00:00:00.000000,
+ 1970-01-01 00:00:00.000001,
+ null,
+ 1970-01-02 03:04:05.000678,
+ null
+])expected";
+ CheckPrimitive<TimestampType, int64_t>(timestamp(TimeUnit::MICRO, "Transylvania"),
+ {0, 10}, is_valid, values, expected);
+ }
+
+ {
+ std::vector<int32_t> values = {1, 62, 2, 3 + 60 * (2 + 60 * 1), 4};
+ static const char* expected = R"expected([
+ 00:00:01,
+ 00:01:02,
+ null,
+ 01:02:03,
+ null
+])expected";
+ CheckPrimitive<Time32Type, int32_t>(time32(TimeUnit::SECOND), {0, 10}, is_valid,
+ values, expected);
+ }
+
+ {
+ std::vector<int64_t> values = {
+ 0, 1, 2, 678 + int64_t(1000000000) * (5 + 60 * (4 + 60 * 3)), 4};
+ static const char* expected = R"expected([
+ 00:00:00.000000000,
+ 00:00:00.000000001,
+ null,
+ 03:04:05.000000678,
+ null
+])expected";
+ CheckPrimitive<Time64Type, int64_t>(time64(TimeUnit::NANO), {0, 10}, is_valid, values,
+ expected);
+ }
+}
+
+TEST_F(TestPrettyPrint, TestIntervalTypes) {
+ std::vector<bool> is_valid = {true, true, false, true, false};
+
+ {
+ std::vector<DayTimeIntervalType::DayMilliseconds> values = {
+ {1, 2}, {-3, 4}, {}, {}, {}};
+ static const char* expected = R"expected([
+ 1d2ms,
+ -3d4ms,
+ null,
+ 0d0ms,
+ null
+])expected";
+ CheckPrimitive<DayTimeIntervalType, DayTimeIntervalType::DayMilliseconds>(
+ {0, 10}, is_valid, values, expected);
+ }
+ {
+ std::vector<MonthDayNanoIntervalType::MonthDayNanos> values = {
+ {1, 2, 3}, {-3, 4, -5}, {}, {}, {}};
+ static const char* expected = R"expected([
+ 1M2d3ns,
+ -3M4d-5ns,
+ null,
+ 0M0d0ns,
+ null
+])expected";
+ CheckPrimitive<MonthDayNanoIntervalType, MonthDayNanoIntervalType::MonthDayNanos>(
+ {0, 10}, is_valid, values, expected);
+ }
+}
+
+TEST_F(TestPrettyPrint, DateTimeTypesWithOutOfRangeValues) {
+ // Our vendored date library allows years within [-32767, 32767],
+ // which limits the range of values which can be displayed.
+ const int32_t min_int32 = std::numeric_limits<int32_t>::min();
+ const int32_t max_int32 = std::numeric_limits<int32_t>::max();
+ const int64_t min_int64 = std::numeric_limits<int64_t>::min();
+ const int64_t max_int64 = std::numeric_limits<int64_t>::max();
+
+ const int32_t min_date32 = -12687428;
+ const int32_t max_date32 = 11248737;
+ const int64_t min_date64 = 86400000LL * min_date32;
+ const int64_t max_date64 = 86400000LL * (max_date32 + 1) - 1;
+
+ const int32_t min_time32_seconds = 0;
+ const int32_t max_time32_seconds = 86399;
+ const int32_t min_time32_millis = 0;
+ const int32_t max_time32_millis = 86399999;
+ const int64_t min_time64_micros = 0;
+ const int64_t max_time64_micros = 86399999999LL;
+ const int64_t min_time64_nanos = 0;
+ const int64_t max_time64_nanos = 86399999999999LL;
+
+ const int64_t min_timestamp_seconds = -1096193779200LL;
+ const int64_t max_timestamp_seconds = 971890963199LL;
+ const int64_t min_timestamp_millis = min_timestamp_seconds * 1000;
+ const int64_t max_timestamp_millis = max_timestamp_seconds * 1000 + 999;
+ const int64_t min_timestamp_micros = min_timestamp_millis * 1000;
+ const int64_t max_timestamp_micros = max_timestamp_millis * 1000 + 999;
+
+ std::vector<bool> is_valid = {false, false, false, false, true,
+ true, true, true, true, true};
+
+ // Dates
+ {
+ std::vector<int32_t> values = {min_int32, max_int32, min_date32 - 1, max_date32 + 1,
+ min_int32, max_int32, min_date32 - 1, max_date32 + 1,
+ min_date32, max_date32};
+ static const char* expected = R"expected([
+ null,
+ null,
+ null,
+ null,
+ <value out of range: -2147483648>,
+ <value out of range: 2147483647>,
+ <value out of range: -12687429>,
+ <value out of range: 11248738>,
+ -32767-01-01,
+ 32767-12-31
+])expected";
+ CheckPrimitive<Date32Type, int32_t>({0, 10}, is_valid, values, expected);
+ }
+ {
+ std::vector<int64_t> values = {min_int64, max_int64, min_date64 - 1, max_date64 + 1,
+ min_int64, max_int64, min_date64 - 1, max_date64 + 1,
+ min_date64, max_date64};
+ static const char* expected = R"expected([
+ null,
+ null,
+ null,
+ null,
+ <value out of range: -9223372036854775808>,
+ <value out of range: 9223372036854775807>,
+ <value out of range: -1096193779200001>,
+ <value out of range: 971890963200000>,
+ -32767-01-01,
+ 32767-12-31
+])expected";
+ CheckPrimitive<Date64Type, int64_t>({0, 10}, is_valid, values, expected);
+ }
+
+ // Times
+ {
+ std::vector<int32_t> values = {min_int32,
+ max_int32,
+ min_time32_seconds - 1,
+ max_time32_seconds + 1,
+ min_int32,
+ max_int32,
+ min_time32_seconds - 1,
+ max_time32_seconds + 1,
+ min_time32_seconds,
+ max_time32_seconds};
+ static const char* expected = R"expected([
+ null,
+ null,
+ null,
+ null,
+ <value out of range: -2147483648>,
+ <value out of range: 2147483647>,
+ <value out of range: -1>,
+ <value out of range: 86400>,
+ 00:00:00,
+ 23:59:59
+])expected";
+ CheckPrimitive<Time32Type, int32_t>(time32(TimeUnit::SECOND), {0, 10}, is_valid,
+ values, expected);
+ }
+ {
+ std::vector<int32_t> values = {
+ min_int32, max_int32, min_time32_millis - 1, max_time32_millis + 1,
+ min_int32, max_int32, min_time32_millis - 1, max_time32_millis + 1,
+ min_time32_millis, max_time32_millis};
+ static const char* expected = R"expected([
+ null,
+ null,
+ null,
+ null,
+ <value out of range: -2147483648>,
+ <value out of range: 2147483647>,
+ <value out of range: -1>,
+ <value out of range: 86400000>,
+ 00:00:00.000,
+ 23:59:59.999
+])expected";
+ CheckPrimitive<Time32Type, int32_t>(time32(TimeUnit::MILLI), {0, 10}, is_valid,
+ values, expected);
+ }
+ {
+ std::vector<int64_t> values = {
+ min_int64, max_int64, min_time64_micros - 1, max_time64_micros + 1,
+ min_int64, max_int64, min_time64_micros - 1, max_time64_micros + 1,
+ min_time64_micros, max_time64_micros};
+ static const char* expected = R"expected([
+ null,
+ null,
+ null,
+ null,
+ <value out of range: -9223372036854775808>,
+ <value out of range: 9223372036854775807>,
+ <value out of range: -1>,
+ <value out of range: 86400000000>,
+ 00:00:00.000000,
+ 23:59:59.999999
+])expected";
+ CheckPrimitive<Time64Type, int64_t>(time64(TimeUnit::MICRO), {0, 10}, is_valid,
+ values, expected);
+ }
+ {
+ std::vector<int64_t> values = {
+ min_int64, max_int64, min_time64_nanos - 1, max_time64_nanos + 1,
+ min_int64, max_int64, min_time64_nanos - 1, max_time64_nanos + 1,
+ min_time64_nanos, max_time64_nanos};
+ static const char* expected = R"expected([
+ null,
+ null,
+ null,
+ null,
+ <value out of range: -9223372036854775808>,
+ <value out of range: 9223372036854775807>,
+ <value out of range: -1>,
+ <value out of range: 86400000000000>,
+ 00:00:00.000000000,
+ 23:59:59.999999999
+])expected";
+ CheckPrimitive<Time64Type, int64_t>(time64(TimeUnit::NANO), {0, 10}, is_valid, values,
+ expected);
+ }
+
+ // Timestamps
+ {
+ std::vector<int64_t> values = {min_int64,
+ max_int64,
+ min_timestamp_seconds - 1,
+ max_timestamp_seconds + 1,
+ min_int64,
+ max_int64,
+ min_timestamp_seconds - 1,
+ max_timestamp_seconds + 1,
+ min_timestamp_seconds,
+ max_timestamp_seconds};
+ static const char* expected = R"expected([
+ null,
+ null,
+ null,
+ null,
+ <value out of range: -9223372036854775808>,
+ <value out of range: 9223372036854775807>,
+ <value out of range: -1096193779201>,
+ <value out of range: 971890963200>,
+ -32767-01-01 00:00:00,
+ 32767-12-31 23:59:59
+])expected";
+ CheckPrimitive<TimestampType, int64_t>(timestamp(TimeUnit::SECOND), {0, 10}, is_valid,
+ values, expected);
+ }
+ {
+ std::vector<int64_t> values = {min_int64,
+ max_int64,
+ min_timestamp_millis - 1,
+ max_timestamp_millis + 1,
+ min_int64,
+ max_int64,
+ min_timestamp_millis - 1,
+ max_timestamp_millis + 1,
+ min_timestamp_millis,
+ max_timestamp_millis};
+ static const char* expected = R"expected([
+ null,
+ null,
+ null,
+ null,
+ <value out of range: -9223372036854775808>,
+ <value out of range: 9223372036854775807>,
+ <value out of range: -1096193779200001>,
+ <value out of range: 971890963200000>,
+ -32767-01-01 00:00:00.000,
+ 32767-12-31 23:59:59.999
+])expected";
+ CheckPrimitive<TimestampType, int64_t>(timestamp(TimeUnit::MILLI), {0, 10}, is_valid,
+ values, expected);
+ }
+ {
+ std::vector<int64_t> values = {min_int64,
+ max_int64,
+ min_timestamp_micros - 1,
+ max_timestamp_micros + 1,
+ min_int64,
+ max_int64,
+ min_timestamp_micros - 1,
+ max_timestamp_micros + 1,
+ min_timestamp_micros,
+ max_timestamp_micros};
+ static const char* expected = R"expected([
+ null,
+ null,
+ null,
+ null,
+ <value out of range: -9223372036854775808>,
+ <value out of range: 9223372036854775807>,
+ <value out of range: -1096193779200000001>,
+ <value out of range: 971890963200000000>,
+ -32767-01-01 00:00:00.000000,
+ 32767-12-31 23:59:59.999999
+])expected";
+ CheckPrimitive<TimestampType, int64_t>(timestamp(TimeUnit::MICRO), {0, 10}, is_valid,
+ values, expected);
+ }
+ // Note that while the values below are legal and correct, they used to
+ // trigger an internal signed overflow inside the vendored "date" library
+ // (https://github.com/HowardHinnant/date/issues/696).
+ {
+ std::vector<int64_t> values = {min_int64, max_int64};
+ static const char* expected = R"expected([
+ 1677-09-21 00:12:43.145224192,
+ 2262-04-11 23:47:16.854775807
+])expected";
+ CheckPrimitive<TimestampType, int64_t>(timestamp(TimeUnit::NANO), {0, 10},
+ {true, true}, values, expected);
+ }
+}
+
+TEST_F(TestPrettyPrint, StructTypeBasic) {
+ auto simple_1 = field("one", int32());
+ auto simple_2 = field("two", int32());
+ auto simple_struct = struct_({simple_1, simple_2});
+
+ auto array = ArrayFromJSON(simple_struct, "[[11, 22]]");
+
+ static const char* ex = R"expected(-- is_valid: all not null
+-- child 0 type: int32
+ [
+ 11
+ ]
+-- child 1 type: int32
+ [
+ 22
+ ])expected";
+ CheckStream(*array, {0, 10}, ex);
+
+ static const char* ex_2 = R"expected( -- is_valid: all not null
+ -- child 0 type: int32
+ [
+ 11
+ ]
+ -- child 1 type: int32
+ [
+ 22
+ ])expected";
+ CheckStream(*array, {2, 10}, ex_2);
+}
+
+TEST_F(TestPrettyPrint, StructTypeAdvanced) {
+ auto simple_1 = field("one", int32());
+ auto simple_2 = field("two", int32());
+ auto simple_struct = struct_({simple_1, simple_2});
+
+ auto array = ArrayFromJSON(simple_struct, "[[11, 22], null, [null, 33]]");
+
+ static const char* ex = R"expected(-- is_valid:
+ [
+ true,
+ false,
+ true
+ ]
+-- child 0 type: int32
+ [
+ 11,
+ 0,
+ null
+ ]
+-- child 1 type: int32
+ [
+ 22,
+ 0,
+ 33
+ ])expected";
+ CheckStream(*array, {0, 10}, ex);
+}
+
+TEST_F(TestPrettyPrint, BinaryType) {
+ std::vector<bool> is_valid = {true, true, false, true, true, true};
+ std::vector<std::string> values = {"foo", "bar", "", "baz", "", "\xff"};
+ static const char* ex = "[\n 666F6F,\n 626172,\n null,\n 62617A,\n ,\n FF\n]";
+ CheckPrimitive<BinaryType, std::string>({0}, is_valid, values, ex);
+ CheckPrimitive<LargeBinaryType, std::string>({0}, is_valid, values, ex);
+ static const char* ex_in2 =
+ " [\n 666F6F,\n 626172,\n null,\n 62617A,\n ,\n FF\n ]";
+ CheckPrimitive<BinaryType, std::string>({2}, is_valid, values, ex_in2);
+ CheckPrimitive<LargeBinaryType, std::string>({2}, is_valid, values, ex_in2);
+}
+
+TEST_F(TestPrettyPrint, BinaryNoNewlines) {
+ std::vector<bool> is_valid = {true, true, false, true, true, true};
+ std::vector<std::string> values = {"foo", "bar", "", "baz", "", "\xff"};
+
+ PrettyPrintOptions options{};
+ options.skip_new_lines = true;
+
+ const char* expected = "[666F6F,626172,null,62617A,,FF]";
+ CheckPrimitive<BinaryType, std::string>(options, is_valid, values, expected, false);
+
+ // With ellipsis
+ options.window = 2;
+ expected = "[666F6F,626172,...,,FF]";
+ CheckPrimitive<BinaryType, std::string>(options, is_valid, values, expected, false);
+}
+
+TEST_F(TestPrettyPrint, ListType) {
+ auto list_type = list(int64());
+
+ static const char* ex = R"expected([
+ [
+ null
+ ],
+ [],
+ null,
+ [
+ 4,
+ 6,
+ 7
+ ],
+ [
+ 2,
+ 3
+ ]
+])expected";
+ static const char* ex_2 = R"expected( [
+ [
+ null
+ ],
+ [],
+ null,
+ [
+ 4,
+ 6,
+ 7
+ ],
+ [
+ 2,
+ 3
+ ]
+ ])expected";
+ static const char* ex_3 = R"expected([
+ [
+ null
+ ],
+ ...
+ [
+ 2,
+ 3
+ ]
+])expected";
+
+ auto array = ArrayFromJSON(list_type, "[[null], [], null, [4, 6, 7], [2, 3]]");
+ CheckArray(*array, {0, 10}, ex);
+ CheckArray(*array, {2, 10}, ex_2);
+ CheckStream(*array, {0, 1}, ex_3);
+
+ list_type = large_list(int64());
+ array = ArrayFromJSON(list_type, "[[null], [], null, [4, 6, 7], [2, 3]]");
+ CheckArray(*array, {0, 10}, ex);
+ CheckArray(*array, {2, 10}, ex_2);
+ CheckStream(*array, {0, 1}, ex_3);
+}
+
+TEST_F(TestPrettyPrint, ListTypeNoNewlines) {
+ auto list_type = list(int64());
+ auto empty_array = ArrayFromJSON(list_type, "[]");
+ auto array = ArrayFromJSON(list_type, "[[null], [], null, [4, 5, 6, 7, 8], [2, 3]]");
+
+ PrettyPrintOptions options{};
+ options.skip_new_lines = true;
+ options.null_rep = "NA";
+ CheckArray(*empty_array, options, "[]", false);
+ CheckArray(*array, options, "[[NA],[],NA,[4,5,6,7,8],[2,3]]", false);
+
+ options.window = 2;
+ CheckArray(*empty_array, options, "[]", false);
+ CheckArray(*array, options, "[[NA],[],...,[4,5,...,7,8],[2,3]]", false);
+}
+
+TEST_F(TestPrettyPrint, MapType) {
+ auto map_type = map(utf8(), int64());
+ auto array = ArrayFromJSON(map_type, R"([
+ [["joe", 0], ["mark", null]],
+ null,
+ [["cap", 8]],
+ []
+ ])");
+
+ static const char* ex = R"expected([
+ keys:
+ [
+ "joe",
+ "mark"
+ ]
+ values:
+ [
+ 0,
+ null
+ ],
+ null,
+ keys:
+ [
+ "cap"
+ ]
+ values:
+ [
+ 8
+ ],
+ keys:
+ []
+ values:
+ []
+])expected";
+ CheckArray(*array, {0, 10}, ex);
+}
+
+TEST_F(TestPrettyPrint, FixedSizeListType) {
+ auto list_type = fixed_size_list(int32(), 3);
+ auto array = ArrayFromJSON(list_type,
+ "[[null, 0, 1], [2, 3, null], null, [4, 6, 7], [8, 9, 5]]");
+
+ CheckArray(*array, {0, 10}, R"expected([
+ [
+ null,
+ 0,
+ 1
+ ],
+ [
+ 2,
+ 3,
+ null
+ ],
+ null,
+ [
+ 4,
+ 6,
+ 7
+ ],
+ [
+ 8,
+ 9,
+ 5
+ ]
+])expected");
+ CheckStream(*array, {0, 1}, R"expected([
+ [
+ null,
+ ...
+ 1
+ ],
+ ...
+ [
+ 8,
+ ...
+ 5
+ ]
+])expected");
+}
+
+TEST_F(TestPrettyPrint, FixedSizeBinaryType) {
+ std::vector<bool> is_valid = {true, true, false, true, false};
+
+ auto type = fixed_size_binary(3);
+ auto array = ArrayFromJSON(type, "[\"foo\", \"bar\", null, \"baz\"]");
+
+ static const char* ex = "[\n 666F6F,\n 626172,\n null,\n 62617A\n]";
+ CheckArray(*array, {0, 10}, ex);
+ static const char* ex_2 = " [\n 666F6F,\n ...\n 62617A\n ]";
+ CheckArray(*array, {2, 1}, ex_2);
+}
+
+TEST_F(TestPrettyPrint, DecimalTypes) {
+ int32_t p = 19;
+ int32_t s = 4;
+
+ for (auto type : {decimal128(p, s), decimal256(p, s)}) {
+ auto array = ArrayFromJSON(type, "[\"123.4567\", \"456.7891\", null]");
+
+ static const char* ex = "[\n 123.4567,\n 456.7891,\n null\n]";
+ CheckArray(*array, {0}, ex);
+ }
+}
+
+TEST_F(TestPrettyPrint, DictionaryType) {
+ std::vector<bool> is_valid = {true, true, false, true, true, true};
+
+ std::shared_ptr<Array> dict;
+ std::vector<std::string> dict_values = {"foo", "bar", "baz"};
+ ArrayFromVector<StringType, std::string>(dict_values, &dict);
+ std::shared_ptr<DataType> dict_type = dictionary(int16(), utf8());
+
+ std::shared_ptr<Array> indices;
+ std::vector<int16_t> indices_values = {1, 2, -1, 0, 2, 0};
+ ArrayFromVector<Int16Type, int16_t>(is_valid, indices_values, &indices);
+ auto arr = std::make_shared<DictionaryArray>(dict_type, indices, dict);
+
+ static const char* expected = R"expected(
+-- dictionary:
+ [
+ "foo",
+ "bar",
+ "baz"
+ ]
+-- indices:
+ [
+ 1,
+ 2,
+ null,
+ 0,
+ 2,
+ 0
+ ])expected";
+
+ CheckArray(*arr, {0}, expected);
+}
+
+TEST_F(TestPrettyPrint, ChunkedArrayPrimitiveType) {
+ auto array = ArrayFromJSON(int32(), "[0, 1, null, 3, null]");
+ ChunkedArray chunked_array(array);
+
+ static const char* expected = R"expected([
+ [
+ 0,
+ 1,
+ null,
+ 3,
+ null
+ ]
+])expected";
+ CheckStream(chunked_array, {0}, expected);
+
+ ChunkedArray chunked_array_2({array, array});
+
+ static const char* expected_2 = R"expected([
+ [
+ 0,
+ 1,
+ null,
+ 3,
+ null
+ ],
+ [
+ 0,
+ 1,
+ null,
+ 3,
+ null
+ ]
+])expected";
+
+ CheckStream(chunked_array_2, {0}, expected_2);
+}
+
+TEST_F(TestPrettyPrint, TablePrimitive) {
+ std::shared_ptr<Field> int_field = field("column", int32());
+ auto array = ArrayFromJSON(int_field->type(), "[0, 1, null, 3, null]");
+ auto column = std::make_shared<ChunkedArray>(ArrayVector({array}));
+ std::shared_ptr<Schema> table_schema = schema({int_field});
+ std::shared_ptr<Table> table = Table::Make(table_schema, {column});
+
+ static const char* expected = R"expected(column: int32
+----
+column:
+ [
+ [
+ 0,
+ 1,
+ null,
+ 3,
+ null
+ ]
+ ]
+)expected";
+ CheckStream(*table, {0}, expected);
+}
+
+TEST_F(TestPrettyPrint, SchemaWithDictionary) {
+ std::vector<bool> is_valid = {true, true, false, true, true, true};
+
+ std::shared_ptr<Array> dict;
+ std::vector<std::string> dict_values = {"foo", "bar", "baz"};
+ ArrayFromVector<StringType, std::string>(dict_values, &dict);
+
+ auto simple = field("one", int32());
+ auto simple_dict = field("two", dictionary(int16(), utf8()));
+ auto list_of_dict = field("three", list(simple_dict));
+ auto struct_with_dict = field("four", struct_({simple, simple_dict}));
+
+ auto sch = schema({simple, simple_dict, list_of_dict, struct_with_dict});
+
+ static const char* expected = R"expected(one: int32
+two: dictionary<values=string, indices=int16, ordered=0>
+three: list<two: dictionary<values=string, indices=int16, ordered=0>>
+ child 0, two: dictionary<values=string, indices=int16, ordered=0>
+four: struct<one: int32, two: dictionary<values=string, indices=int16, ordered=0>>
+ child 0, one: int32
+ child 1, two: dictionary<values=string, indices=int16, ordered=0>)expected";
+
+ PrettyPrintOptions options;
+ Check(*sch, options, expected);
+}
+
+TEST_F(TestPrettyPrint, SchemaWithNotNull) {
+ auto simple = field("one", int32());
+ auto non_null = field("two", int32(), false);
+ auto list_simple = field("three", list(int32()));
+ auto list_non_null = field("four", list(int32()), false);
+ auto list_non_null2 = field("five", list(field("item", int32(), false)));
+
+ auto sch = schema({simple, non_null, list_simple, list_non_null, list_non_null2});
+
+ static const char* expected = R"expected(one: int32
+two: int32 not null
+three: list<item: int32>
+ child 0, item: int32
+four: list<item: int32> not null
+ child 0, item: int32
+five: list<item: int32 not null>
+ child 0, item: int32 not null)expected";
+
+ PrettyPrintOptions options;
+ Check(*sch, options, expected);
+}
+
+TEST_F(TestPrettyPrint, SchemaWithMetadata) {
+ // ARROW-7063
+ auto metadata1 = key_value_metadata({"foo1"}, {"bar1"});
+ auto metadata2 = key_value_metadata({"foo2"}, {"bar2"});
+ auto metadata3 = key_value_metadata(
+ {"foo3", "lorem"},
+ {"bar3",
+ R"(Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nulla accumsan vel
+ turpis et mollis. Aliquam tincidunt arcu id tortor blandit blandit. Donec
+ eget leo quis lectus scelerisque varius. Class aptent taciti sociosqu ad
+ litora torquent per conubia nostra, per inceptos himenaeos. Praesent
+ faucibus, diam eu volutpat iaculis, tellus est porta ligula, a efficitur
+ turpis nulla facilisis quam. Aliquam vitae lorem erat. Proin a dolor ac libero
+ dignissim mollis vitae eu mauris. Quisque posuere tellus vitae massa
+ pellentesque sagittis. Aenean feugiat, diam ac dignissim fermentum, lorem
+ sapien commodo massa, vel volutpat orci nisi eu justo. Nulla non blandit
+ sapien. Quisque pretium vestibulum urna eu vehicula.)"});
+ auto my_schema = schema(
+ {field("one", int32(), true, metadata1), field("two", utf8(), false, metadata2)},
+ metadata3);
+
+ PrettyPrintOptions options;
+ static const char* expected = R"(one: int32
+ -- field metadata --
+ foo1: 'bar1'
+two: string not null
+ -- field metadata --
+ foo2: 'bar2'
+-- schema metadata --
+foo3: 'bar3'
+lorem: 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nulla ac' + 737)";
+ Check(*my_schema, options, expected);
+
+ static const char* expected_verbose = R"(one: int32
+ -- field metadata --
+ foo1: 'bar1'
+two: string not null
+ -- field metadata --
+ foo2: 'bar2'
+-- schema metadata --
+foo3: 'bar3'
+lorem: 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nulla accumsan vel
+ turpis et mollis. Aliquam tincidunt arcu id tortor blandit blandit. Donec
+ eget leo quis lectus scelerisque varius. Class aptent taciti sociosqu ad
+ litora torquent per conubia nostra, per inceptos himenaeos. Praesent
+ faucibus, diam eu volutpat iaculis, tellus est porta ligula, a efficitur
+ turpis nulla facilisis quam. Aliquam vitae lorem erat. Proin a dolor ac libero
+ dignissim mollis vitae eu mauris. Quisque posuere tellus vitae massa
+ pellentesque sagittis. Aenean feugiat, diam ac dignissim fermentum, lorem
+ sapien commodo massa, vel volutpat orci nisi eu justo. Nulla non blandit
+ sapien. Quisque pretium vestibulum urna eu vehicula.')";
+ options.truncate_metadata = false;
+ Check(*my_schema, options, expected_verbose);
+
+ // Metadata that exactly fits
+ auto metadata4 =
+ key_value_metadata({"key"}, {("valuexxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
+ "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")});
+ my_schema = schema({field("f0", int32())}, metadata4);
+ static const char* expected_fits = R"(f0: int32
+-- schema metadata --
+key: 'valuexxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')";
+ options.truncate_metadata = false;
+ Check(*my_schema, options, expected_fits);
+
+ // A large key
+ auto metadata5 = key_value_metadata({"0123456789012345678901234567890123456789"},
+ {("valuexxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
+ "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")});
+ my_schema = schema({field("f0", int32())}, metadata5);
+ static const char* expected_big_key = R"(f0: int32
+-- schema metadata --
+0123456789012345678901234567890123456789: 'valuexxxxxxxxxxxxxxxxxxxxxxxxx' + 40)";
+ options.truncate_metadata = true;
+ Check(*my_schema, options, expected_big_key);
+}
+
+TEST_F(TestPrettyPrint, SchemaIndentation) {
+ // ARROW-6159
+ auto simple = field("one", int32());
+ auto non_null = field("two", int32(), false);
+ auto sch = schema({simple, non_null});
+
+ static const char* expected = R"expected( one: int32
+ two: int32 not null)expected";
+
+ PrettyPrintOptions options(/*indent=*/4);
+ Check(*sch, options, expected);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/public_api_test.cc b/src/arrow/cpp/src/arrow/public_api_test.cc
new file mode 100644
index 000000000..eba14ec66
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/public_api_test.cc
@@ -0,0 +1,93 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+#include <string>
+
+#include "arrow/config.h"
+
+// Include various "api.h" entrypoints and check they don't leak internal symbols
+
+#include "arrow/api.h" // IWYU pragma: keep
+#include "arrow/io/api.h" // IWYU pragma: keep
+#include "arrow/ipc/api.h" // IWYU pragma: keep
+
+#ifdef ARROW_COMPUTE
+#include "arrow/compute/api.h" // IWYU pragma: keep
+#endif
+
+#ifdef ARROW_CSV
+#include "arrow/csv/api.h" // IWYU pragma: keep
+#endif
+
+#ifdef ARROW_DATASET
+#include "arrow/dataset/api.h" // IWYU pragma: keep
+#endif
+
+#ifdef ARROW_FILESYSTEM
+#include "arrow/filesystem/api.h" // IWYU pragma: keep
+#endif
+
+#ifdef ARROW_FLIGHT
+#include "arrow/flight/api.h" // IWYU pragma: keep
+#endif
+
+#ifdef ARROW_JSON
+#include "arrow/json/api.h" // IWYU pragma: keep
+#endif
+
+#ifdef ARROW_PYTHON
+#include "arrow/python/api.h" // IWYU pragma: keep
+#endif
+
+#ifdef DCHECK
+#error "DCHECK should not be visible from Arrow public headers."
+#endif
+
+#ifdef ASSIGN_OR_RAISE
+#error "ASSIGN_OR_RAISE should not be visible from Arrow public headers."
+#endif
+
+#ifdef XSIMD_VERSION_MAJOR
+#error "xsimd should not be visible from Arrow public headers."
+#endif
+
+#ifdef HAS_CHRONO_ROUNDING
+#error "arrow::vendored::date should not be visible from Arrow public headers."
+#endif
+
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
+
+namespace arrow {
+
+TEST(Misc, BuildInfo) {
+ const auto& info = GetBuildInfo();
+ // The runtime version (GetBuildInfo) should have the same major number as the
+ // build-time version (ARROW_VERSION), but may have a greater minor / patch number.
+ ASSERT_EQ(info.version_major, ARROW_VERSION_MAJOR);
+ ASSERT_GE(info.version_minor, ARROW_VERSION_MINOR);
+ ASSERT_GE(info.version_patch, ARROW_VERSION_PATCH);
+ ASSERT_GE(info.version, ARROW_VERSION);
+ ASSERT_LT(info.version, ARROW_VERSION + 1000 * 1000); // Same major version
+ std::stringstream ss;
+ ss << info.version_major << "." << info.version_minor << "." << info.version_patch;
+ ASSERT_THAT(info.version_string, ::testing::HasSubstr(ss.str()));
+ ASSERT_THAT(info.full_so_version, ::testing::HasSubstr(info.so_version));
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/ArrowPythonConfig.cmake.in b/src/arrow/cpp/src/arrow/python/ArrowPythonConfig.cmake.in
new file mode 100644
index 000000000..4cae0c2df
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/ArrowPythonConfig.cmake.in
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This config sets the following variables in your project::
+#
+# ArrowPython_FOUND - true if Arrow Python found on the system
+#
+# This config sets the following targets in your project::
+#
+# arrow_python_shared - for linked as shared library if shared library is built
+# arrow_python_static - for linked as static library if static library is built
+
+@PACKAGE_INIT@
+
+include(CMakeFindDependencyMacro)
+find_dependency(Arrow)
+
+# Load targets only once. If we load targets multiple times, CMake reports
+# already existent target error.
+if(NOT (TARGET arrow_python_shared OR TARGET arrow_python_static))
+ include("${CMAKE_CURRENT_LIST_DIR}/ArrowPythonTargets.cmake")
+endif()
diff --git a/src/arrow/cpp/src/arrow/python/ArrowPythonFlightConfig.cmake.in b/src/arrow/cpp/src/arrow/python/ArrowPythonFlightConfig.cmake.in
new file mode 100644
index 000000000..5dc9deec5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/ArrowPythonFlightConfig.cmake.in
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This config sets the following variables in your project::
+#
+# ArrowPythonFlight_FOUND - true if Arrow Python Flight found on the system
+#
+# This config sets the following targets in your project::
+#
+# arrow_python_flight_shared - for linked as shared library if shared library is built
+# arrow_python_flight_static - for linked as static library if static library is built
+
+@PACKAGE_INIT@
+
+include(CMakeFindDependencyMacro)
+find_dependency(ArrowFlight)
+find_dependency(ArrowPython)
+
+# Load targets only once. If we load targets multiple times, CMake reports
+# already existent target error.
+if(NOT (TARGET arrow_python_flight_shared OR TARGET arrow_python_flight_static))
+ include("${CMAKE_CURRENT_LIST_DIR}/ArrowPythonFlightTargets.cmake")
+endif()
diff --git a/src/arrow/cpp/src/arrow/python/CMakeLists.txt b/src/arrow/cpp/src/arrow/python/CMakeLists.txt
new file mode 100644
index 000000000..40f351b56
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/CMakeLists.txt
@@ -0,0 +1,184 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#
+# arrow_python
+#
+
+find_package(Python3Alt REQUIRED)
+
+add_custom_target(arrow_python-all)
+add_custom_target(arrow_python)
+add_custom_target(arrow_python-tests)
+add_dependencies(arrow_python-all arrow_python arrow_python-tests)
+
+set(ARROW_PYTHON_SRCS
+ arrow_to_pandas.cc
+ benchmark.cc
+ common.cc
+ datetime.cc
+ decimal.cc
+ deserialize.cc
+ extension_type.cc
+ helpers.cc
+ inference.cc
+ init.cc
+ io.cc
+ ipc.cc
+ numpy_convert.cc
+ numpy_to_arrow.cc
+ python_to_arrow.cc
+ pyarrow.cc
+ serialize.cc)
+
+set_source_files_properties(init.cc PROPERTIES SKIP_PRECOMPILE_HEADERS ON
+ SKIP_UNITY_BUILD_INCLUSION ON)
+
+if(ARROW_FILESYSTEM)
+ list(APPEND ARROW_PYTHON_SRCS filesystem.cc)
+endif()
+
+set(ARROW_PYTHON_DEPENDENCIES arrow_dependencies)
+
+if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
+ set_property(SOURCE pyarrow.cc
+ APPEND_STRING
+ PROPERTY COMPILE_FLAGS " -Wno-cast-qual ")
+endif()
+
+set(ARROW_PYTHON_SHARED_LINK_LIBS arrow_shared)
+if(WIN32)
+ list(APPEND ARROW_PYTHON_SHARED_LINK_LIBS ${PYTHON_LIBRARIES} ${PYTHON_OTHER_LIBS})
+endif()
+
+set(ARROW_PYTHON_INCLUDES ${NUMPY_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
+
+add_arrow_lib(arrow_python
+ CMAKE_PACKAGE_NAME
+ ArrowPython
+ PKG_CONFIG_NAME
+ arrow-python
+ SOURCES
+ ${ARROW_PYTHON_SRCS}
+ PRECOMPILED_HEADERS
+ "$<$<COMPILE_LANGUAGE:CXX>:arrow/python/pch.h>"
+ OUTPUTS
+ ARROW_PYTHON_LIBRARIES
+ DEPENDENCIES
+ ${ARROW_PYTHON_DEPENDENCIES}
+ SHARED_LINK_FLAGS
+ ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt
+ SHARED_LINK_LIBS
+ ${ARROW_PYTHON_SHARED_LINK_LIBS}
+ STATIC_LINK_LIBS
+ ${PYTHON_OTHER_LIBS}
+ EXTRA_INCLUDES
+ "${ARROW_PYTHON_INCLUDES}")
+
+add_dependencies(arrow_python ${ARROW_PYTHON_LIBRARIES})
+
+foreach(LIB_TARGET ${ARROW_PYTHON_LIBRARIES})
+ target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_PYTHON_EXPORTING)
+endforeach()
+
+if(ARROW_BUILD_STATIC AND MSVC)
+ target_compile_definitions(arrow_python_static PUBLIC ARROW_STATIC)
+endif()
+
+if(ARROW_FLIGHT AND ARROW_BUILD_SHARED)
+ # Must link to shared libarrow_flight: we don't want to link more than one
+ # copy of gRPC into the eventual Cython shared object, otherwise gRPC calls
+ # fail with weird errors due to multiple copies of global static state (The
+ # other solution is to link gRPC shared everywhere instead of statically only
+ # in Flight)
+ add_arrow_lib(arrow_python_flight
+ CMAKE_PACKAGE_NAME
+ ArrowPythonFlight
+ PKG_CONFIG_NAME
+ arrow-python-flight
+ SOURCES
+ flight.cc
+ OUTPUTS
+ ARROW_PYFLIGHT_LIBRARIES
+ DEPENDENCIES
+ flight_grpc_gen
+ SHARED_LINK_FLAGS
+ ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt
+ SHARED_LINK_LIBS
+ arrow_python_shared
+ arrow_flight_shared
+ STATIC_LINK_LIBS
+ ${PYTHON_OTHER_LIBS}
+ EXTRA_INCLUDES
+ "${ARROW_PYTHON_INCLUDES}")
+
+ add_dependencies(arrow_python ${ARROW_PYFLIGHT_LIBRARIES})
+
+ foreach(LIB_TARGET ${ARROW_PYFLIGHT_LIBRARIES})
+ target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_PYFLIGHT_EXPORTING)
+ endforeach()
+
+ if(ARROW_BUILD_STATIC AND MSVC)
+ target_compile_definitions(arrow_python_flight_static PUBLIC ARROW_STATIC)
+ endif()
+endif()
+
+if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
+ # Clang, be quiet. Python C API has lots of macros
+ set_property(SOURCE ${ARROW_PYTHON_SRCS}
+ APPEND_STRING
+ PROPERTY COMPILE_FLAGS -Wno-parentheses-equality)
+endif()
+
+arrow_install_all_headers("arrow/python")
+
+# ----------------------------------------------------------------------
+
+if(ARROW_BUILD_TESTS)
+ add_library(arrow_python_test_main STATIC util/test_main.cc)
+
+ target_link_libraries(arrow_python_test_main GTest::gtest)
+ target_include_directories(arrow_python_test_main SYSTEM
+ PUBLIC ${ARROW_PYTHON_INCLUDES})
+
+ if(APPLE)
+ target_link_libraries(arrow_python_test_main ${CMAKE_DL_LIBS})
+ set_target_properties(arrow_python_test_main PROPERTIES LINK_FLAGS
+ "-undefined dynamic_lookup")
+ elseif(NOT MSVC)
+ target_link_libraries(arrow_python_test_main pthread ${CMAKE_DL_LIBS})
+ endif()
+
+ if(ARROW_TEST_LINKAGE STREQUAL shared)
+ set(ARROW_PYTHON_TEST_LINK_LIBS arrow_python_test_main arrow_python_shared
+ arrow_testing_shared arrow_shared)
+ else()
+ set(ARROW_PYTHON_TEST_LINK_LIBS arrow_python_test_main arrow_python_static
+ arrow_testing_static arrow_static)
+ endif()
+
+ add_arrow_test(python_test
+ STATIC_LINK_LIBS
+ "${ARROW_PYTHON_TEST_LINK_LIBS}"
+ EXTRA_LINK_LIBS
+ ${PYTHON_LIBRARIES}
+ EXTRA_INCLUDES
+ "${ARROW_PYTHON_INCLUDES}"
+ LABELS
+ "arrow_python-tests"
+ NO_VALGRIND)
+endif()
diff --git a/src/arrow/cpp/src/arrow/python/api.h b/src/arrow/cpp/src/arrow/python/api.h
new file mode 100644
index 000000000..a0b13d6d1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/api.h
@@ -0,0 +1,30 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/python/arrow_to_pandas.h"
+#include "arrow/python/common.h"
+#include "arrow/python/datetime.h"
+#include "arrow/python/deserialize.h"
+#include "arrow/python/helpers.h"
+#include "arrow/python/inference.h"
+#include "arrow/python/io.h"
+#include "arrow/python/numpy_convert.h"
+#include "arrow/python/numpy_to_arrow.h"
+#include "arrow/python/python_to_arrow.h"
+#include "arrow/python/serialize.h"
diff --git a/src/arrow/cpp/src/arrow/python/arrow-python-flight.pc.in b/src/arrow/cpp/src/arrow/python/arrow-python-flight.pc.in
new file mode 100644
index 000000000..fabed1b2d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/arrow-python-flight.pc.in
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+
+Name: Apache Arrow Python Flight
+Description: Python integration library for Apache Arrow Flight
+Version: @ARROW_VERSION@
+Requires: arrow-python arrow-flight
+Libs: -L${libdir} -larrow_python_flight
diff --git a/src/arrow/cpp/src/arrow/python/arrow-python.pc.in b/src/arrow/cpp/src/arrow/python/arrow-python.pc.in
new file mode 100644
index 000000000..529395198
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/arrow-python.pc.in
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+
+Name: Apache Arrow Python
+Description: Python integration library for Apache Arrow
+Version: @ARROW_VERSION@
+Requires: arrow
+Libs: -L${libdir} -larrow_python
+Cflags: -I${includedir} -I@PYTHON_INCLUDE_DIRS@
diff --git a/src/arrow/cpp/src/arrow/python/arrow_to_pandas.cc b/src/arrow/cpp/src/arrow/python/arrow_to_pandas.cc
new file mode 100644
index 000000000..3f386ad52
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/arrow_to_pandas.cc
@@ -0,0 +1,2322 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Functions for pandas conversion via NumPy
+
+#include "arrow/python/arrow_to_pandas.h"
+#include "arrow/python/numpy_interop.h" // IWYU pragma: expand
+
+#include <cmath>
+#include <cstdint>
+#include <iostream>
+#include <memory>
+#include <mutex>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/datum.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/hashing.h"
+#include "arrow/util/int_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/parallel.h"
+#include "arrow/util/string_view.h"
+#include "arrow/visitor_inline.h"
+
+#include "arrow/compute/api.h"
+
+#include "arrow/python/arrow_to_python_internal.h"
+#include "arrow/python/common.h"
+#include "arrow/python/datetime.h"
+#include "arrow/python/decimal.h"
+#include "arrow/python/helpers.h"
+#include "arrow/python/numpy_convert.h"
+#include "arrow/python/numpy_internal.h"
+#include "arrow/python/pyarrow.h"
+#include "arrow/python/python_to_arrow.h"
+#include "arrow/python/type_traits.h"
+
+namespace arrow {
+
+class MemoryPool;
+
+using internal::checked_cast;
+using internal::CheckIndexBounds;
+using internal::GetByteWidth;
+using internal::OptionalParallelFor;
+
+namespace py {
+namespace {
+
+// Fix options for conversion of an inner (child) array.
+PandasOptions MakeInnerOptions(PandasOptions options) {
+ // Make sure conversion of inner dictionary arrays always returns an array,
+ // not a dict {'indices': array, 'dictionary': array, 'ordered': bool}
+ options.decode_dictionaries = true;
+ options.categorical_columns.clear();
+ options.strings_to_categorical = false;
+
+ // In ARROW-7723, we found as a result of ARROW-3789 that second
+ // through microsecond resolution tz-aware timestamps were being promoted to
+ // use the DATETIME_NANO_TZ conversion path, yielding a datetime64[ns] NumPy
+ // array in this function. PyArray_GETITEM returns datetime.datetime for
+ // units second through microsecond but PyLong for nanosecond (because
+ // datetime.datetime does not support nanoseconds).
+ // We force the object conversion to preserve the value of the timezone.
+ // Nanoseconds are returned as integers.
+ options.coerce_temporal_nanoseconds = false;
+
+ return options;
+}
+
+// ----------------------------------------------------------------------
+// PyCapsule code for setting ndarray base to reference C++ object
+
+struct ArrayCapsule {
+ std::shared_ptr<Array> array;
+};
+
+struct BufferCapsule {
+ std::shared_ptr<Buffer> buffer;
+};
+
+void ArrayCapsule_Destructor(PyObject* capsule) {
+ delete reinterpret_cast<ArrayCapsule*>(PyCapsule_GetPointer(capsule, "arrow::Array"));
+}
+
+void BufferCapsule_Destructor(PyObject* capsule) {
+ delete reinterpret_cast<BufferCapsule*>(PyCapsule_GetPointer(capsule, "arrow::Buffer"));
+}
+
+// ----------------------------------------------------------------------
+// pandas 0.x DataFrame conversion internals
+
+using internal::arrow_traits;
+using internal::npy_traits;
+
+template <typename T>
+struct WrapBytes {};
+
+template <>
+struct WrapBytes<StringType> {
+ static inline PyObject* Wrap(const char* data, int64_t length) {
+ return PyUnicode_FromStringAndSize(data, length);
+ }
+};
+
+template <>
+struct WrapBytes<LargeStringType> {
+ static inline PyObject* Wrap(const char* data, int64_t length) {
+ return PyUnicode_FromStringAndSize(data, length);
+ }
+};
+
+template <>
+struct WrapBytes<BinaryType> {
+ static inline PyObject* Wrap(const char* data, int64_t length) {
+ return PyBytes_FromStringAndSize(data, length);
+ }
+};
+
+template <>
+struct WrapBytes<LargeBinaryType> {
+ static inline PyObject* Wrap(const char* data, int64_t length) {
+ return PyBytes_FromStringAndSize(data, length);
+ }
+};
+
+template <>
+struct WrapBytes<FixedSizeBinaryType> {
+ static inline PyObject* Wrap(const char* data, int64_t length) {
+ return PyBytes_FromStringAndSize(data, length);
+ }
+};
+
+static inline bool ListTypeSupported(const DataType& type) {
+ switch (type.id()) {
+ case Type::BOOL:
+ case Type::UINT8:
+ case Type::INT8:
+ case Type::UINT16:
+ case Type::INT16:
+ case Type::UINT32:
+ case Type::INT32:
+ case Type::INT64:
+ case Type::UINT64:
+ case Type::FLOAT:
+ case Type::DOUBLE:
+ case Type::DECIMAL128:
+ case Type::DECIMAL256:
+ case Type::BINARY:
+ case Type::LARGE_BINARY:
+ case Type::STRING:
+ case Type::LARGE_STRING:
+ case Type::DATE32:
+ case Type::DATE64:
+ case Type::STRUCT:
+ case Type::TIME32:
+ case Type::TIME64:
+ case Type::TIMESTAMP:
+ case Type::DURATION:
+ case Type::DICTIONARY:
+ case Type::NA: // empty list
+ // The above types are all supported.
+ return true;
+ case Type::FIXED_SIZE_LIST:
+ case Type::LIST:
+ case Type::LARGE_LIST: {
+ const auto& list_type = checked_cast<const BaseListType&>(type);
+ return ListTypeSupported(*list_type.value_type());
+ }
+ default:
+ break;
+ }
+ return false;
+}
+
+Status CapsulizeArray(const std::shared_ptr<Array>& arr, PyObject** out) {
+ auto capsule = new ArrayCapsule{{arr}};
+ *out = PyCapsule_New(reinterpret_cast<void*>(capsule), "arrow::Array",
+ &ArrayCapsule_Destructor);
+ if (*out == nullptr) {
+ delete capsule;
+ RETURN_IF_PYERROR();
+ }
+ return Status::OK();
+}
+
+Status CapsulizeBuffer(const std::shared_ptr<Buffer>& buffer, PyObject** out) {
+ auto capsule = new BufferCapsule{{buffer}};
+ *out = PyCapsule_New(reinterpret_cast<void*>(capsule), "arrow::Buffer",
+ &BufferCapsule_Destructor);
+ if (*out == nullptr) {
+ delete capsule;
+ RETURN_IF_PYERROR();
+ }
+ return Status::OK();
+}
+
+Status SetNdarrayBase(PyArrayObject* arr, PyObject* base) {
+ if (PyArray_SetBaseObject(arr, base) == -1) {
+ // Error occurred, trust that SetBaseObject sets the error state
+ Py_XDECREF(base);
+ RETURN_IF_PYERROR();
+ }
+ return Status::OK();
+}
+
+Status SetBufferBase(PyArrayObject* arr, const std::shared_ptr<Buffer>& buffer) {
+ PyObject* base;
+ RETURN_NOT_OK(CapsulizeBuffer(buffer, &base));
+ return SetNdarrayBase(arr, base);
+}
+
+inline void set_numpy_metadata(int type, const DataType* datatype, PyArray_Descr* out) {
+ auto metadata = reinterpret_cast<PyArray_DatetimeDTypeMetaData*>(out->c_metadata);
+ if (type == NPY_DATETIME) {
+ if (datatype->id() == Type::TIMESTAMP) {
+ const auto& timestamp_type = checked_cast<const TimestampType&>(*datatype);
+ metadata->meta.base = internal::NumPyFrequency(timestamp_type.unit());
+ } else {
+ DCHECK(false) << "NPY_DATETIME views only supported for Arrow TIMESTAMP types";
+ }
+ } else if (type == NPY_TIMEDELTA) {
+ DCHECK_EQ(datatype->id(), Type::DURATION);
+ const auto& duration_type = checked_cast<const DurationType&>(*datatype);
+ metadata->meta.base = internal::NumPyFrequency(duration_type.unit());
+ }
+}
+
+Status PyArray_NewFromPool(int nd, npy_intp* dims, PyArray_Descr* descr, MemoryPool* pool,
+ PyObject** out) {
+ // ARROW-6570: Allocate memory from MemoryPool for a couple reasons
+ //
+ // * Track allocations
+ // * Get better performance through custom allocators
+ int64_t total_size = descr->elsize;
+ for (int i = 0; i < nd; ++i) {
+ total_size *= dims[i];
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(total_size, pool));
+ *out = PyArray_NewFromDescr(&PyArray_Type, descr, nd, dims,
+ /*strides=*/nullptr,
+ /*data=*/buffer->mutable_data(),
+ /*flags=*/NPY_ARRAY_CARRAY | NPY_ARRAY_WRITEABLE,
+ /*obj=*/nullptr);
+ if (*out == nullptr) {
+ RETURN_IF_PYERROR();
+ // Trust that error set if NULL returned
+ }
+ return SetBufferBase(reinterpret_cast<PyArrayObject*>(*out), std::move(buffer));
+}
+
+template <typename T = void>
+inline const T* GetPrimitiveValues(const Array& arr) {
+ if (arr.length() == 0) {
+ return nullptr;
+ }
+ const int elsize = GetByteWidth(*arr.type());
+ const auto& prim_arr = checked_cast<const PrimitiveArray&>(arr);
+ return reinterpret_cast<const T*>(prim_arr.values()->data() + arr.offset() * elsize);
+}
+
+Status MakeNumPyView(std::shared_ptr<Array> arr, PyObject* py_ref, int npy_type, int ndim,
+ npy_intp* dims, PyObject** out) {
+ PyAcquireGIL lock;
+
+ PyArray_Descr* descr = internal::GetSafeNumPyDtype(npy_type);
+ set_numpy_metadata(npy_type, arr->type().get(), descr);
+ PyObject* result = PyArray_NewFromDescr(
+ &PyArray_Type, descr, ndim, dims, /*strides=*/nullptr,
+ const_cast<void*>(GetPrimitiveValues(*arr)), /*flags=*/0, nullptr);
+ PyArrayObject* np_arr = reinterpret_cast<PyArrayObject*>(result);
+ if (np_arr == nullptr) {
+ // Error occurred, trust that error set
+ return Status::OK();
+ }
+
+ PyObject* base;
+ if (py_ref == nullptr) {
+ // Capsule will be owned by the ndarray, no incref necessary. See
+ // ARROW-1973
+ RETURN_NOT_OK(CapsulizeArray(arr, &base));
+ } else {
+ Py_INCREF(py_ref);
+ base = py_ref;
+ }
+ RETURN_NOT_OK(SetNdarrayBase(np_arr, base));
+
+ // Do not allow Arrow data to be mutated
+ PyArray_CLEARFLAGS(np_arr, NPY_ARRAY_WRITEABLE);
+ *out = result;
+ return Status::OK();
+}
+
+class PandasWriter {
+ public:
+ enum type {
+ OBJECT,
+ UINT8,
+ INT8,
+ UINT16,
+ INT16,
+ UINT32,
+ INT32,
+ UINT64,
+ INT64,
+ HALF_FLOAT,
+ FLOAT,
+ DOUBLE,
+ BOOL,
+ DATETIME_DAY,
+ DATETIME_SECOND,
+ DATETIME_MILLI,
+ DATETIME_MICRO,
+ DATETIME_NANO,
+ DATETIME_NANO_TZ,
+ TIMEDELTA_SECOND,
+ TIMEDELTA_MILLI,
+ TIMEDELTA_MICRO,
+ TIMEDELTA_NANO,
+ CATEGORICAL,
+ EXTENSION
+ };
+
+ PandasWriter(const PandasOptions& options, int64_t num_rows, int num_columns)
+ : options_(options), num_rows_(num_rows), num_columns_(num_columns) {}
+ virtual ~PandasWriter() {}
+
+ void SetBlockData(PyObject* arr) {
+ block_arr_.reset(arr);
+ block_data_ =
+ reinterpret_cast<uint8_t*>(PyArray_DATA(reinterpret_cast<PyArrayObject*>(arr)));
+ }
+
+ /// \brief Either copy or wrap single array to create pandas-compatible array
+ /// for Series or DataFrame. num_columns_ can only be 1. Will try to zero
+ /// copy if possible (or error if not possible and zero_copy_only=True)
+ virtual Status TransferSingle(std::shared_ptr<ChunkedArray> data, PyObject* py_ref) = 0;
+
+ /// \brief Copy ChunkedArray into a multi-column block
+ virtual Status CopyInto(std::shared_ptr<ChunkedArray> data, int64_t rel_placement) = 0;
+
+ Status EnsurePlacementAllocated() {
+ std::lock_guard<std::mutex> guard(allocation_lock_);
+ if (placement_data_ != nullptr) {
+ return Status::OK();
+ }
+ PyAcquireGIL lock;
+
+ npy_intp placement_dims[1] = {num_columns_};
+ PyObject* placement_arr = PyArray_SimpleNew(1, placement_dims, NPY_INT64);
+ RETURN_IF_PYERROR();
+ placement_arr_.reset(placement_arr);
+ placement_data_ = reinterpret_cast<int64_t*>(
+ PyArray_DATA(reinterpret_cast<PyArrayObject*>(placement_arr)));
+ return Status::OK();
+ }
+
+ Status EnsureAllocated() {
+ std::lock_guard<std::mutex> guard(allocation_lock_);
+ if (block_data_ != nullptr) {
+ return Status::OK();
+ }
+ RETURN_NOT_OK(Allocate());
+ return Status::OK();
+ }
+
+ virtual bool CanZeroCopy(const ChunkedArray& data) const { return false; }
+
+ virtual Status Write(std::shared_ptr<ChunkedArray> data, int64_t abs_placement,
+ int64_t rel_placement) {
+ RETURN_NOT_OK(EnsurePlacementAllocated());
+ if (num_columns_ == 1 && options_.allow_zero_copy_blocks) {
+ RETURN_NOT_OK(TransferSingle(data, /*py_ref=*/nullptr));
+ } else {
+ RETURN_NOT_OK(
+ CheckNoZeroCopy("Cannot do zero copy conversion into "
+ "multi-column DataFrame block"));
+ RETURN_NOT_OK(EnsureAllocated());
+ RETURN_NOT_OK(CopyInto(data, rel_placement));
+ }
+ placement_data_[rel_placement] = abs_placement;
+ return Status::OK();
+ }
+
+ virtual Status GetDataFrameResult(PyObject** out) {
+ PyObject* result = PyDict_New();
+ RETURN_IF_PYERROR();
+
+ PyObject* block;
+ RETURN_NOT_OK(GetResultBlock(&block));
+
+ PyDict_SetItemString(result, "block", block);
+ PyDict_SetItemString(result, "placement", placement_arr_.obj());
+
+ RETURN_NOT_OK(AddResultMetadata(result));
+ *out = result;
+ return Status::OK();
+ }
+
+ // Caller steals the reference to this object
+ virtual Status GetSeriesResult(PyObject** out) {
+ RETURN_NOT_OK(MakeBlock1D());
+ // Caller owns the object now
+ *out = block_arr_.detach();
+ return Status::OK();
+ }
+
+ protected:
+ virtual Status AddResultMetadata(PyObject* result) { return Status::OK(); }
+
+ Status MakeBlock1D() {
+ // For Series or for certain DataFrame block types, we need to shape to a
+ // 1D array when there is only one column
+ PyAcquireGIL lock;
+
+ DCHECK_EQ(1, num_columns_);
+
+ npy_intp new_dims[1] = {static_cast<npy_intp>(num_rows_)};
+ PyArray_Dims dims;
+ dims.ptr = new_dims;
+ dims.len = 1;
+
+ PyObject* reshaped = PyArray_Newshape(
+ reinterpret_cast<PyArrayObject*>(block_arr_.obj()), &dims, NPY_ANYORDER);
+ RETURN_IF_PYERROR();
+
+ // ARROW-8801: Here a PyArrayObject is created that is not being managed by
+ // any OwnedRef object. This object is then put in the resulting object
+ // with PyDict_SetItemString, which increments the reference count, so a
+ // memory leak ensues. There are several ways to fix the memory leak but a
+ // simple one is to put the reshaped 1D block array in this OwnedRefNoGIL
+ // so it will be correctly decref'd when this class is destructed.
+ block_arr_.reset(reshaped);
+ return Status::OK();
+ }
+
+ virtual Status GetResultBlock(PyObject** out) {
+ *out = block_arr_.obj();
+ return Status::OK();
+ }
+
+ Status CheckNoZeroCopy(const std::string& message) {
+ if (options_.zero_copy_only) {
+ return Status::Invalid(message);
+ }
+ return Status::OK();
+ }
+
+ Status CheckNotZeroCopyOnly(const ChunkedArray& data) {
+ if (options_.zero_copy_only) {
+ return Status::Invalid("Needed to copy ", data.num_chunks(), " chunks with ",
+ data.null_count(), " nulls, but zero_copy_only was True");
+ }
+ return Status::OK();
+ }
+
+ virtual Status Allocate() {
+ return Status::NotImplemented("Override Allocate in subclasses");
+ }
+
+ Status AllocateNDArray(int npy_type, int ndim = 2) {
+ PyAcquireGIL lock;
+
+ PyObject* block_arr;
+ npy_intp block_dims[2] = {0, 0};
+
+ if (ndim == 2) {
+ block_dims[0] = num_columns_;
+ block_dims[1] = num_rows_;
+ } else {
+ block_dims[0] = num_rows_;
+ }
+ PyArray_Descr* descr = internal::GetSafeNumPyDtype(npy_type);
+ if (PyDataType_REFCHK(descr)) {
+ // ARROW-6876: if the array has refcounted items, let Numpy
+ // own the array memory so as to decref elements on array destruction
+ block_arr = PyArray_SimpleNewFromDescr(ndim, block_dims, descr);
+ RETURN_IF_PYERROR();
+ } else {
+ RETURN_NOT_OK(
+ PyArray_NewFromPool(ndim, block_dims, descr, options_.pool, &block_arr));
+ }
+
+ SetBlockData(block_arr);
+ return Status::OK();
+ }
+
+ void SetDatetimeUnit(NPY_DATETIMEUNIT unit) {
+ PyAcquireGIL lock;
+ auto date_dtype = reinterpret_cast<PyArray_DatetimeDTypeMetaData*>(
+ PyArray_DESCR(reinterpret_cast<PyArrayObject*>(block_arr_.obj()))->c_metadata);
+ date_dtype->meta.base = unit;
+ }
+
+ PandasOptions options_;
+
+ std::mutex allocation_lock_;
+
+ int64_t num_rows_;
+ int num_columns_;
+
+ OwnedRefNoGIL block_arr_;
+ uint8_t* block_data_ = nullptr;
+
+ // ndarray<int32>
+ OwnedRefNoGIL placement_arr_;
+ int64_t* placement_data_ = nullptr;
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(PandasWriter);
+};
+
+template <typename InType, typename OutType>
+inline void ConvertIntegerWithNulls(const PandasOptions& options,
+ const ChunkedArray& data, OutType* out_values) {
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = *data.chunk(c);
+ const InType* in_values = GetPrimitiveValues<InType>(arr);
+ // Upcast to double, set NaN as appropriate
+
+ for (int i = 0; i < arr.length(); ++i) {
+ *out_values++ =
+ arr.IsNull(i) ? static_cast<OutType>(NAN) : static_cast<OutType>(in_values[i]);
+ }
+ }
+}
+
+template <typename T>
+inline void ConvertIntegerNoNullsSameType(const PandasOptions& options,
+ const ChunkedArray& data, T* out_values) {
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = *data.chunk(c);
+ if (arr.length() > 0) {
+ const T* in_values = GetPrimitiveValues<T>(arr);
+ memcpy(out_values, in_values, sizeof(T) * arr.length());
+ out_values += arr.length();
+ }
+ }
+}
+
+template <typename InType, typename OutType>
+inline void ConvertIntegerNoNullsCast(const PandasOptions& options,
+ const ChunkedArray& data, OutType* out_values) {
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = *data.chunk(c);
+ const InType* in_values = GetPrimitiveValues<InType>(arr);
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ *out_values = in_values[i];
+ }
+ }
+}
+
+template <typename T, typename Enable = void>
+struct MemoizationTraits {
+ using Scalar = typename T::c_type;
+};
+
+template <typename T>
+struct MemoizationTraits<T, enable_if_has_string_view<T>> {
+ // For binary, we memoize string_view as a scalar value to avoid having to
+ // unnecessarily copy the memory into the memo table data structure
+ using Scalar = util::string_view;
+};
+
+// Generic Array -> PyObject** converter that handles object deduplication, if
+// requested
+template <typename Type, typename WrapFunction>
+inline Status ConvertAsPyObjects(const PandasOptions& options, const ChunkedArray& data,
+ WrapFunction&& wrap_func, PyObject** out_values) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using Scalar = typename MemoizationTraits<Type>::Scalar;
+
+ ::arrow::internal::ScalarMemoTable<Scalar> memo_table(options.pool);
+ std::vector<PyObject*> unique_values;
+ int32_t memo_size = 0;
+
+ auto WrapMemoized = [&](const Scalar& value, PyObject** out_values) {
+ int32_t memo_index;
+ RETURN_NOT_OK(memo_table.GetOrInsert(value, &memo_index));
+ if (memo_index == memo_size) {
+ // New entry
+ RETURN_NOT_OK(wrap_func(value, out_values));
+ unique_values.push_back(*out_values);
+ ++memo_size;
+ } else {
+ // Duplicate entry
+ Py_INCREF(unique_values[memo_index]);
+ *out_values = unique_values[memo_index];
+ }
+ return Status::OK();
+ };
+
+ auto WrapUnmemoized = [&](const Scalar& value, PyObject** out_values) {
+ return wrap_func(value, out_values);
+ };
+
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = arrow::internal::checked_cast<const ArrayType&>(*data.chunk(c));
+ if (options.deduplicate_objects) {
+ RETURN_NOT_OK(internal::WriteArrayObjects(arr, WrapMemoized, out_values));
+ } else {
+ RETURN_NOT_OK(internal::WriteArrayObjects(arr, WrapUnmemoized, out_values));
+ }
+ out_values += arr.length();
+ }
+ return Status::OK();
+}
+
+Status ConvertStruct(PandasOptions options, const ChunkedArray& data,
+ PyObject** out_values) {
+ if (data.num_chunks() == 0) {
+ return Status::OK();
+ }
+ // ChunkedArray has at least one chunk
+ auto arr = checked_cast<const StructArray*>(data.chunk(0).get());
+ // Use it to cache the struct type and number of fields for all chunks
+ int32_t num_fields = arr->num_fields();
+ auto array_type = arr->type();
+ std::vector<OwnedRef> fields_data(num_fields * data.num_chunks());
+ OwnedRef dict_item;
+
+ // See notes in MakeInnerOptions.
+ options = MakeInnerOptions(std::move(options));
+ // Don't blindly convert because timestamps in lists are handled differently.
+ options.timestamp_as_object = true;
+
+ for (int c = 0; c < data.num_chunks(); c++) {
+ auto fields_data_offset = c * num_fields;
+ auto arr = checked_cast<const StructArray*>(data.chunk(c).get());
+ // Convert the struct arrays first
+ for (int32_t i = 0; i < num_fields; i++) {
+ const auto field = arr->field(static_cast<int>(i));
+ RETURN_NOT_OK(ConvertArrayToPandas(options, field, nullptr,
+ fields_data[i + fields_data_offset].ref()));
+ DCHECK(PyArray_Check(fields_data[i + fields_data_offset].obj()));
+ }
+
+ // Construct a dictionary for each row
+ const bool has_nulls = data.null_count() > 0;
+ for (int64_t i = 0; i < arr->length(); ++i) {
+ if (has_nulls && arr->IsNull(i)) {
+ Py_INCREF(Py_None);
+ *out_values = Py_None;
+ } else {
+ // Build the new dict object for the row
+ dict_item.reset(PyDict_New());
+ RETURN_IF_PYERROR();
+ for (int32_t field_idx = 0; field_idx < num_fields; ++field_idx) {
+ OwnedRef field_value;
+ auto name = array_type->field(static_cast<int>(field_idx))->name();
+ if (!arr->field(static_cast<int>(field_idx))->IsNull(i)) {
+ // Value exists in child array, obtain it
+ auto array = reinterpret_cast<PyArrayObject*>(
+ fields_data[field_idx + fields_data_offset].obj());
+ auto ptr = reinterpret_cast<const char*>(PyArray_GETPTR1(array, i));
+ field_value.reset(PyArray_GETITEM(array, ptr));
+ RETURN_IF_PYERROR();
+ } else {
+ // Translate the Null to a None
+ Py_INCREF(Py_None);
+ field_value.reset(Py_None);
+ }
+ // PyDict_SetItemString increments reference count
+ auto setitem_result =
+ PyDict_SetItemString(dict_item.obj(), name.c_str(), field_value.obj());
+ RETURN_IF_PYERROR();
+ DCHECK_EQ(setitem_result, 0);
+ }
+ *out_values = dict_item.obj();
+ // Grant ownership to the resulting array
+ Py_INCREF(*out_values);
+ }
+ ++out_values;
+ }
+ }
+ return Status::OK();
+}
+
+Status DecodeDictionaries(MemoryPool* pool, const std::shared_ptr<DataType>& dense_type,
+ ArrayVector* arrays) {
+ compute::ExecContext ctx(pool);
+ compute::CastOptions options;
+ for (size_t i = 0; i < arrays->size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE((*arrays)[i],
+ compute::Cast(*(*arrays)[i], dense_type, options, &ctx));
+ }
+ return Status::OK();
+}
+
+Status DecodeDictionaries(MemoryPool* pool, const std::shared_ptr<DataType>& dense_type,
+ std::shared_ptr<ChunkedArray>* array) {
+ auto chunks = (*array)->chunks();
+ RETURN_NOT_OK(DecodeDictionaries(pool, dense_type, &chunks));
+ *array = std::make_shared<ChunkedArray>(std::move(chunks), dense_type);
+ return Status::OK();
+}
+
+template <typename ListArrayT>
+Status ConvertListsLike(PandasOptions options, const ChunkedArray& data,
+ PyObject** out_values) {
+ // Get column of underlying value arrays
+ ArrayVector value_arrays;
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = checked_cast<const ListArrayT&>(*data.chunk(c));
+ value_arrays.emplace_back(arr.values());
+ }
+ using ListArrayType = typename ListArrayT::TypeClass;
+ const auto& list_type = checked_cast<const ListArrayType&>(*data.type());
+ auto value_type = list_type.value_type();
+
+ auto flat_column = std::make_shared<ChunkedArray>(value_arrays, value_type);
+
+ options = MakeInnerOptions(std::move(options));
+
+ OwnedRefNoGIL owned_numpy_array;
+ RETURN_NOT_OK(ConvertChunkedArrayToPandas(options, flat_column, nullptr,
+ owned_numpy_array.ref()));
+
+ PyObject* numpy_array = owned_numpy_array.obj();
+ DCHECK(PyArray_Check(numpy_array));
+
+ int64_t chunk_offset = 0;
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = checked_cast<const ListArrayT&>(*data.chunk(c));
+
+ const bool has_nulls = data.null_count() > 0;
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ if (has_nulls && arr.IsNull(i)) {
+ Py_INCREF(Py_None);
+ *out_values = Py_None;
+ } else {
+ OwnedRef start(PyLong_FromLongLong(arr.value_offset(i) + chunk_offset));
+ OwnedRef end(PyLong_FromLongLong(arr.value_offset(i + 1) + chunk_offset));
+ OwnedRef slice(PySlice_New(start.obj(), end.obj(), nullptr));
+
+ if (ARROW_PREDICT_FALSE(slice.obj() == nullptr)) {
+ // Fall out of loop, will return from RETURN_IF_PYERROR
+ break;
+ }
+ *out_values = PyObject_GetItem(numpy_array, slice.obj());
+
+ if (*out_values == nullptr) {
+ // Fall out of loop, will return from RETURN_IF_PYERROR
+ break;
+ }
+ }
+ ++out_values;
+ }
+ RETURN_IF_PYERROR();
+
+ chunk_offset += arr.values()->length();
+ }
+
+ return Status::OK();
+}
+
+Status ConvertMap(PandasOptions options, const ChunkedArray& data,
+ PyObject** out_values) {
+ // Get columns of underlying key/item arrays
+ std::vector<std::shared_ptr<Array>> key_arrays;
+ std::vector<std::shared_ptr<Array>> item_arrays;
+ for (int c = 0; c < data.num_chunks(); ++c) {
+ const auto& map_arr = checked_cast<const MapArray&>(*data.chunk(c));
+ key_arrays.emplace_back(map_arr.keys());
+ item_arrays.emplace_back(map_arr.items());
+ }
+
+ const auto& map_type = checked_cast<const MapType&>(*data.type());
+ auto key_type = map_type.key_type();
+ auto item_type = map_type.item_type();
+
+ // ARROW-6899: Convert dictionary-encoded children to dense instead of
+ // failing below. A more efficient conversion than this could be done later
+ if (key_type->id() == Type::DICTIONARY) {
+ auto dense_type = checked_cast<const DictionaryType&>(*key_type).value_type();
+ RETURN_NOT_OK(DecodeDictionaries(options.pool, dense_type, &key_arrays));
+ key_type = dense_type;
+ }
+ if (item_type->id() == Type::DICTIONARY) {
+ auto dense_type = checked_cast<const DictionaryType&>(*item_type).value_type();
+ RETURN_NOT_OK(DecodeDictionaries(options.pool, dense_type, &item_arrays));
+ item_type = dense_type;
+ }
+
+ // See notes in MakeInnerOptions.
+ options = MakeInnerOptions(std::move(options));
+ // Don't blindly convert because timestamps in lists are handled differently.
+ options.timestamp_as_object = true;
+
+ auto flat_keys = std::make_shared<ChunkedArray>(key_arrays, key_type);
+ auto flat_items = std::make_shared<ChunkedArray>(item_arrays, item_type);
+ OwnedRef list_item;
+ OwnedRef key_value;
+ OwnedRef item_value;
+ OwnedRefNoGIL owned_numpy_keys;
+ RETURN_NOT_OK(
+ ConvertChunkedArrayToPandas(options, flat_keys, nullptr, owned_numpy_keys.ref()));
+ OwnedRefNoGIL owned_numpy_items;
+ RETURN_NOT_OK(
+ ConvertChunkedArrayToPandas(options, flat_items, nullptr, owned_numpy_items.ref()));
+ PyArrayObject* py_keys = reinterpret_cast<PyArrayObject*>(owned_numpy_keys.obj());
+ PyArrayObject* py_items = reinterpret_cast<PyArrayObject*>(owned_numpy_items.obj());
+
+ int64_t chunk_offset = 0;
+ for (int c = 0; c < data.num_chunks(); ++c) {
+ const auto& arr = checked_cast<const MapArray&>(*data.chunk(c));
+ const bool has_nulls = data.null_count() > 0;
+
+ // Make a list of key/item pairs for each row in array
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ if (has_nulls && arr.IsNull(i)) {
+ Py_INCREF(Py_None);
+ *out_values = Py_None;
+ } else {
+ int64_t entry_offset = arr.value_offset(i);
+ int64_t num_maps = arr.value_offset(i + 1) - entry_offset;
+
+ // Build the new list object for the row of maps
+ list_item.reset(PyList_New(num_maps));
+ RETURN_IF_PYERROR();
+
+ // Add each key/item pair in the row
+ for (int64_t j = 0; j < num_maps; ++j) {
+ // Get key value, key is non-nullable for a valid row
+ auto ptr_key = reinterpret_cast<const char*>(
+ PyArray_GETPTR1(py_keys, chunk_offset + entry_offset + j));
+ key_value.reset(PyArray_GETITEM(py_keys, ptr_key));
+ RETURN_IF_PYERROR();
+
+ if (item_arrays[c]->IsNull(entry_offset + j)) {
+ // Translate the Null to a None
+ Py_INCREF(Py_None);
+ item_value.reset(Py_None);
+ } else {
+ // Get valid value from item array
+ auto ptr_item = reinterpret_cast<const char*>(
+ PyArray_GETPTR1(py_items, chunk_offset + entry_offset + j));
+ item_value.reset(PyArray_GETITEM(py_items, ptr_item));
+ RETURN_IF_PYERROR();
+ }
+
+ // Add the key/item pair to the list for the row
+ PyList_SET_ITEM(list_item.obj(), j,
+ PyTuple_Pack(2, key_value.obj(), item_value.obj()));
+ RETURN_IF_PYERROR();
+ }
+
+ // Pass ownership to the resulting array
+ *out_values = list_item.detach();
+ }
+ ++out_values;
+ }
+ RETURN_IF_PYERROR();
+
+ chunk_offset += arr.values()->length();
+ }
+
+ return Status::OK();
+}
+
+template <typename InType, typename OutType>
+inline void ConvertNumericNullable(const ChunkedArray& data, InType na_value,
+ OutType* out_values) {
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = *data.chunk(c);
+ const InType* in_values = GetPrimitiveValues<InType>(arr);
+
+ if (arr.null_count() > 0) {
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ *out_values++ = arr.IsNull(i) ? na_value : in_values[i];
+ }
+ } else {
+ memcpy(out_values, in_values, sizeof(InType) * arr.length());
+ out_values += arr.length();
+ }
+ }
+}
+
+template <typename InType, typename OutType>
+inline void ConvertNumericNullableCast(const ChunkedArray& data, InType na_value,
+ OutType* out_values) {
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = *data.chunk(c);
+ const InType* in_values = GetPrimitiveValues<InType>(arr);
+
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ *out_values++ = arr.IsNull(i) ? static_cast<OutType>(na_value)
+ : static_cast<OutType>(in_values[i]);
+ }
+ }
+}
+
+template <int NPY_TYPE>
+class TypedPandasWriter : public PandasWriter {
+ public:
+ using T = typename npy_traits<NPY_TYPE>::value_type;
+
+ using PandasWriter::PandasWriter;
+
+ Status TransferSingle(std::shared_ptr<ChunkedArray> data, PyObject* py_ref) override {
+ if (CanZeroCopy(*data)) {
+ PyObject* wrapped;
+ npy_intp dims[2] = {static_cast<npy_intp>(num_columns_),
+ static_cast<npy_intp>(num_rows_)};
+ RETURN_NOT_OK(
+ MakeNumPyView(data->chunk(0), py_ref, NPY_TYPE, /*ndim=*/2, dims, &wrapped));
+ SetBlockData(wrapped);
+ return Status::OK();
+ } else {
+ RETURN_NOT_OK(CheckNotZeroCopyOnly(*data));
+ RETURN_NOT_OK(EnsureAllocated());
+ return CopyInto(data, /*rel_placement=*/0);
+ }
+ }
+
+ Status CheckTypeExact(const DataType& type, Type::type expected) {
+ if (type.id() != expected) {
+ // TODO(wesm): stringify NumPy / pandas type
+ return Status::NotImplemented("Cannot write Arrow data of type ", type.ToString());
+ }
+ return Status::OK();
+ }
+
+ T* GetBlockColumnStart(int64_t rel_placement) {
+ return reinterpret_cast<T*>(block_data_) + rel_placement * num_rows_;
+ }
+
+ protected:
+ Status Allocate() override { return AllocateNDArray(NPY_TYPE); }
+};
+
+struct ObjectWriterVisitor {
+ const PandasOptions& options;
+ const ChunkedArray& data;
+ PyObject** out_values;
+
+ Status Visit(const NullType& type) {
+ for (int c = 0; c < data.num_chunks(); c++) {
+ std::shared_ptr<Array> arr = data.chunk(c);
+
+ for (int64_t i = 0; i < arr->length(); ++i) {
+ // All values are null
+ Py_INCREF(Py_None);
+ *out_values = Py_None;
+ ++out_values;
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const BooleanType& type) {
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = checked_cast<const BooleanArray&>(*data.chunk(c));
+
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ if (arr.IsNull(i)) {
+ Py_INCREF(Py_None);
+ *out_values++ = Py_None;
+ } else if (arr.Value(i)) {
+ // True
+ Py_INCREF(Py_True);
+ *out_values++ = Py_True;
+ } else {
+ // False
+ Py_INCREF(Py_False);
+ *out_values++ = Py_False;
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_integer<Type, Status> Visit(const Type& type) {
+ using T = typename Type::c_type;
+ auto WrapValue = [](T value, PyObject** out) {
+ *out = std::is_signed<T>::value ? PyLong_FromLongLong(value)
+ : PyLong_FromUnsignedLongLong(value);
+ RETURN_IF_PYERROR();
+ return Status::OK();
+ };
+ return ConvertAsPyObjects<Type>(options, data, WrapValue, out_values);
+ }
+
+ template <typename Type>
+ enable_if_t<is_base_binary_type<Type>::value || is_fixed_size_binary_type<Type>::value,
+ Status>
+ Visit(const Type& type) {
+ auto WrapValue = [](const util::string_view& view, PyObject** out) {
+ *out = WrapBytes<Type>::Wrap(view.data(), view.length());
+ if (*out == nullptr) {
+ PyErr_Clear();
+ return Status::UnknownError("Wrapping ", view, " failed");
+ }
+ return Status::OK();
+ };
+ return ConvertAsPyObjects<Type>(options, data, WrapValue, out_values);
+ }
+
+ template <typename Type>
+ enable_if_date<Type, Status> Visit(const Type& type) {
+ auto WrapValue = [](typename Type::c_type value, PyObject** out) {
+ RETURN_NOT_OK(internal::PyDate_from_int(value, Type::UNIT, out));
+ RETURN_IF_PYERROR();
+ return Status::OK();
+ };
+ return ConvertAsPyObjects<Type>(options, data, WrapValue, out_values);
+ }
+
+ template <typename Type>
+ enable_if_time<Type, Status> Visit(const Type& type) {
+ const TimeUnit::type unit = type.unit();
+ auto WrapValue = [unit](typename Type::c_type value, PyObject** out) {
+ RETURN_NOT_OK(internal::PyTime_from_int(value, unit, out));
+ RETURN_IF_PYERROR();
+ return Status::OK();
+ };
+ return ConvertAsPyObjects<Type>(options, data, WrapValue, out_values);
+ }
+
+ template <typename Type>
+ enable_if_timestamp<Type, Status> Visit(const Type& type) {
+ const TimeUnit::type unit = type.unit();
+ OwnedRef tzinfo;
+
+ auto ConvertTimezoneNaive = [&](typename Type::c_type value, PyObject** out) {
+ RETURN_NOT_OK(internal::PyDateTime_from_int(value, unit, out));
+ RETURN_IF_PYERROR();
+ return Status::OK();
+ };
+ auto ConvertTimezoneAware = [&](typename Type::c_type value, PyObject** out) {
+ PyObject* naive_datetime;
+ RETURN_NOT_OK(ConvertTimezoneNaive(value, &naive_datetime));
+ // convert the timezone naive datetime object to timezone aware
+ *out = PyObject_CallMethod(tzinfo.obj(), "fromutc", "O", naive_datetime);
+ // the timezone naive object is no longer required
+ Py_DECREF(naive_datetime);
+ RETURN_IF_PYERROR();
+ return Status::OK();
+ };
+
+ if (!type.timezone().empty() && !options.ignore_timezone) {
+ // convert timezone aware
+ PyObject* tzobj;
+ ARROW_ASSIGN_OR_RAISE(tzobj, internal::StringToTzinfo(type.timezone()));
+ tzinfo.reset(tzobj);
+ RETURN_IF_PYERROR();
+ RETURN_NOT_OK(
+ ConvertAsPyObjects<Type>(options, data, ConvertTimezoneAware, out_values));
+ } else {
+ // convert timezone naive
+ RETURN_NOT_OK(
+ ConvertAsPyObjects<Type>(options, data, ConvertTimezoneNaive, out_values));
+ }
+
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_t<std::is_same<Type, MonthDayNanoIntervalType>::value, Status> Visit(
+ const Type& type) {
+ OwnedRef args(PyTuple_New(0));
+ OwnedRef kwargs(PyDict_New());
+ RETURN_IF_PYERROR();
+ auto to_date_offset = [&](const MonthDayNanoIntervalType::MonthDayNanos& interval,
+ PyObject** out) {
+ DCHECK(internal::BorrowPandasDataOffsetType() != nullptr);
+ // DateOffset objects do not add nanoseconds component to pd.Timestamp.
+ // as of Pandas 1.3.3
+ // (https://github.com/pandas-dev/pandas/issues/43892).
+ // So convert microseconds and remainder to preserve data
+ // but give users more expected results.
+ int64_t microseconds = interval.nanoseconds / 1000;
+ int64_t nanoseconds;
+ if (interval.nanoseconds >= 0) {
+ nanoseconds = interval.nanoseconds % 1000;
+ } else {
+ nanoseconds = -((-interval.nanoseconds) % 1000);
+ }
+
+ PyDict_SetItemString(kwargs.obj(), "months", PyLong_FromLong(interval.months));
+ PyDict_SetItemString(kwargs.obj(), "days", PyLong_FromLong(interval.days));
+ PyDict_SetItemString(kwargs.obj(), "microseconds",
+ PyLong_FromLongLong(microseconds));
+ PyDict_SetItemString(kwargs.obj(), "nanoseconds", PyLong_FromLongLong(nanoseconds));
+ *out =
+ PyObject_Call(internal::BorrowPandasDataOffsetType(), args.obj(), kwargs.obj());
+ RETURN_IF_PYERROR();
+ return Status::OK();
+ };
+ return ConvertAsPyObjects<MonthDayNanoIntervalType>(options, data, to_date_offset,
+ out_values);
+ }
+
+ Status Visit(const Decimal128Type& type) {
+ OwnedRef decimal;
+ OwnedRef Decimal;
+ RETURN_NOT_OK(internal::ImportModule("decimal", &decimal));
+ RETURN_NOT_OK(internal::ImportFromModule(decimal.obj(), "Decimal", &Decimal));
+ PyObject* decimal_constructor = Decimal.obj();
+
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = checked_cast<const arrow::Decimal128Array&>(*data.chunk(c));
+
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ if (arr.IsNull(i)) {
+ Py_INCREF(Py_None);
+ *out_values++ = Py_None;
+ } else {
+ *out_values++ =
+ internal::DecimalFromString(decimal_constructor, arr.FormatValue(i));
+ RETURN_IF_PYERROR();
+ }
+ }
+ }
+
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal256Type& type) {
+ OwnedRef decimal;
+ OwnedRef Decimal;
+ RETURN_NOT_OK(internal::ImportModule("decimal", &decimal));
+ RETURN_NOT_OK(internal::ImportFromModule(decimal.obj(), "Decimal", &Decimal));
+ PyObject* decimal_constructor = Decimal.obj();
+
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = checked_cast<const arrow::Decimal256Array&>(*data.chunk(c));
+
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ if (arr.IsNull(i)) {
+ Py_INCREF(Py_None);
+ *out_values++ = Py_None;
+ } else {
+ *out_values++ =
+ internal::DecimalFromString(decimal_constructor, arr.FormatValue(i));
+ RETURN_IF_PYERROR();
+ }
+ }
+ }
+
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_t<is_fixed_size_list_type<T>::value || is_var_length_list_type<T>::value,
+ Status>
+ Visit(const T& type) {
+ using ArrayType = typename TypeTraits<T>::ArrayType;
+ if (!ListTypeSupported(*type.value_type())) {
+ return Status::NotImplemented(
+ "Not implemented type for conversion from List to Pandas: ",
+ type.value_type()->ToString());
+ }
+ return ConvertListsLike<ArrayType>(options, data, out_values);
+ }
+
+ Status Visit(const MapType& type) { return ConvertMap(options, data, out_values); }
+
+ Status Visit(const StructType& type) {
+ return ConvertStruct(options, data, out_values);
+ }
+
+ template <typename Type>
+ enable_if_t<is_floating_type<Type>::value ||
+ std::is_same<DictionaryType, Type>::value ||
+ std::is_same<DurationType, Type>::value ||
+ std::is_same<ExtensionType, Type>::value ||
+ (std::is_base_of<IntervalType, Type>::value &&
+ !std::is_same<MonthDayNanoIntervalType, Type>::value) ||
+ std::is_base_of<UnionType, Type>::value,
+ Status>
+ Visit(const Type& type) {
+ return Status::NotImplemented("No implemented conversion to object dtype: ",
+ type.ToString());
+ }
+};
+
+class ObjectWriter : public TypedPandasWriter<NPY_OBJECT> {
+ public:
+ using TypedPandasWriter<NPY_OBJECT>::TypedPandasWriter;
+ Status CopyInto(std::shared_ptr<ChunkedArray> data, int64_t rel_placement) override {
+ PyAcquireGIL lock;
+ ObjectWriterVisitor visitor{this->options_, *data,
+ this->GetBlockColumnStart(rel_placement)};
+ return VisitTypeInline(*data->type(), &visitor);
+ }
+};
+
+static inline bool IsNonNullContiguous(const ChunkedArray& data) {
+ return data.num_chunks() == 1 && data.null_count() == 0;
+}
+
+template <int NPY_TYPE>
+class IntWriter : public TypedPandasWriter<NPY_TYPE> {
+ public:
+ using ArrowType = typename npy_traits<NPY_TYPE>::TypeClass;
+ using TypedPandasWriter<NPY_TYPE>::TypedPandasWriter;
+
+ bool CanZeroCopy(const ChunkedArray& data) const override {
+ return IsNonNullContiguous(data);
+ }
+
+ Status CopyInto(std::shared_ptr<ChunkedArray> data, int64_t rel_placement) override {
+ RETURN_NOT_OK(this->CheckTypeExact(*data->type(), ArrowType::type_id));
+ ConvertIntegerNoNullsSameType<typename ArrowType::c_type>(
+ this->options_, *data, this->GetBlockColumnStart(rel_placement));
+ return Status::OK();
+ }
+};
+
+template <int NPY_TYPE>
+class FloatWriter : public TypedPandasWriter<NPY_TYPE> {
+ public:
+ using ArrowType = typename npy_traits<NPY_TYPE>::TypeClass;
+ using TypedPandasWriter<NPY_TYPE>::TypedPandasWriter;
+ using T = typename ArrowType::c_type;
+
+ bool CanZeroCopy(const ChunkedArray& data) const override {
+ return IsNonNullContiguous(data) && data.type()->id() == ArrowType::type_id;
+ }
+
+ Status CopyInto(std::shared_ptr<ChunkedArray> data, int64_t rel_placement) override {
+ Type::type in_type = data->type()->id();
+ auto out_values = this->GetBlockColumnStart(rel_placement);
+
+#define INTEGER_CASE(IN_TYPE) \
+ ConvertIntegerWithNulls<IN_TYPE, T>(this->options_, *data, out_values); \
+ break;
+
+ switch (in_type) {
+ case Type::UINT8:
+ INTEGER_CASE(uint8_t);
+ case Type::INT8:
+ INTEGER_CASE(int8_t);
+ case Type::UINT16:
+ INTEGER_CASE(uint16_t);
+ case Type::INT16:
+ INTEGER_CASE(int16_t);
+ case Type::UINT32:
+ INTEGER_CASE(uint32_t);
+ case Type::INT32:
+ INTEGER_CASE(int32_t);
+ case Type::UINT64:
+ INTEGER_CASE(uint64_t);
+ case Type::INT64:
+ INTEGER_CASE(int64_t);
+ case Type::HALF_FLOAT:
+ ConvertNumericNullableCast(*data, npy_traits<NPY_TYPE>::na_sentinel, out_values);
+ case Type::FLOAT:
+ ConvertNumericNullableCast(*data, npy_traits<NPY_TYPE>::na_sentinel, out_values);
+ break;
+ case Type::DOUBLE:
+ ConvertNumericNullableCast(*data, npy_traits<NPY_TYPE>::na_sentinel, out_values);
+ break;
+ default:
+ return Status::NotImplemented("Cannot write Arrow data of type ",
+ data->type()->ToString(),
+ " to a Pandas floating point block");
+ }
+
+#undef INTEGER_CASE
+
+ return Status::OK();
+ }
+};
+
+using UInt8Writer = IntWriter<NPY_UINT8>;
+using Int8Writer = IntWriter<NPY_INT8>;
+using UInt16Writer = IntWriter<NPY_UINT16>;
+using Int16Writer = IntWriter<NPY_INT16>;
+using UInt32Writer = IntWriter<NPY_UINT32>;
+using Int32Writer = IntWriter<NPY_INT32>;
+using UInt64Writer = IntWriter<NPY_UINT64>;
+using Int64Writer = IntWriter<NPY_INT64>;
+using Float16Writer = FloatWriter<NPY_FLOAT16>;
+using Float32Writer = FloatWriter<NPY_FLOAT32>;
+using Float64Writer = FloatWriter<NPY_FLOAT64>;
+
+class BoolWriter : public TypedPandasWriter<NPY_BOOL> {
+ public:
+ using TypedPandasWriter<NPY_BOOL>::TypedPandasWriter;
+
+ Status TransferSingle(std::shared_ptr<ChunkedArray> data, PyObject* py_ref) override {
+ RETURN_NOT_OK(
+ CheckNoZeroCopy("Zero copy conversions not possible with "
+ "boolean types"));
+ RETURN_NOT_OK(EnsureAllocated());
+ return CopyInto(data, /*rel_placement=*/0);
+ }
+
+ Status CopyInto(std::shared_ptr<ChunkedArray> data, int64_t rel_placement) override {
+ RETURN_NOT_OK(this->CheckTypeExact(*data->type(), Type::BOOL));
+ auto out_values = this->GetBlockColumnStart(rel_placement);
+ for (int c = 0; c < data->num_chunks(); c++) {
+ const auto& arr = checked_cast<const BooleanArray&>(*data->chunk(c));
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ *out_values++ = static_cast<uint8_t>(arr.Value(i));
+ }
+ }
+ return Status::OK();
+ }
+};
+
+// ----------------------------------------------------------------------
+// Date / timestamp types
+
+template <typename T, int64_t SHIFT>
+inline void ConvertDatetimeLikeNanos(const ChunkedArray& data, int64_t* out_values) {
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = *data.chunk(c);
+ const T* in_values = GetPrimitiveValues<T>(arr);
+
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ *out_values++ = arr.IsNull(i) ? kPandasTimestampNull
+ : (static_cast<int64_t>(in_values[i]) * SHIFT);
+ }
+ }
+}
+
+template <typename T, int SHIFT>
+void ConvertDatesShift(const ChunkedArray& data, int64_t* out_values) {
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = *data.chunk(c);
+ const T* in_values = GetPrimitiveValues<T>(arr);
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ *out_values++ = arr.IsNull(i) ? kPandasTimestampNull
+ : static_cast<int64_t>(in_values[i]) / SHIFT;
+ }
+ }
+}
+
+class DatetimeDayWriter : public TypedPandasWriter<NPY_DATETIME> {
+ public:
+ using TypedPandasWriter<NPY_DATETIME>::TypedPandasWriter;
+
+ Status CopyInto(std::shared_ptr<ChunkedArray> data, int64_t rel_placement) override {
+ int64_t* out_values = this->GetBlockColumnStart(rel_placement);
+ const auto& type = checked_cast<const DateType&>(*data->type());
+ switch (type.unit()) {
+ case DateUnit::DAY:
+ ConvertDatesShift<int32_t, 1LL>(*data, out_values);
+ break;
+ case DateUnit::MILLI:
+ ConvertDatesShift<int64_t, 86400000LL>(*data, out_values);
+ break;
+ }
+ return Status::OK();
+ }
+
+ protected:
+ Status Allocate() override {
+ RETURN_NOT_OK(this->AllocateNDArray(NPY_DATETIME));
+ SetDatetimeUnit(NPY_FR_D);
+ return Status::OK();
+ }
+};
+
+template <TimeUnit::type UNIT>
+class DatetimeWriter : public TypedPandasWriter<NPY_DATETIME> {
+ public:
+ using TypedPandasWriter<NPY_DATETIME>::TypedPandasWriter;
+
+ bool CanZeroCopy(const ChunkedArray& data) const override {
+ if (data.type()->id() == Type::TIMESTAMP) {
+ const auto& type = checked_cast<const TimestampType&>(*data.type());
+ return IsNonNullContiguous(data) && type.unit() == UNIT;
+ } else {
+ return false;
+ }
+ }
+
+ Status CopyInto(std::shared_ptr<ChunkedArray> data, int64_t rel_placement) override {
+ const auto& ts_type = checked_cast<const TimestampType&>(*data->type());
+ DCHECK_EQ(UNIT, ts_type.unit()) << "Should only call instances of this writer "
+ << "with arrays of the correct unit";
+ ConvertNumericNullable<int64_t>(*data, kPandasTimestampNull,
+ this->GetBlockColumnStart(rel_placement));
+ return Status::OK();
+ }
+
+ protected:
+ Status Allocate() override {
+ RETURN_NOT_OK(this->AllocateNDArray(NPY_DATETIME));
+ SetDatetimeUnit(internal::NumPyFrequency(UNIT));
+ return Status::OK();
+ }
+};
+
+using DatetimeSecondWriter = DatetimeWriter<TimeUnit::SECOND>;
+using DatetimeMilliWriter = DatetimeWriter<TimeUnit::MILLI>;
+using DatetimeMicroWriter = DatetimeWriter<TimeUnit::MICRO>;
+
+class DatetimeNanoWriter : public DatetimeWriter<TimeUnit::NANO> {
+ public:
+ using DatetimeWriter<TimeUnit::NANO>::DatetimeWriter;
+
+ Status CopyInto(std::shared_ptr<ChunkedArray> data, int64_t rel_placement) override {
+ Type::type type = data->type()->id();
+ int64_t* out_values = this->GetBlockColumnStart(rel_placement);
+ compute::ExecContext ctx(options_.pool);
+ compute::CastOptions options;
+ if (options_.safe_cast) {
+ options = compute::CastOptions::Safe();
+ } else {
+ options = compute::CastOptions::Unsafe();
+ }
+ Datum out;
+ auto target_type = timestamp(TimeUnit::NANO);
+
+ if (type == Type::DATE32) {
+ // Convert from days since epoch to datetime64[ns]
+ ConvertDatetimeLikeNanos<int32_t, kNanosecondsInDay>(*data, out_values);
+ } else if (type == Type::DATE64) {
+ // Date64Type is millisecond timestamp stored as int64_t
+ // TODO(wesm): Do we want to make sure to zero out the milliseconds?
+ ConvertDatetimeLikeNanos<int64_t, 1000000L>(*data, out_values);
+ } else if (type == Type::TIMESTAMP) {
+ const auto& ts_type = checked_cast<const TimestampType&>(*data->type());
+
+ if (ts_type.unit() == TimeUnit::NANO) {
+ ConvertNumericNullable<int64_t>(*data, kPandasTimestampNull, out_values);
+ } else if (ts_type.unit() == TimeUnit::MICRO || ts_type.unit() == TimeUnit::MILLI ||
+ ts_type.unit() == TimeUnit::SECOND) {
+ ARROW_ASSIGN_OR_RAISE(out, compute::Cast(data, target_type, options, &ctx));
+ ConvertNumericNullable<int64_t>(*out.chunked_array(), kPandasTimestampNull,
+ out_values);
+ } else {
+ return Status::NotImplemented("Unsupported time unit");
+ }
+ } else {
+ return Status::NotImplemented("Cannot write Arrow data of type ",
+ data->type()->ToString(),
+ " to a Pandas datetime block.");
+ }
+ return Status::OK();
+ }
+};
+
+class DatetimeTZWriter : public DatetimeNanoWriter {
+ public:
+ DatetimeTZWriter(const PandasOptions& options, const std::string& timezone,
+ int64_t num_rows)
+ : DatetimeNanoWriter(options, num_rows, 1), timezone_(timezone) {}
+
+ protected:
+ Status GetResultBlock(PyObject** out) override {
+ RETURN_NOT_OK(MakeBlock1D());
+ *out = block_arr_.obj();
+ return Status::OK();
+ }
+
+ Status AddResultMetadata(PyObject* result) override {
+ PyObject* py_tz = PyUnicode_FromStringAndSize(
+ timezone_.c_str(), static_cast<Py_ssize_t>(timezone_.size()));
+ RETURN_IF_PYERROR();
+ PyDict_SetItemString(result, "timezone", py_tz);
+ Py_DECREF(py_tz);
+ return Status::OK();
+ }
+
+ private:
+ std::string timezone_;
+};
+
+template <TimeUnit::type UNIT>
+class TimedeltaWriter : public TypedPandasWriter<NPY_TIMEDELTA> {
+ public:
+ using TypedPandasWriter<NPY_TIMEDELTA>::TypedPandasWriter;
+
+ Status AllocateTimedelta(int ndim) {
+ RETURN_NOT_OK(this->AllocateNDArray(NPY_TIMEDELTA, ndim));
+ SetDatetimeUnit(internal::NumPyFrequency(UNIT));
+ return Status::OK();
+ }
+
+ bool CanZeroCopy(const ChunkedArray& data) const override {
+ const auto& type = checked_cast<const DurationType&>(*data.type());
+ return IsNonNullContiguous(data) && type.unit() == UNIT;
+ }
+
+ Status CopyInto(std::shared_ptr<ChunkedArray> data, int64_t rel_placement) override {
+ const auto& type = checked_cast<const DurationType&>(*data->type());
+ DCHECK_EQ(UNIT, type.unit()) << "Should only call instances of this writer "
+ << "with arrays of the correct unit";
+ ConvertNumericNullable<int64_t>(*data, kPandasTimestampNull,
+ this->GetBlockColumnStart(rel_placement));
+ return Status::OK();
+ }
+
+ protected:
+ Status Allocate() override { return AllocateTimedelta(2); }
+};
+
+using TimedeltaSecondWriter = TimedeltaWriter<TimeUnit::SECOND>;
+using TimedeltaMilliWriter = TimedeltaWriter<TimeUnit::MILLI>;
+using TimedeltaMicroWriter = TimedeltaWriter<TimeUnit::MICRO>;
+
+class TimedeltaNanoWriter : public TimedeltaWriter<TimeUnit::NANO> {
+ public:
+ using TimedeltaWriter<TimeUnit::NANO>::TimedeltaWriter;
+
+ Status CopyInto(std::shared_ptr<ChunkedArray> data, int64_t rel_placement) override {
+ Type::type type = data->type()->id();
+ int64_t* out_values = this->GetBlockColumnStart(rel_placement);
+ if (type == Type::DURATION) {
+ const auto& ts_type = checked_cast<const DurationType&>(*data->type());
+ if (ts_type.unit() == TimeUnit::NANO) {
+ ConvertNumericNullable<int64_t>(*data, kPandasTimestampNull, out_values);
+ } else if (ts_type.unit() == TimeUnit::MICRO) {
+ ConvertDatetimeLikeNanos<int64_t, 1000L>(*data, out_values);
+ } else if (ts_type.unit() == TimeUnit::MILLI) {
+ ConvertDatetimeLikeNanos<int64_t, 1000000L>(*data, out_values);
+ } else if (ts_type.unit() == TimeUnit::SECOND) {
+ ConvertDatetimeLikeNanos<int64_t, 1000000000L>(*data, out_values);
+ } else {
+ return Status::NotImplemented("Unsupported time unit");
+ }
+ } else {
+ return Status::NotImplemented("Cannot write Arrow data of type ",
+ data->type()->ToString(),
+ " to a Pandas timedelta block.");
+ }
+ return Status::OK();
+ }
+};
+
+Status MakeZeroLengthArray(const std::shared_ptr<DataType>& type,
+ std::shared_ptr<Array>* out) {
+ std::unique_ptr<ArrayBuilder> builder;
+ RETURN_NOT_OK(MakeBuilder(default_memory_pool(), type, &builder));
+ RETURN_NOT_OK(builder->Resize(0));
+ return builder->Finish(out);
+}
+
+bool NeedDictionaryUnification(const ChunkedArray& data) {
+ if (data.num_chunks() < 2) {
+ return false;
+ }
+ const auto& arr_first = checked_cast<const DictionaryArray&>(*data.chunk(0));
+ for (int c = 1; c < data.num_chunks(); c++) {
+ const auto& arr = checked_cast<const DictionaryArray&>(*data.chunk(c));
+ if (!(arr_first.dictionary()->Equals(arr.dictionary()))) {
+ return true;
+ }
+ }
+ return false;
+}
+
+template <typename IndexType>
+class CategoricalWriter
+ : public TypedPandasWriter<arrow_traits<IndexType::type_id>::npy_type> {
+ public:
+ using TRAITS = arrow_traits<IndexType::type_id>;
+ using ArrayType = typename TypeTraits<IndexType>::ArrayType;
+ using T = typename TRAITS::T;
+
+ explicit CategoricalWriter(const PandasOptions& options, int64_t num_rows)
+ : TypedPandasWriter<TRAITS::npy_type>(options, num_rows, 1),
+ ordered_(false),
+ needs_copy_(false) {}
+
+ Status CopyInto(std::shared_ptr<ChunkedArray> data, int64_t rel_placement) override {
+ return Status::NotImplemented("categorical type");
+ }
+
+ Status TransferSingle(std::shared_ptr<ChunkedArray> data, PyObject* py_ref) override {
+ const auto& dict_type = checked_cast<const DictionaryType&>(*data->type());
+ std::shared_ptr<Array> dict;
+ if (data->num_chunks() == 0) {
+ // no dictionary values => create empty array
+ RETURN_NOT_OK(this->AllocateNDArray(TRAITS::npy_type, 1));
+ RETURN_NOT_OK(MakeZeroLengthArray(dict_type.value_type(), &dict));
+ } else {
+ DCHECK_EQ(IndexType::type_id, dict_type.index_type()->id());
+ RETURN_NOT_OK(WriteIndices(*data, &dict));
+ }
+
+ PyObject* pydict;
+ RETURN_NOT_OK(ConvertArrayToPandas(this->options_, dict, nullptr, &pydict));
+ dictionary_.reset(pydict);
+ ordered_ = dict_type.ordered();
+ return Status::OK();
+ }
+
+ Status Write(std::shared_ptr<ChunkedArray> data, int64_t abs_placement,
+ int64_t rel_placement) override {
+ RETURN_NOT_OK(this->EnsurePlacementAllocated());
+ RETURN_NOT_OK(TransferSingle(data, /*py_ref=*/nullptr));
+ this->placement_data_[rel_placement] = abs_placement;
+ return Status::OK();
+ }
+
+ Status GetSeriesResult(PyObject** out) override {
+ PyAcquireGIL lock;
+
+ PyObject* result = PyDict_New();
+ RETURN_IF_PYERROR();
+
+ // Expected single array dictionary layout
+ PyDict_SetItemString(result, "indices", this->block_arr_.obj());
+ RETURN_IF_PYERROR();
+ RETURN_NOT_OK(AddResultMetadata(result));
+
+ *out = result;
+ return Status::OK();
+ }
+
+ protected:
+ Status AddResultMetadata(PyObject* result) override {
+ PyDict_SetItemString(result, "dictionary", dictionary_.obj());
+ PyObject* py_ordered = ordered_ ? Py_True : Py_False;
+ Py_INCREF(py_ordered);
+ PyDict_SetItemString(result, "ordered", py_ordered);
+ return Status::OK();
+ }
+
+ Status WriteIndicesUniform(const ChunkedArray& data) {
+ RETURN_NOT_OK(this->AllocateNDArray(TRAITS::npy_type, 1));
+ T* out_values = reinterpret_cast<T*>(this->block_data_);
+
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = checked_cast<const DictionaryArray&>(*data.chunk(c));
+ const auto& indices = checked_cast<const ArrayType&>(*arr.indices());
+ auto values = reinterpret_cast<const T*>(indices.raw_values());
+
+ RETURN_NOT_OK(CheckIndexBounds(*indices.data(), arr.dictionary()->length()));
+ // Null is -1 in CategoricalBlock
+ for (int i = 0; i < arr.length(); ++i) {
+ if (indices.IsValid(i)) {
+ *out_values++ = values[i];
+ } else {
+ *out_values++ = -1;
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ Status WriteIndicesVarying(const ChunkedArray& data, std::shared_ptr<Array>* out_dict) {
+ // Yield int32 indices to allow for dictionary outgrowing the current index
+ // type
+ RETURN_NOT_OK(this->AllocateNDArray(NPY_INT32, 1));
+ auto out_values = reinterpret_cast<int32_t*>(this->block_data_);
+
+ const auto& dict_type = checked_cast<const DictionaryType&>(*data.type());
+
+ ARROW_ASSIGN_OR_RAISE(auto unifier, DictionaryUnifier::Make(dict_type.value_type(),
+ this->options_.pool));
+ for (int c = 0; c < data.num_chunks(); c++) {
+ const auto& arr = checked_cast<const DictionaryArray&>(*data.chunk(c));
+ const auto& indices = checked_cast<const ArrayType&>(*arr.indices());
+ auto values = reinterpret_cast<const T*>(indices.raw_values());
+
+ std::shared_ptr<Buffer> transpose_buffer;
+ RETURN_NOT_OK(unifier->Unify(*arr.dictionary(), &transpose_buffer));
+
+ auto transpose = reinterpret_cast<const int32_t*>(transpose_buffer->data());
+ int64_t dict_length = arr.dictionary()->length();
+
+ RETURN_NOT_OK(CheckIndexBounds(*indices.data(), dict_length));
+
+ // Null is -1 in CategoricalBlock
+ for (int i = 0; i < arr.length(); ++i) {
+ if (indices.IsValid(i)) {
+ *out_values++ = transpose[values[i]];
+ } else {
+ *out_values++ = -1;
+ }
+ }
+ }
+
+ std::shared_ptr<DataType> unused_type;
+ return unifier->GetResult(&unused_type, out_dict);
+ }
+
+ Status WriteIndices(const ChunkedArray& data, std::shared_ptr<Array>* out_dict) {
+ DCHECK_GT(data.num_chunks(), 0);
+
+ // Sniff the first chunk
+ const auto& arr_first = checked_cast<const DictionaryArray&>(*data.chunk(0));
+ const auto indices_first = std::static_pointer_cast<ArrayType>(arr_first.indices());
+
+ if (data.num_chunks() == 1 && indices_first->null_count() == 0) {
+ RETURN_NOT_OK(
+ CheckIndexBounds(*indices_first->data(), arr_first.dictionary()->length()));
+
+ PyObject* wrapped;
+ npy_intp dims[1] = {static_cast<npy_intp>(this->num_rows_)};
+ RETURN_NOT_OK(MakeNumPyView(indices_first, /*py_ref=*/nullptr, TRAITS::npy_type,
+ /*ndim=*/1, dims, &wrapped));
+ this->SetBlockData(wrapped);
+ *out_dict = arr_first.dictionary();
+ } else {
+ RETURN_NOT_OK(this->CheckNotZeroCopyOnly(data));
+ if (NeedDictionaryUnification(data)) {
+ RETURN_NOT_OK(WriteIndicesVarying(data, out_dict));
+ } else {
+ RETURN_NOT_OK(WriteIndicesUniform(data));
+ *out_dict = arr_first.dictionary();
+ }
+ }
+ return Status::OK();
+ }
+
+ OwnedRefNoGIL dictionary_;
+ bool ordered_;
+ bool needs_copy_;
+};
+
+class ExtensionWriter : public PandasWriter {
+ public:
+ using PandasWriter::PandasWriter;
+
+ Status Allocate() override {
+ // no-op
+ return Status::OK();
+ }
+
+ Status TransferSingle(std::shared_ptr<ChunkedArray> data, PyObject* py_ref) override {
+ PyAcquireGIL lock;
+ PyObject* py_array;
+ py_array = wrap_chunked_array(data);
+ py_array_.reset(py_array);
+
+ return Status::OK();
+ }
+
+ Status CopyInto(std::shared_ptr<ChunkedArray> data, int64_t rel_placement) override {
+ return TransferSingle(data, nullptr);
+ }
+
+ Status GetDataFrameResult(PyObject** out) override {
+ PyAcquireGIL lock;
+ PyObject* result = PyDict_New();
+ RETURN_IF_PYERROR();
+
+ PyDict_SetItemString(result, "py_array", py_array_.obj());
+ PyDict_SetItemString(result, "placement", placement_arr_.obj());
+ *out = result;
+ return Status::OK();
+ }
+
+ Status GetSeriesResult(PyObject** out) override {
+ *out = py_array_.detach();
+ return Status::OK();
+ }
+
+ protected:
+ OwnedRefNoGIL py_array_;
+};
+
+Status MakeWriter(const PandasOptions& options, PandasWriter::type writer_type,
+ const DataType& type, int64_t num_rows, int num_columns,
+ std::shared_ptr<PandasWriter>* writer) {
+#define BLOCK_CASE(NAME, TYPE) \
+ case PandasWriter::NAME: \
+ *writer = std::make_shared<TYPE>(options, num_rows, num_columns); \
+ break;
+
+#define CATEGORICAL_CASE(TYPE) \
+ case TYPE::type_id: \
+ *writer = std::make_shared<CategoricalWriter<TYPE>>(options, num_rows); \
+ break;
+
+ switch (writer_type) {
+ case PandasWriter::CATEGORICAL: {
+ const auto& index_type = *checked_cast<const DictionaryType&>(type).index_type();
+ switch (index_type.id()) {
+ CATEGORICAL_CASE(Int8Type);
+ CATEGORICAL_CASE(Int16Type);
+ CATEGORICAL_CASE(Int32Type);
+ CATEGORICAL_CASE(Int64Type);
+ case Type::UINT8:
+ case Type::UINT16:
+ case Type::UINT32:
+ case Type::UINT64:
+ return Status::TypeError(
+ "Converting unsigned dictionary indices to pandas",
+ " not yet supported, index type: ", index_type.ToString());
+ default:
+ // Unreachable
+ DCHECK(false);
+ break;
+ }
+ } break;
+ case PandasWriter::EXTENSION:
+ *writer = std::make_shared<ExtensionWriter>(options, num_rows, num_columns);
+ break;
+ BLOCK_CASE(OBJECT, ObjectWriter);
+ BLOCK_CASE(UINT8, UInt8Writer);
+ BLOCK_CASE(INT8, Int8Writer);
+ BLOCK_CASE(UINT16, UInt16Writer);
+ BLOCK_CASE(INT16, Int16Writer);
+ BLOCK_CASE(UINT32, UInt32Writer);
+ BLOCK_CASE(INT32, Int32Writer);
+ BLOCK_CASE(UINT64, UInt64Writer);
+ BLOCK_CASE(INT64, Int64Writer);
+ BLOCK_CASE(HALF_FLOAT, Float16Writer);
+ BLOCK_CASE(FLOAT, Float32Writer);
+ BLOCK_CASE(DOUBLE, Float64Writer);
+ BLOCK_CASE(BOOL, BoolWriter);
+ BLOCK_CASE(DATETIME_DAY, DatetimeDayWriter);
+ BLOCK_CASE(DATETIME_SECOND, DatetimeSecondWriter);
+ BLOCK_CASE(DATETIME_MILLI, DatetimeMilliWriter);
+ BLOCK_CASE(DATETIME_MICRO, DatetimeMicroWriter);
+ BLOCK_CASE(DATETIME_NANO, DatetimeNanoWriter);
+ BLOCK_CASE(TIMEDELTA_SECOND, TimedeltaSecondWriter);
+ BLOCK_CASE(TIMEDELTA_MILLI, TimedeltaMilliWriter);
+ BLOCK_CASE(TIMEDELTA_MICRO, TimedeltaMicroWriter);
+ BLOCK_CASE(TIMEDELTA_NANO, TimedeltaNanoWriter);
+ case PandasWriter::DATETIME_NANO_TZ: {
+ const auto& ts_type = checked_cast<const TimestampType&>(type);
+ *writer = std::make_shared<DatetimeTZWriter>(options, ts_type.timezone(), num_rows);
+ } break;
+ default:
+ return Status::NotImplemented("Unsupported block type");
+ }
+
+#undef BLOCK_CASE
+#undef CATEGORICAL_CASE
+
+ return Status::OK();
+}
+
+static Status GetPandasWriterType(const ChunkedArray& data, const PandasOptions& options,
+ PandasWriter::type* output_type) {
+#define INTEGER_CASE(NAME) \
+ *output_type = \
+ data.null_count() > 0 \
+ ? options.integer_object_nulls ? PandasWriter::OBJECT : PandasWriter::DOUBLE \
+ : PandasWriter::NAME; \
+ break;
+
+ switch (data.type()->id()) {
+ case Type::BOOL:
+ *output_type = data.null_count() > 0 ? PandasWriter::OBJECT : PandasWriter::BOOL;
+ break;
+ case Type::UINT8:
+ INTEGER_CASE(UINT8);
+ case Type::INT8:
+ INTEGER_CASE(INT8);
+ case Type::UINT16:
+ INTEGER_CASE(UINT16);
+ case Type::INT16:
+ INTEGER_CASE(INT16);
+ case Type::UINT32:
+ INTEGER_CASE(UINT32);
+ case Type::INT32:
+ INTEGER_CASE(INT32);
+ case Type::UINT64:
+ INTEGER_CASE(UINT64);
+ case Type::INT64:
+ INTEGER_CASE(INT64);
+ case Type::HALF_FLOAT:
+ *output_type = PandasWriter::HALF_FLOAT;
+ break;
+ case Type::FLOAT:
+ *output_type = PandasWriter::FLOAT;
+ break;
+ case Type::DOUBLE:
+ *output_type = PandasWriter::DOUBLE;
+ break;
+ case Type::STRING: // fall through
+ case Type::LARGE_STRING: // fall through
+ case Type::BINARY: // fall through
+ case Type::LARGE_BINARY:
+ case Type::NA: // fall through
+ case Type::FIXED_SIZE_BINARY: // fall through
+ case Type::STRUCT: // fall through
+ case Type::TIME32: // fall through
+ case Type::TIME64: // fall through
+ case Type::DECIMAL128: // fall through
+ case Type::DECIMAL256: // fall through
+ case Type::INTERVAL_MONTH_DAY_NANO: // fall through
+ *output_type = PandasWriter::OBJECT;
+ break;
+ case Type::DATE32: // fall through
+ case Type::DATE64:
+ if (options.date_as_object) {
+ *output_type = PandasWriter::OBJECT;
+ } else {
+ *output_type = options.coerce_temporal_nanoseconds ? PandasWriter::DATETIME_NANO
+ : PandasWriter::DATETIME_DAY;
+ }
+ break;
+ case Type::TIMESTAMP: {
+ const auto& ts_type = checked_cast<const TimestampType&>(*data.type());
+ if (options.timestamp_as_object && ts_type.unit() != TimeUnit::NANO) {
+ // Nanoseconds are never out of bounds for pandas, so in that case
+ // we don't convert to object
+ *output_type = PandasWriter::OBJECT;
+ } else if (!ts_type.timezone().empty()) {
+ *output_type = PandasWriter::DATETIME_NANO_TZ;
+ } else if (options.coerce_temporal_nanoseconds) {
+ *output_type = PandasWriter::DATETIME_NANO;
+ } else {
+ switch (ts_type.unit()) {
+ case TimeUnit::SECOND:
+ *output_type = PandasWriter::DATETIME_SECOND;
+ break;
+ case TimeUnit::MILLI:
+ *output_type = PandasWriter::DATETIME_MILLI;
+ break;
+ case TimeUnit::MICRO:
+ *output_type = PandasWriter::DATETIME_MICRO;
+ break;
+ case TimeUnit::NANO:
+ *output_type = PandasWriter::DATETIME_NANO;
+ break;
+ }
+ }
+ } break;
+ case Type::DURATION: {
+ const auto& dur_type = checked_cast<const DurationType&>(*data.type());
+ if (options.coerce_temporal_nanoseconds) {
+ *output_type = PandasWriter::TIMEDELTA_NANO;
+ } else {
+ switch (dur_type.unit()) {
+ case TimeUnit::SECOND:
+ *output_type = PandasWriter::TIMEDELTA_SECOND;
+ break;
+ case TimeUnit::MILLI:
+ *output_type = PandasWriter::TIMEDELTA_MILLI;
+ break;
+ case TimeUnit::MICRO:
+ *output_type = PandasWriter::TIMEDELTA_MICRO;
+ break;
+ case TimeUnit::NANO:
+ *output_type = PandasWriter::TIMEDELTA_NANO;
+ break;
+ }
+ }
+ } break;
+ case Type::FIXED_SIZE_LIST:
+ case Type::LIST:
+ case Type::LARGE_LIST:
+ case Type::MAP: {
+ auto list_type = std::static_pointer_cast<BaseListType>(data.type());
+ if (!ListTypeSupported(*list_type->value_type())) {
+ return Status::NotImplemented("Not implemented type for Arrow list to pandas: ",
+ list_type->value_type()->ToString());
+ }
+ *output_type = PandasWriter::OBJECT;
+ } break;
+ case Type::DICTIONARY:
+ *output_type = PandasWriter::CATEGORICAL;
+ break;
+ case Type::EXTENSION:
+ *output_type = PandasWriter::EXTENSION;
+ break;
+ default:
+ return Status::NotImplemented(
+ "No known equivalent Pandas block for Arrow data of type ",
+ data.type()->ToString(), " is known.");
+ }
+ return Status::OK();
+}
+
+// Construct the exact pandas "BlockManager" memory layout
+//
+// * For each column determine the correct output pandas type
+// * Allocate 2D blocks (ncols x nrows) for each distinct data type in output
+// * Allocate block placement arrays
+// * Write Arrow columns out into each slice of memory; populate block
+// * placement arrays as we go
+class PandasBlockCreator {
+ public:
+ using WriterMap = std::unordered_map<int, std::shared_ptr<PandasWriter>>;
+
+ explicit PandasBlockCreator(const PandasOptions& options, FieldVector fields,
+ ChunkedArrayVector arrays)
+ : options_(options), fields_(std::move(fields)), arrays_(std::move(arrays)) {
+ num_columns_ = static_cast<int>(arrays_.size());
+ if (num_columns_ > 0) {
+ num_rows_ = arrays_[0]->length();
+ }
+ column_block_placement_.resize(num_columns_);
+ }
+ virtual ~PandasBlockCreator() = default;
+
+ virtual Status Convert(PyObject** out) = 0;
+
+ Status AppendBlocks(const WriterMap& blocks, PyObject* list) {
+ for (const auto& it : blocks) {
+ PyObject* item;
+ RETURN_NOT_OK(it.second->GetDataFrameResult(&item));
+ if (PyList_Append(list, item) < 0) {
+ RETURN_IF_PYERROR();
+ }
+
+ // ARROW-1017; PyList_Append increments object refcount
+ Py_DECREF(item);
+ }
+ return Status::OK();
+ }
+
+ protected:
+ PandasOptions options_;
+
+ FieldVector fields_;
+ ChunkedArrayVector arrays_;
+ int num_columns_;
+ int64_t num_rows_;
+
+ // column num -> relative placement within internal block
+ std::vector<int> column_block_placement_;
+};
+
+class ConsolidatedBlockCreator : public PandasBlockCreator {
+ public:
+ using PandasBlockCreator::PandasBlockCreator;
+
+ Status Convert(PyObject** out) override {
+ column_types_.resize(num_columns_);
+ RETURN_NOT_OK(CreateBlocks());
+ RETURN_NOT_OK(WriteTableToBlocks());
+ PyAcquireGIL lock;
+
+ PyObject* result = PyList_New(0);
+ RETURN_IF_PYERROR();
+
+ RETURN_NOT_OK(AppendBlocks(blocks_, result));
+ RETURN_NOT_OK(AppendBlocks(singleton_blocks_, result));
+
+ *out = result;
+ return Status::OK();
+ }
+
+ Status GetBlockType(int column_index, PandasWriter::type* out) {
+ if (options_.extension_columns.count(fields_[column_index]->name())) {
+ *out = PandasWriter::EXTENSION;
+ return Status::OK();
+ } else {
+ return GetPandasWriterType(*arrays_[column_index], options_, out);
+ }
+ }
+
+ Status CreateBlocks() {
+ for (int i = 0; i < num_columns_; ++i) {
+ const DataType& type = *arrays_[i]->type();
+ PandasWriter::type output_type;
+ RETURN_NOT_OK(GetBlockType(i, &output_type));
+
+ int block_placement = 0;
+ std::shared_ptr<PandasWriter> writer;
+ if (output_type == PandasWriter::CATEGORICAL ||
+ output_type == PandasWriter::DATETIME_NANO_TZ ||
+ output_type == PandasWriter::EXTENSION) {
+ RETURN_NOT_OK(MakeWriter(options_, output_type, type, num_rows_,
+ /*num_columns=*/1, &writer));
+ singleton_blocks_[i] = writer;
+ } else {
+ auto it = block_sizes_.find(output_type);
+ if (it != block_sizes_.end()) {
+ block_placement = it->second;
+ // Increment count
+ ++it->second;
+ } else {
+ // Add key to map
+ block_sizes_[output_type] = 1;
+ }
+ }
+ column_types_[i] = output_type;
+ column_block_placement_[i] = block_placement;
+ }
+
+ // Create normal non-categorical blocks
+ for (const auto& it : this->block_sizes_) {
+ PandasWriter::type output_type = static_cast<PandasWriter::type>(it.first);
+ std::shared_ptr<PandasWriter> block;
+ RETURN_NOT_OK(MakeWriter(this->options_, output_type, /*unused*/ *null(), num_rows_,
+ it.second, &block));
+ this->blocks_[output_type] = block;
+ }
+ return Status::OK();
+ }
+
+ Status GetWriter(int i, std::shared_ptr<PandasWriter>* block) {
+ PandasWriter::type output_type = this->column_types_[i];
+ switch (output_type) {
+ case PandasWriter::CATEGORICAL:
+ case PandasWriter::DATETIME_NANO_TZ:
+ case PandasWriter::EXTENSION: {
+ auto it = this->singleton_blocks_.find(i);
+ if (it == this->singleton_blocks_.end()) {
+ return Status::KeyError("No block allocated");
+ }
+ *block = it->second;
+ } break;
+ default:
+ auto it = this->blocks_.find(output_type);
+ if (it == this->blocks_.end()) {
+ return Status::KeyError("No block allocated");
+ }
+ *block = it->second;
+ break;
+ }
+ return Status::OK();
+ }
+
+ Status WriteTableToBlocks() {
+ auto WriteColumn = [this](int i) {
+ std::shared_ptr<PandasWriter> block;
+ RETURN_NOT_OK(this->GetWriter(i, &block));
+ // ARROW-3789 Use std::move on the array to permit self-destructing
+ return block->Write(std::move(arrays_[i]), i, this->column_block_placement_[i]);
+ };
+
+ return OptionalParallelFor(options_.use_threads, num_columns_, WriteColumn);
+ }
+
+ private:
+ // column num -> block type id
+ std::vector<PandasWriter::type> column_types_;
+
+ // block type -> type count
+ std::unordered_map<int, int> block_sizes_;
+ std::unordered_map<int, const DataType*> block_types_;
+
+ // block type -> block
+ WriterMap blocks_;
+
+ WriterMap singleton_blocks_;
+};
+
+/// \brief Create blocks for pandas.DataFrame block manager using one block per
+/// column strategy. This permits some zero-copy optimizations as well as the
+/// ability for the table to "self-destruct" if selected by the user.
+class SplitBlockCreator : public PandasBlockCreator {
+ public:
+ using PandasBlockCreator::PandasBlockCreator;
+
+ Status GetWriter(int i, std::shared_ptr<PandasWriter>* writer) {
+ PandasWriter::type output_type = PandasWriter::OBJECT;
+ const DataType& type = *arrays_[i]->type();
+ if (options_.extension_columns.count(fields_[i]->name())) {
+ output_type = PandasWriter::EXTENSION;
+ } else {
+ // Null count needed to determine output type
+ RETURN_NOT_OK(GetPandasWriterType(*arrays_[i], options_, &output_type));
+ }
+ return MakeWriter(this->options_, output_type, type, num_rows_, 1, writer);
+ }
+
+ Status Convert(PyObject** out) override {
+ PyAcquireGIL lock;
+
+ PyObject* result = PyList_New(0);
+ RETURN_IF_PYERROR();
+
+ for (int i = 0; i < num_columns_; ++i) {
+ std::shared_ptr<PandasWriter> writer;
+ RETURN_NOT_OK(GetWriter(i, &writer));
+ // ARROW-3789 Use std::move on the array to permit self-destructing
+ RETURN_NOT_OK(writer->Write(std::move(arrays_[i]), i, /*rel_placement=*/0));
+
+ PyObject* item;
+ RETURN_NOT_OK(writer->GetDataFrameResult(&item));
+ if (PyList_Append(result, item) < 0) {
+ RETURN_IF_PYERROR();
+ }
+ // PyList_Append increments object refcount
+ Py_DECREF(item);
+ }
+
+ *out = result;
+ return Status::OK();
+ }
+
+ private:
+ std::vector<std::shared_ptr<PandasWriter>> writers_;
+};
+
+Status ConvertCategoricals(const PandasOptions& options, ChunkedArrayVector* arrays,
+ FieldVector* fields) {
+ std::vector<int> columns_to_encode;
+
+ // For Categorical conversions
+ auto EncodeColumn = [&](int j) {
+ int i = columns_to_encode[j];
+ if (options.zero_copy_only) {
+ return Status::Invalid("Need to dictionary encode a column, but ",
+ "only zero-copy conversions allowed");
+ }
+ compute::ExecContext ctx(options.pool);
+ ARROW_ASSIGN_OR_RAISE(
+ Datum out, DictionaryEncode((*arrays)[i],
+ compute::DictionaryEncodeOptions::Defaults(), &ctx));
+ (*arrays)[i] = out.chunked_array();
+ (*fields)[i] = (*fields)[i]->WithType((*arrays)[i]->type());
+ return Status::OK();
+ };
+
+ if (!options.categorical_columns.empty()) {
+ for (int i = 0; i < static_cast<int>(arrays->size()); i++) {
+ if ((*arrays)[i]->type()->id() != Type::DICTIONARY &&
+ options.categorical_columns.count((*fields)[i]->name())) {
+ columns_to_encode.push_back(i);
+ }
+ }
+ }
+ if (options.strings_to_categorical) {
+ for (int i = 0; i < static_cast<int>(arrays->size()); i++) {
+ if (is_base_binary_like((*arrays)[i]->type()->id())) {
+ columns_to_encode.push_back(i);
+ }
+ }
+ }
+ return OptionalParallelFor(options.use_threads,
+ static_cast<int>(columns_to_encode.size()), EncodeColumn);
+}
+
+} // namespace
+
+Status ConvertArrayToPandas(const PandasOptions& options, std::shared_ptr<Array> arr,
+ PyObject* py_ref, PyObject** out) {
+ return ConvertChunkedArrayToPandas(
+ options, std::make_shared<ChunkedArray>(std::move(arr)), py_ref, out);
+}
+
+Status ConvertChunkedArrayToPandas(const PandasOptions& options,
+ std::shared_ptr<ChunkedArray> arr, PyObject* py_ref,
+ PyObject** out) {
+ if (options.decode_dictionaries && arr->type()->id() == Type::DICTIONARY) {
+ const auto& dense_type =
+ checked_cast<const DictionaryType&>(*arr->type()).value_type();
+ RETURN_NOT_OK(DecodeDictionaries(options.pool, dense_type, &arr));
+ DCHECK_NE(arr->type()->id(), Type::DICTIONARY);
+
+ // The original Python DictionaryArray won't own the memory anymore
+ // as we actually built a new array when we decoded the DictionaryArray
+ // thus let the final resulting numpy array own the memory through a Capsule
+ py_ref = nullptr;
+ }
+
+ if (options.strings_to_categorical && is_base_binary_like(arr->type()->id())) {
+ if (options.zero_copy_only) {
+ return Status::Invalid("Need to dictionary encode a column, but ",
+ "only zero-copy conversions allowed");
+ }
+ compute::ExecContext ctx(options.pool);
+ ARROW_ASSIGN_OR_RAISE(
+ Datum out,
+ DictionaryEncode(arr, compute::DictionaryEncodeOptions::Defaults(), &ctx));
+ arr = out.chunked_array();
+ }
+
+ PandasOptions modified_options = options;
+ modified_options.strings_to_categorical = false;
+
+ // ARROW-7596: We permit the hybrid Series/DataFrame code path to do zero copy
+ // optimizations that we do not allow in the default case when converting
+ // Table->DataFrame
+ modified_options.allow_zero_copy_blocks = true;
+
+ PandasWriter::type output_type;
+ RETURN_NOT_OK(GetPandasWriterType(*arr, modified_options, &output_type));
+ if (options.decode_dictionaries) {
+ DCHECK_NE(output_type, PandasWriter::CATEGORICAL);
+ }
+
+ std::shared_ptr<PandasWriter> writer;
+ RETURN_NOT_OK(MakeWriter(modified_options, output_type, *arr->type(), arr->length(),
+ /*num_columns=*/1, &writer));
+ RETURN_NOT_OK(writer->TransferSingle(std::move(arr), py_ref));
+ return writer->GetSeriesResult(out);
+}
+
+Status ConvertTableToPandas(const PandasOptions& options, std::shared_ptr<Table> table,
+ PyObject** out) {
+ ChunkedArrayVector arrays = table->columns();
+ FieldVector fields = table->fields();
+
+ // ARROW-3789: allow "self-destructing" by releasing references to columns as
+ // we convert them to pandas
+ table = nullptr;
+
+ RETURN_NOT_OK(ConvertCategoricals(options, &arrays, &fields));
+
+ PandasOptions modified_options = options;
+ modified_options.strings_to_categorical = false;
+ modified_options.categorical_columns.clear();
+
+ if (options.split_blocks) {
+ modified_options.allow_zero_copy_blocks = true;
+ SplitBlockCreator helper(modified_options, std::move(fields), std::move(arrays));
+ return helper.Convert(out);
+ } else {
+ ConsolidatedBlockCreator helper(modified_options, std::move(fields),
+ std::move(arrays));
+ return helper.Convert(out);
+ }
+}
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/arrow_to_pandas.h b/src/arrow/cpp/src/arrow/python/arrow_to_pandas.h
new file mode 100644
index 000000000..6570364b8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/arrow_to_pandas.h
@@ -0,0 +1,124 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Functions for converting between pandas's NumPy-based data representation
+// and Arrow data structures
+
+#pragma once
+
+#include "arrow/python/platform.h"
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+
+#include "arrow/memory_pool.h"
+#include "arrow/python/visibility.h"
+
+namespace arrow {
+
+class Array;
+class ChunkedArray;
+class Column;
+class DataType;
+class MemoryPool;
+class Status;
+class Table;
+
+namespace py {
+
+struct PandasOptions {
+ /// arrow::MemoryPool to use for memory allocations
+ MemoryPool* pool = default_memory_pool();
+
+ /// If true, we will convert all string columns to categoricals
+ bool strings_to_categorical = false;
+ bool zero_copy_only = false;
+ bool integer_object_nulls = false;
+ bool date_as_object = false;
+ bool timestamp_as_object = false;
+ bool use_threads = false;
+
+ /// Coerce all date and timestamp to datetime64[ns]
+ bool coerce_temporal_nanoseconds = false;
+
+ /// Used to maintain backwards compatibility for
+ /// timezone bugs (see ARROW-9528). Should be removed
+ /// after Arrow 2.0 release.
+ bool ignore_timezone = false;
+
+ /// \brief If true, do not create duplicate PyObject versions of equal
+ /// objects. This only applies to immutable objects like strings or datetime
+ /// objects
+ bool deduplicate_objects = false;
+
+ /// \brief For certain data types, a cast is needed in order to store the
+ /// data in a pandas DataFrame or Series (e.g. timestamps are always stored
+ /// as nanoseconds in pandas). This option controls whether it is a safe
+ /// cast or not.
+ bool safe_cast = true;
+
+ /// \brief If true, create one block per column rather than consolidated
+ /// blocks (1 per data type). Do zero-copy wrapping when there are no
+ /// nulls. pandas currently will consolidate the blocks on its own, causing
+ /// increased memory use, so keep this in mind if you are working on a
+ /// memory-constrained situation.
+ bool split_blocks = false;
+
+ /// \brief If true, allow non-writable zero-copy views to be created for
+ /// single column blocks. This option is also used to provide zero copy for
+ /// Series data
+ bool allow_zero_copy_blocks = false;
+
+ /// \brief If true, attempt to deallocate buffers in passed Arrow object if
+ /// it is the only remaining shared_ptr copy of it. See ARROW-3789 for
+ /// original context for this feature. Only currently implemented for Table
+ /// conversions
+ bool self_destruct = false;
+
+ // Used internally for nested arrays.
+ bool decode_dictionaries = false;
+
+ // Columns that should be casted to categorical
+ std::unordered_set<std::string> categorical_columns;
+
+ // Columns that should be passed through to be converted to
+ // ExtensionArray/Block
+ std::unordered_set<std::string> extension_columns;
+};
+
+ARROW_PYTHON_EXPORT
+Status ConvertArrayToPandas(const PandasOptions& options, std::shared_ptr<Array> arr,
+ PyObject* py_ref, PyObject** out);
+
+ARROW_PYTHON_EXPORT
+Status ConvertChunkedArrayToPandas(const PandasOptions& options,
+ std::shared_ptr<ChunkedArray> col, PyObject* py_ref,
+ PyObject** out);
+
+// Convert a whole table as efficiently as possible to a pandas.DataFrame.
+//
+// The returned Python object is a list of tuples consisting of the exact 2D
+// BlockManager structure of the pandas.DataFrame used as of pandas 0.19.x.
+//
+// tuple item: (indices: ndarray[int32], block: ndarray[TYPE, ndim=2])
+ARROW_PYTHON_EXPORT
+Status ConvertTableToPandas(const PandasOptions& options, std::shared_ptr<Table> table,
+ PyObject** out);
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/arrow_to_python_internal.h b/src/arrow/cpp/src/arrow/python/arrow_to_python_internal.h
new file mode 100644
index 000000000..514cda320
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/arrow_to_python_internal.h
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/array.h"
+#include "arrow/python/platform.h"
+
+namespace arrow {
+namespace py {
+namespace internal {
+// TODO(ARROW-12976): See if we can refactor Pandas ObjectWriter logic
+// to the .cc file and move this there as well if we can.
+
+// Converts array to a sequency of python objects.
+template <typename ArrayType, typename WriteValue, typename Assigner>
+inline Status WriteArrayObjects(const ArrayType& arr, WriteValue&& write_func,
+ Assigner out_values) {
+ // TODO(ARROW-12976): Use visitor here?
+ const bool has_nulls = arr.null_count() > 0;
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ if (has_nulls && arr.IsNull(i)) {
+ Py_INCREF(Py_None);
+ *out_values = Py_None;
+ } else {
+ RETURN_NOT_OK(write_func(arr.GetView(i), out_values));
+ }
+ ++out_values;
+ }
+ return Status::OK();
+}
+
+} // namespace internal
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/benchmark.cc b/src/arrow/cpp/src/arrow/python/benchmark.cc
new file mode 100644
index 000000000..2d29f69d2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/benchmark.cc
@@ -0,0 +1,38 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arrow/python/benchmark.h>
+#include <arrow/python/helpers.h>
+
+namespace arrow {
+namespace py {
+namespace benchmark {
+
+void Benchmark_PandasObjectIsNull(PyObject* list) {
+ if (!PyList_CheckExact(list)) {
+ PyErr_SetString(PyExc_TypeError, "expected a list");
+ return;
+ }
+ Py_ssize_t i, n = PyList_GET_SIZE(list);
+ for (i = 0; i < n; i++) {
+ internal::PandasObjectIsNull(PyList_GET_ITEM(list, i));
+ }
+}
+
+} // namespace benchmark
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/benchmark.h b/src/arrow/cpp/src/arrow/python/benchmark.h
new file mode 100644
index 000000000..8060dd337
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/benchmark.h
@@ -0,0 +1,36 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/python/platform.h"
+
+#include "arrow/python/visibility.h"
+
+namespace arrow {
+namespace py {
+namespace benchmark {
+
+// Micro-benchmark routines for use from ASV
+
+// Run PandasObjectIsNull() once over every object in *list*
+ARROW_PYTHON_EXPORT
+void Benchmark_PandasObjectIsNull(PyObject* list);
+
+} // namespace benchmark
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/common.cc b/src/arrow/cpp/src/arrow/python/common.cc
new file mode 100644
index 000000000..6fe2ed4da
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/common.cc
@@ -0,0 +1,203 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/python/common.h"
+
+#include <cstdlib>
+#include <mutex>
+#include <string>
+
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+#include "arrow/python/helpers.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace py {
+
+static std::mutex memory_pool_mutex;
+static MemoryPool* default_python_pool = nullptr;
+
+void set_default_memory_pool(MemoryPool* pool) {
+ std::lock_guard<std::mutex> guard(memory_pool_mutex);
+ default_python_pool = pool;
+}
+
+MemoryPool* get_memory_pool() {
+ std::lock_guard<std::mutex> guard(memory_pool_mutex);
+ if (default_python_pool) {
+ return default_python_pool;
+ } else {
+ return default_memory_pool();
+ }
+}
+
+// ----------------------------------------------------------------------
+// PythonErrorDetail
+
+namespace {
+
+const char kErrorDetailTypeId[] = "arrow::py::PythonErrorDetail";
+
+// Try to match the Python exception type with an appropriate Status code
+StatusCode MapPyError(PyObject* exc_type) {
+ StatusCode code;
+
+ if (PyErr_GivenExceptionMatches(exc_type, PyExc_MemoryError)) {
+ code = StatusCode::OutOfMemory;
+ } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_IndexError)) {
+ code = StatusCode::IndexError;
+ } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_KeyError)) {
+ code = StatusCode::KeyError;
+ } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_TypeError)) {
+ code = StatusCode::TypeError;
+ } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_ValueError) ||
+ PyErr_GivenExceptionMatches(exc_type, PyExc_OverflowError)) {
+ code = StatusCode::Invalid;
+ } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_EnvironmentError)) {
+ code = StatusCode::IOError;
+ } else if (PyErr_GivenExceptionMatches(exc_type, PyExc_NotImplementedError)) {
+ code = StatusCode::NotImplemented;
+ } else {
+ code = StatusCode::UnknownError;
+ }
+ return code;
+}
+
+// PythonErrorDetail indicates a Python exception was raised.
+class PythonErrorDetail : public StatusDetail {
+ public:
+ const char* type_id() const override { return kErrorDetailTypeId; }
+
+ std::string ToString() const override {
+ // This is simple enough not to need the GIL
+ const auto ty = reinterpret_cast<const PyTypeObject*>(exc_type_.obj());
+ // XXX Should we also print traceback?
+ return std::string("Python exception: ") + ty->tp_name;
+ }
+
+ void RestorePyError() const {
+ Py_INCREF(exc_type_.obj());
+ Py_INCREF(exc_value_.obj());
+ Py_INCREF(exc_traceback_.obj());
+ PyErr_Restore(exc_type_.obj(), exc_value_.obj(), exc_traceback_.obj());
+ }
+
+ PyObject* exc_type() const { return exc_type_.obj(); }
+
+ PyObject* exc_value() const { return exc_value_.obj(); }
+
+ static std::shared_ptr<PythonErrorDetail> FromPyError() {
+ PyObject* exc_type = nullptr;
+ PyObject* exc_value = nullptr;
+ PyObject* exc_traceback = nullptr;
+
+ PyErr_Fetch(&exc_type, &exc_value, &exc_traceback);
+ PyErr_NormalizeException(&exc_type, &exc_value, &exc_traceback);
+ ARROW_CHECK(exc_type)
+ << "PythonErrorDetail::FromPyError called without a Python error set";
+ DCHECK(PyType_Check(exc_type));
+ DCHECK(exc_value); // Ensured by PyErr_NormalizeException, double-check
+ if (exc_traceback == nullptr) {
+ // Needed by PyErr_Restore()
+ Py_INCREF(Py_None);
+ exc_traceback = Py_None;
+ }
+
+ std::shared_ptr<PythonErrorDetail> detail(new PythonErrorDetail);
+ detail->exc_type_.reset(exc_type);
+ detail->exc_value_.reset(exc_value);
+ detail->exc_traceback_.reset(exc_traceback);
+ return detail;
+ }
+
+ protected:
+ PythonErrorDetail() = default;
+
+ OwnedRefNoGIL exc_type_, exc_value_, exc_traceback_;
+};
+
+} // namespace
+
+// ----------------------------------------------------------------------
+// Python exception <-> Status
+
+Status ConvertPyError(StatusCode code) {
+ auto detail = PythonErrorDetail::FromPyError();
+ if (code == StatusCode::UnknownError) {
+ code = MapPyError(detail->exc_type());
+ }
+
+ std::string message;
+ RETURN_NOT_OK(internal::PyObject_StdStringStr(detail->exc_value(), &message));
+ return Status(code, message, detail);
+}
+
+bool IsPyError(const Status& status) {
+ if (status.ok()) {
+ return false;
+ }
+ auto detail = status.detail();
+ bool result = detail != nullptr && detail->type_id() == kErrorDetailTypeId;
+ return result;
+}
+
+void RestorePyError(const Status& status) {
+ ARROW_CHECK(IsPyError(status));
+ const auto& detail = checked_cast<const PythonErrorDetail&>(*status.detail());
+ detail.RestorePyError();
+}
+
+// ----------------------------------------------------------------------
+// PyBuffer
+
+PyBuffer::PyBuffer() : Buffer(nullptr, 0) {}
+
+Status PyBuffer::Init(PyObject* obj) {
+ if (!PyObject_GetBuffer(obj, &py_buf_, PyBUF_ANY_CONTIGUOUS)) {
+ data_ = reinterpret_cast<const uint8_t*>(py_buf_.buf);
+ ARROW_CHECK_NE(data_, nullptr) << "Null pointer in Py_buffer";
+ size_ = py_buf_.len;
+ capacity_ = py_buf_.len;
+ is_mutable_ = !py_buf_.readonly;
+ return Status::OK();
+ } else {
+ return ConvertPyError(StatusCode::Invalid);
+ }
+}
+
+Result<std::shared_ptr<Buffer>> PyBuffer::FromPyObject(PyObject* obj) {
+ PyBuffer* buf = new PyBuffer();
+ std::shared_ptr<Buffer> res(buf);
+ RETURN_NOT_OK(buf->Init(obj));
+ return res;
+}
+
+PyBuffer::~PyBuffer() {
+ if (data_ != nullptr) {
+ PyAcquireGIL lock;
+ PyBuffer_Release(&py_buf_);
+ }
+}
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/common.h b/src/arrow/cpp/src/arrow/python/common.h
new file mode 100644
index 000000000..24dcb130a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/common.h
@@ -0,0 +1,360 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/python/pyarrow.h"
+#include "arrow/python/visibility.h"
+#include "arrow/result.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+
+class MemoryPool;
+template <class T>
+class Result;
+
+namespace py {
+
+// Convert current Python error to a Status. The Python error state is cleared
+// and can be restored with RestorePyError().
+ARROW_PYTHON_EXPORT Status ConvertPyError(StatusCode code = StatusCode::UnknownError);
+// Query whether the given Status is a Python error (as wrapped by ConvertPyError()).
+ARROW_PYTHON_EXPORT bool IsPyError(const Status& status);
+// Restore a Python error wrapped in a Status.
+ARROW_PYTHON_EXPORT void RestorePyError(const Status& status);
+
+// Catch a pending Python exception and return the corresponding Status.
+// If no exception is pending, Status::OK() is returned.
+inline Status CheckPyError(StatusCode code = StatusCode::UnknownError) {
+ if (ARROW_PREDICT_TRUE(!PyErr_Occurred())) {
+ return Status::OK();
+ } else {
+ return ConvertPyError(code);
+ }
+}
+
+#define RETURN_IF_PYERROR() ARROW_RETURN_NOT_OK(CheckPyError())
+
+#define PY_RETURN_IF_ERROR(CODE) ARROW_RETURN_NOT_OK(CheckPyError(CODE))
+
+// For Cython, as you can't define template C++ functions in Cython, only use them.
+// This function can set a Python exception. It assumes that T has a (cheap)
+// default constructor.
+template <class T>
+T GetResultValue(Result<T> result) {
+ if (ARROW_PREDICT_TRUE(result.ok())) {
+ return *std::move(result);
+ } else {
+ int r = internal::check_status(result.status()); // takes the GIL
+ assert(r == -1); // should have errored out
+ ARROW_UNUSED(r);
+ return {};
+ }
+}
+
+// A RAII-style helper that ensures the GIL is acquired inside a lexical block.
+class ARROW_PYTHON_EXPORT PyAcquireGIL {
+ public:
+ PyAcquireGIL() : acquired_gil_(false) { acquire(); }
+
+ ~PyAcquireGIL() { release(); }
+
+ void acquire() {
+ if (!acquired_gil_) {
+ state_ = PyGILState_Ensure();
+ acquired_gil_ = true;
+ }
+ }
+
+ // idempotent
+ void release() {
+ if (acquired_gil_) {
+ PyGILState_Release(state_);
+ acquired_gil_ = false;
+ }
+ }
+
+ private:
+ bool acquired_gil_;
+ PyGILState_STATE state_;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(PyAcquireGIL);
+};
+
+// A RAII-style helper that releases the GIL until the end of a lexical block
+class ARROW_PYTHON_EXPORT PyReleaseGIL {
+ public:
+ PyReleaseGIL() { saved_state_ = PyEval_SaveThread(); }
+
+ ~PyReleaseGIL() { PyEval_RestoreThread(saved_state_); }
+
+ private:
+ PyThreadState* saved_state_;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(PyReleaseGIL);
+};
+
+// A helper to call safely into the Python interpreter from arbitrary C++ code.
+// The GIL is acquired, and the current thread's error status is preserved.
+template <typename Function>
+auto SafeCallIntoPython(Function&& func) -> decltype(func()) {
+ PyAcquireGIL lock;
+ PyObject* exc_type;
+ PyObject* exc_value;
+ PyObject* exc_traceback;
+ PyErr_Fetch(&exc_type, &exc_value, &exc_traceback);
+ auto maybe_status = std::forward<Function>(func)();
+ // If the return Status is a "Python error", the current Python error status
+ // describes the error and shouldn't be clobbered.
+ if (!IsPyError(::arrow::internal::GenericToStatus(maybe_status)) &&
+ exc_type != NULLPTR) {
+ PyErr_Restore(exc_type, exc_value, exc_traceback);
+ }
+ return maybe_status;
+}
+
+// A RAII primitive that DECREFs the underlying PyObject* when it
+// goes out of scope.
+class ARROW_PYTHON_EXPORT OwnedRef {
+ public:
+ OwnedRef() : obj_(NULLPTR) {}
+ OwnedRef(OwnedRef&& other) : OwnedRef(other.detach()) {}
+ explicit OwnedRef(PyObject* obj) : obj_(obj) {}
+
+ OwnedRef& operator=(OwnedRef&& other) {
+ obj_ = other.detach();
+ return *this;
+ }
+
+ ~OwnedRef() { reset(); }
+
+ void reset(PyObject* obj) {
+ Py_XDECREF(obj_);
+ obj_ = obj;
+ }
+
+ void reset() { reset(NULLPTR); }
+
+ PyObject* detach() {
+ PyObject* result = obj_;
+ obj_ = NULLPTR;
+ return result;
+ }
+
+ PyObject* obj() const { return obj_; }
+
+ PyObject** ref() { return &obj_; }
+
+ operator bool() const { return obj_ != NULLPTR; }
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(OwnedRef);
+
+ PyObject* obj_;
+};
+
+// Same as OwnedRef, but ensures the GIL is taken when it goes out of scope.
+// This is for situations where the GIL is not always known to be held
+// (e.g. if it is released in the middle of a function for performance reasons)
+class ARROW_PYTHON_EXPORT OwnedRefNoGIL : public OwnedRef {
+ public:
+ OwnedRefNoGIL() : OwnedRef() {}
+ OwnedRefNoGIL(OwnedRefNoGIL&& other) : OwnedRef(other.detach()) {}
+ explicit OwnedRefNoGIL(PyObject* obj) : OwnedRef(obj) {}
+
+ ~OwnedRefNoGIL() {
+ PyAcquireGIL lock;
+ reset();
+ }
+};
+
+template <typename Fn>
+struct BoundFunction;
+
+template <typename... Args>
+struct BoundFunction<void(PyObject*, Args...)> {
+ // We bind `cdef void fn(object, ...)` to get a `Status(...)`
+ // where the Status contains any Python error raised by `fn`
+ using Unbound = void(PyObject*, Args...);
+ using Bound = Status(Args...);
+
+ BoundFunction(Unbound* unbound, PyObject* bound_arg)
+ : bound_arg_(bound_arg), unbound_(unbound) {}
+
+ Status Invoke(Args... args) const {
+ PyAcquireGIL lock;
+ unbound_(bound_arg_.obj(), std::forward<Args>(args)...);
+ RETURN_IF_PYERROR();
+ return Status::OK();
+ }
+
+ Unbound* unbound_;
+ OwnedRefNoGIL bound_arg_;
+};
+
+template <typename Return, typename... Args>
+struct BoundFunction<Return(PyObject*, Args...)> {
+ // We bind `cdef Return fn(object, ...)` to get a `Result<Return>(...)`
+ // where the Result contains any Python error raised by `fn` or the
+ // return value from `fn`.
+ using Unbound = Return(PyObject*, Args...);
+ using Bound = Result<Return>(Args...);
+
+ BoundFunction(Unbound* unbound, PyObject* bound_arg)
+ : bound_arg_(bound_arg), unbound_(unbound) {}
+
+ Result<Return> Invoke(Args... args) const {
+ PyAcquireGIL lock;
+ Return ret = unbound_(bound_arg_.obj(), std::forward<Args>(args)...);
+ RETURN_IF_PYERROR();
+ return ret;
+ }
+
+ Unbound* unbound_;
+ OwnedRefNoGIL bound_arg_;
+};
+
+template <typename OutFn, typename Return, typename... Args>
+std::function<OutFn> BindFunction(Return (*unbound)(PyObject*, Args...),
+ PyObject* bound_arg) {
+ using Fn = BoundFunction<Return(PyObject*, Args...)>;
+
+ static_assert(std::is_same<typename Fn::Bound, OutFn>::value,
+ "requested bound function of unsupported type");
+
+ Py_XINCREF(bound_arg);
+ auto bound_fn = std::make_shared<Fn>(unbound, bound_arg);
+ return
+ [bound_fn](Args... args) { return bound_fn->Invoke(std::forward<Args>(args)...); };
+}
+
+// A temporary conversion of a Python object to a bytes area.
+struct PyBytesView {
+ const char* bytes;
+ Py_ssize_t size;
+ bool is_utf8;
+
+ static Result<PyBytesView> FromString(PyObject* obj, bool check_utf8 = false) {
+ PyBytesView self;
+ ARROW_RETURN_NOT_OK(self.ParseString(obj, check_utf8));
+ return std::move(self);
+ }
+
+ static Result<PyBytesView> FromUnicode(PyObject* obj) {
+ PyBytesView self;
+ ARROW_RETURN_NOT_OK(self.ParseUnicode(obj));
+ return std::move(self);
+ }
+
+ static Result<PyBytesView> FromBinary(PyObject* obj) {
+ PyBytesView self;
+ ARROW_RETURN_NOT_OK(self.ParseBinary(obj));
+ return std::move(self);
+ }
+
+ // View the given Python object as string-like, i.e. str or (utf8) bytes
+ Status ParseString(PyObject* obj, bool check_utf8 = false) {
+ if (PyUnicode_Check(obj)) {
+ return ParseUnicode(obj);
+ } else {
+ ARROW_RETURN_NOT_OK(ParseBinary(obj));
+ if (check_utf8) {
+ // Check the bytes are utf8 utf-8
+ OwnedRef decoded(PyUnicode_FromStringAndSize(bytes, size));
+ if (ARROW_PREDICT_TRUE(!PyErr_Occurred())) {
+ is_utf8 = true;
+ } else {
+ PyErr_Clear();
+ is_utf8 = false;
+ }
+ }
+ return Status::OK();
+ }
+ }
+
+ // View the given Python object as unicode string
+ Status ParseUnicode(PyObject* obj) {
+ // The utf-8 representation is cached on the unicode object
+ bytes = PyUnicode_AsUTF8AndSize(obj, &size);
+ RETURN_IF_PYERROR();
+ is_utf8 = true;
+ return Status::OK();
+ }
+
+ // View the given Python object as binary-like, i.e. bytes
+ Status ParseBinary(PyObject* obj) {
+ if (PyBytes_Check(obj)) {
+ bytes = PyBytes_AS_STRING(obj);
+ size = PyBytes_GET_SIZE(obj);
+ is_utf8 = false;
+ } else if (PyByteArray_Check(obj)) {
+ bytes = PyByteArray_AS_STRING(obj);
+ size = PyByteArray_GET_SIZE(obj);
+ is_utf8 = false;
+ } else if (PyMemoryView_Check(obj)) {
+ PyObject* ref = PyMemoryView_GetContiguous(obj, PyBUF_READ, 'C');
+ RETURN_IF_PYERROR();
+ Py_buffer* buffer = PyMemoryView_GET_BUFFER(ref);
+ bytes = reinterpret_cast<const char*>(buffer->buf);
+ size = buffer->len;
+ is_utf8 = false;
+ } else {
+ return Status::TypeError("Expected bytes, got a '", Py_TYPE(obj)->tp_name,
+ "' object");
+ }
+ return Status::OK();
+ }
+
+ protected:
+ OwnedRef ref;
+};
+
+class ARROW_PYTHON_EXPORT PyBuffer : public Buffer {
+ public:
+ /// While memoryview objects support multi-dimensional buffers, PyBuffer only supports
+ /// one-dimensional byte buffers.
+ ~PyBuffer();
+
+ static Result<std::shared_ptr<Buffer>> FromPyObject(PyObject* obj);
+
+ private:
+ PyBuffer();
+ Status Init(PyObject*);
+
+ Py_buffer py_buf_;
+};
+
+// Return the common PyArrow memory pool
+ARROW_PYTHON_EXPORT void set_default_memory_pool(MemoryPool* pool);
+ARROW_PYTHON_EXPORT MemoryPool* get_memory_pool();
+
+// This is annoying: because C++11 does not allow implicit conversion of string
+// literals to non-const char*, we need to go through some gymnastics to use
+// PyObject_CallMethod without a lot of pain (its arguments are non-const
+// char*)
+template <typename... ArgTypes>
+static inline PyObject* cpp_PyObject_CallMethod(PyObject* obj, const char* method_name,
+ const char* argspec, ArgTypes... args) {
+ return PyObject_CallMethod(obj, const_cast<char*>(method_name),
+ const_cast<char*>(argspec), args...);
+}
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/datetime.cc b/src/arrow/cpp/src/arrow/python/datetime.cc
new file mode 100644
index 000000000..8c954998f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/datetime.cc
@@ -0,0 +1,566 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#include "arrow/python/datetime.h"
+
+#include <algorithm>
+#include <chrono>
+#include <iomanip>
+
+#include "arrow/array.h"
+#include "arrow/python/arrow_to_python_internal.h"
+#include "arrow/python/common.h"
+#include "arrow/python/helpers.h"
+#include "arrow/python/platform.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/value_parsing.h"
+
+namespace arrow {
+namespace py {
+namespace internal {
+
+namespace {
+
+// Same as Regex '([+-])(0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])$'.
+// GCC 4.9 doesn't support regex, so handcode until support for it
+// is dropped.
+bool MatchFixedOffset(const std::string& tz, util::string_view* sign,
+ util::string_view* hour, util::string_view* minute) {
+ if (tz.size() < 5) {
+ return false;
+ }
+ const char* iter = tz.data();
+ if (*iter == '+' || *iter == '-') {
+ *sign = util::string_view(iter, 1);
+ iter++;
+ if (tz.size() < 6) {
+ return false;
+ }
+ }
+ if ((((*iter == '0' || *iter == '1') && *(iter + 1) >= '0' && *(iter + 1) <= '9') ||
+ (*iter == '2' && *(iter + 1) >= '0' && *(iter + 1) <= '3'))) {
+ *hour = util::string_view(iter, 2);
+ iter += 2;
+ } else {
+ return false;
+ }
+ if (*iter != ':') {
+ return false;
+ }
+ iter++;
+
+ if (*iter >= '0' && *iter <= '5' && *(iter + 1) >= '0' && *(iter + 1) <= '9') {
+ *minute = util::string_view(iter, 2);
+ iter += 2;
+ } else {
+ return false;
+ }
+ return iter == (tz.data() + tz.size());
+}
+
+static PyTypeObject MonthDayNanoTupleType = {};
+
+constexpr char* NonConst(const char* st) {
+ // Hack for python versions < 3.7 where members of PyStruct members
+ // where non-const (C++ doesn't like assigning string literals to these types)
+ return const_cast<char*>(st);
+}
+
+static PyStructSequence_Field MonthDayNanoField[] = {
+ {NonConst("months"), NonConst("The number of months in the interval")},
+ {NonConst("days"), NonConst("The number days in the interval")},
+ {NonConst("nanoseconds"), NonConst("The number of nanoseconds in the interval")},
+ {nullptr, nullptr}};
+
+static PyStructSequence_Desc MonthDayNanoTupleDesc = {
+ NonConst("MonthDayNano"),
+ NonConst("A calendar interval consisting of months, days and nanoseconds."),
+ MonthDayNanoField,
+ /*n_in_sequence=*/3};
+
+} // namespace
+
+PyDateTime_CAPI* datetime_api = nullptr;
+
+void InitDatetime() {
+ PyAcquireGIL lock;
+ datetime_api =
+ reinterpret_cast<PyDateTime_CAPI*>(PyCapsule_Import(PyDateTime_CAPSULE_NAME, 0));
+ if (datetime_api == nullptr) {
+ Py_FatalError("Could not import datetime C API");
+ }
+}
+
+// The following code is adapted from
+// https://github.com/numpy/numpy/blob/master/numpy/core/src/multiarray/datetime.c
+
+// Days per month, regular year and leap year
+static int64_t _days_per_month_table[2][12] = {
+ {31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31},
+ {31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}};
+
+static bool is_leapyear(int64_t year) {
+ return (year & 0x3) == 0 && // year % 4 == 0
+ ((year % 100) != 0 || (year % 400) == 0);
+}
+
+// Calculates the days offset from the 1970 epoch.
+static int64_t get_days_from_date(int64_t date_year, int64_t date_month,
+ int64_t date_day) {
+ int64_t i, month;
+ int64_t year, days = 0;
+ int64_t* month_lengths;
+
+ year = date_year - 1970;
+ days = year * 365;
+
+ // Adjust for leap years
+ if (days >= 0) {
+ // 1968 is the closest leap year before 1970.
+ // Exclude the current year, so add 1.
+ year += 1;
+ // Add one day for each 4 years
+ days += year / 4;
+ // 1900 is the closest previous year divisible by 100
+ year += 68;
+ // Subtract one day for each 100 years
+ days -= year / 100;
+ // 1600 is the closest previous year divisible by 400
+ year += 300;
+ // Add one day for each 400 years
+ days += year / 400;
+ } else {
+ // 1972 is the closest later year after 1970.
+ // Include the current year, so subtract 2.
+ year -= 2;
+ // Subtract one day for each 4 years
+ days += year / 4;
+ // 2000 is the closest later year divisible by 100
+ year -= 28;
+ // Add one day for each 100 years
+ days -= year / 100;
+ // 2000 is also the closest later year divisible by 400
+ // Subtract one day for each 400 years
+ days += year / 400;
+ }
+
+ month_lengths = _days_per_month_table[is_leapyear(date_year)];
+ month = date_month - 1;
+
+ // Add the months
+ for (i = 0; i < month; ++i) {
+ days += month_lengths[i];
+ }
+
+ // Add the days
+ days += date_day - 1;
+
+ return days;
+}
+
+// Modifies '*days_' to be the day offset within the year,
+// and returns the year.
+static int64_t days_to_yearsdays(int64_t* days_) {
+ const int64_t days_per_400years = (400 * 365 + 100 - 4 + 1);
+ // Adjust so it's relative to the year 2000 (divisible by 400)
+ int64_t days = (*days_) - (365 * 30 + 7);
+ int64_t year;
+
+ // Break down the 400 year cycle to get the year and day within the year
+ if (days >= 0) {
+ year = 400 * (days / days_per_400years);
+ days = days % days_per_400years;
+ } else {
+ year = 400 * ((days - (days_per_400years - 1)) / days_per_400years);
+ days = days % days_per_400years;
+ if (days < 0) {
+ days += days_per_400years;
+ }
+ }
+
+ // Work out the year/day within the 400 year cycle
+ if (days >= 366) {
+ year += 100 * ((days - 1) / (100 * 365 + 25 - 1));
+ days = (days - 1) % (100 * 365 + 25 - 1);
+ if (days >= 365) {
+ year += 4 * ((days + 1) / (4 * 365 + 1));
+ days = (days + 1) % (4 * 365 + 1);
+ if (days >= 366) {
+ year += (days - 1) / 365;
+ days = (days - 1) % 365;
+ }
+ }
+ }
+
+ *days_ = days;
+ return year + 2000;
+}
+
+// Extracts the month and year and day number from a number of days
+static void get_date_from_days(int64_t days, int64_t* date_year, int64_t* date_month,
+ int64_t* date_day) {
+ int64_t *month_lengths, i;
+
+ *date_year = days_to_yearsdays(&days);
+ month_lengths = _days_per_month_table[is_leapyear(*date_year)];
+
+ for (i = 0; i < 12; ++i) {
+ if (days < month_lengths[i]) {
+ *date_month = i + 1;
+ *date_day = days + 1;
+ return;
+ } else {
+ days -= month_lengths[i];
+ }
+ }
+
+ // Should never get here
+ return;
+}
+
+// Splitting time quantities, for example splitting total seconds into
+// minutes and remaining seconds. After we run
+// int64_t remaining = split_time(total, quotient, &next)
+// we have
+// total = next * quotient + remaining. Handles negative values by propagating
+// them: If total is negative, next will be negative and remaining will
+// always be non-negative.
+static inline int64_t split_time(int64_t total, int64_t quotient, int64_t* next) {
+ int64_t r = total % quotient;
+ if (r < 0) {
+ *next = total / quotient - 1;
+ return r + quotient;
+ } else {
+ *next = total / quotient;
+ return r;
+ }
+}
+
+static inline Status PyTime_convert_int(int64_t val, const TimeUnit::type unit,
+ int64_t* hour, int64_t* minute, int64_t* second,
+ int64_t* microsecond) {
+ switch (unit) {
+ case TimeUnit::NANO:
+ if (val % 1000 != 0) {
+ return Status::Invalid("Value ", val, " has non-zero nanoseconds");
+ }
+ val /= 1000;
+ // fall through
+ case TimeUnit::MICRO:
+ *microsecond = split_time(val, 1000000LL, &val);
+ *second = split_time(val, 60, &val);
+ *minute = split_time(val, 60, hour);
+ break;
+ case TimeUnit::MILLI:
+ *microsecond = split_time(val, 1000, &val) * 1000;
+ // fall through
+ case TimeUnit::SECOND:
+ *second = split_time(val, 60, &val);
+ *minute = split_time(val, 60, hour);
+ break;
+ default:
+ break;
+ }
+ return Status::OK();
+}
+
+static inline Status PyDate_convert_int(int64_t val, const DateUnit unit, int64_t* year,
+ int64_t* month, int64_t* day) {
+ switch (unit) {
+ case DateUnit::MILLI:
+ val /= 86400000LL; // fall through
+ case DateUnit::DAY:
+ get_date_from_days(val, year, month, day);
+ default:
+ break;
+ }
+ return Status::OK();
+}
+
+PyObject* NewMonthDayNanoTupleType() {
+ if (MonthDayNanoTupleType.tp_name == nullptr) {
+ if (PyStructSequence_InitType2(&MonthDayNanoTupleType, &MonthDayNanoTupleDesc) != 0) {
+ Py_FatalError("Could not initialize MonthDayNanoTuple");
+ }
+ }
+ Py_INCREF(&MonthDayNanoTupleType);
+ return (PyObject*)&MonthDayNanoTupleType;
+}
+
+Status PyTime_from_int(int64_t val, const TimeUnit::type unit, PyObject** out) {
+ int64_t hour = 0, minute = 0, second = 0, microsecond = 0;
+ RETURN_NOT_OK(PyTime_convert_int(val, unit, &hour, &minute, &second, &microsecond));
+ *out = PyTime_FromTime(static_cast<int32_t>(hour), static_cast<int32_t>(minute),
+ static_cast<int32_t>(second), static_cast<int32_t>(microsecond));
+ return Status::OK();
+}
+
+Status PyDate_from_int(int64_t val, const DateUnit unit, PyObject** out) {
+ int64_t year = 0, month = 0, day = 0;
+ RETURN_NOT_OK(PyDate_convert_int(val, unit, &year, &month, &day));
+ *out = PyDate_FromDate(static_cast<int32_t>(year), static_cast<int32_t>(month),
+ static_cast<int32_t>(day));
+ return Status::OK();
+}
+
+Status PyDateTime_from_int(int64_t val, const TimeUnit::type unit, PyObject** out) {
+ int64_t hour = 0, minute = 0, second = 0, microsecond = 0;
+ RETURN_NOT_OK(PyTime_convert_int(val, unit, &hour, &minute, &second, &microsecond));
+ int64_t total_days = 0;
+ hour = split_time(hour, 24, &total_days);
+ int64_t year = 0, month = 0, day = 0;
+ get_date_from_days(total_days, &year, &month, &day);
+ *out = PyDateTime_FromDateAndTime(
+ static_cast<int32_t>(year), static_cast<int32_t>(month), static_cast<int32_t>(day),
+ static_cast<int32_t>(hour), static_cast<int32_t>(minute),
+ static_cast<int32_t>(second), static_cast<int32_t>(microsecond));
+ return Status::OK();
+}
+
+int64_t PyDate_to_days(PyDateTime_Date* pydate) {
+ return get_days_from_date(PyDateTime_GET_YEAR(pydate), PyDateTime_GET_MONTH(pydate),
+ PyDateTime_GET_DAY(pydate));
+}
+
+Result<int64_t> PyDateTime_utcoffset_s(PyObject* obj) {
+ // calculate offset from UTC timezone in seconds
+ // supports only PyDateTime_DateTime and PyDateTime_Time objects
+ OwnedRef pyoffset(PyObject_CallMethod(obj, "utcoffset", NULL));
+ RETURN_IF_PYERROR();
+ if (pyoffset.obj() != nullptr && pyoffset.obj() != Py_None) {
+ auto delta = reinterpret_cast<PyDateTime_Delta*>(pyoffset.obj());
+ return internal::PyDelta_to_s(delta);
+ } else {
+ return 0;
+ }
+}
+
+Result<std::string> PyTZInfo_utcoffset_hhmm(PyObject* pytzinfo) {
+ // attempt to convert timezone offset objects to "+/-{hh}:{mm}" format
+ OwnedRef pydelta_object(PyObject_CallMethod(pytzinfo, "utcoffset", "O", Py_None));
+ RETURN_IF_PYERROR();
+
+ if (!PyDelta_Check(pydelta_object.obj())) {
+ return Status::Invalid(
+ "Object returned by tzinfo.utcoffset(None) is not an instance of "
+ "datetime.timedelta");
+ }
+ auto pydelta = reinterpret_cast<PyDateTime_Delta*>(pydelta_object.obj());
+
+ // retrieve the offset as seconds
+ auto total_seconds = internal::PyDelta_to_s(pydelta);
+
+ // determine whether the offset is positive or negative
+ auto sign = (total_seconds < 0) ? "-" : "+";
+ total_seconds = abs(total_seconds);
+
+ // calculate offset components
+ int64_t hours, minutes, seconds;
+ seconds = split_time(total_seconds, 60, &minutes);
+ minutes = split_time(minutes, 60, &hours);
+ if (seconds > 0) {
+ // check there are no remaining seconds
+ return Status::Invalid("Offset must represent whole number of minutes");
+ }
+
+ // construct the timezone string
+ std::stringstream stream;
+ stream << sign << std::setfill('0') << std::setw(2) << hours << ":" << std::setfill('0')
+ << std::setw(2) << minutes;
+ return stream.str();
+}
+
+// Converted from python. See https://github.com/apache/arrow/pull/7604
+// for details.
+Result<PyObject*> StringToTzinfo(const std::string& tz) {
+ util::string_view sign_str, hour_str, minute_str;
+ OwnedRef pytz;
+ RETURN_NOT_OK(internal::ImportModule("pytz", &pytz));
+
+ if (MatchFixedOffset(tz, &sign_str, &hour_str, &minute_str)) {
+ int sign = -1;
+ if (sign_str == "+") {
+ sign = 1;
+ }
+ OwnedRef fixed_offset;
+ RETURN_NOT_OK(internal::ImportFromModule(pytz.obj(), "FixedOffset", &fixed_offset));
+ uint32_t minutes, hours;
+ if (!::arrow::internal::ParseUnsigned(hour_str.data(), hour_str.size(), &hours) ||
+ !::arrow::internal::ParseUnsigned(minute_str.data(), minute_str.size(),
+ &minutes)) {
+ return Status::Invalid("Invalid timezone: ", tz);
+ }
+ OwnedRef total_minutes(PyLong_FromLong(
+ sign * ((static_cast<int>(hours) * 60) + static_cast<int>(minutes))));
+ RETURN_IF_PYERROR();
+ auto tzinfo =
+ PyObject_CallFunctionObjArgs(fixed_offset.obj(), total_minutes.obj(), NULL);
+ RETURN_IF_PYERROR();
+ return tzinfo;
+ }
+
+ OwnedRef timezone;
+ RETURN_NOT_OK(internal::ImportFromModule(pytz.obj(), "timezone", &timezone));
+ OwnedRef py_tz_string(
+ PyUnicode_FromStringAndSize(tz.c_str(), static_cast<Py_ssize_t>(tz.size())));
+ auto tzinfo = PyObject_CallFunctionObjArgs(timezone.obj(), py_tz_string.obj(), NULL);
+ RETURN_IF_PYERROR();
+ return tzinfo;
+}
+
+Result<std::string> TzinfoToString(PyObject* tzinfo) {
+ OwnedRef module_pytz; // import pytz
+ OwnedRef module_datetime; // import datetime
+ OwnedRef class_timezone; // from datetime import timezone
+ OwnedRef class_fixedoffset; // from pytz import _FixedOffset
+
+ // import necessary modules
+ RETURN_NOT_OK(internal::ImportModule("pytz", &module_pytz));
+ RETURN_NOT_OK(internal::ImportModule("datetime", &module_datetime));
+ // import necessary classes
+ RETURN_NOT_OK(
+ internal::ImportFromModule(module_pytz.obj(), "_FixedOffset", &class_fixedoffset));
+ RETURN_NOT_OK(
+ internal::ImportFromModule(module_datetime.obj(), "timezone", &class_timezone));
+
+ // check that it's a valid tzinfo object
+ if (!PyTZInfo_Check(tzinfo)) {
+ return Status::TypeError("Not an instance of datetime.tzinfo");
+ }
+
+ // if tzinfo is an instance of pytz._FixedOffset or datetime.timezone return the
+ // HH:MM offset string representation
+ if (PyObject_IsInstance(tzinfo, class_timezone.obj()) ||
+ PyObject_IsInstance(tzinfo, class_fixedoffset.obj())) {
+ // still recognize datetime.timezone.utc as UTC (instead of +00:00)
+ OwnedRef tzname_object(PyObject_CallMethod(tzinfo, "tzname", "O", Py_None));
+ RETURN_IF_PYERROR();
+ if (PyUnicode_Check(tzname_object.obj())) {
+ std::string result;
+ RETURN_NOT_OK(internal::PyUnicode_AsStdString(tzname_object.obj(), &result));
+ if (result == "UTC") {
+ return result;
+ }
+ }
+ return PyTZInfo_utcoffset_hhmm(tzinfo);
+ }
+
+ // try to look up zone attribute
+ if (PyObject_HasAttrString(tzinfo, "zone")) {
+ OwnedRef zone(PyObject_GetAttrString(tzinfo, "zone"));
+ RETURN_IF_PYERROR();
+ std::string result;
+ RETURN_NOT_OK(internal::PyUnicode_AsStdString(zone.obj(), &result));
+ return result;
+ }
+
+ // attempt to call tzinfo.tzname(None)
+ OwnedRef tzname_object(PyObject_CallMethod(tzinfo, "tzname", "O", Py_None));
+ RETURN_IF_PYERROR();
+ if (PyUnicode_Check(tzname_object.obj())) {
+ std::string result;
+ RETURN_NOT_OK(internal::PyUnicode_AsStdString(tzname_object.obj(), &result));
+ return result;
+ }
+
+ // fall back to HH:MM offset string representation based on tzinfo.utcoffset(None)
+ return PyTZInfo_utcoffset_hhmm(tzinfo);
+}
+
+PyObject* MonthDayNanoIntervalToNamedTuple(
+ const MonthDayNanoIntervalType::MonthDayNanos& interval) {
+ OwnedRef tuple(PyStructSequence_New(&MonthDayNanoTupleType));
+ if (ARROW_PREDICT_FALSE(tuple.obj() == nullptr)) {
+ return nullptr;
+ }
+ PyStructSequence_SetItem(tuple.obj(), /*pos=*/0, PyLong_FromLong(interval.months));
+ PyStructSequence_SetItem(tuple.obj(), /*pos=*/1, PyLong_FromLong(interval.days));
+ PyStructSequence_SetItem(tuple.obj(), /*pos=*/2,
+ PyLong_FromLongLong(interval.nanoseconds));
+ return tuple.detach();
+}
+
+namespace {
+
+// Wrapper around a Python list object that mimics dereference and assignment
+// operations.
+struct PyListAssigner {
+ public:
+ explicit PyListAssigner(PyObject* list) : list_(list) { DCHECK(PyList_Check(list_)); }
+
+ PyListAssigner& operator*() { return *this; }
+
+ void operator=(PyObject* obj) {
+ if (ARROW_PREDICT_FALSE(PyList_SetItem(list_, current_index_, obj) == -1)) {
+ Py_FatalError("list did not have the correct preallocated size.");
+ }
+ }
+
+ PyListAssigner& operator++() {
+ current_index_++;
+ return *this;
+ }
+
+ PyListAssigner& operator+=(int64_t offset) {
+ current_index_ += offset;
+ return *this;
+ }
+
+ private:
+ PyObject* list_;
+ int64_t current_index_ = 0;
+};
+
+} // namespace
+
+Result<PyObject*> MonthDayNanoIntervalArrayToPyList(
+ const MonthDayNanoIntervalArray& array) {
+ OwnedRef out_list(PyList_New(array.length()));
+ RETURN_IF_PYERROR();
+ PyListAssigner out_objects(out_list.obj());
+ auto& interval_array =
+ arrow::internal::checked_cast<const MonthDayNanoIntervalArray&>(array);
+ RETURN_NOT_OK(internal::WriteArrayObjects(
+ interval_array,
+ [&](const MonthDayNanoIntervalType::MonthDayNanos& interval, PyListAssigner& out) {
+ PyObject* tuple = internal::MonthDayNanoIntervalToNamedTuple(interval);
+ if (ARROW_PREDICT_FALSE(tuple == nullptr)) {
+ RETURN_IF_PYERROR();
+ }
+
+ *out = tuple;
+ return Status::OK();
+ },
+ out_objects));
+ return out_list.detach();
+}
+
+Result<PyObject*> MonthDayNanoIntervalScalarToPyObject(
+ const MonthDayNanoIntervalScalar& scalar) {
+ if (scalar.is_valid) {
+ return internal::MonthDayNanoIntervalToNamedTuple(scalar.value);
+ } else {
+ Py_INCREF(Py_None);
+ return Py_None;
+ }
+}
+
+} // namespace internal
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/datetime.h b/src/arrow/cpp/src/arrow/python/datetime.h
new file mode 100644
index 000000000..dd07710aa
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/datetime.h
@@ -0,0 +1,211 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <chrono>
+
+#include "arrow/python/platform.h"
+#include "arrow/python/visibility.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/logging.h"
+
+// By default, PyDateTimeAPI is a *static* variable. This forces
+// PyDateTime_IMPORT to be called in every C/C++ module using the
+// C datetime API. This is error-prone and potentially costly.
+// Instead, we redefine PyDateTimeAPI to point to a global variable,
+// which is initialized once by calling InitDatetime().
+#define PyDateTimeAPI ::arrow::py::internal::datetime_api
+
+namespace arrow {
+namespace py {
+namespace internal {
+
+extern PyDateTime_CAPI* datetime_api;
+
+ARROW_PYTHON_EXPORT
+void InitDatetime();
+
+// Returns the MonthDayNano namedtuple type (increments the reference count).
+ARROW_PYTHON_EXPORT
+PyObject* NewMonthDayNanoTupleType();
+
+ARROW_PYTHON_EXPORT
+inline int64_t PyTime_to_us(PyObject* pytime) {
+ return (PyDateTime_TIME_GET_HOUR(pytime) * 3600000000LL +
+ PyDateTime_TIME_GET_MINUTE(pytime) * 60000000LL +
+ PyDateTime_TIME_GET_SECOND(pytime) * 1000000LL +
+ PyDateTime_TIME_GET_MICROSECOND(pytime));
+}
+
+ARROW_PYTHON_EXPORT
+inline int64_t PyTime_to_s(PyObject* pytime) { return PyTime_to_us(pytime) / 1000000; }
+
+ARROW_PYTHON_EXPORT
+inline int64_t PyTime_to_ms(PyObject* pytime) { return PyTime_to_us(pytime) / 1000; }
+
+ARROW_PYTHON_EXPORT
+inline int64_t PyTime_to_ns(PyObject* pytime) { return PyTime_to_us(pytime) * 1000; }
+
+ARROW_PYTHON_EXPORT
+Status PyTime_from_int(int64_t val, const TimeUnit::type unit, PyObject** out);
+
+ARROW_PYTHON_EXPORT
+Status PyDate_from_int(int64_t val, const DateUnit unit, PyObject** out);
+
+// WARNING: This function returns a naive datetime.
+ARROW_PYTHON_EXPORT
+Status PyDateTime_from_int(int64_t val, const TimeUnit::type unit, PyObject** out);
+
+// This declaration must be the same as in filesystem/filesystem.h
+using TimePoint =
+ std::chrono::time_point<std::chrono::system_clock, std::chrono::nanoseconds>;
+
+ARROW_PYTHON_EXPORT
+int64_t PyDate_to_days(PyDateTime_Date* pydate);
+
+ARROW_PYTHON_EXPORT
+inline int64_t PyDate_to_s(PyDateTime_Date* pydate) {
+ return PyDate_to_days(pydate) * 86400LL;
+}
+
+ARROW_PYTHON_EXPORT
+inline int64_t PyDate_to_ms(PyDateTime_Date* pydate) {
+ return PyDate_to_days(pydate) * 86400000LL;
+}
+
+ARROW_PYTHON_EXPORT
+inline int64_t PyDateTime_to_s(PyDateTime_DateTime* pydatetime) {
+ return (PyDate_to_s(reinterpret_cast<PyDateTime_Date*>(pydatetime)) +
+ PyDateTime_DATE_GET_HOUR(pydatetime) * 3600LL +
+ PyDateTime_DATE_GET_MINUTE(pydatetime) * 60LL +
+ PyDateTime_DATE_GET_SECOND(pydatetime));
+}
+
+ARROW_PYTHON_EXPORT
+inline int64_t PyDateTime_to_ms(PyDateTime_DateTime* pydatetime) {
+ return (PyDateTime_to_s(pydatetime) * 1000LL +
+ PyDateTime_DATE_GET_MICROSECOND(pydatetime) / 1000);
+}
+
+ARROW_PYTHON_EXPORT
+inline int64_t PyDateTime_to_us(PyDateTime_DateTime* pydatetime) {
+ return (PyDateTime_to_s(pydatetime) * 1000000LL +
+ PyDateTime_DATE_GET_MICROSECOND(pydatetime));
+}
+
+ARROW_PYTHON_EXPORT
+inline int64_t PyDateTime_to_ns(PyDateTime_DateTime* pydatetime) {
+ return PyDateTime_to_us(pydatetime) * 1000LL;
+}
+
+ARROW_PYTHON_EXPORT
+inline TimePoint PyDateTime_to_TimePoint(PyDateTime_DateTime* pydatetime) {
+ return TimePoint(TimePoint::duration(PyDateTime_to_ns(pydatetime)));
+}
+
+ARROW_PYTHON_EXPORT
+inline int64_t TimePoint_to_ns(TimePoint val) { return val.time_since_epoch().count(); }
+
+ARROW_PYTHON_EXPORT
+inline TimePoint TimePoint_from_s(double val) {
+ return TimePoint(TimePoint::duration(static_cast<int64_t>(1e9 * val)));
+}
+
+ARROW_PYTHON_EXPORT
+inline TimePoint TimePoint_from_ns(int64_t val) {
+ return TimePoint(TimePoint::duration(val));
+}
+
+ARROW_PYTHON_EXPORT
+inline int64_t PyDelta_to_s(PyDateTime_Delta* pytimedelta) {
+ return (PyDateTime_DELTA_GET_DAYS(pytimedelta) * 86400LL +
+ PyDateTime_DELTA_GET_SECONDS(pytimedelta));
+}
+
+ARROW_PYTHON_EXPORT
+inline int64_t PyDelta_to_ms(PyDateTime_Delta* pytimedelta) {
+ return (PyDelta_to_s(pytimedelta) * 1000LL +
+ PyDateTime_DELTA_GET_MICROSECONDS(pytimedelta) / 1000);
+}
+
+ARROW_PYTHON_EXPORT
+inline int64_t PyDelta_to_us(PyDateTime_Delta* pytimedelta) {
+ return (PyDelta_to_s(pytimedelta) * 1000000LL +
+ PyDateTime_DELTA_GET_MICROSECONDS(pytimedelta));
+}
+
+ARROW_PYTHON_EXPORT
+inline int64_t PyDelta_to_ns(PyDateTime_Delta* pytimedelta) {
+ return PyDelta_to_us(pytimedelta) * 1000LL;
+}
+
+ARROW_PYTHON_EXPORT
+Result<int64_t> PyDateTime_utcoffset_s(PyObject* pydatetime);
+
+/// \brief Convert a time zone name into a time zone object.
+///
+/// Supported input strings are:
+/// * As used in the Olson time zone database (the "tz database" or
+/// "tzdata"), such as "America/New_York"
+/// * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30
+/// GIL must be held when calling this method.
+ARROW_PYTHON_EXPORT
+Result<PyObject*> StringToTzinfo(const std::string& tz);
+
+/// \brief Convert a time zone object to a string representation.
+///
+/// The output strings are:
+/// * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30
+/// if the input object is either an instance of pytz._FixedOffset or
+/// datetime.timedelta
+/// * The timezone's name if the input object's tzname() method returns with a
+/// non-empty timezone name such as "UTC" or "America/New_York"
+///
+/// GIL must be held when calling this method.
+ARROW_PYTHON_EXPORT
+Result<std::string> TzinfoToString(PyObject* pytzinfo);
+
+/// \brief Convert MonthDayNano to a python namedtuple.
+///
+/// Return a named tuple (pyarrow.MonthDayNano) containing attributes
+/// "months", "days", "nanoseconds" in the given order
+/// with values extracted from the fields on interval.
+///
+/// GIL must be held when calling this method.
+ARROW_PYTHON_EXPORT
+PyObject* MonthDayNanoIntervalToNamedTuple(
+ const MonthDayNanoIntervalType::MonthDayNanos& interval);
+
+/// \brief Convert the given Array to a PyList object containing
+/// pyarrow.MonthDayNano objects.
+ARROW_PYTHON_EXPORT
+Result<PyObject*> MonthDayNanoIntervalArrayToPyList(
+ const MonthDayNanoIntervalArray& array);
+
+/// \brief Convert the Scalar obect to a pyarrow.MonthDayNano (or None if
+/// is isn't valid).
+ARROW_PYTHON_EXPORT
+Result<PyObject*> MonthDayNanoIntervalScalarToPyObject(
+ const MonthDayNanoIntervalScalar& scalar);
+
+} // namespace internal
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/decimal.cc b/src/arrow/cpp/src/arrow/python/decimal.cc
new file mode 100644
index 000000000..0c00fcfaa
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/decimal.cc
@@ -0,0 +1,246 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <limits>
+
+#include "arrow/python/common.h"
+#include "arrow/python/decimal.h"
+#include "arrow/python/helpers.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace py {
+namespace internal {
+
+Status ImportDecimalType(OwnedRef* decimal_type) {
+ OwnedRef decimal_module;
+ RETURN_NOT_OK(ImportModule("decimal", &decimal_module));
+ RETURN_NOT_OK(ImportFromModule(decimal_module.obj(), "Decimal", decimal_type));
+ return Status::OK();
+}
+
+Status PythonDecimalToString(PyObject* python_decimal, std::string* out) {
+ // Call Python's str(decimal_object)
+ return PyObject_StdStringStr(python_decimal, out);
+}
+
+// \brief Infer the precision and scale of a Python decimal.Decimal instance
+// \param python_decimal[in] An instance of decimal.Decimal
+// \param precision[out] The value of the inferred precision
+// \param scale[out] The value of the inferred scale
+// \return The status of the operation
+static Status InferDecimalPrecisionAndScale(PyObject* python_decimal, int32_t* precision,
+ int32_t* scale) {
+ DCHECK_NE(python_decimal, NULLPTR);
+ DCHECK_NE(precision, NULLPTR);
+ DCHECK_NE(scale, NULLPTR);
+
+ // TODO(phillipc): Make sure we perform PyDecimal_Check(python_decimal) as a DCHECK
+ OwnedRef as_tuple(PyObject_CallMethod(python_decimal, const_cast<char*>("as_tuple"),
+ const_cast<char*>("")));
+ RETURN_IF_PYERROR();
+ DCHECK(PyTuple_Check(as_tuple.obj()));
+
+ OwnedRef digits(PyObject_GetAttrString(as_tuple.obj(), "digits"));
+ RETURN_IF_PYERROR();
+ DCHECK(PyTuple_Check(digits.obj()));
+
+ const auto num_digits = static_cast<int32_t>(PyTuple_Size(digits.obj()));
+ RETURN_IF_PYERROR();
+
+ OwnedRef py_exponent(PyObject_GetAttrString(as_tuple.obj(), "exponent"));
+ RETURN_IF_PYERROR();
+ DCHECK(IsPyInteger(py_exponent.obj()));
+
+ const auto exponent = static_cast<int32_t>(PyLong_AsLong(py_exponent.obj()));
+ RETURN_IF_PYERROR();
+
+ if (exponent < 0) {
+ // If exponent > num_digits, we have a number with leading zeros
+ // such as 0.01234. Ensure we have enough precision for leading zeros
+ // (which are not included in num_digits).
+ *precision = std::max(num_digits, -exponent);
+ *scale = -exponent;
+ } else {
+ // Trailing zeros are not included in num_digits, need to add to precision.
+ // Note we don't generate negative scales as they are poorly supported
+ // in non-Arrow systems.
+ *precision = num_digits + exponent;
+ *scale = 0;
+ }
+ return Status::OK();
+}
+
+PyObject* DecimalFromString(PyObject* decimal_constructor,
+ const std::string& decimal_string) {
+ DCHECK_NE(decimal_constructor, nullptr);
+
+ auto string_size = decimal_string.size();
+ DCHECK_GT(string_size, 0);
+
+ auto string_bytes = decimal_string.c_str();
+ DCHECK_NE(string_bytes, nullptr);
+
+ return PyObject_CallFunction(decimal_constructor, const_cast<char*>("s#"), string_bytes,
+ static_cast<Py_ssize_t>(string_size));
+}
+
+namespace {
+
+template <typename ArrowDecimal>
+Status DecimalFromStdString(const std::string& decimal_string,
+ const DecimalType& arrow_type, ArrowDecimal* out) {
+ int32_t inferred_precision;
+ int32_t inferred_scale;
+
+ RETURN_NOT_OK(ArrowDecimal::FromString(decimal_string, out, &inferred_precision,
+ &inferred_scale));
+
+ const int32_t precision = arrow_type.precision();
+ const int32_t scale = arrow_type.scale();
+
+ if (scale != inferred_scale) {
+ DCHECK_NE(out, NULLPTR);
+ ARROW_ASSIGN_OR_RAISE(*out, out->Rescale(inferred_scale, scale));
+ }
+
+ auto inferred_scale_delta = inferred_scale - scale;
+ if (ARROW_PREDICT_FALSE((inferred_precision - inferred_scale_delta) > precision)) {
+ return Status::Invalid(
+ "Decimal type with precision ", inferred_precision,
+ " does not fit into precision inferred from first array element: ", precision);
+ }
+
+ return Status::OK();
+}
+
+template <typename ArrowDecimal>
+Status InternalDecimalFromPythonDecimal(PyObject* python_decimal,
+ const DecimalType& arrow_type,
+ ArrowDecimal* out) {
+ DCHECK_NE(python_decimal, NULLPTR);
+ DCHECK_NE(out, NULLPTR);
+
+ std::string string;
+ RETURN_NOT_OK(PythonDecimalToString(python_decimal, &string));
+ return DecimalFromStdString(string, arrow_type, out);
+}
+
+template <typename ArrowDecimal>
+Status InternalDecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type,
+ ArrowDecimal* out) {
+ DCHECK_NE(obj, NULLPTR);
+ DCHECK_NE(out, NULLPTR);
+
+ if (IsPyInteger(obj)) {
+ // TODO: add a fast path for small-ish ints
+ std::string string;
+ RETURN_NOT_OK(PyObject_StdStringStr(obj, &string));
+ return DecimalFromStdString(string, arrow_type, out);
+ } else if (PyDecimal_Check(obj)) {
+ return InternalDecimalFromPythonDecimal<ArrowDecimal>(obj, arrow_type, out);
+ } else {
+ return Status::TypeError("int or Decimal object expected, got ",
+ Py_TYPE(obj)->tp_name);
+ }
+}
+
+} // namespace
+
+Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type,
+ Decimal128* out) {
+ return InternalDecimalFromPythonDecimal(python_decimal, arrow_type, out);
+}
+
+Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type,
+ Decimal128* out) {
+ return InternalDecimalFromPyObject(obj, arrow_type, out);
+}
+
+Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type,
+ Decimal256* out) {
+ return InternalDecimalFromPythonDecimal(python_decimal, arrow_type, out);
+}
+
+Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type,
+ Decimal256* out) {
+ return InternalDecimalFromPyObject(obj, arrow_type, out);
+}
+
+bool PyDecimal_Check(PyObject* obj) {
+ static OwnedRef decimal_type;
+ if (!decimal_type.obj()) {
+ ARROW_CHECK_OK(ImportDecimalType(&decimal_type));
+ DCHECK(PyType_Check(decimal_type.obj()));
+ }
+ // PyObject_IsInstance() is slower as it has to check for virtual subclasses
+ const int result =
+ PyType_IsSubtype(Py_TYPE(obj), reinterpret_cast<PyTypeObject*>(decimal_type.obj()));
+ ARROW_CHECK_NE(result, -1) << " error during PyType_IsSubtype check";
+ return result == 1;
+}
+
+bool PyDecimal_ISNAN(PyObject* obj) {
+ DCHECK(PyDecimal_Check(obj)) << "obj is not an instance of decimal.Decimal";
+ OwnedRef is_nan(
+ PyObject_CallMethod(obj, const_cast<char*>("is_nan"), const_cast<char*>("")));
+ return PyObject_IsTrue(is_nan.obj()) == 1;
+}
+
+DecimalMetadata::DecimalMetadata()
+ : DecimalMetadata(std::numeric_limits<int32_t>::min(),
+ std::numeric_limits<int32_t>::min()) {}
+
+DecimalMetadata::DecimalMetadata(int32_t precision, int32_t scale)
+ : precision_(precision), scale_(scale) {}
+
+Status DecimalMetadata::Update(int32_t suggested_precision, int32_t suggested_scale) {
+ const int32_t current_scale = scale_;
+ scale_ = std::max(current_scale, suggested_scale);
+
+ const int32_t current_precision = precision_;
+
+ if (current_precision == std::numeric_limits<int32_t>::min()) {
+ precision_ = suggested_precision;
+ } else {
+ auto num_digits = std::max(current_precision - current_scale,
+ suggested_precision - suggested_scale);
+ precision_ = std::max(num_digits + scale_, current_precision);
+ }
+
+ return Status::OK();
+}
+
+Status DecimalMetadata::Update(PyObject* object) {
+ bool is_decimal = PyDecimal_Check(object);
+
+ if (ARROW_PREDICT_FALSE(!is_decimal || PyDecimal_ISNAN(object))) {
+ return Status::OK();
+ }
+
+ int32_t precision = 0;
+ int32_t scale = 0;
+ RETURN_NOT_OK(InferDecimalPrecisionAndScale(object, &precision, &scale));
+ return Update(precision, scale);
+}
+
+} // namespace internal
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/decimal.h b/src/arrow/cpp/src/arrow/python/decimal.h
new file mode 100644
index 000000000..1187037ae
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/decimal.h
@@ -0,0 +1,128 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+
+#include "arrow/python/visibility.h"
+#include "arrow/type.h"
+
+namespace arrow {
+
+class Decimal128;
+class Decimal256;
+
+namespace py {
+
+class OwnedRef;
+
+//
+// Python Decimal support
+//
+
+namespace internal {
+
+// \brief Import the Python Decimal type
+ARROW_PYTHON_EXPORT
+Status ImportDecimalType(OwnedRef* decimal_type);
+
+// \brief Convert a Python Decimal object to a C++ string
+// \param[in] python_decimal A Python decimal.Decimal instance
+// \param[out] The string representation of the Python Decimal instance
+// \return The status of the operation
+ARROW_PYTHON_EXPORT
+Status PythonDecimalToString(PyObject* python_decimal, std::string* out);
+
+// \brief Convert a C++ std::string to a Python Decimal instance
+// \param[in] decimal_constructor The decimal type object
+// \param[in] decimal_string A decimal string
+// \return An instance of decimal.Decimal
+ARROW_PYTHON_EXPORT
+PyObject* DecimalFromString(PyObject* decimal_constructor,
+ const std::string& decimal_string);
+
+// \brief Convert a Python decimal to an Arrow Decimal128 object
+// \param[in] python_decimal A Python decimal.Decimal instance
+// \param[in] arrow_type An instance of arrow::DecimalType
+// \param[out] out A pointer to a Decimal128
+// \return The status of the operation
+ARROW_PYTHON_EXPORT
+Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type,
+ Decimal128* out);
+
+// \brief Convert a Python object to an Arrow Decimal128 object
+// \param[in] python_decimal A Python int or decimal.Decimal instance
+// \param[in] arrow_type An instance of arrow::DecimalType
+// \param[out] out A pointer to a Decimal128
+// \return The status of the operation
+ARROW_PYTHON_EXPORT
+Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, Decimal128* out);
+
+// \brief Convert a Python decimal to an Arrow Decimal256 object
+// \param[in] python_decimal A Python decimal.Decimal instance
+// \param[in] arrow_type An instance of arrow::DecimalType
+// \param[out] out A pointer to a Decimal256
+// \return The status of the operation
+ARROW_PYTHON_EXPORT
+Status DecimalFromPythonDecimal(PyObject* python_decimal, const DecimalType& arrow_type,
+ Decimal256* out);
+
+// \brief Convert a Python object to an Arrow Decimal256 object
+// \param[in] python_decimal A Python int or decimal.Decimal instance
+// \param[in] arrow_type An instance of arrow::DecimalType
+// \param[out] out A pointer to a Decimal256
+// \return The status of the operation
+ARROW_PYTHON_EXPORT
+Status DecimalFromPyObject(PyObject* obj, const DecimalType& arrow_type, Decimal256* out);
+
+// \brief Check whether obj is an instance of Decimal
+ARROW_PYTHON_EXPORT
+bool PyDecimal_Check(PyObject* obj);
+
+// \brief Check whether obj is nan. This function will abort the program if the argument
+// is not a Decimal instance
+ARROW_PYTHON_EXPORT
+bool PyDecimal_ISNAN(PyObject* obj);
+
+// \brief Helper class to track and update the precision and scale of a decimal
+class ARROW_PYTHON_EXPORT DecimalMetadata {
+ public:
+ DecimalMetadata();
+ DecimalMetadata(int32_t precision, int32_t scale);
+
+ // \brief Adjust the precision and scale of a decimal type given a new precision and a
+ // new scale \param[in] suggested_precision A candidate precision \param[in]
+ // suggested_scale A candidate scale \return The status of the operation
+ Status Update(int32_t suggested_precision, int32_t suggested_scale);
+
+ // \brief A convenient interface for updating the precision and scale based on a Python
+ // Decimal object \param object A Python Decimal object \return The status of the
+ // operation
+ Status Update(PyObject* object);
+
+ int32_t precision() const { return precision_; }
+ int32_t scale() const { return scale_; }
+
+ private:
+ int32_t precision_;
+ int32_t scale_;
+};
+
+} // namespace internal
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/deserialize.cc b/src/arrow/cpp/src/arrow/python/deserialize.cc
new file mode 100644
index 000000000..961a1686e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/deserialize.cc
@@ -0,0 +1,495 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/python/deserialize.h"
+
+#include "arrow/python/numpy_interop.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <numpy/arrayobject.h>
+#include <numpy/arrayscalars.h>
+
+#include "arrow/array.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/options.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/util.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/table.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/value_parsing.h"
+
+#include "arrow/python/common.h"
+#include "arrow/python/datetime.h"
+#include "arrow/python/helpers.h"
+#include "arrow/python/numpy_convert.h"
+#include "arrow/python/pyarrow.h"
+#include "arrow/python/serialize.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::ParseValue;
+
+namespace py {
+
+Status CallDeserializeCallback(PyObject* context, PyObject* value,
+ PyObject** deserialized_object);
+
+Status DeserializeTuple(PyObject* context, const Array& array, int64_t start_idx,
+ int64_t stop_idx, PyObject* base, const SerializedPyObject& blobs,
+ PyObject** out);
+
+Status DeserializeList(PyObject* context, const Array& array, int64_t start_idx,
+ int64_t stop_idx, PyObject* base, const SerializedPyObject& blobs,
+ PyObject** out);
+
+Status DeserializeSet(PyObject* context, const Array& array, int64_t start_idx,
+ int64_t stop_idx, PyObject* base, const SerializedPyObject& blobs,
+ PyObject** out);
+
+Status DeserializeDict(PyObject* context, const Array& array, int64_t start_idx,
+ int64_t stop_idx, PyObject* base, const SerializedPyObject& blobs,
+ PyObject** out) {
+ const auto& data = checked_cast<const StructArray&>(array);
+ OwnedRef keys, vals;
+ OwnedRef result(PyDict_New());
+ RETURN_IF_PYERROR();
+
+ DCHECK_EQ(2, data.num_fields());
+
+ RETURN_NOT_OK(DeserializeList(context, *data.field(0), start_idx, stop_idx, base, blobs,
+ keys.ref()));
+ RETURN_NOT_OK(DeserializeList(context, *data.field(1), start_idx, stop_idx, base, blobs,
+ vals.ref()));
+ for (int64_t i = start_idx; i < stop_idx; ++i) {
+ // PyDict_SetItem behaves differently from PyList_SetItem and PyTuple_SetItem.
+ // The latter two steal references whereas PyDict_SetItem does not. So we need
+ // to make sure the reference count is decremented by letting the OwnedRef
+ // go out of scope at the end.
+ int ret = PyDict_SetItem(result.obj(), PyList_GET_ITEM(keys.obj(), i - start_idx),
+ PyList_GET_ITEM(vals.obj(), i - start_idx));
+ if (ret != 0) {
+ return ConvertPyError();
+ }
+ }
+ static PyObject* py_type = PyUnicode_FromString("_pytype_");
+ if (PyDict_Contains(result.obj(), py_type)) {
+ RETURN_NOT_OK(CallDeserializeCallback(context, result.obj(), out));
+ } else {
+ *out = result.detach();
+ }
+ return Status::OK();
+}
+
+Status DeserializeArray(int32_t index, PyObject* base, const SerializedPyObject& blobs,
+ PyObject** out) {
+ RETURN_NOT_OK(py::TensorToNdarray(blobs.ndarrays[index], base, out));
+ // Mark the array as immutable
+ OwnedRef flags(PyObject_GetAttrString(*out, "flags"));
+ if (flags.obj() == NULL) {
+ return ConvertPyError();
+ }
+ if (PyObject_SetAttrString(flags.obj(), "writeable", Py_False) < 0) {
+ return ConvertPyError();
+ }
+ return Status::OK();
+}
+
+Status GetValue(PyObject* context, const Array& arr, int64_t index, int8_t type,
+ PyObject* base, const SerializedPyObject& blobs, PyObject** result) {
+ switch (type) {
+ case PythonType::NONE:
+ Py_INCREF(Py_None);
+ *result = Py_None;
+ return Status::OK();
+ case PythonType::BOOL:
+ *result = PyBool_FromLong(checked_cast<const BooleanArray&>(arr).Value(index));
+ return Status::OK();
+ case PythonType::PY2INT:
+ case PythonType::INT: {
+ *result = PyLong_FromSsize_t(checked_cast<const Int64Array&>(arr).Value(index));
+ return Status::OK();
+ }
+ case PythonType::BYTES: {
+ auto view = checked_cast<const BinaryArray&>(arr).GetView(index);
+ *result = PyBytes_FromStringAndSize(view.data(), view.length());
+ return CheckPyError();
+ }
+ case PythonType::STRING: {
+ auto view = checked_cast<const StringArray&>(arr).GetView(index);
+ *result = PyUnicode_FromStringAndSize(view.data(), view.length());
+ return CheckPyError();
+ }
+ case PythonType::HALF_FLOAT: {
+ *result = PyHalf_FromHalf(checked_cast<const HalfFloatArray&>(arr).Value(index));
+ RETURN_IF_PYERROR();
+ return Status::OK();
+ }
+ case PythonType::FLOAT:
+ *result = PyFloat_FromDouble(checked_cast<const FloatArray&>(arr).Value(index));
+ return Status::OK();
+ case PythonType::DOUBLE:
+ *result = PyFloat_FromDouble(checked_cast<const DoubleArray&>(arr).Value(index));
+ return Status::OK();
+ case PythonType::DATE64: {
+ RETURN_NOT_OK(internal::PyDateTime_from_int(
+ checked_cast<const Date64Array&>(arr).Value(index), TimeUnit::MICRO, result));
+ RETURN_IF_PYERROR();
+ return Status::OK();
+ }
+ case PythonType::LIST: {
+ const auto& l = checked_cast<const ListArray&>(arr);
+ return DeserializeList(context, *l.values(), l.value_offset(index),
+ l.value_offset(index + 1), base, blobs, result);
+ }
+ case PythonType::DICT: {
+ const auto& l = checked_cast<const ListArray&>(arr);
+ return DeserializeDict(context, *l.values(), l.value_offset(index),
+ l.value_offset(index + 1), base, blobs, result);
+ }
+ case PythonType::TUPLE: {
+ const auto& l = checked_cast<const ListArray&>(arr);
+ return DeserializeTuple(context, *l.values(), l.value_offset(index),
+ l.value_offset(index + 1), base, blobs, result);
+ }
+ case PythonType::SET: {
+ const auto& l = checked_cast<const ListArray&>(arr);
+ return DeserializeSet(context, *l.values(), l.value_offset(index),
+ l.value_offset(index + 1), base, blobs, result);
+ }
+ case PythonType::TENSOR: {
+ int32_t ref = checked_cast<const Int32Array&>(arr).Value(index);
+ *result = wrap_tensor(blobs.tensors[ref]);
+ return Status::OK();
+ }
+ case PythonType::SPARSECOOTENSOR: {
+ int32_t ref = checked_cast<const Int32Array&>(arr).Value(index);
+ const std::shared_ptr<SparseCOOTensor>& sparse_coo_tensor =
+ arrow::internal::checked_pointer_cast<SparseCOOTensor>(
+ blobs.sparse_tensors[ref]);
+ *result = wrap_sparse_coo_tensor(sparse_coo_tensor);
+ return Status::OK();
+ }
+ case PythonType::SPARSECSRMATRIX: {
+ int32_t ref = checked_cast<const Int32Array&>(arr).Value(index);
+ const std::shared_ptr<SparseCSRMatrix>& sparse_csr_matrix =
+ arrow::internal::checked_pointer_cast<SparseCSRMatrix>(
+ blobs.sparse_tensors[ref]);
+ *result = wrap_sparse_csr_matrix(sparse_csr_matrix);
+ return Status::OK();
+ }
+ case PythonType::SPARSECSCMATRIX: {
+ int32_t ref = checked_cast<const Int32Array&>(arr).Value(index);
+ const std::shared_ptr<SparseCSCMatrix>& sparse_csc_matrix =
+ arrow::internal::checked_pointer_cast<SparseCSCMatrix>(
+ blobs.sparse_tensors[ref]);
+ *result = wrap_sparse_csc_matrix(sparse_csc_matrix);
+ return Status::OK();
+ }
+ case PythonType::SPARSECSFTENSOR: {
+ int32_t ref = checked_cast<const Int32Array&>(arr).Value(index);
+ const std::shared_ptr<SparseCSFTensor>& sparse_csf_tensor =
+ arrow::internal::checked_pointer_cast<SparseCSFTensor>(
+ blobs.sparse_tensors[ref]);
+ *result = wrap_sparse_csf_tensor(sparse_csf_tensor);
+ return Status::OK();
+ }
+ case PythonType::NDARRAY: {
+ int32_t ref = checked_cast<const Int32Array&>(arr).Value(index);
+ return DeserializeArray(ref, base, blobs, result);
+ }
+ case PythonType::BUFFER: {
+ int32_t ref = checked_cast<const Int32Array&>(arr).Value(index);
+ *result = wrap_buffer(blobs.buffers[ref]);
+ return Status::OK();
+ }
+ default: {
+ ARROW_CHECK(false) << "union tag " << type << "' not recognized";
+ }
+ }
+ return Status::OK();
+}
+
+Status GetPythonTypes(const UnionArray& data, std::vector<int8_t>* result) {
+ ARROW_CHECK(result != nullptr);
+ auto type = data.type();
+ for (int i = 0; i < type->num_fields(); ++i) {
+ int8_t tag = 0;
+ const std::string& data = type->field(i)->name();
+ if (!ParseValue<Int8Type>(data.c_str(), data.size(), &tag)) {
+ return Status::SerializationError("Cannot convert string: \"",
+ type->field(i)->name(), "\" to int8_t");
+ }
+ result->push_back(tag);
+ }
+ return Status::OK();
+}
+
+template <typename CreateSequenceFn, typename SetItemFn>
+Status DeserializeSequence(PyObject* context, const Array& array, int64_t start_idx,
+ int64_t stop_idx, PyObject* base,
+ const SerializedPyObject& blobs,
+ CreateSequenceFn&& create_sequence, SetItemFn&& set_item,
+ PyObject** out) {
+ const auto& data = checked_cast<const DenseUnionArray&>(array);
+ OwnedRef result(create_sequence(stop_idx - start_idx));
+ RETURN_IF_PYERROR();
+ const int8_t* type_codes = data.raw_type_codes();
+ const int32_t* value_offsets = data.raw_value_offsets();
+ std::vector<int8_t> python_types;
+ RETURN_NOT_OK(GetPythonTypes(data, &python_types));
+ for (int64_t i = start_idx; i < stop_idx; ++i) {
+ const int64_t offset = value_offsets[i];
+ const uint8_t type = type_codes[i];
+ PyObject* value;
+ RETURN_NOT_OK(GetValue(context, *data.field(type), offset, python_types[type], base,
+ blobs, &value));
+ RETURN_NOT_OK(set_item(result.obj(), i - start_idx, value));
+ }
+ *out = result.detach();
+ return Status::OK();
+}
+
+Status DeserializeList(PyObject* context, const Array& array, int64_t start_idx,
+ int64_t stop_idx, PyObject* base, const SerializedPyObject& blobs,
+ PyObject** out) {
+ return DeserializeSequence(
+ context, array, start_idx, stop_idx, base, blobs,
+ [](int64_t size) { return PyList_New(size); },
+ [](PyObject* seq, int64_t index, PyObject* item) {
+ PyList_SET_ITEM(seq, index, item);
+ return Status::OK();
+ },
+ out);
+}
+
+Status DeserializeTuple(PyObject* context, const Array& array, int64_t start_idx,
+ int64_t stop_idx, PyObject* base, const SerializedPyObject& blobs,
+ PyObject** out) {
+ return DeserializeSequence(
+ context, array, start_idx, stop_idx, base, blobs,
+ [](int64_t size) { return PyTuple_New(size); },
+ [](PyObject* seq, int64_t index, PyObject* item) {
+ PyTuple_SET_ITEM(seq, index, item);
+ return Status::OK();
+ },
+ out);
+}
+
+Status DeserializeSet(PyObject* context, const Array& array, int64_t start_idx,
+ int64_t stop_idx, PyObject* base, const SerializedPyObject& blobs,
+ PyObject** out) {
+ return DeserializeSequence(
+ context, array, start_idx, stop_idx, base, blobs,
+ [](int64_t size) { return PySet_New(nullptr); },
+ [](PyObject* seq, int64_t index, PyObject* item) {
+ int err = PySet_Add(seq, item);
+ Py_DECREF(item);
+ if (err < 0) {
+ RETURN_IF_PYERROR();
+ }
+ return Status::OK();
+ },
+ out);
+}
+
+Status ReadSerializedObject(io::RandomAccessFile* src, SerializedPyObject* out) {
+ int32_t num_tensors;
+ int32_t num_sparse_tensors;
+ int32_t num_ndarrays;
+ int32_t num_buffers;
+
+ // Read number of tensors
+ RETURN_NOT_OK(src->Read(sizeof(int32_t), reinterpret_cast<uint8_t*>(&num_tensors)));
+ RETURN_NOT_OK(
+ src->Read(sizeof(int32_t), reinterpret_cast<uint8_t*>(&num_sparse_tensors)));
+ RETURN_NOT_OK(src->Read(sizeof(int32_t), reinterpret_cast<uint8_t*>(&num_ndarrays)));
+ RETURN_NOT_OK(src->Read(sizeof(int32_t), reinterpret_cast<uint8_t*>(&num_buffers)));
+
+ // Align stream to 8-byte offset
+ RETURN_NOT_OK(ipc::AlignStream(src, ipc::kArrowIpcAlignment));
+ std::shared_ptr<RecordBatchReader> reader;
+ ARROW_ASSIGN_OR_RAISE(reader, ipc::RecordBatchStreamReader::Open(src));
+ RETURN_NOT_OK(reader->ReadNext(&out->batch));
+
+ /// Skip EOS marker
+ RETURN_NOT_OK(src->Advance(4));
+
+ /// Align stream so tensor bodies are 64-byte aligned
+ RETURN_NOT_OK(ipc::AlignStream(src, ipc::kTensorAlignment));
+
+ for (int i = 0; i < num_tensors; ++i) {
+ std::shared_ptr<Tensor> tensor;
+ ARROW_ASSIGN_OR_RAISE(tensor, ipc::ReadTensor(src));
+ RETURN_NOT_OK(ipc::AlignStream(src, ipc::kTensorAlignment));
+ out->tensors.push_back(tensor);
+ }
+
+ for (int i = 0; i < num_sparse_tensors; ++i) {
+ std::shared_ptr<SparseTensor> sparse_tensor;
+ ARROW_ASSIGN_OR_RAISE(sparse_tensor, ipc::ReadSparseTensor(src));
+ RETURN_NOT_OK(ipc::AlignStream(src, ipc::kTensorAlignment));
+ out->sparse_tensors.push_back(sparse_tensor);
+ }
+
+ for (int i = 0; i < num_ndarrays; ++i) {
+ std::shared_ptr<Tensor> ndarray;
+ ARROW_ASSIGN_OR_RAISE(ndarray, ipc::ReadTensor(src));
+ RETURN_NOT_OK(ipc::AlignStream(src, ipc::kTensorAlignment));
+ out->ndarrays.push_back(ndarray);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(int64_t offset, src->Tell());
+ for (int i = 0; i < num_buffers; ++i) {
+ int64_t size;
+ RETURN_NOT_OK(src->ReadAt(offset, sizeof(int64_t), &size));
+ offset += sizeof(int64_t);
+ ARROW_ASSIGN_OR_RAISE(auto buffer, src->ReadAt(offset, size));
+ out->buffers.push_back(buffer);
+ offset += size;
+ }
+
+ return Status::OK();
+}
+
+Status DeserializeObject(PyObject* context, const SerializedPyObject& obj, PyObject* base,
+ PyObject** out) {
+ PyAcquireGIL lock;
+ return DeserializeList(context, *obj.batch->column(0), 0, obj.batch->num_rows(), base,
+ obj, out);
+}
+
+Status GetSerializedFromComponents(int num_tensors,
+ const SparseTensorCounts& num_sparse_tensors,
+ int num_ndarrays, int num_buffers, PyObject* data,
+ SerializedPyObject* out) {
+ PyAcquireGIL gil;
+ const Py_ssize_t data_length = PyList_Size(data);
+ RETURN_IF_PYERROR();
+
+ const Py_ssize_t expected_data_length = 1 + num_tensors * 2 +
+ num_sparse_tensors.num_total_buffers() +
+ num_ndarrays * 2 + num_buffers;
+ if (data_length != expected_data_length) {
+ return Status::Invalid("Invalid number of buffers in data");
+ }
+
+ auto GetBuffer = [&data](Py_ssize_t index, std::shared_ptr<Buffer>* out) {
+ ARROW_CHECK_LE(index, PyList_Size(data));
+ PyObject* py_buf = PyList_GET_ITEM(data, index);
+ return unwrap_buffer(py_buf).Value(out);
+ };
+
+ Py_ssize_t buffer_index = 0;
+
+ // Read the union batch describing object structure
+ {
+ std::shared_ptr<Buffer> data_buffer;
+ RETURN_NOT_OK(GetBuffer(buffer_index++, &data_buffer));
+ gil.release();
+ io::BufferReader buf_reader(data_buffer);
+ std::shared_ptr<RecordBatchReader> reader;
+ ARROW_ASSIGN_OR_RAISE(reader, ipc::RecordBatchStreamReader::Open(&buf_reader));
+ RETURN_NOT_OK(reader->ReadNext(&out->batch));
+ gil.acquire();
+ }
+
+ // Zero-copy reconstruct tensors
+ for (int i = 0; i < num_tensors; ++i) {
+ std::shared_ptr<Buffer> metadata;
+ std::shared_ptr<Buffer> body;
+ std::shared_ptr<Tensor> tensor;
+ RETURN_NOT_OK(GetBuffer(buffer_index++, &metadata));
+ RETURN_NOT_OK(GetBuffer(buffer_index++, &body));
+
+ ipc::Message message(metadata, body);
+
+ ARROW_ASSIGN_OR_RAISE(tensor, ipc::ReadTensor(message));
+ out->tensors.emplace_back(std::move(tensor));
+ }
+
+ // Zero-copy reconstruct sparse tensors
+ for (int i = 0, n = num_sparse_tensors.num_total_tensors(); i < n; ++i) {
+ ipc::IpcPayload payload;
+ RETURN_NOT_OK(GetBuffer(buffer_index++, &payload.metadata));
+
+ ARROW_ASSIGN_OR_RAISE(
+ size_t num_bodies,
+ ipc::internal::ReadSparseTensorBodyBufferCount(*payload.metadata));
+
+ payload.body_buffers.reserve(num_bodies);
+ for (size_t i = 0; i < num_bodies; ++i) {
+ std::shared_ptr<Buffer> body;
+ RETURN_NOT_OK(GetBuffer(buffer_index++, &body));
+ payload.body_buffers.emplace_back(body);
+ }
+
+ std::shared_ptr<SparseTensor> sparse_tensor;
+ ARROW_ASSIGN_OR_RAISE(sparse_tensor, ipc::internal::ReadSparseTensorPayload(payload));
+ out->sparse_tensors.emplace_back(std::move(sparse_tensor));
+ }
+
+ // Zero-copy reconstruct tensors for numpy ndarrays
+ for (int i = 0; i < num_ndarrays; ++i) {
+ std::shared_ptr<Buffer> metadata;
+ std::shared_ptr<Buffer> body;
+ std::shared_ptr<Tensor> tensor;
+ RETURN_NOT_OK(GetBuffer(buffer_index++, &metadata));
+ RETURN_NOT_OK(GetBuffer(buffer_index++, &body));
+
+ ipc::Message message(metadata, body);
+
+ ARROW_ASSIGN_OR_RAISE(tensor, ipc::ReadTensor(message));
+ out->ndarrays.emplace_back(std::move(tensor));
+ }
+
+ // Unwrap and append buffers
+ for (int i = 0; i < num_buffers; ++i) {
+ std::shared_ptr<Buffer> buffer;
+ RETURN_NOT_OK(GetBuffer(buffer_index++, &buffer));
+ out->buffers.emplace_back(std::move(buffer));
+ }
+
+ return Status::OK();
+}
+
+Status DeserializeNdarray(const SerializedPyObject& object,
+ std::shared_ptr<Tensor>* out) {
+ if (object.ndarrays.size() != 1) {
+ return Status::Invalid("Object is not an Ndarray");
+ }
+ *out = object.ndarrays[0];
+ return Status::OK();
+}
+
+Status NdarrayFromBuffer(std::shared_ptr<Buffer> src, std::shared_ptr<Tensor>* out) {
+ io::BufferReader reader(src);
+ SerializedPyObject object;
+ RETURN_NOT_OK(ReadSerializedObject(&reader, &object));
+ return DeserializeNdarray(object, out);
+}
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/deserialize.h b/src/arrow/cpp/src/arrow/python/deserialize.h
new file mode 100644
index 000000000..41b6a13a3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/deserialize.h
@@ -0,0 +1,106 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "arrow/python/serialize.h"
+#include "arrow/python/visibility.h"
+#include "arrow/status.h"
+
+namespace arrow {
+
+class RecordBatch;
+class Tensor;
+
+namespace io {
+
+class RandomAccessFile;
+
+} // namespace io
+
+namespace py {
+
+struct ARROW_PYTHON_EXPORT SparseTensorCounts {
+ int coo;
+ int csr;
+ int csc;
+ int csf;
+ int ndim_csf;
+
+ int num_total_tensors() const { return coo + csr + csc + csf; }
+ int num_total_buffers() const {
+ return coo * 3 + csr * 4 + csc * 4 + 2 * ndim_csf + csf;
+ }
+};
+
+/// \brief Read serialized Python sequence from file interface using Arrow IPC
+/// \param[in] src a RandomAccessFile
+/// \param[out] out the reconstructed data
+/// \return Status
+ARROW_PYTHON_EXPORT
+Status ReadSerializedObject(io::RandomAccessFile* src, SerializedPyObject* out);
+
+/// \brief Reconstruct SerializedPyObject from representation produced by
+/// SerializedPyObject::GetComponents.
+///
+/// \param[in] num_tensors number of tensors in the object
+/// \param[in] num_sparse_tensors number of sparse tensors in the object
+/// \param[in] num_ndarrays number of numpy Ndarrays in the object
+/// \param[in] num_buffers number of buffers in the object
+/// \param[in] data a list containing pyarrow.Buffer instances. It must be 1 +
+/// num_tensors * 2 + num_coo_tensors * 3 + num_csr_tensors * 4 + num_csc_tensors * 4 +
+/// num_csf_tensors * (2 * ndim_csf + 3) + num_buffers in length
+/// \param[out] out the reconstructed object
+/// \return Status
+ARROW_PYTHON_EXPORT
+Status GetSerializedFromComponents(int num_tensors,
+ const SparseTensorCounts& num_sparse_tensors,
+ int num_ndarrays, int num_buffers, PyObject* data,
+ SerializedPyObject* out);
+
+/// \brief Reconstruct Python object from Arrow-serialized representation
+/// \param[in] context Serialization context which contains custom serialization
+/// and deserialization callbacks. Can be any Python object with a
+/// _serialize_callback method for serialization and a _deserialize_callback
+/// method for deserialization. If context is None, no custom serialization
+/// will be attempted.
+/// \param[in] object Object to deserialize
+/// \param[in] base a Python object holding the underlying data that any NumPy
+/// arrays will reference, to avoid premature deallocation
+/// \param[out] out The returned object
+/// \return Status
+/// This acquires the GIL
+ARROW_PYTHON_EXPORT
+Status DeserializeObject(PyObject* context, const SerializedPyObject& object,
+ PyObject* base, PyObject** out);
+
+/// \brief Reconstruct Ndarray from Arrow-serialized representation
+/// \param[in] object Object to deserialize
+/// \param[out] out The deserialized tensor
+/// \return Status
+ARROW_PYTHON_EXPORT
+Status DeserializeNdarray(const SerializedPyObject& object, std::shared_ptr<Tensor>* out);
+
+ARROW_PYTHON_EXPORT
+Status NdarrayFromBuffer(std::shared_ptr<Buffer> src, std::shared_ptr<Tensor>* out);
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/extension_type.cc b/src/arrow/cpp/src/arrow/python/extension_type.cc
new file mode 100644
index 000000000..3ccc171c8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/extension_type.cc
@@ -0,0 +1,217 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <sstream>
+#include <utility>
+
+#include "arrow/python/extension_type.h"
+#include "arrow/python/helpers.h"
+#include "arrow/python/pyarrow.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace py {
+
+namespace {
+
+// Serialize a Python ExtensionType instance
+Status SerializeExtInstance(PyObject* type_instance, std::string* out) {
+ OwnedRef res(
+ cpp_PyObject_CallMethod(type_instance, "__arrow_ext_serialize__", nullptr));
+ if (!res) {
+ return ConvertPyError();
+ }
+ if (!PyBytes_Check(res.obj())) {
+ return Status::TypeError(
+ "__arrow_ext_serialize__ should return bytes object, "
+ "got ",
+ internal::PyObject_StdStringRepr(res.obj()));
+ }
+ *out = internal::PyBytes_AsStdString(res.obj());
+ return Status::OK();
+}
+
+// Deserialize a Python ExtensionType instance
+PyObject* DeserializeExtInstance(PyObject* type_class,
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized_data) {
+ OwnedRef storage_ref(wrap_data_type(storage_type));
+ if (!storage_ref) {
+ return nullptr;
+ }
+ OwnedRef data_ref(PyBytes_FromStringAndSize(
+ serialized_data.data(), static_cast<Py_ssize_t>(serialized_data.size())));
+ if (!data_ref) {
+ return nullptr;
+ }
+
+ return cpp_PyObject_CallMethod(type_class, "__arrow_ext_deserialize__", "OO",
+ storage_ref.obj(), data_ref.obj());
+}
+
+} // namespace
+
+static const char* kExtensionName = "arrow.py_extension_type";
+
+std::string PyExtensionType::ToString() const {
+ PyAcquireGIL lock;
+
+ std::stringstream ss;
+ OwnedRef instance(GetInstance());
+ ss << "extension<" << this->extension_name() << "<" << Py_TYPE(instance.obj())->tp_name
+ << ">>";
+ return ss.str();
+}
+
+PyExtensionType::PyExtensionType(std::shared_ptr<DataType> storage_type, PyObject* typ,
+ PyObject* inst)
+ : ExtensionType(storage_type),
+ extension_name_(kExtensionName),
+ type_class_(typ),
+ type_instance_(inst) {}
+
+PyExtensionType::PyExtensionType(std::shared_ptr<DataType> storage_type,
+ std::string extension_name, PyObject* typ,
+ PyObject* inst)
+ : ExtensionType(storage_type),
+ extension_name_(std::move(extension_name)),
+ type_class_(typ),
+ type_instance_(inst) {}
+
+bool PyExtensionType::ExtensionEquals(const ExtensionType& other) const {
+ PyAcquireGIL lock;
+
+ if (other.extension_name() != extension_name()) {
+ return false;
+ }
+ const auto& other_ext = checked_cast<const PyExtensionType&>(other);
+ int res = -1;
+ if (!type_instance_) {
+ if (other_ext.type_instance_) {
+ return false;
+ }
+ // Compare Python types
+ res = PyObject_RichCompareBool(type_class_.obj(), other_ext.type_class_.obj(), Py_EQ);
+ } else {
+ if (!other_ext.type_instance_) {
+ return false;
+ }
+ // Compare Python instances
+ OwnedRef left(GetInstance());
+ OwnedRef right(other_ext.GetInstance());
+ if (!left || !right) {
+ goto error;
+ }
+ res = PyObject_RichCompareBool(left.obj(), right.obj(), Py_EQ);
+ }
+ if (res == -1) {
+ goto error;
+ }
+ return res == 1;
+
+error:
+ // Cannot propagate error
+ PyErr_WriteUnraisable(nullptr);
+ return false;
+}
+
+std::shared_ptr<Array> PyExtensionType::MakeArray(std::shared_ptr<ArrayData> data) const {
+ DCHECK_EQ(data->type->id(), Type::EXTENSION);
+ return std::make_shared<ExtensionArray>(data);
+}
+
+std::string PyExtensionType::Serialize() const {
+ DCHECK(type_instance_);
+ return serialized_;
+}
+
+Result<std::shared_ptr<DataType>> PyExtensionType::Deserialize(
+ std::shared_ptr<DataType> storage_type, const std::string& serialized_data) const {
+ PyAcquireGIL lock;
+
+ if (import_pyarrow()) {
+ return ConvertPyError();
+ }
+ OwnedRef res(DeserializeExtInstance(type_class_.obj(), storage_type, serialized_data));
+ if (!res) {
+ return ConvertPyError();
+ }
+ return unwrap_data_type(res.obj());
+}
+
+PyObject* PyExtensionType::GetInstance() const {
+ if (!type_instance_) {
+ PyErr_SetString(PyExc_TypeError, "Not an instance");
+ return nullptr;
+ }
+ DCHECK(PyWeakref_CheckRef(type_instance_.obj()));
+ PyObject* inst = PyWeakref_GET_OBJECT(type_instance_.obj());
+ if (inst != Py_None) {
+ // Cached instance still alive
+ Py_INCREF(inst);
+ return inst;
+ } else {
+ // Must reconstruct from serialized form
+ // XXX cache again?
+ return DeserializeExtInstance(type_class_.obj(), storage_type_, serialized_);
+ }
+}
+
+Status PyExtensionType::SetInstance(PyObject* inst) const {
+ // Check we have the right type
+ PyObject* typ = reinterpret_cast<PyObject*>(Py_TYPE(inst));
+ if (typ != type_class_.obj()) {
+ return Status::TypeError("Unexpected Python ExtensionType class ",
+ internal::PyObject_StdStringRepr(typ), " expected ",
+ internal::PyObject_StdStringRepr(type_class_.obj()));
+ }
+
+ PyObject* wr = PyWeakref_NewRef(inst, nullptr);
+ if (wr == NULL) {
+ return ConvertPyError();
+ }
+ type_instance_.reset(wr);
+ return SerializeExtInstance(inst, &serialized_);
+}
+
+Status PyExtensionType::FromClass(const std::shared_ptr<DataType> storage_type,
+ const std::string extension_name, PyObject* typ,
+ std::shared_ptr<ExtensionType>* out) {
+ Py_INCREF(typ);
+ out->reset(new PyExtensionType(storage_type, std::move(extension_name), typ));
+ return Status::OK();
+}
+
+Status RegisterPyExtensionType(const std::shared_ptr<DataType>& type) {
+ DCHECK_EQ(type->id(), Type::EXTENSION);
+ auto ext_type = std::dynamic_pointer_cast<ExtensionType>(type);
+ return RegisterExtensionType(ext_type);
+}
+
+Status UnregisterPyExtensionType(const std::string& type_name) {
+ return UnregisterExtensionType(type_name);
+}
+
+std::string PyExtensionName() { return kExtensionName; }
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/extension_type.h b/src/arrow/cpp/src/arrow/python/extension_type.h
new file mode 100644
index 000000000..e433d9aca
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/extension_type.h
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/extension_type.h"
+#include "arrow/python/common.h"
+#include "arrow/python/visibility.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace py {
+
+class ARROW_PYTHON_EXPORT PyExtensionType : public ExtensionType {
+ public:
+ // Implement extensionType API
+ std::string extension_name() const override { return extension_name_; }
+
+ std::string ToString() const override;
+
+ bool ExtensionEquals(const ExtensionType& other) const override;
+
+ std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;
+
+ Result<std::shared_ptr<DataType>> Deserialize(
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized) const override;
+
+ std::string Serialize() const override;
+
+ // For use from Cython
+ // Assumes that `typ` is borrowed
+ static Status FromClass(const std::shared_ptr<DataType> storage_type,
+ const std::string extension_name, PyObject* typ,
+ std::shared_ptr<ExtensionType>* out);
+
+ // Return new ref
+ PyObject* GetInstance() const;
+ Status SetInstance(PyObject*) const;
+
+ protected:
+ PyExtensionType(std::shared_ptr<DataType> storage_type, PyObject* typ,
+ PyObject* inst = NULLPTR);
+ PyExtensionType(std::shared_ptr<DataType> storage_type, std::string extension_name,
+ PyObject* typ, PyObject* inst = NULLPTR);
+
+ std::string extension_name_;
+
+ // These fields are mutable because of two-step initialization.
+ mutable OwnedRefNoGIL type_class_;
+ // A weakref or null. Storing a strong reference to the Python extension type
+ // instance would create an unreclaimable reference cycle between Python and C++
+ // (the Python instance has to keep a strong reference to the C++ ExtensionType
+ // in other direction). Instead, we store a weakref to the instance.
+ // If the weakref is dead, we reconstruct the instance from its serialized form.
+ mutable OwnedRefNoGIL type_instance_;
+ // Empty if type_instance_ is null
+ mutable std::string serialized_;
+};
+
+ARROW_PYTHON_EXPORT std::string PyExtensionName();
+
+ARROW_PYTHON_EXPORT Status RegisterPyExtensionType(const std::shared_ptr<DataType>&);
+
+ARROW_PYTHON_EXPORT Status UnregisterPyExtensionType(const std::string& type_name);
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/filesystem.cc b/src/arrow/cpp/src/arrow/python/filesystem.cc
new file mode 100644
index 000000000..8c12f05a0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/filesystem.cc
@@ -0,0 +1,206 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/python/filesystem.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using fs::FileInfo;
+using fs::FileSelector;
+
+namespace py {
+namespace fs {
+
+PyFileSystem::PyFileSystem(PyObject* handler, PyFileSystemVtable vtable)
+ : handler_(handler), vtable_(std::move(vtable)) {
+ Py_INCREF(handler);
+}
+
+PyFileSystem::~PyFileSystem() {}
+
+std::shared_ptr<PyFileSystem> PyFileSystem::Make(PyObject* handler,
+ PyFileSystemVtable vtable) {
+ return std::make_shared<PyFileSystem>(handler, std::move(vtable));
+}
+
+std::string PyFileSystem::type_name() const {
+ std::string result;
+ auto st = SafeCallIntoPython([&]() -> Status {
+ vtable_.get_type_name(handler_.obj(), &result);
+ if (PyErr_Occurred()) {
+ PyErr_WriteUnraisable(handler_.obj());
+ }
+ return Status::OK();
+ });
+ ARROW_UNUSED(st);
+ return result;
+}
+
+bool PyFileSystem::Equals(const FileSystem& other) const {
+ bool result;
+ auto st = SafeCallIntoPython([&]() -> Status {
+ result = vtable_.equals(handler_.obj(), other);
+ if (PyErr_Occurred()) {
+ PyErr_WriteUnraisable(handler_.obj());
+ }
+ return Status::OK();
+ });
+ ARROW_UNUSED(st);
+ return result;
+}
+
+Result<FileInfo> PyFileSystem::GetFileInfo(const std::string& path) {
+ FileInfo info;
+
+ auto st = SafeCallIntoPython([&]() -> Status {
+ vtable_.get_file_info(handler_.obj(), path, &info);
+ return CheckPyError();
+ });
+ RETURN_NOT_OK(st);
+ return info;
+}
+
+Result<std::vector<FileInfo>> PyFileSystem::GetFileInfo(
+ const std::vector<std::string>& paths) {
+ std::vector<FileInfo> infos;
+
+ auto st = SafeCallIntoPython([&]() -> Status {
+ vtable_.get_file_info_vector(handler_.obj(), paths, &infos);
+ return CheckPyError();
+ });
+ RETURN_NOT_OK(st);
+ return infos;
+}
+
+Result<std::vector<FileInfo>> PyFileSystem::GetFileInfo(const FileSelector& select) {
+ std::vector<FileInfo> infos;
+
+ auto st = SafeCallIntoPython([&]() -> Status {
+ vtable_.get_file_info_selector(handler_.obj(), select, &infos);
+ return CheckPyError();
+ });
+ RETURN_NOT_OK(st);
+ return infos;
+}
+
+Status PyFileSystem::CreateDir(const std::string& path, bool recursive) {
+ return SafeCallIntoPython([&]() -> Status {
+ vtable_.create_dir(handler_.obj(), path, recursive);
+ return CheckPyError();
+ });
+}
+
+Status PyFileSystem::DeleteDir(const std::string& path) {
+ return SafeCallIntoPython([&]() -> Status {
+ vtable_.delete_dir(handler_.obj(), path);
+ return CheckPyError();
+ });
+}
+
+Status PyFileSystem::DeleteDirContents(const std::string& path) {
+ return SafeCallIntoPython([&]() -> Status {
+ vtable_.delete_dir_contents(handler_.obj(), path);
+ return CheckPyError();
+ });
+}
+
+Status PyFileSystem::DeleteRootDirContents() {
+ return SafeCallIntoPython([&]() -> Status {
+ vtable_.delete_root_dir_contents(handler_.obj());
+ return CheckPyError();
+ });
+}
+
+Status PyFileSystem::DeleteFile(const std::string& path) {
+ return SafeCallIntoPython([&]() -> Status {
+ vtable_.delete_file(handler_.obj(), path);
+ return CheckPyError();
+ });
+}
+
+Status PyFileSystem::Move(const std::string& src, const std::string& dest) {
+ return SafeCallIntoPython([&]() -> Status {
+ vtable_.move(handler_.obj(), src, dest);
+ return CheckPyError();
+ });
+}
+
+Status PyFileSystem::CopyFile(const std::string& src, const std::string& dest) {
+ return SafeCallIntoPython([&]() -> Status {
+ vtable_.copy_file(handler_.obj(), src, dest);
+ return CheckPyError();
+ });
+}
+
+Result<std::shared_ptr<io::InputStream>> PyFileSystem::OpenInputStream(
+ const std::string& path) {
+ std::shared_ptr<io::InputStream> stream;
+ auto st = SafeCallIntoPython([&]() -> Status {
+ vtable_.open_input_stream(handler_.obj(), path, &stream);
+ return CheckPyError();
+ });
+ RETURN_NOT_OK(st);
+ return stream;
+}
+
+Result<std::shared_ptr<io::RandomAccessFile>> PyFileSystem::OpenInputFile(
+ const std::string& path) {
+ std::shared_ptr<io::RandomAccessFile> stream;
+ auto st = SafeCallIntoPython([&]() -> Status {
+ vtable_.open_input_file(handler_.obj(), path, &stream);
+ return CheckPyError();
+ });
+ RETURN_NOT_OK(st);
+ return stream;
+}
+
+Result<std::shared_ptr<io::OutputStream>> PyFileSystem::OpenOutputStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ std::shared_ptr<io::OutputStream> stream;
+ auto st = SafeCallIntoPython([&]() -> Status {
+ vtable_.open_output_stream(handler_.obj(), path, metadata, &stream);
+ return CheckPyError();
+ });
+ RETURN_NOT_OK(st);
+ return stream;
+}
+
+Result<std::shared_ptr<io::OutputStream>> PyFileSystem::OpenAppendStream(
+ const std::string& path, const std::shared_ptr<const KeyValueMetadata>& metadata) {
+ std::shared_ptr<io::OutputStream> stream;
+ auto st = SafeCallIntoPython([&]() -> Status {
+ vtable_.open_append_stream(handler_.obj(), path, metadata, &stream);
+ return CheckPyError();
+ });
+ RETURN_NOT_OK(st);
+ return stream;
+}
+
+Result<std::string> PyFileSystem::NormalizePath(std::string path) {
+ std::string normalized;
+ auto st = SafeCallIntoPython([&]() -> Status {
+ vtable_.normalize_path(handler_.obj(), path, &normalized);
+ return CheckPyError();
+ });
+ RETURN_NOT_OK(st);
+ return normalized;
+}
+
+} // namespace fs
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/filesystem.h b/src/arrow/cpp/src/arrow/python/filesystem.h
new file mode 100644
index 000000000..e1235f8de
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/filesystem.h
@@ -0,0 +1,126 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/filesystem/filesystem.h"
+#include "arrow/python/common.h"
+#include "arrow/python/visibility.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace py {
+namespace fs {
+
+class ARROW_PYTHON_EXPORT PyFileSystemVtable {
+ public:
+ std::function<void(PyObject*, std::string* out)> get_type_name;
+ std::function<bool(PyObject*, const arrow::fs::FileSystem& other)> equals;
+
+ std::function<void(PyObject*, const std::string& path, arrow::fs::FileInfo* out)>
+ get_file_info;
+ std::function<void(PyObject*, const std::vector<std::string>& paths,
+ std::vector<arrow::fs::FileInfo>* out)>
+ get_file_info_vector;
+ std::function<void(PyObject*, const arrow::fs::FileSelector&,
+ std::vector<arrow::fs::FileInfo>* out)>
+ get_file_info_selector;
+
+ std::function<void(PyObject*, const std::string& path, bool)> create_dir;
+ std::function<void(PyObject*, const std::string& path)> delete_dir;
+ std::function<void(PyObject*, const std::string& path)> delete_dir_contents;
+ std::function<void(PyObject*)> delete_root_dir_contents;
+ std::function<void(PyObject*, const std::string& path)> delete_file;
+ std::function<void(PyObject*, const std::string& src, const std::string& dest)> move;
+ std::function<void(PyObject*, const std::string& src, const std::string& dest)>
+ copy_file;
+
+ std::function<void(PyObject*, const std::string& path,
+ std::shared_ptr<io::InputStream>* out)>
+ open_input_stream;
+ std::function<void(PyObject*, const std::string& path,
+ std::shared_ptr<io::RandomAccessFile>* out)>
+ open_input_file;
+ std::function<void(PyObject*, const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>&,
+ std::shared_ptr<io::OutputStream>* out)>
+ open_output_stream;
+ std::function<void(PyObject*, const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>&,
+ std::shared_ptr<io::OutputStream>* out)>
+ open_append_stream;
+
+ std::function<void(PyObject*, const std::string& path, std::string* out)>
+ normalize_path;
+};
+
+class ARROW_PYTHON_EXPORT PyFileSystem : public arrow::fs::FileSystem {
+ public:
+ PyFileSystem(PyObject* handler, PyFileSystemVtable vtable);
+ ~PyFileSystem() override;
+
+ static std::shared_ptr<PyFileSystem> Make(PyObject* handler, PyFileSystemVtable vtable);
+
+ std::string type_name() const override;
+
+ bool Equals(const FileSystem& other) const override;
+
+ Result<arrow::fs::FileInfo> GetFileInfo(const std::string& path) override;
+ Result<std::vector<arrow::fs::FileInfo>> GetFileInfo(
+ const std::vector<std::string>& paths) override;
+ Result<std::vector<arrow::fs::FileInfo>> GetFileInfo(
+ const arrow::fs::FileSelector& select) override;
+
+ Status CreateDir(const std::string& path, bool recursive = true) override;
+
+ Status DeleteDir(const std::string& path) override;
+ Status DeleteDirContents(const std::string& path) override;
+ Status DeleteRootDirContents() override;
+
+ Status DeleteFile(const std::string& path) override;
+
+ Status Move(const std::string& src, const std::string& dest) override;
+
+ Status CopyFile(const std::string& src, const std::string& dest) override;
+
+ Result<std::shared_ptr<io::InputStream>> OpenInputStream(
+ const std::string& path) override;
+ Result<std::shared_ptr<io::RandomAccessFile>> OpenInputFile(
+ const std::string& path) override;
+ Result<std::shared_ptr<io::OutputStream>> OpenOutputStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+ Result<std::shared_ptr<io::OutputStream>> OpenAppendStream(
+ const std::string& path,
+ const std::shared_ptr<const KeyValueMetadata>& metadata = {}) override;
+
+ Result<std::string> NormalizePath(std::string path) override;
+
+ PyObject* handler() const { return handler_.obj(); }
+
+ private:
+ OwnedRefNoGIL handler_;
+ PyFileSystemVtable vtable_;
+};
+
+} // namespace fs
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/flight.cc b/src/arrow/cpp/src/arrow/python/flight.cc
new file mode 100644
index 000000000..ee1491e0d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/flight.cc
@@ -0,0 +1,408 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <signal.h>
+#include <utility>
+
+#include "arrow/flight/internal.h"
+#include "arrow/python/flight.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+
+using arrow::flight::FlightPayload;
+
+namespace arrow {
+namespace py {
+namespace flight {
+
+const char* kPyServerMiddlewareName = "arrow.py_server_middleware";
+
+PyServerAuthHandler::PyServerAuthHandler(PyObject* handler,
+ const PyServerAuthHandlerVtable& vtable)
+ : vtable_(vtable) {
+ Py_INCREF(handler);
+ handler_.reset(handler);
+}
+
+Status PyServerAuthHandler::Authenticate(arrow::flight::ServerAuthSender* outgoing,
+ arrow::flight::ServerAuthReader* incoming) {
+ return SafeCallIntoPython([=] {
+ const Status status = vtable_.authenticate(handler_.obj(), outgoing, incoming);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+Status PyServerAuthHandler::IsValid(const std::string& token,
+ std::string* peer_identity) {
+ return SafeCallIntoPython([=] {
+ const Status status = vtable_.is_valid(handler_.obj(), token, peer_identity);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+PyClientAuthHandler::PyClientAuthHandler(PyObject* handler,
+ const PyClientAuthHandlerVtable& vtable)
+ : vtable_(vtable) {
+ Py_INCREF(handler);
+ handler_.reset(handler);
+}
+
+Status PyClientAuthHandler::Authenticate(arrow::flight::ClientAuthSender* outgoing,
+ arrow::flight::ClientAuthReader* incoming) {
+ return SafeCallIntoPython([=] {
+ const Status status = vtable_.authenticate(handler_.obj(), outgoing, incoming);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+Status PyClientAuthHandler::GetToken(std::string* token) {
+ return SafeCallIntoPython([=] {
+ const Status status = vtable_.get_token(handler_.obj(), token);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+PyFlightServer::PyFlightServer(PyObject* server, const PyFlightServerVtable& vtable)
+ : vtable_(vtable) {
+ Py_INCREF(server);
+ server_.reset(server);
+}
+
+Status PyFlightServer::ListFlights(
+ const arrow::flight::ServerCallContext& context,
+ const arrow::flight::Criteria* criteria,
+ std::unique_ptr<arrow::flight::FlightListing>* listings) {
+ return SafeCallIntoPython([&] {
+ const Status status =
+ vtable_.list_flights(server_.obj(), context, criteria, listings);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+Status PyFlightServer::GetFlightInfo(const arrow::flight::ServerCallContext& context,
+ const arrow::flight::FlightDescriptor& request,
+ std::unique_ptr<arrow::flight::FlightInfo>* info) {
+ return SafeCallIntoPython([&] {
+ const Status status = vtable_.get_flight_info(server_.obj(), context, request, info);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+Status PyFlightServer::GetSchema(const arrow::flight::ServerCallContext& context,
+ const arrow::flight::FlightDescriptor& request,
+ std::unique_ptr<arrow::flight::SchemaResult>* result) {
+ return SafeCallIntoPython([&] {
+ const Status status = vtable_.get_schema(server_.obj(), context, request, result);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+Status PyFlightServer::DoGet(const arrow::flight::ServerCallContext& context,
+ const arrow::flight::Ticket& request,
+ std::unique_ptr<arrow::flight::FlightDataStream>* stream) {
+ return SafeCallIntoPython([&] {
+ const Status status = vtable_.do_get(server_.obj(), context, request, stream);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+Status PyFlightServer::DoPut(
+ const arrow::flight::ServerCallContext& context,
+ std::unique_ptr<arrow::flight::FlightMessageReader> reader,
+ std::unique_ptr<arrow::flight::FlightMetadataWriter> writer) {
+ return SafeCallIntoPython([&] {
+ const Status status =
+ vtable_.do_put(server_.obj(), context, std::move(reader), std::move(writer));
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+Status PyFlightServer::DoExchange(
+ const arrow::flight::ServerCallContext& context,
+ std::unique_ptr<arrow::flight::FlightMessageReader> reader,
+ std::unique_ptr<arrow::flight::FlightMessageWriter> writer) {
+ return SafeCallIntoPython([&] {
+ const Status status =
+ vtable_.do_exchange(server_.obj(), context, std::move(reader), std::move(writer));
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+Status PyFlightServer::DoAction(const arrow::flight::ServerCallContext& context,
+ const arrow::flight::Action& action,
+ std::unique_ptr<arrow::flight::ResultStream>* result) {
+ return SafeCallIntoPython([&] {
+ const Status status = vtable_.do_action(server_.obj(), context, action, result);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+Status PyFlightServer::ListActions(const arrow::flight::ServerCallContext& context,
+ std::vector<arrow::flight::ActionType>* actions) {
+ return SafeCallIntoPython([&] {
+ const Status status = vtable_.list_actions(server_.obj(), context, actions);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+Status PyFlightServer::ServeWithSignals() {
+ // Respect the current Python settings, i.e. only interrupt the server if there is
+ // an active signal handler for SIGINT and SIGTERM.
+ std::vector<int> signals;
+ for (const int signum : {SIGINT, SIGTERM}) {
+ ARROW_ASSIGN_OR_RAISE(auto handler, ::arrow::internal::GetSignalHandler(signum));
+ auto cb = handler.callback();
+ if (cb != SIG_DFL && cb != SIG_IGN) {
+ signals.push_back(signum);
+ }
+ }
+ RETURN_NOT_OK(SetShutdownOnSignals(signals));
+
+ // Serve until we got told to shutdown or a signal interrupted us
+ RETURN_NOT_OK(Serve());
+ int signum = GotSignal();
+ if (signum != 0) {
+ // Issue the signal again with Python's signal handlers restored
+ PyAcquireGIL lock;
+ raise(signum);
+ // XXX Ideally we would loop and serve again if no exception was raised.
+ // Unfortunately, gRPC will return immediately if Serve() is called again.
+ ARROW_UNUSED(PyErr_CheckSignals());
+ }
+
+ return Status::OK();
+}
+
+PyFlightResultStream::PyFlightResultStream(PyObject* generator,
+ PyFlightResultStreamCallback callback)
+ : callback_(callback) {
+ Py_INCREF(generator);
+ generator_.reset(generator);
+}
+
+Status PyFlightResultStream::Next(std::unique_ptr<arrow::flight::Result>* result) {
+ return SafeCallIntoPython([=] {
+ const Status status = callback_(generator_.obj(), result);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+PyFlightDataStream::PyFlightDataStream(
+ PyObject* data_source, std::unique_ptr<arrow::flight::FlightDataStream> stream)
+ : stream_(std::move(stream)) {
+ Py_INCREF(data_source);
+ data_source_.reset(data_source);
+}
+
+std::shared_ptr<Schema> PyFlightDataStream::schema() { return stream_->schema(); }
+
+Status PyFlightDataStream::GetSchemaPayload(FlightPayload* payload) {
+ return stream_->GetSchemaPayload(payload);
+}
+
+Status PyFlightDataStream::Next(FlightPayload* payload) { return stream_->Next(payload); }
+
+PyGeneratorFlightDataStream::PyGeneratorFlightDataStream(
+ PyObject* generator, std::shared_ptr<arrow::Schema> schema,
+ PyGeneratorFlightDataStreamCallback callback, const ipc::IpcWriteOptions& options)
+ : schema_(schema), mapper_(*schema_), options_(options), callback_(callback) {
+ Py_INCREF(generator);
+ generator_.reset(generator);
+}
+
+std::shared_ptr<Schema> PyGeneratorFlightDataStream::schema() { return schema_; }
+
+Status PyGeneratorFlightDataStream::GetSchemaPayload(FlightPayload* payload) {
+ return ipc::GetSchemaPayload(*schema_, options_, mapper_, &payload->ipc_message);
+}
+
+Status PyGeneratorFlightDataStream::Next(FlightPayload* payload) {
+ return SafeCallIntoPython([=] {
+ const Status status = callback_(generator_.obj(), payload);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+// Flight Server Middleware
+
+PyServerMiddlewareFactory::PyServerMiddlewareFactory(PyObject* factory,
+ StartCallCallback start_call)
+ : start_call_(start_call) {
+ Py_INCREF(factory);
+ factory_.reset(factory);
+}
+
+Status PyServerMiddlewareFactory::StartCall(
+ const arrow::flight::CallInfo& info,
+ const arrow::flight::CallHeaders& incoming_headers,
+ std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) {
+ return SafeCallIntoPython([&] {
+ const Status status = start_call_(factory_.obj(), info, incoming_headers, middleware);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+}
+
+PyServerMiddleware::PyServerMiddleware(PyObject* middleware, Vtable vtable)
+ : vtable_(vtable) {
+ Py_INCREF(middleware);
+ middleware_.reset(middleware);
+}
+
+void PyServerMiddleware::SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) {
+ const Status& status = SafeCallIntoPython([&] {
+ const Status status = vtable_.sending_headers(middleware_.obj(), outgoing_headers);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+
+ if (!status.ok()) {
+ ARROW_LOG(WARNING) << "Python server middleware failed in SendingHeaders: " << status;
+ }
+}
+
+void PyServerMiddleware::CallCompleted(const Status& call_status) {
+ const Status& status = SafeCallIntoPython([&] {
+ const Status status = vtable_.call_completed(middleware_.obj(), call_status);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+ if (!status.ok()) {
+ ARROW_LOG(WARNING) << "Python server middleware failed in CallCompleted: " << status;
+ }
+}
+
+std::string PyServerMiddleware::name() const { return kPyServerMiddlewareName; }
+
+PyObject* PyServerMiddleware::py_object() const { return middleware_.obj(); }
+
+// Flight Client Middleware
+
+PyClientMiddlewareFactory::PyClientMiddlewareFactory(PyObject* factory,
+ StartCallCallback start_call)
+ : start_call_(start_call) {
+ Py_INCREF(factory);
+ factory_.reset(factory);
+}
+
+void PyClientMiddlewareFactory::StartCall(
+ const arrow::flight::CallInfo& info,
+ std::unique_ptr<arrow::flight::ClientMiddleware>* middleware) {
+ const Status& status = SafeCallIntoPython([&] {
+ const Status status = start_call_(factory_.obj(), info, middleware);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+ if (!status.ok()) {
+ ARROW_LOG(WARNING) << "Python client middleware failed in StartCall: " << status;
+ }
+}
+
+PyClientMiddleware::PyClientMiddleware(PyObject* middleware, Vtable vtable)
+ : vtable_(vtable) {
+ Py_INCREF(middleware);
+ middleware_.reset(middleware);
+}
+
+void PyClientMiddleware::SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) {
+ const Status& status = SafeCallIntoPython([&] {
+ const Status status = vtable_.sending_headers(middleware_.obj(), outgoing_headers);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+ if (!status.ok()) {
+ ARROW_LOG(WARNING) << "Python client middleware failed in StartCall: " << status;
+ }
+}
+
+void PyClientMiddleware::ReceivedHeaders(
+ const arrow::flight::CallHeaders& incoming_headers) {
+ const Status& status = SafeCallIntoPython([&] {
+ const Status status = vtable_.received_headers(middleware_.obj(), incoming_headers);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+ if (!status.ok()) {
+ ARROW_LOG(WARNING) << "Python client middleware failed in StartCall: " << status;
+ }
+}
+
+void PyClientMiddleware::CallCompleted(const Status& call_status) {
+ const Status& status = SafeCallIntoPython([&] {
+ const Status status = vtable_.call_completed(middleware_.obj(), call_status);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
+ });
+ if (!status.ok()) {
+ ARROW_LOG(WARNING) << "Python client middleware failed in StartCall: " << status;
+ }
+}
+
+Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
+ const arrow::flight::FlightDescriptor& descriptor,
+ const std::vector<arrow::flight::FlightEndpoint>& endpoints,
+ int64_t total_records, int64_t total_bytes,
+ std::unique_ptr<arrow::flight::FlightInfo>* out) {
+ arrow::flight::FlightInfo::Data flight_data;
+ RETURN_NOT_OK(arrow::flight::internal::SchemaToString(*schema, &flight_data.schema));
+ flight_data.descriptor = descriptor;
+ flight_data.endpoints = endpoints;
+ flight_data.total_records = total_records;
+ flight_data.total_bytes = total_bytes;
+ arrow::flight::FlightInfo value(flight_data);
+ *out = std::unique_ptr<arrow::flight::FlightInfo>(new arrow::flight::FlightInfo(value));
+ return Status::OK();
+}
+
+Status CreateSchemaResult(const std::shared_ptr<arrow::Schema>& schema,
+ std::unique_ptr<arrow::flight::SchemaResult>* out) {
+ std::string schema_in;
+ RETURN_NOT_OK(arrow::flight::internal::SchemaToString(*schema, &schema_in));
+ arrow::flight::SchemaResult value(schema_in);
+ *out = std::unique_ptr<arrow::flight::SchemaResult>(
+ new arrow::flight::SchemaResult(value));
+ return Status::OK();
+}
+
+Status DeserializeBasicAuth(const std::string& buf,
+ std::unique_ptr<arrow::flight::BasicAuth>* out) {
+ auto basic_auth = new arrow::flight::BasicAuth();
+ *out = std::unique_ptr<arrow::flight::BasicAuth>(basic_auth);
+ return arrow::flight::BasicAuth::Deserialize(buf, basic_auth);
+}
+
+Status SerializeBasicAuth(const arrow::flight::BasicAuth& basic_auth, std::string* out) {
+ return arrow::flight::BasicAuth::Serialize(basic_auth, out);
+}
+
+} // namespace flight
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/flight.h b/src/arrow/cpp/src/arrow/python/flight.h
new file mode 100644
index 000000000..45a090ef4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/flight.h
@@ -0,0 +1,357 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/flight/api.h"
+#include "arrow/ipc/dictionary.h"
+#include "arrow/python/common.h"
+
+#if defined(_WIN32) || defined(__CYGWIN__) // Windows
+#if defined(_MSC_VER)
+#pragma warning(disable : 4251)
+#else
+#pragma GCC diagnostic ignored "-Wattributes"
+#endif
+
+#ifdef ARROW_STATIC
+#define ARROW_PYFLIGHT_EXPORT
+#elif defined(ARROW_PYFLIGHT_EXPORTING)
+#define ARROW_PYFLIGHT_EXPORT __declspec(dllexport)
+#else
+#define ARROW_PYFLIGHT_EXPORT __declspec(dllimport)
+#endif
+
+#else // Not Windows
+#ifndef ARROW_PYFLIGHT_EXPORT
+#define ARROW_PYFLIGHT_EXPORT __attribute__((visibility("default")))
+#endif
+#endif // Non-Windows
+
+namespace arrow {
+
+namespace py {
+
+namespace flight {
+
+ARROW_PYFLIGHT_EXPORT
+extern const char* kPyServerMiddlewareName;
+
+/// \brief A table of function pointers for calling from C++ into
+/// Python.
+class ARROW_PYFLIGHT_EXPORT PyFlightServerVtable {
+ public:
+ std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
+ const arrow::flight::Criteria*,
+ std::unique_ptr<arrow::flight::FlightListing>*)>
+ list_flights;
+ std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
+ const arrow::flight::FlightDescriptor&,
+ std::unique_ptr<arrow::flight::FlightInfo>*)>
+ get_flight_info;
+ std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
+ const arrow::flight::FlightDescriptor&,
+ std::unique_ptr<arrow::flight::SchemaResult>*)>
+ get_schema;
+ std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
+ const arrow::flight::Ticket&,
+ std::unique_ptr<arrow::flight::FlightDataStream>*)>
+ do_get;
+ std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
+ std::unique_ptr<arrow::flight::FlightMessageReader>,
+ std::unique_ptr<arrow::flight::FlightMetadataWriter>)>
+ do_put;
+ std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
+ std::unique_ptr<arrow::flight::FlightMessageReader>,
+ std::unique_ptr<arrow::flight::FlightMessageWriter>)>
+ do_exchange;
+ std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
+ const arrow::flight::Action&,
+ std::unique_ptr<arrow::flight::ResultStream>*)>
+ do_action;
+ std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
+ std::vector<arrow::flight::ActionType>*)>
+ list_actions;
+};
+
+class ARROW_PYFLIGHT_EXPORT PyServerAuthHandlerVtable {
+ public:
+ std::function<Status(PyObject*, arrow::flight::ServerAuthSender*,
+ arrow::flight::ServerAuthReader*)>
+ authenticate;
+ std::function<Status(PyObject*, const std::string&, std::string*)> is_valid;
+};
+
+class ARROW_PYFLIGHT_EXPORT PyClientAuthHandlerVtable {
+ public:
+ std::function<Status(PyObject*, arrow::flight::ClientAuthSender*,
+ arrow::flight::ClientAuthReader*)>
+ authenticate;
+ std::function<Status(PyObject*, std::string*)> get_token;
+};
+
+/// \brief A helper to implement an auth mechanism in Python.
+class ARROW_PYFLIGHT_EXPORT PyServerAuthHandler
+ : public arrow::flight::ServerAuthHandler {
+ public:
+ explicit PyServerAuthHandler(PyObject* handler,
+ const PyServerAuthHandlerVtable& vtable);
+ Status Authenticate(arrow::flight::ServerAuthSender* outgoing,
+ arrow::flight::ServerAuthReader* incoming) override;
+ Status IsValid(const std::string& token, std::string* peer_identity) override;
+
+ private:
+ OwnedRefNoGIL handler_;
+ PyServerAuthHandlerVtable vtable_;
+};
+
+/// \brief A helper to implement an auth mechanism in Python.
+class ARROW_PYFLIGHT_EXPORT PyClientAuthHandler
+ : public arrow::flight::ClientAuthHandler {
+ public:
+ explicit PyClientAuthHandler(PyObject* handler,
+ const PyClientAuthHandlerVtable& vtable);
+ Status Authenticate(arrow::flight::ClientAuthSender* outgoing,
+ arrow::flight::ClientAuthReader* incoming) override;
+ Status GetToken(std::string* token) override;
+
+ private:
+ OwnedRefNoGIL handler_;
+ PyClientAuthHandlerVtable vtable_;
+};
+
+class ARROW_PYFLIGHT_EXPORT PyFlightServer : public arrow::flight::FlightServerBase {
+ public:
+ explicit PyFlightServer(PyObject* server, const PyFlightServerVtable& vtable);
+
+ // Like Serve(), but set up signals and invoke Python signal handlers
+ // if necessary. This function may return with a Python exception set.
+ Status ServeWithSignals();
+
+ Status ListFlights(const arrow::flight::ServerCallContext& context,
+ const arrow::flight::Criteria* criteria,
+ std::unique_ptr<arrow::flight::FlightListing>* listings) override;
+ Status GetFlightInfo(const arrow::flight::ServerCallContext& context,
+ const arrow::flight::FlightDescriptor& request,
+ std::unique_ptr<arrow::flight::FlightInfo>* info) override;
+ Status GetSchema(const arrow::flight::ServerCallContext& context,
+ const arrow::flight::FlightDescriptor& request,
+ std::unique_ptr<arrow::flight::SchemaResult>* result) override;
+ Status DoGet(const arrow::flight::ServerCallContext& context,
+ const arrow::flight::Ticket& request,
+ std::unique_ptr<arrow::flight::FlightDataStream>* stream) override;
+ Status DoPut(const arrow::flight::ServerCallContext& context,
+ std::unique_ptr<arrow::flight::FlightMessageReader> reader,
+ std::unique_ptr<arrow::flight::FlightMetadataWriter> writer) override;
+ Status DoExchange(const arrow::flight::ServerCallContext& context,
+ std::unique_ptr<arrow::flight::FlightMessageReader> reader,
+ std::unique_ptr<arrow::flight::FlightMessageWriter> writer) override;
+ Status DoAction(const arrow::flight::ServerCallContext& context,
+ const arrow::flight::Action& action,
+ std::unique_ptr<arrow::flight::ResultStream>* result) override;
+ Status ListActions(const arrow::flight::ServerCallContext& context,
+ std::vector<arrow::flight::ActionType>* actions) override;
+
+ private:
+ OwnedRefNoGIL server_;
+ PyFlightServerVtable vtable_;
+};
+
+/// \brief A callback that obtains the next result from a Flight action.
+typedef std::function<Status(PyObject*, std::unique_ptr<arrow::flight::Result>*)>
+ PyFlightResultStreamCallback;
+
+/// \brief A ResultStream built around a Python callback.
+class ARROW_PYFLIGHT_EXPORT PyFlightResultStream : public arrow::flight::ResultStream {
+ public:
+ /// \brief Construct a FlightResultStream from a Python object and callback.
+ /// Must only be called while holding the GIL.
+ explicit PyFlightResultStream(PyObject* generator,
+ PyFlightResultStreamCallback callback);
+ Status Next(std::unique_ptr<arrow::flight::Result>* result) override;
+
+ private:
+ OwnedRefNoGIL generator_;
+ PyFlightResultStreamCallback callback_;
+};
+
+/// \brief A wrapper around a FlightDataStream that keeps alive a
+/// Python object backing it.
+class ARROW_PYFLIGHT_EXPORT PyFlightDataStream : public arrow::flight::FlightDataStream {
+ public:
+ /// \brief Construct a FlightDataStream from a Python object and underlying stream.
+ /// Must only be called while holding the GIL.
+ explicit PyFlightDataStream(PyObject* data_source,
+ std::unique_ptr<arrow::flight::FlightDataStream> stream);
+
+ std::shared_ptr<Schema> schema() override;
+ Status GetSchemaPayload(arrow::flight::FlightPayload* payload) override;
+ Status Next(arrow::flight::FlightPayload* payload) override;
+
+ private:
+ OwnedRefNoGIL data_source_;
+ std::unique_ptr<arrow::flight::FlightDataStream> stream_;
+};
+
+class ARROW_PYFLIGHT_EXPORT PyServerMiddlewareFactory
+ : public arrow::flight::ServerMiddlewareFactory {
+ public:
+ /// \brief A callback to create the middleware instance in Python
+ typedef std::function<Status(
+ PyObject*, const arrow::flight::CallInfo& info,
+ const arrow::flight::CallHeaders& incoming_headers,
+ std::shared_ptr<arrow::flight::ServerMiddleware>* middleware)>
+ StartCallCallback;
+
+ /// \brief Must only be called while holding the GIL.
+ explicit PyServerMiddlewareFactory(PyObject* factory, StartCallCallback start_call);
+
+ Status StartCall(const arrow::flight::CallInfo& info,
+ const arrow::flight::CallHeaders& incoming_headers,
+ std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) override;
+
+ private:
+ OwnedRefNoGIL factory_;
+ StartCallCallback start_call_;
+};
+
+class ARROW_PYFLIGHT_EXPORT PyServerMiddleware : public arrow::flight::ServerMiddleware {
+ public:
+ typedef std::function<Status(PyObject*,
+ arrow::flight::AddCallHeaders* outgoing_headers)>
+ SendingHeadersCallback;
+ typedef std::function<Status(PyObject*, const Status& status)> CallCompletedCallback;
+
+ struct Vtable {
+ SendingHeadersCallback sending_headers;
+ CallCompletedCallback call_completed;
+ };
+
+ /// \brief Must only be called while holding the GIL.
+ explicit PyServerMiddleware(PyObject* middleware, Vtable vtable);
+
+ void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override;
+ void CallCompleted(const Status& status) override;
+ std::string name() const override;
+ /// \brief Get the underlying Python object.
+ PyObject* py_object() const;
+
+ private:
+ OwnedRefNoGIL middleware_;
+ Vtable vtable_;
+};
+
+class ARROW_PYFLIGHT_EXPORT PyClientMiddlewareFactory
+ : public arrow::flight::ClientMiddlewareFactory {
+ public:
+ /// \brief A callback to create the middleware instance in Python
+ typedef std::function<Status(
+ PyObject*, const arrow::flight::CallInfo& info,
+ std::unique_ptr<arrow::flight::ClientMiddleware>* middleware)>
+ StartCallCallback;
+
+ /// \brief Must only be called while holding the GIL.
+ explicit PyClientMiddlewareFactory(PyObject* factory, StartCallCallback start_call);
+
+ void StartCall(const arrow::flight::CallInfo& info,
+ std::unique_ptr<arrow::flight::ClientMiddleware>* middleware) override;
+
+ private:
+ OwnedRefNoGIL factory_;
+ StartCallCallback start_call_;
+};
+
+class ARROW_PYFLIGHT_EXPORT PyClientMiddleware : public arrow::flight::ClientMiddleware {
+ public:
+ typedef std::function<Status(PyObject*,
+ arrow::flight::AddCallHeaders* outgoing_headers)>
+ SendingHeadersCallback;
+ typedef std::function<Status(PyObject*,
+ const arrow::flight::CallHeaders& incoming_headers)>
+ ReceivedHeadersCallback;
+ typedef std::function<Status(PyObject*, const Status& status)> CallCompletedCallback;
+
+ struct Vtable {
+ SendingHeadersCallback sending_headers;
+ ReceivedHeadersCallback received_headers;
+ CallCompletedCallback call_completed;
+ };
+
+ /// \brief Must only be called while holding the GIL.
+ explicit PyClientMiddleware(PyObject* factory, Vtable vtable);
+
+ void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override;
+ void ReceivedHeaders(const arrow::flight::CallHeaders& incoming_headers) override;
+ void CallCompleted(const Status& status) override;
+
+ private:
+ OwnedRefNoGIL middleware_;
+ Vtable vtable_;
+};
+
+/// \brief A callback that obtains the next payload from a Flight result stream.
+typedef std::function<Status(PyObject*, arrow::flight::FlightPayload*)>
+ PyGeneratorFlightDataStreamCallback;
+
+/// \brief A FlightDataStream built around a Python callback.
+class ARROW_PYFLIGHT_EXPORT PyGeneratorFlightDataStream
+ : public arrow::flight::FlightDataStream {
+ public:
+ /// \brief Construct a FlightDataStream from a Python object and underlying stream.
+ /// Must only be called while holding the GIL.
+ explicit PyGeneratorFlightDataStream(PyObject* generator,
+ std::shared_ptr<arrow::Schema> schema,
+ PyGeneratorFlightDataStreamCallback callback,
+ const ipc::IpcWriteOptions& options);
+ std::shared_ptr<Schema> schema() override;
+ Status GetSchemaPayload(arrow::flight::FlightPayload* payload) override;
+ Status Next(arrow::flight::FlightPayload* payload) override;
+
+ private:
+ OwnedRefNoGIL generator_;
+ std::shared_ptr<arrow::Schema> schema_;
+ ipc::DictionaryFieldMapper mapper_;
+ ipc::IpcWriteOptions options_;
+ PyGeneratorFlightDataStreamCallback callback_;
+};
+
+ARROW_PYFLIGHT_EXPORT
+Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
+ const arrow::flight::FlightDescriptor& descriptor,
+ const std::vector<arrow::flight::FlightEndpoint>& endpoints,
+ int64_t total_records, int64_t total_bytes,
+ std::unique_ptr<arrow::flight::FlightInfo>* out);
+
+ARROW_PYFLIGHT_EXPORT
+Status DeserializeBasicAuth(const std::string& buf,
+ std::unique_ptr<arrow::flight::BasicAuth>* out);
+
+ARROW_PYFLIGHT_EXPORT
+Status SerializeBasicAuth(const arrow::flight::BasicAuth& basic_auth, std::string* out);
+
+/// \brief Create a SchemaResult from schema.
+ARROW_PYFLIGHT_EXPORT
+Status CreateSchemaResult(const std::shared_ptr<arrow::Schema>& schema,
+ std::unique_ptr<arrow::flight::SchemaResult>* out);
+
+} // namespace flight
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/helpers.cc b/src/arrow/cpp/src/arrow/python/helpers.cc
new file mode 100644
index 000000000..c266abc16
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/helpers.cc
@@ -0,0 +1,470 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// helpers.h includes a NumPy header, so we include this first
+#include "arrow/python/numpy_interop.h"
+
+#include "arrow/python/helpers.h"
+
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include <type_traits>
+
+#include "arrow/python/common.h"
+#include "arrow/python/decimal.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace py {
+
+#define GET_PRIMITIVE_TYPE(NAME, FACTORY) \
+ case Type::NAME: \
+ return FACTORY()
+
+std::shared_ptr<DataType> GetPrimitiveType(Type::type type) {
+ switch (type) {
+ case Type::NA:
+ return null();
+ GET_PRIMITIVE_TYPE(UINT8, uint8);
+ GET_PRIMITIVE_TYPE(INT8, int8);
+ GET_PRIMITIVE_TYPE(UINT16, uint16);
+ GET_PRIMITIVE_TYPE(INT16, int16);
+ GET_PRIMITIVE_TYPE(UINT32, uint32);
+ GET_PRIMITIVE_TYPE(INT32, int32);
+ GET_PRIMITIVE_TYPE(UINT64, uint64);
+ GET_PRIMITIVE_TYPE(INT64, int64);
+ GET_PRIMITIVE_TYPE(DATE32, date32);
+ GET_PRIMITIVE_TYPE(DATE64, date64);
+ GET_PRIMITIVE_TYPE(BOOL, boolean);
+ GET_PRIMITIVE_TYPE(HALF_FLOAT, float16);
+ GET_PRIMITIVE_TYPE(FLOAT, float32);
+ GET_PRIMITIVE_TYPE(DOUBLE, float64);
+ GET_PRIMITIVE_TYPE(BINARY, binary);
+ GET_PRIMITIVE_TYPE(STRING, utf8);
+ GET_PRIMITIVE_TYPE(LARGE_BINARY, large_binary);
+ GET_PRIMITIVE_TYPE(LARGE_STRING, large_utf8);
+ GET_PRIMITIVE_TYPE(INTERVAL_MONTH_DAY_NANO, month_day_nano_interval);
+ default:
+ return nullptr;
+ }
+}
+
+PyObject* PyHalf_FromHalf(npy_half value) {
+ PyObject* result = PyArrayScalar_New(Half);
+ if (result != NULL) {
+ PyArrayScalar_ASSIGN(result, Half, value);
+ }
+ return result;
+}
+
+Status PyFloat_AsHalf(PyObject* obj, npy_half* out) {
+ if (PyArray_IsScalar(obj, Half)) {
+ *out = PyArrayScalar_VAL(obj, Half);
+ return Status::OK();
+ } else {
+ // XXX: cannot use npy_double_to_half() without linking with Numpy
+ return Status::TypeError("Expected np.float16 instance");
+ }
+}
+
+namespace internal {
+
+std::string PyBytes_AsStdString(PyObject* obj) {
+ DCHECK(PyBytes_Check(obj));
+ return std::string(PyBytes_AS_STRING(obj), PyBytes_GET_SIZE(obj));
+}
+
+Status PyUnicode_AsStdString(PyObject* obj, std::string* out) {
+ DCHECK(PyUnicode_Check(obj));
+ Py_ssize_t size;
+ // The utf-8 representation is cached on the unicode object
+ const char* data = PyUnicode_AsUTF8AndSize(obj, &size);
+ RETURN_IF_PYERROR();
+ *out = std::string(data, size);
+ return Status::OK();
+}
+
+std::string PyObject_StdStringRepr(PyObject* obj) {
+ OwnedRef unicode_ref(PyObject_Repr(obj));
+ OwnedRef bytes_ref;
+
+ if (unicode_ref) {
+ bytes_ref.reset(
+ PyUnicode_AsEncodedString(unicode_ref.obj(), "utf8", "backslashreplace"));
+ }
+ if (!bytes_ref) {
+ PyErr_Clear();
+ std::stringstream ss;
+ ss << "<object of type '" << Py_TYPE(obj)->tp_name << "' repr() failed>";
+ return ss.str();
+ }
+ return PyBytes_AsStdString(bytes_ref.obj());
+}
+
+Status PyObject_StdStringStr(PyObject* obj, std::string* out) {
+ OwnedRef string_ref(PyObject_Str(obj));
+ RETURN_IF_PYERROR();
+ return PyUnicode_AsStdString(string_ref.obj(), out);
+}
+
+Result<bool> IsModuleImported(const std::string& module_name) {
+ // PyImport_GetModuleDict returns with a borrowed reference
+ OwnedRef key(PyUnicode_FromString(module_name.c_str()));
+ auto is_imported = PyDict_Contains(PyImport_GetModuleDict(), key.obj());
+ RETURN_IF_PYERROR();
+ return is_imported;
+}
+
+Status ImportModule(const std::string& module_name, OwnedRef* ref) {
+ PyObject* module = PyImport_ImportModule(module_name.c_str());
+ RETURN_IF_PYERROR();
+ ref->reset(module);
+ return Status::OK();
+}
+
+Status ImportFromModule(PyObject* module, const std::string& name, OwnedRef* ref) {
+ PyObject* attr = PyObject_GetAttrString(module, name.c_str());
+ RETURN_IF_PYERROR();
+ ref->reset(attr);
+ return Status::OK();
+}
+
+namespace {
+
+Status IntegerOverflowStatus(PyObject* obj, const std::string& overflow_message) {
+ if (overflow_message.empty()) {
+ std::string obj_as_stdstring;
+ RETURN_NOT_OK(PyObject_StdStringStr(obj, &obj_as_stdstring));
+ return Status::Invalid("Value ", obj_as_stdstring,
+ " too large to fit in C integer type");
+ } else {
+ return Status::Invalid(overflow_message);
+ }
+}
+
+Result<OwnedRef> PyObjectToPyInt(PyObject* obj) {
+ // Try to call __index__ or __int__ on `obj`
+ // (starting from Python 3.10, the latter isn't done anymore by PyLong_AsLong*).
+ OwnedRef ref(PyNumber_Index(obj));
+ if (ref) {
+ return std::move(ref);
+ }
+ PyErr_Clear();
+ const auto nb = Py_TYPE(obj)->tp_as_number;
+ if (nb && nb->nb_int) {
+ ref.reset(nb->nb_int(obj));
+ if (!ref) {
+ RETURN_IF_PYERROR();
+ }
+ DCHECK(ref);
+ return std::move(ref);
+ }
+ return Status::TypeError(
+ "object of type ",
+ PyObject_StdStringRepr(reinterpret_cast<PyObject*>(Py_TYPE(obj))),
+ " cannot be converted to int");
+}
+
+// Extract C signed int from Python object
+template <typename Int, enable_if_t<std::is_signed<Int>::value, Int> = 0>
+Status CIntFromPythonImpl(PyObject* obj, Int* out, const std::string& overflow_message) {
+ static_assert(sizeof(Int) <= sizeof(long long), // NOLINT
+ "integer type larger than long long");
+
+ OwnedRef ref;
+ if (!PyLong_Check(obj)) {
+ ARROW_ASSIGN_OR_RAISE(ref, PyObjectToPyInt(obj));
+ obj = ref.obj();
+ }
+
+ if (sizeof(Int) > sizeof(long)) { // NOLINT
+ const auto value = PyLong_AsLongLong(obj);
+ if (ARROW_PREDICT_FALSE(value == -1)) {
+ RETURN_IF_PYERROR();
+ }
+ if (ARROW_PREDICT_FALSE(value < std::numeric_limits<Int>::min() ||
+ value > std::numeric_limits<Int>::max())) {
+ return IntegerOverflowStatus(obj, overflow_message);
+ }
+ *out = static_cast<Int>(value);
+ } else {
+ const auto value = PyLong_AsLong(obj);
+ if (ARROW_PREDICT_FALSE(value == -1)) {
+ RETURN_IF_PYERROR();
+ }
+ if (ARROW_PREDICT_FALSE(value < std::numeric_limits<Int>::min() ||
+ value > std::numeric_limits<Int>::max())) {
+ return IntegerOverflowStatus(obj, overflow_message);
+ }
+ *out = static_cast<Int>(value);
+ }
+ return Status::OK();
+}
+
+// Extract C unsigned int from Python object
+template <typename Int, enable_if_t<std::is_unsigned<Int>::value, Int> = 0>
+Status CIntFromPythonImpl(PyObject* obj, Int* out, const std::string& overflow_message) {
+ static_assert(sizeof(Int) <= sizeof(unsigned long long), // NOLINT
+ "integer type larger than unsigned long long");
+
+ OwnedRef ref;
+ if (!PyLong_Check(obj)) {
+ ARROW_ASSIGN_OR_RAISE(ref, PyObjectToPyInt(obj));
+ obj = ref.obj();
+ }
+
+ if (sizeof(Int) > sizeof(unsigned long)) { // NOLINT
+ const auto value = PyLong_AsUnsignedLongLong(obj);
+ if (ARROW_PREDICT_FALSE(value == static_cast<decltype(value)>(-1))) {
+ RETURN_IF_PYERROR();
+ }
+ if (ARROW_PREDICT_FALSE(value > std::numeric_limits<Int>::max())) {
+ return IntegerOverflowStatus(obj, overflow_message);
+ }
+ *out = static_cast<Int>(value);
+ } else {
+ const auto value = PyLong_AsUnsignedLong(obj);
+ if (ARROW_PREDICT_FALSE(value == static_cast<decltype(value)>(-1))) {
+ RETURN_IF_PYERROR();
+ }
+ if (ARROW_PREDICT_FALSE(value > std::numeric_limits<Int>::max())) {
+ return IntegerOverflowStatus(obj, overflow_message);
+ }
+ *out = static_cast<Int>(value);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+template <typename Int>
+Status CIntFromPython(PyObject* obj, Int* out, const std::string& overflow_message) {
+ if (PyBool_Check(obj)) {
+ return Status::TypeError("Expected integer, got bool");
+ }
+ return CIntFromPythonImpl(obj, out, overflow_message);
+}
+
+template Status CIntFromPython(PyObject*, int8_t*, const std::string&);
+template Status CIntFromPython(PyObject*, int16_t*, const std::string&);
+template Status CIntFromPython(PyObject*, int32_t*, const std::string&);
+template Status CIntFromPython(PyObject*, int64_t*, const std::string&);
+template Status CIntFromPython(PyObject*, uint8_t*, const std::string&);
+template Status CIntFromPython(PyObject*, uint16_t*, const std::string&);
+template Status CIntFromPython(PyObject*, uint32_t*, const std::string&);
+template Status CIntFromPython(PyObject*, uint64_t*, const std::string&);
+
+inline bool MayHaveNaN(PyObject* obj) {
+ // Some core types can be very quickly type-checked and do not allow NaN values
+ const int64_t non_nan_tpflags = Py_TPFLAGS_LONG_SUBCLASS | Py_TPFLAGS_LIST_SUBCLASS |
+ Py_TPFLAGS_TUPLE_SUBCLASS | Py_TPFLAGS_BYTES_SUBCLASS |
+ Py_TPFLAGS_UNICODE_SUBCLASS | Py_TPFLAGS_DICT_SUBCLASS |
+ Py_TPFLAGS_BASE_EXC_SUBCLASS | Py_TPFLAGS_TYPE_SUBCLASS;
+ return !PyType_HasFeature(Py_TYPE(obj), non_nan_tpflags);
+}
+
+bool PyFloat_IsNaN(PyObject* obj) {
+ return PyFloat_Check(obj) && std::isnan(PyFloat_AsDouble(obj));
+}
+
+namespace {
+
+static bool pandas_static_initialized = false;
+
+// Once initialized, these variables hold borrowed references to Pandas static data.
+// We should not use OwnedRef here because Python destructors would be
+// called on a finalized interpreter.
+static PyObject* pandas_NA = nullptr;
+static PyObject* pandas_NaT = nullptr;
+static PyObject* pandas_Timedelta = nullptr;
+static PyObject* pandas_Timestamp = nullptr;
+static PyTypeObject* pandas_NaTType = nullptr;
+static PyObject* pandas_DateOffset = nullptr;
+
+} // namespace
+
+void InitPandasStaticData() {
+ // NOTE: This is called with the GIL held. We needn't (and shouldn't,
+ // to avoid deadlocks) use an additional C++ lock (ARROW-10519).
+ if (pandas_static_initialized) {
+ return;
+ }
+
+ OwnedRef pandas;
+
+ // Import pandas
+ Status s = ImportModule("pandas", &pandas);
+ if (!s.ok()) {
+ return;
+ }
+
+ // Since ImportModule can release the GIL, another thread could have
+ // already initialized the static data.
+ if (pandas_static_initialized) {
+ return;
+ }
+ OwnedRef ref;
+
+ // set NaT sentinel and its type
+ if (ImportFromModule(pandas.obj(), "NaT", &ref).ok()) {
+ pandas_NaT = ref.obj();
+ // PyObject_Type returns a new reference but we trust that pandas.NaT will
+ // outlive our use of this PyObject*
+ pandas_NaTType = Py_TYPE(ref.obj());
+ }
+
+ // retain a reference to Timedelta
+ if (ImportFromModule(pandas.obj(), "Timedelta", &ref).ok()) {
+ pandas_Timedelta = ref.obj();
+ }
+
+ // retain a reference to Timestamp
+ if (ImportFromModule(pandas.obj(), "Timestamp", &ref).ok()) {
+ pandas_Timestamp = ref.obj();
+ }
+
+ // if pandas.NA exists, retain a reference to it
+ if (ImportFromModule(pandas.obj(), "NA", &ref).ok()) {
+ pandas_NA = ref.obj();
+ }
+
+ // Import DateOffset type
+ if (ImportFromModule(pandas.obj(), "DateOffset", &ref).ok()) {
+ pandas_DateOffset = ref.obj();
+ }
+
+ pandas_static_initialized = true;
+}
+
+bool PandasObjectIsNull(PyObject* obj) {
+ if (!MayHaveNaN(obj)) {
+ return false;
+ }
+ if (obj == Py_None) {
+ return true;
+ }
+ if (PyFloat_IsNaN(obj) || (pandas_NA && obj == pandas_NA) ||
+ (pandas_NaTType && PyObject_TypeCheck(obj, pandas_NaTType)) ||
+ (internal::PyDecimal_Check(obj) && internal::PyDecimal_ISNAN(obj))) {
+ return true;
+ }
+ return false;
+}
+
+bool IsPandasTimedelta(PyObject* obj) {
+ return pandas_Timedelta && PyObject_IsInstance(obj, pandas_Timedelta);
+}
+
+bool IsPandasTimestamp(PyObject* obj) {
+ return pandas_Timestamp && PyObject_IsInstance(obj, pandas_Timestamp);
+}
+
+PyObject* BorrowPandasDataOffsetType() { return pandas_DateOffset; }
+
+Status InvalidValue(PyObject* obj, const std::string& why) {
+ auto obj_as_str = PyObject_StdStringRepr(obj);
+ return Status::Invalid("Could not convert ", std::move(obj_as_str), " with type ",
+ Py_TYPE(obj)->tp_name, ": ", why);
+}
+
+Status InvalidType(PyObject* obj, const std::string& why) {
+ auto obj_as_str = PyObject_StdStringRepr(obj);
+ return Status::TypeError("Could not convert ", std::move(obj_as_str), " with type ",
+ Py_TYPE(obj)->tp_name, ": ", why);
+}
+
+Status UnboxIntegerAsInt64(PyObject* obj, int64_t* out) {
+ if (PyLong_Check(obj)) {
+ int overflow = 0;
+ *out = PyLong_AsLongLongAndOverflow(obj, &overflow);
+ if (overflow) {
+ return Status::Invalid("PyLong is too large to fit int64");
+ }
+ } else if (PyArray_IsScalar(obj, Byte)) {
+ *out = reinterpret_cast<PyByteScalarObject*>(obj)->obval;
+ } else if (PyArray_IsScalar(obj, UByte)) {
+ *out = reinterpret_cast<PyUByteScalarObject*>(obj)->obval;
+ } else if (PyArray_IsScalar(obj, Short)) {
+ *out = reinterpret_cast<PyShortScalarObject*>(obj)->obval;
+ } else if (PyArray_IsScalar(obj, UShort)) {
+ *out = reinterpret_cast<PyUShortScalarObject*>(obj)->obval;
+ } else if (PyArray_IsScalar(obj, Int)) {
+ *out = reinterpret_cast<PyIntScalarObject*>(obj)->obval;
+ } else if (PyArray_IsScalar(obj, UInt)) {
+ *out = reinterpret_cast<PyUIntScalarObject*>(obj)->obval;
+ } else if (PyArray_IsScalar(obj, Long)) {
+ *out = reinterpret_cast<PyLongScalarObject*>(obj)->obval;
+ } else if (PyArray_IsScalar(obj, ULong)) {
+ *out = reinterpret_cast<PyULongScalarObject*>(obj)->obval;
+ } else if (PyArray_IsScalar(obj, LongLong)) {
+ *out = reinterpret_cast<PyLongLongScalarObject*>(obj)->obval;
+ } else if (PyArray_IsScalar(obj, Int64)) {
+ *out = reinterpret_cast<PyInt64ScalarObject*>(obj)->obval;
+ } else if (PyArray_IsScalar(obj, ULongLong)) {
+ *out = reinterpret_cast<PyULongLongScalarObject*>(obj)->obval;
+ } else if (PyArray_IsScalar(obj, UInt64)) {
+ *out = reinterpret_cast<PyUInt64ScalarObject*>(obj)->obval;
+ } else {
+ return Status::Invalid("Integer scalar type not recognized");
+ }
+ return Status::OK();
+}
+
+Status IntegerScalarToDoubleSafe(PyObject* obj, double* out) {
+ int64_t value = 0;
+ RETURN_NOT_OK(UnboxIntegerAsInt64(obj, &value));
+
+ constexpr int64_t kDoubleMax = 1LL << 53;
+ constexpr int64_t kDoubleMin = -(1LL << 53);
+
+ if (value < kDoubleMin || value > kDoubleMax) {
+ return Status::Invalid("Integer value ", value, " is outside of the range exactly",
+ " representable by a IEEE 754 double precision value");
+ }
+ *out = static_cast<double>(value);
+ return Status::OK();
+}
+
+Status IntegerScalarToFloat32Safe(PyObject* obj, float* out) {
+ int64_t value = 0;
+ RETURN_NOT_OK(UnboxIntegerAsInt64(obj, &value));
+
+ constexpr int64_t kFloatMax = 1LL << 24;
+ constexpr int64_t kFloatMin = -(1LL << 24);
+
+ if (value < kFloatMin || value > kFloatMax) {
+ return Status::Invalid("Integer value ", value, " is outside of the range exactly",
+ " representable by a IEEE 754 single precision value");
+ }
+ *out = static_cast<float>(value);
+ return Status::OK();
+}
+
+void DebugPrint(PyObject* obj) {
+ std::string repr = PyObject_StdStringRepr(obj);
+ PySys_WriteStderr("%s\n", repr.c_str());
+}
+
+} // namespace internal
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/helpers.h b/src/arrow/cpp/src/arrow/python/helpers.h
new file mode 100644
index 000000000..a8e5f80b6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/helpers.h
@@ -0,0 +1,159 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/python/platform.h"
+
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "arrow/python/numpy_interop.h"
+
+#include <numpy/halffloat.h>
+
+#include "arrow/python/visibility.h"
+#include "arrow/type.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+
+namespace py {
+
+class OwnedRef;
+
+// \brief Get an arrow DataType instance from Arrow's Type::type enum
+// \param[in] type One of the values of Arrow's Type::type enum
+// \return A shared pointer to DataType
+ARROW_PYTHON_EXPORT std::shared_ptr<DataType> GetPrimitiveType(Type::type type);
+
+// \brief Construct a np.float16 object from a npy_half value.
+ARROW_PYTHON_EXPORT PyObject* PyHalf_FromHalf(npy_half value);
+
+// \brief Convert a Python object to a npy_half value.
+ARROW_PYTHON_EXPORT Status PyFloat_AsHalf(PyObject* obj, npy_half* out);
+
+namespace internal {
+
+// \brief Check that a Python module has been already imported
+// \param[in] module_name The name of the module
+Result<bool> IsModuleImported(const std::string& module_name);
+
+// \brief Import a Python module
+// \param[in] module_name The name of the module
+// \param[out] ref The OwnedRef containing the module PyObject*
+ARROW_PYTHON_EXPORT
+Status ImportModule(const std::string& module_name, OwnedRef* ref);
+
+// \brief Import an object from a Python module
+// \param[in] module A Python module
+// \param[in] name The name of the object to import
+// \param[out] ref The OwnedRef containing the \c name attribute of the Python module \c
+// module
+ARROW_PYTHON_EXPORT
+Status ImportFromModule(PyObject* module, const std::string& name, OwnedRef* ref);
+
+// \brief Check whether obj is an integer, independent of Python versions.
+inline bool IsPyInteger(PyObject* obj) { return PyLong_Check(obj); }
+
+// \brief Import symbols from pandas that we need for various type-checking,
+// like pandas.NaT or pandas.NA
+void InitPandasStaticData();
+
+// \brief Use pandas missing value semantics to check if a value is null
+ARROW_PYTHON_EXPORT
+bool PandasObjectIsNull(PyObject* obj);
+
+// \brief Check that obj is a pandas.Timedelta instance
+ARROW_PYTHON_EXPORT
+bool IsPandasTimedelta(PyObject* obj);
+
+// \brief Check that obj is a pandas.Timestamp instance
+bool IsPandasTimestamp(PyObject* obj);
+
+// \brief Returned a borrowed reference to the pandas.tseries.offsets.DateOffset
+PyObject* BorrowPandasDataOffsetType();
+
+// \brief Check whether obj is a floating-point NaN
+ARROW_PYTHON_EXPORT
+bool PyFloat_IsNaN(PyObject* obj);
+
+inline bool IsPyBinary(PyObject* obj) {
+ return PyBytes_Check(obj) || PyByteArray_Check(obj) || PyMemoryView_Check(obj);
+}
+
+// \brief Convert a Python integer into a C integer
+// \param[in] obj A Python integer
+// \param[out] out A pointer to a C integer to hold the result of the conversion
+// \return The status of the operation
+template <typename Int>
+Status CIntFromPython(PyObject* obj, Int* out, const std::string& overflow_message = "");
+
+// \brief Convert a Python unicode string to a std::string
+ARROW_PYTHON_EXPORT
+Status PyUnicode_AsStdString(PyObject* obj, std::string* out);
+
+// \brief Convert a Python bytes object to a std::string
+ARROW_PYTHON_EXPORT
+std::string PyBytes_AsStdString(PyObject* obj);
+
+// \brief Call str() on the given object and return the result as a std::string
+ARROW_PYTHON_EXPORT
+Status PyObject_StdStringStr(PyObject* obj, std::string* out);
+
+// \brief Return the repr() of the given object (always succeeds)
+ARROW_PYTHON_EXPORT
+std::string PyObject_StdStringRepr(PyObject* obj);
+
+// \brief Cast the given size to int32_t, with error checking
+inline Status CastSize(Py_ssize_t size, int32_t* out,
+ const char* error_msg = "Maximum size exceeded (2GB)") {
+ // size is assumed to be positive
+ if (size > std::numeric_limits<int32_t>::max()) {
+ return Status::Invalid(error_msg);
+ }
+ *out = static_cast<int32_t>(size);
+ return Status::OK();
+}
+
+inline Status CastSize(Py_ssize_t size, int64_t* out, const char* error_msg = NULLPTR) {
+ // size is assumed to be positive
+ *out = static_cast<int64_t>(size);
+ return Status::OK();
+}
+
+// \brief Print the Python object's __str__ form along with the passed error
+// message
+ARROW_PYTHON_EXPORT
+Status InvalidValue(PyObject* obj, const std::string& why);
+
+ARROW_PYTHON_EXPORT
+Status InvalidType(PyObject* obj, const std::string& why);
+
+ARROW_PYTHON_EXPORT
+Status IntegerScalarToDoubleSafe(PyObject* obj, double* result);
+ARROW_PYTHON_EXPORT
+Status IntegerScalarToFloat32Safe(PyObject* obj, float* result);
+
+// \brief Print Python object __repr__
+void DebugPrint(PyObject* obj);
+
+} // namespace internal
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/inference.cc b/src/arrow/cpp/src/arrow/python/inference.cc
new file mode 100644
index 000000000..db5f0896a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/inference.cc
@@ -0,0 +1,723 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/python/inference.h"
+#include "arrow/python/numpy_interop.h"
+
+#include <datetime.h>
+
+#include <algorithm>
+#include <limits>
+#include <map>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/status.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/logging.h"
+
+#include "arrow/python/datetime.h"
+#include "arrow/python/decimal.h"
+#include "arrow/python/helpers.h"
+#include "arrow/python/iterators.h"
+#include "arrow/python/numpy_convert.h"
+
+namespace arrow {
+namespace py {
+namespace {
+// Assigns a tuple to interval_types_tuple containing the nametuple for
+// MonthDayNanoIntervalType and if present dateutil's relativedelta and
+// pandas DateOffset.
+Status ImportPresentIntervalTypes(OwnedRefNoGIL* interval_types_tuple) {
+ OwnedRef relative_delta_module;
+ // These are Optional imports so swallow errors.
+ OwnedRef relative_delta_type;
+ // Try to import pandas to get types.
+ internal::InitPandasStaticData();
+ if (internal::ImportModule("dateutil.relativedelta", &relative_delta_module).ok()) {
+ RETURN_NOT_OK(internal::ImportFromModule(relative_delta_module.obj(), "relativedelta",
+ &relative_delta_type));
+ }
+
+ PyObject* date_offset_type = internal::BorrowPandasDataOffsetType();
+ interval_types_tuple->reset(
+ PyTuple_New(1 + (date_offset_type != nullptr ? 1 : 0) +
+ (relative_delta_type.obj() != nullptr ? 1 : 0)));
+ RETURN_IF_PYERROR();
+ int index = 0;
+ PyTuple_SetItem(interval_types_tuple->obj(), index++,
+ internal::NewMonthDayNanoTupleType());
+ RETURN_IF_PYERROR();
+ if (date_offset_type != nullptr) {
+ Py_XINCREF(date_offset_type);
+ PyTuple_SetItem(interval_types_tuple->obj(), index++, date_offset_type);
+ RETURN_IF_PYERROR();
+ }
+ if (relative_delta_type.obj() != nullptr) {
+ PyTuple_SetItem(interval_types_tuple->obj(), index++, relative_delta_type.detach());
+ RETURN_IF_PYERROR();
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+#define _NUMPY_UNIFY_NOOP(DTYPE) \
+ case NPY_##DTYPE: \
+ return OK;
+
+#define _NUMPY_UNIFY_PROMOTE(DTYPE) \
+ case NPY_##DTYPE: \
+ current_type_num_ = dtype; \
+ current_dtype_ = descr; \
+ return OK;
+
+#define _NUMPY_UNIFY_PROMOTE_TO(DTYPE, NEW_TYPE) \
+ case NPY_##DTYPE: \
+ current_type_num_ = NPY_##NEW_TYPE; \
+ current_dtype_ = PyArray_DescrFromType(current_type_num_); \
+ return OK;
+
+// Form a consensus NumPy dtype to use for Arrow conversion for a
+// collection of dtype objects observed one at a time
+class NumPyDtypeUnifier {
+ public:
+ enum Action { OK, INVALID };
+
+ NumPyDtypeUnifier() : current_type_num_(-1), current_dtype_(nullptr) {}
+
+ Status InvalidMix(int new_dtype) {
+ return Status::Invalid("Cannot mix NumPy dtypes ",
+ GetNumPyTypeName(current_type_num_), " and ",
+ GetNumPyTypeName(new_dtype));
+ }
+
+ int Observe_BOOL(PyArray_Descr* descr, int dtype) { return INVALID; }
+
+ int Observe_INT8(PyArray_Descr* descr, int dtype) {
+ switch (dtype) {
+ _NUMPY_UNIFY_PROMOTE(INT16);
+ _NUMPY_UNIFY_PROMOTE(INT32);
+ _NUMPY_UNIFY_PROMOTE(INT64);
+ _NUMPY_UNIFY_PROMOTE(FLOAT32);
+ _NUMPY_UNIFY_PROMOTE(FLOAT64);
+ default:
+ return INVALID;
+ }
+ }
+
+ int Observe_INT16(PyArray_Descr* descr, int dtype) {
+ switch (dtype) {
+ _NUMPY_UNIFY_NOOP(INT8);
+ _NUMPY_UNIFY_PROMOTE(INT32);
+ _NUMPY_UNIFY_PROMOTE(INT64);
+ _NUMPY_UNIFY_NOOP(UINT8);
+ _NUMPY_UNIFY_PROMOTE(FLOAT32);
+ _NUMPY_UNIFY_PROMOTE(FLOAT64);
+ default:
+ return INVALID;
+ }
+ }
+
+ int Observe_INT32(PyArray_Descr* descr, int dtype) {
+ switch (dtype) {
+ _NUMPY_UNIFY_NOOP(INT8);
+ _NUMPY_UNIFY_NOOP(INT16);
+ _NUMPY_UNIFY_PROMOTE(INT32);
+ _NUMPY_UNIFY_PROMOTE(INT64);
+ _NUMPY_UNIFY_NOOP(UINT8);
+ _NUMPY_UNIFY_NOOP(UINT16);
+ _NUMPY_UNIFY_PROMOTE_TO(FLOAT32, FLOAT64);
+ _NUMPY_UNIFY_PROMOTE(FLOAT64);
+ default:
+ return INVALID;
+ }
+ }
+
+ int Observe_INT64(PyArray_Descr* descr, int dtype) {
+ switch (dtype) {
+ _NUMPY_UNIFY_NOOP(INT8);
+ _NUMPY_UNIFY_NOOP(INT16);
+ _NUMPY_UNIFY_NOOP(INT32);
+ _NUMPY_UNIFY_NOOP(INT64);
+ _NUMPY_UNIFY_NOOP(UINT8);
+ _NUMPY_UNIFY_NOOP(UINT16);
+ _NUMPY_UNIFY_NOOP(UINT32);
+ _NUMPY_UNIFY_PROMOTE_TO(FLOAT32, FLOAT64);
+ _NUMPY_UNIFY_PROMOTE(FLOAT64);
+ default:
+ return INVALID;
+ }
+ }
+
+ int Observe_UINT8(PyArray_Descr* descr, int dtype) {
+ switch (dtype) {
+ _NUMPY_UNIFY_PROMOTE(UINT16);
+ _NUMPY_UNIFY_PROMOTE(UINT32);
+ _NUMPY_UNIFY_PROMOTE(UINT64);
+ _NUMPY_UNIFY_PROMOTE(FLOAT32);
+ _NUMPY_UNIFY_PROMOTE(FLOAT64);
+ default:
+ return INVALID;
+ }
+ }
+
+ int Observe_UINT16(PyArray_Descr* descr, int dtype) {
+ switch (dtype) {
+ _NUMPY_UNIFY_NOOP(UINT8);
+ _NUMPY_UNIFY_PROMOTE(UINT32);
+ _NUMPY_UNIFY_PROMOTE(UINT64);
+ _NUMPY_UNIFY_PROMOTE(FLOAT32);
+ _NUMPY_UNIFY_PROMOTE(FLOAT64);
+ default:
+ return INVALID;
+ }
+ }
+
+ int Observe_UINT32(PyArray_Descr* descr, int dtype) {
+ switch (dtype) {
+ _NUMPY_UNIFY_NOOP(UINT8);
+ _NUMPY_UNIFY_NOOP(UINT16);
+ _NUMPY_UNIFY_PROMOTE(UINT64);
+ _NUMPY_UNIFY_PROMOTE_TO(FLOAT32, FLOAT64);
+ _NUMPY_UNIFY_PROMOTE(FLOAT64);
+ default:
+ return INVALID;
+ }
+ }
+
+ int Observe_UINT64(PyArray_Descr* descr, int dtype) {
+ switch (dtype) {
+ _NUMPY_UNIFY_NOOP(UINT8);
+ _NUMPY_UNIFY_NOOP(UINT16);
+ _NUMPY_UNIFY_NOOP(UINT32);
+ _NUMPY_UNIFY_PROMOTE_TO(FLOAT32, FLOAT64);
+ _NUMPY_UNIFY_PROMOTE(FLOAT64);
+ default:
+ return INVALID;
+ }
+ }
+
+ int Observe_FLOAT16(PyArray_Descr* descr, int dtype) {
+ switch (dtype) {
+ _NUMPY_UNIFY_PROMOTE(FLOAT32);
+ _NUMPY_UNIFY_PROMOTE(FLOAT64);
+ default:
+ return INVALID;
+ }
+ }
+
+ int Observe_FLOAT32(PyArray_Descr* descr, int dtype) {
+ switch (dtype) {
+ _NUMPY_UNIFY_NOOP(INT8);
+ _NUMPY_UNIFY_NOOP(INT16);
+ _NUMPY_UNIFY_NOOP(INT32);
+ _NUMPY_UNIFY_NOOP(INT64);
+ _NUMPY_UNIFY_NOOP(UINT8);
+ _NUMPY_UNIFY_NOOP(UINT16);
+ _NUMPY_UNIFY_NOOP(UINT32);
+ _NUMPY_UNIFY_NOOP(UINT64);
+ _NUMPY_UNIFY_PROMOTE(FLOAT64);
+ default:
+ return INVALID;
+ }
+ }
+
+ int Observe_FLOAT64(PyArray_Descr* descr, int dtype) {
+ switch (dtype) {
+ _NUMPY_UNIFY_NOOP(INT8);
+ _NUMPY_UNIFY_NOOP(INT16);
+ _NUMPY_UNIFY_NOOP(INT32);
+ _NUMPY_UNIFY_NOOP(INT64);
+ _NUMPY_UNIFY_NOOP(UINT8);
+ _NUMPY_UNIFY_NOOP(UINT16);
+ _NUMPY_UNIFY_NOOP(UINT32);
+ _NUMPY_UNIFY_NOOP(UINT64);
+ default:
+ return INVALID;
+ }
+ }
+
+ int Observe_DATETIME(PyArray_Descr* dtype_obj) {
+ // TODO: check that units are all the same
+ return OK;
+ }
+
+ Status Observe(PyArray_Descr* descr) {
+ int dtype = fix_numpy_type_num(descr->type_num);
+
+ if (current_type_num_ == -1) {
+ current_dtype_ = descr;
+ current_type_num_ = dtype;
+ return Status::OK();
+ } else if (current_type_num_ == dtype) {
+ return Status::OK();
+ }
+
+#define OBSERVE_CASE(DTYPE) \
+ case NPY_##DTYPE: \
+ action = Observe_##DTYPE(descr, dtype); \
+ break;
+
+ int action = OK;
+ switch (current_type_num_) {
+ OBSERVE_CASE(BOOL);
+ OBSERVE_CASE(INT8);
+ OBSERVE_CASE(INT16);
+ OBSERVE_CASE(INT32);
+ OBSERVE_CASE(INT64);
+ OBSERVE_CASE(UINT8);
+ OBSERVE_CASE(UINT16);
+ OBSERVE_CASE(UINT32);
+ OBSERVE_CASE(UINT64);
+ OBSERVE_CASE(FLOAT16);
+ OBSERVE_CASE(FLOAT32);
+ OBSERVE_CASE(FLOAT64);
+ case NPY_DATETIME:
+ action = Observe_DATETIME(descr);
+ break;
+ default:
+ return Status::NotImplemented("Unsupported numpy type ", GetNumPyTypeName(dtype));
+ }
+
+ if (action == INVALID) {
+ return InvalidMix(dtype);
+ }
+ return Status::OK();
+ }
+
+ bool dtype_was_observed() const { return current_type_num_ != -1; }
+
+ PyArray_Descr* current_dtype() const { return current_dtype_; }
+
+ int current_type_num() const { return current_type_num_; }
+
+ private:
+ int current_type_num_;
+ PyArray_Descr* current_dtype_;
+};
+
+class TypeInferrer {
+ // A type inference visitor for Python values
+ public:
+ // \param validate_interval the number of elements to observe before checking
+ // whether the data is mixed type or has other problems. This helps avoid
+ // excess computation for each element while also making sure we "bail out"
+ // early with long sequences that may have problems up front
+ // \param make_unions permit mixed-type data by creating union types (not yet
+ // implemented)
+ explicit TypeInferrer(bool pandas_null_sentinels = false,
+ int64_t validate_interval = 100, bool make_unions = false)
+ : pandas_null_sentinels_(pandas_null_sentinels),
+ validate_interval_(validate_interval),
+ make_unions_(make_unions),
+ total_count_(0),
+ none_count_(0),
+ bool_count_(0),
+ int_count_(0),
+ date_count_(0),
+ time_count_(0),
+ timestamp_micro_count_(0),
+ duration_count_(0),
+ float_count_(0),
+ binary_count_(0),
+ unicode_count_(0),
+ decimal_count_(0),
+ list_count_(0),
+ struct_count_(0),
+ numpy_dtype_count_(0),
+ interval_count_(0),
+ max_decimal_metadata_(std::numeric_limits<int32_t>::min(),
+ std::numeric_limits<int32_t>::min()),
+ decimal_type_() {
+ ARROW_CHECK_OK(internal::ImportDecimalType(&decimal_type_));
+ ARROW_CHECK_OK(ImportPresentIntervalTypes(&interval_types_));
+ }
+
+ /// \param[in] obj a Python object in the sequence
+ /// \param[out] keep_going if sufficient information has been gathered to
+ /// attempt to begin converting the sequence, *keep_going will be set to true
+ /// to signal to the calling visitor loop to terminate
+ Status Visit(PyObject* obj, bool* keep_going) {
+ ++total_count_;
+
+ if (obj == Py_None || (pandas_null_sentinels_ && internal::PandasObjectIsNull(obj))) {
+ ++none_count_;
+ } else if (PyBool_Check(obj)) {
+ ++bool_count_;
+ *keep_going = make_unions_;
+ } else if (PyFloat_Check(obj)) {
+ ++float_count_;
+ *keep_going = make_unions_;
+ } else if (internal::IsPyInteger(obj)) {
+ ++int_count_;
+ } else if (PyDateTime_Check(obj)) {
+ // infer timezone from the first encountered datetime object
+ if (!timestamp_micro_count_) {
+ OwnedRef tzinfo(PyObject_GetAttrString(obj, "tzinfo"));
+ if (tzinfo.obj() != nullptr && tzinfo.obj() != Py_None) {
+ ARROW_ASSIGN_OR_RAISE(timezone_, internal::TzinfoToString(tzinfo.obj()));
+ }
+ }
+ ++timestamp_micro_count_;
+ *keep_going = make_unions_;
+ } else if (PyDelta_Check(obj)) {
+ ++duration_count_;
+ *keep_going = make_unions_;
+ } else if (PyDate_Check(obj)) {
+ ++date_count_;
+ *keep_going = make_unions_;
+ } else if (PyTime_Check(obj)) {
+ ++time_count_;
+ *keep_going = make_unions_;
+ } else if (internal::IsPyBinary(obj)) {
+ ++binary_count_;
+ *keep_going = make_unions_;
+ } else if (PyUnicode_Check(obj)) {
+ ++unicode_count_;
+ *keep_going = make_unions_;
+ } else if (PyArray_CheckAnyScalarExact(obj)) {
+ RETURN_NOT_OK(VisitDType(PyArray_DescrFromScalar(obj), keep_going));
+ } else if (PySet_Check(obj) || (Py_TYPE(obj) == &PyDictValues_Type)) {
+ RETURN_NOT_OK(VisitSet(obj, keep_going));
+ } else if (PyArray_Check(obj)) {
+ RETURN_NOT_OK(VisitNdarray(obj, keep_going));
+ } else if (PyDict_Check(obj)) {
+ RETURN_NOT_OK(VisitDict(obj));
+ } else if (PyList_Check(obj) ||
+ (PyTuple_Check(obj) &&
+ !PyObject_IsInstance(obj, PyTuple_GetItem(interval_types_.obj(), 0)))) {
+ RETURN_NOT_OK(VisitList(obj, keep_going));
+ } else if (PyObject_IsInstance(obj, decimal_type_.obj())) {
+ RETURN_NOT_OK(max_decimal_metadata_.Update(obj));
+ ++decimal_count_;
+ } else if (PyObject_IsInstance(obj, interval_types_.obj())) {
+ ++interval_count_;
+ } else {
+ return internal::InvalidValue(obj,
+ "did not recognize Python value type when inferring "
+ "an Arrow data type");
+ }
+
+ if (total_count_ % validate_interval_ == 0) {
+ RETURN_NOT_OK(Validate());
+ }
+
+ return Status::OK();
+ }
+
+ // Infer value type from a sequence of values
+ Status VisitSequence(PyObject* obj, PyObject* mask = nullptr) {
+ if (mask == nullptr || mask == Py_None) {
+ return internal::VisitSequence(
+ obj, /*offset=*/0,
+ [this](PyObject* value, bool* keep_going) { return Visit(value, keep_going); });
+ } else {
+ return internal::VisitSequenceMasked(
+ obj, mask, /*offset=*/0,
+ [this](PyObject* value, uint8_t masked, bool* keep_going) {
+ if (!masked) {
+ return Visit(value, keep_going);
+ } else {
+ return Status::OK();
+ }
+ });
+ }
+ }
+
+ // Infer value type from a sequence of values
+ Status VisitIterable(PyObject* obj) {
+ return internal::VisitIterable(obj, [this](PyObject* value, bool* keep_going) {
+ return Visit(value, keep_going);
+ });
+ }
+
+ Status GetType(std::shared_ptr<DataType>* out) {
+ // TODO(wesm): handling forming unions
+ if (make_unions_) {
+ return Status::NotImplemented("Creating union types not yet supported");
+ }
+
+ RETURN_NOT_OK(Validate());
+
+ if (numpy_dtype_count_ > 0) {
+ // All NumPy scalars and Nones/nulls
+ if (numpy_dtype_count_ + none_count_ == total_count_) {
+ std::shared_ptr<DataType> type;
+ RETURN_NOT_OK(NumPyDtypeToArrow(numpy_unifier_.current_dtype(), &type));
+ *out = type;
+ return Status::OK();
+ }
+
+ // The "bad path": data contains a mix of NumPy scalars and
+ // other kinds of scalars. Note this can happen innocuously
+ // because numpy.nan is not a NumPy scalar (it's a built-in
+ // PyFloat)
+
+ // TODO(ARROW-5564): Merge together type unification so this
+ // hack is not necessary
+ switch (numpy_unifier_.current_type_num()) {
+ case NPY_BOOL:
+ bool_count_ += numpy_dtype_count_;
+ break;
+ case NPY_INT8:
+ case NPY_INT16:
+ case NPY_INT32:
+ case NPY_INT64:
+ case NPY_UINT8:
+ case NPY_UINT16:
+ case NPY_UINT32:
+ case NPY_UINT64:
+ int_count_ += numpy_dtype_count_;
+ break;
+ case NPY_FLOAT32:
+ case NPY_FLOAT64:
+ float_count_ += numpy_dtype_count_;
+ break;
+ case NPY_DATETIME:
+ return Status::Invalid(
+ "numpy.datetime64 scalars cannot be mixed "
+ "with other Python scalar values currently");
+ }
+ }
+
+ if (list_count_) {
+ std::shared_ptr<DataType> value_type;
+ RETURN_NOT_OK(list_inferrer_->GetType(&value_type));
+ *out = list(value_type);
+ } else if (struct_count_) {
+ RETURN_NOT_OK(GetStructType(out));
+ } else if (decimal_count_) {
+ if (max_decimal_metadata_.precision() > Decimal128Type::kMaxPrecision) {
+ // the default constructor does not validate the precision and scale
+ ARROW_ASSIGN_OR_RAISE(*out,
+ Decimal256Type::Make(max_decimal_metadata_.precision(),
+ max_decimal_metadata_.scale()));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(*out,
+ Decimal128Type::Make(max_decimal_metadata_.precision(),
+ max_decimal_metadata_.scale()));
+ }
+ } else if (float_count_) {
+ // Prioritize floats before integers
+ *out = float64();
+ } else if (int_count_) {
+ *out = int64();
+ } else if (date_count_) {
+ *out = date32();
+ } else if (time_count_) {
+ *out = time64(TimeUnit::MICRO);
+ } else if (timestamp_micro_count_) {
+ *out = timestamp(TimeUnit::MICRO, timezone_);
+ } else if (duration_count_) {
+ *out = duration(TimeUnit::MICRO);
+ } else if (bool_count_) {
+ *out = boolean();
+ } else if (binary_count_) {
+ *out = binary();
+ } else if (unicode_count_) {
+ *out = utf8();
+ } else if (interval_count_) {
+ *out = month_day_nano_interval();
+ } else {
+ *out = null();
+ }
+ return Status::OK();
+ }
+
+ int64_t total_count() const { return total_count_; }
+
+ protected:
+ Status Validate() const {
+ if (list_count_ > 0) {
+ if (list_count_ + none_count_ != total_count_) {
+ return Status::Invalid("cannot mix list and non-list, non-null values");
+ }
+ RETURN_NOT_OK(list_inferrer_->Validate());
+ } else if (struct_count_ > 0) {
+ if (struct_count_ + none_count_ != total_count_) {
+ return Status::Invalid("cannot mix struct and non-struct, non-null values");
+ }
+ for (const auto& it : struct_inferrers_) {
+ RETURN_NOT_OK(it.second.Validate());
+ }
+ }
+ return Status::OK();
+ }
+
+ Status VisitDType(PyArray_Descr* dtype, bool* keep_going) {
+ // Continue visiting dtypes for now.
+ // TODO(wesm): devise approach for unions
+ ++numpy_dtype_count_;
+ *keep_going = true;
+ return numpy_unifier_.Observe(dtype);
+ }
+
+ Status VisitList(PyObject* obj, bool* keep_going /* unused */) {
+ if (!list_inferrer_) {
+ list_inferrer_.reset(
+ new TypeInferrer(pandas_null_sentinels_, validate_interval_, make_unions_));
+ }
+ ++list_count_;
+ return list_inferrer_->VisitSequence(obj);
+ }
+
+ Status VisitSet(PyObject* obj, bool* keep_going /* unused */) {
+ if (!list_inferrer_) {
+ list_inferrer_.reset(
+ new TypeInferrer(pandas_null_sentinels_, validate_interval_, make_unions_));
+ }
+ ++list_count_;
+ return list_inferrer_->VisitIterable(obj);
+ }
+
+ Status VisitNdarray(PyObject* obj, bool* keep_going) {
+ PyArray_Descr* dtype = PyArray_DESCR(reinterpret_cast<PyArrayObject*>(obj));
+ if (dtype->type_num == NPY_OBJECT) {
+ return VisitList(obj, keep_going);
+ }
+ // Not an object array: infer child Arrow type from dtype
+ if (!list_inferrer_) {
+ list_inferrer_.reset(
+ new TypeInferrer(pandas_null_sentinels_, validate_interval_, make_unions_));
+ }
+ ++list_count_;
+
+ // XXX(wesm): In ARROW-4324 I added accounting to check whether
+ // all of the non-null values have NumPy dtypes, but the
+ // total_count not not being properly incremented here
+ ++(*list_inferrer_).total_count_;
+ return list_inferrer_->VisitDType(dtype, keep_going);
+ }
+
+ Status VisitDict(PyObject* obj) {
+ PyObject* key_obj;
+ PyObject* value_obj;
+ Py_ssize_t pos = 0;
+
+ while (PyDict_Next(obj, &pos, &key_obj, &value_obj)) {
+ std::string key;
+ if (PyUnicode_Check(key_obj)) {
+ RETURN_NOT_OK(internal::PyUnicode_AsStdString(key_obj, &key));
+ } else if (PyBytes_Check(key_obj)) {
+ key = internal::PyBytes_AsStdString(key_obj);
+ } else {
+ return Status::TypeError("Expected dict key of type str or bytes, got '",
+ Py_TYPE(key_obj)->tp_name, "'");
+ }
+ // Get or create visitor for this key
+ auto it = struct_inferrers_.find(key);
+ if (it == struct_inferrers_.end()) {
+ it = struct_inferrers_
+ .insert(
+ std::make_pair(key, TypeInferrer(pandas_null_sentinels_,
+ validate_interval_, make_unions_)))
+ .first;
+ }
+ TypeInferrer* visitor = &it->second;
+
+ // We ignore termination signals from child visitors for now
+ //
+ // TODO(wesm): keep track of whether type inference has terminated for
+ // the child visitors to avoid doing unneeded work
+ bool keep_going = true;
+ RETURN_NOT_OK(visitor->Visit(value_obj, &keep_going));
+ }
+
+ // We do not terminate visiting dicts since we want the union of all
+ // observed keys
+ ++struct_count_;
+ return Status::OK();
+ }
+
+ Status GetStructType(std::shared_ptr<DataType>* out) {
+ std::vector<std::shared_ptr<Field>> fields;
+ for (auto&& it : struct_inferrers_) {
+ std::shared_ptr<DataType> field_type;
+ RETURN_NOT_OK(it.second.GetType(&field_type));
+ fields.emplace_back(field(it.first, field_type));
+ }
+ *out = struct_(fields);
+ return Status::OK();
+ }
+
+ private:
+ bool pandas_null_sentinels_;
+ int64_t validate_interval_;
+ bool make_unions_;
+ int64_t total_count_;
+ int64_t none_count_;
+ int64_t bool_count_;
+ int64_t int_count_;
+ int64_t date_count_;
+ int64_t time_count_;
+ int64_t timestamp_micro_count_;
+ std::string timezone_;
+ int64_t duration_count_;
+ int64_t float_count_;
+ int64_t binary_count_;
+ int64_t unicode_count_;
+ int64_t decimal_count_;
+ int64_t list_count_;
+ int64_t struct_count_;
+ int64_t numpy_dtype_count_;
+ int64_t interval_count_;
+ std::unique_ptr<TypeInferrer> list_inferrer_;
+ std::map<std::string, TypeInferrer> struct_inferrers_;
+
+ // If we observe a strongly-typed value in e.g. a NumPy array, we can store
+ // it here to skip the type counting logic above
+ NumPyDtypeUnifier numpy_unifier_;
+
+ internal::DecimalMetadata max_decimal_metadata_;
+
+ OwnedRefNoGIL decimal_type_;
+ OwnedRefNoGIL interval_types_;
+};
+
+// Non-exhaustive type inference
+Result<std::shared_ptr<DataType>> InferArrowType(PyObject* obj, PyObject* mask,
+ bool pandas_null_sentinels) {
+ if (pandas_null_sentinels) {
+ // ARROW-842: If pandas is not installed then null checks will be less
+ // comprehensive, but that is okay.
+ internal::InitPandasStaticData();
+ }
+
+ std::shared_ptr<DataType> out_type;
+ TypeInferrer inferrer(pandas_null_sentinels);
+ RETURN_NOT_OK(inferrer.VisitSequence(obj, mask));
+ RETURN_NOT_OK(inferrer.GetType(&out_type));
+ if (out_type == nullptr) {
+ return Status::TypeError("Unable to determine data type");
+ } else {
+ return std::move(out_type);
+ }
+}
+
+ARROW_PYTHON_EXPORT
+bool IsPyBool(PyObject* obj) { return internal::PyBoolScalar_Check(obj); }
+
+ARROW_PYTHON_EXPORT
+bool IsPyInt(PyObject* obj) { return internal::PyIntScalar_Check(obj); }
+
+ARROW_PYTHON_EXPORT
+bool IsPyFloat(PyObject* obj) { return internal::PyFloatScalar_Check(obj); }
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/inference.h b/src/arrow/cpp/src/arrow/python/inference.h
new file mode 100644
index 000000000..eff183629
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/inference.h
@@ -0,0 +1,64 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Functions for converting between CPython built-in data structures and Arrow
+// data structures
+
+#pragma once
+
+#include "arrow/python/platform.h"
+
+#include <memory>
+
+#include "arrow/python/visibility.h"
+#include "arrow/type.h"
+#include "arrow/util/macros.h"
+
+#include "arrow/python/common.h"
+
+namespace arrow {
+
+class Array;
+class Status;
+
+namespace py {
+
+// These functions take a sequence input, not arbitrary iterables
+
+/// \brief Infer Arrow type from a Python sequence
+/// \param[in] obj the sequence of values
+/// \param[in] mask an optional mask where True values are null. May
+/// be nullptr
+/// \param[in] pandas_null_sentinels use pandas's null value markers
+ARROW_PYTHON_EXPORT
+Result<std::shared_ptr<arrow::DataType>> InferArrowType(PyObject* obj, PyObject* mask,
+ bool pandas_null_sentinels);
+
+/// Checks whether the passed Python object is a boolean scalar
+ARROW_PYTHON_EXPORT
+bool IsPyBool(PyObject* obj);
+
+/// Checks whether the passed Python object is an integer scalar
+ARROW_PYTHON_EXPORT
+bool IsPyInt(PyObject* obj);
+
+/// Checks whether the passed Python object is a float scalar
+ARROW_PYTHON_EXPORT
+bool IsPyFloat(PyObject* obj);
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/init.cc b/src/arrow/cpp/src/arrow/python/init.cc
new file mode 100644
index 000000000..dba293bbe
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/init.cc
@@ -0,0 +1,24 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Trigger the array import (inversion of NO_IMPORT_ARRAY)
+#define NUMPY_IMPORT_ARRAY
+
+#include "arrow/python/init.h"
+#include "arrow/python/numpy_interop.h"
+
+int arrow_init_numpy() { return arrow::py::import_numpy(); }
diff --git a/src/arrow/cpp/src/arrow/python/init.h b/src/arrow/cpp/src/arrow/python/init.h
new file mode 100644
index 000000000..2e6c95486
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/init.h
@@ -0,0 +1,26 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/python/platform.h"
+#include "arrow/python/visibility.h"
+
+extern "C" {
+ARROW_PYTHON_EXPORT
+int arrow_init_numpy();
+}
diff --git a/src/arrow/cpp/src/arrow/python/io.cc b/src/arrow/cpp/src/arrow/python/io.cc
new file mode 100644
index 000000000..73525feed
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/io.cc
@@ -0,0 +1,374 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/python/io.h"
+
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
+#include <mutex>
+#include <string>
+
+#include "arrow/io/memory.h"
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+
+#include "arrow/python/common.h"
+#include "arrow/python/pyarrow.h"
+
+namespace arrow {
+
+using arrow::io::TransformInputStream;
+
+namespace py {
+
+// ----------------------------------------------------------------------
+// Python file
+
+// A common interface to a Python file-like object. Must acquire GIL before
+// calling any methods
+class PythonFile {
+ public:
+ explicit PythonFile(PyObject* file) : file_(file), checked_read_buffer_(false) {
+ Py_INCREF(file);
+ }
+
+ Status CheckClosed() const {
+ if (!file_) {
+ return Status::Invalid("operation on closed Python file");
+ }
+ return Status::OK();
+ }
+
+ Status Close() {
+ if (file_) {
+ PyObject* result = cpp_PyObject_CallMethod(file_.obj(), "close", "()");
+ Py_XDECREF(result);
+ file_.reset();
+ PY_RETURN_IF_ERROR(StatusCode::IOError);
+ }
+ return Status::OK();
+ }
+
+ Status Abort() {
+ file_.reset();
+ return Status::OK();
+ }
+
+ bool closed() const {
+ if (!file_) {
+ return true;
+ }
+ PyObject* result = PyObject_GetAttrString(file_.obj(), "closed");
+ if (result == NULL) {
+ // Can't propagate the error, so write it out and return an arbitrary value
+ PyErr_WriteUnraisable(NULL);
+ return true;
+ }
+ int ret = PyObject_IsTrue(result);
+ Py_XDECREF(result);
+ if (ret < 0) {
+ PyErr_WriteUnraisable(NULL);
+ return true;
+ }
+ return ret != 0;
+ }
+
+ Status Seek(int64_t position, int whence) {
+ RETURN_NOT_OK(CheckClosed());
+
+ // whence: 0 for relative to start of file, 2 for end of file
+ PyObject* result = cpp_PyObject_CallMethod(file_.obj(), "seek", "(ni)",
+ static_cast<Py_ssize_t>(position), whence);
+ Py_XDECREF(result);
+ PY_RETURN_IF_ERROR(StatusCode::IOError);
+ return Status::OK();
+ }
+
+ Status Read(int64_t nbytes, PyObject** out) {
+ RETURN_NOT_OK(CheckClosed());
+
+ PyObject* result = cpp_PyObject_CallMethod(file_.obj(), "read", "(n)",
+ static_cast<Py_ssize_t>(nbytes));
+ PY_RETURN_IF_ERROR(StatusCode::IOError);
+ *out = result;
+ return Status::OK();
+ }
+
+ Status ReadBuffer(int64_t nbytes, PyObject** out) {
+ PyObject* result = cpp_PyObject_CallMethod(file_.obj(), "read_buffer", "(n)",
+ static_cast<Py_ssize_t>(nbytes));
+ PY_RETURN_IF_ERROR(StatusCode::IOError);
+ *out = result;
+ return Status::OK();
+ }
+
+ Status Write(const void* data, int64_t nbytes) {
+ RETURN_NOT_OK(CheckClosed());
+
+ // Since the data isn't owned, we have to make a copy
+ PyObject* py_data =
+ PyBytes_FromStringAndSize(reinterpret_cast<const char*>(data), nbytes);
+ PY_RETURN_IF_ERROR(StatusCode::IOError);
+
+ PyObject* result = cpp_PyObject_CallMethod(file_.obj(), "write", "(O)", py_data);
+ Py_XDECREF(py_data);
+ Py_XDECREF(result);
+ PY_RETURN_IF_ERROR(StatusCode::IOError);
+ return Status::OK();
+ }
+
+ Status Write(const std::shared_ptr<Buffer>& buffer) {
+ RETURN_NOT_OK(CheckClosed());
+
+ PyObject* py_data = wrap_buffer(buffer);
+ PY_RETURN_IF_ERROR(StatusCode::IOError);
+
+ PyObject* result = cpp_PyObject_CallMethod(file_.obj(), "write", "(O)", py_data);
+ Py_XDECREF(py_data);
+ Py_XDECREF(result);
+ PY_RETURN_IF_ERROR(StatusCode::IOError);
+ return Status::OK();
+ }
+
+ Result<int64_t> Tell() {
+ RETURN_NOT_OK(CheckClosed());
+
+ PyObject* result = cpp_PyObject_CallMethod(file_.obj(), "tell", "()");
+ PY_RETURN_IF_ERROR(StatusCode::IOError);
+
+ int64_t position = PyLong_AsLongLong(result);
+ Py_DECREF(result);
+
+ // PyLong_AsLongLong can raise OverflowError
+ PY_RETURN_IF_ERROR(StatusCode::IOError);
+ return position;
+ }
+
+ std::mutex& lock() { return lock_; }
+
+ bool HasReadBuffer() {
+ if (!checked_read_buffer_) { // we don't want to check this each time
+ has_read_buffer_ = PyObject_HasAttrString(file_.obj(), "read_buffer") == 1;
+ checked_read_buffer_ = true;
+ }
+ return has_read_buffer_;
+ }
+
+ private:
+ std::mutex lock_;
+ OwnedRefNoGIL file_;
+ bool has_read_buffer_;
+ bool checked_read_buffer_;
+};
+
+// ----------------------------------------------------------------------
+// Seekable input stream
+
+PyReadableFile::PyReadableFile(PyObject* file) { file_.reset(new PythonFile(file)); }
+
+// The destructor does not close the underlying Python file object, as
+// there may be multiple references to it. Instead let the Python
+// destructor do its job.
+PyReadableFile::~PyReadableFile() {}
+
+Status PyReadableFile::Abort() {
+ return SafeCallIntoPython([this]() { return file_->Abort(); });
+}
+
+Status PyReadableFile::Close() {
+ return SafeCallIntoPython([this]() { return file_->Close(); });
+}
+
+bool PyReadableFile::closed() const {
+ bool res;
+ Status st = SafeCallIntoPython([this, &res]() {
+ res = file_->closed();
+ return Status::OK();
+ });
+ return res;
+}
+
+Status PyReadableFile::Seek(int64_t position) {
+ return SafeCallIntoPython([=] { return file_->Seek(position, 0); });
+}
+
+Result<int64_t> PyReadableFile::Tell() const {
+ return SafeCallIntoPython([=]() -> Result<int64_t> { return file_->Tell(); });
+}
+
+Result<int64_t> PyReadableFile::Read(int64_t nbytes, void* out) {
+ return SafeCallIntoPython([=]() -> Result<int64_t> {
+ OwnedRef bytes;
+ RETURN_NOT_OK(file_->Read(nbytes, bytes.ref()));
+ PyObject* bytes_obj = bytes.obj();
+ DCHECK(bytes_obj != NULL);
+
+ Py_buffer py_buf;
+ if (!PyObject_GetBuffer(bytes_obj, &py_buf, PyBUF_ANY_CONTIGUOUS)) {
+ const uint8_t* data = reinterpret_cast<const uint8_t*>(py_buf.buf);
+ std::memcpy(out, data, py_buf.len);
+ int64_t len = py_buf.len;
+ PyBuffer_Release(&py_buf);
+ return len;
+ } else {
+ return Status::TypeError(
+ "Python file read() should have returned a bytes object or an object "
+ "supporting the buffer protocol, got '",
+ Py_TYPE(bytes_obj)->tp_name, "' (did you open the file in binary mode?)");
+ }
+ });
+}
+
+Result<std::shared_ptr<Buffer>> PyReadableFile::Read(int64_t nbytes) {
+ return SafeCallIntoPython([=]() -> Result<std::shared_ptr<Buffer>> {
+ OwnedRef buffer_obj;
+ if (file_->HasReadBuffer()) {
+ RETURN_NOT_OK(file_->ReadBuffer(nbytes, buffer_obj.ref()));
+ } else {
+ RETURN_NOT_OK(file_->Read(nbytes, buffer_obj.ref()));
+ }
+ DCHECK(buffer_obj.obj() != NULL);
+
+ return PyBuffer::FromPyObject(buffer_obj.obj());
+ });
+}
+
+Result<int64_t> PyReadableFile::ReadAt(int64_t position, int64_t nbytes, void* out) {
+ std::lock_guard<std::mutex> guard(file_->lock());
+ return SafeCallIntoPython([=]() -> Result<int64_t> {
+ RETURN_NOT_OK(Seek(position));
+ return Read(nbytes, out);
+ });
+}
+
+Result<std::shared_ptr<Buffer>> PyReadableFile::ReadAt(int64_t position, int64_t nbytes) {
+ std::lock_guard<std::mutex> guard(file_->lock());
+ return SafeCallIntoPython([=]() -> Result<std::shared_ptr<Buffer>> {
+ RETURN_NOT_OK(Seek(position));
+ return Read(nbytes);
+ });
+}
+
+Result<int64_t> PyReadableFile::GetSize() {
+ return SafeCallIntoPython([=]() -> Result<int64_t> {
+ ARROW_ASSIGN_OR_RAISE(int64_t current_position, file_->Tell());
+ RETURN_NOT_OK(file_->Seek(0, 2));
+
+ ARROW_ASSIGN_OR_RAISE(int64_t file_size, file_->Tell());
+ // Restore previous file position
+ RETURN_NOT_OK(file_->Seek(current_position, 0));
+
+ return file_size;
+ });
+}
+
+// ----------------------------------------------------------------------
+// Output stream
+
+PyOutputStream::PyOutputStream(PyObject* file) : position_(0) {
+ file_.reset(new PythonFile(file));
+}
+
+// The destructor does not close the underlying Python file object, as
+// there may be multiple references to it. Instead let the Python
+// destructor do its job.
+PyOutputStream::~PyOutputStream() {}
+
+Status PyOutputStream::Abort() {
+ return SafeCallIntoPython([=]() { return file_->Abort(); });
+}
+
+Status PyOutputStream::Close() {
+ return SafeCallIntoPython([=]() { return file_->Close(); });
+}
+
+bool PyOutputStream::closed() const {
+ bool res;
+ Status st = SafeCallIntoPython([this, &res]() {
+ res = file_->closed();
+ return Status::OK();
+ });
+ return res;
+}
+
+Result<int64_t> PyOutputStream::Tell() const { return position_; }
+
+Status PyOutputStream::Write(const void* data, int64_t nbytes) {
+ return SafeCallIntoPython([=]() {
+ position_ += nbytes;
+ return file_->Write(data, nbytes);
+ });
+}
+
+Status PyOutputStream::Write(const std::shared_ptr<Buffer>& buffer) {
+ return SafeCallIntoPython([=]() {
+ position_ += buffer->size();
+ return file_->Write(buffer);
+ });
+}
+
+// ----------------------------------------------------------------------
+// Foreign buffer
+
+Status PyForeignBuffer::Make(const uint8_t* data, int64_t size, PyObject* base,
+ std::shared_ptr<Buffer>* out) {
+ PyForeignBuffer* buf = new PyForeignBuffer(data, size, base);
+ if (buf == NULL) {
+ return Status::OutOfMemory("could not allocate foreign buffer object");
+ } else {
+ *out = std::shared_ptr<Buffer>(buf);
+ return Status::OK();
+ }
+}
+
+// ----------------------------------------------------------------------
+// TransformInputStream::TransformFunc wrapper
+
+struct TransformFunctionWrapper {
+ TransformFunctionWrapper(TransformCallback cb, PyObject* arg)
+ : cb_(std::move(cb)), arg_(std::make_shared<OwnedRefNoGIL>(arg)) {
+ Py_INCREF(arg);
+ }
+
+ Result<std::shared_ptr<Buffer>> operator()(const std::shared_ptr<Buffer>& src) {
+ return SafeCallIntoPython([=]() -> Result<std::shared_ptr<Buffer>> {
+ std::shared_ptr<Buffer> dest;
+ cb_(arg_->obj(), src, &dest);
+ RETURN_NOT_OK(CheckPyError());
+ return dest;
+ });
+ }
+
+ protected:
+ // Need to wrap OwnedRefNoGIL because std::function needs the callable
+ // to be copy-constructible...
+ TransformCallback cb_;
+ std::shared_ptr<OwnedRefNoGIL> arg_;
+};
+
+std::shared_ptr<::arrow::io::InputStream> MakeTransformInputStream(
+ std::shared_ptr<::arrow::io::InputStream> wrapped, TransformInputStreamVTable vtable,
+ PyObject* handler) {
+ TransformInputStream::TransformFunc transform(
+ TransformFunctionWrapper{std::move(vtable.transform), handler});
+ return std::make_shared<TransformInputStream>(std::move(wrapped), std::move(transform));
+}
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/io.h b/src/arrow/cpp/src/arrow/python/io.h
new file mode 100644
index 000000000..a38d0ca33
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/io.h
@@ -0,0 +1,116 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/io/interfaces.h"
+#include "arrow/io/transform.h"
+
+#include "arrow/python/common.h"
+#include "arrow/python/visibility.h"
+
+namespace arrow {
+namespace py {
+
+class ARROW_NO_EXPORT PythonFile;
+
+class ARROW_PYTHON_EXPORT PyReadableFile : public io::RandomAccessFile {
+ public:
+ explicit PyReadableFile(PyObject* file);
+ ~PyReadableFile() override;
+
+ Status Close() override;
+ Status Abort() override;
+ bool closed() const override;
+
+ Result<int64_t> Read(int64_t nbytes, void* out) override;
+ Result<std::shared_ptr<Buffer>> Read(int64_t nbytes) override;
+
+ // Thread-safe version
+ Result<int64_t> ReadAt(int64_t position, int64_t nbytes, void* out) override;
+
+ // Thread-safe version
+ Result<std::shared_ptr<Buffer>> ReadAt(int64_t position, int64_t nbytes) override;
+
+ Result<int64_t> GetSize() override;
+
+ Status Seek(int64_t position) override;
+
+ Result<int64_t> Tell() const override;
+
+ private:
+ std::unique_ptr<PythonFile> file_;
+};
+
+class ARROW_PYTHON_EXPORT PyOutputStream : public io::OutputStream {
+ public:
+ explicit PyOutputStream(PyObject* file);
+ ~PyOutputStream() override;
+
+ Status Close() override;
+ Status Abort() override;
+ bool closed() const override;
+ Result<int64_t> Tell() const override;
+ Status Write(const void* data, int64_t nbytes) override;
+ Status Write(const std::shared_ptr<Buffer>& buffer) override;
+
+ private:
+ std::unique_ptr<PythonFile> file_;
+ int64_t position_;
+};
+
+// TODO(wesm): seekable output files
+
+// A Buffer subclass that keeps a PyObject reference throughout its
+// lifetime, such that the Python object is kept alive as long as the
+// C++ buffer is still needed.
+// Keeping the reference in a Python wrapper would be incorrect as
+// the Python wrapper can get destroyed even though the wrapped C++
+// buffer is still alive (ARROW-2270).
+class ARROW_PYTHON_EXPORT PyForeignBuffer : public Buffer {
+ public:
+ static Status Make(const uint8_t* data, int64_t size, PyObject* base,
+ std::shared_ptr<Buffer>* out);
+
+ private:
+ PyForeignBuffer(const uint8_t* data, int64_t size, PyObject* base)
+ : Buffer(data, size) {
+ Py_INCREF(base);
+ base_.reset(base);
+ }
+
+ OwnedRefNoGIL base_;
+};
+
+// All this rigamarole because Cython is really poor with std::function<>
+
+using TransformCallback = std::function<void(
+ PyObject*, const std::shared_ptr<Buffer>& src, std::shared_ptr<Buffer>* out)>;
+
+struct TransformInputStreamVTable {
+ TransformCallback transform;
+};
+
+ARROW_PYTHON_EXPORT
+std::shared_ptr<::arrow::io::InputStream> MakeTransformInputStream(
+ std::shared_ptr<::arrow::io::InputStream> wrapped, TransformInputStreamVTable vtable,
+ PyObject* arg);
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/ipc.cc b/src/arrow/cpp/src/arrow/python/ipc.cc
new file mode 100644
index 000000000..2e6c9d912
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/ipc.cc
@@ -0,0 +1,67 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/python/ipc.h"
+
+#include <memory>
+
+#include "arrow/python/pyarrow.h"
+
+namespace arrow {
+namespace py {
+
+PyRecordBatchReader::PyRecordBatchReader() {}
+
+Status PyRecordBatchReader::Init(std::shared_ptr<Schema> schema, PyObject* iterable) {
+ schema_ = std::move(schema);
+
+ iterator_.reset(PyObject_GetIter(iterable));
+ return CheckPyError();
+}
+
+std::shared_ptr<Schema> PyRecordBatchReader::schema() const { return schema_; }
+
+Status PyRecordBatchReader::ReadNext(std::shared_ptr<RecordBatch>* batch) {
+ PyAcquireGIL lock;
+
+ if (!iterator_) {
+ // End of stream
+ batch->reset();
+ return Status::OK();
+ }
+
+ OwnedRef py_batch(PyIter_Next(iterator_.obj()));
+ if (!py_batch) {
+ RETURN_IF_PYERROR();
+ // End of stream
+ batch->reset();
+ iterator_.reset();
+ return Status::OK();
+ }
+
+ return unwrap_batch(py_batch.obj()).Value(batch);
+}
+
+Result<std::shared_ptr<RecordBatchReader>> PyRecordBatchReader::Make(
+ std::shared_ptr<Schema> schema, PyObject* iterable) {
+ auto reader = std::shared_ptr<PyRecordBatchReader>(new PyRecordBatchReader());
+ RETURN_NOT_OK(reader->Init(std::move(schema), iterable));
+ return reader;
+}
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/ipc.h b/src/arrow/cpp/src/arrow/python/ipc.h
new file mode 100644
index 000000000..92232ed83
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/ipc.h
@@ -0,0 +1,52 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/python/common.h"
+#include "arrow/python/visibility.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace py {
+
+class ARROW_PYTHON_EXPORT PyRecordBatchReader : public RecordBatchReader {
+ public:
+ std::shared_ptr<Schema> schema() const override;
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* batch) override;
+
+ // For use from Cython
+ // Assumes that `iterable` is borrowed
+ static Result<std::shared_ptr<RecordBatchReader>> Make(std::shared_ptr<Schema>,
+ PyObject* iterable);
+
+ protected:
+ PyRecordBatchReader();
+
+ Status Init(std::shared_ptr<Schema>, PyObject* iterable);
+
+ std::shared_ptr<Schema> schema_;
+ OwnedRefNoGIL iterator_;
+};
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/iterators.h b/src/arrow/cpp/src/arrow/python/iterators.h
new file mode 100644
index 000000000..7b31962da
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/iterators.h
@@ -0,0 +1,194 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <utility>
+
+#include "arrow/array/array_primitive.h"
+
+#include "arrow/python/common.h"
+#include "arrow/python/numpy_internal.h"
+
+namespace arrow {
+namespace py {
+namespace internal {
+
+using arrow::internal::checked_cast;
+
+// Visit the Python sequence, calling the given callable on each element. If
+// the callable returns a non-OK status, iteration stops and the status is
+// returned.
+//
+// The call signature for Visitor must be
+//
+// Visit(PyObject* obj, int64_t index, bool* keep_going)
+//
+// If keep_going is set to false, the iteration terminates
+template <class VisitorFunc>
+inline Status VisitSequenceGeneric(PyObject* obj, int64_t offset, VisitorFunc&& func) {
+ // VisitorFunc may set to false to terminate iteration
+ bool keep_going = true;
+
+ if (PyArray_Check(obj)) {
+ PyArrayObject* arr_obj = reinterpret_cast<PyArrayObject*>(obj);
+ if (PyArray_NDIM(arr_obj) != 1) {
+ return Status::Invalid("Only 1D arrays accepted");
+ }
+
+ if (PyArray_DESCR(arr_obj)->type_num == NPY_OBJECT) {
+ // It's an array object, we can fetch object pointers directly
+ const Ndarray1DIndexer<PyObject*> objects(arr_obj);
+ for (int64_t i = offset; keep_going && i < objects.size(); ++i) {
+ RETURN_NOT_OK(func(objects[i], i, &keep_going));
+ }
+ return Status::OK();
+ }
+ // It's a non-object array, fall back on regular sequence access.
+ // (note PyArray_GETITEM() is slightly different: it returns standard
+ // Python types, not Numpy scalar types)
+ // This code path is inefficient: callers should implement dedicated
+ // logic for non-object arrays.
+ }
+ if (PySequence_Check(obj)) {
+ if (PyList_Check(obj) || PyTuple_Check(obj)) {
+ // Use fast item access
+ const Py_ssize_t size = PySequence_Fast_GET_SIZE(obj);
+ for (Py_ssize_t i = offset; keep_going && i < size; ++i) {
+ PyObject* value = PySequence_Fast_GET_ITEM(obj, i);
+ RETURN_NOT_OK(func(value, static_cast<int64_t>(i), &keep_going));
+ }
+ } else {
+ // Regular sequence: avoid making a potentially large copy
+ const Py_ssize_t size = PySequence_Size(obj);
+ RETURN_IF_PYERROR();
+ for (Py_ssize_t i = offset; keep_going && i < size; ++i) {
+ OwnedRef value_ref(PySequence_ITEM(obj, i));
+ RETURN_IF_PYERROR();
+ RETURN_NOT_OK(func(value_ref.obj(), static_cast<int64_t>(i), &keep_going));
+ }
+ }
+ } else {
+ return Status::TypeError("Object is not a sequence");
+ }
+ return Status::OK();
+}
+
+// Visit sequence with no null mask
+template <class VisitorFunc>
+inline Status VisitSequence(PyObject* obj, int64_t offset, VisitorFunc&& func) {
+ return VisitSequenceGeneric(
+ obj, offset, [&func](PyObject* value, int64_t i /* unused */, bool* keep_going) {
+ return func(value, keep_going);
+ });
+}
+
+/// Visit sequence with null mask
+template <class VisitorFunc>
+inline Status VisitSequenceMasked(PyObject* obj, PyObject* mo, int64_t offset,
+ VisitorFunc&& func) {
+ if (PyArray_Check(mo)) {
+ PyArrayObject* mask = reinterpret_cast<PyArrayObject*>(mo);
+ if (PyArray_NDIM(mask) != 1) {
+ return Status::Invalid("Mask must be 1D array");
+ }
+ if (PyArray_SIZE(mask) != static_cast<int64_t>(PySequence_Size(obj))) {
+ return Status::Invalid("Mask was a different length from sequence being converted");
+ }
+
+ const int dtype = fix_numpy_type_num(PyArray_DESCR(mask)->type_num);
+ if (dtype == NPY_BOOL) {
+ Ndarray1DIndexer<uint8_t> mask_values(mask);
+
+ return VisitSequenceGeneric(
+ obj, offset,
+ [&func, &mask_values](PyObject* value, int64_t i, bool* keep_going) {
+ return func(value, mask_values[i], keep_going);
+ });
+ } else {
+ return Status::TypeError("Mask must be boolean dtype");
+ }
+ } else if (py::is_array(mo)) {
+ auto unwrap_mask_result = unwrap_array(mo);
+ ARROW_RETURN_NOT_OK(unwrap_mask_result);
+ std::shared_ptr<Array> mask_ = unwrap_mask_result.ValueOrDie();
+ if (mask_->type_id() != Type::type::BOOL) {
+ return Status::TypeError("Mask must be an array of booleans");
+ }
+
+ if (mask_->length() != PySequence_Size(obj)) {
+ return Status::Invalid("Mask was a different length from sequence being converted");
+ }
+
+ if (mask_->null_count() != 0) {
+ return Status::TypeError("Mask must be an array of booleans");
+ }
+
+ BooleanArray* boolmask = checked_cast<BooleanArray*>(mask_.get());
+ return VisitSequenceGeneric(
+ obj, offset, [&func, &boolmask](PyObject* value, int64_t i, bool* keep_going) {
+ return func(value, boolmask->Value(i), keep_going);
+ });
+ } else if (PySequence_Check(mo)) {
+ if (PySequence_Size(mo) != PySequence_Size(obj)) {
+ return Status::Invalid("Mask was a different length from sequence being converted");
+ }
+ RETURN_IF_PYERROR();
+
+ return VisitSequenceGeneric(
+ obj, offset, [&func, &mo](PyObject* value, int64_t i, bool* keep_going) {
+ OwnedRef value_ref(PySequence_ITEM(mo, i));
+ if (!PyBool_Check(value_ref.obj()))
+ return Status::TypeError("Mask must be a sequence of booleans");
+ return func(value, value_ref.obj() == Py_True, keep_going);
+ });
+ } else {
+ return Status::Invalid("Null mask must be a NumPy array, Arrow array or a Sequence");
+ }
+
+ return Status::OK();
+}
+
+// Like IterateSequence, but accepts any generic iterable (including
+// non-restartable iterators, e.g. generators).
+//
+// The call signature for VisitorFunc must be Visit(PyObject*, bool*
+// keep_going). If keep_going is set to false, the iteration terminates
+template <class VisitorFunc>
+inline Status VisitIterable(PyObject* obj, VisitorFunc&& func) {
+ if (PySequence_Check(obj)) {
+ // Numpy arrays fall here as well
+ return VisitSequence(obj, /*offset=*/0, std::forward<VisitorFunc>(func));
+ }
+ // Fall back on the iterator protocol
+ OwnedRef iter_ref(PyObject_GetIter(obj));
+ PyObject* iter = iter_ref.obj();
+ RETURN_IF_PYERROR();
+ PyObject* value;
+
+ bool keep_going = true;
+ while (keep_going && (value = PyIter_Next(iter))) {
+ OwnedRef value_ref(value);
+ RETURN_NOT_OK(func(value_ref.obj(), &keep_going));
+ }
+ RETURN_IF_PYERROR(); // __next__() might have raised
+ return Status::OK();
+}
+
+} // namespace internal
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/numpy_convert.cc b/src/arrow/cpp/src/arrow/python/numpy_convert.cc
new file mode 100644
index 000000000..497068076
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/numpy_convert.cc
@@ -0,0 +1,562 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/python/numpy_interop.h"
+
+#include "arrow/python/numpy_convert.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/sparse_tensor.h"
+#include "arrow/tensor.h"
+#include "arrow/type.h"
+#include "arrow/util/logging.h"
+
+#include "arrow/python/common.h"
+#include "arrow/python/pyarrow.h"
+#include "arrow/python/type_traits.h"
+
+namespace arrow {
+namespace py {
+
+NumPyBuffer::NumPyBuffer(PyObject* ao) : Buffer(nullptr, 0) {
+ PyAcquireGIL lock;
+ arr_ = ao;
+ Py_INCREF(ao);
+
+ if (PyArray_Check(ao)) {
+ PyArrayObject* ndarray = reinterpret_cast<PyArrayObject*>(ao);
+ auto ptr = reinterpret_cast<uint8_t*>(PyArray_DATA(ndarray));
+ data_ = const_cast<const uint8_t*>(ptr);
+ size_ = PyArray_SIZE(ndarray) * PyArray_DESCR(ndarray)->elsize;
+ capacity_ = size_;
+ is_mutable_ = !!(PyArray_FLAGS(ndarray) & NPY_ARRAY_WRITEABLE);
+ }
+}
+
+NumPyBuffer::~NumPyBuffer() {
+ PyAcquireGIL lock;
+ Py_XDECREF(arr_);
+}
+
+#define TO_ARROW_TYPE_CASE(NPY_NAME, FACTORY) \
+ case NPY_##NPY_NAME: \
+ *out = FACTORY(); \
+ break;
+
+namespace {
+
+Status GetTensorType(PyObject* dtype, std::shared_ptr<DataType>* out) {
+ if (!PyObject_TypeCheck(dtype, &PyArrayDescr_Type)) {
+ return Status::TypeError("Did not pass numpy.dtype object");
+ }
+ PyArray_Descr* descr = reinterpret_cast<PyArray_Descr*>(dtype);
+ int type_num = fix_numpy_type_num(descr->type_num);
+
+ switch (type_num) {
+ TO_ARROW_TYPE_CASE(BOOL, uint8);
+ TO_ARROW_TYPE_CASE(INT8, int8);
+ TO_ARROW_TYPE_CASE(INT16, int16);
+ TO_ARROW_TYPE_CASE(INT32, int32);
+ TO_ARROW_TYPE_CASE(INT64, int64);
+ TO_ARROW_TYPE_CASE(UINT8, uint8);
+ TO_ARROW_TYPE_CASE(UINT16, uint16);
+ TO_ARROW_TYPE_CASE(UINT32, uint32);
+ TO_ARROW_TYPE_CASE(UINT64, uint64);
+ TO_ARROW_TYPE_CASE(FLOAT16, float16);
+ TO_ARROW_TYPE_CASE(FLOAT32, float32);
+ TO_ARROW_TYPE_CASE(FLOAT64, float64);
+ default: {
+ return Status::NotImplemented("Unsupported numpy type ", descr->type_num);
+ }
+ }
+ return Status::OK();
+}
+
+Status GetNumPyType(const DataType& type, int* type_num) {
+#define NUMPY_TYPE_CASE(ARROW_NAME, NPY_NAME) \
+ case Type::ARROW_NAME: \
+ *type_num = NPY_##NPY_NAME; \
+ break;
+
+ switch (type.id()) {
+ NUMPY_TYPE_CASE(UINT8, UINT8);
+ NUMPY_TYPE_CASE(INT8, INT8);
+ NUMPY_TYPE_CASE(UINT16, UINT16);
+ NUMPY_TYPE_CASE(INT16, INT16);
+ NUMPY_TYPE_CASE(UINT32, UINT32);
+ NUMPY_TYPE_CASE(INT32, INT32);
+ NUMPY_TYPE_CASE(UINT64, UINT64);
+ NUMPY_TYPE_CASE(INT64, INT64);
+ NUMPY_TYPE_CASE(HALF_FLOAT, FLOAT16);
+ NUMPY_TYPE_CASE(FLOAT, FLOAT32);
+ NUMPY_TYPE_CASE(DOUBLE, FLOAT64);
+ default: {
+ return Status::NotImplemented("Unsupported tensor type: ", type.ToString());
+ }
+ }
+#undef NUMPY_TYPE_CASE
+
+ return Status::OK();
+}
+
+} // namespace
+
+Status NumPyDtypeToArrow(PyObject* dtype, std::shared_ptr<DataType>* out) {
+ if (!PyObject_TypeCheck(dtype, &PyArrayDescr_Type)) {
+ return Status::TypeError("Did not pass numpy.dtype object");
+ }
+ PyArray_Descr* descr = reinterpret_cast<PyArray_Descr*>(dtype);
+ return NumPyDtypeToArrow(descr, out);
+}
+
+Status NumPyDtypeToArrow(PyArray_Descr* descr, std::shared_ptr<DataType>* out) {
+ int type_num = fix_numpy_type_num(descr->type_num);
+
+ switch (type_num) {
+ TO_ARROW_TYPE_CASE(BOOL, boolean);
+ TO_ARROW_TYPE_CASE(INT8, int8);
+ TO_ARROW_TYPE_CASE(INT16, int16);
+ TO_ARROW_TYPE_CASE(INT32, int32);
+ TO_ARROW_TYPE_CASE(INT64, int64);
+ TO_ARROW_TYPE_CASE(UINT8, uint8);
+ TO_ARROW_TYPE_CASE(UINT16, uint16);
+ TO_ARROW_TYPE_CASE(UINT32, uint32);
+ TO_ARROW_TYPE_CASE(UINT64, uint64);
+ TO_ARROW_TYPE_CASE(FLOAT16, float16);
+ TO_ARROW_TYPE_CASE(FLOAT32, float32);
+ TO_ARROW_TYPE_CASE(FLOAT64, float64);
+ TO_ARROW_TYPE_CASE(STRING, binary);
+ TO_ARROW_TYPE_CASE(UNICODE, utf8);
+ case NPY_DATETIME: {
+ auto date_dtype =
+ reinterpret_cast<PyArray_DatetimeDTypeMetaData*>(descr->c_metadata);
+ switch (date_dtype->meta.base) {
+ case NPY_FR_s:
+ *out = timestamp(TimeUnit::SECOND);
+ break;
+ case NPY_FR_ms:
+ *out = timestamp(TimeUnit::MILLI);
+ break;
+ case NPY_FR_us:
+ *out = timestamp(TimeUnit::MICRO);
+ break;
+ case NPY_FR_ns:
+ *out = timestamp(TimeUnit::NANO);
+ break;
+ case NPY_FR_D:
+ *out = date32();
+ break;
+ case NPY_FR_GENERIC:
+ return Status::NotImplemented("Unbound or generic datetime64 time unit");
+ default:
+ return Status::NotImplemented("Unsupported datetime64 time unit");
+ }
+ } break;
+ case NPY_TIMEDELTA: {
+ auto timedelta_dtype =
+ reinterpret_cast<PyArray_DatetimeDTypeMetaData*>(descr->c_metadata);
+ switch (timedelta_dtype->meta.base) {
+ case NPY_FR_s:
+ *out = duration(TimeUnit::SECOND);
+ break;
+ case NPY_FR_ms:
+ *out = duration(TimeUnit::MILLI);
+ break;
+ case NPY_FR_us:
+ *out = duration(TimeUnit::MICRO);
+ break;
+ case NPY_FR_ns:
+ *out = duration(TimeUnit::NANO);
+ break;
+ case NPY_FR_GENERIC:
+ return Status::NotImplemented("Unbound or generic timedelta64 time unit");
+ default:
+ return Status::NotImplemented("Unsupported timedelta64 time unit");
+ }
+ } break;
+ default: {
+ return Status::NotImplemented("Unsupported numpy type ", descr->type_num);
+ }
+ }
+
+ return Status::OK();
+}
+
+#undef TO_ARROW_TYPE_CASE
+
+Status NdarrayToTensor(MemoryPool* pool, PyObject* ao,
+ const std::vector<std::string>& dim_names,
+ std::shared_ptr<Tensor>* out) {
+ if (!PyArray_Check(ao)) {
+ return Status::TypeError("Did not pass ndarray object");
+ }
+
+ PyArrayObject* ndarray = reinterpret_cast<PyArrayObject*>(ao);
+
+ // TODO(wesm): What do we want to do with non-contiguous memory and negative strides?
+
+ int ndim = PyArray_NDIM(ndarray);
+
+ std::shared_ptr<Buffer> data = std::make_shared<NumPyBuffer>(ao);
+ std::vector<int64_t> shape(ndim);
+ std::vector<int64_t> strides(ndim);
+
+ npy_intp* array_strides = PyArray_STRIDES(ndarray);
+ npy_intp* array_shape = PyArray_SHAPE(ndarray);
+ for (int i = 0; i < ndim; ++i) {
+ if (array_strides[i] < 0) {
+ return Status::Invalid("Negative ndarray strides not supported");
+ }
+ shape[i] = array_shape[i];
+ strides[i] = array_strides[i];
+ }
+
+ std::shared_ptr<DataType> type;
+ RETURN_NOT_OK(
+ GetTensorType(reinterpret_cast<PyObject*>(PyArray_DESCR(ndarray)), &type));
+ *out = std::make_shared<Tensor>(type, data, shape, strides, dim_names);
+ return Status::OK();
+}
+
+Status TensorToNdarray(const std::shared_ptr<Tensor>& tensor, PyObject* base,
+ PyObject** out) {
+ int type_num = 0;
+ RETURN_NOT_OK(GetNumPyType(*tensor->type(), &type_num));
+ PyArray_Descr* dtype = PyArray_DescrNewFromType(type_num);
+ RETURN_IF_PYERROR();
+
+ const int ndim = tensor->ndim();
+ std::vector<npy_intp> npy_shape(ndim);
+ std::vector<npy_intp> npy_strides(ndim);
+
+ for (int i = 0; i < ndim; ++i) {
+ npy_shape[i] = tensor->shape()[i];
+ npy_strides[i] = tensor->strides()[i];
+ }
+
+ const void* immutable_data = nullptr;
+ if (tensor->data()) {
+ immutable_data = tensor->data()->data();
+ }
+
+ // Remove const =(
+ void* mutable_data = const_cast<void*>(immutable_data);
+
+ int array_flags = 0;
+ if (tensor->is_row_major()) {
+ array_flags |= NPY_ARRAY_C_CONTIGUOUS;
+ }
+ if (tensor->is_column_major()) {
+ array_flags |= NPY_ARRAY_F_CONTIGUOUS;
+ }
+ if (tensor->is_mutable()) {
+ array_flags |= NPY_ARRAY_WRITEABLE;
+ }
+
+ PyObject* result =
+ PyArray_NewFromDescr(&PyArray_Type, dtype, ndim, npy_shape.data(),
+ npy_strides.data(), mutable_data, array_flags, nullptr);
+ RETURN_IF_PYERROR();
+
+ if (base == Py_None || base == nullptr) {
+ base = py::wrap_tensor(tensor);
+ } else {
+ Py_XINCREF(base);
+ }
+ PyArray_SetBaseObject(reinterpret_cast<PyArrayObject*>(result), base);
+ *out = result;
+ return Status::OK();
+}
+
+// Wrap the dense data of a sparse tensor in a ndarray
+static Status SparseTensorDataToNdarray(const SparseTensor& sparse_tensor,
+ std::vector<npy_intp> data_shape, PyObject* base,
+ PyObject** out_data) {
+ int type_num_data = 0;
+ RETURN_NOT_OK(GetNumPyType(*sparse_tensor.type(), &type_num_data));
+ PyArray_Descr* dtype_data = PyArray_DescrNewFromType(type_num_data);
+ RETURN_IF_PYERROR();
+
+ const void* immutable_data = sparse_tensor.data()->data();
+ // Remove const =(
+ void* mutable_data = const_cast<void*>(immutable_data);
+ int array_flags = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS;
+ if (sparse_tensor.is_mutable()) {
+ array_flags |= NPY_ARRAY_WRITEABLE;
+ }
+
+ *out_data = PyArray_NewFromDescr(&PyArray_Type, dtype_data,
+ static_cast<int>(data_shape.size()), data_shape.data(),
+ nullptr, mutable_data, array_flags, nullptr);
+ RETURN_IF_PYERROR();
+ Py_XINCREF(base);
+ PyArray_SetBaseObject(reinterpret_cast<PyArrayObject*>(*out_data), base);
+ return Status::OK();
+}
+
+Status SparseCOOTensorToNdarray(const std::shared_ptr<SparseCOOTensor>& sparse_tensor,
+ PyObject* base, PyObject** out_data,
+ PyObject** out_coords) {
+ const auto& sparse_index = arrow::internal::checked_cast<const SparseCOOIndex&>(
+ *sparse_tensor->sparse_index());
+
+ // Wrap tensor data
+ OwnedRef result_data;
+ RETURN_NOT_OK(SparseTensorDataToNdarray(
+ *sparse_tensor, {static_cast<npy_intp>(sparse_tensor->non_zero_length()), 1}, base,
+ result_data.ref()));
+
+ // Wrap indices
+ PyObject* result_coords;
+ RETURN_NOT_OK(TensorToNdarray(sparse_index.indices(), base, &result_coords));
+
+ *out_data = result_data.detach();
+ *out_coords = result_coords;
+ return Status::OK();
+}
+
+Status SparseCSXMatrixToNdarray(const std::shared_ptr<SparseTensor>& sparse_tensor,
+ PyObject* base, PyObject** out_data,
+ PyObject** out_indptr, PyObject** out_indices) {
+ // Wrap indices
+ OwnedRef result_indptr;
+ OwnedRef result_indices;
+
+ switch (sparse_tensor->format_id()) {
+ case SparseTensorFormat::CSR: {
+ const auto& sparse_index = arrow::internal::checked_cast<const SparseCSRIndex&>(
+ *sparse_tensor->sparse_index());
+ RETURN_NOT_OK(TensorToNdarray(sparse_index.indptr(), base, result_indptr.ref()));
+ RETURN_NOT_OK(TensorToNdarray(sparse_index.indices(), base, result_indices.ref()));
+ break;
+ }
+ case SparseTensorFormat::CSC: {
+ const auto& sparse_index = arrow::internal::checked_cast<const SparseCSCIndex&>(
+ *sparse_tensor->sparse_index());
+ RETURN_NOT_OK(TensorToNdarray(sparse_index.indptr(), base, result_indptr.ref()));
+ RETURN_NOT_OK(TensorToNdarray(sparse_index.indices(), base, result_indices.ref()));
+ break;
+ }
+ default:
+ return Status::NotImplemented("Invalid SparseTensor type.");
+ }
+
+ // Wrap tensor data
+ OwnedRef result_data;
+ RETURN_NOT_OK(SparseTensorDataToNdarray(
+ *sparse_tensor, {static_cast<npy_intp>(sparse_tensor->non_zero_length()), 1}, base,
+ result_data.ref()));
+
+ *out_data = result_data.detach();
+ *out_indptr = result_indptr.detach();
+ *out_indices = result_indices.detach();
+ return Status::OK();
+}
+
+Status SparseCSRMatrixToNdarray(const std::shared_ptr<SparseCSRMatrix>& sparse_tensor,
+ PyObject* base, PyObject** out_data,
+ PyObject** out_indptr, PyObject** out_indices) {
+ return SparseCSXMatrixToNdarray(sparse_tensor, base, out_data, out_indptr, out_indices);
+}
+
+Status SparseCSCMatrixToNdarray(const std::shared_ptr<SparseCSCMatrix>& sparse_tensor,
+ PyObject* base, PyObject** out_data,
+ PyObject** out_indptr, PyObject** out_indices) {
+ return SparseCSXMatrixToNdarray(sparse_tensor, base, out_data, out_indptr, out_indices);
+}
+
+Status SparseCSFTensorToNdarray(const std::shared_ptr<SparseCSFTensor>& sparse_tensor,
+ PyObject* base, PyObject** out_data,
+ PyObject** out_indptr, PyObject** out_indices) {
+ const auto& sparse_index = arrow::internal::checked_cast<const SparseCSFIndex&>(
+ *sparse_tensor->sparse_index());
+
+ // Wrap tensor data
+ OwnedRef result_data;
+ RETURN_NOT_OK(SparseTensorDataToNdarray(
+ *sparse_tensor, {static_cast<npy_intp>(sparse_tensor->non_zero_length()), 1}, base,
+ result_data.ref()));
+
+ // Wrap indices
+ int ndim = static_cast<int>(sparse_index.indices().size());
+ OwnedRef indptr(PyList_New(ndim - 1));
+ OwnedRef indices(PyList_New(ndim));
+ RETURN_IF_PYERROR();
+
+ for (int i = 0; i < ndim - 1; ++i) {
+ PyObject* item;
+ RETURN_NOT_OK(TensorToNdarray(sparse_index.indptr()[i], base, &item));
+ if (PyList_SetItem(indptr.obj(), i, item) < 0) {
+ Py_XDECREF(item);
+ RETURN_IF_PYERROR();
+ }
+ }
+ for (int i = 0; i < ndim; ++i) {
+ PyObject* item;
+ RETURN_NOT_OK(TensorToNdarray(sparse_index.indices()[i], base, &item));
+ if (PyList_SetItem(indices.obj(), i, item) < 0) {
+ Py_XDECREF(item);
+ RETURN_IF_PYERROR();
+ }
+ }
+
+ *out_indptr = indptr.detach();
+ *out_indices = indices.detach();
+ *out_data = result_data.detach();
+ return Status::OK();
+}
+
+Status NdarraysToSparseCOOTensor(MemoryPool* pool, PyObject* data_ao, PyObject* coords_ao,
+ const std::vector<int64_t>& shape,
+ const std::vector<std::string>& dim_names,
+ std::shared_ptr<SparseCOOTensor>* out) {
+ if (!PyArray_Check(data_ao) || !PyArray_Check(coords_ao)) {
+ return Status::TypeError("Did not pass ndarray object");
+ }
+
+ PyArrayObject* ndarray_data = reinterpret_cast<PyArrayObject*>(data_ao);
+ std::shared_ptr<Buffer> data = std::make_shared<NumPyBuffer>(data_ao);
+ std::shared_ptr<DataType> type_data;
+ RETURN_NOT_OK(GetTensorType(reinterpret_cast<PyObject*>(PyArray_DESCR(ndarray_data)),
+ &type_data));
+
+ std::shared_ptr<Tensor> coords;
+ RETURN_NOT_OK(NdarrayToTensor(pool, coords_ao, {}, &coords));
+ ARROW_CHECK_EQ(coords->type_id(), Type::INT64); // Should be ensured by caller
+
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<SparseCOOIndex> sparse_index,
+ SparseCOOIndex::Make(coords));
+ *out = std::make_shared<SparseTensorImpl<SparseCOOIndex>>(sparse_index, type_data, data,
+ shape, dim_names);
+ return Status::OK();
+}
+
+template <class IndexType>
+Status NdarraysToSparseCSXMatrix(MemoryPool* pool, PyObject* data_ao, PyObject* indptr_ao,
+ PyObject* indices_ao, const std::vector<int64_t>& shape,
+ const std::vector<std::string>& dim_names,
+ std::shared_ptr<SparseTensorImpl<IndexType>>* out) {
+ if (!PyArray_Check(data_ao) || !PyArray_Check(indptr_ao) ||
+ !PyArray_Check(indices_ao)) {
+ return Status::TypeError("Did not pass ndarray object");
+ }
+
+ PyArrayObject* ndarray_data = reinterpret_cast<PyArrayObject*>(data_ao);
+ std::shared_ptr<Buffer> data = std::make_shared<NumPyBuffer>(data_ao);
+ std::shared_ptr<DataType> type_data;
+ RETURN_NOT_OK(GetTensorType(reinterpret_cast<PyObject*>(PyArray_DESCR(ndarray_data)),
+ &type_data));
+
+ std::shared_ptr<Tensor> indptr, indices;
+ RETURN_NOT_OK(NdarrayToTensor(pool, indptr_ao, {}, &indptr));
+ RETURN_NOT_OK(NdarrayToTensor(pool, indices_ao, {}, &indices));
+ ARROW_CHECK_EQ(indptr->type_id(), Type::INT64); // Should be ensured by caller
+ ARROW_CHECK_EQ(indices->type_id(), Type::INT64); // Should be ensured by caller
+
+ auto sparse_index = std::make_shared<IndexType>(
+ std::static_pointer_cast<NumericTensor<Int64Type>>(indptr),
+ std::static_pointer_cast<NumericTensor<Int64Type>>(indices));
+ *out = std::make_shared<SparseTensorImpl<IndexType>>(sparse_index, type_data, data,
+ shape, dim_names);
+ return Status::OK();
+}
+
+Status NdarraysToSparseCSFTensor(MemoryPool* pool, PyObject* data_ao, PyObject* indptr_ao,
+ PyObject* indices_ao, const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& axis_order,
+ const std::vector<std::string>& dim_names,
+ std::shared_ptr<SparseCSFTensor>* out) {
+ if (!PyArray_Check(data_ao)) {
+ return Status::TypeError("Did not pass ndarray object for data");
+ }
+ const int ndim = static_cast<const int>(shape.size());
+ PyArrayObject* ndarray_data = reinterpret_cast<PyArrayObject*>(data_ao);
+ std::shared_ptr<Buffer> data = std::make_shared<NumPyBuffer>(data_ao);
+ std::shared_ptr<DataType> type_data;
+ RETURN_NOT_OK(GetTensorType(reinterpret_cast<PyObject*>(PyArray_DESCR(ndarray_data)),
+ &type_data));
+
+ std::vector<std::shared_ptr<Tensor>> indptr(ndim - 1);
+ std::vector<std::shared_ptr<Tensor>> indices(ndim);
+
+ for (int i = 0; i < ndim - 1; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(indptr_ao, i);
+ if (!PyArray_Check(item)) {
+ return Status::TypeError("Did not pass ndarray object for indptr");
+ }
+ RETURN_NOT_OK(NdarrayToTensor(pool, item, {}, &indptr[i]));
+ ARROW_CHECK_EQ(indptr[i]->type_id(), Type::INT64); // Should be ensured by caller
+ }
+
+ for (int i = 0; i < ndim; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(indices_ao, i);
+ if (!PyArray_Check(item)) {
+ return Status::TypeError("Did not pass ndarray object for indices");
+ }
+ RETURN_NOT_OK(NdarrayToTensor(pool, item, {}, &indices[i]));
+ ARROW_CHECK_EQ(indices[i]->type_id(), Type::INT64); // Should be ensured by caller
+ }
+
+ auto sparse_index = std::make_shared<SparseCSFIndex>(indptr, indices, axis_order);
+ *out = std::make_shared<SparseTensorImpl<SparseCSFIndex>>(sparse_index, type_data, data,
+ shape, dim_names);
+ return Status::OK();
+}
+
+Status NdarraysToSparseCSRMatrix(MemoryPool* pool, PyObject* data_ao, PyObject* indptr_ao,
+ PyObject* indices_ao, const std::vector<int64_t>& shape,
+ const std::vector<std::string>& dim_names,
+ std::shared_ptr<SparseCSRMatrix>* out) {
+ return NdarraysToSparseCSXMatrix<SparseCSRIndex>(pool, data_ao, indptr_ao, indices_ao,
+ shape, dim_names, out);
+}
+
+Status NdarraysToSparseCSCMatrix(MemoryPool* pool, PyObject* data_ao, PyObject* indptr_ao,
+ PyObject* indices_ao, const std::vector<int64_t>& shape,
+ const std::vector<std::string>& dim_names,
+ std::shared_ptr<SparseCSCMatrix>* out) {
+ return NdarraysToSparseCSXMatrix<SparseCSCIndex>(pool, data_ao, indptr_ao, indices_ao,
+ shape, dim_names, out);
+}
+
+Status TensorToSparseCOOTensor(const std::shared_ptr<Tensor>& tensor,
+ std::shared_ptr<SparseCOOTensor>* out) {
+ return SparseCOOTensor::Make(*tensor).Value(out);
+}
+
+Status TensorToSparseCSRMatrix(const std::shared_ptr<Tensor>& tensor,
+ std::shared_ptr<SparseCSRMatrix>* out) {
+ return SparseCSRMatrix::Make(*tensor).Value(out);
+}
+
+Status TensorToSparseCSCMatrix(const std::shared_ptr<Tensor>& tensor,
+ std::shared_ptr<SparseCSCMatrix>* out) {
+ return SparseCSCMatrix::Make(*tensor).Value(out);
+}
+
+Status TensorToSparseCSFTensor(const std::shared_ptr<Tensor>& tensor,
+ std::shared_ptr<SparseCSFTensor>* out) {
+ return SparseCSFTensor::Make(*tensor).Value(out);
+}
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/numpy_convert.h b/src/arrow/cpp/src/arrow/python/numpy_convert.h
new file mode 100644
index 000000000..10451077a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/numpy_convert.h
@@ -0,0 +1,120 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Functions for converting between pandas's NumPy-based data representation
+// and Arrow data structures
+
+#pragma once
+
+#include "arrow/python/platform.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/python/visibility.h"
+#include "arrow/sparse_tensor.h"
+
+namespace arrow {
+
+class DataType;
+class MemoryPool;
+class Status;
+class Tensor;
+
+namespace py {
+
+class ARROW_PYTHON_EXPORT NumPyBuffer : public Buffer {
+ public:
+ explicit NumPyBuffer(PyObject* arr);
+ virtual ~NumPyBuffer();
+
+ private:
+ PyObject* arr_;
+};
+
+ARROW_PYTHON_EXPORT
+Status NumPyDtypeToArrow(PyObject* dtype, std::shared_ptr<DataType>* out);
+ARROW_PYTHON_EXPORT
+Status NumPyDtypeToArrow(PyArray_Descr* descr, std::shared_ptr<DataType>* out);
+
+ARROW_PYTHON_EXPORT Status NdarrayToTensor(MemoryPool* pool, PyObject* ao,
+ const std::vector<std::string>& dim_names,
+ std::shared_ptr<Tensor>* out);
+
+ARROW_PYTHON_EXPORT Status TensorToNdarray(const std::shared_ptr<Tensor>& tensor,
+ PyObject* base, PyObject** out);
+
+ARROW_PYTHON_EXPORT Status
+SparseCOOTensorToNdarray(const std::shared_ptr<SparseCOOTensor>& sparse_tensor,
+ PyObject* base, PyObject** out_data, PyObject** out_coords);
+
+Status SparseCSXMatrixToNdarray(const std::shared_ptr<SparseTensor>& sparse_tensor,
+ PyObject* base, PyObject** out_data,
+ PyObject** out_indptr, PyObject** out_indices);
+
+ARROW_PYTHON_EXPORT Status SparseCSRMatrixToNdarray(
+ const std::shared_ptr<SparseCSRMatrix>& sparse_tensor, PyObject* base,
+ PyObject** out_data, PyObject** out_indptr, PyObject** out_indices);
+
+ARROW_PYTHON_EXPORT Status SparseCSCMatrixToNdarray(
+ const std::shared_ptr<SparseCSCMatrix>& sparse_tensor, PyObject* base,
+ PyObject** out_data, PyObject** out_indptr, PyObject** out_indices);
+
+ARROW_PYTHON_EXPORT Status SparseCSFTensorToNdarray(
+ const std::shared_ptr<SparseCSFTensor>& sparse_tensor, PyObject* base,
+ PyObject** out_data, PyObject** out_indptr, PyObject** out_indices);
+
+ARROW_PYTHON_EXPORT Status NdarraysToSparseCOOTensor(
+ MemoryPool* pool, PyObject* data_ao, PyObject* coords_ao,
+ const std::vector<int64_t>& shape, const std::vector<std::string>& dim_names,
+ std::shared_ptr<SparseCOOTensor>* out);
+
+ARROW_PYTHON_EXPORT Status NdarraysToSparseCSRMatrix(
+ MemoryPool* pool, PyObject* data_ao, PyObject* indptr_ao, PyObject* indices_ao,
+ const std::vector<int64_t>& shape, const std::vector<std::string>& dim_names,
+ std::shared_ptr<SparseCSRMatrix>* out);
+
+ARROW_PYTHON_EXPORT Status NdarraysToSparseCSCMatrix(
+ MemoryPool* pool, PyObject* data_ao, PyObject* indptr_ao, PyObject* indices_ao,
+ const std::vector<int64_t>& shape, const std::vector<std::string>& dim_names,
+ std::shared_ptr<SparseCSCMatrix>* out);
+
+ARROW_PYTHON_EXPORT Status NdarraysToSparseCSFTensor(
+ MemoryPool* pool, PyObject* data_ao, PyObject* indptr_ao, PyObject* indices_ao,
+ const std::vector<int64_t>& shape, const std::vector<int64_t>& axis_order,
+ const std::vector<std::string>& dim_names, std::shared_ptr<SparseCSFTensor>* out);
+
+ARROW_PYTHON_EXPORT Status
+TensorToSparseCOOTensor(const std::shared_ptr<Tensor>& tensor,
+ std::shared_ptr<SparseCOOTensor>* csparse_tensor);
+
+ARROW_PYTHON_EXPORT Status
+TensorToSparseCSRMatrix(const std::shared_ptr<Tensor>& tensor,
+ std::shared_ptr<SparseCSRMatrix>* csparse_tensor);
+
+ARROW_PYTHON_EXPORT Status
+TensorToSparseCSCMatrix(const std::shared_ptr<Tensor>& tensor,
+ std::shared_ptr<SparseCSCMatrix>* csparse_tensor);
+
+ARROW_PYTHON_EXPORT Status
+TensorToSparseCSFTensor(const std::shared_ptr<Tensor>& tensor,
+ std::shared_ptr<SparseCSFTensor>* csparse_tensor);
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/numpy_internal.h b/src/arrow/cpp/src/arrow/python/numpy_internal.h
new file mode 100644
index 000000000..973f577cb
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/numpy_internal.h
@@ -0,0 +1,182 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Internal utilities for dealing with NumPy
+
+#pragma once
+
+#include "arrow/python/numpy_interop.h"
+
+#include "arrow/status.h"
+
+#include "arrow/python/platform.h"
+
+#include <cstdint>
+#include <sstream>
+#include <string>
+
+namespace arrow {
+namespace py {
+
+/// Indexing convenience for interacting with strided 1-dim ndarray objects
+template <typename T>
+class Ndarray1DIndexer {
+ public:
+ typedef int64_t size_type;
+
+ Ndarray1DIndexer() : arr_(NULLPTR), data_(NULLPTR) {}
+
+ explicit Ndarray1DIndexer(PyArrayObject* arr) : Ndarray1DIndexer() {
+ arr_ = arr;
+ DCHECK_EQ(1, PyArray_NDIM(arr)) << "Only works with 1-dimensional arrays";
+ Py_INCREF(arr);
+ data_ = reinterpret_cast<uint8_t*>(PyArray_DATA(arr));
+ stride_ = PyArray_STRIDES(arr)[0];
+ }
+
+ ~Ndarray1DIndexer() { Py_XDECREF(arr_); }
+
+ int64_t size() const { return PyArray_SIZE(arr_); }
+
+ const T* data() const { return reinterpret_cast<const T*>(data_); }
+
+ bool is_strided() const { return stride_ != sizeof(T); }
+
+ T& operator[](size_type index) {
+ return *reinterpret_cast<T*>(data_ + index * stride_);
+ }
+ const T& operator[](size_type index) const {
+ return *reinterpret_cast<const T*>(data_ + index * stride_);
+ }
+
+ private:
+ PyArrayObject* arr_;
+ uint8_t* data_;
+ int64_t stride_;
+};
+
+// Handling of Numpy Types by their static numbers
+// (the NPY_TYPES enum and related defines)
+
+static inline std::string GetNumPyTypeName(int npy_type) {
+#define TYPE_CASE(TYPE, NAME) \
+ case NPY_##TYPE: \
+ return NAME;
+
+ switch (npy_type) {
+ TYPE_CASE(BOOL, "bool")
+ TYPE_CASE(INT8, "int8")
+ TYPE_CASE(INT16, "int16")
+ TYPE_CASE(INT32, "int32")
+ TYPE_CASE(INT64, "int64")
+#if !NPY_INT32_IS_INT
+ TYPE_CASE(INT, "intc")
+#endif
+#if !NPY_INT64_IS_LONG_LONG
+ TYPE_CASE(LONGLONG, "longlong")
+#endif
+ TYPE_CASE(UINT8, "uint8")
+ TYPE_CASE(UINT16, "uint16")
+ TYPE_CASE(UINT32, "uint32")
+ TYPE_CASE(UINT64, "uint64")
+#if !NPY_INT32_IS_INT
+ TYPE_CASE(UINT, "uintc")
+#endif
+#if !NPY_INT64_IS_LONG_LONG
+ TYPE_CASE(ULONGLONG, "ulonglong")
+#endif
+ TYPE_CASE(FLOAT16, "float16")
+ TYPE_CASE(FLOAT32, "float32")
+ TYPE_CASE(FLOAT64, "float64")
+ TYPE_CASE(DATETIME, "datetime64")
+ TYPE_CASE(TIMEDELTA, "timedelta64")
+ TYPE_CASE(OBJECT, "object")
+ TYPE_CASE(VOID, "void")
+ default:
+ break;
+ }
+
+#undef TYPE_CASE
+ std::stringstream ss;
+ ss << "unrecognized type (" << npy_type << ") in GetNumPyTypeName";
+ return ss.str();
+}
+
+#define TYPE_VISIT_INLINE(TYPE) \
+ case NPY_##TYPE: \
+ return visitor->template Visit<NPY_##TYPE>(arr);
+
+template <typename VISITOR>
+inline Status VisitNumpyArrayInline(PyArrayObject* arr, VISITOR* visitor) {
+ switch (PyArray_TYPE(arr)) {
+ TYPE_VISIT_INLINE(BOOL);
+ TYPE_VISIT_INLINE(INT8);
+ TYPE_VISIT_INLINE(UINT8);
+ TYPE_VISIT_INLINE(INT16);
+ TYPE_VISIT_INLINE(UINT16);
+ TYPE_VISIT_INLINE(INT32);
+ TYPE_VISIT_INLINE(UINT32);
+ TYPE_VISIT_INLINE(INT64);
+ TYPE_VISIT_INLINE(UINT64);
+#if !NPY_INT32_IS_INT
+ TYPE_VISIT_INLINE(INT);
+ TYPE_VISIT_INLINE(UINT);
+#endif
+#if !NPY_INT64_IS_LONG_LONG
+ TYPE_VISIT_INLINE(LONGLONG);
+ TYPE_VISIT_INLINE(ULONGLONG);
+#endif
+ TYPE_VISIT_INLINE(FLOAT16);
+ TYPE_VISIT_INLINE(FLOAT32);
+ TYPE_VISIT_INLINE(FLOAT64);
+ TYPE_VISIT_INLINE(DATETIME);
+ TYPE_VISIT_INLINE(TIMEDELTA);
+ TYPE_VISIT_INLINE(OBJECT);
+ }
+ return Status::NotImplemented("NumPy type not implemented: ",
+ GetNumPyTypeName(PyArray_TYPE(arr)));
+}
+
+#undef TYPE_VISIT_INLINE
+
+namespace internal {
+
+inline bool PyFloatScalar_Check(PyObject* obj) {
+ return PyFloat_Check(obj) || PyArray_IsScalar(obj, Floating);
+}
+
+inline bool PyIntScalar_Check(PyObject* obj) {
+ return PyLong_Check(obj) || PyArray_IsScalar(obj, Integer);
+}
+
+inline bool PyBoolScalar_Check(PyObject* obj) {
+ return PyBool_Check(obj) || PyArray_IsScalar(obj, Bool);
+}
+
+static inline PyArray_Descr* GetSafeNumPyDtype(int type) {
+ if (type == NPY_DATETIME) {
+ // It is not safe to mutate the result of DescrFromType
+ return PyArray_DescrNewFromType(type);
+ } else {
+ return PyArray_DescrFromType(type);
+ }
+}
+
+} // namespace internal
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/numpy_interop.h b/src/arrow/cpp/src/arrow/python/numpy_interop.h
new file mode 100644
index 000000000..ce7baed25
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/numpy_interop.h
@@ -0,0 +1,96 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/python/platform.h" // IWYU pragma: export
+
+#include <numpy/numpyconfig.h> // IWYU pragma: export
+
+// Don't use the deprecated Numpy functions
+#ifdef NPY_1_7_API_VERSION
+#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
+#else
+#define NPY_ARRAY_NOTSWAPPED NPY_NOTSWAPPED
+#define NPY_ARRAY_ALIGNED NPY_ALIGNED
+#define NPY_ARRAY_WRITEABLE NPY_WRITEABLE
+#define NPY_ARRAY_UPDATEIFCOPY NPY_UPDATEIFCOPY
+#endif
+
+// This is required to be able to access the NumPy C API properly in C++ files
+// other than init.cc.
+#define PY_ARRAY_UNIQUE_SYMBOL arrow_ARRAY_API
+#ifndef NUMPY_IMPORT_ARRAY
+#define NO_IMPORT_ARRAY
+#endif
+
+#include <numpy/arrayobject.h> // IWYU pragma: export
+#include <numpy/arrayscalars.h> // IWYU pragma: export
+#include <numpy/ufuncobject.h> // IWYU pragma: export
+
+// A bit subtle. Numpy has 5 canonical integer types:
+// (or, rather, type pairs: signed and unsigned)
+// NPY_BYTE, NPY_SHORT, NPY_INT, NPY_LONG, NPY_LONGLONG
+// It also has 4 fixed-width integer aliases.
+// When mapping Arrow integer types to these 4 fixed-width aliases,
+// we always miss one of the canonical types (even though it may
+// have the same width as one of the aliases).
+// Which one depends on the platform...
+// On a LP64 system, NPY_INT64 maps to NPY_LONG and
+// NPY_LONGLONG needs to be handled separately.
+// On a LLP64 system, NPY_INT32 maps to NPY_LONG and
+// NPY_INT needs to be handled separately.
+
+#if NPY_BITSOF_LONG == 32 && NPY_BITSOF_LONGLONG == 64
+#define NPY_INT64_IS_LONG_LONG 1
+#else
+#define NPY_INT64_IS_LONG_LONG 0
+#endif
+
+#if NPY_BITSOF_INT == 32 && NPY_BITSOF_LONG == 64
+#define NPY_INT32_IS_INT 1
+#else
+#define NPY_INT32_IS_INT 0
+#endif
+
+namespace arrow {
+namespace py {
+
+inline int import_numpy() {
+#ifdef NUMPY_IMPORT_ARRAY
+ import_array1(-1);
+ import_umath1(-1);
+#endif
+
+ return 0;
+}
+
+// See above about the missing Numpy integer type numbers
+inline int fix_numpy_type_num(int type_num) {
+#if !NPY_INT32_IS_INT && NPY_BITSOF_INT == 32
+ if (type_num == NPY_INT) return NPY_INT32;
+ if (type_num == NPY_UINT) return NPY_UINT32;
+#endif
+#if !NPY_INT64_IS_LONG_LONG && NPY_BITSOF_LONGLONG == 64
+ if (type_num == NPY_LONGLONG) return NPY_INT64;
+ if (type_num == NPY_ULONGLONG) return NPY_UINT64;
+#endif
+ return type_num;
+}
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/numpy_to_arrow.cc b/src/arrow/cpp/src/arrow/python/numpy_to_arrow.cc
new file mode 100644
index 000000000..a382f7663
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/numpy_to_arrow.cc
@@ -0,0 +1,865 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Functions for pandas conversion via NumPy
+
+#include "arrow/python/numpy_to_arrow.h"
+#include "arrow/python/numpy_interop.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_generate.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string.h"
+#include "arrow/util/utf8.h"
+#include "arrow/visitor_inline.h"
+
+#include "arrow/compute/api_scalar.h"
+
+#include "arrow/python/common.h"
+#include "arrow/python/datetime.h"
+#include "arrow/python/helpers.h"
+#include "arrow/python/iterators.h"
+#include "arrow/python/numpy_convert.h"
+#include "arrow/python/numpy_internal.h"
+#include "arrow/python/python_to_arrow.h"
+#include "arrow/python/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::CopyBitmap;
+using internal::GenerateBitsUnrolled;
+
+namespace py {
+
+using internal::NumPyTypeSize;
+
+// ----------------------------------------------------------------------
+// Conversion utilities
+
+namespace {
+
+Status AllocateNullBitmap(MemoryPool* pool, int64_t length,
+ std::shared_ptr<ResizableBuffer>* out) {
+ int64_t null_bytes = BitUtil::BytesForBits(length);
+ ARROW_ASSIGN_OR_RAISE(auto null_bitmap, AllocateResizableBuffer(null_bytes, pool));
+
+ // Padding zeroed by AllocateResizableBuffer
+ memset(null_bitmap->mutable_data(), 0, static_cast<size_t>(null_bytes));
+ *out = std::move(null_bitmap);
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Conversion from NumPy-in-Pandas to Arrow null bitmap
+
+template <int TYPE>
+inline int64_t ValuesToBitmap(PyArrayObject* arr, uint8_t* bitmap) {
+ typedef internal::npy_traits<TYPE> traits;
+ typedef typename traits::value_type T;
+
+ int64_t null_count = 0;
+
+ Ndarray1DIndexer<T> values(arr);
+ for (int i = 0; i < values.size(); ++i) {
+ if (traits::isnull(values[i])) {
+ ++null_count;
+ } else {
+ BitUtil::SetBit(bitmap, i);
+ }
+ }
+
+ return null_count;
+}
+
+class NumPyNullsConverter {
+ public:
+ /// Convert the given array's null values to a null bitmap.
+ /// The null bitmap is only allocated if null values are ever possible.
+ static Status Convert(MemoryPool* pool, PyArrayObject* arr, bool from_pandas,
+ std::shared_ptr<ResizableBuffer>* out_null_bitmap_,
+ int64_t* out_null_count) {
+ NumPyNullsConverter converter(pool, arr, from_pandas);
+ RETURN_NOT_OK(VisitNumpyArrayInline(arr, &converter));
+ *out_null_bitmap_ = converter.null_bitmap_;
+ *out_null_count = converter.null_count_;
+ return Status::OK();
+ }
+
+ template <int TYPE>
+ Status Visit(PyArrayObject* arr) {
+ typedef internal::npy_traits<TYPE> traits;
+
+ const bool null_sentinels_possible =
+ // Always treat Numpy's NaT as null
+ TYPE == NPY_DATETIME || TYPE == NPY_TIMEDELTA ||
+ // Observing pandas's null sentinels
+ (from_pandas_ && traits::supports_nulls);
+
+ if (null_sentinels_possible) {
+ RETURN_NOT_OK(AllocateNullBitmap(pool_, PyArray_SIZE(arr), &null_bitmap_));
+ null_count_ = ValuesToBitmap<TYPE>(arr, null_bitmap_->mutable_data());
+ }
+ return Status::OK();
+ }
+
+ protected:
+ NumPyNullsConverter(MemoryPool* pool, PyArrayObject* arr, bool from_pandas)
+ : pool_(pool),
+ arr_(arr),
+ from_pandas_(from_pandas),
+ null_bitmap_data_(nullptr),
+ null_count_(0) {}
+
+ MemoryPool* pool_;
+ PyArrayObject* arr_;
+ bool from_pandas_;
+ std::shared_ptr<ResizableBuffer> null_bitmap_;
+ uint8_t* null_bitmap_data_;
+ int64_t null_count_;
+};
+
+// Returns null count
+int64_t MaskToBitmap(PyArrayObject* mask, int64_t length, uint8_t* bitmap) {
+ int64_t null_count = 0;
+
+ Ndarray1DIndexer<uint8_t> mask_values(mask);
+ for (int i = 0; i < length; ++i) {
+ if (mask_values[i]) {
+ ++null_count;
+ BitUtil::ClearBit(bitmap, i);
+ } else {
+ BitUtil::SetBit(bitmap, i);
+ }
+ }
+ return null_count;
+}
+
+} // namespace
+
+// ----------------------------------------------------------------------
+// Conversion from NumPy arrays (possibly originating from pandas) to Arrow
+// format. Does not handle NPY_OBJECT dtype arrays; use ConvertPySequence for
+// that
+
+class NumPyConverter {
+ public:
+ NumPyConverter(MemoryPool* pool, PyObject* arr, PyObject* mo,
+ const std::shared_ptr<DataType>& type, bool from_pandas,
+ const compute::CastOptions& cast_options = compute::CastOptions())
+ : pool_(pool),
+ type_(type),
+ arr_(reinterpret_cast<PyArrayObject*>(arr)),
+ dtype_(PyArray_DESCR(arr_)),
+ mask_(nullptr),
+ from_pandas_(from_pandas),
+ cast_options_(cast_options),
+ null_bitmap_data_(nullptr),
+ null_count_(0) {
+ if (mo != nullptr && mo != Py_None) {
+ mask_ = reinterpret_cast<PyArrayObject*>(mo);
+ }
+ length_ = static_cast<int64_t>(PyArray_SIZE(arr_));
+ itemsize_ = static_cast<int>(PyArray_DESCR(arr_)->elsize);
+ stride_ = static_cast<int64_t>(PyArray_STRIDES(arr_)[0]);
+ }
+
+ bool is_strided() const { return itemsize_ != stride_; }
+
+ Status Convert();
+
+ const ArrayVector& result() const { return out_arrays_; }
+
+ template <typename T>
+ enable_if_primitive_ctype<T, Status> Visit(const T& type) {
+ return VisitNative<T>();
+ }
+
+ Status Visit(const HalfFloatType& type) { return VisitNative<UInt16Type>(); }
+
+ Status Visit(const Date32Type& type) { return VisitNative<Date32Type>(); }
+ Status Visit(const Date64Type& type) { return VisitNative<Date64Type>(); }
+ Status Visit(const TimestampType& type) { return VisitNative<TimestampType>(); }
+ Status Visit(const Time32Type& type) { return VisitNative<Int32Type>(); }
+ Status Visit(const Time64Type& type) { return VisitNative<Int64Type>(); }
+ Status Visit(const DurationType& type) { return VisitNative<DurationType>(); }
+
+ Status Visit(const NullType& type) { return TypeNotImplemented(type.ToString()); }
+
+ // NumPy ascii string arrays
+ Status Visit(const BinaryType& type);
+
+ // NumPy unicode arrays
+ Status Visit(const StringType& type);
+
+ Status Visit(const StructType& type);
+
+ Status Visit(const FixedSizeBinaryType& type);
+
+ // Default case
+ Status Visit(const DataType& type) { return TypeNotImplemented(type.ToString()); }
+
+ protected:
+ Status InitNullBitmap() {
+ RETURN_NOT_OK(AllocateNullBitmap(pool_, length_, &null_bitmap_));
+ null_bitmap_data_ = null_bitmap_->mutable_data();
+ return Status::OK();
+ }
+
+ // Called before ConvertData to ensure Numpy input buffer is in expected
+ // Arrow layout
+ template <typename ArrowType>
+ Status PrepareInputData(std::shared_ptr<Buffer>* data);
+
+ // ----------------------------------------------------------------------
+ // Traditional visitor conversion for non-object arrays
+
+ template <typename ArrowType>
+ Status ConvertData(std::shared_ptr<Buffer>* data);
+
+ template <typename T>
+ Status PushBuilderResult(T* builder) {
+ std::shared_ptr<Array> out;
+ RETURN_NOT_OK(builder->Finish(&out));
+ out_arrays_.emplace_back(out);
+ return Status::OK();
+ }
+
+ Status PushArray(const std::shared_ptr<ArrayData>& data) {
+ out_arrays_.emplace_back(MakeArray(data));
+ return Status::OK();
+ }
+
+ template <typename ArrowType>
+ Status VisitNative() {
+ if (mask_ != nullptr) {
+ RETURN_NOT_OK(InitNullBitmap());
+ null_count_ = MaskToBitmap(mask_, length_, null_bitmap_data_);
+ } else {
+ RETURN_NOT_OK(NumPyNullsConverter::Convert(pool_, arr_, from_pandas_, &null_bitmap_,
+ &null_count_));
+ }
+
+ std::shared_ptr<Buffer> data;
+ RETURN_NOT_OK(ConvertData<ArrowType>(&data));
+
+ auto arr_data = ArrayData::Make(type_, length_, {null_bitmap_, data}, null_count_, 0);
+ return PushArray(arr_data);
+ }
+
+ Status TypeNotImplemented(std::string type_name) {
+ return Status::NotImplemented("NumPyConverter doesn't implement <", type_name,
+ "> conversion. ");
+ }
+
+ MemoryPool* pool_;
+ std::shared_ptr<DataType> type_;
+ PyArrayObject* arr_;
+ PyArray_Descr* dtype_;
+ PyArrayObject* mask_;
+ int64_t length_;
+ int64_t stride_;
+ int itemsize_;
+
+ bool from_pandas_;
+ compute::CastOptions cast_options_;
+
+ // Used in visitor pattern
+ ArrayVector out_arrays_;
+
+ std::shared_ptr<ResizableBuffer> null_bitmap_;
+ uint8_t* null_bitmap_data_;
+ int64_t null_count_;
+};
+
+Status NumPyConverter::Convert() {
+ if (PyArray_NDIM(arr_) != 1) {
+ return Status::Invalid("only handle 1-dimensional arrays");
+ }
+
+ if (dtype_->type_num == NPY_OBJECT) {
+ // If an object array, convert it like a normal Python sequence
+ PyConversionOptions py_options;
+ py_options.type = type_;
+ py_options.from_pandas = from_pandas_;
+ ARROW_ASSIGN_OR_RAISE(
+ auto chunked_array,
+ ConvertPySequence(reinterpret_cast<PyObject*>(arr_),
+ reinterpret_cast<PyObject*>(mask_), py_options, pool_));
+ out_arrays_ = chunked_array->chunks();
+ return Status::OK();
+ }
+
+ if (type_ == nullptr) {
+ return Status::Invalid("Must pass data type for non-object arrays");
+ }
+
+ // Visit the type to perform conversion
+ return VisitTypeInline(*type_, this);
+}
+
+namespace {
+
+Status CastBuffer(const std::shared_ptr<DataType>& in_type,
+ const std::shared_ptr<Buffer>& input, const int64_t length,
+ const std::shared_ptr<Buffer>& valid_bitmap, const int64_t null_count,
+ const std::shared_ptr<DataType>& out_type,
+ const compute::CastOptions& cast_options, MemoryPool* pool,
+ std::shared_ptr<Buffer>* out) {
+ // Must cast
+ auto tmp_data = ArrayData::Make(in_type, length, {valid_bitmap, input}, null_count);
+ compute::ExecContext context(pool);
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<Array> casted_array,
+ compute::Cast(*MakeArray(tmp_data), out_type, cast_options, &context));
+ *out = casted_array->data()->buffers[1];
+ return Status::OK();
+}
+
+template <typename FromType, typename ToType>
+Status StaticCastBuffer(const Buffer& input, const int64_t length, MemoryPool* pool,
+ std::shared_ptr<Buffer>* out) {
+ ARROW_ASSIGN_OR_RAISE(auto result, AllocateBuffer(sizeof(ToType) * length, pool));
+
+ auto in_values = reinterpret_cast<const FromType*>(input.data());
+ auto out_values = reinterpret_cast<ToType*>(result->mutable_data());
+ for (int64_t i = 0; i < length; ++i) {
+ *out_values++ = static_cast<ToType>(*in_values++);
+ }
+ *out = std::move(result);
+ return Status::OK();
+}
+
+template <typename T>
+void CopyStridedBytewise(int8_t* input_data, int64_t length, int64_t stride,
+ T* output_data) {
+ // Passing input_data as non-const is a concession to PyObject*
+ for (int64_t i = 0; i < length; ++i) {
+ memcpy(output_data + i, input_data, sizeof(T));
+ input_data += stride;
+ }
+}
+
+template <typename T>
+void CopyStridedNatural(T* input_data, int64_t length, int64_t stride, T* output_data) {
+ // Passing input_data as non-const is a concession to PyObject*
+ int64_t j = 0;
+ for (int64_t i = 0; i < length; ++i) {
+ output_data[i] = input_data[j];
+ j += stride;
+ }
+}
+
+class NumPyStridedConverter {
+ public:
+ static Status Convert(PyArrayObject* arr, int64_t length, MemoryPool* pool,
+ std::shared_ptr<Buffer>* out) {
+ NumPyStridedConverter converter(arr, length, pool);
+ RETURN_NOT_OK(VisitNumpyArrayInline(arr, &converter));
+ *out = converter.buffer_;
+ return Status::OK();
+ }
+ template <int TYPE>
+ Status Visit(PyArrayObject* arr) {
+ using traits = internal::npy_traits<TYPE>;
+ using T = typename traits::value_type;
+
+ ARROW_ASSIGN_OR_RAISE(buffer_, AllocateBuffer(sizeof(T) * length_, pool_));
+
+ const int64_t stride = PyArray_STRIDES(arr)[0];
+ if (stride % sizeof(T) == 0) {
+ const int64_t stride_elements = stride / sizeof(T);
+ CopyStridedNatural(reinterpret_cast<T*>(PyArray_DATA(arr)), length_,
+ stride_elements, reinterpret_cast<T*>(buffer_->mutable_data()));
+ } else {
+ CopyStridedBytewise(reinterpret_cast<int8_t*>(PyArray_DATA(arr)), length_, stride,
+ reinterpret_cast<T*>(buffer_->mutable_data()));
+ }
+ return Status::OK();
+ }
+
+ protected:
+ NumPyStridedConverter(PyArrayObject* arr, int64_t length, MemoryPool* pool)
+ : arr_(arr), length_(length), pool_(pool), buffer_(nullptr) {}
+ PyArrayObject* arr_;
+ int64_t length_;
+ MemoryPool* pool_;
+ std::shared_ptr<Buffer> buffer_;
+};
+
+} // namespace
+
+template <typename ArrowType>
+inline Status NumPyConverter::PrepareInputData(std::shared_ptr<Buffer>* data) {
+ if (PyArray_ISBYTESWAPPED(arr_)) {
+ // TODO
+ return Status::NotImplemented("Byte-swapped arrays not supported");
+ }
+
+ if (dtype_->type_num == NPY_BOOL) {
+ int64_t nbytes = BitUtil::BytesForBits(length_);
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(nbytes, pool_));
+
+ Ndarray1DIndexer<uint8_t> values(arr_);
+ int64_t i = 0;
+ const auto generate = [&values, &i]() -> bool { return values[i++] > 0; };
+ GenerateBitsUnrolled(buffer->mutable_data(), 0, length_, generate);
+
+ *data = std::move(buffer);
+ } else if (is_strided()) {
+ RETURN_NOT_OK(NumPyStridedConverter::Convert(arr_, length_, pool_, data));
+ } else {
+ // Can zero-copy
+ *data = std::make_shared<NumPyBuffer>(reinterpret_cast<PyObject*>(arr_));
+ }
+
+ return Status::OK();
+}
+
+template <typename ArrowType>
+inline Status NumPyConverter::ConvertData(std::shared_ptr<Buffer>* data) {
+ RETURN_NOT_OK(PrepareInputData<ArrowType>(data));
+
+ std::shared_ptr<DataType> input_type;
+ RETURN_NOT_OK(NumPyDtypeToArrow(reinterpret_cast<PyObject*>(dtype_), &input_type));
+
+ if (!input_type->Equals(*type_)) {
+ RETURN_NOT_OK(CastBuffer(input_type, *data, length_, null_bitmap_, null_count_, type_,
+ cast_options_, pool_, data));
+ }
+
+ return Status::OK();
+}
+
+template <>
+inline Status NumPyConverter::ConvertData<Date32Type>(std::shared_ptr<Buffer>* data) {
+ std::shared_ptr<DataType> input_type;
+
+ RETURN_NOT_OK(PrepareInputData<Date32Type>(data));
+
+ auto date_dtype = reinterpret_cast<PyArray_DatetimeDTypeMetaData*>(dtype_->c_metadata);
+ if (dtype_->type_num == NPY_DATETIME) {
+ // If we have inbound datetime64[D] data, this needs to be downcasted
+ // separately here from int64_t to int32_t, because this data is not
+ // supported in compute::Cast
+ if (date_dtype->meta.base == NPY_FR_D) {
+ // TODO(wesm): How pedantic do we really want to be about checking for int32
+ // overflow here?
+ Status s = StaticCastBuffer<int64_t, int32_t>(**data, length_, pool_, data);
+ RETURN_NOT_OK(s);
+ } else {
+ RETURN_NOT_OK(NumPyDtypeToArrow(reinterpret_cast<PyObject*>(dtype_), &input_type));
+ if (!input_type->Equals(*type_)) {
+ // The null bitmap was already computed in VisitNative()
+ RETURN_NOT_OK(CastBuffer(input_type, *data, length_, null_bitmap_, null_count_,
+ type_, cast_options_, pool_, data));
+ }
+ }
+ } else {
+ RETURN_NOT_OK(NumPyDtypeToArrow(reinterpret_cast<PyObject*>(dtype_), &input_type));
+ if (!input_type->Equals(*type_)) {
+ RETURN_NOT_OK(CastBuffer(input_type, *data, length_, null_bitmap_, null_count_,
+ type_, cast_options_, pool_, data));
+ }
+ }
+
+ return Status::OK();
+}
+
+template <>
+inline Status NumPyConverter::ConvertData<Date64Type>(std::shared_ptr<Buffer>* data) {
+ constexpr int64_t kMillisecondsInDay = 86400000;
+ std::shared_ptr<DataType> input_type;
+
+ RETURN_NOT_OK(PrepareInputData<Date64Type>(data));
+
+ auto date_dtype = reinterpret_cast<PyArray_DatetimeDTypeMetaData*>(dtype_->c_metadata);
+ if (dtype_->type_num == NPY_DATETIME) {
+ // If we have inbound datetime64[D] data, this needs to be downcasted
+ // separately here from int64_t to int32_t, because this data is not
+ // supported in compute::Cast
+ if (date_dtype->meta.base == NPY_FR_D) {
+ ARROW_ASSIGN_OR_RAISE(auto result,
+ AllocateBuffer(sizeof(int64_t) * length_, pool_));
+
+ auto in_values = reinterpret_cast<const int64_t*>((*data)->data());
+ auto out_values = reinterpret_cast<int64_t*>(result->mutable_data());
+ for (int64_t i = 0; i < length_; ++i) {
+ *out_values++ = kMillisecondsInDay * (*in_values++);
+ }
+ *data = std::move(result);
+ } else {
+ RETURN_NOT_OK(NumPyDtypeToArrow(reinterpret_cast<PyObject*>(dtype_), &input_type));
+ if (!input_type->Equals(*type_)) {
+ // The null bitmap was already computed in VisitNative()
+ RETURN_NOT_OK(CastBuffer(input_type, *data, length_, null_bitmap_, null_count_,
+ type_, cast_options_, pool_, data));
+ }
+ }
+ } else {
+ RETURN_NOT_OK(NumPyDtypeToArrow(reinterpret_cast<PyObject*>(dtype_), &input_type));
+ if (!input_type->Equals(*type_)) {
+ RETURN_NOT_OK(CastBuffer(input_type, *data, length_, null_bitmap_, null_count_,
+ type_, cast_options_, pool_, data));
+ }
+ }
+
+ return Status::OK();
+}
+
+// Create 16MB chunks for binary data
+constexpr int32_t kBinaryChunksize = 1 << 24;
+
+Status NumPyConverter::Visit(const BinaryType& type) {
+ ::arrow::internal::ChunkedBinaryBuilder builder(kBinaryChunksize, pool_);
+
+ auto data = reinterpret_cast<const uint8_t*>(PyArray_DATA(arr_));
+
+ auto AppendNotNull = [&builder, this](const uint8_t* data) {
+ // This is annoying. NumPy allows strings to have nul-terminators, so
+ // we must check for them here
+ const size_t item_size =
+ strnlen(reinterpret_cast<const char*>(data), static_cast<size_t>(itemsize_));
+ return builder.Append(data, static_cast<int32_t>(item_size));
+ };
+
+ if (mask_ != nullptr) {
+ Ndarray1DIndexer<uint8_t> mask_values(mask_);
+ for (int64_t i = 0; i < length_; ++i) {
+ if (mask_values[i]) {
+ RETURN_NOT_OK(builder.AppendNull());
+ } else {
+ RETURN_NOT_OK(AppendNotNull(data));
+ }
+ data += stride_;
+ }
+ } else {
+ for (int64_t i = 0; i < length_; ++i) {
+ RETURN_NOT_OK(AppendNotNull(data));
+ data += stride_;
+ }
+ }
+
+ ArrayVector result;
+ RETURN_NOT_OK(builder.Finish(&result));
+ for (auto arr : result) {
+ RETURN_NOT_OK(PushArray(arr->data()));
+ }
+ return Status::OK();
+}
+
+Status NumPyConverter::Visit(const FixedSizeBinaryType& type) {
+ auto byte_width = type.byte_width();
+
+ if (itemsize_ != byte_width) {
+ return Status::Invalid("Got bytestring of length ", itemsize_, " (expected ",
+ byte_width, ")");
+ }
+
+ FixedSizeBinaryBuilder builder(::arrow::fixed_size_binary(byte_width), pool_);
+ auto data = reinterpret_cast<const uint8_t*>(PyArray_DATA(arr_));
+
+ if (mask_ != nullptr) {
+ Ndarray1DIndexer<uint8_t> mask_values(mask_);
+ RETURN_NOT_OK(builder.Reserve(length_));
+ for (int64_t i = 0; i < length_; ++i) {
+ if (mask_values[i]) {
+ RETURN_NOT_OK(builder.AppendNull());
+ } else {
+ RETURN_NOT_OK(builder.Append(data));
+ }
+ data += stride_;
+ }
+ } else {
+ for (int64_t i = 0; i < length_; ++i) {
+ RETURN_NOT_OK(builder.Append(data));
+ data += stride_;
+ }
+ }
+
+ std::shared_ptr<Array> result;
+ RETURN_NOT_OK(builder.Finish(&result));
+ return PushArray(result->data());
+}
+
+namespace {
+
+// NumPy unicode is UCS4/UTF32 always
+constexpr int kNumPyUnicodeSize = 4;
+
+Status AppendUTF32(const char* data, int itemsize, int byteorder,
+ ::arrow::internal::ChunkedStringBuilder* builder) {
+ // The binary \x00\x00\x00\x00 indicates a nul terminator in NumPy unicode,
+ // so we need to detect that here to truncate if necessary. Yep.
+ int actual_length = 0;
+ for (; actual_length < itemsize / kNumPyUnicodeSize; ++actual_length) {
+ const char* code_point = data + actual_length * kNumPyUnicodeSize;
+ if ((*code_point == '\0') && (*(code_point + 1) == '\0') &&
+ (*(code_point + 2) == '\0') && (*(code_point + 3) == '\0')) {
+ break;
+ }
+ }
+
+ OwnedRef unicode_obj(PyUnicode_DecodeUTF32(data, actual_length * kNumPyUnicodeSize,
+ nullptr, &byteorder));
+ RETURN_IF_PYERROR();
+ OwnedRef utf8_obj(PyUnicode_AsUTF8String(unicode_obj.obj()));
+ if (utf8_obj.obj() == NULL) {
+ PyErr_Clear();
+ return Status::Invalid("failed converting UTF32 to UTF8");
+ }
+
+ const int32_t length = static_cast<int32_t>(PyBytes_GET_SIZE(utf8_obj.obj()));
+ return builder->Append(
+ reinterpret_cast<const uint8_t*>(PyBytes_AS_STRING(utf8_obj.obj())), length);
+}
+
+} // namespace
+
+Status NumPyConverter::Visit(const StringType& type) {
+ util::InitializeUTF8();
+
+ ::arrow::internal::ChunkedStringBuilder builder(kBinaryChunksize, pool_);
+
+ auto data = reinterpret_cast<const uint8_t*>(PyArray_DATA(arr_));
+
+ char numpy_byteorder = dtype_->byteorder;
+
+ // For Python C API, -1 is little-endian, 1 is big-endian
+ int byteorder = numpy_byteorder == '>' ? 1 : -1;
+
+ PyAcquireGIL gil_lock;
+
+ const bool is_binary_type = dtype_->type_num == NPY_STRING;
+ const bool is_unicode_type = dtype_->type_num == NPY_UNICODE;
+
+ if (!is_binary_type && !is_unicode_type) {
+ const bool is_float_type = dtype_->kind == 'f';
+ if (from_pandas_ && is_float_type) {
+ // in case of from_pandas=True, accept an all-NaN float array as input
+ RETURN_NOT_OK(NumPyNullsConverter::Convert(pool_, arr_, from_pandas_, &null_bitmap_,
+ &null_count_));
+ if (null_count_ == length_) {
+ auto arr = std::make_shared<NullArray>(length_);
+ compute::ExecContext context(pool_);
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<Array> out,
+ compute::Cast(*arr, arrow::utf8(), cast_options_, &context));
+ out_arrays_.emplace_back(out);
+ return Status::OK();
+ }
+ }
+ std::string dtype_string;
+ RETURN_NOT_OK(internal::PyObject_StdStringStr(reinterpret_cast<PyObject*>(dtype_),
+ &dtype_string));
+ return Status::TypeError("Expected a string or bytes dtype, got ", dtype_string);
+ }
+
+ auto AppendNonNullValue = [&](const uint8_t* data) {
+ if (is_binary_type) {
+ if (ARROW_PREDICT_TRUE(util::ValidateUTF8(data, itemsize_))) {
+ return builder.Append(data, itemsize_);
+ } else {
+ return Status::Invalid("Encountered non-UTF8 binary value: ",
+ HexEncode(data, itemsize_));
+ }
+ } else {
+ // is_unicode_type case
+ return AppendUTF32(reinterpret_cast<const char*>(data), itemsize_, byteorder,
+ &builder);
+ }
+ };
+
+ if (mask_ != nullptr) {
+ Ndarray1DIndexer<uint8_t> mask_values(mask_);
+ for (int64_t i = 0; i < length_; ++i) {
+ if (mask_values[i]) {
+ RETURN_NOT_OK(builder.AppendNull());
+ } else {
+ RETURN_NOT_OK(AppendNonNullValue(data));
+ }
+ data += stride_;
+ }
+ } else {
+ for (int64_t i = 0; i < length_; ++i) {
+ RETURN_NOT_OK(AppendNonNullValue(data));
+ data += stride_;
+ }
+ }
+
+ ArrayVector result;
+ RETURN_NOT_OK(builder.Finish(&result));
+ for (auto arr : result) {
+ RETURN_NOT_OK(PushArray(arr->data()));
+ }
+ return Status::OK();
+}
+
+Status NumPyConverter::Visit(const StructType& type) {
+ std::vector<NumPyConverter> sub_converters;
+ std::vector<OwnedRefNoGIL> sub_arrays;
+
+ {
+ PyAcquireGIL gil_lock;
+
+ // Create converters for each struct type field
+ if (dtype_->fields == NULL || !PyDict_Check(dtype_->fields)) {
+ return Status::TypeError("Expected struct array");
+ }
+
+ for (auto field : type.fields()) {
+ PyObject* tup = PyDict_GetItemString(dtype_->fields, field->name().c_str());
+ if (tup == NULL) {
+ return Status::Invalid("Missing field '", field->name(), "' in struct array");
+ }
+ PyArray_Descr* sub_dtype =
+ reinterpret_cast<PyArray_Descr*>(PyTuple_GET_ITEM(tup, 0));
+ DCHECK(PyObject_TypeCheck(sub_dtype, &PyArrayDescr_Type));
+ int offset = static_cast<int>(PyLong_AsLong(PyTuple_GET_ITEM(tup, 1)));
+ RETURN_IF_PYERROR();
+ Py_INCREF(sub_dtype); /* PyArray_GetField() steals ref */
+ PyObject* sub_array = PyArray_GetField(arr_, sub_dtype, offset);
+ RETURN_IF_PYERROR();
+ sub_arrays.emplace_back(sub_array);
+ sub_converters.emplace_back(pool_, sub_array, nullptr /* mask */, field->type(),
+ from_pandas_);
+ }
+ }
+
+ std::vector<ArrayVector> groups;
+ int64_t null_count = 0;
+
+ // Compute null bitmap and store it as a Boolean Array to include it
+ // in the rechunking below
+ {
+ if (mask_ != nullptr) {
+ RETURN_NOT_OK(InitNullBitmap());
+ null_count = MaskToBitmap(mask_, length_, null_bitmap_data_);
+ }
+ groups.push_back({std::make_shared<BooleanArray>(length_, null_bitmap_)});
+ }
+
+ // Convert child data
+ for (auto& converter : sub_converters) {
+ RETURN_NOT_OK(converter.Convert());
+ groups.push_back(converter.result());
+ const auto& group = groups.back();
+ int64_t n = 0;
+ for (const auto& array : group) {
+ n += array->length();
+ }
+ }
+ // Ensure the different array groups are chunked consistently
+ groups = ::arrow::internal::RechunkArraysConsistently(groups);
+ for (const auto& group : groups) {
+ int64_t n = 0;
+ for (const auto& array : group) {
+ n += array->length();
+ }
+ }
+
+ // Make struct array chunks by combining groups
+ size_t ngroups = groups.size();
+ size_t nchunks = groups[0].size();
+ for (size_t chunk = 0; chunk < nchunks; chunk++) {
+ // First group has the null bitmaps as Boolean Arrays
+ const auto& null_data = groups[0][chunk]->data();
+ DCHECK_EQ(null_data->type->id(), Type::BOOL);
+ DCHECK_EQ(null_data->buffers.size(), 2);
+ const auto& null_buffer = null_data->buffers[1];
+ // Careful: the rechunked null bitmap may have a non-zero offset
+ // to its buffer, and it may not even start on a byte boundary
+ int64_t null_offset = null_data->offset;
+ std::shared_ptr<Buffer> fixed_null_buffer;
+
+ if (!null_buffer) {
+ fixed_null_buffer = null_buffer;
+ } else if (null_offset % 8 == 0) {
+ fixed_null_buffer =
+ std::make_shared<Buffer>(null_buffer,
+ // byte offset
+ null_offset / 8,
+ // byte size
+ BitUtil::BytesForBits(null_data->length));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ fixed_null_buffer,
+ CopyBitmap(pool_, null_buffer->data(), null_offset, null_data->length));
+ }
+
+ // Create struct array chunk and populate it
+ auto arr_data =
+ ArrayData::Make(type_, null_data->length, null_count ? kUnknownNullCount : 0, 0);
+ arr_data->buffers.push_back(fixed_null_buffer);
+ // Append child chunks
+ for (size_t i = 1; i < ngroups; i++) {
+ arr_data->child_data.push_back(groups[i][chunk]->data());
+ }
+ RETURN_NOT_OK(PushArray(arr_data));
+ }
+
+ return Status::OK();
+}
+
+Status NdarrayToArrow(MemoryPool* pool, PyObject* ao, PyObject* mo, bool from_pandas,
+ const std::shared_ptr<DataType>& type,
+ const compute::CastOptions& cast_options,
+ std::shared_ptr<ChunkedArray>* out) {
+ if (!PyArray_Check(ao)) {
+ // This code path cannot be reached by Python unit tests currently so this
+ // is only a sanity check.
+ return Status::TypeError("Input object was not a NumPy array");
+ }
+ if (PyArray_NDIM(reinterpret_cast<PyArrayObject*>(ao)) != 1) {
+ return Status::Invalid("only handle 1-dimensional arrays");
+ }
+
+ NumPyConverter converter(pool, ao, mo, type, from_pandas, cast_options);
+ RETURN_NOT_OK(converter.Convert());
+ const auto& output_arrays = converter.result();
+ DCHECK_GT(output_arrays.size(), 0);
+ *out = std::make_shared<ChunkedArray>(output_arrays);
+ return Status::OK();
+}
+
+Status NdarrayToArrow(MemoryPool* pool, PyObject* ao, PyObject* mo, bool from_pandas,
+ const std::shared_ptr<DataType>& type,
+ std::shared_ptr<ChunkedArray>* out) {
+ return NdarrayToArrow(pool, ao, mo, from_pandas, type, compute::CastOptions(), out);
+}
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/numpy_to_arrow.h b/src/arrow/cpp/src/arrow/python/numpy_to_arrow.h
new file mode 100644
index 000000000..b6cd093e5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/numpy_to_arrow.h
@@ -0,0 +1,72 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Converting from pandas memory representation to Arrow data structures
+
+#pragma once
+
+#include "arrow/python/platform.h"
+
+#include <memory>
+
+#include "arrow/compute/api.h"
+#include "arrow/python/visibility.h"
+
+namespace arrow {
+
+class Array;
+class ChunkedArray;
+class DataType;
+class MemoryPool;
+class Status;
+
+namespace py {
+
+/// Convert NumPy arrays to Arrow. If target data type is not known, pass a
+/// type with null
+///
+/// \param[in] pool Memory pool for any memory allocations
+/// \param[in] ao an ndarray with the array data
+/// \param[in] mo an ndarray with a null mask (True is null), optional
+/// \param[in] from_pandas If true, use pandas's null sentinels to determine
+/// whether values are null
+/// \param[in] type a specific type to cast to, may be null
+/// \param[in] cast_options casting options
+/// \param[out] out a ChunkedArray, to accommodate chunked output
+ARROW_PYTHON_EXPORT
+Status NdarrayToArrow(MemoryPool* pool, PyObject* ao, PyObject* mo, bool from_pandas,
+ const std::shared_ptr<DataType>& type,
+ const compute::CastOptions& cast_options,
+ std::shared_ptr<ChunkedArray>* out);
+
+/// Safely convert NumPy arrays to Arrow. If target data type is not known,
+/// pass a type with null.
+///
+/// \param[in] pool Memory pool for any memory allocations
+/// \param[in] ao an ndarray with the array data
+/// \param[in] mo an ndarray with a null mask (True is null), optional
+/// \param[in] from_pandas If true, use pandas's null sentinels to determine
+/// whether values are null
+/// \param[in] type a specific type to cast to, may be null
+/// \param[out] out a ChunkedArray, to accommodate chunked output
+ARROW_PYTHON_EXPORT
+Status NdarrayToArrow(MemoryPool* pool, PyObject* ao, PyObject* mo, bool from_pandas,
+ const std::shared_ptr<DataType>& type,
+ std::shared_ptr<ChunkedArray>* out);
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/pch.h b/src/arrow/cpp/src/arrow/python/pch.h
new file mode 100644
index 000000000..d1d688b4f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/pch.h
@@ -0,0 +1,24 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Often-used headers, for precompiling.
+// If updating this header, please make sure you check compilation speed
+// before checking in. Adding headers which are not used extremely often
+// may incur a slowdown, since it makes the precompiled header heavier to load.
+
+#include "arrow/pch.h"
+#include "arrow/python/platform.h"
diff --git a/src/arrow/cpp/src/arrow/python/platform.h b/src/arrow/cpp/src/arrow/python/platform.h
new file mode 100644
index 000000000..80f7e6081
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/platform.h
@@ -0,0 +1,36 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Functions for converting between pandas's NumPy-based data representation
+// and Arrow data structures
+
+#pragma once
+
+// If PY_SSIZE_T_CLEAN is defined, argument parsing functions treat #-specifier
+// to mean Py_ssize_t (defining this to suppress deprecation warning)
+#define PY_SSIZE_T_CLEAN
+
+#include <Python.h> // IWYU pragma: export
+#include <datetime.h>
+
+// Work around C2528 error
+#ifdef _MSC_VER
+#if _MSC_VER >= 1900
+#undef timezone
+#endif
+#endif
+
diff --git a/src/arrow/cpp/src/arrow/python/pyarrow.cc b/src/arrow/cpp/src/arrow/python/pyarrow.cc
new file mode 100644
index 000000000..c3244b74b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/pyarrow.cc
@@ -0,0 +1,90 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/python/pyarrow.h"
+
+#include <memory>
+#include <utility>
+
+#include "arrow/array.h"
+#include "arrow/table.h"
+#include "arrow/tensor.h"
+#include "arrow/type.h"
+
+#include "arrow/python/common.h"
+#include "arrow/python/datetime.h"
+namespace {
+#include "arrow/python/pyarrow_api.h"
+}
+
+namespace arrow {
+namespace py {
+
+static Status UnwrapError(PyObject* obj, const char* expected_type) {
+ return Status::TypeError("Could not unwrap ", expected_type,
+ " from Python object of type '", Py_TYPE(obj)->tp_name, "'");
+}
+
+int import_pyarrow() {
+ internal::InitDatetime();
+ return ::import_pyarrow__lib();
+}
+
+#define DEFINE_WRAP_FUNCTIONS(FUNC_SUFFIX, TYPE_NAME) \
+ bool is_##FUNC_SUFFIX(PyObject* obj) { return ::pyarrow_is_##FUNC_SUFFIX(obj) != 0; } \
+ \
+ PyObject* wrap_##FUNC_SUFFIX(const std::shared_ptr<TYPE_NAME>& src) { \
+ return ::pyarrow_wrap_##FUNC_SUFFIX(src); \
+ } \
+ Result<std::shared_ptr<TYPE_NAME>> unwrap_##FUNC_SUFFIX(PyObject* obj) { \
+ auto out = ::pyarrow_unwrap_##FUNC_SUFFIX(obj); \
+ if (out) { \
+ return std::move(out); \
+ } else { \
+ return UnwrapError(obj, #TYPE_NAME); \
+ } \
+ }
+
+DEFINE_WRAP_FUNCTIONS(buffer, Buffer)
+
+DEFINE_WRAP_FUNCTIONS(data_type, DataType)
+DEFINE_WRAP_FUNCTIONS(field, Field)
+DEFINE_WRAP_FUNCTIONS(schema, Schema)
+
+DEFINE_WRAP_FUNCTIONS(scalar, Scalar)
+
+DEFINE_WRAP_FUNCTIONS(array, Array)
+DEFINE_WRAP_FUNCTIONS(chunked_array, ChunkedArray)
+
+DEFINE_WRAP_FUNCTIONS(sparse_coo_tensor, SparseCOOTensor)
+DEFINE_WRAP_FUNCTIONS(sparse_csc_matrix, SparseCSCMatrix)
+DEFINE_WRAP_FUNCTIONS(sparse_csf_tensor, SparseCSFTensor)
+DEFINE_WRAP_FUNCTIONS(sparse_csr_matrix, SparseCSRMatrix)
+DEFINE_WRAP_FUNCTIONS(tensor, Tensor)
+
+DEFINE_WRAP_FUNCTIONS(batch, RecordBatch)
+DEFINE_WRAP_FUNCTIONS(table, Table)
+
+#undef DEFINE_WRAP_FUNCTIONS
+
+namespace internal {
+
+int check_status(const Status& status) { return ::pyarrow_internal_check_status(status); }
+
+} // namespace internal
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/pyarrow.h b/src/arrow/cpp/src/arrow/python/pyarrow.h
new file mode 100644
index 000000000..4c365081d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/pyarrow.h
@@ -0,0 +1,84 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/python/platform.h"
+
+#include <memory>
+
+#include "arrow/python/visibility.h"
+
+#include "arrow/sparse_tensor.h"
+
+// Work around ARROW-2317 (C linkage warning from Cython)
+extern "C++" {
+
+namespace arrow {
+
+class Array;
+class Buffer;
+class DataType;
+class Field;
+class RecordBatch;
+class Schema;
+class Status;
+class Table;
+class Tensor;
+
+namespace py {
+
+// Returns 0 on success, -1 on error.
+ARROW_PYTHON_EXPORT int import_pyarrow();
+
+#define DECLARE_WRAP_FUNCTIONS(FUNC_SUFFIX, TYPE_NAME) \
+ ARROW_PYTHON_EXPORT bool is_##FUNC_SUFFIX(PyObject*); \
+ ARROW_PYTHON_EXPORT Result<std::shared_ptr<TYPE_NAME>> unwrap_##FUNC_SUFFIX( \
+ PyObject*); \
+ ARROW_PYTHON_EXPORT PyObject* wrap_##FUNC_SUFFIX(const std::shared_ptr<TYPE_NAME>&);
+
+DECLARE_WRAP_FUNCTIONS(buffer, Buffer)
+
+DECLARE_WRAP_FUNCTIONS(data_type, DataType)
+DECLARE_WRAP_FUNCTIONS(field, Field)
+DECLARE_WRAP_FUNCTIONS(schema, Schema)
+
+DECLARE_WRAP_FUNCTIONS(scalar, Scalar)
+
+DECLARE_WRAP_FUNCTIONS(array, Array)
+DECLARE_WRAP_FUNCTIONS(chunked_array, ChunkedArray)
+
+DECLARE_WRAP_FUNCTIONS(sparse_coo_tensor, SparseCOOTensor)
+DECLARE_WRAP_FUNCTIONS(sparse_csc_matrix, SparseCSCMatrix)
+DECLARE_WRAP_FUNCTIONS(sparse_csf_tensor, SparseCSFTensor)
+DECLARE_WRAP_FUNCTIONS(sparse_csr_matrix, SparseCSRMatrix)
+DECLARE_WRAP_FUNCTIONS(tensor, Tensor)
+
+DECLARE_WRAP_FUNCTIONS(batch, RecordBatch)
+DECLARE_WRAP_FUNCTIONS(table, Table)
+
+#undef DECLARE_WRAP_FUNCTIONS
+
+namespace internal {
+
+ARROW_PYTHON_EXPORT int check_status(const Status& status);
+
+} // namespace internal
+} // namespace py
+} // namespace arrow
+
+} // extern "C++"
diff --git a/src/arrow/cpp/src/arrow/python/pyarrow_api.h b/src/arrow/cpp/src/arrow/python/pyarrow_api.h
new file mode 100644
index 000000000..947431200
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/pyarrow_api.h
@@ -0,0 +1,239 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// DO NOT EDIT THIS FILE. Update from pyarrow/lib_api.h after pyarrow build
+// This is used to be able to call back into Cython code from C++.
+
+/* Generated by Cython 0.29.15 */
+
+#ifndef __PYX_HAVE_API__pyarrow__lib
+#define __PYX_HAVE_API__pyarrow__lib
+#ifdef __MINGW64__
+#define MS_WIN64
+#endif
+#include "Python.h"
+#include "pyarrow_lib.h"
+
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_scalar)(std::shared_ptr< arrow::Scalar> const &) = 0;
+#define pyarrow_wrap_scalar __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_scalar
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_array)(std::shared_ptr< arrow::Array> const &) = 0;
+#define pyarrow_wrap_array __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_array
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_chunked_array)(std::shared_ptr< arrow::ChunkedArray> const &) = 0;
+#define pyarrow_wrap_chunked_array __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_chunked_array
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_batch)(std::shared_ptr< arrow::RecordBatch> const &) = 0;
+#define pyarrow_wrap_batch __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_batch
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_buffer)(std::shared_ptr< arrow::Buffer> const &) = 0;
+#define pyarrow_wrap_buffer __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_buffer
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_data_type)(std::shared_ptr< arrow::DataType> const &) = 0;
+#define pyarrow_wrap_data_type __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_data_type
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_field)(std::shared_ptr< arrow::Field> const &) = 0;
+#define pyarrow_wrap_field __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_field
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_resizable_buffer)(std::shared_ptr< arrow::ResizableBuffer> const &) = 0;
+#define pyarrow_wrap_resizable_buffer __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_resizable_buffer
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_schema)(std::shared_ptr< arrow::Schema> const &) = 0;
+#define pyarrow_wrap_schema __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_schema
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_table)(std::shared_ptr< arrow::Table> const &) = 0;
+#define pyarrow_wrap_table __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_table
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_tensor)(std::shared_ptr< arrow::Tensor> const &) = 0;
+#define pyarrow_wrap_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_tensor
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_coo_tensor)(std::shared_ptr< arrow::SparseCOOTensor> const &) = 0;
+#define pyarrow_wrap_sparse_coo_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_coo_tensor
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csr_matrix)(std::shared_ptr< arrow::SparseCSRMatrix> const &) = 0;
+#define pyarrow_wrap_sparse_csr_matrix __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csr_matrix
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csc_matrix)(std::shared_ptr< arrow::SparseCSCMatrix> const &) = 0;
+#define pyarrow_wrap_sparse_csc_matrix __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csc_matrix
+static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csf_tensor)(std::shared_ptr< arrow::SparseCSFTensor> const &) = 0;
+#define pyarrow_wrap_sparse_csf_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csf_tensor
+static std::shared_ptr< arrow::Scalar> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_scalar)(PyObject *) = 0;
+#define pyarrow_unwrap_scalar __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_scalar
+static std::shared_ptr< arrow::Array> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_array)(PyObject *) = 0;
+#define pyarrow_unwrap_array __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_array
+static std::shared_ptr< arrow::ChunkedArray> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_chunked_array)(PyObject *) = 0;
+#define pyarrow_unwrap_chunked_array __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_chunked_array
+static std::shared_ptr< arrow::RecordBatch> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_batch)(PyObject *) = 0;
+#define pyarrow_unwrap_batch __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_batch
+static std::shared_ptr< arrow::Buffer> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_buffer)(PyObject *) = 0;
+#define pyarrow_unwrap_buffer __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_buffer
+static std::shared_ptr< arrow::DataType> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_data_type)(PyObject *) = 0;
+#define pyarrow_unwrap_data_type __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_data_type
+static std::shared_ptr< arrow::Field> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_field)(PyObject *) = 0;
+#define pyarrow_unwrap_field __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_field
+static std::shared_ptr< arrow::Schema> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_schema)(PyObject *) = 0;
+#define pyarrow_unwrap_schema __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_schema
+static std::shared_ptr< arrow::Table> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_table)(PyObject *) = 0;
+#define pyarrow_unwrap_table __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_table
+static std::shared_ptr< arrow::Tensor> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_tensor)(PyObject *) = 0;
+#define pyarrow_unwrap_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_tensor
+static std::shared_ptr< arrow::SparseCOOTensor> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_coo_tensor)(PyObject *) = 0;
+#define pyarrow_unwrap_sparse_coo_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_coo_tensor
+static std::shared_ptr< arrow::SparseCSRMatrix> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csr_matrix)(PyObject *) = 0;
+#define pyarrow_unwrap_sparse_csr_matrix __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csr_matrix
+static std::shared_ptr< arrow::SparseCSCMatrix> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csc_matrix)(PyObject *) = 0;
+#define pyarrow_unwrap_sparse_csc_matrix __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csc_matrix
+static std::shared_ptr< arrow::SparseCSFTensor> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csf_tensor)(PyObject *) = 0;
+#define pyarrow_unwrap_sparse_csf_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csf_tensor
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_internal_check_status)(arrow::Status const &) = 0;
+#define pyarrow_internal_check_status __pyx_api_f_7pyarrow_3lib_pyarrow_internal_check_status
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_buffer)(PyObject *) = 0;
+#define pyarrow_is_buffer __pyx_api_f_7pyarrow_3lib_pyarrow_is_buffer
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_data_type)(PyObject *) = 0;
+#define pyarrow_is_data_type __pyx_api_f_7pyarrow_3lib_pyarrow_is_data_type
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_metadata)(PyObject *) = 0;
+#define pyarrow_is_metadata __pyx_api_f_7pyarrow_3lib_pyarrow_is_metadata
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_field)(PyObject *) = 0;
+#define pyarrow_is_field __pyx_api_f_7pyarrow_3lib_pyarrow_is_field
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_schema)(PyObject *) = 0;
+#define pyarrow_is_schema __pyx_api_f_7pyarrow_3lib_pyarrow_is_schema
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_array)(PyObject *) = 0;
+#define pyarrow_is_array __pyx_api_f_7pyarrow_3lib_pyarrow_is_array
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_chunked_array)(PyObject *) = 0;
+#define pyarrow_is_chunked_array __pyx_api_f_7pyarrow_3lib_pyarrow_is_chunked_array
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_scalar)(PyObject *) = 0;
+#define pyarrow_is_scalar __pyx_api_f_7pyarrow_3lib_pyarrow_is_scalar
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_tensor)(PyObject *) = 0;
+#define pyarrow_is_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_is_tensor
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_coo_tensor)(PyObject *) = 0;
+#define pyarrow_is_sparse_coo_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_coo_tensor
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csr_matrix)(PyObject *) = 0;
+#define pyarrow_is_sparse_csr_matrix __pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csr_matrix
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csc_matrix)(PyObject *) = 0;
+#define pyarrow_is_sparse_csc_matrix __pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csc_matrix
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csf_tensor)(PyObject *) = 0;
+#define pyarrow_is_sparse_csf_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csf_tensor
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_table)(PyObject *) = 0;
+#define pyarrow_is_table __pyx_api_f_7pyarrow_3lib_pyarrow_is_table
+static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_batch)(PyObject *) = 0;
+#define pyarrow_is_batch __pyx_api_f_7pyarrow_3lib_pyarrow_is_batch
+#if !defined(__Pyx_PyIdentifier_FromString)
+#if PY_MAJOR_VERSION < 3
+ #define __Pyx_PyIdentifier_FromString(s) PyString_FromString(s)
+#else
+ #define __Pyx_PyIdentifier_FromString(s) PyUnicode_FromString(s)
+#endif
+#endif
+
+#ifndef __PYX_HAVE_RT_ImportFunction
+#define __PYX_HAVE_RT_ImportFunction
+static int __Pyx_ImportFunction(PyObject *module, const char *funcname, void (**f)(void), const char *sig) {
+ PyObject *d = 0;
+ PyObject *cobj = 0;
+ union {
+ void (*fp)(void);
+ void *p;
+ } tmp;
+ d = PyObject_GetAttrString(module, (char *)"__pyx_capi__");
+ if (!d)
+ goto bad;
+ cobj = PyDict_GetItemString(d, funcname);
+ if (!cobj) {
+ PyErr_Format(PyExc_ImportError,
+ "%.200s does not export expected C function %.200s",
+ PyModule_GetName(module), funcname);
+ goto bad;
+ }
+#if PY_VERSION_HEX >= 0x02070000
+ if (!PyCapsule_IsValid(cobj, sig)) {
+ PyErr_Format(PyExc_TypeError,
+ "C function %.200s.%.200s has wrong signature (expected %.500s, got %.500s)",
+ PyModule_GetName(module), funcname, sig, PyCapsule_GetName(cobj));
+ goto bad;
+ }
+ tmp.p = PyCapsule_GetPointer(cobj, sig);
+#else
+ {const char *desc, *s1, *s2;
+ desc = (const char *)PyCObject_GetDesc(cobj);
+ if (!desc)
+ goto bad;
+ s1 = desc; s2 = sig;
+ while (*s1 != '\0' && *s1 == *s2) { s1++; s2++; }
+ if (*s1 != *s2) {
+ PyErr_Format(PyExc_TypeError,
+ "C function %.200s.%.200s has wrong signature (expected %.500s, got %.500s)",
+ PyModule_GetName(module), funcname, sig, desc);
+ goto bad;
+ }
+ tmp.p = PyCObject_AsVoidPtr(cobj);}
+#endif
+ *f = tmp.fp;
+ if (!(*f))
+ goto bad;
+ Py_DECREF(d);
+ return 0;
+bad:
+ Py_XDECREF(d);
+ return -1;
+}
+#endif
+
+
+static int import_pyarrow__lib(void) {
+ PyObject *module = 0;
+ module = PyImport_ImportModule("pyarrow.lib");
+ if (!module) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_scalar", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_scalar, "PyObject *(std::shared_ptr< arrow::Scalar> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_array", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_array, "PyObject *(std::shared_ptr< arrow::Array> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_chunked_array", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_chunked_array, "PyObject *(std::shared_ptr< arrow::ChunkedArray> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_batch", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_batch, "PyObject *(std::shared_ptr< arrow::RecordBatch> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_buffer", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_buffer, "PyObject *(std::shared_ptr< arrow::Buffer> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_data_type", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_data_type, "PyObject *(std::shared_ptr< arrow::DataType> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_field", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_field, "PyObject *(std::shared_ptr< arrow::Field> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_resizable_buffer", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_resizable_buffer, "PyObject *(std::shared_ptr< arrow::ResizableBuffer> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_schema", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_schema, "PyObject *(std::shared_ptr< arrow::Schema> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_table", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_table, "PyObject *(std::shared_ptr< arrow::Table> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_tensor, "PyObject *(std::shared_ptr< arrow::Tensor> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_sparse_coo_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_coo_tensor, "PyObject *(std::shared_ptr< arrow::SparseCOOTensor> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_sparse_csr_matrix", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csr_matrix, "PyObject *(std::shared_ptr< arrow::SparseCSRMatrix> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_sparse_csc_matrix", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csc_matrix, "PyObject *(std::shared_ptr< arrow::SparseCSCMatrix> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_wrap_sparse_csf_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csf_tensor, "PyObject *(std::shared_ptr< arrow::SparseCSFTensor> const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_unwrap_scalar", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_scalar, "std::shared_ptr< arrow::Scalar> (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_unwrap_array", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_array, "std::shared_ptr< arrow::Array> (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_unwrap_chunked_array", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_chunked_array, "std::shared_ptr< arrow::ChunkedArray> (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_unwrap_batch", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_batch, "std::shared_ptr< arrow::RecordBatch> (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_unwrap_buffer", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_buffer, "std::shared_ptr< arrow::Buffer> (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_unwrap_data_type", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_data_type, "std::shared_ptr< arrow::DataType> (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_unwrap_field", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_field, "std::shared_ptr< arrow::Field> (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_unwrap_schema", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_schema, "std::shared_ptr< arrow::Schema> (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_unwrap_table", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_table, "std::shared_ptr< arrow::Table> (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_unwrap_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_tensor, "std::shared_ptr< arrow::Tensor> (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_unwrap_sparse_coo_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_coo_tensor, "std::shared_ptr< arrow::SparseCOOTensor> (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_unwrap_sparse_csr_matrix", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csr_matrix, "std::shared_ptr< arrow::SparseCSRMatrix> (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_unwrap_sparse_csc_matrix", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csc_matrix, "std::shared_ptr< arrow::SparseCSCMatrix> (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_unwrap_sparse_csf_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csf_tensor, "std::shared_ptr< arrow::SparseCSFTensor> (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_internal_check_status", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_internal_check_status, "int (arrow::Status const &)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_buffer", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_buffer, "int (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_data_type", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_data_type, "int (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_metadata", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_metadata, "int (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_field", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_field, "int (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_schema", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_schema, "int (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_array", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_array, "int (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_chunked_array", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_chunked_array, "int (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_scalar", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_scalar, "int (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_tensor, "int (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_sparse_coo_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_coo_tensor, "int (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_sparse_csr_matrix", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csr_matrix, "int (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_sparse_csc_matrix", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csc_matrix, "int (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_sparse_csf_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csf_tensor, "int (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_table", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_table, "int (PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction(module, "pyarrow_is_batch", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_batch, "int (PyObject *)") < 0) goto bad;
+ Py_DECREF(module); module = 0;
+ return 0;
+ bad:
+ Py_XDECREF(module);
+ return -1;
+}
+
+#endif /* !__PYX_HAVE_API__pyarrow__lib */
diff --git a/src/arrow/cpp/src/arrow/python/pyarrow_lib.h b/src/arrow/cpp/src/arrow/python/pyarrow_lib.h
new file mode 100644
index 000000000..fa5941447
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/pyarrow_lib.h
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// DO NOT EDIT THIS FILE. Update from pyarrow/lib.h after pyarrow build
+
+/* Generated by Cython 0.29.15 */
+
+#ifndef __PYX_HAVE__pyarrow__lib
+#define __PYX_HAVE__pyarrow__lib
+
+#include "Python.h"
+
+#ifndef __PYX_HAVE_API__pyarrow__lib
+
+#ifndef __PYX_EXTERN_C
+ #ifdef __cplusplus
+ #define __PYX_EXTERN_C extern "C"
+ #else
+ #define __PYX_EXTERN_C extern
+ #endif
+#endif
+
+#ifndef DL_IMPORT
+ #define DL_IMPORT(_T) _T
+#endif
+
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_scalar(std::shared_ptr< arrow::Scalar> const &);
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_array(std::shared_ptr< arrow::Array> const &);
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_chunked_array(std::shared_ptr< arrow::ChunkedArray> const &);
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_batch(std::shared_ptr< arrow::RecordBatch> const &);
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_buffer(std::shared_ptr< arrow::Buffer> const &);
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_data_type(std::shared_ptr< arrow::DataType> const &);
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_field(std::shared_ptr< arrow::Field> const &);
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_resizable_buffer(std::shared_ptr< arrow::ResizableBuffer> const &);
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_schema(std::shared_ptr< arrow::Schema> const &);
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_table(std::shared_ptr< arrow::Table> const &);
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_tensor(std::shared_ptr< arrow::Tensor> const &);
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_sparse_coo_tensor(std::shared_ptr< arrow::SparseCOOTensor> const &);
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_sparse_csr_matrix(std::shared_ptr< arrow::SparseCSRMatrix> const &);
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_sparse_csc_matrix(std::shared_ptr< arrow::SparseCSCMatrix> const &);
+__PYX_EXTERN_C PyObject *__pyx_f_7pyarrow_3lib_pyarrow_wrap_sparse_csf_tensor(std::shared_ptr< arrow::SparseCSFTensor> const &);
+__PYX_EXTERN_C std::shared_ptr< arrow::Scalar> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_scalar(PyObject *);
+__PYX_EXTERN_C std::shared_ptr< arrow::Array> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_array(PyObject *);
+__PYX_EXTERN_C std::shared_ptr< arrow::ChunkedArray> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_chunked_array(PyObject *);
+__PYX_EXTERN_C std::shared_ptr< arrow::RecordBatch> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_batch(PyObject *);
+__PYX_EXTERN_C std::shared_ptr< arrow::Buffer> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_buffer(PyObject *);
+__PYX_EXTERN_C std::shared_ptr< arrow::DataType> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_data_type(PyObject *);
+__PYX_EXTERN_C std::shared_ptr< arrow::Field> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_field(PyObject *);
+__PYX_EXTERN_C std::shared_ptr< arrow::Schema> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_schema(PyObject *);
+__PYX_EXTERN_C std::shared_ptr< arrow::Table> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_table(PyObject *);
+__PYX_EXTERN_C std::shared_ptr< arrow::Tensor> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_tensor(PyObject *);
+__PYX_EXTERN_C std::shared_ptr< arrow::SparseCOOTensor> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_sparse_coo_tensor(PyObject *);
+__PYX_EXTERN_C std::shared_ptr< arrow::SparseCSRMatrix> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csr_matrix(PyObject *);
+__PYX_EXTERN_C std::shared_ptr< arrow::SparseCSCMatrix> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csc_matrix(PyObject *);
+__PYX_EXTERN_C std::shared_ptr< arrow::SparseCSFTensor> __pyx_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csf_tensor(PyObject *);
+
+#endif /* !__PYX_HAVE_API__pyarrow__lib */
+
+/* WARNING: the interface of the module init function changed in CPython 3.5. */
+/* It now returns a PyModuleDef instance instead of a PyModule instance. */
+
+#if PY_MAJOR_VERSION < 3
+PyMODINIT_FUNC initlib(void);
+#else
+PyMODINIT_FUNC PyInit_lib(void);
+#endif
+
+#endif /* !__PYX_HAVE__pyarrow__lib */
diff --git a/src/arrow/cpp/src/arrow/python/python_test.cc b/src/arrow/cpp/src/arrow/python/python_test.cc
new file mode 100644
index 000000000..c465fabc6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/python_test.cc
@@ -0,0 +1,599 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gtest/gtest.h"
+
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include "arrow/python/platform.h"
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/optional.h"
+
+#include "arrow/python/arrow_to_pandas.h"
+#include "arrow/python/decimal.h"
+#include "arrow/python/helpers.h"
+#include "arrow/python/numpy_convert.h"
+#include "arrow/python/numpy_interop.h"
+#include "arrow/python/python_to_arrow.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace py {
+
+TEST(OwnedRef, TestMoves) {
+ std::vector<OwnedRef> vec;
+ PyObject *u, *v;
+ u = PyList_New(0);
+ v = PyList_New(0);
+
+ {
+ OwnedRef ref(u);
+ vec.push_back(std::move(ref));
+ ASSERT_EQ(ref.obj(), nullptr);
+ }
+ vec.emplace_back(v);
+ ASSERT_EQ(Py_REFCNT(u), 1);
+ ASSERT_EQ(Py_REFCNT(v), 1);
+}
+
+TEST(OwnedRefNoGIL, TestMoves) {
+ PyAcquireGIL lock;
+ lock.release();
+
+ {
+ std::vector<OwnedRef> vec;
+ PyObject *u, *v;
+ {
+ lock.acquire();
+ u = PyList_New(0);
+ v = PyList_New(0);
+ lock.release();
+ }
+ {
+ OwnedRefNoGIL ref(u);
+ vec.push_back(std::move(ref));
+ ASSERT_EQ(ref.obj(), nullptr);
+ }
+ vec.emplace_back(v);
+ ASSERT_EQ(Py_REFCNT(u), 1);
+ ASSERT_EQ(Py_REFCNT(v), 1);
+ }
+}
+
+std::string FormatPythonException(const std::string& exc_class_name) {
+ std::stringstream ss;
+ ss << "Python exception: ";
+ ss << exc_class_name;
+ return ss.str();
+}
+
+TEST(CheckPyError, TestStatus) {
+ Status st;
+
+ auto check_error = [](Status& st, const char* expected_message = "some error",
+ std::string expected_detail = "") {
+ st = CheckPyError();
+ ASSERT_EQ(st.message(), expected_message);
+ ASSERT_FALSE(PyErr_Occurred());
+ if (expected_detail.size() > 0) {
+ auto detail = st.detail();
+ ASSERT_NE(detail, nullptr);
+ ASSERT_EQ(detail->ToString(), expected_detail);
+ }
+ };
+
+ for (PyObject* exc_type : {PyExc_Exception, PyExc_SyntaxError}) {
+ PyErr_SetString(exc_type, "some error");
+ check_error(st);
+ ASSERT_TRUE(st.IsUnknownError());
+ }
+
+ PyErr_SetString(PyExc_TypeError, "some error");
+ check_error(st, "some error", FormatPythonException("TypeError"));
+ ASSERT_TRUE(st.IsTypeError());
+
+ PyErr_SetString(PyExc_ValueError, "some error");
+ check_error(st);
+ ASSERT_TRUE(st.IsInvalid());
+
+ PyErr_SetString(PyExc_KeyError, "some error");
+ check_error(st, "'some error'");
+ ASSERT_TRUE(st.IsKeyError());
+
+ for (PyObject* exc_type : {PyExc_OSError, PyExc_IOError}) {
+ PyErr_SetString(exc_type, "some error");
+ check_error(st);
+ ASSERT_TRUE(st.IsIOError());
+ }
+
+ PyErr_SetString(PyExc_NotImplementedError, "some error");
+ check_error(st, "some error", FormatPythonException("NotImplementedError"));
+ ASSERT_TRUE(st.IsNotImplemented());
+
+ // No override if a specific status code is given
+ PyErr_SetString(PyExc_TypeError, "some error");
+ st = CheckPyError(StatusCode::SerializationError);
+ ASSERT_TRUE(st.IsSerializationError());
+ ASSERT_EQ(st.message(), "some error");
+ ASSERT_FALSE(PyErr_Occurred());
+}
+
+TEST(CheckPyError, TestStatusNoGIL) {
+ PyAcquireGIL lock;
+ {
+ Status st;
+ PyErr_SetString(PyExc_ZeroDivisionError, "zzzt");
+ st = ConvertPyError();
+ ASSERT_FALSE(PyErr_Occurred());
+ lock.release();
+ ASSERT_TRUE(st.IsUnknownError());
+ ASSERT_EQ(st.message(), "zzzt");
+ ASSERT_EQ(st.detail()->ToString(), FormatPythonException("ZeroDivisionError"));
+ }
+}
+
+TEST(RestorePyError, Basics) {
+ PyErr_SetString(PyExc_ZeroDivisionError, "zzzt");
+ auto st = ConvertPyError();
+ ASSERT_FALSE(PyErr_Occurred());
+ ASSERT_TRUE(st.IsUnknownError());
+ ASSERT_EQ(st.message(), "zzzt");
+ ASSERT_EQ(st.detail()->ToString(), FormatPythonException("ZeroDivisionError"));
+
+ RestorePyError(st);
+ ASSERT_TRUE(PyErr_Occurred());
+ PyObject* exc_type;
+ PyObject* exc_value;
+ PyObject* exc_traceback;
+ PyErr_Fetch(&exc_type, &exc_value, &exc_traceback);
+ ASSERT_TRUE(PyErr_GivenExceptionMatches(exc_type, PyExc_ZeroDivisionError));
+ std::string py_message;
+ ASSERT_OK(internal::PyObject_StdStringStr(exc_value, &py_message));
+ ASSERT_EQ(py_message, "zzzt");
+}
+
+TEST(PyBuffer, InvalidInputObject) {
+ std::shared_ptr<Buffer> res;
+ PyObject* input = Py_None;
+ auto old_refcnt = Py_REFCNT(input);
+ {
+ Status st = PyBuffer::FromPyObject(input).status();
+ ASSERT_TRUE(IsPyError(st)) << st.ToString();
+ ASSERT_FALSE(PyErr_Occurred());
+ }
+ ASSERT_EQ(old_refcnt, Py_REFCNT(input));
+}
+
+// Because of how it is declared, the Numpy C API instance initialized
+// within libarrow_python.dll may not be visible in this test under Windows
+// ("unresolved external symbol arrow_ARRAY_API referenced").
+#ifndef _WIN32
+TEST(PyBuffer, NumpyArray) {
+ const npy_intp dims[1] = {10};
+
+ OwnedRef arr_ref(PyArray_SimpleNew(1, dims, NPY_FLOAT));
+ PyObject* arr = arr_ref.obj();
+ ASSERT_NE(arr, nullptr);
+ auto old_refcnt = Py_REFCNT(arr);
+
+ ASSERT_OK_AND_ASSIGN(auto buf, PyBuffer::FromPyObject(arr));
+ ASSERT_TRUE(buf->is_cpu());
+ ASSERT_EQ(buf->data(), PyArray_DATA(reinterpret_cast<PyArrayObject*>(arr)));
+ ASSERT_TRUE(buf->is_mutable());
+ ASSERT_EQ(buf->mutable_data(), buf->data());
+ ASSERT_EQ(old_refcnt + 1, Py_REFCNT(arr));
+ buf.reset();
+ ASSERT_EQ(old_refcnt, Py_REFCNT(arr));
+
+ // Read-only
+ PyArray_CLEARFLAGS(reinterpret_cast<PyArrayObject*>(arr), NPY_ARRAY_WRITEABLE);
+ ASSERT_OK_AND_ASSIGN(buf, PyBuffer::FromPyObject(arr));
+ ASSERT_TRUE(buf->is_cpu());
+ ASSERT_EQ(buf->data(), PyArray_DATA(reinterpret_cast<PyArrayObject*>(arr)));
+ ASSERT_FALSE(buf->is_mutable());
+ ASSERT_EQ(old_refcnt + 1, Py_REFCNT(arr));
+ buf.reset();
+ ASSERT_EQ(old_refcnt, Py_REFCNT(arr));
+}
+
+TEST(NumPyBuffer, NumpyArray) {
+ npy_intp dims[1] = {10};
+
+ OwnedRef arr_ref(PyArray_SimpleNew(1, dims, NPY_FLOAT));
+ PyObject* arr = arr_ref.obj();
+ ASSERT_NE(arr, nullptr);
+ auto old_refcnt = Py_REFCNT(arr);
+
+ auto buf = std::make_shared<NumPyBuffer>(arr);
+ ASSERT_TRUE(buf->is_cpu());
+ ASSERT_EQ(buf->data(), PyArray_DATA(reinterpret_cast<PyArrayObject*>(arr)));
+ ASSERT_TRUE(buf->is_mutable());
+ ASSERT_EQ(buf->mutable_data(), buf->data());
+ ASSERT_EQ(old_refcnt + 1, Py_REFCNT(arr));
+ buf.reset();
+ ASSERT_EQ(old_refcnt, Py_REFCNT(arr));
+
+ // Read-only
+ PyArray_CLEARFLAGS(reinterpret_cast<PyArrayObject*>(arr), NPY_ARRAY_WRITEABLE);
+ buf = std::make_shared<NumPyBuffer>(arr);
+ ASSERT_TRUE(buf->is_cpu());
+ ASSERT_EQ(buf->data(), PyArray_DATA(reinterpret_cast<PyArrayObject*>(arr)));
+ ASSERT_FALSE(buf->is_mutable());
+ ASSERT_EQ(old_refcnt + 1, Py_REFCNT(arr));
+ buf.reset();
+ ASSERT_EQ(old_refcnt, Py_REFCNT(arr));
+}
+#endif
+
+class DecimalTest : public ::testing::Test {
+ public:
+ DecimalTest() : lock_(), decimal_constructor_() {
+ OwnedRef decimal_module;
+
+ Status status = internal::ImportModule("decimal", &decimal_module);
+ ARROW_CHECK_OK(status);
+
+ status = internal::ImportFromModule(decimal_module.obj(), "Decimal",
+ &decimal_constructor_);
+ ARROW_CHECK_OK(status);
+ }
+
+ OwnedRef CreatePythonDecimal(const std::string& string_value) {
+ OwnedRef ref(internal::DecimalFromString(decimal_constructor_.obj(), string_value));
+ return ref;
+ }
+
+ PyObject* decimal_constructor() const { return decimal_constructor_.obj(); }
+
+ private:
+ PyAcquireGIL lock_;
+ OwnedRef decimal_constructor_;
+};
+
+TEST_F(DecimalTest, TestPythonDecimalToString) {
+ std::string decimal_string("-39402950693754869342983");
+
+ OwnedRef python_object(this->CreatePythonDecimal(decimal_string));
+ ASSERT_NE(python_object.obj(), nullptr);
+
+ std::string string_result;
+ ASSERT_OK(internal::PythonDecimalToString(python_object.obj(), &string_result));
+}
+
+TEST_F(DecimalTest, TestInferPrecisionAndScale) {
+ std::string decimal_string("-394029506937548693.42983");
+ OwnedRef python_decimal(this->CreatePythonDecimal(decimal_string));
+
+ internal::DecimalMetadata metadata;
+ ASSERT_OK(metadata.Update(python_decimal.obj()));
+
+ const auto expected_precision =
+ static_cast<int32_t>(decimal_string.size() - 2); // 1 for -, 1 for .
+ const int32_t expected_scale = 5;
+
+ ASSERT_EQ(expected_precision, metadata.precision());
+ ASSERT_EQ(expected_scale, metadata.scale());
+}
+
+TEST_F(DecimalTest, TestInferPrecisionAndNegativeScale) {
+ std::string decimal_string("-3.94042983E+10");
+ OwnedRef python_decimal(this->CreatePythonDecimal(decimal_string));
+
+ internal::DecimalMetadata metadata;
+ ASSERT_OK(metadata.Update(python_decimal.obj()));
+
+ const auto expected_precision = 11;
+ const int32_t expected_scale = 0;
+
+ ASSERT_EQ(expected_precision, metadata.precision());
+ ASSERT_EQ(expected_scale, metadata.scale());
+}
+
+TEST_F(DecimalTest, TestInferAllLeadingZeros) {
+ std::string decimal_string("0.001");
+ OwnedRef python_decimal(this->CreatePythonDecimal(decimal_string));
+
+ internal::DecimalMetadata metadata;
+ ASSERT_OK(metadata.Update(python_decimal.obj()));
+ ASSERT_EQ(3, metadata.precision());
+ ASSERT_EQ(3, metadata.scale());
+}
+
+TEST_F(DecimalTest, TestInferAllLeadingZerosExponentialNotationPositive) {
+ std::string decimal_string("0.01E5");
+ OwnedRef python_decimal(this->CreatePythonDecimal(decimal_string));
+ internal::DecimalMetadata metadata;
+ ASSERT_OK(metadata.Update(python_decimal.obj()));
+ ASSERT_EQ(4, metadata.precision());
+ ASSERT_EQ(0, metadata.scale());
+}
+
+TEST_F(DecimalTest, TestInferAllLeadingZerosExponentialNotationNegative) {
+ std::string decimal_string("0.01E3");
+ OwnedRef python_decimal(this->CreatePythonDecimal(decimal_string));
+ internal::DecimalMetadata metadata;
+ ASSERT_OK(metadata.Update(python_decimal.obj()));
+ ASSERT_EQ(2, metadata.precision());
+ ASSERT_EQ(0, metadata.scale());
+}
+
+TEST(PandasConversionTest, TestObjectBlockWriteFails) {
+ StringBuilder builder;
+ const char value[] = {'\xf1', '\0'};
+
+ for (int i = 0; i < 1000; ++i) {
+ ASSERT_OK(builder.Append(value, static_cast<int32_t>(strlen(value))));
+ }
+
+ std::shared_ptr<Array> arr;
+ ASSERT_OK(builder.Finish(&arr));
+
+ auto f1 = field("f1", utf8());
+ auto f2 = field("f2", utf8());
+ auto f3 = field("f3", utf8());
+ std::vector<std::shared_ptr<Field>> fields = {f1, f2, f3};
+ std::vector<std::shared_ptr<Array>> cols = {arr, arr, arr};
+
+ auto schema = ::arrow::schema(fields);
+ auto table = Table::Make(schema, cols);
+
+ Status st;
+ Py_BEGIN_ALLOW_THREADS;
+ PyObject* out;
+ PandasOptions options;
+ options.use_threads = true;
+ st = ConvertTableToPandas(options, table, &out);
+ Py_END_ALLOW_THREADS;
+ ASSERT_RAISES(UnknownError, st);
+}
+
+TEST(BuiltinConversionTest, TestMixedTypeFails) {
+ OwnedRef list_ref(PyList_New(3));
+ PyObject* list = list_ref.obj();
+
+ ASSERT_NE(list, nullptr);
+
+ PyObject* str = PyUnicode_FromString("abc");
+ ASSERT_NE(str, nullptr);
+
+ PyObject* integer = PyLong_FromLong(1234L);
+ ASSERT_NE(integer, nullptr);
+
+ PyObject* doub = PyFloat_FromDouble(123.0234);
+ ASSERT_NE(doub, nullptr);
+
+ // This steals a reference to each object, so we don't need to decref them later
+ // just the list
+ ASSERT_EQ(PyList_SetItem(list, 0, str), 0);
+ ASSERT_EQ(PyList_SetItem(list, 1, integer), 0);
+ ASSERT_EQ(PyList_SetItem(list, 2, doub), 0);
+
+ ASSERT_RAISES(TypeError, ConvertPySequence(list, nullptr, {}));
+}
+
+template <typename DecimalValue>
+void DecimalTestFromPythonDecimalRescale(std::shared_ptr<DataType> type,
+ OwnedRef python_decimal,
+ ::arrow::util::optional<int> expected) {
+ DecimalValue value;
+ const auto& decimal_type = checked_cast<const DecimalType&>(*type);
+
+ if (expected.has_value()) {
+ ASSERT_OK(
+ internal::DecimalFromPythonDecimal(python_decimal.obj(), decimal_type, &value));
+ ASSERT_EQ(expected.value(), value);
+
+ ASSERT_OK(internal::DecimalFromPyObject(python_decimal.obj(), decimal_type, &value));
+ ASSERT_EQ(expected.value(), value);
+ } else {
+ ASSERT_RAISES(Invalid, internal::DecimalFromPythonDecimal(python_decimal.obj(),
+ decimal_type, &value));
+ ASSERT_RAISES(Invalid, internal::DecimalFromPyObject(python_decimal.obj(),
+ decimal_type, &value));
+ }
+}
+
+TEST_F(DecimalTest, FromPythonDecimalRescaleNotTruncateable) {
+ // We fail when truncating values that would lose data if cast to a decimal type with
+ // lower scale
+ DecimalTestFromPythonDecimalRescale<Decimal128>(::arrow::decimal128(10, 2),
+ this->CreatePythonDecimal("1.001"), {});
+ DecimalTestFromPythonDecimalRescale<Decimal256>(::arrow::decimal256(10, 2),
+ this->CreatePythonDecimal("1.001"), {});
+}
+
+TEST_F(DecimalTest, FromPythonDecimalRescaleTruncateable) {
+ // We allow truncation of values that do not lose precision when dividing by 10 * the
+ // difference between the scales, e.g., 1.000 -> 1.00
+ DecimalTestFromPythonDecimalRescale<Decimal128>(
+ ::arrow::decimal128(10, 2), this->CreatePythonDecimal("1.000"), 100);
+ DecimalTestFromPythonDecimalRescale<Decimal256>(
+ ::arrow::decimal256(10, 2), this->CreatePythonDecimal("1.000"), 100);
+}
+
+TEST_F(DecimalTest, FromPythonNegativeDecimalRescale) {
+ DecimalTestFromPythonDecimalRescale<Decimal128>(
+ ::arrow::decimal128(10, 9), this->CreatePythonDecimal("-1.000"), -1000000000);
+ DecimalTestFromPythonDecimalRescale<Decimal256>(
+ ::arrow::decimal256(10, 9), this->CreatePythonDecimal("-1.000"), -1000000000);
+}
+
+TEST_F(DecimalTest, Decimal128FromPythonInteger) {
+ Decimal128 value;
+ OwnedRef python_long(PyLong_FromLong(42));
+ auto type = ::arrow::decimal128(10, 2);
+ const auto& decimal_type = checked_cast<const DecimalType&>(*type);
+ ASSERT_OK(internal::DecimalFromPyObject(python_long.obj(), decimal_type, &value));
+ ASSERT_EQ(4200, value);
+}
+
+TEST_F(DecimalTest, Decimal256FromPythonInteger) {
+ Decimal256 value;
+ OwnedRef python_long(PyLong_FromLong(42));
+ auto type = ::arrow::decimal256(10, 2);
+ const auto& decimal_type = checked_cast<const DecimalType&>(*type);
+ ASSERT_OK(internal::DecimalFromPyObject(python_long.obj(), decimal_type, &value));
+ ASSERT_EQ(4200, value);
+}
+
+TEST_F(DecimalTest, TestDecimal128OverflowFails) {
+ Decimal128 value;
+ OwnedRef python_decimal(
+ this->CreatePythonDecimal("9999999999999999999999999999999999999.9"));
+ internal::DecimalMetadata metadata;
+ ASSERT_OK(metadata.Update(python_decimal.obj()));
+ ASSERT_EQ(38, metadata.precision());
+ ASSERT_EQ(1, metadata.scale());
+
+ auto type = ::arrow::decimal(38, 38);
+ const auto& decimal_type = checked_cast<const DecimalType&>(*type);
+ ASSERT_RAISES(Invalid, internal::DecimalFromPythonDecimal(python_decimal.obj(),
+ decimal_type, &value));
+}
+
+TEST_F(DecimalTest, TestDecimal256OverflowFails) {
+ Decimal256 value;
+ OwnedRef python_decimal(this->CreatePythonDecimal(
+ "999999999999999999999999999999999999999999999999999999999999999999999999999.9"));
+ internal::DecimalMetadata metadata;
+ ASSERT_OK(metadata.Update(python_decimal.obj()));
+ ASSERT_EQ(76, metadata.precision());
+ ASSERT_EQ(1, metadata.scale());
+
+ auto type = ::arrow::decimal(76, 76);
+ const auto& decimal_type = checked_cast<const DecimalType&>(*type);
+ ASSERT_RAISES(Invalid, internal::DecimalFromPythonDecimal(python_decimal.obj(),
+ decimal_type, &value));
+}
+
+TEST_F(DecimalTest, TestNoneAndNaN) {
+ OwnedRef list_ref(PyList_New(4));
+ PyObject* list = list_ref.obj();
+
+ ASSERT_NE(list, nullptr);
+
+ PyObject* constructor = this->decimal_constructor();
+ PyObject* decimal_value = internal::DecimalFromString(constructor, "1.234");
+ ASSERT_NE(decimal_value, nullptr);
+
+ Py_INCREF(Py_None);
+ PyObject* missing_value1 = Py_None;
+ ASSERT_NE(missing_value1, nullptr);
+
+ PyObject* missing_value2 = PyFloat_FromDouble(NPY_NAN);
+ ASSERT_NE(missing_value2, nullptr);
+
+ PyObject* missing_value3 = internal::DecimalFromString(constructor, "nan");
+ ASSERT_NE(missing_value3, nullptr);
+
+ // This steals a reference to each object, so we don't need to decref them later,
+ // just the list
+ ASSERT_EQ(0, PyList_SetItem(list, 0, decimal_value));
+ ASSERT_EQ(0, PyList_SetItem(list, 1, missing_value1));
+ ASSERT_EQ(0, PyList_SetItem(list, 2, missing_value2));
+ ASSERT_EQ(0, PyList_SetItem(list, 3, missing_value3));
+
+ PyConversionOptions options;
+ ASSERT_RAISES(TypeError, ConvertPySequence(list, nullptr, options));
+
+ options.from_pandas = true;
+ ASSERT_OK_AND_ASSIGN(auto chunked, ConvertPySequence(list, nullptr, options));
+ ASSERT_EQ(chunked->num_chunks(), 1);
+
+ auto arr = chunked->chunk(0);
+ ASSERT_TRUE(arr->IsValid(0));
+ ASSERT_TRUE(arr->IsNull(1));
+ ASSERT_TRUE(arr->IsNull(2));
+ ASSERT_TRUE(arr->IsNull(3));
+}
+
+TEST_F(DecimalTest, TestMixedPrecisionAndScale) {
+ std::vector<std::string> strings{{"0.001", "1.01E5", "1.01E5"}};
+
+ OwnedRef list_ref(PyList_New(static_cast<Py_ssize_t>(strings.size())));
+ PyObject* list = list_ref.obj();
+
+ ASSERT_NE(list, nullptr);
+
+ // PyList_SetItem steals a reference to the item so we don't decref it later
+ PyObject* decimal_constructor = this->decimal_constructor();
+ for (Py_ssize_t i = 0; i < static_cast<Py_ssize_t>(strings.size()); ++i) {
+ const int result = PyList_SetItem(
+ list, i, internal::DecimalFromString(decimal_constructor, strings.at(i)));
+ ASSERT_EQ(0, result);
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto arr, ConvertPySequence(list, nullptr, {}))
+ const auto& type = checked_cast<const DecimalType&>(*arr->type());
+
+ int32_t expected_precision = 9;
+ int32_t expected_scale = 3;
+ ASSERT_EQ(expected_precision, type.precision());
+ ASSERT_EQ(expected_scale, type.scale());
+}
+
+TEST_F(DecimalTest, TestMixedPrecisionAndScaleSequenceConvert) {
+ PyObject* value1 = this->CreatePythonDecimal("0.01").detach();
+ ASSERT_NE(value1, nullptr);
+
+ PyObject* value2 = this->CreatePythonDecimal("0.001").detach();
+ ASSERT_NE(value2, nullptr);
+
+ OwnedRef list_ref(PyList_New(2));
+ PyObject* list = list_ref.obj();
+
+ // This steals a reference to each object, so we don't need to decref them later
+ // just the list
+ ASSERT_EQ(PyList_SetItem(list, 0, value1), 0);
+ ASSERT_EQ(PyList_SetItem(list, 1, value2), 0);
+
+ ASSERT_OK_AND_ASSIGN(auto arr, ConvertPySequence(list, nullptr, {}));
+ const auto& type = checked_cast<const Decimal128Type&>(*arr->type());
+ ASSERT_EQ(3, type.precision());
+ ASSERT_EQ(3, type.scale());
+}
+
+TEST_F(DecimalTest, SimpleInference) {
+ OwnedRef value(this->CreatePythonDecimal("0.01"));
+ ASSERT_NE(value.obj(), nullptr);
+ internal::DecimalMetadata metadata;
+ ASSERT_OK(metadata.Update(value.obj()));
+ ASSERT_EQ(2, metadata.precision());
+ ASSERT_EQ(2, metadata.scale());
+}
+
+TEST_F(DecimalTest, UpdateWithNaN) {
+ internal::DecimalMetadata metadata;
+ OwnedRef nan_value(this->CreatePythonDecimal("nan"));
+ ASSERT_OK(metadata.Update(nan_value.obj()));
+ ASSERT_EQ(std::numeric_limits<int32_t>::min(), metadata.precision());
+ ASSERT_EQ(std::numeric_limits<int32_t>::min(), metadata.scale());
+}
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/python_to_arrow.cc b/src/arrow/cpp/src/arrow/python/python_to_arrow.cc
new file mode 100644
index 000000000..10250d165
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/python_to_arrow.cc
@@ -0,0 +1,1179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/python/python_to_arrow.h"
+#include "arrow/python/numpy_interop.h"
+
+#include <datetime.h>
+
+#include <algorithm>
+#include <limits>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/array/builder_dict.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/builder_time.h"
+#include "arrow/chunked_array.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/converter.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/logging.h"
+
+#include "arrow/python/datetime.h"
+#include "arrow/python/decimal.h"
+#include "arrow/python/helpers.h"
+#include "arrow/python/inference.h"
+#include "arrow/python/iterators.h"
+#include "arrow/python/numpy_convert.h"
+#include "arrow/python/type_traits.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+using internal::Converter;
+using internal::DictionaryConverter;
+using internal::ListConverter;
+using internal::PrimitiveConverter;
+using internal::StructConverter;
+
+using internal::MakeChunker;
+using internal::MakeConverter;
+
+namespace py {
+
+namespace {
+enum class MonthDayNanoField { kMonths, kWeeksAndDays, kDaysOnly, kNanoseconds };
+
+template <MonthDayNanoField field>
+struct MonthDayNanoTraits;
+
+struct MonthDayNanoAttrData {
+ const char* name;
+ const int64_t multiplier;
+};
+
+template <>
+struct MonthDayNanoTraits<MonthDayNanoField::kMonths> {
+ using c_type = int32_t;
+ static const MonthDayNanoAttrData attrs[];
+};
+
+const MonthDayNanoAttrData MonthDayNanoTraits<MonthDayNanoField::kMonths>::attrs[] = {
+ {"years", 1}, {"months", /*months_in_year=*/12}, {nullptr, 0}};
+
+template <>
+struct MonthDayNanoTraits<MonthDayNanoField::kWeeksAndDays> {
+ using c_type = int32_t;
+ static const MonthDayNanoAttrData attrs[];
+};
+
+const MonthDayNanoAttrData MonthDayNanoTraits<MonthDayNanoField::kWeeksAndDays>::attrs[] =
+ {{"weeks", 1}, {"days", /*days_in_week=*/7}, {nullptr, 0}};
+
+template <>
+struct MonthDayNanoTraits<MonthDayNanoField::kDaysOnly> {
+ using c_type = int32_t;
+ static const MonthDayNanoAttrData attrs[];
+};
+
+const MonthDayNanoAttrData MonthDayNanoTraits<MonthDayNanoField::kDaysOnly>::attrs[] = {
+ {"days", 1}, {nullptr, 0}};
+
+template <>
+struct MonthDayNanoTraits<MonthDayNanoField::kNanoseconds> {
+ using c_type = int64_t;
+ static const MonthDayNanoAttrData attrs[];
+};
+
+const MonthDayNanoAttrData MonthDayNanoTraits<MonthDayNanoField::kNanoseconds>::attrs[] =
+ {{"hours", 1},
+ {"minutes", /*minutes_in_hours=*/60},
+ {"seconds", /*seconds_in_minute=*/60},
+ {"milliseconds", /*milliseconds_in_seconds*/ 1000},
+ {"microseconds", /*microseconds_in_millseconds=*/1000},
+ {"nanoseconds", /*nanoseconds_in_microseconds=*/1000},
+ {nullptr, 0}};
+
+template <MonthDayNanoField field>
+struct PopulateMonthDayNano {
+ using Traits = MonthDayNanoTraits<field>;
+ using field_c_type = typename Traits::c_type;
+
+ static Status Field(PyObject* obj, field_c_type* out, bool* found_attrs) {
+ *out = 0;
+ for (const MonthDayNanoAttrData* attr = &Traits::attrs[0]; attr->multiplier != 0;
+ ++attr) {
+ if (attr->multiplier != 1 &&
+ ::arrow::internal::MultiplyWithOverflow(
+ static_cast<field_c_type>(attr->multiplier), *out, out)) {
+ return Status::Invalid("Overflow on: ", (attr - 1)->name,
+ " for: ", internal::PyObject_StdStringRepr(obj));
+ }
+
+ OwnedRef field_value(PyObject_GetAttrString(obj, attr->name));
+ if (field_value.obj() == nullptr) {
+ // No attribute present, skip to the next one.
+ PyErr_Clear();
+ continue;
+ }
+ RETURN_IF_PYERROR();
+ *found_attrs = true;
+ field_c_type value;
+ RETURN_NOT_OK(internal::CIntFromPython(field_value.obj(), &value, attr->name));
+ if (::arrow::internal::AddWithOverflow(*out, value, out)) {
+ return Status::Invalid("Overflow on: ", attr->name,
+ " for: ", internal::PyObject_StdStringRepr(obj));
+ }
+ }
+
+ return Status::OK();
+ }
+};
+
+// Utility for converting single python objects to their intermediate C representations
+// which can be fed to the typed builders
+class PyValue {
+ public:
+ // Type aliases for shorter signature definitions
+ using I = PyObject*;
+ using O = PyConversionOptions;
+
+ // Used for null checking before actually converting the values
+ static bool IsNull(const O& options, I obj) {
+ if (options.from_pandas) {
+ return internal::PandasObjectIsNull(obj);
+ } else {
+ return obj == Py_None;
+ }
+ }
+
+ // Used for post-conversion numpy NaT sentinel checking
+ static bool IsNaT(const TimestampType*, int64_t value) {
+ return internal::npy_traits<NPY_DATETIME>::isnull(value);
+ }
+
+ // Used for post-conversion numpy NaT sentinel checking
+ static bool IsNaT(const DurationType*, int64_t value) {
+ return internal::npy_traits<NPY_TIMEDELTA>::isnull(value);
+ }
+
+ static Result<std::nullptr_t> Convert(const NullType*, const O&, I obj) {
+ if (obj == Py_None) {
+ return nullptr;
+ } else {
+ return Status::Invalid("Invalid null value");
+ }
+ }
+
+ static Result<bool> Convert(const BooleanType*, const O&, I obj) {
+ if (obj == Py_True) {
+ return true;
+ } else if (obj == Py_False) {
+ return false;
+ } else if (PyArray_IsScalar(obj, Bool)) {
+ return reinterpret_cast<PyBoolScalarObject*>(obj)->obval == NPY_TRUE;
+ } else {
+ return internal::InvalidValue(obj, "tried to convert to boolean");
+ }
+ }
+
+ template <typename T>
+ static enable_if_integer<T, Result<typename T::c_type>> Convert(const T* type, const O&,
+ I obj) {
+ typename T::c_type value;
+ auto status = internal::CIntFromPython(obj, &value);
+ if (ARROW_PREDICT_TRUE(status.ok())) {
+ return value;
+ } else if (!internal::PyIntScalar_Check(obj)) {
+ std::stringstream ss;
+ ss << "tried to convert to " << type->ToString();
+ return internal::InvalidValue(obj, ss.str());
+ } else {
+ return status;
+ }
+ }
+
+ static Result<uint16_t> Convert(const HalfFloatType*, const O&, I obj) {
+ uint16_t value;
+ RETURN_NOT_OK(PyFloat_AsHalf(obj, &value));
+ return value;
+ }
+
+ static Result<float> Convert(const FloatType*, const O&, I obj) {
+ float value;
+ if (internal::PyFloatScalar_Check(obj)) {
+ value = static_cast<float>(PyFloat_AsDouble(obj));
+ RETURN_IF_PYERROR();
+ } else if (internal::PyIntScalar_Check(obj)) {
+ RETURN_NOT_OK(internal::IntegerScalarToFloat32Safe(obj, &value));
+ } else {
+ return internal::InvalidValue(obj, "tried to convert to float32");
+ }
+ return value;
+ }
+
+ static Result<double> Convert(const DoubleType*, const O&, I obj) {
+ double value;
+ if (PyFloat_Check(obj)) {
+ value = PyFloat_AS_DOUBLE(obj);
+ } else if (internal::PyFloatScalar_Check(obj)) {
+ // Other kinds of float-y things
+ value = PyFloat_AsDouble(obj);
+ RETURN_IF_PYERROR();
+ } else if (internal::PyIntScalar_Check(obj)) {
+ RETURN_NOT_OK(internal::IntegerScalarToDoubleSafe(obj, &value));
+ } else {
+ return internal::InvalidValue(obj, "tried to convert to double");
+ }
+ return value;
+ }
+
+ static Result<Decimal128> Convert(const Decimal128Type* type, const O&, I obj) {
+ Decimal128 value;
+ RETURN_NOT_OK(internal::DecimalFromPyObject(obj, *type, &value));
+ return value;
+ }
+
+ static Result<Decimal256> Convert(const Decimal256Type* type, const O&, I obj) {
+ Decimal256 value;
+ RETURN_NOT_OK(internal::DecimalFromPyObject(obj, *type, &value));
+ return value;
+ }
+
+ static Result<int32_t> Convert(const Date32Type*, const O&, I obj) {
+ int32_t value;
+ if (PyDate_Check(obj)) {
+ auto pydate = reinterpret_cast<PyDateTime_Date*>(obj);
+ value = static_cast<int32_t>(internal::PyDate_to_days(pydate));
+ } else {
+ RETURN_NOT_OK(
+ internal::CIntFromPython(obj, &value, "Integer too large for date32"));
+ }
+ return value;
+ }
+
+ static Result<int64_t> Convert(const Date64Type*, const O&, I obj) {
+ int64_t value;
+ if (PyDateTime_Check(obj)) {
+ auto pydate = reinterpret_cast<PyDateTime_DateTime*>(obj);
+ value = internal::PyDateTime_to_ms(pydate);
+ // Truncate any intraday milliseconds
+ // TODO: introduce an option for this
+ value -= value % 86400000LL;
+ } else if (PyDate_Check(obj)) {
+ auto pydate = reinterpret_cast<PyDateTime_Date*>(obj);
+ value = internal::PyDate_to_ms(pydate);
+ } else {
+ RETURN_NOT_OK(
+ internal::CIntFromPython(obj, &value, "Integer too large for date64"));
+ }
+ return value;
+ }
+
+ static Result<int32_t> Convert(const Time32Type* type, const O&, I obj) {
+ int32_t value;
+ if (PyTime_Check(obj)) {
+ switch (type->unit()) {
+ case TimeUnit::SECOND:
+ value = static_cast<int32_t>(internal::PyTime_to_s(obj));
+ break;
+ case TimeUnit::MILLI:
+ value = static_cast<int32_t>(internal::PyTime_to_ms(obj));
+ break;
+ default:
+ return Status::UnknownError("Invalid time unit");
+ }
+ } else {
+ RETURN_NOT_OK(internal::CIntFromPython(obj, &value, "Integer too large for int32"));
+ }
+ return value;
+ }
+
+ static Result<int64_t> Convert(const Time64Type* type, const O&, I obj) {
+ int64_t value;
+ if (PyTime_Check(obj)) {
+ switch (type->unit()) {
+ case TimeUnit::MICRO:
+ value = internal::PyTime_to_us(obj);
+ break;
+ case TimeUnit::NANO:
+ value = internal::PyTime_to_ns(obj);
+ break;
+ default:
+ return Status::UnknownError("Invalid time unit");
+ }
+ } else {
+ RETURN_NOT_OK(internal::CIntFromPython(obj, &value, "Integer too large for int64"));
+ }
+ return value;
+ }
+
+ static Result<int64_t> Convert(const TimestampType* type, const O& options, I obj) {
+ int64_t value, offset;
+ if (PyDateTime_Check(obj)) {
+ if (ARROW_PREDICT_FALSE(options.ignore_timezone)) {
+ offset = 0;
+ } else {
+ ARROW_ASSIGN_OR_RAISE(offset, internal::PyDateTime_utcoffset_s(obj));
+ }
+ auto dt = reinterpret_cast<PyDateTime_DateTime*>(obj);
+ switch (type->unit()) {
+ case TimeUnit::SECOND:
+ value = internal::PyDateTime_to_s(dt) - offset;
+ break;
+ case TimeUnit::MILLI:
+ value = internal::PyDateTime_to_ms(dt) - offset * 1000LL;
+ break;
+ case TimeUnit::MICRO:
+ value = internal::PyDateTime_to_us(dt) - offset * 1000000LL;
+ break;
+ case TimeUnit::NANO:
+ if (internal::IsPandasTimestamp(obj)) {
+ // pd.Timestamp value attribute contains the offset from unix epoch
+ // so no adjustment for timezone is need.
+ OwnedRef nanos(PyObject_GetAttrString(obj, "value"));
+ RETURN_IF_PYERROR();
+ RETURN_NOT_OK(internal::CIntFromPython(nanos.obj(), &value));
+ } else {
+ // Conversion to nanoseconds can overflow -> check multiply of microseconds
+ value = internal::PyDateTime_to_us(dt);
+ if (arrow::internal::MultiplyWithOverflow(value, 1000LL, &value)) {
+ return internal::InvalidValue(obj,
+ "out of bounds for nanosecond resolution");
+ }
+
+ // Adjust with offset and check for overflow
+ if (arrow::internal::SubtractWithOverflow(value, offset * 1000000000LL,
+ &value)) {
+ return internal::InvalidValue(obj,
+ "out of bounds for nanosecond resolution");
+ }
+ }
+ break;
+ default:
+ return Status::UnknownError("Invalid time unit");
+ }
+ } else if (PyArray_CheckAnyScalarExact(obj)) {
+ // validate that the numpy scalar has np.datetime64 dtype
+ std::shared_ptr<DataType> numpy_type;
+ RETURN_NOT_OK(NumPyDtypeToArrow(PyArray_DescrFromScalar(obj), &numpy_type));
+ if (!numpy_type->Equals(*type)) {
+ return Status::NotImplemented("Expected np.datetime64 but got: ",
+ numpy_type->ToString());
+ }
+ return reinterpret_cast<PyDatetimeScalarObject*>(obj)->obval;
+ } else {
+ RETURN_NOT_OK(internal::CIntFromPython(obj, &value));
+ }
+ return value;
+ }
+
+ static Result<MonthDayNanoIntervalType::MonthDayNanos> Convert(
+ const MonthDayNanoIntervalType* /*type*/, const O& /*options*/, I obj) {
+ MonthDayNanoIntervalType::MonthDayNanos output;
+ bool found_attrs = false;
+ RETURN_NOT_OK(PopulateMonthDayNano<MonthDayNanoField::kMonths>::Field(
+ obj, &output.months, &found_attrs));
+ // on relativeoffset weeks is a property calculated from days. On
+ // DateOffset is is a field on its own. timedelta doesn't have a weeks
+ // attribute.
+ PyObject* pandas_date_offset_type = internal::BorrowPandasDataOffsetType();
+ bool is_date_offset = pandas_date_offset_type == (PyObject*)Py_TYPE(obj);
+ if (!is_date_offset) {
+ RETURN_NOT_OK(PopulateMonthDayNano<MonthDayNanoField::kDaysOnly>::Field(
+ obj, &output.days, &found_attrs));
+ } else {
+ RETURN_NOT_OK(PopulateMonthDayNano<MonthDayNanoField::kWeeksAndDays>::Field(
+ obj, &output.days, &found_attrs));
+ }
+ RETURN_NOT_OK(PopulateMonthDayNano<MonthDayNanoField::kNanoseconds>::Field(
+ obj, &output.nanoseconds, &found_attrs));
+
+ if (ARROW_PREDICT_FALSE(!found_attrs) && !is_date_offset) {
+ // date_offset can have zero fields.
+ return Status::TypeError("No temporal attributes found on object.");
+ }
+ return output;
+ }
+
+ static Result<int64_t> Convert(const DurationType* type, const O&, I obj) {
+ int64_t value;
+ if (PyDelta_Check(obj)) {
+ auto dt = reinterpret_cast<PyDateTime_Delta*>(obj);
+ switch (type->unit()) {
+ case TimeUnit::SECOND:
+ value = internal::PyDelta_to_s(dt);
+ break;
+ case TimeUnit::MILLI:
+ value = internal::PyDelta_to_ms(dt);
+ break;
+ case TimeUnit::MICRO:
+ value = internal::PyDelta_to_us(dt);
+ break;
+ case TimeUnit::NANO:
+ if (internal::IsPandasTimedelta(obj)) {
+ OwnedRef nanos(PyObject_GetAttrString(obj, "value"));
+ RETURN_IF_PYERROR();
+ RETURN_NOT_OK(internal::CIntFromPython(nanos.obj(), &value));
+ } else {
+ value = internal::PyDelta_to_ns(dt);
+ }
+ break;
+ default:
+ return Status::UnknownError("Invalid time unit");
+ }
+ } else if (PyArray_CheckAnyScalarExact(obj)) {
+ // validate that the numpy scalar has np.datetime64 dtype
+ std::shared_ptr<DataType> numpy_type;
+ RETURN_NOT_OK(NumPyDtypeToArrow(PyArray_DescrFromScalar(obj), &numpy_type));
+ if (!numpy_type->Equals(*type)) {
+ return Status::NotImplemented("Expected np.timedelta64 but got: ",
+ numpy_type->ToString());
+ }
+ return reinterpret_cast<PyTimedeltaScalarObject*>(obj)->obval;
+ } else {
+ RETURN_NOT_OK(internal::CIntFromPython(obj, &value));
+ }
+ return value;
+ }
+
+ // The binary-like intermediate representation is PyBytesView because it keeps temporary
+ // python objects alive (non-contiguous memoryview) and stores whether the original
+ // object was unicode encoded or not, which is used for unicode -> bytes coersion if
+ // there is a non-unicode object observed.
+
+ static Status Convert(const BaseBinaryType*, const O&, I obj, PyBytesView& view) {
+ return view.ParseString(obj);
+ }
+
+ static Status Convert(const FixedSizeBinaryType* type, const O&, I obj,
+ PyBytesView& view) {
+ ARROW_RETURN_NOT_OK(view.ParseString(obj));
+ if (view.size != type->byte_width()) {
+ std::stringstream ss;
+ ss << "expected to be length " << type->byte_width() << " was " << view.size;
+ return internal::InvalidValue(obj, ss.str());
+ } else {
+ return Status::OK();
+ }
+ }
+
+ template <typename T>
+ static enable_if_string<T, Status> Convert(const T*, const O& options, I obj,
+ PyBytesView& view) {
+ if (options.strict) {
+ // Strict conversion, force output to be unicode / utf8 and validate that
+ // any binary values are utf8
+ ARROW_RETURN_NOT_OK(view.ParseString(obj, true));
+ if (!view.is_utf8) {
+ return internal::InvalidValue(obj, "was not a utf8 string");
+ }
+ return Status::OK();
+ } else {
+ // Non-strict conversion; keep track of whether values are unicode or bytes
+ return view.ParseString(obj);
+ }
+ }
+
+ static Result<bool> Convert(const DataType* type, const O&, I obj) {
+ return Status::NotImplemented("PyValue::Convert is not implemented for type ", type);
+ }
+};
+
+// The base Converter class is a mixin with predefined behavior and constructors.
+class PyConverter : public Converter<PyObject*, PyConversionOptions> {
+ public:
+ // Iterate over the input values and defer the conversion to the Append method
+ Status Extend(PyObject* values, int64_t size, int64_t offset = 0) override {
+ DCHECK_GE(size, offset);
+ /// Ensure we've allocated enough space
+ RETURN_NOT_OK(this->Reserve(size - offset));
+ // Iterate over the items adding each one
+ return internal::VisitSequence(
+ values, offset,
+ [this](PyObject* item, bool* /* unused */) { return this->Append(item); });
+ }
+
+ // Convert and append a sequence of values masked with a numpy array
+ Status ExtendMasked(PyObject* values, PyObject* mask, int64_t size,
+ int64_t offset = 0) override {
+ DCHECK_GE(size, offset);
+ /// Ensure we've allocated enough space
+ RETURN_NOT_OK(this->Reserve(size - offset));
+ // Iterate over the items adding each one
+ return internal::VisitSequenceMasked(
+ values, mask, offset, [this](PyObject* item, bool is_masked, bool* /* unused */) {
+ if (is_masked) {
+ return this->AppendNull();
+ } else {
+ // This will also apply the null-checking convention in the event
+ // that the value is not masked
+ return this->Append(item); // perhaps use AppendValue instead?
+ }
+ });
+ }
+};
+
+template <typename T, typename Enable = void>
+class PyPrimitiveConverter;
+
+template <typename T>
+class PyListConverter;
+
+template <typename U, typename Enable = void>
+class PyDictionaryConverter;
+
+class PyStructConverter;
+
+template <typename T, typename Enable = void>
+struct PyConverterTrait;
+
+template <typename T>
+struct PyConverterTrait<
+ T, enable_if_t<(!is_nested_type<T>::value && !is_interval_type<T>::value &&
+ !is_extension_type<T>::value) ||
+ std::is_same<T, MonthDayNanoIntervalType>::value>> {
+ using type = PyPrimitiveConverter<T>;
+};
+
+template <typename T>
+struct PyConverterTrait<T, enable_if_list_like<T>> {
+ using type = PyListConverter<T>;
+};
+
+template <>
+struct PyConverterTrait<StructType> {
+ using type = PyStructConverter;
+};
+
+template <>
+struct PyConverterTrait<DictionaryType> {
+ template <typename T>
+ using dictionary_type = PyDictionaryConverter<T>;
+};
+
+template <typename T>
+class PyPrimitiveConverter<T, enable_if_null<T>>
+ : public PrimitiveConverter<T, PyConverter> {
+ public:
+ Status Append(PyObject* value) override {
+ if (PyValue::IsNull(this->options_, value)) {
+ return this->primitive_builder_->AppendNull();
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
+ return this->primitive_builder_->Append(converted);
+ }
+ }
+};
+
+template <typename T>
+class PyPrimitiveConverter<
+ T, enable_if_t<is_boolean_type<T>::value || is_number_type<T>::value ||
+ is_decimal_type<T>::value || is_date_type<T>::value ||
+ is_time_type<T>::value ||
+ std::is_same<MonthDayNanoIntervalType, T>::value>>
+ : public PrimitiveConverter<T, PyConverter> {
+ public:
+ Status Append(PyObject* value) override {
+ // Since the required space has been already allocated in the Extend functions we can
+ // rely on the Unsafe builder API which improves the performance.
+ if (PyValue::IsNull(this->options_, value)) {
+ this->primitive_builder_->UnsafeAppendNull();
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
+ this->primitive_builder_->UnsafeAppend(converted);
+ }
+ return Status::OK();
+ }
+};
+
+template <typename T>
+class PyPrimitiveConverter<
+ T, enable_if_t<is_timestamp_type<T>::value || is_duration_type<T>::value>>
+ : public PrimitiveConverter<T, PyConverter> {
+ public:
+ Status Append(PyObject* value) override {
+ if (PyValue::IsNull(this->options_, value)) {
+ this->primitive_builder_->UnsafeAppendNull();
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
+ // Numpy NaT sentinels can be checked after the conversion
+ if (PyArray_CheckAnyScalarExact(value) &&
+ PyValue::IsNaT(this->primitive_type_, converted)) {
+ this->primitive_builder_->UnsafeAppendNull();
+ } else {
+ this->primitive_builder_->UnsafeAppend(converted);
+ }
+ }
+ return Status::OK();
+ }
+};
+
+template <typename T>
+class PyPrimitiveConverter<T, enable_if_t<std::is_same<T, FixedSizeBinaryType>::value>>
+ : public PrimitiveConverter<T, PyConverter> {
+ public:
+ Status Append(PyObject* value) override {
+ if (PyValue::IsNull(this->options_, value)) {
+ this->primitive_builder_->UnsafeAppendNull();
+ } else {
+ ARROW_RETURN_NOT_OK(
+ PyValue::Convert(this->primitive_type_, this->options_, value, view_));
+ ARROW_RETURN_NOT_OK(this->primitive_builder_->ReserveData(view_.size));
+ this->primitive_builder_->UnsafeAppend(view_.bytes);
+ }
+ return Status::OK();
+ }
+
+ protected:
+ PyBytesView view_;
+};
+
+template <typename T>
+class PyPrimitiveConverter<T, enable_if_base_binary<T>>
+ : public PrimitiveConverter<T, PyConverter> {
+ public:
+ using OffsetType = typename T::offset_type;
+
+ Status Append(PyObject* value) override {
+ if (PyValue::IsNull(this->options_, value)) {
+ this->primitive_builder_->UnsafeAppendNull();
+ } else {
+ ARROW_RETURN_NOT_OK(
+ PyValue::Convert(this->primitive_type_, this->options_, value, view_));
+ if (!view_.is_utf8) {
+ // observed binary value
+ observed_binary_ = true;
+ }
+ // Since we don't know the varying length input size in advance, we need to
+ // reserve space in the value builder one by one. ReserveData raises CapacityError
+ // if the value would not fit into the array.
+ ARROW_RETURN_NOT_OK(this->primitive_builder_->ReserveData(view_.size));
+ this->primitive_builder_->UnsafeAppend(view_.bytes,
+ static_cast<OffsetType>(view_.size));
+ }
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<Array>> ToArray() override {
+ ARROW_ASSIGN_OR_RAISE(auto array, (PrimitiveConverter<T, PyConverter>::ToArray()));
+ if (observed_binary_) {
+ // if we saw any non-unicode, cast results to BinaryArray
+ auto binary_type = TypeTraits<typename T::PhysicalType>::type_singleton();
+ return array->View(binary_type);
+ } else {
+ return array;
+ }
+ }
+
+ protected:
+ PyBytesView view_;
+ bool observed_binary_ = false;
+};
+
+template <typename U>
+class PyDictionaryConverter<U, enable_if_has_c_type<U>>
+ : public DictionaryConverter<U, PyConverter> {
+ public:
+ Status Append(PyObject* value) override {
+ if (PyValue::IsNull(this->options_, value)) {
+ return this->value_builder_->AppendNull();
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto converted,
+ PyValue::Convert(this->value_type_, this->options_, value));
+ return this->value_builder_->Append(converted);
+ }
+ }
+};
+
+template <typename U>
+class PyDictionaryConverter<U, enable_if_has_string_view<U>>
+ : public DictionaryConverter<U, PyConverter> {
+ public:
+ Status Append(PyObject* value) override {
+ if (PyValue::IsNull(this->options_, value)) {
+ return this->value_builder_->AppendNull();
+ } else {
+ ARROW_RETURN_NOT_OK(
+ PyValue::Convert(this->value_type_, this->options_, value, view_));
+ return this->value_builder_->Append(view_.bytes, static_cast<int32_t>(view_.size));
+ }
+ }
+
+ protected:
+ PyBytesView view_;
+};
+
+template <typename T>
+class PyListConverter : public ListConverter<T, PyConverter, PyConverterTrait> {
+ public:
+ Status Append(PyObject* value) override {
+ if (PyValue::IsNull(this->options_, value)) {
+ return this->list_builder_->AppendNull();
+ }
+
+ RETURN_NOT_OK(this->list_builder_->Append());
+ if (PyArray_Check(value)) {
+ RETURN_NOT_OK(AppendNdarray(value));
+ } else if (PySequence_Check(value)) {
+ RETURN_NOT_OK(AppendSequence(value));
+ } else if (PySet_Check(value) || (Py_TYPE(value) == &PyDictValues_Type)) {
+ RETURN_NOT_OK(AppendIterable(value));
+ } else {
+ return internal::InvalidType(
+ value, "was not a sequence or recognized null for conversion to list type");
+ }
+
+ return ValidateBuilder(this->list_type_);
+ }
+
+ protected:
+ Status ValidateBuilder(const MapType*) {
+ if (this->list_builder_->key_builder()->null_count() > 0) {
+ return Status::Invalid("Invalid Map: key field can not contain null values");
+ } else {
+ return Status::OK();
+ }
+ }
+
+ Status ValidateBuilder(const BaseListType*) { return Status::OK(); }
+
+ Status AppendSequence(PyObject* value) {
+ int64_t size = static_cast<int64_t>(PySequence_Size(value));
+ RETURN_NOT_OK(this->list_builder_->ValidateOverflow(size));
+ return this->value_converter_->Extend(value, size);
+ }
+
+ Status AppendIterable(PyObject* value) {
+ PyObject* iterator = PyObject_GetIter(value);
+ OwnedRef iter_ref(iterator);
+ while (PyObject* item = PyIter_Next(iterator)) {
+ OwnedRef item_ref(item);
+ RETURN_NOT_OK(this->value_converter_->Reserve(1));
+ RETURN_NOT_OK(this->value_converter_->Append(item));
+ }
+ return Status::OK();
+ }
+
+ Status AppendNdarray(PyObject* value) {
+ PyArrayObject* ndarray = reinterpret_cast<PyArrayObject*>(value);
+ if (PyArray_NDIM(ndarray) != 1) {
+ return Status::Invalid("Can only convert 1-dimensional array values");
+ }
+ const int64_t size = PyArray_SIZE(ndarray);
+ RETURN_NOT_OK(this->list_builder_->ValidateOverflow(size));
+
+ const auto value_type = this->value_converter_->builder()->type();
+ switch (value_type->id()) {
+// If the value type does not match the expected NumPy dtype, then fall through
+// to a slower PySequence-based path
+#define LIST_FAST_CASE(TYPE_ID, TYPE, NUMPY_TYPE) \
+ case Type::TYPE_ID: { \
+ if (PyArray_DESCR(ndarray)->type_num != NUMPY_TYPE) { \
+ return this->value_converter_->Extend(value, size); \
+ } \
+ return AppendNdarrayTyped<TYPE, NUMPY_TYPE>(ndarray); \
+ }
+ LIST_FAST_CASE(BOOL, BooleanType, NPY_BOOL)
+ LIST_FAST_CASE(UINT8, UInt8Type, NPY_UINT8)
+ LIST_FAST_CASE(INT8, Int8Type, NPY_INT8)
+ LIST_FAST_CASE(UINT16, UInt16Type, NPY_UINT16)
+ LIST_FAST_CASE(INT16, Int16Type, NPY_INT16)
+ LIST_FAST_CASE(UINT32, UInt32Type, NPY_UINT32)
+ LIST_FAST_CASE(INT32, Int32Type, NPY_INT32)
+ LIST_FAST_CASE(UINT64, UInt64Type, NPY_UINT64)
+ LIST_FAST_CASE(INT64, Int64Type, NPY_INT64)
+ LIST_FAST_CASE(HALF_FLOAT, HalfFloatType, NPY_FLOAT16)
+ LIST_FAST_CASE(FLOAT, FloatType, NPY_FLOAT)
+ LIST_FAST_CASE(DOUBLE, DoubleType, NPY_DOUBLE)
+ LIST_FAST_CASE(TIMESTAMP, TimestampType, NPY_DATETIME)
+ LIST_FAST_CASE(DURATION, DurationType, NPY_TIMEDELTA)
+#undef LIST_FAST_CASE
+ default: {
+ return this->value_converter_->Extend(value, size);
+ }
+ }
+ }
+
+ template <typename ArrowType, int NUMPY_TYPE>
+ Status AppendNdarrayTyped(PyArrayObject* ndarray) {
+ // no need to go through the conversion
+ using NumpyTrait = internal::npy_traits<NUMPY_TYPE>;
+ using NumpyType = typename NumpyTrait::value_type;
+ using ValueBuilderType = typename TypeTraits<ArrowType>::BuilderType;
+
+ const bool null_sentinels_possible =
+ // Always treat Numpy's NaT as null
+ NUMPY_TYPE == NPY_DATETIME || NUMPY_TYPE == NPY_TIMEDELTA ||
+ // Observing pandas's null sentinels
+ (this->options_.from_pandas && NumpyTrait::supports_nulls);
+
+ auto value_builder =
+ checked_cast<ValueBuilderType*>(this->value_converter_->builder().get());
+
+ Ndarray1DIndexer<NumpyType> values(ndarray);
+ if (null_sentinels_possible) {
+ for (int64_t i = 0; i < values.size(); ++i) {
+ if (NumpyTrait::isnull(values[i])) {
+ RETURN_NOT_OK(value_builder->AppendNull());
+ } else {
+ RETURN_NOT_OK(value_builder->Append(values[i]));
+ }
+ }
+ } else if (!values.is_strided()) {
+ RETURN_NOT_OK(value_builder->AppendValues(values.data(), values.size()));
+ } else {
+ for (int64_t i = 0; i < values.size(); ++i) {
+ RETURN_NOT_OK(value_builder->Append(values[i]));
+ }
+ }
+ return Status::OK();
+ }
+};
+
+class PyStructConverter : public StructConverter<PyConverter, PyConverterTrait> {
+ public:
+ Status Append(PyObject* value) override {
+ if (PyValue::IsNull(this->options_, value)) {
+ return this->struct_builder_->AppendNull();
+ }
+ switch (input_kind_) {
+ case InputKind::DICT:
+ RETURN_NOT_OK(this->struct_builder_->Append());
+ return AppendDict(value);
+ case InputKind::TUPLE:
+ RETURN_NOT_OK(this->struct_builder_->Append());
+ return AppendTuple(value);
+ case InputKind::ITEMS:
+ RETURN_NOT_OK(this->struct_builder_->Append());
+ return AppendItems(value);
+ default:
+ RETURN_NOT_OK(InferInputKind(value));
+ return Append(value);
+ }
+ }
+
+ protected:
+ Status Init(MemoryPool* pool) override {
+ RETURN_NOT_OK((StructConverter<PyConverter, PyConverterTrait>::Init(pool)));
+
+ // Store the field names as a PyObjects for dict matching
+ num_fields_ = this->struct_type_->num_fields();
+ bytes_field_names_.reset(PyList_New(num_fields_));
+ unicode_field_names_.reset(PyList_New(num_fields_));
+ RETURN_IF_PYERROR();
+
+ for (int i = 0; i < num_fields_; i++) {
+ const auto& field_name = this->struct_type_->field(i)->name();
+ PyObject* bytes = PyBytes_FromStringAndSize(field_name.c_str(), field_name.size());
+ PyObject* unicode =
+ PyUnicode_FromStringAndSize(field_name.c_str(), field_name.size());
+ RETURN_IF_PYERROR();
+ PyList_SET_ITEM(bytes_field_names_.obj(), i, bytes);
+ PyList_SET_ITEM(unicode_field_names_.obj(), i, unicode);
+ }
+ return Status::OK();
+ }
+
+ Status InferInputKind(PyObject* value) {
+ // Infer input object's type, note that heterogeneous sequences are not allowed
+ if (PyDict_Check(value)) {
+ input_kind_ = InputKind::DICT;
+ } else if (PyTuple_Check(value)) {
+ input_kind_ = InputKind::TUPLE;
+ } else if (PySequence_Check(value)) {
+ input_kind_ = InputKind::ITEMS;
+ } else {
+ return internal::InvalidType(value,
+ "was not a dict, tuple, or recognized null value "
+ "for conversion to struct type");
+ }
+ return Status::OK();
+ }
+
+ Status InferKeyKind(PyObject* items) {
+ for (int i = 0; i < PySequence_Length(items); i++) {
+ // retrieve the key from the passed key-value pairs
+ ARROW_ASSIGN_OR_RAISE(auto pair, GetKeyValuePair(items, i));
+
+ // check key exists between the unicode field names
+ bool do_contain = PySequence_Contains(unicode_field_names_.obj(), pair.first);
+ RETURN_IF_PYERROR();
+ if (do_contain) {
+ key_kind_ = KeyKind::UNICODE;
+ return Status::OK();
+ }
+
+ // check key exists between the bytes field names
+ do_contain = PySequence_Contains(bytes_field_names_.obj(), pair.first);
+ RETURN_IF_PYERROR();
+ if (do_contain) {
+ key_kind_ = KeyKind::BYTES;
+ return Status::OK();
+ }
+ }
+ return Status::OK();
+ }
+
+ Status AppendEmpty() {
+ for (int i = 0; i < num_fields_; i++) {
+ RETURN_NOT_OK(this->children_[i]->Append(Py_None));
+ }
+ return Status::OK();
+ }
+
+ Status AppendTuple(PyObject* tuple) {
+ if (!PyTuple_Check(tuple)) {
+ return internal::InvalidType(tuple, "was expecting a tuple");
+ }
+ if (PyTuple_GET_SIZE(tuple) != num_fields_) {
+ return Status::Invalid("Tuple size must be equal to number of struct fields");
+ }
+ for (int i = 0; i < num_fields_; i++) {
+ PyObject* value = PyTuple_GET_ITEM(tuple, i);
+ RETURN_NOT_OK(this->children_[i]->Append(value));
+ }
+ return Status::OK();
+ }
+
+ Status AppendDict(PyObject* dict) {
+ if (!PyDict_Check(dict)) {
+ return internal::InvalidType(dict, "was expecting a dict");
+ }
+ switch (key_kind_) {
+ case KeyKind::UNICODE:
+ return AppendDict(dict, unicode_field_names_.obj());
+ case KeyKind::BYTES:
+ return AppendDict(dict, bytes_field_names_.obj());
+ default:
+ RETURN_NOT_OK(InferKeyKind(PyDict_Items(dict)));
+ if (key_kind_ == KeyKind::UNKNOWN) {
+ // was unable to infer the type which means that all keys are absent
+ return AppendEmpty();
+ } else {
+ return AppendDict(dict);
+ }
+ }
+ }
+
+ Status AppendItems(PyObject* items) {
+ if (!PySequence_Check(items)) {
+ return internal::InvalidType(items, "was expecting a sequence of key-value items");
+ }
+ switch (key_kind_) {
+ case KeyKind::UNICODE:
+ return AppendItems(items, unicode_field_names_.obj());
+ case KeyKind::BYTES:
+ return AppendItems(items, bytes_field_names_.obj());
+ default:
+ RETURN_NOT_OK(InferKeyKind(items));
+ if (key_kind_ == KeyKind::UNKNOWN) {
+ // was unable to infer the type which means that all keys are absent
+ return AppendEmpty();
+ } else {
+ return AppendItems(items);
+ }
+ }
+ }
+
+ Status AppendDict(PyObject* dict, PyObject* field_names) {
+ // NOTE we're ignoring any extraneous dict items
+ for (int i = 0; i < num_fields_; i++) {
+ PyObject* name = PyList_GET_ITEM(field_names, i); // borrowed
+ PyObject* value = PyDict_GetItem(dict, name); // borrowed
+ if (value == NULL) {
+ RETURN_IF_PYERROR();
+ }
+ RETURN_NOT_OK(this->children_[i]->Append(value ? value : Py_None));
+ }
+ return Status::OK();
+ }
+
+ Result<std::pair<PyObject*, PyObject*>> GetKeyValuePair(PyObject* seq, int index) {
+ PyObject* pair = PySequence_GetItem(seq, index);
+ RETURN_IF_PYERROR();
+ if (!PyTuple_Check(pair) || PyTuple_Size(pair) != 2) {
+ return internal::InvalidType(pair, "was expecting tuple of (key, value) pair");
+ }
+ PyObject* key = PyTuple_GetItem(pair, 0);
+ RETURN_IF_PYERROR();
+ PyObject* value = PyTuple_GetItem(pair, 1);
+ RETURN_IF_PYERROR();
+ return std::make_pair(key, value);
+ }
+
+ Status AppendItems(PyObject* items, PyObject* field_names) {
+ auto length = static_cast<int>(PySequence_Size(items));
+ RETURN_IF_PYERROR();
+
+ // append the values for the defined fields
+ for (int i = 0; i < std::min(num_fields_, length); i++) {
+ // retrieve the key-value pair
+ ARROW_ASSIGN_OR_RAISE(auto pair, GetKeyValuePair(items, i));
+
+ // validate that the key and the field name are equal
+ PyObject* name = PyList_GET_ITEM(field_names, i);
+ bool are_equal = PyObject_RichCompareBool(pair.first, name, Py_EQ);
+ RETURN_IF_PYERROR();
+
+ // finally append to the respective child builder
+ if (are_equal) {
+ RETURN_NOT_OK(this->children_[i]->Append(pair.second));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto key_view, PyBytesView::FromString(pair.first));
+ ARROW_ASSIGN_OR_RAISE(auto name_view, PyBytesView::FromString(name));
+ return Status::Invalid("The expected field name is `", name_view.bytes, "` but `",
+ key_view.bytes, "` was given");
+ }
+ }
+ // insert null values for missing fields
+ for (int i = length; i < num_fields_; i++) {
+ RETURN_NOT_OK(this->children_[i]->AppendNull());
+ }
+ return Status::OK();
+ }
+
+ // Whether we're converting from a sequence of dicts or tuples or list of pairs
+ enum class InputKind { UNKNOWN, DICT, TUPLE, ITEMS } input_kind_ = InputKind::UNKNOWN;
+ // Whether the input dictionary keys' type is python bytes or unicode
+ enum class KeyKind { UNKNOWN, BYTES, UNICODE } key_kind_ = KeyKind::UNKNOWN;
+ // Store the field names as a PyObjects for dict matching
+ OwnedRef bytes_field_names_;
+ OwnedRef unicode_field_names_;
+ // Store the number of fields for later reuse
+ int num_fields_;
+};
+
+// Convert *obj* to a sequence if necessary
+// Fill *size* to its length. If >= 0 on entry, *size* is an upper size
+// bound that may lead to truncation.
+Status ConvertToSequenceAndInferSize(PyObject* obj, PyObject** seq, int64_t* size) {
+ if (PySequence_Check(obj)) {
+ // obj is already a sequence
+ int64_t real_size = static_cast<int64_t>(PySequence_Size(obj));
+ if (*size < 0) {
+ *size = real_size;
+ } else {
+ *size = std::min(real_size, *size);
+ }
+ Py_INCREF(obj);
+ *seq = obj;
+ } else if (*size < 0) {
+ // unknown size, exhaust iterator
+ *seq = PySequence_List(obj);
+ RETURN_IF_PYERROR();
+ *size = static_cast<int64_t>(PyList_GET_SIZE(*seq));
+ } else {
+ // size is known but iterator could be infinite
+ Py_ssize_t i, n = *size;
+ PyObject* iter = PyObject_GetIter(obj);
+ RETURN_IF_PYERROR();
+ OwnedRef iter_ref(iter);
+ PyObject* lst = PyList_New(n);
+ RETURN_IF_PYERROR();
+ for (i = 0; i < n; i++) {
+ PyObject* item = PyIter_Next(iter);
+ if (!item) break;
+ PyList_SET_ITEM(lst, i, item);
+ }
+ // Shrink list if len(iterator) < size
+ if (i < n && PyList_SetSlice(lst, i, n, NULL)) {
+ Py_DECREF(lst);
+ return Status::UnknownError("failed to resize list");
+ }
+ *seq = lst;
+ *size = std::min<int64_t>(i, *size);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Result<std::shared_ptr<ChunkedArray>> ConvertPySequence(PyObject* obj, PyObject* mask,
+ PyConversionOptions options,
+ MemoryPool* pool) {
+ PyAcquireGIL lock;
+
+ PyObject* seq;
+ OwnedRef tmp_seq_nanny;
+
+ ARROW_ASSIGN_OR_RAISE(auto is_pandas_imported, internal::IsModuleImported("pandas"));
+ if (is_pandas_imported) {
+ // If pandas has been already imported initialize the static pandas objects to
+ // support converting from pd.Timedelta and pd.Timestamp objects
+ internal::InitPandasStaticData();
+ }
+
+ int64_t size = options.size;
+ RETURN_NOT_OK(ConvertToSequenceAndInferSize(obj, &seq, &size));
+ tmp_seq_nanny.reset(seq);
+
+ // In some cases, type inference may be "loose", like strings. If the user
+ // passed pa.string(), then we will error if we encounter any non-UTF8
+ // value. If not, then we will allow the result to be a BinaryArray
+ if (options.type == nullptr) {
+ ARROW_ASSIGN_OR_RAISE(options.type, InferArrowType(seq, mask, options.from_pandas));
+ options.strict = false;
+ } else {
+ options.strict = true;
+ }
+ DCHECK_GE(size, 0);
+
+ ARROW_ASSIGN_OR_RAISE(auto converter, (MakeConverter<PyConverter, PyConverterTrait>(
+ options.type, options, pool)));
+ if (converter->may_overflow()) {
+ // The converter hierarchy contains binary- or list-like builders which can overflow
+ // depending on the input values. Wrap the converter with a chunker which detects
+ // the overflow and automatically creates new chunks.
+ ARROW_ASSIGN_OR_RAISE(auto chunked_converter, MakeChunker(std::move(converter)));
+ if (mask != nullptr && mask != Py_None) {
+ RETURN_NOT_OK(chunked_converter->ExtendMasked(seq, mask, size));
+ } else {
+ RETURN_NOT_OK(chunked_converter->Extend(seq, size));
+ }
+ return chunked_converter->ToChunkedArray();
+ } else {
+ // If the converter can't overflow spare the capacity error checking on the hot-path,
+ // this improves the performance roughly by ~10% for primitive types.
+ if (mask != nullptr && mask != Py_None) {
+ RETURN_NOT_OK(converter->ExtendMasked(seq, mask, size));
+ } else {
+ RETURN_NOT_OK(converter->Extend(seq, size));
+ }
+ return converter->ToChunkedArray();
+ }
+}
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/python_to_arrow.h b/src/arrow/cpp/src/arrow/python/python_to_arrow.h
new file mode 100644
index 000000000..d167996ba
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/python_to_arrow.h
@@ -0,0 +1,80 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Functions for converting between CPython built-in data structures and Arrow
+// data structures
+
+#pragma once
+
+#include "arrow/python/platform.h"
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/python/visibility.h"
+#include "arrow/type.h"
+#include "arrow/util/macros.h"
+
+#include "arrow/python/common.h"
+
+namespace arrow {
+
+class Array;
+class Status;
+
+namespace py {
+
+struct PyConversionOptions {
+ PyConversionOptions() = default;
+
+ PyConversionOptions(const std::shared_ptr<DataType>& type, int64_t size,
+ MemoryPool* pool, bool from_pandas)
+ : type(type), size(size), from_pandas(from_pandas) {}
+
+ // Set to null if to be inferred
+ std::shared_ptr<DataType> type;
+
+ // Default is -1, which indicates the size should the same as the input sequence
+ int64_t size = -1;
+
+ bool from_pandas = false;
+
+ /// Used to maintain backwards compatibility for
+ /// timezone bugs (see ARROW-9528). Should be removed
+ /// after Arrow 2.0 release.
+ bool ignore_timezone = false;
+
+ bool strict = false;
+};
+
+/// \brief Convert sequence (list, generator, NumPy array with dtype object) of
+/// Python objects.
+/// \param[in] obj the sequence to convert
+/// \param[in] mask a NumPy array of true/false values to indicate whether
+/// values in the sequence are null (true) or not null (false). This parameter
+/// may be null
+/// \param[in] options various conversion options
+/// \param[in] pool MemoryPool to use for allocations
+/// \return Result ChunkedArray
+ARROW_PYTHON_EXPORT
+Result<std::shared_ptr<ChunkedArray>> ConvertPySequence(
+ PyObject* obj, PyObject* mask, PyConversionOptions options,
+ MemoryPool* pool = default_memory_pool());
+
+} // namespace py
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/serialize.cc b/src/arrow/cpp/src/arrow/python/serialize.cc
new file mode 100644
index 000000000..ad079cbd9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/serialize.cc
@@ -0,0 +1,798 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/python/serialize.h"
+#include "arrow/python/numpy_interop.h"
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include <numpy/arrayobject.h>
+#include <numpy/arrayscalars.h>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/builder_union.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/util.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/tensor.h"
+#include "arrow/util/logging.h"
+
+#include "arrow/python/common.h"
+#include "arrow/python/datetime.h"
+#include "arrow/python/helpers.h"
+#include "arrow/python/iterators.h"
+#include "arrow/python/numpy_convert.h"
+#include "arrow/python/platform.h"
+#include "arrow/python/pyarrow.h"
+
+constexpr int32_t kMaxRecursionDepth = 100;
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace py {
+
+class SequenceBuilder;
+class DictBuilder;
+
+Status Append(PyObject* context, PyObject* elem, SequenceBuilder* builder,
+ int32_t recursion_depth, SerializedPyObject* blobs_out);
+
+// A Sequence is a heterogeneous collections of elements. It can contain
+// scalar Python types, lists, tuples, dictionaries, tensors and sparse tensors.
+class SequenceBuilder {
+ public:
+ explicit SequenceBuilder(MemoryPool* pool = default_memory_pool())
+ : pool_(pool),
+ types_(::arrow::int8(), pool),
+ offsets_(::arrow::int32(), pool),
+ type_map_(PythonType::NUM_PYTHON_TYPES, -1) {
+ auto null_builder = std::make_shared<NullBuilder>(pool);
+ auto initial_ty = dense_union({field("0", null())});
+ builder_.reset(new DenseUnionBuilder(pool, {null_builder}, initial_ty));
+ }
+
+ // Appending a none to the sequence
+ Status AppendNone() { return builder_->AppendNull(); }
+
+ template <typename BuilderType, typename MakeBuilderFn>
+ Status CreateAndUpdate(std::shared_ptr<BuilderType>* child_builder, int8_t tag,
+ MakeBuilderFn make_builder) {
+ if (!*child_builder) {
+ child_builder->reset(make_builder());
+ std::ostringstream convert;
+ convert.imbue(std::locale::classic());
+ convert << static_cast<int>(tag);
+ type_map_[tag] = builder_->AppendChild(*child_builder, convert.str());
+ }
+ return builder_->Append(type_map_[tag]);
+ }
+
+ template <typename BuilderType, typename T>
+ Status AppendPrimitive(std::shared_ptr<BuilderType>* child_builder, const T val,
+ int8_t tag) {
+ RETURN_NOT_OK(
+ CreateAndUpdate(child_builder, tag, [this]() { return new BuilderType(pool_); }));
+ return (*child_builder)->Append(val);
+ }
+
+ // Appending a boolean to the sequence
+ Status AppendBool(const bool data) {
+ return AppendPrimitive(&bools_, data, PythonType::BOOL);
+ }
+
+ // Appending an int64_t to the sequence
+ Status AppendInt64(const int64_t data) {
+ return AppendPrimitive(&ints_, data, PythonType::INT);
+ }
+
+ // Append a list of bytes to the sequence
+ Status AppendBytes(const uint8_t* data, int32_t length) {
+ RETURN_NOT_OK(CreateAndUpdate(&bytes_, PythonType::BYTES,
+ [this]() { return new BinaryBuilder(pool_); }));
+ return bytes_->Append(data, length);
+ }
+
+ // Appending a string to the sequence
+ Status AppendString(const char* data, int32_t length) {
+ RETURN_NOT_OK(CreateAndUpdate(&strings_, PythonType::STRING,
+ [this]() { return new StringBuilder(pool_); }));
+ return strings_->Append(data, length);
+ }
+
+ // Appending a half_float to the sequence
+ Status AppendHalfFloat(const npy_half data) {
+ return AppendPrimitive(&half_floats_, data, PythonType::HALF_FLOAT);
+ }
+
+ // Appending a float to the sequence
+ Status AppendFloat(const float data) {
+ return AppendPrimitive(&floats_, data, PythonType::FLOAT);
+ }
+
+ // Appending a double to the sequence
+ Status AppendDouble(const double data) {
+ return AppendPrimitive(&doubles_, data, PythonType::DOUBLE);
+ }
+
+ // Appending a Date64 timestamp to the sequence
+ Status AppendDate64(const int64_t timestamp) {
+ return AppendPrimitive(&date64s_, timestamp, PythonType::DATE64);
+ }
+
+ // Appending a tensor to the sequence
+ //
+ // \param tensor_index Index of the tensor in the object.
+ Status AppendTensor(const int32_t tensor_index) {
+ RETURN_NOT_OK(CreateAndUpdate(&tensor_indices_, PythonType::TENSOR,
+ [this]() { return new Int32Builder(pool_); }));
+ return tensor_indices_->Append(tensor_index);
+ }
+
+ // Appending a sparse coo tensor to the sequence
+ //
+ // \param sparse_coo_tensor_index Index of the sparse coo tensor in the object.
+ Status AppendSparseCOOTensor(const int32_t sparse_coo_tensor_index) {
+ RETURN_NOT_OK(CreateAndUpdate(&sparse_coo_tensor_indices_,
+ PythonType::SPARSECOOTENSOR,
+ [this]() { return new Int32Builder(pool_); }));
+ return sparse_coo_tensor_indices_->Append(sparse_coo_tensor_index);
+ }
+
+ // Appending a sparse csr matrix to the sequence
+ //
+ // \param sparse_csr_matrix_index Index of the sparse csr matrix in the object.
+ Status AppendSparseCSRMatrix(const int32_t sparse_csr_matrix_index) {
+ RETURN_NOT_OK(CreateAndUpdate(&sparse_csr_matrix_indices_,
+ PythonType::SPARSECSRMATRIX,
+ [this]() { return new Int32Builder(pool_); }));
+ return sparse_csr_matrix_indices_->Append(sparse_csr_matrix_index);
+ }
+
+ // Appending a sparse csc matrix to the sequence
+ //
+ // \param sparse_csc_matrix_index Index of the sparse csc matrix in the object.
+ Status AppendSparseCSCMatrix(const int32_t sparse_csc_matrix_index) {
+ RETURN_NOT_OK(CreateAndUpdate(&sparse_csc_matrix_indices_,
+ PythonType::SPARSECSCMATRIX,
+ [this]() { return new Int32Builder(pool_); }));
+ return sparse_csc_matrix_indices_->Append(sparse_csc_matrix_index);
+ }
+
+ // Appending a sparse csf tensor to the sequence
+ //
+ // \param sparse_csf_tensor_index Index of the sparse csf tensor in the object.
+ Status AppendSparseCSFTensor(const int32_t sparse_csf_tensor_index) {
+ RETURN_NOT_OK(CreateAndUpdate(&sparse_csf_tensor_indices_,
+ PythonType::SPARSECSFTENSOR,
+ [this]() { return new Int32Builder(pool_); }));
+ return sparse_csf_tensor_indices_->Append(sparse_csf_tensor_index);
+ }
+
+ // Appending a numpy ndarray to the sequence
+ //
+ // \param tensor_index Index of the tensor in the object.
+ Status AppendNdarray(const int32_t ndarray_index) {
+ RETURN_NOT_OK(CreateAndUpdate(&ndarray_indices_, PythonType::NDARRAY,
+ [this]() { return new Int32Builder(pool_); }));
+ return ndarray_indices_->Append(ndarray_index);
+ }
+
+ // Appending a buffer to the sequence
+ //
+ // \param buffer_index Index of the buffer in the object.
+ Status AppendBuffer(const int32_t buffer_index) {
+ RETURN_NOT_OK(CreateAndUpdate(&buffer_indices_, PythonType::BUFFER,
+ [this]() { return new Int32Builder(pool_); }));
+ return buffer_indices_->Append(buffer_index);
+ }
+
+ Status AppendSequence(PyObject* context, PyObject* sequence, int8_t tag,
+ std::shared_ptr<ListBuilder>& target_sequence,
+ std::unique_ptr<SequenceBuilder>& values, int32_t recursion_depth,
+ SerializedPyObject* blobs_out) {
+ if (recursion_depth >= kMaxRecursionDepth) {
+ return Status::NotImplemented(
+ "This object exceeds the maximum recursion depth. It may contain itself "
+ "recursively.");
+ }
+ RETURN_NOT_OK(CreateAndUpdate(&target_sequence, tag, [this, &values]() {
+ values.reset(new SequenceBuilder(pool_));
+ return new ListBuilder(pool_, values->builder());
+ }));
+ RETURN_NOT_OK(target_sequence->Append());
+ return internal::VisitIterable(
+ sequence, [&](PyObject* obj, bool* keep_going /* unused */) {
+ return Append(context, obj, values.get(), recursion_depth, blobs_out);
+ });
+ }
+
+ Status AppendList(PyObject* context, PyObject* list, int32_t recursion_depth,
+ SerializedPyObject* blobs_out) {
+ return AppendSequence(context, list, PythonType::LIST, lists_, list_values_,
+ recursion_depth + 1, blobs_out);
+ }
+
+ Status AppendTuple(PyObject* context, PyObject* tuple, int32_t recursion_depth,
+ SerializedPyObject* blobs_out) {
+ return AppendSequence(context, tuple, PythonType::TUPLE, tuples_, tuple_values_,
+ recursion_depth + 1, blobs_out);
+ }
+
+ Status AppendSet(PyObject* context, PyObject* set, int32_t recursion_depth,
+ SerializedPyObject* blobs_out) {
+ return AppendSequence(context, set, PythonType::SET, sets_, set_values_,
+ recursion_depth + 1, blobs_out);
+ }
+
+ Status AppendDict(PyObject* context, PyObject* dict, int32_t recursion_depth,
+ SerializedPyObject* blobs_out);
+
+ // Finish building the sequence and return the result.
+ // Input arrays may be nullptr
+ Status Finish(std::shared_ptr<Array>* out) { return builder_->Finish(out); }
+
+ std::shared_ptr<DenseUnionBuilder> builder() { return builder_; }
+
+ private:
+ MemoryPool* pool_;
+
+ Int8Builder types_;
+ Int32Builder offsets_;
+
+ /// Mapping from PythonType to child index
+ std::vector<int8_t> type_map_;
+
+ std::shared_ptr<BooleanBuilder> bools_;
+ std::shared_ptr<Int64Builder> ints_;
+ std::shared_ptr<BinaryBuilder> bytes_;
+ std::shared_ptr<StringBuilder> strings_;
+ std::shared_ptr<HalfFloatBuilder> half_floats_;
+ std::shared_ptr<FloatBuilder> floats_;
+ std::shared_ptr<DoubleBuilder> doubles_;
+ std::shared_ptr<Date64Builder> date64s_;
+
+ std::unique_ptr<SequenceBuilder> list_values_;
+ std::shared_ptr<ListBuilder> lists_;
+ std::unique_ptr<DictBuilder> dict_values_;
+ std::shared_ptr<ListBuilder> dicts_;
+ std::unique_ptr<SequenceBuilder> tuple_values_;
+ std::shared_ptr<ListBuilder> tuples_;
+ std::unique_ptr<SequenceBuilder> set_values_;
+ std::shared_ptr<ListBuilder> sets_;
+
+ std::shared_ptr<Int32Builder> tensor_indices_;
+ std::shared_ptr<Int32Builder> sparse_coo_tensor_indices_;
+ std::shared_ptr<Int32Builder> sparse_csr_matrix_indices_;
+ std::shared_ptr<Int32Builder> sparse_csc_matrix_indices_;
+ std::shared_ptr<Int32Builder> sparse_csf_tensor_indices_;
+ std::shared_ptr<Int32Builder> ndarray_indices_;
+ std::shared_ptr<Int32Builder> buffer_indices_;
+
+ std::shared_ptr<DenseUnionBuilder> builder_;
+};
+
+// Constructing dictionaries of key/value pairs. Sequences of
+// keys and values are built separately using a pair of
+// SequenceBuilders. The resulting Arrow representation
+// can be obtained via the Finish method.
+class DictBuilder {
+ public:
+ explicit DictBuilder(MemoryPool* pool = nullptr) : keys_(pool), vals_(pool) {
+ builder_.reset(new StructBuilder(struct_({field("keys", dense_union(FieldVector{})),
+ field("vals", dense_union(FieldVector{}))}),
+ pool, {keys_.builder(), vals_.builder()}));
+ }
+
+ // Builder for the keys of the dictionary
+ SequenceBuilder& keys() { return keys_; }
+ // Builder for the values of the dictionary
+ SequenceBuilder& vals() { return vals_; }
+
+ // Construct an Arrow StructArray representing the dictionary.
+ // Contains a field "keys" for the keys and "vals" for the values.
+ Status Finish(std::shared_ptr<Array>* out) { return builder_->Finish(out); }
+
+ std::shared_ptr<StructBuilder> builder() { return builder_; }
+
+ private:
+ SequenceBuilder keys_;
+ SequenceBuilder vals_;
+ std::shared_ptr<StructBuilder> builder_;
+};
+
+Status SequenceBuilder::AppendDict(PyObject* context, PyObject* dict,
+ int32_t recursion_depth,
+ SerializedPyObject* blobs_out) {
+ if (recursion_depth >= kMaxRecursionDepth) {
+ return Status::NotImplemented(
+ "This object exceeds the maximum recursion depth. It may contain itself "
+ "recursively.");
+ }
+ RETURN_NOT_OK(CreateAndUpdate(&dicts_, PythonType::DICT, [this]() {
+ dict_values_.reset(new DictBuilder(pool_));
+ return new ListBuilder(pool_, dict_values_->builder());
+ }));
+ RETURN_NOT_OK(dicts_->Append());
+ PyObject* key;
+ PyObject* value;
+ Py_ssize_t pos = 0;
+ while (PyDict_Next(dict, &pos, &key, &value)) {
+ RETURN_NOT_OK(dict_values_->builder()->Append());
+ RETURN_NOT_OK(
+ Append(context, key, &dict_values_->keys(), recursion_depth + 1, blobs_out));
+ RETURN_NOT_OK(
+ Append(context, value, &dict_values_->vals(), recursion_depth + 1, blobs_out));
+ }
+
+ // This block is used to decrement the reference counts of the results
+ // returned by the serialization callback, which is called in AppendArray,
+ // in DeserializeDict and in Append
+ static PyObject* py_type = PyUnicode_FromString("_pytype_");
+ if (PyDict_Contains(dict, py_type)) {
+ // If the dictionary contains the key "_pytype_", then the user has to
+ // have registered a callback.
+ if (context == Py_None) {
+ return Status::Invalid("No serialization callback set");
+ }
+ Py_XDECREF(dict);
+ }
+ return Status::OK();
+}
+
+Status CallCustomCallback(PyObject* context, PyObject* method_name, PyObject* elem,
+ PyObject** result) {
+ if (context == Py_None) {
+ *result = NULL;
+ return Status::SerializationError("error while calling callback on ",
+ internal::PyObject_StdStringRepr(elem),
+ ": handler not registered");
+ } else {
+ *result = PyObject_CallMethodObjArgs(context, method_name, elem, NULL);
+ return CheckPyError();
+ }
+}
+
+Status CallSerializeCallback(PyObject* context, PyObject* value,
+ PyObject** serialized_object) {
+ OwnedRef method_name(PyUnicode_FromString("_serialize_callback"));
+ RETURN_NOT_OK(CallCustomCallback(context, method_name.obj(), value, serialized_object));
+ if (!PyDict_Check(*serialized_object)) {
+ return Status::TypeError("serialization callback must return a valid dictionary");
+ }
+ return Status::OK();
+}
+
+Status CallDeserializeCallback(PyObject* context, PyObject* value,
+ PyObject** deserialized_object) {
+ OwnedRef method_name(PyUnicode_FromString("_deserialize_callback"));
+ return CallCustomCallback(context, method_name.obj(), value, deserialized_object);
+}
+
+Status AppendArray(PyObject* context, PyArrayObject* array, SequenceBuilder* builder,
+ int32_t recursion_depth, SerializedPyObject* blobs_out);
+
+template <typename NumpyScalarObject>
+Status AppendIntegerScalar(PyObject* obj, SequenceBuilder* builder) {
+ int64_t value = reinterpret_cast<NumpyScalarObject*>(obj)->obval;
+ return builder->AppendInt64(value);
+}
+
+// Append a potentially 64-bit wide unsigned Numpy scalar.
+// Must check for overflow as we reinterpret it as signed int64.
+template <typename NumpyScalarObject>
+Status AppendLargeUnsignedScalar(PyObject* obj, SequenceBuilder* builder) {
+ constexpr uint64_t max_value = std::numeric_limits<int64_t>::max();
+
+ uint64_t value = reinterpret_cast<NumpyScalarObject*>(obj)->obval;
+ if (value > max_value) {
+ return Status::Invalid("cannot serialize Numpy uint64 scalar >= 2**63");
+ }
+ return builder->AppendInt64(static_cast<int64_t>(value));
+}
+
+Status AppendScalar(PyObject* obj, SequenceBuilder* builder) {
+ if (PyArray_IsScalar(obj, Bool)) {
+ return builder->AppendBool(reinterpret_cast<PyBoolScalarObject*>(obj)->obval != 0);
+ } else if (PyArray_IsScalar(obj, Half)) {
+ return builder->AppendHalfFloat(reinterpret_cast<PyHalfScalarObject*>(obj)->obval);
+ } else if (PyArray_IsScalar(obj, Float)) {
+ return builder->AppendFloat(reinterpret_cast<PyFloatScalarObject*>(obj)->obval);
+ } else if (PyArray_IsScalar(obj, Double)) {
+ return builder->AppendDouble(reinterpret_cast<PyDoubleScalarObject*>(obj)->obval);
+ }
+ if (PyArray_IsScalar(obj, Byte)) {
+ return AppendIntegerScalar<PyByteScalarObject>(obj, builder);
+ } else if (PyArray_IsScalar(obj, Short)) {
+ return AppendIntegerScalar<PyShortScalarObject>(obj, builder);
+ } else if (PyArray_IsScalar(obj, Int)) {
+ return AppendIntegerScalar<PyIntScalarObject>(obj, builder);
+ } else if (PyArray_IsScalar(obj, Long)) {
+ return AppendIntegerScalar<PyLongScalarObject>(obj, builder);
+ } else if (PyArray_IsScalar(obj, LongLong)) {
+ return AppendIntegerScalar<PyLongLongScalarObject>(obj, builder);
+ } else if (PyArray_IsScalar(obj, Int64)) {
+ return AppendIntegerScalar<PyInt64ScalarObject>(obj, builder);
+ } else if (PyArray_IsScalar(obj, UByte)) {
+ return AppendIntegerScalar<PyUByteScalarObject>(obj, builder);
+ } else if (PyArray_IsScalar(obj, UShort)) {
+ return AppendIntegerScalar<PyUShortScalarObject>(obj, builder);
+ } else if (PyArray_IsScalar(obj, UInt)) {
+ return AppendIntegerScalar<PyUIntScalarObject>(obj, builder);
+ } else if (PyArray_IsScalar(obj, ULong)) {
+ return AppendLargeUnsignedScalar<PyULongScalarObject>(obj, builder);
+ } else if (PyArray_IsScalar(obj, ULongLong)) {
+ return AppendLargeUnsignedScalar<PyULongLongScalarObject>(obj, builder);
+ } else if (PyArray_IsScalar(obj, UInt64)) {
+ return AppendLargeUnsignedScalar<PyUInt64ScalarObject>(obj, builder);
+ }
+ return Status::NotImplemented("Numpy scalar type not recognized");
+}
+
+Status Append(PyObject* context, PyObject* elem, SequenceBuilder* builder,
+ int32_t recursion_depth, SerializedPyObject* blobs_out) {
+ // The bool case must precede the int case (PyInt_Check passes for bools)
+ if (PyBool_Check(elem)) {
+ RETURN_NOT_OK(builder->AppendBool(elem == Py_True));
+ } else if (PyArray_DescrFromScalar(elem)->type_num == NPY_HALF) {
+ npy_half halffloat = reinterpret_cast<PyHalfScalarObject*>(elem)->obval;
+ RETURN_NOT_OK(builder->AppendHalfFloat(halffloat));
+ } else if (PyFloat_Check(elem)) {
+ RETURN_NOT_OK(builder->AppendDouble(PyFloat_AS_DOUBLE(elem)));
+ } else if (PyLong_Check(elem)) {
+ int overflow = 0;
+ int64_t data = PyLong_AsLongLongAndOverflow(elem, &overflow);
+ if (!overflow) {
+ RETURN_NOT_OK(builder->AppendInt64(data));
+ } else {
+ // Attempt to serialize the object using the custom callback.
+ PyObject* serialized_object;
+ // The reference count of serialized_object will be decremented in SerializeDict
+ RETURN_NOT_OK(CallSerializeCallback(context, elem, &serialized_object));
+ RETURN_NOT_OK(
+ builder->AppendDict(context, serialized_object, recursion_depth, blobs_out));
+ }
+ } else if (PyBytes_Check(elem)) {
+ auto data = reinterpret_cast<uint8_t*>(PyBytes_AS_STRING(elem));
+ int32_t size = -1;
+ RETURN_NOT_OK(internal::CastSize(PyBytes_GET_SIZE(elem), &size));
+ RETURN_NOT_OK(builder->AppendBytes(data, size));
+ } else if (PyUnicode_Check(elem)) {
+ ARROW_ASSIGN_OR_RAISE(auto view, PyBytesView::FromUnicode(elem));
+ int32_t size = -1;
+ RETURN_NOT_OK(internal::CastSize(view.size, &size));
+ RETURN_NOT_OK(builder->AppendString(view.bytes, size));
+ } else if (PyList_CheckExact(elem)) {
+ RETURN_NOT_OK(builder->AppendList(context, elem, recursion_depth, blobs_out));
+ } else if (PyDict_CheckExact(elem)) {
+ RETURN_NOT_OK(builder->AppendDict(context, elem, recursion_depth, blobs_out));
+ } else if (PyTuple_CheckExact(elem)) {
+ RETURN_NOT_OK(builder->AppendTuple(context, elem, recursion_depth, blobs_out));
+ } else if (PySet_Check(elem)) {
+ RETURN_NOT_OK(builder->AppendSet(context, elem, recursion_depth, blobs_out));
+ } else if (PyArray_IsScalar(elem, Generic)) {
+ RETURN_NOT_OK(AppendScalar(elem, builder));
+ } else if (PyArray_CheckExact(elem)) {
+ RETURN_NOT_OK(AppendArray(context, reinterpret_cast<PyArrayObject*>(elem), builder,
+ recursion_depth, blobs_out));
+ } else if (elem == Py_None) {
+ RETURN_NOT_OK(builder->AppendNone());
+ } else if (PyDateTime_Check(elem)) {
+ PyDateTime_DateTime* datetime = reinterpret_cast<PyDateTime_DateTime*>(elem);
+ RETURN_NOT_OK(builder->AppendDate64(internal::PyDateTime_to_us(datetime)));
+ } else if (is_buffer(elem)) {
+ RETURN_NOT_OK(builder->AppendBuffer(static_cast<int32_t>(blobs_out->buffers.size())));
+ ARROW_ASSIGN_OR_RAISE(auto buffer, unwrap_buffer(elem));
+ blobs_out->buffers.push_back(buffer);
+ } else if (is_tensor(elem)) {
+ RETURN_NOT_OK(builder->AppendTensor(static_cast<int32_t>(blobs_out->tensors.size())));
+ ARROW_ASSIGN_OR_RAISE(auto tensor, unwrap_tensor(elem));
+ blobs_out->tensors.push_back(tensor);
+ } else if (is_sparse_coo_tensor(elem)) {
+ RETURN_NOT_OK(builder->AppendSparseCOOTensor(
+ static_cast<int32_t>(blobs_out->sparse_tensors.size())));
+ ARROW_ASSIGN_OR_RAISE(auto tensor, unwrap_sparse_coo_tensor(elem));
+ blobs_out->sparse_tensors.push_back(tensor);
+ } else if (is_sparse_csr_matrix(elem)) {
+ RETURN_NOT_OK(builder->AppendSparseCSRMatrix(
+ static_cast<int32_t>(blobs_out->sparse_tensors.size())));
+ ARROW_ASSIGN_OR_RAISE(auto matrix, unwrap_sparse_csr_matrix(elem));
+ blobs_out->sparse_tensors.push_back(matrix);
+ } else if (is_sparse_csc_matrix(elem)) {
+ RETURN_NOT_OK(builder->AppendSparseCSCMatrix(
+ static_cast<int32_t>(blobs_out->sparse_tensors.size())));
+ ARROW_ASSIGN_OR_RAISE(auto matrix, unwrap_sparse_csc_matrix(elem));
+ blobs_out->sparse_tensors.push_back(matrix);
+ } else if (is_sparse_csf_tensor(elem)) {
+ RETURN_NOT_OK(builder->AppendSparseCSFTensor(
+ static_cast<int32_t>(blobs_out->sparse_tensors.size())));
+ ARROW_ASSIGN_OR_RAISE(auto tensor, unwrap_sparse_csf_tensor(elem));
+ blobs_out->sparse_tensors.push_back(tensor);
+ } else {
+ // Attempt to serialize the object using the custom callback.
+ PyObject* serialized_object;
+ // The reference count of serialized_object will be decremented in SerializeDict
+ RETURN_NOT_OK(CallSerializeCallback(context, elem, &serialized_object));
+ RETURN_NOT_OK(
+ builder->AppendDict(context, serialized_object, recursion_depth, blobs_out));
+ }
+ return Status::OK();
+}
+
+Status AppendArray(PyObject* context, PyArrayObject* array, SequenceBuilder* builder,
+ int32_t recursion_depth, SerializedPyObject* blobs_out) {
+ int dtype = PyArray_TYPE(array);
+ switch (dtype) {
+ case NPY_UINT8:
+ case NPY_INT8:
+ case NPY_UINT16:
+ case NPY_INT16:
+ case NPY_UINT32:
+ case NPY_INT32:
+ case NPY_UINT64:
+ case NPY_INT64:
+ case NPY_HALF:
+ case NPY_FLOAT:
+ case NPY_DOUBLE: {
+ RETURN_NOT_OK(
+ builder->AppendNdarray(static_cast<int32_t>(blobs_out->ndarrays.size())));
+ std::shared_ptr<Tensor> tensor;
+ RETURN_NOT_OK(NdarrayToTensor(default_memory_pool(),
+ reinterpret_cast<PyObject*>(array), {}, &tensor));
+ blobs_out->ndarrays.push_back(tensor);
+ } break;
+ default: {
+ PyObject* serialized_object;
+ // The reference count of serialized_object will be decremented in SerializeDict
+ RETURN_NOT_OK(CallSerializeCallback(context, reinterpret_cast<PyObject*>(array),
+ &serialized_object));
+ RETURN_NOT_OK(builder->AppendDict(context, serialized_object, recursion_depth + 1,
+ blobs_out));
+ }
+ }
+ return Status::OK();
+}
+
+std::shared_ptr<RecordBatch> MakeBatch(std::shared_ptr<Array> data) {
+ auto field = std::make_shared<Field>("list", data->type());
+ auto schema = ::arrow::schema({field});
+ return RecordBatch::Make(schema, data->length(), {data});
+}
+
+Status SerializeObject(PyObject* context, PyObject* sequence, SerializedPyObject* out) {
+ PyAcquireGIL lock;
+ SequenceBuilder builder;
+ RETURN_NOT_OK(internal::VisitIterable(
+ sequence, [&](PyObject* obj, bool* keep_going /* unused */) {
+ return Append(context, obj, &builder, 0, out);
+ }));
+ std::shared_ptr<Array> array;
+ RETURN_NOT_OK(builder.Finish(&array));
+ out->batch = MakeBatch(array);
+ return Status::OK();
+}
+
+Status SerializeNdarray(std::shared_ptr<Tensor> tensor, SerializedPyObject* out) {
+ std::shared_ptr<Array> array;
+ SequenceBuilder builder;
+ RETURN_NOT_OK(builder.AppendNdarray(static_cast<int32_t>(out->ndarrays.size())));
+ out->ndarrays.push_back(tensor);
+ RETURN_NOT_OK(builder.Finish(&array));
+ out->batch = MakeBatch(array);
+ return Status::OK();
+}
+
+Status WriteNdarrayHeader(std::shared_ptr<DataType> dtype,
+ const std::vector<int64_t>& shape, int64_t tensor_num_bytes,
+ io::OutputStream* dst) {
+ auto empty_tensor = std::make_shared<Tensor>(
+ dtype, std::make_shared<Buffer>(nullptr, tensor_num_bytes), shape);
+ SerializedPyObject serialized_tensor;
+ RETURN_NOT_OK(SerializeNdarray(empty_tensor, &serialized_tensor));
+ return serialized_tensor.WriteTo(dst);
+}
+
+SerializedPyObject::SerializedPyObject()
+ : ipc_options(ipc::IpcWriteOptions::Defaults()) {}
+
+Status SerializedPyObject::WriteTo(io::OutputStream* dst) {
+ int32_t num_tensors = static_cast<int32_t>(this->tensors.size());
+ int32_t num_sparse_tensors = static_cast<int32_t>(this->sparse_tensors.size());
+ int32_t num_ndarrays = static_cast<int32_t>(this->ndarrays.size());
+ int32_t num_buffers = static_cast<int32_t>(this->buffers.size());
+ RETURN_NOT_OK(
+ dst->Write(reinterpret_cast<const uint8_t*>(&num_tensors), sizeof(int32_t)));
+ RETURN_NOT_OK(
+ dst->Write(reinterpret_cast<const uint8_t*>(&num_sparse_tensors), sizeof(int32_t)));
+ RETURN_NOT_OK(
+ dst->Write(reinterpret_cast<const uint8_t*>(&num_ndarrays), sizeof(int32_t)));
+ RETURN_NOT_OK(
+ dst->Write(reinterpret_cast<const uint8_t*>(&num_buffers), sizeof(int32_t)));
+
+ // Align stream to 8-byte offset
+ RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kArrowIpcAlignment));
+ RETURN_NOT_OK(ipc::WriteRecordBatchStream({this->batch}, this->ipc_options, dst));
+
+ // Align stream to 64-byte offset so tensor bodies are 64-byte aligned
+ RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kTensorAlignment));
+
+ int32_t metadata_length;
+ int64_t body_length;
+ for (const auto& tensor : this->tensors) {
+ RETURN_NOT_OK(ipc::WriteTensor(*tensor, dst, &metadata_length, &body_length));
+ RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kTensorAlignment));
+ }
+
+ for (const auto& sparse_tensor : this->sparse_tensors) {
+ RETURN_NOT_OK(
+ ipc::WriteSparseTensor(*sparse_tensor, dst, &metadata_length, &body_length));
+ RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kTensorAlignment));
+ }
+
+ for (const auto& tensor : this->ndarrays) {
+ RETURN_NOT_OK(ipc::WriteTensor(*tensor, dst, &metadata_length, &body_length));
+ RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kTensorAlignment));
+ }
+
+ for (const auto& buffer : this->buffers) {
+ int64_t size = buffer->size();
+ RETURN_NOT_OK(dst->Write(reinterpret_cast<const uint8_t*>(&size), sizeof(int64_t)));
+ RETURN_NOT_OK(dst->Write(buffer->data(), size));
+ }
+
+ return Status::OK();
+}
+
+namespace {
+
+Status CountSparseTensors(
+ const std::vector<std::shared_ptr<SparseTensor>>& sparse_tensors, PyObject** out) {
+ OwnedRef num_sparse_tensors(PyDict_New());
+ size_t num_coo = 0;
+ size_t num_csr = 0;
+ size_t num_csc = 0;
+ size_t num_csf = 0;
+ size_t ndim_csf = 0;
+
+ for (const auto& sparse_tensor : sparse_tensors) {
+ switch (sparse_tensor->format_id()) {
+ case SparseTensorFormat::COO:
+ ++num_coo;
+ break;
+ case SparseTensorFormat::CSR:
+ ++num_csr;
+ break;
+ case SparseTensorFormat::CSC:
+ ++num_csc;
+ break;
+ case SparseTensorFormat::CSF:
+ ++num_csf;
+ ndim_csf += sparse_tensor->ndim();
+ break;
+ }
+ }
+
+ PyDict_SetItemString(num_sparse_tensors.obj(), "coo", PyLong_FromSize_t(num_coo));
+ PyDict_SetItemString(num_sparse_tensors.obj(), "csr", PyLong_FromSize_t(num_csr));
+ PyDict_SetItemString(num_sparse_tensors.obj(), "csc", PyLong_FromSize_t(num_csc));
+ PyDict_SetItemString(num_sparse_tensors.obj(), "csf", PyLong_FromSize_t(num_csf));
+ PyDict_SetItemString(num_sparse_tensors.obj(), "ndim_csf", PyLong_FromSize_t(ndim_csf));
+ RETURN_IF_PYERROR();
+
+ *out = num_sparse_tensors.detach();
+ return Status::OK();
+}
+
+} // namespace
+
+Status SerializedPyObject::GetComponents(MemoryPool* memory_pool, PyObject** out) {
+ PyAcquireGIL py_gil;
+
+ OwnedRef result(PyDict_New());
+ PyObject* buffers = PyList_New(0);
+ PyObject* num_sparse_tensors = nullptr;
+
+ // TODO(wesm): Not sure how pedantic we need to be about checking the return
+ // values of these functions. There are other places where we do not check
+ // PyDict_SetItem/SetItemString return value, but these failures would be
+ // quite esoteric
+ PyDict_SetItemString(result.obj(), "num_tensors",
+ PyLong_FromSize_t(this->tensors.size()));
+ RETURN_NOT_OK(CountSparseTensors(this->sparse_tensors, &num_sparse_tensors));
+ PyDict_SetItemString(result.obj(), "num_sparse_tensors", num_sparse_tensors);
+ PyDict_SetItemString(result.obj(), "ndim_csf", num_sparse_tensors);
+ PyDict_SetItemString(result.obj(), "num_ndarrays",
+ PyLong_FromSize_t(this->ndarrays.size()));
+ PyDict_SetItemString(result.obj(), "num_buffers",
+ PyLong_FromSize_t(this->buffers.size()));
+ PyDict_SetItemString(result.obj(), "data", buffers);
+ RETURN_IF_PYERROR();
+
+ Py_DECREF(buffers);
+
+ auto PushBuffer = [&buffers](const std::shared_ptr<Buffer>& buffer) {
+ PyObject* wrapped_buffer = wrap_buffer(buffer);
+ RETURN_IF_PYERROR();
+ if (PyList_Append(buffers, wrapped_buffer) < 0) {
+ Py_DECREF(wrapped_buffer);
+ RETURN_IF_PYERROR();
+ }
+ Py_DECREF(wrapped_buffer);
+ return Status::OK();
+ };
+
+ constexpr int64_t kInitialCapacity = 1024;
+
+ // Write the record batch describing the object structure
+ py_gil.release();
+ ARROW_ASSIGN_OR_RAISE(auto stream,
+ io::BufferOutputStream::Create(kInitialCapacity, memory_pool));
+ RETURN_NOT_OK(
+ ipc::WriteRecordBatchStream({this->batch}, this->ipc_options, stream.get()));
+ ARROW_ASSIGN_OR_RAISE(auto buffer, stream->Finish());
+ py_gil.acquire();
+
+ RETURN_NOT_OK(PushBuffer(buffer));
+
+ // For each tensor, get a metadata buffer and a buffer for the body
+ for (const auto& tensor : this->tensors) {
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<ipc::Message> message,
+ ipc::GetTensorMessage(*tensor, memory_pool));
+ RETURN_NOT_OK(PushBuffer(message->metadata()));
+ RETURN_NOT_OK(PushBuffer(message->body()));
+ }
+
+ // For each sparse tensor, get a metadata buffer and buffers containing index and data
+ for (const auto& sparse_tensor : this->sparse_tensors) {
+ ipc::IpcPayload payload;
+ RETURN_NOT_OK(ipc::GetSparseTensorPayload(*sparse_tensor, memory_pool, &payload));
+ RETURN_NOT_OK(PushBuffer(payload.metadata));
+ for (const auto& body : payload.body_buffers) {
+ RETURN_NOT_OK(PushBuffer(body));
+ }
+ }
+
+ // For each ndarray, get a metadata buffer and a buffer for the body
+ for (const auto& ndarray : this->ndarrays) {
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<ipc::Message> message,
+ ipc::GetTensorMessage(*ndarray, memory_pool));
+ RETURN_NOT_OK(PushBuffer(message->metadata()));
+ RETURN_NOT_OK(PushBuffer(message->body()));
+ }
+
+ for (const auto& buf : this->buffers) {
+ RETURN_NOT_OK(PushBuffer(buf));
+ }
+
+ *out = result.detach();
+ return Status::OK();
+}
+
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/serialize.h b/src/arrow/cpp/src/arrow/python/serialize.h
new file mode 100644
index 000000000..fd207d3e0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/serialize.h
@@ -0,0 +1,145 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <vector>
+
+#include "arrow/ipc/options.h"
+#include "arrow/python/visibility.h"
+#include "arrow/sparse_tensor.h"
+#include "arrow/status.h"
+
+// Forward declaring PyObject, see
+// https://mail.python.org/pipermail/python-dev/2003-August/037601.html
+#ifndef PyObject_HEAD
+struct _object;
+typedef _object PyObject;
+#endif
+
+namespace arrow {
+
+class Buffer;
+class DataType;
+class MemoryPool;
+class RecordBatch;
+class Tensor;
+
+namespace io {
+
+class OutputStream;
+
+} // namespace io
+
+namespace py {
+
+struct ARROW_PYTHON_EXPORT SerializedPyObject {
+ std::shared_ptr<RecordBatch> batch;
+ std::vector<std::shared_ptr<Tensor>> tensors;
+ std::vector<std::shared_ptr<SparseTensor>> sparse_tensors;
+ std::vector<std::shared_ptr<Tensor>> ndarrays;
+ std::vector<std::shared_ptr<Buffer>> buffers;
+ ipc::IpcWriteOptions ipc_options;
+
+ SerializedPyObject();
+
+ /// \brief Write serialized Python object to OutputStream
+ /// \param[in,out] dst an OutputStream
+ /// \return Status
+ Status WriteTo(io::OutputStream* dst);
+
+ /// \brief Convert SerializedPyObject to a dict containing the message
+ /// components as Buffer instances with minimal memory allocation
+ ///
+ /// {
+ /// 'num_tensors': M,
+ /// 'num_sparse_tensors': N,
+ /// 'num_buffers': K,
+ /// 'data': [Buffer]
+ /// }
+ ///
+ /// Each tensor is written as two buffers, one for the metadata and one for
+ /// the body. Therefore, the number of buffers in 'data' is 2 * M + 2 * N + K + 1,
+ /// with the first buffer containing the serialized record batch containing
+ /// the UnionArray that describes the whole object
+ Status GetComponents(MemoryPool* pool, PyObject** out);
+};
+
+/// \brief Serialize Python sequence as a SerializedPyObject.
+/// \param[in] context Serialization context which contains custom serialization
+/// and deserialization callbacks. Can be any Python object with a
+/// _serialize_callback method for serialization and a _deserialize_callback
+/// method for deserialization. If context is None, no custom serialization
+/// will be attempted.
+/// \param[in] sequence A Python sequence object to serialize to Arrow data
+/// structures
+/// \param[out] out The serialized representation
+/// \return Status
+///
+/// Release GIL before calling
+ARROW_PYTHON_EXPORT
+Status SerializeObject(PyObject* context, PyObject* sequence, SerializedPyObject* out);
+
+/// \brief Serialize an Arrow Tensor as a SerializedPyObject.
+/// \param[in] tensor Tensor to be serialized
+/// \param[out] out The serialized representation
+/// \return Status
+ARROW_PYTHON_EXPORT
+Status SerializeTensor(std::shared_ptr<Tensor> tensor, py::SerializedPyObject* out);
+
+/// \brief Write the Tensor metadata header to an OutputStream.
+/// \param[in] dtype DataType of the Tensor
+/// \param[in] shape The shape of the tensor
+/// \param[in] tensor_num_bytes The length of the Tensor data in bytes
+/// \param[in] dst The OutputStream to write the Tensor header to
+/// \return Status
+ARROW_PYTHON_EXPORT
+Status WriteNdarrayHeader(std::shared_ptr<DataType> dtype,
+ const std::vector<int64_t>& shape, int64_t tensor_num_bytes,
+ io::OutputStream* dst);
+
+struct PythonType {
+ enum type {
+ NONE,
+ BOOL,
+ INT,
+ PY2INT, // Kept for compatibility
+ BYTES,
+ STRING,
+ HALF_FLOAT,
+ FLOAT,
+ DOUBLE,
+ DATE64,
+ LIST,
+ DICT,
+ TUPLE,
+ SET,
+ TENSOR,
+ NDARRAY,
+ BUFFER,
+ SPARSECOOTENSOR,
+ SPARSECSRMATRIX,
+ SPARSECSCMATRIX,
+ SPARSECSFTENSOR,
+ NUM_PYTHON_TYPES
+ };
+};
+
+} // namespace py
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/type_traits.h b/src/arrow/cpp/src/arrow/python/type_traits.h
new file mode 100644
index 000000000..a941577f7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/type_traits.h
@@ -0,0 +1,350 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Internal header
+
+#pragma once
+
+#include "arrow/python/platform.h"
+
+#include <cstdint>
+#include <limits>
+
+#include "arrow/python/numpy_interop.h"
+
+#include <numpy/halffloat.h>
+
+#include "arrow/type_fwd.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace py {
+
+static constexpr int64_t kPandasTimestampNull = std::numeric_limits<int64_t>::min();
+constexpr int64_t kNanosecondsInDay = 86400000000000LL;
+
+namespace internal {
+
+//
+// Type traits for Numpy -> Arrow equivalence
+//
+template <int TYPE>
+struct npy_traits {};
+
+template <>
+struct npy_traits<NPY_BOOL> {
+ typedef uint8_t value_type;
+ using TypeClass = BooleanType;
+ using BuilderClass = BooleanBuilder;
+
+ static constexpr bool supports_nulls = false;
+ static inline bool isnull(uint8_t v) { return false; }
+};
+
+#define NPY_INT_DECL(TYPE, CapType, T) \
+ template <> \
+ struct npy_traits<NPY_##TYPE> { \
+ typedef T value_type; \
+ using TypeClass = CapType##Type; \
+ using BuilderClass = CapType##Builder; \
+ \
+ static constexpr bool supports_nulls = false; \
+ static inline bool isnull(T v) { return false; } \
+ };
+
+NPY_INT_DECL(INT8, Int8, int8_t);
+NPY_INT_DECL(INT16, Int16, int16_t);
+NPY_INT_DECL(INT32, Int32, int32_t);
+NPY_INT_DECL(INT64, Int64, int64_t);
+
+NPY_INT_DECL(UINT8, UInt8, uint8_t);
+NPY_INT_DECL(UINT16, UInt16, uint16_t);
+NPY_INT_DECL(UINT32, UInt32, uint32_t);
+NPY_INT_DECL(UINT64, UInt64, uint64_t);
+
+#if !NPY_INT32_IS_INT && NPY_BITSOF_INT == 32
+NPY_INT_DECL(INT, Int32, int32_t);
+NPY_INT_DECL(UINT, UInt32, uint32_t);
+#endif
+#if !NPY_INT64_IS_LONG_LONG && NPY_BITSOF_LONGLONG == 64
+NPY_INT_DECL(LONGLONG, Int64, int64_t);
+NPY_INT_DECL(ULONGLONG, UInt64, uint64_t);
+#endif
+
+template <>
+struct npy_traits<NPY_FLOAT16> {
+ typedef npy_half value_type;
+ using TypeClass = HalfFloatType;
+ using BuilderClass = HalfFloatBuilder;
+
+ static constexpr npy_half na_sentinel = NPY_HALF_NAN;
+
+ static constexpr bool supports_nulls = true;
+
+ static inline bool isnull(npy_half v) { return v == NPY_HALF_NAN; }
+};
+
+template <>
+struct npy_traits<NPY_FLOAT32> {
+ typedef float value_type;
+ using TypeClass = FloatType;
+ using BuilderClass = FloatBuilder;
+
+ // We need to use quiet_NaN here instead of the NAN macro as on Windows
+ // the NAN macro leads to "division-by-zero" compile-time error with clang.
+ static constexpr float na_sentinel = std::numeric_limits<float>::quiet_NaN();
+
+ static constexpr bool supports_nulls = true;
+
+ static inline bool isnull(float v) { return v != v; }
+};
+
+template <>
+struct npy_traits<NPY_FLOAT64> {
+ typedef double value_type;
+ using TypeClass = DoubleType;
+ using BuilderClass = DoubleBuilder;
+
+ static constexpr double na_sentinel = std::numeric_limits<double>::quiet_NaN();
+
+ static constexpr bool supports_nulls = true;
+
+ static inline bool isnull(double v) { return v != v; }
+};
+
+template <>
+struct npy_traits<NPY_DATETIME> {
+ typedef int64_t value_type;
+ using TypeClass = TimestampType;
+ using BuilderClass = TimestampBuilder;
+
+ static constexpr bool supports_nulls = true;
+
+ static inline bool isnull(int64_t v) {
+ // NaT = -2**63
+ // = -0x8000000000000000
+ // = -9223372036854775808;
+ // = std::numeric_limits<int64_t>::min()
+ return v == std::numeric_limits<int64_t>::min();
+ }
+};
+
+template <>
+struct npy_traits<NPY_TIMEDELTA> {
+ typedef int64_t value_type;
+ using TypeClass = DurationType;
+ using BuilderClass = DurationBuilder;
+
+ static constexpr bool supports_nulls = true;
+
+ static inline bool isnull(int64_t v) {
+ // NaT = -2**63 = std::numeric_limits<int64_t>::min()
+ return v == std::numeric_limits<int64_t>::min();
+ }
+};
+
+template <>
+struct npy_traits<NPY_OBJECT> {
+ typedef PyObject* value_type;
+ static constexpr bool supports_nulls = true;
+
+ static inline bool isnull(PyObject* v) { return v == Py_None; }
+};
+
+//
+// Type traits for Arrow -> Numpy equivalence
+// Note *supports_nulls* means the equivalent Numpy type support nulls
+//
+template <int TYPE>
+struct arrow_traits {};
+
+template <>
+struct arrow_traits<Type::BOOL> {
+ static constexpr int npy_type = NPY_BOOL;
+ static constexpr bool supports_nulls = false;
+ typedef typename npy_traits<NPY_BOOL>::value_type T;
+};
+
+#define INT_DECL(TYPE) \
+ template <> \
+ struct arrow_traits<Type::TYPE> { \
+ static constexpr int npy_type = NPY_##TYPE; \
+ static constexpr bool supports_nulls = false; \
+ static constexpr double na_value = std::numeric_limits<double>::quiet_NaN(); \
+ typedef typename npy_traits<NPY_##TYPE>::value_type T; \
+ };
+
+INT_DECL(INT8);
+INT_DECL(INT16);
+INT_DECL(INT32);
+INT_DECL(INT64);
+INT_DECL(UINT8);
+INT_DECL(UINT16);
+INT_DECL(UINT32);
+INT_DECL(UINT64);
+
+template <>
+struct arrow_traits<Type::HALF_FLOAT> {
+ static constexpr int npy_type = NPY_FLOAT16;
+ static constexpr bool supports_nulls = true;
+ static constexpr uint16_t na_value = NPY_HALF_NAN;
+ typedef typename npy_traits<NPY_FLOAT16>::value_type T;
+};
+
+template <>
+struct arrow_traits<Type::FLOAT> {
+ static constexpr int npy_type = NPY_FLOAT32;
+ static constexpr bool supports_nulls = true;
+ static constexpr float na_value = std::numeric_limits<float>::quiet_NaN();
+ typedef typename npy_traits<NPY_FLOAT32>::value_type T;
+};
+
+template <>
+struct arrow_traits<Type::DOUBLE> {
+ static constexpr int npy_type = NPY_FLOAT64;
+ static constexpr bool supports_nulls = true;
+ static constexpr double na_value = std::numeric_limits<double>::quiet_NaN();
+ typedef typename npy_traits<NPY_FLOAT64>::value_type T;
+};
+
+template <>
+struct arrow_traits<Type::TIMESTAMP> {
+ static constexpr int npy_type = NPY_DATETIME;
+ static constexpr int64_t npy_shift = 1;
+
+ static constexpr bool supports_nulls = true;
+ static constexpr int64_t na_value = kPandasTimestampNull;
+ typedef typename npy_traits<NPY_DATETIME>::value_type T;
+};
+
+template <>
+struct arrow_traits<Type::DURATION> {
+ static constexpr int npy_type = NPY_TIMEDELTA;
+ static constexpr int64_t npy_shift = 1;
+
+ static constexpr bool supports_nulls = true;
+ static constexpr int64_t na_value = kPandasTimestampNull;
+ typedef typename npy_traits<NPY_TIMEDELTA>::value_type T;
+};
+
+template <>
+struct arrow_traits<Type::DATE32> {
+ // Data stores as FR_D day unit
+ static constexpr int npy_type = NPY_DATETIME;
+ static constexpr int64_t npy_shift = 1;
+
+ static constexpr bool supports_nulls = true;
+ typedef typename npy_traits<NPY_DATETIME>::value_type T;
+
+ static constexpr int64_t na_value = kPandasTimestampNull;
+ static inline bool isnull(int64_t v) { return npy_traits<NPY_DATETIME>::isnull(v); }
+};
+
+template <>
+struct arrow_traits<Type::DATE64> {
+ // Data stores as FR_D day unit
+ static constexpr int npy_type = NPY_DATETIME;
+
+ // There are 1000 * 60 * 60 * 24 = 86400000ms in a day
+ static constexpr int64_t npy_shift = 86400000;
+
+ static constexpr bool supports_nulls = true;
+ typedef typename npy_traits<NPY_DATETIME>::value_type T;
+
+ static constexpr int64_t na_value = kPandasTimestampNull;
+ static inline bool isnull(int64_t v) { return npy_traits<NPY_DATETIME>::isnull(v); }
+};
+
+template <>
+struct arrow_traits<Type::TIME32> {
+ static constexpr int npy_type = NPY_OBJECT;
+ static constexpr bool supports_nulls = true;
+ static constexpr int64_t na_value = kPandasTimestampNull;
+ typedef typename npy_traits<NPY_DATETIME>::value_type T;
+};
+
+template <>
+struct arrow_traits<Type::TIME64> {
+ static constexpr int npy_type = NPY_OBJECT;
+ static constexpr bool supports_nulls = true;
+ typedef typename npy_traits<NPY_DATETIME>::value_type T;
+};
+
+template <>
+struct arrow_traits<Type::STRING> {
+ static constexpr int npy_type = NPY_OBJECT;
+ static constexpr bool supports_nulls = true;
+};
+
+template <>
+struct arrow_traits<Type::BINARY> {
+ static constexpr int npy_type = NPY_OBJECT;
+ static constexpr bool supports_nulls = true;
+};
+
+static inline NPY_DATETIMEUNIT NumPyFrequency(TimeUnit::type unit) {
+ switch (unit) {
+ case TimestampType::Unit::SECOND:
+ return NPY_FR_s;
+ case TimestampType::Unit::MILLI:
+ return NPY_FR_ms;
+ break;
+ case TimestampType::Unit::MICRO:
+ return NPY_FR_us;
+ default:
+ // NANO
+ return NPY_FR_ns;
+ }
+}
+
+static inline int NumPyTypeSize(int npy_type) {
+ npy_type = fix_numpy_type_num(npy_type);
+
+ switch (npy_type) {
+ case NPY_BOOL:
+ case NPY_INT8:
+ case NPY_UINT8:
+ return 1;
+ case NPY_INT16:
+ case NPY_UINT16:
+ return 2;
+ case NPY_INT32:
+ case NPY_UINT32:
+ return 4;
+ case NPY_INT64:
+ case NPY_UINT64:
+ return 8;
+ case NPY_FLOAT16:
+ return 2;
+ case NPY_FLOAT32:
+ return 4;
+ case NPY_FLOAT64:
+ return 8;
+ case NPY_DATETIME:
+ return 8;
+ case NPY_OBJECT:
+ return sizeof(void*);
+ default:
+ ARROW_CHECK(false) << "unhandled numpy type";
+ break;
+ }
+ return -1;
+}
+
+} // namespace internal
+} // namespace py
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/python/util/CMakeLists.txt b/src/arrow/cpp/src/arrow/python/util/CMakeLists.txt
new file mode 100644
index 000000000..74141bebc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/util/CMakeLists.txt
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#
+# arrow/python_test_main
+#
+
+if(PYARROW_BUILD_TESTS)
+ add_library(arrow/python_test_main STATIC test_main.cc)
+
+ if(APPLE)
+ target_link_libraries(arrow/python_test_main GTest::gtest dl)
+ set_target_properties(arrow/python_test_main PROPERTIES LINK_FLAGS
+ "-undefined dynamic_lookup")
+ else()
+ target_link_libraries(arrow/python_test_main GTest::gtest pthread dl)
+ endif()
+endif()
diff --git a/src/arrow/cpp/src/arrow/python/util/test_main.cc b/src/arrow/cpp/src/arrow/python/util/test_main.cc
new file mode 100644
index 000000000..dd7f379bd
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/util/test_main.cc
@@ -0,0 +1,41 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/python/platform.h"
+
+#include <gtest/gtest.h>
+
+#include "arrow/python/datetime.h"
+#include "arrow/python/init.h"
+#include "arrow/python/pyarrow.h"
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+
+ Py_Initialize();
+ int ret = arrow_init_numpy();
+ if (ret != 0) {
+ return ret;
+ }
+ ::arrow::py::internal::InitDatetime();
+
+ ret = RUN_ALL_TESTS();
+
+ Py_Finalize();
+
+ return ret;
+}
diff --git a/src/arrow/cpp/src/arrow/python/visibility.h b/src/arrow/cpp/src/arrow/python/visibility.h
new file mode 100644
index 000000000..c0b343c70
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/python/visibility.h
@@ -0,0 +1,39 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#if defined(_WIN32) || defined(__CYGWIN__) // Windows
+#if defined(_MSC_VER)
+#pragma warning(disable : 4251)
+#else
+#pragma GCC diagnostic ignored "-Wattributes"
+#endif
+
+#ifdef ARROW_STATIC
+#define ARROW_PYTHON_EXPORT
+#elif defined(ARROW_PYTHON_EXPORTING)
+#define ARROW_PYTHON_EXPORT __declspec(dllexport)
+#else
+#define ARROW_PYTHON_EXPORT __declspec(dllimport)
+#endif
+
+#else // Not Windows
+#ifndef ARROW_PYTHON_EXPORT
+#define ARROW_PYTHON_EXPORT __attribute__((visibility("default")))
+#endif
+#endif // Non-Windows
diff --git a/src/arrow/cpp/src/arrow/record_batch.cc b/src/arrow/cpp/src/arrow/record_batch.cc
new file mode 100644
index 000000000..66f9e932b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/record_batch.cc
@@ -0,0 +1,367 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/record_batch.h"
+
+#include <algorithm>
+#include <cstdlib>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <utility>
+
+#include "arrow/array.h"
+#include "arrow/array/validate.h"
+#include "arrow/pretty_print.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/util/atomic_shared_ptr.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/vector.h"
+
+namespace arrow {
+
+Result<std::shared_ptr<RecordBatch>> RecordBatch::AddColumn(
+ int i, std::string field_name, const std::shared_ptr<Array>& column) const {
+ auto field = ::arrow::field(std::move(field_name), column->type());
+ return AddColumn(i, field, column);
+}
+
+std::shared_ptr<Array> RecordBatch::GetColumnByName(const std::string& name) const {
+ auto i = schema_->GetFieldIndex(name);
+ return i == -1 ? NULLPTR : column(i);
+}
+
+int RecordBatch::num_columns() const { return schema_->num_fields(); }
+
+/// \class SimpleRecordBatch
+/// \brief A basic, non-lazy in-memory record batch
+class SimpleRecordBatch : public RecordBatch {
+ public:
+ SimpleRecordBatch(std::shared_ptr<Schema> schema, int64_t num_rows,
+ std::vector<std::shared_ptr<Array>> columns)
+ : RecordBatch(std::move(schema), num_rows), boxed_columns_(std::move(columns)) {
+ columns_.resize(boxed_columns_.size());
+ for (size_t i = 0; i < columns_.size(); ++i) {
+ columns_[i] = boxed_columns_[i]->data();
+ }
+ }
+
+ SimpleRecordBatch(const std::shared_ptr<Schema>& schema, int64_t num_rows,
+ std::vector<std::shared_ptr<ArrayData>> columns)
+ : RecordBatch(std::move(schema), num_rows), columns_(std::move(columns)) {
+ boxed_columns_.resize(schema_->num_fields());
+ }
+
+ const std::vector<std::shared_ptr<Array>>& columns() const override {
+ for (int i = 0; i < num_columns(); ++i) {
+ // Force all columns to be boxed
+ column(i);
+ }
+ return boxed_columns_;
+ }
+
+ std::shared_ptr<Array> column(int i) const override {
+ std::shared_ptr<Array> result = internal::atomic_load(&boxed_columns_[i]);
+ if (!result) {
+ result = MakeArray(columns_[i]);
+ internal::atomic_store(&boxed_columns_[i], result);
+ }
+ return result;
+ }
+
+ std::shared_ptr<ArrayData> column_data(int i) const override { return columns_[i]; }
+
+ const ArrayDataVector& column_data() const override { return columns_; }
+
+ Result<std::shared_ptr<RecordBatch>> AddColumn(
+ int i, const std::shared_ptr<Field>& field,
+ const std::shared_ptr<Array>& column) const override {
+ ARROW_CHECK(field != nullptr);
+ ARROW_CHECK(column != nullptr);
+
+ if (!field->type()->Equals(column->type())) {
+ return Status::TypeError("Column data type ", field->type()->name(),
+ " does not match field data type ",
+ column->type()->name());
+ }
+ if (column->length() != num_rows_) {
+ return Status::Invalid(
+ "Added column's length must match record batch's length. Expected length ",
+ num_rows_, " but got length ", column->length());
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->AddField(i, field));
+ return RecordBatch::Make(std::move(new_schema), num_rows_,
+ internal::AddVectorElement(columns_, i, column->data()));
+ }
+
+ Result<std::shared_ptr<RecordBatch>> SetColumn(
+ int i, const std::shared_ptr<Field>& field,
+ const std::shared_ptr<Array>& column) const override {
+ ARROW_CHECK(field != nullptr);
+ ARROW_CHECK(column != nullptr);
+
+ if (!field->type()->Equals(column->type())) {
+ return Status::TypeError("Column data type ", field->type()->name(),
+ " does not match field data type ",
+ column->type()->name());
+ }
+ if (column->length() != num_rows_) {
+ return Status::Invalid(
+ "Added column's length must match record batch's length. Expected length ",
+ num_rows_, " but got length ", column->length());
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->SetField(i, field));
+ return RecordBatch::Make(std::move(new_schema), num_rows_,
+ internal::ReplaceVectorElement(columns_, i, column->data()));
+ }
+
+ Result<std::shared_ptr<RecordBatch>> RemoveColumn(int i) const override {
+ ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->RemoveField(i));
+ return RecordBatch::Make(std::move(new_schema), num_rows_,
+ internal::DeleteVectorElement(columns_, i));
+ }
+
+ std::shared_ptr<RecordBatch> ReplaceSchemaMetadata(
+ const std::shared_ptr<const KeyValueMetadata>& metadata) const override {
+ auto new_schema = schema_->WithMetadata(metadata);
+ return RecordBatch::Make(std::move(new_schema), num_rows_, columns_);
+ }
+
+ std::shared_ptr<RecordBatch> Slice(int64_t offset, int64_t length) const override {
+ std::vector<std::shared_ptr<ArrayData>> arrays;
+ arrays.reserve(num_columns());
+ for (const auto& field : columns_) {
+ arrays.emplace_back(field->Slice(offset, length));
+ }
+ int64_t num_rows = std::min(num_rows_ - offset, length);
+ return std::make_shared<SimpleRecordBatch>(schema_, num_rows, std::move(arrays));
+ }
+
+ Status Validate() const override {
+ if (static_cast<int>(columns_.size()) != schema_->num_fields()) {
+ return Status::Invalid("Number of columns did not match schema");
+ }
+ return RecordBatch::Validate();
+ }
+
+ private:
+ std::vector<std::shared_ptr<ArrayData>> columns_;
+
+ // Caching boxed array data
+ mutable std::vector<std::shared_ptr<Array>> boxed_columns_;
+};
+
+RecordBatch::RecordBatch(const std::shared_ptr<Schema>& schema, int64_t num_rows)
+ : schema_(schema), num_rows_(num_rows) {}
+
+std::shared_ptr<RecordBatch> RecordBatch::Make(
+ std::shared_ptr<Schema> schema, int64_t num_rows,
+ std::vector<std::shared_ptr<Array>> columns) {
+ DCHECK_EQ(schema->num_fields(), static_cast<int>(columns.size()));
+ return std::make_shared<SimpleRecordBatch>(std::move(schema), num_rows, columns);
+}
+
+std::shared_ptr<RecordBatch> RecordBatch::Make(
+ std::shared_ptr<Schema> schema, int64_t num_rows,
+ std::vector<std::shared_ptr<ArrayData>> columns) {
+ DCHECK_EQ(schema->num_fields(), static_cast<int>(columns.size()));
+ return std::make_shared<SimpleRecordBatch>(std::move(schema), num_rows,
+ std::move(columns));
+}
+
+Result<std::shared_ptr<RecordBatch>> RecordBatch::FromStructArray(
+ const std::shared_ptr<Array>& array) {
+ if (array->type_id() != Type::STRUCT) {
+ return Status::TypeError("Cannot construct record batch from array of type ",
+ *array->type());
+ }
+ if (array->null_count() != 0) {
+ return Status::Invalid(
+ "Unable to construct record batch from a StructArray with non-zero nulls.");
+ }
+ return Make(arrow::schema(array->type()->fields()), array->length(),
+ array->data()->child_data);
+}
+
+Result<std::shared_ptr<StructArray>> RecordBatch::ToStructArray() const {
+ if (num_columns() != 0) {
+ return StructArray::Make(columns(), schema()->fields());
+ }
+ return std::make_shared<StructArray>(arrow::struct_({}), num_rows_,
+ std::vector<std::shared_ptr<Array>>{},
+ /*null_bitmap=*/nullptr,
+ /*null_count=*/0,
+ /*offset=*/0);
+}
+
+const std::string& RecordBatch::column_name(int i) const {
+ return schema_->field(i)->name();
+}
+
+bool RecordBatch::Equals(const RecordBatch& other, bool check_metadata) const {
+ if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) {
+ return false;
+ }
+
+ if (check_metadata) {
+ if (!schema_->Equals(*other.schema(), /*check_metadata=*/true)) {
+ return false;
+ }
+ }
+
+ for (int i = 0; i < num_columns(); ++i) {
+ if (!column(i)->Equals(other.column(i))) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool RecordBatch::ApproxEquals(const RecordBatch& other) const {
+ if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) {
+ return false;
+ }
+
+ for (int i = 0; i < num_columns(); ++i) {
+ if (!column(i)->ApproxEquals(other.column(i))) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+Result<std::shared_ptr<RecordBatch>> RecordBatch::SelectColumns(
+ const std::vector<int>& indices) const {
+ int n = static_cast<int>(indices.size());
+
+ FieldVector fields(n);
+ ArrayVector columns(n);
+
+ for (int i = 0; i < n; i++) {
+ int pos = indices[i];
+ if (pos < 0 || pos > num_columns() - 1) {
+ return Status::Invalid("Invalid column index ", pos, " to select columns.");
+ }
+ fields[i] = schema()->field(pos);
+ columns[i] = column(pos);
+ }
+
+ auto new_schema =
+ std::make_shared<arrow::Schema>(std::move(fields), schema()->metadata());
+ return RecordBatch::Make(std::move(new_schema), num_rows(), std::move(columns));
+}
+
+std::shared_ptr<RecordBatch> RecordBatch::Slice(int64_t offset) const {
+ return Slice(offset, this->num_rows() - offset);
+}
+
+std::string RecordBatch::ToString() const {
+ std::stringstream ss;
+ ARROW_CHECK_OK(PrettyPrint(*this, 0, &ss));
+ return ss.str();
+}
+
+Status RecordBatch::Validate() const {
+ for (int i = 0; i < num_columns(); ++i) {
+ const auto& array = *this->column(i);
+ if (array.length() != num_rows_) {
+ return Status::Invalid("Number of rows in column ", i,
+ " did not match batch: ", array.length(), " vs ", num_rows_);
+ }
+ const auto& schema_type = *schema_->field(i)->type();
+ if (!array.type()->Equals(schema_type)) {
+ return Status::Invalid("Column ", i,
+ " type not match schema: ", array.type()->ToString(), " vs ",
+ schema_type.ToString());
+ }
+ RETURN_NOT_OK(internal::ValidateArray(array));
+ }
+ return Status::OK();
+}
+
+Status RecordBatch::ValidateFull() const {
+ RETURN_NOT_OK(Validate());
+ for (int i = 0; i < num_columns(); ++i) {
+ const auto& array = *this->column(i);
+ RETURN_NOT_OK(internal::ValidateArrayFull(array));
+ }
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Base record batch reader
+
+Status RecordBatchReader::ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches) {
+ while (true) {
+ std::shared_ptr<RecordBatch> batch;
+ RETURN_NOT_OK(ReadNext(&batch));
+ if (!batch) {
+ break;
+ }
+ batches->emplace_back(std::move(batch));
+ }
+ return Status::OK();
+}
+
+Status RecordBatchReader::ReadAll(std::shared_ptr<Table>* table) {
+ std::vector<std::shared_ptr<RecordBatch>> batches;
+ RETURN_NOT_OK(ReadAll(&batches));
+ return Table::FromRecordBatches(schema(), std::move(batches)).Value(table);
+}
+
+class SimpleRecordBatchReader : public RecordBatchReader {
+ public:
+ SimpleRecordBatchReader(Iterator<std::shared_ptr<RecordBatch>> it,
+ std::shared_ptr<Schema> schema)
+ : schema_(std::move(schema)), it_(std::move(it)) {}
+
+ SimpleRecordBatchReader(std::vector<std::shared_ptr<RecordBatch>> batches,
+ std::shared_ptr<Schema> schema)
+ : schema_(std::move(schema)), it_(MakeVectorIterator(std::move(batches))) {}
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
+ return it_.Next().Value(batch);
+ }
+
+ std::shared_ptr<Schema> schema() const override { return schema_; }
+
+ protected:
+ std::shared_ptr<Schema> schema_;
+ Iterator<std::shared_ptr<RecordBatch>> it_;
+};
+
+Result<std::shared_ptr<RecordBatchReader>> RecordBatchReader::Make(
+ std::vector<std::shared_ptr<RecordBatch>> batches, std::shared_ptr<Schema> schema) {
+ if (schema == nullptr) {
+ if (batches.size() == 0 || batches[0] == nullptr) {
+ return Status::Invalid("Cannot infer schema from empty vector or nullptr");
+ }
+
+ schema = batches[0]->schema();
+ }
+
+ return std::make_shared<SimpleRecordBatchReader>(std::move(batches), schema);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/record_batch.h b/src/arrow/cpp/src/arrow/record_batch.h
new file mode 100644
index 000000000..3173eee10
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/record_batch.h
@@ -0,0 +1,241 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+/// \class RecordBatch
+/// \brief Collection of equal-length arrays matching a particular Schema
+///
+/// A record batch is table-like data structure that is semantically a sequence
+/// of fields, each a contiguous Arrow array
+class ARROW_EXPORT RecordBatch {
+ public:
+ virtual ~RecordBatch() = default;
+
+ /// \param[in] schema The record batch schema
+ /// \param[in] num_rows length of fields in the record batch. Each array
+ /// should have the same length as num_rows
+ /// \param[in] columns the record batch fields as vector of arrays
+ static std::shared_ptr<RecordBatch> Make(std::shared_ptr<Schema> schema,
+ int64_t num_rows,
+ std::vector<std::shared_ptr<Array>> columns);
+
+ /// \brief Construct record batch from vector of internal data structures
+ /// \since 0.5.0
+ ///
+ /// This class is intended for internal use, or advanced users.
+ ///
+ /// \param schema the record batch schema
+ /// \param num_rows the number of semantic rows in the record batch. This
+ /// should be equal to the length of each field
+ /// \param columns the data for the batch's columns
+ static std::shared_ptr<RecordBatch> Make(
+ std::shared_ptr<Schema> schema, int64_t num_rows,
+ std::vector<std::shared_ptr<ArrayData>> columns);
+
+ /// \brief Convert record batch to struct array
+ ///
+ /// Create a struct array whose child arrays are the record batch's columns.
+ /// Note that the record batch's top-level field metadata cannot be reflected
+ /// in the resulting struct array.
+ Result<std::shared_ptr<StructArray>> ToStructArray() const;
+
+ /// \brief Construct record batch from struct array
+ ///
+ /// This constructs a record batch using the child arrays of the given
+ /// array, which must be a struct array. Note that the struct array's own
+ /// null bitmap is not reflected in the resulting record batch.
+ static Result<std::shared_ptr<RecordBatch>> FromStructArray(
+ const std::shared_ptr<Array>& array);
+
+ /// \brief Determine if two record batches are exactly equal
+ ///
+ /// \param[in] other the RecordBatch to compare with
+ /// \param[in] check_metadata if true, check that Schema metadata is the same
+ /// \return true if batches are equal
+ bool Equals(const RecordBatch& other, bool check_metadata = false) const;
+
+ /// \brief Determine if two record batches are approximately equal
+ bool ApproxEquals(const RecordBatch& other) const;
+
+ /// \return the record batch's schema
+ const std::shared_ptr<Schema>& schema() const { return schema_; }
+
+ /// \brief Retrieve all columns at once
+ virtual const std::vector<std::shared_ptr<Array>>& columns() const = 0;
+
+ /// \brief Retrieve an array from the record batch
+ /// \param[in] i field index, does not boundscheck
+ /// \return an Array object
+ virtual std::shared_ptr<Array> column(int i) const = 0;
+
+ /// \brief Retrieve an array from the record batch
+ /// \param[in] name field name
+ /// \return an Array or null if no field was found
+ std::shared_ptr<Array> GetColumnByName(const std::string& name) const;
+
+ /// \brief Retrieve an array's internal data from the record batch
+ /// \param[in] i field index, does not boundscheck
+ /// \return an internal ArrayData object
+ virtual std::shared_ptr<ArrayData> column_data(int i) const = 0;
+
+ /// \brief Retrieve all arrays' internal data from the record batch.
+ virtual const ArrayDataVector& column_data() const = 0;
+
+ /// \brief Add column to the record batch, producing a new RecordBatch
+ ///
+ /// \param[in] i field index, which will be boundschecked
+ /// \param[in] field field to be added
+ /// \param[in] column column to be added
+ virtual Result<std::shared_ptr<RecordBatch>> AddColumn(
+ int i, const std::shared_ptr<Field>& field,
+ const std::shared_ptr<Array>& column) const = 0;
+
+ /// \brief Add new nullable column to the record batch, producing a new
+ /// RecordBatch.
+ ///
+ /// For non-nullable columns, use the Field-based version of this method.
+ ///
+ /// \param[in] i field index, which will be boundschecked
+ /// \param[in] field_name name of field to be added
+ /// \param[in] column column to be added
+ virtual Result<std::shared_ptr<RecordBatch>> AddColumn(
+ int i, std::string field_name, const std::shared_ptr<Array>& column) const;
+
+ /// \brief Replace a column in the record batch, producing a new RecordBatch
+ ///
+ /// \param[in] i field index, does boundscheck
+ /// \param[in] field field to be replaced
+ /// \param[in] column column to be replaced
+ virtual Result<std::shared_ptr<RecordBatch>> SetColumn(
+ int i, const std::shared_ptr<Field>& field,
+ const std::shared_ptr<Array>& column) const = 0;
+
+ /// \brief Remove column from the record batch, producing a new RecordBatch
+ ///
+ /// \param[in] i field index, does boundscheck
+ virtual Result<std::shared_ptr<RecordBatch>> RemoveColumn(int i) const = 0;
+
+ virtual std::shared_ptr<RecordBatch> ReplaceSchemaMetadata(
+ const std::shared_ptr<const KeyValueMetadata>& metadata) const = 0;
+
+ /// \brief Name in i-th column
+ const std::string& column_name(int i) const;
+
+ /// \return the number of columns in the table
+ int num_columns() const;
+
+ /// \return the number of rows (the corresponding length of each column)
+ int64_t num_rows() const { return num_rows_; }
+
+ /// \brief Slice each of the arrays in the record batch
+ /// \param[in] offset the starting offset to slice, through end of batch
+ /// \return new record batch
+ virtual std::shared_ptr<RecordBatch> Slice(int64_t offset) const;
+
+ /// \brief Slice each of the arrays in the record batch
+ /// \param[in] offset the starting offset to slice
+ /// \param[in] length the number of elements to slice from offset
+ /// \return new record batch
+ virtual std::shared_ptr<RecordBatch> Slice(int64_t offset, int64_t length) const = 0;
+
+ /// \return PrettyPrint representation suitable for debugging
+ std::string ToString() const;
+
+ /// \brief Return new record batch with specified columns
+ Result<std::shared_ptr<RecordBatch>> SelectColumns(
+ const std::vector<int>& indices) const;
+
+ /// \brief Perform cheap validation checks to determine obvious inconsistencies
+ /// within the record batch's schema and internal data.
+ ///
+ /// This is O(k) where k is the total number of fields and array descendents.
+ ///
+ /// \return Status
+ virtual Status Validate() const;
+
+ /// \brief Perform extensive validation checks to determine inconsistencies
+ /// within the record batch's schema and internal data.
+ ///
+ /// This is potentially O(k*n) where n is the number of rows.
+ ///
+ /// \return Status
+ virtual Status ValidateFull() const;
+
+ protected:
+ RecordBatch(const std::shared_ptr<Schema>& schema, int64_t num_rows);
+
+ std::shared_ptr<Schema> schema_;
+ int64_t num_rows_;
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(RecordBatch);
+};
+
+/// \brief Abstract interface for reading stream of record batches
+class ARROW_EXPORT RecordBatchReader {
+ public:
+ using ValueType = std::shared_ptr<RecordBatch>;
+
+ virtual ~RecordBatchReader() = default;
+
+ /// \return the shared schema of the record batches in the stream
+ virtual std::shared_ptr<Schema> schema() const = 0;
+
+ /// \brief Read the next record batch in the stream. Return null for batch
+ /// when reaching end of stream
+ ///
+ /// \param[out] batch the next loaded batch, null at end of stream
+ /// \return Status
+ virtual Status ReadNext(std::shared_ptr<RecordBatch>* batch) = 0;
+
+ /// \brief Iterator interface
+ Result<std::shared_ptr<RecordBatch>> Next() {
+ std::shared_ptr<RecordBatch> batch;
+ ARROW_RETURN_NOT_OK(ReadNext(&batch));
+ return batch;
+ }
+
+ /// \brief Consume entire stream as a vector of record batches
+ Status ReadAll(RecordBatchVector* batches);
+
+ /// \brief Read all batches and concatenate as arrow::Table
+ Status ReadAll(std::shared_ptr<Table>* table);
+
+ /// \brief Create a RecordBatchReader from a vector of RecordBatch.
+ ///
+ /// \param[in] batches the vector of RecordBatch to read from
+ /// \param[in] schema schema to conform to. Will be inferred from the first
+ /// element if not provided.
+ static Result<std::shared_ptr<RecordBatchReader>> Make(
+ RecordBatchVector batches, std::shared_ptr<Schema> schema = NULLPTR);
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/record_batch_test.cc b/src/arrow/cpp/src/arrow/record_batch_test.cc
new file mode 100644
index 000000000..9de57f183
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/record_batch_test.cc
@@ -0,0 +1,320 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/record_batch.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/data.h"
+#include "arrow/array/util.h"
+#include "arrow/chunked_array.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+
+class TestRecordBatch : public TestBase {};
+
+TEST_F(TestRecordBatch, Equals) {
+ const int length = 10;
+
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8());
+ auto f2 = field("f2", int16());
+
+ auto metadata = key_value_metadata({"foo"}, {"bar"});
+
+ std::vector<std::shared_ptr<Field>> fields = {f0, f1, f2};
+ auto schema = ::arrow::schema({f0, f1, f2});
+ auto schema2 = ::arrow::schema({f0, f1});
+ auto schema3 = ::arrow::schema({f0, f1, f2}, metadata);
+
+ auto a0 = MakeRandomArray<Int32Array>(length);
+ auto a1 = MakeRandomArray<UInt8Array>(length);
+ auto a2 = MakeRandomArray<Int16Array>(length);
+
+ auto b1 = RecordBatch::Make(schema, length, {a0, a1, a2});
+ auto b2 = RecordBatch::Make(schema3, length, {a0, a1, a2});
+ auto b3 = RecordBatch::Make(schema2, length, {a0, a1});
+ auto b4 = RecordBatch::Make(schema, length, {a0, a1, a1});
+
+ ASSERT_TRUE(b1->Equals(*b1));
+ ASSERT_FALSE(b1->Equals(*b3));
+ ASSERT_FALSE(b1->Equals(*b4));
+
+ // Different metadata
+ ASSERT_TRUE(b1->Equals(*b2));
+ ASSERT_FALSE(b1->Equals(*b2, /*check_metadata=*/true));
+}
+
+TEST_F(TestRecordBatch, Validate) {
+ const int length = 10;
+
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8());
+ auto f2 = field("f2", int16());
+
+ auto schema = ::arrow::schema({f0, f1, f2});
+
+ auto a0 = MakeRandomArray<Int32Array>(length);
+ auto a1 = MakeRandomArray<UInt8Array>(length);
+ auto a2 = MakeRandomArray<Int16Array>(length);
+ auto a3 = MakeRandomArray<Int16Array>(5);
+
+ auto b1 = RecordBatch::Make(schema, length, {a0, a1, a2});
+
+ ASSERT_OK(b1->ValidateFull());
+
+ // Length mismatch
+ auto b2 = RecordBatch::Make(schema, length, {a0, a1, a3});
+ ASSERT_RAISES(Invalid, b2->ValidateFull());
+
+ // Type mismatch
+ auto b3 = RecordBatch::Make(schema, length, {a0, a1, a0});
+ ASSERT_RAISES(Invalid, b3->ValidateFull());
+}
+
+TEST_F(TestRecordBatch, Slice) {
+ const int length = 7;
+
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8());
+ auto f2 = field("f2", int8());
+
+ std::vector<std::shared_ptr<Field>> fields = {f0, f1, f2};
+ auto schema = ::arrow::schema(fields);
+
+ auto a0 = MakeRandomArray<Int32Array>(length);
+ auto a1 = MakeRandomArray<UInt8Array>(length);
+ auto a2 = ArrayFromJSON(int8(), "[0, 1, 2, 3, 4, 5, 6]");
+
+ auto batch = RecordBatch::Make(schema, length, {a0, a1, a2});
+
+ auto batch_slice = batch->Slice(2);
+ auto batch_slice2 = batch->Slice(1, 5);
+
+ ASSERT_EQ(batch_slice->num_rows(), batch->num_rows() - 2);
+
+ for (int i = 0; i < batch->num_columns(); ++i) {
+ ASSERT_EQ(2, batch_slice->column(i)->offset());
+ ASSERT_EQ(length - 2, batch_slice->column(i)->length());
+
+ ASSERT_EQ(1, batch_slice2->column(i)->offset());
+ ASSERT_EQ(5, batch_slice2->column(i)->length());
+ }
+
+ // ARROW-9143: RecordBatch::Slice was incorrectly setting a2's
+ // ArrayData::null_count to kUnknownNullCount
+ ASSERT_EQ(batch_slice->column(2)->data()->null_count, 0);
+ ASSERT_EQ(batch_slice2->column(2)->data()->null_count, 0);
+}
+
+TEST_F(TestRecordBatch, AddColumn) {
+ const int length = 10;
+
+ auto field1 = field("f1", int32());
+ auto field2 = field("f2", uint8());
+ auto field3 = field("f3", int16());
+
+ auto schema1 = ::arrow::schema({field1, field2});
+ auto schema2 = ::arrow::schema({field2, field3});
+ auto schema3 = ::arrow::schema({field2});
+
+ auto array1 = MakeRandomArray<Int32Array>(length);
+ auto array2 = MakeRandomArray<UInt8Array>(length);
+ auto array3 = MakeRandomArray<Int16Array>(length);
+
+ auto batch1 = RecordBatch::Make(schema1, length, {array1, array2});
+ auto batch2 = RecordBatch::Make(schema2, length, {array2, array3});
+ auto batch3 = RecordBatch::Make(schema3, length, {array2});
+
+ const RecordBatch& batch = *batch3;
+
+ // Negative tests with invalid index
+ ASSERT_RAISES(Invalid, batch.AddColumn(5, field1, array1));
+ ASSERT_RAISES(Invalid, batch.AddColumn(2, field1, array1));
+ ASSERT_RAISES(Invalid, batch.AddColumn(-1, field1, array1));
+
+ // Negative test with wrong length
+ auto longer_col = MakeRandomArray<Int32Array>(length + 1);
+ ASSERT_RAISES(Invalid, batch.AddColumn(0, field1, longer_col));
+
+ // Negative test with mismatch type
+ ASSERT_RAISES(TypeError, batch.AddColumn(0, field1, array2));
+
+ ASSERT_OK_AND_ASSIGN(auto new_batch, batch.AddColumn(0, field1, array1));
+ AssertBatchesEqual(*new_batch, *batch1);
+
+ ASSERT_OK_AND_ASSIGN(new_batch, batch.AddColumn(1, field3, array3));
+ AssertBatchesEqual(*new_batch, *batch2);
+
+ ASSERT_OK_AND_ASSIGN(auto new_batch2, batch.AddColumn(1, "f3", array3));
+ AssertBatchesEqual(*new_batch2, *new_batch);
+
+ ASSERT_TRUE(new_batch2->schema()->field(1)->nullable());
+}
+
+TEST_F(TestRecordBatch, SetColumn) {
+ const int length = 10;
+
+ auto field1 = field("f1", int32());
+ auto field2 = field("f2", uint8());
+ auto field3 = field("f3", int16());
+
+ auto schema1 = ::arrow::schema({field1, field2});
+ auto schema2 = ::arrow::schema({field1, field3});
+ auto schema3 = ::arrow::schema({field3, field2});
+
+ auto array1 = MakeRandomArray<Int32Array>(length);
+ auto array2 = MakeRandomArray<UInt8Array>(length);
+ auto array3 = MakeRandomArray<Int16Array>(length);
+
+ auto batch1 = RecordBatch::Make(schema1, length, {array1, array2});
+ auto batch2 = RecordBatch::Make(schema2, length, {array1, array3});
+ auto batch3 = RecordBatch::Make(schema3, length, {array3, array2});
+
+ const RecordBatch& batch = *batch1;
+
+ // Negative tests with invalid index
+ ASSERT_RAISES(Invalid, batch.SetColumn(5, field1, array1));
+ ASSERT_RAISES(Invalid, batch.SetColumn(-1, field1, array1));
+
+ // Negative test with wrong length
+ auto longer_col = MakeRandomArray<Int32Array>(length + 1);
+ ASSERT_RAISES(Invalid, batch.SetColumn(0, field1, longer_col));
+
+ // Negative test with mismatch type
+ ASSERT_RAISES(TypeError, batch.SetColumn(0, field1, array2));
+
+ ASSERT_OK_AND_ASSIGN(auto new_batch, batch.SetColumn(1, field3, array3));
+ AssertBatchesEqual(*new_batch, *batch2);
+
+ ASSERT_OK_AND_ASSIGN(new_batch, batch.SetColumn(0, field3, array3));
+ AssertBatchesEqual(*new_batch, *batch3);
+}
+
+TEST_F(TestRecordBatch, RemoveColumn) {
+ const int length = 10;
+
+ auto field1 = field("f1", int32());
+ auto field2 = field("f2", uint8());
+ auto field3 = field("f3", int16());
+
+ auto schema1 = ::arrow::schema({field1, field2, field3});
+ auto schema2 = ::arrow::schema({field2, field3});
+ auto schema3 = ::arrow::schema({field1, field3});
+ auto schema4 = ::arrow::schema({field1, field2});
+
+ auto array1 = MakeRandomArray<Int32Array>(length);
+ auto array2 = MakeRandomArray<UInt8Array>(length);
+ auto array3 = MakeRandomArray<Int16Array>(length);
+
+ auto batch1 = RecordBatch::Make(schema1, length, {array1, array2, array3});
+ auto batch2 = RecordBatch::Make(schema2, length, {array2, array3});
+ auto batch3 = RecordBatch::Make(schema3, length, {array1, array3});
+ auto batch4 = RecordBatch::Make(schema4, length, {array1, array2});
+
+ const RecordBatch& batch = *batch1;
+ std::shared_ptr<RecordBatch> result;
+
+ // Negative tests with invalid index
+ ASSERT_RAISES(Invalid, batch.RemoveColumn(3));
+ ASSERT_RAISES(Invalid, batch.RemoveColumn(-1));
+
+ ASSERT_OK_AND_ASSIGN(auto new_batch, batch.RemoveColumn(0));
+ AssertBatchesEqual(*new_batch, *batch2);
+
+ ASSERT_OK_AND_ASSIGN(new_batch, batch.RemoveColumn(1));
+ AssertBatchesEqual(*new_batch, *batch3);
+
+ ASSERT_OK_AND_ASSIGN(new_batch, batch.RemoveColumn(2));
+ AssertBatchesEqual(*new_batch, *batch4);
+}
+
+TEST_F(TestRecordBatch, SelectColumns) {
+ const int length = 10;
+
+ auto field1 = field("f1", int32());
+ auto field2 = field("f2", uint8());
+ auto field3 = field("f3", int16());
+
+ auto schema1 = ::arrow::schema({field1, field2, field3});
+
+ auto array1 = MakeRandomArray<Int32Array>(length);
+ auto array2 = MakeRandomArray<UInt8Array>(length);
+ auto array3 = MakeRandomArray<Int16Array>(length);
+
+ auto batch = RecordBatch::Make(schema1, length, {array1, array2, array3});
+
+ ASSERT_OK_AND_ASSIGN(auto subset, batch->SelectColumns({0, 2}));
+ ASSERT_OK(subset->ValidateFull());
+
+ auto expected_schema = ::arrow::schema({schema1->field(0), schema1->field(2)});
+ auto expected =
+ RecordBatch::Make(expected_schema, length, {batch->column(0), batch->column(2)});
+ ASSERT_TRUE(subset->Equals(*expected));
+
+ // Out of bounds indices
+ ASSERT_RAISES(Invalid, batch->SelectColumns({0, 3}));
+ ASSERT_RAISES(Invalid, batch->SelectColumns({-1}));
+}
+
+TEST_F(TestRecordBatch, RemoveColumnEmpty) {
+ const int length = 10;
+
+ auto field1 = field("f1", int32());
+ auto schema1 = ::arrow::schema({field1});
+ auto array1 = MakeRandomArray<Int32Array>(length);
+ auto batch1 = RecordBatch::Make(schema1, length, {array1});
+
+ ASSERT_OK_AND_ASSIGN(auto empty, batch1->RemoveColumn(0));
+ ASSERT_EQ(batch1->num_rows(), empty->num_rows());
+
+ ASSERT_OK_AND_ASSIGN(auto added, empty->AddColumn(0, field1, array1));
+ AssertBatchesEqual(*added, *batch1);
+}
+
+TEST_F(TestRecordBatch, ToFromEmptyStructArray) {
+ auto batch1 =
+ RecordBatch::Make(::arrow::schema({}), 10, std::vector<std::shared_ptr<Array>>{});
+ ASSERT_OK_AND_ASSIGN(auto struct_array, batch1->ToStructArray());
+ ASSERT_EQ(10, struct_array->length());
+ ASSERT_OK_AND_ASSIGN(auto batch2, RecordBatch::FromStructArray(struct_array));
+ ASSERT_TRUE(batch1->Equals(*batch2));
+}
+
+TEST_F(TestRecordBatch, FromStructArrayInvalidType) {
+ ASSERT_RAISES(TypeError, RecordBatch::FromStructArray(MakeRandomArray<Int32Array>(10)));
+}
+
+TEST_F(TestRecordBatch, FromStructArrayInvalidNullCount) {
+ auto struct_array =
+ ArrayFromJSON(struct_({field("f1", int32())}), R"([{"f1": 1}, null])");
+ ASSERT_RAISES(Invalid, RecordBatch::FromStructArray(struct_array));
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/result.cc b/src/arrow/cpp/src/arrow/result.cc
new file mode 100644
index 000000000..0bb65acb8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/result.cc
@@ -0,0 +1,36 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/result.h"
+
+#include <string>
+
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+namespace internal {
+
+void DieWithMessage(const std::string& msg) { ARROW_LOG(FATAL) << msg; }
+
+void InvalidValueOrDie(const Status& st) {
+ DieWithMessage(std::string("ValueOrDie called on an error: ") + st.ToString());
+}
+
+} // namespace internal
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/result.h b/src/arrow/cpp/src/arrow/result.h
new file mode 100644
index 000000000..7fdbeea4b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/result.h
@@ -0,0 +1,512 @@
+//
+// Copyright 2017 Asylo authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// Adapted from Asylo
+
+#pragma once
+
+#include <cstddef>
+#include <new>
+#include <string>
+#include <type_traits>
+#include <utility>
+
+#include "arrow/status.h"
+#include "arrow/util/aligned_storage.h"
+#include "arrow/util/compare.h"
+
+namespace arrow {
+
+template <typename>
+struct EnsureResult;
+
+namespace internal {
+
+ARROW_EXPORT void DieWithMessage(const std::string& msg);
+
+ARROW_EXPORT void InvalidValueOrDie(const Status& st);
+
+} // namespace internal
+
+/// A class for representing either a usable value, or an error.
+///
+/// A Result object either contains a value of type `T` or a Status object
+/// explaining why such a value is not present. The type `T` must be
+/// copy-constructible and/or move-constructible.
+///
+/// The state of a Result object may be determined by calling ok() or
+/// status(). The ok() method returns true if the object contains a valid value.
+/// The status() method returns the internal Status object. A Result object
+/// that contains a valid value will return an OK Status for a call to status().
+///
+/// A value of type `T` may be extracted from a Result object through a call
+/// to ValueOrDie(). This function should only be called if a call to ok()
+/// returns true. Sample usage:
+///
+/// ```
+/// arrow::Result<Foo> result = CalculateFoo();
+/// if (result.ok()) {
+/// Foo foo = result.ValueOrDie();
+/// foo.DoSomethingCool();
+/// } else {
+/// ARROW_LOG(ERROR) << result.status();
+/// }
+/// ```
+///
+/// If `T` is a move-only type, like `std::unique_ptr<>`, then the value should
+/// only be extracted after invoking `std::move()` on the Result object.
+/// Sample usage:
+///
+/// ```
+/// arrow::Result<std::unique_ptr<Foo>> result = CalculateFoo();
+/// if (result.ok()) {
+/// std::unique_ptr<Foo> foo = std::move(result).ValueOrDie();
+/// foo->DoSomethingCool();
+/// } else {
+/// ARROW_LOG(ERROR) << result.status();
+/// }
+/// ```
+///
+/// Result is provided for the convenience of implementing functions that
+/// return some value but may fail during execution. For instance, consider a
+/// function with the following signature:
+///
+/// ```
+/// arrow::Status CalculateFoo(int *output);
+/// ```
+///
+/// This function may instead be written as:
+///
+/// ```
+/// arrow::Result<int> CalculateFoo();
+/// ```
+template <class T>
+class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable<Result<T>> {
+ template <typename U>
+ friend class Result;
+
+ static_assert(!std::is_same<T, Status>::value,
+ "this assert indicates you have probably made a metaprogramming error");
+
+ public:
+ using ValueType = T;
+
+ /// Constructs a Result object that contains a non-OK status.
+ ///
+ /// This constructor is marked `explicit` to prevent attempts to `return {}`
+ /// from a function with a return type of, for example,
+ /// `Result<std::vector<int>>`. While `return {}` seems like it would return
+ /// an empty vector, it will actually invoke the default constructor of
+ /// Result.
+ explicit Result() noexcept // NOLINT(runtime/explicit)
+ : status_(Status::UnknownError("Uninitialized Result<T>")) {}
+
+ ~Result() noexcept { Destroy(); }
+
+ /// Constructs a Result object with the given non-OK Status object. All
+ /// calls to ValueOrDie() on this object will abort. The given `status` must
+ /// not be an OK status, otherwise this constructor will abort.
+ ///
+ /// This constructor is not declared explicit so that a function with a return
+ /// type of `Result<T>` can return a Status object, and the status will be
+ /// implicitly converted to the appropriate return type as a matter of
+ /// convenience.
+ ///
+ /// \param status The non-OK Status object to initialize to.
+ Result(const Status& status) noexcept // NOLINT(runtime/explicit)
+ : status_(status) {
+ if (ARROW_PREDICT_FALSE(status.ok())) {
+ internal::DieWithMessage(std::string("Constructed with a non-error status: ") +
+ status.ToString());
+ }
+ }
+
+ /// Constructs a Result object that contains `value`. The resulting object
+ /// is considered to have an OK status. The wrapped element can be accessed
+ /// with ValueOrDie().
+ ///
+ /// This constructor is made implicit so that a function with a return type of
+ /// `Result<T>` can return an object of type `U &&`, implicitly converting
+ /// it to a `Result<T>` object.
+ ///
+ /// Note that `T` must be implicitly constructible from `U`, and `U` must not
+ /// be a (cv-qualified) Status or Status-reference type. Due to C++
+ /// reference-collapsing rules and perfect-forwarding semantics, this
+ /// constructor matches invocations that pass `value` either as a const
+ /// reference or as an rvalue reference. Since Result needs to work for both
+ /// reference and rvalue-reference types, the constructor uses perfect
+ /// forwarding to avoid invalidating arguments that were passed by reference.
+ /// See http://thbecker.net/articles/rvalue_references/section_08.html for
+ /// additional details.
+ ///
+ /// \param value The value to initialize to.
+ template <typename U,
+ typename E = typename std::enable_if<
+ std::is_constructible<T, U>::value && std::is_convertible<U, T>::value &&
+ !std::is_same<typename std::remove_reference<
+ typename std::remove_cv<U>::type>::type,
+ Status>::value>::type>
+ Result(U&& value) noexcept { // NOLINT(runtime/explicit)
+ ConstructValue(std::forward<U>(value));
+ }
+
+ /// Constructs a Result object that contains `value`. The resulting object
+ /// is considered to have an OK status. The wrapped element can be accessed
+ /// with ValueOrDie().
+ ///
+ /// This constructor is made implicit so that a function with a return type of
+ /// `Result<T>` can return an object of type `T`, implicitly converting
+ /// it to a `Result<T>` object.
+ ///
+ /// \param value The value to initialize to.
+ // NOTE `Result(U&& value)` above should be sufficient, but some compilers
+ // fail matching it.
+ Result(T&& value) noexcept { // NOLINT(runtime/explicit)
+ ConstructValue(std::move(value));
+ }
+
+ /// Copy constructor.
+ ///
+ /// This constructor needs to be explicitly defined because the presence of
+ /// the move-assignment operator deletes the default copy constructor. In such
+ /// a scenario, since the deleted copy constructor has stricter binding rules
+ /// than the templated copy constructor, the templated constructor cannot act
+ /// as a copy constructor, and any attempt to copy-construct a `Result`
+ /// object results in a compilation error.
+ ///
+ /// \param other The value to copy from.
+ Result(const Result& other) noexcept : status_(other.status_) {
+ if (ARROW_PREDICT_TRUE(status_.ok())) {
+ ConstructValue(other.ValueUnsafe());
+ }
+ }
+
+ /// Templatized constructor that constructs a `Result<T>` from a const
+ /// reference to a `Result<U>`.
+ ///
+ /// `T` must be implicitly constructible from `const U &`.
+ ///
+ /// \param other The value to copy from.
+ template <typename U, typename E = typename std::enable_if<
+ std::is_constructible<T, const U&>::value &&
+ std::is_convertible<U, T>::value>::type>
+ Result(const Result<U>& other) noexcept : status_(other.status_) {
+ if (ARROW_PREDICT_TRUE(status_.ok())) {
+ ConstructValue(other.ValueUnsafe());
+ }
+ }
+
+ /// Copy-assignment operator.
+ ///
+ /// \param other The Result object to copy.
+ Result& operator=(const Result& other) noexcept {
+ // Check for self-assignment.
+ if (ARROW_PREDICT_FALSE(this == &other)) {
+ return *this;
+ }
+ Destroy();
+ status_ = other.status_;
+ if (ARROW_PREDICT_TRUE(status_.ok())) {
+ ConstructValue(other.ValueUnsafe());
+ }
+ return *this;
+ }
+
+ /// Templatized constructor which constructs a `Result<T>` by moving the
+ /// contents of a `Result<U>`. `T` must be implicitly constructible from `U
+ /// &&`.
+ ///
+ /// Sets `other` to contain a non-OK status with a`StatusError::Invalid`
+ /// error code.
+ ///
+ /// \param other The Result object to move from and set to a non-OK status.
+ template <typename U,
+ typename E = typename std::enable_if<std::is_constructible<T, U&&>::value &&
+ std::is_convertible<U, T>::value>::type>
+ Result(Result<U>&& other) noexcept {
+ if (ARROW_PREDICT_TRUE(other.status_.ok())) {
+ status_ = std::move(other.status_);
+ ConstructValue(other.MoveValueUnsafe());
+ } else {
+ // If we moved the status, the other status may become ok but the other
+ // value hasn't been constructed => crash on other destructor.
+ status_ = other.status_;
+ }
+ }
+
+ /// Move-assignment operator.
+ ///
+ /// Sets `other` to an invalid state..
+ ///
+ /// \param other The Result object to assign from and set to a non-OK
+ /// status.
+ Result& operator=(Result&& other) noexcept {
+ // Check for self-assignment.
+ if (ARROW_PREDICT_FALSE(this == &other)) {
+ return *this;
+ }
+ Destroy();
+ if (ARROW_PREDICT_TRUE(other.status_.ok())) {
+ status_ = std::move(other.status_);
+ ConstructValue(other.MoveValueUnsafe());
+ } else {
+ // If we moved the status, the other status may become ok but the other
+ // value hasn't been constructed => crash on other destructor.
+ status_ = other.status_;
+ }
+ return *this;
+ }
+
+ /// Compare to another Result.
+ bool Equals(const Result& other) const {
+ if (ARROW_PREDICT_TRUE(status_.ok())) {
+ return other.status_.ok() && ValueUnsafe() == other.ValueUnsafe();
+ }
+ return status_ == other.status_;
+ }
+
+ /// Indicates whether the object contains a `T` value. Generally instead
+ /// of accessing this directly you will want to use ASSIGN_OR_RAISE defined
+ /// below.
+ ///
+ /// \return True if this Result object's status is OK (i.e. a call to ok()
+ /// returns true). If this function returns true, then it is safe to access
+ /// the wrapped element through a call to ValueOrDie().
+ constexpr bool ok() const { return status_.ok(); }
+
+ /// \brief Equivalent to ok().
+ // operator bool() const { return ok(); }
+
+ /// Gets the stored status object, or an OK status if a `T` value is stored.
+ ///
+ /// \return The stored non-OK status object, or an OK status if this object
+ /// has a value.
+ constexpr const Status& status() const { return status_; }
+
+ /// Gets the stored `T` value.
+ ///
+ /// This method should only be called if this Result object's status is OK
+ /// (i.e. a call to ok() returns true), otherwise this call will abort.
+ ///
+ /// \return The stored `T` value.
+ const T& ValueOrDie() const& {
+ if (ARROW_PREDICT_FALSE(!ok())) {
+ internal::InvalidValueOrDie(status_);
+ }
+ return ValueUnsafe();
+ }
+ const T& operator*() const& { return ValueOrDie(); }
+ const T* operator->() const { return &ValueOrDie(); }
+
+ /// Gets a mutable reference to the stored `T` value.
+ ///
+ /// This method should only be called if this Result object's status is OK
+ /// (i.e. a call to ok() returns true), otherwise this call will abort.
+ ///
+ /// \return The stored `T` value.
+ T& ValueOrDie() & {
+ if (ARROW_PREDICT_FALSE(!ok())) {
+ internal::InvalidValueOrDie(status_);
+ }
+ return ValueUnsafe();
+ }
+ T& operator*() & { return ValueOrDie(); }
+ T* operator->() { return &ValueOrDie(); }
+
+ /// Moves and returns the internally-stored `T` value.
+ ///
+ /// This method should only be called if this Result object's status is OK
+ /// (i.e. a call to ok() returns true), otherwise this call will abort. The
+ /// Result object is invalidated after this call and will be updated to
+ /// contain a non-OK status.
+ ///
+ /// \return The stored `T` value.
+ T ValueOrDie() && {
+ if (ARROW_PREDICT_FALSE(!ok())) {
+ internal::InvalidValueOrDie(status_);
+ }
+ return MoveValueUnsafe();
+ }
+ T operator*() && { return std::move(*this).ValueOrDie(); }
+
+ /// Helper method for implementing Status returning functions in terms of semantically
+ /// equivalent Result returning functions. For example:
+ ///
+ /// Status GetInt(int *out) { return GetInt().Value(out); }
+ template <typename U, typename E = typename std::enable_if<
+ std::is_constructible<U, T>::value>::type>
+ Status Value(U* out) && {
+ if (!ok()) {
+ return status();
+ }
+ *out = U(MoveValueUnsafe());
+ return Status::OK();
+ }
+
+ /// Move and return the internally stored value or alternative if an error is stored.
+ T ValueOr(T alternative) && {
+ if (!ok()) {
+ return alternative;
+ }
+ return MoveValueUnsafe();
+ }
+
+ /// Retrieve the value if ok(), falling back to an alternative generated by the provided
+ /// factory
+ template <typename G>
+ T ValueOrElse(G&& generate_alternative) && {
+ if (ok()) {
+ return MoveValueUnsafe();
+ }
+ return generate_alternative();
+ }
+
+ /// Apply a function to the internally stored value to produce a new result or propagate
+ /// the stored error.
+ template <typename M>
+ typename EnsureResult<decltype(std::declval<M&&>()(std::declval<T&&>()))>::type Map(
+ M&& m) && {
+ if (!ok()) {
+ return status();
+ }
+ return std::forward<M>(m)(MoveValueUnsafe());
+ }
+
+ /// Apply a function to the internally stored value to produce a new result or propagate
+ /// the stored error.
+ template <typename M>
+ typename EnsureResult<decltype(std::declval<M&&>()(std::declval<const T&>()))>::type
+ Map(M&& m) const& {
+ if (!ok()) {
+ return status();
+ }
+ return std::forward<M>(m)(ValueUnsafe());
+ }
+
+ /// Cast the internally stored value to produce a new result or propagate the stored
+ /// error.
+ template <typename U, typename E = typename std::enable_if<
+ std::is_constructible<U, T>::value>::type>
+ Result<U> As() && {
+ if (!ok()) {
+ return status();
+ }
+ return U(MoveValueUnsafe());
+ }
+
+ /// Cast the internally stored value to produce a new result or propagate the stored
+ /// error.
+ template <typename U, typename E = typename std::enable_if<
+ std::is_constructible<U, const T&>::value>::type>
+ Result<U> As() const& {
+ if (!ok()) {
+ return status();
+ }
+ return U(ValueUnsafe());
+ }
+
+ constexpr const T& ValueUnsafe() const& { return *storage_.get(); }
+
+#if __cpp_constexpr >= 201304L // non-const constexpr
+ constexpr T& ValueUnsafe() & { return *storage_.get(); }
+#else
+ T& ValueUnsafe() & { return *storage_.get(); }
+#endif
+
+ T ValueUnsafe() && { return MoveValueUnsafe(); }
+
+ T MoveValueUnsafe() { return std::move(*storage_.get()); }
+
+ private:
+ Status status_; // pointer-sized
+ internal::AlignedStorage<T> storage_;
+
+ template <typename U>
+ void ConstructValue(U&& u) noexcept {
+ storage_.construct(std::forward<U>(u));
+ }
+
+ void Destroy() noexcept {
+ if (ARROW_PREDICT_TRUE(status_.ok())) {
+ static_assert(offsetof(Result<T>, status_) == 0,
+ "Status is guaranteed to be at the start of Result<>");
+ storage_.destroy();
+ }
+ }
+};
+
+#define ARROW_ASSIGN_OR_RAISE_IMPL(result_name, lhs, rexpr) \
+ auto&& result_name = (rexpr); \
+ ARROW_RETURN_IF_(!(result_name).ok(), (result_name).status(), ARROW_STRINGIFY(rexpr)); \
+ lhs = std::move(result_name).ValueUnsafe();
+
+#define ARROW_ASSIGN_OR_RAISE_NAME(x, y) ARROW_CONCAT(x, y)
+
+/// \brief Execute an expression that returns a Result, extracting its value
+/// into the variable defined by `lhs` (or returning a Status on error).
+///
+/// Example: Assigning to a new value:
+/// ARROW_ASSIGN_OR_RAISE(auto value, MaybeGetValue(arg));
+///
+/// Example: Assigning to an existing value:
+/// ValueType value;
+/// ARROW_ASSIGN_OR_RAISE(value, MaybeGetValue(arg));
+///
+/// WARNING: ARROW_ASSIGN_OR_RAISE expands into multiple statements;
+/// it cannot be used in a single statement (e.g. as the body of an if
+/// statement without {})!
+///
+/// WARNING: ARROW_ASSIGN_OR_RAISE `std::move`s its right operand. If you have
+/// an lvalue Result which you *don't* want to move out of cast appropriately.
+///
+/// WARNING: ARROW_ASSIGN_OR_RAISE is not a single expression; it will not
+/// maintain lifetimes of all temporaries in `rexpr` (e.g.
+/// `ARROW_ASSIGN_OR_RAISE(auto x, MakeTemp().GetResultRef());`
+/// will most likely segfault)!
+#define ARROW_ASSIGN_OR_RAISE(lhs, rexpr) \
+ ARROW_ASSIGN_OR_RAISE_IMPL(ARROW_ASSIGN_OR_RAISE_NAME(_error_or_value, __COUNTER__), \
+ lhs, rexpr);
+
+namespace internal {
+
+template <typename T>
+inline const Status& GenericToStatus(const Result<T>& res) {
+ return res.status();
+}
+
+template <typename T>
+inline Status GenericToStatus(Result<T>&& res) {
+ return std::move(res).status();
+}
+
+} // namespace internal
+
+template <typename T, typename R = typename EnsureResult<T>::type>
+R ToResult(T t) {
+ return R(std::move(t));
+}
+
+template <typename T>
+struct EnsureResult {
+ using type = Result<T>;
+};
+
+template <typename T>
+struct EnsureResult<Result<T>> {
+ using type = Result<T>;
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/result_internal.h b/src/arrow/cpp/src/arrow/result_internal.h
new file mode 100644
index 000000000..7550f945d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/result_internal.h
@@ -0,0 +1,22 @@
+//
+// Copyright 2017 Asylo authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+#pragma once
+
+#include "arrow/result.h"
+
+#ifndef ASSIGN_OR_RAISE
+#define ASSIGN_OR_RAISE(lhs, rhs) ARROW_ASSIGN_OR_RAISE(lhs, rhs)
+#endif
diff --git a/src/arrow/cpp/src/arrow/result_test.cc b/src/arrow/cpp/src/arrow/result_test.cc
new file mode 100644
index 000000000..cb645bc74
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/result_test.cc
@@ -0,0 +1,799 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/result.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_compat.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+
+namespace arrow {
+
+namespace {
+
+using ::testing::Eq;
+
+StatusCode kErrorCode = StatusCode::Invalid;
+constexpr const char* kErrorMessage = "Invalid argument";
+
+const int kIntElement = 42;
+constexpr const char* kStringElement =
+ "The Answer to the Ultimate Question of Life, the Universe, and Everything";
+
+// A data type without a default constructor.
+struct Foo {
+ int bar;
+ std::string baz;
+
+ explicit Foo(int value) : bar(value), baz(kStringElement) {}
+
+ bool operator==(const Foo& other) const {
+ return (bar == other.bar) && (baz == other.baz);
+ }
+};
+
+// A data type with only copy constructors.
+struct CopyOnlyDataType {
+ explicit CopyOnlyDataType(int x) : data(x) {}
+
+ CopyOnlyDataType(const CopyOnlyDataType& other) = default;
+ CopyOnlyDataType& operator=(const CopyOnlyDataType& other) = default;
+
+ int data;
+};
+
+struct ImplicitlyCopyConvertible {
+ ImplicitlyCopyConvertible(const CopyOnlyDataType& co) // NOLINT runtime/explicit
+ : copy_only(co) {}
+
+ CopyOnlyDataType copy_only;
+};
+
+// A data type with only move constructors.
+struct MoveOnlyDataType {
+ explicit MoveOnlyDataType(int x) : data(new int(x)) {}
+
+ MoveOnlyDataType(const MoveOnlyDataType& other) = delete;
+ MoveOnlyDataType& operator=(const MoveOnlyDataType& other) = delete;
+
+ MoveOnlyDataType(MoveOnlyDataType&& other) { MoveFrom(&other); }
+ MoveOnlyDataType& operator=(MoveOnlyDataType&& other) {
+ MoveFrom(&other);
+ return *this;
+ }
+
+ ~MoveOnlyDataType() { Destroy(); }
+
+ void Destroy() {
+ if (data != nullptr) {
+ delete data;
+ data = nullptr;
+ }
+ }
+
+ void MoveFrom(MoveOnlyDataType* other) {
+ Destroy();
+ data = other->data;
+ other->data = nullptr;
+ }
+
+ int* data = nullptr;
+};
+
+struct ImplicitlyMoveConvertible {
+ ImplicitlyMoveConvertible(MoveOnlyDataType&& mo) // NOLINT runtime/explicit
+ : move_only(std::move(mo)) {}
+
+ MoveOnlyDataType move_only;
+};
+
+// A data type with dynamically-allocated data.
+struct HeapAllocatedObject {
+ int* value;
+
+ HeapAllocatedObject() {
+ value = new int;
+ *value = kIntElement;
+ }
+
+ HeapAllocatedObject(const HeapAllocatedObject& other) {
+ value = new int;
+ *value = *other.value;
+ }
+
+ HeapAllocatedObject& operator=(const HeapAllocatedObject& other) {
+ *value = *other.value;
+ return *this;
+ }
+
+ HeapAllocatedObject(HeapAllocatedObject&& other) {
+ value = other.value;
+ other.value = nullptr;
+ }
+
+ ~HeapAllocatedObject() { delete value; }
+
+ bool operator==(const HeapAllocatedObject& other) const {
+ return *value == *other.value;
+ }
+};
+
+// Constructs a Foo.
+struct FooCtor {
+ using value_type = Foo;
+
+ Foo operator()() { return Foo(kIntElement); }
+};
+
+// Constructs a HeapAllocatedObject.
+struct HeapAllocatedObjectCtor {
+ using value_type = HeapAllocatedObject;
+
+ HeapAllocatedObject operator()() { return HeapAllocatedObject(); }
+};
+
+// Constructs an integer.
+struct IntCtor {
+ using value_type = int;
+
+ int operator()() { return kIntElement; }
+};
+
+// Constructs a string.
+struct StringCtor {
+ using value_type = std::string;
+
+ std::string operator()() { return std::string(kStringElement); }
+};
+
+// Constructs a vector of strings.
+struct StringVectorCtor {
+ using value_type = std::vector<std::string>;
+
+ std::vector<std::string> operator()() { return {kStringElement, kErrorMessage}; }
+};
+
+// Returns an rvalue reference to the Result<T> object pointed to by
+// |result|.
+template <class T>
+Result<T>&& MoveResult(Result<T>* result) {
+ return std::move(*result);
+}
+
+// A test fixture is required for typed tests.
+template <typename T>
+class ResultTest : public ::testing::Test {};
+
+using TestTypes = ::testing::Types<IntCtor, FooCtor, StringCtor, StringVectorCtor,
+ HeapAllocatedObjectCtor>;
+
+TYPED_TEST_SUITE(ResultTest, TestTypes);
+
+// Verify that the default constructor for Result constructs an object with a
+// non-ok status.
+TYPED_TEST(ResultTest, ConstructorDefault) {
+ Result<typename TypeParam::value_type> result;
+ EXPECT_FALSE(result.ok());
+ EXPECT_EQ(result.status().code(), StatusCode::UnknownError);
+}
+
+// Verify that Result can be constructed from a Status object.
+TYPED_TEST(ResultTest, ConstructorStatus) {
+ Result<typename TypeParam::value_type> result(Status(kErrorCode, kErrorMessage));
+
+ EXPECT_FALSE(result.ok());
+ EXPECT_FALSE(result.status().ok());
+ EXPECT_EQ(result.status().code(), kErrorCode);
+ EXPECT_EQ(result.status().message(), kErrorMessage);
+}
+
+// Verify that Result can be constructed from an object of its element type.
+TYPED_TEST(ResultTest, ConstructorElementConstReference) {
+ typename TypeParam::value_type value = TypeParam()();
+ Result<typename TypeParam::value_type> result(value);
+
+ ASSERT_TRUE(result.ok());
+ ASSERT_TRUE(result.status().ok());
+ EXPECT_EQ(result.ValueOrDie(), value);
+ EXPECT_EQ(*result, value);
+}
+
+// Verify that Result can be constructed from an rvalue reference of an object
+// of its element type.
+TYPED_TEST(ResultTest, ConstructorElementRValue) {
+ typename TypeParam::value_type value = TypeParam()();
+ typename TypeParam::value_type value_copy(value);
+ Result<typename TypeParam::value_type> result(std::move(value));
+
+ ASSERT_TRUE(result.ok());
+ ASSERT_TRUE(result.status().ok());
+
+ // Compare to a copy of the original value, since the original was moved.
+ EXPECT_EQ(result.ValueOrDie(), value_copy);
+}
+
+// Verify that Result can be copy-constructed from a Result with a non-ok
+// status.
+TYPED_TEST(ResultTest, CopyConstructorNonOkStatus) {
+ Result<typename TypeParam::value_type> result1 = Status(kErrorCode, kErrorMessage);
+ Result<typename TypeParam::value_type> result2(result1);
+
+ EXPECT_EQ(result1.ok(), result2.ok());
+ EXPECT_EQ(result1.status().message(), result2.status().message());
+}
+
+// Verify that Result can be copy-constructed from a Result with an ok
+// status.
+TYPED_TEST(ResultTest, CopyConstructorOkStatus) {
+ typename TypeParam::value_type value = TypeParam()();
+ Result<typename TypeParam::value_type> result1(value);
+ Result<typename TypeParam::value_type> result2(result1);
+
+ EXPECT_EQ(result1.ok(), result2.ok());
+ ASSERT_TRUE(result2.ok());
+ EXPECT_EQ(result1.ValueOrDie(), result2.ValueOrDie());
+ EXPECT_EQ(*result1, *result2);
+}
+
+// Verify that copy-assignment of a Result with a non-ok is working as
+// expected.
+TYPED_TEST(ResultTest, CopyAssignmentNonOkStatus) {
+ Result<typename TypeParam::value_type> result1(Status(kErrorCode, kErrorMessage));
+ typename TypeParam::value_type value = TypeParam()();
+ Result<typename TypeParam::value_type> result2(value);
+
+ // Invoke the copy-assignment operator.
+ result2 = result1;
+ EXPECT_EQ(result1.ok(), result2.ok());
+ EXPECT_EQ(result1.status().message(), result2.status().message());
+}
+
+// Verify that copy-assignment of a Result with an ok status is working as
+// expected.
+TYPED_TEST(ResultTest, CopyAssignmentOkStatus) {
+ typename TypeParam::value_type value = TypeParam()();
+ Result<typename TypeParam::value_type> result1(value);
+ Result<typename TypeParam::value_type> result2(Status(kErrorCode, kErrorMessage));
+
+ // Invoke the copy-assignment operator.
+ result2 = result1;
+ EXPECT_EQ(result1.ok(), result2.ok());
+ ASSERT_TRUE(result2.ok());
+ EXPECT_EQ(result1.ValueOrDie(), result2.ValueOrDie());
+ EXPECT_EQ(*result1, *result2);
+}
+
+// Verify that copy-assignment of a Result with a non-ok status to itself is
+// properly handled.
+TYPED_TEST(ResultTest, CopyAssignmentSelfNonOkStatus) {
+ Status status(kErrorCode, kErrorMessage);
+ Result<typename TypeParam::value_type> result(status);
+ result = *&result;
+
+ EXPECT_FALSE(result.ok());
+ EXPECT_EQ(result.status().code(), status.code());
+}
+
+// Verify that copy-assignment of a Result with an ok status to itself is
+// properly handled.
+TYPED_TEST(ResultTest, CopyAssignmentSelfOkStatus) {
+ typename TypeParam::value_type value = TypeParam()();
+ Result<typename TypeParam::value_type> result(value);
+ result = *&result;
+
+ ASSERT_TRUE(result.ok());
+ EXPECT_EQ(result.ValueOrDie(), value);
+ EXPECT_EQ(*result, value);
+}
+
+// Verify that Result can be move-constructed from a Result with a non-ok
+// status.
+TYPED_TEST(ResultTest, MoveConstructorNonOkStatus) {
+ Status status(kErrorCode, kErrorMessage);
+ Result<typename TypeParam::value_type> result1(status);
+ Result<typename TypeParam::value_type> result2(std::move(result1));
+
+ // Verify that the destination object contains the status previously held by
+ // the donor.
+ EXPECT_FALSE(result2.ok());
+ EXPECT_EQ(result2.status().code(), status.code());
+}
+
+// Verify that Result can be move-constructed from a Result with an ok
+// status.
+TYPED_TEST(ResultTest, MoveConstructorOkStatus) {
+ typename TypeParam::value_type value = TypeParam()();
+ Result<typename TypeParam::value_type> result1(value);
+ Result<typename TypeParam::value_type> result2(std::move(result1));
+
+ // The destination object should possess the value previously held by the
+ // donor.
+ ASSERT_TRUE(result2.ok());
+ EXPECT_EQ(result2.ValueOrDie(), value);
+}
+
+// Verify that move-assignment from a Result with a non-ok status is working
+// as expected.
+TYPED_TEST(ResultTest, MoveAssignmentOperatorNonOkStatus) {
+ Status status(kErrorCode, kErrorMessage);
+ Result<typename TypeParam::value_type> result1(status);
+ typename TypeParam::value_type value = TypeParam()();
+ Result<typename TypeParam::value_type> result2(value);
+
+ // Invoke the move-assignment operator.
+ result2 = std::move(result1);
+
+ // Verify that the destination object contains the status previously held by
+ // the donor.
+ EXPECT_FALSE(result2.ok());
+ EXPECT_EQ(result2.status().code(), status.code());
+}
+
+// Verify that move-assignment from a Result with an ok status is working as
+// expected.
+TYPED_TEST(ResultTest, MoveAssignmentOperatorOkStatus) {
+ typename TypeParam::value_type value = TypeParam()();
+ Result<typename TypeParam::value_type> result1(value);
+ Result<typename TypeParam::value_type> result2(Status(kErrorCode, kErrorMessage));
+
+ // Invoke the move-assignment operator.
+ result2 = std::move(result1);
+
+ // The destination object should possess the value previously held by the
+ // donor.
+ ASSERT_TRUE(result2.ok());
+ EXPECT_EQ(result2.ValueOrDie(), value);
+}
+
+// Verify that move-assignment of a Result with a non-ok status to itself is
+// handled properly.
+TYPED_TEST(ResultTest, MoveAssignmentSelfNonOkStatus) {
+ Status status(kErrorCode, kErrorMessage);
+ Result<typename TypeParam::value_type> result(status);
+
+ result = MoveResult(&result);
+
+ EXPECT_FALSE(result.ok());
+ EXPECT_EQ(result.status().code(), status.code());
+}
+
+// Verify that move-assignment of a Result with an ok-status to itself is
+// handled properly.
+TYPED_TEST(ResultTest, MoveAssignmentSelfOkStatus) {
+ typename TypeParam::value_type value = TypeParam()();
+ Result<typename TypeParam::value_type> result(value);
+
+ result = MoveResult(&result);
+
+ ASSERT_TRUE(result.ok());
+ EXPECT_EQ(result.ValueOrDie(), value);
+}
+
+// Tests for move-only types. These tests use std::unique_ptr<> as the
+// test type, since it is valuable to support this type in the Asylo infra.
+// These tests are not part of the typed test suite for the following reasons:
+// * std::unique_ptr<> cannot be used as a type in tests that expect
+// the test type to support copy operations.
+// * std::unique_ptr<> provides an equality operator that checks equality of
+// the underlying ptr. Consequently, it is difficult to generalize existing
+// tests that verify ValueOrDie() functionality using equality comparisons.
+
+// Verify that a Result object can be constructed from a move-only type.
+TEST(ResultTest, InitializationMoveOnlyType) {
+ std::unique_ptr<std::string> value(new std::string(kStringElement));
+ auto str = value.get();
+ Result<std::unique_ptr<std::string>> result(std::move(value));
+
+ ASSERT_TRUE(result.ok());
+ EXPECT_EQ(result.ValueOrDie().get(), str);
+}
+
+// Verify that a Result object can be move-constructed from a move-only type.
+TEST(ResultTest, MoveConstructorMoveOnlyType) {
+ std::unique_ptr<std::string> value(new std::string(kStringElement));
+ auto str = value.get();
+ Result<std::unique_ptr<std::string>> result1(std::move(value));
+ Result<std::unique_ptr<std::string>> result2(std::move(result1));
+
+ // The destination object should possess the value previously held by the
+ // donor.
+ ASSERT_TRUE(result2.ok());
+ EXPECT_EQ(result2.ValueOrDie().get(), str);
+}
+
+// Verify that a Result object can be move-assigned to from a Result object
+// containing a move-only type.
+TEST(ResultTest, MoveAssignmentMoveOnlyType) {
+ std::unique_ptr<std::string> value(new std::string(kStringElement));
+ auto str = value.get();
+ Result<std::unique_ptr<std::string>> result1(std::move(value));
+ Result<std::unique_ptr<std::string>> result2(Status(kErrorCode, kErrorMessage));
+
+ // Invoke the move-assignment operator.
+ result2 = std::move(result1);
+
+ // The destination object should possess the value previously held by the
+ // donor.
+ ASSERT_TRUE(result2.ok());
+ EXPECT_EQ(result2.ValueOrDie().get(), str);
+}
+
+// Verify that a value can be moved out of a Result object via ValueOrDie().
+TEST(ResultTest, ValueOrDieMovedValue) {
+ std::unique_ptr<std::string> value(new std::string(kStringElement));
+ auto str = value.get();
+ Result<std::unique_ptr<std::string>> result(std::move(value));
+
+ std::unique_ptr<std::string> moved_value = std::move(result).ValueOrDie();
+ EXPECT_EQ(moved_value.get(), str);
+ EXPECT_EQ(*moved_value, kStringElement);
+}
+
+// Verify that a Result<T> is implicitly constructible from some U, where T is
+// a type which has an implicit constructor taking a const U &.
+TEST(ResultTest, TemplateValueCopyConstruction) {
+ CopyOnlyDataType copy_only(kIntElement);
+ Result<ImplicitlyCopyConvertible> result(copy_only);
+
+ EXPECT_TRUE(result.ok());
+ EXPECT_EQ(result.ValueOrDie().copy_only.data, kIntElement);
+}
+
+// Verify that a Result<T> is implicitly constructible from some U, where T is
+// a type which has an implicit constructor taking a U &&.
+TEST(ResultTest, TemplateValueMoveConstruction) {
+ MoveOnlyDataType move_only(kIntElement);
+ Result<ImplicitlyMoveConvertible> result(std::move(move_only));
+
+ EXPECT_TRUE(result.ok());
+ EXPECT_EQ(*result.ValueOrDie().move_only.data, kIntElement);
+}
+
+// Verify that an error rvalue Result<T> allows access if an alternative is provided
+TEST(ResultTest, ErrorRvalueValueOrAlternative) {
+ Result<MoveOnlyDataType> result = Status::Invalid("");
+
+ EXPECT_FALSE(result.ok());
+ EXPECT_EQ(*std::move(result).ValueOr(MoveOnlyDataType{kIntElement}).data, kIntElement);
+}
+
+// Verify that an ok rvalue Result<T> will ignore a provided alternative
+TEST(ResultTest, OkRvalueValueOrAlternative) {
+ Result<MoveOnlyDataType> result = MoveOnlyDataType{kIntElement};
+
+ EXPECT_TRUE(result.ok());
+ EXPECT_EQ(*std::move(result).ValueOr(MoveOnlyDataType{kIntElement - 1}).data,
+ kIntElement);
+}
+
+// Verify that an error rvalue Result<T> allows access if an alternative factory is
+// provided
+TEST(ResultTest, ErrorRvalueValueOrGeneratedAlternative) {
+ Result<MoveOnlyDataType> result = Status::Invalid("");
+
+ EXPECT_FALSE(result.ok());
+ auto out = std::move(result).ValueOrElse([] { return MoveOnlyDataType{kIntElement}; });
+ EXPECT_EQ(*out.data, kIntElement);
+}
+
+// Verify that an ok rvalue Result<T> allows access if an alternative factory is provided
+TEST(ResultTest, OkRvalueValueOrGeneratedAlternative) {
+ Result<MoveOnlyDataType> result = MoveOnlyDataType{kIntElement};
+
+ EXPECT_TRUE(result.ok());
+ auto out =
+ std::move(result).ValueOrElse([] { return MoveOnlyDataType{kIntElement - 1}; });
+ EXPECT_EQ(*out.data, kIntElement);
+}
+
+// Verify that a Result<T> can be unpacked to T
+TEST(ResultTest, StatusReturnAdapterCopyValue) {
+ Result<CopyOnlyDataType> result(CopyOnlyDataType{kIntElement});
+ CopyOnlyDataType copy_only{0};
+
+ EXPECT_TRUE(std::move(result).Value(&copy_only).ok());
+ EXPECT_EQ(copy_only.data, kIntElement);
+}
+
+// Verify that a Result<T> can be unpacked to some U, where U is
+// a type which has a constructor taking a const T &.
+TEST(ResultTest, StatusReturnAdapterCopyAndConvertValue) {
+ Result<CopyOnlyDataType> result(CopyOnlyDataType{kIntElement});
+ ImplicitlyCopyConvertible implicitly_convertible(CopyOnlyDataType{0});
+
+ EXPECT_TRUE(std::move(result).Value(&implicitly_convertible).ok());
+ EXPECT_EQ(implicitly_convertible.copy_only.data, kIntElement);
+}
+
+// Verify that a Result<T> can be unpacked to T
+TEST(ResultTest, StatusReturnAdapterMoveValue) {
+ {
+ Result<MoveOnlyDataType> result(MoveOnlyDataType{kIntElement});
+ MoveOnlyDataType move_only{0};
+
+ EXPECT_TRUE(std::move(result).Value(&move_only).ok());
+ EXPECT_EQ(*move_only.data, kIntElement);
+ }
+ {
+ Result<MoveOnlyDataType> result(MoveOnlyDataType{kIntElement});
+ auto move_only = std::move(result).ValueOrDie();
+ EXPECT_EQ(*move_only.data, kIntElement);
+ }
+ {
+ Result<MoveOnlyDataType> result(MoveOnlyDataType{kIntElement});
+ auto move_only = *std::move(result);
+ EXPECT_EQ(*move_only.data, kIntElement);
+ }
+}
+
+// Verify that a Result<T> can be unpacked to some U, where U is
+// a type which has a constructor taking a T &&.
+TEST(ResultTest, StatusReturnAdapterMoveAndConvertValue) {
+ Result<MoveOnlyDataType> result(MoveOnlyDataType{kIntElement});
+ ImplicitlyMoveConvertible implicitly_convertible(MoveOnlyDataType{0});
+
+ EXPECT_TRUE(std::move(result).Value(&implicitly_convertible).ok());
+ EXPECT_EQ(*implicitly_convertible.move_only.data, kIntElement);
+}
+
+// Verify that a Result<T> can be queried for a stored value or an alternative.
+TEST(ResultTest, ValueOrAlternative) {
+ EXPECT_EQ(Result<MoveOnlyDataType>(MoveOnlyDataType{kIntElement})
+ .ValueOr(MoveOnlyDataType{0})
+ .data[0],
+ kIntElement);
+
+ EXPECT_EQ(
+ Result<MoveOnlyDataType>(Status::Invalid("")).ValueOr(MoveOnlyDataType{0}).data[0],
+ 0);
+}
+
+TEST(ResultTest, MapFunctionToConstValue) {
+ static auto error = Status::Invalid("some error message");
+
+ const Result<MoveOnlyDataType> result(MoveOnlyDataType{kIntElement});
+
+ auto const_mapped =
+ result.Map([](const MoveOnlyDataType& m) -> Result<int> { return *m.data; });
+ EXPECT_TRUE(const_mapped.ok());
+ EXPECT_EQ(const_mapped.ValueOrDie(), kIntElement);
+
+ auto const_error =
+ result.Map([](const MoveOnlyDataType& m) -> Result<int> { return error; });
+ EXPECT_FALSE(const_error.ok());
+ EXPECT_EQ(const_error.status(), error);
+}
+
+TEST(ResultTest, MapFunctionToRrefValue) {
+ static auto error = Status::Invalid("some error message");
+
+ auto result = [] { return Result<MoveOnlyDataType>(MoveOnlyDataType{kIntElement}); };
+
+ auto move_mapped =
+ result().Map([](MoveOnlyDataType m) -> Result<int> { return std::move(*m.data); });
+ EXPECT_TRUE(move_mapped.ok());
+ EXPECT_EQ(move_mapped.ValueOrDie(), kIntElement);
+
+ auto move_error = result().Map([](MoveOnlyDataType m) -> Result<int> { return error; });
+ EXPECT_FALSE(move_error.ok());
+ EXPECT_EQ(move_error.status(), error);
+}
+
+TEST(ResultTest, MapFunctionToConstError) {
+ static auto error = Status::Invalid("some error message");
+ static auto other_error = Status::Invalid("some other error message");
+
+ const Result<MoveOnlyDataType> result(error);
+
+ auto const_mapped =
+ result.Map([](const MoveOnlyDataType& m) -> Result<int> { return *m.data; });
+ EXPECT_FALSE(const_mapped.ok());
+ EXPECT_EQ(const_mapped.status(), error); // error is *not* replaced by a value
+
+ auto const_error =
+ result.Map([](const MoveOnlyDataType& m) -> Result<int> { return other_error; });
+ EXPECT_FALSE(const_error.ok());
+ EXPECT_EQ(const_error.status(), error); // error is *not* replaced by other_error
+}
+
+TEST(ResultTest, MapFunctionToRrefError) {
+ static auto error = Status::Invalid("some error message");
+ static auto other_error = Status::Invalid("some other error message");
+
+ auto result = [] { return Result<MoveOnlyDataType>(error); };
+
+ auto move_mapped =
+ result().Map([](MoveOnlyDataType m) -> Result<int> { return std::move(*m.data); });
+ EXPECT_FALSE(move_mapped.ok());
+ EXPECT_EQ(move_mapped.status(), error); // error is *not* replaced by a value
+
+ auto move_error =
+ result().Map([](MoveOnlyDataType m) -> Result<int> { return other_error; });
+ EXPECT_FALSE(move_error.ok());
+ EXPECT_EQ(move_error.status(), error); // error is *not* replaced by other_error
+}
+
+// Verify that a Result<U> is assignable to a Result<T>, where T
+// is a type which has an implicit constructor taking a const U &.
+TEST(ResultTest, TemplateCopyAssign) {
+ CopyOnlyDataType copy_only(kIntElement);
+ Result<CopyOnlyDataType> result(copy_only);
+
+ Result<ImplicitlyCopyConvertible> result2 = result;
+
+ EXPECT_TRUE(result.ok());
+ EXPECT_EQ(result.ValueOrDie().data, kIntElement);
+ EXPECT_TRUE(result2.ok());
+ EXPECT_EQ(result2.ValueOrDie().copy_only.data, kIntElement);
+}
+
+// Verify that a Result<U> is assignable to a Result<T>, where T is a type
+// which has an implicit constructor taking a U &&.
+TEST(ResultTest, TemplateMoveAssign) {
+ MoveOnlyDataType move_only(kIntElement);
+ Result<MoveOnlyDataType> result(std::move(move_only));
+
+ Result<ImplicitlyMoveConvertible> result2 = std::move(result);
+
+ EXPECT_TRUE(result2.ok());
+ EXPECT_EQ(*result2.ValueOrDie().move_only.data, kIntElement);
+}
+
+// Verify that a Result<U> is constructible from a Result<T>, where T is a
+// type which has an implicit constructor taking a const U &.
+TEST(ResultTest, TemplateCopyConstruct) {
+ CopyOnlyDataType copy_only(kIntElement);
+ Result<CopyOnlyDataType> result(copy_only);
+ Result<ImplicitlyCopyConvertible> result2(result);
+
+ EXPECT_TRUE(result.ok());
+ EXPECT_EQ(result.ValueOrDie().data, kIntElement);
+ EXPECT_TRUE(result2.ok());
+ EXPECT_EQ(result2.ValueOrDie().copy_only.data, kIntElement);
+}
+
+// Verify that a Result<U> is constructible from a Result<T>, where T is a
+// type which has an implicit constructor taking a U &&.
+TEST(ResultTest, TemplateMoveConstruct) {
+ MoveOnlyDataType move_only(kIntElement);
+ Result<MoveOnlyDataType> result(std::move(move_only));
+ Result<ImplicitlyMoveConvertible> result2(std::move(result));
+
+ EXPECT_TRUE(result2.ok());
+ EXPECT_EQ(*result2.ValueOrDie().move_only.data, kIntElement);
+}
+
+TEST(ResultTest, Equality) {
+ EXPECT_EQ(Result<int>(), Result<int>());
+ EXPECT_EQ(Result<int>(3), Result<int>(3));
+ EXPECT_EQ(Result<int>(Status::Invalid("error")), Result<int>(Status::Invalid("error")));
+
+ EXPECT_NE(Result<int>(), Result<int>(3));
+ EXPECT_NE(Result<int>(Status::Invalid("error")), Result<int>(3));
+ EXPECT_NE(Result<int>(3333), Result<int>(0));
+ EXPECT_NE(Result<int>(Status::Invalid("error")),
+ Result<int>(Status::Invalid("other error")));
+
+ {
+ Result<int> moved_from(3);
+ auto moved_to = std::move(moved_from);
+ EXPECT_EQ(moved_to, Result<int>(3));
+ }
+ {
+ Result<std::vector<int>> a, b, c;
+ a = std::vector<int>{1, 2, 3, 4, 5};
+ b = std::vector<int>{1, 2, 3, 4, 5};
+ c = std::vector<int>{1, 2, 3, 4};
+ EXPECT_EQ(a, b);
+ EXPECT_NE(a, c);
+
+ c = std::move(b);
+ EXPECT_EQ(a, c);
+ EXPECT_EQ(c.ValueOrDie(), (std::vector<int>{1, 2, 3, 4, 5}));
+ EXPECT_NE(a, b); // b's value was moved
+ }
+}
+
+TEST(ResultTest, ViewAsStatus) {
+ Result<int> ok(3);
+ Result<int> err(Status::Invalid("error"));
+
+ auto ViewAsStatus = [](const void* ptr) { return static_cast<const Status*>(ptr); };
+
+ EXPECT_EQ(ViewAsStatus(&ok), &ok.status());
+ EXPECT_EQ(ViewAsStatus(&err), &err.status());
+}
+
+TEST(ResultTest, MatcherExamples) {
+ EXPECT_THAT(Result<int>(Status::Invalid("arbitrary error")),
+ Raises(StatusCode::Invalid));
+
+ EXPECT_THAT(Result<int>(Status::Invalid("arbitrary error")),
+ Raises(StatusCode::Invalid, testing::HasSubstr("arbitrary")));
+
+ // message doesn't match, so no match
+ EXPECT_THAT(
+ Result<int>(Status::Invalid("arbitrary error")),
+ testing::Not(Raises(StatusCode::Invalid, testing::HasSubstr("reasonable"))));
+
+ // different error code, so no match
+ EXPECT_THAT(Result<int>(Status::TypeError("arbitrary error")),
+ testing::Not(Raises(StatusCode::Invalid)));
+
+ // not an error, so no match
+ EXPECT_THAT(Result<int>(333), testing::Not(Raises(StatusCode::Invalid)));
+
+ EXPECT_THAT(Result<std::string>("hello world"),
+ ResultWith(testing::HasSubstr("hello")));
+
+ EXPECT_THAT(Result<std::string>(Status::Invalid("XXX")),
+ testing::Not(ResultWith(testing::HasSubstr("hello"))));
+
+ // holds a value, but that value doesn't match the given pattern
+ EXPECT_THAT(Result<std::string>("foo bar"),
+ testing::Not(ResultWith(testing::HasSubstr("hello"))));
+}
+
+TEST(ResultTest, MatcherDescriptions) {
+ testing::Matcher<Result<std::string>> matcher = ResultWith(testing::HasSubstr("hello"));
+
+ {
+ std::stringstream ss;
+ matcher.DescribeTo(&ss);
+ EXPECT_THAT(ss.str(), testing::StrEq("value has substring \"hello\""));
+ }
+
+ {
+ std::stringstream ss;
+ matcher.DescribeNegationTo(&ss);
+ EXPECT_THAT(ss.str(), testing::StrEq("value has no substring \"hello\""));
+ }
+}
+
+TEST(ResultTest, MatcherExplanations) {
+ testing::Matcher<Result<std::string>> matcher = ResultWith(testing::HasSubstr("hello"));
+
+ {
+ testing::StringMatchResultListener listener;
+ EXPECT_TRUE(matcher.MatchAndExplain(Result<std::string>("hello world"), &listener));
+ EXPECT_THAT(listener.str(), testing::StrEq("whose value \"hello world\" matches"));
+ }
+
+ {
+ testing::StringMatchResultListener listener;
+ EXPECT_FALSE(matcher.MatchAndExplain(Result<std::string>("foo bar"), &listener));
+ EXPECT_THAT(listener.str(), testing::StrEq("whose value \"foo bar\" doesn't match"));
+ }
+
+ {
+ testing::StringMatchResultListener listener;
+ EXPECT_FALSE(matcher.MatchAndExplain(Status::TypeError("XXX"), &listener));
+ EXPECT_THAT(listener.str(),
+ testing::StrEq("whose error \"Type error: XXX\" doesn't match"));
+ }
+}
+
+} // namespace
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/scalar.cc b/src/arrow/cpp/src/arrow/scalar.cc
new file mode 100644
index 000000000..df8badae6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/scalar.cc
@@ -0,0 +1,1008 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/scalar.h"
+
+#include <memory>
+#include <sstream>
+#include <string>
+#include <utility>
+
+#include "arrow/array.h"
+#include "arrow/array/util.h"
+#include "arrow/buffer.h"
+#include "arrow/compare.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/formatting.h"
+#include "arrow/util/hashing.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/time.h"
+#include "arrow/util/unreachable.h"
+#include "arrow/util/utf8.h"
+#include "arrow/util/value_parsing.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+bool Scalar::Equals(const Scalar& other, const EqualOptions& options) const {
+ return ScalarEquals(*this, other, options);
+}
+
+bool Scalar::ApproxEquals(const Scalar& other, const EqualOptions& options) const {
+ return ScalarApproxEquals(*this, other, options);
+}
+
+namespace {
+
+// Implementation of Scalar::hash()
+struct ScalarHashImpl {
+ Status Visit(const NullScalar& s) { return Status::OK(); }
+
+ template <typename T>
+ Status Visit(const internal::PrimitiveScalar<T>& s) {
+ return ValueHash(s);
+ }
+
+ Status Visit(const BaseBinaryScalar& s) { return BufferHash(*s.value); }
+
+ template <typename T>
+ Status Visit(const TemporalScalar<T>& s) {
+ return ValueHash(s);
+ }
+
+ Status Visit(const DayTimeIntervalScalar& s) {
+ return StdHash(s.value.days) & StdHash(s.value.milliseconds);
+ }
+
+ Status Visit(const MonthDayNanoIntervalScalar& s) {
+ return StdHash(s.value.days) & StdHash(s.value.months) & StdHash(s.value.nanoseconds);
+ }
+
+ Status Visit(const Decimal128Scalar& s) {
+ return StdHash(s.value.low_bits()) & StdHash(s.value.high_bits());
+ }
+
+ Status Visit(const Decimal256Scalar& s) {
+ Status status = Status::OK();
+ // endianness doesn't affect result
+ for (uint64_t elem : s.value.native_endian_array()) {
+ status &= StdHash(elem);
+ }
+ return status;
+ }
+
+ Status Visit(const BaseListScalar& s) { return ArrayHash(*s.value); }
+
+ Status Visit(const StructScalar& s) {
+ for (const auto& child : s.value) {
+ AccumulateHashFrom(*child);
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryScalar& s) {
+ AccumulateHashFrom(*s.value.index);
+ return Status::OK();
+ }
+
+ Status Visit(const UnionScalar& s) {
+ // type_code is ignored when comparing for equality, so do not hash it either
+ AccumulateHashFrom(*s.value);
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionScalar& s) {
+ AccumulateHashFrom(*s.value);
+ return Status::OK();
+ }
+
+ template <typename T>
+ Status StdHash(const T& t) {
+ static std::hash<T> hash;
+ hash_ ^= hash(t);
+ return Status::OK();
+ }
+
+ template <typename S>
+ Status ValueHash(const S& s) {
+ return StdHash(s.value);
+ }
+
+ Status BufferHash(const Buffer& b) {
+ hash_ ^= internal::ComputeStringHash<1>(b.data(), b.size());
+ return Status::OK();
+ }
+
+ Status ArrayHash(const Array& a) { return ArrayHash(*a.data()); }
+
+ Status ArrayHash(const ArrayData& a) {
+ RETURN_NOT_OK(StdHash(a.length) & StdHash(a.GetNullCount()));
+ if (a.buffers[0] != nullptr) {
+ // We can't visit values without unboxing the whole array, so only hash
+ // the null bitmap for now.
+ RETURN_NOT_OK(BufferHash(*a.buffers[0]));
+ }
+ for (const auto& child : a.child_data) {
+ RETURN_NOT_OK(ArrayHash(*child));
+ }
+ return Status::OK();
+ }
+
+ explicit ScalarHashImpl(const Scalar& scalar) : hash_(scalar.type->Hash()) {
+ AccumulateHashFrom(scalar);
+ }
+
+ void AccumulateHashFrom(const Scalar& scalar) {
+ // Note we already injected the type in ScalarHashImpl::ScalarHashImpl
+ if (scalar.is_valid) {
+ DCHECK_OK(VisitScalarInline(scalar, this));
+ }
+ }
+
+ size_t hash_;
+};
+
+struct ScalarBoundsCheckImpl {
+ int64_t min_value;
+ int64_t max_value;
+ int64_t actual_value = -1;
+ bool ok = true;
+
+ ScalarBoundsCheckImpl(int64_t min_value, int64_t max_value)
+ : min_value(min_value), max_value(max_value) {}
+
+ Status Visit(const Scalar&) {
+ Unreachable();
+ return Status::NotImplemented("");
+ }
+
+ template <typename ScalarType, typename Type = typename ScalarType::TypeClass>
+ enable_if_integer<Type, Status> Visit(const ScalarType& scalar) {
+ actual_value = static_cast<int64_t>(scalar.value);
+ ok = (actual_value >= min_value && actual_value <= max_value);
+ return Status::OK();
+ }
+};
+
+// Implementation of Scalar::Validate() and Scalar::ValidateFull()
+struct ScalarValidateImpl {
+ const bool full_validation_;
+
+ explicit ScalarValidateImpl(bool full_validation) : full_validation_(full_validation) {
+ ::arrow::util::InitializeUTF8();
+ }
+
+ Status Validate(const Scalar& scalar) {
+ if (!scalar.type) {
+ return Status::Invalid("scalar lacks a type");
+ }
+ return VisitScalarInline(scalar, this);
+ }
+
+ Status Visit(const NullScalar& s) {
+ if (s.is_valid) {
+ return Status::Invalid("null scalar should have is_valid = false");
+ }
+ return Status::OK();
+ }
+
+ template <typename T>
+ Status Visit(const internal::PrimitiveScalar<T>& s) {
+ return Status::OK();
+ }
+
+ Status Visit(const BaseBinaryScalar& s) { return ValidateBinaryScalar(s); }
+
+ Status Visit(const StringScalar& s) { return ValidateStringScalar(s); }
+
+ Status Visit(const LargeStringScalar& s) { return ValidateStringScalar(s); }
+
+ Status Visit(const FixedSizeBinaryScalar& s) {
+ RETURN_NOT_OK(ValidateBinaryScalar(s));
+ if (s.is_valid) {
+ const auto& byte_width =
+ checked_cast<const FixedSizeBinaryType&>(*s.type).byte_width();
+ if (s.value->size() != byte_width) {
+ return Status::Invalid(s.type->ToString(), " scalar should have a value of size ",
+ byte_width, ", got ", s.value->size());
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal128Scalar& s) {
+ // XXX validate precision?
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal256Scalar& s) {
+ // XXX validate precision?
+ return Status::OK();
+ }
+
+ Status Visit(const BaseListScalar& s) { return ValidateBaseListScalar(s); }
+
+ Status Visit(const FixedSizeListScalar& s) {
+ RETURN_NOT_OK(ValidateBaseListScalar(s));
+ if (s.is_valid) {
+ const auto& list_type = checked_cast<const FixedSizeListType&>(*s.type);
+ if (s.value->length() != list_type.list_size()) {
+ return Status::Invalid(s.type->ToString(),
+ " scalar should have a child value of length ",
+ list_type.list_size(), ", got ", s.value->length());
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const StructScalar& s) {
+ if (!s.is_valid) {
+ if (!s.value.empty()) {
+ return Status::Invalid(s.type->ToString(),
+ " scalar is marked null but has child values");
+ }
+ return Status::OK();
+ }
+ const int num_fields = s.type->num_fields();
+ const auto& fields = s.type->fields();
+ if (fields.size() != s.value.size()) {
+ return Status::Invalid("non-null ", s.type->ToString(), " scalar should have ",
+ num_fields, " child values, got ", s.value.size());
+ }
+ for (int i = 0; i < num_fields; ++i) {
+ if (!s.value[i]) {
+ return Status::Invalid("non-null ", s.type->ToString(),
+ " scalar has missing child value at index ", i);
+ }
+ const auto st = Validate(*s.value[i]);
+ if (!st.ok()) {
+ return st.WithMessage(s.type->ToString(),
+ " scalar fails validation for child at index ", i, ": ",
+ st.message());
+ }
+ if (!s.value[i]->type->Equals(*fields[i]->type())) {
+ return Status::Invalid(
+ s.type->ToString(), " scalar should have a child value of type ",
+ fields[i]->type()->ToString(), "at index ", i, ", got ", s.value[i]->type);
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryScalar& s) {
+ const auto& dict_type = checked_cast<const DictionaryType&>(*s.type);
+
+ // Validate index
+ if (!s.value.index) {
+ return Status::Invalid(s.type->ToString(), " scalar doesn't have an index value");
+ }
+ {
+ const auto st = Validate(*s.value.index);
+ if (!st.ok()) {
+ return st.WithMessage(s.type->ToString(),
+ " scalar fails validation for index value: ", st.message());
+ }
+ }
+ if (!s.value.index->type->Equals(*dict_type.index_type())) {
+ return Status::Invalid(
+ s.type->ToString(), " scalar should have an index value of type ",
+ dict_type.index_type()->ToString(), ", got ", s.value.index->type->ToString());
+ }
+ if (s.is_valid && !s.value.index->is_valid) {
+ return Status::Invalid("non-null ", s.type->ToString(),
+ " scalar has null index value");
+ }
+ if (!s.is_valid && s.value.index->is_valid) {
+ return Status::Invalid("null ", s.type->ToString(),
+ " scalar has non-null index value");
+ }
+
+ // Validate dictionary
+ if (!s.value.dictionary) {
+ return Status::Invalid(s.type->ToString(),
+ " scalar doesn't have a dictionary value");
+ }
+ {
+ const auto st = full_validation_ ? s.value.dictionary->ValidateFull()
+ : s.value.dictionary->Validate();
+ if (!st.ok()) {
+ return st.WithMessage(
+ s.type->ToString(),
+ " scalar fails validation for dictionary value: ", st.message());
+ }
+ }
+ if (!s.value.dictionary->type()->Equals(*dict_type.value_type())) {
+ return Status::Invalid(s.type->ToString(),
+ " scalar should have a dictionary value of type ",
+ dict_type.value_type()->ToString(), ", got ",
+ s.value.dictionary->type()->ToString());
+ }
+
+ // Check index is in bounds
+ if (full_validation_ && s.value.index->is_valid) {
+ ScalarBoundsCheckImpl bounds_checker{0, s.value.dictionary->length() - 1};
+ RETURN_NOT_OK(VisitScalarInline(*s.value.index, &bounds_checker));
+ if (!bounds_checker.ok) {
+ return Status::Invalid(s.type->ToString(), " scalar index value out of bounds: ",
+ bounds_checker.actual_value);
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const UnionScalar& s) {
+ RETURN_NOT_OK(ValidateOptionalValue(s));
+ const int type_code = s.type_code; // avoid 8-bit int types for printing
+ const auto& union_type = checked_cast<const UnionType&>(*s.type);
+ const auto& child_ids = union_type.child_ids();
+ if (type_code < 0 || type_code >= static_cast<int64_t>(child_ids.size()) ||
+ child_ids[type_code] == UnionType::kInvalidChildId) {
+ return Status::Invalid(s.type->ToString(), " scalar has invalid type code ",
+ type_code);
+ }
+ if (s.is_valid) {
+ const auto& field_type = *union_type.field(child_ids[type_code])->type();
+ if (!field_type.Equals(*s.value->type)) {
+ return Status::Invalid(s.type->ToString(), " scalar with type code ", type_code,
+ " should have an underlying value of type ",
+ field_type.ToString(), ", got ",
+ s.value->type->ToString());
+ }
+ const auto st = Validate(*s.value);
+ if (!st.ok()) {
+ return st.WithMessage(
+ s.type->ToString(),
+ " scalar fails validation for underlying value: ", st.message());
+ }
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionScalar& s) {
+ if (!s.is_valid) {
+ if (s.value) {
+ return Status::Invalid("null ", s.type->ToString(), " scalar has storage value");
+ }
+ return Status::OK();
+ }
+
+ if (!s.value) {
+ return Status::Invalid("non-null ", s.type->ToString(),
+ " scalar doesn't have storage value");
+ }
+ if (!s.value->is_valid) {
+ return Status::Invalid("non-null ", s.type->ToString(),
+ " scalar has null storage value");
+ }
+ const auto st = Validate(*s.value);
+ if (!st.ok()) {
+ return st.WithMessage(s.type->ToString(),
+ " scalar fails validation for storage value: ", st.message());
+ }
+ return Status::OK();
+ }
+
+ Status ValidateStringScalar(const BaseBinaryScalar& s) {
+ RETURN_NOT_OK(ValidateBinaryScalar(s));
+ if (s.is_valid && full_validation_) {
+ if (!::arrow::util::ValidateUTF8(s.value->data(), s.value->size())) {
+ return Status::Invalid(s.type->ToString(), " scalar contains invalid UTF8 data");
+ }
+ }
+ return Status::OK();
+ }
+
+ Status ValidateBinaryScalar(const BaseBinaryScalar& s) {
+ return ValidateOptionalValue(s);
+ }
+
+ Status ValidateBaseListScalar(const BaseListScalar& s) {
+ RETURN_NOT_OK(ValidateOptionalValue(s));
+ if (s.is_valid) {
+ const auto st = full_validation_ ? s.value->ValidateFull() : s.value->Validate();
+ if (!st.ok()) {
+ return st.WithMessage(s.type->ToString(),
+ " scalar fails validation for value: ", st.message());
+ }
+
+ const auto& list_type = checked_cast<const BaseListType&>(*s.type);
+ const auto& value_type = *list_type.value_type();
+ if (!s.value->type()->Equals(value_type)) {
+ return Status::Invalid(
+ list_type.ToString(), " scalar should have a value of type ",
+ value_type.ToString(), ", got ", s.value->type()->ToString());
+ }
+ }
+ return Status::OK();
+ }
+
+ template <typename ScalarType>
+ Status ValidateOptionalValue(const ScalarType& s) {
+ return ValidateOptionalValue(s, s.value, "value");
+ }
+
+ template <typename ScalarType, typename ValueType>
+ Status ValidateOptionalValue(const ScalarType& s, const ValueType& value,
+ const char* value_desc) {
+ if (s.is_valid && !s.value) {
+ return Status::Invalid(s.type->ToString(),
+ " scalar is marked valid but doesn't have a ", value_desc);
+ }
+ if (!s.is_valid && s.value) {
+ return Status::Invalid(s.type->ToString(), " scalar is marked null but has a ",
+ value_desc);
+ }
+ return Status::OK();
+ }
+};
+
+} // namespace
+
+size_t Scalar::hash() const { return ScalarHashImpl(*this).hash_; }
+
+Status Scalar::Validate() const {
+ return ScalarValidateImpl(/*full_validation=*/false).Validate(*this);
+}
+
+Status Scalar::ValidateFull() const {
+ return ScalarValidateImpl(/*full_validation=*/true).Validate(*this);
+}
+
+StringScalar::StringScalar(std::string s)
+ : StringScalar(Buffer::FromString(std::move(s))) {}
+
+LargeStringScalar::LargeStringScalar(std::string s)
+ : LargeStringScalar(Buffer::FromString(std::move(s))) {}
+
+FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::shared_ptr<Buffer> value,
+ std::shared_ptr<DataType> type)
+ : BinaryScalar(std::move(value), std::move(type)) {
+ ARROW_CHECK_EQ(checked_cast<const FixedSizeBinaryType&>(*this->type).byte_width(),
+ this->value->size());
+}
+
+BaseListScalar::BaseListScalar(std::shared_ptr<Array> value,
+ std::shared_ptr<DataType> type)
+ : Scalar{std::move(type), true}, value(std::move(value)) {
+ ARROW_CHECK(this->type->field(0)->type()->Equals(this->value->type()));
+}
+
+ListScalar::ListScalar(std::shared_ptr<Array> value)
+ : BaseListScalar(value, list(value->type())) {}
+
+LargeListScalar::LargeListScalar(std::shared_ptr<Array> value)
+ : BaseListScalar(value, large_list(value->type())) {}
+
+inline std::shared_ptr<DataType> MakeMapType(const std::shared_ptr<DataType>& pair_type) {
+ ARROW_CHECK_EQ(pair_type->id(), Type::STRUCT);
+ ARROW_CHECK_EQ(pair_type->num_fields(), 2);
+ return map(pair_type->field(0)->type(), pair_type->field(1)->type());
+}
+
+MapScalar::MapScalar(std::shared_ptr<Array> value)
+ : BaseListScalar(value, MakeMapType(value->type())) {}
+
+FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr<Array> value,
+ std::shared_ptr<DataType> type)
+ : BaseListScalar(value, std::move(type)) {
+ ARROW_CHECK_EQ(this->value->length(),
+ checked_cast<const FixedSizeListType&>(*this->type).list_size());
+}
+
+FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr<Array> value)
+ : BaseListScalar(
+ value, fixed_size_list(value->type(), static_cast<int32_t>(value->length()))) {}
+
+Result<std::shared_ptr<StructScalar>> StructScalar::Make(
+ ScalarVector values, std::vector<std::string> field_names) {
+ if (values.size() != field_names.size()) {
+ return Status::Invalid("Mismatching number of field names and child scalars");
+ }
+
+ FieldVector fields(field_names.size());
+ for (size_t i = 0; i < fields.size(); ++i) {
+ fields[i] = arrow::field(std::move(field_names[i]), values[i]->type);
+ }
+
+ return std::make_shared<StructScalar>(std::move(values), struct_(std::move(fields)));
+}
+
+Result<std::shared_ptr<Scalar>> StructScalar::field(FieldRef ref) const {
+ ARROW_ASSIGN_OR_RAISE(auto path, ref.FindOne(*type));
+ if (path.indices().size() != 1) {
+ return Status::NotImplemented("retrieval of nested fields from StructScalar");
+ }
+ auto index = path.indices()[0];
+ if (is_valid) {
+ return value[index];
+ } else {
+ const auto& struct_type = checked_cast<const StructType&>(*this->type);
+ const auto& field_type = struct_type.field(index)->type();
+ return MakeNullScalar(field_type);
+ }
+}
+
+DictionaryScalar::DictionaryScalar(std::shared_ptr<DataType> type)
+ : internal::PrimitiveScalarBase(std::move(type)),
+ value{MakeNullScalar(checked_cast<const DictionaryType&>(*this->type).index_type()),
+ MakeArrayOfNull(checked_cast<const DictionaryType&>(*this->type).value_type(),
+ 0)
+ .ValueOrDie()} {}
+
+Result<std::shared_ptr<Scalar>> DictionaryScalar::GetEncodedValue() const {
+ const auto& dict_type = checked_cast<DictionaryType&>(*type);
+
+ if (!is_valid) {
+ return MakeNullScalar(dict_type.value_type());
+ }
+
+ int64_t index_value = 0;
+ switch (dict_type.index_type()->id()) {
+ case Type::UINT8:
+ index_value =
+ static_cast<int64_t>(checked_cast<const UInt8Scalar&>(*value.index).value);
+ break;
+ case Type::INT8:
+ index_value =
+ static_cast<int64_t>(checked_cast<const Int8Scalar&>(*value.index).value);
+ break;
+ case Type::UINT16:
+ index_value =
+ static_cast<int64_t>(checked_cast<const UInt16Scalar&>(*value.index).value);
+ break;
+ case Type::INT16:
+ index_value =
+ static_cast<int64_t>(checked_cast<const Int16Scalar&>(*value.index).value);
+ break;
+ case Type::UINT32:
+ index_value =
+ static_cast<int64_t>(checked_cast<const UInt32Scalar&>(*value.index).value);
+ break;
+ case Type::INT32:
+ index_value =
+ static_cast<int64_t>(checked_cast<const Int32Scalar&>(*value.index).value);
+ break;
+ case Type::UINT64:
+ index_value =
+ static_cast<int64_t>(checked_cast<const UInt64Scalar&>(*value.index).value);
+ break;
+ case Type::INT64:
+ index_value =
+ static_cast<int64_t>(checked_cast<const Int64Scalar&>(*value.index).value);
+ break;
+ default:
+ return Status::TypeError("Not implemented dictionary index type");
+ break;
+ }
+ return value.dictionary->GetScalar(index_value);
+}
+
+std::shared_ptr<DictionaryScalar> DictionaryScalar::Make(std::shared_ptr<Scalar> index,
+ std::shared_ptr<Array> dict) {
+ auto type = dictionary(index->type, dict->type());
+ auto is_valid = index->is_valid;
+ return std::make_shared<DictionaryScalar>(ValueType{std::move(index), std::move(dict)},
+ std::move(type), is_valid);
+}
+
+namespace {
+
+template <typename T>
+using scalar_constructor_has_arrow_type =
+ std::is_constructible<typename TypeTraits<T>::ScalarType, std::shared_ptr<DataType>>;
+
+template <typename T, typename R = void>
+using enable_if_scalar_constructor_has_arrow_type =
+ typename std::enable_if<scalar_constructor_has_arrow_type<T>::value, R>::type;
+
+template <typename T, typename R = void>
+using enable_if_scalar_constructor_has_no_arrow_type =
+ typename std::enable_if<!scalar_constructor_has_arrow_type<T>::value, R>::type;
+
+struct MakeNullImpl {
+ template <typename T, typename ScalarType = typename TypeTraits<T>::ScalarType>
+ enable_if_scalar_constructor_has_arrow_type<T, Status> Visit(const T&) {
+ out_ = std::make_shared<ScalarType>(type_);
+ return Status::OK();
+ }
+
+ template <typename T, typename ScalarType = typename TypeTraits<T>::ScalarType>
+ enable_if_scalar_constructor_has_no_arrow_type<T, Status> Visit(const T&) {
+ out_ = std::make_shared<ScalarType>();
+ return Status::OK();
+ }
+
+ Status Visit(const SparseUnionType& type) { return MakeUnionScalar(type); }
+
+ Status Visit(const DenseUnionType& type) { return MakeUnionScalar(type); }
+
+ template <typename T, typename ScalarType = typename TypeTraits<T>::ScalarType>
+ Status MakeUnionScalar(const T& type) {
+ if (type.num_fields() == 0) {
+ return Status::Invalid("Cannot make scalar of empty union type");
+ }
+ out_ = std::make_shared<ScalarType>(type.type_codes()[0], type_);
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionType& type) {
+ out_ = std::make_shared<ExtensionScalar>(type_);
+ return Status::OK();
+ }
+
+ std::shared_ptr<Scalar> Finish() && {
+ // Should not fail.
+ DCHECK_OK(VisitTypeInline(*type_, this));
+ return std::move(out_);
+ }
+
+ std::shared_ptr<DataType> type_;
+ std::shared_ptr<Scalar> out_;
+};
+
+} // namespace
+
+std::shared_ptr<Scalar> MakeNullScalar(std::shared_ptr<DataType> type) {
+ return MakeNullImpl{std::move(type), nullptr}.Finish();
+}
+
+std::string Scalar::ToString() const {
+ if (!this->is_valid) {
+ return "null";
+ }
+ if (type->id() == Type::DICTIONARY) {
+ auto dict_scalar = checked_cast<const DictionaryScalar*>(this);
+ return dict_scalar->value.dictionary->ToString() + "[" +
+ dict_scalar->value.index->ToString() + "]";
+ }
+ auto maybe_repr = CastTo(utf8());
+ if (maybe_repr.ok()) {
+ return checked_cast<const StringScalar&>(*maybe_repr.ValueOrDie()).value->ToString();
+ }
+ return "...";
+}
+
+struct ScalarParseImpl {
+ template <typename T, typename = internal::enable_if_parseable<T>>
+ Status Visit(const T& t) {
+ typename internal::StringConverter<T>::value_type value;
+ if (!internal::ParseValue(t, s_.data(), s_.size(), &value)) {
+ return Status::Invalid("error parsing '", s_, "' as scalar of type ", t);
+ }
+ return Finish(value);
+ }
+
+ Status Visit(const BinaryType&) { return FinishWithBuffer(); }
+
+ Status Visit(const LargeBinaryType&) { return FinishWithBuffer(); }
+
+ Status Visit(const FixedSizeBinaryType&) { return FinishWithBuffer(); }
+
+ Status Visit(const DictionaryType& t) {
+ ARROW_ASSIGN_OR_RAISE(auto value, Scalar::Parse(t.value_type(), s_));
+ return Finish(std::move(value));
+ }
+
+ Status Visit(const DataType& t) {
+ return Status::NotImplemented("parsing scalars of type ", t);
+ }
+
+ template <typename Arg>
+ Status Finish(Arg&& arg) {
+ return MakeScalar(std::move(type_), std::forward<Arg>(arg)).Value(&out_);
+ }
+
+ Status FinishWithBuffer() { return Finish(Buffer::FromString(std::string(s_))); }
+
+ Result<std::shared_ptr<Scalar>> Finish() && {
+ RETURN_NOT_OK(VisitTypeInline(*type_, this));
+ return std::move(out_);
+ }
+
+ ScalarParseImpl(std::shared_ptr<DataType> type, util::string_view s)
+ : type_(std::move(type)), s_(s) {}
+
+ std::shared_ptr<DataType> type_;
+ util::string_view s_;
+ std::shared_ptr<Scalar> out_;
+};
+
+Result<std::shared_ptr<Scalar>> Scalar::Parse(const std::shared_ptr<DataType>& type,
+ util::string_view s) {
+ return ScalarParseImpl{type, s}.Finish();
+}
+
+namespace internal {
+Status CheckBufferLength(const FixedSizeBinaryType* t, const std::shared_ptr<Buffer>* b) {
+ return t->byte_width() == (*b)->size()
+ ? Status::OK()
+ : Status::Invalid("buffer length ", (*b)->size(), " is not compatible with ",
+ *t);
+}
+} // namespace internal
+
+namespace {
+// CastImpl(...) assumes `to` points to a non null scalar of the correct type with
+// uninitialized value
+
+// helper for StringFormatter
+template <typename Formatter, typename ScalarType>
+std::shared_ptr<Buffer> FormatToBuffer(Formatter&& formatter, const ScalarType& from) {
+ if (!from.is_valid) {
+ return Buffer::FromString("null");
+ }
+ return formatter(from.value, [&](util::string_view v) {
+ return Buffer::FromString(std::string(v));
+ });
+}
+
+// error fallback
+Status CastImpl(const Scalar& from, Scalar* to) {
+ return Status::NotImplemented("casting scalars of type ", *from.type, " to type ",
+ *to->type);
+}
+
+// numeric to numeric
+template <typename From, typename To>
+Status CastImpl(const NumericScalar<From>& from, NumericScalar<To>* to) {
+ to->value = static_cast<typename To::c_type>(from.value);
+ return Status::OK();
+}
+
+// numeric to boolean
+template <typename T>
+Status CastImpl(const NumericScalar<T>& from, BooleanScalar* to) {
+ constexpr auto zero = static_cast<typename T::c_type>(0);
+ to->value = from.value != zero;
+ return Status::OK();
+}
+
+// boolean to numeric
+template <typename T>
+Status CastImpl(const BooleanScalar& from, NumericScalar<T>* to) {
+ to->value = static_cast<typename T::c_type>(from.value);
+ return Status::OK();
+}
+
+// numeric to temporal
+template <typename From, typename To>
+typename std::enable_if<std::is_base_of<TemporalType, To>::value &&
+ !std::is_same<DayTimeIntervalType, To>::value &&
+ !std::is_same<MonthDayNanoIntervalType, To>::value,
+ Status>::type
+CastImpl(const NumericScalar<From>& from, TemporalScalar<To>* to) {
+ to->value = static_cast<typename To::c_type>(from.value);
+ return Status::OK();
+}
+
+// temporal to numeric
+template <typename From, typename To>
+typename std::enable_if<std::is_base_of<TemporalType, From>::value &&
+ !std::is_same<DayTimeIntervalType, From>::value &&
+ !std::is_same<MonthDayNanoIntervalType, From>::value,
+ Status>::type
+CastImpl(const TemporalScalar<From>& from, NumericScalar<To>* to) {
+ to->value = static_cast<typename To::c_type>(from.value);
+ return Status::OK();
+}
+
+// timestamp to timestamp
+Status CastImpl(const TimestampScalar& from, TimestampScalar* to) {
+ return util::ConvertTimestampValue(from.type, to->type, from.value).Value(&to->value);
+}
+
+template <typename TypeWithTimeUnit>
+std::shared_ptr<DataType> AsTimestampType(const std::shared_ptr<DataType>& type) {
+ return timestamp(checked_cast<const TypeWithTimeUnit&>(*type).unit());
+}
+
+// duration to duration
+Status CastImpl(const DurationScalar& from, DurationScalar* to) {
+ return util::ConvertTimestampValue(AsTimestampType<DurationType>(from.type),
+ AsTimestampType<DurationType>(to->type), from.value)
+ .Value(&to->value);
+}
+
+// time to time
+template <typename F, typename ToScalar, typename T = typename ToScalar::TypeClass>
+enable_if_time<T, Status> CastImpl(const TimeScalar<F>& from, ToScalar* to) {
+ return util::ConvertTimestampValue(AsTimestampType<F>(from.type),
+ AsTimestampType<T>(to->type), from.value)
+ .Value(&to->value);
+}
+
+constexpr int64_t kMillisecondsInDay = 86400000;
+
+// date to date
+Status CastImpl(const Date32Scalar& from, Date64Scalar* to) {
+ to->value = from.value * kMillisecondsInDay;
+ return Status::OK();
+}
+Status CastImpl(const Date64Scalar& from, Date32Scalar* to) {
+ to->value = static_cast<int32_t>(from.value / kMillisecondsInDay);
+ return Status::OK();
+}
+
+// timestamp to date
+Status CastImpl(const TimestampScalar& from, Date64Scalar* to) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto millis,
+ util::ConvertTimestampValue(from.type, timestamp(TimeUnit::MILLI), from.value));
+ to->value = millis - millis % kMillisecondsInDay;
+ return Status::OK();
+}
+Status CastImpl(const TimestampScalar& from, Date32Scalar* to) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto millis,
+ util::ConvertTimestampValue(from.type, timestamp(TimeUnit::MILLI), from.value));
+ to->value = static_cast<int32_t>(millis / kMillisecondsInDay);
+ return Status::OK();
+}
+
+// date to timestamp
+template <typename D>
+Status CastImpl(const DateScalar<D>& from, TimestampScalar* to) {
+ int64_t millis = from.value;
+ if (std::is_same<D, Date32Type>::value) {
+ millis *= kMillisecondsInDay;
+ }
+ return util::ConvertTimestampValue(timestamp(TimeUnit::MILLI), to->type, millis)
+ .Value(&to->value);
+}
+
+// string to any
+template <typename ScalarType>
+Status CastImpl(const StringScalar& from, ScalarType* to) {
+ ARROW_ASSIGN_OR_RAISE(auto out,
+ Scalar::Parse(to->type, util::string_view(*from.value)));
+ to->value = std::move(checked_cast<ScalarType&>(*out).value);
+ return Status::OK();
+}
+
+// binary to string
+Status CastImpl(const BinaryScalar& from, StringScalar* to) {
+ to->value = from.value;
+ return Status::OK();
+}
+
+// formattable to string
+template <typename ScalarType, typename T = typename ScalarType::TypeClass,
+ typename Formatter = internal::StringFormatter<T>,
+ // note: Value unused but necessary to trigger SFINAE if Formatter is
+ // undefined
+ typename Value = typename Formatter::value_type>
+Status CastImpl(const ScalarType& from, StringScalar* to) {
+ to->value = FormatToBuffer(Formatter{from.type}, from);
+ return Status::OK();
+}
+
+Status CastImpl(const Decimal128Scalar& from, StringScalar* to) {
+ auto from_type = checked_cast<const Decimal128Type*>(from.type.get());
+ to->value = Buffer::FromString(from.value.ToString(from_type->scale()));
+ return Status::OK();
+}
+
+Status CastImpl(const Decimal256Scalar& from, StringScalar* to) {
+ auto from_type = checked_cast<const Decimal256Type*>(from.type.get());
+ to->value = Buffer::FromString(from.value.ToString(from_type->scale()));
+ return Status::OK();
+}
+
+Status CastImpl(const StructScalar& from, StringScalar* to) {
+ std::stringstream ss;
+ ss << '{';
+ for (int i = 0; static_cast<size_t>(i) < from.value.size(); i++) {
+ if (i > 0) ss << ", ";
+ ss << from.type->field(i)->name() << ':' << from.type->field(i)->type()->ToString()
+ << " = " << from.value[i]->ToString();
+ }
+ ss << '}';
+ to->value = Buffer::FromString(ss.str());
+ return Status::OK();
+}
+
+Status CastImpl(const UnionScalar& from, StringScalar* to) {
+ const auto& union_ty = checked_cast<const UnionType&>(*from.type);
+ std::stringstream ss;
+ ss << "union{" << union_ty.field(union_ty.child_ids()[from.type_code])->ToString()
+ << " = " << from.value->ToString() << '}';
+ to->value = Buffer::FromString(ss.str());
+ return Status::OK();
+}
+
+struct CastImplVisitor {
+ Status NotImplemented() {
+ return Status::NotImplemented("cast to ", *to_type_, " from ", *from_.type);
+ }
+
+ const Scalar& from_;
+ const std::shared_ptr<DataType>& to_type_;
+ Scalar* out_;
+};
+
+template <typename ToType>
+struct FromTypeVisitor : CastImplVisitor {
+ using ToScalar = typename TypeTraits<ToType>::ScalarType;
+
+ FromTypeVisitor(const Scalar& from, const std::shared_ptr<DataType>& to_type,
+ Scalar* out)
+ : CastImplVisitor{from, to_type, out} {}
+
+ template <typename FromType>
+ Status Visit(const FromType&) {
+ return CastImpl(checked_cast<const typename TypeTraits<FromType>::ScalarType&>(from_),
+ checked_cast<ToScalar*>(out_));
+ }
+
+ // identity cast only for parameter free types
+ template <typename T1 = ToType>
+ typename std::enable_if<TypeTraits<T1>::is_parameter_free, Status>::type Visit(
+ const ToType&) {
+ checked_cast<ToScalar*>(out_)->value = checked_cast<const ToScalar&>(from_).value;
+ return Status::OK();
+ }
+
+ Status Visit(const NullType&) { return NotImplemented(); }
+ Status Visit(const DictionaryType&) { return NotImplemented(); }
+ Status Visit(const ExtensionType&) { return NotImplemented(); }
+};
+
+struct ToTypeVisitor : CastImplVisitor {
+ ToTypeVisitor(const Scalar& from, const std::shared_ptr<DataType>& to_type, Scalar* out)
+ : CastImplVisitor{from, to_type, out} {}
+
+ template <typename ToType>
+ Status Visit(const ToType&) {
+ FromTypeVisitor<ToType> unpack_from_type{from_, to_type_, out_};
+ return VisitTypeInline(*from_.type, &unpack_from_type);
+ }
+
+ Status Visit(const NullType&) {
+ if (from_.is_valid) {
+ return Status::Invalid("attempting to cast non-null scalar to NullScalar");
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryType& dict_type) {
+ auto& out = checked_cast<DictionaryScalar*>(out_)->value;
+ ARROW_ASSIGN_OR_RAISE(auto cast_value, from_.CastTo(dict_type.value_type()));
+ ARROW_ASSIGN_OR_RAISE(out.dictionary, MakeArrayFromScalar(*cast_value, 1));
+ return Int32Scalar(0).CastTo(dict_type.index_type()).Value(&out.index);
+ }
+
+ Status Visit(const ExtensionType&) { return NotImplemented(); }
+};
+
+} // namespace
+
+Result<std::shared_ptr<Scalar>> Scalar::CastTo(std::shared_ptr<DataType> to) const {
+ std::shared_ptr<Scalar> out = MakeNullScalar(to);
+ if (is_valid) {
+ out->is_valid = true;
+ ToTypeVisitor unpack_to_type{*this, to, out.get()};
+ RETURN_NOT_OK(VisitTypeInline(*to, &unpack_to_type));
+ }
+ return out;
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/scalar.h b/src/arrow/cpp/src/arrow/scalar.h
new file mode 100644
index 000000000..be8df6b64
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/scalar.h
@@ -0,0 +1,636 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Object model for scalar (non-Array) values. Not intended for use with large
+// amounts of data
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/compare.h"
+#include "arrow/extension_type.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/compare.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Array;
+
+/// \brief Base class for scalar values
+///
+/// A Scalar represents a single value with a specific DataType.
+/// Scalars are useful for passing single value inputs to compute functions,
+/// or for representing individual array elements (with a non-trivial
+/// wrapping cost, though).
+struct ARROW_EXPORT Scalar : public util::EqualityComparable<Scalar> {
+ virtual ~Scalar() = default;
+
+ explicit Scalar(std::shared_ptr<DataType> type) : type(std::move(type)) {}
+
+ /// \brief The type of the scalar value
+ std::shared_ptr<DataType> type;
+
+ /// \brief Whether the value is valid (not null) or not
+ bool is_valid = false;
+
+ using util::EqualityComparable<Scalar>::operator==;
+ using util::EqualityComparable<Scalar>::Equals;
+ bool Equals(const Scalar& other,
+ const EqualOptions& options = EqualOptions::Defaults()) const;
+
+ bool ApproxEquals(const Scalar& other,
+ const EqualOptions& options = EqualOptions::Defaults()) const;
+
+ struct ARROW_EXPORT Hash {
+ size_t operator()(const Scalar& scalar) const { return scalar.hash(); }
+
+ size_t operator()(const std::shared_ptr<Scalar>& scalar) const {
+ return scalar->hash();
+ }
+ };
+
+ size_t hash() const;
+
+ std::string ToString() const;
+
+ /// \brief Perform cheap validation checks
+ ///
+ /// This is O(k) where k is the number of descendents.
+ ///
+ /// \return Status
+ Status Validate() const;
+
+ /// \brief Perform extensive data validation checks
+ ///
+ /// This is potentially O(k*n) where k is the number of descendents and n
+ /// is the length of descendents (if list scalars are involved).
+ ///
+ /// \return Status
+ Status ValidateFull() const;
+
+ static Result<std::shared_ptr<Scalar>> Parse(const std::shared_ptr<DataType>& type,
+ util::string_view repr);
+
+ // TODO(bkietz) add compute::CastOptions
+ Result<std::shared_ptr<Scalar>> CastTo(std::shared_ptr<DataType> to) const;
+
+ protected:
+ Scalar(std::shared_ptr<DataType> type, bool is_valid)
+ : type(std::move(type)), is_valid(is_valid) {}
+};
+
+/// \defgroup concrete-scalar-classes Concrete Scalar subclasses
+///
+/// @{
+
+/// \brief A scalar value for NullType. Never valid
+struct ARROW_EXPORT NullScalar : public Scalar {
+ public:
+ using TypeClass = NullType;
+
+ NullScalar() : Scalar{null(), false} {}
+};
+
+/// @}
+
+namespace internal {
+
+struct ARROW_EXPORT PrimitiveScalarBase : public Scalar {
+ using Scalar::Scalar;
+ /// \brief Get a mutable pointer to the value of this scalar. May be null.
+ virtual void* mutable_data() = 0;
+ /// \brief Get an immutable view of the value of this scalar as bytes.
+ virtual util::string_view view() const = 0;
+};
+
+template <typename T, typename CType = typename T::c_type>
+struct ARROW_EXPORT PrimitiveScalar : public PrimitiveScalarBase {
+ using PrimitiveScalarBase::PrimitiveScalarBase;
+ using TypeClass = T;
+ using ValueType = CType;
+
+ // Non-null constructor.
+ PrimitiveScalar(ValueType value, std::shared_ptr<DataType> type)
+ : PrimitiveScalarBase(std::move(type), true), value(value) {}
+
+ explicit PrimitiveScalar(std::shared_ptr<DataType> type)
+ : PrimitiveScalarBase(std::move(type), false) {}
+
+ ValueType value{};
+
+ void* mutable_data() override { return &value; }
+ util::string_view view() const override {
+ return util::string_view(reinterpret_cast<const char*>(&value), sizeof(ValueType));
+ };
+};
+
+} // namespace internal
+
+/// \addtogroup concrete-scalar-classes Concrete Scalar subclasses
+///
+/// @{
+
+struct ARROW_EXPORT BooleanScalar : public internal::PrimitiveScalar<BooleanType, bool> {
+ using Base = internal::PrimitiveScalar<BooleanType, bool>;
+ using Base::Base;
+
+ explicit BooleanScalar(bool value) : Base(value, boolean()) {}
+
+ BooleanScalar() : Base(boolean()) {}
+};
+
+template <typename T>
+struct NumericScalar : public internal::PrimitiveScalar<T> {
+ using Base = typename internal::PrimitiveScalar<T>;
+ using Base::Base;
+ using TypeClass = typename Base::TypeClass;
+ using ValueType = typename Base::ValueType;
+
+ explicit NumericScalar(ValueType value)
+ : Base(value, TypeTraits<T>::type_singleton()) {}
+
+ NumericScalar() : Base(TypeTraits<T>::type_singleton()) {}
+};
+
+struct ARROW_EXPORT Int8Scalar : public NumericScalar<Int8Type> {
+ using NumericScalar<Int8Type>::NumericScalar;
+};
+
+struct ARROW_EXPORT Int16Scalar : public NumericScalar<Int16Type> {
+ using NumericScalar<Int16Type>::NumericScalar;
+};
+
+struct ARROW_EXPORT Int32Scalar : public NumericScalar<Int32Type> {
+ using NumericScalar<Int32Type>::NumericScalar;
+};
+
+struct ARROW_EXPORT Int64Scalar : public NumericScalar<Int64Type> {
+ using NumericScalar<Int64Type>::NumericScalar;
+};
+
+struct ARROW_EXPORT UInt8Scalar : public NumericScalar<UInt8Type> {
+ using NumericScalar<UInt8Type>::NumericScalar;
+};
+
+struct ARROW_EXPORT UInt16Scalar : public NumericScalar<UInt16Type> {
+ using NumericScalar<UInt16Type>::NumericScalar;
+};
+
+struct ARROW_EXPORT UInt32Scalar : public NumericScalar<UInt32Type> {
+ using NumericScalar<UInt32Type>::NumericScalar;
+};
+
+struct ARROW_EXPORT UInt64Scalar : public NumericScalar<UInt64Type> {
+ using NumericScalar<UInt64Type>::NumericScalar;
+};
+
+struct ARROW_EXPORT HalfFloatScalar : public NumericScalar<HalfFloatType> {
+ using NumericScalar<HalfFloatType>::NumericScalar;
+};
+
+struct ARROW_EXPORT FloatScalar : public NumericScalar<FloatType> {
+ using NumericScalar<FloatType>::NumericScalar;
+};
+
+struct ARROW_EXPORT DoubleScalar : public NumericScalar<DoubleType> {
+ using NumericScalar<DoubleType>::NumericScalar;
+};
+
+struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase {
+ using internal::PrimitiveScalarBase::PrimitiveScalarBase;
+ using ValueType = std::shared_ptr<Buffer>;
+
+ std::shared_ptr<Buffer> value;
+
+ void* mutable_data() override {
+ return value ? reinterpret_cast<void*>(value->mutable_data()) : NULLPTR;
+ }
+ util::string_view view() const override {
+ return value ? util::string_view(*value) : util::string_view();
+ }
+
+ protected:
+ BaseBinaryScalar(std::shared_ptr<Buffer> value, std::shared_ptr<DataType> type)
+ : internal::PrimitiveScalarBase{std::move(type), true}, value(std::move(value)) {}
+};
+
+struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar {
+ using BaseBinaryScalar::BaseBinaryScalar;
+ using TypeClass = BinaryType;
+
+ BinaryScalar(std::shared_ptr<Buffer> value, std::shared_ptr<DataType> type)
+ : BaseBinaryScalar(std::move(value), std::move(type)) {}
+
+ explicit BinaryScalar(std::shared_ptr<Buffer> value)
+ : BinaryScalar(std::move(value), binary()) {}
+
+ BinaryScalar() : BinaryScalar(binary()) {}
+};
+
+struct ARROW_EXPORT StringScalar : public BinaryScalar {
+ using BinaryScalar::BinaryScalar;
+ using TypeClass = StringType;
+
+ explicit StringScalar(std::shared_ptr<Buffer> value)
+ : StringScalar(std::move(value), utf8()) {}
+
+ explicit StringScalar(std::string s);
+
+ StringScalar() : StringScalar(utf8()) {}
+};
+
+struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar {
+ using BaseBinaryScalar::BaseBinaryScalar;
+ using TypeClass = LargeBinaryType;
+
+ LargeBinaryScalar(std::shared_ptr<Buffer> value, std::shared_ptr<DataType> type)
+ : BaseBinaryScalar(std::move(value), std::move(type)) {}
+
+ explicit LargeBinaryScalar(std::shared_ptr<Buffer> value)
+ : LargeBinaryScalar(std::move(value), large_binary()) {}
+
+ LargeBinaryScalar() : LargeBinaryScalar(large_binary()) {}
+};
+
+struct ARROW_EXPORT LargeStringScalar : public LargeBinaryScalar {
+ using LargeBinaryScalar::LargeBinaryScalar;
+ using TypeClass = LargeStringType;
+
+ explicit LargeStringScalar(std::shared_ptr<Buffer> value)
+ : LargeStringScalar(std::move(value), large_utf8()) {}
+
+ explicit LargeStringScalar(std::string s);
+
+ LargeStringScalar() : LargeStringScalar(large_utf8()) {}
+};
+
+struct ARROW_EXPORT FixedSizeBinaryScalar : public BinaryScalar {
+ using TypeClass = FixedSizeBinaryType;
+
+ FixedSizeBinaryScalar(std::shared_ptr<Buffer> value, std::shared_ptr<DataType> type);
+
+ explicit FixedSizeBinaryScalar(std::shared_ptr<DataType> type) : BinaryScalar(type) {}
+};
+
+template <typename T>
+struct ARROW_EXPORT TemporalScalar : internal::PrimitiveScalar<T> {
+ using internal::PrimitiveScalar<T>::PrimitiveScalar;
+ using ValueType = typename TemporalScalar<T>::ValueType;
+
+ explicit TemporalScalar(ValueType value, std::shared_ptr<DataType> type)
+ : internal::PrimitiveScalar<T>(std::move(value), type) {}
+};
+
+template <typename T>
+struct ARROW_EXPORT DateScalar : public TemporalScalar<T> {
+ using TemporalScalar<T>::TemporalScalar;
+ using ValueType = typename TemporalScalar<T>::ValueType;
+
+ explicit DateScalar(ValueType value)
+ : TemporalScalar<T>(std::move(value), TypeTraits<T>::type_singleton()) {}
+ DateScalar() : TemporalScalar<T>(TypeTraits<T>::type_singleton()) {}
+};
+
+struct ARROW_EXPORT Date32Scalar : public DateScalar<Date32Type> {
+ using DateScalar<Date32Type>::DateScalar;
+};
+
+struct ARROW_EXPORT Date64Scalar : public DateScalar<Date64Type> {
+ using DateScalar<Date64Type>::DateScalar;
+};
+
+template <typename T>
+struct ARROW_EXPORT TimeScalar : public TemporalScalar<T> {
+ using TemporalScalar<T>::TemporalScalar;
+};
+
+struct ARROW_EXPORT Time32Scalar : public TimeScalar<Time32Type> {
+ using TimeScalar<Time32Type>::TimeScalar;
+};
+
+struct ARROW_EXPORT Time64Scalar : public TimeScalar<Time64Type> {
+ using TimeScalar<Time64Type>::TimeScalar;
+};
+
+struct ARROW_EXPORT TimestampScalar : public TemporalScalar<TimestampType> {
+ using TemporalScalar<TimestampType>::TemporalScalar;
+};
+
+template <typename T>
+struct ARROW_EXPORT IntervalScalar : public TemporalScalar<T> {
+ using TemporalScalar<T>::TemporalScalar;
+ using ValueType = typename TemporalScalar<T>::ValueType;
+
+ explicit IntervalScalar(ValueType value)
+ : TemporalScalar<T>(value, TypeTraits<T>::type_singleton()) {}
+ IntervalScalar() : TemporalScalar<T>(TypeTraits<T>::type_singleton()) {}
+};
+
+struct ARROW_EXPORT MonthIntervalScalar : public IntervalScalar<MonthIntervalType> {
+ using IntervalScalar<MonthIntervalType>::IntervalScalar;
+};
+
+struct ARROW_EXPORT DayTimeIntervalScalar : public IntervalScalar<DayTimeIntervalType> {
+ using IntervalScalar<DayTimeIntervalType>::IntervalScalar;
+};
+
+struct ARROW_EXPORT MonthDayNanoIntervalScalar
+ : public IntervalScalar<MonthDayNanoIntervalType> {
+ using IntervalScalar<MonthDayNanoIntervalType>::IntervalScalar;
+};
+
+struct ARROW_EXPORT DurationScalar : public TemporalScalar<DurationType> {
+ using TemporalScalar<DurationType>::TemporalScalar;
+};
+
+struct ARROW_EXPORT Decimal128Scalar : public internal::PrimitiveScalarBase {
+ using internal::PrimitiveScalarBase::PrimitiveScalarBase;
+ using TypeClass = Decimal128Type;
+ using ValueType = Decimal128;
+
+ Decimal128Scalar(Decimal128 value, std::shared_ptr<DataType> type)
+ : internal::PrimitiveScalarBase(std::move(type), true), value(value) {}
+
+ void* mutable_data() override {
+ return reinterpret_cast<void*>(value.mutable_native_endian_bytes());
+ }
+ util::string_view view() const override {
+ return util::string_view(reinterpret_cast<const char*>(value.native_endian_bytes()),
+ 16);
+ }
+
+ Decimal128 value;
+};
+
+struct ARROW_EXPORT Decimal256Scalar : public internal::PrimitiveScalarBase {
+ using internal::PrimitiveScalarBase::PrimitiveScalarBase;
+ using TypeClass = Decimal256Type;
+ using ValueType = Decimal256;
+
+ Decimal256Scalar(Decimal256 value, std::shared_ptr<DataType> type)
+ : internal::PrimitiveScalarBase(std::move(type), true), value(value) {}
+
+ void* mutable_data() override {
+ return reinterpret_cast<void*>(value.mutable_native_endian_bytes());
+ }
+ util::string_view view() const override {
+ const std::array<uint64_t, 4>& bytes = value.native_endian_array();
+ return util::string_view(reinterpret_cast<const char*>(bytes.data()),
+ bytes.size() * sizeof(uint64_t));
+ }
+
+ Decimal256 value;
+};
+
+struct ARROW_EXPORT BaseListScalar : public Scalar {
+ using Scalar::Scalar;
+ using ValueType = std::shared_ptr<Array>;
+
+ BaseListScalar(std::shared_ptr<Array> value, std::shared_ptr<DataType> type);
+
+ std::shared_ptr<Array> value;
+};
+
+struct ARROW_EXPORT ListScalar : public BaseListScalar {
+ using TypeClass = ListType;
+ using BaseListScalar::BaseListScalar;
+
+ explicit ListScalar(std::shared_ptr<Array> value);
+};
+
+struct ARROW_EXPORT LargeListScalar : public BaseListScalar {
+ using TypeClass = LargeListType;
+ using BaseListScalar::BaseListScalar;
+
+ explicit LargeListScalar(std::shared_ptr<Array> value);
+};
+
+struct ARROW_EXPORT MapScalar : public BaseListScalar {
+ using TypeClass = MapType;
+ using BaseListScalar::BaseListScalar;
+
+ explicit MapScalar(std::shared_ptr<Array> value);
+};
+
+struct ARROW_EXPORT FixedSizeListScalar : public BaseListScalar {
+ using TypeClass = FixedSizeListType;
+ using BaseListScalar::BaseListScalar;
+
+ FixedSizeListScalar(std::shared_ptr<Array> value, std::shared_ptr<DataType> type);
+
+ explicit FixedSizeListScalar(std::shared_ptr<Array> value);
+};
+
+struct ARROW_EXPORT StructScalar : public Scalar {
+ using TypeClass = StructType;
+ using ValueType = std::vector<std::shared_ptr<Scalar>>;
+
+ ScalarVector value;
+
+ Result<std::shared_ptr<Scalar>> field(FieldRef ref) const;
+
+ StructScalar(ValueType value, std::shared_ptr<DataType> type)
+ : Scalar(std::move(type), true), value(std::move(value)) {}
+
+ static Result<std::shared_ptr<StructScalar>> Make(ValueType value,
+ std::vector<std::string> field_names);
+
+ explicit StructScalar(std::shared_ptr<DataType> type) : Scalar(std::move(type)) {}
+};
+
+struct ARROW_EXPORT UnionScalar : public Scalar {
+ using Scalar::Scalar;
+ using ValueType = std::shared_ptr<Scalar>;
+
+ ValueType value;
+ int8_t type_code;
+
+ UnionScalar(int8_t type_code, std::shared_ptr<DataType> type)
+ : Scalar(std::move(type), false), type_code(type_code) {}
+
+ UnionScalar(ValueType value, int8_t type_code, std::shared_ptr<DataType> type)
+ : Scalar(std::move(type), true), value(std::move(value)), type_code(type_code) {}
+};
+
+struct ARROW_EXPORT SparseUnionScalar : public UnionScalar {
+ using UnionScalar::UnionScalar;
+ using TypeClass = SparseUnionType;
+};
+
+struct ARROW_EXPORT DenseUnionScalar : public UnionScalar {
+ using UnionScalar::UnionScalar;
+ using TypeClass = DenseUnionType;
+};
+
+/// \brief A Scalar value for DictionaryType
+///
+/// `is_valid` denotes the validity of the `index`, regardless of
+/// the corresponding value in the `dictionary`.
+struct ARROW_EXPORT DictionaryScalar : public internal::PrimitiveScalarBase {
+ using TypeClass = DictionaryType;
+ struct ValueType {
+ std::shared_ptr<Scalar> index;
+ std::shared_ptr<Array> dictionary;
+ } value;
+
+ explicit DictionaryScalar(std::shared_ptr<DataType> type);
+
+ DictionaryScalar(ValueType value, std::shared_ptr<DataType> type, bool is_valid = true)
+ : internal::PrimitiveScalarBase(std::move(type), is_valid),
+ value(std::move(value)) {}
+
+ static std::shared_ptr<DictionaryScalar> Make(std::shared_ptr<Scalar> index,
+ std::shared_ptr<Array> dict);
+
+ Result<std::shared_ptr<Scalar>> GetEncodedValue() const;
+
+ void* mutable_data() override {
+ return internal::checked_cast<internal::PrimitiveScalarBase&>(*value.index)
+ .mutable_data();
+ }
+ util::string_view view() const override {
+ return internal::checked_cast<const internal::PrimitiveScalarBase&>(*value.index)
+ .view();
+ }
+};
+
+/// \brief A Scalar value for ExtensionType
+///
+/// The value is the underlying storage scalar.
+/// `is_valid` must only be true if `value` is non-null and `value->is_valid` is true
+struct ARROW_EXPORT ExtensionScalar : public Scalar {
+ using Scalar::Scalar;
+ using TypeClass = ExtensionType;
+ using ValueType = std::shared_ptr<Scalar>;
+
+ ExtensionScalar(std::shared_ptr<Scalar> storage, std::shared_ptr<DataType> type)
+ : Scalar(std::move(type), true), value(std::move(storage)) {}
+
+ std::shared_ptr<Scalar> value;
+};
+
+/// @}
+
+namespace internal {
+
+inline Status CheckBufferLength(...) { return Status::OK(); }
+
+ARROW_EXPORT Status CheckBufferLength(const FixedSizeBinaryType* t,
+ const std::shared_ptr<Buffer>* b);
+
+} // namespace internal
+
+template <typename ValueRef>
+struct MakeScalarImpl;
+
+/// \defgroup scalar-factories Scalar factory functions
+///
+/// @{
+
+/// \brief Scalar factory for null scalars
+ARROW_EXPORT
+std::shared_ptr<Scalar> MakeNullScalar(std::shared_ptr<DataType> type);
+
+/// \brief Scalar factory for non-null scalars
+template <typename Value>
+Result<std::shared_ptr<Scalar>> MakeScalar(std::shared_ptr<DataType> type,
+ Value&& value) {
+ return MakeScalarImpl<Value&&>{type, std::forward<Value>(value), NULLPTR}.Finish();
+}
+
+/// \brief Type-inferring scalar factory for non-null scalars
+///
+/// Construct a Scalar instance with a DataType determined by the input C++ type.
+/// (for example Int8Scalar for a int8_t input).
+/// Only non-parametric primitive types and String are supported.
+template <typename Value, typename Traits = CTypeTraits<typename std::decay<Value>::type>,
+ typename ScalarType = typename Traits::ScalarType,
+ typename Enable = decltype(ScalarType(std::declval<Value>(),
+ Traits::type_singleton()))>
+std::shared_ptr<Scalar> MakeScalar(Value value) {
+ return std::make_shared<ScalarType>(std::move(value), Traits::type_singleton());
+}
+
+inline std::shared_ptr<Scalar> MakeScalar(std::string value) {
+ return std::make_shared<StringScalar>(std::move(value));
+}
+
+/// @}
+
+template <typename ValueRef>
+struct MakeScalarImpl {
+ template <typename T, typename ScalarType = typename TypeTraits<T>::ScalarType,
+ typename ValueType = typename ScalarType::ValueType,
+ typename Enable = typename std::enable_if<
+ std::is_constructible<ScalarType, ValueType,
+ std::shared_ptr<DataType>>::value &&
+ std::is_convertible<ValueRef, ValueType>::value>::type>
+ Status Visit(const T& t) {
+ ARROW_RETURN_NOT_OK(internal::CheckBufferLength(&t, &value_));
+ // `static_cast<ValueRef>` makes a rvalue if ValueRef is `ValueType&&`
+ out_ = std::make_shared<ScalarType>(
+ static_cast<ValueType>(static_cast<ValueRef>(value_)), std::move(type_));
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionType& t) {
+ ARROW_ASSIGN_OR_RAISE(auto storage,
+ MakeScalar(t.storage_type(), static_cast<ValueRef>(value_)));
+ out_ = std::make_shared<ExtensionScalar>(std::move(storage), type_);
+ return Status::OK();
+ }
+
+ // Enable constructing string/binary scalars (but not decimal, etc) from std::string
+ template <typename T>
+ enable_if_t<
+ std::is_same<typename std::remove_reference<ValueRef>::type, std::string>::value &&
+ (is_base_binary_type<T>::value || std::is_same<T, FixedSizeBinaryType>::value),
+ Status>
+ Visit(const T& t) {
+ using ScalarType = typename TypeTraits<T>::ScalarType;
+ out_ = std::make_shared<ScalarType>(Buffer::FromString(std::move(value_)),
+ std::move(type_));
+ return Status::OK();
+ }
+
+ Status Visit(const DataType& t) {
+ return Status::NotImplemented("constructing scalars of type ", t,
+ " from unboxed values");
+ }
+
+ Result<std::shared_ptr<Scalar>> Finish() && {
+ ARROW_RETURN_NOT_OK(VisitTypeInline(*type_, this));
+ return std::move(out_);
+ }
+
+ std::shared_ptr<DataType> type_;
+ ValueRef value_;
+ std::shared_ptr<Scalar> out_;
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/scalar_test.cc b/src/arrow/cpp/src/arrow/scalar_test.cc
new file mode 100644
index 000000000..99bcaec09
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/scalar_test.cc
@@ -0,0 +1,1629 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <limits>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/buffer.h"
+#include "arrow/memory_pool.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/testing/extension_type.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+std::shared_ptr<Scalar> CheckMakeNullScalar(const std::shared_ptr<DataType>& type) {
+ const auto scalar = MakeNullScalar(type);
+ ARROW_EXPECT_OK(scalar->Validate());
+ ARROW_EXPECT_OK(scalar->ValidateFull());
+ AssertTypeEqual(*type, *scalar->type);
+ EXPECT_FALSE(scalar->is_valid);
+ return scalar;
+}
+
+template <typename... MakeScalarArgs>
+void AssertMakeScalar(const Scalar& expected, MakeScalarArgs&&... args) {
+ ASSERT_OK_AND_ASSIGN(auto scalar, MakeScalar(std::forward<MakeScalarArgs>(args)...));
+ ASSERT_OK(scalar->Validate());
+ ASSERT_OK(scalar->ValidateFull());
+ AssertScalarsEqual(expected, *scalar, /*verbose=*/true);
+}
+
+void AssertParseScalar(const std::shared_ptr<DataType>& type, const util::string_view& s,
+ const Scalar& expected) {
+ ASSERT_OK_AND_ASSIGN(auto scalar, Scalar::Parse(type, s));
+ ASSERT_OK(scalar->Validate());
+ ASSERT_OK(scalar->ValidateFull());
+ AssertScalarsEqual(expected, *scalar, /*verbose=*/true);
+}
+
+void AssertValidationFails(const Scalar& scalar) {
+ ASSERT_RAISES(Invalid, scalar.Validate());
+ ASSERT_RAISES(Invalid, scalar.ValidateFull());
+}
+
+TEST(TestNullScalar, Basics) {
+ NullScalar scalar;
+ ASSERT_FALSE(scalar.is_valid);
+ ASSERT_TRUE(scalar.type->Equals(*null()));
+ ASSERT_OK(scalar.ValidateFull());
+
+ ASSERT_OK_AND_ASSIGN(auto arr, MakeArrayOfNull(null(), 1));
+ ASSERT_OK_AND_ASSIGN(auto first, arr->GetScalar(0));
+ ASSERT_TRUE(first->Equals(scalar));
+ ASSERT_OK(first->ValidateFull());
+}
+
+TEST(TestNullScalar, ValidateErrors) {
+ NullScalar scalar;
+ scalar.is_valid = true;
+ AssertValidationFails(scalar);
+}
+
+template <typename T>
+class TestNumericScalar : public ::testing::Test {
+ public:
+ TestNumericScalar() = default;
+};
+
+TYPED_TEST_SUITE(TestNumericScalar, NumericArrowTypes);
+
+TYPED_TEST(TestNumericScalar, Basics) {
+ using T = typename TypeParam::c_type;
+ using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
+
+ T value = static_cast<T>(1);
+
+ auto scalar_val = std::make_shared<ScalarType>(value);
+ ASSERT_EQ(value, scalar_val->value);
+ ASSERT_TRUE(scalar_val->is_valid);
+ ASSERT_OK(scalar_val->ValidateFull());
+
+ auto expected_type = TypeTraits<TypeParam>::type_singleton();
+ ASSERT_TRUE(scalar_val->type->Equals(*expected_type));
+
+ T other_value = static_cast<T>(2);
+ auto scalar_other = std::make_shared<ScalarType>(other_value);
+ ASSERT_NE(*scalar_other, *scalar_val);
+
+ scalar_val->value = other_value;
+ ASSERT_EQ(other_value, scalar_val->value);
+ ASSERT_EQ(*scalar_other, *scalar_val);
+
+ ScalarType stack_val;
+ ASSERT_FALSE(stack_val.is_valid);
+ ASSERT_OK(stack_val.ValidateFull());
+
+ auto null_value = std::make_shared<ScalarType>();
+ ASSERT_FALSE(null_value->is_valid);
+ ASSERT_OK(null_value->ValidateFull());
+
+ // Nulls should be equals to itself following Array::Equals
+ ASSERT_EQ(*null_value, stack_val);
+
+ auto dyn_null_value = CheckMakeNullScalar(expected_type);
+ ASSERT_EQ(*null_value, *dyn_null_value);
+
+ // test Array.GetScalar
+ auto arr = ArrayFromJSON(expected_type, "[null, 1, 2]");
+ ASSERT_OK_AND_ASSIGN(auto null, arr->GetScalar(0));
+ ASSERT_OK_AND_ASSIGN(auto one, arr->GetScalar(1));
+ ASSERT_OK_AND_ASSIGN(auto two, arr->GetScalar(2));
+ ASSERT_OK(null->ValidateFull());
+ ASSERT_OK(one->ValidateFull());
+ ASSERT_OK(two->ValidateFull());
+
+ ASSERT_TRUE(null->Equals(*null_value));
+ ASSERT_TRUE(one->Equals(ScalarType(1)));
+ ASSERT_FALSE(one->Equals(ScalarType(2)));
+ ASSERT_TRUE(two->Equals(ScalarType(2)));
+ ASSERT_FALSE(two->Equals(ScalarType(3)));
+
+ ASSERT_TRUE(null->ApproxEquals(*null_value));
+ ASSERT_TRUE(one->ApproxEquals(ScalarType(1)));
+ ASSERT_FALSE(one->ApproxEquals(ScalarType(2)));
+ ASSERT_TRUE(two->ApproxEquals(ScalarType(2)));
+ ASSERT_FALSE(two->ApproxEquals(ScalarType(3)));
+}
+
+TYPED_TEST(TestNumericScalar, Hashing) {
+ using T = typename TypeParam::c_type;
+ using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
+
+ std::unordered_set<std::shared_ptr<Scalar>, Scalar::Hash, Scalar::PtrsEqual> set;
+ set.emplace(std::make_shared<ScalarType>());
+ for (T i = 0; i < 10; ++i) {
+ set.emplace(std::make_shared<ScalarType>(i));
+ }
+
+ ASSERT_FALSE(set.emplace(std::make_shared<ScalarType>()).second);
+ for (T i = 0; i < 10; ++i) {
+ ASSERT_FALSE(set.emplace(std::make_shared<ScalarType>(i)).second);
+ }
+}
+
+TYPED_TEST(TestNumericScalar, MakeScalar) {
+ using T = typename TypeParam::c_type;
+ using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
+ auto type = TypeTraits<TypeParam>::type_singleton();
+
+ std::shared_ptr<Scalar> three = MakeScalar(static_cast<T>(3));
+ ASSERT_OK(three->ValidateFull());
+ ASSERT_EQ(ScalarType(3), *three);
+
+ AssertMakeScalar(ScalarType(3), type, static_cast<T>(3));
+
+ AssertParseScalar(type, "3", ScalarType(3));
+}
+
+template <typename T>
+class TestRealScalar : public ::testing::Test {
+ public:
+ using CType = typename T::c_type;
+ using ScalarType = typename TypeTraits<T>::ScalarType;
+
+ void SetUp() {
+ type_ = TypeTraits<T>::type_singleton();
+
+ scalar_val_ = std::make_shared<ScalarType>(static_cast<CType>(1));
+ ASSERT_TRUE(scalar_val_->is_valid);
+
+ scalar_other_ = std::make_shared<ScalarType>(static_cast<CType>(1.1));
+ ASSERT_TRUE(scalar_other_->is_valid);
+
+ const CType nan_value = std::numeric_limits<CType>::quiet_NaN();
+ scalar_nan_ = std::make_shared<ScalarType>(nan_value);
+ ASSERT_TRUE(scalar_nan_->is_valid);
+
+ const CType other_nan_value = std::numeric_limits<CType>::quiet_NaN();
+ scalar_other_nan_ = std::make_shared<ScalarType>(other_nan_value);
+ ASSERT_TRUE(scalar_other_nan_->is_valid);
+ }
+
+ void TestNanEquals() {
+ EqualOptions options = EqualOptions::Defaults();
+ ASSERT_FALSE(scalar_nan_->Equals(*scalar_val_, options));
+ ASSERT_FALSE(scalar_nan_->Equals(*scalar_nan_, options));
+ ASSERT_FALSE(scalar_nan_->Equals(*scalar_other_nan_, options));
+
+ options = options.nans_equal(true);
+ ASSERT_FALSE(scalar_nan_->Equals(*scalar_val_, options));
+ ASSERT_TRUE(scalar_nan_->Equals(*scalar_nan_, options));
+ ASSERT_TRUE(scalar_nan_->Equals(*scalar_other_nan_, options));
+ }
+
+ void TestApproxEquals() {
+ // The scalars are unequal with the small delta
+ EqualOptions options = EqualOptions::Defaults().atol(0.05);
+ ASSERT_FALSE(scalar_val_->ApproxEquals(*scalar_other_, options));
+ ASSERT_FALSE(scalar_other_->ApproxEquals(*scalar_val_, options));
+ ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options));
+ ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options));
+
+ // After enlarging the delta, they become equal
+ options = options.atol(0.15);
+ ASSERT_TRUE(scalar_val_->ApproxEquals(*scalar_other_, options));
+ ASSERT_TRUE(scalar_other_->ApproxEquals(*scalar_val_, options));
+ ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options));
+ ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options));
+
+ options = options.nans_equal(true);
+ ASSERT_TRUE(scalar_val_->ApproxEquals(*scalar_other_, options));
+ ASSERT_TRUE(scalar_other_->ApproxEquals(*scalar_val_, options));
+ ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options));
+ ASSERT_TRUE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options));
+
+ options = options.atol(0.05);
+ ASSERT_FALSE(scalar_val_->ApproxEquals(*scalar_other_, options));
+ ASSERT_FALSE(scalar_other_->ApproxEquals(*scalar_val_, options));
+ ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options));
+ ASSERT_TRUE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options));
+ }
+
+ void TestStructOf() {
+ auto ty = struct_({field("float", type_)});
+
+ StructScalar struct_val({scalar_val_}, ty);
+ StructScalar struct_other_val({scalar_other_}, ty);
+ StructScalar struct_nan({scalar_nan_}, ty);
+ StructScalar struct_other_nan({scalar_other_nan_}, ty);
+
+ EqualOptions options = EqualOptions::Defaults().atol(0.05);
+ ASSERT_FALSE(struct_val.Equals(struct_other_val, options));
+ ASSERT_FALSE(struct_other_val.Equals(struct_val, options));
+ ASSERT_FALSE(struct_nan.Equals(struct_val, options));
+ ASSERT_FALSE(struct_nan.Equals(struct_nan, options));
+ ASSERT_FALSE(struct_nan.Equals(struct_other_nan, options));
+ ASSERT_FALSE(struct_val.ApproxEquals(struct_other_val, options));
+ ASSERT_FALSE(struct_other_val.ApproxEquals(struct_val, options));
+ ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options));
+ ASSERT_FALSE(struct_nan.ApproxEquals(struct_nan, options));
+ ASSERT_FALSE(struct_nan.ApproxEquals(struct_other_nan, options));
+
+ options = options.atol(0.15);
+ ASSERT_FALSE(struct_val.Equals(struct_other_val, options));
+ ASSERT_FALSE(struct_other_val.Equals(struct_val, options));
+ ASSERT_FALSE(struct_nan.Equals(struct_val, options));
+ ASSERT_FALSE(struct_nan.Equals(struct_nan, options));
+ ASSERT_FALSE(struct_nan.Equals(struct_other_nan, options));
+ ASSERT_TRUE(struct_val.ApproxEquals(struct_other_val, options));
+ ASSERT_TRUE(struct_other_val.ApproxEquals(struct_val, options));
+ ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options));
+ ASSERT_FALSE(struct_nan.ApproxEquals(struct_nan, options));
+ ASSERT_FALSE(struct_nan.ApproxEquals(struct_other_nan, options));
+
+ options = options.nans_equal(true);
+ ASSERT_FALSE(struct_val.Equals(struct_other_val, options));
+ ASSERT_FALSE(struct_other_val.Equals(struct_val, options));
+ ASSERT_FALSE(struct_nan.Equals(struct_val, options));
+ ASSERT_TRUE(struct_nan.Equals(struct_nan, options));
+ ASSERT_TRUE(struct_nan.Equals(struct_other_nan, options));
+ ASSERT_TRUE(struct_val.ApproxEquals(struct_other_val, options));
+ ASSERT_TRUE(struct_other_val.ApproxEquals(struct_val, options));
+ ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options));
+ ASSERT_TRUE(struct_nan.ApproxEquals(struct_nan, options));
+ ASSERT_TRUE(struct_nan.ApproxEquals(struct_other_nan, options));
+
+ options = options.atol(0.05);
+ ASSERT_FALSE(struct_val.Equals(struct_other_val, options));
+ ASSERT_FALSE(struct_other_val.Equals(struct_val, options));
+ ASSERT_FALSE(struct_nan.Equals(struct_val, options));
+ ASSERT_TRUE(struct_nan.Equals(struct_nan, options));
+ ASSERT_TRUE(struct_nan.Equals(struct_other_nan, options));
+ ASSERT_FALSE(struct_val.ApproxEquals(struct_other_val, options));
+ ASSERT_FALSE(struct_other_val.ApproxEquals(struct_val, options));
+ ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options));
+ ASSERT_TRUE(struct_nan.ApproxEquals(struct_nan, options));
+ ASSERT_TRUE(struct_nan.ApproxEquals(struct_other_nan, options));
+ }
+
+ void TestListOf() {
+ auto ty = list(type_);
+
+ ListScalar list_val(ArrayFromJSON(type_, "[0, null, 1.0]"), ty);
+ ListScalar list_other_val(ArrayFromJSON(type_, "[0, null, 1.1]"), ty);
+ ListScalar list_nan(ArrayFromJSON(type_, "[0, null, NaN]"), ty);
+ ListScalar list_other_nan(ArrayFromJSON(type_, "[0, null, NaN]"), ty);
+
+ EqualOptions options = EqualOptions::Defaults().atol(0.05);
+ ASSERT_TRUE(list_val.Equals(list_val, options));
+ ASSERT_FALSE(list_val.Equals(list_other_val, options));
+ ASSERT_FALSE(list_nan.Equals(list_val, options));
+ ASSERT_FALSE(list_nan.Equals(list_nan, options));
+ ASSERT_FALSE(list_nan.Equals(list_other_nan, options));
+ ASSERT_TRUE(list_val.ApproxEquals(list_val, options));
+ ASSERT_FALSE(list_val.ApproxEquals(list_other_val, options));
+ ASSERT_FALSE(list_nan.ApproxEquals(list_val, options));
+ ASSERT_FALSE(list_nan.ApproxEquals(list_nan, options));
+ ASSERT_FALSE(list_nan.ApproxEquals(list_other_nan, options));
+
+ options = options.atol(0.15);
+ ASSERT_TRUE(list_val.Equals(list_val, options));
+ ASSERT_FALSE(list_val.Equals(list_other_val, options));
+ ASSERT_FALSE(list_nan.Equals(list_val, options));
+ ASSERT_FALSE(list_nan.Equals(list_nan, options));
+ ASSERT_FALSE(list_nan.Equals(list_other_nan, options));
+ ASSERT_TRUE(list_val.ApproxEquals(list_val, options));
+ ASSERT_TRUE(list_val.ApproxEquals(list_other_val, options));
+ ASSERT_FALSE(list_nan.ApproxEquals(list_val, options));
+ ASSERT_FALSE(list_nan.ApproxEquals(list_nan, options));
+ ASSERT_FALSE(list_nan.ApproxEquals(list_other_nan, options));
+
+ options = options.nans_equal(true);
+ ASSERT_TRUE(list_val.Equals(list_val, options));
+ ASSERT_FALSE(list_val.Equals(list_other_val, options));
+ ASSERT_FALSE(list_nan.Equals(list_val, options));
+ ASSERT_TRUE(list_nan.Equals(list_nan, options));
+ ASSERT_TRUE(list_nan.Equals(list_other_nan, options));
+ ASSERT_TRUE(list_val.ApproxEquals(list_val, options));
+ ASSERT_TRUE(list_val.ApproxEquals(list_other_val, options));
+ ASSERT_FALSE(list_nan.ApproxEquals(list_val, options));
+ ASSERT_TRUE(list_nan.ApproxEquals(list_nan, options));
+ ASSERT_TRUE(list_nan.ApproxEquals(list_other_nan, options));
+
+ options = options.atol(0.05);
+ ASSERT_TRUE(list_val.Equals(list_val, options));
+ ASSERT_FALSE(list_val.Equals(list_other_val, options));
+ ASSERT_FALSE(list_nan.Equals(list_val, options));
+ ASSERT_TRUE(list_nan.Equals(list_nan, options));
+ ASSERT_TRUE(list_nan.Equals(list_other_nan, options));
+ ASSERT_TRUE(list_val.ApproxEquals(list_val, options));
+ ASSERT_FALSE(list_val.ApproxEquals(list_other_val, options));
+ ASSERT_FALSE(list_nan.ApproxEquals(list_val, options));
+ ASSERT_TRUE(list_nan.ApproxEquals(list_nan, options));
+ ASSERT_TRUE(list_nan.ApproxEquals(list_other_nan, options));
+ }
+
+ protected:
+ std::shared_ptr<DataType> type_;
+ std::shared_ptr<Scalar> scalar_val_, scalar_other_, scalar_nan_, scalar_other_nan_;
+};
+
+TYPED_TEST_SUITE(TestRealScalar, RealArrowTypes);
+
+TYPED_TEST(TestRealScalar, NanEquals) { this->TestNanEquals(); }
+
+TYPED_TEST(TestRealScalar, ApproxEquals) { this->TestApproxEquals(); }
+
+TYPED_TEST(TestRealScalar, StructOf) { this->TestStructOf(); }
+
+TYPED_TEST(TestRealScalar, ListOf) { this->TestListOf(); }
+
+template <typename T>
+class TestDecimalScalar : public ::testing::Test {
+ public:
+ using ScalarType = typename TypeTraits<T>::ScalarType;
+ using ValueType = typename ScalarType::ValueType;
+
+ void TestBasics() {
+ const auto ty = std::make_shared<T>(3, 2);
+ const auto pi = ScalarType(ValueType(314), ty);
+ const auto pi2 = ScalarType(ValueType(628), ty);
+ const auto null = CheckMakeNullScalar(ty);
+
+ ASSERT_OK(pi.ValidateFull());
+ ASSERT_TRUE(pi.is_valid);
+ ASSERT_EQ(pi.value, ValueType("3.14"));
+
+ ASSERT_OK(null->ValidateFull());
+ ASSERT_FALSE(null->is_valid);
+
+ ASSERT_FALSE(pi.Equals(pi2));
+
+ // Test Array::GetScalar
+ auto arr = ArrayFromJSON(ty, "[null, \"3.14\"]");
+ ASSERT_OK_AND_ASSIGN(auto first, arr->GetScalar(0));
+ ASSERT_OK_AND_ASSIGN(auto second, arr->GetScalar(1));
+ ASSERT_OK(first->ValidateFull());
+ ASSERT_OK(second->ValidateFull());
+
+ ASSERT_TRUE(first->Equals(null));
+ ASSERT_FALSE(first->Equals(pi));
+ ASSERT_TRUE(second->Equals(pi));
+ ASSERT_FALSE(second->Equals(null));
+ }
+};
+
+TYPED_TEST_SUITE(TestDecimalScalar, DecimalArrowTypes);
+
+TYPED_TEST(TestDecimalScalar, Basics) { this->TestBasics(); }
+
+TEST(TestBinaryScalar, Basics) {
+ std::string data = "test data";
+ auto buf = std::make_shared<Buffer>(data);
+
+ BinaryScalar value(buf);
+ ASSERT_OK(value.ValidateFull());
+ ASSERT_TRUE(value.value->Equals(*buf));
+ ASSERT_TRUE(value.is_valid);
+ ASSERT_TRUE(value.type->Equals(*binary()));
+
+ auto ref_count = buf.use_count();
+ // Check that destructor doesn't fail to clean up a buffer
+ std::shared_ptr<Scalar> base_ref = std::make_shared<BinaryScalar>(buf);
+ base_ref = nullptr;
+ ASSERT_EQ(ref_count, buf.use_count());
+
+ BinaryScalar null_value;
+ ASSERT_FALSE(null_value.is_valid);
+ ASSERT_EQ(null_value.value, nullptr);
+ ASSERT_OK(null_value.ValidateFull());
+
+ StringScalar value2(buf);
+ ASSERT_OK(value2.ValidateFull());
+ ASSERT_TRUE(value2.value->Equals(*buf));
+ ASSERT_TRUE(value2.is_valid);
+ ASSERT_TRUE(value2.type->Equals(*utf8()));
+
+ // Same buffer, different type.
+ ASSERT_NE(value2, value);
+
+ StringScalar value3(buf);
+ // Same buffer, same type.
+ ASSERT_EQ(value2, value3);
+
+ StringScalar null_value2;
+ ASSERT_FALSE(null_value2.is_valid);
+
+ // test Array.GetScalar
+ auto arr = ArrayFromJSON(binary(), "[null, \"one\", \"two\"]");
+ ASSERT_OK_AND_ASSIGN(auto null, arr->GetScalar(0));
+ ASSERT_OK_AND_ASSIGN(auto one, arr->GetScalar(1));
+ ASSERT_OK_AND_ASSIGN(auto two, arr->GetScalar(2));
+ ASSERT_OK(null->ValidateFull());
+ ASSERT_OK(one->ValidateFull());
+ ASSERT_OK(two->ValidateFull());
+
+ ASSERT_TRUE(null->Equals(null_value));
+ ASSERT_TRUE(one->Equals(BinaryScalar(Buffer::FromString("one"))));
+ ASSERT_TRUE(two->Equals(BinaryScalar(Buffer::FromString("two"))));
+ ASSERT_FALSE(two->Equals(BinaryScalar(Buffer::FromString("else"))));
+}
+
+TEST(TestBinaryScalar, Hashing) {
+ auto FromInt = [](int i) {
+ return std::make_shared<BinaryScalar>(Buffer::FromString(std::to_string(i)));
+ };
+
+ std::unordered_set<std::shared_ptr<Scalar>, Scalar::Hash, Scalar::PtrsEqual> set;
+ set.emplace(std::make_shared<BinaryScalar>());
+ for (int i = 0; i < 10; ++i) {
+ set.emplace(FromInt(i));
+ }
+
+ ASSERT_FALSE(set.emplace(std::make_shared<BinaryScalar>()).second);
+ for (int i = 0; i < 10; ++i) {
+ ASSERT_FALSE(set.emplace(FromInt(i)).second);
+ }
+}
+
+TEST(TestBinaryScalar, ValidateErrors) {
+ // Inconsistent is_valid / value
+ BinaryScalar scalar(Buffer::FromString("xxx"));
+ scalar.is_valid = false;
+ AssertValidationFails(scalar);
+
+ auto null_scalar = MakeNullScalar(binary());
+ null_scalar->is_valid = true;
+ AssertValidationFails(*null_scalar);
+}
+
+template <typename T>
+class TestStringScalar : public ::testing::Test {
+ public:
+ using ScalarType = typename TypeTraits<T>::ScalarType;
+
+ void SetUp() { type_ = TypeTraits<T>::type_singleton(); }
+
+ void TestMakeScalar() {
+ AssertMakeScalar(ScalarType("three"), type_, Buffer::FromString("three"));
+
+ AssertParseScalar(type_, "three", ScalarType("three"));
+ }
+
+ void TestArrayGetScalar() {
+ auto arr = ArrayFromJSON(type_, R"([null, "one", "two"])");
+ ASSERT_OK_AND_ASSIGN(auto null, arr->GetScalar(0));
+ ASSERT_OK_AND_ASSIGN(auto one, arr->GetScalar(1));
+ ASSERT_OK_AND_ASSIGN(auto two, arr->GetScalar(2));
+ ASSERT_OK(null->ValidateFull());
+ ASSERT_OK(one->ValidateFull());
+ ASSERT_OK(two->ValidateFull());
+
+ ASSERT_TRUE(null->Equals(CheckMakeNullScalar(type_)));
+ ASSERT_TRUE(one->Equals(ScalarType("one")));
+ ASSERT_TRUE(two->Equals(ScalarType("two")));
+ ASSERT_FALSE(two->Equals(Int64Scalar(1)));
+ }
+
+ void TestValidateErrors() {
+ // Inconsistent is_valid / value
+ ScalarType scalar(Buffer::FromString("xxx"));
+ scalar.is_valid = false;
+ AssertValidationFails(scalar);
+
+ auto null_scalar = MakeNullScalar(type_);
+ null_scalar->is_valid = true;
+ AssertValidationFails(*null_scalar);
+
+ // Invalid UTF8
+ scalar = ScalarType(Buffer::FromString("\xff"));
+ ASSERT_OK(scalar.Validate());
+ ASSERT_RAISES(Invalid, scalar.ValidateFull());
+ }
+
+ protected:
+ std::shared_ptr<DataType> type_;
+};
+
+TYPED_TEST_SUITE(TestStringScalar, StringArrowTypes);
+
+TYPED_TEST(TestStringScalar, MakeScalar) { this->TestMakeScalar(); }
+
+TYPED_TEST(TestStringScalar, ArrayGetScalar) { this->TestArrayGetScalar(); }
+
+TYPED_TEST(TestStringScalar, ValidateErrors) { this->TestValidateErrors(); }
+
+TEST(TestStringScalar, MakeScalarImplicit) {
+ // MakeScalar("string literal") creates a StringScalar
+ auto three = MakeScalar("three");
+ ASSERT_OK(three->ValidateFull());
+ ASSERT_EQ(StringScalar("three"), *three);
+}
+
+TEST(TestStringScalar, MakeScalarString) {
+ // MakeScalar(std::string) creates a StringScalar via FromBuffer
+ std::string buf = "three";
+ auto three = MakeScalar(std::move(buf));
+ ASSERT_OK(three->ValidateFull());
+ ASSERT_EQ(StringScalar("three"), *three);
+}
+
+TEST(TestFixedSizeBinaryScalar, Basics) {
+ std::string data = "test data";
+ auto buf = std::make_shared<Buffer>(data);
+
+ auto ex_type = fixed_size_binary(9);
+
+ FixedSizeBinaryScalar value(buf, ex_type);
+ ASSERT_OK(value.ValidateFull());
+ ASSERT_TRUE(value.value->Equals(*buf));
+ ASSERT_TRUE(value.is_valid);
+ ASSERT_TRUE(value.type->Equals(*ex_type));
+
+ FixedSizeBinaryScalar null_value(ex_type);
+ ASSERT_OK(null_value.ValidateFull());
+ ASSERT_FALSE(null_value.is_valid);
+ ASSERT_EQ(null_value.value, nullptr);
+
+ // test Array.GetScalar
+ auto ty = fixed_size_binary(3);
+ auto arr = ArrayFromJSON(ty, R"([null, "one", "two"])");
+ ASSERT_OK_AND_ASSIGN(auto null, arr->GetScalar(0));
+ ASSERT_OK_AND_ASSIGN(auto one, arr->GetScalar(1));
+ ASSERT_OK_AND_ASSIGN(auto two, arr->GetScalar(2));
+ ASSERT_OK(null->ValidateFull());
+ ASSERT_OK(one->ValidateFull());
+ ASSERT_OK(two->ValidateFull());
+
+ ASSERT_TRUE(null->Equals(CheckMakeNullScalar(ty)));
+ ASSERT_TRUE(one->Equals(FixedSizeBinaryScalar(Buffer::FromString("one"), ty)));
+ ASSERT_TRUE(two->Equals(FixedSizeBinaryScalar(Buffer::FromString("two"), ty)));
+}
+
+TEST(TestFixedSizeBinaryScalar, MakeScalar) {
+ std::string data = "test data";
+ auto buf = std::make_shared<Buffer>(data);
+ auto type = fixed_size_binary(9);
+
+ AssertMakeScalar(FixedSizeBinaryScalar(buf, type), type, buf);
+
+ AssertParseScalar(type, util::string_view(data), FixedSizeBinaryScalar(buf, type));
+
+ // Wrong length
+ ASSERT_RAISES(Invalid, MakeScalar(type, Buffer::FromString(data.substr(3))).status());
+ ASSERT_RAISES(Invalid, Scalar::Parse(type, util::string_view(data).substr(3)).status());
+}
+
+TEST(TestFixedSizeBinaryScalar, ValidateErrors) {
+ std::string data = "test data";
+ auto buf = std::make_shared<Buffer>(data);
+ auto type = fixed_size_binary(9);
+
+ FixedSizeBinaryScalar scalar(buf, type);
+ ASSERT_OK(scalar.ValidateFull());
+
+ scalar.value = SliceBuffer(buf, 1);
+ AssertValidationFails(scalar);
+}
+
+TEST(TestDateScalars, Basics) {
+ int32_t i32_val = 1;
+ Date32Scalar date32_val(i32_val);
+ Date32Scalar date32_null;
+ ASSERT_OK(date32_val.ValidateFull());
+ ASSERT_OK(date32_null.ValidateFull());
+
+ ASSERT_TRUE(date32_val.type->Equals(*date32()));
+ ASSERT_TRUE(date32_val.is_valid);
+ ASSERT_FALSE(date32_null.is_valid);
+
+ int64_t i64_val = 2;
+ Date64Scalar date64_val(i64_val);
+ Date64Scalar date64_null;
+ ASSERT_OK(date64_val.ValidateFull());
+ ASSERT_OK(date64_null.ValidateFull());
+
+ ASSERT_EQ(i64_val, date64_val.value);
+ ASSERT_TRUE(date64_val.type->Equals(*date64()));
+ ASSERT_TRUE(date64_val.is_valid);
+ ASSERT_FALSE(date64_null.is_valid);
+
+ // test Array.GetScalar
+ for (auto ty : {date32(), date64()}) {
+ auto arr = ArrayFromJSON(ty, "[5, null, 42]");
+ ASSERT_OK_AND_ASSIGN(auto first, arr->GetScalar(0));
+ ASSERT_OK_AND_ASSIGN(auto null, arr->GetScalar(1));
+ ASSERT_OK_AND_ASSIGN(auto last, arr->GetScalar(2));
+ ASSERT_OK(first->ValidateFull());
+ ASSERT_OK(null->ValidateFull());
+ ASSERT_OK(last->ValidateFull());
+
+ ASSERT_TRUE(null->Equals(CheckMakeNullScalar(ty)));
+ ASSERT_TRUE(first->Equals(MakeScalar(ty, 5).ValueOrDie()));
+ ASSERT_TRUE(last->Equals(MakeScalar(ty, 42).ValueOrDie()));
+ ASSERT_FALSE(last->Equals(MakeScalar("string")));
+ }
+}
+
+TEST(TestDateScalars, MakeScalar) {
+ AssertMakeScalar(Date32Scalar(1), date32(), int32_t(1));
+ AssertParseScalar(date32(), "1454-10-22", Date32Scalar(-188171));
+
+ AssertMakeScalar(Date64Scalar(1), date64(), int64_t(1));
+ AssertParseScalar(date64(), "1454-10-22",
+ Date64Scalar(-188171LL * 24 * 60 * 60 * 1000));
+}
+
+TEST(TestTimeScalars, Basics) {
+ auto type1 = time32(TimeUnit::MILLI);
+ auto type2 = time32(TimeUnit::SECOND);
+ auto type3 = time64(TimeUnit::MICRO);
+ auto type4 = time64(TimeUnit::NANO);
+
+ int32_t i32_val = 1;
+ Time32Scalar time32_val(i32_val, type1);
+ Time32Scalar time32_null(type2);
+ ASSERT_OK(time32_val.ValidateFull());
+ ASSERT_OK(time32_null.ValidateFull());
+
+ ASSERT_EQ(i32_val, time32_val.value);
+ ASSERT_TRUE(time32_val.type->Equals(*type1));
+ ASSERT_TRUE(time32_val.is_valid);
+ ASSERT_FALSE(time32_null.is_valid);
+ ASSERT_TRUE(time32_null.type->Equals(*type2));
+
+ int64_t i64_val = 2;
+ Time64Scalar time64_val(i64_val, type3);
+ Time64Scalar time64_null(type4);
+ ASSERT_OK(time64_val.ValidateFull());
+ ASSERT_OK(time64_null.ValidateFull());
+
+ ASSERT_EQ(i64_val, time64_val.value);
+ ASSERT_TRUE(time64_val.type->Equals(*type3));
+ ASSERT_TRUE(time64_val.is_valid);
+ ASSERT_FALSE(time64_null.is_valid);
+ ASSERT_TRUE(time64_null.type->Equals(*type4));
+
+ // test Array.GetScalar
+ for (auto ty : {type1, type2, type3, type4}) {
+ auto arr = ArrayFromJSON(ty, "[5, null, 42]");
+ ASSERT_OK_AND_ASSIGN(auto first, arr->GetScalar(0));
+ ASSERT_OK_AND_ASSIGN(auto null, arr->GetScalar(1));
+ ASSERT_OK_AND_ASSIGN(auto last, arr->GetScalar(2));
+ ASSERT_OK(first->ValidateFull());
+ ASSERT_OK(null->ValidateFull());
+ ASSERT_OK(last->ValidateFull());
+
+ ASSERT_TRUE(null->Equals(CheckMakeNullScalar(ty)));
+ ASSERT_TRUE(first->Equals(MakeScalar(ty, 5).ValueOrDie()));
+ ASSERT_TRUE(last->Equals(MakeScalar(ty, 42).ValueOrDie()));
+ ASSERT_FALSE(last->Equals(MakeScalar("string")));
+ }
+}
+
+TEST(TestTimeScalars, MakeScalar) {
+ auto type1 = time32(TimeUnit::SECOND);
+ auto type2 = time32(TimeUnit::MILLI);
+ auto type3 = time64(TimeUnit::MICRO);
+ auto type4 = time64(TimeUnit::NANO);
+
+ AssertMakeScalar(Time32Scalar(1, type1), type1, int32_t(1));
+ AssertMakeScalar(Time32Scalar(1, type2), type2, int32_t(1));
+ AssertMakeScalar(Time64Scalar(1, type3), type3, int32_t(1));
+ AssertMakeScalar(Time64Scalar(1, type4), type4, int32_t(1));
+
+ int64_t tententen = 60 * (60 * (10) + 10) + 10;
+ AssertParseScalar(type1, "10:10:10",
+ Time32Scalar(static_cast<int32_t>(tententen), type1));
+
+ tententen = 1000 * tententen + 123;
+ AssertParseScalar(type2, "10:10:10.123",
+ Time32Scalar(static_cast<int32_t>(tententen), type2));
+
+ tententen = 1000 * tententen + 456;
+ AssertParseScalar(type3, "10:10:10.123456", Time64Scalar(tententen, type3));
+
+ tententen = 1000 * tententen + 789;
+ AssertParseScalar(type4, "10:10:10.123456789", Time64Scalar(tententen, type4));
+}
+
+TEST(TestTimestampScalars, Basics) {
+ auto type1 = timestamp(TimeUnit::MILLI);
+ auto type2 = timestamp(TimeUnit::SECOND);
+
+ int64_t val1 = 1;
+ int64_t val2 = 2;
+ TimestampScalar ts_val1(val1, type1);
+ TimestampScalar ts_val2(val2, type2);
+ TimestampScalar ts_null(type1);
+ ASSERT_OK(ts_val1.ValidateFull());
+ ASSERT_OK(ts_val2.ValidateFull());
+ ASSERT_OK(ts_null.ValidateFull());
+
+ ASSERT_EQ(val1, ts_val1.value);
+
+ ASSERT_TRUE(ts_val1.type->Equals(*type1));
+ ASSERT_TRUE(ts_val2.type->Equals(*type2));
+ ASSERT_TRUE(ts_val1.is_valid);
+ ASSERT_FALSE(ts_null.is_valid);
+ ASSERT_TRUE(ts_null.type->Equals(*type1));
+
+ ASSERT_NE(ts_val1, ts_val2);
+ ASSERT_NE(ts_val1, ts_null);
+ ASSERT_NE(ts_val2, ts_null);
+
+ // test Array.GetScalar
+ for (auto ty : {type1, type2}) {
+ auto arr = ArrayFromJSON(ty, "[5, null, 42]");
+ ASSERT_OK_AND_ASSIGN(auto first, arr->GetScalar(0));
+ ASSERT_OK_AND_ASSIGN(auto null, arr->GetScalar(1));
+ ASSERT_OK_AND_ASSIGN(auto last, arr->GetScalar(2));
+ ASSERT_OK(first->ValidateFull());
+ ASSERT_OK(null->ValidateFull());
+ ASSERT_OK(last->ValidateFull());
+
+ ASSERT_TRUE(null->Equals(CheckMakeNullScalar(ty)));
+ ASSERT_TRUE(first->Equals(MakeScalar(ty, 5).ValueOrDie()));
+ ASSERT_TRUE(last->Equals(MakeScalar(ty, 42).ValueOrDie()));
+ ASSERT_FALSE(last->Equals(MakeScalar(int64(), 42).ValueOrDie()));
+ }
+}
+
+TEST(TestTimestampScalars, MakeScalar) {
+ auto type1 = timestamp(TimeUnit::MILLI);
+ auto type2 = timestamp(TimeUnit::SECOND);
+ auto type3 = timestamp(TimeUnit::MICRO);
+ auto type4 = timestamp(TimeUnit::NANO);
+
+ util::string_view epoch_plus_1s = "1970-01-01 00:00:01";
+
+ AssertMakeScalar(TimestampScalar(1, type1), type1, int64_t(1));
+ AssertParseScalar(type1, epoch_plus_1s, TimestampScalar(1000, type1));
+
+ AssertMakeScalar(TimestampScalar(1, type2), type2, int64_t(1));
+ AssertParseScalar(type2, epoch_plus_1s, TimestampScalar(1, type2));
+
+ AssertMakeScalar(TimestampScalar(1, type3), type3, int64_t(1));
+ AssertParseScalar(type3, epoch_plus_1s, TimestampScalar(1000 * 1000, type3));
+
+ AssertMakeScalar(TimestampScalar(1, type4), type4, int64_t(1));
+ AssertParseScalar(type4, epoch_plus_1s, TimestampScalar(1000 * 1000 * 1000, type4));
+}
+
+TEST(TestTimestampScalars, Cast) {
+ auto convert = [](TimeUnit::type in, TimeUnit::type out, int64_t value) -> int64_t {
+ auto scalar =
+ TimestampScalar(value, timestamp(in)).CastTo(timestamp(out)).ValueOrDie();
+ return internal::checked_pointer_cast<TimestampScalar>(scalar)->value;
+ };
+
+ EXPECT_EQ(convert(TimeUnit::SECOND, TimeUnit::MILLI, 1), 1000);
+ EXPECT_EQ(convert(TimeUnit::SECOND, TimeUnit::NANO, 1), 1000000000);
+
+ EXPECT_EQ(convert(TimeUnit::NANO, TimeUnit::MICRO, 1234), 1);
+ EXPECT_EQ(convert(TimeUnit::MICRO, TimeUnit::MILLI, 4567), 4);
+
+ ASSERT_OK_AND_ASSIGN(auto str,
+ TimestampScalar(1024, timestamp(TimeUnit::MILLI)).CastTo(utf8()));
+ EXPECT_EQ(*str, StringScalar("1970-01-01 00:00:01.024"));
+ ASSERT_OK_AND_ASSIGN(auto i64,
+ TimestampScalar(1024, timestamp(TimeUnit::MILLI)).CastTo(int64()));
+ EXPECT_EQ(*i64, Int64Scalar(1024));
+
+ constexpr int64_t kMillisecondsInDay = 86400000;
+ ASSERT_OK_AND_ASSIGN(
+ auto d64, TimestampScalar(1024 * kMillisecondsInDay + 3, timestamp(TimeUnit::MILLI))
+ .CastTo(date64()));
+ EXPECT_EQ(*d64, Date64Scalar(1024 * kMillisecondsInDay));
+}
+
+TEST(TestDurationScalars, Basics) {
+ auto type1 = duration(TimeUnit::MILLI);
+ auto type2 = duration(TimeUnit::SECOND);
+
+ int64_t val1 = 1;
+ int64_t val2 = 2;
+ DurationScalar ts_val1(val1, type1);
+ DurationScalar ts_val2(val2, type2);
+ DurationScalar ts_null(type1);
+ ASSERT_OK(ts_val1.ValidateFull());
+ ASSERT_OK(ts_val2.ValidateFull());
+ ASSERT_OK(ts_null.ValidateFull());
+
+ ASSERT_EQ(val1, ts_val1.value);
+
+ ASSERT_TRUE(ts_val1.type->Equals(*type1));
+ ASSERT_TRUE(ts_val2.type->Equals(*type2));
+ ASSERT_TRUE(ts_val1.is_valid);
+ ASSERT_FALSE(ts_null.is_valid);
+ ASSERT_TRUE(ts_null.type->Equals(*type1));
+
+ ASSERT_NE(ts_val1, ts_val2);
+ ASSERT_NE(ts_val1, ts_null);
+ ASSERT_NE(ts_val2, ts_null);
+
+ // test Array.GetScalar
+ for (auto ty : {type1, type2}) {
+ auto arr = ArrayFromJSON(ty, "[5, null, 42]");
+ ASSERT_OK_AND_ASSIGN(auto first, arr->GetScalar(0));
+ ASSERT_OK_AND_ASSIGN(auto null, arr->GetScalar(1));
+ ASSERT_OK_AND_ASSIGN(auto last, arr->GetScalar(2));
+ ASSERT_OK(first->ValidateFull());
+ ASSERT_OK(null->ValidateFull());
+ ASSERT_OK(last->ValidateFull());
+
+ ASSERT_TRUE(null->Equals(CheckMakeNullScalar(ty)));
+ ASSERT_TRUE(first->Equals(MakeScalar(ty, 5).ValueOrDie()));
+ ASSERT_TRUE(last->Equals(MakeScalar(ty, 42).ValueOrDie()));
+ }
+}
+
+TEST(TestMonthIntervalScalars, Basics) {
+ auto type = month_interval();
+
+ int32_t val1 = 1;
+ int32_t val2 = 2;
+ MonthIntervalScalar ts_val1(val1);
+ MonthIntervalScalar ts_val2(val2);
+ MonthIntervalScalar ts_null;
+ ASSERT_OK(ts_val1.ValidateFull());
+ ASSERT_OK(ts_val2.ValidateFull());
+ ASSERT_OK(ts_null.ValidateFull());
+
+ ASSERT_EQ(val1, ts_val1.value);
+
+ ASSERT_TRUE(ts_val1.type->Equals(*type));
+ ASSERT_TRUE(ts_val2.type->Equals(*type));
+ ASSERT_TRUE(ts_val1.is_valid);
+ ASSERT_FALSE(ts_null.is_valid);
+ ASSERT_TRUE(ts_null.type->Equals(*type));
+
+ ASSERT_NE(ts_val1, ts_val2);
+ ASSERT_NE(ts_val1, ts_null);
+ ASSERT_NE(ts_val2, ts_null);
+
+ // test Array.GetScalar
+ auto arr = ArrayFromJSON(type, "[5, null, 42]");
+ ASSERT_OK_AND_ASSIGN(auto first, arr->GetScalar(0));
+ ASSERT_OK_AND_ASSIGN(auto null, arr->GetScalar(1));
+ ASSERT_OK_AND_ASSIGN(auto last, arr->GetScalar(2));
+ ASSERT_OK(first->ValidateFull());
+ ASSERT_OK(null->ValidateFull());
+ ASSERT_OK(last->ValidateFull());
+
+ ASSERT_TRUE(null->Equals(CheckMakeNullScalar(type)));
+ ASSERT_TRUE(first->Equals(MakeScalar(type, 5).ValueOrDie()));
+ ASSERT_TRUE(last->Equals(MakeScalar(type, 42).ValueOrDie()));
+}
+
+TEST(TestDayTimeIntervalScalars, Basics) {
+ auto type = day_time_interval();
+
+ DayTimeIntervalType::DayMilliseconds val1 = {1, 1};
+ DayTimeIntervalType::DayMilliseconds val2 = {2, 2};
+ DayTimeIntervalScalar ts_val1(val1);
+ DayTimeIntervalScalar ts_val2(val2);
+ DayTimeIntervalScalar ts_null;
+ ASSERT_OK(ts_val1.ValidateFull());
+ ASSERT_OK(ts_val2.ValidateFull());
+ ASSERT_OK(ts_null.ValidateFull());
+
+ ASSERT_EQ(val1, ts_val1.value);
+
+ ASSERT_TRUE(ts_val1.type->Equals(*type));
+ ASSERT_TRUE(ts_val2.type->Equals(*type));
+ ASSERT_TRUE(ts_val1.is_valid);
+ ASSERT_FALSE(ts_null.is_valid);
+ ASSERT_TRUE(ts_null.type->Equals(*type));
+
+ ASSERT_NE(ts_val1, ts_val2);
+ ASSERT_NE(ts_val1, ts_null);
+ ASSERT_NE(ts_val2, ts_null);
+
+ // test Array.GetScalar
+ auto arr = ArrayFromJSON(type, "[[2, 2], null]");
+ ASSERT_OK_AND_ASSIGN(auto first, arr->GetScalar(0));
+ ASSERT_OK_AND_ASSIGN(auto null, arr->GetScalar(1));
+ ASSERT_OK(first->ValidateFull());
+ ASSERT_OK(null->ValidateFull());
+
+ ASSERT_TRUE(null->Equals(ts_null));
+ ASSERT_TRUE(first->Equals(ts_val2));
+}
+
+// TODO test HalfFloatScalar
+
+TYPED_TEST(TestNumericScalar, Cast) {
+ auto type = TypeTraits<TypeParam>::type_singleton();
+
+ for (util::string_view repr : {"0", "1", "3"}) {
+ std::shared_ptr<Scalar> scalar;
+ ASSERT_OK_AND_ASSIGN(scalar, Scalar::Parse(type, repr));
+
+ // cast to and from other numeric scalars
+ for (auto other_type : {float32(), int8(), int64(), uint32()}) {
+ std::shared_ptr<Scalar> other_scalar;
+ ASSERT_OK_AND_ASSIGN(other_scalar, Scalar::Parse(other_type, repr));
+
+ ASSERT_OK_AND_ASSIGN(auto cast_to_other, scalar->CastTo(other_type))
+ ASSERT_EQ(*cast_to_other, *other_scalar);
+
+ ASSERT_OK_AND_ASSIGN(auto cast_from_other, other_scalar->CastTo(type))
+ ASSERT_EQ(*cast_from_other, *scalar);
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto cast_from_string,
+ StringScalar(std::string(repr)).CastTo(type));
+ ASSERT_EQ(*cast_from_string, *scalar);
+
+ if (is_integer_type<TypeParam>::value) {
+ ASSERT_OK_AND_ASSIGN(auto cast_to_string, scalar->CastTo(utf8()));
+ ASSERT_EQ(
+ util::string_view(*checked_cast<const StringScalar&>(*cast_to_string).value),
+ repr);
+ }
+ }
+}
+
+template <typename T>
+std::shared_ptr<DataType> MakeListType(std::shared_ptr<DataType> value_type,
+ int32_t list_size) {
+ return std::make_shared<T>(std::move(value_type));
+}
+
+template <>
+std::shared_ptr<DataType> MakeListType<FixedSizeListType>(
+ std::shared_ptr<DataType> value_type, int32_t list_size) {
+ return fixed_size_list(std::move(value_type), list_size);
+}
+
+template <typename T>
+class TestListScalar : public ::testing::Test {
+ public:
+ using ScalarType = typename TypeTraits<T>::ScalarType;
+
+ void SetUp() {
+ // type_ = std::make_shared<T>(int16());
+ type_ = MakeListType<T>(int16(), 3);
+ value_ = ArrayFromJSON(int16(), "[1, 2, null]");
+ }
+
+ void TestBasics() {
+ ScalarType scalar(value_);
+ ASSERT_OK(scalar.ValidateFull());
+ ASSERT_TRUE(scalar.is_valid);
+ AssertTypeEqual(scalar.type, type_);
+
+ auto null_scalar = CheckMakeNullScalar(type_);
+ ASSERT_OK(null_scalar->ValidateFull());
+ ASSERT_FALSE(null_scalar->is_valid);
+ AssertTypeEqual(null_scalar->type, type_);
+ }
+
+ void TestValidateErrors() {
+ // Inconsistent is_valid / value
+ ScalarType scalar(value_);
+ scalar.is_valid = false;
+ AssertValidationFails(scalar);
+
+ scalar = ScalarType(value_);
+ scalar.value = nullptr;
+ AssertValidationFails(scalar);
+
+ // Inconsistent child type
+ scalar = ScalarType(value_);
+ scalar.value = ArrayFromJSON(int32(), "[1, 2, null]");
+ AssertValidationFails(scalar);
+
+ // Invalid UTF8 in child data
+ scalar = ScalarType(ArrayFromJSON(utf8(), "[null, null, \"\xff\"]"));
+ ASSERT_OK(scalar.Validate());
+ ASSERT_RAISES(Invalid, scalar.ValidateFull());
+ }
+
+ protected:
+ std::shared_ptr<DataType> type_;
+ std::shared_ptr<Array> value_;
+};
+
+using ListScalarTestTypes = ::testing::Types<ListType, LargeListType, FixedSizeListType>;
+
+TYPED_TEST_SUITE(TestListScalar, ListScalarTestTypes);
+
+TYPED_TEST(TestListScalar, Basics) { this->TestBasics(); }
+
+TYPED_TEST(TestListScalar, ValidateErrors) { this->TestValidateErrors(); }
+
+TEST(TestFixedSizeListScalar, ValidateErrors) {
+ const auto ty = fixed_size_list(int16(), 3);
+ FixedSizeListScalar scalar(ArrayFromJSON(int16(), "[1, 2, 5]"), ty);
+ ASSERT_OK(scalar.ValidateFull());
+
+ scalar.type = fixed_size_list(int16(), 4);
+ AssertValidationFails(scalar);
+}
+
+TEST(TestMapScalar, Basics) {
+ auto value =
+ ArrayFromJSON(struct_({field("key", utf8(), false), field("value", int8())}),
+ R"([{"key": "a", "value": 1}, {"key": "b", "value": 2}])");
+ auto scalar = MapScalar(value);
+ ASSERT_OK(scalar.ValidateFull());
+
+ auto expected_scalar_type = map(utf8(), field("value", int8()));
+
+ ASSERT_TRUE(scalar.type->Equals(expected_scalar_type));
+ ASSERT_TRUE(value->Equals(scalar.value));
+}
+
+TEST(TestMapScalar, NullScalar) {
+ CheckMakeNullScalar(map(utf8(), field("value", int8())));
+}
+
+TEST(TestStructScalar, FieldAccess) {
+ StructScalar abc({MakeScalar(true), MakeNullScalar(int32()), MakeScalar("hello"),
+ MakeNullScalar(int64())},
+ struct_({field("a", boolean()), field("b", int32()),
+ field("b", utf8()), field("d", int64())}));
+ ASSERT_OK(abc.ValidateFull());
+
+ ASSERT_OK_AND_ASSIGN(auto a, abc.field("a"));
+ AssertScalarsEqual(*a, *abc.value[0]);
+
+ ASSERT_RAISES(Invalid, abc.field("b").status());
+
+ ASSERT_OK_AND_ASSIGN(auto b, abc.field(1));
+ AssertScalarsEqual(*b, *abc.value[1]);
+
+ ASSERT_RAISES(Invalid, abc.field(5).status());
+ ASSERT_RAISES(Invalid, abc.field("c").status());
+
+ ASSERT_OK_AND_ASSIGN(auto d, abc.field("d"));
+ ASSERT_TRUE(d->Equals(MakeNullScalar(int64())));
+ ASSERT_FALSE(d->Equals(MakeScalar(int64(), 12).ValueOrDie()));
+}
+
+TEST(TestStructScalar, NullScalar) {
+ auto ty = struct_({field("a", boolean()), field("b", int32()), field("b", utf8()),
+ field("d", int64())});
+
+ StructScalar null_scalar(ty);
+ ASSERT_OK(null_scalar.ValidateFull());
+ ASSERT_FALSE(null_scalar.is_valid);
+
+ const auto scalar = CheckMakeNullScalar(ty);
+ ASSERT_TRUE(scalar->Equals(null_scalar));
+}
+
+TEST(TestStructScalar, EmptyStruct) {
+ auto ty = struct_({});
+
+ StructScalar null_scalar(ty);
+ ASSERT_OK(null_scalar.ValidateFull());
+ ASSERT_FALSE(null_scalar.is_valid);
+
+ auto scalar = CheckMakeNullScalar(ty);
+ ASSERT_FALSE(scalar->is_valid);
+ ASSERT_TRUE(scalar->Equals(null_scalar));
+
+ StructScalar valid_scalar({}, ty);
+ ASSERT_OK(valid_scalar.ValidateFull());
+ ASSERT_TRUE(valid_scalar.is_valid);
+ ASSERT_FALSE(valid_scalar.Equals(null_scalar));
+
+ auto arr = ArrayFromJSON(ty, "[{}]");
+ ASSERT_TRUE(arr->IsValid(0));
+ ASSERT_OK_AND_ASSIGN(scalar, arr->GetScalar(0));
+ ASSERT_OK(scalar->ValidateFull());
+ ASSERT_TRUE(scalar->is_valid);
+ ASSERT_TRUE(scalar->Equals(valid_scalar));
+}
+
+TEST(TestStructScalar, ValidateErrors) {
+ auto ty = struct_({field("a", utf8())});
+
+ // Inconsistent is_valid / value
+ StructScalar scalar({MakeScalar("hello")}, ty);
+ scalar.is_valid = false;
+ AssertValidationFails(scalar);
+
+ scalar = StructScalar(ty);
+ scalar.is_valid = true;
+ AssertValidationFails(scalar);
+
+ // Inconsistent number of fields
+ scalar = StructScalar({}, ty);
+ AssertValidationFails(scalar);
+
+ scalar = StructScalar({MakeScalar("foo"), MakeScalar("bar")}, ty);
+ AssertValidationFails(scalar);
+
+ // Inconsistent child value type
+ scalar = StructScalar({MakeScalar(42)}, ty);
+ AssertValidationFails(scalar);
+
+ // Child value has invalid UTF8 data
+ scalar = StructScalar({MakeScalar("\xff")}, ty);
+ ASSERT_OK(scalar.Validate());
+ ASSERT_RAISES(Invalid, scalar.ValidateFull());
+}
+
+TEST(TestDictionaryScalar, Basics) {
+ for (auto index_ty : all_dictionary_index_types()) {
+ auto ty = dictionary(index_ty, utf8());
+ auto dict = ArrayFromJSON(utf8(), R"(["alpha", null, "gamma"])");
+
+ DictionaryScalar::ValueType alpha;
+ ASSERT_OK_AND_ASSIGN(alpha.index, MakeScalar(index_ty, 0));
+ alpha.dictionary = dict;
+
+ DictionaryScalar::ValueType gamma;
+ ASSERT_OK_AND_ASSIGN(gamma.index, MakeScalar(index_ty, 2));
+ gamma.dictionary = dict;
+
+ DictionaryScalar::ValueType null_value;
+ ASSERT_OK_AND_ASSIGN(null_value.index, MakeScalar(index_ty, 1));
+ null_value.dictionary = dict;
+
+ auto scalar_null = MakeNullScalar(ty);
+ checked_cast<DictionaryScalar&>(*scalar_null).value.dictionary = dict;
+ ASSERT_OK(scalar_null->ValidateFull());
+
+ auto scalar_alpha = DictionaryScalar(alpha, ty);
+ ASSERT_OK(scalar_alpha.ValidateFull());
+ auto scalar_gamma = DictionaryScalar(gamma, ty);
+ ASSERT_OK(scalar_gamma.ValidateFull());
+
+ // NOTE: index is valid, though corresponding value is null
+ auto scalar_null_value = DictionaryScalar(null_value, ty);
+ ASSERT_OK(scalar_null_value.ValidateFull());
+
+ ASSERT_OK_AND_ASSIGN(
+ auto encoded_null,
+ checked_cast<const DictionaryScalar&>(*scalar_null).GetEncodedValue());
+ ASSERT_OK(encoded_null->ValidateFull());
+ ASSERT_TRUE(encoded_null->Equals(MakeNullScalar(utf8())));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto encoded_null_value,
+ checked_cast<const DictionaryScalar&>(scalar_null_value).GetEncodedValue());
+ ASSERT_OK(encoded_null_value->ValidateFull());
+ ASSERT_TRUE(encoded_null_value->Equals(MakeNullScalar(utf8())));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto encoded_alpha,
+ checked_cast<const DictionaryScalar&>(scalar_alpha).GetEncodedValue());
+ ASSERT_OK(encoded_alpha->ValidateFull());
+ ASSERT_TRUE(encoded_alpha->Equals(MakeScalar("alpha")));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto encoded_gamma,
+ checked_cast<const DictionaryScalar&>(scalar_gamma).GetEncodedValue());
+ ASSERT_OK(encoded_gamma->ValidateFull());
+ ASSERT_TRUE(encoded_gamma->Equals(MakeScalar("gamma")));
+
+ // test Array.GetScalar
+ DictionaryArray arr(ty, ArrayFromJSON(index_ty, "[2, 0, 1, null]"), dict);
+ ASSERT_OK_AND_ASSIGN(auto first, arr.GetScalar(0));
+ ASSERT_OK_AND_ASSIGN(auto second, arr.GetScalar(1));
+ ASSERT_OK_AND_ASSIGN(auto third, arr.GetScalar(1));
+ ASSERT_OK_AND_ASSIGN(auto last, arr.GetScalar(3));
+ ASSERT_OK(first->ValidateFull());
+ ASSERT_OK(second->ValidateFull());
+ ASSERT_OK(last->ValidateFull());
+
+ ASSERT_TRUE(first->is_valid);
+ ASSERT_TRUE(second->is_valid);
+ ASSERT_TRUE(third->is_valid); // valid because of valid index, despite null value
+ ASSERT_FALSE(last->is_valid);
+
+ ASSERT_TRUE(first->Equals(scalar_gamma));
+ ASSERT_TRUE(second->Equals(scalar_alpha));
+ ASSERT_TRUE(last->Equals(scalar_null));
+
+ auto first_dict_scalar = checked_cast<const DictionaryScalar&>(*first);
+ ASSERT_TRUE(first_dict_scalar.value.dictionary->Equals(arr.dictionary()));
+
+ auto second_dict_scalar = checked_cast<const DictionaryScalar&>(*second);
+ ASSERT_TRUE(second_dict_scalar.value.dictionary->Equals(arr.dictionary()));
+ }
+}
+
+TEST(TestDictionaryScalar, ValidateErrors) {
+ auto index_ty = int16();
+ auto value_ty = utf8();
+ auto dict = ArrayFromJSON(value_ty, R"(["alpha", null, "gamma"])");
+ auto dict_ty = dictionary(index_ty, value_ty);
+
+ DictionaryScalar::ValueType alpha;
+ ASSERT_OK_AND_ASSIGN(alpha.index, MakeScalar(index_ty, 0));
+ alpha.dictionary = dict;
+
+ // Valid index, null underlying value
+ DictionaryScalar::ValueType null_value;
+ ASSERT_OK_AND_ASSIGN(null_value.index, MakeScalar(index_ty, 1));
+ null_value.dictionary = dict;
+
+ // Null index, no value
+ DictionaryScalar::ValueType null{MakeNullScalar(index_ty), dict};
+
+ // Inconsistent index type
+ auto scalar = DictionaryScalar(alpha, dictionary(int32(), value_ty));
+ AssertValidationFails(scalar);
+
+ // Inconsistent index type
+ scalar = DictionaryScalar(alpha, dictionary(index_ty, binary()));
+ AssertValidationFails(scalar);
+
+ // Inconsistent is_valid / value
+ scalar = DictionaryScalar(alpha, dict_ty);
+ ASSERT_OK(scalar.ValidateFull());
+ scalar.is_valid = false;
+ AssertValidationFails(scalar);
+
+ scalar = DictionaryScalar(null_value, dict_ty);
+ ASSERT_OK(scalar.ValidateFull());
+ scalar.is_valid = false;
+ AssertValidationFails(scalar);
+
+ scalar = DictionaryScalar(null, dict_ty);
+ AssertValidationFails(scalar);
+ scalar.is_valid = false;
+ ASSERT_OK(scalar.ValidateFull());
+
+ // Index value out of bounds
+ for (int64_t index : {-1, 3}) {
+ DictionaryScalar::ValueType invalid;
+ ASSERT_OK_AND_ASSIGN(invalid.index, MakeScalar(index_ty, index));
+ invalid.dictionary = dict;
+
+ scalar = DictionaryScalar(invalid, dict_ty);
+ ASSERT_OK(scalar.Validate());
+ ASSERT_RAISES(Invalid, scalar.ValidateFull());
+ }
+}
+
+TEST(TestDictionaryScalar, Cast) {
+ for (auto index_ty : all_dictionary_index_types()) {
+ auto ty = dictionary(index_ty, utf8());
+ auto dict = checked_pointer_cast<StringArray>(
+ ArrayFromJSON(utf8(), R"(["alpha", null, "gamma"])"));
+
+ for (int64_t i = 0; i < dict->length(); ++i) {
+ auto alpha =
+ dict->IsValid(i) ? MakeScalar(dict->GetString(i)) : MakeNullScalar(utf8());
+ // Cast string to dict(..., string)
+ ASSERT_OK_AND_ASSIGN(auto cast_alpha, alpha->CastTo(ty));
+ ASSERT_OK(cast_alpha->ValidateFull());
+ ASSERT_OK_AND_ASSIGN(
+ auto roundtripped_alpha,
+ checked_cast<const DictionaryScalar&>(*cast_alpha).GetEncodedValue());
+
+ ASSERT_OK_AND_ASSIGN(auto i_scalar, MakeScalar(index_ty, i));
+ auto alpha_dict = DictionaryScalar({i_scalar, dict}, ty);
+ ASSERT_OK(alpha_dict.ValidateFull());
+ ASSERT_OK_AND_ASSIGN(
+ auto encoded_alpha,
+ checked_cast<const DictionaryScalar&>(alpha_dict).GetEncodedValue());
+
+ AssertScalarsEqual(*alpha, *roundtripped_alpha);
+ AssertScalarsEqual(*encoded_alpha, *roundtripped_alpha);
+
+ // dictionaries differ, though encoded values are identical
+ ASSERT_FALSE(alpha_dict.Equals(cast_alpha));
+ }
+ }
+}
+
+void CheckGetValidUnionScalar(const Array& arr, int64_t index, const Scalar& expected,
+ const Scalar& expected_value) {
+ ASSERT_OK_AND_ASSIGN(auto scalar, arr.GetScalar(index));
+ ASSERT_OK(scalar->ValidateFull());
+ ASSERT_TRUE(scalar->Equals(expected));
+
+ const auto& as_union = checked_cast<const UnionScalar&>(*scalar);
+ ASSERT_TRUE(as_union.is_valid);
+ ASSERT_TRUE(as_union.value->Equals(expected_value));
+}
+
+void CheckGetNullUnionScalar(const Array& arr, int64_t index) {
+ ASSERT_OK_AND_ASSIGN(auto scalar, arr.GetScalar(index));
+ ASSERT_TRUE(scalar->Equals(MakeNullScalar(arr.type())));
+
+ const auto& as_union = checked_cast<const UnionScalar&>(*scalar);
+ ASSERT_FALSE(as_union.is_valid);
+ // XXX in reality, the union array doesn't have a validity bitmap.
+ // Validity is inferred from the underlying child value, which should maybe
+ // be reflected here...
+ ASSERT_EQ(as_union.value, nullptr);
+}
+
+template <typename Type>
+class TestUnionScalar : public ::testing::Test {
+ public:
+ using UnionType = Type;
+ using ScalarType = typename TypeTraits<UnionType>::ScalarType;
+
+ void SetUp() {
+ type_.reset(new UnionType({field("string", utf8()), field("number", uint64()),
+ field("other_number", uint64())},
+ /*type_codes=*/{3, 42, 43}));
+ alpha_ = MakeScalar("alpha");
+ beta_ = MakeScalar("beta");
+ ASSERT_OK_AND_ASSIGN(two_, MakeScalar(uint64(), 2));
+ ASSERT_OK_AND_ASSIGN(three_, MakeScalar(uint64(), 3));
+
+ union_alpha_ = std::make_shared<ScalarType>(alpha_, 3, type_);
+ union_beta_ = std::make_shared<ScalarType>(beta_, 3, type_);
+ union_two_ = std::make_shared<ScalarType>(two_, 42, type_);
+ union_other_two_ = std::make_shared<ScalarType>(two_, 43, type_);
+ union_three_ = std::make_shared<ScalarType>(three_, 42, type_);
+ union_string_null_ = MakeSpecificNullScalar(3);
+ union_number_null_ = MakeSpecificNullScalar(42);
+ }
+
+ void TestValidate() {
+ ASSERT_OK(union_alpha_->ValidateFull());
+ ASSERT_OK(union_beta_->ValidateFull());
+ ASSERT_OK(union_two_->ValidateFull());
+ ASSERT_OK(union_other_two_->ValidateFull());
+ ASSERT_OK(union_three_->ValidateFull());
+ ASSERT_OK(union_string_null_->ValidateFull());
+ ASSERT_OK(union_number_null_->ValidateFull());
+ }
+
+ void TestValidateErrors() {
+ // Type code doesn't exist
+ AssertValidationFails(ScalarType(alpha_, 0, type_));
+ AssertValidationFails(ScalarType(alpha_, 0, type_));
+ AssertValidationFails(ScalarType(0, type_));
+ AssertValidationFails(ScalarType(alpha_, -42, type_));
+ AssertValidationFails(ScalarType(-42, type_));
+
+ // Type code doesn't correspond to child type
+ AssertValidationFails(ScalarType(alpha_, 42, type_));
+ AssertValidationFails(ScalarType(two_, 3, type_));
+
+ // underlying value has invalid UTF8
+ auto invalid_utf8 = std::make_shared<StringScalar>("\xff");
+ auto scalar = std::make_shared<ScalarType>(invalid_utf8, 3, type_);
+ ASSERT_OK(scalar->Validate());
+ ASSERT_RAISES(Invalid, scalar->ValidateFull());
+ }
+
+ void TestEquals() {
+ // Differing values
+ ASSERT_FALSE(union_alpha_->Equals(union_beta_));
+ ASSERT_FALSE(union_two_->Equals(union_three_));
+ // Differing validities
+ ASSERT_FALSE(union_alpha_->Equals(union_string_null_));
+ // Differing types
+ ASSERT_FALSE(union_alpha_->Equals(union_two_));
+ ASSERT_FALSE(union_alpha_->Equals(union_other_two_));
+ // Type codes don't count when comparing union scalars: the underlying values
+ // are identical even though their provenance is different.
+ ASSERT_TRUE(union_two_->Equals(union_other_two_));
+ ASSERT_TRUE(union_string_null_->Equals(union_number_null_));
+ }
+
+ void TestMakeNullScalar() {
+ const auto scalar = MakeNullScalar(type_);
+ const auto& as_union = checked_cast<const UnionScalar&>(*scalar);
+ AssertTypeEqual(type_, as_union.type);
+ ASSERT_FALSE(as_union.is_valid);
+ ASSERT_EQ(as_union.value, nullptr);
+ // Abstractly, the type code must be valid.
+ // Concretely, the first child field is chosen.
+ ASSERT_EQ(as_union.type_code, 3);
+ }
+
+ protected:
+ std::shared_ptr<Scalar> MakeSpecificNullScalar(int8_t type_code) {
+ auto scal = MakeNullScalar(type_);
+ checked_cast<UnionScalar*>(scal.get())->type_code = type_code;
+ return scal;
+ }
+
+ std::shared_ptr<DataType> type_;
+ std::shared_ptr<Scalar> alpha_, beta_, two_, three_;
+ std::shared_ptr<Scalar> union_alpha_, union_beta_, union_two_, union_three_,
+ union_other_two_, union_string_null_, union_number_null_;
+};
+
+TYPED_TEST_SUITE(TestUnionScalar, UnionArrowTypes);
+
+TYPED_TEST(TestUnionScalar, Validate) { this->TestValidate(); }
+
+TYPED_TEST(TestUnionScalar, ValidateErrors) { this->TestValidateErrors(); }
+
+TYPED_TEST(TestUnionScalar, Equals) { this->TestEquals(); }
+
+TYPED_TEST(TestUnionScalar, MakeNullScalar) { this->TestMakeNullScalar(); }
+
+class TestSparseUnionScalar : public TestUnionScalar<SparseUnionType> {};
+
+TEST_F(TestSparseUnionScalar, GetScalar) {
+ ArrayVector children{ArrayFromJSON(utf8(), R"(["alpha", "", "beta", null, "gamma"])"),
+ ArrayFromJSON(uint64(), "[1, 2, 11, 22, null]"),
+ ArrayFromJSON(uint64(), "[100, 101, 102, 103, 104]")};
+
+ auto type_ids = ArrayFromJSON(int8(), "[3, 42, 3, 3, 42]");
+ SparseUnionArray arr(type_, 5, children, type_ids->data()->buffers[1]);
+ ASSERT_OK(arr.ValidateFull());
+
+ CheckGetValidUnionScalar(arr, 0, *union_alpha_, *alpha_);
+ CheckGetValidUnionScalar(arr, 1, *union_two_, *two_);
+ CheckGetValidUnionScalar(arr, 2, *union_beta_, *beta_);
+ CheckGetNullUnionScalar(arr, 3);
+ CheckGetNullUnionScalar(arr, 4);
+}
+
+class TestDenseUnionScalar : public TestUnionScalar<DenseUnionType> {};
+
+TEST_F(TestDenseUnionScalar, GetScalar) {
+ ArrayVector children{ArrayFromJSON(utf8(), R"(["alpha", "beta", null])"),
+ ArrayFromJSON(uint64(), "[2, 3]"), ArrayFromJSON(uint64(), "[]")};
+
+ auto type_ids = ArrayFromJSON(int8(), "[3, 42, 3, 3, 42]");
+ auto offsets = ArrayFromJSON(int32(), "[0, 0, 1, 2, 1]");
+ DenseUnionArray arr(type_, 5, children, type_ids->data()->buffers[1],
+ offsets->data()->buffers[1]);
+ ASSERT_OK(arr.ValidateFull());
+
+ CheckGetValidUnionScalar(arr, 0, *union_alpha_, *alpha_);
+ CheckGetValidUnionScalar(arr, 1, *union_two_, *two_);
+ CheckGetValidUnionScalar(arr, 2, *union_beta_, *beta_);
+ CheckGetNullUnionScalar(arr, 3);
+ CheckGetValidUnionScalar(arr, 4, *union_three_, *three_);
+}
+
+#define UUID_STRING1 "abcdefghijklmnop"
+#define UUID_STRING2 "zyxwvutsrqponmlk"
+
+class TestExtensionScalar : public ::testing::Test {
+ public:
+ void SetUp() {
+ type_ = uuid();
+ storage_type_ = fixed_size_binary(16);
+ uuid_type_ = checked_cast<const UuidType*>(type_.get());
+ }
+
+ protected:
+ ExtensionScalar MakeUuidScalar(util::string_view value) {
+ return ExtensionScalar(std::make_shared<FixedSizeBinaryScalar>(
+ std::make_shared<Buffer>(value), storage_type_),
+ type_);
+ }
+
+ std::shared_ptr<DataType> type_, storage_type_;
+ const UuidType* uuid_type_{nullptr};
+
+ const util::string_view uuid_string1_{UUID_STRING1};
+ const util::string_view uuid_string2_{UUID_STRING2};
+ const util::string_view uuid_json_{"[\"" UUID_STRING1 "\", \"" UUID_STRING2
+ "\", null]"};
+};
+
+#undef UUID_STRING1
+#undef UUID_STRING2
+
+TEST_F(TestExtensionScalar, Basics) {
+ const ExtensionScalar uuid_scalar = MakeUuidScalar(uuid_string1_);
+ ASSERT_OK(uuid_scalar.ValidateFull());
+ ASSERT_TRUE(uuid_scalar.is_valid);
+
+ const ExtensionScalar uuid_scalar2 = MakeUuidScalar(uuid_string2_);
+ ASSERT_OK(uuid_scalar2.ValidateFull());
+ ASSERT_TRUE(uuid_scalar2.is_valid);
+
+ const ExtensionScalar uuid_scalar3 = MakeUuidScalar(uuid_string2_);
+ ASSERT_OK(uuid_scalar2.ValidateFull());
+ ASSERT_TRUE(uuid_scalar2.is_valid);
+
+ const ExtensionScalar null_scalar(type_);
+ ASSERT_OK(null_scalar.ValidateFull());
+ ASSERT_FALSE(null_scalar.is_valid);
+
+ ASSERT_FALSE(uuid_scalar.Equals(uuid_scalar2));
+ ASSERT_TRUE(uuid_scalar2.Equals(uuid_scalar3));
+ ASSERT_FALSE(uuid_scalar.Equals(null_scalar));
+}
+
+TEST_F(TestExtensionScalar, MakeScalar) {
+ const ExtensionScalar null_scalar(type_);
+ const ExtensionScalar uuid_scalar = MakeUuidScalar(uuid_string1_);
+
+ auto scalar = CheckMakeNullScalar(type_);
+ ASSERT_OK(scalar->ValidateFull());
+ ASSERT_FALSE(scalar->is_valid);
+
+ ASSERT_OK_AND_ASSIGN(auto scalar2,
+ MakeScalar(type_, std::make_shared<Buffer>(uuid_string1_)));
+ ASSERT_OK(scalar2->ValidateFull());
+ ASSERT_TRUE(scalar2->is_valid);
+
+ ASSERT_OK_AND_ASSIGN(auto scalar3,
+ MakeScalar(type_, std::make_shared<Buffer>(uuid_string2_)));
+ ASSERT_OK(scalar3->ValidateFull());
+ ASSERT_TRUE(scalar3->is_valid);
+
+ ASSERT_TRUE(scalar->Equals(null_scalar));
+ ASSERT_TRUE(scalar2->Equals(uuid_scalar));
+ ASSERT_FALSE(scalar3->Equals(uuid_scalar));
+}
+
+TEST_F(TestExtensionScalar, GetScalar) {
+ const ExtensionScalar null_scalar(type_);
+ const ExtensionScalar uuid_scalar = MakeUuidScalar(uuid_string1_);
+ const ExtensionScalar uuid_scalar2 = MakeUuidScalar(uuid_string2_);
+
+ auto storage_array = ArrayFromJSON(storage_type_, uuid_json_);
+ auto array = ExtensionType::WrapArray(type_, storage_array);
+
+ ASSERT_OK_AND_ASSIGN(auto scalar, array->GetScalar(0));
+ ASSERT_OK(scalar->ValidateFull());
+ AssertTypeEqual(scalar->type, type_);
+ ASSERT_TRUE(scalar->is_valid);
+ ASSERT_TRUE(scalar->Equals(uuid_scalar));
+ ASSERT_FALSE(scalar->Equals(uuid_scalar2));
+
+ ASSERT_OK_AND_ASSIGN(scalar, array->GetScalar(1));
+ ASSERT_OK(scalar->ValidateFull());
+ AssertTypeEqual(scalar->type, type_);
+ ASSERT_TRUE(scalar->is_valid);
+ ASSERT_TRUE(scalar->Equals(uuid_scalar2));
+ ASSERT_FALSE(scalar->Equals(uuid_scalar));
+
+ ASSERT_OK_AND_ASSIGN(scalar, array->GetScalar(2));
+ ASSERT_OK(scalar->ValidateFull());
+ AssertTypeEqual(scalar->type, type_);
+ ASSERT_FALSE(scalar->is_valid);
+ ASSERT_TRUE(scalar->Equals(null_scalar));
+ ASSERT_FALSE(scalar->Equals(uuid_scalar));
+}
+
+TEST_F(TestExtensionScalar, ValidateErrors) {
+ // Mismatching is_valid and value
+ ExtensionScalar null_scalar(type_);
+ null_scalar.is_valid = true;
+ AssertValidationFails(null_scalar);
+
+ ExtensionScalar uuid_scalar = MakeUuidScalar(uuid_string1_);
+ uuid_scalar.is_valid = false;
+ AssertValidationFails(uuid_scalar);
+
+ // Null storage scalar
+ auto null_storage = std::make_shared<FixedSizeBinaryScalar>(storage_type_);
+ ExtensionScalar scalar(null_storage, type_);
+ scalar.is_valid = true;
+ AssertValidationFails(scalar);
+ scalar.is_valid = false;
+ AssertValidationFails(scalar);
+
+ // Invalid storage scalar (wrong length)
+ auto invalid_storage = std::make_shared<FixedSizeBinaryScalar>(storage_type_);
+ invalid_storage->is_valid = true;
+ invalid_storage->value = std::make_shared<Buffer>("123");
+ AssertValidationFails(*invalid_storage);
+ scalar = ExtensionScalar(invalid_storage, type_);
+ AssertValidationFails(scalar);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/sparse_tensor.cc b/src/arrow/cpp/src/arrow/sparse_tensor.cc
new file mode 100644
index 000000000..03d59c3d7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/sparse_tensor.cc
@@ -0,0 +1,478 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/sparse_tensor.h"
+#include "arrow/tensor/converter.h"
+
+#include <algorithm>
+#include <functional>
+#include <memory>
+#include <numeric>
+
+#include "arrow/compare.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+class MemoryPool;
+
+// ----------------------------------------------------------------------
+// SparseIndex
+
+Status SparseIndex::ValidateShape(const std::vector<int64_t>& shape) const {
+ if (!std::all_of(shape.begin(), shape.end(), [](int64_t x) { return x >= 0; })) {
+ return Status::Invalid("Shape elements must be positive");
+ }
+
+ return Status::OK();
+}
+
+namespace internal {
+namespace {
+
+template <typename IndexValueType>
+Status CheckSparseIndexMaximumValue(const std::vector<int64_t>& shape) {
+ using c_index_value_type = typename IndexValueType::c_type;
+ constexpr int64_t type_max =
+ static_cast<int64_t>(std::numeric_limits<c_index_value_type>::max());
+ auto greater_than_type_max = [&](int64_t x) { return x > type_max; };
+ if (std::any_of(shape.begin(), shape.end(), greater_than_type_max)) {
+ return Status::Invalid("The bit width of the index value type is too small");
+ }
+ return Status::OK();
+}
+
+template <>
+Status CheckSparseIndexMaximumValue<Int64Type>(const std::vector<int64_t>& shape) {
+ return Status::OK();
+}
+
+template <>
+Status CheckSparseIndexMaximumValue<UInt64Type>(const std::vector<int64_t>& shape) {
+ return Status::Invalid("UInt64Type cannot be used as IndexValueType of SparseIndex");
+}
+
+} // namespace
+
+#define CALL_CHECK_MAXIMUM_VALUE(TYPE_CLASS) \
+ case TYPE_CLASS##Type::type_id: \
+ return CheckSparseIndexMaximumValue<TYPE_CLASS##Type>(shape);
+
+Status CheckSparseIndexMaximumValue(const std::shared_ptr<DataType>& index_value_type,
+ const std::vector<int64_t>& shape) {
+ switch (index_value_type->id()) {
+ ARROW_GENERATE_FOR_ALL_INTEGER_TYPES(CALL_CHECK_MAXIMUM_VALUE);
+ default:
+ return Status::TypeError("Unsupported SparseTensor index value type");
+ }
+}
+
+#undef CALL_CHECK_MAXIMUM_VALUE
+
+Status MakeSparseTensorFromTensor(const Tensor& tensor,
+ SparseTensorFormat::type sparse_format_id,
+ const std::shared_ptr<DataType>& index_value_type,
+ MemoryPool* pool,
+ std::shared_ptr<SparseIndex>* out_sparse_index,
+ std::shared_ptr<Buffer>* out_data) {
+ switch (sparse_format_id) {
+ case SparseTensorFormat::COO:
+ return MakeSparseCOOTensorFromTensor(tensor, index_value_type, pool,
+ out_sparse_index, out_data);
+ case SparseTensorFormat::CSR:
+ return MakeSparseCSXMatrixFromTensor(SparseMatrixCompressedAxis::ROW, tensor,
+ index_value_type, pool, out_sparse_index,
+ out_data);
+ case SparseTensorFormat::CSC:
+ return MakeSparseCSXMatrixFromTensor(SparseMatrixCompressedAxis::COLUMN, tensor,
+ index_value_type, pool, out_sparse_index,
+ out_data);
+ case SparseTensorFormat::CSF:
+ return MakeSparseCSFTensorFromTensor(tensor, index_value_type, pool,
+ out_sparse_index, out_data);
+
+ // LCOV_EXCL_START: ignore program failure
+ default:
+ return Status::Invalid("Invalid sparse tensor format");
+ // LCOV_EXCL_STOP
+ }
+}
+
+} // namespace internal
+
+// ----------------------------------------------------------------------
+// SparseCOOIndex
+
+namespace {
+
+inline Status CheckSparseCOOIndexValidity(const std::shared_ptr<DataType>& type,
+ const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& strides) {
+ if (!is_integer(type->id())) {
+ return Status::TypeError("Type of SparseCOOIndex indices must be integer");
+ }
+ if (shape.size() != 2) {
+ return Status::Invalid("SparseCOOIndex indices must be a matrix");
+ }
+
+ RETURN_NOT_OK(internal::CheckSparseIndexMaximumValue(type, shape));
+
+ if (!internal::IsTensorStridesContiguous(type, shape, strides)) {
+ return Status::Invalid("SparseCOOIndex indices must be contiguous");
+ }
+ return Status::OK();
+}
+
+void GetCOOIndexTensorRow(const std::shared_ptr<Tensor>& coords, const int64_t row,
+ std::vector<int64_t>* out_index) {
+ const auto& fw_index_value_type =
+ internal::checked_cast<const FixedWidthType&>(*coords->type());
+ const size_t indices_elsize = fw_index_value_type.bit_width() / CHAR_BIT;
+
+ const auto& shape = coords->shape();
+ const int64_t non_zero_length = shape[0];
+ DCHECK(0 <= row && row < non_zero_length);
+
+ const int64_t ndim = shape[1];
+ out_index->resize(ndim);
+
+ switch (indices_elsize) {
+ case 1: // Int8, UInt8
+ for (int64_t i = 0; i < ndim; ++i) {
+ (*out_index)[i] = static_cast<int64_t>(coords->Value<UInt8Type>({row, i}));
+ }
+ break;
+ case 2: // Int16, UInt16
+ for (int64_t i = 0; i < ndim; ++i) {
+ (*out_index)[i] = static_cast<int64_t>(coords->Value<UInt16Type>({row, i}));
+ }
+ break;
+ case 4: // Int32, UInt32
+ for (int64_t i = 0; i < ndim; ++i) {
+ (*out_index)[i] = static_cast<int64_t>(coords->Value<UInt32Type>({row, i}));
+ }
+ break;
+ case 8: // Int64
+ for (int64_t i = 0; i < ndim; ++i) {
+ (*out_index)[i] = coords->Value<Int64Type>({row, i});
+ }
+ break;
+ default:
+ DCHECK(false) << "Must not reach here";
+ break;
+ }
+}
+
+bool DetectSparseCOOIndexCanonicality(const std::shared_ptr<Tensor>& coords) {
+ DCHECK_EQ(coords->ndim(), 2);
+
+ const auto& shape = coords->shape();
+ const int64_t non_zero_length = shape[0];
+ if (non_zero_length <= 1) return true;
+
+ const int64_t ndim = shape[1];
+ std::vector<int64_t> last_index, index;
+ GetCOOIndexTensorRow(coords, 0, &last_index);
+ for (int64_t i = 1; i < non_zero_length; ++i) {
+ GetCOOIndexTensorRow(coords, i, &index);
+ int64_t j = 0;
+ while (j < ndim) {
+ if (last_index[j] > index[j]) {
+ // last_index > index, so we can detect non-canonical here
+ return false;
+ }
+ if (last_index[j] < index[j]) {
+ // last_index < index, so we can skip the remaining dimensions
+ break;
+ }
+ ++j;
+ }
+ if (j == ndim) {
+ // last_index == index, so we can detect non-canonical here
+ return false;
+ }
+ swap(last_index, index);
+ }
+
+ return true;
+}
+
+} // namespace
+
+Result<std::shared_ptr<SparseCOOIndex>> SparseCOOIndex::Make(
+ const std::shared_ptr<Tensor>& coords, bool is_canonical) {
+ RETURN_NOT_OK(
+ CheckSparseCOOIndexValidity(coords->type(), coords->shape(), coords->strides()));
+ return std::make_shared<SparseCOOIndex>(coords, is_canonical);
+}
+
+Result<std::shared_ptr<SparseCOOIndex>> SparseCOOIndex::Make(
+ const std::shared_ptr<Tensor>& coords) {
+ RETURN_NOT_OK(
+ CheckSparseCOOIndexValidity(coords->type(), coords->shape(), coords->strides()));
+ auto is_canonical = DetectSparseCOOIndexCanonicality(coords);
+ return std::make_shared<SparseCOOIndex>(coords, is_canonical);
+}
+
+Result<std::shared_ptr<SparseCOOIndex>> SparseCOOIndex::Make(
+ const std::shared_ptr<DataType>& indices_type,
+ const std::vector<int64_t>& indices_shape,
+ const std::vector<int64_t>& indices_strides, std::shared_ptr<Buffer> indices_data,
+ bool is_canonical) {
+ RETURN_NOT_OK(
+ CheckSparseCOOIndexValidity(indices_type, indices_shape, indices_strides));
+ return std::make_shared<SparseCOOIndex>(
+ std::make_shared<Tensor>(indices_type, indices_data, indices_shape,
+ indices_strides),
+ is_canonical);
+}
+
+Result<std::shared_ptr<SparseCOOIndex>> SparseCOOIndex::Make(
+ const std::shared_ptr<DataType>& indices_type,
+ const std::vector<int64_t>& indices_shape,
+ const std::vector<int64_t>& indices_strides, std::shared_ptr<Buffer> indices_data) {
+ RETURN_NOT_OK(
+ CheckSparseCOOIndexValidity(indices_type, indices_shape, indices_strides));
+ auto coords = std::make_shared<Tensor>(indices_type, indices_data, indices_shape,
+ indices_strides);
+ auto is_canonical = DetectSparseCOOIndexCanonicality(coords);
+ return std::make_shared<SparseCOOIndex>(coords, is_canonical);
+}
+
+Result<std::shared_ptr<SparseCOOIndex>> SparseCOOIndex::Make(
+ const std::shared_ptr<DataType>& indices_type, const std::vector<int64_t>& shape,
+ int64_t non_zero_length, std::shared_ptr<Buffer> indices_data, bool is_canonical) {
+ auto ndim = static_cast<int64_t>(shape.size());
+ if (!is_integer(indices_type->id())) {
+ return Status::TypeError("Type of SparseCOOIndex indices must be integer");
+ }
+ const int64_t elsize =
+ internal::checked_cast<const IntegerType&>(*indices_type).bit_width() / 8;
+ std::vector<int64_t> indices_shape({non_zero_length, ndim});
+ std::vector<int64_t> indices_strides({elsize * ndim, elsize});
+ return Make(indices_type, indices_shape, indices_strides, indices_data, is_canonical);
+}
+
+Result<std::shared_ptr<SparseCOOIndex>> SparseCOOIndex::Make(
+ const std::shared_ptr<DataType>& indices_type, const std::vector<int64_t>& shape,
+ int64_t non_zero_length, std::shared_ptr<Buffer> indices_data) {
+ auto ndim = static_cast<int64_t>(shape.size());
+ if (!is_integer(indices_type->id())) {
+ return Status::TypeError("Type of SparseCOOIndex indices must be integer");
+ }
+ const int64_t elsize = internal::GetByteWidth(*indices_type);
+ std::vector<int64_t> indices_shape({non_zero_length, ndim});
+ std::vector<int64_t> indices_strides({elsize * ndim, elsize});
+ return Make(indices_type, indices_shape, indices_strides, indices_data);
+}
+
+// Constructor with a contiguous NumericTensor
+SparseCOOIndex::SparseCOOIndex(const std::shared_ptr<Tensor>& coords, bool is_canonical)
+ : SparseIndexBase(), coords_(coords), is_canonical_(is_canonical) {
+ ARROW_CHECK_OK(
+ CheckSparseCOOIndexValidity(coords_->type(), coords_->shape(), coords_->strides()));
+}
+
+std::string SparseCOOIndex::ToString() const { return std::string("SparseCOOIndex"); }
+
+// ----------------------------------------------------------------------
+// SparseCSXIndex
+
+namespace internal {
+
+Status ValidateSparseCSXIndex(const std::shared_ptr<DataType>& indptr_type,
+ const std::shared_ptr<DataType>& indices_type,
+ const std::vector<int64_t>& indptr_shape,
+ const std::vector<int64_t>& indices_shape,
+ char const* type_name) {
+ if (!is_integer(indptr_type->id())) {
+ return Status::TypeError("Type of ", type_name, " indptr must be integer");
+ }
+ if (indptr_shape.size() != 1) {
+ return Status::Invalid(type_name, " indptr must be a vector");
+ }
+ if (!is_integer(indices_type->id())) {
+ return Status::Invalid("Type of ", type_name, " indices must be integer");
+ }
+ if (indices_shape.size() != 1) {
+ return Status::Invalid(type_name, " indices must be a vector");
+ }
+
+ RETURN_NOT_OK(internal::CheckSparseIndexMaximumValue(indptr_type, indptr_shape));
+ RETURN_NOT_OK(internal::CheckSparseIndexMaximumValue(indices_type, indices_shape));
+
+ return Status::OK();
+}
+
+void CheckSparseCSXIndexValidity(const std::shared_ptr<DataType>& indptr_type,
+ const std::shared_ptr<DataType>& indices_type,
+ const std::vector<int64_t>& indptr_shape,
+ const std::vector<int64_t>& indices_shape,
+ char const* type_name) {
+ ARROW_CHECK_OK(ValidateSparseCSXIndex(indptr_type, indices_type, indptr_shape,
+ indices_shape, type_name));
+}
+
+} // namespace internal
+
+// ----------------------------------------------------------------------
+// SparseCSFIndex
+
+namespace {
+
+inline Status CheckSparseCSFIndexValidity(const std::shared_ptr<DataType>& indptr_type,
+ const std::shared_ptr<DataType>& indices_type,
+ const int64_t num_indptrs,
+ const int64_t num_indices,
+ const int64_t axis_order_size) {
+ if (!is_integer(indptr_type->id())) {
+ return Status::TypeError("Type of SparseCSFIndex indptr must be integer");
+ }
+ if (!is_integer(indices_type->id())) {
+ return Status::TypeError("Type of SparseCSFIndex indices must be integer");
+ }
+ if (num_indptrs + 1 != num_indices) {
+ return Status::Invalid(
+ "Length of indices must be equal to length of indptrs + 1 for SparseCSFIndex.");
+ }
+ if (axis_order_size != num_indices) {
+ return Status::Invalid(
+ "Length of indices must be equal to number of dimensions for SparseCSFIndex.");
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Result<std::shared_ptr<SparseCSFIndex>> SparseCSFIndex::Make(
+ const std::shared_ptr<DataType>& indptr_type,
+ const std::shared_ptr<DataType>& indices_type,
+ const std::vector<int64_t>& indices_shapes, const std::vector<int64_t>& axis_order,
+ const std::vector<std::shared_ptr<Buffer>>& indptr_data,
+ const std::vector<std::shared_ptr<Buffer>>& indices_data) {
+ int64_t ndim = axis_order.size();
+ std::vector<std::shared_ptr<Tensor>> indptr(ndim - 1);
+ std::vector<std::shared_ptr<Tensor>> indices(ndim);
+
+ for (int64_t i = 0; i < ndim - 1; ++i)
+ indptr[i] = std::make_shared<Tensor>(indptr_type, indptr_data[i],
+ std::vector<int64_t>({indices_shapes[i] + 1}));
+ for (int64_t i = 0; i < ndim; ++i)
+ indices[i] = std::make_shared<Tensor>(indices_type, indices_data[i],
+ std::vector<int64_t>({indices_shapes[i]}));
+
+ RETURN_NOT_OK(CheckSparseCSFIndexValidity(indptr_type, indices_type, indptr.size(),
+ indices.size(), axis_order.size()));
+
+ for (auto tensor : indptr) {
+ RETURN_NOT_OK(internal::CheckSparseIndexMaximumValue(indptr_type, tensor->shape()));
+ }
+
+ for (auto tensor : indices) {
+ RETURN_NOT_OK(internal::CheckSparseIndexMaximumValue(indices_type, tensor->shape()));
+ }
+
+ return std::make_shared<SparseCSFIndex>(indptr, indices, axis_order);
+}
+
+// Constructor with two index vectors
+SparseCSFIndex::SparseCSFIndex(const std::vector<std::shared_ptr<Tensor>>& indptr,
+ const std::vector<std::shared_ptr<Tensor>>& indices,
+ const std::vector<int64_t>& axis_order)
+ : SparseIndexBase(), indptr_(indptr), indices_(indices), axis_order_(axis_order) {
+ ARROW_CHECK_OK(CheckSparseCSFIndexValidity(indptr_.front()->type(),
+ indices_.front()->type(), indptr_.size(),
+ indices_.size(), axis_order_.size()));
+}
+
+std::string SparseCSFIndex::ToString() const { return std::string("SparseCSFIndex"); }
+
+bool SparseCSFIndex::Equals(const SparseCSFIndex& other) const {
+ for (int64_t i = 0; i < static_cast<int64_t>(indices().size()); ++i) {
+ if (!indices()[i]->Equals(*other.indices()[i])) return false;
+ }
+ for (int64_t i = 0; i < static_cast<int64_t>(indptr().size()); ++i) {
+ if (!indptr()[i]->Equals(*other.indptr()[i])) return false;
+ }
+ return axis_order() == other.axis_order();
+}
+
+// ----------------------------------------------------------------------
+// SparseTensor
+
+// Constructor with all attributes
+SparseTensor::SparseTensor(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Buffer>& data,
+ const std::vector<int64_t>& shape,
+ const std::shared_ptr<SparseIndex>& sparse_index,
+ const std::vector<std::string>& dim_names)
+ : type_(type),
+ data_(data),
+ shape_(shape),
+ sparse_index_(sparse_index),
+ dim_names_(dim_names) {
+ ARROW_CHECK(is_tensor_supported(type->id()));
+}
+
+const std::string& SparseTensor::dim_name(int i) const {
+ static const std::string kEmpty = "";
+ if (dim_names_.size() == 0) {
+ return kEmpty;
+ } else {
+ ARROW_CHECK_LT(i, static_cast<int>(dim_names_.size()));
+ return dim_names_[i];
+ }
+}
+
+int64_t SparseTensor::size() const {
+ return std::accumulate(shape_.begin(), shape_.end(), 1LL, std::multiplies<int64_t>());
+}
+
+bool SparseTensor::Equals(const SparseTensor& other, const EqualOptions& opts) const {
+ return SparseTensorEquals(*this, other, opts);
+}
+
+Result<std::shared_ptr<Tensor>> SparseTensor::ToTensor(MemoryPool* pool) const {
+ switch (format_id()) {
+ case SparseTensorFormat::COO:
+ return MakeTensorFromSparseCOOTensor(
+ pool, internal::checked_cast<const SparseCOOTensor*>(this));
+ break;
+
+ case SparseTensorFormat::CSR:
+ return MakeTensorFromSparseCSRMatrix(
+ pool, internal::checked_cast<const SparseCSRMatrix*>(this));
+ break;
+
+ case SparseTensorFormat::CSC:
+ return MakeTensorFromSparseCSCMatrix(
+ pool, internal::checked_cast<const SparseCSCMatrix*>(this));
+ break;
+
+ case SparseTensorFormat::CSF:
+ return MakeTensorFromSparseCSFTensor(
+ pool, internal::checked_cast<const SparseCSFTensor*>(this));
+
+ default:
+ return Status::NotImplemented("Unsupported SparseIndex format type");
+ }
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/sparse_tensor.h b/src/arrow/cpp/src/arrow/sparse_tensor.h
new file mode 100644
index 000000000..4ec824dfa
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/sparse_tensor.h
@@ -0,0 +1,617 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/compare.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/tensor.h" // IWYU pragma: export
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class MemoryPool;
+
+namespace internal {
+
+ARROW_EXPORT
+Status CheckSparseIndexMaximumValue(const std::shared_ptr<DataType>& index_value_type,
+ const std::vector<int64_t>& shape);
+
+} // namespace internal
+
+// ----------------------------------------------------------------------
+// SparseIndex class
+
+struct SparseTensorFormat {
+ /// EXPERIMENTAL: The index format type of SparseTensor
+ enum type {
+ /// Coordinate list (COO) format.
+ COO,
+ /// Compressed sparse row (CSR) format.
+ CSR,
+ /// Compressed sparse column (CSC) format.
+ CSC,
+ /// Compressed sparse fiber (CSF) format.
+ CSF
+ };
+};
+
+/// \brief EXPERIMENTAL: The base class for the index of a sparse tensor
+///
+/// SparseIndex describes where the non-zero elements are within a SparseTensor.
+///
+/// There are several ways to represent this. The format_id is used to
+/// distinguish what kind of representation is used. Each possible value of
+/// format_id must have only one corresponding concrete subclass of SparseIndex.
+class ARROW_EXPORT SparseIndex {
+ public:
+ explicit SparseIndex(SparseTensorFormat::type format_id) : format_id_(format_id) {}
+
+ virtual ~SparseIndex() = default;
+
+ /// \brief Return the identifier of the format type
+ SparseTensorFormat::type format_id() const { return format_id_; }
+
+ /// \brief Return the number of non zero values in the sparse tensor related
+ /// to this sparse index
+ virtual int64_t non_zero_length() const = 0;
+
+ /// \brief Return the string representation of the sparse index
+ virtual std::string ToString() const = 0;
+
+ virtual Status ValidateShape(const std::vector<int64_t>& shape) const;
+
+ protected:
+ const SparseTensorFormat::type format_id_;
+};
+
+namespace internal {
+template <typename SparseIndexType>
+class SparseIndexBase : public SparseIndex {
+ public:
+ SparseIndexBase() : SparseIndex(SparseIndexType::format_id) {}
+};
+} // namespace internal
+
+// ----------------------------------------------------------------------
+// SparseCOOIndex class
+
+/// \brief EXPERIMENTAL: The index data for a COO sparse tensor
+///
+/// A COO sparse index manages the location of its non-zero values by their
+/// coordinates.
+class ARROW_EXPORT SparseCOOIndex : public internal::SparseIndexBase<SparseCOOIndex> {
+ public:
+ static constexpr SparseTensorFormat::type format_id = SparseTensorFormat::COO;
+
+ /// \brief Make SparseCOOIndex from a coords tensor and canonicality
+ static Result<std::shared_ptr<SparseCOOIndex>> Make(
+ const std::shared_ptr<Tensor>& coords, bool is_canonical);
+
+ /// \brief Make SparseCOOIndex from a coords tensor with canonicality auto-detection
+ static Result<std::shared_ptr<SparseCOOIndex>> Make(
+ const std::shared_ptr<Tensor>& coords);
+
+ /// \brief Make SparseCOOIndex from raw properties with canonicality auto-detection
+ static Result<std::shared_ptr<SparseCOOIndex>> Make(
+ const std::shared_ptr<DataType>& indices_type,
+ const std::vector<int64_t>& indices_shape,
+ const std::vector<int64_t>& indices_strides, std::shared_ptr<Buffer> indices_data);
+
+ /// \brief Make SparseCOOIndex from raw properties
+ static Result<std::shared_ptr<SparseCOOIndex>> Make(
+ const std::shared_ptr<DataType>& indices_type,
+ const std::vector<int64_t>& indices_shape,
+ const std::vector<int64_t>& indices_strides, std::shared_ptr<Buffer> indices_data,
+ bool is_canonical);
+
+ /// \brief Make SparseCOOIndex from sparse tensor's shape properties and data
+ /// with canonicality auto-detection
+ ///
+ /// The indices_data should be in row-major (C-like) order. If not,
+ /// use the raw properties constructor.
+ static Result<std::shared_ptr<SparseCOOIndex>> Make(
+ const std::shared_ptr<DataType>& indices_type, const std::vector<int64_t>& shape,
+ int64_t non_zero_length, std::shared_ptr<Buffer> indices_data);
+
+ /// \brief Make SparseCOOIndex from sparse tensor's shape properties and data
+ ///
+ /// The indices_data should be in row-major (C-like) order. If not,
+ /// use the raw properties constructor.
+ static Result<std::shared_ptr<SparseCOOIndex>> Make(
+ const std::shared_ptr<DataType>& indices_type, const std::vector<int64_t>& shape,
+ int64_t non_zero_length, std::shared_ptr<Buffer> indices_data, bool is_canonical);
+
+ /// \brief Construct SparseCOOIndex from column-major NumericTensor
+ explicit SparseCOOIndex(const std::shared_ptr<Tensor>& coords, bool is_canonical);
+
+ /// \brief Return a tensor that has the coordinates of the non-zero values
+ ///
+ /// The returned tensor is a N x D tensor where N is the number of non-zero
+ /// values and D is the number of dimensions in the logical data.
+ /// The column at index `i` is a D-tuple of coordinates indicating that the
+ /// logical value at those coordinates should be found at physical index `i`.
+ const std::shared_ptr<Tensor>& indices() const { return coords_; }
+
+ /// \brief Return the number of non zero values in the sparse tensor related
+ /// to this sparse index
+ int64_t non_zero_length() const override { return coords_->shape()[0]; }
+
+ /// \brief Return whether a sparse tensor index is canonical, or not.
+ /// If a sparse tensor index is canonical, it is sorted in the lexicographical order,
+ /// and the corresponding sparse tensor doesn't have duplicated entries.
+ bool is_canonical() const { return is_canonical_; }
+
+ /// \brief Return a string representation of the sparse index
+ std::string ToString() const override;
+
+ /// \brief Return whether the COO indices are equal
+ bool Equals(const SparseCOOIndex& other) const {
+ return indices()->Equals(*other.indices());
+ }
+
+ inline Status ValidateShape(const std::vector<int64_t>& shape) const override {
+ ARROW_RETURN_NOT_OK(SparseIndex::ValidateShape(shape));
+
+ if (static_cast<size_t>(coords_->shape()[1]) == shape.size()) {
+ return Status::OK();
+ }
+
+ return Status::Invalid(
+ "shape length is inconsistent with the coords matrix in COO index");
+ }
+
+ protected:
+ std::shared_ptr<Tensor> coords_;
+ bool is_canonical_;
+};
+
+namespace internal {
+
+/// EXPERIMENTAL: The axis to be compressed
+enum class SparseMatrixCompressedAxis : char {
+ /// The value for CSR matrix
+ ROW,
+ /// The value for CSC matrix
+ COLUMN
+};
+
+ARROW_EXPORT
+Status ValidateSparseCSXIndex(const std::shared_ptr<DataType>& indptr_type,
+ const std::shared_ptr<DataType>& indices_type,
+ const std::vector<int64_t>& indptr_shape,
+ const std::vector<int64_t>& indices_shape,
+ char const* type_name);
+
+ARROW_EXPORT
+void CheckSparseCSXIndexValidity(const std::shared_ptr<DataType>& indptr_type,
+ const std::shared_ptr<DataType>& indices_type,
+ const std::vector<int64_t>& indptr_shape,
+ const std::vector<int64_t>& indices_shape,
+ char const* type_name);
+
+template <typename SparseIndexType, SparseMatrixCompressedAxis COMPRESSED_AXIS>
+class SparseCSXIndex : public SparseIndexBase<SparseIndexType> {
+ public:
+ static constexpr SparseMatrixCompressedAxis kCompressedAxis = COMPRESSED_AXIS;
+
+ /// \brief Make a subclass of SparseCSXIndex from raw properties
+ static Result<std::shared_ptr<SparseIndexType>> Make(
+ const std::shared_ptr<DataType>& indptr_type,
+ const std::shared_ptr<DataType>& indices_type,
+ const std::vector<int64_t>& indptr_shape, const std::vector<int64_t>& indices_shape,
+ std::shared_ptr<Buffer> indptr_data, std::shared_ptr<Buffer> indices_data) {
+ ARROW_RETURN_NOT_OK(ValidateSparseCSXIndex(indptr_type, indices_type, indptr_shape,
+ indices_shape,
+ SparseIndexType::kTypeName));
+ return std::make_shared<SparseIndexType>(
+ std::make_shared<Tensor>(indptr_type, indptr_data, indptr_shape),
+ std::make_shared<Tensor>(indices_type, indices_data, indices_shape));
+ }
+
+ /// \brief Make a subclass of SparseCSXIndex from raw properties
+ static Result<std::shared_ptr<SparseIndexType>> Make(
+ const std::shared_ptr<DataType>& indices_type,
+ const std::vector<int64_t>& indptr_shape, const std::vector<int64_t>& indices_shape,
+ std::shared_ptr<Buffer> indptr_data, std::shared_ptr<Buffer> indices_data) {
+ return Make(indices_type, indices_type, indptr_shape, indices_shape, indptr_data,
+ indices_data);
+ }
+
+ /// \brief Make a subclass of SparseCSXIndex from sparse tensor's shape properties and
+ /// data
+ static Result<std::shared_ptr<SparseIndexType>> Make(
+ const std::shared_ptr<DataType>& indptr_type,
+ const std::shared_ptr<DataType>& indices_type, const std::vector<int64_t>& shape,
+ int64_t non_zero_length, std::shared_ptr<Buffer> indptr_data,
+ std::shared_ptr<Buffer> indices_data) {
+ std::vector<int64_t> indptr_shape({shape[0] + 1});
+ std::vector<int64_t> indices_shape({non_zero_length});
+ return Make(indptr_type, indices_type, indptr_shape, indices_shape, indptr_data,
+ indices_data);
+ }
+
+ /// \brief Make a subclass of SparseCSXIndex from sparse tensor's shape properties and
+ /// data
+ static Result<std::shared_ptr<SparseIndexType>> Make(
+ const std::shared_ptr<DataType>& indices_type, const std::vector<int64_t>& shape,
+ int64_t non_zero_length, std::shared_ptr<Buffer> indptr_data,
+ std::shared_ptr<Buffer> indices_data) {
+ return Make(indices_type, indices_type, shape, non_zero_length, indptr_data,
+ indices_data);
+ }
+
+ /// \brief Construct SparseCSXIndex from two index vectors
+ explicit SparseCSXIndex(const std::shared_ptr<Tensor>& indptr,
+ const std::shared_ptr<Tensor>& indices)
+ : SparseIndexBase<SparseIndexType>(), indptr_(indptr), indices_(indices) {
+ CheckSparseCSXIndexValidity(indptr_->type(), indices_->type(), indptr_->shape(),
+ indices_->shape(), SparseIndexType::kTypeName);
+ }
+
+ /// \brief Return a 1D tensor of indptr vector
+ const std::shared_ptr<Tensor>& indptr() const { return indptr_; }
+
+ /// \brief Return a 1D tensor of indices vector
+ const std::shared_ptr<Tensor>& indices() const { return indices_; }
+
+ /// \brief Return the number of non zero values in the sparse tensor related
+ /// to this sparse index
+ int64_t non_zero_length() const override { return indices_->shape()[0]; }
+
+ /// \brief Return a string representation of the sparse index
+ std::string ToString() const override {
+ return std::string(SparseIndexType::kTypeName);
+ }
+
+ /// \brief Return whether the CSR indices are equal
+ bool Equals(const SparseIndexType& other) const {
+ return indptr()->Equals(*other.indptr()) && indices()->Equals(*other.indices());
+ }
+
+ inline Status ValidateShape(const std::vector<int64_t>& shape) const override {
+ ARROW_RETURN_NOT_OK(SparseIndex::ValidateShape(shape));
+
+ if (shape.size() < 2) {
+ return Status::Invalid("shape length is too short");
+ }
+
+ if (shape.size() > 2) {
+ return Status::Invalid("shape length is too long");
+ }
+
+ if (indptr_->shape()[0] == shape[static_cast<int64_t>(kCompressedAxis)] + 1) {
+ return Status::OK();
+ }
+
+ return Status::Invalid("shape length is inconsistent with the ", ToString());
+ }
+
+ protected:
+ std::shared_ptr<Tensor> indptr_;
+ std::shared_ptr<Tensor> indices_;
+};
+
+} // namespace internal
+
+// ----------------------------------------------------------------------
+// SparseCSRIndex class
+
+/// \brief EXPERIMENTAL: The index data for a CSR sparse matrix
+///
+/// A CSR sparse index manages the location of its non-zero values by two
+/// vectors.
+///
+/// The first vector, called indptr, represents the range of the rows; the i-th
+/// row spans from indptr[i] to indptr[i+1] in the corresponding value vector.
+/// So the length of an indptr vector is the number of rows + 1.
+///
+/// The other vector, called indices, represents the column indices of the
+/// corresponding non-zero values. So the length of an indices vector is same
+/// as the number of non-zero-values.
+class ARROW_EXPORT SparseCSRIndex
+ : public internal::SparseCSXIndex<SparseCSRIndex,
+ internal::SparseMatrixCompressedAxis::ROW> {
+ public:
+ using BaseClass =
+ internal::SparseCSXIndex<SparseCSRIndex, internal::SparseMatrixCompressedAxis::ROW>;
+
+ static constexpr SparseTensorFormat::type format_id = SparseTensorFormat::CSR;
+ static constexpr char const* kTypeName = "SparseCSRIndex";
+
+ using SparseCSXIndex::kCompressedAxis;
+ using SparseCSXIndex::Make;
+ using SparseCSXIndex::SparseCSXIndex;
+};
+
+// ----------------------------------------------------------------------
+// SparseCSCIndex class
+
+/// \brief EXPERIMENTAL: The index data for a CSC sparse matrix
+///
+/// A CSC sparse index manages the location of its non-zero values by two
+/// vectors.
+///
+/// The first vector, called indptr, represents the range of the column; the i-th
+/// column spans from indptr[i] to indptr[i+1] in the corresponding value vector.
+/// So the length of an indptr vector is the number of columns + 1.
+///
+/// The other vector, called indices, represents the row indices of the
+/// corresponding non-zero values. So the length of an indices vector is same
+/// as the number of non-zero-values.
+class ARROW_EXPORT SparseCSCIndex
+ : public internal::SparseCSXIndex<SparseCSCIndex,
+ internal::SparseMatrixCompressedAxis::COLUMN> {
+ public:
+ using BaseClass =
+ internal::SparseCSXIndex<SparseCSCIndex,
+ internal::SparseMatrixCompressedAxis::COLUMN>;
+
+ static constexpr SparseTensorFormat::type format_id = SparseTensorFormat::CSC;
+ static constexpr char const* kTypeName = "SparseCSCIndex";
+
+ using SparseCSXIndex::kCompressedAxis;
+ using SparseCSXIndex::Make;
+ using SparseCSXIndex::SparseCSXIndex;
+};
+
+// ----------------------------------------------------------------------
+// SparseCSFIndex class
+
+/// \brief EXPERIMENTAL: The index data for a CSF sparse tensor
+///
+/// A CSF sparse index manages the location of its non-zero values by set of
+/// prefix trees. Each path from a root to leaf forms one tensor non-zero index.
+/// CSF is implemented with three vectors.
+///
+/// Vectors inptr and indices contain N-1 and N buffers respectively, where N is the
+/// number of dimensions. Axis_order is a vector of integers of length N. Indptr and
+/// indices describe the set of prefix trees. Trees traverse dimensions in order given by
+/// axis_order.
+class ARROW_EXPORT SparseCSFIndex : public internal::SparseIndexBase<SparseCSFIndex> {
+ public:
+ static constexpr SparseTensorFormat::type format_id = SparseTensorFormat::CSF;
+ static constexpr char const* kTypeName = "SparseCSFIndex";
+
+ /// \brief Make SparseCSFIndex from raw properties
+ static Result<std::shared_ptr<SparseCSFIndex>> Make(
+ const std::shared_ptr<DataType>& indptr_type,
+ const std::shared_ptr<DataType>& indices_type,
+ const std::vector<int64_t>& indices_shapes, const std::vector<int64_t>& axis_order,
+ const std::vector<std::shared_ptr<Buffer>>& indptr_data,
+ const std::vector<std::shared_ptr<Buffer>>& indices_data);
+
+ /// \brief Make SparseCSFIndex from raw properties
+ static Result<std::shared_ptr<SparseCSFIndex>> Make(
+ const std::shared_ptr<DataType>& indices_type,
+ const std::vector<int64_t>& indices_shapes, const std::vector<int64_t>& axis_order,
+ const std::vector<std::shared_ptr<Buffer>>& indptr_data,
+ const std::vector<std::shared_ptr<Buffer>>& indices_data) {
+ return Make(indices_type, indices_type, indices_shapes, axis_order, indptr_data,
+ indices_data);
+ }
+
+ /// \brief Construct SparseCSFIndex from two index vectors
+ explicit SparseCSFIndex(const std::vector<std::shared_ptr<Tensor>>& indptr,
+ const std::vector<std::shared_ptr<Tensor>>& indices,
+ const std::vector<int64_t>& axis_order);
+
+ /// \brief Return a 1D vector of indptr tensors
+ const std::vector<std::shared_ptr<Tensor>>& indptr() const { return indptr_; }
+
+ /// \brief Return a 1D vector of indices tensors
+ const std::vector<std::shared_ptr<Tensor>>& indices() const { return indices_; }
+
+ /// \brief Return a 1D vector specifying the order of axes
+ const std::vector<int64_t>& axis_order() const { return axis_order_; }
+
+ /// \brief Return the number of non zero values in the sparse tensor related
+ /// to this sparse index
+ int64_t non_zero_length() const override { return indices_.back()->shape()[0]; }
+
+ /// \brief Return a string representation of the sparse index
+ std::string ToString() const override;
+
+ /// \brief Return whether the CSF indices are equal
+ bool Equals(const SparseCSFIndex& other) const;
+
+ protected:
+ std::vector<std::shared_ptr<Tensor>> indptr_;
+ std::vector<std::shared_ptr<Tensor>> indices_;
+ std::vector<int64_t> axis_order_;
+};
+
+// ----------------------------------------------------------------------
+// SparseTensor class
+
+/// \brief EXPERIMENTAL: The base class of sparse tensor container
+class ARROW_EXPORT SparseTensor {
+ public:
+ virtual ~SparseTensor() = default;
+
+ SparseTensorFormat::type format_id() const { return sparse_index_->format_id(); }
+
+ /// \brief Return a value type of the sparse tensor
+ std::shared_ptr<DataType> type() const { return type_; }
+
+ /// \brief Return a buffer that contains the value vector of the sparse tensor
+ std::shared_ptr<Buffer> data() const { return data_; }
+
+ /// \brief Return an immutable raw data pointer
+ const uint8_t* raw_data() const { return data_->data(); }
+
+ /// \brief Return a mutable raw data pointer
+ uint8_t* raw_mutable_data() const { return data_->mutable_data(); }
+
+ /// \brief Return a shape vector of the sparse tensor
+ const std::vector<int64_t>& shape() const { return shape_; }
+
+ /// \brief Return a sparse index of the sparse tensor
+ const std::shared_ptr<SparseIndex>& sparse_index() const { return sparse_index_; }
+
+ /// \brief Return a number of dimensions of the sparse tensor
+ int ndim() const { return static_cast<int>(shape_.size()); }
+
+ /// \brief Return a vector of dimension names
+ const std::vector<std::string>& dim_names() const { return dim_names_; }
+
+ /// \brief Return the name of the i-th dimension
+ const std::string& dim_name(int i) const;
+
+ /// \brief Total number of value cells in the sparse tensor
+ int64_t size() const;
+
+ /// \brief Return true if the underlying data buffer is mutable
+ bool is_mutable() const { return data_->is_mutable(); }
+
+ /// \brief Total number of non-zero cells in the sparse tensor
+ int64_t non_zero_length() const {
+ return sparse_index_ ? sparse_index_->non_zero_length() : 0;
+ }
+
+ /// \brief Return whether sparse tensors are equal
+ bool Equals(const SparseTensor& other,
+ const EqualOptions& = EqualOptions::Defaults()) const;
+
+ /// \brief Return dense representation of sparse tensor as tensor
+ ///
+ /// The returned Tensor has row-major order (C-like).
+ Result<std::shared_ptr<Tensor>> ToTensor(MemoryPool* pool) const;
+ Result<std::shared_ptr<Tensor>> ToTensor() const {
+ return ToTensor(default_memory_pool());
+ }
+
+ protected:
+ // Constructor with all attributes
+ SparseTensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
+ const std::vector<int64_t>& shape,
+ const std::shared_ptr<SparseIndex>& sparse_index,
+ const std::vector<std::string>& dim_names);
+
+ std::shared_ptr<DataType> type_;
+ std::shared_ptr<Buffer> data_;
+ std::vector<int64_t> shape_;
+ std::shared_ptr<SparseIndex> sparse_index_;
+
+ // These names are optional
+ std::vector<std::string> dim_names_;
+};
+
+// ----------------------------------------------------------------------
+// SparseTensorImpl class
+
+namespace internal {
+
+ARROW_EXPORT
+Status MakeSparseTensorFromTensor(const Tensor& tensor,
+ SparseTensorFormat::type sparse_format_id,
+ const std::shared_ptr<DataType>& index_value_type,
+ MemoryPool* pool,
+ std::shared_ptr<SparseIndex>* out_sparse_index,
+ std::shared_ptr<Buffer>* out_data);
+
+} // namespace internal
+
+/// \brief EXPERIMENTAL: Concrete sparse tensor implementation classes with sparse index
+/// type
+template <typename SparseIndexType>
+class SparseTensorImpl : public SparseTensor {
+ public:
+ virtual ~SparseTensorImpl() = default;
+
+ /// \brief Construct a sparse tensor from physical data buffer and logical index
+ SparseTensorImpl(const std::shared_ptr<SparseIndexType>& sparse_index,
+ const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape,
+ const std::vector<std::string>& dim_names)
+ : SparseTensor(type, data, shape, sparse_index, dim_names) {}
+
+ /// \brief Construct an empty sparse tensor
+ SparseTensorImpl(const std::shared_ptr<DataType>& type,
+ const std::vector<int64_t>& shape,
+ const std::vector<std::string>& dim_names = {})
+ : SparseTensorImpl(NULLPTR, type, NULLPTR, shape, dim_names) {}
+
+ /// \brief Create a SparseTensor with full parameters
+ static inline Result<std::shared_ptr<SparseTensorImpl<SparseIndexType>>> Make(
+ const std::shared_ptr<SparseIndexType>& sparse_index,
+ const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
+ const std::vector<int64_t>& shape, const std::vector<std::string>& dim_names) {
+ if (!is_tensor_supported(type->id())) {
+ return Status::Invalid(type->ToString(),
+ " is not valid data type for a sparse tensor");
+ }
+ ARROW_RETURN_NOT_OK(sparse_index->ValidateShape(shape));
+ if (dim_names.size() > 0 && dim_names.size() != shape.size()) {
+ return Status::Invalid("dim_names length is inconsistent with shape");
+ }
+ return std::make_shared<SparseTensorImpl<SparseIndexType>>(sparse_index, type, data,
+ shape, dim_names);
+ }
+
+ /// \brief Create a sparse tensor from a dense tensor
+ ///
+ /// The dense tensor is re-encoded as a sparse index and a physical
+ /// data buffer for the non-zero value.
+ static inline Result<std::shared_ptr<SparseTensorImpl<SparseIndexType>>> Make(
+ const Tensor& tensor, const std::shared_ptr<DataType>& index_value_type,
+ MemoryPool* pool = default_memory_pool()) {
+ std::shared_ptr<SparseIndex> sparse_index;
+ std::shared_ptr<Buffer> data;
+ ARROW_RETURN_NOT_OK(internal::MakeSparseTensorFromTensor(
+ tensor, SparseIndexType::format_id, index_value_type, pool, &sparse_index,
+ &data));
+ return std::make_shared<SparseTensorImpl<SparseIndexType>>(
+ internal::checked_pointer_cast<SparseIndexType>(sparse_index), tensor.type(),
+ data, tensor.shape(), tensor.dim_names_);
+ }
+
+ static inline Result<std::shared_ptr<SparseTensorImpl<SparseIndexType>>> Make(
+ const Tensor& tensor, MemoryPool* pool = default_memory_pool()) {
+ return Make(tensor, int64(), pool);
+ }
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(SparseTensorImpl);
+};
+
+/// \brief EXPERIMENTAL: Type alias for COO sparse tensor
+using SparseCOOTensor = SparseTensorImpl<SparseCOOIndex>;
+
+/// \brief EXPERIMENTAL: Type alias for CSR sparse matrix
+using SparseCSRMatrix = SparseTensorImpl<SparseCSRIndex>;
+
+/// \brief EXPERIMENTAL: Type alias for CSC sparse matrix
+using SparseCSCMatrix = SparseTensorImpl<SparseCSCIndex>;
+
+/// \brief EXPERIMENTAL: Type alias for CSF sparse matrix
+using SparseCSFTensor = SparseTensorImpl<SparseCSFIndex>;
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/sparse_tensor_test.cc b/src/arrow/cpp/src/arrow/sparse_tensor_test.cc
new file mode 100644
index 000000000..219cdd934
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/sparse_tensor_test.cc
@@ -0,0 +1,1678 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Unit tests for DataType (and subclasses), Field, and Schema
+
+#include <cmath>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <iostream>
+
+#include <gtest/gtest.h>
+
+#include "arrow/sparse_tensor.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/sort.h"
+
+namespace arrow {
+
+static inline void CheckSparseIndexFormatType(SparseTensorFormat::type expected,
+ const SparseTensor& sparse_tensor) {
+ ASSERT_EQ(expected, sparse_tensor.format_id());
+ ASSERT_EQ(expected, sparse_tensor.sparse_index()->format_id());
+}
+
+static inline void AssertCOOIndex(const std::shared_ptr<Tensor>& sidx, const int64_t nth,
+ const std::vector<int64_t>& expected_values) {
+ int64_t n = static_cast<int64_t>(expected_values.size());
+ for (int64_t i = 0; i < n; ++i) {
+ ASSERT_EQ(expected_values[i], sidx->Value<Int64Type>({nth, i}));
+ }
+}
+
+//-----------------------------------------------------------------------------
+// SparseCOOIndex
+
+TEST(TestSparseCOOIndex, MakeRowMajorCanonical) {
+ std::vector<int32_t> values = {0, 0, 0, 0, 0, 2, 0, 1, 1, 0, 1, 3, 0, 2, 0, 0, 2, 2,
+ 1, 0, 1, 1, 0, 3, 1, 1, 0, 1, 1, 2, 1, 2, 1, 1, 2, 3};
+ auto data = Buffer::Wrap(values);
+ std::vector<int64_t> shape = {12, 3};
+ std::vector<int64_t> strides = {3 * sizeof(int32_t), sizeof(int32_t)}; // Row-major
+
+ // OK
+ std::shared_ptr<SparseCOOIndex> si;
+ ASSERT_OK_AND_ASSIGN(si, SparseCOOIndex::Make(int32(), shape, strides, data));
+ ASSERT_EQ(shape, si->indices()->shape());
+ ASSERT_EQ(strides, si->indices()->strides());
+ ASSERT_EQ(data->data(), si->indices()->raw_data());
+ ASSERT_TRUE(si->is_canonical());
+
+ // Non-integer type
+ auto res = SparseCOOIndex::Make(float32(), shape, strides, data);
+ ASSERT_RAISES(TypeError, res);
+
+ // Non-matrix indices
+ res = SparseCOOIndex::Make(int32(), {4, 3, 4}, strides, data);
+ ASSERT_RAISES(Invalid, res);
+
+ // Non-contiguous indices
+ res = SparseCOOIndex::Make(int32(), {6, 3}, {6 * sizeof(int32_t), 2 * sizeof(int32_t)},
+ data);
+ ASSERT_RAISES(Invalid, res);
+
+ // Make from sparse tensor properties
+ // (shape is arbitrary 3-dim, non-zero length = 12)
+ ASSERT_OK_AND_ASSIGN(si, SparseCOOIndex::Make(int32(), {99, 99, 99}, 12, data));
+ ASSERT_EQ(shape, si->indices()->shape());
+ ASSERT_EQ(strides, si->indices()->strides());
+ ASSERT_EQ(data->data(), si->indices()->raw_data());
+}
+
+TEST(TestSparseCOOIndex, MakeRowMajorNonCanonical) {
+ std::vector<int32_t> values = {0, 0, 0, 0, 0, 2, 0, 1, 1, 0, 1, 3, 0, 2, 0, 1, 0, 1,
+ 0, 2, 2, 1, 0, 3, 1, 1, 0, 1, 1, 2, 1, 2, 1, 1, 2, 3};
+ auto data = Buffer::Wrap(values);
+ std::vector<int64_t> shape = {12, 3};
+ std::vector<int64_t> strides = {3 * sizeof(int32_t), sizeof(int32_t)}; // Row-major
+
+ // OK
+ std::shared_ptr<SparseCOOIndex> si;
+ ASSERT_OK_AND_ASSIGN(si, SparseCOOIndex::Make(int32(), shape, strides, data));
+ ASSERT_EQ(shape, si->indices()->shape());
+ ASSERT_EQ(strides, si->indices()->strides());
+ ASSERT_EQ(data->data(), si->indices()->raw_data());
+ ASSERT_FALSE(si->is_canonical());
+}
+
+TEST(TestSparseCOOIndex, MakeColumnMajorCanonical) {
+ std::vector<int32_t> values = {0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 2, 2,
+ 0, 0, 1, 1, 2, 2, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3};
+ auto data = Buffer::Wrap(values);
+ std::vector<int64_t> shape = {12, 3};
+ std::vector<int64_t> strides = {sizeof(int32_t), 12 * sizeof(int32_t)}; // Column-major
+
+ // OK
+ std::shared_ptr<SparseCOOIndex> si;
+ ASSERT_OK_AND_ASSIGN(si, SparseCOOIndex::Make(int32(), shape, strides, data));
+ ASSERT_EQ(shape, si->indices()->shape());
+ ASSERT_EQ(strides, si->indices()->strides());
+ ASSERT_EQ(data->data(), si->indices()->raw_data());
+ ASSERT_TRUE(si->is_canonical());
+}
+
+TEST(TestSparseCOOIndex, MakeColumnMajorNonCanonical) {
+ std::vector<int32_t> values = {0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 2, 0,
+ 2, 0, 1, 1, 2, 2, 0, 2, 1, 3, 0, 1, 2, 3, 0, 2, 1, 3};
+ auto data = Buffer::Wrap(values);
+ std::vector<int64_t> shape = {12, 3};
+ std::vector<int64_t> strides = {sizeof(int32_t), 12 * sizeof(int32_t)}; // Column-major
+
+ // OK
+ std::shared_ptr<SparseCOOIndex> si;
+ ASSERT_OK_AND_ASSIGN(si, SparseCOOIndex::Make(int32(), shape, strides, data));
+ ASSERT_EQ(shape, si->indices()->shape());
+ ASSERT_EQ(strides, si->indices()->strides());
+ ASSERT_EQ(data->data(), si->indices()->raw_data());
+ ASSERT_FALSE(si->is_canonical());
+}
+
+TEST(TestSparseCOOIndex, MakeEmptyIndex) {
+ std::vector<int32_t> values = {};
+ auto data = Buffer::Wrap(values);
+ std::vector<int64_t> shape = {0, 3};
+ std::vector<int64_t> strides = {sizeof(int32_t), sizeof(int32_t)}; // Empty strides
+
+ // OK
+ std::shared_ptr<SparseCOOIndex> si;
+ ASSERT_OK_AND_ASSIGN(si, SparseCOOIndex::Make(int32(), shape, strides, data));
+ ASSERT_EQ(shape, si->indices()->shape());
+ ASSERT_EQ(strides, si->indices()->strides());
+ ASSERT_EQ(data->data(), si->indices()->raw_data());
+ ASSERT_TRUE(si->is_canonical());
+}
+
+TEST(TestSparseCSRIndex, Make) {
+ std::vector<int32_t> indptr_values = {0, 2, 4, 6, 8, 10, 12};
+ std::vector<int32_t> indices_values = {0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3};
+ auto indptr_data = Buffer::Wrap(indptr_values);
+ auto indices_data = Buffer::Wrap(indices_values);
+ std::vector<int64_t> indptr_shape = {7};
+ std::vector<int64_t> indices_shape = {12};
+
+ // OK
+ std::shared_ptr<SparseCSRIndex> si;
+ ASSERT_OK_AND_ASSIGN(si, SparseCSRIndex::Make(int32(), indptr_shape, indices_shape,
+ indptr_data, indices_data));
+ ASSERT_EQ(indptr_shape, si->indptr()->shape());
+ ASSERT_EQ(indptr_data->data(), si->indptr()->raw_data());
+ ASSERT_EQ(indices_shape, si->indices()->shape());
+ ASSERT_EQ(indices_data->data(), si->indices()->raw_data());
+ ASSERT_EQ(std::string("SparseCSRIndex"), si->ToString());
+
+ // Non-integer type
+ auto res = SparseCSRIndex::Make(float32(), indptr_shape, indices_shape, indptr_data,
+ indices_data);
+ ASSERT_RAISES(TypeError, res);
+
+ // Non-vector indptr shape
+ ASSERT_RAISES(Invalid, SparseCSRIndex::Make(int32(), {1, 2}, indices_shape, indptr_data,
+ indices_data));
+
+ // Non-vector indices shape
+ ASSERT_RAISES(Invalid, SparseCSRIndex::Make(int32(), indptr_shape, {1, 2}, indptr_data,
+ indices_data));
+}
+
+TEST(TestSparseCSCIndex, Make) {
+ std::vector<int32_t> indptr_values = {0, 2, 4, 6, 8, 10, 12};
+ std::vector<int32_t> indices_values = {0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3};
+ auto indptr_data = Buffer::Wrap(indptr_values);
+ auto indices_data = Buffer::Wrap(indices_values);
+ std::vector<int64_t> indptr_shape = {7};
+ std::vector<int64_t> indices_shape = {12};
+
+ // OK
+ std::shared_ptr<SparseCSCIndex> si;
+ ASSERT_OK_AND_ASSIGN(si, SparseCSCIndex::Make(int32(), indptr_shape, indices_shape,
+ indptr_data, indices_data));
+ ASSERT_EQ(indptr_shape, si->indptr()->shape());
+ ASSERT_EQ(indptr_data->data(), si->indptr()->raw_data());
+ ASSERT_EQ(indices_shape, si->indices()->shape());
+ ASSERT_EQ(indices_data->data(), si->indices()->raw_data());
+ ASSERT_EQ(std::string("SparseCSCIndex"), si->ToString());
+
+ // Non-integer type
+ ASSERT_RAISES(TypeError, SparseCSCIndex::Make(float32(), indptr_shape, indices_shape,
+ indptr_data, indices_data));
+
+ // Non-vector indptr shape
+ ASSERT_RAISES(Invalid, SparseCSCIndex::Make(int32(), {1, 2}, indices_shape, indptr_data,
+ indices_data));
+
+ // Non-vector indices shape
+ ASSERT_RAISES(Invalid, SparseCSCIndex::Make(int32(), indptr_shape, {1, 2}, indptr_data,
+ indices_data));
+}
+
+template <typename ValueType>
+class TestSparseTensorBase : public ::testing::Test {
+ protected:
+ std::vector<int64_t> shape_;
+ std::vector<std::string> dim_names_;
+};
+
+//-----------------------------------------------------------------------------
+// SparseCOOTensor
+
+template <typename IndexValueType, typename ValueType = Int64Type>
+class TestSparseCOOTensorBase : public TestSparseTensorBase<ValueType> {
+ public:
+ using c_value_type = typename ValueType::c_type;
+
+ void SetUp() {
+ shape_ = {2, 3, 4};
+ dim_names_ = {"foo", "bar", "baz"};
+
+ // Dense representation:
+ // [
+ // [
+ // 1 0 2 0
+ // 0 3 0 4
+ // 5 0 6 0
+ // ],
+ // [
+ // 0 11 0 12
+ // 13 0 14 0
+ // 0 15 0 16
+ // ]
+ // ]
+ dense_values_ = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+ auto dense_data = Buffer::Wrap(dense_values_);
+ NumericTensor<ValueType> dense_tensor(dense_data, shape_, {}, dim_names_);
+ ASSERT_OK_AND_ASSIGN(sparse_tensor_from_dense_,
+ SparseCOOTensor::Make(
+ dense_tensor, TypeTraits<IndexValueType>::type_singleton()));
+ }
+
+ protected:
+ using TestSparseTensorBase<ValueType>::shape_;
+ using TestSparseTensorBase<ValueType>::dim_names_;
+ std::vector<c_value_type> dense_values_;
+ std::shared_ptr<SparseCOOTensor> sparse_tensor_from_dense_;
+};
+
+class TestSparseCOOTensor : public TestSparseCOOTensorBase<Int64Type> {};
+
+TEST_F(TestSparseCOOTensor, CreationEmptyTensor) {
+ SparseCOOTensor st1(int64(), this->shape_);
+ SparseCOOTensor st2(int64(), this->shape_, this->dim_names_);
+
+ ASSERT_EQ(0, st1.non_zero_length());
+ ASSERT_EQ(0, st2.non_zero_length());
+
+ ASSERT_EQ(24, st1.size());
+ ASSERT_EQ(24, st2.size());
+
+ ASSERT_EQ(std::vector<std::string>({"foo", "bar", "baz"}), st2.dim_names());
+ ASSERT_EQ("foo", st2.dim_name(0));
+ ASSERT_EQ("bar", st2.dim_name(1));
+ ASSERT_EQ("baz", st2.dim_name(2));
+
+ ASSERT_EQ(std::vector<std::string>({}), st1.dim_names());
+ ASSERT_EQ("", st1.dim_name(0));
+ ASSERT_EQ("", st1.dim_name(1));
+ ASSERT_EQ("", st1.dim_name(2));
+}
+
+TEST_F(TestSparseCOOTensor, CreationFromZeroTensor) {
+ const auto dense_size =
+ std::accumulate(this->shape_.begin(), this->shape_.end(), int64_t(1),
+ [](int64_t a, int64_t x) { return a * x; });
+ std::vector<int64_t> dense_values(dense_size, 0);
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> t_zero,
+ Tensor::Make(int64(), Buffer::Wrap(dense_values), this->shape_));
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<SparseCOOTensor> st_zero,
+ SparseCOOTensor::Make(*t_zero, int64()));
+
+ ASSERT_EQ(0, st_zero->non_zero_length());
+ ASSERT_EQ(dense_size, st_zero->size());
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> t, st_zero->ToTensor());
+ ASSERT_TRUE(t->Equals(*t_zero));
+}
+
+TEST_F(TestSparseCOOTensor, CreationFromNumericTensor) {
+ auto st = this->sparse_tensor_from_dense_;
+ CheckSparseIndexFormatType(SparseTensorFormat::COO, *st);
+
+ ASSERT_EQ(12, st->non_zero_length());
+ ASSERT_TRUE(st->is_mutable());
+
+ auto* raw_data = reinterpret_cast<const int64_t*>(st->raw_data());
+ AssertNumericDataEqual(raw_data, {1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16});
+
+ auto si = internal::checked_pointer_cast<SparseCOOIndex>(st->sparse_index());
+ ASSERT_EQ(std::string("SparseCOOIndex"), si->ToString());
+ ASSERT_TRUE(si->is_canonical());
+
+ std::shared_ptr<Tensor> sidx = si->indices();
+ ASSERT_EQ(std::vector<int64_t>({12, 3}), sidx->shape());
+ ASSERT_TRUE(sidx->is_row_major());
+
+ AssertCOOIndex(sidx, 0, {0, 0, 0});
+ AssertCOOIndex(sidx, 1, {0, 0, 2});
+ AssertCOOIndex(sidx, 2, {0, 1, 1});
+ AssertCOOIndex(sidx, 10, {1, 2, 1});
+ AssertCOOIndex(sidx, 11, {1, 2, 3});
+}
+
+TEST_F(TestSparseCOOTensor, CreationFromNumericTensor1D) {
+ auto dense_data = Buffer::Wrap(this->dense_values_);
+ std::vector<int64_t> dense_shape({static_cast<int64_t>(this->dense_values_.size())});
+ NumericTensor<Int64Type> dense_vector(dense_data, dense_shape);
+
+ std::shared_ptr<SparseCOOTensor> st;
+ ASSERT_OK_AND_ASSIGN(st, SparseCOOTensor::Make(dense_vector));
+
+ ASSERT_EQ(12, st->non_zero_length());
+ ASSERT_TRUE(st->is_mutable());
+
+ auto* raw_data = reinterpret_cast<const int64_t*>(st->raw_data());
+ AssertNumericDataEqual(raw_data, {1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16});
+
+ auto si = internal::checked_pointer_cast<SparseCOOIndex>(st->sparse_index());
+ ASSERT_TRUE(si->is_canonical());
+
+ auto sidx = si->indices();
+ ASSERT_EQ(std::vector<int64_t>({12, 1}), sidx->shape());
+
+ AssertCOOIndex(sidx, 0, {0});
+ AssertCOOIndex(sidx, 1, {2});
+ AssertCOOIndex(sidx, 2, {5});
+ AssertCOOIndex(sidx, 10, {21});
+ AssertCOOIndex(sidx, 11, {23});
+}
+
+TEST_F(TestSparseCOOTensor, CreationFromTensor) {
+ std::shared_ptr<Buffer> buffer = Buffer::Wrap(this->dense_values_);
+ Tensor tensor(int64(), buffer, this->shape_, {}, this->dim_names_);
+
+ std::shared_ptr<SparseCOOTensor> st;
+ ASSERT_OK_AND_ASSIGN(st, SparseCOOTensor::Make(tensor));
+
+ ASSERT_EQ(12, st->non_zero_length());
+ ASSERT_TRUE(st->is_mutable());
+
+ ASSERT_EQ(std::vector<std::string>({"foo", "bar", "baz"}), st->dim_names());
+ ASSERT_EQ("foo", st->dim_name(0));
+ ASSERT_EQ("bar", st->dim_name(1));
+ ASSERT_EQ("baz", st->dim_name(2));
+
+ ASSERT_TRUE(st->Equals(*this->sparse_tensor_from_dense_));
+
+ auto si = internal::checked_pointer_cast<SparseCOOIndex>(st->sparse_index());
+ ASSERT_TRUE(si->is_canonical());
+}
+
+TEST_F(TestSparseCOOTensor, CreationFromNonContiguousTensor) {
+ std::vector<int64_t> values = {1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 3, 0, 0, 0, 4, 0,
+ 5, 0, 0, 0, 6, 0, 0, 0, 0, 0, 11, 0, 0, 0, 12, 0,
+ 13, 0, 0, 0, 14, 0, 0, 0, 0, 0, 15, 0, 0, 0, 16, 0};
+ std::vector<int64_t> strides = {192, 64, 16};
+ std::shared_ptr<Buffer> buffer = Buffer::Wrap(values);
+ Tensor tensor(int64(), buffer, this->shape_, strides);
+
+ std::shared_ptr<SparseCOOTensor> st;
+ ASSERT_OK_AND_ASSIGN(st, SparseCOOTensor::Make(tensor));
+
+ ASSERT_EQ(12, st->non_zero_length());
+ ASSERT_TRUE(st->is_mutable());
+
+ ASSERT_TRUE(st->Equals(*this->sparse_tensor_from_dense_));
+
+ auto si = internal::checked_pointer_cast<SparseCOOIndex>(st->sparse_index());
+ ASSERT_TRUE(si->is_canonical());
+}
+
+TEST_F(TestSparseCOOTensor, TestToTensor) {
+ std::vector<int64_t> values = {1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4};
+ std::vector<int64_t> shape({4, 3, 2, 1});
+ std::shared_ptr<Buffer> buffer = Buffer::Wrap(values);
+ Tensor tensor(int64(), buffer, shape, {}, this->dim_names_);
+
+ std::shared_ptr<SparseCOOTensor> sparse_tensor;
+ ASSERT_OK_AND_ASSIGN(sparse_tensor, SparseCOOTensor::Make(tensor));
+
+ ASSERT_EQ(5, sparse_tensor->non_zero_length());
+ ASSERT_TRUE(sparse_tensor->is_mutable());
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> dense_tensor, sparse_tensor->ToTensor());
+ ASSERT_TRUE(tensor.Equals(*dense_tensor));
+}
+
+template <typename ValueType>
+class TestSparseCOOTensorEquality : public TestSparseTensorBase<ValueType> {
+ public:
+ void SetUp() {
+ shape_ = {2, 3, 4};
+ values1_ = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+ values2_ = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 0, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+ auto buffer1 = Buffer::Wrap(values1_);
+ auto buffer2 = Buffer::Wrap(values2_);
+ DCHECK_OK(NumericTensor<ValueType>::Make(buffer1, this->shape_).Value(&tensor1_));
+ DCHECK_OK(NumericTensor<ValueType>::Make(buffer2, this->shape_).Value(&tensor2_));
+ }
+
+ protected:
+ using TestSparseTensorBase<ValueType>::shape_;
+ std::vector<typename ValueType::c_type> values1_;
+ std::vector<typename ValueType::c_type> values2_;
+ std::shared_ptr<NumericTensor<ValueType>> tensor1_;
+ std::shared_ptr<NumericTensor<ValueType>> tensor2_;
+};
+
+template <typename ValueType>
+class TestIntegerSparseCOOTensorEquality : public TestSparseCOOTensorEquality<ValueType> {
+};
+
+TYPED_TEST_SUITE_P(TestIntegerSparseCOOTensorEquality);
+
+TYPED_TEST_P(TestIntegerSparseCOOTensorEquality, TestEquality) {
+ using ValueType = TypeParam;
+ static_assert(is_integer_type<ValueType>::value, "Integer type is required");
+
+ std::shared_ptr<SparseCOOTensor> st1, st2, st3;
+ ASSERT_OK_AND_ASSIGN(st1, SparseCOOTensor::Make(*this->tensor1_));
+ ASSERT_OK_AND_ASSIGN(st2, SparseCOOTensor::Make(*this->tensor2_));
+ ASSERT_OK_AND_ASSIGN(st3, SparseCOOTensor::Make(*this->tensor1_));
+
+ ASSERT_TRUE(st1->Equals(*st1));
+ ASSERT_FALSE(st1->Equals(*st2));
+ ASSERT_TRUE(st1->Equals(*st3));
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TestIntegerSparseCOOTensorEquality, TestEquality);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt8, TestIntegerSparseCOOTensorEquality, Int8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt8, TestIntegerSparseCOOTensorEquality, UInt8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt16, TestIntegerSparseCOOTensorEquality, Int16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt16, TestIntegerSparseCOOTensorEquality,
+ UInt16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt32, TestIntegerSparseCOOTensorEquality, Int32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt32, TestIntegerSparseCOOTensorEquality,
+ UInt32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt64, TestIntegerSparseCOOTensorEquality, Int64Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt64, TestIntegerSparseCOOTensorEquality,
+ UInt64Type);
+
+template <typename ValueType>
+class TestFloatingSparseCOOTensorEquality
+ : public TestSparseCOOTensorEquality<ValueType> {};
+
+TYPED_TEST_SUITE_P(TestFloatingSparseCOOTensorEquality);
+
+TYPED_TEST_P(TestFloatingSparseCOOTensorEquality, TestEquality) {
+ using ValueType = TypeParam;
+ using c_value_type = typename ValueType::c_type;
+ static_assert(is_floating_type<ValueType>::value, "Float type is required");
+
+ std::shared_ptr<SparseCOOTensor> st1, st2, st3;
+ ASSERT_OK_AND_ASSIGN(st1, SparseCOOTensor::Make(*this->tensor1_));
+ ASSERT_OK_AND_ASSIGN(st2, SparseCOOTensor::Make(*this->tensor2_));
+ ASSERT_OK_AND_ASSIGN(st3, SparseCOOTensor::Make(*this->tensor1_));
+
+ ASSERT_TRUE(st1->Equals(*st1));
+ ASSERT_FALSE(st1->Equals(*st2));
+ ASSERT_TRUE(st1->Equals(*st3));
+
+ // sparse tensors with NaNs
+ const c_value_type nan_value = static_cast<c_value_type>(NAN);
+ this->values2_[13] = nan_value;
+ EXPECT_TRUE(std::isnan(this->tensor2_->Value({1, 0, 1})));
+
+ std::shared_ptr<SparseCOOTensor> st4;
+ ASSERT_OK_AND_ASSIGN(st4, SparseCOOTensor::Make(*this->tensor2_));
+ EXPECT_FALSE(st4->Equals(*st4)); // same object
+ EXPECT_TRUE(st4->Equals(*st4, EqualOptions().nans_equal(true))); // same object
+
+ std::vector<c_value_type> values5 = this->values2_;
+ std::shared_ptr<SparseCOOTensor> st5;
+ std::shared_ptr<Buffer> buffer5 = Buffer::Wrap(values5);
+ NumericTensor<ValueType> tensor5(buffer5, this->shape_);
+ ASSERT_OK_AND_ASSIGN(st5, SparseCOOTensor::Make(tensor5));
+ EXPECT_FALSE(st4->Equals(*st5)); // different memory
+ EXPECT_TRUE(st4->Equals(*st5, EqualOptions().nans_equal(true))); // different memory
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TestFloatingSparseCOOTensorEquality, TestEquality);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(TestFloat, TestFloatingSparseCOOTensorEquality, FloatType);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestDouble, TestFloatingSparseCOOTensorEquality,
+ DoubleType);
+
+template <typename IndexValueType>
+class TestSparseCOOTensorForIndexValueType
+ : public TestSparseCOOTensorBase<IndexValueType> {
+ public:
+ using c_index_value_type = typename IndexValueType::c_type;
+
+ void SetUp() override {
+ TestSparseCOOTensorBase<IndexValueType>::SetUp();
+
+ // Sparse representation:
+ // idx[0] = [0 0 0 0 0 0 1 1 1 1 1 1]
+ // idx[1] = [0 0 1 1 2 2 0 0 1 1 2 2]
+ // idx[2] = [0 2 1 3 0 2 1 3 0 2 1 3]
+ // data = [1 2 3 4 5 6 11 12 13 14 15 16]
+
+ coords_values_row_major_ = {0, 0, 0, 0, 0, 2, 0, 1, 1, 0, 1, 3, 0, 2, 0, 0, 2, 2,
+ 1, 0, 1, 1, 0, 3, 1, 1, 0, 1, 1, 2, 1, 2, 1, 1, 2, 3};
+
+ coords_values_col_major_ = {0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 2, 2,
+ 0, 0, 1, 1, 2, 2, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3};
+ }
+
+ std::shared_ptr<DataType> index_data_type() const {
+ return TypeTraits<IndexValueType>::type_singleton();
+ }
+
+ protected:
+ std::vector<c_index_value_type> coords_values_row_major_;
+ std::vector<c_index_value_type> coords_values_col_major_;
+
+ Result<std::shared_ptr<SparseCOOIndex>> MakeSparseCOOIndex(
+ const std::vector<int64_t>& shape, const std::vector<int64_t>& strides,
+ const std::vector<c_index_value_type>& values) const {
+ return SparseCOOIndex::Make(index_data_type(), shape, strides, Buffer::Wrap(values));
+ }
+
+ template <typename CValueType>
+ Result<std::shared_ptr<SparseCOOTensor>> MakeSparseTensor(
+ const std::shared_ptr<SparseCOOIndex>& si,
+ std::vector<CValueType>& sparse_values) const {
+ auto data = Buffer::Wrap(sparse_values);
+ return SparseCOOTensor::Make(si, CTypeTraits<CValueType>::type_singleton(), data,
+ this->shape_, this->dim_names_);
+ }
+};
+
+TYPED_TEST_SUITE_P(TestSparseCOOTensorForIndexValueType);
+
+TYPED_TEST_P(TestSparseCOOTensorForIndexValueType, Make) {
+ using IndexValueType = TypeParam;
+ using c_index_value_type = typename IndexValueType::c_type;
+
+ constexpr int sizeof_index_value = sizeof(c_index_value_type);
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<SparseCOOIndex> si,
+ this->MakeSparseCOOIndex({12, 3}, {sizeof_index_value * 3, sizeof_index_value},
+ this->coords_values_row_major_));
+
+ std::vector<int64_t> sparse_values = {1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16};
+ auto sparse_data = Buffer::Wrap(sparse_values);
+
+ std::shared_ptr<SparseCOOTensor> st;
+
+ // OK
+ ASSERT_OK_AND_ASSIGN(st, SparseCOOTensor::Make(si, int64(), sparse_data, this->shape_,
+ this->dim_names_));
+ ASSERT_EQ(int64(), st->type());
+ ASSERT_EQ(this->shape_, st->shape());
+ ASSERT_EQ(this->dim_names_, st->dim_names());
+ ASSERT_EQ(sparse_data->data(), st->raw_data());
+ ASSERT_TRUE(
+ internal::checked_pointer_cast<SparseCOOIndex>(st->sparse_index())->Equals(*si));
+
+ // OK with an empty dim_names
+ ASSERT_OK_AND_ASSIGN(st,
+ SparseCOOTensor::Make(si, int64(), sparse_data, this->shape_, {}));
+ ASSERT_EQ(int64(), st->type());
+ ASSERT_EQ(this->shape_, st->shape());
+ ASSERT_EQ(std::vector<std::string>{}, st->dim_names());
+ ASSERT_EQ(sparse_data->data(), st->raw_data());
+ ASSERT_TRUE(
+ internal::checked_pointer_cast<SparseCOOIndex>(st->sparse_index())->Equals(*si));
+
+ // invalid data type
+ auto res = SparseCOOTensor::Make(si, binary(), sparse_data, this->shape_, {});
+ ASSERT_RAISES(Invalid, res);
+
+ // negative items in shape
+ res = SparseCOOTensor::Make(si, int64(), sparse_data, {2, -3, 4}, {});
+ ASSERT_RAISES(Invalid, res);
+
+ // sparse index and ndim (shape length) are inconsistent
+ res = SparseCOOTensor::Make(si, int64(), sparse_data, {6, 4}, {});
+ ASSERT_RAISES(Invalid, res);
+
+ // shape and dim_names are inconsistent
+ res = SparseCOOTensor::Make(si, int64(), sparse_data, this->shape_,
+ std::vector<std::string>{"foo"});
+ ASSERT_RAISES(Invalid, res);
+}
+
+TYPED_TEST_P(TestSparseCOOTensorForIndexValueType, CreationWithRowMajorIndex) {
+ using IndexValueType = TypeParam;
+ using c_index_value_type = typename IndexValueType::c_type;
+
+ constexpr int sizeof_index_value = sizeof(c_index_value_type);
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<SparseCOOIndex> si,
+ this->MakeSparseCOOIndex({12, 3}, {sizeof_index_value * 3, sizeof_index_value},
+ this->coords_values_row_major_));
+
+ std::vector<int64_t> sparse_values = {1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16};
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<SparseCOOTensor> st,
+ this->MakeSparseTensor(si, sparse_values));
+
+ ASSERT_EQ(std::vector<std::string>({"foo", "bar", "baz"}), st->dim_names());
+ ASSERT_EQ("foo", st->dim_name(0));
+ ASSERT_EQ("bar", st->dim_name(1));
+ ASSERT_EQ("baz", st->dim_name(2));
+
+ ASSERT_TRUE(st->Equals(*this->sparse_tensor_from_dense_));
+}
+
+TYPED_TEST_P(TestSparseCOOTensorForIndexValueType, CreationWithColumnMajorIndex) {
+ using IndexValueType = TypeParam;
+ using c_index_value_type = typename IndexValueType::c_type;
+
+ constexpr int sizeof_index_value = sizeof(c_index_value_type);
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<SparseCOOIndex> si,
+ this->MakeSparseCOOIndex({12, 3}, {sizeof_index_value, sizeof_index_value * 12},
+ this->coords_values_col_major_));
+
+ std::vector<int64_t> sparse_values = {1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16};
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<SparseCOOTensor> st,
+ this->MakeSparseTensor(si, sparse_values));
+
+ ASSERT_EQ(std::vector<std::string>({"foo", "bar", "baz"}), st->dim_names());
+ ASSERT_EQ("foo", st->dim_name(0));
+ ASSERT_EQ("bar", st->dim_name(1));
+ ASSERT_EQ("baz", st->dim_name(2));
+
+ ASSERT_TRUE(st->Equals(*this->sparse_tensor_from_dense_));
+}
+
+TYPED_TEST_P(TestSparseCOOTensorForIndexValueType,
+ EqualityBetweenRowAndColumnMajorIndices) {
+ using IndexValueType = TypeParam;
+ using c_index_value_type = typename IndexValueType::c_type;
+
+ // Row-major COO index
+ const std::vector<int64_t> coords_shape = {12, 3};
+ constexpr int sizeof_index_value = sizeof(c_index_value_type);
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<SparseCOOIndex> si_row_major,
+ this->MakeSparseCOOIndex(coords_shape, {sizeof_index_value * 3, sizeof_index_value},
+ this->coords_values_row_major_));
+
+ // Column-major COO index
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<SparseCOOIndex> si_col_major,
+ this->MakeSparseCOOIndex(
+ coords_shape, {sizeof_index_value, sizeof_index_value * 12},
+ this->coords_values_col_major_));
+
+ std::vector<int64_t> sparse_values_1 = {1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16};
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<SparseCOOTensor> st1,
+ this->MakeSparseTensor(si_row_major, sparse_values_1));
+
+ std::vector<int64_t> sparse_values_2 = sparse_values_1;
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<SparseCOOTensor> st2,
+ this->MakeSparseTensor(si_row_major, sparse_values_2));
+
+ ASSERT_TRUE(st2->Equals(*st1));
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TestSparseCOOTensorForIndexValueType, Make,
+ CreationWithRowMajorIndex, CreationWithColumnMajorIndex,
+ EqualityBetweenRowAndColumnMajorIndices);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt8, TestSparseCOOTensorForIndexValueType, Int8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt8, TestSparseCOOTensorForIndexValueType,
+ UInt8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt16, TestSparseCOOTensorForIndexValueType,
+ Int16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt16, TestSparseCOOTensorForIndexValueType,
+ UInt16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt32, TestSparseCOOTensorForIndexValueType,
+ Int32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt32, TestSparseCOOTensorForIndexValueType,
+ UInt32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt64, TestSparseCOOTensorForIndexValueType,
+ Int64Type);
+
+TEST(TestSparseCOOTensorForUInt64Index, Make) {
+ std::vector<int64_t> dense_values = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+ Tensor dense_tensor(uint64(), Buffer::Wrap(dense_values), {2, 3, 4});
+ ASSERT_RAISES(Invalid, SparseCOOTensor::Make(dense_tensor, uint64()));
+}
+
+template <typename IndexValueType>
+class TestSparseCSRMatrixBase : public TestSparseTensorBase<Int64Type> {
+ public:
+ void SetUp() {
+ shape_ = {6, 4};
+ dim_names_ = {"foo", "bar"};
+
+ // Dense representation:
+ // [
+ // 1 0 2 0
+ // 0 3 0 4
+ // 5 0 6 0
+ // 0 11 0 12
+ // 13 0 14 0
+ // 0 15 0 16
+ // ]
+ dense_values_ = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+ auto dense_data = Buffer::Wrap(dense_values_);
+ NumericTensor<Int64Type> dense_tensor(dense_data, shape_, {}, dim_names_);
+ ASSERT_OK_AND_ASSIGN(sparse_tensor_from_dense_,
+ SparseCSRMatrix::Make(
+ dense_tensor, TypeTraits<IndexValueType>::type_singleton()));
+ }
+
+ protected:
+ std::vector<int64_t> dense_values_;
+ std::shared_ptr<SparseCSRMatrix> sparse_tensor_from_dense_;
+};
+
+class TestSparseCSRMatrix : public TestSparseCSRMatrixBase<Int64Type> {};
+
+TEST_F(TestSparseCSRMatrix, CreationFromZeroTensor) {
+ const auto dense_size =
+ std::accumulate(this->shape_.begin(), this->shape_.end(), int64_t(1),
+ [](int64_t a, int64_t x) { return a * x; });
+ std::vector<int64_t> dense_values(dense_size, 0);
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> t_zero,
+ Tensor::Make(int64(), Buffer::Wrap(dense_values), this->shape_));
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<SparseCSRMatrix> st_zero,
+ SparseCSRMatrix::Make(*t_zero, int64()));
+
+ ASSERT_EQ(0, st_zero->non_zero_length());
+ ASSERT_EQ(dense_size, st_zero->size());
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> t, st_zero->ToTensor());
+ ASSERT_TRUE(t->Equals(*t_zero));
+}
+
+TEST_F(TestSparseCSRMatrix, CreationFromNumericTensor2D) {
+ std::shared_ptr<Buffer> buffer = Buffer::Wrap(this->dense_values_);
+ NumericTensor<Int64Type> tensor(buffer, this->shape_);
+
+ std::shared_ptr<SparseCSRMatrix> st1;
+ ASSERT_OK_AND_ASSIGN(st1, SparseCSRMatrix::Make(tensor));
+
+ auto st2 = this->sparse_tensor_from_dense_;
+
+ CheckSparseIndexFormatType(SparseTensorFormat::CSR, *st1);
+
+ ASSERT_EQ(12, st1->non_zero_length());
+ ASSERT_TRUE(st1->is_mutable());
+
+ ASSERT_EQ(std::vector<std::string>({"foo", "bar"}), st2->dim_names());
+ ASSERT_EQ("foo", st2->dim_name(0));
+ ASSERT_EQ("bar", st2->dim_name(1));
+
+ ASSERT_EQ(std::vector<std::string>({}), st1->dim_names());
+ ASSERT_EQ("", st1->dim_name(0));
+ ASSERT_EQ("", st1->dim_name(1));
+ ASSERT_EQ("", st1->dim_name(2));
+
+ const int64_t* raw_data = reinterpret_cast<const int64_t*>(st1->raw_data());
+ AssertNumericDataEqual(raw_data, {1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16});
+
+ auto si = internal::checked_pointer_cast<SparseCSRIndex>(st1->sparse_index());
+ ASSERT_EQ(std::string("SparseCSRIndex"), si->ToString());
+ ASSERT_EQ(1, si->indptr()->ndim());
+ ASSERT_EQ(1, si->indices()->ndim());
+
+ const int64_t* indptr_begin =
+ reinterpret_cast<const int64_t*>(si->indptr()->raw_data());
+ std::vector<int64_t> indptr_values(indptr_begin,
+ indptr_begin + si->indptr()->shape()[0]);
+
+ ASSERT_EQ(7, indptr_values.size());
+ ASSERT_EQ(std::vector<int64_t>({0, 2, 4, 6, 8, 10, 12}), indptr_values);
+
+ const int64_t* indices_begin =
+ reinterpret_cast<const int64_t*>(si->indices()->raw_data());
+ std::vector<int64_t> indices_values(indices_begin,
+ indices_begin + si->indices()->shape()[0]);
+
+ ASSERT_EQ(12, indices_values.size());
+ ASSERT_EQ(std::vector<int64_t>({0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3}), indices_values);
+}
+
+TEST_F(TestSparseCSRMatrix, CreationFromNonContiguousTensor) {
+ std::vector<int64_t> values = {1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 3, 0, 0, 0, 4, 0,
+ 5, 0, 0, 0, 6, 0, 0, 0, 0, 0, 11, 0, 0, 0, 12, 0,
+ 13, 0, 0, 0, 14, 0, 0, 0, 0, 0, 15, 0, 0, 0, 16, 0};
+ std::vector<int64_t> strides = {64, 16};
+ std::shared_ptr<Buffer> buffer = Buffer::Wrap(values);
+ Tensor tensor(int64(), buffer, this->shape_, strides);
+
+ std::shared_ptr<SparseCSRMatrix> st;
+ ASSERT_OK_AND_ASSIGN(st, SparseCSRMatrix::Make(tensor));
+
+ ASSERT_EQ(12, st->non_zero_length());
+ ASSERT_TRUE(st->is_mutable());
+
+ const int64_t* raw_data = reinterpret_cast<const int64_t*>(st->raw_data());
+ AssertNumericDataEqual(raw_data, {1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16});
+
+ auto si = internal::checked_pointer_cast<SparseCSRIndex>(st->sparse_index());
+ ASSERT_EQ(1, si->indptr()->ndim());
+ ASSERT_EQ(1, si->indices()->ndim());
+
+ const int64_t* indptr_begin =
+ reinterpret_cast<const int64_t*>(si->indptr()->raw_data());
+ std::vector<int64_t> indptr_values(indptr_begin,
+ indptr_begin + si->indptr()->shape()[0]);
+
+ ASSERT_EQ(7, indptr_values.size());
+ ASSERT_EQ(std::vector<int64_t>({0, 2, 4, 6, 8, 10, 12}), indptr_values);
+
+ const int64_t* indices_begin =
+ reinterpret_cast<const int64_t*>(si->indices()->raw_data());
+ std::vector<int64_t> indices_values(indices_begin,
+ indices_begin + si->indices()->shape()[0]);
+
+ ASSERT_EQ(12, indices_values.size());
+ ASSERT_EQ(std::vector<int64_t>({0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3}), indices_values);
+
+ ASSERT_TRUE(st->Equals(*this->sparse_tensor_from_dense_));
+}
+
+TEST_F(TestSparseCSRMatrix, TestToTensor) {
+ std::vector<int64_t> values = {1, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 1,
+ 0, 2, 0, 0, 0, 0, 0, 3, 0, 0, 0, 1};
+ std::vector<int64_t> shape({6, 4});
+ std::shared_ptr<Buffer> buffer = Buffer::Wrap(values);
+ Tensor tensor(int64(), buffer, shape, {}, this->dim_names_);
+
+ std::shared_ptr<SparseCSRMatrix> sparse_tensor;
+ ASSERT_OK_AND_ASSIGN(sparse_tensor, SparseCSRMatrix::Make(tensor));
+
+ ASSERT_EQ(7, sparse_tensor->non_zero_length());
+ ASSERT_TRUE(sparse_tensor->is_mutable());
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> dense_tensor, sparse_tensor->ToTensor());
+ ASSERT_TRUE(tensor.Equals(*dense_tensor));
+}
+
+template <typename ValueType>
+class TestSparseCSRMatrixEquality : public TestSparseTensorBase<ValueType> {
+ public:
+ void SetUp() {
+ shape_ = {6, 4};
+ values1_ = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+ values2_ = {9, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+ auto buffer1 = Buffer::Wrap(values1_);
+ auto buffer2 = Buffer::Wrap(values2_);
+ DCHECK_OK(NumericTensor<ValueType>::Make(buffer1, this->shape_).Value(&tensor1_));
+ DCHECK_OK(NumericTensor<ValueType>::Make(buffer2, this->shape_).Value(&tensor2_));
+ }
+
+ protected:
+ using TestSparseTensorBase<ValueType>::shape_;
+ std::vector<typename ValueType::c_type> values1_;
+ std::vector<typename ValueType::c_type> values2_;
+ std::shared_ptr<NumericTensor<ValueType>> tensor1_;
+ std::shared_ptr<NumericTensor<ValueType>> tensor2_;
+};
+
+template <typename ValueType>
+class TestIntegerSparseCSRMatrixEquality : public TestSparseCSRMatrixEquality<ValueType> {
+};
+
+TYPED_TEST_SUITE_P(TestIntegerSparseCSRMatrixEquality);
+
+TYPED_TEST_P(TestIntegerSparseCSRMatrixEquality, TestEquality) {
+ using ValueType = TypeParam;
+ static_assert(is_integer_type<ValueType>::value, "Integer type is required");
+
+ std::shared_ptr<SparseCSRMatrix> st1, st2, st3;
+ ASSERT_OK_AND_ASSIGN(st1, SparseCSRMatrix::Make(*this->tensor1_));
+ ASSERT_OK_AND_ASSIGN(st2, SparseCSRMatrix::Make(*this->tensor2_));
+ ASSERT_OK_AND_ASSIGN(st3, SparseCSRMatrix::Make(*this->tensor1_));
+
+ ASSERT_TRUE(st1->Equals(*st1));
+ ASSERT_FALSE(st1->Equals(*st2));
+ ASSERT_TRUE(st1->Equals(*st3));
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TestIntegerSparseCSRMatrixEquality, TestEquality);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt8, TestIntegerSparseCSRMatrixEquality, Int8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt8, TestIntegerSparseCSRMatrixEquality, UInt8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt16, TestIntegerSparseCSRMatrixEquality, Int16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt16, TestIntegerSparseCSRMatrixEquality,
+ UInt16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt32, TestIntegerSparseCSRMatrixEquality, Int32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt32, TestIntegerSparseCSRMatrixEquality,
+ UInt32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt64, TestIntegerSparseCSRMatrixEquality, Int64Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt64, TestIntegerSparseCSRMatrixEquality,
+ UInt64Type);
+
+template <typename ValueType>
+class TestFloatingSparseCSRMatrixEquality
+ : public TestSparseCSRMatrixEquality<ValueType> {};
+
+TYPED_TEST_SUITE_P(TestFloatingSparseCSRMatrixEquality);
+
+TYPED_TEST_P(TestFloatingSparseCSRMatrixEquality, TestEquality) {
+ using ValueType = TypeParam;
+ using c_value_type = typename ValueType::c_type;
+ static_assert(is_floating_type<ValueType>::value, "Float type is required");
+
+ std::shared_ptr<SparseCSRMatrix> st1, st2, st3;
+ ASSERT_OK_AND_ASSIGN(st1, SparseCSRMatrix::Make(*this->tensor1_));
+ ASSERT_OK_AND_ASSIGN(st2, SparseCSRMatrix::Make(*this->tensor2_));
+ ASSERT_OK_AND_ASSIGN(st3, SparseCSRMatrix::Make(*this->tensor1_));
+
+ ASSERT_TRUE(st1->Equals(*st1));
+ ASSERT_FALSE(st1->Equals(*st2));
+ ASSERT_TRUE(st1->Equals(*st3));
+
+ // sparse tensors with NaNs
+ const c_value_type nan_value = static_cast<c_value_type>(NAN);
+ this->values2_[13] = nan_value;
+ EXPECT_TRUE(std::isnan(this->tensor2_->Value({3, 1})));
+
+ std::shared_ptr<SparseCSRMatrix> st4;
+ ASSERT_OK_AND_ASSIGN(st4, SparseCSRMatrix::Make(*this->tensor2_));
+ EXPECT_FALSE(st4->Equals(*st4)); // same object
+ EXPECT_TRUE(st4->Equals(*st4, EqualOptions().nans_equal(true))); // same object
+
+ std::vector<c_value_type> values5 = this->values2_;
+ std::shared_ptr<SparseCSRMatrix> st5;
+ std::shared_ptr<Buffer> buffer5 = Buffer::Wrap(values5);
+ NumericTensor<ValueType> tensor5(buffer5, this->shape_);
+ ASSERT_OK_AND_ASSIGN(st5, SparseCSRMatrix::Make(tensor5));
+ EXPECT_FALSE(st4->Equals(*st5)); // different memory
+ EXPECT_TRUE(st4->Equals(*st5, EqualOptions().nans_equal(true))); // different memory
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TestFloatingSparseCSRMatrixEquality, TestEquality);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(TestFloat, TestFloatingSparseCSRMatrixEquality, FloatType);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestDouble, TestFloatingSparseCSRMatrixEquality,
+ DoubleType);
+
+template <typename IndexValueType>
+class TestSparseCSRMatrixForIndexValueType
+ : public TestSparseCSRMatrixBase<IndexValueType> {};
+
+TYPED_TEST_SUITE_P(TestSparseCSRMatrixForIndexValueType);
+
+TYPED_TEST_P(TestSparseCSRMatrixForIndexValueType, Make) {
+ using IndexValueType = TypeParam;
+ using c_index_value_type = typename IndexValueType::c_type;
+
+ // Sparse representation:
+ std::vector<c_index_value_type> indptr_values = {0, 2, 4, 6, 8, 10, 12};
+ std::vector<c_index_value_type> indices_values = {0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3};
+
+ std::vector<int64_t> indptr_shape = {7};
+ std::vector<int64_t> indices_shape = {12};
+
+ std::shared_ptr<SparseCSRIndex> si;
+ ASSERT_OK_AND_ASSIGN(
+ si, SparseCSRIndex::Make(TypeTraits<IndexValueType>::type_singleton(), indptr_shape,
+ indices_shape, Buffer::Wrap(indptr_values),
+ Buffer::Wrap(indices_values)));
+
+ std::vector<int64_t> sparse_values = {1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16};
+ auto sparse_data = Buffer::Wrap(sparse_values);
+
+ std::shared_ptr<SparseCSRMatrix> sm;
+
+ // OK
+ ASSERT_OK(
+ SparseCSRMatrix::Make(si, int64(), sparse_data, this->shape_, this->dim_names_));
+
+ // OK with an empty dim_names
+ ASSERT_OK(SparseCSRMatrix::Make(si, int64(), sparse_data, this->shape_, {}));
+
+ // invalid data type
+ ASSERT_RAISES(Invalid,
+ SparseCSRMatrix::Make(si, binary(), sparse_data, this->shape_, {}));
+
+ // empty shape
+ ASSERT_RAISES(Invalid, SparseCSRMatrix::Make(si, int64(), sparse_data, {}, {}));
+
+ // 1D shape
+ ASSERT_RAISES(Invalid, SparseCSRMatrix::Make(si, int64(), sparse_data, {24}, {}));
+
+ // negative items in shape
+ ASSERT_RAISES(Invalid, SparseCSRMatrix::Make(si, int64(), sparse_data, {6, -4}, {}));
+
+ // sparse index and ndim (shape length) are inconsistent
+ ASSERT_RAISES(Invalid, SparseCSRMatrix::Make(si, int64(), sparse_data, {4, 6}, {}));
+
+ // shape and dim_names are inconsistent
+ ASSERT_RAISES(Invalid, SparseCSRMatrix::Make(si, int64(), sparse_data, this->shape_,
+ std::vector<std::string>{"foo"}));
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TestSparseCSRMatrixForIndexValueType, Make);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt8, TestSparseCSRMatrixForIndexValueType, Int8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt8, TestSparseCSRMatrixForIndexValueType,
+ UInt8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt16, TestSparseCSRMatrixForIndexValueType,
+ Int16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt16, TestSparseCSRMatrixForIndexValueType,
+ UInt16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt32, TestSparseCSRMatrixForIndexValueType,
+ Int32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt32, TestSparseCSRMatrixForIndexValueType,
+ UInt32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt64, TestSparseCSRMatrixForIndexValueType,
+ Int64Type);
+
+TEST(TestSparseCSRMatrixForUInt64Index, Make) {
+ std::vector<int64_t> dense_values = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+ Tensor dense_tensor(uint64(), Buffer::Wrap(dense_values), {6, 4});
+ ASSERT_RAISES(Invalid, SparseCSRMatrix::Make(dense_tensor, uint64()));
+}
+
+template <typename IndexValueType>
+class TestSparseCSCMatrixBase : public TestSparseTensorBase<Int64Type> {
+ public:
+ void SetUp() {
+ shape_ = {6, 4};
+ dim_names_ = {"foo", "bar"};
+
+ // Dense representation:
+ // [
+ // 1 0 2 0
+ // 0 3 0 4
+ // 5 0 6 0
+ // 0 11 0 12
+ // 13 0 14 0
+ // 0 15 0 16
+ // ]
+ dense_values_ = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+ auto dense_data = Buffer::Wrap(dense_values_);
+ NumericTensor<Int64Type> dense_tensor(dense_data, shape_, {}, dim_names_);
+ ASSERT_OK_AND_ASSIGN(sparse_tensor_from_dense_,
+ SparseCSCMatrix::Make(
+ dense_tensor, TypeTraits<IndexValueType>::type_singleton()));
+ }
+
+ protected:
+ std::vector<int64_t> dense_values_;
+ std::shared_ptr<SparseCSCMatrix> sparse_tensor_from_dense_;
+};
+
+class TestSparseCSCMatrix : public TestSparseCSCMatrixBase<Int64Type> {};
+
+TEST_F(TestSparseCSCMatrix, CreationFromZeroTensor) {
+ const auto dense_size =
+ std::accumulate(this->shape_.begin(), this->shape_.end(), int64_t(1),
+ [](int64_t a, int64_t x) { return a * x; });
+ std::vector<int64_t> dense_values(dense_size, 0);
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> t_zero,
+ Tensor::Make(int64(), Buffer::Wrap(dense_values), this->shape_));
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<SparseCSCMatrix> st_zero,
+ SparseCSCMatrix::Make(*t_zero, int64()));
+
+ ASSERT_EQ(0, st_zero->non_zero_length());
+ ASSERT_EQ(dense_size, st_zero->size());
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> t, st_zero->ToTensor());
+ ASSERT_TRUE(t->Equals(*t_zero));
+}
+
+TEST_F(TestSparseCSCMatrix, CreationFromNumericTensor2D) {
+ std::shared_ptr<Buffer> buffer = Buffer::Wrap(this->dense_values_);
+ NumericTensor<Int64Type> tensor(buffer, this->shape_);
+
+ std::shared_ptr<SparseCSCMatrix> st1;
+ ASSERT_OK_AND_ASSIGN(st1, SparseCSCMatrix::Make(tensor));
+
+ auto st2 = this->sparse_tensor_from_dense_;
+
+ CheckSparseIndexFormatType(SparseTensorFormat::CSC, *st1);
+
+ ASSERT_EQ(12, st1->non_zero_length());
+ ASSERT_TRUE(st1->is_mutable());
+
+ ASSERT_EQ(std::vector<std::string>({"foo", "bar"}), st2->dim_names());
+ ASSERT_EQ("foo", st2->dim_name(0));
+ ASSERT_EQ("bar", st2->dim_name(1));
+
+ ASSERT_EQ(std::vector<std::string>({}), st1->dim_names());
+ ASSERT_EQ("", st1->dim_name(0));
+ ASSERT_EQ("", st1->dim_name(1));
+ ASSERT_EQ("", st1->dim_name(2));
+
+ const int64_t* raw_data = reinterpret_cast<const int64_t*>(st1->raw_data());
+ AssertNumericDataEqual(raw_data, {1, 5, 13, 3, 11, 15, 2, 6, 14, 4, 12, 16});
+
+ auto si = internal::checked_pointer_cast<SparseCSCIndex>(st1->sparse_index());
+ ASSERT_EQ(std::string("SparseCSCIndex"), si->ToString());
+ ASSERT_EQ(1, si->indptr()->ndim());
+ ASSERT_EQ(1, si->indices()->ndim());
+
+ const int64_t* indptr_begin =
+ reinterpret_cast<const int64_t*>(si->indptr()->raw_data());
+ std::vector<int64_t> indptr_values(indptr_begin,
+ indptr_begin + si->indptr()->shape()[0]);
+
+ ASSERT_EQ(5, indptr_values.size());
+ ASSERT_EQ(std::vector<int64_t>({0, 3, 6, 9, 12}), indptr_values);
+
+ const int64_t* indices_begin =
+ reinterpret_cast<const int64_t*>(si->indices()->raw_data());
+ std::vector<int64_t> indices_values(indices_begin,
+ indices_begin + si->indices()->shape()[0]);
+
+ ASSERT_EQ(12, indices_values.size());
+ ASSERT_EQ(std::vector<int64_t>({0, 2, 4, 1, 3, 5, 0, 2, 4, 1, 3, 5}), indices_values);
+}
+
+TEST_F(TestSparseCSCMatrix, CreationFromNonContiguousTensor) {
+ std::vector<int64_t> values = {1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 3, 0, 0, 0, 4, 0,
+ 5, 0, 0, 0, 6, 0, 0, 0, 0, 0, 11, 0, 0, 0, 12, 0,
+ 13, 0, 0, 0, 14, 0, 0, 0, 0, 0, 15, 0, 0, 0, 16, 0};
+ std::vector<int64_t> strides = {64, 16};
+ std::shared_ptr<Buffer> buffer = Buffer::Wrap(values);
+ Tensor tensor(int64(), buffer, this->shape_, strides);
+
+ std::shared_ptr<SparseCSCMatrix> st;
+ ASSERT_OK_AND_ASSIGN(st, SparseCSCMatrix::Make(tensor));
+
+ ASSERT_EQ(12, st->non_zero_length());
+ ASSERT_TRUE(st->is_mutable());
+
+ const int64_t* raw_data = reinterpret_cast<const int64_t*>(st->raw_data());
+ AssertNumericDataEqual(raw_data, {1, 5, 13, 3, 11, 15, 2, 6, 14, 4, 12, 16});
+
+ auto si = internal::checked_pointer_cast<SparseCSCIndex>(st->sparse_index());
+ ASSERT_EQ(1, si->indptr()->ndim());
+ ASSERT_EQ(1, si->indices()->ndim());
+
+ const int64_t* indptr_begin =
+ reinterpret_cast<const int64_t*>(si->indptr()->raw_data());
+ std::vector<int64_t> indptr_values(indptr_begin,
+ indptr_begin + si->indptr()->shape()[0]);
+
+ ASSERT_EQ(5, indptr_values.size());
+ ASSERT_EQ(std::vector<int64_t>({0, 3, 6, 9, 12}), indptr_values);
+
+ const int64_t* indices_begin =
+ reinterpret_cast<const int64_t*>(si->indices()->raw_data());
+ std::vector<int64_t> indices_values(indices_begin,
+ indices_begin + si->indices()->shape()[0]);
+
+ ASSERT_EQ(12, indices_values.size());
+ ASSERT_EQ(std::vector<int64_t>({0, 2, 4, 1, 3, 5, 0, 2, 4, 1, 3, 5}), indices_values);
+
+ ASSERT_TRUE(st->Equals(*this->sparse_tensor_from_dense_));
+}
+
+TEST_F(TestSparseCSCMatrix, TestToTensor) {
+ std::vector<int64_t> values = {1, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 1,
+ 0, 2, 0, 0, 0, 0, 0, 3, 0, 0, 0, 1};
+ std::vector<int64_t> shape({6, 4});
+ std::shared_ptr<Buffer> buffer = Buffer::Wrap(values);
+ Tensor tensor(int64(), buffer, shape, {}, this->dim_names_);
+
+ std::shared_ptr<SparseCSCMatrix> sparse_tensor;
+ ASSERT_OK_AND_ASSIGN(sparse_tensor, SparseCSCMatrix::Make(tensor));
+
+ ASSERT_EQ(7, sparse_tensor->non_zero_length());
+ ASSERT_TRUE(sparse_tensor->is_mutable());
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> dense_tensor, sparse_tensor->ToTensor());
+ ASSERT_TRUE(tensor.Equals(*dense_tensor));
+}
+
+template <typename ValueType>
+class TestSparseCSCMatrixEquality : public TestSparseTensorBase<ValueType> {
+ public:
+ void SetUp() {
+ shape_ = {6, 4};
+ values1_ = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+ values2_ = {9, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+ auto buffer1 = Buffer::Wrap(values1_);
+ auto buffer2 = Buffer::Wrap(values2_);
+ DCHECK_OK(NumericTensor<ValueType>::Make(buffer1, shape_).Value(&tensor1_));
+ DCHECK_OK(NumericTensor<ValueType>::Make(buffer2, shape_).Value(&tensor2_));
+ }
+
+ protected:
+ using TestSparseTensorBase<ValueType>::shape_;
+ std::vector<typename ValueType::c_type> values1_;
+ std::vector<typename ValueType::c_type> values2_;
+ std::shared_ptr<NumericTensor<ValueType>> tensor1_;
+ std::shared_ptr<NumericTensor<ValueType>> tensor2_;
+};
+
+template <typename ValueType>
+class TestIntegerSparseCSCMatrixEquality : public TestSparseCSCMatrixEquality<ValueType> {
+};
+
+TYPED_TEST_SUITE_P(TestIntegerSparseCSCMatrixEquality);
+
+TYPED_TEST_P(TestIntegerSparseCSCMatrixEquality, TestEquality) {
+ using ValueType = TypeParam;
+ static_assert(is_integer_type<ValueType>::value, "Integer type is required");
+
+ std::shared_ptr<SparseCSCMatrix> st1, st2, st3;
+ ASSERT_OK_AND_ASSIGN(st1, SparseCSCMatrix::Make(*this->tensor1_));
+ ASSERT_OK_AND_ASSIGN(st2, SparseCSCMatrix::Make(*this->tensor2_));
+ ASSERT_OK_AND_ASSIGN(st3, SparseCSCMatrix::Make(*this->tensor1_));
+
+ ASSERT_TRUE(st1->Equals(*st1));
+ ASSERT_FALSE(st1->Equals(*st2));
+ ASSERT_TRUE(st1->Equals(*st3));
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TestIntegerSparseCSCMatrixEquality, TestEquality);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt8, TestIntegerSparseCSCMatrixEquality, Int8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt8, TestIntegerSparseCSCMatrixEquality, UInt8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt16, TestIntegerSparseCSCMatrixEquality, Int16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt16, TestIntegerSparseCSCMatrixEquality,
+ UInt16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt32, TestIntegerSparseCSCMatrixEquality, Int32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt32, TestIntegerSparseCSCMatrixEquality,
+ UInt32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt64, TestIntegerSparseCSCMatrixEquality, Int64Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt64, TestIntegerSparseCSCMatrixEquality,
+ UInt64Type);
+
+template <typename ValueType>
+class TestFloatingSparseCSCMatrixEquality
+ : public TestSparseCSCMatrixEquality<ValueType> {};
+
+TYPED_TEST_SUITE_P(TestFloatingSparseCSCMatrixEquality);
+
+TYPED_TEST_P(TestFloatingSparseCSCMatrixEquality, TestEquality) {
+ using ValueType = TypeParam;
+ using c_value_type = typename ValueType::c_type;
+ static_assert(is_floating_type<ValueType>::value, "Float type is required");
+
+ std::shared_ptr<SparseCSCMatrix> st1, st2, st3;
+ ASSERT_OK_AND_ASSIGN(st1, SparseCSCMatrix::Make(*this->tensor1_));
+ ASSERT_OK_AND_ASSIGN(st2, SparseCSCMatrix::Make(*this->tensor2_));
+ ASSERT_OK_AND_ASSIGN(st3, SparseCSCMatrix::Make(*this->tensor1_));
+
+ ASSERT_TRUE(st1->Equals(*st1));
+ ASSERT_FALSE(st1->Equals(*st2));
+ ASSERT_TRUE(st1->Equals(*st3));
+
+ // sparse tensors with NaNs
+ const c_value_type nan_value = static_cast<c_value_type>(NAN);
+ this->values2_[13] = nan_value;
+ EXPECT_TRUE(std::isnan(this->tensor2_->Value({3, 1})));
+
+ std::shared_ptr<SparseCSCMatrix> st4;
+ ASSERT_OK_AND_ASSIGN(st4, SparseCSCMatrix::Make(*this->tensor2_));
+ EXPECT_FALSE(st4->Equals(*st4)); // same object
+ EXPECT_TRUE(st4->Equals(*st4, EqualOptions().nans_equal(true))); // same object
+
+ std::vector<c_value_type> values5 = this->values2_;
+ std::shared_ptr<SparseCSCMatrix> st5;
+ std::shared_ptr<Buffer> buffer5 = Buffer::Wrap(values5);
+ NumericTensor<ValueType> tensor5(buffer5, this->shape_);
+ ASSERT_OK_AND_ASSIGN(st5, SparseCSCMatrix::Make(tensor5));
+ EXPECT_FALSE(st4->Equals(*st5)); // different memory
+ EXPECT_TRUE(st4->Equals(*st5, EqualOptions().nans_equal(true))); // different memory
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TestFloatingSparseCSCMatrixEquality, TestEquality);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(TestFloat, TestFloatingSparseCSCMatrixEquality, FloatType);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestDouble, TestFloatingSparseCSCMatrixEquality,
+ DoubleType);
+
+template <typename ValueType>
+class TestSparseCSFTensorEquality : public TestSparseTensorBase<ValueType> {
+ public:
+ void SetUp() {
+ shape_ = {2, 3, 4, 5};
+
+ values1_[0][0][0][1] = 1;
+ values1_[0][0][0][2] = 2;
+ values1_[0][1][0][0] = 3;
+ values1_[0][1][0][2] = 4;
+ values1_[0][1][1][0] = 5;
+ values1_[1][1][1][0] = 6;
+ values1_[1][1][1][1] = 7;
+ values1_[1][1][1][2] = 8;
+
+ length_ = sizeof(values1_);
+
+ values2_[0][0][0][1] = 1;
+ values2_[0][0][0][2] = 2;
+ values2_[0][1][0][0] = 3;
+ values2_[0][1][0][2] = 9;
+ values2_[0][1][1][0] = 5;
+ values2_[1][1][1][0] = 6;
+ values2_[1][1][1][1] = 7;
+ values2_[1][1][1][2] = 8;
+
+ auto buffer1 = Buffer::Wrap(values1_, length_);
+ auto buffer2 = Buffer::Wrap(values2_, length_);
+
+ DCHECK_OK(NumericTensor<ValueType>::Make(buffer1, shape_).Value(&tensor1_));
+ DCHECK_OK(NumericTensor<ValueType>::Make(buffer2, shape_).Value(&tensor2_));
+ }
+
+ protected:
+ using TestSparseTensorBase<ValueType>::shape_;
+ typename ValueType::c_type values1_[2][3][4][5] = {};
+ typename ValueType::c_type values2_[2][3][4][5] = {};
+ int64_t length_;
+ std::shared_ptr<NumericTensor<ValueType>> tensor1_;
+ std::shared_ptr<NumericTensor<ValueType>> tensor2_;
+};
+
+template <typename ValueType>
+class TestIntegerSparseCSFTensorEquality : public TestSparseCSFTensorEquality<ValueType> {
+};
+
+TYPED_TEST_SUITE_P(TestIntegerSparseCSFTensorEquality);
+
+TYPED_TEST_P(TestIntegerSparseCSFTensorEquality, TestEquality) {
+ using ValueType = TypeParam;
+ static_assert(is_integer_type<ValueType>::value, "Integer type is required");
+
+ std::shared_ptr<SparseCSFTensor> st1, st2, st3;
+ ASSERT_OK_AND_ASSIGN(st1, SparseCSFTensor::Make(*this->tensor1_));
+ ASSERT_OK_AND_ASSIGN(st2, SparseCSFTensor::Make(*this->tensor2_));
+ ASSERT_OK_AND_ASSIGN(st3, SparseCSFTensor::Make(*this->tensor1_));
+
+ ASSERT_TRUE(st1->Equals(*st1));
+ ASSERT_FALSE(st1->Equals(*st2));
+ ASSERT_TRUE(st1->Equals(*st3));
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TestIntegerSparseCSFTensorEquality, TestEquality);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt8, TestIntegerSparseCSFTensorEquality, Int8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt8, TestIntegerSparseCSFTensorEquality, UInt8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt16, TestIntegerSparseCSFTensorEquality, Int16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt16, TestIntegerSparseCSFTensorEquality,
+ UInt16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt32, TestIntegerSparseCSFTensorEquality, Int32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt32, TestIntegerSparseCSFTensorEquality,
+ UInt32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt64, TestIntegerSparseCSFTensorEquality, Int64Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt64, TestIntegerSparseCSFTensorEquality,
+ UInt64Type);
+
+template <typename ValueType>
+class TestFloatingSparseCSFTensorEquality
+ : public TestSparseCSFTensorEquality<ValueType> {};
+
+TYPED_TEST_SUITE_P(TestFloatingSparseCSFTensorEquality);
+
+TYPED_TEST_P(TestFloatingSparseCSFTensorEquality, TestEquality) {
+ using ValueType = TypeParam;
+ using c_value_type = typename ValueType::c_type;
+ static_assert(is_floating_type<ValueType>::value, "Floating type is required");
+
+ std::shared_ptr<SparseCSFTensor> st1, st2, st3;
+ ASSERT_OK_AND_ASSIGN(st1, SparseCSFTensor::Make(*this->tensor1_));
+ ASSERT_OK_AND_ASSIGN(st2, SparseCSFTensor::Make(*this->tensor2_));
+ ASSERT_OK_AND_ASSIGN(st3, SparseCSFTensor::Make(*this->tensor1_));
+
+ ASSERT_TRUE(st1->Equals(*st1));
+ ASSERT_FALSE(st1->Equals(*st2));
+ ASSERT_TRUE(st1->Equals(*st3));
+
+ // sparse tensors with NaNs
+ const c_value_type nan_value = static_cast<c_value_type>(NAN);
+ this->values2_[1][1][1][1] = nan_value;
+ EXPECT_TRUE(std::isnan(this->tensor2_->Value({1, 1, 1, 1})));
+
+ std::shared_ptr<SparseCSFTensor> st4;
+ ASSERT_OK_AND_ASSIGN(st4, SparseCSFTensor::Make(*this->tensor2_));
+ EXPECT_FALSE(st4->Equals(*st4)); // same object
+ EXPECT_TRUE(st4->Equals(*st4, EqualOptions().nans_equal(true))); // same object
+
+ c_value_type values5[2][3][4][5] = {};
+ std::copy_n(&this->values2_[0][0][0][0], this->length_ / sizeof(c_value_type),
+ &values5[0][0][0][0]);
+ std::shared_ptr<SparseCSFTensor> st5;
+ std::shared_ptr<Buffer> buffer5 = Buffer::Wrap(values5, sizeof(values5));
+ NumericTensor<ValueType> tensor5(buffer5, this->shape_);
+ ASSERT_OK_AND_ASSIGN(st5, SparseCSFTensor::Make(tensor5));
+ EXPECT_FALSE(st4->Equals(*st5)); // different memory
+ EXPECT_TRUE(st4->Equals(*st5, EqualOptions().nans_equal(true))); // different memory
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TestFloatingSparseCSFTensorEquality, TestEquality);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(TestFloat, TestFloatingSparseCSFTensorEquality, FloatType);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestDouble, TestFloatingSparseCSFTensorEquality,
+ DoubleType);
+
+template <typename IndexValueType>
+class TestSparseCSFTensorBase : public TestSparseTensorBase<Int16Type> {
+ public:
+ void SetUp() {
+ dim_names_ = {"a", "b", "c", "d"};
+ shape_ = {2, 3, 4, 5};
+
+ dense_values_[0][0][0][1] = 1;
+ dense_values_[0][0][0][2] = 2;
+ dense_values_[0][1][0][0] = 3;
+ dense_values_[0][1][0][2] = 4;
+ dense_values_[0][1][1][0] = 5;
+ dense_values_[1][1][1][0] = 6;
+ dense_values_[1][1][1][1] = 7;
+ dense_values_[1][1][1][2] = 8;
+
+ auto dense_buffer = Buffer::Wrap(dense_values_, sizeof(dense_values_));
+ Tensor dense_tensor_(int16(), dense_buffer, shape_, {}, dim_names_);
+ ASSERT_OK_AND_ASSIGN(
+ sparse_tensor_from_dense_,
+ SparseCSFTensor::Make(dense_tensor_,
+ TypeTraits<IndexValueType>::type_singleton()));
+ }
+
+ protected:
+ std::vector<int64_t> shape_;
+ std::vector<std::string> dim_names_;
+ int16_t dense_values_[2][3][4][5] = {};
+ std::shared_ptr<SparseCSFTensor> sparse_tensor_from_dense_;
+};
+
+class TestSparseCSFTensor : public TestSparseCSFTensorBase<Int64Type> {};
+
+TEST_F(TestSparseCSFTensor, CreationFromZeroTensor) {
+ const auto dense_size =
+ std::accumulate(this->shape_.begin(), this->shape_.end(), int64_t(1),
+ [](int64_t a, int64_t x) { return a * x; });
+ std::vector<int64_t> dense_values(dense_size, 0);
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> t_zero,
+ Tensor::Make(int64(), Buffer::Wrap(dense_values), this->shape_));
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<SparseCSFTensor> st_zero,
+ SparseCSFTensor::Make(*t_zero, int64()));
+
+ ASSERT_EQ(0, st_zero->non_zero_length());
+ ASSERT_EQ(dense_size, st_zero->size());
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> t, st_zero->ToTensor());
+ ASSERT_TRUE(t->Equals(*t_zero));
+}
+
+template <typename IndexValueType>
+class TestSparseCSFTensorForIndexValueType
+ : public TestSparseCSFTensorBase<IndexValueType> {
+ protected:
+ std::shared_ptr<SparseCSFIndex> MakeSparseCSFIndex(
+ const std::vector<int64_t>& axis_order,
+ const std::vector<std::vector<typename IndexValueType::c_type>>& indptr_values,
+ const std::vector<std::vector<typename IndexValueType::c_type>>& indices_values)
+ const {
+ int64_t ndim = axis_order.size();
+ std::vector<std::shared_ptr<Tensor>> indptr(ndim - 1);
+ std::vector<std::shared_ptr<Tensor>> indices(ndim);
+
+ for (int64_t i = 0; i < ndim - 1; ++i) {
+ indptr[i] = std::make_shared<Tensor>(
+ TypeTraits<IndexValueType>::type_singleton(), Buffer::Wrap(indptr_values[i]),
+ std::vector<int64_t>({static_cast<int64_t>(indptr_values[i].size())}));
+ }
+ for (int64_t i = 0; i < ndim; ++i) {
+ indices[i] = std::make_shared<Tensor>(
+ TypeTraits<IndexValueType>::type_singleton(), Buffer::Wrap(indices_values[i]),
+ std::vector<int64_t>({static_cast<int64_t>(indices_values[i].size())}));
+ }
+ return std::make_shared<SparseCSFIndex>(indptr, indices, axis_order);
+ }
+
+ template <typename CValueType>
+ std::shared_ptr<SparseCSFTensor> MakeSparseTensor(
+ const std::shared_ptr<SparseCSFIndex>& si, std::vector<CValueType>& sparse_values,
+ const std::vector<int64_t>& shape,
+ const std::vector<std::string>& dim_names) const {
+ auto data_buffer = Buffer::Wrap(sparse_values);
+ return std::make_shared<SparseCSFTensor>(
+ si, CTypeTraits<CValueType>::type_singleton(), data_buffer, shape, dim_names);
+ }
+};
+
+TYPED_TEST_SUITE_P(TestSparseCSFTensorForIndexValueType);
+
+TYPED_TEST_P(TestSparseCSFTensorForIndexValueType, TestCreateSparseTensor) {
+ using IndexValueType = TypeParam;
+ using c_index_value_type = typename IndexValueType::c_type;
+
+ std::vector<int64_t> shape = {2, 3, 4, 5};
+ std::vector<std::string> dim_names = {"a", "b", "c", "d"};
+ std::vector<int64_t> axis_order = {0, 1, 2, 3};
+ std::vector<int16_t> sparse_values = {1, 2, 3, 4, 5, 6, 7, 8};
+ std::vector<std::vector<c_index_value_type>> indptr_values = {
+ {0, 2, 3}, {0, 1, 3, 4}, {0, 2, 4, 5, 8}};
+ std::vector<std::vector<c_index_value_type>> indices_values = {
+ {0, 1}, {0, 1, 1}, {0, 0, 1, 1}, {1, 2, 0, 2, 0, 0, 1, 2}};
+
+ auto si = this->MakeSparseCSFIndex(axis_order, indptr_values, indices_values);
+ auto st = this->MakeSparseTensor(si, sparse_values, shape, dim_names);
+
+ ASSERT_TRUE(st->Equals(*this->sparse_tensor_from_dense_));
+}
+
+TYPED_TEST_P(TestSparseCSFTensorForIndexValueType, TestTensorToSparseTensor) {
+ std::vector<std::string> dim_names = {"a", "b", "c", "d"};
+ ASSERT_EQ(8, this->sparse_tensor_from_dense_->non_zero_length());
+ ASSERT_TRUE(this->sparse_tensor_from_dense_->is_mutable());
+ ASSERT_EQ(dim_names, this->sparse_tensor_from_dense_->dim_names());
+}
+
+TYPED_TEST_P(TestSparseCSFTensorForIndexValueType, TestSparseTensorToTensor) {
+ std::vector<int64_t> shape = {2, 3, 4, 5};
+ auto dense_buffer = Buffer::Wrap(this->dense_values_, sizeof(this->dense_values_));
+ Tensor dense_tensor(int16(), dense_buffer, shape, {}, this->dim_names_);
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> dt,
+ this->sparse_tensor_from_dense_->ToTensor());
+ ASSERT_TRUE(dense_tensor.Equals(*dt));
+ ASSERT_EQ(dense_tensor.dim_names(), dt->dim_names());
+}
+
+TYPED_TEST_P(TestSparseCSFTensorForIndexValueType, TestRoundTrip) {
+ using IndexValueType = TypeParam;
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> dt,
+ this->sparse_tensor_from_dense_->ToTensor());
+ std::shared_ptr<SparseCSFTensor> st;
+ ASSERT_OK_AND_ASSIGN(
+ st, SparseCSFTensor::Make(*dt, TypeTraits<IndexValueType>::type_singleton()));
+
+ ASSERT_TRUE(st->Equals(*this->sparse_tensor_from_dense_));
+}
+
+TYPED_TEST_P(TestSparseCSFTensorForIndexValueType, TestAlternativeAxisOrder) {
+ using IndexValueType = TypeParam;
+ using c_index_value_type = typename IndexValueType::c_type;
+
+ std::vector<int16_t> dense_values = {1, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 5};
+ std::vector<int64_t> shape = {4, 6};
+ std::vector<std::string> dim_names = {"a", "b"};
+ std::shared_ptr<Buffer> dense_buffer = Buffer::Wrap(dense_values);
+ Tensor tensor(int16(), dense_buffer, shape, {}, dim_names);
+
+ // Axis order 1
+ std::vector<int64_t> axis_order_1 = {0, 1};
+ std::vector<int16_t> sparse_values_1 = {1, 3, 2, 4, 5};
+ std::vector<std::vector<c_index_value_type>> indptr_values_1 = {{0, 2, 3, 5}};
+ std::vector<std::vector<c_index_value_type>> indices_values_1 = {{0, 1, 3},
+ {0, 3, 1, 3, 5}};
+ auto si_1 = this->MakeSparseCSFIndex(axis_order_1, indptr_values_1, indices_values_1);
+ auto st_1 = this->MakeSparseTensor(si_1, sparse_values_1, shape, dim_names);
+
+ // Axis order 2
+ std::vector<int64_t> axis_order_2 = {1, 0};
+ std::vector<int16_t> sparse_values_2 = {1, 2, 3, 4, 5};
+ std::vector<std::vector<c_index_value_type>> indptr_values_2 = {{0, 1, 2, 4, 5}};
+ std::vector<std::vector<c_index_value_type>> indices_values_2 = {{0, 1, 3, 5},
+ {0, 1, 0, 3, 3}};
+ auto si_2 = this->MakeSparseCSFIndex(axis_order_2, indptr_values_2, indices_values_2);
+ auto st_2 = this->MakeSparseTensor(si_2, sparse_values_2, shape, dim_names);
+
+ std::shared_ptr<Tensor> dt_1, dt_2;
+ ASSERT_OK_AND_ASSIGN(dt_1, st_1->ToTensor());
+ ASSERT_OK_AND_ASSIGN(dt_2, st_2->ToTensor());
+
+ ASSERT_FALSE(st_1->Equals(*st_2));
+ ASSERT_TRUE(dt_1->Equals(*dt_2));
+ ASSERT_TRUE(dt_1->Equals(tensor));
+}
+
+TYPED_TEST_P(TestSparseCSFTensorForIndexValueType, TestNonAscendingShape) {
+ using IndexValueType = TypeParam;
+ using c_index_value_type = typename IndexValueType::c_type;
+
+ std::vector<int64_t> shape = {5, 2, 3, 4};
+ int16_t dense_values[5][2][3][4] = {}; // zero-initialized
+ dense_values[0][0][0][1] = 1;
+ dense_values[0][0][0][2] = 2;
+ dense_values[0][1][0][0] = 3;
+ dense_values[0][1][0][2] = 4;
+ dense_values[0][1][1][0] = 5;
+ dense_values[1][1][1][0] = 6;
+ dense_values[1][1][1][1] = 7;
+ dense_values[1][1][1][2] = 8;
+ auto dense_buffer = Buffer::Wrap(dense_values, sizeof(dense_values));
+ Tensor dense_tensor(int16(), dense_buffer, shape, {}, this->dim_names_);
+
+ std::shared_ptr<SparseCSFTensor> sparse_tensor;
+ ASSERT_OK_AND_ASSIGN(
+ sparse_tensor,
+ SparseCSFTensor::Make(dense_tensor, TypeTraits<IndexValueType>::type_singleton()));
+
+ std::vector<std::vector<c_index_value_type>> indptr_values = {
+ {0, 1, 3}, {0, 2, 4, 7}, {0, 1, 2, 3, 4, 6, 7, 8}};
+ std::vector<std::vector<c_index_value_type>> indices_values = {
+ {0, 1}, {0, 0, 1}, {1, 2, 0, 2, 0, 1, 2}, {0, 0, 0, 0, 0, 1, 1, 1}};
+ std::vector<int64_t> axis_order = {1, 2, 3, 0};
+ std::vector<int16_t> sparse_values = {1, 2, 3, 4, 5, 6, 7, 8};
+ auto si = this->MakeSparseCSFIndex(axis_order, indptr_values, indices_values);
+ auto st = this->MakeSparseTensor(si, sparse_values, shape, this->dim_names_);
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> dt, st->ToTensor());
+ ASSERT_TRUE(dt->Equals(dense_tensor));
+ ASSERT_TRUE(st->Equals(*sparse_tensor));
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TestSparseCSFTensorForIndexValueType, TestCreateSparseTensor,
+ TestTensorToSparseTensor, TestSparseTensorToTensor,
+ TestAlternativeAxisOrder, TestNonAscendingShape,
+ TestRoundTrip);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt8, TestSparseCSFTensorForIndexValueType, Int8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt8, TestSparseCSFTensorForIndexValueType,
+ UInt8Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt16, TestSparseCSFTensorForIndexValueType,
+ Int16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt16, TestSparseCSFTensorForIndexValueType,
+ UInt16Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt32, TestSparseCSFTensorForIndexValueType,
+ Int32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt32, TestSparseCSFTensorForIndexValueType,
+ UInt32Type);
+INSTANTIATE_TYPED_TEST_SUITE_P(TestInt64, TestSparseCSFTensorForIndexValueType,
+ Int64Type);
+
+TEST(TestSparseCSFMatrixForUInt64Index, Make) {
+ int16_t dense_values[2][3][4][5] = {};
+ dense_values[0][0][0][1] = 1;
+ dense_values[0][0][0][2] = 2;
+ dense_values[0][1][0][0] = 3;
+ dense_values[0][1][0][2] = 4;
+ dense_values[0][1][1][0] = 5;
+ dense_values[1][1][1][0] = 6;
+ dense_values[1][1][1][1] = 7;
+ dense_values[1][1][1][2] = 8;
+
+ Tensor dense_tensor(uint64(), Buffer::Wrap(dense_values, sizeof(dense_values)),
+ {2, 3, 4, 5});
+ ASSERT_RAISES(Invalid, SparseCSFTensor::Make(dense_tensor, uint64()));
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/status.cc b/src/arrow/cpp/src/arrow/status.cc
new file mode 100644
index 000000000..0f02cb57a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/status.cc
@@ -0,0 +1,143 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// A Status encapsulates the result of an operation. It may indicate success,
+// or it may indicate an error with an associated error message.
+//
+// Multiple threads can invoke const methods on a Status without
+// external synchronization, but if any of the threads may call a
+// non-const method, all threads accessing the same Status must use
+// external synchronization.
+
+#include "arrow/status.h"
+
+#include <cassert>
+#include <cstdlib>
+#include <iostream>
+#include <sstream>
+
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+Status::Status(StatusCode code, const std::string& msg)
+ : Status::Status(code, msg, nullptr) {}
+
+Status::Status(StatusCode code, std::string msg, std::shared_ptr<StatusDetail> detail) {
+ ARROW_CHECK_NE(code, StatusCode::OK) << "Cannot construct ok status with message";
+ state_ = new State;
+ state_->code = code;
+ state_->msg = std::move(msg);
+ if (detail != nullptr) {
+ state_->detail = std::move(detail);
+ }
+}
+
+void Status::CopyFrom(const Status& s) {
+ delete state_;
+ if (s.state_ == nullptr) {
+ state_ = nullptr;
+ } else {
+ state_ = new State(*s.state_);
+ }
+}
+
+std::string Status::CodeAsString() const {
+ if (state_ == nullptr) {
+ return "OK";
+ }
+ return CodeAsString(code());
+}
+
+std::string Status::CodeAsString(StatusCode code) {
+ const char* type;
+ switch (code) {
+ case StatusCode::OK:
+ type = "OK";
+ break;
+ case StatusCode::OutOfMemory:
+ type = "Out of memory";
+ break;
+ case StatusCode::KeyError:
+ type = "Key error";
+ break;
+ case StatusCode::TypeError:
+ type = "Type error";
+ break;
+ case StatusCode::Invalid:
+ type = "Invalid";
+ break;
+ case StatusCode::Cancelled:
+ type = "Cancelled";
+ break;
+ case StatusCode::IOError:
+ type = "IOError";
+ break;
+ case StatusCode::CapacityError:
+ type = "Capacity error";
+ break;
+ case StatusCode::IndexError:
+ type = "Index error";
+ break;
+ case StatusCode::UnknownError:
+ type = "Unknown error";
+ break;
+ case StatusCode::NotImplemented:
+ type = "NotImplemented";
+ break;
+ case StatusCode::SerializationError:
+ type = "Serialization error";
+ break;
+ case StatusCode::CodeGenError:
+ type = "CodeGenError in Gandiva";
+ break;
+ case StatusCode::ExpressionValidationError:
+ type = "ExpressionValidationError";
+ break;
+ case StatusCode::ExecutionError:
+ type = "ExecutionError in Gandiva";
+ break;
+ default:
+ type = "Unknown";
+ break;
+ }
+ return std::string(type);
+}
+
+std::string Status::ToString() const {
+ std::string result(CodeAsString());
+ if (state_ == nullptr) {
+ return result;
+ }
+ result += ": ";
+ result += state_->msg;
+ if (state_->detail != nullptr) {
+ result += ". Detail: ";
+ result += state_->detail->ToString();
+ }
+
+ return result;
+}
+
+void Status::Abort() const { Abort(std::string()); }
+
+void Status::Abort(const std::string& message) const {
+ std::cerr << "-- Arrow Fatal Error --\n";
+ if (!message.empty()) {
+ std::cerr << message << "\n";
+ }
+ std::cerr << ToString() << std::endl;
+ std::abort();
+}
+
+#ifdef ARROW_EXTRA_ERROR_CONTEXT
+void Status::AddContextLine(const char* filename, int line, const char* expr) {
+ ARROW_CHECK(!ok()) << "Cannot add context line to ok status";
+ std::stringstream ss;
+ ss << "\n" << filename << ":" << line << " " << expr;
+ state_->msg += ss.str();
+}
+#endif
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/status.h b/src/arrow/cpp/src/arrow/status.h
new file mode 100644
index 000000000..056d60d6f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/status.h
@@ -0,0 +1,451 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// A Status encapsulates the result of an operation. It may indicate success,
+// or it may indicate an error with an associated error message.
+//
+// Multiple threads can invoke const methods on a Status without
+// external synchronization, but if any of the threads may call a
+// non-const method, all threads accessing the same Status must use
+// external synchronization.
+
+// Adapted from Apache Kudu, TensorFlow
+
+#pragma once
+
+#include <cstring>
+#include <iosfwd>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "arrow/util/compare.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_builder.h"
+#include "arrow/util/visibility.h"
+
+#ifdef ARROW_EXTRA_ERROR_CONTEXT
+
+/// \brief Return with given status if condition is met.
+#define ARROW_RETURN_IF_(condition, status, expr) \
+ do { \
+ if (ARROW_PREDICT_FALSE(condition)) { \
+ ::arrow::Status _st = (status); \
+ _st.AddContextLine(__FILE__, __LINE__, expr); \
+ return _st; \
+ } \
+ } while (0)
+
+#else
+
+#define ARROW_RETURN_IF_(condition, status, _) \
+ do { \
+ if (ARROW_PREDICT_FALSE(condition)) { \
+ return (status); \
+ } \
+ } while (0)
+
+#endif // ARROW_EXTRA_ERROR_CONTEXT
+
+#define ARROW_RETURN_IF(condition, status) \
+ ARROW_RETURN_IF_(condition, status, ARROW_STRINGIFY(status))
+
+/// \brief Propagate any non-successful Status to the caller
+#define ARROW_RETURN_NOT_OK(status) \
+ do { \
+ ::arrow::Status __s = ::arrow::internal::GenericToStatus(status); \
+ ARROW_RETURN_IF_(!__s.ok(), __s, ARROW_STRINGIFY(status)); \
+ } while (false)
+
+#define RETURN_NOT_OK_ELSE(s, else_) \
+ do { \
+ ::arrow::Status _s = ::arrow::internal::GenericToStatus(s); \
+ if (!_s.ok()) { \
+ else_; \
+ return _s; \
+ } \
+ } while (false)
+
+// This is an internal-use macro and should not be used in public headers.
+#ifndef RETURN_NOT_OK
+#define RETURN_NOT_OK(s) ARROW_RETURN_NOT_OK(s)
+#endif
+
+namespace arrow {
+
+enum class StatusCode : char {
+ OK = 0,
+ OutOfMemory = 1,
+ KeyError = 2,
+ TypeError = 3,
+ Invalid = 4,
+ IOError = 5,
+ CapacityError = 6,
+ IndexError = 7,
+ Cancelled = 8,
+ UnknownError = 9,
+ NotImplemented = 10,
+ SerializationError = 11,
+ RError = 13,
+ // Gandiva range of errors
+ CodeGenError = 40,
+ ExpressionValidationError = 41,
+ ExecutionError = 42,
+ // Continue generic codes.
+ AlreadyExists = 45
+};
+
+/// \brief An opaque class that allows subsystems to retain
+/// additional information inside the Status.
+class ARROW_EXPORT StatusDetail {
+ public:
+ virtual ~StatusDetail() = default;
+ /// \brief Return a unique id for the type of the StatusDetail
+ /// (effectively a poor man's substitute for RTTI).
+ virtual const char* type_id() const = 0;
+ /// \brief Produce a human-readable description of this status.
+ virtual std::string ToString() const = 0;
+
+ bool operator==(const StatusDetail& other) const noexcept {
+ return std::string(type_id()) == other.type_id() && ToString() == other.ToString();
+ }
+};
+
+/// \brief Status outcome object (success or error)
+///
+/// The Status object is an object holding the outcome of an operation.
+/// The outcome is represented as a StatusCode, either success
+/// (StatusCode::OK) or an error (any other of the StatusCode enumeration values).
+///
+/// Additionally, if an error occurred, a specific error message is generally
+/// attached.
+class ARROW_MUST_USE_TYPE ARROW_EXPORT Status : public util::EqualityComparable<Status>,
+ public util::ToStringOstreamable<Status> {
+ public:
+ // Create a success status.
+ Status() noexcept : state_(NULLPTR) {}
+ ~Status() noexcept {
+ // ARROW-2400: On certain compilers, splitting off the slow path improves
+ // performance significantly.
+ if (ARROW_PREDICT_FALSE(state_ != NULL)) {
+ DeleteState();
+ }
+ }
+
+ Status(StatusCode code, const std::string& msg);
+ /// \brief Pluggable constructor for use by sub-systems. detail cannot be null.
+ Status(StatusCode code, std::string msg, std::shared_ptr<StatusDetail> detail);
+
+ // Copy the specified status.
+ inline Status(const Status& s);
+ inline Status& operator=(const Status& s);
+
+ // Move the specified status.
+ inline Status(Status&& s) noexcept;
+ inline Status& operator=(Status&& s) noexcept;
+
+ inline bool Equals(const Status& s) const;
+
+ // AND the statuses.
+ inline Status operator&(const Status& s) const noexcept;
+ inline Status operator&(Status&& s) const noexcept;
+ inline Status& operator&=(const Status& s) noexcept;
+ inline Status& operator&=(Status&& s) noexcept;
+
+ /// Return a success status
+ static Status OK() { return Status(); }
+
+ template <typename... Args>
+ static Status FromArgs(StatusCode code, Args&&... args) {
+ return Status(code, util::StringBuilder(std::forward<Args>(args)...));
+ }
+
+ template <typename... Args>
+ static Status FromDetailAndArgs(StatusCode code, std::shared_ptr<StatusDetail> detail,
+ Args&&... args) {
+ return Status(code, util::StringBuilder(std::forward<Args>(args)...),
+ std::move(detail));
+ }
+
+ /// Return an error status for out-of-memory conditions
+ template <typename... Args>
+ static Status OutOfMemory(Args&&... args) {
+ return Status::FromArgs(StatusCode::OutOfMemory, std::forward<Args>(args)...);
+ }
+
+ /// Return an error status for failed key lookups (e.g. column name in a table)
+ template <typename... Args>
+ static Status KeyError(Args&&... args) {
+ return Status::FromArgs(StatusCode::KeyError, std::forward<Args>(args)...);
+ }
+
+ /// Return an error status for type errors (such as mismatching data types)
+ template <typename... Args>
+ static Status TypeError(Args&&... args) {
+ return Status::FromArgs(StatusCode::TypeError, std::forward<Args>(args)...);
+ }
+
+ /// Return an error status for unknown errors
+ template <typename... Args>
+ static Status UnknownError(Args&&... args) {
+ return Status::FromArgs(StatusCode::UnknownError, std::forward<Args>(args)...);
+ }
+
+ /// Return an error status when an operation or a combination of operation and
+ /// data types is unimplemented
+ template <typename... Args>
+ static Status NotImplemented(Args&&... args) {
+ return Status::FromArgs(StatusCode::NotImplemented, std::forward<Args>(args)...);
+ }
+
+ /// Return an error status for invalid data (for example a string that fails parsing)
+ template <typename... Args>
+ static Status Invalid(Args&&... args) {
+ return Status::FromArgs(StatusCode::Invalid, std::forward<Args>(args)...);
+ }
+
+ /// Return an error status for cancelled operation
+ template <typename... Args>
+ static Status Cancelled(Args&&... args) {
+ return Status::FromArgs(StatusCode::Cancelled, std::forward<Args>(args)...);
+ }
+
+ /// Return an error status when an index is out of bounds
+ template <typename... Args>
+ static Status IndexError(Args&&... args) {
+ return Status::FromArgs(StatusCode::IndexError, std::forward<Args>(args)...);
+ }
+
+ /// Return an error status when a container's capacity would exceed its limits
+ template <typename... Args>
+ static Status CapacityError(Args&&... args) {
+ return Status::FromArgs(StatusCode::CapacityError, std::forward<Args>(args)...);
+ }
+
+ /// Return an error status when some IO-related operation failed
+ template <typename... Args>
+ static Status IOError(Args&&... args) {
+ return Status::FromArgs(StatusCode::IOError, std::forward<Args>(args)...);
+ }
+
+ /// Return an error status when some (de)serialization operation failed
+ template <typename... Args>
+ static Status SerializationError(Args&&... args) {
+ return Status::FromArgs(StatusCode::SerializationError, std::forward<Args>(args)...);
+ }
+
+ template <typename... Args>
+ static Status RError(Args&&... args) {
+ return Status::FromArgs(StatusCode::RError, std::forward<Args>(args)...);
+ }
+
+ template <typename... Args>
+ static Status CodeGenError(Args&&... args) {
+ return Status::FromArgs(StatusCode::CodeGenError, std::forward<Args>(args)...);
+ }
+
+ template <typename... Args>
+ static Status ExpressionValidationError(Args&&... args) {
+ return Status::FromArgs(StatusCode::ExpressionValidationError,
+ std::forward<Args>(args)...);
+ }
+
+ template <typename... Args>
+ static Status ExecutionError(Args&&... args) {
+ return Status::FromArgs(StatusCode::ExecutionError, std::forward<Args>(args)...);
+ }
+
+ template <typename... Args>
+ static Status AlreadyExists(Args&&... args) {
+ return Status::FromArgs(StatusCode::AlreadyExists, std::forward<Args>(args)...);
+ }
+
+ /// Return true iff the status indicates success.
+ bool ok() const { return (state_ == NULLPTR); }
+
+ /// Return true iff the status indicates an out-of-memory error.
+ bool IsOutOfMemory() const { return code() == StatusCode::OutOfMemory; }
+ /// Return true iff the status indicates a key lookup error.
+ bool IsKeyError() const { return code() == StatusCode::KeyError; }
+ /// Return true iff the status indicates invalid data.
+ bool IsInvalid() const { return code() == StatusCode::Invalid; }
+ /// Return true iff the status indicates a cancelled operation.
+ bool IsCancelled() const { return code() == StatusCode::Cancelled; }
+ /// Return true iff the status indicates an IO-related failure.
+ bool IsIOError() const { return code() == StatusCode::IOError; }
+ /// Return true iff the status indicates a container reaching capacity limits.
+ bool IsCapacityError() const { return code() == StatusCode::CapacityError; }
+ /// Return true iff the status indicates an out of bounds index.
+ bool IsIndexError() const { return code() == StatusCode::IndexError; }
+ /// Return true iff the status indicates a type error.
+ bool IsTypeError() const { return code() == StatusCode::TypeError; }
+ /// Return true iff the status indicates an unknown error.
+ bool IsUnknownError() const { return code() == StatusCode::UnknownError; }
+ /// Return true iff the status indicates an unimplemented operation.
+ bool IsNotImplemented() const { return code() == StatusCode::NotImplemented; }
+ /// Return true iff the status indicates a (de)serialization failure
+ bool IsSerializationError() const { return code() == StatusCode::SerializationError; }
+ /// Return true iff the status indicates a R-originated error.
+ bool IsRError() const { return code() == StatusCode::RError; }
+
+ bool IsCodeGenError() const { return code() == StatusCode::CodeGenError; }
+
+ bool IsExpressionValidationError() const {
+ return code() == StatusCode::ExpressionValidationError;
+ }
+
+ bool IsExecutionError() const { return code() == StatusCode::ExecutionError; }
+ bool IsAlreadyExists() const { return code() == StatusCode::AlreadyExists; }
+
+ /// \brief Return a string representation of this status suitable for printing.
+ ///
+ /// The string "OK" is returned for success.
+ std::string ToString() const;
+
+ /// \brief Return a string representation of the status code, without the message
+ /// text or POSIX code information.
+ std::string CodeAsString() const;
+ static std::string CodeAsString(StatusCode);
+
+ /// \brief Return the StatusCode value attached to this status.
+ StatusCode code() const { return ok() ? StatusCode::OK : state_->code; }
+
+ /// \brief Return the specific error message attached to this status.
+ const std::string& message() const {
+ static const std::string no_message = "";
+ return ok() ? no_message : state_->msg;
+ }
+
+ /// \brief Return the status detail attached to this message.
+ const std::shared_ptr<StatusDetail>& detail() const {
+ static std::shared_ptr<StatusDetail> no_detail = NULLPTR;
+ return state_ ? state_->detail : no_detail;
+ }
+
+ /// \brief Return a new Status copying the existing status, but
+ /// updating with the existing detail.
+ Status WithDetail(std::shared_ptr<StatusDetail> new_detail) const {
+ return Status(code(), message(), std::move(new_detail));
+ }
+
+ /// \brief Return a new Status with changed message, copying the
+ /// existing status code and detail.
+ template <typename... Args>
+ Status WithMessage(Args&&... args) const {
+ return FromArgs(code(), std::forward<Args>(args)...).WithDetail(detail());
+ }
+
+ [[noreturn]] void Abort() const;
+ [[noreturn]] void Abort(const std::string& message) const;
+
+#ifdef ARROW_EXTRA_ERROR_CONTEXT
+ void AddContextLine(const char* filename, int line, const char* expr);
+#endif
+
+ private:
+ struct State {
+ StatusCode code;
+ std::string msg;
+ std::shared_ptr<StatusDetail> detail;
+ };
+ // OK status has a `NULL` state_. Otherwise, `state_` points to
+ // a `State` structure containing the error code and message(s)
+ State* state_;
+
+ void DeleteState() {
+ delete state_;
+ state_ = NULLPTR;
+ }
+ void CopyFrom(const Status& s);
+ inline void MoveFrom(Status& s);
+};
+
+void Status::MoveFrom(Status& s) {
+ delete state_;
+ state_ = s.state_;
+ s.state_ = NULLPTR;
+}
+
+Status::Status(const Status& s)
+ : state_((s.state_ == NULLPTR) ? NULLPTR : new State(*s.state_)) {}
+
+Status& Status::operator=(const Status& s) {
+ // The following condition catches both aliasing (when this == &s),
+ // and the common case where both s and *this are ok.
+ if (state_ != s.state_) {
+ CopyFrom(s);
+ }
+ return *this;
+}
+
+Status::Status(Status&& s) noexcept : state_(s.state_) { s.state_ = NULLPTR; }
+
+Status& Status::operator=(Status&& s) noexcept {
+ MoveFrom(s);
+ return *this;
+}
+
+bool Status::Equals(const Status& s) const {
+ if (state_ == s.state_) {
+ return true;
+ }
+
+ if (ok() || s.ok()) {
+ return false;
+ }
+
+ if (detail() != s.detail()) {
+ if ((detail() && !s.detail()) || (!detail() && s.detail())) {
+ return false;
+ }
+ return *detail() == *s.detail();
+ }
+
+ return code() == s.code() && message() == s.message();
+}
+
+/// \cond FALSE
+// (note: emits warnings on Doxygen < 1.8.15,
+// see https://github.com/doxygen/doxygen/issues/6295)
+Status Status::operator&(const Status& s) const noexcept {
+ if (ok()) {
+ return s;
+ } else {
+ return *this;
+ }
+}
+
+Status Status::operator&(Status&& s) const noexcept {
+ if (ok()) {
+ return std::move(s);
+ } else {
+ return *this;
+ }
+}
+
+Status& Status::operator&=(const Status& s) noexcept {
+ if (ok() && !s.ok()) {
+ CopyFrom(s);
+ }
+ return *this;
+}
+
+Status& Status::operator&=(Status&& s) noexcept {
+ if (ok() && !s.ok()) {
+ MoveFrom(s);
+ }
+ return *this;
+}
+/// \endcond
+
+namespace internal {
+
+// Extract Status from Status or Result<T>
+// Useful for the status check macros such as RETURN_NOT_OK.
+inline const Status& GenericToStatus(const Status& st) { return st; }
+inline Status GenericToStatus(Status&& st) { return std::move(st); }
+
+} // namespace internal
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/status_test.cc b/src/arrow/cpp/src/arrow/status_test.cc
new file mode 100644
index 000000000..10a79d9b9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/status_test.cc
@@ -0,0 +1,212 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
+
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+
+namespace arrow {
+
+namespace {
+
+class TestStatusDetail : public StatusDetail {
+ public:
+ const char* type_id() const override { return "type_id"; }
+ std::string ToString() const override { return "a specific detail message"; }
+};
+
+} // namespace
+
+TEST(StatusTest, TestCodeAndMessage) {
+ Status ok = Status::OK();
+ ASSERT_EQ(StatusCode::OK, ok.code());
+ Status file_error = Status::IOError("file error");
+ ASSERT_EQ(StatusCode::IOError, file_error.code());
+ ASSERT_EQ("file error", file_error.message());
+}
+
+TEST(StatusTest, TestToString) {
+ Status file_error = Status::IOError("file error");
+ ASSERT_EQ("IOError: file error", file_error.ToString());
+
+ std::stringstream ss;
+ ss << file_error;
+ ASSERT_EQ(file_error.ToString(), ss.str());
+}
+
+TEST(StatusTest, TestToStringWithDetail) {
+ Status status(StatusCode::IOError, "summary", std::make_shared<TestStatusDetail>());
+ ASSERT_EQ("IOError: summary. Detail: a specific detail message", status.ToString());
+
+ std::stringstream ss;
+ ss << status;
+ ASSERT_EQ(status.ToString(), ss.str());
+}
+
+TEST(StatusTest, TestWithDetail) {
+ Status status(StatusCode::IOError, "summary");
+ auto detail = std::make_shared<TestStatusDetail>();
+ Status new_status = status.WithDetail(detail);
+
+ ASSERT_EQ(new_status.code(), status.code());
+ ASSERT_EQ(new_status.message(), status.message());
+ ASSERT_EQ(new_status.detail(), detail);
+}
+
+TEST(StatusTest, AndStatus) {
+ Status a = Status::OK();
+ Status b = Status::OK();
+ Status c = Status::Invalid("invalid value");
+ Status d = Status::IOError("file error");
+
+ Status res;
+ res = a & b;
+ ASSERT_TRUE(res.ok());
+ res = a & c;
+ ASSERT_TRUE(res.IsInvalid());
+ res = d & c;
+ ASSERT_TRUE(res.IsIOError());
+
+ res = Status::OK();
+ res &= c;
+ ASSERT_TRUE(res.IsInvalid());
+ res &= d;
+ ASSERT_TRUE(res.IsInvalid());
+
+ // With rvalues
+ res = Status::OK() & Status::Invalid("foo");
+ ASSERT_TRUE(res.IsInvalid());
+ res = Status::Invalid("foo") & Status::OK();
+ ASSERT_TRUE(res.IsInvalid());
+ res = Status::Invalid("foo") & Status::IOError("bar");
+ ASSERT_TRUE(res.IsInvalid());
+
+ res = Status::OK();
+ res &= Status::OK();
+ ASSERT_TRUE(res.ok());
+ res &= Status::Invalid("foo");
+ ASSERT_TRUE(res.IsInvalid());
+ res &= Status::IOError("bar");
+ ASSERT_TRUE(res.IsInvalid());
+}
+
+TEST(StatusTest, TestEquality) {
+ ASSERT_EQ(Status(), Status::OK());
+ ASSERT_EQ(Status::Invalid("error"), Status::Invalid("error"));
+
+ ASSERT_NE(Status::Invalid("error"), Status::OK());
+ ASSERT_NE(Status::Invalid("error"), Status::Invalid("other error"));
+}
+
+TEST(StatusTest, MatcherExamples) {
+ EXPECT_THAT(Status::Invalid("arbitrary error"), Raises(StatusCode::Invalid));
+
+ EXPECT_THAT(Status::Invalid("arbitrary error"),
+ Raises(StatusCode::Invalid, testing::HasSubstr("arbitrary")));
+
+ // message doesn't match, so no match
+ EXPECT_THAT(
+ Status::Invalid("arbitrary error"),
+ testing::Not(Raises(StatusCode::Invalid, testing::HasSubstr("reasonable"))));
+
+ // different error code, so no match
+ EXPECT_THAT(Status::TypeError("arbitrary error"),
+ testing::Not(Raises(StatusCode::Invalid)));
+
+ // not an error, so no match
+ EXPECT_THAT(Status::OK(), testing::Not(Raises(StatusCode::Invalid)));
+}
+
+TEST(StatusTest, MatcherDescriptions) {
+ testing::Matcher<Status> matcher = Raises(StatusCode::Invalid);
+
+ {
+ std::stringstream ss;
+ matcher.DescribeTo(&ss);
+ EXPECT_THAT(ss.str(), testing::StrEq("raises StatusCode::Invalid"));
+ }
+
+ {
+ std::stringstream ss;
+ matcher.DescribeNegationTo(&ss);
+ EXPECT_THAT(ss.str(), testing::StrEq("does not raise StatusCode::Invalid"));
+ }
+}
+
+TEST(StatusTest, MessageMatcherDescriptions) {
+ testing::Matcher<Status> matcher =
+ Raises(StatusCode::Invalid, testing::HasSubstr("arbitrary"));
+
+ {
+ std::stringstream ss;
+ matcher.DescribeTo(&ss);
+ EXPECT_THAT(
+ ss.str(),
+ testing::StrEq(
+ "raises StatusCode::Invalid and message has substring \"arbitrary\""));
+ }
+
+ {
+ std::stringstream ss;
+ matcher.DescribeNegationTo(&ss);
+ EXPECT_THAT(ss.str(), testing::StrEq("does not raise StatusCode::Invalid or message "
+ "has no substring \"arbitrary\""));
+ }
+}
+
+TEST(StatusTest, MatcherExplanations) {
+ testing::Matcher<Status> matcher = Raises(StatusCode::Invalid);
+
+ {
+ testing::StringMatchResultListener listener;
+ EXPECT_TRUE(matcher.MatchAndExplain(Status::Invalid("XXX"), &listener));
+ EXPECT_THAT(listener.str(), testing::StrEq("whose value \"Invalid: XXX\" matches"));
+ }
+
+ {
+ testing::StringMatchResultListener listener;
+ EXPECT_FALSE(matcher.MatchAndExplain(Status::OK(), &listener));
+ EXPECT_THAT(listener.str(), testing::StrEq("whose value \"OK\" doesn't match"));
+ }
+
+ {
+ testing::StringMatchResultListener listener;
+ EXPECT_FALSE(matcher.MatchAndExplain(Status::TypeError("XXX"), &listener));
+ EXPECT_THAT(listener.str(),
+ testing::StrEq("whose value \"Type error: XXX\" doesn't match"));
+ }
+}
+
+TEST(StatusTest, TestDetailEquality) {
+ const auto status_with_detail =
+ arrow::Status(StatusCode::IOError, "", std::make_shared<TestStatusDetail>());
+ const auto status_with_detail2 =
+ arrow::Status(StatusCode::IOError, "", std::make_shared<TestStatusDetail>());
+ const auto status_without_detail = arrow::Status::IOError("");
+
+ ASSERT_EQ(*status_with_detail.detail(), *status_with_detail2.detail());
+ ASSERT_EQ(status_with_detail, status_with_detail2);
+ ASSERT_NE(status_with_detail, status_without_detail);
+ ASSERT_NE(status_without_detail, status_with_detail);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/stl.h b/src/arrow/cpp/src/arrow/stl.h
new file mode 100644
index 000000000..a1582ed29
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/stl.h
@@ -0,0 +1,466 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cstddef>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <tuple>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_base.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/chunked_array.h"
+#include "arrow/compute/api.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+
+class Schema;
+
+namespace stl {
+
+namespace internal {
+
+template <typename T, typename = void>
+struct is_optional_like : public std::false_type {};
+
+template <typename T, typename = void>
+struct is_dereferencable : public std::false_type {};
+
+template <typename T>
+struct is_dereferencable<T, arrow::internal::void_t<decltype(*std::declval<T>())>>
+ : public std::true_type {};
+
+template <typename T>
+struct is_optional_like<
+ T, typename std::enable_if<
+ std::is_constructible<bool, T>::value && is_dereferencable<T>::value &&
+ !std::is_array<typename std::remove_reference<T>::type>::value>::type>
+ : public std::true_type {};
+
+template <size_t N, typename Tuple>
+using BareTupleElement =
+ typename std::decay<typename std::tuple_element<N, Tuple>::type>::type;
+
+} // namespace internal
+
+template <typename T, typename R = void>
+using enable_if_optional_like =
+ typename std::enable_if<internal::is_optional_like<T>::value, R>::type;
+
+/// Traits meta class to map standard C/C++ types to equivalent Arrow types.
+template <typename T, typename Enable = void>
+struct ConversionTraits {};
+
+/// Returns builder type for given standard C/C++ type.
+template <typename CType>
+using CBuilderType =
+ typename TypeTraits<typename ConversionTraits<CType>::ArrowType>::BuilderType;
+
+/// Default implementation of AppendListValues.
+///
+/// This function can be specialized by user to take advantage of appending
+/// contiguous ranges while appending. This default implementation will call
+/// ConversionTraits<ValueCType>::AppendRow() for each value in the range.
+template <typename ValueCType, typename Range>
+inline Status AppendListValues(CBuilderType<ValueCType>& value_builder,
+ Range&& cell_range) {
+ for (auto const& value : cell_range) {
+ ARROW_RETURN_NOT_OK(ConversionTraits<ValueCType>::AppendRow(value_builder, value));
+ }
+ return Status::OK();
+}
+
+#define ARROW_STL_CONVERSION(CType_, ArrowType_) \
+ template <> \
+ struct ConversionTraits<CType_> : public CTypeTraits<CType_> { \
+ static Status AppendRow(typename TypeTraits<ArrowType_>::BuilderType& builder, \
+ CType_ cell) { \
+ return builder.Append(cell); \
+ } \
+ static CType_ GetEntry(const typename TypeTraits<ArrowType_>::ArrayType& array, \
+ size_t j) { \
+ return array.Value(j); \
+ } \
+ }; \
+ \
+ template <> \
+ inline Status AppendListValues<CType_, const std::vector<CType_>&>( \
+ typename TypeTraits<ArrowType_>::BuilderType & value_builder, \
+ const std::vector<CType_>& cell_range) { \
+ return value_builder.AppendValues(cell_range); \
+ }
+
+ARROW_STL_CONVERSION(bool, BooleanType)
+ARROW_STL_CONVERSION(int8_t, Int8Type)
+ARROW_STL_CONVERSION(int16_t, Int16Type)
+ARROW_STL_CONVERSION(int32_t, Int32Type)
+ARROW_STL_CONVERSION(int64_t, Int64Type)
+ARROW_STL_CONVERSION(uint8_t, UInt8Type)
+ARROW_STL_CONVERSION(uint16_t, UInt16Type)
+ARROW_STL_CONVERSION(uint32_t, UInt32Type)
+ARROW_STL_CONVERSION(uint64_t, UInt64Type)
+ARROW_STL_CONVERSION(float, FloatType)
+ARROW_STL_CONVERSION(double, DoubleType)
+
+template <>
+struct ConversionTraits<std::string> : public CTypeTraits<std::string> {
+ static Status AppendRow(StringBuilder& builder, const std::string& cell) {
+ return builder.Append(cell);
+ }
+ static std::string GetEntry(const StringArray& array, size_t j) {
+ return array.GetString(j);
+ }
+};
+
+/// Append cell range elements as a single value to the list builder.
+///
+/// Cell range will be added to child builder using AppendListValues<ValueCType>()
+/// if provided. AppendListValues<ValueCType>() has a default implementation, but
+/// it can be specialized by users.
+template <typename ValueCType, typename ListBuilderType, typename Range>
+Status AppendCellRange(ListBuilderType& builder, Range&& cell_range) {
+ constexpr bool is_list_builder = std::is_same<ListBuilderType, ListBuilder>::value;
+ constexpr bool is_large_list_builder =
+ std::is_same<ListBuilderType, LargeListBuilder>::value;
+ static_assert(
+ is_list_builder || is_large_list_builder,
+ "Builder type must be either ListBuilder or LargeListBuilder for appending "
+ "multiple rows.");
+
+ using ChildBuilderType = CBuilderType<ValueCType>;
+ ARROW_RETURN_NOT_OK(builder.Append());
+ auto& value_builder =
+ ::arrow::internal::checked_cast<ChildBuilderType&>(*builder.value_builder());
+
+ // XXX: Remove appended value before returning if status isn't OK?
+ return AppendListValues<ValueCType>(value_builder, std::forward<Range>(cell_range));
+}
+
+template <typename ValueCType>
+struct ConversionTraits<std::vector<ValueCType>>
+ : public CTypeTraits<std::vector<ValueCType>> {
+ static Status AppendRow(ListBuilder& builder, const std::vector<ValueCType>& cell) {
+ return AppendCellRange<ValueCType>(builder, cell);
+ }
+
+ static std::vector<ValueCType> GetEntry(const ListArray& array, size_t j) {
+ using ElementArrayType =
+ typename TypeTraits<typename ConversionTraits<ValueCType>::ArrowType>::ArrayType;
+
+ const ElementArrayType& value_array =
+ ::arrow::internal::checked_cast<const ElementArrayType&>(*array.values());
+
+ std::vector<ValueCType> vec(array.value_length(j));
+ for (int64_t i = 0; i < array.value_length(j); i++) {
+ vec[i] =
+ ConversionTraits<ValueCType>::GetEntry(value_array, array.value_offset(j) + i);
+ }
+ return vec;
+ }
+};
+
+template <typename Optional>
+struct ConversionTraits<Optional, enable_if_optional_like<Optional>>
+ : public CTypeTraits<typename std::decay<decltype(*std::declval<Optional>())>::type> {
+ using OptionalInnerType =
+ typename std::decay<decltype(*std::declval<Optional>())>::type;
+ using typename CTypeTraits<OptionalInnerType>::ArrowType;
+ using CTypeTraits<OptionalInnerType>::type_singleton;
+
+ static Status AppendRow(typename TypeTraits<ArrowType>::BuilderType& builder,
+ const Optional& cell) {
+ if (cell) {
+ return ConversionTraits<OptionalInnerType>::AppendRow(builder, *cell);
+ } else {
+ return builder.AppendNull();
+ }
+ }
+};
+
+/// Build an arrow::Schema based upon the types defined in a std::tuple-like structure.
+///
+/// While the type information is available at compile-time, we still need to add the
+/// column names at runtime, thus these methods are not constexpr.
+template <typename Tuple, std::size_t N = std::tuple_size<Tuple>::value>
+struct SchemaFromTuple {
+ using Element = internal::BareTupleElement<N - 1, Tuple>;
+
+ // Implementations that take a vector-like object for the column names.
+
+ /// Recursively build a vector of arrow::Field from the defined types.
+ ///
+ /// In most cases MakeSchema is the better entrypoint for the Schema creation.
+ static std::vector<std::shared_ptr<Field>> MakeSchemaRecursion(
+ const std::vector<std::string>& names) {
+ std::vector<std::shared_ptr<Field>> ret =
+ SchemaFromTuple<Tuple, N - 1>::MakeSchemaRecursion(names);
+ auto type = ConversionTraits<Element>::type_singleton();
+ ret.push_back(field(names[N - 1], type, internal::is_optional_like<Element>::value));
+ return ret;
+ }
+
+ /// Build a Schema from the types of the tuple-like structure passed in as template
+ /// parameter assign the column names at runtime.
+ ///
+ /// An example usage of this API can look like the following:
+ ///
+ /// \code{.cpp}
+ /// using TupleType = std::tuple<int, std::vector<std::string>>;
+ /// std::shared_ptr<Schema> schema =
+ /// SchemaFromTuple<TupleType>::MakeSchema({"int_column", "list_of_strings_column"});
+ /// \endcode
+ static std::shared_ptr<Schema> MakeSchema(const std::vector<std::string>& names) {
+ return std::make_shared<Schema>(MakeSchemaRecursion(names));
+ }
+
+ // Implementations that take a tuple-like object for the column names.
+
+ /// Recursively build a vector of arrow::Field from the defined types.
+ ///
+ /// In most cases MakeSchema is the better entrypoint for the Schema creation.
+ template <typename NamesTuple>
+ static std::vector<std::shared_ptr<Field>> MakeSchemaRecursionT(
+ const NamesTuple& names) {
+ using std::get;
+
+ std::vector<std::shared_ptr<Field>> ret =
+ SchemaFromTuple<Tuple, N - 1>::MakeSchemaRecursionT(names);
+ std::shared_ptr<DataType> type = ConversionTraits<Element>::type_singleton();
+ ret.push_back(
+ field(get<N - 1>(names), type, internal::is_optional_like<Element>::value));
+ return ret;
+ }
+
+ /// Build a Schema from the types of the tuple-like structure passed in as template
+ /// parameter assign the column names at runtime.
+ ///
+ /// An example usage of this API can look like the following:
+ ///
+ /// \code{.cpp}
+ /// using TupleType = std::tuple<int, std::vector<std::string>>;
+ /// std::shared_ptr<Schema> schema =
+ /// SchemaFromTuple<TupleType>::MakeSchema({"int_column", "list_of_strings_column"});
+ /// \endcode
+ template <typename NamesTuple>
+ static std::shared_ptr<Schema> MakeSchema(const NamesTuple& names) {
+ return std::make_shared<Schema>(MakeSchemaRecursionT<NamesTuple>(names));
+ }
+};
+
+template <typename Tuple>
+struct SchemaFromTuple<Tuple, 0> {
+ static std::vector<std::shared_ptr<Field>> MakeSchemaRecursion(
+ const std::vector<std::string>& names) {
+ std::vector<std::shared_ptr<Field>> ret;
+ ret.reserve(names.size());
+ return ret;
+ }
+
+ template <typename NamesTuple>
+ static std::vector<std::shared_ptr<Field>> MakeSchemaRecursionT(
+ const NamesTuple& names) {
+ std::vector<std::shared_ptr<Field>> ret;
+ ret.reserve(std::tuple_size<NamesTuple>::value);
+ return ret;
+ }
+};
+
+namespace internal {
+
+template <typename Tuple, std::size_t N = std::tuple_size<Tuple>::value>
+struct CreateBuildersRecursive {
+ static Status Make(MemoryPool* pool,
+ std::vector<std::unique_ptr<ArrayBuilder>>* builders) {
+ using Element = BareTupleElement<N - 1, Tuple>;
+ std::shared_ptr<DataType> type = ConversionTraits<Element>::type_singleton();
+ ARROW_RETURN_NOT_OK(MakeBuilder(pool, type, &builders->at(N - 1)));
+
+ return CreateBuildersRecursive<Tuple, N - 1>::Make(pool, builders);
+ }
+};
+
+template <typename Tuple>
+struct CreateBuildersRecursive<Tuple, 0> {
+ static Status Make(MemoryPool*, std::vector<std::unique_ptr<ArrayBuilder>>*) {
+ return Status::OK();
+ }
+};
+
+template <typename Tuple, std::size_t N = std::tuple_size<Tuple>::value>
+struct RowIterator {
+ static Status Append(const std::vector<std::unique_ptr<ArrayBuilder>>& builders,
+ const Tuple& row) {
+ using std::get;
+ using Element = BareTupleElement<N - 1, Tuple>;
+ using BuilderType =
+ typename TypeTraits<typename ConversionTraits<Element>::ArrowType>::BuilderType;
+
+ BuilderType& builder =
+ ::arrow::internal::checked_cast<BuilderType&>(*builders[N - 1]);
+ ARROW_RETURN_NOT_OK(ConversionTraits<Element>::AppendRow(builder, get<N - 1>(row)));
+
+ return RowIterator<Tuple, N - 1>::Append(builders, row);
+ }
+};
+
+template <typename Tuple>
+struct RowIterator<Tuple, 0> {
+ static Status Append(const std::vector<std::unique_ptr<ArrayBuilder>>& builders,
+ const Tuple& row) {
+ return Status::OK();
+ }
+};
+
+template <typename Tuple, std::size_t N = std::tuple_size<Tuple>::value>
+struct EnsureColumnTypes {
+ static Status Cast(const Table& table, std::shared_ptr<Table>* table_owner,
+ const compute::CastOptions& cast_options, compute::ExecContext* ctx,
+ std::reference_wrapper<const ::arrow::Table>* result) {
+ using Element = BareTupleElement<N - 1, Tuple>;
+ std::shared_ptr<DataType> expected_type = ConversionTraits<Element>::type_singleton();
+
+ if (!table.schema()->field(N - 1)->type()->Equals(*expected_type)) {
+ ARROW_ASSIGN_OR_RAISE(
+ Datum casted,
+ compute::Cast(table.column(N - 1), expected_type, cast_options, ctx));
+ auto new_field = table.schema()->field(N - 1)->WithType(expected_type);
+ ARROW_ASSIGN_OR_RAISE(*table_owner,
+ table.SetColumn(N - 1, new_field, casted.chunked_array()));
+ *result = **table_owner;
+ }
+
+ return EnsureColumnTypes<Tuple, N - 1>::Cast(result->get(), table_owner, cast_options,
+ ctx, result);
+ }
+};
+
+template <typename Tuple>
+struct EnsureColumnTypes<Tuple, 0> {
+ static Status Cast(const Table& table, std::shared_ptr<Table>* table_owner,
+ const compute::CastOptions& cast_options, compute::ExecContext* ctx,
+ std::reference_wrapper<const ::arrow::Table>* result) {
+ return Status::OK();
+ }
+};
+
+template <typename Range, typename Tuple, std::size_t N = std::tuple_size<Tuple>::value>
+struct TupleSetter {
+ static void Fill(const Table& table, Range* rows) {
+ using std::get;
+ using Element = typename std::tuple_element<N - 1, Tuple>::type;
+ using ArrayType =
+ typename TypeTraits<typename ConversionTraits<Element>::ArrowType>::ArrayType;
+
+ auto iter = rows->begin();
+ const ChunkedArray& chunked_array = *table.column(N - 1);
+ for (int i = 0; i < chunked_array.num_chunks(); i++) {
+ const ArrayType& array =
+ ::arrow::internal::checked_cast<const ArrayType&>(*chunked_array.chunk(i));
+ for (int64_t j = 0; j < array.length(); j++) {
+ get<N - 1>(*iter++) = ConversionTraits<Element>::GetEntry(array, j);
+ }
+ }
+
+ return TupleSetter<Range, Tuple, N - 1>::Fill(table, rows);
+ }
+};
+
+template <typename Range, typename Tuple>
+struct TupleSetter<Range, Tuple, 0> {
+ static void Fill(const Table& table, Range* rows) {}
+};
+
+} // namespace internal
+
+template <typename Range>
+Status TableFromTupleRange(MemoryPool* pool, Range&& rows,
+ const std::vector<std::string>& names,
+ std::shared_ptr<Table>* table) {
+ using row_type = typename std::iterator_traits<decltype(std::begin(rows))>::value_type;
+ constexpr std::size_t n_columns = std::tuple_size<row_type>::value;
+
+ std::shared_ptr<Schema> schema = SchemaFromTuple<row_type>::MakeSchema(names);
+
+ std::vector<std::unique_ptr<ArrayBuilder>> builders(n_columns);
+ ARROW_RETURN_NOT_OK(internal::CreateBuildersRecursive<row_type>::Make(pool, &builders));
+
+ for (auto const& row : rows) {
+ ARROW_RETURN_NOT_OK(internal::RowIterator<row_type>::Append(builders, row));
+ }
+
+ std::vector<std::shared_ptr<Array>> arrays;
+ for (auto const& builder : builders) {
+ std::shared_ptr<Array> array;
+ ARROW_RETURN_NOT_OK(builder->Finish(&array));
+ arrays.emplace_back(array);
+ }
+
+ *table = Table::Make(std::move(schema), std::move(arrays));
+
+ return Status::OK();
+}
+
+template <typename Range>
+Status TupleRangeFromTable(const Table& table, const compute::CastOptions& cast_options,
+ compute::ExecContext* ctx, Range* rows) {
+ using row_type = typename std::decay<decltype(*std::begin(*rows))>::type;
+ constexpr std::size_t n_columns = std::tuple_size<row_type>::value;
+
+ if (table.schema()->num_fields() != n_columns) {
+ std::stringstream ss;
+ ss << "Number of columns in the table does not match the width of the target: ";
+ ss << table.schema()->num_fields() << " != " << n_columns;
+ return Status::Invalid(ss.str());
+ }
+
+ // TODO: Use std::size with C++17
+ if (rows->size() != static_cast<size_t>(table.num_rows())) {
+ std::stringstream ss;
+ ss << "Number of rows in the table does not match the size of the target: ";
+ ss << table.num_rows() << " != " << rows->size();
+ return Status::Invalid(ss.str());
+ }
+
+ // Check that all columns have the correct type, otherwise cast them.
+ std::shared_ptr<Table> table_owner;
+ std::reference_wrapper<const ::arrow::Table> current_table(table);
+
+ ARROW_RETURN_NOT_OK(internal::EnsureColumnTypes<row_type>::Cast(
+ table, &table_owner, cast_options, ctx, &current_table));
+
+ internal::TupleSetter<Range, row_type>::Fill(current_table.get(), rows);
+
+ return Status::OK();
+}
+
+} // namespace stl
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/stl_allocator.h b/src/arrow/cpp/src/arrow/stl_allocator.h
new file mode 100644
index 000000000..b5ad2b534
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/stl_allocator.h
@@ -0,0 +1,153 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "arrow/memory_pool.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace stl {
+
+/// \brief A STL allocator delegating allocations to a Arrow MemoryPool
+template <class T>
+class allocator {
+ public:
+ using value_type = T;
+ using pointer = T*;
+ using const_pointer = const T*;
+ using reference = T&;
+ using const_reference = const T&;
+ using size_type = std::size_t;
+ using difference_type = std::ptrdiff_t;
+
+ template <class U>
+ struct rebind {
+ using other = allocator<U>;
+ };
+
+ /// \brief Construct an allocator from the default MemoryPool
+ allocator() noexcept : pool_(default_memory_pool()) {}
+ /// \brief Construct an allocator from the given MemoryPool
+ explicit allocator(MemoryPool* pool) noexcept : pool_(pool) {}
+
+ template <class U>
+ allocator(const allocator<U>& rhs) noexcept : pool_(rhs.pool()) {}
+
+ ~allocator() { pool_ = NULLPTR; }
+
+ pointer address(reference r) const noexcept { return std::addressof(r); }
+
+ const_pointer address(const_reference r) const noexcept { return std::addressof(r); }
+
+ pointer allocate(size_type n, const void* /*hint*/ = NULLPTR) {
+ uint8_t* data;
+ Status s = pool_->Allocate(n * sizeof(T), &data);
+ if (!s.ok()) throw std::bad_alloc();
+ return reinterpret_cast<pointer>(data);
+ }
+
+ void deallocate(pointer p, size_type n) {
+ pool_->Free(reinterpret_cast<uint8_t*>(p), n * sizeof(T));
+ }
+
+ size_type size_max() const noexcept { return size_type(-1) / sizeof(T); }
+
+ template <class U, class... Args>
+ void construct(U* p, Args&&... args) {
+ new (reinterpret_cast<void*>(p)) U(std::forward<Args>(args)...);
+ }
+
+ template <class U>
+ void destroy(U* p) {
+ p->~U();
+ }
+
+ MemoryPool* pool() const noexcept { return pool_; }
+
+ private:
+ MemoryPool* pool_;
+};
+
+/// \brief A MemoryPool implementation delegating allocations to a STL allocator
+///
+/// Note that STL allocators don't provide a resizing operation, and therefore
+/// any buffer resizes will do a full reallocation and copy.
+template <typename Allocator = std::allocator<uint8_t>>
+class STLMemoryPool : public MemoryPool {
+ public:
+ /// \brief Construct a memory pool from the given allocator
+ explicit STLMemoryPool(const Allocator& alloc) : alloc_(alloc) {}
+
+ Status Allocate(int64_t size, uint8_t** out) override {
+ try {
+ *out = alloc_.allocate(size);
+ } catch (std::bad_alloc& e) {
+ return Status::OutOfMemory(e.what());
+ }
+ stats_.UpdateAllocatedBytes(size);
+ return Status::OK();
+ }
+
+ Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) override {
+ uint8_t* old_ptr = *ptr;
+ try {
+ *ptr = alloc_.allocate(new_size);
+ } catch (std::bad_alloc& e) {
+ return Status::OutOfMemory(e.what());
+ }
+ memcpy(*ptr, old_ptr, std::min(old_size, new_size));
+ alloc_.deallocate(old_ptr, old_size);
+ stats_.UpdateAllocatedBytes(new_size - old_size);
+ return Status::OK();
+ }
+
+ void Free(uint8_t* buffer, int64_t size) override {
+ alloc_.deallocate(buffer, size);
+ stats_.UpdateAllocatedBytes(-size);
+ }
+
+ int64_t bytes_allocated() const override { return stats_.bytes_allocated(); }
+
+ int64_t max_memory() const override { return stats_.max_memory(); }
+
+ std::string backend_name() const override { return "stl"; }
+
+ private:
+ Allocator alloc_;
+ arrow::internal::MemoryPoolStats stats_;
+};
+
+template <class T1, class T2>
+bool operator==(const allocator<T1>& lhs, const allocator<T2>& rhs) noexcept {
+ return lhs.pool() == rhs.pool();
+}
+
+template <class T1, class T2>
+bool operator!=(const allocator<T1>& lhs, const allocator<T2>& rhs) noexcept {
+ return !(lhs == rhs);
+}
+
+} // namespace stl
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/stl_iterator.h b/src/arrow/cpp/src/arrow/stl_iterator.h
new file mode 100644
index 000000000..6225a89aa
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/stl_iterator.h
@@ -0,0 +1,146 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstddef>
+#include <iterator>
+#include <utility>
+
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/optional.h"
+
+namespace arrow {
+namespace stl {
+
+namespace detail {
+
+template <typename ArrayType>
+struct DefaultValueAccessor {
+ using ValueType = decltype(std::declval<ArrayType>().GetView(0));
+
+ ValueType operator()(const ArrayType& array, int64_t index) {
+ return array.GetView(index);
+ }
+};
+
+} // namespace detail
+
+template <typename ArrayType,
+ typename ValueAccessor = detail::DefaultValueAccessor<ArrayType>>
+class ArrayIterator {
+ public:
+ using value_type = arrow::util::optional<typename ValueAccessor::ValueType>;
+ using difference_type = int64_t;
+ using pointer = value_type*;
+ using reference = value_type&;
+ using iterator_category = std::random_access_iterator_tag;
+
+ // Some algorithms need to default-construct an iterator
+ ArrayIterator() : array_(NULLPTR), index_(0) {}
+
+ explicit ArrayIterator(const ArrayType& array, int64_t index = 0)
+ : array_(&array), index_(index) {}
+
+ // Value access
+ value_type operator*() const {
+ return array_->IsNull(index_) ? value_type{} : array_->GetView(index_);
+ }
+
+ value_type operator[](difference_type n) const {
+ return array_->IsNull(index_ + n) ? value_type{} : array_->GetView(index_ + n);
+ }
+
+ int64_t index() const { return index_; }
+
+ // Forward / backward
+ ArrayIterator& operator++() {
+ ++index_;
+ return *this;
+ }
+ ArrayIterator& operator--() {
+ --index_;
+ return *this;
+ }
+ ArrayIterator operator++(int) {
+ ArrayIterator tmp(*this);
+ ++index_;
+ return tmp;
+ }
+ ArrayIterator operator--(int) {
+ ArrayIterator tmp(*this);
+ --index_;
+ return tmp;
+ }
+
+ // Arithmetic
+ difference_type operator-(const ArrayIterator& other) const {
+ return index_ - other.index_;
+ }
+ ArrayIterator operator+(difference_type n) const {
+ return ArrayIterator(*array_, index_ + n);
+ }
+ ArrayIterator operator-(difference_type n) const {
+ return ArrayIterator(*array_, index_ - n);
+ }
+ friend inline ArrayIterator operator+(difference_type diff,
+ const ArrayIterator& other) {
+ return ArrayIterator(*other.array_, diff + other.index_);
+ }
+ friend inline ArrayIterator operator-(difference_type diff,
+ const ArrayIterator& other) {
+ return ArrayIterator(*other.array_, diff - other.index_);
+ }
+ ArrayIterator& operator+=(difference_type n) {
+ index_ += n;
+ return *this;
+ }
+ ArrayIterator& operator-=(difference_type n) {
+ index_ -= n;
+ return *this;
+ }
+
+ // Comparisons
+ bool operator==(const ArrayIterator& other) const { return index_ == other.index_; }
+ bool operator!=(const ArrayIterator& other) const { return index_ != other.index_; }
+ bool operator<(const ArrayIterator& other) const { return index_ < other.index_; }
+ bool operator>(const ArrayIterator& other) const { return index_ > other.index_; }
+ bool operator<=(const ArrayIterator& other) const { return index_ <= other.index_; }
+ bool operator>=(const ArrayIterator& other) const { return index_ >= other.index_; }
+
+ private:
+ const ArrayType* array_;
+ int64_t index_;
+};
+
+} // namespace stl
+} // namespace arrow
+
+namespace std {
+
+template <typename ArrayType>
+struct iterator_traits<::arrow::stl::ArrayIterator<ArrayType>> {
+ using IteratorType = ::arrow::stl::ArrayIterator<ArrayType>;
+ using difference_type = typename IteratorType::difference_type;
+ using value_type = typename IteratorType::value_type;
+ using pointer = typename IteratorType::pointer;
+ using reference = typename IteratorType::reference;
+ using iterator_category = typename IteratorType::iterator_category;
+};
+
+} // namespace std
diff --git a/src/arrow/cpp/src/arrow/stl_iterator_test.cc b/src/arrow/cpp/src/arrow/stl_iterator_test.cc
new file mode 100644
index 000000000..864011dbf
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/stl_iterator_test.cc
@@ -0,0 +1,252 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+
+#include <gtest/gtest.h>
+
+#include "arrow/stl.h"
+#include "arrow/stl_iterator.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+using util::nullopt;
+using util::optional;
+
+namespace stl {
+
+TEST(ArrayIterator, Basics) {
+ auto array =
+ checked_pointer_cast<Int32Array>(ArrayFromJSON(int32(), "[4, 5, null, 6]"));
+
+ ArrayIterator<Int32Array> it(*array);
+ optional<int32_t> v = *it;
+ ASSERT_EQ(v, 4);
+ ASSERT_EQ(it[0], 4);
+ ++it;
+ ASSERT_EQ(it[0], 5);
+ ASSERT_EQ(*it, 5);
+ ASSERT_EQ(it[1], nullopt);
+ ASSERT_EQ(it[2], 6);
+}
+
+TEST(ArrayIterator, Arithmetic) {
+ auto array = checked_pointer_cast<Int32Array>(
+ ArrayFromJSON(int32(), "[4, 5, null, 6, null, 7]"));
+
+ ArrayIterator<Int32Array> it(*array);
+ auto it2 = it + 2;
+ ASSERT_EQ(*it, 4);
+ ASSERT_EQ(*it2, nullopt);
+ ASSERT_EQ(it2 - it, 2);
+ ASSERT_EQ(it - it2, -2);
+ auto it3 = it++;
+ ASSERT_EQ(it2 - it, 1);
+ ASSERT_EQ(it2 - it3, 2);
+ ASSERT_EQ(*it3, 4);
+ ASSERT_EQ(*it, 5);
+ auto it4 = ++it;
+ ASSERT_EQ(*it, nullopt);
+ ASSERT_EQ(*it4, nullopt);
+ ASSERT_EQ(it2 - it, 0);
+ ASSERT_EQ(it2 - it4, 0);
+ auto it5 = it + 3;
+ ASSERT_EQ(*it5, 7);
+ ASSERT_EQ(*(it5 - 2), 6);
+ ASSERT_EQ(*(it5 + (-2)), 6);
+ auto it6 = (--it5)--;
+ ASSERT_EQ(*it6, nullopt);
+ ASSERT_EQ(*it5, 6);
+ ASSERT_EQ(it6 - it5, 1);
+}
+
+TEST(ArrayIterator, Comparison) {
+ auto array = checked_pointer_cast<Int32Array>(
+ ArrayFromJSON(int32(), "[4, 5, null, 6, null, 7]"));
+
+ auto it = ArrayIterator<Int32Array>(*array) + 2;
+ auto it2 = ArrayIterator<Int32Array>(*array) + 2;
+ auto it3 = ArrayIterator<Int32Array>(*array) + 4;
+
+ ASSERT_TRUE(it == it2);
+ ASSERT_TRUE(it <= it2);
+ ASSERT_TRUE(it >= it2);
+ ASSERT_FALSE(it != it2);
+ ASSERT_FALSE(it < it2);
+ ASSERT_FALSE(it > it2);
+
+ ASSERT_FALSE(it == it3);
+ ASSERT_TRUE(it <= it3);
+ ASSERT_FALSE(it >= it3);
+ ASSERT_TRUE(it != it3);
+ ASSERT_TRUE(it < it3);
+ ASSERT_FALSE(it > it3);
+}
+
+TEST(ArrayIterator, BeginEnd) {
+ auto array =
+ checked_pointer_cast<Int32Array>(ArrayFromJSON(int32(), "[4, 5, null, 6]"));
+ std::vector<optional<int32_t>> values;
+ for (auto it = array->begin(); it != array->end(); ++it) {
+ values.push_back(*it);
+ }
+ std::vector<optional<int32_t>> expected{4, 5, {}, 6};
+ ASSERT_EQ(values, expected);
+}
+
+TEST(ArrayIterator, RangeFor) {
+ auto array =
+ checked_pointer_cast<Int32Array>(ArrayFromJSON(int32(), "[4, 5, null, 6]"));
+ std::vector<optional<int32_t>> values;
+ for (const auto v : *array) {
+ values.push_back(v);
+ }
+ std::vector<optional<int32_t>> expected{4, 5, {}, 6};
+ ASSERT_EQ(values, expected);
+}
+
+TEST(ArrayIterator, String) {
+ auto array = checked_pointer_cast<StringArray>(
+ ArrayFromJSON(utf8(), R"(["foo", "bar", null, "quux"])"));
+ std::vector<optional<util::string_view>> values;
+ for (const auto v : *array) {
+ values.push_back(v);
+ }
+ std::vector<optional<util::string_view>> expected{"foo", "bar", {}, "quux"};
+ ASSERT_EQ(values, expected);
+}
+
+TEST(ArrayIterator, Boolean) {
+ auto array = checked_pointer_cast<BooleanArray>(
+ ArrayFromJSON(boolean(), "[true, null, null, false]"));
+ std::vector<optional<bool>> values;
+ for (const auto v : *array) {
+ values.push_back(v);
+ }
+ std::vector<optional<bool>> expected{true, {}, {}, false};
+ ASSERT_EQ(values, expected);
+}
+
+TEST(ArrayIterator, FixedSizeBinary) {
+ auto array = checked_pointer_cast<FixedSizeBinaryArray>(
+ ArrayFromJSON(fixed_size_binary(3), R"(["foo", "bar", null, "quu"])"));
+ std::vector<optional<util::string_view>> values;
+ for (const auto v : *array) {
+ values.push_back(v);
+ }
+ std::vector<optional<util::string_view>> expected{"foo", "bar", {}, "quu"};
+ ASSERT_EQ(values, expected);
+}
+
+// Test compatibility with various STL algorithms
+
+TEST(ArrayIterator, StdFind) {
+ auto array = checked_pointer_cast<StringArray>(
+ ArrayFromJSON(utf8(), R"(["foo", "bar", null, "quux"])"));
+
+ auto it = std::find(array->begin(), array->end(), "bar");
+ ASSERT_EQ(it.index(), 1);
+ it = std::find(array->begin(), array->end(), nullopt);
+ ASSERT_EQ(it.index(), 2);
+ it = std::find(array->begin(), array->end(), "zzz");
+ ASSERT_EQ(it, array->end());
+}
+
+TEST(ArrayIterator, StdCountIf) {
+ auto array = checked_pointer_cast<BooleanArray>(
+ ArrayFromJSON(boolean(), "[true, null, null, false]"));
+
+ auto n = std::count_if(array->begin(), array->end(),
+ [](optional<bool> v) { return !v.has_value(); });
+ ASSERT_EQ(n, 2);
+}
+
+TEST(ArrayIterator, StdCopy) {
+ auto array = checked_pointer_cast<BooleanArray>(
+ ArrayFromJSON(boolean(), "[true, null, null, false]"));
+ std::vector<optional<bool>> values;
+ std::copy(array->begin() + 1, array->end(), std::back_inserter(values));
+ std::vector<optional<bool>> expected{{}, {}, false};
+ ASSERT_EQ(values, expected);
+}
+
+TEST(ArrayIterator, StdPartitionPoint) {
+ auto array = checked_pointer_cast<DoubleArray>(
+ ArrayFromJSON(float64(), "[4.5, 2.5, 1e100, 3, null, null, null, null, null]"));
+ auto it = std::partition_point(array->begin(), array->end(),
+ [](optional<double> v) { return v.has_value(); });
+ ASSERT_EQ(it.index(), 4);
+ ASSERT_EQ(*it, nullopt);
+
+ array =
+ checked_pointer_cast<DoubleArray>(ArrayFromJSON(float64(), "[null, null, null]"));
+ it = std::partition_point(array->begin(), array->end(),
+ [](optional<double> v) { return v.has_value(); });
+ ASSERT_EQ(it, array->begin());
+ it = std::partition_point(array->begin(), array->end(),
+ [](optional<double> v) { return !v.has_value(); });
+ ASSERT_EQ(it, array->end());
+}
+
+TEST(ArrayIterator, StdEqualRange) {
+ auto array = checked_pointer_cast<Int8Array>(
+ ArrayFromJSON(int8(), "[1, 4, 5, 5, 5, 7, 8, null, null, null]"));
+ auto cmp_lt = [](optional<int8_t> u, optional<int8_t> v) {
+ return u.has_value() && (!v.has_value() || *u < *v);
+ };
+
+ auto pair = std::equal_range(array->begin(), array->end(), nullopt, cmp_lt);
+ ASSERT_EQ(pair.first, array->end() - 3);
+ ASSERT_EQ(pair.second, array->end());
+ pair = std::equal_range(array->begin(), array->end(), 6, cmp_lt);
+ ASSERT_EQ(pair.first, array->begin() + 5);
+ ASSERT_EQ(pair.second, pair.first);
+ pair = std::equal_range(array->begin(), array->end(), 5, cmp_lt);
+ ASSERT_EQ(pair.first, array->begin() + 2);
+ ASSERT_EQ(pair.second, array->begin() + 5);
+ pair = std::equal_range(array->begin(), array->end(), 1, cmp_lt);
+ ASSERT_EQ(pair.first, array->begin());
+ ASSERT_EQ(pair.second, array->begin() + 1);
+ pair = std::equal_range(array->begin(), array->end(), 0, cmp_lt);
+ ASSERT_EQ(pair.first, array->begin());
+ ASSERT_EQ(pair.second, pair.first);
+}
+
+TEST(ArrayIterator, StdMerge) {
+ auto array1 = checked_pointer_cast<Int8Array>(
+ ArrayFromJSON(int8(), "[1, 4, 5, 5, 7, null, null, null]"));
+ auto array2 =
+ checked_pointer_cast<Int8Array>(ArrayFromJSON(int8(), "[-1, 3, 3, 6, 42]"));
+ auto cmp_lt = [](optional<int8_t> u, optional<int8_t> v) {
+ return u.has_value() && (!v.has_value() || *u < *v);
+ };
+
+ std::vector<optional<int8_t>> values;
+ std::merge(array1->begin(), array1->end(), array2->begin(), array2->end(),
+ std::back_inserter(values), cmp_lt);
+ std::vector<optional<int8_t>> expected{-1, 1, 3, 3, 4, 5, 5, 6, 7, 42, {}, {}, {}};
+ ASSERT_EQ(values, expected);
+}
+
+} // namespace stl
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/stl_test.cc b/src/arrow/cpp/src/arrow/stl_test.cc
new file mode 100644
index 000000000..159d1d983
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/stl_test.cc
@@ -0,0 +1,558 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <new>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/memory_pool.h"
+#include "arrow/stl.h"
+#include "arrow/stl_allocator.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/optional.h"
+
+using primitive_types_tuple = std::tuple<int8_t, int16_t, int32_t, int64_t, uint8_t,
+ uint16_t, uint32_t, uint64_t, bool, std::string>;
+
+using raw_pointer_optional_types_tuple =
+ std::tuple<int8_t*, int16_t*, int32_t*, int64_t*, uint8_t*, uint16_t*, uint32_t*,
+ uint64_t*, bool*, std::string*>;
+
+struct CustomType {
+ int8_t i8;
+ int16_t i16;
+ int32_t i32;
+ int64_t i64;
+ uint8_t u8;
+ uint16_t u16;
+ uint32_t u32;
+ uint64_t u64;
+ bool b;
+ std::string s;
+
+#define ARROW_CUSTOM_TYPE_TIED std::tie(i8, i16, i32, i64, u8, u16, u32, u64, b, s)
+ auto tie() const -> decltype(ARROW_CUSTOM_TYPE_TIED) { return ARROW_CUSTOM_TYPE_TIED; }
+#undef ARROW_CUSTOM_TYPE_TIED
+};
+
+// Mock optional object returning null, "yes", "no", null, "yes", "no", ...
+// Note: This mock optional object will advance its state every time it's casted
+// to bool. Successive castings to bool may give inconsistent results. It
+// doesn't mock entire optional logic. It is used only for ensuring user
+// specialization isn't broken with templated Optionals.
+struct CustomOptionalTypeMock {
+ static int counter;
+ mutable bool was_casted_once_ = false;
+
+ CustomOptionalTypeMock() = default;
+ explicit operator bool() const {
+ if (!was_casted_once_) {
+ was_casted_once_ = true;
+ counter++;
+ return counter % 3 != 0;
+ }
+ ADD_FAILURE() << "A CustomOptionalTypeMock should be casted to bool only once.";
+ return false;
+ }
+ std::string operator*() const {
+ switch (counter % 3) {
+ case 0:
+ ADD_FAILURE() << "Optional dereferenced in null value";
+ break;
+ case 1:
+ return "yes";
+ case 2:
+ return "no";
+ }
+ return "error";
+ }
+};
+
+int CustomOptionalTypeMock::counter = -1;
+
+// This is for testing appending list values with custom types
+struct TestInt32Type {
+ int32_t value;
+};
+
+namespace arrow {
+
+using optional_types_tuple =
+ std::tuple<util::optional<int8_t>, util::optional<int16_t>, util::optional<int32_t>,
+ util::optional<int64_t>, util::optional<uint8_t>, util::optional<uint16_t>,
+ util::optional<uint32_t>, util::optional<uint64_t>, util::optional<bool>,
+ util::optional<std::string>>;
+
+template <>
+struct CTypeTraits<CustomOptionalTypeMock> {
+ using ArrowType = ::arrow::StringType;
+
+ static std::shared_ptr<::arrow::DataType> type_singleton() { return ::arrow::utf8(); }
+};
+
+template <>
+struct CTypeTraits<TestInt32Type> {
+ using ArrowType = ::arrow::Int32Type;
+
+ static std::shared_ptr<::arrow::DataType> type_singleton() { return ::arrow::int32(); }
+};
+
+namespace stl {
+
+template <>
+struct ConversionTraits<CustomOptionalTypeMock>
+ : public CTypeTraits<CustomOptionalTypeMock> {
+ static Status AppendRow(typename TypeTraits<ArrowType>::BuilderType& builder,
+ const CustomOptionalTypeMock& cell) {
+ if (cell) {
+ return builder.Append("mock " + *cell);
+ } else {
+ return builder.AppendNull();
+ }
+ }
+};
+
+template <>
+struct ConversionTraits<TestInt32Type> : public CTypeTraits<TestInt32Type> {
+ // AppendRow is not needed, explicitly elide an implementation
+};
+
+template <>
+Status AppendListValues<TestInt32Type, const std::vector<TestInt32Type>&>(
+ Int32Builder& value_builder, const std::vector<TestInt32Type>& cell_range) {
+ return value_builder.AppendValues(reinterpret_cast<const int32_t*>(cell_range.data()),
+ cell_range.size());
+}
+
+TEST(TestSchemaFromTuple, PrimitiveTypesVector) {
+ Schema expected_schema(
+ {field("column1", int8(), false), field("column2", int16(), false),
+ field("column3", int32(), false), field("column4", int64(), false),
+ field("column5", uint8(), false), field("column6", uint16(), false),
+ field("column7", uint32(), false), field("column8", uint64(), false),
+ field("column9", boolean(), false), field("column10", utf8(), false)});
+
+ std::shared_ptr<Schema> schema = SchemaFromTuple<primitive_types_tuple>::MakeSchema(
+ std::vector<std::string>({"column1", "column2", "column3", "column4", "column5",
+ "column6", "column7", "column8", "column9", "column10"}));
+ ASSERT_TRUE(expected_schema.Equals(*schema));
+}
+
+TEST(TestSchemaFromTuple, PrimitiveTypesTuple) {
+ Schema expected_schema(
+ {field("column1", int8(), false), field("column2", int16(), false),
+ field("column3", int32(), false), field("column4", int64(), false),
+ field("column5", uint8(), false), field("column6", uint16(), false),
+ field("column7", uint32(), false), field("column8", uint64(), false),
+ field("column9", boolean(), false), field("column10", utf8(), false)});
+
+ std::shared_ptr<Schema> schema = SchemaFromTuple<primitive_types_tuple>::MakeSchema(
+ std::make_tuple("column1", "column2", "column3", "column4", "column5", "column6",
+ "column7", "column8", "column9", "column10"));
+ ASSERT_TRUE(expected_schema.Equals(*schema));
+}
+
+TEST(TestSchemaFromTuple, SimpleList) {
+ Schema expected_schema({field("column1", list(utf8()), false)});
+ std::shared_ptr<Schema> schema =
+ SchemaFromTuple<std::tuple<std::vector<std::string>>>::MakeSchema({"column1"});
+
+ ASSERT_TRUE(expected_schema.Equals(*schema));
+}
+
+TEST(TestSchemaFromTuple, NestedList) {
+ Schema expected_schema({field("column1", list(list(boolean())), false)});
+ std::shared_ptr<Schema> schema =
+ SchemaFromTuple<std::tuple<std::vector<std::vector<bool>>>>::MakeSchema(
+ {"column1"});
+
+ ASSERT_TRUE(expected_schema.Equals(*schema));
+}
+
+TEST(TestTableFromTupleVector, PrimitiveTypes) {
+ std::vector<std::string> names{"column1", "column2", "column3", "column4", "column5",
+ "column6", "column7", "column8", "column9", "column10"};
+ std::vector<primitive_types_tuple> rows{
+ primitive_types_tuple(-1, -2, -3, -4, 1, 2, 3, 4, true, "Tests"),
+ primitive_types_tuple(-10, -20, -30, -40, 10, 20, 30, 40, false, "Other")};
+ std::shared_ptr<Table> table;
+ ASSERT_OK(TableFromTupleRange(default_memory_pool(), rows, names, &table));
+
+ std::shared_ptr<Schema> expected_schema =
+ schema({field("column1", int8(), false), field("column2", int16(), false),
+ field("column3", int32(), false), field("column4", int64(), false),
+ field("column5", uint8(), false), field("column6", uint16(), false),
+ field("column7", uint32(), false), field("column8", uint64(), false),
+ field("column9", boolean(), false), field("column10", utf8(), false)});
+
+ // Construct expected arrays
+ std::shared_ptr<Array> int8_array = ArrayFromJSON(int8(), "[-1, -10]");
+ std::shared_ptr<Array> int16_array = ArrayFromJSON(int16(), "[-2, -20]");
+ std::shared_ptr<Array> int32_array = ArrayFromJSON(int32(), "[-3, -30]");
+ std::shared_ptr<Array> int64_array = ArrayFromJSON(int64(), "[-4, -40]");
+ std::shared_ptr<Array> uint8_array = ArrayFromJSON(uint8(), "[1, 10]");
+ std::shared_ptr<Array> uint16_array = ArrayFromJSON(uint16(), "[2, 20]");
+ std::shared_ptr<Array> uint32_array = ArrayFromJSON(uint32(), "[3, 30]");
+ std::shared_ptr<Array> uint64_array = ArrayFromJSON(uint64(), "[4, 40]");
+ std::shared_ptr<Array> bool_array = ArrayFromJSON(boolean(), "[true, false]");
+ std::shared_ptr<Array> string_array = ArrayFromJSON(utf8(), R"(["Tests", "Other"])");
+ auto expected_table =
+ Table::Make(expected_schema,
+ {int8_array, int16_array, int32_array, int64_array, uint8_array,
+ uint16_array, uint32_array, uint64_array, bool_array, string_array});
+
+ ASSERT_TRUE(expected_table->Equals(*table));
+}
+
+TEST(TestTableFromTupleVector, ListType) {
+ using tuple_type = std::tuple<std::vector<int64_t>>;
+
+ auto expected_schema =
+ std::shared_ptr<Schema>(new Schema({field("column1", list(int64()), false)}));
+ std::shared_ptr<Array> expected_array =
+ ArrayFromJSON(list(int64()), "[[1, 1, 2, 34], [2, -4]]");
+ std::shared_ptr<Table> expected_table = Table::Make(expected_schema, {expected_array});
+
+ std::vector<tuple_type> rows{tuple_type(std::vector<int64_t>{1, 1, 2, 34}),
+ tuple_type(std::vector<int64_t>{2, -4})};
+ std::vector<std::string> names{"column1"};
+
+ std::shared_ptr<Table> table;
+ ASSERT_OK(TableFromTupleRange(default_memory_pool(), rows, names, &table));
+ ASSERT_TRUE(expected_table->Equals(*table));
+}
+
+TEST(TestTableFromTupleVector, ReferenceTuple) {
+ std::vector<std::string> names{"column1", "column2", "column3", "column4", "column5",
+ "column6", "column7", "column8", "column9", "column10"};
+ std::vector<CustomType> rows{
+ {-1, -2, -3, -4, 1, 2, 3, 4, true, std::string("Tests")},
+ {-10, -20, -30, -40, 10, 20, 30, 40, false, std::string("Other")}};
+ std::vector<decltype(rows[0].tie())> rng_rows{
+ rows[0].tie(),
+ rows[1].tie(),
+ };
+ std::shared_ptr<Table> table;
+ ASSERT_OK(TableFromTupleRange(default_memory_pool(), rng_rows, names, &table));
+
+ std::shared_ptr<Schema> expected_schema =
+ schema({field("column1", int8(), false), field("column2", int16(), false),
+ field("column3", int32(), false), field("column4", int64(), false),
+ field("column5", uint8(), false), field("column6", uint16(), false),
+ field("column7", uint32(), false), field("column8", uint64(), false),
+ field("column9", boolean(), false), field("column10", utf8(), false)});
+
+ // Construct expected arrays
+ std::shared_ptr<Array> int8_array = ArrayFromJSON(int8(), "[-1, -10]");
+ std::shared_ptr<Array> int16_array = ArrayFromJSON(int16(), "[-2, -20]");
+ std::shared_ptr<Array> int32_array = ArrayFromJSON(int32(), "[-3, -30]");
+ std::shared_ptr<Array> int64_array = ArrayFromJSON(int64(), "[-4, -40]");
+ std::shared_ptr<Array> uint8_array = ArrayFromJSON(uint8(), "[1, 10]");
+ std::shared_ptr<Array> uint16_array = ArrayFromJSON(uint16(), "[2, 20]");
+ std::shared_ptr<Array> uint32_array = ArrayFromJSON(uint32(), "[3, 30]");
+ std::shared_ptr<Array> uint64_array = ArrayFromJSON(uint64(), "[4, 40]");
+ std::shared_ptr<Array> bool_array = ArrayFromJSON(boolean(), "[true, false]");
+ std::shared_ptr<Array> string_array = ArrayFromJSON(utf8(), R"(["Tests", "Other"])");
+ auto expected_table =
+ Table::Make(expected_schema,
+ {int8_array, int16_array, int32_array, int64_array, uint8_array,
+ uint16_array, uint32_array, uint64_array, bool_array, string_array});
+
+ ASSERT_TRUE(expected_table->Equals(*table));
+}
+
+TEST(TestTableFromTupleVector, NullableTypesWithBoostOptional) {
+ std::vector<std::string> names{"column1", "column2", "column3", "column4", "column5",
+ "column6", "column7", "column8", "column9", "column10"};
+ using types_tuple = optional_types_tuple;
+ std::vector<types_tuple> rows{
+ types_tuple(-1, -2, -3, -4, 1, 2, 3, 4, true, std::string("Tests")),
+ types_tuple(-10, -20, -30, -40, 10, 20, 30, 40, false, std::string("Other")),
+ types_tuple(util::nullopt, util::nullopt, util::nullopt, util::nullopt,
+ util::nullopt, util::nullopt, util::nullopt, util::nullopt,
+ util::nullopt, util::nullopt),
+ };
+ std::shared_ptr<Table> table;
+ ASSERT_OK(TableFromTupleRange(default_memory_pool(), rows, names, &table));
+
+ std::shared_ptr<Schema> expected_schema =
+ schema({field("column1", int8(), true), field("column2", int16(), true),
+ field("column3", int32(), true), field("column4", int64(), true),
+ field("column5", uint8(), true), field("column6", uint16(), true),
+ field("column7", uint32(), true), field("column8", uint64(), true),
+ field("column9", boolean(), true), field("column10", utf8(), true)});
+
+ // Construct expected arrays
+ std::shared_ptr<Array> int8_array = ArrayFromJSON(int8(), "[-1, -10, null]");
+ std::shared_ptr<Array> int16_array = ArrayFromJSON(int16(), "[-2, -20, null]");
+ std::shared_ptr<Array> int32_array = ArrayFromJSON(int32(), "[-3, -30, null]");
+ std::shared_ptr<Array> int64_array = ArrayFromJSON(int64(), "[-4, -40, null]");
+ std::shared_ptr<Array> uint8_array = ArrayFromJSON(uint8(), "[1, 10, null]");
+ std::shared_ptr<Array> uint16_array = ArrayFromJSON(uint16(), "[2, 20, null]");
+ std::shared_ptr<Array> uint32_array = ArrayFromJSON(uint32(), "[3, 30, null]");
+ std::shared_ptr<Array> uint64_array = ArrayFromJSON(uint64(), "[4, 40, null]");
+ std::shared_ptr<Array> bool_array = ArrayFromJSON(boolean(), "[true, false, null]");
+ std::shared_ptr<Array> string_array =
+ ArrayFromJSON(utf8(), R"(["Tests", "Other", null])");
+ auto expected_table =
+ Table::Make(expected_schema,
+ {int8_array, int16_array, int32_array, int64_array, uint8_array,
+ uint16_array, uint32_array, uint64_array, bool_array, string_array});
+
+ ASSERT_TRUE(expected_table->Equals(*table));
+}
+
+TEST(TestTableFromTupleVector, NullableTypesWithRawPointer) {
+ std::vector<std::string> names{"column1", "column2", "column3", "column4", "column5",
+ "column6", "column7", "column8", "column9", "column10"};
+ std::vector<primitive_types_tuple> data_rows{
+ primitive_types_tuple(-1, -2, -3, -4, 1, 2, 3, 4, true, std::string("Tests")),
+ primitive_types_tuple(-10, -20, -30, -40, 10, 20, 30, 40, false,
+ std::string("Other")),
+ };
+ std::vector<raw_pointer_optional_types_tuple> pointer_rows;
+ for (auto& row : data_rows) {
+ pointer_rows.emplace_back(
+ std::addressof(std::get<0>(row)), std::addressof(std::get<1>(row)),
+ std::addressof(std::get<2>(row)), std::addressof(std::get<3>(row)),
+ std::addressof(std::get<4>(row)), std::addressof(std::get<5>(row)),
+ std::addressof(std::get<6>(row)), std::addressof(std::get<7>(row)),
+ std::addressof(std::get<8>(row)), std::addressof(std::get<9>(row)));
+ }
+ pointer_rows.emplace_back(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
+ nullptr, nullptr, nullptr);
+ std::shared_ptr<Table> table;
+ ASSERT_OK(TableFromTupleRange(default_memory_pool(), pointer_rows, names, &table));
+
+ std::shared_ptr<Schema> expected_schema =
+ schema({field("column1", int8(), true), field("column2", int16(), true),
+ field("column3", int32(), true), field("column4", int64(), true),
+ field("column5", uint8(), true), field("column6", uint16(), true),
+ field("column7", uint32(), true), field("column8", uint64(), true),
+ field("column9", boolean(), true), field("column10", utf8(), true)});
+
+ // Construct expected arrays
+ std::shared_ptr<Array> int8_array = ArrayFromJSON(int8(), "[-1, -10, null]");
+ std::shared_ptr<Array> int16_array = ArrayFromJSON(int16(), "[-2, -20, null]");
+ std::shared_ptr<Array> int32_array = ArrayFromJSON(int32(), "[-3, -30, null]");
+ std::shared_ptr<Array> int64_array = ArrayFromJSON(int64(), "[-4, -40, null]");
+ std::shared_ptr<Array> uint8_array = ArrayFromJSON(uint8(), "[1, 10, null]");
+ std::shared_ptr<Array> uint16_array = ArrayFromJSON(uint16(), "[2, 20, null]");
+ std::shared_ptr<Array> uint32_array = ArrayFromJSON(uint32(), "[3, 30, null]");
+ std::shared_ptr<Array> uint64_array = ArrayFromJSON(uint64(), "[4, 40, null]");
+ std::shared_ptr<Array> bool_array = ArrayFromJSON(boolean(), "[true, false, null]");
+ std::shared_ptr<Array> string_array =
+ ArrayFromJSON(utf8(), R"(["Tests", "Other", null])");
+ auto expected_table =
+ Table::Make(expected_schema,
+ {int8_array, int16_array, int32_array, int64_array, uint8_array,
+ uint16_array, uint32_array, uint64_array, bool_array, string_array});
+
+ ASSERT_TRUE(expected_table->Equals(*table));
+}
+
+TEST(TestTableFromTupleVector, NullableTypesDoNotBreakUserSpecialization) {
+ std::vector<std::string> names{"column1"};
+ std::vector<std::tuple<CustomOptionalTypeMock>> rows(3);
+ std::shared_ptr<Table> table;
+ ASSERT_OK(TableFromTupleRange(default_memory_pool(), rows, names, &table));
+
+ std::shared_ptr<Schema> expected_schema = schema({field("column1", utf8(), true)});
+ std::shared_ptr<Array> string_array =
+ ArrayFromJSON(utf8(), R"([null, "mock yes", "mock no"])");
+ auto expected_table = Table::Make(expected_schema, {string_array});
+
+ ASSERT_TRUE(expected_table->Equals(*table));
+}
+
+TEST(TestTableFromTupleVector, AppendingMultipleRows) {
+ using row_type = std::tuple<std::vector<TestInt32Type>>;
+ std::vector<std::string> names{"column1"};
+ std::vector<row_type> rows = {
+ row_type{{{1}, {2}, {3}}}, //
+ row_type{{{10}, {20}, {30}}} //
+ };
+ std::shared_ptr<Table> table;
+ ASSERT_OK(TableFromTupleRange(default_memory_pool(), rows, names, &table));
+
+ std::shared_ptr<Schema> expected_schema =
+ schema({field("column1", list(int32()), false)});
+ std::shared_ptr<Array> int_array =
+ ArrayFromJSON(list(int32()), "[[1, 2, 3], [10, 20, 30]]");
+ auto expected_table = Table::Make(expected_schema, {int_array});
+
+ ASSERT_TRUE(expected_table->Equals(*table));
+}
+
+TEST(TestTupleVectorFromTable, PrimitiveTypes) {
+ compute::ExecContext ctx;
+ compute::CastOptions cast_options;
+
+ std::vector<primitive_types_tuple> expected_rows{
+ primitive_types_tuple(-1, -2, -3, -4, 1, 2, 3, 4, true, "Tests"),
+ primitive_types_tuple(-10, -20, -30, -40, 10, 20, 30, 40, false, "Other")};
+
+ std::shared_ptr<Schema> schema = std::shared_ptr<Schema>(
+ new Schema({field("column1", int8(), false), field("column2", int16(), false),
+ field("column3", int32(), false), field("column4", int64(), false),
+ field("column5", uint8(), false), field("column6", uint16(), false),
+ field("column7", uint32(), false), field("column8", uint64(), false),
+ field("column9", boolean(), false), field("column10", utf8(), false)}));
+
+ // Construct expected arrays
+ std::shared_ptr<Array> int8_array;
+ ArrayFromVector<Int8Type, int8_t>({-1, -10}, &int8_array);
+ std::shared_ptr<Array> int16_array;
+ ArrayFromVector<Int16Type, int16_t>({-2, -20}, &int16_array);
+ std::shared_ptr<Array> int32_array;
+ ArrayFromVector<Int32Type, int32_t>({-3, -30}, &int32_array);
+ std::shared_ptr<Array> int64_array;
+ ArrayFromVector<Int64Type, int64_t>({-4, -40}, &int64_array);
+ std::shared_ptr<Array> uint8_array;
+ ArrayFromVector<UInt8Type, uint8_t>({1, 10}, &uint8_array);
+ std::shared_ptr<Array> uint16_array;
+ ArrayFromVector<UInt16Type, uint16_t>({2, 20}, &uint16_array);
+ std::shared_ptr<Array> uint32_array;
+ ArrayFromVector<UInt32Type, uint32_t>({3, 30}, &uint32_array);
+ std::shared_ptr<Array> uint64_array;
+ ArrayFromVector<UInt64Type, uint64_t>({4, 40}, &uint64_array);
+ std::shared_ptr<Array> bool_array;
+ ArrayFromVector<BooleanType, bool>({true, false}, &bool_array);
+ std::shared_ptr<Array> string_array;
+ ArrayFromVector<StringType, std::string>({"Tests", "Other"}, &string_array);
+ auto table = Table::Make(
+ schema, {int8_array, int16_array, int32_array, int64_array, uint8_array,
+ uint16_array, uint32_array, uint64_array, bool_array, string_array});
+
+ std::vector<primitive_types_tuple> rows(2);
+ ASSERT_OK(TupleRangeFromTable(*table, cast_options, &ctx, &rows));
+ ASSERT_EQ(rows, expected_rows);
+
+ // The number of rows must match
+ std::vector<primitive_types_tuple> too_few_rows(1);
+ ASSERT_RAISES(Invalid, TupleRangeFromTable(*table, cast_options, &ctx, &too_few_rows));
+
+ // The number of columns must match
+ ASSERT_OK_AND_ASSIGN(auto corrupt_table, table->RemoveColumn(0));
+ ASSERT_RAISES(Invalid, TupleRangeFromTable(*corrupt_table, cast_options, &ctx, &rows));
+}
+
+TEST(TestTupleVectorFromTable, ListType) {
+ using tuple_type = std::tuple<std::vector<int64_t>>;
+
+ compute::ExecContext ctx;
+ compute::CastOptions cast_options;
+ auto expected_schema =
+ std::shared_ptr<Schema>(new Schema({field("column1", list(int64()), false)}));
+ std::shared_ptr<Array> expected_array =
+ ArrayFromJSON(list(int64()), "[[1, 1, 2, 34], [2, -4]]");
+ std::shared_ptr<Table> table = Table::Make(expected_schema, {expected_array});
+
+ std::vector<tuple_type> expected_rows{tuple_type(std::vector<int64_t>{1, 1, 2, 34}),
+ tuple_type(std::vector<int64_t>{2, -4})};
+
+ std::vector<tuple_type> rows(2);
+ ASSERT_OK(TupleRangeFromTable(*table, cast_options, &ctx, &rows));
+ ASSERT_EQ(rows, expected_rows);
+}
+
+TEST(TestTupleVectorFromTable, CastingNeeded) {
+ using tuple_type = std::tuple<std::vector<int64_t>>;
+
+ compute::ExecContext ctx;
+ compute::CastOptions cast_options;
+ auto expected_schema =
+ std::shared_ptr<Schema>(new Schema({field("column1", list(int16()), false)}));
+ std::shared_ptr<Array> expected_array =
+ ArrayFromJSON(list(int16()), "[[1, 1, 2, 34], [2, -4]]");
+ std::shared_ptr<Table> table = Table::Make(expected_schema, {expected_array});
+
+ std::vector<tuple_type> expected_rows{tuple_type(std::vector<int64_t>{1, 1, 2, 34}),
+ tuple_type(std::vector<int64_t>{2, -4})};
+
+ std::vector<tuple_type> rows(2);
+ ASSERT_OK(TupleRangeFromTable(*table, cast_options, &ctx, &rows));
+ ASSERT_EQ(rows, expected_rows);
+}
+
+TEST(STLMemoryPool, Base) {
+ std::allocator<uint8_t> allocator;
+ STLMemoryPool<std::allocator<uint8_t>> pool(allocator);
+
+ uint8_t* data = nullptr;
+ ASSERT_OK(pool.Allocate(100, &data));
+ ASSERT_EQ(pool.max_memory(), 100);
+ ASSERT_EQ(pool.bytes_allocated(), 100);
+ ASSERT_NE(data, nullptr);
+
+ ASSERT_OK(pool.Reallocate(100, 150, &data));
+ ASSERT_EQ(pool.max_memory(), 150);
+ ASSERT_EQ(pool.bytes_allocated(), 150);
+
+ pool.Free(data, 150);
+
+ ASSERT_EQ(pool.max_memory(), 150);
+ ASSERT_EQ(pool.bytes_allocated(), 0);
+}
+
+TEST(allocator, MemoryTracking) {
+ auto pool = default_memory_pool();
+ allocator<uint64_t> alloc;
+ uint64_t* data = alloc.allocate(100);
+
+ ASSERT_EQ(100 * sizeof(uint64_t), pool->bytes_allocated());
+
+ alloc.deallocate(data, 100);
+ ASSERT_EQ(0, pool->bytes_allocated());
+}
+
+#if !(defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER) || defined(ARROW_JEMALLOC))
+
+TEST(allocator, TestOOM) {
+ allocator<uint64_t> alloc;
+ uint64_t to_alloc = std::numeric_limits<uint64_t>::max() / 2;
+ ASSERT_THROW(alloc.allocate(to_alloc), std::bad_alloc);
+}
+
+TEST(stl_allocator, MaxMemory) {
+ auto pool = MemoryPool::CreateDefault();
+
+ allocator<uint8_t> alloc(pool.get());
+ uint8_t* data = alloc.allocate(1000);
+ uint8_t* data2 = alloc.allocate(1000);
+
+ alloc.deallocate(data, 1000);
+ alloc.deallocate(data2, 1000);
+
+ ASSERT_EQ(2000, pool->max_memory());
+}
+
+#endif // !(defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER)
+ // || defined(ARROW_JEMALLOC))
+
+} // namespace stl
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/symbols.map b/src/arrow/cpp/src/arrow/symbols.map
new file mode 100644
index 000000000..7262cc6a8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/symbols.map
@@ -0,0 +1,38 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+{
+ global:
+ extern "C++" {
+ # The leading asterisk is required for symbols such as
+ # "typeinfo for arrow::SomeClass".
+ # Unfortunately this will also catch template specializations
+ # (from e.g. STL or Flatbuffers) involving Arrow types.
+ *arrow::*;
+ *arrow_vendored::*;
+ };
+ # Also export C-level helpers
+ arrow_*;
+ pyarrow_*;
+
+ # Symbols marked as 'local' are not exported by the DSO and thus may not
+ # be used by client applications. Everything except the above falls here.
+ # This ensures we hide symbols of static dependencies.
+ local:
+ *;
+
+};
diff --git a/src/arrow/cpp/src/arrow/table.cc b/src/arrow/cpp/src/arrow/table.cc
new file mode 100644
index 000000000..a8a45e9ed
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/table.cc
@@ -0,0 +1,641 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/table.h"
+
+#include <algorithm>
+#include <cstdlib>
+#include <limits>
+#include <memory>
+#include <sstream>
+#include <utility>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_binary.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/array/util.h"
+#include "arrow/chunked_array.h"
+#include "arrow/pretty_print.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/vector.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+class KeyValueMetadata;
+class MemoryPool;
+struct ArrayData;
+
+// ----------------------------------------------------------------------
+// Table methods
+
+/// \class SimpleTable
+/// \brief A basic, non-lazy in-memory table, like SimpleRecordBatch
+class SimpleTable : public Table {
+ public:
+ SimpleTable(std::shared_ptr<Schema> schema,
+ std::vector<std::shared_ptr<ChunkedArray>> columns, int64_t num_rows = -1)
+ : columns_(std::move(columns)) {
+ schema_ = std::move(schema);
+ if (num_rows < 0) {
+ if (columns_.size() == 0) {
+ num_rows_ = 0;
+ } else {
+ num_rows_ = columns_[0]->length();
+ }
+ } else {
+ num_rows_ = num_rows;
+ }
+ }
+
+ SimpleTable(std::shared_ptr<Schema> schema,
+ const std::vector<std::shared_ptr<Array>>& columns, int64_t num_rows = -1) {
+ schema_ = std::move(schema);
+ if (num_rows < 0) {
+ if (columns.size() == 0) {
+ num_rows_ = 0;
+ } else {
+ num_rows_ = columns[0]->length();
+ }
+ } else {
+ num_rows_ = num_rows;
+ }
+
+ columns_.resize(columns.size());
+ for (size_t i = 0; i < columns.size(); ++i) {
+ columns_[i] = std::make_shared<ChunkedArray>(columns[i]);
+ }
+ }
+
+ std::shared_ptr<ChunkedArray> column(int i) const override { return columns_[i]; }
+
+ const std::vector<std::shared_ptr<ChunkedArray>>& columns() const override {
+ return columns_;
+ }
+
+ std::shared_ptr<Table> Slice(int64_t offset, int64_t length) const override {
+ auto sliced = columns_;
+ int64_t num_rows = length;
+ for (auto& column : sliced) {
+ column = column->Slice(offset, length);
+ num_rows = column->length();
+ }
+ return Table::Make(schema_, std::move(sliced), num_rows);
+ }
+
+ Result<std::shared_ptr<Table>> RemoveColumn(int i) const override {
+ ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->RemoveField(i));
+
+ return Table::Make(std::move(new_schema), internal::DeleteVectorElement(columns_, i),
+ this->num_rows());
+ }
+
+ Result<std::shared_ptr<Table>> AddColumn(
+ int i, std::shared_ptr<Field> field_arg,
+ std::shared_ptr<ChunkedArray> col) const override {
+ DCHECK(col != nullptr);
+
+ if (col->length() != num_rows_) {
+ return Status::Invalid(
+ "Added column's length must match table's length. Expected length ", num_rows_,
+ " but got length ", col->length());
+ }
+
+ if (!field_arg->type()->Equals(col->type())) {
+ return Status::Invalid("Field type did not match data type");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->AddField(i, field_arg));
+ return Table::Make(std::move(new_schema),
+ internal::AddVectorElement(columns_, i, std::move(col)));
+ }
+
+ Result<std::shared_ptr<Table>> SetColumn(
+ int i, std::shared_ptr<Field> field_arg,
+ std::shared_ptr<ChunkedArray> col) const override {
+ DCHECK(col != nullptr);
+
+ if (col->length() != num_rows_) {
+ return Status::Invalid(
+ "Added column's length must match table's length. Expected length ", num_rows_,
+ " but got length ", col->length());
+ }
+
+ if (!field_arg->type()->Equals(col->type())) {
+ return Status::Invalid("Field type did not match data type");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->SetField(i, field_arg));
+ return Table::Make(std::move(new_schema),
+ internal::ReplaceVectorElement(columns_, i, std::move(col)));
+ }
+
+ std::shared_ptr<Table> ReplaceSchemaMetadata(
+ const std::shared_ptr<const KeyValueMetadata>& metadata) const override {
+ auto new_schema = schema_->WithMetadata(metadata);
+ return Table::Make(std::move(new_schema), columns_);
+ }
+
+ Result<std::shared_ptr<Table>> Flatten(MemoryPool* pool) const override {
+ std::vector<std::shared_ptr<Field>> flattened_fields;
+ std::vector<std::shared_ptr<ChunkedArray>> flattened_columns;
+ for (int i = 0; i < num_columns(); ++i) {
+ std::vector<std::shared_ptr<Field>> new_fields = field(i)->Flatten();
+ ARROW_ASSIGN_OR_RAISE(auto new_columns, column(i)->Flatten(pool));
+ DCHECK_EQ(new_columns.size(), new_fields.size());
+ for (size_t j = 0; j < new_columns.size(); ++j) {
+ flattened_fields.push_back(new_fields[j]);
+ flattened_columns.push_back(new_columns[j]);
+ }
+ }
+ auto flattened_schema =
+ std::make_shared<Schema>(std::move(flattened_fields), schema_->metadata());
+ return Table::Make(std::move(flattened_schema), std::move(flattened_columns));
+ }
+
+ Status Validate() const override {
+ RETURN_NOT_OK(ValidateMeta());
+ for (int i = 0; i < num_columns(); ++i) {
+ const ChunkedArray* col = columns_[i].get();
+ Status st = col->Validate();
+ if (!st.ok()) {
+ std::stringstream ss;
+ ss << "Column " << i << ": " << st.message();
+ return st.WithMessage(ss.str());
+ }
+ }
+ return Status::OK();
+ }
+
+ Status ValidateFull() const override {
+ RETURN_NOT_OK(ValidateMeta());
+ for (int i = 0; i < num_columns(); ++i) {
+ const ChunkedArray* col = columns_[i].get();
+ Status st = col->ValidateFull();
+ if (!st.ok()) {
+ std::stringstream ss;
+ ss << "Column " << i << ": " << st.message();
+ return st.WithMessage(ss.str());
+ }
+ }
+ return Status::OK();
+ }
+
+ protected:
+ Status ValidateMeta() const {
+ // Make sure columns and schema are consistent
+ if (static_cast<int>(columns_.size()) != schema_->num_fields()) {
+ return Status::Invalid("Number of columns did not match schema");
+ }
+ for (int i = 0; i < num_columns(); ++i) {
+ const ChunkedArray* col = columns_[i].get();
+ if (col == nullptr) {
+ return Status::Invalid("Column ", i, " was null");
+ }
+ if (!col->type()->Equals(*schema_->field(i)->type())) {
+ return Status::Invalid("Column data for field ", i, " with type ",
+ col->type()->ToString(), " is inconsistent with schema ",
+ schema_->field(i)->type()->ToString());
+ }
+ }
+
+ // Make sure columns are all the same length, and validate them
+ for (int i = 0; i < num_columns(); ++i) {
+ const ChunkedArray* col = columns_[i].get();
+ if (col->length() != num_rows_) {
+ return Status::Invalid("Column ", i, " named ", field(i)->name(),
+ " expected length ", num_rows_, " but got length ",
+ col->length());
+ }
+ Status st = col->Validate();
+ if (!st.ok()) {
+ std::stringstream ss;
+ ss << "Column " << i << ": " << st.message();
+ return st.WithMessage(ss.str());
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ std::vector<std::shared_ptr<ChunkedArray>> columns_;
+};
+
+Table::Table() : num_rows_(0) {}
+
+std::vector<std::shared_ptr<Field>> Table::fields() const {
+ std::vector<std::shared_ptr<Field>> result;
+ for (int i = 0; i < this->num_columns(); ++i) {
+ result.emplace_back(this->field(i));
+ }
+ return result;
+}
+
+std::shared_ptr<Table> Table::Make(std::shared_ptr<Schema> schema,
+ std::vector<std::shared_ptr<ChunkedArray>> columns,
+ int64_t num_rows) {
+ return std::make_shared<SimpleTable>(std::move(schema), std::move(columns), num_rows);
+}
+
+std::shared_ptr<Table> Table::Make(std::shared_ptr<Schema> schema,
+ const std::vector<std::shared_ptr<Array>>& arrays,
+ int64_t num_rows) {
+ return std::make_shared<SimpleTable>(std::move(schema), arrays, num_rows);
+}
+
+Result<std::shared_ptr<Table>> Table::FromRecordBatchReader(RecordBatchReader* reader) {
+ std::shared_ptr<Table> table = nullptr;
+ RETURN_NOT_OK(reader->ReadAll(&table));
+ return table;
+}
+
+Result<std::shared_ptr<Table>> Table::FromRecordBatches(
+ std::shared_ptr<Schema> schema,
+ const std::vector<std::shared_ptr<RecordBatch>>& batches) {
+ const int nbatches = static_cast<int>(batches.size());
+ const int ncolumns = static_cast<int>(schema->num_fields());
+
+ int64_t num_rows = 0;
+ for (int i = 0; i < nbatches; ++i) {
+ if (!batches[i]->schema()->Equals(*schema, false)) {
+ return Status::Invalid("Schema at index ", static_cast<int>(i),
+ " was different: \n", schema->ToString(), "\nvs\n",
+ batches[i]->schema()->ToString());
+ }
+ num_rows += batches[i]->num_rows();
+ }
+
+ std::vector<std::shared_ptr<ChunkedArray>> columns(ncolumns);
+ std::vector<std::shared_ptr<Array>> column_arrays(nbatches);
+
+ for (int i = 0; i < ncolumns; ++i) {
+ for (int j = 0; j < nbatches; ++j) {
+ column_arrays[j] = batches[j]->column(i);
+ }
+ columns[i] = std::make_shared<ChunkedArray>(column_arrays, schema->field(i)->type());
+ }
+
+ return Table::Make(std::move(schema), std::move(columns), num_rows);
+}
+
+Result<std::shared_ptr<Table>> Table::FromRecordBatches(
+ const std::vector<std::shared_ptr<RecordBatch>>& batches) {
+ if (batches.size() == 0) {
+ return Status::Invalid("Must pass at least one record batch or an explicit Schema");
+ }
+
+ return FromRecordBatches(batches[0]->schema(), batches);
+}
+
+Result<std::shared_ptr<Table>> Table::FromChunkedStructArray(
+ const std::shared_ptr<ChunkedArray>& array) {
+ auto type = array->type();
+ if (type->id() != Type::STRUCT) {
+ return Status::Invalid("Expected a chunked struct array, got ", *type);
+ }
+ int num_columns = type->num_fields();
+ int num_chunks = array->num_chunks();
+
+ const auto& struct_chunks = array->chunks();
+ std::vector<std::shared_ptr<ChunkedArray>> columns(num_columns);
+ for (int i = 0; i < num_columns; ++i) {
+ ArrayVector chunks(num_chunks);
+ std::transform(struct_chunks.begin(), struct_chunks.end(), chunks.begin(),
+ [i](const std::shared_ptr<Array>& struct_chunk) {
+ return static_cast<const StructArray&>(*struct_chunk).field(i);
+ });
+ columns[i] =
+ std::make_shared<ChunkedArray>(std::move(chunks), type->field(i)->type());
+ }
+
+ return Table::Make(::arrow::schema(type->fields()), std::move(columns),
+ array->length());
+}
+
+std::vector<std::string> Table::ColumnNames() const {
+ std::vector<std::string> names(num_columns());
+ for (int i = 0; i < num_columns(); ++i) {
+ names[i] = field(i)->name();
+ }
+ return names;
+}
+
+Result<std::shared_ptr<Table>> Table::RenameColumns(
+ const std::vector<std::string>& names) const {
+ if (names.size() != static_cast<size_t>(num_columns())) {
+ return Status::Invalid("tried to rename a table of ", num_columns(),
+ " columns but only ", names.size(), " names were provided");
+ }
+ std::vector<std::shared_ptr<ChunkedArray>> columns(num_columns());
+ std::vector<std::shared_ptr<Field>> fields(num_columns());
+ for (int i = 0; i < num_columns(); ++i) {
+ columns[i] = column(i);
+ fields[i] = field(i)->WithName(names[i]);
+ }
+ return Table::Make(::arrow::schema(std::move(fields)), std::move(columns), num_rows());
+}
+
+Result<std::shared_ptr<Table>> Table::SelectColumns(
+ const std::vector<int>& indices) const {
+ int n = static_cast<int>(indices.size());
+
+ std::vector<std::shared_ptr<ChunkedArray>> columns(n);
+ std::vector<std::shared_ptr<Field>> fields(n);
+ for (int i = 0; i < n; i++) {
+ int pos = indices[i];
+ if (pos < 0 || pos > num_columns() - 1) {
+ return Status::Invalid("Invalid column index ", pos, " to select columns.");
+ }
+ columns[i] = column(pos);
+ fields[i] = field(pos);
+ }
+
+ auto new_schema =
+ std::make_shared<arrow::Schema>(std::move(fields), schema()->metadata());
+ return Table::Make(std::move(new_schema), std::move(columns), num_rows());
+}
+
+std::string Table::ToString() const {
+ std::stringstream ss;
+ ARROW_CHECK_OK(PrettyPrint(*this, 0, &ss));
+ return ss.str();
+}
+
+Result<std::shared_ptr<Table>> ConcatenateTables(
+ const std::vector<std::shared_ptr<Table>>& tables,
+ const ConcatenateTablesOptions options, MemoryPool* memory_pool) {
+ if (tables.size() == 0) {
+ return Status::Invalid("Must pass at least one table");
+ }
+
+ std::vector<std::shared_ptr<Table>> promoted_tables;
+ const std::vector<std::shared_ptr<Table>>* tables_to_concat = &tables;
+ if (options.unify_schemas) {
+ std::vector<std::shared_ptr<Schema>> schemas;
+ schemas.reserve(tables.size());
+ for (const auto& t : tables) {
+ schemas.push_back(t->schema());
+ }
+
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Schema> unified_schema,
+ UnifySchemas(schemas, options.field_merge_options));
+
+ promoted_tables.reserve(tables.size());
+ for (const auto& t : tables) {
+ promoted_tables.emplace_back();
+ ARROW_ASSIGN_OR_RAISE(promoted_tables.back(),
+ PromoteTableToSchema(t, unified_schema, memory_pool));
+ }
+ tables_to_concat = &promoted_tables;
+ } else {
+ auto first_schema = tables[0]->schema();
+ for (size_t i = 1; i < tables.size(); ++i) {
+ if (!tables[i]->schema()->Equals(*first_schema, false)) {
+ return Status::Invalid("Schema at index ", i, " was different: \n",
+ first_schema->ToString(), "\nvs\n",
+ tables[i]->schema()->ToString());
+ }
+ }
+ }
+
+ std::shared_ptr<Schema> schema = tables_to_concat->front()->schema();
+
+ const int ncolumns = schema->num_fields();
+
+ std::vector<std::shared_ptr<ChunkedArray>> columns(ncolumns);
+ for (int i = 0; i < ncolumns; ++i) {
+ std::vector<std::shared_ptr<Array>> column_arrays;
+ for (const auto& table : *tables_to_concat) {
+ const std::vector<std::shared_ptr<Array>>& chunks = table->column(i)->chunks();
+ for (const auto& chunk : chunks) {
+ column_arrays.push_back(chunk);
+ }
+ }
+ columns[i] = std::make_shared<ChunkedArray>(column_arrays, schema->field(i)->type());
+ }
+ return Table::Make(std::move(schema), std::move(columns));
+}
+
+Result<std::shared_ptr<Table>> PromoteTableToSchema(const std::shared_ptr<Table>& table,
+ const std::shared_ptr<Schema>& schema,
+ MemoryPool* pool) {
+ const std::shared_ptr<Schema> current_schema = table->schema();
+ if (current_schema->Equals(*schema, /*check_metadata=*/false)) {
+ return table->ReplaceSchemaMetadata(schema->metadata());
+ }
+
+ // fields_seen[i] == true iff that field is also in `schema`.
+ std::vector<bool> fields_seen(current_schema->num_fields(), false);
+
+ std::vector<std::shared_ptr<ChunkedArray>> columns;
+ columns.reserve(schema->num_fields());
+ const int64_t num_rows = table->num_rows();
+ auto AppendColumnOfNulls = [pool, &columns,
+ num_rows](const std::shared_ptr<DataType>& type) {
+ // TODO(bkietz): share the zero-filled buffers as much as possible across
+ // the null-filled arrays created here.
+ ARROW_ASSIGN_OR_RAISE(auto array_of_nulls, MakeArrayOfNull(type, num_rows, pool));
+ columns.push_back(std::make_shared<ChunkedArray>(array_of_nulls));
+ return Status::OK();
+ };
+
+ for (const auto& field : schema->fields()) {
+ const std::vector<int> field_indices =
+ current_schema->GetAllFieldIndices(field->name());
+ if (field_indices.empty()) {
+ RETURN_NOT_OK(AppendColumnOfNulls(field->type()));
+ continue;
+ }
+
+ if (field_indices.size() > 1) {
+ return Status::Invalid(
+ "PromoteTableToSchema cannot handle schemas with duplicate fields: ",
+ field->name());
+ }
+
+ const int field_index = field_indices[0];
+ const auto& current_field = current_schema->field(field_index);
+ if (!field->nullable() && current_field->nullable()) {
+ return Status::Invalid("Unable to promote field ", current_field->name(),
+ ": it was nullable but the target schema was not.");
+ }
+
+ fields_seen[field_index] = true;
+ if (current_field->type()->Equals(field->type())) {
+ columns.push_back(table->column(field_index));
+ continue;
+ }
+
+ if (current_field->type()->id() == Type::NA) {
+ RETURN_NOT_OK(AppendColumnOfNulls(field->type()));
+ continue;
+ }
+
+ return Status::Invalid("Unable to promote field ", field->name(),
+ ": incompatible types: ", field->type()->ToString(), " vs ",
+ current_field->type()->ToString());
+ }
+
+ auto unseen_field_iter = std::find(fields_seen.begin(), fields_seen.end(), false);
+ if (unseen_field_iter != fields_seen.end()) {
+ const size_t unseen_field_index = unseen_field_iter - fields_seen.begin();
+ return Status::Invalid(
+ "Incompatible schemas: field ",
+ current_schema->field(static_cast<int>(unseen_field_index))->name(),
+ " did not exist in the new schema.");
+ }
+
+ return Table::Make(schema, std::move(columns));
+}
+
+bool Table::Equals(const Table& other, bool check_metadata) const {
+ if (this == &other) {
+ return true;
+ }
+ if (!schema_->Equals(*other.schema(), check_metadata)) {
+ return false;
+ }
+ if (this->num_columns() != other.num_columns()) {
+ return false;
+ }
+
+ for (int i = 0; i < this->num_columns(); i++) {
+ if (!this->column(i)->Equals(other.column(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+Result<std::shared_ptr<Table>> Table::CombineChunks(MemoryPool* pool) const {
+ const int ncolumns = num_columns();
+ std::vector<std::shared_ptr<ChunkedArray>> compacted_columns(ncolumns);
+ for (int i = 0; i < ncolumns; ++i) {
+ const auto& col = column(i);
+ if (col->num_chunks() <= 1) {
+ compacted_columns[i] = col;
+ continue;
+ }
+
+ if (is_binary_like(col->type()->id())) {
+ // ARROW-5744 Allow binary columns to be combined into multiple chunks to avoid
+ // buffer overflow
+ ArrayVector chunks;
+ int chunk_i = 0;
+ while (chunk_i < col->num_chunks()) {
+ ArrayVector safe_chunks;
+ int64_t data_length = 0;
+ for (; chunk_i < col->num_chunks(); ++chunk_i) {
+ const auto& chunk = col->chunk(chunk_i);
+ data_length += checked_cast<const BinaryArray&>(*chunk).total_values_length();
+ if (data_length >= kBinaryMemoryLimit) {
+ break;
+ }
+ safe_chunks.push_back(chunk);
+ }
+ chunks.emplace_back();
+ ARROW_ASSIGN_OR_RAISE(chunks.back(), Concatenate(safe_chunks, pool));
+ }
+ compacted_columns[i] = std::make_shared<ChunkedArray>(std::move(chunks));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto compacted, Concatenate(col->chunks(), pool));
+ compacted_columns[i] = std::make_shared<ChunkedArray>(compacted);
+ }
+ }
+ return Table::Make(schema(), std::move(compacted_columns), num_rows_);
+}
+
+// ----------------------------------------------------------------------
+// Convert a table to a sequence of record batches
+
+TableBatchReader::TableBatchReader(const Table& table)
+ : table_(table),
+ column_data_(table.num_columns()),
+ chunk_numbers_(table.num_columns(), 0),
+ chunk_offsets_(table.num_columns(), 0),
+ absolute_row_position_(0),
+ max_chunksize_(std::numeric_limits<int64_t>::max()) {
+ for (int i = 0; i < table.num_columns(); ++i) {
+ column_data_[i] = table.column(i).get();
+ }
+}
+
+std::shared_ptr<Schema> TableBatchReader::schema() const { return table_.schema(); }
+
+void TableBatchReader::set_chunksize(int64_t chunksize) { max_chunksize_ = chunksize; }
+
+Status TableBatchReader::ReadNext(std::shared_ptr<RecordBatch>* out) {
+ if (absolute_row_position_ == table_.num_rows()) {
+ *out = nullptr;
+ return Status::OK();
+ }
+
+ // Determine the minimum contiguous slice across all columns
+ int64_t chunksize = std::min(table_.num_rows(), max_chunksize_);
+ std::vector<const Array*> chunks(table_.num_columns());
+ for (int i = 0; i < table_.num_columns(); ++i) {
+ auto chunk = column_data_[i]->chunk(chunk_numbers_[i]).get();
+ int64_t chunk_remaining = chunk->length() - chunk_offsets_[i];
+
+ if (chunk_remaining < chunksize) {
+ chunksize = chunk_remaining;
+ }
+
+ chunks[i] = chunk;
+ }
+
+ // Slice chunks and advance chunk index as appropriate
+ std::vector<std::shared_ptr<ArrayData>> batch_data(table_.num_columns());
+
+ for (int i = 0; i < table_.num_columns(); ++i) {
+ // Exhausted chunk
+ const Array* chunk = chunks[i];
+ const int64_t offset = chunk_offsets_[i];
+ std::shared_ptr<ArrayData> slice_data;
+ if ((chunk->length() - offset) == chunksize) {
+ ++chunk_numbers_[i];
+ chunk_offsets_[i] = 0;
+ if (offset > 0) {
+ // Need to slice
+ slice_data = chunk->Slice(offset, chunksize)->data();
+ } else {
+ // No slice
+ slice_data = chunk->data();
+ }
+ } else {
+ chunk_offsets_[i] += chunksize;
+ slice_data = chunk->Slice(offset, chunksize)->data();
+ }
+ batch_data[i] = std::move(slice_data);
+ }
+
+ absolute_row_position_ += chunksize;
+ *out = RecordBatch::Make(table_.schema(), chunksize, std::move(batch_data));
+
+ return Status::OK();
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/table.h b/src/arrow/cpp/src/arrow/table.h
new file mode 100644
index 000000000..b313e9262
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/table.h
@@ -0,0 +1,295 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/chunked_array.h" // IWYU pragma: keep
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Array;
+class ChunkedArray;
+class KeyValueMetadata;
+class MemoryPool;
+
+/// \class Table
+/// \brief Logical table as sequence of chunked arrays
+class ARROW_EXPORT Table {
+ public:
+ virtual ~Table() = default;
+
+ /// \brief Construct a Table from schema and columns
+ ///
+ /// If columns is zero-length, the table's number of rows is zero
+ ///
+ /// \param[in] schema The table schema (column types)
+ /// \param[in] columns The table's columns as chunked arrays
+ /// \param[in] num_rows number of rows in table, -1 (default) to infer from columns
+ static std::shared_ptr<Table> Make(std::shared_ptr<Schema> schema,
+ std::vector<std::shared_ptr<ChunkedArray>> columns,
+ int64_t num_rows = -1);
+
+ /// \brief Construct a Table from schema and arrays
+ ///
+ /// \param[in] schema The table schema (column types)
+ /// \param[in] arrays The table's columns as arrays
+ /// \param[in] num_rows number of rows in table, -1 (default) to infer from columns
+ static std::shared_ptr<Table> Make(std::shared_ptr<Schema> schema,
+ const std::vector<std::shared_ptr<Array>>& arrays,
+ int64_t num_rows = -1);
+
+ /// \brief Construct a Table from a RecordBatchReader.
+ ///
+ /// \param[in] reader the arrow::Schema for each batch
+ static Result<std::shared_ptr<Table>> FromRecordBatchReader(RecordBatchReader* reader);
+
+ /// \brief Construct a Table from RecordBatches, using schema supplied by the first
+ /// RecordBatch.
+ ///
+ /// \param[in] batches a std::vector of record batches
+ static Result<std::shared_ptr<Table>> FromRecordBatches(
+ const std::vector<std::shared_ptr<RecordBatch>>& batches);
+
+ /// \brief Construct a Table from RecordBatches, using supplied schema. There may be
+ /// zero record batches
+ ///
+ /// \param[in] schema the arrow::Schema for each batch
+ /// \param[in] batches a std::vector of record batches
+ static Result<std::shared_ptr<Table>> FromRecordBatches(
+ std::shared_ptr<Schema> schema,
+ const std::vector<std::shared_ptr<RecordBatch>>& batches);
+
+ /// \brief Construct a Table from a chunked StructArray. One column will be produced
+ /// for each field of the StructArray.
+ ///
+ /// \param[in] array a chunked StructArray
+ static Result<std::shared_ptr<Table>> FromChunkedStructArray(
+ const std::shared_ptr<ChunkedArray>& array);
+
+ /// \brief Return the table schema
+ const std::shared_ptr<Schema>& schema() const { return schema_; }
+
+ /// \brief Return a column by index
+ virtual std::shared_ptr<ChunkedArray> column(int i) const = 0;
+
+ /// \brief Return vector of all columns for table
+ virtual const std::vector<std::shared_ptr<ChunkedArray>>& columns() const = 0;
+
+ /// Return a column's field by index
+ std::shared_ptr<Field> field(int i) const { return schema_->field(i); }
+
+ /// \brief Return vector of all fields for table
+ std::vector<std::shared_ptr<Field>> fields() const;
+
+ /// \brief Construct a zero-copy slice of the table with the
+ /// indicated offset and length
+ ///
+ /// \param[in] offset the index of the first row in the constructed
+ /// slice
+ /// \param[in] length the number of rows of the slice. If there are not enough
+ /// rows in the table, the length will be adjusted accordingly
+ ///
+ /// \return a new object wrapped in std::shared_ptr<Table>
+ virtual std::shared_ptr<Table> Slice(int64_t offset, int64_t length) const = 0;
+
+ /// \brief Slice from first row at offset until end of the table
+ std::shared_ptr<Table> Slice(int64_t offset) const { return Slice(offset, num_rows_); }
+
+ /// \brief Return a column by name
+ /// \param[in] name field name
+ /// \return an Array or null if no field was found
+ std::shared_ptr<ChunkedArray> GetColumnByName(const std::string& name) const {
+ auto i = schema_->GetFieldIndex(name);
+ return i == -1 ? NULLPTR : column(i);
+ }
+
+ /// \brief Remove column from the table, producing a new Table
+ virtual Result<std::shared_ptr<Table>> RemoveColumn(int i) const = 0;
+
+ /// \brief Add column to the table, producing a new Table
+ virtual Result<std::shared_ptr<Table>> AddColumn(
+ int i, std::shared_ptr<Field> field_arg,
+ std::shared_ptr<ChunkedArray> column) const = 0;
+
+ /// \brief Replace a column in the table, producing a new Table
+ virtual Result<std::shared_ptr<Table>> SetColumn(
+ int i, std::shared_ptr<Field> field_arg,
+ std::shared_ptr<ChunkedArray> column) const = 0;
+
+ /// \brief Return names of all columns
+ std::vector<std::string> ColumnNames() const;
+
+ /// \brief Rename columns with provided names
+ Result<std::shared_ptr<Table>> RenameColumns(
+ const std::vector<std::string>& names) const;
+
+ /// \brief Return new table with specified columns
+ Result<std::shared_ptr<Table>> SelectColumns(const std::vector<int>& indices) const;
+
+ /// \brief Replace schema key-value metadata with new metadata
+ /// \since 0.5.0
+ ///
+ /// \param[in] metadata new KeyValueMetadata
+ /// \return new Table
+ virtual std::shared_ptr<Table> ReplaceSchemaMetadata(
+ const std::shared_ptr<const KeyValueMetadata>& metadata) const = 0;
+
+ /// \brief Flatten the table, producing a new Table. Any column with a
+ /// struct type will be flattened into multiple columns
+ ///
+ /// \param[in] pool The pool for buffer allocations, if any
+ virtual Result<std::shared_ptr<Table>> Flatten(
+ MemoryPool* pool = default_memory_pool()) const = 0;
+
+ /// \return PrettyPrint representation suitable for debugging
+ std::string ToString() const;
+
+ /// \brief Perform cheap validation checks to determine obvious inconsistencies
+ /// within the table's schema and internal data.
+ ///
+ /// This is O(k*m) where k is the total number of field descendents,
+ /// and m is the number of chunks.
+ ///
+ /// \return Status
+ virtual Status Validate() const = 0;
+
+ /// \brief Perform extensive validation checks to determine inconsistencies
+ /// within the table's schema and internal data.
+ ///
+ /// This is O(k*n) where k is the total number of field descendents,
+ /// and n is the number of rows.
+ ///
+ /// \return Status
+ virtual Status ValidateFull() const = 0;
+
+ /// \brief Return the number of columns in the table
+ int num_columns() const { return schema_->num_fields(); }
+
+ /// \brief Return the number of rows (equal to each column's logical length)
+ int64_t num_rows() const { return num_rows_; }
+
+ /// \brief Determine if tables are equal
+ ///
+ /// Two tables can be equal only if they have equal schemas.
+ /// However, they may be equal even if they have different chunkings.
+ bool Equals(const Table& other, bool check_metadata = false) const;
+
+ /// \brief Make a new table by combining the chunks this table has.
+ ///
+ /// All the underlying chunks in the ChunkedArray of each column are
+ /// concatenated into zero or one chunk.
+ ///
+ /// \param[in] pool The pool for buffer allocations
+ Result<std::shared_ptr<Table>> CombineChunks(
+ MemoryPool* pool = default_memory_pool()) const;
+
+ protected:
+ Table();
+
+ std::shared_ptr<Schema> schema_;
+ int64_t num_rows_;
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Table);
+};
+
+/// \brief Compute a stream of record batches from a (possibly chunked) Table
+///
+/// The conversion is zero-copy: each record batch is a view over a slice
+/// of the table's columns.
+class ARROW_EXPORT TableBatchReader : public RecordBatchReader {
+ public:
+ /// \brief Construct a TableBatchReader for the given table
+ explicit TableBatchReader(const Table& table);
+
+ std::shared_ptr<Schema> schema() const override;
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* out) override;
+
+ /// \brief Set the desired maximum chunk size of record batches
+ ///
+ /// The actual chunk size of each record batch may be smaller, depending
+ /// on actual chunking characteristics of each table column.
+ void set_chunksize(int64_t chunksize);
+
+ private:
+ const Table& table_;
+ std::vector<ChunkedArray*> column_data_;
+ std::vector<int> chunk_numbers_;
+ std::vector<int64_t> chunk_offsets_;
+ int64_t absolute_row_position_;
+ int64_t max_chunksize_;
+};
+
+/// \defgroup concat-tables ConcatenateTables function.
+///
+/// ConcatenateTables function.
+/// @{
+
+/// \brief Controls the behavior of ConcatenateTables().
+struct ARROW_EXPORT ConcatenateTablesOptions {
+ /// If true, the schemas of the tables will be first unified with fields of
+ /// the same name being merged, according to `field_merge_options`, then each
+ /// table will be promoted to the unified schema before being concatenated.
+ /// Otherwise, all tables should have the same schema. Each column in the output table
+ /// is the result of concatenating the corresponding columns in all input tables.
+ bool unify_schemas = false;
+
+ Field::MergeOptions field_merge_options = Field::MergeOptions::Defaults();
+
+ static ConcatenateTablesOptions Defaults() { return {}; }
+};
+
+/// \brief Construct table from multiple input tables.
+ARROW_EXPORT
+Result<std::shared_ptr<Table>> ConcatenateTables(
+ const std::vector<std::shared_ptr<Table>>& tables,
+ ConcatenateTablesOptions options = ConcatenateTablesOptions::Defaults(),
+ MemoryPool* memory_pool = default_memory_pool());
+
+/// \brief Promotes a table to conform to the given schema.
+///
+/// If a field in the schema does not have a corresponding column in the
+/// table, a column of nulls will be added to the resulting table.
+/// If the corresponding column is of type Null, it will be promoted to
+/// the type specified by schema, with null values filled.
+/// Returns an error:
+/// - if the corresponding column's type is not compatible with the
+/// schema.
+/// - if there is a column in the table that does not exist in the schema.
+///
+/// \param[in] table the input Table
+/// \param[in] schema the target schema to promote to
+/// \param[in] pool The memory pool to be used if null-filled arrays need to
+/// be created.
+ARROW_EXPORT
+Result<std::shared_ptr<Table>> PromoteTableToSchema(
+ const std::shared_ptr<Table>& table, const std::shared_ptr<Schema>& schema,
+ MemoryPool* pool = default_memory_pool());
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/table_builder.cc b/src/arrow/cpp/src/arrow/table_builder.cc
new file mode 100644
index 000000000..c026c3557
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/table_builder.cc
@@ -0,0 +1,113 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/table_builder.h"
+
+#include <memory>
+#include <utility>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/builder_base.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+// ----------------------------------------------------------------------
+// RecordBatchBuilder
+
+RecordBatchBuilder::RecordBatchBuilder(const std::shared_ptr<Schema>& schema,
+ MemoryPool* pool, int64_t initial_capacity)
+ : schema_(schema), initial_capacity_(initial_capacity), pool_(pool) {}
+
+Status RecordBatchBuilder::Make(const std::shared_ptr<Schema>& schema, MemoryPool* pool,
+ std::unique_ptr<RecordBatchBuilder>* builder) {
+ return Make(schema, pool, kMinBuilderCapacity, builder);
+}
+
+Status RecordBatchBuilder::Make(const std::shared_ptr<Schema>& schema, MemoryPool* pool,
+ int64_t initial_capacity,
+ std::unique_ptr<RecordBatchBuilder>* builder) {
+ builder->reset(new RecordBatchBuilder(schema, pool, initial_capacity));
+ RETURN_NOT_OK((*builder)->CreateBuilders());
+ return (*builder)->InitBuilders();
+}
+
+Status RecordBatchBuilder::Flush(bool reset_builders,
+ std::shared_ptr<RecordBatch>* batch) {
+ std::vector<std::shared_ptr<Array>> fields;
+ fields.resize(this->num_fields());
+
+ int64_t length = 0;
+ for (int i = 0; i < this->num_fields(); ++i) {
+ RETURN_NOT_OK(raw_field_builders_[i]->Finish(&fields[i]));
+ if (i > 0 && fields[i]->length() != length) {
+ return Status::Invalid("All fields must be same length when calling Flush");
+ }
+ length = fields[i]->length();
+ }
+
+ // For certain types like dictionaries, types may not be fully
+ // determined before we have flushed. Make sure that the RecordBatch
+ // gets the correct types in schema.
+ // See: #ARROW-9969
+ std::vector<std::shared_ptr<Field>> schema_fields(schema_->fields());
+ for (int i = 0; i < this->num_fields(); ++i) {
+ if (!schema_fields[i]->type()->Equals(fields[i]->type())) {
+ schema_fields[i] = schema_fields[i]->WithType(fields[i]->type());
+ }
+ }
+ std::shared_ptr<Schema> schema =
+ std::make_shared<Schema>(std::move(schema_fields), schema_->metadata());
+
+ *batch = RecordBatch::Make(std::move(schema), length, std::move(fields));
+ if (reset_builders) {
+ return InitBuilders();
+ } else {
+ return Status::OK();
+ }
+}
+
+Status RecordBatchBuilder::Flush(std::shared_ptr<RecordBatch>* batch) {
+ return Flush(true, batch);
+}
+
+void RecordBatchBuilder::SetInitialCapacity(int64_t capacity) {
+ ARROW_CHECK_GT(capacity, 0) << "Initial capacity must be positive";
+ initial_capacity_ = capacity;
+}
+
+Status RecordBatchBuilder::CreateBuilders() {
+ field_builders_.resize(this->num_fields());
+ raw_field_builders_.resize(this->num_fields());
+ for (int i = 0; i < this->num_fields(); ++i) {
+ RETURN_NOT_OK(MakeBuilder(pool_, schema_->field(i)->type(), &field_builders_[i]));
+ raw_field_builders_[i] = field_builders_[i].get();
+ }
+ return Status::OK();
+}
+
+Status RecordBatchBuilder::InitBuilders() {
+ for (int i = 0; i < this->num_fields(); ++i) {
+ RETURN_NOT_OK(raw_field_builders_[i]->Reserve(initial_capacity_));
+ }
+ return Status::OK();
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/table_builder.h b/src/arrow/cpp/src/arrow/table_builder.h
new file mode 100644
index 000000000..db130d389
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/table_builder.h
@@ -0,0 +1,110 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "arrow/array/builder_base.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class MemoryPool;
+class RecordBatch;
+
+/// \class RecordBatchBuilder
+/// \brief Helper class for creating record batches iteratively given a known
+/// schema
+class ARROW_EXPORT RecordBatchBuilder {
+ public:
+ /// \brief Create an initialize a RecordBatchBuilder
+ /// \param[in] schema The schema for the record batch
+ /// \param[in] pool A MemoryPool to use for allocations
+ /// \param[in] builder the created builder instance
+ static Status Make(const std::shared_ptr<Schema>& schema, MemoryPool* pool,
+ std::unique_ptr<RecordBatchBuilder>* builder);
+
+ /// \brief Create an initialize a RecordBatchBuilder
+ /// \param[in] schema The schema for the record batch
+ /// \param[in] pool A MemoryPool to use for allocations
+ /// \param[in] initial_capacity The initial capacity for the builders
+ /// \param[in] builder the created builder instance
+ static Status Make(const std::shared_ptr<Schema>& schema, MemoryPool* pool,
+ int64_t initial_capacity,
+ std::unique_ptr<RecordBatchBuilder>* builder);
+
+ /// \brief Get base pointer to field builder
+ /// \param i the field index
+ /// \return pointer to ArrayBuilder
+ ArrayBuilder* GetField(int i) { return raw_field_builders_[i]; }
+
+ /// \brief Return field builder casted to indicated specific builder type
+ /// \param i the field index
+ /// \return pointer to template type
+ template <typename T>
+ T* GetFieldAs(int i) {
+ return internal::checked_cast<T*>(raw_field_builders_[i]);
+ }
+
+ /// \brief Finish current batch and optionally reset
+ /// \param[in] reset_builders the resulting RecordBatch
+ /// \param[out] batch the resulting RecordBatch
+ /// \return Status
+ Status Flush(bool reset_builders, std::shared_ptr<RecordBatch>* batch);
+
+ /// \brief Finish current batch and reset
+ /// \param[out] batch the resulting RecordBatch
+ /// \return Status
+ Status Flush(std::shared_ptr<RecordBatch>* batch);
+
+ /// \brief Set the initial capacity for new builders
+ void SetInitialCapacity(int64_t capacity);
+
+ /// \brief The initial capacity for builders
+ int64_t initial_capacity() const { return initial_capacity_; }
+
+ /// \brief The number of fields in the schema
+ int num_fields() const { return schema_->num_fields(); }
+
+ /// \brief The number of fields in the schema
+ std::shared_ptr<Schema> schema() const { return schema_; }
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(RecordBatchBuilder);
+
+ RecordBatchBuilder(const std::shared_ptr<Schema>& schema, MemoryPool* pool,
+ int64_t initial_capacity);
+
+ Status CreateBuilders();
+ Status InitBuilders();
+
+ std::shared_ptr<Schema> schema_;
+ int64_t initial_capacity_;
+ MemoryPool* pool_;
+
+ std::vector<std::unique_ptr<ArrayBuilder>> field_builders_;
+ std::vector<ArrayBuilder*> raw_field_builders_;
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/table_builder_test.cc b/src/arrow/cpp/src/arrow/table_builder_test.cc
new file mode 100644
index 000000000..c73091312
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/table_builder_test.cc
@@ -0,0 +1,182 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_dict.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/table_builder.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+class TestRecordBatchBuilder : public TestBase {
+ public:
+};
+
+std::shared_ptr<Schema> ExampleSchema1() {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", utf8());
+ auto f2 = field("f1", list(int8()));
+ return ::arrow::schema({f0, f1, f2});
+}
+
+template <typename BuilderType, typename T>
+void AppendValues(BuilderType* builder, const std::vector<T>& values,
+ const std::vector<bool>& is_valid) {
+ for (size_t i = 0; i < values.size(); ++i) {
+ if (is_valid.size() == 0 || is_valid[i]) {
+ ASSERT_OK(builder->Append(values[i]));
+ } else {
+ ASSERT_OK(builder->AppendNull());
+ }
+ }
+}
+
+template <typename ValueType, typename T>
+void AppendList(ListBuilder* builder, const std::vector<std::vector<T>>& values,
+ const std::vector<bool>& is_valid) {
+ auto values_builder = checked_cast<ValueType*>(builder->value_builder());
+
+ for (size_t i = 0; i < values.size(); ++i) {
+ if (is_valid.size() == 0 || is_valid[i]) {
+ ASSERT_OK(builder->Append());
+ AppendValues<ValueType, T>(values_builder, values[i], {});
+ } else {
+ ASSERT_OK(builder->AppendNull());
+ }
+ }
+}
+
+TEST_F(TestRecordBatchBuilder, Basics) {
+ auto schema = ExampleSchema1();
+
+ std::unique_ptr<RecordBatchBuilder> builder;
+ ASSERT_OK(RecordBatchBuilder::Make(schema, pool_, &builder));
+
+ std::vector<bool> is_valid = {false, true, true, true};
+ std::vector<int32_t> f0_values = {0, 1, 2, 3};
+ std::vector<std::string> f1_values = {"a", "bb", "ccc", "dddd"};
+ std::vector<std::vector<int8_t>> f2_values = {{}, {0, 1}, {}, {2}};
+
+ std::shared_ptr<Array> a0, a1, a2;
+
+ // Make the expected record batch
+ auto AppendData = [&](Int32Builder* b0, StringBuilder* b1, ListBuilder* b2) {
+ AppendValues<Int32Builder, int32_t>(b0, f0_values, is_valid);
+ AppendValues<StringBuilder, std::string>(b1, f1_values, is_valid);
+ AppendList<Int8Builder, int8_t>(b2, f2_values, is_valid);
+ };
+
+ Int32Builder ex_b0;
+ StringBuilder ex_b1;
+ ListBuilder ex_b2(pool_, std::unique_ptr<Int8Builder>(new Int8Builder(pool_)));
+
+ AppendData(&ex_b0, &ex_b1, &ex_b2);
+ ASSERT_OK(ex_b0.Finish(&a0));
+ ASSERT_OK(ex_b1.Finish(&a1));
+ ASSERT_OK(ex_b2.Finish(&a2));
+
+ auto expected = RecordBatch::Make(schema, 4, {a0, a1, a2});
+
+ // Builder attributes
+ ASSERT_EQ(3, builder->num_fields());
+ ASSERT_EQ(schema.get(), builder->schema().get());
+
+ const int kIter = 3;
+ for (int i = 0; i < kIter; ++i) {
+ AppendData(builder->GetFieldAs<Int32Builder>(0),
+ checked_cast<StringBuilder*>(builder->GetField(1)),
+ builder->GetFieldAs<ListBuilder>(2));
+
+ std::shared_ptr<RecordBatch> batch;
+
+ if (i == kIter - 1) {
+ // Do not flush in last iteration
+ ASSERT_OK(builder->Flush(false, &batch));
+ } else {
+ ASSERT_OK(builder->Flush(&batch));
+ }
+
+ ASSERT_BATCHES_EQUAL(*expected, *batch);
+ }
+
+ // Test setting initial capacity
+ builder->SetInitialCapacity(4096);
+ ASSERT_EQ(4096, builder->initial_capacity());
+}
+
+TEST_F(TestRecordBatchBuilder, InvalidFieldLength) {
+ auto schema = ExampleSchema1();
+
+ std::unique_ptr<RecordBatchBuilder> builder;
+ ASSERT_OK(RecordBatchBuilder::Make(schema, pool_, &builder));
+
+ std::vector<bool> is_valid = {false, true, true, true};
+ std::vector<int32_t> f0_values = {0, 1, 2, 3};
+
+ AppendValues<Int32Builder, int32_t>(builder->GetFieldAs<Int32Builder>(0), f0_values,
+ is_valid);
+
+ std::shared_ptr<RecordBatch> dummy;
+ ASSERT_RAISES(Invalid, builder->Flush(&dummy));
+}
+
+// In #ARROW-9969 dictionary types were not updated
+// in schema when the index width grew.
+TEST_F(TestRecordBatchBuilder, DictionaryTypes) {
+ const int num_rows = static_cast<int>(UINT8_MAX) + 2;
+ std::vector<std::string> f0_values;
+ std::vector<bool> is_valid(num_rows, true);
+ for (int i = 0; i < num_rows; i++) {
+ f0_values.push_back(std::to_string(i));
+ }
+
+ auto f0 = field("f0", dictionary(int8(), utf8()));
+
+ auto schema = ::arrow::schema({f0});
+
+ std::unique_ptr<RecordBatchBuilder> builder;
+ ASSERT_OK(RecordBatchBuilder::Make(schema, pool_, &builder));
+
+ auto b0 = builder->GetFieldAs<StringDictionaryBuilder>(0);
+
+ AppendValues<StringDictionaryBuilder, std::string>(b0, f0_values, is_valid);
+
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(builder->Flush(&batch));
+
+ AssertTypeEqual(batch->column(0)->type(), batch->schema()->field(0)->type());
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/table_test.cc b/src/arrow/cpp/src/arrow/table_test.cc
new file mode 100644
index 000000000..88f739625
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/table_test.cc
@@ -0,0 +1,753 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/data.h"
+#include "arrow/array/util.h"
+#include "arrow/chunked_array.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+
+class TestTable : public TestBase {
+ public:
+ void MakeExample1(int length) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8());
+ auto f2 = field("f2", int16());
+
+ std::vector<std::shared_ptr<Field>> fields = {f0, f1, f2};
+ schema_ = std::make_shared<Schema>(fields);
+
+ arrays_ = {MakeRandomArray<Int32Array>(length), MakeRandomArray<UInt8Array>(length),
+ MakeRandomArray<Int16Array>(length)};
+
+ columns_ = {std::make_shared<ChunkedArray>(arrays_[0]),
+ std::make_shared<ChunkedArray>(arrays_[1]),
+ std::make_shared<ChunkedArray>(arrays_[2])};
+ }
+
+ protected:
+ std::shared_ptr<Table> table_;
+ std::shared_ptr<Schema> schema_;
+
+ std::vector<std::shared_ptr<Array>> arrays_;
+ std::vector<std::shared_ptr<ChunkedArray>> columns_;
+};
+
+TEST_F(TestTable, EmptySchema) {
+ auto empty_schema = ::arrow::schema({});
+ table_ = Table::Make(empty_schema, columns_);
+ ASSERT_OK(table_->ValidateFull());
+ ASSERT_EQ(0, table_->num_rows());
+ ASSERT_EQ(0, table_->num_columns());
+}
+
+TEST_F(TestTable, Ctors) {
+ const int length = 100;
+ MakeExample1(length);
+
+ table_ = Table::Make(schema_, columns_);
+ ASSERT_OK(table_->ValidateFull());
+ ASSERT_EQ(length, table_->num_rows());
+ ASSERT_EQ(3, table_->num_columns());
+
+ auto array_ctor = Table::Make(schema_, arrays_);
+ ASSERT_TRUE(table_->Equals(*array_ctor));
+
+ table_ = Table::Make(schema_, columns_, length);
+ ASSERT_OK(table_->ValidateFull());
+ ASSERT_EQ(length, table_->num_rows());
+
+ table_ = Table::Make(schema_, arrays_);
+ ASSERT_OK(table_->ValidateFull());
+ ASSERT_EQ(length, table_->num_rows());
+ ASSERT_EQ(3, table_->num_columns());
+}
+
+TEST_F(TestTable, Metadata) {
+ const int length = 100;
+ MakeExample1(length);
+
+ table_ = Table::Make(schema_, columns_);
+
+ ASSERT_TRUE(table_->schema()->Equals(*schema_));
+
+ auto col = table_->column(0);
+ ASSERT_EQ(schema_->field(0)->type(), col->type());
+}
+
+TEST_F(TestTable, InvalidColumns) {
+ // Check that columns are all the same length
+ const int length = 100;
+ MakeExample1(length);
+
+ table_ = Table::Make(schema_, columns_, length - 1);
+ ASSERT_RAISES(Invalid, table_->ValidateFull());
+
+ columns_.clear();
+
+ // Wrong number of columns
+ table_ = Table::Make(schema_, columns_, length);
+ ASSERT_RAISES(Invalid, table_->ValidateFull());
+
+ columns_ = {std::make_shared<ChunkedArray>(MakeRandomArray<Int32Array>(length)),
+ std::make_shared<ChunkedArray>(MakeRandomArray<UInt8Array>(length)),
+ std::make_shared<ChunkedArray>(MakeRandomArray<Int16Array>(length - 1))};
+
+ table_ = Table::Make(schema_, columns_, length);
+ ASSERT_RAISES(Invalid, table_->ValidateFull());
+}
+
+TEST_F(TestTable, AllColumnsAndFields) {
+ const int length = 100;
+ MakeExample1(length);
+ table_ = Table::Make(schema_, columns_);
+
+ auto columns = table_->columns();
+ auto fields = table_->fields();
+
+ for (int i = 0; i < table_->num_columns(); ++i) {
+ AssertChunkedEqual(*table_->column(i), *columns[i]);
+ AssertFieldEqual(*table_->field(i), *fields[i]);
+ }
+
+ // Zero length
+ std::vector<std::shared_ptr<Array>> t2_columns;
+ auto t2 = Table::Make(::arrow::schema({}), t2_columns);
+ columns = t2->columns();
+ fields = t2->fields();
+
+ ASSERT_EQ(0, columns.size());
+ ASSERT_EQ(0, fields.size());
+}
+
+TEST_F(TestTable, Equals) {
+ const int length = 100;
+ MakeExample1(length);
+
+ table_ = Table::Make(schema_, columns_);
+
+ ASSERT_TRUE(table_->Equals(*table_));
+ // Differing schema
+ auto f0 = field("f3", int32());
+ auto f1 = field("f4", uint8());
+ auto f2 = field("f5", int16());
+ std::vector<std::shared_ptr<Field>> fields = {f0, f1, f2};
+ auto other_schema = std::make_shared<Schema>(fields);
+ auto other = Table::Make(other_schema, columns_);
+ ASSERT_FALSE(table_->Equals(*other));
+ // Differing columns
+ std::vector<std::shared_ptr<ChunkedArray>> other_columns = {
+ std::make_shared<ChunkedArray>(MakeRandomArray<Int32Array>(length, 10)),
+ std::make_shared<ChunkedArray>(MakeRandomArray<UInt8Array>(length, 10)),
+ std::make_shared<ChunkedArray>(MakeRandomArray<Int16Array>(length, 10))};
+
+ other = Table::Make(schema_, other_columns);
+ ASSERT_FALSE(table_->Equals(*other));
+
+ // Differring schema metadata
+ other_schema = schema_->WithMetadata(::arrow::key_value_metadata({"key"}, {"value"}));
+ other = Table::Make(other_schema, columns_);
+ ASSERT_TRUE(table_->Equals(*other));
+ ASSERT_FALSE(table_->Equals(*other, /*check_metadata=*/true));
+}
+
+TEST_F(TestTable, FromRecordBatches) {
+ const int64_t length = 10;
+ MakeExample1(length);
+
+ auto batch1 = RecordBatch::Make(schema_, length, arrays_);
+
+ ASSERT_OK_AND_ASSIGN(auto result, Table::FromRecordBatches({batch1}));
+
+ auto expected = Table::Make(schema_, columns_);
+ ASSERT_TRUE(result->Equals(*expected));
+
+ std::vector<std::shared_ptr<ChunkedArray>> other_columns;
+ for (int i = 0; i < schema_->num_fields(); ++i) {
+ std::vector<std::shared_ptr<Array>> col_arrays = {arrays_[i], arrays_[i]};
+ other_columns.push_back(std::make_shared<ChunkedArray>(col_arrays));
+ }
+
+ ASSERT_OK_AND_ASSIGN(result, Table::FromRecordBatches({batch1, batch1}));
+ expected = Table::Make(schema_, other_columns);
+ ASSERT_TRUE(result->Equals(*expected));
+
+ // Error states
+ std::vector<std::shared_ptr<RecordBatch>> empty_batches;
+ ASSERT_RAISES(Invalid, Table::FromRecordBatches(empty_batches));
+
+ auto other_schema = ::arrow::schema({schema_->field(0), schema_->field(1)});
+
+ std::vector<std::shared_ptr<Array>> other_arrays = {arrays_[0], arrays_[1]};
+ auto batch2 = RecordBatch::Make(other_schema, length, other_arrays);
+ ASSERT_RAISES(Invalid, Table::FromRecordBatches({batch1, batch2}));
+}
+
+TEST_F(TestTable, FromRecordBatchesZeroLength) {
+ // ARROW-2307
+ MakeExample1(10);
+
+ ASSERT_OK_AND_ASSIGN(auto result, Table::FromRecordBatches(schema_, {}));
+
+ ASSERT_EQ(0, result->num_rows());
+ ASSERT_TRUE(result->schema()->Equals(*schema_));
+}
+
+TEST_F(TestTable, CombineChunksZeroColumn) {
+ // ARROW-11232
+ auto record_batch = RecordBatch::Make(schema({}), /*num_rows=*/10,
+ std::vector<std::shared_ptr<Array>>{});
+
+ ASSERT_OK_AND_ASSIGN(
+ auto table,
+ Table::FromRecordBatches(record_batch->schema(), {record_batch, record_batch}));
+ ASSERT_EQ(20, table->num_rows());
+
+ ASSERT_OK_AND_ASSIGN(auto combined, table->CombineChunks());
+
+ EXPECT_EQ(20, combined->num_rows());
+ EXPECT_TRUE(combined->Equals(*table));
+}
+
+TEST_F(TestTable, CombineChunksZeroRow) {
+ MakeExample1(10);
+
+ ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches(schema_, {}));
+ ASSERT_EQ(0, table->num_rows());
+
+ ASSERT_OK_AND_ASSIGN(auto compacted, table->CombineChunks());
+
+ EXPECT_TRUE(compacted->Equals(*table));
+}
+
+TEST_F(TestTable, CombineChunks) {
+ MakeExample1(10);
+ auto batch1 = RecordBatch::Make(schema_, 10, arrays_);
+
+ MakeExample1(15);
+ auto batch2 = RecordBatch::Make(schema_, 15, arrays_);
+
+ ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch1, batch2}));
+ for (int i = 0; i < table->num_columns(); ++i) {
+ ASSERT_EQ(2, table->column(i)->num_chunks());
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto compacted, table->CombineChunks());
+
+ EXPECT_TRUE(compacted->Equals(*table));
+ for (int i = 0; i < compacted->num_columns(); ++i) {
+ EXPECT_EQ(1, compacted->column(i)->num_chunks());
+ }
+}
+
+TEST_F(TestTable, LARGE_MEMORY_TEST(CombineChunksStringColumn)) {
+ schema_ = schema({field("str", utf8())});
+ arrays_ = {nullptr};
+
+ std::string value(1 << 16, '-');
+
+ auto num_rows = kBinaryMemoryLimit / static_cast<int64_t>(value.size());
+ StringBuilder builder;
+ ASSERT_OK(builder.Resize(num_rows));
+ ASSERT_OK(builder.ReserveData(value.size() * num_rows));
+ for (int i = 0; i < num_rows; ++i) builder.UnsafeAppend(value);
+ ASSERT_OK(builder.Finish(&arrays_[0]));
+
+ auto batch = RecordBatch::Make(schema_, num_rows, arrays_);
+
+ ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches({batch, batch}));
+ ASSERT_EQ(table->column(0)->num_chunks(), 2);
+
+ ASSERT_OK_AND_ASSIGN(auto compacted, table->CombineChunks());
+ EXPECT_TRUE(compacted->Equals(*table));
+
+ // can't compact these columns any further; they contain too much character data
+ ASSERT_EQ(compacted->column(0)->num_chunks(), 2);
+}
+
+TEST_F(TestTable, ConcatenateTables) {
+ const int64_t length = 10;
+
+ MakeExample1(length);
+ auto batch1 = RecordBatch::Make(schema_, length, arrays_);
+
+ // generate different data
+ MakeExample1(length);
+ auto batch2 = RecordBatch::Make(schema_, length, arrays_);
+
+ ASSERT_OK_AND_ASSIGN(auto t1, Table::FromRecordBatches({batch1}));
+ ASSERT_OK_AND_ASSIGN(auto t2, Table::FromRecordBatches({batch2}));
+
+ ASSERT_OK_AND_ASSIGN(auto result, ConcatenateTables({t1, t2}));
+ ASSERT_OK_AND_ASSIGN(auto expected, Table::FromRecordBatches({batch1, batch2}));
+ AssertTablesEqual(*expected, *result);
+
+ // Error states
+ std::vector<std::shared_ptr<Table>> empty_tables;
+ ASSERT_RAISES(Invalid, ConcatenateTables(empty_tables));
+
+ auto other_schema = ::arrow::schema({schema_->field(0), schema_->field(1)});
+
+ std::vector<std::shared_ptr<Array>> other_arrays = {arrays_[0], arrays_[1]};
+ auto batch3 = RecordBatch::Make(other_schema, length, other_arrays);
+ ASSERT_OK_AND_ASSIGN(auto t3, Table::FromRecordBatches({batch3}));
+
+ ASSERT_RAISES(Invalid, ConcatenateTables({t1, t3}));
+}
+
+std::shared_ptr<Table> MakeTableWithOneNullFilledColumn(
+ const std::string& column_name, const std::shared_ptr<DataType>& data_type,
+ const int length) {
+ auto array_of_nulls = *MakeArrayOfNull(data_type, length);
+ return Table::Make(schema({field(column_name, data_type)}), {array_of_nulls});
+}
+
+using TestPromoteTableToSchema = TestTable;
+
+TEST_F(TestPromoteTableToSchema, IdenticalSchema) {
+ const int length = 10;
+ auto metadata =
+ std::shared_ptr<KeyValueMetadata>(new KeyValueMetadata({"foo"}, {"bar"}));
+ MakeExample1(length);
+ std::shared_ptr<Table> table = Table::Make(schema_, arrays_);
+
+ ASSERT_OK_AND_ASSIGN(auto result,
+ PromoteTableToSchema(table, schema_->WithMetadata(metadata)));
+
+ std::shared_ptr<Table> expected = table->ReplaceSchemaMetadata(metadata);
+
+ ASSERT_TRUE(result->Equals(*expected));
+}
+
+// The promoted table's fields are ordered the same as the promote-to schema.
+TEST_F(TestPromoteTableToSchema, FieldsReorderedAfterPromotion) {
+ const int length = 10;
+ MakeExample1(length);
+
+ std::vector<std::shared_ptr<Field>> reversed_fields(schema_->fields().crbegin(),
+ schema_->fields().crend());
+ std::vector<std::shared_ptr<Array>> reversed_arrays(arrays_.crbegin(), arrays_.crend());
+
+ std::shared_ptr<Table> table = Table::Make(schema(reversed_fields), reversed_arrays);
+
+ ASSERT_OK_AND_ASSIGN(auto result, PromoteTableToSchema(table, schema_));
+
+ ASSERT_TRUE(result->schema()->Equals(*schema_));
+}
+
+TEST_F(TestPromoteTableToSchema, PromoteNullTypeField) {
+ const int length = 10;
+ auto metadata =
+ std::shared_ptr<KeyValueMetadata>(new KeyValueMetadata({"foo"}, {"bar"}));
+ auto table_with_null_column = MakeTableWithOneNullFilledColumn("field", null(), length)
+ ->ReplaceSchemaMetadata(metadata);
+ auto promoted_schema = schema({field("field", int32())});
+
+ ASSERT_OK_AND_ASSIGN(auto result,
+ PromoteTableToSchema(table_with_null_column, promoted_schema));
+
+ ASSERT_TRUE(
+ result->Equals(*MakeTableWithOneNullFilledColumn("field", int32(), length)));
+}
+
+TEST_F(TestPromoteTableToSchema, AddMissingField) {
+ const int length = 10;
+ auto f0 = field("f0", int32());
+ auto table = Table::Make(schema({}), std::vector<std::shared_ptr<Array>>(), length);
+ auto promoted_schema = schema({field("field", int32())});
+
+ ASSERT_OK_AND_ASSIGN(auto result, PromoteTableToSchema(table, promoted_schema));
+
+ ASSERT_TRUE(
+ result->Equals(*MakeTableWithOneNullFilledColumn("field", int32(), length)));
+}
+
+TEST_F(TestPromoteTableToSchema, IncompatibleTypes) {
+ const int length = 10;
+ auto table = MakeTableWithOneNullFilledColumn("field", int32(), length);
+
+ // Invalid promotion: int32 to null.
+ ASSERT_RAISES(Invalid, PromoteTableToSchema(table, schema({field("field", null())})));
+
+ // Invalid promotion: int32 to uint32.
+ ASSERT_RAISES(Invalid, PromoteTableToSchema(table, schema({field("field", uint32())})));
+}
+
+TEST_F(TestPromoteTableToSchema, IncompatibleNullity) {
+ const int length = 10;
+ auto table = MakeTableWithOneNullFilledColumn("field", int32(), length);
+ ASSERT_RAISES(Invalid,
+ PromoteTableToSchema(
+ table, schema({field("field", uint32())->WithNullable(false)})));
+}
+
+TEST_F(TestPromoteTableToSchema, DuplicateFieldNames) {
+ const int length = 10;
+
+ auto table = Table::Make(
+ schema({field("field", int32()), field("field", null())}),
+ {MakeRandomArray<Int32Array>(length), MakeRandomArray<NullArray>(length)});
+
+ ASSERT_RAISES(Invalid, PromoteTableToSchema(table, schema({field("field", int32())})));
+}
+
+TEST_F(TestPromoteTableToSchema, TableFieldAbsentFromSchema) {
+ const int length = 10;
+
+ auto table =
+ Table::Make(schema({field("f0", int32())}), {MakeRandomArray<Int32Array>(length)});
+
+ std::shared_ptr<Table> result;
+ ASSERT_RAISES(Invalid, PromoteTableToSchema(table, schema({field("f1", int32())})));
+}
+
+class ConcatenateTablesWithPromotionTest : public TestTable {
+ protected:
+ ConcatenateTablesOptions GetOptions() {
+ ConcatenateTablesOptions options;
+ options.unify_schemas = true;
+ return options;
+ }
+
+ void MakeExample2(int length) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", null());
+
+ std::vector<std::shared_ptr<Field>> fields = {f0, f1};
+ schema_ = std::make_shared<Schema>(fields);
+
+ arrays_ = {MakeRandomArray<Int32Array>(length), MakeRandomArray<NullArray>(length)};
+
+ columns_ = {std::make_shared<ChunkedArray>(arrays_[0]),
+ std::make_shared<ChunkedArray>(arrays_[1])};
+ }
+
+ void AssertTablesEqualUnorderedFields(const Table& lhs, const Table& rhs) {
+ ASSERT_EQ(lhs.schema()->num_fields(), rhs.schema()->num_fields());
+ if (lhs.schema()->metadata()) {
+ ASSERT_NE(nullptr, rhs.schema()->metadata());
+ ASSERT_TRUE(lhs.schema()->metadata()->Equals(*rhs.schema()->metadata()));
+ } else {
+ ASSERT_EQ(nullptr, rhs.schema()->metadata());
+ }
+ for (int i = 0; i < lhs.schema()->num_fields(); ++i) {
+ const auto& lhs_field = lhs.schema()->field(i);
+ const auto& rhs_field = rhs.schema()->GetFieldByName(lhs_field->name());
+ ASSERT_NE(nullptr, rhs_field);
+ ASSERT_TRUE(lhs_field->Equals(rhs_field, true));
+ const auto& lhs_column = lhs.column(i);
+ const auto& rhs_column = rhs.GetColumnByName(lhs_field->name());
+ AssertChunkedEqual(*lhs_column, *rhs_column);
+ }
+ }
+};
+
+TEST_F(ConcatenateTablesWithPromotionTest, Simple) {
+ const int64_t length = 10;
+
+ MakeExample1(length);
+ auto batch1 = RecordBatch::Make(schema_, length, arrays_);
+
+ ASSERT_OK_AND_ASSIGN(auto f1_nulls, MakeArrayOfNull(schema_->field(1)->type(), length));
+ ASSERT_OK_AND_ASSIGN(auto f2_nulls, MakeArrayOfNull(schema_->field(2)->type(), length));
+
+ MakeExample2(length);
+ auto batch2 = RecordBatch::Make(schema_, length, arrays_);
+
+ auto batch2_null_filled =
+ RecordBatch::Make(batch1->schema(), length, {arrays_[0], f1_nulls, f2_nulls});
+
+ ASSERT_OK_AND_ASSIGN(auto t1, Table::FromRecordBatches({batch1}));
+ ASSERT_OK_AND_ASSIGN(auto t2, Table::FromRecordBatches({batch2}));
+ ASSERT_OK_AND_ASSIGN(auto t3, Table::FromRecordBatches({batch2_null_filled}));
+
+ ASSERT_OK_AND_ASSIGN(auto result, ConcatenateTables({t1, t2}, GetOptions()));
+ ASSERT_OK_AND_ASSIGN(auto expected, ConcatenateTables({t1, t3}));
+ AssertTablesEqualUnorderedFields(*expected, *result);
+
+ ASSERT_OK_AND_ASSIGN(result, ConcatenateTables({t2, t1}, GetOptions()));
+ ASSERT_OK_AND_ASSIGN(expected, ConcatenateTables({t3, t1}));
+ AssertTablesEqualUnorderedFields(*expected, *result);
+}
+
+TEST_F(TestTable, Slice) {
+ const int64_t length = 10;
+
+ MakeExample1(length);
+ auto batch = RecordBatch::Make(schema_, length, arrays_);
+
+ ASSERT_OK_AND_ASSIGN(auto half, Table::FromRecordBatches({batch}));
+ ASSERT_OK_AND_ASSIGN(auto whole, Table::FromRecordBatches({batch, batch}));
+ ASSERT_OK_AND_ASSIGN(auto three, Table::FromRecordBatches({batch, batch, batch}));
+
+ AssertTablesEqual(*whole->Slice(0, length), *half);
+ AssertTablesEqual(*whole->Slice(length), *half);
+ AssertTablesEqual(*whole->Slice(length / 3, 2 * (length - length / 3)),
+ *three->Slice(length + length / 3, 2 * (length - length / 3)));
+}
+
+TEST_F(TestTable, RemoveColumn) {
+ const int64_t length = 10;
+ MakeExample1(length);
+
+ auto table_sp = Table::Make(schema_, columns_);
+ const Table& table = *table_sp;
+
+ ASSERT_OK_AND_ASSIGN(auto result, table.RemoveColumn(0));
+
+ auto ex_schema = ::arrow::schema({schema_->field(1), schema_->field(2)});
+ std::vector<std::shared_ptr<ChunkedArray>> ex_columns = {table.column(1),
+ table.column(2)};
+
+ auto expected = Table::Make(ex_schema, ex_columns);
+ ASSERT_TRUE(result->Equals(*expected));
+
+ ASSERT_OK_AND_ASSIGN(result, table.RemoveColumn(1));
+ ex_schema = ::arrow::schema({schema_->field(0), schema_->field(2)});
+ ex_columns = {table.column(0), table.column(2)};
+
+ expected = Table::Make(ex_schema, ex_columns);
+ ASSERT_TRUE(result->Equals(*expected));
+
+ ASSERT_OK_AND_ASSIGN(result, table.RemoveColumn(2));
+ ex_schema = ::arrow::schema({schema_->field(0), schema_->field(1)});
+ ex_columns = {table.column(0), table.column(1)};
+ expected = Table::Make(ex_schema, ex_columns);
+ ASSERT_TRUE(result->Equals(*expected));
+}
+
+TEST_F(TestTable, SetColumn) {
+ const int64_t length = 10;
+ MakeExample1(length);
+
+ auto table_sp = Table::Make(schema_, columns_);
+ const Table& table = *table_sp;
+
+ ASSERT_OK_AND_ASSIGN(auto result,
+ table.SetColumn(0, schema_->field(1), table.column(1)));
+
+ auto ex_schema =
+ ::arrow::schema({schema_->field(1), schema_->field(1), schema_->field(2)});
+
+ auto expected =
+ Table::Make(ex_schema, {table.column(1), table.column(1), table.column(2)});
+ ASSERT_TRUE(result->Equals(*expected));
+}
+
+TEST_F(TestTable, RenameColumns) {
+ MakeExample1(10);
+ auto table = Table::Make(schema_, columns_);
+ EXPECT_THAT(table->ColumnNames(), testing::ElementsAre("f0", "f1", "f2"));
+
+ ASSERT_OK_AND_ASSIGN(auto renamed, table->RenameColumns({"zero", "one", "two"}));
+ EXPECT_THAT(renamed->ColumnNames(), testing::ElementsAre("zero", "one", "two"));
+ ASSERT_OK(renamed->ValidateFull());
+
+ ASSERT_RAISES(Invalid, table->RenameColumns({"hello", "world"}));
+}
+
+TEST_F(TestTable, SelectColumns) {
+ MakeExample1(10);
+ auto table = Table::Make(schema_, columns_);
+
+ ASSERT_OK_AND_ASSIGN(auto subset, table->SelectColumns({0, 2}));
+ ASSERT_OK(subset->ValidateFull());
+
+ auto expexted_schema = ::arrow::schema({schema_->field(0), schema_->field(2)});
+ auto expected = Table::Make(expexted_schema, {table->column(0), table->column(2)});
+ ASSERT_TRUE(subset->Equals(*expected));
+
+ // Out of bounds indices
+ ASSERT_RAISES(Invalid, table->SelectColumns({0, 3}));
+ ASSERT_RAISES(Invalid, table->SelectColumns({-1}));
+}
+
+TEST_F(TestTable, RemoveColumnEmpty) {
+ // ARROW-1865
+ const int64_t length = 10;
+
+ auto f0 = field("f0", int32());
+ auto schema = ::arrow::schema({f0});
+ auto a0 = MakeRandomArray<Int32Array>(length);
+
+ auto table = Table::Make(schema, {std::make_shared<ChunkedArray>(a0)});
+
+ ASSERT_OK_AND_ASSIGN(auto empty, table->RemoveColumn(0));
+
+ ASSERT_EQ(table->num_rows(), empty->num_rows());
+
+ ASSERT_OK_AND_ASSIGN(auto added, empty->AddColumn(0, f0, table->column(0)));
+ ASSERT_EQ(table->num_rows(), added->num_rows());
+}
+
+TEST_F(TestTable, AddColumn) {
+ const int64_t length = 10;
+ MakeExample1(length);
+
+ auto table_sp = Table::Make(schema_, columns_);
+ const Table& table = *table_sp;
+
+ auto f0 = schema_->field(0);
+
+ // Some negative tests with invalid index
+ ASSERT_RAISES(Invalid, table.AddColumn(10, f0, columns_[0]));
+ ASSERT_RAISES(Invalid, table.AddColumn(4, f0, columns_[0]));
+ ASSERT_RAISES(Invalid, table.AddColumn(-1, f0, columns_[0]));
+
+ // Add column with wrong length
+ auto longer_col =
+ std::make_shared<ChunkedArray>(MakeRandomArray<Int32Array>(length + 1));
+ ASSERT_RAISES(Invalid, table.AddColumn(0, f0, longer_col));
+
+ // Add column 0 in different places
+ ASSERT_OK_AND_ASSIGN(auto result, table.AddColumn(0, f0, columns_[0]));
+ auto ex_schema = ::arrow::schema(
+ {schema_->field(0), schema_->field(0), schema_->field(1), schema_->field(2)});
+
+ auto expected = Table::Make(
+ ex_schema, {table.column(0), table.column(0), table.column(1), table.column(2)});
+ ASSERT_TRUE(result->Equals(*expected));
+
+ ASSERT_OK_AND_ASSIGN(result, table.AddColumn(1, f0, columns_[0]));
+ ex_schema = ::arrow::schema(
+ {schema_->field(0), schema_->field(0), schema_->field(1), schema_->field(2)});
+
+ expected = Table::Make(
+ ex_schema, {table.column(0), table.column(0), table.column(1), table.column(2)});
+ ASSERT_TRUE(result->Equals(*expected));
+
+ ASSERT_OK_AND_ASSIGN(result, table.AddColumn(2, f0, columns_[0]));
+ ex_schema = ::arrow::schema(
+ {schema_->field(0), schema_->field(1), schema_->field(0), schema_->field(2)});
+ expected = Table::Make(
+ ex_schema, {table.column(0), table.column(1), table.column(0), table.column(2)});
+ ASSERT_TRUE(result->Equals(*expected));
+
+ ASSERT_OK_AND_ASSIGN(result, table.AddColumn(3, f0, columns_[0]));
+ ex_schema = ::arrow::schema(
+ {schema_->field(0), schema_->field(1), schema_->field(2), schema_->field(0)});
+ expected = Table::Make(
+ ex_schema, {table.column(0), table.column(1), table.column(2), table.column(0)});
+ ASSERT_TRUE(result->Equals(*expected));
+}
+
+class TestTableBatchReader : public TestBase {};
+
+TEST_F(TestTableBatchReader, ReadNext) {
+ ArrayVector c1, c2;
+
+ auto a1 = MakeRandomArray<Int32Array>(10);
+ auto a2 = MakeRandomArray<Int32Array>(20);
+ auto a3 = MakeRandomArray<Int32Array>(30);
+ auto a4 = MakeRandomArray<Int32Array>(10);
+
+ auto sch1 = arrow::schema({field("f1", int32()), field("f2", int32())});
+
+ std::vector<std::shared_ptr<ChunkedArray>> columns;
+
+ std::shared_ptr<RecordBatch> batch;
+
+ std::vector<std::shared_ptr<Array>> arrays_1 = {a1, a4, a2};
+ std::vector<std::shared_ptr<Array>> arrays_2 = {a2, a2};
+ columns = {std::make_shared<ChunkedArray>(arrays_1),
+ std::make_shared<ChunkedArray>(arrays_2)};
+ auto t1 = Table::Make(sch1, columns);
+
+ TableBatchReader i1(*t1);
+
+ ASSERT_OK(i1.ReadNext(&batch));
+ ASSERT_EQ(10, batch->num_rows());
+
+ ASSERT_OK(i1.ReadNext(&batch));
+ ASSERT_EQ(10, batch->num_rows());
+
+ ASSERT_OK(i1.ReadNext(&batch));
+ ASSERT_EQ(20, batch->num_rows());
+
+ ASSERT_OK(i1.ReadNext(&batch));
+ ASSERT_EQ(nullptr, batch);
+
+ arrays_1 = {a1};
+ arrays_2 = {a4};
+ columns = {std::make_shared<ChunkedArray>(arrays_1),
+ std::make_shared<ChunkedArray>(arrays_2)};
+ auto t2 = Table::Make(sch1, columns);
+
+ TableBatchReader i2(*t2);
+
+ ASSERT_OK(i2.ReadNext(&batch));
+ ASSERT_EQ(10, batch->num_rows());
+
+ // Ensure non-sliced
+ ASSERT_EQ(a1->data().get(), batch->column_data(0).get());
+ ASSERT_EQ(a4->data().get(), batch->column_data(1).get());
+
+ ASSERT_OK(i1.ReadNext(&batch));
+ ASSERT_EQ(nullptr, batch);
+}
+
+TEST_F(TestTableBatchReader, Chunksize) {
+ auto a1 = MakeRandomArray<Int32Array>(10);
+ auto a2 = MakeRandomArray<Int32Array>(20);
+ auto a3 = MakeRandomArray<Int32Array>(10);
+
+ auto sch1 = arrow::schema({field("f1", int32())});
+
+ std::vector<std::shared_ptr<Array>> arrays = {a1, a2, a3};
+ auto t1 = Table::Make(sch1, {std::make_shared<ChunkedArray>(arrays)});
+
+ TableBatchReader i1(*t1);
+
+ i1.set_chunksize(15);
+
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(i1.ReadNext(&batch));
+ ASSERT_OK(batch->ValidateFull());
+ ASSERT_EQ(10, batch->num_rows());
+
+ ASSERT_OK(i1.ReadNext(&batch));
+ ASSERT_OK(batch->ValidateFull());
+ ASSERT_EQ(15, batch->num_rows());
+
+ ASSERT_OK(i1.ReadNext(&batch));
+ ASSERT_OK(batch->ValidateFull());
+ ASSERT_EQ(5, batch->num_rows());
+
+ ASSERT_OK(i1.ReadNext(&batch));
+ ASSERT_OK(batch->ValidateFull());
+ ASSERT_EQ(10, batch->num_rows());
+
+ ASSERT_OK(i1.ReadNext(&batch));
+ ASSERT_EQ(nullptr, batch);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/tensor.cc b/src/arrow/cpp/src/arrow/tensor.cc
new file mode 100644
index 000000000..30ae8c465
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/tensor.cc
@@ -0,0 +1,342 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/tensor.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <numeric>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/logging.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace internal {
+
+Status ComputeRowMajorStrides(const FixedWidthType& type,
+ const std::vector<int64_t>& shape,
+ std::vector<int64_t>* strides) {
+ const int byte_width = GetByteWidth(type);
+ const size_t ndim = shape.size();
+
+ int64_t remaining = 0;
+ if (!shape.empty() && shape.front() > 0) {
+ remaining = byte_width;
+ for (size_t i = 1; i < ndim; ++i) {
+ if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
+ return Status::Invalid(
+ "Row-major strides computed from shape would not fit in 64-bit integer");
+ }
+ }
+ }
+
+ if (remaining == 0) {
+ strides->assign(shape.size(), byte_width);
+ return Status::OK();
+ }
+
+ strides->push_back(remaining);
+ for (size_t i = 1; i < ndim; ++i) {
+ remaining /= shape[i];
+ strides->push_back(remaining);
+ }
+
+ return Status::OK();
+}
+
+Status ComputeColumnMajorStrides(const FixedWidthType& type,
+ const std::vector<int64_t>& shape,
+ std::vector<int64_t>* strides) {
+ const int byte_width = internal::GetByteWidth(type);
+ const size_t ndim = shape.size();
+
+ int64_t total = 0;
+ if (!shape.empty() && shape.back() > 0) {
+ total = byte_width;
+ for (size_t i = 0; i < ndim - 1; ++i) {
+ if (internal::MultiplyWithOverflow(total, shape[i], &total)) {
+ return Status::Invalid(
+ "Column-major strides computed from shape would not fit in 64-bit "
+ "integer");
+ }
+ }
+ }
+
+ if (total == 0) {
+ strides->assign(shape.size(), byte_width);
+ return Status::OK();
+ }
+
+ total = byte_width;
+ for (size_t i = 0; i < ndim - 1; ++i) {
+ strides->push_back(total);
+ total *= shape[i];
+ }
+ strides->push_back(total);
+
+ return Status::OK();
+}
+
+} // namespace internal
+
+namespace {
+
+inline bool IsTensorStridesRowMajor(const std::shared_ptr<DataType>& type,
+ const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& strides) {
+ std::vector<int64_t> c_strides;
+ const auto& fw_type = checked_cast<const FixedWidthType&>(*type);
+ if (internal::ComputeRowMajorStrides(fw_type, shape, &c_strides).ok()) {
+ return strides == c_strides;
+ } else {
+ return false;
+ }
+}
+
+inline bool IsTensorStridesColumnMajor(const std::shared_ptr<DataType>& type,
+ const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& strides) {
+ std::vector<int64_t> f_strides;
+ const auto& fw_type = checked_cast<const FixedWidthType&>(*type);
+ if (internal::ComputeColumnMajorStrides(fw_type, shape, &f_strides).ok()) {
+ return strides == f_strides;
+ } else {
+ return false;
+ }
+}
+
+inline Status CheckTensorValidity(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Buffer>& data,
+ const std::vector<int64_t>& shape) {
+ if (!type) {
+ return Status::Invalid("Null type is supplied");
+ }
+ if (!is_tensor_supported(type->id())) {
+ return Status::Invalid(type->ToString(), " is not valid data type for a tensor");
+ }
+ if (!data) {
+ return Status::Invalid("Null data is supplied");
+ }
+ if (!std::all_of(shape.begin(), shape.end(), [](int64_t x) { return x >= 0; })) {
+ return Status::Invalid("Shape elements must be positive");
+ }
+ return Status::OK();
+}
+
+Status CheckTensorStridesValidity(const std::shared_ptr<Buffer>& data,
+ const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& strides,
+ const std::shared_ptr<DataType>& type) {
+ if (strides.size() != shape.size()) {
+ return Status::Invalid("strides must have the same length as shape");
+ }
+ if (data->size() == 0 && std::find(shape.begin(), shape.end(), 0) != shape.end()) {
+ return Status::OK();
+ }
+
+ // Check the largest offset can be computed without overflow
+ const size_t ndim = shape.size();
+ int64_t largest_offset = 0;
+ for (size_t i = 0; i < ndim; ++i) {
+ if (shape[i] == 0) continue;
+ if (strides[i] < 0) {
+ // TODO(mrkn): Support negative strides for sharing views
+ return Status::Invalid("negative strides not supported");
+ }
+
+ int64_t dim_offset;
+ if (!internal::MultiplyWithOverflow(shape[i] - 1, strides[i], &dim_offset)) {
+ if (!internal::AddWithOverflow(largest_offset, dim_offset, &largest_offset)) {
+ continue;
+ }
+ }
+
+ return Status::Invalid(
+ "offsets computed from shape and strides would not fit in 64-bit integer");
+ }
+
+ const int byte_width = internal::GetByteWidth(*type);
+ if (largest_offset > data->size() - byte_width) {
+ return Status::Invalid("strides must not involve buffer over run");
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+namespace internal {
+
+bool IsTensorStridesContiguous(const std::shared_ptr<DataType>& type,
+ const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& strides) {
+ return IsTensorStridesRowMajor(type, shape, strides) ||
+ IsTensorStridesColumnMajor(type, shape, strides);
+}
+
+Status ValidateTensorParameters(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Buffer>& data,
+ const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& strides,
+ const std::vector<std::string>& dim_names) {
+ RETURN_NOT_OK(CheckTensorValidity(type, data, shape));
+ if (!strides.empty()) {
+ RETURN_NOT_OK(CheckTensorStridesValidity(data, shape, strides, type));
+ } else {
+ std::vector<int64_t> tmp_strides;
+ RETURN_NOT_OK(ComputeRowMajorStrides(checked_cast<const FixedWidthType&>(*type),
+ shape, &tmp_strides));
+ }
+ if (dim_names.size() > shape.size()) {
+ return Status::Invalid("too many dim_names are supplied");
+ }
+ return Status::OK();
+}
+
+} // namespace internal
+
+/// Constructor with strides and dimension names
+Tensor::Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
+ const std::vector<int64_t>& shape, const std::vector<int64_t>& strides,
+ const std::vector<std::string>& dim_names)
+ : type_(type), data_(data), shape_(shape), strides_(strides), dim_names_(dim_names) {
+ ARROW_CHECK(is_tensor_supported(type->id()));
+ if (shape.size() > 0 && strides.size() == 0) {
+ ARROW_CHECK_OK(internal::ComputeRowMajorStrides(
+ checked_cast<const FixedWidthType&>(*type_), shape, &strides_));
+ }
+}
+
+Tensor::Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
+ const std::vector<int64_t>& shape, const std::vector<int64_t>& strides)
+ : Tensor(type, data, shape, strides, {}) {}
+
+Tensor::Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
+ const std::vector<int64_t>& shape)
+ : Tensor(type, data, shape, {}, {}) {}
+
+const std::string& Tensor::dim_name(int i) const {
+ static const std::string kEmpty = "";
+ if (dim_names_.size() == 0) {
+ return kEmpty;
+ } else {
+ ARROW_CHECK_LT(i, static_cast<int>(dim_names_.size()));
+ return dim_names_[i];
+ }
+}
+
+int64_t Tensor::size() const {
+ return std::accumulate(shape_.begin(), shape_.end(), 1LL, std::multiplies<int64_t>());
+}
+
+bool Tensor::is_contiguous() const {
+ return internal::IsTensorStridesContiguous(type_, shape_, strides_);
+}
+
+bool Tensor::is_row_major() const {
+ return IsTensorStridesRowMajor(type_, shape_, strides_);
+}
+
+bool Tensor::is_column_major() const {
+ return IsTensorStridesColumnMajor(type_, shape_, strides_);
+}
+
+Type::type Tensor::type_id() const { return type_->id(); }
+
+bool Tensor::Equals(const Tensor& other, const EqualOptions& opts) const {
+ return TensorEquals(*this, other, opts);
+}
+
+namespace {
+
+template <typename TYPE>
+int64_t StridedTensorCountNonZero(int dim_index, int64_t offset, const Tensor& tensor) {
+ using c_type = typename TYPE::c_type;
+ c_type const zero = c_type(0);
+ int64_t nnz = 0;
+ if (dim_index == tensor.ndim() - 1) {
+ for (int64_t i = 0; i < tensor.shape()[dim_index]; ++i) {
+ auto const* ptr = tensor.raw_data() + offset + i * tensor.strides()[dim_index];
+ auto& elem = *reinterpret_cast<c_type const*>(ptr);
+ if (elem != zero) ++nnz;
+ }
+ return nnz;
+ }
+ for (int64_t i = 0; i < tensor.shape()[dim_index]; ++i) {
+ nnz += StridedTensorCountNonZero<TYPE>(dim_index + 1, offset, tensor);
+ offset += tensor.strides()[dim_index];
+ }
+ return nnz;
+}
+
+template <typename TYPE>
+int64_t ContiguousTensorCountNonZero(const Tensor& tensor) {
+ using c_type = typename TYPE::c_type;
+ auto* data = reinterpret_cast<c_type const*>(tensor.raw_data());
+ return std::count_if(data, data + tensor.size(),
+ [](c_type const& x) { return x != 0; });
+}
+
+template <typename TYPE>
+inline int64_t TensorCountNonZero(const Tensor& tensor) {
+ if (tensor.is_contiguous()) {
+ return ContiguousTensorCountNonZero<TYPE>(tensor);
+ } else {
+ return StridedTensorCountNonZero<TYPE>(0, 0, tensor);
+ }
+}
+
+struct NonZeroCounter {
+ explicit NonZeroCounter(const Tensor& tensor) : tensor_(tensor) {}
+
+ template <typename TYPE>
+ enable_if_number<TYPE, Status> Visit(const TYPE& type) {
+ result = TensorCountNonZero<TYPE>(tensor_);
+ return Status::OK();
+ }
+
+ Status Visit(const DataType& type) {
+ ARROW_CHECK(!is_tensor_supported(type.id()));
+ return Status::NotImplemented("Tensor of ", type.ToString(), " is not implemented");
+ }
+
+ const Tensor& tensor_;
+ int64_t result = 0;
+};
+
+} // namespace
+
+Result<int64_t> Tensor::CountNonZero() const {
+ NonZeroCounter counter(*this);
+ RETURN_NOT_OK(VisitTypeInline(*type(), &counter));
+ return counter.result;
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/tensor.h b/src/arrow/cpp/src/arrow/tensor.h
new file mode 100644
index 000000000..ff6f3735f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/tensor.h
@@ -0,0 +1,246 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/compare.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+static inline bool is_tensor_supported(Type::type type_id) {
+ switch (type_id) {
+ case Type::UINT8:
+ case Type::INT8:
+ case Type::UINT16:
+ case Type::INT16:
+ case Type::UINT32:
+ case Type::INT32:
+ case Type::UINT64:
+ case Type::INT64:
+ case Type::HALF_FLOAT:
+ case Type::FLOAT:
+ case Type::DOUBLE:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+namespace internal {
+
+ARROW_EXPORT
+Status ComputeRowMajorStrides(const FixedWidthType& type,
+ const std::vector<int64_t>& shape,
+ std::vector<int64_t>* strides);
+
+ARROW_EXPORT
+Status ComputeColumnMajorStrides(const FixedWidthType& type,
+ const std::vector<int64_t>& shape,
+ std::vector<int64_t>* strides);
+
+ARROW_EXPORT
+bool IsTensorStridesContiguous(const std::shared_ptr<DataType>& type,
+ const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& strides);
+
+ARROW_EXPORT
+Status ValidateTensorParameters(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Buffer>& data,
+ const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& strides,
+ const std::vector<std::string>& dim_names);
+
+} // namespace internal
+
+class ARROW_EXPORT Tensor {
+ public:
+ /// \brief Create a Tensor with full parameters
+ ///
+ /// This factory function will return Status::Invalid when the parameters are
+ /// inconsistent
+ ///
+ /// \param[in] type The data type of the tensor values
+ /// \param[in] data The buffer of the tensor content
+ /// \param[in] shape The shape of the tensor
+ /// \param[in] strides The strides of the tensor
+ /// (if this is empty, the data assumed to be row-major)
+ /// \param[in] dim_names The names of the tensor dimensions
+ static inline Result<std::shared_ptr<Tensor>> Make(
+ const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
+ const std::vector<int64_t>& shape, const std::vector<int64_t>& strides = {},
+ const std::vector<std::string>& dim_names = {}) {
+ ARROW_RETURN_NOT_OK(
+ internal::ValidateTensorParameters(type, data, shape, strides, dim_names));
+ return std::make_shared<Tensor>(type, data, shape, strides, dim_names);
+ }
+
+ virtual ~Tensor() = default;
+
+ /// Constructor with no dimension names or strides, data assumed to be row-major
+ Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
+ const std::vector<int64_t>& shape);
+
+ /// Constructor with non-negative strides
+ Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
+ const std::vector<int64_t>& shape, const std::vector<int64_t>& strides);
+
+ /// Constructor with non-negative strides and dimension names
+ Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
+ const std::vector<int64_t>& shape, const std::vector<int64_t>& strides,
+ const std::vector<std::string>& dim_names);
+
+ std::shared_ptr<DataType> type() const { return type_; }
+ std::shared_ptr<Buffer> data() const { return data_; }
+
+ const uint8_t* raw_data() const { return data_->data(); }
+ uint8_t* raw_mutable_data() { return data_->mutable_data(); }
+
+ const std::vector<int64_t>& shape() const { return shape_; }
+ const std::vector<int64_t>& strides() const { return strides_; }
+
+ int ndim() const { return static_cast<int>(shape_.size()); }
+
+ const std::vector<std::string>& dim_names() const { return dim_names_; }
+ const std::string& dim_name(int i) const;
+
+ /// Total number of value cells in the tensor
+ int64_t size() const;
+
+ /// Return true if the underlying data buffer is mutable
+ bool is_mutable() const { return data_->is_mutable(); }
+
+ /// Either row major or column major
+ bool is_contiguous() const;
+
+ /// AKA "C order"
+ bool is_row_major() const;
+
+ /// AKA "Fortran order"
+ bool is_column_major() const;
+
+ Type::type type_id() const;
+
+ bool Equals(const Tensor& other, const EqualOptions& = EqualOptions::Defaults()) const;
+
+ /// Compute the number of non-zero values in the tensor
+ Result<int64_t> CountNonZero() const;
+
+ /// Return the offset of the given index on the given strides
+ static int64_t CalculateValueOffset(const std::vector<int64_t>& strides,
+ const std::vector<int64_t>& index) {
+ const int64_t n = static_cast<int64_t>(index.size());
+ int64_t offset = 0;
+ for (int64_t i = 0; i < n; ++i) {
+ offset += index[i] * strides[i];
+ }
+ return offset;
+ }
+
+ int64_t CalculateValueOffset(const std::vector<int64_t>& index) const {
+ return Tensor::CalculateValueOffset(strides_, index);
+ }
+
+ /// Returns the value at the given index without data-type and bounds checks
+ template <typename ValueType>
+ const typename ValueType::c_type& Value(const std::vector<int64_t>& index) const {
+ using c_type = typename ValueType::c_type;
+ const int64_t offset = CalculateValueOffset(index);
+ const c_type* ptr = reinterpret_cast<const c_type*>(raw_data() + offset);
+ return *ptr;
+ }
+
+ Status Validate() const {
+ return internal::ValidateTensorParameters(type_, data_, shape_, strides_, dim_names_);
+ }
+
+ protected:
+ Tensor() {}
+
+ std::shared_ptr<DataType> type_;
+ std::shared_ptr<Buffer> data_;
+ std::vector<int64_t> shape_;
+ std::vector<int64_t> strides_;
+
+ /// These names are optional
+ std::vector<std::string> dim_names_;
+
+ template <typename SparseIndexType>
+ friend class SparseTensorImpl;
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Tensor);
+};
+
+template <typename TYPE>
+class NumericTensor : public Tensor {
+ public:
+ using TypeClass = TYPE;
+ using value_type = typename TypeClass::c_type;
+
+ /// \brief Create a NumericTensor with full parameters
+ ///
+ /// This factory function will return Status::Invalid when the parameters are
+ /// inconsistent
+ ///
+ /// \param[in] data The buffer of the tensor content
+ /// \param[in] shape The shape of the tensor
+ /// \param[in] strides The strides of the tensor
+ /// (if this is empty, the data assumed to be row-major)
+ /// \param[in] dim_names The names of the tensor dimensions
+ static Result<std::shared_ptr<NumericTensor<TYPE>>> Make(
+ const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& strides = {},
+ const std::vector<std::string>& dim_names = {}) {
+ ARROW_RETURN_NOT_OK(internal::ValidateTensorParameters(
+ TypeTraits<TYPE>::type_singleton(), data, shape, strides, dim_names));
+ return std::make_shared<NumericTensor<TYPE>>(data, shape, strides, dim_names);
+ }
+
+ /// Constructor with non-negative strides and dimension names
+ NumericTensor(const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& strides,
+ const std::vector<std::string>& dim_names)
+ : Tensor(TypeTraits<TYPE>::type_singleton(), data, shape, strides, dim_names) {}
+
+ /// Constructor with no dimension names or strides, data assumed to be row-major
+ NumericTensor(const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape)
+ : NumericTensor(data, shape, {}, {}) {}
+
+ /// Constructor with non-negative strides
+ NumericTensor(const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& strides)
+ : NumericTensor(data, shape, strides, {}) {}
+
+ const value_type& Value(const std::vector<int64_t>& index) const {
+ return Tensor::Value<TypeClass>(index);
+ }
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/tensor/CMakeLists.txt b/src/arrow/cpp/src/arrow/tensor/CMakeLists.txt
new file mode 100644
index 000000000..32381c0bc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/tensor/CMakeLists.txt
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#
+# arrow_tensor
+#
+
+# Headers: top level
+arrow_install_all_headers("arrow/tensor")
+
+add_arrow_benchmark(tensor_conversion_benchmark)
diff --git a/src/arrow/cpp/src/arrow/tensor/converter.h b/src/arrow/cpp/src/arrow/tensor/converter.h
new file mode 100644
index 000000000..408ab2230
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/tensor/converter.h
@@ -0,0 +1,67 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/sparse_tensor.h" // IWYU pragma: export
+
+#include <memory>
+
+namespace arrow {
+namespace internal {
+
+struct SparseTensorConverterMixin {
+ static bool IsNonZero(const uint8_t val) { return val != 0; }
+
+ static void AssignIndex(uint8_t* indices, int64_t val, const int elsize);
+
+ static int64_t GetIndexValue(const uint8_t* value_ptr, const int elsize);
+};
+
+Status MakeSparseCOOTensorFromTensor(const Tensor& tensor,
+ const std::shared_ptr<DataType>& index_value_type,
+ MemoryPool* pool,
+ std::shared_ptr<SparseIndex>* out_sparse_index,
+ std::shared_ptr<Buffer>* out_data);
+
+Status MakeSparseCSXMatrixFromTensor(SparseMatrixCompressedAxis axis,
+ const Tensor& tensor,
+ const std::shared_ptr<DataType>& index_value_type,
+ MemoryPool* pool,
+ std::shared_ptr<SparseIndex>* out_sparse_index,
+ std::shared_ptr<Buffer>* out_data);
+
+Status MakeSparseCSFTensorFromTensor(const Tensor& tensor,
+ const std::shared_ptr<DataType>& index_value_type,
+ MemoryPool* pool,
+ std::shared_ptr<SparseIndex>* out_sparse_index,
+ std::shared_ptr<Buffer>* out_data);
+
+Result<std::shared_ptr<Tensor>> MakeTensorFromSparseCOOTensor(
+ MemoryPool* pool, const SparseCOOTensor* sparse_tensor);
+
+Result<std::shared_ptr<Tensor>> MakeTensorFromSparseCSRMatrix(
+ MemoryPool* pool, const SparseCSRMatrix* sparse_tensor);
+
+Result<std::shared_ptr<Tensor>> MakeTensorFromSparseCSCMatrix(
+ MemoryPool* pool, const SparseCSCMatrix* sparse_tensor);
+
+Result<std::shared_ptr<Tensor>> MakeTensorFromSparseCSFTensor(
+ MemoryPool* pool, const SparseCSFTensor* sparse_tensor);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/tensor/converter_internal.h b/src/arrow/cpp/src/arrow/tensor/converter_internal.h
new file mode 100644
index 000000000..3a87feaf4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/tensor/converter_internal.h
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/tensor/converter.h"
+
+#define DISPATCH(ACTION, index_elsize, value_elsize, ...) \
+ switch (index_elsize) { \
+ case 1: \
+ switch (value_elsize) { \
+ case 1: \
+ ACTION(uint8_t, uint8_t, __VA_ARGS__); \
+ break; \
+ case 2: \
+ ACTION(uint8_t, uint16_t, __VA_ARGS__); \
+ break; \
+ case 4: \
+ ACTION(uint8_t, uint32_t, __VA_ARGS__); \
+ break; \
+ case 8: \
+ ACTION(uint8_t, uint64_t, __VA_ARGS__); \
+ break; \
+ } \
+ break; \
+ case 2: \
+ switch (value_elsize) { \
+ case 1: \
+ ACTION(uint16_t, uint8_t, __VA_ARGS__); \
+ break; \
+ case 2: \
+ ACTION(uint16_t, uint16_t, __VA_ARGS__); \
+ break; \
+ case 4: \
+ ACTION(uint16_t, uint32_t, __VA_ARGS__); \
+ break; \
+ case 8: \
+ ACTION(uint16_t, uint64_t, __VA_ARGS__); \
+ break; \
+ } \
+ break; \
+ case 4: \
+ switch (value_elsize) { \
+ case 1: \
+ ACTION(uint32_t, uint8_t, __VA_ARGS__); \
+ break; \
+ case 2: \
+ ACTION(uint32_t, uint16_t, __VA_ARGS__); \
+ break; \
+ case 4: \
+ ACTION(uint32_t, uint32_t, __VA_ARGS__); \
+ break; \
+ case 8: \
+ ACTION(uint32_t, uint64_t, __VA_ARGS__); \
+ break; \
+ } \
+ break; \
+ case 8: \
+ switch (value_elsize) { \
+ case 1: \
+ ACTION(int64_t, uint8_t, __VA_ARGS__); \
+ break; \
+ case 2: \
+ ACTION(int64_t, uint16_t, __VA_ARGS__); \
+ break; \
+ case 4: \
+ ACTION(int64_t, uint32_t, __VA_ARGS__); \
+ break; \
+ case 8: \
+ ACTION(int64_t, uint64_t, __VA_ARGS__); \
+ break; \
+ } \
+ break; \
+ }
diff --git a/src/arrow/cpp/src/arrow/tensor/coo_converter.cc b/src/arrow/cpp/src/arrow/tensor/coo_converter.cc
new file mode 100644
index 000000000..2124d0a4e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/tensor/coo_converter.cc
@@ -0,0 +1,333 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/tensor/converter_internal.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/macros.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+class MemoryPool;
+
+namespace internal {
+namespace {
+
+template <typename c_index_type>
+inline void IncrementRowMajorIndex(std::vector<c_index_type>& coord,
+ const std::vector<int64_t>& shape) {
+ const int64_t ndim = shape.size();
+ ++coord[ndim - 1];
+ if (coord[ndim - 1] == shape[ndim - 1]) {
+ int64_t d = ndim - 1;
+ while (d > 0 && coord[d] == shape[d]) {
+ coord[d] = 0;
+ ++coord[d - 1];
+ --d;
+ }
+ }
+}
+
+template <typename c_index_type, typename c_value_type>
+void ConvertRowMajorTensor(const Tensor& tensor, c_index_type* indices,
+ c_value_type* values, const int64_t size) {
+ const auto ndim = tensor.ndim();
+ const auto& shape = tensor.shape();
+ const c_value_type* tensor_data =
+ reinterpret_cast<const c_value_type*>(tensor.raw_data());
+
+ constexpr c_value_type zero = 0;
+ std::vector<c_index_type> coord(ndim, 0);
+ for (int64_t n = tensor.size(); n > 0; --n) {
+ const c_value_type x = *tensor_data;
+ if (ARROW_PREDICT_FALSE(x != zero)) {
+ std::copy(coord.begin(), coord.end(), indices);
+ *values++ = x;
+ indices += ndim;
+ }
+
+ IncrementRowMajorIndex(coord, shape);
+ ++tensor_data;
+ }
+}
+
+template <typename c_index_type, typename c_value_type>
+void ConvertColumnMajorTensor(const Tensor& tensor, c_index_type* out_indices,
+ c_value_type* out_values, const int64_t size) {
+ const auto ndim = tensor.ndim();
+ std::vector<c_index_type> indices(ndim * size);
+ std::vector<c_value_type> values(size);
+ ConvertRowMajorTensor(tensor, indices.data(), values.data(), size);
+
+ // transpose indices
+ for (int64_t i = 0; i < size; ++i) {
+ for (int j = 0; j < ndim / 2; ++j) {
+ std::swap(indices[i * ndim + j], indices[i * ndim + ndim - j - 1]);
+ }
+ }
+
+ // sort indices
+ std::vector<int64_t> order(size);
+ std::iota(order.begin(), order.end(), 0);
+ std::sort(order.begin(), order.end(), [&](const int64_t xi, const int64_t yi) {
+ const int64_t x_offset = xi * ndim;
+ const int64_t y_offset = yi * ndim;
+ for (int j = 0; j < ndim; ++j) {
+ const auto x = indices[x_offset + j];
+ const auto y = indices[y_offset + j];
+ if (x < y) return true;
+ if (x > y) return false;
+ }
+ return false;
+ });
+
+ // transfer result
+ const auto* indices_data = indices.data();
+ for (int64_t i = 0; i < size; ++i) {
+ out_values[i] = values[i];
+
+ std::copy_n(indices_data, ndim, out_indices);
+ indices_data += ndim;
+ out_indices += ndim;
+ }
+}
+
+template <typename c_index_type, typename c_value_type>
+void ConvertStridedTensor(const Tensor& tensor, c_index_type* indices,
+ c_value_type* values, const int64_t size) {
+ using ValueType = typename CTypeTraits<c_value_type>::ArrowType;
+ const auto& shape = tensor.shape();
+ const auto ndim = tensor.ndim();
+ std::vector<int64_t> coord(ndim, 0);
+
+ constexpr c_value_type zero = 0;
+ c_value_type x;
+ int64_t i;
+ for (int64_t n = tensor.size(); n > 0; --n) {
+ x = tensor.Value<ValueType>(coord);
+ if (ARROW_PREDICT_FALSE(x != zero)) {
+ *values++ = x;
+ for (i = 0; i < ndim; ++i) {
+ *indices++ = static_cast<c_index_type>(coord[i]);
+ }
+ }
+
+ IncrementRowMajorIndex(coord, shape);
+ }
+}
+
+#define CONVERT_TENSOR(func, index_type, value_type, indices, values, size) \
+ func<index_type, value_type>(tensor_, reinterpret_cast<index_type*>(indices), \
+ reinterpret_cast<value_type*>(values), size)
+
+// Using ARROW_EXPAND is necessary to expand __VA_ARGS__ correctly on VC++.
+#define CONVERT_ROW_MAJOR_TENSOR(index_type, value_type, ...) \
+ ARROW_EXPAND(CONVERT_TENSOR(ConvertRowMajorTensor, index_type, value_type, __VA_ARGS__))
+
+#define CONVERT_COLUMN_MAJOR_TENSOR(index_type, value_type, ...) \
+ ARROW_EXPAND( \
+ CONVERT_TENSOR(ConvertColumnMajorTensor, index_type, value_type, __VA_ARGS__))
+
+#define CONVERT_STRIDED_TENSOR(index_type, value_type, ...) \
+ ARROW_EXPAND(CONVERT_TENSOR(ConvertStridedTensor, index_type, value_type, __VA_ARGS__))
+
+// ----------------------------------------------------------------------
+// SparseTensorConverter for SparseCOOIndex
+
+class SparseCOOTensorConverter : private SparseTensorConverterMixin {
+ using SparseTensorConverterMixin::AssignIndex;
+ using SparseTensorConverterMixin::IsNonZero;
+
+ public:
+ SparseCOOTensorConverter(const Tensor& tensor,
+ const std::shared_ptr<DataType>& index_value_type,
+ MemoryPool* pool)
+ : tensor_(tensor), index_value_type_(index_value_type), pool_(pool) {}
+
+ Status Convert() {
+ RETURN_NOT_OK(::arrow::internal::CheckSparseIndexMaximumValue(index_value_type_,
+ tensor_.shape()));
+
+ const int index_elsize = GetByteWidth(*index_value_type_);
+ const int value_elsize = GetByteWidth(*tensor_.type());
+
+ const int64_t ndim = tensor_.ndim();
+ ARROW_ASSIGN_OR_RAISE(int64_t nonzero_count, tensor_.CountNonZero());
+
+ ARROW_ASSIGN_OR_RAISE(auto indices_buffer,
+ AllocateBuffer(index_elsize * ndim * nonzero_count, pool_));
+ uint8_t* indices = indices_buffer->mutable_data();
+
+ ARROW_ASSIGN_OR_RAISE(auto values_buffer,
+ AllocateBuffer(value_elsize * nonzero_count, pool_));
+ uint8_t* values = values_buffer->mutable_data();
+
+ const uint8_t* tensor_data = tensor_.raw_data();
+ if (ndim <= 1) {
+ const int64_t count = ndim == 0 ? 1 : tensor_.shape()[0];
+ for (int64_t i = 0; i < count; ++i) {
+ if (std::any_of(tensor_data, tensor_data + value_elsize, IsNonZero)) {
+ AssignIndex(indices, i, index_elsize);
+ std::copy_n(tensor_data, value_elsize, values);
+
+ indices += index_elsize;
+ values += value_elsize;
+ }
+ tensor_data += value_elsize;
+ }
+ } else if (tensor_.is_row_major()) {
+ DISPATCH(CONVERT_ROW_MAJOR_TENSOR, index_elsize, value_elsize, indices, values,
+ nonzero_count);
+ } else if (tensor_.is_column_major()) {
+ DISPATCH(CONVERT_COLUMN_MAJOR_TENSOR, index_elsize, value_elsize, indices, values,
+ nonzero_count);
+ } else {
+ DISPATCH(CONVERT_STRIDED_TENSOR, index_elsize, value_elsize, indices, values,
+ nonzero_count);
+ }
+
+ // make results
+ const std::vector<int64_t> indices_shape = {nonzero_count, ndim};
+ std::vector<int64_t> indices_strides;
+ RETURN_NOT_OK(internal::ComputeRowMajorStrides(
+ checked_cast<const FixedWidthType&>(*index_value_type_), indices_shape,
+ &indices_strides));
+ auto coords = std::make_shared<Tensor>(index_value_type_, std::move(indices_buffer),
+ indices_shape, indices_strides);
+ ARROW_ASSIGN_OR_RAISE(sparse_index, SparseCOOIndex::Make(coords, true));
+ data = std::move(values_buffer);
+
+ return Status::OK();
+ }
+
+ std::shared_ptr<SparseCOOIndex> sparse_index;
+ std::shared_ptr<Buffer> data;
+
+ private:
+ const Tensor& tensor_;
+ const std::shared_ptr<DataType>& index_value_type_;
+ MemoryPool* pool_;
+};
+
+} // namespace
+
+void SparseTensorConverterMixin::AssignIndex(uint8_t* indices, int64_t val,
+ const int elsize) {
+ switch (elsize) {
+ case 1:
+ *indices = static_cast<uint8_t>(val);
+ break;
+ case 2:
+ *reinterpret_cast<uint16_t*>(indices) = static_cast<uint16_t>(val);
+ break;
+ case 4:
+ *reinterpret_cast<uint32_t*>(indices) = static_cast<uint32_t>(val);
+ break;
+ case 8:
+ *reinterpret_cast<int64_t*>(indices) = val;
+ break;
+ default:
+ break;
+ }
+}
+
+int64_t SparseTensorConverterMixin::GetIndexValue(const uint8_t* value_ptr,
+ const int elsize) {
+ switch (elsize) {
+ case 1:
+ return *value_ptr;
+
+ case 2:
+ return *reinterpret_cast<const uint16_t*>(value_ptr);
+
+ case 4:
+ return *reinterpret_cast<const uint32_t*>(value_ptr);
+
+ case 8:
+ return *reinterpret_cast<const int64_t*>(value_ptr);
+
+ default:
+ return 0;
+ }
+}
+
+Status MakeSparseCOOTensorFromTensor(const Tensor& tensor,
+ const std::shared_ptr<DataType>& index_value_type,
+ MemoryPool* pool,
+ std::shared_ptr<SparseIndex>* out_sparse_index,
+ std::shared_ptr<Buffer>* out_data) {
+ SparseCOOTensorConverter converter(tensor, index_value_type, pool);
+ RETURN_NOT_OK(converter.Convert());
+
+ *out_sparse_index = checked_pointer_cast<SparseIndex>(converter.sparse_index);
+ *out_data = converter.data;
+ return Status::OK();
+}
+
+Result<std::shared_ptr<Tensor>> MakeTensorFromSparseCOOTensor(
+ MemoryPool* pool, const SparseCOOTensor* sparse_tensor) {
+ const auto& sparse_index =
+ checked_cast<const SparseCOOIndex&>(*sparse_tensor->sparse_index());
+ const auto& coords = sparse_index.indices();
+ const auto* coords_data = coords->raw_data();
+
+ const int index_elsize = GetByteWidth(*coords->type());
+
+ const auto& value_type = checked_cast<const FixedWidthType&>(*sparse_tensor->type());
+ const int value_elsize = GetByteWidth(value_type);
+ ARROW_ASSIGN_OR_RAISE(auto values_buffer,
+ AllocateBuffer(value_elsize * sparse_tensor->size(), pool));
+ auto values = values_buffer->mutable_data();
+ std::fill_n(values, value_elsize * sparse_tensor->size(), 0);
+
+ std::vector<int64_t> strides;
+ RETURN_NOT_OK(ComputeRowMajorStrides(value_type, sparse_tensor->shape(), &strides));
+
+ const auto* raw_data = sparse_tensor->raw_data();
+ const int ndim = sparse_tensor->ndim();
+
+ for (int64_t i = 0; i < sparse_tensor->non_zero_length(); ++i) {
+ int64_t offset = 0;
+
+ for (int j = 0; j < ndim; ++j) {
+ auto index = static_cast<int64_t>(
+ SparseTensorConverterMixin::GetIndexValue(coords_data, index_elsize));
+ offset += index * strides[j];
+ coords_data += index_elsize;
+ }
+
+ std::copy_n(raw_data, value_elsize, values + offset);
+ raw_data += value_elsize;
+ }
+
+ return std::make_shared<Tensor>(sparse_tensor->type(), std::move(values_buffer),
+ sparse_tensor->shape(), strides,
+ sparse_tensor->dim_names());
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/tensor/csf_converter.cc b/src/arrow/cpp/src/arrow/tensor/csf_converter.cc
new file mode 100644
index 000000000..77a71d8a1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/tensor/csf_converter.cc
@@ -0,0 +1,289 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/tensor/converter.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/sort.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+class MemoryPool;
+
+namespace internal {
+namespace {
+
+inline void IncrementIndex(std::vector<int64_t>& coord, const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& axis_order) {
+ const int64_t ndim = shape.size();
+ const int64_t last_axis = axis_order[ndim - 1];
+ ++coord[last_axis];
+ if (coord[last_axis] == shape[last_axis]) {
+ int64_t d = ndim - 1;
+ while (d > 0 && coord[axis_order[d]] == shape[axis_order[d]]) {
+ coord[axis_order[d]] = 0;
+ ++coord[axis_order[d - 1]];
+ --d;
+ }
+ }
+}
+
+// ----------------------------------------------------------------------
+// SparseTensorConverter for SparseCSFIndex
+
+class SparseCSFTensorConverter : private SparseTensorConverterMixin {
+ using SparseTensorConverterMixin::AssignIndex;
+ using SparseTensorConverterMixin::IsNonZero;
+
+ public:
+ SparseCSFTensorConverter(const Tensor& tensor,
+ const std::shared_ptr<DataType>& index_value_type,
+ MemoryPool* pool)
+ : tensor_(tensor), index_value_type_(index_value_type), pool_(pool) {}
+
+ Status Convert() {
+ RETURN_NOT_OK(::arrow::internal::CheckSparseIndexMaximumValue(index_value_type_,
+ tensor_.shape()));
+
+ const int index_elsize = GetByteWidth(*index_value_type_);
+ const int value_elsize = GetByteWidth(*tensor_.type());
+
+ const int64_t ndim = tensor_.ndim();
+ // Axis order as ascending order of dimension size is a good heuristic but is not
+ // necessarily optimal.
+ std::vector<int64_t> axis_order = internal::ArgSort(tensor_.shape());
+ ARROW_ASSIGN_OR_RAISE(int64_t nonzero_count, tensor_.CountNonZero());
+
+ ARROW_ASSIGN_OR_RAISE(auto values_buffer,
+ AllocateBuffer(value_elsize * nonzero_count, pool_));
+ auto* values = values_buffer->mutable_data();
+
+ std::vector<int64_t> counts(ndim, 0);
+ std::vector<int64_t> coord(ndim, 0);
+ std::vector<int64_t> previous_coord(ndim, -1);
+ std::vector<BufferBuilder> indptr_buffer_builders(ndim - 1);
+ std::vector<BufferBuilder> indices_buffer_builders(ndim);
+
+ const auto* tensor_data = tensor_.raw_data();
+ uint8_t index_buffer[sizeof(int64_t)];
+
+ if (ndim <= 1) {
+ return Status::NotImplemented("TODO for ndim <= 1");
+ } else {
+ const auto& shape = tensor_.shape();
+ for (int64_t n = tensor_.size(); n > 0; n--) {
+ const auto offset = tensor_.CalculateValueOffset(coord);
+ const auto xp = tensor_data + offset;
+
+ if (std::any_of(xp, xp + value_elsize, IsNonZero)) {
+ bool tree_split = false;
+
+ std::copy_n(xp, value_elsize, values);
+ values += value_elsize;
+
+ for (int64_t i = 0; i < ndim; ++i) {
+ int64_t dimension = axis_order[i];
+
+ tree_split = tree_split || (coord[dimension] != previous_coord[dimension]);
+ if (tree_split) {
+ if (i < ndim - 1) {
+ AssignIndex(index_buffer, counts[i + 1], index_elsize);
+ RETURN_NOT_OK(
+ indptr_buffer_builders[i].Append(index_buffer, index_elsize));
+ }
+
+ AssignIndex(index_buffer, coord[dimension], index_elsize);
+ RETURN_NOT_OK(
+ indices_buffer_builders[i].Append(index_buffer, index_elsize));
+
+ ++counts[i];
+ }
+ }
+
+ previous_coord = coord;
+ }
+
+ IncrementIndex(coord, shape, axis_order);
+ }
+ }
+
+ for (int64_t column = 0; column < ndim - 1; ++column) {
+ AssignIndex(index_buffer, counts[column + 1], index_elsize);
+ RETURN_NOT_OK(indptr_buffer_builders[column].Append(index_buffer, index_elsize));
+ }
+
+ // make results
+ data = std::move(values_buffer);
+
+ std::vector<std::shared_ptr<Buffer>> indptr_buffers(ndim - 1);
+ std::vector<std::shared_ptr<Buffer>> indices_buffers(ndim);
+ std::vector<int64_t> indptr_shapes(counts.begin(), counts.end() - 1);
+ std::vector<int64_t> indices_shapes = counts;
+
+ for (int64_t column = 0; column < ndim; ++column) {
+ RETURN_NOT_OK(
+ indices_buffer_builders[column].Finish(&indices_buffers[column], true));
+ }
+ for (int64_t column = 0; column < ndim - 1; ++column) {
+ RETURN_NOT_OK(indptr_buffer_builders[column].Finish(&indptr_buffers[column], true));
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ sparse_index, SparseCSFIndex::Make(index_value_type_, indices_shapes, axis_order,
+ indptr_buffers, indices_buffers));
+ return Status::OK();
+ }
+
+ std::shared_ptr<SparseCSFIndex> sparse_index;
+ std::shared_ptr<Buffer> data;
+
+ private:
+ const Tensor& tensor_;
+ const std::shared_ptr<DataType>& index_value_type_;
+ MemoryPool* pool_;
+};
+
+class TensorBuilderFromSparseCSFTensor : private SparseTensorConverterMixin {
+ using SparseTensorConverterMixin::GetIndexValue;
+
+ MemoryPool* pool_;
+ const SparseCSFTensor* sparse_tensor_;
+ const SparseCSFIndex* sparse_index_;
+ const std::vector<std::shared_ptr<Tensor>>& indptr_;
+ const std::vector<std::shared_ptr<Tensor>>& indices_;
+ const std::vector<int64_t>& axis_order_;
+ const std::vector<int64_t>& shape_;
+ const int64_t non_zero_length_;
+ const int ndim_;
+ const int64_t tensor_size_;
+ const FixedWidthType& value_type_;
+ const int value_elsize_;
+ const uint8_t* raw_data_;
+ std::vector<int64_t> strides_;
+ std::shared_ptr<Buffer> values_buffer_;
+ uint8_t* values_;
+
+ public:
+ TensorBuilderFromSparseCSFTensor(const SparseCSFTensor* sparse_tensor, MemoryPool* pool)
+ : pool_(pool),
+ sparse_tensor_(sparse_tensor),
+ sparse_index_(
+ checked_cast<const SparseCSFIndex*>(sparse_tensor->sparse_index().get())),
+ indptr_(sparse_index_->indptr()),
+ indices_(sparse_index_->indices()),
+ axis_order_(sparse_index_->axis_order()),
+ shape_(sparse_tensor->shape()),
+ non_zero_length_(sparse_tensor->non_zero_length()),
+ ndim_(sparse_tensor->ndim()),
+ tensor_size_(sparse_tensor->size()),
+ value_type_(checked_cast<const FixedWidthType&>(*sparse_tensor->type())),
+ value_elsize_(GetByteWidth(value_type_)),
+ raw_data_(sparse_tensor->raw_data()) {}
+
+ int ElementSize(const std::shared_ptr<Tensor>& tensor) const {
+ return GetByteWidth(*tensor->type());
+ }
+
+ Result<std::shared_ptr<Tensor>> Build() {
+ RETURN_NOT_OK(internal::ComputeRowMajorStrides(value_type_, shape_, &strides_));
+
+ ARROW_ASSIGN_OR_RAISE(values_buffer_,
+ AllocateBuffer(value_elsize_ * tensor_size_, pool_));
+ values_ = values_buffer_->mutable_data();
+ std::fill_n(values_, value_elsize_ * tensor_size_, 0);
+
+ const int64_t start = 0;
+ const int64_t stop = indptr_[0]->size() - 1;
+ ExpandValues(0, 0, start, stop);
+
+ return std::make_shared<Tensor>(sparse_tensor_->type(), std::move(values_buffer_),
+ shape_, strides_, sparse_tensor_->dim_names());
+ }
+
+ void ExpandValues(const int64_t dim, const int64_t dim_offset, const int64_t start,
+ const int64_t stop) {
+ const auto& cur_indices = indices_[dim];
+ const int indices_elsize = ElementSize(cur_indices);
+ const auto* indices_data = cur_indices->raw_data() + start * indices_elsize;
+
+ if (dim == ndim_ - 1) {
+ for (auto i = start; i < stop; ++i) {
+ const int64_t index =
+ SparseTensorConverterMixin::GetIndexValue(indices_data, indices_elsize);
+ const int64_t offset = dim_offset + index * strides_[axis_order_[dim]];
+
+ std::copy_n(raw_data_ + i * value_elsize_, value_elsize_, values_ + offset);
+
+ indices_data += indices_elsize;
+ }
+ } else {
+ const auto& cur_indptr = indptr_[dim];
+ const int indptr_elsize = ElementSize(cur_indptr);
+ const auto* indptr_data = cur_indptr->raw_data() + start * indptr_elsize;
+
+ for (int64_t i = start; i < stop; ++i) {
+ const int64_t index =
+ SparseTensorConverterMixin::GetIndexValue(indices_data, indices_elsize);
+ const int64_t offset = dim_offset + index * strides_[axis_order_[dim]];
+ const int64_t next_start = GetIndexValue(indptr_data, indptr_elsize);
+ const int64_t next_stop =
+ GetIndexValue(indptr_data + indptr_elsize, indptr_elsize);
+
+ ExpandValues(dim + 1, offset, next_start, next_stop);
+
+ indices_data += indices_elsize;
+ indptr_data += indptr_elsize;
+ }
+ }
+ }
+};
+
+} // namespace
+
+Status MakeSparseCSFTensorFromTensor(const Tensor& tensor,
+ const std::shared_ptr<DataType>& index_value_type,
+ MemoryPool* pool,
+ std::shared_ptr<SparseIndex>* out_sparse_index,
+ std::shared_ptr<Buffer>* out_data) {
+ SparseCSFTensorConverter converter(tensor, index_value_type, pool);
+ RETURN_NOT_OK(converter.Convert());
+
+ *out_sparse_index = checked_pointer_cast<SparseIndex>(converter.sparse_index);
+ *out_data = converter.data;
+ return Status::OK();
+}
+
+Result<std::shared_ptr<Tensor>> MakeTensorFromSparseCSFTensor(
+ MemoryPool* pool, const SparseCSFTensor* sparse_tensor) {
+ TensorBuilderFromSparseCSFTensor builder(sparse_tensor, pool);
+ return builder.Build();
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/tensor/csx_converter.cc b/src/arrow/cpp/src/arrow/tensor/csx_converter.cc
new file mode 100644
index 000000000..137b5d320
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/tensor/csx_converter.cc
@@ -0,0 +1,241 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/tensor/converter.h"
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+class MemoryPool;
+
+namespace internal {
+namespace {
+
+// ----------------------------------------------------------------------
+// SparseTensorConverter for SparseCSRIndex
+
+class SparseCSXMatrixConverter : private SparseTensorConverterMixin {
+ using SparseTensorConverterMixin::AssignIndex;
+ using SparseTensorConverterMixin::IsNonZero;
+
+ public:
+ SparseCSXMatrixConverter(SparseMatrixCompressedAxis axis, const Tensor& tensor,
+ const std::shared_ptr<DataType>& index_value_type,
+ MemoryPool* pool)
+ : axis_(axis), tensor_(tensor), index_value_type_(index_value_type), pool_(pool) {}
+
+ Status Convert() {
+ RETURN_NOT_OK(::arrow::internal::CheckSparseIndexMaximumValue(index_value_type_,
+ tensor_.shape()));
+
+ const int index_elsize = GetByteWidth(*index_value_type_);
+ const int value_elsize = GetByteWidth(*tensor_.type());
+
+ const int64_t ndim = tensor_.ndim();
+ if (ndim > 2) {
+ return Status::Invalid("Invalid tensor dimension");
+ }
+
+ const int major_axis = static_cast<int>(axis_);
+ const int64_t n_major = tensor_.shape()[major_axis];
+ const int64_t n_minor = tensor_.shape()[1 - major_axis];
+ ARROW_ASSIGN_OR_RAISE(int64_t nonzero_count, tensor_.CountNonZero());
+
+ std::shared_ptr<Buffer> indptr_buffer;
+ std::shared_ptr<Buffer> indices_buffer;
+
+ ARROW_ASSIGN_OR_RAISE(auto values_buffer,
+ AllocateBuffer(value_elsize * nonzero_count, pool_));
+ auto* values = values_buffer->mutable_data();
+
+ const auto* tensor_data = tensor_.raw_data();
+
+ if (ndim <= 1) {
+ return Status::NotImplemented("TODO for ndim <= 1");
+ } else {
+ ARROW_ASSIGN_OR_RAISE(indptr_buffer,
+ AllocateBuffer(index_elsize * (n_major + 1), pool_));
+ auto* indptr = indptr_buffer->mutable_data();
+
+ ARROW_ASSIGN_OR_RAISE(indices_buffer,
+ AllocateBuffer(index_elsize * nonzero_count, pool_));
+ auto* indices = indices_buffer->mutable_data();
+
+ std::vector<int64_t> coords(2);
+ int64_t k = 0;
+ std::fill_n(indptr, index_elsize, 0);
+ indptr += index_elsize;
+ for (int64_t i = 0; i < n_major; ++i) {
+ for (int64_t j = 0; j < n_minor; ++j) {
+ if (axis_ == SparseMatrixCompressedAxis::ROW) {
+ coords = {i, j};
+ } else {
+ coords = {j, i};
+ }
+ const int64_t offset = tensor_.CalculateValueOffset(coords);
+ if (std::any_of(tensor_data + offset, tensor_data + offset + value_elsize,
+ IsNonZero)) {
+ std::copy_n(tensor_data + offset, value_elsize, values);
+ values += value_elsize;
+
+ AssignIndex(indices, j, index_elsize);
+ indices += index_elsize;
+
+ k++;
+ }
+ }
+ AssignIndex(indptr, k, index_elsize);
+ indptr += index_elsize;
+ }
+ }
+
+ std::vector<int64_t> indptr_shape({n_major + 1});
+ std::shared_ptr<Tensor> indptr_tensor =
+ std::make_shared<Tensor>(index_value_type_, indptr_buffer, indptr_shape);
+
+ std::vector<int64_t> indices_shape({nonzero_count});
+ std::shared_ptr<Tensor> indices_tensor =
+ std::make_shared<Tensor>(index_value_type_, indices_buffer, indices_shape);
+
+ if (axis_ == SparseMatrixCompressedAxis::ROW) {
+ sparse_index = std::make_shared<SparseCSRIndex>(indptr_tensor, indices_tensor);
+ } else {
+ sparse_index = std::make_shared<SparseCSCIndex>(indptr_tensor, indices_tensor);
+ }
+ data = std::move(values_buffer);
+
+ return Status::OK();
+ }
+
+ std::shared_ptr<SparseIndex> sparse_index;
+ std::shared_ptr<Buffer> data;
+
+ private:
+ SparseMatrixCompressedAxis axis_;
+ const Tensor& tensor_;
+ const std::shared_ptr<DataType>& index_value_type_;
+ MemoryPool* pool_;
+};
+
+} // namespace
+
+Status MakeSparseCSXMatrixFromTensor(SparseMatrixCompressedAxis axis,
+ const Tensor& tensor,
+ const std::shared_ptr<DataType>& index_value_type,
+ MemoryPool* pool,
+ std::shared_ptr<SparseIndex>* out_sparse_index,
+ std::shared_ptr<Buffer>* out_data) {
+ SparseCSXMatrixConverter converter(axis, tensor, index_value_type, pool);
+ RETURN_NOT_OK(converter.Convert());
+
+ *out_sparse_index = converter.sparse_index;
+ *out_data = converter.data;
+ return Status::OK();
+}
+
+Result<std::shared_ptr<Tensor>> MakeTensorFromSparseCSXMatrix(
+ SparseMatrixCompressedAxis axis, MemoryPool* pool,
+ const std::shared_ptr<Tensor>& indptr, const std::shared_ptr<Tensor>& indices,
+ const int64_t non_zero_length, const std::shared_ptr<DataType>& value_type,
+ const std::vector<int64_t>& shape, const int64_t tensor_size, const uint8_t* raw_data,
+ const std::vector<std::string>& dim_names) {
+ const auto* indptr_data = indptr->raw_data();
+ const auto* indices_data = indices->raw_data();
+
+ const int indptr_elsize = GetByteWidth(*indptr->type());
+ const int indices_elsize = GetByteWidth(*indices->type());
+
+ const auto& fw_value_type = checked_cast<const FixedWidthType&>(*value_type);
+ const int value_elsize = GetByteWidth(fw_value_type);
+ ARROW_ASSIGN_OR_RAISE(auto values_buffer,
+ AllocateBuffer(value_elsize * tensor_size, pool));
+ auto values = values_buffer->mutable_data();
+ std::fill_n(values, value_elsize * tensor_size, 0);
+
+ std::vector<int64_t> strides;
+ RETURN_NOT_OK(ComputeRowMajorStrides(fw_value_type, shape, &strides));
+
+ const auto nc = shape[1];
+
+ int64_t offset = 0;
+ for (int64_t i = 0; i < indptr->size() - 1; ++i) {
+ const auto start =
+ SparseTensorConverterMixin::GetIndexValue(indptr_data, indptr_elsize);
+ const auto stop = SparseTensorConverterMixin::GetIndexValue(
+ indptr_data + indptr_elsize, indptr_elsize);
+
+ for (int64_t j = start; j < stop; ++j) {
+ const auto index = SparseTensorConverterMixin::GetIndexValue(
+ indices_data + j * indices_elsize, indices_elsize);
+ switch (axis) {
+ case SparseMatrixCompressedAxis::ROW:
+ offset = (index + i * nc) * value_elsize;
+ break;
+ case SparseMatrixCompressedAxis::COLUMN:
+ offset = (i + index * nc) * value_elsize;
+ break;
+ }
+
+ std::copy_n(raw_data, value_elsize, values + offset);
+ raw_data += value_elsize;
+ }
+
+ indptr_data += indptr_elsize;
+ }
+
+ return std::make_shared<Tensor>(value_type, std::move(values_buffer), shape, strides,
+ dim_names);
+}
+
+Result<std::shared_ptr<Tensor>> MakeTensorFromSparseCSRMatrix(
+ MemoryPool* pool, const SparseCSRMatrix* sparse_tensor) {
+ const auto& sparse_index =
+ internal::checked_cast<const SparseCSRIndex&>(*sparse_tensor->sparse_index());
+ const auto& indptr = sparse_index.indptr();
+ const auto& indices = sparse_index.indices();
+ const auto non_zero_length = sparse_tensor->non_zero_length();
+ return MakeTensorFromSparseCSXMatrix(
+ SparseMatrixCompressedAxis::ROW, pool, indptr, indices, non_zero_length,
+ sparse_tensor->type(), sparse_tensor->shape(), sparse_tensor->size(),
+ sparse_tensor->raw_data(), sparse_tensor->dim_names());
+}
+
+Result<std::shared_ptr<Tensor>> MakeTensorFromSparseCSCMatrix(
+ MemoryPool* pool, const SparseCSCMatrix* sparse_tensor) {
+ const auto& sparse_index =
+ internal::checked_cast<const SparseCSCIndex&>(*sparse_tensor->sparse_index());
+ const auto& indptr = sparse_index.indptr();
+ const auto& indices = sparse_index.indices();
+ const auto non_zero_length = sparse_tensor->non_zero_length();
+ return MakeTensorFromSparseCSXMatrix(
+ SparseMatrixCompressedAxis::COLUMN, pool, indptr, indices, non_zero_length,
+ sparse_tensor->type(), sparse_tensor->shape(), sparse_tensor->size(),
+ sparse_tensor->raw_data(), sparse_tensor->dim_names());
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/tensor/tensor_conversion_benchmark.cc b/src/arrow/cpp/src/arrow/tensor/tensor_conversion_benchmark.cc
new file mode 100644
index 000000000..8456b2c4e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/tensor/tensor_conversion_benchmark.cc
@@ -0,0 +1,230 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/sparse_tensor.h"
+#include "arrow/testing/gtest_util.h"
+
+#include <random>
+
+namespace arrow {
+
+enum ContiguousType { ROW_MAJOR, COLUMN_MAJOR, STRIDED };
+
+template <ContiguousType contiguous_type, typename ValueType, typename IndexType>
+class TensorConversionFixture : public benchmark::Fixture {
+ protected:
+ using c_value_type = typename ValueType::c_type;
+ using c_index_type = typename IndexType::c_type;
+
+ std::shared_ptr<DataType> value_type_ = TypeTraits<ValueType>::type_singleton();
+ std::shared_ptr<DataType> index_type_ = TypeTraits<IndexType>::type_singleton();
+
+ std::vector<c_value_type> values_;
+ std::shared_ptr<Tensor> tensor_;
+
+ public:
+ void SetUp(const ::benchmark::State& state) {
+ std::vector<int64_t> shape = {30, 8, 20, 9};
+ auto n = std::accumulate(shape.begin(), shape.end(), int64_t(1),
+ [](int64_t acc, int64_t i) { return acc * i; });
+ auto m = n / 100;
+
+ switch (contiguous_type) {
+ case STRIDED:
+ values_.resize(2 * n);
+ for (int64_t i = 0; i < 100; ++i) {
+ values_[2 * i * m] = static_cast<c_value_type>(i);
+ }
+ break;
+ default:
+ values_.resize(n);
+ for (int64_t i = 0; i < 100; ++i) {
+ values_[i * m] = static_cast<c_value_type>(i);
+ }
+ break;
+ }
+
+ std::vector<int64_t> strides;
+ int64_t total = sizeof(c_value_type);
+ switch (contiguous_type) {
+ case ROW_MAJOR:
+ break;
+ case COLUMN_MAJOR: {
+ for (auto i : shape) {
+ strides.push_back(total);
+ total *= i;
+ }
+ break;
+ }
+ case STRIDED: {
+ total *= 2;
+ for (auto i : shape) {
+ strides.push_back(total);
+ total *= i;
+ }
+ break;
+ }
+ }
+ ABORT_NOT_OK(
+ Tensor::Make(value_type_, Buffer::Wrap(values_), shape, strides).Value(&tensor_));
+ }
+
+ void SetUpRowMajor() {}
+};
+
+template <ContiguousType contiguous_type, typename ValueType, typename IndexType>
+class MatrixConversionFixture : public benchmark::Fixture {
+ protected:
+ using c_value_type = typename ValueType::c_type;
+ using c_index_type = typename IndexType::c_type;
+
+ std::shared_ptr<DataType> value_type_ = TypeTraits<ValueType>::type_singleton();
+ std::shared_ptr<DataType> index_type_ = TypeTraits<IndexType>::type_singleton();
+
+ std::vector<c_value_type> values_;
+ std::shared_ptr<Tensor> tensor_;
+
+ public:
+ void SetUp(const ::benchmark::State& state) {
+ std::vector<int64_t> shape = {88, 113};
+ auto n = std::accumulate(shape.begin(), shape.end(), int64_t(1),
+ [](int64_t acc, int64_t i) { return acc * i; });
+ auto m = n / 100;
+
+ switch (contiguous_type) {
+ case STRIDED:
+ values_.resize(2 * n);
+ for (int64_t i = 0; i < 100; ++i) {
+ values_[2 * i * m] = static_cast<c_value_type>(i);
+ }
+ break;
+ default:
+ values_.resize(n);
+ for (int64_t i = 0; i < 100; ++i) {
+ values_[i * m] = static_cast<c_value_type>(i);
+ }
+ break;
+ }
+
+ std::vector<int64_t> strides;
+ int64_t total = sizeof(c_value_type);
+ switch (contiguous_type) {
+ case ROW_MAJOR:
+ break;
+ case COLUMN_MAJOR: {
+ for (auto i : shape) {
+ strides.push_back(total);
+ total *= i;
+ }
+ break;
+ }
+ case STRIDED: {
+ total *= 2;
+ for (auto i : shape) {
+ strides.push_back(total);
+ total *= i;
+ }
+ break;
+ }
+ }
+ ABORT_NOT_OK(Tensor::Make(value_type_, Buffer::Wrap(values_), shape).Value(&tensor_));
+ }
+};
+
+#define DEFINE_TYPED_TENSOR_CONVERSION_FIXTURE(value_type_name) \
+ template <typename IndexType> \
+ using value_type_name##RowMajorTensorConversionFixture = \
+ TensorConversionFixture<ROW_MAJOR, value_type_name##Type, IndexType>; \
+ template <typename IndexType> \
+ using value_type_name##ColumnMajorTensorConversionFixture = \
+ TensorConversionFixture<COLUMN_MAJOR, value_type_name##Type, IndexType>; \
+ template <typename IndexType> \
+ using value_type_name##StridedTensorConversionFixture = \
+ TensorConversionFixture<STRIDED, value_type_name##Type, IndexType>
+
+DEFINE_TYPED_TENSOR_CONVERSION_FIXTURE(Int8);
+DEFINE_TYPED_TENSOR_CONVERSION_FIXTURE(Float);
+DEFINE_TYPED_TENSOR_CONVERSION_FIXTURE(Double);
+
+#define DEFINE_TYPED_MATRIX_CONVERSION_FIXTURE(value_type_name) \
+ template <typename IndexType> \
+ using value_type_name##RowMajorMatrixConversionFixture = \
+ MatrixConversionFixture<ROW_MAJOR, value_type_name##Type, IndexType>; \
+ template <typename IndexType> \
+ using value_type_name##ColumnMajorMatrixConversionFixture = \
+ MatrixConversionFixture<COLUMN_MAJOR, value_type_name##Type, IndexType>; \
+ template <typename IndexType> \
+ using value_type_name##StridedMatrixConversionFixture = \
+ MatrixConversionFixture<STRIDED, value_type_name##Type, IndexType>
+
+DEFINE_TYPED_MATRIX_CONVERSION_FIXTURE(Int8);
+DEFINE_TYPED_MATRIX_CONVERSION_FIXTURE(Float);
+DEFINE_TYPED_MATRIX_CONVERSION_FIXTURE(Double);
+
+#define BENCHMARK_CONVERT_TENSOR_(Contiguous, kind, format, value_type_name, \
+ index_type_name) \
+ BENCHMARK_TEMPLATE_F(value_type_name##Contiguous##kind##ConversionFixture, \
+ ConvertToSparse##format##kind##index_type_name, \
+ index_type_name##Type) \
+ (benchmark::State & state) { /* NOLINT non-const reference */ \
+ std::shared_ptr<Sparse##format##kind> sparse_tensor; \
+ for (auto _ : state) { \
+ ABORT_NOT_OK(Sparse##format##kind::Make(*this->tensor_, this->index_type_) \
+ .Value(&sparse_tensor)); \
+ } \
+ benchmark::DoNotOptimize(sparse_tensor); \
+ state.SetItemsProcessed(state.iterations() * this->tensor_->size()); \
+ state.SetBytesProcessed(state.iterations() * this->tensor_->data()->size()); \
+ }
+
+#define BENCHMARK_CONVERT_TENSOR(kind, format, value_type_name, index_type_name) \
+ BENCHMARK_CONVERT_TENSOR_(RowMajor, kind, format, value_type_name, index_type_name); \
+ BENCHMARK_CONVERT_TENSOR_(ColumnMajor, kind, format, value_type_name, \
+ index_type_name); \
+ BENCHMARK_CONVERT_TENSOR_(Strided, kind, format, value_type_name, index_type_name)
+
+BENCHMARK_CONVERT_TENSOR(Tensor, COO, Int8, Int32);
+BENCHMARK_CONVERT_TENSOR(Tensor, COO, Int8, Int64);
+BENCHMARK_CONVERT_TENSOR(Tensor, COO, Float, Int32);
+BENCHMARK_CONVERT_TENSOR(Tensor, COO, Float, Int64);
+BENCHMARK_CONVERT_TENSOR(Tensor, COO, Double, Int32);
+BENCHMARK_CONVERT_TENSOR(Tensor, COO, Double, Int64);
+
+BENCHMARK_CONVERT_TENSOR(Matrix, CSR, Int8, Int8);
+BENCHMARK_CONVERT_TENSOR(Matrix, CSR, Int8, Int16);
+BENCHMARK_CONVERT_TENSOR(Matrix, CSR, Float, Int32);
+BENCHMARK_CONVERT_TENSOR(Matrix, CSR, Float, Int64);
+BENCHMARK_CONVERT_TENSOR(Matrix, CSR, Double, Int32);
+BENCHMARK_CONVERT_TENSOR(Matrix, CSR, Double, Int64);
+
+BENCHMARK_CONVERT_TENSOR(Matrix, CSC, Int8, Int32);
+BENCHMARK_CONVERT_TENSOR(Matrix, CSC, Int8, Int64);
+BENCHMARK_CONVERT_TENSOR(Matrix, CSC, Float, Int32);
+BENCHMARK_CONVERT_TENSOR(Matrix, CSC, Float, Int64);
+BENCHMARK_CONVERT_TENSOR(Matrix, CSC, Double, Int32);
+BENCHMARK_CONVERT_TENSOR(Matrix, CSC, Double, Int64);
+
+BENCHMARK_CONVERT_TENSOR(Tensor, CSF, Int8, Int32);
+BENCHMARK_CONVERT_TENSOR(Tensor, CSF, Int8, Int64);
+BENCHMARK_CONVERT_TENSOR(Tensor, CSF, Float, Int32);
+BENCHMARK_CONVERT_TENSOR(Tensor, CSF, Float, Int64);
+BENCHMARK_CONVERT_TENSOR(Tensor, CSF, Double, Int32);
+BENCHMARK_CONVERT_TENSOR(Tensor, CSF, Double, Int64);
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/tensor_test.cc b/src/arrow/cpp/src/arrow/tensor_test.cc
new file mode 100644
index 000000000..efb1b8d92
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/tensor_test.cc
@@ -0,0 +1,749 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Unit tests for DataType (and subclasses), Field, and Schema
+
+#include <cmath>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
+
+#include "arrow/buffer.h"
+#include "arrow/tensor.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+
+namespace arrow {
+
+void AssertCountNonZero(const Tensor& t, int64_t expected) {
+ ASSERT_OK_AND_ASSIGN(int64_t count, t.CountNonZero());
+ ASSERT_EQ(count, expected);
+}
+
+TEST(TestComputeRowMajorStrides, ZeroDimension) {
+ std::vector<int64_t> strides;
+
+ std::vector<int64_t> shape1 = {0, 2, 3};
+ ASSERT_OK(arrow::internal::ComputeRowMajorStrides(DoubleType(), shape1, &strides));
+ EXPECT_THAT(strides,
+ testing::ElementsAre(sizeof(double), sizeof(double), sizeof(double)));
+
+ std::vector<int64_t> shape2 = {2, 0, 3};
+ strides.clear();
+ ASSERT_OK(arrow::internal::ComputeRowMajorStrides(DoubleType(), shape2, &strides));
+ EXPECT_THAT(strides,
+ testing::ElementsAre(sizeof(double), sizeof(double), sizeof(double)));
+
+ std::vector<int64_t> shape3 = {2, 3, 0};
+ strides.clear();
+ ASSERT_OK(arrow::internal::ComputeRowMajorStrides(DoubleType(), shape3, &strides));
+ EXPECT_THAT(strides,
+ testing::ElementsAre(sizeof(double), sizeof(double), sizeof(double)));
+}
+
+TEST(TestComputeRowMajorStrides, MaximumSize) {
+ constexpr uint64_t total_length =
+ 1 + static_cast<uint64_t>(std::numeric_limits<int64_t>::max());
+ std::vector<int64_t> shape = {2, 2, static_cast<int64_t>(total_length / 4)};
+
+ std::vector<int64_t> strides;
+ ASSERT_OK(arrow::internal::ComputeRowMajorStrides(Int8Type(), shape, &strides));
+ EXPECT_THAT(strides, testing::ElementsAre(2 * shape[2], shape[2], 1));
+}
+
+TEST(TestComputeRowMajorStrides, OverflowCase) {
+ constexpr uint64_t total_length =
+ 1 + static_cast<uint64_t>(std::numeric_limits<int64_t>::max());
+ std::vector<int64_t> shape = {2, 2, static_cast<int64_t>(total_length / 4)};
+
+ std::vector<int64_t> strides;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr(
+ "Row-major strides computed from shape would not fit in 64-bit integer"),
+ arrow::internal::ComputeRowMajorStrides(Int16Type(), shape, &strides));
+ EXPECT_EQ(0, strides.size());
+}
+
+TEST(TestComputeColumnMajorStrides, ZeroDimension) {
+ std::vector<int64_t> strides;
+
+ std::vector<int64_t> shape1 = {0, 2, 3};
+ ASSERT_OK(arrow::internal::ComputeColumnMajorStrides(DoubleType(), shape1, &strides));
+ EXPECT_THAT(strides,
+ testing::ElementsAre(sizeof(double), sizeof(double), sizeof(double)));
+
+ std::vector<int64_t> shape2 = {2, 0, 3};
+ strides.clear();
+ ASSERT_OK(arrow::internal::ComputeColumnMajorStrides(DoubleType(), shape2, &strides));
+ EXPECT_THAT(strides,
+ testing::ElementsAre(sizeof(double), sizeof(double), sizeof(double)));
+
+ std::vector<int64_t> shape3 = {2, 3, 0};
+ strides.clear();
+ ASSERT_OK(arrow::internal::ComputeColumnMajorStrides(DoubleType(), shape3, &strides));
+ EXPECT_THAT(strides,
+ testing::ElementsAre(sizeof(double), sizeof(double), sizeof(double)));
+}
+
+TEST(TestComputeColumnMajorStrides, MaximumSize) {
+ constexpr uint64_t total_length =
+ 1 + static_cast<uint64_t>(std::numeric_limits<int64_t>::max());
+ std::vector<int64_t> shape = {static_cast<int64_t>(total_length / 4), 2, 2};
+
+ std::vector<int64_t> strides;
+ ASSERT_OK(arrow::internal::ComputeColumnMajorStrides(Int8Type(), shape, &strides));
+ EXPECT_THAT(strides, testing::ElementsAre(1, shape[0], 2 * shape[0]));
+}
+
+TEST(TestComputeColumnMajorStrides, OverflowCase) {
+ constexpr uint64_t total_length =
+ 1 + static_cast<uint64_t>(std::numeric_limits<int64_t>::max());
+ std::vector<int64_t> shape = {static_cast<int64_t>(total_length / 4), 2, 2};
+
+ std::vector<int64_t> strides;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr(
+ "Column-major strides computed from shape would not fit in 64-bit integer"),
+ arrow::internal::ComputeColumnMajorStrides(Int16Type(), shape, &strides));
+ EXPECT_EQ(0, strides.size());
+}
+
+TEST(TestTensor, MakeRowMajor) {
+ std::vector<int64_t> shape = {3, 6};
+ std::vector<int64_t> strides = {sizeof(double) * 6, sizeof(double)};
+ std::vector<double> values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ auto data = Buffer::Wrap(values);
+
+ // without strides and dim_names
+ std::shared_ptr<Tensor> tensor1;
+ ASSERT_OK_AND_ASSIGN(tensor1, Tensor::Make(float64(), data, shape));
+ EXPECT_EQ(float64(), tensor1->type());
+ EXPECT_EQ(shape, tensor1->shape());
+ EXPECT_EQ(strides, tensor1->strides());
+ EXPECT_EQ(std::vector<std::string>{}, tensor1->dim_names());
+ EXPECT_EQ(data->data(), tensor1->raw_data());
+ EXPECT_TRUE(tensor1->is_row_major());
+ EXPECT_FALSE(tensor1->is_column_major());
+ EXPECT_TRUE(tensor1->is_contiguous());
+
+ // without dim_names
+ std::shared_ptr<Tensor> tensor2;
+ ASSERT_OK_AND_ASSIGN(tensor2, Tensor::Make(float64(), data, shape, strides));
+ EXPECT_EQ(float64(), tensor2->type());
+ EXPECT_EQ(shape, tensor2->shape());
+ EXPECT_EQ(strides, tensor2->strides());
+ EXPECT_EQ(std::vector<std::string>{}, tensor2->dim_names());
+ EXPECT_EQ(data->data(), tensor2->raw_data());
+ EXPECT_TRUE(tensor2->Equals(*tensor1));
+ EXPECT_TRUE(tensor2->is_row_major());
+ EXPECT_FALSE(tensor2->is_column_major());
+ EXPECT_TRUE(tensor2->is_contiguous());
+
+ // without strides
+ std::vector<std::string> dim_names = {"foo", "bar"};
+ std::shared_ptr<Tensor> tensor3;
+ ASSERT_OK_AND_ASSIGN(tensor3, Tensor::Make(float64(), data, shape, {}, dim_names));
+ EXPECT_EQ(float64(), tensor3->type());
+ EXPECT_EQ(shape, tensor3->shape());
+ EXPECT_EQ(strides, tensor3->strides());
+ EXPECT_EQ(dim_names, tensor3->dim_names());
+ EXPECT_EQ(data->data(), tensor3->raw_data());
+ EXPECT_TRUE(tensor3->Equals(*tensor1));
+ EXPECT_TRUE(tensor3->Equals(*tensor2));
+
+ // supply all parameters
+ std::shared_ptr<Tensor> tensor4;
+ ASSERT_OK_AND_ASSIGN(tensor4, Tensor::Make(float64(), data, shape, strides, dim_names));
+ EXPECT_EQ(float64(), tensor4->type());
+ EXPECT_EQ(shape, tensor4->shape());
+ EXPECT_EQ(strides, tensor4->strides());
+ EXPECT_EQ(dim_names, tensor4->dim_names());
+ EXPECT_EQ(data->data(), tensor4->raw_data());
+ EXPECT_TRUE(tensor4->Equals(*tensor1));
+ EXPECT_TRUE(tensor4->Equals(*tensor2));
+ EXPECT_TRUE(tensor4->Equals(*tensor3));
+}
+
+TEST(TestTensor, MakeColumnMajor) {
+ std::vector<int64_t> shape = {3, 6};
+ std::vector<int64_t> strides = {sizeof(double), sizeof(double) * 3};
+ std::vector<double> values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ auto data = Buffer::Wrap(values);
+
+ std::shared_ptr<Tensor> tensor;
+ ASSERT_OK_AND_ASSIGN(tensor, Tensor::Make(float64(), data, shape, strides));
+ EXPECT_FALSE(tensor->is_row_major());
+ EXPECT_TRUE(tensor->is_column_major());
+ EXPECT_TRUE(tensor->is_contiguous());
+}
+
+TEST(TestTensor, MakeStrided) {
+ std::vector<int64_t> shape = {3, 6};
+ std::vector<int64_t> strides = {sizeof(double) * 12, sizeof(double) * 2};
+ std::vector<double> values = {1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0,
+ 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0};
+ auto data = Buffer::Wrap(values);
+
+ std::shared_ptr<Tensor> tensor;
+ ASSERT_OK_AND_ASSIGN(tensor, Tensor::Make(float64(), data, shape, strides));
+ EXPECT_FALSE(tensor->is_row_major());
+ EXPECT_FALSE(tensor->is_column_major());
+ EXPECT_FALSE(tensor->is_contiguous());
+}
+
+TEST(TestTensor, MakeZeroDim) {
+ std::vector<int64_t> shape = {};
+ std::vector<double> values = {355 / 113.0};
+ auto data = Buffer::Wrap(values);
+ std::shared_ptr<Tensor> tensor;
+
+ ASSERT_OK_AND_ASSIGN(tensor, Tensor::Make(float64(), data, shape));
+ EXPECT_EQ(1, tensor->size());
+ EXPECT_EQ(shape, tensor->shape());
+ EXPECT_EQ(shape, tensor->strides());
+ EXPECT_EQ(data->data(), tensor->raw_data());
+ EXPECT_EQ(values[0], tensor->Value<DoubleType>({}));
+}
+
+TEST(TestTensor, MakeFailureCases) {
+ std::vector<int64_t> shape = {3, 6};
+ std::vector<double> values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ auto data = Buffer::Wrap(values);
+
+ // null type
+ ASSERT_RAISES(Invalid, Tensor::Make(nullptr, data, shape));
+
+ // invalid type
+ ASSERT_RAISES(Invalid, Tensor::Make(binary(), data, shape));
+
+ // null data
+ ASSERT_RAISES(Invalid, Tensor::Make(float64(), nullptr, shape));
+
+ // negative items in shape
+ ASSERT_RAISES(Invalid, Tensor::Make(float64(), data, {-3, 6}));
+
+ // overflow in positive strides computation
+ constexpr uint64_t total_length =
+ 1 + static_cast<uint64_t>(std::numeric_limits<int64_t>::max());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr(
+ "Row-major strides computed from shape would not fit in 64-bit integer"),
+ Tensor::Make(float64(), data, {2, 2, static_cast<int64_t>(total_length / 4)}));
+
+ // negative strides are prohibited
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, testing::HasSubstr("negative strides not supported"),
+ Tensor::Make(float64(), data, {18}, {-(int)sizeof(double)}));
+
+ // invalid stride length
+ ASSERT_RAISES(Invalid, Tensor::Make(float64(), data, shape, {sizeof(double)}));
+ ASSERT_RAISES(Invalid, Tensor::Make(float64(), data, shape,
+ {sizeof(double), sizeof(double), sizeof(double)}));
+
+ // invalid stride values to involve buffer over run
+ ASSERT_RAISES(Invalid, Tensor::Make(float64(), data, shape,
+ {sizeof(double) * 6, sizeof(double) * 2}));
+ ASSERT_RAISES(Invalid, Tensor::Make(float64(), data, shape,
+ {sizeof(double) * 12, sizeof(double)}));
+
+ // too many dim_names are supplied
+ ASSERT_RAISES(Invalid, Tensor::Make(float64(), data, shape, {}, {"foo", "bar", "baz"}));
+}
+
+TEST(TestTensor, ZeroDim) {
+ const int64_t values = 1;
+ std::vector<int64_t> shape = {};
+
+ using T = int64_t;
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> buffer,
+ AllocateBuffer(values * sizeof(T)));
+
+ Tensor t0(int64(), buffer, shape);
+
+ ASSERT_EQ(1, t0.size());
+}
+
+TEST(TestTensor, BasicCtors) {
+ const int64_t values = 24;
+ std::vector<int64_t> shape = {4, 6};
+ std::vector<int64_t> strides = {48, 8};
+ std::vector<std::string> dim_names = {"foo", "bar"};
+
+ using T = int64_t;
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> buffer,
+ AllocateBuffer(values * sizeof(T)));
+
+ Tensor t1(int64(), buffer, shape);
+ Tensor t2(int64(), buffer, shape, strides);
+ Tensor t3(int64(), buffer, shape, strides, dim_names);
+
+ ASSERT_EQ(24, t1.size());
+ ASSERT_TRUE(t1.is_mutable());
+
+ ASSERT_EQ(strides, t1.strides());
+ ASSERT_EQ(strides, t2.strides());
+
+ ASSERT_EQ(std::vector<std::string>({"foo", "bar"}), t3.dim_names());
+ ASSERT_EQ("foo", t3.dim_name(0));
+ ASSERT_EQ("bar", t3.dim_name(1));
+
+ ASSERT_EQ(std::vector<std::string>({}), t1.dim_names());
+ ASSERT_EQ("", t1.dim_name(0));
+ ASSERT_EQ("", t1.dim_name(1));
+}
+
+TEST(TestTensor, IsContiguous) {
+ const int64_t values = 24;
+ std::vector<int64_t> shape = {4, 6};
+ std::vector<int64_t> strides = {48, 8};
+
+ using T = int64_t;
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> buffer,
+ AllocateBuffer(values * sizeof(T)));
+
+ std::vector<int64_t> c_strides = {48, 8};
+ std::vector<int64_t> f_strides = {8, 32};
+ std::vector<int64_t> noncontig_strides = {8, 8};
+ Tensor t1(int64(), buffer, shape, c_strides);
+ Tensor t2(int64(), buffer, shape, f_strides);
+ Tensor t3(int64(), buffer, shape, noncontig_strides);
+
+ ASSERT_TRUE(t1.is_contiguous());
+ ASSERT_TRUE(t2.is_contiguous());
+ ASSERT_FALSE(t3.is_contiguous());
+}
+
+TEST(TestTensor, ZeroSizedTensor) {
+ std::vector<int64_t> shape = {0};
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> buffer, AllocateBuffer(0));
+
+ Tensor t(int64(), buffer, shape);
+ ASSERT_EQ(t.strides().size(), 1);
+}
+
+TEST(TestTensor, CountNonZeroForZeroSizedTensor) {
+ std::vector<int64_t> shape = {0};
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> buffer, AllocateBuffer(0));
+
+ Tensor t(int64(), buffer, shape);
+ AssertCountNonZero(t, 0);
+}
+
+TEST(TestTensor, CountNonZeroForContiguousTensor) {
+ std::vector<int64_t> shape = {4, 6};
+ std::vector<int64_t> values = {1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16};
+ std::shared_ptr<Buffer> buffer = Buffer::Wrap(values);
+
+ std::vector<int64_t> c_strides = {48, 8};
+ std::vector<int64_t> f_strides = {8, 32};
+ Tensor t1(int64(), buffer, shape, c_strides);
+ Tensor t2(int64(), buffer, shape, f_strides);
+
+ ASSERT_TRUE(t1.is_contiguous());
+ ASSERT_TRUE(t2.is_contiguous());
+ AssertCountNonZero(t1, 12);
+ AssertCountNonZero(t2, 12);
+}
+
+TEST(TestTensor, CountNonZeroForNonContiguousTensor) {
+ std::vector<int64_t> shape = {4, 4};
+ std::vector<int64_t> values = {
+ 1, 0, 2, 0, 0, 3, 0, 4, 5, 0, 6, 0, 7, 0, 8, 0,
+ 0, 11, 0, 12, 13, 0, 14, 0, 0, 15, 0, 16, 0, 15, 0, 16,
+ };
+ std::shared_ptr<Buffer> buffer = Buffer::Wrap(values);
+
+ std::vector<int64_t> noncontig_strides = {64, 16};
+ Tensor t(int64(), buffer, shape, noncontig_strides);
+
+ ASSERT_FALSE(t.is_contiguous());
+ AssertCountNonZero(t, 8);
+}
+
+TEST(TestTensor, ElementAccessInt32) {
+ std::vector<int64_t> shape = {2, 3};
+ std::vector<int32_t> values = {1, 2, 3, 4, 5, 6};
+ std::vector<int64_t> c_strides = {sizeof(int32_t) * 3, sizeof(int32_t)};
+ std::vector<int64_t> f_strides = {sizeof(int32_t), sizeof(int32_t) * 2};
+ Tensor tc(int64(), Buffer::Wrap(values), shape, c_strides);
+ Tensor tf(int64(), Buffer::Wrap(values), shape, f_strides);
+
+ EXPECT_EQ(1, tc.Value<Int32Type>({0, 0}));
+ EXPECT_EQ(2, tc.Value<Int32Type>({0, 1}));
+ EXPECT_EQ(4, tc.Value<Int32Type>({1, 0}));
+
+ EXPECT_EQ(1, tf.Value<Int32Type>({0, 0}));
+ EXPECT_EQ(3, tf.Value<Int32Type>({0, 1}));
+ EXPECT_EQ(2, tf.Value<Int32Type>({1, 0}));
+
+ // Tensor::Value<T>() doesn't prohibit element access if the type T is different from
+ // the value type of the tensor
+ EXPECT_NO_THROW({
+ int32_t x = 3;
+ EXPECT_EQ(*reinterpret_cast<int8_t*>(&x), tc.Value<Int8Type>({0, 2}));
+
+ union {
+ int64_t i64;
+ struct {
+ int32_t first;
+ int32_t second;
+ } i32;
+ } y;
+ y.i32.first = 4;
+ y.i32.second = 5;
+ EXPECT_EQ(y.i64, tc.Value<Int64Type>({1, 0}));
+ });
+}
+
+TEST(TestTensor, EqualsInt64) {
+ std::vector<int64_t> shape = {4, 4};
+
+ std::vector<int64_t> c_values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+ std::vector<int64_t> c_strides = {32, 8};
+ Tensor tc1(int64(), Buffer::Wrap(c_values), shape, c_strides);
+
+ std::vector<int64_t> c_values_2 = c_values;
+ Tensor tc2(int64(), Buffer::Wrap(c_values_2), shape, c_strides);
+
+ std::vector<int64_t> f_values = {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16};
+ Tensor tc3(int64(), Buffer::Wrap(f_values), shape, c_strides);
+
+ Tensor tc4(int64(), Buffer::Wrap(c_values), {8, 2}, {16, 8});
+
+ std::vector<int64_t> f_strides = {8, 32};
+ Tensor tf1(int64(), Buffer::Wrap(f_values), shape, f_strides);
+
+ std::vector<int64_t> f_values_2 = f_values;
+ Tensor tf2(int64(), Buffer::Wrap(f_values_2), shape, f_strides);
+
+ Tensor tf3(int64(), Buffer::Wrap(c_values), shape, f_strides);
+
+ std::vector<int64_t> nc_values = {1, 0, 5, 0, 9, 0, 13, 0, 2, 0, 6, 0, 10, 0, 14, 0,
+ 3, 0, 7, 0, 11, 0, 15, 0, 4, 0, 8, 0, 12, 0, 16, 0};
+ std::vector<int64_t> nc_strides = {16, 64};
+ Tensor tnc(int64(), Buffer::Wrap(nc_values), shape, nc_strides);
+
+ ASSERT_TRUE(tc1.is_contiguous());
+ ASSERT_TRUE(tc1.is_row_major());
+
+ ASSERT_TRUE(tf1.is_contiguous());
+ ASSERT_TRUE(tf1.is_column_major());
+
+ ASSERT_FALSE(tnc.is_contiguous());
+
+ // same object
+ EXPECT_TRUE(tc1.Equals(tc1));
+ EXPECT_TRUE(tf1.Equals(tf1));
+ EXPECT_TRUE(tnc.Equals(tnc));
+
+ // different memory
+ EXPECT_TRUE(tc1.Equals(tc2));
+ EXPECT_TRUE(tf1.Equals(tf2));
+ EXPECT_FALSE(tc1.Equals(tc3));
+
+ // different shapes but same data
+ EXPECT_FALSE(tc1.Equals(tc4));
+
+ // row-major and column-major
+ EXPECT_TRUE(tc1.Equals(tf1));
+ EXPECT_FALSE(tc3.Equals(tf1));
+
+ // row-major and non-contiguous
+ EXPECT_TRUE(tc1.Equals(tnc));
+ EXPECT_FALSE(tc3.Equals(tnc));
+
+ // column-major and non-contiguous
+ EXPECT_TRUE(tf1.Equals(tnc));
+ EXPECT_FALSE(tf3.Equals(tnc));
+
+ // zero-size tensor
+ ASSERT_OK_AND_ASSIGN(auto empty_buffer1, AllocateBuffer(0));
+ ASSERT_OK_AND_ASSIGN(auto empty_buffer2, AllocateBuffer(0));
+ Tensor empty1(int64(), std::move(empty_buffer1), {0});
+ Tensor empty2(int64(), std::move(empty_buffer2), {0});
+ EXPECT_FALSE(empty1.Equals(tc1));
+ EXPECT_TRUE(empty1.Equals(empty2));
+}
+
+template <typename DataType>
+class TestFloatTensor : public ::testing::Test {};
+
+TYPED_TEST_SUITE_P(TestFloatTensor);
+
+TYPED_TEST_P(TestFloatTensor, Equals) {
+ using DataType = TypeParam;
+ using c_data_type = typename DataType::c_type;
+ const int unit_size = sizeof(c_data_type);
+
+ std::vector<int64_t> shape = {4, 4};
+
+ std::vector<c_data_type> c_values = {1, 2, 3, 4, 5, 6, 7, 8,
+ 9, 10, 11, 12, 13, 14, 15, 16};
+ std::vector<int64_t> c_strides = {unit_size * shape[1], unit_size};
+ Tensor tc1(TypeTraits<DataType>::type_singleton(), Buffer::Wrap(c_values), shape,
+ c_strides);
+
+ std::vector<c_data_type> c_values_2 = c_values;
+ Tensor tc2(TypeTraits<DataType>::type_singleton(), Buffer::Wrap(c_values_2), shape,
+ c_strides);
+
+ std::vector<c_data_type> f_values = {1, 5, 9, 13, 2, 6, 10, 14,
+ 3, 7, 11, 15, 4, 8, 12, 16};
+ Tensor tc3(TypeTraits<DataType>::type_singleton(), Buffer::Wrap(f_values), shape,
+ c_strides);
+
+ std::vector<int64_t> f_strides = {unit_size, unit_size * shape[0]};
+ Tensor tf1(TypeTraits<DataType>::type_singleton(), Buffer::Wrap(f_values), shape,
+ f_strides);
+
+ std::vector<c_data_type> f_values_2 = f_values;
+ Tensor tf2(TypeTraits<DataType>::type_singleton(), Buffer::Wrap(f_values_2), shape,
+ f_strides);
+
+ Tensor tf3(TypeTraits<DataType>::type_singleton(), Buffer::Wrap(c_values), shape,
+ f_strides);
+
+ std::vector<c_data_type> nc_values = {1, 0, 5, 0, 9, 0, 13, 0, 2, 0, 6,
+ 0, 10, 0, 14, 0, 3, 0, 7, 0, 11, 0,
+ 15, 0, 4, 0, 8, 0, 12, 0, 16, 0};
+ std::vector<int64_t> nc_strides = {unit_size * 2, unit_size * 2 * shape[0]};
+ Tensor tnc(TypeTraits<DataType>::type_singleton(), Buffer::Wrap(nc_values), shape,
+ nc_strides);
+
+ ASSERT_TRUE(tc1.is_contiguous());
+ ASSERT_TRUE(tc1.is_row_major());
+
+ ASSERT_TRUE(tf1.is_contiguous());
+ ASSERT_TRUE(tf1.is_column_major());
+
+ ASSERT_FALSE(tnc.is_contiguous());
+
+ // same object
+ EXPECT_TRUE(tc1.Equals(tc1));
+ EXPECT_TRUE(tf1.Equals(tf1));
+ EXPECT_TRUE(tnc.Equals(tnc));
+
+ // different memory
+ EXPECT_TRUE(tc1.Equals(tc2));
+ EXPECT_TRUE(tf1.Equals(tf2));
+ EXPECT_FALSE(tc1.Equals(tc3));
+
+ // row-major and column-major
+ EXPECT_TRUE(tc1.Equals(tf1));
+ EXPECT_FALSE(tc3.Equals(tf1));
+
+ // row-major and non-contiguous
+ EXPECT_TRUE(tc1.Equals(tnc));
+ EXPECT_FALSE(tc3.Equals(tnc));
+
+ // column-major and non-contiguous
+ EXPECT_TRUE(tf1.Equals(tnc));
+ EXPECT_FALSE(tf3.Equals(tnc));
+
+ // tensors with NaNs
+ const c_data_type nan_value = static_cast<c_data_type>(NAN);
+ c_values[0] = nan_value;
+ EXPECT_TRUE(std::isnan(tc1.Value<DataType>({0, 0})));
+ EXPECT_FALSE(tc1.Equals(tc1)); // same object
+ EXPECT_TRUE(tc1.Equals(tc1, EqualOptions().nans_equal(true))); // same object
+ EXPECT_FALSE(std::isnan(tc2.Value<DataType>({0, 0})));
+ EXPECT_FALSE(tc1.Equals(tc2)); // different memory
+ EXPECT_FALSE(tc1.Equals(tc2, EqualOptions().nans_equal(true))); // different memory
+
+ c_values_2[0] = nan_value;
+ EXPECT_TRUE(std::isnan(tc2.Value<DataType>({0, 0})));
+ EXPECT_FALSE(tc1.Equals(tc2)); // different memory
+ EXPECT_TRUE(tc1.Equals(tc2, EqualOptions().nans_equal(true))); // different memory
+}
+
+REGISTER_TYPED_TEST_SUITE_P(TestFloatTensor, Equals);
+
+INSTANTIATE_TYPED_TEST_SUITE_P(Float32, TestFloatTensor, FloatType);
+INSTANTIATE_TYPED_TEST_SUITE_P(Float64, TestFloatTensor, DoubleType);
+
+TEST(TestNumericTensor, Make) {
+ std::vector<int64_t> shape = {3, 6};
+ std::vector<int64_t> strides = {sizeof(double) * 6, sizeof(double)};
+ std::vector<double> values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ auto data = Buffer::Wrap(values);
+
+ // without strides and dim_names
+ std::shared_ptr<NumericTensor<DoubleType>> tensor1;
+ ASSERT_OK_AND_ASSIGN(tensor1, NumericTensor<DoubleType>::Make(data, shape));
+ EXPECT_EQ(float64(), tensor1->type());
+ EXPECT_EQ(shape, tensor1->shape());
+ EXPECT_EQ(strides, tensor1->strides());
+ EXPECT_EQ(data->data(), tensor1->raw_data());
+ EXPECT_EQ(std::vector<std::string>{}, tensor1->dim_names());
+
+ // without dim_names
+ std::shared_ptr<NumericTensor<DoubleType>> tensor2;
+ ASSERT_OK_AND_ASSIGN(tensor2, NumericTensor<DoubleType>::Make(data, shape, strides));
+ EXPECT_EQ(float64(), tensor2->type());
+ EXPECT_EQ(shape, tensor2->shape());
+ EXPECT_EQ(strides, tensor2->strides());
+ EXPECT_EQ(std::vector<std::string>{}, tensor2->dim_names());
+ EXPECT_EQ(data->data(), tensor2->raw_data());
+ EXPECT_TRUE(tensor2->Equals(*tensor1));
+
+ // without strides
+ std::vector<std::string> dim_names = {"foo", "bar"};
+ std::shared_ptr<NumericTensor<DoubleType>> tensor3;
+ ASSERT_OK_AND_ASSIGN(tensor3,
+ NumericTensor<DoubleType>::Make(data, shape, {}, dim_names));
+ EXPECT_EQ(float64(), tensor3->type());
+ EXPECT_EQ(shape, tensor3->shape());
+ EXPECT_EQ(strides, tensor3->strides());
+ EXPECT_EQ(dim_names, tensor3->dim_names());
+ EXPECT_EQ(data->data(), tensor3->raw_data());
+ EXPECT_TRUE(tensor3->Equals(*tensor1));
+ EXPECT_TRUE(tensor3->Equals(*tensor2));
+
+ // supply all parameters
+ std::shared_ptr<NumericTensor<DoubleType>> tensor4;
+ ASSERT_OK_AND_ASSIGN(tensor4,
+ NumericTensor<DoubleType>::Make(data, shape, strides, dim_names));
+ EXPECT_EQ(float64(), tensor4->type());
+ EXPECT_EQ(shape, tensor4->shape());
+ EXPECT_EQ(strides, tensor4->strides());
+ EXPECT_EQ(dim_names, tensor4->dim_names());
+ EXPECT_EQ(data->data(), tensor4->raw_data());
+ EXPECT_TRUE(tensor4->Equals(*tensor1));
+ EXPECT_TRUE(tensor4->Equals(*tensor2));
+ EXPECT_TRUE(tensor4->Equals(*tensor3));
+}
+
+TEST(TestNumericTensor, ElementAccessWithRowMajorStrides) {
+ std::vector<int64_t> shape = {3, 4};
+
+ std::vector<int64_t> values_i64 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ std::shared_ptr<Buffer> buffer_i64(Buffer::Wrap(values_i64));
+ NumericTensor<Int64Type> t_i64(buffer_i64, shape);
+
+ ASSERT_TRUE(t_i64.is_row_major());
+ ASSERT_FALSE(t_i64.is_column_major());
+ ASSERT_TRUE(t_i64.is_contiguous());
+ ASSERT_EQ(1, t_i64.Value({0, 0}));
+ ASSERT_EQ(5, t_i64.Value({1, 0}));
+ ASSERT_EQ(6, t_i64.Value({1, 1}));
+ ASSERT_EQ(11, t_i64.Value({2, 2}));
+
+ std::vector<float> values_f32 = {1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f,
+ 7.1f, 8.1f, 9.1f, 10.1f, 11.1f, 12.1f};
+ std::shared_ptr<Buffer> buffer_f32(Buffer::Wrap(values_f32));
+ NumericTensor<FloatType> t_f32(buffer_f32, shape);
+
+ ASSERT_TRUE(t_f32.is_row_major());
+ ASSERT_FALSE(t_f32.is_column_major());
+ ASSERT_TRUE(t_f32.is_contiguous());
+ ASSERT_EQ(1.1f, t_f32.Value({0, 0}));
+ ASSERT_EQ(5.1f, t_f32.Value({1, 0}));
+ ASSERT_EQ(6.1f, t_f32.Value({1, 1}));
+ ASSERT_EQ(11.1f, t_f32.Value({2, 2}));
+}
+
+TEST(TestNumericTensor, ElementAccessWithColumnMajorStrides) {
+ std::vector<int64_t> shape = {3, 4};
+
+ const int64_t i64_size = sizeof(int64_t);
+ std::vector<int64_t> values_i64 = {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12};
+ std::vector<int64_t> strides_i64 = {i64_size, i64_size * 3};
+ std::shared_ptr<Buffer> buffer_i64(Buffer::Wrap(values_i64));
+ NumericTensor<Int64Type> t_i64(buffer_i64, shape, strides_i64);
+
+ ASSERT_TRUE(t_i64.is_column_major());
+ ASSERT_FALSE(t_i64.is_row_major());
+ ASSERT_TRUE(t_i64.is_contiguous());
+ ASSERT_EQ(1, t_i64.Value({0, 0}));
+ ASSERT_EQ(2, t_i64.Value({0, 1}));
+ ASSERT_EQ(4, t_i64.Value({0, 3}));
+ ASSERT_EQ(5, t_i64.Value({1, 0}));
+ ASSERT_EQ(6, t_i64.Value({1, 1}));
+ ASSERT_EQ(11, t_i64.Value({2, 2}));
+
+ const int64_t f32_size = sizeof(float);
+ std::vector<float> values_f32 = {1.1f, 5.1f, 9.1f, 2.1f, 6.1f, 10.1f,
+ 3.1f, 7.1f, 11.1f, 4.1f, 8.1f, 12.1f};
+ std::vector<int64_t> strides_f32 = {f32_size, f32_size * 3};
+ std::shared_ptr<Buffer> buffer_f32(Buffer::Wrap(values_f32));
+ NumericTensor<FloatType> t_f32(buffer_f32, shape, strides_f32);
+
+ ASSERT_TRUE(t_f32.is_column_major());
+ ASSERT_FALSE(t_f32.is_row_major());
+ ASSERT_TRUE(t_f32.is_contiguous());
+ ASSERT_EQ(1.1f, t_f32.Value({0, 0}));
+ ASSERT_EQ(2.1f, t_f32.Value({0, 1}));
+ ASSERT_EQ(4.1f, t_f32.Value({0, 3}));
+ ASSERT_EQ(5.1f, t_f32.Value({1, 0}));
+ ASSERT_EQ(6.1f, t_f32.Value({1, 1}));
+ ASSERT_EQ(11.1f, t_f32.Value({2, 2}));
+}
+
+TEST(TestNumericTensor, ElementAccessWithNonContiguousStrides) {
+ std::vector<int64_t> shape = {3, 4};
+
+ const int64_t i64_size = sizeof(int64_t);
+ std::vector<int64_t> values_i64 = {1, 2, 3, 4, 0, 0, 5, 6, 7,
+ 8, 0, 0, 9, 10, 11, 12, 0, 0};
+ std::vector<int64_t> strides_i64 = {i64_size * 6, i64_size};
+ std::shared_ptr<Buffer> buffer_i64(Buffer::Wrap(values_i64));
+ NumericTensor<Int64Type> t_i64(buffer_i64, shape, strides_i64);
+
+ ASSERT_FALSE(t_i64.is_contiguous());
+ ASSERT_FALSE(t_i64.is_row_major());
+ ASSERT_FALSE(t_i64.is_column_major());
+ ASSERT_EQ(1, t_i64.Value({0, 0}));
+ ASSERT_EQ(2, t_i64.Value({0, 1}));
+ ASSERT_EQ(4, t_i64.Value({0, 3}));
+ ASSERT_EQ(5, t_i64.Value({1, 0}));
+ ASSERT_EQ(6, t_i64.Value({1, 1}));
+ ASSERT_EQ(11, t_i64.Value({2, 2}));
+
+ const int64_t f32_size = sizeof(float);
+ std::vector<float> values_f32 = {1.1f, 2.1f, 3.1f, 4.1f, 0.0f, 0.0f,
+ 5.1f, 6.1f, 7.1f, 8.1f, 0.0f, 0.0f,
+ 9.1f, 10.1f, 11.1f, 12.1f, 0.0f, 0.0f};
+ std::vector<int64_t> strides_f32 = {f32_size * 6, f32_size};
+ std::shared_ptr<Buffer> buffer_f32(Buffer::Wrap(values_f32));
+ NumericTensor<FloatType> t_f32(buffer_f32, shape, strides_f32);
+
+ ASSERT_FALSE(t_f32.is_contiguous());
+ ASSERT_FALSE(t_f32.is_row_major());
+ ASSERT_FALSE(t_f32.is_column_major());
+ ASSERT_EQ(1.1f, t_f32.Value({0, 0}));
+ ASSERT_EQ(2.1f, t_f32.Value({0, 1}));
+ ASSERT_EQ(4.1f, t_f32.Value({0, 3}));
+ ASSERT_EQ(5.1f, t_f32.Value({1, 0}));
+ ASSERT_EQ(6.1f, t_f32.Value({1, 1}));
+ ASSERT_EQ(11.1f, t_f32.Value({2, 2}));
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/CMakeLists.txt b/src/arrow/cpp/src/arrow/testing/CMakeLists.txt
new file mode 100644
index 000000000..073224d51
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/CMakeLists.txt
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+arrow_install_all_headers("arrow/testing")
+
+if(ARROW_BUILD_TESTS)
+ add_arrow_test(random_test)
+endif()
+
+# json_integration_test is two things at the same time:
+# - an executable that can be called to answer integration test requests
+# - a self-(unit)test for the C++ side of integration testing
+if(ARROW_BUILD_TESTS)
+ add_arrow_test(json_integration_test EXTRA_LINK_LIBS ${GFLAGS_LIBRARIES})
+ add_dependencies(arrow-integration arrow-json-integration-test)
+elseif(ARROW_BUILD_INTEGRATION)
+ add_executable(arrow-json-integration-test json_integration_test.cc)
+ target_link_libraries(arrow-json-integration-test ${ARROW_TEST_LINK_LIBS}
+ ${GFLAGS_LIBRARIES} GTest::gtest)
+
+ add_dependencies(arrow-json-integration-test arrow arrow_testing)
+ add_dependencies(arrow-integration arrow-json-integration-test)
+endif()
diff --git a/src/arrow/cpp/src/arrow/testing/async_test_util.h b/src/arrow/cpp/src/arrow/testing/async_test_util.h
new file mode 100644
index 000000000..b9f5487ed
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/async_test_util.h
@@ -0,0 +1,54 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <atomic>
+#include <memory>
+
+#include "arrow/util/async_generator.h"
+#include "arrow/util/future.h"
+
+namespace arrow {
+namespace util {
+
+template <typename T>
+class TrackingGenerator {
+ public:
+ explicit TrackingGenerator(AsyncGenerator<T> source)
+ : state_(std::make_shared<State>(std::move(source))) {}
+
+ Future<T> operator()() {
+ state_->num_read++;
+ return state_->source();
+ }
+
+ int num_read() { return state_->num_read.load(); }
+
+ private:
+ struct State {
+ explicit State(AsyncGenerator<T> source) : source(std::move(source)), num_read(0) {}
+
+ AsyncGenerator<T> source;
+ std::atomic<int> num_read;
+ };
+
+ std::shared_ptr<State> state_;
+};
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/executor_util.h b/src/arrow/cpp/src/arrow/testing/executor_util.h
new file mode 100644
index 000000000..e34fc858d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/executor_util.h
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+/// An executor which synchronously runs the task as part of the SpawnReal call.
+class MockExecutor : public internal::Executor {
+ public:
+ int GetCapacity() override { return 0; }
+
+ Status SpawnReal(internal::TaskHints hints, internal::FnOnce<void()> task, StopToken,
+ StopCallback&&) override {
+ spawn_count++;
+ std::move(task)();
+ return Status::OK();
+ }
+
+ int spawn_count = 0;
+};
+
+/// An executor which does not actually run the task. Can be used to simulate situations
+/// where the executor schedules a task in a long queue and doesn't get around to running
+/// it for a while
+class DelayedExecutor : public internal::Executor {
+ public:
+ int GetCapacity() override { return 0; }
+
+ Status SpawnReal(internal::TaskHints hints, internal::FnOnce<void()> task, StopToken,
+ StopCallback&&) override {
+ captured_tasks.push_back(std::move(task));
+ return Status::OK();
+ }
+
+ std::vector<internal::FnOnce<void()>> captured_tasks;
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/extension_type.h b/src/arrow/cpp/src/arrow/testing/extension_type.h
new file mode 100644
index 000000000..5afe23400
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/extension_type.h
@@ -0,0 +1,158 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/extension_type.h"
+#include "arrow/testing/visibility.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+
+class ARROW_TESTING_EXPORT UuidArray : public ExtensionArray {
+ public:
+ using ExtensionArray::ExtensionArray;
+};
+
+class ARROW_TESTING_EXPORT UuidType : public ExtensionType {
+ public:
+ UuidType() : ExtensionType(fixed_size_binary(16)) {}
+
+ std::string extension_name() const override { return "uuid"; }
+
+ bool ExtensionEquals(const ExtensionType& other) const override;
+
+ std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;
+
+ Result<std::shared_ptr<DataType>> Deserialize(
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized) const override;
+
+ std::string Serialize() const override { return "uuid-serialized"; }
+};
+
+class ARROW_TESTING_EXPORT SmallintArray : public ExtensionArray {
+ public:
+ using ExtensionArray::ExtensionArray;
+};
+
+class ARROW_TESTING_EXPORT SmallintType : public ExtensionType {
+ public:
+ SmallintType() : ExtensionType(int16()) {}
+
+ std::string extension_name() const override { return "smallint"; }
+
+ bool ExtensionEquals(const ExtensionType& other) const override;
+
+ std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;
+
+ Result<std::shared_ptr<DataType>> Deserialize(
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized) const override;
+
+ std::string Serialize() const override { return "smallint"; }
+};
+
+class ARROW_TESTING_EXPORT DictExtensionType : public ExtensionType {
+ public:
+ DictExtensionType() : ExtensionType(dictionary(int8(), utf8())) {}
+
+ std::string extension_name() const override { return "dict-extension"; }
+
+ bool ExtensionEquals(const ExtensionType& other) const override;
+
+ std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;
+
+ Result<std::shared_ptr<DataType>> Deserialize(
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized) const override;
+
+ std::string Serialize() const override { return "dict-extension-serialized"; }
+};
+
+class ARROW_TESTING_EXPORT Complex128Array : public ExtensionArray {
+ public:
+ using ExtensionArray::ExtensionArray;
+};
+
+class ARROW_TESTING_EXPORT Complex128Type : public ExtensionType {
+ public:
+ Complex128Type()
+ : ExtensionType(struct_({::arrow::field("real", float64(), /*nullable=*/false),
+ ::arrow::field("imag", float64(), /*nullable=*/false)})) {}
+
+ std::string extension_name() const override { return "complex128"; }
+
+ bool ExtensionEquals(const ExtensionType& other) const override;
+
+ std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;
+
+ Result<std::shared_ptr<DataType>> Deserialize(
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized) const override;
+
+ std::string Serialize() const override { return "complex128-serialized"; }
+};
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<DataType> uuid();
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<DataType> smallint();
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<DataType> dict_extension_type();
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<DataType> complex128();
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<Array> ExampleUuid();
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<Array> ExampleSmallint();
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<Array> ExampleDictExtension();
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<Array> ExampleComplex128();
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<Array> MakeComplex128(const std::shared_ptr<Array>& real,
+ const std::shared_ptr<Array>& imag);
+
+// A RAII class that registers an extension type on construction
+// and unregisters it on destruction.
+class ARROW_TESTING_EXPORT ExtensionTypeGuard {
+ public:
+ explicit ExtensionTypeGuard(const std::shared_ptr<DataType>& type);
+ explicit ExtensionTypeGuard(const DataTypeVector& types);
+ ~ExtensionTypeGuard();
+ ARROW_DEFAULT_MOVE_AND_ASSIGN(ExtensionTypeGuard);
+
+ protected:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ExtensionTypeGuard);
+
+ std::vector<std::string> extension_names_;
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/future_util.h b/src/arrow/cpp/src/arrow/testing/future_util.h
new file mode 100644
index 000000000..2ca70d054
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/future_util.h
@@ -0,0 +1,142 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/future.h"
+
+// This macro should be called by futures that are expected to
+// complete pretty quickly. arrow::kDefaultAssertFinishesWaitSeconds is the
+// default max wait here. Anything longer than that and it's a questionable unit test
+// anyways.
+#define ASSERT_FINISHES_IMPL(fut) \
+ do { \
+ ASSERT_TRUE(fut.Wait(::arrow::kDefaultAssertFinishesWaitSeconds)); \
+ if (!fut.is_finished()) { \
+ FAIL() << "Future did not finish in a timely fashion"; \
+ } \
+ } while (false)
+
+#define ASSERT_FINISHES_OK(expr) \
+ do { \
+ auto&& _fut = (expr); \
+ ASSERT_TRUE(_fut.Wait(::arrow::kDefaultAssertFinishesWaitSeconds)); \
+ if (!_fut.is_finished()) { \
+ FAIL() << "Future did not finish in a timely fashion"; \
+ } \
+ auto& _st = _fut.status(); \
+ if (!_st.ok()) { \
+ FAIL() << "'" ARROW_STRINGIFY(expr) "' failed with " << _st.ToString(); \
+ } \
+ } while (false)
+
+#define ASSERT_FINISHES_AND_RAISES(ENUM, expr) \
+ do { \
+ auto&& _fut = (expr); \
+ ASSERT_FINISHES_IMPL(_fut); \
+ ASSERT_RAISES(ENUM, _fut.status()); \
+ } while (false)
+
+#define EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(ENUM, matcher, expr) \
+ do { \
+ auto&& fut = (expr); \
+ ASSERT_FINISHES_IMPL(fut); \
+ EXPECT_RAISES_WITH_MESSAGE_THAT(ENUM, matcher, fut.status()); \
+ } while (false)
+
+#define ASSERT_FINISHES_OK_AND_ASSIGN_IMPL(lhs, rexpr, _future_name) \
+ auto _future_name = (rexpr); \
+ ASSERT_FINISHES_IMPL(_future_name); \
+ ASSERT_OK_AND_ASSIGN(lhs, _future_name.result());
+
+#define ASSERT_FINISHES_OK_AND_ASSIGN(lhs, rexpr) \
+ ASSERT_FINISHES_OK_AND_ASSIGN_IMPL(lhs, rexpr, \
+ ARROW_ASSIGN_OR_RAISE_NAME(_fut, __COUNTER__))
+
+#define ASSERT_FINISHES_OK_AND_EQ(expected, expr) \
+ do { \
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto _actual, (expr)); \
+ ASSERT_EQ(expected, _actual); \
+ } while (0)
+
+#define EXPECT_FINISHES_IMPL(fut) \
+ do { \
+ EXPECT_TRUE(fut.Wait(::arrow::kDefaultAssertFinishesWaitSeconds)); \
+ if (!fut.is_finished()) { \
+ ADD_FAILURE() << "Future did not finish in a timely fashion"; \
+ } \
+ } while (false)
+
+#define ON_FINISH_ASSIGN_OR_HANDLE_ERROR_IMPL(handle_error, future_name, lhs, rexpr) \
+ auto future_name = (rexpr); \
+ EXPECT_FINISHES_IMPL(future_name); \
+ handle_error(future_name.status()); \
+ EXPECT_OK_AND_ASSIGN(lhs, future_name.result());
+
+#define EXPECT_FINISHES(expr) \
+ do { \
+ EXPECT_FINISHES_IMPL(expr); \
+ } while (0)
+
+#define EXPECT_FINISHES_OK_AND_ASSIGN(lhs, rexpr) \
+ ON_FINISH_ASSIGN_OR_HANDLE_ERROR_IMPL( \
+ ARROW_EXPECT_OK, ARROW_ASSIGN_OR_RAISE_NAME(_fut, __COUNTER__), lhs, rexpr);
+
+#define EXPECT_FINISHES_OK_AND_EQ(expected, expr) \
+ do { \
+ EXPECT_FINISHES_OK_AND_ASSIGN(auto _actual, (expr)); \
+ EXPECT_EQ(expected, _actual); \
+ } while (0)
+
+namespace arrow {
+
+constexpr double kDefaultAssertFinishesWaitSeconds = 64;
+
+template <typename T>
+void AssertNotFinished(const Future<T>& fut) {
+ ASSERT_FALSE(IsFutureFinished(fut.state()));
+}
+
+template <typename T>
+void AssertFinished(const Future<T>& fut) {
+ ASSERT_TRUE(IsFutureFinished(fut.state()));
+}
+
+// Assert the future is successful *now*
+template <typename T>
+void AssertSuccessful(const Future<T>& fut) {
+ if (IsFutureFinished(fut.state())) {
+ ASSERT_EQ(fut.state(), FutureState::SUCCESS);
+ ASSERT_OK(fut.status());
+ } else {
+ FAIL() << "Expected future to be completed successfully but it was still pending";
+ }
+}
+
+// Assert the future is failed *now*
+template <typename T>
+void AssertFailed(const Future<T>& fut) {
+ if (IsFutureFinished(fut.state())) {
+ ASSERT_EQ(fut.state(), FutureState::FAILURE);
+ ASSERT_FALSE(fut.status().ok());
+ } else {
+ FAIL() << "Expected future to have failed but it was still pending";
+ }
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/generator.cc b/src/arrow/cpp/src/arrow/testing/generator.cc
new file mode 100644
index 000000000..33371d55c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/generator.cc
@@ -0,0 +1,110 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/testing/generator.h"
+
+#include <algorithm>
+#include <memory>
+#include <random>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+
+namespace arrow {
+
+template <typename ArrowType, typename CType = typename TypeTraits<ArrowType>::CType,
+ typename BuilderType = typename TypeTraits<ArrowType>::BuilderType>
+static inline std::shared_ptr<Array> ConstantArray(int64_t size, CType value) {
+ auto type = TypeTraits<ArrowType>::type_singleton();
+ auto builder_fn = [&](BuilderType* builder) { builder->UnsafeAppend(value); };
+ return ArrayFromBuilderVisitor(type, size, builder_fn).ValueOrDie();
+}
+
+std::shared_ptr<arrow::Array> ConstantArrayGenerator::Boolean(int64_t size, bool value) {
+ return ConstantArray<BooleanType>(size, value);
+}
+
+std::shared_ptr<arrow::Array> ConstantArrayGenerator::UInt8(int64_t size, uint8_t value) {
+ return ConstantArray<UInt8Type>(size, value);
+}
+
+std::shared_ptr<arrow::Array> ConstantArrayGenerator::Int8(int64_t size, int8_t value) {
+ return ConstantArray<Int8Type>(size, value);
+}
+
+std::shared_ptr<arrow::Array> ConstantArrayGenerator::UInt16(int64_t size,
+ uint16_t value) {
+ return ConstantArray<UInt16Type>(size, value);
+}
+
+std::shared_ptr<arrow::Array> ConstantArrayGenerator::Int16(int64_t size, int16_t value) {
+ return ConstantArray<Int16Type>(size, value);
+}
+
+std::shared_ptr<arrow::Array> ConstantArrayGenerator::UInt32(int64_t size,
+ uint32_t value) {
+ return ConstantArray<UInt32Type>(size, value);
+}
+
+std::shared_ptr<arrow::Array> ConstantArrayGenerator::Int32(int64_t size, int32_t value) {
+ return ConstantArray<Int32Type>(size, value);
+}
+
+std::shared_ptr<arrow::Array> ConstantArrayGenerator::UInt64(int64_t size,
+ uint64_t value) {
+ return ConstantArray<UInt64Type>(size, value);
+}
+
+std::shared_ptr<arrow::Array> ConstantArrayGenerator::Int64(int64_t size, int64_t value) {
+ return ConstantArray<Int64Type>(size, value);
+}
+
+std::shared_ptr<arrow::Array> ConstantArrayGenerator::Float32(int64_t size, float value) {
+ return ConstantArray<FloatType>(size, value);
+}
+
+std::shared_ptr<arrow::Array> ConstantArrayGenerator::Float64(int64_t size,
+ double value) {
+ return ConstantArray<DoubleType>(size, value);
+}
+
+std::shared_ptr<arrow::Array> ConstantArrayGenerator::String(int64_t size,
+ std::string value) {
+ return ConstantArray<StringType>(size, value);
+}
+
+Result<std::shared_ptr<Array>> ScalarVectorToArray(const ScalarVector& scalars) {
+ if (scalars.empty()) {
+ return Status::NotImplemented("ScalarVectorToArray with no scalars");
+ }
+ std::unique_ptr<arrow::ArrayBuilder> builder;
+ RETURN_NOT_OK(MakeBuilder(default_memory_pool(), scalars[0]->type, &builder));
+ RETURN_NOT_OK(builder->AppendScalars(scalars));
+ std::shared_ptr<Array> out;
+ RETURN_NOT_OK(builder->Finish(&out));
+ return out;
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/generator.h b/src/arrow/cpp/src/arrow/testing/generator.h
new file mode 100644
index 000000000..c30002243
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/generator.h
@@ -0,0 +1,261 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/record_batch.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/visibility.h"
+#include "arrow/type.h"
+
+namespace arrow {
+
+class ARROW_TESTING_EXPORT ConstantArrayGenerator {
+ public:
+ /// \brief Generates a constant BooleanArray
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] value to repeat
+ ///
+ /// \return a generated Array
+ static std::shared_ptr<arrow::Array> Boolean(int64_t size, bool value = false);
+
+ /// \brief Generates a constant UInt8Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] value to repeat
+ ///
+ /// \return a generated Array
+ static std::shared_ptr<arrow::Array> UInt8(int64_t size, uint8_t value = 0);
+
+ /// \brief Generates a constant Int8Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] value to repeat
+ ///
+ /// \return a generated Array
+ static std::shared_ptr<arrow::Array> Int8(int64_t size, int8_t value = 0);
+
+ /// \brief Generates a constant UInt16Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] value to repeat
+ ///
+ /// \return a generated Array
+ static std::shared_ptr<arrow::Array> UInt16(int64_t size, uint16_t value = 0);
+
+ /// \brief Generates a constant UInt16Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] value to repeat
+ ///
+ /// \return a generated Array
+ static std::shared_ptr<arrow::Array> Int16(int64_t size, int16_t value = 0);
+
+ /// \brief Generates a constant UInt32Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] value to repeat
+ ///
+ /// \return a generated Array
+ static std::shared_ptr<arrow::Array> UInt32(int64_t size, uint32_t value = 0);
+
+ /// \brief Generates a constant UInt32Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] value to repeat
+ ///
+ /// \return a generated Array
+ static std::shared_ptr<arrow::Array> Int32(int64_t size, int32_t value = 0);
+
+ /// \brief Generates a constant UInt64Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] value to repeat
+ ///
+ /// \return a generated Array
+ static std::shared_ptr<arrow::Array> UInt64(int64_t size, uint64_t value = 0);
+
+ /// \brief Generates a constant UInt64Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] value to repeat
+ ///
+ /// \return a generated Array
+ static std::shared_ptr<arrow::Array> Int64(int64_t size, int64_t value = 0);
+
+ /// \brief Generates a constant Float32Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] value to repeat
+ ///
+ /// \return a generated Array
+ static std::shared_ptr<arrow::Array> Float32(int64_t size, float value = 0);
+
+ /// \brief Generates a constant Float64Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] value to repeat
+ ///
+ /// \return a generated Array
+ static std::shared_ptr<arrow::Array> Float64(int64_t size, double value = 0);
+
+ /// \brief Generates a constant StringArray
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] value to repeat
+ ///
+ /// \return a generated Array
+ static std::shared_ptr<arrow::Array> String(int64_t size, std::string value = "");
+
+ template <typename ArrowType, typename CType = typename ArrowType::c_type>
+ static std::shared_ptr<arrow::Array> Numeric(int64_t size, CType value = 0) {
+ switch (ArrowType::type_id) {
+ case Type::BOOL:
+ return Boolean(size, static_cast<bool>(value));
+ case Type::UINT8:
+ return UInt8(size, static_cast<uint8_t>(value));
+ case Type::INT8:
+ return Int8(size, static_cast<int8_t>(value));
+ case Type::UINT16:
+ return UInt16(size, static_cast<uint16_t>(value));
+ case Type::INT16:
+ return Int16(size, static_cast<int16_t>(value));
+ case Type::UINT32:
+ return UInt32(size, static_cast<uint32_t>(value));
+ case Type::INT32:
+ return Int32(size, static_cast<int32_t>(value));
+ case Type::UINT64:
+ return UInt64(size, static_cast<uint64_t>(value));
+ case Type::INT64:
+ return Int64(size, static_cast<int64_t>(value));
+ case Type::FLOAT:
+ return Float32(size, static_cast<float>(value));
+ case Type::DOUBLE:
+ return Float64(size, static_cast<double>(value));
+ default:
+ return nullptr;
+ }
+ }
+
+ /// \brief Generates a constant Array of zeroes
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] type the type of the Array
+ ///
+ /// \return a generated Array
+ static std::shared_ptr<arrow::Array> Zeroes(int64_t size,
+ const std::shared_ptr<DataType>& type) {
+ switch (type->id()) {
+ case Type::NA:
+ return std::make_shared<NullArray>(size);
+ case Type::BOOL:
+ return Boolean(size);
+ case Type::UINT8:
+ return UInt8(size);
+ case Type::INT8:
+ return Int8(size);
+ case Type::UINT16:
+ return UInt16(size);
+ case Type::INT16:
+ return Int16(size);
+ case Type::UINT32:
+ return UInt32(size);
+ case Type::INT32:
+ return Int32(size);
+ case Type::UINT64:
+ return UInt64(size);
+ case Type::INT64:
+ return Int64(size);
+ case Type::TIME64:
+ case Type::DATE64:
+ case Type::TIMESTAMP: {
+ EXPECT_OK_AND_ASSIGN(auto viewed, Int64(size)->View(type));
+ return viewed;
+ }
+ case Type::INTERVAL_DAY_TIME:
+ case Type::INTERVAL_MONTHS:
+ case Type::TIME32:
+ case Type::DATE32: {
+ EXPECT_OK_AND_ASSIGN(auto viewed, Int32(size)->View(type));
+ return viewed;
+ }
+ case Type::FLOAT:
+ return Float32(size);
+ case Type::DOUBLE:
+ return Float64(size);
+ case Type::STRING:
+ return String(size);
+ default:
+ return nullptr;
+ }
+ }
+
+ /// \brief Generates a RecordBatch of zeroes
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] schema to conform to
+ ///
+ /// This function is handy to return of RecordBatch of a desired shape.
+ ///
+ /// \return a generated RecordBatch
+ static std::shared_ptr<arrow::RecordBatch> Zeroes(
+ int64_t size, const std::shared_ptr<Schema>& schema) {
+ std::vector<std::shared_ptr<Array>> arrays;
+
+ for (const auto& field : schema->fields()) {
+ arrays.emplace_back(Zeroes(size, field->type()));
+ }
+
+ return RecordBatch::Make(schema, size, arrays);
+ }
+
+ /// \brief Generates a RecordBatchReader by repeating a RecordBatch
+ ///
+ /// \param[in] n_batch the number of times it repeats batch
+ /// \param[in] batch the RecordBatch to repeat
+ ///
+ /// \return a generated RecordBatchReader
+ static std::shared_ptr<arrow::RecordBatchReader> Repeat(
+ int64_t n_batch, const std::shared_ptr<RecordBatch> batch) {
+ std::vector<std::shared_ptr<RecordBatch>> batches(static_cast<size_t>(n_batch),
+ batch);
+ return *RecordBatchReader::Make(batches);
+ }
+
+ /// \brief Generates a RecordBatchReader of zeroes batches
+ ///
+ /// \param[in] n_batch the number of RecordBatch
+ /// \param[in] batch_size the size of each RecordBatch
+ /// \param[in] schema to conform to
+ ///
+ /// \return a generated RecordBatchReader
+ static std::shared_ptr<arrow::RecordBatchReader> Zeroes(
+ int64_t n_batch, int64_t batch_size, const std::shared_ptr<Schema>& schema) {
+ return Repeat(n_batch, Zeroes(batch_size, schema));
+ }
+};
+
+ARROW_TESTING_EXPORT
+Result<std::shared_ptr<Array>> ScalarVectorToArray(const ScalarVector& scalars);
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/gtest_common.h b/src/arrow/cpp/src/arrow/testing/gtest_common.h
new file mode 100644
index 000000000..8b48238ed
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/gtest_common.h
@@ -0,0 +1,128 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <random>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_fwd.h"
+
+namespace arrow {
+
+class TestBase : public ::testing::Test {
+ public:
+ void SetUp() {
+ pool_ = default_memory_pool();
+ random_seed_ = 0;
+ }
+
+ std::shared_ptr<Buffer> MakeRandomNullBitmap(int64_t length, int64_t null_count) {
+ const int64_t null_nbytes = BitUtil::BytesForBits(length);
+
+ std::shared_ptr<Buffer> null_bitmap = *AllocateBuffer(null_nbytes, pool_);
+ memset(null_bitmap->mutable_data(), 255, null_nbytes);
+ for (int64_t i = 0; i < null_count; i++) {
+ BitUtil::ClearBit(null_bitmap->mutable_data(), i * (length / null_count));
+ }
+ return null_bitmap;
+ }
+
+ template <typename ArrayType>
+ inline std::shared_ptr<Array> MakeRandomArray(int64_t length, int64_t null_count = 0);
+
+ protected:
+ uint32_t random_seed_;
+ MemoryPool* pool_;
+};
+
+template <typename ArrayType>
+std::shared_ptr<Array> TestBase::MakeRandomArray(int64_t length, int64_t null_count) {
+ const int64_t data_nbytes = length * sizeof(typename ArrayType::value_type);
+ auto data = *AllocateBuffer(data_nbytes, pool_);
+
+ // Fill with random data
+ random_bytes(data_nbytes, random_seed_++, data->mutable_data());
+ std::shared_ptr<Buffer> null_bitmap = MakeRandomNullBitmap(length, null_count);
+
+ return std::make_shared<ArrayType>(length, std::move(data), null_bitmap, null_count);
+}
+
+template <>
+inline std::shared_ptr<Array> TestBase::MakeRandomArray<NullArray>(int64_t length,
+ int64_t null_count) {
+ return std::make_shared<NullArray>(length);
+}
+
+template <>
+inline std::shared_ptr<Array> TestBase::MakeRandomArray<FixedSizeBinaryArray>(
+ int64_t length, int64_t null_count) {
+ const int byte_width = 10;
+ std::shared_ptr<Buffer> null_bitmap = MakeRandomNullBitmap(length, null_count);
+ auto data = *AllocateBuffer(byte_width * length, pool_);
+
+ ::arrow::random_bytes(data->size(), 0, data->mutable_data());
+ return std::make_shared<FixedSizeBinaryArray>(fixed_size_binary(byte_width), length,
+ std::move(data), null_bitmap, null_count);
+}
+
+template <>
+inline std::shared_ptr<Array> TestBase::MakeRandomArray<BinaryArray>(int64_t length,
+ int64_t null_count) {
+ std::vector<uint8_t> valid_bytes(length, 1);
+ for (int64_t i = 0; i < null_count; i++) {
+ valid_bytes[i * 2] = 0;
+ }
+ BinaryBuilder builder(pool_);
+
+ const int kBufferSize = 10;
+ uint8_t buffer[kBufferSize];
+ for (int64_t i = 0; i < length; i++) {
+ if (!valid_bytes[i]) {
+ ARROW_EXPECT_OK(builder.AppendNull());
+ } else {
+ ::arrow::random_bytes(kBufferSize, static_cast<uint32_t>(i), buffer);
+ ARROW_EXPECT_OK(builder.Append(buffer, kBufferSize));
+ }
+ }
+
+ std::shared_ptr<Array> out;
+ ARROW_EXPECT_OK(builder.Finish(&out));
+ return out;
+}
+
+class TestBuilder : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = default_memory_pool(); }
+
+ protected:
+ MemoryPool* pool_;
+ std::shared_ptr<DataType> type_;
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/gtest_compat.h b/src/arrow/cpp/src/arrow/testing/gtest_compat.h
new file mode 100644
index 000000000..c934dd279
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/gtest_compat.h
@@ -0,0 +1,33 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <gtest/gtest.h>
+
+// GTest < 1.11
+#ifndef GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST
+#define GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(A)
+#endif
+// GTest < 1.10
+#ifndef TYPED_TEST_SUITE
+#define TYPED_TEST_SUITE TYPED_TEST_CASE
+#define TYPED_TEST_SUITE_P TYPED_TEST_CASE_P
+#define INSTANTIATE_TEST_SUITE_P INSTANTIATE_TEST_CASE_P
+#define REGISTER_TYPED_TEST_SUITE_P REGISTER_TYPED_TEST_CASE_P
+#define INSTANTIATE_TYPED_TEST_SUITE_P INSTANTIATE_TYPED_TEST_CASE_P
+#endif
diff --git a/src/arrow/cpp/src/arrow/testing/gtest_util.cc b/src/arrow/cpp/src/arrow/testing/gtest_util.cc
new file mode 100644
index 000000000..f471ed140
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/gtest_util.cc
@@ -0,0 +1,1006 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/testing/gtest_util.h"
+
+#include "arrow/testing/extension_type.h"
+
+#ifndef _WIN32
+#include <sys/stat.h> // IWYU pragma: keep
+#include <sys/wait.h> // IWYU pragma: keep
+#include <unistd.h> // IWYU pragma: keep
+#endif
+
+#include <algorithm>
+#include <chrono>
+#include <condition_variable>
+#include <cstdint>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+#include <locale>
+#include <memory>
+#include <mutex>
+#include <sstream>
+#include <stdexcept>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/datum.h"
+#include "arrow/ipc/json_simple.h"
+#include "arrow/pretty_print.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/windows_compatibility.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+std::vector<Type::type> AllTypeIds() {
+ return {Type::NA,
+ Type::BOOL,
+ Type::INT8,
+ Type::INT16,
+ Type::INT32,
+ Type::INT64,
+ Type::UINT8,
+ Type::UINT16,
+ Type::UINT32,
+ Type::UINT64,
+ Type::HALF_FLOAT,
+ Type::FLOAT,
+ Type::DOUBLE,
+ Type::DECIMAL128,
+ Type::DECIMAL256,
+ Type::DATE32,
+ Type::DATE64,
+ Type::TIME32,
+ Type::TIME64,
+ Type::TIMESTAMP,
+ Type::INTERVAL_DAY_TIME,
+ Type::INTERVAL_MONTHS,
+ Type::DURATION,
+ Type::STRING,
+ Type::BINARY,
+ Type::LARGE_STRING,
+ Type::LARGE_BINARY,
+ Type::FIXED_SIZE_BINARY,
+ Type::STRUCT,
+ Type::LIST,
+ Type::LARGE_LIST,
+ Type::FIXED_SIZE_LIST,
+ Type::MAP,
+ Type::DENSE_UNION,
+ Type::SPARSE_UNION,
+ Type::DICTIONARY,
+ Type::EXTENSION,
+ Type::INTERVAL_MONTH_DAY_NANO};
+}
+
+template <typename T, typename CompareFunctor>
+void AssertTsSame(const T& expected, const T& actual, CompareFunctor&& compare) {
+ if (!compare(actual, expected)) {
+ std::stringstream pp_expected;
+ std::stringstream pp_actual;
+ ::arrow::PrettyPrintOptions options(/*indent=*/2);
+ options.window = 50;
+ ARROW_EXPECT_OK(PrettyPrint(expected, options, &pp_expected));
+ ARROW_EXPECT_OK(PrettyPrint(actual, options, &pp_actual));
+ FAIL() << "Got: \n" << pp_actual.str() << "\nExpected: \n" << pp_expected.str();
+ }
+}
+
+template <typename CompareFunctor>
+void AssertArraysEqualWith(const Array& expected, const Array& actual, bool verbose,
+ CompareFunctor&& compare) {
+ std::stringstream diff;
+ if (!compare(expected, actual, &diff)) {
+ if (expected.data()->null_count != actual.data()->null_count) {
+ diff << "Null counts differ. Expected " << expected.data()->null_count
+ << " but was " << actual.data()->null_count << "\n";
+ }
+ if (verbose) {
+ ::arrow::PrettyPrintOptions options(/*indent=*/2);
+ options.window = 50;
+ diff << "Expected:\n";
+ ARROW_EXPECT_OK(PrettyPrint(expected, options, &diff));
+ diff << "\nActual:\n";
+ ARROW_EXPECT_OK(PrettyPrint(actual, options, &diff));
+ }
+ FAIL() << diff.str();
+ }
+}
+
+void AssertArraysEqual(const Array& expected, const Array& actual, bool verbose,
+ const EqualOptions& options) {
+ return AssertArraysEqualWith(
+ expected, actual, verbose,
+ [&](const Array& expected, const Array& actual, std::stringstream* diff) {
+ return expected.Equals(actual, options.diff_sink(diff));
+ });
+}
+
+void AssertArraysApproxEqual(const Array& expected, const Array& actual, bool verbose,
+ const EqualOptions& options) {
+ return AssertArraysEqualWith(
+ expected, actual, verbose,
+ [&](const Array& expected, const Array& actual, std::stringstream* diff) {
+ return expected.ApproxEquals(actual, options.diff_sink(diff));
+ });
+}
+
+void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, bool verbose,
+ const EqualOptions& options) {
+ if (!expected.Equals(actual, options)) {
+ std::stringstream diff;
+ if (verbose) {
+ diff << "Expected:\n" << expected.ToString();
+ diff << "\nActual:\n" << actual.ToString();
+ }
+ FAIL() << diff.str();
+ }
+}
+
+void AssertScalarsApproxEqual(const Scalar& expected, const Scalar& actual, bool verbose,
+ const EqualOptions& options) {
+ if (!expected.ApproxEquals(actual, options)) {
+ std::stringstream diff;
+ if (verbose) {
+ diff << "Expected:\n" << expected.ToString();
+ diff << "\nActual:\n" << actual.ToString();
+ }
+ FAIL() << diff.str();
+ }
+}
+
+void AssertBatchesEqual(const RecordBatch& expected, const RecordBatch& actual,
+ bool check_metadata) {
+ AssertTsSame(expected, actual,
+ [&](const RecordBatch& expected, const RecordBatch& actual) {
+ return expected.Equals(actual, check_metadata);
+ });
+}
+
+void AssertBatchesApproxEqual(const RecordBatch& expected, const RecordBatch& actual) {
+ AssertTsSame(expected, actual,
+ [&](const RecordBatch& expected, const RecordBatch& actual) {
+ return expected.ApproxEquals(actual);
+ });
+}
+
+void AssertChunkedEqual(const ChunkedArray& expected, const ChunkedArray& actual) {
+ ASSERT_EQ(expected.num_chunks(), actual.num_chunks()) << "# chunks unequal";
+ if (!actual.Equals(expected)) {
+ std::stringstream diff;
+ for (int i = 0; i < actual.num_chunks(); ++i) {
+ auto c1 = actual.chunk(i);
+ auto c2 = expected.chunk(i);
+ diff << "# chunk " << i << std::endl;
+ ARROW_IGNORE_EXPR(c1->Equals(c2, EqualOptions().diff_sink(&diff)));
+ }
+ FAIL() << diff.str();
+ }
+}
+
+void AssertChunkedEqual(const ChunkedArray& actual, const ArrayVector& expected) {
+ AssertChunkedEqual(ChunkedArray(expected, actual.type()), actual);
+}
+
+void AssertChunkedEquivalent(const ChunkedArray& expected, const ChunkedArray& actual) {
+ // XXX: AssertChunkedEqual in gtest_util.h does not permit the chunk layouts
+ // to be different
+ if (!actual.Equals(expected)) {
+ std::stringstream pp_expected;
+ std::stringstream pp_actual;
+ ::arrow::PrettyPrintOptions options(/*indent=*/2);
+ options.window = 50;
+ ARROW_EXPECT_OK(PrettyPrint(expected, options, &pp_expected));
+ ARROW_EXPECT_OK(PrettyPrint(actual, options, &pp_actual));
+ FAIL() << "Got: \n" << pp_actual.str() << "\nExpected: \n" << pp_expected.str();
+ }
+}
+
+void AssertChunkedApproxEquivalent(const ChunkedArray& expected,
+ const ChunkedArray& actual,
+ const EqualOptions& equal_options) {
+ if (!actual.ApproxEquals(expected, equal_options)) {
+ std::stringstream pp_expected;
+ std::stringstream pp_actual;
+ ::arrow::PrettyPrintOptions options(/*indent=*/2);
+ options.window = 50;
+ ARROW_EXPECT_OK(PrettyPrint(expected, options, &pp_expected));
+ ARROW_EXPECT_OK(PrettyPrint(actual, options, &pp_actual));
+ FAIL() << "Got: \n" << pp_actual.str() << "\nExpected: \n" << pp_expected.str();
+ }
+}
+
+void AssertBufferEqual(const Buffer& buffer, const std::vector<uint8_t>& expected) {
+ ASSERT_EQ(static_cast<size_t>(buffer.size()), expected.size())
+ << "Mismatching buffer size";
+ const uint8_t* buffer_data = buffer.data();
+ for (size_t i = 0; i < expected.size(); ++i) {
+ ASSERT_EQ(buffer_data[i], expected[i]);
+ }
+}
+
+void AssertBufferEqual(const Buffer& buffer, const std::string& expected) {
+ ASSERT_EQ(static_cast<size_t>(buffer.size()), expected.length())
+ << "Mismatching buffer size";
+ const uint8_t* buffer_data = buffer.data();
+ for (size_t i = 0; i < expected.size(); ++i) {
+ ASSERT_EQ(buffer_data[i], expected[i]);
+ }
+}
+
+void AssertBufferEqual(const Buffer& buffer, const Buffer& expected) {
+ ASSERT_EQ(buffer.size(), expected.size()) << "Mismatching buffer size";
+ ASSERT_TRUE(buffer.Equals(expected));
+}
+
+template <typename T>
+std::string ToStringWithMetadata(const T& t, bool show_metadata) {
+ return t.ToString(show_metadata);
+}
+
+std::string ToStringWithMetadata(const DataType& t, bool show_metadata) {
+ return t.ToString();
+}
+
+template <typename T>
+void AssertFingerprintablesEqual(const T& left, const T& right, bool check_metadata,
+ const char* types_plural) {
+ ASSERT_TRUE(left.Equals(right, check_metadata))
+ << types_plural << " '" << ToStringWithMetadata(left, check_metadata) << "' and '"
+ << ToStringWithMetadata(right, check_metadata) << "' should have compared equal";
+ auto lfp = left.fingerprint();
+ auto rfp = right.fingerprint();
+ // Note: all types tested in this file should implement fingerprinting,
+ // except extension types.
+ if (check_metadata) {
+ lfp += left.metadata_fingerprint();
+ rfp += right.metadata_fingerprint();
+ }
+ ASSERT_EQ(lfp, rfp) << "Fingerprints for " << types_plural << " '"
+ << ToStringWithMetadata(left, check_metadata) << "' and '"
+ << ToStringWithMetadata(right, check_metadata)
+ << "' should have compared equal";
+}
+
+template <typename T>
+void AssertFingerprintablesEqual(const std::shared_ptr<T>& left,
+ const std::shared_ptr<T>& right, bool check_metadata,
+ const char* types_plural) {
+ ASSERT_NE(left, nullptr);
+ ASSERT_NE(right, nullptr);
+ AssertFingerprintablesEqual(*left, *right, check_metadata, types_plural);
+}
+
+template <typename T>
+void AssertFingerprintablesNotEqual(const T& left, const T& right, bool check_metadata,
+ const char* types_plural) {
+ ASSERT_FALSE(left.Equals(right, check_metadata))
+ << types_plural << " '" << ToStringWithMetadata(left, check_metadata) << "' and '"
+ << ToStringWithMetadata(right, check_metadata) << "' should have compared unequal";
+ auto lfp = left.fingerprint();
+ auto rfp = right.fingerprint();
+ // Note: all types tested in this file should implement fingerprinting,
+ // except extension types.
+ if (lfp != "" && rfp != "") {
+ if (check_metadata) {
+ lfp += left.metadata_fingerprint();
+ rfp += right.metadata_fingerprint();
+ }
+ ASSERT_NE(lfp, rfp) << "Fingerprints for " << types_plural << " '"
+ << ToStringWithMetadata(left, check_metadata) << "' and '"
+ << ToStringWithMetadata(right, check_metadata)
+ << "' should have compared unequal";
+ }
+}
+
+template <typename T>
+void AssertFingerprintablesNotEqual(const std::shared_ptr<T>& left,
+ const std::shared_ptr<T>& right, bool check_metadata,
+ const char* types_plural) {
+ ASSERT_NE(left, nullptr);
+ ASSERT_NE(right, nullptr);
+ AssertFingerprintablesNotEqual(*left, *right, check_metadata, types_plural);
+}
+
+#define ASSERT_EQUAL_IMPL(NAME, TYPE, PLURAL) \
+ void Assert##NAME##Equal(const TYPE& left, const TYPE& right, bool check_metadata) { \
+ AssertFingerprintablesEqual(left, right, check_metadata, PLURAL); \
+ } \
+ \
+ void Assert##NAME##Equal(const std::shared_ptr<TYPE>& left, \
+ const std::shared_ptr<TYPE>& right, bool check_metadata) { \
+ AssertFingerprintablesEqual(left, right, check_metadata, PLURAL); \
+ } \
+ \
+ void Assert##NAME##NotEqual(const TYPE& left, const TYPE& right, \
+ bool check_metadata) { \
+ AssertFingerprintablesNotEqual(left, right, check_metadata, PLURAL); \
+ } \
+ void Assert##NAME##NotEqual(const std::shared_ptr<TYPE>& left, \
+ const std::shared_ptr<TYPE>& right, bool check_metadata) { \
+ AssertFingerprintablesNotEqual(left, right, check_metadata, PLURAL); \
+ }
+
+ASSERT_EQUAL_IMPL(Type, DataType, "types")
+ASSERT_EQUAL_IMPL(Field, Field, "fields")
+ASSERT_EQUAL_IMPL(Schema, Schema, "schemas")
+#undef ASSERT_EQUAL_IMPL
+
+void AssertDatumsEqual(const Datum& expected, const Datum& actual, bool verbose) {
+ ASSERT_EQ(expected.kind(), actual.kind())
+ << "expected:" << expected.ToString() << " got:" << actual.ToString();
+
+ switch (expected.kind()) {
+ case Datum::SCALAR:
+ AssertScalarsEqual(*expected.scalar(), *actual.scalar(), verbose);
+ break;
+ case Datum::ARRAY: {
+ auto expected_array = expected.make_array();
+ auto actual_array = actual.make_array();
+ AssertArraysEqual(*expected_array, *actual_array, verbose);
+ } break;
+ case Datum::CHUNKED_ARRAY:
+ AssertChunkedEquivalent(*expected.chunked_array(), *actual.chunked_array());
+ break;
+ default:
+ // TODO: Implement better print
+ ASSERT_TRUE(actual.Equals(expected));
+ break;
+ }
+}
+
+void AssertDatumsApproxEqual(const Datum& expected, const Datum& actual, bool verbose,
+ const EqualOptions& options) {
+ ASSERT_EQ(expected.kind(), actual.kind())
+ << "expected:" << expected.ToString() << " got:" << actual.ToString();
+
+ switch (expected.kind()) {
+ case Datum::SCALAR:
+ AssertScalarsApproxEqual(*expected.scalar(), *actual.scalar(), verbose, options);
+ break;
+ case Datum::ARRAY: {
+ auto expected_array = expected.make_array();
+ auto actual_array = actual.make_array();
+ AssertArraysApproxEqual(*expected_array, *actual_array, verbose, options);
+ break;
+ }
+ case Datum::CHUNKED_ARRAY: {
+ auto expected_array = expected.chunked_array();
+ auto actual_array = actual.chunked_array();
+ AssertChunkedApproxEquivalent(*expected_array, *actual_array, options);
+ break;
+ }
+ default:
+ // TODO: Implement better print
+ ASSERT_TRUE(actual.Equals(expected));
+ break;
+ }
+}
+
+std::shared_ptr<Array> ArrayFromJSON(const std::shared_ptr<DataType>& type,
+ util::string_view json) {
+ std::shared_ptr<Array> out;
+ ABORT_NOT_OK(ipc::internal::json::ArrayFromJSON(type, json, &out));
+ return out;
+}
+
+std::shared_ptr<Array> DictArrayFromJSON(const std::shared_ptr<DataType>& type,
+ util::string_view indices_json,
+ util::string_view dictionary_json) {
+ std::shared_ptr<Array> out;
+ ABORT_NOT_OK(
+ ipc::internal::json::DictArrayFromJSON(type, indices_json, dictionary_json, &out));
+ return out;
+}
+
+std::shared_ptr<ChunkedArray> ChunkedArrayFromJSON(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& json) {
+ ArrayVector out_chunks;
+ for (const std::string& chunk_json : json) {
+ out_chunks.push_back(ArrayFromJSON(type, chunk_json));
+ }
+ return std::make_shared<ChunkedArray>(std::move(out_chunks), type);
+}
+
+std::shared_ptr<RecordBatch> RecordBatchFromJSON(const std::shared_ptr<Schema>& schema,
+ util::string_view json) {
+ // Parse as a StructArray
+ auto struct_type = struct_(schema->fields());
+ std::shared_ptr<Array> struct_array = ArrayFromJSON(struct_type, json);
+
+ // Convert StructArray to RecordBatch
+ return *RecordBatch::FromStructArray(struct_array);
+}
+
+std::shared_ptr<Scalar> ScalarFromJSON(const std::shared_ptr<DataType>& type,
+ util::string_view json) {
+ std::shared_ptr<Scalar> out;
+ ABORT_NOT_OK(ipc::internal::json::ScalarFromJSON(type, json, &out));
+ return out;
+}
+
+std::shared_ptr<Scalar> DictScalarFromJSON(const std::shared_ptr<DataType>& type,
+ util::string_view index_json,
+ util::string_view dictionary_json) {
+ std::shared_ptr<Scalar> out;
+ ABORT_NOT_OK(
+ ipc::internal::json::DictScalarFromJSON(type, index_json, dictionary_json, &out));
+ return out;
+}
+
+std::shared_ptr<Table> TableFromJSON(const std::shared_ptr<Schema>& schema,
+ const std::vector<std::string>& json) {
+ std::vector<std::shared_ptr<RecordBatch>> batches;
+ for (const std::string& batch_json : json) {
+ batches.push_back(RecordBatchFromJSON(schema, batch_json));
+ }
+ return *Table::FromRecordBatches(schema, std::move(batches));
+}
+
+Result<util::optional<std::string>> PrintArrayDiff(const ChunkedArray& expected,
+ const ChunkedArray& actual) {
+ if (actual.Equals(expected)) {
+ return util::nullopt;
+ }
+
+ std::stringstream ss;
+ if (expected.length() != actual.length()) {
+ ss << "Expected length " << expected.length() << " but was actually "
+ << actual.length();
+ return ss.str();
+ }
+
+ PrettyPrintOptions options(/*indent=*/2);
+ options.window = 50;
+ RETURN_NOT_OK(internal::ApplyBinaryChunked(
+ actual, expected,
+ [&](const Array& left_piece, const Array& right_piece, int64_t position) {
+ std::stringstream diff;
+ if (!left_piece.Equals(right_piece, EqualOptions().diff_sink(&diff))) {
+ ss << "Unequal at absolute position " << position << "\n" << diff.str();
+ ss << "Expected:\n";
+ ARROW_EXPECT_OK(PrettyPrint(right_piece, options, &ss));
+ ss << "\nActual:\n";
+ ARROW_EXPECT_OK(PrettyPrint(left_piece, options, &ss));
+ }
+ return Status::OK();
+ }));
+ return ss.str();
+}
+
+void AssertTablesEqual(const Table& expected, const Table& actual, bool same_chunk_layout,
+ bool combine_chunks) {
+ ASSERT_EQ(expected.num_columns(), actual.num_columns());
+
+ if (combine_chunks) {
+ auto pool = default_memory_pool();
+ ASSERT_OK_AND_ASSIGN(auto new_expected, expected.CombineChunks(pool));
+ ASSERT_OK_AND_ASSIGN(auto new_actual, actual.CombineChunks(pool));
+
+ AssertTablesEqual(*new_expected, *new_actual, false, false);
+ return;
+ }
+
+ if (same_chunk_layout) {
+ for (int i = 0; i < actual.num_columns(); ++i) {
+ AssertChunkedEqual(*expected.column(i), *actual.column(i));
+ }
+ } else {
+ std::stringstream ss;
+ for (int i = 0; i < actual.num_columns(); ++i) {
+ auto actual_col = actual.column(i);
+ auto expected_col = expected.column(i);
+
+ ASSERT_OK_AND_ASSIGN(auto diff, PrintArrayDiff(*expected_col, *actual_col));
+ if (diff.has_value()) {
+ FAIL() << *diff;
+ }
+ }
+ }
+}
+
+template <typename CompareFunctor>
+void CompareBatchWith(const RecordBatch& left, const RecordBatch& right,
+ bool compare_metadata, CompareFunctor&& compare) {
+ if (!left.schema()->Equals(*right.schema(), compare_metadata)) {
+ FAIL() << "Left schema: " << left.schema()->ToString(compare_metadata)
+ << "\nRight schema: " << right.schema()->ToString(compare_metadata);
+ }
+ ASSERT_EQ(left.num_columns(), right.num_columns())
+ << left.schema()->ToString() << " result: " << right.schema()->ToString();
+ ASSERT_EQ(left.num_rows(), right.num_rows());
+ for (int i = 0; i < left.num_columns(); ++i) {
+ if (!compare(*left.column(i), *right.column(i))) {
+ std::stringstream ss;
+ ss << "Idx: " << i << " Name: " << left.column_name(i);
+ ss << std::endl << "Left: ";
+ ASSERT_OK(PrettyPrint(*left.column(i), 0, &ss));
+ ss << std::endl << "Right: ";
+ ASSERT_OK(PrettyPrint(*right.column(i), 0, &ss));
+ FAIL() << ss.str();
+ }
+ }
+}
+
+void CompareBatch(const RecordBatch& left, const RecordBatch& right,
+ bool compare_metadata) {
+ return CompareBatchWith(
+ left, right, compare_metadata,
+ [](const Array& left, const Array& right) { return left.Equals(right); });
+}
+
+void ApproxCompareBatch(const RecordBatch& left, const RecordBatch& right,
+ bool compare_metadata) {
+ return CompareBatchWith(
+ left, right, compare_metadata,
+ [](const Array& left, const Array& right) { return left.ApproxEquals(right); });
+}
+
+std::shared_ptr<Array> TweakValidityBit(const std::shared_ptr<Array>& array,
+ int64_t index, bool validity) {
+ auto data = array->data()->Copy();
+ if (data->buffers[0] == nullptr) {
+ data->buffers[0] = *AllocateBitmap(data->length);
+ BitUtil::SetBitsTo(data->buffers[0]->mutable_data(), 0, data->length, true);
+ }
+ BitUtil::SetBitTo(data->buffers[0]->mutable_data(), index, validity);
+ data->null_count = kUnknownNullCount;
+ // Need to return a new array, because Array caches the null bitmap pointer
+ return MakeArray(data);
+}
+
+bool LocaleExists(const char* locale) {
+ try {
+ std::locale loc(locale);
+ return true;
+ } catch (std::runtime_error&) {
+ return false;
+ }
+}
+
+class LocaleGuard::Impl {
+ public:
+ explicit Impl(const char* new_locale) : global_locale_(std::locale()) {
+ try {
+ std::locale::global(std::locale(new_locale));
+ } catch (std::runtime_error&) {
+ ARROW_LOG(WARNING) << "Locale unavailable (ignored): '" << new_locale << "'";
+ }
+ }
+
+ ~Impl() { std::locale::global(global_locale_); }
+
+ protected:
+ std::locale global_locale_;
+};
+
+LocaleGuard::LocaleGuard(const char* new_locale) : impl_(new Impl(new_locale)) {}
+
+LocaleGuard::~LocaleGuard() {}
+
+EnvVarGuard::EnvVarGuard(const std::string& name, const std::string& value)
+ : name_(name) {
+ auto maybe_value = arrow::internal::GetEnvVar(name);
+ if (maybe_value.ok()) {
+ was_set_ = true;
+ old_value_ = *std::move(maybe_value);
+ } else {
+ was_set_ = false;
+ }
+ ARROW_CHECK_OK(arrow::internal::SetEnvVar(name, value));
+}
+
+EnvVarGuard::~EnvVarGuard() {
+ if (was_set_) {
+ ARROW_CHECK_OK(arrow::internal::SetEnvVar(name_, old_value_));
+ } else {
+ ARROW_CHECK_OK(arrow::internal::DelEnvVar(name_));
+ }
+}
+
+struct SignalHandlerGuard::Impl {
+ int signum_;
+ internal::SignalHandler old_handler_;
+
+ Impl(int signum, const internal::SignalHandler& handler)
+ : signum_(signum), old_handler_(*internal::SetSignalHandler(signum, handler)) {}
+
+ ~Impl() { ARROW_EXPECT_OK(internal::SetSignalHandler(signum_, old_handler_)); }
+};
+
+SignalHandlerGuard::SignalHandlerGuard(int signum, Callback cb)
+ : SignalHandlerGuard(signum, internal::SignalHandler(cb)) {}
+
+SignalHandlerGuard::SignalHandlerGuard(int signum, const internal::SignalHandler& handler)
+ : impl_(new Impl{signum, handler}) {}
+
+SignalHandlerGuard::~SignalHandlerGuard() = default;
+
+namespace {
+
+// Used to prevent compiler optimizing away side-effect-less statements
+volatile int throw_away = 0;
+
+} // namespace
+
+void AssertZeroPadded(const Array& array) {
+ for (const auto& buffer : array.data()->buffers) {
+ if (buffer) {
+ const int64_t padding = buffer->capacity() - buffer->size();
+ if (padding > 0) {
+ std::vector<uint8_t> zeros(padding);
+ ASSERT_EQ(0, memcmp(buffer->data() + buffer->size(), zeros.data(), padding));
+ }
+ }
+ }
+}
+
+void TestInitialized(const Array& array) { TestInitialized(*array.data()); }
+
+void TestInitialized(const ArrayData& array) {
+ uint8_t total = 0;
+ for (const auto& buffer : array.buffers) {
+ if (buffer && buffer->capacity() > 0) {
+ auto data = buffer->data();
+ for (int64_t i = 0; i < buffer->size(); ++i) {
+ total ^= data[i];
+ }
+ }
+ }
+ uint8_t total_bit = 0;
+ for (uint32_t mask = 1; mask < 256; mask <<= 1) {
+ total_bit ^= (total & mask) != 0;
+ }
+ // This is a dummy condition on all the bits of `total` (which depend on the
+ // entire buffer data). If not all bits are well-defined, Valgrind will
+ // error with "Conditional jump or move depends on uninitialised value(s)".
+ if (total_bit == 0) {
+ ++throw_away;
+ }
+ for (const auto& child : array.child_data) {
+ TestInitialized(*child);
+ }
+ if (array.dictionary) {
+ TestInitialized(*array.dictionary);
+ }
+}
+
+void SleepFor(double seconds) {
+ std::this_thread::sleep_for(
+ std::chrono::nanoseconds(static_cast<int64_t>(seconds * 1e9)));
+}
+
+#ifdef _WIN32
+void SleepABit() {
+ LARGE_INTEGER freq, start, now;
+ QueryPerformanceFrequency(&freq);
+ // 1 ms
+ auto desired = freq.QuadPart / 1000;
+ if (desired <= 0) {
+ // Fallback to STL sleep if high resolution clock not available, tests may fail,
+ // shouldn't really happen
+ SleepFor(1e-3);
+ return;
+ }
+ QueryPerformanceCounter(&start);
+ while (true) {
+ std::this_thread::yield();
+ QueryPerformanceCounter(&now);
+ auto elapsed = now.QuadPart - start.QuadPart;
+ if (elapsed > desired) {
+ break;
+ }
+ }
+}
+#else
+// std::this_thread::sleep_for should be high enough resolution on non-Windows systems
+void SleepABit() { SleepFor(1e-3); }
+#endif
+
+void BusyWait(double seconds, std::function<bool()> predicate) {
+ const double period = 0.001;
+ for (int i = 0; !predicate() && i * period < seconds; ++i) {
+ SleepFor(period);
+ }
+}
+
+Future<> SleepAsync(double seconds) {
+ auto out = Future<>::Make();
+ std::thread([out, seconds]() mutable {
+ SleepFor(seconds);
+ out.MarkFinished();
+ }).detach();
+ return out;
+}
+
+Future<> SleepABitAsync() {
+ auto out = Future<>::Make();
+ std::thread([out]() mutable {
+ SleepABit();
+ out.MarkFinished();
+ }).detach();
+ return out;
+}
+
+///////////////////////////////////////////////////////////////////////////
+// Extension types
+
+bool UuidType::ExtensionEquals(const ExtensionType& other) const {
+ return (other.extension_name() == this->extension_name());
+}
+
+std::shared_ptr<Array> UuidType::MakeArray(std::shared_ptr<ArrayData> data) const {
+ DCHECK_EQ(data->type->id(), Type::EXTENSION);
+ DCHECK_EQ("uuid", static_cast<const ExtensionType&>(*data->type).extension_name());
+ return std::make_shared<UuidArray>(data);
+}
+
+Result<std::shared_ptr<DataType>> UuidType::Deserialize(
+ std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
+ if (serialized != "uuid-serialized") {
+ return Status::Invalid("Type identifier did not match: '", serialized, "'");
+ }
+ if (!storage_type->Equals(*fixed_size_binary(16))) {
+ return Status::Invalid("Invalid storage type for UuidType: ",
+ storage_type->ToString());
+ }
+ return std::make_shared<UuidType>();
+}
+
+bool SmallintType::ExtensionEquals(const ExtensionType& other) const {
+ return (other.extension_name() == this->extension_name());
+}
+
+std::shared_ptr<Array> SmallintType::MakeArray(std::shared_ptr<ArrayData> data) const {
+ DCHECK_EQ(data->type->id(), Type::EXTENSION);
+ DCHECK_EQ("smallint", static_cast<const ExtensionType&>(*data->type).extension_name());
+ return std::make_shared<SmallintArray>(data);
+}
+
+Result<std::shared_ptr<DataType>> SmallintType::Deserialize(
+ std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
+ if (serialized != "smallint") {
+ return Status::Invalid("Type identifier did not match: '", serialized, "'");
+ }
+ if (!storage_type->Equals(*int16())) {
+ return Status::Invalid("Invalid storage type for SmallintType: ",
+ storage_type->ToString());
+ }
+ return std::make_shared<SmallintType>();
+}
+
+bool DictExtensionType::ExtensionEquals(const ExtensionType& other) const {
+ return (other.extension_name() == this->extension_name());
+}
+
+std::shared_ptr<Array> DictExtensionType::MakeArray(
+ std::shared_ptr<ArrayData> data) const {
+ DCHECK_EQ(data->type->id(), Type::EXTENSION);
+ DCHECK(ExtensionEquals(checked_cast<const ExtensionType&>(*data->type)));
+ // No need for a specific ExtensionArray derived class
+ return std::make_shared<ExtensionArray>(data);
+}
+
+Result<std::shared_ptr<DataType>> DictExtensionType::Deserialize(
+ std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
+ if (serialized != "dict-extension-serialized") {
+ return Status::Invalid("Type identifier did not match: '", serialized, "'");
+ }
+ if (!storage_type->Equals(*storage_type_)) {
+ return Status::Invalid("Invalid storage type for DictExtensionType: ",
+ storage_type->ToString());
+ }
+ return std::make_shared<DictExtensionType>();
+}
+
+bool Complex128Type::ExtensionEquals(const ExtensionType& other) const {
+ return (other.extension_name() == this->extension_name());
+}
+
+std::shared_ptr<Array> Complex128Type::MakeArray(std::shared_ptr<ArrayData> data) const {
+ DCHECK_EQ(data->type->id(), Type::EXTENSION);
+ DCHECK(ExtensionEquals(checked_cast<const ExtensionType&>(*data->type)));
+ return std::make_shared<Complex128Array>(data);
+}
+
+Result<std::shared_ptr<DataType>> Complex128Type::Deserialize(
+ std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
+ if (serialized != "complex128-serialized") {
+ return Status::Invalid("Type identifier did not match: '", serialized, "'");
+ }
+ if (!storage_type->Equals(*storage_type_)) {
+ return Status::Invalid("Invalid storage type for Complex128Type: ",
+ storage_type->ToString());
+ }
+ return std::make_shared<Complex128Type>();
+}
+
+std::shared_ptr<DataType> uuid() { return std::make_shared<UuidType>(); }
+
+std::shared_ptr<DataType> smallint() { return std::make_shared<SmallintType>(); }
+
+std::shared_ptr<DataType> dict_extension_type() {
+ return std::make_shared<DictExtensionType>();
+}
+
+std::shared_ptr<DataType> complex128() { return std::make_shared<Complex128Type>(); }
+
+std::shared_ptr<Array> MakeComplex128(const std::shared_ptr<Array>& real,
+ const std::shared_ptr<Array>& imag) {
+ auto type = complex128();
+ std::shared_ptr<Array> storage(
+ new StructArray(checked_cast<const ExtensionType&>(*type).storage_type(),
+ real->length(), {real, imag}));
+ return ExtensionType::WrapArray(type, storage);
+}
+
+std::shared_ptr<Array> ExampleUuid() {
+ auto arr = ArrayFromJSON(
+ fixed_size_binary(16),
+ "[null, \"abcdefghijklmno0\", \"abcdefghijklmno1\", \"abcdefghijklmno2\"]");
+ return ExtensionType::WrapArray(uuid(), arr);
+}
+
+std::shared_ptr<Array> ExampleSmallint() {
+ auto arr = ArrayFromJSON(int16(), "[-32768, null, 1, 2, 3, 4, 32767]");
+ return ExtensionType::WrapArray(smallint(), arr);
+}
+
+std::shared_ptr<Array> ExampleDictExtension() {
+ auto arr = DictArrayFromJSON(dictionary(int8(), utf8()), "[0, 1, null, 1]",
+ R"(["foo", "bar"])");
+ return ExtensionType::WrapArray(dict_extension_type(), arr);
+}
+
+std::shared_ptr<Array> ExampleComplex128() {
+ auto arr = ArrayFromJSON(struct_({field("", float64()), field("", float64())}),
+ "[[1.0, -2.5], null, [3.0, -4.5]]");
+ return ExtensionType::WrapArray(complex128(), arr);
+}
+
+ExtensionTypeGuard::ExtensionTypeGuard(const std::shared_ptr<DataType>& type)
+ : ExtensionTypeGuard(DataTypeVector{type}) {}
+
+ExtensionTypeGuard::ExtensionTypeGuard(const DataTypeVector& types) {
+ for (const auto& type : types) {
+ ARROW_CHECK_EQ(type->id(), Type::EXTENSION);
+ auto ext_type = checked_pointer_cast<ExtensionType>(type);
+
+ ARROW_CHECK_OK(RegisterExtensionType(ext_type));
+ extension_names_.push_back(ext_type->extension_name());
+ DCHECK(!extension_names_.back().empty());
+ }
+}
+
+ExtensionTypeGuard::~ExtensionTypeGuard() {
+ for (const auto& name : extension_names_) {
+ ARROW_CHECK_OK(UnregisterExtensionType(name));
+ }
+}
+
+class GatingTask::Impl : public std::enable_shared_from_this<GatingTask::Impl> {
+ public:
+ explicit Impl(double timeout_seconds)
+ : timeout_seconds_(timeout_seconds), status_(), unlocked_(false) {
+ unlocked_future_ = Future<>::Make();
+ }
+
+ ~Impl() {
+ if (num_running_ != num_launched_) {
+ ADD_FAILURE()
+ << "A GatingTask instance was destroyed but some underlying tasks did not "
+ "start running"
+ << std::endl;
+ } else if (num_finished_ != num_launched_) {
+ ADD_FAILURE()
+ << "A GatingTask instance was destroyed but some underlying tasks did not "
+ "finish running"
+ << std::endl;
+ }
+ }
+
+ std::function<void()> Task() {
+ num_launched_++;
+ auto self = shared_from_this();
+ return [self] { self->RunTask(); };
+ }
+
+ Future<> AsyncTask() {
+ num_launched_++;
+ num_running_++;
+ /// TODO(ARROW-13004) Could maybe implement this check with future chains
+ /// if we check to see if the future has been "consumed" or not
+ num_finished_++;
+ return unlocked_future_;
+ }
+
+ void RunTask() {
+ std::unique_lock<std::mutex> lk(mx_);
+ num_running_++;
+ running_cv_.notify_all();
+ if (!unlocked_cv_.wait_for(
+ lk, std::chrono::nanoseconds(static_cast<int64_t>(timeout_seconds_ * 1e9)),
+ [this] { return unlocked_; })) {
+ status_ &= Status::Invalid("Timed out (" + std::to_string(timeout_seconds_) + "," +
+ std::to_string(unlocked_) +
+ " seconds) waiting for the gating task to be unlocked");
+ }
+ num_finished_++;
+ }
+
+ Status WaitForRunning(int count) {
+ std::unique_lock<std::mutex> lk(mx_);
+ if (running_cv_.wait_for(
+ lk, std::chrono::nanoseconds(static_cast<int64_t>(timeout_seconds_ * 1e9)),
+ [this, count] { return num_running_ >= count; })) {
+ return Status::OK();
+ }
+ return Status::Invalid("Timed out waiting for tasks to launch");
+ }
+
+ Status Unlock() {
+ std::lock_guard<std::mutex> lk(mx_);
+ unlocked_ = true;
+ unlocked_cv_.notify_all();
+ unlocked_future_.MarkFinished();
+ return status_;
+ }
+
+ private:
+ double timeout_seconds_;
+ Status status_;
+ bool unlocked_;
+ std::atomic<int> num_launched_{0};
+ int num_running_ = 0;
+ int num_finished_ = 0;
+ std::mutex mx_;
+ std::condition_variable running_cv_;
+ std::condition_variable unlocked_cv_;
+ Future<> unlocked_future_;
+};
+
+GatingTask::GatingTask(double timeout_seconds) : impl_(new Impl(timeout_seconds)) {}
+
+GatingTask::~GatingTask() {}
+
+std::function<void()> GatingTask::Task() { return impl_->Task(); }
+
+Future<> GatingTask::AsyncTask() { return impl_->AsyncTask(); }
+
+Status GatingTask::Unlock() { return impl_->Unlock(); }
+
+Status GatingTask::WaitForRunning(int count) { return impl_->WaitForRunning(count); }
+
+std::shared_ptr<GatingTask> GatingTask::Make(double timeout_seconds) {
+ return std::make_shared<GatingTask>(timeout_seconds);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/gtest_util.h b/src/arrow/cpp/src/arrow/testing/gtest_util.h
new file mode 100644
index 000000000..da145bdfa
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/gtest_util.h
@@ -0,0 +1,691 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+#include <functional>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/builder_time.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_compat.h"
+#include "arrow/testing/util.h"
+#include "arrow/testing/visibility.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_builder.h"
+#include "arrow/util/type_fwd.h"
+
+// NOTE: failing must be inline in the macros below, to get correct file / line number
+// reporting on test failures.
+
+// NOTE: using a for loop for this macro allows extra failure messages to be
+// appended with operator<<
+#define ASSERT_RAISES(ENUM, expr) \
+ for (::arrow::Status _st = ::arrow::internal::GenericToStatus((expr)); \
+ !_st.Is##ENUM();) \
+ FAIL() << "Expected '" ARROW_STRINGIFY(expr) "' to fail with " ARROW_STRINGIFY( \
+ ENUM) ", but got " \
+ << _st.ToString()
+
+#define ASSERT_RAISES_WITH_MESSAGE(ENUM, message, expr) \
+ do { \
+ auto _res = (expr); \
+ ::arrow::Status _st = ::arrow::internal::GenericToStatus(_res); \
+ if (!_st.Is##ENUM()) { \
+ FAIL() << "Expected '" ARROW_STRINGIFY(expr) "' to fail with " ARROW_STRINGIFY( \
+ ENUM) ", but got " \
+ << _st.ToString(); \
+ } \
+ ASSERT_EQ((message), _st.ToString()); \
+ } while (false)
+
+#define EXPECT_RAISES_WITH_MESSAGE_THAT(ENUM, matcher, expr) \
+ do { \
+ auto _res = (expr); \
+ ::arrow::Status _st = ::arrow::internal::GenericToStatus(_res); \
+ EXPECT_TRUE(_st.Is##ENUM()) << "Expected '" ARROW_STRINGIFY(expr) "' to fail with " \
+ << ARROW_STRINGIFY(ENUM) ", but got " << _st.ToString(); \
+ EXPECT_THAT(_st.ToString(), (matcher)); \
+ } while (false)
+
+#define EXPECT_RAISES_WITH_CODE_AND_MESSAGE_THAT(code, matcher, expr) \
+ do { \
+ auto _res = (expr); \
+ ::arrow::Status _st = ::arrow::internal::GenericToStatus(_res); \
+ EXPECT_EQ(_st.CodeAsString(), Status::CodeAsString(code)); \
+ EXPECT_THAT(_st.ToString(), (matcher)); \
+ } while (false)
+
+#define ASSERT_OK(expr) \
+ for (::arrow::Status _st = ::arrow::internal::GenericToStatus((expr)); !_st.ok();) \
+ FAIL() << "'" ARROW_STRINGIFY(expr) "' failed with " << _st.ToString()
+
+#define ASSERT_OK_NO_THROW(expr) ASSERT_NO_THROW(ASSERT_OK(expr))
+
+#define ARROW_EXPECT_OK(expr) \
+ do { \
+ auto _res = (expr); \
+ ::arrow::Status _st = ::arrow::internal::GenericToStatus(_res); \
+ EXPECT_TRUE(_st.ok()) << "'" ARROW_STRINGIFY(expr) "' failed with " \
+ << _st.ToString(); \
+ } while (false)
+
+#define ASSERT_NOT_OK(expr) \
+ for (::arrow::Status _st = ::arrow::internal::GenericToStatus((expr)); _st.ok();) \
+ FAIL() << "'" ARROW_STRINGIFY(expr) "' did not failed" << _st.ToString()
+
+#define ABORT_NOT_OK(expr) \
+ do { \
+ auto _res = (expr); \
+ ::arrow::Status _st = ::arrow::internal::GenericToStatus(_res); \
+ if (ARROW_PREDICT_FALSE(!_st.ok())) { \
+ _st.Abort(); \
+ } \
+ } while (false);
+
+#define ASSIGN_OR_HANDLE_ERROR_IMPL(handle_error, status_name, lhs, rexpr) \
+ auto&& status_name = (rexpr); \
+ handle_error(status_name.status()); \
+ lhs = std::move(status_name).ValueOrDie();
+
+#define ASSERT_OK_AND_ASSIGN(lhs, rexpr) \
+ ASSIGN_OR_HANDLE_ERROR_IMPL( \
+ ASSERT_OK, ARROW_ASSIGN_OR_RAISE_NAME(_error_or_value, __COUNTER__), lhs, rexpr);
+
+#define ASSIGN_OR_ABORT(lhs, rexpr) \
+ ASSIGN_OR_HANDLE_ERROR_IMPL(ABORT_NOT_OK, \
+ ARROW_ASSIGN_OR_RAISE_NAME(_error_or_value, __COUNTER__), \
+ lhs, rexpr);
+
+#define EXPECT_OK_AND_ASSIGN(lhs, rexpr) \
+ ASSIGN_OR_HANDLE_ERROR_IMPL(ARROW_EXPECT_OK, \
+ ARROW_ASSIGN_OR_RAISE_NAME(_error_or_value, __COUNTER__), \
+ lhs, rexpr);
+
+#define ASSERT_OK_AND_EQ(expected, expr) \
+ do { \
+ ASSERT_OK_AND_ASSIGN(auto _actual, (expr)); \
+ ASSERT_EQ(expected, _actual); \
+ } while (0)
+
+// A generalized version of GTest's SCOPED_TRACE that takes arbitrary arguments.
+// ARROW_SCOPED_TRACE("some variable = ", some_variable, ...)
+
+#define ARROW_SCOPED_TRACE(...) SCOPED_TRACE(::arrow::util::StringBuilder(__VA_ARGS__))
+
+namespace arrow {
+
+// ----------------------------------------------------------------------
+// Useful testing::Types declarations
+
+inline void PrintTo(StatusCode code, std::ostream* os) {
+ *os << Status::CodeAsString(code);
+}
+
+using NumericArrowTypes =
+ ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
+ Int32Type, Int64Type, FloatType, DoubleType>;
+
+using RealArrowTypes = ::testing::Types<FloatType, DoubleType>;
+
+using IntegralArrowTypes = ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type,
+ Int8Type, Int16Type, Int32Type, Int64Type>;
+
+using PhysicalIntegralArrowTypes =
+ ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
+ Int32Type, Int64Type, Date32Type, Date64Type, Time32Type, Time64Type,
+ TimestampType, MonthIntervalType>;
+
+using PrimitiveArrowTypes =
+ ::testing::Types<BooleanType, Int8Type, UInt8Type, Int16Type, UInt16Type, Int32Type,
+ UInt32Type, Int64Type, UInt64Type, FloatType, DoubleType>;
+
+using TemporalArrowTypes =
+ ::testing::Types<Date32Type, Date64Type, TimestampType, Time32Type, Time64Type>;
+
+using DecimalArrowTypes = ::testing::Types<Decimal128Type, Decimal256Type>;
+
+using BinaryArrowTypes =
+ ::testing::Types<BinaryType, LargeBinaryType, StringType, LargeStringType>;
+
+using StringArrowTypes = ::testing::Types<StringType, LargeStringType>;
+
+using ListArrowTypes = ::testing::Types<ListType, LargeListType>;
+
+using UnionArrowTypes = ::testing::Types<SparseUnionType, DenseUnionType>;
+
+class Array;
+class ChunkedArray;
+class RecordBatch;
+class Table;
+struct Datum;
+
+ARROW_TESTING_EXPORT
+std::vector<Type::type> AllTypeIds();
+
+#define ASSERT_ARRAYS_EQUAL(lhs, rhs) AssertArraysEqual((lhs), (rhs))
+#define ASSERT_BATCHES_EQUAL(lhs, rhs) AssertBatchesEqual((lhs), (rhs))
+#define ASSERT_BATCHES_APPROX_EQUAL(lhs, rhs) AssertBatchesApproxEqual((lhs), (rhs))
+#define ASSERT_TABLES_EQUAL(lhs, rhs) AssertTablesEqual((lhs), (rhs))
+
+// If verbose is true, then the arrays will be pretty printed
+ARROW_TESTING_EXPORT void AssertArraysEqual(const Array& expected, const Array& actual,
+ bool verbose = false,
+ const EqualOptions& options = {});
+ARROW_TESTING_EXPORT void AssertArraysApproxEqual(const Array& expected,
+ const Array& actual,
+ bool verbose = false,
+ const EqualOptions& options = {});
+// Returns true when values are both null
+ARROW_TESTING_EXPORT void AssertScalarsEqual(
+ const Scalar& expected, const Scalar& actual, bool verbose = false,
+ const EqualOptions& options = EqualOptions::Defaults());
+ARROW_TESTING_EXPORT void AssertScalarsApproxEqual(
+ const Scalar& expected, const Scalar& actual, bool verbose = false,
+ const EqualOptions& options = EqualOptions::Defaults());
+ARROW_TESTING_EXPORT void AssertBatchesEqual(const RecordBatch& expected,
+ const RecordBatch& actual,
+ bool check_metadata = false);
+ARROW_TESTING_EXPORT void AssertBatchesApproxEqual(const RecordBatch& expected,
+ const RecordBatch& actual);
+ARROW_TESTING_EXPORT void AssertChunkedEqual(const ChunkedArray& expected,
+ const ChunkedArray& actual);
+ARROW_TESTING_EXPORT void AssertChunkedEqual(const ChunkedArray& actual,
+ const ArrayVector& expected);
+// Like ChunkedEqual, but permits different chunk layout
+ARROW_TESTING_EXPORT void AssertChunkedEquivalent(const ChunkedArray& expected,
+ const ChunkedArray& actual);
+ARROW_TESTING_EXPORT void AssertChunkedApproxEquivalent(
+ const ChunkedArray& expected, const ChunkedArray& actual,
+ const EqualOptions& equal_options = EqualOptions::Defaults());
+ARROW_TESTING_EXPORT void AssertBufferEqual(const Buffer& buffer,
+ const std::vector<uint8_t>& expected);
+ARROW_TESTING_EXPORT void AssertBufferEqual(const Buffer& buffer,
+ const std::string& expected);
+ARROW_TESTING_EXPORT void AssertBufferEqual(const Buffer& buffer, const Buffer& expected);
+
+ARROW_TESTING_EXPORT void AssertTypeEqual(const DataType& lhs, const DataType& rhs,
+ bool check_metadata = false);
+ARROW_TESTING_EXPORT void AssertTypeEqual(const std::shared_ptr<DataType>& lhs,
+ const std::shared_ptr<DataType>& rhs,
+ bool check_metadata = false);
+ARROW_TESTING_EXPORT void AssertFieldEqual(const Field& lhs, const Field& rhs,
+ bool check_metadata = false);
+ARROW_TESTING_EXPORT void AssertFieldEqual(const std::shared_ptr<Field>& lhs,
+ const std::shared_ptr<Field>& rhs,
+ bool check_metadata = false);
+ARROW_TESTING_EXPORT void AssertSchemaEqual(const Schema& lhs, const Schema& rhs,
+ bool check_metadata = false);
+ARROW_TESTING_EXPORT void AssertSchemaEqual(const std::shared_ptr<Schema>& lhs,
+ const std::shared_ptr<Schema>& rhs,
+ bool check_metadata = false);
+
+ARROW_TESTING_EXPORT void AssertTypeNotEqual(const DataType& lhs, const DataType& rhs,
+ bool check_metadata = false);
+ARROW_TESTING_EXPORT void AssertTypeNotEqual(const std::shared_ptr<DataType>& lhs,
+ const std::shared_ptr<DataType>& rhs,
+ bool check_metadata = false);
+ARROW_TESTING_EXPORT void AssertFieldNotEqual(const Field& lhs, const Field& rhs,
+ bool check_metadata = false);
+ARROW_TESTING_EXPORT void AssertFieldNotEqual(const std::shared_ptr<Field>& lhs,
+ const std::shared_ptr<Field>& rhs,
+ bool check_metadata = false);
+ARROW_TESTING_EXPORT void AssertSchemaNotEqual(const Schema& lhs, const Schema& rhs,
+ bool check_metadata = false);
+ARROW_TESTING_EXPORT void AssertSchemaNotEqual(const std::shared_ptr<Schema>& lhs,
+ const std::shared_ptr<Schema>& rhs,
+ bool check_metadata = false);
+
+ARROW_TESTING_EXPORT Result<util::optional<std::string>> PrintArrayDiff(
+ const ChunkedArray& expected, const ChunkedArray& actual);
+
+ARROW_TESTING_EXPORT void AssertTablesEqual(const Table& expected, const Table& actual,
+ bool same_chunk_layout = true,
+ bool flatten = false);
+
+ARROW_TESTING_EXPORT void AssertDatumsEqual(const Datum& expected, const Datum& actual,
+ bool verbose = false);
+ARROW_TESTING_EXPORT void AssertDatumsApproxEqual(
+ const Datum& expected, const Datum& actual, bool verbose = false,
+ const EqualOptions& options = EqualOptions::Defaults());
+
+template <typename C_TYPE>
+void AssertNumericDataEqual(const C_TYPE* raw_data,
+ const std::vector<C_TYPE>& expected_values) {
+ for (auto expected : expected_values) {
+ ASSERT_EQ(expected, *raw_data);
+ ++raw_data;
+ }
+}
+
+ARROW_TESTING_EXPORT void CompareBatch(const RecordBatch& left, const RecordBatch& right,
+ bool compare_metadata = true);
+
+ARROW_TESTING_EXPORT void ApproxCompareBatch(const RecordBatch& left,
+ const RecordBatch& right,
+ bool compare_metadata = true);
+
+// Check if the padding of the buffers of the array is zero.
+// Also cause valgrind warnings if the padding bytes are uninitialized.
+ARROW_TESTING_EXPORT void AssertZeroPadded(const Array& array);
+
+// Check if the valid buffer bytes are initialized
+// and cause valgrind warnings otherwise.
+ARROW_TESTING_EXPORT void TestInitialized(const ArrayData& array);
+ARROW_TESTING_EXPORT void TestInitialized(const Array& array);
+
+template <typename BuilderType>
+void FinishAndCheckPadding(BuilderType* builder, std::shared_ptr<Array>* out) {
+ ASSERT_OK_AND_ASSIGN(*out, builder->Finish());
+ AssertZeroPadded(**out);
+ TestInitialized(**out);
+}
+
+#define DECL_T() typedef typename TestFixture::T T;
+
+#define DECL_TYPE() typedef typename TestFixture::Type Type;
+
+// ArrayFromJSON: construct an Array from a simple JSON representation
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<Array> ArrayFromJSON(const std::shared_ptr<DataType>&,
+ util::string_view json);
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<Array> DictArrayFromJSON(const std::shared_ptr<DataType>& type,
+ util::string_view indices_json,
+ util::string_view dictionary_json);
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<RecordBatch> RecordBatchFromJSON(const std::shared_ptr<Schema>&,
+ util::string_view);
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<ChunkedArray> ChunkedArrayFromJSON(const std::shared_ptr<DataType>&,
+ const std::vector<std::string>& json);
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<Scalar> ScalarFromJSON(const std::shared_ptr<DataType>&,
+ util::string_view json);
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<Scalar> DictScalarFromJSON(const std::shared_ptr<DataType>&,
+ util::string_view index_json,
+ util::string_view dictionary_json);
+
+ARROW_TESTING_EXPORT
+std::shared_ptr<Table> TableFromJSON(const std::shared_ptr<Schema>&,
+ const std::vector<std::string>& json);
+
+// ArrayFromVector: construct an Array from vectors of C values
+
+template <typename TYPE, typename C_TYPE = typename TYPE::c_type>
+void ArrayFromVector(const std::shared_ptr<DataType>& type,
+ const std::vector<bool>& is_valid, const std::vector<C_TYPE>& values,
+ std::shared_ptr<Array>* out) {
+ auto type_id = TYPE::type_id;
+ ASSERT_EQ(type_id, type->id())
+ << "template parameter and concrete DataType instance don't agree";
+
+ std::unique_ptr<ArrayBuilder> builder_ptr;
+ ASSERT_OK(MakeBuilder(default_memory_pool(), type, &builder_ptr));
+ // Get the concrete builder class to access its Append() specializations
+ auto& builder = dynamic_cast<typename TypeTraits<TYPE>::BuilderType&>(*builder_ptr);
+
+ for (size_t i = 0; i < values.size(); ++i) {
+ if (is_valid[i]) {
+ ASSERT_OK(builder.Append(values[i]));
+ } else {
+ ASSERT_OK(builder.AppendNull());
+ }
+ }
+ ASSERT_OK(builder.Finish(out));
+}
+
+template <typename TYPE, typename C_TYPE = typename TYPE::c_type>
+void ArrayFromVector(const std::shared_ptr<DataType>& type,
+ const std::vector<C_TYPE>& values, std::shared_ptr<Array>* out) {
+ auto type_id = TYPE::type_id;
+ ASSERT_EQ(type_id, type->id())
+ << "template parameter and concrete DataType instance don't agree";
+
+ std::unique_ptr<ArrayBuilder> builder_ptr;
+ ASSERT_OK(MakeBuilder(default_memory_pool(), type, &builder_ptr));
+ // Get the concrete builder class to access its Append() specializations
+ auto& builder = dynamic_cast<typename TypeTraits<TYPE>::BuilderType&>(*builder_ptr);
+
+ for (size_t i = 0; i < values.size(); ++i) {
+ ASSERT_OK(builder.Append(values[i]));
+ }
+ ASSERT_OK(builder.Finish(out));
+}
+
+// Overloads without a DataType argument, for parameterless types
+
+template <typename TYPE, typename C_TYPE = typename TYPE::c_type>
+void ArrayFromVector(const std::vector<bool>& is_valid, const std::vector<C_TYPE>& values,
+ std::shared_ptr<Array>* out) {
+ auto type = TypeTraits<TYPE>::type_singleton();
+ ArrayFromVector<TYPE, C_TYPE>(type, is_valid, values, out);
+}
+
+template <typename TYPE, typename C_TYPE = typename TYPE::c_type>
+void ArrayFromVector(const std::vector<C_TYPE>& values, std::shared_ptr<Array>* out) {
+ auto type = TypeTraits<TYPE>::type_singleton();
+ ArrayFromVector<TYPE, C_TYPE>(type, values, out);
+}
+
+// ChunkedArrayFromVector: construct a ChunkedArray from vectors of C values
+
+template <typename TYPE, typename C_TYPE = typename TYPE::c_type>
+void ChunkedArrayFromVector(const std::shared_ptr<DataType>& type,
+ const std::vector<std::vector<bool>>& is_valid,
+ const std::vector<std::vector<C_TYPE>>& values,
+ std::shared_ptr<ChunkedArray>* out) {
+ ArrayVector chunks;
+ ASSERT_EQ(is_valid.size(), values.size());
+ for (size_t i = 0; i < values.size(); ++i) {
+ std::shared_ptr<Array> array;
+ ArrayFromVector<TYPE, C_TYPE>(type, is_valid[i], values[i], &array);
+ chunks.push_back(array);
+ }
+ *out = std::make_shared<ChunkedArray>(chunks);
+}
+
+template <typename TYPE, typename C_TYPE = typename TYPE::c_type>
+void ChunkedArrayFromVector(const std::shared_ptr<DataType>& type,
+ const std::vector<std::vector<C_TYPE>>& values,
+ std::shared_ptr<ChunkedArray>* out) {
+ ArrayVector chunks;
+ for (size_t i = 0; i < values.size(); ++i) {
+ std::shared_ptr<Array> array;
+ ArrayFromVector<TYPE, C_TYPE>(type, values[i], &array);
+ chunks.push_back(array);
+ }
+ *out = std::make_shared<ChunkedArray>(chunks);
+}
+
+// Overloads without a DataType argument, for parameterless types
+
+template <typename TYPE, typename C_TYPE = typename TYPE::c_type>
+void ChunkedArrayFromVector(const std::vector<std::vector<bool>>& is_valid,
+ const std::vector<std::vector<C_TYPE>>& values,
+ std::shared_ptr<ChunkedArray>* out) {
+ auto type = TypeTraits<TYPE>::type_singleton();
+ ChunkedArrayFromVector<TYPE, C_TYPE>(type, is_valid, values, out);
+}
+
+template <typename TYPE, typename C_TYPE = typename TYPE::c_type>
+void ChunkedArrayFromVector(const std::vector<std::vector<C_TYPE>>& values,
+ std::shared_ptr<ChunkedArray>* out) {
+ auto type = TypeTraits<TYPE>::type_singleton();
+ ChunkedArrayFromVector<TYPE, C_TYPE>(type, values, out);
+}
+
+template <typename T>
+static inline Status GetBitmapFromVector(const std::vector<T>& is_valid,
+ std::shared_ptr<Buffer>* result) {
+ size_t length = is_valid.size();
+
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateEmptyBitmap(length));
+
+ uint8_t* bitmap = buffer->mutable_data();
+ for (size_t i = 0; i < static_cast<size_t>(length); ++i) {
+ if (is_valid[i]) {
+ BitUtil::SetBit(bitmap, i);
+ }
+ }
+
+ *result = buffer;
+ return Status::OK();
+}
+
+template <typename T>
+inline void BitmapFromVector(const std::vector<T>& is_valid,
+ std::shared_ptr<Buffer>* out) {
+ ASSERT_OK(GetBitmapFromVector(is_valid, out));
+}
+
+// Given an array, return a new identical array except for one validity bit
+// set to a new value.
+// This is useful to force the underlying "value" of null entries to otherwise
+// invalid data and check that errors don't get reported.
+ARROW_TESTING_EXPORT
+std::shared_ptr<Array> TweakValidityBit(const std::shared_ptr<Array>& array,
+ int64_t index, bool validity);
+
+ARROW_TESTING_EXPORT
+void SleepFor(double seconds);
+
+// Sleeps for a very small amount of time. The thread will be yielded
+// at least once ensuring that context switches could happen. It is intended
+// to be used for stress testing parallel code and shouldn't be assumed to do any
+// reliable timing.
+ARROW_TESTING_EXPORT
+void SleepABit();
+
+// Wait until predicate is true or timeout in seconds expires.
+ARROW_TESTING_EXPORT
+void BusyWait(double seconds, std::function<bool()> predicate);
+
+ARROW_TESTING_EXPORT
+Future<> SleepAsync(double seconds);
+
+// \see SleepABit
+ARROW_TESTING_EXPORT
+Future<> SleepABitAsync();
+
+template <typename T>
+std::vector<T> IteratorToVector(Iterator<T> iterator) {
+ EXPECT_OK_AND_ASSIGN(auto out, iterator.ToVector());
+ return out;
+}
+
+ARROW_TESTING_EXPORT
+bool LocaleExists(const char* locale);
+
+// A RAII-style object that switches to a new locale, and switches back
+// to the old locale when going out of scope. Doesn't do anything if the
+// new locale doesn't exist on the local machine.
+// ATTENTION: may crash with an assertion failure on Windows debug builds.
+// See ARROW-6108, also https://gerrit.libreoffice.org/#/c/54110/
+class ARROW_TESTING_EXPORT LocaleGuard {
+ public:
+ explicit LocaleGuard(const char* new_locale);
+ ~LocaleGuard();
+
+ protected:
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+class ARROW_TESTING_EXPORT EnvVarGuard {
+ public:
+ EnvVarGuard(const std::string& name, const std::string& value);
+ ~EnvVarGuard();
+
+ protected:
+ const std::string name_;
+ std::string old_value_;
+ bool was_set_;
+};
+
+namespace internal {
+class SignalHandler;
+}
+
+class ARROW_TESTING_EXPORT SignalHandlerGuard {
+ public:
+ typedef void (*Callback)(int);
+
+ SignalHandlerGuard(int signum, Callback cb);
+ SignalHandlerGuard(int signum, const internal::SignalHandler& handler);
+ ~SignalHandlerGuard();
+
+ protected:
+ struct Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+#ifndef ARROW_LARGE_MEMORY_TESTS
+#define LARGE_MEMORY_TEST(name) DISABLED_##name
+#else
+#define LARGE_MEMORY_TEST(name) name
+#endif
+
+inline void PrintTo(const Status& st, std::ostream* os) { *os << st.ToString(); }
+
+template <typename T>
+void PrintTo(const Result<T>& result, std::ostream* os) {
+ if (result.ok()) {
+ ::testing::internal::UniversalPrint(result.ValueOrDie(), os);
+ } else {
+ *os << result.status();
+ }
+}
+
+// A data type with only move constructors (no copy, no default).
+struct MoveOnlyDataType {
+ explicit MoveOnlyDataType(int x) : data(new int(x)) {}
+
+ MoveOnlyDataType(const MoveOnlyDataType& other) = delete;
+ MoveOnlyDataType& operator=(const MoveOnlyDataType& other) = delete;
+
+ MoveOnlyDataType(MoveOnlyDataType&& other) { MoveFrom(&other); }
+ MoveOnlyDataType& operator=(MoveOnlyDataType&& other) {
+ MoveFrom(&other);
+ return *this;
+ }
+
+ MoveOnlyDataType& operator=(int x) {
+ if (data != nullptr) {
+ delete data;
+ }
+ data = new int(x);
+ return *this;
+ }
+
+ ~MoveOnlyDataType() { Destroy(); }
+
+ void Destroy() {
+ if (data != nullptr) {
+ delete data;
+ data = nullptr;
+ moves = -1;
+ }
+ }
+
+ void MoveFrom(MoveOnlyDataType* other) {
+ Destroy();
+ data = other->data;
+ other->data = nullptr;
+ moves = other->moves + 1;
+ }
+
+ int ToInt() const { return data == nullptr ? -42 : *data; }
+
+ bool operator==(const MoveOnlyDataType& other) const {
+ return data != nullptr && other.data != nullptr && *data == *other.data;
+ }
+ bool operator<(const MoveOnlyDataType& other) const {
+ return data == nullptr || (other.data != nullptr && *data < *other.data);
+ }
+
+ bool operator==(int other) const { return data != nullptr && *data == other; }
+ friend bool operator==(int left, const MoveOnlyDataType& right) {
+ return right == left;
+ }
+
+ int* data = nullptr;
+ int moves = 0;
+};
+
+// A task that blocks until unlocked. Useful for timing tests.
+class ARROW_TESTING_EXPORT GatingTask {
+ public:
+ explicit GatingTask(double timeout_seconds = 10);
+ /// \brief During destruction we wait for all pending tasks to finish
+ ~GatingTask();
+
+ /// \brief Creates a new waiting task (presumably to spawn on a thread). It will return
+ /// invalid if the timeout arrived before the unlock. The task will not complete until
+ /// unlocked or timed out
+ ///
+ /// Note: The GatingTask must outlive any Task instances
+ std::function<void()> Task();
+ /// \brief Creates a new waiting task as a future. The future will not complete
+ /// until unlocked.
+ Future<> AsyncTask();
+ /// \brief Waits until at least count tasks are running.
+ Status WaitForRunning(int count);
+ /// \brief Unlocks all waiting tasks. Returns an invalid status if any waiting task has
+ /// timed out
+ Status Unlock();
+
+ static std::shared_ptr<GatingTask> Make(double timeout_seconds = 10);
+
+ private:
+ class Impl;
+ std::shared_ptr<Impl> impl_;
+};
+
+} // namespace arrow
+
+namespace nonstd {
+namespace sv_lite {
+
+// Without this hint, GTest will print string_views as a container of char
+template <class Char, class Traits = std::char_traits<Char>>
+void PrintTo(const basic_string_view<Char, Traits>& view, std::ostream* os) {
+ *os << view;
+}
+
+} // namespace sv_lite
+
+namespace optional_lite {
+
+template <typename T>
+void PrintTo(const optional<T>& opt, std::ostream* os) {
+ if (opt.has_value()) {
+ *os << "{";
+ ::testing::internal::UniversalPrint(*opt, os);
+ *os << "}";
+ } else {
+ *os << "nullopt";
+ }
+}
+
+inline void PrintTo(const decltype(nullopt)&, std::ostream* os) { *os << "nullopt"; }
+
+} // namespace optional_lite
+} // namespace nonstd
diff --git a/src/arrow/cpp/src/arrow/testing/json_integration.cc b/src/arrow/cpp/src/arrow/testing/json_integration.cc
new file mode 100644
index 000000000..2af094a78
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/json_integration.cc
@@ -0,0 +1,219 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/testing/json_integration.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/io/file.h"
+#include "arrow/ipc/dictionary.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/testing/json_internal.h"
+#include "arrow/type.h"
+#include "arrow/util/logging.h"
+
+#include <rapidjson/document.h>
+#include <rapidjson/stringbuffer.h>
+#include <rapidjson/writer.h>
+
+using std::size_t;
+
+namespace arrow {
+
+using ipc::DictionaryFieldMapper;
+using ipc::DictionaryMemo;
+
+namespace testing {
+
+// ----------------------------------------------------------------------
+// Writer implementation
+
+class IntegrationJsonWriter::Impl {
+ public:
+ explicit Impl(const std::shared_ptr<Schema>& schema)
+ : schema_(schema), mapper_(*schema), first_batch_written_(false) {
+ writer_.reset(new RjWriter(string_buffer_));
+ }
+
+ Status Start() {
+ writer_->StartObject();
+ RETURN_NOT_OK(json::WriteSchema(*schema_, mapper_, writer_.get()));
+ return Status::OK();
+ }
+
+ Status FirstRecordBatch(const RecordBatch& batch) {
+ ARROW_ASSIGN_OR_RAISE(const auto dictionaries, CollectDictionaries(batch, mapper_));
+
+ // Write dictionaries, if any
+ if (!dictionaries.empty()) {
+ writer_->Key("dictionaries");
+ writer_->StartArray();
+ for (const auto& entry : dictionaries) {
+ RETURN_NOT_OK(json::WriteDictionary(entry.first, entry.second, writer_.get()));
+ }
+ writer_->EndArray();
+ }
+
+ // Record batches
+ writer_->Key("batches");
+ writer_->StartArray();
+ first_batch_written_ = true;
+ return Status::OK();
+ }
+
+ Status Finish(std::string* result) {
+ writer_->EndArray(); // Record batches
+ writer_->EndObject();
+
+ *result = string_buffer_.GetString();
+ return Status::OK();
+ }
+
+ Status WriteRecordBatch(const RecordBatch& batch) {
+ DCHECK_EQ(batch.num_columns(), schema_->num_fields());
+
+ if (!first_batch_written_) {
+ RETURN_NOT_OK(FirstRecordBatch(batch));
+ }
+ return json::WriteRecordBatch(batch, writer_.get());
+ }
+
+ private:
+ std::shared_ptr<Schema> schema_;
+ DictionaryFieldMapper mapper_;
+
+ bool first_batch_written_;
+
+ rj::StringBuffer string_buffer_;
+ std::unique_ptr<RjWriter> writer_;
+};
+
+IntegrationJsonWriter::IntegrationJsonWriter(const std::shared_ptr<Schema>& schema) {
+ impl_.reset(new Impl(schema));
+}
+
+IntegrationJsonWriter::~IntegrationJsonWriter() {}
+
+Status IntegrationJsonWriter::Open(const std::shared_ptr<Schema>& schema,
+ std::unique_ptr<IntegrationJsonWriter>* writer) {
+ *writer = std::unique_ptr<IntegrationJsonWriter>(new IntegrationJsonWriter(schema));
+ return (*writer)->impl_->Start();
+}
+
+Status IntegrationJsonWriter::Finish(std::string* result) {
+ return impl_->Finish(result);
+}
+
+Status IntegrationJsonWriter::WriteRecordBatch(const RecordBatch& batch) {
+ return impl_->WriteRecordBatch(batch);
+}
+
+// ----------------------------------------------------------------------
+// Reader implementation
+
+class IntegrationJsonReader::Impl {
+ public:
+ Impl(MemoryPool* pool, const std::shared_ptr<Buffer>& data)
+ : pool_(pool), data_(data), record_batches_(nullptr) {}
+
+ Status ParseAndReadSchema() {
+ doc_.Parse(reinterpret_cast<const rj::Document::Ch*>(data_->data()),
+ static_cast<size_t>(data_->size()));
+ if (doc_.HasParseError()) {
+ return Status::IOError("JSON parsing failed");
+ }
+
+ RETURN_NOT_OK(json::ReadSchema(doc_, pool_, &dictionary_memo_, &schema_));
+
+ auto it = doc_.FindMember("batches");
+ RETURN_NOT_ARRAY("batches", it, doc_);
+ record_batches_ = &it->value;
+
+ return Status::OK();
+ }
+
+ Status ReadRecordBatch(int i, std::shared_ptr<RecordBatch>* batch) {
+ DCHECK_GE(i, 0) << "i out of bounds";
+ DCHECK_LT(i, static_cast<int>(record_batches_->GetArray().Size()))
+ << "i out of bounds";
+
+ return json::ReadRecordBatch(record_batches_->GetArray()[i], schema_,
+ &dictionary_memo_, pool_, batch);
+ }
+
+ std::shared_ptr<Schema> schema() const { return schema_; }
+
+ int num_record_batches() const {
+ return static_cast<int>(record_batches_->GetArray().Size());
+ }
+
+ private:
+ MemoryPool* pool_;
+ std::shared_ptr<Buffer> data_;
+ rj::Document doc_;
+
+ const rj::Value* record_batches_;
+ std::shared_ptr<Schema> schema_;
+ DictionaryMemo dictionary_memo_;
+};
+
+IntegrationJsonReader::IntegrationJsonReader(MemoryPool* pool,
+ const std::shared_ptr<Buffer>& data) {
+ impl_.reset(new Impl(pool, data));
+}
+
+IntegrationJsonReader::~IntegrationJsonReader() {}
+
+Status IntegrationJsonReader::Open(const std::shared_ptr<Buffer>& data,
+ std::unique_ptr<IntegrationJsonReader>* reader) {
+ return Open(default_memory_pool(), data, reader);
+}
+
+Status IntegrationJsonReader::Open(MemoryPool* pool, const std::shared_ptr<Buffer>& data,
+ std::unique_ptr<IntegrationJsonReader>* reader) {
+ *reader = std::unique_ptr<IntegrationJsonReader>(new IntegrationJsonReader(pool, data));
+ return (*reader)->impl_->ParseAndReadSchema();
+}
+
+Status IntegrationJsonReader::Open(MemoryPool* pool,
+ const std::shared_ptr<io::ReadableFile>& in_file,
+ std::unique_ptr<IntegrationJsonReader>* reader) {
+ ARROW_ASSIGN_OR_RAISE(int64_t file_size, in_file->GetSize());
+ ARROW_ASSIGN_OR_RAISE(auto json_buffer, in_file->Read(file_size));
+ return Open(pool, json_buffer, reader);
+}
+
+std::shared_ptr<Schema> IntegrationJsonReader::schema() const { return impl_->schema(); }
+
+int IntegrationJsonReader::num_record_batches() const {
+ return impl_->num_record_batches();
+}
+
+Status IntegrationJsonReader::ReadRecordBatch(int i,
+ std::shared_ptr<RecordBatch>* batch) const {
+ return impl_->ReadRecordBatch(i, batch);
+}
+
+} // namespace testing
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/json_integration.h b/src/arrow/cpp/src/arrow/testing/json_integration.h
new file mode 100644
index 000000000..3486bb5d9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/json_integration.h
@@ -0,0 +1,129 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Implement Arrow JSON serialization format for integration tests
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/status.h"
+#include "arrow/testing/visibility.h"
+
+namespace arrow {
+
+class Buffer;
+class MemoryPool;
+class RecordBatch;
+class Schema;
+
+namespace io {
+class ReadableFile;
+} // namespace io
+
+namespace testing {
+
+/// \class IntegrationJsonWriter
+/// \brief Write the JSON representation of an Arrow record batch file or stream
+///
+/// This is used for integration testing
+class ARROW_TESTING_EXPORT IntegrationJsonWriter {
+ public:
+ ~IntegrationJsonWriter();
+
+ /// \brief Create a new JSON writer that writes to memory
+ ///
+ /// \param[in] schema the schema of record batches
+ /// \param[out] out the returned writer object
+ /// \return Status
+ static Status Open(const std::shared_ptr<Schema>& schema,
+ std::unique_ptr<IntegrationJsonWriter>* out);
+
+ /// \brief Append a record batch
+ Status WriteRecordBatch(const RecordBatch& batch);
+
+ /// \brief Finish the JSON payload and return as a std::string
+ ///
+ /// \param[out] result the JSON as as a std::string
+ /// \return Status
+ Status Finish(std::string* result);
+
+ private:
+ explicit IntegrationJsonWriter(const std::shared_ptr<Schema>& schema);
+
+ // Hide RapidJSON details from public API
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+/// \class IntegrationJsonReader
+/// \brief Read the JSON representation of an Arrow record batch file or stream
+///
+/// This is used for integration testing
+class ARROW_TESTING_EXPORT IntegrationJsonReader {
+ public:
+ ~IntegrationJsonReader();
+
+ /// \brief Create a new JSON reader
+ ///
+ /// \param[in] pool a MemoryPool to use for buffer allocations
+ /// \param[in] data a Buffer containing the JSON data
+ /// \param[out] reader the returned reader object
+ /// \return Status
+ static Status Open(MemoryPool* pool, const std::shared_ptr<Buffer>& data,
+ std::unique_ptr<IntegrationJsonReader>* reader);
+
+ /// \brief Create a new JSON reader that uses the default memory pool
+ ///
+ /// \param[in] data a Buffer containing the JSON data
+ /// \param[out] reader the returned reader object
+ /// \return Status
+ static Status Open(const std::shared_ptr<Buffer>& data,
+ std::unique_ptr<IntegrationJsonReader>* reader);
+
+ /// \brief Create a new JSON reader from a file
+ ///
+ /// \param[in] pool a MemoryPool to use for buffer allocations
+ /// \param[in] in_file a ReadableFile containing JSON data
+ /// \param[out] reader the returned reader object
+ /// \return Status
+ static Status Open(MemoryPool* pool, const std::shared_ptr<io::ReadableFile>& in_file,
+ std::unique_ptr<IntegrationJsonReader>* reader);
+
+ /// \brief Return the schema read from the JSON
+ std::shared_ptr<Schema> schema() const;
+
+ /// \brief Return the number of record batches
+ int num_record_batches() const;
+
+ /// \brief Read a particular record batch from the file
+ ///
+ /// \param[in] i the record batch index, does not boundscheck
+ /// \param[out] batch the read record batch
+ Status ReadRecordBatch(int i, std::shared_ptr<RecordBatch>* batch) const;
+
+ private:
+ IntegrationJsonReader(MemoryPool* pool, const std::shared_ptr<Buffer>& data);
+
+ // Hide RapidJSON details from public API
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+} // namespace testing
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/json_integration_test.cc b/src/arrow/cpp/src/arrow/testing/json_integration_test.cc
new file mode 100644
index 000000000..556201195
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/json_integration_test.cc
@@ -0,0 +1,1188 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <cstdio>
+#include <cstring>
+#include <fstream> // IWYU pragma: keep
+#include <iostream>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include <gflags/gflags.h>
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/io/file.h"
+#include "arrow/ipc/dictionary.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/test_common.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/pretty_print.h"
+#include "arrow/status.h"
+#include "arrow/testing/extension_type.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/json_integration.h"
+#include "arrow/testing/json_internal.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/io_util.h"
+
+DEFINE_string(arrow, "", "Arrow file name");
+DEFINE_string(json, "", "JSON file name");
+DEFINE_string(
+ mode, "VALIDATE",
+ "Mode of integration testing tool (ARROW_TO_JSON, JSON_TO_ARROW, VALIDATE)");
+DEFINE_bool(integration, false, "Run in integration test mode");
+DEFINE_bool(verbose, true, "Verbose output");
+
+namespace arrow {
+
+using internal::TemporaryDir;
+using ipc::DictionaryFieldMapper;
+using ipc::DictionaryMemo;
+using ipc::IpcWriteOptions;
+using ipc::MetadataVersion;
+
+namespace testing {
+
+using namespace ::arrow::ipc::test; // NOLINT
+
+// Convert JSON file to IPC binary format
+static Status ConvertJsonToArrow(const std::string& json_path,
+ const std::string& arrow_path) {
+ ARROW_ASSIGN_OR_RAISE(auto in_file, io::ReadableFile::Open(json_path));
+ ARROW_ASSIGN_OR_RAISE(auto out_file, io::FileOutputStream::Open(arrow_path));
+
+ ARROW_ASSIGN_OR_RAISE(int64_t file_size, in_file->GetSize());
+ ARROW_ASSIGN_OR_RAISE(auto json_buffer, in_file->Read(file_size));
+
+ std::unique_ptr<IntegrationJsonReader> reader;
+ RETURN_NOT_OK(IntegrationJsonReader::Open(json_buffer, &reader));
+
+ if (FLAGS_verbose) {
+ std::cout << "Found schema:\n"
+ << reader->schema()->ToString(/* show_metadata = */ true) << std::endl;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(out_file, reader->schema(),
+ IpcWriteOptions::Defaults()));
+ for (int i = 0; i < reader->num_record_batches(); ++i) {
+ std::shared_ptr<RecordBatch> batch;
+ RETURN_NOT_OK(reader->ReadRecordBatch(i, &batch));
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+ }
+ return writer->Close();
+}
+
+// Convert IPC binary format to JSON
+static Status ConvertArrowToJson(const std::string& arrow_path,
+ const std::string& json_path) {
+ ARROW_ASSIGN_OR_RAISE(auto in_file, io::ReadableFile::Open(arrow_path));
+ ARROW_ASSIGN_OR_RAISE(auto out_file, io::FileOutputStream::Open(json_path));
+
+ std::shared_ptr<ipc::RecordBatchFileReader> reader;
+ ARROW_ASSIGN_OR_RAISE(reader, ipc::RecordBatchFileReader::Open(in_file.get()));
+
+ if (FLAGS_verbose) {
+ std::cout << "Found schema:\n" << reader->schema()->ToString() << std::endl;
+ }
+
+ std::unique_ptr<IntegrationJsonWriter> writer;
+ RETURN_NOT_OK(IntegrationJsonWriter::Open(reader->schema(), &writer));
+
+ for (int i = 0; i < reader->num_record_batches(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> batch, reader->ReadRecordBatch(i));
+ RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+ }
+
+ std::string result;
+ RETURN_NOT_OK(writer->Finish(&result));
+ return out_file->Write(result.c_str(), static_cast<int64_t>(result.size()));
+}
+
+static Status ValidateArrowVsJson(const std::string& arrow_path,
+ const std::string& json_path) {
+ // Construct JSON reader
+ ARROW_ASSIGN_OR_RAISE(auto json_file, io::ReadableFile::Open(json_path));
+
+ ARROW_ASSIGN_OR_RAISE(int64_t file_size, json_file->GetSize());
+ ARROW_ASSIGN_OR_RAISE(auto json_buffer, json_file->Read(file_size));
+
+ std::unique_ptr<IntegrationJsonReader> json_reader;
+ RETURN_NOT_OK(IntegrationJsonReader::Open(json_buffer, &json_reader));
+
+ // Construct Arrow reader
+ ARROW_ASSIGN_OR_RAISE(auto arrow_file, io::ReadableFile::Open(arrow_path));
+
+ std::shared_ptr<ipc::RecordBatchFileReader> arrow_reader;
+ ARROW_ASSIGN_OR_RAISE(arrow_reader, ipc::RecordBatchFileReader::Open(arrow_file.get()));
+
+ auto json_schema = json_reader->schema();
+ auto arrow_schema = arrow_reader->schema();
+
+ if (!json_schema->Equals(*arrow_schema)) {
+ std::stringstream ss;
+ ss << "JSON schema: \n"
+ << json_schema->ToString(/* show_metadata = */ true) << "\n\n"
+ << "Arrow schema: \n"
+ << arrow_schema->ToString(/* show_metadata = */ true) << "\n";
+
+ if (FLAGS_verbose) {
+ std::cout << ss.str() << std::endl;
+ }
+ return Status::Invalid("Schemas did not match");
+ }
+
+ const int json_nbatches = json_reader->num_record_batches();
+ const int arrow_nbatches = arrow_reader->num_record_batches();
+
+ if (json_nbatches != arrow_nbatches) {
+ return Status::Invalid("Different number of record batches: ", json_nbatches,
+ " (JSON) vs ", arrow_nbatches, " (Arrow)");
+ }
+
+ std::shared_ptr<RecordBatch> arrow_batch;
+ std::shared_ptr<RecordBatch> json_batch;
+ for (int i = 0; i < json_nbatches; ++i) {
+ RETURN_NOT_OK(json_reader->ReadRecordBatch(i, &json_batch));
+ ARROW_ASSIGN_OR_RAISE(arrow_batch, arrow_reader->ReadRecordBatch(i));
+ Status valid_st = json_batch->ValidateFull();
+ if (!valid_st.ok()) {
+ return Status::Invalid("JSON record batch ", i, " did not validate:\n",
+ valid_st.ToString());
+ }
+ valid_st = arrow_batch->ValidateFull();
+ if (!valid_st.ok()) {
+ return Status::Invalid("Arrow record batch ", i, " did not validate:\n",
+ valid_st.ToString());
+ }
+
+ if (!json_batch->ApproxEquals(*arrow_batch)) {
+ std::stringstream ss;
+ ss << "Record batch " << i << " did not match";
+
+ ss << "\nJSON:\n";
+ RETURN_NOT_OK(PrettyPrint(*json_batch, 0, &ss));
+
+ ss << "\nArrow:\n";
+ RETURN_NOT_OK(PrettyPrint(*arrow_batch, 0, &ss));
+ return Status::Invalid(ss.str());
+ }
+ }
+
+ return Status::OK();
+}
+
+Status RunCommand(const std::string& json_path, const std::string& arrow_path,
+ const std::string& command) {
+ // Make sure the required extension types are registered, as they will be
+ // referenced in test data.
+ ExtensionTypeGuard ext_guard({uuid(), dict_extension_type()});
+
+ if (json_path == "") {
+ return Status::Invalid("Must specify json file name");
+ }
+
+ if (arrow_path == "") {
+ return Status::Invalid("Must specify arrow file name");
+ }
+
+ auto file_exists = [](const char* path) { return std::ifstream(path).good(); };
+
+ if (command == "ARROW_TO_JSON") {
+ if (!file_exists(arrow_path.c_str())) {
+ return Status::Invalid("Input file does not exist");
+ }
+
+ return ConvertArrowToJson(arrow_path, json_path);
+ } else if (command == "JSON_TO_ARROW") {
+ if (!file_exists(json_path.c_str())) {
+ return Status::Invalid("Input file does not exist");
+ }
+
+ return ConvertJsonToArrow(json_path, arrow_path);
+ } else if (command == "VALIDATE") {
+ if (!file_exists(json_path.c_str())) {
+ return Status::Invalid("JSON file does not exist");
+ }
+
+ if (!file_exists(arrow_path.c_str())) {
+ return Status::Invalid("Arrow file does not exist");
+ }
+
+ return ValidateArrowVsJson(arrow_path, json_path);
+ } else {
+ return Status::Invalid("Unknown command: ", command);
+ }
+}
+
+class TestJSONIntegration : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK_AND_ASSIGN(temp_dir_, TemporaryDir::Make("json-integration-test-"));
+ }
+
+ std::string mkstemp() {
+ std::stringstream ss;
+ ss << temp_dir_->path().ToString();
+ ss << "file" << ntemp_++;
+ return ss.str();
+ }
+
+ Status WriteJson(const char* data, const std::string& path) {
+ ARROW_ASSIGN_OR_RAISE(auto out_file, io::FileOutputStream::Open(path));
+ return out_file->Write(data, static_cast<int64_t>(strlen(data)));
+ }
+
+ void TearDown() { temp_dir_.reset(); }
+
+ protected:
+ std::unique_ptr<TemporaryDir> temp_dir_;
+ int ntemp_ = 1;
+};
+
+static const char* JSON_EXAMPLE = R"example(
+{
+ "schema": {
+ "fields": [
+ {
+ "name": "foo",
+ "type": {"name": "int", "isSigned": true, "bitWidth": 64},
+ "nullable": true, "children": []
+ },
+ {
+ "name": "bar",
+ "type": {"name": "floatingpoint", "precision": "DOUBLE"},
+ "nullable": true, "children": []
+ }
+ ]
+ },
+ "batches": [
+ {
+ "count": 5,
+ "columns": [
+ {
+ "name": "foo",
+ "count": 5,
+ "DATA": ["1", "2", "3", "4", "5"],
+ "VALIDITY": [1, 0, 1, 1, 1]
+ },
+ {
+ "name": "bar",
+ "count": 5,
+ "DATA": [1.0, 2.0, 3.0, 4.0, 5.0],
+ "VALIDITY": [1, 0, 0, 1, 1]
+ }
+ ]
+ },
+ {
+ "count": 4,
+ "columns": [
+ {
+ "name": "foo",
+ "count": 4,
+ "DATA": ["-1", "0", "9223372036854775807", "-9223372036854775808"],
+ "VALIDITY": [1, 0, 1, 1]
+ },
+ {
+ "name": "bar",
+ "count": 4,
+ "DATA": [1.0, 2.0, 3.0, 4.0],
+ "VALIDITY": [1, 0, 0, 1]
+ }
+ ]
+ }
+ ]
+}
+)example";
+
+static const char* JSON_EXAMPLE2 = R"example(
+{
+ "schema": {
+ "fields": [
+ {
+ "name": "foo",
+ "type": {"name": "int", "isSigned": true, "bitWidth": 32},
+ "nullable": true, "children": [],
+ "metadata": [
+ {"key": "converted_from_time32", "value": "true"}
+ ]
+ }
+ ],
+ "metadata": [
+ {"key": "schema_custom_0", "value": "eh"}
+ ]
+ },
+ "batches": [
+ {
+ "count": 5,
+ "columns": [
+ {
+ "name": "foo",
+ "count": 5,
+ "DATA": [1, 2, 3, 4, 5],
+ "VALIDITY": [1, 0, 1, 1, 1]
+ }
+ ]
+ }
+ ]
+}
+)example";
+
+TEST_F(TestJSONIntegration, ConvertAndValidate) {
+ std::string json_path = this->mkstemp();
+ std::string arrow_path = this->mkstemp();
+
+ ASSERT_OK(WriteJson(JSON_EXAMPLE, json_path));
+
+ ASSERT_OK(RunCommand(json_path, arrow_path, "JSON_TO_ARROW"));
+ ASSERT_OK(RunCommand(json_path, arrow_path, "VALIDATE"));
+
+ // Convert and overwrite
+ ASSERT_OK(RunCommand(json_path, arrow_path, "ARROW_TO_JSON"));
+
+ // Convert back to arrow, and validate
+ ASSERT_OK(RunCommand(json_path, arrow_path, "JSON_TO_ARROW"));
+ ASSERT_OK(RunCommand(json_path, arrow_path, "VALIDATE"));
+}
+
+TEST_F(TestJSONIntegration, ErrorStates) {
+ std::string json_path = this->mkstemp();
+ std::string json_path2 = this->mkstemp();
+ std::string arrow_path = this->mkstemp();
+
+ ASSERT_OK(WriteJson(JSON_EXAMPLE, json_path));
+ ASSERT_OK(WriteJson(JSON_EXAMPLE2, json_path2));
+
+ ASSERT_OK(ConvertJsonToArrow(json_path, arrow_path));
+ ASSERT_RAISES(Invalid, ValidateArrowVsJson(arrow_path, json_path2));
+
+ ASSERT_RAISES(IOError, ValidateArrowVsJson("does_not_exist-1234", json_path2));
+ ASSERT_RAISES(IOError, ValidateArrowVsJson(arrow_path, "does_not_exist-1234"));
+
+ ASSERT_RAISES(Invalid, RunCommand("", arrow_path, "VALIDATE"));
+ ASSERT_RAISES(Invalid, RunCommand(json_path, "", "VALIDATE"));
+}
+
+// A batch with primitive types
+static const char* json_example1 = R"example(
+{
+ "schema": {
+ "fields": [
+ {
+ "name": "foo",
+ "type": {"name": "int", "isSigned": true, "bitWidth": 32},
+ "nullable": true, "children": []
+ },
+ {
+ "name": "bar",
+ "type": {"name": "floatingpoint", "precision": "DOUBLE"},
+ "nullable": true, "children": []
+ }
+ ]
+ },
+ "batches": [
+ {
+ "count": 5,
+ "columns": [
+ {
+ "name": "foo",
+ "count": 5,
+ "DATA": [1, 2, 3, 4, 5],
+ "VALIDITY": [1, 0, 1, 1, 1]
+ },
+ {
+ "name": "bar",
+ "count": 5,
+ "DATA": [1.0, 2.0, 3.0, 4.0, 5.0],
+ "VALIDITY": [1, 0, 0, 1, 1]
+ }
+ ]
+ }
+ ]
+}
+)example";
+
+// A batch with extension types
+static const char* json_example2 = R"example(
+{
+ "schema": {
+ "fields": [
+ {
+ "name": "uuids",
+ "type" : {
+ "name" : "fixedsizebinary",
+ "byteWidth" : 16
+ },
+ "nullable": true,
+ "children" : [],
+ "metadata" : [
+ {"key": "ARROW:extension:name", "value": "uuid"},
+ {"key": "ARROW:extension:metadata", "value": "uuid-serialized"}
+ ]
+ },
+ {
+ "name": "things",
+ "type" : {
+ "name" : "null"
+ },
+ "nullable": true,
+ "children" : [],
+ "metadata" : [
+ {"key": "ARROW:extension:name", "value": "!does not exist!"},
+ {"key": "ARROW:extension:metadata", "value": ""},
+ {"key": "ARROW:integration:allow_unregistered_extension", "value": "true"}
+ ]
+ }
+ ]
+ },
+ "batches": [
+ {
+ "count": 2,
+ "columns": [
+ {
+ "name": "uuids",
+ "count": 2,
+ "DATA": ["30313233343536373839616263646566",
+ "00000000000000000000000000000000"],
+ "VALIDITY": [1, 0]
+ },
+ {
+ "name": "things",
+ "count": 2
+ }
+ ]
+ }
+ ]
+}
+)example";
+
+// A batch with dict-extension types
+static const char* json_example3 = R"example(
+{
+ "schema": {
+ "fields": [
+ {
+ "name": "dict-extensions",
+ "type" : {
+ "name" : "utf8"
+ },
+ "nullable": true,
+ "children" : [],
+ "dictionary": {
+ "id": 0,
+ "indexType": {
+ "name": "int",
+ "isSigned": true,
+ "bitWidth": 8
+ },
+ "isOrdered": false
+ },
+ "metadata" : [
+ {"key": "ARROW:extension:name", "value": "dict-extension"},
+ {"key": "ARROW:extension:metadata", "value": "dict-extension-serialized"}
+ ]
+ }
+ ]
+ },
+ "dictionaries": [
+ {
+ "id": 0,
+ "data": {
+ "count": 3,
+ "columns": [
+ {
+ "name": "DICT0",
+ "count": 3,
+ "VALIDITY": [
+ 1,
+ 1,
+ 1
+ ],
+ "OFFSET": [
+ 0,
+ 3,
+ 6,
+ 10
+ ],
+ "DATA": [
+ "foo",
+ "bar",
+ "quux"
+ ]
+ }
+ ]
+ }
+ }
+ ],
+ "batches": [
+ {
+ "count": 5,
+ "columns": [
+ {
+ "name": "dict-extensions",
+ "count": 5,
+ "DATA": [2, 0, 1, 1, 2],
+ "VALIDITY": [1, 1, 0, 1, 1]
+ }
+ ]
+ }
+ ]
+}
+)example";
+
+// A batch with a map type with non-canonical field names
+static const char* json_example4 = R"example(
+{
+ "schema": {
+ "fields": [
+ {
+ "name": "maps",
+ "type": {
+ "name": "map",
+ "keysSorted": false
+ },
+ "nullable": true,
+ "children": [
+ {
+ "name": "some_entries",
+ "type": {
+ "name": "struct"
+ },
+ "nullable": false,
+ "children": [
+ {
+ "name": "some_key",
+ "type": {
+ "name": "int",
+ "isSigned": true,
+ "bitWidth": 16
+ },
+ "nullable": false,
+ "children": []
+ },
+ {
+ "name": "some_value",
+ "type": {
+ "name": "int",
+ "isSigned": true,
+ "bitWidth": 32
+ },
+ "nullable": true,
+ "children": []
+ }
+ ]
+ }
+ ]
+ }
+ ]
+ },
+ "batches": [
+ {
+ "count": 3,
+ "columns": [
+ {
+ "name": "map_other_names",
+ "count": 3,
+ "VALIDITY": [1, 0, 1],
+ "OFFSET": [0, 3, 3, 5],
+ "children": [
+ {
+ "name": "some_entries",
+ "count": 5,
+ "VALIDITY": [1, 1, 1, 1, 1],
+ "children": [
+ {
+ "name": "some_key",
+ "count": 5,
+ "VALIDITY": [1, 1, 1, 1, 1],
+ "DATA": [11, 22, 33, 44, 55]
+ },
+ {
+ "name": "some_value",
+ "count": 5,
+ "VALIDITY": [1, 1, 0, 1, 1],
+ "DATA": [111, 222, 0, 444, 555]
+ }
+ ]
+ }
+ ]
+ }
+ ]
+ }
+ ]
+}
+)example";
+
+// An empty struct type, with "children" member in batches
+static const char* json_example5 = R"example(
+{
+ "schema": {
+ "fields": [
+ {
+ "name": "empty_struct",
+ "nullable": true,
+ "type": {
+ "name": "struct"
+ },
+ "children": []
+ }
+ ]
+ },
+ "batches": [
+ {
+ "count": 3,
+ "columns": [
+ {
+ "name": "empty_struct",
+ "count": 3,
+ "VALIDITY": [1, 0, 1],
+ "children": []
+ }
+ ]
+ }
+ ]
+}
+)example";
+
+// An empty struct type, without "children" member in batches
+static const char* json_example6 = R"example(
+{
+ "schema": {
+ "fields": [
+ {
+ "name": "empty_struct",
+ "nullable": true,
+ "type": {
+ "name": "struct"
+ },
+ "children": []
+ }
+ ]
+ },
+ "batches": [
+ {
+ "count": 2,
+ "columns": [
+ {
+ "name": "empty_struct",
+ "count": 2,
+ "VALIDITY": [1, 0]
+ }
+ ]
+ }
+ ]
+}
+)example";
+
+void TestSchemaRoundTrip(const Schema& schema) {
+ rj::StringBuffer sb;
+ rj::Writer<rj::StringBuffer> writer(sb);
+
+ DictionaryFieldMapper mapper(schema);
+
+ writer.StartObject();
+ ASSERT_OK(json::WriteSchema(schema, mapper, &writer));
+ writer.EndObject();
+
+ std::string json_schema = sb.GetString();
+
+ rj::Document d;
+ // Pass explicit size to avoid ASAN issues with
+ // SIMD loads in RapidJson.
+ d.Parse(json_schema.data(), json_schema.size());
+
+ DictionaryMemo in_memo;
+ std::shared_ptr<Schema> out;
+ if (!json::ReadSchema(d, default_memory_pool(), &in_memo, &out).ok()) {
+ FAIL() << "Unable to read JSON schema: " << json_schema;
+ }
+
+ if (!schema.Equals(*out)) {
+ FAIL() << "In schema: " << schema.ToString() << "\nOut schema: " << out->ToString();
+ }
+}
+
+void TestArrayRoundTrip(const Array& array) {
+ static std::string name = "dummy";
+
+ rj::StringBuffer sb;
+ rj::Writer<rj::StringBuffer> writer(sb);
+
+ ASSERT_OK(json::WriteArray(name, array, &writer));
+
+ std::string array_as_json = sb.GetString();
+
+ rj::Document d;
+ // Pass explicit size to avoid ASAN issues with
+ // SIMD loads in RapidJson.
+ d.Parse(array_as_json.data(), array_as_json.size());
+
+ if (d.HasParseError()) {
+ FAIL() << "JSON parsing failed";
+ }
+
+ std::shared_ptr<Array> out;
+ ASSERT_OK(json::ReadArray(default_memory_pool(), d, ::arrow::field(name, array.type()),
+ &out));
+
+ // std::cout << array_as_json << std::endl;
+ CompareArraysDetailed(0, *out, array);
+}
+
+template <typename T, typename ValueType>
+void CheckPrimitive(const std::shared_ptr<DataType>& type,
+ const std::vector<bool>& is_valid,
+ const std::vector<ValueType>& values) {
+ MemoryPool* pool = default_memory_pool();
+ typename TypeTraits<T>::BuilderType builder(pool);
+
+ for (size_t i = 0; i < values.size(); ++i) {
+ if (is_valid[i]) {
+ ASSERT_OK(builder.Append(values[i]));
+ } else {
+ ASSERT_OK(builder.AppendNull());
+ }
+ }
+
+ std::shared_ptr<Array> array;
+ ASSERT_OK(builder.Finish(&array));
+ TestArrayRoundTrip(*array);
+}
+
+TEST(TestJsonSchemaWriter, FlatTypes) {
+ // TODO
+ // field("f14", date32())
+ std::vector<std::shared_ptr<Field>> fields = {
+ field("f0", int8()),
+ field("f1", int16(), false),
+ field("f2", int32()),
+ field("f3", int64(), false),
+ field("f4", uint8()),
+ field("f5", uint16()),
+ field("f6", uint32()),
+ field("f7", uint64()),
+ field("f8", float32()),
+ field("f9", float64()),
+ field("f10", utf8()),
+ field("f11", binary()),
+ field("f12", list(int32())),
+ field("f13", struct_({field("s1", int32()), field("s2", utf8())})),
+ field("f15", date64()),
+ field("f16", timestamp(TimeUnit::NANO)),
+ field("f17", time64(TimeUnit::MICRO)),
+ field("f18",
+ dense_union({field("u1", int8()), field("u2", time32(TimeUnit::MILLI))},
+ {0, 1})),
+ field("f19", large_list(uint8())),
+ field("f20", null()),
+ };
+
+ Schema schema(fields);
+ TestSchemaRoundTrip(schema);
+}
+
+template <typename T>
+void PrimitiveTypesCheckOne() {
+ using c_type = typename T::c_type;
+
+ std::vector<bool> is_valid = {true, false, true, true, true, false, true, true};
+ std::vector<c_type> values = {0, 1, 2, 3, 4, 5, 6, 7};
+ CheckPrimitive<T, c_type>(std::make_shared<T>(), is_valid, values);
+}
+
+TEST(TestJsonArrayWriter, NullType) {
+ auto arr = std::make_shared<NullArray>(10);
+ TestArrayRoundTrip(*arr);
+}
+
+TEST(TestJsonArrayWriter, PrimitiveTypes) {
+ PrimitiveTypesCheckOne<Int8Type>();
+ PrimitiveTypesCheckOne<Int16Type>();
+ PrimitiveTypesCheckOne<Int32Type>();
+ PrimitiveTypesCheckOne<Int64Type>();
+ PrimitiveTypesCheckOne<UInt8Type>();
+ PrimitiveTypesCheckOne<UInt16Type>();
+ PrimitiveTypesCheckOne<UInt32Type>();
+ PrimitiveTypesCheckOne<UInt64Type>();
+ PrimitiveTypesCheckOne<FloatType>();
+ PrimitiveTypesCheckOne<DoubleType>();
+
+ std::vector<bool> is_valid = {true, false, true, true, true, false, true, true};
+ std::vector<std::string> values = {"foo", "bar", "", "baz", "qux", "foo", "a", "1"};
+
+ CheckPrimitive<StringType, std::string>(utf8(), is_valid, values);
+ CheckPrimitive<BinaryType, std::string>(binary(), is_valid, values);
+}
+
+TEST(TestJsonArrayWriter, NestedTypes) {
+ auto value_type = int32();
+
+ std::vector<bool> values_is_valid = {true, false, true, true, false, true, true};
+
+ std::vector<int32_t> values = {0, 1, 2, 3, 4, 5, 6};
+ std::shared_ptr<Array> values_array;
+ ArrayFromVector<Int32Type, int32_t>(values_is_valid, values, &values_array);
+
+ std::vector<int16_t> i16_values = {0, 1, 2, 3, 4, 5, 6};
+ std::shared_ptr<Array> i16_values_array;
+ ArrayFromVector<Int16Type, int16_t>(values_is_valid, i16_values, &i16_values_array);
+
+ // List
+ std::vector<bool> list_is_valid = {true, false, true, true, true};
+ std::shared_ptr<Buffer> list_bitmap;
+ ASSERT_OK(GetBitmapFromVector(list_is_valid, &list_bitmap));
+ std::vector<int32_t> offsets = {0, 0, 0, 1, 4, 7};
+ std::shared_ptr<Buffer> offsets_buffer = Buffer::Wrap(offsets);
+ {
+ ListArray list_array(list(value_type), 5, offsets_buffer, values_array, list_bitmap,
+ 1);
+ TestArrayRoundTrip(list_array);
+ }
+
+ // LargeList
+ std::vector<int64_t> large_offsets = {0, 0, 0, 1, 4, 7};
+ std::shared_ptr<Buffer> large_offsets_buffer = Buffer::Wrap(large_offsets);
+ {
+ LargeListArray list_array(large_list(value_type), 5, large_offsets_buffer,
+ values_array, list_bitmap, 1);
+ TestArrayRoundTrip(list_array);
+ }
+
+ // Map
+ auto map_type = map(utf8(), int32());
+ auto keys_array = ArrayFromJSON(utf8(), R"(["a", "b", "c", "d", "a", "b", "c"])");
+
+ MapArray map_array(map_type, 5, offsets_buffer, keys_array, values_array, list_bitmap,
+ 1);
+
+ TestArrayRoundTrip(map_array);
+
+ // FixedSizeList
+ FixedSizeListArray fixed_size_list_array(fixed_size_list(value_type, 2), 3,
+ values_array->Slice(1), list_bitmap, 1);
+
+ TestArrayRoundTrip(fixed_size_list_array);
+
+ // Struct
+ std::vector<bool> struct_is_valid = {true, false, true, true, true, false, true};
+ std::shared_ptr<Buffer> struct_bitmap;
+ ASSERT_OK(GetBitmapFromVector(struct_is_valid, &struct_bitmap));
+
+ auto struct_type =
+ struct_({field("f1", int32()), field("f2", int32()), field("f3", int32())});
+
+ std::vector<std::shared_ptr<Array>> fields = {values_array, values_array, values_array};
+ StructArray struct_array(struct_type, static_cast<int>(struct_is_valid.size()), fields,
+ struct_bitmap, 2);
+ TestArrayRoundTrip(struct_array);
+}
+
+TEST(TestJsonArrayWriter, Unions) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(MakeUnion(&batch));
+
+ for (int i = 0; i < batch->num_columns(); ++i) {
+ TestArrayRoundTrip(*batch->column(i));
+ }
+}
+
+// Data generation for test case below
+void MakeBatchArrays(const std::shared_ptr<Schema>& schema, const int num_rows,
+ std::vector<std::shared_ptr<Array>>* arrays) {
+ const float null_prob = 0.25f;
+ random::RandomArrayGenerator rand(0x564a3bf0);
+
+ *arrays = {rand.Boolean(num_rows, 0.75, null_prob),
+ rand.Int8(num_rows, 0, 100, null_prob),
+ rand.Int32(num_rows, -1000, 1000, null_prob),
+ rand.UInt64(num_rows, 0, 1UL << 16, null_prob)};
+
+ static const int kBufferSize = 10;
+ static uint8_t buffer[kBufferSize];
+ static uint32_t seed = 0;
+ StringBuilder string_builder;
+ for (int i = 0; i < num_rows; ++i) {
+ random_ascii(kBufferSize, seed++, buffer);
+ ASSERT_OK(string_builder.Append(buffer, kBufferSize));
+ }
+ std::shared_ptr<Array> v3;
+ ASSERT_OK(string_builder.Finish(&v3));
+
+ arrays->emplace_back(v3);
+}
+
+TEST(TestJsonFileReadWrite, BasicRoundTrip) {
+ auto v1_type = boolean();
+ auto v2_type = int8();
+ auto v3_type = int32();
+ auto v4_type = uint64();
+ auto v5_type = utf8();
+
+ auto schema =
+ ::arrow::schema({field("f1", v1_type), field("f2", v2_type), field("f3", v3_type),
+ field("f4", v4_type), field("f5", v5_type)});
+
+ std::unique_ptr<IntegrationJsonWriter> writer;
+ ASSERT_OK(IntegrationJsonWriter::Open(schema, &writer));
+
+ const int nbatches = 3;
+ std::vector<std::shared_ptr<RecordBatch>> batches;
+ for (int i = 0; i < nbatches; ++i) {
+ int num_rows = 5 + i * 5;
+ std::vector<std::shared_ptr<Array>> arrays;
+
+ MakeBatchArrays(schema, num_rows, &arrays);
+ auto batch = RecordBatch::Make(schema, num_rows, arrays);
+ batches.push_back(batch);
+ ASSERT_OK(writer->WriteRecordBatch(*batch));
+ }
+
+ std::string result;
+ ASSERT_OK(writer->Finish(&result));
+
+ std::unique_ptr<IntegrationJsonReader> reader;
+
+ auto buffer = std::make_shared<Buffer>(result);
+
+ ASSERT_OK(IntegrationJsonReader::Open(buffer, &reader));
+ ASSERT_TRUE(reader->schema()->Equals(*schema));
+
+ ASSERT_EQ(nbatches, reader->num_record_batches());
+
+ for (int i = 0; i < nbatches; ++i) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(reader->ReadRecordBatch(i, &batch));
+ ASSERT_BATCHES_EQUAL(*batch, *batches[i]);
+ }
+}
+
+static void ReadOneBatchJson(const char* json, const Schema& expected_schema,
+ std::shared_ptr<RecordBatch>* out) {
+ auto buffer = Buffer::Wrap(json, strlen(json));
+
+ std::unique_ptr<IntegrationJsonReader> reader;
+ ASSERT_OK(IntegrationJsonReader::Open(buffer, &reader));
+
+ AssertSchemaEqual(*reader->schema(), expected_schema, /*check_metadata=*/true);
+ ASSERT_EQ(1, reader->num_record_batches());
+
+ ASSERT_OK(reader->ReadRecordBatch(0, out));
+}
+
+TEST(TestJsonFileReadWrite, JsonExample1) {
+ Schema ex_schema({field("foo", int32()), field("bar", float64())});
+
+ std::shared_ptr<RecordBatch> batch;
+ ReadOneBatchJson(json_example1, ex_schema, &batch);
+
+ std::vector<bool> foo_valid = {true, false, true, true, true};
+ std::vector<int32_t> foo_values = {1, 2, 3, 4, 5};
+ std::shared_ptr<Array> foo;
+ ArrayFromVector<Int32Type, int32_t>(foo_valid, foo_values, &foo);
+ ASSERT_TRUE(batch->column(0)->Equals(foo));
+
+ std::vector<bool> bar_valid = {true, false, false, true, true};
+ std::vector<double> bar_values = {1, 2, 3, 4, 5};
+ std::shared_ptr<Array> bar;
+ ArrayFromVector<DoubleType, double>(bar_valid, bar_values, &bar);
+ ASSERT_TRUE(batch->column(1)->Equals(bar));
+}
+
+TEST(TestJsonFileReadWrite, JsonExample2) {
+ // Example 2: two extension types (one registered, one unregistered)
+ auto uuid_type = uuid();
+ auto buffer = Buffer::Wrap(json_example2, strlen(json_example2));
+
+ std::unique_ptr<IntegrationJsonReader> reader;
+ {
+ ExtensionTypeGuard ext_guard(uuid_type);
+
+ ASSERT_OK(IntegrationJsonReader::Open(buffer, &reader));
+ // The second field is an unregistered extension and will be read as
+ // its underlying storage.
+ Schema ex_schema({field("uuids", uuid_type), field("things", null())});
+
+ AssertSchemaEqual(ex_schema, *reader->schema());
+ ASSERT_EQ(1, reader->num_record_batches());
+
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK(reader->ReadRecordBatch(0, &batch));
+
+ auto storage_array =
+ ArrayFromJSON(fixed_size_binary(16), R"(["0123456789abcdef", null])");
+ AssertArraysEqual(*batch->column(0), UuidArray(uuid_type, storage_array));
+
+ AssertArraysEqual(*batch->column(1), NullArray(2));
+ }
+
+ // Should fail now that the Uuid extension is unregistered
+ ASSERT_RAISES(KeyError, IntegrationJsonReader::Open(buffer, &reader));
+}
+
+TEST(TestJsonFileReadWrite, JsonExample3) {
+ // Example 3: An extension type with a dictionary storage type
+ auto dict_ext_type = std::make_shared<DictExtensionType>();
+ ExtensionTypeGuard ext_guard(dict_ext_type);
+ Schema ex_schema({field("dict-extensions", dict_ext_type)});
+
+ std::shared_ptr<RecordBatch> batch;
+ ReadOneBatchJson(json_example3, ex_schema, &batch);
+ auto storage_array = std::make_shared<DictionaryArray>(
+ dict_ext_type->storage_type(), ArrayFromJSON(int8(), "[2, 0, null, 1, 2]"),
+ ArrayFromJSON(utf8(), R"(["foo", "bar", "quux"])"));
+ AssertArraysEqual(*batch->column(0), ExtensionArray(dict_ext_type, storage_array),
+ /*verbose=*/true);
+}
+
+TEST(TestJsonFileReadWrite, JsonExample4) {
+ // Example 4: A map type with non-canonical field names
+ ASSERT_OK_AND_ASSIGN(auto map_type,
+ MapType::Make(field("some_entries",
+ struct_({field("some_key", int16(), false),
+ field("some_value", int32())}),
+ false)));
+ Schema ex_schema({field("maps", map_type)});
+
+ std::shared_ptr<RecordBatch> batch;
+ ReadOneBatchJson(json_example4, ex_schema, &batch);
+
+ auto expected_array = ArrayFromJSON(
+ map(int16(), int32()),
+ R"([[[11, 111], [22, 222], [33, null]], null, [[44, 444], [55, 555]]])");
+ AssertArraysEqual(*batch->column(0), *expected_array);
+}
+
+TEST(TestJsonFileReadWrite, JsonExample5) {
+ // Example 5: An empty struct
+ auto struct_type = struct_(FieldVector{});
+ Schema ex_schema({field("empty_struct", struct_type)});
+
+ std::shared_ptr<RecordBatch> batch;
+ ReadOneBatchJson(json_example5, ex_schema, &batch);
+
+ auto expected_array = ArrayFromJSON(struct_type, "[{}, null, {}]");
+ AssertArraysEqual(*batch->column(0), *expected_array);
+}
+
+TEST(TestJsonFileReadWrite, JsonExample6) {
+ // Example 6: An empty struct
+ auto struct_type = struct_(FieldVector{});
+ Schema ex_schema({field("empty_struct", struct_type)});
+
+ std::shared_ptr<RecordBatch> batch;
+ ReadOneBatchJson(json_example6, ex_schema, &batch);
+
+ auto expected_array = ArrayFromJSON(struct_type, "[{}, null]");
+ AssertArraysEqual(*batch->column(0), *expected_array);
+}
+
+class TestJsonRoundTrip : public ::testing::TestWithParam<MakeRecordBatch*> {
+ public:
+ void SetUp() {}
+ void TearDown() {}
+};
+
+void CheckRoundtrip(const RecordBatch& batch) {
+ ExtensionTypeGuard guard({uuid(), dict_extension_type(), complex128()});
+
+ TestSchemaRoundTrip(*batch.schema());
+
+ std::unique_ptr<IntegrationJsonWriter> writer;
+ ASSERT_OK(IntegrationJsonWriter::Open(batch.schema(), &writer));
+ ASSERT_OK(writer->WriteRecordBatch(batch));
+
+ std::string result;
+ ASSERT_OK(writer->Finish(&result));
+
+ auto buffer = std::make_shared<Buffer>(result);
+
+ std::unique_ptr<IntegrationJsonReader> reader;
+ ASSERT_OK(IntegrationJsonReader::Open(buffer, &reader));
+
+ std::shared_ptr<RecordBatch> result_batch;
+ ASSERT_OK(reader->ReadRecordBatch(0, &result_batch));
+
+ // take care of float rounding error in the text representation
+ ApproxCompareBatch(batch, *result_batch);
+}
+
+TEST_P(TestJsonRoundTrip, RoundTrip) {
+ std::shared_ptr<RecordBatch> batch;
+ ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue
+
+ CheckRoundtrip(*batch);
+}
+
+const std::vector<ipc::test::MakeRecordBatch*> kBatchCases = {
+ &MakeIntRecordBatch,
+ &MakeListRecordBatch,
+ &MakeFixedSizeListRecordBatch,
+ &MakeNonNullRecordBatch,
+ &MakeZeroLengthRecordBatch,
+ &MakeDeeplyNestedList,
+ &MakeStringTypesRecordBatchWithNulls,
+ &MakeStruct,
+ &MakeUnion,
+ &MakeDictionary,
+ &MakeNestedDictionary,
+ &MakeMap,
+ &MakeMapOfDictionary,
+ &MakeDates,
+ &MakeTimestamps,
+ &MakeTimes,
+ &MakeFWBinary,
+ &MakeNull,
+ &MakeDecimal,
+ &MakeBooleanBatch,
+ &MakeFloatBatch,
+ &MakeIntervals,
+ &MakeUuid,
+ &MakeComplex128,
+ &MakeDictExtension};
+
+INSTANTIATE_TEST_SUITE_P(TestJsonRoundTrip, TestJsonRoundTrip,
+ ::testing::ValuesIn(kBatchCases));
+
+} // namespace testing
+} // namespace arrow
+
+int main(int argc, char** argv) {
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+ int ret = 0;
+
+ if (FLAGS_integration) {
+ arrow::Status result =
+ arrow::testing::RunCommand(FLAGS_json, FLAGS_arrow, FLAGS_mode);
+ if (!result.ok()) {
+ std::cout << "Error message: " << result.ToString() << std::endl;
+ ret = 1;
+ }
+ } else {
+ ::testing::InitGoogleTest(&argc, argv);
+ ret = RUN_ALL_TESTS();
+ }
+ gflags::ShutDownCommandLineFlags();
+ return ret;
+}
diff --git a/src/arrow/cpp/src/arrow/testing/json_internal.cc b/src/arrow/cpp/src/arrow/testing/json_internal.cc
new file mode 100644
index 000000000..5e5d67a3f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/json_internal.cc
@@ -0,0 +1,1804 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/testing/json_internal.h"
+
+#include <cstdint>
+#include <cstdlib>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/builder_time.h"
+#include "arrow/extension_type.h"
+#include "arrow/ipc/dictionary.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/formatting.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string.h"
+#include "arrow/util/value_parsing.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::ParseValue;
+
+using ipc::DictionaryFieldMapper;
+using ipc::DictionaryMemo;
+using ipc::internal::FieldPosition;
+
+namespace testing {
+namespace json {
+
+namespace {
+
+constexpr char kData[] = "DATA";
+constexpr char kDays[] = "days";
+constexpr char kDayTime[] = "DAY_TIME";
+constexpr char kDuration[] = "duration";
+constexpr char kMilliseconds[] = "milliseconds";
+constexpr char kMonths[] = "months";
+constexpr char kNanoseconds[] = "nanoseconds";
+constexpr char kYearMonth[] = "YEAR_MONTH";
+constexpr char kMonthDayNano[] = "MONTH_DAY_NANO";
+
+std::string GetFloatingPrecisionName(FloatingPointType::Precision precision) {
+ switch (precision) {
+ case FloatingPointType::HALF:
+ return "HALF";
+ case FloatingPointType::SINGLE:
+ return "SINGLE";
+ case FloatingPointType::DOUBLE:
+ return "DOUBLE";
+ default:
+ break;
+ }
+ return "UNKNOWN";
+}
+
+std::string GetTimeUnitName(TimeUnit::type unit) {
+ switch (unit) {
+ case TimeUnit::SECOND:
+ return "SECOND";
+ case TimeUnit::MILLI:
+ return "MILLISECOND";
+ case TimeUnit::MICRO:
+ return "MICROSECOND";
+ case TimeUnit::NANO:
+ return "NANOSECOND";
+ default:
+ break;
+ }
+ return "UNKNOWN";
+}
+
+class SchemaWriter {
+ public:
+ explicit SchemaWriter(const Schema& schema, const DictionaryFieldMapper& mapper,
+ RjWriter* writer)
+ : schema_(schema), mapper_(mapper), writer_(writer) {}
+
+ Status Write() {
+ writer_->Key("schema");
+ writer_->StartObject();
+ writer_->Key("fields");
+ writer_->StartArray();
+
+ FieldPosition field_pos;
+ int i = 0;
+ for (const std::shared_ptr<Field>& field : schema_.fields()) {
+ RETURN_NOT_OK(VisitField(field, field_pos.child(i)));
+ ++i;
+ }
+ writer_->EndArray();
+ WriteKeyValueMetadata(schema_.metadata());
+ writer_->EndObject();
+ return Status::OK();
+ }
+
+ void WriteKeyValueMetadata(
+ const std::shared_ptr<const KeyValueMetadata>& metadata,
+ const std::vector<std::pair<std::string, std::string>>& additional_metadata = {}) {
+ if ((metadata == nullptr || metadata->size() == 0) && additional_metadata.empty()) {
+ return;
+ }
+ writer_->Key("metadata");
+
+ writer_->StartArray();
+ if (metadata != nullptr) {
+ for (int64_t i = 0; i < metadata->size(); ++i) {
+ WriteKeyValue(metadata->key(i), metadata->value(i));
+ }
+ }
+ for (const auto& kv : additional_metadata) {
+ WriteKeyValue(kv.first, kv.second);
+ }
+ writer_->EndArray();
+ }
+
+ void WriteKeyValue(const std::string& key, const std::string& value) {
+ writer_->StartObject();
+
+ writer_->Key("key");
+ writer_->String(key.c_str());
+
+ writer_->Key("value");
+ writer_->String(value.c_str());
+
+ writer_->EndObject();
+ }
+
+ Status WriteDictionaryMetadata(int64_t id, const DictionaryType& type) {
+ writer_->Key("dictionary");
+
+ // Emulate DictionaryEncoding from Schema.fbs
+ writer_->StartObject();
+ writer_->Key("id");
+ writer_->Int(static_cast<int32_t>(id));
+ writer_->Key("indexType");
+
+ writer_->StartObject();
+ RETURN_NOT_OK(VisitType(*type.index_type()));
+ writer_->EndObject();
+
+ writer_->Key("isOrdered");
+ writer_->Bool(type.ordered());
+ writer_->EndObject();
+
+ return Status::OK();
+ }
+
+ Status VisitField(const std::shared_ptr<Field>& field, FieldPosition field_pos) {
+ writer_->StartObject();
+
+ writer_->Key("name");
+ writer_->String(field->name().c_str());
+
+ writer_->Key("nullable");
+ writer_->Bool(field->nullable());
+
+ const DataType* type = field->type().get();
+ std::vector<std::pair<std::string, std::string>> additional_metadata;
+ if (type->id() == Type::EXTENSION) {
+ const auto& ext_type = checked_cast<const ExtensionType&>(*type);
+ type = ext_type.storage_type().get();
+ additional_metadata.emplace_back(kExtensionTypeKeyName, ext_type.extension_name());
+ additional_metadata.emplace_back(kExtensionMetadataKeyName, ext_type.Serialize());
+ }
+
+ // Visit the type
+ writer_->Key("type");
+ writer_->StartObject();
+ RETURN_NOT_OK(VisitType(*type));
+ writer_->EndObject();
+
+ if (type->id() == Type::DICTIONARY) {
+ const auto& dict_type = checked_cast<const DictionaryType&>(*type);
+ // Ensure we visit child fields first so that, in the case of nested
+ // dictionaries, inner dictionaries get a smaller id than outer dictionaries.
+ RETURN_NOT_OK(WriteChildren(dict_type.value_type()->fields(), field_pos));
+ ARROW_ASSIGN_OR_RAISE(const int64_t dictionary_id,
+ mapper_.GetFieldId(field_pos.path()));
+ RETURN_NOT_OK(WriteDictionaryMetadata(dictionary_id, dict_type));
+ } else {
+ RETURN_NOT_OK(WriteChildren(type->fields(), field_pos));
+ }
+
+ WriteKeyValueMetadata(field->metadata(), additional_metadata);
+ writer_->EndObject();
+
+ return Status::OK();
+ }
+
+ Status VisitType(const DataType& type);
+
+ template <typename T>
+ enable_if_t<is_null_type<T>::value || is_primitive_ctype<T>::value ||
+ is_base_binary_type<T>::value || is_base_list_type<T>::value ||
+ is_struct_type<T>::value>
+ WriteTypeMetadata(const T& type) {}
+
+ void WriteTypeMetadata(const MapType& type) {
+ writer_->Key("keysSorted");
+ writer_->Bool(type.keys_sorted());
+ }
+
+ void WriteTypeMetadata(const IntegerType& type) {
+ writer_->Key("bitWidth");
+ writer_->Int(type.bit_width());
+ writer_->Key("isSigned");
+ writer_->Bool(type.is_signed());
+ }
+
+ void WriteTypeMetadata(const FloatingPointType& type) {
+ writer_->Key("precision");
+ writer_->String(GetFloatingPrecisionName(type.precision()));
+ }
+
+ void WriteTypeMetadata(const IntervalType& type) {
+ writer_->Key("unit");
+ switch (type.interval_type()) {
+ case IntervalType::MONTHS:
+ writer_->String(kYearMonth);
+ break;
+ case IntervalType::DAY_TIME:
+ writer_->String(kDayTime);
+ break;
+ case IntervalType::MONTH_DAY_NANO:
+ writer_->String(kMonthDayNano);
+ break;
+ }
+ }
+
+ void WriteTypeMetadata(const TimestampType& type) {
+ writer_->Key("unit");
+ writer_->String(GetTimeUnitName(type.unit()));
+ if (type.timezone().size() > 0) {
+ writer_->Key("timezone");
+ writer_->String(type.timezone());
+ }
+ }
+
+ void WriteTypeMetadata(const DurationType& type) {
+ writer_->Key("unit");
+ writer_->String(GetTimeUnitName(type.unit()));
+ }
+
+ void WriteTypeMetadata(const TimeType& type) {
+ writer_->Key("unit");
+ writer_->String(GetTimeUnitName(type.unit()));
+ writer_->Key("bitWidth");
+ writer_->Int(type.bit_width());
+ }
+
+ void WriteTypeMetadata(const DateType& type) {
+ writer_->Key("unit");
+ switch (type.unit()) {
+ case DateUnit::DAY:
+ writer_->String("DAY");
+ break;
+ case DateUnit::MILLI:
+ writer_->String("MILLISECOND");
+ break;
+ }
+ }
+
+ void WriteTypeMetadata(const FixedSizeBinaryType& type) {
+ writer_->Key("byteWidth");
+ writer_->Int(type.byte_width());
+ }
+
+ void WriteTypeMetadata(const FixedSizeListType& type) {
+ writer_->Key("listSize");
+ writer_->Int(type.list_size());
+ }
+
+ void WriteTypeMetadata(const Decimal128Type& type) {
+ writer_->Key("precision");
+ writer_->Int(type.precision());
+ writer_->Key("scale");
+ writer_->Int(type.scale());
+ }
+
+ void WriteTypeMetadata(const Decimal256Type& type) {
+ writer_->Key("precision");
+ writer_->Int(type.precision());
+ writer_->Key("scale");
+ writer_->Int(type.scale());
+ }
+
+ void WriteTypeMetadata(const UnionType& type) {
+ writer_->Key("mode");
+ switch (type.mode()) {
+ case UnionMode::SPARSE:
+ writer_->String("SPARSE");
+ break;
+ case UnionMode::DENSE:
+ writer_->String("DENSE");
+ break;
+ }
+
+ // Write type ids
+ writer_->Key("typeIds");
+ writer_->StartArray();
+ for (size_t i = 0; i < type.type_codes().size(); ++i) {
+ writer_->Int(type.type_codes()[i]);
+ }
+ writer_->EndArray();
+ }
+
+ // TODO(wesm): Other Type metadata
+
+ template <typename T>
+ void WriteName(const std::string& typeclass, const T& type) {
+ writer_->Key("name");
+ writer_->String(typeclass);
+ WriteTypeMetadata(type);
+ }
+
+ template <typename T>
+ Status WritePrimitive(const std::string& typeclass, const T& type) {
+ WriteName(typeclass, type);
+ return Status::OK();
+ }
+
+ template <typename T>
+ Status WriteVarBytes(const std::string& typeclass, const T& type) {
+ WriteName(typeclass, type);
+ return Status::OK();
+ }
+
+ Status WriteChildren(const std::vector<std::shared_ptr<Field>>& children,
+ FieldPosition field_pos) {
+ writer_->Key("children");
+ writer_->StartArray();
+ int i = 0;
+ for (const std::shared_ptr<Field>& field : children) {
+ RETURN_NOT_OK(VisitField(field, field_pos.child(i)));
+ ++i;
+ }
+ writer_->EndArray();
+ return Status::OK();
+ }
+
+ Status Visit(const NullType& type) { return WritePrimitive("null", type); }
+ Status Visit(const BooleanType& type) { return WritePrimitive("bool", type); }
+ Status Visit(const IntegerType& type) { return WritePrimitive("int", type); }
+
+ Status Visit(const FloatingPointType& type) {
+ return WritePrimitive("floatingpoint", type);
+ }
+
+ Status Visit(const DateType& type) { return WritePrimitive("date", type); }
+ Status Visit(const TimeType& type) { return WritePrimitive("time", type); }
+ Status Visit(const StringType& type) { return WriteVarBytes("utf8", type); }
+ Status Visit(const BinaryType& type) { return WriteVarBytes("binary", type); }
+ Status Visit(const LargeStringType& type) { return WriteVarBytes("largeutf8", type); }
+ Status Visit(const LargeBinaryType& type) { return WriteVarBytes("largebinary", type); }
+ Status Visit(const FixedSizeBinaryType& type) {
+ return WritePrimitive("fixedsizebinary", type);
+ }
+
+ Status Visit(const Decimal128Type& type) { return WritePrimitive("decimal", type); }
+ Status Visit(const Decimal256Type& type) { return WritePrimitive("decimal256", type); }
+ Status Visit(const TimestampType& type) { return WritePrimitive("timestamp", type); }
+ Status Visit(const DurationType& type) { return WritePrimitive(kDuration, type); }
+ Status Visit(const MonthIntervalType& type) { return WritePrimitive("interval", type); }
+ Status Visit(const MonthDayNanoIntervalType& type) {
+ return WritePrimitive("interval", type);
+ }
+
+ Status Visit(const DayTimeIntervalType& type) {
+ return WritePrimitive("interval", type);
+ }
+
+ Status Visit(const ListType& type) {
+ WriteName("list", type);
+ return Status::OK();
+ }
+
+ Status Visit(const LargeListType& type) {
+ WriteName("largelist", type);
+ return Status::OK();
+ }
+
+ Status Visit(const MapType& type) {
+ WriteName("map", type);
+ return Status::OK();
+ }
+
+ Status Visit(const FixedSizeListType& type) {
+ WriteName("fixedsizelist", type);
+ return Status::OK();
+ }
+
+ Status Visit(const StructType& type) {
+ WriteName("struct", type);
+ return Status::OK();
+ }
+
+ Status Visit(const UnionType& type) {
+ WriteName("union", type);
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryType& type) { return VisitType(*type.value_type()); }
+
+ Status Visit(const ExtensionType& type) { return Status::NotImplemented(type.name()); }
+
+ private:
+ const Schema& schema_;
+ const DictionaryFieldMapper& mapper_;
+ RjWriter* writer_;
+};
+
+Status SchemaWriter::VisitType(const DataType& type) {
+ return VisitTypeInline(type, this);
+}
+
+class ArrayWriter {
+ public:
+ ArrayWriter(const std::string& name, const Array& array, RjWriter* writer)
+ : name_(name), array_(array), writer_(writer) {}
+
+ Status Write() { return VisitArray(name_, array_); }
+
+ Status VisitArrayValues(const Array& arr) { return VisitArrayInline(arr, this); }
+
+ Status VisitArray(const std::string& name, const Array& arr) {
+ writer_->StartObject();
+ writer_->Key("name");
+ writer_->String(name);
+
+ writer_->Key("count");
+ writer_->Int(static_cast<int32_t>(arr.length()));
+
+ RETURN_NOT_OK(VisitArrayValues(arr));
+
+ writer_->EndObject();
+ return Status::OK();
+ }
+
+ void WriteRawNumber(util::string_view v) {
+ // Avoid RawNumber() as it misleadingly adds quotes
+ // (see https://github.com/Tencent/rapidjson/pull/1155)
+ writer_->RawValue(v.data(), v.size(), rj::kNumberType);
+ }
+
+ template <typename ArrayType, typename TypeClass = typename ArrayType::TypeClass,
+ typename CType = typename TypeClass::c_type>
+ enable_if_t<is_physical_integer_type<TypeClass>::value &&
+ sizeof(CType) != sizeof(int64_t)>
+ WriteDataValues(const ArrayType& arr) {
+ static const std::string null_string = "0";
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ if (arr.IsValid(i)) {
+ writer_->Int64(arr.Value(i));
+ } else {
+ WriteRawNumber(null_string);
+ }
+ }
+ }
+
+ template <typename ArrayType, typename TypeClass = typename ArrayType::TypeClass,
+ typename CType = typename TypeClass::c_type>
+ enable_if_t<is_physical_integer_type<TypeClass>::value &&
+ sizeof(CType) == sizeof(int64_t)>
+ WriteDataValues(const ArrayType& arr) {
+ ::arrow::internal::StringFormatter<typename CTypeTraits<CType>::ArrowType> fmt;
+
+ static const std::string null_string = "0";
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ if (arr.IsValid(i)) {
+ fmt(arr.Value(i), [&](util::string_view repr) {
+ writer_->String(repr.data(), static_cast<rj::SizeType>(repr.size()));
+ });
+ } else {
+ writer_->String(null_string.data(),
+ static_cast<rj::SizeType>(null_string.size()));
+ }
+ }
+ }
+
+ template <typename ArrayType>
+ enable_if_physical_floating_point<typename ArrayType::TypeClass> WriteDataValues(
+ const ArrayType& arr) {
+ static const std::string null_string = "0";
+ const auto data = arr.raw_values();
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ if (arr.IsValid(i)) {
+ writer_->Double(data[i]);
+ } else {
+ WriteRawNumber(null_string);
+ }
+ }
+ }
+
+ // Binary, encode to hexadecimal.
+ template <typename ArrayType>
+ enable_if_binary_like<typename ArrayType::TypeClass> WriteDataValues(
+ const ArrayType& arr) {
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ writer_->String(HexEncode(arr.GetView(i)));
+ }
+ }
+
+ // UTF8 string, write as is
+ template <typename ArrayType>
+ enable_if_string_like<typename ArrayType::TypeClass> WriteDataValues(
+ const ArrayType& arr) {
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ auto view = arr.GetView(i);
+ writer_->String(view.data(), static_cast<rj::SizeType>(view.size()));
+ }
+ }
+
+ void WriteDataValues(const MonthDayNanoIntervalArray& arr) {
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ writer_->StartObject();
+ if (arr.IsValid(i)) {
+ const MonthDayNanoIntervalType::MonthDayNanos dm = arr.GetValue(i);
+ writer_->Key(kMonths);
+ writer_->Int(dm.months);
+ writer_->Key(kDays);
+ writer_->Int(dm.days);
+ writer_->Key(kNanoseconds);
+ writer_->Int64(dm.nanoseconds);
+ }
+ writer_->EndObject();
+ }
+ }
+
+ void WriteDataValues(const DayTimeIntervalArray& arr) {
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ writer_->StartObject();
+ if (arr.IsValid(i)) {
+ const DayTimeIntervalType::DayMilliseconds dm = arr.GetValue(i);
+ writer_->Key(kDays);
+ writer_->Int(dm.days);
+ writer_->Key(kMilliseconds);
+ writer_->Int(dm.milliseconds);
+ }
+ writer_->EndObject();
+ }
+ }
+
+ void WriteDataValues(const Decimal128Array& arr) {
+ static const char null_string[] = "0";
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ if (arr.IsValid(i)) {
+ const Decimal128 value(arr.GetValue(i));
+ writer_->String(value.ToIntegerString());
+ } else {
+ writer_->String(null_string, sizeof(null_string));
+ }
+ }
+ }
+
+ void WriteDataValues(const Decimal256Array& arr) {
+ static const char null_string[] = "0";
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ if (arr.IsValid(i)) {
+ const Decimal256 value(arr.GetValue(i));
+ writer_->String(value.ToIntegerString());
+ } else {
+ writer_->String(null_string, sizeof(null_string));
+ }
+ }
+ }
+
+ void WriteDataValues(const BooleanArray& arr) {
+ for (int64_t i = 0; i < arr.length(); ++i) {
+ if (arr.IsValid(i)) {
+ writer_->Bool(arr.Value(i));
+ } else {
+ writer_->Bool(false);
+ }
+ }
+ }
+
+ template <typename T>
+ void WriteDataField(const T& arr) {
+ writer_->Key(kData);
+ writer_->StartArray();
+ WriteDataValues(arr);
+ writer_->EndArray();
+ }
+
+ template <typename T>
+ void WriteIntegerField(const char* name, const T* values, int64_t length) {
+ writer_->Key(name);
+ writer_->StartArray();
+ if (sizeof(T) < sizeof(int64_t)) {
+ for (int i = 0; i < length; ++i) {
+ writer_->Int64(values[i]);
+ }
+ } else {
+ // Represent 64-bit integers as strings, as JSON numbers cannot represent
+ // them exactly.
+ ::arrow::internal::StringFormatter<typename CTypeTraits<T>::ArrowType> formatter;
+ auto append = [this](util::string_view v) {
+ writer_->String(v.data(), static_cast<rj::SizeType>(v.size()));
+ return Status::OK();
+ };
+ for (int i = 0; i < length; ++i) {
+ DCHECK_OK(formatter(values[i], append));
+ }
+ }
+ writer_->EndArray();
+ }
+
+ void WriteValidityField(const Array& arr) {
+ writer_->Key("VALIDITY");
+ writer_->StartArray();
+ if (arr.null_count() > 0) {
+ for (int i = 0; i < arr.length(); ++i) {
+ writer_->Int(arr.IsNull(i) ? 0 : 1);
+ }
+ } else {
+ for (int i = 0; i < arr.length(); ++i) {
+ writer_->Int(1);
+ }
+ }
+ writer_->EndArray();
+ }
+
+ void SetNoChildren() {
+ // Nothing. We used to write an empty "children" array member,
+ // but that fails the Java parser (ARROW-11483).
+ }
+
+ Status WriteChildren(const std::vector<std::shared_ptr<Field>>& fields,
+ const std::vector<std::shared_ptr<Array>>& arrays) {
+ // NOTE: the Java parser fails on an empty "children" member (ARROW-11483).
+ if (fields.size() > 0) {
+ writer_->Key("children");
+ writer_->StartArray();
+ for (size_t i = 0; i < fields.size(); ++i) {
+ RETURN_NOT_OK(VisitArray(fields[i]->name(), *arrays[i]));
+ }
+ writer_->EndArray();
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const NullArray& array) {
+ SetNoChildren();
+ return Status::OK();
+ }
+
+ template <typename ArrayType>
+ enable_if_t<std::is_base_of<PrimitiveArray, ArrayType>::value, Status> Visit(
+ const ArrayType& array) {
+ WriteValidityField(array);
+ WriteDataField(array);
+ SetNoChildren();
+ return Status::OK();
+ }
+
+ template <typename ArrayType>
+ enable_if_base_binary<typename ArrayType::TypeClass, Status> Visit(
+ const ArrayType& array) {
+ WriteValidityField(array);
+ WriteIntegerField("OFFSET", array.raw_value_offsets(), array.length() + 1);
+ WriteDataField(array);
+ SetNoChildren();
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryArray& array) {
+ return VisitArrayValues(*array.indices());
+ }
+
+ template <typename ArrayType>
+ enable_if_var_size_list<typename ArrayType::TypeClass, Status> Visit(
+ const ArrayType& array) {
+ WriteValidityField(array);
+ WriteIntegerField("OFFSET", array.raw_value_offsets(), array.length() + 1);
+ return WriteChildren(array.type()->fields(), {array.values()});
+ }
+
+ Status Visit(const FixedSizeListArray& array) {
+ WriteValidityField(array);
+ const auto& type = checked_cast<const FixedSizeListType&>(*array.type());
+ return WriteChildren(type.fields(), {array.values()});
+ }
+
+ Status Visit(const StructArray& array) {
+ WriteValidityField(array);
+ const auto& type = checked_cast<const StructType&>(*array.type());
+ std::vector<std::shared_ptr<Array>> children;
+ children.reserve(array.num_fields());
+ for (int i = 0; i < array.num_fields(); ++i) {
+ children.emplace_back(array.field(i));
+ }
+ return WriteChildren(type.fields(), children);
+ }
+
+ Status Visit(const UnionArray& array) {
+ const auto& type = checked_cast<const UnionType&>(*array.type());
+ WriteIntegerField("TYPE_ID", array.raw_type_codes(), array.length());
+ if (type.mode() == UnionMode::DENSE) {
+ auto offsets = checked_cast<const DenseUnionArray&>(array).raw_value_offsets();
+ WriteIntegerField("OFFSET", offsets, array.length());
+ }
+ std::vector<std::shared_ptr<Array>> children;
+ children.reserve(array.num_fields());
+ for (int i = 0; i < array.num_fields(); ++i) {
+ children.emplace_back(array.field(i));
+ }
+ return WriteChildren(type.fields(), children);
+ }
+
+ Status Visit(const ExtensionArray& array) { return VisitArrayValues(*array.storage()); }
+
+ private:
+ const std::string& name_;
+ const Array& array_;
+ RjWriter* writer_;
+};
+
+Result<TimeUnit::type> GetUnitFromString(const std::string& unit_str) {
+ if (unit_str == "SECOND") {
+ return TimeUnit::SECOND;
+ } else if (unit_str == "MILLISECOND") {
+ return TimeUnit::MILLI;
+ } else if (unit_str == "MICROSECOND") {
+ return TimeUnit::MICRO;
+ } else if (unit_str == "NANOSECOND") {
+ return TimeUnit::NANO;
+ } else {
+ return Status::Invalid("Invalid time unit: ", unit_str);
+ }
+}
+
+template <typename IntType = int>
+Result<IntType> GetMemberInt(const RjObject& obj, const std::string& key) {
+ const auto& it = obj.FindMember(key);
+ RETURN_NOT_INT(key, it, obj);
+ return static_cast<IntType>(it->value.GetInt64());
+}
+
+Result<bool> GetMemberBool(const RjObject& obj, const std::string& key) {
+ const auto& it = obj.FindMember(key);
+ RETURN_NOT_BOOL(key, it, obj);
+ return it->value.GetBool();
+}
+
+Result<std::string> GetMemberString(const RjObject& obj, const std::string& key) {
+ const auto& it = obj.FindMember(key);
+ RETURN_NOT_STRING(key, it, obj);
+ return it->value.GetString();
+}
+
+Result<const RjObject> GetMemberObject(const RjObject& obj, const std::string& key) {
+ const auto& it = obj.FindMember(key);
+ RETURN_NOT_OBJECT(key, it, obj);
+ return it->value.GetObject();
+}
+
+Result<const RjArray> GetMemberArray(const RjObject& obj, const std::string& key,
+ bool allow_absent = false) {
+ static const auto empty_array = rj::Value(rj::kArrayType);
+
+ const auto& it = obj.FindMember(key);
+ if (allow_absent && it == obj.MemberEnd()) {
+ return empty_array.GetArray();
+ }
+ RETURN_NOT_ARRAY(key, it, obj);
+ return it->value.GetArray();
+}
+
+Result<TimeUnit::type> GetMemberTimeUnit(const RjObject& obj, const std::string& key) {
+ ARROW_ASSIGN_OR_RAISE(const auto unit_str, GetMemberString(obj, key));
+ return GetUnitFromString(unit_str);
+}
+
+Status GetInteger(const rj::Value::ConstObject& json_type,
+ std::shared_ptr<DataType>* type) {
+ ARROW_ASSIGN_OR_RAISE(const bool is_signed, GetMemberBool(json_type, "isSigned"));
+ ARROW_ASSIGN_OR_RAISE(const int bit_width, GetMemberInt<int>(json_type, "bitWidth"));
+
+ switch (bit_width) {
+ case 8:
+ *type = is_signed ? int8() : uint8();
+ break;
+ case 16:
+ *type = is_signed ? int16() : uint16();
+ break;
+ case 32:
+ *type = is_signed ? int32() : uint32();
+ break;
+ case 64:
+ *type = is_signed ? int64() : uint64();
+ break;
+ default:
+ return Status::Invalid("Invalid bit width: ", bit_width);
+ }
+ return Status::OK();
+}
+
+Status GetFloatingPoint(const RjObject& json_type, std::shared_ptr<DataType>* type) {
+ ARROW_ASSIGN_OR_RAISE(const auto precision, GetMemberString(json_type, "precision"));
+
+ if (precision == "DOUBLE") {
+ *type = float64();
+ } else if (precision == "SINGLE") {
+ *type = float32();
+ } else if (precision == "HALF") {
+ *type = float16();
+ } else {
+ return Status::Invalid("Invalid precision: ", precision);
+ }
+ return Status::OK();
+}
+
+Status GetMap(const RjObject& json_type,
+ const std::vector<std::shared_ptr<Field>>& children,
+ std::shared_ptr<DataType>* type) {
+ if (children.size() != 1) {
+ return Status::Invalid("Map must have exactly one child");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(const bool keys_sorted, GetMemberBool(json_type, "keysSorted"));
+ return MapType::Make(children[0], keys_sorted).Value(type);
+}
+
+Status GetFixedSizeBinary(const RjObject& json_type, std::shared_ptr<DataType>* type) {
+ ARROW_ASSIGN_OR_RAISE(const int32_t byte_width,
+ GetMemberInt<int32_t>(json_type, "byteWidth"));
+ *type = fixed_size_binary(byte_width);
+ return Status::OK();
+}
+
+Status GetFixedSizeList(const RjObject& json_type,
+ const std::vector<std::shared_ptr<Field>>& children,
+ std::shared_ptr<DataType>* type) {
+ if (children.size() != 1) {
+ return Status::Invalid("FixedSizeList must have exactly one child");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(const int32_t list_size,
+ GetMemberInt<int32_t>(json_type, "listSize"));
+ *type = fixed_size_list(children[0], list_size);
+ return Status::OK();
+}
+
+Status GetDecimal(const RjObject& json_type, std::shared_ptr<DataType>* type) {
+ ARROW_ASSIGN_OR_RAISE(const int32_t precision,
+ GetMemberInt<int32_t>(json_type, "precision"));
+ ARROW_ASSIGN_OR_RAISE(const int32_t scale, GetMemberInt<int32_t>(json_type, "scale"));
+ int32_t bit_width = 128;
+ Result<int32_t> maybe_bit_width = GetMemberInt<int32_t>(json_type, "bitWidth");
+ if (maybe_bit_width.ok()) {
+ bit_width = maybe_bit_width.ValueOrDie();
+ }
+
+ if (bit_width == 128) {
+ *type = decimal128(precision, scale);
+ } else if (bit_width == 256) {
+ *type = decimal256(precision, scale);
+ } else {
+ return Status::Invalid("Only 128 bit and 256 Decimals are supported. Received",
+ bit_width);
+ }
+ return Status::OK();
+}
+
+Status GetDate(const RjObject& json_type, std::shared_ptr<DataType>* type) {
+ ARROW_ASSIGN_OR_RAISE(const auto unit_str, GetMemberString(json_type, "unit"));
+
+ if (unit_str == "DAY") {
+ *type = date32();
+ } else if (unit_str == "MILLISECOND") {
+ *type = date64();
+ } else {
+ return Status::Invalid("Invalid date unit: ", unit_str);
+ }
+ return Status::OK();
+}
+
+Status GetTime(const RjObject& json_type, std::shared_ptr<DataType>* type) {
+ ARROW_ASSIGN_OR_RAISE(const auto unit_str, GetMemberString(json_type, "unit"));
+ ARROW_ASSIGN_OR_RAISE(const int bit_width, GetMemberInt<int>(json_type, "bitWidth"));
+
+ if (unit_str == "SECOND") {
+ *type = time32(TimeUnit::SECOND);
+ } else if (unit_str == "MILLISECOND") {
+ *type = time32(TimeUnit::MILLI);
+ } else if (unit_str == "MICROSECOND") {
+ *type = time64(TimeUnit::MICRO);
+ } else if (unit_str == "NANOSECOND") {
+ *type = time64(TimeUnit::NANO);
+ } else {
+ return Status::Invalid("Invalid time unit: ", unit_str);
+ }
+
+ const auto& fw_type = checked_cast<const FixedWidthType&>(**type);
+
+ if (bit_width != fw_type.bit_width()) {
+ return Status::Invalid("Indicated bit width does not match unit");
+ }
+
+ return Status::OK();
+}
+
+Status GetDuration(const RjObject& json_type, std::shared_ptr<DataType>* type) {
+ ARROW_ASSIGN_OR_RAISE(const TimeUnit::type unit, GetMemberTimeUnit(json_type, "unit"));
+ *type = duration(unit);
+ return Status::OK();
+}
+
+Status GetTimestamp(const RjObject& json_type, std::shared_ptr<DataType>* type) {
+ ARROW_ASSIGN_OR_RAISE(const TimeUnit::type unit, GetMemberTimeUnit(json_type, "unit"));
+
+ const auto& it_tz = json_type.FindMember("timezone");
+ if (it_tz == json_type.MemberEnd()) {
+ *type = timestamp(unit);
+ } else {
+ RETURN_NOT_STRING("timezone", it_tz, json_type);
+ *type = timestamp(unit, it_tz->value.GetString());
+ }
+
+ return Status::OK();
+}
+
+Status GetInterval(const RjObject& json_type, std::shared_ptr<DataType>* type) {
+ ARROW_ASSIGN_OR_RAISE(const auto unit_str, GetMemberString(json_type, "unit"));
+
+ if (unit_str == kDayTime) {
+ *type = day_time_interval();
+ } else if (unit_str == kYearMonth) {
+ *type = month_interval();
+ } else if (unit_str == kMonthDayNano) {
+ *type = month_day_nano_interval();
+ } else {
+ return Status::Invalid("Invalid interval unit: " + unit_str);
+ }
+ return Status::OK();
+}
+
+Status GetUnion(const RjObject& json_type,
+ const std::vector<std::shared_ptr<Field>>& children,
+ std::shared_ptr<DataType>* type) {
+ ARROW_ASSIGN_OR_RAISE(const auto mode_str, GetMemberString(json_type, "mode"));
+
+ UnionMode::type mode;
+ if (mode_str == "SPARSE") {
+ mode = UnionMode::SPARSE;
+ } else if (mode_str == "DENSE") {
+ mode = UnionMode::DENSE;
+ } else {
+ return Status::Invalid("Invalid union mode: ", mode_str);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(const auto json_type_codes, GetMemberArray(json_type, "typeIds"));
+
+ std::vector<int8_t> type_codes;
+ type_codes.reserve(json_type_codes.Size());
+ for (const rj::Value& val : json_type_codes) {
+ if (!val.IsInt()) {
+ return Status::Invalid("Union type codes must be integers");
+ }
+ type_codes.push_back(static_cast<int8_t>(val.GetInt()));
+ }
+
+ if (mode == UnionMode::SPARSE) {
+ *type = sparse_union(std::move(children), std::move(type_codes));
+ } else {
+ *type = dense_union(std::move(children), std::move(type_codes));
+ }
+
+ return Status::OK();
+}
+
+Status GetType(const RjObject& json_type,
+ const std::vector<std::shared_ptr<Field>>& children,
+ std::shared_ptr<DataType>* type) {
+ ARROW_ASSIGN_OR_RAISE(const auto type_name, GetMemberString(json_type, "name"));
+
+ if (type_name == "int") {
+ return GetInteger(json_type, type);
+ } else if (type_name == "floatingpoint") {
+ return GetFloatingPoint(json_type, type);
+ } else if (type_name == "bool") {
+ *type = boolean();
+ } else if (type_name == "utf8") {
+ *type = utf8();
+ } else if (type_name == "binary") {
+ *type = binary();
+ } else if (type_name == "largeutf8") {
+ *type = large_utf8();
+ } else if (type_name == "largebinary") {
+ *type = large_binary();
+ } else if (type_name == "fixedsizebinary") {
+ return GetFixedSizeBinary(json_type, type);
+ } else if (type_name == "decimal") {
+ return GetDecimal(json_type, type);
+ } else if (type_name == "null") {
+ *type = null();
+ } else if (type_name == "date") {
+ return GetDate(json_type, type);
+ } else if (type_name == "time") {
+ return GetTime(json_type, type);
+ } else if (type_name == "timestamp") {
+ return GetTimestamp(json_type, type);
+ } else if (type_name == "interval") {
+ return GetInterval(json_type, type);
+ } else if (type_name == kDuration) {
+ return GetDuration(json_type, type);
+ } else if (type_name == "list") {
+ if (children.size() != 1) {
+ return Status::Invalid("List must have exactly one child");
+ }
+ *type = list(children[0]);
+ } else if (type_name == "largelist") {
+ if (children.size() != 1) {
+ return Status::Invalid("Large list must have exactly one child");
+ }
+ *type = large_list(children[0]);
+ } else if (type_name == "map") {
+ return GetMap(json_type, children, type);
+ } else if (type_name == "fixedsizelist") {
+ return GetFixedSizeList(json_type, children, type);
+ } else if (type_name == "struct") {
+ *type = struct_(children);
+ } else if (type_name == "union") {
+ return GetUnion(json_type, children, type);
+ } else {
+ return Status::Invalid("Unrecognized type name: ", type_name);
+ }
+ return Status::OK();
+}
+
+Status GetField(const rj::Value& obj, FieldPosition field_pos,
+ DictionaryMemo* dictionary_memo, std::shared_ptr<Field>* field);
+
+Status GetFieldsFromArray(const RjArray& json_fields, FieldPosition parent_pos,
+ DictionaryMemo* dictionary_memo,
+ std::vector<std::shared_ptr<Field>>* fields) {
+ fields->resize(json_fields.Size());
+ for (rj::SizeType i = 0; i < json_fields.Size(); ++i) {
+ RETURN_NOT_OK(GetField(json_fields[i], parent_pos.child(static_cast<int>(i)),
+ dictionary_memo, &(*fields)[i]));
+ }
+ return Status::OK();
+}
+
+Status ParseDictionary(const RjObject& obj, int64_t* id, bool* is_ordered,
+ std::shared_ptr<DataType>* index_type) {
+ ARROW_ASSIGN_OR_RAISE(*id, GetMemberInt<int64_t>(obj, "id"));
+ ARROW_ASSIGN_OR_RAISE(*is_ordered, GetMemberBool(obj, "isOrdered"));
+
+ ARROW_ASSIGN_OR_RAISE(const auto json_index_type, GetMemberObject(obj, "indexType"));
+
+ ARROW_ASSIGN_OR_RAISE(const auto type_name, GetMemberString(json_index_type, "name"));
+ if (type_name != "int") {
+ return Status::Invalid("Dictionary indices can only be integers");
+ }
+ return GetInteger(json_index_type, index_type);
+}
+
+template <typename FieldOrStruct>
+Status GetKeyValueMetadata(const FieldOrStruct& field_or_struct,
+ std::shared_ptr<KeyValueMetadata>* out) {
+ out->reset(new KeyValueMetadata);
+ auto it = field_or_struct.FindMember("metadata");
+ if (it == field_or_struct.MemberEnd() || it->value.IsNull()) {
+ return Status::OK();
+ }
+ if (!it->value.IsArray()) {
+ return Status::Invalid("Metadata was not a JSON array");
+ }
+
+ for (const auto& val : it->value.GetArray()) {
+ if (!val.IsObject()) {
+ return Status::Invalid("Metadata KeyValue was not a JSON object");
+ }
+ const auto& key_value_pair = val.GetObject();
+
+ ARROW_ASSIGN_OR_RAISE(const auto key, GetMemberString(key_value_pair, "key"));
+ ARROW_ASSIGN_OR_RAISE(const auto value, GetMemberString(key_value_pair, "value"));
+
+ (*out)->Append(std::move(key), std::move(value));
+ }
+ return Status::OK();
+}
+
+Status GetField(const rj::Value& obj, FieldPosition field_pos,
+ DictionaryMemo* dictionary_memo, std::shared_ptr<Field>* field) {
+ if (!obj.IsObject()) {
+ return Status::Invalid("Field was not a JSON object");
+ }
+ const auto& json_field = obj.GetObject();
+
+ std::shared_ptr<DataType> type;
+
+ ARROW_ASSIGN_OR_RAISE(const auto name, GetMemberString(json_field, "name"));
+ ARROW_ASSIGN_OR_RAISE(const bool nullable, GetMemberBool(json_field, "nullable"));
+
+ ARROW_ASSIGN_OR_RAISE(const auto json_type, GetMemberObject(json_field, "type"));
+ ARROW_ASSIGN_OR_RAISE(const auto json_children, GetMemberArray(json_field, "children"));
+
+ std::vector<std::shared_ptr<Field>> children;
+ RETURN_NOT_OK(GetFieldsFromArray(json_children, field_pos, dictionary_memo, &children));
+ RETURN_NOT_OK(GetType(json_type, children, &type));
+
+ std::shared_ptr<KeyValueMetadata> metadata;
+ RETURN_NOT_OK(GetKeyValueMetadata(json_field, &metadata));
+
+ // Is it a dictionary type?
+ int64_t dictionary_id = -1;
+ std::shared_ptr<DataType> dict_value_type;
+ const auto& it_dictionary = json_field.FindMember("dictionary");
+ if (dictionary_memo != nullptr && it_dictionary != json_field.MemberEnd()) {
+ // Parse dictionary id in JSON and add dictionary field to the
+ // memo, and parse the dictionaries later
+ RETURN_NOT_OBJECT("dictionary", it_dictionary, json_field);
+ bool is_ordered;
+ std::shared_ptr<DataType> index_type;
+ RETURN_NOT_OK(ParseDictionary(it_dictionary->value.GetObject(), &dictionary_id,
+ &is_ordered, &index_type));
+
+ dict_value_type = type;
+ type = ::arrow::dictionary(index_type, type, is_ordered);
+ }
+
+ // Is it an extension type?
+ int ext_name_index = metadata->FindKey(kExtensionTypeKeyName);
+ if (ext_name_index != -1) {
+ const auto& ext_name = metadata->value(ext_name_index);
+ ARROW_ASSIGN_OR_RAISE(auto ext_data, metadata->Get(kExtensionMetadataKeyName));
+
+ auto ext_type = GetExtensionType(ext_name);
+ if (ext_type == nullptr) {
+ // Some integration tests check that unregistered extensions pass through
+ auto maybe_value = metadata->Get("ARROW:integration:allow_unregistered_extension");
+ if (!maybe_value.ok() || *maybe_value != "true") {
+ return Status::KeyError("Extension type '", ext_name, "' not found");
+ }
+ } else {
+ ARROW_ASSIGN_OR_RAISE(type, ext_type->Deserialize(type, ext_data));
+
+ // Remove extension type metadata, for exact roundtripping
+ RETURN_NOT_OK(metadata->Delete(kExtensionTypeKeyName));
+ RETURN_NOT_OK(metadata->Delete(kExtensionMetadataKeyName));
+ }
+ }
+
+ // Create field
+ *field = ::arrow::field(name, type, nullable, metadata);
+ if (dictionary_id != -1) {
+ RETURN_NOT_OK(dictionary_memo->fields().AddField(dictionary_id, field_pos.path()));
+ RETURN_NOT_OK(dictionary_memo->AddDictionaryType(dictionary_id, dict_value_type));
+ }
+
+ return Status::OK();
+}
+
+template <typename T>
+enable_if_boolean<T, bool> UnboxValue(const rj::Value& val) {
+ DCHECK(val.IsBool());
+ return val.GetBool();
+}
+
+template <typename T, typename CType = typename T::c_type>
+enable_if_t<is_physical_integer_type<T>::value && sizeof(CType) != sizeof(int64_t), CType>
+UnboxValue(const rj::Value& val) {
+ DCHECK(val.IsInt64());
+ return static_cast<CType>(val.GetInt64());
+}
+
+template <typename T, typename CType = typename T::c_type>
+enable_if_t<is_physical_integer_type<T>::value && sizeof(CType) == sizeof(int64_t), CType>
+UnboxValue(const rj::Value& val) {
+ DCHECK(val.IsString());
+
+ CType out;
+ bool success = ::arrow::internal::ParseValue<typename CTypeTraits<CType>::ArrowType>(
+ val.GetString(), val.GetStringLength(), &out);
+
+ DCHECK(success);
+ return out;
+}
+
+template <typename T>
+enable_if_physical_floating_point<T, typename T::c_type> UnboxValue(
+ const rj::Value& val) {
+ DCHECK(val.IsFloat());
+ return static_cast<typename T::c_type>(val.GetDouble());
+}
+
+class ArrayReader {
+ public:
+ ArrayReader(const RjObject& obj, MemoryPool* pool, const std::shared_ptr<Field>& field)
+ : obj_(obj), pool_(pool), field_(field), type_(field->type()) {}
+
+ template <typename BuilderType>
+ Status FinishBuilder(BuilderType* builder) {
+ std::shared_ptr<Array> array;
+ RETURN_NOT_OK(builder->Finish(&array));
+ data_ = array->data();
+ return Status::OK();
+ }
+
+ Result<const RjArray> GetDataArray(const RjObject& obj) {
+ ARROW_ASSIGN_OR_RAISE(const auto json_data_arr, GetMemberArray(obj, kData));
+ if (static_cast<int32_t>(json_data_arr.Size()) != length_) {
+ return Status::Invalid("JSON DATA array size differs from advertised array length");
+ }
+ return json_data_arr;
+ }
+
+ template <typename T>
+ enable_if_has_c_type<T, Status> Visit(const T& type) {
+ typename TypeTraits<T>::BuilderType builder(type_, pool_);
+
+ ARROW_ASSIGN_OR_RAISE(const auto json_data_arr, GetDataArray(obj_));
+
+ for (int i = 0; i < length_; ++i) {
+ if (!is_valid_[i]) {
+ RETURN_NOT_OK(builder.AppendNull());
+ continue;
+ }
+ const rj::Value& val = json_data_arr[i];
+ RETURN_NOT_OK(builder.Append(UnboxValue<T>(val)));
+ }
+ return FinishBuilder(&builder);
+ }
+
+ int64_t ParseOffset(const rj::Value& json_offset) {
+ DCHECK(json_offset.IsInt() || json_offset.IsInt64() || json_offset.IsString());
+
+ if (json_offset.IsInt64()) {
+ return json_offset.GetInt64();
+ } else {
+ return UnboxValue<Int64Type>(json_offset);
+ }
+ }
+
+ template <typename T>
+ enable_if_base_binary<T, Status> Visit(const T& type) {
+ typename TypeTraits<T>::BuilderType builder(pool_);
+ using offset_type = typename T::offset_type;
+
+ ARROW_ASSIGN_OR_RAISE(const auto json_data_arr, GetDataArray(obj_));
+ ARROW_ASSIGN_OR_RAISE(const auto json_offsets, GetMemberArray(obj_, "OFFSET"));
+ if (static_cast<int32_t>(json_offsets.Size()) != (length_ + 1)) {
+ return Status::Invalid(
+ "JSON OFFSET array size differs from advertised array length + 1");
+ }
+
+ for (int i = 0; i < length_; ++i) {
+ if (!is_valid_[i]) {
+ RETURN_NOT_OK(builder.AppendNull());
+ continue;
+ }
+ const rj::Value& val = json_data_arr[i];
+ DCHECK(val.IsString());
+
+ int64_t offset_start = ParseOffset(json_offsets[i]);
+ int64_t offset_end = ParseOffset(json_offsets[i + 1]);
+ DCHECK(offset_end >= offset_start);
+
+ if (T::is_utf8) {
+ auto str = val.GetString();
+ DCHECK(std::string(str).size() == static_cast<size_t>(offset_end - offset_start));
+ RETURN_NOT_OK(builder.Append(str));
+ } else {
+ std::string hex_string = val.GetString();
+
+ if (hex_string.size() % 2 != 0) {
+ return Status::Invalid("Expected base16 hex string");
+ }
+ const auto value_len = static_cast<int64_t>(hex_string.size()) / 2;
+
+ ARROW_ASSIGN_OR_RAISE(auto byte_buffer, AllocateBuffer(value_len, pool_));
+
+ const char* hex_data = hex_string.c_str();
+ uint8_t* byte_buffer_data = byte_buffer->mutable_data();
+ for (int64_t j = 0; j < value_len; ++j) {
+ RETURN_NOT_OK(ParseHexValue(hex_data + j * 2, &byte_buffer_data[j]));
+ }
+ RETURN_NOT_OK(
+ builder.Append(byte_buffer_data, static_cast<offset_type>(value_len)));
+ }
+ }
+ return FinishBuilder(&builder);
+ }
+
+ Status Visit(const DayTimeIntervalType& type) {
+ DayTimeIntervalBuilder builder(pool_);
+
+ ARROW_ASSIGN_OR_RAISE(const auto json_data_arr, GetDataArray(obj_));
+
+ for (int i = 0; i < length_; ++i) {
+ if (!is_valid_[i]) {
+ RETURN_NOT_OK(builder.AppendNull());
+ continue;
+ }
+
+ const rj::Value& val = json_data_arr[i];
+ DCHECK(val.IsObject());
+ DayTimeIntervalType::DayMilliseconds dm = {0, 0};
+ dm.days = val[kDays].GetInt();
+ dm.milliseconds = val[kMilliseconds].GetInt();
+ RETURN_NOT_OK(builder.Append(dm));
+ }
+ return FinishBuilder(&builder);
+ }
+
+ Status Visit(const MonthDayNanoIntervalType& type) {
+ MonthDayNanoIntervalBuilder builder(pool_);
+
+ ARROW_ASSIGN_OR_RAISE(const auto json_data_arr, GetDataArray(obj_));
+
+ for (int i = 0; i < length_; ++i) {
+ if (!is_valid_[i]) {
+ RETURN_NOT_OK(builder.AppendNull());
+ continue;
+ }
+
+ const rj::Value& val = json_data_arr[i];
+ DCHECK(val.IsObject());
+ MonthDayNanoIntervalType::MonthDayNanos dm = {0, 0, 0};
+ dm.months = val[kMonths].GetInt();
+ dm.days = val[kDays].GetInt();
+ dm.nanoseconds = val[kNanoseconds].GetInt64();
+ RETURN_NOT_OK(builder.Append(dm));
+ }
+ return FinishBuilder(&builder);
+ }
+
+ template <typename T>
+ enable_if_t<is_fixed_size_binary_type<T>::value && !is_decimal_type<T>::value, Status>
+ Visit(const T& type) {
+ typename TypeTraits<T>::BuilderType builder(type_, pool_);
+
+ ARROW_ASSIGN_OR_RAISE(const auto json_data_arr, GetDataArray(obj_));
+
+ int32_t byte_width = type.byte_width();
+
+ // Allocate space for parsed values
+ ARROW_ASSIGN_OR_RAISE(auto byte_buffer, AllocateBuffer(byte_width, pool_));
+ uint8_t* byte_buffer_data = byte_buffer->mutable_data();
+
+ for (int i = 0; i < length_; ++i) {
+ if (!is_valid_[i]) {
+ RETURN_NOT_OK(builder.AppendNull());
+ } else {
+ const rj::Value& val = json_data_arr[i];
+ DCHECK(val.IsString())
+ << "Found non-string JSON value when parsing FixedSizeBinary value";
+ std::string hex_string = val.GetString();
+ if (static_cast<int32_t>(hex_string.size()) != byte_width * 2) {
+ DCHECK(false) << "Expected size: " << byte_width * 2
+ << " got: " << hex_string.size();
+ }
+ const char* hex_data = hex_string.c_str();
+
+ for (int32_t j = 0; j < byte_width; ++j) {
+ RETURN_NOT_OK(ParseHexValue(hex_data + j * 2, &byte_buffer_data[j]));
+ }
+ RETURN_NOT_OK(builder.Append(byte_buffer_data));
+ }
+ }
+ return FinishBuilder(&builder);
+ }
+
+ template <typename T>
+ enable_if_decimal<T, Status> Visit(const T& type) {
+ typename TypeTraits<T>::BuilderType builder(type_, pool_);
+
+ ARROW_ASSIGN_OR_RAISE(const auto json_data_arr, GetDataArray(obj_));
+
+ for (int i = 0; i < length_; ++i) {
+ if (!is_valid_[i]) {
+ RETURN_NOT_OK(builder.AppendNull());
+ } else {
+ const rj::Value& val = json_data_arr[i];
+ DCHECK(val.IsString())
+ << "Found non-string JSON value when parsing Decimal128 value";
+ DCHECK_GT(val.GetStringLength(), 0)
+ << "Empty string found when parsing Decimal128 value";
+
+ using Value = typename TypeTraits<T>::ScalarType::ValueType;
+ Value value;
+ ARROW_ASSIGN_OR_RAISE(value, Value::FromString(val.GetString()));
+ RETURN_NOT_OK(builder.Append(value));
+ }
+ }
+
+ return FinishBuilder(&builder);
+ }
+
+ template <typename T>
+ Status GetIntArray(const RjArray& json_array, const int32_t length,
+ std::shared_ptr<Buffer>* out) {
+ using ArrowType = typename CTypeTraits<T>::ArrowType;
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(length * sizeof(T), pool_));
+
+ T* values = reinterpret_cast<T*>(buffer->mutable_data());
+ if (sizeof(T) < sizeof(int64_t)) {
+ for (int i = 0; i < length; ++i) {
+ const rj::Value& val = json_array[i];
+ DCHECK(val.IsInt() || val.IsInt64());
+ if (val.IsInt()) {
+ values[i] = static_cast<T>(val.GetInt());
+ } else {
+ values[i] = static_cast<T>(val.GetInt64());
+ }
+ }
+ } else {
+ // Read 64-bit integers as strings, as JSON numbers cannot represent
+ // them exactly.
+ for (int i = 0; i < length; ++i) {
+ const rj::Value& val = json_array[i];
+ DCHECK(val.IsString());
+ if (!ParseValue<ArrowType>(val.GetString(), val.GetStringLength(), &values[i])) {
+ return Status::Invalid("Failed to parse integer: '",
+ std::string(val.GetString(), val.GetStringLength()),
+ "'");
+ }
+ }
+ }
+
+ *out = std::move(buffer);
+ return Status::OK();
+ }
+
+ template <typename T>
+ Status CreateList(const std::shared_ptr<DataType>& type) {
+ using offset_type = typename T::offset_type;
+
+ RETURN_NOT_OK(InitializeData(2));
+
+ RETURN_NOT_OK(GetNullBitmap());
+ ARROW_ASSIGN_OR_RAISE(const auto json_offsets, GetMemberArray(obj_, "OFFSET"));
+ RETURN_NOT_OK(
+ GetIntArray<offset_type>(json_offsets, length_ + 1, &data_->buffers[1]));
+ RETURN_NOT_OK(GetChildren(obj_, *type));
+ return Status::OK();
+ }
+
+ template <typename T>
+ enable_if_var_size_list<T, Status> Visit(const T& type) {
+ return CreateList<T>(type_);
+ }
+
+ Status Visit(const MapType& type) {
+ auto list_type = std::make_shared<ListType>(type.value_field());
+ RETURN_NOT_OK(CreateList<ListType>(list_type));
+ data_->type = type_;
+ return Status::OK();
+ }
+
+ Status Visit(const FixedSizeListType& type) {
+ RETURN_NOT_OK(InitializeData(1));
+ RETURN_NOT_OK(GetNullBitmap());
+
+ RETURN_NOT_OK(GetChildren(obj_, type));
+ DCHECK_EQ(data_->child_data[0]->length, type.list_size() * length_);
+ return Status::OK();
+ }
+
+ Status Visit(const StructType& type) {
+ RETURN_NOT_OK(InitializeData(1));
+
+ RETURN_NOT_OK(GetNullBitmap());
+ RETURN_NOT_OK(GetChildren(obj_, type));
+ return Status::OK();
+ }
+
+ Status GetUnionTypeIds() {
+ ARROW_ASSIGN_OR_RAISE(const auto json_type_ids, GetMemberArray(obj_, "TYPE_ID"));
+ return GetIntArray<uint8_t>(json_type_ids, length_, &data_->buffers[1]);
+ }
+
+ Status Visit(const SparseUnionType& type) {
+ RETURN_NOT_OK(InitializeData(2));
+
+ RETURN_NOT_OK(GetNullBitmap());
+ RETURN_NOT_OK(GetUnionTypeIds());
+ RETURN_NOT_OK(GetChildren(obj_, type));
+ return Status::OK();
+ }
+
+ Status Visit(const DenseUnionType& type) {
+ RETURN_NOT_OK(InitializeData(3));
+
+ RETURN_NOT_OK(GetNullBitmap());
+ RETURN_NOT_OK(GetUnionTypeIds());
+ RETURN_NOT_OK(GetChildren(obj_, type));
+
+ ARROW_ASSIGN_OR_RAISE(const auto json_offsets, GetMemberArray(obj_, "OFFSET"));
+ return GetIntArray<int32_t>(json_offsets, length_, &data_->buffers[2]);
+ }
+
+ Status Visit(const NullType& type) {
+ data_ = std::make_shared<NullArray>(length_)->data();
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryType& type) {
+ ArrayReader parser(obj_, pool_, ::arrow::field("indices", type.index_type()));
+ ARROW_ASSIGN_OR_RAISE(data_, parser.Parse());
+
+ data_->type = field_->type();
+ // data_->dictionary will be filled later by ResolveDictionaries()
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionType& type) {
+ ArrayReader parser(obj_, pool_, field_->WithType(type.storage_type()));
+ ARROW_ASSIGN_OR_RAISE(data_, parser.Parse());
+ data_->type = type_;
+ // If the storage array is a dictionary array, lookup its dictionary id
+ // using the extension field.
+ // (the field is looked up by pointer, so the Field instance constructed
+ // above wouldn't work)
+ return Status::OK();
+ }
+
+ Status InitializeData(int num_buffers) {
+ data_ = std::make_shared<ArrayData>(type_, length_);
+ data_->buffers.resize(num_buffers);
+ return Status::OK();
+ }
+
+ Status GetNullBitmap() {
+ const int64_t length = static_cast<int64_t>(is_valid_.size());
+
+ ARROW_ASSIGN_OR_RAISE(data_->buffers[0], AllocateEmptyBitmap(length, pool_));
+ uint8_t* bitmap = data_->buffers[0]->mutable_data();
+
+ data_->null_count = 0;
+ for (int64_t i = 0; i < length; ++i) {
+ if (is_valid_[i]) {
+ BitUtil::SetBit(bitmap, i);
+ } else {
+ ++data_->null_count;
+ }
+ }
+ if (data_->null_count == 0) {
+ data_->buffers[0].reset();
+ }
+
+ return Status::OK();
+ }
+ Status GetChildren(const RjObject& obj, const DataType& type) {
+ ARROW_ASSIGN_OR_RAISE(const auto json_children,
+ GetMemberArray(obj, "children", /*allow_absent=*/true));
+
+ if (type.num_fields() != static_cast<int>(json_children.Size())) {
+ return Status::Invalid("Expected ", type.num_fields(), " children, but got ",
+ json_children.Size());
+ }
+
+ data_->child_data.resize(type.num_fields());
+ for (int i = 0; i < type.num_fields(); ++i) {
+ const rj::Value& json_child = json_children[i];
+ DCHECK(json_child.IsObject());
+ const auto& child_obj = json_child.GetObject();
+
+ std::shared_ptr<Field> child_field = type.field(i);
+
+ auto it = json_child.FindMember("name");
+ RETURN_NOT_STRING("name", it, json_child);
+
+ DCHECK_EQ(it->value.GetString(), child_field->name());
+ ArrayReader child_reader(child_obj, pool_, child_field);
+ ARROW_ASSIGN_OR_RAISE(data_->child_data[i], child_reader.Parse());
+ }
+
+ return Status::OK();
+ }
+
+ Status ParseValidityBitmap() {
+ ARROW_ASSIGN_OR_RAISE(const auto json_validity, GetMemberArray(obj_, "VALIDITY"));
+ if (static_cast<int>(json_validity.Size()) != length_) {
+ return Status::Invalid("JSON VALIDITY size differs from advertised array length");
+ }
+ is_valid_.reserve(json_validity.Size());
+ for (const rj::Value& val : json_validity) {
+ DCHECK(val.IsInt());
+ is_valid_.push_back(val.GetInt() != 0);
+ }
+ return Status::OK();
+ }
+
+ Result<std::shared_ptr<ArrayData>> Parse() {
+ ARROW_ASSIGN_OR_RAISE(length_, GetMemberInt<int32_t>(obj_, "count"));
+
+ if (::arrow::internal::HasValidityBitmap(type_->id())) {
+ // Null and union types don't have a validity bitmap
+ RETURN_NOT_OK(ParseValidityBitmap());
+ }
+
+ RETURN_NOT_OK(VisitTypeInline(*type_, this));
+ return data_;
+ }
+
+ private:
+ const RjObject& obj_;
+ MemoryPool* pool_;
+ std::shared_ptr<Field> field_;
+ std::shared_ptr<DataType> type_;
+
+ // Parsed common attributes
+ std::vector<bool> is_valid_;
+ int32_t length_;
+ std::shared_ptr<ArrayData> data_;
+};
+
+Result<std::shared_ptr<ArrayData>> ReadArrayData(MemoryPool* pool,
+ const rj::Value& json_array,
+ const std::shared_ptr<Field>& field) {
+ if (!json_array.IsObject()) {
+ return Status::Invalid("Array element was not a JSON object");
+ }
+ auto obj = json_array.GetObject();
+ ArrayReader parser(obj, pool, field);
+ return parser.Parse();
+}
+
+Status ReadDictionary(const RjObject& obj, MemoryPool* pool,
+ DictionaryMemo* dictionary_memo) {
+ ARROW_ASSIGN_OR_RAISE(int64_t dictionary_id, GetMemberInt<int64_t>(obj, "id"));
+
+ ARROW_ASSIGN_OR_RAISE(const auto batch_obj, GetMemberObject(obj, "data"));
+
+ ARROW_ASSIGN_OR_RAISE(auto value_type,
+ dictionary_memo->GetDictionaryType(dictionary_id));
+
+ ARROW_ASSIGN_OR_RAISE(const int64_t num_rows,
+ GetMemberInt<int64_t>(batch_obj, "count"));
+ ARROW_ASSIGN_OR_RAISE(const auto json_columns, GetMemberArray(batch_obj, "columns"));
+ if (json_columns.Size() != 1) {
+ return Status::Invalid("Dictionary batch must contain only one column");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto dict_data,
+ ReadArrayData(pool, json_columns[0], field("dummy", value_type)));
+ if (num_rows != dict_data->length) {
+ return Status::Invalid("Dictionary batch length mismatch: advertised (", num_rows,
+ ") != actual (", dict_data->length, ")");
+ }
+ return dictionary_memo->AddDictionary(dictionary_id, dict_data);
+}
+
+Status ReadDictionaries(const rj::Value& doc, MemoryPool* pool,
+ DictionaryMemo* dictionary_memo) {
+ auto it = doc.FindMember("dictionaries");
+ if (it == doc.MemberEnd()) {
+ // No dictionaries
+ return Status::OK();
+ }
+
+ RETURN_NOT_ARRAY("dictionaries", it, doc);
+ const auto& dictionary_array = it->value.GetArray();
+
+ for (const rj::Value& val : dictionary_array) {
+ DCHECK(val.IsObject());
+ RETURN_NOT_OK(ReadDictionary(val.GetObject(), pool, dictionary_memo));
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status ReadSchema(const rj::Value& json_schema, MemoryPool* pool,
+ DictionaryMemo* dictionary_memo, std::shared_ptr<Schema>* schema) {
+ DCHECK(json_schema.IsObject());
+ ARROW_ASSIGN_OR_RAISE(const auto obj_schema,
+ GetMemberObject(json_schema.GetObject(), "schema"));
+
+ ARROW_ASSIGN_OR_RAISE(const auto json_fields, GetMemberArray(obj_schema, "fields"));
+
+ std::shared_ptr<KeyValueMetadata> metadata;
+ RETURN_NOT_OK(GetKeyValueMetadata(obj_schema, &metadata));
+
+ std::vector<std::shared_ptr<Field>> fields;
+ RETURN_NOT_OK(
+ GetFieldsFromArray(json_fields, FieldPosition(), dictionary_memo, &fields));
+
+ // Read the dictionaries (if any) and cache in the memo
+ RETURN_NOT_OK(ReadDictionaries(json_schema, pool, dictionary_memo));
+
+ *schema = ::arrow::schema(fields, metadata);
+ return Status::OK();
+}
+
+Status ReadArray(MemoryPool* pool, const rj::Value& json_array,
+ const std::shared_ptr<Field>& field, std::shared_ptr<Array>* out) {
+ ARROW_ASSIGN_OR_RAISE(auto data, ReadArrayData(pool, json_array, field));
+ *out = MakeArray(data);
+ return Status::OK();
+}
+
+Status ReadRecordBatch(const rj::Value& json_obj, const std::shared_ptr<Schema>& schema,
+ DictionaryMemo* dictionary_memo, MemoryPool* pool,
+ std::shared_ptr<RecordBatch>* batch) {
+ DCHECK(json_obj.IsObject());
+ const auto& batch_obj = json_obj.GetObject();
+
+ ARROW_ASSIGN_OR_RAISE(const int64_t num_rows,
+ GetMemberInt<int64_t>(batch_obj, "count"));
+
+ ARROW_ASSIGN_OR_RAISE(const auto json_columns, GetMemberArray(batch_obj, "columns"));
+
+ ArrayDataVector columns(json_columns.Size());
+ for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
+ ARROW_ASSIGN_OR_RAISE(columns[i],
+ ReadArrayData(pool, json_columns[i], schema->field(i)));
+ }
+
+ RETURN_NOT_OK(ResolveDictionaries(columns, *dictionary_memo, pool));
+
+ *batch = RecordBatch::Make(schema, num_rows, columns);
+ return Status::OK();
+}
+
+Status WriteSchema(const Schema& schema, const DictionaryFieldMapper& mapper,
+ RjWriter* json_writer) {
+ SchemaWriter converter(schema, mapper, json_writer);
+ return converter.Write();
+}
+
+Status WriteDictionary(int64_t id, const std::shared_ptr<Array>& dictionary,
+ RjWriter* writer) {
+ writer->StartObject();
+ writer->Key("id");
+ writer->Int(static_cast<int32_t>(id));
+ writer->Key("data");
+
+ // Make a dummy record batch. A bit tedious as we have to make a schema
+ auto schema = ::arrow::schema({arrow::field("dictionary", dictionary->type())});
+ auto batch = RecordBatch::Make(schema, dictionary->length(), {dictionary});
+ RETURN_NOT_OK(WriteRecordBatch(*batch, writer));
+ writer->EndObject();
+ return Status::OK();
+}
+
+Status WriteRecordBatch(const RecordBatch& batch, RjWriter* writer) {
+ writer->StartObject();
+ writer->Key("count");
+ writer->Int(static_cast<int32_t>(batch.num_rows()));
+
+ writer->Key("columns");
+ writer->StartArray();
+
+ for (int i = 0; i < batch.num_columns(); ++i) {
+ const std::shared_ptr<Array>& column = batch.column(i);
+
+ DCHECK_EQ(batch.num_rows(), column->length())
+ << "Array length did not match record batch length: " << batch.num_rows()
+ << " != " << column->length() << " " << batch.column_name(i);
+
+ RETURN_NOT_OK(WriteArray(batch.column_name(i), *column, writer));
+ }
+
+ writer->EndArray();
+ writer->EndObject();
+ return Status::OK();
+}
+
+Status WriteArray(const std::string& name, const Array& array, RjWriter* json_writer) {
+ ArrayWriter converter(name, array, json_writer);
+ return converter.Write();
+}
+
+} // namespace json
+} // namespace testing
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/json_internal.h b/src/arrow/cpp/src/arrow/testing/json_internal.h
new file mode 100644
index 000000000..0870dd1e7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/json_internal.h
@@ -0,0 +1,126 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+
+#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
+
+#include <rapidjson/document.h> // IWYU pragma: export
+#include <rapidjson/encodings.h> // IWYU pragma: export
+#include <rapidjson/error/en.h> // IWYU pragma: export
+#include <rapidjson/rapidjson.h> // IWYU pragma: export
+#include <rapidjson/stringbuffer.h> // IWYU pragma: export
+#include <rapidjson/writer.h> // IWYU pragma: export
+
+#include "arrow/status.h" // IWYU pragma: export
+#include "arrow/testing/visibility.h"
+#include "arrow/type_fwd.h" // IWYU pragma: keep
+
+namespace rj = arrow::rapidjson;
+using RjWriter = rj::Writer<rj::StringBuffer>;
+using RjArray = rj::Value::ConstArray;
+using RjObject = rj::Value::ConstObject;
+
+#define RETURN_NOT_FOUND(TOK, NAME, PARENT) \
+ if (NAME == (PARENT).MemberEnd()) { \
+ return Status::Invalid("field ", TOK, " not found"); \
+ }
+
+#define RETURN_NOT_STRING(TOK, NAME, PARENT) \
+ RETURN_NOT_FOUND(TOK, NAME, PARENT); \
+ if (!NAME->value.IsString()) { \
+ return Status::Invalid("field was not a string line ", __LINE__); \
+ }
+
+#define RETURN_NOT_BOOL(TOK, NAME, PARENT) \
+ RETURN_NOT_FOUND(TOK, NAME, PARENT); \
+ if (!NAME->value.IsBool()) { \
+ return Status::Invalid("field was not a boolean line ", __LINE__); \
+ }
+
+#define RETURN_NOT_INT(TOK, NAME, PARENT) \
+ RETURN_NOT_FOUND(TOK, NAME, PARENT); \
+ if (!NAME->value.IsInt()) { \
+ return Status::Invalid("field was not an int line ", __LINE__); \
+ }
+
+#define RETURN_NOT_ARRAY(TOK, NAME, PARENT) \
+ RETURN_NOT_FOUND(TOK, NAME, PARENT); \
+ if (!NAME->value.IsArray()) { \
+ return Status::Invalid("field was not an array line ", __LINE__); \
+ }
+
+#define RETURN_NOT_OBJECT(TOK, NAME, PARENT) \
+ RETURN_NOT_FOUND(TOK, NAME, PARENT); \
+ if (!NAME->value.IsObject()) { \
+ return Status::Invalid("field was not an object line ", __LINE__); \
+ }
+
+namespace arrow {
+
+class Array;
+class Field;
+class MemoryPool;
+class RecordBatch;
+class Schema;
+
+namespace ipc {
+
+class DictionaryFieldMapper;
+class DictionaryMemo;
+
+} // namespace ipc
+
+namespace testing {
+namespace json {
+
+/// \brief Append integration test Schema format to rapidjson writer
+ARROW_TESTING_EXPORT
+Status WriteSchema(const Schema& schema, const ipc::DictionaryFieldMapper& mapper,
+ RjWriter* writer);
+
+ARROW_TESTING_EXPORT
+Status WriteDictionary(int64_t id, const std::shared_ptr<Array>& dictionary,
+ RjWriter* writer);
+
+ARROW_TESTING_EXPORT
+Status WriteRecordBatch(const RecordBatch& batch, RjWriter* writer);
+
+ARROW_TESTING_EXPORT
+Status WriteArray(const std::string& name, const Array& array, RjWriter* writer);
+
+ARROW_TESTING_EXPORT
+Status ReadSchema(const rj::Value& json_obj, MemoryPool* pool,
+ ipc::DictionaryMemo* dictionary_memo, std::shared_ptr<Schema>* schema);
+
+ARROW_TESTING_EXPORT
+Status ReadRecordBatch(const rj::Value& json_obj, const std::shared_ptr<Schema>& schema,
+ ipc::DictionaryMemo* dict_memo, MemoryPool* pool,
+ std::shared_ptr<RecordBatch>* batch);
+
+// NOTE: Doesn't work with dictionary arrays, use ReadRecordBatch instead.
+ARROW_TESTING_EXPORT
+Status ReadArray(MemoryPool* pool, const rj::Value& json_obj,
+ const std::shared_ptr<Field>& type, std::shared_ptr<Array>* array);
+
+} // namespace json
+} // namespace testing
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/matchers.h b/src/arrow/cpp/src/arrow/testing/matchers.h
new file mode 100644
index 000000000..b64269ea7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/matchers.h
@@ -0,0 +1,237 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <gmock/gmock-matchers.h>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/future.h"
+
+namespace arrow {
+
+template <typename ResultMatcher>
+class FutureMatcher {
+ public:
+ explicit FutureMatcher(ResultMatcher result_matcher, double wait_seconds)
+ : result_matcher_(std::move(result_matcher)), wait_seconds_(wait_seconds) {}
+
+ template <typename Fut,
+ typename ValueType = typename std::decay<Fut>::type::ValueType>
+ operator testing::Matcher<Fut>() const { // NOLINT runtime/explicit
+ struct Impl : testing::MatcherInterface<const Fut&> {
+ explicit Impl(const ResultMatcher& result_matcher, double wait_seconds)
+ : result_matcher_(testing::MatcherCast<Result<ValueType>>(result_matcher)),
+ wait_seconds_(wait_seconds) {}
+
+ void DescribeTo(::std::ostream* os) const override {
+ *os << "value ";
+ result_matcher_.DescribeTo(os);
+ }
+
+ void DescribeNegationTo(::std::ostream* os) const override {
+ *os << "value ";
+ result_matcher_.DescribeNegationTo(os);
+ }
+
+ bool MatchAndExplain(const Fut& fut,
+ testing::MatchResultListener* listener) const override {
+ if (!fut.Wait(wait_seconds_)) {
+ *listener << "which didn't finish within " << wait_seconds_ << " seconds";
+ return false;
+ }
+ return result_matcher_.MatchAndExplain(fut.result(), listener);
+ }
+
+ const testing::Matcher<Result<ValueType>> result_matcher_;
+ const double wait_seconds_;
+ };
+
+ return testing::Matcher<Fut>(new Impl(result_matcher_, wait_seconds_));
+ }
+
+ private:
+ const ResultMatcher result_matcher_;
+ const double wait_seconds_;
+};
+
+template <typename ValueMatcher>
+class ResultMatcher {
+ public:
+ explicit ResultMatcher(ValueMatcher value_matcher)
+ : value_matcher_(std::move(value_matcher)) {}
+
+ template <typename Res,
+ typename ValueType = typename std::decay<Res>::type::ValueType>
+ operator testing::Matcher<Res>() const { // NOLINT runtime/explicit
+ struct Impl : testing::MatcherInterface<const Res&> {
+ explicit Impl(const ValueMatcher& value_matcher)
+ : value_matcher_(testing::MatcherCast<ValueType>(value_matcher)) {}
+
+ void DescribeTo(::std::ostream* os) const override {
+ *os << "value ";
+ value_matcher_.DescribeTo(os);
+ }
+
+ void DescribeNegationTo(::std::ostream* os) const override {
+ *os << "value ";
+ value_matcher_.DescribeNegationTo(os);
+ }
+
+ bool MatchAndExplain(const Res& maybe_value,
+ testing::MatchResultListener* listener) const override {
+ if (!maybe_value.status().ok()) {
+ *listener << "whose error "
+ << testing::PrintToString(maybe_value.status().ToString())
+ << " doesn't match";
+ return false;
+ }
+ const ValueType& value = maybe_value.ValueOrDie();
+ testing::StringMatchResultListener value_listener;
+ const bool match = value_matcher_.MatchAndExplain(value, &value_listener);
+ *listener << "whose value " << testing::PrintToString(value)
+ << (match ? " matches" : " doesn't match");
+ testing::internal::PrintIfNotEmpty(value_listener.str(), listener->stream());
+ return match;
+ }
+
+ const testing::Matcher<ValueType> value_matcher_;
+ };
+
+ return testing::Matcher<Res>(new Impl(value_matcher_));
+ }
+
+ private:
+ const ValueMatcher value_matcher_;
+};
+
+class ErrorMatcher {
+ public:
+ explicit ErrorMatcher(StatusCode code,
+ util::optional<testing::Matcher<std::string>> message_matcher)
+ : code_(code), message_matcher_(std::move(message_matcher)) {}
+
+ template <typename Res>
+ operator testing::Matcher<Res>() const { // NOLINT runtime/explicit
+ struct Impl : testing::MatcherInterface<const Res&> {
+ explicit Impl(StatusCode code,
+ util::optional<testing::Matcher<std::string>> message_matcher)
+ : code_(code), message_matcher_(std::move(message_matcher)) {}
+
+ void DescribeTo(::std::ostream* os) const override {
+ *os << "raises StatusCode::" << Status::CodeAsString(code_);
+ if (message_matcher_) {
+ *os << " and message ";
+ message_matcher_->DescribeTo(os);
+ }
+ }
+
+ void DescribeNegationTo(::std::ostream* os) const override {
+ *os << "does not raise StatusCode::" << Status::CodeAsString(code_);
+ if (message_matcher_) {
+ *os << " or message ";
+ message_matcher_->DescribeNegationTo(os);
+ }
+ }
+
+ bool MatchAndExplain(const Res& maybe_value,
+ testing::MatchResultListener* listener) const override {
+ const Status& status = internal::GenericToStatus(maybe_value);
+ testing::StringMatchResultListener value_listener;
+
+ bool match = status.code() == code_;
+ if (message_matcher_) {
+ match = match &&
+ message_matcher_->MatchAndExplain(status.message(), &value_listener);
+ }
+
+ *listener << "whose value " << testing::PrintToString(status.ToString())
+ << (match ? " matches" : " doesn't match");
+ testing::internal::PrintIfNotEmpty(value_listener.str(), listener->stream());
+ return match;
+ }
+
+ const StatusCode code_;
+ const util::optional<testing::Matcher<std::string>> message_matcher_;
+ };
+
+ return testing::Matcher<Res>(new Impl(code_, message_matcher_));
+ }
+
+ private:
+ const StatusCode code_;
+ const util::optional<testing::Matcher<std::string>> message_matcher_;
+};
+
+class OkMatcher {
+ public:
+ template <typename Res>
+ operator testing::Matcher<Res>() const { // NOLINT runtime/explicit
+ struct Impl : testing::MatcherInterface<const Res&> {
+ void DescribeTo(::std::ostream* os) const override { *os << "is ok"; }
+
+ void DescribeNegationTo(::std::ostream* os) const override { *os << "is not ok"; }
+
+ bool MatchAndExplain(const Res& maybe_value,
+ testing::MatchResultListener* listener) const override {
+ const Status& status = internal::GenericToStatus(maybe_value);
+ testing::StringMatchResultListener value_listener;
+
+ const bool match = status.ok();
+ *listener << "whose value " << testing::PrintToString(status.ToString())
+ << (match ? " matches" : " doesn't match");
+ testing::internal::PrintIfNotEmpty(value_listener.str(), listener->stream());
+ return match;
+ }
+ };
+
+ return testing::Matcher<Res>(new Impl());
+ }
+};
+
+// Returns a matcher that waits on a Future (by default for 16 seconds)
+// then applies a matcher to the result.
+template <typename ResultMatcher>
+FutureMatcher<ResultMatcher> Finishes(
+ const ResultMatcher& result_matcher,
+ double wait_seconds = kDefaultAssertFinishesWaitSeconds) {
+ return FutureMatcher<ResultMatcher>(result_matcher, wait_seconds);
+}
+
+// Returns a matcher that matches the value of a successful Result<T>.
+template <typename ValueMatcher>
+ResultMatcher<ValueMatcher> ResultWith(const ValueMatcher& value_matcher) {
+ return ResultMatcher<ValueMatcher>(value_matcher);
+}
+
+// Returns a matcher that matches an ok Status or Result<T>.
+inline OkMatcher Ok() { return {}; }
+
+// Returns a matcher that matches the StatusCode of a Status or Result<T>.
+// Do not use Raises(StatusCode::OK) to match a non error code.
+inline ErrorMatcher Raises(StatusCode code) { return ErrorMatcher(code, util::nullopt); }
+
+// Returns a matcher that matches the StatusCode and message of a Status or Result<T>.
+template <typename MessageMatcher>
+ErrorMatcher Raises(StatusCode code, const MessageMatcher& message_matcher) {
+ return ErrorMatcher(code, testing::MatcherCast<std::string>(message_matcher));
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/pch.h b/src/arrow/cpp/src/arrow/testing/pch.h
new file mode 100644
index 000000000..d6c3c7496
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/pch.h
@@ -0,0 +1,26 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Often-used headers, for precompiling.
+// If updating this header, please make sure you check compilation speed
+// before checking in. Adding headers which are not used extremely often
+// may incur a slowdown, since it makes the precompiled header heavier to load.
+
+#include "arrow/pch.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
diff --git a/src/arrow/cpp/src/arrow/testing/random.cc b/src/arrow/cpp/src/arrow/testing/random.cc
new file mode 100644
index 000000000..ce6ec1a6e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/random.cc
@@ -0,0 +1,949 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/testing/random.h"
+
+#include <gtest/gtest.h>
+
+#include <algorithm>
+#include <array>
+#include <cmath>
+#include <limits>
+#include <memory>
+#include <random>
+#include <type_traits>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/buffer.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/pcg_random.h"
+#include "arrow/util/value_parsing.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace random {
+
+namespace {
+
+template <typename ValueType, typename DistributionType>
+struct GenerateOptions {
+ GenerateOptions(SeedType seed, ValueType min, ValueType max, double probability,
+ double nan_probability = 0.0)
+ : min_(min),
+ max_(max),
+ seed_(seed),
+ probability_(probability),
+ nan_probability_(nan_probability) {}
+
+ void GenerateData(uint8_t* buffer, size_t n) {
+ GenerateTypedData(reinterpret_cast<ValueType*>(buffer), n);
+ }
+
+ template <typename V>
+ typename std::enable_if<!std::is_floating_point<V>::value>::type GenerateTypedData(
+ V* data, size_t n) {
+ GenerateTypedDataNoNan(data, n);
+ }
+
+ template <typename V>
+ typename std::enable_if<std::is_floating_point<V>::value>::type GenerateTypedData(
+ V* data, size_t n) {
+ if (nan_probability_ == 0.0) {
+ GenerateTypedDataNoNan(data, n);
+ return;
+ }
+ pcg32_fast rng(seed_++);
+ DistributionType dist(min_, max_);
+ ::arrow::random::bernoulli_distribution nan_dist(nan_probability_);
+ const ValueType nan_value = std::numeric_limits<ValueType>::quiet_NaN();
+
+ // A static cast is required due to the int16 -> int8 handling.
+ std::generate(data, data + n, [&] {
+ return nan_dist(rng) ? nan_value : static_cast<ValueType>(dist(rng));
+ });
+ }
+
+ void GenerateTypedDataNoNan(ValueType* data, size_t n) {
+ pcg32_fast rng(seed_++);
+ DistributionType dist(min_, max_);
+
+ // A static cast is required due to the int16 -> int8 handling.
+ std::generate(data, data + n, [&] { return static_cast<ValueType>(dist(rng)); });
+ }
+
+ void GenerateBitmap(uint8_t* buffer, size_t n, int64_t* null_count) {
+ int64_t count = 0;
+ pcg32_fast rng(seed_++);
+ ::arrow::random::bernoulli_distribution dist(1.0 - probability_);
+
+ for (size_t i = 0; i < n; i++) {
+ if (dist(rng)) {
+ BitUtil::SetBit(buffer, i);
+ } else {
+ count++;
+ }
+ }
+
+ if (null_count != nullptr) *null_count = count;
+ }
+
+ ValueType min_;
+ ValueType max_;
+ SeedType seed_;
+ double probability_;
+ double nan_probability_;
+};
+
+} // namespace
+
+std::shared_ptr<Buffer> RandomArrayGenerator::NullBitmap(int64_t size,
+ double null_probability) {
+ // The bitmap generator does not care about the value distribution since it
+ // only calls the GenerateBitmap method.
+ using GenOpt = GenerateOptions<int, std::uniform_int_distribution<int>>;
+
+ GenOpt null_gen(seed(), 0, 1, null_probability);
+ std::shared_ptr<Buffer> bitmap = *AllocateEmptyBitmap(size);
+ null_gen.GenerateBitmap(bitmap->mutable_data(), size, nullptr);
+
+ return bitmap;
+}
+
+std::shared_ptr<Array> RandomArrayGenerator::Boolean(int64_t size,
+ double true_probability,
+ double null_probability) {
+ // The boolean generator does not care about the value distribution since it
+ // only calls the GenerateBitmap method.
+ using GenOpt = GenerateOptions<int, std::uniform_int_distribution<int>>;
+
+ BufferVector buffers{2};
+ // Need 2 distinct generators such that probabilities are not shared.
+
+ // The "GenerateBitmap" function is written to generate validity bitmaps
+ // parameterized by the null probability, which is the probability of 0. For
+ // boolean data, the true probability is the probability of 1, so to use
+ // GenerateBitmap we must provide the probability of false instead.
+ GenOpt value_gen(seed(), 0, 1, 1 - true_probability);
+
+ GenOpt null_gen(seed(), 0, 1, null_probability);
+
+ int64_t null_count = 0;
+ buffers[0] = *AllocateEmptyBitmap(size);
+ null_gen.GenerateBitmap(buffers[0]->mutable_data(), size, &null_count);
+
+ buffers[1] = *AllocateEmptyBitmap(size);
+ value_gen.GenerateBitmap(buffers[1]->mutable_data(), size, nullptr);
+
+ auto array_data = ArrayData::Make(arrow::boolean(), size, buffers, null_count);
+ return std::make_shared<BooleanArray>(array_data);
+}
+
+template <typename ArrowType, typename OptionType>
+static std::shared_ptr<NumericArray<ArrowType>> GenerateNumericArray(int64_t size,
+ OptionType options) {
+ using CType = typename ArrowType::c_type;
+ auto type = TypeTraits<ArrowType>::type_singleton();
+ BufferVector buffers{2};
+
+ int64_t null_count = 0;
+ buffers[0] = *AllocateEmptyBitmap(size);
+ options.GenerateBitmap(buffers[0]->mutable_data(), size, &null_count);
+
+ buffers[1] = *AllocateBuffer(sizeof(CType) * size);
+ options.GenerateData(buffers[1]->mutable_data(), size);
+
+ auto array_data = ArrayData::Make(type, size, buffers, null_count);
+ return std::make_shared<NumericArray<ArrowType>>(array_data);
+}
+
+#define PRIMITIVE_RAND_IMPL(Name, CType, ArrowType, Distribution) \
+ std::shared_ptr<Array> RandomArrayGenerator::Name(int64_t size, CType min, CType max, \
+ double probability) { \
+ using OptionType = GenerateOptions<CType, Distribution>; \
+ OptionType options(seed(), min, max, probability); \
+ return GenerateNumericArray<ArrowType, OptionType>(size, options); \
+ }
+
+#define PRIMITIVE_RAND_INTEGER_IMPL(Name, CType, ArrowType) \
+ PRIMITIVE_RAND_IMPL(Name, CType, ArrowType, std::uniform_int_distribution<CType>)
+
+// Visual Studio does not implement uniform_int_distribution for char types.
+PRIMITIVE_RAND_IMPL(UInt8, uint8_t, UInt8Type, std::uniform_int_distribution<uint16_t>)
+PRIMITIVE_RAND_IMPL(Int8, int8_t, Int8Type, std::uniform_int_distribution<int16_t>)
+
+PRIMITIVE_RAND_INTEGER_IMPL(UInt16, uint16_t, UInt16Type)
+PRIMITIVE_RAND_INTEGER_IMPL(Int16, int16_t, Int16Type)
+PRIMITIVE_RAND_INTEGER_IMPL(UInt32, uint32_t, UInt32Type)
+PRIMITIVE_RAND_INTEGER_IMPL(Int32, int32_t, Int32Type)
+PRIMITIVE_RAND_INTEGER_IMPL(UInt64, uint64_t, UInt64Type)
+PRIMITIVE_RAND_INTEGER_IMPL(Int64, int64_t, Int64Type)
+// Generate 16bit values for half-float
+PRIMITIVE_RAND_INTEGER_IMPL(Float16, int16_t, HalfFloatType)
+
+std::shared_ptr<Array> RandomArrayGenerator::Float32(int64_t size, float min, float max,
+ double null_probability,
+ double nan_probability) {
+ using OptionType =
+ GenerateOptions<float, ::arrow::random::uniform_real_distribution<float>>;
+ OptionType options(seed(), min, max, null_probability, nan_probability);
+ return GenerateNumericArray<FloatType, OptionType>(size, options);
+}
+
+std::shared_ptr<Array> RandomArrayGenerator::Float64(int64_t size, double min, double max,
+ double null_probability,
+ double nan_probability) {
+ using OptionType =
+ GenerateOptions<double, ::arrow::random::uniform_real_distribution<double>>;
+ OptionType options(seed(), min, max, null_probability, nan_probability);
+ return GenerateNumericArray<DoubleType, OptionType>(size, options);
+}
+
+#undef PRIMITIVE_RAND_INTEGER_IMPL
+#undef PRIMITIVE_RAND_IMPL
+
+namespace {
+
+// A generic generator for random decimal arrays
+template <typename DecimalType>
+struct DecimalGenerator {
+ using DecimalBuilderType = typename TypeTraits<DecimalType>::BuilderType;
+ using DecimalValue = typename DecimalBuilderType::ValueType;
+
+ std::shared_ptr<DataType> type_;
+ RandomArrayGenerator* rng_;
+
+ static uint64_t MaxDecimalInteger(int32_t digits) {
+ // Need to decrement *after* the cast to uint64_t because, while
+ // 10**x is exactly representable in a double for x <= 19,
+ // 10**x - 1 is not.
+ return static_cast<uint64_t>(std::ceil(std::pow(10.0, digits))) - 1;
+ }
+
+ std::shared_ptr<Array> MakeRandomArray(int64_t size, double null_probability) {
+ // 10**19 fits in a 64-bit unsigned integer
+ static constexpr int32_t kMaxDigitsInInteger = 19;
+ static constexpr int kNumIntegers = DecimalType::kByteWidth / 8;
+
+ static_assert(
+ kNumIntegers ==
+ (DecimalType::kMaxPrecision + kMaxDigitsInInteger - 1) / kMaxDigitsInInteger,
+ "inconsistent decimal metadata: kMaxPrecision doesn't match kByteWidth");
+
+ // First generate separate random values for individual components:
+ // boolean sign (including null-ness), and uint64 "digits" in big endian order.
+ const auto& decimal_type = checked_cast<const DecimalType&>(*type_);
+
+ const auto sign_array = checked_pointer_cast<BooleanArray>(
+ rng_->Boolean(size, /*true_probability=*/0.5, null_probability));
+ std::array<std::shared_ptr<UInt64Array>, kNumIntegers> digit_arrays;
+
+ auto remaining_digits = decimal_type.precision();
+ for (int i = kNumIntegers - 1; i >= 0; --i) {
+ const auto digits = std::min(kMaxDigitsInInteger, remaining_digits);
+ digit_arrays[i] = checked_pointer_cast<UInt64Array>(
+ rng_->UInt64(size, 0, MaxDecimalInteger(digits)));
+ DCHECK_EQ(digit_arrays[i]->null_count(), 0);
+ remaining_digits -= digits;
+ }
+
+ // Second compute decimal values from the individual components,
+ // building up a decimal array.
+ DecimalBuilderType builder(type_);
+ ABORT_NOT_OK(builder.Reserve(size));
+
+ const DecimalValue kDigitsMultiplier =
+ DecimalValue::GetScaleMultiplier(kMaxDigitsInInteger);
+
+ for (int64_t i = 0; i < size; ++i) {
+ if (sign_array->IsValid(i)) {
+ DecimalValue dec_value{0};
+ for (int j = 0; j < kNumIntegers; ++j) {
+ dec_value =
+ dec_value * kDigitsMultiplier + DecimalValue(digit_arrays[j]->Value(i));
+ }
+ if (sign_array->Value(i)) {
+ builder.UnsafeAppend(dec_value.Negate());
+ } else {
+ builder.UnsafeAppend(dec_value);
+ }
+ } else {
+ builder.UnsafeAppendNull();
+ }
+ }
+ std::shared_ptr<Array> array;
+ ABORT_NOT_OK(builder.Finish(&array));
+ return array;
+ }
+};
+
+} // namespace
+
+std::shared_ptr<Array> RandomArrayGenerator::Decimal128(std::shared_ptr<DataType> type,
+ int64_t size,
+ double null_probability) {
+ DecimalGenerator<Decimal128Type> gen{type, this};
+ return gen.MakeRandomArray(size, null_probability);
+}
+
+std::shared_ptr<Array> RandomArrayGenerator::Decimal256(std::shared_ptr<DataType> type,
+ int64_t size,
+ double null_probability) {
+ DecimalGenerator<Decimal256Type> gen{type, this};
+ return gen.MakeRandomArray(size, null_probability);
+}
+
+template <typename TypeClass>
+static std::shared_ptr<Array> GenerateBinaryArray(RandomArrayGenerator* gen, int64_t size,
+ int32_t min_length, int32_t max_length,
+ double null_probability) {
+ using offset_type = typename TypeClass::offset_type;
+ using BuilderType = typename TypeTraits<TypeClass>::BuilderType;
+ using OffsetArrowType = typename CTypeTraits<offset_type>::ArrowType;
+ using OffsetArrayType = typename TypeTraits<OffsetArrowType>::ArrayType;
+
+ if (null_probability < 0 || null_probability > 1) {
+ ABORT_NOT_OK(Status::Invalid("null_probability must be between 0 and 1"));
+ }
+
+ auto lengths = std::dynamic_pointer_cast<OffsetArrayType>(
+ gen->Numeric<OffsetArrowType>(size, min_length, max_length, null_probability));
+
+ // Visual Studio does not implement uniform_int_distribution for char types.
+ using GenOpt = GenerateOptions<uint8_t, std::uniform_int_distribution<uint16_t>>;
+ GenOpt options(gen->seed(), static_cast<uint8_t>('A'), static_cast<uint8_t>('z'),
+ /*null_probability=*/0);
+
+ std::vector<uint8_t> str_buffer(max_length);
+ BuilderType builder;
+
+ for (int64_t i = 0; i < size; ++i) {
+ if (lengths->IsValid(i)) {
+ options.GenerateData(str_buffer.data(), lengths->Value(i));
+ ABORT_NOT_OK(builder.Append(str_buffer.data(), lengths->Value(i)));
+ } else {
+ ABORT_NOT_OK(builder.AppendNull());
+ }
+ }
+
+ std::shared_ptr<Array> result;
+ ABORT_NOT_OK(builder.Finish(&result));
+ return result;
+}
+
+std::shared_ptr<Array> RandomArrayGenerator::String(int64_t size, int32_t min_length,
+ int32_t max_length,
+ double null_probability) {
+ return GenerateBinaryArray<StringType>(this, size, min_length, max_length,
+ null_probability);
+}
+
+std::shared_ptr<Array> RandomArrayGenerator::LargeString(int64_t size, int32_t min_length,
+ int32_t max_length,
+ double null_probability) {
+ return GenerateBinaryArray<LargeStringType>(this, size, min_length, max_length,
+ null_probability);
+}
+
+std::shared_ptr<Array> RandomArrayGenerator::BinaryWithRepeats(int64_t size,
+ int64_t unique,
+ int32_t min_length,
+ int32_t max_length,
+ double null_probability) {
+ auto strings =
+ StringWithRepeats(size, unique, min_length, max_length, null_probability);
+ std::shared_ptr<Array> out;
+ return *strings->View(binary());
+}
+
+std::shared_ptr<Array> RandomArrayGenerator::StringWithRepeats(int64_t size,
+ int64_t unique,
+ int32_t min_length,
+ int32_t max_length,
+ double null_probability) {
+ ARROW_CHECK_LE(unique, size);
+
+ // Generate a random string dictionary without any nulls
+ auto array = String(unique, min_length, max_length, /*null_probability=*/0);
+ auto dictionary = std::dynamic_pointer_cast<StringArray>(array);
+
+ // Generate random indices to sample the dictionary with
+ auto id_array = Int64(size, 0, unique - 1, null_probability);
+ auto indices = std::dynamic_pointer_cast<Int64Array>(id_array);
+ StringBuilder builder;
+
+ for (int64_t i = 0; i < size; ++i) {
+ if (indices->IsValid(i)) {
+ const auto index = indices->Value(i);
+ const auto value = dictionary->GetView(index);
+ ABORT_NOT_OK(builder.Append(value));
+ } else {
+ ABORT_NOT_OK(builder.AppendNull());
+ }
+ }
+
+ std::shared_ptr<Array> result;
+ ABORT_NOT_OK(builder.Finish(&result));
+ return result;
+}
+
+std::shared_ptr<Array> RandomArrayGenerator::FixedSizeBinary(int64_t size,
+ int32_t byte_width,
+ double null_probability) {
+ if (null_probability < 0 || null_probability > 1) {
+ ABORT_NOT_OK(Status::Invalid("null_probability must be between 0 and 1"));
+ }
+
+ // Visual Studio does not implement uniform_int_distribution for char types.
+ using GenOpt = GenerateOptions<uint8_t, std::uniform_int_distribution<uint16_t>>;
+ GenOpt options(seed(), static_cast<uint8_t>('A'), static_cast<uint8_t>('z'),
+ null_probability);
+
+ int64_t null_count = 0;
+ auto null_bitmap = *AllocateEmptyBitmap(size);
+ auto data_buffer = *AllocateBuffer(size * byte_width);
+ options.GenerateBitmap(null_bitmap->mutable_data(), size, &null_count);
+ options.GenerateData(data_buffer->mutable_data(), size * byte_width);
+
+ auto type = fixed_size_binary(byte_width);
+ return std::make_shared<FixedSizeBinaryArray>(type, size, std::move(data_buffer),
+ std::move(null_bitmap), null_count);
+}
+
+namespace {
+template <typename OffsetArrayType>
+std::shared_ptr<Array> GenerateOffsets(SeedType seed, int64_t size,
+ typename OffsetArrayType::value_type first_offset,
+ typename OffsetArrayType::value_type last_offset,
+ double null_probability, bool force_empty_nulls) {
+ using GenOpt = GenerateOptions<
+ typename OffsetArrayType::value_type,
+ std::uniform_int_distribution<typename OffsetArrayType::value_type>>;
+ GenOpt options(seed, first_offset, last_offset, null_probability);
+
+ BufferVector buffers{2};
+
+ int64_t null_count = 0;
+
+ buffers[0] = *AllocateEmptyBitmap(size);
+ uint8_t* null_bitmap = buffers[0]->mutable_data();
+ options.GenerateBitmap(null_bitmap, size, &null_count);
+ // Make sure the first and last entry are non-null
+ for (const int64_t offset : std::vector<int64_t>{0, size - 1}) {
+ if (!arrow::BitUtil::GetBit(null_bitmap, offset)) {
+ arrow::BitUtil::SetBit(null_bitmap, offset);
+ --null_count;
+ }
+ }
+
+ buffers[1] = *AllocateBuffer(sizeof(typename OffsetArrayType::value_type) * size);
+ auto data =
+ reinterpret_cast<typename OffsetArrayType::value_type*>(buffers[1]->mutable_data());
+ options.GenerateTypedData(data, size);
+ // Ensure offsets are in increasing order
+ std::sort(data, data + size);
+ // Ensure first and last offsets are as required
+ DCHECK_GE(data[0], first_offset);
+ DCHECK_LE(data[size - 1], last_offset);
+ data[0] = first_offset;
+ data[size - 1] = last_offset;
+
+ if (force_empty_nulls) {
+ arrow::internal::BitmapReader reader(null_bitmap, 0, size);
+ for (int64_t i = 0; i < size; ++i) {
+ if (reader.IsNotSet()) {
+ // Ensure a null entry corresponds to a 0-sized list extent
+ // (note this can be neither the first nor the last list entry, see above)
+ data[i + 1] = data[i];
+ }
+ reader.Next();
+ }
+ }
+
+ auto array_data = ArrayData::Make(
+ std::make_shared<typename OffsetArrayType::TypeClass>(), size, buffers, null_count);
+ return std::make_shared<OffsetArrayType>(array_data);
+}
+
+template <typename OffsetArrayType>
+std::shared_ptr<Array> OffsetsFromLengthsArray(OffsetArrayType* lengths,
+ bool force_empty_nulls) {
+ DCHECK(lengths->length() == 0 || !lengths->IsNull(0));
+ DCHECK(lengths->length() == 0 || !lengths->IsNull(lengths->length() - 1));
+ // Need N + 1 offsets for N items
+ int64_t size = lengths->length() + 1;
+ BufferVector buffers{2};
+
+ int64_t null_count = 0;
+
+ buffers[0] = *AllocateEmptyBitmap(size);
+ uint8_t* null_bitmap = buffers[0]->mutable_data();
+ // Make sure the first and last entry are non-null
+ arrow::BitUtil::SetBit(null_bitmap, 0);
+ arrow::BitUtil::SetBit(null_bitmap, size - 1);
+
+ buffers[1] = *AllocateBuffer(sizeof(typename OffsetArrayType::value_type) * size);
+ auto data =
+ reinterpret_cast<typename OffsetArrayType::value_type*>(buffers[1]->mutable_data());
+ data[0] = 0;
+ int index = 1;
+ for (const auto& length : *lengths) {
+ if (length.has_value()) {
+ arrow::BitUtil::SetBit(null_bitmap, index);
+ data[index] = data[index - 1] + *length;
+ DCHECK_GE(*length, 0);
+ } else {
+ data[index] = data[index - 1];
+ null_count++;
+ }
+ index++;
+ }
+
+ if (force_empty_nulls) {
+ arrow::internal::BitmapReader reader(null_bitmap, 0, size);
+ for (int64_t i = 0; i < size; ++i) {
+ if (reader.IsNotSet()) {
+ // Ensure a null entry corresponds to a 0-sized list extent
+ // (note this can be neither the first nor the last list entry, see above)
+ data[i + 1] = data[i];
+ }
+ reader.Next();
+ }
+ }
+
+ auto array_data = ArrayData::Make(
+ std::make_shared<typename OffsetArrayType::TypeClass>(), size, buffers, null_count);
+ return std::make_shared<OffsetArrayType>(array_data);
+}
+} // namespace
+
+std::shared_ptr<Array> RandomArrayGenerator::Offsets(int64_t size, int32_t first_offset,
+ int32_t last_offset,
+ double null_probability,
+ bool force_empty_nulls) {
+ return GenerateOffsets<NumericArray<Int32Type>>(seed(), size, first_offset, last_offset,
+ null_probability, force_empty_nulls);
+}
+
+std::shared_ptr<Array> RandomArrayGenerator::LargeOffsets(int64_t size,
+ int64_t first_offset,
+ int64_t last_offset,
+ double null_probability,
+ bool force_empty_nulls) {
+ return GenerateOffsets<NumericArray<Int64Type>>(seed(), size, first_offset, last_offset,
+ null_probability, force_empty_nulls);
+}
+
+std::shared_ptr<Array> RandomArrayGenerator::List(const Array& values, int64_t size,
+ double null_probability,
+ bool force_empty_nulls) {
+ auto offsets = Offsets(size + 1, static_cast<int32_t>(values.offset()),
+ static_cast<int32_t>(values.offset() + values.length()),
+ null_probability, force_empty_nulls);
+ return *::arrow::ListArray::FromArrays(*offsets, values);
+}
+
+std::shared_ptr<Array> RandomArrayGenerator::Map(const std::shared_ptr<Array>& keys,
+ const std::shared_ptr<Array>& items,
+ int64_t size, double null_probability,
+ bool force_empty_nulls) {
+ DCHECK_EQ(keys->length(), items->length());
+ auto offsets = Offsets(size + 1, static_cast<int32_t>(keys->offset()),
+ static_cast<int32_t>(keys->offset() + keys->length()),
+ null_probability, force_empty_nulls);
+ return *::arrow::MapArray::FromArrays(offsets, keys, items);
+}
+
+std::shared_ptr<Array> RandomArrayGenerator::SparseUnion(const ArrayVector& fields,
+ int64_t size) {
+ DCHECK_GT(fields.size(), 0);
+ // Trivial type codes map
+ std::vector<UnionArray::type_code_t> type_codes(fields.size());
+ std::iota(type_codes.begin(), type_codes.end(), 0);
+
+ // Generate array of type ids
+ auto type_ids = Int8(size, 0, static_cast<int8_t>(fields.size() - 1));
+ return *SparseUnionArray::Make(*type_ids, fields, type_codes);
+}
+
+std::shared_ptr<Array> RandomArrayGenerator::DenseUnion(const ArrayVector& fields,
+ int64_t size) {
+ DCHECK_GT(fields.size(), 0);
+ // Trivial type codes map
+ std::vector<UnionArray::type_code_t> type_codes(fields.size());
+ std::iota(type_codes.begin(), type_codes.end(), 0);
+
+ // Generate array of type ids
+ auto type_ids = Int8(size, 0, static_cast<int8_t>(fields.size() - 1));
+
+ // Generate array of offsets
+ const auto& concrete_ids = checked_cast<const Int8Array&>(*type_ids);
+ Int32Builder offsets_builder;
+ ABORT_NOT_OK(offsets_builder.Reserve(size));
+ std::vector<int32_t> last_offsets(fields.size(), 0);
+ for (int64_t i = 0; i < size; ++i) {
+ const auto field_id = concrete_ids.Value(i);
+ offsets_builder.UnsafeAppend(last_offsets[field_id]++);
+ }
+ std::shared_ptr<Array> offsets;
+ ABORT_NOT_OK(offsets_builder.Finish(&offsets));
+
+ return *DenseUnionArray::Make(*type_ids, *offsets, fields, type_codes);
+}
+
+namespace {
+
+// Helper for RandomArrayGenerator::ArrayOf: extract some C value from
+// a given metadata key.
+template <typename T, typename ArrowType = typename CTypeTraits<T>::ArrowType>
+enable_if_parameter_free<ArrowType, T> GetMetadata(const KeyValueMetadata* metadata,
+ const std::string& key,
+ T default_value) {
+ if (!metadata) return default_value;
+ const auto index = metadata->FindKey(key);
+ if (index < 0) return default_value;
+ const auto& value = metadata->value(index);
+ T output{};
+ if (!internal::ParseValue<ArrowType>(value.data(), value.length(), &output)) {
+ ABORT_NOT_OK(Status::Invalid("Could not parse ", key, " = ", value));
+ }
+ return output;
+}
+
+} // namespace
+
+std::shared_ptr<Array> RandomArrayGenerator::ArrayOf(std::shared_ptr<DataType> type,
+ int64_t size,
+ double null_probability) {
+ auto metadata =
+ key_value_metadata({"null_probability"}, {std::to_string(null_probability)});
+ auto field = ::arrow::field("", std::move(type), std::move(metadata));
+ return ArrayOf(*field, size);
+}
+
+std::shared_ptr<Array> RandomArrayGenerator::ArrayOf(const Field& field, int64_t length) {
+#define VALIDATE_RANGE(PARAM, MIN, MAX) \
+ if (PARAM < MIN || PARAM > MAX) { \
+ ABORT_NOT_OK(Status::Invalid(field.ToString(), ": ", ARROW_STRINGIFY(PARAM), \
+ " must be in [", MIN, ", ", MAX, " ] but got ", \
+ PARAM)); \
+ }
+#define VALIDATE_MIN_MAX(MIN, MAX) \
+ if (MIN > MAX) { \
+ ABORT_NOT_OK( \
+ Status::Invalid(field.ToString(), ": min ", MIN, " must be <= max ", MAX)); \
+ }
+#define GENERATE_INTEGRAL_CASE_VIEW(BASE_TYPE, VIEW_TYPE) \
+ case VIEW_TYPE::type_id: { \
+ const BASE_TYPE::c_type min_value = GetMetadata<BASE_TYPE::c_type>( \
+ field.metadata().get(), "min", std::numeric_limits<BASE_TYPE::c_type>::min()); \
+ const BASE_TYPE::c_type max_value = GetMetadata<BASE_TYPE::c_type>( \
+ field.metadata().get(), "max", std::numeric_limits<BASE_TYPE::c_type>::max()); \
+ VALIDATE_MIN_MAX(min_value, max_value); \
+ return *Numeric<BASE_TYPE>(length, min_value, max_value, null_probability) \
+ ->View(field.type()); \
+ }
+#define GENERATE_INTEGRAL_CASE(ARROW_TYPE) \
+ GENERATE_INTEGRAL_CASE_VIEW(ARROW_TYPE, ARROW_TYPE)
+#define GENERATE_FLOATING_CASE(ARROW_TYPE, GENERATOR_FUNC) \
+ case ARROW_TYPE::type_id: { \
+ const ARROW_TYPE::c_type min_value = GetMetadata<ARROW_TYPE::c_type>( \
+ field.metadata().get(), "min", std::numeric_limits<ARROW_TYPE::c_type>::min()); \
+ const ARROW_TYPE::c_type max_value = GetMetadata<ARROW_TYPE::c_type>( \
+ field.metadata().get(), "max", std::numeric_limits<ARROW_TYPE::c_type>::max()); \
+ const double nan_probability = \
+ GetMetadata<double>(field.metadata().get(), "nan_probability", 0); \
+ VALIDATE_MIN_MAX(min_value, max_value); \
+ VALIDATE_RANGE(nan_probability, 0.0, 1.0); \
+ return GENERATOR_FUNC(length, min_value, max_value, null_probability, \
+ nan_probability); \
+ }
+
+ // Don't use compute::Sum since that may not get built
+#define GENERATE_LIST_CASE(ARRAY_TYPE) \
+ case ARRAY_TYPE::TypeClass::type_id: { \
+ const auto min_length = GetMetadata<ARRAY_TYPE::TypeClass::offset_type>( \
+ field.metadata().get(), "min_length", 0); \
+ const auto max_length = GetMetadata<ARRAY_TYPE::TypeClass::offset_type>( \
+ field.metadata().get(), "max_length", 20); \
+ const auto lengths = internal::checked_pointer_cast< \
+ CTypeTraits<ARRAY_TYPE::TypeClass::offset_type>::ArrayType>( \
+ Numeric<CTypeTraits<ARRAY_TYPE::TypeClass::offset_type>::ArrowType>( \
+ length, min_length, max_length, null_probability)); \
+ int64_t values_length = 0; \
+ for (const auto& length : *lengths) { \
+ if (length.has_value()) values_length += *length; \
+ } \
+ const auto force_empty_nulls = \
+ GetMetadata<bool>(field.metadata().get(), "force_empty_nulls", false); \
+ const auto values = \
+ ArrayOf(*internal::checked_pointer_cast<ARRAY_TYPE::TypeClass>(field.type()) \
+ ->value_field(), \
+ values_length); \
+ const auto offsets = OffsetsFromLengthsArray(lengths.get(), force_empty_nulls); \
+ return *ARRAY_TYPE::FromArrays(*offsets, *values); \
+ }
+
+ const double null_probability =
+ field.nullable()
+ ? GetMetadata<double>(field.metadata().get(), "null_probability", 0.01)
+ : 0.0;
+ VALIDATE_RANGE(null_probability, 0.0, 1.0);
+ switch (field.type()->id()) {
+ case Type::type::NA: {
+ return std::make_shared<NullArray>(length);
+ }
+
+ case Type::type::BOOL: {
+ const double true_probability =
+ GetMetadata<double>(field.metadata().get(), "true_probability", 0.5);
+ return Boolean(length, true_probability, null_probability);
+ }
+
+ GENERATE_INTEGRAL_CASE(UInt8Type);
+ GENERATE_INTEGRAL_CASE(Int8Type);
+ GENERATE_INTEGRAL_CASE(UInt16Type);
+ GENERATE_INTEGRAL_CASE(Int16Type);
+ GENERATE_INTEGRAL_CASE(UInt32Type);
+ GENERATE_INTEGRAL_CASE(Int32Type);
+ GENERATE_INTEGRAL_CASE(UInt64Type);
+ GENERATE_INTEGRAL_CASE(Int64Type);
+ GENERATE_INTEGRAL_CASE_VIEW(Int16Type, HalfFloatType);
+ GENERATE_FLOATING_CASE(FloatType, Float32);
+ GENERATE_FLOATING_CASE(DoubleType, Float64);
+
+ case Type::type::STRING:
+ case Type::type::BINARY: {
+ const auto min_length =
+ GetMetadata<int32_t>(field.metadata().get(), "min_length", 0);
+ const auto max_length =
+ GetMetadata<int32_t>(field.metadata().get(), "max_length", 20);
+ const auto unique_values =
+ GetMetadata<int32_t>(field.metadata().get(), "unique", -1);
+ if (unique_values > 0) {
+ return *StringWithRepeats(length, unique_values, min_length, max_length,
+ null_probability)
+ ->View(field.type());
+ }
+ return *String(length, min_length, max_length, null_probability)
+ ->View(field.type());
+ }
+
+ case Type::type::DECIMAL128:
+ return Decimal128(field.type(), length, null_probability);
+
+ case Type::type::DECIMAL256:
+ return Decimal256(field.type(), length, null_probability);
+
+ case Type::type::FIXED_SIZE_BINARY: {
+ auto byte_width =
+ internal::checked_pointer_cast<FixedSizeBinaryType>(field.type())->byte_width();
+ return *FixedSizeBinary(length, byte_width, null_probability)->View(field.type());
+ }
+
+ GENERATE_INTEGRAL_CASE_VIEW(Int32Type, Date32Type);
+ GENERATE_INTEGRAL_CASE_VIEW(Int64Type, Date64Type);
+ GENERATE_INTEGRAL_CASE_VIEW(Int64Type, TimestampType);
+ GENERATE_INTEGRAL_CASE_VIEW(Int32Type, Time32Type);
+ GENERATE_INTEGRAL_CASE_VIEW(Int64Type, Time64Type);
+ GENERATE_INTEGRAL_CASE_VIEW(Int32Type, MonthIntervalType);
+
+ // This isn't as flexible as it could be, but the array-of-structs layout of this
+ // type means it's not a (useful) composition of other generators
+ GENERATE_INTEGRAL_CASE_VIEW(Int64Type, DayTimeIntervalType);
+ case Type::type::INTERVAL_MONTH_DAY_NANO: {
+ return *FixedSizeBinary(length, /*byte_width=*/16, null_probability)
+ ->View(month_day_nano_interval());
+ }
+
+ GENERATE_LIST_CASE(ListArray);
+
+ case Type::type::STRUCT: {
+ ArrayVector child_arrays(field.type()->num_fields());
+ std::vector<std::string> field_names;
+ for (int i = 0; i < field.type()->num_fields(); i++) {
+ const auto& child_field = field.type()->field(i);
+ child_arrays[i] = ArrayOf(*child_field, length);
+ field_names.push_back(child_field->name());
+ }
+ return *StructArray::Make(child_arrays, field_names,
+ NullBitmap(length, null_probability));
+ }
+
+ case Type::type::SPARSE_UNION:
+ case Type::type::DENSE_UNION: {
+ ArrayVector child_arrays(field.type()->num_fields());
+ for (int i = 0; i < field.type()->num_fields(); i++) {
+ const auto& child_field = field.type()->field(i);
+ child_arrays[i] = ArrayOf(*child_field, length);
+ }
+ auto array = field.type()->id() == Type::type::SPARSE_UNION
+ ? SparseUnion(child_arrays, length)
+ : DenseUnion(child_arrays, length);
+ return *array->View(field.type());
+ }
+
+ case Type::type::DICTIONARY: {
+ const auto values_length =
+ GetMetadata<int64_t>(field.metadata().get(), "values", 4);
+ auto dict_type = internal::checked_pointer_cast<DictionaryType>(field.type());
+ // TODO: no way to control generation of dictionary
+ auto values =
+ ArrayOf(*arrow::field("temporary", dict_type->value_type(), /*nullable=*/false),
+ values_length);
+ auto merged = field.metadata() ? field.metadata() : key_value_metadata({}, {});
+ if (merged->Contains("min"))
+ ABORT_NOT_OK(Status::Invalid(field.ToString(), ": cannot specify min"));
+ if (merged->Contains("max"))
+ ABORT_NOT_OK(Status::Invalid(field.ToString(), ": cannot specify max"));
+ merged = merged->Merge(*key_value_metadata(
+ {{"min", "0"}, {"max", std::to_string(values_length - 1)}}));
+ auto indices = ArrayOf(
+ *arrow::field("temporary", dict_type->index_type(), field.nullable(), merged),
+ length);
+ return *DictionaryArray::FromArrays(field.type(), indices, values);
+ }
+
+ case Type::type::MAP: {
+ const auto values_length = GetMetadata<int32_t>(field.metadata().get(), "values",
+ static_cast<int32_t>(length));
+ const auto force_empty_nulls =
+ GetMetadata<bool>(field.metadata().get(), "force_empty_nulls", false);
+ auto map_type = internal::checked_pointer_cast<MapType>(field.type());
+ auto keys = ArrayOf(*map_type->key_field(), values_length);
+ auto items = ArrayOf(*map_type->item_field(), values_length);
+ // need N + 1 offsets to have N values
+ auto offsets =
+ Offsets(length + 1, 0, values_length, null_probability, force_empty_nulls);
+ return *MapArray::FromArrays(map_type, offsets, keys, items);
+ }
+
+ case Type::type::EXTENSION:
+ // Could be supported by generating the storage type (though any extension
+ // invariants wouldn't be preserved)
+ break;
+
+ case Type::type::FIXED_SIZE_LIST: {
+ auto list_type = internal::checked_pointer_cast<FixedSizeListType>(field.type());
+ const int64_t values_length = list_type->list_size() * length;
+ auto values = ArrayOf(*list_type->value_field(), values_length);
+ auto null_bitmap = NullBitmap(length, null_probability);
+ return std::make_shared<FixedSizeListArray>(list_type, length, values, null_bitmap);
+ }
+
+ GENERATE_INTEGRAL_CASE_VIEW(Int64Type, DurationType);
+
+ case Type::type::LARGE_STRING:
+ case Type::type::LARGE_BINARY: {
+ const auto min_length =
+ GetMetadata<int32_t>(field.metadata().get(), "min_length", 0);
+ const auto max_length =
+ GetMetadata<int32_t>(field.metadata().get(), "max_length", 20);
+ const auto unique_values =
+ GetMetadata<int32_t>(field.metadata().get(), "unique", -1);
+ if (unique_values > 0) {
+ ABORT_NOT_OK(
+ Status::NotImplemented("Generating random array with repeated values for "
+ "large string/large binary types"));
+ }
+ return *LargeString(length, min_length, max_length, null_probability)
+ ->View(field.type());
+ }
+
+ GENERATE_LIST_CASE(LargeListArray);
+
+ default:
+ break;
+ }
+#undef GENERATE_INTEGRAL_CASE_VIEW
+#undef GENERATE_INTEGRAL_CASE
+#undef GENERATE_FLOATING_CASE
+#undef GENERATE_LIST_CASE
+#undef VALIDATE_RANGE
+#undef VALIDATE_MIN_MAX
+
+ ABORT_NOT_OK(
+ Status::NotImplemented("Generating random array for field ", field.ToString()));
+ return nullptr;
+}
+
+std::shared_ptr<arrow::RecordBatch> RandomArrayGenerator::BatchOf(
+ const FieldVector& fields, int64_t length) {
+ std::vector<std::shared_ptr<Array>> arrays(fields.size());
+ for (size_t i = 0; i < fields.size(); i++) {
+ const auto& field = fields[i];
+ arrays[i] = ArrayOf(*field, length);
+ }
+ return RecordBatch::Make(schema(fields), length, std::move(arrays));
+}
+
+std::shared_ptr<arrow::Array> GenerateArray(const Field& field, int64_t length,
+ SeedType seed) {
+ return RandomArrayGenerator(seed).ArrayOf(field, length);
+}
+
+std::shared_ptr<arrow::RecordBatch> GenerateBatch(const FieldVector& fields,
+ int64_t length, SeedType seed) {
+ return RandomArrayGenerator(seed).BatchOf(fields, length);
+}
+} // namespace random
+
+void rand_day_millis(int64_t N, std::vector<DayTimeIntervalType::DayMilliseconds>* out) {
+ const int random_seed = 0;
+ arrow::random::pcg32_fast gen(random_seed);
+ std::uniform_int_distribution<int32_t> d(std::numeric_limits<int32_t>::min(),
+ std::numeric_limits<int32_t>::max());
+ out->resize(N, {});
+ std::generate(out->begin(), out->end(), [&d, &gen] {
+ DayTimeIntervalType::DayMilliseconds tmp;
+ tmp.days = d(gen);
+ tmp.milliseconds = d(gen);
+ return tmp;
+ });
+}
+
+void rand_month_day_nanos(int64_t N,
+ std::vector<MonthDayNanoIntervalType::MonthDayNanos>* out) {
+ const int random_seed = 0;
+ arrow::random::pcg32_fast gen(random_seed);
+ std::uniform_int_distribution<int64_t> d(std::numeric_limits<int64_t>::min(),
+ std::numeric_limits<int64_t>::max());
+ out->resize(N, {});
+ std::generate(out->begin(), out->end(), [&d, &gen] {
+ MonthDayNanoIntervalType::MonthDayNanos tmp;
+ tmp.months = static_cast<int32_t>(d(gen));
+ tmp.days = static_cast<int32_t>(d(gen));
+ tmp.nanoseconds = d(gen);
+ return tmp;
+ });
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/random.h b/src/arrow/cpp/src/arrow/testing/random.h
new file mode 100644
index 000000000..c77ae2525
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/random.h
@@ -0,0 +1,489 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <random>
+#include <vector>
+
+#include "arrow/testing/uniform_real.h"
+#include "arrow/testing/visibility.h"
+#include "arrow/type.h"
+
+namespace arrow {
+
+class Array;
+
+namespace random {
+
+using SeedType = int32_t;
+constexpr SeedType kSeedMax = std::numeric_limits<SeedType>::max();
+
+class ARROW_TESTING_EXPORT RandomArrayGenerator {
+ public:
+ explicit RandomArrayGenerator(SeedType seed)
+ : seed_distribution_(static_cast<SeedType>(1), kSeedMax), seed_rng_(seed) {}
+
+ /// \brief Generate a null bitmap
+ ///
+ /// \param[in] size the size of the bitmap to generate
+ /// \param[in] null_probability the probability of a bit being zero
+ ///
+ /// \return a generated Buffer
+ std::shared_ptr<Buffer> NullBitmap(int64_t size, double null_probability = 0);
+
+ /// \brief Generate a random BooleanArray
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] true_probability the probability of a value being 1 / bit-set
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> Boolean(int64_t size, double true_probability,
+ double null_probability = 0);
+
+ /// \brief Generate a random UInt8Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] min the lower bound of the uniform distribution
+ /// \param[in] max the upper bound of the uniform distribution
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> UInt8(int64_t size, uint8_t min, uint8_t max,
+ double null_probability = 0);
+
+ /// \brief Generate a random Int8Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] min the lower bound of the uniform distribution
+ /// \param[in] max the upper bound of the uniform distribution
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> Int8(int64_t size, int8_t min, int8_t max,
+ double null_probability = 0);
+
+ /// \brief Generate a random UInt16Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] min the lower bound of the uniform distribution
+ /// \param[in] max the upper bound of the uniform distribution
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> UInt16(int64_t size, uint16_t min, uint16_t max,
+ double null_probability = 0);
+
+ /// \brief Generate a random Int16Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] min the lower bound of the uniform distribution
+ /// \param[in] max the upper bound of the uniform distribution
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> Int16(int64_t size, int16_t min, int16_t max,
+ double null_probability = 0);
+
+ /// \brief Generate a random UInt32Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] min the lower bound of the uniform distribution
+ /// \param[in] max the upper bound of the uniform distribution
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> UInt32(int64_t size, uint32_t min, uint32_t max,
+ double null_probability = 0);
+
+ /// \brief Generate a random Int32Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] min the lower bound of the uniform distribution
+ /// \param[in] max the upper bound of the uniform distribution
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> Int32(int64_t size, int32_t min, int32_t max,
+ double null_probability = 0);
+
+ /// \brief Generate a random UInt64Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] min the lower bound of the uniform distribution
+ /// \param[in] max the upper bound of the uniform distribution
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> UInt64(int64_t size, uint64_t min, uint64_t max,
+ double null_probability = 0);
+
+ /// \brief Generate a random Int64Array
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] min the lower bound of the uniform distribution
+ /// \param[in] max the upper bound of the uniform distribution
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> Int64(int64_t size, int64_t min, int64_t max,
+ double null_probability = 0);
+
+ /// \brief Generate a random HalfFloatArray
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] min the lower bound of the distribution
+ /// \param[in] max the upper bound of the distribution
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> Float16(int64_t size, int16_t min, int16_t max,
+ double null_probability = 0);
+
+ /// \brief Generate a random FloatArray
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] min the lower bound of the uniform distribution
+ /// \param[in] max the upper bound of the uniform distribution
+ /// \param[in] null_probability the probability of a value being null
+ /// \param[in] nan_probability the probability of a value being NaN
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> Float32(int64_t size, float min, float max,
+ double null_probability = 0, double nan_probability = 0);
+
+ /// \brief Generate a random DoubleArray
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] min the lower bound of the uniform distribution
+ /// \param[in] max the upper bound of the uniform distribution
+ /// \param[in] null_probability the probability of a value being null
+ /// \param[in] nan_probability the probability of a value being NaN
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> Float64(int64_t size, double min, double max,
+ double null_probability = 0, double nan_probability = 0);
+
+ template <typename ArrowType, typename CType = typename ArrowType::c_type>
+ std::shared_ptr<Array> Numeric(int64_t size, CType min, CType max,
+ double null_probability = 0) {
+ switch (ArrowType::type_id) {
+ case Type::UINT8:
+ return UInt8(size, static_cast<uint8_t>(min), static_cast<uint8_t>(max),
+ null_probability);
+ case Type::INT8:
+ return Int8(size, static_cast<int8_t>(min), static_cast<int8_t>(max),
+ null_probability);
+ case Type::UINT16:
+ return UInt16(size, static_cast<uint16_t>(min), static_cast<uint16_t>(max),
+ null_probability);
+ case Type::INT16:
+ return Int16(size, static_cast<int16_t>(min), static_cast<int16_t>(max),
+ null_probability);
+ case Type::UINT32:
+ return UInt32(size, static_cast<uint32_t>(min), static_cast<uint32_t>(max),
+ null_probability);
+ case Type::INT32:
+ return Int32(size, static_cast<int32_t>(min), static_cast<int32_t>(max),
+ null_probability);
+ case Type::UINT64:
+ return UInt64(size, static_cast<uint64_t>(min), static_cast<uint64_t>(max),
+ null_probability);
+ case Type::INT64:
+ return Int64(size, static_cast<int64_t>(min), static_cast<int64_t>(max),
+ null_probability);
+ case Type::HALF_FLOAT:
+ return Float16(size, static_cast<int16_t>(min), static_cast<int16_t>(max),
+ null_probability);
+ case Type::FLOAT:
+ return Float32(size, static_cast<float>(min), static_cast<float>(max),
+ null_probability);
+ case Type::DOUBLE:
+ return Float64(size, static_cast<double>(min), static_cast<double>(max),
+ null_probability);
+ default:
+ return nullptr;
+ }
+ }
+
+ /// \brief Generate a random Decimal128Array
+ ///
+ /// \param[in] type the type of the array to generate
+ /// (must be an instance of Decimal128Type)
+ /// \param[in] size the size of the array to generate
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> Decimal128(std::shared_ptr<DataType> type, int64_t size,
+ double null_probability = 0);
+
+ /// \brief Generate a random Decimal256Array
+ ///
+ /// \param[in] type the type of the array to generate
+ /// (must be an instance of Decimal256Type)
+ /// \param[in] size the size of the array to generate
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> Decimal256(std::shared_ptr<DataType> type, int64_t size,
+ double null_probability = 0);
+
+ /// \brief Generate an array of offsets (for use in e.g. ListArray::FromArrays)
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] first_offset the first offset value (usually 0)
+ /// \param[in] last_offset the last offset value (usually the size of the child array)
+ /// \param[in] null_probability the probability of an offset being null
+ /// \param[in] force_empty_nulls if true, null offsets must have 0 "length"
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> Offsets(int64_t size, int32_t first_offset, int32_t last_offset,
+ double null_probability = 0,
+ bool force_empty_nulls = false);
+
+ std::shared_ptr<Array> LargeOffsets(int64_t size, int64_t first_offset,
+ int64_t last_offset, double null_probability = 0,
+ bool force_empty_nulls = false);
+
+ /// \brief Generate a random StringArray
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] min_length the lower bound of the string length
+ /// determined by the uniform distribution
+ /// \param[in] max_length the upper bound of the string length
+ /// determined by the uniform distribution
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> String(int64_t size, int32_t min_length, int32_t max_length,
+ double null_probability = 0);
+
+ /// \brief Generate a random LargeStringArray
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] min_length the lower bound of the string length
+ /// determined by the uniform distribution
+ /// \param[in] max_length the upper bound of the string length
+ /// determined by the uniform distribution
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> LargeString(int64_t size, int32_t min_length, int32_t max_length,
+ double null_probability = 0);
+
+ /// \brief Generate a random StringArray with repeated values
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] unique the number of unique string values used
+ /// to populate the array
+ /// \param[in] min_length the lower bound of the string length
+ /// determined by the uniform distribution
+ /// \param[in] max_length the upper bound of the string length
+ /// determined by the uniform distribution
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> StringWithRepeats(int64_t size, int64_t unique,
+ int32_t min_length, int32_t max_length,
+ double null_probability = 0);
+
+ /// \brief Like StringWithRepeats but return BinaryArray
+ std::shared_ptr<Array> BinaryWithRepeats(int64_t size, int64_t unique,
+ int32_t min_length, int32_t max_length,
+ double null_probability = 0);
+
+ /// \brief Generate a random FixedSizeBinaryArray
+ ///
+ /// \param[in] size the size of the array to generate
+ /// \param[in] byte_width the byte width of fixed-size binary items
+ /// \param[in] null_probability the probability of a value being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> FixedSizeBinary(int64_t size, int32_t byte_width,
+ double null_probability = 0);
+
+ /// \brief Generate a random ListArray
+ ///
+ /// \param[in] values The underlying values array
+ /// \param[in] size The size of the generated list array
+ /// \param[in] null_probability the probability of a list value being null
+ /// \param[in] force_empty_nulls if true, null list entries must have 0 length
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> List(const Array& values, int64_t size, double null_probability,
+ bool force_empty_nulls = false);
+
+ /// \brief Generate a random MapArray
+ ///
+ /// \param[in] keys The underlying keys array
+ /// \param[in] items The underlying items array
+ /// \param[in] size The size of the generated map array
+ /// \param[in] null_probability the probability of a map value being null
+ /// \param[in] force_empty_nulls if true, null map entries must have 0 length
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> Map(const std::shared_ptr<Array>& keys,
+ const std::shared_ptr<Array>& items, int64_t size,
+ double null_probability, bool force_empty_nulls = false);
+
+ /// \brief Generate a random SparseUnionArray
+ ///
+ /// The type ids are chosen randomly, according to a uniform distribution,
+ /// amongst the given child fields.
+ ///
+ /// \param[in] fields Vector of Arrays containing the data for each union field
+ /// \param[in] size The size of the generated sparse union array
+ std::shared_ptr<Array> SparseUnion(const ArrayVector& fields, int64_t size);
+
+ /// \brief Generate a random DenseUnionArray
+ ///
+ /// The type ids are chosen randomly, according to a uniform distribution,
+ /// amongst the given child fields. The offsets are incremented along
+ /// each child field.
+ ///
+ /// \param[in] fields Vector of Arrays containing the data for each union field
+ /// \param[in] size The size of the generated sparse union array
+ std::shared_ptr<Array> DenseUnion(const ArrayVector& fields, int64_t size);
+
+ /// \brief Generate a random Array of the specified type, size, and null_probability.
+ ///
+ /// Generation parameters other than size and null_probability are determined based on
+ /// the type of Array to be generated.
+ /// If boolean the probabilities of true,false values are 0.25,0.75 respectively.
+ /// If numeric min,max will be the least and greatest representable values.
+ /// If string min_length,max_length will be 0,sqrt(size) respectively.
+ ///
+ /// \param[in] type the type of Array to generate
+ /// \param[in] size the size of the Array to generate
+ /// \param[in] null_probability the probability of a slot being null
+ /// \return a generated Array
+ std::shared_ptr<Array> ArrayOf(std::shared_ptr<DataType> type, int64_t size,
+ double null_probability);
+
+ /// \brief Generate an array with random data based on the given field. See BatchOf
+ /// for usage info.
+ std::shared_ptr<Array> ArrayOf(const Field& field, int64_t size);
+
+ /// \brief Generate a record batch with random data of the specified length.
+ ///
+ /// Generation options are read from key-value metadata for each field, and may be
+ /// specified at any nesting level. For example, generation options for the child
+ /// values of a list array can be specified by constructing the list type with
+ /// list(field("item", int8(), options_metadata))
+ ///
+ /// The following options are supported:
+ ///
+ /// For all types except NullType:
+ /// - null_probability (double): range [0.0, 1.0] the probability of a null value.
+ /// Default/value is 0.0 if the field is marked non-nullable, else it is 0.01
+ ///
+ /// For all numeric types T:
+ /// - min (T::c_type): the minimum value to generate (inclusive), default
+ /// std::numeric_limits<T::c_type>::min()
+ /// - max (T::c_type): the maximum value to generate (inclusive), default
+ /// std::numeric_limits<T::c_type>::max()
+ /// Note this means that, for example, min/max are int16_t values for HalfFloatType.
+ ///
+ /// For floating point types T for which is_physical_floating_type<T>:
+ /// - nan_probability (double): range [0.0, 1.0] the probability of a NaN value.
+ ///
+ /// For BooleanType:
+ /// - true_probability (double): range [0.0, 1.0] the probability of a true.
+ ///
+ /// For DictionaryType:
+ /// - values (int32_t): the size of the dictionary.
+ /// Other properties are passed to the generator for the dictionary indices. However,
+ /// min and max cannot be specified. Note it is not possible to otherwise customize
+ /// the generation of dictionary values.
+ ///
+ /// For list, string, and binary types T, including their large variants:
+ /// - min_length (T::offset_type): the minimum length of the child to generate,
+ /// default 0
+ /// - max_length (T::offset_type): the minimum length of the child to generate,
+ /// default 1024
+ ///
+ /// For string and binary types T (not including their large variants):
+ /// - unique (int32_t): if positive, this many distinct values will be generated
+ /// and all array values will be one of these values, default -1
+ ///
+ /// For MapType:
+ /// - values (int32_t): the number of key-value pairs to generate, which will be
+ /// partitioned among the array values.
+ std::shared_ptr<arrow::RecordBatch> BatchOf(const FieldVector& fields, int64_t size);
+
+ SeedType seed() { return seed_distribution_(seed_rng_); }
+
+ private:
+ std::uniform_int_distribution<SeedType> seed_distribution_;
+ std::default_random_engine seed_rng_;
+};
+
+/// Generate an array with random data. See RandomArrayGenerator::BatchOf.
+ARROW_TESTING_EXPORT
+std::shared_ptr<arrow::RecordBatch> GenerateBatch(const FieldVector& fields, int64_t size,
+ SeedType seed);
+
+/// Generate an array with random data. See RandomArrayGenerator::BatchOf.
+ARROW_TESTING_EXPORT
+std::shared_ptr<arrow::Array> GenerateArray(const Field& field, int64_t size,
+ SeedType seed);
+
+} // namespace random
+
+//
+// Assorted functions
+//
+
+ARROW_TESTING_EXPORT
+void rand_day_millis(int64_t N, std::vector<DayTimeIntervalType::DayMilliseconds>* out);
+ARROW_TESTING_EXPORT
+void rand_month_day_nanos(int64_t N,
+ std::vector<MonthDayNanoIntervalType::MonthDayNanos>* out);
+
+template <typename T, typename U>
+void randint(int64_t N, T lower, T upper, std::vector<U>* out) {
+ const int random_seed = 0;
+ std::default_random_engine gen(random_seed);
+ std::uniform_int_distribution<T> d(lower, upper);
+ out->resize(N, static_cast<T>(0));
+ std::generate(out->begin(), out->end(), [&d, &gen] { return static_cast<U>(d(gen)); });
+}
+
+template <typename T, typename U>
+void random_real(int64_t n, uint32_t seed, T min_value, T max_value,
+ std::vector<U>* out) {
+ std::default_random_engine gen(seed);
+ ::arrow::random::uniform_real_distribution<T> d(min_value, max_value);
+ out->resize(n, static_cast<T>(0));
+ std::generate(out->begin(), out->end(), [&d, &gen] { return static_cast<U>(d(gen)); });
+}
+
+template <typename T, typename U>
+void rand_uniform_int(int64_t n, uint32_t seed, T min_value, T max_value, U* out) {
+ assert(out || (n == 0));
+ std::default_random_engine gen(seed);
+ std::uniform_int_distribution<T> d(min_value, max_value);
+ std::generate(out, out + n, [&d, &gen] { return static_cast<U>(d(gen)); });
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/random_test.cc b/src/arrow/cpp/src/arrow/testing/random_test.cc
new file mode 100644
index 000000000..002c6c9b7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/random_test.cc
@@ -0,0 +1,513 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/record_batch.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/pcg_random.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace random {
+
+// Use short arrays since especially in debug mode, generating list(list()) is slow
+constexpr int64_t kExpectedLength = 24;
+
+class RandomArrayTest : public ::testing::TestWithParam<std::shared_ptr<Field>> {
+ protected:
+ std::shared_ptr<Field> GetField() { return GetParam(); }
+};
+
+TEST_P(RandomArrayTest, GenerateArray) {
+ auto field = GetField();
+ auto array = GenerateArray(*field, kExpectedLength, 0xDEADBEEF);
+ AssertTypeEqual(field->type(), array->type());
+ ASSERT_EQ(kExpectedLength, array->length());
+ ASSERT_OK(array->ValidateFull());
+}
+
+TEST_P(RandomArrayTest, GenerateBatch) {
+ auto field = GetField();
+ auto batch = GenerateBatch({field}, kExpectedLength, 0xDEADBEEF);
+ AssertSchemaEqual(schema({field}), batch->schema());
+ auto array = batch->column(0);
+ ASSERT_EQ(kExpectedLength, array->length());
+ ASSERT_OK(array->ValidateFull());
+}
+
+TEST_P(RandomArrayTest, GenerateZeroLengthArray) {
+ auto field = GetField();
+ if (field->type()->id() == Type::type::DENSE_UNION) {
+ GTEST_SKIP() << "Cannot generate zero-length dense union arrays";
+ }
+ auto array = GenerateArray(*field, 0, 0xDEADBEEF);
+ AssertTypeEqual(field->type(), array->type());
+ ASSERT_EQ(0, array->length());
+ ASSERT_OK(array->ValidateFull());
+}
+
+TEST_P(RandomArrayTest, GenerateArrayWithZeroNullProbability) {
+ auto field =
+ GetField()->WithMetadata(key_value_metadata({{"null_probability", "0.0"}}));
+ if (field->type()->id() == Type::type::NA) {
+ GTEST_SKIP() << "Cannot generate non-null null arrays";
+ }
+ auto batch = GenerateBatch({field}, kExpectedLength, 0xDEADBEEF);
+ AssertSchemaEqual(schema({field}), batch->schema());
+ auto array = batch->column(0);
+ ASSERT_OK(array->ValidateFull());
+ ASSERT_EQ(0, array->null_count());
+}
+
+TEST_P(RandomArrayTest, GenerateNonNullableArray) {
+ auto field = GetField()->WithNullable(false);
+ if (field->type()->id() == Type::type::NA) {
+ GTEST_SKIP() << "Cannot generate non-null null arrays";
+ }
+ auto batch = GenerateBatch({field}, kExpectedLength, 0xDEADBEEF);
+ AssertSchemaEqual(schema({field}), batch->schema());
+ auto array = batch->column(0);
+ ASSERT_OK(array->ValidateFull());
+ ASSERT_EQ(0, array->null_count());
+}
+
+auto values = ::testing::Values(
+ field("null", null()), field("bool", boolean()), field("uint8", uint8()),
+ field("int8", int8()), field("uint16", uint16()), field("int16", int16()),
+ field("uint32", uint32()), field("int32", int32()), field("uint64", uint64()),
+ field("int64", int64()), field("float16", float16()), field("float32", float32()),
+ field("float64", float64()), field("string", utf8()), field("binary", binary()),
+ field("fixed_size_binary", fixed_size_binary(8)),
+ field("decimal128", decimal128(8, 3)), field("decimal128", decimal128(29, -5)),
+ field("decimal256", decimal256(16, 4)), field("decimal256", decimal256(57, -6)),
+ field("date32", date32()), field("date64", date64()),
+ field("timestampns", timestamp(TimeUnit::NANO)),
+ field("timestamps", timestamp(TimeUnit::SECOND, "America/Phoenix")),
+ field("time32ms", time32(TimeUnit::MILLI)), field("time64ns", time64(TimeUnit::NANO)),
+ field("time32s", time32(TimeUnit::SECOND)),
+ field("time64us", time64(TimeUnit::MICRO)), field("month_interval", month_interval()),
+ field("daytime_interval", day_time_interval()),
+ field("month_day_nano_interval", month_day_nano_interval()),
+ field("listint8", list(int8())), field("listlistint8", list(list(int8()))),
+ field("listint8emptynulls", list(int8()), true,
+ key_value_metadata({{"force_empty_nulls", "true"}})),
+ field("listint81024values", list(int8()), true,
+ key_value_metadata({{"values", "1024"}})),
+ field("structints", struct_({
+ field("int8", int8()),
+ field("int16", int16()),
+ field("int32", int32()),
+ })),
+ field("structnested", struct_({
+ field("string", utf8()),
+ field("list", list(int64())),
+ field("timestamp", timestamp(TimeUnit::MILLI)),
+ })),
+ field("sparseunion", sparse_union({
+ field("int8", int8()),
+ field("int16", int16()),
+ field("int32", int32()),
+ })),
+ field("denseunion", dense_union({
+ field("int8", int8()),
+ field("int16", int16()),
+ field("int32", int32()),
+ })),
+ field("dictionary", dictionary(int8(), utf8())), field("map", map(int8(), utf8())),
+ field("fixedsizelist", fixed_size_list(int8(), 4)),
+ field("durationns", duration(TimeUnit::NANO)), field("largestring", large_utf8()),
+ field("largebinary", large_binary()),
+ field("largelistlistint8", large_list(list(int8()))));
+
+INSTANTIATE_TEST_SUITE_P(
+ TestRandomArrayGeneration, RandomArrayTest, values,
+ [](const ::testing::TestParamInfo<RandomArrayTest::ParamType>& info) {
+ return std::to_string(info.index) + info.param->name();
+ });
+
+template <typename T>
+class RandomNumericArrayTest : public ::testing::Test {
+ protected:
+ std::shared_ptr<Field> GetField() { return field("field0", std::make_shared<T>()); }
+
+ std::shared_ptr<NumericArray<T>> Downcast(std::shared_ptr<Array> array) {
+ return internal::checked_pointer_cast<NumericArray<T>>(array);
+ }
+};
+
+using NumericTypes =
+ ::testing::Types<UInt8Type, Int8Type, UInt16Type, Int16Type, UInt32Type, Int32Type,
+ HalfFloatType, FloatType, DoubleType>;
+TYPED_TEST_SUITE(RandomNumericArrayTest, NumericTypes);
+
+TYPED_TEST(RandomNumericArrayTest, GenerateMinMax) {
+ auto field = this->GetField()->WithMetadata(
+ key_value_metadata({{"min", "0"}, {"max", "127"}, {"nan_probability", "0.0"}}));
+ auto batch = GenerateBatch({field}, kExpectedLength, 0xDEADBEEF);
+ ASSERT_OK(batch->ValidateFull());
+ AssertSchemaEqual(schema({field}), batch->schema());
+ auto array = this->Downcast(batch->column(0));
+ for (auto slot : *array) {
+ if (!slot.has_value()) continue;
+ ASSERT_GE(slot, typename TypeParam::c_type(0));
+ ASSERT_LE(slot, typename TypeParam::c_type(127));
+ }
+}
+
+TYPED_TEST(RandomNumericArrayTest, EmptyRange) {
+ auto field =
+ this->GetField()->WithMetadata(key_value_metadata({{"min", "42"}, {"max", "42"}}));
+ auto batch = GenerateBatch({field}, kExpectedLength, 0xcafe);
+ ASSERT_OK(batch->ValidateFull());
+ AssertSchemaEqual(schema({field}), batch->schema());
+ auto array = this->Downcast(batch->column(0));
+ for (auto slot : *array) {
+ if (!slot.has_value()) continue;
+ ASSERT_EQ(slot, typename TypeParam::c_type(42));
+ }
+}
+
+template <typename DecimalType>
+class RandomDecimalArrayTest : public ::testing::Test {
+ protected:
+ using ArrayType = typename TypeTraits<DecimalType>::ArrayType;
+ using DecimalValue = typename TypeTraits<DecimalType>::BuilderType::ValueType;
+
+ constexpr static int32_t max_precision() { return DecimalType::kMaxPrecision; }
+
+ std::shared_ptr<DataType> type(int32_t precision, int32_t scale) {
+ return std::make_shared<DecimalType>(precision, scale);
+ }
+
+ void CheckArray(const Array& array) {
+ ASSERT_OK(array.ValidateFull());
+
+ const auto& type = checked_cast<const DecimalType&>(*array.type());
+ const auto& values = checked_cast<const ArrayType&>(array);
+
+ const DecimalValue limit = DecimalValue::GetScaleMultiplier(type.precision());
+ const DecimalValue neg_limit = DecimalValue(limit).Negate();
+ const DecimalValue half_limit = limit / DecimalValue(2);
+ const DecimalValue neg_half_limit = DecimalValue(half_limit).Negate();
+
+ // Check that random-generated values:
+ // - satisfy the requested precision
+ // - at least sometimes are close to the max allowable values for precision
+ // - sometimes are negative
+ int64_t non_nulls = 0;
+ int64_t over_half = 0;
+ int64_t negative = 0;
+
+ for (int64_t i = 0; i < values.length(); ++i) {
+ if (values.IsNull(i)) {
+ continue;
+ }
+ ++non_nulls;
+ const DecimalValue value(values.GetValue(i));
+ ASSERT_LT(value, limit);
+ ASSERT_GT(value, neg_limit);
+ if (value >= half_limit || value <= neg_half_limit) {
+ ++over_half;
+ }
+ if (value.Sign() < 0) {
+ ++negative;
+ }
+ }
+
+ ASSERT_GE(over_half, non_nulls * 0.3);
+ ASSERT_LE(over_half, non_nulls * 0.7);
+ ASSERT_GE(negative, non_nulls * 0.3);
+ ASSERT_LE(negative, non_nulls * 0.7);
+ }
+};
+
+using DecimalTypes = ::testing::Types<Decimal128Type, Decimal256Type>;
+TYPED_TEST_SUITE(RandomDecimalArrayTest, DecimalTypes);
+
+TYPED_TEST(RandomDecimalArrayTest, Basic) {
+ random::RandomArrayGenerator rng(42);
+
+ for (const int32_t precision :
+ {1, 2, 5, 9, 18, 19, 25, this->max_precision() - 1, this->max_precision()}) {
+ ARROW_SCOPED_TRACE("precision = ", precision);
+ const auto type = this->type(precision, 5);
+ auto array = rng.ArrayOf(type, /*size=*/1000, /*null_probability=*/0.2);
+ this->CheckArray(*array);
+ }
+}
+
+// Test all the supported options
+TEST(TypeSpecificTests, BoolTrueProbability) {
+ auto field =
+ arrow::field("bool", boolean(), key_value_metadata({{"true_probability", "1.0"}}));
+ auto base_array = GenerateArray(*field, kExpectedLength, 0xDEADBEEF);
+ AssertTypeEqual(field->type(), base_array->type());
+ auto array = internal::checked_pointer_cast<BooleanArray>(base_array);
+ ASSERT_OK(array->ValidateFull());
+ for (const auto& value : *array) {
+ ASSERT_TRUE(!value.has_value() || *value);
+ }
+}
+
+TEST(TypeSpecificTests, DictionaryValues) {
+ auto field = arrow::field("dictionary", dictionary(int8(), utf8()),
+ key_value_metadata({{"values", "16"}}));
+ auto base_array = GenerateArray(*field, kExpectedLength, 0xDEADBEEF);
+ AssertTypeEqual(field->type(), base_array->type());
+ auto array = internal::checked_pointer_cast<DictionaryArray>(base_array);
+ ASSERT_OK(array->ValidateFull());
+ ASSERT_EQ(16, array->dictionary()->length());
+}
+
+TEST(TypeSpecificTests, Float32Nan) {
+ auto field = arrow::field("float32", float32(),
+ key_value_metadata({{"nan_probability", "1.0"}}));
+ auto base_array = GenerateArray(*field, kExpectedLength, 0xDEADBEEF);
+ AssertTypeEqual(field->type(), base_array->type());
+ auto array = internal::checked_pointer_cast<NumericArray<FloatType>>(base_array);
+ ASSERT_OK(array->ValidateFull());
+ for (const auto& value : *array) {
+ ASSERT_TRUE(!value.has_value() || std::isnan(*value));
+ }
+}
+
+TEST(TypeSpecificTests, Float64Nan) {
+ auto field = arrow::field("float64", float64(),
+ key_value_metadata({{"nan_probability", "1.0"}}));
+ auto base_array = GenerateArray(*field, kExpectedLength, 0xDEADBEEF);
+ AssertTypeEqual(field->type(), base_array->type());
+ auto array = internal::checked_pointer_cast<NumericArray<DoubleType>>(base_array);
+ ASSERT_OK(array->ValidateFull());
+ for (const auto& value : *array) {
+ ASSERT_TRUE(!value.has_value() || std::isnan(*value));
+ }
+}
+
+TEST(TypeSpecificTests, ListLengths) {
+ {
+ auto field =
+ arrow::field("list", list(int8()),
+ key_value_metadata({{"min_length", "1"}, {"max_length", "1"}}));
+ auto base_array = GenerateArray(*field, kExpectedLength, 0xDEADBEEF);
+ AssertTypeEqual(field->type(), base_array->type());
+ auto array = internal::checked_pointer_cast<ListArray>(base_array);
+ ASSERT_OK(array->ValidateFull());
+ ASSERT_EQ(array->length(), kExpectedLength);
+ for (int i = 0; i < kExpectedLength; i++) {
+ if (!array->IsNull(i)) {
+ ASSERT_EQ(1, array->value_length(i));
+ }
+ }
+ }
+ {
+ auto field =
+ arrow::field("list", large_list(int8()),
+ key_value_metadata({{"min_length", "10"}, {"max_length", "10"}}));
+ auto base_array = GenerateArray(*field, kExpectedLength, 0xDEADBEEF);
+ AssertTypeEqual(field->type(), base_array->type());
+ auto array = internal::checked_pointer_cast<LargeListArray>(base_array);
+ ASSERT_EQ(array->length(), kExpectedLength);
+ ASSERT_OK(array->ValidateFull());
+ for (int i = 0; i < kExpectedLength; i++) {
+ if (!array->IsNull(i)) {
+ ASSERT_EQ(10, array->value_length(i));
+ }
+ }
+ }
+}
+
+TEST(TypeSpecificTests, MapValues) {
+ auto field =
+ arrow::field("map", map(int8(), int8()), key_value_metadata({{"values", "4"}}));
+ auto base_array = GenerateArray(*field, kExpectedLength, 0xDEADBEEF);
+ AssertTypeEqual(field->type(), base_array->type());
+ auto array = internal::checked_pointer_cast<MapArray>(base_array);
+ ASSERT_OK(array->ValidateFull());
+ ASSERT_EQ(4, array->keys()->length());
+ ASSERT_EQ(4, array->items()->length());
+}
+
+TEST(TypeSpecificTests, RepeatedStrings) {
+ auto field = arrow::field("string", utf8(), key_value_metadata({{"unique", "1"}}));
+ auto base_array = GenerateArray(*field, kExpectedLength, 0xDEADBEEF);
+ AssertTypeEqual(field->type(), base_array->type());
+ auto array = internal::checked_pointer_cast<StringArray>(base_array);
+ ASSERT_OK(array->ValidateFull());
+ util::string_view singular_value = array->GetView(0);
+ for (auto slot : *array) {
+ if (!slot.has_value()) continue;
+ ASSERT_EQ(slot, singular_value);
+ }
+ // N.B. LargeString does not support unique
+}
+
+TEST(TypeSpecificTests, StringLengths) {
+ {
+ auto field = arrow::field(
+ "list", utf8(), key_value_metadata({{"min_length", "1"}, {"max_length", "1"}}));
+ auto base_array = GenerateArray(*field, kExpectedLength, 0xDEADBEEF);
+ AssertTypeEqual(field->type(), base_array->type());
+ auto array = internal::checked_pointer_cast<StringArray>(base_array);
+ ASSERT_OK(array->ValidateFull());
+ for (int i = 0; i < kExpectedLength; i++) {
+ if (!array->IsNull(i)) {
+ ASSERT_EQ(1, array->value_length(i));
+ }
+ }
+ }
+ {
+ auto field = arrow::field(
+ "list", binary(), key_value_metadata({{"min_length", "1"}, {"max_length", "1"}}));
+ auto base_array = GenerateArray(*field, kExpectedLength, 0xDEADBEEF);
+ AssertTypeEqual(field->type(), base_array->type());
+ auto array = internal::checked_pointer_cast<BinaryArray>(base_array);
+ ASSERT_OK(array->ValidateFull());
+ for (int i = 0; i < kExpectedLength; i++) {
+ if (!array->IsNull(i)) {
+ ASSERT_EQ(1, array->value_length(i));
+ }
+ }
+ }
+ {
+ auto field =
+ arrow::field("list", large_utf8(),
+ key_value_metadata({{"min_length", "10"}, {"max_length", "10"}}));
+ auto base_array = GenerateArray(*field, kExpectedLength, 0xDEADBEEF);
+ AssertTypeEqual(field->type(), base_array->type());
+ auto array = internal::checked_pointer_cast<LargeStringArray>(base_array);
+ ASSERT_OK(array->ValidateFull());
+ for (int i = 0; i < kExpectedLength; i++) {
+ if (!array->IsNull(i)) {
+ ASSERT_EQ(10, array->value_length(i));
+ }
+ }
+ }
+ {
+ auto field =
+ arrow::field("list", large_binary(),
+ key_value_metadata({{"min_length", "10"}, {"max_length", "10"}}));
+ auto base_array = GenerateArray(*field, kExpectedLength, 0xDEADBEEF);
+ AssertTypeEqual(field->type(), base_array->type());
+ auto array = internal::checked_pointer_cast<LargeBinaryArray>(base_array);
+ ASSERT_OK(array->ValidateFull());
+ for (int i = 0; i < kExpectedLength; i++) {
+ if (!array->IsNull(i)) {
+ ASSERT_EQ(10, array->value_length(i));
+ }
+ }
+ }
+}
+
+TEST(RandomList, Basics) {
+ random::RandomArrayGenerator rng(42);
+ for (const double null_probability : {0.0, 0.1, 0.98}) {
+ SCOPED_TRACE("null_probability = " + std::to_string(null_probability));
+ auto values = rng.Int16(1234, 0, 10000, null_probability);
+ auto array = rng.List(*values, 45, null_probability);
+ ASSERT_OK(array->ValidateFull());
+ ASSERT_EQ(array->length(), 45);
+ const auto& list_array = checked_cast<const ListArray&>(*array);
+ ASSERT_EQ(list_array.values()->length(), 1234);
+ int64_t null_count = 0;
+ for (int64_t i = 0; i < array->length(); ++i) {
+ null_count += array->IsNull(i);
+ }
+ ASSERT_EQ(null_count, array->data()->null_count);
+ }
+}
+
+template <typename T>
+class UniformRealTest : public ::testing::Test {
+ protected:
+ void VerifyDist(int seed, T a, T b) {
+ pcg32_fast rng(seed);
+ ::arrow::random::uniform_real_distribution<T> dist(a, b);
+
+ const int kCount = 5000;
+ T min = std::numeric_limits<T>::max();
+ T max = std::numeric_limits<T>::lowest();
+ double sum = 0;
+ double square_sum = 0;
+ for (int i = 0; i < kCount; ++i) {
+ const T v = dist(rng);
+ min = std::min(min, v);
+ max = std::max(max, v);
+ sum += v;
+ square_sum += static_cast<double>(v) * v;
+ }
+
+ ASSERT_GE(min, a);
+ ASSERT_LT(max, b);
+
+ // verify E(X), E(X^2) is near theory
+ const double E_X = (a + b) / 2.0;
+ const double E_X2 = 1.0 / 12 * (a - b) * (a - b) + E_X * E_X;
+ ASSERT_NEAR(sum / kCount, E_X, std::abs(E_X) * 0.02);
+ ASSERT_NEAR(square_sum / kCount, E_X2, E_X2 * 0.02);
+ }
+};
+
+using RealCTypes = ::testing::Types<float, double>;
+TYPED_TEST_SUITE(UniformRealTest, RealCTypes);
+
+TYPED_TEST(UniformRealTest, Basic) {
+ int seed = 42;
+ this->VerifyDist(seed++, 0, 1);
+ this->VerifyDist(seed++, -3, 1);
+ this->VerifyDist(seed++, -123456, 654321);
+}
+
+TEST(BernoulliTest, Basic) {
+ int seed = 42;
+
+ // count #trues (values less than p), p = 0 ~ 1
+ auto count = [&seed](double p, int total) {
+ pcg32_fast rng(seed++);
+ ::arrow::random::bernoulli_distribution dist(p);
+ int cnt = 0;
+ for (int i = 0; i < total; ++i) {
+ cnt += dist(rng);
+ }
+ return cnt;
+ };
+
+ ASSERT_EQ(count(0, 1000), 0);
+ ASSERT_EQ(count(1, 1000), 1000);
+
+ // verify #trues is near p*total
+ auto verify = [&count](double p, int total, double dev) {
+ const int cnt = count(p, total);
+ const int min = std::max(0, static_cast<int>(total * p * (1 - dev)));
+ const int max = std::min(total, static_cast<int>(total * p * (1 + dev)));
+ ASSERT_TRUE(cnt >= min && cnt <= max);
+ };
+
+ for (double p = 0.1; p < 0.95; p += 0.1) {
+ verify(p, 5000, 0.1);
+ }
+}
+
+} // namespace random
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/uniform_real.h b/src/arrow/cpp/src/arrow/testing/uniform_real.h
new file mode 100644
index 000000000..155cb16b6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/uniform_real.h
@@ -0,0 +1,84 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Random real generation is very slow on Arm if built with clang + libstdc++
+// due to software emulated long double arithmetic.
+// This file ports some random real libs from llvm libc++ library, which are
+// free from long double calculation.
+// It improves performance significantly on both Arm (~100x) and x86 (~8x) in
+// generating random reals when built with clang + gnu libstdc++.
+// Based on: https://github.com/llvm/llvm-project/tree/main/libcxx
+
+#pragma once
+
+#include <limits>
+
+#include <arrow/util/bit_util.h>
+
+namespace arrow {
+namespace random {
+
+namespace detail {
+
+// std::generate_canonical, simplified
+// https://en.cppreference.com/w/cpp/numeric/random/generate_canonical
+template <typename RealType, typename Rng>
+RealType generate_canonical(Rng& rng) {
+ const size_t b = std::numeric_limits<RealType>::digits;
+ const size_t log2R = 63 - ::arrow::BitUtil::CountLeadingZeros(
+ static_cast<uint64_t>(Rng::max() - Rng::min()) + 1);
+ const size_t k = b / log2R + (b % log2R != 0) + (b == 0);
+ const RealType r = static_cast<RealType>(Rng::max() - Rng::min()) + 1;
+ RealType base = r;
+ RealType sp = static_cast<RealType>(rng() - Rng::min());
+ for (size_t i = 1; i < k; ++i, base *= r) {
+ sp += (rng() - Rng::min()) * base;
+ }
+ return sp / base;
+}
+
+} // namespace detail
+
+// std::uniform_real_distribution, simplified
+// https://en.cppreference.com/w/cpp/numeric/random/uniform_real_distribution
+template <typename RealType = double>
+struct uniform_real_distribution {
+ const RealType a, b;
+
+ explicit uniform_real_distribution(RealType a = 0, RealType b = 1) : a(a), b(b) {}
+
+ template <typename Rng>
+ RealType operator()(Rng& rng) {
+ return (b - a) * detail::generate_canonical<RealType>(rng) + a;
+ }
+};
+
+// std::bernoulli_distribution, simplified
+// https://en.cppreference.com/w/cpp/numeric/random/bernoulli_distribution
+struct bernoulli_distribution {
+ const double p;
+
+ explicit bernoulli_distribution(double p = 0.5) : p(p) {}
+
+ template <class Rng>
+ bool operator()(Rng& rng) {
+ return detail::generate_canonical<double>(rng) < p;
+ }
+};
+
+} // namespace random
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/util.cc b/src/arrow/cpp/src/arrow/testing/util.cc
new file mode 100644
index 000000000..9e3e27174
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/util.cc
@@ -0,0 +1,188 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/testing/util.h"
+
+#include <chrono>
+#include <cstring>
+#include <random>
+
+#ifdef _WIN32
+// clang-format off
+// (prevent include reordering)
+#include "arrow/util/windows_compatibility.h"
+#include <winsock2.h>
+// clang-format on
+#else
+#include <arpa/inet.h> // IWYU pragma: keep
+#include <netinet/in.h> // IWYU pragma: keep
+#include <sys/socket.h> // IWYU pragma: keep
+#include <sys/stat.h> // IWYU pragma: keep
+#include <sys/types.h> // IWYU pragma: keep
+#include <sys/wait.h> // IWYU pragma: keep
+#include <unistd.h> // IWYU pragma: keep
+#endif
+
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/pcg_random.h"
+
+namespace arrow {
+
+using random::pcg32_fast;
+
+uint64_t random_seed() {
+ return std::chrono::high_resolution_clock::now().time_since_epoch().count();
+}
+
+void random_null_bytes(int64_t n, double pct_null, uint8_t* null_bytes) {
+ const int random_seed = 0;
+ pcg32_fast gen(random_seed);
+ ::arrow::random::uniform_real_distribution<double> d(0.0, 1.0);
+ std::generate(null_bytes, null_bytes + n,
+ [&d, &gen, &pct_null] { return d(gen) > pct_null; });
+}
+
+void random_is_valid(int64_t n, double pct_null, std::vector<bool>* is_valid,
+ int random_seed) {
+ pcg32_fast gen(random_seed);
+ ::arrow::random::uniform_real_distribution<double> d(0.0, 1.0);
+ is_valid->resize(n, false);
+ std::generate(is_valid->begin(), is_valid->end(),
+ [&d, &gen, &pct_null] { return d(gen) > pct_null; });
+}
+
+void random_bytes(int64_t n, uint32_t seed, uint8_t* out) {
+ pcg32_fast gen(seed);
+ std::uniform_int_distribution<uint32_t> d(0, std::numeric_limits<uint8_t>::max());
+ std::generate(out, out + n, [&d, &gen] { return static_cast<uint8_t>(d(gen)); });
+}
+
+std::string random_string(int64_t n, uint32_t seed) {
+ std::string s;
+ s.resize(static_cast<size_t>(n));
+ random_bytes(n, seed, reinterpret_cast<uint8_t*>(&s[0]));
+ return s;
+}
+
+void random_decimals(int64_t n, uint32_t seed, int32_t precision, uint8_t* out) {
+ pcg32_fast gen(seed);
+ std::uniform_int_distribution<uint32_t> d(0, std::numeric_limits<uint8_t>::max());
+ const int32_t required_bytes = DecimalType::DecimalSize(precision);
+ constexpr int32_t byte_width = 16;
+ std::fill(out, out + byte_width * n, '\0');
+
+ for (int64_t i = 0; i < n; ++i, out += byte_width) {
+ std::generate(out, out + required_bytes,
+ [&d, &gen] { return static_cast<uint8_t>(d(gen)); });
+
+ // sign extend if the sign bit is set for the last byte generated
+ // 0b10000000 == 0x80 == 128
+ if ((out[required_bytes - 1] & '\x80') != 0) {
+ std::fill(out + required_bytes, out + byte_width, '\xFF');
+ }
+ }
+}
+
+void random_ascii(int64_t n, uint32_t seed, uint8_t* out) {
+ rand_uniform_int(n, seed, static_cast<int32_t>('A'), static_cast<int32_t>('z'), out);
+}
+
+int64_t CountNulls(const std::vector<uint8_t>& valid_bytes) {
+ return static_cast<int64_t>(std::count(valid_bytes.cbegin(), valid_bytes.cend(), '\0'));
+}
+
+Status MakeRandomByteBuffer(int64_t length, MemoryPool* pool,
+ std::shared_ptr<ResizableBuffer>* out, uint32_t seed) {
+ ARROW_ASSIGN_OR_RAISE(auto result, AllocateResizableBuffer(length, pool));
+ random_bytes(length, seed, result->mutable_data());
+ *out = std::move(result);
+ return Status::OK();
+}
+
+Status GetTestResourceRoot(std::string* out) {
+ const char* c_root = std::getenv("ARROW_TEST_DATA");
+ if (!c_root) {
+ return Status::IOError(
+ "Test resources not found, set ARROW_TEST_DATA to <repo root>/testing/data");
+ }
+ *out = std::string(c_root);
+ return Status::OK();
+}
+
+int GetListenPort() {
+ // Get a new available port number by binding a socket to an ephemeral port
+ // and then closing it. Since ephemeral port allocation tends to avoid
+ // reusing port numbers, this should give a different port number
+ // every time, even across processes.
+ struct sockaddr_in sin;
+#ifdef _WIN32
+ SOCKET sock_fd;
+ auto sin_len = static_cast<int>(sizeof(sin));
+ auto errno_message = []() -> std::string {
+ return internal::WinErrorMessage(WSAGetLastError());
+ };
+#else
+#define INVALID_SOCKET -1
+#define SOCKET_ERROR -1
+ int sock_fd;
+ auto sin_len = static_cast<socklen_t>(sizeof(sin));
+ auto errno_message = []() -> std::string { return internal::ErrnoMessage(errno); };
+#endif
+
+#ifdef _WIN32
+ WSADATA wsa_data;
+ if (WSAStartup(0x0202, &wsa_data) != 0) {
+ ARROW_LOG(FATAL) << "Failed to initialize Windows Sockets";
+ }
+#endif
+
+ sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ if (sock_fd == INVALID_SOCKET) {
+ Status::IOError("Failed to create TCP socket: ", errno_message()).Abort();
+ }
+ // First bind to ('0.0.0.0', 0)
+ memset(&sin, 0, sizeof(sin));
+ sin.sin_family = AF_INET;
+ if (bind(sock_fd, reinterpret_cast<struct sockaddr*>(&sin), sin_len) == SOCKET_ERROR) {
+ Status::IOError("bind() failed: ", errno_message()).Abort();
+ }
+ // Then get actual bound port number
+ if (getsockname(sock_fd, reinterpret_cast<struct sockaddr*>(&sin), &sin_len) ==
+ SOCKET_ERROR) {
+ Status::IOError("getsockname() failed: ", errno_message()).Abort();
+ }
+ int port = ntohs(sin.sin_port);
+#ifdef _WIN32
+ closesocket(sock_fd);
+#else
+ close(sock_fd);
+#endif
+
+ return port;
+}
+
+const std::vector<std::shared_ptr<DataType>>& all_dictionary_index_types() {
+ static std::vector<std::shared_ptr<DataType>> types = {
+ int8(), uint8(), int16(), uint16(), int32(), uint32(), int64(), uint64()};
+ return types;
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/util.h b/src/arrow/cpp/src/arrow/testing/util.h
new file mode 100644
index 000000000..05fb8c68e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/util.h
@@ -0,0 +1,190 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/builder_primitive.h"
+#include "arrow/record_batch.h"
+#include "arrow/status.h"
+#include "arrow/testing/visibility.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+template <typename T>
+Status CopyBufferFromVector(const std::vector<T>& values, MemoryPool* pool,
+ std::shared_ptr<Buffer>* result) {
+ int64_t nbytes = static_cast<int>(values.size()) * sizeof(T);
+
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(nbytes, pool));
+ auto immutable_data = reinterpret_cast<const uint8_t*>(values.data());
+ std::copy(immutable_data, immutable_data + nbytes, buffer->mutable_data());
+ memset(buffer->mutable_data() + nbytes, 0,
+ static_cast<size_t>(buffer->capacity() - nbytes));
+
+ *result = std::move(buffer);
+ return Status::OK();
+}
+
+// Sets approximately pct_null of the first n bytes in null_bytes to zero
+// and the rest to non-zero (true) values.
+ARROW_TESTING_EXPORT void random_null_bytes(int64_t n, double pct_null,
+ uint8_t* null_bytes);
+ARROW_TESTING_EXPORT void random_is_valid(int64_t n, double pct_null,
+ std::vector<bool>* is_valid,
+ int random_seed = 0);
+ARROW_TESTING_EXPORT void random_bytes(int64_t n, uint32_t seed, uint8_t* out);
+ARROW_TESTING_EXPORT std::string random_string(int64_t n, uint32_t seed);
+ARROW_TESTING_EXPORT int32_t DecimalSize(int32_t precision);
+ARROW_TESTING_EXPORT void random_decimals(int64_t n, uint32_t seed, int32_t precision,
+ uint8_t* out);
+ARROW_TESTING_EXPORT void random_ascii(int64_t n, uint32_t seed, uint8_t* out);
+ARROW_TESTING_EXPORT int64_t CountNulls(const std::vector<uint8_t>& valid_bytes);
+
+ARROW_TESTING_EXPORT Status MakeRandomByteBuffer(int64_t length, MemoryPool* pool,
+ std::shared_ptr<ResizableBuffer>* out,
+ uint32_t seed = 0);
+
+ARROW_TESTING_EXPORT uint64_t random_seed();
+
+template <class T, class Builder>
+Status MakeArray(const std::vector<uint8_t>& valid_bytes, const std::vector<T>& values,
+ int64_t size, Builder* builder, std::shared_ptr<Array>* out) {
+ // Append the first 1000
+ for (int64_t i = 0; i < size; ++i) {
+ if (valid_bytes[i] > 0) {
+ RETURN_NOT_OK(builder->Append(values[i]));
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ }
+ return builder->Finish(out);
+}
+
+#define DECL_T() typedef typename TestFixture::T T;
+
+#define DECL_TYPE() typedef typename TestFixture::Type Type;
+
+// ----------------------------------------------------------------------
+// A RecordBatchReader for serving a sequence of in-memory record batches
+
+class BatchIterator : public RecordBatchReader {
+ public:
+ BatchIterator(const std::shared_ptr<Schema>& schema,
+ const std::vector<std::shared_ptr<RecordBatch>>& batches)
+ : schema_(schema), batches_(batches), position_(0) {}
+
+ std::shared_ptr<Schema> schema() const override { return schema_; }
+
+ Status ReadNext(std::shared_ptr<RecordBatch>* out) override {
+ if (position_ >= batches_.size()) {
+ *out = nullptr;
+ } else {
+ *out = batches_[position_++];
+ }
+ return Status::OK();
+ }
+
+ private:
+ std::shared_ptr<Schema> schema_;
+ std::vector<std::shared_ptr<RecordBatch>> batches_;
+ size_t position_;
+};
+
+template <typename Fn>
+struct VisitBuilderImpl {
+ template <typename T, typename BuilderType = typename TypeTraits<T>::BuilderType,
+ // need to let SFINAE drop this Visit when it would result in
+ // [](NullBuilder*){}(double_builder)
+ typename = decltype(std::declval<Fn>()(std::declval<BuilderType*>()))>
+ Status Visit(const T&) {
+ fn_(internal::checked_cast<BuilderType*>(builder_));
+ return Status::OK();
+ }
+
+ Status Visit(const DataType& t) {
+ return Status::NotImplemented("visiting builders of type ", t);
+ }
+
+ Status Visit() { return VisitTypeInline(*builder_->type(), this); }
+
+ ArrayBuilder* builder_;
+ Fn fn_;
+};
+
+template <typename Fn>
+Status VisitBuilder(ArrayBuilder* builder, Fn&& fn) {
+ return VisitBuilderImpl<Fn>{builder, std::forward<Fn>(fn)}.Visit();
+}
+
+template <typename Fn>
+Result<std::shared_ptr<Array>> ArrayFromBuilderVisitor(
+ const std::shared_ptr<DataType>& type, int64_t initial_capacity,
+ int64_t visitor_repetitions, Fn&& fn) {
+ std::unique_ptr<ArrayBuilder> builder;
+ RETURN_NOT_OK(MakeBuilder(default_memory_pool(), type, &builder));
+
+ if (initial_capacity != 0) {
+ RETURN_NOT_OK(builder->Resize(initial_capacity));
+ }
+
+ for (int64_t i = 0; i < visitor_repetitions; ++i) {
+ RETURN_NOT_OK(VisitBuilder(builder.get(), std::forward<Fn>(fn)));
+ }
+
+ std::shared_ptr<Array> out;
+ RETURN_NOT_OK(builder->Finish(&out));
+ return std::move(out);
+}
+
+template <typename Fn>
+Result<std::shared_ptr<Array>> ArrayFromBuilderVisitor(
+ const std::shared_ptr<DataType>& type, int64_t length, Fn&& fn) {
+ return ArrayFromBuilderVisitor(type, length, length, std::forward<Fn>(fn));
+}
+
+static inline std::vector<std::shared_ptr<DataType> (*)(FieldVector, std::vector<int8_t>)>
+UnionTypeFactories() {
+ return {sparse_union, dense_union};
+}
+
+// Return the value of the ARROW_TEST_DATA environment variable or return error
+// Status
+ARROW_TESTING_EXPORT Status GetTestResourceRoot(std::string*);
+
+// Get a TCP port number to listen on. This is a different number every time,
+// as reusing the same port across tests can produce spurious bind errors on
+// Windows.
+ARROW_TESTING_EXPORT int GetListenPort();
+
+ARROW_TESTING_EXPORT
+const std::vector<std::shared_ptr<DataType>>& all_dictionary_index_types();
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/testing/visibility.h b/src/arrow/cpp/src/arrow/testing/visibility.h
new file mode 100644
index 000000000..1b2aa7cd8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/testing/visibility.h
@@ -0,0 +1,48 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#if defined(_WIN32) || defined(__CYGWIN__)
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4251)
+#else
+#pragma GCC diagnostic ignored "-Wattributes"
+#endif
+
+#ifdef ARROW_TESTING_STATIC
+#define ARROW_TESTING_EXPORT
+#elif defined(ARROW_TESTING_EXPORTING)
+#define ARROW_TESTING_EXPORT __declspec(dllexport)
+#else
+#define ARROW_TESTING_EXPORT __declspec(dllimport)
+#endif
+
+#define ARROW_TESTING_NO_EXPORT
+#else // Not Windows
+#ifndef ARROW_TESTING_EXPORT
+#define ARROW_TESTING_EXPORT __attribute__((visibility("default")))
+#endif
+#ifndef ARROW_TESTING_NO_EXPORT
+#define ARROW_TESTING_NO_EXPORT __attribute__((visibility("hidden")))
+#endif
+#endif // Non-Windows
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
diff --git a/src/arrow/cpp/src/arrow/type.cc b/src/arrow/cpp/src/arrow/type.cc
new file mode 100644
index 000000000..ab5e15ed7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/type.cc
@@ -0,0 +1,2428 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/type.h"
+
+#include <algorithm>
+#include <climits>
+#include <cstddef>
+#include <limits>
+#include <mutex>
+#include <ostream>
+#include <sstream> // IWYU pragma: keep
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/compare.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/hash_util.h"
+#include "arrow/util/hashing.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/range.h"
+#include "arrow/util/vector.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+
+constexpr Type::type NullType::type_id;
+constexpr Type::type ListType::type_id;
+constexpr Type::type LargeListType::type_id;
+
+constexpr Type::type MapType::type_id;
+
+constexpr Type::type FixedSizeListType::type_id;
+
+constexpr Type::type BinaryType::type_id;
+
+constexpr Type::type LargeBinaryType::type_id;
+
+constexpr Type::type StringType::type_id;
+
+constexpr Type::type LargeStringType::type_id;
+
+constexpr Type::type FixedSizeBinaryType::type_id;
+
+constexpr Type::type StructType::type_id;
+
+constexpr Type::type Decimal128Type::type_id;
+
+constexpr Type::type Decimal256Type::type_id;
+
+constexpr Type::type SparseUnionType::type_id;
+
+constexpr Type::type DenseUnionType::type_id;
+
+constexpr Type::type Date32Type::type_id;
+
+constexpr Type::type Date64Type::type_id;
+
+constexpr Type::type Time32Type::type_id;
+
+constexpr Type::type Time64Type::type_id;
+
+constexpr Type::type TimestampType::type_id;
+
+constexpr Type::type MonthIntervalType::type_id;
+
+constexpr Type::type DayTimeIntervalType::type_id;
+
+constexpr Type::type MonthDayNanoIntervalType::type_id;
+
+constexpr Type::type DurationType::type_id;
+
+constexpr Type::type DictionaryType::type_id;
+
+namespace internal {
+
+struct TypeIdToTypeNameVisitor {
+ std::string out;
+
+ template <typename ArrowType>
+ Status Visit(const ArrowType*) {
+ out = ArrowType::type_name();
+ return Status::OK();
+ }
+};
+
+std::string ToTypeName(Type::type id) {
+ TypeIdToTypeNameVisitor visitor;
+
+ ARROW_CHECK_OK(VisitTypeIdInline(id, &visitor));
+ return std::move(visitor.out);
+}
+
+std::string ToString(Type::type id) {
+ switch (id) {
+#define TO_STRING_CASE(_id) \
+ case Type::_id: \
+ return ARROW_STRINGIFY(_id);
+
+ TO_STRING_CASE(NA)
+ TO_STRING_CASE(BOOL)
+ TO_STRING_CASE(INT8)
+ TO_STRING_CASE(INT16)
+ TO_STRING_CASE(INT32)
+ TO_STRING_CASE(INT64)
+ TO_STRING_CASE(UINT8)
+ TO_STRING_CASE(UINT16)
+ TO_STRING_CASE(UINT32)
+ TO_STRING_CASE(UINT64)
+ TO_STRING_CASE(HALF_FLOAT)
+ TO_STRING_CASE(FLOAT)
+ TO_STRING_CASE(DOUBLE)
+ TO_STRING_CASE(DECIMAL128)
+ TO_STRING_CASE(DECIMAL256)
+ TO_STRING_CASE(DATE32)
+ TO_STRING_CASE(DATE64)
+ TO_STRING_CASE(TIME32)
+ TO_STRING_CASE(TIME64)
+ TO_STRING_CASE(TIMESTAMP)
+ TO_STRING_CASE(INTERVAL_DAY_TIME)
+ TO_STRING_CASE(INTERVAL_MONTH_DAY_NANO)
+ TO_STRING_CASE(INTERVAL_MONTHS)
+ TO_STRING_CASE(DURATION)
+ TO_STRING_CASE(STRING)
+ TO_STRING_CASE(BINARY)
+ TO_STRING_CASE(LARGE_STRING)
+ TO_STRING_CASE(LARGE_BINARY)
+ TO_STRING_CASE(FIXED_SIZE_BINARY)
+ TO_STRING_CASE(STRUCT)
+ TO_STRING_CASE(LIST)
+ TO_STRING_CASE(LARGE_LIST)
+ TO_STRING_CASE(FIXED_SIZE_LIST)
+ TO_STRING_CASE(MAP)
+ TO_STRING_CASE(DENSE_UNION)
+ TO_STRING_CASE(SPARSE_UNION)
+ TO_STRING_CASE(DICTIONARY)
+ TO_STRING_CASE(EXTENSION)
+
+#undef TO_STRING_CASE
+
+ default:
+ ARROW_LOG(FATAL) << "Unhandled type id: " << id;
+ return "";
+ }
+}
+
+std::string ToString(TimeUnit::type unit) {
+ switch (unit) {
+ case TimeUnit::SECOND:
+ return "s";
+ case TimeUnit::MILLI:
+ return "ms";
+ case TimeUnit::MICRO:
+ return "us";
+ case TimeUnit::NANO:
+ return "ns";
+ default:
+ DCHECK(false);
+ return "";
+ }
+}
+
+int GetByteWidth(const DataType& type) {
+ const auto& fw_type = checked_cast<const FixedWidthType&>(type);
+ return fw_type.bit_width() / CHAR_BIT;
+}
+
+} // namespace internal
+
+namespace {
+
+struct PhysicalTypeVisitor {
+ const std::shared_ptr<DataType>& real_type;
+ std::shared_ptr<DataType> result;
+
+ Status Visit(const DataType&) {
+ result = real_type;
+ return Status::OK();
+ }
+
+ template <typename Type, typename PhysicalType = typename Type::PhysicalType>
+ Status Visit(const Type&) {
+ result = TypeTraits<PhysicalType>::type_singleton();
+ return Status::OK();
+ }
+};
+
+} // namespace
+
+std::shared_ptr<DataType> GetPhysicalType(const std::shared_ptr<DataType>& real_type) {
+ PhysicalTypeVisitor visitor{real_type, {}};
+ ARROW_CHECK_OK(VisitTypeInline(*real_type, &visitor));
+ return std::move(visitor.result);
+}
+
+namespace {
+
+using internal::checked_cast;
+
+// Merges `existing` and `other` if one of them is of NullType, otherwise
+// returns nullptr.
+// - if `other` if of NullType or is nullable, the unified field will be nullable.
+// - if `existing` is of NullType but other is not, the unified field will
+// have `other`'s type and will be nullable
+std::shared_ptr<Field> MaybePromoteNullTypes(const Field& existing, const Field& other) {
+ if (existing.type()->id() != Type::NA && other.type()->id() != Type::NA) {
+ return nullptr;
+ }
+ if (existing.type()->id() == Type::NA) {
+ return other.WithNullable(true)->WithMetadata(existing.metadata());
+ }
+ // `other` must be null.
+ return existing.WithNullable(true);
+}
+} // namespace
+
+Field::~Field() {}
+
+bool Field::HasMetadata() const {
+ return (metadata_ != nullptr) && (metadata_->size() > 0);
+}
+
+std::shared_ptr<Field> Field::WithMetadata(
+ const std::shared_ptr<const KeyValueMetadata>& metadata) const {
+ return std::make_shared<Field>(name_, type_, nullable_, metadata);
+}
+
+std::shared_ptr<Field> Field::WithMergedMetadata(
+ const std::shared_ptr<const KeyValueMetadata>& metadata) const {
+ std::shared_ptr<const KeyValueMetadata> merged_metadata;
+ if (metadata_) {
+ merged_metadata = metadata_->Merge(*metadata);
+ } else {
+ merged_metadata = metadata;
+ }
+ return std::make_shared<Field>(name_, type_, nullable_, merged_metadata);
+}
+
+std::shared_ptr<Field> Field::RemoveMetadata() const {
+ return std::make_shared<Field>(name_, type_, nullable_);
+}
+
+std::shared_ptr<Field> Field::WithType(const std::shared_ptr<DataType>& type) const {
+ return std::make_shared<Field>(name_, type, nullable_, metadata_);
+}
+
+std::shared_ptr<Field> Field::WithName(const std::string& name) const {
+ return std::make_shared<Field>(name, type_, nullable_, metadata_);
+}
+
+std::shared_ptr<Field> Field::WithNullable(const bool nullable) const {
+ return std::make_shared<Field>(name_, type_, nullable, metadata_);
+}
+
+Result<std::shared_ptr<Field>> Field::MergeWith(const Field& other,
+ MergeOptions options) const {
+ if (name() != other.name()) {
+ return Status::Invalid("Field ", name(), " doesn't have the same name as ",
+ other.name());
+ }
+
+ if (Equals(other, /*check_metadata=*/false)) {
+ return Copy();
+ }
+
+ if (options.promote_nullability) {
+ if (type()->Equals(other.type())) {
+ return Copy()->WithNullable(nullable() || other.nullable());
+ }
+ std::shared_ptr<Field> promoted = MaybePromoteNullTypes(*this, other);
+ if (promoted) return promoted;
+ }
+
+ return Status::Invalid("Unable to merge: Field ", name(),
+ " has incompatible types: ", type()->ToString(), " vs ",
+ other.type()->ToString());
+}
+
+Result<std::shared_ptr<Field>> Field::MergeWith(const std::shared_ptr<Field>& other,
+ MergeOptions options) const {
+ DCHECK_NE(other, nullptr);
+ return MergeWith(*other, options);
+}
+
+std::vector<std::shared_ptr<Field>> Field::Flatten() const {
+ std::vector<std::shared_ptr<Field>> flattened;
+ if (type_->id() == Type::STRUCT) {
+ for (const auto& child : type_->fields()) {
+ auto flattened_child = child->Copy();
+ flattened.push_back(flattened_child);
+ flattened_child->name_.insert(0, name() + ".");
+ flattened_child->nullable_ |= nullable_;
+ }
+ } else {
+ flattened.push_back(this->Copy());
+ }
+ return flattened;
+}
+
+std::shared_ptr<Field> Field::Copy() const {
+ return ::arrow::field(name_, type_, nullable_, metadata_);
+}
+
+bool Field::Equals(const Field& other, bool check_metadata) const {
+ if (this == &other) {
+ return true;
+ }
+ if (this->name_ == other.name_ && this->nullable_ == other.nullable_ &&
+ this->type_->Equals(*other.type_.get(), check_metadata)) {
+ if (!check_metadata) {
+ return true;
+ } else if (this->HasMetadata() && other.HasMetadata()) {
+ return metadata_->Equals(*other.metadata_);
+ } else if (!this->HasMetadata() && !other.HasMetadata()) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+ return false;
+}
+
+bool Field::Equals(const std::shared_ptr<Field>& other, bool check_metadata) const {
+ return Equals(*other.get(), check_metadata);
+}
+
+bool Field::IsCompatibleWith(const Field& other) const { return MergeWith(other).ok(); }
+
+bool Field::IsCompatibleWith(const std::shared_ptr<Field>& other) const {
+ DCHECK_NE(other, nullptr);
+ return IsCompatibleWith(*other);
+}
+
+std::string Field::ToString(bool show_metadata) const {
+ std::stringstream ss;
+ ss << name_ << ": " << type_->ToString();
+ if (!nullable_) {
+ ss << " not null";
+ }
+ if (show_metadata && metadata_) {
+ ss << metadata_->ToString();
+ }
+ return ss.str();
+}
+
+DataType::~DataType() {}
+
+bool DataType::Equals(const DataType& other, bool check_metadata) const {
+ return TypeEquals(*this, other, check_metadata);
+}
+
+bool DataType::Equals(const std::shared_ptr<DataType>& other) const {
+ if (!other) {
+ return false;
+ }
+ return Equals(*other.get());
+}
+
+size_t DataType::Hash() const {
+ static constexpr size_t kHashSeed = 0;
+ size_t result = kHashSeed;
+ internal::hash_combine(result, this->fingerprint());
+ return result;
+}
+
+std::ostream& operator<<(std::ostream& os, const DataType& type) {
+ os << type.ToString();
+ return os;
+}
+
+FloatingPointType::Precision HalfFloatType::precision() const {
+ return FloatingPointType::HALF;
+}
+
+FloatingPointType::Precision FloatType::precision() const {
+ return FloatingPointType::SINGLE;
+}
+
+FloatingPointType::Precision DoubleType::precision() const {
+ return FloatingPointType::DOUBLE;
+}
+
+std::ostream& operator<<(std::ostream& os,
+ DayTimeIntervalType::DayMilliseconds interval) {
+ os << interval.days << "d" << interval.milliseconds << "ms";
+ return os;
+}
+
+std::ostream& operator<<(std::ostream& os,
+ MonthDayNanoIntervalType::MonthDayNanos interval) {
+ os << interval.months << "M" << interval.days << "d" << interval.nanoseconds << "ns";
+ return os;
+}
+
+std::string ListType::ToString() const {
+ std::stringstream s;
+ s << "list<" << value_field()->ToString() << ">";
+ return s.str();
+}
+
+std::string LargeListType::ToString() const {
+ std::stringstream s;
+ s << "large_list<" << value_field()->ToString() << ">";
+ return s.str();
+}
+
+MapType::MapType(std::shared_ptr<DataType> key_type, std::shared_ptr<DataType> item_type,
+ bool keys_sorted)
+ : MapType(::arrow::field("key", std::move(key_type), false),
+ ::arrow::field("value", std::move(item_type)), keys_sorted) {}
+
+MapType::MapType(std::shared_ptr<DataType> key_type, std::shared_ptr<Field> item_field,
+ bool keys_sorted)
+ : MapType(::arrow::field("key", std::move(key_type), false), std::move(item_field),
+ keys_sorted) {}
+
+MapType::MapType(std::shared_ptr<Field> key_field, std::shared_ptr<Field> item_field,
+ bool keys_sorted)
+ : MapType(
+ ::arrow::field("entries",
+ struct_({std::move(key_field), std::move(item_field)}), false),
+ keys_sorted) {}
+
+MapType::MapType(std::shared_ptr<Field> value_field, bool keys_sorted)
+ : ListType(std::move(value_field)), keys_sorted_(keys_sorted) {
+ id_ = type_id;
+}
+
+Result<std::shared_ptr<DataType>> MapType::Make(std::shared_ptr<Field> value_field,
+ bool keys_sorted) {
+ const auto& value_type = *value_field->type();
+ if (value_field->nullable() || value_type.id() != Type::STRUCT) {
+ return Status::TypeError("Map entry field should be non-nullable struct");
+ }
+ const auto& struct_type = checked_cast<const StructType&>(value_type);
+ if (struct_type.num_fields() != 2) {
+ return Status::TypeError("Map entry field should have two children (got ",
+ struct_type.num_fields(), ")");
+ }
+ if (struct_type.field(0)->nullable()) {
+ return Status::TypeError("Map key field should be non-nullable");
+ }
+ return std::make_shared<MapType>(std::move(value_field), keys_sorted);
+}
+
+std::string MapType::ToString() const {
+ std::stringstream s;
+
+ const auto print_field_name = [](std::ostream& os, const Field& field,
+ const char* std_name) {
+ if (field.name() != std_name) {
+ os << " ('" << field.name() << "')";
+ }
+ };
+ const auto print_field = [&](std::ostream& os, const Field& field,
+ const char* std_name) {
+ os << field.type()->ToString();
+ print_field_name(os, field, std_name);
+ };
+
+ s << "map<";
+ print_field(s, *key_field(), "key");
+ s << ", ";
+ print_field(s, *item_field(), "value");
+ if (keys_sorted_) {
+ s << ", keys_sorted";
+ }
+ print_field_name(s, *value_field(), "entries");
+ s << ">";
+ return s.str();
+}
+
+std::string FixedSizeListType::ToString() const {
+ std::stringstream s;
+ s << "fixed_size_list<" << value_field()->ToString() << ">[" << list_size_ << "]";
+ return s.str();
+}
+
+std::string BinaryType::ToString() const { return "binary"; }
+
+std::string LargeBinaryType::ToString() const { return "large_binary"; }
+
+std::string StringType::ToString() const { return "string"; }
+
+std::string LargeStringType::ToString() const { return "large_string"; }
+
+int FixedSizeBinaryType::bit_width() const { return CHAR_BIT * byte_width(); }
+
+Result<std::shared_ptr<DataType>> FixedSizeBinaryType::Make(int32_t byte_width) {
+ if (byte_width < 0) {
+ return Status::Invalid("Negative FixedSizeBinaryType byte width");
+ }
+ if (byte_width > std::numeric_limits<int>::max() / CHAR_BIT) {
+ // bit_width() would overflow
+ return Status::Invalid("byte width of FixedSizeBinaryType too large");
+ }
+ return std::make_shared<FixedSizeBinaryType>(byte_width);
+}
+
+std::string FixedSizeBinaryType::ToString() const {
+ std::stringstream ss;
+ ss << "fixed_size_binary[" << byte_width_ << "]";
+ return ss.str();
+}
+
+// ----------------------------------------------------------------------
+// Date types
+
+DateType::DateType(Type::type type_id) : TemporalType(type_id) {}
+
+Date32Type::Date32Type() : DateType(Type::DATE32) {}
+
+Date64Type::Date64Type() : DateType(Type::DATE64) {}
+
+std::string Date64Type::ToString() const { return std::string("date64[ms]"); }
+
+std::string Date32Type::ToString() const { return std::string("date32[day]"); }
+
+// ----------------------------------------------------------------------
+// Time types
+
+TimeType::TimeType(Type::type type_id, TimeUnit::type unit)
+ : TemporalType(type_id), unit_(unit) {}
+
+Time32Type::Time32Type(TimeUnit::type unit) : TimeType(Type::TIME32, unit) {
+ ARROW_CHECK(unit == TimeUnit::SECOND || unit == TimeUnit::MILLI)
+ << "Must be seconds or milliseconds";
+}
+
+std::string Time32Type::ToString() const {
+ std::stringstream ss;
+ ss << "time32[" << this->unit_ << "]";
+ return ss.str();
+}
+
+Time64Type::Time64Type(TimeUnit::type unit) : TimeType(Type::TIME64, unit) {
+ ARROW_CHECK(unit == TimeUnit::MICRO || unit == TimeUnit::NANO)
+ << "Must be microseconds or nanoseconds";
+}
+
+std::string Time64Type::ToString() const {
+ std::stringstream ss;
+ ss << "time64[" << this->unit_ << "]";
+ return ss.str();
+}
+
+std::ostream& operator<<(std::ostream& os, TimeUnit::type unit) {
+ switch (unit) {
+ case TimeUnit::SECOND:
+ os << "s";
+ break;
+ case TimeUnit::MILLI:
+ os << "ms";
+ break;
+ case TimeUnit::MICRO:
+ os << "us";
+ break;
+ case TimeUnit::NANO:
+ os << "ns";
+ break;
+ }
+ return os;
+}
+
+// ----------------------------------------------------------------------
+// Timestamp types
+
+std::string TimestampType::ToString() const {
+ std::stringstream ss;
+ ss << "timestamp[" << this->unit_;
+ if (this->timezone_.size() > 0) {
+ ss << ", tz=" << this->timezone_;
+ }
+ ss << "]";
+ return ss.str();
+}
+
+// Duration types
+std::string DurationType::ToString() const {
+ std::stringstream ss;
+ ss << "duration[" << this->unit_ << "]";
+ return ss.str();
+}
+
+// ----------------------------------------------------------------------
+// Union type
+
+constexpr int8_t UnionType::kMaxTypeCode;
+constexpr int UnionType::kInvalidChildId;
+
+UnionMode::type UnionType::mode() const {
+ return id_ == Type::SPARSE_UNION ? UnionMode::SPARSE : UnionMode::DENSE;
+}
+
+UnionType::UnionType(std::vector<std::shared_ptr<Field>> fields,
+ std::vector<int8_t> type_codes, Type::type id)
+ : NestedType(id),
+ type_codes_(std::move(type_codes)),
+ child_ids_(kMaxTypeCode + 1, kInvalidChildId) {
+ children_ = std::move(fields);
+ DCHECK_OK(ValidateParameters(children_, type_codes_, mode()));
+ for (int child_id = 0; child_id < static_cast<int>(type_codes_.size()); ++child_id) {
+ const auto type_code = type_codes_[child_id];
+ child_ids_[type_code] = child_id;
+ }
+}
+
+Status UnionType::ValidateParameters(const std::vector<std::shared_ptr<Field>>& fields,
+ const std::vector<int8_t>& type_codes,
+ UnionMode::type mode) {
+ if (fields.size() != type_codes.size()) {
+ return Status::Invalid("Union should get the same number of fields as type codes");
+ }
+ for (const auto type_code : type_codes) {
+ if (type_code < 0 || type_code > kMaxTypeCode) {
+ return Status::Invalid("Union type code out of bounds");
+ }
+ }
+ return Status::OK();
+}
+
+DataTypeLayout UnionType::layout() const {
+ if (mode() == UnionMode::SPARSE) {
+ return DataTypeLayout(
+ {DataTypeLayout::AlwaysNull(), DataTypeLayout::FixedWidth(sizeof(uint8_t))});
+ } else {
+ return DataTypeLayout({DataTypeLayout::AlwaysNull(),
+ DataTypeLayout::FixedWidth(sizeof(uint8_t)),
+ DataTypeLayout::FixedWidth(sizeof(int32_t))});
+ }
+}
+
+uint8_t UnionType::max_type_code() const {
+ return type_codes_.size() == 0
+ ? 0
+ : *std::max_element(type_codes_.begin(), type_codes_.end());
+}
+
+std::string UnionType::ToString() const {
+ std::stringstream s;
+
+ s << name() << "<";
+
+ for (size_t i = 0; i < children_.size(); ++i) {
+ if (i) {
+ s << ", ";
+ }
+ s << children_[i]->ToString() << "=" << static_cast<int>(type_codes_[i]);
+ }
+ s << ">";
+ return s.str();
+}
+
+SparseUnionType::SparseUnionType(std::vector<std::shared_ptr<Field>> fields,
+ std::vector<int8_t> type_codes)
+ : UnionType(fields, type_codes, Type::SPARSE_UNION) {}
+
+Result<std::shared_ptr<DataType>> SparseUnionType::Make(
+ std::vector<std::shared_ptr<Field>> fields, std::vector<int8_t> type_codes) {
+ RETURN_NOT_OK(ValidateParameters(fields, type_codes, UnionMode::SPARSE));
+ return std::make_shared<SparseUnionType>(fields, type_codes);
+}
+
+DenseUnionType::DenseUnionType(std::vector<std::shared_ptr<Field>> fields,
+ std::vector<int8_t> type_codes)
+ : UnionType(fields, type_codes, Type::DENSE_UNION) {}
+
+Result<std::shared_ptr<DataType>> DenseUnionType::Make(
+ std::vector<std::shared_ptr<Field>> fields, std::vector<int8_t> type_codes) {
+ RETURN_NOT_OK(ValidateParameters(fields, type_codes, UnionMode::DENSE));
+ return std::make_shared<DenseUnionType>(fields, type_codes);
+}
+
+// ----------------------------------------------------------------------
+// Struct type
+
+namespace {
+
+std::unordered_multimap<std::string, int> CreateNameToIndexMap(
+ const std::vector<std::shared_ptr<Field>>& fields) {
+ std::unordered_multimap<std::string, int> name_to_index;
+ for (size_t i = 0; i < fields.size(); ++i) {
+ name_to_index.emplace(fields[i]->name(), static_cast<int>(i));
+ }
+ return name_to_index;
+}
+
+template <int NotFoundValue = -1, int DuplicateFoundValue = -1>
+int LookupNameIndex(const std::unordered_multimap<std::string, int>& name_to_index,
+ const std::string& name) {
+ auto p = name_to_index.equal_range(name);
+ auto it = p.first;
+ if (it == p.second) {
+ // Not found
+ return NotFoundValue;
+ }
+ auto index = it->second;
+ if (++it != p.second) {
+ // Duplicate field name
+ return DuplicateFoundValue;
+ }
+ return index;
+}
+
+} // namespace
+
+class StructType::Impl {
+ public:
+ explicit Impl(const std::vector<std::shared_ptr<Field>>& fields)
+ : name_to_index_(CreateNameToIndexMap(fields)) {}
+
+ const std::unordered_multimap<std::string, int> name_to_index_;
+};
+
+StructType::StructType(const std::vector<std::shared_ptr<Field>>& fields)
+ : NestedType(Type::STRUCT), impl_(new Impl(fields)) {
+ children_ = fields;
+}
+
+StructType::~StructType() {}
+
+std::string StructType::ToString() const {
+ std::stringstream s;
+ s << "struct<";
+ for (int i = 0; i < this->num_fields(); ++i) {
+ if (i > 0) {
+ s << ", ";
+ }
+ std::shared_ptr<Field> field = this->field(i);
+ s << field->ToString();
+ }
+ s << ">";
+ return s.str();
+}
+
+std::shared_ptr<Field> StructType::GetFieldByName(const std::string& name) const {
+ int i = GetFieldIndex(name);
+ return i == -1 ? nullptr : children_[i];
+}
+
+int StructType::GetFieldIndex(const std::string& name) const {
+ return LookupNameIndex(impl_->name_to_index_, name);
+}
+
+std::vector<int> StructType::GetAllFieldIndices(const std::string& name) const {
+ std::vector<int> result;
+ auto p = impl_->name_to_index_.equal_range(name);
+ for (auto it = p.first; it != p.second; ++it) {
+ result.push_back(it->second);
+ }
+ if (result.size() > 1) {
+ std::sort(result.begin(), result.end());
+ }
+ return result;
+}
+
+std::vector<std::shared_ptr<Field>> StructType::GetAllFieldsByName(
+ const std::string& name) const {
+ std::vector<std::shared_ptr<Field>> result;
+ auto p = impl_->name_to_index_.equal_range(name);
+ for (auto it = p.first; it != p.second; ++it) {
+ result.push_back(children_[it->second]);
+ }
+ return result;
+}
+
+Result<std::shared_ptr<DataType>> DecimalType::Make(Type::type type_id, int32_t precision,
+ int32_t scale) {
+ if (type_id == Type::DECIMAL128) {
+ return Decimal128Type::Make(precision, scale);
+ } else if (type_id == Type::DECIMAL256) {
+ return Decimal256Type::Make(precision, scale);
+ } else {
+ return Status::Invalid("Not a decimal type_id: ", type_id);
+ }
+}
+
+// Taken from the Apache Impala codebase. The comments next
+// to the return values are the maximum value that can be represented in 2's
+// complement with the returned number of bytes.
+int32_t DecimalType::DecimalSize(int32_t precision) {
+ DCHECK_GE(precision, 1) << "decimal precision must be greater than or equal to 1, got "
+ << precision;
+
+ // Generated in python with:
+ // >>> decimal_size = lambda prec: int(math.ceil((prec * math.log2(10) + 1) / 8))
+ // >>> [-1] + [decimal_size(i) for i in range(1, 77)]
+ constexpr int32_t kBytes[] = {
+ -1, 1, 1, 2, 2, 3, 3, 4, 4, 4, 5, 5, 6, 6, 6, 7, 7, 8, 8, 9,
+ 9, 9, 10, 10, 11, 11, 11, 12, 12, 13, 13, 13, 14, 14, 15, 15, 16, 16, 16, 17,
+ 17, 18, 18, 18, 19, 19, 20, 20, 21, 21, 21, 22, 22, 23, 23, 23, 24, 24, 25, 25,
+ 26, 26, 26, 27, 27, 28, 28, 28, 29, 29, 30, 30, 31, 31, 31, 32, 32};
+
+ if (precision <= 76) {
+ return kBytes[precision];
+ }
+ return static_cast<int32_t>(std::ceil((precision / 8.0) * std::log2(10) + 1));
+}
+
+// ----------------------------------------------------------------------
+// Decimal128 type
+
+Decimal128Type::Decimal128Type(int32_t precision, int32_t scale)
+ : DecimalType(type_id, 16, precision, scale) {
+ ARROW_CHECK_GE(precision, kMinPrecision);
+ ARROW_CHECK_LE(precision, kMaxPrecision);
+}
+
+Result<std::shared_ptr<DataType>> Decimal128Type::Make(int32_t precision, int32_t scale) {
+ if (precision < kMinPrecision || precision > kMaxPrecision) {
+ return Status::Invalid("Decimal precision out of range: ", precision);
+ }
+ return std::make_shared<Decimal128Type>(precision, scale);
+}
+
+// ----------------------------------------------------------------------
+// Decimal256 type
+
+Decimal256Type::Decimal256Type(int32_t precision, int32_t scale)
+ : DecimalType(type_id, 32, precision, scale) {
+ ARROW_CHECK_GE(precision, kMinPrecision);
+ ARROW_CHECK_LE(precision, kMaxPrecision);
+}
+
+Result<std::shared_ptr<DataType>> Decimal256Type::Make(int32_t precision, int32_t scale) {
+ if (precision < kMinPrecision || precision > kMaxPrecision) {
+ return Status::Invalid("Decimal precision out of range: ", precision);
+ }
+ return std::make_shared<Decimal256Type>(precision, scale);
+}
+
+// ----------------------------------------------------------------------
+// Dictionary-encoded type
+
+Status DictionaryType::ValidateParameters(const DataType& index_type,
+ const DataType& value_type) {
+ if (!is_integer(index_type.id())) {
+ return Status::TypeError("Dictionary index type should be integer, got ",
+ index_type.ToString());
+ }
+ return Status::OK();
+}
+
+int DictionaryType::bit_width() const {
+ return checked_cast<const FixedWidthType&>(*index_type_).bit_width();
+}
+
+Result<std::shared_ptr<DataType>> DictionaryType::Make(
+ const std::shared_ptr<DataType>& index_type,
+ const std::shared_ptr<DataType>& value_type, bool ordered) {
+ RETURN_NOT_OK(ValidateParameters(*index_type, *value_type));
+ return std::make_shared<DictionaryType>(index_type, value_type, ordered);
+}
+
+DictionaryType::DictionaryType(const std::shared_ptr<DataType>& index_type,
+ const std::shared_ptr<DataType>& value_type, bool ordered)
+ : FixedWidthType(Type::DICTIONARY),
+ index_type_(index_type),
+ value_type_(value_type),
+ ordered_(ordered) {
+ ARROW_CHECK_OK(ValidateParameters(*index_type_, *value_type_));
+}
+
+DataTypeLayout DictionaryType::layout() const {
+ auto layout = index_type_->layout();
+ layout.has_dictionary = true;
+ return layout;
+}
+
+std::string DictionaryType::ToString() const {
+ std::stringstream ss;
+ ss << this->name() << "<values=" << value_type_->ToString()
+ << ", indices=" << index_type_->ToString() << ", ordered=" << ordered_ << ">";
+ return ss.str();
+}
+
+// ----------------------------------------------------------------------
+// Null type
+
+std::string NullType::ToString() const { return name(); }
+
+// ----------------------------------------------------------------------
+// FieldRef
+
+size_t FieldPath::hash() const {
+ return internal::ComputeStringHash<0>(indices().data(), indices().size() * sizeof(int));
+}
+
+std::string FieldPath::ToString() const {
+ if (this->indices().empty()) {
+ return "FieldPath(empty)";
+ }
+
+ std::string repr = "FieldPath(";
+ for (auto index : this->indices()) {
+ repr += std::to_string(index) + " ";
+ }
+ repr.back() = ')';
+ return repr;
+}
+
+struct FieldPathGetImpl {
+ static const DataType& GetType(const ArrayData& data) { return *data.type; }
+
+ static void Summarize(const FieldVector& fields, std::stringstream* ss) {
+ *ss << "{ ";
+ for (const auto& field : fields) {
+ *ss << field->ToString() << ", ";
+ }
+ *ss << "}";
+ }
+
+ template <typename T>
+ static void Summarize(const std::vector<T>& columns, std::stringstream* ss) {
+ *ss << "{ ";
+ for (const auto& column : columns) {
+ *ss << GetType(*column) << ", ";
+ }
+ *ss << "}";
+ }
+
+ template <typename T>
+ static Status IndexError(const FieldPath* path, int out_of_range_depth,
+ const std::vector<T>& children) {
+ std::stringstream ss;
+ ss << "index out of range. ";
+
+ ss << "indices=[ ";
+ int depth = 0;
+ for (int i : path->indices()) {
+ if (depth != out_of_range_depth) {
+ ss << i << " ";
+ continue;
+ }
+ ss << ">" << i << "< ";
+ ++depth;
+ }
+ ss << "] ";
+
+ if (std::is_same<T, std::shared_ptr<Field>>::value) {
+ ss << "fields were: ";
+ } else {
+ ss << "columns had types: ";
+ }
+ Summarize(children, &ss);
+
+ return Status::IndexError(ss.str());
+ }
+
+ template <typename T, typename GetChildren>
+ static Result<T> Get(const FieldPath* path, const std::vector<T>* children,
+ GetChildren&& get_children, int* out_of_range_depth) {
+ if (path->indices().empty()) {
+ return Status::Invalid("empty indices cannot be traversed");
+ }
+
+ int depth = 0;
+ const T* out;
+ for (int index : path->indices()) {
+ if (children == nullptr) {
+ return Status::NotImplemented("Get child data of non-struct array");
+ }
+
+ if (index < 0 || static_cast<size_t>(index) >= children->size()) {
+ *out_of_range_depth = depth;
+ return nullptr;
+ }
+
+ out = &children->at(index);
+ children = get_children(*out);
+ ++depth;
+ }
+
+ return *out;
+ }
+
+ template <typename T, typename GetChildren>
+ static Result<T> Get(const FieldPath* path, const std::vector<T>* children,
+ GetChildren&& get_children) {
+ int out_of_range_depth = -1;
+ ARROW_ASSIGN_OR_RAISE(auto child,
+ Get(path, children, std::forward<GetChildren>(get_children),
+ &out_of_range_depth));
+ if (child != nullptr) {
+ return std::move(child);
+ }
+ return IndexError(path, out_of_range_depth, *children);
+ }
+
+ static Result<std::shared_ptr<Field>> Get(const FieldPath* path,
+ const FieldVector& fields) {
+ return FieldPathGetImpl::Get(path, &fields, [](const std::shared_ptr<Field>& field) {
+ return &field->type()->fields();
+ });
+ }
+
+ static Result<std::shared_ptr<ArrayData>> Get(const FieldPath* path,
+ const ArrayDataVector& child_data) {
+ return FieldPathGetImpl::Get(
+ path, &child_data,
+ [](const std::shared_ptr<ArrayData>& data) -> const ArrayDataVector* {
+ if (data->type->id() != Type::STRUCT) {
+ return nullptr;
+ }
+ return &data->child_data;
+ });
+ }
+};
+
+Result<std::shared_ptr<Field>> FieldPath::Get(const Schema& schema) const {
+ return FieldPathGetImpl::Get(this, schema.fields());
+}
+
+Result<std::shared_ptr<Field>> FieldPath::Get(const Field& field) const {
+ return FieldPathGetImpl::Get(this, field.type()->fields());
+}
+
+Result<std::shared_ptr<Field>> FieldPath::Get(const DataType& type) const {
+ return FieldPathGetImpl::Get(this, type.fields());
+}
+
+Result<std::shared_ptr<Field>> FieldPath::Get(const FieldVector& fields) const {
+ return FieldPathGetImpl::Get(this, fields);
+}
+
+Result<std::shared_ptr<Array>> FieldPath::Get(const RecordBatch& batch) const {
+ ARROW_ASSIGN_OR_RAISE(auto data, FieldPathGetImpl::Get(this, batch.column_data()));
+ return MakeArray(std::move(data));
+}
+
+Result<std::shared_ptr<Array>> FieldPath::Get(const Array& array) const {
+ ARROW_ASSIGN_OR_RAISE(auto data, Get(*array.data()));
+ return MakeArray(std::move(data));
+}
+
+Result<std::shared_ptr<ArrayData>> FieldPath::Get(const ArrayData& data) const {
+ if (data.type->id() != Type::STRUCT) {
+ return Status::NotImplemented("Get child data of non-struct array");
+ }
+ return FieldPathGetImpl::Get(this, data.child_data);
+}
+
+FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) {
+ DCHECK_GT(util::get<FieldPath>(impl_).indices().size(), 0);
+}
+
+void FieldRef::Flatten(std::vector<FieldRef> children) {
+ // flatten children
+ struct Visitor {
+ void operator()(std::string* name) { *out++ = FieldRef(std::move(*name)); }
+
+ void operator()(FieldPath* indices) { *out++ = FieldRef(std::move(*indices)); }
+
+ void operator()(std::vector<FieldRef>* children) {
+ for (auto& child : *children) {
+ util::visit(*this, &child.impl_);
+ }
+ }
+
+ std::back_insert_iterator<std::vector<FieldRef>> out;
+ };
+
+ std::vector<FieldRef> out;
+ Visitor visitor{std::back_inserter(out)};
+ visitor(&children);
+
+ DCHECK(!out.empty());
+ DCHECK(std::none_of(out.begin(), out.end(),
+ [](const FieldRef& ref) { return ref.IsNested(); }));
+
+ if (out.size() == 1) {
+ impl_ = std::move(out[0].impl_);
+ } else {
+ impl_ = std::move(out);
+ }
+}
+
+Result<FieldRef> FieldRef::FromDotPath(const std::string& dot_path_arg) {
+ if (dot_path_arg.empty()) {
+ return Status::Invalid("Dot path was empty");
+ }
+
+ std::vector<FieldRef> children;
+
+ util::string_view dot_path = dot_path_arg;
+
+ auto parse_name = [&] {
+ std::string name;
+ for (;;) {
+ auto segment_end = dot_path.find_first_of("\\[.");
+ if (segment_end == util::string_view::npos) {
+ // dot_path doesn't contain any other special characters; consume all
+ name.append(dot_path.begin(), dot_path.end());
+ dot_path = "";
+ break;
+ }
+
+ if (dot_path[segment_end] != '\\') {
+ // segment_end points to a subscript for a new FieldRef
+ name.append(dot_path.begin(), segment_end);
+ dot_path = dot_path.substr(segment_end);
+ break;
+ }
+
+ if (dot_path.size() == segment_end + 1) {
+ // dot_path ends with backslash; consume it all
+ name.append(dot_path.begin(), dot_path.end());
+ dot_path = "";
+ break;
+ }
+
+ // append all characters before backslash, then the character which follows it
+ name.append(dot_path.begin(), segment_end);
+ name.push_back(dot_path[segment_end + 1]);
+ dot_path = dot_path.substr(segment_end + 2);
+ }
+ return name;
+ };
+
+ while (!dot_path.empty()) {
+ auto subscript = dot_path[0];
+ dot_path = dot_path.substr(1);
+ switch (subscript) {
+ case '.': {
+ // next element is a name
+ children.emplace_back(parse_name());
+ continue;
+ }
+ case '[': {
+ auto subscript_end = dot_path.find_first_not_of("0123456789");
+ if (subscript_end == util::string_view::npos || dot_path[subscript_end] != ']') {
+ return Status::Invalid("Dot path '", dot_path_arg,
+ "' contained an unterminated index");
+ }
+ children.emplace_back(std::atoi(dot_path.data()));
+ dot_path = dot_path.substr(subscript_end + 1);
+ continue;
+ }
+ default:
+ return Status::Invalid("Dot path must begin with '[' or '.', got '", dot_path_arg,
+ "'");
+ }
+ }
+
+ FieldRef out;
+ out.Flatten(std::move(children));
+ return out;
+}
+
+size_t FieldRef::hash() const {
+ struct Visitor : std::hash<std::string> {
+ using std::hash<std::string>::operator();
+
+ size_t operator()(const FieldPath& path) { return path.hash(); }
+
+ size_t operator()(const std::vector<FieldRef>& children) {
+ size_t hash = 0;
+
+ for (const FieldRef& child : children) {
+ hash ^= child.hash();
+ }
+
+ return hash;
+ }
+ };
+
+ return util::visit(Visitor{}, impl_);
+}
+
+std::string FieldRef::ToString() const {
+ struct Visitor {
+ std::string operator()(const FieldPath& path) { return path.ToString(); }
+
+ std::string operator()(const std::string& name) { return "Name(" + name + ")"; }
+
+ std::string operator()(const std::vector<FieldRef>& children) {
+ std::string repr = "Nested(";
+ for (const auto& child : children) {
+ repr += child.ToString() + " ";
+ }
+ repr.resize(repr.size() - 1);
+ repr += ")";
+ return repr;
+ }
+ };
+
+ return "FieldRef." + util::visit(Visitor{}, impl_);
+}
+
+std::vector<FieldPath> FieldRef::FindAll(const Schema& schema) const {
+ if (auto name = this->name()) {
+ return internal::MapVector([](int i) { return FieldPath{i}; },
+ schema.GetAllFieldIndices(*name));
+ }
+ return FindAll(schema.fields());
+}
+
+std::vector<FieldPath> FieldRef::FindAll(const Field& field) const {
+ return FindAll(field.type()->fields());
+}
+
+std::vector<FieldPath> FieldRef::FindAll(const DataType& type) const {
+ return FindAll(type.fields());
+}
+
+std::vector<FieldPath> FieldRef::FindAll(const FieldVector& fields) const {
+ struct Visitor {
+ std::vector<FieldPath> operator()(const FieldPath& path) {
+ // skip long IndexError construction if path is out of range
+ int out_of_range_depth;
+ auto maybe_field = FieldPathGetImpl::Get(
+ &path, &fields_,
+ [](const std::shared_ptr<Field>& field) { return &field->type()->fields(); },
+ &out_of_range_depth);
+
+ DCHECK_OK(maybe_field.status());
+
+ if (maybe_field.ValueOrDie() != nullptr) {
+ return {path};
+ }
+ return {};
+ }
+
+ std::vector<FieldPath> operator()(const std::string& name) {
+ std::vector<FieldPath> out;
+
+ for (int i = 0; i < static_cast<int>(fields_.size()); ++i) {
+ if (fields_[i]->name() == name) {
+ out.push_back({i});
+ }
+ }
+
+ return out;
+ }
+
+ struct Matches {
+ // referents[i] is referenced by prefixes[i]
+ std::vector<FieldPath> prefixes;
+ FieldVector referents;
+
+ Matches(std::vector<FieldPath> matches, const FieldVector& fields) {
+ for (auto& match : matches) {
+ Add({}, std::move(match), fields);
+ }
+ }
+
+ Matches() = default;
+
+ size_t size() const { return referents.size(); }
+
+ void Add(const FieldPath& prefix, const FieldPath& suffix,
+ const FieldVector& fields) {
+ auto maybe_field = suffix.Get(fields);
+ DCHECK_OK(maybe_field.status());
+ referents.push_back(std::move(maybe_field).ValueOrDie());
+
+ std::vector<int> concatenated_indices(prefix.indices().size() +
+ suffix.indices().size());
+ auto it = concatenated_indices.begin();
+ for (auto path : {&prefix, &suffix}) {
+ it = std::copy(path->indices().begin(), path->indices().end(), it);
+ }
+ prefixes.emplace_back(std::move(concatenated_indices));
+ }
+ };
+
+ std::vector<FieldPath> operator()(const std::vector<FieldRef>& refs) {
+ DCHECK_GE(refs.size(), 1);
+ Matches matches(refs.front().FindAll(fields_), fields_);
+
+ for (auto ref_it = refs.begin() + 1; ref_it != refs.end(); ++ref_it) {
+ Matches next_matches;
+ for (size_t i = 0; i < matches.size(); ++i) {
+ const auto& referent = *matches.referents[i];
+
+ for (const FieldPath& match : ref_it->FindAll(referent)) {
+ next_matches.Add(matches.prefixes[i], match, referent.type()->fields());
+ }
+ }
+ matches = std::move(next_matches);
+ }
+
+ return matches.prefixes;
+ }
+
+ const FieldVector& fields_;
+ };
+
+ return util::visit(Visitor{fields}, impl_);
+}
+
+std::vector<FieldPath> FieldRef::FindAll(const ArrayData& array) const {
+ return FindAll(*array.type);
+}
+
+std::vector<FieldPath> FieldRef::FindAll(const Array& array) const {
+ return FindAll(*array.type());
+}
+
+std::vector<FieldPath> FieldRef::FindAll(const RecordBatch& batch) const {
+ return FindAll(*batch.schema());
+}
+
+void PrintTo(const FieldRef& ref, std::ostream* os) { *os << ref.ToString(); }
+
+// ----------------------------------------------------------------------
+// Schema implementation
+
+std::string EndiannessToString(Endianness endianness) {
+ switch (endianness) {
+ case Endianness::Little:
+ return "little";
+ case Endianness::Big:
+ return "big";
+ default:
+ DCHECK(false) << "invalid endianness";
+ return "???";
+ }
+}
+
+class Schema::Impl {
+ public:
+ Impl(std::vector<std::shared_ptr<Field>> fields, Endianness endianness,
+ std::shared_ptr<const KeyValueMetadata> metadata)
+ : fields_(std::move(fields)),
+ endianness_(endianness),
+ name_to_index_(CreateNameToIndexMap(fields_)),
+ metadata_(std::move(metadata)) {}
+
+ std::vector<std::shared_ptr<Field>> fields_;
+ Endianness endianness_;
+ std::unordered_multimap<std::string, int> name_to_index_;
+ std::shared_ptr<const KeyValueMetadata> metadata_;
+};
+
+Schema::Schema(std::vector<std::shared_ptr<Field>> fields, Endianness endianness,
+ std::shared_ptr<const KeyValueMetadata> metadata)
+ : detail::Fingerprintable(),
+ impl_(new Impl(std::move(fields), endianness, std::move(metadata))) {}
+
+Schema::Schema(std::vector<std::shared_ptr<Field>> fields,
+ std::shared_ptr<const KeyValueMetadata> metadata)
+ : detail::Fingerprintable(),
+ impl_(new Impl(std::move(fields), Endianness::Native, std::move(metadata))) {}
+
+Schema::Schema(const Schema& schema)
+ : detail::Fingerprintable(), impl_(new Impl(*schema.impl_)) {}
+
+Schema::~Schema() = default;
+
+std::shared_ptr<Schema> Schema::WithEndianness(Endianness endianness) const {
+ return std::make_shared<Schema>(impl_->fields_, endianness, impl_->metadata_);
+}
+
+Endianness Schema::endianness() const { return impl_->endianness_; }
+
+bool Schema::is_native_endian() const { return impl_->endianness_ == Endianness::Native; }
+
+int Schema::num_fields() const { return static_cast<int>(impl_->fields_.size()); }
+
+const std::shared_ptr<Field>& Schema::field(int i) const {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, num_fields());
+ return impl_->fields_[i];
+}
+
+const std::vector<std::shared_ptr<Field>>& Schema::fields() const {
+ return impl_->fields_;
+}
+
+bool Schema::Equals(const Schema& other, bool check_metadata) const {
+ if (this == &other) {
+ return true;
+ }
+
+ // checks endianness equality
+ if (endianness() != other.endianness()) {
+ return false;
+ }
+
+ // checks field equality
+ if (num_fields() != other.num_fields()) {
+ return false;
+ }
+
+ if (check_metadata) {
+ const auto& metadata_fp = metadata_fingerprint();
+ const auto& other_metadata_fp = other.metadata_fingerprint();
+ if (metadata_fp != other_metadata_fp) {
+ return false;
+ }
+ }
+
+ // Fast path using fingerprints, if possible
+ const auto& fp = fingerprint();
+ const auto& other_fp = other.fingerprint();
+ if (!fp.empty() && !other_fp.empty()) {
+ return fp == other_fp;
+ }
+
+ // Fall back on field-by-field comparison
+ for (int i = 0; i < num_fields(); ++i) {
+ if (!field(i)->Equals(*other.field(i).get(), check_metadata)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool Schema::Equals(const std::shared_ptr<Schema>& other, bool check_metadata) const {
+ if (other == nullptr) {
+ return false;
+ }
+
+ return Equals(*other, check_metadata);
+}
+
+std::shared_ptr<Field> Schema::GetFieldByName(const std::string& name) const {
+ int i = GetFieldIndex(name);
+ return i == -1 ? nullptr : impl_->fields_[i];
+}
+
+int Schema::GetFieldIndex(const std::string& name) const {
+ return LookupNameIndex(impl_->name_to_index_, name);
+}
+
+std::vector<int> Schema::GetAllFieldIndices(const std::string& name) const {
+ std::vector<int> result;
+ auto p = impl_->name_to_index_.equal_range(name);
+ for (auto it = p.first; it != p.second; ++it) {
+ result.push_back(it->second);
+ }
+ if (result.size() > 1) {
+ std::sort(result.begin(), result.end());
+ }
+ return result;
+}
+
+Status Schema::CanReferenceFieldsByNames(const std::vector<std::string>& names) const {
+ for (const auto& name : names) {
+ if (GetFieldByName(name) == nullptr) {
+ return Status::Invalid("Field named '", name,
+ "' not found or not unique in the schema.");
+ }
+ }
+
+ return Status::OK();
+}
+
+std::vector<std::shared_ptr<Field>> Schema::GetAllFieldsByName(
+ const std::string& name) const {
+ std::vector<std::shared_ptr<Field>> result;
+ auto p = impl_->name_to_index_.equal_range(name);
+ for (auto it = p.first; it != p.second; ++it) {
+ result.push_back(impl_->fields_[it->second]);
+ }
+ return result;
+}
+
+Result<std::shared_ptr<Schema>> Schema::AddField(
+ int i, const std::shared_ptr<Field>& field) const {
+ if (i < 0 || i > this->num_fields()) {
+ return Status::Invalid("Invalid column index to add field.");
+ }
+
+ return std::make_shared<Schema>(internal::AddVectorElement(impl_->fields_, i, field),
+ impl_->metadata_);
+}
+
+Result<std::shared_ptr<Schema>> Schema::SetField(
+ int i, const std::shared_ptr<Field>& field) const {
+ if (i < 0 || i > this->num_fields()) {
+ return Status::Invalid("Invalid column index to add field.");
+ }
+
+ return std::make_shared<Schema>(
+ internal::ReplaceVectorElement(impl_->fields_, i, field), impl_->metadata_);
+}
+
+Result<std::shared_ptr<Schema>> Schema::RemoveField(int i) const {
+ if (i < 0 || i >= this->num_fields()) {
+ return Status::Invalid("Invalid column index to remove field.");
+ }
+
+ return std::make_shared<Schema>(internal::DeleteVectorElement(impl_->fields_, i),
+ impl_->metadata_);
+}
+
+bool Schema::HasMetadata() const {
+ return (impl_->metadata_ != nullptr) && (impl_->metadata_->size() > 0);
+}
+
+bool Schema::HasDistinctFieldNames() const {
+ auto fields = field_names();
+ std::unordered_set<std::string> names{fields.cbegin(), fields.cend()};
+ return names.size() == fields.size();
+}
+
+std::shared_ptr<Schema> Schema::WithMetadata(
+ const std::shared_ptr<const KeyValueMetadata>& metadata) const {
+ return std::make_shared<Schema>(impl_->fields_, metadata);
+}
+
+const std::shared_ptr<const KeyValueMetadata>& Schema::metadata() const {
+ return impl_->metadata_;
+}
+
+std::shared_ptr<Schema> Schema::RemoveMetadata() const {
+ return std::make_shared<Schema>(impl_->fields_);
+}
+
+std::string Schema::ToString(bool show_metadata) const {
+ std::stringstream buffer;
+
+ int i = 0;
+ for (const auto& field : impl_->fields_) {
+ if (i > 0) {
+ buffer << std::endl;
+ }
+ buffer << field->ToString(show_metadata);
+ ++i;
+ }
+
+ if (impl_->endianness_ != Endianness::Native) {
+ buffer << "\n-- endianness: " << EndiannessToString(impl_->endianness_) << " --";
+ }
+
+ if (show_metadata && HasMetadata()) {
+ buffer << impl_->metadata_->ToString();
+ }
+
+ return buffer.str();
+}
+
+std::vector<std::string> Schema::field_names() const {
+ std::vector<std::string> names;
+ for (const auto& field : impl_->fields_) {
+ names.push_back(field->name());
+ }
+ return names;
+}
+
+class SchemaBuilder::Impl {
+ public:
+ friend class SchemaBuilder;
+ Impl(ConflictPolicy policy, Field::MergeOptions field_merge_options)
+ : policy_(policy), field_merge_options_(field_merge_options) {}
+
+ Impl(std::vector<std::shared_ptr<Field>> fields,
+ std::shared_ptr<const KeyValueMetadata> metadata, ConflictPolicy conflict_policy,
+ Field::MergeOptions field_merge_options)
+ : fields_(std::move(fields)),
+ name_to_index_(CreateNameToIndexMap(fields_)),
+ metadata_(std::move(metadata)),
+ policy_(conflict_policy),
+ field_merge_options_(field_merge_options) {}
+
+ Status AddField(const std::shared_ptr<Field>& field) {
+ DCHECK_NE(field, nullptr);
+
+ // Short-circuit, no lookup needed.
+ if (policy_ == CONFLICT_APPEND) {
+ return AppendField(field);
+ }
+
+ auto name = field->name();
+ constexpr int kNotFound = -1;
+ constexpr int kDuplicateFound = -2;
+ auto i = LookupNameIndex<kNotFound, kDuplicateFound>(name_to_index_, name);
+
+ if (i == kNotFound) {
+ return AppendField(field);
+ }
+
+ // From this point, there's one or more field in the builder that exists with
+ // the same name.
+
+ if (policy_ == CONFLICT_IGNORE) {
+ // The ignore policy is more generous when there's duplicate in the builder.
+ return Status::OK();
+ } else if (policy_ == CONFLICT_ERROR) {
+ return Status::Invalid("Duplicate found, policy dictate to treat as an error");
+ }
+
+ if (i == kDuplicateFound) {
+ // Cannot merge/replace when there's more than one field in the builder
+ // because we can't decide which to merge/replace.
+ return Status::Invalid("Cannot merge field ", name,
+ " more than one field with same name exists");
+ }
+
+ DCHECK_GE(i, 0);
+
+ if (policy_ == CONFLICT_REPLACE) {
+ fields_[i] = field;
+ } else if (policy_ == CONFLICT_MERGE) {
+ ARROW_ASSIGN_OR_RAISE(fields_[i], fields_[i]->MergeWith(field));
+ }
+
+ return Status::OK();
+ }
+
+ Status AppendField(const std::shared_ptr<Field>& field) {
+ name_to_index_.emplace(field->name(), static_cast<int>(fields_.size()));
+ fields_.push_back(field);
+ return Status::OK();
+ }
+
+ void Reset() {
+ fields_.clear();
+ name_to_index_.clear();
+ metadata_.reset();
+ }
+
+ private:
+ std::vector<std::shared_ptr<Field>> fields_;
+ std::unordered_multimap<std::string, int> name_to_index_;
+ std::shared_ptr<const KeyValueMetadata> metadata_;
+ ConflictPolicy policy_;
+ Field::MergeOptions field_merge_options_;
+};
+
+SchemaBuilder::SchemaBuilder(ConflictPolicy policy,
+ Field::MergeOptions field_merge_options) {
+ impl_ = internal::make_unique<Impl>(policy, field_merge_options);
+}
+
+SchemaBuilder::SchemaBuilder(std::vector<std::shared_ptr<Field>> fields,
+ ConflictPolicy policy,
+ Field::MergeOptions field_merge_options) {
+ impl_ = internal::make_unique<Impl>(std::move(fields), nullptr, policy,
+ field_merge_options);
+}
+
+SchemaBuilder::SchemaBuilder(const std::shared_ptr<Schema>& schema, ConflictPolicy policy,
+ Field::MergeOptions field_merge_options) {
+ std::shared_ptr<const KeyValueMetadata> metadata;
+ if (schema->HasMetadata()) {
+ metadata = schema->metadata()->Copy();
+ }
+
+ impl_ = internal::make_unique<Impl>(schema->fields(), std::move(metadata), policy,
+ field_merge_options);
+}
+
+SchemaBuilder::~SchemaBuilder() {}
+
+SchemaBuilder::ConflictPolicy SchemaBuilder::policy() const { return impl_->policy_; }
+
+void SchemaBuilder::SetPolicy(SchemaBuilder::ConflictPolicy resolution) {
+ impl_->policy_ = resolution;
+}
+
+Status SchemaBuilder::AddField(const std::shared_ptr<Field>& field) {
+ return impl_->AddField(field);
+}
+
+Status SchemaBuilder::AddFields(const std::vector<std::shared_ptr<Field>>& fields) {
+ for (const auto& field : fields) {
+ RETURN_NOT_OK(AddField(field));
+ }
+
+ return Status::OK();
+}
+
+Status SchemaBuilder::AddSchema(const std::shared_ptr<Schema>& schema) {
+ DCHECK_NE(schema, nullptr);
+ return AddFields(schema->fields());
+}
+
+Status SchemaBuilder::AddSchemas(const std::vector<std::shared_ptr<Schema>>& schemas) {
+ for (const auto& schema : schemas) {
+ RETURN_NOT_OK(AddSchema(schema));
+ }
+
+ return Status::OK();
+}
+
+Status SchemaBuilder::AddMetadata(const KeyValueMetadata& metadata) {
+ impl_->metadata_ = metadata.Copy();
+ return Status::OK();
+}
+
+Result<std::shared_ptr<Schema>> SchemaBuilder::Finish() const {
+ return schema(impl_->fields_, impl_->metadata_);
+}
+
+void SchemaBuilder::Reset() { impl_->Reset(); }
+
+Result<std::shared_ptr<Schema>> SchemaBuilder::Merge(
+ const std::vector<std::shared_ptr<Schema>>& schemas, ConflictPolicy policy) {
+ SchemaBuilder builder{policy};
+ RETURN_NOT_OK(builder.AddSchemas(schemas));
+ return builder.Finish();
+}
+
+Status SchemaBuilder::AreCompatible(const std::vector<std::shared_ptr<Schema>>& schemas,
+ ConflictPolicy policy) {
+ return Merge(schemas, policy).status();
+}
+
+std::shared_ptr<Schema> schema(std::vector<std::shared_ptr<Field>> fields,
+ std::shared_ptr<const KeyValueMetadata> metadata) {
+ return std::make_shared<Schema>(std::move(fields), std::move(metadata));
+}
+
+std::shared_ptr<Schema> schema(std::vector<std::shared_ptr<Field>> fields,
+ Endianness endianness,
+ std::shared_ptr<const KeyValueMetadata> metadata) {
+ return std::make_shared<Schema>(std::move(fields), endianness, std::move(metadata));
+}
+
+Result<std::shared_ptr<Schema>> UnifySchemas(
+ const std::vector<std::shared_ptr<Schema>>& schemas,
+ const Field::MergeOptions field_merge_options) {
+ if (schemas.empty()) {
+ return Status::Invalid("Must provide at least one schema to unify.");
+ }
+
+ if (!schemas[0]->HasDistinctFieldNames()) {
+ return Status::Invalid("Can't unify schema with duplicate field names.");
+ }
+
+ SchemaBuilder builder{schemas[0], SchemaBuilder::CONFLICT_MERGE, field_merge_options};
+
+ for (size_t i = 1; i < schemas.size(); i++) {
+ const auto& schema = schemas[i];
+ if (!schema->HasDistinctFieldNames()) {
+ return Status::Invalid("Can't unify schema with duplicate field names.");
+ }
+ RETURN_NOT_OK(builder.AddSchema(schema));
+ }
+
+ return builder.Finish();
+}
+
+// ----------------------------------------------------------------------
+// Fingerprint computations
+
+namespace detail {
+
+Fingerprintable::~Fingerprintable() {
+ delete fingerprint_.load();
+ delete metadata_fingerprint_.load();
+}
+
+template <typename ComputeFingerprint>
+static const std::string& LoadFingerprint(std::atomic<std::string*>* fingerprint,
+ ComputeFingerprint&& compute_fingerprint) {
+ auto new_p = new std::string(std::forward<ComputeFingerprint>(compute_fingerprint)());
+ // Since fingerprint() and metadata_fingerprint() return a *reference* to the
+ // allocated string, the first allocation ever should never be replaced by another
+ // one. Hence the compare_exchange_strong() against nullptr.
+ std::string* expected = nullptr;
+ if (fingerprint->compare_exchange_strong(expected, new_p)) {
+ return *new_p;
+ } else {
+ delete new_p;
+ DCHECK_NE(expected, nullptr);
+ return *expected;
+ }
+}
+
+const std::string& Fingerprintable::LoadFingerprintSlow() const {
+ return LoadFingerprint(&fingerprint_, [this]() { return ComputeFingerprint(); });
+}
+
+const std::string& Fingerprintable::LoadMetadataFingerprintSlow() const {
+ return LoadFingerprint(&metadata_fingerprint_,
+ [this]() { return ComputeMetadataFingerprint(); });
+}
+
+} // namespace detail
+
+static inline std::string TypeIdFingerprint(const DataType& type) {
+ auto c = static_cast<int>(type.id()) + 'A';
+ DCHECK_GE(c, 0);
+ DCHECK_LT(c, 128); // Unlikely to happen any soon
+ // Prefix with an unusual character in order to disambiguate
+ std::string s{'@', static_cast<char>(c)};
+ return s;
+}
+
+static char TimeUnitFingerprint(TimeUnit::type unit) {
+ switch (unit) {
+ case TimeUnit::SECOND:
+ return 's';
+ case TimeUnit::MILLI:
+ return 'm';
+ case TimeUnit::MICRO:
+ return 'u';
+ case TimeUnit::NANO:
+ return 'n';
+ default:
+ DCHECK(false) << "Unexpected TimeUnit";
+ return '\0';
+ }
+}
+
+static char IntervalTypeFingerprint(IntervalType::type unit) {
+ switch (unit) {
+ case IntervalType::DAY_TIME:
+ return 'd';
+ case IntervalType::MONTHS:
+ return 'M';
+ case IntervalType::MONTH_DAY_NANO:
+ return 'N';
+ default:
+ DCHECK(false) << "Unexpected IntervalType::type";
+ return '\0';
+ }
+}
+
+static void AppendMetadataFingerprint(const KeyValueMetadata& metadata,
+ std::stringstream* ss) {
+ // Compute metadata fingerprint. KeyValueMetadata is not immutable,
+ // so we don't cache the result on the metadata instance.
+ const auto pairs = metadata.sorted_pairs();
+ if (!pairs.empty()) {
+ *ss << "!{";
+ for (const auto& p : pairs) {
+ const auto& k = p.first;
+ const auto& v = p.second;
+ // Since metadata strings can contain arbitrary characters, prefix with
+ // string length to disambiguate.
+ *ss << k.length() << ':' << k << ':';
+ *ss << v.length() << ':' << v << ';';
+ }
+ *ss << '}';
+ }
+}
+
+std::string Field::ComputeFingerprint() const {
+ const auto& type_fingerprint = type_->fingerprint();
+ if (type_fingerprint.empty()) {
+ // Underlying DataType doesn't support fingerprinting.
+ return "";
+ }
+ std::stringstream ss;
+ ss << 'F';
+ if (nullable_) {
+ ss << 'n';
+ } else {
+ ss << 'N';
+ }
+ ss << name_;
+ ss << '{' << type_fingerprint << '}';
+ return ss.str();
+}
+
+std::string Field::ComputeMetadataFingerprint() const {
+ std::stringstream ss;
+ if (metadata_) {
+ AppendMetadataFingerprint(*metadata_, &ss);
+ }
+ const auto& type_fingerprint = type_->metadata_fingerprint();
+ if (!type_fingerprint.empty()) {
+ ss << "+{" << type_->metadata_fingerprint() << "}";
+ }
+ return ss.str();
+}
+
+std::string Schema::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << "S{";
+ for (const auto& field : fields()) {
+ const auto& field_fingerprint = field->fingerprint();
+ if (field_fingerprint.empty()) {
+ return "";
+ }
+ ss << field_fingerprint << ";";
+ }
+ ss << (endianness() == Endianness::Little ? "L" : "B");
+ ss << "}";
+ return ss.str();
+}
+
+std::string Schema::ComputeMetadataFingerprint() const {
+ std::stringstream ss;
+ if (HasMetadata()) {
+ AppendMetadataFingerprint(*metadata(), &ss);
+ }
+ ss << "S{";
+ for (const auto& field : fields()) {
+ const auto& field_fingerprint = field->metadata_fingerprint();
+ ss << field_fingerprint << ";";
+ }
+ ss << "}";
+ return ss.str();
+}
+
+void PrintTo(const Schema& s, std::ostream* os) { *os << s; }
+
+std::string DataType::ComputeFingerprint() const {
+ // Default implementation returns empty string, signalling non-implemented
+ // functionality.
+ return "";
+}
+
+std::string DataType::ComputeMetadataFingerprint() const {
+ // Whatever the data type, metadata can only be found on child fields
+ std::string s;
+ for (const auto& child : children_) {
+ s += child->metadata_fingerprint() + ";";
+ }
+ return s;
+}
+
+#define PARAMETER_LESS_FINGERPRINT(TYPE_CLASS) \
+ std::string TYPE_CLASS##Type::ComputeFingerprint() const { \
+ return TypeIdFingerprint(*this); \
+ }
+
+PARAMETER_LESS_FINGERPRINT(Null)
+PARAMETER_LESS_FINGERPRINT(Boolean)
+PARAMETER_LESS_FINGERPRINT(Int8)
+PARAMETER_LESS_FINGERPRINT(Int16)
+PARAMETER_LESS_FINGERPRINT(Int32)
+PARAMETER_LESS_FINGERPRINT(Int64)
+PARAMETER_LESS_FINGERPRINT(UInt8)
+PARAMETER_LESS_FINGERPRINT(UInt16)
+PARAMETER_LESS_FINGERPRINT(UInt32)
+PARAMETER_LESS_FINGERPRINT(UInt64)
+PARAMETER_LESS_FINGERPRINT(HalfFloat)
+PARAMETER_LESS_FINGERPRINT(Float)
+PARAMETER_LESS_FINGERPRINT(Double)
+PARAMETER_LESS_FINGERPRINT(Binary)
+PARAMETER_LESS_FINGERPRINT(LargeBinary)
+PARAMETER_LESS_FINGERPRINT(String)
+PARAMETER_LESS_FINGERPRINT(LargeString)
+PARAMETER_LESS_FINGERPRINT(Date32)
+PARAMETER_LESS_FINGERPRINT(Date64)
+
+#undef PARAMETER_LESS_FINGERPRINT
+
+std::string DictionaryType::ComputeFingerprint() const {
+ const auto& index_fingerprint = index_type_->fingerprint();
+ const auto& value_fingerprint = value_type_->fingerprint();
+ std::string ordered_fingerprint = ordered_ ? "1" : "0";
+
+ DCHECK(!index_fingerprint.empty()); // it's an integer type
+ if (!value_fingerprint.empty()) {
+ return TypeIdFingerprint(*this) + index_fingerprint + value_fingerprint +
+ ordered_fingerprint;
+ }
+ return ordered_fingerprint;
+}
+
+std::string ListType::ComputeFingerprint() const {
+ const auto& child_fingerprint = children_[0]->fingerprint();
+ if (!child_fingerprint.empty()) {
+ return TypeIdFingerprint(*this) + "{" + child_fingerprint + "}";
+ }
+ return "";
+}
+
+std::string LargeListType::ComputeFingerprint() const {
+ const auto& child_fingerprint = children_[0]->fingerprint();
+ if (!child_fingerprint.empty()) {
+ return TypeIdFingerprint(*this) + "{" + child_fingerprint + "}";
+ }
+ return "";
+}
+
+std::string MapType::ComputeFingerprint() const {
+ const auto& key_fingerprint = key_type()->fingerprint();
+ const auto& item_fingerprint = item_type()->fingerprint();
+ if (!key_fingerprint.empty() && !item_fingerprint.empty()) {
+ if (keys_sorted_) {
+ return TypeIdFingerprint(*this) + "s{" + key_fingerprint + item_fingerprint + "}";
+ } else {
+ return TypeIdFingerprint(*this) + "{" + key_fingerprint + item_fingerprint + "}";
+ }
+ }
+ return "";
+}
+
+std::string FixedSizeListType::ComputeFingerprint() const {
+ const auto& child_fingerprint = children_[0]->fingerprint();
+ if (!child_fingerprint.empty()) {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << "[" << list_size_ << "]"
+ << "{" << child_fingerprint << "}";
+ return ss.str();
+ }
+ return "";
+}
+
+std::string FixedSizeBinaryType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << "[" << byte_width_ << "]";
+ return ss.str();
+}
+
+std::string DecimalType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << "[" << byte_width_ << "," << precision_ << ","
+ << scale_ << "]";
+ return ss.str();
+}
+
+std::string StructType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << "{";
+ for (const auto& child : children_) {
+ const auto& child_fingerprint = child->fingerprint();
+ if (child_fingerprint.empty()) {
+ return "";
+ }
+ ss << child_fingerprint << ";";
+ }
+ ss << "}";
+ return ss.str();
+}
+
+std::string UnionType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this);
+ switch (mode()) {
+ case UnionMode::SPARSE:
+ ss << "[s";
+ break;
+ case UnionMode::DENSE:
+ ss << "[d";
+ break;
+ default:
+ DCHECK(false) << "Unexpected UnionMode";
+ }
+ for (const auto code : type_codes_) {
+ // Represent code as integer, not raw character
+ ss << ':' << static_cast<int32_t>(code);
+ }
+ ss << "]{";
+ for (const auto& child : children_) {
+ const auto& child_fingerprint = child->fingerprint();
+ if (child_fingerprint.empty()) {
+ return "";
+ }
+ ss << child_fingerprint << ";";
+ }
+ ss << "}";
+ return ss.str();
+}
+
+std::string TimeType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << TimeUnitFingerprint(unit_);
+ return ss.str();
+}
+
+std::string TimestampType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << TimeUnitFingerprint(unit_) << timezone_.length()
+ << ':' << timezone_;
+ return ss.str();
+}
+
+std::string IntervalType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << IntervalTypeFingerprint(interval_type());
+ return ss.str();
+}
+
+std::string DurationType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << TimeUnitFingerprint(unit_);
+ return ss.str();
+}
+
+// ----------------------------------------------------------------------
+// Visitors and factory functions
+
+Status DataType::Accept(TypeVisitor* visitor) const {
+ return VisitTypeInline(*this, visitor);
+}
+
+#define TYPE_FACTORY(NAME, KLASS) \
+ std::shared_ptr<DataType> NAME() { \
+ static std::shared_ptr<DataType> result = std::make_shared<KLASS>(); \
+ return result; \
+ }
+
+TYPE_FACTORY(null, NullType)
+TYPE_FACTORY(boolean, BooleanType)
+TYPE_FACTORY(int8, Int8Type)
+TYPE_FACTORY(uint8, UInt8Type)
+TYPE_FACTORY(int16, Int16Type)
+TYPE_FACTORY(uint16, UInt16Type)
+TYPE_FACTORY(int32, Int32Type)
+TYPE_FACTORY(uint32, UInt32Type)
+TYPE_FACTORY(int64, Int64Type)
+TYPE_FACTORY(uint64, UInt64Type)
+TYPE_FACTORY(float16, HalfFloatType)
+TYPE_FACTORY(float32, FloatType)
+TYPE_FACTORY(float64, DoubleType)
+TYPE_FACTORY(utf8, StringType)
+TYPE_FACTORY(large_utf8, LargeStringType)
+TYPE_FACTORY(binary, BinaryType)
+TYPE_FACTORY(large_binary, LargeBinaryType)
+TYPE_FACTORY(date64, Date64Type)
+TYPE_FACTORY(date32, Date32Type)
+
+std::shared_ptr<DataType> fixed_size_binary(int32_t byte_width) {
+ return std::make_shared<FixedSizeBinaryType>(byte_width);
+}
+
+std::shared_ptr<DataType> duration(TimeUnit::type unit) {
+ return std::make_shared<DurationType>(unit);
+}
+
+std::shared_ptr<DataType> day_time_interval() {
+ return std::make_shared<DayTimeIntervalType>();
+}
+
+std::shared_ptr<DataType> month_day_nano_interval() {
+ return std::make_shared<MonthDayNanoIntervalType>();
+}
+
+std::shared_ptr<DataType> month_interval() {
+ return std::make_shared<MonthIntervalType>();
+}
+
+std::shared_ptr<DataType> timestamp(TimeUnit::type unit) {
+ return std::make_shared<TimestampType>(unit);
+}
+
+std::shared_ptr<DataType> timestamp(TimeUnit::type unit, const std::string& timezone) {
+ return std::make_shared<TimestampType>(unit, timezone);
+}
+
+std::shared_ptr<DataType> time32(TimeUnit::type unit) {
+ return std::make_shared<Time32Type>(unit);
+}
+
+std::shared_ptr<DataType> time64(TimeUnit::type unit) {
+ return std::make_shared<Time64Type>(unit);
+}
+
+std::shared_ptr<DataType> list(const std::shared_ptr<DataType>& value_type) {
+ return std::make_shared<ListType>(value_type);
+}
+
+std::shared_ptr<DataType> list(const std::shared_ptr<Field>& value_field) {
+ return std::make_shared<ListType>(value_field);
+}
+
+std::shared_ptr<DataType> large_list(const std::shared_ptr<DataType>& value_type) {
+ return std::make_shared<LargeListType>(value_type);
+}
+
+std::shared_ptr<DataType> large_list(const std::shared_ptr<Field>& value_field) {
+ return std::make_shared<LargeListType>(value_field);
+}
+
+std::shared_ptr<DataType> map(std::shared_ptr<DataType> key_type,
+ std::shared_ptr<DataType> item_type, bool keys_sorted) {
+ return std::make_shared<MapType>(std::move(key_type), std::move(item_type),
+ keys_sorted);
+}
+
+std::shared_ptr<DataType> map(std::shared_ptr<DataType> key_type,
+ std::shared_ptr<Field> item_field, bool keys_sorted) {
+ return std::make_shared<MapType>(std::move(key_type), std::move(item_field),
+ keys_sorted);
+}
+
+std::shared_ptr<DataType> fixed_size_list(const std::shared_ptr<DataType>& value_type,
+ int32_t list_size) {
+ return std::make_shared<FixedSizeListType>(value_type, list_size);
+}
+
+std::shared_ptr<DataType> fixed_size_list(const std::shared_ptr<Field>& value_field,
+ int32_t list_size) {
+ return std::make_shared<FixedSizeListType>(value_field, list_size);
+}
+
+std::shared_ptr<DataType> struct_(const std::vector<std::shared_ptr<Field>>& fields) {
+ return std::make_shared<StructType>(fields);
+}
+
+std::shared_ptr<DataType> sparse_union(FieldVector child_fields,
+ std::vector<int8_t> type_codes) {
+ if (type_codes.empty()) {
+ type_codes = internal::Iota(static_cast<int8_t>(child_fields.size()));
+ }
+ return std::make_shared<SparseUnionType>(std::move(child_fields),
+ std::move(type_codes));
+}
+std::shared_ptr<DataType> dense_union(FieldVector child_fields,
+ std::vector<int8_t> type_codes) {
+ if (type_codes.empty()) {
+ type_codes = internal::Iota(static_cast<int8_t>(child_fields.size()));
+ }
+ return std::make_shared<DenseUnionType>(std::move(child_fields), std::move(type_codes));
+}
+
+FieldVector FieldsFromArraysAndNames(std::vector<std::string> names,
+ const ArrayVector& arrays) {
+ FieldVector fields(arrays.size());
+ int i = 0;
+ if (names.empty()) {
+ for (const auto& array : arrays) {
+ fields[i] = field(std::to_string(i), array->type());
+ ++i;
+ }
+ } else {
+ DCHECK_EQ(names.size(), arrays.size());
+ for (const auto& array : arrays) {
+ fields[i] = field(std::move(names[i]), array->type());
+ ++i;
+ }
+ }
+ return fields;
+}
+
+std::shared_ptr<DataType> sparse_union(const ArrayVector& children,
+ std::vector<std::string> field_names,
+ std::vector<int8_t> type_codes) {
+ if (type_codes.empty()) {
+ type_codes = internal::Iota(static_cast<int8_t>(children.size()));
+ }
+ auto fields = FieldsFromArraysAndNames(std::move(field_names), children);
+ return sparse_union(std::move(fields), std::move(type_codes));
+}
+
+std::shared_ptr<DataType> dense_union(const ArrayVector& children,
+ std::vector<std::string> field_names,
+ std::vector<int8_t> type_codes) {
+ if (type_codes.empty()) {
+ type_codes = internal::Iota(static_cast<int8_t>(children.size()));
+ }
+ auto fields = FieldsFromArraysAndNames(std::move(field_names), children);
+ return dense_union(std::move(fields), std::move(type_codes));
+}
+
+std::shared_ptr<DataType> dictionary(const std::shared_ptr<DataType>& index_type,
+ const std::shared_ptr<DataType>& dict_type,
+ bool ordered) {
+ return std::make_shared<DictionaryType>(index_type, dict_type, ordered);
+}
+
+std::shared_ptr<Field> field(std::string name, std::shared_ptr<DataType> type,
+ bool nullable,
+ std::shared_ptr<const KeyValueMetadata> metadata) {
+ return std::make_shared<Field>(std::move(name), std::move(type), nullable,
+ std::move(metadata));
+}
+
+std::shared_ptr<Field> field(std::string name, std::shared_ptr<DataType> type,
+ std::shared_ptr<const KeyValueMetadata> metadata) {
+ return std::make_shared<Field>(std::move(name), std::move(type), /*nullable=*/true,
+ std::move(metadata));
+}
+
+std::shared_ptr<DataType> decimal(int32_t precision, int32_t scale) {
+ return precision <= Decimal128Type::kMaxPrecision ? decimal128(precision, scale)
+ : decimal256(precision, scale);
+}
+
+std::shared_ptr<DataType> decimal128(int32_t precision, int32_t scale) {
+ return std::make_shared<Decimal128Type>(precision, scale);
+}
+
+std::shared_ptr<DataType> decimal256(int32_t precision, int32_t scale) {
+ return std::make_shared<Decimal256Type>(precision, scale);
+}
+
+std::string Decimal128Type::ToString() const {
+ std::stringstream s;
+ s << "decimal128(" << precision_ << ", " << scale_ << ")";
+ return s.str();
+}
+
+std::string Decimal256Type::ToString() const {
+ std::stringstream s;
+ s << "decimal256(" << precision_ << ", " << scale_ << ")";
+ return s.str();
+}
+
+namespace {
+
+std::vector<std::shared_ptr<DataType>> g_signed_int_types;
+std::vector<std::shared_ptr<DataType>> g_unsigned_int_types;
+std::vector<std::shared_ptr<DataType>> g_int_types;
+std::vector<std::shared_ptr<DataType>> g_floating_types;
+std::vector<std::shared_ptr<DataType>> g_numeric_types;
+std::vector<std::shared_ptr<DataType>> g_base_binary_types;
+std::vector<std::shared_ptr<DataType>> g_temporal_types;
+std::vector<std::shared_ptr<DataType>> g_interval_types;
+std::vector<std::shared_ptr<DataType>> g_primitive_types;
+std::once_flag static_data_initialized;
+
+template <typename T>
+void Extend(const std::vector<T>& values, std::vector<T>* out) {
+ out->insert(out->end(), values.begin(), values.end());
+}
+
+void InitStaticData() {
+ // Signed int types
+ g_signed_int_types = {int8(), int16(), int32(), int64()};
+
+ // Unsigned int types
+ g_unsigned_int_types = {uint8(), uint16(), uint32(), uint64()};
+
+ // All int types
+ Extend(g_unsigned_int_types, &g_int_types);
+ Extend(g_signed_int_types, &g_int_types);
+
+ // Floating point types
+ g_floating_types = {float32(), float64()};
+
+ // Numeric types
+ Extend(g_int_types, &g_numeric_types);
+ Extend(g_floating_types, &g_numeric_types);
+
+ // Temporal types
+ g_temporal_types = {date32(),
+ date64(),
+ time32(TimeUnit::SECOND),
+ time32(TimeUnit::MILLI),
+ time64(TimeUnit::MICRO),
+ time64(TimeUnit::NANO),
+ timestamp(TimeUnit::SECOND),
+ timestamp(TimeUnit::MILLI),
+ timestamp(TimeUnit::MICRO),
+ timestamp(TimeUnit::NANO)};
+
+ // Interval types
+ g_interval_types = {day_time_interval(), month_interval()};
+
+ // Base binary types (without FixedSizeBinary)
+ g_base_binary_types = {binary(), utf8(), large_binary(), large_utf8()};
+
+ // Non-parametric, non-nested types. This also DOES NOT include
+ //
+ // * Decimal
+ // * Fixed Size Binary
+ // * Time32
+ // * Time64
+ // * Timestamp
+ g_primitive_types = {null(), boolean(), date32(), date64()};
+ Extend(g_numeric_types, &g_primitive_types);
+ Extend(g_base_binary_types, &g_primitive_types);
+}
+
+} // namespace
+
+const std::vector<std::shared_ptr<DataType>>& BaseBinaryTypes() {
+ std::call_once(static_data_initialized, InitStaticData);
+ return g_base_binary_types;
+}
+
+const std::vector<std::shared_ptr<DataType>>& StringTypes() {
+ static DataTypeVector types = {utf8(), large_utf8()};
+ return types;
+}
+
+const std::vector<std::shared_ptr<DataType>>& SignedIntTypes() {
+ std::call_once(static_data_initialized, InitStaticData);
+ return g_signed_int_types;
+}
+
+const std::vector<std::shared_ptr<DataType>>& UnsignedIntTypes() {
+ std::call_once(static_data_initialized, InitStaticData);
+ return g_unsigned_int_types;
+}
+
+const std::vector<std::shared_ptr<DataType>>& IntTypes() {
+ std::call_once(static_data_initialized, InitStaticData);
+ return g_int_types;
+}
+
+const std::vector<std::shared_ptr<DataType>>& FloatingPointTypes() {
+ std::call_once(static_data_initialized, InitStaticData);
+ return g_floating_types;
+}
+
+const std::vector<std::shared_ptr<DataType>>& NumericTypes() {
+ std::call_once(static_data_initialized, InitStaticData);
+ return g_numeric_types;
+}
+
+const std::vector<std::shared_ptr<DataType>>& TemporalTypes() {
+ std::call_once(static_data_initialized, InitStaticData);
+ return g_temporal_types;
+}
+
+const std::vector<std::shared_ptr<DataType>>& IntervalTypes() {
+ std::call_once(static_data_initialized, InitStaticData);
+ return g_interval_types;
+}
+
+const std::vector<std::shared_ptr<DataType>>& PrimitiveTypes() {
+ std::call_once(static_data_initialized, InitStaticData);
+ return g_primitive_types;
+}
+
+const std::vector<TimeUnit::type>& TimeUnit::values() {
+ static std::vector<TimeUnit::type> units = {TimeUnit::SECOND, TimeUnit::MILLI,
+ TimeUnit::MICRO, TimeUnit::NANO};
+ return units;
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/type.h b/src/arrow/cpp/src/arrow/type.h
new file mode 100644
index 000000000..23e6c7e9e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/type.h
@@ -0,0 +1,2041 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <atomic>
+#include <climits>
+#include <cstdint>
+#include <iosfwd>
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/result.h"
+#include "arrow/type_fwd.h" // IWYU pragma: export
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/variant.h"
+#include "arrow/util/visibility.h"
+#include "arrow/visitor.h" // IWYU pragma: keep
+
+namespace arrow {
+namespace detail {
+
+/// \defgroup numeric-datatypes Datatypes for numeric data
+/// @{
+/// @}
+
+/// \defgroup binary-datatypes Datatypes for binary/string data
+/// @{
+/// @}
+
+/// \defgroup temporal-datatypes Datatypes for temporal data
+/// @{
+/// @}
+
+/// \defgroup nested-datatypes Datatypes for nested data
+/// @{
+/// @}
+
+class ARROW_EXPORT Fingerprintable {
+ public:
+ virtual ~Fingerprintable();
+
+ const std::string& fingerprint() const {
+ auto p = fingerprint_.load();
+ if (ARROW_PREDICT_TRUE(p != NULLPTR)) {
+ return *p;
+ }
+ return LoadFingerprintSlow();
+ }
+
+ const std::string& metadata_fingerprint() const {
+ auto p = metadata_fingerprint_.load();
+ if (ARROW_PREDICT_TRUE(p != NULLPTR)) {
+ return *p;
+ }
+ return LoadMetadataFingerprintSlow();
+ }
+
+ protected:
+ const std::string& LoadFingerprintSlow() const;
+ const std::string& LoadMetadataFingerprintSlow() const;
+
+ virtual std::string ComputeFingerprint() const = 0;
+ virtual std::string ComputeMetadataFingerprint() const = 0;
+
+ mutable std::atomic<std::string*> fingerprint_;
+ mutable std::atomic<std::string*> metadata_fingerprint_;
+};
+
+} // namespace detail
+
+/// EXPERIMENTAL: Layout specification for a data type
+struct ARROW_EXPORT DataTypeLayout {
+ enum BufferKind { FIXED_WIDTH, VARIABLE_WIDTH, BITMAP, ALWAYS_NULL };
+
+ /// Layout specification for a single data type buffer
+ struct BufferSpec {
+ BufferKind kind;
+ int64_t byte_width; // For FIXED_WIDTH
+
+ bool operator==(const BufferSpec& other) const {
+ return kind == other.kind &&
+ (kind != FIXED_WIDTH || byte_width == other.byte_width);
+ }
+ bool operator!=(const BufferSpec& other) const { return !(*this == other); }
+ };
+
+ static BufferSpec FixedWidth(int64_t w) { return BufferSpec{FIXED_WIDTH, w}; }
+ static BufferSpec VariableWidth() { return BufferSpec{VARIABLE_WIDTH, -1}; }
+ static BufferSpec Bitmap() { return BufferSpec{BITMAP, -1}; }
+ static BufferSpec AlwaysNull() { return BufferSpec{ALWAYS_NULL, -1}; }
+
+ /// A vector of buffer layout specifications, one for each expected buffer
+ std::vector<BufferSpec> buffers;
+ /// Whether this type expects an associated dictionary array.
+ bool has_dictionary = false;
+
+ explicit DataTypeLayout(std::vector<BufferSpec> v) : buffers(std::move(v)) {}
+};
+
+/// \brief Base class for all data types
+///
+/// Data types in this library are all *logical*. They can be expressed as
+/// either a primitive physical type (bytes or bits of some fixed size), a
+/// nested type consisting of other data types, or another data type (e.g. a
+/// timestamp encoded as an int64).
+///
+/// Simple datatypes may be entirely described by their Type::type id, but
+/// complex datatypes are usually parametric.
+class ARROW_EXPORT DataType : public detail::Fingerprintable {
+ public:
+ explicit DataType(Type::type id) : detail::Fingerprintable(), id_(id) {}
+ ~DataType() override;
+
+ /// \brief Return whether the types are equal
+ ///
+ /// Types that are logically convertible from one to another (e.g. List<UInt8>
+ /// and Binary) are NOT equal.
+ bool Equals(const DataType& other, bool check_metadata = false) const;
+
+ /// \brief Return whether the types are equal
+ bool Equals(const std::shared_ptr<DataType>& other) const;
+
+ /// \brief Return the child field at index i.
+ const std::shared_ptr<Field>& field(int i) const { return children_[i]; }
+
+ /// \brief Return the children fields associated with this type.
+ const std::vector<std::shared_ptr<Field>>& fields() const { return children_; }
+
+ /// \brief Return the number of children fields associated with this type.
+ int num_fields() const { return static_cast<int>(children_.size()); }
+
+ Status Accept(TypeVisitor* visitor) const;
+
+ /// \brief A string representation of the type, including any children
+ virtual std::string ToString() const = 0;
+
+ /// \brief Return hash value (excluding metadata in child fields)
+ size_t Hash() const;
+
+ /// \brief A string name of the type, omitting any child fields
+ ///
+ /// \since 0.7.0
+ virtual std::string name() const = 0;
+
+ /// \brief Return the data type layout. Children are not included.
+ ///
+ /// \note Experimental API
+ virtual DataTypeLayout layout() const = 0;
+
+ /// \brief Return the type category
+ Type::type id() const { return id_; }
+
+ protected:
+ // Dummy version that returns a null string (indicating not implemented).
+ // Subclasses should override for fast equality checks.
+ std::string ComputeFingerprint() const override;
+
+ // Generic versions that works for all regular types, nested or not.
+ std::string ComputeMetadataFingerprint() const override;
+
+ Type::type id_;
+ std::vector<std::shared_ptr<Field>> children_;
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(DataType);
+};
+
+ARROW_EXPORT
+std::ostream& operator<<(std::ostream& os, const DataType& type);
+
+/// \brief Return the compatible physical data type
+///
+/// Some types may have distinct logical meanings but the exact same physical
+/// representation. For example, TimestampType has Int64Type as a physical
+/// type (defined as TimestampType::PhysicalType).
+///
+/// The return value is as follows:
+/// - if a `PhysicalType` alias exists in the concrete type class, return
+/// an instance of `PhysicalType`.
+/// - otherwise, return the input type itself.
+std::shared_ptr<DataType> GetPhysicalType(const std::shared_ptr<DataType>& type);
+
+/// \brief Base class for all fixed-width data types
+class ARROW_EXPORT FixedWidthType : public DataType {
+ public:
+ using DataType::DataType;
+
+ virtual int bit_width() const = 0;
+};
+
+/// \brief Base class for all data types representing primitive values
+class ARROW_EXPORT PrimitiveCType : public FixedWidthType {
+ public:
+ using FixedWidthType::FixedWidthType;
+};
+
+/// \brief Base class for all numeric data types
+class ARROW_EXPORT NumberType : public PrimitiveCType {
+ public:
+ using PrimitiveCType::PrimitiveCType;
+};
+
+/// \brief Base class for all integral data types
+class ARROW_EXPORT IntegerType : public NumberType {
+ public:
+ using NumberType::NumberType;
+ virtual bool is_signed() const = 0;
+};
+
+/// \brief Base class for all floating-point data types
+class ARROW_EXPORT FloatingPointType : public NumberType {
+ public:
+ using NumberType::NumberType;
+ enum Precision { HALF, SINGLE, DOUBLE };
+ virtual Precision precision() const = 0;
+};
+
+/// \brief Base class for all parametric data types
+class ParametricType {};
+
+class ARROW_EXPORT NestedType : public DataType, public ParametricType {
+ public:
+ using DataType::DataType;
+};
+
+/// \brief The combination of a field name and data type, with optional metadata
+///
+/// Fields are used to describe the individual constituents of a
+/// nested DataType or a Schema.
+///
+/// A field's metadata is represented by a KeyValueMetadata instance,
+/// which holds arbitrary key-value pairs.
+class ARROW_EXPORT Field : public detail::Fingerprintable {
+ public:
+ Field(std::string name, std::shared_ptr<DataType> type, bool nullable = true,
+ std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR)
+ : detail::Fingerprintable(),
+ name_(std::move(name)),
+ type_(std::move(type)),
+ nullable_(nullable),
+ metadata_(std::move(metadata)) {}
+
+ ~Field() override;
+
+ /// \brief Return the field's attached metadata
+ std::shared_ptr<const KeyValueMetadata> metadata() const { return metadata_; }
+
+ /// \brief Return whether the field has non-empty metadata
+ bool HasMetadata() const;
+
+ /// \brief Return a copy of this field with the given metadata attached to it
+ std::shared_ptr<Field> WithMetadata(
+ const std::shared_ptr<const KeyValueMetadata>& metadata) const;
+
+ /// \brief EXPERIMENTAL: Return a copy of this field with the given metadata
+ /// merged with existing metadata (any colliding keys will be overridden by
+ /// the passed metadata)
+ std::shared_ptr<Field> WithMergedMetadata(
+ const std::shared_ptr<const KeyValueMetadata>& metadata) const;
+
+ /// \brief Return a copy of this field without any metadata attached to it
+ std::shared_ptr<Field> RemoveMetadata() const;
+
+ /// \brief Return a copy of this field with the replaced type.
+ std::shared_ptr<Field> WithType(const std::shared_ptr<DataType>& type) const;
+
+ /// \brief Return a copy of this field with the replaced name.
+ std::shared_ptr<Field> WithName(const std::string& name) const;
+
+ /// \brief Return a copy of this field with the replaced nullability.
+ std::shared_ptr<Field> WithNullable(bool nullable) const;
+
+ /// \brief Options that control the behavior of `MergeWith`.
+ /// Options are to be added to allow type conversions, including integer
+ /// widening, promotion from integer to float, or conversion to or from boolean.
+ struct MergeOptions {
+ /// If true, a Field of NullType can be unified with a Field of another type.
+ /// The unified field will be of the other type and become nullable.
+ /// Nullability will be promoted to the looser option (nullable if one is not
+ /// nullable).
+ bool promote_nullability = true;
+
+ static MergeOptions Defaults() { return MergeOptions(); }
+ };
+
+ /// \brief Merge the current field with a field of the same name.
+ ///
+ /// The two fields must be compatible, i.e:
+ /// - have the same name
+ /// - have the same type, or of compatible types according to `options`.
+ ///
+ /// The metadata of the current field is preserved; the metadata of the other
+ /// field is discarded.
+ Result<std::shared_ptr<Field>> MergeWith(
+ const Field& other, MergeOptions options = MergeOptions::Defaults()) const;
+ Result<std::shared_ptr<Field>> MergeWith(
+ const std::shared_ptr<Field>& other,
+ MergeOptions options = MergeOptions::Defaults()) const;
+
+ std::vector<std::shared_ptr<Field>> Flatten() const;
+
+ /// \brief Indicate if fields are equals.
+ ///
+ /// \param[in] other field to check equality with.
+ /// \param[in] check_metadata controls if it should check for metadata
+ /// equality.
+ ///
+ /// \return true if fields are equal, false otherwise.
+ bool Equals(const Field& other, bool check_metadata = false) const;
+ bool Equals(const std::shared_ptr<Field>& other, bool check_metadata = false) const;
+
+ /// \brief Indicate if fields are compatibles.
+ ///
+ /// See the criteria of MergeWith.
+ ///
+ /// \return true if fields are compatible, false otherwise.
+ bool IsCompatibleWith(const Field& other) const;
+ bool IsCompatibleWith(const std::shared_ptr<Field>& other) const;
+
+ /// \brief Return a string representation ot the field
+ /// \param[in] show_metadata when true, if KeyValueMetadata is non-empty,
+ /// print keys and values in the output
+ std::string ToString(bool show_metadata = false) const;
+
+ /// \brief Return the field name
+ const std::string& name() const { return name_; }
+ /// \brief Return the field data type
+ const std::shared_ptr<DataType>& type() const { return type_; }
+ /// \brief Return whether the field is nullable
+ bool nullable() const { return nullable_; }
+
+ std::shared_ptr<Field> Copy() const;
+
+ private:
+ std::string ComputeFingerprint() const override;
+ std::string ComputeMetadataFingerprint() const override;
+
+ // Field name
+ std::string name_;
+
+ // The field's data type
+ std::shared_ptr<DataType> type_;
+
+ // Fields can be nullable
+ bool nullable_;
+
+ // The field's metadata, if any
+ std::shared_ptr<const KeyValueMetadata> metadata_;
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Field);
+};
+
+namespace detail {
+
+template <typename DERIVED, typename BASE, Type::type TYPE_ID, typename C_TYPE>
+class ARROW_EXPORT CTypeImpl : public BASE {
+ public:
+ static constexpr Type::type type_id = TYPE_ID;
+ using c_type = C_TYPE;
+ using PhysicalType = DERIVED;
+
+ CTypeImpl() : BASE(TYPE_ID) {}
+
+ int bit_width() const override { return static_cast<int>(sizeof(C_TYPE) * CHAR_BIT); }
+
+ DataTypeLayout layout() const override {
+ return DataTypeLayout(
+ {DataTypeLayout::Bitmap(), DataTypeLayout::FixedWidth(sizeof(C_TYPE))});
+ }
+
+ std::string name() const override { return DERIVED::type_name(); }
+
+ std::string ToString() const override { return this->name(); }
+};
+
+template <typename DERIVED, typename BASE, Type::type TYPE_ID, typename C_TYPE>
+constexpr Type::type CTypeImpl<DERIVED, BASE, TYPE_ID, C_TYPE>::type_id;
+
+template <typename DERIVED, Type::type TYPE_ID, typename C_TYPE>
+class IntegerTypeImpl : public detail::CTypeImpl<DERIVED, IntegerType, TYPE_ID, C_TYPE> {
+ bool is_signed() const override { return std::is_signed<C_TYPE>::value; }
+};
+
+} // namespace detail
+
+/// Concrete type class for always-null data
+class ARROW_EXPORT NullType : public DataType {
+ public:
+ static constexpr Type::type type_id = Type::NA;
+
+ static constexpr const char* type_name() { return "null"; }
+
+ NullType() : DataType(Type::NA) {}
+
+ std::string ToString() const override;
+
+ DataTypeLayout layout() const override {
+ return DataTypeLayout({DataTypeLayout::AlwaysNull()});
+ }
+
+ std::string name() const override { return "null"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// Concrete type class for boolean data
+class ARROW_EXPORT BooleanType
+ : public detail::CTypeImpl<BooleanType, PrimitiveCType, Type::BOOL, bool> {
+ public:
+ static constexpr const char* type_name() { return "bool"; }
+
+ // BooleanType within arrow use a single bit instead of the C 8-bits layout.
+ int bit_width() const final { return 1; }
+
+ DataTypeLayout layout() const override {
+ return DataTypeLayout({DataTypeLayout::Bitmap(), DataTypeLayout::Bitmap()});
+ }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// \addtogroup numeric-datatypes
+///
+/// @{
+
+/// Concrete type class for unsigned 8-bit integer data
+class ARROW_EXPORT UInt8Type
+ : public detail::IntegerTypeImpl<UInt8Type, Type::UINT8, uint8_t> {
+ public:
+ static constexpr const char* type_name() { return "uint8"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// Concrete type class for signed 8-bit integer data
+class ARROW_EXPORT Int8Type
+ : public detail::IntegerTypeImpl<Int8Type, Type::INT8, int8_t> {
+ public:
+ static constexpr const char* type_name() { return "int8"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// Concrete type class for unsigned 16-bit integer data
+class ARROW_EXPORT UInt16Type
+ : public detail::IntegerTypeImpl<UInt16Type, Type::UINT16, uint16_t> {
+ public:
+ static constexpr const char* type_name() { return "uint16"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// Concrete type class for signed 16-bit integer data
+class ARROW_EXPORT Int16Type
+ : public detail::IntegerTypeImpl<Int16Type, Type::INT16, int16_t> {
+ public:
+ static constexpr const char* type_name() { return "int16"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// Concrete type class for unsigned 32-bit integer data
+class ARROW_EXPORT UInt32Type
+ : public detail::IntegerTypeImpl<UInt32Type, Type::UINT32, uint32_t> {
+ public:
+ static constexpr const char* type_name() { return "uint32"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// Concrete type class for signed 32-bit integer data
+class ARROW_EXPORT Int32Type
+ : public detail::IntegerTypeImpl<Int32Type, Type::INT32, int32_t> {
+ public:
+ static constexpr const char* type_name() { return "int32"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// Concrete type class for unsigned 64-bit integer data
+class ARROW_EXPORT UInt64Type
+ : public detail::IntegerTypeImpl<UInt64Type, Type::UINT64, uint64_t> {
+ public:
+ static constexpr const char* type_name() { return "uint64"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// Concrete type class for signed 64-bit integer data
+class ARROW_EXPORT Int64Type
+ : public detail::IntegerTypeImpl<Int64Type, Type::INT64, int64_t> {
+ public:
+ static constexpr const char* type_name() { return "int64"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// Concrete type class for 16-bit floating-point data
+class ARROW_EXPORT HalfFloatType
+ : public detail::CTypeImpl<HalfFloatType, FloatingPointType, Type::HALF_FLOAT,
+ uint16_t> {
+ public:
+ Precision precision() const override;
+ static constexpr const char* type_name() { return "halffloat"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// Concrete type class for 32-bit floating-point data (C "float")
+class ARROW_EXPORT FloatType
+ : public detail::CTypeImpl<FloatType, FloatingPointType, Type::FLOAT, float> {
+ public:
+ Precision precision() const override;
+ static constexpr const char* type_name() { return "float"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// Concrete type class for 64-bit floating-point data (C "double")
+class ARROW_EXPORT DoubleType
+ : public detail::CTypeImpl<DoubleType, FloatingPointType, Type::DOUBLE, double> {
+ public:
+ Precision precision() const override;
+ static constexpr const char* type_name() { return "double"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// @}
+
+/// \brief Base class for all variable-size binary data types
+class ARROW_EXPORT BaseBinaryType : public DataType {
+ public:
+ using DataType::DataType;
+};
+
+constexpr int64_t kBinaryMemoryLimit = std::numeric_limits<int32_t>::max() - 1;
+
+/// \addtogroup binary-datatypes
+///
+/// @{
+
+/// \brief Concrete type class for variable-size binary data
+class ARROW_EXPORT BinaryType : public BaseBinaryType {
+ public:
+ static constexpr Type::type type_id = Type::BINARY;
+ static constexpr bool is_utf8 = false;
+ using offset_type = int32_t;
+ using PhysicalType = BinaryType;
+
+ static constexpr const char* type_name() { return "binary"; }
+
+ BinaryType() : BinaryType(Type::BINARY) {}
+
+ DataTypeLayout layout() const override {
+ return DataTypeLayout({DataTypeLayout::Bitmap(),
+ DataTypeLayout::FixedWidth(sizeof(offset_type)),
+ DataTypeLayout::VariableWidth()});
+ }
+
+ std::string ToString() const override;
+ std::string name() const override { return "binary"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+
+ // Allow subclasses like StringType to change the logical type.
+ explicit BinaryType(Type::type logical_type) : BaseBinaryType(logical_type) {}
+};
+
+/// \brief Concrete type class for large variable-size binary data
+class ARROW_EXPORT LargeBinaryType : public BaseBinaryType {
+ public:
+ static constexpr Type::type type_id = Type::LARGE_BINARY;
+ static constexpr bool is_utf8 = false;
+ using offset_type = int64_t;
+ using PhysicalType = LargeBinaryType;
+
+ static constexpr const char* type_name() { return "large_binary"; }
+
+ LargeBinaryType() : LargeBinaryType(Type::LARGE_BINARY) {}
+
+ DataTypeLayout layout() const override {
+ return DataTypeLayout({DataTypeLayout::Bitmap(),
+ DataTypeLayout::FixedWidth(sizeof(offset_type)),
+ DataTypeLayout::VariableWidth()});
+ }
+
+ std::string ToString() const override;
+ std::string name() const override { return "large_binary"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+
+ // Allow subclasses like LargeStringType to change the logical type.
+ explicit LargeBinaryType(Type::type logical_type) : BaseBinaryType(logical_type) {}
+};
+
+/// \brief Concrete type class for variable-size string data, utf8-encoded
+class ARROW_EXPORT StringType : public BinaryType {
+ public:
+ static constexpr Type::type type_id = Type::STRING;
+ static constexpr bool is_utf8 = true;
+ using PhysicalType = BinaryType;
+
+ static constexpr const char* type_name() { return "utf8"; }
+
+ StringType() : BinaryType(Type::STRING) {}
+
+ std::string ToString() const override;
+ std::string name() const override { return "utf8"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// \brief Concrete type class for large variable-size string data, utf8-encoded
+class ARROW_EXPORT LargeStringType : public LargeBinaryType {
+ public:
+ static constexpr Type::type type_id = Type::LARGE_STRING;
+ static constexpr bool is_utf8 = true;
+ using PhysicalType = LargeBinaryType;
+
+ static constexpr const char* type_name() { return "large_utf8"; }
+
+ LargeStringType() : LargeBinaryType(Type::LARGE_STRING) {}
+
+ std::string ToString() const override;
+ std::string name() const override { return "large_utf8"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// \brief Concrete type class for fixed-size binary data
+class ARROW_EXPORT FixedSizeBinaryType : public FixedWidthType, public ParametricType {
+ public:
+ static constexpr Type::type type_id = Type::FIXED_SIZE_BINARY;
+ static constexpr bool is_utf8 = false;
+
+ static constexpr const char* type_name() { return "fixed_size_binary"; }
+
+ explicit FixedSizeBinaryType(int32_t byte_width)
+ : FixedWidthType(Type::FIXED_SIZE_BINARY), byte_width_(byte_width) {}
+ explicit FixedSizeBinaryType(int32_t byte_width, Type::type override_type_id)
+ : FixedWidthType(override_type_id), byte_width_(byte_width) {}
+
+ std::string ToString() const override;
+ std::string name() const override { return "fixed_size_binary"; }
+
+ DataTypeLayout layout() const override {
+ return DataTypeLayout(
+ {DataTypeLayout::Bitmap(), DataTypeLayout::FixedWidth(byte_width())});
+ }
+
+ int32_t byte_width() const { return byte_width_; }
+ int bit_width() const override;
+
+ // Validating constructor
+ static Result<std::shared_ptr<DataType>> Make(int32_t byte_width);
+
+ protected:
+ std::string ComputeFingerprint() const override;
+
+ int32_t byte_width_;
+};
+
+/// @}
+
+/// \addtogroup numeric-datatypes
+///
+/// @{
+
+/// \brief Base type class for (fixed-size) decimal data
+class ARROW_EXPORT DecimalType : public FixedSizeBinaryType {
+ public:
+ explicit DecimalType(Type::type type_id, int32_t byte_width, int32_t precision,
+ int32_t scale)
+ : FixedSizeBinaryType(byte_width, type_id), precision_(precision), scale_(scale) {}
+
+ /// Constructs concrete decimal types
+ static Result<std::shared_ptr<DataType>> Make(Type::type type_id, int32_t precision,
+ int32_t scale);
+
+ int32_t precision() const { return precision_; }
+ int32_t scale() const { return scale_; }
+
+ /// \brief Returns the number of bytes needed for precision.
+ ///
+ /// precision must be >= 1
+ static int32_t DecimalSize(int32_t precision);
+
+ protected:
+ std::string ComputeFingerprint() const override;
+
+ int32_t precision_;
+ int32_t scale_;
+};
+
+/// \brief Concrete type class for 128-bit decimal data
+///
+/// Arrow decimals are fixed-point decimal numbers encoded as a scaled
+/// integer. The precision is the number of significant digits that the
+/// decimal type can represent; the scale is the number of digits after
+/// the decimal point (note the scale can be negative).
+///
+/// As an example, `Decimal128Type(7, 3)` can exactly represent the numbers
+/// 1234.567 and -1234.567 (encoded internally as the 128-bit integers
+/// 1234567 and -1234567, respectively), but neither 12345.67 nor 123.4567.
+///
+/// Decimal128Type has a maximum precision of 38 significant digits
+/// (also available as Decimal128Type::kMaxPrecision).
+/// If higher precision is needed, consider using Decimal256Type.
+class ARROW_EXPORT Decimal128Type : public DecimalType {
+ public:
+ static constexpr Type::type type_id = Type::DECIMAL128;
+
+ static constexpr const char* type_name() { return "decimal128"; }
+
+ /// Decimal128Type constructor that aborts on invalid input.
+ explicit Decimal128Type(int32_t precision, int32_t scale);
+
+ /// Decimal128Type constructor that returns an error on invalid input.
+ static Result<std::shared_ptr<DataType>> Make(int32_t precision, int32_t scale);
+
+ std::string ToString() const override;
+ std::string name() const override { return "decimal128"; }
+
+ static constexpr int32_t kMinPrecision = 1;
+ static constexpr int32_t kMaxPrecision = 38;
+ static constexpr int32_t kByteWidth = 16;
+};
+
+/// \brief Concrete type class for 256-bit decimal data
+///
+/// Arrow decimals are fixed-point decimal numbers encoded as a scaled
+/// integer. The precision is the number of significant digits that the
+/// decimal type can represent; the scale is the number of digits after
+/// the decimal point (note the scale can be negative).
+///
+/// Decimal256Type has a maximum precision of 76 significant digits.
+/// (also available as Decimal256Type::kMaxPrecision).
+///
+/// For most use cases, the maximum precision offered by Decimal128Type
+/// is sufficient, and it will result in a more compact and more efficient
+/// encoding.
+class ARROW_EXPORT Decimal256Type : public DecimalType {
+ public:
+ static constexpr Type::type type_id = Type::DECIMAL256;
+
+ static constexpr const char* type_name() { return "decimal256"; }
+
+ /// Decimal256Type constructor that aborts on invalid input.
+ explicit Decimal256Type(int32_t precision, int32_t scale);
+
+ /// Decimal256Type constructor that returns an error on invalid input.
+ static Result<std::shared_ptr<DataType>> Make(int32_t precision, int32_t scale);
+
+ std::string ToString() const override;
+ std::string name() const override { return "decimal256"; }
+
+ static constexpr int32_t kMinPrecision = 1;
+ static constexpr int32_t kMaxPrecision = 76;
+ static constexpr int32_t kByteWidth = 32;
+};
+
+/// @}
+
+/// \addtogroup nested-datatypes
+///
+/// @{
+
+/// \brief Base class for all variable-size list data types
+class ARROW_EXPORT BaseListType : public NestedType {
+ public:
+ using NestedType::NestedType;
+ std::shared_ptr<Field> value_field() const { return children_[0]; }
+
+ std::shared_ptr<DataType> value_type() const { return children_[0]->type(); }
+};
+
+/// \brief Concrete type class for list data
+///
+/// List data is nested data where each value is a variable number of
+/// child items. Lists can be recursively nested, for example
+/// list(list(int32)).
+class ARROW_EXPORT ListType : public BaseListType {
+ public:
+ static constexpr Type::type type_id = Type::LIST;
+ using offset_type = int32_t;
+
+ static constexpr const char* type_name() { return "list"; }
+
+ // List can contain any other logical value type
+ explicit ListType(const std::shared_ptr<DataType>& value_type)
+ : ListType(std::make_shared<Field>("item", value_type)) {}
+
+ explicit ListType(const std::shared_ptr<Field>& value_field) : BaseListType(type_id) {
+ children_ = {value_field};
+ }
+
+ DataTypeLayout layout() const override {
+ return DataTypeLayout(
+ {DataTypeLayout::Bitmap(), DataTypeLayout::FixedWidth(sizeof(offset_type))});
+ }
+
+ std::string ToString() const override;
+
+ std::string name() const override { return "list"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// \brief Concrete type class for large list data
+///
+/// LargeListType is like ListType but with 64-bit rather than 32-bit offsets.
+class ARROW_EXPORT LargeListType : public BaseListType {
+ public:
+ static constexpr Type::type type_id = Type::LARGE_LIST;
+ using offset_type = int64_t;
+
+ static constexpr const char* type_name() { return "large_list"; }
+
+ // List can contain any other logical value type
+ explicit LargeListType(const std::shared_ptr<DataType>& value_type)
+ : LargeListType(std::make_shared<Field>("item", value_type)) {}
+
+ explicit LargeListType(const std::shared_ptr<Field>& value_field)
+ : BaseListType(type_id) {
+ children_ = {value_field};
+ }
+
+ DataTypeLayout layout() const override {
+ return DataTypeLayout(
+ {DataTypeLayout::Bitmap(), DataTypeLayout::FixedWidth(sizeof(offset_type))});
+ }
+
+ std::string ToString() const override;
+
+ std::string name() const override { return "large_list"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// \brief Concrete type class for map data
+///
+/// Map data is nested data where each value is a variable number of
+/// key-item pairs. Its physical representation is the same as
+/// a list of `{key, item}` structs.
+///
+/// Maps can be recursively nested, for example map(utf8, map(utf8, int32)).
+class ARROW_EXPORT MapType : public ListType {
+ public:
+ static constexpr Type::type type_id = Type::MAP;
+
+ static constexpr const char* type_name() { return "map"; }
+
+ MapType(std::shared_ptr<DataType> key_type, std::shared_ptr<DataType> item_type,
+ bool keys_sorted = false);
+
+ MapType(std::shared_ptr<DataType> key_type, std::shared_ptr<Field> item_field,
+ bool keys_sorted = false);
+
+ MapType(std::shared_ptr<Field> key_field, std::shared_ptr<Field> item_field,
+ bool keys_sorted = false);
+
+ explicit MapType(std::shared_ptr<Field> value_field, bool keys_sorted = false);
+
+ // Validating constructor
+ static Result<std::shared_ptr<DataType>> Make(std::shared_ptr<Field> value_field,
+ bool keys_sorted = false);
+
+ std::shared_ptr<Field> key_field() const { return value_type()->field(0); }
+ std::shared_ptr<DataType> key_type() const { return key_field()->type(); }
+
+ std::shared_ptr<Field> item_field() const { return value_type()->field(1); }
+ std::shared_ptr<DataType> item_type() const { return item_field()->type(); }
+
+ std::string ToString() const override;
+
+ std::string name() const override { return "map"; }
+
+ bool keys_sorted() const { return keys_sorted_; }
+
+ private:
+ std::string ComputeFingerprint() const override;
+
+ bool keys_sorted_;
+};
+
+/// \brief Concrete type class for fixed size list data
+class ARROW_EXPORT FixedSizeListType : public BaseListType {
+ public:
+ static constexpr Type::type type_id = Type::FIXED_SIZE_LIST;
+ using offset_type = int32_t;
+
+ static constexpr const char* type_name() { return "fixed_size_list"; }
+
+ // List can contain any other logical value type
+ FixedSizeListType(const std::shared_ptr<DataType>& value_type, int32_t list_size)
+ : FixedSizeListType(std::make_shared<Field>("item", value_type), list_size) {}
+
+ FixedSizeListType(const std::shared_ptr<Field>& value_field, int32_t list_size)
+ : BaseListType(type_id), list_size_(list_size) {
+ children_ = {value_field};
+ }
+
+ DataTypeLayout layout() const override {
+ return DataTypeLayout({DataTypeLayout::Bitmap()});
+ }
+
+ std::string ToString() const override;
+
+ std::string name() const override { return "fixed_size_list"; }
+
+ int32_t list_size() const { return list_size_; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+
+ int32_t list_size_;
+};
+
+/// \brief Concrete type class for struct data
+class ARROW_EXPORT StructType : public NestedType {
+ public:
+ static constexpr Type::type type_id = Type::STRUCT;
+
+ static constexpr const char* type_name() { return "struct"; }
+
+ explicit StructType(const std::vector<std::shared_ptr<Field>>& fields);
+
+ ~StructType() override;
+
+ DataTypeLayout layout() const override {
+ return DataTypeLayout({DataTypeLayout::Bitmap()});
+ }
+
+ std::string ToString() const override;
+ std::string name() const override { return "struct"; }
+
+ /// Returns null if name not found
+ std::shared_ptr<Field> GetFieldByName(const std::string& name) const;
+
+ /// Return all fields having this name
+ std::vector<std::shared_ptr<Field>> GetAllFieldsByName(const std::string& name) const;
+
+ /// Returns -1 if name not found or if there are multiple fields having the
+ /// same name
+ int GetFieldIndex(const std::string& name) const;
+
+ /// \brief Return the indices of all fields having this name in sorted order
+ std::vector<int> GetAllFieldIndices(const std::string& name) const;
+
+ private:
+ std::string ComputeFingerprint() const override;
+
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+/// \brief Base type class for union data
+class ARROW_EXPORT UnionType : public NestedType {
+ public:
+ static constexpr int8_t kMaxTypeCode = 127;
+ static constexpr int kInvalidChildId = -1;
+
+ static Result<std::shared_ptr<DataType>> Make(
+ const std::vector<std::shared_ptr<Field>>& fields,
+ const std::vector<int8_t>& type_codes, UnionMode::type mode = UnionMode::SPARSE) {
+ if (mode == UnionMode::SPARSE) {
+ return sparse_union(fields, type_codes);
+ } else {
+ return dense_union(fields, type_codes);
+ }
+ }
+
+ DataTypeLayout layout() const override;
+
+ std::string ToString() const override;
+
+ /// The array of logical type ids.
+ ///
+ /// For example, the first type in the union might be denoted by the id 5
+ /// (instead of 0).
+ const std::vector<int8_t>& type_codes() const { return type_codes_; }
+
+ /// An array mapping logical type ids to physical child ids.
+ const std::vector<int>& child_ids() const { return child_ids_; }
+
+ uint8_t max_type_code() const;
+
+ UnionMode::type mode() const;
+
+ protected:
+ UnionType(std::vector<std::shared_ptr<Field>> fields, std::vector<int8_t> type_codes,
+ Type::type id);
+
+ static Status ValidateParameters(const std::vector<std::shared_ptr<Field>>& fields,
+ const std::vector<int8_t>& type_codes,
+ UnionMode::type mode);
+
+ private:
+ std::string ComputeFingerprint() const override;
+
+ std::vector<int8_t> type_codes_;
+ std::vector<int> child_ids_;
+};
+
+/// \brief Concrete type class for sparse union data
+///
+/// A sparse union is a nested type where each logical value is taken from
+/// a single child. A buffer of 8-bit type ids indicates which child
+/// a given logical value is to be taken from.
+///
+/// In a sparse union, each child array should have the same length as the
+/// union array, regardless of the actual number of union values that
+/// refer to it.
+///
+/// Note that, unlike most other types, unions don't have a top-level validity bitmap.
+class ARROW_EXPORT SparseUnionType : public UnionType {
+ public:
+ static constexpr Type::type type_id = Type::SPARSE_UNION;
+
+ static constexpr const char* type_name() { return "sparse_union"; }
+
+ SparseUnionType(std::vector<std::shared_ptr<Field>> fields,
+ std::vector<int8_t> type_codes);
+
+ // A constructor variant that validates input parameters
+ static Result<std::shared_ptr<DataType>> Make(
+ std::vector<std::shared_ptr<Field>> fields, std::vector<int8_t> type_codes);
+
+ std::string name() const override { return "sparse_union"; }
+};
+
+/// \brief Concrete type class for dense union data
+///
+/// A dense union is a nested type where each logical value is taken from
+/// a single child, at a specific offset. A buffer of 8-bit type ids
+/// indicates which child a given logical value is to be taken from,
+/// and a buffer of 32-bit offsets indicates at which physical position
+/// in the given child array the logical value is to be taken from.
+///
+/// Unlike a sparse union, a dense union allows encoding only the child array
+/// values which are actually referred to by the union array. This is
+/// counterbalanced by the additional footprint of the offsets buffer, and
+/// the additional indirection cost when looking up values.
+///
+/// Note that, unlike most other types, unions don't have a top-level validity bitmap.
+class ARROW_EXPORT DenseUnionType : public UnionType {
+ public:
+ static constexpr Type::type type_id = Type::DENSE_UNION;
+
+ static constexpr const char* type_name() { return "dense_union"; }
+
+ DenseUnionType(std::vector<std::shared_ptr<Field>> fields,
+ std::vector<int8_t> type_codes);
+
+ // A constructor variant that validates input parameters
+ static Result<std::shared_ptr<DataType>> Make(
+ std::vector<std::shared_ptr<Field>> fields, std::vector<int8_t> type_codes);
+
+ std::string name() const override { return "dense_union"; }
+};
+
+/// @}
+
+// ----------------------------------------------------------------------
+// Date and time types
+
+/// \addtogroup temporal-datatypes
+///
+/// @{
+
+/// \brief Base type for all date and time types
+class ARROW_EXPORT TemporalType : public FixedWidthType {
+ public:
+ using FixedWidthType::FixedWidthType;
+
+ DataTypeLayout layout() const override {
+ return DataTypeLayout(
+ {DataTypeLayout::Bitmap(), DataTypeLayout::FixedWidth(bit_width() / 8)});
+ }
+};
+
+/// \brief Base type class for date data
+class ARROW_EXPORT DateType : public TemporalType {
+ public:
+ virtual DateUnit unit() const = 0;
+
+ protected:
+ explicit DateType(Type::type type_id);
+};
+
+/// Concrete type class for 32-bit date data (as number of days since UNIX epoch)
+class ARROW_EXPORT Date32Type : public DateType {
+ public:
+ static constexpr Type::type type_id = Type::DATE32;
+ static constexpr DateUnit UNIT = DateUnit::DAY;
+ using c_type = int32_t;
+ using PhysicalType = Int32Type;
+
+ static constexpr const char* type_name() { return "date32"; }
+
+ Date32Type();
+
+ int bit_width() const override { return static_cast<int>(sizeof(c_type) * CHAR_BIT); }
+
+ std::string ToString() const override;
+
+ std::string name() const override { return "date32"; }
+ DateUnit unit() const override { return UNIT; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+/// Concrete type class for 64-bit date data (as number of milliseconds since UNIX epoch)
+class ARROW_EXPORT Date64Type : public DateType {
+ public:
+ static constexpr Type::type type_id = Type::DATE64;
+ static constexpr DateUnit UNIT = DateUnit::MILLI;
+ using c_type = int64_t;
+ using PhysicalType = Int64Type;
+
+ static constexpr const char* type_name() { return "date64"; }
+
+ Date64Type();
+
+ int bit_width() const override { return static_cast<int>(sizeof(c_type) * CHAR_BIT); }
+
+ std::string ToString() const override;
+
+ std::string name() const override { return "date64"; }
+ DateUnit unit() const override { return UNIT; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+};
+
+ARROW_EXPORT
+std::ostream& operator<<(std::ostream& os, TimeUnit::type unit);
+
+/// Base type class for time data
+class ARROW_EXPORT TimeType : public TemporalType, public ParametricType {
+ public:
+ TimeUnit::type unit() const { return unit_; }
+
+ protected:
+ TimeType(Type::type type_id, TimeUnit::type unit);
+ std::string ComputeFingerprint() const override;
+
+ TimeUnit::type unit_;
+};
+
+/// Concrete type class for 32-bit time data (as number of seconds or milliseconds
+/// since midnight)
+class ARROW_EXPORT Time32Type : public TimeType {
+ public:
+ static constexpr Type::type type_id = Type::TIME32;
+ using c_type = int32_t;
+ using PhysicalType = Int32Type;
+
+ static constexpr const char* type_name() { return "time32"; }
+
+ int bit_width() const override { return static_cast<int>(sizeof(c_type) * CHAR_BIT); }
+
+ explicit Time32Type(TimeUnit::type unit = TimeUnit::MILLI);
+
+ std::string ToString() const override;
+
+ std::string name() const override { return "time32"; }
+};
+
+/// Concrete type class for 64-bit time data (as number of microseconds or nanoseconds
+/// since midnight)
+class ARROW_EXPORT Time64Type : public TimeType {
+ public:
+ static constexpr Type::type type_id = Type::TIME64;
+ using c_type = int64_t;
+ using PhysicalType = Int64Type;
+
+ static constexpr const char* type_name() { return "time64"; }
+
+ int bit_width() const override { return static_cast<int>(sizeof(c_type) * CHAR_BIT); }
+
+ explicit Time64Type(TimeUnit::type unit = TimeUnit::NANO);
+
+ std::string ToString() const override;
+
+ std::string name() const override { return "time64"; }
+};
+
+/// \brief Concrete type class for datetime data (as number of seconds, milliseconds,
+/// microseconds or nanoseconds since UNIX epoch)
+///
+/// If supplied, the timezone string should take either the form (i) "Area/Location",
+/// with values drawn from the names in the IANA Time Zone Database (such as
+/// "Europe/Zurich"); or (ii) "(+|-)HH:MM" indicating an absolute offset from GMT
+/// (such as "-08:00"). To indicate a native UTC timestamp, one of the strings "UTC",
+/// "Etc/UTC" or "+00:00" should be used.
+///
+/// If any non-empty string is supplied as the timezone for a TimestampType, then the
+/// Arrow field containing that timestamp type (and by extension the column associated
+/// with such a field) is considered "timezone-aware". The integer arrays that comprise
+/// a timezone-aware column must contain UTC normalized datetime values, regardless of
+/// the contents of their timezone string. More precisely, (i) the producer of a
+/// timezone-aware column must populate its constituent arrays with valid UTC values
+/// (performing offset conversions from non-UTC values if necessary); and (ii) the
+/// consumer of a timezone-aware column may assume that the column's values are directly
+/// comparable (that is, with no offset adjustment required) to the values of any other
+/// timezone-aware column or to any other valid UTC datetime value (provided all values
+/// are expressed in the same units).
+///
+/// If a TimestampType is constructed without a timezone (or, equivalently, if the
+/// timezone supplied is an empty string) then the resulting Arrow field (column) is
+/// considered "timezone-naive". The producer of a timezone-naive column may populate
+/// its constituent integer arrays with datetime values from any timezone; the consumer
+/// of a timezone-naive column should make no assumptions about the interoperability or
+/// comparability of the values of such a column with those of any other timestamp
+/// column or datetime value.
+///
+/// If a timezone-aware field contains a recognized timezone, its values may be
+/// localized to that locale upon display; the values of timezone-naive fields must
+/// always be displayed "as is", with no localization performed on them.
+class ARROW_EXPORT TimestampType : public TemporalType, public ParametricType {
+ public:
+ using Unit = TimeUnit;
+
+ static constexpr Type::type type_id = Type::TIMESTAMP;
+ using c_type = int64_t;
+ using PhysicalType = Int64Type;
+
+ static constexpr const char* type_name() { return "timestamp"; }
+
+ int bit_width() const override { return static_cast<int>(sizeof(int64_t) * CHAR_BIT); }
+
+ explicit TimestampType(TimeUnit::type unit = TimeUnit::MILLI)
+ : TemporalType(Type::TIMESTAMP), unit_(unit) {}
+
+ explicit TimestampType(TimeUnit::type unit, const std::string& timezone)
+ : TemporalType(Type::TIMESTAMP), unit_(unit), timezone_(timezone) {}
+
+ std::string ToString() const override;
+ std::string name() const override { return "timestamp"; }
+
+ TimeUnit::type unit() const { return unit_; }
+ const std::string& timezone() const { return timezone_; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+
+ private:
+ TimeUnit::type unit_;
+ std::string timezone_;
+};
+
+// Base class for the different kinds of calendar intervals.
+class ARROW_EXPORT IntervalType : public TemporalType, public ParametricType {
+ public:
+ enum type { MONTHS, DAY_TIME, MONTH_DAY_NANO };
+
+ virtual type interval_type() const = 0;
+
+ protected:
+ explicit IntervalType(Type::type subtype) : TemporalType(subtype) {}
+ std::string ComputeFingerprint() const override;
+};
+
+/// \brief Represents a number of months.
+///
+/// Type representing a number of months. Corresponds to YearMonth type
+/// in Schema.fbs (years are defined as 12 months).
+class ARROW_EXPORT MonthIntervalType : public IntervalType {
+ public:
+ static constexpr Type::type type_id = Type::INTERVAL_MONTHS;
+ using c_type = int32_t;
+ using PhysicalType = Int32Type;
+
+ static constexpr const char* type_name() { return "month_interval"; }
+
+ IntervalType::type interval_type() const override { return IntervalType::MONTHS; }
+
+ int bit_width() const override { return static_cast<int>(sizeof(c_type) * CHAR_BIT); }
+
+ MonthIntervalType() : IntervalType(type_id) {}
+
+ std::string ToString() const override { return name(); }
+ std::string name() const override { return "month_interval"; }
+};
+
+/// \brief Represents a number of days and milliseconds (fraction of day).
+class ARROW_EXPORT DayTimeIntervalType : public IntervalType {
+ public:
+ struct DayMilliseconds {
+ int32_t days = 0;
+ int32_t milliseconds = 0;
+ constexpr DayMilliseconds() = default;
+ constexpr DayMilliseconds(int32_t days, int32_t milliseconds)
+ : days(days), milliseconds(milliseconds) {}
+ bool operator==(DayMilliseconds other) const {
+ return this->days == other.days && this->milliseconds == other.milliseconds;
+ }
+ bool operator!=(DayMilliseconds other) const { return !(*this == other); }
+ bool operator<(DayMilliseconds other) const {
+ return this->days < other.days || this->milliseconds < other.milliseconds;
+ }
+ };
+ using c_type = DayMilliseconds;
+ using PhysicalType = DayTimeIntervalType;
+
+ static_assert(sizeof(DayMilliseconds) == 8,
+ "DayMilliseconds struct assumed to be of size 8 bytes");
+ static constexpr Type::type type_id = Type::INTERVAL_DAY_TIME;
+
+ static constexpr const char* type_name() { return "day_time_interval"; }
+
+ IntervalType::type interval_type() const override { return IntervalType::DAY_TIME; }
+
+ DayTimeIntervalType() : IntervalType(type_id) {}
+
+ int bit_width() const override { return static_cast<int>(sizeof(c_type) * CHAR_BIT); }
+
+ std::string ToString() const override { return name(); }
+ std::string name() const override { return "day_time_interval"; }
+};
+
+ARROW_EXPORT
+std::ostream& operator<<(std::ostream& os, DayTimeIntervalType::DayMilliseconds interval);
+
+/// \brief Represents a number of months, days and nanoseconds between
+/// two dates.
+///
+/// All fields are independent from one another.
+class ARROW_EXPORT MonthDayNanoIntervalType : public IntervalType {
+ public:
+ struct MonthDayNanos {
+ int32_t months;
+ int32_t days;
+ int64_t nanoseconds;
+ bool operator==(MonthDayNanos other) const {
+ return this->months == other.months && this->days == other.days &&
+ this->nanoseconds == other.nanoseconds;
+ }
+ bool operator!=(MonthDayNanos other) const { return !(*this == other); }
+ };
+ using c_type = MonthDayNanos;
+ using PhysicalType = MonthDayNanoIntervalType;
+
+ static_assert(sizeof(MonthDayNanos) == 16,
+ "MonthDayNanos struct assumed to be of size 16 bytes");
+ static constexpr Type::type type_id = Type::INTERVAL_MONTH_DAY_NANO;
+
+ static constexpr const char* type_name() { return "month_day_nano_interval"; }
+
+ IntervalType::type interval_type() const override {
+ return IntervalType::MONTH_DAY_NANO;
+ }
+
+ MonthDayNanoIntervalType() : IntervalType(type_id) {}
+
+ int bit_width() const override { return static_cast<int>(sizeof(c_type) * CHAR_BIT); }
+
+ std::string ToString() const override { return name(); }
+ std::string name() const override { return "month_day_nano_interval"; }
+};
+
+ARROW_EXPORT
+std::ostream& operator<<(std::ostream& os,
+ MonthDayNanoIntervalType::MonthDayNanos interval);
+
+/// \brief Represents an elapsed time without any relation to a calendar artifact.
+class ARROW_EXPORT DurationType : public TemporalType, public ParametricType {
+ public:
+ using Unit = TimeUnit;
+
+ static constexpr Type::type type_id = Type::DURATION;
+ using c_type = int64_t;
+ using PhysicalType = Int64Type;
+
+ static constexpr const char* type_name() { return "duration"; }
+
+ int bit_width() const override { return static_cast<int>(sizeof(int64_t) * CHAR_BIT); }
+
+ explicit DurationType(TimeUnit::type unit = TimeUnit::MILLI)
+ : TemporalType(Type::DURATION), unit_(unit) {}
+
+ std::string ToString() const override;
+ std::string name() const override { return "duration"; }
+
+ TimeUnit::type unit() const { return unit_; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
+
+ private:
+ TimeUnit::type unit_;
+};
+
+/// @}
+
+// ----------------------------------------------------------------------
+// Dictionary type (for representing categorical or dictionary-encoded
+// in memory)
+
+/// \brief Dictionary-encoded value type with data-dependent
+/// dictionary. Indices are represented by any integer types.
+class ARROW_EXPORT DictionaryType : public FixedWidthType {
+ public:
+ static constexpr Type::type type_id = Type::DICTIONARY;
+
+ static constexpr const char* type_name() { return "dictionary"; }
+
+ DictionaryType(const std::shared_ptr<DataType>& index_type,
+ const std::shared_ptr<DataType>& value_type, bool ordered = false);
+
+ // A constructor variant that validates its input parameters
+ static Result<std::shared_ptr<DataType>> Make(
+ const std::shared_ptr<DataType>& index_type,
+ const std::shared_ptr<DataType>& value_type, bool ordered = false);
+
+ std::string ToString() const override;
+ std::string name() const override { return "dictionary"; }
+
+ int bit_width() const override;
+
+ DataTypeLayout layout() const override;
+
+ const std::shared_ptr<DataType>& index_type() const { return index_type_; }
+ const std::shared_ptr<DataType>& value_type() const { return value_type_; }
+
+ bool ordered() const { return ordered_; }
+
+ protected:
+ static Status ValidateParameters(const DataType& index_type,
+ const DataType& value_type);
+
+ std::string ComputeFingerprint() const override;
+
+ // Must be an integer type (not currently checked)
+ std::shared_ptr<DataType> index_type_;
+ std::shared_ptr<DataType> value_type_;
+ bool ordered_;
+};
+
+// ----------------------------------------------------------------------
+// FieldRef
+
+/// \class FieldPath
+///
+/// Represents a path to a nested field using indices of child fields.
+/// For example, given indices {5, 9, 3} the field would be retrieved with
+/// schema->field(5)->type()->field(9)->type()->field(3)
+///
+/// Attempting to retrieve a child field using a FieldPath which is not valid for
+/// a given schema will raise an error. Invalid FieldPaths include:
+/// - an index is out of range
+/// - the path is empty (note: a default constructed FieldPath will be empty)
+///
+/// FieldPaths provide a number of accessors for drilling down to potentially nested
+/// children. They are overloaded for convenience to support Schema (returns a field),
+/// DataType (returns a child field), Field (returns a child field of this field's type)
+/// Array (returns a child array), RecordBatch (returns a column).
+class ARROW_EXPORT FieldPath {
+ public:
+ FieldPath() = default;
+
+ FieldPath(std::vector<int> indices) // NOLINT runtime/explicit
+ : indices_(std::move(indices)) {}
+
+ FieldPath(std::initializer_list<int> indices) // NOLINT runtime/explicit
+ : indices_(std::move(indices)) {}
+
+ std::string ToString() const;
+
+ size_t hash() const;
+ struct Hash {
+ size_t operator()(const FieldPath& path) const { return path.hash(); }
+ };
+
+ bool empty() const { return indices_.empty(); }
+ bool operator==(const FieldPath& other) const { return indices() == other.indices(); }
+ bool operator!=(const FieldPath& other) const { return indices() != other.indices(); }
+
+ const std::vector<int>& indices() const { return indices_; }
+ int operator[](size_t i) const { return indices_[i]; }
+ std::vector<int>::const_iterator begin() const { return indices_.begin(); }
+ std::vector<int>::const_iterator end() const { return indices_.end(); }
+
+ /// \brief Retrieve the referenced child Field from a Schema, Field, or DataType
+ Result<std::shared_ptr<Field>> Get(const Schema& schema) const;
+ Result<std::shared_ptr<Field>> Get(const Field& field) const;
+ Result<std::shared_ptr<Field>> Get(const DataType& type) const;
+ Result<std::shared_ptr<Field>> Get(const FieldVector& fields) const;
+
+ /// \brief Retrieve the referenced column from a RecordBatch or Table
+ Result<std::shared_ptr<Array>> Get(const RecordBatch& batch) const;
+
+ /// \brief Retrieve the referenced child from an Array or ArrayData
+ Result<std::shared_ptr<Array>> Get(const Array& array) const;
+ Result<std::shared_ptr<ArrayData>> Get(const ArrayData& data) const;
+
+ private:
+ std::vector<int> indices_;
+};
+
+/// \class FieldRef
+/// \brief Descriptor of a (potentially nested) field within a schema.
+///
+/// Unlike FieldPath (which exclusively uses indices of child fields), FieldRef may
+/// reference a field by name. It is intended to replace parameters like `int field_index`
+/// and `const std::string& field_name`; it can be implicitly constructed from either a
+/// field index or a name.
+///
+/// Nested fields can be referenced as well. Given
+/// schema({field("a", struct_({field("n", null())})), field("b", int32())})
+///
+/// the following all indicate the nested field named "n":
+/// FieldRef ref1(0, 0);
+/// FieldRef ref2("a", 0);
+/// FieldRef ref3("a", "n");
+/// FieldRef ref4(0, "n");
+/// ARROW_ASSIGN_OR_RAISE(FieldRef ref5,
+/// FieldRef::FromDotPath(".a[0]"));
+///
+/// FieldPaths matching a FieldRef are retrieved using the member function FindAll.
+/// Multiple matches are possible because field names may be duplicated within a schema.
+/// For example:
+/// Schema a_is_ambiguous({field("a", int32()), field("a", float32())});
+/// auto matches = FieldRef("a").FindAll(a_is_ambiguous);
+/// assert(matches.size() == 2);
+/// assert(matches[0].Get(a_is_ambiguous)->Equals(a_is_ambiguous.field(0)));
+/// assert(matches[1].Get(a_is_ambiguous)->Equals(a_is_ambiguous.field(1)));
+///
+/// Convenience accessors are available which raise a helpful error if the field is not
+/// found or ambiguous, and for immediately calling FieldPath::Get to retrieve any
+/// matching children:
+/// auto maybe_match = FieldRef("struct", "field_i32").FindOneOrNone(schema);
+/// auto maybe_column = FieldRef("struct", "field_i32").GetOne(some_table);
+class ARROW_EXPORT FieldRef {
+ public:
+ FieldRef() = default;
+
+ /// Construct a FieldRef using a string of indices. The reference will be retrieved as:
+ /// schema.fields[self.indices[0]].type.fields[self.indices[1]] ...
+ ///
+ /// Empty indices are not valid.
+ FieldRef(FieldPath indices); // NOLINT runtime/explicit
+
+ /// Construct a by-name FieldRef. Multiple fields may match a by-name FieldRef:
+ /// [f for f in schema.fields where f.name == self.name]
+ FieldRef(std::string name) : impl_(std::move(name)) {} // NOLINT runtime/explicit
+ FieldRef(const char* name) : impl_(std::string(name)) {} // NOLINT runtime/explicit
+
+ /// Equivalent to a single index string of indices.
+ FieldRef(int index) : impl_(FieldPath({index})) {} // NOLINT runtime/explicit
+
+ /// Convenience constructor for nested FieldRefs: each argument will be used to
+ /// construct a FieldRef
+ template <typename A0, typename A1, typename... A>
+ FieldRef(A0&& a0, A1&& a1, A&&... a) {
+ Flatten({// cpplint thinks the following are constructor decls
+ FieldRef(std::forward<A0>(a0)), // NOLINT runtime/explicit
+ FieldRef(std::forward<A1>(a1)), // NOLINT runtime/explicit
+ FieldRef(std::forward<A>(a))...}); // NOLINT runtime/explicit
+ }
+
+ /// Parse a dot path into a FieldRef.
+ ///
+ /// dot_path = '.' name
+ /// | '[' digit+ ']'
+ /// | dot_path+
+ ///
+ /// Examples:
+ /// ".alpha" => FieldRef("alpha")
+ /// "[2]" => FieldRef(2)
+ /// ".beta[3]" => FieldRef("beta", 3)
+ /// "[5].gamma.delta[7]" => FieldRef(5, "gamma", "delta", 7)
+ /// ".hello world" => FieldRef("hello world")
+ /// R"(.\[y\]\\tho\.\)" => FieldRef(R"([y]\tho.\)")
+ ///
+ /// Note: When parsing a name, a '\' preceding any other character will be dropped from
+ /// the resulting name. Therefore if a name must contain the characters '.', '\', or '['
+ /// those must be escaped with a preceding '\'.
+ static Result<FieldRef> FromDotPath(const std::string& dot_path);
+
+ bool Equals(const FieldRef& other) const { return impl_ == other.impl_; }
+ bool operator==(const FieldRef& other) const { return Equals(other); }
+
+ std::string ToString() const;
+
+ size_t hash() const;
+ struct Hash {
+ size_t operator()(const FieldRef& ref) const { return ref.hash(); }
+ };
+
+ explicit operator bool() const { return Equals(FieldPath{}); }
+ bool operator!() const { return !Equals(FieldPath{}); }
+
+ bool IsFieldPath() const { return util::holds_alternative<FieldPath>(impl_); }
+ bool IsName() const { return util::holds_alternative<std::string>(impl_); }
+ bool IsNested() const {
+ if (IsName()) return false;
+ if (IsFieldPath()) return util::get<FieldPath>(impl_).indices().size() > 1;
+ return true;
+ }
+
+ const FieldPath* field_path() const {
+ return IsFieldPath() ? &util::get<FieldPath>(impl_) : NULLPTR;
+ }
+ const std::string* name() const {
+ return IsName() ? &util::get<std::string>(impl_) : NULLPTR;
+ }
+
+ /// \brief Retrieve FieldPath of every child field which matches this FieldRef.
+ std::vector<FieldPath> FindAll(const Schema& schema) const;
+ std::vector<FieldPath> FindAll(const Field& field) const;
+ std::vector<FieldPath> FindAll(const DataType& type) const;
+ std::vector<FieldPath> FindAll(const FieldVector& fields) const;
+
+ /// \brief Convenience function which applies FindAll to arg's type or schema.
+ std::vector<FieldPath> FindAll(const ArrayData& array) const;
+ std::vector<FieldPath> FindAll(const Array& array) const;
+ std::vector<FieldPath> FindAll(const RecordBatch& batch) const;
+
+ /// \brief Convenience function: raise an error if matches is empty.
+ template <typename T>
+ Status CheckNonEmpty(const std::vector<FieldPath>& matches, const T& root) const {
+ if (matches.empty()) {
+ return Status::Invalid("No match for ", ToString(), " in ", root.ToString());
+ }
+ return Status::OK();
+ }
+
+ /// \brief Convenience function: raise an error if matches contains multiple FieldPaths.
+ template <typename T>
+ Status CheckNonMultiple(const std::vector<FieldPath>& matches, const T& root) const {
+ if (matches.size() > 1) {
+ return Status::Invalid("Multiple matches for ", ToString(), " in ",
+ root.ToString());
+ }
+ return Status::OK();
+ }
+
+ /// \brief Retrieve FieldPath of a single child field which matches this
+ /// FieldRef. Emit an error if none or multiple match.
+ template <typename T>
+ Result<FieldPath> FindOne(const T& root) const {
+ auto matches = FindAll(root);
+ ARROW_RETURN_NOT_OK(CheckNonEmpty(matches, root));
+ ARROW_RETURN_NOT_OK(CheckNonMultiple(matches, root));
+ return std::move(matches[0]);
+ }
+
+ /// \brief Retrieve FieldPath of a single child field which matches this
+ /// FieldRef. Emit an error if multiple match. An empty (invalid) FieldPath
+ /// will be returned if none match.
+ template <typename T>
+ Result<FieldPath> FindOneOrNone(const T& root) const {
+ auto matches = FindAll(root);
+ ARROW_RETURN_NOT_OK(CheckNonMultiple(matches, root));
+ if (matches.empty()) {
+ return FieldPath();
+ }
+ return std::move(matches[0]);
+ }
+
+ template <typename T>
+ using GetType = decltype(std::declval<FieldPath>().Get(std::declval<T>()).ValueOrDie());
+
+ /// \brief Get all children matching this FieldRef.
+ template <typename T>
+ std::vector<GetType<T>> GetAll(const T& root) const {
+ std::vector<GetType<T>> out;
+ for (const auto& match : FindAll(root)) {
+ out.push_back(match.Get(root).ValueOrDie());
+ }
+ return out;
+ }
+
+ /// \brief Get the single child matching this FieldRef.
+ /// Emit an error if none or multiple match.
+ template <typename T>
+ Result<GetType<T>> GetOne(const T& root) const {
+ ARROW_ASSIGN_OR_RAISE(auto match, FindOne(root));
+ return match.Get(root).ValueOrDie();
+ }
+
+ /// \brief Get the single child matching this FieldRef.
+ /// Return nullptr if none match, emit an error if multiple match.
+ template <typename T>
+ Result<GetType<T>> GetOneOrNone(const T& root) const {
+ ARROW_ASSIGN_OR_RAISE(auto match, FindOneOrNone(root));
+ if (match.empty()) {
+ return static_cast<GetType<T>>(NULLPTR);
+ }
+ return match.Get(root).ValueOrDie();
+ }
+
+ private:
+ void Flatten(std::vector<FieldRef> children);
+
+ util::Variant<FieldPath, std::string, std::vector<FieldRef>> impl_;
+
+ ARROW_EXPORT friend void PrintTo(const FieldRef& ref, std::ostream* os);
+};
+
+// ----------------------------------------------------------------------
+// Schema
+
+enum class Endianness {
+ Little = 0,
+ Big = 1,
+#if ARROW_LITTLE_ENDIAN
+ Native = Little
+#else
+ Native = Big
+#endif
+};
+
+/// \class Schema
+/// \brief Sequence of arrow::Field objects describing the columns of a record
+/// batch or table data structure
+class ARROW_EXPORT Schema : public detail::Fingerprintable,
+ public util::EqualityComparable<Schema>,
+ public util::ToStringOstreamable<Schema> {
+ public:
+ explicit Schema(FieldVector fields, Endianness endianness,
+ std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR);
+
+ explicit Schema(FieldVector fields,
+ std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR);
+
+ Schema(const Schema&);
+
+ ~Schema() override;
+
+ /// Returns true if all of the schema fields are equal
+ bool Equals(const Schema& other, bool check_metadata = false) const;
+ bool Equals(const std::shared_ptr<Schema>& other, bool check_metadata = false) const;
+
+ /// \brief Set endianness in the schema
+ ///
+ /// \return new Schema
+ std::shared_ptr<Schema> WithEndianness(Endianness endianness) const;
+
+ /// \brief Return endianness in the schema
+ Endianness endianness() const;
+
+ /// \brief Indicate if endianness is equal to platform-native endianness
+ bool is_native_endian() const;
+
+ /// \brief Return the number of fields (columns) in the schema
+ int num_fields() const;
+
+ /// Return the ith schema element. Does not boundscheck
+ const std::shared_ptr<Field>& field(int i) const;
+
+ const FieldVector& fields() const;
+
+ std::vector<std::string> field_names() const;
+
+ /// Returns null if name not found
+ std::shared_ptr<Field> GetFieldByName(const std::string& name) const;
+
+ /// \brief Return the indices of all fields having this name in sorted order
+ FieldVector GetAllFieldsByName(const std::string& name) const;
+
+ /// Returns -1 if name not found
+ int GetFieldIndex(const std::string& name) const;
+
+ /// Return the indices of all fields having this name
+ std::vector<int> GetAllFieldIndices(const std::string& name) const;
+
+ /// Indicate if fields named `names` can be found unambiguously in the schema.
+ Status CanReferenceFieldsByNames(const std::vector<std::string>& names) const;
+
+ /// \brief The custom key-value metadata, if any
+ ///
+ /// \return metadata may be null
+ const std::shared_ptr<const KeyValueMetadata>& metadata() const;
+
+ /// \brief Render a string representation of the schema suitable for debugging
+ /// \param[in] show_metadata when true, if KeyValueMetadata is non-empty,
+ /// print keys and values in the output
+ std::string ToString(bool show_metadata = false) const;
+
+ Result<std::shared_ptr<Schema>> AddField(int i,
+ const std::shared_ptr<Field>& field) const;
+ Result<std::shared_ptr<Schema>> RemoveField(int i) const;
+ Result<std::shared_ptr<Schema>> SetField(int i,
+ const std::shared_ptr<Field>& field) const;
+
+ /// \brief Replace key-value metadata with new metadata
+ ///
+ /// \param[in] metadata new KeyValueMetadata
+ /// \return new Schema
+ std::shared_ptr<Schema> WithMetadata(
+ const std::shared_ptr<const KeyValueMetadata>& metadata) const;
+
+ /// \brief Return copy of Schema without the KeyValueMetadata
+ std::shared_ptr<Schema> RemoveMetadata() const;
+
+ /// \brief Indicate that the Schema has non-empty KevValueMetadata
+ bool HasMetadata() const;
+
+ /// \brief Indicate that the Schema has distinct field names.
+ bool HasDistinctFieldNames() const;
+
+ protected:
+ std::string ComputeFingerprint() const override;
+ std::string ComputeMetadataFingerprint() const override;
+
+ private:
+ ARROW_EXPORT friend void PrintTo(const Schema& s, std::ostream* os);
+
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+ARROW_EXPORT
+std::string EndiannessToString(Endianness endianness);
+
+// ----------------------------------------------------------------------
+
+/// \brief Convenience class to incrementally construct/merge schemas.
+///
+/// This class amortizes the cost of validating field name conflicts by
+/// maintaining the mapping. The caller also controls the conflict resolution
+/// scheme.
+class ARROW_EXPORT SchemaBuilder {
+ public:
+ // Indicate how field conflict(s) should be resolved when building a schema. A
+ // conflict arise when a field is added to the builder and one or more field(s)
+ // with the same name already exists.
+ enum ConflictPolicy {
+ // Ignore the conflict and append the field. This is the default behavior of the
+ // Schema constructor and the `arrow::schema` factory function.
+ CONFLICT_APPEND = 0,
+ // Keep the existing field and ignore the newer one.
+ CONFLICT_IGNORE,
+ // Replace the existing field with the newer one.
+ CONFLICT_REPLACE,
+ // Merge the fields. The merging behavior can be controlled by `Field::MergeOptions`
+ // specified at construction time. Also see documentation of `Field::MergeWith`.
+ CONFLICT_MERGE,
+ // Refuse the new field and error out.
+ CONFLICT_ERROR
+ };
+
+ /// \brief Construct an empty SchemaBuilder
+ /// `field_merge_options` is only effective when `conflict_policy` == `CONFLICT_MERGE`.
+ SchemaBuilder(
+ ConflictPolicy conflict_policy = CONFLICT_APPEND,
+ Field::MergeOptions field_merge_options = Field::MergeOptions::Defaults());
+ /// \brief Construct a SchemaBuilder from a list of fields
+ /// `field_merge_options` is only effective when `conflict_policy` == `CONFLICT_MERGE`.
+ SchemaBuilder(
+ std::vector<std::shared_ptr<Field>> fields,
+ ConflictPolicy conflict_policy = CONFLICT_APPEND,
+ Field::MergeOptions field_merge_options = Field::MergeOptions::Defaults());
+ /// \brief Construct a SchemaBuilder from a schema, preserving the metadata
+ /// `field_merge_options` is only effective when `conflict_policy` == `CONFLICT_MERGE`.
+ SchemaBuilder(
+ const std::shared_ptr<Schema>& schema,
+ ConflictPolicy conflict_policy = CONFLICT_APPEND,
+ Field::MergeOptions field_merge_options = Field::MergeOptions::Defaults());
+
+ /// \brief Return the conflict resolution method.
+ ConflictPolicy policy() const;
+
+ /// \brief Set the conflict resolution method.
+ void SetPolicy(ConflictPolicy resolution);
+
+ /// \brief Add a field to the constructed schema.
+ ///
+ /// \param[in] field to add to the constructed Schema.
+ /// \return A failure if encountered.
+ Status AddField(const std::shared_ptr<Field>& field);
+
+ /// \brief Add multiple fields to the constructed schema.
+ ///
+ /// \param[in] fields to add to the constructed Schema.
+ /// \return The first failure encountered, if any.
+ Status AddFields(const std::vector<std::shared_ptr<Field>>& fields);
+
+ /// \brief Add fields of a Schema to the constructed Schema.
+ ///
+ /// \param[in] schema to take fields to add to the constructed Schema.
+ /// \return The first failure encountered, if any.
+ Status AddSchema(const std::shared_ptr<Schema>& schema);
+
+ /// \brief Add fields of multiple Schemas to the constructed Schema.
+ ///
+ /// \param[in] schemas to take fields to add to the constructed Schema.
+ /// \return The first failure encountered, if any.
+ Status AddSchemas(const std::vector<std::shared_ptr<Schema>>& schemas);
+
+ Status AddMetadata(const KeyValueMetadata& metadata);
+
+ /// \brief Return the constructed Schema.
+ ///
+ /// The builder internal state is not affected by invoking this method, i.e.
+ /// a single builder can yield multiple incrementally constructed schemas.
+ ///
+ /// \return the constructed schema.
+ Result<std::shared_ptr<Schema>> Finish() const;
+
+ /// \brief Merge schemas in a unified schema according to policy.
+ static Result<std::shared_ptr<Schema>> Merge(
+ const std::vector<std::shared_ptr<Schema>>& schemas,
+ ConflictPolicy policy = CONFLICT_MERGE);
+
+ /// \brief Indicate if schemas are compatible to merge according to policy.
+ static Status AreCompatible(const std::vector<std::shared_ptr<Schema>>& schemas,
+ ConflictPolicy policy = CONFLICT_MERGE);
+
+ /// \brief Reset internal state with an empty schema (and metadata).
+ void Reset();
+
+ ~SchemaBuilder();
+
+ private:
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+
+ Status AppendField(const std::shared_ptr<Field>& field);
+};
+
+/// \brief Unifies schemas by merging fields by name.
+///
+/// The behavior of field merging can be controlled via `Field::MergeOptions`.
+///
+/// The resulting schema will contain the union of fields from all schemas.
+/// Fields with the same name will be merged. See `Field::MergeOptions`.
+/// - They are expected to be mergeable under provided `field_merge_options`.
+/// - The unified field will inherit the metadata from the schema where
+/// that field is first defined.
+/// - The first N fields in the schema will be ordered the same as the
+/// N fields in the first schema.
+/// The resulting schema will inherit its metadata from the first input schema.
+/// Returns an error if:
+/// - Any input schema contains fields with duplicate names.
+/// - Fields of the same name are not mergeable.
+ARROW_EXPORT
+Result<std::shared_ptr<Schema>> UnifySchemas(
+ const std::vector<std::shared_ptr<Schema>>& schemas,
+ Field::MergeOptions field_merge_options = Field::MergeOptions::Defaults());
+
+namespace internal {
+
+static inline bool HasValidityBitmap(Type::type id) {
+ switch (id) {
+ case Type::NA:
+ case Type::DENSE_UNION:
+ case Type::SPARSE_UNION:
+ return false;
+ default:
+ return true;
+ }
+}
+
+ARROW_EXPORT
+std::string ToString(Type::type id);
+
+ARROW_EXPORT
+std::string ToTypeName(Type::type id);
+
+ARROW_EXPORT
+std::string ToString(TimeUnit::type unit);
+
+ARROW_EXPORT
+int GetByteWidth(const DataType& type);
+
+} // namespace internal
+
+// Helpers to get instances of data types based on general categories
+
+ARROW_EXPORT
+const std::vector<std::shared_ptr<DataType>>& SignedIntTypes();
+ARROW_EXPORT
+const std::vector<std::shared_ptr<DataType>>& UnsignedIntTypes();
+ARROW_EXPORT
+const std::vector<std::shared_ptr<DataType>>& IntTypes();
+ARROW_EXPORT
+const std::vector<std::shared_ptr<DataType>>& FloatingPointTypes();
+// Number types without boolean
+ARROW_EXPORT
+const std::vector<std::shared_ptr<DataType>>& NumericTypes();
+// Binary and string-like types (except fixed-size binary)
+ARROW_EXPORT
+const std::vector<std::shared_ptr<DataType>>& BaseBinaryTypes();
+ARROW_EXPORT
+const std::vector<std::shared_ptr<DataType>>& StringTypes();
+// Temporal types including time and timestamps for each unit
+ARROW_EXPORT
+const std::vector<std::shared_ptr<DataType>>& TemporalTypes();
+// Interval types
+ARROW_EXPORT
+const std::vector<std::shared_ptr<DataType>>& IntervalTypes();
+// Integer, floating point, base binary, and temporal
+ARROW_EXPORT
+const std::vector<std::shared_ptr<DataType>>& PrimitiveTypes();
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/type_benchmark.cc b/src/arrow/cpp/src/arrow/type_benchmark.cc
new file mode 100644
index 000000000..de90577ff
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/type_benchmark.cc
@@ -0,0 +1,439 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <exception>
+#include <random>
+#include <string>
+#include <vector>
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+
+static void TypeEqualsSimple(benchmark::State& state) { // NOLINT non-const reference
+ auto a = uint8();
+ auto b = uint8();
+ auto c = float64();
+
+ int64_t total = 0;
+ for (auto _ : state) {
+ total += a->Equals(*b);
+ total += a->Equals(*c);
+ }
+ benchmark::DoNotOptimize(total);
+ state.SetItemsProcessed(state.iterations() * 2);
+}
+
+static void TypeEqualsComplex(benchmark::State& state) { // NOLINT non-const reference
+ auto fa1 = field("as", list(float16()));
+ auto fa2 = field("as", list(float16()));
+ auto fb1 = field("bs", utf8());
+ auto fb2 = field("bs", utf8());
+ auto fc1 = field("cs", list(fixed_size_binary(10)));
+ auto fc2 = field("cs", list(fixed_size_binary(10)));
+ auto fc3 = field("cs", list(fixed_size_binary(11)));
+
+ auto a = struct_({fa1, fb1, fc1});
+ auto b = struct_({fa2, fb2, fc2});
+ auto c = struct_({fa2, fb2, fc3});
+
+ int64_t total = 0;
+ for (auto _ : state) {
+ total += a->Equals(*b);
+ total += a->Equals(*c);
+ }
+ benchmark::DoNotOptimize(total);
+ state.SetItemsProcessed(state.iterations() * 2);
+}
+
+static void TypeEqualsWithMetadata(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto md1 = key_value_metadata({"k1", "k2"}, {"some value1", "some value2"});
+ auto md2 = key_value_metadata({"k1", "k2"}, {"some value1", "some value2"});
+ auto md3 = key_value_metadata({"k2", "k1"}, {"some value2", "some value1"});
+
+ auto fa1 = field("as", list(float16()));
+ auto fa2 = field("as", list(float16()));
+ auto fb1 = field("bs", utf8(), /*nullable=*/true, md1);
+ auto fb2 = field("bs", utf8(), /*nullable=*/true, md2);
+ auto fb3 = field("bs", utf8(), /*nullable=*/true, md3);
+
+ auto a = struct_({fa1, fb1});
+ auto b = struct_({fa2, fb2});
+ auto c = struct_({fa2, fb3});
+
+ int64_t total = 0;
+ for (auto _ : state) {
+ total += a->Equals(*b);
+ total += a->Equals(*c);
+ }
+ benchmark::DoNotOptimize(total);
+ state.SetItemsProcessed(state.iterations() * 2);
+}
+
+static std::vector<std::shared_ptr<Schema>> SampleSchemas() {
+ auto fa1 = field("as", list(float16()));
+ auto fa2 = field("as", list(float16()));
+ auto fb1 = field("bs", utf8());
+ auto fb2 = field("bs", utf8());
+ auto fc1 = field("cs", list(fixed_size_binary(10)));
+ auto fc2 = field("cs", list(fixed_size_binary(10)));
+ auto fd1 = field("ds", decimal(19, 5));
+ auto fd2 = field("ds", decimal(19, 5));
+ auto fe1 = field("es", map(utf8(), int32()));
+ auto fe2 = field("es", map(utf8(), int32()));
+ auto ff1 = field("fs", dictionary(int8(), binary()));
+ auto ff2 = field("fs", dictionary(int8(), binary()));
+ auto fg1 = field(
+ "gs", struct_({field("A", int8()), field("B", int16()), field("C", float32())}));
+ auto fg2 = field(
+ "gs", struct_({field("A", int8()), field("B", int16()), field("C", float32())}));
+ auto fh1 = field("hs", large_binary());
+ auto fh2 = field("hs", large_binary());
+
+ auto fz1 = field("zs", duration(TimeUnit::MICRO));
+ auto fz2 = field("zs", duration(TimeUnit::MICRO));
+ auto fz3 = field("zs", duration(TimeUnit::NANO));
+
+ auto schema1 = ::arrow::schema({fa1, fb1, fc1, fd1, fe1, ff1, fg1, fh1, fz1});
+ auto schema2 = ::arrow::schema({fa2, fb2, fc2, fd2, fe2, ff2, fg2, fh2, fz2});
+ auto schema3 = ::arrow::schema({fa2, fb2, fc2, fd2, fe2, ff2, fg2, fh2, fz3});
+
+ return {schema1, schema2, schema3};
+}
+
+static void SchemaEquals(benchmark::State& state) { // NOLINT non-const reference
+ auto schemas = SampleSchemas();
+
+ auto schema1 = schemas[0];
+ auto schema2 = schemas[1];
+ auto schema3 = schemas[2];
+
+ int64_t total = 0;
+ for (auto _ : state) {
+ total += schema1->Equals(*schema2, /*check_metadata =*/false);
+ total += schema1->Equals(*schema3, /*check_metadata =*/false);
+ }
+ benchmark::DoNotOptimize(total);
+ state.SetItemsProcessed(state.iterations() * 2);
+}
+
+static void SchemaEqualsWithMetadata(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto schemas = SampleSchemas();
+
+ auto schema1 = schemas[0];
+ auto schema2 = schemas[1];
+ auto schema3 = schemas[2];
+
+ auto md1 = key_value_metadata({"k1", "k2"}, {"some value1", "some value2"});
+ auto md2 = key_value_metadata({"k1", "k2"}, {"some value1", "some value2"});
+ auto md3 = key_value_metadata({"k2", "k1"}, {"some value2", "some value1"});
+
+ schema1 = schema1->WithMetadata(md1);
+ schema2 = schema1->WithMetadata(md2);
+ schema3 = schema1->WithMetadata(md3);
+
+ int64_t total = 0;
+ for (auto _ : state) {
+ total += schema1->Equals(*schema2);
+ total += schema1->Equals(*schema3);
+ }
+ benchmark::DoNotOptimize(total);
+ state.SetItemsProcessed(state.iterations() * 2);
+}
+
+// ------------------------------------------------------------------------
+// Micro-benchmark various error reporting schemes
+
+#if (defined(__GNUC__) || defined(__APPLE__))
+#define ARROW_NO_INLINE __attribute__((noinline))
+#elif defined(_MSC_VER)
+#define ARROW_NO_INLINE __declspec(noinline)
+#else
+#define ARROW_NO_INLINE
+#warning Missing "noinline" attribute, no-inline benchmarks may be bogus
+#endif
+
+inline int64_t Accumulate(int64_t partial, int32_t value) {
+ // Something non-trivial to avoid vectorization
+ return partial + value + (partial >> 5) * value;
+}
+
+std::vector<int32_t> RandomIntegers() {
+ std::default_random_engine gen(42);
+ // Make 42 extremely unlikely (to make error Status allocation negligible)
+ std::uniform_int_distribution<int32_t> dist(0, 100000);
+
+ std::vector<int32_t> integers(6000);
+ std::generate(integers.begin(), integers.end(), [&]() { return dist(gen); });
+ return integers;
+}
+
+inline int32_t NoError(int32_t v) { return v + 1; }
+
+ARROW_NO_INLINE int32_t NoErrorNoInline(int32_t v) { return v + 1; }
+
+inline std::pair<bool, int32_t> ErrorAsBool(int32_t v) {
+ return {ARROW_PREDICT_FALSE(v == 42), v + 1};
+}
+
+ARROW_NO_INLINE std::pair<bool, int32_t> ErrorAsBoolNoInline(int32_t v) {
+ return {ARROW_PREDICT_FALSE(v == 42), v + 1};
+}
+
+inline Status ErrorAsStatus(int32_t v, int32_t* out) {
+ if (ARROW_PREDICT_FALSE(v == 42)) {
+ return Status::Invalid("42");
+ }
+ *out = v + 1;
+ return Status::OK();
+}
+
+ARROW_NO_INLINE Status ErrorAsStatusNoInline(int32_t v, int32_t* out) {
+ if (ARROW_PREDICT_FALSE(v == 42)) {
+ return Status::Invalid("42");
+ }
+ *out = v + 1;
+ return Status::OK();
+}
+
+inline Result<int32_t> ErrorAsResult(int32_t v) {
+ if (ARROW_PREDICT_FALSE(v == 42)) {
+ return Status::Invalid("42");
+ }
+ return v + 1;
+}
+
+ARROW_NO_INLINE Result<int32_t> ErrorAsResultNoInline(int32_t v) {
+ if (ARROW_PREDICT_FALSE(v == 42)) {
+ return Status::Invalid("42");
+ }
+ return v + 1;
+}
+
+inline int32_t ErrorAsException(int32_t v) {
+ if (ARROW_PREDICT_FALSE(v == 42)) {
+ throw std::invalid_argument("42");
+ }
+ return v + 1;
+}
+
+ARROW_NO_INLINE int32_t ErrorAsExceptionNoInline(int32_t v) {
+ if (ARROW_PREDICT_FALSE(v == 42)) {
+ throw std::invalid_argument("42");
+ }
+ return v + 1;
+}
+
+static void ErrorSchemeNoError(benchmark::State& state) { // NOLINT non-const reference
+ auto integers = RandomIntegers();
+
+ for (auto _ : state) {
+ int64_t total = 0;
+ for (const auto v : integers) {
+ total = Accumulate(total, NoError(v));
+ }
+ benchmark::DoNotOptimize(total);
+ }
+
+ state.SetItemsProcessed(state.iterations() * integers.size());
+}
+
+static void ErrorSchemeNoErrorNoInline(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto integers = RandomIntegers();
+
+ for (auto _ : state) {
+ int64_t total = 0;
+ for (const auto v : integers) {
+ total = Accumulate(total, NoErrorNoInline(v));
+ }
+ benchmark::DoNotOptimize(total);
+ }
+
+ state.SetItemsProcessed(state.iterations() * integers.size());
+}
+
+static void ErrorSchemeBool(benchmark::State& state) { // NOLINT non-const reference
+ auto integers = RandomIntegers();
+
+ for (auto _ : state) {
+ int64_t total = 0;
+ for (const auto v : integers) {
+ auto pair = ErrorAsBool(v);
+ if (!ARROW_PREDICT_FALSE(pair.first)) {
+ total = Accumulate(total, pair.second);
+ }
+ }
+ benchmark::DoNotOptimize(total);
+ }
+
+ state.SetItemsProcessed(state.iterations() * integers.size());
+}
+
+static void ErrorSchemeBoolNoInline(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto integers = RandomIntegers();
+
+ for (auto _ : state) {
+ int64_t total = 0;
+ for (const auto v : integers) {
+ auto pair = ErrorAsBoolNoInline(v);
+ if (!ARROW_PREDICT_FALSE(pair.first)) {
+ total = Accumulate(total, pair.second);
+ }
+ }
+ benchmark::DoNotOptimize(total);
+ }
+
+ state.SetItemsProcessed(state.iterations() * integers.size());
+}
+
+static void ErrorSchemeStatus(benchmark::State& state) { // NOLINT non-const reference
+ auto integers = RandomIntegers();
+
+ for (auto _ : state) {
+ int64_t total = 0;
+ for (const auto v : integers) {
+ int32_t value = 0;
+ if (ARROW_PREDICT_TRUE(ErrorAsStatus(v, &value).ok())) {
+ total = Accumulate(total, value);
+ }
+ }
+ benchmark::DoNotOptimize(total);
+ }
+
+ state.SetItemsProcessed(state.iterations() * integers.size());
+}
+
+static void ErrorSchemeStatusNoInline(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto integers = RandomIntegers();
+
+ for (auto _ : state) {
+ int64_t total = 0;
+ for (const auto v : integers) {
+ int32_t value;
+ if (ARROW_PREDICT_TRUE(ErrorAsStatusNoInline(v, &value).ok())) {
+ total = Accumulate(total, value);
+ }
+ }
+ benchmark::DoNotOptimize(total);
+ }
+
+ state.SetItemsProcessed(state.iterations() * integers.size());
+}
+
+static void ErrorSchemeResult(benchmark::State& state) { // NOLINT non-const reference
+ auto integers = RandomIntegers();
+
+ for (auto _ : state) {
+ int64_t total = 0;
+ for (const auto v : integers) {
+ auto maybe_value = ErrorAsResult(v);
+ if (ARROW_PREDICT_TRUE(maybe_value.ok())) {
+ total = Accumulate(total, *std::move(maybe_value));
+ }
+ }
+ benchmark::DoNotOptimize(total);
+ }
+
+ state.SetItemsProcessed(state.iterations() * integers.size());
+}
+
+static void ErrorSchemeResultNoInline(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto integers = RandomIntegers();
+
+ for (auto _ : state) {
+ int64_t total = 0;
+ for (const auto v : integers) {
+ auto maybe_value = ErrorAsResultNoInline(v);
+ if (ARROW_PREDICT_TRUE(maybe_value.ok())) {
+ total = Accumulate(total, *std::move(maybe_value));
+ }
+ }
+ benchmark::DoNotOptimize(total);
+ }
+
+ state.SetItemsProcessed(state.iterations() * integers.size());
+}
+
+static void ErrorSchemeException(benchmark::State& state) { // NOLINT non-const reference
+ auto integers = RandomIntegers();
+
+ for (auto _ : state) {
+ int64_t total = 0;
+ for (const auto v : integers) {
+ try {
+ total = Accumulate(total, ErrorAsException(v));
+ } catch (const std::exception&) {
+ }
+ }
+ benchmark::DoNotOptimize(total);
+ }
+
+ state.SetItemsProcessed(state.iterations() * integers.size());
+}
+
+static void ErrorSchemeExceptionNoInline(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto integers = RandomIntegers();
+
+ for (auto _ : state) {
+ int64_t total = 0;
+ for (const auto v : integers) {
+ try {
+ total = Accumulate(total, ErrorAsExceptionNoInline(v));
+ } catch (const std::exception&) {
+ }
+ }
+ benchmark::DoNotOptimize(total);
+ }
+
+ state.SetItemsProcessed(state.iterations() * integers.size());
+}
+
+BENCHMARK(TypeEqualsSimple);
+BENCHMARK(TypeEqualsComplex);
+BENCHMARK(TypeEqualsWithMetadata);
+BENCHMARK(SchemaEquals);
+BENCHMARK(SchemaEqualsWithMetadata);
+
+BENCHMARK(ErrorSchemeNoError);
+BENCHMARK(ErrorSchemeBool);
+BENCHMARK(ErrorSchemeStatus);
+BENCHMARK(ErrorSchemeResult);
+BENCHMARK(ErrorSchemeException);
+
+BENCHMARK(ErrorSchemeNoErrorNoInline);
+BENCHMARK(ErrorSchemeBoolNoInline);
+BENCHMARK(ErrorSchemeStatusNoInline);
+BENCHMARK(ErrorSchemeResultNoInline);
+BENCHMARK(ErrorSchemeExceptionNoInline);
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/type_fwd.h b/src/arrow/cpp/src/arrow/type_fwd.h
new file mode 100644
index 000000000..45afd7af2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/type_fwd.h
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+template <typename T>
+class Iterator;
+template <typename T>
+struct IterationTraits;
+
+template <typename T>
+class Result;
+
+class Status;
+
+namespace internal {
+struct Empty;
+} // namespace internal
+template <typename T = internal::Empty>
+class Future;
+
+namespace util {
+class Codec;
+} // namespace util
+
+class Buffer;
+class Device;
+class MemoryManager;
+class MemoryPool;
+class MutableBuffer;
+class ResizableBuffer;
+
+using BufferVector = std::vector<std::shared_ptr<Buffer>>;
+
+class DataType;
+class Field;
+class FieldRef;
+class KeyValueMetadata;
+enum class Endianness;
+class Schema;
+
+using DataTypeVector = std::vector<std::shared_ptr<DataType>>;
+using FieldVector = std::vector<std::shared_ptr<Field>>;
+
+class Array;
+struct ArrayData;
+class ArrayBuilder;
+struct Scalar;
+
+using ArrayDataVector = std::vector<std::shared_ptr<ArrayData>>;
+using ArrayVector = std::vector<std::shared_ptr<Array>>;
+using ScalarVector = std::vector<std::shared_ptr<Scalar>>;
+
+class ChunkedArray;
+class RecordBatch;
+class RecordBatchReader;
+class Table;
+
+struct Datum;
+struct ValueDescr;
+
+using ChunkedArrayVector = std::vector<std::shared_ptr<ChunkedArray>>;
+using RecordBatchVector = std::vector<std::shared_ptr<RecordBatch>>;
+using RecordBatchIterator = Iterator<std::shared_ptr<RecordBatch>>;
+
+class DictionaryType;
+class DictionaryArray;
+struct DictionaryScalar;
+
+class NullType;
+class NullArray;
+class NullBuilder;
+struct NullScalar;
+
+class FixedWidthType;
+
+class BooleanType;
+class BooleanArray;
+class BooleanBuilder;
+struct BooleanScalar;
+
+class BinaryType;
+class BinaryArray;
+class BinaryBuilder;
+struct BinaryScalar;
+
+class LargeBinaryType;
+class LargeBinaryArray;
+class LargeBinaryBuilder;
+struct LargeBinaryScalar;
+
+class FixedSizeBinaryType;
+class FixedSizeBinaryArray;
+class FixedSizeBinaryBuilder;
+struct FixedSizeBinaryScalar;
+
+class StringType;
+class StringArray;
+class StringBuilder;
+struct StringScalar;
+
+class LargeStringType;
+class LargeStringArray;
+class LargeStringBuilder;
+struct LargeStringScalar;
+
+class ListType;
+class ListArray;
+class ListBuilder;
+struct ListScalar;
+
+class LargeListType;
+class LargeListArray;
+class LargeListBuilder;
+struct LargeListScalar;
+
+class MapType;
+class MapArray;
+class MapBuilder;
+struct MapScalar;
+
+class FixedSizeListType;
+class FixedSizeListArray;
+class FixedSizeListBuilder;
+struct FixedSizeListScalar;
+
+class StructType;
+class StructArray;
+class StructBuilder;
+struct StructScalar;
+
+class Decimal128;
+class Decimal256;
+class DecimalType;
+class Decimal128Type;
+class Decimal256Type;
+class Decimal128Array;
+class Decimal256Array;
+class Decimal128Builder;
+class Decimal256Builder;
+struct Decimal128Scalar;
+struct Decimal256Scalar;
+
+struct UnionMode {
+ enum type { SPARSE, DENSE };
+};
+
+class SparseUnionType;
+class SparseUnionArray;
+class SparseUnionBuilder;
+struct SparseUnionScalar;
+
+class DenseUnionType;
+class DenseUnionArray;
+class DenseUnionBuilder;
+struct DenseUnionScalar;
+
+template <typename TypeClass>
+class NumericArray;
+
+template <typename TypeClass>
+class NumericBuilder;
+
+template <typename TypeClass>
+class NumericTensor;
+
+#define _NUMERIC_TYPE_DECL(KLASS) \
+ class KLASS##Type; \
+ using KLASS##Array = NumericArray<KLASS##Type>; \
+ using KLASS##Builder = NumericBuilder<KLASS##Type>; \
+ struct KLASS##Scalar; \
+ using KLASS##Tensor = NumericTensor<KLASS##Type>;
+
+_NUMERIC_TYPE_DECL(Int8)
+_NUMERIC_TYPE_DECL(Int16)
+_NUMERIC_TYPE_DECL(Int32)
+_NUMERIC_TYPE_DECL(Int64)
+_NUMERIC_TYPE_DECL(UInt8)
+_NUMERIC_TYPE_DECL(UInt16)
+_NUMERIC_TYPE_DECL(UInt32)
+_NUMERIC_TYPE_DECL(UInt64)
+_NUMERIC_TYPE_DECL(HalfFloat)
+_NUMERIC_TYPE_DECL(Float)
+_NUMERIC_TYPE_DECL(Double)
+
+#undef _NUMERIC_TYPE_DECL
+
+enum class DateUnit : char { DAY = 0, MILLI = 1 };
+
+class DateType;
+class Date32Type;
+using Date32Array = NumericArray<Date32Type>;
+using Date32Builder = NumericBuilder<Date32Type>;
+struct Date32Scalar;
+
+class Date64Type;
+using Date64Array = NumericArray<Date64Type>;
+using Date64Builder = NumericBuilder<Date64Type>;
+struct Date64Scalar;
+
+struct ARROW_EXPORT TimeUnit {
+ /// The unit for a time or timestamp DataType
+ enum type { SECOND = 0, MILLI = 1, MICRO = 2, NANO = 3 };
+
+ /// Iterate over all valid time units
+ static const std::vector<TimeUnit::type>& values();
+};
+
+class TimeType;
+class Time32Type;
+using Time32Array = NumericArray<Time32Type>;
+using Time32Builder = NumericBuilder<Time32Type>;
+struct Time32Scalar;
+
+class Time64Type;
+using Time64Array = NumericArray<Time64Type>;
+using Time64Builder = NumericBuilder<Time64Type>;
+struct Time64Scalar;
+
+class TimestampType;
+using TimestampArray = NumericArray<TimestampType>;
+using TimestampBuilder = NumericBuilder<TimestampType>;
+struct TimestampScalar;
+
+class MonthIntervalType;
+using MonthIntervalArray = NumericArray<MonthIntervalType>;
+using MonthIntervalBuilder = NumericBuilder<MonthIntervalType>;
+struct MonthIntervalScalar;
+
+class DayTimeIntervalType;
+class DayTimeIntervalArray;
+class DayTimeIntervalBuilder;
+struct DayTimeIntervalScalar;
+
+class MonthDayNanoIntervalType;
+class MonthDayNanoIntervalArray;
+class MonthDayNanoIntervalBuilder;
+struct MonthDayNanoIntervalScalar;
+
+class DurationType;
+using DurationArray = NumericArray<DurationType>;
+using DurationBuilder = NumericBuilder<DurationType>;
+struct DurationScalar;
+
+class ExtensionType;
+class ExtensionArray;
+struct ExtensionScalar;
+
+class Tensor;
+class SparseTensor;
+
+// ----------------------------------------------------------------------
+
+struct Type {
+ /// \brief Main data type enumeration
+ ///
+ /// This enumeration provides a quick way to interrogate the category
+ /// of a DataType instance.
+ enum type {
+ /// A NULL type having no physical storage
+ NA = 0,
+
+ /// Boolean as 1 bit, LSB bit-packed ordering
+ BOOL,
+
+ /// Unsigned 8-bit little-endian integer
+ UINT8,
+
+ /// Signed 8-bit little-endian integer
+ INT8,
+
+ /// Unsigned 16-bit little-endian integer
+ UINT16,
+
+ /// Signed 16-bit little-endian integer
+ INT16,
+
+ /// Unsigned 32-bit little-endian integer
+ UINT32,
+
+ /// Signed 32-bit little-endian integer
+ INT32,
+
+ /// Unsigned 64-bit little-endian integer
+ UINT64,
+
+ /// Signed 64-bit little-endian integer
+ INT64,
+
+ /// 2-byte floating point value
+ HALF_FLOAT,
+
+ /// 4-byte floating point value
+ FLOAT,
+
+ /// 8-byte floating point value
+ DOUBLE,
+
+ /// UTF8 variable-length string as List<Char>
+ STRING,
+
+ /// Variable-length bytes (no guarantee of UTF8-ness)
+ BINARY,
+
+ /// Fixed-size binary. Each value occupies the same number of bytes
+ FIXED_SIZE_BINARY,
+
+ /// int32_t days since the UNIX epoch
+ DATE32,
+
+ /// int64_t milliseconds since the UNIX epoch
+ DATE64,
+
+ /// Exact timestamp encoded with int64 since UNIX epoch
+ /// Default unit millisecond
+ TIMESTAMP,
+
+ /// Time as signed 32-bit integer, representing either seconds or
+ /// milliseconds since midnight
+ TIME32,
+
+ /// Time as signed 64-bit integer, representing either microseconds or
+ /// nanoseconds since midnight
+ TIME64,
+
+ /// YEAR_MONTH interval in SQL style
+ INTERVAL_MONTHS,
+
+ /// DAY_TIME interval in SQL style
+ INTERVAL_DAY_TIME,
+
+ /// Precision- and scale-based decimal type with 128 bits.
+ DECIMAL128,
+
+ /// Defined for backward-compatibility.
+ DECIMAL = DECIMAL128,
+
+ /// Precision- and scale-based decimal type with 256 bits.
+ DECIMAL256,
+
+ /// A list of some logical data type
+ LIST,
+
+ /// Struct of logical types
+ STRUCT,
+
+ /// Sparse unions of logical types
+ SPARSE_UNION,
+
+ /// Dense unions of logical types
+ DENSE_UNION,
+
+ /// Dictionary-encoded type, also called "categorical" or "factor"
+ /// in other programming languages. Holds the dictionary value
+ /// type but not the dictionary itself, which is part of the
+ /// ArrayData struct
+ DICTIONARY,
+
+ /// Map, a repeated struct logical type
+ MAP,
+
+ /// Custom data type, implemented by user
+ EXTENSION,
+
+ /// Fixed size list of some logical type
+ FIXED_SIZE_LIST,
+
+ /// Measure of elapsed time in either seconds, milliseconds, microseconds
+ /// or nanoseconds.
+ DURATION,
+
+ /// Like STRING, but with 64-bit offsets
+ LARGE_STRING,
+
+ /// Like BINARY, but with 64-bit offsets
+ LARGE_BINARY,
+
+ /// Like LIST, but with 64-bit offsets
+ LARGE_LIST,
+
+ /// Calendar interval type with three fields.
+ INTERVAL_MONTH_DAY_NANO,
+
+ // Leave this at the end
+ MAX_ID
+ };
+};
+
+/// \defgroup type-factories Factory functions for creating data types
+///
+/// Factory functions for creating data types
+/// @{
+
+/// \brief Return a NullType instance
+std::shared_ptr<DataType> ARROW_EXPORT null();
+/// \brief Return a BooleanType instance
+std::shared_ptr<DataType> ARROW_EXPORT boolean();
+/// \brief Return a Int8Type instance
+std::shared_ptr<DataType> ARROW_EXPORT int8();
+/// \brief Return a Int16Type instance
+std::shared_ptr<DataType> ARROW_EXPORT int16();
+/// \brief Return a Int32Type instance
+std::shared_ptr<DataType> ARROW_EXPORT int32();
+/// \brief Return a Int64Type instance
+std::shared_ptr<DataType> ARROW_EXPORT int64();
+/// \brief Return a UInt8Type instance
+std::shared_ptr<DataType> ARROW_EXPORT uint8();
+/// \brief Return a UInt16Type instance
+std::shared_ptr<DataType> ARROW_EXPORT uint16();
+/// \brief Return a UInt32Type instance
+std::shared_ptr<DataType> ARROW_EXPORT uint32();
+/// \brief Return a UInt64Type instance
+std::shared_ptr<DataType> ARROW_EXPORT uint64();
+/// \brief Return a HalfFloatType instance
+std::shared_ptr<DataType> ARROW_EXPORT float16();
+/// \brief Return a FloatType instance
+std::shared_ptr<DataType> ARROW_EXPORT float32();
+/// \brief Return a DoubleType instance
+std::shared_ptr<DataType> ARROW_EXPORT float64();
+/// \brief Return a StringType instance
+std::shared_ptr<DataType> ARROW_EXPORT utf8();
+/// \brief Return a LargeStringType instance
+std::shared_ptr<DataType> ARROW_EXPORT large_utf8();
+/// \brief Return a BinaryType instance
+std::shared_ptr<DataType> ARROW_EXPORT binary();
+/// \brief Return a LargeBinaryType instance
+std::shared_ptr<DataType> ARROW_EXPORT large_binary();
+/// \brief Return a Date32Type instance
+std::shared_ptr<DataType> ARROW_EXPORT date32();
+/// \brief Return a Date64Type instance
+std::shared_ptr<DataType> ARROW_EXPORT date64();
+
+/// \brief Create a FixedSizeBinaryType instance.
+ARROW_EXPORT
+std::shared_ptr<DataType> fixed_size_binary(int32_t byte_width);
+
+/// \brief Create a DecimalType instance depending on the precision
+///
+/// If the precision is greater than 38, a Decimal256Type is returned,
+/// otherwise a Decimal128Type.
+ARROW_EXPORT
+std::shared_ptr<DataType> decimal(int32_t precision, int32_t scale);
+
+/// \brief Create a Decimal128Type instance
+ARROW_EXPORT
+std::shared_ptr<DataType> decimal128(int32_t precision, int32_t scale);
+
+/// \brief Create a Decimal256Type instance
+ARROW_EXPORT
+std::shared_ptr<DataType> decimal256(int32_t precision, int32_t scale);
+
+/// \brief Create a ListType instance from its child Field type
+ARROW_EXPORT
+std::shared_ptr<DataType> list(const std::shared_ptr<Field>& value_type);
+
+/// \brief Create a ListType instance from its child DataType
+ARROW_EXPORT
+std::shared_ptr<DataType> list(const std::shared_ptr<DataType>& value_type);
+
+/// \brief Create a LargeListType instance from its child Field type
+ARROW_EXPORT
+std::shared_ptr<DataType> large_list(const std::shared_ptr<Field>& value_type);
+
+/// \brief Create a LargeListType instance from its child DataType
+ARROW_EXPORT
+std::shared_ptr<DataType> large_list(const std::shared_ptr<DataType>& value_type);
+
+/// \brief Create a MapType instance from its key and value DataTypes
+ARROW_EXPORT
+std::shared_ptr<DataType> map(std::shared_ptr<DataType> key_type,
+ std::shared_ptr<DataType> item_type,
+ bool keys_sorted = false);
+
+/// \brief Create a MapType instance from its key DataType and value field.
+///
+/// The field override is provided to communicate nullability of the value.
+ARROW_EXPORT
+std::shared_ptr<DataType> map(std::shared_ptr<DataType> key_type,
+ std::shared_ptr<Field> item_field,
+ bool keys_sorted = false);
+
+/// \brief Create a FixedSizeListType instance from its child Field type
+ARROW_EXPORT
+std::shared_ptr<DataType> fixed_size_list(const std::shared_ptr<Field>& value_type,
+ int32_t list_size);
+
+/// \brief Create a FixedSizeListType instance from its child DataType
+ARROW_EXPORT
+std::shared_ptr<DataType> fixed_size_list(const std::shared_ptr<DataType>& value_type,
+ int32_t list_size);
+/// \brief Return a Duration instance (naming use _type to avoid namespace conflict with
+/// built in time classes).
+std::shared_ptr<DataType> ARROW_EXPORT duration(TimeUnit::type unit);
+
+/// \brief Return a DayTimeIntervalType instance
+std::shared_ptr<DataType> ARROW_EXPORT day_time_interval();
+
+/// \brief Return a MonthIntervalType instance
+std::shared_ptr<DataType> ARROW_EXPORT month_interval();
+
+/// \brief Return a MonthDayNanoIntervalType instance
+std::shared_ptr<DataType> ARROW_EXPORT month_day_nano_interval();
+
+/// \brief Create a TimestampType instance from its unit
+ARROW_EXPORT
+std::shared_ptr<DataType> timestamp(TimeUnit::type unit);
+
+/// \brief Create a TimestampType instance from its unit and timezone
+ARROW_EXPORT
+std::shared_ptr<DataType> timestamp(TimeUnit::type unit, const std::string& timezone);
+
+/// \brief Create a 32-bit time type instance
+///
+/// Unit can be either SECOND or MILLI
+std::shared_ptr<DataType> ARROW_EXPORT time32(TimeUnit::type unit);
+
+/// \brief Create a 64-bit time type instance
+///
+/// Unit can be either MICRO or NANO
+std::shared_ptr<DataType> ARROW_EXPORT time64(TimeUnit::type unit);
+
+/// \brief Create a StructType instance
+std::shared_ptr<DataType> ARROW_EXPORT
+struct_(const std::vector<std::shared_ptr<Field>>& fields);
+
+/// \brief Create a SparseUnionType instance
+std::shared_ptr<DataType> ARROW_EXPORT sparse_union(FieldVector child_fields,
+ std::vector<int8_t> type_codes = {});
+/// \brief Create a SparseUnionType instance
+std::shared_ptr<DataType> ARROW_EXPORT
+sparse_union(const ArrayVector& children, std::vector<std::string> field_names = {},
+ std::vector<int8_t> type_codes = {});
+
+/// \brief Create a DenseUnionType instance
+std::shared_ptr<DataType> ARROW_EXPORT dense_union(FieldVector child_fields,
+ std::vector<int8_t> type_codes = {});
+/// \brief Create a DenseUnionType instance
+std::shared_ptr<DataType> ARROW_EXPORT
+dense_union(const ArrayVector& children, std::vector<std::string> field_names = {},
+ std::vector<int8_t> type_codes = {});
+
+/// \brief Create a DictionaryType instance
+/// \param[in] index_type the type of the dictionary indices (must be
+/// a signed integer)
+/// \param[in] dict_type the type of the values in the variable dictionary
+/// \param[in] ordered true if the order of the dictionary values has
+/// semantic meaning and should be preserved where possible
+ARROW_EXPORT
+std::shared_ptr<DataType> dictionary(const std::shared_ptr<DataType>& index_type,
+ const std::shared_ptr<DataType>& dict_type,
+ bool ordered = false);
+
+/// @}
+
+/// \defgroup schema-factories Factory functions for fields and schemas
+///
+/// Factory functions for fields and schemas
+/// @{
+
+/// \brief Create a Field instance
+///
+/// \param name the field name
+/// \param type the field value type
+/// \param nullable whether the values are nullable, default true
+/// \param metadata any custom key-value metadata, default null
+std::shared_ptr<Field> ARROW_EXPORT
+field(std::string name, std::shared_ptr<DataType> type, bool nullable = true,
+ std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR);
+
+/// \brief Create a Field instance with metadata
+///
+/// The field will be assumed to be nullable.
+///
+/// \param name the field name
+/// \param type the field value type
+/// \param metadata any custom key-value metadata
+std::shared_ptr<Field> ARROW_EXPORT
+field(std::string name, std::shared_ptr<DataType> type,
+ std::shared_ptr<const KeyValueMetadata> metadata);
+
+/// \brief Create a Schema instance
+///
+/// \param fields the schema's fields
+/// \param metadata any custom key-value metadata, default null
+/// \return schema shared_ptr to Schema
+ARROW_EXPORT
+std::shared_ptr<Schema> schema(
+ std::vector<std::shared_ptr<Field>> fields,
+ std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR);
+
+/// \brief Create a Schema instance
+///
+/// \param fields the schema's fields
+/// \param endianness the endianness of the data
+/// \param metadata any custom key-value metadata, default null
+/// \return schema shared_ptr to Schema
+ARROW_EXPORT
+std::shared_ptr<Schema> schema(
+ std::vector<std::shared_ptr<Field>> fields, Endianness endianness,
+ std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR);
+
+/// @}
+
+/// Return the process-wide default memory pool.
+ARROW_EXPORT MemoryPool* default_memory_pool();
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/type_test.cc b/src/arrow/cpp/src/arrow/type_test.cc
new file mode 100644
index 000000000..f8294fc6d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/type_test.cc
@@ -0,0 +1,1792 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Unit tests for DataType (and subclasses), Field, and Schema
+
+#include <algorithm>
+#include <cctype>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include <gmock/gmock.h>
+
+#include "arrow/memory_pool.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+
+using testing::ElementsAre;
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+TEST(TestTypeId, AllTypeIds) {
+ const auto all_ids = AllTypeIds();
+ ASSERT_EQ(static_cast<int>(all_ids.size()), Type::MAX_ID);
+}
+
+template <typename ReprFunc>
+void CheckTypeIdReprs(ReprFunc&& repr_func, bool expect_uppercase) {
+ std::unordered_set<std::string> unique_reprs;
+ const auto all_ids = AllTypeIds();
+ for (const auto id : all_ids) {
+ std::string repr = repr_func(id);
+ ASSERT_TRUE(std::all_of(repr.begin(), repr.end(),
+ [=](const char c) {
+ return c == '_' || std::isdigit(c) ||
+ (expect_uppercase ? std::isupper(c)
+ : std::islower(c));
+ }))
+ << "Invalid type id repr: '" << repr << "'";
+ unique_reprs.insert(std::move(repr));
+ }
+ // No duplicates
+ ASSERT_EQ(unique_reprs.size(), all_ids.size());
+}
+
+TEST(TestTypeId, ToString) {
+ // Should be all uppercase strings (corresponding to the enum member names)
+ CheckTypeIdReprs([](Type::type id) { return internal::ToString(id); },
+ /* expect_uppercase=*/true);
+}
+
+TEST(TestTypeId, ToTypeName) {
+ // Should be all lowercase strings (corresponding to TypeClass::type_name())
+ CheckTypeIdReprs([](Type::type id) { return internal::ToTypeName(id); },
+ /* expect_uppercase=*/false);
+}
+
+TEST(TestField, Basics) {
+ Field f0("f0", int32());
+ Field f0_nn("f0", int32(), false);
+
+ ASSERT_EQ(f0.name(), "f0");
+ ASSERT_EQ(f0.type()->ToString(), int32()->ToString());
+
+ ASSERT_TRUE(f0.nullable());
+ ASSERT_FALSE(f0_nn.nullable());
+}
+
+TEST(TestField, ToString) {
+ auto metadata = key_value_metadata({"foo", "bar"}, {"bizz", "buzz"});
+ auto f0 = field("f0", int32(), false, metadata);
+
+ std::string result = f0->ToString(/*print_metadata=*/true);
+ std::string expected = R"(f0: int32 not null
+-- metadata --
+foo: bizz
+bar: buzz)";
+ ASSERT_EQ(expected, result);
+
+ result = f0->ToString();
+ expected = "f0: int32 not null";
+ ASSERT_EQ(expected, result);
+}
+
+TEST(TestField, Equals) {
+ auto meta1 = key_value_metadata({{"a", "1"}, {"b", "2"}});
+ // Different from meta1
+ auto meta2 = key_value_metadata({{"a", "1"}, {"b", "3"}});
+ // Equal to meta1, though in different order
+ auto meta3 = key_value_metadata({{"b", "2"}, {"a", "1"}});
+
+ Field f0("f0", int32());
+ Field f0_nn("f0", int32(), false);
+ Field f0_other("f0", int32());
+ Field f0_with_meta1("f0", int32(), true, meta1);
+ Field f0_with_meta2("f0", int32(), true, meta2);
+ Field f0_with_meta3("f0", int32(), true, meta3);
+
+ AssertFieldEqual(f0, f0_other);
+ AssertFieldNotEqual(f0, f0_nn);
+ AssertFieldNotEqual(f0, f0_with_meta1, /*check_metadata=*/true);
+ AssertFieldNotEqual(f0_with_meta1, f0_with_meta2, /*check_metadata=*/true);
+ AssertFieldEqual(f0_with_meta1, f0_with_meta3, /*check_metadata=*/true);
+
+ AssertFieldEqual(f0, f0_with_meta1);
+ AssertFieldEqual(f0, f0_with_meta2);
+ AssertFieldEqual(f0_with_meta1, f0_with_meta2);
+}
+
+#define ASSERT_COMPATIBLE_IMPL(NAME, TYPE, PLURAL) \
+ void Assert##NAME##Compatible(const TYPE& left, const TYPE& right) { \
+ ASSERT_TRUE(left.IsCompatibleWith(right)) \
+ << PLURAL << left.ToString() << "' and '" << right.ToString() \
+ << "' should be compatible"; \
+ } \
+ \
+ void Assert##NAME##Compatible(const std::shared_ptr<TYPE>& left, \
+ const std::shared_ptr<TYPE>& right) { \
+ ASSERT_NE(left, nullptr); \
+ ASSERT_NE(right, nullptr); \
+ Assert##NAME##Compatible(*left, *right); \
+ } \
+ \
+ void Assert##NAME##NotCompatible(const TYPE& left, const TYPE& right) { \
+ ASSERT_FALSE(left.IsCompatibleWith(right)) \
+ << PLURAL << left.ToString() << "' and '" << right.ToString() \
+ << "' should not be compatible"; \
+ } \
+ \
+ void Assert##NAME##NotCompatible(const std::shared_ptr<TYPE>& left, \
+ const std::shared_ptr<TYPE>& right) { \
+ ASSERT_NE(left, nullptr); \
+ ASSERT_NE(right, nullptr); \
+ Assert##NAME##NotCompatible(*left, *right); \
+ }
+
+ASSERT_COMPATIBLE_IMPL(Field, Field, "fields")
+#undef ASSERT_COMPATIBLE_IMPL
+
+TEST(TestField, IsCompatibleWith) {
+ auto meta1 = key_value_metadata({{"a", "1"}, {"b", "2"}});
+ // Different from meta1
+ auto meta2 = key_value_metadata({{"a", "1"}, {"b", "3"}});
+ // Equal to meta1, though in different order
+ auto meta3 = key_value_metadata({{"b", "2"}, {"a", "1"}});
+
+ Field f0("f0", int32());
+ Field f0_nn("f0", int32(), false);
+ Field f0_nt("f0", null());
+ Field f0_other("f0", int32());
+ Field f0_with_meta1("f0", int32(), true, meta1);
+ Field f0_with_meta2("f0", int32(), true, meta2);
+ Field f0_with_meta3("f0", int32(), true, meta3);
+ Field other("other", int64());
+
+ AssertFieldCompatible(f0, f0_other);
+ AssertFieldCompatible(f0, f0_with_meta1);
+ AssertFieldCompatible(f0, f0_nn);
+ AssertFieldCompatible(f0, f0_nt);
+ AssertFieldCompatible(f0_nt, f0_with_meta1);
+ AssertFieldCompatible(f0_with_meta1, f0_with_meta2);
+ AssertFieldCompatible(f0_with_meta1, f0_with_meta3);
+ AssertFieldNotCompatible(f0, other);
+}
+
+TEST(TestField, TestMetadataConstruction) {
+ auto metadata = key_value_metadata({"foo", "bar"}, {"bizz", "buzz"});
+ auto metadata2 = metadata->Copy();
+ auto f0 = field("f0", int32(), true, metadata);
+ auto f1 = field("f0", int32(), true, metadata2);
+ ASSERT_TRUE(metadata->Equals(*f0->metadata()));
+ AssertFieldEqual(f0, f1);
+}
+
+TEST(TestField, TestWithMetadata) {
+ auto metadata = key_value_metadata({"foo", "bar"}, {"bizz", "buzz"});
+ auto f0 = field("f0", int32());
+ auto f1 = field("f0", int32(), true, metadata);
+ std::shared_ptr<Field> f2 = f0->WithMetadata(metadata);
+
+ AssertFieldEqual(f0, f2);
+ AssertFieldNotEqual(f0, f2, /*check_metadata=*/true);
+
+ AssertFieldEqual(f1, f2);
+ AssertFieldEqual(f1, f2, /*check_metadata=*/true);
+
+ // Ensure pointer equality for zero-copy
+ ASSERT_EQ(metadata.get(), f1->metadata().get());
+}
+
+TEST(TestField, TestWithMergedMetadata) {
+ auto metadata = key_value_metadata({"foo", "bar"}, {"bizz", "buzz"});
+ auto f0 = field("f0", int32(), true, metadata);
+ auto f1 = field("f0", int32());
+
+ auto metadata2 = key_value_metadata({"bar", "baz"}, {"bozz", "bazz"});
+
+ auto f2 = f0->WithMergedMetadata(metadata2);
+ auto expected = field("f0", int32(), true, metadata->Merge(*metadata2));
+ AssertFieldEqual(expected, f2);
+
+ auto f3 = f1->WithMergedMetadata(metadata2);
+ expected = field("f0", int32(), true, metadata2);
+ AssertFieldEqual(expected, f3);
+}
+
+TEST(TestField, TestRemoveMetadata) {
+ auto metadata = key_value_metadata({"foo", "bar"}, {"bizz", "buzz"});
+ auto f0 = field("f0", int32());
+ auto f1 = field("f0", int32(), true, metadata);
+ std::shared_ptr<Field> f2 = f1->RemoveMetadata();
+ ASSERT_EQ(f2->metadata(), nullptr);
+}
+
+TEST(TestField, TestEmptyMetadata) {
+ // Empty metadata should be equivalent to no metadata at all
+ auto metadata1 = key_value_metadata({});
+ auto metadata2 = key_value_metadata({"foo"}, {"foo value"});
+
+ auto f0 = field("f0", int32());
+ auto f1 = field("f0", int32(), true, metadata1);
+ auto f2 = field("f0", int32(), true, metadata2);
+
+ AssertFieldEqual(f0, f1);
+ AssertFieldEqual(f0, f2);
+ AssertFieldEqual(f0, f1, /*check_metadata =*/true);
+ AssertFieldNotEqual(f0, f2, /*check_metadata =*/true);
+}
+
+TEST(TestField, TestFlatten) {
+ auto metadata = key_value_metadata({"foo", "bar"}, {"bizz", "buzz"});
+ auto f0 = field("f0", int32(), true /* nullable */, metadata);
+ auto vec = f0->Flatten();
+ ASSERT_EQ(vec.size(), 1);
+ AssertFieldEqual(vec[0], f0);
+
+ auto f1 = field("f1", float64(), false /* nullable */);
+ auto ff = field("nest", struct_({f0, f1}));
+ vec = ff->Flatten();
+ ASSERT_EQ(vec.size(), 2);
+ auto expected0 = field("nest.f0", int32(), true /* nullable */, metadata);
+ // nullable parent implies nullable flattened child
+ auto expected1 = field("nest.f1", float64(), true /* nullable */);
+ AssertFieldEqual(vec[0], expected0);
+ AssertFieldEqual(vec[1], expected1);
+
+ ff = field("nest", struct_({f0, f1}), false /* nullable */);
+ vec = ff->Flatten();
+ ASSERT_EQ(vec.size(), 2);
+ expected0 = field("nest.f0", int32(), true /* nullable */, metadata);
+ expected1 = field("nest.f1", float64(), false /* nullable */);
+ AssertFieldEqual(vec[0], expected0);
+ AssertFieldEqual(vec[1], expected1);
+}
+
+TEST(TestField, TestReplacement) {
+ auto metadata = key_value_metadata({"foo", "bar"}, {"bizz", "buzz"});
+ auto f0 = field("f0", int32(), true, metadata);
+ auto fzero = f0->WithType(utf8());
+ auto f1 = f0->WithName("f1");
+
+ AssertFieldNotEqual(f0, fzero);
+ AssertFieldNotCompatible(f0, fzero);
+ AssertFieldNotEqual(fzero, f1);
+ AssertFieldNotCompatible(fzero, f1);
+ AssertFieldNotEqual(f1, f0);
+ AssertFieldNotCompatible(f1, f0);
+
+ ASSERT_EQ(fzero->name(), "f0");
+ AssertTypeEqual(fzero->type(), utf8());
+ ASSERT_TRUE(fzero->metadata()->Equals(*metadata));
+
+ ASSERT_EQ(f1->name(), "f1");
+ AssertTypeEqual(f1->type(), int32());
+ ASSERT_TRUE(f1->metadata()->Equals(*metadata));
+}
+
+TEST(TestField, TestMerge) {
+ auto metadata1 = key_value_metadata({"foo"}, {"v"});
+ auto metadata2 = key_value_metadata({"bar"}, {"v"});
+ {
+ // different name.
+ ASSERT_RAISES(Invalid, field("f0", int32())->MergeWith(field("f1", int32())));
+ }
+ {
+ // Same type.
+ auto f1 = field("f", int32())->WithMetadata(metadata1);
+ auto f2 = field("f", int32())->WithMetadata(metadata2);
+ std::shared_ptr<Field> result;
+ ASSERT_OK_AND_ASSIGN(result, f1->MergeWith(f2));
+ ASSERT_TRUE(result->Equals(f1));
+ ASSERT_OK_AND_ASSIGN(result, f2->MergeWith(f1));
+ ASSERT_TRUE(result->Equals(f2));
+ }
+ {
+ // promote_nullability == false
+ auto f = field("f", int32());
+ auto null_field = field("f", null());
+ Field::MergeOptions options;
+ options.promote_nullability = false;
+ ASSERT_RAISES(Invalid, f->MergeWith(null_field, options));
+ ASSERT_RAISES(Invalid, null_field->MergeWith(f, options));
+
+ // Also rejects fields with different nullability.
+ ASSERT_RAISES(Invalid,
+ f->WithNullable(true)->MergeWith(f->WithNullable(false), options));
+ }
+ {
+ // promote_nullability == true; merge with a null field.
+ Field::MergeOptions options;
+ options.promote_nullability = true;
+ auto f = field("f", int32())->WithNullable(false)->WithMetadata(metadata1);
+ auto null_field = field("f", null())->WithMetadata(metadata2);
+
+ std::shared_ptr<Field> result;
+ ASSERT_OK_AND_ASSIGN(result, f->MergeWith(null_field, options));
+ ASSERT_TRUE(result->Equals(f->WithNullable(true)->WithMetadata(metadata1)));
+ ASSERT_OK_AND_ASSIGN(result, null_field->MergeWith(f, options));
+ ASSERT_TRUE(result->Equals(f->WithNullable(true)->WithMetadata(metadata2)));
+ }
+ {
+ // promote_nullability == true; merge a nullable field and a in-nullable field.
+ Field::MergeOptions options;
+ options.promote_nullability = true;
+ auto f1 = field("f", int32())->WithNullable(false);
+ auto f2 = field("f", int32())->WithNullable(true);
+ std::shared_ptr<Field> result;
+ ASSERT_OK_AND_ASSIGN(result, f1->MergeWith(f2, options));
+ ASSERT_TRUE(result->Equals(f1->WithNullable(true)));
+ ASSERT_OK_AND_ASSIGN(result, f2->MergeWith(f1, options));
+ ASSERT_TRUE(result->Equals(f2));
+ }
+}
+
+TEST(TestFieldPath, Basics) {
+ auto f0 = field("alpha", int32());
+ auto f1 = field("beta", int32());
+ auto f2 = field("alpha", int32());
+ auto f3 = field("beta", int32());
+ Schema s({f0, f1, f2, f3});
+
+ // retrieving a field with single-element FieldPath is equivalent to Schema::field
+ for (int index = 0; index < s.num_fields(); ++index) {
+ ASSERT_OK_AND_EQ(s.field(index), FieldPath({index}).Get(s));
+ }
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ testing::HasSubstr("empty indices cannot be traversed"),
+ FieldPath().Get(s));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(IndexError, testing::HasSubstr("index out of range"),
+ FieldPath({s.num_fields() * 2}).Get(s));
+}
+
+TEST(TestFieldRef, Basics) {
+ auto f0 = field("alpha", int32());
+ auto f1 = field("beta", int32());
+ auto f2 = field("alpha", int32());
+ auto f3 = field("beta", int32());
+ Schema s({f0, f1, f2, f3});
+
+ // lookup by index returns Indices{index}
+ for (int index = 0; index < s.num_fields(); ++index) {
+ EXPECT_THAT(FieldRef(index).FindAll(s), ElementsAre(FieldPath{index}));
+ }
+ // out of range index results in a failure to match
+ EXPECT_THAT(FieldRef(s.num_fields() * 2).FindAll(s), ElementsAre());
+
+ // lookup by name returns the Indices of both matching fields
+ EXPECT_THAT(FieldRef("alpha").FindAll(s), ElementsAre(FieldPath{0}, FieldPath{2}));
+ EXPECT_THAT(FieldRef("beta").FindAll(s), ElementsAre(FieldPath{1}, FieldPath{3}));
+}
+
+TEST(TestFieldRef, FromDotPath) {
+ ASSERT_OK_AND_EQ(FieldRef("alpha"), FieldRef::FromDotPath(R"(.alpha)"));
+
+ ASSERT_OK_AND_EQ(FieldRef("", ""), FieldRef::FromDotPath(R"(..)"));
+
+ ASSERT_OK_AND_EQ(FieldRef(2), FieldRef::FromDotPath(R"([2])"));
+
+ ASSERT_OK_AND_EQ(FieldRef("beta", 3), FieldRef::FromDotPath(R"(.beta[3])"));
+
+ ASSERT_OK_AND_EQ(FieldRef(5, "gamma", "delta", 7),
+ FieldRef::FromDotPath(R"([5].gamma.delta[7])"));
+
+ ASSERT_OK_AND_EQ(FieldRef("hello world"), FieldRef::FromDotPath(R"(.hello world)"));
+
+ ASSERT_OK_AND_EQ(FieldRef(R"([y]\tho.\)"), FieldRef::FromDotPath(R"(.\[y\]\\tho\.\)"));
+
+ ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"()"));
+ ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"(alpha)"));
+ ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"([134234)"));
+ ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"([1stuf])"));
+}
+
+TEST(TestFieldPath, Nested) {
+ auto f0 = field("alpha", int32());
+ auto f1_0 = field("alpha", int32());
+ auto f1 = field("beta", struct_({f1_0}));
+ auto f2_0 = field("alpha", int32());
+ auto f2_1_0 = field("alpha", int32());
+ auto f2_1_1 = field("alpha", int32());
+ auto f2_1 = field("gamma", struct_({f2_1_0, f2_1_1}));
+ auto f2 = field("beta", struct_({f2_0, f2_1}));
+ Schema s({f0, f1, f2});
+
+ // retrieving fields with nested indices
+ EXPECT_EQ(FieldPath({0}).Get(s), f0);
+ EXPECT_EQ(FieldPath({1, 0}).Get(s), f1_0);
+ EXPECT_EQ(FieldPath({2, 0}).Get(s), f2_0);
+ EXPECT_EQ(FieldPath({2, 1, 0}).Get(s), f2_1_0);
+ EXPECT_EQ(FieldPath({2, 1, 1}).Get(s), f2_1_1);
+}
+
+TEST(TestFieldRef, Nested) {
+ auto f0 = field("alpha", int32());
+ auto f1_0 = field("alpha", int32());
+ auto f1 = field("beta", struct_({f1_0}));
+ auto f2_0 = field("alpha", int32());
+ auto f2_1_0 = field("alpha", int32());
+ auto f2_1_1 = field("alpha", int32());
+ auto f2_1 = field("gamma", struct_({f2_1_0, f2_1_1}));
+ auto f2 = field("beta", struct_({f2_0, f2_1}));
+ Schema s({f0, f1, f2});
+
+ EXPECT_THAT(FieldRef("beta", "alpha").FindAll(s),
+ ElementsAre(FieldPath{1, 0}, FieldPath{2, 0}));
+ EXPECT_THAT(FieldRef("beta", "gamma", "alpha").FindAll(s),
+ ElementsAre(FieldPath{2, 1, 0}, FieldPath{2, 1, 1}));
+}
+
+using TestSchema = ::testing::Test;
+
+TEST_F(TestSchema, Basics) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f1_optional = field("f1", uint8());
+
+ auto f2 = field("f2", utf8());
+
+ auto schema = ::arrow::schema({f0, f1, f2});
+
+ ASSERT_EQ(3, schema->num_fields());
+ AssertFieldEqual(*f0, *schema->field(0));
+ AssertFieldEqual(*f1, *schema->field(1));
+ AssertFieldEqual(*f2, *schema->field(2));
+
+ auto schema2 = ::arrow::schema({f0, f1, f2});
+
+ std::vector<std::shared_ptr<Field>> fields3 = {f0, f1_optional, f2};
+ auto schema3 = std::make_shared<Schema>(fields3);
+ AssertSchemaEqual(schema, schema2);
+ AssertSchemaNotEqual(schema, schema3);
+
+ ASSERT_EQ(schema->fingerprint(), schema2->fingerprint());
+ ASSERT_NE(schema->fingerprint(), schema3->fingerprint());
+
+ auto schema4 = ::arrow::schema({f0}, Endianness::Little);
+ auto schema5 = ::arrow::schema({f0}, Endianness::Little);
+ auto schema6 = ::arrow::schema({f0}, Endianness::Big);
+ auto schema7 = ::arrow::schema({f0});
+
+ AssertSchemaEqual(schema4, schema5);
+ AssertSchemaNotEqual(schema4, schema6);
+#if ARROW_LITTLE_ENDIAN
+ AssertSchemaEqual(schema4, schema7);
+ AssertSchemaNotEqual(schema6, schema7);
+#else
+ AssertSchemaNotEqual(schema4, schema6);
+ AssertSchemaEqual(schema6, schema7);
+#endif
+
+ ASSERT_EQ(schema4->fingerprint(), schema5->fingerprint());
+ ASSERT_NE(schema4->fingerprint(), schema6->fingerprint());
+#if ARROW_LITTLE_ENDIAN
+ ASSERT_EQ(schema4->fingerprint(), schema7->fingerprint());
+ ASSERT_NE(schema6->fingerprint(), schema7->fingerprint());
+#else
+ ASSERT_NE(schema4->fingerprint(), schema7->fingerprint());
+ ASSERT_EQ(schema6->fingerprint(), schema7->fingerprint());
+#endif
+}
+
+TEST_F(TestSchema, ToString) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8());
+ auto f3 = field("f3", list(int16()));
+
+ auto metadata = key_value_metadata({"foo"}, {"bar"});
+ auto schema = ::arrow::schema({f0, f1, f2, f3}, metadata);
+
+ std::string result = schema->ToString();
+ std::string expected = R"(f0: int32
+f1: uint8 not null
+f2: string
+f3: list<item: int16>)";
+
+ ASSERT_EQ(expected, result);
+
+ result = schema->ToString(/*print_metadata=*/true);
+ std::string expected_with_metadata = expected + R"(
+-- metadata --
+foo: bar)";
+
+ ASSERT_EQ(expected_with_metadata, result);
+
+ // With swapped endianness
+#if ARROW_LITTLE_ENDIAN
+ schema = schema->WithEndianness(Endianness::Big);
+ expected = R"(f0: int32
+f1: uint8 not null
+f2: string
+f3: list<item: int16>
+-- endianness: big --)";
+#else
+ schema = schema->WithEndianness(Endianness::Little);
+ expected = R"(f0: int32
+f1: uint8 not null
+f2: string
+f3: list<item: int16>
+-- endianness: little --)";
+#endif
+
+ result = schema->ToString();
+ ASSERT_EQ(expected, result);
+
+ result = schema->ToString(/*print_metadata=*/true);
+ expected_with_metadata = expected + R"(
+-- metadata --
+foo: bar)";
+
+ ASSERT_EQ(expected_with_metadata, result);
+}
+
+TEST_F(TestSchema, GetFieldByName) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8());
+ auto f3 = field("f3", list(int16()));
+
+ auto schema = ::arrow::schema({f0, f1, f2, f3});
+
+ std::shared_ptr<Field> result;
+
+ result = schema->GetFieldByName("f1");
+ AssertFieldEqual(f1, result);
+
+ result = schema->GetFieldByName("f3");
+ AssertFieldEqual(f3, result);
+
+ result = schema->GetFieldByName("not-found");
+ ASSERT_EQ(result, nullptr);
+}
+
+TEST_F(TestSchema, GetFieldIndex) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8());
+ auto f3 = field("f3", list(int16()));
+
+ auto schema = ::arrow::schema({f0, f1, f2, f3});
+
+ ASSERT_EQ(0, schema->GetFieldIndex(f0->name()));
+ ASSERT_EQ(1, schema->GetFieldIndex(f1->name()));
+ ASSERT_EQ(2, schema->GetFieldIndex(f2->name()));
+ ASSERT_EQ(3, schema->GetFieldIndex(f3->name()));
+ ASSERT_EQ(-1, schema->GetFieldIndex("not-found"));
+}
+
+TEST_F(TestSchema, GetFieldDuplicates) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8());
+ auto f3 = field("f1", list(int16()));
+
+ auto schema = ::arrow::schema({f0, f1, f2, f3});
+
+ ASSERT_EQ(0, schema->GetFieldIndex(f0->name()));
+ ASSERT_EQ(-1, schema->GetFieldIndex(f1->name())); // duplicate
+ ASSERT_EQ(2, schema->GetFieldIndex(f2->name()));
+ ASSERT_EQ(-1, schema->GetFieldIndex("not-found"));
+ ASSERT_EQ(std::vector<int>{0}, schema->GetAllFieldIndices(f0->name()));
+ ASSERT_EQ(std::vector<int>({1, 3}), schema->GetAllFieldIndices(f1->name()));
+
+ ASSERT_TRUE(::arrow::schema({f0, f1, f2})->HasDistinctFieldNames());
+ ASSERT_FALSE(schema->HasDistinctFieldNames());
+
+ std::vector<std::shared_ptr<Field>> results;
+
+ results = schema->GetAllFieldsByName(f0->name());
+ ASSERT_EQ(results.size(), 1);
+ AssertFieldEqual(results[0], f0);
+
+ results = schema->GetAllFieldsByName(f1->name());
+ ASSERT_EQ(results.size(), 2);
+ if (results[0]->type()->id() == Type::UINT8) {
+ AssertFieldEqual(results[0], f1);
+ AssertFieldEqual(results[1], f3);
+ } else {
+ AssertFieldEqual(results[0], f3);
+ AssertFieldEqual(results[1], f1);
+ }
+
+ results = schema->GetAllFieldsByName("not-found");
+ ASSERT_EQ(results.size(), 0);
+}
+
+TEST_F(TestSchema, CanReferenceFieldsByNames) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8());
+ auto f3 = field("f1", list(int16()));
+
+ auto schema = ::arrow::schema({f0, f1, f2, f3});
+
+ ASSERT_OK(schema->CanReferenceFieldsByNames({"f0", "f2"}));
+ ASSERT_OK(schema->CanReferenceFieldsByNames({"f2", "f0"}));
+
+ // Not found
+ ASSERT_RAISES(Invalid, schema->CanReferenceFieldsByNames({"nope"}));
+ ASSERT_RAISES(Invalid, schema->CanReferenceFieldsByNames({"f0", "nope"}));
+ // Duplicates
+ ASSERT_RAISES(Invalid, schema->CanReferenceFieldsByNames({"f1"}));
+ ASSERT_RAISES(Invalid, schema->CanReferenceFieldsByNames({"f0", "f1"}));
+ // Both
+ ASSERT_RAISES(Invalid, schema->CanReferenceFieldsByNames({"f0", "f1", "nope"}));
+}
+
+TEST_F(TestSchema, TestMetadataConstruction) {
+ auto metadata0 = key_value_metadata({{"foo", "bar"}, {"bizz", "buzz"}});
+ auto metadata1 = key_value_metadata({{"foo", "baz"}});
+
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8(), true);
+ auto f3 = field("f2", utf8(), true, metadata1->Copy());
+
+ auto schema0 = ::arrow::schema({f0, f1, f2}, metadata0);
+ auto schema1 = ::arrow::schema({f0, f1, f2}, metadata1);
+ auto schema2 = ::arrow::schema({f0, f1, f2}, metadata0->Copy());
+ auto schema3 = ::arrow::schema({f0, f1, f3}, metadata0->Copy());
+
+ ASSERT_TRUE(metadata0->Equals(*schema0->metadata()));
+ ASSERT_TRUE(metadata1->Equals(*schema1->metadata()));
+ ASSERT_TRUE(metadata0->Equals(*schema2->metadata()));
+ AssertSchemaEqual(schema0, schema2);
+
+ AssertSchemaEqual(schema0, schema1);
+ AssertSchemaNotEqual(schema0, schema1, /*check_metadata=*/true);
+
+ AssertSchemaEqual(schema2, schema1);
+ AssertSchemaNotEqual(schema2, schema1, /*check_metadata=*/true);
+
+ // Field has different metatadata
+ AssertSchemaEqual(schema2, schema3);
+ AssertSchemaNotEqual(schema2, schema3, /*check_metadata=*/true);
+
+ ASSERT_EQ(schema0->fingerprint(), schema1->fingerprint());
+ ASSERT_EQ(schema0->fingerprint(), schema2->fingerprint());
+ ASSERT_EQ(schema0->fingerprint(), schema3->fingerprint());
+ ASSERT_NE(schema0->metadata_fingerprint(), schema1->metadata_fingerprint());
+ ASSERT_EQ(schema0->metadata_fingerprint(), schema2->metadata_fingerprint());
+ ASSERT_NE(schema0->metadata_fingerprint(), schema3->metadata_fingerprint());
+}
+
+TEST_F(TestSchema, TestNestedMetadataComparison) {
+ auto item0 = field("item", int32(), true);
+ auto item1 = field("item", int32(), true, key_value_metadata({{"foo", "baz"}}));
+
+ Schema schema0({field("f", list(item0))});
+ Schema schema1({field("f", list(item1))});
+
+ ASSERT_EQ(schema0.fingerprint(), schema1.fingerprint());
+ ASSERT_NE(schema0.metadata_fingerprint(), schema1.metadata_fingerprint());
+
+ AssertSchemaEqual(schema0, schema1);
+ AssertSchemaNotEqual(schema0, schema1, /* check_metadata = */ true);
+}
+
+TEST_F(TestSchema, TestDeeplyNestedMetadataComparison) {
+ auto item0 = field("item", int32(), true);
+ auto item1 = field("item", int32(), true, key_value_metadata({{"foo", "baz"}}));
+
+ Schema schema0(
+ {field("f", list(list(sparse_union({field("struct", struct_({item0}))}))))});
+ Schema schema1(
+ {field("f", list(list(sparse_union({field("struct", struct_({item1}))}))))});
+
+ ASSERT_EQ(schema0.fingerprint(), schema1.fingerprint());
+ ASSERT_NE(schema0.metadata_fingerprint(), schema1.metadata_fingerprint());
+
+ AssertSchemaEqual(schema0, schema1);
+ AssertSchemaNotEqual(schema0, schema1, /* check_metadata = */ true);
+}
+
+TEST_F(TestSchema, TestFieldsDifferOnlyInMetadata) {
+ auto f0 = field("f", utf8(), true, nullptr);
+ auto f1 = field("f", utf8(), true, key_value_metadata({{"foo", "baz"}}));
+
+ Schema schema0({f0, f1});
+ Schema schema1({f1, f0});
+
+ AssertSchemaEqual(schema0, schema1);
+ AssertSchemaNotEqual(schema0, schema1, /* check_metadata = */ true);
+
+ ASSERT_EQ(schema0.fingerprint(), schema1.fingerprint());
+ ASSERT_NE(schema0.metadata_fingerprint(), schema1.metadata_fingerprint());
+}
+
+TEST_F(TestSchema, TestEmptyMetadata) {
+ // Empty metadata should be equivalent to no metadata at all
+ auto f1 = field("f1", int32());
+ auto metadata1 = key_value_metadata({});
+ auto metadata2 = key_value_metadata({"foo"}, {"foo value"});
+
+ auto schema1 = ::arrow::schema({f1});
+ auto schema2 = ::arrow::schema({f1}, metadata1);
+ auto schema3 = ::arrow::schema({f1}, metadata2);
+
+ AssertSchemaEqual(schema1, schema2);
+ AssertSchemaNotEqual(schema1, schema3, /*check_metadata=*/true);
+
+ ASSERT_EQ(schema1->fingerprint(), schema2->fingerprint());
+ ASSERT_EQ(schema1->fingerprint(), schema3->fingerprint());
+ ASSERT_EQ(schema1->metadata_fingerprint(), schema2->metadata_fingerprint());
+ ASSERT_NE(schema1->metadata_fingerprint(), schema3->metadata_fingerprint());
+}
+
+TEST_F(TestSchema, TestWithMetadata) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8());
+ std::vector<std::shared_ptr<Field>> fields = {f0, f1, f2};
+ auto metadata = key_value_metadata({"foo", "bar"}, {"bizz", "buzz"});
+ auto schema = std::make_shared<Schema>(fields);
+ std::shared_ptr<Schema> new_schema = schema->WithMetadata(metadata);
+ ASSERT_TRUE(metadata->Equals(*new_schema->metadata()));
+
+ // Not copied
+ ASSERT_TRUE(metadata.get() == new_schema->metadata().get());
+}
+
+TEST_F(TestSchema, TestRemoveMetadata) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8());
+ std::vector<std::shared_ptr<Field>> fields = {f0, f1, f2};
+ auto schema = std::make_shared<Schema>(fields);
+ std::shared_ptr<Schema> new_schema = schema->RemoveMetadata();
+ ASSERT_TRUE(new_schema->metadata() == nullptr);
+}
+
+void AssertSchemaBuilderYield(const SchemaBuilder& builder,
+ const std::shared_ptr<Schema>& expected) {
+ ASSERT_OK_AND_ASSIGN(auto schema, builder.Finish());
+ AssertSchemaEqual(schema, expected);
+}
+
+TEST(TestSchemaBuilder, DefaultBehavior) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8());
+
+ SchemaBuilder builder;
+ ASSERT_OK(builder.AddField(f0));
+ ASSERT_OK(builder.AddField(f1));
+ ASSERT_OK(builder.AddField(f2));
+ AssertSchemaBuilderYield(builder, schema({f0, f1, f2}));
+
+ builder.Reset();
+ ASSERT_OK(builder.AddFields({f0, f1, f2->WithNullable(false)}));
+ AssertSchemaBuilderYield(builder, schema({f0, f1, f2->WithNullable(false)}));
+
+ builder.Reset();
+ ASSERT_OK(builder.AddSchema(schema({f2, f0})));
+ AssertSchemaBuilderYield(builder, schema({f2, f0}));
+
+ builder.Reset();
+ ASSERT_OK(builder.AddSchemas({schema({f1, f2}), schema({f2, f0})}));
+ AssertSchemaBuilderYield(builder, schema({f1, f2, f2, f0}));
+}
+
+TEST(TestSchemaBuilder, WithMetadata) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto metadata = key_value_metadata({{"foo", "bar"}});
+
+ SchemaBuilder builder;
+ ASSERT_OK(builder.AddMetadata(*metadata));
+ ASSERT_OK_AND_ASSIGN(auto schema, builder.Finish());
+ AssertSchemaEqual(schema, ::arrow::schema({})->WithMetadata(metadata));
+
+ ASSERT_OK(builder.AddField(f0));
+ ASSERT_OK_AND_ASSIGN(schema, builder.Finish());
+ AssertSchemaEqual(schema, ::arrow::schema({f0})->WithMetadata(metadata));
+
+ SchemaBuilder other_builder{::arrow::schema({})->WithMetadata(metadata)};
+ ASSERT_OK(other_builder.AddField(f1));
+ ASSERT_OK_AND_ASSIGN(schema, other_builder.Finish());
+ AssertSchemaEqual(schema, ::arrow::schema({f1})->WithMetadata(metadata));
+
+ other_builder.Reset();
+ ASSERT_OK(other_builder.AddField(f1->WithMetadata(metadata)));
+ ASSERT_OK_AND_ASSIGN(schema, other_builder.Finish());
+ AssertSchemaEqual(schema, ::arrow::schema({f1->WithMetadata(metadata)}));
+}
+
+TEST(TestSchemaBuilder, IncrementalConstruction) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8());
+
+ SchemaBuilder builder;
+ std::shared_ptr<Schema> actual;
+
+ ASSERT_OK_AND_ASSIGN(actual, builder.Finish());
+ AssertSchemaEqual(actual, ::arrow::schema({}));
+
+ ASSERT_OK(builder.AddField(f0));
+ ASSERT_OK_AND_ASSIGN(actual, builder.Finish());
+ AssertSchemaEqual(actual, ::arrow::schema({f0}));
+
+ ASSERT_OK(builder.AddField(f1));
+ ASSERT_OK_AND_ASSIGN(actual, builder.Finish());
+ AssertSchemaEqual(actual, ::arrow::schema({f0, f1}));
+
+ ASSERT_OK(builder.AddField(f2));
+ AssertSchemaBuilderYield(builder, schema({f0, f1, f2}));
+}
+
+TEST(TestSchemaBuilder, PolicyIgnore) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8());
+ auto f0_req = field("f0", utf8(), false);
+
+ SchemaBuilder builder{SchemaBuilder::CONFLICT_IGNORE};
+
+ ASSERT_OK(builder.AddFields({f0, f1}));
+ AssertSchemaBuilderYield(builder, schema({f0, f1}));
+
+ ASSERT_OK(builder.AddField(f0_req));
+ AssertSchemaBuilderYield(builder, schema({f0, f1}));
+
+ ASSERT_OK(builder.AddField(f0));
+ AssertSchemaBuilderYield(builder, schema({f0, f1}));
+}
+
+TEST(TestSchemaBuilder, PolicyReplace) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8());
+ auto f0_req = field("f0", utf8(), false);
+
+ SchemaBuilder builder{SchemaBuilder::CONFLICT_REPLACE};
+
+ ASSERT_OK(builder.AddFields({f0, f1}));
+ AssertSchemaBuilderYield(builder, schema({f0, f1}));
+
+ ASSERT_OK(builder.AddField(f0_req));
+ AssertSchemaBuilderYield(builder, schema({f0_req, f1}));
+
+ ASSERT_OK(builder.AddField(f0));
+ AssertSchemaBuilderYield(builder, schema({f0, f1}));
+}
+
+TEST(TestSchemaBuilder, PolicyMerge) {
+ auto f0 = field("f0", int32(), true);
+ auto f1 = field("f1", uint8());
+ // Same as f0, but not required.
+ auto f0_opt = field("f0", int32());
+ // Another type, can't merge
+ auto f0_other = field("f0", utf8(), false);
+
+ SchemaBuilder builder{SchemaBuilder::CONFLICT_MERGE};
+
+ ASSERT_OK(builder.AddFields({f0, f1}));
+ AssertSchemaBuilderYield(builder, schema({f0, f1}));
+
+ ASSERT_OK(builder.AddField(f0_opt));
+ AssertSchemaBuilderYield(builder, schema({f0_opt, f1}));
+
+ // Unsupported merge with a different type
+ ASSERT_RAISES(Invalid, builder.AddField(f0_other));
+ // Builder should still contain state
+ AssertSchemaBuilderYield(builder, schema({f0, f1}));
+
+ builder.Reset();
+ // Create a schema with duplicate fields
+ builder.SetPolicy(SchemaBuilder::CONFLICT_APPEND);
+ ASSERT_OK(builder.AddFields({f0, f0}));
+
+ builder.SetPolicy(SchemaBuilder::CONFLICT_MERGE);
+ // Even if the field is compatible, it can't know with which field to merge.
+ ASSERT_RAISES(Invalid, builder.AddField(f0_opt));
+
+ AssertSchemaBuilderYield(builder, schema({f0, f0}));
+}
+
+TEST(TestSchemaBuilder, PolicyError) {
+ auto f0 = field("f0", int32(), true);
+ auto f1 = field("f1", uint8());
+ // Same as f0, but not required.
+ auto f0_opt = field("f0", int32());
+ // Another type, can't merge
+ auto f0_other = field("f0", utf8(), false);
+
+ SchemaBuilder builder{SchemaBuilder::CONFLICT_ERROR};
+
+ ASSERT_OK(builder.AddFields({f0, f1}));
+ AssertSchemaBuilderYield(builder, schema({f0, f1}));
+
+ ASSERT_RAISES(Invalid, builder.AddField(f0));
+ ASSERT_RAISES(Invalid, builder.AddField(f0_opt));
+ ASSERT_RAISES(Invalid, builder.AddField(f0_other));
+ AssertSchemaBuilderYield(builder, schema({f0, f1}));
+}
+
+TEST(TestSchemaBuilder, Merge) {
+ auto f0 = field("f0", int32(), true);
+ auto f1 = field("f1", uint8());
+ // Same as f0, but not required.
+ auto f0_opt = field("f0", int32());
+ // Another type, can't merge
+ auto f0_other = field("f0", utf8(), false);
+
+ auto s1 = schema({f0, f1});
+ auto s2 = schema({f1, f0});
+ auto s3 = schema({f0_opt});
+ auto broken = schema({f0_other});
+
+ ASSERT_OK_AND_ASSIGN(auto schema, SchemaBuilder::Merge({s1, s2, s3}));
+ ASSERT_OK(SchemaBuilder::AreCompatible({s1, s2, s3}));
+ AssertSchemaEqual(schema, ::arrow::schema({f0_opt, f1}));
+
+ ASSERT_OK_AND_ASSIGN(schema, SchemaBuilder::Merge({s2, s3, s1}));
+ AssertSchemaEqual(schema, ::arrow::schema({f1, f0_opt}));
+
+ ASSERT_RAISES(Invalid, SchemaBuilder::Merge({s3, broken}));
+ ASSERT_RAISES(Invalid, SchemaBuilder::AreCompatible({s3, broken}));
+}
+
+class TestUnifySchemas : public TestSchema {
+ protected:
+ void AssertSchemaEqualsUnorderedFields(const Schema& lhs, const Schema& rhs) {
+ if (lhs.metadata()) {
+ ASSERT_NE(nullptr, rhs.metadata());
+ ASSERT_TRUE(lhs.metadata()->Equals(*rhs.metadata()));
+ } else {
+ ASSERT_EQ(nullptr, rhs.metadata());
+ }
+ ASSERT_EQ(lhs.num_fields(), rhs.num_fields());
+ for (int i = 0; i < lhs.num_fields(); ++i) {
+ auto lhs_field = lhs.field(i);
+ auto rhs_field = rhs.GetFieldByName(lhs_field->name());
+ ASSERT_NE(nullptr, rhs_field);
+ ASSERT_TRUE(lhs_field->Equals(rhs_field, true))
+ << lhs_field->ToString() << " vs " << rhs_field->ToString();
+ }
+ }
+};
+
+TEST_F(TestUnifySchemas, EmptyInput) { ASSERT_RAISES(Invalid, UnifySchemas({})); }
+
+TEST_F(TestUnifySchemas, IdenticalSchemas) {
+ auto int32_field = field("int32_field", int32());
+ auto uint8_field = field("uint8_field", uint8(), false);
+ auto utf8_field = field("utf8_field", utf8());
+ std::vector<std::string> keys{"foo"};
+ std::vector<std::string> vals{"bar"};
+ auto metadata = std::make_shared<KeyValueMetadata>(keys, vals);
+
+ auto schema1 = schema({int32_field, uint8_field, utf8_field});
+ auto schema2 = schema({int32_field, uint8_field, utf8_field->WithMetadata(metadata)})
+ ->WithMetadata(metadata);
+
+ ASSERT_OK_AND_ASSIGN(auto result, UnifySchemas({schema1, schema2}));
+ // Using Schema::Equals to make sure the ordering of fields is not changed.
+ ASSERT_TRUE(result->Equals(*schema1, /*check_metadata=*/true));
+
+ ASSERT_OK_AND_ASSIGN(result, UnifySchemas({schema2, schema1}));
+ // Using Schema::Equals to make sure the ordering of fields is not changed.
+ ASSERT_TRUE(result->Equals(*schema2, /*check_metadata=*/true));
+}
+
+TEST_F(TestUnifySchemas, FieldOrderingSameAsTheFirstSchema) {
+ auto int32_field = field("int32_field", int32());
+ auto uint8_field = field("uint8_field", uint8(), false);
+ auto utf8_field = field("utf8_field", utf8());
+ auto binary_field = field("binary_field", binary());
+
+ auto schema1 = schema({int32_field, uint8_field, utf8_field});
+ // schema2 only differs from schema1 in field ordering.
+ auto schema2 = schema({uint8_field, int32_field, utf8_field});
+ auto schema3 = schema({binary_field});
+
+ ASSERT_OK_AND_ASSIGN(auto result, UnifySchemas({schema1, schema2, schema3}));
+
+ ASSERT_EQ(4, result->num_fields());
+ ASSERT_TRUE(int32_field->Equals(result->field(0)));
+ ASSERT_TRUE(uint8_field->Equals(result->field(1)));
+ ASSERT_TRUE(utf8_field->Equals(result->field(2)));
+ ASSERT_TRUE(binary_field->Equals(result->field(3)));
+}
+
+TEST_F(TestUnifySchemas, MissingField) {
+ auto int32_field = field("int32_field", int32());
+ auto uint8_field = field("uint8_field", uint8(), false);
+ auto utf8_field = field("utf8_field", utf8());
+ auto metadata1 = key_value_metadata({"foo"}, {"bar"});
+ auto metadata2 = key_value_metadata({"q"}, {"42"});
+
+ auto schema1 = schema({int32_field, uint8_field})->WithMetadata(metadata1);
+ auto schema2 = schema({uint8_field, utf8_field->WithMetadata(metadata2)});
+ auto schema3 = schema({int32_field->WithMetadata(metadata1), uint8_field, utf8_field});
+
+ ASSERT_OK_AND_ASSIGN(auto result, UnifySchemas({schema1, schema2}));
+ AssertSchemaEqualsUnorderedFields(
+ *result, *schema({int32_field, uint8_field, utf8_field->WithMetadata(metadata2)})
+ ->WithMetadata(metadata1));
+}
+
+TEST_F(TestUnifySchemas, PromoteNullTypeField) {
+ auto metadata = key_value_metadata({"foo"}, {"bar"});
+ auto null_field = field("f", null());
+ auto int32_field = field("f", int32(), /*nullable=*/false);
+
+ auto schema1 = schema({null_field->WithMetadata(metadata)});
+ auto schema2 = schema({int32_field});
+
+ ASSERT_OK_AND_ASSIGN(auto result, UnifySchemas({schema1, schema2}));
+ AssertSchemaEqualsUnorderedFields(
+ *result, *schema({int32_field->WithMetadata(metadata)->WithNullable(true)}));
+
+ ASSERT_OK_AND_ASSIGN(result, UnifySchemas({schema2, schema1}));
+ AssertSchemaEqualsUnorderedFields(*result, *schema({int32_field->WithNullable(true)}));
+}
+
+TEST_F(TestUnifySchemas, MoreSchemas) {
+ auto int32_field = field("int32_field", int32());
+ auto uint8_field = field("uint8_field", uint8(), false);
+ auto utf8_field = field("utf8_field", utf8());
+
+ ASSERT_OK_AND_ASSIGN(
+ auto result,
+ UnifySchemas({schema({int32_field}), schema({uint8_field}), schema({utf8_field})}));
+ AssertSchemaEqualsUnorderedFields(
+ *result, *schema({int32_field->WithNullable(true), uint8_field->WithNullable(false),
+ utf8_field->WithNullable(true)}));
+}
+
+TEST_F(TestUnifySchemas, IncompatibleTypes) {
+ auto int32_field = field("f", int32());
+ auto uint8_field = field("f", uint8(), false);
+
+ auto schema1 = schema({int32_field});
+ auto schema2 = schema({uint8_field});
+
+ ASSERT_RAISES(Invalid, UnifySchemas({schema1, schema2}));
+}
+
+TEST_F(TestUnifySchemas, DuplicateFieldNames) {
+ auto int32_field = field("int32_field", int32());
+ auto utf8_field = field("utf8_field", utf8());
+
+ auto schema1 = schema({int32_field, utf8_field});
+ auto schema2 = schema({int32_field, int32_field, utf8_field});
+
+ ASSERT_RAISES(Invalid, UnifySchemas({schema1, schema2}));
+}
+
+#define PRIMITIVE_TEST(KLASS, CTYPE, ENUM, NAME) \
+ TEST(TypesTest, ARROW_CONCAT(TestPrimitive_, ENUM)) { \
+ KLASS tp; \
+ \
+ ASSERT_EQ(tp.id(), Type::ENUM); \
+ ASSERT_EQ(tp.ToString(), std::string(NAME)); \
+ \
+ using CType = TypeTraits<KLASS>::CType; \
+ static_assert(std::is_same<CType, CTYPE>::value, "Not the same c-type!"); \
+ \
+ using DerivedArrowType = CTypeTraits<CTYPE>::ArrowType; \
+ static_assert(std::is_same<DerivedArrowType, KLASS>::value, \
+ "Not the same arrow-type!"); \
+ }
+
+PRIMITIVE_TEST(Int8Type, int8_t, INT8, "int8");
+PRIMITIVE_TEST(Int16Type, int16_t, INT16, "int16");
+PRIMITIVE_TEST(Int32Type, int32_t, INT32, "int32");
+PRIMITIVE_TEST(Int64Type, int64_t, INT64, "int64");
+PRIMITIVE_TEST(UInt8Type, uint8_t, UINT8, "uint8");
+PRIMITIVE_TEST(UInt16Type, uint16_t, UINT16, "uint16");
+PRIMITIVE_TEST(UInt32Type, uint32_t, UINT32, "uint32");
+PRIMITIVE_TEST(UInt64Type, uint64_t, UINT64, "uint64");
+
+PRIMITIVE_TEST(FloatType, float, FLOAT, "float");
+PRIMITIVE_TEST(DoubleType, double, DOUBLE, "double");
+
+PRIMITIVE_TEST(BooleanType, bool, BOOL, "bool");
+
+TEST(TestBinaryType, ToString) {
+ BinaryType t1;
+ BinaryType e1;
+ StringType t2;
+ AssertTypeEqual(t1, e1);
+ AssertTypeNotEqual(t1, t2);
+ ASSERT_EQ(t1.id(), Type::BINARY);
+ ASSERT_EQ(t1.ToString(), std::string("binary"));
+}
+
+TEST(TestStringType, ToString) {
+ StringType str;
+ ASSERT_EQ(str.id(), Type::STRING);
+ ASSERT_EQ(str.ToString(), std::string("string"));
+}
+
+TEST(TestLargeBinaryTypes, ToString) {
+ BinaryType bt1;
+ LargeBinaryType t1;
+ LargeBinaryType e1;
+ LargeStringType t2;
+ AssertTypeEqual(t1, e1);
+ AssertTypeNotEqual(t1, t2);
+ AssertTypeNotEqual(t1, bt1);
+ ASSERT_EQ(t1.id(), Type::LARGE_BINARY);
+ ASSERT_EQ(t1.ToString(), std::string("large_binary"));
+ ASSERT_EQ(t2.id(), Type::LARGE_STRING);
+ ASSERT_EQ(t2.ToString(), std::string("large_string"));
+}
+
+TEST(TestFixedSizeBinaryType, ToString) {
+ auto t = fixed_size_binary(10);
+ ASSERT_EQ(t->id(), Type::FIXED_SIZE_BINARY);
+ ASSERT_EQ("fixed_size_binary[10]", t->ToString());
+}
+
+TEST(TestFixedSizeBinaryType, Equals) {
+ auto t1 = fixed_size_binary(10);
+ auto t2 = fixed_size_binary(10);
+ auto t3 = fixed_size_binary(3);
+
+ AssertTypeEqual(*t1, *t2);
+ AssertTypeNotEqual(*t1, *t3);
+}
+
+TEST(TestListType, Basics) {
+ std::shared_ptr<DataType> vt = std::make_shared<UInt8Type>();
+
+ ListType list_type(vt);
+ ASSERT_EQ(list_type.id(), Type::LIST);
+
+ ASSERT_EQ("list", list_type.name());
+ ASSERT_EQ("list<item: uint8>", list_type.ToString());
+
+ ASSERT_EQ(list_type.value_type()->id(), vt->id());
+ ASSERT_EQ(list_type.value_type()->id(), vt->id());
+
+ std::shared_ptr<DataType> st = std::make_shared<StringType>();
+ std::shared_ptr<DataType> lt = std::make_shared<ListType>(st);
+ ASSERT_EQ("list<item: string>", lt->ToString());
+
+ ListType lt2(lt);
+ ASSERT_EQ("list<item: list<item: string>>", lt2.ToString());
+}
+
+TEST(TestLargeListType, Basics) {
+ std::shared_ptr<DataType> vt = std::make_shared<UInt8Type>();
+
+ LargeListType list_type(vt);
+ ASSERT_EQ(list_type.id(), Type::LARGE_LIST);
+
+ ASSERT_EQ("large_list", list_type.name());
+ ASSERT_EQ("large_list<item: uint8>", list_type.ToString());
+
+ ASSERT_EQ(list_type.value_type()->id(), vt->id());
+ ASSERT_EQ(list_type.value_type()->id(), vt->id());
+
+ std::shared_ptr<DataType> st = std::make_shared<StringType>();
+ std::shared_ptr<DataType> lt = std::make_shared<LargeListType>(st);
+ ASSERT_EQ("large_list<item: string>", lt->ToString());
+
+ LargeListType lt2(lt);
+ ASSERT_EQ("large_list<item: large_list<item: string>>", lt2.ToString());
+}
+
+TEST(TestMapType, Basics) {
+ std::shared_ptr<DataType> kt = std::make_shared<StringType>();
+ std::shared_ptr<DataType> it = std::make_shared<UInt8Type>();
+
+ MapType map_type(kt, it);
+ ASSERT_EQ(map_type.id(), Type::MAP);
+
+ ASSERT_EQ("map", map_type.name());
+ ASSERT_EQ("map<string, uint8>", map_type.ToString());
+
+ ASSERT_EQ(map_type.key_type()->id(), kt->id());
+ ASSERT_EQ(map_type.item_type()->id(), it->id());
+ ASSERT_EQ(map_type.value_type()->id(), Type::STRUCT);
+
+ std::shared_ptr<DataType> mt = std::make_shared<MapType>(it, kt);
+ ASSERT_EQ("map<uint8, string>", mt->ToString());
+
+ MapType mt2(kt, mt, /*keys_sorted=*/true);
+ ASSERT_EQ("map<string, map<uint8, string>, keys_sorted>", mt2.ToString());
+ AssertTypeNotEqual(map_type, mt2);
+ MapType mt3(kt, mt);
+ ASSERT_EQ("map<string, map<uint8, string>>", mt3.ToString());
+ AssertTypeNotEqual(mt2, mt3);
+ MapType mt4(kt, mt);
+ AssertTypeEqual(mt3, mt4);
+
+ // Field names are indifferent when comparing map types
+ ASSERT_OK_AND_ASSIGN(
+ auto mt5,
+ MapType::Make(field(
+ "some_entries",
+ struct_({field("some_key", kt, false), field("some_value", mt)}), false)));
+ AssertTypeEqual(mt3, *mt5);
+}
+
+TEST(TestFixedSizeListType, Basics) {
+ std::shared_ptr<DataType> vt = std::make_shared<UInt8Type>();
+
+ FixedSizeListType fixed_size_list_type(vt, 4);
+ ASSERT_EQ(fixed_size_list_type.id(), Type::FIXED_SIZE_LIST);
+
+ ASSERT_EQ(4, fixed_size_list_type.list_size());
+ ASSERT_EQ("fixed_size_list", fixed_size_list_type.name());
+ ASSERT_EQ("fixed_size_list<item: uint8>[4]", fixed_size_list_type.ToString());
+
+ ASSERT_EQ(fixed_size_list_type.value_type()->id(), vt->id());
+ ASSERT_EQ(fixed_size_list_type.value_type()->id(), vt->id());
+
+ std::shared_ptr<DataType> st = std::make_shared<StringType>();
+ std::shared_ptr<DataType> lt = std::make_shared<FixedSizeListType>(st, 3);
+ ASSERT_EQ("fixed_size_list<item: string>[3]", lt->ToString());
+
+ FixedSizeListType lt2(lt, 7);
+ ASSERT_EQ("fixed_size_list<item: fixed_size_list<item: string>[3]>[7]", lt2.ToString());
+}
+
+TEST(TestFixedSizeListType, Equals) {
+ auto t1 = fixed_size_list(int8(), 3);
+ auto t2 = fixed_size_list(int8(), 3);
+ auto t3 = fixed_size_list(int8(), 4);
+ auto t4 = fixed_size_list(int16(), 4);
+ auto t5 = fixed_size_list(list(int16()), 4);
+ auto t6 = fixed_size_list(list(int16()), 4);
+ auto t7 = fixed_size_list(list(int32()), 4);
+
+ AssertTypeEqual(t1, t2);
+ AssertTypeNotEqual(t2, t3);
+ AssertTypeNotEqual(t3, t4);
+ AssertTypeNotEqual(t4, t5);
+ AssertTypeEqual(t5, t6);
+ AssertTypeNotEqual(t6, t7);
+}
+
+TEST(TestDateTypes, Attrs) {
+ auto t1 = date32();
+ auto t2 = date64();
+
+ ASSERT_EQ("date32[day]", t1->ToString());
+ ASSERT_EQ("date64[ms]", t2->ToString());
+
+ ASSERT_EQ(32, checked_cast<const FixedWidthType&>(*t1).bit_width());
+ ASSERT_EQ(64, checked_cast<const FixedWidthType&>(*t2).bit_width());
+}
+
+TEST(TestTimeType, Equals) {
+ Time32Type t0;
+ Time32Type t1(TimeUnit::SECOND);
+ Time32Type t2(TimeUnit::MILLI);
+ Time64Type t3(TimeUnit::MICRO);
+ Time64Type t4(TimeUnit::NANO);
+ Time64Type t5(TimeUnit::MICRO);
+
+ ASSERT_EQ(32, t0.bit_width());
+ ASSERT_EQ(64, t3.bit_width());
+
+ AssertTypeEqual(t0, t2);
+ AssertTypeEqual(t1, t1);
+ AssertTypeNotEqual(t1, t3);
+ AssertTypeNotEqual(t3, t4);
+ AssertTypeEqual(t3, t5);
+}
+
+TEST(TestTimeType, ToString) {
+ auto t1 = time32(TimeUnit::MILLI);
+ auto t2 = time64(TimeUnit::NANO);
+ auto t3 = time32(TimeUnit::SECOND);
+ auto t4 = time64(TimeUnit::MICRO);
+
+ ASSERT_EQ("time32[ms]", t1->ToString());
+ ASSERT_EQ("time64[ns]", t2->ToString());
+ ASSERT_EQ("time32[s]", t3->ToString());
+ ASSERT_EQ("time64[us]", t4->ToString());
+}
+
+TEST(TestMonthIntervalType, Equals) {
+ MonthIntervalType t1;
+ MonthIntervalType t2;
+ DayTimeIntervalType t3;
+
+ AssertTypeEqual(t1, t2);
+ AssertTypeNotEqual(t1, t3);
+}
+
+TEST(TestMonthIntervalType, ToString) {
+ auto t1 = month_interval();
+
+ ASSERT_EQ("month_interval", t1->ToString());
+}
+
+TEST(TestDayTimeIntervalType, Equals) {
+ DayTimeIntervalType t1;
+ DayTimeIntervalType t2;
+ MonthIntervalType t3;
+
+ AssertTypeEqual(t1, t2);
+ AssertTypeNotEqual(t1, t3);
+}
+
+TEST(TestDayTimeIntervalType, ToString) {
+ auto t1 = day_time_interval();
+
+ ASSERT_EQ("day_time_interval", t1->ToString());
+}
+
+TEST(TestMonthDayNanoIntervalType, Equals) {
+ MonthDayNanoIntervalType t1;
+ MonthDayNanoIntervalType t2;
+ MonthIntervalType t3;
+ DayTimeIntervalType t4;
+
+ AssertTypeEqual(t1, t2);
+ AssertTypeNotEqual(t1, t3);
+ AssertTypeNotEqual(t1, t4);
+}
+
+TEST(TestMonthDayNanoIntervalType, ToString) {
+ auto t1 = month_day_nano_interval();
+
+ ASSERT_EQ("month_day_nano_interval", t1->ToString());
+}
+
+TEST(TestDurationType, Equals) {
+ DurationType t1;
+ DurationType t2;
+ DurationType t3(TimeUnit::NANO);
+ DurationType t4(TimeUnit::NANO);
+
+ AssertTypeEqual(t1, t2);
+ AssertTypeNotEqual(t1, t3);
+ AssertTypeEqual(t3, t4);
+}
+
+TEST(TestDurationType, ToString) {
+ auto t1 = duration(TimeUnit::MILLI);
+ auto t2 = duration(TimeUnit::NANO);
+ auto t3 = duration(TimeUnit::SECOND);
+ auto t4 = duration(TimeUnit::MICRO);
+
+ ASSERT_EQ("duration[ms]", t1->ToString());
+ ASSERT_EQ("duration[ns]", t2->ToString());
+ ASSERT_EQ("duration[s]", t3->ToString());
+ ASSERT_EQ("duration[us]", t4->ToString());
+}
+
+TEST(TestTimestampType, Equals) {
+ TimestampType t1;
+ TimestampType t2;
+ TimestampType t3(TimeUnit::NANO);
+ TimestampType t4(TimeUnit::NANO);
+
+ DurationType dt1;
+ DurationType dt2(TimeUnit::NANO);
+
+ AssertTypeEqual(t1, t2);
+ AssertTypeNotEqual(t1, t3);
+ AssertTypeEqual(t3, t4);
+
+ AssertTypeNotEqual(t1, dt1);
+ AssertTypeNotEqual(t3, dt2);
+}
+
+TEST(TestTimestampType, ToString) {
+ auto t1 = timestamp(TimeUnit::MILLI);
+ auto t2 = timestamp(TimeUnit::NANO, "US/Eastern");
+ auto t3 = timestamp(TimeUnit::SECOND);
+ auto t4 = timestamp(TimeUnit::MICRO);
+
+ ASSERT_EQ("timestamp[ms]", t1->ToString());
+ ASSERT_EQ("timestamp[ns, tz=US/Eastern]", t2->ToString());
+ ASSERT_EQ("timestamp[s]", t3->ToString());
+ ASSERT_EQ("timestamp[us]", t4->ToString());
+}
+
+TEST(TestListType, Equals) {
+ auto t1 = list(utf8());
+ auto t2 = list(utf8());
+ auto t3 = list(binary());
+ auto t4 = large_list(binary());
+ auto t5 = large_list(binary());
+ auto t6 = large_list(float64());
+
+ AssertTypeEqual(*t1, *t2);
+ AssertTypeNotEqual(*t1, *t3);
+ AssertTypeNotEqual(*t3, *t4);
+ AssertTypeEqual(*t4, *t5);
+ AssertTypeNotEqual(*t5, *t6);
+}
+
+TEST(TestListType, Metadata) {
+ auto md1 = key_value_metadata({"foo", "bar"}, {"foo value", "bar value"});
+ auto md2 = key_value_metadata({"foo", "bar"}, {"foo value", "bar value"});
+ auto md3 = key_value_metadata({"foo"}, {"foo value"});
+
+ auto f1 = field("item", utf8(), /*nullable =*/true, md1);
+ auto f2 = field("item", utf8(), /*nullable =*/true, md2);
+ auto f3 = field("item", utf8(), /*nullable =*/true, md3);
+ auto f4 = field("item", utf8());
+ auto f5 = field("item", utf8(), /*nullable =*/false, md1);
+
+ auto t1 = list(f1);
+ auto t2 = list(f2);
+ auto t3 = list(f3);
+ auto t4 = list(f4);
+ auto t5 = list(f5);
+
+ AssertTypeEqual(*t1, *t2);
+ AssertTypeEqual(*t1, *t2, /*check_metadata =*/false);
+
+ AssertTypeEqual(*t1, *t3);
+ AssertTypeNotEqual(*t1, *t3, /*check_metadata =*/true);
+
+ AssertTypeEqual(*t1, *t4);
+ AssertTypeNotEqual(*t1, *t4, /*check_metadata =*/true);
+
+ AssertTypeNotEqual(*t1, *t5);
+ AssertTypeNotEqual(*t1, *t5, /*check_metadata =*/true);
+}
+
+TEST(TestNestedType, Equals) {
+ auto create_struct = [](std::string inner_name,
+ std::string struct_name) -> std::shared_ptr<Field> {
+ auto f_type = field(inner_name, int32());
+ std::vector<std::shared_ptr<Field>> fields = {f_type};
+ auto s_type = std::make_shared<StructType>(fields);
+ return field(struct_name, s_type);
+ };
+
+ auto create_union = [](std::string inner_name,
+ std::string union_name) -> std::shared_ptr<Field> {
+ auto f_type = field(inner_name, int32());
+ std::vector<std::shared_ptr<Field>> fields = {f_type};
+ std::vector<int8_t> codes = {42};
+ return field(union_name, sparse_union(fields, codes));
+ };
+
+ auto s0 = create_struct("f0", "s0");
+ auto s0_other = create_struct("f0", "s0");
+ auto s0_bad = create_struct("f1", "s0");
+ auto s1 = create_struct("f1", "s1");
+
+ AssertFieldEqual(*s0, *s0_other);
+ AssertFieldNotEqual(*s0, *s1);
+ AssertFieldNotEqual(*s0, *s0_bad);
+
+ auto u0 = create_union("f0", "u0");
+ auto u0_other = create_union("f0", "u0");
+ auto u0_bad = create_union("f1", "u0");
+ auto u1 = create_union("f1", "u1");
+
+ AssertFieldEqual(*u0, *u0_other);
+ AssertFieldNotEqual(*u0, *u1);
+ AssertFieldNotEqual(*u0, *u0_bad);
+}
+
+TEST(TestStructType, Basics) {
+ auto f0_type = int32();
+ auto f0 = field("f0", f0_type);
+
+ auto f1_type = utf8();
+ auto f1 = field("f1", f1_type);
+
+ auto f2_type = uint8();
+ auto f2 = field("f2", f2_type);
+
+ std::vector<std::shared_ptr<Field>> fields = {f0, f1, f2};
+
+ StructType struct_type(fields);
+
+ ASSERT_TRUE(struct_type.field(0)->Equals(f0));
+ ASSERT_TRUE(struct_type.field(1)->Equals(f1));
+ ASSERT_TRUE(struct_type.field(2)->Equals(f2));
+
+ ASSERT_EQ(struct_type.ToString(), "struct<f0: int32, f1: string, f2: uint8>");
+
+ // TODO(wesm): out of bounds for field(...)
+}
+
+TEST(TestStructType, GetFieldByName) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8());
+ auto f3 = field("f3", list(int16()));
+
+ StructType struct_type({f0, f1, f2, f3});
+ std::shared_ptr<Field> result;
+
+ result = struct_type.GetFieldByName("f1");
+ ASSERT_EQ(f1, result);
+
+ result = struct_type.GetFieldByName("f3");
+ ASSERT_EQ(f3, result);
+
+ result = struct_type.GetFieldByName("not-found");
+ ASSERT_EQ(result, nullptr);
+}
+
+TEST(TestStructType, GetFieldIndex) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8());
+ auto f3 = field("f3", list(int16()));
+
+ StructType struct_type({f0, f1, f2, f3});
+
+ ASSERT_EQ(0, struct_type.GetFieldIndex(f0->name()));
+ ASSERT_EQ(1, struct_type.GetFieldIndex(f1->name()));
+ ASSERT_EQ(2, struct_type.GetFieldIndex(f2->name()));
+ ASSERT_EQ(3, struct_type.GetFieldIndex(f3->name()));
+ ASSERT_EQ(-1, struct_type.GetFieldIndex("not-found"));
+}
+
+TEST(TestStructType, GetFieldDuplicates) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", int64());
+ auto f2 = field("f1", utf8());
+ StructType struct_type({f0, f1, f2});
+
+ ASSERT_EQ(0, struct_type.GetFieldIndex("f0"));
+ ASSERT_EQ(-1, struct_type.GetFieldIndex("f1"));
+ ASSERT_EQ(std::vector<int>{0}, struct_type.GetAllFieldIndices(f0->name()));
+ ASSERT_EQ(std::vector<int>({1, 2}), struct_type.GetAllFieldIndices(f1->name()));
+
+ std::vector<std::shared_ptr<Field>> results;
+
+ results = struct_type.GetAllFieldsByName(f0->name());
+ ASSERT_EQ(results.size(), 1);
+ ASSERT_TRUE(results[0]->Equals(f0));
+
+ results = struct_type.GetAllFieldsByName(f1->name());
+ ASSERT_EQ(results.size(), 2);
+ if (results[0]->type()->id() == Type::INT64) {
+ ASSERT_TRUE(results[0]->Equals(f1));
+ ASSERT_TRUE(results[1]->Equals(f2));
+ } else {
+ ASSERT_TRUE(results[0]->Equals(f2));
+ ASSERT_TRUE(results[1]->Equals(f1));
+ }
+
+ results = struct_type.GetAllFieldsByName("not-found");
+ ASSERT_EQ(results.size(), 0);
+}
+
+TEST(TestStructType, TestFieldsDifferOnlyInMetadata) {
+ auto f0 = field("f", utf8(), true, nullptr);
+ auto f1 = field("f", utf8(), true, key_value_metadata({{"foo", "baz"}}));
+
+ StructType s0({f0, f1});
+ StructType s1({f1, f0});
+
+ AssertTypeEqual(s0, s1);
+ AssertTypeNotEqual(s0, s1, /* check_metadata = */ true);
+
+ ASSERT_EQ(s0.fingerprint(), s1.fingerprint());
+ ASSERT_NE(s0.metadata_fingerprint(), s1.metadata_fingerprint());
+}
+
+TEST(TestUnionType, Basics) {
+ auto f0_type = int32();
+ auto f0 = field("f0", f0_type);
+ auto f1_type = utf8();
+ auto f1 = field("f1", f1_type);
+ auto f2_type = uint8();
+ auto f2 = field("f2", f2_type);
+
+ std::vector<std::shared_ptr<Field>> fields = {f0, f1, f2};
+ std::vector<int8_t> type_codes1 = {0, 1, 2};
+ std::vector<int8_t> type_codes2 = {10, 11, 12};
+ std::vector<int> child_ids1(128, -1);
+ std::vector<int> child_ids2(128, -1);
+ child_ids1[0] = 0;
+ child_ids1[1] = 1;
+ child_ids1[2] = 2;
+ child_ids2[10] = 0;
+ child_ids2[11] = 1;
+ child_ids2[12] = 2;
+
+ auto ty1 = checked_pointer_cast<UnionType>(dense_union(fields));
+ auto ty2 = checked_pointer_cast<UnionType>(dense_union(fields, type_codes1));
+ auto ty3 = checked_pointer_cast<UnionType>(dense_union(fields, type_codes2));
+ auto ty4 = checked_pointer_cast<UnionType>(sparse_union(fields));
+ auto ty5 = checked_pointer_cast<UnionType>(sparse_union(fields, type_codes1));
+ auto ty6 = checked_pointer_cast<UnionType>(sparse_union(fields, type_codes2));
+
+ ASSERT_EQ(ty1->type_codes(), type_codes1);
+ ASSERT_EQ(ty2->type_codes(), type_codes1);
+ ASSERT_EQ(ty3->type_codes(), type_codes2);
+ ASSERT_EQ(ty4->type_codes(), type_codes1);
+ ASSERT_EQ(ty5->type_codes(), type_codes1);
+ ASSERT_EQ(ty6->type_codes(), type_codes2);
+
+ ASSERT_EQ(ty1->child_ids(), child_ids1);
+ ASSERT_EQ(ty2->child_ids(), child_ids1);
+ ASSERT_EQ(ty3->child_ids(), child_ids2);
+ ASSERT_EQ(ty4->child_ids(), child_ids1);
+ ASSERT_EQ(ty5->child_ids(), child_ids1);
+ ASSERT_EQ(ty6->child_ids(), child_ids2);
+}
+
+TEST(TestDictionaryType, Basics) {
+ auto value_type = int32();
+
+ std::shared_ptr<DictionaryType> type1 =
+ std::dynamic_pointer_cast<DictionaryType>(dictionary(int16(), value_type));
+
+ auto type2 = std::dynamic_pointer_cast<DictionaryType>(
+ ::arrow::dictionary(int16(), type1, true));
+
+ ASSERT_TRUE(int16()->Equals(type1->index_type()));
+ ASSERT_TRUE(type1->value_type()->Equals(value_type));
+
+ ASSERT_TRUE(int16()->Equals(type2->index_type()));
+ ASSERT_TRUE(type2->value_type()->Equals(type1));
+
+ ASSERT_EQ("dictionary<values=int32, indices=int16, ordered=0>", type1->ToString());
+ ASSERT_EQ(
+ "dictionary<values="
+ "dictionary<values=int32, indices=int16, ordered=0>, "
+ "indices=int16, ordered=1>",
+ type2->ToString());
+}
+
+TEST(TestDictionaryType, Equals) {
+ auto t1 = dictionary(int8(), int32());
+ auto t2 = dictionary(int8(), int32());
+ auto t3 = dictionary(int16(), int32());
+ auto t4 = dictionary(int8(), int16());
+
+ AssertTypeEqual(*t1, *t2);
+ AssertTypeNotEqual(*t1, *t3);
+ AssertTypeNotEqual(*t1, *t4);
+
+ auto t5 = dictionary(int8(), int32(), /*ordered=*/false);
+ auto t6 = dictionary(int8(), int32(), /*ordered=*/true);
+ AssertTypeNotEqual(*t5, *t6);
+}
+
+TEST(TypesTest, TestDecimal128Small) {
+ Decimal128Type t1(8, 4);
+
+ EXPECT_EQ(t1.id(), Type::DECIMAL128);
+ EXPECT_EQ(t1.precision(), 8);
+ EXPECT_EQ(t1.scale(), 4);
+
+ EXPECT_EQ(t1.ToString(), std::string("decimal128(8, 4)"));
+
+ // Test properties
+ EXPECT_EQ(t1.byte_width(), 16);
+ EXPECT_EQ(t1.bit_width(), 128);
+}
+
+TEST(TypesTest, TestDecimal128Medium) {
+ Decimal128Type t1(12, 5);
+
+ EXPECT_EQ(t1.id(), Type::DECIMAL128);
+ EXPECT_EQ(t1.precision(), 12);
+ EXPECT_EQ(t1.scale(), 5);
+
+ EXPECT_EQ(t1.ToString(), std::string("decimal128(12, 5)"));
+
+ // Test properties
+ EXPECT_EQ(t1.byte_width(), 16);
+ EXPECT_EQ(t1.bit_width(), 128);
+}
+
+TEST(TypesTest, TestDecimal128Large) {
+ Decimal128Type t1(27, 7);
+
+ EXPECT_EQ(t1.id(), Type::DECIMAL128);
+ EXPECT_EQ(t1.precision(), 27);
+ EXPECT_EQ(t1.scale(), 7);
+
+ EXPECT_EQ(t1.ToString(), std::string("decimal128(27, 7)"));
+
+ // Test properties
+ EXPECT_EQ(t1.byte_width(), 16);
+ EXPECT_EQ(t1.bit_width(), 128);
+}
+
+TEST(TypesTest, TestDecimal256Small) {
+ Decimal256Type t1(8, 4);
+
+ EXPECT_EQ(t1.id(), Type::DECIMAL256);
+ EXPECT_EQ(t1.precision(), 8);
+ EXPECT_EQ(t1.scale(), 4);
+
+ EXPECT_EQ(t1.ToString(), std::string("decimal256(8, 4)"));
+
+ // Test properties
+ EXPECT_EQ(t1.byte_width(), 32);
+ EXPECT_EQ(t1.bit_width(), 256);
+}
+
+TEST(TypesTest, TestDecimal256Medium) {
+ Decimal256Type t1(12, 5);
+
+ EXPECT_EQ(t1.id(), Type::DECIMAL256);
+ EXPECT_EQ(t1.precision(), 12);
+ EXPECT_EQ(t1.scale(), 5);
+
+ EXPECT_EQ(t1.ToString(), std::string("decimal256(12, 5)"));
+
+ // Test properties
+ EXPECT_EQ(t1.byte_width(), 32);
+ EXPECT_EQ(t1.bit_width(), 256);
+}
+
+TEST(TypesTest, TestDecimal256Large) {
+ Decimal256Type t1(76, 38);
+
+ EXPECT_EQ(t1.id(), Type::DECIMAL256);
+ EXPECT_EQ(t1.precision(), 76);
+ EXPECT_EQ(t1.scale(), 38);
+
+ EXPECT_EQ(t1.ToString(), std::string("decimal256(76, 38)"));
+
+ // Test properties
+ EXPECT_EQ(t1.byte_width(), 32);
+ EXPECT_EQ(t1.bit_width(), 256);
+}
+
+TEST(TypesTest, TestDecimalEquals) {
+ Decimal128Type t1(8, 4);
+ Decimal128Type t2(8, 4);
+ Decimal128Type t3(8, 5);
+ Decimal128Type t4(27, 5);
+
+ Decimal256Type t5(8, 4);
+ Decimal256Type t6(8, 4);
+ Decimal256Type t7(8, 5);
+ Decimal256Type t8(27, 5);
+
+ FixedSizeBinaryType t9(16);
+ FixedSizeBinaryType t10(32);
+
+ AssertTypeEqual(t1, t2);
+ AssertTypeNotEqual(t1, t3);
+ AssertTypeNotEqual(t1, t4);
+ AssertTypeNotEqual(t1, t9);
+
+ AssertTypeEqual(t5, t6);
+ AssertTypeNotEqual(t5, t1);
+ AssertTypeNotEqual(t5, t7);
+ AssertTypeNotEqual(t5, t8);
+ AssertTypeNotEqual(t5, t10);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/type_traits.h b/src/arrow/cpp/src/arrow/type_traits.h
new file mode 100644
index 000000000..56d059287
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/type_traits.h
@@ -0,0 +1,1059 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "arrow/type.h"
+#include "arrow/util/bit_util.h"
+
+namespace arrow {
+
+//
+// Per-type id type lookup
+//
+
+template <Type::type id>
+struct TypeIdTraits {};
+
+#define TYPE_ID_TRAIT(_id, _typeclass) \
+ template <> \
+ struct TypeIdTraits<Type::_id> { \
+ using Type = _typeclass; \
+ };
+
+TYPE_ID_TRAIT(NA, NullType)
+TYPE_ID_TRAIT(BOOL, BooleanType)
+TYPE_ID_TRAIT(INT8, Int8Type)
+TYPE_ID_TRAIT(INT16, Int16Type)
+TYPE_ID_TRAIT(INT32, Int32Type)
+TYPE_ID_TRAIT(INT64, Int64Type)
+TYPE_ID_TRAIT(UINT8, UInt8Type)
+TYPE_ID_TRAIT(UINT16, UInt16Type)
+TYPE_ID_TRAIT(UINT32, UInt32Type)
+TYPE_ID_TRAIT(UINT64, UInt64Type)
+TYPE_ID_TRAIT(HALF_FLOAT, HalfFloatType)
+TYPE_ID_TRAIT(FLOAT, FloatType)
+TYPE_ID_TRAIT(DOUBLE, DoubleType)
+TYPE_ID_TRAIT(STRING, StringType)
+TYPE_ID_TRAIT(BINARY, BinaryType)
+TYPE_ID_TRAIT(LARGE_STRING, LargeStringType)
+TYPE_ID_TRAIT(LARGE_BINARY, LargeBinaryType)
+TYPE_ID_TRAIT(FIXED_SIZE_BINARY, FixedSizeBinaryType)
+TYPE_ID_TRAIT(DATE32, Date32Type)
+TYPE_ID_TRAIT(DATE64, Date64Type)
+TYPE_ID_TRAIT(TIME32, Time32Type)
+TYPE_ID_TRAIT(TIME64, Time64Type)
+TYPE_ID_TRAIT(TIMESTAMP, TimestampType)
+TYPE_ID_TRAIT(INTERVAL_DAY_TIME, DayTimeIntervalType)
+TYPE_ID_TRAIT(INTERVAL_MONTH_DAY_NANO, MonthDayNanoIntervalType)
+TYPE_ID_TRAIT(INTERVAL_MONTHS, MonthIntervalType)
+TYPE_ID_TRAIT(DURATION, DurationType)
+TYPE_ID_TRAIT(DECIMAL128, Decimal128Type)
+TYPE_ID_TRAIT(DECIMAL256, Decimal256Type)
+TYPE_ID_TRAIT(STRUCT, StructType)
+TYPE_ID_TRAIT(LIST, ListType)
+TYPE_ID_TRAIT(LARGE_LIST, LargeListType)
+TYPE_ID_TRAIT(FIXED_SIZE_LIST, FixedSizeListType)
+TYPE_ID_TRAIT(MAP, MapType)
+TYPE_ID_TRAIT(DENSE_UNION, DenseUnionType)
+TYPE_ID_TRAIT(SPARSE_UNION, SparseUnionType)
+TYPE_ID_TRAIT(DICTIONARY, DictionaryType)
+TYPE_ID_TRAIT(EXTENSION, ExtensionType)
+
+#undef TYPE_ID_TRAIT
+
+//
+// Per-type type traits
+//
+
+template <typename T>
+struct TypeTraits {};
+
+template <typename T>
+struct CTypeTraits {};
+
+template <>
+struct TypeTraits<NullType> {
+ using ArrayType = NullArray;
+ using BuilderType = NullBuilder;
+ using ScalarType = NullScalar;
+
+ static constexpr int64_t bytes_required(int64_t) { return 0; }
+ constexpr static bool is_parameter_free = true;
+ static inline std::shared_ptr<DataType> type_singleton() { return null(); }
+};
+
+template <>
+struct TypeTraits<BooleanType> {
+ using ArrayType = BooleanArray;
+ using BuilderType = BooleanBuilder;
+ using ScalarType = BooleanScalar;
+ using CType = bool;
+
+ static constexpr int64_t bytes_required(int64_t elements) {
+ return BitUtil::BytesForBits(elements);
+ }
+ constexpr static bool is_parameter_free = true;
+ static inline std::shared_ptr<DataType> type_singleton() { return boolean(); }
+};
+
+template <>
+struct CTypeTraits<bool> : public TypeTraits<BooleanType> {
+ using ArrowType = BooleanType;
+};
+
+#define PRIMITIVE_TYPE_TRAITS_DEF_(CType_, ArrowType_, ArrowArrayType, ArrowBuilderType, \
+ ArrowScalarType, ArrowTensorType, SingletonFn) \
+ template <> \
+ struct TypeTraits<ArrowType_> { \
+ using ArrayType = ArrowArrayType; \
+ using BuilderType = ArrowBuilderType; \
+ using ScalarType = ArrowScalarType; \
+ using TensorType = ArrowTensorType; \
+ using CType = ArrowType_::c_type; \
+ static constexpr int64_t bytes_required(int64_t elements) { \
+ return elements * static_cast<int64_t>(sizeof(CType)); \
+ } \
+ constexpr static bool is_parameter_free = true; \
+ static inline std::shared_ptr<DataType> type_singleton() { return SingletonFn(); } \
+ }; \
+ \
+ template <> \
+ struct CTypeTraits<CType_> : public TypeTraits<ArrowType_> { \
+ using ArrowType = ArrowType_; \
+ };
+
+#define PRIMITIVE_TYPE_TRAITS_DEF(CType, ArrowShort, SingletonFn) \
+ PRIMITIVE_TYPE_TRAITS_DEF_( \
+ CType, ARROW_CONCAT(ArrowShort, Type), ARROW_CONCAT(ArrowShort, Array), \
+ ARROW_CONCAT(ArrowShort, Builder), ARROW_CONCAT(ArrowShort, Scalar), \
+ ARROW_CONCAT(ArrowShort, Tensor), SingletonFn)
+
+PRIMITIVE_TYPE_TRAITS_DEF(uint8_t, UInt8, uint8)
+PRIMITIVE_TYPE_TRAITS_DEF(int8_t, Int8, int8)
+PRIMITIVE_TYPE_TRAITS_DEF(uint16_t, UInt16, uint16)
+PRIMITIVE_TYPE_TRAITS_DEF(int16_t, Int16, int16)
+PRIMITIVE_TYPE_TRAITS_DEF(uint32_t, UInt32, uint32)
+PRIMITIVE_TYPE_TRAITS_DEF(int32_t, Int32, int32)
+PRIMITIVE_TYPE_TRAITS_DEF(uint64_t, UInt64, uint64)
+PRIMITIVE_TYPE_TRAITS_DEF(int64_t, Int64, int64)
+PRIMITIVE_TYPE_TRAITS_DEF(float, Float, float32)
+PRIMITIVE_TYPE_TRAITS_DEF(double, Double, float64)
+
+#undef PRIMITIVE_TYPE_TRAITS_DEF
+#undef PRIMITIVE_TYPE_TRAITS_DEF_
+
+template <>
+struct TypeTraits<Date64Type> {
+ using ArrayType = Date64Array;
+ using BuilderType = Date64Builder;
+ using ScalarType = Date64Scalar;
+ using CType = Date64Type::c_type;
+
+ static constexpr int64_t bytes_required(int64_t elements) {
+ return elements * static_cast<int64_t>(sizeof(int64_t));
+ }
+ constexpr static bool is_parameter_free = true;
+ static inline std::shared_ptr<DataType> type_singleton() { return date64(); }
+};
+
+template <>
+struct TypeTraits<Date32Type> {
+ using ArrayType = Date32Array;
+ using BuilderType = Date32Builder;
+ using ScalarType = Date32Scalar;
+ using CType = Date32Type::c_type;
+
+ static constexpr int64_t bytes_required(int64_t elements) {
+ return elements * static_cast<int64_t>(sizeof(int32_t));
+ }
+ constexpr static bool is_parameter_free = true;
+ static inline std::shared_ptr<DataType> type_singleton() { return date32(); }
+};
+
+template <>
+struct TypeTraits<TimestampType> {
+ using ArrayType = TimestampArray;
+ using BuilderType = TimestampBuilder;
+ using ScalarType = TimestampScalar;
+ using CType = TimestampType::c_type;
+
+ static constexpr int64_t bytes_required(int64_t elements) {
+ return elements * static_cast<int64_t>(sizeof(int64_t));
+ }
+ constexpr static bool is_parameter_free = false;
+};
+
+template <>
+struct TypeTraits<DurationType> {
+ using ArrayType = DurationArray;
+ using BuilderType = DurationBuilder;
+ using ScalarType = DurationScalar;
+ using CType = DurationType::c_type;
+
+ static constexpr int64_t bytes_required(int64_t elements) {
+ return elements * static_cast<int64_t>(sizeof(int64_t));
+ }
+ constexpr static bool is_parameter_free = false;
+};
+
+template <>
+struct TypeTraits<DayTimeIntervalType> {
+ using ArrayType = DayTimeIntervalArray;
+ using BuilderType = DayTimeIntervalBuilder;
+ using ScalarType = DayTimeIntervalScalar;
+ using CType = DayTimeIntervalType::c_type;
+
+ static constexpr int64_t bytes_required(int64_t elements) {
+ return elements * static_cast<int64_t>(sizeof(DayTimeIntervalType::DayMilliseconds));
+ }
+ constexpr static bool is_parameter_free = true;
+ static std::shared_ptr<DataType> type_singleton() { return day_time_interval(); }
+};
+
+template <>
+struct TypeTraits<MonthDayNanoIntervalType> {
+ using ArrayType = MonthDayNanoIntervalArray;
+ using BuilderType = MonthDayNanoIntervalBuilder;
+ using ScalarType = MonthDayNanoIntervalScalar;
+
+ static constexpr int64_t bytes_required(int64_t elements) {
+ return elements *
+ static_cast<int64_t>(sizeof(MonthDayNanoIntervalType::MonthDayNanos));
+ }
+ constexpr static bool is_parameter_free = true;
+ static std::shared_ptr<DataType> type_singleton() { return month_day_nano_interval(); }
+};
+
+template <>
+struct TypeTraits<MonthIntervalType> {
+ using ArrayType = MonthIntervalArray;
+ using BuilderType = MonthIntervalBuilder;
+ using ScalarType = MonthIntervalScalar;
+ using CType = MonthIntervalType::c_type;
+
+ static constexpr int64_t bytes_required(int64_t elements) {
+ return elements * static_cast<int64_t>(sizeof(int32_t));
+ }
+ constexpr static bool is_parameter_free = true;
+ static std::shared_ptr<DataType> type_singleton() { return month_interval(); }
+};
+
+template <>
+struct TypeTraits<Time32Type> {
+ using ArrayType = Time32Array;
+ using BuilderType = Time32Builder;
+ using ScalarType = Time32Scalar;
+ using CType = Time32Type::c_type;
+
+ static constexpr int64_t bytes_required(int64_t elements) {
+ return elements * static_cast<int64_t>(sizeof(int32_t));
+ }
+ constexpr static bool is_parameter_free = false;
+};
+
+template <>
+struct TypeTraits<Time64Type> {
+ using ArrayType = Time64Array;
+ using BuilderType = Time64Builder;
+ using ScalarType = Time64Scalar;
+ using CType = Time64Type::c_type;
+
+ static constexpr int64_t bytes_required(int64_t elements) {
+ return elements * static_cast<int64_t>(sizeof(int64_t));
+ }
+ constexpr static bool is_parameter_free = false;
+};
+
+template <>
+struct TypeTraits<HalfFloatType> {
+ using ArrayType = HalfFloatArray;
+ using BuilderType = HalfFloatBuilder;
+ using ScalarType = HalfFloatScalar;
+ using TensorType = HalfFloatTensor;
+
+ static constexpr int64_t bytes_required(int64_t elements) {
+ return elements * static_cast<int64_t>(sizeof(uint16_t));
+ }
+ constexpr static bool is_parameter_free = true;
+ static inline std::shared_ptr<DataType> type_singleton() { return float16(); }
+};
+
+template <>
+struct TypeTraits<Decimal128Type> {
+ using ArrayType = Decimal128Array;
+ using BuilderType = Decimal128Builder;
+ using ScalarType = Decimal128Scalar;
+ using CType = Decimal128;
+ constexpr static bool is_parameter_free = false;
+};
+
+template <>
+struct TypeTraits<Decimal256Type> {
+ using ArrayType = Decimal256Array;
+ using BuilderType = Decimal256Builder;
+ using ScalarType = Decimal256Scalar;
+ using CType = Decimal256;
+ constexpr static bool is_parameter_free = false;
+};
+
+template <>
+struct TypeTraits<BinaryType> {
+ using ArrayType = BinaryArray;
+ using BuilderType = BinaryBuilder;
+ using ScalarType = BinaryScalar;
+ using OffsetType = Int32Type;
+ constexpr static bool is_parameter_free = true;
+ static inline std::shared_ptr<DataType> type_singleton() { return binary(); }
+};
+
+template <>
+struct TypeTraits<LargeBinaryType> {
+ using ArrayType = LargeBinaryArray;
+ using BuilderType = LargeBinaryBuilder;
+ using ScalarType = LargeBinaryScalar;
+ using OffsetType = Int64Type;
+ constexpr static bool is_parameter_free = true;
+ static inline std::shared_ptr<DataType> type_singleton() { return large_binary(); }
+};
+
+template <>
+struct TypeTraits<FixedSizeBinaryType> {
+ using ArrayType = FixedSizeBinaryArray;
+ using BuilderType = FixedSizeBinaryBuilder;
+ using ScalarType = FixedSizeBinaryScalar;
+ // FixedSizeBinary doesn't have offsets per se, but string length is int32 sized
+ using OffsetType = Int32Type;
+ constexpr static bool is_parameter_free = false;
+};
+
+template <>
+struct TypeTraits<StringType> {
+ using ArrayType = StringArray;
+ using BuilderType = StringBuilder;
+ using ScalarType = StringScalar;
+ using OffsetType = Int32Type;
+ constexpr static bool is_parameter_free = true;
+ static inline std::shared_ptr<DataType> type_singleton() { return utf8(); }
+};
+
+template <>
+struct TypeTraits<LargeStringType> {
+ using ArrayType = LargeStringArray;
+ using BuilderType = LargeStringBuilder;
+ using ScalarType = LargeStringScalar;
+ using OffsetType = Int64Type;
+ constexpr static bool is_parameter_free = true;
+ static inline std::shared_ptr<DataType> type_singleton() { return large_utf8(); }
+};
+
+template <>
+struct CTypeTraits<std::string> : public TypeTraits<StringType> {
+ using ArrowType = StringType;
+};
+
+template <>
+struct CTypeTraits<const char*> : public CTypeTraits<std::string> {};
+
+template <size_t N>
+struct CTypeTraits<const char (&)[N]> : public CTypeTraits<std::string> {};
+
+template <>
+struct CTypeTraits<DayTimeIntervalType::DayMilliseconds>
+ : public TypeTraits<DayTimeIntervalType> {
+ using ArrowType = DayTimeIntervalType;
+};
+
+template <>
+struct TypeTraits<ListType> {
+ using ArrayType = ListArray;
+ using BuilderType = ListBuilder;
+ using ScalarType = ListScalar;
+ using OffsetType = Int32Type;
+ using OffsetArrayType = Int32Array;
+ using OffsetBuilderType = Int32Builder;
+ using OffsetScalarType = Int32Scalar;
+ constexpr static bool is_parameter_free = false;
+};
+
+template <>
+struct TypeTraits<LargeListType> {
+ using ArrayType = LargeListArray;
+ using BuilderType = LargeListBuilder;
+ using ScalarType = LargeListScalar;
+ using OffsetType = Int64Type;
+ using OffsetArrayType = Int64Array;
+ using OffsetBuilderType = Int64Builder;
+ using OffsetScalarType = Int64Scalar;
+ constexpr static bool is_parameter_free = false;
+};
+
+template <>
+struct TypeTraits<MapType> {
+ using ArrayType = MapArray;
+ using BuilderType = MapBuilder;
+ using ScalarType = MapScalar;
+ using OffsetType = Int32Type;
+ using OffsetArrayType = Int32Array;
+ using OffsetBuilderType = Int32Builder;
+ constexpr static bool is_parameter_free = false;
+};
+
+template <>
+struct TypeTraits<FixedSizeListType> {
+ using ArrayType = FixedSizeListArray;
+ using BuilderType = FixedSizeListBuilder;
+ using ScalarType = FixedSizeListScalar;
+ constexpr static bool is_parameter_free = false;
+};
+
+template <typename CType>
+struct CTypeTraits<std::vector<CType>> : public TypeTraits<ListType> {
+ using ArrowType = ListType;
+
+ static inline std::shared_ptr<DataType> type_singleton() {
+ return list(CTypeTraits<CType>::type_singleton());
+ }
+};
+
+template <>
+struct TypeTraits<StructType> {
+ using ArrayType = StructArray;
+ using BuilderType = StructBuilder;
+ using ScalarType = StructScalar;
+ constexpr static bool is_parameter_free = false;
+};
+
+template <>
+struct TypeTraits<SparseUnionType> {
+ using ArrayType = SparseUnionArray;
+ using BuilderType = SparseUnionBuilder;
+ using ScalarType = SparseUnionScalar;
+ constexpr static bool is_parameter_free = false;
+};
+
+template <>
+struct TypeTraits<DenseUnionType> {
+ using ArrayType = DenseUnionArray;
+ using BuilderType = DenseUnionBuilder;
+ using ScalarType = DenseUnionScalar;
+ constexpr static bool is_parameter_free = false;
+};
+
+template <>
+struct TypeTraits<DictionaryType> {
+ using ArrayType = DictionaryArray;
+ using ScalarType = DictionaryScalar;
+ constexpr static bool is_parameter_free = false;
+};
+
+template <>
+struct TypeTraits<ExtensionType> {
+ using ArrayType = ExtensionArray;
+ using ScalarType = ExtensionScalar;
+ constexpr static bool is_parameter_free = false;
+};
+
+namespace internal {
+
+template <typename... Ts>
+struct make_void {
+ using type = void;
+};
+
+template <typename... Ts>
+using void_t = typename make_void<Ts...>::type;
+
+} // namespace internal
+
+//
+// Useful type predicates
+//
+
+// only in C++14
+template <bool B, typename T = void>
+using enable_if_t = typename std::enable_if<B, T>::type;
+
+template <typename T>
+using is_null_type = std::is_same<NullType, T>;
+
+template <typename T, typename R = void>
+using enable_if_null = enable_if_t<is_null_type<T>::value, R>;
+
+template <typename T>
+using is_boolean_type = std::is_same<BooleanType, T>;
+
+template <typename T, typename R = void>
+using enable_if_boolean = enable_if_t<is_boolean_type<T>::value, R>;
+
+template <typename T>
+using is_number_type = std::is_base_of<NumberType, T>;
+
+template <typename T, typename R = void>
+using enable_if_number = enable_if_t<is_number_type<T>::value, R>;
+
+template <typename T>
+using is_integer_type = std::is_base_of<IntegerType, T>;
+
+template <typename T, typename R = void>
+using enable_if_integer = enable_if_t<is_integer_type<T>::value, R>;
+
+template <typename T>
+using is_signed_integer_type =
+ std::integral_constant<bool, is_integer_type<T>::value &&
+ std::is_signed<typename T::c_type>::value>;
+
+template <typename T, typename R = void>
+using enable_if_signed_integer = enable_if_t<is_signed_integer_type<T>::value, R>;
+
+template <typename T>
+using is_unsigned_integer_type =
+ std::integral_constant<bool, is_integer_type<T>::value &&
+ std::is_unsigned<typename T::c_type>::value>;
+
+template <typename T, typename R = void>
+using enable_if_unsigned_integer = enable_if_t<is_unsigned_integer_type<T>::value, R>;
+
+// Note this will also include HalfFloatType which is represented by a
+// non-floating point primitive (uint16_t).
+template <typename T>
+using is_floating_type = std::is_base_of<FloatingPointType, T>;
+
+template <typename T, typename R = void>
+using enable_if_floating_point = enable_if_t<is_floating_type<T>::value, R>;
+
+// Half floats are special in that they behave physically like an unsigned
+// integer.
+template <typename T>
+using is_half_float_type = std::is_same<HalfFloatType, T>;
+
+template <typename T, typename R = void>
+using enable_if_half_float = enable_if_t<is_half_float_type<T>::value, R>;
+
+// Binary Types
+
+// Base binary refers to Binary/LargeBinary/String/LargeString
+template <typename T>
+using is_base_binary_type = std::is_base_of<BaseBinaryType, T>;
+
+template <typename T, typename R = void>
+using enable_if_base_binary = enable_if_t<is_base_binary_type<T>::value, R>;
+
+// Any binary excludes string from Base binary
+template <typename T>
+using is_binary_type =
+ std::integral_constant<bool, std::is_same<BinaryType, T>::value ||
+ std::is_same<LargeBinaryType, T>::value>;
+
+template <typename T, typename R = void>
+using enable_if_binary = enable_if_t<is_binary_type<T>::value, R>;
+
+template <typename T>
+using is_string_type =
+ std::integral_constant<bool, std::is_same<StringType, T>::value ||
+ std::is_same<LargeStringType, T>::value>;
+
+template <typename T, typename R = void>
+using enable_if_string = enable_if_t<is_string_type<T>::value, R>;
+
+template <typename T>
+using is_string_like_type =
+ std::integral_constant<bool, is_base_binary_type<T>::value && T::is_utf8>;
+
+template <typename T, typename R = void>
+using enable_if_string_like = enable_if_t<is_string_like_type<T>::value, R>;
+
+template <typename T, typename U, typename R = void>
+using enable_if_same = enable_if_t<std::is_same<T, U>::value, R>;
+
+// Note that this also includes DecimalType
+template <typename T>
+using is_fixed_size_binary_type = std::is_base_of<FixedSizeBinaryType, T>;
+
+template <typename T, typename R = void>
+using enable_if_fixed_size_binary = enable_if_t<is_fixed_size_binary_type<T>::value, R>;
+
+template <typename T>
+using is_binary_like_type =
+ std::integral_constant<bool, (is_base_binary_type<T>::value &&
+ !is_string_like_type<T>::value) ||
+ is_fixed_size_binary_type<T>::value>;
+
+template <typename T, typename R = void>
+using enable_if_binary_like = enable_if_t<is_binary_like_type<T>::value, R>;
+
+template <typename T>
+using is_decimal_type = std::is_base_of<DecimalType, T>;
+
+template <typename T, typename R = void>
+using enable_if_decimal = enable_if_t<is_decimal_type<T>::value, R>;
+
+template <typename T>
+using is_decimal128_type = std::is_base_of<Decimal128Type, T>;
+
+template <typename T, typename R = void>
+using enable_if_decimal128 = enable_if_t<is_decimal128_type<T>::value, R>;
+
+template <typename T>
+using is_decimal256_type = std::is_base_of<Decimal256Type, T>;
+
+template <typename T, typename R = void>
+using enable_if_decimal256 = enable_if_t<is_decimal256_type<T>::value, R>;
+
+// Nested Types
+
+template <typename T>
+using is_nested_type = std::is_base_of<NestedType, T>;
+
+template <typename T, typename R = void>
+using enable_if_nested = enable_if_t<is_nested_type<T>::value, R>;
+
+template <typename T, typename R = void>
+using enable_if_not_nested = enable_if_t<!is_nested_type<T>::value, R>;
+
+template <typename T>
+using is_var_length_list_type =
+ std::integral_constant<bool, std::is_base_of<LargeListType, T>::value ||
+ std::is_base_of<ListType, T>::value>;
+
+template <typename T, typename R = void>
+using enable_if_var_size_list = enable_if_t<is_var_length_list_type<T>::value, R>;
+
+// DEPRECATED use is_var_length_list_type.
+template <typename T>
+using is_base_list_type = is_var_length_list_type<T>;
+
+// DEPRECATED use enable_if_var_size_list
+template <typename T, typename R = void>
+using enable_if_base_list = enable_if_var_size_list<T, R>;
+
+template <typename T>
+using is_fixed_size_list_type = std::is_same<FixedSizeListType, T>;
+
+template <typename T, typename R = void>
+using enable_if_fixed_size_list = enable_if_t<is_fixed_size_list_type<T>::value, R>;
+
+template <typename T>
+using is_list_type =
+ std::integral_constant<bool, std::is_same<T, ListType>::value ||
+ std::is_same<T, LargeListType>::value ||
+ std::is_same<T, FixedSizeListType>::value>;
+
+template <typename T, typename R = void>
+using enable_if_list_type = enable_if_t<is_list_type<T>::value, R>;
+
+template <typename T>
+using is_list_like_type =
+ std::integral_constant<bool, is_base_list_type<T>::value ||
+ is_fixed_size_list_type<T>::value>;
+
+template <typename T, typename R = void>
+using enable_if_list_like = enable_if_t<is_list_like_type<T>::value, R>;
+
+template <typename T>
+using is_struct_type = std::is_base_of<StructType, T>;
+
+template <typename T, typename R = void>
+using enable_if_struct = enable_if_t<is_struct_type<T>::value, R>;
+
+template <typename T>
+using is_union_type = std::is_base_of<UnionType, T>;
+
+template <typename T, typename R = void>
+using enable_if_union = enable_if_t<is_union_type<T>::value, R>;
+
+// TemporalTypes
+
+template <typename T>
+using is_temporal_type = std::is_base_of<TemporalType, T>;
+
+template <typename T, typename R = void>
+using enable_if_temporal = enable_if_t<is_temporal_type<T>::value, R>;
+
+template <typename T>
+using is_date_type = std::is_base_of<DateType, T>;
+
+template <typename T, typename R = void>
+using enable_if_date = enable_if_t<is_date_type<T>::value, R>;
+
+template <typename T>
+using is_time_type = std::is_base_of<TimeType, T>;
+
+template <typename T, typename R = void>
+using enable_if_time = enable_if_t<is_time_type<T>::value, R>;
+
+template <typename T>
+using is_timestamp_type = std::is_base_of<TimestampType, T>;
+
+template <typename T, typename R = void>
+using enable_if_timestamp = enable_if_t<is_timestamp_type<T>::value, R>;
+
+template <typename T>
+using is_duration_type = std::is_base_of<DurationType, T>;
+
+template <typename T, typename R = void>
+using enable_if_duration = enable_if_t<is_duration_type<T>::value, R>;
+
+template <typename T>
+using is_interval_type = std::is_base_of<IntervalType, T>;
+
+template <typename T, typename R = void>
+using enable_if_interval = enable_if_t<is_interval_type<T>::value, R>;
+
+template <typename T>
+using is_dictionary_type = std::is_base_of<DictionaryType, T>;
+
+template <typename T, typename R = void>
+using enable_if_dictionary = enable_if_t<is_dictionary_type<T>::value, R>;
+
+template <typename T>
+using is_extension_type = std::is_base_of<ExtensionType, T>;
+
+template <typename T, typename R = void>
+using enable_if_extension = enable_if_t<is_extension_type<T>::value, R>;
+
+// Attribute differentiation
+
+template <typename T>
+using is_primitive_ctype = std::is_base_of<PrimitiveCType, T>;
+
+template <typename T, typename R = void>
+using enable_if_primitive_ctype = enable_if_t<is_primitive_ctype<T>::value, R>;
+
+template <typename T>
+using has_c_type = std::integral_constant<bool, is_primitive_ctype<T>::value ||
+ is_temporal_type<T>::value>;
+
+template <typename T, typename R = void>
+using enable_if_has_c_type = enable_if_t<has_c_type<T>::value, R>;
+
+template <typename T>
+using has_string_view =
+ std::integral_constant<bool, std::is_same<BinaryType, T>::value ||
+ std::is_same<LargeBinaryType, T>::value ||
+ std::is_same<StringType, T>::value ||
+ std::is_same<LargeStringType, T>::value ||
+ std::is_same<FixedSizeBinaryType, T>::value>;
+
+template <typename T, typename R = void>
+using enable_if_has_string_view = enable_if_t<has_string_view<T>::value, R>;
+
+template <typename T>
+using is_8bit_int = std::integral_constant<bool, std::is_same<UInt8Type, T>::value ||
+ std::is_same<Int8Type, T>::value>;
+
+template <typename T, typename R = void>
+using enable_if_8bit_int = enable_if_t<is_8bit_int<T>::value, R>;
+
+template <typename T>
+using is_parameter_free_type =
+ std::integral_constant<bool, TypeTraits<T>::is_parameter_free>;
+
+template <typename T, typename R = void>
+using enable_if_parameter_free = enable_if_t<is_parameter_free_type<T>::value, R>;
+
+// Physical representation quirks
+
+template <typename T>
+using is_physical_signed_integer_type =
+ std::integral_constant<bool,
+ is_signed_integer_type<T>::value ||
+ (is_temporal_type<T>::value && has_c_type<T>::value &&
+ std::is_integral<typename T::c_type>::value)>;
+
+template <typename T, typename R = void>
+using enable_if_physical_signed_integer =
+ enable_if_t<is_physical_signed_integer_type<T>::value, R>;
+
+template <typename T>
+using is_physical_unsigned_integer_type =
+ std::integral_constant<bool, is_unsigned_integer_type<T>::value ||
+ is_half_float_type<T>::value>;
+
+template <typename T, typename R = void>
+using enable_if_physical_unsigned_integer =
+ enable_if_t<is_physical_unsigned_integer_type<T>::value, R>;
+
+template <typename T>
+using is_physical_integer_type =
+ std::integral_constant<bool, is_physical_unsigned_integer_type<T>::value ||
+ is_physical_signed_integer_type<T>::value>;
+
+template <typename T, typename R = void>
+using enable_if_physical_integer = enable_if_t<is_physical_integer_type<T>::value, R>;
+
+// Like is_floating_type but excluding half-floats which don't have a
+// float-like c type.
+template <typename T>
+using is_physical_floating_type =
+ std::integral_constant<bool,
+ is_floating_type<T>::value && !is_half_float_type<T>::value>;
+
+template <typename T, typename R = void>
+using enable_if_physical_floating_point =
+ enable_if_t<is_physical_floating_type<T>::value, R>;
+
+static inline bool is_integer(Type::type type_id) {
+ switch (type_id) {
+ case Type::UINT8:
+ case Type::INT8:
+ case Type::UINT16:
+ case Type::INT16:
+ case Type::UINT32:
+ case Type::INT32:
+ case Type::UINT64:
+ case Type::INT64:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline bool is_signed_integer(Type::type type_id) {
+ switch (type_id) {
+ case Type::INT8:
+ case Type::INT16:
+ case Type::INT32:
+ case Type::INT64:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline bool is_unsigned_integer(Type::type type_id) {
+ switch (type_id) {
+ case Type::UINT8:
+ case Type::UINT16:
+ case Type::UINT32:
+ case Type::UINT64:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline bool is_floating(Type::type type_id) {
+ switch (type_id) {
+ case Type::HALF_FLOAT:
+ case Type::FLOAT:
+ case Type::DOUBLE:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline bool is_decimal(Type::type type_id) {
+ switch (type_id) {
+ case Type::DECIMAL128:
+ case Type::DECIMAL256:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline bool is_primitive(Type::type type_id) {
+ switch (type_id) {
+ case Type::BOOL:
+ case Type::UINT8:
+ case Type::INT8:
+ case Type::UINT16:
+ case Type::INT16:
+ case Type::UINT32:
+ case Type::INT32:
+ case Type::UINT64:
+ case Type::INT64:
+ case Type::HALF_FLOAT:
+ case Type::FLOAT:
+ case Type::DOUBLE:
+ case Type::DATE32:
+ case Type::DATE64:
+ case Type::TIME32:
+ case Type::TIME64:
+ case Type::TIMESTAMP:
+ case Type::DURATION:
+ case Type::INTERVAL_MONTHS:
+ case Type::INTERVAL_MONTH_DAY_NANO:
+ case Type::INTERVAL_DAY_TIME:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline bool is_base_binary_like(Type::type type_id) {
+ switch (type_id) {
+ case Type::BINARY:
+ case Type::LARGE_BINARY:
+ case Type::STRING:
+ case Type::LARGE_STRING:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline bool is_binary_like(Type::type type_id) {
+ switch (type_id) {
+ case Type::BINARY:
+ case Type::STRING:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline bool is_large_binary_like(Type::type type_id) {
+ switch (type_id) {
+ case Type::LARGE_BINARY:
+ case Type::LARGE_STRING:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline bool is_dictionary(Type::type type_id) {
+ return type_id == Type::DICTIONARY;
+}
+
+static inline bool is_fixed_size_binary(Type::type type_id) {
+ switch (type_id) {
+ case Type::DECIMAL128:
+ case Type::DECIMAL256:
+ case Type::FIXED_SIZE_BINARY:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline bool is_fixed_width(Type::type type_id) {
+ return is_primitive(type_id) || is_dictionary(type_id) || is_fixed_size_binary(type_id);
+}
+
+static inline int bit_width(Type::type type_id) {
+ switch (type_id) {
+ case Type::BOOL:
+ return 1;
+ case Type::UINT8:
+ case Type::INT8:
+ return 8;
+ case Type::UINT16:
+ case Type::INT16:
+ return 16;
+ case Type::UINT32:
+ case Type::INT32:
+ case Type::DATE32:
+ case Type::TIME32:
+ return 32;
+ case Type::UINT64:
+ case Type::INT64:
+ case Type::DATE64:
+ case Type::TIME64:
+ case Type::TIMESTAMP:
+ case Type::DURATION:
+ return 64;
+
+ case Type::HALF_FLOAT:
+ return 16;
+ case Type::FLOAT:
+ return 32;
+ case Type::DOUBLE:
+ return 64;
+
+ case Type::INTERVAL_MONTHS:
+ return 32;
+ case Type::INTERVAL_DAY_TIME:
+ return 64;
+ case Type::INTERVAL_MONTH_DAY_NANO:
+ return 128;
+
+ case Type::DECIMAL128:
+ return 128;
+ case Type::DECIMAL256:
+ return 256;
+
+ default:
+ break;
+ }
+ return 0;
+}
+
+static inline bool is_nested(Type::type type_id) {
+ switch (type_id) {
+ case Type::LIST:
+ case Type::LARGE_LIST:
+ case Type::FIXED_SIZE_LIST:
+ case Type::MAP:
+ case Type::STRUCT:
+ case Type::SPARSE_UNION:
+ case Type::DENSE_UNION:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline bool is_union(Type::type type_id) {
+ switch (type_id) {
+ case Type::SPARSE_UNION:
+ case Type::DENSE_UNION:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline int offset_bit_width(Type::type type_id) {
+ switch (type_id) {
+ case Type::STRING:
+ case Type::BINARY:
+ case Type::LIST:
+ case Type::MAP:
+ case Type::DENSE_UNION:
+ return 32;
+ case Type::LARGE_STRING:
+ case Type::LARGE_BINARY:
+ case Type::LARGE_LIST:
+ return 64;
+ default:
+ break;
+ }
+ return 0;
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/CMakeLists.txt b/src/arrow/cpp/src/arrow/util/CMakeLists.txt
new file mode 100644
index 000000000..6d36fde93
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/CMakeLists.txt
@@ -0,0 +1,100 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#
+# arrow_util
+#
+
+# Headers: top level
+arrow_install_all_headers("arrow/util")
+
+#
+# arrow_test_main
+#
+
+if(WIN32)
+ # This manifest enables long file paths on Windows 10+
+ # See https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file#enable-long-paths-in-windows-10-version-1607-and-later
+ if(MSVC)
+ set(IO_UTIL_TEST_SOURCES io_util_test.cc io_util_test.manifest)
+ else()
+ set(IO_UTIL_TEST_SOURCES io_util_test.cc io_util_test.rc)
+ endif()
+else()
+ set(IO_UTIL_TEST_SOURCES io_util_test.cc)
+endif()
+
+add_arrow_test(utility-test
+ SOURCES
+ align_util_test.cc
+ async_generator_test.cc
+ async_util_test.cc
+ bit_block_counter_test.cc
+ bit_util_test.cc
+ cache_test.cc
+ checked_cast_test.cc
+ compression_test.cc
+ decimal_test.cc
+ formatting_util_test.cc
+ key_value_metadata_test.cc
+ hashing_test.cc
+ int_util_test.cc
+ ${IO_UTIL_TEST_SOURCES}
+ iterator_test.cc
+ logging_test.cc
+ queue_test.cc
+ range_test.cc
+ reflection_test.cc
+ rle_encoding_test.cc
+ small_vector_test.cc
+ stl_util_test.cc
+ string_test.cc
+ tdigest_test.cc
+ test_common.cc
+ time_test.cc
+ trie_test.cc
+ uri_test.cc
+ utf8_util_test.cc
+ value_parsing_test.cc
+ variant_test.cc)
+
+add_arrow_test(threading-utility-test
+ SOURCES
+ cancel_test.cc
+ counting_semaphore_test.cc
+ future_test.cc
+ task_group_test.cc
+ thread_pool_test.cc)
+
+add_arrow_benchmark(bit_block_counter_benchmark)
+add_arrow_benchmark(bit_util_benchmark)
+add_arrow_benchmark(bitmap_reader_benchmark)
+add_arrow_benchmark(cache_benchmark)
+add_arrow_benchmark(compression_benchmark)
+add_arrow_benchmark(decimal_benchmark)
+add_arrow_benchmark(hashing_benchmark)
+add_arrow_benchmark(int_util_benchmark)
+add_arrow_benchmark(machine_benchmark)
+add_arrow_benchmark(queue_benchmark)
+add_arrow_benchmark(range_benchmark)
+add_arrow_benchmark(small_vector_benchmark)
+add_arrow_benchmark(tdigest_benchmark)
+add_arrow_benchmark(thread_pool_benchmark)
+add_arrow_benchmark(trie_benchmark)
+add_arrow_benchmark(utf8_util_benchmark)
+add_arrow_benchmark(value_parsing_benchmark)
+add_arrow_benchmark(variant_benchmark)
diff --git a/src/arrow/cpp/src/arrow/util/algorithm.h b/src/arrow/cpp/src/arrow/util/algorithm.h
new file mode 100644
index 000000000..2a0e6ba70
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/algorithm.h
@@ -0,0 +1,33 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/result.h"
+
+namespace arrow {
+
+template <typename InputIterator, typename OutputIterator, typename UnaryOperation>
+Status MaybeTransform(InputIterator first, InputIterator last, OutputIterator out,
+ UnaryOperation unary_op) {
+ for (; first != last; ++first, (void)++out) {
+ ARROW_ASSIGN_OR_RAISE(*out, unary_op(*first));
+ }
+ return Status::OK();
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/align_util.h b/src/arrow/cpp/src/arrow/util/align_util.h
new file mode 100644
index 000000000..4c25a1a17
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/align_util.h
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+
+#include "arrow/util/bit_util.h"
+
+namespace arrow {
+namespace internal {
+
+struct BitmapWordAlignParams {
+ int64_t leading_bits;
+ int64_t trailing_bits;
+ int64_t trailing_bit_offset;
+ const uint8_t* aligned_start;
+ int64_t aligned_bits;
+ int64_t aligned_words;
+};
+
+// Compute parameters for accessing a bitmap using aligned word instructions.
+// The returned parameters describe:
+// - a leading area of size `leading_bits` before the aligned words
+// - a word-aligned area of size `aligned_bits`
+// - a trailing area of size `trailing_bits` after the aligned words
+template <uint64_t ALIGN_IN_BYTES>
+inline BitmapWordAlignParams BitmapWordAlign(const uint8_t* data, int64_t bit_offset,
+ int64_t length) {
+ static_assert(BitUtil::IsPowerOf2(ALIGN_IN_BYTES),
+ "ALIGN_IN_BYTES should be a positive power of two");
+ constexpr uint64_t ALIGN_IN_BITS = ALIGN_IN_BYTES * 8;
+
+ BitmapWordAlignParams p;
+
+ // Compute a "bit address" that we can align up to ALIGN_IN_BITS.
+ // We don't care about losing the upper bits since we are only interested in the
+ // difference between both addresses.
+ const uint64_t bit_addr =
+ reinterpret_cast<size_t>(data) * 8 + static_cast<uint64_t>(bit_offset);
+ const uint64_t aligned_bit_addr = BitUtil::RoundUpToPowerOf2(bit_addr, ALIGN_IN_BITS);
+
+ p.leading_bits = std::min<int64_t>(length, aligned_bit_addr - bit_addr);
+ p.aligned_words = (length - p.leading_bits) / ALIGN_IN_BITS;
+ p.aligned_bits = p.aligned_words * ALIGN_IN_BITS;
+ p.trailing_bits = length - p.leading_bits - p.aligned_bits;
+ p.trailing_bit_offset = bit_offset + p.leading_bits + p.aligned_bits;
+
+ p.aligned_start = data + (bit_offset + p.leading_bits) / 8;
+ return p;
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/align_util_test.cc b/src/arrow/cpp/src/arrow/util/align_util_test.cc
new file mode 100644
index 000000000..2f6380c62
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/align_util_test.cc
@@ -0,0 +1,150 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/util/align_util.h"
+
+namespace arrow {
+namespace internal {
+
+template <int64_t NBYTES>
+void CheckBitmapWordAlign(const uint8_t* data, int64_t bit_offset, int64_t length,
+ BitmapWordAlignParams expected) {
+ auto p = BitmapWordAlign<static_cast<uint64_t>(NBYTES)>(data, bit_offset, length);
+
+ ASSERT_EQ(p.leading_bits, expected.leading_bits);
+ ASSERT_EQ(p.trailing_bits, expected.trailing_bits);
+ if (p.trailing_bits > 0) {
+ // Only relevant if trailing_bits > 0
+ ASSERT_EQ(p.trailing_bit_offset, expected.trailing_bit_offset);
+ }
+ ASSERT_EQ(p.aligned_bits, expected.aligned_bits);
+ ASSERT_EQ(p.aligned_words, expected.aligned_words);
+ if (p.aligned_bits > 0) {
+ // Only relevant if aligned_bits > 0
+ ASSERT_EQ(p.aligned_start, expected.aligned_start);
+ }
+
+ // Invariants
+ ASSERT_LT(p.leading_bits, NBYTES * 8);
+ ASSERT_LT(p.trailing_bits, NBYTES * 8);
+ ASSERT_EQ(p.leading_bits + p.aligned_bits + p.trailing_bits, length);
+ ASSERT_EQ(p.aligned_bits, NBYTES * 8 * p.aligned_words);
+ if (p.aligned_bits > 0) {
+ ASSERT_EQ(reinterpret_cast<size_t>(p.aligned_start) & (NBYTES - 1), 0);
+ }
+ if (p.trailing_bits > 0) {
+ ASSERT_EQ(p.trailing_bit_offset, bit_offset + p.leading_bits + p.aligned_bits);
+ ASSERT_EQ(p.trailing_bit_offset + p.trailing_bits, bit_offset + length);
+ }
+}
+
+TEST(BitmapWordAlign, AlignedDataStart) {
+ alignas(8) char buf[136];
+
+ // A 8-byte aligned pointer
+ const uint8_t* P = reinterpret_cast<const uint8_t*>(buf);
+ const uint8_t* A = P;
+
+ // {leading_bits, trailing_bits, trailing_bit_offset,
+ // aligned_start, aligned_bits, aligned_words}
+ CheckBitmapWordAlign<8>(P, 0, 0, {0, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 0, 13, {0, 13, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 0, 63, {0, 63, 0, A, 0, 0});
+
+ CheckBitmapWordAlign<8>(P, 0, 64, {0, 0, 0, A, 64, 1});
+ CheckBitmapWordAlign<8>(P, 0, 73, {0, 9, 64, A, 64, 1});
+ CheckBitmapWordAlign<8>(P, 0, 191, {0, 63, 128, A, 128, 2});
+
+ CheckBitmapWordAlign<8>(P, 5, 0, {0, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 5, 13, {13, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 5, 59, {59, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 5, 60, {59, 1, 64, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 5, 64, {59, 5, 64, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 5, 122, {59, 63, 64, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 5, 123, {59, 0, 64, A + 8, 64, 1});
+ CheckBitmapWordAlign<8>(P, 5, 314, {59, 63, 256, A + 8, 192, 3});
+ CheckBitmapWordAlign<8>(P, 5, 315, {59, 0, 320, A + 8, 256, 4});
+
+ CheckBitmapWordAlign<8>(P, 63, 0, {0, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 63, 1, {1, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 63, 2, {1, 1, 64, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 63, 64, {1, 63, 64, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 63, 65, {1, 0, 128, A + 8, 64, 1});
+ CheckBitmapWordAlign<8>(P, 63, 128, {1, 63, 128, A + 8, 64, 1});
+ CheckBitmapWordAlign<8>(P, 63, 129, {1, 0, 192, A + 8, 128, 2});
+
+ CheckBitmapWordAlign<8>(P, 1024, 0, {0, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 1024, 130, {0, 2, 1152, A + 128, 128, 2});
+
+ CheckBitmapWordAlign<8>(P, 1025, 1, {1, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 1025, 63, {63, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 1025, 64, {63, 1, 1088, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 1025, 128, {63, 1, 1152, A + 136, 64, 1});
+}
+
+TEST(BitmapWordAlign, UnalignedDataStart) {
+ alignas(8) char buf[136];
+
+ const uint8_t* P = reinterpret_cast<const uint8_t*>(buf) + 1;
+ const uint8_t* A = P + 7;
+
+ // {leading_bits, trailing_bits, trailing_bit_offset,
+ // aligned_start, aligned_bits, aligned_words}
+ CheckBitmapWordAlign<8>(P, 0, 0, {0, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 0, 13, {13, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 0, 56, {56, 0, 0, A, 0, 0});
+
+ CheckBitmapWordAlign<8>(P, 0, 57, {56, 1, 56, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 0, 119, {56, 63, 56, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 0, 120, {56, 0, 120, A, 64, 1});
+ CheckBitmapWordAlign<8>(P, 0, 184, {56, 0, 184, A, 128, 2});
+ CheckBitmapWordAlign<8>(P, 0, 185, {56, 1, 184, A, 128, 2});
+
+ CheckBitmapWordAlign<8>(P, 55, 0, {0, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 55, 1, {1, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 55, 2, {1, 1, 56, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 55, 66, {1, 1, 120, A, 64, 1});
+
+ // (P + 56 bits) is 64-bit aligned
+ CheckBitmapWordAlign<8>(P, 56, 0, {0, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 56, 1, {0, 1, 56, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 56, 63, {0, 63, 56, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 56, 191, {0, 63, 184, A, 128, 2});
+
+ // (P + 1016 bits) is 64-bit aligned
+ CheckBitmapWordAlign<8>(P, 1016, 0, {0, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 1016, 5, {0, 5, 1016, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 1016, 63, {0, 63, 1016, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 1016, 64, {0, 0, 1080, A + 120, 64, 1});
+ CheckBitmapWordAlign<8>(P, 1016, 129, {0, 1, 1144, A + 120, 128, 2});
+
+ CheckBitmapWordAlign<8>(P, 1017, 0, {0, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 1017, 1, {1, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 1017, 63, {63, 0, 0, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 1017, 64, {63, 1, 1080, A, 0, 0});
+ CheckBitmapWordAlign<8>(P, 1017, 128, {63, 1, 1144, A + 128, 64, 1});
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/aligned_storage.h b/src/arrow/cpp/src/arrow/util/aligned_storage.h
new file mode 100644
index 000000000..f6acb36c9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/aligned_storage.h
@@ -0,0 +1,127 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstring>
+#include <type_traits>
+#include <utility>
+
+#include "arrow/util/launder.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+template <typename T>
+class AlignedStorage {
+ public:
+ static constexpr bool can_memcpy = std::is_trivial<T>::value;
+
+#if __cpp_constexpr >= 201304L // non-const constexpr
+ constexpr T* get() noexcept { return launder(reinterpret_cast<T*>(&data_)); }
+#else
+ T* get() noexcept { return launder(reinterpret_cast<T*>(&data_)); }
+#endif
+
+ constexpr const T* get() const noexcept {
+ return launder(reinterpret_cast<const T*>(&data_));
+ }
+
+ void destroy() noexcept {
+ if (!std::is_trivially_destructible<T>::value) {
+ get()->~T();
+ }
+ }
+
+ template <typename... A>
+ void construct(A&&... args) noexcept {
+ new (&data_) T(std::forward<A>(args)...);
+ }
+
+ template <typename V>
+ void assign(V&& v) noexcept {
+ *get() = std::forward<V>(v);
+ }
+
+ void move_construct(AlignedStorage* other) noexcept {
+ new (&data_) T(std::move(*other->get()));
+ }
+
+ void move_assign(AlignedStorage* other) noexcept { *get() = std::move(*other->get()); }
+
+ template <bool CanMemcpy = can_memcpy>
+ static typename std::enable_if<CanMemcpy>::type move_construct_several(
+ AlignedStorage* ARROW_RESTRICT src, AlignedStorage* ARROW_RESTRICT dest, size_t n,
+ size_t memcpy_length) noexcept {
+ memcpy(dest->get(), src->get(), memcpy_length * sizeof(T));
+ }
+
+ template <bool CanMemcpy = can_memcpy>
+ static typename std::enable_if<CanMemcpy>::type
+ move_construct_several_and_destroy_source(AlignedStorage* ARROW_RESTRICT src,
+ AlignedStorage* ARROW_RESTRICT dest, size_t n,
+ size_t memcpy_length) noexcept {
+ memcpy(dest->get(), src->get(), memcpy_length * sizeof(T));
+ }
+
+ template <bool CanMemcpy = can_memcpy>
+ static typename std::enable_if<!CanMemcpy>::type move_construct_several(
+ AlignedStorage* ARROW_RESTRICT src, AlignedStorage* ARROW_RESTRICT dest, size_t n,
+ size_t memcpy_length) noexcept {
+ for (size_t i = 0; i < n; ++i) {
+ new (dest[i].get()) T(std::move(*src[i].get()));
+ }
+ }
+
+ template <bool CanMemcpy = can_memcpy>
+ static typename std::enable_if<!CanMemcpy>::type
+ move_construct_several_and_destroy_source(AlignedStorage* ARROW_RESTRICT src,
+ AlignedStorage* ARROW_RESTRICT dest, size_t n,
+ size_t memcpy_length) noexcept {
+ for (size_t i = 0; i < n; ++i) {
+ new (dest[i].get()) T(std::move(*src[i].get()));
+ src[i].destroy();
+ }
+ }
+
+ static void move_construct_several(AlignedStorage* ARROW_RESTRICT src,
+ AlignedStorage* ARROW_RESTRICT dest,
+ size_t n) noexcept {
+ move_construct_several(src, dest, n, n);
+ }
+
+ static void move_construct_several_and_destroy_source(
+ AlignedStorage* ARROW_RESTRICT src, AlignedStorage* ARROW_RESTRICT dest,
+ size_t n) noexcept {
+ move_construct_several_and_destroy_source(src, dest, n, n);
+ }
+
+ static void destroy_several(AlignedStorage* p, size_t n) noexcept {
+ if (!std::is_trivially_destructible<T>::value) {
+ for (size_t i = 0; i < n; ++i) {
+ p[i].destroy();
+ }
+ }
+ }
+
+ private:
+ typename std::aligned_storage<sizeof(T), alignof(T)>::type data_;
+};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/async_generator.h b/src/arrow/cpp/src/arrow/util/async_generator.h
new file mode 100644
index 000000000..0948e5537
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/async_generator.h
@@ -0,0 +1,1804 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <atomic>
+#include <cassert>
+#include <cstring>
+#include <deque>
+#include <limits>
+#include <queue>
+
+#include "arrow/util/async_util.h"
+#include "arrow/util/functional.h"
+#include "arrow/util/future.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/mutex.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/queue.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+// The methods in this file create, modify, and utilize AsyncGenerator which is an
+// iterator of futures. This allows an asynchronous source (like file input) to be run
+// through a pipeline in the same way that iterators can be used to create pipelined
+// workflows.
+//
+// In order to support pipeline parallelism we introduce the concept of asynchronous
+// reentrancy. This is different than synchronous reentrancy. With synchronous code a
+// function is reentrant if the function can be called again while a previous call to that
+// function is still running. Unless otherwise specified none of these generators are
+// synchronously reentrant. Care should be taken to avoid calling them in such a way (and
+// the utilities Visit/Collect/Await take care to do this).
+//
+// Asynchronous reentrancy on the other hand means the function is called again before the
+// future returned by the function is marked finished (but after the call to get the
+// future returns). Some of these generators are async-reentrant while others (e.g.
+// those that depend on ordered processing like decompression) are not. Read the MakeXYZ
+// function comments to determine which generators support async reentrancy.
+//
+// Note: Generators that are not asynchronously reentrant can still support readahead
+// (\see MakeSerialReadaheadGenerator).
+//
+// Readahead operators, and some other operators, may introduce queueing. Any operators
+// that introduce buffering should detail the amount of buffering they introduce in their
+// MakeXYZ function comments.
+template <typename T>
+using AsyncGenerator = std::function<Future<T>()>;
+
+template <typename T>
+struct IterationTraits<AsyncGenerator<T>> {
+ /// \brief by default when iterating through a sequence of AsyncGenerator<T>,
+ /// an empty function indicates the end of iteration.
+ static AsyncGenerator<T> End() { return AsyncGenerator<T>(); }
+
+ static bool IsEnd(const AsyncGenerator<T>& val) { return !val; }
+};
+
+template <typename T>
+Future<T> AsyncGeneratorEnd() {
+ return Future<T>::MakeFinished(IterationTraits<T>::End());
+}
+
+/// returning a future that completes when all have been visited
+template <typename T, typename Visitor>
+Future<> VisitAsyncGenerator(AsyncGenerator<T> generator, Visitor visitor) {
+ struct LoopBody {
+ struct Callback {
+ Result<ControlFlow<>> operator()(const T& next) {
+ if (IsIterationEnd(next)) {
+ return Break();
+ } else {
+ auto visited = visitor(next);
+ if (visited.ok()) {
+ return Continue();
+ } else {
+ return visited;
+ }
+ }
+ }
+
+ Visitor visitor;
+ };
+
+ Future<ControlFlow<>> operator()() {
+ Callback callback{visitor};
+ auto next = generator();
+ return next.Then(std::move(callback));
+ }
+
+ AsyncGenerator<T> generator;
+ Visitor visitor;
+ };
+
+ return Loop(LoopBody{std::move(generator), std::move(visitor)});
+}
+
+/// \brief Wait for an async generator to complete, discarding results.
+template <typename T>
+Future<> DiscardAllFromAsyncGenerator(AsyncGenerator<T> generator) {
+ std::function<Status(T)> visitor = [](const T&) { return Status::OK(); };
+ return VisitAsyncGenerator(generator, visitor);
+}
+
+/// \brief Collect the results of an async generator into a vector
+template <typename T>
+Future<std::vector<T>> CollectAsyncGenerator(AsyncGenerator<T> generator) {
+ auto vec = std::make_shared<std::vector<T>>();
+ struct LoopBody {
+ Future<ControlFlow<std::vector<T>>> operator()() {
+ auto next = generator_();
+ auto vec = vec_;
+ return next.Then([vec](const T& result) -> Result<ControlFlow<std::vector<T>>> {
+ if (IsIterationEnd(result)) {
+ return Break(*vec);
+ } else {
+ vec->push_back(result);
+ return Continue();
+ }
+ });
+ }
+ AsyncGenerator<T> generator_;
+ std::shared_ptr<std::vector<T>> vec_;
+ };
+ return Loop(LoopBody{std::move(generator), std::move(vec)});
+}
+
+/// \see MakeMappedGenerator
+template <typename T, typename V>
+class MappingGenerator {
+ public:
+ MappingGenerator(AsyncGenerator<T> source, std::function<Future<V>(const T&)> map)
+ : state_(std::make_shared<State>(std::move(source), std::move(map))) {}
+
+ Future<V> operator()() {
+ auto future = Future<V>::Make();
+ bool should_trigger;
+ {
+ auto guard = state_->mutex.Lock();
+ if (state_->finished) {
+ return AsyncGeneratorEnd<V>();
+ }
+ should_trigger = state_->waiting_jobs.empty();
+ state_->waiting_jobs.push_back(future);
+ }
+ if (should_trigger) {
+ state_->source().AddCallback(Callback{state_});
+ }
+ return future;
+ }
+
+ private:
+ struct State {
+ State(AsyncGenerator<T> source, std::function<Future<V>(const T&)> map)
+ : source(std::move(source)),
+ map(std::move(map)),
+ waiting_jobs(),
+ mutex(),
+ finished(false) {}
+
+ void Purge() {
+ // This might be called by an original callback (if the source iterator fails or
+ // ends) or by a mapped callback (if the map function fails or ends prematurely).
+ // Either way it should only be called once and after finished is set so there is no
+ // need to guard access to `waiting_jobs`.
+ while (!waiting_jobs.empty()) {
+ waiting_jobs.front().MarkFinished(IterationTraits<V>::End());
+ waiting_jobs.pop_front();
+ }
+ }
+
+ AsyncGenerator<T> source;
+ std::function<Future<V>(const T&)> map;
+ std::deque<Future<V>> waiting_jobs;
+ util::Mutex mutex;
+ bool finished;
+ };
+
+ struct Callback;
+
+ struct MappedCallback {
+ void operator()(const Result<V>& maybe_next) {
+ bool end = !maybe_next.ok() || IsIterationEnd(*maybe_next);
+ bool should_purge = false;
+ if (end) {
+ {
+ auto guard = state->mutex.Lock();
+ should_purge = !state->finished;
+ state->finished = true;
+ }
+ }
+ sink.MarkFinished(maybe_next);
+ if (should_purge) {
+ state->Purge();
+ }
+ }
+ std::shared_ptr<State> state;
+ Future<V> sink;
+ };
+
+ struct Callback {
+ void operator()(const Result<T>& maybe_next) {
+ Future<V> sink;
+ bool end = !maybe_next.ok() || IsIterationEnd(*maybe_next);
+ bool should_purge = false;
+ bool should_trigger;
+ {
+ auto guard = state->mutex.Lock();
+ // A MappedCallback may have purged or be purging the queue;
+ // we shouldn't do anything here.
+ if (state->finished) return;
+ if (end) {
+ should_purge = !state->finished;
+ state->finished = true;
+ }
+ sink = state->waiting_jobs.front();
+ state->waiting_jobs.pop_front();
+ should_trigger = !end && !state->waiting_jobs.empty();
+ }
+ if (should_purge) {
+ state->Purge();
+ }
+ if (should_trigger) {
+ state->source().AddCallback(Callback{state});
+ }
+ if (maybe_next.ok()) {
+ const T& val = maybe_next.ValueUnsafe();
+ if (IsIterationEnd(val)) {
+ sink.MarkFinished(IterationTraits<V>::End());
+ } else {
+ Future<V> mapped_fut = state->map(val);
+ mapped_fut.AddCallback(MappedCallback{std::move(state), std::move(sink)});
+ }
+ } else {
+ sink.MarkFinished(maybe_next.status());
+ }
+ }
+
+ std::shared_ptr<State> state;
+ };
+
+ std::shared_ptr<State> state_;
+};
+
+/// \brief Create a generator that will apply the map function to each element of
+/// source. The map function is not called on the end token.
+///
+/// Note: This function makes a copy of `map` for each item
+/// Note: Errors returned from the `map` function will be propagated
+///
+/// If the source generator is async-reentrant then this generator will be also
+template <typename T, typename MapFn,
+ typename Mapped = detail::result_of_t<MapFn(const T&)>,
+ typename V = typename EnsureFuture<Mapped>::type::ValueType>
+AsyncGenerator<V> MakeMappedGenerator(AsyncGenerator<T> source_generator, MapFn map) {
+ struct MapCallback {
+ MapFn map_;
+
+ Future<V> operator()(const T& val) { return ToFuture(map_(val)); }
+ };
+
+ return MappingGenerator<T, V>(std::move(source_generator), MapCallback{std::move(map)});
+}
+
+/// \brief Create a generator that will apply the map function to
+/// each element of source. The map function is not called on the end
+/// token. The result of the map function should be another
+/// generator; all these generators will then be flattened to produce
+/// a single stream of items.
+///
+/// Note: This function makes a copy of `map` for each item
+/// Note: Errors returned from the `map` function will be propagated
+///
+/// If the source generator is async-reentrant then this generator will be also
+template <typename T, typename MapFn,
+ typename Mapped = detail::result_of_t<MapFn(const T&)>,
+ typename V = typename EnsureFuture<Mapped>::type::ValueType>
+AsyncGenerator<T> MakeFlatMappedGenerator(AsyncGenerator<T> source_generator, MapFn map) {
+ return MakeConcatenatedGenerator(
+ MakeMappedGenerator(std::move(source_generator), std::move(map)));
+}
+
+/// \see MakeSequencingGenerator
+template <typename T, typename ComesAfter, typename IsNext>
+class SequencingGenerator {
+ public:
+ SequencingGenerator(AsyncGenerator<T> source, ComesAfter compare, IsNext is_next,
+ T initial_value)
+ : state_(std::make_shared<State>(std::move(source), std::move(compare),
+ std::move(is_next), std::move(initial_value))) {}
+
+ Future<T> operator()() {
+ {
+ auto guard = state_->mutex.Lock();
+ // We can send a result immediately if the top of the queue is either an
+ // error or the next item
+ if (!state_->queue.empty() &&
+ (!state_->queue.top().ok() ||
+ state_->is_next(state_->previous_value, *state_->queue.top()))) {
+ auto result = std::move(state_->queue.top());
+ if (result.ok()) {
+ state_->previous_value = *result;
+ }
+ state_->queue.pop();
+ return Future<T>::MakeFinished(result);
+ }
+ if (state_->finished) {
+ return AsyncGeneratorEnd<T>();
+ }
+ // The next item is not in the queue so we will need to wait
+ auto new_waiting_fut = Future<T>::Make();
+ state_->waiting_future = new_waiting_fut;
+ guard.Unlock();
+ state_->source().AddCallback(Callback{state_});
+ return new_waiting_fut;
+ }
+ }
+
+ private:
+ struct WrappedComesAfter {
+ bool operator()(const Result<T>& left, const Result<T>& right) {
+ if (!left.ok() || !right.ok()) {
+ // Should never happen
+ return false;
+ }
+ return compare(*left, *right);
+ }
+ ComesAfter compare;
+ };
+
+ struct State {
+ State(AsyncGenerator<T> source, ComesAfter compare, IsNext is_next, T initial_value)
+ : source(std::move(source)),
+ is_next(std::move(is_next)),
+ previous_value(std::move(initial_value)),
+ waiting_future(),
+ queue(WrappedComesAfter{compare}),
+ finished(false),
+ mutex() {}
+
+ AsyncGenerator<T> source;
+ IsNext is_next;
+ T previous_value;
+ Future<T> waiting_future;
+ std::priority_queue<Result<T>, std::vector<Result<T>>, WrappedComesAfter> queue;
+ bool finished;
+ util::Mutex mutex;
+ };
+
+ class Callback {
+ public:
+ explicit Callback(std::shared_ptr<State> state) : state_(std::move(state)) {}
+
+ void operator()(const Result<T> result) {
+ Future<T> to_deliver;
+ bool finished;
+ {
+ auto guard = state_->mutex.Lock();
+ bool ready_to_deliver = false;
+ if (!result.ok()) {
+ // Clear any cached results
+ while (!state_->queue.empty()) {
+ state_->queue.pop();
+ }
+ ready_to_deliver = true;
+ state_->finished = true;
+ } else if (IsIterationEnd<T>(result.ValueUnsafe())) {
+ ready_to_deliver = state_->queue.empty();
+ state_->finished = true;
+ } else {
+ ready_to_deliver = state_->is_next(state_->previous_value, *result);
+ }
+
+ if (ready_to_deliver && state_->waiting_future.is_valid()) {
+ to_deliver = state_->waiting_future;
+ if (result.ok()) {
+ state_->previous_value = *result;
+ }
+ } else {
+ state_->queue.push(result);
+ }
+ // Capture state_->finished so we can access it outside the mutex
+ finished = state_->finished;
+ }
+ // Must deliver result outside of the mutex
+ if (to_deliver.is_valid()) {
+ to_deliver.MarkFinished(result);
+ } else {
+ // Otherwise, if we didn't get the next item (or a terminal item), we
+ // need to keep looking
+ if (!finished) {
+ state_->source().AddCallback(Callback{state_});
+ }
+ }
+ }
+
+ private:
+ const std::shared_ptr<State> state_;
+ };
+
+ const std::shared_ptr<State> state_;
+};
+
+/// \brief Buffer an AsyncGenerator to return values in sequence order ComesAfter
+/// and IsNext determine the sequence order.
+///
+/// ComesAfter should be a BinaryPredicate that only returns true if a comes after b
+///
+/// IsNext should be a BinaryPredicate that returns true, given `a` and `b`, only if
+/// `b` follows immediately after `a`. It should return true given `initial_value` and
+/// `b` if `b` is the first item in the sequence.
+///
+/// This operator will queue unboundedly while waiting for the next item. It is intended
+/// for jittery sources that might scatter an ordered sequence. It is NOT intended to
+/// sort. Using it to try and sort could result in excessive RAM usage. This generator
+/// will queue up to N blocks where N is the max "out of order"ness of the source.
+///
+/// For example, if the source is 1,6,2,5,4,3 it will queue 3 blocks because 3 is 3
+/// blocks beyond where it belongs.
+///
+/// This generator is not async-reentrant but it consists only of a simple log(n)
+/// insertion into a priority queue.
+template <typename T, typename ComesAfter, typename IsNext>
+AsyncGenerator<T> MakeSequencingGenerator(AsyncGenerator<T> source_generator,
+ ComesAfter compare, IsNext is_next,
+ T initial_value) {
+ return SequencingGenerator<T, ComesAfter, IsNext>(
+ std::move(source_generator), std::move(compare), std::move(is_next),
+ std::move(initial_value));
+}
+
+/// \see MakeTransformedGenerator
+template <typename T, typename V>
+class TransformingGenerator {
+ // The transforming generator state will be referenced as an async generator but will
+ // also be referenced via callback to various futures. If the async generator owner
+ // moves it around we need the state to be consistent for future callbacks.
+ struct TransformingGeneratorState
+ : std::enable_shared_from_this<TransformingGeneratorState> {
+ TransformingGeneratorState(AsyncGenerator<T> generator, Transformer<T, V> transformer)
+ : generator_(std::move(generator)),
+ transformer_(std::move(transformer)),
+ last_value_(),
+ finished_() {}
+
+ Future<V> operator()() {
+ while (true) {
+ auto maybe_next_result = Pump();
+ if (!maybe_next_result.ok()) {
+ return Future<V>::MakeFinished(maybe_next_result.status());
+ }
+ auto maybe_next = std::move(maybe_next_result).ValueUnsafe();
+ if (maybe_next.has_value()) {
+ return Future<V>::MakeFinished(*std::move(maybe_next));
+ }
+
+ auto next_fut = generator_();
+ // If finished already, process results immediately inside the loop to avoid
+ // stack overflow
+ if (next_fut.is_finished()) {
+ auto next_result = next_fut.result();
+ if (next_result.ok()) {
+ last_value_ = *next_result;
+ } else {
+ return Future<V>::MakeFinished(next_result.status());
+ }
+ // Otherwise, if not finished immediately, add callback to process results
+ } else {
+ auto self = this->shared_from_this();
+ return next_fut.Then([self](const T& next_result) {
+ self->last_value_ = next_result;
+ return (*self)();
+ });
+ }
+ }
+ }
+
+ // See comment on TransformingIterator::Pump
+ Result<util::optional<V>> Pump() {
+ if (!finished_ && last_value_.has_value()) {
+ ARROW_ASSIGN_OR_RAISE(TransformFlow<V> next, transformer_(*last_value_));
+ if (next.ReadyForNext()) {
+ if (IsIterationEnd(*last_value_)) {
+ finished_ = true;
+ }
+ last_value_.reset();
+ }
+ if (next.Finished()) {
+ finished_ = true;
+ }
+ if (next.HasValue()) {
+ return next.Value();
+ }
+ }
+ if (finished_) {
+ return IterationTraits<V>::End();
+ }
+ return util::nullopt;
+ }
+
+ AsyncGenerator<T> generator_;
+ Transformer<T, V> transformer_;
+ util::optional<T> last_value_;
+ bool finished_;
+ };
+
+ public:
+ explicit TransformingGenerator(AsyncGenerator<T> generator,
+ Transformer<T, V> transformer)
+ : state_(std::make_shared<TransformingGeneratorState>(std::move(generator),
+ std::move(transformer))) {}
+
+ Future<V> operator()() { return (*state_)(); }
+
+ protected:
+ std::shared_ptr<TransformingGeneratorState> state_;
+};
+
+/// \brief Transform an async generator using a transformer function returning a new
+/// AsyncGenerator
+///
+/// The transform function here behaves exactly the same as the transform function in
+/// MakeTransformedIterator and you can safely use the same transform function to
+/// transform both synchronous and asynchronous streams.
+///
+/// This generator is not async-reentrant
+///
+/// This generator may queue up to 1 instance of T but will not delay
+template <typename T, typename V>
+AsyncGenerator<V> MakeTransformedGenerator(AsyncGenerator<T> generator,
+ Transformer<T, V> transformer) {
+ return TransformingGenerator<T, V>(generator, transformer);
+}
+
+/// \see MakeSerialReadaheadGenerator
+template <typename T>
+class SerialReadaheadGenerator {
+ public:
+ SerialReadaheadGenerator(AsyncGenerator<T> source_generator, int max_readahead)
+ : state_(std::make_shared<State>(std::move(source_generator), max_readahead)) {}
+
+ Future<T> operator()() {
+ if (state_->first_) {
+ // Lazy generator, need to wait for the first ask to prime the pump
+ state_->first_ = false;
+ auto next = state_->source_();
+ return next.Then(Callback{state_}, ErrCallback{state_});
+ }
+
+ // This generator is not async-reentrant. We won't be called until the last
+ // future finished so we know there is something in the queue
+ auto finished = state_->finished_.load();
+ if (finished && state_->readahead_queue_.IsEmpty()) {
+ return AsyncGeneratorEnd<T>();
+ }
+
+ std::shared_ptr<Future<T>> next;
+ if (!state_->readahead_queue_.Read(next)) {
+ return Status::UnknownError("Could not read from readahead_queue");
+ }
+
+ auto last_available = state_->spaces_available_.fetch_add(1);
+ if (last_available == 0 && !finished) {
+ // Reader idled out, we need to restart it
+ ARROW_RETURN_NOT_OK(state_->Pump(state_));
+ }
+ return *next;
+ }
+
+ private:
+ struct State {
+ State(AsyncGenerator<T> source, int max_readahead)
+ : first_(true),
+ source_(std::move(source)),
+ finished_(false),
+ // There is one extra "space" for the in-flight request
+ spaces_available_(max_readahead + 1),
+ // The SPSC queue has size-1 "usable" slots so we need to overallocate 1
+ readahead_queue_(max_readahead + 1) {}
+
+ Status Pump(const std::shared_ptr<State>& self) {
+ // Can't do readahead_queue.write(source().Then(...)) because then the
+ // callback might run immediately and add itself to the queue before this gets added
+ // to the queue messing up the order.
+ auto next_slot = std::make_shared<Future<T>>();
+ auto written = readahead_queue_.Write(next_slot);
+ if (!written) {
+ return Status::UnknownError("Could not write to readahead_queue");
+ }
+ // If this Pump is being called from a callback it is possible for the source to
+ // poll and read from the queue between the Write and this spot where we fill the
+ // value in. However, it is not possible for the future to read this value we are
+ // writing. That is because this callback (the callback for future X) must be
+ // finished before future X is marked complete and this source is not pulled
+ // reentrantly so it will not poll for future X+1 until this callback has completed.
+ *next_slot = source_().Then(Callback{self}, ErrCallback{self});
+ return Status::OK();
+ }
+
+ // Only accessed by the consumer end
+ bool first_;
+ // Accessed by both threads
+ AsyncGenerator<T> source_;
+ std::atomic<bool> finished_;
+ // The queue has a size but it is not atomic. We keep track of how many spaces are
+ // left in the queue here so we know if we've just written the last value and we need
+ // to stop reading ahead or if we've just read from a full queue and we need to
+ // restart reading ahead
+ std::atomic<uint32_t> spaces_available_;
+ // Needs to be a queue of shared_ptr and not Future because we set the value of the
+ // future after we add it to the queue
+ util::SpscQueue<std::shared_ptr<Future<T>>> readahead_queue_;
+ };
+
+ struct Callback {
+ Result<T> operator()(const T& next) {
+ if (IsIterationEnd(next)) {
+ state_->finished_.store(true);
+ return next;
+ }
+ auto last_available = state_->spaces_available_.fetch_sub(1);
+ if (last_available > 1) {
+ ARROW_RETURN_NOT_OK(state_->Pump(state_));
+ }
+ return next;
+ }
+
+ std::shared_ptr<State> state_;
+ };
+
+ struct ErrCallback {
+ Result<T> operator()(const Status& st) {
+ state_->finished_.store(true);
+ return st;
+ }
+
+ std::shared_ptr<State> state_;
+ };
+
+ std::shared_ptr<State> state_;
+};
+
+/// \see MakeFromFuture
+template <typename T>
+class FutureFirstGenerator {
+ public:
+ explicit FutureFirstGenerator(Future<AsyncGenerator<T>> future)
+ : state_(std::make_shared<State>(std::move(future))) {}
+
+ Future<T> operator()() {
+ if (state_->source_) {
+ return state_->source_();
+ } else {
+ auto state = state_;
+ return state_->future_.Then([state](const AsyncGenerator<T>& source) {
+ state->source_ = source;
+ return state->source_();
+ });
+ }
+ }
+
+ private:
+ struct State {
+ explicit State(Future<AsyncGenerator<T>> future) : future_(future), source_() {}
+
+ Future<AsyncGenerator<T>> future_;
+ AsyncGenerator<T> source_;
+ };
+
+ std::shared_ptr<State> state_;
+};
+
+/// \brief Transform a Future<AsyncGenerator<T>> into an AsyncGenerator<T>
+/// that waits for the future to complete as part of the first item.
+///
+/// This generator is not async-reentrant (even if the generator yielded by future is)
+///
+/// This generator does not queue
+template <typename T>
+AsyncGenerator<T> MakeFromFuture(Future<AsyncGenerator<T>> future) {
+ return FutureFirstGenerator<T>(std::move(future));
+}
+
+/// \brief Create a generator that will pull from the source into a queue. Unlike
+/// MakeReadaheadGenerator this will not pull reentrantly from the source.
+///
+/// The source generator does not need to be async-reentrant
+///
+/// This generator is not async-reentrant (even if the source is)
+///
+/// This generator may queue up to max_readahead additional instances of T
+template <typename T>
+AsyncGenerator<T> MakeSerialReadaheadGenerator(AsyncGenerator<T> source_generator,
+ int max_readahead) {
+ return SerialReadaheadGenerator<T>(std::move(source_generator), max_readahead);
+}
+
+/// \brief Create a generator that immediately pulls from the source
+///
+/// Typical generators do not pull from their source until they themselves
+/// are pulled. This generator does not follow that convention and will call
+/// generator() once before it returns. The returned generator will otherwise
+/// mirror the source.
+///
+/// This generator forwards aysnc-reentrant pressure to the source
+/// This generator buffers one item (the first result) until it is delivered.
+template <typename T>
+AsyncGenerator<T> MakeAutoStartingGenerator(AsyncGenerator<T> generator) {
+ struct AutostartGenerator {
+ Future<T> operator()() {
+ if (first_future->is_valid()) {
+ Future<T> result = *first_future;
+ *first_future = Future<T>();
+ return result;
+ }
+ return source();
+ }
+
+ std::shared_ptr<Future<T>> first_future;
+ AsyncGenerator<T> source;
+ };
+
+ std::shared_ptr<Future<T>> first_future = std::make_shared<Future<T>>(generator());
+ return AutostartGenerator{std::move(first_future), std::move(generator)};
+}
+
+/// \see MakeReadaheadGenerator
+template <typename T>
+class ReadaheadGenerator {
+ public:
+ ReadaheadGenerator(AsyncGenerator<T> source_generator, int max_readahead)
+ : state_(std::make_shared<State>(std::move(source_generator), max_readahead)) {}
+
+ Future<T> AddMarkFinishedContinuation(Future<T> fut) {
+ auto state = state_;
+ return fut.Then(
+ [state](const T& result) -> Result<T> {
+ state->MarkFinishedIfDone(result);
+ return result;
+ },
+ [state](const Status& err) -> Result<T> {
+ state->finished.store(true);
+ return err;
+ });
+ }
+
+ Future<T> operator()() {
+ if (state_->readahead_queue.empty()) {
+ // This is the first request, let's pump the underlying queue
+ for (int i = 0; i < state_->max_readahead; i++) {
+ auto next = state_->source_generator();
+ auto next_after_check = AddMarkFinishedContinuation(std::move(next));
+ state_->readahead_queue.push(std::move(next_after_check));
+ }
+ }
+ // Pop one and add one
+ auto result = state_->readahead_queue.front();
+ state_->readahead_queue.pop();
+ if (state_->finished.load()) {
+ state_->readahead_queue.push(AsyncGeneratorEnd<T>());
+ } else {
+ auto back_of_queue = state_->source_generator();
+ auto back_of_queue_after_check =
+ AddMarkFinishedContinuation(std::move(back_of_queue));
+ state_->readahead_queue.push(std::move(back_of_queue_after_check));
+ }
+ return result;
+ }
+
+ private:
+ struct State {
+ State(AsyncGenerator<T> source_generator, int max_readahead)
+ : source_generator(std::move(source_generator)), max_readahead(max_readahead) {
+ finished.store(false);
+ }
+
+ void MarkFinishedIfDone(const T& next_result) {
+ if (IsIterationEnd(next_result)) {
+ finished.store(true);
+ }
+ }
+
+ AsyncGenerator<T> source_generator;
+ int max_readahead;
+ std::atomic<bool> finished;
+ std::queue<Future<T>> readahead_queue;
+ };
+
+ std::shared_ptr<State> state_;
+};
+
+/// \brief A generator where the producer pushes items on a queue.
+///
+/// No back-pressure is applied, so this generator is mostly useful when
+/// producing the values is neither CPU- nor memory-expensive (e.g. fetching
+/// filesystem metadata).
+///
+/// This generator is not async-reentrant.
+template <typename T>
+class PushGenerator {
+ struct State {
+ explicit State(util::BackpressureOptions backpressure)
+ : backpressure(std::move(backpressure)) {}
+
+ void OpenBackpressureIfFreeUnlocked(util::Mutex::Guard&& guard) {
+ if (backpressure.toggle && result_q.size() < backpressure.resume_if_below) {
+ // Open might trigger callbacks so release the lock first
+ guard.Unlock();
+ backpressure.toggle->Open();
+ }
+ }
+
+ void CloseBackpressureIfFullUnlocked() {
+ if (backpressure.toggle && result_q.size() > backpressure.pause_if_above) {
+ backpressure.toggle->Close();
+ }
+ }
+
+ util::BackpressureOptions backpressure;
+ util::Mutex mutex;
+ std::deque<Result<T>> result_q;
+ util::optional<Future<T>> consumer_fut;
+ bool finished = false;
+ };
+
+ public:
+ /// Producer API for PushGenerator
+ class Producer {
+ public:
+ explicit Producer(const std::shared_ptr<State>& state) : weak_state_(state) {}
+
+ /// \brief Push a value on the queue
+ ///
+ /// True is returned if the value was pushed, false if the generator is
+ /// already closed or destroyed. If the latter, it is recommended to stop
+ /// producing any further values.
+ bool Push(Result<T> result) {
+ auto state = weak_state_.lock();
+ if (!state) {
+ // Generator was destroyed
+ return false;
+ }
+ auto lock = state->mutex.Lock();
+ if (state->finished) {
+ // Closed early
+ return false;
+ }
+ if (state->consumer_fut.has_value()) {
+ auto fut = std::move(state->consumer_fut.value());
+ state->consumer_fut.reset();
+ lock.Unlock(); // unlock before potentially invoking a callback
+ fut.MarkFinished(std::move(result));
+ } else {
+ state->result_q.push_back(std::move(result));
+ state->CloseBackpressureIfFullUnlocked();
+ }
+ return true;
+ }
+
+ /// \brief Tell the consumer we have finished producing
+ ///
+ /// It is allowed to call this and later call Push() again ("early close").
+ /// In this case, calls to Push() after the queue is closed are silently
+ /// ignored. This can help implementing non-trivial cancellation cases.
+ ///
+ /// True is returned on success, false if the generator is already closed
+ /// or destroyed.
+ bool Close() {
+ auto state = weak_state_.lock();
+ if (!state) {
+ // Generator was destroyed
+ return false;
+ }
+ auto lock = state->mutex.Lock();
+ if (state->finished) {
+ // Already closed
+ return false;
+ }
+ state->finished = true;
+ if (state->consumer_fut.has_value()) {
+ auto fut = std::move(state->consumer_fut.value());
+ state->consumer_fut.reset();
+ lock.Unlock(); // unlock before potentially invoking a callback
+ fut.MarkFinished(IterationTraits<T>::End());
+ }
+ return true;
+ }
+
+ /// Return whether the generator was closed or destroyed.
+ bool is_closed() const {
+ auto state = weak_state_.lock();
+ if (!state) {
+ // Generator was destroyed
+ return true;
+ }
+ auto lock = state->mutex.Lock();
+ return state->finished;
+ }
+
+ private:
+ const std::weak_ptr<State> weak_state_;
+ };
+
+ explicit PushGenerator(util::BackpressureOptions backpressure = {})
+ : state_(std::make_shared<State>(std::move(backpressure))) {}
+
+ /// Read an item from the queue
+ Future<T> operator()() const {
+ auto lock = state_->mutex.Lock();
+ assert(!state_->consumer_fut.has_value()); // Non-reentrant
+ if (!state_->result_q.empty()) {
+ auto fut = Future<T>::MakeFinished(std::move(state_->result_q.front()));
+ state_->result_q.pop_front();
+ state_->OpenBackpressureIfFreeUnlocked(std::move(lock));
+ return fut;
+ }
+ if (state_->finished) {
+ return AsyncGeneratorEnd<T>();
+ }
+ auto fut = Future<T>::Make();
+ state_->consumer_fut = fut;
+ return fut;
+ }
+
+ /// \brief Return producer-side interface
+ ///
+ /// The returned object must be used by the producer to push values on the queue.
+ /// Only a single Producer object should be instantiated.
+ Producer producer() { return Producer{state_}; }
+
+ private:
+ const std::shared_ptr<State> state_;
+};
+
+/// \brief Create a generator that pulls reentrantly from a source
+/// This generator will pull reentrantly from a source, ensuring that max_readahead
+/// requests are active at any given time.
+///
+/// The source generator must be async-reentrant
+///
+/// This generator itself is async-reentrant.
+///
+/// This generator may queue up to max_readahead instances of T
+template <typename T>
+AsyncGenerator<T> MakeReadaheadGenerator(AsyncGenerator<T> source_generator,
+ int max_readahead) {
+ return ReadaheadGenerator<T>(std::move(source_generator), max_readahead);
+}
+
+/// \brief Creates a generator that will yield finished futures from a vector
+///
+/// This generator is async-reentrant
+template <typename T>
+AsyncGenerator<T> MakeVectorGenerator(std::vector<T> vec) {
+ struct State {
+ explicit State(std::vector<T> vec_) : vec(std::move(vec_)), vec_idx(0) {}
+
+ std::vector<T> vec;
+ std::atomic<std::size_t> vec_idx;
+ };
+
+ auto state = std::make_shared<State>(std::move(vec));
+ return [state]() {
+ auto idx = state->vec_idx.fetch_add(1);
+ if (idx >= state->vec.size()) {
+ // Eagerly return memory
+ state->vec.clear();
+ return AsyncGeneratorEnd<T>();
+ }
+ return Future<T>::MakeFinished(state->vec[idx]);
+ };
+}
+
+/// \see MakeMergedGenerator
+template <typename T>
+class MergedGenerator {
+ public:
+ explicit MergedGenerator(AsyncGenerator<AsyncGenerator<T>> source,
+ int max_subscriptions)
+ : state_(std::make_shared<State>(std::move(source), max_subscriptions)) {}
+
+ Future<T> operator()() {
+ Future<T> waiting_future;
+ std::shared_ptr<DeliveredJob> delivered_job;
+ {
+ auto guard = state_->mutex.Lock();
+ if (!state_->delivered_jobs.empty()) {
+ delivered_job = std::move(state_->delivered_jobs.front());
+ state_->delivered_jobs.pop_front();
+ } else if (state_->finished) {
+ return IterationTraits<T>::End();
+ } else {
+ waiting_future = Future<T>::Make();
+ state_->waiting_jobs.push_back(std::make_shared<Future<T>>(waiting_future));
+ }
+ }
+ if (delivered_job) {
+ // deliverer will be invalid if outer callback encounters an error and delivers a
+ // failed result
+ if (delivered_job->deliverer) {
+ delivered_job->deliverer().AddCallback(
+ InnerCallback{state_, delivered_job->index});
+ }
+ return std::move(delivered_job->value);
+ }
+ if (state_->first) {
+ state_->first = false;
+ for (std::size_t i = 0; i < state_->active_subscriptions.size(); i++) {
+ state_->PullSource().AddCallback(OuterCallback{state_, i});
+ }
+ }
+ return waiting_future;
+ }
+
+ private:
+ struct DeliveredJob {
+ explicit DeliveredJob(AsyncGenerator<T> deliverer_, Result<T> value_,
+ std::size_t index_)
+ : deliverer(deliverer_), value(std::move(value_)), index(index_) {}
+
+ AsyncGenerator<T> deliverer;
+ Result<T> value;
+ std::size_t index;
+ };
+
+ struct State {
+ State(AsyncGenerator<AsyncGenerator<T>> source, int max_subscriptions)
+ : source(std::move(source)),
+ active_subscriptions(max_subscriptions),
+ delivered_jobs(),
+ waiting_jobs(),
+ mutex(),
+ first(true),
+ source_exhausted(false),
+ finished(false),
+ num_active_subscriptions(max_subscriptions) {}
+
+ Future<AsyncGenerator<T>> PullSource() {
+ // Need to guard access to source() so we don't pull sync-reentrantly which
+ // is never valid.
+ auto lock = mutex.Lock();
+ return source();
+ }
+
+ AsyncGenerator<AsyncGenerator<T>> source;
+ // active_subscriptions and delivered_jobs will be bounded by max_subscriptions
+ std::vector<AsyncGenerator<T>> active_subscriptions;
+ std::deque<std::shared_ptr<DeliveredJob>> delivered_jobs;
+ // waiting_jobs is unbounded, reentrant pulls (e.g. AddReadahead) will provide the
+ // backpressure
+ std::deque<std::shared_ptr<Future<T>>> waiting_jobs;
+ util::Mutex mutex;
+ bool first;
+ bool source_exhausted;
+ bool finished;
+ int num_active_subscriptions;
+ };
+
+ struct InnerCallback {
+ void operator()(const Result<T>& maybe_next_ref) {
+ Future<T> next_fut;
+ const Result<T>* maybe_next = &maybe_next_ref;
+
+ while (true) {
+ Future<T> sink;
+ bool sub_finished = maybe_next->ok() && IsIterationEnd(**maybe_next);
+ {
+ auto guard = state->mutex.Lock();
+ if (state->finished) {
+ // We've errored out so just ignore this result and don't keep pumping
+ return;
+ }
+ if (!sub_finished) {
+ if (state->waiting_jobs.empty()) {
+ state->delivered_jobs.push_back(std::make_shared<DeliveredJob>(
+ state->active_subscriptions[index], *maybe_next, index));
+ } else {
+ sink = std::move(*state->waiting_jobs.front());
+ state->waiting_jobs.pop_front();
+ }
+ }
+ }
+ if (sub_finished) {
+ state->PullSource().AddCallback(OuterCallback{state, index});
+ } else if (sink.is_valid()) {
+ sink.MarkFinished(*maybe_next);
+ if (!maybe_next->ok()) return;
+
+ next_fut = state->active_subscriptions[index]();
+ if (next_fut.TryAddCallback([this]() { return *this; })) {
+ return;
+ }
+ // Already completed. Avoid very deep recursion by looping
+ // here instead of relying on the callback.
+ maybe_next = &next_fut.result();
+ continue;
+ }
+ return;
+ }
+ }
+ std::shared_ptr<State> state;
+ std::size_t index;
+ };
+
+ struct OuterCallback {
+ void operator()(const Result<AsyncGenerator<T>>& maybe_next) {
+ bool should_purge = false;
+ bool should_continue = false;
+ Future<T> error_sink;
+ {
+ auto guard = state->mutex.Lock();
+ if (!maybe_next.ok() || IsIterationEnd(*maybe_next)) {
+ state->source_exhausted = true;
+ if (!maybe_next.ok() || --state->num_active_subscriptions == 0) {
+ state->finished = true;
+ should_purge = true;
+ }
+ if (!maybe_next.ok()) {
+ if (state->waiting_jobs.empty()) {
+ state->delivered_jobs.push_back(std::make_shared<DeliveredJob>(
+ AsyncGenerator<T>(), maybe_next.status(), index));
+ } else {
+ error_sink = std::move(*state->waiting_jobs.front());
+ state->waiting_jobs.pop_front();
+ }
+ }
+ } else {
+ state->active_subscriptions[index] = *maybe_next;
+ should_continue = true;
+ }
+ }
+ if (error_sink.is_valid()) {
+ error_sink.MarkFinished(maybe_next.status());
+ }
+ if (should_continue) {
+ (*maybe_next)().AddCallback(InnerCallback{state, index});
+ } else if (should_purge) {
+ // At this point state->finished has been marked true so no one else
+ // will be interacting with waiting_jobs and we can iterate outside lock
+ while (!state->waiting_jobs.empty()) {
+ state->waiting_jobs.front()->MarkFinished(IterationTraits<T>::End());
+ state->waiting_jobs.pop_front();
+ }
+ }
+ }
+ std::shared_ptr<State> state;
+ std::size_t index;
+ };
+
+ std::shared_ptr<State> state_;
+};
+
+/// \brief Create a generator that takes in a stream of generators and pulls from up to
+/// max_subscriptions at a time
+///
+/// Note: This may deliver items out of sequence. For example, items from the third
+/// AsyncGenerator generated by the source may be emitted before some items from the first
+/// AsyncGenerator generated by the source.
+///
+/// This generator will pull from source async-reentrantly unless max_subscriptions is 1
+/// This generator will not pull from the individual subscriptions reentrantly. Add
+/// readahead to the individual subscriptions if that is desired.
+/// This generator is async-reentrant
+///
+/// This generator may queue up to max_subscriptions instances of T
+template <typename T>
+AsyncGenerator<T> MakeMergedGenerator(AsyncGenerator<AsyncGenerator<T>> source,
+ int max_subscriptions) {
+ return MergedGenerator<T>(std::move(source), max_subscriptions);
+}
+
+template <typename T>
+Result<AsyncGenerator<T>> MakeSequencedMergedGenerator(
+ AsyncGenerator<AsyncGenerator<T>> source, int max_subscriptions) {
+ if (max_subscriptions < 0) {
+ return Status::Invalid("max_subscriptions must be a positive integer");
+ }
+ if (max_subscriptions == 1) {
+ return Status::Invalid("Use MakeConcatenatedGenerator if max_subscriptions is 1");
+ }
+ AsyncGenerator<AsyncGenerator<T>> autostarting_source = MakeMappedGenerator(
+ std::move(source),
+ [](const AsyncGenerator<T>& sub) { return MakeAutoStartingGenerator(sub); });
+ AsyncGenerator<AsyncGenerator<T>> sub_readahead =
+ MakeSerialReadaheadGenerator(std::move(autostarting_source), max_subscriptions - 1);
+ return MakeConcatenatedGenerator(std::move(sub_readahead));
+}
+
+/// \brief Create a generator that takes in a stream of generators and pulls from each
+/// one in sequence.
+///
+/// This generator is async-reentrant but will never pull from source reentrantly and
+/// will never pull from any subscription reentrantly.
+///
+/// This generator may queue 1 instance of T
+///
+/// TODO: Could potentially make a bespoke implementation instead of MergedGenerator that
+/// forwards async-reentrant requests instead of buffering them (which is what
+/// MergedGenerator does)
+template <typename T>
+AsyncGenerator<T> MakeConcatenatedGenerator(AsyncGenerator<AsyncGenerator<T>> source) {
+ return MergedGenerator<T>(std::move(source), 1);
+}
+
+template <typename T>
+struct Enumerated {
+ T value;
+ int index;
+ bool last;
+};
+
+template <typename T>
+struct IterationTraits<Enumerated<T>> {
+ static Enumerated<T> End() { return Enumerated<T>{IterationEnd<T>(), -1, false}; }
+ static bool IsEnd(const Enumerated<T>& val) { return val.index < 0; }
+};
+
+/// \see MakeEnumeratedGenerator
+template <typename T>
+class EnumeratingGenerator {
+ public:
+ EnumeratingGenerator(AsyncGenerator<T> source, T initial_value)
+ : state_(std::make_shared<State>(std::move(source), std::move(initial_value))) {}
+
+ Future<Enumerated<T>> operator()() {
+ if (state_->finished) {
+ return AsyncGeneratorEnd<Enumerated<T>>();
+ } else {
+ auto state = state_;
+ return state->source().Then([state](const T& next) {
+ auto finished = IsIterationEnd<T>(next);
+ auto prev = Enumerated<T>{state->prev_value, state->prev_index, finished};
+ state->prev_value = next;
+ state->prev_index++;
+ state->finished = finished;
+ return prev;
+ });
+ }
+ }
+
+ private:
+ struct State {
+ State(AsyncGenerator<T> source, T initial_value)
+ : source(std::move(source)), prev_value(std::move(initial_value)), prev_index(0) {
+ finished = IsIterationEnd<T>(prev_value);
+ }
+
+ AsyncGenerator<T> source;
+ T prev_value;
+ int prev_index;
+ bool finished;
+ };
+
+ std::shared_ptr<State> state_;
+};
+
+/// Wrap items from a source generator with positional information
+///
+/// When used with MakeMergedGenerator and MakeSequencingGenerator this allows items to be
+/// processed in a "first-available" fashion and later resequenced which can reduce the
+/// impact of sources with erratic performance (e.g. a filesystem where some items may
+/// take longer to read than others).
+///
+/// TODO(ARROW-12371) Would require this generator be async-reentrant
+///
+/// \see MakeSequencingGenerator for an example of putting items back in order
+///
+/// This generator is not async-reentrant
+///
+/// This generator buffers one item (so it knows which item is the last item)
+template <typename T>
+AsyncGenerator<Enumerated<T>> MakeEnumeratedGenerator(AsyncGenerator<T> source) {
+ return FutureFirstGenerator<Enumerated<T>>(
+ source().Then([source](const T& initial_value) -> AsyncGenerator<Enumerated<T>> {
+ return EnumeratingGenerator<T>(std::move(source), initial_value);
+ }));
+}
+
+/// \see MakeTransferredGenerator
+template <typename T>
+class TransferringGenerator {
+ public:
+ explicit TransferringGenerator(AsyncGenerator<T> source, internal::Executor* executor)
+ : source_(std::move(source)), executor_(executor) {}
+
+ Future<T> operator()() { return executor_->Transfer(source_()); }
+
+ private:
+ AsyncGenerator<T> source_;
+ internal::Executor* executor_;
+};
+
+/// \brief Transfer a future to an underlying executor.
+///
+/// Continuations run on the returned future will be run on the given executor
+/// if they cannot be run synchronously.
+///
+/// This is often needed to move computation off I/O threads or other external
+/// completion sources and back on to the CPU executor so the I/O thread can
+/// stay busy and focused on I/O
+///
+/// Keep in mind that continuations called on an already completed future will
+/// always be run synchronously and so no transfer will happen in that case.
+///
+/// This generator is async reentrant if the source is
+///
+/// This generator will not queue
+template <typename T>
+AsyncGenerator<T> MakeTransferredGenerator(AsyncGenerator<T> source,
+ internal::Executor* executor) {
+ return TransferringGenerator<T>(std::move(source), executor);
+}
+
+/// \see MakeBackgroundGenerator
+template <typename T>
+class BackgroundGenerator {
+ public:
+ explicit BackgroundGenerator(Iterator<T> it, internal::Executor* io_executor, int max_q,
+ int q_restart)
+ : state_(std::make_shared<State>(io_executor, std::move(it), max_q, q_restart)),
+ cleanup_(std::make_shared<Cleanup>(state_.get())) {}
+
+ Future<T> operator()() {
+ auto guard = state_->mutex.Lock();
+ Future<T> waiting_future;
+ if (state_->queue.empty()) {
+ if (state_->finished) {
+ return AsyncGeneratorEnd<T>();
+ } else {
+ waiting_future = Future<T>::Make();
+ state_->waiting_future = waiting_future;
+ }
+ } else {
+ auto next = Future<T>::MakeFinished(std::move(state_->queue.front()));
+ state_->queue.pop();
+ if (state_->NeedsRestart()) {
+ return state_->RestartTask(state_, std::move(guard), std::move(next));
+ }
+ return next;
+ }
+ // This should only trigger the very first time this method is called
+ if (state_->NeedsRestart()) {
+ return state_->RestartTask(state_, std::move(guard), std::move(waiting_future));
+ }
+ return waiting_future;
+ }
+
+ protected:
+ static constexpr uint64_t kUnlikelyThreadId{std::numeric_limits<uint64_t>::max()};
+
+ struct State {
+ State(internal::Executor* io_executor, Iterator<T> it, int max_q, int q_restart)
+ : io_executor(io_executor),
+ max_q(max_q),
+ q_restart(q_restart),
+ it(std::move(it)),
+ reading(false),
+ finished(false),
+ should_shutdown(false) {}
+
+ void ClearQueue() {
+ while (!queue.empty()) {
+ queue.pop();
+ }
+ }
+
+ bool TaskIsRunning() const { return task_finished.is_valid(); }
+
+ bool NeedsRestart() const {
+ return !finished && !reading && static_cast<int>(queue.size()) <= q_restart;
+ }
+
+ void DoRestartTask(std::shared_ptr<State> state, util::Mutex::Guard guard) {
+ // If we get here we are actually going to start a new task so let's create a
+ // task_finished future for it
+ state->task_finished = Future<>::Make();
+ state->reading = true;
+ auto spawn_status = io_executor->Spawn(
+ [state]() { BackgroundGenerator::WorkerTask(std::move(state)); });
+ if (!spawn_status.ok()) {
+ // If we can't spawn a new task then send an error to the consumer (either via a
+ // waiting future or the queue) and mark ourselves finished
+ state->finished = true;
+ state->task_finished = Future<>();
+ if (waiting_future.has_value()) {
+ auto to_deliver = std::move(waiting_future.value());
+ waiting_future.reset();
+ guard.Unlock();
+ to_deliver.MarkFinished(spawn_status);
+ } else {
+ ClearQueue();
+ queue.push(spawn_status);
+ }
+ }
+ }
+
+ Future<T> RestartTask(std::shared_ptr<State> state, util::Mutex::Guard guard,
+ Future<T> next) {
+ if (TaskIsRunning()) {
+ // If the task is still cleaning up we need to wait for it to finish before
+ // restarting. We also want to block the consumer until we've restarted the
+ // reader to avoid multiple restarts
+ return task_finished.Then([state, next]() {
+ // This may appear dangerous (recursive mutex) but we should be guaranteed the
+ // outer guard has been released by this point. We know...
+ // * task_finished is not already finished (it would be invalid in that case)
+ // * task_finished will not be marked complete until we've given up the mutex
+ auto guard_ = state->mutex.Lock();
+ state->DoRestartTask(state, std::move(guard_));
+ return next;
+ });
+ }
+ // Otherwise we can restart immediately
+ DoRestartTask(std::move(state), std::move(guard));
+ return next;
+ }
+
+ internal::Executor* io_executor;
+ const int max_q;
+ const int q_restart;
+ Iterator<T> it;
+ std::atomic<uint64_t> worker_thread_id{kUnlikelyThreadId};
+
+ // If true, the task is actively pumping items from the queue and does not need a
+ // restart
+ bool reading;
+ // Set to true when a terminal item arrives
+ bool finished;
+ // Signal to the background task to end early because consumers have given up on it
+ bool should_shutdown;
+ // If the queue is empty, the consumer will create a waiting future and wait for it
+ std::queue<Result<T>> queue;
+ util::optional<Future<T>> waiting_future;
+ // Every background task is given a future to complete when it is entirely finished
+ // processing and ready for the next task to start or for State to be destroyed
+ Future<> task_finished;
+ util::Mutex mutex;
+ };
+
+ // Cleanup task that will be run when all consumer references to the generator are lost
+ struct Cleanup {
+ explicit Cleanup(State* state) : state(state) {}
+ ~Cleanup() {
+ /// TODO: Once ARROW-13109 is available then we can be force consumers to spawn and
+ /// there is no need to perform this check.
+ ///
+ /// It's a deadlock if we enter cleanup from
+ /// the worker thread but it can happen if the consumer doesn't transfer away
+ assert(state->worker_thread_id.load() != ::arrow::internal::GetThreadId());
+ Future<> finish_fut;
+ {
+ auto lock = state->mutex.Lock();
+ if (!state->TaskIsRunning()) {
+ return;
+ }
+ // Signal the current task to stop and wait for it to finish
+ state->should_shutdown = true;
+ finish_fut = state->task_finished;
+ }
+ // Using future as a condition variable here
+ Status st = finish_fut.status();
+ ARROW_UNUSED(st);
+ }
+ State* state;
+ };
+
+ static void WorkerTask(std::shared_ptr<State> state) {
+ state->worker_thread_id.store(::arrow::internal::GetThreadId());
+ // We need to capture the state to read while outside the mutex
+ bool reading = true;
+ while (reading) {
+ auto next = state->it.Next();
+ // Need to capture state->waiting_future inside the mutex to mark finished outside
+ Future<T> waiting_future;
+ {
+ auto guard = state->mutex.Lock();
+
+ if (state->should_shutdown) {
+ state->finished = true;
+ break;
+ }
+
+ if (!next.ok() || IsIterationEnd<T>(*next)) {
+ // Terminal item. Mark finished to true, send this last item, and quit
+ state->finished = true;
+ if (!next.ok()) {
+ state->ClearQueue();
+ }
+ }
+ // At this point we are going to send an item. Either we will add it to the
+ // queue or deliver it to a waiting future.
+ if (state->waiting_future.has_value()) {
+ waiting_future = std::move(state->waiting_future.value());
+ state->waiting_future.reset();
+ } else {
+ state->queue.push(std::move(next));
+ // We just filled up the queue so it is time to quit. We may need to notify
+ // a cleanup task so we transition to Quitting
+ if (static_cast<int>(state->queue.size()) >= state->max_q) {
+ state->reading = false;
+ }
+ }
+ reading = state->reading && !state->finished;
+ }
+ // This should happen outside the mutex. Presumably there is a
+ // transferring generator on the other end that will quickly transfer any
+ // callbacks off of this thread so we can continue looping. Still, best not to
+ // rely on that
+ if (waiting_future.is_valid()) {
+ waiting_future.MarkFinished(next);
+ }
+ }
+ // Once we've sent our last item we can notify any waiters that we are done and so
+ // either state can be cleaned up or a new background task can be started
+ Future<> task_finished;
+ {
+ auto guard = state->mutex.Lock();
+ // After we give up the mutex state can be safely deleted. We will no longer
+ // reference it. We can safely transition to idle now.
+ task_finished = state->task_finished;
+ state->task_finished = Future<>();
+ state->worker_thread_id.store(kUnlikelyThreadId);
+ }
+ task_finished.MarkFinished();
+ }
+
+ std::shared_ptr<State> state_;
+ // state_ is held by both the generator and the background thread so it won't be cleaned
+ // up when all consumer references are relinquished. cleanup_ is only held by the
+ // generator so it will be destructed when the last consumer reference is gone. We use
+ // this to cleanup / stop the background generator in case the consuming end stops
+ // listening (e.g. due to a downstream error)
+ std::shared_ptr<Cleanup> cleanup_;
+};
+
+constexpr int kDefaultBackgroundMaxQ = 32;
+constexpr int kDefaultBackgroundQRestart = 16;
+
+/// \brief Create an AsyncGenerator<T> by iterating over an Iterator<T> on a background
+/// thread
+///
+/// The parameter max_q and q_restart control queue size and background thread task
+/// management. If the background task is fast you typically don't want it creating a
+/// thread task for every item. Instead the background thread will run until it fills
+/// up a readahead queue.
+///
+/// Once the queue has filled up the background thread task will terminate (allowing other
+/// I/O tasks to use the thread). Once the queue has been drained enough (specified by
+/// q_restart) then the background thread task will be restarted. If q_restart is too low
+/// then you may exhaust the queue waiting for the background thread task to start running
+/// again. If it is too high then it will be constantly stopping and restarting the
+/// background queue task
+///
+/// The "background thread" is a logical thread and will run as tasks on the io_executor.
+/// This thread may stop and start when the queue fills up but there will only be one
+/// active background thread task at any given time. You MUST transfer away from this
+/// background generator. Otherwise there could be a race condition if a callback on the
+/// background thread deletes the last consumer reference to the background generator. You
+/// can transfer onto the same executor as the background thread, it is only neccesary to
+/// create a new thread task, not to switch executors.
+///
+/// This generator is not async-reentrant
+///
+/// This generator will queue up to max_q blocks
+template <typename T>
+static Result<AsyncGenerator<T>> MakeBackgroundGenerator(
+ Iterator<T> iterator, internal::Executor* io_executor,
+ int max_q = kDefaultBackgroundMaxQ, int q_restart = kDefaultBackgroundQRestart) {
+ if (max_q < q_restart) {
+ return Status::Invalid("max_q must be >= q_restart");
+ }
+ return BackgroundGenerator<T>(std::move(iterator), io_executor, max_q, q_restart);
+}
+
+/// \see MakeGeneratorIterator
+template <typename T>
+class GeneratorIterator {
+ public:
+ explicit GeneratorIterator(AsyncGenerator<T> source) : source_(std::move(source)) {}
+
+ Result<T> Next() { return source_().result(); }
+
+ private:
+ AsyncGenerator<T> source_;
+};
+
+/// \brief Convert an AsyncGenerator<T> to an Iterator<T> which blocks until each future
+/// is finished
+template <typename T>
+Iterator<T> MakeGeneratorIterator(AsyncGenerator<T> source) {
+ return Iterator<T>(GeneratorIterator<T>(std::move(source)));
+}
+
+/// \brief Add readahead to an iterator using a background thread.
+///
+/// Under the hood this is converting the iterator to a generator using
+/// MakeBackgroundGenerator, adding readahead to the converted generator with
+/// MakeReadaheadGenerator, and then converting back to an iterator using
+/// MakeGeneratorIterator.
+template <typename T>
+Result<Iterator<T>> MakeReadaheadIterator(Iterator<T> it, int readahead_queue_size) {
+ ARROW_ASSIGN_OR_RAISE(auto io_executor, internal::ThreadPool::Make(1));
+ auto max_q = readahead_queue_size;
+ auto q_restart = std::max(1, max_q / 2);
+ ARROW_ASSIGN_OR_RAISE(
+ auto background_generator,
+ MakeBackgroundGenerator(std::move(it), io_executor.get(), max_q, q_restart));
+ // Capture io_executor to keep it alive as long as owned_bg_generator is still
+ // referenced
+ AsyncGenerator<T> owned_bg_generator = [io_executor, background_generator]() {
+ return background_generator();
+ };
+ return MakeGeneratorIterator(std::move(owned_bg_generator));
+}
+
+/// \brief Make a generator that returns a single pre-generated future
+///
+/// This generator is async-reentrant.
+template <typename T>
+std::function<Future<T>()> MakeSingleFutureGenerator(Future<T> future) {
+ assert(future.is_valid());
+ auto state = std::make_shared<Future<T>>(std::move(future));
+ return [state]() -> Future<T> {
+ auto fut = std::move(*state);
+ if (fut.is_valid()) {
+ return fut;
+ } else {
+ return AsyncGeneratorEnd<T>();
+ }
+ };
+}
+
+/// \brief Make a generator that immediately ends.
+///
+/// This generator is async-reentrant.
+template <typename T>
+std::function<Future<T>()> MakeEmptyGenerator() {
+ return []() -> Future<T> { return AsyncGeneratorEnd<T>(); };
+}
+
+/// \brief Make a generator that always fails with a given error
+///
+/// This generator is async-reentrant.
+template <typename T>
+AsyncGenerator<T> MakeFailingGenerator(Status st) {
+ assert(!st.ok());
+ auto state = std::make_shared<Status>(std::move(st));
+ return [state]() -> Future<T> {
+ auto st = std::move(*state);
+ if (!st.ok()) {
+ return std::move(st);
+ } else {
+ return AsyncGeneratorEnd<T>();
+ }
+ };
+}
+
+/// \brief Make a generator that always fails with a given error
+///
+/// This overload allows inferring the return type from the argument.
+template <typename T>
+AsyncGenerator<T> MakeFailingGenerator(const Result<T>& result) {
+ return MakeFailingGenerator<T>(result.status());
+}
+
+/// \brief Prepend initial_values onto a generator
+///
+/// This generator is async-reentrant but will buffer requests and will not
+/// pull from following_values async-reentrantly.
+template <typename T>
+AsyncGenerator<T> MakeGeneratorStartsWith(std::vector<T> initial_values,
+ AsyncGenerator<T> following_values) {
+ auto initial_values_vec_gen = MakeVectorGenerator(std::move(initial_values));
+ auto gen_gen = MakeVectorGenerator<AsyncGenerator<T>>(
+ {std::move(initial_values_vec_gen), std::move(following_values)});
+ return MakeConcatenatedGenerator(std::move(gen_gen));
+}
+
+template <typename T>
+struct CancellableGenerator {
+ Future<T> operator()() {
+ if (stop_token.IsStopRequested()) {
+ return stop_token.Poll();
+ }
+ return source();
+ }
+
+ AsyncGenerator<T> source;
+ StopToken stop_token;
+};
+
+/// \brief Allow an async generator to be cancelled
+///
+/// This generator is async-reentrant
+template <typename T>
+AsyncGenerator<T> MakeCancellable(AsyncGenerator<T> source, StopToken stop_token) {
+ return CancellableGenerator<T>{std::move(source), std::move(stop_token)};
+}
+
+template <typename T>
+struct PauseableGenerator {
+ public:
+ PauseableGenerator(AsyncGenerator<T> source, std::shared_ptr<util::AsyncToggle> toggle)
+ : state_(std::make_shared<PauseableGeneratorState>(std::move(source),
+ std::move(toggle))) {}
+
+ Future<T> operator()() { return (*state_)(); }
+
+ private:
+ struct PauseableGeneratorState
+ : public std::enable_shared_from_this<PauseableGeneratorState> {
+ PauseableGeneratorState(AsyncGenerator<T> source,
+ std::shared_ptr<util::AsyncToggle> toggle)
+ : source_(std::move(source)), toggle_(std::move(toggle)) {}
+
+ Future<T> operator()() {
+ std::shared_ptr<PauseableGeneratorState> self = this->shared_from_this();
+ return toggle_->WhenOpen().Then([self] {
+ util::Mutex::Guard guard = self->mutex_.Lock();
+ return self->source_();
+ });
+ }
+
+ AsyncGenerator<T> source_;
+ std::shared_ptr<util::AsyncToggle> toggle_;
+ util::Mutex mutex_;
+ };
+ std::shared_ptr<PauseableGeneratorState> state_;
+};
+
+/// \brief Allow an async generator to be paused
+///
+/// This generator is NOT async-reentrant and calling it in an async-reentrant fashion
+/// may lead to items getting reordered (and potentially truncated if the end token is
+/// reordered ahead of valid items)
+///
+/// This generator forwards async-reentrant pressure
+template <typename T>
+AsyncGenerator<T> MakePauseable(AsyncGenerator<T> source,
+ std::shared_ptr<util::AsyncToggle> toggle) {
+ return PauseableGenerator<T>(std::move(source), std::move(toggle));
+}
+
+template <typename T>
+class DefaultIfEmptyGenerator {
+ public:
+ DefaultIfEmptyGenerator(AsyncGenerator<T> source, T or_value)
+ : state_(std::make_shared<State>(std::move(source), std::move(or_value))) {}
+
+ Future<T> operator()() {
+ if (state_->first) {
+ state_->first = false;
+ struct {
+ T or_value;
+
+ Result<T> operator()(const T& value) {
+ if (IterationTraits<T>::IsEnd(value)) {
+ return std::move(or_value);
+ }
+ return value;
+ }
+ } Continuation;
+ Continuation.or_value = std::move(state_->or_value);
+ return state_->source().Then(std::move(Continuation));
+ }
+ return state_->source();
+ }
+
+ private:
+ struct State {
+ AsyncGenerator<T> source;
+ T or_value;
+ bool first;
+ State(AsyncGenerator<T> source_, T or_value_)
+ : source(std::move(source_)), or_value(std::move(or_value_)), first(true) {}
+ };
+ std::shared_ptr<State> state_;
+};
+
+/// \brief If the generator is empty, return the given value, else
+/// forward the values from the generator.
+///
+/// This generator is async-reentrant.
+template <typename T>
+AsyncGenerator<T> MakeDefaultIfEmptyGenerator(AsyncGenerator<T> source, T or_value) {
+ return DefaultIfEmptyGenerator<T>(std::move(source), std::move(or_value));
+}
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/async_generator_test.cc b/src/arrow/cpp/src/arrow/util/async_generator_test.cc
new file mode 100644
index 000000000..7e5fccd9e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/async_generator_test.cc
@@ -0,0 +1,1842 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <atomic>
+#include <chrono>
+#include <condition_variable>
+#include <mutex>
+#include <random>
+#include <thread>
+#include <unordered_set>
+#include <utility>
+
+#include "arrow/io/slow.h"
+#include "arrow/testing/async_test_util.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/async_util.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/test_common.h"
+#include "arrow/util/vector.h"
+
+namespace arrow {
+
+template <typename T>
+AsyncGenerator<T> AsyncVectorIt(std::vector<T> v) {
+ return MakeVectorGenerator(std::move(v));
+}
+
+template <typename T>
+AsyncGenerator<T> FailsAt(AsyncGenerator<T> src, int failing_index) {
+ auto index = std::make_shared<std::atomic<int>>(0);
+ return [src, index, failing_index]() {
+ auto idx = index->fetch_add(1);
+ if (idx >= failing_index) {
+ return Future<T>::MakeFinished(Status::Invalid("XYZ"));
+ }
+ return src();
+ };
+}
+
+template <typename T>
+AsyncGenerator<T> SlowdownABit(AsyncGenerator<T> source) {
+ return MakeMappedGenerator(std::move(source), [](const T& res) {
+ return SleepABitAsync().Then([res]() { return res; });
+ });
+}
+
+template <typename T>
+AsyncGenerator<T> MakeJittery(AsyncGenerator<T> source) {
+ auto latency_generator = arrow::io::LatencyGenerator::Make(0.01);
+ return MakeMappedGenerator(std::move(source), [latency_generator](const T& res) {
+ auto out = Future<T>::Make();
+ std::thread([out, res, latency_generator]() mutable {
+ latency_generator->Sleep();
+ out.MarkFinished(res);
+ }).detach();
+ return out;
+ });
+}
+
+// Yields items with a small pause between each one from a background thread
+std::function<Future<TestInt>()> BackgroundAsyncVectorIt(
+ std::vector<TestInt> v, bool sleep = true, int max_q = kDefaultBackgroundMaxQ,
+ int q_restart = kDefaultBackgroundQRestart) {
+ auto pool = internal::GetCpuThreadPool();
+ auto slow_iterator = PossiblySlowVectorIt(v, sleep);
+ EXPECT_OK_AND_ASSIGN(
+ auto background,
+ MakeBackgroundGenerator<TestInt>(std::move(slow_iterator), pool, max_q, q_restart));
+ return MakeTransferredGenerator(background, pool);
+}
+
+std::function<Future<TestInt>()> NewBackgroundAsyncVectorIt(std::vector<TestInt> v,
+ bool sleep = true) {
+ auto pool = internal::GetCpuThreadPool();
+ auto iterator = VectorIt(v);
+ auto slow_iterator = MakeTransformedIterator<TestInt, TestInt>(
+ std::move(iterator), [sleep](TestInt item) -> Result<TransformFlow<TestInt>> {
+ if (sleep) {
+ SleepABit();
+ }
+ return TransformYield(item);
+ });
+
+ EXPECT_OK_AND_ASSIGN(auto background,
+ MakeBackgroundGenerator<TestInt>(std::move(slow_iterator), pool));
+ return MakeTransferredGenerator(background, pool);
+}
+
+template <typename T>
+void AssertAsyncGeneratorMatch(std::vector<T> expected, AsyncGenerator<T> actual) {
+ auto vec_future = CollectAsyncGenerator(std::move(actual));
+ EXPECT_OK_AND_ASSIGN(auto vec, vec_future.result());
+ EXPECT_EQ(expected, vec);
+}
+
+template <typename T>
+void AssertGeneratorExhausted(AsyncGenerator<T>& gen) {
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto next, gen());
+ ASSERT_TRUE(IsIterationEnd(next));
+}
+
+// --------------------------------------------------------------------
+// Asynchronous iterator tests
+
+template <typename T>
+class ReentrantCheckerGuard;
+
+template <typename T>
+ReentrantCheckerGuard<T> ExpectNotAccessedReentrantly(AsyncGenerator<T>* generator);
+
+template <typename T>
+class ReentrantChecker {
+ public:
+ Future<T> operator()() {
+ if (state_->generated_unfinished_future.load()) {
+ state_->valid.store(false);
+ }
+ state_->generated_unfinished_future.store(true);
+ auto result = state_->source();
+ return result.Then(Callback{state_});
+ }
+
+ bool valid() { return state_->valid.load(); }
+
+ private:
+ explicit ReentrantChecker(AsyncGenerator<T> source)
+ : state_(std::make_shared<State>(std::move(source))) {}
+
+ friend ReentrantCheckerGuard<T> ExpectNotAccessedReentrantly<T>(
+ AsyncGenerator<T>* generator);
+
+ struct State {
+ explicit State(AsyncGenerator<T> source_)
+ : source(std::move(source_)), generated_unfinished_future(false), valid(true) {}
+
+ AsyncGenerator<T> source;
+ std::atomic<bool> generated_unfinished_future;
+ std::atomic<bool> valid;
+ };
+ struct Callback {
+ Future<T> operator()(const T& result) {
+ state_->generated_unfinished_future.store(false);
+ return result;
+ }
+ std::shared_ptr<State> state_;
+ };
+
+ std::shared_ptr<State> state_;
+};
+
+template <typename T>
+class ReentrantCheckerGuard {
+ public:
+ explicit ReentrantCheckerGuard(ReentrantChecker<T> checker)
+ : checker_(std::move(checker)) {}
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ReentrantCheckerGuard);
+ ReentrantCheckerGuard(ReentrantCheckerGuard&& other) : checker_(other.checker_) {
+ if (other.owner_) {
+ other.owner_ = false;
+ owner_ = true;
+ } else {
+ owner_ = false;
+ }
+ }
+ ReentrantCheckerGuard& operator=(ReentrantCheckerGuard&& other) {
+ checker_ = other.checker_;
+ if (other.owner_) {
+ other.owner_ = false;
+ owner_ = true;
+ } else {
+ owner_ = false;
+ }
+ return *this;
+ }
+
+ ~ReentrantCheckerGuard() {
+ if (owner_ && !checker_.valid()) {
+ ADD_FAILURE() << "A generator was accessed reentrantly when the test asserted it "
+ "should not be.";
+ }
+ }
+
+ private:
+ ReentrantChecker<T> checker_;
+ bool owner_ = true;
+};
+
+template <typename T>
+ReentrantCheckerGuard<T> ExpectNotAccessedReentrantly(AsyncGenerator<T>* generator) {
+ auto reentrant_checker = ReentrantChecker<T>(*generator);
+ *generator = reentrant_checker;
+ return ReentrantCheckerGuard<T>(reentrant_checker);
+}
+
+class GeneratorTestFixture : public ::testing::TestWithParam<bool> {
+ public:
+ ~GeneratorTestFixture() override = default;
+
+ protected:
+ AsyncGenerator<TestInt> MakeSource(const std::vector<TestInt>& items) {
+ std::vector<TestInt> wrapped(items.begin(), items.end());
+ auto gen = AsyncVectorIt(std::move(wrapped));
+ if (IsSlow()) {
+ return SlowdownABit(std::move(gen));
+ }
+ return gen;
+ }
+
+ AsyncGenerator<TestInt> MakeEmptySource() { return MakeSource({}); }
+
+ AsyncGenerator<TestInt> MakeFailingSource() {
+ AsyncGenerator<TestInt> gen = [] {
+ return Future<TestInt>::MakeFinished(Status::Invalid("XYZ"));
+ };
+ if (IsSlow()) {
+ return SlowdownABit(std::move(gen));
+ }
+ return gen;
+ }
+
+ int GetNumItersForStress() {
+ // Run fewer trials for the slow case since they take longer
+ if (IsSlow()) {
+ return 10;
+ } else {
+ return 100;
+ }
+ }
+
+ bool IsSlow() { return GetParam(); }
+};
+
+template <typename T>
+class ManualIteratorControl {
+ public:
+ virtual ~ManualIteratorControl() {}
+ virtual void Push(Result<T> result) = 0;
+ virtual uint32_t times_polled() = 0;
+};
+
+template <typename T>
+class PushIterator : public ManualIteratorControl<T> {
+ public:
+ PushIterator() : state_(std::make_shared<State>()) {}
+ virtual ~PushIterator() {}
+
+ Result<T> Next() {
+ std::unique_lock<std::mutex> lk(state_->mx);
+ state_->times_polled++;
+ if (!state_->cv.wait_for(lk, std::chrono::seconds(300),
+ [&] { return !state_->items.empty(); })) {
+ return Status::Invalid("Timed out waiting for PushIterator");
+ }
+ auto next = std::move(state_->items.front());
+ state_->items.pop();
+ return next;
+ }
+
+ void Push(Result<T> result) override {
+ {
+ std::lock_guard<std::mutex> lg(state_->mx);
+ state_->items.push(std::move(result));
+ }
+ state_->cv.notify_one();
+ }
+
+ uint32_t times_polled() override {
+ std::lock_guard<std::mutex> lg(state_->mx);
+ return state_->times_polled;
+ }
+
+ private:
+ struct State {
+ uint32_t times_polled = 0;
+ std::mutex mx;
+ std::condition_variable cv;
+ std::queue<Result<T>> items;
+ };
+
+ std::shared_ptr<State> state_;
+};
+
+template <typename T>
+Iterator<T> MakePushIterator(std::shared_ptr<ManualIteratorControl<T>>* out) {
+ auto iter = std::make_shared<PushIterator<T>>();
+ *out = iter;
+ return Iterator<T>(*iter);
+}
+
+template <typename T>
+class ManualGenerator {
+ public:
+ ManualGenerator() : times_polled_(std::make_shared<uint32_t>()) {}
+
+ Future<T> operator()() {
+ (*times_polled_)++;
+ return source_();
+ }
+
+ uint32_t times_polled() const { return *times_polled_; }
+ typename PushGenerator<T>::Producer producer() { return source_.producer(); }
+
+ private:
+ PushGenerator<T> source_;
+ std::shared_ptr<uint32_t> times_polled_;
+};
+
+TEST(TestAsyncUtil, Visit) {
+ auto generator = AsyncVectorIt<TestInt>({1, 2, 3});
+ unsigned int sum = 0;
+ auto sum_future = VisitAsyncGenerator<TestInt>(generator, [&sum](TestInt item) {
+ sum += item.value;
+ return Status::OK();
+ });
+ ASSERT_TRUE(sum_future.is_finished());
+ ASSERT_EQ(6, sum);
+}
+
+TEST(TestAsyncUtil, Collect) {
+ std::vector<TestInt> expected = {1, 2, 3};
+ auto generator = AsyncVectorIt(expected);
+ auto collected = CollectAsyncGenerator(generator);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto collected_val, collected);
+ ASSERT_EQ(expected, collected_val);
+}
+
+TEST(TestAsyncUtil, Map) {
+ std::vector<TestInt> input = {1, 2, 3};
+ auto generator = AsyncVectorIt(input);
+ std::function<TestStr(const TestInt&)> mapper = [](const TestInt& in) {
+ return std::to_string(in.value);
+ };
+ auto mapped = MakeMappedGenerator(std::move(generator), mapper);
+ std::vector<TestStr> expected{"1", "2", "3"};
+ AssertAsyncGeneratorMatch(expected, mapped);
+}
+
+TEST(TestAsyncUtil, MapAsync) {
+ std::vector<TestInt> input = {1, 2, 3};
+ auto generator = AsyncVectorIt(input);
+ std::function<Future<TestStr>(const TestInt&)> mapper = [](const TestInt& in) {
+ return SleepAsync(1e-3).Then([in]() { return TestStr(std::to_string(in.value)); });
+ };
+ auto mapped = MakeMappedGenerator(std::move(generator), mapper);
+ std::vector<TestStr> expected{"1", "2", "3"};
+ AssertAsyncGeneratorMatch(expected, mapped);
+}
+
+TEST(TestAsyncUtil, MapReentrant) {
+ std::vector<TestInt> input = {1, 2};
+ auto source = AsyncVectorIt(input);
+ util::TrackingGenerator<TestInt> tracker(std::move(source));
+ source = MakeTransferredGenerator(AsyncGenerator<TestInt>(tracker),
+ internal::GetCpuThreadPool());
+
+ std::atomic<int> map_tasks_running(0);
+ // Mapper blocks until can_proceed is marked finished, should start multiple map tasks
+ Future<> can_proceed = Future<>::Make();
+ std::function<Future<TestStr>(const TestInt&)> mapper = [&](const TestInt& in) {
+ map_tasks_running.fetch_add(1);
+ return can_proceed.Then([in]() { return TestStr(std::to_string(in.value)); });
+ };
+ auto mapped = MakeMappedGenerator(std::move(source), mapper);
+
+ EXPECT_EQ(0, tracker.num_read());
+
+ auto one = mapped();
+ auto two = mapped();
+
+ BusyWait(10, [&] { return map_tasks_running.load() == 2; });
+ EXPECT_EQ(2, map_tasks_running.load());
+ EXPECT_EQ(2, tracker.num_read());
+
+ auto end_one = mapped();
+ auto end_two = mapped();
+
+ can_proceed.MarkFinished();
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto oneval, one);
+ EXPECT_EQ("1", oneval.value);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto twoval, two);
+ EXPECT_EQ("2", twoval.value);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto end, end_one);
+ ASSERT_EQ(IterationTraits<TestStr>::End(), end);
+ ASSERT_FINISHES_OK_AND_ASSIGN(end, end_two);
+ ASSERT_EQ(IterationTraits<TestStr>::End(), end);
+}
+
+TEST(TestAsyncUtil, MapParallelStress) {
+ constexpr int NTASKS = 10;
+ constexpr int NITEMS = 10;
+ for (int i = 0; i < NTASKS; i++) {
+ auto gen = MakeVectorGenerator(RangeVector(NITEMS));
+ gen = SlowdownABit(std::move(gen));
+ auto guard = ExpectNotAccessedReentrantly(&gen);
+ std::function<TestStr(const TestInt&)> mapper = [](const TestInt& in) {
+ SleepABit();
+ return std::to_string(in.value);
+ };
+ auto mapped = MakeMappedGenerator(std::move(gen), mapper);
+ mapped = MakeReadaheadGenerator(mapped, 8);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, CollectAsyncGenerator(mapped));
+ ASSERT_EQ(NITEMS, collected.size());
+ }
+}
+
+TEST(TestAsyncUtil, MapQueuingFailStress) {
+ constexpr int NTASKS = 10;
+ constexpr int NITEMS = 10;
+ for (bool slow : {true, false}) {
+ for (int i = 0; i < NTASKS; i++) {
+ std::shared_ptr<std::atomic<bool>> done = std::make_shared<std::atomic<bool>>();
+ auto inner = AsyncVectorIt(RangeVector(NITEMS));
+ if (slow) inner = MakeJittery(inner);
+ auto gen = FailsAt(inner, NITEMS / 2);
+ std::function<TestStr(const TestInt&)> mapper = [done](const TestInt& in) {
+ if (done->load()) {
+ ADD_FAILURE() << "Callback called after generator sent end signal";
+ }
+ return std::to_string(in.value);
+ };
+ auto mapped = MakeMappedGenerator(std::move(gen), mapper);
+ auto readahead = MakeReadaheadGenerator(std::move(mapped), 8);
+ ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(std::move(readahead)));
+ done->store(true);
+ }
+ }
+}
+
+TEST(TestAsyncUtil, MapTaskFail) {
+ std::vector<TestInt> input = {1, 2, 3};
+ auto generator = AsyncVectorIt(input);
+ std::function<Result<TestStr>(const TestInt&)> mapper =
+ [](const TestInt& in) -> Result<TestStr> {
+ if (in.value == 2) {
+ return Status::Invalid("XYZ");
+ }
+ return TestStr(std::to_string(in.value));
+ };
+ auto mapped = MakeMappedGenerator(std::move(generator), mapper);
+ ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(mapped));
+}
+
+TEST(TestAsyncUtil, MapTaskDelayedFail) {
+ // Regression test for an edge case in MappingGenerator
+ auto push = PushGenerator<TestInt>();
+ auto producer = push.producer();
+ AsyncGenerator<TestInt> generator = push;
+
+ auto delayed = Future<TestStr>::Make();
+ std::function<Future<TestStr>(const TestInt&)> mapper =
+ [=](const TestInt& in) -> Future<TestStr> {
+ if (in.value == 1) return delayed;
+ return TestStr(std::to_string(in.value));
+ };
+ auto mapped = MakeMappedGenerator(std::move(generator), mapper);
+
+ producer.Push(TestInt(1));
+ auto fut = mapped();
+ SleepABit();
+ ASSERT_FALSE(fut.is_finished());
+ // At this point there should be nothing in waiting_jobs, so the
+ // next call will push something to the queue and schedule Callback
+ auto fut2 = mapped();
+ // There's now one job in waiting_jobs. Failing the original task will
+ // purge the queue.
+ delayed.MarkFinished(Status::Invalid("XYZ"));
+ ASSERT_FINISHES_AND_RAISES(Invalid, fut);
+ // However, Callback can still run once we fulfill the remaining
+ // request. Callback needs to see that the generator is finished and
+ // bail out, instead of trying to manipulate waiting_jobs.
+ producer.Push(TestInt(2));
+ ASSERT_FINISHES_OK_AND_EQ(TestStr(), fut2);
+}
+
+TEST(TestAsyncUtil, MapSourceFail) {
+ std::vector<TestInt> input = {1, 2, 3};
+ auto generator = FailsAt(AsyncVectorIt(input), 1);
+ std::function<Result<TestStr>(const TestInt&)> mapper =
+ [](const TestInt& in) -> Result<TestStr> {
+ return TestStr(std::to_string(in.value));
+ };
+ auto mapped = MakeMappedGenerator(std::move(generator), mapper);
+ ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(mapped));
+}
+
+TEST(TestAsyncUtil, Concatenated) {
+ std::vector<TestInt> inputOne{1, 2, 3};
+ std::vector<TestInt> inputTwo{4, 5, 6};
+ std::vector<TestInt> expected{1, 2, 3, 4, 5, 6};
+ auto gen = AsyncVectorIt<AsyncGenerator<TestInt>>(
+ {AsyncVectorIt<TestInt>(inputOne), AsyncVectorIt<TestInt>(inputTwo)});
+ auto concat = MakeConcatenatedGenerator(gen);
+ AssertAsyncGeneratorMatch(expected, concat);
+}
+
+class FromFutureFixture : public GeneratorTestFixture {};
+
+TEST_P(FromFutureFixture, Basic) {
+ auto source = Future<std::vector<TestInt>>::MakeFinished(RangeVector(3));
+ if (IsSlow()) {
+ source = SleepABitAsync().Then(
+ []() -> Result<std::vector<TestInt>> { return RangeVector(3); });
+ }
+ auto slow = IsSlow();
+ auto to_gen = source.Then([slow](const std::vector<TestInt>& vec) {
+ auto vec_gen = MakeVectorGenerator(vec);
+ if (slow) {
+ return SlowdownABit(std::move(vec_gen));
+ }
+ return vec_gen;
+ });
+ auto gen = MakeFromFuture(std::move(to_gen));
+ auto collected = CollectAsyncGenerator(std::move(gen));
+ ASSERT_FINISHES_OK_AND_EQ(RangeVector(3), collected);
+}
+
+INSTANTIATE_TEST_SUITE_P(FromFutureTests, FromFutureFixture,
+ ::testing::Values(false, true));
+
+class MergedGeneratorTestFixture : public GeneratorTestFixture {};
+
+TEST_P(MergedGeneratorTestFixture, Merged) {
+ auto gen = AsyncVectorIt<AsyncGenerator<TestInt>>(
+ {MakeSource({1, 2, 3}), MakeSource({4, 5, 6})});
+
+ auto concat_gen = MakeMergedGenerator(gen, 10);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto concat, CollectAsyncGenerator(concat_gen));
+ auto concat_ints =
+ internal::MapVector([](const TestInt& val) { return val.value; }, concat);
+ std::set<int> concat_set(concat_ints.begin(), concat_ints.end());
+
+ std::set<int> expected{1, 2, 4, 3, 5, 6};
+ ASSERT_EQ(expected, concat_set);
+}
+
+TEST_P(MergedGeneratorTestFixture, MergedInnerFail) {
+ auto gen = AsyncVectorIt<AsyncGenerator<TestInt>>(
+ {MakeSource({1, 2, 3}), MakeFailingSource()});
+ auto merged_gen = MakeMergedGenerator(gen, 10);
+ ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(merged_gen));
+}
+
+TEST_P(MergedGeneratorTestFixture, MergedOuterFail) {
+ auto gen =
+ FailsAt(AsyncVectorIt<AsyncGenerator<TestInt>>(
+ {MakeSource({1, 2, 3}), MakeSource({1, 2, 3}), MakeSource({1, 2, 3})}),
+ 1);
+ auto merged_gen = MakeMergedGenerator(gen, 10);
+ ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(merged_gen));
+}
+
+TEST_P(MergedGeneratorTestFixture, MergedLimitedSubscriptions) {
+ auto gen = AsyncVectorIt<AsyncGenerator<TestInt>>(
+ {MakeSource({1, 2}), MakeSource({3, 4}), MakeSource({5, 6, 7, 8}),
+ MakeSource({9, 10, 11, 12})});
+ util::TrackingGenerator<AsyncGenerator<TestInt>> tracker(std::move(gen));
+ auto merged = MakeMergedGenerator(AsyncGenerator<AsyncGenerator<TestInt>>(tracker), 2);
+
+ SleepABit();
+ // Lazy pull, should not start until first pull
+ ASSERT_EQ(0, tracker.num_read());
+
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto next, merged());
+ ASSERT_TRUE(next.value == 1 || next.value == 3);
+
+ // First 2 values have to come from one of the first 2 sources
+ ASSERT_EQ(2, tracker.num_read());
+ ASSERT_FINISHES_OK_AND_ASSIGN(next, merged());
+ ASSERT_LT(next.value, 5);
+ ASSERT_GT(next.value, 0);
+
+ // By the time five values have been read we should have exhausted at
+ // least one source
+ for (int i = 0; i < 3; i++) {
+ ASSERT_FINISHES_OK_AND_ASSIGN(next, merged());
+ // 9 is possible if we read 1,2,3,4 and then grab 9 while 5 is running slow
+ ASSERT_LT(next.value, 10);
+ ASSERT_GT(next.value, 0);
+ }
+ ASSERT_GT(tracker.num_read(), 2);
+ ASSERT_LT(tracker.num_read(), 5);
+
+ // Read remaining values
+ for (int i = 0; i < 7; i++) {
+ ASSERT_FINISHES_OK_AND_ASSIGN(next, merged());
+ ASSERT_LT(next.value, 13);
+ ASSERT_GT(next.value, 0);
+ }
+
+ AssertGeneratorExhausted(merged);
+}
+
+TEST_P(MergedGeneratorTestFixture, MergedStress) {
+ constexpr int NGENERATORS = 10;
+ constexpr int NITEMS = 10;
+ for (int i = 0; i < GetNumItersForStress(); i++) {
+ std::vector<AsyncGenerator<TestInt>> sources;
+ std::vector<ReentrantCheckerGuard<TestInt>> guards;
+ for (int j = 0; j < NGENERATORS; j++) {
+ auto source = MakeSource(RangeVector(NITEMS));
+ guards.push_back(ExpectNotAccessedReentrantly(&source));
+ sources.push_back(source);
+ }
+ AsyncGenerator<AsyncGenerator<TestInt>> source_gen = AsyncVectorIt(sources);
+ auto outer_gaurd = ExpectNotAccessedReentrantly(&source_gen);
+
+ auto merged = MakeMergedGenerator(source_gen, 4);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto items, CollectAsyncGenerator(merged));
+ ASSERT_EQ(NITEMS * NGENERATORS, items.size());
+ }
+}
+
+TEST_P(MergedGeneratorTestFixture, MergedParallelStress) {
+ constexpr int NGENERATORS = 10;
+ constexpr int NITEMS = 10;
+ for (int i = 0; i < GetNumItersForStress(); i++) {
+ std::vector<AsyncGenerator<TestInt>> sources;
+ for (int j = 0; j < NGENERATORS; j++) {
+ sources.push_back(MakeSource(RangeVector(NITEMS)));
+ }
+ auto merged = MakeMergedGenerator(AsyncVectorIt(sources), 4);
+ merged = MakeReadaheadGenerator(merged, 4);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto items, CollectAsyncGenerator(merged));
+ ASSERT_EQ(NITEMS * NGENERATORS, items.size());
+ }
+}
+
+TEST_P(MergedGeneratorTestFixture, MergedRecursion) {
+ // Regression test for an edge case in MergedGenerator. Ensure if
+ // the source generator returns already-completed futures and there
+ // are many queued pulls (or, the consumer pulls again as part of
+ // the callback), we don't recurse due to AddCallback (leading to an
+ // eventual stack overflow).
+ const int kNumItems = IsSlow() ? 128 : 4096;
+ std::vector<TestInt> items(kNumItems, TestInt(42));
+ auto generator = MakeSource(items);
+ PushGenerator<AsyncGenerator<TestInt>> sources;
+ auto merged = MakeMergedGenerator(AsyncGenerator<AsyncGenerator<TestInt>>(sources), 1);
+ std::vector<Future<TestInt>> pulls;
+ for (int i = 0; i < kNumItems; i++) {
+ pulls.push_back(merged());
+ }
+ sources.producer().Push(generator);
+ sources.producer().Close();
+ for (const auto& fut : pulls) {
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(42), fut);
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(MergedGeneratorTests, MergedGeneratorTestFixture,
+ ::testing::Values(false, true));
+
+class AutoStartingGeneratorTestFixture : public GeneratorTestFixture {};
+
+TEST_P(AutoStartingGeneratorTestFixture, Basic) {
+ AsyncGenerator<TestInt> source = MakeSource({1, 2, 3});
+ util::TrackingGenerator<TestInt> tracked(source);
+ AsyncGenerator<TestInt> gen =
+ MakeAutoStartingGenerator(static_cast<AsyncGenerator<TestInt>>(tracked));
+ ASSERT_EQ(1, tracked.num_read());
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(1), gen());
+ ASSERT_EQ(1, tracked.num_read());
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(2), gen());
+ ASSERT_EQ(2, tracked.num_read());
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(3), gen());
+ ASSERT_EQ(3, tracked.num_read());
+ AssertGeneratorExhausted(gen);
+}
+
+TEST_P(AutoStartingGeneratorTestFixture, CopySafe) {
+ AsyncGenerator<TestInt> source = MakeSource({1, 2, 3});
+ AsyncGenerator<TestInt> gen = MakeAutoStartingGenerator(std::move(source));
+ AsyncGenerator<TestInt> copy = gen;
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(1), gen());
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(2), copy());
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(3), gen());
+ AssertGeneratorExhausted(gen);
+ AssertGeneratorExhausted(copy);
+}
+
+INSTANTIATE_TEST_SUITE_P(AutoStartingGeneratorTests, AutoStartingGeneratorTestFixture,
+ ::testing::Values(false, true));
+
+class SeqMergedGeneratorTestFixture : public ::testing::Test {
+ protected:
+ SeqMergedGeneratorTestFixture() : tracked_source_(push_source_) {}
+
+ void BeginCaptureOutput(AsyncGenerator<TestInt> gen) {
+ finished_ = VisitAsyncGenerator(std::move(gen), [this](TestInt val) {
+ sink_.push_back(val.value);
+ return Status::OK();
+ });
+ }
+
+ void EmitItem(int sub_index, int value) {
+ EXPECT_LT(sub_index, push_subs_.size());
+ push_subs_[sub_index].producer().Push(value);
+ }
+
+ void EmitErrorItem(int sub_index) {
+ EXPECT_LT(sub_index, push_subs_.size());
+ push_subs_[sub_index].producer().Push(Status::Invalid("XYZ"));
+ }
+
+ void EmitSub() {
+ PushGenerator<TestInt> sub;
+ util::TrackingGenerator<TestInt> tracked_sub(sub);
+ tracked_subs_.push_back(tracked_sub);
+ push_subs_.push_back(std::move(sub));
+ push_source_.producer().Push(std::move(tracked_sub));
+ }
+
+ void EmitErrorSub() { push_source_.producer().Push(Status::Invalid("XYZ")); }
+
+ void FinishSub(int sub_index) {
+ EXPECT_LT(sub_index, tracked_subs_.size());
+ push_subs_[sub_index].producer().Close();
+ }
+
+ void FinishSubs() { push_source_.producer().Close(); }
+
+ void AssertFinishedOk() { ASSERT_FINISHES_OK(finished_); }
+
+ void AssertFailed() { ASSERT_FINISHES_AND_RAISES(Invalid, finished_); }
+
+ int NumItemsAskedFor(int sub_index) {
+ EXPECT_LT(sub_index, tracked_subs_.size());
+ return tracked_subs_[sub_index].num_read();
+ }
+
+ int NumSubsAskedFor() { return tracked_source_.num_read(); }
+
+ void AssertRead(std::vector<int> values) {
+ ASSERT_EQ(values.size(), sink_.size());
+ for (std::size_t i = 0; i < sink_.size(); i++) {
+ ASSERT_EQ(values[i], sink_[i]);
+ }
+ }
+
+ PushGenerator<AsyncGenerator<TestInt>> push_source_;
+ std::vector<PushGenerator<TestInt>> push_subs_;
+ std::vector<util::TrackingGenerator<TestInt>> tracked_subs_;
+ util::TrackingGenerator<AsyncGenerator<TestInt>> tracked_source_;
+ Future<> finished_;
+ std::vector<int> sink_;
+};
+
+TEST_F(SeqMergedGeneratorTestFixture, Basic) {
+ ASSERT_OK_AND_ASSIGN(
+ AsyncGenerator<TestInt> gen,
+ MakeSequencedMergedGenerator(
+ static_cast<AsyncGenerator<AsyncGenerator<TestInt>>>(tracked_source_), 4));
+ // Should not initially ask for anything
+ ASSERT_EQ(0, NumSubsAskedFor());
+ BeginCaptureOutput(gen);
+ // Should not read ahead async-reentrantly from source
+ ASSERT_EQ(1, NumSubsAskedFor());
+ EmitSub();
+ ASSERT_EQ(2, NumSubsAskedFor());
+ // Should immediately start polling
+ ASSERT_EQ(1, NumItemsAskedFor(0));
+ EmitSub();
+ EmitSub();
+ EmitSub();
+ EmitSub();
+ // Should limit how many subs it reads ahead
+ ASSERT_EQ(4, NumSubsAskedFor());
+ // Should immediately start polling subs even if they aren't yet active
+ ASSERT_EQ(1, NumItemsAskedFor(1));
+ ASSERT_EQ(1, NumItemsAskedFor(2));
+ ASSERT_EQ(1, NumItemsAskedFor(3));
+ // Items emitted on non-active subs should not be delivered and should not trigger
+ // further polling on the inactive sub
+ EmitItem(1, 0);
+ ASSERT_EQ(1, NumItemsAskedFor(1));
+ AssertRead({});
+ EmitItem(0, 1);
+ AssertRead({1});
+ ASSERT_EQ(2, NumItemsAskedFor(0));
+ EmitItem(0, 2);
+ AssertRead({1, 2});
+ ASSERT_EQ(3, NumItemsAskedFor(0));
+ // On finish it should move to the next sub and pull 1 item
+ FinishSub(0);
+ ASSERT_EQ(5, NumSubsAskedFor());
+ ASSERT_EQ(2, NumItemsAskedFor(1));
+ AssertRead({1, 2, 0});
+ // Now finish all the subs and make sure an empty sub is ok
+ FinishSub(1);
+ FinishSub(2);
+ FinishSub(3);
+ FinishSub(4);
+ ASSERT_EQ(6, NumSubsAskedFor());
+ FinishSubs();
+ AssertFinishedOk();
+}
+
+TEST_F(SeqMergedGeneratorTestFixture, ErrorItem) {
+ ASSERT_OK_AND_ASSIGN(
+ AsyncGenerator<TestInt> gen,
+ MakeSequencedMergedGenerator(
+ static_cast<AsyncGenerator<AsyncGenerator<TestInt>>>(tracked_source_), 4));
+ BeginCaptureOutput(gen);
+ EmitSub();
+ EmitSub();
+ EmitErrorItem(1);
+ // It will still read from the active sub and won't notice the error until it switches
+ // to the failing sub
+ EmitItem(0, 0);
+ AssertRead({0});
+ FinishSub(0);
+ AssertFailed();
+ FinishSub(1);
+ FinishSubs();
+}
+
+TEST_F(SeqMergedGeneratorTestFixture, ErrorSub) {
+ ASSERT_OK_AND_ASSIGN(
+ AsyncGenerator<TestInt> gen,
+ MakeSequencedMergedGenerator(
+ static_cast<AsyncGenerator<AsyncGenerator<TestInt>>>(tracked_source_), 4));
+ BeginCaptureOutput(gen);
+ EmitSub();
+ EmitErrorSub();
+ FinishSub(0);
+ AssertFailed();
+}
+
+TEST(TestAsyncUtil, FromVector) {
+ AsyncGenerator<TestInt> gen;
+ {
+ std::vector<TestInt> input = {1, 2, 3};
+ gen = MakeVectorGenerator(std::move(input));
+ }
+ std::vector<TestInt> expected = {1, 2, 3};
+ AssertAsyncGeneratorMatch(expected, gen);
+}
+
+TEST(TestAsyncUtil, SynchronousFinish) {
+ AsyncGenerator<TestInt> generator = []() {
+ return Future<TestInt>::MakeFinished(IterationTraits<TestInt>::End());
+ };
+ Transformer<TestInt, TestStr> skip_all = [](TestInt value) { return TransformSkip(); };
+ auto transformed = MakeTransformedGenerator(generator, skip_all);
+ auto future = CollectAsyncGenerator(transformed);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto actual, future);
+ ASSERT_EQ(std::vector<TestStr>(), actual);
+}
+
+TEST(TestAsyncUtil, GeneratorIterator) {
+ auto generator = BackgroundAsyncVectorIt({1, 2, 3});
+ auto iterator = MakeGeneratorIterator(std::move(generator));
+ ASSERT_OK_AND_EQ(TestInt(1), iterator.Next());
+ ASSERT_OK_AND_EQ(TestInt(2), iterator.Next());
+ ASSERT_OK_AND_EQ(TestInt(3), iterator.Next());
+ AssertIteratorExhausted(iterator);
+ AssertIteratorExhausted(iterator);
+}
+
+TEST(TestAsyncUtil, MakeTransferredGenerator) {
+ std::mutex mutex;
+ std::condition_variable cv;
+ std::atomic<bool> finished(false);
+
+ ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1));
+
+ // Needs to be a slow source to ensure we don't call Then on a completed
+ AsyncGenerator<TestInt> slow_generator = [&]() {
+ return thread_pool
+ ->Submit([&] {
+ std::unique_lock<std::mutex> lock(mutex);
+ cv.wait_for(lock, std::chrono::duration<double>(30),
+ [&] { return finished.load(); });
+ return IterationTraits<TestInt>::End();
+ })
+ .ValueOrDie();
+ };
+
+ auto transferred =
+ MakeTransferredGenerator<TestInt>(std::move(slow_generator), thread_pool.get());
+
+ auto current_thread_id = std::this_thread::get_id();
+ auto fut = transferred().Then([&current_thread_id](const TestInt&) {
+ ASSERT_NE(current_thread_id, std::this_thread::get_id());
+ });
+
+ {
+ std::lock_guard<std::mutex> lg(mutex);
+ finished.store(true);
+ }
+ cv.notify_one();
+ ASSERT_FINISHES_OK(fut);
+}
+
+// This test is too slow for valgrind
+#if !(defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER))
+
+TEST(TestAsyncUtil, StackOverflow) {
+ int counter = 0;
+ AsyncGenerator<TestInt> generator = [&counter]() {
+ if (counter < 10000) {
+ return Future<TestInt>::MakeFinished(counter++);
+ } else {
+ return Future<TestInt>::MakeFinished(IterationTraits<TestInt>::End());
+ }
+ };
+ Transformer<TestInt, TestStr> discard =
+ [](TestInt next) -> Result<TransformFlow<TestStr>> { return TransformSkip(); };
+ auto transformed = MakeTransformedGenerator(generator, discard);
+ auto collected_future = CollectAsyncGenerator(transformed);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, collected_future);
+ ASSERT_EQ(0, collected.size());
+}
+
+#endif
+
+class BackgroundGeneratorTestFixture : public GeneratorTestFixture {
+ protected:
+ AsyncGenerator<TestInt> Make(const std::vector<TestInt>& it,
+ int max_q = kDefaultBackgroundMaxQ,
+ int q_restart = kDefaultBackgroundQRestart) {
+ bool slow = GetParam();
+ return BackgroundAsyncVectorIt(it, slow, max_q, q_restart);
+ }
+};
+
+TEST_P(BackgroundGeneratorTestFixture, Empty) {
+ auto background = Make({});
+ AssertGeneratorExhausted(background);
+}
+
+TEST_P(BackgroundGeneratorTestFixture, Basic) {
+ std::vector<TestInt> expected = {1, 2, 3};
+ auto background = Make(expected);
+ auto future = CollectAsyncGenerator(background);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, future);
+ ASSERT_EQ(expected, collected);
+}
+
+TEST_P(BackgroundGeneratorTestFixture, BadResult) {
+ std::shared_ptr<ManualIteratorControl<TestInt>> iterator_control;
+ auto iterator = MakePushIterator<TestInt>(&iterator_control);
+ // Enough valid items to fill the queue and then some
+ for (int i = 0; i < 5; i++) {
+ iterator_control->Push(i);
+ }
+ // Next fail
+ iterator_control->Push(Status::Invalid("XYZ"));
+ ASSERT_OK_AND_ASSIGN(
+ auto generator,
+ MakeBackgroundGenerator(std::move(iterator), internal::GetCpuThreadPool(), 4, 2));
+
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(0), generator());
+ // Have not yet restarted so next results should always be valid
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(1), generator());
+ // Next three results may or may not be valid.
+ // The typical case is the call for TestInt(2) restarts a full queue and then maybe
+ // TestInt(3) and TestInt(4) arrive quickly enough to not get pre-empted or maybe
+ // they don't.
+ //
+ // A more bizarre, but possible, case is the checking thread falls behind the producer
+ // thread just so and TestInt(1) arrives and is delivered but before the call for
+ // TestInt(2) happens the background reader reads 2, 3, 4, and 5[err] into the queue so
+ // the queue never fills up and even TestInt(2) is preempted.
+ bool invalid_encountered = false;
+ for (int i = 0; i < 3; i++) {
+ auto next_fut = generator();
+ auto next_result = next_fut.result();
+ if (next_result.ok()) {
+ ASSERT_EQ(TestInt(i + 2), next_result.ValueUnsafe());
+ } else {
+ invalid_encountered = true;
+ break;
+ }
+ }
+ // If both of the next two results are valid then this one will surely be invalid
+ if (!invalid_encountered) {
+ ASSERT_FINISHES_AND_RAISES(Invalid, generator());
+ }
+ AssertGeneratorExhausted(generator);
+}
+
+TEST_P(BackgroundGeneratorTestFixture, InvalidExecutor) {
+ std::vector<TestInt> expected = {1, 2, 3, 4, 5, 6, 7, 8};
+ // Case 1: waiting future
+ auto slow = GetParam();
+ auto it = PossiblySlowVectorIt(expected, slow);
+ ASSERT_OK_AND_ASSIGN(auto invalid_executor, internal::ThreadPool::Make(1));
+ ASSERT_OK(invalid_executor->Shutdown());
+ ASSERT_OK_AND_ASSIGN(auto background, MakeBackgroundGenerator(
+ std::move(it), invalid_executor.get(), 4, 2));
+ ASSERT_FINISHES_AND_RAISES(Invalid, background());
+
+ // Case 2: Queue bad result
+ it = PossiblySlowVectorIt(expected, slow);
+ ASSERT_OK_AND_ASSIGN(invalid_executor, internal::ThreadPool::Make(1));
+ ASSERT_OK_AND_ASSIGN(
+ background, MakeBackgroundGenerator(std::move(it), invalid_executor.get(), 4, 2));
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(1), background());
+ ASSERT_OK(invalid_executor->Shutdown());
+ // Next two are ok because queue is shutdown
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(2), background());
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(3), background());
+ // Now the queue should have tried (and failed) to start back up
+ ASSERT_FINISHES_AND_RAISES(Invalid, background());
+}
+
+TEST_P(BackgroundGeneratorTestFixture, StopAndRestart) {
+ std::shared_ptr<ManualIteratorControl<TestInt>> iterator_control;
+ auto iterator = MakePushIterator<TestInt>(&iterator_control);
+ // Start with 6 items in the source
+ for (int i = 0; i < 6; i++) {
+ iterator_control->Push(i);
+ }
+ iterator_control->Push(IterationEnd<TestInt>());
+
+ ASSERT_OK_AND_ASSIGN(
+ auto generator,
+ MakeBackgroundGenerator(std::move(iterator), internal::GetCpuThreadPool(), 4, 2));
+ SleepABit();
+ // Lazy, should not start until polled once
+ ASSERT_EQ(iterator_control->times_polled(), 0);
+ // First poll should trigger 5 reads (1 for the polled value, 4 for the queue)
+ auto next = generator();
+ BusyWait(10, [&] { return iterator_control->times_polled() >= 5; });
+ // And then stop and not read any more
+ SleepABit();
+ ASSERT_EQ(iterator_control->times_polled(), 5);
+
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(0), next);
+ // One more read should bring q down to 3 and should not restart
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(1), generator());
+ SleepABit();
+ ASSERT_EQ(iterator_control->times_polled(), 5);
+
+ // One more read should bring q down to 2 and that should restart
+ // but it will only read up to 6 because we hit end of stream
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(2), generator());
+ BusyWait(10, [&] { return iterator_control->times_polled() >= 7; });
+ ASSERT_EQ(iterator_control->times_polled(), 7);
+
+ for (int i = 3; i < 6; i++) {
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(i), generator());
+ }
+
+ AssertGeneratorExhausted(generator);
+}
+
+struct TrackingIterator {
+ explicit TrackingIterator(bool slow)
+ : token(std::make_shared<bool>(false)), slow(slow) {}
+
+ Result<TestInt> Next() {
+ if (slow) {
+ SleepABit();
+ }
+ return TestInt(0);
+ }
+ std::weak_ptr<bool> GetWeakTargetRef() { return std::weak_ptr<bool>(token); }
+
+ std::shared_ptr<bool> token;
+ bool slow;
+};
+
+TEST_P(BackgroundGeneratorTestFixture, AbortReading) {
+ // If there is an error downstream then it is likely the chain will abort and the
+ // background generator will lose all references and should abandon reading
+ TrackingIterator source(IsSlow());
+ auto tracker = source.GetWeakTargetRef();
+ auto iter = Iterator<TestInt>(std::move(source));
+ std::shared_ptr<AsyncGenerator<TestInt>> generator;
+ {
+ ASSERT_OK_AND_ASSIGN(
+ auto gen, MakeBackgroundGenerator(std::move(iter), internal::GetCpuThreadPool()));
+ generator = std::make_shared<AsyncGenerator<TestInt>>(gen);
+ }
+
+ // Poll one item to start it up
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(0), (*generator)());
+ ASSERT_FALSE(tracker.expired());
+ // Remove last reference to generator, should trigger and wait for cleanup
+ generator.reset();
+ // Cleanup should have ensured no more reference to the source. It may take a moment
+ // to expire because the background thread has to destruct itself
+ BusyWait(10, [&tracker] { return tracker.expired(); });
+}
+
+TEST_P(BackgroundGeneratorTestFixture, AbortOnIdleBackground) {
+ // Tests what happens when the downstream aborts while the background thread is idle
+ ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1));
+
+ auto source = PossiblySlowVectorIt(RangeVector(100), IsSlow());
+ std::shared_ptr<AsyncGenerator<TestInt>> generator;
+ {
+ ASSERT_OK_AND_ASSIGN(auto gen,
+ MakeBackgroundGenerator(std::move(source), thread_pool.get()));
+ generator = std::make_shared<AsyncGenerator<TestInt>>(gen);
+ }
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(0), (*generator)());
+
+ // The generator should pretty quickly fill up the queue and idle
+ BusyWait(10, [&thread_pool] { return thread_pool->GetNumTasks() == 0; });
+
+ // Now delete the generator and hope we don't deadlock
+ generator.reset();
+}
+
+struct SlowEmptyIterator {
+ Result<TestInt> Next() {
+ if (called_) {
+ return Status::Invalid("Should not have been called twice");
+ }
+ SleepFor(0.1);
+ return IterationTraits<TestInt>::End();
+ }
+
+ private:
+ bool called_ = false;
+};
+
+TEST_P(BackgroundGeneratorTestFixture, BackgroundRepeatEnd) {
+ // Ensure that the background generator properly fulfills the asyncgenerator contract
+ // and can be called after it ends.
+ ASSERT_OK_AND_ASSIGN(auto io_pool, internal::ThreadPool::Make(1));
+
+ bool slow = GetParam();
+ Iterator<TestInt> iterator;
+ if (slow) {
+ iterator = Iterator<TestInt>(SlowEmptyIterator());
+ } else {
+ iterator = MakeEmptyIterator<TestInt>();
+ }
+ ASSERT_OK_AND_ASSIGN(auto background_gen,
+ MakeBackgroundGenerator(std::move(iterator), io_pool.get()));
+
+ background_gen =
+ MakeTransferredGenerator(std::move(background_gen), internal::GetCpuThreadPool());
+
+ auto one = background_gen();
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto one_fin, one);
+ ASSERT_TRUE(IsIterationEnd(one_fin));
+
+ auto two = background_gen();
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto two_fin, two);
+ ASSERT_TRUE(IsIterationEnd(two_fin));
+}
+
+TEST_P(BackgroundGeneratorTestFixture, Stress) {
+ constexpr int NTASKS = 20;
+ constexpr int NITEMS = 20;
+ auto expected = RangeVector(NITEMS);
+ std::vector<Future<std::vector<TestInt>>> futures;
+ for (unsigned int i = 0; i < NTASKS; i++) {
+ auto background = Make(expected, /*max_q=*/4, /*q_restart=*/2);
+ futures.push_back(CollectAsyncGenerator(background));
+ }
+ auto combined = All(futures);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto completed_vectors, combined);
+ for (std::size_t i = 0; i < completed_vectors.size(); i++) {
+ ASSERT_OK_AND_ASSIGN(auto vector, completed_vectors[i]);
+ ASSERT_EQ(vector, expected);
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(BackgroundGeneratorTests, BackgroundGeneratorTestFixture,
+ ::testing::Values(false, true));
+
+TEST(TestAsyncUtil, SerialReadaheadSlowProducer) {
+ AsyncGenerator<TestInt> gen = BackgroundAsyncVectorIt({1, 2, 3, 4, 5});
+ auto guard = ExpectNotAccessedReentrantly(&gen);
+ SerialReadaheadGenerator<TestInt> serial_readahead(gen, 2);
+ AssertAsyncGeneratorMatch({1, 2, 3, 4, 5},
+ static_cast<AsyncGenerator<TestInt>>(serial_readahead));
+}
+
+TEST(TestAsyncUtil, SerialReadaheadSlowConsumer) {
+ int num_delivered = 0;
+ auto source = [&num_delivered]() {
+ if (num_delivered < 5) {
+ return Future<TestInt>::MakeFinished(num_delivered++);
+ } else {
+ return Future<TestInt>::MakeFinished(IterationTraits<TestInt>::End());
+ }
+ };
+ AsyncGenerator<TestInt> serial_readahead = SerialReadaheadGenerator<TestInt>(source, 3);
+ SleepABit();
+ ASSERT_EQ(0, num_delivered);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto next, serial_readahead());
+ ASSERT_EQ(0, next.value);
+ ASSERT_EQ(4, num_delivered);
+ AssertAsyncGeneratorMatch({1, 2, 3, 4}, serial_readahead);
+
+ // Ensure still reads ahead with just 1 slot
+ num_delivered = 0;
+ serial_readahead = SerialReadaheadGenerator<TestInt>(source, 1);
+ ASSERT_FINISHES_OK_AND_ASSIGN(next, serial_readahead());
+ ASSERT_EQ(0, next.value);
+ ASSERT_EQ(2, num_delivered);
+ AssertAsyncGeneratorMatch({1, 2, 3, 4}, serial_readahead);
+}
+
+TEST(TestAsyncUtil, SerialReadaheadStress) {
+ constexpr int NTASKS = 20;
+ constexpr int NITEMS = 50;
+ for (int i = 0; i < NTASKS; i++) {
+ AsyncGenerator<TestInt> gen = BackgroundAsyncVectorIt(RangeVector(NITEMS));
+ auto guard = ExpectNotAccessedReentrantly(&gen);
+ SerialReadaheadGenerator<TestInt> serial_readahead(gen, 2);
+ auto visit_fut =
+ VisitAsyncGenerator<TestInt>(serial_readahead, [](TestInt test_int) -> Status {
+ // Normally sleeping in a visit function would be a faux-pas but we want to slow
+ // the reader down to match the producer to maximize the stress
+ SleepABit();
+ return Status::OK();
+ });
+ ASSERT_FINISHES_OK(visit_fut);
+ }
+}
+
+TEST(TestAsyncUtil, SerialReadaheadStressFast) {
+ constexpr int NTASKS = 20;
+ constexpr int NITEMS = 50;
+ for (int i = 0; i < NTASKS; i++) {
+ AsyncGenerator<TestInt> gen = BackgroundAsyncVectorIt(RangeVector(NITEMS), false);
+ auto guard = ExpectNotAccessedReentrantly(&gen);
+ SerialReadaheadGenerator<TestInt> serial_readahead(gen, 2);
+ auto visit_fut = VisitAsyncGenerator<TestInt>(
+ serial_readahead, [](TestInt test_int) -> Status { return Status::OK(); });
+ ASSERT_FINISHES_OK(visit_fut);
+ }
+}
+
+TEST(TestAsyncUtil, SerialReadaheadStressFailing) {
+ constexpr int NTASKS = 20;
+ constexpr int NITEMS = 50;
+ constexpr int EXPECTED_SUM = 45;
+ for (int i = 0; i < NTASKS; i++) {
+ AsyncGenerator<TestInt> it = BackgroundAsyncVectorIt(RangeVector(NITEMS));
+ AsyncGenerator<TestInt> fails_at_ten = [&it]() {
+ auto next = it();
+ return next.Then([](const TestInt& item) -> Result<TestInt> {
+ if (item.value >= 10) {
+ return Status::Invalid("XYZ");
+ } else {
+ return item;
+ }
+ });
+ };
+ SerialReadaheadGenerator<TestInt> serial_readahead(fails_at_ten, 2);
+ unsigned int sum = 0;
+ auto visit_fut = VisitAsyncGenerator<TestInt>(serial_readahead,
+ [&sum](TestInt test_int) -> Status {
+ sum += test_int.value;
+ // Sleep to maximize stress
+ SleepABit();
+ return Status::OK();
+ });
+ ASSERT_FINISHES_AND_RAISES(Invalid, visit_fut);
+ ASSERT_EQ(EXPECTED_SUM, sum);
+ }
+}
+
+TEST(TestAsyncUtil, Readahead) {
+ int num_delivered = 0;
+ auto source = [&num_delivered]() {
+ if (num_delivered < 5) {
+ return Future<TestInt>::MakeFinished(num_delivered++);
+ } else {
+ return Future<TestInt>::MakeFinished(IterationTraits<TestInt>::End());
+ }
+ };
+ auto readahead = MakeReadaheadGenerator<TestInt>(source, 10);
+ // Should not pump until first item requested
+ ASSERT_EQ(0, num_delivered);
+
+ auto first = readahead();
+ // At this point the pumping should have happened
+ ASSERT_EQ(5, num_delivered);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto first_val, first);
+ ASSERT_EQ(TestInt(0), first_val);
+
+ // Read the rest
+ for (int i = 0; i < 4; i++) {
+ auto next = readahead();
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto next_val, next);
+ ASSERT_EQ(TestInt(i + 1), next_val);
+ }
+
+ // Next should be end
+ auto last = readahead();
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto last_val, last);
+ ASSERT_TRUE(IsIterationEnd(last_val));
+}
+
+TEST(TestAsyncUtil, ReadaheadCopy) {
+ auto source = AsyncVectorIt<TestInt>(RangeVector(6));
+ auto gen = MakeReadaheadGenerator(std::move(source), 2);
+
+ for (int i = 0; i < 2; i++) {
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(i), gen());
+ }
+ auto gen_copy = gen;
+ for (int i = 0; i < 2; i++) {
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(i + 2), gen_copy());
+ }
+ for (int i = 0; i < 2; i++) {
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(i + 4), gen());
+ }
+ AssertGeneratorExhausted(gen);
+ AssertGeneratorExhausted(gen_copy);
+}
+
+TEST(TestAsyncUtil, ReadaheadMove) {
+ auto source = AsyncVectorIt<TestInt>(RangeVector(6));
+ auto gen = MakeReadaheadGenerator(std::move(source), 2);
+
+ for (int i = 0; i < 2; i++) {
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(i), gen());
+ }
+ auto gen_copy = std::move(gen);
+ for (int i = 0; i < 4; i++) {
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(i + 2), gen_copy());
+ }
+ AssertGeneratorExhausted(gen_copy);
+}
+
+TEST(TestAsyncUtil, ReadaheadFailed) {
+ ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(20));
+ std::atomic<int32_t> counter(0);
+ auto gating_task = GatingTask::Make();
+ // All tasks are a little slow. The first task fails.
+ // The readahead will have spawned 9 more tasks and they
+ // should all pass
+ auto source = [&]() -> Future<TestInt> {
+ auto count = counter++;
+ return DeferNotOk(thread_pool->Submit([&, count]() -> Result<TestInt> {
+ gating_task->Task()();
+ if (count == 0) {
+ return Status::Invalid("X");
+ }
+ return TestInt(count);
+ }));
+ };
+ auto readahead = MakeReadaheadGenerator<TestInt>(source, 10);
+ auto should_be_invalid = readahead();
+ // Polling once should allow 10 additional calls to start
+ ASSERT_OK(gating_task->WaitForRunning(11));
+ ASSERT_OK(gating_task->Unlock());
+
+ // Once unlocked the error task should always be the first. Some number of successful
+ // tasks may follow until the end.
+ ASSERT_FINISHES_AND_RAISES(Invalid, should_be_invalid);
+
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto remaining_results, CollectAsyncGenerator(readahead));
+ // Don't need to know the exact number of successful tasks (and it may vary)
+ for (std::size_t i = 0; i < remaining_results.size(); i++) {
+ ASSERT_EQ(TestInt(static_cast<int>(i) + 1), remaining_results[i]);
+ }
+}
+
+class EnumeratorTestFixture : public GeneratorTestFixture {
+ protected:
+ void AssertEnumeratedCorrectly(AsyncGenerator<Enumerated<TestInt>>& gen,
+ int num_items) {
+ auto collected = CollectAsyncGenerator(gen);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto items, collected);
+ EXPECT_EQ(num_items, items.size());
+
+ for (const auto& item : items) {
+ ASSERT_EQ(item.index, item.value.value);
+ bool last = item.index == num_items - 1;
+ ASSERT_EQ(last, item.last);
+ }
+ AssertGeneratorExhausted(gen);
+ }
+};
+
+TEST_P(EnumeratorTestFixture, Basic) {
+ constexpr int NITEMS = 100;
+
+ auto source = MakeSource(RangeVector(NITEMS));
+ auto enumerated = MakeEnumeratedGenerator(std::move(source));
+
+ AssertEnumeratedCorrectly(enumerated, NITEMS);
+}
+
+TEST_P(EnumeratorTestFixture, Empty) {
+ auto source = MakeEmptySource();
+ auto enumerated = MakeEnumeratedGenerator(std::move(source));
+ AssertGeneratorExhausted(enumerated);
+}
+
+TEST_P(EnumeratorTestFixture, Error) {
+ auto source = FailsAt(MakeSource({1, 2, 3}), 1);
+ auto enumerated = MakeEnumeratedGenerator(std::move(source));
+
+ // Even though the first item finishes ok the enumerator buffers it. The error then
+ // takes priority over the buffered result.
+ ASSERT_FINISHES_AND_RAISES(Invalid, enumerated());
+}
+
+INSTANTIATE_TEST_SUITE_P(EnumeratedTests, EnumeratorTestFixture,
+ ::testing::Values(false, true));
+
+class PauseableTestFixture : public GeneratorTestFixture {
+ public:
+ ~PauseableTestFixture() override { generator_.producer().Close(); }
+
+ protected:
+ PauseableTestFixture() : toggle_(std::make_shared<util::AsyncToggle>()) {
+ sink_.clear();
+ counter_ = 0;
+ AsyncGenerator<TestInt> source = GetSource();
+ AsyncGenerator<TestInt> pauseable = MakePauseable(std::move(source), toggle_);
+ finished_ = VisitAsyncGenerator(std::move(pauseable), [this](TestInt val) {
+ std::lock_guard<std::mutex> lg(mutex_);
+ sink_.push_back(val.value);
+ return Status::OK();
+ });
+ }
+
+ void Emit() { generator_.producer().Push(counter_++); }
+
+ void Pause() { toggle_->Close(); }
+
+ void Resume() { toggle_->Open(); }
+
+ int NumCollected() {
+ std::lock_guard<std::mutex> lg(mutex_);
+ // The push generator can desequence things so we check and don't count gaps. It's
+ // a bit inefficient but good enough for this test
+ int count = 0;
+ for (std::size_t i = 0; i < sink_.size(); i++) {
+ int prev_count = count;
+ for (std::size_t j = 0; j < sink_.size(); j++) {
+ if (sink_[j] == count) {
+ count++;
+ break;
+ }
+ }
+ if (prev_count == count) {
+ break;
+ }
+ }
+ return count;
+ }
+
+ void AssertAtLeastNCollected(int target_count) {
+ BusyWait(10, [this, target_count] { return NumCollected() >= target_count; });
+ ASSERT_GE(NumCollected(), target_count);
+ }
+
+ void AssertNoMoreThanNCollected(int target_count) {
+ ASSERT_LE(NumCollected(), target_count);
+ }
+
+ AsyncGenerator<TestInt> GetSource() {
+ const auto& source = static_cast<AsyncGenerator<TestInt>>(generator_);
+ if (IsSlow()) {
+ return SlowdownABit(source);
+ } else {
+ return source;
+ }
+ }
+
+ std::mutex mutex_;
+ int counter_ = 0;
+ PushGenerator<TestInt> generator_;
+ std::shared_ptr<util::AsyncToggle> toggle_;
+ std::vector<int> sink_;
+ Future<> finished_;
+};
+
+INSTANTIATE_TEST_SUITE_P(PauseableTests, PauseableTestFixture,
+ ::testing::Values(false, true));
+
+TEST_P(PauseableTestFixture, PauseBasic) {
+ Emit();
+ Pause();
+ // This emit was asked for before the pause so it will go through
+ Emit();
+ AssertNoMoreThanNCollected(2);
+ // This emit should be blocked by the pause
+ Emit();
+ AssertNoMoreThanNCollected(2);
+ Resume();
+ AssertAtLeastNCollected(3);
+}
+
+class SequencerTestFixture : public GeneratorTestFixture {
+ protected:
+ void RandomShuffle(std::vector<TestInt>& values) {
+ std::default_random_engine gen(seed_++);
+ std::shuffle(values.begin(), values.end(), gen);
+ }
+
+ int seed_ = 42;
+ std::function<bool(const TestInt&, const TestInt&)> cmp_ =
+ [](const TestInt& left, const TestInt& right) { return left.value > right.value; };
+ // Let's increment by 2's to make it interesting
+ std::function<bool(const TestInt&, const TestInt&)> is_next_ =
+ [](const TestInt& left, const TestInt& right) {
+ return left.value + 2 == right.value;
+ };
+};
+
+TEST_P(SequencerTestFixture, SequenceBasic) {
+ // Basic sequencing
+ auto original = MakeSource({6, 4, 2});
+ auto sequenced = MakeSequencingGenerator(original, cmp_, is_next_, TestInt(0));
+ AssertAsyncGeneratorMatch({2, 4, 6}, sequenced);
+
+ // From ordered input
+ original = MakeSource({2, 4, 6});
+ sequenced = MakeSequencingGenerator(original, cmp_, is_next_, TestInt(0));
+ AssertAsyncGeneratorMatch({2, 4, 6}, sequenced);
+}
+
+TEST_P(SequencerTestFixture, SequenceLambda) {
+ auto cmp = [](const TestInt& left, const TestInt& right) {
+ return left.value > right.value;
+ };
+ auto is_next = [](const TestInt& left, const TestInt& right) {
+ return left.value + 2 == right.value;
+ };
+ // Basic sequencing
+ auto original = MakeSource({6, 4, 2});
+ auto sequenced = MakeSequencingGenerator(original, cmp, is_next, TestInt(0));
+ AssertAsyncGeneratorMatch({2, 4, 6}, sequenced);
+}
+
+TEST_P(SequencerTestFixture, SequenceError) {
+ {
+ auto original = MakeSource({6, 4, 2});
+ original = FailsAt(original, 1);
+ auto sequenced = MakeSequencingGenerator(original, cmp_, is_next_, TestInt(0));
+ auto collected = CollectAsyncGenerator(sequenced);
+ ASSERT_FINISHES_AND_RAISES(Invalid, collected);
+ }
+ {
+ // Failure should clear old items out of the queue immediately
+ // shared_ptr versions of cmp_ and is_next_
+ auto cmp = cmp_;
+ std::function<bool(const std::shared_ptr<TestInt>&, const std::shared_ptr<TestInt>&)>
+ ptr_cmp =
+ [cmp](const std::shared_ptr<TestInt>& left,
+ const std::shared_ptr<TestInt>& right) { return cmp(*left, *right); };
+ auto is_next = is_next_;
+ std::function<bool(const std::shared_ptr<TestInt>&, const std::shared_ptr<TestInt>&)>
+ ptr_is_next = [is_next](const std::shared_ptr<TestInt>& left,
+ const std::shared_ptr<TestInt>& right) {
+ return is_next(*left, *right);
+ };
+
+ PushGenerator<std::shared_ptr<TestInt>> source;
+ auto sequenced = MakeSequencingGenerator(
+ static_cast<AsyncGenerator<std::shared_ptr<TestInt>>>(source), ptr_cmp,
+ ptr_is_next, std::make_shared<TestInt>(0));
+
+ auto should_be_cleared = std::make_shared<TestInt>(4);
+ std::weak_ptr<TestInt> ref = should_be_cleared;
+ auto producer = source.producer();
+ auto next_fut = sequenced();
+ producer.Push(std::move(should_be_cleared));
+ producer.Push(Status::Invalid("XYZ"));
+ ASSERT_TRUE(ref.expired());
+
+ ASSERT_FINISHES_AND_RAISES(Invalid, next_fut);
+ }
+ {
+ // Failure should interrupt pumping
+ PushGenerator<TestInt> source;
+ auto sequenced = MakeSequencingGenerator(static_cast<AsyncGenerator<TestInt>>(source),
+ cmp_, is_next_, TestInt(0));
+
+ auto producer = source.producer();
+ auto next_fut = sequenced();
+ producer.Push(TestInt(4));
+ producer.Push(Status::Invalid("XYZ"));
+ producer.Push(TestInt(2));
+ ASSERT_FINISHES_AND_RAISES(Invalid, next_fut);
+ // The sequencer should not have pulled the 2 out of the source because it should
+ // have stopped pumping on error
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(2), source());
+ }
+}
+
+TEST_P(SequencerTestFixture, Readahead) {
+ AsyncGenerator<TestInt> original = MakeSource({4, 2, 0, 6});
+ util::TrackingGenerator<TestInt> tracker(original);
+ AsyncGenerator<TestInt> sequenced = MakeSequencingGenerator(
+ static_cast<AsyncGenerator<TestInt>>(tracker), cmp_, is_next_, TestInt(-2));
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(0), sequenced());
+ ASSERT_EQ(3, tracker.num_read());
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(2), sequenced());
+ ASSERT_EQ(3, tracker.num_read());
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(4), sequenced());
+ ASSERT_EQ(3, tracker.num_read());
+ ASSERT_FINISHES_OK_AND_EQ(TestInt(6), sequenced());
+ ASSERT_EQ(4, tracker.num_read());
+}
+
+TEST_P(SequencerTestFixture, SequenceStress) {
+ constexpr int NITEMS = 100;
+ for (auto task_index = 0; task_index < GetNumItersForStress(); task_index++) {
+ auto input = RangeVector(NITEMS, 2);
+ RandomShuffle(input);
+ auto original = MakeSource(input);
+ auto sequenced = MakeSequencingGenerator(original, cmp_, is_next_, TestInt(-2));
+ AssertAsyncGeneratorMatch(RangeVector(NITEMS, 2), sequenced);
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(SequencerTests, SequencerTestFixture,
+ ::testing::Values(false, true));
+
+TEST(TestAsyncIteratorTransform, SkipSome) {
+ auto original = AsyncVectorIt<TestInt>({1, 2, 3});
+ auto filter = MakeFilter([](TestInt& t) { return t.value != 2; });
+ auto filtered = MakeTransformedGenerator(std::move(original), filter);
+ AssertAsyncGeneratorMatch({"1", "3"}, std::move(filtered));
+}
+
+TEST(PushGenerator, Empty) {
+ PushGenerator<TestInt> gen;
+ auto producer = gen.producer();
+
+ auto fut = gen();
+ AssertNotFinished(fut);
+ ASSERT_FALSE(producer.is_closed());
+ ASSERT_TRUE(producer.Close());
+ ASSERT_TRUE(producer.is_closed());
+ ASSERT_FINISHES_OK_AND_EQ(IterationTraits<TestInt>::End(), fut);
+ ASSERT_FINISHES_OK_AND_EQ(IterationTraits<TestInt>::End(), gen());
+
+ // Close idempotent
+ fut = gen();
+ ASSERT_FALSE(producer.Close());
+ ASSERT_TRUE(producer.is_closed());
+ ASSERT_FINISHES_OK_AND_EQ(IterationTraits<TestInt>::End(), fut);
+ ASSERT_FINISHES_OK_AND_EQ(IterationTraits<TestInt>::End(), gen());
+ ASSERT_FINISHES_OK_AND_EQ(IterationTraits<TestInt>::End(), gen());
+}
+
+TEST(PushGenerator, Success) {
+ PushGenerator<TestInt> gen;
+ auto producer = gen.producer();
+ std::vector<Future<TestInt>> futures;
+
+ ASSERT_TRUE(producer.Push(TestInt{1}));
+ ASSERT_TRUE(producer.Push(TestInt{2}));
+ for (int i = 0; i < 3; ++i) {
+ futures.push_back(gen());
+ }
+ ASSERT_FINISHES_OK_AND_EQ(TestInt{1}, futures[0]);
+ ASSERT_FINISHES_OK_AND_EQ(TestInt{2}, futures[1]);
+ AssertNotFinished(futures[2]);
+
+ ASSERT_TRUE(producer.Push(TestInt{3}));
+ ASSERT_FINISHES_OK_AND_EQ(TestInt{3}, futures[2]);
+ ASSERT_TRUE(producer.Push(TestInt{4}));
+ futures.push_back(gen());
+ ASSERT_FINISHES_OK_AND_EQ(TestInt{4}, futures[3]);
+ ASSERT_TRUE(producer.Push(TestInt{5}));
+
+ ASSERT_FALSE(producer.is_closed());
+ ASSERT_TRUE(producer.Close());
+ ASSERT_TRUE(producer.is_closed());
+ for (int i = 0; i < 4; ++i) {
+ futures.push_back(gen());
+ }
+ ASSERT_FINISHES_OK_AND_EQ(TestInt{5}, futures[4]);
+ for (int i = 5; i < 8; ++i) {
+ ASSERT_FINISHES_OK_AND_EQ(IterationTraits<TestInt>::End(), futures[i]);
+ }
+ ASSERT_FINISHES_OK_AND_EQ(IterationTraits<TestInt>::End(), gen());
+}
+
+TEST(PushGenerator, Errors) {
+ PushGenerator<TestInt> gen;
+ auto producer = gen.producer();
+ std::vector<Future<TestInt>> futures;
+
+ ASSERT_TRUE(producer.Push(TestInt{1}));
+ ASSERT_TRUE(producer.Push(Status::Invalid("2")));
+ for (int i = 0; i < 3; ++i) {
+ futures.push_back(gen());
+ }
+ ASSERT_FINISHES_OK_AND_EQ(TestInt{1}, futures[0]);
+ ASSERT_FINISHES_AND_RAISES(Invalid, futures[1]);
+ AssertNotFinished(futures[2]);
+
+ ASSERT_TRUE(producer.Push(Status::IOError("3")));
+ ASSERT_TRUE(producer.Push(TestInt{4}));
+ ASSERT_FINISHES_AND_RAISES(IOError, futures[2]);
+ futures.push_back(gen());
+ ASSERT_FINISHES_OK_AND_EQ(TestInt{4}, futures[3]);
+
+ ASSERT_FALSE(producer.is_closed());
+ ASSERT_TRUE(producer.Close());
+ ASSERT_TRUE(producer.is_closed());
+ ASSERT_FINISHES_OK_AND_EQ(IterationTraits<TestInt>::End(), gen());
+}
+
+TEST(PushGenerator, CloseEarly) {
+ PushGenerator<TestInt> gen;
+ auto producer = gen.producer();
+ std::vector<Future<TestInt>> futures;
+
+ ASSERT_TRUE(producer.Push(TestInt{1}));
+ ASSERT_TRUE(producer.Push(TestInt{2}));
+ for (int i = 0; i < 3; ++i) {
+ futures.push_back(gen());
+ }
+ ASSERT_FALSE(producer.is_closed());
+ ASSERT_TRUE(producer.Close());
+ ASSERT_TRUE(producer.is_closed());
+ ASSERT_FALSE(producer.Push(TestInt{3}));
+ ASSERT_FALSE(producer.Close());
+ ASSERT_TRUE(producer.is_closed());
+
+ ASSERT_FINISHES_OK_AND_EQ(TestInt{1}, futures[0]);
+ ASSERT_FINISHES_OK_AND_EQ(TestInt{2}, futures[1]);
+ ASSERT_FINISHES_OK_AND_EQ(IterationTraits<TestInt>::End(), futures[2]);
+ ASSERT_FINISHES_OK_AND_EQ(IterationTraits<TestInt>::End(), gen());
+}
+
+TEST(PushGenerator, DanglingProducer) {
+ util::optional<PushGenerator<TestInt>> gen;
+ gen.emplace();
+ auto producer = gen->producer();
+
+ ASSERT_TRUE(producer.Push(TestInt{1}));
+ ASSERT_FALSE(producer.is_closed());
+ gen.reset();
+ ASSERT_TRUE(producer.is_closed());
+ ASSERT_FALSE(producer.Push(TestInt{2}));
+ ASSERT_FALSE(producer.Close());
+}
+
+TEST(PushGenerator, Stress) {
+ const int NTHREADS = 20;
+ const int NVALUES = 2000;
+ const int NFUTURES = NVALUES + 100;
+
+ PushGenerator<TestInt> gen;
+ auto producer = gen.producer();
+
+ std::atomic<int> next_value{0};
+
+ auto producer_worker = [&]() {
+ while (true) {
+ int v = next_value.fetch_add(1);
+ if (v >= NVALUES) {
+ break;
+ }
+ producer.Push(v);
+ }
+ };
+
+ auto producer_main = [&]() {
+ std::vector<std::thread> threads;
+ for (int i = 0; i < NTHREADS; ++i) {
+ threads.emplace_back(producer_worker);
+ }
+ for (auto& thread : threads) {
+ thread.join();
+ }
+ producer.Close();
+ };
+
+ std::vector<Result<TestInt>> results;
+ std::thread thread(producer_main);
+ for (int i = 0; i < NFUTURES; ++i) {
+ results.push_back(gen().result());
+ }
+ thread.join();
+
+ std::unordered_set<int> seen_values;
+ for (int i = 0; i < NVALUES; ++i) {
+ ASSERT_OK_AND_ASSIGN(auto v, results[i]);
+ ASSERT_EQ(seen_values.count(v.value), 0);
+ seen_values.insert(v.value);
+ }
+ for (int i = NVALUES; i < NFUTURES; ++i) {
+ ASSERT_OK_AND_EQ(IterationTraits<TestInt>::End(), results[i]);
+ }
+}
+
+TEST(SingleFutureGenerator, Basics) {
+ auto fut = Future<TestInt>::Make();
+ auto gen = MakeSingleFutureGenerator(fut);
+ auto collect_fut = CollectAsyncGenerator(gen);
+ AssertNotFinished(collect_fut);
+ fut.MarkFinished(TestInt{42});
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, collect_fut);
+ ASSERT_EQ(collected, std::vector<TestInt>{42});
+ // Generator exhausted
+ collect_fut = CollectAsyncGenerator(gen);
+ ASSERT_FINISHES_OK_AND_EQ(std::vector<TestInt>{}, collect_fut);
+}
+
+TEST(FailingGenerator, Basics) {
+ auto gen = MakeFailingGenerator<TestInt>(Status::IOError("zzz"));
+ auto collect_fut = CollectAsyncGenerator(gen);
+ ASSERT_FINISHES_AND_RAISES(IOError, collect_fut);
+ // Generator exhausted
+ collect_fut = CollectAsyncGenerator(gen);
+ ASSERT_FINISHES_OK_AND_EQ(std::vector<TestInt>{}, collect_fut);
+}
+
+TEST(DefaultIfEmptyGenerator, Basics) {
+ std::vector<TestInt> values{1, 2, 3, 4};
+ auto gen = MakeVectorGenerator(values);
+ ASSERT_FINISHES_OK_AND_ASSIGN(
+ auto actual, CollectAsyncGenerator(MakeDefaultIfEmptyGenerator(gen, TestInt(42))));
+ EXPECT_EQ(values, actual);
+
+ gen = MakeVectorGenerator<TestInt>({});
+ ASSERT_FINISHES_OK_AND_ASSIGN(
+ actual, CollectAsyncGenerator(MakeDefaultIfEmptyGenerator(gen, TestInt(42))));
+ EXPECT_EQ(std::vector<TestInt>{42}, actual);
+}
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/async_util.cc b/src/arrow/cpp/src/arrow/util/async_util.cc
new file mode 100644
index 000000000..f5b9bdcbe
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/async_util.cc
@@ -0,0 +1,206 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/async_util.h"
+
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace util {
+
+AsyncDestroyable::AsyncDestroyable() : on_closed_(Future<>::Make()) {}
+
+#ifndef NDEBUG
+AsyncDestroyable::~AsyncDestroyable() {
+ DCHECK(constructed_correctly_) << "An instance of AsyncDestroyable must be created by "
+ "MakeSharedAsync or MakeUniqueAsync";
+}
+#else
+AsyncDestroyable::~AsyncDestroyable() = default;
+#endif
+
+void AsyncDestroyable::Destroy() {
+ DoDestroy().AddCallback([this](const Status& st) {
+ on_closed_.MarkFinished(st);
+ delete this;
+ });
+}
+
+Status AsyncTaskGroup::AddTask(std::function<Result<Future<>>()> task) {
+ auto guard = mutex_.Lock();
+ if (all_tasks_done_.is_finished()) {
+ return Status::Invalid("Attempt to add a task after the task group has completed");
+ }
+ if (!err_.ok()) {
+ return err_;
+ }
+ Result<Future<>> maybe_task_fut = task();
+ if (!maybe_task_fut.ok()) {
+ err_ = maybe_task_fut.status();
+ return err_;
+ }
+ return AddTaskUnlocked(*maybe_task_fut, std::move(guard));
+}
+
+Status AsyncTaskGroup::AddTaskUnlocked(const Future<>& task_fut,
+ util::Mutex::Guard guard) {
+ // If the task is already finished there is nothing to track so lets save
+ // some work and return early
+ if (task_fut.is_finished()) {
+ err_ &= task_fut.status();
+ return err_;
+ }
+ running_tasks_++;
+ guard.Unlock();
+ task_fut.AddCallback([this](const Status& st) {
+ auto guard = mutex_.Lock();
+ err_ &= st;
+ if (--running_tasks_ == 0 && finished_adding_) {
+ guard.Unlock();
+ all_tasks_done_.MarkFinished(err_);
+ }
+ });
+ return Status::OK();
+}
+
+Status AsyncTaskGroup::AddTask(const Future<>& task_fut) {
+ auto guard = mutex_.Lock();
+ if (all_tasks_done_.is_finished()) {
+ return Status::Invalid("Attempt to add a task after the task group has completed");
+ }
+ if (!err_.ok()) {
+ return err_;
+ }
+ return AddTaskUnlocked(task_fut, std::move(guard));
+}
+
+Future<> AsyncTaskGroup::End() {
+ auto guard = mutex_.Lock();
+ finished_adding_ = true;
+ if (running_tasks_ == 0) {
+ all_tasks_done_.MarkFinished(err_);
+ return all_tasks_done_;
+ }
+ return all_tasks_done_;
+}
+
+Future<> AsyncTaskGroup::OnFinished() const { return all_tasks_done_; }
+
+SerializedAsyncTaskGroup::SerializedAsyncTaskGroup() : on_finished_(Future<>::Make()) {}
+
+Status SerializedAsyncTaskGroup::AddTask(std::function<Result<Future<>>()> task) {
+ util::Mutex::Guard guard = mutex_.Lock();
+ ARROW_RETURN_NOT_OK(err_);
+ if (on_finished_.is_finished()) {
+ return Status::Invalid("Attempt to add a task after a task group has finished");
+ }
+ tasks_.push(std::move(task));
+ if (!processing_.is_valid()) {
+ ConsumeAsMuchAsPossibleUnlocked(std::move(guard));
+ }
+ return err_;
+}
+
+Future<> SerializedAsyncTaskGroup::End() {
+ util::Mutex::Guard guard = mutex_.Lock();
+ ended_ = true;
+ if (!processing_.is_valid()) {
+ guard.Unlock();
+ on_finished_.MarkFinished(err_);
+ }
+ return on_finished_;
+}
+
+void SerializedAsyncTaskGroup::ConsumeAsMuchAsPossibleUnlocked(
+ util::Mutex::Guard&& guard) {
+ while (err_.ok() && !tasks_.empty() && TryDrainUnlocked()) {
+ }
+ if (ended_ && tasks_.empty() && !processing_.is_valid()) {
+ guard.Unlock();
+ on_finished_.MarkFinished(err_);
+ }
+}
+
+bool SerializedAsyncTaskGroup::TryDrainUnlocked() {
+ if (processing_.is_valid()) {
+ return false;
+ }
+ std::function<Result<Future<>>()> next_task = std::move(tasks_.front());
+ tasks_.pop();
+ Result<Future<>> maybe_next_fut = next_task();
+ if (!maybe_next_fut.ok()) {
+ err_ &= maybe_next_fut.status();
+ return true;
+ }
+ Future<> next_fut = maybe_next_fut.MoveValueUnsafe();
+ if (next_fut.is_finished()) {
+ err_ &= next_fut.status();
+ return true;
+ }
+ processing_ = std::move(next_fut);
+ processing_.AddCallback([this](const Status& st) {
+ util::Mutex::Guard guard = mutex_.Lock();
+ processing_ = Future<>();
+ err_ &= st;
+ ConsumeAsMuchAsPossibleUnlocked(std::move(guard));
+ });
+ return false;
+}
+
+Future<> AsyncToggle::WhenOpen() {
+ util::Mutex::Guard guard = mutex_.Lock();
+ return when_open_;
+}
+
+void AsyncToggle::Open() {
+ util::Mutex::Guard guard = mutex_.Lock();
+ if (!closed_) {
+ return;
+ }
+ closed_ = false;
+ Future<> to_finish = when_open_;
+ guard.Unlock();
+ to_finish.MarkFinished();
+}
+
+void AsyncToggle::Close() {
+ util::Mutex::Guard guard = mutex_.Lock();
+ if (closed_) {
+ return;
+ }
+ closed_ = true;
+ when_open_ = Future<>::Make();
+}
+
+bool AsyncToggle::IsOpen() {
+ util::Mutex::Guard guard = mutex_.Lock();
+ return !closed_;
+}
+
+BackpressureOptions BackpressureOptions::Make(uint32_t resume_if_below,
+ uint32_t pause_if_above) {
+ auto toggle = std::make_shared<util::AsyncToggle>();
+ return BackpressureOptions{std::move(toggle), resume_if_below, pause_if_above};
+}
+
+BackpressureOptions BackpressureOptions::NoBackpressure() {
+ return BackpressureOptions();
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/async_util.h b/src/arrow/cpp/src/arrow/util/async_util.h
new file mode 100644
index 000000000..29b216830
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/async_util.h
@@ -0,0 +1,258 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <queue>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/future.h"
+#include "arrow/util/mutex.h"
+
+namespace arrow {
+namespace util {
+
+/// Custom deleter for AsyncDestroyable objects
+template <typename T>
+struct DestroyingDeleter {
+ void operator()(T* p) {
+ if (p) {
+ p->Destroy();
+ }
+ }
+};
+
+/// An object which should be asynchronously closed before it is destroyed
+///
+/// Classes can extend this to ensure that the close method is called and completed
+/// before the instance is deleted. This provides smart_ptr / delete semantics for
+/// objects with an asynchronous destructor.
+///
+/// Classes which extend this must be constructed using MakeSharedAsync or MakeUniqueAsync
+class ARROW_EXPORT AsyncDestroyable {
+ public:
+ AsyncDestroyable();
+ virtual ~AsyncDestroyable();
+
+ /// A future which will complete when the AsyncDestroyable has finished and is ready
+ /// to be deleted.
+ ///
+ /// This can be used to ensure all work done by this object has been completed before
+ /// proceeding.
+ Future<> on_closed() { return on_closed_; }
+
+ protected:
+ /// Subclasses should override this and perform any cleanup. Once the future returned
+ /// by this method finishes then this object is eligible for destruction and any
+ /// reference to `this` will be invalid
+ virtual Future<> DoDestroy() = 0;
+
+ private:
+ void Destroy();
+
+ Future<> on_closed_;
+#ifndef NDEBUG
+ bool constructed_correctly_ = false;
+#endif
+
+ template <typename T>
+ friend struct DestroyingDeleter;
+ template <typename T, typename... Args>
+ friend std::shared_ptr<T> MakeSharedAsync(Args&&... args);
+ template <typename T, typename... Args>
+ friend std::unique_ptr<T, DestroyingDeleter<T>> MakeUniqueAsync(Args&&... args);
+};
+
+template <typename T, typename... Args>
+std::shared_ptr<T> MakeSharedAsync(Args&&... args) {
+ static_assert(std::is_base_of<AsyncDestroyable, T>::value,
+ "Nursery::MakeSharedCloseable only works with AsyncDestroyable types");
+ std::shared_ptr<T> ptr(new T(std::forward<Args&&>(args)...), DestroyingDeleter<T>());
+#ifndef NDEBUG
+ ptr->constructed_correctly_ = true;
+#endif
+ return ptr;
+}
+
+template <typename T, typename... Args>
+std::unique_ptr<T, DestroyingDeleter<T>> MakeUniqueAsync(Args&&... args) {
+ static_assert(std::is_base_of<AsyncDestroyable, T>::value,
+ "Nursery::MakeUniqueCloseable only works with AsyncDestroyable types");
+ std::unique_ptr<T, DestroyingDeleter<T>> ptr(new T(std::forward<Args>(args)...),
+ DestroyingDeleter<T>());
+#ifndef NDEBUG
+ ptr->constructed_correctly_ = true;
+#endif
+ return ptr;
+}
+
+/// A utility which keeps track of a collection of asynchronous tasks
+///
+/// This can be used to provide structured concurrency for asynchronous development.
+/// A task group created at a high level can be distributed amongst low level components
+/// which register work to be completed. The high level job can then wait for all work
+/// to be completed before cleaning up.
+class ARROW_EXPORT AsyncTaskGroup {
+ public:
+ /// Add a task to be tracked by this task group
+ ///
+ /// If a previous task has failed then adding a task will fail
+ ///
+ /// If WaitForTasksToFinish has been called and the returned future has been marked
+ /// completed then adding a task will fail.
+ Status AddTask(std::function<Result<Future<>>()> task);
+ /// Add a task that has already been started
+ Status AddTask(const Future<>& task);
+ /// Signal that top level tasks are done being added
+ ///
+ /// It is allowed for tasks to be added after this call provided the future has not yet
+ /// completed. This should be safe as long as the tasks being added are added as part
+ /// of a task that is tracked. As soon as the count of running tasks reaches 0 this
+ /// future will be marked complete.
+ ///
+ /// Any attempt to add a task after the returned future has completed will fail.
+ ///
+ /// The returned future that will finish when all running tasks have finsihed.
+ Future<> End();
+ /// A future that will be finished after End is called and all tasks have completed
+ ///
+ /// This is the same future that is returned by End() but calling this method does
+ /// not indicate that top level tasks are done being added. End() must still be called
+ /// at some point or the future returned will never finish.
+ ///
+ /// This is a utility method for workflows where the finish future needs to be
+ /// referenced before all top level tasks have been queued.
+ Future<> OnFinished() const;
+
+ private:
+ Status AddTaskUnlocked(const Future<>& task, util::Mutex::Guard guard);
+
+ bool finished_adding_ = false;
+ int running_tasks_ = 0;
+ Status err_;
+ Future<> all_tasks_done_ = Future<>::Make();
+ util::Mutex mutex_;
+};
+
+/// A task group which serializes asynchronous tasks in a push-based workflow
+///
+/// Tasks will be executed in the order they are added
+///
+/// This will buffer results in an unlimited fashion so it should be combined
+/// with some kind of backpressure
+class ARROW_EXPORT SerializedAsyncTaskGroup {
+ public:
+ SerializedAsyncTaskGroup();
+ /// Push an item into the serializer and (eventually) into the consumer
+ ///
+ /// The item will not be delivered to the consumer until all previous items have been
+ /// consumed.
+ ///
+ /// If the consumer returns an error then this serializer will go into an error state
+ /// and all subsequent pushes will fail with that error. Pushes that have been queued
+ /// but not delivered will be silently dropped.
+ ///
+ /// \return True if the item was pushed immediately to the consumer, false if it was
+ /// queued
+ Status AddTask(std::function<Result<Future<>>()> task);
+
+ /// Signal that all top level tasks have been added
+ ///
+ /// The returned future that will finish when all tasks have been consumed.
+ Future<> End();
+
+ /// A future that finishes when all queued items have been delivered.
+ ///
+ /// This will return the same future returned by End but will not signal
+ /// that all tasks have been finished. End must be called at some point in order for
+ /// this future to finish.
+ Future<> OnFinished() const { return on_finished_; }
+
+ private:
+ void ConsumeAsMuchAsPossibleUnlocked(util::Mutex::Guard&& guard);
+ bool TryDrainUnlocked();
+
+ Future<> on_finished_;
+ std::queue<std::function<Result<Future<>>()>> tasks_;
+ util::Mutex mutex_;
+ bool ended_ = false;
+ Status err_;
+ Future<> processing_;
+};
+
+class ARROW_EXPORT AsyncToggle {
+ public:
+ /// Get a future that will complete when the toggle next becomes open
+ ///
+ /// If the toggle is open this returns immediately
+ /// If the toggle is closed this future will be unfinished until the next call to Open
+ Future<> WhenOpen();
+ /// \brief Close the toggle
+ ///
+ /// After this call any call to WhenOpen will be delayed until the next open
+ void Close();
+ /// \brief Open the toggle
+ ///
+ /// Note: This call may complete a future, triggering any callbacks, and generally
+ /// should not be done while holding any locks.
+ ///
+ /// Note: If Open is called from multiple threads it could lead to a situation where
+ /// callbacks from the second open finish before callbacks on the first open.
+ ///
+ /// All current waiters will be released to enter, even if another close call
+ /// quickly follows
+ void Open();
+
+ /// \brief Return true if the toggle is currently open
+ bool IsOpen();
+
+ private:
+ Future<> when_open_ = Future<>::MakeFinished();
+ bool closed_ = false;
+ util::Mutex mutex_;
+};
+
+/// \brief Options to control backpressure behavior
+struct ARROW_EXPORT BackpressureOptions {
+ /// \brief Create default options that perform no backpressure
+ BackpressureOptions() : toggle(NULLPTR), resume_if_below(0), pause_if_above(0) {}
+ /// \brief Create options that will perform backpressure
+ ///
+ /// \param toggle A toggle to be shared between the producer and consumer
+ /// \param resume_if_below The producer should resume producing if the backpressure
+ /// queue has fewer than resume_if_below items.
+ /// \param pause_if_above The producer should pause producing if the backpressure
+ /// queue has more than pause_if_above items
+ BackpressureOptions(std::shared_ptr<util::AsyncToggle> toggle, uint32_t resume_if_below,
+ uint32_t pause_if_above)
+ : toggle(std::move(toggle)),
+ resume_if_below(resume_if_below),
+ pause_if_above(pause_if_above) {}
+
+ static BackpressureOptions Make(uint32_t resume_if_below = 32,
+ uint32_t pause_if_above = 64);
+
+ static BackpressureOptions NoBackpressure();
+
+ std::shared_ptr<util::AsyncToggle> toggle;
+ uint32_t resume_if_below;
+ uint32_t pause_if_above;
+};
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/async_util_test.cc b/src/arrow/cpp/src/arrow/util/async_util_test.cc
new file mode 100644
index 000000000..eae4adfdf
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/async_util_test.cc
@@ -0,0 +1,239 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/async_util.h"
+
+#include <gtest/gtest.h>
+
+#include "arrow/result.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+namespace util {
+
+class GatingDestroyable : public AsyncDestroyable {
+ public:
+ GatingDestroyable(Future<> close_future, bool* destroyed)
+ : close_future_(std::move(close_future)), destroyed_(destroyed) {}
+ ~GatingDestroyable() override { *destroyed_ = true; }
+
+ protected:
+ Future<> DoDestroy() override { return close_future_; }
+
+ private:
+ Future<> close_future_;
+ bool* destroyed_;
+};
+
+template <typename Factory>
+void TestAsyncDestroyable(Factory factory) {
+ Future<> gate = Future<>::Make();
+ bool destroyed = false;
+ bool on_closed = false;
+ {
+ auto obj = factory(gate, &destroyed);
+ obj->on_closed().AddCallback([&](const Status& st) { on_closed = true; });
+ ASSERT_FALSE(destroyed);
+ }
+ ASSERT_FALSE(destroyed);
+ ASSERT_FALSE(on_closed);
+ gate.MarkFinished();
+ ASSERT_TRUE(destroyed);
+ ASSERT_TRUE(on_closed);
+}
+
+TEST(AsyncDestroyable, MakeShared) {
+ TestAsyncDestroyable([](Future<> gate, bool* destroyed) {
+ return MakeSharedAsync<GatingDestroyable>(gate, destroyed);
+ });
+}
+
+// The next four tests are corner cases but can sometimes occur when using these types
+// in standard containers on certain versions of the compiler/cpplib. Basically we
+// want to make sure our deleter is ok with null pointers.
+TEST(AsyncDestroyable, DefaultUnique) {
+ std::unique_ptr<GatingDestroyable, DestroyingDeleter<GatingDestroyable>> default_ptr;
+ default_ptr.reset();
+}
+
+TEST(AsyncDestroyable, NullUnique) {
+ std::unique_ptr<GatingDestroyable, DestroyingDeleter<GatingDestroyable>> null_ptr(
+ nullptr);
+ null_ptr.reset();
+}
+
+TEST(AsyncDestroyable, NullShared) {
+ std::shared_ptr<GatingDestroyable> null_ptr(nullptr,
+ DestroyingDeleter<GatingDestroyable>());
+ null_ptr.reset();
+}
+
+TEST(AsyncDestroyable, NullUniqueToShared) {
+ std::unique_ptr<GatingDestroyable, DestroyingDeleter<GatingDestroyable>> null_ptr(
+ nullptr);
+ std::shared_ptr<GatingDestroyable> null_shared = std::move(null_ptr);
+ null_shared.reset();
+}
+
+TEST(AsyncDestroyable, MakeUnique) {
+ TestAsyncDestroyable([](Future<> gate, bool* destroyed) {
+ return MakeUniqueAsync<GatingDestroyable>(gate, destroyed);
+ });
+}
+
+template <typename T>
+class TypedTestAsyncTaskGroup : public ::testing::Test {};
+
+using AsyncTaskGroupTypes = ::testing::Types<AsyncTaskGroup, SerializedAsyncTaskGroup>;
+
+TYPED_TEST_SUITE(TypedTestAsyncTaskGroup, AsyncTaskGroupTypes);
+
+TYPED_TEST(TypedTestAsyncTaskGroup, Basic) {
+ TypeParam task_group;
+ Future<> fut1 = Future<>::Make();
+ Future<> fut2 = Future<>::Make();
+ ASSERT_OK(task_group.AddTask([fut1]() { return fut1; }));
+ ASSERT_OK(task_group.AddTask([fut2]() { return fut2; }));
+ Future<> all_done = task_group.End();
+ AssertNotFinished(all_done);
+ fut1.MarkFinished();
+ AssertNotFinished(all_done);
+ fut2.MarkFinished();
+ ASSERT_FINISHES_OK(all_done);
+}
+
+TYPED_TEST(TypedTestAsyncTaskGroup, NoTasks) {
+ TypeParam task_group;
+ ASSERT_FINISHES_OK(task_group.End());
+}
+
+TYPED_TEST(TypedTestAsyncTaskGroup, OnFinishedDoesNotEnd) {
+ TypeParam task_group;
+ Future<> on_finished = task_group.OnFinished();
+ AssertNotFinished(on_finished);
+ ASSERT_FINISHES_OK(task_group.End());
+ ASSERT_FINISHES_OK(on_finished);
+}
+
+TYPED_TEST(TypedTestAsyncTaskGroup, AddAfterDone) {
+ TypeParam task_group;
+ ASSERT_FINISHES_OK(task_group.End());
+ ASSERT_RAISES(Invalid, task_group.AddTask([] { return Future<>::Make(); }));
+}
+
+TYPED_TEST(TypedTestAsyncTaskGroup, AddAfterWaitButBeforeFinish) {
+ TypeParam task_group;
+ Future<> task_one = Future<>::Make();
+ ASSERT_OK(task_group.AddTask([task_one] { return task_one; }));
+ Future<> finish_fut = task_group.End();
+ AssertNotFinished(finish_fut);
+ Future<> task_two = Future<>::Make();
+ ASSERT_OK(task_group.AddTask([task_two] { return task_two; }));
+ AssertNotFinished(finish_fut);
+ task_one.MarkFinished();
+ AssertNotFinished(finish_fut);
+ task_two.MarkFinished();
+ AssertFinished(finish_fut);
+ ASSERT_FINISHES_OK(finish_fut);
+}
+
+TYPED_TEST(TypedTestAsyncTaskGroup, Error) {
+ TypeParam task_group;
+ Future<> failed_task = Future<>::MakeFinished(Status::Invalid("XYZ"));
+ ASSERT_RAISES(Invalid, task_group.AddTask([failed_task] { return failed_task; }));
+ ASSERT_FINISHES_AND_RAISES(Invalid, task_group.End());
+}
+
+TYPED_TEST(TypedTestAsyncTaskGroup, TaskFactoryFails) {
+ TypeParam task_group;
+ ASSERT_RAISES(Invalid, task_group.AddTask([] { return Status::Invalid("XYZ"); }));
+ ASSERT_RAISES(Invalid, task_group.AddTask([] { return Future<>::Make(); }));
+ ASSERT_FINISHES_AND_RAISES(Invalid, task_group.End());
+}
+
+TYPED_TEST(TypedTestAsyncTaskGroup, AddAfterFailed) {
+ TypeParam task_group;
+ ASSERT_RAISES(Invalid, task_group.AddTask([] {
+ return Future<>::MakeFinished(Status::Invalid("XYZ"));
+ }));
+ ASSERT_RAISES(Invalid, task_group.AddTask([] { return Future<>::Make(); }));
+ ASSERT_FINISHES_AND_RAISES(Invalid, task_group.End());
+}
+
+TEST(StandardAsyncTaskGroup, TaskFinishesAfterError) {
+ AsyncTaskGroup task_group;
+ Future<> fut1 = Future<>::Make();
+ ASSERT_OK(task_group.AddTask([fut1] { return fut1; }));
+ ASSERT_RAISES(Invalid, task_group.AddTask([] {
+ return Future<>::MakeFinished(Status::Invalid("XYZ"));
+ }));
+ Future<> finished_fut = task_group.End();
+ AssertNotFinished(finished_fut);
+ fut1.MarkFinished();
+ ASSERT_FINISHES_AND_RAISES(Invalid, finished_fut);
+}
+
+TEST(StandardAsyncTaskGroup, FailAfterAdd) {
+ AsyncTaskGroup task_group;
+ Future<> will_fail = Future<>::Make();
+ ASSERT_OK(task_group.AddTask([will_fail] { return will_fail; }));
+ Future<> added_later_and_passes = Future<>::Make();
+ ASSERT_OK(
+ task_group.AddTask([added_later_and_passes] { return added_later_and_passes; }));
+ will_fail.MarkFinished(Status::Invalid("XYZ"));
+ ASSERT_RAISES(Invalid, task_group.AddTask([] { return Future<>::Make(); }));
+ Future<> finished_fut = task_group.End();
+ AssertNotFinished(finished_fut);
+ added_later_and_passes.MarkFinished();
+ AssertFinished(finished_fut);
+ ASSERT_FINISHES_AND_RAISES(Invalid, finished_fut);
+}
+
+// The serialized task group can never really get into a "fail after add" scenario
+// because there is no parallelism. So the behavior is a little unique in these scenarios
+
+TEST(SerializedAsyncTaskGroup, TaskFinishesAfterError) {
+ SerializedAsyncTaskGroup task_group;
+ Future<> fut1 = Future<>::Make();
+ ASSERT_OK(task_group.AddTask([fut1] { return fut1; }));
+ ASSERT_OK(
+ task_group.AddTask([] { return Future<>::MakeFinished(Status::Invalid("XYZ")); }));
+ Future<> finished_fut = task_group.End();
+ AssertNotFinished(finished_fut);
+ fut1.MarkFinished();
+ ASSERT_FINISHES_AND_RAISES(Invalid, finished_fut);
+}
+
+TEST(SerializedAsyncTaskGroup, FailAfterAdd) {
+ SerializedAsyncTaskGroup task_group;
+ Future<> will_fail = Future<>::Make();
+ ASSERT_OK(task_group.AddTask([will_fail] { return will_fail; }));
+ Future<> added_later_and_passes = Future<>::Make();
+ bool added_later_and_passes_created = false;
+ ASSERT_OK(task_group.AddTask([added_later_and_passes, &added_later_and_passes_created] {
+ added_later_and_passes_created = true;
+ return added_later_and_passes;
+ }));
+ will_fail.MarkFinished(Status::Invalid("XYZ"));
+ ASSERT_RAISES(Invalid, task_group.AddTask([] { return Future<>::Make(); }));
+ ASSERT_FINISHES_AND_RAISES(Invalid, task_group.End());
+ ASSERT_FALSE(added_later_and_passes_created);
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/atomic_shared_ptr.h b/src/arrow/cpp/src/arrow/util/atomic_shared_ptr.h
new file mode 100644
index 000000000..d93ad921d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/atomic_shared_ptr.h
@@ -0,0 +1,111 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <atomic>
+#include <memory>
+#include <utility>
+
+#include "arrow/type_traits.h"
+
+namespace arrow {
+namespace internal {
+
+// Atomic shared_ptr operations only appeared in libstdc++ since GCC 5,
+// emulate them with unsafe ops if unavailable.
+// See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=57250
+
+template <typename T, typename = void>
+struct is_atomic_load_shared_ptr_available : std::false_type {};
+
+template <typename T>
+struct is_atomic_load_shared_ptr_available<
+ T, void_t<decltype(std::atomic_load(std::declval<const std::shared_ptr<T>*>()))>>
+ : std::true_type {};
+
+template <typename T>
+using enable_if_atomic_load_shared_ptr_available =
+ enable_if_t<is_atomic_load_shared_ptr_available<T>::value, T>;
+
+template <typename T>
+using enable_if_atomic_load_shared_ptr_unavailable =
+ enable_if_t<!is_atomic_load_shared_ptr_available<T>::value, T>;
+
+template <class T>
+enable_if_atomic_load_shared_ptr_available<std::shared_ptr<T>> atomic_load(
+ const std::shared_ptr<T>* p) {
+ return std::atomic_load(p);
+}
+
+template <class T>
+enable_if_atomic_load_shared_ptr_unavailable<std::shared_ptr<T>> atomic_load(
+ const std::shared_ptr<T>* p) {
+ return *p;
+}
+
+template <typename T, typename = void>
+struct is_atomic_store_shared_ptr_available : std::false_type {};
+
+template <typename T>
+struct is_atomic_store_shared_ptr_available<
+ T, void_t<decltype(std::atomic_store(std::declval<std::shared_ptr<T>*>(),
+ std::declval<std::shared_ptr<T>>()))>>
+ : std::true_type {};
+
+template <typename T>
+using enable_if_atomic_store_shared_ptr_available =
+ enable_if_t<is_atomic_store_shared_ptr_available<T>::value, T>;
+
+template <typename T>
+using enable_if_atomic_store_shared_ptr_unavailable =
+ enable_if_t<!is_atomic_store_shared_ptr_available<T>::value, T>;
+
+template <class T>
+void atomic_store(enable_if_atomic_store_shared_ptr_available<std::shared_ptr<T>*> p,
+ std::shared_ptr<T> r) {
+ std::atomic_store(p, std::move(r));
+}
+
+template <class T>
+void atomic_store(enable_if_atomic_store_shared_ptr_unavailable<std::shared_ptr<T>*> p,
+ std::shared_ptr<T> r) {
+ *p = r;
+}
+
+template <class T>
+bool atomic_compare_exchange_strong(
+ enable_if_atomic_store_shared_ptr_available<std::shared_ptr<T>*> p,
+ std::shared_ptr<T>* expected, std::shared_ptr<T> desired) {
+ return std::atomic_compare_exchange_strong(p, expected, std::move(desired));
+}
+
+template <class T>
+bool atomic_compare_exchange_strong(
+ enable_if_atomic_store_shared_ptr_unavailable<std::shared_ptr<T>*> p,
+ std::shared_ptr<T>* expected, std::shared_ptr<T> desired) {
+ if (*p == *expected) {
+ *p = std::move(desired);
+ return true;
+ } else {
+ *expected = *p;
+ return false;
+ }
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/base64.h b/src/arrow/cpp/src/arrow/util/base64.h
new file mode 100644
index 000000000..a46884d17
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/base64.h
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+
+#include "arrow/util/string_view.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace util {
+
+ARROW_EXPORT
+std::string base64_encode(string_view s);
+
+ARROW_EXPORT
+std::string base64_decode(string_view s);
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/basic_decimal.cc b/src/arrow/cpp/src/arrow/util/basic_decimal.cc
new file mode 100644
index 000000000..1832bf5c4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/basic_decimal.cc
@@ -0,0 +1,1381 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/basic_decimal.h"
+
+#include <algorithm>
+#include <array>
+#include <climits>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+#include <iomanip>
+#include <limits>
+#include <string>
+
+#include "arrow/util/bit_util.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/int128_internal.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+
+using internal::SafeLeftShift;
+using internal::SafeSignedAdd;
+
+static const BasicDecimal128 ScaleMultipliers[] = {
+ BasicDecimal128(1LL),
+ BasicDecimal128(10LL),
+ BasicDecimal128(100LL),
+ BasicDecimal128(1000LL),
+ BasicDecimal128(10000LL),
+ BasicDecimal128(100000LL),
+ BasicDecimal128(1000000LL),
+ BasicDecimal128(10000000LL),
+ BasicDecimal128(100000000LL),
+ BasicDecimal128(1000000000LL),
+ BasicDecimal128(10000000000LL),
+ BasicDecimal128(100000000000LL),
+ BasicDecimal128(1000000000000LL),
+ BasicDecimal128(10000000000000LL),
+ BasicDecimal128(100000000000000LL),
+ BasicDecimal128(1000000000000000LL),
+ BasicDecimal128(10000000000000000LL),
+ BasicDecimal128(100000000000000000LL),
+ BasicDecimal128(1000000000000000000LL),
+ BasicDecimal128(0LL, 10000000000000000000ULL),
+ BasicDecimal128(5LL, 7766279631452241920ULL),
+ BasicDecimal128(54LL, 3875820019684212736ULL),
+ BasicDecimal128(542LL, 1864712049423024128ULL),
+ BasicDecimal128(5421LL, 200376420520689664ULL),
+ BasicDecimal128(54210LL, 2003764205206896640ULL),
+ BasicDecimal128(542101LL, 1590897978359414784ULL),
+ BasicDecimal128(5421010LL, 15908979783594147840ULL),
+ BasicDecimal128(54210108LL, 11515845246265065472ULL),
+ BasicDecimal128(542101086LL, 4477988020393345024ULL),
+ BasicDecimal128(5421010862LL, 7886392056514347008ULL),
+ BasicDecimal128(54210108624LL, 5076944270305263616ULL),
+ BasicDecimal128(542101086242LL, 13875954555633532928ULL),
+ BasicDecimal128(5421010862427LL, 9632337040368467968ULL),
+ BasicDecimal128(54210108624275LL, 4089650035136921600ULL),
+ BasicDecimal128(542101086242752LL, 4003012203950112768ULL),
+ BasicDecimal128(5421010862427522LL, 3136633892082024448ULL),
+ BasicDecimal128(54210108624275221LL, 12919594847110692864ULL),
+ BasicDecimal128(542101086242752217LL, 68739955140067328ULL),
+ BasicDecimal128(5421010862427522170LL, 687399551400673280ULL)};
+
+static const BasicDecimal128 ScaleMultipliersHalf[] = {
+ BasicDecimal128(0ULL),
+ BasicDecimal128(5ULL),
+ BasicDecimal128(50ULL),
+ BasicDecimal128(500ULL),
+ BasicDecimal128(5000ULL),
+ BasicDecimal128(50000ULL),
+ BasicDecimal128(500000ULL),
+ BasicDecimal128(5000000ULL),
+ BasicDecimal128(50000000ULL),
+ BasicDecimal128(500000000ULL),
+ BasicDecimal128(5000000000ULL),
+ BasicDecimal128(50000000000ULL),
+ BasicDecimal128(500000000000ULL),
+ BasicDecimal128(5000000000000ULL),
+ BasicDecimal128(50000000000000ULL),
+ BasicDecimal128(500000000000000ULL),
+ BasicDecimal128(5000000000000000ULL),
+ BasicDecimal128(50000000000000000ULL),
+ BasicDecimal128(500000000000000000ULL),
+ BasicDecimal128(5000000000000000000ULL),
+ BasicDecimal128(2LL, 13106511852580896768ULL),
+ BasicDecimal128(27LL, 1937910009842106368ULL),
+ BasicDecimal128(271LL, 932356024711512064ULL),
+ BasicDecimal128(2710LL, 9323560247115120640ULL),
+ BasicDecimal128(27105LL, 1001882102603448320ULL),
+ BasicDecimal128(271050LL, 10018821026034483200ULL),
+ BasicDecimal128(2710505LL, 7954489891797073920ULL),
+ BasicDecimal128(27105054LL, 5757922623132532736ULL),
+ BasicDecimal128(271050543LL, 2238994010196672512ULL),
+ BasicDecimal128(2710505431LL, 3943196028257173504ULL),
+ BasicDecimal128(27105054312LL, 2538472135152631808ULL),
+ BasicDecimal128(271050543121LL, 6937977277816766464ULL),
+ BasicDecimal128(2710505431213LL, 14039540557039009792ULL),
+ BasicDecimal128(27105054312137LL, 11268197054423236608ULL),
+ BasicDecimal128(271050543121376LL, 2001506101975056384ULL),
+ BasicDecimal128(2710505431213761LL, 1568316946041012224ULL),
+ BasicDecimal128(27105054312137610LL, 15683169460410122240ULL),
+ BasicDecimal128(271050543121376108LL, 9257742014424809472ULL),
+ BasicDecimal128(2710505431213761085LL, 343699775700336640ULL)};
+
+#define BasicDecimal256FromLE(v1, v2, v3, v4) \
+ BasicDecimal256(BitUtil::LittleEndianArray::ToNative<uint64_t, 4>(v1, v2, v3, v4))
+
+static const BasicDecimal256 ScaleMultipliersDecimal256[] = {
+ BasicDecimal256FromLE({1ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({10ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({100ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({1000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({10000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({100000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({1000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({10000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({100000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({1000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({10000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({100000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({1000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({10000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({100000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({1000000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({10000000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({100000000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({1000000000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({10000000000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({7766279631452241920ULL, 5ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({3875820019684212736ULL, 54ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({1864712049423024128ULL, 542ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({200376420520689664ULL, 5421ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({2003764205206896640ULL, 54210ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({1590897978359414784ULL, 542101ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({15908979783594147840ULL, 5421010ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({11515845246265065472ULL, 54210108ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({4477988020393345024ULL, 542101086ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({7886392056514347008ULL, 5421010862ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({5076944270305263616ULL, 54210108624ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({13875954555633532928ULL, 542101086242ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({9632337040368467968ULL, 5421010862427ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({4089650035136921600ULL, 54210108624275ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({4003012203950112768ULL, 542101086242752ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({3136633892082024448ULL, 5421010862427522ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({12919594847110692864ULL, 54210108624275221ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({68739955140067328ULL, 542101086242752217ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({687399551400673280ULL, 5421010862427522170ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({6873995514006732800ULL, 17316620476856118468ULL, 2ULL, 0ULL}),
+ BasicDecimal256FromLE({13399722918938673152ULL, 7145508105175220139ULL, 29ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {4870020673419870208ULL, 16114848830623546549ULL, 293ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {11806718586779598848ULL, 13574535716559052564ULL, 2938ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {7386721425538678784ULL, 6618148649623664334ULL, 29387ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {80237960548581376ULL, 10841254275107988496ULL, 293873ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {802379605485813760ULL, 16178822382532126880ULL, 2938735ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {8023796054858137600ULL, 14214271235644855872ULL, 29387358ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {6450984253743169536ULL, 13015503840481697412ULL, 293873587ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {9169610316303040512ULL, 1027829888850112811ULL, 2938735877ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {17909126868192198656ULL, 10278298888501128114ULL, 29387358770ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {13070572018536022016ULL, 10549268516463523069ULL, 293873587705ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {1578511669393358848ULL, 13258964796087472617ULL, 2938735877055ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {15785116693933588480ULL, 3462439444907864858ULL, 29387358770557ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {10277214349659471872ULL, 16177650375369096972ULL, 293873587705571ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {10538423128046960640ULL, 14202551164014556797ULL, 2938735877055718ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {13150510911921848320ULL, 12898303124178706663ULL, 29387358770557187ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {2377900603251621888ULL, 18302566799529756941ULL, 293873587705571876ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {5332261958806667264ULL, 17004971331911604867ULL, 2938735877055718769ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {16429131440647569408ULL, 4029016655730084128ULL, 10940614696847636083ULL, 1ULL}),
+ BasicDecimal256FromLE({16717361816799281152ULL, 3396678409881738056ULL,
+ 17172426599928602752ULL, 15ULL}),
+ BasicDecimal256FromLE({1152921504606846976ULL, 15520040025107828953ULL,
+ 5703569335900062977ULL, 159ULL}),
+ BasicDecimal256FromLE({11529215046068469760ULL, 7626447661401876602ULL,
+ 1695461137871974930ULL, 1593ULL}),
+ BasicDecimal256FromLE({4611686018427387904ULL, 2477500319180559562ULL,
+ 16954611378719749304ULL, 15930ULL}),
+ BasicDecimal256FromLE({9223372036854775808ULL, 6328259118096044006ULL,
+ 3525417123811528497ULL, 159309ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 7942358959831785217ULL, 16807427164405733357ULL, 1593091ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 5636613303479645706ULL, 2053574980671369030ULL, 15930919ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 1025900813667802212ULL, 2089005733004138687ULL, 159309191ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 10259008136678022120ULL, 2443313256331835254ULL, 1593091911ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 10356360998232463120ULL, 5986388489608800929ULL, 15930919111ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 11329889613776873120ULL, 4523652674959354447ULL, 159309191113ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 2618431695511421504ULL, 8343038602174441244ULL, 1593091911132ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 7737572881404663424ULL, 9643409726906205977ULL, 15930919111324ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 3588752519208427776ULL, 4200376900514301694ULL, 159309191113245ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 17440781118374726144ULL, 5110280857723913709ULL, 1593091911132452ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 8387114520361296896ULL, 14209320429820033867ULL, 15930919111324522ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 10084168908774762496ULL, 12965995782233477362ULL, 159309191113245227ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 8607968719199866880ULL, 532749306367912313ULL, 1593091911132452277ULL})};
+
+static const BasicDecimal256 ScaleMultipliersHalfDecimal256[] = {
+ BasicDecimal256FromLE({0ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({5ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({50ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({500ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({5000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({50000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({500000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({5000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({50000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({500000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({5000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({50000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({500000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({5000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({50000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({500000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({5000000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({50000000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({500000000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({5000000000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({13106511852580896768ULL, 2ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({1937910009842106368ULL, 27ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({932356024711512064ULL, 271ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({9323560247115120640ULL, 2710ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({1001882102603448320ULL, 27105ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({10018821026034483200ULL, 271050ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({7954489891797073920ULL, 2710505ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({5757922623132532736ULL, 27105054ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({2238994010196672512ULL, 271050543ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({3943196028257173504ULL, 2710505431ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({2538472135152631808ULL, 27105054312ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({6937977277816766464ULL, 271050543121ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({14039540557039009792ULL, 2710505431213ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({11268197054423236608ULL, 27105054312137ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({2001506101975056384ULL, 271050543121376ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({1568316946041012224ULL, 2710505431213761ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({15683169460410122240ULL, 27105054312137610ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({9257742014424809472ULL, 271050543121376108ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({343699775700336640ULL, 2710505431213761085ULL, 0ULL, 0ULL}),
+ BasicDecimal256FromLE({3436997757003366400ULL, 8658310238428059234ULL, 1ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {15923233496324112384ULL, 12796126089442385877ULL, 14ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {11658382373564710912ULL, 17280796452166549082ULL, 146ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {5903359293389799424ULL, 6787267858279526282ULL, 1469ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {3693360712769339392ULL, 12532446361666607975ULL, 14693ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {40118980274290688ULL, 14643999174408770056ULL, 146936ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {401189802742906880ULL, 17312783228120839248ULL, 1469367ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {4011898027429068800ULL, 7107135617822427936ULL, 14693679ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {3225492126871584768ULL, 15731123957095624514ULL, 146936793ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {13808177195006296064ULL, 9737286981279832213ULL, 1469367938ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {8954563434096099328ULL, 5139149444250564057ULL, 14693679385ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {15758658046122786816ULL, 14498006295086537342ULL, 146936793852ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {10012627871551455232ULL, 15852854434898512116ULL, 1469367938527ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {7892558346966794240ULL, 10954591759308708237ULL, 14693679385278ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {5138607174829735936ULL, 17312197224539324294ULL, 146936793852785ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {14492583600878256128ULL, 7101275582007278398ULL, 1469367938527859ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {15798627492815699968ULL, 15672523598944129139ULL, 14693679385278593ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {10412322338480586752ULL, 9151283399764878470ULL, 146936793852785938ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {11889503016258109440ULL, 17725857702810578241ULL, 1469367938527859384ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {8214565720323784704ULL, 11237880364719817872ULL, 14693679385278593849ULL, 0ULL}),
+ BasicDecimal256FromLE(
+ {8358680908399640576ULL, 1698339204940869028ULL, 17809585336819077184ULL, 7ULL}),
+ BasicDecimal256FromLE({9799832789158199296ULL, 16983392049408690284ULL,
+ 12075156704804807296ULL, 79ULL}),
+ BasicDecimal256FromLE({5764607523034234880ULL, 3813223830700938301ULL,
+ 10071102605790763273ULL, 796ULL}),
+ BasicDecimal256FromLE({2305843009213693952ULL, 1238750159590279781ULL,
+ 8477305689359874652ULL, 7965ULL}),
+ BasicDecimal256FromLE({4611686018427387904ULL, 12387501595902797811ULL,
+ 10986080598760540056ULL, 79654ULL}),
+ BasicDecimal256FromLE({9223372036854775808ULL, 13194551516770668416ULL,
+ 17627085619057642486ULL, 796545ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 2818306651739822853ULL, 10250159527190460323ULL, 7965459ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 9736322443688676914ULL, 10267874903356845151ULL, 79654595ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 5129504068339011060ULL, 10445028665020693435ULL, 796545955ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 14401552535971007368ULL, 12216566281659176272ULL, 7965459555ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 14888316843743212368ULL, 11485198374334453031ULL, 79654595556ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 1309215847755710752ULL, 4171519301087220622ULL, 796545955566ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 13092158477557107520ULL, 4821704863453102988ULL, 7965459555662ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 1794376259604213888ULL, 11323560487111926655ULL, 79654595556622ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 17943762596042138880ULL, 2555140428861956854ULL, 796545955566226ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 13416929297035424256ULL, 7104660214910016933ULL, 7965459555662261ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 5042084454387381248ULL, 15706369927971514489ULL, 79654595556622613ULL}),
+ BasicDecimal256FromLE(
+ {0ULL, 13527356396454709248ULL, 9489746690038731964ULL, 796545955566226138ULL})};
+
+#undef BasicDecimal256FromLE
+
+#ifdef ARROW_USE_NATIVE_INT128
+static constexpr uint64_t kInt64Mask = 0xFFFFFFFFFFFFFFFF;
+#else
+static constexpr uint64_t kInt32Mask = 0xFFFFFFFF;
+#endif
+
+// same as ScaleMultipliers[38] - 1
+static constexpr BasicDecimal128 kMaxValue =
+ BasicDecimal128(5421010862427522170LL, 687399551400673280ULL - 1);
+
+#if ARROW_LITTLE_ENDIAN
+BasicDecimal128::BasicDecimal128(const uint8_t* bytes)
+ : BasicDecimal128(reinterpret_cast<const int64_t*>(bytes)[1],
+ reinterpret_cast<const uint64_t*>(bytes)[0]) {}
+#else
+BasicDecimal128::BasicDecimal128(const uint8_t* bytes)
+ : BasicDecimal128(reinterpret_cast<const int64_t*>(bytes)[0],
+ reinterpret_cast<const uint64_t*>(bytes)[1]) {}
+#endif
+
+constexpr int BasicDecimal128::kBitWidth;
+constexpr int BasicDecimal128::kMaxPrecision;
+constexpr int BasicDecimal128::kMaxScale;
+
+std::array<uint8_t, 16> BasicDecimal128::ToBytes() const {
+ std::array<uint8_t, 16> out{{0}};
+ ToBytes(out.data());
+ return out;
+}
+
+void BasicDecimal128::ToBytes(uint8_t* out) const {
+ DCHECK_NE(out, nullptr);
+#if ARROW_LITTLE_ENDIAN
+ reinterpret_cast<uint64_t*>(out)[0] = low_bits_;
+ reinterpret_cast<int64_t*>(out)[1] = high_bits_;
+#else
+ reinterpret_cast<int64_t*>(out)[0] = high_bits_;
+ reinterpret_cast<uint64_t*>(out)[1] = low_bits_;
+#endif
+}
+
+BasicDecimal128& BasicDecimal128::Negate() {
+ low_bits_ = ~low_bits_ + 1;
+ high_bits_ = ~high_bits_;
+ if (low_bits_ == 0) {
+ high_bits_ = SafeSignedAdd<int64_t>(high_bits_, 1);
+ }
+ return *this;
+}
+
+BasicDecimal128& BasicDecimal128::Abs() { return *this < 0 ? Negate() : *this; }
+
+BasicDecimal128 BasicDecimal128::Abs(const BasicDecimal128& in) {
+ BasicDecimal128 result(in);
+ return result.Abs();
+}
+
+bool BasicDecimal128::FitsInPrecision(int32_t precision) const {
+ DCHECK_GT(precision, 0);
+ DCHECK_LE(precision, 38);
+ return BasicDecimal128::Abs(*this) < ScaleMultipliers[precision];
+}
+
+BasicDecimal128& BasicDecimal128::operator+=(const BasicDecimal128& right) {
+ const uint64_t sum = low_bits_ + right.low_bits_;
+ high_bits_ = SafeSignedAdd<int64_t>(high_bits_, right.high_bits_);
+ if (sum < low_bits_) {
+ high_bits_ = SafeSignedAdd<int64_t>(high_bits_, 1);
+ }
+ low_bits_ = sum;
+ return *this;
+}
+
+BasicDecimal128& BasicDecimal128::operator-=(const BasicDecimal128& right) {
+ const uint64_t diff = low_bits_ - right.low_bits_;
+ high_bits_ -= right.high_bits_;
+ if (diff > low_bits_) {
+ --high_bits_;
+ }
+ low_bits_ = diff;
+ return *this;
+}
+
+BasicDecimal128& BasicDecimal128::operator/=(const BasicDecimal128& right) {
+ BasicDecimal128 remainder;
+ auto s = Divide(right, this, &remainder);
+ DCHECK_EQ(s, DecimalStatus::kSuccess);
+ return *this;
+}
+
+BasicDecimal128& BasicDecimal128::operator|=(const BasicDecimal128& right) {
+ low_bits_ |= right.low_bits_;
+ high_bits_ |= right.high_bits_;
+ return *this;
+}
+
+BasicDecimal128& BasicDecimal128::operator&=(const BasicDecimal128& right) {
+ low_bits_ &= right.low_bits_;
+ high_bits_ &= right.high_bits_;
+ return *this;
+}
+
+BasicDecimal128& BasicDecimal128::operator<<=(uint32_t bits) {
+ if (bits != 0) {
+ if (bits < 64) {
+ high_bits_ = SafeLeftShift(high_bits_, bits);
+ high_bits_ |= (low_bits_ >> (64 - bits));
+ low_bits_ <<= bits;
+ } else if (bits < 128) {
+ high_bits_ = static_cast<int64_t>(low_bits_) << (bits - 64);
+ low_bits_ = 0;
+ } else {
+ high_bits_ = 0;
+ low_bits_ = 0;
+ }
+ }
+ return *this;
+}
+
+BasicDecimal128& BasicDecimal128::operator>>=(uint32_t bits) {
+ if (bits != 0) {
+ if (bits < 64) {
+ low_bits_ >>= bits;
+ low_bits_ |= static_cast<uint64_t>(high_bits_ << (64 - bits));
+ high_bits_ = static_cast<int64_t>(static_cast<uint64_t>(high_bits_) >> bits);
+ } else if (bits < 128) {
+ low_bits_ = static_cast<uint64_t>(high_bits_ >> (bits - 64));
+ high_bits_ = static_cast<int64_t>(high_bits_ >= 0L ? 0L : -1L);
+ } else {
+ high_bits_ = static_cast<int64_t>(high_bits_ >= 0L ? 0L : -1L);
+ low_bits_ = static_cast<uint64_t>(high_bits_);
+ }
+ }
+ return *this;
+}
+
+namespace {
+
+// Convenience wrapper type over 128 bit unsigned integers. We opt not to
+// replace the uint128_t type in int128_internal.h because it would require
+// significantly more implementation work to be done. This class merely
+// provides the minimum necessary set of functions to perform 128+ bit
+// multiplication operations when there may or may not be native support.
+#ifdef ARROW_USE_NATIVE_INT128
+struct uint128_t {
+ uint128_t() {}
+ uint128_t(uint64_t hi, uint64_t lo) : val_((static_cast<__uint128_t>(hi) << 64) | lo) {}
+ explicit uint128_t(const BasicDecimal128& decimal) {
+ val_ = (static_cast<__uint128_t>(decimal.high_bits()) << 64) | decimal.low_bits();
+ }
+
+ explicit uint128_t(uint64_t value) : val_(value) {}
+
+ uint64_t hi() { return val_ >> 64; }
+ uint64_t lo() { return val_ & kInt64Mask; }
+
+ uint128_t& operator+=(const uint128_t& other) {
+ val_ += other.val_;
+ return *this;
+ }
+
+ uint128_t& operator*=(const uint128_t& other) {
+ val_ *= other.val_;
+ return *this;
+ }
+
+ __uint128_t val_;
+};
+
+#else
+// Multiply two 64 bit word components into a 128 bit result, with high bits
+// stored in hi and low bits in lo.
+inline void ExtendAndMultiply(uint64_t x, uint64_t y, uint64_t* hi, uint64_t* lo) {
+ // Perform multiplication on two 64 bit words x and y into a 128 bit result
+ // by splitting up x and y into 32 bit high/low bit components,
+ // allowing us to represent the multiplication as
+ // x * y = x_lo * y_lo + x_hi * y_lo * 2^32 + y_hi * x_lo * 2^32
+ // + x_hi * y_hi * 2^64
+ //
+ // Now, consider the final output as lo_lo || lo_hi || hi_lo || hi_hi
+ // Therefore,
+ // lo_lo is (x_lo * y_lo)_lo,
+ // lo_hi is ((x_lo * y_lo)_hi + (x_hi * y_lo)_lo + (x_lo * y_hi)_lo)_lo,
+ // hi_lo is ((x_hi * y_hi)_lo + (x_hi * y_lo)_hi + (x_lo * y_hi)_hi)_hi,
+ // hi_hi is (x_hi * y_hi)_hi
+ const uint64_t x_lo = x & kInt32Mask;
+ const uint64_t y_lo = y & kInt32Mask;
+ const uint64_t x_hi = x >> 32;
+ const uint64_t y_hi = y >> 32;
+
+ const uint64_t t = x_lo * y_lo;
+ const uint64_t t_lo = t & kInt32Mask;
+ const uint64_t t_hi = t >> 32;
+
+ const uint64_t u = x_hi * y_lo + t_hi;
+ const uint64_t u_lo = u & kInt32Mask;
+ const uint64_t u_hi = u >> 32;
+
+ const uint64_t v = x_lo * y_hi + u_lo;
+ const uint64_t v_hi = v >> 32;
+
+ *hi = x_hi * y_hi + u_hi + v_hi;
+ *lo = (v << 32) + t_lo;
+}
+
+struct uint128_t {
+ uint128_t() {}
+ uint128_t(uint64_t hi, uint64_t lo) : hi_(hi), lo_(lo) {}
+ explicit uint128_t(const BasicDecimal128& decimal) {
+ hi_ = decimal.high_bits();
+ lo_ = decimal.low_bits();
+ }
+
+ uint64_t hi() const { return hi_; }
+ uint64_t lo() const { return lo_; }
+
+ uint128_t& operator+=(const uint128_t& other) {
+ // To deduce the carry bit, we perform "65 bit" addition on the low bits and
+ // seeing if the resulting high bit is 1. This is accomplished by shifting the
+ // low bits to the right by 1 (chopping off the lowest bit), then adding 1 if the
+ // result of adding the two chopped bits would have produced a carry.
+ uint64_t carry = (((lo_ & other.lo_) & 1) + (lo_ >> 1) + (other.lo_ >> 1)) >> 63;
+ hi_ += other.hi_ + carry;
+ lo_ += other.lo_;
+ return *this;
+ }
+
+ uint128_t& operator*=(const uint128_t& other) {
+ uint128_t r;
+ ExtendAndMultiply(lo_, other.lo_, &r.hi_, &r.lo_);
+ r.hi_ += (hi_ * other.lo_) + (lo_ * other.hi_);
+ *this = r;
+ return *this;
+ }
+
+ uint64_t hi_;
+ uint64_t lo_;
+};
+#endif
+
+// Multiplies two N * 64 bit unsigned integer types, represented by a uint64_t
+// array into a same sized output. Elements in the array should be in
+// native endian order, and output will be the same. Overflow in multiplication
+// will result in the lower N * 64 bits of the result being set.
+template <int N>
+inline void MultiplyUnsignedArray(const std::array<uint64_t, N>& lh,
+ const std::array<uint64_t, N>& rh,
+ std::array<uint64_t, N>* result) {
+ const auto lh_le = BitUtil::LittleEndianArray::Make(lh);
+ const auto rh_le = BitUtil::LittleEndianArray::Make(rh);
+ auto result_le = BitUtil::LittleEndianArray::Make(result);
+
+ for (int j = 0; j < N; ++j) {
+ uint64_t carry = 0;
+ for (int i = 0; i < N - j; ++i) {
+ uint128_t tmp(lh_le[i]);
+ tmp *= uint128_t(rh_le[j]);
+ tmp += uint128_t(result_le[i + j]);
+ tmp += uint128_t(carry);
+ result_le[i + j] = tmp.lo();
+ carry = tmp.hi();
+ }
+ }
+}
+
+} // namespace
+
+BasicDecimal128& BasicDecimal128::operator*=(const BasicDecimal128& right) {
+ // Since the max value of BasicDecimal128 is supposed to be 1e38 - 1 and the
+ // min the negation taking the absolute values here should always be safe.
+ const bool negate = Sign() != right.Sign();
+ BasicDecimal128 x = BasicDecimal128::Abs(*this);
+ BasicDecimal128 y = BasicDecimal128::Abs(right);
+ uint128_t r(x);
+ r *= uint128_t{y};
+ high_bits_ = r.hi();
+ low_bits_ = r.lo();
+ if (negate) {
+ Negate();
+ }
+ return *this;
+}
+
+/// Expands the given native endian array of uint64_t into a big endian array of
+/// uint32_t. The value of input array is expected to be non-negative. The result_array
+/// will remove leading zeros from the input array.
+/// \param value_array a native endian array to represent the value
+/// \param result_array a big endian array of length N*2 to set with the value
+/// \result the output length of the array
+template <size_t N>
+static int64_t FillInArray(const std::array<uint64_t, N>& value_array,
+ uint32_t* result_array) {
+ const auto value_array_le = BitUtil::LittleEndianArray::Make(value_array);
+ int64_t next_index = 0;
+ // 1st loop to find out 1st non-negative value in input
+ int64_t i = N - 1;
+ for (; i >= 0; i--) {
+ if (value_array_le[i] != 0) {
+ if (value_array_le[i] <= std::numeric_limits<uint32_t>::max()) {
+ result_array[next_index++] = static_cast<uint32_t>(value_array_le[i]);
+ i--;
+ }
+ break;
+ }
+ }
+ // 2nd loop to fill in the rest of the array.
+ for (int64_t j = i; j >= 0; j--) {
+ result_array[next_index++] = static_cast<uint32_t>(value_array_le[j] >> 32);
+ result_array[next_index++] = static_cast<uint32_t>(value_array_le[j]);
+ }
+ return next_index;
+}
+
+/// Expands the given value into a big endian array of ints so that we can work on
+/// it. The array will be converted to an absolute value and the was_negative
+/// flag will be set appropriately. The array will remove leading zeros from
+/// the value.
+/// \param array a big endian array of length 4 to set with the value
+/// \param was_negative a flag for whether the value was original negative
+/// \result the output length of the array
+static int64_t FillInArray(const BasicDecimal128& value, uint32_t* array,
+ bool& was_negative) {
+ BasicDecimal128 abs_value = BasicDecimal128::Abs(value);
+ was_negative = value.high_bits() < 0;
+ uint64_t high = static_cast<uint64_t>(abs_value.high_bits());
+ uint64_t low = abs_value.low_bits();
+
+ // FillInArray(std::array<uint64_t, N>& value_array, uint32_t* result_array) is not
+ // called here as the following code has better performance, to avoid regression on
+ // BasicDecimal128 Division.
+ if (high != 0) {
+ if (high > std::numeric_limits<uint32_t>::max()) {
+ array[0] = static_cast<uint32_t>(high >> 32);
+ array[1] = static_cast<uint32_t>(high);
+ array[2] = static_cast<uint32_t>(low >> 32);
+ array[3] = static_cast<uint32_t>(low);
+ return 4;
+ }
+
+ array[0] = static_cast<uint32_t>(high);
+ array[1] = static_cast<uint32_t>(low >> 32);
+ array[2] = static_cast<uint32_t>(low);
+ return 3;
+ }
+
+ if (low > std::numeric_limits<uint32_t>::max()) {
+ array[0] = static_cast<uint32_t>(low >> 32);
+ array[1] = static_cast<uint32_t>(low);
+ return 2;
+ }
+
+ if (low == 0) {
+ return 0;
+ }
+
+ array[0] = static_cast<uint32_t>(low);
+ return 1;
+}
+
+/// Expands the given value into a big endian array of ints so that we can work on
+/// it. The array will be converted to an absolute value and the was_negative
+/// flag will be set appropriately. The array will remove leading zeros from
+/// the value.
+/// \param array a big endian array of length 8 to set with the value
+/// \param was_negative a flag for whether the value was original negative
+/// \result the output length of the array
+static int64_t FillInArray(const BasicDecimal256& value, uint32_t* array,
+ bool& was_negative) {
+ BasicDecimal256 positive_value = value;
+ was_negative = false;
+ if (positive_value.IsNegative()) {
+ positive_value.Negate();
+ was_negative = true;
+ }
+ return FillInArray<4>(positive_value.native_endian_array(), array);
+}
+
+/// Shift the number in the array left by bits positions.
+/// \param array the number to shift, must have length elements
+/// \param length the number of entries in the array
+/// \param bits the number of bits to shift (0 <= bits < 32)
+static void ShiftArrayLeft(uint32_t* array, int64_t length, int64_t bits) {
+ if (length > 0 && bits != 0) {
+ for (int64_t i = 0; i < length - 1; ++i) {
+ array[i] = (array[i] << bits) | (array[i + 1] >> (32 - bits));
+ }
+ array[length - 1] <<= bits;
+ }
+}
+
+/// Shift the number in the array right by bits positions.
+/// \param array the number to shift, must have length elements
+/// \param length the number of entries in the array
+/// \param bits the number of bits to shift (0 <= bits < 32)
+static inline void ShiftArrayRight(uint32_t* array, int64_t length, int64_t bits) {
+ if (length > 0 && bits != 0) {
+ for (int64_t i = length - 1; i > 0; --i) {
+ array[i] = (array[i] >> bits) | (array[i - 1] << (32 - bits));
+ }
+ array[0] >>= bits;
+ }
+}
+
+/// \brief Fix the signs of the result and remainder at the end of the division based on
+/// the signs of the dividend and divisor.
+template <class DecimalClass>
+static inline void FixDivisionSigns(DecimalClass* result, DecimalClass* remainder,
+ bool dividend_was_negative,
+ bool divisor_was_negative) {
+ if (dividend_was_negative != divisor_was_negative) {
+ result->Negate();
+ }
+
+ if (dividend_was_negative) {
+ remainder->Negate();
+ }
+}
+
+/// \brief Build a native endian array of uint64_t from a big endian array of uint32_t.
+template <size_t N>
+static DecimalStatus BuildFromArray(std::array<uint64_t, N>* result_array,
+ const uint32_t* array, int64_t length) {
+ for (int64_t i = length - 2 * N - 1; i >= 0; i--) {
+ if (array[i] != 0) {
+ return DecimalStatus::kOverflow;
+ }
+ }
+ int64_t next_index = length - 1;
+ size_t i = 0;
+ auto result_array_le = BitUtil::LittleEndianArray::Make(result_array);
+ for (; i < N && next_index >= 0; i++) {
+ uint64_t lower_bits = array[next_index--];
+ result_array_le[i] =
+ (next_index < 0)
+ ? lower_bits
+ : ((static_cast<uint64_t>(array[next_index--]) << 32) + lower_bits);
+ }
+ for (; i < N; i++) {
+ result_array_le[i] = 0;
+ }
+ return DecimalStatus::kSuccess;
+}
+
+/// \brief Build a BasicDecimal128 from a big endian array of uint32_t.
+static DecimalStatus BuildFromArray(BasicDecimal128* value, const uint32_t* array,
+ int64_t length) {
+ std::array<uint64_t, 2> result_array;
+ auto status = BuildFromArray(&result_array, array, length);
+ if (status != DecimalStatus::kSuccess) {
+ return status;
+ }
+ const auto result_array_le = BitUtil::LittleEndianArray::Make(result_array);
+ *value = {static_cast<int64_t>(result_array_le[1]), result_array_le[0]};
+ return DecimalStatus::kSuccess;
+}
+
+/// \brief Build a BasicDecimal256 from a big endian array of uint32_t.
+static DecimalStatus BuildFromArray(BasicDecimal256* value, const uint32_t* array,
+ int64_t length) {
+ std::array<uint64_t, 4> result_array;
+ auto status = BuildFromArray(&result_array, array, length);
+ if (status != DecimalStatus::kSuccess) {
+ return status;
+ }
+ *value = result_array;
+ return DecimalStatus::kSuccess;
+}
+
+/// \brief Do a division where the divisor fits into a single 32 bit value.
+template <class DecimalClass>
+static inline DecimalStatus SingleDivide(const uint32_t* dividend,
+ int64_t dividend_length, uint32_t divisor,
+ DecimalClass* remainder,
+ bool dividend_was_negative,
+ bool divisor_was_negative,
+ DecimalClass* result) {
+ uint64_t r = 0;
+ constexpr int64_t kDecimalArrayLength = DecimalClass::kBitWidth / sizeof(uint32_t) + 1;
+ uint32_t result_array[kDecimalArrayLength];
+ for (int64_t j = 0; j < dividend_length; j++) {
+ r <<= 32;
+ r += dividend[j];
+ result_array[j] = static_cast<uint32_t>(r / divisor);
+ r %= divisor;
+ }
+ auto status = BuildFromArray(result, result_array, dividend_length);
+ if (status != DecimalStatus::kSuccess) {
+ return status;
+ }
+
+ *remainder = static_cast<int64_t>(r);
+ FixDivisionSigns(result, remainder, dividend_was_negative, divisor_was_negative);
+ return DecimalStatus::kSuccess;
+}
+
+/// \brief Do a decimal division with remainder.
+template <class DecimalClass>
+static inline DecimalStatus DecimalDivide(const DecimalClass& dividend,
+ const DecimalClass& divisor,
+ DecimalClass* result, DecimalClass* remainder) {
+ constexpr int64_t kDecimalArrayLength = DecimalClass::kBitWidth / sizeof(uint32_t);
+ // Split the dividend and divisor into integer pieces so that we can
+ // work on them.
+ uint32_t dividend_array[kDecimalArrayLength + 1];
+ uint32_t divisor_array[kDecimalArrayLength];
+ bool dividend_was_negative;
+ bool divisor_was_negative;
+ // leave an extra zero before the dividend
+ dividend_array[0] = 0;
+ int64_t dividend_length =
+ FillInArray(dividend, dividend_array + 1, dividend_was_negative) + 1;
+ int64_t divisor_length = FillInArray(divisor, divisor_array, divisor_was_negative);
+
+ // Handle some of the easy cases.
+ if (dividend_length <= divisor_length) {
+ *remainder = dividend;
+ *result = 0;
+ return DecimalStatus::kSuccess;
+ }
+
+ if (divisor_length == 0) {
+ return DecimalStatus::kDivideByZero;
+ }
+
+ if (divisor_length == 1) {
+ return SingleDivide(dividend_array, dividend_length, divisor_array[0], remainder,
+ dividend_was_negative, divisor_was_negative, result);
+ }
+
+ int64_t result_length = dividend_length - divisor_length;
+ uint32_t result_array[kDecimalArrayLength];
+ DCHECK_LE(result_length, kDecimalArrayLength);
+
+ // Normalize by shifting both by a multiple of 2 so that
+ // the digit guessing is better. The requirement is that
+ // divisor_array[0] is greater than 2**31.
+ int64_t normalize_bits = BitUtil::CountLeadingZeros(divisor_array[0]);
+ ShiftArrayLeft(divisor_array, divisor_length, normalize_bits);
+ ShiftArrayLeft(dividend_array, dividend_length, normalize_bits);
+
+ // compute each digit in the result
+ for (int64_t j = 0; j < result_length; ++j) {
+ // Guess the next digit. At worst it is two too large
+ uint32_t guess = std::numeric_limits<uint32_t>::max();
+ const auto high_dividend =
+ static_cast<uint64_t>(dividend_array[j]) << 32 | dividend_array[j + 1];
+ if (dividend_array[j] != divisor_array[0]) {
+ guess = static_cast<uint32_t>(high_dividend / divisor_array[0]);
+ }
+
+ // catch all of the cases where guess is two too large and most of the
+ // cases where it is one too large
+ auto rhat = static_cast<uint32_t>(high_dividend -
+ guess * static_cast<uint64_t>(divisor_array[0]));
+ while (static_cast<uint64_t>(divisor_array[1]) * guess >
+ (static_cast<uint64_t>(rhat) << 32) + dividend_array[j + 2]) {
+ --guess;
+ rhat += divisor_array[0];
+ if (static_cast<uint64_t>(rhat) < divisor_array[0]) {
+ break;
+ }
+ }
+
+ // subtract off the guess * divisor from the dividend
+ uint64_t mult = 0;
+ for (int64_t i = divisor_length - 1; i >= 0; --i) {
+ mult += static_cast<uint64_t>(guess) * divisor_array[i];
+ uint32_t prev = dividend_array[j + i + 1];
+ dividend_array[j + i + 1] -= static_cast<uint32_t>(mult);
+ mult >>= 32;
+ if (dividend_array[j + i + 1] > prev) {
+ ++mult;
+ }
+ }
+ uint32_t prev = dividend_array[j];
+ dividend_array[j] -= static_cast<uint32_t>(mult);
+
+ // if guess was too big, we add back divisor
+ if (dividend_array[j] > prev) {
+ --guess;
+ uint32_t carry = 0;
+ for (int64_t i = divisor_length - 1; i >= 0; --i) {
+ const auto sum =
+ static_cast<uint64_t>(divisor_array[i]) + dividend_array[j + i + 1] + carry;
+ dividend_array[j + i + 1] = static_cast<uint32_t>(sum);
+ carry = static_cast<uint32_t>(sum >> 32);
+ }
+ dividend_array[j] += carry;
+ }
+
+ result_array[j] = guess;
+ }
+
+ // denormalize the remainder
+ ShiftArrayRight(dividend_array, dividend_length, normalize_bits);
+
+ // return result and remainder
+ auto status = BuildFromArray(result, result_array, result_length);
+ if (status != DecimalStatus::kSuccess) {
+ return status;
+ }
+ status = BuildFromArray(remainder, dividend_array, dividend_length);
+ if (status != DecimalStatus::kSuccess) {
+ return status;
+ }
+
+ FixDivisionSigns(result, remainder, dividend_was_negative, divisor_was_negative);
+ return DecimalStatus::kSuccess;
+}
+
+DecimalStatus BasicDecimal128::Divide(const BasicDecimal128& divisor,
+ BasicDecimal128* result,
+ BasicDecimal128* remainder) const {
+ return DecimalDivide(*this, divisor, result, remainder);
+}
+
+bool operator==(const BasicDecimal128& left, const BasicDecimal128& right) {
+ return left.high_bits() == right.high_bits() && left.low_bits() == right.low_bits();
+}
+
+bool operator!=(const BasicDecimal128& left, const BasicDecimal128& right) {
+ return !operator==(left, right);
+}
+
+bool operator<(const BasicDecimal128& left, const BasicDecimal128& right) {
+ return left.high_bits() < right.high_bits() ||
+ (left.high_bits() == right.high_bits() && left.low_bits() < right.low_bits());
+}
+
+bool operator<=(const BasicDecimal128& left, const BasicDecimal128& right) {
+ return !operator>(left, right);
+}
+
+bool operator>(const BasicDecimal128& left, const BasicDecimal128& right) {
+ return operator<(right, left);
+}
+
+bool operator>=(const BasicDecimal128& left, const BasicDecimal128& right) {
+ return !operator<(left, right);
+}
+
+BasicDecimal128 operator-(const BasicDecimal128& operand) {
+ BasicDecimal128 result(operand.high_bits(), operand.low_bits());
+ return result.Negate();
+}
+
+BasicDecimal128 operator~(const BasicDecimal128& operand) {
+ BasicDecimal128 result(~operand.high_bits(), ~operand.low_bits());
+ return result;
+}
+
+BasicDecimal128 operator+(const BasicDecimal128& left, const BasicDecimal128& right) {
+ BasicDecimal128 result(left.high_bits(), left.low_bits());
+ result += right;
+ return result;
+}
+
+BasicDecimal128 operator-(const BasicDecimal128& left, const BasicDecimal128& right) {
+ BasicDecimal128 result(left.high_bits(), left.low_bits());
+ result -= right;
+ return result;
+}
+
+BasicDecimal128 operator*(const BasicDecimal128& left, const BasicDecimal128& right) {
+ BasicDecimal128 result(left.high_bits(), left.low_bits());
+ result *= right;
+ return result;
+}
+
+BasicDecimal128 operator/(const BasicDecimal128& left, const BasicDecimal128& right) {
+ BasicDecimal128 remainder;
+ BasicDecimal128 result;
+ auto s = left.Divide(right, &result, &remainder);
+ DCHECK_EQ(s, DecimalStatus::kSuccess);
+ return result;
+}
+
+BasicDecimal128 operator%(const BasicDecimal128& left, const BasicDecimal128& right) {
+ BasicDecimal128 remainder;
+ BasicDecimal128 result;
+ auto s = left.Divide(right, &result, &remainder);
+ DCHECK_EQ(s, DecimalStatus::kSuccess);
+ return remainder;
+}
+
+template <class DecimalClass>
+static bool RescaleWouldCauseDataLoss(const DecimalClass& value, int32_t delta_scale,
+ const DecimalClass& multiplier,
+ DecimalClass* result) {
+ if (delta_scale < 0) {
+ DCHECK_NE(multiplier, 0);
+ DecimalClass remainder;
+ auto status = value.Divide(multiplier, result, &remainder);
+ DCHECK_EQ(status, DecimalStatus::kSuccess);
+ return remainder != 0;
+ }
+
+ *result = value * multiplier;
+ return (value < 0) ? *result > value : *result < value;
+}
+
+template <class DecimalClass>
+DecimalStatus DecimalRescale(const DecimalClass& value, int32_t original_scale,
+ int32_t new_scale, DecimalClass* out) {
+ DCHECK_NE(out, nullptr);
+
+ if (original_scale == new_scale) {
+ *out = value;
+ return DecimalStatus::kSuccess;
+ }
+
+ const int32_t delta_scale = new_scale - original_scale;
+ const int32_t abs_delta_scale = std::abs(delta_scale);
+
+ DecimalClass multiplier = DecimalClass::GetScaleMultiplier(abs_delta_scale);
+
+ const bool rescale_would_cause_data_loss =
+ RescaleWouldCauseDataLoss(value, delta_scale, multiplier, out);
+
+ // Fail if we overflow or truncate
+ if (ARROW_PREDICT_FALSE(rescale_would_cause_data_loss)) {
+ return DecimalStatus::kRescaleDataLoss;
+ }
+
+ return DecimalStatus::kSuccess;
+}
+
+DecimalStatus BasicDecimal128::Rescale(int32_t original_scale, int32_t new_scale,
+ BasicDecimal128* out) const {
+ return DecimalRescale(*this, original_scale, new_scale, out);
+}
+
+void BasicDecimal128::GetWholeAndFraction(int scale, BasicDecimal128* whole,
+ BasicDecimal128* fraction) const {
+ DCHECK_GE(scale, 0);
+ DCHECK_LE(scale, 38);
+
+ BasicDecimal128 multiplier(ScaleMultipliers[scale]);
+ auto s = Divide(multiplier, whole, fraction);
+ DCHECK_EQ(s, DecimalStatus::kSuccess);
+}
+
+const BasicDecimal128& BasicDecimal128::GetScaleMultiplier(int32_t scale) {
+ DCHECK_GE(scale, 0);
+ DCHECK_LE(scale, 38);
+
+ return ScaleMultipliers[scale];
+}
+
+const BasicDecimal128& BasicDecimal128::GetHalfScaleMultiplier(int32_t scale) {
+ DCHECK_GE(scale, 0);
+ DCHECK_LE(scale, 38);
+
+ return ScaleMultipliersHalf[scale];
+}
+
+const BasicDecimal128& BasicDecimal128::GetMaxValue() { return kMaxValue; }
+
+BasicDecimal128 BasicDecimal128::IncreaseScaleBy(int32_t increase_by) const {
+ DCHECK_GE(increase_by, 0);
+ DCHECK_LE(increase_by, 38);
+
+ return (*this) * ScaleMultipliers[increase_by];
+}
+
+BasicDecimal128 BasicDecimal128::ReduceScaleBy(int32_t reduce_by, bool round) const {
+ DCHECK_GE(reduce_by, 0);
+ DCHECK_LE(reduce_by, 38);
+
+ if (reduce_by == 0) {
+ return *this;
+ }
+
+ BasicDecimal128 divisor(ScaleMultipliers[reduce_by]);
+ BasicDecimal128 result;
+ BasicDecimal128 remainder;
+ auto s = Divide(divisor, &result, &remainder);
+ DCHECK_EQ(s, DecimalStatus::kSuccess);
+ if (round) {
+ auto divisor_half = ScaleMultipliersHalf[reduce_by];
+ if (remainder.Abs() >= divisor_half) {
+ if (result > 0) {
+ result += 1;
+ } else {
+ result -= 1;
+ }
+ }
+ }
+ return result;
+}
+
+int32_t BasicDecimal128::CountLeadingBinaryZeros() const {
+ DCHECK_GE(*this, BasicDecimal128(0));
+
+ if (high_bits_ == 0) {
+ return BitUtil::CountLeadingZeros(low_bits_) + 64;
+ } else {
+ return BitUtil::CountLeadingZeros(static_cast<uint64_t>(high_bits_));
+ }
+}
+
+BasicDecimal256::BasicDecimal256(const uint8_t* bytes)
+ : array_({reinterpret_cast<const uint64_t*>(bytes)[0],
+ reinterpret_cast<const uint64_t*>(bytes)[1],
+ reinterpret_cast<const uint64_t*>(bytes)[2],
+ reinterpret_cast<const uint64_t*>(bytes)[3]}) {}
+
+constexpr int BasicDecimal256::kBitWidth;
+constexpr int BasicDecimal256::kMaxPrecision;
+constexpr int BasicDecimal256::kMaxScale;
+
+BasicDecimal256& BasicDecimal256::Negate() {
+ auto array_le = BitUtil::LittleEndianArray::Make(&array_);
+ uint64_t carry = 1;
+ for (size_t i = 0; i < array_.size(); ++i) {
+ uint64_t& elem = array_le[i];
+ elem = ~elem + carry;
+ carry &= (elem == 0);
+ }
+ return *this;
+}
+
+BasicDecimal256& BasicDecimal256::Abs() { return *this < 0 ? Negate() : *this; }
+
+BasicDecimal256 BasicDecimal256::Abs(const BasicDecimal256& in) {
+ BasicDecimal256 result(in);
+ return result.Abs();
+}
+
+BasicDecimal256& BasicDecimal256::operator+=(const BasicDecimal256& right) {
+ auto array_le = BitUtil::LittleEndianArray::Make(&array_);
+ const auto right_array_le = BitUtil::LittleEndianArray::Make(right.array_);
+ uint64_t carry = 0;
+ for (size_t i = 0; i < array_.size(); i++) {
+ const uint64_t right_value = right_array_le[i];
+ uint64_t sum = right_value + carry;
+ carry = 0;
+ if (sum < right_value) {
+ carry += 1;
+ }
+ sum += array_le[i];
+ if (sum < array_le[i]) {
+ carry += 1;
+ }
+ array_le[i] = sum;
+ }
+ return *this;
+}
+
+BasicDecimal256& BasicDecimal256::operator-=(const BasicDecimal256& right) {
+ *this += -right;
+ return *this;
+}
+
+BasicDecimal256& BasicDecimal256::operator<<=(uint32_t bits) {
+ if (bits == 0) {
+ return *this;
+ }
+ int cross_word_shift = bits / 64;
+ if (static_cast<size_t>(cross_word_shift) >= array_.size()) {
+ array_ = {0, 0, 0, 0};
+ return *this;
+ }
+ uint32_t in_word_shift = bits % 64;
+ auto array_le = BitUtil::LittleEndianArray::Make(&array_);
+ for (int i = static_cast<int>(array_.size() - 1); i >= cross_word_shift; i--) {
+ // Account for shifts larger then 64 bits
+ array_le[i] = array_le[i - cross_word_shift];
+ array_le[i] <<= in_word_shift;
+ if (in_word_shift != 0 && i >= cross_word_shift + 1) {
+ array_le[i] |= array_le[i - (cross_word_shift + 1)] >> (64 - in_word_shift);
+ }
+ }
+ for (int i = cross_word_shift - 1; i >= 0; i--) {
+ array_le[i] = 0;
+ }
+ return *this;
+}
+
+std::array<uint8_t, 32> BasicDecimal256::ToBytes() const {
+ std::array<uint8_t, 32> out{{0}};
+ ToBytes(out.data());
+ return out;
+}
+
+void BasicDecimal256::ToBytes(uint8_t* out) const {
+ DCHECK_NE(out, nullptr);
+ reinterpret_cast<uint64_t*>(out)[0] = array_[0];
+ reinterpret_cast<uint64_t*>(out)[1] = array_[1];
+ reinterpret_cast<uint64_t*>(out)[2] = array_[2];
+ reinterpret_cast<uint64_t*>(out)[3] = array_[3];
+}
+
+BasicDecimal256& BasicDecimal256::operator*=(const BasicDecimal256& right) {
+ // Since the max value of BasicDecimal256 is supposed to be 1e76 - 1 and the
+ // min the negation taking the absolute values here should always be safe.
+ const bool negate = Sign() != right.Sign();
+ BasicDecimal256 x = BasicDecimal256::Abs(*this);
+ BasicDecimal256 y = BasicDecimal256::Abs(right);
+
+ std::array<uint64_t, 4> res{0, 0, 0, 0};
+ MultiplyUnsignedArray<4>(x.array_, y.array_, &res);
+ array_ = res;
+ if (negate) {
+ Negate();
+ }
+ return *this;
+}
+
+DecimalStatus BasicDecimal256::Divide(const BasicDecimal256& divisor,
+ BasicDecimal256* result,
+ BasicDecimal256* remainder) const {
+ return DecimalDivide(*this, divisor, result, remainder);
+}
+
+DecimalStatus BasicDecimal256::Rescale(int32_t original_scale, int32_t new_scale,
+ BasicDecimal256* out) const {
+ return DecimalRescale(*this, original_scale, new_scale, out);
+}
+
+BasicDecimal256 BasicDecimal256::IncreaseScaleBy(int32_t increase_by) const {
+ DCHECK_GE(increase_by, 0);
+ DCHECK_LE(increase_by, 76);
+
+ return (*this) * ScaleMultipliersDecimal256[increase_by];
+}
+
+BasicDecimal256 BasicDecimal256::ReduceScaleBy(int32_t reduce_by, bool round) const {
+ DCHECK_GE(reduce_by, 0);
+ DCHECK_LE(reduce_by, 76);
+
+ if (reduce_by == 0) {
+ return *this;
+ }
+
+ BasicDecimal256 divisor(ScaleMultipliersDecimal256[reduce_by]);
+ BasicDecimal256 result;
+ BasicDecimal256 remainder;
+ auto s = Divide(divisor, &result, &remainder);
+ DCHECK_EQ(s, DecimalStatus::kSuccess);
+ if (round) {
+ auto divisor_half = ScaleMultipliersHalfDecimal256[reduce_by];
+ if (remainder.Abs() >= divisor_half) {
+ if (result > 0) {
+ result += 1;
+ } else {
+ result -= 1;
+ }
+ }
+ }
+ return result;
+}
+
+bool BasicDecimal256::FitsInPrecision(int32_t precision) const {
+ DCHECK_GT(precision, 0);
+ DCHECK_LE(precision, 76);
+ return BasicDecimal256::Abs(*this) < ScaleMultipliersDecimal256[precision];
+}
+
+const BasicDecimal256& BasicDecimal256::GetScaleMultiplier(int32_t scale) {
+ DCHECK_GE(scale, 0);
+ DCHECK_LE(scale, 76);
+
+ return ScaleMultipliersDecimal256[scale];
+}
+
+const BasicDecimal256& BasicDecimal256::GetHalfScaleMultiplier(int32_t scale) {
+ DCHECK_GE(scale, 0);
+ DCHECK_LE(scale, 76);
+
+ return ScaleMultipliersHalfDecimal256[scale];
+}
+
+BasicDecimal256 operator*(const BasicDecimal256& left, const BasicDecimal256& right) {
+ BasicDecimal256 result = left;
+ result *= right;
+ return result;
+}
+
+bool operator<(const BasicDecimal256& left, const BasicDecimal256& right) {
+ const auto lhs_le = BitUtil::LittleEndianArray::Make(left.native_endian_array());
+ const auto rhs_le = BitUtil::LittleEndianArray::Make(right.native_endian_array());
+ return lhs_le[3] != rhs_le[3]
+ ? static_cast<int64_t>(lhs_le[3]) < static_cast<int64_t>(rhs_le[3])
+ : lhs_le[2] != rhs_le[2] ? lhs_le[2] < rhs_le[2]
+ : lhs_le[1] != rhs_le[1] ? lhs_le[1] < rhs_le[1]
+ : lhs_le[0] < rhs_le[0];
+}
+
+BasicDecimal256 operator-(const BasicDecimal256& operand) {
+ BasicDecimal256 result(operand);
+ return result.Negate();
+}
+
+BasicDecimal256 operator~(const BasicDecimal256& operand) {
+ const std::array<uint64_t, 4>& arr = operand.native_endian_array();
+ BasicDecimal256 result({~arr[0], ~arr[1], ~arr[2], ~arr[3]});
+ return result;
+}
+
+BasicDecimal256& BasicDecimal256::operator/=(const BasicDecimal256& right) {
+ BasicDecimal256 remainder;
+ auto s = Divide(right, this, &remainder);
+ DCHECK_EQ(s, DecimalStatus::kSuccess);
+ return *this;
+}
+
+BasicDecimal256 operator+(const BasicDecimal256& left, const BasicDecimal256& right) {
+ BasicDecimal256 sum = left;
+ sum += right;
+ return sum;
+}
+
+BasicDecimal256 operator/(const BasicDecimal256& left, const BasicDecimal256& right) {
+ BasicDecimal256 remainder;
+ BasicDecimal256 result;
+ auto s = left.Divide(right, &result, &remainder);
+ DCHECK_EQ(s, DecimalStatus::kSuccess);
+ return result;
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/basic_decimal.h b/src/arrow/cpp/src/arrow/util/basic_decimal.h
new file mode 100644
index 000000000..a4df32855
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/basic_decimal.h
@@ -0,0 +1,494 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <array>
+#include <cstdint>
+#include <limits>
+#include <string>
+#include <type_traits>
+
+#include "arrow/util/endian.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/type_traits.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+enum class DecimalStatus {
+ kSuccess,
+ kDivideByZero,
+ kOverflow,
+ kRescaleDataLoss,
+};
+
+/// Represents a signed 128-bit integer in two's complement.
+///
+/// This class is also compiled into LLVM IR - so, it should not have cpp references like
+/// streams and boost.
+class ARROW_EXPORT BasicDecimal128 {
+ struct LittleEndianArrayTag {};
+
+ public:
+ static constexpr int kBitWidth = 128;
+ static constexpr int kMaxPrecision = 38;
+ static constexpr int kMaxScale = 38;
+
+ // A constructor tag to introduce a little-endian encoded array
+ static constexpr LittleEndianArrayTag LittleEndianArray{};
+
+ /// \brief Create a BasicDecimal128 from the two's complement representation.
+#if ARROW_LITTLE_ENDIAN
+ constexpr BasicDecimal128(int64_t high, uint64_t low) noexcept
+ : low_bits_(low), high_bits_(high) {}
+#else
+ constexpr BasicDecimal128(int64_t high, uint64_t low) noexcept
+ : high_bits_(high), low_bits_(low) {}
+#endif
+
+ /// \brief Create a BasicDecimal256 from the two's complement representation.
+ ///
+ /// Input array is assumed to be in native endianness.
+#if ARROW_LITTLE_ENDIAN
+ constexpr BasicDecimal128(const std::array<uint64_t, 2>& array) noexcept
+ : low_bits_(array[0]), high_bits_(static_cast<int64_t>(array[1])) {}
+#else
+ constexpr BasicDecimal128(const std::array<uint64_t, 2>& array) noexcept
+ : high_bits_(static_cast<int64_t>(array[0])), low_bits_(array[1]) {}
+#endif
+
+ /// \brief Create a BasicDecimal128 from the two's complement representation.
+ ///
+ /// Input array is assumed to be in little endianness, with native endian elements.
+ BasicDecimal128(LittleEndianArrayTag, const std::array<uint64_t, 2>& array) noexcept
+ : BasicDecimal128(BitUtil::LittleEndianArray::ToNative(array)) {}
+
+ /// \brief Empty constructor creates a BasicDecimal128 with a value of 0.
+ constexpr BasicDecimal128() noexcept : BasicDecimal128(0, 0) {}
+
+ /// \brief Convert any integer value into a BasicDecimal128.
+ template <typename T,
+ typename = typename std::enable_if<
+ std::is_integral<T>::value && (sizeof(T) <= sizeof(uint64_t)), T>::type>
+ constexpr BasicDecimal128(T value) noexcept
+ : BasicDecimal128(value >= T{0} ? 0 : -1, static_cast<uint64_t>(value)) { // NOLINT
+ }
+
+ /// \brief Create a BasicDecimal128 from an array of bytes. Bytes are assumed to be in
+ /// native-endian byte order.
+ explicit BasicDecimal128(const uint8_t* bytes);
+
+ /// \brief Negate the current value (in-place)
+ BasicDecimal128& Negate();
+
+ /// \brief Absolute value (in-place)
+ BasicDecimal128& Abs();
+
+ /// \brief Absolute value
+ static BasicDecimal128 Abs(const BasicDecimal128& left);
+
+ /// \brief Add a number to this one. The result is truncated to 128 bits.
+ BasicDecimal128& operator+=(const BasicDecimal128& right);
+
+ /// \brief Subtract a number from this one. The result is truncated to 128 bits.
+ BasicDecimal128& operator-=(const BasicDecimal128& right);
+
+ /// \brief Multiply this number by another number. The result is truncated to 128 bits.
+ BasicDecimal128& operator*=(const BasicDecimal128& right);
+
+ /// Divide this number by right and return the result.
+ ///
+ /// This operation is not destructive.
+ /// The answer rounds to zero. Signs work like:
+ /// 21 / 5 -> 4, 1
+ /// -21 / 5 -> -4, -1
+ /// 21 / -5 -> -4, 1
+ /// -21 / -5 -> 4, -1
+ /// \param[in] divisor the number to divide by
+ /// \param[out] result the quotient
+ /// \param[out] remainder the remainder after the division
+ DecimalStatus Divide(const BasicDecimal128& divisor, BasicDecimal128* result,
+ BasicDecimal128* remainder) const;
+
+ /// \brief In-place division.
+ BasicDecimal128& operator/=(const BasicDecimal128& right);
+
+ /// \brief Bitwise "or" between two BasicDecimal128.
+ BasicDecimal128& operator|=(const BasicDecimal128& right);
+
+ /// \brief Bitwise "and" between two BasicDecimal128.
+ BasicDecimal128& operator&=(const BasicDecimal128& right);
+
+ /// \brief Shift left by the given number of bits.
+ BasicDecimal128& operator<<=(uint32_t bits);
+
+ /// \brief Shift right by the given number of bits. Negative values will
+ BasicDecimal128& operator>>=(uint32_t bits);
+
+ /// \brief Get the high bits of the two's complement representation of the number.
+ inline constexpr int64_t high_bits() const { return high_bits_; }
+
+ /// \brief Get the low bits of the two's complement representation of the number.
+ inline constexpr uint64_t low_bits() const { return low_bits_; }
+
+ /// \brief Get the bits of the two's complement representation of the number.
+ ///
+ /// The 2 elements are in native endian order. The bits within each uint64_t element
+ /// are in native endian order. For example, on a little endian machine,
+ /// BasicDecimal128(123).native_endian_array() = {123, 0};
+ /// but on a big endian machine,
+ /// BasicDecimal128(123).native_endian_array() = {0, 123};
+ inline std::array<uint64_t, 2> native_endian_array() const {
+#if ARROW_LITTLE_ENDIAN
+ return {low_bits_, static_cast<uint64_t>(high_bits_)};
+#else
+ return {static_cast<uint64_t>(high_bits_), low_bits_};
+#endif
+ }
+
+ /// \brief Get the bits of the two's complement representation of the number.
+ ///
+ /// The 2 elements are in little endian order. However, the bits within each
+ /// uint64_t element are in native endian order.
+ /// For example, BasicDecimal128(123).little_endian_array() = {123, 0};
+ inline std::array<uint64_t, 2> little_endian_array() const {
+ return {low_bits_, static_cast<uint64_t>(high_bits_)};
+ }
+
+ inline const uint8_t* native_endian_bytes() const {
+#if ARROW_LITTLE_ENDIAN
+ return reinterpret_cast<const uint8_t*>(&low_bits_);
+#else
+ return reinterpret_cast<const uint8_t*>(&high_bits_);
+#endif
+ }
+
+ inline uint8_t* mutable_native_endian_bytes() {
+#if ARROW_LITTLE_ENDIAN
+ return reinterpret_cast<uint8_t*>(&low_bits_);
+#else
+ return reinterpret_cast<uint8_t*>(&high_bits_);
+#endif
+ }
+
+ /// \brief Return the raw bytes of the value in native-endian byte order.
+ std::array<uint8_t, 16> ToBytes() const;
+ void ToBytes(uint8_t* out) const;
+
+ /// \brief separate the integer and fractional parts for the given scale.
+ void GetWholeAndFraction(int32_t scale, BasicDecimal128* whole,
+ BasicDecimal128* fraction) const;
+
+ /// \brief Scale multiplier for given scale value.
+ static const BasicDecimal128& GetScaleMultiplier(int32_t scale);
+ /// \brief Half-scale multiplier for given scale value.
+ static const BasicDecimal128& GetHalfScaleMultiplier(int32_t scale);
+
+ /// \brief Convert BasicDecimal128 from one scale to another
+ DecimalStatus Rescale(int32_t original_scale, int32_t new_scale,
+ BasicDecimal128* out) const;
+
+ /// \brief Scale up.
+ BasicDecimal128 IncreaseScaleBy(int32_t increase_by) const;
+
+ /// \brief Scale down.
+ /// - If 'round' is true, the right-most digits are dropped and the result value is
+ /// rounded up (+1 for +ve, -1 for -ve) based on the value of the dropped digits
+ /// (>= 10^reduce_by / 2).
+ /// - If 'round' is false, the right-most digits are simply dropped.
+ BasicDecimal128 ReduceScaleBy(int32_t reduce_by, bool round = true) const;
+
+ /// \brief Whether this number fits in the given precision
+ ///
+ /// Return true if the number of significant digits is less or equal to `precision`.
+ bool FitsInPrecision(int32_t precision) const;
+
+ // returns 1 for positive and zero decimal values, -1 for negative decimal values.
+ inline int64_t Sign() const { return 1 | (high_bits_ >> 63); }
+
+ /// \brief count the number of leading binary zeroes.
+ int32_t CountLeadingBinaryZeros() const;
+
+ /// \brief Get the maximum valid unscaled decimal value.
+ static const BasicDecimal128& GetMaxValue();
+
+ /// \brief Get the maximum decimal value (is not a valid value).
+ static inline constexpr BasicDecimal128 GetMaxSentinel() {
+ return BasicDecimal128(/*high=*/std::numeric_limits<int64_t>::max(),
+ /*low=*/std::numeric_limits<uint64_t>::max());
+ }
+ /// \brief Get the minimum decimal value (is not a valid value).
+ static inline constexpr BasicDecimal128 GetMinSentinel() {
+ return BasicDecimal128(/*high=*/std::numeric_limits<int64_t>::min(),
+ /*low=*/std::numeric_limits<uint64_t>::min());
+ }
+
+ private:
+#if ARROW_LITTLE_ENDIAN
+ uint64_t low_bits_;
+ int64_t high_bits_;
+#else
+ int64_t high_bits_;
+ uint64_t low_bits_;
+#endif
+};
+
+ARROW_EXPORT bool operator==(const BasicDecimal128& left, const BasicDecimal128& right);
+ARROW_EXPORT bool operator!=(const BasicDecimal128& left, const BasicDecimal128& right);
+ARROW_EXPORT bool operator<(const BasicDecimal128& left, const BasicDecimal128& right);
+ARROW_EXPORT bool operator<=(const BasicDecimal128& left, const BasicDecimal128& right);
+ARROW_EXPORT bool operator>(const BasicDecimal128& left, const BasicDecimal128& right);
+ARROW_EXPORT bool operator>=(const BasicDecimal128& left, const BasicDecimal128& right);
+
+ARROW_EXPORT BasicDecimal128 operator-(const BasicDecimal128& operand);
+ARROW_EXPORT BasicDecimal128 operator~(const BasicDecimal128& operand);
+ARROW_EXPORT BasicDecimal128 operator+(const BasicDecimal128& left,
+ const BasicDecimal128& right);
+ARROW_EXPORT BasicDecimal128 operator-(const BasicDecimal128& left,
+ const BasicDecimal128& right);
+ARROW_EXPORT BasicDecimal128 operator*(const BasicDecimal128& left,
+ const BasicDecimal128& right);
+ARROW_EXPORT BasicDecimal128 operator/(const BasicDecimal128& left,
+ const BasicDecimal128& right);
+ARROW_EXPORT BasicDecimal128 operator%(const BasicDecimal128& left,
+ const BasicDecimal128& right);
+
+class ARROW_EXPORT BasicDecimal256 {
+ private:
+ // Due to a bug in clang, we have to declare the extend method prior to its
+ // usage.
+ template <typename T>
+ inline static constexpr uint64_t extend(T low_bits) noexcept {
+ return low_bits >= T() ? uint64_t{0} : ~uint64_t{0};
+ }
+
+ struct LittleEndianArrayTag {};
+
+ public:
+ static constexpr int kBitWidth = 256;
+ static constexpr int kMaxPrecision = 76;
+ static constexpr int kMaxScale = 76;
+
+ // A constructor tag to denote a little-endian encoded array
+ static constexpr LittleEndianArrayTag LittleEndianArray{};
+
+ /// \brief Create a BasicDecimal256 from the two's complement representation.
+ ///
+ /// Input array is assumed to be in native endianness.
+ constexpr BasicDecimal256(const std::array<uint64_t, 4>& array) noexcept
+ : array_(array) {}
+
+ /// \brief Create a BasicDecimal256 from the two's complement representation.
+ ///
+ /// Input array is assumed to be in little endianness, with native endian elements.
+ BasicDecimal256(LittleEndianArrayTag, const std::array<uint64_t, 4>& array) noexcept
+ : BasicDecimal256(BitUtil::LittleEndianArray::ToNative(array)) {}
+
+ /// \brief Empty constructor creates a BasicDecimal256 with a value of 0.
+ constexpr BasicDecimal256() noexcept : array_({0, 0, 0, 0}) {}
+
+ /// \brief Convert any integer value into a BasicDecimal256.
+ template <typename T,
+ typename = typename std::enable_if<
+ std::is_integral<T>::value && (sizeof(T) <= sizeof(uint64_t)), T>::type>
+ constexpr BasicDecimal256(T value) noexcept
+ : array_(BitUtil::LittleEndianArray::ToNative<uint64_t, 4>(
+ {static_cast<uint64_t>(value), extend(value), extend(value),
+ extend(value)})) {}
+
+ explicit BasicDecimal256(const BasicDecimal128& value) noexcept
+ : array_(BitUtil::LittleEndianArray::ToNative<uint64_t, 4>(
+ {value.low_bits(), static_cast<uint64_t>(value.high_bits()),
+ extend(value.high_bits()), extend(value.high_bits())})) {}
+
+ /// \brief Create a BasicDecimal256 from an array of bytes. Bytes are assumed to be in
+ /// native-endian byte order.
+ explicit BasicDecimal256(const uint8_t* bytes);
+
+ /// \brief Negate the current value (in-place)
+ BasicDecimal256& Negate();
+
+ /// \brief Absolute value (in-place)
+ BasicDecimal256& Abs();
+
+ /// \brief Absolute value
+ static BasicDecimal256 Abs(const BasicDecimal256& left);
+
+ /// \brief Add a number to this one. The result is truncated to 256 bits.
+ BasicDecimal256& operator+=(const BasicDecimal256& right);
+
+ /// \brief Subtract a number from this one. The result is truncated to 256 bits.
+ BasicDecimal256& operator-=(const BasicDecimal256& right);
+
+ /// \brief Get the bits of the two's complement representation of the number.
+ ///
+ /// The 4 elements are in native endian order. The bits within each uint64_t element
+ /// are in native endian order. For example, on a little endian machine,
+ /// BasicDecimal256(123).native_endian_array() = {123, 0, 0, 0};
+ /// BasicDecimal256(-2).native_endian_array() = {0xFF...FE, 0xFF...FF, 0xFF...FF,
+ /// 0xFF...FF}.
+ /// while on a big endian machine,
+ /// BasicDecimal256(123).native_endian_array() = {0, 0, 0, 123};
+ /// BasicDecimal256(-2).native_endian_array() = {0xFF...FF, 0xFF...FF, 0xFF...FF,
+ /// 0xFF...FE}.
+ inline const std::array<uint64_t, 4>& native_endian_array() const { return array_; }
+
+ /// \brief Get the bits of the two's complement representation of the number.
+ ///
+ /// The 4 elements are in little endian order. However, the bits within each
+ /// uint64_t element are in native endian order.
+ /// For example, BasicDecimal256(123).little_endian_array() = {123, 0};
+ inline const std::array<uint64_t, 4> little_endian_array() const {
+ return BitUtil::LittleEndianArray::FromNative(array_);
+ }
+
+ inline const uint8_t* native_endian_bytes() const {
+ return reinterpret_cast<const uint8_t*>(array_.data());
+ }
+
+ inline uint8_t* mutable_native_endian_bytes() {
+ return reinterpret_cast<uint8_t*>(array_.data());
+ }
+
+ /// \brief Get the lowest bits of the two's complement representation of the number.
+ inline uint64_t low_bits() const { return BitUtil::LittleEndianArray::Make(array_)[0]; }
+
+ /// \brief Return the raw bytes of the value in native-endian byte order.
+ std::array<uint8_t, 32> ToBytes() const;
+ void ToBytes(uint8_t* out) const;
+
+ /// \brief Scale multiplier for given scale value.
+ static const BasicDecimal256& GetScaleMultiplier(int32_t scale);
+ /// \brief Half-scale multiplier for given scale value.
+ static const BasicDecimal256& GetHalfScaleMultiplier(int32_t scale);
+
+ /// \brief Convert BasicDecimal256 from one scale to another
+ DecimalStatus Rescale(int32_t original_scale, int32_t new_scale,
+ BasicDecimal256* out) const;
+
+ /// \brief Scale up.
+ BasicDecimal256 IncreaseScaleBy(int32_t increase_by) const;
+
+ /// \brief Scale down.
+ /// - If 'round' is true, the right-most digits are dropped and the result value is
+ /// rounded up (+1 for positive, -1 for negative) based on the value of the
+ /// dropped digits (>= 10^reduce_by / 2).
+ /// - If 'round' is false, the right-most digits are simply dropped.
+ BasicDecimal256 ReduceScaleBy(int32_t reduce_by, bool round = true) const;
+
+ /// \brief Whether this number fits in the given precision
+ ///
+ /// Return true if the number of significant digits is less or equal to `precision`.
+ bool FitsInPrecision(int32_t precision) const;
+
+ inline int64_t Sign() const {
+ return 1 | (static_cast<int64_t>(BitUtil::LittleEndianArray::Make(array_)[3]) >> 63);
+ }
+
+ inline int64_t IsNegative() const {
+ return static_cast<int64_t>(BitUtil::LittleEndianArray::Make(array_)[3]) < 0;
+ }
+
+ /// \brief Multiply this number by another number. The result is truncated to 256 bits.
+ BasicDecimal256& operator*=(const BasicDecimal256& right);
+
+ /// Divide this number by right and return the result.
+ ///
+ /// This operation is not destructive.
+ /// The answer rounds to zero. Signs work like:
+ /// 21 / 5 -> 4, 1
+ /// -21 / 5 -> -4, -1
+ /// 21 / -5 -> -4, 1
+ /// -21 / -5 -> 4, -1
+ /// \param[in] divisor the number to divide by
+ /// \param[out] result the quotient
+ /// \param[out] remainder the remainder after the division
+ DecimalStatus Divide(const BasicDecimal256& divisor, BasicDecimal256* result,
+ BasicDecimal256* remainder) const;
+
+ /// \brief Shift left by the given number of bits.
+ BasicDecimal256& operator<<=(uint32_t bits);
+
+ /// \brief In-place division.
+ BasicDecimal256& operator/=(const BasicDecimal256& right);
+
+ /// \brief Get the maximum decimal value (is not a valid value).
+ static inline constexpr BasicDecimal256 GetMaxSentinel() {
+#if ARROW_LITTLE_ENDIAN
+ return BasicDecimal256({std::numeric_limits<uint64_t>::max(),
+ std::numeric_limits<uint64_t>::max(),
+ std::numeric_limits<uint64_t>::max(),
+ static_cast<uint64_t>(std::numeric_limits<int64_t>::max())});
+#else
+ return BasicDecimal256({static_cast<uint64_t>(std::numeric_limits<int64_t>::max()),
+ std::numeric_limits<uint64_t>::max(),
+ std::numeric_limits<uint64_t>::max(),
+ std::numeric_limits<uint64_t>::max()});
+#endif
+ }
+ /// \brief Get the minimum decimal value (is not a valid value).
+ static inline constexpr BasicDecimal256 GetMinSentinel() {
+#if ARROW_LITTLE_ENDIAN
+ return BasicDecimal256(
+ {0, 0, 0, static_cast<uint64_t>(std::numeric_limits<int64_t>::min())});
+#else
+ return BasicDecimal256(
+ {static_cast<uint64_t>(std::numeric_limits<int64_t>::min()), 0, 0, 0});
+#endif
+ }
+
+ private:
+ std::array<uint64_t, 4> array_;
+};
+
+ARROW_EXPORT inline bool operator==(const BasicDecimal256& left,
+ const BasicDecimal256& right) {
+ return left.native_endian_array() == right.native_endian_array();
+}
+
+ARROW_EXPORT inline bool operator!=(const BasicDecimal256& left,
+ const BasicDecimal256& right) {
+ return left.native_endian_array() != right.native_endian_array();
+}
+
+ARROW_EXPORT bool operator<(const BasicDecimal256& left, const BasicDecimal256& right);
+
+ARROW_EXPORT inline bool operator<=(const BasicDecimal256& left,
+ const BasicDecimal256& right) {
+ return !operator<(right, left);
+}
+
+ARROW_EXPORT inline bool operator>(const BasicDecimal256& left,
+ const BasicDecimal256& right) {
+ return operator<(right, left);
+}
+
+ARROW_EXPORT inline bool operator>=(const BasicDecimal256& left,
+ const BasicDecimal256& right) {
+ return !operator<(left, right);
+}
+
+ARROW_EXPORT BasicDecimal256 operator-(const BasicDecimal256& operand);
+ARROW_EXPORT BasicDecimal256 operator~(const BasicDecimal256& operand);
+ARROW_EXPORT BasicDecimal256 operator+(const BasicDecimal256& left,
+ const BasicDecimal256& right);
+ARROW_EXPORT BasicDecimal256 operator*(const BasicDecimal256& left,
+ const BasicDecimal256& right);
+ARROW_EXPORT BasicDecimal256 operator/(const BasicDecimal256& left,
+ const BasicDecimal256& right);
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/benchmark_main.cc b/src/arrow/cpp/src/arrow/util/benchmark_main.cc
new file mode 100644
index 000000000..c9739af03
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/benchmark_main.cc
@@ -0,0 +1,24 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+int main(int argc, char** argv) {
+ benchmark::Initialize(&argc, argv);
+ benchmark::RunSpecifiedBenchmarks();
+ return 0;
+}
diff --git a/src/arrow/cpp/src/arrow/util/benchmark_util.h b/src/arrow/cpp/src/arrow/util/benchmark_util.h
new file mode 100644
index 000000000..8379948bc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/benchmark_util.h
@@ -0,0 +1,138 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <string>
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/util/cpu_info.h"
+
+namespace arrow {
+
+using internal::CpuInfo;
+
+static CpuInfo* cpu_info = CpuInfo::GetInstance();
+
+static const int64_t kL1Size = cpu_info->CacheSize(CpuInfo::L1_CACHE);
+static const int64_t kL2Size = cpu_info->CacheSize(CpuInfo::L2_CACHE);
+static const int64_t kL3Size = cpu_info->CacheSize(CpuInfo::L3_CACHE);
+static const int64_t kCantFitInL3Size = kL3Size * 4;
+static const std::vector<int64_t> kMemorySizes = {kL1Size, kL2Size, kL3Size,
+ kCantFitInL3Size};
+
+template <typename Func>
+struct BenchmarkArgsType;
+
+// Pattern matching that extracts the vector element type of Benchmark::Args()
+template <typename Values>
+struct BenchmarkArgsType<benchmark::internal::Benchmark* (
+ benchmark::internal::Benchmark::*)(const std::vector<Values>&)> {
+ using type = Values;
+};
+
+// Benchmark changed its parameter type between releases from
+// int to int64_t. As it doesn't have version macros, we need
+// to apply C++ template magic.
+using ArgsType =
+ typename BenchmarkArgsType<decltype(&benchmark::internal::Benchmark::Args)>::type;
+
+struct GenericItemsArgs {
+ // number of items processed per iteration
+ const int64_t size;
+
+ // proportion of nulls in generated arrays
+ double null_proportion;
+
+ explicit GenericItemsArgs(benchmark::State& state)
+ : size(state.range(0)), state_(state) {
+ if (state.range(1) == 0) {
+ this->null_proportion = 0.0;
+ } else {
+ this->null_proportion = std::min(1., 1. / static_cast<double>(state.range(1)));
+ }
+ }
+
+ ~GenericItemsArgs() {
+ state_.counters["size"] = static_cast<double>(size);
+ state_.counters["null_percent"] = null_proportion * 100;
+ state_.SetItemsProcessed(state_.iterations() * size);
+ }
+
+ private:
+ benchmark::State& state_;
+};
+
+void BenchmarkSetArgsWithSizes(benchmark::internal::Benchmark* bench,
+ const std::vector<int64_t>& sizes = kMemorySizes) {
+ bench->Unit(benchmark::kMicrosecond);
+
+ // 0 is treated as "no nulls"
+ for (const auto size : sizes) {
+ for (const auto inverse_null_proportion :
+ std::vector<ArgsType>({10000, 100, 10, 2, 1, 0})) {
+ bench->Args({static_cast<ArgsType>(size), inverse_null_proportion});
+ }
+ }
+}
+
+void BenchmarkSetArgs(benchmark::internal::Benchmark* bench) {
+ BenchmarkSetArgsWithSizes(bench, kMemorySizes);
+}
+
+void RegressionSetArgs(benchmark::internal::Benchmark* bench) {
+ // Regression do not need to account for cache hierarchy, thus optimize for
+ // the best case.
+ BenchmarkSetArgsWithSizes(bench, {kL1Size});
+}
+
+// RAII struct to handle some of the boilerplate in regression benchmarks
+struct RegressionArgs {
+ // size of memory tested (per iteration) in bytes
+ const int64_t size;
+
+ // proportion of nulls in generated arrays
+ double null_proportion;
+
+ // If size_is_bytes is true, then it's a number of bytes, otherwise it's the
+ // number of items processed (for reporting)
+ explicit RegressionArgs(benchmark::State& state, bool size_is_bytes = true)
+ : size(state.range(0)), state_(state), size_is_bytes_(size_is_bytes) {
+ if (state.range(1) == 0) {
+ this->null_proportion = 0.0;
+ } else {
+ this->null_proportion = std::min(1., 1. / static_cast<double>(state.range(1)));
+ }
+ }
+
+ ~RegressionArgs() {
+ state_.counters["size"] = static_cast<double>(size);
+ state_.counters["null_percent"] = null_proportion * 100;
+ if (size_is_bytes_) {
+ state_.SetBytesProcessed(state_.iterations() * size);
+ } else {
+ state_.SetItemsProcessed(state_.iterations() * size);
+ }
+ }
+
+ private:
+ benchmark::State& state_;
+ bool size_is_bytes_;
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bit_block_counter.cc b/src/arrow/cpp/src/arrow/util/bit_block_counter.cc
new file mode 100644
index 000000000..7b5590f17
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bit_block_counter.cc
@@ -0,0 +1,74 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/bit_block_counter.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <type_traits>
+
+#include "arrow/buffer.h"
+#include "arrow/util/bitmap_ops.h"
+
+namespace arrow {
+namespace internal {
+
+BitBlockCount BitBlockCounter::GetBlockSlow(int64_t block_size) noexcept {
+ const int16_t run_length = static_cast<int16_t>(std::min(bits_remaining_, block_size));
+ int16_t popcount = static_cast<int16_t>(CountSetBits(bitmap_, offset_, run_length));
+ bits_remaining_ -= run_length;
+ // This code path should trigger _at most_ 2 times. In the "two times"
+ // case, the first time the run length will be a multiple of 8 by construction
+ bitmap_ += run_length / 8;
+ return {run_length, popcount};
+}
+
+OptionalBitBlockCounter::OptionalBitBlockCounter(const uint8_t* validity_bitmap,
+ int64_t offset, int64_t length)
+ : has_bitmap_(validity_bitmap != nullptr),
+ position_(0),
+ length_(length),
+ counter_(util::MakeNonNull(validity_bitmap), offset, length) {}
+
+OptionalBitBlockCounter::OptionalBitBlockCounter(
+ const std::shared_ptr<Buffer>& validity_bitmap, int64_t offset, int64_t length)
+ : OptionalBitBlockCounter(validity_bitmap ? validity_bitmap->data() : nullptr, offset,
+ length) {}
+
+OptionalBinaryBitBlockCounter::OptionalBinaryBitBlockCounter(const uint8_t* left_bitmap,
+ int64_t left_offset,
+ const uint8_t* right_bitmap,
+ int64_t right_offset,
+ int64_t length)
+ : has_bitmap_(HasBitmapFromBitmaps(left_bitmap != nullptr, right_bitmap != nullptr)),
+ position_(0),
+ length_(length),
+ unary_counter_(
+ util::MakeNonNull(left_bitmap != nullptr ? left_bitmap : right_bitmap),
+ left_bitmap != nullptr ? left_offset : right_offset, length),
+ binary_counter_(util::MakeNonNull(left_bitmap), left_offset,
+ util::MakeNonNull(right_bitmap), right_offset, length) {}
+
+OptionalBinaryBitBlockCounter::OptionalBinaryBitBlockCounter(
+ const std::shared_ptr<Buffer>& left_bitmap, int64_t left_offset,
+ const std::shared_ptr<Buffer>& right_bitmap, int64_t right_offset, int64_t length)
+ : OptionalBinaryBitBlockCounter(
+ left_bitmap ? left_bitmap->data() : nullptr, left_offset,
+ right_bitmap ? right_bitmap->data() : nullptr, right_offset, length) {}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bit_block_counter.h b/src/arrow/cpp/src/arrow/util/bit_block_counter.h
new file mode 100644
index 000000000..460799036
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bit_block_counter.h
@@ -0,0 +1,542 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <memory>
+
+#include "arrow/buffer.h"
+#include "arrow/status.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/ubsan.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace internal {
+namespace detail {
+
+inline uint64_t LoadWord(const uint8_t* bytes) {
+ return BitUtil::ToLittleEndian(util::SafeLoadAs<uint64_t>(bytes));
+}
+
+inline uint64_t ShiftWord(uint64_t current, uint64_t next, int64_t shift) {
+ if (shift == 0) {
+ return current;
+ }
+ return (current >> shift) | (next << (64 - shift));
+}
+
+// These templates are here to help with unit tests
+
+template <typename T>
+struct BitBlockAnd {
+ static T Call(T left, T right) { return left & right; }
+};
+
+template <>
+struct BitBlockAnd<bool> {
+ static bool Call(bool left, bool right) { return left && right; }
+};
+
+template <typename T>
+struct BitBlockAndNot {
+ static T Call(T left, T right) { return left & ~right; }
+};
+
+template <>
+struct BitBlockAndNot<bool> {
+ static bool Call(bool left, bool right) { return left && !right; }
+};
+
+template <typename T>
+struct BitBlockOr {
+ static T Call(T left, T right) { return left | right; }
+};
+
+template <>
+struct BitBlockOr<bool> {
+ static bool Call(bool left, bool right) { return left || right; }
+};
+
+template <typename T>
+struct BitBlockOrNot {
+ static T Call(T left, T right) { return left | ~right; }
+};
+
+template <>
+struct BitBlockOrNot<bool> {
+ static bool Call(bool left, bool right) { return left || !right; }
+};
+
+} // namespace detail
+
+/// \brief Return value from bit block counters: the total number of bits and
+/// the number of set bits.
+struct BitBlockCount {
+ int16_t length;
+ int16_t popcount;
+
+ bool NoneSet() const { return this->popcount == 0; }
+ bool AllSet() const { return this->length == this->popcount; }
+};
+
+/// \brief A class that scans through a true/false bitmap to compute popcounts
+/// 64 or 256 bits at a time. This is used to accelerate processing of
+/// mostly-not-null array data.
+class ARROW_EXPORT BitBlockCounter {
+ public:
+ BitBlockCounter(const uint8_t* bitmap, int64_t start_offset, int64_t length)
+ : bitmap_(util::MakeNonNull(bitmap) + start_offset / 8),
+ bits_remaining_(length),
+ offset_(start_offset % 8) {}
+
+ /// \brief The bit size of each word run
+ static constexpr int64_t kWordBits = 64;
+
+ /// \brief The bit size of four words run
+ static constexpr int64_t kFourWordsBits = kWordBits * 4;
+
+ /// \brief Return the next run of available bits, usually 256. The returned
+ /// pair contains the size of run and the number of true values. The last
+ /// block will have a length less than 256 if the bitmap length is not a
+ /// multiple of 256, and will return 0-length blocks in subsequent
+ /// invocations.
+ BitBlockCount NextFourWords() {
+ using detail::LoadWord;
+ using detail::ShiftWord;
+
+ if (!bits_remaining_) {
+ return {0, 0};
+ }
+ int64_t total_popcount = 0;
+ if (offset_ == 0) {
+ if (bits_remaining_ < kFourWordsBits) {
+ return GetBlockSlow(kFourWordsBits);
+ }
+ total_popcount += BitUtil::PopCount(LoadWord(bitmap_));
+ total_popcount += BitUtil::PopCount(LoadWord(bitmap_ + 8));
+ total_popcount += BitUtil::PopCount(LoadWord(bitmap_ + 16));
+ total_popcount += BitUtil::PopCount(LoadWord(bitmap_ + 24));
+ } else {
+ // When the offset is > 0, we need there to be a word beyond the last
+ // aligned word in the bitmap for the bit shifting logic.
+ if (bits_remaining_ < 5 * kFourWordsBits - offset_) {
+ return GetBlockSlow(kFourWordsBits);
+ }
+ auto current = LoadWord(bitmap_);
+ auto next = LoadWord(bitmap_ + 8);
+ total_popcount += BitUtil::PopCount(ShiftWord(current, next, offset_));
+ current = next;
+ next = LoadWord(bitmap_ + 16);
+ total_popcount += BitUtil::PopCount(ShiftWord(current, next, offset_));
+ current = next;
+ next = LoadWord(bitmap_ + 24);
+ total_popcount += BitUtil::PopCount(ShiftWord(current, next, offset_));
+ current = next;
+ next = LoadWord(bitmap_ + 32);
+ total_popcount += BitUtil::PopCount(ShiftWord(current, next, offset_));
+ }
+ bitmap_ += BitUtil::BytesForBits(kFourWordsBits);
+ bits_remaining_ -= kFourWordsBits;
+ return {256, static_cast<int16_t>(total_popcount)};
+ }
+
+ /// \brief Return the next run of available bits, usually 64. The returned
+ /// pair contains the size of run and the number of true values. The last
+ /// block will have a length less than 64 if the bitmap length is not a
+ /// multiple of 64, and will return 0-length blocks in subsequent
+ /// invocations.
+ BitBlockCount NextWord() {
+ using detail::LoadWord;
+ using detail::ShiftWord;
+
+ if (!bits_remaining_) {
+ return {0, 0};
+ }
+ int64_t popcount = 0;
+ if (offset_ == 0) {
+ if (bits_remaining_ < kWordBits) {
+ return GetBlockSlow(kWordBits);
+ }
+ popcount = BitUtil::PopCount(LoadWord(bitmap_));
+ } else {
+ // When the offset is > 0, we need there to be a word beyond the last
+ // aligned word in the bitmap for the bit shifting logic.
+ if (bits_remaining_ < 2 * kWordBits - offset_) {
+ return GetBlockSlow(kWordBits);
+ }
+ popcount =
+ BitUtil::PopCount(ShiftWord(LoadWord(bitmap_), LoadWord(bitmap_ + 8), offset_));
+ }
+ bitmap_ += kWordBits / 8;
+ bits_remaining_ -= kWordBits;
+ return {64, static_cast<int16_t>(popcount)};
+ }
+
+ private:
+ /// \brief Return block with the requested size when doing word-wise
+ /// computation is not possible due to inadequate bits remaining.
+ BitBlockCount GetBlockSlow(int64_t block_size) noexcept;
+
+ const uint8_t* bitmap_;
+ int64_t bits_remaining_;
+ int64_t offset_;
+};
+
+/// \brief A tool to iterate through a possibly non-existent validity bitmap,
+/// to allow us to write one code path for both the with-nulls and no-nulls
+/// cases without giving up a lot of performance.
+class ARROW_EXPORT OptionalBitBlockCounter {
+ public:
+ // validity_bitmap may be NULLPTR
+ OptionalBitBlockCounter(const uint8_t* validity_bitmap, int64_t offset, int64_t length);
+
+ // validity_bitmap may be null
+ OptionalBitBlockCounter(const std::shared_ptr<Buffer>& validity_bitmap, int64_t offset,
+ int64_t length);
+
+ /// Return block count for next word when the bitmap is available otherwise
+ /// return a block with length up to INT16_MAX when there is no validity
+ /// bitmap (so all the referenced values are not null).
+ BitBlockCount NextBlock() {
+ static constexpr int64_t kMaxBlockSize = std::numeric_limits<int16_t>::max();
+ if (has_bitmap_) {
+ BitBlockCount block = counter_.NextWord();
+ position_ += block.length;
+ return block;
+ } else {
+ int16_t block_size =
+ static_cast<int16_t>(std::min(kMaxBlockSize, length_ - position_));
+ position_ += block_size;
+ // All values are non-null
+ return {block_size, block_size};
+ }
+ }
+
+ // Like NextBlock, but returns a word-sized block even when there is no
+ // validity bitmap
+ BitBlockCount NextWord() {
+ static constexpr int64_t kWordSize = 64;
+ if (has_bitmap_) {
+ BitBlockCount block = counter_.NextWord();
+ position_ += block.length;
+ return block;
+ } else {
+ int16_t block_size = static_cast<int16_t>(std::min(kWordSize, length_ - position_));
+ position_ += block_size;
+ // All values are non-null
+ return {block_size, block_size};
+ }
+ }
+
+ private:
+ const bool has_bitmap_;
+ int64_t position_;
+ int64_t length_;
+ BitBlockCounter counter_;
+};
+
+/// \brief A class that computes popcounts on the result of bitwise operations
+/// between two bitmaps, 64 bits at a time. A 64-bit word is loaded from each
+/// bitmap, then the popcount is computed on e.g. the bitwise-and of the two
+/// words.
+class ARROW_EXPORT BinaryBitBlockCounter {
+ public:
+ BinaryBitBlockCounter(const uint8_t* left_bitmap, int64_t left_offset,
+ const uint8_t* right_bitmap, int64_t right_offset, int64_t length)
+ : left_bitmap_(util::MakeNonNull(left_bitmap) + left_offset / 8),
+ left_offset_(left_offset % 8),
+ right_bitmap_(util::MakeNonNull(right_bitmap) + right_offset / 8),
+ right_offset_(right_offset % 8),
+ bits_remaining_(length) {}
+
+ /// \brief Return the popcount of the bitwise-and of the next run of
+ /// available bits, up to 64. The returned pair contains the size of run and
+ /// the number of true values. The last block will have a length less than 64
+ /// if the bitmap length is not a multiple of 64, and will return 0-length
+ /// blocks in subsequent invocations.
+ BitBlockCount NextAndWord() { return NextWord<detail::BitBlockAnd>(); }
+
+ /// \brief Computes "x & ~y" block for each available run of bits.
+ BitBlockCount NextAndNotWord() { return NextWord<detail::BitBlockAndNot>(); }
+
+ /// \brief Computes "x | y" block for each available run of bits.
+ BitBlockCount NextOrWord() { return NextWord<detail::BitBlockOr>(); }
+
+ /// \brief Computes "x | ~y" block for each available run of bits.
+ BitBlockCount NextOrNotWord() { return NextWord<detail::BitBlockOrNot>(); }
+
+ private:
+ template <template <typename T> class Op>
+ BitBlockCount NextWord() {
+ using detail::LoadWord;
+ using detail::ShiftWord;
+
+ if (!bits_remaining_) {
+ return {0, 0};
+ }
+ // When the offset is > 0, we need there to be a word beyond the last aligned
+ // word in the bitmap for the bit shifting logic.
+ constexpr int64_t kWordBits = BitBlockCounter::kWordBits;
+ const int64_t bits_required_to_use_words =
+ std::max(left_offset_ == 0 ? 64 : 64 + (64 - left_offset_),
+ right_offset_ == 0 ? 64 : 64 + (64 - right_offset_));
+ if (bits_remaining_ < bits_required_to_use_words) {
+ const int16_t run_length =
+ static_cast<int16_t>(std::min(bits_remaining_, kWordBits));
+ int16_t popcount = 0;
+ for (int64_t i = 0; i < run_length; ++i) {
+ if (Op<bool>::Call(BitUtil::GetBit(left_bitmap_, left_offset_ + i),
+ BitUtil::GetBit(right_bitmap_, right_offset_ + i))) {
+ ++popcount;
+ }
+ }
+ // This code path should trigger _at most_ 2 times. In the "two times"
+ // case, the first time the run length will be a multiple of 8.
+ left_bitmap_ += run_length / 8;
+ right_bitmap_ += run_length / 8;
+ bits_remaining_ -= run_length;
+ return {run_length, popcount};
+ }
+
+ int64_t popcount = 0;
+ if (left_offset_ == 0 && right_offset_ == 0) {
+ popcount = BitUtil::PopCount(
+ Op<uint64_t>::Call(LoadWord(left_bitmap_), LoadWord(right_bitmap_)));
+ } else {
+ auto left_word =
+ ShiftWord(LoadWord(left_bitmap_), LoadWord(left_bitmap_ + 8), left_offset_);
+ auto right_word =
+ ShiftWord(LoadWord(right_bitmap_), LoadWord(right_bitmap_ + 8), right_offset_);
+ popcount = BitUtil::PopCount(Op<uint64_t>::Call(left_word, right_word));
+ }
+ left_bitmap_ += kWordBits / 8;
+ right_bitmap_ += kWordBits / 8;
+ bits_remaining_ -= kWordBits;
+ return {64, static_cast<int16_t>(popcount)};
+ }
+
+ const uint8_t* left_bitmap_;
+ int64_t left_offset_;
+ const uint8_t* right_bitmap_;
+ int64_t right_offset_;
+ int64_t bits_remaining_;
+};
+
+class ARROW_EXPORT OptionalBinaryBitBlockCounter {
+ public:
+ // Any bitmap may be NULLPTR
+ OptionalBinaryBitBlockCounter(const uint8_t* left_bitmap, int64_t left_offset,
+ const uint8_t* right_bitmap, int64_t right_offset,
+ int64_t length);
+
+ // Any bitmap may be null
+ OptionalBinaryBitBlockCounter(const std::shared_ptr<Buffer>& left_bitmap,
+ int64_t left_offset,
+ const std::shared_ptr<Buffer>& right_bitmap,
+ int64_t right_offset, int64_t length);
+
+ BitBlockCount NextAndBlock() {
+ static constexpr int64_t kMaxBlockSize = std::numeric_limits<int16_t>::max();
+ switch (has_bitmap_) {
+ case HasBitmap::BOTH: {
+ BitBlockCount block = binary_counter_.NextAndWord();
+ position_ += block.length;
+ return block;
+ }
+ case HasBitmap::ONE: {
+ BitBlockCount block = unary_counter_.NextWord();
+ position_ += block.length;
+ return block;
+ }
+ case HasBitmap::NONE:
+ default: {
+ const int16_t block_size =
+ static_cast<int16_t>(std::min(kMaxBlockSize, length_ - position_));
+ position_ += block_size;
+ // All values are non-null
+ return {block_size, block_size};
+ }
+ }
+ }
+
+ BitBlockCount NextOrNotBlock() {
+ static constexpr int64_t kMaxBlockSize = std::numeric_limits<int16_t>::max();
+ switch (has_bitmap_) {
+ case HasBitmap::BOTH: {
+ BitBlockCount block = binary_counter_.NextOrNotWord();
+ position_ += block.length;
+ return block;
+ }
+ case HasBitmap::ONE: {
+ BitBlockCount block = unary_counter_.NextWord();
+ position_ += block.length;
+ return block;
+ }
+ case HasBitmap::NONE:
+ default: {
+ const int16_t block_size =
+ static_cast<int16_t>(std::min(kMaxBlockSize, length_ - position_));
+ position_ += block_size;
+ // All values are non-null
+ return {block_size, block_size};
+ }
+ }
+ }
+
+ private:
+ enum class HasBitmap : int { BOTH, ONE, NONE };
+
+ const HasBitmap has_bitmap_;
+ int64_t position_;
+ int64_t length_;
+ BitBlockCounter unary_counter_;
+ BinaryBitBlockCounter binary_counter_;
+
+ static HasBitmap HasBitmapFromBitmaps(bool has_left, bool has_right) {
+ switch (static_cast<int>(has_left) + static_cast<int>(has_right)) {
+ case 0:
+ return HasBitmap::NONE;
+ case 1:
+ return HasBitmap::ONE;
+ default: // 2
+ return HasBitmap::BOTH;
+ }
+ }
+};
+
+// Functional-style bit block visitors.
+
+template <typename VisitNotNull, typename VisitNull>
+static Status VisitBitBlocks(const std::shared_ptr<Buffer>& bitmap_buf, int64_t offset,
+ int64_t length, VisitNotNull&& visit_not_null,
+ VisitNull&& visit_null) {
+ const uint8_t* bitmap = NULLPTR;
+ if (bitmap_buf != NULLPTR) {
+ bitmap = bitmap_buf->data();
+ }
+ internal::OptionalBitBlockCounter bit_counter(bitmap, offset, length);
+ int64_t position = 0;
+ while (position < length) {
+ internal::BitBlockCount block = bit_counter.NextBlock();
+ if (block.AllSet()) {
+ for (int64_t i = 0; i < block.length; ++i, ++position) {
+ ARROW_RETURN_NOT_OK(visit_not_null(position));
+ }
+ } else if (block.NoneSet()) {
+ for (int64_t i = 0; i < block.length; ++i, ++position) {
+ ARROW_RETURN_NOT_OK(visit_null());
+ }
+ } else {
+ for (int64_t i = 0; i < block.length; ++i, ++position) {
+ if (BitUtil::GetBit(bitmap, offset + position)) {
+ ARROW_RETURN_NOT_OK(visit_not_null(position));
+ } else {
+ ARROW_RETURN_NOT_OK(visit_null());
+ }
+ }
+ }
+ }
+ return Status::OK();
+}
+
+template <typename VisitNotNull, typename VisitNull>
+static void VisitBitBlocksVoid(const std::shared_ptr<Buffer>& bitmap_buf, int64_t offset,
+ int64_t length, VisitNotNull&& visit_not_null,
+ VisitNull&& visit_null) {
+ const uint8_t* bitmap = NULLPTR;
+ if (bitmap_buf != NULLPTR) {
+ bitmap = bitmap_buf->data();
+ }
+ internal::OptionalBitBlockCounter bit_counter(bitmap, offset, length);
+ int64_t position = 0;
+ while (position < length) {
+ internal::BitBlockCount block = bit_counter.NextBlock();
+ if (block.AllSet()) {
+ for (int64_t i = 0; i < block.length; ++i, ++position) {
+ visit_not_null(position);
+ }
+ } else if (block.NoneSet()) {
+ for (int64_t i = 0; i < block.length; ++i, ++position) {
+ visit_null();
+ }
+ } else {
+ for (int64_t i = 0; i < block.length; ++i, ++position) {
+ if (BitUtil::GetBit(bitmap, offset + position)) {
+ visit_not_null(position);
+ } else {
+ visit_null();
+ }
+ }
+ }
+ }
+}
+
+template <typename VisitNotNull, typename VisitNull>
+static void VisitTwoBitBlocksVoid(const std::shared_ptr<Buffer>& left_bitmap_buf,
+ int64_t left_offset,
+ const std::shared_ptr<Buffer>& right_bitmap_buf,
+ int64_t right_offset, int64_t length,
+ VisitNotNull&& visit_not_null, VisitNull&& visit_null) {
+ if (left_bitmap_buf == NULLPTR || right_bitmap_buf == NULLPTR) {
+ // At most one bitmap is present
+ if (left_bitmap_buf == NULLPTR) {
+ return VisitBitBlocksVoid(right_bitmap_buf, right_offset, length,
+ std::forward<VisitNotNull>(visit_not_null),
+ std::forward<VisitNull>(visit_null));
+ } else {
+ return VisitBitBlocksVoid(left_bitmap_buf, left_offset, length,
+ std::forward<VisitNotNull>(visit_not_null),
+ std::forward<VisitNull>(visit_null));
+ }
+ }
+ // Both bitmaps are present
+ const uint8_t* left_bitmap = left_bitmap_buf->data();
+ const uint8_t* right_bitmap = right_bitmap_buf->data();
+ BinaryBitBlockCounter bit_counter(left_bitmap, left_offset, right_bitmap, right_offset,
+ length);
+ int64_t position = 0;
+ while (position < length) {
+ BitBlockCount block = bit_counter.NextAndWord();
+ if (block.AllSet()) {
+ for (int64_t i = 0; i < block.length; ++i, ++position) {
+ visit_not_null(position);
+ }
+ } else if (block.NoneSet()) {
+ for (int64_t i = 0; i < block.length; ++i, ++position) {
+ visit_null();
+ }
+ } else {
+ for (int64_t i = 0; i < block.length; ++i, ++position) {
+ if (BitUtil::GetBit(left_bitmap, left_offset + position) &&
+ BitUtil::GetBit(right_bitmap, right_offset + position)) {
+ visit_not_null(position);
+ } else {
+ visit_null();
+ }
+ }
+ }
+ }
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bit_block_counter_benchmark.cc b/src/arrow/cpp/src/arrow/util/bit_block_counter_benchmark.cc
new file mode 100644
index 000000000..429993442
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bit_block_counter_benchmark.cc
@@ -0,0 +1,266 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_primitive.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_reader.h"
+
+namespace arrow {
+namespace internal {
+
+struct UnaryBitBlockBenchmark {
+ benchmark::State& state;
+ int64_t offset;
+ int64_t bitmap_length;
+ std::shared_ptr<Array> arr;
+ int64_t expected;
+
+ explicit UnaryBitBlockBenchmark(benchmark::State& state, int64_t offset = 0)
+ : state(state), offset(offset), bitmap_length(1 << 20) {
+ random::RandomArrayGenerator rng(/*seed=*/0);
+ // State parameter is the average number of total values for each null
+ // value. So 100 means that 1 out of 100 on average are null.
+ double null_probability = 1. / static_cast<double>(state.range(0));
+ arr = rng.Int8(bitmap_length, 0, 100, null_probability);
+
+ // Compute the expected result
+ this->expected = 0;
+ const auto& int8_arr = static_cast<const Int8Array&>(*arr);
+ for (int64_t i = this->offset; i < bitmap_length; ++i) {
+ if (int8_arr.IsValid(i)) {
+ this->expected += int8_arr.Value(i);
+ }
+ }
+ }
+
+ template <typename NextBlockFunc>
+ void BenchBitBlockCounter(NextBlockFunc&& next_block) {
+ const auto& int8_arr = static_cast<const Int8Array&>(*arr);
+ const uint8_t* bitmap = arr->null_bitmap_data();
+ for (auto _ : state) {
+ BitBlockCounter scanner(bitmap, this->offset, bitmap_length - this->offset);
+ int64_t result = 0;
+ int64_t position = this->offset;
+ while (true) {
+ BitBlockCount block = next_block(&scanner);
+ if (block.length == 0) {
+ break;
+ }
+ if (block.length == block.popcount) {
+ // All not-null
+ for (int64_t i = 0; i < block.length; ++i) {
+ result += int8_arr.Value(position + i);
+ }
+ } else if (block.popcount > 0) {
+ // Some but not all not-null
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(bitmap, position + i)) {
+ result += int8_arr.Value(position + i);
+ }
+ }
+ }
+ position += block.length;
+ }
+ // Sanity check
+ if (result != expected) {
+ std::abort();
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * bitmap_length);
+ }
+
+ void BenchBitmapReader() {
+ const auto& int8_arr = static_cast<const Int8Array&>(*arr);
+ for (auto _ : state) {
+ internal::BitmapReader bit_reader(arr->null_bitmap_data(), this->offset,
+ bitmap_length - this->offset);
+ int64_t result = 0;
+ for (int64_t i = this->offset; i < bitmap_length; ++i) {
+ if (bit_reader.IsSet()) {
+ result += int8_arr.Value(i);
+ }
+ bit_reader.Next();
+ }
+ // Sanity check
+ if (result != expected) {
+ std::abort();
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * bitmap_length);
+ }
+};
+
+struct BinaryBitBlockBenchmark {
+ benchmark::State& state;
+ int64_t offset;
+ int64_t bitmap_length;
+ std::shared_ptr<Array> left;
+ std::shared_ptr<Array> right;
+ int64_t expected;
+ const Int8Array* left_int8;
+ const Int8Array* right_int8;
+
+ explicit BinaryBitBlockBenchmark(benchmark::State& state, int64_t offset = 0)
+ : state(state), offset(offset), bitmap_length(1 << 20) {
+ random::RandomArrayGenerator rng(/*seed=*/0);
+
+ // State parameter is the average number of total values for each null
+ // value. So 100 means that 1 out of 100 on average are null.
+ double null_probability = 1. / static_cast<double>(state.range(0));
+ left = rng.Int8(bitmap_length, 0, 100, null_probability);
+ right = rng.Int8(bitmap_length, 0, 50, null_probability);
+ left_int8 = static_cast<const Int8Array*>(left.get());
+ right_int8 = static_cast<const Int8Array*>(right.get());
+
+ // Compute the expected result
+ expected = 0;
+ for (int64_t i = this->offset; i < bitmap_length; ++i) {
+ if (left_int8->IsValid(i) && right_int8->IsValid(i)) {
+ expected += left_int8->Value(i) + right_int8->Value(i);
+ }
+ }
+ }
+
+ void BenchBitBlockCounter() {
+ const uint8_t* left_bitmap = left->null_bitmap_data();
+ const uint8_t* right_bitmap = right->null_bitmap_data();
+ for (auto _ : state) {
+ BinaryBitBlockCounter scanner(left_bitmap, this->offset, right_bitmap, this->offset,
+ bitmap_length - this->offset);
+ int64_t result = 0;
+ int64_t position = this->offset;
+ while (true) {
+ BitBlockCount block = scanner.NextAndWord();
+ if (block.length == 0) {
+ break;
+ }
+ if (block.length == block.popcount) {
+ // All not-null
+ for (int64_t i = 0; i < block.length; ++i) {
+ result += left_int8->Value(position + i) + right_int8->Value(position + i);
+ }
+ } else if (block.popcount > 0) {
+ // Some but not all not-null
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(left_bitmap, position + i) &&
+ BitUtil::GetBit(right_bitmap, position + i)) {
+ result += left_int8->Value(position + i) + right_int8->Value(position + i);
+ }
+ }
+ }
+ position += block.length;
+ }
+ // Sanity check
+ if (result != expected) {
+ std::abort();
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * bitmap_length);
+ }
+
+ void BenchBitmapReader() {
+ for (auto _ : state) {
+ internal::BitmapReader left_reader(left->null_bitmap_data(), this->offset,
+ bitmap_length - this->offset);
+ internal::BitmapReader right_reader(right->null_bitmap_data(), this->offset,
+ bitmap_length - this->offset);
+ int64_t result = 0;
+ for (int64_t i = this->offset; i < bitmap_length; ++i) {
+ if (left_reader.IsSet() && right_reader.IsSet()) {
+ result += left_int8->Value(i) + right_int8->Value(i);
+ }
+ left_reader.Next();
+ right_reader.Next();
+ }
+ // Sanity check
+ if (result != expected) {
+ std::abort();
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * bitmap_length);
+ }
+};
+
+static void BitBlockCounterSum(benchmark::State& state) {
+ UnaryBitBlockBenchmark(state, /*offset=*/0)
+ .BenchBitBlockCounter([](BitBlockCounter* counter) { return counter->NextWord(); });
+}
+
+static void BitBlockCounterSumWithOffset(benchmark::State& state) {
+ UnaryBitBlockBenchmark(state, /*offset=*/4)
+ .BenchBitBlockCounter([](BitBlockCounter* counter) { return counter->NextWord(); });
+}
+
+static void BitBlockCounterFourWordsSum(benchmark::State& state) {
+ UnaryBitBlockBenchmark(state, /*offset=*/0)
+ .BenchBitBlockCounter(
+ [](BitBlockCounter* counter) { return counter->NextFourWords(); });
+}
+
+static void BitBlockCounterFourWordsSumWithOffset(benchmark::State& state) {
+ UnaryBitBlockBenchmark(state, /*offset=*/4)
+ .BenchBitBlockCounter(
+ [](BitBlockCounter* counter) { return counter->NextFourWords(); });
+}
+
+static void BitmapReaderSum(benchmark::State& state) {
+ UnaryBitBlockBenchmark(state, /*offset=*/0).BenchBitmapReader();
+}
+
+static void BitmapReaderSumWithOffset(benchmark::State& state) {
+ UnaryBitBlockBenchmark(state, /*offset=*/4).BenchBitmapReader();
+}
+
+static void BinaryBitBlockCounterSum(benchmark::State& state) {
+ BinaryBitBlockBenchmark(state, /*offset=*/0).BenchBitBlockCounter();
+}
+
+static void BinaryBitBlockCounterSumWithOffset(benchmark::State& state) {
+ BinaryBitBlockBenchmark(state, /*offset=*/4).BenchBitBlockCounter();
+}
+
+static void BinaryBitmapReaderSum(benchmark::State& state) {
+ BinaryBitBlockBenchmark(state, /*offset=*/0).BenchBitmapReader();
+}
+
+static void BinaryBitmapReaderSumWithOffset(benchmark::State& state) {
+ BinaryBitBlockBenchmark(state, /*offset=*/4).BenchBitmapReader();
+}
+
+// Range value: average number of total values per null
+BENCHMARK(BitBlockCounterSum)->Range(2, 1 << 16);
+BENCHMARK(BitBlockCounterSumWithOffset)->Range(2, 1 << 16);
+BENCHMARK(BitBlockCounterFourWordsSum)->Range(2, 1 << 16);
+BENCHMARK(BitBlockCounterFourWordsSumWithOffset)->Range(2, 1 << 16);
+BENCHMARK(BitmapReaderSum)->Range(2, 1 << 16);
+BENCHMARK(BitmapReaderSumWithOffset)->Range(2, 1 << 16);
+BENCHMARK(BinaryBitBlockCounterSum)->Range(2, 1 << 16);
+BENCHMARK(BinaryBitBlockCounterSumWithOffset)->Range(2, 1 << 16);
+BENCHMARK(BinaryBitmapReaderSum)->Range(2, 1 << 16);
+BENCHMARK(BinaryBitmapReaderSumWithOffset)->Range(2, 1 << 16);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bit_block_counter_test.cc b/src/arrow/cpp/src/arrow/util/bit_block_counter_test.cc
new file mode 100644
index 000000000..3fdfa3ed9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bit_block_counter_test.cc
@@ -0,0 +1,417 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+
+#include <gtest/gtest.h>
+
+#include "arrow/buffer.h"
+#include "arrow/memory_pool.h"
+#include "arrow/result.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+
+namespace arrow {
+namespace internal {
+
+class TestBitBlockCounter : public ::testing::Test {
+ public:
+ void Create(int64_t nbytes, int64_t offset, int64_t length) {
+ ASSERT_OK_AND_ASSIGN(buf_, AllocateBuffer(nbytes));
+ // Start with data zeroed out
+ std::memset(buf_->mutable_data(), 0, nbytes);
+ counter_.reset(new BitBlockCounter(buf_->data(), offset, length));
+ }
+
+ protected:
+ std::shared_ptr<Buffer> buf_;
+ std::unique_ptr<BitBlockCounter> counter_;
+};
+
+static constexpr int64_t kWordSize = 64;
+
+TEST_F(TestBitBlockCounter, OneWordBasics) {
+ const int64_t nbytes = 1024;
+
+ Create(nbytes, 0, nbytes * 8);
+
+ int64_t bits_scanned = 0;
+ for (int64_t i = 0; i < nbytes / 8; ++i) {
+ BitBlockCount block = counter_->NextWord();
+ ASSERT_EQ(block.length, kWordSize);
+ ASSERT_EQ(block.popcount, 0);
+ bits_scanned += block.length;
+ }
+ ASSERT_EQ(bits_scanned, 1024 * 8);
+
+ auto block = counter_->NextWord();
+ ASSERT_EQ(block.length, 0);
+ ASSERT_EQ(block.popcount, 0);
+}
+
+TEST_F(TestBitBlockCounter, FourWordsBasics) {
+ const int64_t nbytes = 1024;
+
+ Create(nbytes, 0, nbytes * 8);
+
+ int64_t bits_scanned = 0;
+ for (int64_t i = 0; i < nbytes / 32; ++i) {
+ BitBlockCount block = counter_->NextFourWords();
+ ASSERT_EQ(block.length, 4 * kWordSize);
+ ASSERT_EQ(block.popcount, 0);
+ bits_scanned += block.length;
+ }
+ ASSERT_EQ(bits_scanned, 1024 * 8);
+
+ auto block = counter_->NextFourWords();
+ ASSERT_EQ(block.length, 0);
+ ASSERT_EQ(block.popcount, 0);
+}
+
+TEST_F(TestBitBlockCounter, OneWordWithOffsets) {
+ auto CheckWithOffset = [&](int64_t offset) {
+ const int64_t nwords = 4;
+
+ const int64_t total_bytes = nwords * 8 + 1;
+ // Trim a bit from the end of the bitmap so we can check the remainder bits
+ // behavior
+ Create(total_bytes, offset, nwords * kWordSize - offset - 1);
+
+ // Start with data all set
+ std::memset(buf_->mutable_data(), 0xFF, total_bytes);
+
+ BitBlockCount block = counter_->NextWord();
+ ASSERT_EQ(kWordSize, block.length);
+ ASSERT_EQ(block.popcount, 64);
+
+ // Add a false value to the next word
+ BitUtil::SetBitTo(buf_->mutable_data(), kWordSize + offset, false);
+ block = counter_->NextWord();
+ ASSERT_EQ(block.length, 64);
+ ASSERT_EQ(block.popcount, 63);
+
+ // Set the next word to all false
+ BitUtil::SetBitsTo(buf_->mutable_data(), 2 * kWordSize + offset, kWordSize, false);
+
+ block = counter_->NextWord();
+ ASSERT_EQ(block.length, 64);
+ ASSERT_EQ(block.popcount, 0);
+
+ block = counter_->NextWord();
+ ASSERT_EQ(block.length, kWordSize - offset - 1);
+ ASSERT_EQ(block.length, block.popcount);
+
+ // We can keep calling NextWord safely
+ block = counter_->NextWord();
+ ASSERT_EQ(block.length, 0);
+ ASSERT_EQ(block.popcount, 0);
+ };
+
+ for (int64_t offset_i = 0; offset_i < 8; ++offset_i) {
+ CheckWithOffset(offset_i);
+ }
+}
+
+TEST_F(TestBitBlockCounter, FourWordsWithOffsets) {
+ auto CheckWithOffset = [&](int64_t offset) {
+ const int64_t nwords = 17;
+
+ const int64_t total_bytes = nwords * 8 + 1;
+ // Trim a bit from the end of the bitmap so we can check the remainder bits
+ // behavior
+ Create(total_bytes, offset, nwords * kWordSize - offset - 1);
+
+ // Start with data all set
+ std::memset(buf_->mutable_data(), 0xFF, total_bytes);
+
+ BitBlockCount block = counter_->NextFourWords();
+ ASSERT_EQ(block.length, 4 * kWordSize);
+ ASSERT_EQ(block.popcount, block.length);
+
+ // Add some false values to the next 3 shifted words
+ BitUtil::SetBitTo(buf_->mutable_data(), 4 * kWordSize + offset, false);
+ BitUtil::SetBitTo(buf_->mutable_data(), 5 * kWordSize + offset, false);
+ BitUtil::SetBitTo(buf_->mutable_data(), 6 * kWordSize + offset, false);
+ block = counter_->NextFourWords();
+
+ ASSERT_EQ(block.length, 4 * kWordSize);
+ ASSERT_EQ(block.popcount, 253);
+
+ // Set the next two words to all false
+ BitUtil::SetBitsTo(buf_->mutable_data(), 8 * kWordSize + offset, 2 * kWordSize,
+ false);
+
+ // Block is half set
+ block = counter_->NextFourWords();
+ ASSERT_EQ(block.length, 4 * kWordSize);
+ ASSERT_EQ(block.popcount, 128);
+
+ // Last full block whether offset or no
+ block = counter_->NextFourWords();
+ ASSERT_EQ(block.length, 4 * kWordSize);
+ ASSERT_EQ(block.length, block.popcount);
+
+ // Partial block
+ block = counter_->NextFourWords();
+ ASSERT_EQ(block.length, kWordSize - offset - 1);
+ ASSERT_EQ(block.length, block.popcount);
+
+ // We can keep calling NextFourWords safely
+ block = counter_->NextFourWords();
+ ASSERT_EQ(block.length, 0);
+ ASSERT_EQ(block.popcount, 0);
+ };
+
+ for (int64_t offset_i = 0; offset_i < 8; ++offset_i) {
+ CheckWithOffset(offset_i);
+ }
+}
+
+TEST_F(TestBitBlockCounter, FourWordsRandomData) {
+ const int64_t nbytes = 1024;
+ auto buffer = *AllocateBuffer(nbytes);
+ random_bytes(nbytes, 0, buffer->mutable_data());
+
+ auto CheckWithOffset = [&](int64_t offset) {
+ BitBlockCounter counter(buffer->data(), offset, nbytes * 8 - offset);
+ for (int64_t i = 0; i < nbytes / 32; ++i) {
+ BitBlockCount block = counter.NextFourWords();
+ ASSERT_EQ(block.popcount,
+ CountSetBits(buffer->data(), i * 256 + offset, block.length));
+ }
+ };
+ for (int64_t offset_i = 0; offset_i < 8; ++offset_i) {
+ CheckWithOffset(offset_i);
+ }
+}
+
+template <template <typename T> class Op, typename NextWordFunc>
+void CheckBinaryBitBlockOp(NextWordFunc&& get_next_word) {
+ const int64_t nbytes = 1024;
+ auto left = *AllocateBuffer(nbytes);
+ auto right = *AllocateBuffer(nbytes);
+ random_bytes(nbytes, 0, left->mutable_data());
+ random_bytes(nbytes, 0, right->mutable_data());
+
+ auto CheckWithOffsets = [&](int left_offset, int right_offset) {
+ int64_t overlap_length = nbytes * 8 - std::max(left_offset, right_offset);
+ BinaryBitBlockCounter counter(left->data(), left_offset, right->data(), right_offset,
+ overlap_length);
+ int64_t position = 0;
+ do {
+ BitBlockCount block = get_next_word(&counter);
+ int expected_popcount = 0;
+ for (int j = 0; j < block.length; ++j) {
+ expected_popcount += static_cast<int>(
+ Op<bool>::Call(BitUtil::GetBit(left->data(), position + left_offset + j),
+ BitUtil::GetBit(right->data(), position + right_offset + j)));
+ }
+ ASSERT_EQ(block.popcount, expected_popcount);
+ position += block.length;
+ } while (position < overlap_length);
+ // We made it through all the data
+ ASSERT_EQ(position, overlap_length);
+
+ BitBlockCount block = get_next_word(&counter);
+ ASSERT_EQ(block.length, 0);
+ ASSERT_EQ(block.popcount, 0);
+ };
+
+ for (int left_i = 0; left_i < 8; ++left_i) {
+ for (int right_i = 0; right_i < 8; ++right_i) {
+ CheckWithOffsets(left_i, right_i);
+ }
+ }
+}
+
+TEST(TestBinaryBitBlockCounter, NextAndWord) {
+ CheckBinaryBitBlockOp<detail::BitBlockAnd>(
+ [](BinaryBitBlockCounter* counter) { return counter->NextAndWord(); });
+}
+
+TEST(TestBinaryBitBlockCounter, NextOrWord) {
+ CheckBinaryBitBlockOp<detail::BitBlockOr>(
+ [](BinaryBitBlockCounter* counter) { return counter->NextOrWord(); });
+}
+
+TEST(TestBinaryBitBlockCounter, NextOrNotWord) {
+ CheckBinaryBitBlockOp<detail::BitBlockOrNot>(
+ [](BinaryBitBlockCounter* counter) { return counter->NextOrNotWord(); });
+}
+
+TEST(TestOptionalBitBlockCounter, NextBlock) {
+ const int64_t nbytes = 5000;
+ auto bitmap = *AllocateBitmap(nbytes * 8);
+ random_bytes(nbytes, 0, bitmap->mutable_data());
+
+ OptionalBitBlockCounter optional_counter(bitmap, 0, nbytes * 8);
+ BitBlockCounter bit_counter(bitmap->data(), 0, nbytes * 8);
+
+ while (true) {
+ BitBlockCount block = bit_counter.NextWord();
+ BitBlockCount optional_block = optional_counter.NextBlock();
+ ASSERT_EQ(optional_block.length, block.length);
+ ASSERT_EQ(optional_block.popcount, block.popcount);
+ if (block.length == 0) {
+ break;
+ }
+ }
+
+ BitBlockCount optional_block = optional_counter.NextBlock();
+ ASSERT_EQ(optional_block.length, 0);
+ ASSERT_EQ(optional_block.popcount, 0);
+
+ OptionalBitBlockCounter optional_counter_no_bitmap(nullptr, 0, nbytes * 8);
+ BitBlockCount no_bitmap_block = optional_counter_no_bitmap.NextBlock();
+
+ int16_t max_length = std::numeric_limits<int16_t>::max();
+ ASSERT_EQ(no_bitmap_block.length, max_length);
+ ASSERT_EQ(no_bitmap_block.popcount, max_length);
+ no_bitmap_block = optional_counter_no_bitmap.NextBlock();
+ ASSERT_EQ(no_bitmap_block.length, nbytes * 8 - max_length);
+ ASSERT_EQ(no_bitmap_block.popcount, no_bitmap_block.length);
+}
+
+TEST(TestOptionalBitBlockCounter, NextWord) {
+ const int64_t nbytes = 5000;
+ auto bitmap = *AllocateBitmap(nbytes * 8);
+ random_bytes(nbytes, 0, bitmap->mutable_data());
+
+ OptionalBitBlockCounter optional_counter(bitmap, 0, nbytes * 8);
+ OptionalBitBlockCounter optional_counter_no_bitmap(nullptr, 0, nbytes * 8);
+ BitBlockCounter bit_counter(bitmap->data(), 0, nbytes * 8);
+
+ while (true) {
+ BitBlockCount block = bit_counter.NextWord();
+ BitBlockCount no_bitmap_block = optional_counter_no_bitmap.NextWord();
+ BitBlockCount optional_block = optional_counter.NextWord();
+ ASSERT_EQ(optional_block.length, block.length);
+ ASSERT_EQ(optional_block.popcount, block.popcount);
+
+ ASSERT_EQ(no_bitmap_block.length, block.length);
+ ASSERT_EQ(no_bitmap_block.popcount, block.length);
+ if (block.length == 0) {
+ break;
+ }
+ }
+
+ BitBlockCount optional_block = optional_counter.NextWord();
+ ASSERT_EQ(optional_block.length, 0);
+ ASSERT_EQ(optional_block.popcount, 0);
+}
+
+class TestOptionalBinaryBitBlockCounter : public ::testing::Test {
+ public:
+ void SetUp() {
+ const int64_t nbytes = 5000;
+ ASSERT_OK_AND_ASSIGN(left_bitmap_, AllocateBitmap(nbytes * 8));
+ ASSERT_OK_AND_ASSIGN(right_bitmap_, AllocateBitmap(nbytes * 8));
+ random_bytes(nbytes, 0, left_bitmap_->mutable_data());
+ random_bytes(nbytes, 0, right_bitmap_->mutable_data());
+
+ left_offset_ = 12;
+ right_offset_ = 23;
+ length_ = nbytes * 8 - std::max(left_offset_, right_offset_);
+ }
+
+ protected:
+ std::shared_ptr<Buffer> left_bitmap_, right_bitmap_;
+ int64_t left_offset_;
+ int64_t right_offset_;
+ int64_t length_;
+};
+
+TEST_F(TestOptionalBinaryBitBlockCounter, NextBlockBothBitmaps) {
+ // Both bitmaps present
+ OptionalBinaryBitBlockCounter optional_counter(left_bitmap_, left_offset_,
+ right_bitmap_, right_offset_, length_);
+ BinaryBitBlockCounter bit_counter(left_bitmap_->data(), left_offset_,
+ right_bitmap_->data(), right_offset_, length_);
+
+ while (true) {
+ BitBlockCount block = bit_counter.NextAndWord();
+ BitBlockCount optional_block = optional_counter.NextAndBlock();
+ ASSERT_EQ(optional_block.length, block.length);
+ ASSERT_EQ(optional_block.popcount, block.popcount);
+ if (block.length == 0) {
+ break;
+ }
+ }
+}
+
+TEST_F(TestOptionalBinaryBitBlockCounter, NextBlockLeftBitmap) {
+ // Left bitmap present
+ OptionalBinaryBitBlockCounter optional_counter(left_bitmap_, left_offset_, nullptr,
+ right_offset_, length_);
+ BitBlockCounter bit_counter(left_bitmap_->data(), left_offset_, length_);
+
+ while (true) {
+ BitBlockCount block = bit_counter.NextWord();
+ BitBlockCount optional_block = optional_counter.NextAndBlock();
+ ASSERT_EQ(optional_block.length, block.length);
+ ASSERT_EQ(optional_block.popcount, block.popcount);
+ if (block.length == 0) {
+ break;
+ }
+ }
+}
+
+TEST_F(TestOptionalBinaryBitBlockCounter, NextBlockRightBitmap) {
+ // Right bitmap present
+ OptionalBinaryBitBlockCounter optional_counter(nullptr, left_offset_, right_bitmap_,
+ right_offset_, length_);
+ BitBlockCounter bit_counter(right_bitmap_->data(), right_offset_, length_);
+
+ while (true) {
+ BitBlockCount block = bit_counter.NextWord();
+ BitBlockCount optional_block = optional_counter.NextAndBlock();
+ ASSERT_EQ(optional_block.length, block.length);
+ ASSERT_EQ(optional_block.popcount, block.popcount);
+ if (block.length == 0) {
+ break;
+ }
+ }
+}
+
+TEST_F(TestOptionalBinaryBitBlockCounter, NextBlockNoBitmap) {
+ // No bitmap present
+ OptionalBinaryBitBlockCounter optional_counter(nullptr, left_offset_, nullptr,
+ right_offset_, length_);
+
+ BitBlockCount block = optional_counter.NextAndBlock();
+ ASSERT_EQ(block.length, std::numeric_limits<int16_t>::max());
+ ASSERT_EQ(block.popcount, block.length);
+
+ const int64_t remaining_length = length_ - block.length;
+ block = optional_counter.NextAndBlock();
+ ASSERT_EQ(block.length, remaining_length);
+ ASSERT_EQ(block.popcount, block.length);
+
+ block = optional_counter.NextAndBlock();
+ ASSERT_EQ(block.length, 0);
+ ASSERT_EQ(block.popcount, 0);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bit_run_reader.cc b/src/arrow/cpp/src/arrow/util/bit_run_reader.cc
new file mode 100644
index 000000000..eda6088eb
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bit_run_reader.cc
@@ -0,0 +1,54 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/bit_run_reader.h"
+
+#include <cstdint>
+
+#include "arrow/util/bit_util.h"
+
+namespace arrow {
+namespace internal {
+
+#if ARROW_LITTLE_ENDIAN
+
+BitRunReader::BitRunReader(const uint8_t* bitmap, int64_t start_offset, int64_t length)
+ : bitmap_(bitmap + (start_offset / 8)),
+ position_(start_offset % 8),
+ length_(position_ + length) {
+ if (ARROW_PREDICT_FALSE(length == 0)) {
+ word_ = 0;
+ return;
+ }
+
+ // On the initial load if there is an offset we need to account for this when
+ // loading bytes. Every other call to LoadWord() should only occur when
+ // position_ is a multiple of 64.
+ current_run_bit_set_ = !BitUtil::GetBit(bitmap, start_offset);
+ int64_t bits_remaining = length + position_;
+
+ LoadWord(bits_remaining);
+
+ // Prepare for inversion in NextRun.
+ // Clear out any preceding bits.
+ word_ = word_ & ~BitUtil::LeastSignificantBitMask(position_);
+}
+
+#endif
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bit_run_reader.h b/src/arrow/cpp/src/arrow/util/bit_run_reader.h
new file mode 100644
index 000000000..ed9f4fa86
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bit_run_reader.h
@@ -0,0 +1,515 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cassert>
+#include <cstdint>
+#include <cstring>
+#include <string>
+
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace internal {
+
+struct BitRun {
+ int64_t length;
+ // Whether bits are set at this point.
+ bool set;
+
+ std::string ToString() const {
+ return std::string("{Length: ") + std::to_string(length) +
+ ", set=" + std::to_string(set) + "}";
+ }
+};
+
+inline bool operator==(const BitRun& lhs, const BitRun& rhs) {
+ return lhs.length == rhs.length && lhs.set == rhs.set;
+}
+
+inline bool operator!=(const BitRun& lhs, const BitRun& rhs) {
+ return lhs.length != rhs.length || lhs.set != rhs.set;
+}
+
+class BitRunReaderLinear {
+ public:
+ BitRunReaderLinear(const uint8_t* bitmap, int64_t start_offset, int64_t length)
+ : reader_(bitmap, start_offset, length) {}
+
+ BitRun NextRun() {
+ BitRun rl = {/*length=*/0, reader_.IsSet()};
+ // Advance while the values are equal and not at the end of list.
+ while (reader_.position() < reader_.length() && reader_.IsSet() == rl.set) {
+ rl.length++;
+ reader_.Next();
+ }
+ return rl;
+ }
+
+ private:
+ BitmapReader reader_;
+};
+
+#if ARROW_LITTLE_ENDIAN
+/// A convenience class for counting the number of contiguous set/unset bits
+/// in a bitmap.
+class ARROW_EXPORT BitRunReader {
+ public:
+ /// \brief Constructs new BitRunReader.
+ ///
+ /// \param[in] bitmap source data
+ /// \param[in] start_offset bit offset into the source data
+ /// \param[in] length number of bits to copy
+ BitRunReader(const uint8_t* bitmap, int64_t start_offset, int64_t length);
+
+ /// Returns a new BitRun containing the number of contiguous
+ /// bits with the same value. length == 0 indicates the
+ /// end of the bitmap.
+ BitRun NextRun() {
+ if (ARROW_PREDICT_FALSE(position_ >= length_)) {
+ return {/*length=*/0, false};
+ }
+ // This implementation relies on a efficient implementations of
+ // CountTrailingZeros and assumes that runs are more often then
+ // not. The logic is to incrementally find the next bit change
+ // from the current position. This is done by zeroing all
+ // bits in word_ up to position_ and using the TrailingZeroCount
+ // to find the index of the next set bit.
+
+ // The runs alternate on each call, so flip the bit.
+ current_run_bit_set_ = !current_run_bit_set_;
+
+ int64_t start_position = position_;
+ int64_t start_bit_offset = start_position & 63;
+ // Invert the word for proper use of CountTrailingZeros and
+ // clear bits so CountTrailingZeros can do it magic.
+ word_ = ~word_ & ~BitUtil::LeastSignificantBitMask(start_bit_offset);
+
+ // Go forward until the next change from unset to set.
+ int64_t new_bits = BitUtil::CountTrailingZeros(word_) - start_bit_offset;
+ position_ += new_bits;
+
+ if (ARROW_PREDICT_FALSE(BitUtil::IsMultipleOf64(position_)) &&
+ ARROW_PREDICT_TRUE(position_ < length_)) {
+ // Continue extending position while we can advance an entire word.
+ // (updates position_ accordingly).
+ AdvanceUntilChange();
+ }
+
+ return {/*length=*/position_ - start_position, current_run_bit_set_};
+ }
+
+ private:
+ void AdvanceUntilChange() {
+ int64_t new_bits = 0;
+ do {
+ // Advance the position of the bitmap for loading.
+ bitmap_ += sizeof(uint64_t);
+ LoadNextWord();
+ new_bits = BitUtil::CountTrailingZeros(word_);
+ // Continue calculating run length.
+ position_ += new_bits;
+ } while (ARROW_PREDICT_FALSE(BitUtil::IsMultipleOf64(position_)) &&
+ ARROW_PREDICT_TRUE(position_ < length_) && new_bits > 0);
+ }
+
+ void LoadNextWord() { return LoadWord(length_ - position_); }
+
+ // Helper method for Loading the next word.
+ void LoadWord(int64_t bits_remaining) {
+ word_ = 0;
+ // we need at least an extra byte in this case.
+ if (ARROW_PREDICT_TRUE(bits_remaining >= 64)) {
+ std::memcpy(&word_, bitmap_, 8);
+ } else {
+ int64_t bytes_to_load = BitUtil::BytesForBits(bits_remaining);
+ auto word_ptr = reinterpret_cast<uint8_t*>(&word_);
+ std::memcpy(word_ptr, bitmap_, bytes_to_load);
+ // Ensure stoppage at last bit in bitmap by reversing the next higher
+ // order bit.
+ BitUtil::SetBitTo(word_ptr, bits_remaining,
+ !BitUtil::GetBit(word_ptr, bits_remaining - 1));
+ }
+
+ // Two cases:
+ // 1. For unset, CountTrailingZeros works naturally so we don't
+ // invert the word.
+ // 2. Otherwise invert so we can use CountTrailingZeros.
+ if (current_run_bit_set_) {
+ word_ = ~word_;
+ }
+ }
+ const uint8_t* bitmap_;
+ int64_t position_;
+ int64_t length_;
+ uint64_t word_;
+ bool current_run_bit_set_;
+};
+#else
+using BitRunReader = BitRunReaderLinear;
+#endif
+
+struct SetBitRun {
+ int64_t position;
+ int64_t length;
+
+ bool AtEnd() const { return length == 0; }
+
+ std::string ToString() const {
+ return std::string("{pos=") + std::to_string(position) +
+ ", len=" + std::to_string(length) + "}";
+ }
+
+ bool operator==(const SetBitRun& other) const {
+ return position == other.position && length == other.length;
+ }
+ bool operator!=(const SetBitRun& other) const {
+ return position != other.position || length != other.length;
+ }
+};
+
+template <bool Reverse>
+class BaseSetBitRunReader {
+ public:
+ /// \brief Constructs new SetBitRunReader.
+ ///
+ /// \param[in] bitmap source data
+ /// \param[in] start_offset bit offset into the source data
+ /// \param[in] length number of bits to copy
+ ARROW_NOINLINE
+ BaseSetBitRunReader(const uint8_t* bitmap, int64_t start_offset, int64_t length)
+ : bitmap_(util::MakeNonNull(bitmap)),
+ length_(length),
+ remaining_(length_),
+ current_word_(0),
+ current_num_bits_(0) {
+ if (Reverse) {
+ bitmap_ += (start_offset + length) / 8;
+ const int8_t end_bit_offset = static_cast<int8_t>((start_offset + length) % 8);
+ if (length > 0 && end_bit_offset) {
+ // Get LSBs from last byte
+ ++bitmap_;
+ current_num_bits_ =
+ std::min(static_cast<int32_t>(length), static_cast<int32_t>(end_bit_offset));
+ current_word_ = LoadPartialWord(8 - end_bit_offset, current_num_bits_);
+ }
+ } else {
+ bitmap_ += start_offset / 8;
+ const int8_t bit_offset = static_cast<int8_t>(start_offset % 8);
+ if (length > 0 && bit_offset) {
+ // Get MSBs from first byte
+ current_num_bits_ =
+ std::min(static_cast<int32_t>(length), static_cast<int32_t>(8 - bit_offset));
+ current_word_ = LoadPartialWord(bit_offset, current_num_bits_);
+ }
+ }
+ }
+
+ ARROW_NOINLINE
+ SetBitRun NextRun() {
+ int64_t pos = 0;
+ int64_t len = 0;
+ if (current_num_bits_) {
+ const auto run = FindCurrentRun();
+ assert(remaining_ >= 0);
+ if (run.length && current_num_bits_) {
+ // The run ends in current_word_
+ return AdjustRun(run);
+ }
+ pos = run.position;
+ len = run.length;
+ }
+ if (!len) {
+ // We didn't get any ones in current_word_, so we can skip any zeros
+ // in the following words
+ SkipNextZeros();
+ if (remaining_ == 0) {
+ return {0, 0};
+ }
+ assert(current_num_bits_);
+ pos = position();
+ } else if (!current_num_bits_) {
+ if (ARROW_PREDICT_TRUE(remaining_ >= 64)) {
+ current_word_ = LoadFullWord();
+ current_num_bits_ = 64;
+ } else if (remaining_ > 0) {
+ current_word_ = LoadPartialWord(/*bit_offset=*/0, remaining_);
+ current_num_bits_ = static_cast<int32_t>(remaining_);
+ } else {
+ // No bits remaining, perhaps we found a run?
+ return AdjustRun({pos, len});
+ }
+ // If current word starts with a zero, we got a full run
+ if (!(current_word_ & kFirstBit)) {
+ return AdjustRun({pos, len});
+ }
+ }
+ // Current word should now start with a set bit
+ len += CountNextOnes();
+ return AdjustRun({pos, len});
+ }
+
+ protected:
+ int64_t position() const {
+ if (Reverse) {
+ return remaining_;
+ } else {
+ return length_ - remaining_;
+ }
+ }
+
+ SetBitRun AdjustRun(SetBitRun run) {
+ if (Reverse) {
+ assert(run.position >= run.length);
+ run.position -= run.length;
+ }
+ return run;
+ }
+
+ uint64_t LoadFullWord() {
+ uint64_t word;
+ if (Reverse) {
+ bitmap_ -= 8;
+ }
+ memcpy(&word, bitmap_, 8);
+ if (!Reverse) {
+ bitmap_ += 8;
+ }
+ return BitUtil::ToLittleEndian(word);
+ }
+
+ uint64_t LoadPartialWord(int8_t bit_offset, int64_t num_bits) {
+ assert(num_bits > 0);
+ uint64_t word = 0;
+ const int64_t num_bytes = BitUtil::BytesForBits(num_bits);
+ if (Reverse) {
+ // Read in the most significant bytes of the word
+ bitmap_ -= num_bytes;
+ memcpy(reinterpret_cast<char*>(&word) + 8 - num_bytes, bitmap_, num_bytes);
+ // XXX MostSignificantBitmask
+ return (BitUtil::ToLittleEndian(word) << bit_offset) &
+ ~BitUtil::LeastSignificantBitMask(64 - num_bits);
+ } else {
+ memcpy(&word, bitmap_, num_bytes);
+ bitmap_ += num_bytes;
+ return (BitUtil::ToLittleEndian(word) >> bit_offset) &
+ BitUtil::LeastSignificantBitMask(num_bits);
+ }
+ }
+
+ void SkipNextZeros() {
+ assert(current_num_bits_ == 0);
+ while (ARROW_PREDICT_TRUE(remaining_ >= 64)) {
+ current_word_ = LoadFullWord();
+ const auto num_zeros = CountFirstZeros(current_word_);
+ if (num_zeros < 64) {
+ // Run of zeros ends here
+ current_word_ = ConsumeBits(current_word_, num_zeros);
+ current_num_bits_ = 64 - num_zeros;
+ remaining_ -= num_zeros;
+ assert(remaining_ >= 0);
+ assert(current_num_bits_ >= 0);
+ return;
+ }
+ remaining_ -= 64;
+ }
+ // Run of zeros continues in last bitmap word
+ if (remaining_ > 0) {
+ current_word_ = LoadPartialWord(/*bit_offset=*/0, remaining_);
+ current_num_bits_ = static_cast<int32_t>(remaining_);
+ const auto num_zeros =
+ std::min<int32_t>(current_num_bits_, CountFirstZeros(current_word_));
+ current_word_ = ConsumeBits(current_word_, num_zeros);
+ current_num_bits_ -= num_zeros;
+ remaining_ -= num_zeros;
+ assert(remaining_ >= 0);
+ assert(current_num_bits_ >= 0);
+ }
+ }
+
+ int64_t CountNextOnes() {
+ assert(current_word_ & kFirstBit);
+
+ int64_t len;
+ if (~current_word_) {
+ const auto num_ones = CountFirstZeros(~current_word_);
+ assert(num_ones <= current_num_bits_);
+ assert(num_ones <= remaining_);
+ remaining_ -= num_ones;
+ current_word_ = ConsumeBits(current_word_, num_ones);
+ current_num_bits_ -= num_ones;
+ if (current_num_bits_) {
+ // Run of ones ends here
+ return num_ones;
+ }
+ len = num_ones;
+ } else {
+ // current_word_ is all ones
+ remaining_ -= 64;
+ current_num_bits_ = 0;
+ len = 64;
+ }
+
+ while (ARROW_PREDICT_TRUE(remaining_ >= 64)) {
+ current_word_ = LoadFullWord();
+ const auto num_ones = CountFirstZeros(~current_word_);
+ len += num_ones;
+ remaining_ -= num_ones;
+ if (num_ones < 64) {
+ // Run of ones ends here
+ current_word_ = ConsumeBits(current_word_, num_ones);
+ current_num_bits_ = 64 - num_ones;
+ return len;
+ }
+ }
+ // Run of ones continues in last bitmap word
+ if (remaining_ > 0) {
+ current_word_ = LoadPartialWord(/*bit_offset=*/0, remaining_);
+ current_num_bits_ = static_cast<int32_t>(remaining_);
+ const auto num_ones = CountFirstZeros(~current_word_);
+ assert(num_ones <= current_num_bits_);
+ assert(num_ones <= remaining_);
+ current_word_ = ConsumeBits(current_word_, num_ones);
+ current_num_bits_ -= num_ones;
+ remaining_ -= num_ones;
+ len += num_ones;
+ }
+ return len;
+ }
+
+ SetBitRun FindCurrentRun() {
+ // Skip any pending zeros
+ const auto num_zeros = CountFirstZeros(current_word_);
+ if (num_zeros >= current_num_bits_) {
+ remaining_ -= current_num_bits_;
+ current_word_ = 0;
+ current_num_bits_ = 0;
+ return {0, 0};
+ }
+ assert(num_zeros <= remaining_);
+ current_word_ = ConsumeBits(current_word_, num_zeros);
+ current_num_bits_ -= num_zeros;
+ remaining_ -= num_zeros;
+ const int64_t pos = position();
+ // Count any ones
+ const auto num_ones = CountFirstZeros(~current_word_);
+ assert(num_ones <= current_num_bits_);
+ assert(num_ones <= remaining_);
+ current_word_ = ConsumeBits(current_word_, num_ones);
+ current_num_bits_ -= num_ones;
+ remaining_ -= num_ones;
+ return {pos, num_ones};
+ }
+
+ inline int CountFirstZeros(uint64_t word);
+ inline uint64_t ConsumeBits(uint64_t word, int32_t num_bits);
+
+ const uint8_t* bitmap_;
+ const int64_t length_;
+ int64_t remaining_;
+ uint64_t current_word_;
+ int32_t current_num_bits_;
+
+ static constexpr uint64_t kFirstBit = Reverse ? 0x8000000000000000ULL : 1;
+};
+
+template <>
+inline int BaseSetBitRunReader<false>::CountFirstZeros(uint64_t word) {
+ return BitUtil::CountTrailingZeros(word);
+}
+
+template <>
+inline int BaseSetBitRunReader<true>::CountFirstZeros(uint64_t word) {
+ return BitUtil::CountLeadingZeros(word);
+}
+
+template <>
+inline uint64_t BaseSetBitRunReader<false>::ConsumeBits(uint64_t word, int32_t num_bits) {
+ return word >> num_bits;
+}
+
+template <>
+inline uint64_t BaseSetBitRunReader<true>::ConsumeBits(uint64_t word, int32_t num_bits) {
+ return word << num_bits;
+}
+
+using SetBitRunReader = BaseSetBitRunReader</*Reverse=*/false>;
+using ReverseSetBitRunReader = BaseSetBitRunReader</*Reverse=*/true>;
+
+// Functional-style bit run visitors.
+
+// XXX: Try to make this function small so the compiler can inline and optimize
+// the `visit` function, which is normally a hot loop with vectorizable code.
+// - don't inline SetBitRunReader constructor, it doesn't hurt performance
+// - un-inline NextRun hurts 'many null' cases a bit, but improves normal cases
+template <typename Visit>
+inline Status VisitSetBitRuns(const uint8_t* bitmap, int64_t offset, int64_t length,
+ Visit&& visit) {
+ if (bitmap == NULLPTR) {
+ // Assuming all set (as in a null bitmap)
+ return visit(static_cast<int64_t>(0), static_cast<int64_t>(length));
+ }
+ SetBitRunReader reader(bitmap, offset, length);
+ while (true) {
+ const auto run = reader.NextRun();
+ if (run.length == 0) {
+ break;
+ }
+ ARROW_RETURN_NOT_OK(visit(run.position, run.length));
+ }
+ return Status::OK();
+}
+
+template <typename Visit>
+inline void VisitSetBitRunsVoid(const uint8_t* bitmap, int64_t offset, int64_t length,
+ Visit&& visit) {
+ if (bitmap == NULLPTR) {
+ // Assuming all set (as in a null bitmap)
+ visit(static_cast<int64_t>(0), static_cast<int64_t>(length));
+ return;
+ }
+ SetBitRunReader reader(bitmap, offset, length);
+ while (true) {
+ const auto run = reader.NextRun();
+ if (run.length == 0) {
+ break;
+ }
+ visit(run.position, run.length);
+ }
+}
+
+template <typename Visit>
+inline Status VisitSetBitRuns(const std::shared_ptr<Buffer>& bitmap, int64_t offset,
+ int64_t length, Visit&& visit) {
+ return VisitSetBitRuns(bitmap ? bitmap->data() : NULLPTR, offset, length,
+ std::forward<Visit>(visit));
+}
+
+template <typename Visit>
+inline void VisitSetBitRunsVoid(const std::shared_ptr<Buffer>& bitmap, int64_t offset,
+ int64_t length, Visit&& visit) {
+ VisitSetBitRunsVoid(bitmap ? bitmap->data() : NULLPTR, offset, length,
+ std::forward<Visit>(visit));
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bit_stream_utils.h b/src/arrow/cpp/src/arrow/util/bit_stream_utils.h
new file mode 100644
index 000000000..49f602ed8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bit_stream_utils.h
@@ -0,0 +1,513 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// From Apache Impala (incubating) as of 2016-01-29
+
+#pragma once
+
+#include <string.h>
+
+#include <algorithm>
+#include <cstdint>
+
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bpacking.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/ubsan.h"
+
+namespace arrow {
+namespace BitUtil {
+
+/// Utility class to write bit/byte streams. This class can write data to either be
+/// bit packed or byte aligned (and a single stream that has a mix of both).
+/// This class does not allocate memory.
+class BitWriter {
+ public:
+ /// buffer: buffer to write bits to. Buffer should be preallocated with
+ /// 'buffer_len' bytes.
+ BitWriter(uint8_t* buffer, int buffer_len) : buffer_(buffer), max_bytes_(buffer_len) {
+ Clear();
+ }
+
+ void Clear() {
+ buffered_values_ = 0;
+ byte_offset_ = 0;
+ bit_offset_ = 0;
+ }
+
+ /// The number of current bytes written, including the current byte (i.e. may include a
+ /// fraction of a byte). Includes buffered values.
+ int bytes_written() const {
+ return byte_offset_ + static_cast<int>(BitUtil::BytesForBits(bit_offset_));
+ }
+ uint8_t* buffer() const { return buffer_; }
+ int buffer_len() const { return max_bytes_; }
+
+ /// Writes a value to buffered_values_, flushing to buffer_ if necessary. This is bit
+ /// packed. Returns false if there was not enough space. num_bits must be <= 32.
+ bool PutValue(uint64_t v, int num_bits);
+
+ /// Writes v to the next aligned byte using num_bytes. If T is larger than
+ /// num_bytes, the extra high-order bytes will be ignored. Returns false if
+ /// there was not enough space.
+ /// Assume the v is stored in buffer_ as a litte-endian format
+ template <typename T>
+ bool PutAligned(T v, int num_bytes);
+
+ /// Write a Vlq encoded int to the buffer. Returns false if there was not enough
+ /// room. The value is written byte aligned.
+ /// For more details on vlq:
+ /// en.wikipedia.org/wiki/Variable-length_quantity
+ bool PutVlqInt(uint32_t v);
+
+ // Writes an int zigzag encoded.
+ bool PutZigZagVlqInt(int32_t v);
+
+ /// Write a Vlq encoded int64 to the buffer. Returns false if there was not enough
+ /// room. The value is written byte aligned.
+ /// For more details on vlq:
+ /// en.wikipedia.org/wiki/Variable-length_quantity
+ bool PutVlqInt(uint64_t v);
+
+ // Writes an int64 zigzag encoded.
+ bool PutZigZagVlqInt(int64_t v);
+
+ /// Get a pointer to the next aligned byte and advance the underlying buffer
+ /// by num_bytes.
+ /// Returns NULL if there was not enough space.
+ uint8_t* GetNextBytePtr(int num_bytes = 1);
+
+ /// Flushes all buffered values to the buffer. Call this when done writing to
+ /// the buffer. If 'align' is true, buffered_values_ is reset and any future
+ /// writes will be written to the next byte boundary.
+ void Flush(bool align = false);
+
+ private:
+ uint8_t* buffer_;
+ int max_bytes_;
+
+ /// Bit-packed values are initially written to this variable before being memcpy'd to
+ /// buffer_. This is faster than writing values byte by byte directly to buffer_.
+ uint64_t buffered_values_;
+
+ int byte_offset_; // Offset in buffer_
+ int bit_offset_; // Offset in buffered_values_
+};
+
+/// Utility class to read bit/byte stream. This class can read bits or bytes
+/// that are either byte aligned or not. It also has utilities to read multiple
+/// bytes in one read (e.g. encoded int).
+class BitReader {
+ public:
+ /// 'buffer' is the buffer to read from. The buffer's length is 'buffer_len'.
+ BitReader(const uint8_t* buffer, int buffer_len)
+ : buffer_(buffer), max_bytes_(buffer_len), byte_offset_(0), bit_offset_(0) {
+ int num_bytes = std::min(8, max_bytes_ - byte_offset_);
+ memcpy(&buffered_values_, buffer_ + byte_offset_, num_bytes);
+ buffered_values_ = arrow::BitUtil::FromLittleEndian(buffered_values_);
+ }
+
+ BitReader()
+ : buffer_(NULL),
+ max_bytes_(0),
+ buffered_values_(0),
+ byte_offset_(0),
+ bit_offset_(0) {}
+
+ void Reset(const uint8_t* buffer, int buffer_len) {
+ buffer_ = buffer;
+ max_bytes_ = buffer_len;
+ byte_offset_ = 0;
+ bit_offset_ = 0;
+ int num_bytes = std::min(8, max_bytes_ - byte_offset_);
+ memcpy(&buffered_values_, buffer_ + byte_offset_, num_bytes);
+ buffered_values_ = arrow::BitUtil::FromLittleEndian(buffered_values_);
+ }
+
+ /// Gets the next value from the buffer. Returns true if 'v' could be read or false if
+ /// there are not enough bytes left. num_bits must be <= 32.
+ template <typename T>
+ bool GetValue(int num_bits, T* v);
+
+ /// Get a number of values from the buffer. Return the number of values actually read.
+ template <typename T>
+ int GetBatch(int num_bits, T* v, int batch_size);
+
+ /// Reads a 'num_bytes'-sized value from the buffer and stores it in 'v'. T
+ /// needs to be a little-endian native type and big enough to store
+ /// 'num_bytes'. The value is assumed to be byte-aligned so the stream will
+ /// be advanced to the start of the next byte before 'v' is read. Returns
+ /// false if there are not enough bytes left.
+ /// Assume the v was stored in buffer_ as a litte-endian format
+ template <typename T>
+ bool GetAligned(int num_bytes, T* v);
+
+ /// Reads a vlq encoded int from the stream. The encoded int must start at
+ /// the beginning of a byte. Return false if there were not enough bytes in
+ /// the buffer.
+ bool GetVlqInt(uint32_t* v);
+
+ // Reads a zigzag encoded int `into` v.
+ bool GetZigZagVlqInt(int32_t* v);
+
+ /// Reads a vlq encoded int64 from the stream. The encoded int must start at
+ /// the beginning of a byte. Return false if there were not enough bytes in
+ /// the buffer.
+ bool GetVlqInt(uint64_t* v);
+
+ // Reads a zigzag encoded int64 `into` v.
+ bool GetZigZagVlqInt(int64_t* v);
+
+ /// Returns the number of bytes left in the stream, not including the current
+ /// byte (i.e., there may be an additional fraction of a byte).
+ int bytes_left() {
+ return max_bytes_ -
+ (byte_offset_ + static_cast<int>(BitUtil::BytesForBits(bit_offset_)));
+ }
+
+ /// Maximum byte length of a vlq encoded int
+ static constexpr int kMaxVlqByteLength = 5;
+
+ /// Maximum byte length of a vlq encoded int64
+ static constexpr int kMaxVlqByteLengthForInt64 = 10;
+
+ private:
+ const uint8_t* buffer_;
+ int max_bytes_;
+
+ /// Bytes are memcpy'd from buffer_ and values are read from this variable. This is
+ /// faster than reading values byte by byte directly from buffer_.
+ uint64_t buffered_values_;
+
+ int byte_offset_; // Offset in buffer_
+ int bit_offset_; // Offset in buffered_values_
+};
+
+inline bool BitWriter::PutValue(uint64_t v, int num_bits) {
+ // TODO: revisit this limit if necessary (can be raised to 64 by fixing some edge cases)
+ DCHECK_LE(num_bits, 32);
+ DCHECK_EQ(v >> num_bits, 0) << "v = " << v << ", num_bits = " << num_bits;
+
+ if (ARROW_PREDICT_FALSE(byte_offset_ * 8 + bit_offset_ + num_bits > max_bytes_ * 8))
+ return false;
+
+ buffered_values_ |= v << bit_offset_;
+ bit_offset_ += num_bits;
+
+ if (ARROW_PREDICT_FALSE(bit_offset_ >= 64)) {
+ // Flush buffered_values_ and write out bits of v that did not fit
+ buffered_values_ = arrow::BitUtil::ToLittleEndian(buffered_values_);
+ memcpy(buffer_ + byte_offset_, &buffered_values_, 8);
+ buffered_values_ = 0;
+ byte_offset_ += 8;
+ bit_offset_ -= 64;
+ buffered_values_ = v >> (num_bits - bit_offset_);
+ }
+ DCHECK_LT(bit_offset_, 64);
+ return true;
+}
+
+inline void BitWriter::Flush(bool align) {
+ int num_bytes = static_cast<int>(BitUtil::BytesForBits(bit_offset_));
+ DCHECK_LE(byte_offset_ + num_bytes, max_bytes_);
+ auto buffered_values = arrow::BitUtil::ToLittleEndian(buffered_values_);
+ memcpy(buffer_ + byte_offset_, &buffered_values, num_bytes);
+
+ if (align) {
+ buffered_values_ = 0;
+ byte_offset_ += num_bytes;
+ bit_offset_ = 0;
+ }
+}
+
+inline uint8_t* BitWriter::GetNextBytePtr(int num_bytes) {
+ Flush(/* align */ true);
+ DCHECK_LE(byte_offset_, max_bytes_);
+ if (byte_offset_ + num_bytes > max_bytes_) return NULL;
+ uint8_t* ptr = buffer_ + byte_offset_;
+ byte_offset_ += num_bytes;
+ return ptr;
+}
+
+template <typename T>
+inline bool BitWriter::PutAligned(T val, int num_bytes) {
+ uint8_t* ptr = GetNextBytePtr(num_bytes);
+ if (ptr == NULL) return false;
+ val = arrow::BitUtil::ToLittleEndian(val);
+ memcpy(ptr, &val, num_bytes);
+ return true;
+}
+
+namespace detail {
+
+template <typename T>
+inline void GetValue_(int num_bits, T* v, int max_bytes, const uint8_t* buffer,
+ int* bit_offset, int* byte_offset, uint64_t* buffered_values) {
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4800)
+#endif
+ *v = static_cast<T>(BitUtil::TrailingBits(*buffered_values, *bit_offset + num_bits) >>
+ *bit_offset);
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+ *bit_offset += num_bits;
+ if (*bit_offset >= 64) {
+ *byte_offset += 8;
+ *bit_offset -= 64;
+
+ int bytes_remaining = max_bytes - *byte_offset;
+ if (ARROW_PREDICT_TRUE(bytes_remaining >= 8)) {
+ memcpy(buffered_values, buffer + *byte_offset, 8);
+ } else {
+ memcpy(buffered_values, buffer + *byte_offset, bytes_remaining);
+ }
+ *buffered_values = arrow::BitUtil::FromLittleEndian(*buffered_values);
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4800 4805)
+#endif
+ // Read bits of v that crossed into new buffered_values_
+ if (ARROW_PREDICT_TRUE(num_bits - *bit_offset < static_cast<int>(8 * sizeof(T)))) {
+ // if shift exponent(num_bits - *bit_offset) is not less than sizeof(T), *v will not
+ // change and the following code may cause a runtime error that the shift exponent
+ // is too large
+ *v = *v | static_cast<T>(BitUtil::TrailingBits(*buffered_values, *bit_offset)
+ << (num_bits - *bit_offset));
+ }
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+ DCHECK_LE(*bit_offset, 64);
+ }
+}
+
+} // namespace detail
+
+template <typename T>
+inline bool BitReader::GetValue(int num_bits, T* v) {
+ return GetBatch(num_bits, v, 1) == 1;
+}
+
+template <typename T>
+inline int BitReader::GetBatch(int num_bits, T* v, int batch_size) {
+ DCHECK(buffer_ != NULL);
+ DCHECK_LE(num_bits, static_cast<int>(sizeof(T) * 8));
+
+ int bit_offset = bit_offset_;
+ int byte_offset = byte_offset_;
+ uint64_t buffered_values = buffered_values_;
+ int max_bytes = max_bytes_;
+ const uint8_t* buffer = buffer_;
+
+ uint64_t needed_bits = num_bits * batch_size;
+ constexpr uint64_t kBitsPerByte = 8;
+ uint64_t remaining_bits = (max_bytes - byte_offset) * kBitsPerByte - bit_offset;
+ if (remaining_bits < needed_bits) {
+ batch_size = static_cast<int>(remaining_bits) / num_bits;
+ }
+
+ int i = 0;
+ if (ARROW_PREDICT_FALSE(bit_offset != 0)) {
+ for (; i < batch_size && bit_offset != 0; ++i) {
+ detail::GetValue_(num_bits, &v[i], max_bytes, buffer, &bit_offset, &byte_offset,
+ &buffered_values);
+ }
+ }
+
+ if (sizeof(T) == 4) {
+ int num_unpacked =
+ internal::unpack32(reinterpret_cast<const uint32_t*>(buffer + byte_offset),
+ reinterpret_cast<uint32_t*>(v + i), batch_size - i, num_bits);
+ i += num_unpacked;
+ byte_offset += num_unpacked * num_bits / 8;
+ } else if (sizeof(T) == 8 && num_bits > 32) {
+ // Use unpack64 only if num_bits is larger than 32
+ // TODO (ARROW-13677): improve the performance of internal::unpack64
+ // and remove the restriction of num_bits
+ int num_unpacked =
+ internal::unpack64(buffer + byte_offset, reinterpret_cast<uint64_t*>(v + i),
+ batch_size - i, num_bits);
+ i += num_unpacked;
+ byte_offset += num_unpacked * num_bits / 8;
+ } else {
+ // TODO: revisit this limit if necessary
+ DCHECK_LE(num_bits, 32);
+ const int buffer_size = 1024;
+ uint32_t unpack_buffer[buffer_size];
+ while (i < batch_size) {
+ int unpack_size = std::min(buffer_size, batch_size - i);
+ int num_unpacked =
+ internal::unpack32(reinterpret_cast<const uint32_t*>(buffer + byte_offset),
+ unpack_buffer, unpack_size, num_bits);
+ if (num_unpacked == 0) {
+ break;
+ }
+ for (int k = 0; k < num_unpacked; ++k) {
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4800)
+#endif
+ v[i + k] = static_cast<T>(unpack_buffer[k]);
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+ }
+ i += num_unpacked;
+ byte_offset += num_unpacked * num_bits / 8;
+ }
+ }
+
+ int bytes_remaining = max_bytes - byte_offset;
+ if (bytes_remaining >= 8) {
+ memcpy(&buffered_values, buffer + byte_offset, 8);
+ } else {
+ memcpy(&buffered_values, buffer + byte_offset, bytes_remaining);
+ }
+ buffered_values = arrow::BitUtil::FromLittleEndian(buffered_values);
+
+ for (; i < batch_size; ++i) {
+ detail::GetValue_(num_bits, &v[i], max_bytes, buffer, &bit_offset, &byte_offset,
+ &buffered_values);
+ }
+
+ bit_offset_ = bit_offset;
+ byte_offset_ = byte_offset;
+ buffered_values_ = buffered_values;
+
+ return batch_size;
+}
+
+template <typename T>
+inline bool BitReader::GetAligned(int num_bytes, T* v) {
+ if (ARROW_PREDICT_FALSE(num_bytes > static_cast<int>(sizeof(T)))) {
+ return false;
+ }
+
+ int bytes_read = static_cast<int>(BitUtil::BytesForBits(bit_offset_));
+ if (ARROW_PREDICT_FALSE(byte_offset_ + bytes_read + num_bytes > max_bytes_)) {
+ return false;
+ }
+
+ // Advance byte_offset to next unread byte and read num_bytes
+ byte_offset_ += bytes_read;
+ memcpy(v, buffer_ + byte_offset_, num_bytes);
+ *v = arrow::BitUtil::FromLittleEndian(*v);
+ byte_offset_ += num_bytes;
+
+ // Reset buffered_values_
+ bit_offset_ = 0;
+ int bytes_remaining = max_bytes_ - byte_offset_;
+ if (ARROW_PREDICT_TRUE(bytes_remaining >= 8)) {
+ memcpy(&buffered_values_, buffer_ + byte_offset_, 8);
+ } else {
+ memcpy(&buffered_values_, buffer_ + byte_offset_, bytes_remaining);
+ }
+ buffered_values_ = arrow::BitUtil::FromLittleEndian(buffered_values_);
+ return true;
+}
+
+inline bool BitWriter::PutVlqInt(uint32_t v) {
+ bool result = true;
+ while ((v & 0xFFFFFF80UL) != 0UL) {
+ result &= PutAligned<uint8_t>(static_cast<uint8_t>((v & 0x7F) | 0x80), 1);
+ v >>= 7;
+ }
+ result &= PutAligned<uint8_t>(static_cast<uint8_t>(v & 0x7F), 1);
+ return result;
+}
+
+inline bool BitReader::GetVlqInt(uint32_t* v) {
+ uint32_t tmp = 0;
+
+ for (int i = 0; i < kMaxVlqByteLength; i++) {
+ uint8_t byte = 0;
+ if (ARROW_PREDICT_FALSE(!GetAligned<uint8_t>(1, &byte))) {
+ return false;
+ }
+ tmp |= static_cast<uint32_t>(byte & 0x7F) << (7 * i);
+
+ if ((byte & 0x80) == 0) {
+ *v = tmp;
+ return true;
+ }
+ }
+
+ return false;
+}
+
+inline bool BitWriter::PutZigZagVlqInt(int32_t v) {
+ uint32_t u_v = ::arrow::util::SafeCopy<uint32_t>(v);
+ u_v = (u_v << 1) ^ static_cast<uint32_t>(v >> 31);
+ return PutVlqInt(u_v);
+}
+
+inline bool BitReader::GetZigZagVlqInt(int32_t* v) {
+ uint32_t u;
+ if (!GetVlqInt(&u)) return false;
+ u = (u >> 1) ^ (~(u & 1) + 1);
+ *v = ::arrow::util::SafeCopy<int32_t>(u);
+ return true;
+}
+
+inline bool BitWriter::PutVlqInt(uint64_t v) {
+ bool result = true;
+ while ((v & 0xFFFFFFFFFFFFFF80ULL) != 0ULL) {
+ result &= PutAligned<uint8_t>(static_cast<uint8_t>((v & 0x7F) | 0x80), 1);
+ v >>= 7;
+ }
+ result &= PutAligned<uint8_t>(static_cast<uint8_t>(v & 0x7F), 1);
+ return result;
+}
+
+inline bool BitReader::GetVlqInt(uint64_t* v) {
+ uint64_t tmp = 0;
+
+ for (int i = 0; i < kMaxVlqByteLengthForInt64; i++) {
+ uint8_t byte = 0;
+ if (ARROW_PREDICT_FALSE(!GetAligned<uint8_t>(1, &byte))) {
+ return false;
+ }
+ tmp |= static_cast<uint64_t>(byte & 0x7F) << (7 * i);
+
+ if ((byte & 0x80) == 0) {
+ *v = tmp;
+ return true;
+ }
+ }
+
+ return false;
+}
+
+inline bool BitWriter::PutZigZagVlqInt(int64_t v) {
+ uint64_t u_v = ::arrow::util::SafeCopy<uint64_t>(v);
+ u_v = (u_v << 1) ^ static_cast<uint64_t>(v >> 63);
+ return PutVlqInt(u_v);
+}
+
+inline bool BitReader::GetZigZagVlqInt(int64_t* v) {
+ uint64_t u;
+ if (!GetVlqInt(&u)) return false;
+ u = (u >> 1) ^ (~(u & 1) + 1);
+ *v = ::arrow::util::SafeCopy<int64_t>(u);
+ return true;
+}
+
+} // namespace BitUtil
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bit_util.cc b/src/arrow/cpp/src/arrow/util/bit_util.cc
new file mode 100644
index 000000000..aa78da765
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bit_util.cc
@@ -0,0 +1,129 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/bit_util.h"
+
+#include <cstdint>
+#include <cstring>
+
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace BitUtil {
+
+void SetBitsTo(uint8_t* bits, int64_t start_offset, int64_t length, bool bits_are_set) {
+ if (length == 0) {
+ return;
+ }
+
+ const int64_t i_begin = start_offset;
+ const int64_t i_end = start_offset + length;
+ const uint8_t fill_byte = static_cast<uint8_t>(-static_cast<uint8_t>(bits_are_set));
+
+ const int64_t bytes_begin = i_begin / 8;
+ const int64_t bytes_end = i_end / 8 + 1;
+
+ const uint8_t first_byte_mask = kPrecedingBitmask[i_begin % 8];
+ const uint8_t last_byte_mask = kTrailingBitmask[i_end % 8];
+
+ if (bytes_end == bytes_begin + 1) {
+ // set bits within a single byte
+ const uint8_t only_byte_mask =
+ i_end % 8 == 0 ? first_byte_mask
+ : static_cast<uint8_t>(first_byte_mask | last_byte_mask);
+ bits[bytes_begin] &= only_byte_mask;
+ bits[bytes_begin] |= static_cast<uint8_t>(fill_byte & ~only_byte_mask);
+ return;
+ }
+
+ // set/clear trailing bits of first byte
+ bits[bytes_begin] &= first_byte_mask;
+ bits[bytes_begin] |= static_cast<uint8_t>(fill_byte & ~first_byte_mask);
+
+ if (bytes_end - bytes_begin > 2) {
+ // set/clear whole bytes
+ std::memset(bits + bytes_begin + 1, fill_byte,
+ static_cast<size_t>(bytes_end - bytes_begin - 2));
+ }
+
+ if (i_end % 8 == 0) {
+ return;
+ }
+
+ // set/clear leading bits of last byte
+ bits[bytes_end - 1] &= last_byte_mask;
+ bits[bytes_end - 1] |= static_cast<uint8_t>(fill_byte & ~last_byte_mask);
+}
+
+template <bool value>
+void SetBitmapImpl(uint8_t* data, int64_t offset, int64_t length) {
+ // offset length
+ // data |<------------->|
+ // |--------|...|--------|...|--------|
+ // |<--->| |<--->|
+ // pro epi
+ if (ARROW_PREDICT_FALSE(length == 0)) {
+ return;
+ }
+
+ constexpr uint8_t set_byte = value ? UINT8_MAX : 0;
+
+ auto prologue = static_cast<int32_t>(BitUtil::RoundUp(offset, 8) - offset);
+ DCHECK_LT(prologue, 8);
+
+ if (length < prologue) { // special case where a mask is required
+ // offset length
+ // data |<->|
+ // |--------|...|--------|...
+ // mask --> |111|
+ // |<---->|
+ // pro
+ uint8_t mask = BitUtil::kPrecedingBitmask[8 - prologue] ^
+ BitUtil::kPrecedingBitmask[8 - prologue + length];
+ data[offset / 8] = value ? data[offset / 8] | mask : data[offset / 8] & ~mask;
+ return;
+ }
+
+ // align to a byte boundary
+ data[offset / 8] = BitUtil::SpliceWord(8 - prologue, data[offset / 8], set_byte);
+ offset += prologue;
+ length -= prologue;
+
+ // set values per byte
+ DCHECK_EQ(offset % 8, 0);
+ std::memset(data + offset / 8, set_byte, length / 8);
+ offset += BitUtil::RoundDown(length, 8);
+ length -= BitUtil::RoundDown(length, 8);
+
+ // clean up
+ DCHECK_LT(length, 8);
+ if (length > 0) {
+ data[offset / 8] =
+ BitUtil::SpliceWord(static_cast<int32_t>(length), set_byte, data[offset / 8]);
+ }
+}
+
+void SetBitmap(uint8_t* data, int64_t offset, int64_t length) {
+ SetBitmapImpl<true>(data, offset, length);
+}
+
+void ClearBitmap(uint8_t* data, int64_t offset, int64_t length) {
+ SetBitmapImpl<false>(data, offset, length);
+}
+
+} // namespace BitUtil
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bit_util.h b/src/arrow/cpp/src/arrow/util/bit_util.h
new file mode 100644
index 000000000..c306ce782
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bit_util.h
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#if defined(_MSC_VER)
+#include <intrin.h> // IWYU pragma: keep
+#include <nmmintrin.h>
+#pragma intrinsic(_BitScanReverse)
+#pragma intrinsic(_BitScanForward)
+#define ARROW_POPCOUNT64 __popcnt64
+#define ARROW_POPCOUNT32 __popcnt
+#else
+#define ARROW_POPCOUNT64 __builtin_popcountll
+#define ARROW_POPCOUNT32 __builtin_popcount
+#endif
+
+#include <cstdint>
+#include <type_traits>
+
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace detail {
+
+template <typename Integer>
+typename std::make_unsigned<Integer>::type as_unsigned(Integer x) {
+ return static_cast<typename std::make_unsigned<Integer>::type>(x);
+}
+
+} // namespace detail
+
+namespace BitUtil {
+
+// The number of set bits in a given unsigned byte value, pre-computed
+//
+// Generated with the following Python code
+// output = 'static constexpr uint8_t kBytePopcount[] = {{{0}}};'
+// popcounts = [str(bin(i).count('1')) for i in range(0, 256)]
+// print(output.format(', '.join(popcounts)))
+static constexpr uint8_t kBytePopcount[] = {
+ 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3,
+ 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4,
+ 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4,
+ 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5,
+ 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2,
+ 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5,
+ 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4,
+ 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6,
+ 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8};
+
+static inline uint64_t PopCount(uint64_t bitmap) { return ARROW_POPCOUNT64(bitmap); }
+static inline uint32_t PopCount(uint32_t bitmap) { return ARROW_POPCOUNT32(bitmap); }
+
+//
+// Bit-related computations on integer values
+//
+
+// Returns the ceil of value/divisor
+constexpr int64_t CeilDiv(int64_t value, int64_t divisor) {
+ return (value == 0) ? 0 : 1 + (value - 1) / divisor;
+}
+
+// Return the number of bytes needed to fit the given number of bits
+constexpr int64_t BytesForBits(int64_t bits) {
+ // This formula avoids integer overflow on very large `bits`
+ return (bits >> 3) + ((bits & 7) != 0);
+}
+
+constexpr bool IsPowerOf2(int64_t value) {
+ return value > 0 && (value & (value - 1)) == 0;
+}
+
+constexpr bool IsPowerOf2(uint64_t value) {
+ return value > 0 && (value & (value - 1)) == 0;
+}
+
+// Returns the smallest power of two that contains v. If v is already a
+// power of two, it is returned as is.
+static inline int64_t NextPower2(int64_t n) {
+ // Taken from
+ // http://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
+ n--;
+ n |= n >> 1;
+ n |= n >> 2;
+ n |= n >> 4;
+ n |= n >> 8;
+ n |= n >> 16;
+ n |= n >> 32;
+ n++;
+ return n;
+}
+
+constexpr bool IsMultipleOf64(int64_t n) { return (n & 63) == 0; }
+
+constexpr bool IsMultipleOf8(int64_t n) { return (n & 7) == 0; }
+
+// Returns a mask for the bit_index lower order bits.
+// Only valid for bit_index in the range [0, 64).
+constexpr uint64_t LeastSignificantBitMask(int64_t bit_index) {
+ return (static_cast<uint64_t>(1) << bit_index) - 1;
+}
+
+// Returns 'value' rounded up to the nearest multiple of 'factor'
+constexpr int64_t RoundUp(int64_t value, int64_t factor) {
+ return CeilDiv(value, factor) * factor;
+}
+
+// Returns 'value' rounded down to the nearest multiple of 'factor'
+constexpr int64_t RoundDown(int64_t value, int64_t factor) {
+ return (value / factor) * factor;
+}
+
+// Returns 'value' rounded up to the nearest multiple of 'factor' when factor
+// is a power of two.
+// The result is undefined on overflow, i.e. if `value > 2**64 - factor`,
+// since we cannot return the correct result which would be 2**64.
+constexpr int64_t RoundUpToPowerOf2(int64_t value, int64_t factor) {
+ // DCHECK(value >= 0);
+ // DCHECK(IsPowerOf2(factor));
+ return (value + (factor - 1)) & ~(factor - 1);
+}
+
+constexpr uint64_t RoundUpToPowerOf2(uint64_t value, uint64_t factor) {
+ // DCHECK(IsPowerOf2(factor));
+ return (value + (factor - 1)) & ~(factor - 1);
+}
+
+constexpr int64_t RoundUpToMultipleOf8(int64_t num) { return RoundUpToPowerOf2(num, 8); }
+
+constexpr int64_t RoundUpToMultipleOf64(int64_t num) {
+ return RoundUpToPowerOf2(num, 64);
+}
+
+// Returns the number of bytes covering a sliced bitmap. Find the length
+// rounded to cover full bytes on both extremities.
+//
+// The following example represents a slice (offset=10, length=9)
+//
+// 0 8 16 24
+// |-------|-------|------|
+// [ ] (slice)
+// [ ] (same slice aligned to bytes bounds, length=16)
+//
+// The covering bytes is the length (in bytes) of this new aligned slice.
+constexpr int64_t CoveringBytes(int64_t offset, int64_t length) {
+ return (BitUtil::RoundUp(length + offset, 8) - BitUtil::RoundDown(offset, 8)) / 8;
+}
+
+// Returns the 'num_bits' least-significant bits of 'v'.
+static inline uint64_t TrailingBits(uint64_t v, int num_bits) {
+ if (ARROW_PREDICT_FALSE(num_bits == 0)) return 0;
+ if (ARROW_PREDICT_FALSE(num_bits >= 64)) return v;
+ int n = 64 - num_bits;
+ return (v << n) >> n;
+}
+
+/// \brief Count the number of leading zeros in an unsigned integer.
+static inline int CountLeadingZeros(uint32_t value) {
+#if defined(__clang__) || defined(__GNUC__)
+ if (value == 0) return 32;
+ return static_cast<int>(__builtin_clz(value));
+#elif defined(_MSC_VER)
+ unsigned long index; // NOLINT
+ if (_BitScanReverse(&index, static_cast<unsigned long>(value))) { // NOLINT
+ return 31 - static_cast<int>(index);
+ } else {
+ return 32;
+ }
+#else
+ int bitpos = 0;
+ while (value != 0) {
+ value >>= 1;
+ ++bitpos;
+ }
+ return 32 - bitpos;
+#endif
+}
+
+static inline int CountLeadingZeros(uint64_t value) {
+#if defined(__clang__) || defined(__GNUC__)
+ if (value == 0) return 64;
+ return static_cast<int>(__builtin_clzll(value));
+#elif defined(_MSC_VER)
+ unsigned long index; // NOLINT
+ if (_BitScanReverse64(&index, value)) { // NOLINT
+ return 63 - static_cast<int>(index);
+ } else {
+ return 64;
+ }
+#else
+ int bitpos = 0;
+ while (value != 0) {
+ value >>= 1;
+ ++bitpos;
+ }
+ return 64 - bitpos;
+#endif
+}
+
+static inline int CountTrailingZeros(uint32_t value) {
+#if defined(__clang__) || defined(__GNUC__)
+ if (value == 0) return 32;
+ return static_cast<int>(__builtin_ctzl(value));
+#elif defined(_MSC_VER)
+ unsigned long index; // NOLINT
+ if (_BitScanForward(&index, value)) {
+ return static_cast<int>(index);
+ } else {
+ return 32;
+ }
+#else
+ int bitpos = 0;
+ if (value) {
+ while (value & 1 == 0) {
+ value >>= 1;
+ ++bitpos;
+ }
+ } else {
+ bitpos = 32;
+ }
+ return bitpos;
+#endif
+}
+
+static inline int CountTrailingZeros(uint64_t value) {
+#if defined(__clang__) || defined(__GNUC__)
+ if (value == 0) return 64;
+ return static_cast<int>(__builtin_ctzll(value));
+#elif defined(_MSC_VER)
+ unsigned long index; // NOLINT
+ if (_BitScanForward64(&index, value)) {
+ return static_cast<int>(index);
+ } else {
+ return 64;
+ }
+#else
+ int bitpos = 0;
+ if (value) {
+ while (value & 1 == 0) {
+ value >>= 1;
+ ++bitpos;
+ }
+ } else {
+ bitpos = 64;
+ }
+ return bitpos;
+#endif
+}
+
+// Returns the minimum number of bits needed to represent an unsigned value
+static inline int NumRequiredBits(uint64_t x) { return 64 - CountLeadingZeros(x); }
+
+// Returns ceil(log2(x)).
+static inline int Log2(uint64_t x) {
+ // DCHECK_GT(x, 0);
+ return NumRequiredBits(x - 1);
+}
+
+//
+// Utilities for reading and writing individual bits by their index
+// in a memory area.
+//
+
+// Bitmask selecting the k-th bit in a byte
+static constexpr uint8_t kBitmask[] = {1, 2, 4, 8, 16, 32, 64, 128};
+
+// the bitwise complement version of kBitmask
+static constexpr uint8_t kFlippedBitmask[] = {254, 253, 251, 247, 239, 223, 191, 127};
+
+// Bitmask selecting the (k - 1) preceding bits in a byte
+static constexpr uint8_t kPrecedingBitmask[] = {0, 1, 3, 7, 15, 31, 63, 127};
+static constexpr uint8_t kPrecedingWrappingBitmask[] = {255, 1, 3, 7, 15, 31, 63, 127};
+
+// the bitwise complement version of kPrecedingBitmask
+static constexpr uint8_t kTrailingBitmask[] = {255, 254, 252, 248, 240, 224, 192, 128};
+
+static constexpr bool GetBit(const uint8_t* bits, uint64_t i) {
+ return (bits[i >> 3] >> (i & 0x07)) & 1;
+}
+
+// Gets the i-th bit from a byte. Should only be used with i <= 7.
+static constexpr bool GetBitFromByte(uint8_t byte, uint8_t i) {
+ return byte & kBitmask[i];
+}
+
+static inline void ClearBit(uint8_t* bits, int64_t i) {
+ bits[i / 8] &= kFlippedBitmask[i % 8];
+}
+
+static inline void SetBit(uint8_t* bits, int64_t i) { bits[i / 8] |= kBitmask[i % 8]; }
+
+static inline void SetBitTo(uint8_t* bits, int64_t i, bool bit_is_set) {
+ // https://graphics.stanford.edu/~seander/bithacks.html
+ // "Conditionally set or clear bits without branching"
+ // NOTE: this seems to confuse Valgrind as it reads from potentially
+ // uninitialized memory
+ bits[i / 8] ^= static_cast<uint8_t>(-static_cast<uint8_t>(bit_is_set) ^ bits[i / 8]) &
+ kBitmask[i % 8];
+}
+
+/// \brief set or clear a range of bits quickly
+ARROW_EXPORT
+void SetBitsTo(uint8_t* bits, int64_t start_offset, int64_t length, bool bits_are_set);
+
+/// \brief Sets all bits in the bitmap to true
+ARROW_EXPORT
+void SetBitmap(uint8_t* data, int64_t offset, int64_t length);
+
+/// \brief Clears all bits in the bitmap (set to false)
+ARROW_EXPORT
+void ClearBitmap(uint8_t* data, int64_t offset, int64_t length);
+
+/// Returns a mask with lower i bits set to 1. If i >= sizeof(Word)*8, all-ones will be
+/// returned
+/// ex:
+/// ref: https://stackoverflow.com/a/59523400
+template <typename Word>
+constexpr Word PrecedingWordBitmask(unsigned int const i) {
+ return (static_cast<Word>(i < sizeof(Word) * 8) << (i & (sizeof(Word) * 8 - 1))) - 1;
+}
+static_assert(PrecedingWordBitmask<uint8_t>(0) == 0x00, "");
+static_assert(PrecedingWordBitmask<uint8_t>(4) == 0x0f, "");
+static_assert(PrecedingWordBitmask<uint8_t>(8) == 0xff, "");
+static_assert(PrecedingWordBitmask<uint16_t>(8) == 0x00ff, "");
+
+/// \brief Create a word with low `n` bits from `low` and high `sizeof(Word)-n` bits
+/// from `high`.
+/// Word ret
+/// for (i = 0; i < sizeof(Word)*8; i++){
+/// ret[i]= i < n ? low[i]: high[i];
+/// }
+template <typename Word>
+constexpr Word SpliceWord(int n, Word low, Word high) {
+ return (high & ~PrecedingWordBitmask<Word>(n)) | (low & PrecedingWordBitmask<Word>(n));
+}
+
+} // namespace BitUtil
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bit_util_benchmark.cc b/src/arrow/cpp/src/arrow/util/bit_util_benchmark.cc
new file mode 100644
index 000000000..8a4f3e0c5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bit_util_benchmark.cc
@@ -0,0 +1,560 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <array>
+#include <bitset>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+#include <memory>
+#include <utility>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_primitive.h"
+#include "arrow/buffer.h"
+#include "arrow/result.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap.h"
+#include "arrow/util/bitmap_generate.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/bitmap_visit.h"
+#include "arrow/util/bitmap_writer.h"
+
+namespace arrow {
+namespace BitUtil {
+
+constexpr int64_t kBufferSize = 1024 * 8;
+
+#ifdef ARROW_WITH_BENCHMARKS_REFERENCE
+
+// A naive bitmap reader implementation, meant as a baseline against
+// internal::BitmapReader
+
+class NaiveBitmapReader {
+ public:
+ NaiveBitmapReader(const uint8_t* bitmap, int64_t start_offset, int64_t length)
+ : bitmap_(bitmap), position_(0) {}
+
+ bool IsSet() const { return BitUtil::GetBit(bitmap_, position_); }
+
+ bool IsNotSet() const { return !IsSet(); }
+
+ void Next() { ++position_; }
+
+ private:
+ const uint8_t* bitmap_;
+ uint64_t position_;
+};
+
+// A naive bitmap writer implementation, meant as a baseline against
+// internal::BitmapWriter
+
+class NaiveBitmapWriter {
+ public:
+ NaiveBitmapWriter(uint8_t* bitmap, int64_t start_offset, int64_t length)
+ : bitmap_(bitmap), position_(0) {}
+
+ void Set() {
+ const int64_t byte_offset = position_ / 8;
+ const int64_t bit_offset = position_ % 8;
+ auto bit_set_mask = (1U << bit_offset);
+ bitmap_[byte_offset] = static_cast<uint8_t>(bitmap_[byte_offset] | bit_set_mask);
+ }
+
+ void Clear() {
+ const int64_t byte_offset = position_ / 8;
+ const int64_t bit_offset = position_ % 8;
+ auto bit_clear_mask = 0xFFU ^ (1U << bit_offset);
+ bitmap_[byte_offset] = static_cast<uint8_t>(bitmap_[byte_offset] & bit_clear_mask);
+ }
+
+ void Next() { ++position_; }
+
+ void Finish() {}
+
+ int64_t position() const { return position_; }
+
+ private:
+ uint8_t* bitmap_;
+ int64_t position_;
+};
+
+#endif
+
+static std::shared_ptr<Buffer> CreateRandomBuffer(int64_t nbytes) {
+ auto buffer = *AllocateBuffer(nbytes);
+ memset(buffer->mutable_data(), 0, nbytes);
+ random_bytes(nbytes, /*seed=*/0, buffer->mutable_data());
+ return std::move(buffer);
+}
+
+static std::shared_ptr<Buffer> CreateRandomBitsBuffer(int64_t nbits,
+ int64_t set_percentage) {
+ ::arrow::random::RandomArrayGenerator rag(/*seed=*/23);
+ double set_probability =
+ static_cast<double>(set_percentage == -1 ? 0 : set_percentage) / 100.0;
+ std::shared_ptr<Buffer> buffer =
+ rag.Boolean(nbits, set_probability)->data()->buffers[1];
+
+ if (set_percentage == -1) {
+ internal::BitmapWriter writer(buffer->mutable_data(), /*start_offset=*/0,
+ /*length=*/nbits);
+ for (int x = 0; x < nbits; x++) {
+ if (x % 2 == 0) {
+ writer.Set();
+ } else {
+ writer.Clear();
+ }
+ writer.Next();
+ }
+ }
+ return buffer;
+}
+
+template <typename DoAnd>
+static void BenchmarkAndImpl(benchmark::State& state, DoAnd&& do_and) {
+ int64_t nbytes = state.range(0);
+ int64_t offset = state.range(1);
+
+ std::shared_ptr<Buffer> buffer_1 = CreateRandomBuffer(nbytes);
+ std::shared_ptr<Buffer> buffer_2 = CreateRandomBuffer(nbytes);
+ std::shared_ptr<Buffer> buffer_3 = CreateRandomBuffer(nbytes);
+
+ const int64_t num_bits = nbytes * 8 - offset;
+
+ internal::Bitmap bitmap_1{buffer_1, 0, num_bits};
+ internal::Bitmap bitmap_2{buffer_2, offset, num_bits};
+ internal::Bitmap bitmap_3{buffer_3, 0, num_bits};
+
+ for (auto _ : state) {
+ do_and({bitmap_1, bitmap_2}, &bitmap_3);
+ auto total = internal::CountSetBits(bitmap_3.buffer()->data(), bitmap_3.offset(),
+ bitmap_3.length());
+ benchmark::DoNotOptimize(total);
+ }
+ state.SetBytesProcessed(state.iterations() * nbytes);
+}
+
+static void BenchmarkBitmapAnd(benchmark::State& state) {
+ BenchmarkAndImpl(state, [](const internal::Bitmap(&bitmaps)[2], internal::Bitmap* out) {
+ internal::BitmapAnd(bitmaps[0].buffer()->data(), bitmaps[0].offset(),
+ bitmaps[1].buffer()->data(), bitmaps[1].offset(),
+ bitmaps[0].length(), 0, out->buffer()->mutable_data());
+ });
+}
+
+static void BenchmarkBitmapVisitBitsetAnd(benchmark::State& state) {
+ BenchmarkAndImpl(state, [](const internal::Bitmap(&bitmaps)[2], internal::Bitmap* out) {
+ int64_t i = 0;
+ internal::Bitmap::VisitBits(
+ bitmaps, [&](std::bitset<2> bits) { out->SetBitTo(i++, bits[0] && bits[1]); });
+ });
+}
+
+static void BenchmarkBitmapVisitUInt8And(benchmark::State& state) {
+ BenchmarkAndImpl(state, [](const internal::Bitmap(&bitmaps)[2], internal::Bitmap* out) {
+ int64_t i = 0;
+ internal::Bitmap::VisitWords(bitmaps, [&](std::array<uint8_t, 2> uint8s) {
+ reinterpret_cast<uint8_t*>(out->buffer()->mutable_data())[i++] =
+ uint8s[0] & uint8s[1];
+ });
+ });
+}
+
+static void BenchmarkBitmapVisitUInt64And(benchmark::State& state) {
+ BenchmarkAndImpl(state, [](const internal::Bitmap(&bitmaps)[2], internal::Bitmap* out) {
+ int64_t i = 0;
+ internal::Bitmap::VisitWords(bitmaps, [&](std::array<uint64_t, 2> uint64s) {
+ reinterpret_cast<uint64_t*>(out->buffer()->mutable_data())[i++] =
+ uint64s[0] & uint64s[1];
+ });
+ });
+}
+
+template <typename BitmapReaderType>
+static void BenchmarkBitmapReader(benchmark::State& state, int64_t nbytes) {
+ std::shared_ptr<Buffer> buffer = CreateRandomBuffer(nbytes);
+
+ const int64_t num_bits = nbytes * 8;
+ const uint8_t* bitmap = buffer->data();
+
+ for (auto _ : state) {
+ {
+ BitmapReaderType reader(bitmap, 0, num_bits);
+ int64_t total = 0;
+ for (int64_t i = 0; i < num_bits; i++) {
+ total += reader.IsSet();
+ reader.Next();
+ }
+ benchmark::DoNotOptimize(total);
+ }
+ {
+ BitmapReaderType reader(bitmap, 0, num_bits);
+ int64_t total = 0;
+ for (int64_t i = 0; i < num_bits; i++) {
+ total += !reader.IsNotSet();
+ reader.Next();
+ }
+ benchmark::DoNotOptimize(total);
+ }
+ }
+ state.SetBytesProcessed(2LL * state.iterations() * nbytes);
+}
+
+template <typename BitRunReaderType>
+static void BenchmarkBitRunReader(benchmark::State& state, int64_t set_percentage) {
+ constexpr int64_t kNumBits = 4096;
+ auto buffer = CreateRandomBitsBuffer(kNumBits, set_percentage);
+
+ for (auto _ : state) {
+ {
+ BitRunReaderType reader(buffer->data(), 0, kNumBits);
+ int64_t set_total = 0;
+ internal::BitRun br;
+ do {
+ br = reader.NextRun();
+ set_total += br.set ? br.length : 0;
+ } while (br.length != 0);
+ benchmark::DoNotOptimize(set_total);
+ }
+ }
+ state.SetBytesProcessed(state.iterations() * (kNumBits / 8));
+}
+
+template <typename SetBitRunReaderType>
+static void BenchmarkSetBitRunReader(benchmark::State& state, int64_t set_percentage) {
+ constexpr int64_t kNumBits = 4096;
+ auto buffer = CreateRandomBitsBuffer(kNumBits, set_percentage);
+
+ for (auto _ : state) {
+ {
+ SetBitRunReaderType reader(buffer->data(), 0, kNumBits);
+ int64_t set_total = 0;
+ internal::SetBitRun br;
+ do {
+ br = reader.NextRun();
+ set_total += br.length;
+ } while (br.length != 0);
+ benchmark::DoNotOptimize(set_total);
+ }
+ }
+ state.SetBytesProcessed(state.iterations() * (kNumBits / 8));
+}
+
+template <typename VisitBitsFunctorType>
+static void BenchmarkVisitBits(benchmark::State& state, int64_t nbytes) {
+ std::shared_ptr<Buffer> buffer = CreateRandomBuffer(nbytes);
+
+ const int64_t num_bits = nbytes * 8;
+ const uint8_t* bitmap = buffer->data();
+
+ for (auto _ : state) {
+ {
+ int64_t total = 0;
+ const auto visit = [&total](bool value) -> void { total += value; };
+ VisitBitsFunctorType()(bitmap, 0, num_bits, visit);
+ benchmark::DoNotOptimize(total);
+ }
+ {
+ int64_t total = 0;
+ const auto visit = [&total](bool value) -> void { total += value; };
+ VisitBitsFunctorType()(bitmap, 0, num_bits, visit);
+ benchmark::DoNotOptimize(total);
+ }
+ }
+ state.SetBytesProcessed(2LL * state.iterations() * nbytes);
+}
+
+constexpr bool pattern[] = {false, false, false, true, true, true};
+static_assert(
+ (sizeof(pattern) / sizeof(pattern[0])) % 8 != 0,
+ "pattern must not be a multiple of 8, otherwise gcc can optimize with a memset");
+
+template <typename BitmapWriterType>
+static void BenchmarkBitmapWriter(benchmark::State& state, int64_t nbytes) {
+ std::shared_ptr<Buffer> buffer = CreateRandomBuffer(nbytes);
+
+ const int64_t num_bits = nbytes * 8;
+ uint8_t* bitmap = buffer->mutable_data();
+
+ for (auto _ : state) {
+ BitmapWriterType writer(bitmap, 0, num_bits);
+ int64_t pattern_index = 0;
+ for (int64_t i = 0; i < num_bits; i++) {
+ if (pattern[pattern_index++]) {
+ writer.Set();
+ } else {
+ writer.Clear();
+ }
+ if (pattern_index == sizeof(pattern) / sizeof(bool)) {
+ pattern_index = 0;
+ }
+ writer.Next();
+ }
+ writer.Finish();
+ benchmark::ClobberMemory();
+ }
+ state.SetBytesProcessed(state.iterations() * nbytes);
+}
+
+template <typename GenerateBitsFunctorType>
+static void BenchmarkGenerateBits(benchmark::State& state, int64_t nbytes) {
+ std::shared_ptr<Buffer> buffer = CreateRandomBuffer(nbytes);
+
+ const int64_t num_bits = nbytes * 8;
+ uint8_t* bitmap = buffer->mutable_data();
+
+ while (state.KeepRunning()) {
+ int64_t pattern_index = 0;
+ const auto generate = [&]() -> bool {
+ bool b = pattern[pattern_index++];
+ if (pattern_index == sizeof(pattern) / sizeof(bool)) {
+ pattern_index = 0;
+ }
+ return b;
+ };
+ GenerateBitsFunctorType()(bitmap, 0, num_bits, generate);
+ benchmark::ClobberMemory();
+ }
+ state.SetBytesProcessed(state.iterations() * nbytes);
+}
+
+static void BitmapReader(benchmark::State& state) {
+ BenchmarkBitmapReader<internal::BitmapReader>(state, state.range(0));
+}
+
+static void BitmapUInt64Reader(benchmark::State& state) {
+ const int64_t nbytes = state.range(0);
+ std::shared_ptr<Buffer> buffer = CreateRandomBuffer(nbytes);
+
+ const int64_t num_bits = nbytes * 8;
+ const uint8_t* bitmap = buffer->data();
+
+ for (auto _ : state) {
+ {
+ internal::BitmapUInt64Reader reader(bitmap, 0, num_bits);
+ uint64_t total = 0;
+ for (int64_t i = 0; i < num_bits; i += 64) {
+ total += reader.NextWord();
+ }
+ benchmark::DoNotOptimize(total);
+ }
+ }
+ state.SetBytesProcessed(state.iterations() * nbytes);
+}
+
+static void BitRunReader(benchmark::State& state) {
+ BenchmarkBitRunReader<internal::BitRunReader>(state, state.range(0));
+}
+
+static void BitRunReaderLinear(benchmark::State& state) {
+ BenchmarkBitRunReader<internal::BitRunReaderLinear>(state, state.range(0));
+}
+
+static void SetBitRunReader(benchmark::State& state) {
+ BenchmarkSetBitRunReader<internal::SetBitRunReader>(state, state.range(0));
+}
+
+static void ReverseSetBitRunReader(benchmark::State& state) {
+ BenchmarkSetBitRunReader<internal::ReverseSetBitRunReader>(state, state.range(0));
+}
+
+static void BitmapWriter(benchmark::State& state) {
+ BenchmarkBitmapWriter<internal::BitmapWriter>(state, state.range(0));
+}
+
+static void FirstTimeBitmapWriter(benchmark::State& state) {
+ BenchmarkBitmapWriter<internal::FirstTimeBitmapWriter>(state, state.range(0));
+}
+
+struct GenerateBitsFunctor {
+ template <class Generator>
+ void operator()(uint8_t* bitmap, int64_t start_offset, int64_t length, Generator&& g) {
+ return internal::GenerateBits(bitmap, start_offset, length, g);
+ }
+};
+
+struct GenerateBitsUnrolledFunctor {
+ template <class Generator>
+ void operator()(uint8_t* bitmap, int64_t start_offset, int64_t length, Generator&& g) {
+ return internal::GenerateBitsUnrolled(bitmap, start_offset, length, g);
+ }
+};
+
+struct VisitBitsFunctor {
+ template <class Visitor>
+ void operator()(const uint8_t* bitmap, int64_t start_offset, int64_t length,
+ Visitor&& g) {
+ return internal::VisitBits(bitmap, start_offset, length, g);
+ }
+};
+
+struct VisitBitsUnrolledFunctor {
+ template <class Visitor>
+ void operator()(const uint8_t* bitmap, int64_t start_offset, int64_t length,
+ Visitor&& g) {
+ return internal::VisitBitsUnrolled(bitmap, start_offset, length, g);
+ }
+};
+
+static void GenerateBits(benchmark::State& state) {
+ BenchmarkGenerateBits<GenerateBitsFunctor>(state, state.range(0));
+}
+
+static void GenerateBitsUnrolled(benchmark::State& state) {
+ BenchmarkGenerateBits<GenerateBitsUnrolledFunctor>(state, state.range(0));
+}
+
+static void VisitBits(benchmark::State& state) {
+ BenchmarkVisitBits<VisitBitsFunctor>(state, state.range(0));
+}
+
+static void VisitBitsUnrolled(benchmark::State& state) {
+ BenchmarkVisitBits<VisitBitsUnrolledFunctor>(state, state.range(0));
+}
+
+static void SetBitsTo(benchmark::State& state) {
+ int64_t nbytes = state.range(0);
+ std::shared_ptr<Buffer> buffer = CreateRandomBuffer(nbytes);
+
+ for (auto _ : state) {
+ BitUtil::SetBitsTo(buffer->mutable_data(), /*offset=*/0, nbytes * 8, true);
+ }
+ state.SetBytesProcessed(state.iterations() * nbytes);
+}
+
+template <int64_t OffsetSrc, int64_t OffsetDest = 0>
+static void CopyBitmap(benchmark::State& state) { // NOLINT non-const reference
+ const int64_t buffer_size = state.range(0);
+ const int64_t bits_size = buffer_size * 8;
+ std::shared_ptr<Buffer> buffer = CreateRandomBuffer(buffer_size);
+
+ const uint8_t* src = buffer->data();
+ const int64_t length = bits_size - OffsetSrc;
+
+ auto copy = *AllocateEmptyBitmap(length);
+
+ for (auto _ : state) {
+ internal::CopyBitmap(src, OffsetSrc, length, copy->mutable_data(), OffsetDest);
+ }
+
+ state.SetBytesProcessed(state.iterations() * buffer_size);
+}
+
+static void CopyBitmapWithoutOffset(
+ benchmark::State& state) { // NOLINT non-const reference
+ CopyBitmap<0>(state);
+}
+
+// Trigger the slow path where the source buffer is not byte aligned.
+static void CopyBitmapWithOffset(benchmark::State& state) { // NOLINT non-const reference
+ CopyBitmap<4>(state);
+}
+
+// Trigger the slow path where both source and dest buffer are not byte aligned.
+static void CopyBitmapWithOffsetBoth(benchmark::State& state) { CopyBitmap<3, 7>(state); }
+
+// Benchmark the worst case of comparing two identical bitmap
+template <int64_t Offset = 0>
+static void BitmapEquals(benchmark::State& state) {
+ const int64_t buffer_size = state.range(0);
+ const int64_t bits_size = buffer_size * 8;
+ std::shared_ptr<Buffer> buffer = CreateRandomBuffer(buffer_size);
+
+ const uint8_t* src = buffer->data();
+ const int64_t offset = Offset;
+ const int64_t length = bits_size - offset;
+
+ auto copy = *AllocateEmptyBitmap(length + offset);
+ internal::CopyBitmap(src, 0, length, copy->mutable_data(), offset);
+
+ for (auto _ : state) {
+ auto is_same = internal::BitmapEquals(src, 0, copy->data(), offset, length);
+ benchmark::DoNotOptimize(is_same);
+ }
+
+ state.SetBytesProcessed(state.iterations() * buffer_size);
+}
+
+static void BitmapEqualsWithoutOffset(benchmark::State& state) { BitmapEquals<0>(state); }
+
+static void BitmapEqualsWithOffset(benchmark::State& state) { BitmapEquals<4>(state); }
+
+#ifdef ARROW_WITH_BENCHMARKS_REFERENCE
+static void ReferenceNaiveBitmapReader(benchmark::State& state) {
+ BenchmarkBitmapReader<NaiveBitmapReader>(state, state.range(0));
+}
+
+BENCHMARK(ReferenceNaiveBitmapReader)->Arg(kBufferSize);
+#endif
+
+void SetBitRunReaderPercentageArg(benchmark::internal::Benchmark* bench) {
+ bench->Arg(-1)->Arg(0)->Arg(10)->Arg(25)->Arg(50)->Arg(60)->Arg(75)->Arg(99);
+}
+
+BENCHMARK(BitmapReader)->Arg(kBufferSize);
+BENCHMARK(BitmapUInt64Reader)->Arg(kBufferSize);
+
+BENCHMARK(BitRunReader)->Apply(SetBitRunReaderPercentageArg);
+BENCHMARK(BitRunReaderLinear)->Apply(SetBitRunReaderPercentageArg);
+BENCHMARK(SetBitRunReader)->Apply(SetBitRunReaderPercentageArg);
+BENCHMARK(ReverseSetBitRunReader)->Apply(SetBitRunReaderPercentageArg);
+
+BENCHMARK(VisitBits)->Arg(kBufferSize);
+BENCHMARK(VisitBitsUnrolled)->Arg(kBufferSize);
+BENCHMARK(SetBitsTo)->Arg(2)->Arg(1 << 4)->Arg(1 << 10)->Arg(1 << 17);
+
+#ifdef ARROW_WITH_BENCHMARKS_REFERENCE
+static void ReferenceNaiveBitmapWriter(benchmark::State& state) {
+ BenchmarkBitmapWriter<NaiveBitmapWriter>(state, state.range(0));
+}
+
+BENCHMARK(ReferenceNaiveBitmapWriter)->Arg(kBufferSize);
+#endif
+
+BENCHMARK(BitmapWriter)->Arg(kBufferSize);
+BENCHMARK(FirstTimeBitmapWriter)->Arg(kBufferSize);
+
+BENCHMARK(GenerateBits)->Arg(kBufferSize);
+BENCHMARK(GenerateBitsUnrolled)->Arg(kBufferSize);
+
+BENCHMARK(CopyBitmapWithoutOffset)->Arg(kBufferSize);
+BENCHMARK(CopyBitmapWithOffset)->Arg(kBufferSize);
+BENCHMARK(CopyBitmapWithOffsetBoth)->Arg(kBufferSize);
+
+BENCHMARK(BitmapEqualsWithoutOffset)->Arg(kBufferSize);
+BENCHMARK(BitmapEqualsWithOffset)->Arg(kBufferSize);
+
+#define AND_BENCHMARK_RANGES \
+ { \
+ {kBufferSize * 4, kBufferSize * 16}, { 0, 2 } \
+ }
+BENCHMARK(BenchmarkBitmapAnd)->Ranges(AND_BENCHMARK_RANGES);
+BENCHMARK(BenchmarkBitmapVisitBitsetAnd)->Ranges(AND_BENCHMARK_RANGES);
+BENCHMARK(BenchmarkBitmapVisitUInt8And)->Ranges(AND_BENCHMARK_RANGES);
+BENCHMARK(BenchmarkBitmapVisitUInt64And)->Ranges(AND_BENCHMARK_RANGES);
+
+} // namespace BitUtil
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bit_util_test.cc b/src/arrow/cpp/src/arrow/util/bit_util_test.cc
new file mode 100644
index 000000000..c3fb08321
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bit_util_test.cc
@@ -0,0 +1,2330 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <array>
+#include <climits>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock-matchers.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/data.h"
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_compat.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bit_stream_utils.h"
+#include "arrow/util/bitmap.h"
+#include "arrow/util/bitmap_generate.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/bitmap_visit.h"
+#include "arrow/util/bitmap_writer.h"
+#include "arrow/util/bitset_stack.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/ubsan.h"
+
+namespace arrow {
+
+using internal::BitmapAnd;
+using internal::BitmapAndNot;
+using internal::BitmapOr;
+using internal::BitmapXor;
+using internal::BitsetStack;
+using internal::CopyBitmap;
+using internal::CountSetBits;
+using internal::InvertBitmap;
+using util::SafeCopy;
+
+using ::testing::ElementsAreArray;
+
+namespace internal {
+
+void PrintTo(const BitRun& run, std::ostream* os) { *os << run.ToString(); }
+void PrintTo(const SetBitRun& run, std::ostream* os) { *os << run.ToString(); }
+
+} // namespace internal
+
+template <class BitmapWriter>
+void WriteVectorToWriter(BitmapWriter& writer, const std::vector<int> values) {
+ for (const auto& value : values) {
+ if (value) {
+ writer.Set();
+ } else {
+ writer.Clear();
+ }
+ writer.Next();
+ }
+ writer.Finish();
+}
+
+void BitmapFromVector(const std::vector<int>& values, int64_t bit_offset,
+ std::shared_ptr<Buffer>* out_buffer, int64_t* out_length) {
+ const int64_t length = values.size();
+ *out_length = length;
+ ASSERT_OK_AND_ASSIGN(*out_buffer, AllocateEmptyBitmap(length + bit_offset));
+ auto writer = internal::BitmapWriter((*out_buffer)->mutable_data(), bit_offset, length);
+ WriteVectorToWriter(writer, values);
+}
+
+std::shared_ptr<Buffer> BitmapFromString(const std::string& s) {
+ TypedBufferBuilder<bool> builder;
+ ABORT_NOT_OK(builder.Reserve(s.size()));
+ for (const char c : s) {
+ switch (c) {
+ case '0':
+ builder.UnsafeAppend(false);
+ break;
+ case '1':
+ builder.UnsafeAppend(true);
+ break;
+ case ' ':
+ case '\t':
+ case '\n':
+ case '\r':
+ break;
+ default:
+ ARROW_LOG(FATAL) << "Unexpected character in bitmap string";
+ }
+ }
+ std::shared_ptr<Buffer> buffer;
+ ABORT_NOT_OK(builder.Finish(&buffer));
+ return buffer;
+}
+
+#define ASSERT_READER_SET(reader) \
+ do { \
+ ASSERT_TRUE(reader.IsSet()); \
+ ASSERT_FALSE(reader.IsNotSet()); \
+ reader.Next(); \
+ } while (false)
+
+#define ASSERT_READER_NOT_SET(reader) \
+ do { \
+ ASSERT_FALSE(reader.IsSet()); \
+ ASSERT_TRUE(reader.IsNotSet()); \
+ reader.Next(); \
+ } while (false)
+
+// Assert that a BitmapReader yields the given bit values
+void ASSERT_READER_VALUES(internal::BitmapReader& reader, std::vector<int> values) {
+ for (const auto& value : values) {
+ if (value) {
+ ASSERT_READER_SET(reader);
+ } else {
+ ASSERT_READER_NOT_SET(reader);
+ }
+ }
+}
+
+// Assert equal contents of a memory area and a vector of bytes
+void ASSERT_BYTES_EQ(const uint8_t* left, const std::vector<uint8_t>& right) {
+ auto left_array = std::vector<uint8_t>(left, left + right.size());
+ ASSERT_EQ(left_array, right);
+}
+
+TEST(BitUtilTests, TestIsMultipleOf64) {
+ using BitUtil::IsMultipleOf64;
+ EXPECT_TRUE(IsMultipleOf64(64));
+ EXPECT_TRUE(IsMultipleOf64(0));
+ EXPECT_TRUE(IsMultipleOf64(128));
+ EXPECT_TRUE(IsMultipleOf64(192));
+ EXPECT_FALSE(IsMultipleOf64(23));
+ EXPECT_FALSE(IsMultipleOf64(32));
+}
+
+TEST(BitUtilTests, TestNextPower2) {
+ using BitUtil::NextPower2;
+
+ ASSERT_EQ(8, NextPower2(6));
+ ASSERT_EQ(8, NextPower2(8));
+
+ ASSERT_EQ(1, NextPower2(1));
+ ASSERT_EQ(256, NextPower2(131));
+
+ ASSERT_EQ(1024, NextPower2(1000));
+
+ ASSERT_EQ(4096, NextPower2(4000));
+
+ ASSERT_EQ(65536, NextPower2(64000));
+
+ ASSERT_EQ(1LL << 32, NextPower2((1LL << 32) - 1));
+ ASSERT_EQ(1LL << 31, NextPower2((1LL << 31) - 1));
+ ASSERT_EQ(1LL << 62, NextPower2((1LL << 62) - 1));
+}
+
+TEST(BitUtilTests, BytesForBits) {
+ using BitUtil::BytesForBits;
+
+ ASSERT_EQ(BytesForBits(0), 0);
+ ASSERT_EQ(BytesForBits(1), 1);
+ ASSERT_EQ(BytesForBits(7), 1);
+ ASSERT_EQ(BytesForBits(8), 1);
+ ASSERT_EQ(BytesForBits(9), 2);
+ ASSERT_EQ(BytesForBits(0xffff), 8192);
+ ASSERT_EQ(BytesForBits(0x10000), 8192);
+ ASSERT_EQ(BytesForBits(0x10001), 8193);
+ ASSERT_EQ(BytesForBits(0x7ffffffffffffff8ll), 0x0fffffffffffffffll);
+ ASSERT_EQ(BytesForBits(0x7ffffffffffffff9ll), 0x1000000000000000ll);
+ ASSERT_EQ(BytesForBits(0x7fffffffffffffffll), 0x1000000000000000ll);
+}
+
+TEST(BitmapReader, NormalOperation) {
+ std::shared_ptr<Buffer> buffer;
+ int64_t length;
+
+ for (int64_t offset : {0, 1, 3, 5, 7, 8, 12, 13, 21, 38, 75, 120}) {
+ BitmapFromVector({0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1}, offset, &buffer,
+ &length);
+ ASSERT_EQ(length, 14);
+
+ auto reader = internal::BitmapReader(buffer->mutable_data(), offset, length);
+ ASSERT_READER_VALUES(reader, {0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
+ }
+}
+
+TEST(BitmapReader, DoesNotReadOutOfBounds) {
+ uint8_t bitmap[16] = {0};
+
+ const int length = 128;
+
+ internal::BitmapReader r1(bitmap, 0, length);
+
+ // If this were to read out of bounds, valgrind would tell us
+ for (int i = 0; i < length; ++i) {
+ ASSERT_TRUE(r1.IsNotSet());
+ r1.Next();
+ }
+
+ internal::BitmapReader r2(bitmap, 5, length - 5);
+
+ for (int i = 0; i < (length - 5); ++i) {
+ ASSERT_TRUE(r2.IsNotSet());
+ r2.Next();
+ }
+
+ // Does not access invalid memory
+ internal::BitmapReader r3(nullptr, 0, 0);
+}
+
+class TestBitmapUInt64Reader : public ::testing::Test {
+ public:
+ void AssertWords(const Buffer& buffer, int64_t start_offset, int64_t length,
+ const std::vector<uint64_t>& expected) {
+ internal::BitmapUInt64Reader reader(buffer.data(), start_offset, length);
+ ASSERT_EQ(reader.position(), 0);
+ ASSERT_EQ(reader.length(), length);
+ for (const uint64_t word : expected) {
+ ASSERT_EQ(reader.NextWord(), word);
+ }
+ ASSERT_EQ(reader.position(), length);
+ }
+
+ void Check(const Buffer& buffer, int64_t start_offset, int64_t length) {
+ internal::BitmapUInt64Reader reader(buffer.data(), start_offset, length);
+ for (int64_t i = 0; i < length; i += 64) {
+ ASSERT_EQ(reader.position(), i);
+ const auto nbits = std::min<int64_t>(64, length - i);
+ uint64_t word = reader.NextWord();
+ for (int64_t j = 0; j < nbits; ++j) {
+ ASSERT_EQ(word & 1, BitUtil::GetBit(buffer.data(), start_offset + i + j));
+ word >>= 1;
+ }
+ }
+ ASSERT_EQ(reader.position(), length);
+ }
+
+ void CheckExtensive(const Buffer& buffer) {
+ for (const int64_t offset : kTestOffsets) {
+ for (int64_t length : kTestOffsets) {
+ if (offset + length <= buffer.size()) {
+ Check(buffer, offset, length);
+ length = buffer.size() - offset - length;
+ if (offset + length <= buffer.size()) {
+ Check(buffer, offset, length);
+ }
+ }
+ }
+ }
+ }
+
+ protected:
+ const std::vector<int64_t> kTestOffsets = {0, 1, 6, 7, 8, 33, 62, 63, 64, 65};
+};
+
+TEST_F(TestBitmapUInt64Reader, Empty) {
+ for (const int64_t offset : kTestOffsets) {
+ // Does not access invalid memory
+ internal::BitmapUInt64Reader reader(nullptr, offset, 0);
+ ASSERT_EQ(reader.position(), 0);
+ ASSERT_EQ(reader.length(), 0);
+ }
+}
+
+TEST_F(TestBitmapUInt64Reader, Small) {
+ auto buffer = BitmapFromString(
+ "11111111 10000000 00000000 00000000 00000000 00000000 00000001 11111111"
+ "11111111 10000000 00000000 00000000 00000000 00000000 00000001 11111111"
+ "11111111 10000000 00000000 00000000 00000000 00000000 00000001 11111111"
+ "11111111 10000000 00000000 00000000 00000000 00000000 00000001 11111111");
+
+ // One word
+ AssertWords(*buffer, 0, 9, {0x1ff});
+ AssertWords(*buffer, 1, 9, {0xff});
+ AssertWords(*buffer, 7, 9, {0x3});
+ AssertWords(*buffer, 8, 9, {0x1});
+ AssertWords(*buffer, 9, 9, {0x0});
+
+ AssertWords(*buffer, 54, 10, {0x3fe});
+ AssertWords(*buffer, 54, 9, {0x1fe});
+ AssertWords(*buffer, 54, 8, {0xfe});
+
+ AssertWords(*buffer, 55, 9, {0x1ff});
+ AssertWords(*buffer, 56, 8, {0xff});
+ AssertWords(*buffer, 57, 7, {0x7f});
+ AssertWords(*buffer, 63, 1, {0x1});
+
+ AssertWords(*buffer, 0, 64, {0xff800000000001ffULL});
+
+ // One straddling word
+ AssertWords(*buffer, 54, 12, {0xffe});
+ AssertWords(*buffer, 63, 2, {0x3});
+
+ // One word (start_offset >= 64)
+ AssertWords(*buffer, 96, 64, {0x000001ffff800000ULL});
+
+ // Two words
+ AssertWords(*buffer, 0, 128, {0xff800000000001ffULL, 0xff800000000001ffULL});
+ AssertWords(*buffer, 0, 127, {0xff800000000001ffULL, 0x7f800000000001ffULL});
+ AssertWords(*buffer, 1, 127, {0xffc00000000000ffULL, 0x7fc00000000000ffULL});
+ AssertWords(*buffer, 1, 128, {0xffc00000000000ffULL, 0xffc00000000000ffULL});
+ AssertWords(*buffer, 63, 128, {0xff000000000003ffULL, 0xff000000000003ffULL});
+ AssertWords(*buffer, 63, 65, {0xff000000000003ffULL, 0x1});
+
+ // More than two words
+ AssertWords(*buffer, 0, 256,
+ {0xff800000000001ffULL, 0xff800000000001ffULL, 0xff800000000001ffULL,
+ 0xff800000000001ffULL});
+ AssertWords(*buffer, 1, 255,
+ {0xffc00000000000ffULL, 0xffc00000000000ffULL, 0xffc00000000000ffULL,
+ 0x7fc00000000000ffULL});
+ AssertWords(*buffer, 63, 193,
+ {0xff000000000003ffULL, 0xff000000000003ffULL, 0xff000000000003ffULL, 0x1});
+ AssertWords(*buffer, 63, 192,
+ {0xff000000000003ffULL, 0xff000000000003ffULL, 0xff000000000003ffULL});
+
+ CheckExtensive(*buffer);
+}
+
+TEST_F(TestBitmapUInt64Reader, Random) {
+ random::RandomArrayGenerator rng(42);
+ auto buffer = rng.NullBitmap(500, 0.5);
+ CheckExtensive(*buffer);
+}
+
+class TestSetBitRunReader : public ::testing::Test {
+ public:
+ std::vector<internal::SetBitRun> ReferenceBitRuns(const uint8_t* data,
+ int64_t start_offset,
+ int64_t length) {
+ std::vector<internal::SetBitRun> runs;
+ internal::BitRunReaderLinear reader(data, start_offset, length);
+ int64_t position = 0;
+ while (position < length) {
+ const auto br = reader.NextRun();
+ if (br.set) {
+ runs.push_back({position, br.length});
+ }
+ position += br.length;
+ }
+ return runs;
+ }
+
+ template <typename SetBitRunReaderType>
+ std::vector<internal::SetBitRun> AllBitRuns(SetBitRunReaderType* reader) {
+ std::vector<internal::SetBitRun> runs;
+ auto run = reader->NextRun();
+ while (!run.AtEnd()) {
+ runs.push_back(run);
+ run = reader->NextRun();
+ }
+ return runs;
+ }
+
+ template <typename SetBitRunReaderType>
+ void AssertBitRuns(SetBitRunReaderType* reader,
+ const std::vector<internal::SetBitRun>& expected) {
+ ASSERT_EQ(AllBitRuns(reader), expected);
+ }
+
+ void AssertBitRuns(const uint8_t* data, int64_t start_offset, int64_t length,
+ const std::vector<internal::SetBitRun>& expected) {
+ {
+ internal::SetBitRunReader reader(data, start_offset, length);
+ AssertBitRuns(&reader, expected);
+ }
+ {
+ internal::ReverseSetBitRunReader reader(data, start_offset, length);
+ auto reversed_expected = expected;
+ std::reverse(reversed_expected.begin(), reversed_expected.end());
+ AssertBitRuns(&reader, reversed_expected);
+ }
+ }
+
+ void AssertBitRuns(const Buffer& buffer, int64_t start_offset, int64_t length,
+ const std::vector<internal::SetBitRun>& expected) {
+ AssertBitRuns(buffer.data(), start_offset, length, expected);
+ }
+
+ void CheckAgainstReference(const Buffer& buffer, int64_t start_offset, int64_t length) {
+ const auto expected = ReferenceBitRuns(buffer.data(), start_offset, length);
+ AssertBitRuns(buffer.data(), start_offset, length, expected);
+ }
+
+ struct Range {
+ int64_t offset;
+ int64_t length;
+
+ int64_t end_offset() const { return offset + length; }
+ };
+
+ std::vector<Range> BufferTestRanges(const Buffer& buffer) {
+ const int64_t buffer_size = buffer.size() * 8; // in bits
+ std::vector<Range> ranges;
+ for (const int64_t offset : kTestOffsets) {
+ for (const int64_t length_adjust : kTestOffsets) {
+ int64_t length = std::min(buffer_size - offset, length_adjust);
+ EXPECT_GE(length, 0);
+ ranges.push_back({offset, length});
+ length = std::min(buffer_size - offset, buffer_size - length_adjust);
+ EXPECT_GE(length, 0);
+ ranges.push_back({offset, length});
+ }
+ }
+ return ranges;
+ }
+
+ protected:
+ const std::vector<int64_t> kTestOffsets = {0, 1, 6, 7, 8, 33, 63, 64, 65, 71};
+};
+
+TEST_F(TestSetBitRunReader, Empty) {
+ for (const int64_t offset : kTestOffsets) {
+ // Does not access invalid memory
+ AssertBitRuns(nullptr, offset, 0, {});
+ }
+}
+
+TEST_F(TestSetBitRunReader, OneByte) {
+ auto buffer = BitmapFromString("01101101");
+ AssertBitRuns(*buffer, 0, 8, {{1, 2}, {4, 2}, {7, 1}});
+
+ for (const char* bitmap_string : {"01101101", "10110110", "00000000", "11111111"}) {
+ auto buffer = BitmapFromString(bitmap_string);
+ for (int64_t offset = 0; offset < 8; ++offset) {
+ for (int64_t length = 0; length <= 8 - offset; ++length) {
+ CheckAgainstReference(*buffer, offset, length);
+ }
+ }
+ }
+}
+
+TEST_F(TestSetBitRunReader, Tiny) {
+ auto buffer = BitmapFromString("11100011 10001110 00111000 11100011 10001110 00111000");
+
+ AssertBitRuns(*buffer, 0, 48,
+ {{0, 3}, {6, 3}, {12, 3}, {18, 3}, {24, 3}, {30, 3}, {36, 3}, {42, 3}});
+ AssertBitRuns(*buffer, 0, 46,
+ {{0, 3}, {6, 3}, {12, 3}, {18, 3}, {24, 3}, {30, 3}, {36, 3}, {42, 3}});
+ AssertBitRuns(*buffer, 0, 45,
+ {{0, 3}, {6, 3}, {12, 3}, {18, 3}, {24, 3}, {30, 3}, {36, 3}, {42, 3}});
+ AssertBitRuns(*buffer, 0, 42,
+ {{0, 3}, {6, 3}, {12, 3}, {18, 3}, {24, 3}, {30, 3}, {36, 3}});
+ AssertBitRuns(*buffer, 3, 45,
+ {{3, 3}, {9, 3}, {15, 3}, {21, 3}, {27, 3}, {33, 3}, {39, 3}});
+ AssertBitRuns(*buffer, 3, 43,
+ {{3, 3}, {9, 3}, {15, 3}, {21, 3}, {27, 3}, {33, 3}, {39, 3}});
+ AssertBitRuns(*buffer, 3, 42,
+ {{3, 3}, {9, 3}, {15, 3}, {21, 3}, {27, 3}, {33, 3}, {39, 3}});
+ AssertBitRuns(*buffer, 3, 39, {{3, 3}, {9, 3}, {15, 3}, {21, 3}, {27, 3}, {33, 3}});
+}
+
+TEST_F(TestSetBitRunReader, AllZeros) {
+ const int64_t kBufferSize = 256;
+ ASSERT_OK_AND_ASSIGN(auto buffer, AllocateEmptyBitmap(kBufferSize));
+
+ for (const auto range : BufferTestRanges(*buffer)) {
+ AssertBitRuns(*buffer, range.offset, range.length, {});
+ }
+}
+
+TEST_F(TestSetBitRunReader, AllOnes) {
+ const int64_t kBufferSize = 256;
+ ASSERT_OK_AND_ASSIGN(auto buffer, AllocateEmptyBitmap(kBufferSize));
+ BitUtil::SetBitsTo(buffer->mutable_data(), 0, kBufferSize, true);
+
+ for (const auto range : BufferTestRanges(*buffer)) {
+ if (range.length > 0) {
+ AssertBitRuns(*buffer, range.offset, range.length, {{0, range.length}});
+ } else {
+ AssertBitRuns(*buffer, range.offset, range.length, {});
+ }
+ }
+}
+
+TEST_F(TestSetBitRunReader, Small) {
+ // Ones then zeros then ones
+ const int64_t kBufferSize = 256;
+ const int64_t kOnesLength = 64;
+ const int64_t kSecondOnesStart = kBufferSize - kOnesLength;
+ ASSERT_OK_AND_ASSIGN(auto buffer, AllocateEmptyBitmap(kBufferSize));
+ BitUtil::SetBitsTo(buffer->mutable_data(), 0, kBufferSize, false);
+ BitUtil::SetBitsTo(buffer->mutable_data(), 0, kOnesLength, true);
+ BitUtil::SetBitsTo(buffer->mutable_data(), kSecondOnesStart, kOnesLength, true);
+
+ for (const auto range : BufferTestRanges(*buffer)) {
+ std::vector<internal::SetBitRun> expected;
+ if (range.offset < kOnesLength && range.length > 0) {
+ expected.push_back({0, std::min(kOnesLength - range.offset, range.length)});
+ }
+ if (range.offset + range.length > kSecondOnesStart) {
+ expected.push_back({kSecondOnesStart - range.offset,
+ range.length + range.offset - kSecondOnesStart});
+ }
+ AssertBitRuns(*buffer, range.offset, range.length, expected);
+ }
+}
+
+TEST_F(TestSetBitRunReader, SingleRun) {
+ // One single run of ones, at varying places in the buffer
+ const int64_t kBufferSize = 512;
+ ASSERT_OK_AND_ASSIGN(auto buffer, AllocateEmptyBitmap(kBufferSize));
+
+ for (const auto ones_range : BufferTestRanges(*buffer)) {
+ BitUtil::SetBitsTo(buffer->mutable_data(), 0, kBufferSize, false);
+ BitUtil::SetBitsTo(buffer->mutable_data(), ones_range.offset, ones_range.length,
+ true);
+ for (const auto range : BufferTestRanges(*buffer)) {
+ std::vector<internal::SetBitRun> expected;
+
+ if (range.length && ones_range.length && range.offset < ones_range.end_offset() &&
+ ones_range.offset < range.end_offset()) {
+ // The two ranges intersect
+ const int64_t intersect_start = std::max(range.offset, ones_range.offset);
+ const int64_t intersect_stop =
+ std::min(range.end_offset(), ones_range.end_offset());
+ expected.push_back(
+ {intersect_start - range.offset, intersect_stop - intersect_start});
+ }
+ AssertBitRuns(*buffer, range.offset, range.length, expected);
+ }
+ }
+}
+
+TEST_F(TestSetBitRunReader, Random) {
+ const int64_t kBufferSize = 4096;
+ arrow::random::RandomArrayGenerator rng(42);
+ for (const double set_probability : {0.003, 0.01, 0.1, 0.5, 0.9, 0.99, 0.997}) {
+ auto arr = rng.Boolean(kBufferSize, set_probability);
+ auto buffer = arr->data()->buffers[1];
+ for (const auto range : BufferTestRanges(*buffer)) {
+ CheckAgainstReference(*buffer, range.offset, range.length);
+ }
+ }
+}
+
+TEST(BitRunReader, ZeroLength) {
+ internal::BitRunReader reader(nullptr, /*start_offset=*/0, /*length=*/0);
+
+ EXPECT_EQ(reader.NextRun().length, 0);
+}
+
+TEST(BitRunReader, NormalOperation) {
+ std::vector<int> bm_vector = {1, 0, 1}; // size: 3
+ bm_vector.insert(bm_vector.end(), /*n=*/5, /*val=*/0); // size: 8
+ bm_vector.insert(bm_vector.end(), /*n=*/7, /*val=*/1); // size: 15
+ bm_vector.insert(bm_vector.end(), /*n=*/3, /*val=*/0); // size: 18
+ bm_vector.insert(bm_vector.end(), /*n=*/25, /*val=*/1); // size: 43
+ bm_vector.insert(bm_vector.end(), /*n=*/21, /*val=*/0); // size: 64
+ bm_vector.insert(bm_vector.end(), /*n=*/26, /*val=*/1); // size: 90
+ bm_vector.insert(bm_vector.end(), /*n=*/130, /*val=*/0); // size: 220
+ bm_vector.insert(bm_vector.end(), /*n=*/65, /*val=*/1); // size: 285
+ std::shared_ptr<Buffer> bitmap;
+ int64_t length;
+ BitmapFromVector(bm_vector, /*bit_offset=*/0, &bitmap, &length);
+
+ internal::BitRunReader reader(bitmap->data(), /*start_offset=*/0, /*length=*/length);
+ std::vector<internal::BitRun> results;
+ internal::BitRun rl;
+ do {
+ rl = reader.NextRun();
+ results.push_back(rl);
+ } while (rl.length != 0);
+ EXPECT_EQ(results.back().length, 0);
+ results.pop_back();
+ EXPECT_THAT(results, ElementsAreArray(
+ std::vector<internal::BitRun>{{/*length=*/1, /*set=*/true},
+ {/*length=*/1, /*set=*/false},
+ {/*length=*/1, /*set=*/true},
+ {/*length=*/5, /*set=*/false},
+ {/*length=*/7, /*set=*/true},
+ {/*length=*/3, /*set=*/false},
+ {/*length=*/25, /*set=*/true},
+ {/*length=*/21, /*set=*/false},
+ {/*length=*/26, /*set=*/true},
+ {/*length=*/130, /*set=*/false},
+ {/*length=*/65, /*set=*/true}}));
+}
+
+TEST(BitRunReader, AllFirstByteCombos) {
+ for (int offset = 0; offset < 8; offset++) {
+ for (int64_t x = 0; x < (1 << 8) - 1; x++) {
+ int64_t bits = BitUtil::ToLittleEndian(x);
+ internal::BitRunReader reader(reinterpret_cast<uint8_t*>(&bits),
+ /*start_offset=*/offset,
+ /*length=*/8 - offset);
+ std::vector<internal::BitRun> results;
+ internal::BitRun rl;
+ do {
+ rl = reader.NextRun();
+ results.push_back(rl);
+ } while (rl.length != 0);
+ EXPECT_EQ(results.back().length, 0);
+ results.pop_back();
+ int64_t sum = 0;
+ for (const auto& result : results) {
+ sum += result.length;
+ }
+ ASSERT_EQ(sum, 8 - offset);
+ }
+ }
+}
+
+TEST(BitRunReader, TruncatedAtWord) {
+ std::vector<int> bm_vector;
+ bm_vector.insert(bm_vector.end(), /*n=*/7, /*val=*/1);
+ bm_vector.insert(bm_vector.end(), /*n=*/58, /*val=*/0);
+
+ std::shared_ptr<Buffer> bitmap;
+ int64_t length;
+ BitmapFromVector(bm_vector, /*bit_offset=*/0, &bitmap, &length);
+
+ internal::BitRunReader reader(bitmap->data(), /*start_offset=*/1,
+ /*length=*/63);
+ std::vector<internal::BitRun> results;
+ internal::BitRun rl;
+ do {
+ rl = reader.NextRun();
+ results.push_back(rl);
+ } while (rl.length != 0);
+ EXPECT_EQ(results.back().length, 0);
+ results.pop_back();
+ EXPECT_THAT(results,
+ ElementsAreArray(std::vector<internal::BitRun>{
+ {/*length=*/6, /*set=*/true}, {/*length=*/57, /*set=*/false}}));
+}
+
+TEST(BitRunReader, ScalarComparison) {
+ ::arrow::random::RandomArrayGenerator rag(/*seed=*/23);
+ constexpr int64_t kNumBits = 1000000;
+ std::shared_ptr<Buffer> buffer =
+ rag.Boolean(kNumBits, /*set_probability=*/.4)->data()->buffers[1];
+
+ const uint8_t* bitmap = buffer->data();
+
+ internal::BitRunReader reader(bitmap, 0, kNumBits);
+ internal::BitRunReaderLinear scalar_reader(bitmap, 0, kNumBits);
+ internal::BitRun br, brs;
+ int64_t br_bits = 0;
+ int64_t brs_bits = 0;
+ do {
+ br = reader.NextRun();
+ brs = scalar_reader.NextRun();
+ br_bits += br.length;
+ brs_bits += brs.length;
+ EXPECT_EQ(br.length, brs.length);
+ if (br.length > 0) {
+ EXPECT_EQ(br, brs) << internal::Bitmap(bitmap, 0, kNumBits).ToString() << br_bits
+ << " " << brs_bits;
+ }
+ } while (brs.length != 0);
+ EXPECT_EQ(br_bits, brs_bits);
+}
+
+TEST(BitRunReader, TruncatedWithinWordMultipleOf8Bits) {
+ std::vector<int> bm_vector;
+ bm_vector.insert(bm_vector.end(), /*n=*/7, /*val=*/1);
+ bm_vector.insert(bm_vector.end(), /*n=*/5, /*val=*/0);
+
+ std::shared_ptr<Buffer> bitmap;
+ int64_t length;
+ BitmapFromVector(bm_vector, /*bit_offset=*/0, &bitmap, &length);
+
+ internal::BitRunReader reader(bitmap->data(), /*start_offset=*/1,
+ /*length=*/7);
+ std::vector<internal::BitRun> results;
+ internal::BitRun rl;
+ do {
+ rl = reader.NextRun();
+ results.push_back(rl);
+ } while (rl.length != 0);
+ EXPECT_EQ(results.back().length, 0);
+ results.pop_back();
+ EXPECT_THAT(results, ElementsAreArray(std::vector<internal::BitRun>{
+ {/*length=*/6, /*set=*/true}, {/*length=*/1, /*set=*/false}}));
+}
+
+TEST(BitRunReader, TruncatedWithinWord) {
+ std::vector<int> bm_vector;
+ bm_vector.insert(bm_vector.end(), /*n=*/37 + 40, /*val=*/0);
+ bm_vector.insert(bm_vector.end(), /*n=*/23, /*val=*/1);
+
+ std::shared_ptr<Buffer> bitmap;
+ int64_t length;
+ BitmapFromVector(bm_vector, /*bit_offset=*/0, &bitmap, &length);
+
+ constexpr int64_t kOffset = 37;
+ internal::BitRunReader reader(bitmap->data(), /*start_offset=*/kOffset,
+ /*length=*/53);
+ std::vector<internal::BitRun> results;
+ internal::BitRun rl;
+ do {
+ rl = reader.NextRun();
+ results.push_back(rl);
+ } while (rl.length != 0);
+ EXPECT_EQ(results.back().length, 0);
+ results.pop_back();
+ EXPECT_THAT(results,
+ ElementsAreArray(std::vector<internal::BitRun>{
+ {/*length=*/40, /*set=*/false}, {/*length=*/13, /*set=*/true}}));
+}
+
+TEST(BitRunReader, TruncatedMultipleWords) {
+ std::vector<int> bm_vector = {1, 0, 1}; // size: 3
+ bm_vector.insert(bm_vector.end(), /*n=*/5, /*val=*/0); // size: 8
+ bm_vector.insert(bm_vector.end(), /*n=*/30, /*val=*/1); // size: 38
+ bm_vector.insert(bm_vector.end(), /*n=*/95, /*val=*/0); // size: 133
+ std::shared_ptr<Buffer> bitmap;
+ int64_t length;
+ BitmapFromVector(bm_vector, /*bit_offset=*/0, &bitmap, &length);
+
+ constexpr int64_t kOffset = 5;
+ internal::BitRunReader reader(bitmap->data(), /*start_offset=*/kOffset,
+ /*length=*/length - (kOffset + 3));
+ std::vector<internal::BitRun> results;
+ internal::BitRun rl;
+ do {
+ rl = reader.NextRun();
+ results.push_back(rl);
+ } while (rl.length != 0);
+ EXPECT_EQ(results.back().length, 0);
+ results.pop_back();
+ EXPECT_THAT(results, ElementsAreArray(std::vector<internal::BitRun>{
+ {/*length=*/3, /*set=*/false},
+ {/*length=*/30, /*set=*/true},
+ {/*length=*/92, /*set=*/false}}));
+}
+
+TEST(BitmapWriter, NormalOperation) {
+ for (const auto fill_byte_int : {0x00, 0xff}) {
+ const uint8_t fill_byte = static_cast<uint8_t>(fill_byte_int);
+ {
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ auto writer = internal::BitmapWriter(bitmap, 0, 12);
+ WriteVectorToWriter(writer, {0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1});
+ // {0b00110110, 0b....1010, ........, ........}
+ ASSERT_BYTES_EQ(bitmap, {0x36, static_cast<uint8_t>(0x0a | (fill_byte & 0xf0)),
+ fill_byte, fill_byte});
+ }
+ {
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ auto writer = internal::BitmapWriter(bitmap, 3, 12);
+ WriteVectorToWriter(writer, {0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1});
+ // {0b10110..., 0b.1010001, ........, ........}
+ ASSERT_BYTES_EQ(bitmap, {static_cast<uint8_t>(0xb0 | (fill_byte & 0x07)),
+ static_cast<uint8_t>(0x51 | (fill_byte & 0x80)), fill_byte,
+ fill_byte});
+ }
+ {
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ auto writer = internal::BitmapWriter(bitmap, 20, 12);
+ WriteVectorToWriter(writer, {0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1});
+ // {........, ........, 0b0110...., 0b10100011}
+ ASSERT_BYTES_EQ(bitmap, {fill_byte, fill_byte,
+ static_cast<uint8_t>(0x60 | (fill_byte & 0x0f)), 0xa3});
+ }
+ // 0-length writes
+ for (int64_t pos = 0; pos < 32; ++pos) {
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ auto writer = internal::BitmapWriter(bitmap, pos, 0);
+ WriteVectorToWriter(writer, {});
+ ASSERT_BYTES_EQ(bitmap, {fill_byte, fill_byte, fill_byte, fill_byte});
+ }
+ }
+}
+
+TEST(BitmapWriter, DoesNotWriteOutOfBounds) {
+ uint8_t bitmap[16] = {0};
+
+ const int length = 128;
+
+ int64_t num_values = 0;
+
+ internal::BitmapWriter r1(bitmap, 0, length);
+
+ // If this were to write out of bounds, valgrind would tell us
+ for (int i = 0; i < length; ++i) {
+ r1.Set();
+ r1.Clear();
+ r1.Next();
+ }
+ r1.Finish();
+ num_values = r1.position();
+
+ ASSERT_EQ(length, num_values);
+
+ internal::BitmapWriter r2(bitmap, 5, length - 5);
+
+ for (int i = 0; i < (length - 5); ++i) {
+ r2.Set();
+ r2.Clear();
+ r2.Next();
+ }
+ r2.Finish();
+ num_values = r2.position();
+
+ ASSERT_EQ((length - 5), num_values);
+}
+
+TEST(FirstTimeBitmapWriter, NormalOperation) {
+ for (const auto fill_byte_int : {0x00, 0xff}) {
+ const uint8_t fill_byte = static_cast<uint8_t>(fill_byte_int);
+ {
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ auto writer = internal::FirstTimeBitmapWriter(bitmap, 0, 12);
+ WriteVectorToWriter(writer, {0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1});
+ // {0b00110110, 0b1010, 0, 0}
+ ASSERT_BYTES_EQ(bitmap, {0x36, 0x0a});
+ }
+ {
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ auto writer = internal::FirstTimeBitmapWriter(bitmap, 4, 12);
+ WriteVectorToWriter(writer, {0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1});
+ // {0b00110110, 0b1010, 0, 0}
+ ASSERT_BYTES_EQ(bitmap, {static_cast<uint8_t>(0x60 | (fill_byte & 0x0f)), 0xa3});
+ }
+ // Consecutive write chunks
+ {
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ {
+ auto writer = internal::FirstTimeBitmapWriter(bitmap, 0, 6);
+ WriteVectorToWriter(writer, {0, 1, 1, 0, 1, 1});
+ }
+ {
+ auto writer = internal::FirstTimeBitmapWriter(bitmap, 6, 3);
+ WriteVectorToWriter(writer, {0, 0, 0});
+ }
+ {
+ auto writer = internal::FirstTimeBitmapWriter(bitmap, 9, 3);
+ WriteVectorToWriter(writer, {1, 0, 1});
+ }
+ ASSERT_BYTES_EQ(bitmap, {0x36, 0x0a});
+ }
+ {
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ {
+ auto writer = internal::FirstTimeBitmapWriter(bitmap, 4, 0);
+ WriteVectorToWriter(writer, {});
+ }
+ {
+ auto writer = internal::FirstTimeBitmapWriter(bitmap, 4, 6);
+ WriteVectorToWriter(writer, {0, 1, 1, 0, 1, 1});
+ }
+ {
+ auto writer = internal::FirstTimeBitmapWriter(bitmap, 10, 3);
+ WriteVectorToWriter(writer, {0, 0, 0});
+ }
+ {
+ auto writer = internal::FirstTimeBitmapWriter(bitmap, 13, 0);
+ WriteVectorToWriter(writer, {});
+ }
+ {
+ auto writer = internal::FirstTimeBitmapWriter(bitmap, 13, 3);
+ WriteVectorToWriter(writer, {1, 0, 1});
+ }
+ ASSERT_BYTES_EQ(bitmap, {static_cast<uint8_t>(0x60 | (fill_byte & 0x0f)), 0xa3});
+ }
+ }
+}
+
+std::string BitmapToString(const uint8_t* bitmap, int64_t bit_count) {
+ return arrow::internal::Bitmap(bitmap, /*offset*/ 0, /*length=*/bit_count).ToString();
+}
+
+std::string BitmapToString(const std::vector<uint8_t>& bitmap, int64_t bit_count) {
+ return BitmapToString(bitmap.data(), bit_count);
+}
+
+TEST(FirstTimeBitmapWriter, AppendWordOffsetOverwritesCorrectBitsOnExistingByte) {
+ auto check_append = [](const std::string& expected_bits, int64_t offset) {
+ std::vector<uint8_t> valid_bits = {0x00};
+ constexpr int64_t kBitsAfterAppend = 8;
+ internal::FirstTimeBitmapWriter writer(valid_bits.data(), offset,
+ /*length=*/(8 * valid_bits.size()) - offset);
+ writer.AppendWord(/*word=*/0xFF, /*number_of_bits=*/kBitsAfterAppend - offset);
+ writer.Finish();
+ EXPECT_EQ(BitmapToString(valid_bits, kBitsAfterAppend), expected_bits);
+ };
+ check_append("11111111", 0);
+ check_append("01111111", 1);
+ check_append("00111111", 2);
+ check_append("00011111", 3);
+ check_append("00001111", 4);
+ check_append("00000111", 5);
+ check_append("00000011", 6);
+ check_append("00000001", 7);
+
+ auto check_with_set = [](const std::string& expected_bits, int64_t offset) {
+ std::vector<uint8_t> valid_bits = {0x1};
+ constexpr int64_t kBitsAfterAppend = 8;
+ internal::FirstTimeBitmapWriter writer(valid_bits.data(), offset,
+ /*length=*/(8 * valid_bits.size()) - offset);
+ writer.AppendWord(/*word=*/0xFF, /*number_of_bits=*/kBitsAfterAppend - offset);
+ writer.Finish();
+ EXPECT_EQ(BitmapToString(valid_bits, kBitsAfterAppend), expected_bits);
+ };
+ // 0ffset zero would not be a valid mask.
+ check_with_set("11111111", 1);
+ check_with_set("10111111", 2);
+ check_with_set("10011111", 3);
+ check_with_set("10001111", 4);
+ check_with_set("10000111", 5);
+ check_with_set("10000011", 6);
+ check_with_set("10000001", 7);
+
+ auto check_with_preceding = [](const std::string& expected_bits, int64_t offset) {
+ std::vector<uint8_t> valid_bits = {0xFF};
+ constexpr int64_t kBitsAfterAppend = 8;
+ internal::FirstTimeBitmapWriter writer(valid_bits.data(), offset,
+ /*length=*/(8 * valid_bits.size()) - offset);
+ writer.AppendWord(/*word=*/0xFF, /*number_of_bits=*/kBitsAfterAppend - offset);
+ writer.Finish();
+ EXPECT_EQ(BitmapToString(valid_bits, kBitsAfterAppend), expected_bits);
+ };
+ check_with_preceding("11111111", 0);
+ check_with_preceding("11111111", 1);
+ check_with_preceding("11111111", 2);
+ check_with_preceding("11111111", 3);
+ check_with_preceding("11111111", 4);
+ check_with_preceding("11111111", 5);
+ check_with_preceding("11111111", 6);
+ check_with_preceding("11111111", 7);
+}
+
+TEST(FirstTimeBitmapWriter, AppendZeroBitsHasNoImpact) {
+ std::vector<uint8_t> valid_bits(/*count=*/1, 0);
+ internal::FirstTimeBitmapWriter writer(valid_bits.data(), /*start_offset=*/1,
+ /*length=*/valid_bits.size() * 8);
+ writer.AppendWord(/*word=*/0xFF, /*number_of_bits=*/0);
+ writer.AppendWord(/*word=*/0xFF, /*number_of_bits=*/0);
+ writer.AppendWord(/*word=*/0x01, /*number_of_bits=*/1);
+ writer.Finish();
+ EXPECT_EQ(valid_bits[0], 0x2);
+}
+
+TEST(FirstTimeBitmapWriter, AppendLessThanByte) {
+ {
+ std::vector<uint8_t> valid_bits(/*count*/ 8, 0);
+ internal::FirstTimeBitmapWriter writer(valid_bits.data(), /*start_offset=*/1,
+ /*length=*/8);
+ writer.AppendWord(0xB, 4);
+ writer.Finish();
+ EXPECT_EQ(BitmapToString(valid_bits, /*bit_count=*/8), "01101000");
+ }
+ {
+ // Test with all bits initially set.
+ std::vector<uint8_t> valid_bits(/*count*/ 8, 0xFF);
+ internal::FirstTimeBitmapWriter writer(valid_bits.data(), /*start_offset=*/1,
+ /*length=*/8);
+ writer.AppendWord(0xB, 4);
+ writer.Finish();
+ EXPECT_EQ(BitmapToString(valid_bits, /*bit_count=*/8), "11101000");
+ }
+}
+
+TEST(FirstTimeBitmapWriter, AppendByteThenMore) {
+ {
+ std::vector<uint8_t> valid_bits(/*count*/ 8, 0);
+ internal::FirstTimeBitmapWriter writer(valid_bits.data(), /*start_offset=*/0,
+ /*length=*/9);
+ writer.AppendWord(0xC3, 8);
+ writer.AppendWord(0x01, 1);
+ writer.Finish();
+ EXPECT_EQ(BitmapToString(valid_bits, /*bit_count=*/9), "11000011 1");
+ }
+ {
+ std::vector<uint8_t> valid_bits(/*count*/ 8, 0xFF);
+ internal::FirstTimeBitmapWriter writer(valid_bits.data(), /*start_offset=*/0,
+ /*length=*/9);
+ writer.AppendWord(0xC3, 8);
+ writer.AppendWord(0x01, 1);
+ writer.Finish();
+ EXPECT_EQ(BitmapToString(valid_bits, /*bit_count=*/9), "11000011 1");
+ }
+}
+
+TEST(FirstTimeBitmapWriter, AppendWordShiftsBitsCorrectly) {
+ constexpr uint64_t kPattern = 0x9A9A9A9A9A9A9A9A;
+ auto check_append = [&](const std::string& leading_bits, const std::string& middle_bits,
+ const std::string& trailing_bits, int64_t offset,
+ bool preset_buffer_bits = false) {
+ ASSERT_GE(offset, 8);
+ std::vector<uint8_t> valid_bits(/*count=*/10, preset_buffer_bits ? 0xFF : 0);
+ valid_bits[0] = 0x99;
+ internal::FirstTimeBitmapWriter writer(valid_bits.data(), offset,
+ /*length=*/(9 * sizeof(kPattern)) - offset);
+ writer.AppendWord(/*word=*/kPattern, /*number_of_bits=*/64);
+ writer.Finish();
+ EXPECT_EQ(valid_bits[0], 0x99); // shouldn't get changed.
+ EXPECT_EQ(BitmapToString(valid_bits.data() + 1, /*num_bits=*/8), leading_bits);
+ for (int x = 2; x < 9; x++) {
+ EXPECT_EQ(BitmapToString(valid_bits.data() + x, /*num_bits=*/8), middle_bits)
+ << "x: " << x << " " << offset << " " << BitmapToString(valid_bits.data(), 80);
+ }
+ EXPECT_EQ(BitmapToString(valid_bits.data() + 9, /*num_bits=*/8), trailing_bits);
+ };
+ // Original Pattern = "01011001"
+ check_append(/*leading_bits= */ "01011001", /*middle_bits=*/"01011001",
+ /*trailing_bits=*/"00000000", /*offset=*/8);
+ check_append("00101100", "10101100", "10000000", 9);
+ check_append("00010110", "01010110", "01000000", 10);
+ check_append("00001011", "00101011", "00100000", 11);
+ check_append("00000101", "10010101", "10010000", 12);
+ check_append("00000010", "11001010", "11001000", 13);
+ check_append("00000001", "01100101", "01100100", 14);
+ check_append("00000000", "10110010", "10110010", 15);
+
+ check_append(/*leading_bits= */ "01011001", /*middle_bits=*/"01011001",
+ /*trailing_bits=*/"11111111", /*offset=*/8, /*preset_buffer_bits=*/true);
+ check_append("10101100", "10101100", "10000000", 9, true);
+ check_append("11010110", "01010110", "01000000", 10, true);
+ check_append("11101011", "00101011", "00100000", 11, true);
+ check_append("11110101", "10010101", "10010000", 12, true);
+ check_append("11111010", "11001010", "11001000", 13, true);
+ check_append("11111101", "01100101", "01100100", 14, true);
+ check_append("11111110", "10110010", "10110010", 15, true);
+}
+
+TEST(TestAppendBitmap, AppendWordOnlyAppropriateBytesWritten) {
+ std::vector<uint8_t> valid_bits = {0x00, 0x00};
+
+ uint64_t bitmap = 0x1FF;
+ {
+ internal::FirstTimeBitmapWriter writer(valid_bits.data(), /*start_offset=*/1,
+ /*length=*/(8 * valid_bits.size()) - 1);
+ writer.AppendWord(bitmap, /*number_of_bits*/ 7);
+ writer.Finish();
+ EXPECT_THAT(valid_bits, ElementsAreArray(std::vector<uint8_t>{0xFE, 0x00}));
+ }
+ {
+ internal::FirstTimeBitmapWriter writer(valid_bits.data(), /*start_offset=*/1,
+ /*length=*/(8 * valid_bits.size()) - 1);
+ writer.AppendWord(bitmap, /*number_of_bits*/ 8);
+ writer.Finish();
+ EXPECT_THAT(valid_bits, ElementsAreArray(std::vector<uint8_t>{0xFE, 0x03}));
+ }
+}
+
+// Tests for GenerateBits and GenerateBitsUnrolled
+
+struct GenerateBitsFunctor {
+ template <class Generator>
+ void operator()(uint8_t* bitmap, int64_t start_offset, int64_t length, Generator&& g) {
+ return internal::GenerateBits(bitmap, start_offset, length, g);
+ }
+};
+
+struct GenerateBitsUnrolledFunctor {
+ template <class Generator>
+ void operator()(uint8_t* bitmap, int64_t start_offset, int64_t length, Generator&& g) {
+ return internal::GenerateBitsUnrolled(bitmap, start_offset, length, g);
+ }
+};
+
+template <typename T>
+class TestGenerateBits : public ::testing::Test {};
+
+typedef ::testing::Types<GenerateBitsFunctor, GenerateBitsUnrolledFunctor>
+ GenerateBitsTypes;
+TYPED_TEST_SUITE(TestGenerateBits, GenerateBitsTypes);
+
+TYPED_TEST(TestGenerateBits, NormalOperation) {
+ const int kSourceSize = 256;
+ uint8_t source[kSourceSize];
+ random_bytes(kSourceSize, 0, source);
+
+ const int64_t start_offsets[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 21, 31, 32};
+ const int64_t lengths[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 16,
+ 17, 21, 31, 32, 100, 201, 202, 203, 204, 205, 206, 207};
+ const uint8_t fill_bytes[] = {0x00, 0xff};
+
+ for (const int64_t start_offset : start_offsets) {
+ for (const int64_t length : lengths) {
+ for (const uint8_t fill_byte : fill_bytes) {
+ uint8_t bitmap[kSourceSize + 1];
+ memset(bitmap, fill_byte, kSourceSize + 1);
+ // First call GenerateBits
+ {
+ int64_t ncalled = 0;
+ internal::BitmapReader reader(source, 0, length);
+ TypeParam()(bitmap, start_offset, length, [&]() -> bool {
+ bool b = reader.IsSet();
+ reader.Next();
+ ++ncalled;
+ return b;
+ });
+ ASSERT_EQ(ncalled, length);
+ }
+ // Then check generated contents
+ {
+ internal::BitmapReader source_reader(source, 0, length);
+ internal::BitmapReader result_reader(bitmap, start_offset, length);
+ for (int64_t i = 0; i < length; ++i) {
+ ASSERT_EQ(source_reader.IsSet(), result_reader.IsSet())
+ << "mismatch at bit #" << i;
+ source_reader.Next();
+ result_reader.Next();
+ }
+ }
+ // Check bits preceding generated contents weren't clobbered
+ {
+ internal::BitmapReader reader_before(bitmap, 0, start_offset);
+ for (int64_t i = 0; i < start_offset; ++i) {
+ ASSERT_EQ(reader_before.IsSet(), fill_byte == 0xff)
+ << "mismatch at preceding bit #" << start_offset - i;
+ }
+ }
+ // Check the byte following generated contents wasn't clobbered
+ auto byte_after = bitmap[BitUtil::CeilDiv(start_offset + length, 8)];
+ ASSERT_EQ(byte_after, fill_byte);
+ }
+ }
+ }
+}
+
+// Tests for VisitBits and VisitBitsUnrolled. Based on the tests for GenerateBits and
+// GenerateBitsUnrolled.
+struct VisitBitsFunctor {
+ void operator()(const uint8_t* bitmap, int64_t start_offset, int64_t length,
+ bool* destination) {
+ auto writer = [&](const bool& bit_value) { *destination++ = bit_value; };
+ return internal::VisitBits(bitmap, start_offset, length, writer);
+ }
+};
+
+struct VisitBitsUnrolledFunctor {
+ void operator()(const uint8_t* bitmap, int64_t start_offset, int64_t length,
+ bool* destination) {
+ auto writer = [&](const bool& bit_value) { *destination++ = bit_value; };
+ return internal::VisitBitsUnrolled(bitmap, start_offset, length, writer);
+ }
+};
+
+/* Define a typed test class with some utility members. */
+template <typename T>
+class TestVisitBits : public ::testing::Test {
+ protected:
+ // The bitmap size that will be used throughout the VisitBits tests.
+ static const int64_t kBitmapSizeInBytes = 32;
+
+ // Typedefs for the source and expected destination types in this test.
+ using PackedBitmapType = std::array<uint8_t, kBitmapSizeInBytes>;
+ using UnpackedBitmapType = std::array<bool, 8 * kBitmapSizeInBytes>;
+
+ // Helper functions to generate the source bitmap and expected destination
+ // arrays.
+ static PackedBitmapType generate_packed_bitmap() {
+ PackedBitmapType bitmap;
+ // Assign random values into the source array.
+ random_bytes(kBitmapSizeInBytes, 0, bitmap.data());
+ return bitmap;
+ }
+
+ static UnpackedBitmapType generate_unpacked_bitmap(PackedBitmapType bitmap) {
+ // Use a BitmapReader (tested earlier) to populate the expected
+ // unpacked bitmap.
+ UnpackedBitmapType result;
+ internal::BitmapReader reader(bitmap.data(), 0, 8 * kBitmapSizeInBytes);
+ for (int64_t index = 0; index < 8 * kBitmapSizeInBytes; ++index) {
+ result[index] = reader.IsSet();
+ reader.Next();
+ }
+ return result;
+ }
+
+ // A pre-defined packed bitmap for use in test cases.
+ const PackedBitmapType packed_bitmap_;
+
+ // The expected unpacked bitmap that would be generated if each bit in
+ // the entire source bitmap was correctly unpacked to bytes.
+ const UnpackedBitmapType expected_unpacked_bitmap_;
+
+ // Define a test constructor that populates the packed bitmap and the expected
+ // unpacked bitmap.
+ TestVisitBits()
+ : packed_bitmap_(generate_packed_bitmap()),
+ expected_unpacked_bitmap_(generate_unpacked_bitmap(packed_bitmap_)) {}
+};
+
+using VisitBitsTestTypes = ::testing::Types<VisitBitsFunctor, VisitBitsUnrolledFunctor>;
+TYPED_TEST_SUITE(TestVisitBits, VisitBitsTestTypes);
+
+/* Test bit-unpacking when reading less than eight bits from the input */
+TYPED_TEST(TestVisitBits, NormalOperation) {
+ typename TestFixture::UnpackedBitmapType unpacked_bitmap;
+ const int64_t start_offsets[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 21, 31, 32};
+ const int64_t lengths[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 16,
+ 17, 21, 31, 32, 100, 201, 202, 203, 204, 205, 206, 207};
+ const bool fill_values[] = {false, true};
+
+ for (const bool fill_value : fill_values) {
+ auto is_unmodified = [=](bool value) -> bool { return value == fill_value; };
+
+ for (const int64_t start_offset : start_offsets) {
+ for (const int64_t length : lengths) {
+ std::string failure_info = std::string("fill value: ") +
+ std::to_string(fill_value) +
+ ", start offset: " + std::to_string(start_offset) +
+ ", length: " + std::to_string(length);
+ // Pre-fill the unpacked_bitmap array.
+ unpacked_bitmap.fill(fill_value);
+
+ // Attempt to read bits from the input bitmap into the unpacked_bitmap bitmap.
+ using VisitBitsFunctor = TypeParam;
+ VisitBitsFunctor()(this->packed_bitmap_.data(), start_offset, length,
+ unpacked_bitmap.data() + start_offset);
+
+ // Verify that the correct values have been written in the [start_offset,
+ // start_offset+length) range.
+ EXPECT_TRUE(std::equal(unpacked_bitmap.begin() + start_offset,
+ unpacked_bitmap.begin() + start_offset + length,
+ this->expected_unpacked_bitmap_.begin() + start_offset))
+ << "Invalid bytes unpacked when using " << failure_info;
+
+ // Verify that the unpacked_bitmap array has not changed before or after
+ // the [start_offset, start_offset+length) range.
+ EXPECT_TRUE(std::all_of(unpacked_bitmap.begin(),
+ unpacked_bitmap.begin() + start_offset, is_unmodified))
+ << "Unexpected modification to unpacked_bitmap array before written range "
+ "when using "
+ << failure_info;
+ EXPECT_TRUE(std::all_of(unpacked_bitmap.begin() + start_offset + length,
+ unpacked_bitmap.end(), is_unmodified))
+ << "Unexpected modification to unpacked_bitmap array after written range "
+ "when using "
+ << failure_info;
+ }
+ }
+ }
+}
+
+struct BitmapOperation {
+ virtual Result<std::shared_ptr<Buffer>> Call(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset) const = 0;
+
+ virtual Status Call(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset,
+ uint8_t* out_buffer) const = 0;
+
+ virtual ~BitmapOperation() = default;
+};
+
+struct BitmapAndOp : public BitmapOperation {
+ Result<std::shared_ptr<Buffer>> Call(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset) const override {
+ return BitmapAnd(pool, left, left_offset, right, right_offset, length, out_offset);
+ }
+
+ Status Call(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset,
+ uint8_t* out_buffer) const override {
+ BitmapAnd(left, left_offset, right, right_offset, length, out_offset, out_buffer);
+ return Status::OK();
+ }
+};
+
+struct BitmapOrOp : public BitmapOperation {
+ Result<std::shared_ptr<Buffer>> Call(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset) const override {
+ return BitmapOr(pool, left, left_offset, right, right_offset, length, out_offset);
+ }
+
+ Status Call(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset,
+ uint8_t* out_buffer) const override {
+ BitmapOr(left, left_offset, right, right_offset, length, out_offset, out_buffer);
+ return Status::OK();
+ }
+};
+
+struct BitmapXorOp : public BitmapOperation {
+ Result<std::shared_ptr<Buffer>> Call(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset) const override {
+ return BitmapXor(pool, left, left_offset, right, right_offset, length, out_offset);
+ }
+
+ Status Call(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset,
+ uint8_t* out_buffer) const override {
+ BitmapXor(left, left_offset, right, right_offset, length, out_offset, out_buffer);
+ return Status::OK();
+ }
+};
+
+struct BitmapAndNotOp : public BitmapOperation {
+ Result<std::shared_ptr<Buffer>> Call(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset) const override {
+ return BitmapAndNot(pool, left, left_offset, right, right_offset, length, out_offset);
+ }
+
+ Status Call(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset,
+ uint8_t* out_buffer) const override {
+ BitmapAndNot(left, left_offset, right, right_offset, length, out_offset, out_buffer);
+ return Status::OK();
+ }
+};
+
+class BitmapOp : public TestBase {
+ public:
+ void TestAligned(const BitmapOperation& op, const std::vector<int>& left_bits,
+ const std::vector<int>& right_bits,
+ const std::vector<int>& result_bits) {
+ std::shared_ptr<Buffer> left, right, out;
+ int64_t length;
+
+ for (int64_t left_offset : {0, 1, 3, 5, 7, 8, 13, 21, 38, 75, 120, 65536}) {
+ BitmapFromVector(left_bits, left_offset, &left, &length);
+ for (int64_t right_offset : {left_offset, left_offset + 8, left_offset + 40}) {
+ BitmapFromVector(right_bits, right_offset, &right, &length);
+ for (int64_t out_offset : {left_offset, left_offset + 16, left_offset + 24}) {
+ ASSERT_OK_AND_ASSIGN(
+ out, op.Call(default_memory_pool(), left->mutable_data(), left_offset,
+ right->mutable_data(), right_offset, length, out_offset));
+ auto reader = internal::BitmapReader(out->mutable_data(), out_offset, length);
+ ASSERT_READER_VALUES(reader, result_bits);
+
+ // Clear out buffer and try non-allocating version
+ std::memset(out->mutable_data(), 0, out->size());
+ ASSERT_OK(op.Call(left->mutable_data(), left_offset, right->mutable_data(),
+ right_offset, length, out_offset, out->mutable_data()));
+ reader = internal::BitmapReader(out->mutable_data(), out_offset, length);
+ ASSERT_READER_VALUES(reader, result_bits);
+ }
+ }
+ }
+ }
+
+ void TestUnaligned(const BitmapOperation& op, const std::vector<int>& left_bits,
+ const std::vector<int>& right_bits,
+ const std::vector<int>& result_bits) {
+ std::shared_ptr<Buffer> left, right, out;
+ int64_t length;
+ auto offset_values = {0, 1, 3, 5, 7, 8, 13, 21, 38, 75, 120, 65536};
+
+ for (int64_t left_offset : offset_values) {
+ BitmapFromVector(left_bits, left_offset, &left, &length);
+
+ for (int64_t right_offset : offset_values) {
+ BitmapFromVector(right_bits, right_offset, &right, &length);
+
+ for (int64_t out_offset : offset_values) {
+ ASSERT_OK_AND_ASSIGN(
+ out, op.Call(default_memory_pool(), left->mutable_data(), left_offset,
+ right->mutable_data(), right_offset, length, out_offset));
+ auto reader = internal::BitmapReader(out->mutable_data(), out_offset, length);
+ ASSERT_READER_VALUES(reader, result_bits);
+
+ // Clear out buffer and try non-allocating version
+ std::memset(out->mutable_data(), 0, out->size());
+ ASSERT_OK(op.Call(left->mutable_data(), left_offset, right->mutable_data(),
+ right_offset, length, out_offset, out->mutable_data()));
+ reader = internal::BitmapReader(out->mutable_data(), out_offset, length);
+ ASSERT_READER_VALUES(reader, result_bits);
+ }
+ }
+ }
+ }
+};
+
+TEST_F(BitmapOp, And) {
+ BitmapAndOp op;
+ std::vector<int> left = {0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1};
+ std::vector<int> right = {0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0};
+ std::vector<int> result = {0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0};
+
+ TestAligned(op, left, right, result);
+ TestUnaligned(op, left, right, result);
+}
+
+TEST_F(BitmapOp, Or) {
+ BitmapOrOp op;
+ std::vector<int> left = {0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0};
+ std::vector<int> right = {0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0};
+ std::vector<int> result = {0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0};
+
+ TestAligned(op, left, right, result);
+ TestUnaligned(op, left, right, result);
+}
+
+TEST_F(BitmapOp, Xor) {
+ BitmapXorOp op;
+ std::vector<int> left = {0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1};
+ std::vector<int> right = {0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0};
+ std::vector<int> result = {0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1};
+
+ TestAligned(op, left, right, result);
+ TestUnaligned(op, left, right, result);
+}
+
+TEST_F(BitmapOp, AndNot) {
+ BitmapAndNotOp op;
+ std::vector<int> left = {0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1};
+ std::vector<int> right = {0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0};
+ std::vector<int> result = {0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1};
+
+ TestAligned(op, left, right, result);
+ TestUnaligned(op, left, right, result);
+}
+
+TEST_F(BitmapOp, RandomXor) {
+ const int kBitCount = 1000;
+ uint8_t buffer[kBitCount * 2] = {0};
+
+ random_bytes(kBitCount * 2, 0, buffer);
+
+ std::vector<int> left(kBitCount);
+ std::vector<int> right(kBitCount);
+ std::vector<int> result(kBitCount);
+
+ for (int i = 0; i < kBitCount; ++i) {
+ left[i] = buffer[i] & 1;
+ right[i] = buffer[i + kBitCount] & 1;
+ result[i] = left[i] ^ right[i];
+ }
+
+ BitmapXorOp op;
+ for (int i = 0; i < 3; ++i) {
+ TestAligned(op, left, right, result);
+ TestUnaligned(op, left, right, result);
+
+ left.resize(left.size() * 5 / 11);
+ right.resize(left.size());
+ result.resize(left.size());
+ }
+}
+
+static inline int64_t SlowCountBits(const uint8_t* data, int64_t bit_offset,
+ int64_t length) {
+ int64_t count = 0;
+ for (int64_t i = bit_offset; i < bit_offset + length; ++i) {
+ if (BitUtil::GetBit(data, i)) {
+ ++count;
+ }
+ }
+ return count;
+}
+
+TEST(BitUtilTests, TestCountSetBits) {
+ const int kBufferSize = 1000;
+ alignas(8) uint8_t buffer[kBufferSize] = {0};
+ const int buffer_bits = kBufferSize * 8;
+
+ random_bytes(kBufferSize, 0, buffer);
+
+ // Check start addresses with 64-bit alignment and without
+ for (const uint8_t* data : {buffer, buffer + 1, buffer + 7}) {
+ for (const int num_bits : {buffer_bits - 96, buffer_bits - 101, buffer_bits - 127}) {
+ std::vector<int64_t> offsets = {
+ 0, 12, 16, 32, 37, 63, 64, 128, num_bits - 30, num_bits - 64};
+ for (const int64_t offset : offsets) {
+ int64_t result = CountSetBits(data, offset, num_bits - offset);
+ int64_t expected = SlowCountBits(data, offset, num_bits - offset);
+
+ ASSERT_EQ(expected, result);
+ }
+ }
+ }
+}
+
+TEST(BitUtilTests, TestSetBitsTo) {
+ using BitUtil::SetBitsTo;
+ for (const auto fill_byte_int : {0x00, 0xff}) {
+ const uint8_t fill_byte = static_cast<uint8_t>(fill_byte_int);
+ {
+ // test set within a byte
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ SetBitsTo(bitmap, 2, 2, true);
+ SetBitsTo(bitmap, 4, 2, false);
+ ASSERT_BYTES_EQ(bitmap, {static_cast<uint8_t>((fill_byte & ~0x3C) | 0xC)});
+ }
+ {
+ // test straddling a single byte boundary
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ SetBitsTo(bitmap, 4, 7, true);
+ SetBitsTo(bitmap, 11, 7, false);
+ ASSERT_BYTES_EQ(bitmap, {static_cast<uint8_t>((fill_byte & 0xF) | 0xF0), 0x7,
+ static_cast<uint8_t>(fill_byte & ~0x3)});
+ }
+ {
+ // test byte aligned end
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ SetBitsTo(bitmap, 4, 4, true);
+ SetBitsTo(bitmap, 8, 8, false);
+ ASSERT_BYTES_EQ(bitmap,
+ {static_cast<uint8_t>((fill_byte & 0xF) | 0xF0), 0x00, fill_byte});
+ }
+ {
+ // test byte aligned end, multiple bytes
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ SetBitsTo(bitmap, 0, 24, false);
+ uint8_t false_byte = static_cast<uint8_t>(0);
+ ASSERT_BYTES_EQ(bitmap, {false_byte, false_byte, false_byte, fill_byte});
+ }
+ }
+}
+
+TEST(BitUtilTests, TestSetBitmap) {
+ using BitUtil::SetBitsTo;
+ for (const auto fill_byte_int : {0xff}) {
+ const uint8_t fill_byte = static_cast<uint8_t>(fill_byte_int);
+ {
+ // test set within a byte
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ BitUtil::SetBitmap(bitmap, 2, 2);
+ BitUtil::ClearBitmap(bitmap, 4, 2);
+ ASSERT_BYTES_EQ(bitmap, {static_cast<uint8_t>((fill_byte & ~0x3C) | 0xC)});
+ }
+ {
+ // test straddling a single byte boundary
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ BitUtil::SetBitmap(bitmap, 4, 7);
+ BitUtil::ClearBitmap(bitmap, 11, 7);
+ ASSERT_BYTES_EQ(bitmap, {static_cast<uint8_t>((fill_byte & 0xF) | 0xF0), 0x7,
+ static_cast<uint8_t>(fill_byte & ~0x3)});
+ }
+ {
+ // test byte aligned end
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ BitUtil::SetBitmap(bitmap, 4, 4);
+ BitUtil::ClearBitmap(bitmap, 8, 8);
+ ASSERT_BYTES_EQ(bitmap,
+ {static_cast<uint8_t>((fill_byte & 0xF) | 0xF0), 0x00, fill_byte});
+ }
+ {
+ // test byte aligned end, multiple bytes
+ uint8_t bitmap[] = {fill_byte, fill_byte, fill_byte, fill_byte};
+ BitUtil::ClearBitmap(bitmap, 0, 24);
+ uint8_t false_byte = static_cast<uint8_t>(0);
+ ASSERT_BYTES_EQ(bitmap, {false_byte, false_byte, false_byte, fill_byte});
+ }
+ {
+ // ASAN test against out of bound access (ARROW-13803)
+ uint8_t bitmap[1] = {fill_byte};
+ BitUtil::ClearBitmap(bitmap, 0, 8);
+ ASSERT_EQ(bitmap[0], 0);
+ }
+ }
+}
+
+TEST(BitUtilTests, TestCopyBitmap) {
+ const int kBufferSize = 1000;
+
+ ASSERT_OK_AND_ASSIGN(auto buffer, AllocateBuffer(kBufferSize));
+ memset(buffer->mutable_data(), 0, kBufferSize);
+ random_bytes(kBufferSize, 0, buffer->mutable_data());
+
+ const uint8_t* src = buffer->data();
+
+ std::vector<int64_t> lengths = {kBufferSize * 8 - 4, kBufferSize * 8};
+ std::vector<int64_t> offsets = {0, 12, 16, 32, 37, 63, 64, 128};
+ for (int64_t num_bits : lengths) {
+ for (int64_t offset : offsets) {
+ const int64_t copy_length = num_bits - offset;
+
+ std::shared_ptr<Buffer> copy;
+ ASSERT_OK_AND_ASSIGN(copy,
+ CopyBitmap(default_memory_pool(), src, offset, copy_length));
+
+ for (int64_t i = 0; i < copy_length; ++i) {
+ ASSERT_EQ(BitUtil::GetBit(src, i + offset), BitUtil::GetBit(copy->data(), i));
+ }
+ }
+ }
+}
+
+TEST(BitUtilTests, TestCopyBitmapPreAllocated) {
+ const int kBufferSize = 1000;
+ std::vector<int64_t> lengths = {kBufferSize * 8 - 4, kBufferSize * 8};
+ std::vector<int64_t> offsets = {0, 12, 16, 32, 37, 63, 64, 128};
+
+ ASSERT_OK_AND_ASSIGN(auto buffer, AllocateBuffer(kBufferSize));
+ memset(buffer->mutable_data(), 0, kBufferSize);
+ random_bytes(kBufferSize, 0, buffer->mutable_data());
+ const uint8_t* src = buffer->data();
+
+ // Add 16 byte padding on both sides
+ ASSERT_OK_AND_ASSIGN(auto other_buffer, AllocateBuffer(kBufferSize + 32));
+ memset(other_buffer->mutable_data(), 0, kBufferSize + 32);
+ random_bytes(kBufferSize + 32, 0, other_buffer->mutable_data());
+ const uint8_t* other = other_buffer->data();
+
+ for (int64_t num_bits : lengths) {
+ for (int64_t offset : offsets) {
+ for (int64_t dest_offset : offsets) {
+ const int64_t copy_length = num_bits - offset;
+
+ ASSERT_OK_AND_ASSIGN(auto copy, AllocateBuffer(other_buffer->size()));
+ memcpy(copy->mutable_data(), other_buffer->data(), other_buffer->size());
+ CopyBitmap(src, offset, copy_length, copy->mutable_data(), dest_offset);
+
+ for (int64_t i = 0; i < dest_offset; ++i) {
+ ASSERT_EQ(BitUtil::GetBit(other, i), BitUtil::GetBit(copy->data(), i));
+ }
+ for (int64_t i = 0; i < copy_length; ++i) {
+ ASSERT_EQ(BitUtil::GetBit(src, i + offset),
+ BitUtil::GetBit(copy->data(), i + dest_offset));
+ }
+ for (int64_t i = dest_offset + copy_length; i < (other_buffer->size() * 8); ++i) {
+ ASSERT_EQ(BitUtil::GetBit(other, i), BitUtil::GetBit(copy->data(), i));
+ }
+ }
+ }
+ }
+}
+
+TEST(BitUtilTests, TestCopyAndInvertBitmapPreAllocated) {
+ const int kBufferSize = 1000;
+ std::vector<int64_t> lengths = {kBufferSize * 8 - 4, kBufferSize * 8};
+ std::vector<int64_t> offsets = {0, 12, 16, 32, 37, 63, 64, 128};
+
+ ASSERT_OK_AND_ASSIGN(auto buffer, AllocateBuffer(kBufferSize));
+ memset(buffer->mutable_data(), 0, kBufferSize);
+ random_bytes(kBufferSize, 0, buffer->mutable_data());
+ const uint8_t* src = buffer->data();
+
+ // Add 16 byte padding on both sides
+ ASSERT_OK_AND_ASSIGN(auto other_buffer, AllocateBuffer(kBufferSize + 32));
+ memset(other_buffer->mutable_data(), 0, kBufferSize + 32);
+ random_bytes(kBufferSize + 32, 0, other_buffer->mutable_data());
+ const uint8_t* other = other_buffer->data();
+
+ for (int64_t num_bits : lengths) {
+ for (int64_t offset : offsets) {
+ for (int64_t dest_offset : offsets) {
+ const int64_t copy_length = num_bits - offset;
+
+ ASSERT_OK_AND_ASSIGN(auto copy, AllocateBuffer(other_buffer->size()));
+ memcpy(copy->mutable_data(), other_buffer->data(), other_buffer->size());
+ InvertBitmap(src, offset, copy_length, copy->mutable_data(), dest_offset);
+
+ for (int64_t i = 0; i < dest_offset; ++i) {
+ ASSERT_EQ(BitUtil::GetBit(other, i), BitUtil::GetBit(copy->data(), i));
+ }
+ for (int64_t i = 0; i < copy_length; ++i) {
+ ASSERT_EQ(BitUtil::GetBit(src, i + offset),
+ !BitUtil::GetBit(copy->data(), i + dest_offset));
+ }
+ for (int64_t i = dest_offset + copy_length; i < (other_buffer->size() * 8); ++i) {
+ ASSERT_EQ(BitUtil::GetBit(other, i), BitUtil::GetBit(copy->data(), i));
+ }
+ }
+ }
+ }
+}
+
+TEST(BitUtilTests, TestBitmapEquals) {
+ const int srcBufferSize = 1000;
+
+ ASSERT_OK_AND_ASSIGN(auto src_buffer, AllocateBuffer(srcBufferSize));
+ memset(src_buffer->mutable_data(), 0, srcBufferSize);
+ random_bytes(srcBufferSize, 0, src_buffer->mutable_data());
+ const uint8_t* src = src_buffer->data();
+
+ std::vector<int64_t> lengths = {srcBufferSize * 8 - 4, srcBufferSize * 8};
+ std::vector<int64_t> offsets = {0, 12, 16, 32, 37, 63, 64, 128};
+
+ const auto dstBufferSize = srcBufferSize + BitUtil::BytesForBits(*std::max_element(
+ offsets.cbegin(), offsets.cend()));
+ ASSERT_OK_AND_ASSIGN(auto dst_buffer, AllocateBuffer(dstBufferSize))
+ uint8_t* dst = dst_buffer->mutable_data();
+
+ for (int64_t num_bits : lengths) {
+ for (int64_t offset_src : offsets) {
+ for (int64_t offset_dst : offsets) {
+ const auto bit_length = num_bits - offset_src;
+
+ internal::CopyBitmap(src, offset_src, bit_length, dst, offset_dst);
+ ASSERT_TRUE(internal::BitmapEquals(src, offset_src, dst, offset_dst, bit_length));
+
+ // test negative cases by flip some bit at head and tail
+ for (int64_t offset_flip : offsets) {
+ const auto offset_flip_head = offset_dst + offset_flip;
+ dst[offset_flip_head / 8] ^= 1 << (offset_flip_head % 8);
+ ASSERT_FALSE(
+ internal::BitmapEquals(src, offset_src, dst, offset_dst, bit_length));
+ dst[offset_flip_head / 8] ^= 1 << (offset_flip_head % 8);
+
+ const auto offset_flip_tail = offset_dst + bit_length - offset_flip - 1;
+ dst[offset_flip_tail / 8] ^= 1 << (offset_flip_tail % 8);
+ ASSERT_FALSE(
+ internal::BitmapEquals(src, offset_src, dst, offset_dst, bit_length));
+ dst[offset_flip_tail / 8] ^= 1 << (offset_flip_tail % 8);
+ }
+ }
+ }
+ }
+}
+
+TEST(BitUtil, CeilDiv) {
+ EXPECT_EQ(BitUtil::CeilDiv(0, 1), 0);
+ EXPECT_EQ(BitUtil::CeilDiv(1, 1), 1);
+ EXPECT_EQ(BitUtil::CeilDiv(1, 2), 1);
+ EXPECT_EQ(BitUtil::CeilDiv(0, 8), 0);
+ EXPECT_EQ(BitUtil::CeilDiv(1, 8), 1);
+ EXPECT_EQ(BitUtil::CeilDiv(7, 8), 1);
+ EXPECT_EQ(BitUtil::CeilDiv(8, 8), 1);
+ EXPECT_EQ(BitUtil::CeilDiv(9, 8), 2);
+ EXPECT_EQ(BitUtil::CeilDiv(9, 9), 1);
+ EXPECT_EQ(BitUtil::CeilDiv(10000000000, 10), 1000000000);
+ EXPECT_EQ(BitUtil::CeilDiv(10, 10000000000), 1);
+ EXPECT_EQ(BitUtil::CeilDiv(100000000000, 10000000000), 10);
+
+ // test overflow
+ int64_t value = std::numeric_limits<int64_t>::max() - 1;
+ int64_t divisor = std::numeric_limits<int64_t>::max();
+ EXPECT_EQ(BitUtil::CeilDiv(value, divisor), 1);
+
+ value = std::numeric_limits<int64_t>::max();
+ EXPECT_EQ(BitUtil::CeilDiv(value, divisor), 1);
+}
+
+TEST(BitUtil, RoundUp) {
+ EXPECT_EQ(BitUtil::RoundUp(0, 1), 0);
+ EXPECT_EQ(BitUtil::RoundUp(1, 1), 1);
+ EXPECT_EQ(BitUtil::RoundUp(1, 2), 2);
+ EXPECT_EQ(BitUtil::RoundUp(6, 2), 6);
+ EXPECT_EQ(BitUtil::RoundUp(0, 3), 0);
+ EXPECT_EQ(BitUtil::RoundUp(7, 3), 9);
+ EXPECT_EQ(BitUtil::RoundUp(9, 9), 9);
+ EXPECT_EQ(BitUtil::RoundUp(10000000001, 10), 10000000010);
+ EXPECT_EQ(BitUtil::RoundUp(10, 10000000000), 10000000000);
+ EXPECT_EQ(BitUtil::RoundUp(100000000000, 10000000000), 100000000000);
+
+ // test overflow
+ int64_t value = std::numeric_limits<int64_t>::max() - 1;
+ int64_t divisor = std::numeric_limits<int64_t>::max();
+ EXPECT_EQ(BitUtil::RoundUp(value, divisor), divisor);
+
+ value = std::numeric_limits<int64_t>::max();
+ EXPECT_EQ(BitUtil::RoundUp(value, divisor), divisor);
+}
+
+TEST(BitUtil, RoundDown) {
+ EXPECT_EQ(BitUtil::RoundDown(0, 1), 0);
+ EXPECT_EQ(BitUtil::RoundDown(1, 1), 1);
+ EXPECT_EQ(BitUtil::RoundDown(1, 2), 0);
+ EXPECT_EQ(BitUtil::RoundDown(6, 2), 6);
+ EXPECT_EQ(BitUtil::RoundDown(5, 7), 0);
+ EXPECT_EQ(BitUtil::RoundDown(10, 7), 7);
+ EXPECT_EQ(BitUtil::RoundDown(7, 3), 6);
+ EXPECT_EQ(BitUtil::RoundDown(9, 9), 9);
+ EXPECT_EQ(BitUtil::RoundDown(10000000001, 10), 10000000000);
+ EXPECT_EQ(BitUtil::RoundDown(10, 10000000000), 0);
+ EXPECT_EQ(BitUtil::RoundDown(100000000000, 10000000000), 100000000000);
+
+ for (int i = 0; i < 100; i++) {
+ for (int j = 1; j < 100; j++) {
+ EXPECT_EQ(BitUtil::RoundDown(i, j), i - (i % j));
+ }
+ }
+}
+
+TEST(BitUtil, CoveringBytes) {
+ EXPECT_EQ(BitUtil::CoveringBytes(0, 8), 1);
+ EXPECT_EQ(BitUtil::CoveringBytes(0, 9), 2);
+ EXPECT_EQ(BitUtil::CoveringBytes(1, 7), 1);
+ EXPECT_EQ(BitUtil::CoveringBytes(1, 8), 2);
+ EXPECT_EQ(BitUtil::CoveringBytes(2, 19), 3);
+ EXPECT_EQ(BitUtil::CoveringBytes(7, 18), 4);
+}
+
+TEST(BitUtil, TrailingBits) {
+ EXPECT_EQ(BitUtil::TrailingBits(0xFF, 0), 0);
+ EXPECT_EQ(BitUtil::TrailingBits(0xFF, 1), 1);
+ EXPECT_EQ(BitUtil::TrailingBits(0xFF, 64), 0xFF);
+ EXPECT_EQ(BitUtil::TrailingBits(0xFF, 100), 0xFF);
+ EXPECT_EQ(BitUtil::TrailingBits(0, 1), 0);
+ EXPECT_EQ(BitUtil::TrailingBits(0, 64), 0);
+ EXPECT_EQ(BitUtil::TrailingBits(1LL << 63, 0), 0);
+ EXPECT_EQ(BitUtil::TrailingBits(1LL << 63, 63), 0);
+ EXPECT_EQ(BitUtil::TrailingBits(1LL << 63, 64), 1LL << 63);
+}
+
+TEST(BitUtil, ByteSwap) {
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<uint32_t>(0)), 0);
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<uint32_t>(0x11223344)), 0x44332211);
+
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<int32_t>(0)), 0);
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<int32_t>(0x11223344)), 0x44332211);
+
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<uint64_t>(0)), 0);
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<uint64_t>(0x1122334455667788)),
+ 0x8877665544332211);
+
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<int64_t>(0)), 0);
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<int64_t>(0x1122334455667788)),
+ 0x8877665544332211);
+
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<int16_t>(0)), 0);
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<int16_t>(0x1122)), 0x2211);
+
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<uint16_t>(0)), 0);
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<uint16_t>(0x1122)), 0x2211);
+
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<int8_t>(0)), 0);
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<int8_t>(0x11)), 0x11);
+
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<uint8_t>(0)), 0);
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<uint8_t>(0x11)), 0x11);
+
+ EXPECT_EQ(BitUtil::ByteSwap(static_cast<float>(0)), 0);
+ uint32_t srci32 = 0xaabbccdd, expectedi32 = 0xddccbbaa;
+ float srcf32 = SafeCopy<float>(srci32);
+ float expectedf32 = SafeCopy<float>(expectedi32);
+ EXPECT_EQ(BitUtil::ByteSwap(srcf32), expectedf32);
+ uint64_t srci64 = 0xaabb11223344ccdd, expectedi64 = 0xddcc44332211bbaa;
+ double srcd64 = SafeCopy<double>(srci64);
+ double expectedd64 = SafeCopy<double>(expectedi64);
+ EXPECT_EQ(BitUtil::ByteSwap(srcd64), expectedd64);
+}
+
+TEST(BitUtil, Log2) {
+ EXPECT_EQ(BitUtil::Log2(1), 0);
+ EXPECT_EQ(BitUtil::Log2(2), 1);
+ EXPECT_EQ(BitUtil::Log2(3), 2);
+ EXPECT_EQ(BitUtil::Log2(4), 2);
+ EXPECT_EQ(BitUtil::Log2(5), 3);
+ EXPECT_EQ(BitUtil::Log2(8), 3);
+ EXPECT_EQ(BitUtil::Log2(9), 4);
+ EXPECT_EQ(BitUtil::Log2(INT_MAX), 31);
+ EXPECT_EQ(BitUtil::Log2(UINT_MAX), 32);
+ EXPECT_EQ(BitUtil::Log2(ULLONG_MAX), 64);
+}
+
+TEST(BitUtil, NumRequiredBits) {
+ EXPECT_EQ(BitUtil::NumRequiredBits(0), 0);
+ EXPECT_EQ(BitUtil::NumRequiredBits(1), 1);
+ EXPECT_EQ(BitUtil::NumRequiredBits(2), 2);
+ EXPECT_EQ(BitUtil::NumRequiredBits(3), 2);
+ EXPECT_EQ(BitUtil::NumRequiredBits(4), 3);
+ EXPECT_EQ(BitUtil::NumRequiredBits(5), 3);
+ EXPECT_EQ(BitUtil::NumRequiredBits(7), 3);
+ EXPECT_EQ(BitUtil::NumRequiredBits(8), 4);
+ EXPECT_EQ(BitUtil::NumRequiredBits(9), 4);
+ EXPECT_EQ(BitUtil::NumRequiredBits(UINT_MAX - 1), 32);
+ EXPECT_EQ(BitUtil::NumRequiredBits(UINT_MAX), 32);
+ EXPECT_EQ(BitUtil::NumRequiredBits(static_cast<uint64_t>(UINT_MAX) + 1), 33);
+ EXPECT_EQ(BitUtil::NumRequiredBits(ULLONG_MAX / 2), 63);
+ EXPECT_EQ(BitUtil::NumRequiredBits(ULLONG_MAX / 2 + 1), 64);
+ EXPECT_EQ(BitUtil::NumRequiredBits(ULLONG_MAX - 1), 64);
+ EXPECT_EQ(BitUtil::NumRequiredBits(ULLONG_MAX), 64);
+}
+
+#define U32(x) static_cast<uint32_t>(x)
+#define U64(x) static_cast<uint64_t>(x)
+#define S64(x) static_cast<int64_t>(x)
+
+TEST(BitUtil, CountLeadingZeros) {
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U32(0)), 32);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U32(1)), 31);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U32(2)), 30);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U32(3)), 30);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U32(4)), 29);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U32(7)), 29);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U32(8)), 28);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U32(UINT_MAX / 2)), 1);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U32(UINT_MAX / 2 + 1)), 0);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U32(UINT_MAX)), 0);
+
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U64(0)), 64);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U64(1)), 63);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U64(2)), 62);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U64(3)), 62);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U64(4)), 61);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U64(7)), 61);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U64(8)), 60);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U64(UINT_MAX)), 32);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U64(UINT_MAX) + 1), 31);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U64(ULLONG_MAX / 2)), 1);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U64(ULLONG_MAX / 2 + 1)), 0);
+ EXPECT_EQ(BitUtil::CountLeadingZeros(U64(ULLONG_MAX)), 0);
+}
+
+TEST(BitUtil, CountTrailingZeros) {
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U32(0)), 32);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U32(1) << 31), 31);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U32(1) << 30), 30);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U32(1) << 29), 29);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U32(1) << 28), 28);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U32(8)), 3);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U32(4)), 2);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U32(2)), 1);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U32(1)), 0);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U32(ULONG_MAX)), 0);
+
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U64(0)), 64);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U64(1) << 63), 63);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U64(1) << 62), 62);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U64(1) << 61), 61);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U64(1) << 60), 60);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U64(8)), 3);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U64(4)), 2);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U64(2)), 1);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U64(1)), 0);
+ EXPECT_EQ(BitUtil::CountTrailingZeros(U64(ULLONG_MAX)), 0);
+}
+
+TEST(BitUtil, RoundUpToPowerOf2) {
+ EXPECT_EQ(BitUtil::RoundUpToPowerOf2(S64(7), 8), 8);
+ EXPECT_EQ(BitUtil::RoundUpToPowerOf2(S64(8), 8), 8);
+ EXPECT_EQ(BitUtil::RoundUpToPowerOf2(S64(9), 8), 16);
+
+ EXPECT_EQ(BitUtil::RoundUpToPowerOf2(U64(7), 8), 8);
+ EXPECT_EQ(BitUtil::RoundUpToPowerOf2(U64(8), 8), 8);
+ EXPECT_EQ(BitUtil::RoundUpToPowerOf2(U64(9), 8), 16);
+}
+
+#undef U32
+#undef U64
+#undef S64
+
+static void TestZigZag(int32_t v, std::array<uint8_t, 5> buffer_expect) {
+ uint8_t buffer[BitUtil::BitReader::kMaxVlqByteLength] = {};
+ BitUtil::BitWriter writer(buffer, sizeof(buffer));
+ BitUtil::BitReader reader(buffer, sizeof(buffer));
+ writer.PutZigZagVlqInt(v);
+ EXPECT_THAT(buffer, testing::ElementsAreArray(buffer_expect));
+ int32_t result;
+ EXPECT_TRUE(reader.GetZigZagVlqInt(&result));
+ EXPECT_EQ(v, result);
+}
+
+TEST(BitStreamUtil, ZigZag) {
+ TestZigZag(0, {0, 0, 0, 0, 0});
+ TestZigZag(1, {2, 0, 0, 0, 0});
+ TestZigZag(1234, {164, 19, 0, 0, 0});
+ TestZigZag(-1, {1, 0, 0, 0, 0});
+ TestZigZag(-1234, {163, 19, 0, 0, 0});
+ TestZigZag(std::numeric_limits<int32_t>::max(), {254, 255, 255, 255, 15});
+ TestZigZag(-std::numeric_limits<int32_t>::max(), {253, 255, 255, 255, 15});
+ TestZigZag(std::numeric_limits<int32_t>::min(), {255, 255, 255, 255, 15});
+}
+
+static void TestZigZag64(int64_t v, std::array<uint8_t, 10> buffer_expect) {
+ uint8_t buffer[BitUtil::BitReader::kMaxVlqByteLengthForInt64] = {};
+ BitUtil::BitWriter writer(buffer, sizeof(buffer));
+ BitUtil::BitReader reader(buffer, sizeof(buffer));
+ writer.PutZigZagVlqInt(v);
+ EXPECT_THAT(buffer, testing::ElementsAreArray(buffer_expect));
+ int64_t result;
+ EXPECT_TRUE(reader.GetZigZagVlqInt(&result));
+ EXPECT_EQ(v, result);
+}
+
+TEST(BitStreamUtil, ZigZag64) {
+ TestZigZag64(0, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ TestZigZag64(1, {2, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ TestZigZag64(1234, {164, 19, 0, 0, 0, 0, 0, 0, 0, 0});
+ TestZigZag64(-1, {1, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ TestZigZag64(-1234, {163, 19, 0, 0, 0, 0, 0, 0, 0, 0});
+ TestZigZag64(std::numeric_limits<int64_t>::max(),
+ {254, 255, 255, 255, 255, 255, 255, 255, 255, 1});
+ TestZigZag64(-std::numeric_limits<int64_t>::max(),
+ {253, 255, 255, 255, 255, 255, 255, 255, 255, 1});
+ TestZigZag64(std::numeric_limits<int64_t>::min(),
+ {255, 255, 255, 255, 255, 255, 255, 255, 255, 1});
+}
+
+TEST(BitUtil, RoundTripLittleEndianTest) {
+ uint64_t value = 0xFF;
+
+#if ARROW_LITTLE_ENDIAN
+ uint64_t expected = value;
+#else
+ uint64_t expected = std::numeric_limits<uint64_t>::max() << 56;
+#endif
+
+ uint64_t little_endian_result = BitUtil::ToLittleEndian(value);
+ ASSERT_EQ(expected, little_endian_result);
+
+ uint64_t from_little_endian = BitUtil::FromLittleEndian(little_endian_result);
+ ASSERT_EQ(value, from_little_endian);
+}
+
+TEST(BitUtil, RoundTripBigEndianTest) {
+ uint64_t value = 0xFF;
+
+#if ARROW_LITTLE_ENDIAN
+ uint64_t expected = std::numeric_limits<uint64_t>::max() << 56;
+#else
+ uint64_t expected = value;
+#endif
+
+ uint64_t big_endian_result = BitUtil::ToBigEndian(value);
+ ASSERT_EQ(expected, big_endian_result);
+
+ uint64_t from_big_endian = BitUtil::FromBigEndian(big_endian_result);
+ ASSERT_EQ(value, from_big_endian);
+}
+
+TEST(BitUtil, BitsetStack) {
+ BitsetStack stack;
+ ASSERT_EQ(stack.TopSize(), 0);
+ stack.Push(3, false);
+ ASSERT_EQ(stack.TopSize(), 3);
+ stack[1] = true;
+ stack.Push(5, true);
+ ASSERT_EQ(stack.TopSize(), 5);
+ stack[1] = false;
+ for (int i = 0; i != 5; ++i) {
+ ASSERT_EQ(stack[i], i != 1);
+ }
+ stack.Pop();
+ ASSERT_EQ(stack.TopSize(), 3);
+ for (int i = 0; i != 3; ++i) {
+ ASSERT_EQ(stack[i], i == 1);
+ }
+ stack.Pop();
+ ASSERT_EQ(stack.TopSize(), 0);
+}
+
+TEST(SpliceWord, SpliceWord) {
+ static_assert(
+ BitUtil::PrecedingWordBitmask<uint8_t>(0) == BitUtil::kPrecedingBitmask[0], "");
+ static_assert(
+ BitUtil::PrecedingWordBitmask<uint8_t>(5) == BitUtil::kPrecedingBitmask[5], "");
+ static_assert(BitUtil::PrecedingWordBitmask<uint8_t>(8) == UINT8_MAX, "");
+
+ static_assert(BitUtil::PrecedingWordBitmask<uint64_t>(0) == uint64_t(0), "");
+ static_assert(BitUtil::PrecedingWordBitmask<uint64_t>(33) == 8589934591, "");
+ static_assert(BitUtil::PrecedingWordBitmask<uint64_t>(64) == UINT64_MAX, "");
+ static_assert(BitUtil::PrecedingWordBitmask<uint64_t>(65) == UINT64_MAX, "");
+
+ ASSERT_EQ(BitUtil::SpliceWord<uint8_t>(0, 0x12, 0xef), 0xef);
+ ASSERT_EQ(BitUtil::SpliceWord<uint8_t>(8, 0x12, 0xef), 0x12);
+ ASSERT_EQ(BitUtil::SpliceWord<uint8_t>(3, 0x12, 0xef), 0xea);
+
+ ASSERT_EQ(BitUtil::SpliceWord<uint32_t>(0, 0x12345678, 0xfedcba98), 0xfedcba98);
+ ASSERT_EQ(BitUtil::SpliceWord<uint32_t>(32, 0x12345678, 0xfedcba98), 0x12345678);
+ ASSERT_EQ(BitUtil::SpliceWord<uint32_t>(24, 0x12345678, 0xfedcba98), 0xfe345678);
+
+ ASSERT_EQ(BitUtil::SpliceWord<uint64_t>(0, 0x0123456789abcdef, 0xfedcba9876543210),
+ 0xfedcba9876543210);
+ ASSERT_EQ(BitUtil::SpliceWord<uint64_t>(64, 0x0123456789abcdef, 0xfedcba9876543210),
+ 0x0123456789abcdef);
+ ASSERT_EQ(BitUtil::SpliceWord<uint64_t>(48, 0x0123456789abcdef, 0xfedcba9876543210),
+ 0xfedc456789abcdef);
+}
+
+// test the basic assumption of word level Bitmap::Visit
+TEST(Bitmap, ShiftingWordsOptimization) {
+ // single word
+ {
+ uint64_t word;
+ auto bytes = reinterpret_cast<uint8_t*>(&word);
+ constexpr size_t kBitWidth = sizeof(word) * 8;
+
+ for (int seed = 0; seed < 64; ++seed) {
+ random_bytes(sizeof(word), seed, bytes);
+ uint64_t native_word = BitUtil::FromLittleEndian(word);
+
+ // bits are accessible through simple bit shifting of the word
+ for (size_t i = 0; i < kBitWidth; ++i) {
+ ASSERT_EQ(BitUtil::GetBit(bytes, i), bool((native_word >> i) & 1));
+ }
+
+ // bit offset can therefore be accommodated by shifting the word
+ for (size_t offset = 0; offset < (kBitWidth * 3) / 4; ++offset) {
+ uint64_t shifted_word = arrow::BitUtil::ToLittleEndian(native_word >> offset);
+ auto shifted_bytes = reinterpret_cast<uint8_t*>(&shifted_word);
+ ASSERT_TRUE(
+ internal::BitmapEquals(bytes, offset, shifted_bytes, 0, kBitWidth - offset));
+ }
+ }
+ }
+
+ // two words
+ {
+ uint64_t words[2];
+ auto bytes = reinterpret_cast<uint8_t*>(words);
+ constexpr size_t kBitWidth = sizeof(words[0]) * 8;
+
+ for (int seed = 0; seed < 64; ++seed) {
+ random_bytes(sizeof(words), seed, bytes);
+ uint64_t native_words0 = BitUtil::FromLittleEndian(words[0]);
+ uint64_t native_words1 = BitUtil::FromLittleEndian(words[1]);
+
+ // bits are accessible through simple bit shifting of a word
+ for (size_t i = 0; i < kBitWidth; ++i) {
+ ASSERT_EQ(BitUtil::GetBit(bytes, i), bool((native_words0 >> i) & 1));
+ }
+ for (size_t i = 0; i < kBitWidth; ++i) {
+ ASSERT_EQ(BitUtil::GetBit(bytes, i + kBitWidth), bool((native_words1 >> i) & 1));
+ }
+
+ // bit offset can therefore be accommodated by shifting the word
+ for (size_t offset = 1; offset < (kBitWidth * 3) / 4; offset += 3) {
+ uint64_t shifted_words[2];
+ shifted_words[0] = arrow::BitUtil::ToLittleEndian(
+ native_words0 >> offset | (native_words1 << (kBitWidth - offset)));
+ shifted_words[1] = arrow::BitUtil::ToLittleEndian(native_words1 >> offset);
+ auto shifted_bytes = reinterpret_cast<uint8_t*>(shifted_words);
+
+ // from offset to unshifted word boundary
+ ASSERT_TRUE(
+ internal::BitmapEquals(bytes, offset, shifted_bytes, 0, kBitWidth - offset));
+
+ // from unshifted word boundary to shifted word boundary
+ ASSERT_TRUE(internal::BitmapEquals(bytes, kBitWidth, shifted_bytes,
+ kBitWidth - offset, offset));
+
+ // from shifted word boundary to end
+ ASSERT_TRUE(internal::BitmapEquals(bytes, kBitWidth + offset, shifted_bytes,
+ kBitWidth, kBitWidth - offset));
+ }
+ }
+ }
+}
+
+namespace internal {
+
+static Bitmap Copy(const Bitmap& bitmap, std::shared_ptr<Buffer> storage) {
+ int64_t i = 0;
+ Bitmap bitmaps[] = {bitmap};
+ auto min_offset = Bitmap::VisitWords(bitmaps, [&](std::array<uint64_t, 1> uint64s) {
+ reinterpret_cast<uint64_t*>(storage->mutable_data())[i++] = uint64s[0];
+ });
+ return Bitmap(std::move(storage), min_offset, bitmap.length());
+}
+
+// reconstruct a bitmap from a word-wise visit
+TEST(Bitmap, VisitWords) {
+ constexpr int64_t nbytes = 1 << 10;
+ std::shared_ptr<Buffer> buffer, actual_buffer;
+ for (std::shared_ptr<Buffer>* b : {&buffer, &actual_buffer}) {
+ ASSERT_OK_AND_ASSIGN(*b, AllocateBuffer(nbytes));
+ memset((*b)->mutable_data(), 0, nbytes);
+ }
+ random_bytes(nbytes, 0, buffer->mutable_data());
+
+ constexpr int64_t kBitWidth = 8 * sizeof(uint64_t);
+
+ for (int64_t offset : {0, 1, 2, 5, 17}) {
+ for (int64_t num_bits :
+ {int64_t(13), int64_t(9), kBitWidth - 1, kBitWidth, kBitWidth + 1,
+ nbytes * 8 - offset, nbytes * 6, nbytes * 4}) {
+ Bitmap actual = Copy({buffer, offset, num_bits}, actual_buffer);
+ ASSERT_EQ(actual, Bitmap(buffer->data(), offset, num_bits))
+ << "offset:" << offset << " bits:" << num_bits << std::endl
+ << Bitmap(actual_buffer, 0, num_bits).Diff({buffer, offset, num_bits});
+ }
+ }
+}
+
+#ifndef ARROW_VALGRIND
+
+// This test reads uninitialized memory
+TEST(Bitmap, VisitPartialWords) {
+ uint64_t words[2];
+ constexpr auto nbytes = sizeof(words);
+ constexpr auto nbits = nbytes * 8;
+
+ auto buffer = Buffer::Wrap(words, 2);
+ Bitmap bitmap(buffer, 0, nbits);
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Buffer> storage, AllocateBuffer(nbytes));
+
+ // words partially outside the buffer are not accessible, but they are loaded bitwise
+ auto first_byte_was_missing = Bitmap(SliceBuffer(buffer, 1), 0, nbits - 8);
+ ASSERT_EQ(Copy(first_byte_was_missing, storage), bitmap.Slice(8));
+
+ auto last_byte_was_missing = Bitmap(SliceBuffer(buffer, 0, nbytes - 1), 0, nbits - 8);
+ ASSERT_EQ(Copy(last_byte_was_missing, storage), bitmap.Slice(0, nbits - 8));
+}
+
+#endif // ARROW_VALGRIND
+
+TEST(Bitmap, ToString) {
+ uint8_t bitmap[8] = {0xAC, 0xCA, 0, 0, 0, 0, 0, 0};
+ EXPECT_EQ(Bitmap(bitmap, /*bit_offset*/ 0, /*length=*/34).ToString(),
+ "00110101 01010011 00000000 00000000 00");
+ EXPECT_EQ(Bitmap(bitmap, /*bit_offset*/ 0, /*length=*/16).ToString(),
+ "00110101 01010011");
+ EXPECT_EQ(Bitmap(bitmap, /*bit_offset*/ 0, /*length=*/11).ToString(), "00110101 010");
+ EXPECT_EQ(Bitmap(bitmap, /*bit_offset*/ 3, /*length=*/8).ToString(), "10101010");
+}
+
+// compute bitwise AND of bitmaps using word-wise visit
+TEST(Bitmap, VisitWordsAnd) {
+ constexpr int64_t nbytes = 1 << 10;
+ std::shared_ptr<Buffer> buffer, actual_buffer, expected_buffer;
+ for (std::shared_ptr<Buffer>* b : {&buffer, &actual_buffer, &expected_buffer}) {
+ ASSERT_OK_AND_ASSIGN(*b, AllocateBuffer(nbytes));
+ memset((*b)->mutable_data(), 0, nbytes);
+ }
+ random_bytes(nbytes, 0, buffer->mutable_data());
+
+ constexpr int64_t kBitWidth = 8 * sizeof(uint64_t);
+
+ for (int64_t left_offset :
+ {0, 1, 2, 5, 17, int(kBitWidth - 1), int(kBitWidth + 1), int(kBitWidth + 17)}) {
+ for (int64_t right_offset = 0; right_offset < left_offset; ++right_offset) {
+ for (int64_t num_bits :
+ {int64_t(13), int64_t(9), kBitWidth - 1, kBitWidth, kBitWidth + 1,
+ 2 * kBitWidth - 1, 2 * kBitWidth, 2 * kBitWidth + 1, nbytes * 8 - left_offset,
+ 3 * kBitWidth - 1, 3 * kBitWidth, 3 * kBitWidth + 1, nbytes * 6,
+ nbytes * 4}) {
+ Bitmap bitmaps[] = {{buffer, left_offset, num_bits},
+ {buffer, right_offset, num_bits}};
+
+ int64_t i = 0;
+ auto min_offset =
+ Bitmap::VisitWords(bitmaps, [&](std::array<uint64_t, 2> uint64s) {
+ reinterpret_cast<uint64_t*>(actual_buffer->mutable_data())[i++] =
+ uint64s[0] & uint64s[1];
+ });
+
+ BitmapAnd(bitmaps[0].buffer()->data(), bitmaps[0].offset(),
+ bitmaps[1].buffer()->data(), bitmaps[1].offset(), bitmaps[0].length(),
+ 0, expected_buffer->mutable_data());
+
+ ASSERT_TRUE(BitmapEquals(actual_buffer->data(), min_offset,
+ expected_buffer->data(), 0, num_bits))
+ << "left_offset:" << left_offset << " bits:" << num_bits
+ << " right_offset:" << right_offset << std::endl
+ << Bitmap(actual_buffer, 0, num_bits).Diff({expected_buffer, 0, num_bits});
+ }
+ }
+ }
+}
+
+void DoBitmapVisitAndWrite(int64_t part, bool with_offset) {
+ int64_t bits = part * 4;
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+ auto arrow_data = rand.ArrayOf(boolean(), bits, 0);
+
+ std::shared_ptr<Buffer>& arrow_buffer = arrow_data->data()->buffers[1];
+
+ Bitmap bm0(arrow_buffer, 0, part);
+ Bitmap bm1(arrow_buffer, part * 1, part);
+ Bitmap bm2(arrow_buffer, part * 2, part);
+
+ std::array<Bitmap, 2> out_bms;
+ if (with_offset) {
+ ASSERT_OK_AND_ASSIGN(auto out, AllocateBitmap(part * 4));
+ out_bms[0] = Bitmap(out, part, part);
+ out_bms[1] = Bitmap(out, part * 2, part);
+ } else {
+ ASSERT_OK_AND_ASSIGN(auto out0, AllocateBitmap(part));
+ ASSERT_OK_AND_ASSIGN(auto out1, AllocateBitmap(part));
+ out_bms[0] = Bitmap(out0, 0, part);
+ out_bms[1] = Bitmap(out1, 0, part);
+ }
+
+ // out0 = bm0 & bm1, out1= bm0 | bm2
+ std::array<Bitmap, 3> in_bms{bm0, bm1, bm2};
+ Bitmap::VisitWordsAndWrite(
+ in_bms, &out_bms,
+ [](const std::array<uint64_t, 3>& in, std::array<uint64_t, 2>* out) {
+ out->at(0) = in[0] & in[1];
+ out->at(1) = in[0] | in[2];
+ });
+
+ auto pool = MemoryPool::CreateDefault();
+ ASSERT_OK_AND_ASSIGN(auto exp_0,
+ BitmapAnd(pool.get(), bm0.buffer()->data(), bm0.offset(),
+ bm1.buffer()->data(), bm1.offset(), part, 0));
+ ASSERT_OK_AND_ASSIGN(auto exp_1,
+ BitmapOr(pool.get(), bm0.buffer()->data(), bm0.offset(),
+ bm2.buffer()->data(), bm2.offset(), part, 0));
+
+ ASSERT_TRUE(BitmapEquals(exp_0->data(), 0, out_bms[0].buffer()->data(),
+ out_bms[0].offset(), part))
+ << "exp: " << Bitmap(exp_0->data(), 0, part).ToString() << std::endl
+ << "got: " << out_bms[0].ToString();
+
+ ASSERT_TRUE(BitmapEquals(exp_1->data(), 0, out_bms[1].buffer()->data(),
+ out_bms[1].offset(), part))
+ << "exp: " << Bitmap(exp_1->data(), 0, part).ToString() << std::endl
+ << "got: " << out_bms[1].ToString();
+}
+
+class TestBitmapVisitAndWrite : public ::testing::TestWithParam<int32_t> {};
+
+INSTANTIATE_TEST_SUITE_P(VisitWriteGeneral, TestBitmapVisitAndWrite,
+ testing::Values(199, 256, 1000));
+
+INSTANTIATE_TEST_SUITE_P(VisitWriteEdgeCases, TestBitmapVisitAndWrite,
+ testing::Values(5, 13, 21, 29, 37, 41, 51, 59, 64, 97));
+
+INSTANTIATE_TEST_SUITE_P(VisitWriteEdgeCases2, TestBitmapVisitAndWrite,
+ testing::Values(8, 16, 24, 32, 40, 48, 56, 64));
+
+TEST_P(TestBitmapVisitAndWrite, NoOffset) { DoBitmapVisitAndWrite(GetParam(), false); }
+
+TEST_P(TestBitmapVisitAndWrite, WithOffset) { DoBitmapVisitAndWrite(GetParam(), true); }
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bitmap.cc b/src/arrow/cpp/src/arrow/util/bitmap.cc
new file mode 100644
index 000000000..33d1dee19
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bitmap.cc
@@ -0,0 +1,75 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/bitmap.h"
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+
+#include "arrow/array/array_primitive.h"
+#include "arrow/buffer.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace internal {
+
+std::string Bitmap::ToString() const {
+ std::string out(length_ + ((length_ - 1) / 8), ' ');
+ for (int64_t i = 0; i < length_; ++i) {
+ out[i + (i / 8)] = GetBit(i) ? '1' : '0';
+ }
+ return out;
+}
+
+std::shared_ptr<BooleanArray> Bitmap::ToArray() const {
+ return std::make_shared<BooleanArray>(length_, buffer_, nullptr, 0, offset_);
+}
+
+std::string Bitmap::Diff(const Bitmap& other) const {
+ return ToArray()->Diff(*other.ToArray());
+}
+
+void Bitmap::CopyFrom(const Bitmap& other) {
+ ::arrow::internal::CopyBitmap(other.buffer_->data(), other.offset_, other.length_,
+ buffer_->mutable_data(), offset_);
+}
+
+void Bitmap::CopyFromInverted(const Bitmap& other) {
+ ::arrow::internal::InvertBitmap(other.buffer_->data(), other.offset_, other.length_,
+ buffer_->mutable_data(), offset_);
+}
+
+bool Bitmap::Equals(const Bitmap& other) const {
+ if (length_ != other.length_) {
+ return false;
+ }
+ return BitmapEquals(buffer_->data(), offset_, other.buffer_->data(), other.offset(),
+ length_);
+}
+
+int64_t Bitmap::BitLength(const Bitmap* bitmaps, size_t N) {
+ for (size_t i = 1; i < N; ++i) {
+ DCHECK_EQ(bitmaps[i].length(), bitmaps[0].length());
+ }
+ return bitmaps[0].length();
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bitmap.h b/src/arrow/cpp/src/arrow/util/bitmap.h
new file mode 100644
index 000000000..141f863c0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bitmap.h
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <array>
+#include <bitset>
+#include <cassert>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/bitmap_writer.h"
+#include "arrow/util/compare.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/functional.h"
+#include "arrow/util/string_builder.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class BooleanArray;
+
+namespace internal {
+
+class ARROW_EXPORT Bitmap : public util::ToStringOstreamable<Bitmap>,
+ public util::EqualityComparable<Bitmap> {
+ public:
+ template <typename Word>
+ using View = util::basic_string_view<Word>;
+
+ Bitmap() = default;
+
+ Bitmap(std::shared_ptr<Buffer> buffer, int64_t offset, int64_t length)
+ : buffer_(std::move(buffer)), offset_(offset), length_(length) {}
+
+ Bitmap(const void* data, int64_t offset, int64_t length)
+ : buffer_(std::make_shared<Buffer>(static_cast<const uint8_t*>(data),
+ BitUtil::BytesForBits(offset + length))),
+ offset_(offset),
+ length_(length) {}
+
+ Bitmap(void* data, int64_t offset, int64_t length)
+ : buffer_(std::make_shared<MutableBuffer>(static_cast<uint8_t*>(data),
+ BitUtil::BytesForBits(offset + length))),
+ offset_(offset),
+ length_(length) {}
+
+ Bitmap Slice(int64_t offset) const {
+ return Bitmap(buffer_, offset_ + offset, length_ - offset);
+ }
+
+ Bitmap Slice(int64_t offset, int64_t length) const {
+ return Bitmap(buffer_, offset_ + offset, length);
+ }
+
+ std::string ToString() const;
+
+ bool Equals(const Bitmap& other) const;
+
+ std::string Diff(const Bitmap& other) const;
+
+ bool GetBit(int64_t i) const { return BitUtil::GetBit(buffer_->data(), i + offset_); }
+
+ bool operator[](int64_t i) const { return GetBit(i); }
+
+ void SetBitTo(int64_t i, bool v) const {
+ BitUtil::SetBitTo(buffer_->mutable_data(), i + offset_, v);
+ }
+
+ void SetBitsTo(bool v) {
+ BitUtil::SetBitsTo(buffer_->mutable_data(), offset_, length_, v);
+ }
+
+ void CopyFrom(const Bitmap& other);
+ void CopyFromInverted(const Bitmap& other);
+
+ /// \brief Visit bits from each bitmap as bitset<N>
+ ///
+ /// All bitmaps must have identical length.
+ template <size_t N, typename Visitor>
+ static void VisitBits(const Bitmap (&bitmaps)[N], Visitor&& visitor) {
+ int64_t bit_length = BitLength(bitmaps, N);
+ std::bitset<N> bits;
+ for (int64_t bit_i = 0; bit_i < bit_length; ++bit_i) {
+ for (size_t i = 0; i < N; ++i) {
+ bits[i] = bitmaps[i].GetBit(bit_i);
+ }
+ visitor(bits);
+ }
+ }
+
+ /// \brief Visit bits from each bitmap as bitset<N>
+ ///
+ /// All bitmaps must have identical length.
+ template <size_t N, typename Visitor>
+ static void VisitBits(const std::array<Bitmap, N>& bitmaps, Visitor&& visitor) {
+ int64_t bit_length = BitLength(bitmaps);
+ std::bitset<N> bits;
+ for (int64_t bit_i = 0; bit_i < bit_length; ++bit_i) {
+ for (size_t i = 0; i < N; ++i) {
+ bits[i] = bitmaps[i].GetBit(bit_i);
+ }
+ visitor(bits);
+ }
+ }
+
+ /// \brief Visit words of bits from each bitmap as array<Word, N>
+ ///
+ /// All bitmaps must have identical length. The first bit in a visited bitmap
+ /// may be offset within the first visited word, but words will otherwise contain
+ /// densely packed bits loaded from the bitmap. That offset within the first word is
+ /// returned.
+ ///
+ /// TODO(bkietz) allow for early termination
+ // NOTE: this function is efficient on 3+ sufficiently large bitmaps.
+ // It also has a large prolog / epilog overhead and should be used
+ // carefully in other cases.
+ // For 2 bitmaps or less, and/or smaller bitmaps, see also VisitTwoBitBlocksVoid
+ // and BitmapUInt64Reader.
+ template <size_t N, typename Visitor,
+ typename Word = typename std::decay<
+ internal::call_traits::argument_type<0, Visitor&&>>::type::value_type>
+ static int64_t VisitWords(const Bitmap (&bitmaps_arg)[N], Visitor&& visitor) {
+ constexpr int64_t kBitWidth = sizeof(Word) * 8;
+
+ // local, mutable variables which will be sliced/decremented to represent consumption:
+ Bitmap bitmaps[N];
+ int64_t offsets[N];
+ int64_t bit_length = BitLength(bitmaps_arg, N);
+ View<Word> words[N];
+ for (size_t i = 0; i < N; ++i) {
+ bitmaps[i] = bitmaps_arg[i];
+ offsets[i] = bitmaps[i].template word_offset<Word>();
+ assert(offsets[i] >= 0 && offsets[i] < kBitWidth);
+ words[i] = bitmaps[i].template words<Word>();
+ }
+
+ auto consume = [&](int64_t consumed_bits) {
+ for (size_t i = 0; i < N; ++i) {
+ bitmaps[i] = bitmaps[i].Slice(consumed_bits, bit_length - consumed_bits);
+ offsets[i] = bitmaps[i].template word_offset<Word>();
+ assert(offsets[i] >= 0 && offsets[i] < kBitWidth);
+ words[i] = bitmaps[i].template words<Word>();
+ }
+ bit_length -= consumed_bits;
+ };
+
+ std::array<Word, N> visited_words;
+ visited_words.fill(0);
+
+ if (bit_length <= kBitWidth * 2) {
+ // bitmaps fit into one or two words so don't bother with optimization
+ while (bit_length > 0) {
+ auto leading_bits = std::min(bit_length, kBitWidth);
+ SafeLoadWords(bitmaps, 0, leading_bits, false, &visited_words);
+ visitor(visited_words);
+ consume(leading_bits);
+ }
+ return 0;
+ }
+
+ int64_t max_offset = *std::max_element(offsets, offsets + N);
+ int64_t min_offset = *std::min_element(offsets, offsets + N);
+ if (max_offset > 0) {
+ // consume leading bits
+ auto leading_bits = kBitWidth - min_offset;
+ SafeLoadWords(bitmaps, 0, leading_bits, true, &visited_words);
+ visitor(visited_words);
+ consume(leading_bits);
+ }
+ assert(*std::min_element(offsets, offsets + N) == 0);
+
+ int64_t whole_word_count = bit_length / kBitWidth;
+ assert(whole_word_count >= 1);
+
+ if (min_offset == max_offset) {
+ // all offsets were identical, all leading bits have been consumed
+ assert(
+ std::all_of(offsets, offsets + N, [](int64_t offset) { return offset == 0; }));
+
+ for (int64_t word_i = 0; word_i < whole_word_count; ++word_i) {
+ for (size_t i = 0; i < N; ++i) {
+ visited_words[i] = words[i][word_i];
+ }
+ visitor(visited_words);
+ }
+ consume(whole_word_count * kBitWidth);
+ } else {
+ // leading bits from potentially incomplete words have been consumed
+
+ // word_i such that words[i][word_i] and words[i][word_i + 1] are lie entirely
+ // within the bitmap for all i
+ for (int64_t word_i = 0; word_i < whole_word_count - 1; ++word_i) {
+ for (size_t i = 0; i < N; ++i) {
+ if (offsets[i] == 0) {
+ visited_words[i] = words[i][word_i];
+ } else {
+ auto words0 = BitUtil::ToLittleEndian(words[i][word_i]);
+ auto words1 = BitUtil::ToLittleEndian(words[i][word_i + 1]);
+ visited_words[i] = BitUtil::FromLittleEndian(
+ (words0 >> offsets[i]) | (words1 << (kBitWidth - offsets[i])));
+ }
+ }
+ visitor(visited_words);
+ }
+ consume((whole_word_count - 1) * kBitWidth);
+
+ SafeLoadWords(bitmaps, 0, kBitWidth, false, &visited_words);
+
+ visitor(visited_words);
+ consume(kBitWidth);
+ }
+
+ // load remaining bits
+ if (bit_length > 0) {
+ SafeLoadWords(bitmaps, 0, bit_length, false, &visited_words);
+ visitor(visited_words);
+ }
+
+ return min_offset;
+ }
+
+ template <size_t N, size_t M, typename ReaderT, typename WriterT, typename Visitor,
+ typename Word = typename std::decay<
+ internal::call_traits::argument_type<0, Visitor&&>>::type::value_type>
+ static void RunVisitWordsAndWriteLoop(int64_t bit_length,
+ std::array<ReaderT, N>& readers,
+ std::array<WriterT, M>& writers,
+ Visitor&& visitor) {
+ constexpr int64_t kBitWidth = sizeof(Word) * 8;
+
+ std::array<Word, N> visited_words;
+ std::array<Word, M> output_words;
+
+ // every reader will have same number of words, since they are same length'ed
+ // TODO($JIRA) this will be inefficient in some cases. When there are offsets beyond
+ // Word boundary, every Word would have to be created from 2 adjoining Words
+ auto n_words = readers[0].words();
+ bit_length -= n_words * kBitWidth;
+ while (n_words--) {
+ // first collect all words to visited_words array
+ for (size_t i = 0; i < N; i++) {
+ visited_words[i] = readers[i].NextWord();
+ }
+ visitor(visited_words, &output_words);
+ for (size_t i = 0; i < M; i++) {
+ writers[i].PutNextWord(output_words[i]);
+ }
+ }
+
+ // every reader will have same number of trailing bytes, because of the above reason
+ // tailing portion could be more than one word! (ref: BitmapWordReader constructor)
+ // remaining full/ partial words to write
+
+ if (bit_length) {
+ // convert the word visitor lambda to a byte_visitor
+ auto byte_visitor = [&](const std::array<uint8_t, N>& in,
+ std::array<uint8_t, M>* out) {
+ std::array<Word, N> in_words;
+ std::array<Word, M> out_words;
+ std::copy(in.begin(), in.end(), in_words.begin());
+ visitor(in_words, &out_words);
+ for (size_t i = 0; i < M; i++) {
+ out->at(i) = static_cast<uint8_t>(out_words[i]);
+ }
+ };
+
+ std::array<uint8_t, N> visited_bytes;
+ std::array<uint8_t, M> output_bytes;
+ int n_bytes = readers[0].trailing_bytes();
+ while (n_bytes--) {
+ visited_bytes.fill(0);
+ output_bytes.fill(0);
+ int valid_bits;
+ for (size_t i = 0; i < N; i++) {
+ visited_bytes[i] = readers[i].NextTrailingByte(valid_bits);
+ }
+ byte_visitor(visited_bytes, &output_bytes);
+ for (size_t i = 0; i < M; i++) {
+ writers[i].PutNextTrailingByte(output_bytes[i], valid_bits);
+ }
+ }
+ }
+ }
+
+ /// \brief Visit words of bits from each input bitmap as array<Word, N> and collects
+ /// outputs to an array<Word, M>, to be written into the output bitmaps accordingly.
+ ///
+ /// All bitmaps must have identical length. The first bit in a visited bitmap
+ /// may be offset within the first visited word, but words will otherwise contain
+ /// densely packed bits loaded from the bitmap. That offset within the first word is
+ /// returned.
+ /// Visitor is expected to have the following signature
+ /// [](const std::array<Word, N>& in_words, std::array<Word, M>* out_words){...}
+ ///
+ // NOTE: this function is efficient on 3+ sufficiently large bitmaps.
+ // It also has a large prolog / epilog overhead and should be used
+ // carefully in other cases.
+ // For 2 bitmaps or less, and/or smaller bitmaps, see also VisitTwoBitBlocksVoid
+ // and BitmapUInt64Reader.
+ template <size_t N, size_t M, typename Visitor,
+ typename Word = typename std::decay<
+ internal::call_traits::argument_type<0, Visitor&&>>::type::value_type>
+ static void VisitWordsAndWrite(const std::array<Bitmap, N>& bitmaps_arg,
+ std::array<Bitmap, M>* out_bitmaps_arg,
+ Visitor&& visitor) {
+ int64_t bit_length = BitLength(bitmaps_arg);
+ assert(bit_length == BitLength(*out_bitmaps_arg));
+
+ // if both input and output bitmaps have no byte offset, then use special template
+ if (std::all_of(bitmaps_arg.begin(), bitmaps_arg.end(),
+ [](const Bitmap& b) { return b.offset_ % 8 == 0; }) &&
+ std::all_of(out_bitmaps_arg->begin(), out_bitmaps_arg->end(),
+ [](const Bitmap& b) { return b.offset_ % 8 == 0; })) {
+ std::array<BitmapWordReader<Word, /*may_have_byte_offset=*/false>, N> readers;
+ for (size_t i = 0; i < N; ++i) {
+ const Bitmap& in_bitmap = bitmaps_arg[i];
+ readers[i] = BitmapWordReader<Word, /*may_have_byte_offset=*/false>(
+ in_bitmap.buffer_->data(), in_bitmap.offset_, in_bitmap.length_);
+ }
+
+ std::array<BitmapWordWriter<Word, /*may_have_byte_offset=*/false>, M> writers;
+ for (size_t i = 0; i < M; ++i) {
+ const Bitmap& out_bitmap = out_bitmaps_arg->at(i);
+ writers[i] = BitmapWordWriter<Word, /*may_have_byte_offset=*/false>(
+ out_bitmap.buffer_->mutable_data(), out_bitmap.offset_, out_bitmap.length_);
+ }
+
+ RunVisitWordsAndWriteLoop(bit_length, readers, writers, visitor);
+ } else {
+ std::array<BitmapWordReader<Word>, N> readers;
+ for (size_t i = 0; i < N; ++i) {
+ const Bitmap& in_bitmap = bitmaps_arg[i];
+ readers[i] = BitmapWordReader<Word>(in_bitmap.buffer_->data(), in_bitmap.offset_,
+ in_bitmap.length_);
+ }
+
+ std::array<BitmapWordWriter<Word>, M> writers;
+ for (size_t i = 0; i < M; ++i) {
+ const Bitmap& out_bitmap = out_bitmaps_arg->at(i);
+ writers[i] = BitmapWordWriter<Word>(out_bitmap.buffer_->mutable_data(),
+ out_bitmap.offset_, out_bitmap.length_);
+ }
+
+ RunVisitWordsAndWriteLoop(bit_length, readers, writers, visitor);
+ }
+ }
+
+ const std::shared_ptr<Buffer>& buffer() const { return buffer_; }
+
+ /// offset of first bit relative to buffer().data()
+ int64_t offset() const { return offset_; }
+
+ /// number of bits in this Bitmap
+ int64_t length() const { return length_; }
+
+ /// string_view of all bytes which contain any bit in this Bitmap
+ util::bytes_view bytes() const {
+ auto byte_offset = offset_ / 8;
+ auto byte_count = BitUtil::CeilDiv(offset_ + length_, 8) - byte_offset;
+ return util::bytes_view(buffer_->data() + byte_offset, byte_count);
+ }
+
+ private:
+ /// string_view of all Words which contain any bit in this Bitmap
+ ///
+ /// For example, given Word=uint16_t and a bitmap spanning bits [20, 36)
+ /// words() would span bits [16, 48).
+ ///
+ /// 0 16 32 48 64
+ /// |-------|-------|------|------| (buffer)
+ /// [ ] (bitmap)
+ /// |-------|------| (returned words)
+ ///
+ /// \warning The words may contain bytes which lie outside the buffer or are
+ /// uninitialized.
+ template <typename Word>
+ View<Word> words() const {
+ auto bytes_addr = reinterpret_cast<intptr_t>(bytes().data());
+ auto words_addr = bytes_addr - bytes_addr % sizeof(Word);
+ auto word_byte_count =
+ BitUtil::RoundUpToPowerOf2(static_cast<int64_t>(bytes_addr + bytes().size()),
+ static_cast<int64_t>(sizeof(Word))) -
+ words_addr;
+ return View<Word>(reinterpret_cast<const Word*>(words_addr),
+ word_byte_count / sizeof(Word));
+ }
+
+ /// offset of first bit relative to words<Word>().data()
+ template <typename Word>
+ int64_t word_offset() const {
+ return offset_ + 8 * (reinterpret_cast<intptr_t>(buffer_->data()) -
+ reinterpret_cast<intptr_t>(words<Word>().data()));
+ }
+
+ /// load words from bitmaps bitwise
+ template <size_t N, typename Word>
+ static void SafeLoadWords(const Bitmap (&bitmaps)[N], int64_t offset,
+ int64_t out_length, bool set_trailing_bits,
+ std::array<Word, N>* out) {
+ out->fill(0);
+
+ int64_t out_offset = set_trailing_bits ? sizeof(Word) * 8 - out_length : 0;
+
+ Bitmap slices[N], out_bitmaps[N];
+ for (size_t i = 0; i < N; ++i) {
+ slices[i] = bitmaps[i].Slice(offset, out_length);
+ out_bitmaps[i] = Bitmap(&out->at(i), out_offset, out_length);
+ }
+
+ int64_t bit_i = 0;
+ Bitmap::VisitBits(slices, [&](std::bitset<N> bits) {
+ for (size_t i = 0; i < N; ++i) {
+ out_bitmaps[i].SetBitTo(bit_i, bits[i]);
+ }
+ ++bit_i;
+ });
+ }
+
+ std::shared_ptr<BooleanArray> ToArray() const;
+
+ /// assert bitmaps have identical length and return that length
+ static int64_t BitLength(const Bitmap* bitmaps, size_t N);
+
+ template <size_t N>
+ static int64_t BitLength(const std::array<Bitmap, N>& bitmaps) {
+ for (size_t i = 1; i < N; ++i) {
+ assert(bitmaps[i].length() == bitmaps[0].length());
+ }
+ return bitmaps[0].length();
+ }
+
+ std::shared_ptr<Buffer> buffer_;
+ int64_t offset_ = 0, length_ = 0;
+};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bitmap_builders.cc b/src/arrow/cpp/src/arrow/util/bitmap_builders.cc
new file mode 100644
index 000000000..9a91b7ac6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bitmap_builders.cc
@@ -0,0 +1,72 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/bitmap_builders.h"
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <type_traits>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/bit_util.h"
+
+namespace arrow {
+namespace internal {
+
+namespace {
+
+void FillBitsFromBytes(const std::vector<uint8_t>& bytes, uint8_t* bits) {
+ for (size_t i = 0; i < bytes.size(); ++i) {
+ if (bytes[i] > 0) {
+ BitUtil::SetBit(bits, i);
+ }
+ }
+}
+
+} // namespace
+
+Result<std::shared_ptr<Buffer>> BytesToBits(const std::vector<uint8_t>& bytes,
+ MemoryPool* pool) {
+ int64_t bit_length = BitUtil::BytesForBits(bytes.size());
+
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(bit_length, pool));
+ uint8_t* out_buf = buffer->mutable_data();
+ memset(out_buf, 0, static_cast<size_t>(buffer->capacity()));
+ FillBitsFromBytes(bytes, out_buf);
+ return std::move(buffer);
+}
+
+Result<std::shared_ptr<Buffer>> BitmapAllButOne(MemoryPool* pool, int64_t length,
+ int64_t straggler_pos, bool value) {
+ if (straggler_pos < 0 || straggler_pos >= length) {
+ return Status::Invalid("invalid straggler_pos ", straggler_pos);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(BitUtil::BytesForBits(length), pool));
+
+ auto bitmap_data = buffer->mutable_data();
+ BitUtil::SetBitsTo(bitmap_data, 0, length, value);
+ BitUtil::SetBitTo(bitmap_data, straggler_pos, !value);
+ return std::move(buffer);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bitmap_builders.h b/src/arrow/cpp/src/arrow/util/bitmap_builders.h
new file mode 100644
index 000000000..5bd2ad441
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bitmap_builders.h
@@ -0,0 +1,43 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "arrow/result.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace internal {
+
+/// \brief Generate Bitmap with all position to `value` except for one found
+/// at `straggler_pos`.
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> BitmapAllButOne(MemoryPool* pool, int64_t length,
+ int64_t straggler_pos, bool value = true);
+
+/// \brief Convert vector of bytes to bitmap buffer
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> BytesToBits(const std::vector<uint8_t>&,
+ MemoryPool* pool = default_memory_pool());
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bitmap_generate.h b/src/arrow/cpp/src/arrow/util/bitmap_generate.h
new file mode 100644
index 000000000..6b900f246
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bitmap_generate.h
@@ -0,0 +1,111 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/buffer.h"
+#include "arrow/memory_pool.h"
+#include "arrow/result.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace internal {
+
+// A std::generate() like function to write sequential bits into a bitmap area.
+// Bits preceding the bitmap area are preserved, bits following the bitmap
+// area may be clobbered.
+
+template <class Generator>
+void GenerateBits(uint8_t* bitmap, int64_t start_offset, int64_t length, Generator&& g) {
+ if (length == 0) {
+ return;
+ }
+ uint8_t* cur = bitmap + start_offset / 8;
+ uint8_t bit_mask = BitUtil::kBitmask[start_offset % 8];
+ uint8_t current_byte = *cur & BitUtil::kPrecedingBitmask[start_offset % 8];
+
+ for (int64_t index = 0; index < length; ++index) {
+ const bool bit = g();
+ current_byte = bit ? (current_byte | bit_mask) : current_byte;
+ bit_mask = static_cast<uint8_t>(bit_mask << 1);
+ if (bit_mask == 0) {
+ bit_mask = 1;
+ *cur++ = current_byte;
+ current_byte = 0;
+ }
+ }
+ if (bit_mask != 1) {
+ *cur++ = current_byte;
+ }
+}
+
+// Like GenerateBits(), but unrolls its main loop for higher performance.
+
+template <class Generator>
+void GenerateBitsUnrolled(uint8_t* bitmap, int64_t start_offset, int64_t length,
+ Generator&& g) {
+ static_assert(std::is_same<decltype(std::declval<Generator>()()), bool>::value,
+ "Functor passed to GenerateBitsUnrolled must return bool");
+
+ if (length == 0) {
+ return;
+ }
+ uint8_t current_byte;
+ uint8_t* cur = bitmap + start_offset / 8;
+ const uint64_t start_bit_offset = start_offset % 8;
+ uint8_t bit_mask = BitUtil::kBitmask[start_bit_offset];
+ int64_t remaining = length;
+
+ if (bit_mask != 0x01) {
+ current_byte = *cur & BitUtil::kPrecedingBitmask[start_bit_offset];
+ while (bit_mask != 0 && remaining > 0) {
+ current_byte |= g() * bit_mask;
+ bit_mask = static_cast<uint8_t>(bit_mask << 1);
+ --remaining;
+ }
+ *cur++ = current_byte;
+ }
+
+ int64_t remaining_bytes = remaining / 8;
+ uint8_t out_results[8];
+ while (remaining_bytes-- > 0) {
+ for (int i = 0; i < 8; ++i) {
+ out_results[i] = g();
+ }
+ *cur++ = (out_results[0] | out_results[1] << 1 | out_results[2] << 2 |
+ out_results[3] << 3 | out_results[4] << 4 | out_results[5] << 5 |
+ out_results[6] << 6 | out_results[7] << 7);
+ }
+
+ int64_t remaining_bits = remaining % 8;
+ if (remaining_bits) {
+ current_byte = 0;
+ bit_mask = 0x01;
+ while (remaining_bits-- > 0) {
+ current_byte |= g() * bit_mask;
+ bit_mask = static_cast<uint8_t>(bit_mask << 1);
+ }
+ *cur++ = current_byte;
+ }
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bitmap_ops.cc b/src/arrow/cpp/src/arrow/util/bitmap_ops.cc
new file mode 100644
index 000000000..63c8b008f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bitmap_ops.cc
@@ -0,0 +1,387 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/bitmap_ops.h"
+
+#include <cstdint>
+#include <cstring>
+#include <functional>
+#include <memory>
+
+#include "arrow/buffer.h"
+#include "arrow/result.h"
+#include "arrow/util/align_util.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/bitmap_writer.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace internal {
+
+int64_t CountSetBits(const uint8_t* data, int64_t bit_offset, int64_t length) {
+ constexpr int64_t pop_len = sizeof(uint64_t) * 8;
+ DCHECK_GE(bit_offset, 0);
+ int64_t count = 0;
+
+ const auto p = BitmapWordAlign<pop_len / 8>(data, bit_offset, length);
+ for (int64_t i = bit_offset; i < bit_offset + p.leading_bits; ++i) {
+ if (BitUtil::GetBit(data, i)) {
+ ++count;
+ }
+ }
+
+ if (p.aligned_words > 0) {
+ // popcount as much as possible with the widest possible count
+ const uint64_t* u64_data = reinterpret_cast<const uint64_t*>(p.aligned_start);
+ DCHECK_EQ(reinterpret_cast<size_t>(u64_data) & 7, 0);
+ const uint64_t* end = u64_data + p.aligned_words;
+
+ constexpr int64_t kCountUnrollFactor = 4;
+ const int64_t words_rounded = BitUtil::RoundDown(p.aligned_words, kCountUnrollFactor);
+ int64_t count_unroll[kCountUnrollFactor] = {0};
+
+ // Unroll the loop for better performance
+ for (int64_t i = 0; i < words_rounded; i += kCountUnrollFactor) {
+ for (int64_t k = 0; k < kCountUnrollFactor; k++) {
+ count_unroll[k] += BitUtil::PopCount(u64_data[k]);
+ }
+ u64_data += kCountUnrollFactor;
+ }
+ for (int64_t k = 0; k < kCountUnrollFactor; k++) {
+ count += count_unroll[k];
+ }
+
+ // The trailing part
+ for (; u64_data < end; ++u64_data) {
+ count += BitUtil::PopCount(*u64_data);
+ }
+ }
+
+ // Account for left over bits (in theory we could fall back to smaller
+ // versions of popcount but the code complexity is likely not worth it)
+ for (int64_t i = p.trailing_bit_offset; i < bit_offset + length; ++i) {
+ if (BitUtil::GetBit(data, i)) {
+ ++count;
+ }
+ }
+
+ return count;
+}
+
+enum class TransferMode : bool { Copy, Invert };
+
+template <TransferMode mode>
+void TransferBitmap(const uint8_t* data, int64_t offset, int64_t length,
+ int64_t dest_offset, uint8_t* dest) {
+ int64_t bit_offset = offset % 8;
+ int64_t dest_bit_offset = dest_offset % 8;
+
+ if (bit_offset || dest_bit_offset) {
+ auto reader = internal::BitmapWordReader<uint64_t>(data, offset, length);
+ auto writer = internal::BitmapWordWriter<uint64_t>(dest, dest_offset, length);
+
+ auto nwords = reader.words();
+ while (nwords--) {
+ auto word = reader.NextWord();
+ writer.PutNextWord(mode == TransferMode::Invert ? ~word : word);
+ }
+ auto nbytes = reader.trailing_bytes();
+ while (nbytes--) {
+ int valid_bits;
+ auto byte = reader.NextTrailingByte(valid_bits);
+ writer.PutNextTrailingByte(mode == TransferMode::Invert ? ~byte : byte, valid_bits);
+ }
+ } else if (length) {
+ int64_t num_bytes = BitUtil::BytesForBits(length);
+
+ // Shift by its byte offset
+ data += offset / 8;
+ dest += dest_offset / 8;
+
+ // Take care of the trailing bits in the last byte
+ // E.g., if trailing_bits = 5, last byte should be
+ // - low 3 bits: new bits from last byte of data buffer
+ // - high 5 bits: old bits from last byte of dest buffer
+ int64_t trailing_bits = num_bytes * 8 - length;
+ uint8_t trail_mask = (1U << (8 - trailing_bits)) - 1;
+ uint8_t last_data;
+
+ if (mode == TransferMode::Invert) {
+ for (int64_t i = 0; i < num_bytes - 1; i++) {
+ dest[i] = static_cast<uint8_t>(~(data[i]));
+ }
+ last_data = ~data[num_bytes - 1];
+ } else {
+ std::memcpy(dest, data, static_cast<size_t>(num_bytes - 1));
+ last_data = data[num_bytes - 1];
+ }
+
+ // Set last byte
+ dest[num_bytes - 1] &= ~trail_mask;
+ dest[num_bytes - 1] |= last_data & trail_mask;
+ }
+}
+
+template <TransferMode mode>
+Result<std::shared_ptr<Buffer>> TransferBitmap(MemoryPool* pool, const uint8_t* data,
+ int64_t offset, int64_t length) {
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateEmptyBitmap(length, pool));
+ uint8_t* dest = buffer->mutable_data();
+
+ TransferBitmap<mode>(data, offset, length, 0, dest);
+
+ // As we have freshly allocated this bitmap, we should take care of zeroing the
+ // remaining bits.
+ int64_t num_bytes = BitUtil::BytesForBits(length);
+ int64_t bits_to_zero = num_bytes * 8 - length;
+ for (int64_t i = length; i < length + bits_to_zero; ++i) {
+ // Both branches may copy extra bits - unsetting to match specification.
+ BitUtil::ClearBit(dest, i);
+ }
+ return buffer;
+}
+
+void CopyBitmap(const uint8_t* data, int64_t offset, int64_t length, uint8_t* dest,
+ int64_t dest_offset) {
+ TransferBitmap<TransferMode::Copy>(data, offset, length, dest_offset, dest);
+}
+
+void InvertBitmap(const uint8_t* data, int64_t offset, int64_t length, uint8_t* dest,
+ int64_t dest_offset) {
+ TransferBitmap<TransferMode::Invert>(data, offset, length, dest_offset, dest);
+}
+
+Result<std::shared_ptr<Buffer>> CopyBitmap(MemoryPool* pool, const uint8_t* data,
+ int64_t offset, int64_t length) {
+ return TransferBitmap<TransferMode::Copy>(pool, data, offset, length);
+}
+
+Result<std::shared_ptr<Buffer>> InvertBitmap(MemoryPool* pool, const uint8_t* data,
+ int64_t offset, int64_t length) {
+ return TransferBitmap<TransferMode::Invert>(pool, data, offset, length);
+}
+
+bool BitmapEquals(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length) {
+ if (left_offset % 8 == 0 && right_offset % 8 == 0) {
+ // byte aligned, can use memcmp
+ bool bytes_equal =
+ std::memcmp(left + left_offset / 8, right + right_offset / 8, length / 8) == 0;
+ if (!bytes_equal) {
+ return false;
+ }
+ for (int64_t i = (length / 8) * 8; i < length; ++i) {
+ if (BitUtil::GetBit(left, left_offset + i) !=
+ BitUtil::GetBit(right, right_offset + i)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // Unaligned slow case
+ auto left_reader = internal::BitmapWordReader<uint64_t>(left, left_offset, length);
+ auto right_reader = internal::BitmapWordReader<uint64_t>(right, right_offset, length);
+
+ auto nwords = left_reader.words();
+ while (nwords--) {
+ if (left_reader.NextWord() != right_reader.NextWord()) {
+ return false;
+ }
+ }
+ auto nbytes = left_reader.trailing_bytes();
+ while (nbytes--) {
+ int valid_bits;
+ if (left_reader.NextTrailingByte(valid_bits) !=
+ right_reader.NextTrailingByte(valid_bits)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool OptionalBitmapEquals(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length) {
+ if (left == nullptr && right == nullptr) {
+ return true;
+ } else if (left != nullptr && right != nullptr) {
+ return BitmapEquals(left, left_offset, right, right_offset, length);
+ } else if (left != nullptr) {
+ return CountSetBits(left, left_offset, length) == length;
+ } else {
+ return CountSetBits(right, right_offset, length) == length;
+ }
+}
+
+bool OptionalBitmapEquals(const std::shared_ptr<Buffer>& left, int64_t left_offset,
+ const std::shared_ptr<Buffer>& right, int64_t right_offset,
+ int64_t length) {
+ return OptionalBitmapEquals(left ? left->data() : nullptr, left_offset,
+ right ? right->data() : nullptr, right_offset, length);
+}
+
+namespace {
+
+template <template <typename> class BitOp>
+void AlignedBitmapOp(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, uint8_t* out, int64_t out_offset,
+ int64_t length) {
+ BitOp<uint8_t> op;
+ DCHECK_EQ(left_offset % 8, right_offset % 8);
+ DCHECK_EQ(left_offset % 8, out_offset % 8);
+
+ const int64_t nbytes = BitUtil::BytesForBits(length + left_offset % 8);
+ left += left_offset / 8;
+ right += right_offset / 8;
+ out += out_offset / 8;
+ for (int64_t i = 0; i < nbytes; ++i) {
+ out[i] = op(left[i], right[i]);
+ }
+}
+
+template <template <typename> class BitOp>
+void UnalignedBitmapOp(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, uint8_t* out, int64_t out_offset,
+ int64_t length) {
+ BitOp<uint64_t> op_word;
+ BitOp<uint8_t> op_byte;
+
+ auto left_reader = internal::BitmapWordReader<uint64_t>(left, left_offset, length);
+ auto right_reader = internal::BitmapWordReader<uint64_t>(right, right_offset, length);
+ auto writer = internal::BitmapWordWriter<uint64_t>(out, out_offset, length);
+
+ auto nwords = left_reader.words();
+ while (nwords--) {
+ writer.PutNextWord(op_word(left_reader.NextWord(), right_reader.NextWord()));
+ }
+ auto nbytes = left_reader.trailing_bytes();
+ while (nbytes--) {
+ int left_valid_bits, right_valid_bits;
+ uint8_t left_byte = left_reader.NextTrailingByte(left_valid_bits);
+ uint8_t right_byte = right_reader.NextTrailingByte(right_valid_bits);
+ DCHECK_EQ(left_valid_bits, right_valid_bits);
+ writer.PutNextTrailingByte(op_byte(left_byte, right_byte), left_valid_bits);
+ }
+}
+
+template <template <typename> class BitOp>
+void BitmapOp(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* dest) {
+ if ((out_offset % 8 == left_offset % 8) && (out_offset % 8 == right_offset % 8)) {
+ // Fast case: can use bytewise AND
+ AlignedBitmapOp<BitOp>(left, left_offset, right, right_offset, dest, out_offset,
+ length);
+ } else {
+ // Unaligned
+ UnalignedBitmapOp<BitOp>(left, left_offset, right, right_offset, dest, out_offset,
+ length);
+ }
+}
+
+template <template <typename> class BitOp>
+Result<std::shared_ptr<Buffer>> BitmapOp(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset) {
+ const int64_t phys_bits = length + out_offset;
+ ARROW_ASSIGN_OR_RAISE(auto out_buffer, AllocateEmptyBitmap(phys_bits, pool));
+ BitmapOp<BitOp>(left, left_offset, right, right_offset, length, out_offset,
+ out_buffer->mutable_data());
+ return out_buffer;
+}
+
+} // namespace
+
+Result<std::shared_ptr<Buffer>> BitmapAnd(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset) {
+ return BitmapOp<std::bit_and>(pool, left, left_offset, right, right_offset, length,
+ out_offset);
+}
+
+void BitmapAnd(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out) {
+ BitmapOp<std::bit_and>(left, left_offset, right, right_offset, length, out_offset, out);
+}
+
+Result<std::shared_ptr<Buffer>> BitmapOr(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset) {
+ return BitmapOp<std::bit_or>(pool, left, left_offset, right, right_offset, length,
+ out_offset);
+}
+
+void BitmapOr(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out) {
+ BitmapOp<std::bit_or>(left, left_offset, right, right_offset, length, out_offset, out);
+}
+
+Result<std::shared_ptr<Buffer>> BitmapXor(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset) {
+ return BitmapOp<std::bit_xor>(pool, left, left_offset, right, right_offset, length,
+ out_offset);
+}
+
+void BitmapXor(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out) {
+ BitmapOp<std::bit_xor>(left, left_offset, right, right_offset, length, out_offset, out);
+}
+
+template <typename T>
+struct AndNotOp {
+ constexpr T operator()(const T& l, const T& r) const { return l & ~r; }
+};
+
+Result<std::shared_ptr<Buffer>> BitmapAndNot(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset) {
+ return BitmapOp<AndNotOp>(pool, left, left_offset, right, right_offset, length,
+ out_offset);
+}
+
+void BitmapAndNot(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset,
+ uint8_t* out) {
+ BitmapOp<AndNotOp>(left, left_offset, right, right_offset, length, out_offset, out);
+}
+
+template <typename T>
+struct OrNotOp {
+ constexpr T operator()(const T& l, const T& r) const { return l | ~r; }
+};
+
+Result<std::shared_ptr<Buffer>> BitmapOrNot(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset) {
+ return BitmapOp<OrNotOp>(pool, left, left_offset, right, right_offset, length,
+ out_offset);
+}
+
+void BitmapOrNot(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out) {
+ BitmapOp<OrNotOp>(left, left_offset, right, right_offset, length, out_offset, out);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bitmap_ops.h b/src/arrow/cpp/src/arrow/util/bitmap_ops.h
new file mode 100644
index 000000000..40a7797a2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bitmap_ops.h
@@ -0,0 +1,206 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/result.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Buffer;
+class MemoryPool;
+
+namespace internal {
+
+// ----------------------------------------------------------------------
+// Bitmap utilities
+
+/// Copy a bit range of an existing bitmap
+///
+/// \param[in] pool memory pool to allocate memory from
+/// \param[in] bitmap source data
+/// \param[in] offset bit offset into the source data
+/// \param[in] length number of bits to copy
+///
+/// \return Status message
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> CopyBitmap(MemoryPool* pool, const uint8_t* bitmap,
+ int64_t offset, int64_t length);
+
+/// Copy a bit range of an existing bitmap into an existing bitmap
+///
+/// \param[in] bitmap source data
+/// \param[in] offset bit offset into the source data
+/// \param[in] length number of bits to copy
+/// \param[in] dest_offset bit offset into the destination
+/// \param[out] dest the destination buffer, must have at least space for
+/// (offset + length) bits
+ARROW_EXPORT
+void CopyBitmap(const uint8_t* bitmap, int64_t offset, int64_t length, uint8_t* dest,
+ int64_t dest_offset);
+
+/// Invert a bit range of an existing bitmap into an existing bitmap
+///
+/// \param[in] bitmap source data
+/// \param[in] offset bit offset into the source data
+/// \param[in] length number of bits to copy
+/// \param[in] dest_offset bit offset into the destination
+/// \param[out] dest the destination buffer, must have at least space for
+/// (offset + length) bits
+ARROW_EXPORT
+void InvertBitmap(const uint8_t* bitmap, int64_t offset, int64_t length, uint8_t* dest,
+ int64_t dest_offset);
+
+/// Invert a bit range of an existing bitmap
+///
+/// \param[in] pool memory pool to allocate memory from
+/// \param[in] bitmap source data
+/// \param[in] offset bit offset into the source data
+/// \param[in] length number of bits to copy
+///
+/// \return Status message
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> InvertBitmap(MemoryPool* pool, const uint8_t* bitmap,
+ int64_t offset, int64_t length);
+
+/// Compute the number of 1's in the given data array
+///
+/// \param[in] data a packed LSB-ordered bitmap as a byte array
+/// \param[in] bit_offset a bitwise offset into the bitmap
+/// \param[in] length the number of bits to inspect in the bitmap relative to
+/// the offset
+///
+/// \return The number of set (1) bits in the range
+ARROW_EXPORT
+int64_t CountSetBits(const uint8_t* data, int64_t bit_offset, int64_t length);
+
+ARROW_EXPORT
+bool BitmapEquals(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length);
+
+// Same as BitmapEquals, but considers a NULL bitmap pointer the same as an
+// all-ones bitmap.
+ARROW_EXPORT
+bool OptionalBitmapEquals(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length);
+
+ARROW_EXPORT
+bool OptionalBitmapEquals(const std::shared_ptr<Buffer>& left, int64_t left_offset,
+ const std::shared_ptr<Buffer>& right, int64_t right_offset,
+ int64_t length);
+
+/// \brief Do a "bitmap and" on right and left buffers starting at
+/// their respective bit-offsets for the given bit-length and put
+/// the results in out_buffer starting at the given bit-offset.
+///
+/// out_buffer will be allocated and initialized to zeros using pool before
+/// the operation.
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> BitmapAnd(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset);
+
+/// \brief Do a "bitmap and" on right and left buffers starting at
+/// their respective bit-offsets for the given bit-length and put
+/// the results in out starting at the given bit-offset.
+ARROW_EXPORT
+void BitmapAnd(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out);
+
+/// \brief Do a "bitmap or" for the given bit length on right and left buffers
+/// starting at their respective bit-offsets and put the results in out_buffer
+/// starting at the given bit-offset.
+///
+/// out_buffer will be allocated and initialized to zeros using pool before
+/// the operation.
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> BitmapOr(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset);
+
+/// \brief Do a "bitmap or" for the given bit length on right and left buffers
+/// starting at their respective bit-offsets and put the results in out
+/// starting at the given bit-offset.
+ARROW_EXPORT
+void BitmapOr(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out);
+
+/// \brief Do a "bitmap xor" for the given bit-length on right and left
+/// buffers starting at their respective bit-offsets and put the results in
+/// out_buffer starting at the given bit offset.
+///
+/// out_buffer will be allocated and initialized to zeros using pool before
+/// the operation.
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> BitmapXor(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset);
+
+/// \brief Do a "bitmap xor" for the given bit-length on right and left
+/// buffers starting at their respective bit-offsets and put the results in
+/// out starting at the given bit offset.
+ARROW_EXPORT
+void BitmapXor(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out);
+
+/// \brief Do a "bitmap and not" on right and left buffers starting at
+/// their respective bit-offsets for the given bit-length and put
+/// the results in out_buffer starting at the given bit-offset.
+///
+/// out_buffer will be allocated and initialized to zeros using pool before
+/// the operation.
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> BitmapAndNot(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset);
+
+/// \brief Do a "bitmap and not" on right and left buffers starting at
+/// their respective bit-offsets for the given bit-length and put
+/// the results in out starting at the given bit-offset.
+ARROW_EXPORT
+void BitmapAndNot(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out);
+
+/// \brief Do a "bitmap or not" on right and left buffers starting at
+/// their respective bit-offsets for the given bit-length and put
+/// the results in out_buffer starting at the given bit-offset.
+///
+/// out_buffer will be allocated and initialized to zeros using pool before
+/// the operation.
+ARROW_EXPORT
+Result<std::shared_ptr<Buffer>> BitmapOrNot(MemoryPool* pool, const uint8_t* left,
+ int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length,
+ int64_t out_offset);
+
+/// \brief Do a "bitmap or not" on right and left buffers starting at
+/// their respective bit-offsets for the given bit-length and put
+/// the results in out starting at the given bit-offset.
+ARROW_EXPORT
+void BitmapOrNot(const uint8_t* left, int64_t left_offset, const uint8_t* right,
+ int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bitmap_reader.h b/src/arrow/cpp/src/arrow/util/bitmap_reader.h
new file mode 100644
index 000000000..55d92d15c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bitmap_reader.h
@@ -0,0 +1,271 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <cstring>
+
+#include "arrow/buffer.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+class BitmapReader {
+ public:
+ BitmapReader(const uint8_t* bitmap, int64_t start_offset, int64_t length)
+ : bitmap_(bitmap), position_(0), length_(length) {
+ current_byte_ = 0;
+ byte_offset_ = start_offset / 8;
+ bit_offset_ = start_offset % 8;
+ if (length > 0) {
+ current_byte_ = bitmap[byte_offset_];
+ }
+ }
+
+ bool IsSet() const { return (current_byte_ & (1 << bit_offset_)) != 0; }
+
+ bool IsNotSet() const { return (current_byte_ & (1 << bit_offset_)) == 0; }
+
+ void Next() {
+ ++bit_offset_;
+ ++position_;
+ if (ARROW_PREDICT_FALSE(bit_offset_ == 8)) {
+ bit_offset_ = 0;
+ ++byte_offset_;
+ if (ARROW_PREDICT_TRUE(position_ < length_)) {
+ current_byte_ = bitmap_[byte_offset_];
+ }
+ }
+ }
+
+ int64_t position() const { return position_; }
+
+ int64_t length() const { return length_; }
+
+ private:
+ const uint8_t* bitmap_;
+ int64_t position_;
+ int64_t length_;
+
+ uint8_t current_byte_;
+ int64_t byte_offset_;
+ int64_t bit_offset_;
+};
+
+// XXX Cannot name it BitmapWordReader because the name is already used
+// in bitmap_ops.cc
+
+class BitmapUInt64Reader {
+ public:
+ BitmapUInt64Reader(const uint8_t* bitmap, int64_t start_offset, int64_t length)
+ : bitmap_(util::MakeNonNull(bitmap) + start_offset / 8),
+ num_carry_bits_(8 - start_offset % 8),
+ length_(length),
+ remaining_length_(length_) {
+ if (length_ > 0) {
+ // Load carry bits from the first byte's MSBs
+ if (length_ >= num_carry_bits_) {
+ carry_bits_ =
+ LoadPartialWord(static_cast<int8_t>(8 - num_carry_bits_), num_carry_bits_);
+ } else {
+ carry_bits_ = LoadPartialWord(static_cast<int8_t>(8 - num_carry_bits_), length_);
+ }
+ }
+ }
+
+ uint64_t NextWord() {
+ if (ARROW_PREDICT_TRUE(remaining_length_ >= 64 + num_carry_bits_)) {
+ // We can load a full word
+ uint64_t next_word = LoadFullWord();
+ // Carry bits come first, then the (64 - num_carry_bits_) LSBs from next_word
+ uint64_t word = carry_bits_ | (next_word << num_carry_bits_);
+ carry_bits_ = next_word >> (64 - num_carry_bits_);
+ remaining_length_ -= 64;
+ return word;
+ } else if (remaining_length_ > num_carry_bits_) {
+ // We can load a partial word
+ uint64_t next_word =
+ LoadPartialWord(/*bit_offset=*/0, remaining_length_ - num_carry_bits_);
+ uint64_t word = carry_bits_ | (next_word << num_carry_bits_);
+ carry_bits_ = next_word >> (64 - num_carry_bits_);
+ remaining_length_ = std::max<int64_t>(remaining_length_ - 64, 0);
+ return word;
+ } else {
+ remaining_length_ = 0;
+ return carry_bits_;
+ }
+ }
+
+ int64_t position() const { return length_ - remaining_length_; }
+
+ int64_t length() const { return length_; }
+
+ private:
+ uint64_t LoadFullWord() {
+ uint64_t word;
+ memcpy(&word, bitmap_, 8);
+ bitmap_ += 8;
+ return BitUtil::ToLittleEndian(word);
+ }
+
+ uint64_t LoadPartialWord(int8_t bit_offset, int64_t num_bits) {
+ uint64_t word = 0;
+ const int64_t num_bytes = BitUtil::BytesForBits(num_bits);
+ memcpy(&word, bitmap_, num_bytes);
+ bitmap_ += num_bytes;
+ return (BitUtil::ToLittleEndian(word) >> bit_offset) &
+ BitUtil::LeastSignificantBitMask(num_bits);
+ }
+
+ const uint8_t* bitmap_;
+ const int64_t num_carry_bits_; // in [1, 8]
+ const int64_t length_;
+ int64_t remaining_length_;
+ uint64_t carry_bits_;
+};
+
+// BitmapWordReader here is faster than BitmapUInt64Reader (in bitmap_reader.h)
+// on sufficiently large inputs. However, it has a larger prolog / epilog overhead
+// and should probably not be used for small bitmaps.
+
+template <typename Word, bool may_have_byte_offset = true>
+class BitmapWordReader {
+ public:
+ BitmapWordReader() = default;
+ BitmapWordReader(const uint8_t* bitmap, int64_t offset, int64_t length)
+ : offset_(static_cast<int64_t>(may_have_byte_offset) * (offset % 8)),
+ bitmap_(bitmap + offset / 8),
+ bitmap_end_(bitmap_ + BitUtil::BytesForBits(offset_ + length)) {
+ // decrement word count by one as we may touch two adjacent words in one iteration
+ nwords_ = length / (sizeof(Word) * 8) - 1;
+ if (nwords_ < 0) {
+ nwords_ = 0;
+ }
+ trailing_bits_ = static_cast<int>(length - nwords_ * sizeof(Word) * 8);
+ trailing_bytes_ = static_cast<int>(BitUtil::BytesForBits(trailing_bits_));
+
+ if (nwords_ > 0) {
+ current_data.word_ = load<Word>(bitmap_);
+ } else if (length > 0) {
+ current_data.epi.byte_ = load<uint8_t>(bitmap_);
+ }
+ }
+
+ Word NextWord() {
+ bitmap_ += sizeof(Word);
+ const Word next_word = load<Word>(bitmap_);
+ Word word = current_data.word_;
+ if (may_have_byte_offset && offset_) {
+ // combine two adjacent words into one word
+ // |<------ next ----->|<---- current ---->|
+ // +-------------+-----+-------------+-----+
+ // | --- | A | B | --- |
+ // +-------------+-----+-------------+-----+
+ // | | offset
+ // v v
+ // +-----+-------------+
+ // | A | B |
+ // +-----+-------------+
+ // |<------ word ----->|
+ word >>= offset_;
+ word |= next_word << (sizeof(Word) * 8 - offset_);
+ }
+ current_data.word_ = next_word;
+ return word;
+ }
+
+ uint8_t NextTrailingByte(int& valid_bits) {
+ uint8_t byte;
+ assert(trailing_bits_ > 0);
+
+ if (trailing_bits_ <= 8) {
+ // last byte
+ valid_bits = trailing_bits_;
+ trailing_bits_ = 0;
+ byte = 0;
+ internal::BitmapReader reader(bitmap_, offset_, valid_bits);
+ for (int i = 0; i < valid_bits; ++i) {
+ byte >>= 1;
+ if (reader.IsSet()) {
+ byte |= 0x80;
+ }
+ reader.Next();
+ }
+ byte >>= (8 - valid_bits);
+ } else {
+ ++bitmap_;
+ const uint8_t next_byte = load<uint8_t>(bitmap_);
+ byte = current_data.epi.byte_;
+ if (may_have_byte_offset && offset_) {
+ byte >>= offset_;
+ byte |= next_byte << (8 - offset_);
+ }
+ current_data.epi.byte_ = next_byte;
+ trailing_bits_ -= 8;
+ trailing_bytes_--;
+ valid_bits = 8;
+ }
+ return byte;
+ }
+
+ int64_t words() const { return nwords_; }
+ int trailing_bytes() const { return trailing_bytes_; }
+
+ private:
+ int64_t offset_;
+ const uint8_t* bitmap_;
+
+ const uint8_t* bitmap_end_;
+ int64_t nwords_;
+ int trailing_bits_;
+ int trailing_bytes_;
+ union {
+ Word word_;
+ struct {
+#if ARROW_LITTLE_ENDIAN == 0
+ uint8_t padding_bytes_[sizeof(Word) - 1];
+#endif
+ uint8_t byte_;
+ } epi;
+ } current_data;
+
+ template <typename DType>
+ DType load(const uint8_t* bitmap) {
+ assert(bitmap + sizeof(DType) <= bitmap_end_);
+ return BitUtil::ToLittleEndian(util::SafeLoadAs<DType>(bitmap));
+ }
+};
+
+/// \brief Index into a possibly non-existent bitmap
+struct OptionalBitIndexer {
+ const uint8_t* bitmap;
+ const int64_t offset;
+
+ explicit OptionalBitIndexer(const std::shared_ptr<Buffer>& buffer, int64_t offset = 0)
+ : bitmap(buffer == NULLPTR ? NULLPTR : buffer->data()), offset(offset) {}
+
+ bool operator[](int64_t i) const {
+ return bitmap == NULLPTR || BitUtil::GetBit(bitmap, offset + i);
+ }
+};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bitmap_reader_benchmark.cc b/src/arrow/cpp/src/arrow/util/bitmap_reader_benchmark.cc
new file mode 100644
index 000000000..359653c96
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bitmap_reader_benchmark.cc
@@ -0,0 +1,113 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <bitset>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+#include <memory>
+#include <utility>
+
+#include "arrow/buffer.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_reader.h"
+#include "benchmark/benchmark.h"
+
+namespace arrow {
+namespace BitUtil {
+
+using internal::BitBlockCount;
+using internal::BitBlockCounter;
+using internal::BitmapWordReader;
+
+const int64_t kBufferSize = 1024 * (std::rand() % 25 + 1000);
+
+// const int seed = std::rand();
+
+static std::shared_ptr<Buffer> CreateRandomBuffer(int64_t nbytes) {
+ auto buffer = *AllocateBuffer(nbytes);
+ memset(buffer->mutable_data(), 0, nbytes);
+ random_bytes(nbytes, /*seed=*/0, buffer->mutable_data());
+ return std::move(buffer);
+}
+
+static void BitBlockCounterBench(benchmark::State& state) {
+ int64_t nbytes = state.range(0);
+ std::shared_ptr<Buffer> cond_buf = CreateRandomBuffer(nbytes);
+ for (auto _ : state) {
+ BitBlockCounter counter(cond_buf->data(), 0, nbytes * 8);
+
+ int64_t offset = 0;
+ uint64_t set_bits = 0;
+
+ while (offset < nbytes * 8) {
+ const BitBlockCount& word = counter.NextWord();
+ // if (word.AllSet()) {
+ // set_bits += word.length;
+ // } else if (word.popcount) {
+ // set_bits += word.popcount;
+ // }
+ set_bits += word.popcount;
+ benchmark::DoNotOptimize(set_bits);
+ offset += word.length;
+ }
+ benchmark::ClobberMemory();
+ }
+
+ state.SetBytesProcessed(state.iterations() * nbytes);
+}
+
+static void BitmapWordReaderBench(benchmark::State& state) {
+ int64_t nbytes = state.range(0);
+ std::shared_ptr<Buffer> cond_buf = CreateRandomBuffer(nbytes);
+ for (auto _ : state) {
+ BitmapWordReader<uint64_t> counter(cond_buf->data(), 0, nbytes * 8);
+
+ int64_t set_bits = 0;
+
+ int64_t cnt = counter.words();
+ while (cnt--) {
+ const auto& word = counter.NextWord();
+ // if (word == UINT64_MAX) {
+ // set_bits += sizeof(uint64_t) * 8;
+ // } else if (word) {
+ // set_bits += PopCount(word);
+ // }
+ set_bits += PopCount(word);
+ benchmark::DoNotOptimize(set_bits);
+ }
+
+ cnt = counter.trailing_bytes();
+ while (cnt--) {
+ int valid_bits;
+ const auto& byte = static_cast<uint32_t>(counter.NextTrailingByte(valid_bits));
+ set_bits += PopCount(kPrecedingBitmask[valid_bits] & byte);
+ benchmark::DoNotOptimize(set_bits);
+ }
+ benchmark::ClobberMemory();
+ }
+ state.SetBytesProcessed(state.iterations() * nbytes);
+}
+
+BENCHMARK(BitBlockCounterBench)->Arg(kBufferSize);
+BENCHMARK(BitmapWordReaderBench)->Arg(kBufferSize);
+
+} // namespace BitUtil
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bitmap_visit.h b/src/arrow/cpp/src/arrow/util/bitmap_visit.h
new file mode 100644
index 000000000..8a16993e0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bitmap_visit.h
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_reader.h"
+
+namespace arrow {
+namespace internal {
+
+// A function that visits each bit in a bitmap and calls a visitor function with a
+// boolean representation of that bit. This is intended to be analogous to
+// GenerateBits.
+template <class Visitor>
+void VisitBits(const uint8_t* bitmap, int64_t start_offset, int64_t length,
+ Visitor&& visit) {
+ BitmapReader reader(bitmap, start_offset, length);
+ for (int64_t index = 0; index < length; ++index) {
+ visit(reader.IsSet());
+ reader.Next();
+ }
+}
+
+// Like VisitBits(), but unrolls its main loop for better performance.
+template <class Visitor>
+void VisitBitsUnrolled(const uint8_t* bitmap, int64_t start_offset, int64_t length,
+ Visitor&& visit) {
+ if (length == 0) {
+ return;
+ }
+
+ // Start by visiting any bits preceding the first full byte.
+ int64_t num_bits_before_full_bytes =
+ BitUtil::RoundUpToMultipleOf8(start_offset) - start_offset;
+ // Truncate num_bits_before_full_bytes if it is greater than length.
+ if (num_bits_before_full_bytes > length) {
+ num_bits_before_full_bytes = length;
+ }
+ // Use the non loop-unrolled VisitBits since we don't want to add branches
+ VisitBits<Visitor>(bitmap, start_offset, num_bits_before_full_bytes, visit);
+
+ // Shift the start pointer to the first full byte and compute the
+ // number of full bytes to be read.
+ const uint8_t* first_full_byte = bitmap + BitUtil::CeilDiv(start_offset, 8);
+ const int64_t num_full_bytes = (length - num_bits_before_full_bytes) / 8;
+
+ // Iterate over each full byte of the input bitmap and call the visitor in
+ // a loop-unrolled manner.
+ for (int64_t byte_index = 0; byte_index < num_full_bytes; ++byte_index) {
+ // Get the current bit-packed byte value from the bitmap.
+ const uint8_t byte = *(first_full_byte + byte_index);
+
+ // Execute the visitor function on each bit of the current byte.
+ visit(BitUtil::GetBitFromByte(byte, 0));
+ visit(BitUtil::GetBitFromByte(byte, 1));
+ visit(BitUtil::GetBitFromByte(byte, 2));
+ visit(BitUtil::GetBitFromByte(byte, 3));
+ visit(BitUtil::GetBitFromByte(byte, 4));
+ visit(BitUtil::GetBitFromByte(byte, 5));
+ visit(BitUtil::GetBitFromByte(byte, 6));
+ visit(BitUtil::GetBitFromByte(byte, 7));
+ }
+
+ // Write any leftover bits in the last byte.
+ const int64_t num_bits_after_full_bytes = (length - num_bits_before_full_bytes) % 8;
+ VisitBits<Visitor>(first_full_byte + num_full_bytes, 0, num_bits_after_full_bytes,
+ visit);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bitmap_writer.h b/src/arrow/cpp/src/arrow/util/bitmap_writer.h
new file mode 100644
index 000000000..1df1baa0f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bitmap_writer.h
@@ -0,0 +1,285 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <cstring>
+
+#include "arrow/util/bit_util.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+class BitmapWriter {
+ // A sequential bitwise writer that preserves surrounding bit values.
+
+ public:
+ BitmapWriter(uint8_t* bitmap, int64_t start_offset, int64_t length)
+ : bitmap_(bitmap), position_(0), length_(length) {
+ byte_offset_ = start_offset / 8;
+ bit_mask_ = BitUtil::kBitmask[start_offset % 8];
+ if (length > 0) {
+ current_byte_ = bitmap[byte_offset_];
+ } else {
+ current_byte_ = 0;
+ }
+ }
+
+ void Set() { current_byte_ |= bit_mask_; }
+
+ void Clear() { current_byte_ &= bit_mask_ ^ 0xFF; }
+
+ void Next() {
+ bit_mask_ = static_cast<uint8_t>(bit_mask_ << 1);
+ ++position_;
+ if (bit_mask_ == 0) {
+ // Finished this byte, need advancing
+ bit_mask_ = 0x01;
+ bitmap_[byte_offset_++] = current_byte_;
+ if (ARROW_PREDICT_TRUE(position_ < length_)) {
+ current_byte_ = bitmap_[byte_offset_];
+ }
+ }
+ }
+
+ void Finish() {
+ // Store current byte if we didn't went past bitmap storage
+ if (length_ > 0 && (bit_mask_ != 0x01 || position_ < length_)) {
+ bitmap_[byte_offset_] = current_byte_;
+ }
+ }
+
+ int64_t position() const { return position_; }
+
+ private:
+ uint8_t* bitmap_;
+ int64_t position_;
+ int64_t length_;
+
+ uint8_t current_byte_;
+ uint8_t bit_mask_;
+ int64_t byte_offset_;
+};
+
+class FirstTimeBitmapWriter {
+ // Like BitmapWriter, but any bit values *following* the bits written
+ // might be clobbered. It is hence faster than BitmapWriter, and can
+ // also avoid false positives with Valgrind.
+
+ public:
+ FirstTimeBitmapWriter(uint8_t* bitmap, int64_t start_offset, int64_t length)
+ : bitmap_(bitmap), position_(0), length_(length) {
+ current_byte_ = 0;
+ byte_offset_ = start_offset / 8;
+ bit_mask_ = BitUtil::kBitmask[start_offset % 8];
+ if (length > 0) {
+ current_byte_ = bitmap[byte_offset_] & BitUtil::kPrecedingBitmask[start_offset % 8];
+ } else {
+ current_byte_ = 0;
+ }
+ }
+
+ /// Appends number_of_bits from word to valid_bits and valid_bits_offset.
+ ///
+ /// \param[in] word The LSB bitmap to append. Any bits past number_of_bits are assumed
+ /// to be unset (i.e. 0).
+ /// \param[in] number_of_bits The number of bits to append from word.
+ void AppendWord(uint64_t word, int64_t number_of_bits) {
+ if (ARROW_PREDICT_FALSE(number_of_bits == 0)) {
+ return;
+ }
+
+ // Location that the first byte needs to be written to.
+ uint8_t* append_position = bitmap_ + byte_offset_;
+
+ // Update state variables except for current_byte_ here.
+ position_ += number_of_bits;
+ int64_t bit_offset = BitUtil::CountTrailingZeros(static_cast<uint32_t>(bit_mask_));
+ bit_mask_ = BitUtil::kBitmask[(bit_offset + number_of_bits) % 8];
+ byte_offset_ += (bit_offset + number_of_bits) / 8;
+
+ if (bit_offset != 0) {
+ // We are in the middle of the byte. This code updates the byte and shifts
+ // bits appropriately within word so it can be memcpy'd below.
+ int64_t bits_to_carry = 8 - bit_offset;
+ // Carry over bits from word to current_byte_. We assume any extra bits in word
+ // unset so no additional accounting is needed for when number_of_bits <
+ // bits_to_carry.
+ current_byte_ |= (word & BitUtil::kPrecedingBitmask[bits_to_carry]) << bit_offset;
+ // Check if everything is transfered into current_byte_.
+ if (ARROW_PREDICT_FALSE(number_of_bits < bits_to_carry)) {
+ return;
+ }
+ *append_position = current_byte_;
+ append_position++;
+ // Move the carry bits off of word.
+ word = word >> bits_to_carry;
+ number_of_bits -= bits_to_carry;
+ }
+ word = BitUtil::ToLittleEndian(word);
+ int64_t bytes_for_word = ::arrow::BitUtil::BytesForBits(number_of_bits);
+ std::memcpy(append_position, &word, bytes_for_word);
+ // At this point, the previous current_byte_ has been written to bitmap_.
+ // The new current_byte_ is either the last relevant byte in 'word'
+ // or cleared if the new position is byte aligned (i.e. a fresh byte).
+ if (bit_mask_ == 0x1) {
+ current_byte_ = 0;
+ } else {
+ current_byte_ = *(append_position + bytes_for_word - 1);
+ }
+ }
+
+ void Set() { current_byte_ |= bit_mask_; }
+
+ void Clear() {}
+
+ void Next() {
+ bit_mask_ = static_cast<uint8_t>(bit_mask_ << 1);
+ ++position_;
+ if (bit_mask_ == 0) {
+ // Finished this byte, need advancing
+ bit_mask_ = 0x01;
+ bitmap_[byte_offset_++] = current_byte_;
+ current_byte_ = 0;
+ }
+ }
+
+ void Finish() {
+ // Store current byte if we didn't went go bitmap storage
+ if (length_ > 0 && (bit_mask_ != 0x01 || position_ < length_)) {
+ bitmap_[byte_offset_] = current_byte_;
+ }
+ }
+
+ int64_t position() const { return position_; }
+
+ private:
+ uint8_t* bitmap_;
+ int64_t position_;
+ int64_t length_;
+
+ uint8_t current_byte_;
+ uint8_t bit_mask_;
+ int64_t byte_offset_;
+};
+
+template <typename Word, bool may_have_byte_offset = true>
+class BitmapWordWriter {
+ public:
+ BitmapWordWriter() = default;
+ BitmapWordWriter(uint8_t* bitmap, int64_t offset, int64_t length)
+ : offset_(static_cast<int64_t>(may_have_byte_offset) * (offset % 8)),
+ bitmap_(bitmap + offset / 8),
+ bitmap_end_(bitmap_ + BitUtil::BytesForBits(offset_ + length)),
+ mask_((1U << offset_) - 1) {
+ if (offset_) {
+ if (length >= static_cast<int>(sizeof(Word) * 8)) {
+ current_data.word_ = load<Word>(bitmap_);
+ } else if (length > 0) {
+ current_data.epi.byte_ = load<uint8_t>(bitmap_);
+ }
+ }
+ }
+
+ void PutNextWord(Word word) {
+ if (may_have_byte_offset && offset_) {
+ // split one word into two adjacent words, don't touch unused bits
+ // |<------ word ----->|
+ // +-----+-------------+
+ // | A | B |
+ // +-----+-------------+
+ // | |
+ // v v offset
+ // +-------------+-----+-------------+-----+
+ // | --- | A | B | --- |
+ // +-------------+-----+-------------+-----+
+ // |<------ next ----->|<---- current ---->|
+ word = (word << offset_) | (word >> (sizeof(Word) * 8 - offset_));
+ Word next_word = load<Word>(bitmap_ + sizeof(Word));
+ current_data.word_ = (current_data.word_ & mask_) | (word & ~mask_);
+ next_word = (next_word & ~mask_) | (word & mask_);
+ store<Word>(bitmap_, current_data.word_);
+ store<Word>(bitmap_ + sizeof(Word), next_word);
+ current_data.word_ = next_word;
+ } else {
+ store<Word>(bitmap_, word);
+ }
+ bitmap_ += sizeof(Word);
+ }
+
+ void PutNextTrailingByte(uint8_t byte, int valid_bits) {
+ if (valid_bits == 8) {
+ if (may_have_byte_offset && offset_) {
+ byte = (byte << offset_) | (byte >> (8 - offset_));
+ uint8_t next_byte = load<uint8_t>(bitmap_ + 1);
+ current_data.epi.byte_ = (current_data.epi.byte_ & mask_) | (byte & ~mask_);
+ next_byte = (next_byte & ~mask_) | (byte & mask_);
+ store<uint8_t>(bitmap_, current_data.epi.byte_);
+ store<uint8_t>(bitmap_ + 1, next_byte);
+ current_data.epi.byte_ = next_byte;
+ } else {
+ store<uint8_t>(bitmap_, byte);
+ }
+ ++bitmap_;
+ } else {
+ assert(valid_bits > 0);
+ assert(valid_bits < 8);
+ assert(bitmap_ + BitUtil::BytesForBits(offset_ + valid_bits) <= bitmap_end_);
+ internal::BitmapWriter writer(bitmap_, offset_, valid_bits);
+ for (int i = 0; i < valid_bits; ++i) {
+ (byte & 0x01) ? writer.Set() : writer.Clear();
+ writer.Next();
+ byte >>= 1;
+ }
+ writer.Finish();
+ }
+ }
+
+ private:
+ int64_t offset_;
+ uint8_t* bitmap_;
+
+ const uint8_t* bitmap_end_;
+ uint64_t mask_;
+ union {
+ Word word_;
+ struct {
+#if ARROW_LITTLE_ENDIAN == 0
+ uint8_t padding_bytes_[sizeof(Word) - 1];
+#endif
+ uint8_t byte_;
+ } epi;
+ } current_data;
+
+ template <typename DType>
+ DType load(const uint8_t* bitmap) {
+ assert(bitmap + sizeof(DType) <= bitmap_end_);
+ return BitUtil::ToLittleEndian(util::SafeLoadAs<DType>(bitmap));
+ }
+
+ template <typename DType>
+ void store(uint8_t* bitmap, DType data) {
+ assert(bitmap + sizeof(DType) <= bitmap_end_);
+ util::SafeStore(bitmap, BitUtil::FromLittleEndian(data));
+ }
+};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bitset_stack.h b/src/arrow/cpp/src/arrow/util/bitset_stack.h
new file mode 100644
index 000000000..addded949
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bitset_stack.h
@@ -0,0 +1,89 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <array>
+#include <bitset>
+#include <cassert>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/memory_pool.h"
+#include "arrow/result.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/compare.h"
+#include "arrow/util/functional.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_builder.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/type_traits.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace internal {
+
+/// \brief Store a stack of bitsets efficiently. The top bitset may be
+/// accessed and its bits may be modified, but it may not be resized.
+class BitsetStack {
+ public:
+ using reference = typename std::vector<bool>::reference;
+
+ /// \brief push a bitset onto the stack
+ /// \param size number of bits in the next bitset
+ /// \param value initial value for bits in the pushed bitset
+ void Push(int size, bool value) {
+ offsets_.push_back(bit_count());
+ bits_.resize(bit_count() + size, value);
+ }
+
+ /// \brief number of bits in the bitset at the top of the stack
+ int TopSize() const {
+ if (offsets_.size() == 0) return 0;
+ return bit_count() - offsets_.back();
+ }
+
+ /// \brief pop a bitset off the stack
+ void Pop() {
+ bits_.resize(offsets_.back());
+ offsets_.pop_back();
+ }
+
+ /// \brief get the value of a bit in the top bitset
+ /// \param i index of the bit to access
+ bool operator[](int i) const { return bits_[offsets_.back() + i]; }
+
+ /// \brief get a mutable reference to a bit in the top bitset
+ /// \param i index of the bit to access
+ reference operator[](int i) { return bits_[offsets_.back() + i]; }
+
+ private:
+ int bit_count() const { return static_cast<int>(bits_.size()); }
+ std::vector<bool> bits_;
+ std::vector<int> offsets_;
+};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bpacking.cc b/src/arrow/cpp/src/arrow/util/bpacking.cc
new file mode 100644
index 000000000..c1b0d706a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking.cc
@@ -0,0 +1,396 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/bpacking.h"
+
+#include "arrow/util/bpacking64_default.h"
+#include "arrow/util/bpacking_default.h"
+#include "arrow/util/cpu_info.h"
+#include "arrow/util/dispatch.h"
+#include "arrow/util/logging.h"
+
+#if defined(ARROW_HAVE_RUNTIME_AVX2)
+#include "arrow/util/bpacking_avx2.h"
+#endif
+#if defined(ARROW_HAVE_RUNTIME_AVX512)
+#include "arrow/util/bpacking_avx512.h"
+#endif
+#if defined(ARROW_HAVE_NEON)
+#include "arrow/util/bpacking_neon.h"
+#endif
+
+namespace arrow {
+namespace internal {
+
+namespace {
+
+int unpack32_default(const uint32_t* in, uint32_t* out, int batch_size, int num_bits) {
+ batch_size = batch_size / 32 * 32;
+ int num_loops = batch_size / 32;
+
+ switch (num_bits) {
+ case 0:
+ for (int i = 0; i < num_loops; ++i) in = nullunpacker32(in, out + i * 32);
+ break;
+ case 1:
+ for (int i = 0; i < num_loops; ++i) in = unpack1_32(in, out + i * 32);
+ break;
+ case 2:
+ for (int i = 0; i < num_loops; ++i) in = unpack2_32(in, out + i * 32);
+ break;
+ case 3:
+ for (int i = 0; i < num_loops; ++i) in = unpack3_32(in, out + i * 32);
+ break;
+ case 4:
+ for (int i = 0; i < num_loops; ++i) in = unpack4_32(in, out + i * 32);
+ break;
+ case 5:
+ for (int i = 0; i < num_loops; ++i) in = unpack5_32(in, out + i * 32);
+ break;
+ case 6:
+ for (int i = 0; i < num_loops; ++i) in = unpack6_32(in, out + i * 32);
+ break;
+ case 7:
+ for (int i = 0; i < num_loops; ++i) in = unpack7_32(in, out + i * 32);
+ break;
+ case 8:
+ for (int i = 0; i < num_loops; ++i) in = unpack8_32(in, out + i * 32);
+ break;
+ case 9:
+ for (int i = 0; i < num_loops; ++i) in = unpack9_32(in, out + i * 32);
+ break;
+ case 10:
+ for (int i = 0; i < num_loops; ++i) in = unpack10_32(in, out + i * 32);
+ break;
+ case 11:
+ for (int i = 0; i < num_loops; ++i) in = unpack11_32(in, out + i * 32);
+ break;
+ case 12:
+ for (int i = 0; i < num_loops; ++i) in = unpack12_32(in, out + i * 32);
+ break;
+ case 13:
+ for (int i = 0; i < num_loops; ++i) in = unpack13_32(in, out + i * 32);
+ break;
+ case 14:
+ for (int i = 0; i < num_loops; ++i) in = unpack14_32(in, out + i * 32);
+ break;
+ case 15:
+ for (int i = 0; i < num_loops; ++i) in = unpack15_32(in, out + i * 32);
+ break;
+ case 16:
+ for (int i = 0; i < num_loops; ++i) in = unpack16_32(in, out + i * 32);
+ break;
+ case 17:
+ for (int i = 0; i < num_loops; ++i) in = unpack17_32(in, out + i * 32);
+ break;
+ case 18:
+ for (int i = 0; i < num_loops; ++i) in = unpack18_32(in, out + i * 32);
+ break;
+ case 19:
+ for (int i = 0; i < num_loops; ++i) in = unpack19_32(in, out + i * 32);
+ break;
+ case 20:
+ for (int i = 0; i < num_loops; ++i) in = unpack20_32(in, out + i * 32);
+ break;
+ case 21:
+ for (int i = 0; i < num_loops; ++i) in = unpack21_32(in, out + i * 32);
+ break;
+ case 22:
+ for (int i = 0; i < num_loops; ++i) in = unpack22_32(in, out + i * 32);
+ break;
+ case 23:
+ for (int i = 0; i < num_loops; ++i) in = unpack23_32(in, out + i * 32);
+ break;
+ case 24:
+ for (int i = 0; i < num_loops; ++i) in = unpack24_32(in, out + i * 32);
+ break;
+ case 25:
+ for (int i = 0; i < num_loops; ++i) in = unpack25_32(in, out + i * 32);
+ break;
+ case 26:
+ for (int i = 0; i < num_loops; ++i) in = unpack26_32(in, out + i * 32);
+ break;
+ case 27:
+ for (int i = 0; i < num_loops; ++i) in = unpack27_32(in, out + i * 32);
+ break;
+ case 28:
+ for (int i = 0; i < num_loops; ++i) in = unpack28_32(in, out + i * 32);
+ break;
+ case 29:
+ for (int i = 0; i < num_loops; ++i) in = unpack29_32(in, out + i * 32);
+ break;
+ case 30:
+ for (int i = 0; i < num_loops; ++i) in = unpack30_32(in, out + i * 32);
+ break;
+ case 31:
+ for (int i = 0; i < num_loops; ++i) in = unpack31_32(in, out + i * 32);
+ break;
+ case 32:
+ for (int i = 0; i < num_loops; ++i) in = unpack32_32(in, out + i * 32);
+ break;
+ default:
+ DCHECK(false) << "Unsupported num_bits";
+ }
+
+ return batch_size;
+}
+
+struct Unpack32DynamicFunction {
+ using FunctionType = decltype(&unpack32_default);
+
+ static std::vector<std::pair<DispatchLevel, FunctionType>> implementations() {
+ return {
+ { DispatchLevel::NONE, unpack32_default }
+#if defined(ARROW_HAVE_RUNTIME_AVX2)
+ , { DispatchLevel::AVX2, unpack32_avx2 }
+#endif
+#if defined(ARROW_HAVE_RUNTIME_AVX512)
+ , { DispatchLevel::AVX512, unpack32_avx512 }
+#endif
+ };
+ }
+};
+
+} // namespace
+
+int unpack32(const uint32_t* in, uint32_t* out, int batch_size, int num_bits) {
+#if defined(ARROW_HAVE_NEON)
+ return unpack32_neon(in, out, batch_size, num_bits);
+#else
+ static DynamicDispatch<Unpack32DynamicFunction> dispatch;
+ return dispatch.func(in, out, batch_size, num_bits);
+#endif
+}
+
+namespace {
+
+int unpack64_default(const uint8_t* in, uint64_t* out, int batch_size, int num_bits) {
+ batch_size = batch_size / 32 * 32;
+ int num_loops = batch_size / 32;
+
+ switch (num_bits) {
+ case 0:
+ for (int i = 0; i < num_loops; ++i) in = unpack0_64(in, out + i * 32);
+ break;
+ case 1:
+ for (int i = 0; i < num_loops; ++i) in = unpack1_64(in, out + i * 32);
+ break;
+ case 2:
+ for (int i = 0; i < num_loops; ++i) in = unpack2_64(in, out + i * 32);
+ break;
+ case 3:
+ for (int i = 0; i < num_loops; ++i) in = unpack3_64(in, out + i * 32);
+ break;
+ case 4:
+ for (int i = 0; i < num_loops; ++i) in = unpack4_64(in, out + i * 32);
+ break;
+ case 5:
+ for (int i = 0; i < num_loops; ++i) in = unpack5_64(in, out + i * 32);
+ break;
+ case 6:
+ for (int i = 0; i < num_loops; ++i) in = unpack6_64(in, out + i * 32);
+ break;
+ case 7:
+ for (int i = 0; i < num_loops; ++i) in = unpack7_64(in, out + i * 32);
+ break;
+ case 8:
+ for (int i = 0; i < num_loops; ++i) in = unpack8_64(in, out + i * 32);
+ break;
+ case 9:
+ for (int i = 0; i < num_loops; ++i) in = unpack9_64(in, out + i * 32);
+ break;
+ case 10:
+ for (int i = 0; i < num_loops; ++i) in = unpack10_64(in, out + i * 32);
+ break;
+ case 11:
+ for (int i = 0; i < num_loops; ++i) in = unpack11_64(in, out + i * 32);
+ break;
+ case 12:
+ for (int i = 0; i < num_loops; ++i) in = unpack12_64(in, out + i * 32);
+ break;
+ case 13:
+ for (int i = 0; i < num_loops; ++i) in = unpack13_64(in, out + i * 32);
+ break;
+ case 14:
+ for (int i = 0; i < num_loops; ++i) in = unpack14_64(in, out + i * 32);
+ break;
+ case 15:
+ for (int i = 0; i < num_loops; ++i) in = unpack15_64(in, out + i * 32);
+ break;
+ case 16:
+ for (int i = 0; i < num_loops; ++i) in = unpack16_64(in, out + i * 32);
+ break;
+ case 17:
+ for (int i = 0; i < num_loops; ++i) in = unpack17_64(in, out + i * 32);
+ break;
+ case 18:
+ for (int i = 0; i < num_loops; ++i) in = unpack18_64(in, out + i * 32);
+ break;
+ case 19:
+ for (int i = 0; i < num_loops; ++i) in = unpack19_64(in, out + i * 32);
+ break;
+ case 20:
+ for (int i = 0; i < num_loops; ++i) in = unpack20_64(in, out + i * 32);
+ break;
+ case 21:
+ for (int i = 0; i < num_loops; ++i) in = unpack21_64(in, out + i * 32);
+ break;
+ case 22:
+ for (int i = 0; i < num_loops; ++i) in = unpack22_64(in, out + i * 32);
+ break;
+ case 23:
+ for (int i = 0; i < num_loops; ++i) in = unpack23_64(in, out + i * 32);
+ break;
+ case 24:
+ for (int i = 0; i < num_loops; ++i) in = unpack24_64(in, out + i * 32);
+ break;
+ case 25:
+ for (int i = 0; i < num_loops; ++i) in = unpack25_64(in, out + i * 32);
+ break;
+ case 26:
+ for (int i = 0; i < num_loops; ++i) in = unpack26_64(in, out + i * 32);
+ break;
+ case 27:
+ for (int i = 0; i < num_loops; ++i) in = unpack27_64(in, out + i * 32);
+ break;
+ case 28:
+ for (int i = 0; i < num_loops; ++i) in = unpack28_64(in, out + i * 32);
+ break;
+ case 29:
+ for (int i = 0; i < num_loops; ++i) in = unpack29_64(in, out + i * 32);
+ break;
+ case 30:
+ for (int i = 0; i < num_loops; ++i) in = unpack30_64(in, out + i * 32);
+ break;
+ case 31:
+ for (int i = 0; i < num_loops; ++i) in = unpack31_64(in, out + i * 32);
+ break;
+ case 32:
+ for (int i = 0; i < num_loops; ++i) in = unpack32_64(in, out + i * 32);
+ break;
+ case 33:
+ for (int i = 0; i < num_loops; ++i) in = unpack33_64(in, out + i * 32);
+ break;
+ case 34:
+ for (int i = 0; i < num_loops; ++i) in = unpack34_64(in, out + i * 32);
+ break;
+ case 35:
+ for (int i = 0; i < num_loops; ++i) in = unpack35_64(in, out + i * 32);
+ break;
+ case 36:
+ for (int i = 0; i < num_loops; ++i) in = unpack36_64(in, out + i * 32);
+ break;
+ case 37:
+ for (int i = 0; i < num_loops; ++i) in = unpack37_64(in, out + i * 32);
+ break;
+ case 38:
+ for (int i = 0; i < num_loops; ++i) in = unpack38_64(in, out + i * 32);
+ break;
+ case 39:
+ for (int i = 0; i < num_loops; ++i) in = unpack39_64(in, out + i * 32);
+ break;
+ case 40:
+ for (int i = 0; i < num_loops; ++i) in = unpack40_64(in, out + i * 32);
+ break;
+ case 41:
+ for (int i = 0; i < num_loops; ++i) in = unpack41_64(in, out + i * 32);
+ break;
+ case 42:
+ for (int i = 0; i < num_loops; ++i) in = unpack42_64(in, out + i * 32);
+ break;
+ case 43:
+ for (int i = 0; i < num_loops; ++i) in = unpack43_64(in, out + i * 32);
+ break;
+ case 44:
+ for (int i = 0; i < num_loops; ++i) in = unpack44_64(in, out + i * 32);
+ break;
+ case 45:
+ for (int i = 0; i < num_loops; ++i) in = unpack45_64(in, out + i * 32);
+ break;
+ case 46:
+ for (int i = 0; i < num_loops; ++i) in = unpack46_64(in, out + i * 32);
+ break;
+ case 47:
+ for (int i = 0; i < num_loops; ++i) in = unpack47_64(in, out + i * 32);
+ break;
+ case 48:
+ for (int i = 0; i < num_loops; ++i) in = unpack48_64(in, out + i * 32);
+ break;
+ case 49:
+ for (int i = 0; i < num_loops; ++i) in = unpack49_64(in, out + i * 32);
+ break;
+ case 50:
+ for (int i = 0; i < num_loops; ++i) in = unpack50_64(in, out + i * 32);
+ break;
+ case 51:
+ for (int i = 0; i < num_loops; ++i) in = unpack51_64(in, out + i * 32);
+ break;
+ case 52:
+ for (int i = 0; i < num_loops; ++i) in = unpack52_64(in, out + i * 32);
+ break;
+ case 53:
+ for (int i = 0; i < num_loops; ++i) in = unpack53_64(in, out + i * 32);
+ break;
+ case 54:
+ for (int i = 0; i < num_loops; ++i) in = unpack54_64(in, out + i * 32);
+ break;
+ case 55:
+ for (int i = 0; i < num_loops; ++i) in = unpack55_64(in, out + i * 32);
+ break;
+ case 56:
+ for (int i = 0; i < num_loops; ++i) in = unpack56_64(in, out + i * 32);
+ break;
+ case 57:
+ for (int i = 0; i < num_loops; ++i) in = unpack57_64(in, out + i * 32);
+ break;
+ case 58:
+ for (int i = 0; i < num_loops; ++i) in = unpack58_64(in, out + i * 32);
+ break;
+ case 59:
+ for (int i = 0; i < num_loops; ++i) in = unpack59_64(in, out + i * 32);
+ break;
+ case 60:
+ for (int i = 0; i < num_loops; ++i) in = unpack60_64(in, out + i * 32);
+ break;
+ case 61:
+ for (int i = 0; i < num_loops; ++i) in = unpack61_64(in, out + i * 32);
+ break;
+ case 62:
+ for (int i = 0; i < num_loops; ++i) in = unpack62_64(in, out + i * 32);
+ break;
+ case 63:
+ for (int i = 0; i < num_loops; ++i) in = unpack63_64(in, out + i * 32);
+ break;
+ case 64:
+ for (int i = 0; i < num_loops; ++i) in = unpack64_64(in, out + i * 32);
+ break;
+ default:
+ DCHECK(false) << "Unsupported num_bits";
+ }
+
+ return batch_size;
+}
+
+} // namespace
+
+int unpack64(const uint8_t* in, uint64_t* out, int batch_size, int num_bits) {
+ // TODO: unpack64_neon, unpack64_avx2 and unpack64_avx512
+ return unpack64_default(in, out, batch_size, num_bits);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bpacking.h b/src/arrow/cpp/src/arrow/util/bpacking.h
new file mode 100644
index 000000000..dd85c1638
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking.h
@@ -0,0 +1,34 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/util/endian.h"
+#include "arrow/util/visibility.h"
+
+#include <stdint.h>
+
+namespace arrow {
+namespace internal {
+
+ARROW_EXPORT
+int unpack32(const uint32_t* in, uint32_t* out, int batch_size, int num_bits);
+ARROW_EXPORT
+int unpack64(const uint8_t* in, uint64_t* out, int batch_size, int num_bits);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bpacking64_codegen.py b/src/arrow/cpp/src/arrow/util/bpacking64_codegen.py
new file mode 100644
index 000000000..f9b06b4d8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking64_codegen.py
@@ -0,0 +1,131 @@
+#!/bin/python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This script is modified from its original version in GitHub. Original source:
+# https://github.com/lemire/FrameOfReference/blob/146948b6058a976bc7767262ad3a2ce201486b93/scripts/turbopacking64.py
+
+# Usage:
+# python bpacking64_codegen.py > bpacking64_default.h
+
+def howmany(bit):
+ """ how many values are we going to pack? """
+ return 32
+
+
+def howmanywords(bit):
+ return (howmany(bit) * bit + 63)//64
+
+
+def howmanybytes(bit):
+ return (howmany(bit) * bit + 7)//8
+
+
+print('''// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This file was generated by script which is modified from its original version in GitHub.
+// Original source:
+// https://github.com/lemire/FrameOfReference/blob/master/scripts/turbopacking64.py
+// The original copyright notice follows.
+
+// This code is released under the
+// Apache License Version 2.0 http://www.apache.org/licenses/.
+// (c) Daniel Lemire 2013
+
+#pragma once
+
+#include "arrow/util/bit_util.h"
+#include "arrow/util/ubsan.h"
+
+namespace arrow {
+namespace internal {
+''')
+
+
+print("inline const uint8_t* unpack0_64(const uint8_t* in, uint64_t* out) {")
+print(" for(int k = 0; k < {0} ; k += 1) {{".format(howmany(0)))
+print(" out[k] = 0;")
+print(" }")
+print(" return in;")
+print("}")
+
+for bit in range(1, 65):
+ print("")
+ print(
+ "inline const uint8_t* unpack{0}_64(const uint8_t* in, uint64_t* out) {{".format(bit))
+
+ if(bit < 64):
+ print(" const uint64_t mask = {0}ULL;".format((1 << bit)-1))
+ maskstr = " & mask"
+ if (bit == 64):
+ maskstr = "" # no need
+
+ for k in range(howmanywords(bit)-1):
+ print(" uint64_t w{0} = util::SafeLoadAs<uint64_t>(in);".format(k))
+ print(" w{0} = arrow::BitUtil::FromLittleEndian(w{0});".format(k))
+ print(" in += 8;".format(k))
+ k = howmanywords(bit) - 1
+ if (bit % 2 == 0):
+ print(" uint64_t w{0} = util::SafeLoadAs<uint64_t>(in);".format(k))
+ print(" w{0} = arrow::BitUtil::FromLittleEndian(w{0});".format(k))
+ print(" in += 8;".format(k))
+ else:
+ print(" uint64_t w{0} = util::SafeLoadAs<uint32_t>(in);".format(k))
+ print(" w{0} = arrow::BitUtil::FromLittleEndian(w{0});".format(k))
+ print(" in += 4;".format(k))
+
+ for j in range(howmany(bit)):
+ firstword = j * bit // 64
+ secondword = (j * bit + bit - 1)//64
+ firstshift = (j*bit) % 64
+ firstshiftstr = " >> {0}".format(firstshift)
+ if(firstshift == 0):
+ firstshiftstr = "" # no need
+ if(firstword == secondword):
+ if(firstshift + bit == 64):
+ print(" out[{0}] = w{1}{2};".format(
+ j, firstword, firstshiftstr, firstshift))
+ else:
+ print(" out[{0}] = (w{1}{2}){3};".format(
+ j, firstword, firstshiftstr, maskstr))
+ else:
+ secondshift = (64-firstshift)
+ print(" out[{0}] = ((w{1}{2}) | (w{3} << {4})){5};".format(
+ j, firstword, firstshiftstr, firstword+1, secondshift, maskstr))
+ print("")
+ print(" return in;")
+ print("}")
+
+print('''
+} // namespace internal
+} // namespace arrow''')
diff --git a/src/arrow/cpp/src/arrow/util/bpacking64_default.h b/src/arrow/cpp/src/arrow/util/bpacking64_default.h
new file mode 100644
index 000000000..81189cb58
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking64_default.h
@@ -0,0 +1,5642 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This file was generated by script which is modified from its original version in
+// GitHub. Original source:
+// https://github.com/lemire/FrameOfReference/blob/146948b6058a976bc7767262ad3a2ce201486b93/scripts/turbopacking64.py
+// The original copyright notice follows.
+
+// This code is released under the
+// Apache License Version 2.0 http://www.apache.org/licenses/.
+// (c) Daniel Lemire 2013
+
+#pragma once
+
+#include "arrow/util/bit_util.h"
+#include "arrow/util/ubsan.h"
+
+namespace arrow {
+namespace internal {
+
+inline const uint8_t* unpack0_64(const uint8_t* in, uint64_t* out) {
+ for (int k = 0; k < 32; k += 1) {
+ out[k] = 0;
+ }
+ return in;
+}
+
+inline const uint8_t* unpack1_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 1ULL;
+ uint64_t w0 = util::SafeLoadAs<uint32_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 1) & mask;
+ out[2] = (w0 >> 2) & mask;
+ out[3] = (w0 >> 3) & mask;
+ out[4] = (w0 >> 4) & mask;
+ out[5] = (w0 >> 5) & mask;
+ out[6] = (w0 >> 6) & mask;
+ out[7] = (w0 >> 7) & mask;
+ out[8] = (w0 >> 8) & mask;
+ out[9] = (w0 >> 9) & mask;
+ out[10] = (w0 >> 10) & mask;
+ out[11] = (w0 >> 11) & mask;
+ out[12] = (w0 >> 12) & mask;
+ out[13] = (w0 >> 13) & mask;
+ out[14] = (w0 >> 14) & mask;
+ out[15] = (w0 >> 15) & mask;
+ out[16] = (w0 >> 16) & mask;
+ out[17] = (w0 >> 17) & mask;
+ out[18] = (w0 >> 18) & mask;
+ out[19] = (w0 >> 19) & mask;
+ out[20] = (w0 >> 20) & mask;
+ out[21] = (w0 >> 21) & mask;
+ out[22] = (w0 >> 22) & mask;
+ out[23] = (w0 >> 23) & mask;
+ out[24] = (w0 >> 24) & mask;
+ out[25] = (w0 >> 25) & mask;
+ out[26] = (w0 >> 26) & mask;
+ out[27] = (w0 >> 27) & mask;
+ out[28] = (w0 >> 28) & mask;
+ out[29] = (w0 >> 29) & mask;
+ out[30] = (w0 >> 30) & mask;
+ out[31] = (w0 >> 31) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack2_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 3ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 2) & mask;
+ out[2] = (w0 >> 4) & mask;
+ out[3] = (w0 >> 6) & mask;
+ out[4] = (w0 >> 8) & mask;
+ out[5] = (w0 >> 10) & mask;
+ out[6] = (w0 >> 12) & mask;
+ out[7] = (w0 >> 14) & mask;
+ out[8] = (w0 >> 16) & mask;
+ out[9] = (w0 >> 18) & mask;
+ out[10] = (w0 >> 20) & mask;
+ out[11] = (w0 >> 22) & mask;
+ out[12] = (w0 >> 24) & mask;
+ out[13] = (w0 >> 26) & mask;
+ out[14] = (w0 >> 28) & mask;
+ out[15] = (w0 >> 30) & mask;
+ out[16] = (w0 >> 32) & mask;
+ out[17] = (w0 >> 34) & mask;
+ out[18] = (w0 >> 36) & mask;
+ out[19] = (w0 >> 38) & mask;
+ out[20] = (w0 >> 40) & mask;
+ out[21] = (w0 >> 42) & mask;
+ out[22] = (w0 >> 44) & mask;
+ out[23] = (w0 >> 46) & mask;
+ out[24] = (w0 >> 48) & mask;
+ out[25] = (w0 >> 50) & mask;
+ out[26] = (w0 >> 52) & mask;
+ out[27] = (w0 >> 54) & mask;
+ out[28] = (w0 >> 56) & mask;
+ out[29] = (w0 >> 58) & mask;
+ out[30] = (w0 >> 60) & mask;
+ out[31] = w0 >> 62;
+
+ return in;
+}
+
+inline const uint8_t* unpack3_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 7ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint32_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 3) & mask;
+ out[2] = (w0 >> 6) & mask;
+ out[3] = (w0 >> 9) & mask;
+ out[4] = (w0 >> 12) & mask;
+ out[5] = (w0 >> 15) & mask;
+ out[6] = (w0 >> 18) & mask;
+ out[7] = (w0 >> 21) & mask;
+ out[8] = (w0 >> 24) & mask;
+ out[9] = (w0 >> 27) & mask;
+ out[10] = (w0 >> 30) & mask;
+ out[11] = (w0 >> 33) & mask;
+ out[12] = (w0 >> 36) & mask;
+ out[13] = (w0 >> 39) & mask;
+ out[14] = (w0 >> 42) & mask;
+ out[15] = (w0 >> 45) & mask;
+ out[16] = (w0 >> 48) & mask;
+ out[17] = (w0 >> 51) & mask;
+ out[18] = (w0 >> 54) & mask;
+ out[19] = (w0 >> 57) & mask;
+ out[20] = (w0 >> 60) & mask;
+ out[21] = ((w0 >> 63) | (w1 << 1)) & mask;
+ out[22] = (w1 >> 2) & mask;
+ out[23] = (w1 >> 5) & mask;
+ out[24] = (w1 >> 8) & mask;
+ out[25] = (w1 >> 11) & mask;
+ out[26] = (w1 >> 14) & mask;
+ out[27] = (w1 >> 17) & mask;
+ out[28] = (w1 >> 20) & mask;
+ out[29] = (w1 >> 23) & mask;
+ out[30] = (w1 >> 26) & mask;
+ out[31] = (w1 >> 29) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack4_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 15ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 4) & mask;
+ out[2] = (w0 >> 8) & mask;
+ out[3] = (w0 >> 12) & mask;
+ out[4] = (w0 >> 16) & mask;
+ out[5] = (w0 >> 20) & mask;
+ out[6] = (w0 >> 24) & mask;
+ out[7] = (w0 >> 28) & mask;
+ out[8] = (w0 >> 32) & mask;
+ out[9] = (w0 >> 36) & mask;
+ out[10] = (w0 >> 40) & mask;
+ out[11] = (w0 >> 44) & mask;
+ out[12] = (w0 >> 48) & mask;
+ out[13] = (w0 >> 52) & mask;
+ out[14] = (w0 >> 56) & mask;
+ out[15] = w0 >> 60;
+ out[16] = (w1)&mask;
+ out[17] = (w1 >> 4) & mask;
+ out[18] = (w1 >> 8) & mask;
+ out[19] = (w1 >> 12) & mask;
+ out[20] = (w1 >> 16) & mask;
+ out[21] = (w1 >> 20) & mask;
+ out[22] = (w1 >> 24) & mask;
+ out[23] = (w1 >> 28) & mask;
+ out[24] = (w1 >> 32) & mask;
+ out[25] = (w1 >> 36) & mask;
+ out[26] = (w1 >> 40) & mask;
+ out[27] = (w1 >> 44) & mask;
+ out[28] = (w1 >> 48) & mask;
+ out[29] = (w1 >> 52) & mask;
+ out[30] = (w1 >> 56) & mask;
+ out[31] = w1 >> 60;
+
+ return in;
+}
+
+inline const uint8_t* unpack5_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 31ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint32_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 5) & mask;
+ out[2] = (w0 >> 10) & mask;
+ out[3] = (w0 >> 15) & mask;
+ out[4] = (w0 >> 20) & mask;
+ out[5] = (w0 >> 25) & mask;
+ out[6] = (w0 >> 30) & mask;
+ out[7] = (w0 >> 35) & mask;
+ out[8] = (w0 >> 40) & mask;
+ out[9] = (w0 >> 45) & mask;
+ out[10] = (w0 >> 50) & mask;
+ out[11] = (w0 >> 55) & mask;
+ out[12] = ((w0 >> 60) | (w1 << 4)) & mask;
+ out[13] = (w1 >> 1) & mask;
+ out[14] = (w1 >> 6) & mask;
+ out[15] = (w1 >> 11) & mask;
+ out[16] = (w1 >> 16) & mask;
+ out[17] = (w1 >> 21) & mask;
+ out[18] = (w1 >> 26) & mask;
+ out[19] = (w1 >> 31) & mask;
+ out[20] = (w1 >> 36) & mask;
+ out[21] = (w1 >> 41) & mask;
+ out[22] = (w1 >> 46) & mask;
+ out[23] = (w1 >> 51) & mask;
+ out[24] = (w1 >> 56) & mask;
+ out[25] = ((w1 >> 61) | (w2 << 3)) & mask;
+ out[26] = (w2 >> 2) & mask;
+ out[27] = (w2 >> 7) & mask;
+ out[28] = (w2 >> 12) & mask;
+ out[29] = (w2 >> 17) & mask;
+ out[30] = (w2 >> 22) & mask;
+ out[31] = (w2 >> 27) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack6_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 63ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 6) & mask;
+ out[2] = (w0 >> 12) & mask;
+ out[3] = (w0 >> 18) & mask;
+ out[4] = (w0 >> 24) & mask;
+ out[5] = (w0 >> 30) & mask;
+ out[6] = (w0 >> 36) & mask;
+ out[7] = (w0 >> 42) & mask;
+ out[8] = (w0 >> 48) & mask;
+ out[9] = (w0 >> 54) & mask;
+ out[10] = ((w0 >> 60) | (w1 << 4)) & mask;
+ out[11] = (w1 >> 2) & mask;
+ out[12] = (w1 >> 8) & mask;
+ out[13] = (w1 >> 14) & mask;
+ out[14] = (w1 >> 20) & mask;
+ out[15] = (w1 >> 26) & mask;
+ out[16] = (w1 >> 32) & mask;
+ out[17] = (w1 >> 38) & mask;
+ out[18] = (w1 >> 44) & mask;
+ out[19] = (w1 >> 50) & mask;
+ out[20] = (w1 >> 56) & mask;
+ out[21] = ((w1 >> 62) | (w2 << 2)) & mask;
+ out[22] = (w2 >> 4) & mask;
+ out[23] = (w2 >> 10) & mask;
+ out[24] = (w2 >> 16) & mask;
+ out[25] = (w2 >> 22) & mask;
+ out[26] = (w2 >> 28) & mask;
+ out[27] = (w2 >> 34) & mask;
+ out[28] = (w2 >> 40) & mask;
+ out[29] = (w2 >> 46) & mask;
+ out[30] = (w2 >> 52) & mask;
+ out[31] = w2 >> 58;
+
+ return in;
+}
+
+inline const uint8_t* unpack7_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 127ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint32_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 7) & mask;
+ out[2] = (w0 >> 14) & mask;
+ out[3] = (w0 >> 21) & mask;
+ out[4] = (w0 >> 28) & mask;
+ out[5] = (w0 >> 35) & mask;
+ out[6] = (w0 >> 42) & mask;
+ out[7] = (w0 >> 49) & mask;
+ out[8] = (w0 >> 56) & mask;
+ out[9] = ((w0 >> 63) | (w1 << 1)) & mask;
+ out[10] = (w1 >> 6) & mask;
+ out[11] = (w1 >> 13) & mask;
+ out[12] = (w1 >> 20) & mask;
+ out[13] = (w1 >> 27) & mask;
+ out[14] = (w1 >> 34) & mask;
+ out[15] = (w1 >> 41) & mask;
+ out[16] = (w1 >> 48) & mask;
+ out[17] = (w1 >> 55) & mask;
+ out[18] = ((w1 >> 62) | (w2 << 2)) & mask;
+ out[19] = (w2 >> 5) & mask;
+ out[20] = (w2 >> 12) & mask;
+ out[21] = (w2 >> 19) & mask;
+ out[22] = (w2 >> 26) & mask;
+ out[23] = (w2 >> 33) & mask;
+ out[24] = (w2 >> 40) & mask;
+ out[25] = (w2 >> 47) & mask;
+ out[26] = (w2 >> 54) & mask;
+ out[27] = ((w2 >> 61) | (w3 << 3)) & mask;
+ out[28] = (w3 >> 4) & mask;
+ out[29] = (w3 >> 11) & mask;
+ out[30] = (w3 >> 18) & mask;
+ out[31] = (w3 >> 25) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack8_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 255ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 8) & mask;
+ out[2] = (w0 >> 16) & mask;
+ out[3] = (w0 >> 24) & mask;
+ out[4] = (w0 >> 32) & mask;
+ out[5] = (w0 >> 40) & mask;
+ out[6] = (w0 >> 48) & mask;
+ out[7] = w0 >> 56;
+ out[8] = (w1)&mask;
+ out[9] = (w1 >> 8) & mask;
+ out[10] = (w1 >> 16) & mask;
+ out[11] = (w1 >> 24) & mask;
+ out[12] = (w1 >> 32) & mask;
+ out[13] = (w1 >> 40) & mask;
+ out[14] = (w1 >> 48) & mask;
+ out[15] = w1 >> 56;
+ out[16] = (w2)&mask;
+ out[17] = (w2 >> 8) & mask;
+ out[18] = (w2 >> 16) & mask;
+ out[19] = (w2 >> 24) & mask;
+ out[20] = (w2 >> 32) & mask;
+ out[21] = (w2 >> 40) & mask;
+ out[22] = (w2 >> 48) & mask;
+ out[23] = w2 >> 56;
+ out[24] = (w3)&mask;
+ out[25] = (w3 >> 8) & mask;
+ out[26] = (w3 >> 16) & mask;
+ out[27] = (w3 >> 24) & mask;
+ out[28] = (w3 >> 32) & mask;
+ out[29] = (w3 >> 40) & mask;
+ out[30] = (w3 >> 48) & mask;
+ out[31] = w3 >> 56;
+
+ return in;
+}
+
+inline const uint8_t* unpack9_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 511ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint32_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 9) & mask;
+ out[2] = (w0 >> 18) & mask;
+ out[3] = (w0 >> 27) & mask;
+ out[4] = (w0 >> 36) & mask;
+ out[5] = (w0 >> 45) & mask;
+ out[6] = (w0 >> 54) & mask;
+ out[7] = ((w0 >> 63) | (w1 << 1)) & mask;
+ out[8] = (w1 >> 8) & mask;
+ out[9] = (w1 >> 17) & mask;
+ out[10] = (w1 >> 26) & mask;
+ out[11] = (w1 >> 35) & mask;
+ out[12] = (w1 >> 44) & mask;
+ out[13] = (w1 >> 53) & mask;
+ out[14] = ((w1 >> 62) | (w2 << 2)) & mask;
+ out[15] = (w2 >> 7) & mask;
+ out[16] = (w2 >> 16) & mask;
+ out[17] = (w2 >> 25) & mask;
+ out[18] = (w2 >> 34) & mask;
+ out[19] = (w2 >> 43) & mask;
+ out[20] = (w2 >> 52) & mask;
+ out[21] = ((w2 >> 61) | (w3 << 3)) & mask;
+ out[22] = (w3 >> 6) & mask;
+ out[23] = (w3 >> 15) & mask;
+ out[24] = (w3 >> 24) & mask;
+ out[25] = (w3 >> 33) & mask;
+ out[26] = (w3 >> 42) & mask;
+ out[27] = (w3 >> 51) & mask;
+ out[28] = ((w3 >> 60) | (w4 << 4)) & mask;
+ out[29] = (w4 >> 5) & mask;
+ out[30] = (w4 >> 14) & mask;
+ out[31] = (w4 >> 23) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack10_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 1023ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 10) & mask;
+ out[2] = (w0 >> 20) & mask;
+ out[3] = (w0 >> 30) & mask;
+ out[4] = (w0 >> 40) & mask;
+ out[5] = (w0 >> 50) & mask;
+ out[6] = ((w0 >> 60) | (w1 << 4)) & mask;
+ out[7] = (w1 >> 6) & mask;
+ out[8] = (w1 >> 16) & mask;
+ out[9] = (w1 >> 26) & mask;
+ out[10] = (w1 >> 36) & mask;
+ out[11] = (w1 >> 46) & mask;
+ out[12] = ((w1 >> 56) | (w2 << 8)) & mask;
+ out[13] = (w2 >> 2) & mask;
+ out[14] = (w2 >> 12) & mask;
+ out[15] = (w2 >> 22) & mask;
+ out[16] = (w2 >> 32) & mask;
+ out[17] = (w2 >> 42) & mask;
+ out[18] = (w2 >> 52) & mask;
+ out[19] = ((w2 >> 62) | (w3 << 2)) & mask;
+ out[20] = (w3 >> 8) & mask;
+ out[21] = (w3 >> 18) & mask;
+ out[22] = (w3 >> 28) & mask;
+ out[23] = (w3 >> 38) & mask;
+ out[24] = (w3 >> 48) & mask;
+ out[25] = ((w3 >> 58) | (w4 << 6)) & mask;
+ out[26] = (w4 >> 4) & mask;
+ out[27] = (w4 >> 14) & mask;
+ out[28] = (w4 >> 24) & mask;
+ out[29] = (w4 >> 34) & mask;
+ out[30] = (w4 >> 44) & mask;
+ out[31] = w4 >> 54;
+
+ return in;
+}
+
+inline const uint8_t* unpack11_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 2047ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint32_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 11) & mask;
+ out[2] = (w0 >> 22) & mask;
+ out[3] = (w0 >> 33) & mask;
+ out[4] = (w0 >> 44) & mask;
+ out[5] = ((w0 >> 55) | (w1 << 9)) & mask;
+ out[6] = (w1 >> 2) & mask;
+ out[7] = (w1 >> 13) & mask;
+ out[8] = (w1 >> 24) & mask;
+ out[9] = (w1 >> 35) & mask;
+ out[10] = (w1 >> 46) & mask;
+ out[11] = ((w1 >> 57) | (w2 << 7)) & mask;
+ out[12] = (w2 >> 4) & mask;
+ out[13] = (w2 >> 15) & mask;
+ out[14] = (w2 >> 26) & mask;
+ out[15] = (w2 >> 37) & mask;
+ out[16] = (w2 >> 48) & mask;
+ out[17] = ((w2 >> 59) | (w3 << 5)) & mask;
+ out[18] = (w3 >> 6) & mask;
+ out[19] = (w3 >> 17) & mask;
+ out[20] = (w3 >> 28) & mask;
+ out[21] = (w3 >> 39) & mask;
+ out[22] = (w3 >> 50) & mask;
+ out[23] = ((w3 >> 61) | (w4 << 3)) & mask;
+ out[24] = (w4 >> 8) & mask;
+ out[25] = (w4 >> 19) & mask;
+ out[26] = (w4 >> 30) & mask;
+ out[27] = (w4 >> 41) & mask;
+ out[28] = (w4 >> 52) & mask;
+ out[29] = ((w4 >> 63) | (w5 << 1)) & mask;
+ out[30] = (w5 >> 10) & mask;
+ out[31] = (w5 >> 21) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack12_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 4095ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 12) & mask;
+ out[2] = (w0 >> 24) & mask;
+ out[3] = (w0 >> 36) & mask;
+ out[4] = (w0 >> 48) & mask;
+ out[5] = ((w0 >> 60) | (w1 << 4)) & mask;
+ out[6] = (w1 >> 8) & mask;
+ out[7] = (w1 >> 20) & mask;
+ out[8] = (w1 >> 32) & mask;
+ out[9] = (w1 >> 44) & mask;
+ out[10] = ((w1 >> 56) | (w2 << 8)) & mask;
+ out[11] = (w2 >> 4) & mask;
+ out[12] = (w2 >> 16) & mask;
+ out[13] = (w2 >> 28) & mask;
+ out[14] = (w2 >> 40) & mask;
+ out[15] = w2 >> 52;
+ out[16] = (w3)&mask;
+ out[17] = (w3 >> 12) & mask;
+ out[18] = (w3 >> 24) & mask;
+ out[19] = (w3 >> 36) & mask;
+ out[20] = (w3 >> 48) & mask;
+ out[21] = ((w3 >> 60) | (w4 << 4)) & mask;
+ out[22] = (w4 >> 8) & mask;
+ out[23] = (w4 >> 20) & mask;
+ out[24] = (w4 >> 32) & mask;
+ out[25] = (w4 >> 44) & mask;
+ out[26] = ((w4 >> 56) | (w5 << 8)) & mask;
+ out[27] = (w5 >> 4) & mask;
+ out[28] = (w5 >> 16) & mask;
+ out[29] = (w5 >> 28) & mask;
+ out[30] = (w5 >> 40) & mask;
+ out[31] = w5 >> 52;
+
+ return in;
+}
+
+inline const uint8_t* unpack13_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 8191ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint32_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 13) & mask;
+ out[2] = (w0 >> 26) & mask;
+ out[3] = (w0 >> 39) & mask;
+ out[4] = ((w0 >> 52) | (w1 << 12)) & mask;
+ out[5] = (w1 >> 1) & mask;
+ out[6] = (w1 >> 14) & mask;
+ out[7] = (w1 >> 27) & mask;
+ out[8] = (w1 >> 40) & mask;
+ out[9] = ((w1 >> 53) | (w2 << 11)) & mask;
+ out[10] = (w2 >> 2) & mask;
+ out[11] = (w2 >> 15) & mask;
+ out[12] = (w2 >> 28) & mask;
+ out[13] = (w2 >> 41) & mask;
+ out[14] = ((w2 >> 54) | (w3 << 10)) & mask;
+ out[15] = (w3 >> 3) & mask;
+ out[16] = (w3 >> 16) & mask;
+ out[17] = (w3 >> 29) & mask;
+ out[18] = (w3 >> 42) & mask;
+ out[19] = ((w3 >> 55) | (w4 << 9)) & mask;
+ out[20] = (w4 >> 4) & mask;
+ out[21] = (w4 >> 17) & mask;
+ out[22] = (w4 >> 30) & mask;
+ out[23] = (w4 >> 43) & mask;
+ out[24] = ((w4 >> 56) | (w5 << 8)) & mask;
+ out[25] = (w5 >> 5) & mask;
+ out[26] = (w5 >> 18) & mask;
+ out[27] = (w5 >> 31) & mask;
+ out[28] = (w5 >> 44) & mask;
+ out[29] = ((w5 >> 57) | (w6 << 7)) & mask;
+ out[30] = (w6 >> 6) & mask;
+ out[31] = (w6 >> 19) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack14_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 16383ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 14) & mask;
+ out[2] = (w0 >> 28) & mask;
+ out[3] = (w0 >> 42) & mask;
+ out[4] = ((w0 >> 56) | (w1 << 8)) & mask;
+ out[5] = (w1 >> 6) & mask;
+ out[6] = (w1 >> 20) & mask;
+ out[7] = (w1 >> 34) & mask;
+ out[8] = (w1 >> 48) & mask;
+ out[9] = ((w1 >> 62) | (w2 << 2)) & mask;
+ out[10] = (w2 >> 12) & mask;
+ out[11] = (w2 >> 26) & mask;
+ out[12] = (w2 >> 40) & mask;
+ out[13] = ((w2 >> 54) | (w3 << 10)) & mask;
+ out[14] = (w3 >> 4) & mask;
+ out[15] = (w3 >> 18) & mask;
+ out[16] = (w3 >> 32) & mask;
+ out[17] = (w3 >> 46) & mask;
+ out[18] = ((w3 >> 60) | (w4 << 4)) & mask;
+ out[19] = (w4 >> 10) & mask;
+ out[20] = (w4 >> 24) & mask;
+ out[21] = (w4 >> 38) & mask;
+ out[22] = ((w4 >> 52) | (w5 << 12)) & mask;
+ out[23] = (w5 >> 2) & mask;
+ out[24] = (w5 >> 16) & mask;
+ out[25] = (w5 >> 30) & mask;
+ out[26] = (w5 >> 44) & mask;
+ out[27] = ((w5 >> 58) | (w6 << 6)) & mask;
+ out[28] = (w6 >> 8) & mask;
+ out[29] = (w6 >> 22) & mask;
+ out[30] = (w6 >> 36) & mask;
+ out[31] = w6 >> 50;
+
+ return in;
+}
+
+inline const uint8_t* unpack15_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 32767ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint32_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 15) & mask;
+ out[2] = (w0 >> 30) & mask;
+ out[3] = (w0 >> 45) & mask;
+ out[4] = ((w0 >> 60) | (w1 << 4)) & mask;
+ out[5] = (w1 >> 11) & mask;
+ out[6] = (w1 >> 26) & mask;
+ out[7] = (w1 >> 41) & mask;
+ out[8] = ((w1 >> 56) | (w2 << 8)) & mask;
+ out[9] = (w2 >> 7) & mask;
+ out[10] = (w2 >> 22) & mask;
+ out[11] = (w2 >> 37) & mask;
+ out[12] = ((w2 >> 52) | (w3 << 12)) & mask;
+ out[13] = (w3 >> 3) & mask;
+ out[14] = (w3 >> 18) & mask;
+ out[15] = (w3 >> 33) & mask;
+ out[16] = (w3 >> 48) & mask;
+ out[17] = ((w3 >> 63) | (w4 << 1)) & mask;
+ out[18] = (w4 >> 14) & mask;
+ out[19] = (w4 >> 29) & mask;
+ out[20] = (w4 >> 44) & mask;
+ out[21] = ((w4 >> 59) | (w5 << 5)) & mask;
+ out[22] = (w5 >> 10) & mask;
+ out[23] = (w5 >> 25) & mask;
+ out[24] = (w5 >> 40) & mask;
+ out[25] = ((w5 >> 55) | (w6 << 9)) & mask;
+ out[26] = (w6 >> 6) & mask;
+ out[27] = (w6 >> 21) & mask;
+ out[28] = (w6 >> 36) & mask;
+ out[29] = ((w6 >> 51) | (w7 << 13)) & mask;
+ out[30] = (w7 >> 2) & mask;
+ out[31] = (w7 >> 17) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack16_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 65535ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 16) & mask;
+ out[2] = (w0 >> 32) & mask;
+ out[3] = w0 >> 48;
+ out[4] = (w1)&mask;
+ out[5] = (w1 >> 16) & mask;
+ out[6] = (w1 >> 32) & mask;
+ out[7] = w1 >> 48;
+ out[8] = (w2)&mask;
+ out[9] = (w2 >> 16) & mask;
+ out[10] = (w2 >> 32) & mask;
+ out[11] = w2 >> 48;
+ out[12] = (w3)&mask;
+ out[13] = (w3 >> 16) & mask;
+ out[14] = (w3 >> 32) & mask;
+ out[15] = w3 >> 48;
+ out[16] = (w4)&mask;
+ out[17] = (w4 >> 16) & mask;
+ out[18] = (w4 >> 32) & mask;
+ out[19] = w4 >> 48;
+ out[20] = (w5)&mask;
+ out[21] = (w5 >> 16) & mask;
+ out[22] = (w5 >> 32) & mask;
+ out[23] = w5 >> 48;
+ out[24] = (w6)&mask;
+ out[25] = (w6 >> 16) & mask;
+ out[26] = (w6 >> 32) & mask;
+ out[27] = w6 >> 48;
+ out[28] = (w7)&mask;
+ out[29] = (w7 >> 16) & mask;
+ out[30] = (w7 >> 32) & mask;
+ out[31] = w7 >> 48;
+
+ return in;
+}
+
+inline const uint8_t* unpack17_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 131071ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint32_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 17) & mask;
+ out[2] = (w0 >> 34) & mask;
+ out[3] = ((w0 >> 51) | (w1 << 13)) & mask;
+ out[4] = (w1 >> 4) & mask;
+ out[5] = (w1 >> 21) & mask;
+ out[6] = (w1 >> 38) & mask;
+ out[7] = ((w1 >> 55) | (w2 << 9)) & mask;
+ out[8] = (w2 >> 8) & mask;
+ out[9] = (w2 >> 25) & mask;
+ out[10] = (w2 >> 42) & mask;
+ out[11] = ((w2 >> 59) | (w3 << 5)) & mask;
+ out[12] = (w3 >> 12) & mask;
+ out[13] = (w3 >> 29) & mask;
+ out[14] = (w3 >> 46) & mask;
+ out[15] = ((w3 >> 63) | (w4 << 1)) & mask;
+ out[16] = (w4 >> 16) & mask;
+ out[17] = (w4 >> 33) & mask;
+ out[18] = ((w4 >> 50) | (w5 << 14)) & mask;
+ out[19] = (w5 >> 3) & mask;
+ out[20] = (w5 >> 20) & mask;
+ out[21] = (w5 >> 37) & mask;
+ out[22] = ((w5 >> 54) | (w6 << 10)) & mask;
+ out[23] = (w6 >> 7) & mask;
+ out[24] = (w6 >> 24) & mask;
+ out[25] = (w6 >> 41) & mask;
+ out[26] = ((w6 >> 58) | (w7 << 6)) & mask;
+ out[27] = (w7 >> 11) & mask;
+ out[28] = (w7 >> 28) & mask;
+ out[29] = (w7 >> 45) & mask;
+ out[30] = ((w7 >> 62) | (w8 << 2)) & mask;
+ out[31] = (w8 >> 15) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack18_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 262143ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 18) & mask;
+ out[2] = (w0 >> 36) & mask;
+ out[3] = ((w0 >> 54) | (w1 << 10)) & mask;
+ out[4] = (w1 >> 8) & mask;
+ out[5] = (w1 >> 26) & mask;
+ out[6] = (w1 >> 44) & mask;
+ out[7] = ((w1 >> 62) | (w2 << 2)) & mask;
+ out[8] = (w2 >> 16) & mask;
+ out[9] = (w2 >> 34) & mask;
+ out[10] = ((w2 >> 52) | (w3 << 12)) & mask;
+ out[11] = (w3 >> 6) & mask;
+ out[12] = (w3 >> 24) & mask;
+ out[13] = (w3 >> 42) & mask;
+ out[14] = ((w3 >> 60) | (w4 << 4)) & mask;
+ out[15] = (w4 >> 14) & mask;
+ out[16] = (w4 >> 32) & mask;
+ out[17] = ((w4 >> 50) | (w5 << 14)) & mask;
+ out[18] = (w5 >> 4) & mask;
+ out[19] = (w5 >> 22) & mask;
+ out[20] = (w5 >> 40) & mask;
+ out[21] = ((w5 >> 58) | (w6 << 6)) & mask;
+ out[22] = (w6 >> 12) & mask;
+ out[23] = (w6 >> 30) & mask;
+ out[24] = ((w6 >> 48) | (w7 << 16)) & mask;
+ out[25] = (w7 >> 2) & mask;
+ out[26] = (w7 >> 20) & mask;
+ out[27] = (w7 >> 38) & mask;
+ out[28] = ((w7 >> 56) | (w8 << 8)) & mask;
+ out[29] = (w8 >> 10) & mask;
+ out[30] = (w8 >> 28) & mask;
+ out[31] = w8 >> 46;
+
+ return in;
+}
+
+inline const uint8_t* unpack19_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 524287ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint32_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 19) & mask;
+ out[2] = (w0 >> 38) & mask;
+ out[3] = ((w0 >> 57) | (w1 << 7)) & mask;
+ out[4] = (w1 >> 12) & mask;
+ out[5] = (w1 >> 31) & mask;
+ out[6] = ((w1 >> 50) | (w2 << 14)) & mask;
+ out[7] = (w2 >> 5) & mask;
+ out[8] = (w2 >> 24) & mask;
+ out[9] = (w2 >> 43) & mask;
+ out[10] = ((w2 >> 62) | (w3 << 2)) & mask;
+ out[11] = (w3 >> 17) & mask;
+ out[12] = (w3 >> 36) & mask;
+ out[13] = ((w3 >> 55) | (w4 << 9)) & mask;
+ out[14] = (w4 >> 10) & mask;
+ out[15] = (w4 >> 29) & mask;
+ out[16] = ((w4 >> 48) | (w5 << 16)) & mask;
+ out[17] = (w5 >> 3) & mask;
+ out[18] = (w5 >> 22) & mask;
+ out[19] = (w5 >> 41) & mask;
+ out[20] = ((w5 >> 60) | (w6 << 4)) & mask;
+ out[21] = (w6 >> 15) & mask;
+ out[22] = (w6 >> 34) & mask;
+ out[23] = ((w6 >> 53) | (w7 << 11)) & mask;
+ out[24] = (w7 >> 8) & mask;
+ out[25] = (w7 >> 27) & mask;
+ out[26] = ((w7 >> 46) | (w8 << 18)) & mask;
+ out[27] = (w8 >> 1) & mask;
+ out[28] = (w8 >> 20) & mask;
+ out[29] = (w8 >> 39) & mask;
+ out[30] = ((w8 >> 58) | (w9 << 6)) & mask;
+ out[31] = (w9 >> 13) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack20_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 1048575ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 20) & mask;
+ out[2] = (w0 >> 40) & mask;
+ out[3] = ((w0 >> 60) | (w1 << 4)) & mask;
+ out[4] = (w1 >> 16) & mask;
+ out[5] = (w1 >> 36) & mask;
+ out[6] = ((w1 >> 56) | (w2 << 8)) & mask;
+ out[7] = (w2 >> 12) & mask;
+ out[8] = (w2 >> 32) & mask;
+ out[9] = ((w2 >> 52) | (w3 << 12)) & mask;
+ out[10] = (w3 >> 8) & mask;
+ out[11] = (w3 >> 28) & mask;
+ out[12] = ((w3 >> 48) | (w4 << 16)) & mask;
+ out[13] = (w4 >> 4) & mask;
+ out[14] = (w4 >> 24) & mask;
+ out[15] = w4 >> 44;
+ out[16] = (w5)&mask;
+ out[17] = (w5 >> 20) & mask;
+ out[18] = (w5 >> 40) & mask;
+ out[19] = ((w5 >> 60) | (w6 << 4)) & mask;
+ out[20] = (w6 >> 16) & mask;
+ out[21] = (w6 >> 36) & mask;
+ out[22] = ((w6 >> 56) | (w7 << 8)) & mask;
+ out[23] = (w7 >> 12) & mask;
+ out[24] = (w7 >> 32) & mask;
+ out[25] = ((w7 >> 52) | (w8 << 12)) & mask;
+ out[26] = (w8 >> 8) & mask;
+ out[27] = (w8 >> 28) & mask;
+ out[28] = ((w8 >> 48) | (w9 << 16)) & mask;
+ out[29] = (w9 >> 4) & mask;
+ out[30] = (w9 >> 24) & mask;
+ out[31] = w9 >> 44;
+
+ return in;
+}
+
+inline const uint8_t* unpack21_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 2097151ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint32_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 21) & mask;
+ out[2] = (w0 >> 42) & mask;
+ out[3] = ((w0 >> 63) | (w1 << 1)) & mask;
+ out[4] = (w1 >> 20) & mask;
+ out[5] = (w1 >> 41) & mask;
+ out[6] = ((w1 >> 62) | (w2 << 2)) & mask;
+ out[7] = (w2 >> 19) & mask;
+ out[8] = (w2 >> 40) & mask;
+ out[9] = ((w2 >> 61) | (w3 << 3)) & mask;
+ out[10] = (w3 >> 18) & mask;
+ out[11] = (w3 >> 39) & mask;
+ out[12] = ((w3 >> 60) | (w4 << 4)) & mask;
+ out[13] = (w4 >> 17) & mask;
+ out[14] = (w4 >> 38) & mask;
+ out[15] = ((w4 >> 59) | (w5 << 5)) & mask;
+ out[16] = (w5 >> 16) & mask;
+ out[17] = (w5 >> 37) & mask;
+ out[18] = ((w5 >> 58) | (w6 << 6)) & mask;
+ out[19] = (w6 >> 15) & mask;
+ out[20] = (w6 >> 36) & mask;
+ out[21] = ((w6 >> 57) | (w7 << 7)) & mask;
+ out[22] = (w7 >> 14) & mask;
+ out[23] = (w7 >> 35) & mask;
+ out[24] = ((w7 >> 56) | (w8 << 8)) & mask;
+ out[25] = (w8 >> 13) & mask;
+ out[26] = (w8 >> 34) & mask;
+ out[27] = ((w8 >> 55) | (w9 << 9)) & mask;
+ out[28] = (w9 >> 12) & mask;
+ out[29] = (w9 >> 33) & mask;
+ out[30] = ((w9 >> 54) | (w10 << 10)) & mask;
+ out[31] = (w10 >> 11) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack22_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 4194303ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 22) & mask;
+ out[2] = ((w0 >> 44) | (w1 << 20)) & mask;
+ out[3] = (w1 >> 2) & mask;
+ out[4] = (w1 >> 24) & mask;
+ out[5] = ((w1 >> 46) | (w2 << 18)) & mask;
+ out[6] = (w2 >> 4) & mask;
+ out[7] = (w2 >> 26) & mask;
+ out[8] = ((w2 >> 48) | (w3 << 16)) & mask;
+ out[9] = (w3 >> 6) & mask;
+ out[10] = (w3 >> 28) & mask;
+ out[11] = ((w3 >> 50) | (w4 << 14)) & mask;
+ out[12] = (w4 >> 8) & mask;
+ out[13] = (w4 >> 30) & mask;
+ out[14] = ((w4 >> 52) | (w5 << 12)) & mask;
+ out[15] = (w5 >> 10) & mask;
+ out[16] = (w5 >> 32) & mask;
+ out[17] = ((w5 >> 54) | (w6 << 10)) & mask;
+ out[18] = (w6 >> 12) & mask;
+ out[19] = (w6 >> 34) & mask;
+ out[20] = ((w6 >> 56) | (w7 << 8)) & mask;
+ out[21] = (w7 >> 14) & mask;
+ out[22] = (w7 >> 36) & mask;
+ out[23] = ((w7 >> 58) | (w8 << 6)) & mask;
+ out[24] = (w8 >> 16) & mask;
+ out[25] = (w8 >> 38) & mask;
+ out[26] = ((w8 >> 60) | (w9 << 4)) & mask;
+ out[27] = (w9 >> 18) & mask;
+ out[28] = (w9 >> 40) & mask;
+ out[29] = ((w9 >> 62) | (w10 << 2)) & mask;
+ out[30] = (w10 >> 20) & mask;
+ out[31] = w10 >> 42;
+
+ return in;
+}
+
+inline const uint8_t* unpack23_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 8388607ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint32_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 23) & mask;
+ out[2] = ((w0 >> 46) | (w1 << 18)) & mask;
+ out[3] = (w1 >> 5) & mask;
+ out[4] = (w1 >> 28) & mask;
+ out[5] = ((w1 >> 51) | (w2 << 13)) & mask;
+ out[6] = (w2 >> 10) & mask;
+ out[7] = (w2 >> 33) & mask;
+ out[8] = ((w2 >> 56) | (w3 << 8)) & mask;
+ out[9] = (w3 >> 15) & mask;
+ out[10] = (w3 >> 38) & mask;
+ out[11] = ((w3 >> 61) | (w4 << 3)) & mask;
+ out[12] = (w4 >> 20) & mask;
+ out[13] = ((w4 >> 43) | (w5 << 21)) & mask;
+ out[14] = (w5 >> 2) & mask;
+ out[15] = (w5 >> 25) & mask;
+ out[16] = ((w5 >> 48) | (w6 << 16)) & mask;
+ out[17] = (w6 >> 7) & mask;
+ out[18] = (w6 >> 30) & mask;
+ out[19] = ((w6 >> 53) | (w7 << 11)) & mask;
+ out[20] = (w7 >> 12) & mask;
+ out[21] = (w7 >> 35) & mask;
+ out[22] = ((w7 >> 58) | (w8 << 6)) & mask;
+ out[23] = (w8 >> 17) & mask;
+ out[24] = (w8 >> 40) & mask;
+ out[25] = ((w8 >> 63) | (w9 << 1)) & mask;
+ out[26] = (w9 >> 22) & mask;
+ out[27] = ((w9 >> 45) | (w10 << 19)) & mask;
+ out[28] = (w10 >> 4) & mask;
+ out[29] = (w10 >> 27) & mask;
+ out[30] = ((w10 >> 50) | (w11 << 14)) & mask;
+ out[31] = (w11 >> 9) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack24_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 16777215ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 24) & mask;
+ out[2] = ((w0 >> 48) | (w1 << 16)) & mask;
+ out[3] = (w1 >> 8) & mask;
+ out[4] = (w1 >> 32) & mask;
+ out[5] = ((w1 >> 56) | (w2 << 8)) & mask;
+ out[6] = (w2 >> 16) & mask;
+ out[7] = w2 >> 40;
+ out[8] = (w3)&mask;
+ out[9] = (w3 >> 24) & mask;
+ out[10] = ((w3 >> 48) | (w4 << 16)) & mask;
+ out[11] = (w4 >> 8) & mask;
+ out[12] = (w4 >> 32) & mask;
+ out[13] = ((w4 >> 56) | (w5 << 8)) & mask;
+ out[14] = (w5 >> 16) & mask;
+ out[15] = w5 >> 40;
+ out[16] = (w6)&mask;
+ out[17] = (w6 >> 24) & mask;
+ out[18] = ((w6 >> 48) | (w7 << 16)) & mask;
+ out[19] = (w7 >> 8) & mask;
+ out[20] = (w7 >> 32) & mask;
+ out[21] = ((w7 >> 56) | (w8 << 8)) & mask;
+ out[22] = (w8 >> 16) & mask;
+ out[23] = w8 >> 40;
+ out[24] = (w9)&mask;
+ out[25] = (w9 >> 24) & mask;
+ out[26] = ((w9 >> 48) | (w10 << 16)) & mask;
+ out[27] = (w10 >> 8) & mask;
+ out[28] = (w10 >> 32) & mask;
+ out[29] = ((w10 >> 56) | (w11 << 8)) & mask;
+ out[30] = (w11 >> 16) & mask;
+ out[31] = w11 >> 40;
+
+ return in;
+}
+
+inline const uint8_t* unpack25_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 33554431ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint32_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 25) & mask;
+ out[2] = ((w0 >> 50) | (w1 << 14)) & mask;
+ out[3] = (w1 >> 11) & mask;
+ out[4] = (w1 >> 36) & mask;
+ out[5] = ((w1 >> 61) | (w2 << 3)) & mask;
+ out[6] = (w2 >> 22) & mask;
+ out[7] = ((w2 >> 47) | (w3 << 17)) & mask;
+ out[8] = (w3 >> 8) & mask;
+ out[9] = (w3 >> 33) & mask;
+ out[10] = ((w3 >> 58) | (w4 << 6)) & mask;
+ out[11] = (w4 >> 19) & mask;
+ out[12] = ((w4 >> 44) | (w5 << 20)) & mask;
+ out[13] = (w5 >> 5) & mask;
+ out[14] = (w5 >> 30) & mask;
+ out[15] = ((w5 >> 55) | (w6 << 9)) & mask;
+ out[16] = (w6 >> 16) & mask;
+ out[17] = ((w6 >> 41) | (w7 << 23)) & mask;
+ out[18] = (w7 >> 2) & mask;
+ out[19] = (w7 >> 27) & mask;
+ out[20] = ((w7 >> 52) | (w8 << 12)) & mask;
+ out[21] = (w8 >> 13) & mask;
+ out[22] = (w8 >> 38) & mask;
+ out[23] = ((w8 >> 63) | (w9 << 1)) & mask;
+ out[24] = (w9 >> 24) & mask;
+ out[25] = ((w9 >> 49) | (w10 << 15)) & mask;
+ out[26] = (w10 >> 10) & mask;
+ out[27] = (w10 >> 35) & mask;
+ out[28] = ((w10 >> 60) | (w11 << 4)) & mask;
+ out[29] = (w11 >> 21) & mask;
+ out[30] = ((w11 >> 46) | (w12 << 18)) & mask;
+ out[31] = (w12 >> 7) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack26_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 67108863ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 26) & mask;
+ out[2] = ((w0 >> 52) | (w1 << 12)) & mask;
+ out[3] = (w1 >> 14) & mask;
+ out[4] = ((w1 >> 40) | (w2 << 24)) & mask;
+ out[5] = (w2 >> 2) & mask;
+ out[6] = (w2 >> 28) & mask;
+ out[7] = ((w2 >> 54) | (w3 << 10)) & mask;
+ out[8] = (w3 >> 16) & mask;
+ out[9] = ((w3 >> 42) | (w4 << 22)) & mask;
+ out[10] = (w4 >> 4) & mask;
+ out[11] = (w4 >> 30) & mask;
+ out[12] = ((w4 >> 56) | (w5 << 8)) & mask;
+ out[13] = (w5 >> 18) & mask;
+ out[14] = ((w5 >> 44) | (w6 << 20)) & mask;
+ out[15] = (w6 >> 6) & mask;
+ out[16] = (w6 >> 32) & mask;
+ out[17] = ((w6 >> 58) | (w7 << 6)) & mask;
+ out[18] = (w7 >> 20) & mask;
+ out[19] = ((w7 >> 46) | (w8 << 18)) & mask;
+ out[20] = (w8 >> 8) & mask;
+ out[21] = (w8 >> 34) & mask;
+ out[22] = ((w8 >> 60) | (w9 << 4)) & mask;
+ out[23] = (w9 >> 22) & mask;
+ out[24] = ((w9 >> 48) | (w10 << 16)) & mask;
+ out[25] = (w10 >> 10) & mask;
+ out[26] = (w10 >> 36) & mask;
+ out[27] = ((w10 >> 62) | (w11 << 2)) & mask;
+ out[28] = (w11 >> 24) & mask;
+ out[29] = ((w11 >> 50) | (w12 << 14)) & mask;
+ out[30] = (w12 >> 12) & mask;
+ out[31] = w12 >> 38;
+
+ return in;
+}
+
+inline const uint8_t* unpack27_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 134217727ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint32_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 27) & mask;
+ out[2] = ((w0 >> 54) | (w1 << 10)) & mask;
+ out[3] = (w1 >> 17) & mask;
+ out[4] = ((w1 >> 44) | (w2 << 20)) & mask;
+ out[5] = (w2 >> 7) & mask;
+ out[6] = (w2 >> 34) & mask;
+ out[7] = ((w2 >> 61) | (w3 << 3)) & mask;
+ out[8] = (w3 >> 24) & mask;
+ out[9] = ((w3 >> 51) | (w4 << 13)) & mask;
+ out[10] = (w4 >> 14) & mask;
+ out[11] = ((w4 >> 41) | (w5 << 23)) & mask;
+ out[12] = (w5 >> 4) & mask;
+ out[13] = (w5 >> 31) & mask;
+ out[14] = ((w5 >> 58) | (w6 << 6)) & mask;
+ out[15] = (w6 >> 21) & mask;
+ out[16] = ((w6 >> 48) | (w7 << 16)) & mask;
+ out[17] = (w7 >> 11) & mask;
+ out[18] = ((w7 >> 38) | (w8 << 26)) & mask;
+ out[19] = (w8 >> 1) & mask;
+ out[20] = (w8 >> 28) & mask;
+ out[21] = ((w8 >> 55) | (w9 << 9)) & mask;
+ out[22] = (w9 >> 18) & mask;
+ out[23] = ((w9 >> 45) | (w10 << 19)) & mask;
+ out[24] = (w10 >> 8) & mask;
+ out[25] = (w10 >> 35) & mask;
+ out[26] = ((w10 >> 62) | (w11 << 2)) & mask;
+ out[27] = (w11 >> 25) & mask;
+ out[28] = ((w11 >> 52) | (w12 << 12)) & mask;
+ out[29] = (w12 >> 15) & mask;
+ out[30] = ((w12 >> 42) | (w13 << 22)) & mask;
+ out[31] = (w13 >> 5) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack28_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 268435455ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 28) & mask;
+ out[2] = ((w0 >> 56) | (w1 << 8)) & mask;
+ out[3] = (w1 >> 20) & mask;
+ out[4] = ((w1 >> 48) | (w2 << 16)) & mask;
+ out[5] = (w2 >> 12) & mask;
+ out[6] = ((w2 >> 40) | (w3 << 24)) & mask;
+ out[7] = (w3 >> 4) & mask;
+ out[8] = (w3 >> 32) & mask;
+ out[9] = ((w3 >> 60) | (w4 << 4)) & mask;
+ out[10] = (w4 >> 24) & mask;
+ out[11] = ((w4 >> 52) | (w5 << 12)) & mask;
+ out[12] = (w5 >> 16) & mask;
+ out[13] = ((w5 >> 44) | (w6 << 20)) & mask;
+ out[14] = (w6 >> 8) & mask;
+ out[15] = w6 >> 36;
+ out[16] = (w7)&mask;
+ out[17] = (w7 >> 28) & mask;
+ out[18] = ((w7 >> 56) | (w8 << 8)) & mask;
+ out[19] = (w8 >> 20) & mask;
+ out[20] = ((w8 >> 48) | (w9 << 16)) & mask;
+ out[21] = (w9 >> 12) & mask;
+ out[22] = ((w9 >> 40) | (w10 << 24)) & mask;
+ out[23] = (w10 >> 4) & mask;
+ out[24] = (w10 >> 32) & mask;
+ out[25] = ((w10 >> 60) | (w11 << 4)) & mask;
+ out[26] = (w11 >> 24) & mask;
+ out[27] = ((w11 >> 52) | (w12 << 12)) & mask;
+ out[28] = (w12 >> 16) & mask;
+ out[29] = ((w12 >> 44) | (w13 << 20)) & mask;
+ out[30] = (w13 >> 8) & mask;
+ out[31] = w13 >> 36;
+
+ return in;
+}
+
+inline const uint8_t* unpack29_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 536870911ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint32_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 29) & mask;
+ out[2] = ((w0 >> 58) | (w1 << 6)) & mask;
+ out[3] = (w1 >> 23) & mask;
+ out[4] = ((w1 >> 52) | (w2 << 12)) & mask;
+ out[5] = (w2 >> 17) & mask;
+ out[6] = ((w2 >> 46) | (w3 << 18)) & mask;
+ out[7] = (w3 >> 11) & mask;
+ out[8] = ((w3 >> 40) | (w4 << 24)) & mask;
+ out[9] = (w4 >> 5) & mask;
+ out[10] = (w4 >> 34) & mask;
+ out[11] = ((w4 >> 63) | (w5 << 1)) & mask;
+ out[12] = (w5 >> 28) & mask;
+ out[13] = ((w5 >> 57) | (w6 << 7)) & mask;
+ out[14] = (w6 >> 22) & mask;
+ out[15] = ((w6 >> 51) | (w7 << 13)) & mask;
+ out[16] = (w7 >> 16) & mask;
+ out[17] = ((w7 >> 45) | (w8 << 19)) & mask;
+ out[18] = (w8 >> 10) & mask;
+ out[19] = ((w8 >> 39) | (w9 << 25)) & mask;
+ out[20] = (w9 >> 4) & mask;
+ out[21] = (w9 >> 33) & mask;
+ out[22] = ((w9 >> 62) | (w10 << 2)) & mask;
+ out[23] = (w10 >> 27) & mask;
+ out[24] = ((w10 >> 56) | (w11 << 8)) & mask;
+ out[25] = (w11 >> 21) & mask;
+ out[26] = ((w11 >> 50) | (w12 << 14)) & mask;
+ out[27] = (w12 >> 15) & mask;
+ out[28] = ((w12 >> 44) | (w13 << 20)) & mask;
+ out[29] = (w13 >> 9) & mask;
+ out[30] = ((w13 >> 38) | (w14 << 26)) & mask;
+ out[31] = (w14 >> 3) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack30_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 1073741823ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 30) & mask;
+ out[2] = ((w0 >> 60) | (w1 << 4)) & mask;
+ out[3] = (w1 >> 26) & mask;
+ out[4] = ((w1 >> 56) | (w2 << 8)) & mask;
+ out[5] = (w2 >> 22) & mask;
+ out[6] = ((w2 >> 52) | (w3 << 12)) & mask;
+ out[7] = (w3 >> 18) & mask;
+ out[8] = ((w3 >> 48) | (w4 << 16)) & mask;
+ out[9] = (w4 >> 14) & mask;
+ out[10] = ((w4 >> 44) | (w5 << 20)) & mask;
+ out[11] = (w5 >> 10) & mask;
+ out[12] = ((w5 >> 40) | (w6 << 24)) & mask;
+ out[13] = (w6 >> 6) & mask;
+ out[14] = ((w6 >> 36) | (w7 << 28)) & mask;
+ out[15] = (w7 >> 2) & mask;
+ out[16] = (w7 >> 32) & mask;
+ out[17] = ((w7 >> 62) | (w8 << 2)) & mask;
+ out[18] = (w8 >> 28) & mask;
+ out[19] = ((w8 >> 58) | (w9 << 6)) & mask;
+ out[20] = (w9 >> 24) & mask;
+ out[21] = ((w9 >> 54) | (w10 << 10)) & mask;
+ out[22] = (w10 >> 20) & mask;
+ out[23] = ((w10 >> 50) | (w11 << 14)) & mask;
+ out[24] = (w11 >> 16) & mask;
+ out[25] = ((w11 >> 46) | (w12 << 18)) & mask;
+ out[26] = (w12 >> 12) & mask;
+ out[27] = ((w12 >> 42) | (w13 << 22)) & mask;
+ out[28] = (w13 >> 8) & mask;
+ out[29] = ((w13 >> 38) | (w14 << 26)) & mask;
+ out[30] = (w14 >> 4) & mask;
+ out[31] = w14 >> 34;
+
+ return in;
+}
+
+inline const uint8_t* unpack31_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 2147483647ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint32_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = (w0 >> 31) & mask;
+ out[2] = ((w0 >> 62) | (w1 << 2)) & mask;
+ out[3] = (w1 >> 29) & mask;
+ out[4] = ((w1 >> 60) | (w2 << 4)) & mask;
+ out[5] = (w2 >> 27) & mask;
+ out[6] = ((w2 >> 58) | (w3 << 6)) & mask;
+ out[7] = (w3 >> 25) & mask;
+ out[8] = ((w3 >> 56) | (w4 << 8)) & mask;
+ out[9] = (w4 >> 23) & mask;
+ out[10] = ((w4 >> 54) | (w5 << 10)) & mask;
+ out[11] = (w5 >> 21) & mask;
+ out[12] = ((w5 >> 52) | (w6 << 12)) & mask;
+ out[13] = (w6 >> 19) & mask;
+ out[14] = ((w6 >> 50) | (w7 << 14)) & mask;
+ out[15] = (w7 >> 17) & mask;
+ out[16] = ((w7 >> 48) | (w8 << 16)) & mask;
+ out[17] = (w8 >> 15) & mask;
+ out[18] = ((w8 >> 46) | (w9 << 18)) & mask;
+ out[19] = (w9 >> 13) & mask;
+ out[20] = ((w9 >> 44) | (w10 << 20)) & mask;
+ out[21] = (w10 >> 11) & mask;
+ out[22] = ((w10 >> 42) | (w11 << 22)) & mask;
+ out[23] = (w11 >> 9) & mask;
+ out[24] = ((w11 >> 40) | (w12 << 24)) & mask;
+ out[25] = (w12 >> 7) & mask;
+ out[26] = ((w12 >> 38) | (w13 << 26)) & mask;
+ out[27] = (w13 >> 5) & mask;
+ out[28] = ((w13 >> 36) | (w14 << 28)) & mask;
+ out[29] = (w14 >> 3) & mask;
+ out[30] = ((w14 >> 34) | (w15 << 30)) & mask;
+ out[31] = (w15 >> 1) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack32_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 4294967295ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = w0 >> 32;
+ out[2] = (w1)&mask;
+ out[3] = w1 >> 32;
+ out[4] = (w2)&mask;
+ out[5] = w2 >> 32;
+ out[6] = (w3)&mask;
+ out[7] = w3 >> 32;
+ out[8] = (w4)&mask;
+ out[9] = w4 >> 32;
+ out[10] = (w5)&mask;
+ out[11] = w5 >> 32;
+ out[12] = (w6)&mask;
+ out[13] = w6 >> 32;
+ out[14] = (w7)&mask;
+ out[15] = w7 >> 32;
+ out[16] = (w8)&mask;
+ out[17] = w8 >> 32;
+ out[18] = (w9)&mask;
+ out[19] = w9 >> 32;
+ out[20] = (w10)&mask;
+ out[21] = w10 >> 32;
+ out[22] = (w11)&mask;
+ out[23] = w11 >> 32;
+ out[24] = (w12)&mask;
+ out[25] = w12 >> 32;
+ out[26] = (w13)&mask;
+ out[27] = w13 >> 32;
+ out[28] = (w14)&mask;
+ out[29] = w14 >> 32;
+ out[30] = (w15)&mask;
+ out[31] = w15 >> 32;
+
+ return in;
+}
+
+inline const uint8_t* unpack33_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 8589934591ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint32_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 33) | (w1 << 31)) & mask;
+ out[2] = (w1 >> 2) & mask;
+ out[3] = ((w1 >> 35) | (w2 << 29)) & mask;
+ out[4] = (w2 >> 4) & mask;
+ out[5] = ((w2 >> 37) | (w3 << 27)) & mask;
+ out[6] = (w3 >> 6) & mask;
+ out[7] = ((w3 >> 39) | (w4 << 25)) & mask;
+ out[8] = (w4 >> 8) & mask;
+ out[9] = ((w4 >> 41) | (w5 << 23)) & mask;
+ out[10] = (w5 >> 10) & mask;
+ out[11] = ((w5 >> 43) | (w6 << 21)) & mask;
+ out[12] = (w6 >> 12) & mask;
+ out[13] = ((w6 >> 45) | (w7 << 19)) & mask;
+ out[14] = (w7 >> 14) & mask;
+ out[15] = ((w7 >> 47) | (w8 << 17)) & mask;
+ out[16] = (w8 >> 16) & mask;
+ out[17] = ((w8 >> 49) | (w9 << 15)) & mask;
+ out[18] = (w9 >> 18) & mask;
+ out[19] = ((w9 >> 51) | (w10 << 13)) & mask;
+ out[20] = (w10 >> 20) & mask;
+ out[21] = ((w10 >> 53) | (w11 << 11)) & mask;
+ out[22] = (w11 >> 22) & mask;
+ out[23] = ((w11 >> 55) | (w12 << 9)) & mask;
+ out[24] = (w12 >> 24) & mask;
+ out[25] = ((w12 >> 57) | (w13 << 7)) & mask;
+ out[26] = (w13 >> 26) & mask;
+ out[27] = ((w13 >> 59) | (w14 << 5)) & mask;
+ out[28] = (w14 >> 28) & mask;
+ out[29] = ((w14 >> 61) | (w15 << 3)) & mask;
+ out[30] = (w15 >> 30) & mask;
+ out[31] = ((w15 >> 63) | (w16 << 1)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack34_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 17179869183ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 34) | (w1 << 30)) & mask;
+ out[2] = (w1 >> 4) & mask;
+ out[3] = ((w1 >> 38) | (w2 << 26)) & mask;
+ out[4] = (w2 >> 8) & mask;
+ out[5] = ((w2 >> 42) | (w3 << 22)) & mask;
+ out[6] = (w3 >> 12) & mask;
+ out[7] = ((w3 >> 46) | (w4 << 18)) & mask;
+ out[8] = (w4 >> 16) & mask;
+ out[9] = ((w4 >> 50) | (w5 << 14)) & mask;
+ out[10] = (w5 >> 20) & mask;
+ out[11] = ((w5 >> 54) | (w6 << 10)) & mask;
+ out[12] = (w6 >> 24) & mask;
+ out[13] = ((w6 >> 58) | (w7 << 6)) & mask;
+ out[14] = (w7 >> 28) & mask;
+ out[15] = ((w7 >> 62) | (w8 << 2)) & mask;
+ out[16] = ((w8 >> 32) | (w9 << 32)) & mask;
+ out[17] = (w9 >> 2) & mask;
+ out[18] = ((w9 >> 36) | (w10 << 28)) & mask;
+ out[19] = (w10 >> 6) & mask;
+ out[20] = ((w10 >> 40) | (w11 << 24)) & mask;
+ out[21] = (w11 >> 10) & mask;
+ out[22] = ((w11 >> 44) | (w12 << 20)) & mask;
+ out[23] = (w12 >> 14) & mask;
+ out[24] = ((w12 >> 48) | (w13 << 16)) & mask;
+ out[25] = (w13 >> 18) & mask;
+ out[26] = ((w13 >> 52) | (w14 << 12)) & mask;
+ out[27] = (w14 >> 22) & mask;
+ out[28] = ((w14 >> 56) | (w15 << 8)) & mask;
+ out[29] = (w15 >> 26) & mask;
+ out[30] = ((w15 >> 60) | (w16 << 4)) & mask;
+ out[31] = w16 >> 30;
+
+ return in;
+}
+
+inline const uint8_t* unpack35_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 34359738367ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint32_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 35) | (w1 << 29)) & mask;
+ out[2] = (w1 >> 6) & mask;
+ out[3] = ((w1 >> 41) | (w2 << 23)) & mask;
+ out[4] = (w2 >> 12) & mask;
+ out[5] = ((w2 >> 47) | (w3 << 17)) & mask;
+ out[6] = (w3 >> 18) & mask;
+ out[7] = ((w3 >> 53) | (w4 << 11)) & mask;
+ out[8] = (w4 >> 24) & mask;
+ out[9] = ((w4 >> 59) | (w5 << 5)) & mask;
+ out[10] = ((w5 >> 30) | (w6 << 34)) & mask;
+ out[11] = (w6 >> 1) & mask;
+ out[12] = ((w6 >> 36) | (w7 << 28)) & mask;
+ out[13] = (w7 >> 7) & mask;
+ out[14] = ((w7 >> 42) | (w8 << 22)) & mask;
+ out[15] = (w8 >> 13) & mask;
+ out[16] = ((w8 >> 48) | (w9 << 16)) & mask;
+ out[17] = (w9 >> 19) & mask;
+ out[18] = ((w9 >> 54) | (w10 << 10)) & mask;
+ out[19] = (w10 >> 25) & mask;
+ out[20] = ((w10 >> 60) | (w11 << 4)) & mask;
+ out[21] = ((w11 >> 31) | (w12 << 33)) & mask;
+ out[22] = (w12 >> 2) & mask;
+ out[23] = ((w12 >> 37) | (w13 << 27)) & mask;
+ out[24] = (w13 >> 8) & mask;
+ out[25] = ((w13 >> 43) | (w14 << 21)) & mask;
+ out[26] = (w14 >> 14) & mask;
+ out[27] = ((w14 >> 49) | (w15 << 15)) & mask;
+ out[28] = (w15 >> 20) & mask;
+ out[29] = ((w15 >> 55) | (w16 << 9)) & mask;
+ out[30] = (w16 >> 26) & mask;
+ out[31] = ((w16 >> 61) | (w17 << 3)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack36_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 68719476735ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 36) | (w1 << 28)) & mask;
+ out[2] = (w1 >> 8) & mask;
+ out[3] = ((w1 >> 44) | (w2 << 20)) & mask;
+ out[4] = (w2 >> 16) & mask;
+ out[5] = ((w2 >> 52) | (w3 << 12)) & mask;
+ out[6] = (w3 >> 24) & mask;
+ out[7] = ((w3 >> 60) | (w4 << 4)) & mask;
+ out[8] = ((w4 >> 32) | (w5 << 32)) & mask;
+ out[9] = (w5 >> 4) & mask;
+ out[10] = ((w5 >> 40) | (w6 << 24)) & mask;
+ out[11] = (w6 >> 12) & mask;
+ out[12] = ((w6 >> 48) | (w7 << 16)) & mask;
+ out[13] = (w7 >> 20) & mask;
+ out[14] = ((w7 >> 56) | (w8 << 8)) & mask;
+ out[15] = w8 >> 28;
+ out[16] = (w9)&mask;
+ out[17] = ((w9 >> 36) | (w10 << 28)) & mask;
+ out[18] = (w10 >> 8) & mask;
+ out[19] = ((w10 >> 44) | (w11 << 20)) & mask;
+ out[20] = (w11 >> 16) & mask;
+ out[21] = ((w11 >> 52) | (w12 << 12)) & mask;
+ out[22] = (w12 >> 24) & mask;
+ out[23] = ((w12 >> 60) | (w13 << 4)) & mask;
+ out[24] = ((w13 >> 32) | (w14 << 32)) & mask;
+ out[25] = (w14 >> 4) & mask;
+ out[26] = ((w14 >> 40) | (w15 << 24)) & mask;
+ out[27] = (w15 >> 12) & mask;
+ out[28] = ((w15 >> 48) | (w16 << 16)) & mask;
+ out[29] = (w16 >> 20) & mask;
+ out[30] = ((w16 >> 56) | (w17 << 8)) & mask;
+ out[31] = w17 >> 28;
+
+ return in;
+}
+
+inline const uint8_t* unpack37_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 137438953471ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint32_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 37) | (w1 << 27)) & mask;
+ out[2] = (w1 >> 10) & mask;
+ out[3] = ((w1 >> 47) | (w2 << 17)) & mask;
+ out[4] = (w2 >> 20) & mask;
+ out[5] = ((w2 >> 57) | (w3 << 7)) & mask;
+ out[6] = ((w3 >> 30) | (w4 << 34)) & mask;
+ out[7] = (w4 >> 3) & mask;
+ out[8] = ((w4 >> 40) | (w5 << 24)) & mask;
+ out[9] = (w5 >> 13) & mask;
+ out[10] = ((w5 >> 50) | (w6 << 14)) & mask;
+ out[11] = (w6 >> 23) & mask;
+ out[12] = ((w6 >> 60) | (w7 << 4)) & mask;
+ out[13] = ((w7 >> 33) | (w8 << 31)) & mask;
+ out[14] = (w8 >> 6) & mask;
+ out[15] = ((w8 >> 43) | (w9 << 21)) & mask;
+ out[16] = (w9 >> 16) & mask;
+ out[17] = ((w9 >> 53) | (w10 << 11)) & mask;
+ out[18] = (w10 >> 26) & mask;
+ out[19] = ((w10 >> 63) | (w11 << 1)) & mask;
+ out[20] = ((w11 >> 36) | (w12 << 28)) & mask;
+ out[21] = (w12 >> 9) & mask;
+ out[22] = ((w12 >> 46) | (w13 << 18)) & mask;
+ out[23] = (w13 >> 19) & mask;
+ out[24] = ((w13 >> 56) | (w14 << 8)) & mask;
+ out[25] = ((w14 >> 29) | (w15 << 35)) & mask;
+ out[26] = (w15 >> 2) & mask;
+ out[27] = ((w15 >> 39) | (w16 << 25)) & mask;
+ out[28] = (w16 >> 12) & mask;
+ out[29] = ((w16 >> 49) | (w17 << 15)) & mask;
+ out[30] = (w17 >> 22) & mask;
+ out[31] = ((w17 >> 59) | (w18 << 5)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack38_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 274877906943ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 38) | (w1 << 26)) & mask;
+ out[2] = (w1 >> 12) & mask;
+ out[3] = ((w1 >> 50) | (w2 << 14)) & mask;
+ out[4] = (w2 >> 24) & mask;
+ out[5] = ((w2 >> 62) | (w3 << 2)) & mask;
+ out[6] = ((w3 >> 36) | (w4 << 28)) & mask;
+ out[7] = (w4 >> 10) & mask;
+ out[8] = ((w4 >> 48) | (w5 << 16)) & mask;
+ out[9] = (w5 >> 22) & mask;
+ out[10] = ((w5 >> 60) | (w6 << 4)) & mask;
+ out[11] = ((w6 >> 34) | (w7 << 30)) & mask;
+ out[12] = (w7 >> 8) & mask;
+ out[13] = ((w7 >> 46) | (w8 << 18)) & mask;
+ out[14] = (w8 >> 20) & mask;
+ out[15] = ((w8 >> 58) | (w9 << 6)) & mask;
+ out[16] = ((w9 >> 32) | (w10 << 32)) & mask;
+ out[17] = (w10 >> 6) & mask;
+ out[18] = ((w10 >> 44) | (w11 << 20)) & mask;
+ out[19] = (w11 >> 18) & mask;
+ out[20] = ((w11 >> 56) | (w12 << 8)) & mask;
+ out[21] = ((w12 >> 30) | (w13 << 34)) & mask;
+ out[22] = (w13 >> 4) & mask;
+ out[23] = ((w13 >> 42) | (w14 << 22)) & mask;
+ out[24] = (w14 >> 16) & mask;
+ out[25] = ((w14 >> 54) | (w15 << 10)) & mask;
+ out[26] = ((w15 >> 28) | (w16 << 36)) & mask;
+ out[27] = (w16 >> 2) & mask;
+ out[28] = ((w16 >> 40) | (w17 << 24)) & mask;
+ out[29] = (w17 >> 14) & mask;
+ out[30] = ((w17 >> 52) | (w18 << 12)) & mask;
+ out[31] = w18 >> 26;
+
+ return in;
+}
+
+inline const uint8_t* unpack39_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 549755813887ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint32_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 39) | (w1 << 25)) & mask;
+ out[2] = (w1 >> 14) & mask;
+ out[3] = ((w1 >> 53) | (w2 << 11)) & mask;
+ out[4] = ((w2 >> 28) | (w3 << 36)) & mask;
+ out[5] = (w3 >> 3) & mask;
+ out[6] = ((w3 >> 42) | (w4 << 22)) & mask;
+ out[7] = (w4 >> 17) & mask;
+ out[8] = ((w4 >> 56) | (w5 << 8)) & mask;
+ out[9] = ((w5 >> 31) | (w6 << 33)) & mask;
+ out[10] = (w6 >> 6) & mask;
+ out[11] = ((w6 >> 45) | (w7 << 19)) & mask;
+ out[12] = (w7 >> 20) & mask;
+ out[13] = ((w7 >> 59) | (w8 << 5)) & mask;
+ out[14] = ((w8 >> 34) | (w9 << 30)) & mask;
+ out[15] = (w9 >> 9) & mask;
+ out[16] = ((w9 >> 48) | (w10 << 16)) & mask;
+ out[17] = (w10 >> 23) & mask;
+ out[18] = ((w10 >> 62) | (w11 << 2)) & mask;
+ out[19] = ((w11 >> 37) | (w12 << 27)) & mask;
+ out[20] = (w12 >> 12) & mask;
+ out[21] = ((w12 >> 51) | (w13 << 13)) & mask;
+ out[22] = ((w13 >> 26) | (w14 << 38)) & mask;
+ out[23] = (w14 >> 1) & mask;
+ out[24] = ((w14 >> 40) | (w15 << 24)) & mask;
+ out[25] = (w15 >> 15) & mask;
+ out[26] = ((w15 >> 54) | (w16 << 10)) & mask;
+ out[27] = ((w16 >> 29) | (w17 << 35)) & mask;
+ out[28] = (w17 >> 4) & mask;
+ out[29] = ((w17 >> 43) | (w18 << 21)) & mask;
+ out[30] = (w18 >> 18) & mask;
+ out[31] = ((w18 >> 57) | (w19 << 7)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack40_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 1099511627775ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 40) | (w1 << 24)) & mask;
+ out[2] = (w1 >> 16) & mask;
+ out[3] = ((w1 >> 56) | (w2 << 8)) & mask;
+ out[4] = ((w2 >> 32) | (w3 << 32)) & mask;
+ out[5] = (w3 >> 8) & mask;
+ out[6] = ((w3 >> 48) | (w4 << 16)) & mask;
+ out[7] = w4 >> 24;
+ out[8] = (w5)&mask;
+ out[9] = ((w5 >> 40) | (w6 << 24)) & mask;
+ out[10] = (w6 >> 16) & mask;
+ out[11] = ((w6 >> 56) | (w7 << 8)) & mask;
+ out[12] = ((w7 >> 32) | (w8 << 32)) & mask;
+ out[13] = (w8 >> 8) & mask;
+ out[14] = ((w8 >> 48) | (w9 << 16)) & mask;
+ out[15] = w9 >> 24;
+ out[16] = (w10)&mask;
+ out[17] = ((w10 >> 40) | (w11 << 24)) & mask;
+ out[18] = (w11 >> 16) & mask;
+ out[19] = ((w11 >> 56) | (w12 << 8)) & mask;
+ out[20] = ((w12 >> 32) | (w13 << 32)) & mask;
+ out[21] = (w13 >> 8) & mask;
+ out[22] = ((w13 >> 48) | (w14 << 16)) & mask;
+ out[23] = w14 >> 24;
+ out[24] = (w15)&mask;
+ out[25] = ((w15 >> 40) | (w16 << 24)) & mask;
+ out[26] = (w16 >> 16) & mask;
+ out[27] = ((w16 >> 56) | (w17 << 8)) & mask;
+ out[28] = ((w17 >> 32) | (w18 << 32)) & mask;
+ out[29] = (w18 >> 8) & mask;
+ out[30] = ((w18 >> 48) | (w19 << 16)) & mask;
+ out[31] = w19 >> 24;
+
+ return in;
+}
+
+inline const uint8_t* unpack41_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 2199023255551ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint32_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 41) | (w1 << 23)) & mask;
+ out[2] = (w1 >> 18) & mask;
+ out[3] = ((w1 >> 59) | (w2 << 5)) & mask;
+ out[4] = ((w2 >> 36) | (w3 << 28)) & mask;
+ out[5] = (w3 >> 13) & mask;
+ out[6] = ((w3 >> 54) | (w4 << 10)) & mask;
+ out[7] = ((w4 >> 31) | (w5 << 33)) & mask;
+ out[8] = (w5 >> 8) & mask;
+ out[9] = ((w5 >> 49) | (w6 << 15)) & mask;
+ out[10] = ((w6 >> 26) | (w7 << 38)) & mask;
+ out[11] = (w7 >> 3) & mask;
+ out[12] = ((w7 >> 44) | (w8 << 20)) & mask;
+ out[13] = (w8 >> 21) & mask;
+ out[14] = ((w8 >> 62) | (w9 << 2)) & mask;
+ out[15] = ((w9 >> 39) | (w10 << 25)) & mask;
+ out[16] = (w10 >> 16) & mask;
+ out[17] = ((w10 >> 57) | (w11 << 7)) & mask;
+ out[18] = ((w11 >> 34) | (w12 << 30)) & mask;
+ out[19] = (w12 >> 11) & mask;
+ out[20] = ((w12 >> 52) | (w13 << 12)) & mask;
+ out[21] = ((w13 >> 29) | (w14 << 35)) & mask;
+ out[22] = (w14 >> 6) & mask;
+ out[23] = ((w14 >> 47) | (w15 << 17)) & mask;
+ out[24] = ((w15 >> 24) | (w16 << 40)) & mask;
+ out[25] = (w16 >> 1) & mask;
+ out[26] = ((w16 >> 42) | (w17 << 22)) & mask;
+ out[27] = (w17 >> 19) & mask;
+ out[28] = ((w17 >> 60) | (w18 << 4)) & mask;
+ out[29] = ((w18 >> 37) | (w19 << 27)) & mask;
+ out[30] = (w19 >> 14) & mask;
+ out[31] = ((w19 >> 55) | (w20 << 9)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack42_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 4398046511103ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 42) | (w1 << 22)) & mask;
+ out[2] = (w1 >> 20) & mask;
+ out[3] = ((w1 >> 62) | (w2 << 2)) & mask;
+ out[4] = ((w2 >> 40) | (w3 << 24)) & mask;
+ out[5] = (w3 >> 18) & mask;
+ out[6] = ((w3 >> 60) | (w4 << 4)) & mask;
+ out[7] = ((w4 >> 38) | (w5 << 26)) & mask;
+ out[8] = (w5 >> 16) & mask;
+ out[9] = ((w5 >> 58) | (w6 << 6)) & mask;
+ out[10] = ((w6 >> 36) | (w7 << 28)) & mask;
+ out[11] = (w7 >> 14) & mask;
+ out[12] = ((w7 >> 56) | (w8 << 8)) & mask;
+ out[13] = ((w8 >> 34) | (w9 << 30)) & mask;
+ out[14] = (w9 >> 12) & mask;
+ out[15] = ((w9 >> 54) | (w10 << 10)) & mask;
+ out[16] = ((w10 >> 32) | (w11 << 32)) & mask;
+ out[17] = (w11 >> 10) & mask;
+ out[18] = ((w11 >> 52) | (w12 << 12)) & mask;
+ out[19] = ((w12 >> 30) | (w13 << 34)) & mask;
+ out[20] = (w13 >> 8) & mask;
+ out[21] = ((w13 >> 50) | (w14 << 14)) & mask;
+ out[22] = ((w14 >> 28) | (w15 << 36)) & mask;
+ out[23] = (w15 >> 6) & mask;
+ out[24] = ((w15 >> 48) | (w16 << 16)) & mask;
+ out[25] = ((w16 >> 26) | (w17 << 38)) & mask;
+ out[26] = (w17 >> 4) & mask;
+ out[27] = ((w17 >> 46) | (w18 << 18)) & mask;
+ out[28] = ((w18 >> 24) | (w19 << 40)) & mask;
+ out[29] = (w19 >> 2) & mask;
+ out[30] = ((w19 >> 44) | (w20 << 20)) & mask;
+ out[31] = w20 >> 22;
+
+ return in;
+}
+
+inline const uint8_t* unpack43_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 8796093022207ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint32_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 43) | (w1 << 21)) & mask;
+ out[2] = ((w1 >> 22) | (w2 << 42)) & mask;
+ out[3] = (w2 >> 1) & mask;
+ out[4] = ((w2 >> 44) | (w3 << 20)) & mask;
+ out[5] = ((w3 >> 23) | (w4 << 41)) & mask;
+ out[6] = (w4 >> 2) & mask;
+ out[7] = ((w4 >> 45) | (w5 << 19)) & mask;
+ out[8] = ((w5 >> 24) | (w6 << 40)) & mask;
+ out[9] = (w6 >> 3) & mask;
+ out[10] = ((w6 >> 46) | (w7 << 18)) & mask;
+ out[11] = ((w7 >> 25) | (w8 << 39)) & mask;
+ out[12] = (w8 >> 4) & mask;
+ out[13] = ((w8 >> 47) | (w9 << 17)) & mask;
+ out[14] = ((w9 >> 26) | (w10 << 38)) & mask;
+ out[15] = (w10 >> 5) & mask;
+ out[16] = ((w10 >> 48) | (w11 << 16)) & mask;
+ out[17] = ((w11 >> 27) | (w12 << 37)) & mask;
+ out[18] = (w12 >> 6) & mask;
+ out[19] = ((w12 >> 49) | (w13 << 15)) & mask;
+ out[20] = ((w13 >> 28) | (w14 << 36)) & mask;
+ out[21] = (w14 >> 7) & mask;
+ out[22] = ((w14 >> 50) | (w15 << 14)) & mask;
+ out[23] = ((w15 >> 29) | (w16 << 35)) & mask;
+ out[24] = (w16 >> 8) & mask;
+ out[25] = ((w16 >> 51) | (w17 << 13)) & mask;
+ out[26] = ((w17 >> 30) | (w18 << 34)) & mask;
+ out[27] = (w18 >> 9) & mask;
+ out[28] = ((w18 >> 52) | (w19 << 12)) & mask;
+ out[29] = ((w19 >> 31) | (w20 << 33)) & mask;
+ out[30] = (w20 >> 10) & mask;
+ out[31] = ((w20 >> 53) | (w21 << 11)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack44_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 17592186044415ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 44) | (w1 << 20)) & mask;
+ out[2] = ((w1 >> 24) | (w2 << 40)) & mask;
+ out[3] = (w2 >> 4) & mask;
+ out[4] = ((w2 >> 48) | (w3 << 16)) & mask;
+ out[5] = ((w3 >> 28) | (w4 << 36)) & mask;
+ out[6] = (w4 >> 8) & mask;
+ out[7] = ((w4 >> 52) | (w5 << 12)) & mask;
+ out[8] = ((w5 >> 32) | (w6 << 32)) & mask;
+ out[9] = (w6 >> 12) & mask;
+ out[10] = ((w6 >> 56) | (w7 << 8)) & mask;
+ out[11] = ((w7 >> 36) | (w8 << 28)) & mask;
+ out[12] = (w8 >> 16) & mask;
+ out[13] = ((w8 >> 60) | (w9 << 4)) & mask;
+ out[14] = ((w9 >> 40) | (w10 << 24)) & mask;
+ out[15] = w10 >> 20;
+ out[16] = (w11)&mask;
+ out[17] = ((w11 >> 44) | (w12 << 20)) & mask;
+ out[18] = ((w12 >> 24) | (w13 << 40)) & mask;
+ out[19] = (w13 >> 4) & mask;
+ out[20] = ((w13 >> 48) | (w14 << 16)) & mask;
+ out[21] = ((w14 >> 28) | (w15 << 36)) & mask;
+ out[22] = (w15 >> 8) & mask;
+ out[23] = ((w15 >> 52) | (w16 << 12)) & mask;
+ out[24] = ((w16 >> 32) | (w17 << 32)) & mask;
+ out[25] = (w17 >> 12) & mask;
+ out[26] = ((w17 >> 56) | (w18 << 8)) & mask;
+ out[27] = ((w18 >> 36) | (w19 << 28)) & mask;
+ out[28] = (w19 >> 16) & mask;
+ out[29] = ((w19 >> 60) | (w20 << 4)) & mask;
+ out[30] = ((w20 >> 40) | (w21 << 24)) & mask;
+ out[31] = w21 >> 20;
+
+ return in;
+}
+
+inline const uint8_t* unpack45_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 35184372088831ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint32_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 45) | (w1 << 19)) & mask;
+ out[2] = ((w1 >> 26) | (w2 << 38)) & mask;
+ out[3] = (w2 >> 7) & mask;
+ out[4] = ((w2 >> 52) | (w3 << 12)) & mask;
+ out[5] = ((w3 >> 33) | (w4 << 31)) & mask;
+ out[6] = (w4 >> 14) & mask;
+ out[7] = ((w4 >> 59) | (w5 << 5)) & mask;
+ out[8] = ((w5 >> 40) | (w6 << 24)) & mask;
+ out[9] = ((w6 >> 21) | (w7 << 43)) & mask;
+ out[10] = (w7 >> 2) & mask;
+ out[11] = ((w7 >> 47) | (w8 << 17)) & mask;
+ out[12] = ((w8 >> 28) | (w9 << 36)) & mask;
+ out[13] = (w9 >> 9) & mask;
+ out[14] = ((w9 >> 54) | (w10 << 10)) & mask;
+ out[15] = ((w10 >> 35) | (w11 << 29)) & mask;
+ out[16] = (w11 >> 16) & mask;
+ out[17] = ((w11 >> 61) | (w12 << 3)) & mask;
+ out[18] = ((w12 >> 42) | (w13 << 22)) & mask;
+ out[19] = ((w13 >> 23) | (w14 << 41)) & mask;
+ out[20] = (w14 >> 4) & mask;
+ out[21] = ((w14 >> 49) | (w15 << 15)) & mask;
+ out[22] = ((w15 >> 30) | (w16 << 34)) & mask;
+ out[23] = (w16 >> 11) & mask;
+ out[24] = ((w16 >> 56) | (w17 << 8)) & mask;
+ out[25] = ((w17 >> 37) | (w18 << 27)) & mask;
+ out[26] = (w18 >> 18) & mask;
+ out[27] = ((w18 >> 63) | (w19 << 1)) & mask;
+ out[28] = ((w19 >> 44) | (w20 << 20)) & mask;
+ out[29] = ((w20 >> 25) | (w21 << 39)) & mask;
+ out[30] = (w21 >> 6) & mask;
+ out[31] = ((w21 >> 51) | (w22 << 13)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack46_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 70368744177663ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 46) | (w1 << 18)) & mask;
+ out[2] = ((w1 >> 28) | (w2 << 36)) & mask;
+ out[3] = (w2 >> 10) & mask;
+ out[4] = ((w2 >> 56) | (w3 << 8)) & mask;
+ out[5] = ((w3 >> 38) | (w4 << 26)) & mask;
+ out[6] = ((w4 >> 20) | (w5 << 44)) & mask;
+ out[7] = (w5 >> 2) & mask;
+ out[8] = ((w5 >> 48) | (w6 << 16)) & mask;
+ out[9] = ((w6 >> 30) | (w7 << 34)) & mask;
+ out[10] = (w7 >> 12) & mask;
+ out[11] = ((w7 >> 58) | (w8 << 6)) & mask;
+ out[12] = ((w8 >> 40) | (w9 << 24)) & mask;
+ out[13] = ((w9 >> 22) | (w10 << 42)) & mask;
+ out[14] = (w10 >> 4) & mask;
+ out[15] = ((w10 >> 50) | (w11 << 14)) & mask;
+ out[16] = ((w11 >> 32) | (w12 << 32)) & mask;
+ out[17] = (w12 >> 14) & mask;
+ out[18] = ((w12 >> 60) | (w13 << 4)) & mask;
+ out[19] = ((w13 >> 42) | (w14 << 22)) & mask;
+ out[20] = ((w14 >> 24) | (w15 << 40)) & mask;
+ out[21] = (w15 >> 6) & mask;
+ out[22] = ((w15 >> 52) | (w16 << 12)) & mask;
+ out[23] = ((w16 >> 34) | (w17 << 30)) & mask;
+ out[24] = (w17 >> 16) & mask;
+ out[25] = ((w17 >> 62) | (w18 << 2)) & mask;
+ out[26] = ((w18 >> 44) | (w19 << 20)) & mask;
+ out[27] = ((w19 >> 26) | (w20 << 38)) & mask;
+ out[28] = (w20 >> 8) & mask;
+ out[29] = ((w20 >> 54) | (w21 << 10)) & mask;
+ out[30] = ((w21 >> 36) | (w22 << 28)) & mask;
+ out[31] = w22 >> 18;
+
+ return in;
+}
+
+inline const uint8_t* unpack47_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 140737488355327ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint32_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 47) | (w1 << 17)) & mask;
+ out[2] = ((w1 >> 30) | (w2 << 34)) & mask;
+ out[3] = (w2 >> 13) & mask;
+ out[4] = ((w2 >> 60) | (w3 << 4)) & mask;
+ out[5] = ((w3 >> 43) | (w4 << 21)) & mask;
+ out[6] = ((w4 >> 26) | (w5 << 38)) & mask;
+ out[7] = (w5 >> 9) & mask;
+ out[8] = ((w5 >> 56) | (w6 << 8)) & mask;
+ out[9] = ((w6 >> 39) | (w7 << 25)) & mask;
+ out[10] = ((w7 >> 22) | (w8 << 42)) & mask;
+ out[11] = (w8 >> 5) & mask;
+ out[12] = ((w8 >> 52) | (w9 << 12)) & mask;
+ out[13] = ((w9 >> 35) | (w10 << 29)) & mask;
+ out[14] = ((w10 >> 18) | (w11 << 46)) & mask;
+ out[15] = (w11 >> 1) & mask;
+ out[16] = ((w11 >> 48) | (w12 << 16)) & mask;
+ out[17] = ((w12 >> 31) | (w13 << 33)) & mask;
+ out[18] = (w13 >> 14) & mask;
+ out[19] = ((w13 >> 61) | (w14 << 3)) & mask;
+ out[20] = ((w14 >> 44) | (w15 << 20)) & mask;
+ out[21] = ((w15 >> 27) | (w16 << 37)) & mask;
+ out[22] = (w16 >> 10) & mask;
+ out[23] = ((w16 >> 57) | (w17 << 7)) & mask;
+ out[24] = ((w17 >> 40) | (w18 << 24)) & mask;
+ out[25] = ((w18 >> 23) | (w19 << 41)) & mask;
+ out[26] = (w19 >> 6) & mask;
+ out[27] = ((w19 >> 53) | (w20 << 11)) & mask;
+ out[28] = ((w20 >> 36) | (w21 << 28)) & mask;
+ out[29] = ((w21 >> 19) | (w22 << 45)) & mask;
+ out[30] = (w22 >> 2) & mask;
+ out[31] = ((w22 >> 49) | (w23 << 15)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack48_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 281474976710655ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 48) | (w1 << 16)) & mask;
+ out[2] = ((w1 >> 32) | (w2 << 32)) & mask;
+ out[3] = w2 >> 16;
+ out[4] = (w3)&mask;
+ out[5] = ((w3 >> 48) | (w4 << 16)) & mask;
+ out[6] = ((w4 >> 32) | (w5 << 32)) & mask;
+ out[7] = w5 >> 16;
+ out[8] = (w6)&mask;
+ out[9] = ((w6 >> 48) | (w7 << 16)) & mask;
+ out[10] = ((w7 >> 32) | (w8 << 32)) & mask;
+ out[11] = w8 >> 16;
+ out[12] = (w9)&mask;
+ out[13] = ((w9 >> 48) | (w10 << 16)) & mask;
+ out[14] = ((w10 >> 32) | (w11 << 32)) & mask;
+ out[15] = w11 >> 16;
+ out[16] = (w12)&mask;
+ out[17] = ((w12 >> 48) | (w13 << 16)) & mask;
+ out[18] = ((w13 >> 32) | (w14 << 32)) & mask;
+ out[19] = w14 >> 16;
+ out[20] = (w15)&mask;
+ out[21] = ((w15 >> 48) | (w16 << 16)) & mask;
+ out[22] = ((w16 >> 32) | (w17 << 32)) & mask;
+ out[23] = w17 >> 16;
+ out[24] = (w18)&mask;
+ out[25] = ((w18 >> 48) | (w19 << 16)) & mask;
+ out[26] = ((w19 >> 32) | (w20 << 32)) & mask;
+ out[27] = w20 >> 16;
+ out[28] = (w21)&mask;
+ out[29] = ((w21 >> 48) | (w22 << 16)) & mask;
+ out[30] = ((w22 >> 32) | (w23 << 32)) & mask;
+ out[31] = w23 >> 16;
+
+ return in;
+}
+
+inline const uint8_t* unpack49_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 562949953421311ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint32_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 49) | (w1 << 15)) & mask;
+ out[2] = ((w1 >> 34) | (w2 << 30)) & mask;
+ out[3] = ((w2 >> 19) | (w3 << 45)) & mask;
+ out[4] = (w3 >> 4) & mask;
+ out[5] = ((w3 >> 53) | (w4 << 11)) & mask;
+ out[6] = ((w4 >> 38) | (w5 << 26)) & mask;
+ out[7] = ((w5 >> 23) | (w6 << 41)) & mask;
+ out[8] = (w6 >> 8) & mask;
+ out[9] = ((w6 >> 57) | (w7 << 7)) & mask;
+ out[10] = ((w7 >> 42) | (w8 << 22)) & mask;
+ out[11] = ((w8 >> 27) | (w9 << 37)) & mask;
+ out[12] = (w9 >> 12) & mask;
+ out[13] = ((w9 >> 61) | (w10 << 3)) & mask;
+ out[14] = ((w10 >> 46) | (w11 << 18)) & mask;
+ out[15] = ((w11 >> 31) | (w12 << 33)) & mask;
+ out[16] = ((w12 >> 16) | (w13 << 48)) & mask;
+ out[17] = (w13 >> 1) & mask;
+ out[18] = ((w13 >> 50) | (w14 << 14)) & mask;
+ out[19] = ((w14 >> 35) | (w15 << 29)) & mask;
+ out[20] = ((w15 >> 20) | (w16 << 44)) & mask;
+ out[21] = (w16 >> 5) & mask;
+ out[22] = ((w16 >> 54) | (w17 << 10)) & mask;
+ out[23] = ((w17 >> 39) | (w18 << 25)) & mask;
+ out[24] = ((w18 >> 24) | (w19 << 40)) & mask;
+ out[25] = (w19 >> 9) & mask;
+ out[26] = ((w19 >> 58) | (w20 << 6)) & mask;
+ out[27] = ((w20 >> 43) | (w21 << 21)) & mask;
+ out[28] = ((w21 >> 28) | (w22 << 36)) & mask;
+ out[29] = (w22 >> 13) & mask;
+ out[30] = ((w22 >> 62) | (w23 << 2)) & mask;
+ out[31] = ((w23 >> 47) | (w24 << 17)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack50_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 1125899906842623ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 50) | (w1 << 14)) & mask;
+ out[2] = ((w1 >> 36) | (w2 << 28)) & mask;
+ out[3] = ((w2 >> 22) | (w3 << 42)) & mask;
+ out[4] = (w3 >> 8) & mask;
+ out[5] = ((w3 >> 58) | (w4 << 6)) & mask;
+ out[6] = ((w4 >> 44) | (w5 << 20)) & mask;
+ out[7] = ((w5 >> 30) | (w6 << 34)) & mask;
+ out[8] = ((w6 >> 16) | (w7 << 48)) & mask;
+ out[9] = (w7 >> 2) & mask;
+ out[10] = ((w7 >> 52) | (w8 << 12)) & mask;
+ out[11] = ((w8 >> 38) | (w9 << 26)) & mask;
+ out[12] = ((w9 >> 24) | (w10 << 40)) & mask;
+ out[13] = (w10 >> 10) & mask;
+ out[14] = ((w10 >> 60) | (w11 << 4)) & mask;
+ out[15] = ((w11 >> 46) | (w12 << 18)) & mask;
+ out[16] = ((w12 >> 32) | (w13 << 32)) & mask;
+ out[17] = ((w13 >> 18) | (w14 << 46)) & mask;
+ out[18] = (w14 >> 4) & mask;
+ out[19] = ((w14 >> 54) | (w15 << 10)) & mask;
+ out[20] = ((w15 >> 40) | (w16 << 24)) & mask;
+ out[21] = ((w16 >> 26) | (w17 << 38)) & mask;
+ out[22] = (w17 >> 12) & mask;
+ out[23] = ((w17 >> 62) | (w18 << 2)) & mask;
+ out[24] = ((w18 >> 48) | (w19 << 16)) & mask;
+ out[25] = ((w19 >> 34) | (w20 << 30)) & mask;
+ out[26] = ((w20 >> 20) | (w21 << 44)) & mask;
+ out[27] = (w21 >> 6) & mask;
+ out[28] = ((w21 >> 56) | (w22 << 8)) & mask;
+ out[29] = ((w22 >> 42) | (w23 << 22)) & mask;
+ out[30] = ((w23 >> 28) | (w24 << 36)) & mask;
+ out[31] = w24 >> 14;
+
+ return in;
+}
+
+inline const uint8_t* unpack51_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 2251799813685247ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ uint64_t w25 = util::SafeLoadAs<uint32_t>(in);
+ w25 = arrow::BitUtil::FromLittleEndian(w25);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 51) | (w1 << 13)) & mask;
+ out[2] = ((w1 >> 38) | (w2 << 26)) & mask;
+ out[3] = ((w2 >> 25) | (w3 << 39)) & mask;
+ out[4] = (w3 >> 12) & mask;
+ out[5] = ((w3 >> 63) | (w4 << 1)) & mask;
+ out[6] = ((w4 >> 50) | (w5 << 14)) & mask;
+ out[7] = ((w5 >> 37) | (w6 << 27)) & mask;
+ out[8] = ((w6 >> 24) | (w7 << 40)) & mask;
+ out[9] = (w7 >> 11) & mask;
+ out[10] = ((w7 >> 62) | (w8 << 2)) & mask;
+ out[11] = ((w8 >> 49) | (w9 << 15)) & mask;
+ out[12] = ((w9 >> 36) | (w10 << 28)) & mask;
+ out[13] = ((w10 >> 23) | (w11 << 41)) & mask;
+ out[14] = (w11 >> 10) & mask;
+ out[15] = ((w11 >> 61) | (w12 << 3)) & mask;
+ out[16] = ((w12 >> 48) | (w13 << 16)) & mask;
+ out[17] = ((w13 >> 35) | (w14 << 29)) & mask;
+ out[18] = ((w14 >> 22) | (w15 << 42)) & mask;
+ out[19] = (w15 >> 9) & mask;
+ out[20] = ((w15 >> 60) | (w16 << 4)) & mask;
+ out[21] = ((w16 >> 47) | (w17 << 17)) & mask;
+ out[22] = ((w17 >> 34) | (w18 << 30)) & mask;
+ out[23] = ((w18 >> 21) | (w19 << 43)) & mask;
+ out[24] = (w19 >> 8) & mask;
+ out[25] = ((w19 >> 59) | (w20 << 5)) & mask;
+ out[26] = ((w20 >> 46) | (w21 << 18)) & mask;
+ out[27] = ((w21 >> 33) | (w22 << 31)) & mask;
+ out[28] = ((w22 >> 20) | (w23 << 44)) & mask;
+ out[29] = (w23 >> 7) & mask;
+ out[30] = ((w23 >> 58) | (w24 << 6)) & mask;
+ out[31] = ((w24 >> 45) | (w25 << 19)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack52_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 4503599627370495ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ uint64_t w25 = util::SafeLoadAs<uint64_t>(in);
+ w25 = arrow::BitUtil::FromLittleEndian(w25);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 52) | (w1 << 12)) & mask;
+ out[2] = ((w1 >> 40) | (w2 << 24)) & mask;
+ out[3] = ((w2 >> 28) | (w3 << 36)) & mask;
+ out[4] = ((w3 >> 16) | (w4 << 48)) & mask;
+ out[5] = (w4 >> 4) & mask;
+ out[6] = ((w4 >> 56) | (w5 << 8)) & mask;
+ out[7] = ((w5 >> 44) | (w6 << 20)) & mask;
+ out[8] = ((w6 >> 32) | (w7 << 32)) & mask;
+ out[9] = ((w7 >> 20) | (w8 << 44)) & mask;
+ out[10] = (w8 >> 8) & mask;
+ out[11] = ((w8 >> 60) | (w9 << 4)) & mask;
+ out[12] = ((w9 >> 48) | (w10 << 16)) & mask;
+ out[13] = ((w10 >> 36) | (w11 << 28)) & mask;
+ out[14] = ((w11 >> 24) | (w12 << 40)) & mask;
+ out[15] = w12 >> 12;
+ out[16] = (w13)&mask;
+ out[17] = ((w13 >> 52) | (w14 << 12)) & mask;
+ out[18] = ((w14 >> 40) | (w15 << 24)) & mask;
+ out[19] = ((w15 >> 28) | (w16 << 36)) & mask;
+ out[20] = ((w16 >> 16) | (w17 << 48)) & mask;
+ out[21] = (w17 >> 4) & mask;
+ out[22] = ((w17 >> 56) | (w18 << 8)) & mask;
+ out[23] = ((w18 >> 44) | (w19 << 20)) & mask;
+ out[24] = ((w19 >> 32) | (w20 << 32)) & mask;
+ out[25] = ((w20 >> 20) | (w21 << 44)) & mask;
+ out[26] = (w21 >> 8) & mask;
+ out[27] = ((w21 >> 60) | (w22 << 4)) & mask;
+ out[28] = ((w22 >> 48) | (w23 << 16)) & mask;
+ out[29] = ((w23 >> 36) | (w24 << 28)) & mask;
+ out[30] = ((w24 >> 24) | (w25 << 40)) & mask;
+ out[31] = w25 >> 12;
+
+ return in;
+}
+
+inline const uint8_t* unpack53_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 9007199254740991ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ uint64_t w25 = util::SafeLoadAs<uint64_t>(in);
+ w25 = arrow::BitUtil::FromLittleEndian(w25);
+ in += 8;
+ uint64_t w26 = util::SafeLoadAs<uint32_t>(in);
+ w26 = arrow::BitUtil::FromLittleEndian(w26);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 53) | (w1 << 11)) & mask;
+ out[2] = ((w1 >> 42) | (w2 << 22)) & mask;
+ out[3] = ((w2 >> 31) | (w3 << 33)) & mask;
+ out[4] = ((w3 >> 20) | (w4 << 44)) & mask;
+ out[5] = (w4 >> 9) & mask;
+ out[6] = ((w4 >> 62) | (w5 << 2)) & mask;
+ out[7] = ((w5 >> 51) | (w6 << 13)) & mask;
+ out[8] = ((w6 >> 40) | (w7 << 24)) & mask;
+ out[9] = ((w7 >> 29) | (w8 << 35)) & mask;
+ out[10] = ((w8 >> 18) | (w9 << 46)) & mask;
+ out[11] = (w9 >> 7) & mask;
+ out[12] = ((w9 >> 60) | (w10 << 4)) & mask;
+ out[13] = ((w10 >> 49) | (w11 << 15)) & mask;
+ out[14] = ((w11 >> 38) | (w12 << 26)) & mask;
+ out[15] = ((w12 >> 27) | (w13 << 37)) & mask;
+ out[16] = ((w13 >> 16) | (w14 << 48)) & mask;
+ out[17] = (w14 >> 5) & mask;
+ out[18] = ((w14 >> 58) | (w15 << 6)) & mask;
+ out[19] = ((w15 >> 47) | (w16 << 17)) & mask;
+ out[20] = ((w16 >> 36) | (w17 << 28)) & mask;
+ out[21] = ((w17 >> 25) | (w18 << 39)) & mask;
+ out[22] = ((w18 >> 14) | (w19 << 50)) & mask;
+ out[23] = (w19 >> 3) & mask;
+ out[24] = ((w19 >> 56) | (w20 << 8)) & mask;
+ out[25] = ((w20 >> 45) | (w21 << 19)) & mask;
+ out[26] = ((w21 >> 34) | (w22 << 30)) & mask;
+ out[27] = ((w22 >> 23) | (w23 << 41)) & mask;
+ out[28] = ((w23 >> 12) | (w24 << 52)) & mask;
+ out[29] = (w24 >> 1) & mask;
+ out[30] = ((w24 >> 54) | (w25 << 10)) & mask;
+ out[31] = ((w25 >> 43) | (w26 << 21)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack54_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 18014398509481983ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ uint64_t w25 = util::SafeLoadAs<uint64_t>(in);
+ w25 = arrow::BitUtil::FromLittleEndian(w25);
+ in += 8;
+ uint64_t w26 = util::SafeLoadAs<uint64_t>(in);
+ w26 = arrow::BitUtil::FromLittleEndian(w26);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 54) | (w1 << 10)) & mask;
+ out[2] = ((w1 >> 44) | (w2 << 20)) & mask;
+ out[3] = ((w2 >> 34) | (w3 << 30)) & mask;
+ out[4] = ((w3 >> 24) | (w4 << 40)) & mask;
+ out[5] = ((w4 >> 14) | (w5 << 50)) & mask;
+ out[6] = (w5 >> 4) & mask;
+ out[7] = ((w5 >> 58) | (w6 << 6)) & mask;
+ out[8] = ((w6 >> 48) | (w7 << 16)) & mask;
+ out[9] = ((w7 >> 38) | (w8 << 26)) & mask;
+ out[10] = ((w8 >> 28) | (w9 << 36)) & mask;
+ out[11] = ((w9 >> 18) | (w10 << 46)) & mask;
+ out[12] = (w10 >> 8) & mask;
+ out[13] = ((w10 >> 62) | (w11 << 2)) & mask;
+ out[14] = ((w11 >> 52) | (w12 << 12)) & mask;
+ out[15] = ((w12 >> 42) | (w13 << 22)) & mask;
+ out[16] = ((w13 >> 32) | (w14 << 32)) & mask;
+ out[17] = ((w14 >> 22) | (w15 << 42)) & mask;
+ out[18] = ((w15 >> 12) | (w16 << 52)) & mask;
+ out[19] = (w16 >> 2) & mask;
+ out[20] = ((w16 >> 56) | (w17 << 8)) & mask;
+ out[21] = ((w17 >> 46) | (w18 << 18)) & mask;
+ out[22] = ((w18 >> 36) | (w19 << 28)) & mask;
+ out[23] = ((w19 >> 26) | (w20 << 38)) & mask;
+ out[24] = ((w20 >> 16) | (w21 << 48)) & mask;
+ out[25] = (w21 >> 6) & mask;
+ out[26] = ((w21 >> 60) | (w22 << 4)) & mask;
+ out[27] = ((w22 >> 50) | (w23 << 14)) & mask;
+ out[28] = ((w23 >> 40) | (w24 << 24)) & mask;
+ out[29] = ((w24 >> 30) | (w25 << 34)) & mask;
+ out[30] = ((w25 >> 20) | (w26 << 44)) & mask;
+ out[31] = w26 >> 10;
+
+ return in;
+}
+
+inline const uint8_t* unpack55_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 36028797018963967ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ uint64_t w25 = util::SafeLoadAs<uint64_t>(in);
+ w25 = arrow::BitUtil::FromLittleEndian(w25);
+ in += 8;
+ uint64_t w26 = util::SafeLoadAs<uint64_t>(in);
+ w26 = arrow::BitUtil::FromLittleEndian(w26);
+ in += 8;
+ uint64_t w27 = util::SafeLoadAs<uint32_t>(in);
+ w27 = arrow::BitUtil::FromLittleEndian(w27);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 55) | (w1 << 9)) & mask;
+ out[2] = ((w1 >> 46) | (w2 << 18)) & mask;
+ out[3] = ((w2 >> 37) | (w3 << 27)) & mask;
+ out[4] = ((w3 >> 28) | (w4 << 36)) & mask;
+ out[5] = ((w4 >> 19) | (w5 << 45)) & mask;
+ out[6] = ((w5 >> 10) | (w6 << 54)) & mask;
+ out[7] = (w6 >> 1) & mask;
+ out[8] = ((w6 >> 56) | (w7 << 8)) & mask;
+ out[9] = ((w7 >> 47) | (w8 << 17)) & mask;
+ out[10] = ((w8 >> 38) | (w9 << 26)) & mask;
+ out[11] = ((w9 >> 29) | (w10 << 35)) & mask;
+ out[12] = ((w10 >> 20) | (w11 << 44)) & mask;
+ out[13] = ((w11 >> 11) | (w12 << 53)) & mask;
+ out[14] = (w12 >> 2) & mask;
+ out[15] = ((w12 >> 57) | (w13 << 7)) & mask;
+ out[16] = ((w13 >> 48) | (w14 << 16)) & mask;
+ out[17] = ((w14 >> 39) | (w15 << 25)) & mask;
+ out[18] = ((w15 >> 30) | (w16 << 34)) & mask;
+ out[19] = ((w16 >> 21) | (w17 << 43)) & mask;
+ out[20] = ((w17 >> 12) | (w18 << 52)) & mask;
+ out[21] = (w18 >> 3) & mask;
+ out[22] = ((w18 >> 58) | (w19 << 6)) & mask;
+ out[23] = ((w19 >> 49) | (w20 << 15)) & mask;
+ out[24] = ((w20 >> 40) | (w21 << 24)) & mask;
+ out[25] = ((w21 >> 31) | (w22 << 33)) & mask;
+ out[26] = ((w22 >> 22) | (w23 << 42)) & mask;
+ out[27] = ((w23 >> 13) | (w24 << 51)) & mask;
+ out[28] = (w24 >> 4) & mask;
+ out[29] = ((w24 >> 59) | (w25 << 5)) & mask;
+ out[30] = ((w25 >> 50) | (w26 << 14)) & mask;
+ out[31] = ((w26 >> 41) | (w27 << 23)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack56_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 72057594037927935ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ uint64_t w25 = util::SafeLoadAs<uint64_t>(in);
+ w25 = arrow::BitUtil::FromLittleEndian(w25);
+ in += 8;
+ uint64_t w26 = util::SafeLoadAs<uint64_t>(in);
+ w26 = arrow::BitUtil::FromLittleEndian(w26);
+ in += 8;
+ uint64_t w27 = util::SafeLoadAs<uint64_t>(in);
+ w27 = arrow::BitUtil::FromLittleEndian(w27);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 56) | (w1 << 8)) & mask;
+ out[2] = ((w1 >> 48) | (w2 << 16)) & mask;
+ out[3] = ((w2 >> 40) | (w3 << 24)) & mask;
+ out[4] = ((w3 >> 32) | (w4 << 32)) & mask;
+ out[5] = ((w4 >> 24) | (w5 << 40)) & mask;
+ out[6] = ((w5 >> 16) | (w6 << 48)) & mask;
+ out[7] = w6 >> 8;
+ out[8] = (w7)&mask;
+ out[9] = ((w7 >> 56) | (w8 << 8)) & mask;
+ out[10] = ((w8 >> 48) | (w9 << 16)) & mask;
+ out[11] = ((w9 >> 40) | (w10 << 24)) & mask;
+ out[12] = ((w10 >> 32) | (w11 << 32)) & mask;
+ out[13] = ((w11 >> 24) | (w12 << 40)) & mask;
+ out[14] = ((w12 >> 16) | (w13 << 48)) & mask;
+ out[15] = w13 >> 8;
+ out[16] = (w14)&mask;
+ out[17] = ((w14 >> 56) | (w15 << 8)) & mask;
+ out[18] = ((w15 >> 48) | (w16 << 16)) & mask;
+ out[19] = ((w16 >> 40) | (w17 << 24)) & mask;
+ out[20] = ((w17 >> 32) | (w18 << 32)) & mask;
+ out[21] = ((w18 >> 24) | (w19 << 40)) & mask;
+ out[22] = ((w19 >> 16) | (w20 << 48)) & mask;
+ out[23] = w20 >> 8;
+ out[24] = (w21)&mask;
+ out[25] = ((w21 >> 56) | (w22 << 8)) & mask;
+ out[26] = ((w22 >> 48) | (w23 << 16)) & mask;
+ out[27] = ((w23 >> 40) | (w24 << 24)) & mask;
+ out[28] = ((w24 >> 32) | (w25 << 32)) & mask;
+ out[29] = ((w25 >> 24) | (w26 << 40)) & mask;
+ out[30] = ((w26 >> 16) | (w27 << 48)) & mask;
+ out[31] = w27 >> 8;
+
+ return in;
+}
+
+inline const uint8_t* unpack57_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 144115188075855871ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ uint64_t w25 = util::SafeLoadAs<uint64_t>(in);
+ w25 = arrow::BitUtil::FromLittleEndian(w25);
+ in += 8;
+ uint64_t w26 = util::SafeLoadAs<uint64_t>(in);
+ w26 = arrow::BitUtil::FromLittleEndian(w26);
+ in += 8;
+ uint64_t w27 = util::SafeLoadAs<uint64_t>(in);
+ w27 = arrow::BitUtil::FromLittleEndian(w27);
+ in += 8;
+ uint64_t w28 = util::SafeLoadAs<uint32_t>(in);
+ w28 = arrow::BitUtil::FromLittleEndian(w28);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 57) | (w1 << 7)) & mask;
+ out[2] = ((w1 >> 50) | (w2 << 14)) & mask;
+ out[3] = ((w2 >> 43) | (w3 << 21)) & mask;
+ out[4] = ((w3 >> 36) | (w4 << 28)) & mask;
+ out[5] = ((w4 >> 29) | (w5 << 35)) & mask;
+ out[6] = ((w5 >> 22) | (w6 << 42)) & mask;
+ out[7] = ((w6 >> 15) | (w7 << 49)) & mask;
+ out[8] = ((w7 >> 8) | (w8 << 56)) & mask;
+ out[9] = (w8 >> 1) & mask;
+ out[10] = ((w8 >> 58) | (w9 << 6)) & mask;
+ out[11] = ((w9 >> 51) | (w10 << 13)) & mask;
+ out[12] = ((w10 >> 44) | (w11 << 20)) & mask;
+ out[13] = ((w11 >> 37) | (w12 << 27)) & mask;
+ out[14] = ((w12 >> 30) | (w13 << 34)) & mask;
+ out[15] = ((w13 >> 23) | (w14 << 41)) & mask;
+ out[16] = ((w14 >> 16) | (w15 << 48)) & mask;
+ out[17] = ((w15 >> 9) | (w16 << 55)) & mask;
+ out[18] = (w16 >> 2) & mask;
+ out[19] = ((w16 >> 59) | (w17 << 5)) & mask;
+ out[20] = ((w17 >> 52) | (w18 << 12)) & mask;
+ out[21] = ((w18 >> 45) | (w19 << 19)) & mask;
+ out[22] = ((w19 >> 38) | (w20 << 26)) & mask;
+ out[23] = ((w20 >> 31) | (w21 << 33)) & mask;
+ out[24] = ((w21 >> 24) | (w22 << 40)) & mask;
+ out[25] = ((w22 >> 17) | (w23 << 47)) & mask;
+ out[26] = ((w23 >> 10) | (w24 << 54)) & mask;
+ out[27] = (w24 >> 3) & mask;
+ out[28] = ((w24 >> 60) | (w25 << 4)) & mask;
+ out[29] = ((w25 >> 53) | (w26 << 11)) & mask;
+ out[30] = ((w26 >> 46) | (w27 << 18)) & mask;
+ out[31] = ((w27 >> 39) | (w28 << 25)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack58_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 288230376151711743ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ uint64_t w25 = util::SafeLoadAs<uint64_t>(in);
+ w25 = arrow::BitUtil::FromLittleEndian(w25);
+ in += 8;
+ uint64_t w26 = util::SafeLoadAs<uint64_t>(in);
+ w26 = arrow::BitUtil::FromLittleEndian(w26);
+ in += 8;
+ uint64_t w27 = util::SafeLoadAs<uint64_t>(in);
+ w27 = arrow::BitUtil::FromLittleEndian(w27);
+ in += 8;
+ uint64_t w28 = util::SafeLoadAs<uint64_t>(in);
+ w28 = arrow::BitUtil::FromLittleEndian(w28);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 58) | (w1 << 6)) & mask;
+ out[2] = ((w1 >> 52) | (w2 << 12)) & mask;
+ out[3] = ((w2 >> 46) | (w3 << 18)) & mask;
+ out[4] = ((w3 >> 40) | (w4 << 24)) & mask;
+ out[5] = ((w4 >> 34) | (w5 << 30)) & mask;
+ out[6] = ((w5 >> 28) | (w6 << 36)) & mask;
+ out[7] = ((w6 >> 22) | (w7 << 42)) & mask;
+ out[8] = ((w7 >> 16) | (w8 << 48)) & mask;
+ out[9] = ((w8 >> 10) | (w9 << 54)) & mask;
+ out[10] = (w9 >> 4) & mask;
+ out[11] = ((w9 >> 62) | (w10 << 2)) & mask;
+ out[12] = ((w10 >> 56) | (w11 << 8)) & mask;
+ out[13] = ((w11 >> 50) | (w12 << 14)) & mask;
+ out[14] = ((w12 >> 44) | (w13 << 20)) & mask;
+ out[15] = ((w13 >> 38) | (w14 << 26)) & mask;
+ out[16] = ((w14 >> 32) | (w15 << 32)) & mask;
+ out[17] = ((w15 >> 26) | (w16 << 38)) & mask;
+ out[18] = ((w16 >> 20) | (w17 << 44)) & mask;
+ out[19] = ((w17 >> 14) | (w18 << 50)) & mask;
+ out[20] = ((w18 >> 8) | (w19 << 56)) & mask;
+ out[21] = (w19 >> 2) & mask;
+ out[22] = ((w19 >> 60) | (w20 << 4)) & mask;
+ out[23] = ((w20 >> 54) | (w21 << 10)) & mask;
+ out[24] = ((w21 >> 48) | (w22 << 16)) & mask;
+ out[25] = ((w22 >> 42) | (w23 << 22)) & mask;
+ out[26] = ((w23 >> 36) | (w24 << 28)) & mask;
+ out[27] = ((w24 >> 30) | (w25 << 34)) & mask;
+ out[28] = ((w25 >> 24) | (w26 << 40)) & mask;
+ out[29] = ((w26 >> 18) | (w27 << 46)) & mask;
+ out[30] = ((w27 >> 12) | (w28 << 52)) & mask;
+ out[31] = w28 >> 6;
+
+ return in;
+}
+
+inline const uint8_t* unpack59_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 576460752303423487ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ uint64_t w25 = util::SafeLoadAs<uint64_t>(in);
+ w25 = arrow::BitUtil::FromLittleEndian(w25);
+ in += 8;
+ uint64_t w26 = util::SafeLoadAs<uint64_t>(in);
+ w26 = arrow::BitUtil::FromLittleEndian(w26);
+ in += 8;
+ uint64_t w27 = util::SafeLoadAs<uint64_t>(in);
+ w27 = arrow::BitUtil::FromLittleEndian(w27);
+ in += 8;
+ uint64_t w28 = util::SafeLoadAs<uint64_t>(in);
+ w28 = arrow::BitUtil::FromLittleEndian(w28);
+ in += 8;
+ uint64_t w29 = util::SafeLoadAs<uint32_t>(in);
+ w29 = arrow::BitUtil::FromLittleEndian(w29);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 59) | (w1 << 5)) & mask;
+ out[2] = ((w1 >> 54) | (w2 << 10)) & mask;
+ out[3] = ((w2 >> 49) | (w3 << 15)) & mask;
+ out[4] = ((w3 >> 44) | (w4 << 20)) & mask;
+ out[5] = ((w4 >> 39) | (w5 << 25)) & mask;
+ out[6] = ((w5 >> 34) | (w6 << 30)) & mask;
+ out[7] = ((w6 >> 29) | (w7 << 35)) & mask;
+ out[8] = ((w7 >> 24) | (w8 << 40)) & mask;
+ out[9] = ((w8 >> 19) | (w9 << 45)) & mask;
+ out[10] = ((w9 >> 14) | (w10 << 50)) & mask;
+ out[11] = ((w10 >> 9) | (w11 << 55)) & mask;
+ out[12] = (w11 >> 4) & mask;
+ out[13] = ((w11 >> 63) | (w12 << 1)) & mask;
+ out[14] = ((w12 >> 58) | (w13 << 6)) & mask;
+ out[15] = ((w13 >> 53) | (w14 << 11)) & mask;
+ out[16] = ((w14 >> 48) | (w15 << 16)) & mask;
+ out[17] = ((w15 >> 43) | (w16 << 21)) & mask;
+ out[18] = ((w16 >> 38) | (w17 << 26)) & mask;
+ out[19] = ((w17 >> 33) | (w18 << 31)) & mask;
+ out[20] = ((w18 >> 28) | (w19 << 36)) & mask;
+ out[21] = ((w19 >> 23) | (w20 << 41)) & mask;
+ out[22] = ((w20 >> 18) | (w21 << 46)) & mask;
+ out[23] = ((w21 >> 13) | (w22 << 51)) & mask;
+ out[24] = ((w22 >> 8) | (w23 << 56)) & mask;
+ out[25] = (w23 >> 3) & mask;
+ out[26] = ((w23 >> 62) | (w24 << 2)) & mask;
+ out[27] = ((w24 >> 57) | (w25 << 7)) & mask;
+ out[28] = ((w25 >> 52) | (w26 << 12)) & mask;
+ out[29] = ((w26 >> 47) | (w27 << 17)) & mask;
+ out[30] = ((w27 >> 42) | (w28 << 22)) & mask;
+ out[31] = ((w28 >> 37) | (w29 << 27)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack60_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 1152921504606846975ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ uint64_t w25 = util::SafeLoadAs<uint64_t>(in);
+ w25 = arrow::BitUtil::FromLittleEndian(w25);
+ in += 8;
+ uint64_t w26 = util::SafeLoadAs<uint64_t>(in);
+ w26 = arrow::BitUtil::FromLittleEndian(w26);
+ in += 8;
+ uint64_t w27 = util::SafeLoadAs<uint64_t>(in);
+ w27 = arrow::BitUtil::FromLittleEndian(w27);
+ in += 8;
+ uint64_t w28 = util::SafeLoadAs<uint64_t>(in);
+ w28 = arrow::BitUtil::FromLittleEndian(w28);
+ in += 8;
+ uint64_t w29 = util::SafeLoadAs<uint64_t>(in);
+ w29 = arrow::BitUtil::FromLittleEndian(w29);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 60) | (w1 << 4)) & mask;
+ out[2] = ((w1 >> 56) | (w2 << 8)) & mask;
+ out[3] = ((w2 >> 52) | (w3 << 12)) & mask;
+ out[4] = ((w3 >> 48) | (w4 << 16)) & mask;
+ out[5] = ((w4 >> 44) | (w5 << 20)) & mask;
+ out[6] = ((w5 >> 40) | (w6 << 24)) & mask;
+ out[7] = ((w6 >> 36) | (w7 << 28)) & mask;
+ out[8] = ((w7 >> 32) | (w8 << 32)) & mask;
+ out[9] = ((w8 >> 28) | (w9 << 36)) & mask;
+ out[10] = ((w9 >> 24) | (w10 << 40)) & mask;
+ out[11] = ((w10 >> 20) | (w11 << 44)) & mask;
+ out[12] = ((w11 >> 16) | (w12 << 48)) & mask;
+ out[13] = ((w12 >> 12) | (w13 << 52)) & mask;
+ out[14] = ((w13 >> 8) | (w14 << 56)) & mask;
+ out[15] = w14 >> 4;
+ out[16] = (w15)&mask;
+ out[17] = ((w15 >> 60) | (w16 << 4)) & mask;
+ out[18] = ((w16 >> 56) | (w17 << 8)) & mask;
+ out[19] = ((w17 >> 52) | (w18 << 12)) & mask;
+ out[20] = ((w18 >> 48) | (w19 << 16)) & mask;
+ out[21] = ((w19 >> 44) | (w20 << 20)) & mask;
+ out[22] = ((w20 >> 40) | (w21 << 24)) & mask;
+ out[23] = ((w21 >> 36) | (w22 << 28)) & mask;
+ out[24] = ((w22 >> 32) | (w23 << 32)) & mask;
+ out[25] = ((w23 >> 28) | (w24 << 36)) & mask;
+ out[26] = ((w24 >> 24) | (w25 << 40)) & mask;
+ out[27] = ((w25 >> 20) | (w26 << 44)) & mask;
+ out[28] = ((w26 >> 16) | (w27 << 48)) & mask;
+ out[29] = ((w27 >> 12) | (w28 << 52)) & mask;
+ out[30] = ((w28 >> 8) | (w29 << 56)) & mask;
+ out[31] = w29 >> 4;
+
+ return in;
+}
+
+inline const uint8_t* unpack61_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 2305843009213693951ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ uint64_t w25 = util::SafeLoadAs<uint64_t>(in);
+ w25 = arrow::BitUtil::FromLittleEndian(w25);
+ in += 8;
+ uint64_t w26 = util::SafeLoadAs<uint64_t>(in);
+ w26 = arrow::BitUtil::FromLittleEndian(w26);
+ in += 8;
+ uint64_t w27 = util::SafeLoadAs<uint64_t>(in);
+ w27 = arrow::BitUtil::FromLittleEndian(w27);
+ in += 8;
+ uint64_t w28 = util::SafeLoadAs<uint64_t>(in);
+ w28 = arrow::BitUtil::FromLittleEndian(w28);
+ in += 8;
+ uint64_t w29 = util::SafeLoadAs<uint64_t>(in);
+ w29 = arrow::BitUtil::FromLittleEndian(w29);
+ in += 8;
+ uint64_t w30 = util::SafeLoadAs<uint32_t>(in);
+ w30 = arrow::BitUtil::FromLittleEndian(w30);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 61) | (w1 << 3)) & mask;
+ out[2] = ((w1 >> 58) | (w2 << 6)) & mask;
+ out[3] = ((w2 >> 55) | (w3 << 9)) & mask;
+ out[4] = ((w3 >> 52) | (w4 << 12)) & mask;
+ out[5] = ((w4 >> 49) | (w5 << 15)) & mask;
+ out[6] = ((w5 >> 46) | (w6 << 18)) & mask;
+ out[7] = ((w6 >> 43) | (w7 << 21)) & mask;
+ out[8] = ((w7 >> 40) | (w8 << 24)) & mask;
+ out[9] = ((w8 >> 37) | (w9 << 27)) & mask;
+ out[10] = ((w9 >> 34) | (w10 << 30)) & mask;
+ out[11] = ((w10 >> 31) | (w11 << 33)) & mask;
+ out[12] = ((w11 >> 28) | (w12 << 36)) & mask;
+ out[13] = ((w12 >> 25) | (w13 << 39)) & mask;
+ out[14] = ((w13 >> 22) | (w14 << 42)) & mask;
+ out[15] = ((w14 >> 19) | (w15 << 45)) & mask;
+ out[16] = ((w15 >> 16) | (w16 << 48)) & mask;
+ out[17] = ((w16 >> 13) | (w17 << 51)) & mask;
+ out[18] = ((w17 >> 10) | (w18 << 54)) & mask;
+ out[19] = ((w18 >> 7) | (w19 << 57)) & mask;
+ out[20] = ((w19 >> 4) | (w20 << 60)) & mask;
+ out[21] = (w20 >> 1) & mask;
+ out[22] = ((w20 >> 62) | (w21 << 2)) & mask;
+ out[23] = ((w21 >> 59) | (w22 << 5)) & mask;
+ out[24] = ((w22 >> 56) | (w23 << 8)) & mask;
+ out[25] = ((w23 >> 53) | (w24 << 11)) & mask;
+ out[26] = ((w24 >> 50) | (w25 << 14)) & mask;
+ out[27] = ((w25 >> 47) | (w26 << 17)) & mask;
+ out[28] = ((w26 >> 44) | (w27 << 20)) & mask;
+ out[29] = ((w27 >> 41) | (w28 << 23)) & mask;
+ out[30] = ((w28 >> 38) | (w29 << 26)) & mask;
+ out[31] = ((w29 >> 35) | (w30 << 29)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack62_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 4611686018427387903ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ uint64_t w25 = util::SafeLoadAs<uint64_t>(in);
+ w25 = arrow::BitUtil::FromLittleEndian(w25);
+ in += 8;
+ uint64_t w26 = util::SafeLoadAs<uint64_t>(in);
+ w26 = arrow::BitUtil::FromLittleEndian(w26);
+ in += 8;
+ uint64_t w27 = util::SafeLoadAs<uint64_t>(in);
+ w27 = arrow::BitUtil::FromLittleEndian(w27);
+ in += 8;
+ uint64_t w28 = util::SafeLoadAs<uint64_t>(in);
+ w28 = arrow::BitUtil::FromLittleEndian(w28);
+ in += 8;
+ uint64_t w29 = util::SafeLoadAs<uint64_t>(in);
+ w29 = arrow::BitUtil::FromLittleEndian(w29);
+ in += 8;
+ uint64_t w30 = util::SafeLoadAs<uint64_t>(in);
+ w30 = arrow::BitUtil::FromLittleEndian(w30);
+ in += 8;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 62) | (w1 << 2)) & mask;
+ out[2] = ((w1 >> 60) | (w2 << 4)) & mask;
+ out[3] = ((w2 >> 58) | (w3 << 6)) & mask;
+ out[4] = ((w3 >> 56) | (w4 << 8)) & mask;
+ out[5] = ((w4 >> 54) | (w5 << 10)) & mask;
+ out[6] = ((w5 >> 52) | (w6 << 12)) & mask;
+ out[7] = ((w6 >> 50) | (w7 << 14)) & mask;
+ out[8] = ((w7 >> 48) | (w8 << 16)) & mask;
+ out[9] = ((w8 >> 46) | (w9 << 18)) & mask;
+ out[10] = ((w9 >> 44) | (w10 << 20)) & mask;
+ out[11] = ((w10 >> 42) | (w11 << 22)) & mask;
+ out[12] = ((w11 >> 40) | (w12 << 24)) & mask;
+ out[13] = ((w12 >> 38) | (w13 << 26)) & mask;
+ out[14] = ((w13 >> 36) | (w14 << 28)) & mask;
+ out[15] = ((w14 >> 34) | (w15 << 30)) & mask;
+ out[16] = ((w15 >> 32) | (w16 << 32)) & mask;
+ out[17] = ((w16 >> 30) | (w17 << 34)) & mask;
+ out[18] = ((w17 >> 28) | (w18 << 36)) & mask;
+ out[19] = ((w18 >> 26) | (w19 << 38)) & mask;
+ out[20] = ((w19 >> 24) | (w20 << 40)) & mask;
+ out[21] = ((w20 >> 22) | (w21 << 42)) & mask;
+ out[22] = ((w21 >> 20) | (w22 << 44)) & mask;
+ out[23] = ((w22 >> 18) | (w23 << 46)) & mask;
+ out[24] = ((w23 >> 16) | (w24 << 48)) & mask;
+ out[25] = ((w24 >> 14) | (w25 << 50)) & mask;
+ out[26] = ((w25 >> 12) | (w26 << 52)) & mask;
+ out[27] = ((w26 >> 10) | (w27 << 54)) & mask;
+ out[28] = ((w27 >> 8) | (w28 << 56)) & mask;
+ out[29] = ((w28 >> 6) | (w29 << 58)) & mask;
+ out[30] = ((w29 >> 4) | (w30 << 60)) & mask;
+ out[31] = w30 >> 2;
+
+ return in;
+}
+
+inline const uint8_t* unpack63_64(const uint8_t* in, uint64_t* out) {
+ const uint64_t mask = 9223372036854775807ULL;
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ uint64_t w25 = util::SafeLoadAs<uint64_t>(in);
+ w25 = arrow::BitUtil::FromLittleEndian(w25);
+ in += 8;
+ uint64_t w26 = util::SafeLoadAs<uint64_t>(in);
+ w26 = arrow::BitUtil::FromLittleEndian(w26);
+ in += 8;
+ uint64_t w27 = util::SafeLoadAs<uint64_t>(in);
+ w27 = arrow::BitUtil::FromLittleEndian(w27);
+ in += 8;
+ uint64_t w28 = util::SafeLoadAs<uint64_t>(in);
+ w28 = arrow::BitUtil::FromLittleEndian(w28);
+ in += 8;
+ uint64_t w29 = util::SafeLoadAs<uint64_t>(in);
+ w29 = arrow::BitUtil::FromLittleEndian(w29);
+ in += 8;
+ uint64_t w30 = util::SafeLoadAs<uint64_t>(in);
+ w30 = arrow::BitUtil::FromLittleEndian(w30);
+ in += 8;
+ uint64_t w31 = util::SafeLoadAs<uint32_t>(in);
+ w31 = arrow::BitUtil::FromLittleEndian(w31);
+ in += 4;
+ out[0] = (w0)&mask;
+ out[1] = ((w0 >> 63) | (w1 << 1)) & mask;
+ out[2] = ((w1 >> 62) | (w2 << 2)) & mask;
+ out[3] = ((w2 >> 61) | (w3 << 3)) & mask;
+ out[4] = ((w3 >> 60) | (w4 << 4)) & mask;
+ out[5] = ((w4 >> 59) | (w5 << 5)) & mask;
+ out[6] = ((w5 >> 58) | (w6 << 6)) & mask;
+ out[7] = ((w6 >> 57) | (w7 << 7)) & mask;
+ out[8] = ((w7 >> 56) | (w8 << 8)) & mask;
+ out[9] = ((w8 >> 55) | (w9 << 9)) & mask;
+ out[10] = ((w9 >> 54) | (w10 << 10)) & mask;
+ out[11] = ((w10 >> 53) | (w11 << 11)) & mask;
+ out[12] = ((w11 >> 52) | (w12 << 12)) & mask;
+ out[13] = ((w12 >> 51) | (w13 << 13)) & mask;
+ out[14] = ((w13 >> 50) | (w14 << 14)) & mask;
+ out[15] = ((w14 >> 49) | (w15 << 15)) & mask;
+ out[16] = ((w15 >> 48) | (w16 << 16)) & mask;
+ out[17] = ((w16 >> 47) | (w17 << 17)) & mask;
+ out[18] = ((w17 >> 46) | (w18 << 18)) & mask;
+ out[19] = ((w18 >> 45) | (w19 << 19)) & mask;
+ out[20] = ((w19 >> 44) | (w20 << 20)) & mask;
+ out[21] = ((w20 >> 43) | (w21 << 21)) & mask;
+ out[22] = ((w21 >> 42) | (w22 << 22)) & mask;
+ out[23] = ((w22 >> 41) | (w23 << 23)) & mask;
+ out[24] = ((w23 >> 40) | (w24 << 24)) & mask;
+ out[25] = ((w24 >> 39) | (w25 << 25)) & mask;
+ out[26] = ((w25 >> 38) | (w26 << 26)) & mask;
+ out[27] = ((w26 >> 37) | (w27 << 27)) & mask;
+ out[28] = ((w27 >> 36) | (w28 << 28)) & mask;
+ out[29] = ((w28 >> 35) | (w29 << 29)) & mask;
+ out[30] = ((w29 >> 34) | (w30 << 30)) & mask;
+ out[31] = ((w30 >> 33) | (w31 << 31)) & mask;
+
+ return in;
+}
+
+inline const uint8_t* unpack64_64(const uint8_t* in, uint64_t* out) {
+ uint64_t w0 = util::SafeLoadAs<uint64_t>(in);
+ w0 = arrow::BitUtil::FromLittleEndian(w0);
+ in += 8;
+ uint64_t w1 = util::SafeLoadAs<uint64_t>(in);
+ w1 = arrow::BitUtil::FromLittleEndian(w1);
+ in += 8;
+ uint64_t w2 = util::SafeLoadAs<uint64_t>(in);
+ w2 = arrow::BitUtil::FromLittleEndian(w2);
+ in += 8;
+ uint64_t w3 = util::SafeLoadAs<uint64_t>(in);
+ w3 = arrow::BitUtil::FromLittleEndian(w3);
+ in += 8;
+ uint64_t w4 = util::SafeLoadAs<uint64_t>(in);
+ w4 = arrow::BitUtil::FromLittleEndian(w4);
+ in += 8;
+ uint64_t w5 = util::SafeLoadAs<uint64_t>(in);
+ w5 = arrow::BitUtil::FromLittleEndian(w5);
+ in += 8;
+ uint64_t w6 = util::SafeLoadAs<uint64_t>(in);
+ w6 = arrow::BitUtil::FromLittleEndian(w6);
+ in += 8;
+ uint64_t w7 = util::SafeLoadAs<uint64_t>(in);
+ w7 = arrow::BitUtil::FromLittleEndian(w7);
+ in += 8;
+ uint64_t w8 = util::SafeLoadAs<uint64_t>(in);
+ w8 = arrow::BitUtil::FromLittleEndian(w8);
+ in += 8;
+ uint64_t w9 = util::SafeLoadAs<uint64_t>(in);
+ w9 = arrow::BitUtil::FromLittleEndian(w9);
+ in += 8;
+ uint64_t w10 = util::SafeLoadAs<uint64_t>(in);
+ w10 = arrow::BitUtil::FromLittleEndian(w10);
+ in += 8;
+ uint64_t w11 = util::SafeLoadAs<uint64_t>(in);
+ w11 = arrow::BitUtil::FromLittleEndian(w11);
+ in += 8;
+ uint64_t w12 = util::SafeLoadAs<uint64_t>(in);
+ w12 = arrow::BitUtil::FromLittleEndian(w12);
+ in += 8;
+ uint64_t w13 = util::SafeLoadAs<uint64_t>(in);
+ w13 = arrow::BitUtil::FromLittleEndian(w13);
+ in += 8;
+ uint64_t w14 = util::SafeLoadAs<uint64_t>(in);
+ w14 = arrow::BitUtil::FromLittleEndian(w14);
+ in += 8;
+ uint64_t w15 = util::SafeLoadAs<uint64_t>(in);
+ w15 = arrow::BitUtil::FromLittleEndian(w15);
+ in += 8;
+ uint64_t w16 = util::SafeLoadAs<uint64_t>(in);
+ w16 = arrow::BitUtil::FromLittleEndian(w16);
+ in += 8;
+ uint64_t w17 = util::SafeLoadAs<uint64_t>(in);
+ w17 = arrow::BitUtil::FromLittleEndian(w17);
+ in += 8;
+ uint64_t w18 = util::SafeLoadAs<uint64_t>(in);
+ w18 = arrow::BitUtil::FromLittleEndian(w18);
+ in += 8;
+ uint64_t w19 = util::SafeLoadAs<uint64_t>(in);
+ w19 = arrow::BitUtil::FromLittleEndian(w19);
+ in += 8;
+ uint64_t w20 = util::SafeLoadAs<uint64_t>(in);
+ w20 = arrow::BitUtil::FromLittleEndian(w20);
+ in += 8;
+ uint64_t w21 = util::SafeLoadAs<uint64_t>(in);
+ w21 = arrow::BitUtil::FromLittleEndian(w21);
+ in += 8;
+ uint64_t w22 = util::SafeLoadAs<uint64_t>(in);
+ w22 = arrow::BitUtil::FromLittleEndian(w22);
+ in += 8;
+ uint64_t w23 = util::SafeLoadAs<uint64_t>(in);
+ w23 = arrow::BitUtil::FromLittleEndian(w23);
+ in += 8;
+ uint64_t w24 = util::SafeLoadAs<uint64_t>(in);
+ w24 = arrow::BitUtil::FromLittleEndian(w24);
+ in += 8;
+ uint64_t w25 = util::SafeLoadAs<uint64_t>(in);
+ w25 = arrow::BitUtil::FromLittleEndian(w25);
+ in += 8;
+ uint64_t w26 = util::SafeLoadAs<uint64_t>(in);
+ w26 = arrow::BitUtil::FromLittleEndian(w26);
+ in += 8;
+ uint64_t w27 = util::SafeLoadAs<uint64_t>(in);
+ w27 = arrow::BitUtil::FromLittleEndian(w27);
+ in += 8;
+ uint64_t w28 = util::SafeLoadAs<uint64_t>(in);
+ w28 = arrow::BitUtil::FromLittleEndian(w28);
+ in += 8;
+ uint64_t w29 = util::SafeLoadAs<uint64_t>(in);
+ w29 = arrow::BitUtil::FromLittleEndian(w29);
+ in += 8;
+ uint64_t w30 = util::SafeLoadAs<uint64_t>(in);
+ w30 = arrow::BitUtil::FromLittleEndian(w30);
+ in += 8;
+ uint64_t w31 = util::SafeLoadAs<uint64_t>(in);
+ w31 = arrow::BitUtil::FromLittleEndian(w31);
+ in += 8;
+ out[0] = w0;
+ out[1] = w1;
+ out[2] = w2;
+ out[3] = w3;
+ out[4] = w4;
+ out[5] = w5;
+ out[6] = w6;
+ out[7] = w7;
+ out[8] = w8;
+ out[9] = w9;
+ out[10] = w10;
+ out[11] = w11;
+ out[12] = w12;
+ out[13] = w13;
+ out[14] = w14;
+ out[15] = w15;
+ out[16] = w16;
+ out[17] = w17;
+ out[18] = w18;
+ out[19] = w19;
+ out[20] = w20;
+ out[21] = w21;
+ out[22] = w22;
+ out[23] = w23;
+ out[24] = w24;
+ out[25] = w25;
+ out[26] = w26;
+ out[27] = w27;
+ out[28] = w28;
+ out[29] = w29;
+ out[30] = w30;
+ out[31] = w31;
+
+ return in;
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bpacking_avx2.cc b/src/arrow/cpp/src/arrow/util/bpacking_avx2.cc
new file mode 100644
index 000000000..5a3a7bad3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking_avx2.cc
@@ -0,0 +1,31 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/bpacking_avx2.h"
+#include "arrow/util/bpacking_simd256_generated.h"
+#include "arrow/util/bpacking_simd_internal.h"
+
+namespace arrow {
+namespace internal {
+
+int unpack32_avx2(const uint32_t* in, uint32_t* out, int batch_size, int num_bits) {
+ return unpack32_specialized<UnpackBits256<DispatchLevel::AVX2>>(in, out, batch_size,
+ num_bits);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bpacking_avx2.h b/src/arrow/cpp/src/arrow/util/bpacking_avx2.h
new file mode 100644
index 000000000..7a7d8bf8c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking_avx2.h
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <stdint.h>
+
+namespace arrow {
+namespace internal {
+
+int unpack32_avx2(const uint32_t* in, uint32_t* out, int batch_size, int num_bits);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bpacking_avx512.cc b/src/arrow/cpp/src/arrow/util/bpacking_avx512.cc
new file mode 100644
index 000000000..08ccd3fcd
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking_avx512.cc
@@ -0,0 +1,31 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/bpacking_avx512.h"
+#include "arrow/util/bpacking_simd512_generated.h"
+#include "arrow/util/bpacking_simd_internal.h"
+
+namespace arrow {
+namespace internal {
+
+int unpack32_avx512(const uint32_t* in, uint32_t* out, int batch_size, int num_bits) {
+ return unpack32_specialized<UnpackBits512<DispatchLevel::AVX512>>(in, out, batch_size,
+ num_bits);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bpacking_avx512.h b/src/arrow/cpp/src/arrow/util/bpacking_avx512.h
new file mode 100644
index 000000000..96723f803
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking_avx512.h
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <stdint.h>
+
+namespace arrow {
+namespace internal {
+
+int unpack32_avx512(const uint32_t* in, uint32_t* out, int batch_size, int num_bits);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bpacking_default.h b/src/arrow/cpp/src/arrow/util/bpacking_default.h
new file mode 100644
index 000000000..d2516effa
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking_default.h
@@ -0,0 +1,4251 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This file was modified from its original version for inclusion in parquet-cpp.
+// Original source:
+// https://github.com/lemire/FrameOfReference/blob/6ccaf9e97160f9a3b299e23a8ef739e711ef0c71/src/bpacking.cpp
+// The original copyright notice follows.
+
+// This code is released under the
+// Apache License Version 2.0 http://www.apache.org/licenses/.
+// (c) Daniel Lemire 2013
+
+#pragma once
+
+#include "arrow/util/bit_util.h"
+#include "arrow/util/ubsan.h"
+
+namespace arrow {
+namespace internal {
+
+inline const uint32_t* unpack1_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) & 1;
+ out++;
+ *out = (inl >> 1) & 1;
+ out++;
+ *out = (inl >> 2) & 1;
+ out++;
+ *out = (inl >> 3) & 1;
+ out++;
+ *out = (inl >> 4) & 1;
+ out++;
+ *out = (inl >> 5) & 1;
+ out++;
+ *out = (inl >> 6) & 1;
+ out++;
+ *out = (inl >> 7) & 1;
+ out++;
+ *out = (inl >> 8) & 1;
+ out++;
+ *out = (inl >> 9) & 1;
+ out++;
+ *out = (inl >> 10) & 1;
+ out++;
+ *out = (inl >> 11) & 1;
+ out++;
+ *out = (inl >> 12) & 1;
+ out++;
+ *out = (inl >> 13) & 1;
+ out++;
+ *out = (inl >> 14) & 1;
+ out++;
+ *out = (inl >> 15) & 1;
+ out++;
+ *out = (inl >> 16) & 1;
+ out++;
+ *out = (inl >> 17) & 1;
+ out++;
+ *out = (inl >> 18) & 1;
+ out++;
+ *out = (inl >> 19) & 1;
+ out++;
+ *out = (inl >> 20) & 1;
+ out++;
+ *out = (inl >> 21) & 1;
+ out++;
+ *out = (inl >> 22) & 1;
+ out++;
+ *out = (inl >> 23) & 1;
+ out++;
+ *out = (inl >> 24) & 1;
+ out++;
+ *out = (inl >> 25) & 1;
+ out++;
+ *out = (inl >> 26) & 1;
+ out++;
+ *out = (inl >> 27) & 1;
+ out++;
+ *out = (inl >> 28) & 1;
+ out++;
+ *out = (inl >> 29) & 1;
+ out++;
+ *out = (inl >> 30) & 1;
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack2_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 2);
+ out++;
+ *out = (inl >> 2) % (1U << 2);
+ out++;
+ *out = (inl >> 4) % (1U << 2);
+ out++;
+ *out = (inl >> 6) % (1U << 2);
+ out++;
+ *out = (inl >> 8) % (1U << 2);
+ out++;
+ *out = (inl >> 10) % (1U << 2);
+ out++;
+ *out = (inl >> 12) % (1U << 2);
+ out++;
+ *out = (inl >> 14) % (1U << 2);
+ out++;
+ *out = (inl >> 16) % (1U << 2);
+ out++;
+ *out = (inl >> 18) % (1U << 2);
+ out++;
+ *out = (inl >> 20) % (1U << 2);
+ out++;
+ *out = (inl >> 22) % (1U << 2);
+ out++;
+ *out = (inl >> 24) % (1U << 2);
+ out++;
+ *out = (inl >> 26) % (1U << 2);
+ out++;
+ *out = (inl >> 28) % (1U << 2);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 2);
+ out++;
+ *out = (inl >> 2) % (1U << 2);
+ out++;
+ *out = (inl >> 4) % (1U << 2);
+ out++;
+ *out = (inl >> 6) % (1U << 2);
+ out++;
+ *out = (inl >> 8) % (1U << 2);
+ out++;
+ *out = (inl >> 10) % (1U << 2);
+ out++;
+ *out = (inl >> 12) % (1U << 2);
+ out++;
+ *out = (inl >> 14) % (1U << 2);
+ out++;
+ *out = (inl >> 16) % (1U << 2);
+ out++;
+ *out = (inl >> 18) % (1U << 2);
+ out++;
+ *out = (inl >> 20) % (1U << 2);
+ out++;
+ *out = (inl >> 22) % (1U << 2);
+ out++;
+ *out = (inl >> 24) % (1U << 2);
+ out++;
+ *out = (inl >> 26) % (1U << 2);
+ out++;
+ *out = (inl >> 28) % (1U << 2);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack3_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 3);
+ out++;
+ *out = (inl >> 3) % (1U << 3);
+ out++;
+ *out = (inl >> 6) % (1U << 3);
+ out++;
+ *out = (inl >> 9) % (1U << 3);
+ out++;
+ *out = (inl >> 12) % (1U << 3);
+ out++;
+ *out = (inl >> 15) % (1U << 3);
+ out++;
+ *out = (inl >> 18) % (1U << 3);
+ out++;
+ *out = (inl >> 21) % (1U << 3);
+ out++;
+ *out = (inl >> 24) % (1U << 3);
+ out++;
+ *out = (inl >> 27) % (1U << 3);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (3 - 1);
+ out++;
+ *out = (inl >> 1) % (1U << 3);
+ out++;
+ *out = (inl >> 4) % (1U << 3);
+ out++;
+ *out = (inl >> 7) % (1U << 3);
+ out++;
+ *out = (inl >> 10) % (1U << 3);
+ out++;
+ *out = (inl >> 13) % (1U << 3);
+ out++;
+ *out = (inl >> 16) % (1U << 3);
+ out++;
+ *out = (inl >> 19) % (1U << 3);
+ out++;
+ *out = (inl >> 22) % (1U << 3);
+ out++;
+ *out = (inl >> 25) % (1U << 3);
+ out++;
+ *out = (inl >> 28) % (1U << 3);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (3 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 3);
+ out++;
+ *out = (inl >> 5) % (1U << 3);
+ out++;
+ *out = (inl >> 8) % (1U << 3);
+ out++;
+ *out = (inl >> 11) % (1U << 3);
+ out++;
+ *out = (inl >> 14) % (1U << 3);
+ out++;
+ *out = (inl >> 17) % (1U << 3);
+ out++;
+ *out = (inl >> 20) % (1U << 3);
+ out++;
+ *out = (inl >> 23) % (1U << 3);
+ out++;
+ *out = (inl >> 26) % (1U << 3);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack4_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 4);
+ out++;
+ *out = (inl >> 4) % (1U << 4);
+ out++;
+ *out = (inl >> 8) % (1U << 4);
+ out++;
+ *out = (inl >> 12) % (1U << 4);
+ out++;
+ *out = (inl >> 16) % (1U << 4);
+ out++;
+ *out = (inl >> 20) % (1U << 4);
+ out++;
+ *out = (inl >> 24) % (1U << 4);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 4);
+ out++;
+ *out = (inl >> 4) % (1U << 4);
+ out++;
+ *out = (inl >> 8) % (1U << 4);
+ out++;
+ *out = (inl >> 12) % (1U << 4);
+ out++;
+ *out = (inl >> 16) % (1U << 4);
+ out++;
+ *out = (inl >> 20) % (1U << 4);
+ out++;
+ *out = (inl >> 24) % (1U << 4);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 4);
+ out++;
+ *out = (inl >> 4) % (1U << 4);
+ out++;
+ *out = (inl >> 8) % (1U << 4);
+ out++;
+ *out = (inl >> 12) % (1U << 4);
+ out++;
+ *out = (inl >> 16) % (1U << 4);
+ out++;
+ *out = (inl >> 20) % (1U << 4);
+ out++;
+ *out = (inl >> 24) % (1U << 4);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 4);
+ out++;
+ *out = (inl >> 4) % (1U << 4);
+ out++;
+ *out = (inl >> 8) % (1U << 4);
+ out++;
+ *out = (inl >> 12) % (1U << 4);
+ out++;
+ *out = (inl >> 16) % (1U << 4);
+ out++;
+ *out = (inl >> 20) % (1U << 4);
+ out++;
+ *out = (inl >> 24) % (1U << 4);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack5_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 5);
+ out++;
+ *out = (inl >> 5) % (1U << 5);
+ out++;
+ *out = (inl >> 10) % (1U << 5);
+ out++;
+ *out = (inl >> 15) % (1U << 5);
+ out++;
+ *out = (inl >> 20) % (1U << 5);
+ out++;
+ *out = (inl >> 25) % (1U << 5);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 3)) << (5 - 3);
+ out++;
+ *out = (inl >> 3) % (1U << 5);
+ out++;
+ *out = (inl >> 8) % (1U << 5);
+ out++;
+ *out = (inl >> 13) % (1U << 5);
+ out++;
+ *out = (inl >> 18) % (1U << 5);
+ out++;
+ *out = (inl >> 23) % (1U << 5);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (5 - 1);
+ out++;
+ *out = (inl >> 1) % (1U << 5);
+ out++;
+ *out = (inl >> 6) % (1U << 5);
+ out++;
+ *out = (inl >> 11) % (1U << 5);
+ out++;
+ *out = (inl >> 16) % (1U << 5);
+ out++;
+ *out = (inl >> 21) % (1U << 5);
+ out++;
+ *out = (inl >> 26) % (1U << 5);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (5 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 5);
+ out++;
+ *out = (inl >> 9) % (1U << 5);
+ out++;
+ *out = (inl >> 14) % (1U << 5);
+ out++;
+ *out = (inl >> 19) % (1U << 5);
+ out++;
+ *out = (inl >> 24) % (1U << 5);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (5 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 5);
+ out++;
+ *out = (inl >> 7) % (1U << 5);
+ out++;
+ *out = (inl >> 12) % (1U << 5);
+ out++;
+ *out = (inl >> 17) % (1U << 5);
+ out++;
+ *out = (inl >> 22) % (1U << 5);
+ out++;
+ *out = (inl >> 27);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack6_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 6);
+ out++;
+ *out = (inl >> 6) % (1U << 6);
+ out++;
+ *out = (inl >> 12) % (1U << 6);
+ out++;
+ *out = (inl >> 18) % (1U << 6);
+ out++;
+ *out = (inl >> 24) % (1U << 6);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (6 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 6);
+ out++;
+ *out = (inl >> 10) % (1U << 6);
+ out++;
+ *out = (inl >> 16) % (1U << 6);
+ out++;
+ *out = (inl >> 22) % (1U << 6);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (6 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 6);
+ out++;
+ *out = (inl >> 8) % (1U << 6);
+ out++;
+ *out = (inl >> 14) % (1U << 6);
+ out++;
+ *out = (inl >> 20) % (1U << 6);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 6);
+ out++;
+ *out = (inl >> 6) % (1U << 6);
+ out++;
+ *out = (inl >> 12) % (1U << 6);
+ out++;
+ *out = (inl >> 18) % (1U << 6);
+ out++;
+ *out = (inl >> 24) % (1U << 6);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (6 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 6);
+ out++;
+ *out = (inl >> 10) % (1U << 6);
+ out++;
+ *out = (inl >> 16) % (1U << 6);
+ out++;
+ *out = (inl >> 22) % (1U << 6);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (6 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 6);
+ out++;
+ *out = (inl >> 8) % (1U << 6);
+ out++;
+ *out = (inl >> 14) % (1U << 6);
+ out++;
+ *out = (inl >> 20) % (1U << 6);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack7_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 7);
+ out++;
+ *out = (inl >> 7) % (1U << 7);
+ out++;
+ *out = (inl >> 14) % (1U << 7);
+ out++;
+ *out = (inl >> 21) % (1U << 7);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 3)) << (7 - 3);
+ out++;
+ *out = (inl >> 3) % (1U << 7);
+ out++;
+ *out = (inl >> 10) % (1U << 7);
+ out++;
+ *out = (inl >> 17) % (1U << 7);
+ out++;
+ *out = (inl >> 24) % (1U << 7);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (7 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 7);
+ out++;
+ *out = (inl >> 13) % (1U << 7);
+ out++;
+ *out = (inl >> 20) % (1U << 7);
+ out++;
+ *out = (inl >> 27);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (7 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 7);
+ out++;
+ *out = (inl >> 9) % (1U << 7);
+ out++;
+ *out = (inl >> 16) % (1U << 7);
+ out++;
+ *out = (inl >> 23) % (1U << 7);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 5)) << (7 - 5);
+ out++;
+ *out = (inl >> 5) % (1U << 7);
+ out++;
+ *out = (inl >> 12) % (1U << 7);
+ out++;
+ *out = (inl >> 19) % (1U << 7);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (7 - 1);
+ out++;
+ *out = (inl >> 1) % (1U << 7);
+ out++;
+ *out = (inl >> 8) % (1U << 7);
+ out++;
+ *out = (inl >> 15) % (1U << 7);
+ out++;
+ *out = (inl >> 22) % (1U << 7);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (7 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 7);
+ out++;
+ *out = (inl >> 11) % (1U << 7);
+ out++;
+ *out = (inl >> 18) % (1U << 7);
+ out++;
+ *out = (inl >> 25);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack8_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 8);
+ out++;
+ *out = (inl >> 8) % (1U << 8);
+ out++;
+ *out = (inl >> 16) % (1U << 8);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 8);
+ out++;
+ *out = (inl >> 8) % (1U << 8);
+ out++;
+ *out = (inl >> 16) % (1U << 8);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 8);
+ out++;
+ *out = (inl >> 8) % (1U << 8);
+ out++;
+ *out = (inl >> 16) % (1U << 8);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 8);
+ out++;
+ *out = (inl >> 8) % (1U << 8);
+ out++;
+ *out = (inl >> 16) % (1U << 8);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 8);
+ out++;
+ *out = (inl >> 8) % (1U << 8);
+ out++;
+ *out = (inl >> 16) % (1U << 8);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 8);
+ out++;
+ *out = (inl >> 8) % (1U << 8);
+ out++;
+ *out = (inl >> 16) % (1U << 8);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 8);
+ out++;
+ *out = (inl >> 8) % (1U << 8);
+ out++;
+ *out = (inl >> 16) % (1U << 8);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 8);
+ out++;
+ *out = (inl >> 8) % (1U << 8);
+ out++;
+ *out = (inl >> 16) % (1U << 8);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack9_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 9);
+ out++;
+ *out = (inl >> 9) % (1U << 9);
+ out++;
+ *out = (inl >> 18) % (1U << 9);
+ out++;
+ *out = (inl >> 27);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (9 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 9);
+ out++;
+ *out = (inl >> 13) % (1U << 9);
+ out++;
+ *out = (inl >> 22) % (1U << 9);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (9 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 9);
+ out++;
+ *out = (inl >> 17) % (1U << 9);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 3)) << (9 - 3);
+ out++;
+ *out = (inl >> 3) % (1U << 9);
+ out++;
+ *out = (inl >> 12) % (1U << 9);
+ out++;
+ *out = (inl >> 21) % (1U << 9);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 7)) << (9 - 7);
+ out++;
+ *out = (inl >> 7) % (1U << 9);
+ out++;
+ *out = (inl >> 16) % (1U << 9);
+ out++;
+ *out = (inl >> 25);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (9 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 9);
+ out++;
+ *out = (inl >> 11) % (1U << 9);
+ out++;
+ *out = (inl >> 20) % (1U << 9);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (9 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 9);
+ out++;
+ *out = (inl >> 15) % (1U << 9);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (9 - 1);
+ out++;
+ *out = (inl >> 1) % (1U << 9);
+ out++;
+ *out = (inl >> 10) % (1U << 9);
+ out++;
+ *out = (inl >> 19) % (1U << 9);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 5)) << (9 - 5);
+ out++;
+ *out = (inl >> 5) % (1U << 9);
+ out++;
+ *out = (inl >> 14) % (1U << 9);
+ out++;
+ *out = (inl >> 23);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack10_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 10);
+ out++;
+ *out = (inl >> 10) % (1U << 10);
+ out++;
+ *out = (inl >> 20) % (1U << 10);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (10 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 10);
+ out++;
+ *out = (inl >> 18) % (1U << 10);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (10 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 10);
+ out++;
+ *out = (inl >> 16) % (1U << 10);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (10 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 10);
+ out++;
+ *out = (inl >> 14) % (1U << 10);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (10 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 10);
+ out++;
+ *out = (inl >> 12) % (1U << 10);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 10);
+ out++;
+ *out = (inl >> 10) % (1U << 10);
+ out++;
+ *out = (inl >> 20) % (1U << 10);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (10 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 10);
+ out++;
+ *out = (inl >> 18) % (1U << 10);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (10 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 10);
+ out++;
+ *out = (inl >> 16) % (1U << 10);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (10 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 10);
+ out++;
+ *out = (inl >> 14) % (1U << 10);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (10 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 10);
+ out++;
+ *out = (inl >> 12) % (1U << 10);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack11_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 11);
+ out++;
+ *out = (inl >> 11) % (1U << 11);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (11 - 1);
+ out++;
+ *out = (inl >> 1) % (1U << 11);
+ out++;
+ *out = (inl >> 12) % (1U << 11);
+ out++;
+ *out = (inl >> 23);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (11 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 11);
+ out++;
+ *out = (inl >> 13) % (1U << 11);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 3)) << (11 - 3);
+ out++;
+ *out = (inl >> 3) % (1U << 11);
+ out++;
+ *out = (inl >> 14) % (1U << 11);
+ out++;
+ *out = (inl >> 25);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (11 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 11);
+ out++;
+ *out = (inl >> 15) % (1U << 11);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 5)) << (11 - 5);
+ out++;
+ *out = (inl >> 5) % (1U << 11);
+ out++;
+ *out = (inl >> 16) % (1U << 11);
+ out++;
+ *out = (inl >> 27);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (11 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 11);
+ out++;
+ *out = (inl >> 17) % (1U << 11);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 7)) << (11 - 7);
+ out++;
+ *out = (inl >> 7) % (1U << 11);
+ out++;
+ *out = (inl >> 18) % (1U << 11);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (11 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 11);
+ out++;
+ *out = (inl >> 19) % (1U << 11);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 9)) << (11 - 9);
+ out++;
+ *out = (inl >> 9) % (1U << 11);
+ out++;
+ *out = (inl >> 20) % (1U << 11);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (11 - 10);
+ out++;
+ *out = (inl >> 10) % (1U << 11);
+ out++;
+ *out = (inl >> 21);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack12_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 12);
+ out++;
+ *out = (inl >> 12) % (1U << 12);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (12 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 12);
+ out++;
+ *out = (inl >> 16) % (1U << 12);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (12 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 12);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 12);
+ out++;
+ *out = (inl >> 12) % (1U << 12);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (12 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 12);
+ out++;
+ *out = (inl >> 16) % (1U << 12);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (12 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 12);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 12);
+ out++;
+ *out = (inl >> 12) % (1U << 12);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (12 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 12);
+ out++;
+ *out = (inl >> 16) % (1U << 12);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (12 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 12);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 12);
+ out++;
+ *out = (inl >> 12) % (1U << 12);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (12 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 12);
+ out++;
+ *out = (inl >> 16) % (1U << 12);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (12 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 12);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack13_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 13);
+ out++;
+ *out = (inl >> 13) % (1U << 13);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 7)) << (13 - 7);
+ out++;
+ *out = (inl >> 7) % (1U << 13);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (13 - 1);
+ out++;
+ *out = (inl >> 1) % (1U << 13);
+ out++;
+ *out = (inl >> 14) % (1U << 13);
+ out++;
+ *out = (inl >> 27);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (13 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 13);
+ out++;
+ *out = (inl >> 21);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (13 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 13);
+ out++;
+ *out = (inl >> 15) % (1U << 13);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 9)) << (13 - 9);
+ out++;
+ *out = (inl >> 9) % (1U << 13);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 3)) << (13 - 3);
+ out++;
+ *out = (inl >> 3) % (1U << 13);
+ out++;
+ *out = (inl >> 16) % (1U << 13);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (13 - 10);
+ out++;
+ *out = (inl >> 10) % (1U << 13);
+ out++;
+ *out = (inl >> 23);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (13 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 13);
+ out++;
+ *out = (inl >> 17) % (1U << 13);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 11)) << (13 - 11);
+ out++;
+ *out = (inl >> 11) % (1U << 13);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 5)) << (13 - 5);
+ out++;
+ *out = (inl >> 5) % (1U << 13);
+ out++;
+ *out = (inl >> 18) % (1U << 13);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (13 - 12);
+ out++;
+ *out = (inl >> 12) % (1U << 13);
+ out++;
+ *out = (inl >> 25);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (13 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 13);
+ out++;
+ *out = (inl >> 19);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack14_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 14);
+ out++;
+ *out = (inl >> 14) % (1U << 14);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (14 - 10);
+ out++;
+ *out = (inl >> 10) % (1U << 14);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (14 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 14);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (14 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 14);
+ out++;
+ *out = (inl >> 16) % (1U << 14);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (14 - 12);
+ out++;
+ *out = (inl >> 12) % (1U << 14);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (14 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 14);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (14 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 14);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 14);
+ out++;
+ *out = (inl >> 14) % (1U << 14);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (14 - 10);
+ out++;
+ *out = (inl >> 10) % (1U << 14);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (14 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 14);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (14 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 14);
+ out++;
+ *out = (inl >> 16) % (1U << 14);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (14 - 12);
+ out++;
+ *out = (inl >> 12) % (1U << 14);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (14 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 14);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (14 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 14);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack15_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 15);
+ out++;
+ *out = (inl >> 15) % (1U << 15);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 13)) << (15 - 13);
+ out++;
+ *out = (inl >> 13) % (1U << 15);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 11)) << (15 - 11);
+ out++;
+ *out = (inl >> 11) % (1U << 15);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 9)) << (15 - 9);
+ out++;
+ *out = (inl >> 9) % (1U << 15);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 7)) << (15 - 7);
+ out++;
+ *out = (inl >> 7) % (1U << 15);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 5)) << (15 - 5);
+ out++;
+ *out = (inl >> 5) % (1U << 15);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 3)) << (15 - 3);
+ out++;
+ *out = (inl >> 3) % (1U << 15);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (15 - 1);
+ out++;
+ *out = (inl >> 1) % (1U << 15);
+ out++;
+ *out = (inl >> 16) % (1U << 15);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (15 - 14);
+ out++;
+ *out = (inl >> 14) % (1U << 15);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (15 - 12);
+ out++;
+ *out = (inl >> 12) % (1U << 15);
+ out++;
+ *out = (inl >> 27);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (15 - 10);
+ out++;
+ *out = (inl >> 10) % (1U << 15);
+ out++;
+ *out = (inl >> 25);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (15 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 15);
+ out++;
+ *out = (inl >> 23);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (15 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 15);
+ out++;
+ *out = (inl >> 21);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (15 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 15);
+ out++;
+ *out = (inl >> 19);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (15 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 15);
+ out++;
+ *out = (inl >> 17);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack16_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack17_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 17);
+ out++;
+ *out = (inl >> 17);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (17 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 17);
+ out++;
+ *out = (inl >> 19);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (17 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 17);
+ out++;
+ *out = (inl >> 21);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (17 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 17);
+ out++;
+ *out = (inl >> 23);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (17 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 17);
+ out++;
+ *out = (inl >> 25);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (17 - 10);
+ out++;
+ *out = (inl >> 10) % (1U << 17);
+ out++;
+ *out = (inl >> 27);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (17 - 12);
+ out++;
+ *out = (inl >> 12) % (1U << 17);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (17 - 14);
+ out++;
+ *out = (inl >> 14) % (1U << 17);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (17 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (17 - 1);
+ out++;
+ *out = (inl >> 1) % (1U << 17);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 3)) << (17 - 3);
+ out++;
+ *out = (inl >> 3) % (1U << 17);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 5)) << (17 - 5);
+ out++;
+ *out = (inl >> 5) % (1U << 17);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 7)) << (17 - 7);
+ out++;
+ *out = (inl >> 7) % (1U << 17);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 9)) << (17 - 9);
+ out++;
+ *out = (inl >> 9) % (1U << 17);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 11)) << (17 - 11);
+ out++;
+ *out = (inl >> 11) % (1U << 17);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 13)) << (17 - 13);
+ out++;
+ *out = (inl >> 13) % (1U << 17);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 15)) << (17 - 15);
+ out++;
+ *out = (inl >> 15);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack18_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (18 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 18);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (18 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 18);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (18 - 12);
+ out++;
+ *out = (inl >> 12) % (1U << 18);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (18 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (18 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 18);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (18 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 18);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (18 - 10);
+ out++;
+ *out = (inl >> 10) % (1U << 18);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (18 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (18 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 18);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (18 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 18);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (18 - 12);
+ out++;
+ *out = (inl >> 12) % (1U << 18);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (18 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (18 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 18);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (18 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 18);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (18 - 10);
+ out++;
+ *out = (inl >> 10) % (1U << 18);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (18 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack19_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 19);
+ out++;
+ *out = (inl >> 19);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (19 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 19);
+ out++;
+ *out = (inl >> 25);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (19 - 12);
+ out++;
+ *out = (inl >> 12) % (1U << 19);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 18)) << (19 - 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 5)) << (19 - 5);
+ out++;
+ *out = (inl >> 5) % (1U << 19);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 11)) << (19 - 11);
+ out++;
+ *out = (inl >> 11) % (1U << 19);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 17)) << (19 - 17);
+ out++;
+ *out = (inl >> 17);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (19 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 19);
+ out++;
+ *out = (inl >> 23);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (19 - 10);
+ out++;
+ *out = (inl >> 10) % (1U << 19);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (19 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 3)) << (19 - 3);
+ out++;
+ *out = (inl >> 3) % (1U << 19);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 9)) << (19 - 9);
+ out++;
+ *out = (inl >> 9) % (1U << 19);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 15)) << (19 - 15);
+ out++;
+ *out = (inl >> 15);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (19 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 19);
+ out++;
+ *out = (inl >> 21);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (19 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 19);
+ out++;
+ *out = (inl >> 27);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (19 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (19 - 1);
+ out++;
+ *out = (inl >> 1) % (1U << 19);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 7)) << (19 - 7);
+ out++;
+ *out = (inl >> 7) % (1U << 19);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 13)) << (19 - 13);
+ out++;
+ *out = (inl >> 13);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack20_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (20 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 20);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (20 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (20 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 20);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (20 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (20 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 20);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (20 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (20 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 20);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (20 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (20 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 20);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (20 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (20 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 20);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (20 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (20 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 20);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (20 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (20 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 20);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (20 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack21_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 21);
+ out++;
+ *out = (inl >> 21);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (21 - 10);
+ out++;
+ *out = (inl >> 10) % (1U << 21);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (21 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 9)) << (21 - 9);
+ out++;
+ *out = (inl >> 9) % (1U << 21);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 19)) << (21 - 19);
+ out++;
+ *out = (inl >> 19);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (21 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 21);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 18)) << (21 - 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 7)) << (21 - 7);
+ out++;
+ *out = (inl >> 7) % (1U << 21);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 17)) << (21 - 17);
+ out++;
+ *out = (inl >> 17);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (21 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 21);
+ out++;
+ *out = (inl >> 27);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (21 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 5)) << (21 - 5);
+ out++;
+ *out = (inl >> 5) % (1U << 21);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 15)) << (21 - 15);
+ out++;
+ *out = (inl >> 15);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (21 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 21);
+ out++;
+ *out = (inl >> 25);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (21 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 3)) << (21 - 3);
+ out++;
+ *out = (inl >> 3) % (1U << 21);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 13)) << (21 - 13);
+ out++;
+ *out = (inl >> 13);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (21 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 21);
+ out++;
+ *out = (inl >> 23);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (21 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (21 - 1);
+ out++;
+ *out = (inl >> 1) % (1U << 21);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 11)) << (21 - 11);
+ out++;
+ *out = (inl >> 11);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack22_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 22);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (22 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (22 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 22);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (22 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (22 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 22);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (22 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (22 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 22);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 18)) << (22 - 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (22 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 22);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (22 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (22 - 10);
+ out++;
+ *out = (inl >> 10);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 22);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (22 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (22 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 22);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (22 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (22 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 22);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (22 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (22 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 22);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 18)) << (22 - 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (22 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 22);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (22 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (22 - 10);
+ out++;
+ *out = (inl >> 10);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack23_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 23);
+ out++;
+ *out = (inl >> 23);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (23 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 5)) << (23 - 5);
+ out++;
+ *out = (inl >> 5) % (1U << 23);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 19)) << (23 - 19);
+ out++;
+ *out = (inl >> 19);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (23 - 10);
+ out++;
+ *out = (inl >> 10);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (23 - 1);
+ out++;
+ *out = (inl >> 1) % (1U << 23);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 15)) << (23 - 15);
+ out++;
+ *out = (inl >> 15);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (23 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 23);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (23 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 11)) << (23 - 11);
+ out++;
+ *out = (inl >> 11);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (23 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 23);
+ out++;
+ *out = (inl >> 25);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (23 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 7)) << (23 - 7);
+ out++;
+ *out = (inl >> 7) % (1U << 23);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 21)) << (23 - 21);
+ out++;
+ *out = (inl >> 21);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (23 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 3)) << (23 - 3);
+ out++;
+ *out = (inl >> 3) % (1U << 23);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 17)) << (23 - 17);
+ out++;
+ *out = (inl >> 17);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (23 - 8);
+ out++;
+ *out = (inl >> 8) % (1U << 23);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 22)) << (23 - 22);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 13)) << (23 - 13);
+ out++;
+ *out = (inl >> 13);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (23 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 23);
+ out++;
+ *out = (inl >> 27);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 18)) << (23 - 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 9)) << (23 - 9);
+ out++;
+ *out = (inl >> 9);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack24_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (24 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (24 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (24 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (24 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (24 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (24 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (24 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (24 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (24 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (24 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (24 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (24 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (24 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (24 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (24 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (24 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack25_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 25);
+ out++;
+ *out = (inl >> 25);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 18)) << (25 - 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 11)) << (25 - 11);
+ out++;
+ *out = (inl >> 11);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (25 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 25);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 22)) << (25 - 22);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 15)) << (25 - 15);
+ out++;
+ *out = (inl >> 15);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (25 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (25 - 1);
+ out++;
+ *out = (inl >> 1) % (1U << 25);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 19)) << (25 - 19);
+ out++;
+ *out = (inl >> 19);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (25 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 5)) << (25 - 5);
+ out++;
+ *out = (inl >> 5) % (1U << 25);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 23)) << (25 - 23);
+ out++;
+ *out = (inl >> 23);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (25 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 9)) << (25 - 9);
+ out++;
+ *out = (inl >> 9);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (25 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 25);
+ out++;
+ *out = (inl >> 27);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (25 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 13)) << (25 - 13);
+ out++;
+ *out = (inl >> 13);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (25 - 6);
+ out++;
+ *out = (inl >> 6) % (1U << 25);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 24)) << (25 - 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 17)) << (25 - 17);
+ out++;
+ *out = (inl >> 17);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (25 - 10);
+ out++;
+ *out = (inl >> 10);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 3)) << (25 - 3);
+ out++;
+ *out = (inl >> 3) % (1U << 25);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 21)) << (25 - 21);
+ out++;
+ *out = (inl >> 21);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (25 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 7)) << (25 - 7);
+ out++;
+ *out = (inl >> 7);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack26_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 26);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (26 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (26 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (26 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (26 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 26);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 22)) << (26 - 22);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (26 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (26 - 10);
+ out++;
+ *out = (inl >> 10);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (26 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 26);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 24)) << (26 - 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 18)) << (26 - 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (26 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (26 - 6);
+ out++;
+ *out = (inl >> 6);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 26);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (26 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (26 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (26 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (26 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 26);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 22)) << (26 - 22);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (26 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (26 - 10);
+ out++;
+ *out = (inl >> 10);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (26 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 26);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 24)) << (26 - 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 18)) << (26 - 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (26 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (26 - 6);
+ out++;
+ *out = (inl >> 6);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack27_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 27);
+ out++;
+ *out = (inl >> 27);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 22)) << (27 - 22);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 17)) << (27 - 17);
+ out++;
+ *out = (inl >> 17);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (27 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 7)) << (27 - 7);
+ out++;
+ *out = (inl >> 7);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (27 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 27);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 24)) << (27 - 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 19)) << (27 - 19);
+ out++;
+ *out = (inl >> 19);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (27 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 9)) << (27 - 9);
+ out++;
+ *out = (inl >> 9);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (27 - 4);
+ out++;
+ *out = (inl >> 4) % (1U << 27);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 26)) << (27 - 26);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 21)) << (27 - 21);
+ out++;
+ *out = (inl >> 21);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (27 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 11)) << (27 - 11);
+ out++;
+ *out = (inl >> 11);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (27 - 6);
+ out++;
+ *out = (inl >> 6);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (27 - 1);
+ out++;
+ *out = (inl >> 1) % (1U << 27);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 23)) << (27 - 23);
+ out++;
+ *out = (inl >> 23);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 18)) << (27 - 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 13)) << (27 - 13);
+ out++;
+ *out = (inl >> 13);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (27 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 3)) << (27 - 3);
+ out++;
+ *out = (inl >> 3) % (1U << 27);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 25)) << (27 - 25);
+ out++;
+ *out = (inl >> 25);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (27 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 15)) << (27 - 15);
+ out++;
+ *out = (inl >> 15);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (27 - 10);
+ out++;
+ *out = (inl >> 10);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 5)) << (27 - 5);
+ out++;
+ *out = (inl >> 5);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack28_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 28);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 24)) << (28 - 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (28 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (28 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (28 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (28 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (28 - 4);
+ out++;
+ *out = (inl >> 4);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 28);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 24)) << (28 - 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (28 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (28 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (28 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (28 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (28 - 4);
+ out++;
+ *out = (inl >> 4);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 28);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 24)) << (28 - 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (28 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (28 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (28 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (28 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (28 - 4);
+ out++;
+ *out = (inl >> 4);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 28);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 24)) << (28 - 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (28 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (28 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (28 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (28 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (28 - 4);
+ out++;
+ *out = (inl >> 4);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack29_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 29);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 26)) << (29 - 26);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 23)) << (29 - 23);
+ out++;
+ *out = (inl >> 23);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (29 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 17)) << (29 - 17);
+ out++;
+ *out = (inl >> 17);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (29 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 11)) << (29 - 11);
+ out++;
+ *out = (inl >> 11);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (29 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 5)) << (29 - 5);
+ out++;
+ *out = (inl >> 5);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (29 - 2);
+ out++;
+ *out = (inl >> 2) % (1U << 29);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 28)) << (29 - 28);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 25)) << (29 - 25);
+ out++;
+ *out = (inl >> 25);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 22)) << (29 - 22);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 19)) << (29 - 19);
+ out++;
+ *out = (inl >> 19);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (29 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 13)) << (29 - 13);
+ out++;
+ *out = (inl >> 13);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (29 - 10);
+ out++;
+ *out = (inl >> 10);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 7)) << (29 - 7);
+ out++;
+ *out = (inl >> 7);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (29 - 4);
+ out++;
+ *out = (inl >> 4);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (29 - 1);
+ out++;
+ *out = (inl >> 1) % (1U << 29);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 27)) << (29 - 27);
+ out++;
+ *out = (inl >> 27);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 24)) << (29 - 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 21)) << (29 - 21);
+ out++;
+ *out = (inl >> 21);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 18)) << (29 - 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 15)) << (29 - 15);
+ out++;
+ *out = (inl >> 15);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (29 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 9)) << (29 - 9);
+ out++;
+ *out = (inl >> 9);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (29 - 6);
+ out++;
+ *out = (inl >> 6);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 3)) << (29 - 3);
+ out++;
+ *out = (inl >> 3);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack30_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 30);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 28)) << (30 - 28);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 26)) << (30 - 26);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 24)) << (30 - 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 22)) << (30 - 22);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (30 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 18)) << (30 - 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (30 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (30 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (30 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (30 - 10);
+ out++;
+ *out = (inl >> 10);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (30 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (30 - 6);
+ out++;
+ *out = (inl >> 6);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (30 - 4);
+ out++;
+ *out = (inl >> 4);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (30 - 2);
+ out++;
+ *out = (inl >> 2);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0) % (1U << 30);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 28)) << (30 - 28);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 26)) << (30 - 26);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 24)) << (30 - 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 22)) << (30 - 22);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (30 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 18)) << (30 - 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (30 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (30 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (30 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (30 - 10);
+ out++;
+ *out = (inl >> 10);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (30 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (30 - 6);
+ out++;
+ *out = (inl >> 6);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (30 - 4);
+ out++;
+ *out = (inl >> 4);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (30 - 2);
+ out++;
+ *out = (inl >> 2);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack31_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0) % (1U << 31);
+ out++;
+ *out = (inl >> 31);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 30)) << (31 - 30);
+ out++;
+ *out = (inl >> 30);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 29)) << (31 - 29);
+ out++;
+ *out = (inl >> 29);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 28)) << (31 - 28);
+ out++;
+ *out = (inl >> 28);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 27)) << (31 - 27);
+ out++;
+ *out = (inl >> 27);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 26)) << (31 - 26);
+ out++;
+ *out = (inl >> 26);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 25)) << (31 - 25);
+ out++;
+ *out = (inl >> 25);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 24)) << (31 - 24);
+ out++;
+ *out = (inl >> 24);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 23)) << (31 - 23);
+ out++;
+ *out = (inl >> 23);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 22)) << (31 - 22);
+ out++;
+ *out = (inl >> 22);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 21)) << (31 - 21);
+ out++;
+ *out = (inl >> 21);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 20)) << (31 - 20);
+ out++;
+ *out = (inl >> 20);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 19)) << (31 - 19);
+ out++;
+ *out = (inl >> 19);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 18)) << (31 - 18);
+ out++;
+ *out = (inl >> 18);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 17)) << (31 - 17);
+ out++;
+ *out = (inl >> 17);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 16)) << (31 - 16);
+ out++;
+ *out = (inl >> 16);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 15)) << (31 - 15);
+ out++;
+ *out = (inl >> 15);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 14)) << (31 - 14);
+ out++;
+ *out = (inl >> 14);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 13)) << (31 - 13);
+ out++;
+ *out = (inl >> 13);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 12)) << (31 - 12);
+ out++;
+ *out = (inl >> 12);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 11)) << (31 - 11);
+ out++;
+ *out = (inl >> 11);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 10)) << (31 - 10);
+ out++;
+ *out = (inl >> 10);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 9)) << (31 - 9);
+ out++;
+ *out = (inl >> 9);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 8)) << (31 - 8);
+ out++;
+ *out = (inl >> 8);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 7)) << (31 - 7);
+ out++;
+ *out = (inl >> 7);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 6)) << (31 - 6);
+ out++;
+ *out = (inl >> 6);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 5)) << (31 - 5);
+ out++;
+ *out = (inl >> 5);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 4)) << (31 - 4);
+ out++;
+ *out = (inl >> 4);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 3)) << (31 - 3);
+ out++;
+ *out = (inl >> 3);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 2)) << (31 - 2);
+ out++;
+ *out = (inl >> 2);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out |= (inl % (1U << 1)) << (31 - 1);
+ out++;
+ *out = (inl >> 1);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* unpack32_32(const uint32_t* in, uint32_t* out) {
+ uint32_t inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ inl = util::SafeLoad(in);
+ inl = arrow::BitUtil::FromLittleEndian(inl);
+ out++;
+ *out = (inl >> 0);
+ ++in;
+ out++;
+
+ return in;
+}
+
+inline const uint32_t* nullunpacker32(const uint32_t* in, uint32_t* out) {
+ for (int k = 0; k < 32; ++k) {
+ out[k] = 0;
+ }
+ return in;
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bpacking_neon.cc b/src/arrow/cpp/src/arrow/util/bpacking_neon.cc
new file mode 100644
index 000000000..a0bb5dc7a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking_neon.cc
@@ -0,0 +1,31 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/bpacking_neon.h"
+#include "arrow/util/bpacking_simd128_generated.h"
+#include "arrow/util/bpacking_simd_internal.h"
+
+namespace arrow {
+namespace internal {
+
+int unpack32_neon(const uint32_t* in, uint32_t* out, int batch_size, int num_bits) {
+ return unpack32_specialized<UnpackBits128<DispatchLevel::NEON>>(in, out, batch_size,
+ num_bits);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bpacking_neon.h b/src/arrow/cpp/src/arrow/util/bpacking_neon.h
new file mode 100644
index 000000000..9d02cd568
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking_neon.h
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <stdint.h>
+
+namespace arrow {
+namespace internal {
+
+int unpack32_neon(const uint32_t* in, uint32_t* out, int batch_size, int num_bits);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/bpacking_simd128_generated.h b/src/arrow/cpp/src/arrow/util/bpacking_simd128_generated.h
new file mode 100644
index 000000000..dca692971
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking_simd128_generated.h
@@ -0,0 +1,2144 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Automatically generated file; DO NOT EDIT.
+
+#pragma once
+
+#include <cstdint>
+#include <cstring>
+
+#include <xsimd/xsimd.hpp>
+
+#include "arrow/util/dispatch.h"
+#include "arrow/util/ubsan.h"
+
+namespace arrow {
+namespace internal {
+namespace {
+
+using ::arrow::util::SafeLoad;
+
+template <DispatchLevel level>
+struct UnpackBits128 {
+
+#ifdef ARROW_HAVE_NEON
+using simd_arch = xsimd::neon64;
+#else
+using simd_arch = xsimd::sse4_2;
+#endif
+
+using simd_batch = xsimd::batch<uint32_t, simd_arch>;
+
+inline static const uint32_t* unpack0_32(const uint32_t* in, uint32_t* out) {
+ memset(out, 0x0, 32 * sizeof(*out));
+ out += 32;
+
+ return in;
+}
+
+inline static const uint32_t* unpack1_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 1-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 0, 1, 2, 3 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 1-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 4, 5, 6, 7 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 1-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 8, 9, 10, 11 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 1-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 12, 13, 14, 15 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 1-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 16, 17, 18, 19 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 1-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 20, 21, 22, 23 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 1-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 24, 25, 26, 27 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 1-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 28, 29, 30, 31 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 1;
+ return in;
+}
+
+inline static const uint32_t* unpack2_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 2-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 0, 2, 4, 6 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 2-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 8, 10, 12, 14 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 2-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 16, 18, 20, 22 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 2-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 24, 26, 28, 30 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 2-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 2, 4, 6 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 2-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 8, 10, 12, 14 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 2-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 16, 18, 20, 22 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 2-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 24, 26, 28, 30 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 2;
+ return in;
+}
+
+inline static const uint32_t* unpack3_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 3-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 0, 3, 6, 9 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 3-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 12, 15, 18, 21 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 3-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 24, 27, 0, 1 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 3-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 4, 7, 10, 13 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 3-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 16, 19, 22, 25 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 3-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 31 | SafeLoad<uint32_t>(in + 2) << 1, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 28, 0, 2, 5 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 3-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 8, 11, 14, 17 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 3-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 20, 23, 26, 29 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 3;
+ return in;
+}
+
+inline static const uint32_t* unpack4_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xf;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 4-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 0, 4, 8, 12 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 4-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 16, 20, 24, 28 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 4-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 4, 8, 12 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 4-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 16, 20, 24, 28 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 4-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 0, 4, 8, 12 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 4-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 16, 20, 24, 28 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 4-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 0, 4, 8, 12 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 4-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 16, 20, 24, 28 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 4;
+ return in;
+}
+
+inline static const uint32_t* unpack5_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1f;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 5-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 0, 5, 10, 15 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 5-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 20, 25, 0, 3 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 5-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 8, 13, 18, 23 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 5-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 0, 1, 6, 11 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 5-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 31 | SafeLoad<uint32_t>(in + 3) << 1 };
+ shifts = simd_batch{ 16, 21, 26, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 5-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 4, 9, 14, 19 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 5-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 29 | SafeLoad<uint32_t>(in + 4) << 3, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 24, 0, 2, 7 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 5-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 12, 17, 22, 27 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 5;
+ return in;
+}
+
+inline static const uint32_t* unpack6_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3f;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 6-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 0, 6, 12, 18 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 6-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 24, 0, 4, 10 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 6-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 16, 22, 0, 2 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 6-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 8, 14, 20, 26 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 6-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 0, 6, 12, 18 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 6-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 24, 0, 4, 10 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 6-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 28 | SafeLoad<uint32_t>(in + 5) << 4, SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 16, 22, 0, 2 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 6-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 8, 14, 20, 26 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 6;
+ return in;
+}
+
+inline static const uint32_t* unpack7_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7f;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 7-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 0, 7, 14, 21 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 7-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0) >> 28 | SafeLoad<uint32_t>(in + 1) << 4, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 3, 10, 17 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 7-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 31 | SafeLoad<uint32_t>(in + 2) << 1, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 24, 0, 6, 13 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 7-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 27 | SafeLoad<uint32_t>(in + 3) << 5, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 20, 0, 2, 9 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 7-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 16, 23, 0, 5 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 7-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 26 | SafeLoad<uint32_t>(in + 5) << 6, SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 12, 19, 0, 1 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 7-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 29 | SafeLoad<uint32_t>(in + 6) << 3 };
+ shifts = simd_batch{ 8, 15, 22, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 7-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 4, 11, 18, 25 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 7;
+ return in;
+}
+
+inline static const uint32_t* unpack8_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 8-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 0, 8, 16, 24 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 8-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 8, 16, 24 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 8-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 0, 8, 16, 24 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 8-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 0, 8, 16, 24 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 8-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 0, 8, 16, 24 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 8-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 0, 8, 16, 24 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 8-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 0, 8, 16, 24 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 8-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) };
+ shifts = simd_batch{ 0, 8, 16, 24 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 8;
+ return in;
+}
+
+inline static const uint32_t* unpack9_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1ff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 9-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 27 | SafeLoad<uint32_t>(in + 1) << 5 };
+ shifts = simd_batch{ 0, 9, 18, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 9-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 31 | SafeLoad<uint32_t>(in + 2) << 1 };
+ shifts = simd_batch{ 4, 13, 22, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 9-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6, SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 8, 17, 0, 3 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 9-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 12, 21, 0, 7 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 9-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 25 | SafeLoad<uint32_t>(in + 5) << 7, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 16, 0, 2, 11 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 9-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 29 | SafeLoad<uint32_t>(in + 6) << 3, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 20, 0, 6, 15 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 9-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) };
+ shifts = simd_batch{ 0, 1, 10, 19 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 9-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) };
+ shifts = simd_batch{ 0, 5, 14, 23 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 9;
+ return in;
+}
+
+inline static const uint32_t* unpack10_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3ff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 10-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2 };
+ shifts = simd_batch{ 0, 10, 20, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 10-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 8, 18, 0, 6 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 10-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 16, 0, 4, 14 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 10-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 0, 2, 12, 22 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 10-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 30 | SafeLoad<uint32_t>(in + 6) << 2 };
+ shifts = simd_batch{ 0, 10, 20, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 10-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 28 | SafeLoad<uint32_t>(in + 7) << 4, SafeLoad<uint32_t>(in + 7) };
+ shifts = simd_batch{ 8, 18, 0, 6 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 10-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 26 | SafeLoad<uint32_t>(in + 8) << 6, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) };
+ shifts = simd_batch{ 16, 0, 4, 14 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 10-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 8) >> 24 | SafeLoad<uint32_t>(in + 9) << 8, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) };
+ shifts = simd_batch{ 0, 2, 12, 22 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 10;
+ return in;
+}
+
+inline static const uint32_t* unpack11_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7ff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 11-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 22 | SafeLoad<uint32_t>(in + 1) << 10, SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 11, 0, 1 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 11-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 23 | SafeLoad<uint32_t>(in + 2) << 9, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 12, 0, 2, 13 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 11-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2) >> 24 | SafeLoad<uint32_t>(in + 3) << 8, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 25 | SafeLoad<uint32_t>(in + 4) << 7 };
+ shifts = simd_batch{ 0, 3, 14, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 11-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 26 | SafeLoad<uint32_t>(in + 5) << 6, SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 4, 15, 0, 5 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 11-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 27 | SafeLoad<uint32_t>(in + 6) << 5, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 16, 0, 6, 17 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 11-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6) >> 28 | SafeLoad<uint32_t>(in + 7) << 4, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 29 | SafeLoad<uint32_t>(in + 8) << 3 };
+ shifts = simd_batch{ 0, 7, 18, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 11-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 30 | SafeLoad<uint32_t>(in + 9) << 2, SafeLoad<uint32_t>(in + 9) };
+ shifts = simd_batch{ 8, 19, 0, 9 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 11-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 31 | SafeLoad<uint32_t>(in + 10) << 1, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) };
+ shifts = simd_batch{ 20, 0, 10, 21 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 11;
+ return in;
+}
+
+inline static const uint32_t* unpack12_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xfff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 12-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 24 | SafeLoad<uint32_t>(in + 1) << 8, SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 12, 0, 4 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 12-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 16, 0, 8, 20 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 12-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 0, 12, 0, 4 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 12-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 28 | SafeLoad<uint32_t>(in + 5) << 4, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 16, 0, 8, 20 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 12-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7) };
+ shifts = simd_batch{ 0, 12, 0, 4 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 12-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) };
+ shifts = simd_batch{ 16, 0, 8, 20 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 12-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 24 | SafeLoad<uint32_t>(in + 10) << 8, SafeLoad<uint32_t>(in + 10) };
+ shifts = simd_batch{ 0, 12, 0, 4 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 12-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 28 | SafeLoad<uint32_t>(in + 11) << 4, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) };
+ shifts = simd_batch{ 16, 0, 8, 20 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 12;
+ return in;
+}
+
+inline static const uint32_t* unpack13_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1fff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 13-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 26 | SafeLoad<uint32_t>(in + 1) << 6, SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 13, 0, 7 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 13-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1) >> 20 | SafeLoad<uint32_t>(in + 2) << 12, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 27 | SafeLoad<uint32_t>(in + 3) << 5 };
+ shifts = simd_batch{ 0, 1, 14, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 13-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 21 | SafeLoad<uint32_t>(in + 4) << 11, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 8, 0, 2, 15 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 13-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4) >> 28 | SafeLoad<uint32_t>(in + 5) << 4, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 22 | SafeLoad<uint32_t>(in + 6) << 10, SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 0, 9, 0, 3 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 13-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 29 | SafeLoad<uint32_t>(in + 7) << 3, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 23 | SafeLoad<uint32_t>(in + 8) << 9 };
+ shifts = simd_batch{ 16, 0, 10, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 13-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 30 | SafeLoad<uint32_t>(in + 9) << 2, SafeLoad<uint32_t>(in + 9) };
+ shifts = simd_batch{ 4, 17, 0, 11 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 13-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9) >> 24 | SafeLoad<uint32_t>(in + 10) << 8, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 31 | SafeLoad<uint32_t>(in + 11) << 1 };
+ shifts = simd_batch{ 0, 5, 18, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 13-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 25 | SafeLoad<uint32_t>(in + 12) << 7, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) };
+ shifts = simd_batch{ 12, 0, 6, 19 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 13;
+ return in;
+}
+
+inline static const uint32_t* unpack14_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3fff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 14-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 28 | SafeLoad<uint32_t>(in + 1) << 4, SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 14, 0, 10 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 14-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1) >> 24 | SafeLoad<uint32_t>(in + 2) << 8, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 20 | SafeLoad<uint32_t>(in + 3) << 12, SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 0, 6, 0, 2 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 14-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 26 | SafeLoad<uint32_t>(in + 5) << 6 };
+ shifts = simd_batch{ 16, 0, 12, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 14-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 22 | SafeLoad<uint32_t>(in + 6) << 10, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 8, 0, 4, 18 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 14-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8) };
+ shifts = simd_batch{ 0, 14, 0, 10 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 14-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 8) >> 24 | SafeLoad<uint32_t>(in + 9) << 8, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 20 | SafeLoad<uint32_t>(in + 10) << 12, SafeLoad<uint32_t>(in + 10) };
+ shifts = simd_batch{ 0, 6, 0, 2 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 14-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 30 | SafeLoad<uint32_t>(in + 11) << 2, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 26 | SafeLoad<uint32_t>(in + 12) << 6 };
+ shifts = simd_batch{ 16, 0, 12, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 14-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 22 | SafeLoad<uint32_t>(in + 13) << 10, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) };
+ shifts = simd_batch{ 8, 0, 4, 18 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 14;
+ return in;
+}
+
+inline static const uint32_t* unpack15_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7fff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 15-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 15, 0, 13 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 15-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6, SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 0, 11, 0, 9 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 15-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 22 | SafeLoad<uint32_t>(in + 5) << 10, SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 0, 7, 0, 5 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 15-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5) >> 20 | SafeLoad<uint32_t>(in + 6) << 12, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 18 | SafeLoad<uint32_t>(in + 7) << 14, SafeLoad<uint32_t>(in + 7) };
+ shifts = simd_batch{ 0, 3, 0, 1 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 15-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 31 | SafeLoad<uint32_t>(in + 8) << 1, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 29 | SafeLoad<uint32_t>(in + 9) << 3 };
+ shifts = simd_batch{ 16, 0, 14, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 15-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 27 | SafeLoad<uint32_t>(in + 10) << 5, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 25 | SafeLoad<uint32_t>(in + 11) << 7 };
+ shifts = simd_batch{ 12, 0, 10, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 15-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 23 | SafeLoad<uint32_t>(in + 12) << 9, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 21 | SafeLoad<uint32_t>(in + 13) << 11 };
+ shifts = simd_batch{ 8, 0, 6, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 15-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 19 | SafeLoad<uint32_t>(in + 14) << 13, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) };
+ shifts = simd_batch{ 4, 0, 2, 17 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 15;
+ return in;
+}
+
+inline static const uint32_t* unpack16_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 16-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 16, 0, 16 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 16-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 0, 16, 0, 16 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 16-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 0, 16, 0, 16 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 16-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) };
+ shifts = simd_batch{ 0, 16, 0, 16 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 16-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) };
+ shifts = simd_batch{ 0, 16, 0, 16 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 16-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) };
+ shifts = simd_batch{ 0, 16, 0, 16 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 16-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) };
+ shifts = simd_batch{ 0, 16, 0, 16 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 16-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) };
+ shifts = simd_batch{ 0, 16, 0, 16 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 16;
+ return in;
+}
+
+inline static const uint32_t* unpack17_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1ffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 17-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 17 | SafeLoad<uint32_t>(in + 1) << 15, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 19 | SafeLoad<uint32_t>(in + 2) << 13 };
+ shifts = simd_batch{ 0, 0, 2, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 17-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 21 | SafeLoad<uint32_t>(in + 3) << 11, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 23 | SafeLoad<uint32_t>(in + 4) << 9 };
+ shifts = simd_batch{ 4, 0, 6, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 17-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 25 | SafeLoad<uint32_t>(in + 5) << 7, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 27 | SafeLoad<uint32_t>(in + 6) << 5 };
+ shifts = simd_batch{ 8, 0, 10, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 17-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 29 | SafeLoad<uint32_t>(in + 7) << 3, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 31 | SafeLoad<uint32_t>(in + 8) << 1 };
+ shifts = simd_batch{ 12, 0, 14, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 17-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 8) >> 16 | SafeLoad<uint32_t>(in + 9) << 16, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 18 | SafeLoad<uint32_t>(in + 10) << 14, SafeLoad<uint32_t>(in + 10) };
+ shifts = simd_batch{ 0, 1, 0, 3 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 17-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 10) >> 20 | SafeLoad<uint32_t>(in + 11) << 12, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 22 | SafeLoad<uint32_t>(in + 12) << 10, SafeLoad<uint32_t>(in + 12) };
+ shifts = simd_batch{ 0, 5, 0, 7 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 17-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 12) >> 24 | SafeLoad<uint32_t>(in + 13) << 8, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 26 | SafeLoad<uint32_t>(in + 14) << 6, SafeLoad<uint32_t>(in + 14) };
+ shifts = simd_batch{ 0, 9, 0, 11 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 17-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 14) >> 28 | SafeLoad<uint32_t>(in + 15) << 4, SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 30 | SafeLoad<uint32_t>(in + 16) << 2, SafeLoad<uint32_t>(in + 16) };
+ shifts = simd_batch{ 0, 13, 0, 15 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 17;
+ return in;
+}
+
+inline static const uint32_t* unpack18_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3ffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 18-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 18 | SafeLoad<uint32_t>(in + 1) << 14, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 22 | SafeLoad<uint32_t>(in + 2) << 10 };
+ shifts = simd_batch{ 0, 0, 4, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 18-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2 };
+ shifts = simd_batch{ 8, 0, 12, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 18-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4) >> 16 | SafeLoad<uint32_t>(in + 5) << 16, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 20 | SafeLoad<uint32_t>(in + 6) << 12, SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 0, 2, 0, 6 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 18-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8) };
+ shifts = simd_batch{ 0, 10, 0, 14 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 18-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 18 | SafeLoad<uint32_t>(in + 10) << 14, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 22 | SafeLoad<uint32_t>(in + 11) << 10 };
+ shifts = simd_batch{ 0, 0, 4, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 18-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 26 | SafeLoad<uint32_t>(in + 12) << 6, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 30 | SafeLoad<uint32_t>(in + 13) << 2 };
+ shifts = simd_batch{ 8, 0, 12, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 18-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 13) >> 16 | SafeLoad<uint32_t>(in + 14) << 16, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) >> 20 | SafeLoad<uint32_t>(in + 15) << 12, SafeLoad<uint32_t>(in + 15) };
+ shifts = simd_batch{ 0, 2, 0, 6 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 18-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 15) >> 24 | SafeLoad<uint32_t>(in + 16) << 8, SafeLoad<uint32_t>(in + 16), SafeLoad<uint32_t>(in + 16) >> 28 | SafeLoad<uint32_t>(in + 17) << 4, SafeLoad<uint32_t>(in + 17) };
+ shifts = simd_batch{ 0, 10, 0, 14 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 18;
+ return in;
+}
+
+inline static const uint32_t* unpack19_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7ffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 19-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 19 | SafeLoad<uint32_t>(in + 1) << 13, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 25 | SafeLoad<uint32_t>(in + 2) << 7 };
+ shifts = simd_batch{ 0, 0, 6, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 19-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 31 | SafeLoad<uint32_t>(in + 3) << 1, SafeLoad<uint32_t>(in + 3) >> 18 | SafeLoad<uint32_t>(in + 4) << 14, SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 12, 0, 0, 5 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 19-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4) >> 24 | SafeLoad<uint32_t>(in + 5) << 8, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 30 | SafeLoad<uint32_t>(in + 6) << 2, SafeLoad<uint32_t>(in + 6) >> 17 | SafeLoad<uint32_t>(in + 7) << 15 };
+ shifts = simd_batch{ 0, 11, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 19-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 23 | SafeLoad<uint32_t>(in + 8) << 9, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 29 | SafeLoad<uint32_t>(in + 9) << 3 };
+ shifts = simd_batch{ 4, 0, 10, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 19-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9) >> 16 | SafeLoad<uint32_t>(in + 10) << 16, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 22 | SafeLoad<uint32_t>(in + 11) << 10, SafeLoad<uint32_t>(in + 11) };
+ shifts = simd_batch{ 0, 3, 0, 9 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 19-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 11) >> 28 | SafeLoad<uint32_t>(in + 12) << 4, SafeLoad<uint32_t>(in + 12) >> 15 | SafeLoad<uint32_t>(in + 13) << 17, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 21 | SafeLoad<uint32_t>(in + 14) << 11 };
+ shifts = simd_batch{ 0, 0, 2, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 19-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) >> 27 | SafeLoad<uint32_t>(in + 15) << 5, SafeLoad<uint32_t>(in + 15) >> 14 | SafeLoad<uint32_t>(in + 16) << 18, SafeLoad<uint32_t>(in + 16) };
+ shifts = simd_batch{ 8, 0, 0, 1 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 19-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 16) >> 20 | SafeLoad<uint32_t>(in + 17) << 12, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 26 | SafeLoad<uint32_t>(in + 18) << 6, SafeLoad<uint32_t>(in + 18) };
+ shifts = simd_batch{ 0, 7, 0, 13 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 19;
+ return in;
+}
+
+inline static const uint32_t* unpack20_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xfffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 20-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 20 | SafeLoad<uint32_t>(in + 1) << 12, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4 };
+ shifts = simd_batch{ 0, 0, 8, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 20-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2) >> 16 | SafeLoad<uint32_t>(in + 3) << 16, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 0, 4, 0, 12 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 20-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 20 | SafeLoad<uint32_t>(in + 6) << 12, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 28 | SafeLoad<uint32_t>(in + 7) << 4 };
+ shifts = simd_batch{ 0, 0, 8, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 20-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7) >> 16 | SafeLoad<uint32_t>(in + 8) << 16, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 24 | SafeLoad<uint32_t>(in + 9) << 8, SafeLoad<uint32_t>(in + 9) };
+ shifts = simd_batch{ 0, 4, 0, 12 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 20-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 20 | SafeLoad<uint32_t>(in + 11) << 12, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 28 | SafeLoad<uint32_t>(in + 12) << 4 };
+ shifts = simd_batch{ 0, 0, 8, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 20-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 12) >> 16 | SafeLoad<uint32_t>(in + 13) << 16, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 24 | SafeLoad<uint32_t>(in + 14) << 8, SafeLoad<uint32_t>(in + 14) };
+ shifts = simd_batch{ 0, 4, 0, 12 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 20-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 20 | SafeLoad<uint32_t>(in + 16) << 12, SafeLoad<uint32_t>(in + 16), SafeLoad<uint32_t>(in + 16) >> 28 | SafeLoad<uint32_t>(in + 17) << 4 };
+ shifts = simd_batch{ 0, 0, 8, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 20-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 17) >> 16 | SafeLoad<uint32_t>(in + 18) << 16, SafeLoad<uint32_t>(in + 18), SafeLoad<uint32_t>(in + 18) >> 24 | SafeLoad<uint32_t>(in + 19) << 8, SafeLoad<uint32_t>(in + 19) };
+ shifts = simd_batch{ 0, 4, 0, 12 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 20;
+ return in;
+}
+
+inline static const uint32_t* unpack21_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1fffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 21-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 21 | SafeLoad<uint32_t>(in + 1) << 11, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 31 | SafeLoad<uint32_t>(in + 2) << 1 };
+ shifts = simd_batch{ 0, 0, 10, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 21-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2) >> 20 | SafeLoad<uint32_t>(in + 3) << 12, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4) >> 19 | SafeLoad<uint32_t>(in + 5) << 13 };
+ shifts = simd_batch{ 0, 9, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 21-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 29 | SafeLoad<uint32_t>(in + 6) << 3, SafeLoad<uint32_t>(in + 6) >> 18 | SafeLoad<uint32_t>(in + 7) << 14, SafeLoad<uint32_t>(in + 7) };
+ shifts = simd_batch{ 8, 0, 0, 7 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 21-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8) >> 17 | SafeLoad<uint32_t>(in + 9) << 15, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 27 | SafeLoad<uint32_t>(in + 10) << 5 };
+ shifts = simd_batch{ 0, 0, 6, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 21-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 10) >> 16 | SafeLoad<uint32_t>(in + 11) << 16, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 26 | SafeLoad<uint32_t>(in + 12) << 6, SafeLoad<uint32_t>(in + 12) >> 15 | SafeLoad<uint32_t>(in + 13) << 17 };
+ shifts = simd_batch{ 0, 5, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 21-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 25 | SafeLoad<uint32_t>(in + 14) << 7, SafeLoad<uint32_t>(in + 14) >> 14 | SafeLoad<uint32_t>(in + 15) << 18, SafeLoad<uint32_t>(in + 15) };
+ shifts = simd_batch{ 4, 0, 0, 3 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 21-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 15) >> 24 | SafeLoad<uint32_t>(in + 16) << 8, SafeLoad<uint32_t>(in + 16) >> 13 | SafeLoad<uint32_t>(in + 17) << 19, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 23 | SafeLoad<uint32_t>(in + 18) << 9 };
+ shifts = simd_batch{ 0, 0, 2, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 21-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 18) >> 12 | SafeLoad<uint32_t>(in + 19) << 20, SafeLoad<uint32_t>(in + 19), SafeLoad<uint32_t>(in + 19) >> 22 | SafeLoad<uint32_t>(in + 20) << 10, SafeLoad<uint32_t>(in + 20) };
+ shifts = simd_batch{ 0, 1, 0, 11 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 21;
+ return in;
+}
+
+inline static const uint32_t* unpack22_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3fffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 22-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 22 | SafeLoad<uint32_t>(in + 1) << 10, SafeLoad<uint32_t>(in + 1) >> 12 | SafeLoad<uint32_t>(in + 2) << 20, SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 0, 0, 0, 2 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 22-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2) >> 24 | SafeLoad<uint32_t>(in + 3) << 8, SafeLoad<uint32_t>(in + 3) >> 14 | SafeLoad<uint32_t>(in + 4) << 18, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 26 | SafeLoad<uint32_t>(in + 5) << 6 };
+ shifts = simd_batch{ 0, 0, 4, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 22-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5) >> 16 | SafeLoad<uint32_t>(in + 6) << 16, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 28 | SafeLoad<uint32_t>(in + 7) << 4, SafeLoad<uint32_t>(in + 7) >> 18 | SafeLoad<uint32_t>(in + 8) << 14 };
+ shifts = simd_batch{ 0, 6, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 22-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 30 | SafeLoad<uint32_t>(in + 9) << 2, SafeLoad<uint32_t>(in + 9) >> 20 | SafeLoad<uint32_t>(in + 10) << 12, SafeLoad<uint32_t>(in + 10) };
+ shifts = simd_batch{ 8, 0, 0, 10 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 22-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 22 | SafeLoad<uint32_t>(in + 12) << 10, SafeLoad<uint32_t>(in + 12) >> 12 | SafeLoad<uint32_t>(in + 13) << 20, SafeLoad<uint32_t>(in + 13) };
+ shifts = simd_batch{ 0, 0, 0, 2 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 22-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 13) >> 24 | SafeLoad<uint32_t>(in + 14) << 8, SafeLoad<uint32_t>(in + 14) >> 14 | SafeLoad<uint32_t>(in + 15) << 18, SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 26 | SafeLoad<uint32_t>(in + 16) << 6 };
+ shifts = simd_batch{ 0, 0, 4, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 22-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 16) >> 16 | SafeLoad<uint32_t>(in + 17) << 16, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 28 | SafeLoad<uint32_t>(in + 18) << 4, SafeLoad<uint32_t>(in + 18) >> 18 | SafeLoad<uint32_t>(in + 19) << 14 };
+ shifts = simd_batch{ 0, 6, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 22-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 19), SafeLoad<uint32_t>(in + 19) >> 30 | SafeLoad<uint32_t>(in + 20) << 2, SafeLoad<uint32_t>(in + 20) >> 20 | SafeLoad<uint32_t>(in + 21) << 12, SafeLoad<uint32_t>(in + 21) };
+ shifts = simd_batch{ 8, 0, 0, 10 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 22;
+ return in;
+}
+
+inline static const uint32_t* unpack23_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7fffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 23-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 23 | SafeLoad<uint32_t>(in + 1) << 9, SafeLoad<uint32_t>(in + 1) >> 14 | SafeLoad<uint32_t>(in + 2) << 18, SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 0, 0, 0, 5 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 23-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2) >> 28 | SafeLoad<uint32_t>(in + 3) << 4, SafeLoad<uint32_t>(in + 3) >> 19 | SafeLoad<uint32_t>(in + 4) << 13, SafeLoad<uint32_t>(in + 4) >> 10 | SafeLoad<uint32_t>(in + 5) << 22, SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 0, 0, 0, 1 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 23-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5) >> 24 | SafeLoad<uint32_t>(in + 6) << 8, SafeLoad<uint32_t>(in + 6) >> 15 | SafeLoad<uint32_t>(in + 7) << 17, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 29 | SafeLoad<uint32_t>(in + 8) << 3 };
+ shifts = simd_batch{ 0, 0, 6, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 23-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 8) >> 20 | SafeLoad<uint32_t>(in + 9) << 12, SafeLoad<uint32_t>(in + 9) >> 11 | SafeLoad<uint32_t>(in + 10) << 21, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 25 | SafeLoad<uint32_t>(in + 11) << 7 };
+ shifts = simd_batch{ 0, 0, 2, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 23-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 11) >> 16 | SafeLoad<uint32_t>(in + 12) << 16, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 30 | SafeLoad<uint32_t>(in + 13) << 2, SafeLoad<uint32_t>(in + 13) >> 21 | SafeLoad<uint32_t>(in + 14) << 11 };
+ shifts = simd_batch{ 0, 7, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 23-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 14) >> 12 | SafeLoad<uint32_t>(in + 15) << 20, SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 26 | SafeLoad<uint32_t>(in + 16) << 6, SafeLoad<uint32_t>(in + 16) >> 17 | SafeLoad<uint32_t>(in + 17) << 15 };
+ shifts = simd_batch{ 0, 3, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 23-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 31 | SafeLoad<uint32_t>(in + 18) << 1, SafeLoad<uint32_t>(in + 18) >> 22 | SafeLoad<uint32_t>(in + 19) << 10, SafeLoad<uint32_t>(in + 19) >> 13 | SafeLoad<uint32_t>(in + 20) << 19 };
+ shifts = simd_batch{ 8, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 23-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 20), SafeLoad<uint32_t>(in + 20) >> 27 | SafeLoad<uint32_t>(in + 21) << 5, SafeLoad<uint32_t>(in + 21) >> 18 | SafeLoad<uint32_t>(in + 22) << 14, SafeLoad<uint32_t>(in + 22) };
+ shifts = simd_batch{ 4, 0, 0, 9 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 23;
+ return in;
+}
+
+inline static const uint32_t* unpack24_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 24-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 24 | SafeLoad<uint32_t>(in + 1) << 8, SafeLoad<uint32_t>(in + 1) >> 16 | SafeLoad<uint32_t>(in + 2) << 16, SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 0, 0, 0, 8 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 24-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4) >> 16 | SafeLoad<uint32_t>(in + 5) << 16, SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 0, 0, 0, 8 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 24-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7) >> 16 | SafeLoad<uint32_t>(in + 8) << 16, SafeLoad<uint32_t>(in + 8) };
+ shifts = simd_batch{ 0, 0, 0, 8 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 24-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 24 | SafeLoad<uint32_t>(in + 10) << 8, SafeLoad<uint32_t>(in + 10) >> 16 | SafeLoad<uint32_t>(in + 11) << 16, SafeLoad<uint32_t>(in + 11) };
+ shifts = simd_batch{ 0, 0, 0, 8 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 24-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 24 | SafeLoad<uint32_t>(in + 13) << 8, SafeLoad<uint32_t>(in + 13) >> 16 | SafeLoad<uint32_t>(in + 14) << 16, SafeLoad<uint32_t>(in + 14) };
+ shifts = simd_batch{ 0, 0, 0, 8 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 24-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 24 | SafeLoad<uint32_t>(in + 16) << 8, SafeLoad<uint32_t>(in + 16) >> 16 | SafeLoad<uint32_t>(in + 17) << 16, SafeLoad<uint32_t>(in + 17) };
+ shifts = simd_batch{ 0, 0, 0, 8 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 24-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 18), SafeLoad<uint32_t>(in + 18) >> 24 | SafeLoad<uint32_t>(in + 19) << 8, SafeLoad<uint32_t>(in + 19) >> 16 | SafeLoad<uint32_t>(in + 20) << 16, SafeLoad<uint32_t>(in + 20) };
+ shifts = simd_batch{ 0, 0, 0, 8 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 24-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 21), SafeLoad<uint32_t>(in + 21) >> 24 | SafeLoad<uint32_t>(in + 22) << 8, SafeLoad<uint32_t>(in + 22) >> 16 | SafeLoad<uint32_t>(in + 23) << 16, SafeLoad<uint32_t>(in + 23) };
+ shifts = simd_batch{ 0, 0, 0, 8 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 24;
+ return in;
+}
+
+inline static const uint32_t* unpack25_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1ffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 25-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 25 | SafeLoad<uint32_t>(in + 1) << 7, SafeLoad<uint32_t>(in + 1) >> 18 | SafeLoad<uint32_t>(in + 2) << 14, SafeLoad<uint32_t>(in + 2) >> 11 | SafeLoad<uint32_t>(in + 3) << 21 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 25-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 29 | SafeLoad<uint32_t>(in + 4) << 3, SafeLoad<uint32_t>(in + 4) >> 22 | SafeLoad<uint32_t>(in + 5) << 10, SafeLoad<uint32_t>(in + 5) >> 15 | SafeLoad<uint32_t>(in + 6) << 17 };
+ shifts = simd_batch{ 4, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 25-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6) >> 8 | SafeLoad<uint32_t>(in + 7) << 24, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 26 | SafeLoad<uint32_t>(in + 8) << 6, SafeLoad<uint32_t>(in + 8) >> 19 | SafeLoad<uint32_t>(in + 9) << 13 };
+ shifts = simd_batch{ 0, 1, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 25-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9) >> 12 | SafeLoad<uint32_t>(in + 10) << 20, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 30 | SafeLoad<uint32_t>(in + 11) << 2, SafeLoad<uint32_t>(in + 11) >> 23 | SafeLoad<uint32_t>(in + 12) << 9 };
+ shifts = simd_batch{ 0, 5, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 25-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 12) >> 16 | SafeLoad<uint32_t>(in + 13) << 16, SafeLoad<uint32_t>(in + 13) >> 9 | SafeLoad<uint32_t>(in + 14) << 23, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) >> 27 | SafeLoad<uint32_t>(in + 15) << 5 };
+ shifts = simd_batch{ 0, 0, 2, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 25-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 15) >> 20 | SafeLoad<uint32_t>(in + 16) << 12, SafeLoad<uint32_t>(in + 16) >> 13 | SafeLoad<uint32_t>(in + 17) << 19, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 31 | SafeLoad<uint32_t>(in + 18) << 1 };
+ shifts = simd_batch{ 0, 0, 6, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 25-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 18) >> 24 | SafeLoad<uint32_t>(in + 19) << 8, SafeLoad<uint32_t>(in + 19) >> 17 | SafeLoad<uint32_t>(in + 20) << 15, SafeLoad<uint32_t>(in + 20) >> 10 | SafeLoad<uint32_t>(in + 21) << 22, SafeLoad<uint32_t>(in + 21) };
+ shifts = simd_batch{ 0, 0, 0, 3 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 25-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 21) >> 28 | SafeLoad<uint32_t>(in + 22) << 4, SafeLoad<uint32_t>(in + 22) >> 21 | SafeLoad<uint32_t>(in + 23) << 11, SafeLoad<uint32_t>(in + 23) >> 14 | SafeLoad<uint32_t>(in + 24) << 18, SafeLoad<uint32_t>(in + 24) };
+ shifts = simd_batch{ 0, 0, 0, 7 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 25;
+ return in;
+}
+
+inline static const uint32_t* unpack26_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3ffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 26-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 26 | SafeLoad<uint32_t>(in + 1) << 6, SafeLoad<uint32_t>(in + 1) >> 20 | SafeLoad<uint32_t>(in + 2) << 12, SafeLoad<uint32_t>(in + 2) >> 14 | SafeLoad<uint32_t>(in + 3) << 18 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 26-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3) >> 8 | SafeLoad<uint32_t>(in + 4) << 24, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 28 | SafeLoad<uint32_t>(in + 5) << 4, SafeLoad<uint32_t>(in + 5) >> 22 | SafeLoad<uint32_t>(in + 6) << 10 };
+ shifts = simd_batch{ 0, 2, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 26-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6) >> 16 | SafeLoad<uint32_t>(in + 7) << 16, SafeLoad<uint32_t>(in + 7) >> 10 | SafeLoad<uint32_t>(in + 8) << 22, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 30 | SafeLoad<uint32_t>(in + 9) << 2 };
+ shifts = simd_batch{ 0, 0, 4, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 26-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9) >> 24 | SafeLoad<uint32_t>(in + 10) << 8, SafeLoad<uint32_t>(in + 10) >> 18 | SafeLoad<uint32_t>(in + 11) << 14, SafeLoad<uint32_t>(in + 11) >> 12 | SafeLoad<uint32_t>(in + 12) << 20, SafeLoad<uint32_t>(in + 12) };
+ shifts = simd_batch{ 0, 0, 0, 6 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 26-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 26 | SafeLoad<uint32_t>(in + 14) << 6, SafeLoad<uint32_t>(in + 14) >> 20 | SafeLoad<uint32_t>(in + 15) << 12, SafeLoad<uint32_t>(in + 15) >> 14 | SafeLoad<uint32_t>(in + 16) << 18 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 26-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 16) >> 8 | SafeLoad<uint32_t>(in + 17) << 24, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 28 | SafeLoad<uint32_t>(in + 18) << 4, SafeLoad<uint32_t>(in + 18) >> 22 | SafeLoad<uint32_t>(in + 19) << 10 };
+ shifts = simd_batch{ 0, 2, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 26-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 19) >> 16 | SafeLoad<uint32_t>(in + 20) << 16, SafeLoad<uint32_t>(in + 20) >> 10 | SafeLoad<uint32_t>(in + 21) << 22, SafeLoad<uint32_t>(in + 21), SafeLoad<uint32_t>(in + 21) >> 30 | SafeLoad<uint32_t>(in + 22) << 2 };
+ shifts = simd_batch{ 0, 0, 4, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 26-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 22) >> 24 | SafeLoad<uint32_t>(in + 23) << 8, SafeLoad<uint32_t>(in + 23) >> 18 | SafeLoad<uint32_t>(in + 24) << 14, SafeLoad<uint32_t>(in + 24) >> 12 | SafeLoad<uint32_t>(in + 25) << 20, SafeLoad<uint32_t>(in + 25) };
+ shifts = simd_batch{ 0, 0, 0, 6 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 26;
+ return in;
+}
+
+inline static const uint32_t* unpack27_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7ffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 27-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 27 | SafeLoad<uint32_t>(in + 1) << 5, SafeLoad<uint32_t>(in + 1) >> 22 | SafeLoad<uint32_t>(in + 2) << 10, SafeLoad<uint32_t>(in + 2) >> 17 | SafeLoad<uint32_t>(in + 3) << 15 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 27-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3) >> 12 | SafeLoad<uint32_t>(in + 4) << 20, SafeLoad<uint32_t>(in + 4) >> 7 | SafeLoad<uint32_t>(in + 5) << 25, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 29 | SafeLoad<uint32_t>(in + 6) << 3 };
+ shifts = simd_batch{ 0, 0, 2, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 27-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7) >> 19 | SafeLoad<uint32_t>(in + 8) << 13, SafeLoad<uint32_t>(in + 8) >> 14 | SafeLoad<uint32_t>(in + 9) << 18, SafeLoad<uint32_t>(in + 9) >> 9 | SafeLoad<uint32_t>(in + 10) << 23 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 27-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 31 | SafeLoad<uint32_t>(in + 11) << 1, SafeLoad<uint32_t>(in + 11) >> 26 | SafeLoad<uint32_t>(in + 12) << 6, SafeLoad<uint32_t>(in + 12) >> 21 | SafeLoad<uint32_t>(in + 13) << 11 };
+ shifts = simd_batch{ 4, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 27-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 13) >> 16 | SafeLoad<uint32_t>(in + 14) << 16, SafeLoad<uint32_t>(in + 14) >> 11 | SafeLoad<uint32_t>(in + 15) << 21, SafeLoad<uint32_t>(in + 15) >> 6 | SafeLoad<uint32_t>(in + 16) << 26, SafeLoad<uint32_t>(in + 16) };
+ shifts = simd_batch{ 0, 0, 0, 1 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 27-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 16) >> 28 | SafeLoad<uint32_t>(in + 17) << 4, SafeLoad<uint32_t>(in + 17) >> 23 | SafeLoad<uint32_t>(in + 18) << 9, SafeLoad<uint32_t>(in + 18) >> 18 | SafeLoad<uint32_t>(in + 19) << 14, SafeLoad<uint32_t>(in + 19) >> 13 | SafeLoad<uint32_t>(in + 20) << 19 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 27-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 20) >> 8 | SafeLoad<uint32_t>(in + 21) << 24, SafeLoad<uint32_t>(in + 21), SafeLoad<uint32_t>(in + 21) >> 30 | SafeLoad<uint32_t>(in + 22) << 2, SafeLoad<uint32_t>(in + 22) >> 25 | SafeLoad<uint32_t>(in + 23) << 7 };
+ shifts = simd_batch{ 0, 3, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 27-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 23) >> 20 | SafeLoad<uint32_t>(in + 24) << 12, SafeLoad<uint32_t>(in + 24) >> 15 | SafeLoad<uint32_t>(in + 25) << 17, SafeLoad<uint32_t>(in + 25) >> 10 | SafeLoad<uint32_t>(in + 26) << 22, SafeLoad<uint32_t>(in + 26) };
+ shifts = simd_batch{ 0, 0, 0, 5 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 27;
+ return in;
+}
+
+inline static const uint32_t* unpack28_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xfffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 28-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 28 | SafeLoad<uint32_t>(in + 1) << 4, SafeLoad<uint32_t>(in + 1) >> 24 | SafeLoad<uint32_t>(in + 2) << 8, SafeLoad<uint32_t>(in + 2) >> 20 | SafeLoad<uint32_t>(in + 3) << 12 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 28-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3) >> 16 | SafeLoad<uint32_t>(in + 4) << 16, SafeLoad<uint32_t>(in + 4) >> 12 | SafeLoad<uint32_t>(in + 5) << 20, SafeLoad<uint32_t>(in + 5) >> 8 | SafeLoad<uint32_t>(in + 6) << 24, SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 0, 0, 0, 4 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 28-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8) >> 24 | SafeLoad<uint32_t>(in + 9) << 8, SafeLoad<uint32_t>(in + 9) >> 20 | SafeLoad<uint32_t>(in + 10) << 12 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 28-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 10) >> 16 | SafeLoad<uint32_t>(in + 11) << 16, SafeLoad<uint32_t>(in + 11) >> 12 | SafeLoad<uint32_t>(in + 12) << 20, SafeLoad<uint32_t>(in + 12) >> 8 | SafeLoad<uint32_t>(in + 13) << 24, SafeLoad<uint32_t>(in + 13) };
+ shifts = simd_batch{ 0, 0, 0, 4 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 28-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) >> 28 | SafeLoad<uint32_t>(in + 15) << 4, SafeLoad<uint32_t>(in + 15) >> 24 | SafeLoad<uint32_t>(in + 16) << 8, SafeLoad<uint32_t>(in + 16) >> 20 | SafeLoad<uint32_t>(in + 17) << 12 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 28-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 17) >> 16 | SafeLoad<uint32_t>(in + 18) << 16, SafeLoad<uint32_t>(in + 18) >> 12 | SafeLoad<uint32_t>(in + 19) << 20, SafeLoad<uint32_t>(in + 19) >> 8 | SafeLoad<uint32_t>(in + 20) << 24, SafeLoad<uint32_t>(in + 20) };
+ shifts = simd_batch{ 0, 0, 0, 4 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 28-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 21), SafeLoad<uint32_t>(in + 21) >> 28 | SafeLoad<uint32_t>(in + 22) << 4, SafeLoad<uint32_t>(in + 22) >> 24 | SafeLoad<uint32_t>(in + 23) << 8, SafeLoad<uint32_t>(in + 23) >> 20 | SafeLoad<uint32_t>(in + 24) << 12 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 28-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 24) >> 16 | SafeLoad<uint32_t>(in + 25) << 16, SafeLoad<uint32_t>(in + 25) >> 12 | SafeLoad<uint32_t>(in + 26) << 20, SafeLoad<uint32_t>(in + 26) >> 8 | SafeLoad<uint32_t>(in + 27) << 24, SafeLoad<uint32_t>(in + 27) };
+ shifts = simd_batch{ 0, 0, 0, 4 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 28;
+ return in;
+}
+
+inline static const uint32_t* unpack29_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1fffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 29-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 29 | SafeLoad<uint32_t>(in + 1) << 3, SafeLoad<uint32_t>(in + 1) >> 26 | SafeLoad<uint32_t>(in + 2) << 6, SafeLoad<uint32_t>(in + 2) >> 23 | SafeLoad<uint32_t>(in + 3) << 9 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 29-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3) >> 20 | SafeLoad<uint32_t>(in + 4) << 12, SafeLoad<uint32_t>(in + 4) >> 17 | SafeLoad<uint32_t>(in + 5) << 15, SafeLoad<uint32_t>(in + 5) >> 14 | SafeLoad<uint32_t>(in + 6) << 18, SafeLoad<uint32_t>(in + 6) >> 11 | SafeLoad<uint32_t>(in + 7) << 21 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 29-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7) >> 8 | SafeLoad<uint32_t>(in + 8) << 24, SafeLoad<uint32_t>(in + 8) >> 5 | SafeLoad<uint32_t>(in + 9) << 27, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 31 | SafeLoad<uint32_t>(in + 10) << 1 };
+ shifts = simd_batch{ 0, 0, 2, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 29-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 10) >> 28 | SafeLoad<uint32_t>(in + 11) << 4, SafeLoad<uint32_t>(in + 11) >> 25 | SafeLoad<uint32_t>(in + 12) << 7, SafeLoad<uint32_t>(in + 12) >> 22 | SafeLoad<uint32_t>(in + 13) << 10, SafeLoad<uint32_t>(in + 13) >> 19 | SafeLoad<uint32_t>(in + 14) << 13 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 29-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 14) >> 16 | SafeLoad<uint32_t>(in + 15) << 16, SafeLoad<uint32_t>(in + 15) >> 13 | SafeLoad<uint32_t>(in + 16) << 19, SafeLoad<uint32_t>(in + 16) >> 10 | SafeLoad<uint32_t>(in + 17) << 22, SafeLoad<uint32_t>(in + 17) >> 7 | SafeLoad<uint32_t>(in + 18) << 25 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 29-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 18) >> 4 | SafeLoad<uint32_t>(in + 19) << 28, SafeLoad<uint32_t>(in + 19), SafeLoad<uint32_t>(in + 19) >> 30 | SafeLoad<uint32_t>(in + 20) << 2, SafeLoad<uint32_t>(in + 20) >> 27 | SafeLoad<uint32_t>(in + 21) << 5 };
+ shifts = simd_batch{ 0, 1, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 29-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 21) >> 24 | SafeLoad<uint32_t>(in + 22) << 8, SafeLoad<uint32_t>(in + 22) >> 21 | SafeLoad<uint32_t>(in + 23) << 11, SafeLoad<uint32_t>(in + 23) >> 18 | SafeLoad<uint32_t>(in + 24) << 14, SafeLoad<uint32_t>(in + 24) >> 15 | SafeLoad<uint32_t>(in + 25) << 17 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 29-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 25) >> 12 | SafeLoad<uint32_t>(in + 26) << 20, SafeLoad<uint32_t>(in + 26) >> 9 | SafeLoad<uint32_t>(in + 27) << 23, SafeLoad<uint32_t>(in + 27) >> 6 | SafeLoad<uint32_t>(in + 28) << 26, SafeLoad<uint32_t>(in + 28) };
+ shifts = simd_batch{ 0, 0, 0, 3 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 29;
+ return in;
+}
+
+inline static const uint32_t* unpack30_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3fffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 30-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 30-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4) >> 22 | SafeLoad<uint32_t>(in + 5) << 10, SafeLoad<uint32_t>(in + 5) >> 20 | SafeLoad<uint32_t>(in + 6) << 12, SafeLoad<uint32_t>(in + 6) >> 18 | SafeLoad<uint32_t>(in + 7) << 14 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 30-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7) >> 16 | SafeLoad<uint32_t>(in + 8) << 16, SafeLoad<uint32_t>(in + 8) >> 14 | SafeLoad<uint32_t>(in + 9) << 18, SafeLoad<uint32_t>(in + 9) >> 12 | SafeLoad<uint32_t>(in + 10) << 20, SafeLoad<uint32_t>(in + 10) >> 10 | SafeLoad<uint32_t>(in + 11) << 22 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 30-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 11) >> 8 | SafeLoad<uint32_t>(in + 12) << 24, SafeLoad<uint32_t>(in + 12) >> 6 | SafeLoad<uint32_t>(in + 13) << 26, SafeLoad<uint32_t>(in + 13) >> 4 | SafeLoad<uint32_t>(in + 14) << 28, SafeLoad<uint32_t>(in + 14) };
+ shifts = simd_batch{ 0, 0, 0, 2 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 30-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 30 | SafeLoad<uint32_t>(in + 16) << 2, SafeLoad<uint32_t>(in + 16) >> 28 | SafeLoad<uint32_t>(in + 17) << 4, SafeLoad<uint32_t>(in + 17) >> 26 | SafeLoad<uint32_t>(in + 18) << 6 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 30-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 18) >> 24 | SafeLoad<uint32_t>(in + 19) << 8, SafeLoad<uint32_t>(in + 19) >> 22 | SafeLoad<uint32_t>(in + 20) << 10, SafeLoad<uint32_t>(in + 20) >> 20 | SafeLoad<uint32_t>(in + 21) << 12, SafeLoad<uint32_t>(in + 21) >> 18 | SafeLoad<uint32_t>(in + 22) << 14 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 30-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 22) >> 16 | SafeLoad<uint32_t>(in + 23) << 16, SafeLoad<uint32_t>(in + 23) >> 14 | SafeLoad<uint32_t>(in + 24) << 18, SafeLoad<uint32_t>(in + 24) >> 12 | SafeLoad<uint32_t>(in + 25) << 20, SafeLoad<uint32_t>(in + 25) >> 10 | SafeLoad<uint32_t>(in + 26) << 22 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 30-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 26) >> 8 | SafeLoad<uint32_t>(in + 27) << 24, SafeLoad<uint32_t>(in + 27) >> 6 | SafeLoad<uint32_t>(in + 28) << 26, SafeLoad<uint32_t>(in + 28) >> 4 | SafeLoad<uint32_t>(in + 29) << 28, SafeLoad<uint32_t>(in + 29) };
+ shifts = simd_batch{ 0, 0, 0, 2 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 30;
+ return in;
+}
+
+inline static const uint32_t* unpack31_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7fffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 31-bit bundles 0 to 3
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 31 | SafeLoad<uint32_t>(in + 1) << 1, SafeLoad<uint32_t>(in + 1) >> 30 | SafeLoad<uint32_t>(in + 2) << 2, SafeLoad<uint32_t>(in + 2) >> 29 | SafeLoad<uint32_t>(in + 3) << 3 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 31-bit bundles 4 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3) >> 28 | SafeLoad<uint32_t>(in + 4) << 4, SafeLoad<uint32_t>(in + 4) >> 27 | SafeLoad<uint32_t>(in + 5) << 5, SafeLoad<uint32_t>(in + 5) >> 26 | SafeLoad<uint32_t>(in + 6) << 6, SafeLoad<uint32_t>(in + 6) >> 25 | SafeLoad<uint32_t>(in + 7) << 7 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 31-bit bundles 8 to 11
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7) >> 24 | SafeLoad<uint32_t>(in + 8) << 8, SafeLoad<uint32_t>(in + 8) >> 23 | SafeLoad<uint32_t>(in + 9) << 9, SafeLoad<uint32_t>(in + 9) >> 22 | SafeLoad<uint32_t>(in + 10) << 10, SafeLoad<uint32_t>(in + 10) >> 21 | SafeLoad<uint32_t>(in + 11) << 11 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 31-bit bundles 12 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 11) >> 20 | SafeLoad<uint32_t>(in + 12) << 12, SafeLoad<uint32_t>(in + 12) >> 19 | SafeLoad<uint32_t>(in + 13) << 13, SafeLoad<uint32_t>(in + 13) >> 18 | SafeLoad<uint32_t>(in + 14) << 14, SafeLoad<uint32_t>(in + 14) >> 17 | SafeLoad<uint32_t>(in + 15) << 15 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 31-bit bundles 16 to 19
+ words = simd_batch{ SafeLoad<uint32_t>(in + 15) >> 16 | SafeLoad<uint32_t>(in + 16) << 16, SafeLoad<uint32_t>(in + 16) >> 15 | SafeLoad<uint32_t>(in + 17) << 17, SafeLoad<uint32_t>(in + 17) >> 14 | SafeLoad<uint32_t>(in + 18) << 18, SafeLoad<uint32_t>(in + 18) >> 13 | SafeLoad<uint32_t>(in + 19) << 19 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 31-bit bundles 20 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 19) >> 12 | SafeLoad<uint32_t>(in + 20) << 20, SafeLoad<uint32_t>(in + 20) >> 11 | SafeLoad<uint32_t>(in + 21) << 21, SafeLoad<uint32_t>(in + 21) >> 10 | SafeLoad<uint32_t>(in + 22) << 22, SafeLoad<uint32_t>(in + 22) >> 9 | SafeLoad<uint32_t>(in + 23) << 23 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 31-bit bundles 24 to 27
+ words = simd_batch{ SafeLoad<uint32_t>(in + 23) >> 8 | SafeLoad<uint32_t>(in + 24) << 24, SafeLoad<uint32_t>(in + 24) >> 7 | SafeLoad<uint32_t>(in + 25) << 25, SafeLoad<uint32_t>(in + 25) >> 6 | SafeLoad<uint32_t>(in + 26) << 26, SafeLoad<uint32_t>(in + 26) >> 5 | SafeLoad<uint32_t>(in + 27) << 27 };
+ shifts = simd_batch{ 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ // extract 31-bit bundles 28 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 27) >> 4 | SafeLoad<uint32_t>(in + 28) << 28, SafeLoad<uint32_t>(in + 28) >> 3 | SafeLoad<uint32_t>(in + 29) << 29, SafeLoad<uint32_t>(in + 29) >> 2 | SafeLoad<uint32_t>(in + 30) << 30, SafeLoad<uint32_t>(in + 30) };
+ shifts = simd_batch{ 0, 0, 0, 1 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 4;
+
+ in += 31;
+ return in;
+}
+
+inline static const uint32_t* unpack32_32(const uint32_t* in, uint32_t* out) {
+ memcpy(out, in, 32 * sizeof(*out));
+ in += 32;
+ out += 32;
+
+ return in;
+}
+
+}; // struct UnpackBits128
+
+} // namespace
+} // namespace internal
+} // namespace arrow
+
diff --git a/src/arrow/cpp/src/arrow/util/bpacking_simd256_generated.h b/src/arrow/cpp/src/arrow/util/bpacking_simd256_generated.h
new file mode 100644
index 000000000..9fa0ded98
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking_simd256_generated.h
@@ -0,0 +1,1271 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Automatically generated file; DO NOT EDIT.
+
+#pragma once
+
+#include <cstdint>
+#include <cstring>
+
+#include <xsimd/xsimd.hpp>
+
+#include "arrow/util/dispatch.h"
+#include "arrow/util/ubsan.h"
+
+namespace arrow {
+namespace internal {
+namespace {
+
+using ::arrow::util::SafeLoad;
+
+template <DispatchLevel level>
+struct UnpackBits256 {
+
+using simd_arch = xsimd::avx2;
+using simd_batch = xsimd::batch<uint32_t, simd_arch>;
+
+inline static const uint32_t* unpack0_32(const uint32_t* in, uint32_t* out) {
+ memset(out, 0x0, 32 * sizeof(*out));
+ out += 32;
+
+ return in;
+}
+
+inline static const uint32_t* unpack1_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 1-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 0, 1, 2, 3, 4, 5, 6, 7 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 1-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 8, 9, 10, 11, 12, 13, 14, 15 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 1-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 16, 17, 18, 19, 20, 21, 22, 23 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 1-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 24, 25, 26, 27, 28, 29, 30, 31 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 1;
+ return in;
+}
+
+inline static const uint32_t* unpack2_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 2-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 0, 2, 4, 6, 8, 10, 12, 14 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 2-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 16, 18, 20, 22, 24, 26, 28, 30 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 2-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 2, 4, 6, 8, 10, 12, 14 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 2-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 16, 18, 20, 22, 24, 26, 28, 30 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 2;
+ return in;
+}
+
+inline static const uint32_t* unpack3_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 3-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 0, 3, 6, 9, 12, 15, 18, 21 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 3-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 24, 27, 0, 1, 4, 7, 10, 13 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 3-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 31 | SafeLoad<uint32_t>(in + 2) << 1, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 16, 19, 22, 25, 28, 0, 2, 5 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 3-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 8, 11, 14, 17, 20, 23, 26, 29 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 3;
+ return in;
+}
+
+inline static const uint32_t* unpack4_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xf;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 4-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 0, 4, 8, 12, 16, 20, 24, 28 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 4-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 4, 8, 12, 16, 20, 24, 28 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 4-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 0, 4, 8, 12, 16, 20, 24, 28 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 4-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 0, 4, 8, 12, 16, 20, 24, 28 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 4;
+ return in;
+}
+
+inline static const uint32_t* unpack5_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1f;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 5-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 5, 10, 15, 20, 25, 0, 3 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 5-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 8, 13, 18, 23, 0, 1, 6, 11 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 5-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 31 | SafeLoad<uint32_t>(in + 3) << 1, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 16, 21, 26, 0, 4, 9, 14, 19 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 5-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 29 | SafeLoad<uint32_t>(in + 4) << 3, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 24, 0, 2, 7, 12, 17, 22, 27 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 5;
+ return in;
+}
+
+inline static const uint32_t* unpack6_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3f;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 6-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 6, 12, 18, 24, 0, 4, 10 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 6-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 16, 22, 0, 2, 8, 14, 20, 26 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 6-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 0, 6, 12, 18, 24, 0, 4, 10 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 6-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 28 | SafeLoad<uint32_t>(in + 5) << 4, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 16, 22, 0, 2, 8, 14, 20, 26 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 6;
+ return in;
+}
+
+inline static const uint32_t* unpack7_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7f;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 7-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 28 | SafeLoad<uint32_t>(in + 1) << 4, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 7, 14, 21, 0, 3, 10, 17 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 7-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 31 | SafeLoad<uint32_t>(in + 2) << 1, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 27 | SafeLoad<uint32_t>(in + 3) << 5, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 24, 0, 6, 13, 20, 0, 2, 9 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 7-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 26 | SafeLoad<uint32_t>(in + 5) << 6, SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 16, 23, 0, 5, 12, 19, 0, 1 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 7-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 29 | SafeLoad<uint32_t>(in + 6) << 3, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 8, 15, 22, 0, 4, 11, 18, 25 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 7;
+ return in;
+}
+
+inline static const uint32_t* unpack8_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 8-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 8, 16, 24, 0, 8, 16, 24 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 8-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 0, 8, 16, 24, 0, 8, 16, 24 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 8-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 0, 8, 16, 24, 0, 8, 16, 24 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 8-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) };
+ shifts = simd_batch{ 0, 8, 16, 24, 0, 8, 16, 24 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 8;
+ return in;
+}
+
+inline static const uint32_t* unpack9_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1ff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 9-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 27 | SafeLoad<uint32_t>(in + 1) << 5, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 31 | SafeLoad<uint32_t>(in + 2) << 1 };
+ shifts = simd_batch{ 0, 9, 18, 0, 4, 13, 22, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 9-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 8, 17, 0, 3, 12, 21, 0, 7 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 9-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 25 | SafeLoad<uint32_t>(in + 5) << 7, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 29 | SafeLoad<uint32_t>(in + 6) << 3, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 16, 0, 2, 11, 20, 0, 6, 15 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 9-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) };
+ shifts = simd_batch{ 0, 1, 10, 19, 0, 5, 14, 23 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 9;
+ return in;
+}
+
+inline static const uint32_t* unpack10_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3ff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 10-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 0, 10, 20, 0, 8, 18, 0, 6 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 10-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 16, 0, 4, 14, 0, 2, 12, 22 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 10-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 30 | SafeLoad<uint32_t>(in + 6) << 2, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 28 | SafeLoad<uint32_t>(in + 7) << 4, SafeLoad<uint32_t>(in + 7) };
+ shifts = simd_batch{ 0, 10, 20, 0, 8, 18, 0, 6 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 10-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 26 | SafeLoad<uint32_t>(in + 8) << 6, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 24 | SafeLoad<uint32_t>(in + 9) << 8, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) };
+ shifts = simd_batch{ 16, 0, 4, 14, 0, 2, 12, 22 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 10;
+ return in;
+}
+
+inline static const uint32_t* unpack11_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7ff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 11-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 22 | SafeLoad<uint32_t>(in + 1) << 10, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 23 | SafeLoad<uint32_t>(in + 2) << 9, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 0, 11, 0, 1, 12, 0, 2, 13 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 11-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2) >> 24 | SafeLoad<uint32_t>(in + 3) << 8, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 25 | SafeLoad<uint32_t>(in + 4) << 7, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 26 | SafeLoad<uint32_t>(in + 5) << 6, SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 0, 3, 14, 0, 4, 15, 0, 5 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 11-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 27 | SafeLoad<uint32_t>(in + 6) << 5, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 28 | SafeLoad<uint32_t>(in + 7) << 4, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 29 | SafeLoad<uint32_t>(in + 8) << 3 };
+ shifts = simd_batch{ 16, 0, 6, 17, 0, 7, 18, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 11-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 30 | SafeLoad<uint32_t>(in + 9) << 2, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 31 | SafeLoad<uint32_t>(in + 10) << 1, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) };
+ shifts = simd_batch{ 8, 19, 0, 9, 20, 0, 10, 21 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 11;
+ return in;
+}
+
+inline static const uint32_t* unpack12_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xfff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 12-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 24 | SafeLoad<uint32_t>(in + 1) << 8, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 0, 12, 0, 4, 16, 0, 8, 20 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 12-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 28 | SafeLoad<uint32_t>(in + 5) << 4, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 0, 12, 0, 4, 16, 0, 8, 20 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 12-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) };
+ shifts = simd_batch{ 0, 12, 0, 4, 16, 0, 8, 20 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 12-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 24 | SafeLoad<uint32_t>(in + 10) << 8, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 28 | SafeLoad<uint32_t>(in + 11) << 4, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) };
+ shifts = simd_batch{ 0, 12, 0, 4, 16, 0, 8, 20 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 12;
+ return in;
+}
+
+inline static const uint32_t* unpack13_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1fff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 13-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 26 | SafeLoad<uint32_t>(in + 1) << 6, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 20 | SafeLoad<uint32_t>(in + 2) << 12, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 27 | SafeLoad<uint32_t>(in + 3) << 5 };
+ shifts = simd_batch{ 0, 13, 0, 7, 0, 1, 14, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 13-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 21 | SafeLoad<uint32_t>(in + 4) << 11, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 28 | SafeLoad<uint32_t>(in + 5) << 4, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 22 | SafeLoad<uint32_t>(in + 6) << 10, SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 8, 0, 2, 15, 0, 9, 0, 3 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 13-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 29 | SafeLoad<uint32_t>(in + 7) << 3, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 23 | SafeLoad<uint32_t>(in + 8) << 9, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 30 | SafeLoad<uint32_t>(in + 9) << 2, SafeLoad<uint32_t>(in + 9) };
+ shifts = simd_batch{ 16, 0, 10, 0, 4, 17, 0, 11 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 13-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9) >> 24 | SafeLoad<uint32_t>(in + 10) << 8, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 31 | SafeLoad<uint32_t>(in + 11) << 1, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 25 | SafeLoad<uint32_t>(in + 12) << 7, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) };
+ shifts = simd_batch{ 0, 5, 18, 0, 12, 0, 6, 19 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 13;
+ return in;
+}
+
+inline static const uint32_t* unpack14_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3fff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 14-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 28 | SafeLoad<uint32_t>(in + 1) << 4, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 24 | SafeLoad<uint32_t>(in + 2) << 8, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 20 | SafeLoad<uint32_t>(in + 3) << 12, SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 0, 14, 0, 10, 0, 6, 0, 2 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 14-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 26 | SafeLoad<uint32_t>(in + 5) << 6, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 22 | SafeLoad<uint32_t>(in + 6) << 10, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 16, 0, 12, 0, 8, 0, 4, 18 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 14-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 24 | SafeLoad<uint32_t>(in + 9) << 8, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 20 | SafeLoad<uint32_t>(in + 10) << 12, SafeLoad<uint32_t>(in + 10) };
+ shifts = simd_batch{ 0, 14, 0, 10, 0, 6, 0, 2 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 14-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 30 | SafeLoad<uint32_t>(in + 11) << 2, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 26 | SafeLoad<uint32_t>(in + 12) << 6, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 22 | SafeLoad<uint32_t>(in + 13) << 10, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) };
+ shifts = simd_batch{ 16, 0, 12, 0, 8, 0, 4, 18 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 14;
+ return in;
+}
+
+inline static const uint32_t* unpack15_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7fff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 15-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6, SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 0, 15, 0, 13, 0, 11, 0, 9 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 15-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 22 | SafeLoad<uint32_t>(in + 5) << 10, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 20 | SafeLoad<uint32_t>(in + 6) << 12, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 18 | SafeLoad<uint32_t>(in + 7) << 14, SafeLoad<uint32_t>(in + 7) };
+ shifts = simd_batch{ 0, 7, 0, 5, 0, 3, 0, 1 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 15-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 31 | SafeLoad<uint32_t>(in + 8) << 1, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 29 | SafeLoad<uint32_t>(in + 9) << 3, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 27 | SafeLoad<uint32_t>(in + 10) << 5, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 25 | SafeLoad<uint32_t>(in + 11) << 7 };
+ shifts = simd_batch{ 16, 0, 14, 0, 12, 0, 10, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 15-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 23 | SafeLoad<uint32_t>(in + 12) << 9, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 21 | SafeLoad<uint32_t>(in + 13) << 11, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 19 | SafeLoad<uint32_t>(in + 14) << 13, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) };
+ shifts = simd_batch{ 8, 0, 6, 0, 4, 0, 2, 17 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 15;
+ return in;
+}
+
+inline static const uint32_t* unpack16_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 16-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 0, 16, 0, 16, 0, 16, 0, 16 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 16-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) };
+ shifts = simd_batch{ 0, 16, 0, 16, 0, 16, 0, 16 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 16-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) };
+ shifts = simd_batch{ 0, 16, 0, 16, 0, 16, 0, 16 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 16-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) };
+ shifts = simd_batch{ 0, 16, 0, 16, 0, 16, 0, 16 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 16;
+ return in;
+}
+
+inline static const uint32_t* unpack17_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1ffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 17-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 17 | SafeLoad<uint32_t>(in + 1) << 15, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 19 | SafeLoad<uint32_t>(in + 2) << 13, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 21 | SafeLoad<uint32_t>(in + 3) << 11, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 23 | SafeLoad<uint32_t>(in + 4) << 9 };
+ shifts = simd_batch{ 0, 0, 2, 0, 4, 0, 6, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 17-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 25 | SafeLoad<uint32_t>(in + 5) << 7, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 27 | SafeLoad<uint32_t>(in + 6) << 5, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 29 | SafeLoad<uint32_t>(in + 7) << 3, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 31 | SafeLoad<uint32_t>(in + 8) << 1 };
+ shifts = simd_batch{ 8, 0, 10, 0, 12, 0, 14, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 17-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 8) >> 16 | SafeLoad<uint32_t>(in + 9) << 16, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 18 | SafeLoad<uint32_t>(in + 10) << 14, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 20 | SafeLoad<uint32_t>(in + 11) << 12, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 22 | SafeLoad<uint32_t>(in + 12) << 10, SafeLoad<uint32_t>(in + 12) };
+ shifts = simd_batch{ 0, 1, 0, 3, 0, 5, 0, 7 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 17-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 12) >> 24 | SafeLoad<uint32_t>(in + 13) << 8, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 26 | SafeLoad<uint32_t>(in + 14) << 6, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) >> 28 | SafeLoad<uint32_t>(in + 15) << 4, SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 30 | SafeLoad<uint32_t>(in + 16) << 2, SafeLoad<uint32_t>(in + 16) };
+ shifts = simd_batch{ 0, 9, 0, 11, 0, 13, 0, 15 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 17;
+ return in;
+}
+
+inline static const uint32_t* unpack18_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3ffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 18-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 18 | SafeLoad<uint32_t>(in + 1) << 14, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 22 | SafeLoad<uint32_t>(in + 2) << 10, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2 };
+ shifts = simd_batch{ 0, 0, 4, 0, 8, 0, 12, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 18-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4) >> 16 | SafeLoad<uint32_t>(in + 5) << 16, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 20 | SafeLoad<uint32_t>(in + 6) << 12, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8) };
+ shifts = simd_batch{ 0, 2, 0, 6, 0, 10, 0, 14 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 18-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 18 | SafeLoad<uint32_t>(in + 10) << 14, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 22 | SafeLoad<uint32_t>(in + 11) << 10, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 26 | SafeLoad<uint32_t>(in + 12) << 6, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 30 | SafeLoad<uint32_t>(in + 13) << 2 };
+ shifts = simd_batch{ 0, 0, 4, 0, 8, 0, 12, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 18-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 13) >> 16 | SafeLoad<uint32_t>(in + 14) << 16, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) >> 20 | SafeLoad<uint32_t>(in + 15) << 12, SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 24 | SafeLoad<uint32_t>(in + 16) << 8, SafeLoad<uint32_t>(in + 16), SafeLoad<uint32_t>(in + 16) >> 28 | SafeLoad<uint32_t>(in + 17) << 4, SafeLoad<uint32_t>(in + 17) };
+ shifts = simd_batch{ 0, 2, 0, 6, 0, 10, 0, 14 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 18;
+ return in;
+}
+
+inline static const uint32_t* unpack19_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7ffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 19-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 19 | SafeLoad<uint32_t>(in + 1) << 13, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 25 | SafeLoad<uint32_t>(in + 2) << 7, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 31 | SafeLoad<uint32_t>(in + 3) << 1, SafeLoad<uint32_t>(in + 3) >> 18 | SafeLoad<uint32_t>(in + 4) << 14, SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 0, 0, 6, 0, 12, 0, 0, 5 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 19-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4) >> 24 | SafeLoad<uint32_t>(in + 5) << 8, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 30 | SafeLoad<uint32_t>(in + 6) << 2, SafeLoad<uint32_t>(in + 6) >> 17 | SafeLoad<uint32_t>(in + 7) << 15, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 23 | SafeLoad<uint32_t>(in + 8) << 9, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 29 | SafeLoad<uint32_t>(in + 9) << 3 };
+ shifts = simd_batch{ 0, 11, 0, 0, 4, 0, 10, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 19-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9) >> 16 | SafeLoad<uint32_t>(in + 10) << 16, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 22 | SafeLoad<uint32_t>(in + 11) << 10, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 28 | SafeLoad<uint32_t>(in + 12) << 4, SafeLoad<uint32_t>(in + 12) >> 15 | SafeLoad<uint32_t>(in + 13) << 17, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 21 | SafeLoad<uint32_t>(in + 14) << 11 };
+ shifts = simd_batch{ 0, 3, 0, 9, 0, 0, 2, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 19-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) >> 27 | SafeLoad<uint32_t>(in + 15) << 5, SafeLoad<uint32_t>(in + 15) >> 14 | SafeLoad<uint32_t>(in + 16) << 18, SafeLoad<uint32_t>(in + 16), SafeLoad<uint32_t>(in + 16) >> 20 | SafeLoad<uint32_t>(in + 17) << 12, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 26 | SafeLoad<uint32_t>(in + 18) << 6, SafeLoad<uint32_t>(in + 18) };
+ shifts = simd_batch{ 8, 0, 0, 1, 0, 7, 0, 13 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 19;
+ return in;
+}
+
+inline static const uint32_t* unpack20_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xfffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 20-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 20 | SafeLoad<uint32_t>(in + 1) << 12, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2) >> 16 | SafeLoad<uint32_t>(in + 3) << 16, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 0, 0, 8, 0, 0, 4, 0, 12 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 20-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 20 | SafeLoad<uint32_t>(in + 6) << 12, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 28 | SafeLoad<uint32_t>(in + 7) << 4, SafeLoad<uint32_t>(in + 7) >> 16 | SafeLoad<uint32_t>(in + 8) << 16, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 24 | SafeLoad<uint32_t>(in + 9) << 8, SafeLoad<uint32_t>(in + 9) };
+ shifts = simd_batch{ 0, 0, 8, 0, 0, 4, 0, 12 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 20-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 20 | SafeLoad<uint32_t>(in + 11) << 12, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 28 | SafeLoad<uint32_t>(in + 12) << 4, SafeLoad<uint32_t>(in + 12) >> 16 | SafeLoad<uint32_t>(in + 13) << 16, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 24 | SafeLoad<uint32_t>(in + 14) << 8, SafeLoad<uint32_t>(in + 14) };
+ shifts = simd_batch{ 0, 0, 8, 0, 0, 4, 0, 12 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 20-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 20 | SafeLoad<uint32_t>(in + 16) << 12, SafeLoad<uint32_t>(in + 16), SafeLoad<uint32_t>(in + 16) >> 28 | SafeLoad<uint32_t>(in + 17) << 4, SafeLoad<uint32_t>(in + 17) >> 16 | SafeLoad<uint32_t>(in + 18) << 16, SafeLoad<uint32_t>(in + 18), SafeLoad<uint32_t>(in + 18) >> 24 | SafeLoad<uint32_t>(in + 19) << 8, SafeLoad<uint32_t>(in + 19) };
+ shifts = simd_batch{ 0, 0, 8, 0, 0, 4, 0, 12 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 20;
+ return in;
+}
+
+inline static const uint32_t* unpack21_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1fffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 21-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 21 | SafeLoad<uint32_t>(in + 1) << 11, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 31 | SafeLoad<uint32_t>(in + 2) << 1, SafeLoad<uint32_t>(in + 2) >> 20 | SafeLoad<uint32_t>(in + 3) << 12, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4) >> 19 | SafeLoad<uint32_t>(in + 5) << 13 };
+ shifts = simd_batch{ 0, 0, 10, 0, 0, 9, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 21-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 29 | SafeLoad<uint32_t>(in + 6) << 3, SafeLoad<uint32_t>(in + 6) >> 18 | SafeLoad<uint32_t>(in + 7) << 14, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8) >> 17 | SafeLoad<uint32_t>(in + 9) << 15, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 27 | SafeLoad<uint32_t>(in + 10) << 5 };
+ shifts = simd_batch{ 8, 0, 0, 7, 0, 0, 6, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 21-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 10) >> 16 | SafeLoad<uint32_t>(in + 11) << 16, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 26 | SafeLoad<uint32_t>(in + 12) << 6, SafeLoad<uint32_t>(in + 12) >> 15 | SafeLoad<uint32_t>(in + 13) << 17, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 25 | SafeLoad<uint32_t>(in + 14) << 7, SafeLoad<uint32_t>(in + 14) >> 14 | SafeLoad<uint32_t>(in + 15) << 18, SafeLoad<uint32_t>(in + 15) };
+ shifts = simd_batch{ 0, 5, 0, 0, 4, 0, 0, 3 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 21-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 15) >> 24 | SafeLoad<uint32_t>(in + 16) << 8, SafeLoad<uint32_t>(in + 16) >> 13 | SafeLoad<uint32_t>(in + 17) << 19, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 23 | SafeLoad<uint32_t>(in + 18) << 9, SafeLoad<uint32_t>(in + 18) >> 12 | SafeLoad<uint32_t>(in + 19) << 20, SafeLoad<uint32_t>(in + 19), SafeLoad<uint32_t>(in + 19) >> 22 | SafeLoad<uint32_t>(in + 20) << 10, SafeLoad<uint32_t>(in + 20) };
+ shifts = simd_batch{ 0, 0, 2, 0, 0, 1, 0, 11 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 21;
+ return in;
+}
+
+inline static const uint32_t* unpack22_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3fffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 22-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 22 | SafeLoad<uint32_t>(in + 1) << 10, SafeLoad<uint32_t>(in + 1) >> 12 | SafeLoad<uint32_t>(in + 2) << 20, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 24 | SafeLoad<uint32_t>(in + 3) << 8, SafeLoad<uint32_t>(in + 3) >> 14 | SafeLoad<uint32_t>(in + 4) << 18, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 26 | SafeLoad<uint32_t>(in + 5) << 6 };
+ shifts = simd_batch{ 0, 0, 0, 2, 0, 0, 4, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 22-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5) >> 16 | SafeLoad<uint32_t>(in + 6) << 16, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 28 | SafeLoad<uint32_t>(in + 7) << 4, SafeLoad<uint32_t>(in + 7) >> 18 | SafeLoad<uint32_t>(in + 8) << 14, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 30 | SafeLoad<uint32_t>(in + 9) << 2, SafeLoad<uint32_t>(in + 9) >> 20 | SafeLoad<uint32_t>(in + 10) << 12, SafeLoad<uint32_t>(in + 10) };
+ shifts = simd_batch{ 0, 6, 0, 0, 8, 0, 0, 10 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 22-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 22 | SafeLoad<uint32_t>(in + 12) << 10, SafeLoad<uint32_t>(in + 12) >> 12 | SafeLoad<uint32_t>(in + 13) << 20, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 24 | SafeLoad<uint32_t>(in + 14) << 8, SafeLoad<uint32_t>(in + 14) >> 14 | SafeLoad<uint32_t>(in + 15) << 18, SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 26 | SafeLoad<uint32_t>(in + 16) << 6 };
+ shifts = simd_batch{ 0, 0, 0, 2, 0, 0, 4, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 22-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 16) >> 16 | SafeLoad<uint32_t>(in + 17) << 16, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 28 | SafeLoad<uint32_t>(in + 18) << 4, SafeLoad<uint32_t>(in + 18) >> 18 | SafeLoad<uint32_t>(in + 19) << 14, SafeLoad<uint32_t>(in + 19), SafeLoad<uint32_t>(in + 19) >> 30 | SafeLoad<uint32_t>(in + 20) << 2, SafeLoad<uint32_t>(in + 20) >> 20 | SafeLoad<uint32_t>(in + 21) << 12, SafeLoad<uint32_t>(in + 21) };
+ shifts = simd_batch{ 0, 6, 0, 0, 8, 0, 0, 10 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 22;
+ return in;
+}
+
+inline static const uint32_t* unpack23_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7fffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 23-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 23 | SafeLoad<uint32_t>(in + 1) << 9, SafeLoad<uint32_t>(in + 1) >> 14 | SafeLoad<uint32_t>(in + 2) << 18, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 28 | SafeLoad<uint32_t>(in + 3) << 4, SafeLoad<uint32_t>(in + 3) >> 19 | SafeLoad<uint32_t>(in + 4) << 13, SafeLoad<uint32_t>(in + 4) >> 10 | SafeLoad<uint32_t>(in + 5) << 22, SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 0, 0, 0, 5, 0, 0, 0, 1 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 23-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5) >> 24 | SafeLoad<uint32_t>(in + 6) << 8, SafeLoad<uint32_t>(in + 6) >> 15 | SafeLoad<uint32_t>(in + 7) << 17, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 29 | SafeLoad<uint32_t>(in + 8) << 3, SafeLoad<uint32_t>(in + 8) >> 20 | SafeLoad<uint32_t>(in + 9) << 12, SafeLoad<uint32_t>(in + 9) >> 11 | SafeLoad<uint32_t>(in + 10) << 21, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 25 | SafeLoad<uint32_t>(in + 11) << 7 };
+ shifts = simd_batch{ 0, 0, 6, 0, 0, 0, 2, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 23-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 11) >> 16 | SafeLoad<uint32_t>(in + 12) << 16, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 30 | SafeLoad<uint32_t>(in + 13) << 2, SafeLoad<uint32_t>(in + 13) >> 21 | SafeLoad<uint32_t>(in + 14) << 11, SafeLoad<uint32_t>(in + 14) >> 12 | SafeLoad<uint32_t>(in + 15) << 20, SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 26 | SafeLoad<uint32_t>(in + 16) << 6, SafeLoad<uint32_t>(in + 16) >> 17 | SafeLoad<uint32_t>(in + 17) << 15 };
+ shifts = simd_batch{ 0, 7, 0, 0, 0, 3, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 23-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 31 | SafeLoad<uint32_t>(in + 18) << 1, SafeLoad<uint32_t>(in + 18) >> 22 | SafeLoad<uint32_t>(in + 19) << 10, SafeLoad<uint32_t>(in + 19) >> 13 | SafeLoad<uint32_t>(in + 20) << 19, SafeLoad<uint32_t>(in + 20), SafeLoad<uint32_t>(in + 20) >> 27 | SafeLoad<uint32_t>(in + 21) << 5, SafeLoad<uint32_t>(in + 21) >> 18 | SafeLoad<uint32_t>(in + 22) << 14, SafeLoad<uint32_t>(in + 22) };
+ shifts = simd_batch{ 8, 0, 0, 0, 4, 0, 0, 9 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 23;
+ return in;
+}
+
+inline static const uint32_t* unpack24_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 24-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 24 | SafeLoad<uint32_t>(in + 1) << 8, SafeLoad<uint32_t>(in + 1) >> 16 | SafeLoad<uint32_t>(in + 2) << 16, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4) >> 16 | SafeLoad<uint32_t>(in + 5) << 16, SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 0, 0, 0, 8, 0, 0, 0, 8 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 24-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7) >> 16 | SafeLoad<uint32_t>(in + 8) << 16, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 24 | SafeLoad<uint32_t>(in + 10) << 8, SafeLoad<uint32_t>(in + 10) >> 16 | SafeLoad<uint32_t>(in + 11) << 16, SafeLoad<uint32_t>(in + 11) };
+ shifts = simd_batch{ 0, 0, 0, 8, 0, 0, 0, 8 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 24-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 24 | SafeLoad<uint32_t>(in + 13) << 8, SafeLoad<uint32_t>(in + 13) >> 16 | SafeLoad<uint32_t>(in + 14) << 16, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 24 | SafeLoad<uint32_t>(in + 16) << 8, SafeLoad<uint32_t>(in + 16) >> 16 | SafeLoad<uint32_t>(in + 17) << 16, SafeLoad<uint32_t>(in + 17) };
+ shifts = simd_batch{ 0, 0, 0, 8, 0, 0, 0, 8 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 24-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 18), SafeLoad<uint32_t>(in + 18) >> 24 | SafeLoad<uint32_t>(in + 19) << 8, SafeLoad<uint32_t>(in + 19) >> 16 | SafeLoad<uint32_t>(in + 20) << 16, SafeLoad<uint32_t>(in + 20), SafeLoad<uint32_t>(in + 21), SafeLoad<uint32_t>(in + 21) >> 24 | SafeLoad<uint32_t>(in + 22) << 8, SafeLoad<uint32_t>(in + 22) >> 16 | SafeLoad<uint32_t>(in + 23) << 16, SafeLoad<uint32_t>(in + 23) };
+ shifts = simd_batch{ 0, 0, 0, 8, 0, 0, 0, 8 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 24;
+ return in;
+}
+
+inline static const uint32_t* unpack25_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1ffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 25-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 25 | SafeLoad<uint32_t>(in + 1) << 7, SafeLoad<uint32_t>(in + 1) >> 18 | SafeLoad<uint32_t>(in + 2) << 14, SafeLoad<uint32_t>(in + 2) >> 11 | SafeLoad<uint32_t>(in + 3) << 21, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 29 | SafeLoad<uint32_t>(in + 4) << 3, SafeLoad<uint32_t>(in + 4) >> 22 | SafeLoad<uint32_t>(in + 5) << 10, SafeLoad<uint32_t>(in + 5) >> 15 | SafeLoad<uint32_t>(in + 6) << 17 };
+ shifts = simd_batch{ 0, 0, 0, 0, 4, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 25-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6) >> 8 | SafeLoad<uint32_t>(in + 7) << 24, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 26 | SafeLoad<uint32_t>(in + 8) << 6, SafeLoad<uint32_t>(in + 8) >> 19 | SafeLoad<uint32_t>(in + 9) << 13, SafeLoad<uint32_t>(in + 9) >> 12 | SafeLoad<uint32_t>(in + 10) << 20, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 30 | SafeLoad<uint32_t>(in + 11) << 2, SafeLoad<uint32_t>(in + 11) >> 23 | SafeLoad<uint32_t>(in + 12) << 9 };
+ shifts = simd_batch{ 0, 1, 0, 0, 0, 5, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 25-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 12) >> 16 | SafeLoad<uint32_t>(in + 13) << 16, SafeLoad<uint32_t>(in + 13) >> 9 | SafeLoad<uint32_t>(in + 14) << 23, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) >> 27 | SafeLoad<uint32_t>(in + 15) << 5, SafeLoad<uint32_t>(in + 15) >> 20 | SafeLoad<uint32_t>(in + 16) << 12, SafeLoad<uint32_t>(in + 16) >> 13 | SafeLoad<uint32_t>(in + 17) << 19, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 31 | SafeLoad<uint32_t>(in + 18) << 1 };
+ shifts = simd_batch{ 0, 0, 2, 0, 0, 0, 6, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 25-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 18) >> 24 | SafeLoad<uint32_t>(in + 19) << 8, SafeLoad<uint32_t>(in + 19) >> 17 | SafeLoad<uint32_t>(in + 20) << 15, SafeLoad<uint32_t>(in + 20) >> 10 | SafeLoad<uint32_t>(in + 21) << 22, SafeLoad<uint32_t>(in + 21), SafeLoad<uint32_t>(in + 21) >> 28 | SafeLoad<uint32_t>(in + 22) << 4, SafeLoad<uint32_t>(in + 22) >> 21 | SafeLoad<uint32_t>(in + 23) << 11, SafeLoad<uint32_t>(in + 23) >> 14 | SafeLoad<uint32_t>(in + 24) << 18, SafeLoad<uint32_t>(in + 24) };
+ shifts = simd_batch{ 0, 0, 0, 3, 0, 0, 0, 7 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 25;
+ return in;
+}
+
+inline static const uint32_t* unpack26_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3ffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 26-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 26 | SafeLoad<uint32_t>(in + 1) << 6, SafeLoad<uint32_t>(in + 1) >> 20 | SafeLoad<uint32_t>(in + 2) << 12, SafeLoad<uint32_t>(in + 2) >> 14 | SafeLoad<uint32_t>(in + 3) << 18, SafeLoad<uint32_t>(in + 3) >> 8 | SafeLoad<uint32_t>(in + 4) << 24, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 28 | SafeLoad<uint32_t>(in + 5) << 4, SafeLoad<uint32_t>(in + 5) >> 22 | SafeLoad<uint32_t>(in + 6) << 10 };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 2, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 26-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6) >> 16 | SafeLoad<uint32_t>(in + 7) << 16, SafeLoad<uint32_t>(in + 7) >> 10 | SafeLoad<uint32_t>(in + 8) << 22, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 30 | SafeLoad<uint32_t>(in + 9) << 2, SafeLoad<uint32_t>(in + 9) >> 24 | SafeLoad<uint32_t>(in + 10) << 8, SafeLoad<uint32_t>(in + 10) >> 18 | SafeLoad<uint32_t>(in + 11) << 14, SafeLoad<uint32_t>(in + 11) >> 12 | SafeLoad<uint32_t>(in + 12) << 20, SafeLoad<uint32_t>(in + 12) };
+ shifts = simd_batch{ 0, 0, 4, 0, 0, 0, 0, 6 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 26-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 26 | SafeLoad<uint32_t>(in + 14) << 6, SafeLoad<uint32_t>(in + 14) >> 20 | SafeLoad<uint32_t>(in + 15) << 12, SafeLoad<uint32_t>(in + 15) >> 14 | SafeLoad<uint32_t>(in + 16) << 18, SafeLoad<uint32_t>(in + 16) >> 8 | SafeLoad<uint32_t>(in + 17) << 24, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 28 | SafeLoad<uint32_t>(in + 18) << 4, SafeLoad<uint32_t>(in + 18) >> 22 | SafeLoad<uint32_t>(in + 19) << 10 };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 2, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 26-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 19) >> 16 | SafeLoad<uint32_t>(in + 20) << 16, SafeLoad<uint32_t>(in + 20) >> 10 | SafeLoad<uint32_t>(in + 21) << 22, SafeLoad<uint32_t>(in + 21), SafeLoad<uint32_t>(in + 21) >> 30 | SafeLoad<uint32_t>(in + 22) << 2, SafeLoad<uint32_t>(in + 22) >> 24 | SafeLoad<uint32_t>(in + 23) << 8, SafeLoad<uint32_t>(in + 23) >> 18 | SafeLoad<uint32_t>(in + 24) << 14, SafeLoad<uint32_t>(in + 24) >> 12 | SafeLoad<uint32_t>(in + 25) << 20, SafeLoad<uint32_t>(in + 25) };
+ shifts = simd_batch{ 0, 0, 4, 0, 0, 0, 0, 6 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 26;
+ return in;
+}
+
+inline static const uint32_t* unpack27_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7ffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 27-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 27 | SafeLoad<uint32_t>(in + 1) << 5, SafeLoad<uint32_t>(in + 1) >> 22 | SafeLoad<uint32_t>(in + 2) << 10, SafeLoad<uint32_t>(in + 2) >> 17 | SafeLoad<uint32_t>(in + 3) << 15, SafeLoad<uint32_t>(in + 3) >> 12 | SafeLoad<uint32_t>(in + 4) << 20, SafeLoad<uint32_t>(in + 4) >> 7 | SafeLoad<uint32_t>(in + 5) << 25, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 29 | SafeLoad<uint32_t>(in + 6) << 3 };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 2, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 27-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7) >> 19 | SafeLoad<uint32_t>(in + 8) << 13, SafeLoad<uint32_t>(in + 8) >> 14 | SafeLoad<uint32_t>(in + 9) << 18, SafeLoad<uint32_t>(in + 9) >> 9 | SafeLoad<uint32_t>(in + 10) << 23, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 31 | SafeLoad<uint32_t>(in + 11) << 1, SafeLoad<uint32_t>(in + 11) >> 26 | SafeLoad<uint32_t>(in + 12) << 6, SafeLoad<uint32_t>(in + 12) >> 21 | SafeLoad<uint32_t>(in + 13) << 11 };
+ shifts = simd_batch{ 0, 0, 0, 0, 4, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 27-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 13) >> 16 | SafeLoad<uint32_t>(in + 14) << 16, SafeLoad<uint32_t>(in + 14) >> 11 | SafeLoad<uint32_t>(in + 15) << 21, SafeLoad<uint32_t>(in + 15) >> 6 | SafeLoad<uint32_t>(in + 16) << 26, SafeLoad<uint32_t>(in + 16), SafeLoad<uint32_t>(in + 16) >> 28 | SafeLoad<uint32_t>(in + 17) << 4, SafeLoad<uint32_t>(in + 17) >> 23 | SafeLoad<uint32_t>(in + 18) << 9, SafeLoad<uint32_t>(in + 18) >> 18 | SafeLoad<uint32_t>(in + 19) << 14, SafeLoad<uint32_t>(in + 19) >> 13 | SafeLoad<uint32_t>(in + 20) << 19 };
+ shifts = simd_batch{ 0, 0, 0, 1, 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 27-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 20) >> 8 | SafeLoad<uint32_t>(in + 21) << 24, SafeLoad<uint32_t>(in + 21), SafeLoad<uint32_t>(in + 21) >> 30 | SafeLoad<uint32_t>(in + 22) << 2, SafeLoad<uint32_t>(in + 22) >> 25 | SafeLoad<uint32_t>(in + 23) << 7, SafeLoad<uint32_t>(in + 23) >> 20 | SafeLoad<uint32_t>(in + 24) << 12, SafeLoad<uint32_t>(in + 24) >> 15 | SafeLoad<uint32_t>(in + 25) << 17, SafeLoad<uint32_t>(in + 25) >> 10 | SafeLoad<uint32_t>(in + 26) << 22, SafeLoad<uint32_t>(in + 26) };
+ shifts = simd_batch{ 0, 3, 0, 0, 0, 0, 0, 5 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 27;
+ return in;
+}
+
+inline static const uint32_t* unpack28_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xfffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 28-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 28 | SafeLoad<uint32_t>(in + 1) << 4, SafeLoad<uint32_t>(in + 1) >> 24 | SafeLoad<uint32_t>(in + 2) << 8, SafeLoad<uint32_t>(in + 2) >> 20 | SafeLoad<uint32_t>(in + 3) << 12, SafeLoad<uint32_t>(in + 3) >> 16 | SafeLoad<uint32_t>(in + 4) << 16, SafeLoad<uint32_t>(in + 4) >> 12 | SafeLoad<uint32_t>(in + 5) << 20, SafeLoad<uint32_t>(in + 5) >> 8 | SafeLoad<uint32_t>(in + 6) << 24, SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 4 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 28-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8) >> 24 | SafeLoad<uint32_t>(in + 9) << 8, SafeLoad<uint32_t>(in + 9) >> 20 | SafeLoad<uint32_t>(in + 10) << 12, SafeLoad<uint32_t>(in + 10) >> 16 | SafeLoad<uint32_t>(in + 11) << 16, SafeLoad<uint32_t>(in + 11) >> 12 | SafeLoad<uint32_t>(in + 12) << 20, SafeLoad<uint32_t>(in + 12) >> 8 | SafeLoad<uint32_t>(in + 13) << 24, SafeLoad<uint32_t>(in + 13) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 4 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 28-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) >> 28 | SafeLoad<uint32_t>(in + 15) << 4, SafeLoad<uint32_t>(in + 15) >> 24 | SafeLoad<uint32_t>(in + 16) << 8, SafeLoad<uint32_t>(in + 16) >> 20 | SafeLoad<uint32_t>(in + 17) << 12, SafeLoad<uint32_t>(in + 17) >> 16 | SafeLoad<uint32_t>(in + 18) << 16, SafeLoad<uint32_t>(in + 18) >> 12 | SafeLoad<uint32_t>(in + 19) << 20, SafeLoad<uint32_t>(in + 19) >> 8 | SafeLoad<uint32_t>(in + 20) << 24, SafeLoad<uint32_t>(in + 20) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 4 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 28-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 21), SafeLoad<uint32_t>(in + 21) >> 28 | SafeLoad<uint32_t>(in + 22) << 4, SafeLoad<uint32_t>(in + 22) >> 24 | SafeLoad<uint32_t>(in + 23) << 8, SafeLoad<uint32_t>(in + 23) >> 20 | SafeLoad<uint32_t>(in + 24) << 12, SafeLoad<uint32_t>(in + 24) >> 16 | SafeLoad<uint32_t>(in + 25) << 16, SafeLoad<uint32_t>(in + 25) >> 12 | SafeLoad<uint32_t>(in + 26) << 20, SafeLoad<uint32_t>(in + 26) >> 8 | SafeLoad<uint32_t>(in + 27) << 24, SafeLoad<uint32_t>(in + 27) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 4 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 28;
+ return in;
+}
+
+inline static const uint32_t* unpack29_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1fffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 29-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 29 | SafeLoad<uint32_t>(in + 1) << 3, SafeLoad<uint32_t>(in + 1) >> 26 | SafeLoad<uint32_t>(in + 2) << 6, SafeLoad<uint32_t>(in + 2) >> 23 | SafeLoad<uint32_t>(in + 3) << 9, SafeLoad<uint32_t>(in + 3) >> 20 | SafeLoad<uint32_t>(in + 4) << 12, SafeLoad<uint32_t>(in + 4) >> 17 | SafeLoad<uint32_t>(in + 5) << 15, SafeLoad<uint32_t>(in + 5) >> 14 | SafeLoad<uint32_t>(in + 6) << 18, SafeLoad<uint32_t>(in + 6) >> 11 | SafeLoad<uint32_t>(in + 7) << 21 };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 29-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7) >> 8 | SafeLoad<uint32_t>(in + 8) << 24, SafeLoad<uint32_t>(in + 8) >> 5 | SafeLoad<uint32_t>(in + 9) << 27, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 31 | SafeLoad<uint32_t>(in + 10) << 1, SafeLoad<uint32_t>(in + 10) >> 28 | SafeLoad<uint32_t>(in + 11) << 4, SafeLoad<uint32_t>(in + 11) >> 25 | SafeLoad<uint32_t>(in + 12) << 7, SafeLoad<uint32_t>(in + 12) >> 22 | SafeLoad<uint32_t>(in + 13) << 10, SafeLoad<uint32_t>(in + 13) >> 19 | SafeLoad<uint32_t>(in + 14) << 13 };
+ shifts = simd_batch{ 0, 0, 2, 0, 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 29-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 14) >> 16 | SafeLoad<uint32_t>(in + 15) << 16, SafeLoad<uint32_t>(in + 15) >> 13 | SafeLoad<uint32_t>(in + 16) << 19, SafeLoad<uint32_t>(in + 16) >> 10 | SafeLoad<uint32_t>(in + 17) << 22, SafeLoad<uint32_t>(in + 17) >> 7 | SafeLoad<uint32_t>(in + 18) << 25, SafeLoad<uint32_t>(in + 18) >> 4 | SafeLoad<uint32_t>(in + 19) << 28, SafeLoad<uint32_t>(in + 19), SafeLoad<uint32_t>(in + 19) >> 30 | SafeLoad<uint32_t>(in + 20) << 2, SafeLoad<uint32_t>(in + 20) >> 27 | SafeLoad<uint32_t>(in + 21) << 5 };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 1, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 29-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 21) >> 24 | SafeLoad<uint32_t>(in + 22) << 8, SafeLoad<uint32_t>(in + 22) >> 21 | SafeLoad<uint32_t>(in + 23) << 11, SafeLoad<uint32_t>(in + 23) >> 18 | SafeLoad<uint32_t>(in + 24) << 14, SafeLoad<uint32_t>(in + 24) >> 15 | SafeLoad<uint32_t>(in + 25) << 17, SafeLoad<uint32_t>(in + 25) >> 12 | SafeLoad<uint32_t>(in + 26) << 20, SafeLoad<uint32_t>(in + 26) >> 9 | SafeLoad<uint32_t>(in + 27) << 23, SafeLoad<uint32_t>(in + 27) >> 6 | SafeLoad<uint32_t>(in + 28) << 26, SafeLoad<uint32_t>(in + 28) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 3 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 29;
+ return in;
+}
+
+inline static const uint32_t* unpack30_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3fffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 30-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6, SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4) >> 22 | SafeLoad<uint32_t>(in + 5) << 10, SafeLoad<uint32_t>(in + 5) >> 20 | SafeLoad<uint32_t>(in + 6) << 12, SafeLoad<uint32_t>(in + 6) >> 18 | SafeLoad<uint32_t>(in + 7) << 14 };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 30-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7) >> 16 | SafeLoad<uint32_t>(in + 8) << 16, SafeLoad<uint32_t>(in + 8) >> 14 | SafeLoad<uint32_t>(in + 9) << 18, SafeLoad<uint32_t>(in + 9) >> 12 | SafeLoad<uint32_t>(in + 10) << 20, SafeLoad<uint32_t>(in + 10) >> 10 | SafeLoad<uint32_t>(in + 11) << 22, SafeLoad<uint32_t>(in + 11) >> 8 | SafeLoad<uint32_t>(in + 12) << 24, SafeLoad<uint32_t>(in + 12) >> 6 | SafeLoad<uint32_t>(in + 13) << 26, SafeLoad<uint32_t>(in + 13) >> 4 | SafeLoad<uint32_t>(in + 14) << 28, SafeLoad<uint32_t>(in + 14) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 2 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 30-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 30 | SafeLoad<uint32_t>(in + 16) << 2, SafeLoad<uint32_t>(in + 16) >> 28 | SafeLoad<uint32_t>(in + 17) << 4, SafeLoad<uint32_t>(in + 17) >> 26 | SafeLoad<uint32_t>(in + 18) << 6, SafeLoad<uint32_t>(in + 18) >> 24 | SafeLoad<uint32_t>(in + 19) << 8, SafeLoad<uint32_t>(in + 19) >> 22 | SafeLoad<uint32_t>(in + 20) << 10, SafeLoad<uint32_t>(in + 20) >> 20 | SafeLoad<uint32_t>(in + 21) << 12, SafeLoad<uint32_t>(in + 21) >> 18 | SafeLoad<uint32_t>(in + 22) << 14 };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 30-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 22) >> 16 | SafeLoad<uint32_t>(in + 23) << 16, SafeLoad<uint32_t>(in + 23) >> 14 | SafeLoad<uint32_t>(in + 24) << 18, SafeLoad<uint32_t>(in + 24) >> 12 | SafeLoad<uint32_t>(in + 25) << 20, SafeLoad<uint32_t>(in + 25) >> 10 | SafeLoad<uint32_t>(in + 26) << 22, SafeLoad<uint32_t>(in + 26) >> 8 | SafeLoad<uint32_t>(in + 27) << 24, SafeLoad<uint32_t>(in + 27) >> 6 | SafeLoad<uint32_t>(in + 28) << 26, SafeLoad<uint32_t>(in + 28) >> 4 | SafeLoad<uint32_t>(in + 29) << 28, SafeLoad<uint32_t>(in + 29) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 2 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 30;
+ return in;
+}
+
+inline static const uint32_t* unpack31_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7fffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 31-bit bundles 0 to 7
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 31 | SafeLoad<uint32_t>(in + 1) << 1, SafeLoad<uint32_t>(in + 1) >> 30 | SafeLoad<uint32_t>(in + 2) << 2, SafeLoad<uint32_t>(in + 2) >> 29 | SafeLoad<uint32_t>(in + 3) << 3, SafeLoad<uint32_t>(in + 3) >> 28 | SafeLoad<uint32_t>(in + 4) << 4, SafeLoad<uint32_t>(in + 4) >> 27 | SafeLoad<uint32_t>(in + 5) << 5, SafeLoad<uint32_t>(in + 5) >> 26 | SafeLoad<uint32_t>(in + 6) << 6, SafeLoad<uint32_t>(in + 6) >> 25 | SafeLoad<uint32_t>(in + 7) << 7 };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 31-bit bundles 8 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7) >> 24 | SafeLoad<uint32_t>(in + 8) << 8, SafeLoad<uint32_t>(in + 8) >> 23 | SafeLoad<uint32_t>(in + 9) << 9, SafeLoad<uint32_t>(in + 9) >> 22 | SafeLoad<uint32_t>(in + 10) << 10, SafeLoad<uint32_t>(in + 10) >> 21 | SafeLoad<uint32_t>(in + 11) << 11, SafeLoad<uint32_t>(in + 11) >> 20 | SafeLoad<uint32_t>(in + 12) << 12, SafeLoad<uint32_t>(in + 12) >> 19 | SafeLoad<uint32_t>(in + 13) << 13, SafeLoad<uint32_t>(in + 13) >> 18 | SafeLoad<uint32_t>(in + 14) << 14, SafeLoad<uint32_t>(in + 14) >> 17 | SafeLoad<uint32_t>(in + 15) << 15 };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 31-bit bundles 16 to 23
+ words = simd_batch{ SafeLoad<uint32_t>(in + 15) >> 16 | SafeLoad<uint32_t>(in + 16) << 16, SafeLoad<uint32_t>(in + 16) >> 15 | SafeLoad<uint32_t>(in + 17) << 17, SafeLoad<uint32_t>(in + 17) >> 14 | SafeLoad<uint32_t>(in + 18) << 18, SafeLoad<uint32_t>(in + 18) >> 13 | SafeLoad<uint32_t>(in + 19) << 19, SafeLoad<uint32_t>(in + 19) >> 12 | SafeLoad<uint32_t>(in + 20) << 20, SafeLoad<uint32_t>(in + 20) >> 11 | SafeLoad<uint32_t>(in + 21) << 21, SafeLoad<uint32_t>(in + 21) >> 10 | SafeLoad<uint32_t>(in + 22) << 22, SafeLoad<uint32_t>(in + 22) >> 9 | SafeLoad<uint32_t>(in + 23) << 23 };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ // extract 31-bit bundles 24 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 23) >> 8 | SafeLoad<uint32_t>(in + 24) << 24, SafeLoad<uint32_t>(in + 24) >> 7 | SafeLoad<uint32_t>(in + 25) << 25, SafeLoad<uint32_t>(in + 25) >> 6 | SafeLoad<uint32_t>(in + 26) << 26, SafeLoad<uint32_t>(in + 26) >> 5 | SafeLoad<uint32_t>(in + 27) << 27, SafeLoad<uint32_t>(in + 27) >> 4 | SafeLoad<uint32_t>(in + 28) << 28, SafeLoad<uint32_t>(in + 28) >> 3 | SafeLoad<uint32_t>(in + 29) << 29, SafeLoad<uint32_t>(in + 29) >> 2 | SafeLoad<uint32_t>(in + 30) << 30, SafeLoad<uint32_t>(in + 30) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 1 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 8;
+
+ in += 31;
+ return in;
+}
+
+inline static const uint32_t* unpack32_32(const uint32_t* in, uint32_t* out) {
+ memcpy(out, in, 32 * sizeof(*out));
+ in += 32;
+ out += 32;
+
+ return in;
+}
+
+}; // struct UnpackBits256
+
+} // namespace
+} // namespace internal
+} // namespace arrow
+
diff --git a/src/arrow/cpp/src/arrow/util/bpacking_simd512_generated.h b/src/arrow/cpp/src/arrow/util/bpacking_simd512_generated.h
new file mode 100644
index 000000000..d5d643878
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking_simd512_generated.h
@@ -0,0 +1,837 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Automatically generated file; DO NOT EDIT.
+
+#pragma once
+
+#include <cstdint>
+#include <cstring>
+
+#include <xsimd/xsimd.hpp>
+
+#include "arrow/util/dispatch.h"
+#include "arrow/util/ubsan.h"
+
+namespace arrow {
+namespace internal {
+namespace {
+
+using ::arrow::util::SafeLoad;
+
+template <DispatchLevel level>
+struct UnpackBits512 {
+
+using simd_arch = xsimd::avx512bw;
+using simd_batch = xsimd::batch<uint32_t, simd_arch>;
+
+inline static const uint32_t* unpack0_32(const uint32_t* in, uint32_t* out) {
+ memset(out, 0x0, 32 * sizeof(*out));
+ out += 32;
+
+ return in;
+}
+
+inline static const uint32_t* unpack1_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 1-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 1-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 1;
+ return in;
+}
+
+inline static const uint32_t* unpack2_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 2-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) };
+ shifts = simd_batch{ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 2-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 2;
+ return in;
+}
+
+inline static const uint32_t* unpack3_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 3-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 0, 1, 4, 7, 10, 13 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 3-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 31 | SafeLoad<uint32_t>(in + 2) << 1, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 16, 19, 22, 25, 28, 0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 3;
+ return in;
+}
+
+inline static const uint32_t* unpack4_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xf;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 4-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) };
+ shifts = simd_batch{ 0, 4, 8, 12, 16, 20, 24, 28, 0, 4, 8, 12, 16, 20, 24, 28 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 4-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 0, 4, 8, 12, 16, 20, 24, 28, 0, 4, 8, 12, 16, 20, 24, 28 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 4;
+ return in;
+}
+
+inline static const uint32_t* unpack5_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1f;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 5-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 0, 5, 10, 15, 20, 25, 0, 3, 8, 13, 18, 23, 0, 1, 6, 11 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 5-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 31 | SafeLoad<uint32_t>(in + 3) << 1, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 29 | SafeLoad<uint32_t>(in + 4) << 3, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 16, 21, 26, 0, 4, 9, 14, 19, 24, 0, 2, 7, 12, 17, 22, 27 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 5;
+ return in;
+}
+
+inline static const uint32_t* unpack6_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3f;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 6-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) };
+ shifts = simd_batch{ 0, 6, 12, 18, 24, 0, 4, 10, 16, 22, 0, 2, 8, 14, 20, 26 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 6-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 28 | SafeLoad<uint32_t>(in + 5) << 4, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 0, 6, 12, 18, 24, 0, 4, 10, 16, 22, 0, 2, 8, 14, 20, 26 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 6;
+ return in;
+}
+
+inline static const uint32_t* unpack7_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7f;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 7-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 28 | SafeLoad<uint32_t>(in + 1) << 4, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 31 | SafeLoad<uint32_t>(in + 2) << 1, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 27 | SafeLoad<uint32_t>(in + 3) << 5, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 0, 7, 14, 21, 0, 3, 10, 17, 24, 0, 6, 13, 20, 0, 2, 9 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 7-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 26 | SafeLoad<uint32_t>(in + 5) << 6, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 29 | SafeLoad<uint32_t>(in + 6) << 3, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 16, 23, 0, 5, 12, 19, 0, 1, 8, 15, 22, 0, 4, 11, 18, 25 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 7;
+ return in;
+}
+
+inline static const uint32_t* unpack8_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 8-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) };
+ shifts = simd_batch{ 0, 8, 16, 24, 0, 8, 16, 24, 0, 8, 16, 24, 0, 8, 16, 24 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 8-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) };
+ shifts = simd_batch{ 0, 8, 16, 24, 0, 8, 16, 24, 0, 8, 16, 24, 0, 8, 16, 24 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 8;
+ return in;
+}
+
+inline static const uint32_t* unpack9_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1ff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 9-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 27 | SafeLoad<uint32_t>(in + 1) << 5, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 31 | SafeLoad<uint32_t>(in + 2) << 1, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 0, 9, 18, 0, 4, 13, 22, 0, 8, 17, 0, 3, 12, 21, 0, 7 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 9-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 25 | SafeLoad<uint32_t>(in + 5) << 7, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 29 | SafeLoad<uint32_t>(in + 6) << 3, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) };
+ shifts = simd_batch{ 16, 0, 2, 11, 20, 0, 6, 15, 0, 1, 10, 19, 0, 5, 14, 23 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 9;
+ return in;
+}
+
+inline static const uint32_t* unpack10_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3ff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 10-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) };
+ shifts = simd_batch{ 0, 10, 20, 0, 8, 18, 0, 6, 16, 0, 4, 14, 0, 2, 12, 22 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 10-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 30 | SafeLoad<uint32_t>(in + 6) << 2, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 28 | SafeLoad<uint32_t>(in + 7) << 4, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 26 | SafeLoad<uint32_t>(in + 8) << 6, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 24 | SafeLoad<uint32_t>(in + 9) << 8, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) };
+ shifts = simd_batch{ 0, 10, 20, 0, 8, 18, 0, 6, 16, 0, 4, 14, 0, 2, 12, 22 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 10;
+ return in;
+}
+
+inline static const uint32_t* unpack11_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7ff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 11-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 22 | SafeLoad<uint32_t>(in + 1) << 10, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 23 | SafeLoad<uint32_t>(in + 2) << 9, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 24 | SafeLoad<uint32_t>(in + 3) << 8, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 25 | SafeLoad<uint32_t>(in + 4) << 7, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 26 | SafeLoad<uint32_t>(in + 5) << 6, SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 0, 11, 0, 1, 12, 0, 2, 13, 0, 3, 14, 0, 4, 15, 0, 5 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 11-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 27 | SafeLoad<uint32_t>(in + 6) << 5, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 28 | SafeLoad<uint32_t>(in + 7) << 4, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 29 | SafeLoad<uint32_t>(in + 8) << 3, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 30 | SafeLoad<uint32_t>(in + 9) << 2, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 31 | SafeLoad<uint32_t>(in + 10) << 1, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) };
+ shifts = simd_batch{ 16, 0, 6, 17, 0, 7, 18, 0, 8, 19, 0, 9, 20, 0, 10, 21 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 11;
+ return in;
+}
+
+inline static const uint32_t* unpack12_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xfff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 12-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 24 | SafeLoad<uint32_t>(in + 1) << 8, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 28 | SafeLoad<uint32_t>(in + 5) << 4, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) };
+ shifts = simd_batch{ 0, 12, 0, 4, 16, 0, 8, 20, 0, 12, 0, 4, 16, 0, 8, 20 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 12-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 24 | SafeLoad<uint32_t>(in + 10) << 8, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 28 | SafeLoad<uint32_t>(in + 11) << 4, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) };
+ shifts = simd_batch{ 0, 12, 0, 4, 16, 0, 8, 20, 0, 12, 0, 4, 16, 0, 8, 20 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 12;
+ return in;
+}
+
+inline static const uint32_t* unpack13_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1fff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 13-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 26 | SafeLoad<uint32_t>(in + 1) << 6, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 20 | SafeLoad<uint32_t>(in + 2) << 12, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 27 | SafeLoad<uint32_t>(in + 3) << 5, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 21 | SafeLoad<uint32_t>(in + 4) << 11, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 28 | SafeLoad<uint32_t>(in + 5) << 4, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 22 | SafeLoad<uint32_t>(in + 6) << 10, SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 0, 13, 0, 7, 0, 1, 14, 0, 8, 0, 2, 15, 0, 9, 0, 3 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 13-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 29 | SafeLoad<uint32_t>(in + 7) << 3, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 23 | SafeLoad<uint32_t>(in + 8) << 9, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 30 | SafeLoad<uint32_t>(in + 9) << 2, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 24 | SafeLoad<uint32_t>(in + 10) << 8, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 31 | SafeLoad<uint32_t>(in + 11) << 1, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 25 | SafeLoad<uint32_t>(in + 12) << 7, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) };
+ shifts = simd_batch{ 16, 0, 10, 0, 4, 17, 0, 11, 0, 5, 18, 0, 12, 0, 6, 19 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 13;
+ return in;
+}
+
+inline static const uint32_t* unpack14_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3fff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 14-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 28 | SafeLoad<uint32_t>(in + 1) << 4, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 24 | SafeLoad<uint32_t>(in + 2) << 8, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 20 | SafeLoad<uint32_t>(in + 3) << 12, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 26 | SafeLoad<uint32_t>(in + 5) << 6, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 22 | SafeLoad<uint32_t>(in + 6) << 10, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) };
+ shifts = simd_batch{ 0, 14, 0, 10, 0, 6, 0, 2, 16, 0, 12, 0, 8, 0, 4, 18 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 14-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 24 | SafeLoad<uint32_t>(in + 9) << 8, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 20 | SafeLoad<uint32_t>(in + 10) << 12, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 30 | SafeLoad<uint32_t>(in + 11) << 2, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 26 | SafeLoad<uint32_t>(in + 12) << 6, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 22 | SafeLoad<uint32_t>(in + 13) << 10, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) };
+ shifts = simd_batch{ 0, 14, 0, 10, 0, 6, 0, 2, 16, 0, 12, 0, 8, 0, 4, 18 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 14;
+ return in;
+}
+
+inline static const uint32_t* unpack15_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7fff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 15-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 22 | SafeLoad<uint32_t>(in + 5) << 10, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 20 | SafeLoad<uint32_t>(in + 6) << 12, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 18 | SafeLoad<uint32_t>(in + 7) << 14, SafeLoad<uint32_t>(in + 7) };
+ shifts = simd_batch{ 0, 15, 0, 13, 0, 11, 0, 9, 0, 7, 0, 5, 0, 3, 0, 1 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 15-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 31 | SafeLoad<uint32_t>(in + 8) << 1, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 29 | SafeLoad<uint32_t>(in + 9) << 3, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 27 | SafeLoad<uint32_t>(in + 10) << 5, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 25 | SafeLoad<uint32_t>(in + 11) << 7, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 23 | SafeLoad<uint32_t>(in + 12) << 9, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 21 | SafeLoad<uint32_t>(in + 13) << 11, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 19 | SafeLoad<uint32_t>(in + 14) << 13, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) };
+ shifts = simd_batch{ 16, 0, 14, 0, 12, 0, 10, 0, 8, 0, 6, 0, 4, 0, 2, 17 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 15;
+ return in;
+}
+
+inline static const uint32_t* unpack16_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 16-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) };
+ shifts = simd_batch{ 0, 16, 0, 16, 0, 16, 0, 16, 0, 16, 0, 16, 0, 16, 0, 16 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 16-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) };
+ shifts = simd_batch{ 0, 16, 0, 16, 0, 16, 0, 16, 0, 16, 0, 16, 0, 16, 0, 16 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 16;
+ return in;
+}
+
+inline static const uint32_t* unpack17_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1ffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 17-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 17 | SafeLoad<uint32_t>(in + 1) << 15, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 19 | SafeLoad<uint32_t>(in + 2) << 13, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 21 | SafeLoad<uint32_t>(in + 3) << 11, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 23 | SafeLoad<uint32_t>(in + 4) << 9, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 25 | SafeLoad<uint32_t>(in + 5) << 7, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 27 | SafeLoad<uint32_t>(in + 6) << 5, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 29 | SafeLoad<uint32_t>(in + 7) << 3, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 31 | SafeLoad<uint32_t>(in + 8) << 1 };
+ shifts = simd_batch{ 0, 0, 2, 0, 4, 0, 6, 0, 8, 0, 10, 0, 12, 0, 14, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 17-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 8) >> 16 | SafeLoad<uint32_t>(in + 9) << 16, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 18 | SafeLoad<uint32_t>(in + 10) << 14, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 20 | SafeLoad<uint32_t>(in + 11) << 12, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 22 | SafeLoad<uint32_t>(in + 12) << 10, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 24 | SafeLoad<uint32_t>(in + 13) << 8, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 26 | SafeLoad<uint32_t>(in + 14) << 6, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) >> 28 | SafeLoad<uint32_t>(in + 15) << 4, SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 30 | SafeLoad<uint32_t>(in + 16) << 2, SafeLoad<uint32_t>(in + 16) };
+ shifts = simd_batch{ 0, 1, 0, 3, 0, 5, 0, 7, 0, 9, 0, 11, 0, 13, 0, 15 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 17;
+ return in;
+}
+
+inline static const uint32_t* unpack18_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3ffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 18-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 18 | SafeLoad<uint32_t>(in + 1) << 14, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 22 | SafeLoad<uint32_t>(in + 2) << 10, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4) >> 16 | SafeLoad<uint32_t>(in + 5) << 16, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 20 | SafeLoad<uint32_t>(in + 6) << 12, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8) };
+ shifts = simd_batch{ 0, 0, 4, 0, 8, 0, 12, 0, 0, 2, 0, 6, 0, 10, 0, 14 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 18-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 18 | SafeLoad<uint32_t>(in + 10) << 14, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 22 | SafeLoad<uint32_t>(in + 11) << 10, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 26 | SafeLoad<uint32_t>(in + 12) << 6, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 30 | SafeLoad<uint32_t>(in + 13) << 2, SafeLoad<uint32_t>(in + 13) >> 16 | SafeLoad<uint32_t>(in + 14) << 16, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) >> 20 | SafeLoad<uint32_t>(in + 15) << 12, SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 24 | SafeLoad<uint32_t>(in + 16) << 8, SafeLoad<uint32_t>(in + 16), SafeLoad<uint32_t>(in + 16) >> 28 | SafeLoad<uint32_t>(in + 17) << 4, SafeLoad<uint32_t>(in + 17) };
+ shifts = simd_batch{ 0, 0, 4, 0, 8, 0, 12, 0, 0, 2, 0, 6, 0, 10, 0, 14 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 18;
+ return in;
+}
+
+inline static const uint32_t* unpack19_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7ffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 19-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 19 | SafeLoad<uint32_t>(in + 1) << 13, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 25 | SafeLoad<uint32_t>(in + 2) << 7, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 31 | SafeLoad<uint32_t>(in + 3) << 1, SafeLoad<uint32_t>(in + 3) >> 18 | SafeLoad<uint32_t>(in + 4) << 14, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 24 | SafeLoad<uint32_t>(in + 5) << 8, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 30 | SafeLoad<uint32_t>(in + 6) << 2, SafeLoad<uint32_t>(in + 6) >> 17 | SafeLoad<uint32_t>(in + 7) << 15, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 23 | SafeLoad<uint32_t>(in + 8) << 9, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 29 | SafeLoad<uint32_t>(in + 9) << 3 };
+ shifts = simd_batch{ 0, 0, 6, 0, 12, 0, 0, 5, 0, 11, 0, 0, 4, 0, 10, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 19-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 9) >> 16 | SafeLoad<uint32_t>(in + 10) << 16, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 22 | SafeLoad<uint32_t>(in + 11) << 10, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 28 | SafeLoad<uint32_t>(in + 12) << 4, SafeLoad<uint32_t>(in + 12) >> 15 | SafeLoad<uint32_t>(in + 13) << 17, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 21 | SafeLoad<uint32_t>(in + 14) << 11, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) >> 27 | SafeLoad<uint32_t>(in + 15) << 5, SafeLoad<uint32_t>(in + 15) >> 14 | SafeLoad<uint32_t>(in + 16) << 18, SafeLoad<uint32_t>(in + 16), SafeLoad<uint32_t>(in + 16) >> 20 | SafeLoad<uint32_t>(in + 17) << 12, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 26 | SafeLoad<uint32_t>(in + 18) << 6, SafeLoad<uint32_t>(in + 18) };
+ shifts = simd_batch{ 0, 3, 0, 9, 0, 0, 2, 0, 8, 0, 0, 1, 0, 7, 0, 13 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 19;
+ return in;
+}
+
+inline static const uint32_t* unpack20_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xfffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 20-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 20 | SafeLoad<uint32_t>(in + 1) << 12, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2) >> 16 | SafeLoad<uint32_t>(in + 3) << 16, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 20 | SafeLoad<uint32_t>(in + 6) << 12, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 28 | SafeLoad<uint32_t>(in + 7) << 4, SafeLoad<uint32_t>(in + 7) >> 16 | SafeLoad<uint32_t>(in + 8) << 16, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 24 | SafeLoad<uint32_t>(in + 9) << 8, SafeLoad<uint32_t>(in + 9) };
+ shifts = simd_batch{ 0, 0, 8, 0, 0, 4, 0, 12, 0, 0, 8, 0, 0, 4, 0, 12 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 20-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 20 | SafeLoad<uint32_t>(in + 11) << 12, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 28 | SafeLoad<uint32_t>(in + 12) << 4, SafeLoad<uint32_t>(in + 12) >> 16 | SafeLoad<uint32_t>(in + 13) << 16, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 24 | SafeLoad<uint32_t>(in + 14) << 8, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 20 | SafeLoad<uint32_t>(in + 16) << 12, SafeLoad<uint32_t>(in + 16), SafeLoad<uint32_t>(in + 16) >> 28 | SafeLoad<uint32_t>(in + 17) << 4, SafeLoad<uint32_t>(in + 17) >> 16 | SafeLoad<uint32_t>(in + 18) << 16, SafeLoad<uint32_t>(in + 18), SafeLoad<uint32_t>(in + 18) >> 24 | SafeLoad<uint32_t>(in + 19) << 8, SafeLoad<uint32_t>(in + 19) };
+ shifts = simd_batch{ 0, 0, 8, 0, 0, 4, 0, 12, 0, 0, 8, 0, 0, 4, 0, 12 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 20;
+ return in;
+}
+
+inline static const uint32_t* unpack21_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1fffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 21-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 21 | SafeLoad<uint32_t>(in + 1) << 11, SafeLoad<uint32_t>(in + 1), SafeLoad<uint32_t>(in + 1) >> 31 | SafeLoad<uint32_t>(in + 2) << 1, SafeLoad<uint32_t>(in + 2) >> 20 | SafeLoad<uint32_t>(in + 3) << 12, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 30 | SafeLoad<uint32_t>(in + 4) << 2, SafeLoad<uint32_t>(in + 4) >> 19 | SafeLoad<uint32_t>(in + 5) << 13, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 29 | SafeLoad<uint32_t>(in + 6) << 3, SafeLoad<uint32_t>(in + 6) >> 18 | SafeLoad<uint32_t>(in + 7) << 14, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8) >> 17 | SafeLoad<uint32_t>(in + 9) << 15, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 27 | SafeLoad<uint32_t>(in + 10) << 5 };
+ shifts = simd_batch{ 0, 0, 10, 0, 0, 9, 0, 0, 8, 0, 0, 7, 0, 0, 6, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 21-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 10) >> 16 | SafeLoad<uint32_t>(in + 11) << 16, SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 26 | SafeLoad<uint32_t>(in + 12) << 6, SafeLoad<uint32_t>(in + 12) >> 15 | SafeLoad<uint32_t>(in + 13) << 17, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 25 | SafeLoad<uint32_t>(in + 14) << 7, SafeLoad<uint32_t>(in + 14) >> 14 | SafeLoad<uint32_t>(in + 15) << 18, SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 24 | SafeLoad<uint32_t>(in + 16) << 8, SafeLoad<uint32_t>(in + 16) >> 13 | SafeLoad<uint32_t>(in + 17) << 19, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 23 | SafeLoad<uint32_t>(in + 18) << 9, SafeLoad<uint32_t>(in + 18) >> 12 | SafeLoad<uint32_t>(in + 19) << 20, SafeLoad<uint32_t>(in + 19), SafeLoad<uint32_t>(in + 19) >> 22 | SafeLoad<uint32_t>(in + 20) << 10, SafeLoad<uint32_t>(in + 20) };
+ shifts = simd_batch{ 0, 5, 0, 0, 4, 0, 0, 3, 0, 0, 2, 0, 0, 1, 0, 11 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 21;
+ return in;
+}
+
+inline static const uint32_t* unpack22_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3fffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 22-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 22 | SafeLoad<uint32_t>(in + 1) << 10, SafeLoad<uint32_t>(in + 1) >> 12 | SafeLoad<uint32_t>(in + 2) << 20, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 24 | SafeLoad<uint32_t>(in + 3) << 8, SafeLoad<uint32_t>(in + 3) >> 14 | SafeLoad<uint32_t>(in + 4) << 18, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 26 | SafeLoad<uint32_t>(in + 5) << 6, SafeLoad<uint32_t>(in + 5) >> 16 | SafeLoad<uint32_t>(in + 6) << 16, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 28 | SafeLoad<uint32_t>(in + 7) << 4, SafeLoad<uint32_t>(in + 7) >> 18 | SafeLoad<uint32_t>(in + 8) << 14, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 30 | SafeLoad<uint32_t>(in + 9) << 2, SafeLoad<uint32_t>(in + 9) >> 20 | SafeLoad<uint32_t>(in + 10) << 12, SafeLoad<uint32_t>(in + 10) };
+ shifts = simd_batch{ 0, 0, 0, 2, 0, 0, 4, 0, 0, 6, 0, 0, 8, 0, 0, 10 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 22-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 11), SafeLoad<uint32_t>(in + 11) >> 22 | SafeLoad<uint32_t>(in + 12) << 10, SafeLoad<uint32_t>(in + 12) >> 12 | SafeLoad<uint32_t>(in + 13) << 20, SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 24 | SafeLoad<uint32_t>(in + 14) << 8, SafeLoad<uint32_t>(in + 14) >> 14 | SafeLoad<uint32_t>(in + 15) << 18, SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 26 | SafeLoad<uint32_t>(in + 16) << 6, SafeLoad<uint32_t>(in + 16) >> 16 | SafeLoad<uint32_t>(in + 17) << 16, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 28 | SafeLoad<uint32_t>(in + 18) << 4, SafeLoad<uint32_t>(in + 18) >> 18 | SafeLoad<uint32_t>(in + 19) << 14, SafeLoad<uint32_t>(in + 19), SafeLoad<uint32_t>(in + 19) >> 30 | SafeLoad<uint32_t>(in + 20) << 2, SafeLoad<uint32_t>(in + 20) >> 20 | SafeLoad<uint32_t>(in + 21) << 12, SafeLoad<uint32_t>(in + 21) };
+ shifts = simd_batch{ 0, 0, 0, 2, 0, 0, 4, 0, 0, 6, 0, 0, 8, 0, 0, 10 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 22;
+ return in;
+}
+
+inline static const uint32_t* unpack23_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7fffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 23-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 23 | SafeLoad<uint32_t>(in + 1) << 9, SafeLoad<uint32_t>(in + 1) >> 14 | SafeLoad<uint32_t>(in + 2) << 18, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 2) >> 28 | SafeLoad<uint32_t>(in + 3) << 4, SafeLoad<uint32_t>(in + 3) >> 19 | SafeLoad<uint32_t>(in + 4) << 13, SafeLoad<uint32_t>(in + 4) >> 10 | SafeLoad<uint32_t>(in + 5) << 22, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 24 | SafeLoad<uint32_t>(in + 6) << 8, SafeLoad<uint32_t>(in + 6) >> 15 | SafeLoad<uint32_t>(in + 7) << 17, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 29 | SafeLoad<uint32_t>(in + 8) << 3, SafeLoad<uint32_t>(in + 8) >> 20 | SafeLoad<uint32_t>(in + 9) << 12, SafeLoad<uint32_t>(in + 9) >> 11 | SafeLoad<uint32_t>(in + 10) << 21, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 25 | SafeLoad<uint32_t>(in + 11) << 7 };
+ shifts = simd_batch{ 0, 0, 0, 5, 0, 0, 0, 1, 0, 0, 6, 0, 0, 0, 2, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 23-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 11) >> 16 | SafeLoad<uint32_t>(in + 12) << 16, SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 30 | SafeLoad<uint32_t>(in + 13) << 2, SafeLoad<uint32_t>(in + 13) >> 21 | SafeLoad<uint32_t>(in + 14) << 11, SafeLoad<uint32_t>(in + 14) >> 12 | SafeLoad<uint32_t>(in + 15) << 20, SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 26 | SafeLoad<uint32_t>(in + 16) << 6, SafeLoad<uint32_t>(in + 16) >> 17 | SafeLoad<uint32_t>(in + 17) << 15, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 31 | SafeLoad<uint32_t>(in + 18) << 1, SafeLoad<uint32_t>(in + 18) >> 22 | SafeLoad<uint32_t>(in + 19) << 10, SafeLoad<uint32_t>(in + 19) >> 13 | SafeLoad<uint32_t>(in + 20) << 19, SafeLoad<uint32_t>(in + 20), SafeLoad<uint32_t>(in + 20) >> 27 | SafeLoad<uint32_t>(in + 21) << 5, SafeLoad<uint32_t>(in + 21) >> 18 | SafeLoad<uint32_t>(in + 22) << 14, SafeLoad<uint32_t>(in + 22) };
+ shifts = simd_batch{ 0, 7, 0, 0, 0, 3, 0, 0, 8, 0, 0, 0, 4, 0, 0, 9 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 23;
+ return in;
+}
+
+inline static const uint32_t* unpack24_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 24-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 24 | SafeLoad<uint32_t>(in + 1) << 8, SafeLoad<uint32_t>(in + 1) >> 16 | SafeLoad<uint32_t>(in + 2) << 16, SafeLoad<uint32_t>(in + 2), SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4) >> 16 | SafeLoad<uint32_t>(in + 5) << 16, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7) >> 16 | SafeLoad<uint32_t>(in + 8) << 16, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 24 | SafeLoad<uint32_t>(in + 10) << 8, SafeLoad<uint32_t>(in + 10) >> 16 | SafeLoad<uint32_t>(in + 11) << 16, SafeLoad<uint32_t>(in + 11) };
+ shifts = simd_batch{ 0, 0, 0, 8, 0, 0, 0, 8, 0, 0, 0, 8, 0, 0, 0, 8 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 24-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 12), SafeLoad<uint32_t>(in + 12) >> 24 | SafeLoad<uint32_t>(in + 13) << 8, SafeLoad<uint32_t>(in + 13) >> 16 | SafeLoad<uint32_t>(in + 14) << 16, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 24 | SafeLoad<uint32_t>(in + 16) << 8, SafeLoad<uint32_t>(in + 16) >> 16 | SafeLoad<uint32_t>(in + 17) << 16, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 18), SafeLoad<uint32_t>(in + 18) >> 24 | SafeLoad<uint32_t>(in + 19) << 8, SafeLoad<uint32_t>(in + 19) >> 16 | SafeLoad<uint32_t>(in + 20) << 16, SafeLoad<uint32_t>(in + 20), SafeLoad<uint32_t>(in + 21), SafeLoad<uint32_t>(in + 21) >> 24 | SafeLoad<uint32_t>(in + 22) << 8, SafeLoad<uint32_t>(in + 22) >> 16 | SafeLoad<uint32_t>(in + 23) << 16, SafeLoad<uint32_t>(in + 23) };
+ shifts = simd_batch{ 0, 0, 0, 8, 0, 0, 0, 8, 0, 0, 0, 8, 0, 0, 0, 8 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 24;
+ return in;
+}
+
+inline static const uint32_t* unpack25_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1ffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 25-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 25 | SafeLoad<uint32_t>(in + 1) << 7, SafeLoad<uint32_t>(in + 1) >> 18 | SafeLoad<uint32_t>(in + 2) << 14, SafeLoad<uint32_t>(in + 2) >> 11 | SafeLoad<uint32_t>(in + 3) << 21, SafeLoad<uint32_t>(in + 3), SafeLoad<uint32_t>(in + 3) >> 29 | SafeLoad<uint32_t>(in + 4) << 3, SafeLoad<uint32_t>(in + 4) >> 22 | SafeLoad<uint32_t>(in + 5) << 10, SafeLoad<uint32_t>(in + 5) >> 15 | SafeLoad<uint32_t>(in + 6) << 17, SafeLoad<uint32_t>(in + 6) >> 8 | SafeLoad<uint32_t>(in + 7) << 24, SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 26 | SafeLoad<uint32_t>(in + 8) << 6, SafeLoad<uint32_t>(in + 8) >> 19 | SafeLoad<uint32_t>(in + 9) << 13, SafeLoad<uint32_t>(in + 9) >> 12 | SafeLoad<uint32_t>(in + 10) << 20, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 30 | SafeLoad<uint32_t>(in + 11) << 2, SafeLoad<uint32_t>(in + 11) >> 23 | SafeLoad<uint32_t>(in + 12) << 9 };
+ shifts = simd_batch{ 0, 0, 0, 0, 4, 0, 0, 0, 0, 1, 0, 0, 0, 5, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 25-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 12) >> 16 | SafeLoad<uint32_t>(in + 13) << 16, SafeLoad<uint32_t>(in + 13) >> 9 | SafeLoad<uint32_t>(in + 14) << 23, SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) >> 27 | SafeLoad<uint32_t>(in + 15) << 5, SafeLoad<uint32_t>(in + 15) >> 20 | SafeLoad<uint32_t>(in + 16) << 12, SafeLoad<uint32_t>(in + 16) >> 13 | SafeLoad<uint32_t>(in + 17) << 19, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 31 | SafeLoad<uint32_t>(in + 18) << 1, SafeLoad<uint32_t>(in + 18) >> 24 | SafeLoad<uint32_t>(in + 19) << 8, SafeLoad<uint32_t>(in + 19) >> 17 | SafeLoad<uint32_t>(in + 20) << 15, SafeLoad<uint32_t>(in + 20) >> 10 | SafeLoad<uint32_t>(in + 21) << 22, SafeLoad<uint32_t>(in + 21), SafeLoad<uint32_t>(in + 21) >> 28 | SafeLoad<uint32_t>(in + 22) << 4, SafeLoad<uint32_t>(in + 22) >> 21 | SafeLoad<uint32_t>(in + 23) << 11, SafeLoad<uint32_t>(in + 23) >> 14 | SafeLoad<uint32_t>(in + 24) << 18, SafeLoad<uint32_t>(in + 24) };
+ shifts = simd_batch{ 0, 0, 2, 0, 0, 0, 6, 0, 0, 0, 0, 3, 0, 0, 0, 7 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 25;
+ return in;
+}
+
+inline static const uint32_t* unpack26_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3ffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 26-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 26 | SafeLoad<uint32_t>(in + 1) << 6, SafeLoad<uint32_t>(in + 1) >> 20 | SafeLoad<uint32_t>(in + 2) << 12, SafeLoad<uint32_t>(in + 2) >> 14 | SafeLoad<uint32_t>(in + 3) << 18, SafeLoad<uint32_t>(in + 3) >> 8 | SafeLoad<uint32_t>(in + 4) << 24, SafeLoad<uint32_t>(in + 4), SafeLoad<uint32_t>(in + 4) >> 28 | SafeLoad<uint32_t>(in + 5) << 4, SafeLoad<uint32_t>(in + 5) >> 22 | SafeLoad<uint32_t>(in + 6) << 10, SafeLoad<uint32_t>(in + 6) >> 16 | SafeLoad<uint32_t>(in + 7) << 16, SafeLoad<uint32_t>(in + 7) >> 10 | SafeLoad<uint32_t>(in + 8) << 22, SafeLoad<uint32_t>(in + 8), SafeLoad<uint32_t>(in + 8) >> 30 | SafeLoad<uint32_t>(in + 9) << 2, SafeLoad<uint32_t>(in + 9) >> 24 | SafeLoad<uint32_t>(in + 10) << 8, SafeLoad<uint32_t>(in + 10) >> 18 | SafeLoad<uint32_t>(in + 11) << 14, SafeLoad<uint32_t>(in + 11) >> 12 | SafeLoad<uint32_t>(in + 12) << 20, SafeLoad<uint32_t>(in + 12) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 4, 0, 0, 0, 0, 6 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 26-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 13), SafeLoad<uint32_t>(in + 13) >> 26 | SafeLoad<uint32_t>(in + 14) << 6, SafeLoad<uint32_t>(in + 14) >> 20 | SafeLoad<uint32_t>(in + 15) << 12, SafeLoad<uint32_t>(in + 15) >> 14 | SafeLoad<uint32_t>(in + 16) << 18, SafeLoad<uint32_t>(in + 16) >> 8 | SafeLoad<uint32_t>(in + 17) << 24, SafeLoad<uint32_t>(in + 17), SafeLoad<uint32_t>(in + 17) >> 28 | SafeLoad<uint32_t>(in + 18) << 4, SafeLoad<uint32_t>(in + 18) >> 22 | SafeLoad<uint32_t>(in + 19) << 10, SafeLoad<uint32_t>(in + 19) >> 16 | SafeLoad<uint32_t>(in + 20) << 16, SafeLoad<uint32_t>(in + 20) >> 10 | SafeLoad<uint32_t>(in + 21) << 22, SafeLoad<uint32_t>(in + 21), SafeLoad<uint32_t>(in + 21) >> 30 | SafeLoad<uint32_t>(in + 22) << 2, SafeLoad<uint32_t>(in + 22) >> 24 | SafeLoad<uint32_t>(in + 23) << 8, SafeLoad<uint32_t>(in + 23) >> 18 | SafeLoad<uint32_t>(in + 24) << 14, SafeLoad<uint32_t>(in + 24) >> 12 | SafeLoad<uint32_t>(in + 25) << 20, SafeLoad<uint32_t>(in + 25) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 4, 0, 0, 0, 0, 6 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 26;
+ return in;
+}
+
+inline static const uint32_t* unpack27_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7ffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 27-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 27 | SafeLoad<uint32_t>(in + 1) << 5, SafeLoad<uint32_t>(in + 1) >> 22 | SafeLoad<uint32_t>(in + 2) << 10, SafeLoad<uint32_t>(in + 2) >> 17 | SafeLoad<uint32_t>(in + 3) << 15, SafeLoad<uint32_t>(in + 3) >> 12 | SafeLoad<uint32_t>(in + 4) << 20, SafeLoad<uint32_t>(in + 4) >> 7 | SafeLoad<uint32_t>(in + 5) << 25, SafeLoad<uint32_t>(in + 5), SafeLoad<uint32_t>(in + 5) >> 29 | SafeLoad<uint32_t>(in + 6) << 3, SafeLoad<uint32_t>(in + 6) >> 24 | SafeLoad<uint32_t>(in + 7) << 8, SafeLoad<uint32_t>(in + 7) >> 19 | SafeLoad<uint32_t>(in + 8) << 13, SafeLoad<uint32_t>(in + 8) >> 14 | SafeLoad<uint32_t>(in + 9) << 18, SafeLoad<uint32_t>(in + 9) >> 9 | SafeLoad<uint32_t>(in + 10) << 23, SafeLoad<uint32_t>(in + 10), SafeLoad<uint32_t>(in + 10) >> 31 | SafeLoad<uint32_t>(in + 11) << 1, SafeLoad<uint32_t>(in + 11) >> 26 | SafeLoad<uint32_t>(in + 12) << 6, SafeLoad<uint32_t>(in + 12) >> 21 | SafeLoad<uint32_t>(in + 13) << 11 };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 4, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 27-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 13) >> 16 | SafeLoad<uint32_t>(in + 14) << 16, SafeLoad<uint32_t>(in + 14) >> 11 | SafeLoad<uint32_t>(in + 15) << 21, SafeLoad<uint32_t>(in + 15) >> 6 | SafeLoad<uint32_t>(in + 16) << 26, SafeLoad<uint32_t>(in + 16), SafeLoad<uint32_t>(in + 16) >> 28 | SafeLoad<uint32_t>(in + 17) << 4, SafeLoad<uint32_t>(in + 17) >> 23 | SafeLoad<uint32_t>(in + 18) << 9, SafeLoad<uint32_t>(in + 18) >> 18 | SafeLoad<uint32_t>(in + 19) << 14, SafeLoad<uint32_t>(in + 19) >> 13 | SafeLoad<uint32_t>(in + 20) << 19, SafeLoad<uint32_t>(in + 20) >> 8 | SafeLoad<uint32_t>(in + 21) << 24, SafeLoad<uint32_t>(in + 21), SafeLoad<uint32_t>(in + 21) >> 30 | SafeLoad<uint32_t>(in + 22) << 2, SafeLoad<uint32_t>(in + 22) >> 25 | SafeLoad<uint32_t>(in + 23) << 7, SafeLoad<uint32_t>(in + 23) >> 20 | SafeLoad<uint32_t>(in + 24) << 12, SafeLoad<uint32_t>(in + 24) >> 15 | SafeLoad<uint32_t>(in + 25) << 17, SafeLoad<uint32_t>(in + 25) >> 10 | SafeLoad<uint32_t>(in + 26) << 22, SafeLoad<uint32_t>(in + 26) };
+ shifts = simd_batch{ 0, 0, 0, 1, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 5 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 27;
+ return in;
+}
+
+inline static const uint32_t* unpack28_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0xfffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 28-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 28 | SafeLoad<uint32_t>(in + 1) << 4, SafeLoad<uint32_t>(in + 1) >> 24 | SafeLoad<uint32_t>(in + 2) << 8, SafeLoad<uint32_t>(in + 2) >> 20 | SafeLoad<uint32_t>(in + 3) << 12, SafeLoad<uint32_t>(in + 3) >> 16 | SafeLoad<uint32_t>(in + 4) << 16, SafeLoad<uint32_t>(in + 4) >> 12 | SafeLoad<uint32_t>(in + 5) << 20, SafeLoad<uint32_t>(in + 5) >> 8 | SafeLoad<uint32_t>(in + 6) << 24, SafeLoad<uint32_t>(in + 6), SafeLoad<uint32_t>(in + 7), SafeLoad<uint32_t>(in + 7) >> 28 | SafeLoad<uint32_t>(in + 8) << 4, SafeLoad<uint32_t>(in + 8) >> 24 | SafeLoad<uint32_t>(in + 9) << 8, SafeLoad<uint32_t>(in + 9) >> 20 | SafeLoad<uint32_t>(in + 10) << 12, SafeLoad<uint32_t>(in + 10) >> 16 | SafeLoad<uint32_t>(in + 11) << 16, SafeLoad<uint32_t>(in + 11) >> 12 | SafeLoad<uint32_t>(in + 12) << 20, SafeLoad<uint32_t>(in + 12) >> 8 | SafeLoad<uint32_t>(in + 13) << 24, SafeLoad<uint32_t>(in + 13) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 4 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 28-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 14), SafeLoad<uint32_t>(in + 14) >> 28 | SafeLoad<uint32_t>(in + 15) << 4, SafeLoad<uint32_t>(in + 15) >> 24 | SafeLoad<uint32_t>(in + 16) << 8, SafeLoad<uint32_t>(in + 16) >> 20 | SafeLoad<uint32_t>(in + 17) << 12, SafeLoad<uint32_t>(in + 17) >> 16 | SafeLoad<uint32_t>(in + 18) << 16, SafeLoad<uint32_t>(in + 18) >> 12 | SafeLoad<uint32_t>(in + 19) << 20, SafeLoad<uint32_t>(in + 19) >> 8 | SafeLoad<uint32_t>(in + 20) << 24, SafeLoad<uint32_t>(in + 20), SafeLoad<uint32_t>(in + 21), SafeLoad<uint32_t>(in + 21) >> 28 | SafeLoad<uint32_t>(in + 22) << 4, SafeLoad<uint32_t>(in + 22) >> 24 | SafeLoad<uint32_t>(in + 23) << 8, SafeLoad<uint32_t>(in + 23) >> 20 | SafeLoad<uint32_t>(in + 24) << 12, SafeLoad<uint32_t>(in + 24) >> 16 | SafeLoad<uint32_t>(in + 25) << 16, SafeLoad<uint32_t>(in + 25) >> 12 | SafeLoad<uint32_t>(in + 26) << 20, SafeLoad<uint32_t>(in + 26) >> 8 | SafeLoad<uint32_t>(in + 27) << 24, SafeLoad<uint32_t>(in + 27) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 4 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 28;
+ return in;
+}
+
+inline static const uint32_t* unpack29_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x1fffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 29-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 29 | SafeLoad<uint32_t>(in + 1) << 3, SafeLoad<uint32_t>(in + 1) >> 26 | SafeLoad<uint32_t>(in + 2) << 6, SafeLoad<uint32_t>(in + 2) >> 23 | SafeLoad<uint32_t>(in + 3) << 9, SafeLoad<uint32_t>(in + 3) >> 20 | SafeLoad<uint32_t>(in + 4) << 12, SafeLoad<uint32_t>(in + 4) >> 17 | SafeLoad<uint32_t>(in + 5) << 15, SafeLoad<uint32_t>(in + 5) >> 14 | SafeLoad<uint32_t>(in + 6) << 18, SafeLoad<uint32_t>(in + 6) >> 11 | SafeLoad<uint32_t>(in + 7) << 21, SafeLoad<uint32_t>(in + 7) >> 8 | SafeLoad<uint32_t>(in + 8) << 24, SafeLoad<uint32_t>(in + 8) >> 5 | SafeLoad<uint32_t>(in + 9) << 27, SafeLoad<uint32_t>(in + 9), SafeLoad<uint32_t>(in + 9) >> 31 | SafeLoad<uint32_t>(in + 10) << 1, SafeLoad<uint32_t>(in + 10) >> 28 | SafeLoad<uint32_t>(in + 11) << 4, SafeLoad<uint32_t>(in + 11) >> 25 | SafeLoad<uint32_t>(in + 12) << 7, SafeLoad<uint32_t>(in + 12) >> 22 | SafeLoad<uint32_t>(in + 13) << 10, SafeLoad<uint32_t>(in + 13) >> 19 | SafeLoad<uint32_t>(in + 14) << 13 };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 29-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 14) >> 16 | SafeLoad<uint32_t>(in + 15) << 16, SafeLoad<uint32_t>(in + 15) >> 13 | SafeLoad<uint32_t>(in + 16) << 19, SafeLoad<uint32_t>(in + 16) >> 10 | SafeLoad<uint32_t>(in + 17) << 22, SafeLoad<uint32_t>(in + 17) >> 7 | SafeLoad<uint32_t>(in + 18) << 25, SafeLoad<uint32_t>(in + 18) >> 4 | SafeLoad<uint32_t>(in + 19) << 28, SafeLoad<uint32_t>(in + 19), SafeLoad<uint32_t>(in + 19) >> 30 | SafeLoad<uint32_t>(in + 20) << 2, SafeLoad<uint32_t>(in + 20) >> 27 | SafeLoad<uint32_t>(in + 21) << 5, SafeLoad<uint32_t>(in + 21) >> 24 | SafeLoad<uint32_t>(in + 22) << 8, SafeLoad<uint32_t>(in + 22) >> 21 | SafeLoad<uint32_t>(in + 23) << 11, SafeLoad<uint32_t>(in + 23) >> 18 | SafeLoad<uint32_t>(in + 24) << 14, SafeLoad<uint32_t>(in + 24) >> 15 | SafeLoad<uint32_t>(in + 25) << 17, SafeLoad<uint32_t>(in + 25) >> 12 | SafeLoad<uint32_t>(in + 26) << 20, SafeLoad<uint32_t>(in + 26) >> 9 | SafeLoad<uint32_t>(in + 27) << 23, SafeLoad<uint32_t>(in + 27) >> 6 | SafeLoad<uint32_t>(in + 28) << 26, SafeLoad<uint32_t>(in + 28) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 29;
+ return in;
+}
+
+inline static const uint32_t* unpack30_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x3fffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 30-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 30 | SafeLoad<uint32_t>(in + 1) << 2, SafeLoad<uint32_t>(in + 1) >> 28 | SafeLoad<uint32_t>(in + 2) << 4, SafeLoad<uint32_t>(in + 2) >> 26 | SafeLoad<uint32_t>(in + 3) << 6, SafeLoad<uint32_t>(in + 3) >> 24 | SafeLoad<uint32_t>(in + 4) << 8, SafeLoad<uint32_t>(in + 4) >> 22 | SafeLoad<uint32_t>(in + 5) << 10, SafeLoad<uint32_t>(in + 5) >> 20 | SafeLoad<uint32_t>(in + 6) << 12, SafeLoad<uint32_t>(in + 6) >> 18 | SafeLoad<uint32_t>(in + 7) << 14, SafeLoad<uint32_t>(in + 7) >> 16 | SafeLoad<uint32_t>(in + 8) << 16, SafeLoad<uint32_t>(in + 8) >> 14 | SafeLoad<uint32_t>(in + 9) << 18, SafeLoad<uint32_t>(in + 9) >> 12 | SafeLoad<uint32_t>(in + 10) << 20, SafeLoad<uint32_t>(in + 10) >> 10 | SafeLoad<uint32_t>(in + 11) << 22, SafeLoad<uint32_t>(in + 11) >> 8 | SafeLoad<uint32_t>(in + 12) << 24, SafeLoad<uint32_t>(in + 12) >> 6 | SafeLoad<uint32_t>(in + 13) << 26, SafeLoad<uint32_t>(in + 13) >> 4 | SafeLoad<uint32_t>(in + 14) << 28, SafeLoad<uint32_t>(in + 14) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 30-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 15), SafeLoad<uint32_t>(in + 15) >> 30 | SafeLoad<uint32_t>(in + 16) << 2, SafeLoad<uint32_t>(in + 16) >> 28 | SafeLoad<uint32_t>(in + 17) << 4, SafeLoad<uint32_t>(in + 17) >> 26 | SafeLoad<uint32_t>(in + 18) << 6, SafeLoad<uint32_t>(in + 18) >> 24 | SafeLoad<uint32_t>(in + 19) << 8, SafeLoad<uint32_t>(in + 19) >> 22 | SafeLoad<uint32_t>(in + 20) << 10, SafeLoad<uint32_t>(in + 20) >> 20 | SafeLoad<uint32_t>(in + 21) << 12, SafeLoad<uint32_t>(in + 21) >> 18 | SafeLoad<uint32_t>(in + 22) << 14, SafeLoad<uint32_t>(in + 22) >> 16 | SafeLoad<uint32_t>(in + 23) << 16, SafeLoad<uint32_t>(in + 23) >> 14 | SafeLoad<uint32_t>(in + 24) << 18, SafeLoad<uint32_t>(in + 24) >> 12 | SafeLoad<uint32_t>(in + 25) << 20, SafeLoad<uint32_t>(in + 25) >> 10 | SafeLoad<uint32_t>(in + 26) << 22, SafeLoad<uint32_t>(in + 26) >> 8 | SafeLoad<uint32_t>(in + 27) << 24, SafeLoad<uint32_t>(in + 27) >> 6 | SafeLoad<uint32_t>(in + 28) << 26, SafeLoad<uint32_t>(in + 28) >> 4 | SafeLoad<uint32_t>(in + 29) << 28, SafeLoad<uint32_t>(in + 29) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 30;
+ return in;
+}
+
+inline static const uint32_t* unpack31_32(const uint32_t* in, uint32_t* out) {
+ uint32_t mask = 0x7fffffff;
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+
+ // extract 31-bit bundles 0 to 15
+ words = simd_batch{ SafeLoad<uint32_t>(in + 0), SafeLoad<uint32_t>(in + 0) >> 31 | SafeLoad<uint32_t>(in + 1) << 1, SafeLoad<uint32_t>(in + 1) >> 30 | SafeLoad<uint32_t>(in + 2) << 2, SafeLoad<uint32_t>(in + 2) >> 29 | SafeLoad<uint32_t>(in + 3) << 3, SafeLoad<uint32_t>(in + 3) >> 28 | SafeLoad<uint32_t>(in + 4) << 4, SafeLoad<uint32_t>(in + 4) >> 27 | SafeLoad<uint32_t>(in + 5) << 5, SafeLoad<uint32_t>(in + 5) >> 26 | SafeLoad<uint32_t>(in + 6) << 6, SafeLoad<uint32_t>(in + 6) >> 25 | SafeLoad<uint32_t>(in + 7) << 7, SafeLoad<uint32_t>(in + 7) >> 24 | SafeLoad<uint32_t>(in + 8) << 8, SafeLoad<uint32_t>(in + 8) >> 23 | SafeLoad<uint32_t>(in + 9) << 9, SafeLoad<uint32_t>(in + 9) >> 22 | SafeLoad<uint32_t>(in + 10) << 10, SafeLoad<uint32_t>(in + 10) >> 21 | SafeLoad<uint32_t>(in + 11) << 11, SafeLoad<uint32_t>(in + 11) >> 20 | SafeLoad<uint32_t>(in + 12) << 12, SafeLoad<uint32_t>(in + 12) >> 19 | SafeLoad<uint32_t>(in + 13) << 13, SafeLoad<uint32_t>(in + 13) >> 18 | SafeLoad<uint32_t>(in + 14) << 14, SafeLoad<uint32_t>(in + 14) >> 17 | SafeLoad<uint32_t>(in + 15) << 15 };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ // extract 31-bit bundles 16 to 31
+ words = simd_batch{ SafeLoad<uint32_t>(in + 15) >> 16 | SafeLoad<uint32_t>(in + 16) << 16, SafeLoad<uint32_t>(in + 16) >> 15 | SafeLoad<uint32_t>(in + 17) << 17, SafeLoad<uint32_t>(in + 17) >> 14 | SafeLoad<uint32_t>(in + 18) << 18, SafeLoad<uint32_t>(in + 18) >> 13 | SafeLoad<uint32_t>(in + 19) << 19, SafeLoad<uint32_t>(in + 19) >> 12 | SafeLoad<uint32_t>(in + 20) << 20, SafeLoad<uint32_t>(in + 20) >> 11 | SafeLoad<uint32_t>(in + 21) << 21, SafeLoad<uint32_t>(in + 21) >> 10 | SafeLoad<uint32_t>(in + 22) << 22, SafeLoad<uint32_t>(in + 22) >> 9 | SafeLoad<uint32_t>(in + 23) << 23, SafeLoad<uint32_t>(in + 23) >> 8 | SafeLoad<uint32_t>(in + 24) << 24, SafeLoad<uint32_t>(in + 24) >> 7 | SafeLoad<uint32_t>(in + 25) << 25, SafeLoad<uint32_t>(in + 25) >> 6 | SafeLoad<uint32_t>(in + 26) << 26, SafeLoad<uint32_t>(in + 26) >> 5 | SafeLoad<uint32_t>(in + 27) << 27, SafeLoad<uint32_t>(in + 27) >> 4 | SafeLoad<uint32_t>(in + 28) << 28, SafeLoad<uint32_t>(in + 28) >> 3 | SafeLoad<uint32_t>(in + 29) << 29, SafeLoad<uint32_t>(in + 29) >> 2 | SafeLoad<uint32_t>(in + 30) << 30, SafeLoad<uint32_t>(in + 30) };
+ shifts = simd_batch{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 };
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += 16;
+
+ in += 31;
+ return in;
+}
+
+inline static const uint32_t* unpack32_32(const uint32_t* in, uint32_t* out) {
+ memcpy(out, in, 32 * sizeof(*out));
+ in += 32;
+ out += 32;
+
+ return in;
+}
+
+}; // struct UnpackBits512
+
+} // namespace
+} // namespace internal
+} // namespace arrow
+
diff --git a/src/arrow/cpp/src/arrow/util/bpacking_simd_codegen.py b/src/arrow/cpp/src/arrow/util/bpacking_simd_codegen.py
new file mode 100644
index 000000000..9bdc22569
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking_simd_codegen.py
@@ -0,0 +1,223 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Usage:
+# python bpacking_simd_codegen.py 128 > bpacking_simd128_generated.h
+# python bpacking_simd_codegen.py 256 > bpacking_simd256_generated.h
+# python bpacking_simd_codegen.py 512 > bpacking_simd512_generated.h
+
+from functools import partial
+import sys
+from textwrap import dedent, indent
+
+
+class UnpackGenerator:
+
+ def __init__(self, simd_width):
+ self.simd_width = simd_width
+ if simd_width % 32 != 0:
+ raise("SIMD bit width should be a multiple of 32")
+ self.simd_byte_width = simd_width // 8
+
+ def print_unpack_bit0_func(self):
+ print(
+ "inline static const uint32_t* unpack0_32(const uint32_t* in, uint32_t* out) {")
+ print(" memset(out, 0x0, 32 * sizeof(*out));")
+ print(" out += 32;")
+ print("")
+ print(" return in;")
+ print("}")
+
+
+ def print_unpack_bit32_func(self):
+ print(
+ "inline static const uint32_t* unpack32_32(const uint32_t* in, uint32_t* out) {")
+ print(" memcpy(out, in, 32 * sizeof(*out));")
+ print(" in += 32;")
+ print(" out += 32;")
+ print("")
+ print(" return in;")
+ print("}")
+
+ def print_unpack_bit_func(self, bit):
+ def p(code):
+ print(indent(code, prefix=' '))
+
+ shift = 0
+ shifts = []
+ in_index = 0
+ inls = []
+ mask = (1 << bit) - 1
+ bracket = "{"
+
+ print(f"inline static const uint32_t* unpack{bit}_32(const uint32_t* in, uint32_t* out) {{")
+ p(dedent(f"""\
+ uint32_t mask = 0x{mask:0x};
+
+ simd_batch masks(mask);
+ simd_batch words, shifts;
+ simd_batch results;
+ """))
+
+ def safe_load(index):
+ return f"SafeLoad<uint32_t>(in + {index})"
+
+ for i in range(32):
+ if shift + bit == 32:
+ shifts.append(shift)
+ inls.append(safe_load(in_index))
+ in_index += 1
+ shift = 0
+ elif shift + bit > 32: # cross the boundary
+ inls.append(
+ f"{safe_load(in_index)} >> {shift} | {safe_load(in_index + 1)} << {32 - shift}")
+ in_index += 1
+ shift = bit - (32 - shift)
+ shifts.append(0) # zero shift
+ else:
+ shifts.append(shift)
+ inls.append(safe_load(in_index))
+ shift += bit
+
+ bytes_per_batch = self.simd_byte_width
+ words_per_batch = bytes_per_batch // 4
+
+ one_word_template = dedent("""\
+ words = simd_batch{{ {words} }};
+ shifts = simd_batch{{ {shifts} }};
+ results = (words >> shifts) & masks;
+ results.store_unaligned(out);
+ out += {words_per_batch};
+ """)
+
+ for start in range(0, 32, words_per_batch):
+ stop = start + words_per_batch;
+ p(f"""// extract {bit}-bit bundles {start} to {stop - 1}""")
+ p(one_word_template.format(
+ words=", ".join(inls[start:stop]),
+ shifts=", ".join(map(str, shifts[start:stop])),
+ words_per_batch=words_per_batch))
+
+ p(dedent(f"""\
+ in += {bit};
+ return in;"""))
+ print("}")
+
+
+def print_copyright():
+ print(dedent("""\
+ // Licensed to the Apache Software Foundation (ASF) under one
+ // or more contributor license agreements. See the NOTICE file
+ // distributed with this work for additional information
+ // regarding copyright ownership. The ASF licenses this file
+ // to you under the Apache License, Version 2.0 (the
+ // "License"); you may not use this file except in compliance
+ // with the License. You may obtain a copy of the License at
+ //
+ // http://www.apache.org/licenses/LICENSE-2.0
+ //
+ // Unless required by applicable law or agreed to in writing,
+ // software distributed under the License is distributed on an
+ // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ // KIND, either express or implied. See the License for the
+ // specific language governing permissions and limitations
+ // under the License.
+ """))
+
+
+def print_note():
+ print("// Automatically generated file; DO NOT EDIT.")
+ print()
+
+
+def main(simd_width):
+ print_copyright()
+ print_note()
+
+ struct_name = f"UnpackBits{simd_width}"
+
+ define_simd_arch = {
+ # ugly format to get aligned output
+ 128: dedent("""\
+ #ifdef ARROW_HAVE_NEON
+ using simd_arch = xsimd::neon64;
+ #else
+ using simd_arch = xsimd::sse4_2;
+ #endif
+ """),
+ 256: "using simd_arch = xsimd::avx2;",
+ 512: "using simd_arch = xsimd::avx512bw;"
+ }
+
+ # NOTE: templating the UnpackBits struct on the dispatch level avoids
+ # potential name collisions if there are several UnpackBits generations
+ # with the same SIMD width on a given architecture.
+
+ print(dedent(f"""\
+ #pragma once
+
+ #include <cstdint>
+ #include <cstring>
+
+ #include <xsimd/xsimd.hpp>
+
+ #include "arrow/util/dispatch.h"
+ #include "arrow/util/ubsan.h"
+
+ namespace arrow {{
+ namespace internal {{
+ namespace {{
+
+ using ::arrow::util::SafeLoad;
+
+ template <DispatchLevel level>
+ struct {struct_name} {{
+
+ {define_simd_arch[simd_width]}
+ using simd_batch = xsimd::batch<uint32_t, simd_arch>;
+ """))
+
+ gen = UnpackGenerator(simd_width)
+ gen.print_unpack_bit0_func()
+ print()
+ for i in range(1, 32):
+ gen.print_unpack_bit_func(i)
+ print()
+ gen.print_unpack_bit32_func()
+ print()
+
+ print(dedent(f"""\
+ }}; // struct {struct_name}
+
+ }} // namespace
+ }} // namespace internal
+ }} // namespace arrow
+ """))
+
+
+if __name__ == '__main__':
+ usage = f"""Usage: {__file__} <SIMD bit-width>"""
+ if len(sys.argv) != 2:
+ raise ValueError(usage)
+ try:
+ simd_width = int(sys.argv[1])
+ except ValueError:
+ raise ValueError(usage)
+
+ main(simd_width)
diff --git a/src/arrow/cpp/src/arrow/util/bpacking_simd_internal.h b/src/arrow/cpp/src/arrow/util/bpacking_simd_internal.h
new file mode 100644
index 000000000..72d23f2d3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/bpacking_simd_internal.h
@@ -0,0 +1,138 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/dispatch.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace internal {
+
+template <typename UnpackBits>
+static int unpack32_specialized(const uint32_t* in, uint32_t* out, int batch_size,
+ int num_bits) {
+ batch_size = batch_size / 32 * 32;
+ int num_loops = batch_size / 32;
+
+ switch (num_bits) {
+ case 0:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack0_32(in, out + i * 32);
+ break;
+ case 1:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack1_32(in, out + i * 32);
+ break;
+ case 2:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack2_32(in, out + i * 32);
+ break;
+ case 3:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack3_32(in, out + i * 32);
+ break;
+ case 4:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack4_32(in, out + i * 32);
+ break;
+ case 5:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack5_32(in, out + i * 32);
+ break;
+ case 6:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack6_32(in, out + i * 32);
+ break;
+ case 7:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack7_32(in, out + i * 32);
+ break;
+ case 8:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack8_32(in, out + i * 32);
+ break;
+ case 9:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack9_32(in, out + i * 32);
+ break;
+ case 10:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack10_32(in, out + i * 32);
+ break;
+ case 11:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack11_32(in, out + i * 32);
+ break;
+ case 12:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack12_32(in, out + i * 32);
+ break;
+ case 13:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack13_32(in, out + i * 32);
+ break;
+ case 14:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack14_32(in, out + i * 32);
+ break;
+ case 15:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack15_32(in, out + i * 32);
+ break;
+ case 16:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack16_32(in, out + i * 32);
+ break;
+ case 17:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack17_32(in, out + i * 32);
+ break;
+ case 18:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack18_32(in, out + i * 32);
+ break;
+ case 19:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack19_32(in, out + i * 32);
+ break;
+ case 20:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack20_32(in, out + i * 32);
+ break;
+ case 21:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack21_32(in, out + i * 32);
+ break;
+ case 22:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack22_32(in, out + i * 32);
+ break;
+ case 23:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack23_32(in, out + i * 32);
+ break;
+ case 24:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack24_32(in, out + i * 32);
+ break;
+ case 25:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack25_32(in, out + i * 32);
+ break;
+ case 26:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack26_32(in, out + i * 32);
+ break;
+ case 27:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack27_32(in, out + i * 32);
+ break;
+ case 28:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack28_32(in, out + i * 32);
+ break;
+ case 29:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack29_32(in, out + i * 32);
+ break;
+ case 30:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack30_32(in, out + i * 32);
+ break;
+ case 31:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack31_32(in, out + i * 32);
+ break;
+ case 32:
+ for (int i = 0; i < num_loops; ++i) in = UnpackBits::unpack32_32(in, out + i * 32);
+ break;
+ default:
+ DCHECK(false) << "Unsupported num_bits";
+ }
+
+ return batch_size;
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/byte_stream_split.h b/src/arrow/cpp/src/arrow/util/byte_stream_split.h
new file mode 100644
index 000000000..28dcce52b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/byte_stream_split.h
@@ -0,0 +1,626 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/util/simd.h"
+#include "arrow/util/ubsan.h"
+
+#include <stdint.h>
+#include <algorithm>
+
+#ifdef ARROW_HAVE_SSE4_2
+// Enable the SIMD for ByteStreamSplit Encoder/Decoder
+#define ARROW_HAVE_SIMD_SPLIT
+#endif // ARROW_HAVE_SSE4_2
+
+namespace arrow {
+namespace util {
+namespace internal {
+
+#if defined(ARROW_HAVE_SSE4_2)
+template <typename T>
+void ByteStreamSplitDecodeSse2(const uint8_t* data, int64_t num_values, int64_t stride,
+ T* out) {
+ constexpr size_t kNumStreams = sizeof(T);
+ static_assert(kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams.");
+ constexpr size_t kNumStreamsLog2 = (kNumStreams == 8U ? 3U : 2U);
+
+ const int64_t size = num_values * sizeof(T);
+ constexpr int64_t kBlockSize = sizeof(__m128i) * kNumStreams;
+ const int64_t num_blocks = size / kBlockSize;
+ uint8_t* output_data = reinterpret_cast<uint8_t*>(out);
+
+ // First handle suffix.
+ // This helps catch if the simd-based processing overflows into the suffix
+ // since almost surely a test would fail.
+ const int64_t num_processed_elements = (num_blocks * kBlockSize) / kNumStreams;
+ for (int64_t i = num_processed_elements; i < num_values; ++i) {
+ uint8_t gathered_byte_data[kNumStreams];
+ for (size_t b = 0; b < kNumStreams; ++b) {
+ const size_t byte_index = b * stride + i;
+ gathered_byte_data[b] = data[byte_index];
+ }
+ out[i] = arrow::util::SafeLoadAs<T>(&gathered_byte_data[0]);
+ }
+
+ // The blocks get processed hierarchically using the unpack intrinsics.
+ // Example with four streams:
+ // Stage 1: AAAA BBBB CCCC DDDD
+ // Stage 2: ACAC ACAC BDBD BDBD
+ // Stage 3: ABCD ABCD ABCD ABCD
+ __m128i stage[kNumStreamsLog2 + 1U][kNumStreams];
+ constexpr size_t kNumStreamsHalf = kNumStreams / 2U;
+
+ for (int64_t i = 0; i < num_blocks; ++i) {
+ for (size_t j = 0; j < kNumStreams; ++j) {
+ stage[0][j] = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(&data[i * sizeof(__m128i) + j * stride]));
+ }
+ for (size_t step = 0; step < kNumStreamsLog2; ++step) {
+ for (size_t j = 0; j < kNumStreamsHalf; ++j) {
+ stage[step + 1U][j * 2] =
+ _mm_unpacklo_epi8(stage[step][j], stage[step][kNumStreamsHalf + j]);
+ stage[step + 1U][j * 2 + 1U] =
+ _mm_unpackhi_epi8(stage[step][j], stage[step][kNumStreamsHalf + j]);
+ }
+ }
+ for (size_t j = 0; j < kNumStreams; ++j) {
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(
+ &output_data[(i * kNumStreams + j) * sizeof(__m128i)]),
+ stage[kNumStreamsLog2][j]);
+ }
+ }
+}
+
+template <typename T>
+void ByteStreamSplitEncodeSse2(const uint8_t* raw_values, const size_t num_values,
+ uint8_t* output_buffer_raw) {
+ constexpr size_t kNumStreams = sizeof(T);
+ static_assert(kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams.");
+ __m128i stage[3][kNumStreams];
+ __m128i final_result[kNumStreams];
+
+ const size_t size = num_values * sizeof(T);
+ constexpr size_t kBlockSize = sizeof(__m128i) * kNumStreams;
+ const size_t num_blocks = size / kBlockSize;
+ const __m128i* raw_values_sse = reinterpret_cast<const __m128i*>(raw_values);
+ __m128i* output_buffer_streams[kNumStreams];
+ for (size_t i = 0; i < kNumStreams; ++i) {
+ output_buffer_streams[i] =
+ reinterpret_cast<__m128i*>(&output_buffer_raw[num_values * i]);
+ }
+
+ // First handle suffix.
+ const size_t num_processed_elements = (num_blocks * kBlockSize) / sizeof(T);
+ for (size_t i = num_processed_elements; i < num_values; ++i) {
+ for (size_t j = 0U; j < kNumStreams; ++j) {
+ const uint8_t byte_in_value = raw_values[i * kNumStreams + j];
+ output_buffer_raw[j * num_values + i] = byte_in_value;
+ }
+ }
+ // The current shuffling algorithm diverges for float and double types but the compiler
+ // should be able to remove the branch since only one path is taken for each template
+ // instantiation.
+ // Example run for floats:
+ // Step 0, copy:
+ // 0: ABCD ABCD ABCD ABCD 1: ABCD ABCD ABCD ABCD ...
+ // Step 1: _mm_unpacklo_epi8 and mm_unpackhi_epi8:
+ // 0: AABB CCDD AABB CCDD 1: AABB CCDD AABB CCDD ...
+ // 0: AAAA BBBB CCCC DDDD 1: AAAA BBBB CCCC DDDD ...
+ // Step 3: __mm_unpacklo_epi8 and _mm_unpackhi_epi8:
+ // 0: AAAA AAAA BBBB BBBB 1: CCCC CCCC DDDD DDDD ...
+ // Step 4: __mm_unpacklo_epi64 and _mm_unpackhi_epi64:
+ // 0: AAAA AAAA AAAA AAAA 1: BBBB BBBB BBBB BBBB ...
+ for (size_t block_index = 0; block_index < num_blocks; ++block_index) {
+ // First copy the data to stage 0.
+ for (size_t i = 0; i < kNumStreams; ++i) {
+ stage[0][i] = _mm_loadu_si128(&raw_values_sse[block_index * kNumStreams + i]);
+ }
+
+ // The shuffling of bytes is performed through the unpack intrinsics.
+ // In my measurements this gives better performance then an implementation
+ // which uses the shuffle intrinsics.
+ for (size_t stage_lvl = 0; stage_lvl < 2U; ++stage_lvl) {
+ for (size_t i = 0; i < kNumStreams / 2U; ++i) {
+ stage[stage_lvl + 1][i * 2] =
+ _mm_unpacklo_epi8(stage[stage_lvl][i * 2], stage[stage_lvl][i * 2 + 1]);
+ stage[stage_lvl + 1][i * 2 + 1] =
+ _mm_unpackhi_epi8(stage[stage_lvl][i * 2], stage[stage_lvl][i * 2 + 1]);
+ }
+ }
+ if (kNumStreams == 8U) {
+ // This is the path for double.
+ __m128i tmp[8];
+ for (size_t i = 0; i < 4; ++i) {
+ tmp[i * 2] = _mm_unpacklo_epi32(stage[2][i], stage[2][i + 4]);
+ tmp[i * 2 + 1] = _mm_unpackhi_epi32(stage[2][i], stage[2][i + 4]);
+ }
+
+ for (size_t i = 0; i < 4; ++i) {
+ final_result[i * 2] = _mm_unpacklo_epi32(tmp[i], tmp[i + 4]);
+ final_result[i * 2 + 1] = _mm_unpackhi_epi32(tmp[i], tmp[i + 4]);
+ }
+ } else {
+ // this is the path for float.
+ __m128i tmp[4];
+ for (size_t i = 0; i < 2; ++i) {
+ tmp[i * 2] = _mm_unpacklo_epi8(stage[2][i * 2], stage[2][i * 2 + 1]);
+ tmp[i * 2 + 1] = _mm_unpackhi_epi8(stage[2][i * 2], stage[2][i * 2 + 1]);
+ }
+ for (size_t i = 0; i < 2; ++i) {
+ final_result[i * 2] = _mm_unpacklo_epi64(tmp[i], tmp[i + 2]);
+ final_result[i * 2 + 1] = _mm_unpackhi_epi64(tmp[i], tmp[i + 2]);
+ }
+ }
+ for (size_t i = 0; i < kNumStreams; ++i) {
+ _mm_storeu_si128(&output_buffer_streams[i][block_index], final_result[i]);
+ }
+ }
+}
+#endif // ARROW_HAVE_SSE4_2
+
+#if defined(ARROW_HAVE_AVX2)
+template <typename T>
+void ByteStreamSplitDecodeAvx2(const uint8_t* data, int64_t num_values, int64_t stride,
+ T* out) {
+ constexpr size_t kNumStreams = sizeof(T);
+ static_assert(kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams.");
+ constexpr size_t kNumStreamsLog2 = (kNumStreams == 8U ? 3U : 2U);
+
+ const int64_t size = num_values * sizeof(T);
+ constexpr int64_t kBlockSize = sizeof(__m256i) * kNumStreams;
+ if (size < kBlockSize) // Back to SSE for small size
+ return ByteStreamSplitDecodeSse2(data, num_values, stride, out);
+ const int64_t num_blocks = size / kBlockSize;
+ uint8_t* output_data = reinterpret_cast<uint8_t*>(out);
+
+ // First handle suffix.
+ const int64_t num_processed_elements = (num_blocks * kBlockSize) / kNumStreams;
+ for (int64_t i = num_processed_elements; i < num_values; ++i) {
+ uint8_t gathered_byte_data[kNumStreams];
+ for (size_t b = 0; b < kNumStreams; ++b) {
+ const size_t byte_index = b * stride + i;
+ gathered_byte_data[b] = data[byte_index];
+ }
+ out[i] = arrow::util::SafeLoadAs<T>(&gathered_byte_data[0]);
+ }
+
+ // Processed hierarchically using unpack intrinsics, then permute intrinsics.
+ __m256i stage[kNumStreamsLog2 + 1U][kNumStreams];
+ __m256i final_result[kNumStreams];
+ constexpr size_t kNumStreamsHalf = kNumStreams / 2U;
+
+ for (int64_t i = 0; i < num_blocks; ++i) {
+ for (size_t j = 0; j < kNumStreams; ++j) {
+ stage[0][j] = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(&data[i * sizeof(__m256i) + j * stride]));
+ }
+
+ for (size_t step = 0; step < kNumStreamsLog2; ++step) {
+ for (size_t j = 0; j < kNumStreamsHalf; ++j) {
+ stage[step + 1U][j * 2] =
+ _mm256_unpacklo_epi8(stage[step][j], stage[step][kNumStreamsHalf + j]);
+ stage[step + 1U][j * 2 + 1U] =
+ _mm256_unpackhi_epi8(stage[step][j], stage[step][kNumStreamsHalf + j]);
+ }
+ }
+
+ if (kNumStreams == 8U) {
+ // path for double, 128i index:
+ // {0x00, 0x08}, {0x01, 0x09}, {0x02, 0x0A}, {0x03, 0x0B},
+ // {0x04, 0x0C}, {0x05, 0x0D}, {0x06, 0x0E}, {0x07, 0x0F},
+ final_result[0] = _mm256_permute2x128_si256(stage[kNumStreamsLog2][0],
+ stage[kNumStreamsLog2][1], 0b00100000);
+ final_result[1] = _mm256_permute2x128_si256(stage[kNumStreamsLog2][2],
+ stage[kNumStreamsLog2][3], 0b00100000);
+ final_result[2] = _mm256_permute2x128_si256(stage[kNumStreamsLog2][4],
+ stage[kNumStreamsLog2][5], 0b00100000);
+ final_result[3] = _mm256_permute2x128_si256(stage[kNumStreamsLog2][6],
+ stage[kNumStreamsLog2][7], 0b00100000);
+ final_result[4] = _mm256_permute2x128_si256(stage[kNumStreamsLog2][0],
+ stage[kNumStreamsLog2][1], 0b00110001);
+ final_result[5] = _mm256_permute2x128_si256(stage[kNumStreamsLog2][2],
+ stage[kNumStreamsLog2][3], 0b00110001);
+ final_result[6] = _mm256_permute2x128_si256(stage[kNumStreamsLog2][4],
+ stage[kNumStreamsLog2][5], 0b00110001);
+ final_result[7] = _mm256_permute2x128_si256(stage[kNumStreamsLog2][6],
+ stage[kNumStreamsLog2][7], 0b00110001);
+ } else {
+ // path for float, 128i index:
+ // {0x00, 0x04}, {0x01, 0x05}, {0x02, 0x06}, {0x03, 0x07}
+ final_result[0] = _mm256_permute2x128_si256(stage[kNumStreamsLog2][0],
+ stage[kNumStreamsLog2][1], 0b00100000);
+ final_result[1] = _mm256_permute2x128_si256(stage[kNumStreamsLog2][2],
+ stage[kNumStreamsLog2][3], 0b00100000);
+ final_result[2] = _mm256_permute2x128_si256(stage[kNumStreamsLog2][0],
+ stage[kNumStreamsLog2][1], 0b00110001);
+ final_result[3] = _mm256_permute2x128_si256(stage[kNumStreamsLog2][2],
+ stage[kNumStreamsLog2][3], 0b00110001);
+ }
+
+ for (size_t j = 0; j < kNumStreams; ++j) {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(
+ &output_data[(i * kNumStreams + j) * sizeof(__m256i)]),
+ final_result[j]);
+ }
+ }
+}
+
+template <typename T>
+void ByteStreamSplitEncodeAvx2(const uint8_t* raw_values, const size_t num_values,
+ uint8_t* output_buffer_raw) {
+ constexpr size_t kNumStreams = sizeof(T);
+ static_assert(kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams.");
+ if (kNumStreams == 8U) // Back to SSE, currently no path for double.
+ return ByteStreamSplitEncodeSse2<T>(raw_values, num_values, output_buffer_raw);
+
+ const size_t size = num_values * sizeof(T);
+ constexpr size_t kBlockSize = sizeof(__m256i) * kNumStreams;
+ if (size < kBlockSize) // Back to SSE for small size
+ return ByteStreamSplitEncodeSse2<T>(raw_values, num_values, output_buffer_raw);
+ const size_t num_blocks = size / kBlockSize;
+ const __m256i* raw_values_simd = reinterpret_cast<const __m256i*>(raw_values);
+ __m256i* output_buffer_streams[kNumStreams];
+
+ for (size_t i = 0; i < kNumStreams; ++i) {
+ output_buffer_streams[i] =
+ reinterpret_cast<__m256i*>(&output_buffer_raw[num_values * i]);
+ }
+
+ // First handle suffix.
+ const size_t num_processed_elements = (num_blocks * kBlockSize) / sizeof(T);
+ for (size_t i = num_processed_elements; i < num_values; ++i) {
+ for (size_t j = 0U; j < kNumStreams; ++j) {
+ const uint8_t byte_in_value = raw_values[i * kNumStreams + j];
+ output_buffer_raw[j * num_values + i] = byte_in_value;
+ }
+ }
+
+ // Path for float.
+ // 1. Processed hierarchically to 32i blcok using the unpack intrinsics.
+ // 2. Pack 128i block using _mm256_permutevar8x32_epi32.
+ // 3. Pack final 256i block with _mm256_permute2x128_si256.
+ constexpr size_t kNumUnpack = 3U;
+ __m256i stage[kNumUnpack + 1][kNumStreams];
+ static const __m256i kPermuteMask =
+ _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
+ __m256i permute[kNumStreams];
+ __m256i final_result[kNumStreams];
+
+ for (size_t block_index = 0; block_index < num_blocks; ++block_index) {
+ for (size_t i = 0; i < kNumStreams; ++i) {
+ stage[0][i] = _mm256_loadu_si256(&raw_values_simd[block_index * kNumStreams + i]);
+ }
+
+ for (size_t stage_lvl = 0; stage_lvl < kNumUnpack; ++stage_lvl) {
+ for (size_t i = 0; i < kNumStreams / 2U; ++i) {
+ stage[stage_lvl + 1][i * 2] =
+ _mm256_unpacklo_epi8(stage[stage_lvl][i * 2], stage[stage_lvl][i * 2 + 1]);
+ stage[stage_lvl + 1][i * 2 + 1] =
+ _mm256_unpackhi_epi8(stage[stage_lvl][i * 2], stage[stage_lvl][i * 2 + 1]);
+ }
+ }
+
+ for (size_t i = 0; i < kNumStreams; ++i) {
+ permute[i] = _mm256_permutevar8x32_epi32(stage[kNumUnpack][i], kPermuteMask);
+ }
+
+ final_result[0] = _mm256_permute2x128_si256(permute[0], permute[2], 0b00100000);
+ final_result[1] = _mm256_permute2x128_si256(permute[0], permute[2], 0b00110001);
+ final_result[2] = _mm256_permute2x128_si256(permute[1], permute[3], 0b00100000);
+ final_result[3] = _mm256_permute2x128_si256(permute[1], permute[3], 0b00110001);
+
+ for (size_t i = 0; i < kNumStreams; ++i) {
+ _mm256_storeu_si256(&output_buffer_streams[i][block_index], final_result[i]);
+ }
+ }
+}
+#endif // ARROW_HAVE_AVX2
+
+#if defined(ARROW_HAVE_AVX512)
+template <typename T>
+void ByteStreamSplitDecodeAvx512(const uint8_t* data, int64_t num_values, int64_t stride,
+ T* out) {
+ constexpr size_t kNumStreams = sizeof(T);
+ static_assert(kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams.");
+ constexpr size_t kNumStreamsLog2 = (kNumStreams == 8U ? 3U : 2U);
+
+ const int64_t size = num_values * sizeof(T);
+ constexpr int64_t kBlockSize = sizeof(__m512i) * kNumStreams;
+ if (size < kBlockSize) // Back to AVX2 for small size
+ return ByteStreamSplitDecodeAvx2(data, num_values, stride, out);
+ const int64_t num_blocks = size / kBlockSize;
+ uint8_t* output_data = reinterpret_cast<uint8_t*>(out);
+
+ // First handle suffix.
+ const int64_t num_processed_elements = (num_blocks * kBlockSize) / kNumStreams;
+ for (int64_t i = num_processed_elements; i < num_values; ++i) {
+ uint8_t gathered_byte_data[kNumStreams];
+ for (size_t b = 0; b < kNumStreams; ++b) {
+ const size_t byte_index = b * stride + i;
+ gathered_byte_data[b] = data[byte_index];
+ }
+ out[i] = arrow::util::SafeLoadAs<T>(&gathered_byte_data[0]);
+ }
+
+ // Processed hierarchically using the unpack, then two shuffles.
+ __m512i stage[kNumStreamsLog2 + 1U][kNumStreams];
+ __m512i shuffle[kNumStreams];
+ __m512i final_result[kNumStreams];
+ constexpr size_t kNumStreamsHalf = kNumStreams / 2U;
+
+ for (int64_t i = 0; i < num_blocks; ++i) {
+ for (size_t j = 0; j < kNumStreams; ++j) {
+ stage[0][j] = _mm512_loadu_si512(
+ reinterpret_cast<const __m512i*>(&data[i * sizeof(__m512i) + j * stride]));
+ }
+
+ for (size_t step = 0; step < kNumStreamsLog2; ++step) {
+ for (size_t j = 0; j < kNumStreamsHalf; ++j) {
+ stage[step + 1U][j * 2] =
+ _mm512_unpacklo_epi8(stage[step][j], stage[step][kNumStreamsHalf + j]);
+ stage[step + 1U][j * 2 + 1U] =
+ _mm512_unpackhi_epi8(stage[step][j], stage[step][kNumStreamsHalf + j]);
+ }
+ }
+
+ if (kNumStreams == 8U) {
+ // path for double, 128i index:
+ // {0x00, 0x04, 0x08, 0x0C}, {0x10, 0x14, 0x18, 0x1C},
+ // {0x01, 0x05, 0x09, 0x0D}, {0x11, 0x15, 0x19, 0x1D},
+ // {0x02, 0x06, 0x0A, 0x0E}, {0x12, 0x16, 0x1A, 0x1E},
+ // {0x03, 0x07, 0x0B, 0x0F}, {0x13, 0x17, 0x1B, 0x1F},
+ shuffle[0] = _mm512_shuffle_i32x4(stage[kNumStreamsLog2][0],
+ stage[kNumStreamsLog2][1], 0b01000100);
+ shuffle[1] = _mm512_shuffle_i32x4(stage[kNumStreamsLog2][2],
+ stage[kNumStreamsLog2][3], 0b01000100);
+ shuffle[2] = _mm512_shuffle_i32x4(stage[kNumStreamsLog2][4],
+ stage[kNumStreamsLog2][5], 0b01000100);
+ shuffle[3] = _mm512_shuffle_i32x4(stage[kNumStreamsLog2][6],
+ stage[kNumStreamsLog2][7], 0b01000100);
+ shuffle[4] = _mm512_shuffle_i32x4(stage[kNumStreamsLog2][0],
+ stage[kNumStreamsLog2][1], 0b11101110);
+ shuffle[5] = _mm512_shuffle_i32x4(stage[kNumStreamsLog2][2],
+ stage[kNumStreamsLog2][3], 0b11101110);
+ shuffle[6] = _mm512_shuffle_i32x4(stage[kNumStreamsLog2][4],
+ stage[kNumStreamsLog2][5], 0b11101110);
+ shuffle[7] = _mm512_shuffle_i32x4(stage[kNumStreamsLog2][6],
+ stage[kNumStreamsLog2][7], 0b11101110);
+
+ final_result[0] = _mm512_shuffle_i32x4(shuffle[0], shuffle[1], 0b10001000);
+ final_result[1] = _mm512_shuffle_i32x4(shuffle[2], shuffle[3], 0b10001000);
+ final_result[2] = _mm512_shuffle_i32x4(shuffle[0], shuffle[1], 0b11011101);
+ final_result[3] = _mm512_shuffle_i32x4(shuffle[2], shuffle[3], 0b11011101);
+ final_result[4] = _mm512_shuffle_i32x4(shuffle[4], shuffle[5], 0b10001000);
+ final_result[5] = _mm512_shuffle_i32x4(shuffle[6], shuffle[7], 0b10001000);
+ final_result[6] = _mm512_shuffle_i32x4(shuffle[4], shuffle[5], 0b11011101);
+ final_result[7] = _mm512_shuffle_i32x4(shuffle[6], shuffle[7], 0b11011101);
+ } else {
+ // path for float, 128i index:
+ // {0x00, 0x04, 0x08, 0x0C}, {0x01, 0x05, 0x09, 0x0D}
+ // {0x02, 0x06, 0x0A, 0x0E}, {0x03, 0x07, 0x0B, 0x0F},
+ shuffle[0] = _mm512_shuffle_i32x4(stage[kNumStreamsLog2][0],
+ stage[kNumStreamsLog2][1], 0b01000100);
+ shuffle[1] = _mm512_shuffle_i32x4(stage[kNumStreamsLog2][2],
+ stage[kNumStreamsLog2][3], 0b01000100);
+ shuffle[2] = _mm512_shuffle_i32x4(stage[kNumStreamsLog2][0],
+ stage[kNumStreamsLog2][1], 0b11101110);
+ shuffle[3] = _mm512_shuffle_i32x4(stage[kNumStreamsLog2][2],
+ stage[kNumStreamsLog2][3], 0b11101110);
+
+ final_result[0] = _mm512_shuffle_i32x4(shuffle[0], shuffle[1], 0b10001000);
+ final_result[1] = _mm512_shuffle_i32x4(shuffle[0], shuffle[1], 0b11011101);
+ final_result[2] = _mm512_shuffle_i32x4(shuffle[2], shuffle[3], 0b10001000);
+ final_result[3] = _mm512_shuffle_i32x4(shuffle[2], shuffle[3], 0b11011101);
+ }
+
+ for (size_t j = 0; j < kNumStreams; ++j) {
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(
+ &output_data[(i * kNumStreams + j) * sizeof(__m512i)]),
+ final_result[j]);
+ }
+ }
+}
+
+template <typename T>
+void ByteStreamSplitEncodeAvx512(const uint8_t* raw_values, const size_t num_values,
+ uint8_t* output_buffer_raw) {
+ constexpr size_t kNumStreams = sizeof(T);
+ static_assert(kNumStreams == 4U || kNumStreams == 8U, "Invalid number of streams.");
+ const size_t size = num_values * sizeof(T);
+ constexpr size_t kBlockSize = sizeof(__m512i) * kNumStreams;
+ if (size < kBlockSize) // Back to AVX2 for small size
+ return ByteStreamSplitEncodeAvx2<T>(raw_values, num_values, output_buffer_raw);
+
+ const size_t num_blocks = size / kBlockSize;
+ const __m512i* raw_values_simd = reinterpret_cast<const __m512i*>(raw_values);
+ __m512i* output_buffer_streams[kNumStreams];
+ for (size_t i = 0; i < kNumStreams; ++i) {
+ output_buffer_streams[i] =
+ reinterpret_cast<__m512i*>(&output_buffer_raw[num_values * i]);
+ }
+
+ // First handle suffix.
+ const size_t num_processed_elements = (num_blocks * kBlockSize) / sizeof(T);
+ for (size_t i = num_processed_elements; i < num_values; ++i) {
+ for (size_t j = 0U; j < kNumStreams; ++j) {
+ const uint8_t byte_in_value = raw_values[i * kNumStreams + j];
+ output_buffer_raw[j * num_values + i] = byte_in_value;
+ }
+ }
+
+ constexpr size_t KNumUnpack = (kNumStreams == 8U) ? 2U : 3U;
+ __m512i final_result[kNumStreams];
+ __m512i unpack[KNumUnpack + 1][kNumStreams];
+ __m512i permutex[kNumStreams];
+ __m512i permutex_mask;
+ if (kNumStreams == 8U) {
+ // use _mm512_set_epi32, no _mm512_set_epi16 for some old gcc version.
+ permutex_mask = _mm512_set_epi32(0x001F0017, 0x000F0007, 0x001E0016, 0x000E0006,
+ 0x001D0015, 0x000D0005, 0x001C0014, 0x000C0004,
+ 0x001B0013, 0x000B0003, 0x001A0012, 0x000A0002,
+ 0x00190011, 0x00090001, 0x00180010, 0x00080000);
+ } else {
+ permutex_mask = _mm512_set_epi32(0x0F, 0x0B, 0x07, 0x03, 0x0E, 0x0A, 0x06, 0x02, 0x0D,
+ 0x09, 0x05, 0x01, 0x0C, 0x08, 0x04, 0x00);
+ }
+
+ for (size_t block_index = 0; block_index < num_blocks; ++block_index) {
+ for (size_t i = 0; i < kNumStreams; ++i) {
+ unpack[0][i] = _mm512_loadu_si512(&raw_values_simd[block_index * kNumStreams + i]);
+ }
+
+ for (size_t unpack_lvl = 0; unpack_lvl < KNumUnpack; ++unpack_lvl) {
+ for (size_t i = 0; i < kNumStreams / 2U; ++i) {
+ unpack[unpack_lvl + 1][i * 2] = _mm512_unpacklo_epi8(
+ unpack[unpack_lvl][i * 2], unpack[unpack_lvl][i * 2 + 1]);
+ unpack[unpack_lvl + 1][i * 2 + 1] = _mm512_unpackhi_epi8(
+ unpack[unpack_lvl][i * 2], unpack[unpack_lvl][i * 2 + 1]);
+ }
+ }
+
+ if (kNumStreams == 8U) {
+ // path for double
+ // 1. unpack to epi16 block
+ // 2. permutexvar_epi16 to 128i block
+ // 3. shuffle 128i to final 512i target, index:
+ // {0x00, 0x04, 0x08, 0x0C}, {0x10, 0x14, 0x18, 0x1C},
+ // {0x01, 0x05, 0x09, 0x0D}, {0x11, 0x15, 0x19, 0x1D},
+ // {0x02, 0x06, 0x0A, 0x0E}, {0x12, 0x16, 0x1A, 0x1E},
+ // {0x03, 0x07, 0x0B, 0x0F}, {0x13, 0x17, 0x1B, 0x1F},
+ for (size_t i = 0; i < kNumStreams; ++i)
+ permutex[i] = _mm512_permutexvar_epi16(permutex_mask, unpack[KNumUnpack][i]);
+
+ __m512i shuffle[kNumStreams];
+ shuffle[0] = _mm512_shuffle_i32x4(permutex[0], permutex[2], 0b01000100);
+ shuffle[1] = _mm512_shuffle_i32x4(permutex[4], permutex[6], 0b01000100);
+ shuffle[2] = _mm512_shuffle_i32x4(permutex[0], permutex[2], 0b11101110);
+ shuffle[3] = _mm512_shuffle_i32x4(permutex[4], permutex[6], 0b11101110);
+ shuffle[4] = _mm512_shuffle_i32x4(permutex[1], permutex[3], 0b01000100);
+ shuffle[5] = _mm512_shuffle_i32x4(permutex[5], permutex[7], 0b01000100);
+ shuffle[6] = _mm512_shuffle_i32x4(permutex[1], permutex[3], 0b11101110);
+ shuffle[7] = _mm512_shuffle_i32x4(permutex[5], permutex[7], 0b11101110);
+
+ final_result[0] = _mm512_shuffle_i32x4(shuffle[0], shuffle[1], 0b10001000);
+ final_result[1] = _mm512_shuffle_i32x4(shuffle[0], shuffle[1], 0b11011101);
+ final_result[2] = _mm512_shuffle_i32x4(shuffle[2], shuffle[3], 0b10001000);
+ final_result[3] = _mm512_shuffle_i32x4(shuffle[2], shuffle[3], 0b11011101);
+ final_result[4] = _mm512_shuffle_i32x4(shuffle[4], shuffle[5], 0b10001000);
+ final_result[5] = _mm512_shuffle_i32x4(shuffle[4], shuffle[5], 0b11011101);
+ final_result[6] = _mm512_shuffle_i32x4(shuffle[6], shuffle[7], 0b10001000);
+ final_result[7] = _mm512_shuffle_i32x4(shuffle[6], shuffle[7], 0b11011101);
+ } else {
+ // Path for float.
+ // 1. Processed hierarchically to 32i blcok using the unpack intrinsics.
+ // 2. Pack 128i block using _mm256_permutevar8x32_epi32.
+ // 3. Pack final 256i block with _mm256_permute2x128_si256.
+ for (size_t i = 0; i < kNumStreams; ++i)
+ permutex[i] = _mm512_permutexvar_epi32(permutex_mask, unpack[KNumUnpack][i]);
+
+ final_result[0] = _mm512_shuffle_i32x4(permutex[0], permutex[2], 0b01000100);
+ final_result[1] = _mm512_shuffle_i32x4(permutex[0], permutex[2], 0b11101110);
+ final_result[2] = _mm512_shuffle_i32x4(permutex[1], permutex[3], 0b01000100);
+ final_result[3] = _mm512_shuffle_i32x4(permutex[1], permutex[3], 0b11101110);
+ }
+
+ for (size_t i = 0; i < kNumStreams; ++i) {
+ _mm512_storeu_si512(&output_buffer_streams[i][block_index], final_result[i]);
+ }
+ }
+}
+#endif // ARROW_HAVE_AVX512
+
+#if defined(ARROW_HAVE_SIMD_SPLIT)
+template <typename T>
+void inline ByteStreamSplitDecodeSimd(const uint8_t* data, int64_t num_values,
+ int64_t stride, T* out) {
+#if defined(ARROW_HAVE_AVX512)
+ return ByteStreamSplitDecodeAvx512(data, num_values, stride, out);
+#elif defined(ARROW_HAVE_AVX2)
+ return ByteStreamSplitDecodeAvx2(data, num_values, stride, out);
+#elif defined(ARROW_HAVE_SSE4_2)
+ return ByteStreamSplitDecodeSse2(data, num_values, stride, out);
+#else
+#error "ByteStreamSplitDecodeSimd not implemented"
+#endif
+}
+
+template <typename T>
+void inline ByteStreamSplitEncodeSimd(const uint8_t* raw_values, const size_t num_values,
+ uint8_t* output_buffer_raw) {
+#if defined(ARROW_HAVE_AVX512)
+ return ByteStreamSplitEncodeAvx512<T>(raw_values, num_values, output_buffer_raw);
+#elif defined(ARROW_HAVE_AVX2)
+ return ByteStreamSplitEncodeAvx2<T>(raw_values, num_values, output_buffer_raw);
+#elif defined(ARROW_HAVE_SSE4_2)
+ return ByteStreamSplitEncodeSse2<T>(raw_values, num_values, output_buffer_raw);
+#else
+#error "ByteStreamSplitEncodeSimd not implemented"
+#endif
+}
+#endif
+
+template <typename T>
+void ByteStreamSplitEncodeScalar(const uint8_t* raw_values, const size_t num_values,
+ uint8_t* output_buffer_raw) {
+ constexpr size_t kNumStreams = sizeof(T);
+ for (size_t i = 0U; i < num_values; ++i) {
+ for (size_t j = 0U; j < kNumStreams; ++j) {
+ const uint8_t byte_in_value = raw_values[i * kNumStreams + j];
+ output_buffer_raw[j * num_values + i] = byte_in_value;
+ }
+ }
+}
+
+template <typename T>
+void ByteStreamSplitDecodeScalar(const uint8_t* data, int64_t num_values, int64_t stride,
+ T* out) {
+ constexpr size_t kNumStreams = sizeof(T);
+ auto output_buffer_raw = reinterpret_cast<uint8_t*>(out);
+
+ for (int64_t i = 0; i < num_values; ++i) {
+ for (size_t b = 0; b < kNumStreams; ++b) {
+ const size_t byte_index = b * stride + i;
+ output_buffer_raw[i * kNumStreams + b] = data[byte_index];
+ }
+ }
+}
+
+template <typename T>
+void inline ByteStreamSplitEncode(const uint8_t* raw_values, const size_t num_values,
+ uint8_t* output_buffer_raw) {
+#if defined(ARROW_HAVE_SIMD_SPLIT)
+ return ByteStreamSplitEncodeSimd<T>(raw_values, num_values, output_buffer_raw);
+#else
+ return ByteStreamSplitEncodeScalar<T>(raw_values, num_values, output_buffer_raw);
+#endif
+}
+
+template <typename T>
+void inline ByteStreamSplitDecode(const uint8_t* data, int64_t num_values, int64_t stride,
+ T* out) {
+#if defined(ARROW_HAVE_SIMD_SPLIT)
+ return ByteStreamSplitDecodeSimd(data, num_values, stride, out);
+#else
+ return ByteStreamSplitDecodeScalar(data, num_values, stride, out);
+#endif
+}
+
+} // namespace internal
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/cache_benchmark.cc b/src/arrow/cpp/src/arrow/util/cache_benchmark.cc
new file mode 100644
index 000000000..7439ee2f5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/cache_benchmark.cc
@@ -0,0 +1,146 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/cache_internal.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+static constexpr int32_t kCacheSize = 100;
+static constexpr int32_t kSmallKeyLength = 8;
+static constexpr int32_t kLargeKeyLength = 64;
+static constexpr int32_t kSmallValueLength = 16;
+static constexpr int32_t kLargeValueLength = 1024;
+
+static std::vector<std::string> MakeStrings(int64_t nvalues, int64_t min_length,
+ int64_t max_length) {
+ auto rng = ::arrow::random::RandomArrayGenerator(42);
+ auto arr = checked_pointer_cast<StringArray>(rng.String(
+ nvalues, static_cast<int32_t>(min_length), static_cast<int32_t>(max_length)));
+ std::vector<std::string> vec(nvalues);
+ for (int64_t i = 0; i < nvalues; ++i) {
+ vec[i] = arr->GetString(i);
+ }
+ return vec;
+}
+
+static std::vector<std::string> MakeStrings(int64_t nvalues, int64_t length) {
+ return MakeStrings(nvalues, length, length);
+}
+
+template <typename Cache, typename Key, typename Value>
+static void BenchmarkCacheLookups(benchmark::State& state, const std::vector<Key>& keys,
+ const std::vector<Value>& values) {
+ const int32_t nitems = static_cast<int32_t>(keys.size());
+ Cache cache(nitems);
+ for (int32_t i = 0; i < nitems; ++i) {
+ cache.Replace(keys[i], values[i]);
+ }
+
+ for (auto _ : state) {
+ int64_t nfinds = 0;
+ for (const auto& key : keys) {
+ nfinds += (cache.Find(key) != nullptr);
+ }
+ benchmark::DoNotOptimize(nfinds);
+ ARROW_CHECK_EQ(nfinds, nitems);
+ }
+ state.SetItemsProcessed(state.iterations() * nitems);
+}
+
+static void LruCacheLookup(benchmark::State& state) {
+ const auto keys = MakeStrings(kCacheSize, state.range(0));
+ const auto values = MakeStrings(kCacheSize, state.range(1));
+ BenchmarkCacheLookups<LruCache<std::string, std::string>>(state, keys, values);
+}
+
+static void SetCacheArgs(benchmark::internal::Benchmark* bench) {
+ bench->Args({kSmallKeyLength, kSmallValueLength});
+ bench->Args({kSmallKeyLength, kLargeValueLength});
+ bench->Args({kLargeKeyLength, kSmallValueLength});
+ bench->Args({kLargeKeyLength, kLargeValueLength});
+}
+
+BENCHMARK(LruCacheLookup)->Apply(SetCacheArgs);
+
+struct Callable {
+ explicit Callable(std::vector<std::string> values)
+ : index_(0), values_(std::move(values)) {}
+
+ std::string operator()(const std::string& key) {
+ // Return a value unrelated to the key
+ if (++index_ >= static_cast<int64_t>(values_.size())) {
+ index_ = 0;
+ }
+ return values_[index_];
+ }
+
+ private:
+ int64_t index_;
+ std::vector<std::string> values_;
+};
+
+template <typename Memoized>
+static void BenchmarkMemoize(benchmark::State& state, Memoized&& mem,
+ const std::vector<std::string>& keys) {
+ // Prime memoization cache
+ for (const auto& key : keys) {
+ mem(key);
+ }
+
+ for (auto _ : state) {
+ int64_t nbytes = 0;
+ for (const auto& key : keys) {
+ nbytes += static_cast<int64_t>(mem(key).length());
+ }
+ benchmark::DoNotOptimize(nbytes);
+ }
+ state.SetItemsProcessed(state.iterations() * keys.size());
+}
+
+static void MemoizeLruCached(benchmark::State& state) {
+ const auto keys = MakeStrings(kCacheSize, state.range(0));
+ const auto values = MakeStrings(kCacheSize, state.range(1));
+ auto mem = MemoizeLru(Callable(values), kCacheSize);
+ BenchmarkMemoize(state, mem, keys);
+}
+
+static void MemoizeLruCachedThreadUnsafe(benchmark::State& state) {
+ const auto keys = MakeStrings(kCacheSize, state.range(0));
+ const auto values = MakeStrings(kCacheSize, state.range(1));
+ // Emulate recommended usage of MemoizeLruCachedThreadUnsafe
+ // (the compiler is probably able to cache the TLS-looked up value, though)
+ thread_local auto mem = MemoizeLruThreadUnsafe(Callable(values), kCacheSize);
+ BenchmarkMemoize(state, mem, keys);
+}
+
+BENCHMARK(MemoizeLruCached)->Apply(SetCacheArgs);
+BENCHMARK(MemoizeLruCachedThreadUnsafe)->Apply(SetCacheArgs);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/cache_internal.h b/src/arrow/cpp/src/arrow/util/cache_internal.h
new file mode 100644
index 000000000..231fd800b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/cache_internal.h
@@ -0,0 +1,210 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <functional>
+#include <list>
+#include <memory>
+#include <mutex>
+#include <type_traits>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "arrow/util/functional.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+// A LRU (Least recently used) replacement cache
+template <typename Key, typename Value>
+class LruCache {
+ public:
+ explicit LruCache(int32_t capacity) : capacity_(capacity) {
+ // The map size can temporarily exceed the cache capacity, see Replace()
+ map_.reserve(capacity_ + 1);
+ }
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(LruCache);
+ ARROW_DEFAULT_MOVE_AND_ASSIGN(LruCache);
+
+ void Clear() {
+ items_.clear();
+ map_.clear();
+ // The C++ spec doesn't tell whether map_.clear() will shrink the map capacity
+ map_.reserve(capacity_ + 1);
+ }
+
+ int32_t size() const {
+ DCHECK_EQ(items_.size(), map_.size());
+ return static_cast<int32_t>(items_.size());
+ }
+
+ template <typename K>
+ Value* Find(K&& key) {
+ const auto it = map_.find(key);
+ if (it == map_.end()) {
+ return nullptr;
+ } else {
+ // Found => move item at front of the list
+ auto list_it = it->second;
+ items_.splice(items_.begin(), items_, list_it);
+ return &list_it->value;
+ }
+ }
+
+ template <typename K, typename V>
+ std::pair<bool, Value*> Replace(K&& key, V&& value) {
+ // Try to insert temporary iterator
+ auto pair = map_.emplace(std::forward<K>(key), ListIt{});
+ const auto it = pair.first;
+ const bool inserted = pair.second;
+ if (inserted) {
+ // Inserted => push item at front of the list, and update iterator
+ items_.push_front(Item{&it->first, std::forward<V>(value)});
+ it->second = items_.begin();
+ // Did we exceed the cache capacity? If so, remove least recently used item
+ if (static_cast<int32_t>(items_.size()) > capacity_) {
+ const bool erased = map_.erase(*items_.back().key);
+ DCHECK(erased);
+ ARROW_UNUSED(erased);
+ items_.pop_back();
+ }
+ return {true, &it->second->value};
+ } else {
+ // Already exists => move item at front of the list, and update value
+ auto list_it = it->second;
+ items_.splice(items_.begin(), items_, list_it);
+ list_it->value = std::forward<V>(value);
+ return {false, &list_it->value};
+ }
+ }
+
+ private:
+ struct Item {
+ // Pointer to the key inside the unordered_map
+ const Key* key;
+ Value value;
+ };
+ using List = std::list<Item>;
+ using ListIt = typename List::iterator;
+
+ const int32_t capacity_;
+ // In most to least recently used order
+ std::list<Item> items_;
+ std::unordered_map<Key, ListIt> map_;
+};
+
+namespace detail {
+
+template <typename Key, typename Value, typename Cache, typename Func>
+struct ThreadSafeMemoizer {
+ using RetType = Value;
+
+ template <typename F>
+ ThreadSafeMemoizer(F&& func, int32_t cache_capacity)
+ : func_(std::forward<F>(func)), cache_(cache_capacity) {}
+
+ // The memoizer can't return a pointer to the cached value, because
+ // the cache entry may be evicted by another thread.
+
+ Value operator()(const Key& key) {
+ std::unique_lock<std::mutex> lock(mutex_);
+ const Value* value_ptr;
+ value_ptr = cache_.Find(key);
+ if (ARROW_PREDICT_TRUE(value_ptr != nullptr)) {
+ return *value_ptr;
+ }
+ lock.unlock();
+ Value v = func_(key);
+ lock.lock();
+ return *cache_.Replace(key, std::move(v)).second;
+ }
+
+ private:
+ std::mutex mutex_;
+ Func func_;
+ Cache cache_;
+};
+
+template <typename Key, typename Value, typename Cache, typename Func>
+struct ThreadUnsafeMemoizer {
+ using RetType = const Value&;
+
+ template <typename F>
+ ThreadUnsafeMemoizer(F&& func, int32_t cache_capacity)
+ : func_(std::forward<F>(func)), cache_(cache_capacity) {}
+
+ const Value& operator()(const Key& key) {
+ const Value* value_ptr;
+ value_ptr = cache_.Find(key);
+ if (ARROW_PREDICT_TRUE(value_ptr != nullptr)) {
+ return *value_ptr;
+ }
+ return *cache_.Replace(key, func_(key)).second;
+ }
+
+ private:
+ Func func_;
+ Cache cache_;
+};
+
+template <template <typename...> class Cache, template <typename...> class MemoizerType,
+ typename Func,
+ typename Key = typename std::decay<call_traits::argument_type<0, Func>>::type,
+ typename Value = typename std::decay<call_traits::return_type<Func>>::type,
+ typename Memoizer = MemoizerType<Key, Value, Cache<Key, Value>, Func>,
+ typename RetType = typename Memoizer::RetType>
+static std::function<RetType(const Key&)> Memoize(Func&& func, int32_t cache_capacity) {
+ // std::function<> requires copy constructibility
+ struct {
+ RetType operator()(const Key& key) const { return (*memoized_)(key); }
+ std::shared_ptr<Memoizer> memoized_;
+ } shared_memoized = {
+ std::make_shared<Memoizer>(std::forward<Func>(func), cache_capacity)};
+
+ return shared_memoized;
+}
+
+} // namespace detail
+
+// Apply a LRU memoization cache to a callable.
+template <typename Func>
+static auto MemoizeLru(Func&& func, int32_t cache_capacity)
+ -> decltype(detail::Memoize<LruCache, detail::ThreadSafeMemoizer>(
+ std::forward<Func>(func), cache_capacity)) {
+ return detail::Memoize<LruCache, detail::ThreadSafeMemoizer>(std::forward<Func>(func),
+ cache_capacity);
+}
+
+// Like MemoizeLru, but not thread-safe. This version allows for much faster
+// lookups (more than 2x faster), but you'll have to manage thread safety yourself.
+// A recommended usage is to declare per-thread caches using `thread_local`
+// (see cache_benchmark.cc).
+template <typename Func>
+static auto MemoizeLruThreadUnsafe(Func&& func, int32_t cache_capacity)
+ -> decltype(detail::Memoize<LruCache, detail::ThreadUnsafeMemoizer>(
+ std::forward<Func>(func), cache_capacity)) {
+ return detail::Memoize<LruCache, detail::ThreadUnsafeMemoizer>(std::forward<Func>(func),
+ cache_capacity);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/cache_test.cc b/src/arrow/cpp/src/arrow/util/cache_test.cc
new file mode 100644
index 000000000..6b71baa36
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/cache_test.cc
@@ -0,0 +1,290 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <atomic>
+#include <cstdint>
+#include <functional>
+#include <ostream>
+#include <string>
+#include <thread>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/cache_internal.h"
+
+namespace arrow {
+namespace internal {
+
+template <typename K1, typename V1, typename K2, typename V2>
+void AssertPairsEqual(std::pair<K1, V1> left, std::pair<K2, V2> right) {
+ ASSERT_EQ(left.first, right.first);
+ ASSERT_EQ(left.second, right.second);
+}
+
+class IntValue {
+ public:
+ explicit IntValue(int value = 0) : value_(std::make_shared<int>(value)) {}
+
+ IntValue(const IntValue&) = default;
+ IntValue(IntValue&&) = default;
+ IntValue& operator=(const IntValue&) = default;
+ IntValue& operator=(IntValue&&) = default;
+
+ int value() const { return *value_; }
+
+ bool operator==(const IntValue& other) const { return *value_ == *other.value_; }
+ bool operator!=(const IntValue& other) const { return *value_ != *other.value_; }
+
+ friend std::ostream& operator<<(std::ostream& os, IntValue v) {
+ os << "IntValue{" << *v.value_ << "}";
+ return os;
+ }
+
+ private:
+ // The shared_ptr makes it easier to detect lifetime bugs
+ std::shared_ptr<int> value_;
+};
+
+template <typename Value>
+Value Identity(Value&& v) {
+ return std::forward<Value>(v);
+}
+
+class TestLruCache : public ::testing::Test {
+ public:
+ using K = std::string;
+ using V = IntValue;
+ using Cache = LruCache<K, V>;
+
+ K MakeKey(int num) { return std::to_string(num); }
+
+ const V* Find(Cache* cache, int num) { return cache->Find(MakeKey(num)); }
+
+ bool Replace(Cache* cache, int num, int value_num) {
+ auto pair = cache->Replace(MakeKey(num), V{value_num});
+ EXPECT_NE(pair.second, nullptr);
+ EXPECT_EQ(*pair.second, V{value_num});
+ return pair.first;
+ }
+};
+
+TEST_F(TestLruCache, Basics) {
+ Cache cache(10);
+
+ using namespace std::placeholders; // NOLINT [build/namespaces]
+ auto Replace = std::bind(&TestLruCache::Replace, this, &cache, _1, _2);
+ auto Find = std::bind(&TestLruCache::Find, this, &cache, _1);
+
+ ASSERT_EQ(cache.size(), 0);
+ ASSERT_EQ(Find(100), nullptr);
+
+ // Insertions
+ ASSERT_TRUE(Replace(100, 100));
+ ASSERT_TRUE(Replace(101, 101));
+ ASSERT_TRUE(Replace(102, 102));
+ ASSERT_EQ(cache.size(), 3);
+ ASSERT_EQ(*Find(100), V{100});
+ ASSERT_EQ(*Find(101), V{101});
+ ASSERT_EQ(*Find(102), V{102});
+
+ // Replacements
+ ASSERT_FALSE(Replace(100, -100));
+ ASSERT_FALSE(Replace(101, -101));
+ ASSERT_FALSE(Replace(102, -102));
+ ASSERT_EQ(cache.size(), 3);
+ ASSERT_EQ(*Find(100), V{-100});
+ ASSERT_EQ(*Find(101), V{-101});
+ ASSERT_EQ(*Find(102), V{-102});
+
+ ASSERT_EQ(cache.size(), 3);
+ cache.Clear();
+ ASSERT_EQ(cache.size(), 0);
+}
+
+TEST_F(TestLruCache, Eviction) {
+ Cache cache(5);
+
+ using namespace std::placeholders; // NOLINT [build/namespaces]
+ auto Replace = std::bind(&TestLruCache::Replace, this, &cache, _1, _2);
+ auto Find = std::bind(&TestLruCache::Find, this, &cache, _1);
+
+ for (int i = 100; i < 105; ++i) {
+ ASSERT_TRUE(Replace(i, i));
+ }
+ ASSERT_EQ(cache.size(), 5);
+
+ // Access keys in a specific order
+ for (int i : {102, 103, 101, 104, 100}) {
+ ASSERT_EQ(*Find(i), V{i});
+ }
+ // Insert more entries
+ ASSERT_TRUE(Replace(105, 105));
+ ASSERT_TRUE(Replace(106, 106));
+ // The least recently used keys were evicted
+ ASSERT_EQ(Find(102), nullptr);
+ ASSERT_EQ(Find(103), nullptr);
+ for (int i : {100, 101, 104, 105, 106}) {
+ ASSERT_EQ(*Find(i), V{i});
+ }
+
+ // Alternate insertions and replacements
+ // MRU = [106, 105, 104, 101, 100]
+ ASSERT_FALSE(Replace(106, -106));
+ // MRU = [106, 105, 104, 101, 100]
+ ASSERT_FALSE(Replace(100, -100));
+ // MRU = [100, 106, 105, 104, 101]
+ ASSERT_FALSE(Replace(104, -104));
+ // MRU = [104, 100, 106, 105, 101]
+ ASSERT_TRUE(Replace(102, -102));
+ // MRU = [102, 104, 100, 106, 105]
+ ASSERT_TRUE(Replace(101, -101));
+ // MRU = [101, 102, 104, 100, 106]
+ for (int i : {101, 102, 104, 100, 106}) {
+ ASSERT_EQ(*Find(i), V{-i});
+ }
+ ASSERT_EQ(Find(103), nullptr);
+ ASSERT_EQ(Find(105), nullptr);
+
+ // MRU = [106, 100, 104, 102, 101]
+ ASSERT_TRUE(Replace(103, -103));
+ // MRU = [103, 106, 100, 104, 102]
+ ASSERT_TRUE(Replace(105, -105));
+ // MRU = [105, 103, 106, 100, 104]
+ for (int i : {105, 103, 106, 100, 104}) {
+ ASSERT_EQ(*Find(i), V{-i});
+ }
+ ASSERT_EQ(Find(101), nullptr);
+ ASSERT_EQ(Find(102), nullptr);
+}
+
+struct Callable {
+ std::atomic<int> num_calls{0};
+
+ IntValue operator()(const std::string& s) {
+ ++num_calls;
+ return IntValue{std::stoi(s)};
+ }
+};
+
+struct MemoizeLruFactory {
+ template <typename Func,
+ typename RetType = decltype(MemoizeLru(std::declval<Func>(), 0))>
+ RetType operator()(Func&& func, int32_t capacity) {
+ return MemoizeLru(std::forward<Func>(func), capacity);
+ }
+};
+
+struct MemoizeLruThreadUnsafeFactory {
+ template <typename Func,
+ typename RetType = decltype(MemoizeLruThreadUnsafe(std::declval<Func>(), 0))>
+ RetType operator()(Func&& func, int32_t capacity) {
+ return MemoizeLruThreadUnsafe(std::forward<Func>(func), capacity);
+ }
+};
+
+template <typename T>
+class TestMemoizeLru : public ::testing::Test {
+ public:
+ using K = std::string;
+ using V = IntValue;
+ using MemoizerFactory = T;
+
+ K MakeKey(int num) { return std::to_string(num); }
+
+ void TestBasics() {
+ using V = IntValue;
+ Callable c;
+
+ auto mem = factory_(c, 5);
+
+ // Cache fills
+ for (int i = 0; i < 5; ++i) {
+ ASSERT_EQ(mem(MakeKey(i)), V{i});
+ }
+ ASSERT_EQ(c.num_calls, 5);
+
+ // Cache hits
+ for (int i : {1, 3, 4, 0, 2}) {
+ ASSERT_EQ(mem(MakeKey(i)), V{i});
+ }
+ ASSERT_EQ(c.num_calls, 5);
+
+ // Calling with other inputs will cause evictions
+ for (int i = 5; i < 8; ++i) {
+ ASSERT_EQ(mem(MakeKey(i)), V{i});
+ }
+ ASSERT_EQ(c.num_calls, 8);
+ // Hits
+ for (int i : {0, 2, 5, 6, 7}) {
+ ASSERT_EQ(mem(MakeKey(i)), V{i});
+ }
+ ASSERT_EQ(c.num_calls, 8);
+ // Misses
+ for (int i : {1, 3, 4}) {
+ ASSERT_EQ(mem(MakeKey(i)), V{i});
+ }
+ ASSERT_EQ(c.num_calls, 11);
+ }
+
+ protected:
+ MemoizerFactory factory_;
+};
+
+using MemoizeLruTestTypes =
+ ::testing::Types<MemoizeLruFactory, MemoizeLruThreadUnsafeFactory>;
+
+TYPED_TEST_SUITE(TestMemoizeLru, MemoizeLruTestTypes);
+
+TYPED_TEST(TestMemoizeLru, Basics) { this->TestBasics(); }
+
+class TestMemoizeLruThreadSafe : public TestMemoizeLru<MemoizeLruFactory> {};
+
+TEST_F(TestMemoizeLruThreadSafe, Threads) {
+ using V = IntValue;
+ Callable c;
+
+ auto mem = this->factory_(c, 15);
+ const int n_threads = 4;
+#ifdef ARROW_VALGRIND
+ const int n_iters = 10;
+#else
+ const int n_iters = 100;
+#endif
+
+ auto thread_func = [&]() {
+ for (int i = 0; i < n_iters; ++i) {
+ const V& orig_value = mem("1");
+ // Ensure that some replacements are going on
+ // (# distinct keys > cache size)
+ for (int j = 0; j < 30; ++j) {
+ ASSERT_EQ(mem(std::to_string(j)), V{j});
+ }
+ ASSERT_EQ(orig_value, V{1});
+ }
+ };
+ std::vector<std::thread> threads;
+ for (int i = 0; i < n_threads; ++i) {
+ threads.emplace_back(thread_func);
+ }
+ for (auto& thread : threads) {
+ thread.join();
+ }
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/cancel.cc b/src/arrow/cpp/src/arrow/util/cancel.cc
new file mode 100644
index 000000000..874b2c2c8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/cancel.cc
@@ -0,0 +1,226 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/cancel.h"
+
+#include <atomic>
+#include <mutex>
+#include <sstream>
+#include <utility>
+
+#include "arrow/result.h"
+#include "arrow/util/atomic_shared_ptr.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+#if ATOMIC_INT_LOCK_FREE != 2
+#error Lock-free atomic int required for signal safety
+#endif
+
+using internal::ReinstateSignalHandler;
+using internal::SetSignalHandler;
+using internal::SignalHandler;
+
+// NOTE: We care mainly about the making the common case (not cancelled) fast.
+
+struct StopSourceImpl {
+ std::atomic<int> requested_{0}; // will be -1 or signal number if requested
+ std::mutex mutex_;
+ Status cancel_error_;
+};
+
+StopSource::StopSource() : impl_(new StopSourceImpl) {}
+
+StopSource::~StopSource() = default;
+
+void StopSource::RequestStop() { RequestStop(Status::Cancelled("Operation cancelled")); }
+
+void StopSource::RequestStop(Status st) {
+ std::lock_guard<std::mutex> lock(impl_->mutex_);
+ DCHECK(!st.ok());
+ if (!impl_->requested_) {
+ impl_->requested_ = -1;
+ impl_->cancel_error_ = std::move(st);
+ }
+}
+
+void StopSource::RequestStopFromSignal(int signum) {
+ // Only async-signal-safe code allowed here
+ impl_->requested_.store(signum);
+}
+
+void StopSource::Reset() {
+ std::lock_guard<std::mutex> lock(impl_->mutex_);
+ impl_->cancel_error_ = Status::OK();
+ impl_->requested_.store(0);
+}
+
+StopToken StopSource::token() { return StopToken(impl_); }
+
+bool StopToken::IsStopRequested() const {
+ if (!impl_) {
+ return false;
+ }
+ return impl_->requested_.load() != 0;
+}
+
+Status StopToken::Poll() const {
+ if (!impl_) {
+ return Status::OK();
+ }
+ if (!impl_->requested_.load()) {
+ return Status::OK();
+ }
+
+ std::lock_guard<std::mutex> lock(impl_->mutex_);
+ if (impl_->cancel_error_.ok()) {
+ auto signum = impl_->requested_.load();
+ DCHECK_GT(signum, 0);
+ impl_->cancel_error_ = internal::CancelledFromSignal(signum, "Operation cancelled");
+ }
+ return impl_->cancel_error_;
+}
+
+namespace {
+
+struct SignalStopState {
+ struct SavedSignalHandler {
+ int signum;
+ SignalHandler handler;
+ };
+
+ Status RegisterHandlers(const std::vector<int>& signals) {
+ if (!saved_handlers_.empty()) {
+ return Status::Invalid("Signal handlers already registered");
+ }
+ for (int signum : signals) {
+ ARROW_ASSIGN_OR_RAISE(auto handler,
+ SetSignalHandler(signum, SignalHandler{&HandleSignal}));
+ saved_handlers_.push_back({signum, handler});
+ }
+ return Status::OK();
+ }
+
+ void UnregisterHandlers() {
+ auto handlers = std::move(saved_handlers_);
+ for (const auto& h : handlers) {
+ ARROW_CHECK_OK(SetSignalHandler(h.signum, h.handler).status());
+ }
+ }
+
+ ~SignalStopState() {
+ UnregisterHandlers();
+ Disable();
+ }
+
+ StopSource* stop_source() { return stop_source_.get(); }
+
+ bool enabled() { return stop_source_ != nullptr; }
+
+ void Enable() {
+ // Before creating a new StopSource, delete any lingering reference to
+ // the previous one in the trash can. See DoHandleSignal() for details.
+ EmptyTrashCan();
+ internal::atomic_store(&stop_source_, std::make_shared<StopSource>());
+ }
+
+ void Disable() { internal::atomic_store(&stop_source_, NullSource()); }
+
+ static SignalStopState* instance() { return &instance_; }
+
+ private:
+ // For readability
+ std::shared_ptr<StopSource> NullSource() { return nullptr; }
+
+ void EmptyTrashCan() { internal::atomic_store(&trash_can_, NullSource()); }
+
+ static void HandleSignal(int signum) { instance_.DoHandleSignal(signum); }
+
+ void DoHandleSignal(int signum) {
+ // async-signal-safe code only
+ auto source = internal::atomic_load(&stop_source_);
+ if (source) {
+ source->RequestStopFromSignal(signum);
+ // Disable() may have been called in the meantime, but we can't
+ // deallocate a shared_ptr here, so instead move it to a "trash can".
+ // This minimizes the possibility of running a deallocator here,
+ // however it doesn't entirely preclude it.
+ //
+ // Possible case:
+ // - a signal handler (A) starts running, fetches the current source
+ // - Disable() then Enable() are called, emptying the trash can and
+ // replacing the current source
+ // - a signal handler (B) starts running, fetches the current source
+ // - signal handler A resumes, moves its source (the old source) into
+ // the trash can (the only remaining reference)
+ // - signal handler B resumes, moves its source (the current source)
+ // into the trash can. This triggers deallocation of the old source,
+ // since the trash can had the only remaining reference to it.
+ //
+ // This case should be sufficiently unlikely, but we cannot entirely
+ // rule it out. The problem might be solved properly with a lock-free
+ // linked list of StopSources.
+ internal::atomic_store(&trash_can_, std::move(source));
+ }
+ ReinstateSignalHandler(signum, &HandleSignal);
+ }
+
+ std::shared_ptr<StopSource> stop_source_;
+ std::shared_ptr<StopSource> trash_can_;
+
+ std::vector<SavedSignalHandler> saved_handlers_;
+
+ static SignalStopState instance_;
+};
+
+SignalStopState SignalStopState::instance_{};
+
+} // namespace
+
+Result<StopSource*> SetSignalStopSource() {
+ auto stop_state = SignalStopState::instance();
+ if (stop_state->enabled()) {
+ return Status::Invalid("Signal stop source already set up");
+ }
+ stop_state->Enable();
+ return stop_state->stop_source();
+}
+
+void ResetSignalStopSource() {
+ auto stop_state = SignalStopState::instance();
+ DCHECK(stop_state->enabled());
+ stop_state->Disable();
+}
+
+Status RegisterCancellingSignalHandler(const std::vector<int>& signals) {
+ auto stop_state = SignalStopState::instance();
+ if (!stop_state->enabled()) {
+ return Status::Invalid("Signal stop source was not set up");
+ }
+ return stop_state->RegisterHandlers(signals);
+}
+
+void UnregisterCancellingSignalHandler() {
+ auto stop_state = SignalStopState::instance();
+ DCHECK(stop_state->enabled());
+ stop_state->UnregisterHandlers();
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/cancel.h b/src/arrow/cpp/src/arrow/util/cancel.h
new file mode 100644
index 000000000..9e00f673a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/cancel.h
@@ -0,0 +1,102 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class StopToken;
+
+struct StopSourceImpl;
+
+/// EXPERIMENTAL
+class ARROW_EXPORT StopSource {
+ public:
+ StopSource();
+ ~StopSource();
+
+ // Consumer API (the side that stops)
+ void RequestStop();
+ void RequestStop(Status error);
+ void RequestStopFromSignal(int signum);
+
+ StopToken token();
+
+ // For internal use only
+ void Reset();
+
+ protected:
+ std::shared_ptr<StopSourceImpl> impl_;
+};
+
+/// EXPERIMENTAL
+class ARROW_EXPORT StopToken {
+ public:
+ // Public for Cython
+ StopToken() {}
+
+ explicit StopToken(std::shared_ptr<StopSourceImpl> impl) : impl_(std::move(impl)) {}
+
+ // A trivial token that never propagates any stop request
+ static StopToken Unstoppable() { return StopToken(); }
+
+ // Producer API (the side that gets asked to stopped)
+ Status Poll() const;
+ bool IsStopRequested() const;
+
+ protected:
+ std::shared_ptr<StopSourceImpl> impl_;
+};
+
+/// EXPERIMENTAL: Set a global StopSource that can receive signals
+///
+/// The only allowed order of calls is the following:
+/// - SetSignalStopSource()
+/// - any number of pairs of (RegisterCancellingSignalHandler,
+/// UnregisterCancellingSignalHandler) calls
+/// - ResetSignalStopSource()
+///
+/// Beware that these settings are process-wide. Typically, only one
+/// thread should call these APIs, even in a multithreaded setting.
+ARROW_EXPORT
+Result<StopSource*> SetSignalStopSource();
+
+/// EXPERIMENTAL: Reset the global signal-receiving StopSource
+///
+/// This will invalidate the pointer returned by SetSignalStopSource.
+ARROW_EXPORT
+void ResetSignalStopSource();
+
+/// EXPERIMENTAL: Register signal handler triggering the signal-receiving StopSource
+ARROW_EXPORT
+Status RegisterCancellingSignalHandler(const std::vector<int>& signals);
+
+/// EXPERIMENTAL: Unregister signal handler set up by RegisterCancellingSignalHandler
+ARROW_EXPORT
+void UnregisterCancellingSignalHandler();
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/cancel_test.cc b/src/arrow/cpp/src/arrow/util/cancel_test.cc
new file mode 100644
index 000000000..b9bf94ba4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/cancel_test.cc
@@ -0,0 +1,308 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <atomic>
+#include <cmath>
+#include <sstream>
+#include <string>
+#include <thread>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include <signal.h>
+#ifndef _WIN32
+#include <sys/time.h> // for setitimer()
+#endif
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/cancel.h"
+#include "arrow/util/future.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/optional.h"
+
+namespace arrow {
+
+class CancelTest : public ::testing::Test {};
+
+TEST_F(CancelTest, StopBasics) {
+ {
+ StopSource source;
+ StopToken token = source.token();
+ ASSERT_FALSE(token.IsStopRequested());
+ ASSERT_OK(token.Poll());
+
+ source.RequestStop();
+ ASSERT_TRUE(token.IsStopRequested());
+ ASSERT_RAISES(Cancelled, token.Poll());
+ }
+ {
+ StopSource source;
+ StopToken token = source.token();
+ source.RequestStop(Status::IOError("Operation cancelled"));
+ ASSERT_TRUE(token.IsStopRequested());
+ ASSERT_RAISES(IOError, token.Poll());
+ }
+}
+
+TEST_F(CancelTest, StopTokenCopy) {
+ StopSource source;
+ StopToken token = source.token();
+ ASSERT_FALSE(token.IsStopRequested());
+ ASSERT_OK(token.Poll());
+
+ StopToken token2 = token;
+ ASSERT_FALSE(token2.IsStopRequested());
+ ASSERT_OK(token2.Poll());
+
+ source.RequestStop();
+ StopToken token3 = token;
+
+ ASSERT_TRUE(token.IsStopRequested());
+ ASSERT_TRUE(token2.IsStopRequested());
+ ASSERT_TRUE(token3.IsStopRequested());
+ ASSERT_RAISES(Cancelled, token.Poll());
+ ASSERT_EQ(token2.Poll(), token.Poll());
+ ASSERT_EQ(token3.Poll(), token.Poll());
+}
+
+TEST_F(CancelTest, RequestStopTwice) {
+ StopSource source;
+ StopToken token = source.token();
+ source.RequestStop();
+ // Second RequestStop() call is ignored
+ source.RequestStop(Status::IOError("Operation cancelled"));
+ ASSERT_TRUE(token.IsStopRequested());
+ ASSERT_RAISES(Cancelled, token.Poll());
+}
+
+TEST_F(CancelTest, Unstoppable) {
+ StopToken token = StopToken::Unstoppable();
+ ASSERT_FALSE(token.IsStopRequested());
+ ASSERT_OK(token.Poll());
+}
+
+TEST_F(CancelTest, SourceVanishes) {
+ {
+ util::optional<StopSource> source{StopSource()};
+ StopToken token = source->token();
+ ASSERT_FALSE(token.IsStopRequested());
+ ASSERT_OK(token.Poll());
+
+ source.reset();
+ ASSERT_FALSE(token.IsStopRequested());
+ ASSERT_OK(token.Poll());
+ }
+ {
+ util::optional<StopSource> source{StopSource()};
+ StopToken token = source->token();
+ source->RequestStop();
+
+ source.reset();
+ ASSERT_TRUE(token.IsStopRequested());
+ ASSERT_RAISES(Cancelled, token.Poll());
+ }
+}
+
+static void noop_signal_handler(int signum) {
+ internal::ReinstateSignalHandler(signum, &noop_signal_handler);
+}
+
+#ifndef _WIN32
+static util::optional<StopSource> signal_stop_source;
+
+static void signal_handler(int signum) {
+ signal_stop_source->RequestStopFromSignal(signum);
+}
+
+// SIGALRM will be received once after the specified wait
+static void SetITimer(double seconds) {
+ const double fractional = std::modf(seconds, &seconds);
+ struct itimerval it;
+ it.it_value.tv_sec = seconds;
+ it.it_value.tv_usec = 1e6 * fractional;
+ it.it_interval.tv_sec = 0;
+ it.it_interval.tv_usec = 0;
+ ASSERT_EQ(0, setitimer(ITIMER_REAL, &it, nullptr)) << "setitimer failed";
+}
+
+TEST_F(CancelTest, RequestStopFromSignal) {
+ signal_stop_source = StopSource(); // Start with a fresh StopSource
+ StopToken signal_token = signal_stop_source->token();
+ SignalHandlerGuard guard(SIGALRM, &signal_handler);
+
+ // Timer will be triggered once in 100 usecs
+ SetITimer(0.0001);
+
+ BusyWait(1.0, [&]() { return signal_token.IsStopRequested(); });
+ ASSERT_TRUE(signal_token.IsStopRequested());
+ auto st = signal_token.Poll();
+ ASSERT_RAISES(Cancelled, st);
+ ASSERT_EQ(st.message(), "Operation cancelled");
+ ASSERT_EQ(internal::SignalFromStatus(st), SIGALRM);
+}
+#endif
+
+class SignalCancelTest : public CancelTest {
+ public:
+ void SetUp() override {
+ // Setup a dummy signal handler to avoid crashing when receiving signal
+ guard_.emplace(expected_signal_, &noop_signal_handler);
+ ASSERT_OK_AND_ASSIGN(auto stop_source, SetSignalStopSource());
+ stop_token_ = stop_source->token();
+ }
+
+ void TearDown() override {
+ UnregisterCancellingSignalHandler();
+ ResetSignalStopSource();
+ }
+
+ void RegisterHandler() {
+ ASSERT_OK(RegisterCancellingSignalHandler({expected_signal_}));
+ }
+
+#ifdef _WIN32
+ void TriggerSignal() {
+ std::thread([]() { ASSERT_OK(internal::SendSignal(SIGINT)); }).detach();
+ }
+#else
+ // On Unix, use setitimer() to exercise signal-async-safety
+ void TriggerSignal() { SetITimer(0.0001); }
+#endif
+
+ void AssertStopNotRequested() {
+ SleepFor(0.01);
+ ASSERT_FALSE(stop_token_->IsStopRequested());
+ ASSERT_OK(stop_token_->Poll());
+ }
+
+ void AssertStopRequested() {
+ BusyWait(1.0, [&]() { return stop_token_->IsStopRequested(); });
+ ASSERT_TRUE(stop_token_->IsStopRequested());
+ auto st = stop_token_->Poll();
+ ASSERT_RAISES(Cancelled, st);
+ ASSERT_EQ(st.message(), "Operation cancelled");
+ ASSERT_EQ(internal::SignalFromStatus(st), expected_signal_);
+ }
+
+ protected:
+#ifdef _WIN32
+ const int expected_signal_ = SIGINT;
+#else
+ const int expected_signal_ = SIGALRM;
+#endif
+ util::optional<SignalHandlerGuard> guard_;
+ util::optional<StopToken> stop_token_;
+};
+
+TEST_F(SignalCancelTest, Register) {
+ RegisterHandler();
+
+ TriggerSignal();
+ AssertStopRequested();
+}
+
+TEST_F(SignalCancelTest, RegisterUnregister) {
+ // The signal stop source was set up but no handler was registered,
+ // so the token shouldn't be signalled.
+ TriggerSignal();
+ AssertStopNotRequested();
+
+ // Register and then unregister: same
+ RegisterHandler();
+ UnregisterCancellingSignalHandler();
+
+ TriggerSignal();
+ AssertStopNotRequested();
+
+ // Register again and raise the signal: token will be signalled.
+ RegisterHandler();
+
+ TriggerSignal();
+ AssertStopRequested();
+}
+
+TEST_F(CancelTest, ThreadedPollSuccess) {
+ constexpr int kNumThreads = 10;
+
+ std::vector<Status> results(kNumThreads);
+ std::vector<std::thread> threads;
+
+ StopSource source;
+ StopToken token = source.token();
+ std::atomic<bool> terminate_flag{false};
+
+ const auto worker_func = [&](int thread_num) {
+ while (token.Poll().ok() && !terminate_flag) {
+ }
+ results[thread_num] = token.Poll();
+ };
+ for (int i = 0; i < kNumThreads; ++i) {
+ threads.emplace_back(std::bind(worker_func, i));
+ }
+
+ // Let the threads start and hammer on Poll() for a while
+ SleepFor(1e-2);
+ // Tell threads to stop
+ terminate_flag = true;
+ for (auto& thread : threads) {
+ thread.join();
+ }
+
+ for (const auto& st : results) {
+ ASSERT_OK(st);
+ }
+}
+
+TEST_F(CancelTest, ThreadedPollCancel) {
+ constexpr int kNumThreads = 10;
+
+ std::vector<Status> results(kNumThreads);
+ std::vector<std::thread> threads;
+
+ StopSource source;
+ StopToken token = source.token();
+ std::atomic<bool> terminate_flag{false};
+ const auto stop_error = Status::IOError("Operation cancelled");
+
+ const auto worker_func = [&](int thread_num) {
+ while (token.Poll().ok() && !terminate_flag) {
+ }
+ results[thread_num] = token.Poll();
+ };
+
+ for (int i = 0; i < kNumThreads; ++i) {
+ threads.emplace_back(std::bind(worker_func, i));
+ }
+ // Let the threads start
+ SleepFor(1e-2);
+ // Cancel token while threads are hammering on Poll()
+ source.RequestStop(stop_error);
+ // Tell threads to stop
+ terminate_flag = true;
+ for (auto& thread : threads) {
+ thread.join();
+ }
+
+ for (const auto& st : results) {
+ ASSERT_EQ(st, stop_error);
+ }
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/checked_cast.h b/src/arrow/cpp/src/arrow/util/checked_cast.h
new file mode 100644
index 000000000..97f6b61a1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/checked_cast.h
@@ -0,0 +1,61 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <type_traits>
+#include <utility>
+
+namespace arrow {
+namespace internal {
+
+template <typename OutputType, typename InputType>
+inline OutputType checked_cast(InputType&& value) {
+ static_assert(std::is_class<typename std::remove_pointer<
+ typename std::remove_reference<InputType>::type>::type>::value,
+ "checked_cast input type must be a class");
+ static_assert(std::is_class<typename std::remove_pointer<
+ typename std::remove_reference<OutputType>::type>::type>::value,
+ "checked_cast output type must be a class");
+#ifdef NDEBUG
+ return static_cast<OutputType>(value);
+#else
+ return dynamic_cast<OutputType>(value);
+#endif
+}
+
+template <class T, class U>
+std::shared_ptr<T> checked_pointer_cast(std::shared_ptr<U> r) noexcept {
+#ifdef NDEBUG
+ return std::static_pointer_cast<T>(std::move(r));
+#else
+ return std::dynamic_pointer_cast<T>(std::move(r));
+#endif
+}
+
+template <class T, class U>
+std::unique_ptr<T> checked_pointer_cast(std::unique_ptr<U> r) noexcept {
+#ifdef NDEBUG
+ return std::unique_ptr<T>(static_cast<T*>(r.release()));
+#else
+ return std::unique_ptr<T>(dynamic_cast<T*>(r.release()));
+#endif
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/checked_cast_test.cc b/src/arrow/cpp/src/arrow/util/checked_cast_test.cc
new file mode 100644
index 000000000..b50a859cb
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/checked_cast_test.cc
@@ -0,0 +1,74 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <type_traits>
+#include <typeinfo>
+
+#include <gtest/gtest.h>
+
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+namespace internal {
+
+class Foo {
+ public:
+ virtual ~Foo() = default;
+};
+
+class Bar {};
+class FooSub : public Foo {};
+template <typename T>
+class Baz : public Foo {};
+
+TEST(CheckedCast, TestInvalidSubclassCast) {
+ static_assert(std::is_polymorphic<Foo>::value, "Foo is not polymorphic");
+
+ Foo foo;
+ FooSub foosub;
+ const Foo& foosubref = foosub;
+ Baz<double> baz;
+ const Foo& bazref = baz;
+
+#ifndef NDEBUG // debug mode
+ // illegal pointer cast
+ ASSERT_EQ(nullptr, checked_cast<Bar*>(&foo));
+
+ // illegal reference cast
+ ASSERT_THROW(checked_cast<const Bar&>(foosubref), std::bad_cast);
+
+ // legal reference casts
+ ASSERT_NO_THROW(checked_cast<const FooSub&>(foosubref));
+ ASSERT_NO_THROW(checked_cast<const Baz<double>&>(bazref));
+#else // release mode
+ // failure modes for the invalid casts occur at compile time
+
+ // legal pointer cast
+ ASSERT_NE(nullptr, checked_cast<const FooSub*>(&foosubref));
+
+ // legal reference casts: this is static_cast in a release build, so ASSERT_NO_THROW
+ // doesn't make a whole lot of sense here.
+ auto& x = checked_cast<const FooSub&>(foosubref);
+ ASSERT_EQ(&foosubref, &x);
+
+ auto& y = checked_cast<const Baz<double>&>(bazref);
+ ASSERT_EQ(&bazref, &y);
+#endif
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/compare.h b/src/arrow/cpp/src/arrow/util/compare.h
new file mode 100644
index 000000000..6477bf139
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/compare.h
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <type_traits>
+#include <utility>
+
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace util {
+
+/// CRTP helper for declaring equality comparison. Defines operator== and operator!=
+template <typename T>
+class EqualityComparable {
+ public:
+ ~EqualityComparable() {
+ static_assert(
+ std::is_same<decltype(std::declval<const T>().Equals(std::declval<const T>())),
+ bool>::value,
+ "EqualityComparable depends on the method T::Equals(const T&) const");
+ }
+
+ template <typename... Extra>
+ bool Equals(const std::shared_ptr<T>& other, Extra&&... extra) const {
+ if (other == NULLPTR) {
+ return false;
+ }
+ return cast().Equals(*other, std::forward<Extra>(extra)...);
+ }
+
+ struct PtrsEqual {
+ bool operator()(const std::shared_ptr<T>& l, const std::shared_ptr<T>& r) const {
+ return l->Equals(r);
+ }
+ };
+
+ bool operator==(const T& other) const { return cast().Equals(other); }
+ bool operator!=(const T& other) const { return !(cast() == other); }
+
+ private:
+ const T& cast() const { return static_cast<const T&>(*this); }
+};
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/compression.cc b/src/arrow/cpp/src/arrow/util/compression.cc
new file mode 100644
index 000000000..8db199b4e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/compression.cc
@@ -0,0 +1,261 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/compression.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/compression_internal.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace util {
+
+namespace {
+
+Status CheckSupportsCompressionLevel(Compression::type type) {
+ if (!Codec::SupportsCompressionLevel(type)) {
+ return Status::Invalid(
+ "The specified codec does not support the compression level parameter");
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+int Codec::UseDefaultCompressionLevel() { return kUseDefaultCompressionLevel; }
+
+Status Codec::Init() { return Status::OK(); }
+
+const std::string& Codec::GetCodecAsString(Compression::type t) {
+ static const std::string uncompressed = "uncompressed", snappy = "snappy",
+ gzip = "gzip", lzo = "lzo", brotli = "brotli",
+ lz4_raw = "lz4_raw", lz4 = "lz4", lz4_hadoop = "lz4_hadoop",
+ zstd = "zstd", bz2 = "bz2", unknown = "unknown";
+
+ switch (t) {
+ case Compression::UNCOMPRESSED:
+ return uncompressed;
+ case Compression::SNAPPY:
+ return snappy;
+ case Compression::GZIP:
+ return gzip;
+ case Compression::LZO:
+ return lzo;
+ case Compression::BROTLI:
+ return brotli;
+ case Compression::LZ4:
+ return lz4_raw;
+ case Compression::LZ4_FRAME:
+ return lz4;
+ case Compression::LZ4_HADOOP:
+ return lz4_hadoop;
+ case Compression::ZSTD:
+ return zstd;
+ case Compression::BZ2:
+ return bz2;
+ default:
+ return unknown;
+ }
+}
+
+Result<Compression::type> Codec::GetCompressionType(const std::string& name) {
+ if (name == "uncompressed") {
+ return Compression::UNCOMPRESSED;
+ } else if (name == "gzip") {
+ return Compression::GZIP;
+ } else if (name == "snappy") {
+ return Compression::SNAPPY;
+ } else if (name == "lzo") {
+ return Compression::LZO;
+ } else if (name == "brotli") {
+ return Compression::BROTLI;
+ } else if (name == "lz4_raw") {
+ return Compression::LZ4;
+ } else if (name == "lz4") {
+ return Compression::LZ4_FRAME;
+ } else if (name == "lz4_hadoop") {
+ return Compression::LZ4_HADOOP;
+ } else if (name == "zstd") {
+ return Compression::ZSTD;
+ } else if (name == "bz2") {
+ return Compression::BZ2;
+ } else {
+ return Status::Invalid("Unrecognized compression type: ", name);
+ }
+}
+
+bool Codec::SupportsCompressionLevel(Compression::type codec) {
+ switch (codec) {
+ case Compression::GZIP:
+ case Compression::BROTLI:
+ case Compression::ZSTD:
+ case Compression::BZ2:
+ return true;
+ default:
+ return false;
+ }
+}
+
+Result<int> Codec::MaximumCompressionLevel(Compression::type codec_type) {
+ RETURN_NOT_OK(CheckSupportsCompressionLevel(codec_type));
+ ARROW_ASSIGN_OR_RAISE(auto codec, Codec::Create(codec_type));
+ return codec->maximum_compression_level();
+}
+
+Result<int> Codec::MinimumCompressionLevel(Compression::type codec_type) {
+ RETURN_NOT_OK(CheckSupportsCompressionLevel(codec_type));
+ ARROW_ASSIGN_OR_RAISE(auto codec, Codec::Create(codec_type));
+ return codec->minimum_compression_level();
+}
+
+Result<int> Codec::DefaultCompressionLevel(Compression::type codec_type) {
+ RETURN_NOT_OK(CheckSupportsCompressionLevel(codec_type));
+ ARROW_ASSIGN_OR_RAISE(auto codec, Codec::Create(codec_type));
+ return codec->default_compression_level();
+}
+
+Result<std::unique_ptr<Codec>> Codec::Create(Compression::type codec_type,
+ int compression_level) {
+ if (!IsAvailable(codec_type)) {
+ if (codec_type == Compression::LZO) {
+ return Status::NotImplemented("LZO codec not implemented");
+ }
+
+ auto name = GetCodecAsString(codec_type);
+ if (name == "unknown") {
+ return Status::Invalid("Unrecognized codec");
+ }
+
+ return Status::NotImplemented("Support for codec '", GetCodecAsString(codec_type),
+ "' not built");
+ }
+
+ if (compression_level != kUseDefaultCompressionLevel &&
+ !SupportsCompressionLevel(codec_type)) {
+ return Status::Invalid("Codec '", GetCodecAsString(codec_type),
+ "' doesn't support setting a compression level.");
+ }
+
+ std::unique_ptr<Codec> codec;
+ switch (codec_type) {
+ case Compression::UNCOMPRESSED:
+ return nullptr;
+ case Compression::SNAPPY:
+#ifdef ARROW_WITH_SNAPPY
+ codec = internal::MakeSnappyCodec();
+#endif
+ break;
+ case Compression::GZIP:
+#ifdef ARROW_WITH_ZLIB
+ codec = internal::MakeGZipCodec(compression_level);
+#endif
+ break;
+ case Compression::BROTLI:
+#ifdef ARROW_WITH_BROTLI
+ codec = internal::MakeBrotliCodec(compression_level);
+#endif
+ break;
+ case Compression::LZ4:
+#ifdef ARROW_WITH_LZ4
+ codec = internal::MakeLz4RawCodec();
+#endif
+ break;
+ case Compression::LZ4_FRAME:
+#ifdef ARROW_WITH_LZ4
+ codec = internal::MakeLz4FrameCodec();
+#endif
+ break;
+ case Compression::LZ4_HADOOP:
+#ifdef ARROW_WITH_LZ4
+ codec = internal::MakeLz4HadoopRawCodec();
+#endif
+ break;
+ case Compression::ZSTD:
+#ifdef ARROW_WITH_ZSTD
+ codec = internal::MakeZSTDCodec(compression_level);
+#endif
+ break;
+ case Compression::BZ2:
+#ifdef ARROW_WITH_BZ2
+ codec = internal::MakeBZ2Codec(compression_level);
+#endif
+ break;
+ default:
+ break;
+ }
+
+ DCHECK_NE(codec, nullptr);
+ RETURN_NOT_OK(codec->Init());
+ return std::move(codec);
+}
+
+bool Codec::IsAvailable(Compression::type codec_type) {
+ switch (codec_type) {
+ case Compression::UNCOMPRESSED:
+ return true;
+ case Compression::SNAPPY:
+#ifdef ARROW_WITH_SNAPPY
+ return true;
+#else
+ return false;
+#endif
+ case Compression::GZIP:
+#ifdef ARROW_WITH_ZLIB
+ return true;
+#else
+ return false;
+#endif
+ case Compression::LZO:
+ return false;
+ case Compression::BROTLI:
+#ifdef ARROW_WITH_BROTLI
+ return true;
+#else
+ return false;
+#endif
+ case Compression::LZ4:
+ case Compression::LZ4_FRAME:
+ case Compression::LZ4_HADOOP:
+#ifdef ARROW_WITH_LZ4
+ return true;
+#else
+ return false;
+#endif
+ case Compression::ZSTD:
+#ifdef ARROW_WITH_ZSTD
+ return true;
+#else
+ return false;
+#endif
+ case Compression::BZ2:
+#ifdef ARROW_WITH_BZ2
+ return true;
+#else
+ return false;
+#endif
+ default:
+ return false;
+ }
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/compression.h b/src/arrow/cpp/src/arrow/util/compression.h
new file mode 100644
index 000000000..0832e82a6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/compression.h
@@ -0,0 +1,202 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <string>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace util {
+
+constexpr int kUseDefaultCompressionLevel = std::numeric_limits<int>::min();
+
+/// \brief Streaming compressor interface
+///
+class ARROW_EXPORT Compressor {
+ public:
+ virtual ~Compressor() = default;
+
+ struct CompressResult {
+ int64_t bytes_read;
+ int64_t bytes_written;
+ };
+ struct FlushResult {
+ int64_t bytes_written;
+ bool should_retry;
+ };
+ struct EndResult {
+ int64_t bytes_written;
+ bool should_retry;
+ };
+
+ /// \brief Compress some input.
+ ///
+ /// If bytes_read is 0 on return, then a larger output buffer should be supplied.
+ virtual Result<CompressResult> Compress(int64_t input_len, const uint8_t* input,
+ int64_t output_len, uint8_t* output) = 0;
+
+ /// \brief Flush part of the compressed output.
+ ///
+ /// If should_retry is true on return, Flush() should be called again
+ /// with a larger buffer.
+ virtual Result<FlushResult> Flush(int64_t output_len, uint8_t* output) = 0;
+
+ /// \brief End compressing, doing whatever is necessary to end the stream.
+ ///
+ /// If should_retry is true on return, End() should be called again
+ /// with a larger buffer. Otherwise, the Compressor should not be used anymore.
+ ///
+ /// End() implies Flush().
+ virtual Result<EndResult> End(int64_t output_len, uint8_t* output) = 0;
+
+ // XXX add methods for buffer size heuristics?
+};
+
+/// \brief Streaming decompressor interface
+///
+class ARROW_EXPORT Decompressor {
+ public:
+ virtual ~Decompressor() = default;
+
+ struct DecompressResult {
+ // XXX is need_more_output necessary? (Brotli?)
+ int64_t bytes_read;
+ int64_t bytes_written;
+ bool need_more_output;
+ };
+
+ /// \brief Decompress some input.
+ ///
+ /// If need_more_output is true on return, a larger output buffer needs
+ /// to be supplied.
+ virtual Result<DecompressResult> Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_len, uint8_t* output) = 0;
+
+ /// \brief Return whether the compressed stream is finished.
+ ///
+ /// This is a heuristic. If true is returned, then it is guaranteed
+ /// that the stream is finished. If false is returned, however, it may
+ /// simply be that the underlying library isn't able to provide the information.
+ virtual bool IsFinished() = 0;
+
+ /// \brief Reinitialize decompressor, making it ready for a new compressed stream.
+ virtual Status Reset() = 0;
+
+ // XXX add methods for buffer size heuristics?
+};
+
+/// \brief Compression codec
+class ARROW_EXPORT Codec {
+ public:
+ virtual ~Codec() = default;
+
+ /// \brief Return special value to indicate that a codec implementation
+ /// should use its default compression level
+ static int UseDefaultCompressionLevel();
+
+ /// \brief Return a string name for compression type
+ static const std::string& GetCodecAsString(Compression::type t);
+
+ /// \brief Return compression type for name (all upper case)
+ static Result<Compression::type> GetCompressionType(const std::string& name);
+
+ /// \brief Create a codec for the given compression algorithm
+ static Result<std::unique_ptr<Codec>> Create(
+ Compression::type codec, int compression_level = kUseDefaultCompressionLevel);
+
+ /// \brief Return true if support for indicated codec has been enabled
+ static bool IsAvailable(Compression::type codec);
+
+ /// \brief Return true if indicated codec supports setting a compression level
+ static bool SupportsCompressionLevel(Compression::type codec);
+
+ /// \brief Return the smallest supported compression level for the codec
+ /// Note: This function creates a temporary Codec instance
+ static Result<int> MinimumCompressionLevel(Compression::type codec);
+
+ /// \brief Return the largest supported compression level for the codec
+ /// Note: This function creates a temporary Codec instance
+ static Result<int> MaximumCompressionLevel(Compression::type codec);
+
+ /// \brief Return the default compression level
+ /// Note: This function creates a temporary Codec instance
+ static Result<int> DefaultCompressionLevel(Compression::type codec);
+
+ /// \brief Return the smallest supported compression level
+ virtual int minimum_compression_level() const = 0;
+
+ /// \brief Return the largest supported compression level
+ virtual int maximum_compression_level() const = 0;
+
+ /// \brief Return the default compression level
+ virtual int default_compression_level() const = 0;
+
+ /// \brief One-shot decompression function
+ ///
+ /// output_buffer_len must be correct and therefore be obtained in advance.
+ /// The actual decompressed length is returned.
+ ///
+ /// \note One-shot decompression is not always compatible with streaming
+ /// compression. Depending on the codec (e.g. LZ4), different formats may
+ /// be used.
+ virtual Result<int64_t> Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len,
+ uint8_t* output_buffer) = 0;
+
+ /// \brief One-shot compression function
+ ///
+ /// output_buffer_len must first have been computed using MaxCompressedLen().
+ /// The actual compressed length is returned.
+ ///
+ /// \note One-shot compression is not always compatible with streaming
+ /// decompression. Depending on the codec (e.g. LZ4), different formats may
+ /// be used.
+ virtual Result<int64_t> Compress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) = 0;
+
+ virtual int64_t MaxCompressedLen(int64_t input_len, const uint8_t* input) = 0;
+
+ /// \brief Create a streaming compressor instance
+ virtual Result<std::shared_ptr<Compressor>> MakeCompressor() = 0;
+
+ /// \brief Create a streaming compressor instance
+ virtual Result<std::shared_ptr<Decompressor>> MakeDecompressor() = 0;
+
+ /// \brief This Codec's compression type
+ virtual Compression::type compression_type() const = 0;
+
+ /// \brief The name of this Codec's compression type
+ const std::string& name() const { return GetCodecAsString(compression_type()); }
+
+ /// \brief This Codec's compression level, if applicable
+ virtual int compression_level() const { return UseDefaultCompressionLevel(); }
+
+ private:
+ /// \brief Initializes the codec's resources.
+ virtual Status Init();
+};
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/compression_benchmark.cc b/src/arrow/cpp/src/arrow/util/compression_benchmark.cc
new file mode 100644
index 000000000..c76be275f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/compression_benchmark.cc
@@ -0,0 +1,201 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <random>
+#include <string>
+#include <vector>
+
+#include "arrow/result.h"
+#include "arrow/util/compression.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace util {
+
+#ifdef ARROW_WITH_BENCHMARKS_REFERENCE
+
+std::vector<uint8_t> MakeCompressibleData(int data_size) {
+ // XXX This isn't a real-world corpus so doesn't really represent the
+ // comparative qualities of the algorithms
+
+ // First make highly compressible data
+ std::string base_data =
+ "Apache Arrow is a cross-language development platform for in-memory data";
+ int nrepeats = static_cast<int>(1 + data_size / base_data.size());
+
+ std::vector<uint8_t> data(base_data.size() * nrepeats);
+ for (int i = 0; i < nrepeats; ++i) {
+ std::memcpy(data.data() + i * base_data.size(), base_data.data(), base_data.size());
+ }
+ data.resize(data_size);
+
+ // Then randomly mutate some bytes so as to make things harder
+ std::mt19937 engine(42);
+ std::exponential_distribution<> offsets(0.05);
+ std::uniform_int_distribution<> values(0, 255);
+
+ int64_t pos = 0;
+ while (pos < data_size) {
+ data[pos] = static_cast<uint8_t>(values(engine));
+ pos += static_cast<int64_t>(offsets(engine));
+ }
+
+ return data;
+}
+
+int64_t StreamingCompress(Codec* codec, const std::vector<uint8_t>& data,
+ std::vector<uint8_t>* compressed_data = nullptr) {
+ if (compressed_data != nullptr) {
+ compressed_data->clear();
+ compressed_data->shrink_to_fit();
+ }
+ auto compressor = *codec->MakeCompressor();
+
+ const uint8_t* input = data.data();
+ int64_t input_len = data.size();
+ int64_t compressed_size = 0;
+
+ std::vector<uint8_t> output_buffer(1 << 20); // 1 MB
+
+ while (input_len > 0) {
+ auto result = *compressor->Compress(input_len, input, output_buffer.size(),
+ output_buffer.data());
+ input += result.bytes_read;
+ input_len -= result.bytes_read;
+ compressed_size += result.bytes_written;
+ if (compressed_data != nullptr && result.bytes_written > 0) {
+ compressed_data->resize(compressed_data->size() + result.bytes_written);
+ memcpy(compressed_data->data() + compressed_data->size() - result.bytes_written,
+ output_buffer.data(), result.bytes_written);
+ }
+ if (result.bytes_read == 0) {
+ // Need to enlarge output buffer
+ output_buffer.resize(output_buffer.size() * 2);
+ }
+ }
+ while (true) {
+ auto result = *compressor->End(output_buffer.size(), output_buffer.data());
+ compressed_size += result.bytes_written;
+ if (compressed_data != nullptr && result.bytes_written > 0) {
+ compressed_data->resize(compressed_data->size() + result.bytes_written);
+ memcpy(compressed_data->data() + compressed_data->size() - result.bytes_written,
+ output_buffer.data(), result.bytes_written);
+ }
+ if (result.should_retry) {
+ // Need to enlarge output buffer
+ output_buffer.resize(output_buffer.size() * 2);
+ } else {
+ break;
+ }
+ }
+ return compressed_size;
+}
+
+static void StreamingCompression(Compression::type compression,
+ const std::vector<uint8_t>& data,
+ benchmark::State& state) { // NOLINT non-const reference
+ auto codec = *Codec::Create(compression);
+
+ while (state.KeepRunning()) {
+ int64_t compressed_size = StreamingCompress(codec.get(), data);
+ state.counters["ratio"] =
+ static_cast<double>(data.size()) / static_cast<double>(compressed_size);
+ }
+ state.SetBytesProcessed(state.iterations() * data.size());
+}
+
+template <Compression::type COMPRESSION>
+static void ReferenceStreamingCompression(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto data = MakeCompressibleData(8 * 1024 * 1024); // 8 MB
+
+ StreamingCompression(COMPRESSION, data, state);
+}
+
+static void StreamingDecompression(
+ Compression::type compression, const std::vector<uint8_t>& data,
+ benchmark::State& state) { // NOLINT non-const reference
+ auto codec = *Codec::Create(compression);
+
+ std::vector<uint8_t> compressed_data;
+ ARROW_UNUSED(StreamingCompress(codec.get(), data, &compressed_data));
+ state.counters["ratio"] =
+ static_cast<double>(data.size()) / static_cast<double>(compressed_data.size());
+
+ while (state.KeepRunning()) {
+ auto decompressor = *codec->MakeDecompressor();
+
+ const uint8_t* input = compressed_data.data();
+ int64_t input_len = compressed_data.size();
+ int64_t decompressed_size = 0;
+
+ std::vector<uint8_t> output_buffer(1 << 20); // 1 MB
+ while (!decompressor->IsFinished()) {
+ auto result = *decompressor->Decompress(input_len, input, output_buffer.size(),
+ output_buffer.data());
+ input += result.bytes_read;
+ input_len -= result.bytes_read;
+ decompressed_size += result.bytes_written;
+ if (result.need_more_output) {
+ // Enlarge output buffer
+ output_buffer.resize(output_buffer.size() * 2);
+ }
+ }
+ ARROW_CHECK(decompressed_size == static_cast<int64_t>(data.size()));
+ }
+ state.SetBytesProcessed(state.iterations() * data.size());
+}
+
+template <Compression::type COMPRESSION>
+static void ReferenceStreamingDecompression(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto data = MakeCompressibleData(8 * 1024 * 1024); // 8 MB
+
+ StreamingDecompression(COMPRESSION, data, state);
+}
+
+#ifdef ARROW_WITH_ZLIB
+BENCHMARK_TEMPLATE(ReferenceStreamingCompression, Compression::GZIP);
+BENCHMARK_TEMPLATE(ReferenceStreamingDecompression, Compression::GZIP);
+#endif
+
+#ifdef ARROW_WITH_BROTLI
+BENCHMARK_TEMPLATE(ReferenceStreamingCompression, Compression::BROTLI);
+BENCHMARK_TEMPLATE(ReferenceStreamingDecompression, Compression::BROTLI);
+#endif
+
+#ifdef ARROW_WITH_ZSTD
+BENCHMARK_TEMPLATE(ReferenceStreamingCompression, Compression::ZSTD);
+BENCHMARK_TEMPLATE(ReferenceStreamingDecompression, Compression::ZSTD);
+#endif
+
+#ifdef ARROW_WITH_LZ4
+BENCHMARK_TEMPLATE(ReferenceStreamingCompression, Compression::LZ4_FRAME);
+BENCHMARK_TEMPLATE(ReferenceStreamingDecompression, Compression::LZ4_FRAME);
+#endif
+
+#endif
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/compression_brotli.cc b/src/arrow/cpp/src/arrow/util/compression_brotli.cc
new file mode 100644
index 000000000..cb547c2c8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/compression_brotli.cc
@@ -0,0 +1,245 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/compression_internal.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+
+#include <brotli/decode.h>
+#include <brotli/encode.h>
+#include <brotli/types.h>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace util {
+namespace internal {
+
+namespace {
+
+class BrotliDecompressor : public Decompressor {
+ public:
+ ~BrotliDecompressor() override {
+ if (state_ != nullptr) {
+ BrotliDecoderDestroyInstance(state_);
+ }
+ }
+
+ Status Init() {
+ state_ = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr);
+ if (state_ == nullptr) {
+ return BrotliError("Brotli init failed");
+ }
+ return Status::OK();
+ }
+
+ Status Reset() override {
+ if (state_ != nullptr) {
+ BrotliDecoderDestroyInstance(state_);
+ }
+ return Init();
+ }
+
+ Result<DecompressResult> Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_len, uint8_t* output) override {
+ auto avail_in = static_cast<size_t>(input_len);
+ auto avail_out = static_cast<size_t>(output_len);
+ BrotliDecoderResult ret;
+
+ ret = BrotliDecoderDecompressStream(state_, &avail_in, &input, &avail_out, &output,
+ nullptr /* total_out */);
+ if (ret == BROTLI_DECODER_RESULT_ERROR) {
+ return BrotliError(BrotliDecoderGetErrorCode(state_), "Brotli decompress failed: ");
+ }
+ return DecompressResult{static_cast<int64_t>(input_len - avail_in),
+ static_cast<int64_t>(output_len - avail_out),
+ (ret == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT)};
+ }
+
+ bool IsFinished() override { return BrotliDecoderIsFinished(state_); }
+
+ protected:
+ Status BrotliError(const char* msg) { return Status::IOError(msg); }
+
+ Status BrotliError(BrotliDecoderErrorCode code, const char* prefix_msg) {
+ return Status::IOError(prefix_msg, BrotliDecoderErrorString(code));
+ }
+
+ BrotliDecoderState* state_ = nullptr;
+};
+
+// ----------------------------------------------------------------------
+// Brotli compressor implementation
+
+class BrotliCompressor : public Compressor {
+ public:
+ explicit BrotliCompressor(int compression_level)
+ : compression_level_(compression_level) {}
+
+ ~BrotliCompressor() override {
+ if (state_ != nullptr) {
+ BrotliEncoderDestroyInstance(state_);
+ }
+ }
+
+ Status Init() {
+ state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr);
+ if (state_ == nullptr) {
+ return BrotliError("Brotli init failed");
+ }
+ if (!BrotliEncoderSetParameter(state_, BROTLI_PARAM_QUALITY, compression_level_)) {
+ return BrotliError("Brotli set compression level failed");
+ }
+ return Status::OK();
+ }
+
+ Result<CompressResult> Compress(int64_t input_len, const uint8_t* input,
+ int64_t output_len, uint8_t* output) override {
+ auto avail_in = static_cast<size_t>(input_len);
+ auto avail_out = static_cast<size_t>(output_len);
+ BROTLI_BOOL ret;
+
+ ret = BrotliEncoderCompressStream(state_, BROTLI_OPERATION_PROCESS, &avail_in, &input,
+ &avail_out, &output, nullptr /* total_out */);
+ if (!ret) {
+ return BrotliError("Brotli compress failed");
+ }
+ return CompressResult{static_cast<int64_t>(input_len - avail_in),
+ static_cast<int64_t>(output_len - avail_out)};
+ }
+
+ Result<FlushResult> Flush(int64_t output_len, uint8_t* output) override {
+ size_t avail_in = 0;
+ const uint8_t* next_in = nullptr;
+ auto avail_out = static_cast<size_t>(output_len);
+ BROTLI_BOOL ret;
+
+ ret = BrotliEncoderCompressStream(state_, BROTLI_OPERATION_FLUSH, &avail_in, &next_in,
+ &avail_out, &output, nullptr /* total_out */);
+ if (!ret) {
+ return BrotliError("Brotli flush failed");
+ }
+ return FlushResult{static_cast<int64_t>(output_len - avail_out),
+ !!BrotliEncoderHasMoreOutput(state_)};
+ }
+
+ Result<EndResult> End(int64_t output_len, uint8_t* output) override {
+ size_t avail_in = 0;
+ const uint8_t* next_in = nullptr;
+ auto avail_out = static_cast<size_t>(output_len);
+ BROTLI_BOOL ret;
+
+ ret =
+ BrotliEncoderCompressStream(state_, BROTLI_OPERATION_FINISH, &avail_in, &next_in,
+ &avail_out, &output, nullptr /* total_out */);
+ if (!ret) {
+ return BrotliError("Brotli end failed");
+ }
+ bool should_retry = !!BrotliEncoderHasMoreOutput(state_);
+ DCHECK_EQ(should_retry, !BrotliEncoderIsFinished(state_));
+ return EndResult{static_cast<int64_t>(output_len - avail_out), should_retry};
+ }
+
+ protected:
+ Status BrotliError(const char* msg) { return Status::IOError(msg); }
+
+ BrotliEncoderState* state_ = nullptr;
+
+ private:
+ const int compression_level_;
+};
+
+// ----------------------------------------------------------------------
+// Brotli codec implementation
+
+class BrotliCodec : public Codec {
+ public:
+ explicit BrotliCodec(int compression_level)
+ : compression_level_(compression_level == kUseDefaultCompressionLevel
+ ? kBrotliDefaultCompressionLevel
+ : compression_level) {}
+
+ Result<int64_t> Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) override {
+ DCHECK_GE(input_len, 0);
+ DCHECK_GE(output_buffer_len, 0);
+ std::size_t output_size = static_cast<size_t>(output_buffer_len);
+ if (BrotliDecoderDecompress(static_cast<size_t>(input_len), input, &output_size,
+ output_buffer) != BROTLI_DECODER_RESULT_SUCCESS) {
+ return Status::IOError("Corrupt brotli compressed data.");
+ }
+ return output_size;
+ }
+
+ int64_t MaxCompressedLen(int64_t input_len,
+ const uint8_t* ARROW_ARG_UNUSED(input)) override {
+ DCHECK_GE(input_len, 0);
+ return BrotliEncoderMaxCompressedSize(static_cast<size_t>(input_len));
+ }
+
+ Result<int64_t> Compress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) override {
+ DCHECK_GE(input_len, 0);
+ DCHECK_GE(output_buffer_len, 0);
+ std::size_t output_size = static_cast<size_t>(output_buffer_len);
+ if (BrotliEncoderCompress(compression_level_, BROTLI_DEFAULT_WINDOW,
+ BROTLI_DEFAULT_MODE, static_cast<size_t>(input_len), input,
+ &output_size, output_buffer) == BROTLI_FALSE) {
+ return Status::IOError("Brotli compression failure.");
+ }
+ return output_size;
+ }
+
+ Result<std::shared_ptr<Compressor>> MakeCompressor() override {
+ auto ptr = std::make_shared<BrotliCompressor>(compression_level_);
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+ }
+
+ Result<std::shared_ptr<Decompressor>> MakeDecompressor() override {
+ auto ptr = std::make_shared<BrotliDecompressor>();
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+ }
+
+ Compression::type compression_type() const override { return Compression::BROTLI; }
+
+ int compression_level() const override { return compression_level_; }
+ int minimum_compression_level() const override { return BROTLI_MIN_QUALITY; }
+ int maximum_compression_level() const override { return BROTLI_MAX_QUALITY; }
+ int default_compression_level() const override {
+ return kBrotliDefaultCompressionLevel;
+ }
+
+ private:
+ const int compression_level_;
+};
+
+} // namespace
+
+std::unique_ptr<Codec> MakeBrotliCodec(int compression_level) {
+ return std::unique_ptr<Codec>(new BrotliCodec(compression_level));
+}
+
+} // namespace internal
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/compression_bz2.cc b/src/arrow/cpp/src/arrow/util/compression_bz2.cc
new file mode 100644
index 000000000..b367f2ff2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/compression_bz2.cc
@@ -0,0 +1,287 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/compression_internal.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <sstream>
+
+// Avoid defining max() macro
+#include "arrow/util/windows_compatibility.h"
+
+#include <bzlib.h>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace util {
+namespace internal {
+
+namespace {
+
+constexpr int kBZ2MinCompressionLevel = 1;
+constexpr int kBZ2MaxCompressionLevel = 9;
+
+// Max number of bytes the bz2 APIs accept at a time
+constexpr auto kSizeLimit =
+ static_cast<int64_t>(std::numeric_limits<unsigned int>::max());
+
+Status BZ2Error(const char* prefix_msg, int bz_result) {
+ ARROW_CHECK(bz_result != BZ_OK && bz_result != BZ_RUN_OK && bz_result != BZ_FLUSH_OK &&
+ bz_result != BZ_FINISH_OK && bz_result != BZ_STREAM_END);
+ StatusCode code;
+ std::stringstream ss;
+ ss << prefix_msg;
+ switch (bz_result) {
+ case BZ_CONFIG_ERROR:
+ code = StatusCode::UnknownError;
+ ss << "bz2 library improperly configured (internal error)";
+ break;
+ case BZ_SEQUENCE_ERROR:
+ code = StatusCode::UnknownError;
+ ss << "wrong sequence of calls to bz2 library (internal error)";
+ break;
+ case BZ_PARAM_ERROR:
+ code = StatusCode::UnknownError;
+ ss << "wrong parameter to bz2 library (internal error)";
+ break;
+ case BZ_MEM_ERROR:
+ code = StatusCode::OutOfMemory;
+ ss << "could not allocate memory for bz2 library";
+ break;
+ case BZ_DATA_ERROR:
+ code = StatusCode::IOError;
+ ss << "invalid bz2 data";
+ break;
+ case BZ_DATA_ERROR_MAGIC:
+ code = StatusCode::IOError;
+ ss << "data is not bz2-compressed (no magic header)";
+ break;
+ default:
+ code = StatusCode::UnknownError;
+ ss << "unknown bz2 error " << bz_result;
+ break;
+ }
+ return Status(code, ss.str());
+}
+
+// ----------------------------------------------------------------------
+// bz2 decompressor implementation
+
+class BZ2Decompressor : public Decompressor {
+ public:
+ BZ2Decompressor() : initialized_(false) {}
+
+ ~BZ2Decompressor() override {
+ if (initialized_) {
+ ARROW_UNUSED(BZ2_bzDecompressEnd(&stream_));
+ }
+ }
+
+ Status Init() {
+ DCHECK(!initialized_);
+ memset(&stream_, 0, sizeof(stream_));
+ int ret;
+ ret = BZ2_bzDecompressInit(&stream_, 0, 0);
+ if (ret != BZ_OK) {
+ return BZ2Error("bz2 decompressor init failed: ", ret);
+ }
+ initialized_ = true;
+ finished_ = false;
+ return Status::OK();
+ }
+
+ Status Reset() override {
+ if (initialized_) {
+ ARROW_UNUSED(BZ2_bzDecompressEnd(&stream_));
+ initialized_ = false;
+ }
+ return Init();
+ }
+
+ Result<DecompressResult> Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_len, uint8_t* output) override {
+ stream_.next_in = const_cast<char*>(reinterpret_cast<const char*>(input));
+ stream_.avail_in = static_cast<unsigned int>(std::min(input_len, kSizeLimit));
+ stream_.next_out = reinterpret_cast<char*>(output);
+ stream_.avail_out = static_cast<unsigned int>(std::min(output_len, kSizeLimit));
+ int ret;
+
+ ret = BZ2_bzDecompress(&stream_);
+ if (ret == BZ_OK || ret == BZ_STREAM_END) {
+ finished_ = (ret == BZ_STREAM_END);
+ int64_t bytes_read = input_len - stream_.avail_in;
+ int64_t bytes_written = output_len - stream_.avail_out;
+ return DecompressResult{bytes_read, bytes_written,
+ (!finished_ && bytes_read == 0 && bytes_written == 0)};
+ } else {
+ return BZ2Error("bz2 decompress failed: ", ret);
+ }
+ }
+
+ bool IsFinished() override { return finished_; }
+
+ protected:
+ bz_stream stream_;
+ bool initialized_;
+ bool finished_;
+};
+
+// ----------------------------------------------------------------------
+// bz2 compressor implementation
+
+class BZ2Compressor : public Compressor {
+ public:
+ explicit BZ2Compressor(int compression_level)
+ : initialized_(false), compression_level_(compression_level) {}
+
+ ~BZ2Compressor() override {
+ if (initialized_) {
+ ARROW_UNUSED(BZ2_bzCompressEnd(&stream_));
+ }
+ }
+
+ Status Init() {
+ DCHECK(!initialized_);
+ memset(&stream_, 0, sizeof(stream_));
+ int ret;
+ ret = BZ2_bzCompressInit(&stream_, compression_level_, 0, 0);
+ if (ret != BZ_OK) {
+ return BZ2Error("bz2 compressor init failed: ", ret);
+ }
+ initialized_ = true;
+ return Status::OK();
+ }
+
+ Result<CompressResult> Compress(int64_t input_len, const uint8_t* input,
+ int64_t output_len, uint8_t* output) override {
+ stream_.next_in = const_cast<char*>(reinterpret_cast<const char*>(input));
+ stream_.avail_in = static_cast<unsigned int>(std::min(input_len, kSizeLimit));
+ stream_.next_out = reinterpret_cast<char*>(output);
+ stream_.avail_out = static_cast<unsigned int>(std::min(output_len, kSizeLimit));
+ int ret;
+
+ ret = BZ2_bzCompress(&stream_, BZ_RUN);
+ if (ret == BZ_RUN_OK) {
+ return CompressResult{input_len - stream_.avail_in, output_len - stream_.avail_out};
+ } else {
+ return BZ2Error("bz2 compress failed: ", ret);
+ }
+ }
+
+ Result<FlushResult> Flush(int64_t output_len, uint8_t* output) override {
+ stream_.next_in = nullptr;
+ stream_.avail_in = 0;
+ stream_.next_out = reinterpret_cast<char*>(output);
+ stream_.avail_out = static_cast<unsigned int>(std::min(output_len, kSizeLimit));
+ int ret;
+
+ ret = BZ2_bzCompress(&stream_, BZ_FLUSH);
+ if (ret == BZ_RUN_OK || ret == BZ_FLUSH_OK) {
+ return FlushResult{output_len - stream_.avail_out, (ret == BZ_FLUSH_OK)};
+ } else {
+ return BZ2Error("bz2 compress failed: ", ret);
+ }
+ }
+
+ Result<EndResult> End(int64_t output_len, uint8_t* output) override {
+ stream_.next_in = nullptr;
+ stream_.avail_in = 0;
+ stream_.next_out = reinterpret_cast<char*>(output);
+ stream_.avail_out = static_cast<unsigned int>(std::min(output_len, kSizeLimit));
+ int ret;
+
+ ret = BZ2_bzCompress(&stream_, BZ_FINISH);
+ if (ret == BZ_STREAM_END || ret == BZ_FINISH_OK) {
+ return EndResult{output_len - stream_.avail_out, (ret == BZ_FINISH_OK)};
+ } else {
+ return BZ2Error("bz2 compress failed: ", ret);
+ }
+ }
+
+ protected:
+ bz_stream stream_;
+ bool initialized_;
+ int compression_level_;
+};
+
+// ----------------------------------------------------------------------
+// bz2 codec implementation
+
+class BZ2Codec : public Codec {
+ public:
+ explicit BZ2Codec(int compression_level) : compression_level_(compression_level) {
+ compression_level_ = compression_level == kUseDefaultCompressionLevel
+ ? kBZ2DefaultCompressionLevel
+ : compression_level;
+ }
+
+ Result<int64_t> Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) override {
+ return Status::NotImplemented("One-shot bz2 decompression not supported");
+ }
+
+ Result<int64_t> Compress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) override {
+ return Status::NotImplemented("One-shot bz2 compression not supported");
+ }
+
+ int64_t MaxCompressedLen(int64_t input_len,
+ const uint8_t* ARROW_ARG_UNUSED(input)) override {
+ // Cannot determine upper bound for bz2-compressed data
+ return 0;
+ }
+
+ Result<std::shared_ptr<Compressor>> MakeCompressor() override {
+ auto ptr = std::make_shared<BZ2Compressor>(compression_level_);
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+ }
+
+ Result<std::shared_ptr<Decompressor>> MakeDecompressor() override {
+ auto ptr = std::make_shared<BZ2Decompressor>();
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+ }
+
+ Compression::type compression_type() const override { return Compression::BZ2; }
+
+ int compression_level() const override { return compression_level_; }
+ int minimum_compression_level() const override { return kBZ2MinCompressionLevel; }
+ int maximum_compression_level() const override { return kBZ2MaxCompressionLevel; }
+ int default_compression_level() const override { return kBZ2DefaultCompressionLevel; }
+
+ private:
+ int compression_level_;
+};
+
+} // namespace
+
+std::unique_ptr<Codec> MakeBZ2Codec(int compression_level) {
+ return std::unique_ptr<Codec>(new BZ2Codec(compression_level));
+}
+
+} // namespace internal
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/compression_internal.h b/src/arrow/cpp/src/arrow/util/compression_internal.h
new file mode 100644
index 000000000..268672e14
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/compression_internal.h
@@ -0,0 +1,80 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/util/compression.h" // IWYU pragma: export
+
+namespace arrow {
+namespace util {
+
+// ----------------------------------------------------------------------
+// Internal Codec factories
+
+namespace internal {
+
+// Brotli compression quality is max (11) by default, which is slow.
+// We use 8 as a default as it is the best trade-off for Parquet workload.
+constexpr int kBrotliDefaultCompressionLevel = 8;
+
+// Brotli codec.
+std::unique_ptr<Codec> MakeBrotliCodec(
+ int compression_level = kBrotliDefaultCompressionLevel);
+
+// BZ2 codec.
+constexpr int kBZ2DefaultCompressionLevel = 9;
+std::unique_ptr<Codec> MakeBZ2Codec(int compression_level = kBZ2DefaultCompressionLevel);
+
+// GZip
+constexpr int kGZipDefaultCompressionLevel = 9;
+
+struct GZipFormat {
+ enum type {
+ ZLIB,
+ DEFLATE,
+ GZIP,
+ };
+};
+
+std::unique_ptr<Codec> MakeGZipCodec(int compression_level = kGZipDefaultCompressionLevel,
+ GZipFormat::type format = GZipFormat::GZIP);
+
+// Snappy
+std::unique_ptr<Codec> MakeSnappyCodec();
+
+// Lz4 "raw" format codec.
+std::unique_ptr<Codec> MakeLz4RawCodec();
+
+// Lz4 "Hadoop" format codec (== Lz4 raw codec prefixed with lengths header)
+std::unique_ptr<Codec> MakeLz4HadoopRawCodec();
+
+// Lz4 frame format codec.
+std::unique_ptr<Codec> MakeLz4FrameCodec();
+
+// ZSTD codec.
+
+// XXX level = 1 probably doesn't compress very much
+constexpr int kZSTDDefaultCompressionLevel = 1;
+
+std::unique_ptr<Codec> MakeZSTDCodec(
+ int compression_level = kZSTDDefaultCompressionLevel);
+
+} // namespace internal
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/compression_lz4.cc b/src/arrow/cpp/src/arrow/util/compression_lz4.cc
new file mode 100644
index 000000000..c783e4055
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/compression_lz4.cc
@@ -0,0 +1,495 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/compression.h"
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+
+#include <lz4.h>
+#include <lz4frame.h>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/ubsan.h"
+
+#ifndef LZ4F_HEADER_SIZE_MAX
+#define LZ4F_HEADER_SIZE_MAX 19
+#endif
+
+namespace arrow {
+namespace util {
+
+namespace {
+
+static Status LZ4Error(LZ4F_errorCode_t ret, const char* prefix_msg) {
+ return Status::IOError(prefix_msg, LZ4F_getErrorName(ret));
+}
+
+static LZ4F_preferences_t DefaultPreferences() {
+ LZ4F_preferences_t prefs;
+ memset(&prefs, 0, sizeof(prefs));
+ return prefs;
+}
+
+// ----------------------------------------------------------------------
+// Lz4 frame decompressor implementation
+
+class LZ4Decompressor : public Decompressor {
+ public:
+ LZ4Decompressor() {}
+
+ ~LZ4Decompressor() override {
+ if (ctx_ != nullptr) {
+ ARROW_UNUSED(LZ4F_freeDecompressionContext(ctx_));
+ }
+ }
+
+ Status Init() {
+ LZ4F_errorCode_t ret;
+ finished_ = false;
+
+ ret = LZ4F_createDecompressionContext(&ctx_, LZ4F_VERSION);
+ if (LZ4F_isError(ret)) {
+ return LZ4Error(ret, "LZ4 init failed: ");
+ } else {
+ return Status::OK();
+ }
+ }
+
+ Status Reset() override {
+#if defined(LZ4_VERSION_NUMBER) && LZ4_VERSION_NUMBER >= 10800
+ // LZ4F_resetDecompressionContext appeared in 1.8.0
+ DCHECK_NE(ctx_, nullptr);
+ LZ4F_resetDecompressionContext(ctx_);
+ finished_ = false;
+ return Status::OK();
+#else
+ if (ctx_ != nullptr) {
+ ARROW_UNUSED(LZ4F_freeDecompressionContext(ctx_));
+ }
+ return Init();
+#endif
+ }
+
+ Result<DecompressResult> Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_len, uint8_t* output) override {
+ auto src = input;
+ auto dst = output;
+ auto src_size = static_cast<size_t>(input_len);
+ auto dst_capacity = static_cast<size_t>(output_len);
+ size_t ret;
+
+ ret =
+ LZ4F_decompress(ctx_, dst, &dst_capacity, src, &src_size, nullptr /* options */);
+ if (LZ4F_isError(ret)) {
+ return LZ4Error(ret, "LZ4 decompress failed: ");
+ }
+ finished_ = (ret == 0);
+ return DecompressResult{static_cast<int64_t>(src_size),
+ static_cast<int64_t>(dst_capacity),
+ (src_size == 0 && dst_capacity == 0)};
+ }
+
+ bool IsFinished() override { return finished_; }
+
+ protected:
+ LZ4F_decompressionContext_t ctx_ = nullptr;
+ bool finished_;
+};
+
+// ----------------------------------------------------------------------
+// Lz4 frame compressor implementation
+
+class LZ4Compressor : public Compressor {
+ public:
+ LZ4Compressor() {}
+
+ ~LZ4Compressor() override {
+ if (ctx_ != nullptr) {
+ ARROW_UNUSED(LZ4F_freeCompressionContext(ctx_));
+ }
+ }
+
+ Status Init() {
+ LZ4F_errorCode_t ret;
+ prefs_ = DefaultPreferences();
+ first_time_ = true;
+
+ ret = LZ4F_createCompressionContext(&ctx_, LZ4F_VERSION);
+ if (LZ4F_isError(ret)) {
+ return LZ4Error(ret, "LZ4 init failed: ");
+ } else {
+ return Status::OK();
+ }
+ }
+
+#define BEGIN_COMPRESS(dst, dst_capacity, output_too_small) \
+ if (first_time_) { \
+ if (dst_capacity < LZ4F_HEADER_SIZE_MAX) { \
+ /* Output too small to write LZ4F header */ \
+ return (output_too_small); \
+ } \
+ ret = LZ4F_compressBegin(ctx_, dst, dst_capacity, &prefs_); \
+ if (LZ4F_isError(ret)) { \
+ return LZ4Error(ret, "LZ4 compress begin failed: "); \
+ } \
+ first_time_ = false; \
+ dst += ret; \
+ dst_capacity -= ret; \
+ bytes_written += static_cast<int64_t>(ret); \
+ }
+
+ Result<CompressResult> Compress(int64_t input_len, const uint8_t* input,
+ int64_t output_len, uint8_t* output) override {
+ auto src = input;
+ auto dst = output;
+ auto src_size = static_cast<size_t>(input_len);
+ auto dst_capacity = static_cast<size_t>(output_len);
+ size_t ret;
+ int64_t bytes_written = 0;
+
+ BEGIN_COMPRESS(dst, dst_capacity, (CompressResult{0, 0}));
+
+ if (dst_capacity < LZ4F_compressBound(src_size, &prefs_)) {
+ // Output too small to compress into
+ return CompressResult{0, bytes_written};
+ }
+ ret = LZ4F_compressUpdate(ctx_, dst, dst_capacity, src, src_size,
+ nullptr /* options */);
+ if (LZ4F_isError(ret)) {
+ return LZ4Error(ret, "LZ4 compress update failed: ");
+ }
+ bytes_written += static_cast<int64_t>(ret);
+ DCHECK_LE(bytes_written, output_len);
+ return CompressResult{input_len, bytes_written};
+ }
+
+ Result<FlushResult> Flush(int64_t output_len, uint8_t* output) override {
+ auto dst = output;
+ auto dst_capacity = static_cast<size_t>(output_len);
+ size_t ret;
+ int64_t bytes_written = 0;
+
+ BEGIN_COMPRESS(dst, dst_capacity, (FlushResult{0, true}));
+
+ if (dst_capacity < LZ4F_compressBound(0, &prefs_)) {
+ // Output too small to flush into
+ return FlushResult{bytes_written, true};
+ }
+
+ ret = LZ4F_flush(ctx_, dst, dst_capacity, nullptr /* options */);
+ if (LZ4F_isError(ret)) {
+ return LZ4Error(ret, "LZ4 flush failed: ");
+ }
+ bytes_written += static_cast<int64_t>(ret);
+ DCHECK_LE(bytes_written, output_len);
+ return FlushResult{bytes_written, false};
+ }
+
+ Result<EndResult> End(int64_t output_len, uint8_t* output) override {
+ auto dst = output;
+ auto dst_capacity = static_cast<size_t>(output_len);
+ size_t ret;
+ int64_t bytes_written = 0;
+
+ BEGIN_COMPRESS(dst, dst_capacity, (EndResult{0, true}));
+
+ if (dst_capacity < LZ4F_compressBound(0, &prefs_)) {
+ // Output too small to end frame into
+ return EndResult{bytes_written, true};
+ }
+
+ ret = LZ4F_compressEnd(ctx_, dst, dst_capacity, nullptr /* options */);
+ if (LZ4F_isError(ret)) {
+ return LZ4Error(ret, "LZ4 end failed: ");
+ }
+ bytes_written += static_cast<int64_t>(ret);
+ DCHECK_LE(bytes_written, output_len);
+ return EndResult{bytes_written, false};
+ }
+
+#undef BEGIN_COMPRESS
+
+ protected:
+ LZ4F_compressionContext_t ctx_ = nullptr;
+ LZ4F_preferences_t prefs_;
+ bool first_time_;
+};
+
+// ----------------------------------------------------------------------
+// Lz4 frame codec implementation
+
+class Lz4FrameCodec : public Codec {
+ public:
+ Lz4FrameCodec() : prefs_(DefaultPreferences()) {}
+
+ int64_t MaxCompressedLen(int64_t input_len,
+ const uint8_t* ARROW_ARG_UNUSED(input)) override {
+ return static_cast<int64_t>(
+ LZ4F_compressFrameBound(static_cast<size_t>(input_len), &prefs_));
+ }
+
+ Result<int64_t> Compress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) override {
+ auto output_len =
+ LZ4F_compressFrame(output_buffer, static_cast<size_t>(output_buffer_len), input,
+ static_cast<size_t>(input_len), &prefs_);
+ if (LZ4F_isError(output_len)) {
+ return LZ4Error(output_len, "Lz4 compression failure: ");
+ }
+ return static_cast<int64_t>(output_len);
+ }
+
+ Result<int64_t> Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) override {
+ ARROW_ASSIGN_OR_RAISE(auto decomp, MakeDecompressor());
+
+ int64_t total_bytes_written = 0;
+ while (!decomp->IsFinished() && input_len != 0) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto res,
+ decomp->Decompress(input_len, input, output_buffer_len, output_buffer));
+ input += res.bytes_read;
+ input_len -= res.bytes_read;
+ output_buffer += res.bytes_written;
+ output_buffer_len -= res.bytes_written;
+ total_bytes_written += res.bytes_written;
+ if (res.need_more_output) {
+ return Status::IOError("Lz4 decompression buffer too small");
+ }
+ }
+ if (!decomp->IsFinished()) {
+ return Status::IOError("Lz4 compressed input contains less than one frame");
+ }
+ if (input_len != 0) {
+ return Status::IOError("Lz4 compressed input contains more than one frame");
+ }
+ return total_bytes_written;
+ }
+
+ Result<std::shared_ptr<Compressor>> MakeCompressor() override {
+ auto ptr = std::make_shared<LZ4Compressor>();
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+ }
+
+ Result<std::shared_ptr<Decompressor>> MakeDecompressor() override {
+ auto ptr = std::make_shared<LZ4Decompressor>();
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+ }
+
+ Compression::type compression_type() const override { return Compression::LZ4_FRAME; }
+ int minimum_compression_level() const override { return kUseDefaultCompressionLevel; }
+ int maximum_compression_level() const override { return kUseDefaultCompressionLevel; }
+ int default_compression_level() const override { return kUseDefaultCompressionLevel; }
+
+ protected:
+ const LZ4F_preferences_t prefs_;
+};
+
+// ----------------------------------------------------------------------
+// Lz4 "raw" codec implementation
+
+class Lz4Codec : public Codec {
+ public:
+ Result<int64_t> Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) override {
+ int64_t decompressed_size = LZ4_decompress_safe(
+ reinterpret_cast<const char*>(input), reinterpret_cast<char*>(output_buffer),
+ static_cast<int>(input_len), static_cast<int>(output_buffer_len));
+ if (decompressed_size < 0) {
+ return Status::IOError("Corrupt Lz4 compressed data.");
+ }
+ return decompressed_size;
+ }
+
+ int64_t MaxCompressedLen(int64_t input_len,
+ const uint8_t* ARROW_ARG_UNUSED(input)) override {
+ return LZ4_compressBound(static_cast<int>(input_len));
+ }
+
+ Result<int64_t> Compress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) override {
+ int64_t output_len = LZ4_compress_default(
+ reinterpret_cast<const char*>(input), reinterpret_cast<char*>(output_buffer),
+ static_cast<int>(input_len), static_cast<int>(output_buffer_len));
+ if (output_len == 0) {
+ return Status::IOError("Lz4 compression failure.");
+ }
+ return output_len;
+ }
+
+ Result<std::shared_ptr<Compressor>> MakeCompressor() override {
+ return Status::NotImplemented(
+ "Streaming compression unsupported with LZ4 raw format. "
+ "Try using LZ4 frame format instead.");
+ }
+
+ Result<std::shared_ptr<Decompressor>> MakeDecompressor() override {
+ return Status::NotImplemented(
+ "Streaming decompression unsupported with LZ4 raw format. "
+ "Try using LZ4 frame format instead.");
+ }
+
+ Compression::type compression_type() const override { return Compression::LZ4; }
+ int minimum_compression_level() const override { return kUseDefaultCompressionLevel; }
+ int maximum_compression_level() const override { return kUseDefaultCompressionLevel; }
+ int default_compression_level() const override { return kUseDefaultCompressionLevel; }
+};
+
+// ----------------------------------------------------------------------
+// Lz4 Hadoop "raw" codec implementation
+
+class Lz4HadoopCodec : public Lz4Codec {
+ public:
+ Result<int64_t> Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) override {
+ const int64_t decompressed_size =
+ TryDecompressHadoop(input_len, input, output_buffer_len, output_buffer);
+ if (decompressed_size != kNotHadoop) {
+ return decompressed_size;
+ }
+ // Fall back on raw LZ4 codec (for files produces by earlier versions of Parquet C++)
+ return Lz4Codec::Decompress(input_len, input, output_buffer_len, output_buffer);
+ }
+
+ int64_t MaxCompressedLen(int64_t input_len,
+ const uint8_t* ARROW_ARG_UNUSED(input)) override {
+ return kPrefixLength + Lz4Codec::MaxCompressedLen(input_len, nullptr);
+ }
+
+ Result<int64_t> Compress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) override {
+ if (output_buffer_len < kPrefixLength) {
+ return Status::Invalid("Output buffer too small for Lz4HadoopCodec compression");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ int64_t output_len,
+ Lz4Codec::Compress(input_len, input, output_buffer_len - kPrefixLength,
+ output_buffer + kPrefixLength));
+
+ // Prepend decompressed size in bytes and compressed size in bytes
+ // to be compatible with Hadoop Lz4Codec
+ const uint32_t decompressed_size =
+ BitUtil::ToBigEndian(static_cast<uint32_t>(input_len));
+ const uint32_t compressed_size =
+ BitUtil::ToBigEndian(static_cast<uint32_t>(output_len));
+ SafeStore(output_buffer, decompressed_size);
+ SafeStore(output_buffer + sizeof(uint32_t), compressed_size);
+
+ return kPrefixLength + output_len;
+ }
+
+ Result<std::shared_ptr<Compressor>> MakeCompressor() override {
+ return Status::NotImplemented(
+ "Streaming compression unsupported with LZ4 Hadoop raw format. "
+ "Try using LZ4 frame format instead.");
+ }
+
+ Result<std::shared_ptr<Decompressor>> MakeDecompressor() override {
+ return Status::NotImplemented(
+ "Streaming decompression unsupported with LZ4 Hadoop raw format. "
+ "Try using LZ4 frame format instead.");
+ }
+
+ Compression::type compression_type() const override { return Compression::LZ4_HADOOP; }
+
+ protected:
+ // Offset starting at which page data can be read/written
+ static const int64_t kPrefixLength = sizeof(uint32_t) * 2;
+
+ static const int64_t kNotHadoop = -1;
+
+ int64_t TryDecompressHadoop(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) {
+ // Parquet files written with the Hadoop Lz4Codec use their own framing.
+ // The input buffer can contain an arbitrary number of "frames", each
+ // with the following structure:
+ // - bytes 0..3: big-endian uint32_t representing the frame decompressed size
+ // - bytes 4..7: big-endian uint32_t representing the frame compressed size
+ // - bytes 8...: frame compressed data
+ //
+ // The Hadoop Lz4Codec source code can be found here:
+ // https://github.com/apache/hadoop/blob/trunk/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-nativetask/src/main/native/src/codec/Lz4Codec.cc
+ int64_t total_decompressed_size = 0;
+
+ while (input_len >= kPrefixLength) {
+ const uint32_t expected_decompressed_size =
+ BitUtil::FromBigEndian(SafeLoadAs<uint32_t>(input));
+ const uint32_t expected_compressed_size =
+ BitUtil::FromBigEndian(SafeLoadAs<uint32_t>(input + sizeof(uint32_t)));
+ input += kPrefixLength;
+ input_len -= kPrefixLength;
+
+ if (input_len < expected_compressed_size) {
+ // Not enough bytes for Hadoop "frame"
+ return kNotHadoop;
+ }
+ if (output_buffer_len < expected_decompressed_size) {
+ // Not enough bytes to hold advertised output => probably not Hadoop
+ return kNotHadoop;
+ }
+ // Try decompressing and compare with expected decompressed length
+ auto maybe_decompressed_size = Lz4Codec::Decompress(
+ expected_compressed_size, input, output_buffer_len, output_buffer);
+ if (!maybe_decompressed_size.ok() ||
+ *maybe_decompressed_size != expected_decompressed_size) {
+ return kNotHadoop;
+ }
+ input += expected_compressed_size;
+ input_len -= expected_compressed_size;
+ output_buffer += expected_decompressed_size;
+ output_buffer_len -= expected_decompressed_size;
+ total_decompressed_size += expected_decompressed_size;
+ }
+
+ if (input_len == 0) {
+ return total_decompressed_size;
+ } else {
+ return kNotHadoop;
+ }
+ }
+};
+
+} // namespace
+
+namespace internal {
+
+std::unique_ptr<Codec> MakeLz4FrameCodec() {
+ return std::unique_ptr<Codec>(new Lz4FrameCodec());
+}
+
+std::unique_ptr<Codec> MakeLz4HadoopRawCodec() {
+ return std::unique_ptr<Codec>(new Lz4HadoopCodec());
+}
+
+std::unique_ptr<Codec> MakeLz4RawCodec() {
+ return std::unique_ptr<Codec>(new Lz4Codec());
+}
+
+} // namespace internal
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/compression_snappy.cc b/src/arrow/cpp/src/arrow/util/compression_snappy.cc
new file mode 100644
index 000000000..3756f957d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/compression_snappy.cc
@@ -0,0 +1,102 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/compression_internal.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+
+#include <snappy.h>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+using std::size_t;
+
+namespace arrow {
+namespace util {
+namespace internal {
+
+namespace {
+
+// ----------------------------------------------------------------------
+// Snappy implementation
+
+class SnappyCodec : public Codec {
+ public:
+ Result<int64_t> Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) override {
+ size_t decompressed_size;
+ if (!snappy::GetUncompressedLength(reinterpret_cast<const char*>(input),
+ static_cast<size_t>(input_len),
+ &decompressed_size)) {
+ return Status::IOError("Corrupt snappy compressed data.");
+ }
+ if (output_buffer_len < static_cast<int64_t>(decompressed_size)) {
+ return Status::Invalid("Output buffer size (", output_buffer_len, ") must be ",
+ decompressed_size, " or larger.");
+ }
+ if (!snappy::RawUncompress(reinterpret_cast<const char*>(input),
+ static_cast<size_t>(input_len),
+ reinterpret_cast<char*>(output_buffer))) {
+ return Status::IOError("Corrupt snappy compressed data.");
+ }
+ return static_cast<int64_t>(decompressed_size);
+ }
+
+ int64_t MaxCompressedLen(int64_t input_len,
+ const uint8_t* ARROW_ARG_UNUSED(input)) override {
+ DCHECK_GE(input_len, 0);
+ return snappy::MaxCompressedLength(static_cast<size_t>(input_len));
+ }
+
+ Result<int64_t> Compress(int64_t input_len, const uint8_t* input,
+ int64_t ARROW_ARG_UNUSED(output_buffer_len),
+ uint8_t* output_buffer) override {
+ size_t output_size;
+ snappy::RawCompress(reinterpret_cast<const char*>(input),
+ static_cast<size_t>(input_len),
+ reinterpret_cast<char*>(output_buffer), &output_size);
+ return static_cast<int64_t>(output_size);
+ }
+
+ Result<std::shared_ptr<Compressor>> MakeCompressor() override {
+ return Status::NotImplemented("Streaming compression unsupported with Snappy");
+ }
+
+ Result<std::shared_ptr<Decompressor>> MakeDecompressor() override {
+ return Status::NotImplemented("Streaming decompression unsupported with Snappy");
+ }
+
+ Compression::type compression_type() const override { return Compression::SNAPPY; }
+ int minimum_compression_level() const override { return kUseDefaultCompressionLevel; }
+ int maximum_compression_level() const override { return kUseDefaultCompressionLevel; }
+ int default_compression_level() const override { return kUseDefaultCompressionLevel; }
+};
+
+} // namespace
+
+std::unique_ptr<Codec> MakeSnappyCodec() {
+ return std::unique_ptr<Codec>(new SnappyCodec());
+}
+
+} // namespace internal
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/compression_test.cc b/src/arrow/cpp/src/arrow/util/compression_test.cc
new file mode 100644
index 000000000..2dbbf607b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/compression_test.cc
@@ -0,0 +1,635 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <ostream>
+#include <random>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/result.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/compression.h"
+
+namespace arrow {
+namespace util {
+
+std::vector<uint8_t> MakeRandomData(int data_size) {
+ std::vector<uint8_t> data(data_size);
+ random_bytes(data_size, 1234, data.data());
+ return data;
+}
+
+std::vector<uint8_t> MakeCompressibleData(int data_size) {
+ std::string base_data =
+ "Apache Arrow is a cross-language development platform for in-memory data";
+ int nrepeats = static_cast<int>(1 + data_size / base_data.size());
+
+ std::vector<uint8_t> data(base_data.size() * nrepeats);
+ for (int i = 0; i < nrepeats; ++i) {
+ std::memcpy(data.data() + i * base_data.size(), base_data.data(), base_data.size());
+ }
+ data.resize(data_size);
+ return data;
+}
+
+// Check roundtrip of one-shot compression and decompression functions.
+void CheckCodecRoundtrip(std::unique_ptr<Codec>& c1, std::unique_ptr<Codec>& c2,
+ const std::vector<uint8_t>& data, bool check_reverse = true) {
+ int max_compressed_len =
+ static_cast<int>(c1->MaxCompressedLen(data.size(), data.data()));
+ std::vector<uint8_t> compressed(max_compressed_len);
+ std::vector<uint8_t> decompressed(data.size());
+
+ // compress with c1
+ int64_t actual_size;
+ ASSERT_OK_AND_ASSIGN(actual_size, c1->Compress(data.size(), data.data(),
+ max_compressed_len, compressed.data()));
+ compressed.resize(actual_size);
+
+ // decompress with c2
+ int64_t actual_decompressed_size;
+ ASSERT_OK_AND_ASSIGN(actual_decompressed_size,
+ c2->Decompress(compressed.size(), compressed.data(),
+ decompressed.size(), decompressed.data()));
+
+ ASSERT_EQ(data, decompressed);
+ ASSERT_EQ(data.size(), actual_decompressed_size);
+
+ if (check_reverse) {
+ // compress with c2
+ ASSERT_EQ(max_compressed_len,
+ static_cast<int>(c2->MaxCompressedLen(data.size(), data.data())));
+ // Resize to prevent ASAN from detecting container overflow.
+ compressed.resize(max_compressed_len);
+
+ int64_t actual_size2;
+ ASSERT_OK_AND_ASSIGN(
+ actual_size2,
+ c2->Compress(data.size(), data.data(), max_compressed_len, compressed.data()));
+ ASSERT_EQ(actual_size2, actual_size);
+ compressed.resize(actual_size2);
+
+ // decompress with c1
+ int64_t actual_decompressed_size2;
+ ASSERT_OK_AND_ASSIGN(actual_decompressed_size2,
+ c1->Decompress(compressed.size(), compressed.data(),
+ decompressed.size(), decompressed.data()));
+
+ ASSERT_EQ(data, decompressed);
+ ASSERT_EQ(data.size(), actual_decompressed_size2);
+ }
+}
+
+// Check the streaming compressor against one-shot decompression
+
+void CheckStreamingCompressor(Codec* codec, const std::vector<uint8_t>& data) {
+ std::shared_ptr<Compressor> compressor;
+ ASSERT_OK_AND_ASSIGN(compressor, codec->MakeCompressor());
+
+ std::vector<uint8_t> compressed;
+ int64_t compressed_size = 0;
+ const uint8_t* input = data.data();
+ int64_t remaining = data.size();
+
+ compressed.resize(10);
+ bool do_flush = false;
+
+ while (remaining > 0) {
+ // Feed a small amount each time
+ int64_t input_len = std::min(remaining, static_cast<int64_t>(1111));
+ int64_t output_len = compressed.size() - compressed_size;
+ uint8_t* output = compressed.data() + compressed_size;
+ ASSERT_OK_AND_ASSIGN(auto result,
+ compressor->Compress(input_len, input, output_len, output));
+ ASSERT_LE(result.bytes_read, input_len);
+ ASSERT_LE(result.bytes_written, output_len);
+ compressed_size += result.bytes_written;
+ input += result.bytes_read;
+ remaining -= result.bytes_read;
+ if (result.bytes_read == 0) {
+ compressed.resize(compressed.capacity() * 2);
+ }
+ // Once every two iterations, do a flush
+ if (do_flush) {
+ Compressor::FlushResult result;
+ do {
+ output_len = compressed.size() - compressed_size;
+ output = compressed.data() + compressed_size;
+ ASSERT_OK_AND_ASSIGN(result, compressor->Flush(output_len, output));
+ ASSERT_LE(result.bytes_written, output_len);
+ compressed_size += result.bytes_written;
+ if (result.should_retry) {
+ compressed.resize(compressed.capacity() * 2);
+ }
+ } while (result.should_retry);
+ }
+ do_flush = !do_flush;
+ }
+
+ // End the compressed stream
+ Compressor::EndResult result;
+ do {
+ int64_t output_len = compressed.size() - compressed_size;
+ uint8_t* output = compressed.data() + compressed_size;
+ ASSERT_OK_AND_ASSIGN(result, compressor->End(output_len, output));
+ ASSERT_LE(result.bytes_written, output_len);
+ compressed_size += result.bytes_written;
+ if (result.should_retry) {
+ compressed.resize(compressed.capacity() * 2);
+ }
+ } while (result.should_retry);
+
+ // Check decompressing the compressed data
+ std::vector<uint8_t> decompressed(data.size());
+ ASSERT_OK(codec->Decompress(compressed_size, compressed.data(), decompressed.size(),
+ decompressed.data()));
+
+ ASSERT_EQ(data, decompressed);
+}
+
+// Check the streaming decompressor against one-shot compression
+
+void CheckStreamingDecompressor(Codec* codec, const std::vector<uint8_t>& data) {
+ // Create compressed data
+ int64_t max_compressed_len = codec->MaxCompressedLen(data.size(), data.data());
+ std::vector<uint8_t> compressed(max_compressed_len);
+ int64_t compressed_size;
+ ASSERT_OK_AND_ASSIGN(
+ compressed_size,
+ codec->Compress(data.size(), data.data(), max_compressed_len, compressed.data()));
+ compressed.resize(compressed_size);
+
+ // Run streaming decompression
+ std::shared_ptr<Decompressor> decompressor;
+ ASSERT_OK_AND_ASSIGN(decompressor, codec->MakeDecompressor());
+
+ std::vector<uint8_t> decompressed;
+ int64_t decompressed_size = 0;
+ const uint8_t* input = compressed.data();
+ int64_t remaining = compressed.size();
+
+ decompressed.resize(10);
+ while (!decompressor->IsFinished()) {
+ // Feed a small amount each time
+ int64_t input_len = std::min(remaining, static_cast<int64_t>(23));
+ int64_t output_len = decompressed.size() - decompressed_size;
+ uint8_t* output = decompressed.data() + decompressed_size;
+ ASSERT_OK_AND_ASSIGN(auto result,
+ decompressor->Decompress(input_len, input, output_len, output));
+ ASSERT_LE(result.bytes_read, input_len);
+ ASSERT_LE(result.bytes_written, output_len);
+ ASSERT_TRUE(result.need_more_output || result.bytes_written > 0 ||
+ result.bytes_read > 0)
+ << "Decompression not progressing anymore";
+ if (result.need_more_output) {
+ decompressed.resize(decompressed.capacity() * 2);
+ }
+ decompressed_size += result.bytes_written;
+ input += result.bytes_read;
+ remaining -= result.bytes_read;
+ }
+ ASSERT_TRUE(decompressor->IsFinished());
+ ASSERT_EQ(remaining, 0);
+
+ // Check the decompressed data
+ decompressed.resize(decompressed_size);
+ ASSERT_EQ(data.size(), decompressed_size);
+ ASSERT_EQ(data, decompressed);
+}
+
+// Check the streaming compressor and decompressor together
+
+void CheckStreamingRoundtrip(std::shared_ptr<Compressor> compressor,
+ std::shared_ptr<Decompressor> decompressor,
+ const std::vector<uint8_t>& data) {
+ std::default_random_engine engine(42);
+ std::uniform_int_distribution<int> buf_size_distribution(10, 40);
+
+ auto make_buf_size = [&]() -> int64_t { return buf_size_distribution(engine); };
+
+ // Compress...
+
+ std::vector<uint8_t> compressed(1);
+ int64_t compressed_size = 0;
+ {
+ const uint8_t* input = data.data();
+ int64_t remaining = data.size();
+
+ while (remaining > 0) {
+ // Feed a varying amount each time
+ int64_t input_len = std::min(remaining, make_buf_size());
+ int64_t output_len = compressed.size() - compressed_size;
+ uint8_t* output = compressed.data() + compressed_size;
+ ASSERT_OK_AND_ASSIGN(auto result,
+ compressor->Compress(input_len, input, output_len, output));
+ ASSERT_LE(result.bytes_read, input_len);
+ ASSERT_LE(result.bytes_written, output_len);
+ compressed_size += result.bytes_written;
+ input += result.bytes_read;
+ remaining -= result.bytes_read;
+ if (result.bytes_read == 0) {
+ compressed.resize(compressed.capacity() * 2);
+ }
+ }
+ // End the compressed stream
+ Compressor::EndResult result;
+ do {
+ int64_t output_len = compressed.size() - compressed_size;
+ uint8_t* output = compressed.data() + compressed_size;
+ ASSERT_OK_AND_ASSIGN(result, compressor->End(output_len, output));
+ ASSERT_LE(result.bytes_written, output_len);
+ compressed_size += result.bytes_written;
+ if (result.should_retry) {
+ compressed.resize(compressed.capacity() * 2);
+ }
+ } while (result.should_retry);
+
+ compressed.resize(compressed_size);
+ }
+
+ // Then decompress...
+
+ std::vector<uint8_t> decompressed(2);
+ int64_t decompressed_size = 0;
+ {
+ const uint8_t* input = compressed.data();
+ int64_t remaining = compressed.size();
+
+ while (!decompressor->IsFinished()) {
+ // Feed a varying amount each time
+ int64_t input_len = std::min(remaining, make_buf_size());
+ int64_t output_len = decompressed.size() - decompressed_size;
+ uint8_t* output = decompressed.data() + decompressed_size;
+ ASSERT_OK_AND_ASSIGN(
+ auto result, decompressor->Decompress(input_len, input, output_len, output));
+ ASSERT_LE(result.bytes_read, input_len);
+ ASSERT_LE(result.bytes_written, output_len);
+ ASSERT_TRUE(result.need_more_output || result.bytes_written > 0 ||
+ result.bytes_read > 0)
+ << "Decompression not progressing anymore";
+ if (result.need_more_output) {
+ decompressed.resize(decompressed.capacity() * 2);
+ }
+ decompressed_size += result.bytes_written;
+ input += result.bytes_read;
+ remaining -= result.bytes_read;
+ }
+ ASSERT_EQ(remaining, 0);
+ decompressed.resize(decompressed_size);
+ }
+
+ ASSERT_EQ(data.size(), decompressed.size());
+ ASSERT_EQ(data, decompressed);
+}
+
+void CheckStreamingRoundtrip(Codec* codec, const std::vector<uint8_t>& data) {
+ std::shared_ptr<Compressor> compressor;
+ std::shared_ptr<Decompressor> decompressor;
+ ASSERT_OK_AND_ASSIGN(compressor, codec->MakeCompressor());
+ ASSERT_OK_AND_ASSIGN(decompressor, codec->MakeDecompressor());
+
+ CheckStreamingRoundtrip(compressor, decompressor, data);
+}
+
+class CodecTest : public ::testing::TestWithParam<Compression::type> {
+ protected:
+ Compression::type GetCompression() { return GetParam(); }
+
+ std::unique_ptr<Codec> MakeCodec() { return *Codec::Create(GetCompression()); }
+};
+
+TEST(TestCodecMisc, GetCodecAsString) {
+ EXPECT_EQ(Codec::GetCodecAsString(Compression::UNCOMPRESSED), "uncompressed");
+ EXPECT_EQ(Codec::GetCodecAsString(Compression::SNAPPY), "snappy");
+ EXPECT_EQ(Codec::GetCodecAsString(Compression::GZIP), "gzip");
+ EXPECT_EQ(Codec::GetCodecAsString(Compression::LZO), "lzo");
+ EXPECT_EQ(Codec::GetCodecAsString(Compression::BROTLI), "brotli");
+ EXPECT_EQ(Codec::GetCodecAsString(Compression::LZ4), "lz4_raw");
+ EXPECT_EQ(Codec::GetCodecAsString(Compression::LZ4_FRAME), "lz4");
+ EXPECT_EQ(Codec::GetCodecAsString(Compression::ZSTD), "zstd");
+ EXPECT_EQ(Codec::GetCodecAsString(Compression::BZ2), "bz2");
+}
+
+TEST(TestCodecMisc, GetCompressionType) {
+ ASSERT_OK_AND_EQ(Compression::UNCOMPRESSED, Codec::GetCompressionType("uncompressed"));
+ ASSERT_OK_AND_EQ(Compression::SNAPPY, Codec::GetCompressionType("snappy"));
+ ASSERT_OK_AND_EQ(Compression::GZIP, Codec::GetCompressionType("gzip"));
+ ASSERT_OK_AND_EQ(Compression::LZO, Codec::GetCompressionType("lzo"));
+ ASSERT_OK_AND_EQ(Compression::BROTLI, Codec::GetCompressionType("brotli"));
+ ASSERT_OK_AND_EQ(Compression::LZ4, Codec::GetCompressionType("lz4_raw"));
+ ASSERT_OK_AND_EQ(Compression::LZ4_FRAME, Codec::GetCompressionType("lz4"));
+ ASSERT_OK_AND_EQ(Compression::ZSTD, Codec::GetCompressionType("zstd"));
+ ASSERT_OK_AND_EQ(Compression::BZ2, Codec::GetCompressionType("bz2"));
+
+ ASSERT_RAISES(Invalid, Codec::GetCompressionType("unk"));
+ ASSERT_RAISES(Invalid, Codec::GetCompressionType("SNAPPY"));
+}
+
+TEST_P(CodecTest, CodecRoundtrip) {
+ const auto compression = GetCompression();
+ if (compression == Compression::BZ2) {
+ GTEST_SKIP() << "BZ2 does not support one-shot compression";
+ }
+
+ int sizes[] = {0, 10000, 100000};
+
+ // create multiple compressors to try to break them
+ std::unique_ptr<Codec> c1, c2;
+ ASSERT_OK_AND_ASSIGN(c1, Codec::Create(compression));
+ ASSERT_OK_AND_ASSIGN(c2, Codec::Create(compression));
+
+ for (int data_size : sizes) {
+ std::vector<uint8_t> data = MakeRandomData(data_size);
+ CheckCodecRoundtrip(c1, c2, data);
+
+ data = MakeCompressibleData(data_size);
+ CheckCodecRoundtrip(c1, c2, data);
+ }
+}
+
+TEST(TestCodecMisc, SpecifyCompressionLevel) {
+ struct CombinationOption {
+ Compression::type codec;
+ int level;
+ bool expect_success;
+ };
+ constexpr CombinationOption combinations[] = {
+ {Compression::GZIP, 2, true}, {Compression::BROTLI, 10, true},
+ {Compression::ZSTD, 4, true}, {Compression::LZ4, -10, false},
+ {Compression::LZO, -22, false}, {Compression::UNCOMPRESSED, 10, false},
+ {Compression::SNAPPY, 16, false}, {Compression::GZIP, -992, false}};
+
+ std::vector<uint8_t> data = MakeRandomData(2000);
+ for (const auto& combination : combinations) {
+ const auto compression = combination.codec;
+ if (!Codec::IsAvailable(compression)) {
+ // Support for this codec hasn't been built
+ continue;
+ }
+ const auto level = combination.level;
+ const auto expect_success = combination.expect_success;
+ auto result1 = Codec::Create(compression, level);
+ auto result2 = Codec::Create(compression, level);
+ ASSERT_EQ(expect_success, result1.ok());
+ ASSERT_EQ(expect_success, result2.ok());
+ if (expect_success) {
+ CheckCodecRoundtrip(*result1, *result2, data);
+ }
+ }
+}
+
+TEST_P(CodecTest, MinMaxCompressionLevel) {
+ auto type = GetCompression();
+ ASSERT_OK_AND_ASSIGN(auto codec, Codec::Create(type));
+
+ if (Codec::SupportsCompressionLevel(type)) {
+ ASSERT_OK_AND_ASSIGN(auto min_level, Codec::MinimumCompressionLevel(type));
+ ASSERT_OK_AND_ASSIGN(auto max_level, Codec::MaximumCompressionLevel(type));
+ ASSERT_OK_AND_ASSIGN(auto default_level, Codec::DefaultCompressionLevel(type));
+ ASSERT_NE(min_level, Codec::UseDefaultCompressionLevel());
+ ASSERT_NE(max_level, Codec::UseDefaultCompressionLevel());
+ ASSERT_NE(default_level, Codec::UseDefaultCompressionLevel());
+ ASSERT_LT(min_level, max_level);
+ ASSERT_EQ(min_level, codec->minimum_compression_level());
+ ASSERT_EQ(max_level, codec->maximum_compression_level());
+ ASSERT_GE(default_level, min_level);
+ ASSERT_LE(default_level, max_level);
+ } else {
+ ASSERT_RAISES(Invalid, Codec::MinimumCompressionLevel(type));
+ ASSERT_RAISES(Invalid, Codec::MaximumCompressionLevel(type));
+ ASSERT_RAISES(Invalid, Codec::DefaultCompressionLevel(type));
+ ASSERT_EQ(codec->minimum_compression_level(), Codec::UseDefaultCompressionLevel());
+ ASSERT_EQ(codec->maximum_compression_level(), Codec::UseDefaultCompressionLevel());
+ ASSERT_EQ(codec->default_compression_level(), Codec::UseDefaultCompressionLevel());
+ }
+}
+
+TEST_P(CodecTest, OutputBufferIsSmall) {
+ auto type = GetCompression();
+ if (type != Compression::SNAPPY) {
+ return;
+ }
+
+ ASSERT_OK_AND_ASSIGN(auto codec, Codec::Create(type));
+
+ std::vector<uint8_t> data = MakeRandomData(10);
+ auto max_compressed_len = codec->MaxCompressedLen(data.size(), data.data());
+ std::vector<uint8_t> compressed(max_compressed_len);
+ std::vector<uint8_t> decompressed(data.size() - 1);
+
+ int64_t actual_size;
+ ASSERT_OK_AND_ASSIGN(
+ actual_size,
+ codec->Compress(data.size(), data.data(), max_compressed_len, compressed.data()));
+ compressed.resize(actual_size);
+
+ std::stringstream ss;
+ ss << "Invalid: Output buffer size (" << decompressed.size() << ") must be "
+ << data.size() << " or larger.";
+ ASSERT_RAISES_WITH_MESSAGE(Invalid, ss.str(),
+ codec->Decompress(compressed.size(), compressed.data(),
+ decompressed.size(), decompressed.data()));
+}
+
+TEST_P(CodecTest, StreamingCompressor) {
+ if (GetCompression() == Compression::SNAPPY) {
+ GTEST_SKIP() << "snappy doesn't support streaming compression";
+ }
+ if (GetCompression() == Compression::BZ2) {
+ GTEST_SKIP() << "Z2 doesn't support one-shot decompression";
+ }
+ if (GetCompression() == Compression::LZ4 ||
+ GetCompression() == Compression::LZ4_HADOOP) {
+ GTEST_SKIP() << "LZ4 raw format doesn't support streaming compression.";
+ }
+
+ int sizes[] = {0, 10, 100000};
+ for (int data_size : sizes) {
+ auto codec = MakeCodec();
+
+ std::vector<uint8_t> data = MakeRandomData(data_size);
+ CheckStreamingCompressor(codec.get(), data);
+
+ data = MakeCompressibleData(data_size);
+ CheckStreamingCompressor(codec.get(), data);
+ }
+}
+
+TEST_P(CodecTest, StreamingDecompressor) {
+ if (GetCompression() == Compression::SNAPPY) {
+ GTEST_SKIP() << "snappy doesn't support streaming decompression.";
+ }
+ if (GetCompression() == Compression::BZ2) {
+ GTEST_SKIP() << "Z2 doesn't support one-shot compression";
+ }
+ if (GetCompression() == Compression::LZ4 ||
+ GetCompression() == Compression::LZ4_HADOOP) {
+ GTEST_SKIP() << "LZ4 raw format doesn't support streaming decompression.";
+ }
+
+ int sizes[] = {0, 10, 100000};
+ for (int data_size : sizes) {
+ auto codec = MakeCodec();
+
+ std::vector<uint8_t> data = MakeRandomData(data_size);
+ CheckStreamingDecompressor(codec.get(), data);
+
+ data = MakeCompressibleData(data_size);
+ CheckStreamingDecompressor(codec.get(), data);
+ }
+}
+
+TEST_P(CodecTest, StreamingRoundtrip) {
+ if (GetCompression() == Compression::SNAPPY) {
+ GTEST_SKIP() << "snappy doesn't support streaming decompression";
+ }
+ if (GetCompression() == Compression::LZ4 ||
+ GetCompression() == Compression::LZ4_HADOOP) {
+ GTEST_SKIP() << "LZ4 raw format doesn't support streaming compression.";
+ }
+
+ int sizes[] = {0, 10, 100000};
+ for (int data_size : sizes) {
+ auto codec = MakeCodec();
+
+ std::vector<uint8_t> data = MakeRandomData(data_size);
+ CheckStreamingRoundtrip(codec.get(), data);
+
+ data = MakeCompressibleData(data_size);
+ CheckStreamingRoundtrip(codec.get(), data);
+ }
+}
+
+TEST_P(CodecTest, StreamingDecompressorReuse) {
+ if (GetCompression() == Compression::SNAPPY) {
+ GTEST_SKIP() << "snappy doesn't support streaming decompression";
+ }
+ if (GetCompression() == Compression::LZ4 ||
+ GetCompression() == Compression::LZ4_HADOOP) {
+ GTEST_SKIP() << "LZ4 raw format doesn't support streaming decompression.";
+ }
+
+ auto codec = MakeCodec();
+ std::shared_ptr<Compressor> compressor;
+ std::shared_ptr<Decompressor> decompressor;
+ ASSERT_OK_AND_ASSIGN(compressor, codec->MakeCompressor());
+ ASSERT_OK_AND_ASSIGN(decompressor, codec->MakeDecompressor());
+
+ std::vector<uint8_t> data = MakeRandomData(100);
+ CheckStreamingRoundtrip(compressor, decompressor, data);
+ // Decompressor::Reset() should allow reusing decompressor for a new stream
+ ASSERT_OK_AND_ASSIGN(compressor, codec->MakeCompressor());
+ ASSERT_OK(decompressor->Reset());
+ data = MakeRandomData(200);
+ CheckStreamingRoundtrip(compressor, decompressor, data);
+}
+
+TEST_P(CodecTest, StreamingMultiFlush) {
+ // Regression test for ARROW-11937
+ if (GetCompression() == Compression::SNAPPY) {
+ GTEST_SKIP() << "snappy doesn't support streaming decompression";
+ }
+ if (GetCompression() == Compression::LZ4 ||
+ GetCompression() == Compression::LZ4_HADOOP) {
+ GTEST_SKIP() << "LZ4 raw format doesn't support streaming decompression.";
+ }
+ auto type = GetCompression();
+ ASSERT_OK_AND_ASSIGN(auto codec, Codec::Create(type));
+
+ std::shared_ptr<Compressor> compressor;
+ ASSERT_OK_AND_ASSIGN(compressor, codec->MakeCompressor());
+
+ // Grow the buffer and flush again while requested (up to a bounded number of times)
+ std::vector<uint8_t> compressed(1024);
+ Compressor::FlushResult result;
+ int attempts = 0;
+ int64_t actual_size = 0;
+ int64_t output_len = 0;
+ uint8_t* output = compressed.data();
+ do {
+ compressed.resize(compressed.capacity() * 2);
+ output_len = compressed.size() - actual_size;
+ output = compressed.data() + actual_size;
+ ASSERT_OK_AND_ASSIGN(result, compressor->Flush(output_len, output));
+ actual_size += result.bytes_written;
+ attempts++;
+ } while (attempts < 8 && result.should_retry);
+ // The LZ4 codec actually needs this many attempts to settle
+
+ // Flush again having done nothing - should not require retry
+ output_len = compressed.size() - actual_size;
+ output = compressed.data() + actual_size;
+ ASSERT_OK_AND_ASSIGN(result, compressor->Flush(output_len, output));
+ ASSERT_FALSE(result.should_retry);
+}
+
+#if !defined ARROW_WITH_ZLIB && !defined ARROW_WITH_SNAPPY && !defined ARROW_WITH_LZ4 && \
+ !defined ARROW_WITH_BROTLI && !defined ARROW_WITH_BZ2 && !defined ARROW_WITH_ZSTD
+GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CodecTest);
+#endif
+
+#ifdef ARROW_WITH_ZLIB
+INSTANTIATE_TEST_SUITE_P(TestGZip, CodecTest, ::testing::Values(Compression::GZIP));
+#endif
+
+#ifdef ARROW_WITH_SNAPPY
+INSTANTIATE_TEST_SUITE_P(TestSnappy, CodecTest, ::testing::Values(Compression::SNAPPY));
+#endif
+
+#ifdef ARROW_WITH_LZ4
+INSTANTIATE_TEST_SUITE_P(TestLZ4, CodecTest, ::testing::Values(Compression::LZ4));
+INSTANTIATE_TEST_SUITE_P(TestLZ4Hadoop, CodecTest,
+ ::testing::Values(Compression::LZ4_HADOOP));
+#endif
+
+#ifdef ARROW_WITH_LZ4
+INSTANTIATE_TEST_SUITE_P(TestLZ4Frame, CodecTest,
+ ::testing::Values(Compression::LZ4_FRAME));
+#endif
+
+#ifdef ARROW_WITH_BROTLI
+INSTANTIATE_TEST_SUITE_P(TestBrotli, CodecTest, ::testing::Values(Compression::BROTLI));
+#endif
+
+#if ARROW_WITH_BZ2
+INSTANTIATE_TEST_SUITE_P(TestBZ2, CodecTest, ::testing::Values(Compression::BZ2));
+#endif
+
+#ifdef ARROW_WITH_ZSTD
+INSTANTIATE_TEST_SUITE_P(TestZSTD, CodecTest, ::testing::Values(Compression::ZSTD));
+#endif
+
+#ifdef ARROW_WITH_LZ4
+TEST(TestCodecLZ4Hadoop, Compatibility) {
+ // LZ4 Hadoop codec should be able to read back LZ4 raw blocks
+ ASSERT_OK_AND_ASSIGN(auto c1, Codec::Create(Compression::LZ4));
+ ASSERT_OK_AND_ASSIGN(auto c2, Codec::Create(Compression::LZ4_HADOOP));
+
+ std::vector<uint8_t> data = MakeRandomData(100);
+ CheckCodecRoundtrip(c1, c2, data, /*check_reverse=*/false);
+}
+#endif
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/compression_zlib.cc b/src/arrow/cpp/src/arrow/util/compression_zlib.cc
new file mode 100644
index 000000000..e9cb2470e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/compression_zlib.cc
@@ -0,0 +1,507 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/compression_internal.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <memory>
+
+#include <zconf.h>
+#include <zlib.h>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace util {
+namespace internal {
+
+namespace {
+
+// ----------------------------------------------------------------------
+// gzip implementation
+
+// These are magic numbers from zlib.h. Not clear why they are not defined
+// there.
+
+// Maximum window size
+constexpr int WINDOW_BITS = 15;
+
+// Output Gzip.
+constexpr int GZIP_CODEC = 16;
+
+// Determine if this is libz or gzip from header.
+constexpr int DETECT_CODEC = 32;
+
+constexpr int kGZipMinCompressionLevel = 1;
+constexpr int kGZipMaxCompressionLevel = 9;
+
+int CompressionWindowBitsForFormat(GZipFormat::type format) {
+ int window_bits = WINDOW_BITS;
+ switch (format) {
+ case GZipFormat::DEFLATE:
+ window_bits = -window_bits;
+ break;
+ case GZipFormat::GZIP:
+ window_bits += GZIP_CODEC;
+ break;
+ case GZipFormat::ZLIB:
+ break;
+ }
+ return window_bits;
+}
+
+int DecompressionWindowBitsForFormat(GZipFormat::type format) {
+ if (format == GZipFormat::DEFLATE) {
+ return -WINDOW_BITS;
+ } else {
+ /* If not deflate, autodetect format from header */
+ return WINDOW_BITS | DETECT_CODEC;
+ }
+}
+
+Status ZlibErrorPrefix(const char* prefix_msg, const char* msg) {
+ return Status::IOError(prefix_msg, (msg) ? msg : "(unknown error)");
+}
+
+// ----------------------------------------------------------------------
+// gzip decompressor implementation
+
+class GZipDecompressor : public Decompressor {
+ public:
+ explicit GZipDecompressor(GZipFormat::type format)
+ : format_(format), initialized_(false), finished_(false) {}
+
+ ~GZipDecompressor() override {
+ if (initialized_) {
+ inflateEnd(&stream_);
+ }
+ }
+
+ Status Init() {
+ DCHECK(!initialized_);
+ memset(&stream_, 0, sizeof(stream_));
+ finished_ = false;
+
+ int ret;
+ int window_bits = DecompressionWindowBitsForFormat(format_);
+ if ((ret = inflateInit2(&stream_, window_bits)) != Z_OK) {
+ return ZlibError("zlib inflateInit failed: ");
+ } else {
+ initialized_ = true;
+ return Status::OK();
+ }
+ }
+
+ Status Reset() override {
+ DCHECK(initialized_);
+ finished_ = false;
+ int ret;
+ if ((ret = inflateReset(&stream_)) != Z_OK) {
+ return ZlibError("zlib inflateReset failed: ");
+ } else {
+ return Status::OK();
+ }
+ }
+
+ Result<DecompressResult> Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_len, uint8_t* output) override {
+ static constexpr auto input_limit =
+ static_cast<int64_t>(std::numeric_limits<uInt>::max());
+ stream_.next_in = const_cast<Bytef*>(reinterpret_cast<const Bytef*>(input));
+ stream_.avail_in = static_cast<uInt>(std::min(input_len, input_limit));
+ stream_.next_out = reinterpret_cast<Bytef*>(output);
+ stream_.avail_out = static_cast<uInt>(std::min(output_len, input_limit));
+ int ret;
+
+ ret = inflate(&stream_, Z_SYNC_FLUSH);
+ if (ret == Z_DATA_ERROR || ret == Z_STREAM_ERROR || ret == Z_MEM_ERROR) {
+ return ZlibError("zlib inflate failed: ");
+ }
+ if (ret == Z_NEED_DICT) {
+ return ZlibError("zlib inflate failed (need preset dictionary): ");
+ }
+ finished_ = (ret == Z_STREAM_END);
+ if (ret == Z_BUF_ERROR) {
+ // No progress was possible
+ return DecompressResult{0, 0, true};
+ } else {
+ ARROW_CHECK(ret == Z_OK || ret == Z_STREAM_END);
+ // Some progress has been made
+ return DecompressResult{input_len - stream_.avail_in,
+ output_len - stream_.avail_out, false};
+ }
+ return Status::OK();
+ }
+
+ bool IsFinished() override { return finished_; }
+
+ protected:
+ Status ZlibError(const char* prefix_msg) {
+ return ZlibErrorPrefix(prefix_msg, stream_.msg);
+ }
+
+ z_stream stream_;
+ GZipFormat::type format_;
+ bool initialized_;
+ bool finished_;
+};
+
+// ----------------------------------------------------------------------
+// gzip compressor implementation
+
+class GZipCompressor : public Compressor {
+ public:
+ explicit GZipCompressor(int compression_level)
+ : initialized_(false), compression_level_(compression_level) {}
+
+ ~GZipCompressor() override {
+ if (initialized_) {
+ deflateEnd(&stream_);
+ }
+ }
+
+ Status Init(GZipFormat::type format) {
+ DCHECK(!initialized_);
+ memset(&stream_, 0, sizeof(stream_));
+
+ int ret;
+ // Initialize to run specified format
+ int window_bits = CompressionWindowBitsForFormat(format);
+ if ((ret = deflateInit2(&stream_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, window_bits,
+ compression_level_, Z_DEFAULT_STRATEGY)) != Z_OK) {
+ return ZlibError("zlib deflateInit failed: ");
+ } else {
+ initialized_ = true;
+ return Status::OK();
+ }
+ }
+
+ Result<CompressResult> Compress(int64_t input_len, const uint8_t* input,
+ int64_t output_len, uint8_t* output) override {
+ DCHECK(initialized_) << "Called on non-initialized stream";
+
+ static constexpr auto input_limit =
+ static_cast<int64_t>(std::numeric_limits<uInt>::max());
+
+ stream_.next_in = const_cast<Bytef*>(reinterpret_cast<const Bytef*>(input));
+ stream_.avail_in = static_cast<uInt>(std::min(input_len, input_limit));
+ stream_.next_out = reinterpret_cast<Bytef*>(output);
+ stream_.avail_out = static_cast<uInt>(std::min(output_len, input_limit));
+
+ int64_t ret = 0;
+ ret = deflate(&stream_, Z_NO_FLUSH);
+ if (ret == Z_STREAM_ERROR) {
+ return ZlibError("zlib compress failed: ");
+ }
+ if (ret == Z_OK) {
+ // Some progress has been made
+ return CompressResult{input_len - stream_.avail_in, output_len - stream_.avail_out};
+ } else {
+ // No progress was possible
+ ARROW_CHECK_EQ(ret, Z_BUF_ERROR);
+ return CompressResult{0, 0};
+ }
+ }
+
+ Result<FlushResult> Flush(int64_t output_len, uint8_t* output) override {
+ DCHECK(initialized_) << "Called on non-initialized stream";
+
+ static constexpr auto input_limit =
+ static_cast<int64_t>(std::numeric_limits<uInt>::max());
+
+ stream_.avail_in = 0;
+ stream_.next_out = reinterpret_cast<Bytef*>(output);
+ stream_.avail_out = static_cast<uInt>(std::min(output_len, input_limit));
+
+ int64_t ret = 0;
+ ret = deflate(&stream_, Z_SYNC_FLUSH);
+ if (ret == Z_STREAM_ERROR) {
+ return ZlibError("zlib flush failed: ");
+ }
+ int64_t bytes_written;
+ if (ret == Z_OK) {
+ bytes_written = output_len - stream_.avail_out;
+ } else {
+ ARROW_CHECK_EQ(ret, Z_BUF_ERROR);
+ bytes_written = 0;
+ }
+ // "If deflate returns with avail_out == 0, this function must be called
+ // again with the same value of the flush parameter and more output space
+ // (updated avail_out), until the flush is complete (deflate returns
+ // with non-zero avail_out)."
+ // "Note that Z_BUF_ERROR is not fatal, and deflate() can be called again
+ // with more input and more output space to continue compressing."
+ return FlushResult{bytes_written, stream_.avail_out == 0};
+ }
+
+ Result<EndResult> End(int64_t output_len, uint8_t* output) override {
+ DCHECK(initialized_) << "Called on non-initialized stream";
+
+ static constexpr auto input_limit =
+ static_cast<int64_t>(std::numeric_limits<uInt>::max());
+
+ stream_.avail_in = 0;
+ stream_.next_out = reinterpret_cast<Bytef*>(output);
+ stream_.avail_out = static_cast<uInt>(std::min(output_len, input_limit));
+
+ int64_t ret = 0;
+ ret = deflate(&stream_, Z_FINISH);
+ if (ret == Z_STREAM_ERROR) {
+ return ZlibError("zlib flush failed: ");
+ }
+ int64_t bytes_written = output_len - stream_.avail_out;
+ if (ret == Z_STREAM_END) {
+ // Flush complete, we can now end the stream
+ initialized_ = false;
+ ret = deflateEnd(&stream_);
+ if (ret == Z_OK) {
+ return EndResult{bytes_written, false};
+ } else {
+ return ZlibError("zlib end failed: ");
+ }
+ } else {
+ // Not everything could be flushed,
+ return EndResult{bytes_written, true};
+ }
+ }
+
+ protected:
+ Status ZlibError(const char* prefix_msg) {
+ return ZlibErrorPrefix(prefix_msg, stream_.msg);
+ }
+
+ z_stream stream_;
+ bool initialized_;
+ int compression_level_;
+};
+
+// ----------------------------------------------------------------------
+// gzip codec implementation
+
+class GZipCodec : public Codec {
+ public:
+ explicit GZipCodec(int compression_level, GZipFormat::type format)
+ : format_(format),
+ compressor_initialized_(false),
+ decompressor_initialized_(false) {
+ compression_level_ = compression_level == kUseDefaultCompressionLevel
+ ? kGZipDefaultCompressionLevel
+ : compression_level;
+ }
+
+ ~GZipCodec() override {
+ EndCompressor();
+ EndDecompressor();
+ }
+
+ Result<std::shared_ptr<Compressor>> MakeCompressor() override {
+ auto ptr = std::make_shared<GZipCompressor>(compression_level_);
+ RETURN_NOT_OK(ptr->Init(format_));
+ return ptr;
+ }
+
+ Result<std::shared_ptr<Decompressor>> MakeDecompressor() override {
+ auto ptr = std::make_shared<GZipDecompressor>(format_);
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+ }
+
+ Status InitCompressor() {
+ EndDecompressor();
+ memset(&stream_, 0, sizeof(stream_));
+
+ int ret;
+ // Initialize to run specified format
+ int window_bits = CompressionWindowBitsForFormat(format_);
+ if ((ret = deflateInit2(&stream_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, window_bits,
+ compression_level_, Z_DEFAULT_STRATEGY)) != Z_OK) {
+ return ZlibErrorPrefix("zlib deflateInit failed: ", stream_.msg);
+ }
+ compressor_initialized_ = true;
+ return Status::OK();
+ }
+
+ void EndCompressor() {
+ if (compressor_initialized_) {
+ (void)deflateEnd(&stream_);
+ }
+ compressor_initialized_ = false;
+ }
+
+ Status InitDecompressor() {
+ EndCompressor();
+ memset(&stream_, 0, sizeof(stream_));
+ int ret;
+
+ // Initialize to run either deflate or zlib/gzip format
+ int window_bits = DecompressionWindowBitsForFormat(format_);
+ if ((ret = inflateInit2(&stream_, window_bits)) != Z_OK) {
+ return ZlibErrorPrefix("zlib inflateInit failed: ", stream_.msg);
+ }
+ decompressor_initialized_ = true;
+ return Status::OK();
+ }
+
+ void EndDecompressor() {
+ if (decompressor_initialized_) {
+ (void)inflateEnd(&stream_);
+ }
+ decompressor_initialized_ = false;
+ }
+
+ Result<int64_t> Decompress(int64_t input_length, const uint8_t* input,
+ int64_t output_buffer_length, uint8_t* output) override {
+ if (!decompressor_initialized_) {
+ RETURN_NOT_OK(InitDecompressor());
+ }
+ if (output_buffer_length == 0) {
+ // The zlib library does not allow *output to be NULL, even when
+ // output_buffer_length is 0 (inflate() will return Z_STREAM_ERROR). We don't
+ // consider this an error, so bail early if no output is expected. Note that we
+ // don't signal an error if the input actually contains compressed data.
+ return 0;
+ }
+
+ // Reset the stream for this block
+ if (inflateReset(&stream_) != Z_OK) {
+ return ZlibErrorPrefix("zlib inflateReset failed: ", stream_.msg);
+ }
+
+ int ret = 0;
+ // gzip can run in streaming mode or non-streaming mode. We only
+ // support the non-streaming use case where we present it the entire
+ // compressed input and a buffer big enough to contain the entire
+ // compressed output. In the case where we don't know the output,
+ // we just make a bigger buffer and try the non-streaming mode
+ // from the beginning again.
+ while (ret != Z_STREAM_END) {
+ stream_.next_in = const_cast<Bytef*>(reinterpret_cast<const Bytef*>(input));
+ stream_.avail_in = static_cast<uInt>(input_length);
+ stream_.next_out = reinterpret_cast<Bytef*>(output);
+ stream_.avail_out = static_cast<uInt>(output_buffer_length);
+
+ // We know the output size. In this case, we can use Z_FINISH
+ // which is more efficient.
+ ret = inflate(&stream_, Z_FINISH);
+ if (ret == Z_STREAM_END || ret != Z_OK) break;
+
+ // Failure, buffer was too small
+ return Status::IOError("Too small a buffer passed to GZipCodec. InputLength=",
+ input_length, " OutputLength=", output_buffer_length);
+ }
+
+ // Failure for some other reason
+ if (ret != Z_STREAM_END) {
+ return ZlibErrorPrefix("GZipCodec failed: ", stream_.msg);
+ }
+
+ return stream_.total_out;
+ }
+
+ int64_t MaxCompressedLen(int64_t input_length,
+ const uint8_t* ARROW_ARG_UNUSED(input)) override {
+ // Must be in compression mode
+ if (!compressor_initialized_) {
+ Status s = InitCompressor();
+ ARROW_CHECK_OK(s);
+ }
+ int64_t max_len = deflateBound(&stream_, static_cast<uLong>(input_length));
+ // ARROW-3514: return a more pessimistic estimate to account for bugs
+ // in old zlib versions.
+ return max_len + 12;
+ }
+
+ Result<int64_t> Compress(int64_t input_length, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output) override {
+ if (!compressor_initialized_) {
+ RETURN_NOT_OK(InitCompressor());
+ }
+ stream_.next_in = const_cast<Bytef*>(reinterpret_cast<const Bytef*>(input));
+ stream_.avail_in = static_cast<uInt>(input_length);
+ stream_.next_out = reinterpret_cast<Bytef*>(output);
+ stream_.avail_out = static_cast<uInt>(output_buffer_len);
+
+ int64_t ret = 0;
+ if ((ret = deflate(&stream_, Z_FINISH)) != Z_STREAM_END) {
+ if (ret == Z_OK) {
+ // Will return Z_OK (and stream.msg NOT set) if stream.avail_out is too
+ // small
+ return Status::IOError("zlib deflate failed, output buffer too small");
+ }
+
+ return ZlibErrorPrefix("zlib deflate failed: ", stream_.msg);
+ }
+
+ if (deflateReset(&stream_) != Z_OK) {
+ return ZlibErrorPrefix("zlib deflateReset failed: ", stream_.msg);
+ }
+
+ // Actual output length
+ return output_buffer_len - stream_.avail_out;
+ }
+
+ Status Init() override {
+ const Status init_compressor_status = InitCompressor();
+ if (!init_compressor_status.ok()) {
+ return init_compressor_status;
+ }
+ return InitDecompressor();
+ }
+
+ Compression::type compression_type() const override { return Compression::GZIP; }
+
+ int compression_level() const override { return compression_level_; }
+ int minimum_compression_level() const override { return kGZipMinCompressionLevel; }
+ int maximum_compression_level() const override { return kGZipMaxCompressionLevel; }
+ int default_compression_level() const override { return kGZipDefaultCompressionLevel; }
+
+ private:
+ // zlib is stateful and the z_stream state variable must be initialized
+ // before
+ z_stream stream_;
+
+ // Realistically, this will always be GZIP, but we leave the option open to
+ // configure
+ GZipFormat::type format_;
+
+ // These variables are mutually exclusive. When the codec is in "compressor"
+ // state, compressor_initialized_ is true while decompressor_initialized_ is
+ // false. When it's decompressing, the opposite is true.
+ //
+ // Indeed, this is slightly hacky, but the alternative is having separate
+ // Compressor and Decompressor classes. If this ever becomes an issue, we can
+ // perform the refactoring then
+ bool compressor_initialized_;
+ bool decompressor_initialized_;
+ int compression_level_;
+};
+
+} // namespace
+
+std::unique_ptr<Codec> MakeGZipCodec(int compression_level, GZipFormat::type format) {
+ return std::unique_ptr<Codec>(new GZipCodec(compression_level, format));
+}
+
+} // namespace internal
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/compression_zstd.cc b/src/arrow/cpp/src/arrow/util/compression_zstd.cc
new file mode 100644
index 000000000..e15ecb4e1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/compression_zstd.cc
@@ -0,0 +1,249 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/compression_internal.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+
+#include <zstd.h>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+using std::size_t;
+
+namespace arrow {
+namespace util {
+namespace internal {
+
+namespace {
+
+Status ZSTDError(size_t ret, const char* prefix_msg) {
+ return Status::IOError(prefix_msg, ZSTD_getErrorName(ret));
+}
+
+// ----------------------------------------------------------------------
+// ZSTD decompressor implementation
+
+class ZSTDDecompressor : public Decompressor {
+ public:
+ ZSTDDecompressor() : stream_(ZSTD_createDStream()) {}
+
+ ~ZSTDDecompressor() override { ZSTD_freeDStream(stream_); }
+
+ Status Init() {
+ finished_ = false;
+ size_t ret = ZSTD_initDStream(stream_);
+ if (ZSTD_isError(ret)) {
+ return ZSTDError(ret, "ZSTD init failed: ");
+ } else {
+ return Status::OK();
+ }
+ }
+
+ Result<DecompressResult> Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_len, uint8_t* output) override {
+ ZSTD_inBuffer in_buf;
+ ZSTD_outBuffer out_buf;
+
+ in_buf.src = input;
+ in_buf.size = static_cast<size_t>(input_len);
+ in_buf.pos = 0;
+ out_buf.dst = output;
+ out_buf.size = static_cast<size_t>(output_len);
+ out_buf.pos = 0;
+
+ size_t ret;
+ ret = ZSTD_decompressStream(stream_, &out_buf, &in_buf);
+ if (ZSTD_isError(ret)) {
+ return ZSTDError(ret, "ZSTD decompress failed: ");
+ }
+ finished_ = (ret == 0);
+ return DecompressResult{static_cast<int64_t>(in_buf.pos),
+ static_cast<int64_t>(out_buf.pos),
+ in_buf.pos == 0 && out_buf.pos == 0};
+ }
+
+ Status Reset() override { return Init(); }
+
+ bool IsFinished() override { return finished_; }
+
+ protected:
+ ZSTD_DStream* stream_;
+ bool finished_;
+};
+
+// ----------------------------------------------------------------------
+// ZSTD compressor implementation
+
+class ZSTDCompressor : public Compressor {
+ public:
+ explicit ZSTDCompressor(int compression_level)
+ : stream_(ZSTD_createCStream()), compression_level_(compression_level) {}
+
+ ~ZSTDCompressor() override { ZSTD_freeCStream(stream_); }
+
+ Status Init() {
+ size_t ret = ZSTD_initCStream(stream_, compression_level_);
+ if (ZSTD_isError(ret)) {
+ return ZSTDError(ret, "ZSTD init failed: ");
+ } else {
+ return Status::OK();
+ }
+ }
+
+ Result<CompressResult> Compress(int64_t input_len, const uint8_t* input,
+ int64_t output_len, uint8_t* output) override {
+ ZSTD_inBuffer in_buf;
+ ZSTD_outBuffer out_buf;
+
+ in_buf.src = input;
+ in_buf.size = static_cast<size_t>(input_len);
+ in_buf.pos = 0;
+ out_buf.dst = output;
+ out_buf.size = static_cast<size_t>(output_len);
+ out_buf.pos = 0;
+
+ size_t ret;
+ ret = ZSTD_compressStream(stream_, &out_buf, &in_buf);
+ if (ZSTD_isError(ret)) {
+ return ZSTDError(ret, "ZSTD compress failed: ");
+ }
+ return CompressResult{static_cast<int64_t>(in_buf.pos),
+ static_cast<int64_t>(out_buf.pos)};
+ }
+
+ Result<FlushResult> Flush(int64_t output_len, uint8_t* output) override {
+ ZSTD_outBuffer out_buf;
+
+ out_buf.dst = output;
+ out_buf.size = static_cast<size_t>(output_len);
+ out_buf.pos = 0;
+
+ size_t ret;
+ ret = ZSTD_flushStream(stream_, &out_buf);
+ if (ZSTD_isError(ret)) {
+ return ZSTDError(ret, "ZSTD flush failed: ");
+ }
+ return FlushResult{static_cast<int64_t>(out_buf.pos), ret > 0};
+ }
+
+ Result<EndResult> End(int64_t output_len, uint8_t* output) override {
+ ZSTD_outBuffer out_buf;
+
+ out_buf.dst = output;
+ out_buf.size = static_cast<size_t>(output_len);
+ out_buf.pos = 0;
+
+ size_t ret;
+ ret = ZSTD_endStream(stream_, &out_buf);
+ if (ZSTD_isError(ret)) {
+ return ZSTDError(ret, "ZSTD end failed: ");
+ }
+ return EndResult{static_cast<int64_t>(out_buf.pos), ret > 0};
+ }
+
+ protected:
+ ZSTD_CStream* stream_;
+
+ private:
+ int compression_level_;
+};
+
+// ----------------------------------------------------------------------
+// ZSTD codec implementation
+
+class ZSTDCodec : public Codec {
+ public:
+ explicit ZSTDCodec(int compression_level)
+ : compression_level_(compression_level == kUseDefaultCompressionLevel
+ ? kZSTDDefaultCompressionLevel
+ : compression_level) {}
+
+ Result<int64_t> Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) override {
+ if (output_buffer == nullptr) {
+ // We may pass a NULL 0-byte output buffer but some zstd versions demand
+ // a valid pointer: https://github.com/facebook/zstd/issues/1385
+ static uint8_t empty_buffer;
+ DCHECK_EQ(output_buffer_len, 0);
+ output_buffer = &empty_buffer;
+ }
+
+ size_t ret = ZSTD_decompress(output_buffer, static_cast<size_t>(output_buffer_len),
+ input, static_cast<size_t>(input_len));
+ if (ZSTD_isError(ret)) {
+ return ZSTDError(ret, "ZSTD decompression failed: ");
+ }
+ if (static_cast<int64_t>(ret) != output_buffer_len) {
+ return Status::IOError("Corrupt ZSTD compressed data.");
+ }
+ return static_cast<int64_t>(ret);
+ }
+
+ int64_t MaxCompressedLen(int64_t input_len,
+ const uint8_t* ARROW_ARG_UNUSED(input)) override {
+ DCHECK_GE(input_len, 0);
+ return ZSTD_compressBound(static_cast<size_t>(input_len));
+ }
+
+ Result<int64_t> Compress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len, uint8_t* output_buffer) override {
+ size_t ret = ZSTD_compress(output_buffer, static_cast<size_t>(output_buffer_len),
+ input, static_cast<size_t>(input_len), compression_level_);
+ if (ZSTD_isError(ret)) {
+ return ZSTDError(ret, "ZSTD compression failed: ");
+ }
+ return static_cast<int64_t>(ret);
+ }
+
+ Result<std::shared_ptr<Compressor>> MakeCompressor() override {
+ auto ptr = std::make_shared<ZSTDCompressor>(compression_level_);
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+ }
+
+ Result<std::shared_ptr<Decompressor>> MakeDecompressor() override {
+ auto ptr = std::make_shared<ZSTDDecompressor>();
+ RETURN_NOT_OK(ptr->Init());
+ return ptr;
+ }
+
+ Compression::type compression_type() const override { return Compression::ZSTD; }
+ int minimum_compression_level() const override { return ZSTD_minCLevel(); }
+ int maximum_compression_level() const override { return ZSTD_maxCLevel(); }
+ int default_compression_level() const override { return kZSTDDefaultCompressionLevel; }
+
+ int compression_level() const override { return compression_level_; }
+
+ private:
+ const int compression_level_;
+};
+
+} // namespace
+
+std::unique_ptr<Codec> MakeZSTDCodec(int compression_level) {
+ return std::unique_ptr<Codec>(new ZSTDCodec(compression_level));
+}
+
+} // namespace internal
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/concurrent_map.h b/src/arrow/cpp/src/arrow/util/concurrent_map.h
new file mode 100644
index 000000000..ff1584552
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/concurrent_map.h
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <unordered_map>
+#include <utility>
+
+#include "arrow/util/mutex.h"
+
+namespace arrow {
+namespace util {
+
+template <typename K, typename V>
+class ConcurrentMap {
+ public:
+ void Insert(const K& key, const V& value) {
+ auto lock = mutex_.Lock();
+ map_.insert({key, value});
+ }
+
+ template <typename ValueFunc>
+ V GetOrInsert(const K& key, ValueFunc&& compute_value_func) {
+ auto lock = mutex_.Lock();
+ auto it = map_.find(key);
+ if (it == map_.end()) {
+ auto pair = map_.emplace(key, compute_value_func());
+ it = pair.first;
+ }
+ return it->second;
+ }
+
+ void Erase(const K& key) {
+ auto lock = mutex_.Lock();
+ map_.erase(key);
+ }
+
+ void Clear() {
+ auto lock = mutex_.Lock();
+ map_.clear();
+ }
+
+ size_t size() const {
+ auto lock = mutex_.Lock();
+ return map_.size();
+ }
+
+ private:
+ std::unordered_map<K, V> map_;
+ mutable arrow::util::Mutex mutex_;
+};
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/config.h.cmake b/src/arrow/cpp/src/arrow/util/config.h.cmake
new file mode 100644
index 000000000..de3b03ccb
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/config.h.cmake
@@ -0,0 +1,48 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#define ARROW_VERSION_MAJOR @ARROW_VERSION_MAJOR@
+#define ARROW_VERSION_MINOR @ARROW_VERSION_MINOR@
+#define ARROW_VERSION_PATCH @ARROW_VERSION_PATCH@
+#define ARROW_VERSION ((ARROW_VERSION_MAJOR * 1000) + ARROW_VERSION_MINOR) * 1000 + ARROW_VERSION_PATCH
+
+#define ARROW_VERSION_STRING "@ARROW_VERSION@"
+
+#define ARROW_SO_VERSION "@ARROW_SO_VERSION@"
+#define ARROW_FULL_SO_VERSION "@ARROW_FULL_SO_VERSION@"
+
+#define ARROW_CXX_COMPILER_ID "@CMAKE_CXX_COMPILER_ID@"
+#define ARROW_CXX_COMPILER_VERSION "@CMAKE_CXX_COMPILER_VERSION@"
+#define ARROW_CXX_COMPILER_FLAGS "@CMAKE_CXX_FLAGS@"
+
+#define ARROW_GIT_ID "@ARROW_GIT_ID@"
+#define ARROW_GIT_DESCRIPTION "@ARROW_GIT_DESCRIPTION@"
+
+#define ARROW_PACKAGE_KIND "@ARROW_PACKAGE_KIND@"
+
+#cmakedefine ARROW_COMPUTE
+#cmakedefine ARROW_CSV
+#cmakedefine ARROW_DATASET
+#cmakedefine ARROW_FILESYSTEM
+#cmakedefine ARROW_FLIGHT
+#cmakedefine ARROW_IPC
+#cmakedefine ARROW_JSON
+
+#cmakedefine ARROW_S3
+#cmakedefine ARROW_USE_NATIVE_INT128
+
+#cmakedefine GRPCPP_PP_INCLUDE
diff --git a/src/arrow/cpp/src/arrow/util/converter.h b/src/arrow/cpp/src/arrow/util/converter.h
new file mode 100644
index 000000000..0b29e0f5b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/converter.h
@@ -0,0 +1,411 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/chunked_array.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+namespace internal {
+
+template <typename BaseConverter, template <typename...> class ConverterTrait>
+static Result<std::unique_ptr<BaseConverter>> MakeConverter(
+ std::shared_ptr<DataType> type, typename BaseConverter::OptionsType options,
+ MemoryPool* pool);
+
+template <typename Input, typename Options>
+class Converter {
+ public:
+ using Self = Converter<Input, Options>;
+ using InputType = Input;
+ using OptionsType = Options;
+
+ virtual ~Converter() = default;
+
+ Status Construct(std::shared_ptr<DataType> type, OptionsType options,
+ MemoryPool* pool) {
+ type_ = std::move(type);
+ options_ = std::move(options);
+ return Init(pool);
+ }
+
+ virtual Status Append(InputType value) { return Status::NotImplemented("Append"); }
+
+ virtual Status Extend(InputType values, int64_t size, int64_t offset = 0) {
+ return Status::NotImplemented("Extend");
+ }
+
+ virtual Status ExtendMasked(InputType values, InputType mask, int64_t size,
+ int64_t offset = 0) {
+ return Status::NotImplemented("ExtendMasked");
+ }
+
+ const std::shared_ptr<ArrayBuilder>& builder() const { return builder_; }
+
+ const std::shared_ptr<DataType>& type() const { return type_; }
+
+ OptionsType options() const { return options_; }
+
+ bool may_overflow() const { return may_overflow_; }
+
+ bool rewind_on_overflow() const { return rewind_on_overflow_; }
+
+ virtual Status Reserve(int64_t additional_capacity) {
+ return builder_->Reserve(additional_capacity);
+ }
+
+ Status AppendNull() { return builder_->AppendNull(); }
+
+ virtual Result<std::shared_ptr<Array>> ToArray() { return builder_->Finish(); }
+
+ virtual Result<std::shared_ptr<Array>> ToArray(int64_t length) {
+ ARROW_ASSIGN_OR_RAISE(auto arr, this->ToArray());
+ return arr->Slice(0, length);
+ }
+
+ virtual Result<std::shared_ptr<ChunkedArray>> ToChunkedArray() {
+ ARROW_ASSIGN_OR_RAISE(auto array, ToArray());
+ std::vector<std::shared_ptr<Array>> chunks = {std::move(array)};
+ return std::make_shared<ChunkedArray>(chunks);
+ }
+
+ protected:
+ virtual Status Init(MemoryPool* pool) { return Status::OK(); }
+
+ std::shared_ptr<DataType> type_;
+ std::shared_ptr<ArrayBuilder> builder_;
+ OptionsType options_;
+ bool may_overflow_ = false;
+ bool rewind_on_overflow_ = false;
+};
+
+template <typename ArrowType, typename BaseConverter>
+class PrimitiveConverter : public BaseConverter {
+ public:
+ using BuilderType = typename TypeTraits<ArrowType>::BuilderType;
+
+ protected:
+ Status Init(MemoryPool* pool) override {
+ this->builder_ = std::make_shared<BuilderType>(this->type_, pool);
+ // Narrow variable-sized binary types may overflow
+ this->may_overflow_ = is_binary_like(this->type_->id());
+ primitive_type_ = checked_cast<const ArrowType*>(this->type_.get());
+ primitive_builder_ = checked_cast<BuilderType*>(this->builder_.get());
+ return Status::OK();
+ }
+
+ const ArrowType* primitive_type_;
+ BuilderType* primitive_builder_;
+};
+
+template <typename ArrowType, typename BaseConverter,
+ template <typename...> class ConverterTrait>
+class ListConverter : public BaseConverter {
+ public:
+ using BuilderType = typename TypeTraits<ArrowType>::BuilderType;
+ using ConverterType = typename ConverterTrait<ArrowType>::type;
+
+ protected:
+ Status Init(MemoryPool* pool) override {
+ list_type_ = checked_cast<const ArrowType*>(this->type_.get());
+ ARROW_ASSIGN_OR_RAISE(value_converter_,
+ (MakeConverter<BaseConverter, ConverterTrait>(
+ list_type_->value_type(), this->options_, pool)));
+ this->builder_ =
+ std::make_shared<BuilderType>(pool, value_converter_->builder(), this->type_);
+ list_builder_ = checked_cast<BuilderType*>(this->builder_.get());
+ // Narrow list types may overflow
+ this->may_overflow_ = this->rewind_on_overflow_ =
+ sizeof(typename ArrowType::offset_type) < sizeof(int64_t);
+ return Status::OK();
+ }
+
+ const ArrowType* list_type_;
+ BuilderType* list_builder_;
+ std::unique_ptr<BaseConverter> value_converter_;
+};
+
+template <typename BaseConverter, template <typename...> class ConverterTrait>
+class StructConverter : public BaseConverter {
+ public:
+ using ConverterType = typename ConverterTrait<StructType>::type;
+
+ Status Reserve(int64_t additional_capacity) override {
+ ARROW_RETURN_NOT_OK(this->builder_->Reserve(additional_capacity));
+ for (const auto& child : children_) {
+ ARROW_RETURN_NOT_OK(child->Reserve(additional_capacity));
+ }
+ return Status::OK();
+ }
+
+ protected:
+ Status Init(MemoryPool* pool) override {
+ std::unique_ptr<BaseConverter> child_converter;
+ std::vector<std::shared_ptr<ArrayBuilder>> child_builders;
+
+ struct_type_ = checked_cast<const StructType*>(this->type_.get());
+ for (const auto& field : struct_type_->fields()) {
+ ARROW_ASSIGN_OR_RAISE(child_converter,
+ (MakeConverter<BaseConverter, ConverterTrait>(
+ field->type(), this->options_, pool)));
+ this->may_overflow_ |= child_converter->may_overflow();
+ this->rewind_on_overflow_ = this->may_overflow_;
+ child_builders.push_back(child_converter->builder());
+ children_.push_back(std::move(child_converter));
+ }
+
+ this->builder_ =
+ std::make_shared<StructBuilder>(this->type_, pool, std::move(child_builders));
+ struct_builder_ = checked_cast<StructBuilder*>(this->builder_.get());
+
+ return Status::OK();
+ }
+
+ const StructType* struct_type_;
+ StructBuilder* struct_builder_;
+ std::vector<std::unique_ptr<BaseConverter>> children_;
+};
+
+template <typename ValueType, typename BaseConverter>
+class DictionaryConverter : public BaseConverter {
+ public:
+ using BuilderType = DictionaryBuilder<ValueType>;
+
+ protected:
+ Status Init(MemoryPool* pool) override {
+ std::unique_ptr<ArrayBuilder> builder;
+ ARROW_RETURN_NOT_OK(MakeDictionaryBuilder(pool, this->type_, NULLPTR, &builder));
+ this->builder_ = std::move(builder);
+ this->may_overflow_ = false;
+ dict_type_ = checked_cast<const DictionaryType*>(this->type_.get());
+ value_type_ = checked_cast<const ValueType*>(dict_type_->value_type().get());
+ value_builder_ = checked_cast<BuilderType*>(this->builder_.get());
+ return Status::OK();
+ }
+
+ const DictionaryType* dict_type_;
+ const ValueType* value_type_;
+ BuilderType* value_builder_;
+};
+
+template <typename BaseConverter, template <typename...> class ConverterTrait>
+struct MakeConverterImpl {
+ template <typename T, typename ConverterType = typename ConverterTrait<T>::type>
+ Status Visit(const T&) {
+ out.reset(new ConverterType());
+ return out->Construct(std::move(type), std::move(options), pool);
+ }
+
+ Status Visit(const DictionaryType& t) {
+ switch (t.value_type()->id()) {
+#define DICTIONARY_CASE(TYPE) \
+ case TYPE::type_id: \
+ out = internal::make_unique< \
+ typename ConverterTrait<DictionaryType>::template dictionary_type<TYPE>>(); \
+ break;
+ DICTIONARY_CASE(BooleanType);
+ DICTIONARY_CASE(Int8Type);
+ DICTIONARY_CASE(Int16Type);
+ DICTIONARY_CASE(Int32Type);
+ DICTIONARY_CASE(Int64Type);
+ DICTIONARY_CASE(UInt8Type);
+ DICTIONARY_CASE(UInt16Type);
+ DICTIONARY_CASE(UInt32Type);
+ DICTIONARY_CASE(UInt64Type);
+ DICTIONARY_CASE(FloatType);
+ DICTIONARY_CASE(DoubleType);
+ DICTIONARY_CASE(BinaryType);
+ DICTIONARY_CASE(StringType);
+ DICTIONARY_CASE(FixedSizeBinaryType);
+#undef DICTIONARY_CASE
+ default:
+ return Status::NotImplemented("DictionaryArray converter for type ", t.ToString(),
+ " not implemented");
+ }
+ return out->Construct(std::move(type), std::move(options), pool);
+ }
+
+ Status Visit(const DataType& t) { return Status::NotImplemented(t.name()); }
+
+ std::shared_ptr<DataType> type;
+ typename BaseConverter::OptionsType options;
+ MemoryPool* pool;
+ std::unique_ptr<BaseConverter> out;
+};
+
+template <typename BaseConverter, template <typename...> class ConverterTrait>
+static Result<std::unique_ptr<BaseConverter>> MakeConverter(
+ std::shared_ptr<DataType> type, typename BaseConverter::OptionsType options,
+ MemoryPool* pool) {
+ MakeConverterImpl<BaseConverter, ConverterTrait> visitor{
+ std::move(type), std::move(options), pool, NULLPTR};
+ ARROW_RETURN_NOT_OK(VisitTypeInline(*visitor.type, &visitor));
+ return std::move(visitor.out);
+}
+
+template <typename Converter>
+class Chunker {
+ public:
+ using InputType = typename Converter::InputType;
+
+ explicit Chunker(std::unique_ptr<Converter> converter)
+ : converter_(std::move(converter)) {}
+
+ Status Reserve(int64_t additional_capacity) {
+ ARROW_RETURN_NOT_OK(converter_->Reserve(additional_capacity));
+ reserved_ += additional_capacity;
+ return Status::OK();
+ }
+
+ Status AppendNull() {
+ auto status = converter_->AppendNull();
+ if (ARROW_PREDICT_FALSE(status.IsCapacityError())) {
+ if (converter_->builder()->length() == 0) {
+ // Builder length == 0 means the individual element is too large to append.
+ // In this case, no need to try again.
+ return status;
+ }
+ ARROW_RETURN_NOT_OK(FinishChunk());
+ return converter_->AppendNull();
+ }
+ ++length_;
+ return status;
+ }
+
+ Status Append(InputType value) {
+ auto status = converter_->Append(value);
+ if (ARROW_PREDICT_FALSE(status.IsCapacityError())) {
+ if (converter_->builder()->length() == 0) {
+ return status;
+ }
+ ARROW_RETURN_NOT_OK(FinishChunk());
+ return Append(value);
+ }
+ ++length_;
+ return status;
+ }
+
+ Status Extend(InputType values, int64_t size, int64_t offset = 0) {
+ while (offset < size) {
+ auto length_before = converter_->builder()->length();
+ auto status = converter_->Extend(values, size, offset);
+ auto length_after = converter_->builder()->length();
+ auto num_converted = length_after - length_before;
+
+ offset += num_converted;
+ length_ += num_converted;
+
+ if (status.IsCapacityError()) {
+ if (converter_->builder()->length() == 0) {
+ // Builder length == 0 means the individual element is too large to append.
+ // In this case, no need to try again.
+ return status;
+ } else if (converter_->rewind_on_overflow()) {
+ // The list-like and binary-like conversion paths may raise a capacity error,
+ // we need to handle them differently. While the binary-like converters check
+ // the capacity before append/extend the list-like converters just check after
+ // append/extend. Thus depending on the implementation semantics we may need
+ // to rewind (slice) the output chunk by one.
+ length_ -= 1;
+ offset -= 1;
+ }
+ ARROW_RETURN_NOT_OK(FinishChunk());
+ } else if (!status.ok()) {
+ return status;
+ }
+ }
+ return Status::OK();
+ }
+
+ Status ExtendMasked(InputType values, InputType mask, int64_t size,
+ int64_t offset = 0) {
+ while (offset < size) {
+ auto length_before = converter_->builder()->length();
+ auto status = converter_->ExtendMasked(values, mask, size, offset);
+ auto length_after = converter_->builder()->length();
+ auto num_converted = length_after - length_before;
+
+ offset += num_converted;
+ length_ += num_converted;
+
+ if (status.IsCapacityError()) {
+ if (converter_->builder()->length() == 0) {
+ // Builder length == 0 means the individual element is too large to append.
+ // In this case, no need to try again.
+ return status;
+ } else if (converter_->rewind_on_overflow()) {
+ // The list-like and binary-like conversion paths may raise a capacity error,
+ // we need to handle them differently. While the binary-like converters check
+ // the capacity before append/extend the list-like converters just check after
+ // append/extend. Thus depending on the implementation semantics we may need
+ // to rewind (slice) the output chunk by one.
+ length_ -= 1;
+ offset -= 1;
+ }
+ ARROW_RETURN_NOT_OK(FinishChunk());
+ } else if (!status.ok()) {
+ return status;
+ }
+ }
+ return Status::OK();
+ }
+
+ Status FinishChunk() {
+ ARROW_ASSIGN_OR_RAISE(auto chunk, converter_->ToArray(length_));
+ chunks_.push_back(chunk);
+ // Reserve space for the remaining items.
+ // Besides being an optimization, it is also required if the converter's
+ // implementation relies on unsafe builder methods in converter->Append().
+ auto remaining = reserved_ - length_;
+ Reset();
+ return Reserve(remaining);
+ }
+
+ Result<std::shared_ptr<ChunkedArray>> ToChunkedArray() {
+ ARROW_RETURN_NOT_OK(FinishChunk());
+ return std::make_shared<ChunkedArray>(chunks_);
+ }
+
+ protected:
+ void Reset() {
+ converter_->builder()->Reset();
+ length_ = 0;
+ reserved_ = 0;
+ }
+
+ int64_t length_ = 0;
+ int64_t reserved_ = 0;
+ std::unique_ptr<Converter> converter_;
+ std::vector<std::shared_ptr<Array>> chunks_;
+};
+
+template <typename T>
+static Result<std::unique_ptr<Chunker<T>>> MakeChunker(std::unique_ptr<T> converter) {
+ return internal::make_unique<Chunker<T>>(std::move(converter));
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/counting_semaphore.cc b/src/arrow/cpp/src/arrow/util/counting_semaphore.cc
new file mode 100644
index 000000000..b3106a6f8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/counting_semaphore.cc
@@ -0,0 +1,126 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/counting_semaphore.h"
+
+#include <chrono>
+#include <condition_variable>
+#include <cstdint>
+#include <iostream>
+#include <mutex>
+
+#include "arrow/status.h"
+
+namespace arrow {
+namespace util {
+
+class CountingSemaphore::Impl {
+ public:
+ Impl(uint32_t initial_avail, double timeout_seconds)
+ : num_permits_(initial_avail), timeout_seconds_(timeout_seconds) {}
+
+ Status Acquire(uint32_t num_permits) {
+ std::unique_lock<std::mutex> lk(mutex_);
+ RETURN_NOT_OK(CheckClosed());
+ num_waiters_ += num_permits;
+ waiter_cv_.notify_all();
+ bool timed_out = !acquirer_cv_.wait_for(
+ lk, std::chrono::nanoseconds(static_cast<int64_t>(timeout_seconds_ * 1e9)),
+ [&] { return closed_ || num_permits <= num_permits_; });
+ num_waiters_ -= num_permits;
+ if (timed_out) {
+ return Status::Invalid("Timed out waiting for semaphore to release ", num_permits,
+ " permits.");
+ }
+ if (closed_) {
+ return Status::Invalid("Semaphore closed while acquiring");
+ }
+ num_permits_ -= num_permits;
+ return Status::OK();
+ }
+
+ Status Release(uint32_t num_permits) {
+ std::lock_guard<std::mutex> lg(mutex_);
+ RETURN_NOT_OK(CheckClosed());
+ num_permits_ += num_permits;
+ acquirer_cv_.notify_all();
+ return Status::OK();
+ }
+
+ Status WaitForWaiters(uint32_t num_waiters) {
+ std::unique_lock<std::mutex> lk(mutex_);
+ RETURN_NOT_OK(CheckClosed());
+ if (waiter_cv_.wait_for(
+ lk, std::chrono::nanoseconds(static_cast<int64_t>(timeout_seconds_ * 1e9)),
+ [&] { return closed_ || num_waiters <= num_waiters_; })) {
+ if (closed_) {
+ return Status::Invalid("Semaphore closed while waiting for waiters");
+ }
+ return Status::OK();
+ }
+ return Status::Invalid("Timed out waiting for ", num_waiters,
+ " to start waiting on semaphore");
+ }
+
+ Status Close() {
+ std::lock_guard<std::mutex> lg(mutex_);
+ RETURN_NOT_OK(CheckClosed());
+ closed_ = true;
+ if (num_waiters_ > 0) {
+ waiter_cv_.notify_all();
+ acquirer_cv_.notify_all();
+ return Status::Invalid(
+ "There were one or more threads waiting on a semaphore when it was closed");
+ }
+ return Status::OK();
+ }
+
+ private:
+ Status CheckClosed() const {
+ if (closed_) {
+ return Status::Invalid("Invalid operation on closed semaphore");
+ }
+ return Status::OK();
+ }
+
+ uint32_t num_permits_;
+ double timeout_seconds_;
+ uint32_t num_waiters_ = 0;
+ bool closed_ = false;
+ std::mutex mutex_;
+ std::condition_variable acquirer_cv_;
+ std::condition_variable waiter_cv_;
+};
+
+CountingSemaphore::CountingSemaphore(uint32_t initial_avail, double timeout_seconds)
+ : impl_(new Impl(initial_avail, timeout_seconds)) {}
+
+CountingSemaphore::~CountingSemaphore() = default;
+
+Status CountingSemaphore::Acquire(uint32_t num_permits) {
+ return impl_->Acquire(num_permits);
+}
+Status CountingSemaphore::Release(uint32_t num_permits) {
+ return impl_->Release(num_permits);
+}
+Status CountingSemaphore::WaitForWaiters(uint32_t num_waiters) {
+ return impl_->WaitForWaiters(num_waiters);
+}
+Status CountingSemaphore::Close() { return impl_->Close(); }
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/counting_semaphore.h b/src/arrow/cpp/src/arrow/util/counting_semaphore.h
new file mode 100644
index 000000000..a3c13cc3b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/counting_semaphore.h
@@ -0,0 +1,60 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef ARROW_COUNTING_SEMAPHORE_H
+#define ARROW_COUNTING_SEMAPHORE_H
+
+#include <memory>
+
+#include "arrow/status.h"
+
+namespace arrow {
+namespace util {
+
+/// \brief Simple mutex-based counting semaphore with timeout
+class ARROW_EXPORT CountingSemaphore {
+ public:
+ /// \brief Create an instance with initial_avail starting permits
+ ///
+ /// \param[in] initial_avail The semaphore will start with this many permits available
+ /// \param[in] timeout_seconds A timeout to be applied to all operations. Operations
+ /// will return Status::Invalid if this timeout elapses
+ explicit CountingSemaphore(uint32_t initial_avail = 0, double timeout_seconds = 10);
+ ~CountingSemaphore();
+ /// \brief Block until num_permits permits are available
+ Status Acquire(uint32_t num_permits);
+ /// \brief Make num_permits permits available
+ Status Release(uint32_t num_permits);
+ /// \brief Wait until num_waiters are waiting on permits
+ ///
+ /// This method is non-standard but useful in unit tests to ensure sequencing
+ Status WaitForWaiters(uint32_t num_waiters);
+ /// \brief Immediately time out any waiters
+ ///
+ /// This method will return Status::OK only if there were no waiters to time out.
+ /// Once closed any operation on this instance will return an invalid status.
+ Status Close();
+
+ private:
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+} // namespace util
+} // namespace arrow
+
+#endif // ARROW_COUNTING_SEMAPHORE_H
diff --git a/src/arrow/cpp/src/arrow/util/counting_semaphore_test.cc b/src/arrow/cpp/src/arrow/util/counting_semaphore_test.cc
new file mode 100644
index 000000000..a5fa9f6bd
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/counting_semaphore_test.cc
@@ -0,0 +1,98 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/counting_semaphore.h"
+
+#include <atomic>
+#include <thread>
+#include <vector>
+
+#include "arrow/testing/gtest_util.h"
+#include "gtest/gtest.h"
+
+namespace arrow {
+namespace util {
+
+TEST(CountingSemaphore, Basic) {
+ CountingSemaphore semaphore;
+ std::atomic<bool> acquired{false};
+ std::atomic<bool> started{false};
+ std::thread acquirer([&] {
+ started.store(true);
+ ASSERT_OK(semaphore.Acquire(3));
+ acquired = true;
+ });
+ ASSERT_OK(semaphore.WaitForWaiters(1));
+ ASSERT_TRUE(started.load());
+ ASSERT_FALSE(acquired.load());
+ ASSERT_OK(semaphore.Release(2));
+ SleepABit();
+ ASSERT_FALSE(acquired.load());
+ ASSERT_OK(semaphore.Release(1));
+ BusyWait(10, [&] { return acquired.load(); });
+ ASSERT_TRUE(acquired.load());
+ ASSERT_OK(semaphore.Close());
+ acquirer.join();
+}
+
+TEST(CountingSemaphore, CloseAborts) {
+ CountingSemaphore semaphore;
+ std::atomic<bool> cleanup{false};
+ std::thread acquirer([&] {
+ ASSERT_RAISES(Invalid, semaphore.Acquire(1));
+ cleanup = true;
+ });
+ ASSERT_OK(semaphore.WaitForWaiters(1));
+ ASSERT_FALSE(cleanup.load());
+ ASSERT_RAISES(Invalid, semaphore.Close());
+ BusyWait(10, [&] { return cleanup.load(); });
+ acquirer.join();
+}
+
+TEST(CountingSemaphore, Stress) {
+ constexpr uint32_t NTHREADS = 10;
+ CountingSemaphore semaphore;
+ std::vector<uint32_t> max_allowed_cases = {1, 3};
+ std::atomic<uint32_t> count{0};
+ std::atomic<bool> max_exceeded{false};
+ std::vector<std::thread> threads;
+ for (uint32_t max_allowed : max_allowed_cases) {
+ ASSERT_OK(semaphore.Release(max_allowed));
+ for (uint32_t i = 0; i < NTHREADS; i++) {
+ threads.emplace_back([&] {
+ ASSERT_OK(semaphore.Acquire(1));
+ uint32_t last_count = count.fetch_add(1);
+ if (last_count >= max_allowed) {
+ max_exceeded.store(true);
+ }
+ SleepABit();
+ count.fetch_sub(1);
+ ASSERT_OK(semaphore.Release(1));
+ });
+ }
+ for (auto& thread : threads) {
+ thread.join();
+ }
+ threads.clear();
+ ASSERT_OK(semaphore.Acquire(max_allowed));
+ }
+ ASSERT_OK(semaphore.Close());
+ ASSERT_FALSE(max_exceeded.load());
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/cpu_info.cc b/src/arrow/cpp/src/arrow/util/cpu_info.cc
new file mode 100644
index 000000000..d803521a2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/cpu_info.cc
@@ -0,0 +1,563 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// From Apache Impala (incubating) as of 2016-01-29.
+
+#include "arrow/util/cpu_info.h"
+
+#ifdef __APPLE__
+#include <sys/sysctl.h>
+#endif
+
+#include <stdlib.h>
+#include <string.h>
+
+#ifndef _MSC_VER
+#include <unistd.h>
+#endif
+
+#ifdef _WIN32
+#include <immintrin.h>
+#include <intrin.h>
+#include <array>
+#include <bitset>
+
+#include "arrow/util/windows_compatibility.h"
+#endif
+
+#include <algorithm>
+#include <cctype>
+#include <cerrno>
+#include <cstdint>
+#include <fstream>
+#include <memory>
+#include <mutex>
+#include <string>
+
+#include "arrow/result.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/string.h"
+
+namespace arrow {
+namespace internal {
+
+namespace {
+
+using std::max;
+
+constexpr int64_t kDefaultL1CacheSize = 32 * 1024; // Level 1: 32k
+constexpr int64_t kDefaultL2CacheSize = 256 * 1024; // Level 2: 256k
+constexpr int64_t kDefaultL3CacheSize = 3072 * 1024; // Level 3: 3M
+
+#if defined(__MINGW64_VERSION_MAJOR) && __MINGW64_VERSION_MAJOR < 5
+void __cpuidex(int CPUInfo[4], int function_id, int subfunction_id) {
+ __asm__ __volatile__("cpuid"
+ : "=a"(CPUInfo[0]), "=b"(CPUInfo[1]), "=c"(CPUInfo[2]),
+ "=d"(CPUInfo[3])
+ : "a"(function_id), "c"(subfunction_id));
+}
+
+int64_t _xgetbv(int xcr) {
+ int out = 0;
+ __asm__ __volatile__("xgetbv" : "=a"(out) : "c"(xcr) : "%edx");
+ return out;
+}
+#endif
+
+#ifdef __APPLE__
+util::optional<int64_t> IntegerSysCtlByName(const char* name) {
+ size_t len = sizeof(int64_t);
+ int64_t data = 0;
+ if (sysctlbyname(name, &data, &len, nullptr, 0) == 0) {
+ return data;
+ }
+ // ENOENT is the official errno value for non-existing sysctl's,
+ // but EINVAL and ENOTSUP have been seen in the wild.
+ if (errno != ENOENT && errno != EINVAL && errno != ENOTSUP) {
+ auto st = IOErrorFromErrno(errno, "sysctlbyname failed for '", name, "'");
+ ARROW_LOG(WARNING) << st.ToString();
+ }
+ return util::nullopt;
+}
+#endif
+
+#if defined(__GNUC__) && defined(__linux__) && defined(__aarch64__)
+// There is no direct instruction to get cache size on Arm64 like '__cpuid' on x86;
+// Get Arm64 cache size by reading '/sys/devices/system/cpu/cpu0/cache/index*/size';
+// index* :
+// index0: L1 Dcache
+// index1: L1 Icache
+// index2: L2 cache
+// index3: L3 cache
+const char* kL1CacheSizeFile = "/sys/devices/system/cpu/cpu0/cache/index0/size";
+const char* kL2CacheSizeFile = "/sys/devices/system/cpu/cpu0/cache/index2/size";
+const char* kL3CacheSizeFile = "/sys/devices/system/cpu/cpu0/cache/index3/size";
+
+int64_t GetArm64CacheSize(const char* filename, int64_t default_size = -1) {
+ char* content = nullptr;
+ char* last_char = nullptr;
+ size_t file_len = 0;
+
+ // Read cache file to 'content' for getting cache size.
+ FILE* cache_file = fopen(filename, "r");
+ if (cache_file == nullptr) {
+ return default_size;
+ }
+ int res = getline(&content, &file_len, cache_file);
+ fclose(cache_file);
+ if (res == -1) {
+ return default_size;
+ }
+ std::unique_ptr<char, decltype(&free)> content_guard(content, &free);
+
+ errno = 0;
+ const auto cardinal_num = strtoull(content, &last_char, 0);
+ if (errno != 0) {
+ return default_size;
+ }
+ // kB, MB, or GB
+ int64_t multip = 1;
+ switch (*last_char) {
+ case 'g':
+ case 'G':
+ multip *= 1024;
+ case 'm':
+ case 'M':
+ multip *= 1024;
+ case 'k':
+ case 'K':
+ multip *= 1024;
+ }
+ return cardinal_num * multip;
+}
+#endif
+
+#if !defined(_WIN32) && !defined(__APPLE__)
+struct {
+ std::string name;
+ int64_t flag;
+} flag_mappings[] = {
+#if (defined(__i386) || defined(_M_IX86) || defined(__x86_64__) || defined(_M_X64))
+ {"ssse3", CpuInfo::SSSE3}, {"sse4_1", CpuInfo::SSE4_1},
+ {"sse4_2", CpuInfo::SSE4_2}, {"popcnt", CpuInfo::POPCNT},
+ {"avx", CpuInfo::AVX}, {"avx2", CpuInfo::AVX2},
+ {"avx512f", CpuInfo::AVX512F}, {"avx512cd", CpuInfo::AVX512CD},
+ {"avx512vl", CpuInfo::AVX512VL}, {"avx512dq", CpuInfo::AVX512DQ},
+ {"avx512bw", CpuInfo::AVX512BW}, {"bmi1", CpuInfo::BMI1},
+ {"bmi2", CpuInfo::BMI2},
+#endif
+#if defined(__aarch64__)
+ {"asimd", CpuInfo::ASIMD},
+#endif
+};
+const int64_t num_flags = sizeof(flag_mappings) / sizeof(flag_mappings[0]);
+
+// Helper function to parse for hardware flags.
+// values contains a list of space-separated flags. check to see if the flags we
+// care about are present.
+// Returns a bitmap of flags.
+int64_t ParseCPUFlags(const std::string& values) {
+ int64_t flags = 0;
+ for (int i = 0; i < num_flags; ++i) {
+ if (values.find(flag_mappings[i].name) != std::string::npos) {
+ flags |= flag_mappings[i].flag;
+ }
+ }
+ return flags;
+}
+#endif
+
+#ifdef _WIN32
+bool RetrieveCacheSize(int64_t* cache_sizes) {
+ if (!cache_sizes) {
+ return false;
+ }
+ PSYSTEM_LOGICAL_PROCESSOR_INFORMATION buffer = nullptr;
+ PSYSTEM_LOGICAL_PROCESSOR_INFORMATION buffer_position = nullptr;
+ DWORD buffer_size = 0;
+ size_t offset = 0;
+ typedef BOOL(WINAPI * GetLogicalProcessorInformationFuncPointer)(void*, void*);
+ GetLogicalProcessorInformationFuncPointer func_pointer =
+ (GetLogicalProcessorInformationFuncPointer)GetProcAddress(
+ GetModuleHandle("kernel32"), "GetLogicalProcessorInformation");
+
+ if (!func_pointer) {
+ return false;
+ }
+
+ // Get buffer size
+ if (func_pointer(buffer, &buffer_size) && GetLastError() != ERROR_INSUFFICIENT_BUFFER)
+ return false;
+
+ buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION)malloc(buffer_size);
+
+ if (!buffer || !func_pointer(buffer, &buffer_size)) {
+ return false;
+ }
+
+ buffer_position = buffer;
+ while (offset + sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION) <= buffer_size) {
+ if (RelationCache == buffer_position->Relationship) {
+ PCACHE_DESCRIPTOR cache = &buffer_position->Cache;
+ if (cache->Level >= 1 && cache->Level <= 3) {
+ cache_sizes[cache->Level - 1] += cache->Size;
+ }
+ }
+ offset += sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION);
+ buffer_position++;
+ }
+
+ if (buffer) {
+ free(buffer);
+ }
+ return true;
+}
+
+// Source: https://en.wikipedia.org/wiki/CPUID
+bool RetrieveCPUInfo(int64_t* hardware_flags, std::string* model_name,
+ CpuInfo::Vendor* vendor) {
+ if (!hardware_flags || !model_name || !vendor) {
+ return false;
+ }
+ int register_EAX_id = 1;
+ int highest_valid_id = 0;
+ int highest_extended_valid_id = 0;
+ std::bitset<32> features_ECX;
+ std::array<int, 4> cpu_info;
+
+ // Get highest valid id
+ __cpuid(cpu_info.data(), 0);
+ highest_valid_id = cpu_info[0];
+ // HEX of "GenuineIntel": 47656E75 696E6549 6E74656C
+ // HEX of "AuthenticAMD": 41757468 656E7469 63414D44
+ if (cpu_info[1] == 0x756e6547 && cpu_info[2] == 0x49656e69 &&
+ cpu_info[3] == 0x6c65746e) {
+ *vendor = CpuInfo::Vendor::Intel;
+ } else if (cpu_info[1] == 0x68747541 && cpu_info[2] == 0x69746e65 &&
+ cpu_info[3] == 0x444d4163) {
+ *vendor = CpuInfo::Vendor::AMD;
+ }
+
+ if (highest_valid_id <= register_EAX_id) return false;
+
+ // EAX=1: Processor Info and Feature Bits
+ __cpuidex(cpu_info.data(), register_EAX_id, 0);
+ features_ECX = cpu_info[2];
+
+ // Get highest extended id
+ __cpuid(cpu_info.data(), 0x80000000);
+ highest_extended_valid_id = cpu_info[0];
+
+ // Retrieve CPU model name
+ if (highest_extended_valid_id >= static_cast<int>(0x80000004)) {
+ model_name->clear();
+ for (int i = 0x80000002; i <= static_cast<int>(0x80000004); ++i) {
+ __cpuidex(cpu_info.data(), i, 0);
+ *model_name +=
+ std::string(reinterpret_cast<char*>(cpu_info.data()), sizeof(cpu_info));
+ }
+ }
+
+ bool zmm_enabled = false;
+ if (features_ECX[27]) { // OSXSAVE
+ // Query if the OS supports saving ZMM registers when switching contexts
+ int64_t xcr0 = _xgetbv(0);
+ zmm_enabled = (xcr0 & 0xE0) == 0xE0;
+ }
+
+ if (features_ECX[9]) *hardware_flags |= CpuInfo::SSSE3;
+ if (features_ECX[19]) *hardware_flags |= CpuInfo::SSE4_1;
+ if (features_ECX[20]) *hardware_flags |= CpuInfo::SSE4_2;
+ if (features_ECX[23]) *hardware_flags |= CpuInfo::POPCNT;
+ if (features_ECX[23]) *hardware_flags |= CpuInfo::AVX;
+
+ // cpuid with EAX=7, ECX=0: Extended Features
+ register_EAX_id = 7;
+ if (highest_valid_id > register_EAX_id) {
+ __cpuidex(cpu_info.data(), register_EAX_id, 0);
+ std::bitset<32> features_EBX = cpu_info[1];
+
+ if (features_EBX[3]) *hardware_flags |= CpuInfo::BMI1;
+ if (features_EBX[5]) *hardware_flags |= CpuInfo::AVX2;
+ if (features_EBX[8]) *hardware_flags |= CpuInfo::BMI2;
+ // ARROW-11427: only use AVX512 if enabled by the OS
+ if (zmm_enabled) {
+ if (features_EBX[16]) *hardware_flags |= CpuInfo::AVX512F;
+ if (features_EBX[17]) *hardware_flags |= CpuInfo::AVX512DQ;
+ if (features_EBX[28]) *hardware_flags |= CpuInfo::AVX512CD;
+ if (features_EBX[30]) *hardware_flags |= CpuInfo::AVX512BW;
+ if (features_EBX[31]) *hardware_flags |= CpuInfo::AVX512VL;
+ }
+ }
+
+ return true;
+}
+#endif
+
+} // namespace
+
+CpuInfo::CpuInfo()
+ : hardware_flags_(0),
+ num_cores_(1),
+ model_name_("unknown"),
+ vendor_(Vendor::Unknown) {}
+
+std::unique_ptr<CpuInfo> g_cpu_info;
+static std::once_flag cpuinfo_initialized;
+
+CpuInfo* CpuInfo::GetInstance() {
+ std::call_once(cpuinfo_initialized, []() {
+ g_cpu_info.reset(new CpuInfo);
+ g_cpu_info->Init();
+ });
+ return g_cpu_info.get();
+}
+
+void CpuInfo::Init() {
+ std::string line;
+ std::string name;
+ std::string value;
+
+ float max_mhz = 0;
+ int num_cores = 0;
+
+ memset(&cache_sizes_, 0, sizeof(cache_sizes_));
+
+#ifdef _WIN32
+ SYSTEM_INFO system_info;
+ GetSystemInfo(&system_info);
+ num_cores = system_info.dwNumberOfProcessors;
+
+ LARGE_INTEGER performance_frequency;
+ if (QueryPerformanceFrequency(&performance_frequency)) {
+ max_mhz = static_cast<float>(performance_frequency.QuadPart);
+ }
+#elif defined(__APPLE__)
+ // On macOS, get CPU information from system information base
+ struct SysCtlCpuFeature {
+ const char* name;
+ int64_t flag;
+ };
+ std::vector<SysCtlCpuFeature> features = {
+#if defined(__aarch64__)
+ // ARM64 (note that this is exposed under Rosetta as well)
+ {"hw.optional.neon", ASIMD},
+#else
+ // x86
+ {"hw.optional.sse4_2", SSSE3 | SSE4_1 | SSE4_2 | POPCNT},
+ {"hw.optional.avx1_0", AVX},
+ {"hw.optional.avx2_0", AVX2},
+ {"hw.optional.bmi1", BMI1},
+ {"hw.optional.bmi2", BMI2},
+ {"hw.optional.avx512f", AVX512F},
+ {"hw.optional.avx512cd", AVX512CD},
+ {"hw.optional.avx512dq", AVX512DQ},
+ {"hw.optional.avx512bw", AVX512BW},
+ {"hw.optional.avx512vl", AVX512VL},
+#endif
+ };
+ for (const auto& feature : features) {
+ auto v = IntegerSysCtlByName(feature.name);
+ if (v.value_or(0)) {
+ hardware_flags_ |= feature.flag;
+ }
+ }
+#else
+ // Read from /proc/cpuinfo
+ std::ifstream cpuinfo("/proc/cpuinfo", std::ios::in);
+ while (cpuinfo) {
+ std::getline(cpuinfo, line);
+ size_t colon = line.find(':');
+ if (colon != std::string::npos) {
+ name = TrimString(line.substr(0, colon - 1));
+ value = TrimString(line.substr(colon + 1, std::string::npos));
+ if (name.compare("flags") == 0 || name.compare("Features") == 0) {
+ hardware_flags_ |= ParseCPUFlags(value);
+ } else if (name.compare("cpu MHz") == 0) {
+ // Every core will report a different speed. We'll take the max, assuming
+ // that when impala is running, the core will not be in a lower power state.
+ // TODO: is there a more robust way to do this, such as
+ // Window's QueryPerformanceFrequency()
+ float mhz = static_cast<float>(atof(value.c_str()));
+ max_mhz = max(mhz, max_mhz);
+ } else if (name.compare("processor") == 0) {
+ ++num_cores;
+ } else if (name.compare("model name") == 0) {
+ model_name_ = value;
+ } else if (name.compare("vendor_id") == 0) {
+ if (value.compare("GenuineIntel") == 0) {
+ vendor_ = Vendor::Intel;
+ } else if (value.compare("AuthenticAMD") == 0) {
+ vendor_ = Vendor::AMD;
+ }
+ }
+ }
+ }
+ if (cpuinfo.is_open()) cpuinfo.close();
+#endif
+
+#ifdef __APPLE__
+ // On macOS, get cache size from system information base
+ SetDefaultCacheSize();
+ auto c = IntegerSysCtlByName("hw.l1dcachesize");
+ if (c.has_value()) {
+ cache_sizes_[0] = *c;
+ }
+ c = IntegerSysCtlByName("hw.l2cachesize");
+ if (c.has_value()) {
+ cache_sizes_[1] = *c;
+ }
+ c = IntegerSysCtlByName("hw.l3cachesize");
+ if (c.has_value()) {
+ cache_sizes_[2] = *c;
+ }
+#elif _WIN32
+ if (!RetrieveCacheSize(cache_sizes_)) {
+ SetDefaultCacheSize();
+ }
+ RetrieveCPUInfo(&hardware_flags_, &model_name_, &vendor_);
+#else
+ SetDefaultCacheSize();
+#endif
+
+ if (max_mhz != 0) {
+ cycles_per_ms_ = static_cast<int64_t>(max_mhz);
+#ifndef _WIN32
+ cycles_per_ms_ *= 1000;
+#endif
+ } else {
+ cycles_per_ms_ = 1000000;
+ }
+ original_hardware_flags_ = hardware_flags_;
+
+ if (num_cores > 0) {
+ num_cores_ = num_cores;
+ } else {
+ num_cores_ = 1;
+ }
+
+ // Parse the user simd level
+ ParseUserSimdLevel();
+}
+
+void CpuInfo::VerifyCpuRequirements() {
+#ifdef ARROW_HAVE_SSE4_2
+ if (!IsSupported(CpuInfo::SSSE3)) {
+ DCHECK(false) << "CPU does not support the Supplemental SSE3 instruction set";
+ }
+#endif
+#if defined(ARROW_HAVE_NEON)
+ if (!IsSupported(CpuInfo::ASIMD)) {
+ DCHECK(false) << "CPU does not support the Armv8 Neon instruction set";
+ }
+#endif
+}
+
+bool CpuInfo::CanUseSSE4_2() const {
+#if defined(ARROW_HAVE_SSE4_2)
+ return IsSupported(CpuInfo::SSE4_2);
+#else
+ return false;
+#endif
+}
+
+void CpuInfo::EnableFeature(int64_t flag, bool enable) {
+ if (!enable) {
+ hardware_flags_ &= ~flag;
+ } else {
+ // Can't turn something on that can't be supported
+ DCHECK_NE(original_hardware_flags_ & flag, 0);
+ hardware_flags_ |= flag;
+ }
+}
+
+int64_t CpuInfo::hardware_flags() { return hardware_flags_; }
+
+int64_t CpuInfo::CacheSize(CacheLevel level) { return cache_sizes_[level]; }
+
+int64_t CpuInfo::cycles_per_ms() { return cycles_per_ms_; }
+
+int CpuInfo::num_cores() { return num_cores_; }
+
+std::string CpuInfo::model_name() { return model_name_; }
+
+void CpuInfo::SetDefaultCacheSize() {
+#if defined(_SC_LEVEL1_DCACHE_SIZE) && !defined(__aarch64__)
+ // Call sysconf to query for the cache sizes
+ cache_sizes_[0] = sysconf(_SC_LEVEL1_DCACHE_SIZE);
+ cache_sizes_[1] = sysconf(_SC_LEVEL2_CACHE_SIZE);
+ cache_sizes_[2] = sysconf(_SC_LEVEL3_CACHE_SIZE);
+ ARROW_UNUSED(kDefaultL1CacheSize);
+ ARROW_UNUSED(kDefaultL2CacheSize);
+ ARROW_UNUSED(kDefaultL3CacheSize);
+#elif defined(__GNUC__) && defined(__linux__) && defined(__aarch64__)
+ cache_sizes_[0] = GetArm64CacheSize(kL1CacheSizeFile, kDefaultL1CacheSize);
+ cache_sizes_[1] = GetArm64CacheSize(kL2CacheSizeFile, kDefaultL2CacheSize);
+ cache_sizes_[2] = GetArm64CacheSize(kL3CacheSizeFile, kDefaultL3CacheSize);
+#else
+ // Provide reasonable default values if no info
+ cache_sizes_[0] = kDefaultL1CacheSize;
+ cache_sizes_[1] = kDefaultL2CacheSize;
+ cache_sizes_[2] = kDefaultL3CacheSize;
+#endif
+}
+
+void CpuInfo::ParseUserSimdLevel() {
+ auto maybe_env_var = GetEnvVar("ARROW_USER_SIMD_LEVEL");
+ if (!maybe_env_var.ok()) {
+ // No user settings
+ return;
+ }
+ std::string s = *std::move(maybe_env_var);
+ std::transform(s.begin(), s.end(), s.begin(),
+ [](unsigned char c) { return std::toupper(c); });
+
+ int level = USER_SIMD_MAX;
+ // Parse the level
+ if (s == "AVX512") {
+ level = USER_SIMD_AVX512;
+ } else if (s == "AVX2") {
+ level = USER_SIMD_AVX2;
+ } else if (s == "AVX") {
+ level = USER_SIMD_AVX;
+ } else if (s == "SSE4_2") {
+ level = USER_SIMD_SSE4_2;
+ } else if (s == "NONE") {
+ level = USER_SIMD_NONE;
+ } else if (!s.empty()) {
+ ARROW_LOG(WARNING) << "Invalid value for ARROW_USER_SIMD_LEVEL: " << s;
+ }
+
+ // Disable feature as the level
+ if (level < USER_SIMD_AVX512) { // Disable all AVX512 features
+ EnableFeature(AVX512, false);
+ }
+ if (level < USER_SIMD_AVX2) { // Disable all AVX2 features
+ EnableFeature(AVX2 | BMI2, false);
+ }
+ if (level < USER_SIMD_AVX) { // Disable all AVX features
+ EnableFeature(AVX, false);
+ }
+ if (level < USER_SIMD_SSE4_2) { // Disable all SSE4_2 features
+ EnableFeature(SSE4_2 | BMI1, false);
+ }
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/cpu_info.h b/src/arrow/cpp/src/arrow/util/cpu_info.h
new file mode 100644
index 000000000..83819c255
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/cpu_info.h
@@ -0,0 +1,143 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// From Apache Impala (incubating) as of 2016-01-29. Pared down to a minimal
+// set of functions needed for Apache Arrow / Apache parquet-cpp
+
+#pragma once
+
+#include <cstdint>
+#include <string>
+
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace internal {
+
+/// CpuInfo is an interface to query for cpu information at runtime. The caller can
+/// ask for the sizes of the caches and what hardware features are supported.
+/// On Linux, this information is pulled from a couple of sys files (/proc/cpuinfo and
+/// /sys/devices)
+class ARROW_EXPORT CpuInfo {
+ public:
+ static constexpr int64_t SSSE3 = (1 << 1);
+ static constexpr int64_t SSE4_1 = (1 << 2);
+ static constexpr int64_t SSE4_2 = (1 << 3);
+ static constexpr int64_t POPCNT = (1 << 4);
+ static constexpr int64_t ASIMD = (1 << 5);
+ static constexpr int64_t AVX = (1 << 6);
+ static constexpr int64_t AVX2 = (1 << 7);
+ static constexpr int64_t AVX512F = (1 << 8);
+ static constexpr int64_t AVX512CD = (1 << 9);
+ static constexpr int64_t AVX512VL = (1 << 10);
+ static constexpr int64_t AVX512DQ = (1 << 11);
+ static constexpr int64_t AVX512BW = (1 << 12);
+ static constexpr int64_t BMI1 = (1 << 13);
+ static constexpr int64_t BMI2 = (1 << 14);
+
+ /// Typical AVX512 subsets consists of AVX512F,AVX512BW,AVX512VL,AVX512CD,AVX512DQ
+ static constexpr int64_t AVX512 = AVX512F | AVX512CD | AVX512VL | AVX512DQ | AVX512BW;
+
+ /// Cache enums for L1 (data), L2 and L3
+ enum CacheLevel {
+ L1_CACHE = 0,
+ L2_CACHE = 1,
+ L3_CACHE = 2,
+ };
+
+ enum class Vendor : int { Unknown = 0, Intel, AMD };
+
+ static CpuInfo* GetInstance();
+
+ /// Determine if the CPU meets the minimum CPU requirements and if not, issue an error
+ /// and terminate.
+ void VerifyCpuRequirements();
+
+ /// Returns all the flags for this cpu
+ int64_t hardware_flags();
+
+ /// \brief Returns whether or not the given feature is enabled.
+ ///
+ /// IsSupported() is true iff IsDetected() is also true and the feature
+ /// wasn't disabled by the user (for example by setting the ARROW_USER_SIMD_LEVEL
+ /// environment variable).
+ bool IsSupported(int64_t flags) const { return (hardware_flags_ & flags) == flags; }
+
+ /// Returns whether or not the given feature is available on the CPU.
+ bool IsDetected(int64_t flags) const {
+ return (original_hardware_flags_ & flags) == flags;
+ }
+
+ /// \brief The processor supports SSE4.2 and the Arrow libraries are built
+ /// with support for it
+ bool CanUseSSE4_2() const;
+
+ /// Toggle a hardware feature on and off. It is not valid to turn on a feature
+ /// that the underlying hardware cannot support. This is useful for testing.
+ void EnableFeature(int64_t flag, bool enable);
+
+ /// Returns the size of the cache in KB at this cache level
+ int64_t CacheSize(CacheLevel level);
+
+ /// Returns the number of cpu cycles per millisecond
+ int64_t cycles_per_ms();
+
+ /// Returns the number of cores (including hyper-threaded) on this machine.
+ int num_cores();
+
+ /// Returns the model name of the cpu (e.g. Intel i7-2600)
+ std::string model_name();
+
+ /// Returns the vendor of the cpu.
+ Vendor vendor() const { return vendor_; }
+
+ bool HasEfficientBmi2() const {
+ // BMI2 (pext, pdep) is only efficient on Intel X86 processors.
+ return vendor() == Vendor::Intel && IsSupported(BMI2);
+ }
+
+ private:
+ CpuInfo();
+
+ enum UserSimdLevel {
+ USER_SIMD_NONE = 0,
+ USER_SIMD_SSE4_2,
+ USER_SIMD_AVX,
+ USER_SIMD_AVX2,
+ USER_SIMD_AVX512,
+ USER_SIMD_MAX,
+ };
+
+ void Init();
+
+ /// Inits CPU cache size variables with default values
+ void SetDefaultCacheSize();
+
+ /// Parse the SIMD level by ARROW_USER_SIMD_LEVEL env
+ void ParseUserSimdLevel();
+
+ int64_t hardware_flags_;
+ int64_t original_hardware_flags_;
+ int64_t cache_sizes_[L3_CACHE + 1];
+ int64_t cycles_per_ms_;
+ int num_cores_;
+ std::string model_name_;
+ Vendor vendor_;
+};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/decimal.cc b/src/arrow/cpp/src/arrow/util/decimal.cc
new file mode 100644
index 000000000..3118db994
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/decimal.cc
@@ -0,0 +1,908 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <array>
+#include <climits>
+#include <cmath>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+#include <iomanip>
+#include <limits>
+#include <ostream>
+#include <sstream>
+#include <string>
+
+#include "arrow/status.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/formatting.h"
+#include "arrow/util/int128_internal.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/value_parsing.h"
+
+namespace arrow {
+
+using internal::SafeLeftShift;
+using internal::SafeSignedAdd;
+using internal::uint128_t;
+
+Decimal128::Decimal128(const std::string& str) : Decimal128() {
+ *this = Decimal128::FromString(str).ValueOrDie();
+}
+
+static constexpr auto kInt64DecimalDigits =
+ static_cast<size_t>(std::numeric_limits<int64_t>::digits10);
+
+static constexpr uint64_t kUInt64PowersOfTen[kInt64DecimalDigits + 1] = {
+ // clang-format off
+ 1ULL,
+ 10ULL,
+ 100ULL,
+ 1000ULL,
+ 10000ULL,
+ 100000ULL,
+ 1000000ULL,
+ 10000000ULL,
+ 100000000ULL,
+ 1000000000ULL,
+ 10000000000ULL,
+ 100000000000ULL,
+ 1000000000000ULL,
+ 10000000000000ULL,
+ 100000000000000ULL,
+ 1000000000000000ULL,
+ 10000000000000000ULL,
+ 100000000000000000ULL,
+ 1000000000000000000ULL
+ // clang-format on
+};
+
+static constexpr float kFloatPowersOfTen[2 * 38 + 1] = {
+ 1e-38f, 1e-37f, 1e-36f, 1e-35f, 1e-34f, 1e-33f, 1e-32f, 1e-31f, 1e-30f, 1e-29f,
+ 1e-28f, 1e-27f, 1e-26f, 1e-25f, 1e-24f, 1e-23f, 1e-22f, 1e-21f, 1e-20f, 1e-19f,
+ 1e-18f, 1e-17f, 1e-16f, 1e-15f, 1e-14f, 1e-13f, 1e-12f, 1e-11f, 1e-10f, 1e-9f,
+ 1e-8f, 1e-7f, 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f, 1e-1f, 1e0f, 1e1f,
+ 1e2f, 1e3f, 1e4f, 1e5f, 1e6f, 1e7f, 1e8f, 1e9f, 1e10f, 1e11f,
+ 1e12f, 1e13f, 1e14f, 1e15f, 1e16f, 1e17f, 1e18f, 1e19f, 1e20f, 1e21f,
+ 1e22f, 1e23f, 1e24f, 1e25f, 1e26f, 1e27f, 1e28f, 1e29f, 1e30f, 1e31f,
+ 1e32f, 1e33f, 1e34f, 1e35f, 1e36f, 1e37f, 1e38f};
+
+static constexpr double kDoublePowersOfTen[2 * 38 + 1] = {
+ 1e-38, 1e-37, 1e-36, 1e-35, 1e-34, 1e-33, 1e-32, 1e-31, 1e-30, 1e-29, 1e-28,
+ 1e-27, 1e-26, 1e-25, 1e-24, 1e-23, 1e-22, 1e-21, 1e-20, 1e-19, 1e-18, 1e-17,
+ 1e-16, 1e-15, 1e-14, 1e-13, 1e-12, 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6,
+ 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4, 1e5,
+ 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16,
+ 1e17, 1e18, 1e19, 1e20, 1e21, 1e22, 1e23, 1e24, 1e25, 1e26, 1e27,
+ 1e28, 1e29, 1e30, 1e31, 1e32, 1e33, 1e34, 1e35, 1e36, 1e37, 1e38};
+
+// On the Windows R toolchain, INFINITY is double type instead of float
+static constexpr float kFloatInf = std::numeric_limits<float>::infinity();
+static constexpr float kFloatPowersOfTen76[2 * 76 + 1] = {
+ 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1e-45f, 1e-44f, 1e-43f, 1e-42f,
+ 1e-41f, 1e-40f, 1e-39f, 1e-38f, 1e-37f, 1e-36f, 1e-35f,
+ 1e-34f, 1e-33f, 1e-32f, 1e-31f, 1e-30f, 1e-29f, 1e-28f,
+ 1e-27f, 1e-26f, 1e-25f, 1e-24f, 1e-23f, 1e-22f, 1e-21f,
+ 1e-20f, 1e-19f, 1e-18f, 1e-17f, 1e-16f, 1e-15f, 1e-14f,
+ 1e-13f, 1e-12f, 1e-11f, 1e-10f, 1e-9f, 1e-8f, 1e-7f,
+ 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f, 1e-1f, 1e0f,
+ 1e1f, 1e2f, 1e3f, 1e4f, 1e5f, 1e6f, 1e7f,
+ 1e8f, 1e9f, 1e10f, 1e11f, 1e12f, 1e13f, 1e14f,
+ 1e15f, 1e16f, 1e17f, 1e18f, 1e19f, 1e20f, 1e21f,
+ 1e22f, 1e23f, 1e24f, 1e25f, 1e26f, 1e27f, 1e28f,
+ 1e29f, 1e30f, 1e31f, 1e32f, 1e33f, 1e34f, 1e35f,
+ 1e36f, 1e37f, 1e38f, kFloatInf, kFloatInf, kFloatInf, kFloatInf,
+ kFloatInf, kFloatInf, kFloatInf, kFloatInf, kFloatInf, kFloatInf, kFloatInf,
+ kFloatInf, kFloatInf, kFloatInf, kFloatInf, kFloatInf, kFloatInf, kFloatInf,
+ kFloatInf, kFloatInf, kFloatInf, kFloatInf, kFloatInf, kFloatInf, kFloatInf,
+ kFloatInf, kFloatInf, kFloatInf, kFloatInf, kFloatInf, kFloatInf, kFloatInf,
+ kFloatInf, kFloatInf, kFloatInf, kFloatInf, kFloatInf, kFloatInf};
+
+static constexpr double kDoublePowersOfTen76[2 * 76 + 1] = {
+ 1e-76, 1e-75, 1e-74, 1e-73, 1e-72, 1e-71, 1e-70, 1e-69, 1e-68, 1e-67, 1e-66, 1e-65,
+ 1e-64, 1e-63, 1e-62, 1e-61, 1e-60, 1e-59, 1e-58, 1e-57, 1e-56, 1e-55, 1e-54, 1e-53,
+ 1e-52, 1e-51, 1e-50, 1e-49, 1e-48, 1e-47, 1e-46, 1e-45, 1e-44, 1e-43, 1e-42, 1e-41,
+ 1e-40, 1e-39, 1e-38, 1e-37, 1e-36, 1e-35, 1e-34, 1e-33, 1e-32, 1e-31, 1e-30, 1e-29,
+ 1e-28, 1e-27, 1e-26, 1e-25, 1e-24, 1e-23, 1e-22, 1e-21, 1e-20, 1e-19, 1e-18, 1e-17,
+ 1e-16, 1e-15, 1e-14, 1e-13, 1e-12, 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5,
+ 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7,
+ 1e8, 1e9, 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19,
+ 1e20, 1e21, 1e22, 1e23, 1e24, 1e25, 1e26, 1e27, 1e28, 1e29, 1e30, 1e31,
+ 1e32, 1e33, 1e34, 1e35, 1e36, 1e37, 1e38, 1e39, 1e40, 1e41, 1e42, 1e43,
+ 1e44, 1e45, 1e46, 1e47, 1e48, 1e49, 1e50, 1e51, 1e52, 1e53, 1e54, 1e55,
+ 1e56, 1e57, 1e58, 1e59, 1e60, 1e61, 1e62, 1e63, 1e64, 1e65, 1e66, 1e67,
+ 1e68, 1e69, 1e70, 1e71, 1e72, 1e73, 1e74, 1e75, 1e76};
+
+namespace {
+
+template <typename Real, typename Derived>
+struct DecimalRealConversion {
+ static Result<Decimal128> FromPositiveReal(Real real, int32_t precision,
+ int32_t scale) {
+ auto x = real;
+ if (scale >= -38 && scale <= 38) {
+ x *= Derived::powers_of_ten()[scale + 38];
+ } else {
+ x *= std::pow(static_cast<Real>(10), static_cast<Real>(scale));
+ }
+ x = std::nearbyint(x);
+ const auto max_abs = Derived::powers_of_ten()[precision + 38];
+ if (x <= -max_abs || x >= max_abs) {
+ return Status::Invalid("Cannot convert ", real,
+ " to Decimal128(precision = ", precision,
+ ", scale = ", scale, "): overflow");
+ }
+ // Extract high and low bits
+ const auto high = std::floor(std::ldexp(x, -64));
+ const auto low = x - std::ldexp(high, 64);
+
+ DCHECK_GE(high, -9.223372036854776e+18); // -2**63
+ DCHECK_LT(high, 9.223372036854776e+18); // 2**63
+ DCHECK_GE(low, 0);
+ DCHECK_LT(low, 1.8446744073709552e+19); // 2**64
+ return Decimal128(static_cast<int64_t>(high), static_cast<uint64_t>(low));
+ }
+
+ static Result<Decimal128> FromReal(Real x, int32_t precision, int32_t scale) {
+ DCHECK_GT(precision, 0);
+ DCHECK_LE(precision, 38);
+
+ if (!std::isfinite(x)) {
+ return Status::Invalid("Cannot convert ", x, " to Decimal128");
+ }
+ if (x < 0) {
+ ARROW_ASSIGN_OR_RAISE(auto dec, FromPositiveReal(-x, precision, scale));
+ return dec.Negate();
+ } else {
+ // Includes negative zero
+ return FromPositiveReal(x, precision, scale);
+ }
+ }
+
+ static Real ToRealPositive(const Decimal128& decimal, int32_t scale) {
+ Real x = static_cast<Real>(decimal.high_bits()) * Derived::two_to_64();
+ x += static_cast<Real>(decimal.low_bits());
+ if (scale >= -38 && scale <= 38) {
+ x *= Derived::powers_of_ten()[-scale + 38];
+ } else {
+ x *= std::pow(static_cast<Real>(10), static_cast<Real>(-scale));
+ }
+ return x;
+ }
+
+ static Real ToReal(Decimal128 decimal, int32_t scale) {
+ if (decimal.high_bits() < 0) {
+ // Convert the absolute value to avoid precision loss
+ decimal.Negate();
+ return -ToRealPositive(decimal, scale);
+ } else {
+ return ToRealPositive(decimal, scale);
+ }
+ }
+};
+
+struct DecimalFloatConversion
+ : public DecimalRealConversion<float, DecimalFloatConversion> {
+ static constexpr const float* powers_of_ten() { return kFloatPowersOfTen; }
+
+ static constexpr float two_to_64() { return 1.8446744e+19f; }
+};
+
+struct DecimalDoubleConversion
+ : public DecimalRealConversion<double, DecimalDoubleConversion> {
+ static constexpr const double* powers_of_ten() { return kDoublePowersOfTen; }
+
+ static constexpr double two_to_64() { return 1.8446744073709552e+19; }
+};
+
+} // namespace
+
+Result<Decimal128> Decimal128::FromReal(float x, int32_t precision, int32_t scale) {
+ return DecimalFloatConversion::FromReal(x, precision, scale);
+}
+
+Result<Decimal128> Decimal128::FromReal(double x, int32_t precision, int32_t scale) {
+ return DecimalDoubleConversion::FromReal(x, precision, scale);
+}
+
+float Decimal128::ToFloat(int32_t scale) const {
+ return DecimalFloatConversion::ToReal(*this, scale);
+}
+
+double Decimal128::ToDouble(int32_t scale) const {
+ return DecimalDoubleConversion::ToReal(*this, scale);
+}
+
+template <size_t n>
+static void AppendLittleEndianArrayToString(const std::array<uint64_t, n>& array,
+ std::string* result) {
+ const auto most_significant_non_zero =
+ find_if(array.rbegin(), array.rend(), [](uint64_t v) { return v != 0; });
+ if (most_significant_non_zero == array.rend()) {
+ result->push_back('0');
+ return;
+ }
+
+ size_t most_significant_elem_idx = &*most_significant_non_zero - array.data();
+ std::array<uint64_t, n> copy = array;
+ constexpr uint32_t k1e9 = 1000000000U;
+ constexpr size_t kNumBits = n * 64;
+ // Segments will contain the array split into groups that map to decimal digits,
+ // in little endian order. Each segment will hold at most 9 decimal digits.
+ // For example, if the input represents 9876543210123456789, then segments will be
+ // [123456789, 876543210, 9].
+ // The max number of segments needed = ceil(kNumBits * log(2) / log(1e9))
+ // = ceil(kNumBits / 29.897352854) <= ceil(kNumBits / 29).
+ std::array<uint32_t, (kNumBits + 28) / 29> segments;
+ size_t num_segments = 0;
+ uint64_t* most_significant_elem = &copy[most_significant_elem_idx];
+ do {
+ // Compute remainder = copy % 1e9 and copy = copy / 1e9.
+ uint32_t remainder = 0;
+ uint64_t* elem = most_significant_elem;
+ do {
+ // Compute dividend = (remainder << 32) | *elem (a virtual 96-bit integer);
+ // *elem = dividend / 1e9;
+ // remainder = dividend % 1e9.
+ uint32_t hi = static_cast<uint32_t>(*elem >> 32);
+ uint32_t lo = static_cast<uint32_t>(*elem & BitUtil::LeastSignificantBitMask(32));
+ uint64_t dividend_hi = (static_cast<uint64_t>(remainder) << 32) | hi;
+ uint64_t quotient_hi = dividend_hi / k1e9;
+ remainder = static_cast<uint32_t>(dividend_hi % k1e9);
+ uint64_t dividend_lo = (static_cast<uint64_t>(remainder) << 32) | lo;
+ uint64_t quotient_lo = dividend_lo / k1e9;
+ remainder = static_cast<uint32_t>(dividend_lo % k1e9);
+ *elem = (quotient_hi << 32) | quotient_lo;
+ } while (elem-- != copy.data());
+
+ segments[num_segments++] = remainder;
+ } while (*most_significant_elem != 0 || most_significant_elem-- != copy.data());
+
+ size_t old_size = result->size();
+ size_t new_size = old_size + num_segments * 9;
+ result->resize(new_size, '0');
+ char* output = &result->at(old_size);
+ const uint32_t* segment = &segments[num_segments - 1];
+ internal::StringFormatter<UInt32Type> format;
+ // First segment is formatted as-is.
+ format(*segment, [&output](util::string_view formatted) {
+ memcpy(output, formatted.data(), formatted.size());
+ output += formatted.size();
+ });
+ while (segment != segments.data()) {
+ --segment;
+ // Right-pad formatted segment such that e.g. 123 is formatted as "000000123".
+ output += 9;
+ format(*segment, [output](util::string_view formatted) {
+ memcpy(output - formatted.size(), formatted.data(), formatted.size());
+ });
+ }
+ result->resize(output - result->data());
+}
+
+std::string Decimal128::ToIntegerString() const {
+ std::string result;
+ if (high_bits() < 0) {
+ result.push_back('-');
+ Decimal128 abs = *this;
+ abs.Negate();
+ AppendLittleEndianArrayToString<2>(
+ {abs.low_bits(), static_cast<uint64_t>(abs.high_bits())}, &result);
+ } else {
+ AppendLittleEndianArrayToString<2>({low_bits(), static_cast<uint64_t>(high_bits())},
+ &result);
+ }
+ return result;
+}
+
+Decimal128::operator int64_t() const {
+ DCHECK(high_bits() == 0 || high_bits() == -1)
+ << "Trying to cast a Decimal128 greater than the value range of a "
+ "int64_t. high_bits_ must be equal to 0 or -1, got: "
+ << high_bits();
+ return static_cast<int64_t>(low_bits());
+}
+
+static void AdjustIntegerStringWithScale(int32_t scale, std::string* str) {
+ if (scale == 0) {
+ return;
+ }
+ DCHECK(str != nullptr);
+ DCHECK(!str->empty());
+ const bool is_negative = str->front() == '-';
+ const auto is_negative_offset = static_cast<int32_t>(is_negative);
+ const auto len = static_cast<int32_t>(str->size());
+ const int32_t num_digits = len - is_negative_offset;
+ const int32_t adjusted_exponent = num_digits - 1 - scale;
+
+ /// Note that the -6 is taken from the Java BigDecimal documentation.
+ if (scale < 0 || adjusted_exponent < -6) {
+ // Example 1:
+ // Precondition: *str = "123", is_negative_offset = 0, num_digits = 3, scale = -2,
+ // adjusted_exponent = 4
+ // After inserting decimal point: *str = "1.23"
+ // After appending exponent: *str = "1.23E+4"
+ // Example 2:
+ // Precondition: *str = "-123", is_negative_offset = 1, num_digits = 3, scale = 9,
+ // adjusted_exponent = -7
+ // After inserting decimal point: *str = "-1.23"
+ // After appending exponent: *str = "-1.23E-7"
+ str->insert(str->begin() + 1 + is_negative_offset, '.');
+ str->push_back('E');
+ if (adjusted_exponent >= 0) {
+ str->push_back('+');
+ }
+ internal::StringFormatter<Int32Type> format;
+ format(adjusted_exponent, [str](util::string_view formatted) {
+ str->append(formatted.data(), formatted.size());
+ });
+ return;
+ }
+
+ if (num_digits > scale) {
+ const auto n = static_cast<size_t>(len - scale);
+ // Example 1:
+ // Precondition: *str = "123", len = num_digits = 3, scale = 1, n = 2
+ // After inserting decimal point: *str = "12.3"
+ // Example 2:
+ // Precondition: *str = "-123", len = 4, num_digits = 3, scale = 1, n = 3
+ // After inserting decimal point: *str = "-12.3"
+ str->insert(str->begin() + n, '.');
+ return;
+ }
+
+ // Example 1:
+ // Precondition: *str = "123", is_negative_offset = 0, num_digits = 3, scale = 4
+ // After insert: *str = "000123"
+ // After setting decimal point: *str = "0.0123"
+ // Example 2:
+ // Precondition: *str = "-123", is_negative_offset = 1, num_digits = 3, scale = 4
+ // After insert: *str = "-000123"
+ // After setting decimal point: *str = "-0.0123"
+ str->insert(is_negative_offset, scale - num_digits + 2, '0');
+ str->at(is_negative_offset + 1) = '.';
+}
+
+std::string Decimal128::ToString(int32_t scale) const {
+ if (ARROW_PREDICT_FALSE(scale < -kMaxScale || scale > kMaxScale)) {
+ return "<scale out of range, cannot format Decimal128 value>";
+ }
+ std::string str(ToIntegerString());
+ AdjustIntegerStringWithScale(scale, &str);
+ return str;
+}
+
+// Iterates over input and for each group of kInt64DecimalDigits multiple out by
+// the appropriate power of 10 necessary to add source parsed as uint64 and
+// then adds the parsed value of source.
+static inline void ShiftAndAdd(const util::string_view& input, uint64_t out[],
+ size_t out_size) {
+ for (size_t posn = 0; posn < input.size();) {
+ const size_t group_size = std::min(kInt64DecimalDigits, input.size() - posn);
+ const uint64_t multiple = kUInt64PowersOfTen[group_size];
+ uint64_t chunk = 0;
+ ARROW_CHECK(
+ internal::ParseValue<UInt64Type>(input.data() + posn, group_size, &chunk));
+
+ for (size_t i = 0; i < out_size; ++i) {
+ uint128_t tmp = out[i];
+ tmp *= multiple;
+ tmp += chunk;
+ out[i] = static_cast<uint64_t>(tmp & 0xFFFFFFFFFFFFFFFFULL);
+ chunk = static_cast<uint64_t>(tmp >> 64);
+ }
+ posn += group_size;
+ }
+}
+
+namespace {
+
+struct DecimalComponents {
+ util::string_view whole_digits;
+ util::string_view fractional_digits;
+ int32_t exponent = 0;
+ char sign = 0;
+ bool has_exponent = false;
+};
+
+inline bool IsSign(char c) { return c == '-' || c == '+'; }
+
+inline bool IsDot(char c) { return c == '.'; }
+
+inline bool IsDigit(char c) { return c >= '0' && c <= '9'; }
+
+inline bool StartsExponent(char c) { return c == 'e' || c == 'E'; }
+
+inline size_t ParseDigitsRun(const char* s, size_t start, size_t size,
+ util::string_view* out) {
+ size_t pos;
+ for (pos = start; pos < size; ++pos) {
+ if (!IsDigit(s[pos])) {
+ break;
+ }
+ }
+ *out = util::string_view(s + start, pos - start);
+ return pos;
+}
+
+bool ParseDecimalComponents(const char* s, size_t size, DecimalComponents* out) {
+ size_t pos = 0;
+
+ if (size == 0) {
+ return false;
+ }
+ // Sign of the number
+ if (IsSign(s[pos])) {
+ out->sign = *(s + pos);
+ ++pos;
+ }
+ // First run of digits
+ pos = ParseDigitsRun(s, pos, size, &out->whole_digits);
+ if (pos == size) {
+ return !out->whole_digits.empty();
+ }
+ // Optional dot (if given in fractional form)
+ bool has_dot = IsDot(s[pos]);
+ if (has_dot) {
+ // Second run of digits
+ ++pos;
+ pos = ParseDigitsRun(s, pos, size, &out->fractional_digits);
+ }
+ if (out->whole_digits.empty() && out->fractional_digits.empty()) {
+ // Need at least some digits (whole or fractional)
+ return false;
+ }
+ if (pos == size) {
+ return true;
+ }
+ // Optional exponent
+ if (StartsExponent(s[pos])) {
+ ++pos;
+ if (pos != size && s[pos] == '+') {
+ ++pos;
+ }
+ out->has_exponent = true;
+ return internal::ParseValue<Int32Type>(s + pos, size - pos, &(out->exponent));
+ }
+ return pos == size;
+}
+
+inline Status ToArrowStatus(DecimalStatus dstatus, int num_bits) {
+ switch (dstatus) {
+ case DecimalStatus::kSuccess:
+ return Status::OK();
+
+ case DecimalStatus::kDivideByZero:
+ return Status::Invalid("Division by 0 in Decimal", num_bits);
+
+ case DecimalStatus::kOverflow:
+ return Status::Invalid("Overflow occurred during Decimal", num_bits, " operation.");
+
+ case DecimalStatus::kRescaleDataLoss:
+ return Status::Invalid("Rescaling Decimal", num_bits,
+ " value would cause data loss");
+ }
+ return Status::OK();
+}
+
+template <typename Decimal>
+Status DecimalFromString(const char* type_name, const util::string_view& s, Decimal* out,
+ int32_t* precision, int32_t* scale) {
+ if (s.empty()) {
+ return Status::Invalid("Empty string cannot be converted to ", type_name);
+ }
+
+ DecimalComponents dec;
+ if (!ParseDecimalComponents(s.data(), s.size(), &dec)) {
+ return Status::Invalid("The string '", s, "' is not a valid ", type_name, " number");
+ }
+
+ // Count number of significant digits (without leading zeros)
+ size_t first_non_zero = dec.whole_digits.find_first_not_of('0');
+ size_t significant_digits = dec.fractional_digits.size();
+ if (first_non_zero != std::string::npos) {
+ significant_digits += dec.whole_digits.size() - first_non_zero;
+ }
+ int32_t parsed_precision = static_cast<int32_t>(significant_digits);
+
+ int32_t parsed_scale = 0;
+ if (dec.has_exponent) {
+ auto adjusted_exponent = dec.exponent;
+ parsed_scale =
+ -adjusted_exponent + static_cast<int32_t>(dec.fractional_digits.size());
+ } else {
+ parsed_scale = static_cast<int32_t>(dec.fractional_digits.size());
+ }
+
+ if (out != nullptr) {
+ static_assert(Decimal::kBitWidth % 64 == 0, "decimal bit-width not a multiple of 64");
+ std::array<uint64_t, Decimal::kBitWidth / 64> little_endian_array{};
+ ShiftAndAdd(dec.whole_digits, little_endian_array.data(), little_endian_array.size());
+ ShiftAndAdd(dec.fractional_digits, little_endian_array.data(),
+ little_endian_array.size());
+ *out = Decimal(BitUtil::LittleEndianArray::ToNative(little_endian_array));
+ if (dec.sign == '-') {
+ out->Negate();
+ }
+ }
+
+ if (parsed_scale < 0) {
+ // Force the scale to zero, to avoid negative scales (due to compatibility issues
+ // with external systems such as databases)
+ if (-parsed_scale > Decimal::kMaxScale) {
+ return Status::Invalid("The string '", s, "' cannot be represented as ", type_name);
+ }
+ if (out != nullptr) {
+ *out *= Decimal::GetScaleMultiplier(-parsed_scale);
+ }
+ parsed_precision -= parsed_scale;
+ parsed_scale = 0;
+ }
+
+ if (precision != nullptr) {
+ *precision = parsed_precision;
+ }
+ if (scale != nullptr) {
+ *scale = parsed_scale;
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+Status Decimal128::FromString(const util::string_view& s, Decimal128* out,
+ int32_t* precision, int32_t* scale) {
+ return DecimalFromString("decimal128", s, out, precision, scale);
+}
+
+Status Decimal128::FromString(const std::string& s, Decimal128* out, int32_t* precision,
+ int32_t* scale) {
+ return FromString(util::string_view(s), out, precision, scale);
+}
+
+Status Decimal128::FromString(const char* s, Decimal128* out, int32_t* precision,
+ int32_t* scale) {
+ return FromString(util::string_view(s), out, precision, scale);
+}
+
+Result<Decimal128> Decimal128::FromString(const util::string_view& s) {
+ Decimal128 out;
+ RETURN_NOT_OK(FromString(s, &out, nullptr, nullptr));
+ return std::move(out);
+}
+
+Result<Decimal128> Decimal128::FromString(const std::string& s) {
+ return FromString(util::string_view(s));
+}
+
+Result<Decimal128> Decimal128::FromString(const char* s) {
+ return FromString(util::string_view(s));
+}
+
+// Helper function used by Decimal128::FromBigEndian
+static inline uint64_t UInt64FromBigEndian(const uint8_t* bytes, int32_t length) {
+ // We don't bounds check the length here because this is called by
+ // FromBigEndian that has a Decimal128 as its out parameters and
+ // that function is already checking the length of the bytes and only
+ // passes lengths between zero and eight.
+ uint64_t result = 0;
+ // Using memcpy instead of special casing for length
+ // and doing the conversion in 16, 32 parts, which could
+ // possibly create unaligned memory access on certain platforms
+ memcpy(reinterpret_cast<uint8_t*>(&result) + 8 - length, bytes, length);
+ return ::arrow::BitUtil::FromBigEndian(result);
+}
+
+Result<Decimal128> Decimal128::FromBigEndian(const uint8_t* bytes, int32_t length) {
+ static constexpr int32_t kMinDecimalBytes = 1;
+ static constexpr int32_t kMaxDecimalBytes = 16;
+
+ int64_t high, low;
+
+ if (ARROW_PREDICT_FALSE(length < kMinDecimalBytes || length > kMaxDecimalBytes)) {
+ return Status::Invalid("Length of byte array passed to Decimal128::FromBigEndian ",
+ "was ", length, ", but must be between ", kMinDecimalBytes,
+ " and ", kMaxDecimalBytes);
+ }
+
+ // Bytes are coming in big-endian, so the first byte is the MSB and therefore holds the
+ // sign bit.
+ const bool is_negative = static_cast<int8_t>(bytes[0]) < 0;
+
+ // 1. Extract the high bytes
+ // Stop byte of the high bytes
+ const int32_t high_bits_offset = std::max(0, length - 8);
+ const auto high_bits = UInt64FromBigEndian(bytes, high_bits_offset);
+
+ if (high_bits_offset == 8) {
+ // Avoid undefined shift by 64 below
+ high = high_bits;
+ } else {
+ high = -1 * (is_negative && length < kMaxDecimalBytes);
+ // Shift left enough bits to make room for the incoming int64_t
+ high = SafeLeftShift(high, high_bits_offset * CHAR_BIT);
+ // Preserve the upper bits by inplace OR-ing the int64_t
+ high |= high_bits;
+ }
+
+ // 2. Extract the low bytes
+ // Stop byte of the low bytes
+ const int32_t low_bits_offset = std::min(length, 8);
+ const auto low_bits =
+ UInt64FromBigEndian(bytes + high_bits_offset, length - high_bits_offset);
+
+ if (low_bits_offset == 8) {
+ // Avoid undefined shift by 64 below
+ low = low_bits;
+ } else {
+ // Sign extend the low bits if necessary
+ low = -1 * (is_negative && length < 8);
+ // Shift left enough bits to make room for the incoming int64_t
+ low = SafeLeftShift(low, low_bits_offset * CHAR_BIT);
+ // Preserve the upper bits by inplace OR-ing the int64_t
+ low |= low_bits;
+ }
+
+ return Decimal128(high, static_cast<uint64_t>(low));
+}
+
+Status Decimal128::ToArrowStatus(DecimalStatus dstatus) const {
+ return arrow::ToArrowStatus(dstatus, 128);
+}
+
+std::ostream& operator<<(std::ostream& os, const Decimal128& decimal) {
+ os << decimal.ToIntegerString();
+ return os;
+}
+
+Decimal256::Decimal256(const std::string& str) : Decimal256() {
+ *this = Decimal256::FromString(str).ValueOrDie();
+}
+
+std::string Decimal256::ToIntegerString() const {
+ std::string result;
+ if (IsNegative()) {
+ result.push_back('-');
+ Decimal256 abs = *this;
+ abs.Negate();
+ AppendLittleEndianArrayToString(
+ BitUtil::LittleEndianArray::FromNative(abs.native_endian_array()), &result);
+ } else {
+ AppendLittleEndianArrayToString(
+ BitUtil::LittleEndianArray::FromNative(native_endian_array()), &result);
+ }
+ return result;
+}
+
+std::string Decimal256::ToString(int32_t scale) const {
+ if (ARROW_PREDICT_FALSE(scale < -kMaxScale || scale > kMaxScale)) {
+ return "<scale out of range, cannot format Decimal256 value>";
+ }
+ std::string str(ToIntegerString());
+ AdjustIntegerStringWithScale(scale, &str);
+ return str;
+}
+
+Status Decimal256::FromString(const util::string_view& s, Decimal256* out,
+ int32_t* precision, int32_t* scale) {
+ return DecimalFromString("decimal256", s, out, precision, scale);
+}
+
+Status Decimal256::FromString(const std::string& s, Decimal256* out, int32_t* precision,
+ int32_t* scale) {
+ return FromString(util::string_view(s), out, precision, scale);
+}
+
+Status Decimal256::FromString(const char* s, Decimal256* out, int32_t* precision,
+ int32_t* scale) {
+ return FromString(util::string_view(s), out, precision, scale);
+}
+
+Result<Decimal256> Decimal256::FromString(const util::string_view& s) {
+ Decimal256 out;
+ RETURN_NOT_OK(FromString(s, &out, nullptr, nullptr));
+ return std::move(out);
+}
+
+Result<Decimal256> Decimal256::FromString(const std::string& s) {
+ return FromString(util::string_view(s));
+}
+
+Result<Decimal256> Decimal256::FromString(const char* s) {
+ return FromString(util::string_view(s));
+}
+
+Result<Decimal256> Decimal256::FromBigEndian(const uint8_t* bytes, int32_t length) {
+ static constexpr int32_t kMinDecimalBytes = 1;
+ static constexpr int32_t kMaxDecimalBytes = 32;
+
+ std::array<uint64_t, 4> little_endian_array;
+
+ if (ARROW_PREDICT_FALSE(length < kMinDecimalBytes || length > kMaxDecimalBytes)) {
+ return Status::Invalid("Length of byte array passed to Decimal128::FromBigEndian ",
+ "was ", length, ", but must be between ", kMinDecimalBytes,
+ " and ", kMaxDecimalBytes);
+ }
+
+ // Bytes are coming in big-endian, so the first byte is the MSB and therefore holds the
+ // sign bit.
+ const bool is_negative = static_cast<int8_t>(bytes[0]) < 0;
+
+ for (int word_idx = 0; word_idx < 4; word_idx++) {
+ const int32_t word_length = std::min(length, static_cast<int32_t>(sizeof(uint64_t)));
+
+ if (word_length == 8) {
+ // Full words can be assigned as is (and are UB with the shift below).
+ little_endian_array[word_idx] =
+ UInt64FromBigEndian(bytes + length - word_length, word_length);
+ } else {
+ // Sign extend the word its if necessary
+ uint64_t word = -1 * is_negative;
+ if (length > 0) {
+ // Incorporate the actual values if present.
+ // Shift left enough bits to make room for the incoming int64_t
+ word = SafeLeftShift(word, word_length * CHAR_BIT);
+ // Preserve the upper bits by inplace OR-ing the int64_t
+ word |= UInt64FromBigEndian(bytes + length - word_length, word_length);
+ }
+ little_endian_array[word_idx] = word;
+ }
+ // Move on to the next word.
+ length -= word_length;
+ }
+
+ return Decimal256(BitUtil::LittleEndianArray::ToNative(little_endian_array));
+}
+
+Status Decimal256::ToArrowStatus(DecimalStatus dstatus) const {
+ return arrow::ToArrowStatus(dstatus, 256);
+}
+
+namespace {
+
+template <typename Real, typename Derived>
+struct Decimal256RealConversion {
+ static Result<Decimal256> FromPositiveReal(Real real, int32_t precision,
+ int32_t scale) {
+ auto x = real;
+ if (scale >= -76 && scale <= 76) {
+ x *= Derived::powers_of_ten()[scale + 76];
+ } else {
+ x *= std::pow(static_cast<Real>(10), static_cast<Real>(scale));
+ }
+ x = std::nearbyint(x);
+ const auto max_abs = Derived::powers_of_ten()[precision + 76];
+ if (x >= max_abs) {
+ return Status::Invalid("Cannot convert ", real,
+ " to Decimal256(precision = ", precision,
+ ", scale = ", scale, "): overflow");
+ }
+ // Extract parts
+ const auto part3 = std::floor(std::ldexp(x, -192));
+ x -= std::ldexp(part3, 192);
+ const auto part2 = std::floor(std::ldexp(x, -128));
+ x -= std::ldexp(part2, 128);
+ const auto part1 = std::floor(std::ldexp(x, -64));
+ x -= std::ldexp(part1, 64);
+ const auto part0 = x;
+
+ DCHECK_GE(part3, 0);
+ DCHECK_LT(part3, 1.8446744073709552e+19); // 2**64
+ DCHECK_GE(part2, 0);
+ DCHECK_LT(part2, 1.8446744073709552e+19); // 2**64
+ DCHECK_GE(part1, 0);
+ DCHECK_LT(part1, 1.8446744073709552e+19); // 2**64
+ DCHECK_GE(part0, 0);
+ DCHECK_LT(part0, 1.8446744073709552e+19); // 2**64
+ return Decimal256(BitUtil::LittleEndianArray::ToNative<uint64_t, 4>(
+ {static_cast<uint64_t>(part0), static_cast<uint64_t>(part1),
+ static_cast<uint64_t>(part2), static_cast<uint64_t>(part3)}));
+ }
+
+ static Result<Decimal256> FromReal(Real x, int32_t precision, int32_t scale) {
+ DCHECK_GT(precision, 0);
+ DCHECK_LE(precision, 76);
+
+ if (!std::isfinite(x)) {
+ return Status::Invalid("Cannot convert ", x, " to Decimal256");
+ }
+ if (x < 0) {
+ ARROW_ASSIGN_OR_RAISE(auto dec, FromPositiveReal(-x, precision, scale));
+ return dec.Negate();
+ } else {
+ // Includes negative zero
+ return FromPositiveReal(x, precision, scale);
+ }
+ }
+
+ static Real ToRealPositive(const Decimal256& decimal, int32_t scale) {
+ DCHECK_GE(decimal, 0);
+ Real x = 0;
+ const auto parts_le = BitUtil::LittleEndianArray::Make(decimal.native_endian_array());
+ x += Derived::two_to_192(static_cast<Real>(parts_le[3]));
+ x += Derived::two_to_128(static_cast<Real>(parts_le[2]));
+ x += Derived::two_to_64(static_cast<Real>(parts_le[1]));
+ x += static_cast<Real>(parts_le[0]);
+ if (scale >= -76 && scale <= 76) {
+ x *= Derived::powers_of_ten()[-scale + 76];
+ } else {
+ x *= std::pow(static_cast<Real>(10), static_cast<Real>(-scale));
+ }
+ return x;
+ }
+
+ static Real ToReal(Decimal256 decimal, int32_t scale) {
+ if (decimal.IsNegative()) {
+ // Convert the absolute value to avoid precision loss
+ decimal.Negate();
+ return -ToRealPositive(decimal, scale);
+ } else {
+ return ToRealPositive(decimal, scale);
+ }
+ }
+};
+
+struct Decimal256FloatConversion
+ : public Decimal256RealConversion<float, Decimal256FloatConversion> {
+ static constexpr const float* powers_of_ten() { return kFloatPowersOfTen76; }
+
+ static float two_to_64(float x) { return x * 1.8446744e+19f; }
+ static float two_to_128(float x) { return x == 0 ? 0 : INFINITY; }
+ static float two_to_192(float x) { return x == 0 ? 0 : INFINITY; }
+};
+
+struct Decimal256DoubleConversion
+ : public Decimal256RealConversion<double, Decimal256DoubleConversion> {
+ static constexpr const double* powers_of_ten() { return kDoublePowersOfTen76; }
+
+ static double two_to_64(double x) { return x * 1.8446744073709552e+19; }
+ static double two_to_128(double x) { return x * 3.402823669209385e+38; }
+ static double two_to_192(double x) { return x * 6.277101735386681e+57; }
+};
+
+} // namespace
+
+Result<Decimal256> Decimal256::FromReal(float x, int32_t precision, int32_t scale) {
+ return Decimal256FloatConversion::FromReal(x, precision, scale);
+}
+
+Result<Decimal256> Decimal256::FromReal(double x, int32_t precision, int32_t scale) {
+ return Decimal256DoubleConversion::FromReal(x, precision, scale);
+}
+
+float Decimal256::ToFloat(int32_t scale) const {
+ return Decimal256FloatConversion::ToReal(*this, scale);
+}
+
+double Decimal256::ToDouble(int32_t scale) const {
+ return Decimal256DoubleConversion::ToReal(*this, scale);
+}
+
+std::ostream& operator<<(std::ostream& os, const Decimal256& decimal) {
+ os << decimal.ToIntegerString();
+ return os;
+}
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/decimal.h b/src/arrow/cpp/src/arrow/util/decimal.h
new file mode 100644
index 000000000..da88fbeb3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/decimal.h
@@ -0,0 +1,314 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <iosfwd>
+#include <limits>
+#include <string>
+#include <utility>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/basic_decimal.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+
+/// Represents a signed 128-bit integer in two's complement.
+/// Calculations wrap around and overflow is ignored.
+/// The max decimal precision that can be safely represented is
+/// 38 significant digits.
+///
+/// For a discussion of the algorithms, look at Knuth's volume 2,
+/// Semi-numerical Algorithms section 4.3.1.
+///
+/// Adapted from the Apache ORC C++ implementation
+///
+/// The implementation is split into two parts :
+///
+/// 1. BasicDecimal128
+/// - can be safely compiled to IR without references to libstdc++.
+/// 2. Decimal128
+/// - has additional functionality on top of BasicDecimal128 to deal with
+/// strings and streams.
+class ARROW_EXPORT Decimal128 : public BasicDecimal128 {
+ public:
+ /// \cond FALSE
+ // (need to avoid a duplicate definition in Sphinx)
+ using BasicDecimal128::BasicDecimal128;
+ /// \endcond
+
+ /// \brief constructor creates a Decimal128 from a BasicDecimal128.
+ constexpr Decimal128(const BasicDecimal128& value) noexcept // NOLINT runtime/explicit
+ : BasicDecimal128(value) {}
+
+ /// \brief Parse the number from a base 10 string representation.
+ explicit Decimal128(const std::string& value);
+
+ /// \brief Empty constructor creates a Decimal128 with a value of 0.
+ // This is required on some older compilers.
+ constexpr Decimal128() noexcept : BasicDecimal128() {}
+
+ /// Divide this number by right and return the result.
+ ///
+ /// This operation is not destructive.
+ /// The answer rounds to zero. Signs work like:
+ /// 21 / 5 -> 4, 1
+ /// -21 / 5 -> -4, -1
+ /// 21 / -5 -> -4, 1
+ /// -21 / -5 -> 4, -1
+ /// \param[in] divisor the number to divide by
+ /// \return the pair of the quotient and the remainder
+ Result<std::pair<Decimal128, Decimal128>> Divide(const Decimal128& divisor) const {
+ std::pair<Decimal128, Decimal128> result;
+ auto dstatus = BasicDecimal128::Divide(divisor, &result.first, &result.second);
+ ARROW_RETURN_NOT_OK(ToArrowStatus(dstatus));
+ return std::move(result);
+ }
+
+ /// \brief Convert the Decimal128 value to a base 10 decimal string with the given
+ /// scale.
+ std::string ToString(int32_t scale) const;
+
+ /// \brief Convert the value to an integer string
+ std::string ToIntegerString() const;
+
+ /// \brief Cast this value to an int64_t.
+ explicit operator int64_t() const;
+
+ /// \brief Convert a decimal string to a Decimal128 value, optionally including
+ /// precision and scale if they're passed in and not null.
+ static Status FromString(const util::string_view& s, Decimal128* out,
+ int32_t* precision, int32_t* scale = NULLPTR);
+ static Status FromString(const std::string& s, Decimal128* out, int32_t* precision,
+ int32_t* scale = NULLPTR);
+ static Status FromString(const char* s, Decimal128* out, int32_t* precision,
+ int32_t* scale = NULLPTR);
+ static Result<Decimal128> FromString(const util::string_view& s);
+ static Result<Decimal128> FromString(const std::string& s);
+ static Result<Decimal128> FromString(const char* s);
+
+ static Result<Decimal128> FromReal(double real, int32_t precision, int32_t scale);
+ static Result<Decimal128> FromReal(float real, int32_t precision, int32_t scale);
+
+ /// \brief Convert from a big-endian byte representation. The length must be
+ /// between 1 and 16.
+ /// \return error status if the length is an invalid value
+ static Result<Decimal128> FromBigEndian(const uint8_t* data, int32_t length);
+
+ /// \brief Convert Decimal128 from one scale to another
+ Result<Decimal128> Rescale(int32_t original_scale, int32_t new_scale) const {
+ Decimal128 out;
+ auto dstatus = BasicDecimal128::Rescale(original_scale, new_scale, &out);
+ ARROW_RETURN_NOT_OK(ToArrowStatus(dstatus));
+ return std::move(out);
+ }
+
+ /// \brief Convert to a signed integer
+ template <typename T, typename = internal::EnableIfIsOneOf<T, int32_t, int64_t>>
+ Result<T> ToInteger() const {
+ constexpr auto min_value = std::numeric_limits<T>::min();
+ constexpr auto max_value = std::numeric_limits<T>::max();
+ const auto& self = *this;
+ if (self < min_value || self > max_value) {
+ return Status::Invalid("Invalid cast from Decimal128 to ", sizeof(T),
+ " byte integer");
+ }
+ return static_cast<T>(low_bits());
+ }
+
+ /// \brief Convert to a signed integer
+ template <typename T, typename = internal::EnableIfIsOneOf<T, int32_t, int64_t>>
+ Status ToInteger(T* out) const {
+ return ToInteger<T>().Value(out);
+ }
+
+ /// \brief Convert to a floating-point number (scaled)
+ float ToFloat(int32_t scale) const;
+ /// \brief Convert to a floating-point number (scaled)
+ double ToDouble(int32_t scale) const;
+
+ /// \brief Convert to a floating-point number (scaled)
+ template <typename T>
+ T ToReal(int32_t scale) const {
+ return ToRealConversion<T>::ToReal(*this, scale);
+ }
+
+ friend ARROW_EXPORT std::ostream& operator<<(std::ostream& os,
+ const Decimal128& decimal);
+
+ private:
+ /// Converts internal error code to Status
+ Status ToArrowStatus(DecimalStatus dstatus) const;
+
+ template <typename T>
+ struct ToRealConversion {};
+};
+
+template <>
+struct Decimal128::ToRealConversion<float> {
+ static float ToReal(const Decimal128& dec, int32_t scale) { return dec.ToFloat(scale); }
+};
+
+template <>
+struct Decimal128::ToRealConversion<double> {
+ static double ToReal(const Decimal128& dec, int32_t scale) {
+ return dec.ToDouble(scale);
+ }
+};
+
+/// Represents a signed 256-bit integer in two's complement.
+/// The max decimal precision that can be safely represented is
+/// 76 significant digits.
+///
+/// The implementation is split into two parts :
+///
+/// 1. BasicDecimal256
+/// - can be safely compiled to IR without references to libstdc++.
+/// 2. Decimal256
+/// - (TODO) has additional functionality on top of BasicDecimal256 to deal with
+/// strings and streams.
+class ARROW_EXPORT Decimal256 : public BasicDecimal256 {
+ public:
+ /// \cond FALSE
+ // (need to avoid a duplicate definition in Sphinx)
+ using BasicDecimal256::BasicDecimal256;
+ /// \endcond
+
+ /// \brief constructor creates a Decimal256 from a BasicDecimal256.
+ constexpr Decimal256(const BasicDecimal256& value) noexcept : BasicDecimal256(value) {}
+
+ /// \brief Parse the number from a base 10 string representation.
+ explicit Decimal256(const std::string& value);
+
+ /// \brief Empty constructor creates a Decimal256 with a value of 0.
+ // This is required on some older compilers.
+ constexpr Decimal256() noexcept : BasicDecimal256() {}
+
+ /// \brief Convert the Decimal256 value to a base 10 decimal string with the given
+ /// scale.
+ std::string ToString(int32_t scale) const;
+
+ /// \brief Convert the value to an integer string
+ std::string ToIntegerString() const;
+
+ /// \brief Convert a decimal string to a Decimal256 value, optionally including
+ /// precision and scale if they're passed in and not null.
+ static Status FromString(const util::string_view& s, Decimal256* out,
+ int32_t* precision, int32_t* scale = NULLPTR);
+ static Status FromString(const std::string& s, Decimal256* out, int32_t* precision,
+ int32_t* scale = NULLPTR);
+ static Status FromString(const char* s, Decimal256* out, int32_t* precision,
+ int32_t* scale = NULLPTR);
+ static Result<Decimal256> FromString(const util::string_view& s);
+ static Result<Decimal256> FromString(const std::string& s);
+ static Result<Decimal256> FromString(const char* s);
+
+ /// \brief Convert Decimal256 from one scale to another
+ Result<Decimal256> Rescale(int32_t original_scale, int32_t new_scale) const {
+ Decimal256 out;
+ auto dstatus = BasicDecimal256::Rescale(original_scale, new_scale, &out);
+ ARROW_RETURN_NOT_OK(ToArrowStatus(dstatus));
+ return std::move(out);
+ }
+
+ /// Divide this number by right and return the result.
+ ///
+ /// This operation is not destructive.
+ /// The answer rounds to zero. Signs work like:
+ /// 21 / 5 -> 4, 1
+ /// -21 / 5 -> -4, -1
+ /// 21 / -5 -> -4, 1
+ /// -21 / -5 -> 4, -1
+ /// \param[in] divisor the number to divide by
+ /// \return the pair of the quotient and the remainder
+ Result<std::pair<Decimal256, Decimal256>> Divide(const Decimal256& divisor) const {
+ std::pair<Decimal256, Decimal256> result;
+ auto dstatus = BasicDecimal256::Divide(divisor, &result.first, &result.second);
+ ARROW_RETURN_NOT_OK(ToArrowStatus(dstatus));
+ return std::move(result);
+ }
+
+ /// \brief Convert from a big-endian byte representation. The length must be
+ /// between 1 and 32.
+ /// \return error status if the length is an invalid value
+ static Result<Decimal256> FromBigEndian(const uint8_t* data, int32_t length);
+
+ static Result<Decimal256> FromReal(double real, int32_t precision, int32_t scale);
+ static Result<Decimal256> FromReal(float real, int32_t precision, int32_t scale);
+
+ /// \brief Convert to a floating-point number (scaled).
+ /// May return infinity in case of overflow.
+ float ToFloat(int32_t scale) const;
+ /// \brief Convert to a floating-point number (scaled)
+ double ToDouble(int32_t scale) const;
+
+ /// \brief Convert to a floating-point number (scaled)
+ template <typename T>
+ T ToReal(int32_t scale) const {
+ return ToRealConversion<T>::ToReal(*this, scale);
+ }
+
+ friend ARROW_EXPORT std::ostream& operator<<(std::ostream& os,
+ const Decimal256& decimal);
+
+ private:
+ /// Converts internal error code to Status
+ Status ToArrowStatus(DecimalStatus dstatus) const;
+
+ template <typename T>
+ struct ToRealConversion {};
+};
+
+template <>
+struct Decimal256::ToRealConversion<float> {
+ static float ToReal(const Decimal256& dec, int32_t scale) { return dec.ToFloat(scale); }
+};
+
+template <>
+struct Decimal256::ToRealConversion<double> {
+ static double ToReal(const Decimal256& dec, int32_t scale) {
+ return dec.ToDouble(scale);
+ }
+};
+
+/// For an integer type, return the max number of decimal digits
+/// (=minimal decimal precision) it can represent.
+inline Result<int32_t> MaxDecimalDigitsForInteger(Type::type type_id) {
+ switch (type_id) {
+ case Type::INT8:
+ case Type::UINT8:
+ return 3;
+ case Type::INT16:
+ case Type::UINT16:
+ return 5;
+ case Type::INT32:
+ case Type::UINT32:
+ return 10;
+ case Type::INT64:
+ case Type::UINT64:
+ return 19;
+ default:
+ break;
+ }
+ return Status::Invalid("Not an integer type: ", type_id);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/decimal_benchmark.cc b/src/arrow/cpp/src/arrow/util/decimal_benchmark.cc
new file mode 100644
index 000000000..ddcc1528f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/decimal_benchmark.cc
@@ -0,0 +1,282 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <string>
+#include <vector>
+
+#include "arrow/util/decimal.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace Decimal {
+
+static const std::vector<std::string>& GetValuesAsString() {
+ static const std::vector<std::string> kValues = {"0",
+ "1.23",
+ "12.345e6",
+ "-12.345e-6",
+ "123456789.123456789",
+ "1231234567890.451234567890"};
+ return kValues;
+}
+
+struct DecimalValueAndScale {
+ Decimal128 decimal;
+ int32_t scale;
+};
+
+static std::vector<DecimalValueAndScale> GetDecimalValuesAndScales() {
+ const std::vector<std::string>& value_strs = GetValuesAsString();
+ std::vector<DecimalValueAndScale> result(value_strs.size());
+ for (size_t i = 0; i < value_strs.size(); ++i) {
+ int32_t precision;
+ ARROW_CHECK_OK(Decimal128::FromString(value_strs[i], &result[i].decimal,
+ &result[i].scale, &precision));
+ }
+ return result;
+}
+
+static void FromString(benchmark::State& state) { // NOLINT non-const reference
+ const std::vector<std::string>& values = GetValuesAsString();
+ for (auto _ : state) {
+ for (const auto& value : values) {
+ Decimal128 dec;
+ int32_t scale, precision;
+ benchmark::DoNotOptimize(Decimal128::FromString(value, &dec, &scale, &precision));
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * values.size());
+}
+
+static void ToString(benchmark::State& state) { // NOLINT non-const reference
+ static const std::vector<DecimalValueAndScale> values = GetDecimalValuesAndScales();
+ for (auto _ : state) {
+ for (const DecimalValueAndScale& item : values) {
+ benchmark::DoNotOptimize(item.decimal.ToString(item.scale));
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * values.size());
+}
+
+constexpr int32_t kValueSize = 10;
+
+static void BinaryCompareOp(benchmark::State& state) { // NOLINT non-const reference
+ std::vector<BasicDecimal128> v1, v2;
+ for (int x = 0; x < kValueSize; x++) {
+ v1.emplace_back(100 + x, 100 + x);
+ v2.emplace_back(200 + x, 200 + x);
+ }
+ for (auto _ : state) {
+ for (int x = 0; x < kValueSize; x += 4) {
+ benchmark::DoNotOptimize(v1[x] == v2[x]);
+ benchmark::DoNotOptimize(v1[x + 1] <= v2[x + 1]);
+ benchmark::DoNotOptimize(v1[x + 2] >= v2[x + 2]);
+ benchmark::DoNotOptimize(v1[x + 3] >= v1[x + 3]);
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * kValueSize);
+}
+
+static void BinaryCompareOpConstant(
+ benchmark::State& state) { // NOLINT non-const reference
+ std::vector<BasicDecimal128> v1;
+ for (int x = 0; x < kValueSize; x++) {
+ v1.emplace_back(100 + x, 100 + x);
+ }
+ BasicDecimal128 constant(313, 212);
+ for (auto _ : state) {
+ for (int x = 0; x < kValueSize; x += 4) {
+ benchmark::DoNotOptimize(v1[x] == constant);
+ benchmark::DoNotOptimize(v1[x + 1] <= constant);
+ benchmark::DoNotOptimize(v1[x + 2] >= constant);
+ benchmark::DoNotOptimize(v1[x + 3] != constant);
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * kValueSize);
+}
+
+static void BinaryMathOpAggregate(
+ benchmark::State& state) { // NOLINT non-const reference
+ std::vector<BasicDecimal128> v;
+ for (int x = 0; x < kValueSize; x++) {
+ v.emplace_back(100 + x, 100 + x);
+ }
+
+ for (auto _ : state) {
+ BasicDecimal128 result;
+ for (int x = 0; x < 100; x++) {
+ result += v[x];
+ }
+ benchmark::DoNotOptimize(result);
+ }
+ state.SetItemsProcessed(state.iterations() * kValueSize);
+}
+
+static void BinaryMathOpAdd128(benchmark::State& state) { // NOLINT non-const reference
+ std::vector<BasicDecimal128> v1, v2;
+ for (int x = 0; x < kValueSize; x++) {
+ v1.emplace_back(100 + x, 100 + x);
+ v2.emplace_back(200 + x, 200 + x);
+ }
+
+ for (auto _ : state) {
+ for (int x = 0; x < kValueSize; ++x) {
+ benchmark::DoNotOptimize(v1[x] + v2[x]);
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * kValueSize);
+}
+
+static void BinaryMathOpMultiply128(
+ benchmark::State& state) { // NOLINT non-const reference
+ std::vector<BasicDecimal128> v1, v2;
+ for (int x = 0; x < kValueSize; x++) {
+ v1.emplace_back(100 + x, 100 + x);
+ v2.emplace_back(200 + x, 200 + x);
+ }
+
+ for (auto _ : state) {
+ for (int x = 0; x < kValueSize; ++x) {
+ benchmark::DoNotOptimize(v1[x] * v2[x]);
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * kValueSize);
+}
+
+static void BinaryMathOpDivide128(
+ benchmark::State& state) { // NOLINT non-const reference
+ std::vector<BasicDecimal128> v1, v2;
+ for (int x = 0; x < kValueSize; x++) {
+ v1.emplace_back(100 + x, 100 + x);
+ v2.emplace_back(200 + x, 200 + x);
+ }
+
+ for (auto _ : state) {
+ for (int x = 0; x < kValueSize; ++x) {
+ benchmark::DoNotOptimize(v1[x] / v2[x]);
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * kValueSize);
+}
+
+static void BinaryMathOpAdd256(benchmark::State& state) { // NOLINT non-const reference
+ std::vector<BasicDecimal256> v1, v2;
+ for (uint64_t x = 0; x < kValueSize; x++) {
+ v1.push_back(BasicDecimal256({100 + x, 100 + x, 100 + x, 100 + x}));
+ v2.push_back(BasicDecimal256({200 + x, 200 + x, 200 + x, 200 + x}));
+ }
+
+ for (auto _ : state) {
+ for (int x = 0; x < kValueSize; ++x) {
+ benchmark::DoNotOptimize(v1[x] + v2[x]);
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * kValueSize);
+}
+
+static void BinaryMathOpMultiply256(
+ benchmark::State& state) { // NOLINT non-const reference
+ std::vector<BasicDecimal256> v1, v2;
+ for (uint64_t x = 0; x < kValueSize; x++) {
+ v1.push_back(BasicDecimal256({100 + x, 100 + x, 100 + x, 100 + x}));
+ v2.push_back(BasicDecimal256({200 + x, 200 + x, 200 + x, 200 + x}));
+ }
+
+ for (auto _ : state) {
+ for (int x = 0; x < kValueSize; ++x) {
+ benchmark::DoNotOptimize(v1[x] * v2[x]);
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * kValueSize);
+}
+
+static void BinaryMathOpDivide256(
+ benchmark::State& state) { // NOLINT non-const reference
+ std::vector<BasicDecimal256> v1, v2;
+ for (uint64_t x = 0; x < kValueSize; x++) {
+ v1.push_back(BasicDecimal256({100 + x, 100 + x, 100 + x, 100 + x}));
+ v2.push_back(BasicDecimal256({200 + x, 200 + x, 200 + x, 200 + x}));
+ }
+
+ for (auto _ : state) {
+ for (int x = 0; x < kValueSize; ++x) {
+ benchmark::DoNotOptimize(v1[x] / v2[x]);
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * kValueSize);
+}
+
+static void UnaryOp(benchmark::State& state) { // NOLINT non-const reference
+ std::vector<BasicDecimal128> v;
+ for (int x = 0; x < kValueSize; x++) {
+ v.emplace_back(100 + x, 100 + x);
+ }
+
+ for (auto _ : state) {
+ for (int x = 0; x < kValueSize; x += 2) {
+ benchmark::DoNotOptimize(v[x].Abs());
+ benchmark::DoNotOptimize(v[x + 1].Negate());
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * kValueSize);
+}
+
+static void Constants(benchmark::State& state) { // NOLINT non-const reference
+ BasicDecimal128 d1(-546, 123), d2(-123, 456);
+ for (auto _ : state) {
+ benchmark::DoNotOptimize(BasicDecimal128::GetMaxValue() - d1);
+ benchmark::DoNotOptimize(BasicDecimal128::GetScaleMultiplier(3) + d2);
+ }
+ state.SetItemsProcessed(state.iterations() * 2);
+}
+
+static void BinaryBitOp(benchmark::State& state) { // NOLINT non-const reference
+ std::vector<BasicDecimal128> v1, v2;
+ for (int x = 0; x < kValueSize; x++) {
+ v1.emplace_back(100 + x, 100 + x);
+ v2.emplace_back(200 + x, 200 + x);
+ }
+
+ for (auto _ : state) {
+ for (int x = 0; x < kValueSize; x += 2) {
+ benchmark::DoNotOptimize(v1[x] |= v2[x]);
+ benchmark::DoNotOptimize(v1[x + 1] &= v2[x + 1]);
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * kValueSize);
+}
+
+BENCHMARK(FromString);
+BENCHMARK(ToString);
+BENCHMARK(BinaryMathOpAdd128);
+BENCHMARK(BinaryMathOpMultiply128);
+BENCHMARK(BinaryMathOpDivide128);
+BENCHMARK(BinaryMathOpAdd256);
+BENCHMARK(BinaryMathOpMultiply256);
+BENCHMARK(BinaryMathOpDivide256);
+BENCHMARK(BinaryMathOpAggregate);
+BENCHMARK(BinaryCompareOp);
+BENCHMARK(BinaryCompareOpConstant);
+BENCHMARK(UnaryOp);
+BENCHMARK(Constants);
+BENCHMARK(BinaryBitOp);
+
+} // namespace Decimal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/decimal_test.cc b/src/arrow/cpp/src/arrow/util/decimal_test.cc
new file mode 100644
index 000000000..75716f943
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/decimal_test.cc
@@ -0,0 +1,1939 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <array>
+#include <cmath>
+#include <cstdint>
+#include <ostream>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include <boost/multiprecision/cpp_int.hpp>
+
+#include "arrow/array.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/int128_internal.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::int128_t;
+using internal::uint128_t;
+
+using DecimalTypes = ::testing::Types<Decimal128, Decimal256>;
+
+static const int128_t kInt128Max =
+ (static_cast<int128_t>(INT64_MAX) << 64) + static_cast<int128_t>(UINT64_MAX);
+
+template <typename DecimalType>
+void AssertDecimalFromString(const std::string& s, const DecimalType& expected,
+ int32_t expected_precision, int32_t expected_scale) {
+ ARROW_SCOPED_TRACE("s = '", s, "'");
+ DecimalType d;
+ int32_t precision, scale;
+ ASSERT_OK(DecimalType::FromString(s, &d, &precision, &scale));
+ EXPECT_EQ(expected, d);
+ EXPECT_EQ(expected_precision, precision);
+ EXPECT_EQ(expected_scale, scale);
+}
+
+// Assert that the low bits of an array of integers are equal to `expected_low`,
+// and that all other bits are equal to `expected_high`.
+template <typename T, size_t N, typename U, typename V>
+void AssertArrayBits(const std::array<T, N>& a, U expected_low, V expected_high) {
+ EXPECT_EQ(a[0], expected_low);
+ for (size_t i = 1; i < N; ++i) {
+ EXPECT_EQ(a[i], expected_high);
+ }
+}
+
+Decimal128 Decimal128FromLE(const std::array<uint64_t, 2>& a) {
+ return Decimal128(Decimal128::LittleEndianArray, a);
+}
+
+Decimal256 Decimal256FromLE(const std::array<uint64_t, 4>& a) {
+ return Decimal256(Decimal256::LittleEndianArray, a);
+}
+
+template <typename DecimalType>
+struct DecimalTraits {};
+
+template <>
+struct DecimalTraits<Decimal128> {
+ using ArrowType = Decimal128Type;
+};
+
+template <>
+struct DecimalTraits<Decimal256> {
+ using ArrowType = Decimal256Type;
+};
+
+template <typename DecimalType>
+class DecimalFromStringTest : public ::testing::Test {
+ public:
+ using ArrowType = typename DecimalTraits<DecimalType>::ArrowType;
+ using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+
+ void TestBasics() { AssertDecimalFromString("234.23445", DecimalType(23423445), 8, 5); }
+
+ void TestStringStartingWithPlus() {
+ AssertDecimalFromString("+234.567", DecimalType(234567), 6, 3);
+ AssertDecimalFromString("+2342394230592.232349023094",
+ DecimalType("2342394230592232349023094"), 25, 12);
+ }
+
+ void TestInvalidInput() {
+ for (const std::string invalid_value :
+ {"-", "0.0.0", "0-13-32", "a", "-23092.235-", "-+23092.235", "+-23092.235",
+ "00a", "1e1a", "0.00123D/3", "1.23eA8", "1.23E+3A", "-1.23E--5",
+ "1.2345E+++07"}) {
+ ARROW_SCOPED_TRACE("invalid_value = '", invalid_value, "'");
+ ASSERT_RAISES(Invalid, Decimal128::FromString(invalid_value));
+ }
+ }
+
+ void TestLeadingZerosNoDecimalPoint() {
+ AssertDecimalFromString("0000000", DecimalType(0), 0, 0);
+ }
+
+ void TestLeadingZerosDecimalPoint() {
+ AssertDecimalFromString("000.0000", DecimalType(0), 4, 4);
+ }
+
+ void TestNoLeadingZerosDecimalPoint() {
+ AssertDecimalFromString(".00000", DecimalType(0), 5, 5);
+ }
+
+ void TestNoDecimalPointExponent() {
+ AssertDecimalFromString("1E1", DecimalType(10), 2, 0);
+ }
+
+ void TestWithExponentAndNullptrScale() {
+ const DecimalType expected_value(123);
+ ASSERT_OK_AND_EQ(expected_value, DecimalType::FromString("1.23E-8"));
+ }
+
+ void TestSmallValues() {
+ struct TestValue {
+ std::string s;
+ int64_t expected;
+ int32_t expected_precision;
+ int32_t expected_scale;
+ };
+ for (const auto& tv : std::vector<TestValue>{{"12.3", 123LL, 3, 1},
+ {"0.00123", 123LL, 5, 5},
+ {"1.23E-8", 123LL, 3, 10},
+ {"-1.23E-8", -123LL, 3, 10},
+ {"1.23E+3", 1230LL, 4, 0},
+ {"-1.23E+3", -1230LL, 4, 0},
+ {"1.23E+5", 123000LL, 6, 0},
+ {"1.2345E+7", 12345000LL, 8, 0},
+ {"1.23e-8", 123LL, 3, 10},
+ {"-1.23e-8", -123LL, 3, 10},
+ {"1.23e+3", 1230LL, 4, 0},
+ {"-1.23e+3", -1230LL, 4, 0},
+ {"1.23e+5", 123000LL, 6, 0},
+ {"1.2345e+7", 12345000LL, 8, 0}}) {
+ ARROW_SCOPED_TRACE("s = '", tv.s, "'");
+ AssertDecimalFromString(tv.s, DecimalType(tv.expected), tv.expected_precision,
+ tv.expected_scale);
+ }
+ }
+
+ void CheckRandomValuesRoundTrip(int32_t precision, int32_t scale) {
+ auto rnd = random::RandomArrayGenerator(42);
+ const auto ty = std::make_shared<ArrowType>(precision, scale);
+ const auto array = rnd.ArrayOf(ty, 100, /*null_probability=*/0.0);
+ for (int64_t i = 0; i < array->length(); ++i) {
+ ASSERT_OK_AND_ASSIGN(auto scalar, array->GetScalar(i));
+ const DecimalType& dec_value = checked_cast<const ScalarType&>(*scalar).value;
+ const auto s = dec_value.ToString(scale);
+ ASSERT_OK_AND_ASSIGN(auto round_tripped, DecimalType::FromString(s));
+ ASSERT_EQ(dec_value, round_tripped);
+ }
+ }
+
+ void TestRandomSmallValuesRoundTrip() {
+ for (int32_t scale : {0, 2, 9}) {
+ ARROW_SCOPED_TRACE("scale = ", scale);
+ CheckRandomValuesRoundTrip(9, scale);
+ }
+ }
+
+ void TestRandomValuesRoundTrip() {
+ const auto max_scale = DecimalType::kMaxScale;
+ for (int32_t scale : {0, 3, max_scale / 2, max_scale}) {
+ ARROW_SCOPED_TRACE("scale = ", scale);
+ CheckRandomValuesRoundTrip(DecimalType::kMaxPrecision, scale);
+ }
+ }
+};
+
+TYPED_TEST_SUITE(DecimalFromStringTest, DecimalTypes);
+
+TYPED_TEST(DecimalFromStringTest, Basics) { this->TestBasics(); }
+
+TYPED_TEST(DecimalFromStringTest, StringStartingWithPlus) {
+ this->TestStringStartingWithPlus();
+}
+
+TYPED_TEST(DecimalFromStringTest, InvalidInput) { this->TestInvalidInput(); }
+
+TYPED_TEST(DecimalFromStringTest, LeadingZerosDecimalPoint) {
+ this->TestLeadingZerosDecimalPoint();
+}
+
+TYPED_TEST(DecimalFromStringTest, LeadingZerosNoDecimalPoint) {
+ this->TestLeadingZerosNoDecimalPoint();
+}
+
+TYPED_TEST(DecimalFromStringTest, NoLeadingZerosDecimalPoint) {
+ this->TestNoLeadingZerosDecimalPoint();
+}
+
+TYPED_TEST(DecimalFromStringTest, NoDecimalPointExponent) {
+ this->TestNoDecimalPointExponent();
+}
+
+TYPED_TEST(DecimalFromStringTest, WithExponentAndNullptrScale) {
+ this->TestWithExponentAndNullptrScale();
+}
+
+TYPED_TEST(DecimalFromStringTest, SmallValues) { this->TestSmallValues(); }
+
+TYPED_TEST(DecimalFromStringTest, RandomSmallValuesRoundTrip) {
+ this->TestRandomSmallValuesRoundTrip();
+}
+
+TYPED_TEST(DecimalFromStringTest, RandomValuesRoundTrip) {
+ this->TestRandomValuesRoundTrip();
+}
+
+TEST(Decimal128Test, TestFromStringDecimal128) {
+ std::string string_value("-23049223942343532412");
+ Decimal128 result(string_value);
+ Decimal128 expected(static_cast<int64_t>(-230492239423435324));
+ ASSERT_EQ(result, expected * 100 - 12);
+
+ // Sanity check that our number is actually using more than 64 bits
+ ASSERT_NE(result.high_bits(), 0);
+}
+
+TEST(Decimal128Test, TestFromDecimalString128) {
+ std::string string_value("-23049223942343.532412");
+ Decimal128 result;
+ ASSERT_OK_AND_ASSIGN(result, Decimal128::FromString(string_value));
+ Decimal128 expected(static_cast<int64_t>(-230492239423435324));
+ ASSERT_EQ(result, expected * 100 - 12);
+
+ // Sanity check that our number is actually using more than 64 bits
+ ASSERT_NE(result.high_bits(), 0);
+}
+
+TEST(Decimal128Test, TestStringRoundTrip) {
+ static constexpr uint64_t kTestBits[] = {
+ 0,
+ 1,
+ 999,
+ 1000,
+ std::numeric_limits<int32_t>::max(),
+ (1ull << 31),
+ std::numeric_limits<uint32_t>::max(),
+ (1ull << 32),
+ std::numeric_limits<int64_t>::max(),
+ (1ull << 63),
+ std::numeric_limits<uint64_t>::max(),
+ };
+ static constexpr int32_t kScales[] = {0, 1, 10};
+ for (uint64_t high_bits : kTestBits) {
+ for (uint64_t low_bits : kTestBits) {
+ // When high_bits = 1ull << 63 or std::numeric_limits<uint64_t>::max(), decimal is
+ // negative.
+ Decimal128 decimal(high_bits, low_bits);
+ for (int32_t scale : kScales) {
+ std::string str = decimal.ToString(scale);
+ ASSERT_OK_AND_ASSIGN(Decimal128 result, Decimal128::FromString(str));
+ EXPECT_EQ(decimal, result);
+ }
+ }
+ }
+}
+
+TEST(Decimal128Test, TestDecimal32SignedRoundTrip) {
+ Decimal128 expected("-3402692");
+
+ auto bytes = expected.ToBytes();
+ Decimal128 result(bytes.data());
+ ASSERT_EQ(expected, result);
+}
+
+TEST(Decimal128Test, TestDecimal64SignedRoundTrip) {
+ Decimal128 expected;
+ std::string string_value("-34034293045.921");
+ ASSERT_OK_AND_ASSIGN(expected, Decimal128::FromString(string_value));
+
+ auto bytes = expected.ToBytes();
+ Decimal128 result(bytes.data());
+
+ ASSERT_EQ(expected, result);
+}
+
+TEST(Decimal128Test, TestDecimalStringAndBytesRoundTrip) {
+ Decimal128 expected;
+ std::string string_value("-340282366920938463463374607431.711455");
+ ASSERT_OK_AND_ASSIGN(expected, Decimal128::FromString(string_value));
+
+ std::string expected_string_value("-340282366920938463463374607431711455");
+ Decimal128 expected_underlying_value(expected_string_value);
+
+ ASSERT_EQ(expected, expected_underlying_value);
+
+ auto bytes = expected.ToBytes();
+
+ Decimal128 result(bytes.data());
+
+ ASSERT_EQ(expected, result);
+}
+
+/*
+ Note: generating a number of 64-bit decimal digits from a bigint:
+
+ >>> def dec(x, n):
+ ...: sign = x < 0
+ ...: if sign:
+ ...: x = 2**(64*n) + x
+ ...: a = []
+ ...: for i in range(n-1):
+ ...: x, r = divmod(x, 2**64)
+ ...: a.append(r)
+ ...: assert x < 2**64
+ ...: a.append(x)
+ ...: return a
+ ...:
+ >>> dec(10**37, 2)
+ [68739955140067328, 542101086242752217]
+ >>> dec(-10**37, 2)
+ [18378004118569484288, 17904642987466799398]
+ >>> dec(10**75, 4)
+ [0, 10084168908774762496, 12965995782233477362, 159309191113245227]
+ >>> dec(-10**75, 4)
+ [0, 8362575164934789120, 5480748291476074253, 18287434882596306388]
+*/
+
+TEST(Decimal128Test, FromStringLimits) {
+ // Positive / zero exponent
+ AssertDecimalFromString(
+ "1e37", Decimal128FromLE({68739955140067328ULL, 542101086242752217ULL}), 38, 0);
+ AssertDecimalFromString(
+ "-1e37", Decimal128FromLE({18378004118569484288ULL, 17904642987466799398ULL}), 38,
+ 0);
+ AssertDecimalFromString(
+ "9.87e37", Decimal128FromLE({15251391175463010304ULL, 5350537721215964381ULL}), 38,
+ 0);
+ AssertDecimalFromString(
+ "-9.87e37", Decimal128FromLE({3195352898246541312ULL, 13096206352493587234ULL}), 38,
+ 0);
+ AssertDecimalFromString(
+ "12345678901234567890123456789012345678",
+ Decimal128FromLE({14143994781733811022ULL, 669260594276348691ULL}), 38, 0);
+ AssertDecimalFromString(
+ "-12345678901234567890123456789012345678",
+ Decimal128FromLE({4302749291975740594ULL, 17777483479433202924ULL}), 38, 0);
+
+ // "9..9" (38 times)
+ const auto dec38times9pos =
+ Decimal128FromLE({687399551400673279ULL, 5421010862427522170ULL});
+ // "-9..9" (38 times)
+ const auto dec38times9neg =
+ Decimal128FromLE({17759344522308878337ULL, 13025733211282029445ULL});
+
+ AssertDecimalFromString("99999999999999999999999999999999999999", dec38times9pos, 38,
+ 0);
+ AssertDecimalFromString("-99999999999999999999999999999999999999", dec38times9neg, 38,
+ 0);
+ AssertDecimalFromString("9.9999999999999999999999999999999999999e37", dec38times9pos,
+ 38, 0);
+ AssertDecimalFromString("-9.9999999999999999999999999999999999999e37", dec38times9neg,
+ 38, 0);
+
+ // Positive / zero exponent, precision too large for a non-negative scale
+ ASSERT_RAISES(Invalid, Decimal128::FromString("1e39"));
+ ASSERT_RAISES(Invalid, Decimal128::FromString("-1e39"));
+ ASSERT_RAISES(Invalid, Decimal128::FromString("9e39"));
+ ASSERT_RAISES(Invalid, Decimal128::FromString("-9e39"));
+ ASSERT_RAISES(Invalid, Decimal128::FromString("9.9e40"));
+ ASSERT_RAISES(Invalid, Decimal128::FromString("-9.9e40"));
+ // XXX conversion overflows are currently not detected
+ // ASSERT_RAISES(Invalid, Decimal128::FromString("99e38"));
+ // ASSERT_RAISES(Invalid, Decimal128::FromString("-99e38"));
+ // ASSERT_RAISES(Invalid,
+ // Decimal128::FromString("999999999999999999999999999999999999999e1"));
+ // ASSERT_RAISES(Invalid,
+ // Decimal128::FromString("-999999999999999999999999999999999999999e1"));
+ // ASSERT_RAISES(Invalid,
+ // Decimal128::FromString("999999999999999999999999999999999999999"));
+
+ // No exponent, many fractional digits
+ AssertDecimalFromString("9.9999999999999999999999999999999999999", dec38times9pos, 38,
+ 37);
+ AssertDecimalFromString("-9.9999999999999999999999999999999999999", dec38times9neg, 38,
+ 37);
+ AssertDecimalFromString("0.99999999999999999999999999999999999999", dec38times9pos, 38,
+ 38);
+ AssertDecimalFromString("-0.99999999999999999999999999999999999999", dec38times9neg, 38,
+ 38);
+
+ // Negative exponent
+ AssertDecimalFromString("1e-38", Decimal128FromLE({1, 0}), 1, 38);
+ AssertDecimalFromString(
+ "-1e-38", Decimal128FromLE({18446744073709551615ULL, 18446744073709551615ULL}), 1,
+ 38);
+ AssertDecimalFromString("9.99e-36", Decimal128FromLE({999, 0}), 3, 38);
+ AssertDecimalFromString(
+ "-9.99e-36", Decimal128FromLE({18446744073709550617ULL, 18446744073709551615ULL}),
+ 3, 38);
+ AssertDecimalFromString("987e-38", Decimal128FromLE({987, 0}), 3, 38);
+ AssertDecimalFromString(
+ "-987e-38", Decimal128FromLE({18446744073709550629ULL, 18446744073709551615ULL}), 3,
+ 38);
+ AssertDecimalFromString("99999999999999999999999999999999999999e-37", dec38times9pos,
+ 38, 37);
+ AssertDecimalFromString("-99999999999999999999999999999999999999e-37", dec38times9neg,
+ 38, 37);
+ AssertDecimalFromString("99999999999999999999999999999999999999e-38", dec38times9pos,
+ 38, 38);
+ AssertDecimalFromString("-99999999999999999999999999999999999999e-38", dec38times9neg,
+ 38, 38);
+}
+
+TEST(Decimal256Test, FromStringLimits) {
+ // Positive / zero exponent
+ AssertDecimalFromString(
+ "1e75",
+ Decimal256FromLE(
+ {0, 10084168908774762496ULL, 12965995782233477362ULL, 159309191113245227ULL}),
+ 76, 0);
+ AssertDecimalFromString(
+ "-1e75",
+ Decimal256FromLE(
+ {0, 8362575164934789120ULL, 5480748291476074253ULL, 18287434882596306388ULL}),
+ 76, 0);
+ AssertDecimalFromString(
+ "9.87e75",
+ Decimal256FromLE(
+ {0, 3238743064843046400ULL, 7886074450795240548ULL, 1572381716287730397ULL}),
+ 76, 0);
+ AssertDecimalFromString(
+ "-9.87e75",
+ Decimal256FromLE(
+ {0, 15208001008866505216ULL, 10560669622914311067ULL, 16874362357421821218ULL}),
+ 76, 0);
+
+ AssertDecimalFromString(
+ "1234567890123456789012345678901234567890123456789012345678901234567890123456",
+ Decimal256FromLE({17877984925544397504ULL, 5352188884907840935ULL,
+ 234631617561833724ULL, 196678011949953713ULL}),
+ 76, 0);
+ AssertDecimalFromString(
+ "-1234567890123456789012345678901234567890123456789012345678901234567890123456",
+ Decimal256FromLE({568759148165154112ULL, 13094555188801710680ULL,
+ 18212112456147717891ULL, 18250066061759597902ULL}),
+ 76, 0);
+
+ // "9..9" (76 times)
+ const auto dec76times9pos =
+ Decimal256FromLE({18446744073709551615ULL, 8607968719199866879ULL,
+ 532749306367912313ULL, 1593091911132452277ULL});
+ // "-9..9" (76 times)
+ const auto dec76times9neg = Decimal256FromLE(
+ {1, 9838775354509684736ULL, 17913994767341639302ULL, 16853652162577099338ULL});
+
+ AssertDecimalFromString(
+ "9999999999999999999999999999999999999999999999999999999999999999999999999999",
+ dec76times9pos, 76, 0);
+ AssertDecimalFromString(
+ "-9999999999999999999999999999999999999999999999999999999999999999999999999999",
+ dec76times9neg, 76, 0);
+ AssertDecimalFromString(
+ "9.999999999999999999999999999999999999999999999999999999999999999999999999999e75",
+ dec76times9pos, 76, 0);
+ AssertDecimalFromString(
+ "-9.999999999999999999999999999999999999999999999999999999999999999999999999999e75",
+ dec76times9neg, 76, 0);
+
+ // Positive / zero exponent, precision too large for a non-negative scale
+ ASSERT_RAISES(Invalid, Decimal256::FromString("1e77"));
+ ASSERT_RAISES(Invalid, Decimal256::FromString("-1e77"));
+ ASSERT_RAISES(Invalid, Decimal256::FromString("9e77"));
+ ASSERT_RAISES(Invalid, Decimal256::FromString("-9e77"));
+ ASSERT_RAISES(Invalid, Decimal256::FromString("9.9e78"));
+ ASSERT_RAISES(Invalid, Decimal256::FromString("-9.9e78"));
+
+ // XXX conversion overflows are currently not detected
+ // ASSERT_RAISES(Invalid, Decimal256::FromString("99e76"));
+ // ASSERT_RAISES(Invalid, Decimal256::FromString("-99e76"));
+ // ASSERT_RAISES(Invalid,
+ // Decimal256::FromString("9999999999999999999999999999999999999999999999999999999999999999999999999999e1"));
+ // ASSERT_RAISES(Invalid,
+ // Decimal256::FromString("-9999999999999999999999999999999999999999999999999999999999999999999999999999e1"));
+ // ASSERT_RAISES(Invalid,
+ // Decimal256::FromString("99999999999999999999999999999999999999999999999999999999999999999999999999999"));
+
+ // No exponent, many fractional digits
+ AssertDecimalFromString(
+ "9.999999999999999999999999999999999999999999999999999999999999999999999999999",
+ dec76times9pos, 76, 75);
+ AssertDecimalFromString(
+ "-9.999999999999999999999999999999999999999999999999999999999999999999999999999",
+ dec76times9neg, 76, 75);
+ AssertDecimalFromString(
+ "0.9999999999999999999999999999999999999999999999999999999999999999999999999999",
+ dec76times9pos, 76, 76);
+ AssertDecimalFromString(
+ "-0.9999999999999999999999999999999999999999999999999999999999999999999999999999",
+ dec76times9neg, 76, 76);
+
+ // Negative exponent
+ AssertDecimalFromString("1e-76", Decimal256FromLE({1, 0, 0, 0}), 1, 76);
+ AssertDecimalFromString(
+ "-1e-76",
+ Decimal256FromLE({18446744073709551615ULL, 18446744073709551615ULL,
+ 18446744073709551615ULL, 18446744073709551615ULL}),
+ 1, 76);
+ AssertDecimalFromString("9.99e-74", Decimal256FromLE({999, 0, 0, 0}), 3, 76);
+ AssertDecimalFromString(
+ "-9.99e-74",
+ Decimal256FromLE({18446744073709550617ULL, 18446744073709551615ULL,
+ 18446744073709551615ULL, 18446744073709551615ULL}),
+ 3, 76);
+ AssertDecimalFromString("987e-76", Decimal256FromLE({987, 0, 0, 0}), 3, 76);
+ AssertDecimalFromString(
+ "-987e-76",
+ Decimal256FromLE({18446744073709550629ULL, 18446744073709551615ULL,
+ 18446744073709551615ULL, 18446744073709551615ULL}),
+ 3, 76);
+ AssertDecimalFromString(
+ "9999999999999999999999999999999999999999999999999999999999999999999999999999e-75",
+ dec76times9pos, 76, 75);
+ AssertDecimalFromString(
+ "-9999999999999999999999999999999999999999999999999999999999999999999999999999e-75",
+ dec76times9neg, 76, 75);
+ AssertDecimalFromString(
+ "9999999999999999999999999999999999999999999999999999999999999999999999999999e-76",
+ dec76times9pos, 76, 76);
+ AssertDecimalFromString(
+ "-9999999999999999999999999999999999999999999999999999999999999999999999999999e-76",
+ dec76times9neg, 76, 76);
+}
+
+template <typename DecimalType>
+class DecimalFromIntegerTest : public ::testing::Test {
+ public:
+ template <typename IntegerType>
+ void CheckConstructFrom() {
+ DecimalType value(IntegerType{42});
+ AssertArrayBits(value.little_endian_array(), 42, 0);
+
+ DecimalType max_value(std::numeric_limits<IntegerType>::max());
+ AssertArrayBits(max_value.little_endian_array(),
+ std::numeric_limits<IntegerType>::max(), 0);
+
+ DecimalType min_value(std::numeric_limits<IntegerType>::min());
+ AssertArrayBits(min_value.little_endian_array(),
+ std::numeric_limits<IntegerType>::min(),
+ (std::is_signed<IntegerType>::value ? -1 : 0));
+ }
+
+ void TestConstructibleFromAnyIntegerType() {
+ CheckConstructFrom<char>(); // NOLINT
+ CheckConstructFrom<signed char>(); // NOLINT
+ CheckConstructFrom<unsigned char>(); // NOLINT
+ CheckConstructFrom<short>(); // NOLINT
+ CheckConstructFrom<unsigned short>(); // NOLINT
+ CheckConstructFrom<int>(); // NOLINT
+ CheckConstructFrom<unsigned int>(); // NOLINT
+ CheckConstructFrom<long>(); // NOLINT
+ CheckConstructFrom<unsigned long>(); // NOLINT
+ CheckConstructFrom<long long>(); // NOLINT
+ CheckConstructFrom<unsigned long long>(); // NOLINT
+ }
+
+ void TestConstructibleFromBool() {
+ {
+ DecimalType value(true);
+ AssertArrayBits(value.little_endian_array(), 1, 0);
+ }
+ {
+ DecimalType value(false);
+ AssertArrayBits(value.little_endian_array(), 0, 0);
+ }
+ }
+};
+
+TYPED_TEST_SUITE(DecimalFromIntegerTest, DecimalTypes);
+
+TYPED_TEST(DecimalFromIntegerTest, ConstructibleFromAnyIntegerType) {
+ this->TestConstructibleFromAnyIntegerType();
+}
+
+TYPED_TEST(DecimalFromIntegerTest, ConstructibleFromBool) {
+ this->TestConstructibleFromBool();
+}
+
+TEST(Decimal128Test, Division) {
+ const std::string expected_string_value("-23923094039234029");
+ const Decimal128 value(expected_string_value);
+ const Decimal128 result(value / 3);
+ const Decimal128 expected_value("-7974364679744676");
+ ASSERT_EQ(expected_value, result);
+}
+
+TEST(Decimal128Test, PrintLargePositiveValue) {
+ const std::string string_value("99999999999999999999999999999999999999");
+ const Decimal128 value(string_value);
+ const std::string printed_value = value.ToIntegerString();
+ ASSERT_EQ(string_value, printed_value);
+}
+
+TEST(Decimal128Test, PrintLargeNegativeValue) {
+ const std::string string_value("-99999999999999999999999999999999999999");
+ const Decimal128 value(string_value);
+ const std::string printed_value = value.ToIntegerString();
+ ASSERT_EQ(string_value, printed_value);
+}
+
+TEST(Decimal128Test, PrintMaxValue) {
+ const std::string string_value("170141183460469231731687303715884105727");
+ const Decimal128 value(string_value);
+ const std::string printed_value = value.ToIntegerString();
+ ASSERT_EQ(string_value, printed_value);
+}
+
+TEST(Decimal128Test, PrintMinValue) {
+ const std::string string_value("-170141183460469231731687303715884105728");
+ const Decimal128 value(string_value);
+ const std::string printed_value = value.ToIntegerString();
+ ASSERT_EQ(string_value, printed_value);
+}
+
+struct ToStringTestParam {
+ int64_t test_value;
+ int32_t scale;
+ std::string expected_string;
+
+ // Avoid Valgrind uninitialized memory reads with the default GTest print routine.
+ friend std::ostream& operator<<(std::ostream& os, const ToStringTestParam& param) {
+ return os << "<value: " << param.test_value << ">";
+ }
+};
+
+static const ToStringTestParam kToStringTestData[] = {
+ {0, -1, "0.E+1"},
+ {0, 0, "0"},
+ {0, 1, "0.0"},
+ {0, 6, "0.000000"},
+ {2, 7, "2.E-7"},
+ {2, -1, "2.E+1"},
+ {2, 0, "2"},
+ {2, 1, "0.2"},
+ {2, 6, "0.000002"},
+ {-2, 7, "-2.E-7"},
+ {-2, 7, "-2.E-7"},
+ {-2, -1, "-2.E+1"},
+ {-2, 0, "-2"},
+ {-2, 1, "-0.2"},
+ {-2, 6, "-0.000002"},
+ {-2, 7, "-2.E-7"},
+ {123, -3, "1.23E+5"},
+ {123, -1, "1.23E+3"},
+ {123, 1, "12.3"},
+ {123, 0, "123"},
+ {123, 5, "0.00123"},
+ {123, 8, "0.00000123"},
+ {123, 9, "1.23E-7"},
+ {123, 10, "1.23E-8"},
+ {-123, -3, "-1.23E+5"},
+ {-123, -1, "-1.23E+3"},
+ {-123, 1, "-12.3"},
+ {-123, 0, "-123"},
+ {-123, 5, "-0.00123"},
+ {-123, 8, "-0.00000123"},
+ {-123, 9, "-1.23E-7"},
+ {-123, 10, "-1.23E-8"},
+ {1000000000, -3, "1.000000000E+12"},
+ {1000000000, -1, "1.000000000E+10"},
+ {1000000000, 0, "1000000000"},
+ {1000000000, 1, "100000000.0"},
+ {1000000000, 5, "10000.00000"},
+ {1000000000, 15, "0.000001000000000"},
+ {1000000000, 16, "1.000000000E-7"},
+ {1000000000, 17, "1.000000000E-8"},
+ {-1000000000, -3, "-1.000000000E+12"},
+ {-1000000000, -1, "-1.000000000E+10"},
+ {-1000000000, 0, "-1000000000"},
+ {-1000000000, 1, "-100000000.0"},
+ {-1000000000, 5, "-10000.00000"},
+ {-1000000000, 15, "-0.000001000000000"},
+ {-1000000000, 16, "-1.000000000E-7"},
+ {-1000000000, 17, "-1.000000000E-8"},
+ {1234567890123456789LL, -3, "1.234567890123456789E+21"},
+ {1234567890123456789LL, -1, "1.234567890123456789E+19"},
+ {1234567890123456789LL, 0, "1234567890123456789"},
+ {1234567890123456789LL, 1, "123456789012345678.9"},
+ {1234567890123456789LL, 5, "12345678901234.56789"},
+ {1234567890123456789LL, 24, "0.000001234567890123456789"},
+ {1234567890123456789LL, 25, "1.234567890123456789E-7"},
+ {-1234567890123456789LL, -3, "-1.234567890123456789E+21"},
+ {-1234567890123456789LL, -1, "-1.234567890123456789E+19"},
+ {-1234567890123456789LL, 0, "-1234567890123456789"},
+ {-1234567890123456789LL, 1, "-123456789012345678.9"},
+ {-1234567890123456789LL, 5, "-12345678901234.56789"},
+ {-1234567890123456789LL, 24, "-0.000001234567890123456789"},
+ {-1234567890123456789LL, 25, "-1.234567890123456789E-7"},
+};
+
+class Decimal128ToStringTest : public ::testing::TestWithParam<ToStringTestParam> {};
+
+TEST_P(Decimal128ToStringTest, ToString) {
+ const ToStringTestParam& param = GetParam();
+ const Decimal128 value(param.test_value);
+ const std::string printed_value = value.ToString(param.scale);
+ ASSERT_EQ(param.expected_string, printed_value);
+}
+
+INSTANTIATE_TEST_SUITE_P(Decimal128ToStringTest, Decimal128ToStringTest,
+ ::testing::ValuesIn(kToStringTestData));
+
+template <typename Decimal, typename Real>
+void CheckDecimalFromReal(Real real, int32_t precision, int32_t scale,
+ const std::string& expected) {
+ ASSERT_OK_AND_ASSIGN(auto dec, Decimal::FromReal(real, precision, scale));
+ ASSERT_EQ(dec.ToString(scale), expected);
+}
+
+template <typename Decimal, typename Real>
+void CheckDecimalFromRealIntegerString(Real real, int32_t precision, int32_t scale,
+ const std::string& expected) {
+ ASSERT_OK_AND_ASSIGN(auto dec, Decimal::FromReal(real, precision, scale));
+ ASSERT_EQ(dec.ToIntegerString(), expected);
+}
+
+template <typename Real>
+struct FromRealTestParam {
+ Real real;
+ int32_t precision;
+ int32_t scale;
+ std::string expected;
+
+ // Avoid Valgrind uninitialized memory reads with the default GTest print routine.
+ friend std::ostream& operator<<(std::ostream& os,
+ const FromRealTestParam<Real>& param) {
+ return os << "<real: " << param.real << ">";
+ }
+};
+
+using FromFloatTestParam = FromRealTestParam<float>;
+using FromDoubleTestParam = FromRealTestParam<double>;
+
+// Common tests for Decimal128::FromReal(T, ...) and Decimal256::FromReal(T, ...)
+template <typename T>
+class TestDecimalFromReal : public ::testing::Test {
+ public:
+ using Decimal = typename T::first_type;
+ using Real = typename T::second_type;
+ using ParamType = FromRealTestParam<Real>;
+
+ void TestSuccess() {
+ const std::vector<ParamType> params{
+ // clang-format off
+ {0.0f, 1, 0, "0"},
+ {-0.0f, 1, 0, "0"},
+ {0.0f, 19, 4, "0.0000"},
+ {-0.0f, 19, 4, "0.0000"},
+ {123.0f, 7, 4, "123.0000"},
+ {-123.0f, 7, 4, "-123.0000"},
+ {456.78f, 7, 4, "456.7800"},
+ {-456.78f, 7, 4, "-456.7800"},
+ {456.784f, 5, 2, "456.78"},
+ {-456.784f, 5, 2, "-456.78"},
+ {456.786f, 5, 2, "456.79"},
+ {-456.786f, 5, 2, "-456.79"},
+ {999.99f, 5, 2, "999.99"},
+ {-999.99f, 5, 2, "-999.99"},
+ {123.0f, 19, 0, "123"},
+ {-123.0f, 19, 0, "-123"},
+ {123.4f, 19, 0, "123"},
+ {-123.4f, 19, 0, "-123"},
+ {123.6f, 19, 0, "124"},
+ {-123.6f, 19, 0, "-124"},
+ // 2**62
+ {4.611686e+18f, 19, 0, "4611686018427387904"},
+ {-4.611686e+18f, 19, 0, "-4611686018427387904"},
+ // 2**63
+ {9.223372e+18f, 19, 0, "9223372036854775808"},
+ {-9.223372e+18f, 19, 0, "-9223372036854775808"},
+ // 2**64
+ {1.8446744e+19f, 20, 0, "18446744073709551616"},
+ {-1.8446744e+19f, 20, 0, "-18446744073709551616"}
+ // clang-format on
+ };
+ for (const ParamType& param : params) {
+ CheckDecimalFromReal<Decimal>(param.real, param.precision, param.scale,
+ param.expected);
+ }
+ }
+
+ void TestErrors() {
+ ASSERT_RAISES(Invalid, Decimal::FromReal(INFINITY, 19, 4));
+ ASSERT_RAISES(Invalid, Decimal::FromReal(-INFINITY, 19, 4));
+ ASSERT_RAISES(Invalid, Decimal::FromReal(NAN, 19, 4));
+ // Overflows
+ ASSERT_RAISES(Invalid, Decimal::FromReal(1000.0, 3, 0));
+ ASSERT_RAISES(Invalid, Decimal::FromReal(-1000.0, 3, 0));
+ ASSERT_RAISES(Invalid, Decimal::FromReal(1000.0, 5, 2));
+ ASSERT_RAISES(Invalid, Decimal::FromReal(-1000.0, 5, 2));
+ ASSERT_RAISES(Invalid, Decimal::FromReal(999.996, 5, 2));
+ ASSERT_RAISES(Invalid, Decimal::FromReal(-999.996, 5, 2));
+ ASSERT_RAISES(Invalid, Decimal::FromReal(1e+38, 38, 0));
+ ASSERT_RAISES(Invalid, Decimal::FromReal(-1e+38, 38, 0));
+ }
+};
+
+using RealTypes =
+ ::testing::Types<std::pair<Decimal128, float>, std::pair<Decimal128, double>,
+ std::pair<Decimal256, float>, std::pair<Decimal256, double>>;
+TYPED_TEST_SUITE(TestDecimalFromReal, RealTypes);
+
+TYPED_TEST(TestDecimalFromReal, TestSuccess) { this->TestSuccess(); }
+
+TYPED_TEST(TestDecimalFromReal, TestErrors) { this->TestErrors(); }
+
+// Tests for Decimal128::FromReal(float, ...) and Decimal256::FromReal(float, ...)
+template <typename T>
+class TestDecimalFromRealFloat : public ::testing::Test {
+ protected:
+ std::vector<FromFloatTestParam> GetValues() {
+ return {// 2**63 + 2**40 (exactly representable in a float's 24 bits of precision)
+ FromFloatTestParam{9.223373e+18f, 19, 0, "9223373136366403584"},
+ FromFloatTestParam{-9.223373e+18f, 19, 0, "-9223373136366403584"},
+ FromFloatTestParam{9.223373e+14f, 19, 4, "922337313636640.3584"},
+ FromFloatTestParam{-9.223373e+14f, 19, 4, "-922337313636640.3584"},
+ // 2**64 - 2**40 (exactly representable in a float)
+ FromFloatTestParam{1.8446743e+19f, 20, 0, "18446742974197923840"},
+ FromFloatTestParam{-1.8446743e+19f, 20, 0, "-18446742974197923840"},
+ // 2**64 + 2**41 (exactly representable in a float)
+ FromFloatTestParam{1.8446746e+19f, 20, 0, "18446746272732807168"},
+ FromFloatTestParam{-1.8446746e+19f, 20, 0, "-18446746272732807168"},
+ FromFloatTestParam{1.8446746e+15f, 20, 4, "1844674627273280.7168"},
+ FromFloatTestParam{-1.8446746e+15f, 20, 4, "-1844674627273280.7168"},
+ // Almost 10**38 (minus 2**103)
+ FromFloatTestParam{9.999999e+37f, 38, 0,
+ "99999986661652122824821048795547566080"},
+ FromFloatTestParam{-9.999999e+37f, 38, 0,
+ "-99999986661652122824821048795547566080"}};
+ }
+};
+TYPED_TEST_SUITE(TestDecimalFromRealFloat, DecimalTypes);
+
+TYPED_TEST(TestDecimalFromRealFloat, SuccessConversion) {
+ for (const auto& param : this->GetValues()) {
+ CheckDecimalFromReal<TypeParam>(param.real, param.precision, param.scale,
+ param.expected);
+ }
+}
+
+TYPED_TEST(TestDecimalFromRealFloat, LargeValues) {
+ // Test the entire float range
+ for (int32_t scale = -38; scale <= 38; ++scale) {
+ float real = std::pow(10.0f, static_cast<float>(scale));
+ CheckDecimalFromRealIntegerString<TypeParam>(real, 1, -scale, "1");
+ }
+ for (int32_t scale = -37; scale <= 36; ++scale) {
+ float real = 123.f * std::pow(10.f, static_cast<float>(scale));
+ CheckDecimalFromRealIntegerString<TypeParam>(real, 2, -scale - 1, "12");
+ CheckDecimalFromRealIntegerString<TypeParam>(real, 3, -scale, "123");
+ CheckDecimalFromRealIntegerString<TypeParam>(real, 4, -scale + 1, "1230");
+ }
+}
+
+// Tests for Decimal128::FromReal(double, ...) and Decimal256::FromReal(double, ...)
+template <typename T>
+class TestDecimalFromRealDouble : public ::testing::Test {
+ protected:
+ std::vector<FromDoubleTestParam> GetValues() {
+ return {// 2**63 + 2**11 (exactly representable in a double's 53 bits of precision)
+ FromDoubleTestParam{9.223372036854778e+18, 19, 0, "9223372036854777856"},
+ FromDoubleTestParam{-9.223372036854778e+18, 19, 0, "-9223372036854777856"},
+ FromDoubleTestParam{9.223372036854778e+10, 19, 8, "92233720368.54777856"},
+ FromDoubleTestParam{-9.223372036854778e+10, 19, 8, "-92233720368.54777856"},
+ // 2**64 - 2**11 (exactly representable in a double)
+ FromDoubleTestParam{1.844674407370955e+19, 20, 0, "18446744073709549568"},
+ FromDoubleTestParam{-1.844674407370955e+19, 20, 0, "-18446744073709549568"},
+ // 2**64 + 2**11 (exactly representable in a double)
+ FromDoubleTestParam{1.8446744073709556e+19, 20, 0, "18446744073709555712"},
+ FromDoubleTestParam{-1.8446744073709556e+19, 20, 0, "-18446744073709555712"},
+ FromDoubleTestParam{1.8446744073709556e+15, 20, 4, "1844674407370955.5712"},
+ FromDoubleTestParam{-1.8446744073709556e+15, 20, 4, "-1844674407370955.5712"},
+ // Almost 10**38 (minus 2**73)
+ FromDoubleTestParam{9.999999999999998e+37, 38, 0,
+ "99999999999999978859343891977453174784"},
+ FromDoubleTestParam{-9.999999999999998e+37, 38, 0,
+ "-99999999999999978859343891977453174784"},
+ FromDoubleTestParam{9.999999999999998e+27, 38, 10,
+ "9999999999999997885934389197.7453174784"},
+ FromDoubleTestParam{-9.999999999999998e+27, 38, 10,
+ "-9999999999999997885934389197.7453174784"}};
+ }
+};
+TYPED_TEST_SUITE(TestDecimalFromRealDouble, DecimalTypes);
+
+TYPED_TEST(TestDecimalFromRealDouble, SuccessConversion) {
+ for (const auto& param : this->GetValues()) {
+ CheckDecimalFromReal<TypeParam>(param.real, param.precision, param.scale,
+ param.expected);
+ }
+}
+
+TYPED_TEST(TestDecimalFromRealDouble, LargeValues) {
+ // Test the entire double range
+ for (int32_t scale = -308; scale <= 308; ++scale) {
+ double real = std::pow(10.0, static_cast<double>(scale));
+ CheckDecimalFromRealIntegerString<TypeParam>(real, 1, -scale, "1");
+ }
+ for (int32_t scale = -307; scale <= 306; ++scale) {
+ double real = 123. * std::pow(10.0, static_cast<double>(scale));
+ CheckDecimalFromRealIntegerString<TypeParam>(real, 2, -scale - 1, "12");
+ CheckDecimalFromRealIntegerString<TypeParam>(real, 3, -scale, "123");
+ CheckDecimalFromRealIntegerString<TypeParam>(real, 4, -scale + 1, "1230");
+ }
+}
+
+// Additional values that only apply to Decimal256
+TEST(TestDecimal256FromRealDouble, ExtremeValues) {
+ const std::vector<FromDoubleTestParam> values = {
+ // Almost 10**76
+ FromDoubleTestParam{9.999999999999999e+75, 76, 0,
+ "999999999999999886366330070006442034959750906670402"
+ "8242075715752105414230016"},
+ FromDoubleTestParam{-9.999999999999999e+75, 76, 0,
+ "-999999999999999886366330070006442034959750906670402"
+ "8242075715752105414230016"},
+ FromDoubleTestParam{9.999999999999999e+65, 76, 10,
+ "999999999999999886366330070006442034959750906670402"
+ "824207571575210.5414230016"},
+ FromDoubleTestParam{-9.999999999999999e+65, 76, 10,
+ "-999999999999999886366330070006442034959750906670402"
+ "824207571575210.5414230016"}};
+ for (const auto& param : values) {
+ CheckDecimalFromReal<Decimal256>(param.real, param.precision, param.scale,
+ param.expected);
+ }
+}
+
+template <typename Real>
+struct ToRealTestParam {
+ std::string decimal_value;
+ int32_t scale;
+ Real expected;
+};
+
+using ToFloatTestParam = ToRealTestParam<float>;
+using ToDoubleTestParam = ToRealTestParam<double>;
+
+template <typename Decimal, typename Real>
+void CheckDecimalToReal(const std::string& decimal_value, int32_t scale, Real expected) {
+ Decimal dec(decimal_value);
+ ASSERT_EQ(dec.template ToReal<Real>(scale), expected)
+ << "Decimal value: " << decimal_value << " Scale: " << scale;
+}
+
+template <typename Decimal>
+void CheckDecimalToRealApprox(const std::string& decimal_value, int32_t scale,
+ float expected) {
+ Decimal dec(decimal_value);
+ ASSERT_FLOAT_EQ(dec.template ToReal<float>(scale), expected)
+ << "Decimal value: " << decimal_value << " Scale: " << scale;
+}
+
+template <typename Decimal>
+void CheckDecimalToRealApprox(const std::string& decimal_value, int32_t scale,
+ double expected) {
+ Decimal dec(decimal_value);
+ ASSERT_DOUBLE_EQ(dec.template ToReal<double>(scale), expected)
+ << "Decimal value: " << decimal_value << " Scale: " << scale;
+}
+
+// Common tests for Decimal128::ToReal<T> and Decimal256::ToReal<T>
+template <typename T>
+class TestDecimalToReal : public ::testing::Test {
+ public:
+ using Decimal = typename T::first_type;
+ using Real = typename T::second_type;
+ using ParamType = ToRealTestParam<Real>;
+
+ Real Pow2(int exp) { return std::pow(static_cast<Real>(2), static_cast<Real>(exp)); }
+
+ Real Pow10(int exp) { return std::pow(static_cast<Real>(10), static_cast<Real>(exp)); }
+
+ void TestSuccess() {
+ const std::vector<ParamType> params{
+ // clang-format off
+ {"0", 0, 0.0f},
+ {"0", 10, 0.0f},
+ {"0", -10, 0.0f},
+ {"1", 0, 1.0f},
+ {"12345", 0, 12345.f},
+#ifndef __MINGW32__ // MinGW has precision issues
+ {"12345", 1, 1234.5f},
+#endif
+ {"12345", -3, 12345000.f},
+ // 2**62
+ {"4611686018427387904", 0, Pow2(62)},
+ // 2**63 + 2**62
+ {"13835058055282163712", 0, Pow2(63) + Pow2(62)},
+ // 2**64 + 2**62
+ {"23058430092136939520", 0, Pow2(64) + Pow2(62)},
+ // 10**38 - 2**103
+#ifndef __MINGW32__ // MinGW has precision issues
+ {"99999989858795198174164788026374356992", 0, Pow10(38) - Pow2(103)},
+#endif
+ // clang-format on
+ };
+ for (const ParamType& param : params) {
+ CheckDecimalToReal<Decimal, Real>(param.decimal_value, param.scale, param.expected);
+ if (param.decimal_value != "0") {
+ CheckDecimalToReal<Decimal, Real>("-" + param.decimal_value, param.scale,
+ -param.expected);
+ }
+ }
+ }
+
+ // Test precision of conversions to float values
+ void TestPrecision() {
+ // 2**63 + 2**40 (exactly representable in a float's 24 bits of precision)
+ CheckDecimalToReal<Decimal, Real>("9223373136366403584", 0, 9.223373e+18f);
+ CheckDecimalToReal<Decimal, Real>("-9223373136366403584", 0, -9.223373e+18f);
+ // 2**64 + 2**41 (exactly representable in a float)
+ CheckDecimalToReal<Decimal, Real>("18446746272732807168", 0, 1.8446746e+19f);
+ CheckDecimalToReal<Decimal, Real>("-18446746272732807168", 0, -1.8446746e+19f);
+ }
+
+ // Test conversions with a range of scales
+ void TestLargeValues(int32_t max_scale) {
+ // Note that exact comparisons would succeed on some platforms (Linux, macOS).
+ // Nevertheless, power-of-ten factors are not all exactly representable
+ // in binary floating point.
+ for (int32_t scale = -max_scale; scale <= max_scale; scale++) {
+#ifdef _WIN32
+ // MSVC gives pow(10.f, -45.f) == 0 even though 1e-45f is nonzero
+ if (scale == 45) continue;
+#endif
+ CheckDecimalToRealApprox<Decimal>("1", scale, Pow10(-scale));
+ }
+ for (int32_t scale = -max_scale; scale <= max_scale - 2; scale++) {
+#ifdef _WIN32
+ // MSVC gives pow(10.f, -45.f) == 0 even though 1e-45f is nonzero
+ if (scale == 45) continue;
+#endif
+ const Real factor = static_cast<Real>(123);
+ CheckDecimalToRealApprox<Decimal>("123", scale, factor * Pow10(-scale));
+ }
+ }
+};
+
+TYPED_TEST_SUITE(TestDecimalToReal, RealTypes);
+
+TYPED_TEST(TestDecimalToReal, TestSuccess) { this->TestSuccess(); }
+
+// Custom test for Decimal128::ToReal<float>
+class TestDecimal128ToRealFloat : public TestDecimalToReal<std::pair<Decimal128, float>> {
+};
+TEST_F(TestDecimal128ToRealFloat, LargeValues) { TestLargeValues(/*max_scale=*/38); }
+TEST_F(TestDecimal128ToRealFloat, Precision) { this->TestPrecision(); }
+// Custom test for Decimal256::ToReal<float>
+class TestDecimal256ToRealFloat : public TestDecimalToReal<std::pair<Decimal256, float>> {
+};
+TEST_F(TestDecimal256ToRealFloat, LargeValues) { TestLargeValues(/*max_scale=*/76); }
+TEST_F(TestDecimal256ToRealFloat, Precision) { this->TestPrecision(); }
+
+// ToReal<double> tests are disabled on MinGW because of precision issues in results
+#ifndef __MINGW32__
+
+// Custom test for Decimal128::ToReal<double>
+template <typename DecimalType>
+class TestDecimalToRealDouble : public TestDecimalToReal<std::pair<DecimalType, double>> {
+};
+TYPED_TEST_SUITE(TestDecimalToRealDouble, DecimalTypes);
+
+TYPED_TEST(TestDecimalToRealDouble, LargeValues) {
+ // Note that exact comparisons would succeed on some platforms (Linux, macOS).
+ // Nevertheless, power-of-ten factors are not all exactly representable
+ // in binary floating point.
+ for (int32_t scale = -308; scale <= 308; scale++) {
+ CheckDecimalToRealApprox<TypeParam>("1", scale, this->Pow10(-scale));
+ }
+ for (int32_t scale = -308; scale <= 306; scale++) {
+ const double factor = 123.;
+ CheckDecimalToRealApprox<TypeParam>("123", scale, factor * this->Pow10(-scale));
+ }
+}
+
+TYPED_TEST(TestDecimalToRealDouble, Precision) {
+ // 2**63 + 2**11 (exactly representable in a double's 53 bits of precision)
+ CheckDecimalToReal<TypeParam, double>("9223372036854777856", 0, 9.223372036854778e+18);
+ CheckDecimalToReal<TypeParam, double>("-9223372036854777856", 0,
+ -9.223372036854778e+18);
+ // 2**64 - 2**11 (exactly representable in a double)
+ CheckDecimalToReal<TypeParam, double>("18446744073709549568", 0, 1.844674407370955e+19);
+ CheckDecimalToReal<TypeParam, double>("-18446744073709549568", 0,
+ -1.844674407370955e+19);
+ // 2**64 + 2**11 (exactly representable in a double)
+ CheckDecimalToReal<TypeParam, double>("18446744073709555712", 0,
+ 1.8446744073709556e+19);
+ CheckDecimalToReal<TypeParam, double>("-18446744073709555712", 0,
+ -1.8446744073709556e+19);
+ // Almost 10**38 (minus 2**73)
+ CheckDecimalToReal<TypeParam, double>("99999999999999978859343891977453174784", 0,
+ 9.999999999999998e+37);
+ CheckDecimalToReal<TypeParam, double>("-99999999999999978859343891977453174784", 0,
+ -9.999999999999998e+37);
+ CheckDecimalToReal<TypeParam, double>("99999999999999978859343891977453174784", 10,
+ 9.999999999999998e+27);
+ CheckDecimalToReal<TypeParam, double>("-99999999999999978859343891977453174784", 10,
+ -9.999999999999998e+27);
+ CheckDecimalToReal<TypeParam, double>("99999999999999978859343891977453174784", -10,
+ 9.999999999999998e+47);
+ CheckDecimalToReal<TypeParam, double>("-99999999999999978859343891977453174784", -10,
+ -9.999999999999998e+47);
+}
+
+#endif // __MINGW32__
+
+TEST(Decimal128Test, TestFromBigEndian) {
+ // We test out a variety of scenarios:
+ //
+ // * Positive values that are left shifted
+ // and filled in with the same bit pattern
+ // * Negated of the positive values
+ // * Complement of the positive values
+ //
+ // For the positive values, we can call FromBigEndian
+ // with a length that is less than 16, whereas we must
+ // pass all 16 bytes for the negative and complement.
+ //
+ // We use a number of bit patterns to increase the coverage
+ // of scenarios
+ for (int32_t start : {1, 15, /* 00001111 */
+ 85, /* 01010101 */
+ 127 /* 01111111 */}) {
+ Decimal128 value(start);
+ for (int ii = 0; ii < 16; ++ii) {
+ auto native_endian = value.ToBytes();
+#if ARROW_LITTLE_ENDIAN
+ std::reverse(native_endian.begin(), native_endian.end());
+#endif
+ // Limit the number of bytes we are passing to make
+ // sure that it works correctly. That's why all of the
+ // 'start' values don't have a 1 in the most significant
+ // bit place
+ ASSERT_OK_AND_EQ(value,
+ Decimal128::FromBigEndian(native_endian.data() + 15 - ii, ii + 1));
+
+ // Negate it
+ auto negated = -value;
+ native_endian = negated.ToBytes();
+#if ARROW_LITTLE_ENDIAN
+ // convert to big endian
+ std::reverse(native_endian.begin(), native_endian.end());
+#endif
+ // The sign bit is looked up in the MSB
+ ASSERT_OK_AND_EQ(negated,
+ Decimal128::FromBigEndian(native_endian.data() + 15 - ii, ii + 1));
+
+ // Take the complement
+ auto complement = ~value;
+ native_endian = complement.ToBytes();
+#if ARROW_LITTLE_ENDIAN
+ // convert to big endian
+ std::reverse(native_endian.begin(), native_endian.end());
+#endif
+ ASSERT_OK_AND_EQ(complement, Decimal128::FromBigEndian(native_endian.data(), 16));
+
+ value <<= 8;
+ value += Decimal128(start);
+ }
+ }
+}
+
+TEST(Decimal128Test, TestFromBigEndianBadLength) {
+ ASSERT_RAISES(Invalid, Decimal128::FromBigEndian(0, -1));
+ ASSERT_RAISES(Invalid, Decimal128::FromBigEndian(0, 17));
+}
+
+TEST(Decimal128Test, TestToInteger) {
+ Decimal128 value1("1234");
+ int32_t out1;
+
+ Decimal128 value2("-1234");
+ int64_t out2;
+
+ ASSERT_OK(value1.ToInteger(&out1));
+ ASSERT_EQ(1234, out1);
+
+ ASSERT_OK(value1.ToInteger(&out2));
+ ASSERT_EQ(1234, out2);
+
+ ASSERT_OK(value2.ToInteger(&out1));
+ ASSERT_EQ(-1234, out1);
+
+ ASSERT_OK(value2.ToInteger(&out2));
+ ASSERT_EQ(-1234, out2);
+
+ Decimal128 invalid_int32(static_cast<int64_t>(std::pow(2, 31)));
+ ASSERT_RAISES(Invalid, invalid_int32.ToInteger(&out1));
+
+ Decimal128 invalid_int64("12345678912345678901");
+ ASSERT_RAISES(Invalid, invalid_int64.ToInteger(&out2));
+}
+
+template <typename ArrowType, typename CType = typename ArrowType::c_type>
+std::vector<CType> GetRandomNumbers(int32_t size) {
+ auto rand = random::RandomArrayGenerator(0x5487655);
+ auto x_array = rand.Numeric<ArrowType>(size, static_cast<CType>(0),
+ std::numeric_limits<CType>::max(), 0);
+
+ auto x_ptr = x_array->data()->template GetValues<CType>(1);
+ std::vector<CType> ret;
+ for (int i = 0; i < size; ++i) {
+ ret.push_back(x_ptr[i]);
+ }
+ return ret;
+}
+
+Decimal128 Decimal128FromInt128(int128_t value) {
+ return Decimal128(static_cast<int64_t>(value >> 64),
+ static_cast<uint64_t>(value & 0xFFFFFFFFFFFFFFFFULL));
+}
+
+TEST(Decimal128Test, Multiply) {
+ ASSERT_EQ(Decimal128(60501), Decimal128(301) * Decimal128(201));
+
+ ASSERT_EQ(Decimal128(-60501), Decimal128(-301) * Decimal128(201));
+
+ ASSERT_EQ(Decimal128(-60501), Decimal128(301) * Decimal128(-201));
+
+ ASSERT_EQ(Decimal128(60501), Decimal128(-301) * Decimal128(-201));
+
+ // Test some random numbers.
+ for (auto x : GetRandomNumbers<Int32Type>(16)) {
+ for (auto y : GetRandomNumbers<Int32Type>(16)) {
+ Decimal128 result = Decimal128(x) * Decimal128(y);
+ ASSERT_EQ(Decimal128(static_cast<int64_t>(x) * y), result)
+ << " x: " << x << " y: " << y;
+ // Test by multiplying with an additional 32 bit factor, then additional
+ // factor of 2^30 to test results in the range of -2^123 to 2^123 without overflow.
+ for (auto z : GetRandomNumbers<Int32Type>(32)) {
+ int128_t w = static_cast<int128_t>(x) * y * (1ull << 30);
+ Decimal128 expected = Decimal128FromInt128(static_cast<int128_t>(w) * z);
+ Decimal128 actual = Decimal128FromInt128(w) * Decimal128(z);
+ ASSERT_EQ(expected, actual) << " w: " << x << " * " << y << " * 2^30 z: " << z;
+ }
+ }
+ }
+
+ // Test some edge cases
+ for (auto x : std::vector<int128_t>{-INT64_MAX, -INT32_MAX, 0, INT32_MAX, INT64_MAX}) {
+ for (auto y :
+ std::vector<int128_t>{-INT32_MAX, -32, -2, -1, 0, 1, 2, 32, INT32_MAX}) {
+ Decimal128 decimal_x = Decimal128FromInt128(x);
+ Decimal128 decimal_y = Decimal128FromInt128(y);
+ Decimal128 result = decimal_x * decimal_y;
+ EXPECT_EQ(Decimal128FromInt128(x * y), result)
+ << " x: " << decimal_x << " y: " << decimal_y;
+ }
+ }
+}
+
+TEST(Decimal128Test, Divide) {
+ ASSERT_EQ(Decimal128(66), Decimal128(20100) / Decimal128(301));
+
+ ASSERT_EQ(Decimal128(-66), Decimal128(-20100) / Decimal128(301));
+
+ ASSERT_EQ(Decimal128(-66), Decimal128(20100) / Decimal128(-301));
+
+ ASSERT_EQ(Decimal128(66), Decimal128(-20100) / Decimal128(-301));
+
+ // Test some random numbers.
+ for (auto x : GetRandomNumbers<Int32Type>(16)) {
+ for (auto y : GetRandomNumbers<Int32Type>(16)) {
+ if (y == 0) {
+ continue;
+ }
+
+ Decimal128 result = Decimal128(x) / Decimal128(y);
+ ASSERT_EQ(Decimal128(static_cast<int64_t>(x) / y), result)
+ << " x: " << x << " y: " << y;
+ }
+ }
+
+ // Test some edge cases
+ for (auto x : std::vector<int128_t>{-INT64_MAX, -INT32_MAX, 0, INT32_MAX, INT64_MAX}) {
+ for (auto y : std::vector<int128_t>{-INT32_MAX, -32, -2, -1, 1, 2, 32, INT32_MAX}) {
+ Decimal128 decimal_x = Decimal128FromInt128(x);
+ Decimal128 decimal_y = Decimal128FromInt128(y);
+ Decimal128 result = decimal_x / decimal_y;
+ EXPECT_EQ(Decimal128FromInt128(x / y), result)
+ << " x: " << decimal_x << " y: " << decimal_y;
+ }
+ }
+}
+
+TEST(Decimal128Test, Rescale) {
+ ASSERT_OK_AND_EQ(Decimal128(11100), Decimal128(111).Rescale(0, 2));
+ ASSERT_OK_AND_EQ(Decimal128(111), Decimal128(11100).Rescale(2, 0));
+ ASSERT_OK_AND_EQ(Decimal128(5), Decimal128(500000).Rescale(6, 1));
+ ASSERT_OK_AND_EQ(Decimal128(500000), Decimal128(5).Rescale(1, 6));
+ ASSERT_RAISES(Invalid, Decimal128(555555).Rescale(6, 1));
+
+ // Test some random numbers.
+ for (auto original_scale : GetRandomNumbers<Int16Type>(16)) {
+ for (auto value : GetRandomNumbers<Int32Type>(16)) {
+ Decimal128 unscaled_value = Decimal128(value);
+ Decimal128 scaled_value = unscaled_value;
+ for (int32_t new_scale = original_scale; new_scale < original_scale + 29;
+ new_scale++, scaled_value *= Decimal128(10)) {
+ ASSERT_OK_AND_EQ(scaled_value, unscaled_value.Rescale(original_scale, new_scale));
+ ASSERT_OK_AND_EQ(unscaled_value, scaled_value.Rescale(new_scale, original_scale));
+ }
+ }
+ }
+
+ for (auto original_scale : GetRandomNumbers<Int16Type>(16)) {
+ Decimal128 value(1);
+ for (int32_t new_scale = original_scale; new_scale < original_scale + 39;
+ new_scale++, value *= Decimal128(10)) {
+ Decimal128 negative_value = value * -1;
+ ASSERT_OK_AND_EQ(value, Decimal128(1).Rescale(original_scale, new_scale));
+ ASSERT_OK_AND_EQ(negative_value, Decimal128(-1).Rescale(original_scale, new_scale));
+ ASSERT_OK_AND_EQ(Decimal128(1), value.Rescale(new_scale, original_scale));
+ ASSERT_OK_AND_EQ(Decimal128(-1), negative_value.Rescale(new_scale, original_scale));
+ }
+ }
+}
+
+TEST(Decimal128Test, Mod) {
+ ASSERT_EQ(Decimal128(234), Decimal128(20100) % Decimal128(301));
+
+ ASSERT_EQ(Decimal128(-234), Decimal128(-20100) % Decimal128(301));
+
+ ASSERT_EQ(Decimal128(234), Decimal128(20100) % Decimal128(-301));
+
+ ASSERT_EQ(Decimal128(-234), Decimal128(-20100) % Decimal128(-301));
+
+ // Test some random numbers.
+ for (auto x : GetRandomNumbers<Int32Type>(16)) {
+ for (auto y : GetRandomNumbers<Int32Type>(16)) {
+ if (y == 0) {
+ continue;
+ }
+
+ Decimal128 result = Decimal128(x) % Decimal128(y);
+ ASSERT_EQ(Decimal128(static_cast<int64_t>(x) % y), result)
+ << " x: " << x << " y: " << y;
+ }
+ }
+
+ // Test some edge cases
+ for (auto x : std::vector<int128_t>{-INT64_MAX, -INT32_MAX, 0, INT32_MAX, INT64_MAX}) {
+ for (auto y : std::vector<int128_t>{-INT32_MAX, -32, -2, -1, 1, 2, 32, INT32_MAX}) {
+ Decimal128 decimal_x = Decimal128FromInt128(x);
+ Decimal128 decimal_y = Decimal128FromInt128(y);
+ Decimal128 result = decimal_x % decimal_y;
+ EXPECT_EQ(Decimal128FromInt128(x % y), result)
+ << " x: " << decimal_x << " y: " << decimal_y;
+ }
+ }
+}
+
+TEST(Decimal128Test, Sign) {
+ ASSERT_EQ(1, Decimal128(999999).Sign());
+ ASSERT_EQ(-1, Decimal128(-999999).Sign());
+ ASSERT_EQ(1, Decimal128(0).Sign());
+}
+
+TEST(Decimal128Test, GetWholeAndFraction) {
+ Decimal128 value("123456");
+ Decimal128 whole;
+ Decimal128 fraction;
+ int32_t out;
+
+ value.GetWholeAndFraction(0, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(123456, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(0, out);
+
+ value.GetWholeAndFraction(1, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(12345, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(6, out);
+
+ value.GetWholeAndFraction(5, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(1, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(23456, out);
+
+ value.GetWholeAndFraction(7, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(0, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(123456, out);
+}
+
+TEST(Decimal128Test, GetWholeAndFractionNegative) {
+ Decimal128 value("-123456");
+ Decimal128 whole;
+ Decimal128 fraction;
+ int32_t out;
+
+ value.GetWholeAndFraction(0, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(-123456, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(0, out);
+
+ value.GetWholeAndFraction(1, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(-12345, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(-6, out);
+
+ value.GetWholeAndFraction(5, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(-1, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(-23456, out);
+
+ value.GetWholeAndFraction(7, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(0, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(-123456, out);
+}
+
+TEST(Decimal128Test, IncreaseScale) {
+ Decimal128 result;
+ int32_t out;
+
+ result = Decimal128("1234").IncreaseScaleBy(0);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(1234, out);
+
+ result = Decimal128("1234").IncreaseScaleBy(3);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(1234000, out);
+
+ result = Decimal128("-1234").IncreaseScaleBy(3);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(-1234000, out);
+}
+
+TEST(Decimal128Test, ReduceScaleAndRound) {
+ Decimal128 result;
+ int32_t out;
+
+ result = Decimal128("123456").ReduceScaleBy(0);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(123456, out);
+
+ result = Decimal128("123456").ReduceScaleBy(1, false);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(12345, out);
+
+ result = Decimal128("123456").ReduceScaleBy(1, true);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(12346, out);
+
+ result = Decimal128("123451").ReduceScaleBy(1, true);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(12345, out);
+
+ result = Decimal128("-123789").ReduceScaleBy(2, true);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(-1238, out);
+
+ result = Decimal128("-123749").ReduceScaleBy(2, true);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(-1237, out);
+
+ result = Decimal128("-123750").ReduceScaleBy(2, true);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(-1238, out);
+}
+
+TEST(Decimal128Test, FitsInPrecision) {
+ ASSERT_TRUE(Decimal128("0").FitsInPrecision(1));
+ ASSERT_TRUE(Decimal128("9").FitsInPrecision(1));
+ ASSERT_TRUE(Decimal128("-9").FitsInPrecision(1));
+ ASSERT_FALSE(Decimal128("10").FitsInPrecision(1));
+ ASSERT_FALSE(Decimal128("-10").FitsInPrecision(1));
+
+ ASSERT_TRUE(Decimal128("0").FitsInPrecision(2));
+ ASSERT_TRUE(Decimal128("10").FitsInPrecision(2));
+ ASSERT_TRUE(Decimal128("-10").FitsInPrecision(2));
+ ASSERT_TRUE(Decimal128("99").FitsInPrecision(2));
+ ASSERT_TRUE(Decimal128("-99").FitsInPrecision(2));
+ ASSERT_FALSE(Decimal128("100").FitsInPrecision(2));
+ ASSERT_FALSE(Decimal128("-100").FitsInPrecision(2));
+
+ ASSERT_TRUE(Decimal128("99999999999999999999999999999999999999").FitsInPrecision(38));
+ ASSERT_TRUE(Decimal128("-99999999999999999999999999999999999999").FitsInPrecision(38));
+ ASSERT_FALSE(Decimal128("100000000000000000000000000000000000000").FitsInPrecision(38));
+ ASSERT_FALSE(
+ Decimal128("-100000000000000000000000000000000000000").FitsInPrecision(38));
+}
+
+static constexpr std::array<uint64_t, 4> kSortedDecimal256Bits[] = {
+ {0, 0, 0, 0x8000000000000000ULL}, // min
+ {0xFFFFFFFFFFFFFFFEULL, 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL,
+ 0xFFFFFFFFFFFFFFFFULL}, // -2
+ {0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL,
+ 0xFFFFFFFFFFFFFFFFULL}, // -1
+ {0, 0, 0, 0},
+ {1, 0, 0, 0},
+ {2, 0, 0, 0},
+ {0xFFFFFFFFFFFFFFFFULL, 0, 0, 0},
+ {0, 1, 0, 0},
+ {0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, 0, 0},
+ {0, 0, 1, 0},
+ {0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, 0},
+ {0, 0, 0, 1},
+ {0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL,
+ 0x7FFFFFFFFFFFFFFFULL}, // max
+};
+
+TEST(Decimal256Test, TestComparators) {
+ constexpr size_t num_values =
+ sizeof(kSortedDecimal256Bits) / sizeof(kSortedDecimal256Bits[0]);
+ for (size_t i = 0; i < num_values; ++i) {
+ Decimal256 left(
+ ::arrow::BitUtil::LittleEndianArray::ToNative(kSortedDecimal256Bits[i]));
+ for (size_t j = 0; j < num_values; ++j) {
+ Decimal256 right(
+ ::arrow::BitUtil::LittleEndianArray::ToNative(kSortedDecimal256Bits[j]));
+ EXPECT_EQ(i == j, left == right);
+ EXPECT_EQ(i != j, left != right);
+ EXPECT_EQ(i < j, left < right);
+ EXPECT_EQ(i > j, left > right);
+ EXPECT_EQ(i <= j, left <= right);
+ EXPECT_EQ(i >= j, left >= right);
+ }
+ }
+}
+
+TEST(Decimal256Test, TestToBytesRoundTrip) {
+ for (const std::array<uint64_t, 4>& bits : kSortedDecimal256Bits) {
+ Decimal256 decimal(::arrow::BitUtil::LittleEndianArray::ToNative(bits));
+ EXPECT_EQ(decimal, Decimal256(decimal.ToBytes().data()));
+ }
+}
+
+template <typename T>
+class Decimal256Test : public ::testing::Test {
+ public:
+ Decimal256Test() {}
+};
+
+using Decimal256Types =
+ ::testing::Types<char, unsigned char, short, unsigned short, // NOLINT
+ int, unsigned int, long, unsigned long, // NOLINT
+ long long, unsigned long long // NOLINT
+ >;
+
+TYPED_TEST_SUITE(Decimal256Test, Decimal256Types);
+
+TYPED_TEST(Decimal256Test, ConstructibleFromAnyIntegerType) {
+ using UInt64Array = std::array<uint64_t, 4>;
+ Decimal256 value(TypeParam{42});
+ EXPECT_EQ(UInt64Array({42, 0, 0, 0}),
+ ::arrow::BitUtil::LittleEndianArray::FromNative(value.native_endian_array()));
+
+ TypeParam max = std::numeric_limits<TypeParam>::max();
+ Decimal256 max_value(max);
+ EXPECT_EQ(
+ UInt64Array({static_cast<uint64_t>(max), 0, 0, 0}),
+ ::arrow::BitUtil::LittleEndianArray::FromNative(max_value.native_endian_array()));
+
+ TypeParam min = std::numeric_limits<TypeParam>::min();
+ Decimal256 min_value(min);
+ uint64_t high_bits = std::is_signed<TypeParam>::value ? ~uint64_t{0} : uint64_t{0};
+ EXPECT_EQ(
+ UInt64Array({static_cast<uint64_t>(min), high_bits, high_bits, high_bits}),
+ ::arrow::BitUtil::LittleEndianArray::FromNative(min_value.native_endian_array()));
+}
+
+TEST(Decimal256Test, ConstructibleFromBool) {
+ EXPECT_EQ(Decimal256(0), Decimal256(false));
+ EXPECT_EQ(Decimal256(1), Decimal256(true));
+}
+
+Decimal256 Decimal256FromInt128(int128_t value) {
+ return Decimal256(Decimal128(static_cast<int64_t>(value >> 64),
+ static_cast<uint64_t>(value & 0xFFFFFFFFFFFFFFFFULL)));
+}
+
+TEST(Decimal256Test, Multiply) {
+ using boost::multiprecision::int256_t;
+ using boost::multiprecision::uint256_t;
+
+ ASSERT_EQ(Decimal256(60501), Decimal256(301) * Decimal256(201));
+
+ ASSERT_EQ(Decimal256(-60501), Decimal256(-301) * Decimal256(201));
+
+ ASSERT_EQ(Decimal256(-60501), Decimal256(301) * Decimal256(-201));
+
+ ASSERT_EQ(Decimal256(60501), Decimal256(-301) * Decimal256(-201));
+
+ // Test some random numbers.
+ std::vector<int128_t> left;
+ std::vector<int128_t> right;
+ for (auto x : GetRandomNumbers<Int32Type>(16)) {
+ for (auto y : GetRandomNumbers<Int32Type>(16)) {
+ for (auto z : GetRandomNumbers<Int32Type>(16)) {
+ for (auto w : GetRandomNumbers<Int32Type>(16)) {
+ // Test two 128 bit numbers which have a large amount of bits set.
+ int128_t l = static_cast<uint128_t>(x) << 96 | static_cast<uint128_t>(y) << 64 |
+ static_cast<uint128_t>(z) << 32 | static_cast<uint128_t>(w);
+ int128_t r = static_cast<uint128_t>(w) << 96 | static_cast<uint128_t>(z) << 64 |
+ static_cast<uint128_t>(y) << 32 | static_cast<uint128_t>(x);
+ int256_t expected = int256_t(l) * r;
+ Decimal256 actual = Decimal256FromInt128(l) * Decimal256FromInt128(r);
+ ASSERT_EQ(expected.str(), actual.ToIntegerString())
+ << " " << int256_t(l).str() << " * " << int256_t(r).str();
+ // Test a 96 bit number against a 160 bit number.
+ int128_t s = l >> 32;
+ uint256_t b = uint256_t(r) << 32;
+ Decimal256 b_dec =
+ Decimal256FromInt128(r) * Decimal256(static_cast<uint64_t>(1) << 32);
+ ASSERT_EQ(b.str(), b_dec.ToIntegerString()) << int256_t(r).str();
+ expected = int256_t(s) * b;
+ actual = Decimal256FromInt128(s) * b_dec;
+ ASSERT_EQ(expected.str(), actual.ToIntegerString())
+ << " " << int256_t(s).str() << " * " << int256_t(b).str();
+ }
+ }
+ }
+ }
+
+ // Test some edge cases
+ for (auto x : std::vector<int128_t>{-INT64_MAX, -INT32_MAX, 0, INT32_MAX, INT64_MAX}) {
+ for (auto y :
+ std::vector<int128_t>{-INT32_MAX, -32, -2, -1, 0, 1, 2, 32, INT32_MAX}) {
+ Decimal256 decimal_x = Decimal256FromInt128(x);
+ Decimal256 decimal_y = Decimal256FromInt128(y);
+ Decimal256 result = decimal_x * decimal_y;
+ EXPECT_EQ(Decimal256FromInt128(x * y), result)
+ << " x: " << decimal_x << " y: " << decimal_y;
+ }
+ }
+}
+
+TEST(Decimal256Test, Shift) {
+ {
+ // Values compared against python's implementation of shift.
+ Decimal256 v(967);
+ v <<= 16;
+ ASSERT_EQ(v, Decimal256("63373312"));
+ v <<= 66;
+ ASSERT_EQ(v, Decimal256("4676125070269385647763488768"));
+ v <<= 128;
+ ASSERT_EQ(v,
+ Decimal256(
+ "1591202906929606242763855199532957938318305582067671727858104926208"));
+ }
+ {
+ // Values compared against python's implementation of shift.
+ Decimal256 v(0xEFFACDA);
+ v <<= 17;
+ ASSERT_EQ(v, Decimal256("32982558834688"));
+ v <<= 67;
+ ASSERT_EQ(v, Decimal256("4867366573756459829801535578046464"));
+ v <<= 129;
+ ASSERT_EQ(
+ v,
+ Decimal256(
+ "3312558036779413504434176328500812891073739806516698535430241719490183168"));
+ v <<= 43;
+ ASSERT_EQ(v, Decimal256(0));
+ }
+
+ {
+ // Values compared against python's implementation of shift.
+ Decimal256 v("-12346789123456789123456789");
+ v <<= 15;
+ ASSERT_EQ(v, Decimal256("-404579585997432065997432061952"))
+ << std::hex << v.native_endian_array()[0] << " " << v.native_endian_array()[1]
+ << " " << v.native_endian_array()[2] << " " << v.native_endian_array()[3] << "\n"
+ << Decimal256("-404579585997432065997432061952").native_endian_array()[0] << " "
+ << Decimal256("-404579585997432065997432061952").native_endian_array()[1] << " "
+ << Decimal256("-404579585997432065997432061952").native_endian_array()[2] << " "
+ << Decimal256("-404579585997432065997432061952").native_endian_array()[3];
+ v <<= 30;
+ ASSERT_EQ(v, Decimal256("-434414022622047565860171081516421480448"));
+ v <<= 66;
+ ASSERT_EQ(v,
+ Decimal256("-32054097189358332105678889809255994470201895906771963215872"));
+ }
+}
+
+TEST(Decimal256Test, Add) {
+ EXPECT_EQ(Decimal256(103), Decimal256(100) + Decimal256(3));
+ EXPECT_EQ(Decimal256(203), Decimal256(200) + Decimal256(3));
+ EXPECT_EQ(Decimal256(20401), Decimal256(20100) + Decimal256(301));
+ EXPECT_EQ(Decimal256(-19799), Decimal256(-20100) + Decimal256(301));
+ EXPECT_EQ(Decimal256(19799), Decimal256(20100) + Decimal256(-301));
+ EXPECT_EQ(Decimal256(-20401), Decimal256(-20100) + Decimal256(-301));
+ EXPECT_EQ(Decimal256("100000000000000000000000000000000001"),
+ Decimal256("99999999999999999999999999999999999") + Decimal256("2"));
+ EXPECT_EQ(Decimal256("120200000000000000000000000000002019"),
+ Decimal256("99999999999999999999999999999999999") +
+ Decimal256("20200000000000000000000000000002020"));
+
+ // Test some random numbers.
+ for (auto x : GetRandomNumbers<Int32Type>(16)) {
+ for (auto y : GetRandomNumbers<Int32Type>(16)) {
+ if (y == 0) {
+ continue;
+ }
+
+ Decimal256 result = Decimal256(x) + Decimal256(y);
+ ASSERT_EQ(Decimal256(static_cast<int64_t>(x) + y), result)
+ << " x: " << x << " y: " << y;
+ }
+ }
+}
+
+TEST(Decimal256Test, Divide) {
+ ASSERT_EQ(Decimal256(33), Decimal256(100) / Decimal256(3));
+ ASSERT_EQ(Decimal256(66), Decimal256(200) / Decimal256(3));
+ ASSERT_EQ(Decimal256(66), Decimal256(20100) / Decimal256(301));
+ ASSERT_EQ(Decimal256(-66), Decimal256(-20100) / Decimal256(301));
+ ASSERT_EQ(Decimal256(-66), Decimal256(20100) / Decimal256(-301));
+ ASSERT_EQ(Decimal256(66), Decimal256(-20100) / Decimal256(-301));
+ ASSERT_EQ(Decimal256("-5192296858534827628530496329343552"),
+ Decimal256("-269599466671506397946670150910580797473777870509761363"
+ "24636208709184") /
+ Decimal256("5192296858534827628530496329874417"));
+ ASSERT_EQ(Decimal256("5192296858534827628530496329343552"),
+ Decimal256("-269599466671506397946670150910580797473777870509761363"
+ "24636208709184") /
+ Decimal256("-5192296858534827628530496329874417"));
+ ASSERT_EQ(Decimal256("5192296858534827628530496329343552"),
+ Decimal256("2695994666715063979466701509105807974737778705097613632"
+ "4636208709184") /
+ Decimal256("5192296858534827628530496329874417"));
+ ASSERT_EQ(Decimal256("-5192296858534827628530496329343552"),
+ Decimal256("2695994666715063979466701509105807974737778705097613632"
+ "4636208709184") /
+ Decimal256("-5192296858534827628530496329874417"));
+
+ // Test some random numbers.
+ for (auto x : GetRandomNumbers<Int32Type>(16)) {
+ for (auto y : GetRandomNumbers<Int32Type>(16)) {
+ if (y == 0) {
+ continue;
+ }
+
+ Decimal256 result = Decimal256(x) / Decimal256(y);
+ ASSERT_EQ(Decimal256(static_cast<int64_t>(x) / y), result)
+ << " x: " << x << " y: " << y;
+ }
+ }
+
+ // Test some edge cases
+ for (auto x :
+ std::vector<int128_t>{-kInt128Max, -INT64_MAX - 1, -INT64_MAX, -INT32_MAX - 1,
+ -INT32_MAX, 0, INT32_MAX, INT64_MAX, kInt128Max}) {
+ for (auto y : std::vector<int128_t>{-INT64_MAX - 1, -INT64_MAX, -INT32_MAX, -32, -2,
+ -1, 1, 2, 32, INT32_MAX, INT64_MAX}) {
+ Decimal256 decimal_x = Decimal256FromInt128(x);
+ Decimal256 decimal_y = Decimal256FromInt128(y);
+ Decimal256 result = decimal_x / decimal_y;
+ EXPECT_EQ(Decimal256FromInt128(x / y), result);
+ }
+ }
+}
+
+TEST(Decimal256Test, Rescale) {
+ ASSERT_OK_AND_EQ(Decimal256(11100), Decimal256(111).Rescale(0, 2));
+ ASSERT_OK_AND_EQ(Decimal256(111), Decimal256(11100).Rescale(2, 0));
+ ASSERT_OK_AND_EQ(Decimal256(5), Decimal256(500000).Rescale(6, 1));
+ ASSERT_OK_AND_EQ(Decimal256(500000), Decimal256(5).Rescale(1, 6));
+ ASSERT_RAISES(Invalid, Decimal256(555555).Rescale(6, 1));
+
+ // Test some random numbers.
+ for (auto original_scale : GetRandomNumbers<Int16Type>(16)) {
+ for (auto value : GetRandomNumbers<Int32Type>(16)) {
+ Decimal256 unscaled_value = Decimal256(value);
+ Decimal256 scaled_value = unscaled_value;
+ for (int32_t new_scale = original_scale; new_scale < original_scale + 68;
+ new_scale++, scaled_value *= Decimal256(10)) {
+ ASSERT_OK_AND_EQ(scaled_value, unscaled_value.Rescale(original_scale, new_scale));
+ ASSERT_OK_AND_EQ(unscaled_value, scaled_value.Rescale(new_scale, original_scale));
+ }
+ }
+ }
+
+ for (auto original_scale : GetRandomNumbers<Int16Type>(16)) {
+ Decimal256 value(1);
+ for (int32_t new_scale = original_scale; new_scale < original_scale + 77;
+ new_scale++, value *= Decimal256(10)) {
+ Decimal256 negative_value = value * -1;
+ ASSERT_OK_AND_EQ(value, Decimal256(1).Rescale(original_scale, new_scale));
+ ASSERT_OK_AND_EQ(negative_value, Decimal256(-1).Rescale(original_scale, new_scale));
+ ASSERT_OK_AND_EQ(Decimal256(1), value.Rescale(new_scale, original_scale));
+ ASSERT_OK_AND_EQ(Decimal256(-1), negative_value.Rescale(new_scale, original_scale));
+ }
+ }
+}
+
+TEST(Decimal256Test, IncreaseScale) {
+ Decimal256 result;
+
+ result = Decimal256("1234").IncreaseScaleBy(0);
+ ASSERT_EQ("1234", result.ToIntegerString());
+
+ result = Decimal256("1234").IncreaseScaleBy(3);
+ ASSERT_EQ("1234000", result.ToIntegerString());
+
+ result = Decimal256("-1234").IncreaseScaleBy(3);
+ ASSERT_EQ("-1234000", result.ToIntegerString());
+}
+
+TEST(Decimal256Test, ReduceScaleAndRound) {
+ Decimal256 result;
+
+ result = Decimal256("123456").ReduceScaleBy(0);
+ ASSERT_EQ("123456", result.ToIntegerString());
+
+ result = Decimal256("123456").ReduceScaleBy(1, false);
+ ASSERT_EQ("12345", result.ToIntegerString());
+
+ result = Decimal256("123456").ReduceScaleBy(1, true);
+ ASSERT_EQ("12346", result.ToIntegerString());
+
+ result = Decimal256("123451").ReduceScaleBy(1, true);
+ ASSERT_EQ("12345", result.ToIntegerString());
+
+ result = Decimal256("-123789").ReduceScaleBy(2, true);
+ ASSERT_EQ("-1238", result.ToIntegerString());
+
+ result = Decimal256("-123749").ReduceScaleBy(2, true);
+ ASSERT_EQ("-1237", result.ToIntegerString());
+
+ result = Decimal256("-123750").ReduceScaleBy(2, true);
+ ASSERT_EQ("-1238", result.ToIntegerString());
+}
+
+TEST(Decimal256, FromBigEndianTest) {
+ // We test out a variety of scenarios:
+ //
+ // * Positive values that are left shifted
+ // and filled in with the same bit pattern
+ // * Negated of the positive values
+ // * Complement of the positive values
+ //
+ // For the positive values, we can call FromBigEndian
+ // with a length that is less than 16, whereas we must
+ // pass all 32 bytes for the negative and complement.
+ //
+ // We use a number of bit patterns to increase the coverage
+ // of scenarios
+ for (int32_t start : {1, 1, 15, /* 00001111 */
+ 85, /* 01010101 */
+ 127 /* 01111111 */}) {
+ Decimal256 value(start);
+ for (int ii = 0; ii < 32; ++ii) {
+ auto native_endian = value.ToBytes();
+#if ARROW_LITTLE_ENDIAN
+ std::reverse(native_endian.begin(), native_endian.end());
+#endif
+ // Limit the number of bytes we are passing to make
+ // sure that it works correctly. That's why all of the
+ // 'start' values don't have a 1 in the most significant
+ // bit place
+ ASSERT_OK_AND_EQ(value,
+ Decimal256::FromBigEndian(native_endian.data() + 31 - ii, ii + 1));
+
+ // Negate it
+ auto negated = -value;
+ native_endian = negated.ToBytes();
+#if ARROW_LITTLE_ENDIAN
+ // convert to big endian
+ std::reverse(native_endian.begin(), native_endian.end());
+#endif
+ // The sign bit is looked up in the MSB
+ ASSERT_OK_AND_EQ(negated,
+ Decimal256::FromBigEndian(native_endian.data() + 31 - ii, ii + 1));
+
+ // Take the complement
+ auto complement = ~value;
+ native_endian = complement.ToBytes();
+#if ARROW_LITTLE_ENDIAN
+ // convert to big endian
+ std::reverse(native_endian.begin(), native_endian.end());
+#endif
+ ASSERT_OK_AND_EQ(complement, Decimal256::FromBigEndian(native_endian.data(), 32));
+
+ value <<= 8;
+ value += Decimal256(start);
+ }
+ }
+}
+
+TEST(Decimal256Test, TestFromBigEndianBadLength) {
+ ASSERT_RAISES(Invalid, Decimal128::FromBigEndian(nullptr, -1));
+ ASSERT_RAISES(Invalid, Decimal128::FromBigEndian(nullptr, 33));
+}
+
+class Decimal256ToStringTest : public ::testing::TestWithParam<ToStringTestParam> {};
+
+TEST_P(Decimal256ToStringTest, ToString) {
+ const ToStringTestParam& data = GetParam();
+ const Decimal256 value(data.test_value);
+ const std::string printed_value = value.ToString(data.scale);
+ ASSERT_EQ(data.expected_string, printed_value);
+}
+
+INSTANTIATE_TEST_SUITE_P(Decimal256ToStringTest, Decimal256ToStringTest,
+ ::testing::ValuesIn(kToStringTestData));
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/delimiting.cc b/src/arrow/cpp/src/arrow/util/delimiting.cc
new file mode 100644
index 000000000..fe1b6ea31
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/delimiting.cc
@@ -0,0 +1,193 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/delimiting.h"
+#include "arrow/buffer.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+BoundaryFinder::~BoundaryFinder() {}
+
+namespace {
+
+Status StraddlingTooLarge() {
+ return Status::Invalid(
+ "straddling object straddles two block boundaries (try to increase block size?)");
+}
+
+class NewlineBoundaryFinder : public BoundaryFinder {
+ public:
+ Status FindFirst(util::string_view partial, util::string_view block,
+ int64_t* out_pos) override {
+ auto pos = block.find_first_of(newline_delimiters);
+ if (pos == util::string_view::npos) {
+ *out_pos = kNoDelimiterFound;
+ } else {
+ auto end = block.find_first_not_of(newline_delimiters, pos);
+ if (end == util::string_view::npos) {
+ end = block.length();
+ }
+ *out_pos = static_cast<int64_t>(end);
+ }
+ return Status::OK();
+ }
+
+ Status FindLast(util::string_view block, int64_t* out_pos) override {
+ auto pos = block.find_last_of(newline_delimiters);
+ if (pos == util::string_view::npos) {
+ *out_pos = kNoDelimiterFound;
+ } else {
+ auto end = block.find_first_not_of(newline_delimiters, pos);
+ if (end == util::string_view::npos) {
+ end = block.length();
+ }
+ *out_pos = static_cast<int64_t>(end);
+ }
+ return Status::OK();
+ }
+
+ Status FindNth(util::string_view partial, util::string_view block, int64_t count,
+ int64_t* out_pos, int64_t* num_found) override {
+ DCHECK(partial.find_first_of(newline_delimiters) == util::string_view::npos);
+
+ int64_t found = 0;
+ int64_t pos = kNoDelimiterFound;
+
+ auto cur_pos = block.find_first_of(newline_delimiters);
+ while (cur_pos != util::string_view::npos) {
+ if (block[cur_pos] == '\r' && cur_pos + 1 < block.length() &&
+ block[cur_pos + 1] == '\n') {
+ cur_pos += 2;
+ } else {
+ ++cur_pos;
+ }
+
+ pos = static_cast<int64_t>(cur_pos);
+ if (++found >= count) {
+ break;
+ }
+
+ cur_pos = block.find_first_of(newline_delimiters, cur_pos);
+ }
+
+ *out_pos = pos;
+ *num_found = found;
+ return Status::OK();
+ }
+
+ protected:
+ static constexpr const char* newline_delimiters = "\r\n";
+};
+
+} // namespace
+
+std::shared_ptr<BoundaryFinder> MakeNewlineBoundaryFinder() {
+ return std::make_shared<NewlineBoundaryFinder>();
+}
+
+Chunker::~Chunker() {}
+
+Chunker::Chunker(std::shared_ptr<BoundaryFinder> delimiter)
+ : boundary_finder_(delimiter) {}
+
+Status Chunker::Process(std::shared_ptr<Buffer> block, std::shared_ptr<Buffer>* whole,
+ std::shared_ptr<Buffer>* partial) {
+ int64_t last_pos = -1;
+ RETURN_NOT_OK(boundary_finder_->FindLast(util::string_view(*block), &last_pos));
+ if (last_pos == BoundaryFinder::kNoDelimiterFound) {
+ // No delimiter found
+ *whole = SliceBuffer(block, 0, 0);
+ *partial = block;
+ return Status::OK();
+ } else {
+ *whole = SliceBuffer(block, 0, last_pos);
+ *partial = SliceBuffer(block, last_pos);
+ }
+ return Status::OK();
+}
+
+Status Chunker::ProcessWithPartial(std::shared_ptr<Buffer> partial,
+ std::shared_ptr<Buffer> block,
+ std::shared_ptr<Buffer>* completion,
+ std::shared_ptr<Buffer>* rest) {
+ if (partial->size() == 0) {
+ // If partial is empty, don't bother looking for completion
+ *completion = SliceBuffer(block, 0, 0);
+ *rest = block;
+ return Status::OK();
+ }
+ int64_t first_pos = -1;
+ RETURN_NOT_OK(boundary_finder_->FindFirst(util::string_view(*partial),
+ util::string_view(*block), &first_pos));
+ if (first_pos == BoundaryFinder::kNoDelimiterFound) {
+ // No delimiter in block => the current object is too large for block size
+ return StraddlingTooLarge();
+ } else {
+ *completion = SliceBuffer(block, 0, first_pos);
+ *rest = SliceBuffer(block, first_pos);
+ return Status::OK();
+ }
+}
+
+Status Chunker::ProcessFinal(std::shared_ptr<Buffer> partial,
+ std::shared_ptr<Buffer> block,
+ std::shared_ptr<Buffer>* completion,
+ std::shared_ptr<Buffer>* rest) {
+ if (partial->size() == 0) {
+ // If partial is empty, don't bother looking for completion
+ *completion = SliceBuffer(block, 0, 0);
+ *rest = block;
+ return Status::OK();
+ }
+ int64_t first_pos = -1;
+ RETURN_NOT_OK(boundary_finder_->FindFirst(util::string_view(*partial),
+ util::string_view(*block), &first_pos));
+ if (first_pos == BoundaryFinder::kNoDelimiterFound) {
+ // No delimiter in block => it's entirely a completion of partial
+ *completion = block;
+ *rest = SliceBuffer(block, 0, 0);
+ } else {
+ *completion = SliceBuffer(block, 0, first_pos);
+ *rest = SliceBuffer(block, first_pos);
+ }
+ return Status::OK();
+}
+
+Status Chunker::ProcessSkip(std::shared_ptr<Buffer> partial,
+ std::shared_ptr<Buffer> block, bool final, int64_t* count,
+ std::shared_ptr<Buffer>* rest) {
+ DCHECK_GT(*count, 0);
+ int64_t pos;
+ int64_t num_found;
+ ARROW_RETURN_NOT_OK(boundary_finder_->FindNth(
+ util::string_view(*partial), util::string_view(*block), *count, &pos, &num_found));
+ if (pos == BoundaryFinder::kNoDelimiterFound) {
+ return StraddlingTooLarge();
+ }
+ if (ARROW_PREDICT_FALSE(final && *count > num_found && block->size() != pos)) {
+ // Skip the last row in the final block which does not have a delimiter
+ ++num_found;
+ *rest = SliceBuffer(block, 0, 0);
+ } else {
+ *rest = SliceBuffer(block, pos);
+ }
+ *count -= num_found;
+ return Status::OK();
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/delimiting.h b/src/arrow/cpp/src/arrow/util/delimiting.h
new file mode 100644
index 000000000..b4b868340
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/delimiting.h
@@ -0,0 +1,181 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/status.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Buffer;
+
+class ARROW_EXPORT BoundaryFinder {
+ public:
+ BoundaryFinder() = default;
+
+ virtual ~BoundaryFinder();
+
+ /// \brief Find the position of the first delimiter inside block
+ ///
+ /// `partial` is taken to be the beginning of the block, and `block`
+ /// its continuation. Also, `partial` doesn't contain a delimiter.
+ ///
+ /// The returned `out_pos` is relative to `block`'s start and should point
+ /// to the first character after the first delimiter.
+ /// `out_pos` will be -1 if no delimiter is found.
+ virtual Status FindFirst(util::string_view partial, util::string_view block,
+ int64_t* out_pos) = 0;
+
+ /// \brief Find the position of the last delimiter inside block
+ ///
+ /// The returned `out_pos` is relative to `block`'s start and should point
+ /// to the first character after the last delimiter.
+ /// `out_pos` will be -1 if no delimiter is found.
+ virtual Status FindLast(util::string_view block, int64_t* out_pos) = 0;
+
+ /// \brief Find the position of the Nth delimiter inside the block
+ ///
+ /// `partial` is taken to be the beginning of the block, and `block`
+ /// its continuation. Also, `partial` doesn't contain a delimiter.
+ ///
+ /// The returned `out_pos` is relative to `block`'s start and should point
+ /// to the first character after the first delimiter.
+ /// `out_pos` will be -1 if no delimiter is found.
+ ///
+ /// The returned `num_found` is the number of delimiters actually found
+ virtual Status FindNth(util::string_view partial, util::string_view block,
+ int64_t count, int64_t* out_pos, int64_t* num_found) = 0;
+
+ static constexpr int64_t kNoDelimiterFound = -1;
+
+ protected:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(BoundaryFinder);
+};
+
+ARROW_EXPORT
+std::shared_ptr<BoundaryFinder> MakeNewlineBoundaryFinder();
+
+/// \brief A reusable block-based chunker for delimited data
+///
+/// The chunker takes a block of delimited data and helps carve a sub-block
+/// which begins and ends on delimiters (suitable for consumption by parsers
+/// which can only parse whole objects).
+class ARROW_EXPORT Chunker {
+ public:
+ explicit Chunker(std::shared_ptr<BoundaryFinder> delimiter);
+ ~Chunker();
+
+ /// \brief Carve up a chunk in a block of data to contain only whole objects
+ ///
+ /// Pre-conditions:
+ /// - `block` is the start of a valid block of delimited data
+ /// (i.e. starts just after a delimiter)
+ ///
+ /// Post-conditions:
+ /// - block == whole + partial
+ /// - `whole` is a valid block of delimited data
+ /// (i.e. starts just after a delimiter and ends with a delimiter)
+ /// - `partial` doesn't contain an entire delimited object
+ /// (IOW: `partial` is generally small)
+ ///
+ /// This method will look for the last delimiter in `block` and may
+ /// therefore be costly.
+ ///
+ /// \param[in] block data to be chunked
+ /// \param[out] whole subrange of block containing whole delimited objects
+ /// \param[out] partial subrange of block starting with a partial delimited object
+ Status Process(std::shared_ptr<Buffer> block, std::shared_ptr<Buffer>* whole,
+ std::shared_ptr<Buffer>* partial);
+
+ /// \brief Carve the completion of a partial object out of a block
+ ///
+ /// Pre-conditions:
+ /// - `partial` is the start of a valid block of delimited data
+ /// (i.e. starts just after a delimiter)
+ /// - `block` follows `partial` in file order
+ ///
+ /// Post-conditions:
+ /// - block == completion + rest
+ /// - `partial + completion` is a valid block of delimited data
+ /// (i.e. starts just after a delimiter and ends with a delimiter)
+ /// - `completion` doesn't contain an entire delimited object
+ /// (IOW: `completion` is generally small)
+ ///
+ /// This method will look for the first delimiter in `block` and should
+ /// therefore be reasonably cheap.
+ ///
+ /// \param[in] partial incomplete delimited data
+ /// \param[in] block delimited data following partial
+ /// \param[out] completion subrange of block containing the completion of partial
+ /// \param[out] rest subrange of block containing what completion does not cover
+ Status ProcessWithPartial(std::shared_ptr<Buffer> partial,
+ std::shared_ptr<Buffer> block,
+ std::shared_ptr<Buffer>* completion,
+ std::shared_ptr<Buffer>* rest);
+
+ /// \brief Like ProcessWithPartial, but for the last block of a file
+ ///
+ /// This method allows for a final delimited object without a trailing delimiter
+ /// (ProcessWithPartial would return an error in that case).
+ ///
+ /// Pre-conditions:
+ /// - `partial` is the start of a valid block of delimited data
+ /// - `block` follows `partial` in file order and is the last data block
+ ///
+ /// Post-conditions:
+ /// - block == completion + rest
+ /// - `partial + completion` is a valid block of delimited data
+ /// - `completion` doesn't contain an entire delimited object
+ /// (IOW: `completion` is generally small)
+ ///
+ Status ProcessFinal(std::shared_ptr<Buffer> partial, std::shared_ptr<Buffer> block,
+ std::shared_ptr<Buffer>* completion, std::shared_ptr<Buffer>* rest);
+
+ /// \brief Skip count number of rows
+ /// Pre-conditions:
+ /// - `partial` is the start of a valid block of delimited data
+ /// (i.e. starts just after a delimiter)
+ /// - `block` follows `partial` in file order
+ ///
+ /// Post-conditions:
+ /// - `count` is updated to indicate the number of rows that still need to be skipped
+ /// - If `count` is > 0 then `rest` is an incomplete block that should be a future
+ /// `partial`
+ /// - Else `rest` could be one or more valid blocks of delimited data which need to be
+ /// parsed
+ ///
+ /// \param[in] partial incomplete delimited data
+ /// \param[in] block delimited data following partial
+ /// \param[in] final whether this is the final chunk
+ /// \param[in,out] count number of rows that need to be skipped
+ /// \param[out] rest subrange of block containing what was not skipped
+ Status ProcessSkip(std::shared_ptr<Buffer> partial, std::shared_ptr<Buffer> block,
+ bool final, int64_t* count, std::shared_ptr<Buffer>* rest);
+
+ protected:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Chunker);
+
+ std::shared_ptr<BoundaryFinder> boundary_finder_;
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/dispatch.h b/src/arrow/cpp/src/arrow/util/dispatch.h
new file mode 100644
index 000000000..fae9293f9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/dispatch.h
@@ -0,0 +1,115 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <utility>
+#include <vector>
+
+#include "arrow/status.h"
+#include "arrow/util/cpu_info.h"
+
+namespace arrow {
+namespace internal {
+
+enum class DispatchLevel : int {
+ // These dispatch levels, corresponding to instruction set features,
+ // are sorted in increasing order of preference.
+ NONE = 0,
+ SSE4_2,
+ AVX2,
+ AVX512,
+ NEON,
+ MAX
+};
+
+/*
+ A facility for dynamic dispatch according to available DispatchLevel.
+
+ Typical use:
+
+ static void my_function_default(...);
+ static void my_function_avx2(...);
+
+ struct MyDynamicFunction {
+ using FunctionType = decltype(&my_function_default);
+
+ static std::vector<std::pair<DispatchLevel, FunctionType>> implementations() {
+ return {
+ { DispatchLevel::NONE, my_function_default }
+ #if defined(ARROW_HAVE_RUNTIME_AVX2)
+ , { DispatchLevel::AVX2, my_function_avx2 }
+ #endif
+ };
+ }
+ };
+
+ void my_function(...) {
+ static DynamicDispatch<MyDynamicFunction> dispatch;
+ return dispatch.func(...);
+ }
+*/
+template <typename DynamicFunction>
+class DynamicDispatch {
+ protected:
+ using FunctionType = typename DynamicFunction::FunctionType;
+ using Implementation = std::pair<DispatchLevel, FunctionType>;
+
+ public:
+ DynamicDispatch() { Resolve(DynamicFunction::implementations()); }
+
+ FunctionType func = {};
+
+ protected:
+ // Use the Implementation with the highest DispatchLevel
+ void Resolve(const std::vector<Implementation>& implementations) {
+ Implementation cur{DispatchLevel::NONE, {}};
+
+ for (const auto& impl : implementations) {
+ if (impl.first >= cur.first && IsSupported(impl.first)) {
+ // Higher (or same) level than current
+ cur = impl;
+ }
+ }
+
+ if (!cur.second) {
+ Status::Invalid("No appropriate implementation found").Abort();
+ }
+ func = cur.second;
+ }
+
+ private:
+ bool IsSupported(DispatchLevel level) const {
+ static const auto cpu_info = arrow::internal::CpuInfo::GetInstance();
+
+ switch (level) {
+ case DispatchLevel::NONE:
+ return true;
+ case DispatchLevel::SSE4_2:
+ return cpu_info->IsSupported(CpuInfo::SSE4_2);
+ case DispatchLevel::AVX2:
+ return cpu_info->IsSupported(CpuInfo::AVX2);
+ case DispatchLevel::AVX512:
+ return cpu_info->IsSupported(CpuInfo::AVX512);
+ default:
+ return false;
+ }
+ }
+};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/double_conversion.h b/src/arrow/cpp/src/arrow/util/double_conversion.h
new file mode 100644
index 000000000..8edc6544b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/double_conversion.h
@@ -0,0 +1,32 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/vendored/double-conversion/double-conversion.h" // IWYU pragma: export
+
+namespace arrow {
+namespace util {
+namespace double_conversion {
+
+using ::double_conversion::DoubleToStringConverter;
+using ::double_conversion::StringBuilder;
+using ::double_conversion::StringToDoubleConverter;
+
+} // namespace double_conversion
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/endian.h b/src/arrow/cpp/src/arrow/util/endian.h
new file mode 100644
index 000000000..0fae454e0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/endian.h
@@ -0,0 +1,245 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#ifdef _WIN32
+#define ARROW_LITTLE_ENDIAN 1
+#else
+#if defined(__APPLE__) || defined(__FreeBSD__)
+#include <machine/endian.h> // IWYU pragma: keep
+#elif defined(sun) || defined(__sun)
+#include <sys/byteorder.h> // IWYU pragma: keep
+#else
+#include <endian.h> // IWYU pragma: keep
+#endif
+#
+#ifndef __BYTE_ORDER__
+#error "__BYTE_ORDER__ not defined"
+#endif
+#
+#ifndef __ORDER_LITTLE_ENDIAN__
+#error "__ORDER_LITTLE_ENDIAN__ not defined"
+#endif
+#
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+#define ARROW_LITTLE_ENDIAN 1
+#else
+#define ARROW_LITTLE_ENDIAN 0
+#endif
+#endif
+
+#if defined(_MSC_VER)
+#include <intrin.h> // IWYU pragma: keep
+#define ARROW_BYTE_SWAP64 _byteswap_uint64
+#define ARROW_BYTE_SWAP32 _byteswap_ulong
+#else
+#define ARROW_BYTE_SWAP64 __builtin_bswap64
+#define ARROW_BYTE_SWAP32 __builtin_bswap32
+#endif
+
+#include <algorithm>
+#include <array>
+
+#include "arrow/util/type_traits.h"
+#include "arrow/util/ubsan.h"
+
+namespace arrow {
+namespace BitUtil {
+
+//
+// Byte-swap 16-bit, 32-bit and 64-bit values
+//
+
+// Swap the byte order (i.e. endianness)
+static inline int64_t ByteSwap(int64_t value) { return ARROW_BYTE_SWAP64(value); }
+static inline uint64_t ByteSwap(uint64_t value) {
+ return static_cast<uint64_t>(ARROW_BYTE_SWAP64(value));
+}
+static inline int32_t ByteSwap(int32_t value) { return ARROW_BYTE_SWAP32(value); }
+static inline uint32_t ByteSwap(uint32_t value) {
+ return static_cast<uint32_t>(ARROW_BYTE_SWAP32(value));
+}
+static inline int16_t ByteSwap(int16_t value) {
+ constexpr auto m = static_cast<int16_t>(0xff);
+ return static_cast<int16_t>(((value >> 8) & m) | ((value & m) << 8));
+}
+static inline uint16_t ByteSwap(uint16_t value) {
+ return static_cast<uint16_t>(ByteSwap(static_cast<int16_t>(value)));
+}
+static inline uint8_t ByteSwap(uint8_t value) { return value; }
+static inline int8_t ByteSwap(int8_t value) { return value; }
+static inline double ByteSwap(double value) {
+ const uint64_t swapped = ARROW_BYTE_SWAP64(util::SafeCopy<uint64_t>(value));
+ return util::SafeCopy<double>(swapped);
+}
+static inline float ByteSwap(float value) {
+ const uint32_t swapped = ARROW_BYTE_SWAP32(util::SafeCopy<uint32_t>(value));
+ return util::SafeCopy<float>(swapped);
+}
+
+// Write the swapped bytes into dst. Src and dst cannot overlap.
+static inline void ByteSwap(void* dst, const void* src, int len) {
+ switch (len) {
+ case 1:
+ *reinterpret_cast<int8_t*>(dst) = *reinterpret_cast<const int8_t*>(src);
+ return;
+ case 2:
+ *reinterpret_cast<int16_t*>(dst) = ByteSwap(*reinterpret_cast<const int16_t*>(src));
+ return;
+ case 4:
+ *reinterpret_cast<int32_t*>(dst) = ByteSwap(*reinterpret_cast<const int32_t*>(src));
+ return;
+ case 8:
+ *reinterpret_cast<int64_t*>(dst) = ByteSwap(*reinterpret_cast<const int64_t*>(src));
+ return;
+ default:
+ break;
+ }
+
+ auto d = reinterpret_cast<uint8_t*>(dst);
+ auto s = reinterpret_cast<const uint8_t*>(src);
+ for (int i = 0; i < len; ++i) {
+ d[i] = s[len - i - 1];
+ }
+}
+
+// Convert to little/big endian format from the machine's native endian format.
+#if ARROW_LITTLE_ENDIAN
+template <typename T, typename = internal::EnableIfIsOneOf<
+ T, int64_t, uint64_t, int32_t, uint32_t, int16_t, uint16_t,
+ uint8_t, int8_t, float, double>>
+static inline T ToBigEndian(T value) {
+ return ByteSwap(value);
+}
+
+template <typename T, typename = internal::EnableIfIsOneOf<
+ T, int64_t, uint64_t, int32_t, uint32_t, int16_t, uint16_t,
+ uint8_t, int8_t, float, double>>
+static inline T ToLittleEndian(T value) {
+ return value;
+}
+#else
+template <typename T, typename = internal::EnableIfIsOneOf<
+ T, int64_t, uint64_t, int32_t, uint32_t, int16_t, uint16_t,
+ uint8_t, int8_t, float, double>>
+static inline T ToBigEndian(T value) {
+ return value;
+}
+
+template <typename T, typename = internal::EnableIfIsOneOf<
+ T, int64_t, uint64_t, int32_t, uint32_t, int16_t, uint16_t,
+ uint8_t, int8_t, float, double>>
+static inline T ToLittleEndian(T value) {
+ return ByteSwap(value);
+}
+#endif
+
+// Convert from big/little endian format to the machine's native endian format.
+#if ARROW_LITTLE_ENDIAN
+template <typename T, typename = internal::EnableIfIsOneOf<
+ T, int64_t, uint64_t, int32_t, uint32_t, int16_t, uint16_t,
+ uint8_t, int8_t, float, double>>
+static inline T FromBigEndian(T value) {
+ return ByteSwap(value);
+}
+
+template <typename T, typename = internal::EnableIfIsOneOf<
+ T, int64_t, uint64_t, int32_t, uint32_t, int16_t, uint16_t,
+ uint8_t, int8_t, float, double>>
+static inline T FromLittleEndian(T value) {
+ return value;
+}
+#else
+template <typename T, typename = internal::EnableIfIsOneOf<
+ T, int64_t, uint64_t, int32_t, uint32_t, int16_t, uint16_t,
+ uint8_t, int8_t, float, double>>
+static inline T FromBigEndian(T value) {
+ return value;
+}
+
+template <typename T, typename = internal::EnableIfIsOneOf<
+ T, int64_t, uint64_t, int32_t, uint32_t, int16_t, uint16_t,
+ uint8_t, int8_t, float, double>>
+static inline T FromLittleEndian(T value) {
+ return ByteSwap(value);
+}
+#endif
+
+// Handle endianness in *word* granuality (keep individual array element untouched)
+namespace LittleEndianArray {
+
+namespace detail {
+
+// Read a native endian array as little endian
+template <typename T, size_t N>
+struct Reader {
+ const std::array<T, N>& native_array;
+
+ explicit Reader(const std::array<T, N>& native_array) : native_array(native_array) {}
+
+ const T& operator[](size_t i) const {
+ return native_array[ARROW_LITTLE_ENDIAN ? i : N - 1 - i];
+ }
+};
+
+// Read/write a native endian array as little endian
+template <typename T, size_t N>
+struct Writer {
+ std::array<T, N>* native_array;
+
+ explicit Writer(std::array<T, N>* native_array) : native_array(native_array) {}
+
+ const T& operator[](size_t i) const {
+ return (*native_array)[ARROW_LITTLE_ENDIAN ? i : N - 1 - i];
+ }
+ T& operator[](size_t i) { return (*native_array)[ARROW_LITTLE_ENDIAN ? i : N - 1 - i]; }
+};
+
+} // namespace detail
+
+// Construct array reader and try to deduce template augments
+template <typename T, size_t N>
+static inline detail::Reader<T, N> Make(const std::array<T, N>& native_array) {
+ return detail::Reader<T, N>(native_array);
+}
+
+// Construct array writer and try to deduce template augments
+template <typename T, size_t N>
+static inline detail::Writer<T, N> Make(std::array<T, N>* native_array) {
+ return detail::Writer<T, N>(native_array);
+}
+
+// Convert little endian array to native endian
+template <typename T, size_t N>
+static inline std::array<T, N> ToNative(std::array<T, N> array) {
+ if (!ARROW_LITTLE_ENDIAN) {
+ std::reverse(array.begin(), array.end());
+ }
+ return array;
+}
+
+// Convert native endian array to little endian
+template <typename T, size_t N>
+static inline std::array<T, N> FromNative(std::array<T, N> array) {
+ return ToNative(array);
+}
+
+} // namespace LittleEndianArray
+
+} // namespace BitUtil
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/formatting.cc b/src/arrow/cpp/src/arrow/util/formatting.cc
new file mode 100644
index 000000000..c16d42ce5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/formatting.cc
@@ -0,0 +1,91 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/formatting.h"
+#include "arrow/util/config.h"
+#include "arrow/util/double_conversion.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using util::double_conversion::DoubleToStringConverter;
+
+static constexpr int kMinBufferSize = DoubleToStringConverter::kBase10MaximalLength + 1;
+
+namespace internal {
+namespace detail {
+
+const char digit_pairs[] =
+ "0001020304050607080910111213141516171819"
+ "2021222324252627282930313233343536373839"
+ "4041424344454647484950515253545556575859"
+ "6061626364656667686970717273747576777879"
+ "8081828384858687888990919293949596979899";
+
+} // namespace detail
+
+struct FloatToStringFormatter::Impl {
+ Impl()
+ : converter_(DoubleToStringConverter::EMIT_POSITIVE_EXPONENT_SIGN, "inf", "nan",
+ 'e', -6, 10, 6, 0) {}
+
+ Impl(int flags, const char* inf_symbol, const char* nan_symbol, char exp_character,
+ int decimal_in_shortest_low, int decimal_in_shortest_high,
+ int max_leading_padding_zeroes_in_precision_mode,
+ int max_trailing_padding_zeroes_in_precision_mode)
+ : converter_(flags, inf_symbol, nan_symbol, exp_character, decimal_in_shortest_low,
+ decimal_in_shortest_high, max_leading_padding_zeroes_in_precision_mode,
+ max_trailing_padding_zeroes_in_precision_mode) {}
+
+ DoubleToStringConverter converter_;
+};
+
+FloatToStringFormatter::FloatToStringFormatter() : impl_(new Impl()) {}
+
+FloatToStringFormatter::FloatToStringFormatter(
+ int flags, const char* inf_symbol, const char* nan_symbol, char exp_character,
+ int decimal_in_shortest_low, int decimal_in_shortest_high,
+ int max_leading_padding_zeroes_in_precision_mode,
+ int max_trailing_padding_zeroes_in_precision_mode)
+ : impl_(new Impl(flags, inf_symbol, nan_symbol, exp_character,
+ decimal_in_shortest_low, decimal_in_shortest_high,
+ max_leading_padding_zeroes_in_precision_mode,
+ max_trailing_padding_zeroes_in_precision_mode)) {}
+
+FloatToStringFormatter::~FloatToStringFormatter() {}
+
+int FloatToStringFormatter::FormatFloat(float v, char* out_buffer, int out_size) {
+ DCHECK_GE(out_size, kMinBufferSize);
+ // StringBuilder checks bounds in debug mode for us
+ util::double_conversion::StringBuilder builder(out_buffer, out_size);
+ bool result = impl_->converter_.ToShortestSingle(v, &builder);
+ DCHECK(result);
+ ARROW_UNUSED(result);
+ return builder.position();
+}
+
+int FloatToStringFormatter::FormatFloat(double v, char* out_buffer, int out_size) {
+ DCHECK_GE(out_size, kMinBufferSize);
+ util::double_conversion::StringBuilder builder(out_buffer, out_size);
+ bool result = impl_->converter_.ToShortest(v, &builder);
+ DCHECK(result);
+ ARROW_UNUSED(result);
+ return builder.position();
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/formatting.h b/src/arrow/cpp/src/arrow/util/formatting.h
new file mode 100644
index 000000000..09eb748e4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/formatting.h
@@ -0,0 +1,602 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This is a private header for number-to-string formatting utilities
+
+#pragma once
+
+#include <array>
+#include <cassert>
+#include <chrono>
+#include <limits>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <utility>
+
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/double_conversion.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/time.h"
+#include "arrow/util/visibility.h"
+#include "arrow/vendored/datetime.h"
+
+namespace arrow {
+namespace internal {
+
+/// \brief The entry point for conversion to strings.
+template <typename ARROW_TYPE, typename Enable = void>
+class StringFormatter;
+
+template <typename T>
+struct is_formattable {
+ template <typename U, typename = typename StringFormatter<U>::value_type>
+ static std::true_type Test(U*);
+
+ template <typename U>
+ static std::false_type Test(...);
+
+ static constexpr bool value = decltype(Test<T>(NULLPTR))::value;
+};
+
+template <typename T, typename R = void>
+using enable_if_formattable = enable_if_t<is_formattable<T>::value, R>;
+
+template <typename Appender>
+using Return = decltype(std::declval<Appender>()(util::string_view{}));
+
+/////////////////////////////////////////////////////////////////////////
+// Boolean formatting
+
+template <>
+class StringFormatter<BooleanType> {
+ public:
+ explicit StringFormatter(const std::shared_ptr<DataType>& = NULLPTR) {}
+
+ using value_type = bool;
+
+ template <typename Appender>
+ Return<Appender> operator()(bool value, Appender&& append) {
+ if (value) {
+ const char string[] = "true";
+ return append(util::string_view(string));
+ } else {
+ const char string[] = "false";
+ return append(util::string_view(string));
+ }
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////
+// Integer formatting
+
+namespace detail {
+
+// A 2x100 direct table mapping integers in [0..99] to their decimal representations.
+ARROW_EXPORT extern const char digit_pairs[];
+
+// Based on fmtlib's format_int class:
+// Write digits from right to left into a stack allocated buffer
+inline void FormatOneChar(char c, char** cursor) { *--*cursor = c; }
+
+template <typename Int>
+void FormatOneDigit(Int value, char** cursor) {
+ assert(value >= 0 && value <= 9);
+ FormatOneChar(static_cast<char>('0' + value), cursor);
+}
+
+template <typename Int>
+void FormatTwoDigits(Int value, char** cursor) {
+ assert(value >= 0 && value <= 99);
+ auto digit_pair = &digit_pairs[value * 2];
+ FormatOneChar(digit_pair[1], cursor);
+ FormatOneChar(digit_pair[0], cursor);
+}
+
+template <typename Int>
+void FormatAllDigits(Int value, char** cursor) {
+ assert(value >= 0);
+ while (value >= 100) {
+ FormatTwoDigits(value % 100, cursor);
+ value /= 100;
+ }
+
+ if (value >= 10) {
+ FormatTwoDigits(value, cursor);
+ } else {
+ FormatOneDigit(value, cursor);
+ }
+}
+
+template <typename Int>
+void FormatAllDigitsLeftPadded(Int value, size_t pad, char pad_char, char** cursor) {
+ auto end = *cursor - pad;
+ FormatAllDigits(value, cursor);
+ while (*cursor > end) {
+ FormatOneChar(pad_char, cursor);
+ }
+}
+
+template <size_t BUFFER_SIZE>
+util::string_view ViewDigitBuffer(const std::array<char, BUFFER_SIZE>& buffer,
+ char* cursor) {
+ auto buffer_end = buffer.data() + BUFFER_SIZE;
+ return {cursor, static_cast<size_t>(buffer_end - cursor)};
+}
+
+template <typename Int, typename UInt = typename std::make_unsigned<Int>::type>
+constexpr UInt Abs(Int value) {
+ return value < 0 ? ~static_cast<UInt>(value) + 1 : static_cast<UInt>(value);
+}
+
+template <typename Int>
+constexpr size_t Digits10(Int value) {
+ return value <= 9 ? 1 : Digits10(value / 10) + 1;
+}
+
+} // namespace detail
+
+template <typename ARROW_TYPE>
+class IntToStringFormatterMixin {
+ public:
+ explicit IntToStringFormatterMixin(const std::shared_ptr<DataType>& = NULLPTR) {}
+
+ using value_type = typename ARROW_TYPE::c_type;
+
+ template <typename Appender>
+ Return<Appender> operator()(value_type value, Appender&& append) {
+ constexpr size_t buffer_size =
+ detail::Digits10(std::numeric_limits<value_type>::max()) + 1;
+
+ std::array<char, buffer_size> buffer;
+ char* cursor = buffer.data() + buffer_size;
+ detail::FormatAllDigits(detail::Abs(value), &cursor);
+ if (value < 0) {
+ detail::FormatOneChar('-', &cursor);
+ }
+ return append(detail::ViewDigitBuffer(buffer, cursor));
+ }
+};
+
+template <>
+class StringFormatter<Int8Type> : public IntToStringFormatterMixin<Int8Type> {
+ using IntToStringFormatterMixin::IntToStringFormatterMixin;
+};
+
+template <>
+class StringFormatter<Int16Type> : public IntToStringFormatterMixin<Int16Type> {
+ using IntToStringFormatterMixin::IntToStringFormatterMixin;
+};
+
+template <>
+class StringFormatter<Int32Type> : public IntToStringFormatterMixin<Int32Type> {
+ using IntToStringFormatterMixin::IntToStringFormatterMixin;
+};
+
+template <>
+class StringFormatter<Int64Type> : public IntToStringFormatterMixin<Int64Type> {
+ using IntToStringFormatterMixin::IntToStringFormatterMixin;
+};
+
+template <>
+class StringFormatter<UInt8Type> : public IntToStringFormatterMixin<UInt8Type> {
+ using IntToStringFormatterMixin::IntToStringFormatterMixin;
+};
+
+template <>
+class StringFormatter<UInt16Type> : public IntToStringFormatterMixin<UInt16Type> {
+ using IntToStringFormatterMixin::IntToStringFormatterMixin;
+};
+
+template <>
+class StringFormatter<UInt32Type> : public IntToStringFormatterMixin<UInt32Type> {
+ using IntToStringFormatterMixin::IntToStringFormatterMixin;
+};
+
+template <>
+class StringFormatter<UInt64Type> : public IntToStringFormatterMixin<UInt64Type> {
+ using IntToStringFormatterMixin::IntToStringFormatterMixin;
+};
+
+/////////////////////////////////////////////////////////////////////////
+// Floating-point formatting
+
+class ARROW_EXPORT FloatToStringFormatter {
+ public:
+ FloatToStringFormatter();
+ FloatToStringFormatter(int flags, const char* inf_symbol, const char* nan_symbol,
+ char exp_character, int decimal_in_shortest_low,
+ int decimal_in_shortest_high,
+ int max_leading_padding_zeroes_in_precision_mode,
+ int max_trailing_padding_zeroes_in_precision_mode);
+ ~FloatToStringFormatter();
+
+ // Returns the number of characters written
+ int FormatFloat(float v, char* out_buffer, int out_size);
+ int FormatFloat(double v, char* out_buffer, int out_size);
+
+ protected:
+ struct Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+template <typename ARROW_TYPE>
+class FloatToStringFormatterMixin : public FloatToStringFormatter {
+ public:
+ using value_type = typename ARROW_TYPE::c_type;
+
+ static constexpr int buffer_size = 50;
+
+ explicit FloatToStringFormatterMixin(const std::shared_ptr<DataType>& = NULLPTR) {}
+
+ FloatToStringFormatterMixin(int flags, const char* inf_symbol, const char* nan_symbol,
+ char exp_character, int decimal_in_shortest_low,
+ int decimal_in_shortest_high,
+ int max_leading_padding_zeroes_in_precision_mode,
+ int max_trailing_padding_zeroes_in_precision_mode)
+ : FloatToStringFormatter(flags, inf_symbol, nan_symbol, exp_character,
+ decimal_in_shortest_low, decimal_in_shortest_high,
+ max_leading_padding_zeroes_in_precision_mode,
+ max_trailing_padding_zeroes_in_precision_mode) {}
+
+ template <typename Appender>
+ Return<Appender> operator()(value_type value, Appender&& append) {
+ char buffer[buffer_size];
+ int size = FormatFloat(value, buffer, buffer_size);
+ return append(util::string_view(buffer, size));
+ }
+};
+
+template <>
+class StringFormatter<FloatType> : public FloatToStringFormatterMixin<FloatType> {
+ public:
+ using FloatToStringFormatterMixin::FloatToStringFormatterMixin;
+};
+
+template <>
+class StringFormatter<DoubleType> : public FloatToStringFormatterMixin<DoubleType> {
+ public:
+ using FloatToStringFormatterMixin::FloatToStringFormatterMixin;
+};
+
+/////////////////////////////////////////////////////////////////////////
+// Temporal formatting
+
+namespace detail {
+
+constexpr size_t BufferSizeYYYY_MM_DD() {
+ return 1 + detail::Digits10(99999) + 1 + detail::Digits10(12) + 1 +
+ detail::Digits10(31);
+}
+
+inline void FormatYYYY_MM_DD(arrow_vendored::date::year_month_day ymd, char** cursor) {
+ FormatTwoDigits(static_cast<unsigned>(ymd.day()), cursor);
+ FormatOneChar('-', cursor);
+ FormatTwoDigits(static_cast<unsigned>(ymd.month()), cursor);
+ FormatOneChar('-', cursor);
+ auto year = static_cast<int>(ymd.year());
+ const auto is_neg_year = year < 0;
+ year = std::abs(year);
+ assert(year <= 99999);
+ FormatTwoDigits(year % 100, cursor);
+ year /= 100;
+ FormatTwoDigits(year % 100, cursor);
+ if (year >= 100) {
+ FormatOneDigit(year / 100, cursor);
+ }
+ if (is_neg_year) {
+ FormatOneChar('-', cursor);
+ }
+}
+
+template <typename Duration>
+constexpr size_t BufferSizeHH_MM_SS() {
+ return detail::Digits10(23) + 1 + detail::Digits10(59) + 1 + detail::Digits10(59) + 1 +
+ detail::Digits10(Duration::period::den) - 1;
+}
+
+template <typename Duration>
+void FormatHH_MM_SS(arrow_vendored::date::hh_mm_ss<Duration> hms, char** cursor) {
+ constexpr size_t subsecond_digits = Digits10(Duration::period::den) - 1;
+ if (subsecond_digits != 0) {
+ FormatAllDigitsLeftPadded(hms.subseconds().count(), subsecond_digits, '0', cursor);
+ FormatOneChar('.', cursor);
+ }
+ FormatTwoDigits(hms.seconds().count(), cursor);
+ FormatOneChar(':', cursor);
+ FormatTwoDigits(hms.minutes().count(), cursor);
+ FormatOneChar(':', cursor);
+ FormatTwoDigits(hms.hours().count(), cursor);
+}
+
+// Some out-of-bound datetime values would result in erroneous printing
+// because of silent integer wraparound in the `arrow_vendored::date` library.
+//
+// To avoid such misprinting, we must therefore check the bounds explicitly.
+// The bounds correspond to start of year -32767 and end of year 32767,
+// respectively (-32768 is an invalid year value in `arrow_vendored::date`).
+//
+// Note these values are the same as documented for C++20:
+// https://en.cppreference.com/w/cpp/chrono/year_month_day/operator_days
+template <typename Unit>
+bool IsDateTimeInRange(Unit duration) {
+ constexpr Unit kMinIncl =
+ std::chrono::duration_cast<Unit>(arrow_vendored::date::days{-12687428});
+ constexpr Unit kMaxExcl =
+ std::chrono::duration_cast<Unit>(arrow_vendored::date::days{11248738});
+ return duration >= kMinIncl && duration < kMaxExcl;
+}
+
+// IsDateTimeInRange() specialization for nanoseconds: a 64-bit number of
+// nanoseconds cannot represent years outside of the [-32767, 32767]
+// range, and the {kMinIncl, kMaxExcl} constants above would overflow.
+constexpr bool IsDateTimeInRange(std::chrono::nanoseconds duration) { return true; }
+
+template <typename Unit>
+bool IsTimeInRange(Unit duration) {
+ constexpr Unit kMinIncl = std::chrono::duration_cast<Unit>(std::chrono::seconds{0});
+ constexpr Unit kMaxExcl = std::chrono::duration_cast<Unit>(std::chrono::seconds{86400});
+ return duration >= kMinIncl && duration < kMaxExcl;
+}
+
+template <typename RawValue, typename Appender>
+Return<Appender> FormatOutOfRange(RawValue&& raw_value, Appender&& append) {
+ // XXX locale-sensitive but good enough for now
+ std::string formatted = "<value out of range: " + std::to_string(raw_value) + ">";
+ return append(std::move(formatted));
+}
+
+const auto kEpoch = arrow_vendored::date::sys_days{arrow_vendored::date::jan / 1 / 1970};
+
+} // namespace detail
+
+template <>
+class StringFormatter<DurationType> : public IntToStringFormatterMixin<DurationType> {
+ using IntToStringFormatterMixin::IntToStringFormatterMixin;
+};
+
+class DateToStringFormatterMixin {
+ public:
+ explicit DateToStringFormatterMixin(const std::shared_ptr<DataType>& = NULLPTR) {}
+
+ protected:
+ template <typename Appender>
+ Return<Appender> FormatDays(arrow_vendored::date::days since_epoch, Appender&& append) {
+ arrow_vendored::date::sys_days timepoint_days{since_epoch};
+
+ constexpr size_t buffer_size = detail::BufferSizeYYYY_MM_DD();
+
+ std::array<char, buffer_size> buffer;
+ char* cursor = buffer.data() + buffer_size;
+
+ detail::FormatYYYY_MM_DD(arrow_vendored::date::year_month_day{timepoint_days},
+ &cursor);
+ return append(detail::ViewDigitBuffer(buffer, cursor));
+ }
+};
+
+template <>
+class StringFormatter<Date32Type> : public DateToStringFormatterMixin {
+ public:
+ using value_type = typename Date32Type::c_type;
+
+ using DateToStringFormatterMixin::DateToStringFormatterMixin;
+
+ template <typename Appender>
+ Return<Appender> operator()(value_type value, Appender&& append) {
+ const auto since_epoch = arrow_vendored::date::days{value};
+ if (!ARROW_PREDICT_TRUE(detail::IsDateTimeInRange(since_epoch))) {
+ return detail::FormatOutOfRange(value, append);
+ }
+ return FormatDays(since_epoch, std::forward<Appender>(append));
+ }
+};
+
+template <>
+class StringFormatter<Date64Type> : public DateToStringFormatterMixin {
+ public:
+ using value_type = typename Date64Type::c_type;
+
+ using DateToStringFormatterMixin::DateToStringFormatterMixin;
+
+ template <typename Appender>
+ Return<Appender> operator()(value_type value, Appender&& append) {
+ const auto since_epoch = std::chrono::milliseconds{value};
+ if (!ARROW_PREDICT_TRUE(detail::IsDateTimeInRange(since_epoch))) {
+ return detail::FormatOutOfRange(value, append);
+ }
+ return FormatDays(std::chrono::duration_cast<arrow_vendored::date::days>(since_epoch),
+ std::forward<Appender>(append));
+ }
+};
+
+template <>
+class StringFormatter<TimestampType> {
+ public:
+ using value_type = int64_t;
+
+ explicit StringFormatter(const std::shared_ptr<DataType>& type)
+ : unit_(checked_cast<const TimestampType&>(*type).unit()) {}
+
+ template <typename Duration, typename Appender>
+ Return<Appender> operator()(Duration, value_type value, Appender&& append) {
+ using arrow_vendored::date::days;
+
+ const Duration since_epoch{value};
+ if (!ARROW_PREDICT_TRUE(detail::IsDateTimeInRange(since_epoch))) {
+ return detail::FormatOutOfRange(value, append);
+ }
+
+ const auto timepoint = detail::kEpoch + since_epoch;
+ // Round days towards zero
+ // (the naive approach of using arrow_vendored::date::floor() would
+ // result in UB for very large negative timestamps, similarly as
+ // https://github.com/HowardHinnant/date/issues/696)
+ auto timepoint_days = std::chrono::time_point_cast<days>(timepoint);
+ Duration since_midnight;
+ if (timepoint_days <= timepoint) {
+ // Year >= 1970
+ since_midnight = timepoint - timepoint_days;
+ } else {
+ // Year < 1970
+ since_midnight = days(1) - (timepoint_days - timepoint);
+ timepoint_days -= days(1);
+ }
+
+ constexpr size_t buffer_size =
+ detail::BufferSizeYYYY_MM_DD() + 1 + detail::BufferSizeHH_MM_SS<Duration>();
+
+ std::array<char, buffer_size> buffer;
+ char* cursor = buffer.data() + buffer_size;
+
+ detail::FormatHH_MM_SS(arrow_vendored::date::make_time(since_midnight), &cursor);
+ detail::FormatOneChar(' ', &cursor);
+ detail::FormatYYYY_MM_DD(timepoint_days, &cursor);
+ return append(detail::ViewDigitBuffer(buffer, cursor));
+ }
+
+ template <typename Appender>
+ Return<Appender> operator()(value_type value, Appender&& append) {
+ return util::VisitDuration(unit_, *this, value, std::forward<Appender>(append));
+ }
+
+ private:
+ TimeUnit::type unit_;
+};
+
+template <typename T>
+class StringFormatter<T, enable_if_time<T>> {
+ public:
+ using value_type = typename T::c_type;
+
+ explicit StringFormatter(const std::shared_ptr<DataType>& type)
+ : unit_(checked_cast<const T&>(*type).unit()) {}
+
+ template <typename Duration, typename Appender>
+ Return<Appender> operator()(Duration, value_type count, Appender&& append) {
+ const Duration since_midnight{count};
+ if (!ARROW_PREDICT_TRUE(detail::IsTimeInRange(since_midnight))) {
+ return detail::FormatOutOfRange(count, append);
+ }
+
+ constexpr size_t buffer_size = detail::BufferSizeHH_MM_SS<Duration>();
+
+ std::array<char, buffer_size> buffer;
+ char* cursor = buffer.data() + buffer_size;
+
+ detail::FormatHH_MM_SS(arrow_vendored::date::make_time(since_midnight), &cursor);
+ return append(detail::ViewDigitBuffer(buffer, cursor));
+ }
+
+ template <typename Appender>
+ Return<Appender> operator()(value_type value, Appender&& append) {
+ return util::VisitDuration(unit_, *this, value, std::forward<Appender>(append));
+ }
+
+ private:
+ TimeUnit::type unit_;
+};
+
+template <>
+class StringFormatter<MonthIntervalType> {
+ public:
+ using value_type = MonthIntervalType::c_type;
+
+ explicit StringFormatter(const std::shared_ptr<DataType>&) {}
+
+ template <typename Appender>
+ Return<Appender> operator()(value_type interval, Appender&& append) {
+ constexpr size_t buffer_size =
+ /*'m'*/ 3 + /*negative signs*/ 1 +
+ /*months*/ detail::Digits10(std::numeric_limits<value_type>::max());
+ std::array<char, buffer_size> buffer;
+ char* cursor = buffer.data() + buffer_size;
+
+ detail::FormatOneChar('M', &cursor);
+ detail::FormatAllDigits(detail::Abs(interval), &cursor);
+ if (interval < 0) detail::FormatOneChar('-', &cursor);
+
+ return append(detail::ViewDigitBuffer(buffer, cursor));
+ }
+};
+
+template <>
+class StringFormatter<DayTimeIntervalType> {
+ public:
+ using value_type = DayTimeIntervalType::DayMilliseconds;
+
+ explicit StringFormatter(const std::shared_ptr<DataType>&) {}
+
+ template <typename Appender>
+ Return<Appender> operator()(value_type interval, Appender&& append) {
+ constexpr size_t buffer_size =
+ /*d, ms*/ 3 + /*negative signs*/ 2 +
+ /*days/milliseconds*/ 2 * detail::Digits10(std::numeric_limits<int32_t>::max());
+ std::array<char, buffer_size> buffer;
+ char* cursor = buffer.data() + buffer_size;
+
+ detail::FormatOneChar('s', &cursor);
+ detail::FormatOneChar('m', &cursor);
+ detail::FormatAllDigits(detail::Abs(interval.milliseconds), &cursor);
+ if (interval.milliseconds < 0) detail::FormatOneChar('-', &cursor);
+
+ detail::FormatOneChar('d', &cursor);
+ detail::FormatAllDigits(detail::Abs(interval.days), &cursor);
+ if (interval.days < 0) detail::FormatOneChar('-', &cursor);
+
+ return append(detail::ViewDigitBuffer(buffer, cursor));
+ }
+};
+
+template <>
+class StringFormatter<MonthDayNanoIntervalType> {
+ public:
+ using value_type = MonthDayNanoIntervalType::MonthDayNanos;
+
+ explicit StringFormatter(const std::shared_ptr<DataType>&) {}
+
+ template <typename Appender>
+ Return<Appender> operator()(value_type interval, Appender&& append) {
+ constexpr size_t buffer_size =
+ /*m, d, ns*/ 4 + /*negative signs*/ 3 +
+ /*months/days*/ 2 * detail::Digits10(std::numeric_limits<int32_t>::max()) +
+ /*nanoseconds*/ detail::Digits10(std::numeric_limits<int64_t>::max());
+ std::array<char, buffer_size> buffer;
+ char* cursor = buffer.data() + buffer_size;
+
+ detail::FormatOneChar('s', &cursor);
+ detail::FormatOneChar('n', &cursor);
+ detail::FormatAllDigits(detail::Abs(interval.nanoseconds), &cursor);
+ if (interval.nanoseconds < 0) detail::FormatOneChar('-', &cursor);
+
+ detail::FormatOneChar('d', &cursor);
+ detail::FormatAllDigits(detail::Abs(interval.days), &cursor);
+ if (interval.days < 0) detail::FormatOneChar('-', &cursor);
+
+ detail::FormatOneChar('M', &cursor);
+ detail::FormatAllDigits(detail::Abs(interval.months), &cursor);
+ if (interval.months < 0) detail::FormatOneChar('-', &cursor);
+
+ return append(detail::ViewDigitBuffer(buffer, cursor));
+ }
+};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/formatting_util_test.cc b/src/arrow/cpp/src/arrow/util/formatting_util_test.cc
new file mode 100644
index 000000000..3e7855187
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/formatting_util_test.cc
@@ -0,0 +1,468 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cmath>
+#include <locale>
+#include <stdexcept>
+#include <string>
+
+#include <gtest/gtest.h>
+
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/formatting.h"
+
+namespace arrow {
+
+using internal::StringFormatter;
+
+class StringAppender {
+ public:
+ Status operator()(util::string_view v) {
+ string_.append(v.data(), v.size());
+ return Status::OK();
+ }
+
+ std::string string() const { return string_; }
+
+ protected:
+ std::string string_;
+};
+
+template <typename FormatterType, typename C_TYPE = typename FormatterType::value_type>
+void AssertFormatting(FormatterType& formatter, C_TYPE value,
+ const std::string& expected) {
+ StringAppender appender;
+ ASSERT_OK(formatter(value, appender));
+ ASSERT_EQ(appender.string(), expected) << "Formatting failed (value = " << value << ")";
+}
+
+TEST(Formatting, Boolean) {
+ StringFormatter<BooleanType> formatter;
+
+ AssertFormatting(formatter, true, "true");
+ AssertFormatting(formatter, false, "false");
+}
+
+template <typename FormatterType>
+void TestAnyIntUpTo8(FormatterType& formatter) {
+ AssertFormatting(formatter, 0, "0");
+ AssertFormatting(formatter, 1, "1");
+ AssertFormatting(formatter, 9, "9");
+ AssertFormatting(formatter, 10, "10");
+ AssertFormatting(formatter, 99, "99");
+ AssertFormatting(formatter, 100, "100");
+ AssertFormatting(formatter, 127, "127");
+}
+
+template <typename FormatterType>
+void TestAnyIntUpTo16(FormatterType& formatter) {
+ TestAnyIntUpTo8(formatter);
+ AssertFormatting(formatter, 999, "999");
+ AssertFormatting(formatter, 1000, "1000");
+ AssertFormatting(formatter, 9999, "9999");
+ AssertFormatting(formatter, 10000, "10000");
+ AssertFormatting(formatter, 32767, "32767");
+}
+
+template <typename FormatterType>
+void TestAnyIntUpTo32(FormatterType& formatter) {
+ TestAnyIntUpTo16(formatter);
+ AssertFormatting(formatter, 99999, "99999");
+ AssertFormatting(formatter, 100000, "100000");
+ AssertFormatting(formatter, 999999, "999999");
+ AssertFormatting(formatter, 1000000, "1000000");
+ AssertFormatting(formatter, 9999999, "9999999");
+ AssertFormatting(formatter, 10000000, "10000000");
+ AssertFormatting(formatter, 99999999, "99999999");
+ AssertFormatting(formatter, 100000000, "100000000");
+ AssertFormatting(formatter, 999999999, "999999999");
+ AssertFormatting(formatter, 1000000000, "1000000000");
+ AssertFormatting(formatter, 1234567890, "1234567890");
+ AssertFormatting(formatter, 2147483647, "2147483647");
+}
+
+template <typename FormatterType>
+void TestAnyIntUpTo64(FormatterType& formatter) {
+ TestAnyIntUpTo32(formatter);
+ AssertFormatting(formatter, 9999999999ULL, "9999999999");
+ AssertFormatting(formatter, 10000000000ULL, "10000000000");
+ AssertFormatting(formatter, 99999999999ULL, "99999999999");
+ AssertFormatting(formatter, 100000000000ULL, "100000000000");
+ AssertFormatting(formatter, 999999999999ULL, "999999999999");
+ AssertFormatting(formatter, 1000000000000ULL, "1000000000000");
+ AssertFormatting(formatter, 9999999999999ULL, "9999999999999");
+ AssertFormatting(formatter, 10000000000000ULL, "10000000000000");
+ AssertFormatting(formatter, 99999999999999ULL, "99999999999999");
+ AssertFormatting(formatter, 1000000000000000000ULL, "1000000000000000000");
+ AssertFormatting(formatter, 9223372036854775807ULL, "9223372036854775807");
+}
+
+template <typename FormatterType>
+void TestUIntUpTo8(FormatterType& formatter) {
+ TestAnyIntUpTo8(formatter);
+ AssertFormatting(formatter, 128, "128");
+ AssertFormatting(formatter, 255, "255");
+}
+
+template <typename FormatterType>
+void TestUIntUpTo16(FormatterType& formatter) {
+ TestAnyIntUpTo16(formatter);
+ AssertFormatting(formatter, 32768, "32768");
+ AssertFormatting(formatter, 65535, "65535");
+}
+
+template <typename FormatterType>
+void TestUIntUpTo32(FormatterType& formatter) {
+ TestAnyIntUpTo32(formatter);
+ AssertFormatting(formatter, 2147483648U, "2147483648");
+ AssertFormatting(formatter, 4294967295U, "4294967295");
+}
+
+template <typename FormatterType>
+void TestUIntUpTo64(FormatterType& formatter) {
+ TestAnyIntUpTo64(formatter);
+ AssertFormatting(formatter, 9999999999999999999ULL, "9999999999999999999");
+ AssertFormatting(formatter, 10000000000000000000ULL, "10000000000000000000");
+ AssertFormatting(formatter, 12345678901234567890ULL, "12345678901234567890");
+ AssertFormatting(formatter, 18446744073709551615ULL, "18446744073709551615");
+}
+
+TEST(Formatting, UInt8) {
+ StringFormatter<UInt8Type> formatter;
+
+ TestUIntUpTo8(formatter);
+}
+
+TEST(Formatting, UInt16) {
+ StringFormatter<UInt16Type> formatter;
+
+ TestUIntUpTo16(formatter);
+}
+
+TEST(Formatting, UInt32) {
+ StringFormatter<UInt32Type> formatter;
+
+ TestUIntUpTo32(formatter);
+}
+
+TEST(Formatting, UInt64) {
+ StringFormatter<UInt64Type> formatter;
+
+ TestUIntUpTo64(formatter);
+}
+
+template <typename FormatterType>
+void TestIntUpTo8(FormatterType& formatter) {
+ TestAnyIntUpTo8(formatter);
+ AssertFormatting(formatter, -1, "-1");
+ AssertFormatting(formatter, -9, "-9");
+ AssertFormatting(formatter, -10, "-10");
+ AssertFormatting(formatter, -99, "-99");
+ AssertFormatting(formatter, -100, "-100");
+ AssertFormatting(formatter, -127, "-127");
+ AssertFormatting(formatter, -128, "-128");
+}
+
+template <typename FormatterType>
+void TestIntUpTo16(FormatterType& formatter) {
+ TestAnyIntUpTo16(formatter);
+ TestIntUpTo8(formatter);
+ AssertFormatting(formatter, -129, "-129");
+ AssertFormatting(formatter, -999, "-999");
+ AssertFormatting(formatter, -1000, "-1000");
+ AssertFormatting(formatter, -9999, "-9999");
+ AssertFormatting(formatter, -10000, "-10000");
+ AssertFormatting(formatter, -32768, "-32768");
+}
+
+template <typename FormatterType>
+void TestIntUpTo32(FormatterType& formatter) {
+ TestAnyIntUpTo32(formatter);
+ TestIntUpTo16(formatter);
+ AssertFormatting(formatter, -32769, "-32769");
+ AssertFormatting(formatter, -99999, "-99999");
+ AssertFormatting(formatter, -1000000000, "-1000000000");
+ AssertFormatting(formatter, -1234567890, "-1234567890");
+ AssertFormatting(formatter, -2147483647, "-2147483647");
+ AssertFormatting(formatter, -2147483647 - 1, "-2147483648");
+}
+
+template <typename FormatterType>
+void TestIntUpTo64(FormatterType& formatter) {
+ TestAnyIntUpTo64(formatter);
+ TestIntUpTo32(formatter);
+ AssertFormatting(formatter, -2147483649LL, "-2147483649");
+ AssertFormatting(formatter, -9999999999LL, "-9999999999");
+ AssertFormatting(formatter, -1000000000000000000LL, "-1000000000000000000");
+ AssertFormatting(formatter, -9012345678901234567LL, "-9012345678901234567");
+ AssertFormatting(formatter, -9223372036854775807LL, "-9223372036854775807");
+ AssertFormatting(formatter, -9223372036854775807LL - 1, "-9223372036854775808");
+}
+
+TEST(Formatting, Int8) {
+ StringFormatter<Int8Type> formatter;
+
+ TestIntUpTo8(formatter);
+}
+
+TEST(Formatting, Int16) {
+ StringFormatter<Int16Type> formatter;
+
+ TestIntUpTo16(formatter);
+}
+
+TEST(Formatting, Int32) {
+ StringFormatter<Int32Type> formatter;
+
+ TestIntUpTo32(formatter);
+}
+
+TEST(Formatting, Int64) {
+ StringFormatter<Int64Type> formatter;
+
+ TestIntUpTo64(formatter);
+}
+
+TEST(Formatting, Float) {
+ StringFormatter<FloatType> formatter;
+
+ AssertFormatting(formatter, 0.0f, "0");
+ AssertFormatting(formatter, -0.0f, "-0");
+ AssertFormatting(formatter, 1.5f, "1.5");
+ AssertFormatting(formatter, 0.0001f, "0.0001");
+ AssertFormatting(formatter, 1234.567f, "1234.567");
+ AssertFormatting(formatter, 1e9f, "1000000000");
+ AssertFormatting(formatter, 1e10f, "1e+10");
+ AssertFormatting(formatter, 1e20f, "1e+20");
+ AssertFormatting(formatter, 1e-6f, "0.000001");
+ AssertFormatting(formatter, 1e-7f, "1e-7");
+ AssertFormatting(formatter, 1e-20f, "1e-20");
+
+ AssertFormatting(formatter, std::nanf(""), "nan");
+ AssertFormatting(formatter, HUGE_VALF, "inf");
+ AssertFormatting(formatter, -HUGE_VALF, "-inf");
+}
+
+TEST(Formatting, Double) {
+ StringFormatter<DoubleType> formatter;
+
+ AssertFormatting(formatter, 0.0, "0");
+ AssertFormatting(formatter, -0.0, "-0");
+ AssertFormatting(formatter, 1.5, "1.5");
+ AssertFormatting(formatter, 0.0001, "0.0001");
+ AssertFormatting(formatter, 1234.567, "1234.567");
+ AssertFormatting(formatter, 1e9, "1000000000");
+ AssertFormatting(formatter, 1e10, "1e+10");
+ AssertFormatting(formatter, 1e20, "1e+20");
+ AssertFormatting(formatter, 1e-6, "0.000001");
+ AssertFormatting(formatter, 1e-7, "1e-7");
+ AssertFormatting(formatter, 1e-20, "1e-20");
+
+ AssertFormatting(formatter, std::nan(""), "nan");
+ AssertFormatting(formatter, HUGE_VAL, "inf");
+ AssertFormatting(formatter, -HUGE_VAL, "-inf");
+}
+
+TEST(Formatting, Date32) {
+ StringFormatter<Date32Type> formatter;
+
+ AssertFormatting(formatter, 0, "1970-01-01");
+ AssertFormatting(formatter, 1, "1970-01-02");
+ AssertFormatting(formatter, 30, "1970-01-31");
+ AssertFormatting(formatter, 30 + 1, "1970-02-01");
+ AssertFormatting(formatter, 30 + 28, "1970-02-28");
+ AssertFormatting(formatter, 30 + 28 + 1, "1970-03-01");
+ AssertFormatting(formatter, -1, "1969-12-31");
+ AssertFormatting(formatter, 365, "1971-01-01");
+ AssertFormatting(formatter, 2 * 365, "1972-01-01");
+ AssertFormatting(formatter, 2 * 365 + 30 + 28 + 1, "1972-02-29");
+}
+
+TEST(Formatting, Date64) {
+ StringFormatter<Date64Type> formatter;
+
+ constexpr int64_t kMillisInDay = 24 * 60 * 60 * 1000;
+ AssertFormatting(formatter, kMillisInDay * (0), "1970-01-01");
+ AssertFormatting(formatter, kMillisInDay * (1), "1970-01-02");
+ AssertFormatting(formatter, kMillisInDay * (30), "1970-01-31");
+ AssertFormatting(formatter, kMillisInDay * (30 + 1), "1970-02-01");
+ AssertFormatting(formatter, kMillisInDay * (30 + 28), "1970-02-28");
+ AssertFormatting(formatter, kMillisInDay * (30 + 28 + 1), "1970-03-01");
+ AssertFormatting(formatter, kMillisInDay * (-1), "1969-12-31");
+ AssertFormatting(formatter, kMillisInDay * (365), "1971-01-01");
+ AssertFormatting(formatter, kMillisInDay * (2 * 365), "1972-01-01");
+ AssertFormatting(formatter, kMillisInDay * (2 * 365 + 30 + 28 + 1), "1972-02-29");
+}
+
+TEST(Formatting, Time32) {
+ {
+ StringFormatter<Time32Type> formatter(time32(TimeUnit::SECOND));
+
+ AssertFormatting(formatter, 0, "00:00:00");
+ AssertFormatting(formatter, 1, "00:00:01");
+ AssertFormatting(formatter, ((12) * 60 + 34) * 60 + 56, "12:34:56");
+ AssertFormatting(formatter, 24 * 60 * 60 - 1, "23:59:59");
+ }
+
+ {
+ StringFormatter<Time32Type> formatter(time32(TimeUnit::MILLI));
+
+ AssertFormatting(formatter, 0, "00:00:00.000");
+ AssertFormatting(formatter, 1, "00:00:00.001");
+ AssertFormatting(formatter, 1000, "00:00:01.000");
+ AssertFormatting(formatter, (((12) * 60 + 34) * 60 + 56) * 1000 + 789,
+ "12:34:56.789");
+ AssertFormatting(formatter, 24 * 60 * 60 * 1000 - 1, "23:59:59.999");
+ }
+}
+
+TEST(Formatting, Time64) {
+ {
+ StringFormatter<Time64Type> formatter(time64(TimeUnit::MICRO));
+
+ AssertFormatting(formatter, 0, "00:00:00.000000");
+ AssertFormatting(formatter, 1, "00:00:00.000001");
+ AssertFormatting(formatter, 1000000, "00:00:01.000000");
+ AssertFormatting(formatter, (((12) * 60 + 34) * 60 + 56) * 1000000LL + 789000,
+ "12:34:56.789000");
+ AssertFormatting(formatter, (24 * 60 * 60) * 1000000LL - 1, "23:59:59.999999");
+ }
+
+ {
+ StringFormatter<Time64Type> formatter(time64(TimeUnit::NANO));
+
+ AssertFormatting(formatter, 0, "00:00:00.000000000");
+ AssertFormatting(formatter, 1, "00:00:00.000000001");
+ AssertFormatting(formatter, 1000000000LL, "00:00:01.000000000");
+ AssertFormatting(formatter, (((12) * 60 + 34) * 60 + 56) * 1000000000LL + 789000000LL,
+ "12:34:56.789000000");
+ AssertFormatting(formatter, (24 * 60 * 60) * 1000000000LL - 1, "23:59:59.999999999");
+ }
+}
+
+TEST(Formatting, Timestamp) {
+ {
+ StringFormatter<TimestampType> formatter(timestamp(TimeUnit::SECOND));
+
+ AssertFormatting(formatter, 0, "1970-01-01 00:00:00");
+ AssertFormatting(formatter, 1, "1970-01-01 00:00:01");
+ AssertFormatting(formatter, 24 * 60 * 60, "1970-01-02 00:00:00");
+ AssertFormatting(formatter, 616377600, "1989-07-14 00:00:00");
+ AssertFormatting(formatter, 951782400, "2000-02-29 00:00:00");
+ AssertFormatting(formatter, 63730281600LL, "3989-07-14 00:00:00");
+ AssertFormatting(formatter, -2203977600LL, "1900-02-28 00:00:00");
+
+ AssertFormatting(formatter, 1542129070, "2018-11-13 17:11:10");
+ AssertFormatting(formatter, -2203932304LL, "1900-02-28 12:34:56");
+ }
+
+ {
+ StringFormatter<TimestampType> formatter(timestamp(TimeUnit::MILLI));
+
+ AssertFormatting(formatter, 0, "1970-01-01 00:00:00.000");
+ AssertFormatting(formatter, 1000L + 1, "1970-01-01 00:00:01.001");
+ AssertFormatting(formatter, 24 * 60 * 60 * 1000LL + 2, "1970-01-02 00:00:00.002");
+ AssertFormatting(formatter, 616377600 * 1000LL + 3, "1989-07-14 00:00:00.003");
+ AssertFormatting(formatter, 951782400 * 1000LL + 4, "2000-02-29 00:00:00.004");
+ AssertFormatting(formatter, 63730281600LL * 1000LL + 5, "3989-07-14 00:00:00.005");
+ AssertFormatting(formatter, -2203977600LL * 1000LL + 6, "1900-02-28 00:00:00.006");
+
+ AssertFormatting(formatter, 1542129070LL * 1000LL + 7, "2018-11-13 17:11:10.007");
+ AssertFormatting(formatter, -2203932304LL * 1000LL + 8, "1900-02-28 12:34:56.008");
+ }
+
+ {
+ StringFormatter<TimestampType> formatter(timestamp(TimeUnit::MICRO));
+
+ AssertFormatting(formatter, 0, "1970-01-01 00:00:00.000000");
+ AssertFormatting(formatter, 1000000LL + 1, "1970-01-01 00:00:01.000001");
+ AssertFormatting(formatter, 24 * 60 * 60 * 1000000LL + 2,
+ "1970-01-02 00:00:00.000002");
+ AssertFormatting(formatter, 616377600 * 1000000LL + 3, "1989-07-14 00:00:00.000003");
+ AssertFormatting(formatter, 951782400 * 1000000LL + 4, "2000-02-29 00:00:00.000004");
+ AssertFormatting(formatter, 63730281600LL * 1000000LL + 5,
+ "3989-07-14 00:00:00.000005");
+ AssertFormatting(formatter, -2203977600LL * 1000000LL + 6,
+ "1900-02-28 00:00:00.000006");
+
+ AssertFormatting(formatter, 1542129070 * 1000000LL + 7, "2018-11-13 17:11:10.000007");
+ AssertFormatting(formatter, -2203932304LL * 1000000LL + 8,
+ "1900-02-28 12:34:56.000008");
+ }
+
+ {
+ StringFormatter<TimestampType> formatter(timestamp(TimeUnit::NANO));
+
+ AssertFormatting(formatter, 0, "1970-01-01 00:00:00.000000000");
+ AssertFormatting(formatter, 1000000000LL + 1, "1970-01-01 00:00:01.000000001");
+ AssertFormatting(formatter, 24 * 60 * 60 * 1000000000LL + 2,
+ "1970-01-02 00:00:00.000000002");
+ AssertFormatting(formatter, 616377600 * 1000000000LL + 3,
+ "1989-07-14 00:00:00.000000003");
+ AssertFormatting(formatter, 951782400 * 1000000000LL + 4,
+ "2000-02-29 00:00:00.000000004");
+ AssertFormatting(formatter, -2203977600LL * 1000000000LL + 6,
+ "1900-02-28 00:00:00.000000006");
+
+ AssertFormatting(formatter, 1542129070 * 1000000000LL + 7,
+ "2018-11-13 17:11:10.000000007");
+ AssertFormatting(formatter, -2203932304LL * 1000000000LL + 8,
+ "1900-02-28 12:34:56.000000008");
+ }
+}
+
+TEST(Formatting, Interval) {
+ using DayMilliseconds = DayTimeIntervalType::DayMilliseconds;
+ using MonthDayNanos = MonthDayNanoIntervalType::MonthDayNanos;
+
+ const int32_t max_int32 = std::numeric_limits<int32_t>::max();
+ const int32_t min_int32 = std::numeric_limits<int32_t>::min();
+ const int64_t max_int64 = std::numeric_limits<int64_t>::max();
+ const int64_t min_int64 = std::numeric_limits<int64_t>::min();
+ {
+ StringFormatter<MonthIntervalType> formatter(month_interval());
+
+ AssertFormatting(formatter, 0, "0M");
+ AssertFormatting(formatter, -1, "-1M");
+ AssertFormatting(formatter, min_int32, "-2147483648M");
+ AssertFormatting(formatter, max_int32, "2147483647M");
+ }
+ {
+ StringFormatter<DayTimeIntervalType> formatter(day_time_interval());
+
+ AssertFormatting(formatter, DayMilliseconds{0, 0}, "0d0ms");
+ AssertFormatting(formatter, DayMilliseconds{-1, -1}, "-1d-1ms");
+ AssertFormatting(formatter, DayMilliseconds{min_int32, min_int32},
+ "-2147483648d-2147483648ms");
+ AssertFormatting(formatter, DayMilliseconds{max_int32, max_int32},
+ "2147483647d2147483647ms");
+ }
+ {
+ StringFormatter<MonthDayNanoIntervalType> formatter(month_day_nano_interval());
+
+ AssertFormatting(formatter, MonthDayNanos{0, 0, 0}, "0M0d0ns");
+ AssertFormatting(formatter, MonthDayNanos{-1, -1, -1}, "-1M-1d-1ns");
+ AssertFormatting(formatter, MonthDayNanos{min_int32, min_int32, min_int64},
+ "-2147483648M-2147483648d-9223372036854775808ns");
+ AssertFormatting(formatter, MonthDayNanos{max_int32, max_int32, max_int64},
+ "2147483647M2147483647d9223372036854775807ns");
+ }
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/functional.h b/src/arrow/cpp/src/arrow/util/functional.h
new file mode 100644
index 000000000..41e268852
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/functional.h
@@ -0,0 +1,160 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <tuple>
+#include <type_traits>
+
+#include "arrow/result.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+struct Empty {
+ static Result<Empty> ToResult(Status s) {
+ if (ARROW_PREDICT_TRUE(s.ok())) {
+ return Empty{};
+ }
+ return s;
+ }
+};
+
+/// Helper struct for examining lambdas and other callables.
+/// TODO(ARROW-12655) support function pointers
+struct call_traits {
+ public:
+ template <typename R, typename... A>
+ static std::false_type is_overloaded_impl(R(A...));
+
+ template <typename F>
+ static std::false_type is_overloaded_impl(decltype(&F::operator())*);
+
+ template <typename F>
+ static std::true_type is_overloaded_impl(...);
+
+ template <typename F, typename R, typename... A>
+ static R return_type_impl(R (F::*)(A...));
+
+ template <typename F, typename R, typename... A>
+ static R return_type_impl(R (F::*)(A...) const);
+
+ template <std::size_t I, typename F, typename R, typename... A>
+ static typename std::tuple_element<I, std::tuple<A...>>::type argument_type_impl(
+ R (F::*)(A...));
+
+ template <std::size_t I, typename F, typename R, typename... A>
+ static typename std::tuple_element<I, std::tuple<A...>>::type argument_type_impl(
+ R (F::*)(A...) const);
+
+ template <std::size_t I, typename F, typename R, typename... A>
+ static typename std::tuple_element<I, std::tuple<A...>>::type argument_type_impl(
+ R (F::*)(A...) &&);
+
+ template <typename F, typename R, typename... A>
+ static std::integral_constant<int, sizeof...(A)> argument_count_impl(R (F::*)(A...));
+
+ template <typename F, typename R, typename... A>
+ static std::integral_constant<int, sizeof...(A)> argument_count_impl(R (F::*)(A...)
+ const);
+
+ template <typename F, typename R, typename... A>
+ static std::integral_constant<int, sizeof...(A)> argument_count_impl(R (F::*)(A...) &&);
+
+ /// bool constant indicating whether F is a callable with more than one possible
+ /// signature. Will be true_type for objects which define multiple operator() or which
+ /// define a template operator()
+ template <typename F>
+ using is_overloaded =
+ decltype(is_overloaded_impl<typename std::decay<F>::type>(NULLPTR));
+
+ template <typename F, typename T = void>
+ using enable_if_overloaded = typename std::enable_if<is_overloaded<F>::value, T>::type;
+
+ template <typename F, typename T = void>
+ using disable_if_overloaded =
+ typename std::enable_if<!is_overloaded<F>::value, T>::type;
+
+ /// If F is not overloaded, the argument types of its call operator can be
+ /// extracted via call_traits::argument_type<Index, F>
+ template <std::size_t I, typename F>
+ using argument_type = decltype(argument_type_impl<I>(&std::decay<F>::type::operator()));
+
+ template <typename F>
+ using argument_count = decltype(argument_count_impl(&std::decay<F>::type::operator()));
+
+ template <typename F>
+ using return_type = decltype(return_type_impl(&std::decay<F>::type::operator()));
+
+ template <typename F, typename T, typename RT = T>
+ using enable_if_return =
+ typename std::enable_if<std::is_same<return_type<F>, T>::value, RT>;
+
+ template <typename T, typename R = void>
+ using enable_if_empty = typename std::enable_if<std::is_same<T, Empty>::value, R>::type;
+
+ template <typename T, typename R = void>
+ using enable_if_not_empty =
+ typename std::enable_if<!std::is_same<T, Empty>::value, R>::type;
+};
+
+/// A type erased callable object which may only be invoked once.
+/// It can be constructed from any lambda which matches the provided call signature.
+/// Invoking it results in destruction of the lambda, freeing any state/references
+/// immediately. Invoking a default constructed FnOnce or one which has already been
+/// invoked will segfault.
+template <typename Signature>
+class FnOnce;
+
+template <typename R, typename... A>
+class FnOnce<R(A...)> {
+ public:
+ FnOnce() = default;
+
+ template <typename Fn,
+ typename = typename std::enable_if<std::is_convertible<
+ decltype(std::declval<Fn&&>()(std::declval<A>()...)), R>::value>::type>
+ FnOnce(Fn fn) : impl_(new FnImpl<Fn>(std::move(fn))) { // NOLINT runtime/explicit
+ }
+
+ explicit operator bool() const { return impl_ != NULLPTR; }
+
+ R operator()(A... a) && {
+ auto bye = std::move(impl_);
+ return bye->invoke(std::forward<A&&>(a)...);
+ }
+
+ private:
+ struct Impl {
+ virtual ~Impl() = default;
+ virtual R invoke(A&&... a) = 0;
+ };
+
+ template <typename Fn>
+ struct FnImpl : Impl {
+ explicit FnImpl(Fn fn) : fn_(std::move(fn)) {}
+ R invoke(A&&... a) override { return std::move(fn_)(std::forward<A&&>(a)...); }
+ Fn fn_;
+ };
+
+ std::unique_ptr<Impl> impl_;
+};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/future.cc b/src/arrow/cpp/src/arrow/util/future.cc
new file mode 100644
index 000000000..c398d9928
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/future.cc
@@ -0,0 +1,437 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/future.h"
+
+#include <algorithm>
+#include <atomic>
+#include <chrono>
+#include <condition_variable>
+#include <mutex>
+#include <numeric>
+
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+// Shared mutex for all FutureWaiter instances.
+// This simplifies lock management compared to a per-waiter mutex.
+// The locking order is: global waiter mutex, then per-future mutex.
+//
+// It is unlikely that many waiter instances are alive at once, so this
+// should ideally not limit scalability.
+static std::mutex global_waiter_mutex;
+
+const double FutureWaiter::kInfinity = HUGE_VAL;
+
+class FutureWaiterImpl : public FutureWaiter {
+ public:
+ FutureWaiterImpl(Kind kind, std::vector<FutureImpl*> futures)
+ : signalled_(false),
+ kind_(kind),
+ futures_(std::move(futures)),
+ one_failed_(-1),
+ fetch_pos_(0) {
+ finished_futures_.reserve(futures_.size());
+
+ // Observe the current state of futures and add waiters to receive future
+ // state changes, atomically per future.
+ // We need to lock ourselves, because as soon as SetWaiter() is called,
+ // a FutureImpl may call MarkFutureFinished() from another thread
+ // before this constructor finishes.
+ std::unique_lock<std::mutex> lock(global_waiter_mutex);
+
+ for (int i = 0; i < static_cast<int>(futures_.size()); ++i) {
+ const auto state = futures_[i]->SetWaiter(this, i);
+ if (IsFutureFinished(state)) {
+ finished_futures_.push_back(i);
+ }
+ if (state != FutureState::SUCCESS) {
+ one_failed_ = i;
+ }
+ }
+
+ // Maybe signal the waiter, if the ending condition is already satisfied
+ if (ShouldSignal()) {
+ // No need to notify non-existent Wait() calls
+ signalled_ = true;
+ }
+ }
+
+ ~FutureWaiterImpl() override {
+ for (auto future : futures_) {
+ future->RemoveWaiter(this);
+ }
+ }
+
+ // Is the ending condition satisfied?
+ bool ShouldSignal() {
+ bool do_signal = false;
+ switch (kind_) {
+ case ANY:
+ do_signal = (finished_futures_.size() > 0);
+ break;
+ case ALL:
+ do_signal = (finished_futures_.size() == futures_.size());
+ break;
+ case ALL_OR_FIRST_FAILED:
+ do_signal = (finished_futures_.size() == futures_.size()) || one_failed_ >= 0;
+ break;
+ case ITERATE:
+ do_signal = (finished_futures_.size() > static_cast<size_t>(fetch_pos_));
+ break;
+ }
+ return do_signal;
+ }
+
+ void Signal() {
+ signalled_ = true;
+ cv_.notify_one();
+ }
+
+ void DoWaitUnlocked(std::unique_lock<std::mutex>* lock) {
+ cv_.wait(*lock, [this] { return signalled_.load(); });
+ }
+
+ bool DoWait() {
+ if (signalled_) {
+ return true;
+ }
+ std::unique_lock<std::mutex> lock(global_waiter_mutex);
+ DoWaitUnlocked(&lock);
+ return true;
+ }
+
+ template <class Rep, class Period>
+ bool DoWait(const std::chrono::duration<Rep, Period>& duration) {
+ if (signalled_) {
+ return true;
+ }
+ std::unique_lock<std::mutex> lock(global_waiter_mutex);
+ cv_.wait_for(lock, duration, [this] { return signalled_.load(); });
+ return signalled_.load();
+ }
+
+ void DoMarkFutureFinishedUnlocked(int future_num, FutureState state) {
+ finished_futures_.push_back(future_num);
+ if (state != FutureState::SUCCESS) {
+ one_failed_ = future_num;
+ }
+ if (!signalled_ && ShouldSignal()) {
+ Signal();
+ }
+ }
+
+ int DoWaitAndFetchOne() {
+ std::unique_lock<std::mutex> lock(global_waiter_mutex);
+
+ DCHECK_EQ(kind_, ITERATE);
+ DoWaitUnlocked(&lock);
+ DCHECK_LT(static_cast<size_t>(fetch_pos_), finished_futures_.size());
+ if (static_cast<size_t>(fetch_pos_) == finished_futures_.size() - 1) {
+ signalled_ = false;
+ }
+ return finished_futures_[fetch_pos_++];
+ }
+
+ std::vector<int> DoMoveFinishedFutures() {
+ std::unique_lock<std::mutex> lock(global_waiter_mutex);
+
+ return std::move(finished_futures_);
+ }
+
+ protected:
+ std::condition_variable cv_;
+ std::atomic<bool> signalled_;
+
+ Kind kind_;
+ std::vector<FutureImpl*> futures_;
+ std::vector<int> finished_futures_;
+ int one_failed_;
+ int fetch_pos_;
+};
+
+namespace {
+
+FutureWaiterImpl* GetConcreteWaiter(FutureWaiter* waiter) {
+ return checked_cast<FutureWaiterImpl*>(waiter);
+}
+
+} // namespace
+
+FutureWaiter::FutureWaiter() = default;
+
+FutureWaiter::~FutureWaiter() = default;
+
+std::unique_ptr<FutureWaiter> FutureWaiter::Make(Kind kind,
+ std::vector<FutureImpl*> futures) {
+ return std::unique_ptr<FutureWaiter>(new FutureWaiterImpl(kind, std::move(futures)));
+}
+
+void FutureWaiter::MarkFutureFinishedUnlocked(int future_num, FutureState state) {
+ // Called by FutureImpl on state changes
+ GetConcreteWaiter(this)->DoMarkFutureFinishedUnlocked(future_num, state);
+}
+
+bool FutureWaiter::Wait(double seconds) {
+ if (seconds == kInfinity) {
+ return GetConcreteWaiter(this)->DoWait();
+ } else {
+ return GetConcreteWaiter(this)->DoWait(std::chrono::duration<double>(seconds));
+ }
+}
+
+int FutureWaiter::WaitAndFetchOne() {
+ return GetConcreteWaiter(this)->DoWaitAndFetchOne();
+}
+
+std::vector<int> FutureWaiter::MoveFinishedFutures() {
+ return GetConcreteWaiter(this)->DoMoveFinishedFutures();
+}
+
+class ConcreteFutureImpl : public FutureImpl {
+ public:
+ FutureState DoSetWaiter(FutureWaiter* w, int future_num) {
+ std::unique_lock<std::mutex> lock(mutex_);
+
+ // Atomically load state at the time of adding the waiter, to avoid
+ // missed or duplicate events in the caller
+ ARROW_CHECK_EQ(waiter_, nullptr)
+ << "Only one Waiter allowed per Future at any given time";
+ waiter_ = w;
+ waiter_arg_ = future_num;
+ return state_.load();
+ }
+
+ void DoRemoveWaiter(FutureWaiter* w) {
+ std::unique_lock<std::mutex> lock(mutex_);
+
+ ARROW_CHECK_EQ(waiter_, w);
+ waiter_ = nullptr;
+ }
+
+ void DoMarkFinished() { DoMarkFinishedOrFailed(FutureState::SUCCESS); }
+
+ void DoMarkFailed() { DoMarkFinishedOrFailed(FutureState::FAILURE); }
+
+ void CheckOptions(const CallbackOptions& opts) {
+ if (opts.should_schedule != ShouldSchedule::Never) {
+ DCHECK_NE(opts.executor, nullptr)
+ << "An executor must be specified when adding a callback that might schedule";
+ }
+ }
+
+ void AddCallback(Callback callback, CallbackOptions opts) {
+ CheckOptions(opts);
+ std::unique_lock<std::mutex> lock(mutex_);
+ CallbackRecord callback_record{std::move(callback), opts};
+ if (IsFutureFinished(state_)) {
+ lock.unlock();
+ RunOrScheduleCallback(shared_from_this(), std::move(callback_record),
+ /*in_add_callback=*/true);
+ } else {
+ callbacks_.push_back(std::move(callback_record));
+ }
+ }
+
+ bool TryAddCallback(const std::function<Callback()>& callback_factory,
+ CallbackOptions opts) {
+ CheckOptions(opts);
+ std::unique_lock<std::mutex> lock(mutex_);
+ if (IsFutureFinished(state_)) {
+ return false;
+ } else {
+ callbacks_.push_back({callback_factory(), opts});
+ return true;
+ }
+ }
+
+ static bool ShouldScheduleCallback(const CallbackRecord& callback_record,
+ bool in_add_callback) {
+ switch (callback_record.options.should_schedule) {
+ case ShouldSchedule::Never:
+ return false;
+ case ShouldSchedule::Always:
+ return true;
+ case ShouldSchedule::IfUnfinished:
+ return !in_add_callback;
+ case ShouldSchedule::IfDifferentExecutor:
+ return !callback_record.options.executor->OwnsThisThread();
+ default:
+ DCHECK(false) << "Unrecognized ShouldSchedule option";
+ return false;
+ }
+ }
+
+ static void RunOrScheduleCallback(const std::shared_ptr<FutureImpl>& self,
+ CallbackRecord&& callback_record,
+ bool in_add_callback) {
+ if (ShouldScheduleCallback(callback_record, in_add_callback)) {
+ struct CallbackTask {
+ void operator()() { std::move(callback)(*self); }
+
+ Callback callback;
+ std::shared_ptr<FutureImpl> self;
+ };
+ // Need to keep `this` alive until the callback has a chance to be scheduled.
+ CallbackTask task{std::move(callback_record.callback), self};
+ DCHECK_OK(callback_record.options.executor->Spawn(std::move(task)));
+ } else {
+ std::move(callback_record.callback)(*self);
+ }
+ }
+
+ void DoMarkFinishedOrFailed(FutureState state) {
+ {
+ // Lock the hypothetical waiter first, and the future after.
+ // This matches the locking order done in FutureWaiter constructor.
+ std::unique_lock<std::mutex> waiter_lock(global_waiter_mutex);
+ std::unique_lock<std::mutex> lock(mutex_);
+
+ DCHECK(!IsFutureFinished(state_)) << "Future already marked finished";
+ state_ = state;
+ if (waiter_ != nullptr) {
+ waiter_->MarkFutureFinishedUnlocked(waiter_arg_, state);
+ }
+ }
+ cv_.notify_all();
+
+ auto callbacks = std::move(callbacks_);
+ auto self = shared_from_this();
+
+ // run callbacks, lock not needed since the future is finished by this
+ // point so nothing else can modify the callbacks list and it is safe
+ // to iterate.
+ //
+ // In fact, it is important not to hold the locks because the callback
+ // may be slow or do its own locking on other resources
+ for (auto& callback_record : callbacks) {
+ RunOrScheduleCallback(self, std::move(callback_record), /*in_add_callback=*/false);
+ }
+ }
+
+ void DoWait() {
+ std::unique_lock<std::mutex> lock(mutex_);
+
+ cv_.wait(lock, [this] { return IsFutureFinished(state_); });
+ }
+
+ bool DoWait(double seconds) {
+ std::unique_lock<std::mutex> lock(mutex_);
+
+ cv_.wait_for(lock, std::chrono::duration<double>(seconds),
+ [this] { return IsFutureFinished(state_); });
+ return IsFutureFinished(state_);
+ }
+
+ std::mutex mutex_;
+ std::condition_variable cv_;
+ FutureWaiter* waiter_ = nullptr;
+ int waiter_arg_ = -1;
+};
+
+namespace {
+
+ConcreteFutureImpl* GetConcreteFuture(FutureImpl* future) {
+ return checked_cast<ConcreteFutureImpl*>(future);
+}
+
+} // namespace
+
+std::unique_ptr<FutureImpl> FutureImpl::Make() {
+ return std::unique_ptr<FutureImpl>(new ConcreteFutureImpl());
+}
+
+std::unique_ptr<FutureImpl> FutureImpl::MakeFinished(FutureState state) {
+ std::unique_ptr<ConcreteFutureImpl> ptr(new ConcreteFutureImpl());
+ ptr->state_ = state;
+ return std::move(ptr);
+}
+
+FutureImpl::FutureImpl() : state_(FutureState::PENDING) {}
+
+FutureState FutureImpl::SetWaiter(FutureWaiter* w, int future_num) {
+ return GetConcreteFuture(this)->DoSetWaiter(w, future_num);
+}
+
+void FutureImpl::RemoveWaiter(FutureWaiter* w) {
+ GetConcreteFuture(this)->DoRemoveWaiter(w);
+}
+
+void FutureImpl::Wait() { GetConcreteFuture(this)->DoWait(); }
+
+bool FutureImpl::Wait(double seconds) { return GetConcreteFuture(this)->DoWait(seconds); }
+
+void FutureImpl::MarkFinished() { GetConcreteFuture(this)->DoMarkFinished(); }
+
+void FutureImpl::MarkFailed() { GetConcreteFuture(this)->DoMarkFailed(); }
+
+void FutureImpl::AddCallback(Callback callback, CallbackOptions opts) {
+ GetConcreteFuture(this)->AddCallback(std::move(callback), opts);
+}
+
+bool FutureImpl::TryAddCallback(const std::function<Callback()>& callback_factory,
+ CallbackOptions opts) {
+ return GetConcreteFuture(this)->TryAddCallback(callback_factory, opts);
+}
+
+Future<> AllComplete(const std::vector<Future<>>& futures) {
+ struct State {
+ explicit State(int64_t n_futures) : mutex(), n_remaining(n_futures) {}
+
+ std::mutex mutex;
+ std::atomic<size_t> n_remaining;
+ };
+
+ if (futures.empty()) {
+ return Future<>::MakeFinished();
+ }
+
+ auto state = std::make_shared<State>(futures.size());
+ auto out = Future<>::Make();
+ for (const auto& future : futures) {
+ future.AddCallback([state, out](const Status& status) mutable {
+ if (!status.ok()) {
+ std::unique_lock<std::mutex> lock(state->mutex);
+ if (!out.is_finished()) {
+ out.MarkFinished(status);
+ }
+ return;
+ }
+ if (state->n_remaining.fetch_sub(1) != 1) return;
+ out.MarkFinished();
+ });
+ }
+ return out;
+}
+
+Future<> AllFinished(const std::vector<Future<>>& futures) {
+ return All(futures).Then([](const std::vector<Result<internal::Empty>>& results) {
+ for (const auto& res : results) {
+ if (!res.ok()) {
+ return res.status();
+ }
+ }
+ return Status::OK();
+ });
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/future.h b/src/arrow/cpp/src/arrow/util/future.h
new file mode 100644
index 000000000..695ee9ff3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/future.h
@@ -0,0 +1,978 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <atomic>
+#include <cmath>
+#include <functional>
+#include <memory>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/functional.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+template <typename>
+struct EnsureFuture;
+
+namespace detail {
+
+template <typename>
+struct is_future : std::false_type {};
+
+template <typename T>
+struct is_future<Future<T>> : std::true_type {};
+
+template <typename Signature, typename Enable = void>
+struct result_of;
+
+template <typename Fn, typename... A>
+struct result_of<Fn(A...),
+ internal::void_t<decltype(std::declval<Fn>()(std::declval<A>()...))>> {
+ using type = decltype(std::declval<Fn>()(std::declval<A>()...));
+};
+
+template <typename Signature>
+using result_of_t = typename result_of<Signature>::type;
+
+// Helper to find the synchronous counterpart for a Future
+template <typename T>
+struct SyncType {
+ using type = Result<T>;
+};
+
+template <>
+struct SyncType<internal::Empty> {
+ using type = Status;
+};
+
+template <typename Fn>
+using first_arg_is_status =
+ std::is_same<typename std::decay<internal::call_traits::argument_type<0, Fn>>::type,
+ Status>;
+
+template <typename Fn, typename Then, typename Else,
+ typename Count = internal::call_traits::argument_count<Fn>>
+using if_has_no_args = typename std::conditional<Count::value == 0, Then, Else>::type;
+
+/// Creates a callback that can be added to a future to mark a `dest` future finished
+template <typename Source, typename Dest, bool SourceEmpty = Source::is_empty,
+ bool DestEmpty = Dest::is_empty>
+struct MarkNextFinished {};
+
+/// If the source and dest are both empty we can pass on the status
+template <typename Source, typename Dest>
+struct MarkNextFinished<Source, Dest, true, true> {
+ void operator()(const Status& status) && { next.MarkFinished(status); }
+ Dest next;
+};
+
+/// If the source is not empty but the dest is then we can take the
+/// status out of the result
+template <typename Source, typename Dest>
+struct MarkNextFinished<Source, Dest, false, true> {
+ void operator()(const Result<typename Source::ValueType>& res) && {
+ next.MarkFinished(internal::Empty::ToResult(res.status()));
+ }
+ Dest next;
+};
+
+/// If neither are empty we pass on the result
+template <typename Source, typename Dest>
+struct MarkNextFinished<Source, Dest, false, false> {
+ void operator()(const Result<typename Source::ValueType>& res) && {
+ next.MarkFinished(res);
+ }
+ Dest next;
+};
+
+/// Helper that contains information about how to apply a continuation
+struct ContinueFuture {
+ template <typename Return>
+ struct ForReturnImpl;
+
+ template <typename Return>
+ using ForReturn = typename ForReturnImpl<Return>::type;
+
+ template <typename Signature>
+ using ForSignature = ForReturn<result_of_t<Signature>>;
+
+ // If the callback returns void then we return Future<> that always finishes OK.
+ template <typename ContinueFunc, typename... Args,
+ typename ContinueResult = result_of_t<ContinueFunc && (Args && ...)>,
+ typename NextFuture = ForReturn<ContinueResult>>
+ typename std::enable_if<std::is_void<ContinueResult>::value>::type operator()(
+ NextFuture next, ContinueFunc&& f, Args&&... a) const {
+ std::forward<ContinueFunc>(f)(std::forward<Args>(a)...);
+ next.MarkFinished();
+ }
+
+ /// If the callback returns a non-future then we return Future<T>
+ /// and mark the future finished with the callback result. It will get promoted
+ /// to Result<T> as part of MarkFinished if it isn't already.
+ ///
+ /// If the callback returns Status and we return Future<> then also send the callback
+ /// result as-is to the destination future.
+ template <typename ContinueFunc, typename... Args,
+ typename ContinueResult = result_of_t<ContinueFunc && (Args && ...)>,
+ typename NextFuture = ForReturn<ContinueResult>>
+ typename std::enable_if<
+ !std::is_void<ContinueResult>::value && !is_future<ContinueResult>::value &&
+ (!NextFuture::is_empty || std::is_same<ContinueResult, Status>::value)>::type
+ operator()(NextFuture next, ContinueFunc&& f, Args&&... a) const {
+ next.MarkFinished(std::forward<ContinueFunc>(f)(std::forward<Args>(a)...));
+ }
+
+ /// If the callback returns a Result and the next future is Future<> then we mark
+ /// the future finished with the callback result.
+ ///
+ /// It may seem odd that the next future is Future<> when the callback returns a
+ /// result but this can occur if the OnFailure callback returns a result while the
+ /// OnSuccess callback is void/Status (e.g. you would get this calling the one-arg
+ /// version of Then with an OnSuccess callback that returns void)
+ template <typename ContinueFunc, typename... Args,
+ typename ContinueResult = result_of_t<ContinueFunc && (Args && ...)>,
+ typename NextFuture = ForReturn<ContinueResult>>
+ typename std::enable_if<!std::is_void<ContinueResult>::value &&
+ !is_future<ContinueResult>::value && NextFuture::is_empty &&
+ !std::is_same<ContinueResult, Status>::value>::type
+ operator()(NextFuture next, ContinueFunc&& f, Args&&... a) const {
+ next.MarkFinished(std::forward<ContinueFunc>(f)(std::forward<Args>(a)...).status());
+ }
+
+ /// If the callback returns a Future<T> then we return Future<T>. We create a new
+ /// future and add a callback to the future given to us by the user that forwards the
+ /// result to the future we just created
+ template <typename ContinueFunc, typename... Args,
+ typename ContinueResult = result_of_t<ContinueFunc && (Args && ...)>,
+ typename NextFuture = ForReturn<ContinueResult>>
+ typename std::enable_if<is_future<ContinueResult>::value>::type operator()(
+ NextFuture next, ContinueFunc&& f, Args&&... a) const {
+ ContinueResult signal_to_complete_next =
+ std::forward<ContinueFunc>(f)(std::forward<Args>(a)...);
+ MarkNextFinished<ContinueResult, NextFuture> callback{std::move(next)};
+ signal_to_complete_next.AddCallback(std::move(callback));
+ }
+
+ /// Helpers to conditionally ignore arguments to ContinueFunc
+ template <typename ContinueFunc, typename NextFuture, typename... Args>
+ void IgnoringArgsIf(std::true_type, NextFuture&& next, ContinueFunc&& f,
+ Args&&...) const {
+ operator()(std::forward<NextFuture>(next), std::forward<ContinueFunc>(f));
+ }
+ template <typename ContinueFunc, typename NextFuture, typename... Args>
+ void IgnoringArgsIf(std::false_type, NextFuture&& next, ContinueFunc&& f,
+ Args&&... a) const {
+ operator()(std::forward<NextFuture>(next), std::forward<ContinueFunc>(f),
+ std::forward<Args>(a)...);
+ }
+};
+
+/// Helper struct which tells us what kind of Future gets returned from `Then` based on
+/// the return type of the OnSuccess callback
+template <>
+struct ContinueFuture::ForReturnImpl<void> {
+ using type = Future<>;
+};
+
+template <>
+struct ContinueFuture::ForReturnImpl<Status> {
+ using type = Future<>;
+};
+
+template <typename R>
+struct ContinueFuture::ForReturnImpl {
+ using type = Future<R>;
+};
+
+template <typename T>
+struct ContinueFuture::ForReturnImpl<Result<T>> {
+ using type = Future<T>;
+};
+
+template <typename T>
+struct ContinueFuture::ForReturnImpl<Future<T>> {
+ using type = Future<T>;
+};
+
+} // namespace detail
+
+/// A Future's execution or completion status
+enum class FutureState : int8_t { PENDING, SUCCESS, FAILURE };
+
+inline bool IsFutureFinished(FutureState state) { return state != FutureState::PENDING; }
+
+/// \brief Describe whether the callback should be scheduled or run synchronously
+enum class ShouldSchedule {
+ /// Always run the callback synchronously (the default)
+ Never = 0,
+ /// Schedule a new task only if the future is not finished when the
+ /// callback is added
+ IfUnfinished = 1,
+ /// Always schedule the callback as a new task
+ Always = 2,
+ /// Schedule a new task only if it would run on an executor other than
+ /// the specified executor.
+ IfDifferentExecutor = 3,
+};
+
+/// \brief Options that control how a continuation is run
+struct CallbackOptions {
+ /// Describe whether the callback should be run synchronously or scheduled
+ ShouldSchedule should_schedule = ShouldSchedule::Never;
+ /// If the callback is scheduled then this is the executor it should be scheduled
+ /// on. If this is NULL then should_schedule must be Never
+ internal::Executor* executor = NULLPTR;
+
+ static CallbackOptions Defaults() { return {}; }
+};
+
+// Untyped private implementation
+class ARROW_EXPORT FutureImpl : public std::enable_shared_from_this<FutureImpl> {
+ public:
+ FutureImpl();
+ virtual ~FutureImpl() = default;
+
+ FutureState state() { return state_.load(); }
+
+ static std::unique_ptr<FutureImpl> Make();
+ static std::unique_ptr<FutureImpl> MakeFinished(FutureState state);
+
+ // Future API
+ void MarkFinished();
+ void MarkFailed();
+ void Wait();
+ bool Wait(double seconds);
+ template <typename ValueType>
+ Result<ValueType>* CastResult() const {
+ return static_cast<Result<ValueType>*>(result_.get());
+ }
+
+ using Callback = internal::FnOnce<void(const FutureImpl& impl)>;
+ void AddCallback(Callback callback, CallbackOptions opts);
+ bool TryAddCallback(const std::function<Callback()>& callback_factory,
+ CallbackOptions opts);
+
+ // Waiter API
+ inline FutureState SetWaiter(FutureWaiter* w, int future_num);
+ inline void RemoveWaiter(FutureWaiter* w);
+
+ std::atomic<FutureState> state_{FutureState::PENDING};
+
+ // Type erased storage for arbitrary results
+ // XXX small objects could be stored inline instead of boxed in a pointer
+ using Storage = std::unique_ptr<void, void (*)(void*)>;
+ Storage result_{NULLPTR, NULLPTR};
+
+ struct CallbackRecord {
+ Callback callback;
+ CallbackOptions options;
+ };
+ std::vector<CallbackRecord> callbacks_;
+};
+
+// An object that waits on multiple futures at once. Only one waiter
+// can be registered for each future at any time.
+class ARROW_EXPORT FutureWaiter {
+ public:
+ enum Kind : int8_t { ANY, ALL, ALL_OR_FIRST_FAILED, ITERATE };
+
+ // HUGE_VAL isn't constexpr on Windows
+ // https://social.msdn.microsoft.com/Forums/vstudio/en-US/47e8b9ff-b205-4189-968e-ee3bc3e2719f/constexpr-compile-error?forum=vclanguage
+ static const double kInfinity;
+
+ static std::unique_ptr<FutureWaiter> Make(Kind kind, std::vector<FutureImpl*> futures);
+
+ template <typename FutureType>
+ static std::unique_ptr<FutureWaiter> Make(Kind kind,
+ const std::vector<FutureType>& futures) {
+ return Make(kind, ExtractFutures(futures));
+ }
+
+ virtual ~FutureWaiter();
+
+ bool Wait(double seconds = kInfinity);
+ int WaitAndFetchOne();
+
+ std::vector<int> MoveFinishedFutures();
+
+ protected:
+ // Extract FutureImpls from Futures
+ template <typename FutureType,
+ typename Enable = std::enable_if<!std::is_pointer<FutureType>::value>>
+ static std::vector<FutureImpl*> ExtractFutures(const std::vector<FutureType>& futures) {
+ std::vector<FutureImpl*> base_futures(futures.size());
+ for (int i = 0; i < static_cast<int>(futures.size()); ++i) {
+ base_futures[i] = futures[i].impl_.get();
+ }
+ return base_futures;
+ }
+
+ // Extract FutureImpls from Future pointers
+ template <typename FutureType>
+ static std::vector<FutureImpl*> ExtractFutures(
+ const std::vector<FutureType*>& futures) {
+ std::vector<FutureImpl*> base_futures(futures.size());
+ for (int i = 0; i < static_cast<int>(futures.size()); ++i) {
+ base_futures[i] = futures[i]->impl_.get();
+ }
+ return base_futures;
+ }
+
+ FutureWaiter();
+ ARROW_DISALLOW_COPY_AND_ASSIGN(FutureWaiter);
+
+ inline void MarkFutureFinishedUnlocked(int future_num, FutureState state);
+
+ friend class FutureImpl;
+ friend class ConcreteFutureImpl;
+};
+
+// ---------------------------------------------------------------------
+// Public API
+
+/// \brief EXPERIMENTAL A std::future-like class with more functionality.
+///
+/// A Future represents the results of a past or future computation.
+/// The Future API has two sides: a producer side and a consumer side.
+///
+/// The producer API allows creating a Future and setting its result or
+/// status, possibly after running a computation function.
+///
+/// The consumer API allows querying a Future's current state, wait for it
+/// to complete, or wait on multiple Futures at once (using WaitForAll,
+/// WaitForAny or AsCompletedIterator).
+template <typename T>
+class ARROW_MUST_USE_TYPE Future {
+ public:
+ using ValueType = T;
+ using SyncType = typename detail::SyncType<T>::type;
+ static constexpr bool is_empty = std::is_same<T, internal::Empty>::value;
+ // The default constructor creates an invalid Future. Use Future::Make()
+ // for a valid Future. This constructor is mostly for the convenience
+ // of being able to presize a vector of Futures.
+ Future() = default;
+
+ // Consumer API
+
+ bool is_valid() const { return impl_ != NULLPTR; }
+
+ /// \brief Return the Future's current state
+ ///
+ /// A return value of PENDING is only indicative, as the Future can complete
+ /// concurrently. A return value of FAILURE or SUCCESS is definitive, though.
+ FutureState state() const {
+ CheckValid();
+ return impl_->state();
+ }
+
+ /// \brief Whether the Future is finished
+ ///
+ /// A false return value is only indicative, as the Future can complete
+ /// concurrently. A true return value is definitive, though.
+ bool is_finished() const {
+ CheckValid();
+ return IsFutureFinished(impl_->state());
+ }
+
+ /// \brief Wait for the Future to complete and return its Result
+ const Result<ValueType>& result() const& {
+ Wait();
+ return *GetResult();
+ }
+
+ /// \brief Returns an rvalue to the result. This method is potentially unsafe
+ ///
+ /// The future is not the unique owner of the result, copies of a future will
+ /// also point to the same result. You must make sure that no other copies
+ /// of the future exist. Attempts to add callbacks after you move the result
+ /// will result in undefined behavior.
+ Result<ValueType>&& MoveResult() {
+ Wait();
+ return std::move(*GetResult());
+ }
+
+ /// \brief Wait for the Future to complete and return its Status
+ const Status& status() const { return result().status(); }
+
+ /// \brief Future<T> is convertible to Future<>, which views only the
+ /// Status of the original. Marking the returned Future Finished is not supported.
+ explicit operator Future<>() const {
+ Future<> status_future;
+ status_future.impl_ = impl_;
+ return status_future;
+ }
+
+ /// \brief Wait for the Future to complete
+ void Wait() const {
+ CheckValid();
+ if (!IsFutureFinished(impl_->state())) {
+ impl_->Wait();
+ }
+ }
+
+ /// \brief Wait for the Future to complete, or for the timeout to expire
+ ///
+ /// `true` is returned if the Future completed, `false` if the timeout expired.
+ /// Note a `false` value is only indicative, as the Future can complete
+ /// concurrently.
+ bool Wait(double seconds) const {
+ CheckValid();
+ if (IsFutureFinished(impl_->state())) {
+ return true;
+ }
+ return impl_->Wait(seconds);
+ }
+
+ // Producer API
+
+ /// \brief Producer API: mark Future finished
+ ///
+ /// The Future's result is set to `res`.
+ void MarkFinished(Result<ValueType> res) { DoMarkFinished(std::move(res)); }
+
+ /// \brief Mark a Future<> completed with the provided Status.
+ template <typename E = ValueType, typename = typename std::enable_if<
+ std::is_same<E, internal::Empty>::value>::type>
+ void MarkFinished(Status s = Status::OK()) {
+ return DoMarkFinished(E::ToResult(std::move(s)));
+ }
+
+ /// \brief Producer API: instantiate a valid Future
+ ///
+ /// The Future's state is initialized with PENDING. If you are creating a future with
+ /// this method you must ensure that future is eventually completed (with success or
+ /// failure). Creating a future, returning it, and never completing the future can lead
+ /// to memory leaks (for example, see Loop).
+ static Future Make() {
+ Future fut;
+ fut.impl_ = FutureImpl::Make();
+ return fut;
+ }
+
+ /// \brief Producer API: instantiate a finished Future
+ static Future<ValueType> MakeFinished(Result<ValueType> res) {
+ Future<ValueType> fut;
+ fut.InitializeFromResult(std::move(res));
+ return fut;
+ }
+
+ /// \brief Make a finished Future<> with the provided Status.
+ template <typename E = ValueType, typename = typename std::enable_if<
+ std::is_same<E, internal::Empty>::value>::type>
+ static Future<> MakeFinished(Status s = Status::OK()) {
+ return MakeFinished(E::ToResult(std::move(s)));
+ }
+
+ struct WrapResultyOnComplete {
+ template <typename OnComplete>
+ struct Callback {
+ void operator()(const FutureImpl& impl) && {
+ std::move(on_complete)(*impl.CastResult<ValueType>());
+ }
+ OnComplete on_complete;
+ };
+ };
+
+ struct WrapStatusyOnComplete {
+ template <typename OnComplete>
+ struct Callback {
+ static_assert(std::is_same<internal::Empty, ValueType>::value,
+ "Only callbacks for Future<> should accept Status and not Result");
+
+ void operator()(const FutureImpl& impl) && {
+ std::move(on_complete)(impl.CastResult<ValueType>()->status());
+ }
+ OnComplete on_complete;
+ };
+ };
+
+ template <typename OnComplete>
+ using WrapOnComplete = typename std::conditional<
+ detail::first_arg_is_status<OnComplete>::value, WrapStatusyOnComplete,
+ WrapResultyOnComplete>::type::template Callback<OnComplete>;
+
+ /// \brief Consumer API: Register a callback to run when this future completes
+ ///
+ /// The callback should receive the result of the future (const Result<T>&)
+ /// For a void or statusy future this should be (const Status&)
+ ///
+ /// There is no guarantee to the order in which callbacks will run. In
+ /// particular, callbacks added while the future is being marked complete
+ /// may be executed immediately, ahead of, or even the same time as, other
+ /// callbacks that have been previously added.
+ ///
+ /// WARNING: callbacks may hold arbitrary references, including cyclic references.
+ /// Since callbacks will only be destroyed after they are invoked, this can lead to
+ /// memory leaks if a Future is never marked finished (abandoned):
+ ///
+ /// {
+ /// auto fut = Future<>::Make();
+ /// fut.AddCallback([fut]() {});
+ /// }
+ ///
+ /// In this example `fut` falls out of scope but is not destroyed because it holds a
+ /// cyclic reference to itself through the callback.
+ template <typename OnComplete, typename Callback = WrapOnComplete<OnComplete>>
+ void AddCallback(OnComplete on_complete,
+ CallbackOptions opts = CallbackOptions::Defaults()) const {
+ // We know impl_ will not be dangling when invoking callbacks because at least one
+ // thread will be waiting for MarkFinished to return. Thus it's safe to keep a
+ // weak reference to impl_ here
+ impl_->AddCallback(Callback{std::move(on_complete)}, opts);
+ }
+
+ /// \brief Overload of AddCallback that will return false instead of running
+ /// synchronously
+ ///
+ /// This overload will guarantee the callback is never run synchronously. If the future
+ /// is already finished then it will simply return false. This can be useful to avoid
+ /// stack overflow in a situation where you have recursive Futures. For an example
+ /// see the Loop function
+ ///
+ /// Takes in a callback factory function to allow moving callbacks (the factory function
+ /// will only be called if the callback can successfully be added)
+ ///
+ /// Returns true if a callback was actually added and false if the callback failed
+ /// to add because the future was marked complete.
+ template <typename CallbackFactory,
+ typename OnComplete = detail::result_of_t<CallbackFactory()>,
+ typename Callback = WrapOnComplete<OnComplete>>
+ bool TryAddCallback(const CallbackFactory& callback_factory,
+ CallbackOptions opts = CallbackOptions::Defaults()) const {
+ return impl_->TryAddCallback([&]() { return Callback{callback_factory()}; }, opts);
+ }
+
+ template <typename OnSuccess, typename OnFailure>
+ struct ThenOnComplete {
+ static constexpr bool has_no_args =
+ internal::call_traits::argument_count<OnSuccess>::value == 0;
+
+ using ContinuedFuture = detail::ContinueFuture::ForSignature<
+ detail::if_has_no_args<OnSuccess, OnSuccess && (), OnSuccess && (const T&)>>;
+
+ static_assert(
+ std::is_same<detail::ContinueFuture::ForSignature<OnFailure && (const Status&)>,
+ ContinuedFuture>::value,
+ "OnSuccess and OnFailure must continue with the same future type");
+
+ struct DummyOnSuccess {
+ void operator()(const T&);
+ };
+ using OnSuccessArg = typename std::decay<internal::call_traits::argument_type<
+ 0, detail::if_has_no_args<OnSuccess, DummyOnSuccess, OnSuccess>>>::type;
+
+ static_assert(
+ !std::is_same<OnSuccessArg, typename EnsureResult<OnSuccessArg>::type>::value,
+ "OnSuccess' argument should not be a Result");
+
+ void operator()(const Result<T>& result) && {
+ detail::ContinueFuture continue_future;
+ if (ARROW_PREDICT_TRUE(result.ok())) {
+ // move on_failure to a(n immediately destroyed) temporary to free its resources
+ ARROW_UNUSED(OnFailure(std::move(on_failure)));
+ continue_future.IgnoringArgsIf(
+ detail::if_has_no_args<OnSuccess, std::true_type, std::false_type>{},
+ std::move(next), std::move(on_success), result.ValueOrDie());
+ } else {
+ ARROW_UNUSED(OnSuccess(std::move(on_success)));
+ continue_future(std::move(next), std::move(on_failure), result.status());
+ }
+ }
+
+ OnSuccess on_success;
+ OnFailure on_failure;
+ ContinuedFuture next;
+ };
+
+ template <typename OnSuccess>
+ struct PassthruOnFailure {
+ using ContinuedFuture = detail::ContinueFuture::ForSignature<
+ detail::if_has_no_args<OnSuccess, OnSuccess && (), OnSuccess && (const T&)>>;
+
+ Result<typename ContinuedFuture::ValueType> operator()(const Status& s) { return s; }
+ };
+
+ /// \brief Consumer API: Register a continuation to run when this future completes
+ ///
+ /// The continuation will run in the same thread that called MarkFinished (whatever
+ /// callback is registered with this function will run before MarkFinished returns).
+ /// Avoid long-running callbacks in favor of submitting a task to an Executor and
+ /// returning the future.
+ ///
+ /// Two callbacks are supported:
+ /// - OnSuccess, called with the result (const ValueType&) on successul completion.
+ /// for an empty future this will be called with nothing ()
+ /// - OnFailure, called with the error (const Status&) on failed completion.
+ /// This callback is optional and defaults to a passthru of any errors.
+ ///
+ /// Then() returns a Future whose ValueType is derived from the return type of the
+ /// callbacks. If a callback returns:
+ /// - void, a Future<> will be returned which will completes successully as soon
+ /// as the callback runs.
+ /// - Status, a Future<> will be returned which will complete with the returned Status
+ /// as soon as the callback runs.
+ /// - V or Result<V>, a Future<V> will be returned which will complete with the result
+ /// of invoking the callback as soon as the callback runs.
+ /// - Future<V>, a Future<V> will be returned which will be marked complete when the
+ /// future returned by the callback completes (and will complete with the same
+ /// result).
+ ///
+ /// The continued Future type must be the same for both callbacks.
+ ///
+ /// Note that OnFailure can swallow errors, allowing continued Futures to successully
+ /// complete even if this Future fails.
+ ///
+ /// If this future is already completed then the callback will be run immediately
+ /// and the returned future may already be marked complete.
+ ///
+ /// See AddCallback for general considerations when writing callbacks.
+ template <typename OnSuccess, typename OnFailure = PassthruOnFailure<OnSuccess>,
+ typename OnComplete = ThenOnComplete<OnSuccess, OnFailure>,
+ typename ContinuedFuture = typename OnComplete::ContinuedFuture>
+ ContinuedFuture Then(OnSuccess on_success, OnFailure on_failure = {},
+ CallbackOptions options = CallbackOptions::Defaults()) const {
+ auto next = ContinuedFuture::Make();
+ AddCallback(OnComplete{std::forward<OnSuccess>(on_success),
+ std::forward<OnFailure>(on_failure), next},
+ options);
+ return next;
+ }
+
+ /// \brief Implicit constructor to create a finished future from a value
+ Future(ValueType val) : Future() { // NOLINT runtime/explicit
+ impl_ = FutureImpl::MakeFinished(FutureState::SUCCESS);
+ SetResult(std::move(val));
+ }
+
+ /// \brief Implicit constructor to create a future from a Result, enabling use
+ /// of macros like ARROW_ASSIGN_OR_RAISE.
+ Future(Result<ValueType> res) : Future() { // NOLINT runtime/explicit
+ if (ARROW_PREDICT_TRUE(res.ok())) {
+ impl_ = FutureImpl::MakeFinished(FutureState::SUCCESS);
+ } else {
+ impl_ = FutureImpl::MakeFinished(FutureState::FAILURE);
+ }
+ SetResult(std::move(res));
+ }
+
+ /// \brief Implicit constructor to create a future from a Status, enabling use
+ /// of macros like ARROW_RETURN_NOT_OK.
+ Future(Status s) // NOLINT runtime/explicit
+ : Future(Result<ValueType>(std::move(s))) {}
+
+ protected:
+ void InitializeFromResult(Result<ValueType> res) {
+ if (ARROW_PREDICT_TRUE(res.ok())) {
+ impl_ = FutureImpl::MakeFinished(FutureState::SUCCESS);
+ } else {
+ impl_ = FutureImpl::MakeFinished(FutureState::FAILURE);
+ }
+ SetResult(std::move(res));
+ }
+
+ void Initialize() { impl_ = FutureImpl::Make(); }
+
+ Result<ValueType>* GetResult() const { return impl_->CastResult<ValueType>(); }
+
+ void SetResult(Result<ValueType> res) {
+ impl_->result_ = {new Result<ValueType>(std::move(res)),
+ [](void* p) { delete static_cast<Result<ValueType>*>(p); }};
+ }
+
+ void DoMarkFinished(Result<ValueType> res) {
+ SetResult(std::move(res));
+
+ if (ARROW_PREDICT_TRUE(GetResult()->ok())) {
+ impl_->MarkFinished();
+ } else {
+ impl_->MarkFailed();
+ }
+ }
+
+ void CheckValid() const {
+#ifndef NDEBUG
+ if (!is_valid()) {
+ Status::Invalid("Invalid Future (default-initialized?)").Abort();
+ }
+#endif
+ }
+
+ explicit Future(std::shared_ptr<FutureImpl> impl) : impl_(std::move(impl)) {}
+
+ std::shared_ptr<FutureImpl> impl_;
+
+ friend class FutureWaiter;
+ friend struct detail::ContinueFuture;
+
+ template <typename U>
+ friend class Future;
+ friend class WeakFuture<T>;
+
+ FRIEND_TEST(FutureRefTest, ChainRemoved);
+ FRIEND_TEST(FutureRefTest, TailRemoved);
+ FRIEND_TEST(FutureRefTest, HeadRemoved);
+};
+
+template <typename T>
+typename Future<T>::SyncType FutureToSync(const Future<T>& fut) {
+ return fut.result();
+}
+
+template <>
+inline typename Future<internal::Empty>::SyncType FutureToSync<internal::Empty>(
+ const Future<internal::Empty>& fut) {
+ return fut.status();
+}
+
+template <typename T>
+class WeakFuture {
+ public:
+ explicit WeakFuture(const Future<T>& future) : impl_(future.impl_) {}
+
+ Future<T> get() { return Future<T>{impl_.lock()}; }
+
+ private:
+ std::weak_ptr<FutureImpl> impl_;
+};
+
+/// If a Result<Future> holds an error instead of a Future, construct a finished Future
+/// holding that error.
+template <typename T>
+static Future<T> DeferNotOk(Result<Future<T>> maybe_future) {
+ if (ARROW_PREDICT_FALSE(!maybe_future.ok())) {
+ return Future<T>::MakeFinished(std::move(maybe_future).status());
+ }
+ return std::move(maybe_future).MoveValueUnsafe();
+}
+
+/// \brief Wait for all the futures to end, or for the given timeout to expire.
+///
+/// `true` is returned if all the futures completed before the timeout was reached,
+/// `false` otherwise.
+template <typename T>
+inline bool WaitForAll(const std::vector<Future<T>>& futures,
+ double seconds = FutureWaiter::kInfinity) {
+ auto waiter = FutureWaiter::Make(FutureWaiter::ALL, futures);
+ return waiter->Wait(seconds);
+}
+
+/// \brief Wait for all the futures to end, or for the given timeout to expire.
+///
+/// `true` is returned if all the futures completed before the timeout was reached,
+/// `false` otherwise.
+template <typename T>
+inline bool WaitForAll(const std::vector<Future<T>*>& futures,
+ double seconds = FutureWaiter::kInfinity) {
+ auto waiter = FutureWaiter::Make(FutureWaiter::ALL, futures);
+ return waiter->Wait(seconds);
+}
+
+/// \brief Create a Future which completes when all of `futures` complete.
+///
+/// The future's result is a vector of the results of `futures`.
+/// Note that this future will never be marked "failed"; failed results
+/// will be stored in the result vector alongside successful results.
+template <typename T>
+Future<std::vector<Result<T>>> All(std::vector<Future<T>> futures) {
+ struct State {
+ explicit State(std::vector<Future<T>> f)
+ : futures(std::move(f)), n_remaining(futures.size()) {}
+
+ std::vector<Future<T>> futures;
+ std::atomic<size_t> n_remaining;
+ };
+
+ if (futures.size() == 0) {
+ return {std::vector<Result<T>>{}};
+ }
+
+ auto state = std::make_shared<State>(std::move(futures));
+
+ auto out = Future<std::vector<Result<T>>>::Make();
+ for (const Future<T>& future : state->futures) {
+ future.AddCallback([state, out](const Result<T>&) mutable {
+ if (state->n_remaining.fetch_sub(1) != 1) return;
+
+ std::vector<Result<T>> results(state->futures.size());
+ for (size_t i = 0; i < results.size(); ++i) {
+ results[i] = state->futures[i].result();
+ }
+ out.MarkFinished(std::move(results));
+ });
+ }
+ return out;
+}
+
+template <>
+inline Future<>::Future(Status s) : Future(internal::Empty::ToResult(std::move(s))) {}
+
+/// \brief Create a Future which completes when all of `futures` complete.
+///
+/// The future will be marked complete if all `futures` complete
+/// successfully. Otherwise, it will be marked failed with the status of
+/// the first failing future.
+ARROW_EXPORT
+Future<> AllComplete(const std::vector<Future<>>& futures);
+
+/// \brief Create a Future which completes when all of `futures` complete.
+///
+/// The future will finish with an ok status if all `futures` finish with
+/// an ok status. Otherwise, it will be marked failed with the status of
+/// one of the failing futures.
+///
+/// Unlike AllComplete this Future will not complete immediately when a
+/// failure occurs. It will wait until all futures have finished.
+ARROW_EXPORT
+Future<> AllFinished(const std::vector<Future<>>& futures);
+
+/// \brief Wait for one of the futures to end, or for the given timeout to expire.
+///
+/// The indices of all completed futures are returned. Note that some futures
+/// may not be in the returned set, but still complete concurrently.
+template <typename T>
+inline std::vector<int> WaitForAny(const std::vector<Future<T>>& futures,
+ double seconds = FutureWaiter::kInfinity) {
+ auto waiter = FutureWaiter::Make(FutureWaiter::ANY, futures);
+ waiter->Wait(seconds);
+ return waiter->MoveFinishedFutures();
+}
+
+/// \brief Wait for one of the futures to end, or for the given timeout to expire.
+///
+/// The indices of all completed futures are returned. Note that some futures
+/// may not be in the returned set, but still complete concurrently.
+template <typename T>
+inline std::vector<int> WaitForAny(const std::vector<Future<T>*>& futures,
+ double seconds = FutureWaiter::kInfinity) {
+ auto waiter = FutureWaiter::Make(FutureWaiter::ANY, futures);
+ waiter->Wait(seconds);
+ return waiter->MoveFinishedFutures();
+}
+
+struct Continue {
+ template <typename T>
+ operator util::optional<T>() && { // NOLINT explicit
+ return {};
+ }
+};
+
+template <typename T = internal::Empty>
+util::optional<T> Break(T break_value = {}) {
+ return util::optional<T>{std::move(break_value)};
+}
+
+template <typename T = internal::Empty>
+using ControlFlow = util::optional<T>;
+
+/// \brief Loop through an asynchronous sequence
+///
+/// \param[in] iterate A generator of Future<ControlFlow<BreakValue>>. On completion
+/// of each yielded future the resulting ControlFlow will be examined. A Break will
+/// terminate the loop, while a Continue will re-invoke `iterate`.
+///
+/// \return A future which will complete when a Future returned by iterate completes with
+/// a Break
+template <typename Iterate,
+ typename Control = typename detail::result_of_t<Iterate()>::ValueType,
+ typename BreakValueType = typename Control::value_type>
+Future<BreakValueType> Loop(Iterate iterate) {
+ struct Callback {
+ bool CheckForTermination(const Result<Control>& control_res) {
+ if (!control_res.ok()) {
+ break_fut.MarkFinished(control_res.status());
+ return true;
+ }
+ if (control_res->has_value()) {
+ break_fut.MarkFinished(**control_res);
+ return true;
+ }
+ return false;
+ }
+
+ void operator()(const Result<Control>& maybe_control) && {
+ if (CheckForTermination(maybe_control)) return;
+
+ auto control_fut = iterate();
+ while (true) {
+ if (control_fut.TryAddCallback([this]() { return *this; })) {
+ // Adding a callback succeeded; control_fut was not finished
+ // and we must wait to CheckForTermination.
+ return;
+ }
+ // Adding a callback failed; control_fut was finished and we
+ // can CheckForTermination immediately. This also avoids recursion and potential
+ // stack overflow.
+ if (CheckForTermination(control_fut.result())) return;
+
+ control_fut = iterate();
+ }
+ }
+
+ Iterate iterate;
+
+ // If the future returned by control_fut is never completed then we will be hanging on
+ // to break_fut forever even if the listener has given up listening on it. Instead we
+ // rely on the fact that a producer (the caller of Future<>::Make) is always
+ // responsible for completing the futures they create.
+ // TODO: Could avoid this kind of situation with "future abandonment" similar to mesos
+ Future<BreakValueType> break_fut;
+ };
+
+ auto break_fut = Future<BreakValueType>::Make();
+ auto control_fut = iterate();
+ control_fut.AddCallback(Callback{std::move(iterate), break_fut});
+
+ return break_fut;
+}
+
+inline Future<> ToFuture(Status status) {
+ return Future<>::MakeFinished(std::move(status));
+}
+
+template <typename T>
+Future<T> ToFuture(T value) {
+ return Future<T>::MakeFinished(std::move(value));
+}
+
+template <typename T>
+Future<T> ToFuture(Result<T> maybe_value) {
+ return Future<T>::MakeFinished(std::move(maybe_value));
+}
+
+template <typename T>
+Future<T> ToFuture(Future<T> fut) {
+ return std::move(fut);
+}
+
+template <typename T>
+struct EnsureFuture {
+ using type = decltype(ToFuture(std::declval<T>()));
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/future_iterator.h b/src/arrow/cpp/src/arrow/util/future_iterator.h
new file mode 100644
index 000000000..9837ae853
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/future_iterator.h
@@ -0,0 +1,75 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cassert>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/util/future.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+/// An iterator that takes a set of futures, and yields their results as
+/// they are completed, in any order.
+template <typename T>
+class AsCompletedIterator {
+ public:
+ // Public default constructor creates an empty iterator
+ AsCompletedIterator();
+
+ explicit AsCompletedIterator(std::vector<Future<T>> futures)
+ : futures_(std::move(futures)),
+ waiter_(FutureWaiter::Make(FutureWaiter::ITERATE, futures_)) {}
+
+ ARROW_DEFAULT_MOVE_AND_ASSIGN(AsCompletedIterator);
+ ARROW_DISALLOW_COPY_AND_ASSIGN(AsCompletedIterator);
+
+ /// Return the results of the first completed, not-yet-returned Future.
+ ///
+ /// The result can be successful or not, depending on the Future's underlying
+ /// task's result. Even if a Future returns a failed Result, you can still
+ /// call Next() to get further results.
+ Result<T> Next() {
+ if (n_fetched_ == futures_.size()) {
+ return IterationTraits<T>::End();
+ }
+ auto index = waiter_->WaitAndFetchOne();
+ ++n_fetched_;
+ assert(index >= 0 && static_cast<size_t>(index) < futures_.size());
+ auto& fut = futures_[index];
+ assert(IsFutureFinished(fut.state()));
+ return std::move(fut).result();
+ }
+
+ private:
+ size_t n_fetched_ = 0;
+ std::vector<Future<T>> futures_;
+ std::unique_ptr<FutureWaiter> waiter_;
+};
+
+template <typename T>
+Iterator<T> MakeAsCompletedIterator(std::vector<Future<T>> futures) {
+ return Iterator<T>(AsCompletedIterator<T>(std::move(futures)));
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/future_test.cc b/src/arrow/cpp/src/arrow/util/future_test.cc
new file mode 100644
index 000000000..0db355433
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/future_test.cc
@@ -0,0 +1,1803 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/future.h"
+#include "arrow/util/future_iterator.h"
+
+#include <algorithm>
+#include <chrono>
+#include <condition_variable>
+#include <memory>
+#include <mutex>
+#include <ostream>
+#include <random>
+#include <string>
+#include <thread>
+#include <unordered_set>
+#include <vector>
+
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
+
+#include "arrow/testing/executor_util.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::ThreadPool;
+
+int ToInt(int x) { return x; }
+
+// A data type without a default constructor.
+struct Foo {
+ int bar;
+ std::string baz;
+
+ explicit Foo(int value) : bar(value), baz(std::to_string(value)) {}
+
+ int ToInt() const { return bar; }
+
+ bool operator==(int other) const { return bar == other; }
+ bool operator==(const Foo& other) const { return bar == other.bar; }
+ friend bool operator==(int left, const Foo& right) { return right == left; }
+
+ friend std::ostream& operator<<(std::ostream& os, const Foo& foo) {
+ return os << "Foo(" << foo.bar << ")";
+ }
+};
+
+template <>
+struct IterationTraits<Foo> {
+ static Foo End() { return Foo(-1); }
+};
+
+template <>
+struct IterationTraits<MoveOnlyDataType> {
+ static MoveOnlyDataType End() { return MoveOnlyDataType(-1); }
+};
+
+template <typename T>
+struct IteratorResults {
+ std::vector<T> values;
+ std::vector<Status> errors;
+};
+
+template <typename T>
+IteratorResults<T> IteratorToResults(Iterator<T> iterator) {
+ IteratorResults<T> results;
+
+ while (true) {
+ auto res = iterator.Next();
+ if (res == IterationTraits<T>::End()) {
+ break;
+ }
+ if (res.ok()) {
+ results.values.push_back(*std::move(res));
+ } else {
+ results.errors.push_back(res.status());
+ }
+ }
+ return results;
+}
+
+// So that main thread may wait a bit for a future to be finished
+constexpr auto kYieldDuration = std::chrono::microseconds(50);
+constexpr double kTinyWait = 1e-5; // seconds
+constexpr double kLargeWait = 5.0; // seconds
+
+template <typename T>
+class SimpleExecutor {
+ public:
+ explicit SimpleExecutor(int nfutures)
+ : pool_(ThreadPool::Make(/*threads=*/4).ValueOrDie()) {
+ for (int i = 0; i < nfutures; ++i) {
+ futures_.push_back(Future<T>::Make());
+ }
+ }
+
+ std::vector<Future<T>>& futures() { return futures_; }
+
+ void SetFinished(const std::vector<std::pair<int, bool>>& pairs) {
+ for (const auto& pair : pairs) {
+ const int fut_index = pair.first;
+ if (pair.second) {
+ futures_[fut_index].MarkFinished(T(fut_index));
+ } else {
+ futures_[fut_index].MarkFinished(Status::UnknownError("xxx"));
+ }
+ }
+ }
+
+ void SetFinishedDeferred(std::vector<std::pair<int, bool>> pairs) {
+ std::this_thread::sleep_for(kYieldDuration);
+ ABORT_NOT_OK(pool_->Spawn([=]() { SetFinished(pairs); }));
+ }
+
+ // Mark future successful
+ void SetFinished(int fut_index) { futures_[fut_index].MarkFinished(T(fut_index)); }
+
+ void SetFinishedDeferred(int fut_index) {
+ std::this_thread::sleep_for(kYieldDuration);
+ ABORT_NOT_OK(pool_->Spawn([=]() { SetFinished(fut_index); }));
+ }
+
+ // Mark all futures in [start, stop) successful
+ void SetFinished(int start, int stop) {
+ for (int fut_index = start; fut_index < stop; ++fut_index) {
+ futures_[fut_index].MarkFinished(T(fut_index));
+ }
+ }
+
+ void SetFinishedDeferred(int start, int stop) {
+ std::this_thread::sleep_for(kYieldDuration);
+ ABORT_NOT_OK(pool_->Spawn([=]() { SetFinished(start, stop); }));
+ }
+
+ protected:
+ std::vector<Future<T>> futures_;
+ std::shared_ptr<ThreadPool> pool_;
+};
+
+// --------------------------------------------------------------------
+// Simple in-thread tests
+
+TEST(FutureSyncTest, Int) {
+ {
+ // MarkFinished(int)
+ auto fut = Future<int>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished(42);
+ AssertSuccessful(fut);
+ auto res = fut.result();
+ ASSERT_OK(res);
+ ASSERT_EQ(*res, 42);
+ res = std::move(fut).result();
+ ASSERT_OK(res);
+ ASSERT_EQ(*res, 42);
+ }
+ {
+ // MakeFinished(int)
+ auto fut = Future<int>::MakeFinished(42);
+ AssertSuccessful(fut);
+ auto res = fut.result();
+ ASSERT_OK(res);
+ ASSERT_EQ(*res, 42);
+ res = std::move(fut.result());
+ ASSERT_OK(res);
+ ASSERT_EQ(*res, 42);
+ }
+ {
+ // MarkFinished(Result<int>)
+ auto fut = Future<int>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished(Result<int>(43));
+ AssertSuccessful(fut);
+ ASSERT_OK_AND_ASSIGN(auto value, fut.result());
+ ASSERT_EQ(value, 43);
+ }
+ {
+ // MarkFinished(failed Result<int>)
+ auto fut = Future<int>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished(Result<int>(Status::IOError("xxx")));
+ AssertFailed(fut);
+ ASSERT_RAISES(IOError, fut.result());
+ }
+ {
+ // MakeFinished(Status)
+ auto fut = Future<int>::MakeFinished(Status::IOError("xxx"));
+ AssertFailed(fut);
+ ASSERT_RAISES(IOError, fut.result());
+ }
+ {
+ // MarkFinished(Status)
+ auto fut = Future<int>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut);
+ ASSERT_RAISES(IOError, fut.result());
+ }
+}
+
+TEST(FutureSyncTest, Foo) {
+ {
+ auto fut = Future<Foo>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished(Foo(42));
+ AssertSuccessful(fut);
+ auto res = fut.result();
+ ASSERT_OK(res);
+ Foo value = *res;
+ ASSERT_EQ(value, 42);
+ ASSERT_OK(fut.status());
+ res = std::move(fut).result();
+ ASSERT_OK(res);
+ value = *res;
+ ASSERT_EQ(value, 42);
+ }
+ {
+ // MarkFinished(Result<Foo>)
+ auto fut = Future<Foo>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished(Result<Foo>(Foo(42)));
+ AssertSuccessful(fut);
+ ASSERT_OK_AND_ASSIGN(Foo value, fut.result());
+ ASSERT_EQ(value, 42);
+ }
+ {
+ // MarkFinished(failed Result<Foo>)
+ auto fut = Future<Foo>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished(Result<Foo>(Status::IOError("xxx")));
+ AssertFailed(fut);
+ ASSERT_RAISES(IOError, fut.result());
+ }
+}
+
+TEST(FutureSyncTest, Empty) {
+ {
+ // MarkFinished()
+ auto fut = Future<>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished();
+ AssertSuccessful(fut);
+ }
+ {
+ // MakeFinished()
+ auto fut = Future<>::MakeFinished();
+ AssertSuccessful(fut);
+ }
+ {
+ // MarkFinished(Status)
+ auto fut = Future<>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished();
+ AssertSuccessful(fut);
+ }
+ {
+ // MakeFinished(Status)
+ auto fut = Future<>::MakeFinished();
+ AssertSuccessful(fut);
+ fut = Future<>::MakeFinished(Status::IOError("xxx"));
+ AssertFailed(fut);
+ }
+ {
+ // MarkFinished(Status)
+ auto fut = Future<>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut);
+ ASSERT_RAISES(IOError, fut.status());
+ }
+}
+
+TEST(FutureSyncTest, GetStatusFuture) {
+ {
+ auto fut = Future<MoveOnlyDataType>::Make();
+ Future<> status_future(fut);
+
+ AssertNotFinished(fut);
+ AssertNotFinished(status_future);
+
+ fut.MarkFinished(MoveOnlyDataType(42));
+ AssertSuccessful(fut);
+ AssertSuccessful(status_future);
+ ASSERT_EQ(&fut.status(), &status_future.status());
+ }
+ {
+ auto fut = Future<MoveOnlyDataType>::Make();
+ Future<> status_future(fut);
+
+ AssertNotFinished(fut);
+ AssertNotFinished(status_future);
+
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut);
+ AssertFailed(status_future);
+ ASSERT_EQ(&fut.status(), &status_future.status());
+ }
+}
+
+// Ensure the implicit convenience constructors behave as desired.
+TEST(FutureSyncTest, ImplicitConstructors) {
+ {
+ auto fut = ([]() -> Future<MoveOnlyDataType> {
+ return arrow::Status::Invalid("Invalid");
+ })();
+ AssertFailed(fut);
+ ASSERT_RAISES(Invalid, fut.result());
+ }
+ {
+ auto fut = ([]() -> Future<MoveOnlyDataType> {
+ return arrow::Result<MoveOnlyDataType>(arrow::Status::Invalid("Invalid"));
+ })();
+ AssertFailed(fut);
+ ASSERT_RAISES(Invalid, fut.result());
+ }
+ {
+ auto fut = ([]() -> Future<MoveOnlyDataType> { return MoveOnlyDataType(42); })();
+ AssertSuccessful(fut);
+ }
+ {
+ auto fut = ([]() -> Future<MoveOnlyDataType> {
+ return arrow::Result<MoveOnlyDataType>(MoveOnlyDataType(42));
+ })();
+ AssertSuccessful(fut);
+ }
+}
+
+TEST(FutureRefTest, ChainRemoved) {
+ // Creating a future chain should not prevent the futures from being deleted if the
+ // entire chain is deleted
+ std::weak_ptr<FutureImpl> ref;
+ std::weak_ptr<FutureImpl> ref2;
+ {
+ auto fut = Future<>::Make();
+ auto fut2 = fut.Then([]() { return Status::OK(); });
+ ref = fut.impl_;
+ ref2 = fut2.impl_;
+ }
+ ASSERT_TRUE(ref.expired());
+ ASSERT_TRUE(ref2.expired());
+
+ {
+ auto fut = Future<>::Make();
+ auto fut2 = fut.Then([]() { return Future<>::Make(); });
+ ref = fut.impl_;
+ ref2 = fut2.impl_;
+ }
+ ASSERT_TRUE(ref.expired());
+ ASSERT_TRUE(ref2.expired());
+}
+
+TEST(FutureRefTest, TailRemoved) {
+ // Keeping the head of the future chain should keep the entire chain alive
+ std::shared_ptr<Future<>> ref;
+ std::weak_ptr<FutureImpl> ref2;
+ bool side_effect_run = false;
+ {
+ ref = std::make_shared<Future<>>(Future<>::Make());
+ auto fut2 = ref->Then([&side_effect_run]() {
+ side_effect_run = true;
+ return Status::OK();
+ });
+ ref2 = fut2.impl_;
+ }
+ ASSERT_FALSE(ref2.expired());
+
+ ref->MarkFinished();
+ ASSERT_TRUE(side_effect_run);
+ ASSERT_TRUE(ref2.expired());
+}
+
+TEST(FutureRefTest, HeadRemoved) {
+ // Keeping the tail of the future chain should not keep the entire chain alive. If no
+ // one has a reference to the head then the future is abandoned. TODO (ARROW-12207):
+ // detect abandonment.
+ std::weak_ptr<FutureImpl> ref;
+ std::shared_ptr<Future<>> ref2;
+ {
+ auto fut = std::make_shared<Future<>>(Future<>::Make());
+ ref = fut->impl_;
+ ref2 = std::make_shared<Future<>>(fut->Then([]() {}));
+ }
+ ASSERT_TRUE(ref.expired());
+
+ {
+ auto fut = Future<>::Make();
+ ref2 = std::make_shared<Future<>>(fut.Then([&]() {
+ auto intermediate = Future<>::Make();
+ ref = intermediate.impl_;
+ return intermediate;
+ }));
+ fut.MarkFinished();
+ }
+ ASSERT_TRUE(ref.expired());
+}
+
+TEST(FutureStressTest, Callback) {
+#ifdef ARROW_VALGRIND
+ const int NITERS = 2;
+#else
+ const int NITERS = 1000;
+#endif
+ for (unsigned int n = 0; n < NITERS; n++) {
+ auto fut = Future<>::Make();
+ std::atomic<unsigned int> count_finished_immediately(0);
+ std::atomic<unsigned int> count_finished_deferred(0);
+ std::atomic<unsigned int> callbacks_added(0);
+ std::atomic<bool> finished(false);
+
+ std::thread callback_adder([&] {
+ auto test_thread = std::this_thread::get_id();
+ while (!finished.load()) {
+ fut.AddCallback([&test_thread, &count_finished_immediately,
+ &count_finished_deferred](const Status& status) {
+ ARROW_EXPECT_OK(status);
+ if (std::this_thread::get_id() == test_thread) {
+ count_finished_immediately++;
+ } else {
+ count_finished_deferred++;
+ }
+ });
+ callbacks_added++;
+ if (callbacks_added.load() > 10000) {
+ // If we've added many callbacks already and the main thread hasn't noticed yet,
+ // help it a bit (this seems especially useful in Valgrind).
+ SleepABit();
+ }
+ }
+ });
+
+ while (callbacks_added.load() == 0) {
+ // Spin until the callback_adder has started running
+ }
+
+ ASSERT_EQ(0, count_finished_deferred.load());
+ ASSERT_EQ(0, count_finished_immediately.load());
+
+ fut.MarkFinished();
+
+ while (count_finished_immediately.load() == 0) {
+ // Spin until the callback_adder has added at least one post-future
+ }
+
+ finished.store(true);
+ callback_adder.join();
+ auto total_added = callbacks_added.load();
+ auto total_immediate = count_finished_immediately.load();
+ auto total_deferred = count_finished_deferred.load();
+ ASSERT_EQ(total_added, total_immediate + total_deferred);
+ }
+}
+
+TEST(FutureStressTest, TryAddCallback) {
+ for (unsigned int n = 0; n < 1; n++) {
+ auto fut = Future<>::Make();
+ std::atomic<unsigned int> callbacks_added(0);
+ std::atomic<bool> finished(false);
+ std::mutex mutex;
+ std::condition_variable cv;
+ std::thread::id callback_adder_thread_id;
+
+ std::thread callback_adder([&] {
+ callback_adder_thread_id = std::this_thread::get_id();
+ std::function<void(const Status&)> callback =
+ [&callback_adder_thread_id](const Status& st) {
+ ARROW_EXPECT_OK(st);
+ if (std::this_thread::get_id() == callback_adder_thread_id) {
+ FAIL() << "TryAddCallback allowed a callback to be run synchronously";
+ }
+ };
+ std::function<std::function<void(const Status&)>()> callback_factory =
+ [&callback]() { return callback; };
+ while (true) {
+ auto callback_added = fut.TryAddCallback(callback_factory);
+ if (callback_added) {
+ callbacks_added++;
+ if (callbacks_added.load() > 10000) {
+ // If we've added many callbacks already and the main thread hasn't
+ // noticed yet, help it a bit (this seems especially useful in Valgrind).
+ SleepABit();
+ }
+ } else {
+ break;
+ }
+ }
+ {
+ std::lock_guard<std::mutex> lg(mutex);
+ finished.store(true);
+ }
+ cv.notify_one();
+ });
+
+ while (callbacks_added.load() == 0) {
+ // Spin until the callback_adder has started running
+ }
+
+ fut.MarkFinished();
+
+ std::unique_lock<std::mutex> lk(mutex);
+ cv.wait_for(lk, std::chrono::duration<double>(0.5),
+ [&finished] { return finished.load(); });
+ lk.unlock();
+
+ ASSERT_TRUE(finished);
+ callback_adder.join();
+ }
+}
+
+TEST(FutureCompletionTest, Void) {
+ {
+ // Simple callback
+ auto fut = Future<int>::Make();
+ int passed_in_result = 0;
+ auto fut2 =
+ fut.Then([&passed_in_result](const int& result) { passed_in_result = result; });
+ fut.MarkFinished(42);
+ AssertSuccessful(fut2);
+ ASSERT_EQ(passed_in_result, 42);
+ }
+ {
+ // Propagate failure by returning it from on_failure
+ auto fut = Future<int>::Make();
+ auto fut2 = fut.Then([](const int&) {}, [](const Status& s) { return s; });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut2);
+ ASSERT_TRUE(fut2.status().IsIOError());
+ }
+ {
+ // From void
+ auto fut = Future<>::Make();
+ auto fut2 = fut.Then([]() {});
+ fut.MarkFinished();
+ AssertSuccessful(fut2);
+ }
+ {
+ // Propagate failure by not having on_failure
+ auto fut = Future<>::Make();
+ auto cb_was_run = false;
+ auto fut2 = fut.Then([&cb_was_run]() {
+ cb_was_run = true;
+ return Status::OK();
+ });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut2);
+ ASSERT_FALSE(cb_was_run);
+ }
+ {
+ // Swallow failure by catching in on_failure
+ auto fut = Future<>::Make();
+ Status status_seen = Status::OK();
+ auto fut2 = fut.Then([]() {},
+ [&status_seen](const Status& s) {
+ status_seen = s;
+ return Status::OK();
+ });
+ ASSERT_TRUE(status_seen.ok());
+ fut.MarkFinished(Status::IOError("xxx"));
+ ASSERT_TRUE(status_seen.IsIOError());
+ AssertSuccessful(fut2);
+ }
+}
+
+TEST(FutureCompletionTest, NonVoid) {
+ {
+ // Simple callback
+ auto fut = Future<int>::Make();
+ auto fut2 = fut.Then([](int result) {
+ auto passed_in_result = result;
+ return passed_in_result * passed_in_result;
+ });
+ fut.MarkFinished(42);
+ AssertSuccessful(fut2);
+ auto result = *fut2.result();
+ ASSERT_EQ(result, 42 * 42);
+ }
+ {
+ // Propagate failure by not having on_failure
+ auto fut = Future<int>::Make();
+ auto cb_was_run = false;
+ auto fut2 = fut.Then([&cb_was_run](int result) {
+ cb_was_run = true;
+ auto passed_in_result = result;
+ return passed_in_result * passed_in_result;
+ });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut2);
+ ASSERT_TRUE(fut2.status().IsIOError());
+ ASSERT_FALSE(cb_was_run);
+ }
+ {
+ // Swallow failure by catching in on_failure
+ auto fut = Future<int>::Make();
+ bool was_io_error = false;
+ auto fut2 = fut.Then([](int) { return 99; },
+ [&was_io_error](const Status& s) {
+ was_io_error = s.IsIOError();
+ return 100;
+ });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertSuccessful(fut2);
+ auto result = *fut2.result();
+ ASSERT_EQ(result, 100);
+ ASSERT_TRUE(was_io_error);
+ }
+ {
+ // From void
+ auto fut = Future<>::Make();
+ auto fut2 = fut.Then([]() { return 42; });
+ fut.MarkFinished();
+ AssertSuccessful(fut2);
+ auto result = *fut2.result();
+ ASSERT_EQ(result, 42);
+ }
+ {
+ // Propagate failure by returning failure
+
+ // Cannot do this. Must return Result<int> because
+ // both callbacks must return the same thing and you can't
+ // return an int from the second callback if you're trying
+ // to propagate a failure
+ }
+}
+
+TEST(FutureCompletionTest, FutureNonVoid) {
+ {
+ // Simple callback
+ auto fut = Future<int>::Make();
+ auto innerFut = Future<std::string>::Make();
+ int passed_in_result = 0;
+ auto fut2 = fut.Then([&passed_in_result, innerFut](int result) {
+ passed_in_result = result;
+ return innerFut;
+ });
+ fut.MarkFinished(42);
+ ASSERT_EQ(passed_in_result, 42);
+ AssertNotFinished(fut2);
+ innerFut.MarkFinished("hello");
+ AssertSuccessful(fut2);
+ auto result = *fut2.result();
+ ASSERT_EQ(result, "hello");
+ }
+ {
+ // Propagate failure by not having on_failure
+ auto fut = Future<int>::Make();
+ auto innerFut = Future<std::string>::Make();
+ auto was_cb_run = false;
+ auto fut2 = fut.Then([innerFut, &was_cb_run](int) {
+ was_cb_run = true;
+ return innerFut;
+ });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut2);
+ ASSERT_TRUE(fut2.status().IsIOError());
+ ASSERT_FALSE(was_cb_run);
+ }
+ {
+ // Swallow failure by catching in on_failure
+ auto fut = Future<int>::Make();
+ auto innerFut = Future<std::string>::Make();
+ bool was_io_error = false;
+ auto was_cb_run = false;
+ auto fut2 = fut.Then(
+ [innerFut, &was_cb_run](int) {
+ was_cb_run = true;
+ return innerFut;
+ },
+ [&was_io_error, innerFut](const Status& s) {
+ was_io_error = s.IsIOError();
+ return innerFut;
+ });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertNotFinished(fut2);
+ innerFut.MarkFinished("hello");
+ AssertSuccessful(fut2);
+ auto result = *fut2.result();
+ ASSERT_EQ(result, "hello");
+ ASSERT_TRUE(was_io_error);
+ ASSERT_FALSE(was_cb_run);
+ }
+ {
+ // From void
+ auto fut = Future<>::Make();
+ auto innerFut = Future<std::string>::Make();
+ auto fut2 = fut.Then([&innerFut]() { return innerFut; });
+ fut.MarkFinished();
+ AssertNotFinished(fut2);
+ innerFut.MarkFinished("hello");
+ AssertSuccessful(fut2);
+ auto result = *fut2.result();
+ ASSERT_EQ(result, "hello");
+ }
+ {
+ // Propagate failure by returning failure
+ auto fut = Future<>::Make();
+ auto innerFut = Future<std::string>::Make();
+ auto was_cb_run = false;
+ auto fut2 = fut.Then(
+ [&innerFut, &was_cb_run]() {
+ was_cb_run = true;
+ return Result<Future<std::string>>(innerFut);
+ },
+ [](const Status& status) { return Result<Future<std::string>>(status); });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut2);
+ ASSERT_FALSE(was_cb_run);
+ }
+}
+
+TEST(FutureCompletionTest, Status) {
+ {
+ // Simple callback
+ auto fut = Future<int>::Make();
+ int passed_in_result = 0;
+ Future<> fut2 = fut.Then([&passed_in_result](int result) {
+ passed_in_result = result;
+ return Status::OK();
+ });
+ fut.MarkFinished(42);
+ ASSERT_EQ(passed_in_result, 42);
+ AssertSuccessful(fut2);
+ }
+ {
+ // Propagate failure by not having on_failure
+ auto fut = Future<int>::Make();
+ auto was_cb_run = false;
+ auto fut2 = fut.Then([&was_cb_run](int) {
+ was_cb_run = true;
+ return Status::OK();
+ });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut2);
+ ASSERT_TRUE(fut2.status().IsIOError());
+ ASSERT_FALSE(was_cb_run);
+ }
+ {
+ // Swallow failure by catching in on_failure
+ auto fut = Future<int>::Make();
+ bool was_io_error = false;
+ auto was_cb_run = false;
+ auto fut2 = fut.Then(
+ [&was_cb_run](int i) {
+ was_cb_run = true;
+ return Status::OK();
+ },
+ [&was_io_error](const Status& s) {
+ was_io_error = s.IsIOError();
+ return Status::OK();
+ });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertSuccessful(fut2);
+ ASSERT_TRUE(was_io_error);
+ ASSERT_FALSE(was_cb_run);
+ }
+ {
+ // From void
+ auto fut = Future<>::Make();
+ auto fut2 = fut.Then([]() { return Status::OK(); });
+ fut.MarkFinished();
+ AssertSuccessful(fut2);
+ }
+ {
+ // Propagate failure by returning failure
+ auto fut = Future<>::Make();
+ auto was_cb_run = false;
+ auto fut2 = fut.Then(
+ [&was_cb_run]() {
+ was_cb_run = true;
+ return Status::OK();
+ },
+ [](const Status& s) { return s; });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut2);
+ ASSERT_FALSE(was_cb_run);
+ }
+}
+
+TEST(FutureCompletionTest, Result) {
+ {
+ // Simple callback
+ auto fut = Future<int>::Make();
+ Future<int> fut2 = fut.Then([](const int& i) {
+ auto passed_in_result = i;
+ return Result<int>(passed_in_result * passed_in_result);
+ });
+ fut.MarkFinished(42);
+ AssertSuccessful(fut2);
+ auto result = *fut2.result();
+ ASSERT_EQ(result, 42 * 42);
+ }
+ {
+ // Propagate failure by not having on_failure
+ auto fut = Future<int>::Make();
+ auto was_cb_run = false;
+ auto fut2 = fut.Then([&was_cb_run](const int& i) {
+ was_cb_run = true;
+ auto passed_in_result = i;
+ return Result<int>(passed_in_result * passed_in_result);
+ });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut2);
+ ASSERT_TRUE(fut2.status().IsIOError());
+ ASSERT_FALSE(was_cb_run);
+ }
+ {
+ // Swallow failure by catching in on_failure
+ auto fut = Future<int>::Make();
+ bool was_io_error = false;
+ bool was_cb_run = false;
+ auto fut2 = fut.Then(
+ [&was_cb_run](const int& i) {
+ was_cb_run = true;
+ return Result<int>(100);
+ },
+ [&was_io_error](const Status& s) {
+ was_io_error = s.IsIOError();
+ return Result<int>(100);
+ });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertSuccessful(fut2);
+ auto result = *fut2.result();
+ ASSERT_EQ(result, 100);
+ ASSERT_TRUE(was_io_error);
+ ASSERT_FALSE(was_cb_run);
+ }
+ {
+ // From void
+ auto fut = Future<>::Make();
+ auto fut2 = fut.Then([]() { return Result<int>(42); });
+ fut.MarkFinished();
+ AssertSuccessful(fut2);
+ auto result = *fut2.result();
+ ASSERT_EQ(result, 42);
+ }
+ {
+ // Propagate failure by returning failure
+ auto fut = Future<>::Make();
+ auto was_cb_run = false;
+ auto fut2 = fut.Then(
+ [&was_cb_run]() {
+ was_cb_run = true;
+ return Result<int>(42);
+ },
+ [](const Status& s) { return Result<int>(s); });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut2);
+ ASSERT_FALSE(was_cb_run);
+ }
+}
+
+TEST(FutureCompletionTest, FutureVoid) {
+ {
+ // Simple callback
+ auto fut = Future<int>::Make();
+ auto innerFut = Future<>::Make();
+ int passed_in_result = 0;
+ auto fut2 = fut.Then([&passed_in_result, innerFut](int i) {
+ passed_in_result = i;
+ return innerFut;
+ });
+ fut.MarkFinished(42);
+ AssertNotFinished(fut2);
+ innerFut.MarkFinished();
+ AssertSuccessful(fut2);
+ auto res = fut2.status();
+ ASSERT_OK(res);
+ ASSERT_EQ(passed_in_result, 42);
+ }
+ {
+ // Precompleted future
+ auto fut = Future<int>::Make();
+ auto innerFut = Future<>::Make();
+ innerFut.MarkFinished();
+ int passed_in_result = 0;
+ auto fut2 = fut.Then([&passed_in_result, innerFut](int i) {
+ passed_in_result = i;
+ return innerFut;
+ });
+ AssertNotFinished(fut2);
+ fut.MarkFinished(42);
+ AssertSuccessful(fut2);
+ ASSERT_EQ(passed_in_result, 42);
+ }
+ {
+ // Propagate failure by not having on_failure
+ auto fut = Future<int>::Make();
+ auto innerFut = Future<>::Make();
+ auto was_cb_run = false;
+ auto fut2 = fut.Then([innerFut, &was_cb_run](int) {
+ was_cb_run = true;
+ return innerFut;
+ });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut2);
+ if (IsFutureFinished(fut2.state())) {
+ ASSERT_TRUE(fut2.status().IsIOError());
+ }
+ ASSERT_FALSE(was_cb_run);
+ }
+ {
+ // Swallow failure by catching in on_failure
+ auto fut = Future<int>::Make();
+ auto innerFut = Future<>::Make();
+ auto was_cb_run = false;
+ auto fut2 = fut.Then(
+ [innerFut, &was_cb_run](int) {
+ was_cb_run = true;
+ return innerFut;
+ },
+ [innerFut](const Status& s) { return innerFut; });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertNotFinished(fut2);
+ innerFut.MarkFinished();
+ AssertSuccessful(fut2);
+ ASSERT_FALSE(was_cb_run);
+ }
+ {
+ // From void
+ auto fut = Future<>::Make();
+ auto innerFut = Future<>::Make();
+ auto fut2 = fut.Then([&innerFut]() { return innerFut; });
+ fut.MarkFinished();
+ AssertNotFinished(fut2);
+ innerFut.MarkFinished();
+ AssertSuccessful(fut2);
+ }
+ {
+ // Propagate failure by returning failure
+ auto fut = Future<>::Make();
+ auto innerFut = Future<>::Make();
+ auto fut2 = fut.Then([&innerFut]() { return innerFut; },
+ [](const Status& s) { return Future<>::MakeFinished(s); });
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut2);
+ }
+}
+
+class FutureSchedulingTest : public testing::Test {
+ public:
+ internal::Executor* executor() { return mock_executor.get(); }
+
+ int spawn_count() { return static_cast<int>(mock_executor->captured_tasks.size()); }
+
+ void AssertRunSynchronously(const std::vector<int>& ids) { AssertIds(ids, true); }
+
+ void AssertScheduled(const std::vector<int>& ids) { AssertIds(ids, false); }
+
+ void AssertIds(const std::vector<int>& ids, bool should_be_synchronous) {
+ for (auto id : ids) {
+ ASSERT_EQ(should_be_synchronous, callbacks_run_synchronously.find(id) !=
+ callbacks_run_synchronously.end());
+ }
+ }
+
+ std::function<void(const Status&)> callback(int id) {
+ return [this, id](const Status&) { callbacks_run_synchronously.insert(id); };
+ }
+
+ std::shared_ptr<DelayedExecutor> mock_executor = std::make_shared<DelayedExecutor>();
+ std::unordered_set<int> callbacks_run_synchronously;
+};
+
+TEST_F(FutureSchedulingTest, ScheduleNever) {
+ CallbackOptions options;
+ options.should_schedule = ShouldSchedule::Never;
+ options.executor = executor();
+ // Successful future
+ {
+ auto fut = Future<>::Make();
+ fut.AddCallback(callback(1), options);
+ fut.MarkFinished();
+ fut.AddCallback(callback(2), options);
+ ASSERT_EQ(0, spawn_count());
+ AssertRunSynchronously({1, 2});
+ }
+ // Failing future
+ {
+ auto fut = Future<>::Make();
+ fut.AddCallback(callback(3), options);
+ fut.MarkFinished(Status::Invalid("XYZ"));
+ fut.AddCallback(callback(4), options);
+ ASSERT_EQ(0, spawn_count());
+ AssertRunSynchronously({3, 4});
+ }
+}
+
+TEST_F(FutureSchedulingTest, ScheduleAlways) {
+ CallbackOptions options;
+ options.should_schedule = ShouldSchedule::Always;
+ options.executor = executor();
+ // Successful future
+ {
+ auto fut = Future<>::Make();
+ fut.AddCallback(callback(1), options);
+ fut.MarkFinished();
+ fut.AddCallback(callback(2), options);
+ ASSERT_EQ(2, spawn_count());
+ AssertScheduled({1, 2});
+ }
+ // Failing future
+ {
+ auto fut = Future<>::Make();
+ fut.AddCallback(callback(3), options);
+ fut.MarkFinished(Status::Invalid("XYZ"));
+ fut.AddCallback(callback(4), options);
+ ASSERT_EQ(4, spawn_count());
+ AssertScheduled({3, 4});
+ }
+}
+
+TEST_F(FutureSchedulingTest, ScheduleIfUnfinished) {
+ CallbackOptions options;
+ options.should_schedule = ShouldSchedule::IfUnfinished;
+ options.executor = executor();
+ // Successful future
+ {
+ auto fut = Future<>::Make();
+ fut.AddCallback(callback(1), options);
+ fut.MarkFinished();
+ fut.AddCallback(callback(2), options);
+ ASSERT_EQ(1, spawn_count());
+ AssertRunSynchronously({2});
+ AssertScheduled({1});
+ }
+ // Failing future
+ {
+ auto fut = Future<>::Make();
+ fut.AddCallback(callback(3), options);
+ fut.MarkFinished(Status::Invalid("XYZ"));
+ fut.AddCallback(callback(4), options);
+ ASSERT_EQ(2, spawn_count());
+ AssertRunSynchronously({4});
+ AssertScheduled({3});
+ }
+}
+
+TEST_F(FutureSchedulingTest, ScheduleIfDifferentExecutor) {
+ struct : internal::Executor {
+ int GetCapacity() override { return pool_->GetCapacity(); }
+
+ bool OwnsThisThread() override { return pool_->OwnsThisThread(); }
+
+ Status SpawnReal(internal::TaskHints hints, internal::FnOnce<void()> task,
+ StopToken stop_token, StopCallback&& stop_callback) override {
+ ++spawn_count;
+ return pool_->Spawn(hints, std::move(task), std::move(stop_token),
+ std::move(stop_callback));
+ }
+
+ std::atomic<int> spawn_count{0};
+ internal::Executor* pool_ = internal::GetCpuThreadPool();
+ } executor;
+
+ CallbackOptions options;
+ options.executor = &executor;
+ options.should_schedule = ShouldSchedule::IfDifferentExecutor;
+ auto pass_err = [](const Status& s) { return s; };
+
+ std::atomic<bool> fut0_on_executor{false};
+ std::atomic<bool> fut1_on_executor{false};
+
+ auto fut0 = Future<>::Make();
+ auto fut1 = Future<>::Make();
+
+ auto fut0_done = fut0.Then(
+ [&] {
+ // marked finished on main thread -> must be scheduled to executor
+ fut0_on_executor.store(executor.OwnsThisThread());
+
+ fut1.MarkFinished();
+ },
+ pass_err, options);
+
+ auto fut1_done = fut1.Then(
+ [&] {
+ // marked finished on executor -> no need to schedule
+ fut1_on_executor.store(executor.OwnsThisThread());
+ },
+ pass_err, options);
+
+ fut0.MarkFinished();
+
+ AllComplete({fut0_done, fut1_done}).Wait();
+
+ ASSERT_EQ(executor.spawn_count, 1);
+ ASSERT_TRUE(fut0_on_executor);
+ ASSERT_TRUE(fut1_on_executor);
+}
+
+TEST_F(FutureSchedulingTest, ScheduleAlwaysKeepsFutureAliveUntilCallback) {
+ CallbackOptions options;
+ options.should_schedule = ShouldSchedule::Always;
+ options.executor = executor();
+ {
+ auto fut = Future<int>::Make();
+ fut.AddCallback([](const Result<int> val) { ASSERT_EQ(7, *val); }, options);
+ fut.MarkFinished(7);
+ }
+ std::move(mock_executor->captured_tasks[0])();
+}
+
+TEST(FutureAllTest, Empty) {
+ auto combined = arrow::All(std::vector<Future<int>>{});
+ auto after_assert = combined.Then(
+ [](std::vector<Result<int>> results) { ASSERT_EQ(0, results.size()); });
+ AssertSuccessful(after_assert);
+}
+
+TEST(FutureAllTest, Simple) {
+ auto f1 = Future<int>::Make();
+ auto f2 = Future<int>::Make();
+ std::vector<Future<int>> futures = {f1, f2};
+ auto combined = arrow::All(futures);
+
+ auto after_assert = combined.Then([](std::vector<Result<int>> results) {
+ ASSERT_EQ(2, results.size());
+ ASSERT_EQ(1, *results[0]);
+ ASSERT_EQ(2, *results[1]);
+ });
+
+ // Finish in reverse order, results should still be delivered in proper order
+ AssertNotFinished(after_assert);
+ f2.MarkFinished(2);
+ AssertNotFinished(after_assert);
+ f1.MarkFinished(1);
+ AssertSuccessful(after_assert);
+}
+
+TEST(FutureAllTest, Failure) {
+ auto f1 = Future<int>::Make();
+ auto f2 = Future<int>::Make();
+ auto f3 = Future<int>::Make();
+ std::vector<Future<int>> futures = {f1, f2, f3};
+ auto combined = arrow::All(futures);
+
+ auto after_assert = combined.Then([](std::vector<Result<int>> results) {
+ ASSERT_EQ(3, results.size());
+ ASSERT_EQ(1, *results[0]);
+ ASSERT_EQ(Status::IOError("XYZ"), results[1].status());
+ ASSERT_EQ(3, *results[2]);
+ });
+
+ f1.MarkFinished(1);
+ f2.MarkFinished(Status::IOError("XYZ"));
+ f3.MarkFinished(3);
+
+ AssertFinished(after_assert);
+}
+
+TEST(FutureAllCompleteTest, Empty) {
+ Future<> combined = AllComplete(std::vector<Future<>>{});
+ AssertSuccessful(combined);
+}
+
+TEST(FutureAllCompleteTest, Simple) {
+ auto f1 = Future<int>::Make();
+ auto f2 = Future<int>::Make();
+ std::vector<Future<>> futures = {Future<>(f1), Future<>(f2)};
+ auto combined = AllComplete(futures);
+ AssertNotFinished(combined);
+ f2.MarkFinished(2);
+ AssertNotFinished(combined);
+ f1.MarkFinished(1);
+ AssertSuccessful(combined);
+}
+
+TEST(FutureAllCompleteTest, Failure) {
+ auto f1 = Future<int>::Make();
+ auto f2 = Future<int>::Make();
+ auto f3 = Future<int>::Make();
+ std::vector<Future<>> futures = {Future<>(f1), Future<>(f2), Future<>(f3)};
+ auto combined = AllComplete(futures);
+ AssertNotFinished(combined);
+ f1.MarkFinished(1);
+ AssertNotFinished(combined);
+ f2.MarkFinished(Status::IOError("XYZ"));
+ AssertFinished(combined);
+ f3.MarkFinished(3);
+ AssertFinished(combined);
+ ASSERT_EQ(Status::IOError("XYZ"), combined.status());
+}
+
+TEST(FutureLoopTest, Sync) {
+ struct {
+ int i = 0;
+ Future<int> Get() { return Future<int>::MakeFinished(i++); }
+ } IntSource;
+
+ bool do_fail = false;
+ std::vector<int> ints;
+ auto loop_body = [&] {
+ return IntSource.Get().Then([&](int i) -> Result<ControlFlow<int>> {
+ if (do_fail && i == 3) {
+ return Status::IOError("xxx");
+ }
+
+ if (i == 5) {
+ int sum = 0;
+ for (int i : ints) sum += i;
+ return Break(sum);
+ }
+
+ ints.push_back(i);
+ return Continue();
+ });
+ };
+
+ {
+ auto sum_fut = Loop(loop_body);
+ AssertSuccessful(sum_fut);
+
+ ASSERT_OK_AND_ASSIGN(auto sum, sum_fut.result());
+ ASSERT_EQ(sum, 0 + 1 + 2 + 3 + 4);
+ }
+
+ {
+ do_fail = true;
+ IntSource.i = 0;
+ auto sum_fut = Loop(loop_body);
+ AssertFailed(sum_fut);
+ ASSERT_RAISES(IOError, sum_fut.result());
+ }
+}
+
+TEST(FutureLoopTest, EmptyBreakValue) {
+ Future<> none_fut =
+ Loop([&] { return Future<>::MakeFinished().Then([&]() { return Break(); }); });
+ AssertSuccessful(none_fut);
+}
+
+TEST(FutureLoopTest, EmptyLoop) {
+ auto loop_body = []() -> Future<ControlFlow<int>> {
+ return Future<ControlFlow<int>>::MakeFinished(Break(0));
+ };
+ auto loop_fut = Loop(loop_body);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto loop_res, loop_fut);
+ ASSERT_EQ(loop_res, 0);
+}
+
+// TODO - Test provided by Ben but I don't understand how it can pass legitimately.
+// Any future result will be passed by reference to the callbacks (as there can be
+// multiple callbacks). In the Loop construct it takes the break and forwards it
+// on to the outer future. Since there is no way to move a reference this can only
+// be done by copying.
+//
+// In theory it should be safe since Loop is guaranteed to be the last callback added
+// to the control future and so the value can be safely moved at that point. However,
+// I'm unable to reproduce whatever trick you had in ControlFlow to make this work.
+// If we want to formalize this "last callback can steal" concept then we could add
+// a "last callback" to Future which gets called with an rvalue instead of an lvalue
+// reference but that seems overly complicated.
+//
+// Ben, can you recreate whatever trick you had in place before that allowed this to
+// pass? Perhaps some kind of cast. Worst case, I can move back to using
+// ControlFlow instead of std::optional
+//
+// TEST(FutureLoopTest, MoveOnlyBreakValue) {
+// Future<MoveOnlyDataType> one_fut = Loop([&] {
+// return Future<int>::MakeFinished(1).Then(
+// [&](int i) { return Break(MoveOnlyDataType(i)); });
+// });
+// AssertSuccessful(one_fut);
+// ASSERT_OK_AND_ASSIGN(auto one, std::move(one_fut).result());
+// ASSERT_EQ(one, 1);
+// }
+
+TEST(FutureLoopTest, StackOverflow) {
+ // Looping over futures is normally a rather recursive task. If the futures complete
+ // synchronously (because they are already finished) it could lead to a stack overflow
+ // if care is not taken.
+ int counter = 0;
+ auto loop_body = [&counter]() -> Future<ControlFlow<int>> {
+ while (counter < 1000000) {
+ counter++;
+ return Future<ControlFlow<int>>::MakeFinished(Continue());
+ }
+ return Future<ControlFlow<int>>::MakeFinished(Break(-1));
+ };
+ auto loop_fut = Loop(loop_body);
+ ASSERT_TRUE(loop_fut.Wait(0.1));
+}
+
+TEST(FutureLoopTest, AllowsBreakFutToBeDiscarded) {
+ int counter = 0;
+ auto loop_body = [&counter]() -> Future<ControlFlow<int>> {
+ while (counter < 10) {
+ counter++;
+ return Future<ControlFlow<int>>::MakeFinished(Continue());
+ }
+ return Future<ControlFlow<int>>::MakeFinished(Break(-1));
+ };
+ auto loop_fut = Loop(loop_body).Then([](const int&) { return Status::OK(); });
+ ASSERT_TRUE(loop_fut.Wait(0.1));
+}
+
+class MoveTrackingCallable {
+ public:
+ MoveTrackingCallable() {
+ // std::cout << "CONSTRUCT" << std::endl;
+ }
+ ~MoveTrackingCallable() {
+ valid_ = false;
+ // std::cout << "DESTRUCT" << std::endl;
+ }
+ MoveTrackingCallable(const MoveTrackingCallable& other) {
+ // std::cout << "COPY CONSTRUCT" << std::endl;
+ }
+ MoveTrackingCallable(MoveTrackingCallable&& other) {
+ other.valid_ = false;
+ // std::cout << "MOVE CONSTRUCT" << std::endl;
+ }
+ MoveTrackingCallable& operator=(const MoveTrackingCallable& other) {
+ // std::cout << "COPY ASSIGN" << std::endl;
+ return *this;
+ }
+ MoveTrackingCallable& operator=(MoveTrackingCallable&& other) {
+ other.valid_ = false;
+ // std::cout << "MOVE ASSIGN" << std::endl;
+ return *this;
+ }
+
+ Status operator()() {
+ // std::cout << "TRIGGER" << std::endl;
+ if (valid_) {
+ return Status::OK();
+ } else {
+ return Status::Invalid("Invalid callback triggered");
+ }
+ }
+
+ private:
+ bool valid_ = true;
+};
+
+TEST(FutureCompletionTest, ReuseCallback) {
+ auto fut = Future<>::Make();
+
+ Future<> continuation;
+ {
+ MoveTrackingCallable callback;
+ continuation = fut.Then(callback);
+ }
+
+ fut.MarkFinished();
+
+ ASSERT_TRUE(continuation.is_finished());
+ if (continuation.is_finished()) {
+ ASSERT_OK(continuation.status());
+ }
+}
+
+// --------------------------------------------------------------------
+// Tests with an executor
+
+template <typename T>
+class FutureTestBase : public ::testing::Test {
+ public:
+ using ExecutorType = SimpleExecutor<T>;
+
+ void MakeExecutor(int nfutures) { executor_.reset(new ExecutorType(nfutures)); }
+
+ void MakeExecutor(int nfutures, std::vector<std::pair<int, bool>> immediate) {
+ MakeExecutor(nfutures);
+ executor_->SetFinished(std::move(immediate));
+ }
+
+ template <typename U>
+ void RandomShuffle(std::vector<U>* values) {
+ std::default_random_engine gen(seed_++);
+ std::shuffle(values->begin(), values->end(), gen);
+ }
+
+ // Generate a sequence of randomly-sized ordered spans covering exactly [0, size).
+ // Returns a vector of (start, stop) pairs.
+ std::vector<std::pair<int, int>> RandomSequenceSpans(int size) {
+ std::default_random_engine gen(seed_++);
+ // The distribution of span sizes
+ std::poisson_distribution<int> dist(5);
+ std::vector<std::pair<int, int>> spans;
+ int start = 0;
+ while (start < size) {
+ int stop = std::min(start + dist(gen), size);
+ spans.emplace_back(start, stop);
+ start = stop;
+ }
+ return spans;
+ }
+
+ void AssertAllNotFinished(const std::vector<int>& future_indices) {
+ const auto& futures = executor_->futures();
+ for (const auto fut_index : future_indices) {
+ AssertNotFinished(futures[fut_index]);
+ }
+ }
+
+ // Assert the given futures are *eventually* successful
+ void AssertAllSuccessful(const std::vector<int>& future_indices) {
+ const auto& futures = executor_->futures();
+ for (const auto fut_index : future_indices) {
+ ASSERT_OK(futures[fut_index].status());
+ ASSERT_EQ(*futures[fut_index].result(), fut_index);
+ }
+ }
+
+ // Assert the given futures are *eventually* failed
+ void AssertAllFailed(const std::vector<int>& future_indices) {
+ const auto& futures = executor_->futures();
+ for (const auto fut_index : future_indices) {
+ ASSERT_RAISES(UnknownError, futures[fut_index].status());
+ }
+ }
+
+ // Assert the given futures are *eventually* successful
+ void AssertSpanSuccessful(int start, int stop) {
+ const auto& futures = executor_->futures();
+ for (int fut_index = start; fut_index < stop; ++fut_index) {
+ ASSERT_OK(futures[fut_index].status());
+ ASSERT_EQ(*futures[fut_index].result(), fut_index);
+ }
+ }
+
+ void AssertAllSuccessful() {
+ AssertSpanSuccessful(0, static_cast<int>(executor_->futures().size()));
+ }
+
+ // Assert the given futures are successful *now*
+ void AssertSpanSuccessfulNow(int start, int stop) {
+ const auto& futures = executor_->futures();
+ for (int fut_index = start; fut_index < stop; ++fut_index) {
+ ASSERT_TRUE(IsFutureFinished(futures[fut_index].state()));
+ }
+ }
+
+ void AssertAllSuccessfulNow() {
+ AssertSpanSuccessfulNow(0, static_cast<int>(executor_->futures().size()));
+ }
+
+ void TestBasicWait() {
+ MakeExecutor(4, {{1, true}, {2, false}});
+ AssertAllNotFinished({0, 3});
+ AssertAllSuccessful({1});
+ AssertAllFailed({2});
+ AssertAllNotFinished({0, 3});
+ executor_->SetFinishedDeferred({{0, true}, {3, true}});
+ AssertAllSuccessful({0, 1, 3});
+ }
+
+ void TestTimedWait() {
+ MakeExecutor(2);
+ const auto& futures = executor_->futures();
+ ASSERT_FALSE(futures[0].Wait(kTinyWait));
+ ASSERT_FALSE(futures[1].Wait(kTinyWait));
+ AssertAllNotFinished({0, 1});
+ executor_->SetFinishedDeferred({{0, true}, {1, true}});
+ ASSERT_TRUE(futures[0].Wait(kLargeWait));
+ ASSERT_TRUE(futures[1].Wait(kLargeWait));
+ AssertAllSuccessfulNow();
+ }
+
+ void TestStressWait() {
+#ifdef ARROW_VALGRIND
+ const int N = 20;
+#else
+ const int N = 2000;
+#endif
+ MakeExecutor(N);
+ const auto& futures = executor_->futures();
+ const auto spans = RandomSequenceSpans(N);
+ for (const auto& span : spans) {
+ int start = span.first, stop = span.second;
+ executor_->SetFinishedDeferred(start, stop);
+ AssertSpanSuccessful(start, stop);
+ if (stop < N) {
+ AssertNotFinished(futures[stop]);
+ }
+ }
+ AssertAllSuccessful();
+ }
+
+ void TestBasicWaitForAny() {
+ MakeExecutor(4, {{1, true}, {2, false}});
+ auto& futures = executor_->futures();
+
+ std::vector<Future<T>*> wait_on = {&futures[0], &futures[1]};
+ auto finished = WaitForAny(wait_on);
+ ASSERT_THAT(finished, testing::ElementsAre(1));
+
+ wait_on = {&futures[1], &futures[2], &futures[3]};
+ while (finished.size() < 2) {
+ finished = WaitForAny(wait_on);
+ }
+ ASSERT_THAT(finished, testing::UnorderedElementsAre(0, 1));
+
+ executor_->SetFinished(3);
+ finished = WaitForAny(futures);
+ ASSERT_THAT(finished, testing::UnorderedElementsAre(1, 2, 3));
+
+ executor_->SetFinishedDeferred(0);
+ // Busy wait until the state change is done
+ while (finished.size() < 4) {
+ finished = WaitForAny(futures);
+ }
+ ASSERT_THAT(finished, testing::UnorderedElementsAre(0, 1, 2, 3));
+ }
+
+ void TestTimedWaitForAny() {
+ MakeExecutor(4, {{1, true}, {2, false}});
+ auto& futures = executor_->futures();
+
+ std::vector<int> finished;
+ std::vector<Future<T>*> wait_on = {&futures[0], &futures[3]};
+ finished = WaitForAny(wait_on, kTinyWait);
+ ASSERT_EQ(finished.size(), 0);
+
+ executor_->SetFinished(3);
+ finished = WaitForAny(wait_on, kLargeWait);
+ ASSERT_THAT(finished, testing::ElementsAre(1));
+
+ executor_->SetFinished(0);
+ while (finished.size() < 2) {
+ finished = WaitForAny(wait_on, kTinyWait);
+ }
+ ASSERT_THAT(finished, testing::UnorderedElementsAre(0, 1));
+
+ while (finished.size() < 4) {
+ finished = WaitForAny(futures, kTinyWait);
+ }
+ ASSERT_THAT(finished, testing::UnorderedElementsAre(0, 1, 2, 3));
+ }
+
+ void TestBasicWaitForAll() {
+ MakeExecutor(4, {{1, true}, {2, false}});
+ auto& futures = executor_->futures();
+
+ std::vector<Future<T>*> wait_on = {&futures[1], &futures[2]};
+ WaitForAll(wait_on);
+ AssertSpanSuccessfulNow(1, 3);
+
+ executor_->SetFinishedDeferred({{0, true}, {3, false}});
+ WaitForAll(futures);
+ AssertAllSuccessfulNow();
+ WaitForAll(futures);
+ }
+
+ void TestTimedWaitForAll() {
+ MakeExecutor(4, {{1, true}, {2, false}});
+ auto& futures = executor_->futures();
+
+ ASSERT_FALSE(WaitForAll(futures, kTinyWait));
+
+ executor_->SetFinishedDeferred({{0, true}, {3, false}});
+ ASSERT_TRUE(WaitForAll(futures, kLargeWait));
+ AssertAllSuccessfulNow();
+ }
+
+ void TestStressWaitForAny() {
+#ifdef ARROW_VALGRIND
+ const int N = 5;
+#else
+ const int N = 300;
+#endif
+ MakeExecutor(N);
+ const auto& futures = executor_->futures();
+ const auto spans = RandomSequenceSpans(N);
+ std::vector<int> finished;
+ // Note this loop is potentially O(N**2), because we're copying
+ // O(N)-sized vector when waiting.
+ for (const auto& span : spans) {
+ int start = span.first, stop = span.second;
+ executor_->SetFinishedDeferred(start, stop);
+ size_t last_finished_size = finished.size();
+ finished = WaitForAny(futures);
+ ASSERT_GE(finished.size(), last_finished_size);
+ // The spans are contiguous and ordered, so `stop` is also the number
+ // of futures for which SetFinishedDeferred() was called.
+ ASSERT_LE(finished.size(), static_cast<size_t>(stop));
+ }
+ // Semi-busy wait for all futures to be finished
+ while (finished.size() < static_cast<size_t>(N)) {
+ finished = WaitForAny(futures);
+ }
+ AssertAllSuccessfulNow();
+ }
+
+ void TestStressWaitForAll() {
+#ifdef ARROW_VALGRIND
+ const int N = 5;
+#else
+ const int N = 300;
+#endif
+ MakeExecutor(N);
+ const auto& futures = executor_->futures();
+ const auto spans = RandomSequenceSpans(N);
+ // Note this loop is potentially O(N**2), because we're copying
+ // O(N)-sized vector when waiting.
+ for (const auto& span : spans) {
+ int start = span.first, stop = span.second;
+ executor_->SetFinishedDeferred(start, stop);
+ bool finished = WaitForAll(futures, kTinyWait);
+ if (stop < N) {
+ ASSERT_FALSE(finished);
+ }
+ }
+ ASSERT_TRUE(WaitForAll(futures, kLargeWait));
+ AssertAllSuccessfulNow();
+ }
+
+ void TestBasicAsCompleted() {
+ {
+ MakeExecutor(4, {{1, true}, {2, true}});
+ executor_->SetFinishedDeferred({{0, true}, {3, true}});
+ auto it = MakeAsCompletedIterator(executor_->futures());
+ std::vector<T> values = IteratorToVector(std::move(it));
+ ASSERT_THAT(values, testing::UnorderedElementsAre(0, 1, 2, 3));
+ }
+ {
+ // Check that AsCompleted is opportunistic, it yields elements in order
+ // of completion.
+ MakeExecutor(4, {{2, true}});
+ auto it = MakeAsCompletedIterator(executor_->futures());
+ ASSERT_OK_AND_EQ(2, it.Next());
+ executor_->SetFinishedDeferred({{3, true}});
+ ASSERT_OK_AND_EQ(3, it.Next());
+ executor_->SetFinishedDeferred({{0, true}});
+ ASSERT_OK_AND_EQ(0, it.Next());
+ executor_->SetFinishedDeferred({{1, true}});
+ ASSERT_OK_AND_EQ(1, it.Next());
+ ASSERT_OK_AND_EQ(IterationTraits<T>::End(), it.Next());
+ ASSERT_OK_AND_EQ(IterationTraits<T>::End(), it.Next()); // idempotent
+ }
+ }
+
+ void TestErrorsAsCompleted() {
+ MakeExecutor(4, {{1, true}, {2, false}});
+ executor_->SetFinishedDeferred({{0, true}, {3, false}});
+ auto it = MakeAsCompletedIterator(executor_->futures());
+ auto results = IteratorToResults(std::move(it));
+ ASSERT_THAT(results.values, testing::UnorderedElementsAre(0, 1));
+ ASSERT_EQ(results.errors.size(), 2);
+ ASSERT_RAISES(UnknownError, results.errors[0]);
+ ASSERT_RAISES(UnknownError, results.errors[1]);
+ }
+
+ void TestStressAsCompleted() {
+#ifdef ARROW_VALGRIND
+ const int N = 10;
+#else
+ const int N = 1000;
+#endif
+ MakeExecutor(N);
+
+ // Launch a worker thread that will finish random spans of futures,
+ // in random order.
+ auto spans = RandomSequenceSpans(N);
+ RandomShuffle(&spans);
+ auto feed_iterator = [&]() {
+ for (const auto& span : spans) {
+ int start = span.first, stop = span.second;
+ executor_->SetFinishedDeferred(start, stop); // will sleep a bit
+ }
+ };
+ auto worker = std::thread(std::move(feed_iterator));
+ auto it = MakeAsCompletedIterator(executor_->futures());
+ auto results = IteratorToResults(std::move(it));
+ worker.join();
+
+ ASSERT_EQ(results.values.size(), static_cast<size_t>(N));
+ ASSERT_EQ(results.errors.size(), 0);
+ std::vector<int> expected(N);
+ std::iota(expected.begin(), expected.end(), 0);
+ std::vector<int> actual(N);
+ std::transform(results.values.begin(), results.values.end(), actual.begin(),
+ [](const T& value) { return value.ToInt(); });
+ std::sort(actual.begin(), actual.end());
+ ASSERT_EQ(expected, actual);
+ }
+
+ protected:
+ std::unique_ptr<ExecutorType> executor_;
+ int seed_ = 42;
+};
+
+template <typename T>
+class FutureWaitTest : public FutureTestBase<T> {};
+
+using FutureWaitTestTypes = ::testing::Types<int, Foo, MoveOnlyDataType>;
+
+TYPED_TEST_SUITE(FutureWaitTest, FutureWaitTestTypes);
+
+TYPED_TEST(FutureWaitTest, BasicWait) { this->TestBasicWait(); }
+
+TYPED_TEST(FutureWaitTest, TimedWait) { this->TestTimedWait(); }
+
+TYPED_TEST(FutureWaitTest, StressWait) { this->TestStressWait(); }
+
+TYPED_TEST(FutureWaitTest, BasicWaitForAny) { this->TestBasicWaitForAny(); }
+
+TYPED_TEST(FutureWaitTest, TimedWaitForAny) { this->TestTimedWaitForAny(); }
+
+TYPED_TEST(FutureWaitTest, StressWaitForAny) { this->TestStressWaitForAny(); }
+
+TYPED_TEST(FutureWaitTest, BasicWaitForAll) { this->TestBasicWaitForAll(); }
+
+TYPED_TEST(FutureWaitTest, TimedWaitForAll) { this->TestTimedWaitForAll(); }
+
+TYPED_TEST(FutureWaitTest, StressWaitForAll) { this->TestStressWaitForAll(); }
+
+template <typename T>
+class FutureIteratorTest : public FutureTestBase<T> {};
+
+using FutureIteratorTestTypes = ::testing::Types<Foo>;
+
+TYPED_TEST_SUITE(FutureIteratorTest, FutureIteratorTestTypes);
+
+TYPED_TEST(FutureIteratorTest, BasicAsCompleted) { this->TestBasicAsCompleted(); }
+
+TYPED_TEST(FutureIteratorTest, ErrorsAsCompleted) { this->TestErrorsAsCompleted(); }
+
+TYPED_TEST(FutureIteratorTest, StressAsCompleted) { this->TestStressAsCompleted(); }
+
+namespace internal {
+TEST(FnOnceTest, MoveOnlyDataType) {
+ // ensuring this is valid guarantees we are making no unnecessary copies
+ FnOnce<int(const MoveOnlyDataType&, MoveOnlyDataType, std::string)> fn =
+ [](const MoveOnlyDataType& i0, MoveOnlyDataType i1, std::string copyable) {
+ return *i0.data + *i1.data + (i0.moves * 1000) + (i1.moves * 100);
+ };
+
+ using arg0 = call_traits::argument_type<0, decltype(fn)>;
+ using arg1 = call_traits::argument_type<1, decltype(fn)>;
+ using arg2 = call_traits::argument_type<2, decltype(fn)>;
+ static_assert(std::is_same<arg0, const MoveOnlyDataType&>::value, "");
+ static_assert(std::is_same<arg1, MoveOnlyDataType>::value, "");
+ static_assert(std::is_same<arg2, std::string>::value,
+ "should not add a && to the call type (demanding rvalue unnecessarily)");
+
+ MoveOnlyDataType i0{1}, i1{41};
+ std::string copyable = "";
+ ASSERT_EQ(std::move(fn)(i0, std::move(i1), copyable), 242);
+ ASSERT_EQ(i0.moves, 0);
+ ASSERT_EQ(i1.moves, 0);
+}
+
+TEST(FutureTest, MatcherExamples) {
+ EXPECT_THAT(Future<int>::MakeFinished(Status::Invalid("arbitrary error")),
+ Finishes(Raises(StatusCode::Invalid)));
+
+ EXPECT_THAT(Future<int>::MakeFinished(Status::Invalid("arbitrary error")),
+ Finishes(Raises(StatusCode::Invalid, testing::HasSubstr("arbitrary"))));
+
+ // message doesn't match, so no match
+ EXPECT_THAT(Future<int>::MakeFinished(Status::Invalid("arbitrary error")),
+ Finishes(testing::Not(
+ Raises(StatusCode::Invalid, testing::HasSubstr("reasonable")))));
+
+ // different error code, so no match
+ EXPECT_THAT(Future<int>::MakeFinished(Status::TypeError("arbitrary error")),
+ Finishes(testing::Not(Raises(StatusCode::Invalid))));
+
+ // not an error, so no match
+ EXPECT_THAT(Future<int>::MakeFinished(333),
+ Finishes(testing::Not(Raises(StatusCode::Invalid))));
+
+ EXPECT_THAT(Future<std::string>::MakeFinished("hello world"),
+ Finishes(ResultWith(testing::HasSubstr("hello"))));
+
+ // Matcher waits on Futures
+ auto string_fut = Future<std::string>::Make();
+ auto finisher = std::thread([&] {
+ SleepABit();
+ string_fut.MarkFinished("hello world");
+ });
+ EXPECT_THAT(string_fut, Finishes(ResultWith(testing::HasSubstr("hello"))));
+ finisher.join();
+
+ EXPECT_THAT(Future<std::string>::MakeFinished(Status::Invalid("XXX")),
+ Finishes(testing::Not(ResultWith(testing::HasSubstr("hello")))));
+
+ // holds a value, but that value doesn't match the given pattern
+ EXPECT_THAT(Future<std::string>::MakeFinished("foo bar"),
+ Finishes(testing::Not(ResultWith(testing::HasSubstr("hello")))));
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/hash_util.h b/src/arrow/cpp/src/arrow/util/hash_util.h
new file mode 100644
index 000000000..dd1c38a78
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/hash_util.h
@@ -0,0 +1,66 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+namespace arrow {
+namespace internal {
+
+// ----------------------------------------------------------------------
+// BEGIN Hash utilities from Boost
+
+namespace detail {
+
+#if defined(_MSC_VER)
+#define ARROW_HASH_ROTL32(x, r) _rotl(x, r)
+#else
+#define ARROW_HASH_ROTL32(x, r) (x << r) | (x >> (32 - r))
+#endif
+
+template <typename SizeT>
+inline void hash_combine_impl(SizeT& seed, SizeT value) {
+ seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+}
+
+inline void hash_combine_impl(uint32_t& h1, uint32_t k1) {
+ const uint32_t c1 = 0xcc9e2d51;
+ const uint32_t c2 = 0x1b873593;
+
+ k1 *= c1;
+ k1 = ARROW_HASH_ROTL32(k1, 15);
+ k1 *= c2;
+
+ h1 ^= k1;
+ h1 = ARROW_HASH_ROTL32(h1, 13);
+ h1 = h1 * 5 + 0xe6546b64;
+}
+
+#undef ARROW_HASH_ROTL32
+
+} // namespace detail
+
+template <class T>
+inline void hash_combine(std::size_t& seed, T const& v) {
+ std::hash<T> hasher;
+ return ::arrow::internal::detail::hash_combine_impl(seed, hasher(v));
+}
+
+// END Hash utilities from Boost
+// ----------------------------------------------------------------------
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/hashing.h b/src/arrow/cpp/src/arrow/util/hashing.h
new file mode 100644
index 000000000..09076c54d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/hashing.h
@@ -0,0 +1,886 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Private header, not to be exported
+
+#pragma once
+
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array/builder_binary.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_builders.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/ubsan.h"
+
+#define XXH_INLINE_ALL
+
+#include "arrow/vendored/xxhash.h" // IWYU pragma: keep
+
+namespace arrow {
+namespace internal {
+
+// XXX would it help to have a 32-bit hash value on large datasets?
+typedef uint64_t hash_t;
+
+// Notes about the choice of a hash function.
+// - XXH3 is extremely fast on most data sizes, from small to huge;
+// faster even than HW CRC-based hashing schemes
+// - our custom hash function for tiny values (< 16 bytes) is still
+// significantly faster (~30%), at least on this machine and compiler
+
+template <uint64_t AlgNum>
+inline hash_t ComputeStringHash(const void* data, int64_t length);
+
+template <typename Scalar, uint64_t AlgNum>
+struct ScalarHelperBase {
+ static bool CompareScalars(Scalar u, Scalar v) { return u == v; }
+
+ static hash_t ComputeHash(const Scalar& value) {
+ // Generic hash computation for scalars. Simply apply the string hash
+ // to the bit representation of the value.
+
+ // XXX in the case of FP values, we'd like equal values to have the same hash,
+ // even if they have different bit representations...
+ return ComputeStringHash<AlgNum>(&value, sizeof(value));
+ }
+};
+
+template <typename Scalar, uint64_t AlgNum = 0, typename Enable = void>
+struct ScalarHelper : public ScalarHelperBase<Scalar, AlgNum> {};
+
+template <typename Scalar, uint64_t AlgNum>
+struct ScalarHelper<Scalar, AlgNum, enable_if_t<std::is_integral<Scalar>::value>>
+ : public ScalarHelperBase<Scalar, AlgNum> {
+ // ScalarHelper specialization for integers
+
+ static hash_t ComputeHash(const Scalar& value) {
+ // Faster hash computation for integers.
+
+ // Two of xxhash's prime multipliers (which are chosen for their
+ // bit dispersion properties)
+ static constexpr uint64_t multipliers[] = {11400714785074694791ULL,
+ 14029467366897019727ULL};
+
+ // Multiplying by the prime number mixes the low bits into the high bits,
+ // then byte-swapping (which is a single CPU instruction) allows the
+ // combined high and low bits to participate in the initial hash table index.
+ auto h = static_cast<hash_t>(value);
+ return BitUtil::ByteSwap(multipliers[AlgNum] * h);
+ }
+};
+
+template <typename Scalar, uint64_t AlgNum>
+struct ScalarHelper<Scalar, AlgNum,
+ enable_if_t<std::is_same<util::string_view, Scalar>::value>>
+ : public ScalarHelperBase<Scalar, AlgNum> {
+ // ScalarHelper specialization for util::string_view
+
+ static hash_t ComputeHash(const util::string_view& value) {
+ return ComputeStringHash<AlgNum>(value.data(), static_cast<int64_t>(value.size()));
+ }
+};
+
+template <typename Scalar, uint64_t AlgNum>
+struct ScalarHelper<Scalar, AlgNum, enable_if_t<std::is_floating_point<Scalar>::value>>
+ : public ScalarHelperBase<Scalar, AlgNum> {
+ // ScalarHelper specialization for reals
+
+ static bool CompareScalars(Scalar u, Scalar v) {
+ if (std::isnan(u)) {
+ // XXX should we do a bit-precise comparison?
+ return std::isnan(v);
+ }
+ return u == v;
+ }
+};
+
+template <uint64_t AlgNum = 0>
+hash_t ComputeStringHash(const void* data, int64_t length) {
+ if (ARROW_PREDICT_TRUE(length <= 16)) {
+ // Specialize for small hash strings, as they are quite common as
+ // hash table keys. Even XXH3 isn't quite as fast.
+ auto p = reinterpret_cast<const uint8_t*>(data);
+ auto n = static_cast<uint32_t>(length);
+ if (n <= 8) {
+ if (n <= 3) {
+ if (n == 0) {
+ return 1U;
+ }
+ uint32_t x = (n << 24) ^ (p[0] << 16) ^ (p[n / 2] << 8) ^ p[n - 1];
+ return ScalarHelper<uint32_t, AlgNum>::ComputeHash(x);
+ }
+ // 4 <= length <= 8
+ // We can read the string as two overlapping 32-bit ints, apply
+ // different hash functions to each of them in parallel, then XOR
+ // the results
+ uint32_t x, y;
+ hash_t hx, hy;
+ x = util::SafeLoadAs<uint32_t>(p + n - 4);
+ y = util::SafeLoadAs<uint32_t>(p);
+ hx = ScalarHelper<uint32_t, AlgNum>::ComputeHash(x);
+ hy = ScalarHelper<uint32_t, AlgNum ^ 1>::ComputeHash(y);
+ return n ^ hx ^ hy;
+ }
+ // 8 <= length <= 16
+ // Apply the same principle as above
+ uint64_t x, y;
+ hash_t hx, hy;
+ x = util::SafeLoadAs<uint64_t>(p + n - 8);
+ y = util::SafeLoadAs<uint64_t>(p);
+ hx = ScalarHelper<uint64_t, AlgNum>::ComputeHash(x);
+ hy = ScalarHelper<uint64_t, AlgNum ^ 1>::ComputeHash(y);
+ return n ^ hx ^ hy;
+ }
+
+#if XXH3_SECRET_SIZE_MIN != 136
+#error XXH3_SECRET_SIZE_MIN changed, please fix kXxh3Secrets
+#endif
+
+ // XXH3_64bits_withSeed generates a secret based on the seed, which is too slow.
+ // Instead, we use hard-coded random secrets. To maximize cache efficiency,
+ // they reuse the same memory area.
+ static constexpr unsigned char kXxh3Secrets[XXH3_SECRET_SIZE_MIN + 1] = {
+ 0xe7, 0x8b, 0x13, 0xf9, 0xfc, 0xb5, 0x8e, 0xef, 0x81, 0x48, 0x2c, 0xbf, 0xf9, 0x9f,
+ 0xc1, 0x1e, 0x43, 0x6d, 0xbf, 0xa6, 0x6d, 0xb5, 0x72, 0xbc, 0x97, 0xd8, 0x61, 0x24,
+ 0x0f, 0x12, 0xe3, 0x05, 0x21, 0xf7, 0x5c, 0x66, 0x67, 0xa5, 0x65, 0x03, 0x96, 0x26,
+ 0x69, 0xd8, 0x29, 0x20, 0xf8, 0xc7, 0xb0, 0x3d, 0xdd, 0x7d, 0x18, 0xa0, 0x60, 0x75,
+ 0x92, 0xa4, 0xce, 0xba, 0xc0, 0x77, 0xf4, 0xac, 0xb7, 0x03, 0x53, 0xf0, 0x98, 0xce,
+ 0xe6, 0x2b, 0x20, 0xc7, 0x82, 0x91, 0xab, 0xbf, 0x68, 0x5c, 0x62, 0x4d, 0x33, 0xa3,
+ 0xe1, 0xb3, 0xff, 0x97, 0x54, 0x4c, 0x44, 0x34, 0xb5, 0xb9, 0x32, 0x4c, 0x75, 0x42,
+ 0x89, 0x53, 0x94, 0xd4, 0x9f, 0x2b, 0x76, 0x4d, 0x4e, 0xe6, 0xfa, 0x15, 0x3e, 0xc1,
+ 0xdb, 0x71, 0x4b, 0x2c, 0x94, 0xf5, 0xfc, 0x8c, 0x89, 0x4b, 0xfb, 0xc1, 0x82, 0xa5,
+ 0x6a, 0x53, 0xf9, 0x4a, 0xba, 0xce, 0x1f, 0xc0, 0x97, 0x1a, 0x87};
+
+ static_assert(AlgNum < 2, "AlgNum too large");
+ static constexpr auto secret = kXxh3Secrets + AlgNum;
+ return XXH3_64bits_withSecret(data, static_cast<size_t>(length), secret,
+ XXH3_SECRET_SIZE_MIN);
+}
+
+// XXX add a HashEq<ArrowType> struct with both hash and compare functions?
+
+// ----------------------------------------------------------------------
+// An open-addressing insert-only hash table (no deletes)
+
+template <typename Payload>
+class HashTable {
+ public:
+ static constexpr hash_t kSentinel = 0ULL;
+ static constexpr int64_t kLoadFactor = 2UL;
+
+ struct Entry {
+ hash_t h;
+ Payload payload;
+
+ // An entry is valid if the hash is different from the sentinel value
+ operator bool() const { return h != kSentinel; }
+ };
+
+ HashTable(MemoryPool* pool, uint64_t capacity) : entries_builder_(pool) {
+ DCHECK_NE(pool, nullptr);
+ // Minimum of 32 elements
+ capacity = std::max<uint64_t>(capacity, 32UL);
+ capacity_ = BitUtil::NextPower2(capacity);
+ capacity_mask_ = capacity_ - 1;
+ size_ = 0;
+
+ DCHECK_OK(UpsizeBuffer(capacity_));
+ }
+
+ // Lookup with non-linear probing
+ // cmp_func should have signature bool(const Payload*).
+ // Return a (Entry*, found) pair.
+ template <typename CmpFunc>
+ std::pair<Entry*, bool> Lookup(hash_t h, CmpFunc&& cmp_func) {
+ auto p = Lookup<DoCompare, CmpFunc>(h, entries_, capacity_mask_,
+ std::forward<CmpFunc>(cmp_func));
+ return {&entries_[p.first], p.second};
+ }
+
+ template <typename CmpFunc>
+ std::pair<const Entry*, bool> Lookup(hash_t h, CmpFunc&& cmp_func) const {
+ auto p = Lookup<DoCompare, CmpFunc>(h, entries_, capacity_mask_,
+ std::forward<CmpFunc>(cmp_func));
+ return {&entries_[p.first], p.second};
+ }
+
+ Status Insert(Entry* entry, hash_t h, const Payload& payload) {
+ // Ensure entry is empty before inserting
+ assert(!*entry);
+ entry->h = FixHash(h);
+ entry->payload = payload;
+ ++size_;
+
+ if (ARROW_PREDICT_FALSE(NeedUpsizing())) {
+ // Resize less frequently since it is expensive
+ return Upsize(capacity_ * kLoadFactor * 2);
+ }
+ return Status::OK();
+ }
+
+ uint64_t size() const { return size_; }
+
+ // Visit all non-empty entries in the table
+ // The visit_func should have signature void(const Entry*)
+ template <typename VisitFunc>
+ void VisitEntries(VisitFunc&& visit_func) const {
+ for (uint64_t i = 0; i < capacity_; i++) {
+ const auto& entry = entries_[i];
+ if (entry) {
+ visit_func(&entry);
+ }
+ }
+ }
+
+ protected:
+ // NoCompare is for when the value is known not to exist in the table
+ enum CompareKind { DoCompare, NoCompare };
+
+ // The workhorse lookup function
+ template <CompareKind CKind, typename CmpFunc>
+ std::pair<uint64_t, bool> Lookup(hash_t h, const Entry* entries, uint64_t size_mask,
+ CmpFunc&& cmp_func) const {
+ static constexpr uint8_t perturb_shift = 5;
+
+ uint64_t index, perturb;
+ const Entry* entry;
+
+ h = FixHash(h);
+ index = h & size_mask;
+ perturb = (h >> perturb_shift) + 1U;
+
+ while (true) {
+ entry = &entries[index];
+ if (CompareEntry<CKind, CmpFunc>(h, entry, std::forward<CmpFunc>(cmp_func))) {
+ // Found
+ return {index, true};
+ }
+ if (entry->h == kSentinel) {
+ // Empty slot
+ return {index, false};
+ }
+
+ // Perturbation logic inspired from CPython's set / dict object.
+ // The goal is that all 64 bits of the unmasked hash value eventually
+ // participate in the probing sequence, to minimize clustering.
+ index = (index + perturb) & size_mask;
+ perturb = (perturb >> perturb_shift) + 1U;
+ }
+ }
+
+ template <CompareKind CKind, typename CmpFunc>
+ bool CompareEntry(hash_t h, const Entry* entry, CmpFunc&& cmp_func) const {
+ if (CKind == NoCompare) {
+ return false;
+ } else {
+ return entry->h == h && cmp_func(&entry->payload);
+ }
+ }
+
+ bool NeedUpsizing() const {
+ // Keep the load factor <= 1/2
+ return size_ * kLoadFactor >= capacity_;
+ }
+
+ Status UpsizeBuffer(uint64_t capacity) {
+ RETURN_NOT_OK(entries_builder_.Resize(capacity));
+ entries_ = entries_builder_.mutable_data();
+ memset(static_cast<void*>(entries_), 0, capacity * sizeof(Entry));
+
+ return Status::OK();
+ }
+
+ Status Upsize(uint64_t new_capacity) {
+ assert(new_capacity > capacity_);
+ uint64_t new_mask = new_capacity - 1;
+ assert((new_capacity & new_mask) == 0); // it's a power of two
+
+ // Stash old entries and seal builder, effectively resetting the Buffer
+ const Entry* old_entries = entries_;
+ ARROW_ASSIGN_OR_RAISE(auto previous, entries_builder_.FinishWithLength(capacity_));
+ // Allocate new buffer
+ RETURN_NOT_OK(UpsizeBuffer(new_capacity));
+
+ for (uint64_t i = 0; i < capacity_; i++) {
+ const auto& entry = old_entries[i];
+ if (entry) {
+ // Dummy compare function will not be called
+ auto p = Lookup<NoCompare>(entry.h, entries_, new_mask,
+ [](const Payload*) { return false; });
+ // Lookup<NoCompare> (and CompareEntry<NoCompare>) ensure that an
+ // empty slots is always returned
+ assert(!p.second);
+ entries_[p.first] = entry;
+ }
+ }
+ capacity_ = new_capacity;
+ capacity_mask_ = new_mask;
+
+ return Status::OK();
+ }
+
+ hash_t FixHash(hash_t h) const { return (h == kSentinel) ? 42U : h; }
+
+ // The number of slots available in the hash table array.
+ uint64_t capacity_;
+ uint64_t capacity_mask_;
+ // The number of used slots in the hash table array.
+ uint64_t size_;
+
+ Entry* entries_;
+ TypedBufferBuilder<Entry> entries_builder_;
+};
+
+// XXX typedef memo_index_t int32_t ?
+
+constexpr int32_t kKeyNotFound = -1;
+
+// ----------------------------------------------------------------------
+// A base class for memoization table.
+
+class MemoTable {
+ public:
+ virtual ~MemoTable() = default;
+
+ virtual int32_t size() const = 0;
+};
+
+// ----------------------------------------------------------------------
+// A memoization table for memory-cheap scalar values.
+
+// The memoization table remembers and allows to look up the insertion
+// index for each key.
+
+template <typename Scalar, template <class> class HashTableTemplateType = HashTable>
+class ScalarMemoTable : public MemoTable {
+ public:
+ explicit ScalarMemoTable(MemoryPool* pool, int64_t entries = 0)
+ : hash_table_(pool, static_cast<uint64_t>(entries)) {}
+
+ int32_t Get(const Scalar& value) const {
+ auto cmp_func = [value](const Payload* payload) -> bool {
+ return ScalarHelper<Scalar, 0>::CompareScalars(payload->value, value);
+ };
+ hash_t h = ComputeHash(value);
+ auto p = hash_table_.Lookup(h, cmp_func);
+ if (p.second) {
+ return p.first->payload.memo_index;
+ } else {
+ return kKeyNotFound;
+ }
+ }
+
+ template <typename Func1, typename Func2>
+ Status GetOrInsert(const Scalar& value, Func1&& on_found, Func2&& on_not_found,
+ int32_t* out_memo_index) {
+ auto cmp_func = [value](const Payload* payload) -> bool {
+ return ScalarHelper<Scalar, 0>::CompareScalars(value, payload->value);
+ };
+ hash_t h = ComputeHash(value);
+ auto p = hash_table_.Lookup(h, cmp_func);
+ int32_t memo_index;
+ if (p.second) {
+ memo_index = p.first->payload.memo_index;
+ on_found(memo_index);
+ } else {
+ memo_index = size();
+ RETURN_NOT_OK(hash_table_.Insert(p.first, h, {value, memo_index}));
+ on_not_found(memo_index);
+ }
+ *out_memo_index = memo_index;
+ return Status::OK();
+ }
+
+ Status GetOrInsert(const Scalar& value, int32_t* out_memo_index) {
+ return GetOrInsert(
+ value, [](int32_t i) {}, [](int32_t i) {}, out_memo_index);
+ }
+
+ int32_t GetNull() const { return null_index_; }
+
+ template <typename Func1, typename Func2>
+ int32_t GetOrInsertNull(Func1&& on_found, Func2&& on_not_found) {
+ int32_t memo_index = GetNull();
+ if (memo_index != kKeyNotFound) {
+ on_found(memo_index);
+ } else {
+ null_index_ = memo_index = size();
+ on_not_found(memo_index);
+ }
+ return memo_index;
+ }
+
+ int32_t GetOrInsertNull() {
+ return GetOrInsertNull([](int32_t i) {}, [](int32_t i) {});
+ }
+
+ // The number of entries in the memo table +1 if null was added.
+ // (which is also 1 + the largest memo index)
+ int32_t size() const override {
+ return static_cast<int32_t>(hash_table_.size()) + (GetNull() != kKeyNotFound);
+ }
+
+ // Copy values starting from index `start` into `out_data`
+ void CopyValues(int32_t start, Scalar* out_data) const {
+ hash_table_.VisitEntries([=](const HashTableEntry* entry) {
+ int32_t index = entry->payload.memo_index - start;
+ if (index >= 0) {
+ out_data[index] = entry->payload.value;
+ }
+ });
+ // Zero-initialize the null entry
+ if (null_index_ != kKeyNotFound) {
+ int32_t index = null_index_ - start;
+ if (index >= 0) {
+ out_data[index] = Scalar{};
+ }
+ }
+ }
+
+ void CopyValues(Scalar* out_data) const { CopyValues(0, out_data); }
+
+ protected:
+ struct Payload {
+ Scalar value;
+ int32_t memo_index;
+ };
+
+ using HashTableType = HashTableTemplateType<Payload>;
+ using HashTableEntry = typename HashTableType::Entry;
+ HashTableType hash_table_;
+ int32_t null_index_ = kKeyNotFound;
+
+ hash_t ComputeHash(const Scalar& value) const {
+ return ScalarHelper<Scalar, 0>::ComputeHash(value);
+ }
+};
+
+// ----------------------------------------------------------------------
+// A memoization table for small scalar values, using direct indexing
+
+template <typename Scalar, typename Enable = void>
+struct SmallScalarTraits {};
+
+template <>
+struct SmallScalarTraits<bool> {
+ static constexpr int32_t cardinality = 2;
+
+ static uint32_t AsIndex(bool value) { return value ? 1 : 0; }
+};
+
+template <typename Scalar>
+struct SmallScalarTraits<Scalar, enable_if_t<std::is_integral<Scalar>::value>> {
+ using Unsigned = typename std::make_unsigned<Scalar>::type;
+
+ static constexpr int32_t cardinality = 1U + std::numeric_limits<Unsigned>::max();
+
+ static uint32_t AsIndex(Scalar value) { return static_cast<Unsigned>(value); }
+};
+
+template <typename Scalar, template <class> class HashTableTemplateType = HashTable>
+class SmallScalarMemoTable : public MemoTable {
+ public:
+ explicit SmallScalarMemoTable(MemoryPool* pool, int64_t entries = 0) {
+ std::fill(value_to_index_, value_to_index_ + cardinality + 1, kKeyNotFound);
+ index_to_value_.reserve(cardinality);
+ }
+
+ int32_t Get(const Scalar value) const {
+ auto value_index = AsIndex(value);
+ return value_to_index_[value_index];
+ }
+
+ template <typename Func1, typename Func2>
+ Status GetOrInsert(const Scalar value, Func1&& on_found, Func2&& on_not_found,
+ int32_t* out_memo_index) {
+ auto value_index = AsIndex(value);
+ auto memo_index = value_to_index_[value_index];
+ if (memo_index == kKeyNotFound) {
+ memo_index = static_cast<int32_t>(index_to_value_.size());
+ index_to_value_.push_back(value);
+ value_to_index_[value_index] = memo_index;
+ DCHECK_LT(memo_index, cardinality + 1);
+ on_not_found(memo_index);
+ } else {
+ on_found(memo_index);
+ }
+ *out_memo_index = memo_index;
+ return Status::OK();
+ }
+
+ Status GetOrInsert(const Scalar value, int32_t* out_memo_index) {
+ return GetOrInsert(
+ value, [](int32_t i) {}, [](int32_t i) {}, out_memo_index);
+ }
+
+ int32_t GetNull() const { return value_to_index_[cardinality]; }
+
+ template <typename Func1, typename Func2>
+ int32_t GetOrInsertNull(Func1&& on_found, Func2&& on_not_found) {
+ auto memo_index = GetNull();
+ if (memo_index == kKeyNotFound) {
+ memo_index = value_to_index_[cardinality] = size();
+ index_to_value_.push_back(0);
+ on_not_found(memo_index);
+ } else {
+ on_found(memo_index);
+ }
+ return memo_index;
+ }
+
+ int32_t GetOrInsertNull() {
+ return GetOrInsertNull([](int32_t i) {}, [](int32_t i) {});
+ }
+
+ // The number of entries in the memo table
+ // (which is also 1 + the largest memo index)
+ int32_t size() const override { return static_cast<int32_t>(index_to_value_.size()); }
+
+ // Copy values starting from index `start` into `out_data`
+ void CopyValues(int32_t start, Scalar* out_data) const {
+ DCHECK_GE(start, 0);
+ DCHECK_LE(static_cast<size_t>(start), index_to_value_.size());
+ int64_t offset = start * static_cast<int32_t>(sizeof(Scalar));
+ memcpy(out_data, index_to_value_.data() + offset, (size() - start) * sizeof(Scalar));
+ }
+
+ void CopyValues(Scalar* out_data) const { CopyValues(0, out_data); }
+
+ const std::vector<Scalar>& values() const { return index_to_value_; }
+
+ protected:
+ static constexpr auto cardinality = SmallScalarTraits<Scalar>::cardinality;
+ static_assert(cardinality <= 256, "cardinality too large for direct-addressed table");
+
+ uint32_t AsIndex(Scalar value) const {
+ return SmallScalarTraits<Scalar>::AsIndex(value);
+ }
+
+ // The last index is reserved for the null element.
+ int32_t value_to_index_[cardinality + 1];
+ std::vector<Scalar> index_to_value_;
+};
+
+// ----------------------------------------------------------------------
+// A memoization table for variable-sized binary data.
+
+template <typename BinaryBuilderT>
+class BinaryMemoTable : public MemoTable {
+ public:
+ using builder_offset_type = typename BinaryBuilderT::offset_type;
+ explicit BinaryMemoTable(MemoryPool* pool, int64_t entries = 0,
+ int64_t values_size = -1)
+ : hash_table_(pool, static_cast<uint64_t>(entries)), binary_builder_(pool) {
+ const int64_t data_size = (values_size < 0) ? entries * 4 : values_size;
+ DCHECK_OK(binary_builder_.Resize(entries));
+ DCHECK_OK(binary_builder_.ReserveData(data_size));
+ }
+
+ int32_t Get(const void* data, builder_offset_type length) const {
+ hash_t h = ComputeStringHash<0>(data, length);
+ auto p = Lookup(h, data, length);
+ if (p.second) {
+ return p.first->payload.memo_index;
+ } else {
+ return kKeyNotFound;
+ }
+ }
+
+ int32_t Get(const util::string_view& value) const {
+ return Get(value.data(), static_cast<builder_offset_type>(value.length()));
+ }
+
+ template <typename Func1, typename Func2>
+ Status GetOrInsert(const void* data, builder_offset_type length, Func1&& on_found,
+ Func2&& on_not_found, int32_t* out_memo_index) {
+ hash_t h = ComputeStringHash<0>(data, length);
+ auto p = Lookup(h, data, length);
+ int32_t memo_index;
+ if (p.second) {
+ memo_index = p.first->payload.memo_index;
+ on_found(memo_index);
+ } else {
+ memo_index = size();
+ // Insert string value
+ RETURN_NOT_OK(binary_builder_.Append(static_cast<const char*>(data), length));
+ // Insert hash entry
+ RETURN_NOT_OK(
+ hash_table_.Insert(const_cast<HashTableEntry*>(p.first), h, {memo_index}));
+
+ on_not_found(memo_index);
+ }
+ *out_memo_index = memo_index;
+ return Status::OK();
+ }
+
+ template <typename Func1, typename Func2>
+ Status GetOrInsert(const util::string_view& value, Func1&& on_found,
+ Func2&& on_not_found, int32_t* out_memo_index) {
+ return GetOrInsert(value.data(), static_cast<builder_offset_type>(value.length()),
+ std::forward<Func1>(on_found), std::forward<Func2>(on_not_found),
+ out_memo_index);
+ }
+
+ Status GetOrInsert(const void* data, builder_offset_type length,
+ int32_t* out_memo_index) {
+ return GetOrInsert(
+ data, length, [](int32_t i) {}, [](int32_t i) {}, out_memo_index);
+ }
+
+ Status GetOrInsert(const util::string_view& value, int32_t* out_memo_index) {
+ return GetOrInsert(value.data(), static_cast<builder_offset_type>(value.length()),
+ out_memo_index);
+ }
+
+ int32_t GetNull() const { return null_index_; }
+
+ template <typename Func1, typename Func2>
+ int32_t GetOrInsertNull(Func1&& on_found, Func2&& on_not_found) {
+ int32_t memo_index = GetNull();
+ if (memo_index == kKeyNotFound) {
+ memo_index = null_index_ = size();
+ DCHECK_OK(binary_builder_.AppendNull());
+ on_not_found(memo_index);
+ } else {
+ on_found(memo_index);
+ }
+ return memo_index;
+ }
+
+ int32_t GetOrInsertNull() {
+ return GetOrInsertNull([](int32_t i) {}, [](int32_t i) {});
+ }
+
+ // The number of entries in the memo table
+ // (which is also 1 + the largest memo index)
+ int32_t size() const override {
+ return static_cast<int32_t>(hash_table_.size() + (GetNull() != kKeyNotFound));
+ }
+
+ int64_t values_size() const { return binary_builder_.value_data_length(); }
+
+ // Copy (n + 1) offsets starting from index `start` into `out_data`
+ template <class Offset>
+ void CopyOffsets(int32_t start, Offset* out_data) const {
+ DCHECK_LE(start, size());
+
+ const builder_offset_type* offsets = binary_builder_.offsets_data();
+ const builder_offset_type delta =
+ start < binary_builder_.length() ? offsets[start] : 0;
+ for (int32_t i = start; i < size(); ++i) {
+ const builder_offset_type adjusted_offset = offsets[i] - delta;
+ Offset cast_offset = static_cast<Offset>(adjusted_offset);
+ assert(static_cast<builder_offset_type>(cast_offset) ==
+ adjusted_offset); // avoid truncation
+ *out_data++ = cast_offset;
+ }
+
+ // Copy last value since BinaryBuilder only materializes it on in Finish()
+ *out_data = static_cast<Offset>(binary_builder_.value_data_length() - delta);
+ }
+
+ template <class Offset>
+ void CopyOffsets(Offset* out_data) const {
+ CopyOffsets(0, out_data);
+ }
+
+ // Copy values starting from index `start` into `out_data`
+ void CopyValues(int32_t start, uint8_t* out_data) const {
+ CopyValues(start, -1, out_data);
+ }
+
+ // Same as above, but check output size in debug mode
+ void CopyValues(int32_t start, int64_t out_size, uint8_t* out_data) const {
+ DCHECK_LE(start, size());
+
+ // The absolute byte offset of `start` value in the binary buffer.
+ const builder_offset_type offset = binary_builder_.offset(start);
+ const auto length = binary_builder_.value_data_length() - static_cast<size_t>(offset);
+
+ if (out_size != -1) {
+ assert(static_cast<int64_t>(length) <= out_size);
+ }
+
+ auto view = binary_builder_.GetView(start);
+ memcpy(out_data, view.data(), length);
+ }
+
+ void CopyValues(uint8_t* out_data) const { CopyValues(0, -1, out_data); }
+
+ void CopyValues(int64_t out_size, uint8_t* out_data) const {
+ CopyValues(0, out_size, out_data);
+ }
+
+ void CopyFixedWidthValues(int32_t start, int32_t width_size, int64_t out_size,
+ uint8_t* out_data) const {
+ // This method exists to cope with the fact that the BinaryMemoTable does
+ // not know the fixed width when inserting the null value. The data
+ // buffer hold a zero length string for the null value (if found).
+ //
+ // Thus, the method will properly inject an empty value of the proper width
+ // in the output buffer.
+ //
+ if (start >= size()) {
+ return;
+ }
+
+ int32_t null_index = GetNull();
+ if (null_index < start) {
+ // Nothing to skip, proceed as usual.
+ CopyValues(start, out_size, out_data);
+ return;
+ }
+
+ builder_offset_type left_offset = binary_builder_.offset(start);
+
+ // Ensure that the data length is exactly missing width_size bytes to fit
+ // in the expected output (n_values * width_size).
+#ifndef NDEBUG
+ int64_t data_length = values_size() - static_cast<size_t>(left_offset);
+ assert(data_length + width_size == out_size);
+ ARROW_UNUSED(data_length);
+#endif
+
+ auto in_data = binary_builder_.value_data() + left_offset;
+ // The null use 0-length in the data, slice the data in 2 and skip by
+ // width_size in out_data. [part_1][width_size][part_2]
+ auto null_data_offset = binary_builder_.offset(null_index);
+ auto left_size = null_data_offset - left_offset;
+ if (left_size > 0) {
+ memcpy(out_data, in_data + left_offset, left_size);
+ }
+ // Zero-initialize the null entry
+ memset(out_data + left_size, 0, width_size);
+
+ auto right_size = values_size() - static_cast<size_t>(null_data_offset);
+ if (right_size > 0) {
+ // skip the null fixed size value.
+ auto out_offset = left_size + width_size;
+ assert(out_data + out_offset + right_size == out_data + out_size);
+ memcpy(out_data + out_offset, in_data + null_data_offset, right_size);
+ }
+ }
+
+ // Visit the stored values in insertion order.
+ // The visitor function should have the signature `void(util::string_view)`
+ // or `void(const util::string_view&)`.
+ template <typename VisitFunc>
+ void VisitValues(int32_t start, VisitFunc&& visit) const {
+ for (int32_t i = start; i < size(); ++i) {
+ visit(binary_builder_.GetView(i));
+ }
+ }
+
+ protected:
+ struct Payload {
+ int32_t memo_index;
+ };
+
+ using HashTableType = HashTable<Payload>;
+ using HashTableEntry = typename HashTable<Payload>::Entry;
+ HashTableType hash_table_;
+ BinaryBuilderT binary_builder_;
+
+ int32_t null_index_ = kKeyNotFound;
+
+ std::pair<const HashTableEntry*, bool> Lookup(hash_t h, const void* data,
+ builder_offset_type length) const {
+ auto cmp_func = [=](const Payload* payload) {
+ util::string_view lhs = binary_builder_.GetView(payload->memo_index);
+ util::string_view rhs(static_cast<const char*>(data), length);
+ return lhs == rhs;
+ };
+ return hash_table_.Lookup(h, cmp_func);
+ }
+};
+
+template <typename T, typename Enable = void>
+struct HashTraits {};
+
+template <>
+struct HashTraits<BooleanType> {
+ using MemoTableType = SmallScalarMemoTable<bool>;
+};
+
+template <typename T>
+struct HashTraits<T, enable_if_8bit_int<T>> {
+ using c_type = typename T::c_type;
+ using MemoTableType = SmallScalarMemoTable<typename T::c_type>;
+};
+
+template <typename T>
+struct HashTraits<T, enable_if_t<has_c_type<T>::value && !is_8bit_int<T>::value>> {
+ using c_type = typename T::c_type;
+ using MemoTableType = ScalarMemoTable<c_type, HashTable>;
+};
+
+template <typename T>
+struct HashTraits<T, enable_if_t<has_string_view<T>::value &&
+ !std::is_base_of<LargeBinaryType, T>::value>> {
+ using MemoTableType = BinaryMemoTable<BinaryBuilder>;
+};
+
+template <typename T>
+struct HashTraits<T, enable_if_decimal<T>> {
+ using MemoTableType = BinaryMemoTable<BinaryBuilder>;
+};
+
+template <typename T>
+struct HashTraits<T, enable_if_t<std::is_base_of<LargeBinaryType, T>::value>> {
+ using MemoTableType = BinaryMemoTable<LargeBinaryBuilder>;
+};
+
+template <typename MemoTableType>
+static inline Status ComputeNullBitmap(MemoryPool* pool, const MemoTableType& memo_table,
+ int64_t start_offset, int64_t* null_count,
+ std::shared_ptr<Buffer>* null_bitmap) {
+ int64_t dict_length = static_cast<int64_t>(memo_table.size()) - start_offset;
+ int64_t null_index = memo_table.GetNull();
+
+ *null_count = 0;
+ *null_bitmap = nullptr;
+
+ if (null_index != kKeyNotFound && null_index >= start_offset) {
+ null_index -= start_offset;
+ *null_count = 1;
+ ARROW_ASSIGN_OR_RAISE(*null_bitmap,
+ internal::BitmapAllButOne(pool, dict_length, null_index));
+ }
+
+ return Status::OK();
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/hashing_benchmark.cc b/src/arrow/cpp/src/arrow/util/hashing_benchmark.cc
new file mode 100644
index 000000000..c7051d1a3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/hashing_benchmark.cc
@@ -0,0 +1,123 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <random>
+#include <string>
+#include <vector>
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/hashing.h"
+
+namespace arrow {
+namespace internal {
+
+template <class Integer>
+static std::vector<Integer> MakeIntegers(int32_t n_values) {
+ std::vector<Integer> values(n_values);
+
+ std::default_random_engine gen(42);
+ std::uniform_int_distribution<Integer> values_dist(0,
+ std::numeric_limits<Integer>::max());
+ std::generate(values.begin(), values.end(),
+ [&]() { return static_cast<Integer>(values_dist(gen)); });
+ return values;
+}
+
+static std::vector<std::string> MakeStrings(int32_t n_values, int32_t min_length,
+ int32_t max_length) {
+ std::default_random_engine gen(42);
+ std::vector<std::string> values(n_values);
+
+ // Generate strings between 2 and 20 bytes
+ std::uniform_int_distribution<int32_t> length_dist(min_length, max_length);
+ std::independent_bits_engine<std::default_random_engine, 8, uint16_t> bytes_gen(42);
+
+ std::generate(values.begin(), values.end(), [&]() {
+ auto length = length_dist(gen);
+ std::string s(length, 'X');
+ for (int32_t i = 0; i < length; ++i) {
+ s[i] = static_cast<uint8_t>(bytes_gen());
+ }
+ return s;
+ });
+ return values;
+}
+
+static void HashIntegers(benchmark::State& state) { // NOLINT non-const reference
+ const std::vector<int64_t> values = MakeIntegers<int64_t>(10000);
+
+ while (state.KeepRunning()) {
+ hash_t total = 0;
+ for (const int64_t v : values) {
+ total += ScalarHelper<int64_t, 0>::ComputeHash(v);
+ total += ScalarHelper<int64_t, 1>::ComputeHash(v);
+ }
+ benchmark::DoNotOptimize(total);
+ }
+ state.SetBytesProcessed(2 * state.iterations() * values.size() * sizeof(int64_t));
+ state.SetItemsProcessed(2 * state.iterations() * values.size());
+}
+
+static void BenchmarkStringHashing(benchmark::State& state, // NOLINT non-const reference
+ const std::vector<std::string>& values) {
+ uint64_t total_size = 0;
+ for (const std::string& v : values) {
+ total_size += v.size();
+ }
+
+ while (state.KeepRunning()) {
+ hash_t total = 0;
+ for (const std::string& v : values) {
+ total += ComputeStringHash<0>(v.data(), static_cast<int64_t>(v.size()));
+ total += ComputeStringHash<1>(v.data(), static_cast<int64_t>(v.size()));
+ }
+ benchmark::DoNotOptimize(total);
+ }
+ state.SetBytesProcessed(2 * state.iterations() * total_size);
+ state.SetItemsProcessed(2 * state.iterations() * values.size());
+}
+
+static void HashSmallStrings(benchmark::State& state) { // NOLINT non-const reference
+ const std::vector<std::string> values = MakeStrings(10000, 2, 20);
+ BenchmarkStringHashing(state, values);
+}
+
+static void HashMediumStrings(benchmark::State& state) { // NOLINT non-const reference
+ const std::vector<std::string> values = MakeStrings(10000, 20, 120);
+ BenchmarkStringHashing(state, values);
+}
+
+static void HashLargeStrings(benchmark::State& state) { // NOLINT non-const reference
+ const std::vector<std::string> values = MakeStrings(1000, 120, 2000);
+ BenchmarkStringHashing(state, values);
+}
+
+// ----------------------------------------------------------------------
+// Benchmark declarations
+
+BENCHMARK(HashIntegers);
+BENCHMARK(HashSmallStrings);
+BENCHMARK(HashMediumStrings);
+BENCHMARK(HashLargeStrings);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/hashing_test.cc b/src/arrow/cpp/src/arrow/util/hashing_test.cc
new file mode 100644
index 000000000..116e305e5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/hashing_test.cc
@@ -0,0 +1,490 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cmath>
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <random>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/hashing.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace internal {
+
+template <typename Integer>
+static std::unordered_set<Integer> MakeDistinctIntegers(int32_t n_values) {
+ std::default_random_engine gen(42);
+ std::uniform_int_distribution<Integer> values_dist(0,
+ std::numeric_limits<Integer>::max());
+
+ std::unordered_set<Integer> values;
+ values.reserve(n_values);
+
+ while (values.size() < static_cast<uint32_t>(n_values)) {
+ values.insert(static_cast<Integer>(values_dist(gen)));
+ }
+ return values;
+}
+
+template <typename Integer>
+static std::unordered_set<Integer> MakeSequentialIntegers(int32_t n_values) {
+ std::unordered_set<Integer> values;
+ values.reserve(n_values);
+
+ for (int32_t i = 0; i < n_values; ++i) {
+ values.insert(static_cast<Integer>(i));
+ }
+ DCHECK_EQ(values.size(), static_cast<uint32_t>(n_values));
+ return values;
+}
+
+static std::unordered_set<std::string> MakeDistinctStrings(int32_t n_values) {
+ std::unordered_set<std::string> values;
+ values.reserve(n_values);
+
+ // Generate strings between 0 and 24 bytes, with ASCII characters
+ std::default_random_engine gen(42);
+ std::uniform_int_distribution<int32_t> length_dist(0, 24);
+ std::uniform_int_distribution<uint32_t> char_dist('0', 'z');
+
+ while (values.size() < static_cast<uint32_t>(n_values)) {
+ auto length = length_dist(gen);
+ std::string s(length, 'X');
+ for (int32_t i = 0; i < length; ++i) {
+ s[i] = static_cast<uint8_t>(char_dist(gen));
+ }
+ values.insert(std::move(s));
+ }
+ return values;
+}
+
+template <typename T>
+static void CheckScalarHashQuality(const std::unordered_set<T>& distinct_values) {
+ std::unordered_set<hash_t> hashes;
+ for (const auto v : distinct_values) {
+ hashes.insert(ScalarHelper<T, 0>::ComputeHash(v));
+ hashes.insert(ScalarHelper<T, 1>::ComputeHash(v));
+ }
+ ASSERT_GE(static_cast<double>(hashes.size()),
+ 0.96 * static_cast<double>(2 * distinct_values.size()));
+}
+
+TEST(HashingQuality, Int64) {
+#ifdef ARROW_VALGRIND
+ const int32_t n_values = 500;
+#else
+ const int32_t n_values = 10000;
+#endif
+ {
+ const auto values = MakeDistinctIntegers<int64_t>(n_values);
+ CheckScalarHashQuality<int64_t>(values);
+ }
+ {
+ const auto values = MakeSequentialIntegers<int64_t>(n_values);
+ CheckScalarHashQuality<int64_t>(values);
+ }
+}
+
+TEST(HashingQuality, Strings) {
+#ifdef ARROW_VALGRIND
+ const int32_t n_values = 500;
+#else
+ const int32_t n_values = 10000;
+#endif
+ const auto values = MakeDistinctStrings(n_values);
+
+ std::unordered_set<hash_t> hashes;
+ for (const auto& v : values) {
+ hashes.insert(ComputeStringHash<0>(v.data(), static_cast<int64_t>(v.size())));
+ hashes.insert(ComputeStringHash<1>(v.data(), static_cast<int64_t>(v.size())));
+ }
+ ASSERT_GE(static_cast<double>(hashes.size()),
+ 0.96 * static_cast<double>(2 * values.size()));
+}
+
+TEST(HashingBounds, Strings) {
+ std::vector<size_t> sizes({1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17, 18, 19, 20, 21});
+ for (const auto s : sizes) {
+ std::string str;
+ for (size_t i = 0; i < s; i++) {
+ str.push_back(static_cast<char>(i));
+ }
+ hash_t h = ComputeStringHash<1>(str.c_str(), str.size());
+ int different = 0;
+ for (char i = 0; i < 120; i++) {
+ str[str.size() - 1] = i;
+ if (ComputeStringHash<1>(str.c_str(), str.size()) != h) {
+ different++;
+ }
+ }
+ ASSERT_GE(different, 118);
+ }
+}
+
+template <typename MemoTable, typename Value>
+void AssertGet(MemoTable& table, const Value& v, int32_t expected) {
+ ASSERT_EQ(table.Get(v), expected);
+}
+
+template <typename MemoTable, typename Value>
+void AssertGetOrInsert(MemoTable& table, const Value& v, int32_t expected) {
+ int32_t memo_index;
+ ASSERT_OK(table.GetOrInsert(v, &memo_index));
+ ASSERT_EQ(memo_index, expected);
+}
+
+template <typename MemoTable>
+void AssertGetNull(MemoTable& table, int32_t expected) {
+ ASSERT_EQ(table.GetNull(), expected);
+}
+
+template <typename MemoTable>
+void AssertGetOrInsertNull(MemoTable& table, int32_t expected) {
+ ASSERT_EQ(table.GetOrInsertNull(), expected);
+}
+
+TEST(ScalarMemoTable, Int64) {
+ const int64_t A = 1234, B = 0, C = -98765321, D = 12345678901234LL, E = -1, F = 1,
+ G = 9223372036854775807LL, H = -9223372036854775807LL - 1;
+
+ ScalarMemoTable<int64_t> table(default_memory_pool(), 0);
+ ASSERT_EQ(table.size(), 0);
+ AssertGet(table, A, kKeyNotFound);
+ AssertGetNull(table, kKeyNotFound);
+ AssertGetOrInsert(table, A, 0);
+ AssertGet(table, B, kKeyNotFound);
+ AssertGetOrInsert(table, B, 1);
+ AssertGetOrInsert(table, C, 2);
+ AssertGetOrInsert(table, D, 3);
+ AssertGetOrInsert(table, E, 4);
+ AssertGetOrInsertNull(table, 5);
+
+ AssertGet(table, A, 0);
+ AssertGetOrInsert(table, A, 0);
+ AssertGet(table, E, 4);
+ AssertGetOrInsert(table, E, 4);
+
+ AssertGetOrInsert(table, F, 6);
+ AssertGetOrInsert(table, G, 7);
+ AssertGetOrInsert(table, H, 8);
+
+ AssertGetOrInsert(table, G, 7);
+ AssertGetOrInsert(table, F, 6);
+ AssertGetOrInsertNull(table, 5);
+ AssertGetOrInsert(table, E, 4);
+ AssertGetOrInsert(table, D, 3);
+ AssertGetOrInsert(table, C, 2);
+ AssertGetOrInsert(table, B, 1);
+ AssertGetOrInsert(table, A, 0);
+
+ const int64_t size = 9;
+ ASSERT_EQ(table.size(), size);
+ {
+ std::vector<int64_t> values(size);
+ table.CopyValues(values.data());
+ EXPECT_THAT(values, testing::ElementsAre(A, B, C, D, E, 0, F, G, H));
+ }
+ {
+ const int32_t start_offset = 3;
+ std::vector<int64_t> values(size - start_offset);
+ table.CopyValues(start_offset, values.data());
+ EXPECT_THAT(values, testing::ElementsAre(D, E, 0, F, G, H));
+ }
+}
+
+TEST(ScalarMemoTable, UInt16) {
+ const uint16_t A = 1236, B = 0, C = 65535, D = 32767, E = 1;
+
+ ScalarMemoTable<uint16_t> table(default_memory_pool(), 0);
+ ASSERT_EQ(table.size(), 0);
+ AssertGet(table, A, kKeyNotFound);
+ AssertGetNull(table, kKeyNotFound);
+ AssertGetOrInsert(table, A, 0);
+ AssertGet(table, B, kKeyNotFound);
+ AssertGetOrInsert(table, B, 1);
+ AssertGetOrInsert(table, C, 2);
+ AssertGetOrInsert(table, D, 3);
+
+ {
+ EXPECT_EQ(table.size(), 4);
+ std::vector<uint16_t> values(table.size());
+ table.CopyValues(values.data());
+ EXPECT_THAT(values, testing::ElementsAre(A, B, C, D));
+ }
+
+ AssertGetOrInsertNull(table, 4);
+ AssertGetOrInsert(table, E, 5);
+
+ AssertGet(table, A, 0);
+ AssertGetOrInsert(table, A, 0);
+ AssertGetOrInsert(table, B, 1);
+ AssertGetOrInsert(table, C, 2);
+ AssertGetOrInsert(table, D, 3);
+ AssertGetNull(table, 4);
+ AssertGet(table, E, 5);
+ AssertGetOrInsert(table, E, 5);
+
+ ASSERT_EQ(table.size(), 6);
+ std::vector<uint16_t> values(table.size());
+ table.CopyValues(values.data());
+ EXPECT_THAT(values, testing::ElementsAre(A, B, C, D, 0, E));
+}
+
+TEST(SmallScalarMemoTable, Int8) {
+ const int8_t A = 1, B = 0, C = -1, D = -128, E = 127;
+
+ SmallScalarMemoTable<int8_t> table(default_memory_pool(), 0);
+ AssertGet(table, A, kKeyNotFound);
+ AssertGetNull(table, kKeyNotFound);
+ AssertGetOrInsert(table, A, 0);
+ AssertGet(table, B, kKeyNotFound);
+ AssertGetOrInsert(table, B, 1);
+ AssertGetOrInsert(table, C, 2);
+ AssertGetOrInsert(table, D, 3);
+ AssertGetOrInsert(table, E, 4);
+ AssertGetOrInsertNull(table, 5);
+
+ AssertGet(table, A, 0);
+ AssertGetOrInsert(table, A, 0);
+ AssertGetOrInsert(table, B, 1);
+ AssertGetOrInsert(table, C, 2);
+ AssertGetOrInsert(table, D, 3);
+ AssertGet(table, E, 4);
+ AssertGetOrInsert(table, E, 4);
+ AssertGetNull(table, 5);
+ AssertGetOrInsertNull(table, 5);
+
+ ASSERT_EQ(table.size(), 6);
+ std::vector<int8_t> values(table.size());
+ table.CopyValues(values.data());
+ EXPECT_THAT(values, testing::ElementsAre(A, B, C, D, E, 0));
+}
+
+TEST(SmallScalarMemoTable, Bool) {
+ SmallScalarMemoTable<bool> table(default_memory_pool(), 0);
+ ASSERT_EQ(table.size(), 0);
+ AssertGet(table, true, kKeyNotFound);
+ AssertGetOrInsert(table, true, 0);
+ AssertGetOrInsertNull(table, 1);
+ AssertGetOrInsert(table, false, 2);
+
+ AssertGet(table, true, 0);
+ AssertGetOrInsert(table, true, 0);
+ AssertGetNull(table, 1);
+ AssertGetOrInsertNull(table, 1);
+ AssertGet(table, false, 2);
+ AssertGetOrInsert(table, false, 2);
+
+ ASSERT_EQ(table.size(), 3);
+ EXPECT_THAT(table.values(), testing::ElementsAre(true, 0, false));
+ // NOTE std::vector<bool> doesn't have a data() method
+}
+
+TEST(ScalarMemoTable, Float64) {
+ const double A = 0.0, B = 1.5, C = -0.0, D = std::numeric_limits<double>::infinity(),
+ E = -D, F = std::nan("");
+
+ ScalarMemoTable<double> table(default_memory_pool(), 0);
+ ASSERT_EQ(table.size(), 0);
+ AssertGet(table, A, kKeyNotFound);
+ AssertGetNull(table, kKeyNotFound);
+ AssertGetOrInsert(table, A, 0);
+ AssertGet(table, B, kKeyNotFound);
+ AssertGetOrInsert(table, B, 1);
+ AssertGetOrInsert(table, C, 2);
+ AssertGetOrInsert(table, D, 3);
+ AssertGetOrInsert(table, E, 4);
+ AssertGetOrInsert(table, F, 5);
+
+ AssertGet(table, A, 0);
+ AssertGetOrInsert(table, A, 0);
+ AssertGetOrInsert(table, B, 1);
+ AssertGetOrInsert(table, C, 2);
+ AssertGetOrInsert(table, D, 3);
+ AssertGet(table, E, 4);
+ AssertGetOrInsert(table, E, 4);
+ AssertGet(table, F, 5);
+ AssertGetOrInsert(table, F, 5);
+
+ ASSERT_EQ(table.size(), 6);
+ std::vector<double> expected({A, B, C, D, E, F});
+ std::vector<double> values(table.size());
+ table.CopyValues(values.data());
+ for (uint32_t i = 0; i < expected.size(); ++i) {
+ auto u = expected[i];
+ auto v = values[i];
+ if (std::isnan(u)) {
+ ASSERT_TRUE(std::isnan(v));
+ } else {
+ ASSERT_EQ(u, v);
+ }
+ }
+}
+
+TEST(ScalarMemoTable, StressInt64) {
+ std::default_random_engine gen(42);
+ std::uniform_int_distribution<int64_t> value_dist(-50, 50);
+#ifdef ARROW_VALGRIND
+ const int32_t n_repeats = 500;
+#else
+ const int32_t n_repeats = 10000;
+#endif
+
+ ScalarMemoTable<int64_t> table(default_memory_pool(), 0);
+ std::unordered_map<int64_t, int32_t> map;
+
+ for (int32_t i = 0; i < n_repeats; ++i) {
+ int64_t value = value_dist(gen);
+ int32_t expected, actual;
+ auto it = map.find(value);
+ if (it == map.end()) {
+ expected = static_cast<int32_t>(map.size());
+ map[value] = expected;
+ } else {
+ expected = it->second;
+ }
+ ASSERT_OK(table.GetOrInsert(value, &actual));
+ ASSERT_EQ(actual, expected);
+ }
+ ASSERT_EQ(table.size(), map.size());
+}
+
+TEST(BinaryMemoTable, Basics) {
+ std::string A = "", B = "a", C = "foo", D = "bar", E, F;
+ E += '\0';
+ F += '\0';
+ F += "trailing";
+
+ BinaryMemoTable<BinaryBuilder> table(default_memory_pool(), 0);
+ ASSERT_EQ(table.size(), 0);
+ AssertGet(table, A, kKeyNotFound);
+ AssertGetNull(table, kKeyNotFound);
+ AssertGetOrInsert(table, A, 0);
+ AssertGet(table, B, kKeyNotFound);
+ AssertGetOrInsert(table, B, 1);
+ AssertGetOrInsert(table, C, 2);
+ AssertGetOrInsert(table, D, 3);
+ AssertGetOrInsert(table, E, 4);
+ AssertGetOrInsert(table, F, 5);
+ AssertGetOrInsertNull(table, 6);
+
+ AssertGet(table, A, 0);
+ AssertGetOrInsert(table, A, 0);
+ AssertGet(table, B, 1);
+ AssertGetOrInsert(table, B, 1);
+ AssertGetOrInsert(table, C, 2);
+ AssertGetOrInsert(table, D, 3);
+ AssertGetOrInsert(table, E, 4);
+ AssertGet(table, F, 5);
+ AssertGetOrInsert(table, F, 5);
+ AssertGetNull(table, 6);
+ AssertGetOrInsertNull(table, 6);
+
+ ASSERT_EQ(table.size(), 7);
+ ASSERT_EQ(table.values_size(), 17);
+
+ const int32_t size = table.size();
+ {
+ std::vector<int8_t> offsets(size + 1);
+ table.CopyOffsets(offsets.data());
+ EXPECT_THAT(offsets, testing::ElementsAre(0, 0, 1, 4, 7, 8, 17, 17));
+
+ std::string expected_values;
+ expected_values += "afoobar";
+ expected_values += '\0';
+ expected_values += '\0';
+ expected_values += "trailing";
+ std::string values(17, 'X');
+ table.CopyValues(reinterpret_cast<uint8_t*>(&values[0]));
+ ASSERT_EQ(values, expected_values);
+ }
+ {
+ const int32_t start_offset = 4;
+ std::vector<int8_t> offsets(size + 1 - start_offset);
+ table.CopyOffsets(start_offset, offsets.data());
+ EXPECT_THAT(offsets, testing::ElementsAre(0, 1, 10, 10));
+
+ std::string expected_values;
+ expected_values += '\0';
+ expected_values += '\0';
+ expected_values += "trailing";
+ std::string values(10, 'X');
+ table.CopyValues(4 /* start offset */, reinterpret_cast<uint8_t*>(&values[0]));
+ ASSERT_EQ(values, expected_values);
+ }
+ {
+ const int32_t start_offset = 1;
+ std::vector<std::string> actual;
+ table.VisitValues(start_offset, [&](const util::string_view& v) {
+ actual.emplace_back(v.data(), v.length());
+ });
+ EXPECT_THAT(actual, testing::ElementsAre(B, C, D, E, F, ""));
+ }
+}
+
+TEST(BinaryMemoTable, Stress) {
+#ifdef ARROW_VALGRIND
+ const int32_t n_values = 20;
+ const int32_t n_repeats = 20;
+#else
+ const int32_t n_values = 100;
+ const int32_t n_repeats = 100;
+#endif
+
+ const auto values = MakeDistinctStrings(n_values);
+
+ BinaryMemoTable<BinaryBuilder> table(default_memory_pool(), 0);
+ std::unordered_map<std::string, int32_t> map;
+
+ for (int32_t i = 0; i < n_repeats; ++i) {
+ for (const auto& value : values) {
+ int32_t expected, actual;
+ auto it = map.find(value);
+ if (it == map.end()) {
+ expected = static_cast<int32_t>(map.size());
+ map[value] = expected;
+ } else {
+ expected = it->second;
+ }
+ ASSERT_OK(table.GetOrInsert(value, &actual));
+ ASSERT_EQ(actual, expected);
+ }
+ }
+ ASSERT_EQ(table.size(), map.size());
+}
+
+TEST(BinaryMemoTable, Empty) {
+ BinaryMemoTable<BinaryBuilder> table(default_memory_pool());
+ ASSERT_EQ(table.size(), 0);
+ BinaryMemoTable<BinaryBuilder>::builder_offset_type offsets[1];
+ table.CopyOffsets(0, offsets);
+ EXPECT_EQ(offsets[0], 0);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/int128_internal.h b/src/arrow/cpp/src/arrow/util/int128_internal.h
new file mode 100644
index 000000000..1d494671a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/int128_internal.h
@@ -0,0 +1,45 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#pragma once
+
+#include "arrow/util/config.h"
+#include "arrow/util/macros.h"
+
+#ifndef ARROW_USE_NATIVE_INT128
+#include <boost/multiprecision/cpp_int.hpp>
+#endif
+
+namespace arrow {
+namespace internal {
+
+// NOTE: __int128_t and boost::multiprecision::int128_t are not interchangeable.
+// For example, __int128_t does not have any member function, and does not have
+// operator<<(std::ostream, __int128_t). On the other hand, the behavior of
+// boost::multiprecision::int128_t might be surprising with some configs (e.g.,
+// static_cast<uint64_t>(boost::multiprecision::uint128_t) might return
+// ~uint64_t{0} instead of the lower 64 bits of the input).
+// Try to minimize the usage of int128_t and uint128_t.
+#ifdef ARROW_USE_NATIVE_INT128
+using int128_t = __int128_t;
+using uint128_t = __uint128_t;
+#else
+using boost::multiprecision::int128_t;
+using boost::multiprecision::uint128_t;
+#endif
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/int_util.cc b/src/arrow/cpp/src/arrow/util/int_util.cc
new file mode 100644
index 000000000..24c5fe56e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/int_util.cc
@@ -0,0 +1,952 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/int_util.h"
+
+#include <algorithm>
+#include <cstring>
+#include <limits>
+
+#include "arrow/array/data.h"
+#include "arrow/datum.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/ubsan.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+namespace internal {
+
+using internal::checked_cast;
+
+static constexpr uint64_t max_uint8 =
+ static_cast<uint64_t>(std::numeric_limits<uint8_t>::max());
+static constexpr uint64_t max_uint16 =
+ static_cast<uint64_t>(std::numeric_limits<uint16_t>::max());
+static constexpr uint64_t max_uint32 =
+ static_cast<uint64_t>(std::numeric_limits<uint32_t>::max());
+static constexpr uint64_t max_uint64 = std::numeric_limits<uint64_t>::max();
+
+static constexpr uint64_t mask_uint8 = ~0xffULL;
+static constexpr uint64_t mask_uint16 = ~0xffffULL;
+static constexpr uint64_t mask_uint32 = ~0xffffffffULL;
+
+//
+// Unsigned integer width detection
+//
+
+static const uint64_t max_uints[] = {0, max_uint8, max_uint16, 0, max_uint32,
+ 0, 0, 0, max_uint64};
+
+// Check if we would need to expand the underlying storage type
+static inline uint8_t ExpandedUIntWidth(uint64_t val, uint8_t current_width) {
+ // Optimize for the common case where width doesn't change
+ if (ARROW_PREDICT_TRUE(val <= max_uints[current_width])) {
+ return current_width;
+ }
+ if (current_width == 1 && val <= max_uint8) {
+ return 1;
+ } else if (current_width <= 2 && val <= max_uint16) {
+ return 2;
+ } else if (current_width <= 4 && val <= max_uint32) {
+ return 4;
+ } else {
+ return 8;
+ }
+}
+
+uint8_t DetectUIntWidth(const uint64_t* values, int64_t length, uint8_t min_width) {
+ uint8_t width = min_width;
+ if (min_width < 8) {
+ auto p = values;
+ const auto end = p + length;
+ while (p <= end - 16) {
+ // This is probably SIMD-izable
+ auto u = p[0];
+ auto v = p[1];
+ auto w = p[2];
+ auto x = p[3];
+ u |= p[4];
+ v |= p[5];
+ w |= p[6];
+ x |= p[7];
+ u |= p[8];
+ v |= p[9];
+ w |= p[10];
+ x |= p[11];
+ u |= p[12];
+ v |= p[13];
+ w |= p[14];
+ x |= p[15];
+ p += 16;
+ width = ExpandedUIntWidth(u | v | w | x, width);
+ if (ARROW_PREDICT_FALSE(width == 8)) {
+ break;
+ }
+ }
+ if (p <= end - 8) {
+ auto u = p[0];
+ auto v = p[1];
+ auto w = p[2];
+ auto x = p[3];
+ u |= p[4];
+ v |= p[5];
+ w |= p[6];
+ x |= p[7];
+ p += 8;
+ width = ExpandedUIntWidth(u | v | w | x, width);
+ }
+ while (p < end) {
+ width = ExpandedUIntWidth(*p++, width);
+ }
+ }
+ return width;
+}
+
+uint8_t DetectUIntWidth(const uint64_t* values, const uint8_t* valid_bytes,
+ int64_t length, uint8_t min_width) {
+ if (valid_bytes == nullptr) {
+ return DetectUIntWidth(values, length, min_width);
+ }
+ uint8_t width = min_width;
+ if (min_width < 8) {
+ auto p = values;
+ const auto end = p + length;
+ auto b = valid_bytes;
+
+#define MASK(p, b, i) p[i] * (b[i] != 0)
+
+ while (p <= end - 8) {
+ // This is probably be SIMD-izable
+ auto u = MASK(p, b, 0);
+ auto v = MASK(p, b, 1);
+ auto w = MASK(p, b, 2);
+ auto x = MASK(p, b, 3);
+ u |= MASK(p, b, 4);
+ v |= MASK(p, b, 5);
+ w |= MASK(p, b, 6);
+ x |= MASK(p, b, 7);
+ b += 8;
+ p += 8;
+ width = ExpandedUIntWidth(u | v | w | x, width);
+ if (ARROW_PREDICT_FALSE(width == 8)) {
+ break;
+ }
+ }
+ uint64_t mask = 0;
+ while (p < end) {
+ mask |= MASK(p, b, 0);
+ ++b;
+ ++p;
+ }
+ width = ExpandedUIntWidth(mask, width);
+
+#undef MASK
+ }
+ return width;
+}
+
+//
+// Signed integer width detection
+//
+
+uint8_t DetectIntWidth(const int64_t* values, int64_t length, uint8_t min_width) {
+ if (min_width == 8) {
+ return min_width;
+ }
+ uint8_t width = min_width;
+
+ auto p = values;
+ const auto end = p + length;
+ // Strategy: to determine whether `x` is between -0x80 and 0x7f,
+ // we determine whether `x + 0x80` is between 0x00 and 0xff. The
+ // latter can be done with a simple AND mask with ~0xff and, more
+ // importantly, can be computed in a single step over multiple ORed
+ // values (so we can branch once every N items instead of once every item).
+ // This strategy could probably lend itself to explicit SIMD-ization,
+ // if more performance is needed.
+ constexpr uint64_t addend8 = 0x80ULL;
+ constexpr uint64_t addend16 = 0x8000ULL;
+ constexpr uint64_t addend32 = 0x80000000ULL;
+
+ auto test_one_item = [&](uint64_t addend, uint64_t test_mask) -> bool {
+ auto v = *p++;
+ if (ARROW_PREDICT_FALSE(((v + addend) & test_mask) != 0)) {
+ --p;
+ return false;
+ } else {
+ return true;
+ }
+ };
+
+ auto test_four_items = [&](uint64_t addend, uint64_t test_mask) -> bool {
+ auto mask = (p[0] + addend) | (p[1] + addend) | (p[2] + addend) | (p[3] + addend);
+ p += 4;
+ if (ARROW_PREDICT_FALSE((mask & test_mask) != 0)) {
+ p -= 4;
+ return false;
+ } else {
+ return true;
+ }
+ };
+
+ if (width == 1) {
+ while (p <= end - 4) {
+ if (!test_four_items(addend8, mask_uint8)) {
+ width = 2;
+ goto width2;
+ }
+ }
+ while (p < end) {
+ if (!test_one_item(addend8, mask_uint8)) {
+ width = 2;
+ goto width2;
+ }
+ }
+ return 1;
+ }
+width2:
+ if (width == 2) {
+ while (p <= end - 4) {
+ if (!test_four_items(addend16, mask_uint16)) {
+ width = 4;
+ goto width4;
+ }
+ }
+ while (p < end) {
+ if (!test_one_item(addend16, mask_uint16)) {
+ width = 4;
+ goto width4;
+ }
+ }
+ return 2;
+ }
+width4:
+ if (width == 4) {
+ while (p <= end - 4) {
+ if (!test_four_items(addend32, mask_uint32)) {
+ width = 8;
+ goto width8;
+ }
+ }
+ while (p < end) {
+ if (!test_one_item(addend32, mask_uint32)) {
+ width = 8;
+ goto width8;
+ }
+ }
+ return 4;
+ }
+width8:
+ return 8;
+}
+
+uint8_t DetectIntWidth(const int64_t* values, const uint8_t* valid_bytes, int64_t length,
+ uint8_t min_width) {
+ if (valid_bytes == nullptr) {
+ return DetectIntWidth(values, length, min_width);
+ }
+
+ if (min_width == 8) {
+ return min_width;
+ }
+ uint8_t width = min_width;
+
+ auto p = values;
+ const auto end = p + length;
+ auto b = valid_bytes;
+ // Strategy is similar to the no-nulls case above, but we also
+ // have to zero any incoming items that have a zero validity byte.
+ constexpr uint64_t addend8 = 0x80ULL;
+ constexpr uint64_t addend16 = 0x8000ULL;
+ constexpr uint64_t addend32 = 0x80000000ULL;
+
+#define MASK(p, b, addend, i) (p[i] + addend) * (b[i] != 0)
+
+ auto test_one_item = [&](uint64_t addend, uint64_t test_mask) -> bool {
+ auto v = MASK(p, b, addend, 0);
+ ++b;
+ ++p;
+ if (ARROW_PREDICT_FALSE((v & test_mask) != 0)) {
+ --b;
+ --p;
+ return false;
+ } else {
+ return true;
+ }
+ };
+
+ auto test_eight_items = [&](uint64_t addend, uint64_t test_mask) -> bool {
+ auto mask1 = MASK(p, b, addend, 0) | MASK(p, b, addend, 1) | MASK(p, b, addend, 2) |
+ MASK(p, b, addend, 3);
+ auto mask2 = MASK(p, b, addend, 4) | MASK(p, b, addend, 5) | MASK(p, b, addend, 6) |
+ MASK(p, b, addend, 7);
+ b += 8;
+ p += 8;
+ if (ARROW_PREDICT_FALSE(((mask1 | mask2) & test_mask) != 0)) {
+ b -= 8;
+ p -= 8;
+ return false;
+ } else {
+ return true;
+ }
+ };
+
+#undef MASK
+
+ if (width == 1) {
+ while (p <= end - 8) {
+ if (!test_eight_items(addend8, mask_uint8)) {
+ width = 2;
+ goto width2;
+ }
+ }
+ while (p < end) {
+ if (!test_one_item(addend8, mask_uint8)) {
+ width = 2;
+ goto width2;
+ }
+ }
+ return 1;
+ }
+width2:
+ if (width == 2) {
+ while (p <= end - 8) {
+ if (!test_eight_items(addend16, mask_uint16)) {
+ width = 4;
+ goto width4;
+ }
+ }
+ while (p < end) {
+ if (!test_one_item(addend16, mask_uint16)) {
+ width = 4;
+ goto width4;
+ }
+ }
+ return 2;
+ }
+width4:
+ if (width == 4) {
+ while (p <= end - 8) {
+ if (!test_eight_items(addend32, mask_uint32)) {
+ width = 8;
+ goto width8;
+ }
+ }
+ while (p < end) {
+ if (!test_one_item(addend32, mask_uint32)) {
+ width = 8;
+ goto width8;
+ }
+ }
+ return 4;
+ }
+width8:
+ return 8;
+}
+
+template <typename Source, typename Dest>
+static inline void CastIntsInternal(const Source* src, Dest* dest, int64_t length) {
+ while (length >= 4) {
+ dest[0] = static_cast<Dest>(src[0]);
+ dest[1] = static_cast<Dest>(src[1]);
+ dest[2] = static_cast<Dest>(src[2]);
+ dest[3] = static_cast<Dest>(src[3]);
+ length -= 4;
+ src += 4;
+ dest += 4;
+ }
+ while (length > 0) {
+ *dest++ = static_cast<Dest>(*src++);
+ --length;
+ }
+}
+
+void DowncastInts(const int64_t* source, int8_t* dest, int64_t length) {
+ CastIntsInternal(source, dest, length);
+}
+
+void DowncastInts(const int64_t* source, int16_t* dest, int64_t length) {
+ CastIntsInternal(source, dest, length);
+}
+
+void DowncastInts(const int64_t* source, int32_t* dest, int64_t length) {
+ CastIntsInternal(source, dest, length);
+}
+
+void DowncastInts(const int64_t* source, int64_t* dest, int64_t length) {
+ memcpy(dest, source, length * sizeof(int64_t));
+}
+
+void DowncastUInts(const uint64_t* source, uint8_t* dest, int64_t length) {
+ CastIntsInternal(source, dest, length);
+}
+
+void DowncastUInts(const uint64_t* source, uint16_t* dest, int64_t length) {
+ CastIntsInternal(source, dest, length);
+}
+
+void DowncastUInts(const uint64_t* source, uint32_t* dest, int64_t length) {
+ CastIntsInternal(source, dest, length);
+}
+
+void DowncastUInts(const uint64_t* source, uint64_t* dest, int64_t length) {
+ memcpy(dest, source, length * sizeof(int64_t));
+}
+
+void UpcastInts(const int32_t* source, int64_t* dest, int64_t length) {
+ CastIntsInternal(source, dest, length);
+}
+
+template <typename InputInt, typename OutputInt>
+void TransposeInts(const InputInt* src, OutputInt* dest, int64_t length,
+ const int32_t* transpose_map) {
+ while (length >= 4) {
+ dest[0] = static_cast<OutputInt>(transpose_map[src[0]]);
+ dest[1] = static_cast<OutputInt>(transpose_map[src[1]]);
+ dest[2] = static_cast<OutputInt>(transpose_map[src[2]]);
+ dest[3] = static_cast<OutputInt>(transpose_map[src[3]]);
+ length -= 4;
+ src += 4;
+ dest += 4;
+ }
+ while (length > 0) {
+ *dest++ = static_cast<OutputInt>(transpose_map[*src++]);
+ --length;
+ }
+}
+
+#define INSTANTIATE(SRC, DEST) \
+ template ARROW_EXPORT void TransposeInts( \
+ const SRC* source, DEST* dest, int64_t length, const int32_t* transpose_map);
+
+#define INSTANTIATE_ALL_DEST(DEST) \
+ INSTANTIATE(uint8_t, DEST) \
+ INSTANTIATE(int8_t, DEST) \
+ INSTANTIATE(uint16_t, DEST) \
+ INSTANTIATE(int16_t, DEST) \
+ INSTANTIATE(uint32_t, DEST) \
+ INSTANTIATE(int32_t, DEST) \
+ INSTANTIATE(uint64_t, DEST) \
+ INSTANTIATE(int64_t, DEST)
+
+#define INSTANTIATE_ALL() \
+ INSTANTIATE_ALL_DEST(uint8_t) \
+ INSTANTIATE_ALL_DEST(int8_t) \
+ INSTANTIATE_ALL_DEST(uint16_t) \
+ INSTANTIATE_ALL_DEST(int16_t) \
+ INSTANTIATE_ALL_DEST(uint32_t) \
+ INSTANTIATE_ALL_DEST(int32_t) \
+ INSTANTIATE_ALL_DEST(uint64_t) \
+ INSTANTIATE_ALL_DEST(int64_t)
+
+INSTANTIATE_ALL()
+
+#undef INSTANTIATE
+#undef INSTANTIATE_ALL
+#undef INSTANTIATE_ALL_DEST
+
+namespace {
+
+template <typename SrcType>
+struct TransposeIntsDest {
+ const SrcType* src;
+ uint8_t* dest;
+ int64_t dest_offset;
+ int64_t length;
+ const int32_t* transpose_map;
+
+ template <typename T>
+ enable_if_integer<T, Status> Visit(const T&) {
+ using DestType = typename T::c_type;
+ TransposeInts(src, reinterpret_cast<DestType*>(dest) + dest_offset, length,
+ transpose_map);
+ return Status::OK();
+ }
+
+ Status Visit(const DataType& type) {
+ return Status::TypeError("TransposeInts received non-integer dest_type");
+ }
+
+ Status operator()(const DataType& type) { return VisitTypeInline(type, this); }
+};
+
+struct TransposeIntsSrc {
+ const uint8_t* src;
+ uint8_t* dest;
+ int64_t src_offset;
+ int64_t dest_offset;
+ int64_t length;
+ const int32_t* transpose_map;
+ const DataType& dest_type;
+
+ template <typename T>
+ enable_if_integer<T, Status> Visit(const T&) {
+ using SrcType = typename T::c_type;
+ return TransposeIntsDest<SrcType>{reinterpret_cast<const SrcType*>(src) + src_offset,
+ dest, dest_offset, length,
+ transpose_map}(dest_type);
+ }
+
+ Status Visit(const DataType& type) {
+ return Status::TypeError("TransposeInts received non-integer dest_type");
+ }
+
+ Status operator()(const DataType& type) { return VisitTypeInline(type, this); }
+};
+
+}; // namespace
+
+Status TransposeInts(const DataType& src_type, const DataType& dest_type,
+ const uint8_t* src, uint8_t* dest, int64_t src_offset,
+ int64_t dest_offset, int64_t length, const int32_t* transpose_map) {
+ TransposeIntsSrc transposer{src, dest, src_offset, dest_offset,
+ length, transpose_map, dest_type};
+ return transposer(src_type);
+}
+
+template <typename T>
+static std::string FormatInt(T val) {
+ return std::to_string(val);
+}
+
+template <typename IndexCType, bool IsSigned = std::is_signed<IndexCType>::value>
+static Status CheckIndexBoundsImpl(const ArrayData& indices, uint64_t upper_limit) {
+ // For unsigned integers, if the values array is larger than the maximum
+ // index value (e.g. especially for UINT8 / UINT16), then there is no need to
+ // boundscheck.
+ if (!IsSigned &&
+ upper_limit > static_cast<uint64_t>(std::numeric_limits<IndexCType>::max())) {
+ return Status::OK();
+ }
+
+ const IndexCType* indices_data = indices.GetValues<IndexCType>(1);
+ const uint8_t* bitmap = nullptr;
+ if (indices.buffers[0]) {
+ bitmap = indices.buffers[0]->data();
+ }
+ auto IsOutOfBounds = [&](IndexCType val) -> bool {
+ return ((IsSigned && val < 0) ||
+ (val >= 0 && static_cast<uint64_t>(val) >= upper_limit));
+ };
+ return VisitSetBitRuns(
+ bitmap, indices.offset, indices.length, [&](int64_t offset, int64_t length) {
+ bool block_out_of_bounds = false;
+ for (int64_t i = 0; i < length; ++i) {
+ block_out_of_bounds |= IsOutOfBounds(indices_data[offset + i]);
+ }
+ if (ARROW_PREDICT_FALSE(block_out_of_bounds)) {
+ for (int64_t i = 0; i < length; ++i) {
+ if (IsOutOfBounds(indices_data[offset + i])) {
+ return Status::IndexError("Index ", FormatInt(indices_data[offset + i]),
+ " out of bounds");
+ }
+ }
+ }
+ return Status::OK();
+ });
+}
+
+/// \brief Branchless boundschecking of the indices. Processes batches of
+/// indices at a time and shortcircuits when encountering an out-of-bounds
+/// index in a batch
+Status CheckIndexBounds(const ArrayData& indices, uint64_t upper_limit) {
+ switch (indices.type->id()) {
+ case Type::INT8:
+ return CheckIndexBoundsImpl<int8_t>(indices, upper_limit);
+ case Type::INT16:
+ return CheckIndexBoundsImpl<int16_t>(indices, upper_limit);
+ case Type::INT32:
+ return CheckIndexBoundsImpl<int32_t>(indices, upper_limit);
+ case Type::INT64:
+ return CheckIndexBoundsImpl<int64_t>(indices, upper_limit);
+ case Type::UINT8:
+ return CheckIndexBoundsImpl<uint8_t>(indices, upper_limit);
+ case Type::UINT16:
+ return CheckIndexBoundsImpl<uint16_t>(indices, upper_limit);
+ case Type::UINT32:
+ return CheckIndexBoundsImpl<uint32_t>(indices, upper_limit);
+ case Type::UINT64:
+ return CheckIndexBoundsImpl<uint64_t>(indices, upper_limit);
+ default:
+ return Status::Invalid("Invalid index type for boundschecking");
+ }
+}
+
+// ----------------------------------------------------------------------
+// Utilities for casting from one integer type to another
+
+namespace {
+
+template <typename InType, typename CType = typename InType::c_type>
+Status IntegersInRange(const Datum& datum, CType bound_lower, CType bound_upper) {
+ if (std::numeric_limits<CType>::lowest() >= bound_lower &&
+ std::numeric_limits<CType>::max() <= bound_upper) {
+ return Status::OK();
+ }
+
+ auto IsOutOfBounds = [&](CType val) -> bool {
+ return val < bound_lower || val > bound_upper;
+ };
+ auto IsOutOfBoundsMaybeNull = [&](CType val, bool is_valid) -> bool {
+ return is_valid && (val < bound_lower || val > bound_upper);
+ };
+ auto GetErrorMessage = [&](CType val) {
+ return Status::Invalid("Integer value ", FormatInt(val),
+ " not in range: ", FormatInt(bound_lower), " to ",
+ FormatInt(bound_upper));
+ };
+
+ if (datum.kind() == Datum::SCALAR) {
+ const auto& scalar = datum.scalar_as<typename TypeTraits<InType>::ScalarType>();
+ if (IsOutOfBoundsMaybeNull(scalar.value, scalar.is_valid)) {
+ return GetErrorMessage(scalar.value);
+ }
+ return Status::OK();
+ }
+
+ const ArrayData& indices = *datum.array();
+ const CType* indices_data = indices.GetValues<CType>(1);
+ const uint8_t* bitmap = nullptr;
+ if (indices.buffers[0]) {
+ bitmap = indices.buffers[0]->data();
+ }
+ OptionalBitBlockCounter indices_bit_counter(bitmap, indices.offset, indices.length);
+ int64_t position = 0;
+ int64_t offset_position = indices.offset;
+ while (position < indices.length) {
+ BitBlockCount block = indices_bit_counter.NextBlock();
+ bool block_out_of_bounds = false;
+ if (block.popcount == block.length) {
+ // Fast path: branchless
+ int64_t i = 0;
+ for (int64_t chunk = 0; chunk < block.length / 8; ++chunk) {
+ // Let the compiler unroll this
+ for (int j = 0; j < 8; ++j) {
+ block_out_of_bounds |= IsOutOfBounds(indices_data[i++]);
+ }
+ }
+ for (; i < block.length; ++i) {
+ block_out_of_bounds |= IsOutOfBounds(indices_data[i]);
+ }
+ } else if (block.popcount > 0) {
+ // Indices have nulls, must only boundscheck non-null values
+ int64_t i = 0;
+ for (int64_t chunk = 0; chunk < block.length / 8; ++chunk) {
+ // Let the compiler unroll this
+ for (int j = 0; j < 8; ++j) {
+ block_out_of_bounds |= IsOutOfBoundsMaybeNull(
+ indices_data[i], BitUtil::GetBit(bitmap, offset_position + i));
+ ++i;
+ }
+ }
+ for (; i < block.length; ++i) {
+ block_out_of_bounds |= IsOutOfBoundsMaybeNull(
+ indices_data[i], BitUtil::GetBit(bitmap, offset_position + i));
+ }
+ }
+ if (ARROW_PREDICT_FALSE(block_out_of_bounds)) {
+ if (indices.GetNullCount() > 0) {
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (IsOutOfBoundsMaybeNull(indices_data[i],
+ BitUtil::GetBit(bitmap, offset_position + i))) {
+ return GetErrorMessage(indices_data[i]);
+ }
+ }
+ } else {
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (IsOutOfBounds(indices_data[i])) {
+ return GetErrorMessage(indices_data[i]);
+ }
+ }
+ }
+ }
+ indices_data += block.length;
+ position += block.length;
+ offset_position += block.length;
+ }
+ return Status::OK();
+}
+
+template <typename Type>
+Status CheckIntegersInRangeImpl(const Datum& datum, const Scalar& bound_lower,
+ const Scalar& bound_upper) {
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+ return IntegersInRange<Type>(datum, checked_cast<const ScalarType&>(bound_lower).value,
+ checked_cast<const ScalarType&>(bound_upper).value);
+}
+
+} // namespace
+
+Status CheckIntegersInRange(const Datum& datum, const Scalar& bound_lower,
+ const Scalar& bound_upper) {
+ Type::type type_id = datum.type()->id();
+
+ if (bound_lower.type->id() != type_id || bound_upper.type->id() != type_id ||
+ !bound_lower.is_valid || !bound_upper.is_valid) {
+ return Status::Invalid("Scalar bound types must be non-null and same type as data");
+ }
+
+ switch (type_id) {
+ case Type::INT8:
+ return CheckIntegersInRangeImpl<Int8Type>(datum, bound_lower, bound_upper);
+ case Type::INT16:
+ return CheckIntegersInRangeImpl<Int16Type>(datum, bound_lower, bound_upper);
+ case Type::INT32:
+ return CheckIntegersInRangeImpl<Int32Type>(datum, bound_lower, bound_upper);
+ case Type::INT64:
+ return CheckIntegersInRangeImpl<Int64Type>(datum, bound_lower, bound_upper);
+ case Type::UINT8:
+ return CheckIntegersInRangeImpl<UInt8Type>(datum, bound_lower, bound_upper);
+ case Type::UINT16:
+ return CheckIntegersInRangeImpl<UInt16Type>(datum, bound_lower, bound_upper);
+ case Type::UINT32:
+ return CheckIntegersInRangeImpl<UInt32Type>(datum, bound_lower, bound_upper);
+ case Type::UINT64:
+ return CheckIntegersInRangeImpl<UInt64Type>(datum, bound_lower, bound_upper);
+ default:
+ return Status::TypeError("Invalid index type for boundschecking");
+ }
+}
+
+namespace {
+
+template <typename O, typename I, typename Enable = void>
+struct is_number_downcast {
+ static constexpr bool value = false;
+};
+
+template <typename O, typename I>
+struct is_number_downcast<
+ O, I, enable_if_t<is_number_type<O>::value && is_number_type<I>::value>> {
+ using O_T = typename O::c_type;
+ using I_T = typename I::c_type;
+
+ static constexpr bool value =
+ ((!std::is_same<O, I>::value) &&
+ // Both types are of the same sign-ness.
+ ((std::is_signed<O_T>::value == std::is_signed<I_T>::value) &&
+ // Both types are of the same integral-ness.
+ (std::is_floating_point<O_T>::value == std::is_floating_point<I_T>::value)) &&
+ // Smaller output size
+ (sizeof(O_T) < sizeof(I_T)));
+};
+
+template <typename O, typename I, typename Enable = void>
+struct is_number_upcast {
+ static constexpr bool value = false;
+};
+
+template <typename O, typename I>
+struct is_number_upcast<
+ O, I, enable_if_t<is_number_type<O>::value && is_number_type<I>::value>> {
+ using O_T = typename O::c_type;
+ using I_T = typename I::c_type;
+
+ static constexpr bool value =
+ ((!std::is_same<O, I>::value) &&
+ // Both types are of the same sign-ness.
+ ((std::is_signed<O_T>::value == std::is_signed<I_T>::value) &&
+ // Both types are of the same integral-ness.
+ (std::is_floating_point<O_T>::value == std::is_floating_point<I_T>::value)) &&
+ // Larger output size
+ (sizeof(O_T) > sizeof(I_T)));
+};
+
+template <typename O, typename I, typename Enable = void>
+struct is_integral_signed_to_unsigned {
+ static constexpr bool value = false;
+};
+
+template <typename O, typename I>
+struct is_integral_signed_to_unsigned<
+ O, I, enable_if_t<is_integer_type<O>::value && is_integer_type<I>::value>> {
+ using O_T = typename O::c_type;
+ using I_T = typename I::c_type;
+
+ static constexpr bool value =
+ ((!std::is_same<O, I>::value) &&
+ ((std::is_unsigned<O_T>::value && std::is_signed<I_T>::value)));
+};
+
+template <typename O, typename I, typename Enable = void>
+struct is_integral_unsigned_to_signed {
+ static constexpr bool value = false;
+};
+
+template <typename O, typename I>
+struct is_integral_unsigned_to_signed<
+ O, I, enable_if_t<is_integer_type<O>::value && is_integer_type<I>::value>> {
+ using O_T = typename O::c_type;
+ using I_T = typename I::c_type;
+
+ static constexpr bool value =
+ ((!std::is_same<O, I>::value) &&
+ ((std::is_signed<O_T>::value && std::is_unsigned<I_T>::value)));
+};
+
+// This set of functions SafeMinimum/SafeMaximum would be simplified with
+// C++17 and `if constexpr`.
+
+// clang-format doesn't handle this construct properly. Thus the macro, but it
+// also improves readability.
+//
+// The effective return type of the function is always `I::c_type`, this is
+// just how enable_if works with functions.
+#define RET_TYPE(TRAIT) enable_if_t<TRAIT<O, I>::value, typename I::c_type>
+
+template <typename O, typename I>
+constexpr RET_TYPE(std::is_same) SafeMinimum() {
+ using out_type = typename O::c_type;
+
+ return std::numeric_limits<out_type>::lowest();
+}
+
+template <typename O, typename I>
+constexpr RET_TYPE(std::is_same) SafeMaximum() {
+ using out_type = typename O::c_type;
+
+ return std::numeric_limits<out_type>::max();
+}
+
+template <typename O, typename I>
+constexpr RET_TYPE(is_number_downcast) SafeMinimum() {
+ using out_type = typename O::c_type;
+
+ return std::numeric_limits<out_type>::lowest();
+}
+
+template <typename O, typename I>
+constexpr RET_TYPE(is_number_downcast) SafeMaximum() {
+ using out_type = typename O::c_type;
+
+ return std::numeric_limits<out_type>::max();
+}
+
+template <typename O, typename I>
+constexpr RET_TYPE(is_number_upcast) SafeMinimum() {
+ using in_type = typename I::c_type;
+ return std::numeric_limits<in_type>::lowest();
+}
+
+template <typename O, typename I>
+constexpr RET_TYPE(is_number_upcast) SafeMaximum() {
+ using in_type = typename I::c_type;
+ return std::numeric_limits<in_type>::max();
+}
+
+template <typename O, typename I>
+constexpr RET_TYPE(is_integral_unsigned_to_signed) SafeMinimum() {
+ return 0;
+}
+
+template <typename O, typename I>
+constexpr RET_TYPE(is_integral_unsigned_to_signed) SafeMaximum() {
+ using in_type = typename I::c_type;
+ using out_type = typename O::c_type;
+
+ // Equality is missing because in_type::max() > out_type::max() when types
+ // are of the same width.
+ return static_cast<in_type>(sizeof(in_type) < sizeof(out_type)
+ ? std::numeric_limits<in_type>::max()
+ : std::numeric_limits<out_type>::max());
+}
+
+template <typename O, typename I>
+constexpr RET_TYPE(is_integral_signed_to_unsigned) SafeMinimum() {
+ return 0;
+}
+
+template <typename O, typename I>
+constexpr RET_TYPE(is_integral_signed_to_unsigned) SafeMaximum() {
+ using in_type = typename I::c_type;
+ using out_type = typename O::c_type;
+
+ return static_cast<in_type>(sizeof(in_type) <= sizeof(out_type)
+ ? std::numeric_limits<in_type>::max()
+ : std::numeric_limits<out_type>::max());
+}
+
+#undef RET_TYPE
+
+#define GET_MIN_MAX_CASE(TYPE, OUT_TYPE) \
+ case Type::TYPE: \
+ *min = SafeMinimum<OUT_TYPE, InType>(); \
+ *max = SafeMaximum<OUT_TYPE, InType>(); \
+ break
+
+template <typename InType, typename T = typename InType::c_type>
+void GetSafeMinMax(Type::type out_type, T* min, T* max) {
+ switch (out_type) {
+ GET_MIN_MAX_CASE(INT8, Int8Type);
+ GET_MIN_MAX_CASE(INT16, Int16Type);
+ GET_MIN_MAX_CASE(INT32, Int32Type);
+ GET_MIN_MAX_CASE(INT64, Int64Type);
+ GET_MIN_MAX_CASE(UINT8, UInt8Type);
+ GET_MIN_MAX_CASE(UINT16, UInt16Type);
+ GET_MIN_MAX_CASE(UINT32, UInt32Type);
+ GET_MIN_MAX_CASE(UINT64, UInt64Type);
+ default:
+ break;
+ }
+}
+
+template <typename Type, typename CType = typename Type::c_type,
+ typename ScalarType = typename TypeTraits<Type>::ScalarType>
+Status IntegersCanFitImpl(const Datum& datum, const DataType& target_type) {
+ CType bound_min{}, bound_max{};
+ GetSafeMinMax<Type>(target_type.id(), &bound_min, &bound_max);
+ return CheckIntegersInRange(datum, ScalarType(bound_min), ScalarType(bound_max));
+}
+
+} // namespace
+
+Status IntegersCanFit(const Datum& datum, const DataType& target_type) {
+ if (!is_integer(target_type.id())) {
+ return Status::Invalid("Target type is not an integer type: ", target_type);
+ }
+
+ switch (datum.type()->id()) {
+ case Type::INT8:
+ return IntegersCanFitImpl<Int8Type>(datum, target_type);
+ case Type::INT16:
+ return IntegersCanFitImpl<Int16Type>(datum, target_type);
+ case Type::INT32:
+ return IntegersCanFitImpl<Int32Type>(datum, target_type);
+ case Type::INT64:
+ return IntegersCanFitImpl<Int64Type>(datum, target_type);
+ case Type::UINT8:
+ return IntegersCanFitImpl<UInt8Type>(datum, target_type);
+ case Type::UINT16:
+ return IntegersCanFitImpl<UInt16Type>(datum, target_type);
+ case Type::UINT32:
+ return IntegersCanFitImpl<UInt32Type>(datum, target_type);
+ case Type::UINT64:
+ return IntegersCanFitImpl<UInt64Type>(datum, target_type);
+ default:
+ return Status::TypeError("Invalid index type for boundschecking");
+ }
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/int_util.h b/src/arrow/cpp/src/arrow/util/int_util.h
new file mode 100644
index 000000000..bf9226cdf
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/int_util.h
@@ -0,0 +1,117 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <type_traits>
+
+#include "arrow/status.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class DataType;
+struct ArrayData;
+struct Datum;
+struct Scalar;
+
+namespace internal {
+
+ARROW_EXPORT
+uint8_t DetectUIntWidth(const uint64_t* values, int64_t length, uint8_t min_width = 1);
+
+ARROW_EXPORT
+uint8_t DetectUIntWidth(const uint64_t* values, const uint8_t* valid_bytes,
+ int64_t length, uint8_t min_width = 1);
+
+ARROW_EXPORT
+uint8_t DetectIntWidth(const int64_t* values, int64_t length, uint8_t min_width = 1);
+
+ARROW_EXPORT
+uint8_t DetectIntWidth(const int64_t* values, const uint8_t* valid_bytes, int64_t length,
+ uint8_t min_width = 1);
+
+ARROW_EXPORT
+void DowncastInts(const int64_t* source, int8_t* dest, int64_t length);
+
+ARROW_EXPORT
+void DowncastInts(const int64_t* source, int16_t* dest, int64_t length);
+
+ARROW_EXPORT
+void DowncastInts(const int64_t* source, int32_t* dest, int64_t length);
+
+ARROW_EXPORT
+void DowncastInts(const int64_t* source, int64_t* dest, int64_t length);
+
+ARROW_EXPORT
+void DowncastUInts(const uint64_t* source, uint8_t* dest, int64_t length);
+
+ARROW_EXPORT
+void DowncastUInts(const uint64_t* source, uint16_t* dest, int64_t length);
+
+ARROW_EXPORT
+void DowncastUInts(const uint64_t* source, uint32_t* dest, int64_t length);
+
+ARROW_EXPORT
+void DowncastUInts(const uint64_t* source, uint64_t* dest, int64_t length);
+
+ARROW_EXPORT
+void UpcastInts(const int32_t* source, int64_t* dest, int64_t length);
+
+template <typename InputInt, typename OutputInt>
+inline typename std::enable_if<(sizeof(InputInt) >= sizeof(OutputInt))>::type CastInts(
+ const InputInt* source, OutputInt* dest, int64_t length) {
+ DowncastInts(source, dest, length);
+}
+
+template <typename InputInt, typename OutputInt>
+inline typename std::enable_if<(sizeof(InputInt) < sizeof(OutputInt))>::type CastInts(
+ const InputInt* source, OutputInt* dest, int64_t length) {
+ UpcastInts(source, dest, length);
+}
+
+template <typename InputInt, typename OutputInt>
+ARROW_EXPORT void TransposeInts(const InputInt* source, OutputInt* dest, int64_t length,
+ const int32_t* transpose_map);
+
+ARROW_EXPORT
+Status TransposeInts(const DataType& src_type, const DataType& dest_type,
+ const uint8_t* src, uint8_t* dest, int64_t src_offset,
+ int64_t dest_offset, int64_t length, const int32_t* transpose_map);
+
+/// \brief Do vectorized boundschecking of integer-type array indices. The
+/// indices must be non-nonnegative and strictly less than the passed upper
+/// limit (which is usually the length of an array that is being indexed-into).
+ARROW_EXPORT
+Status CheckIndexBounds(const ArrayData& indices, uint64_t upper_limit);
+
+/// \brief Boundscheck integer values to determine if they are all between the
+/// passed upper and lower limits (inclusive). Upper and lower bounds must be
+/// the same type as the data and are not currently casted.
+ARROW_EXPORT
+Status CheckIntegersInRange(const Datum& datum, const Scalar& bound_lower,
+ const Scalar& bound_upper);
+
+/// \brief Use CheckIntegersInRange to determine whether the passed integers
+/// can fit safely in the passed integer type. This helps quickly determine if
+/// integer narrowing (e.g. int64->int32) is safe to do.
+ARROW_EXPORT
+Status IntegersCanFit(const Datum& datum, const DataType& target_type);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/int_util_benchmark.cc b/src/arrow/cpp/src/arrow/util/int_util_benchmark.cc
new file mode 100644
index 000000000..1eae604a7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/int_util_benchmark.cc
@@ -0,0 +1,143 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "arrow/array/array_base.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+#include "arrow/util/int_util.h"
+
+namespace arrow {
+namespace internal {
+
+constexpr auto kSeed = 0x94378165;
+
+std::vector<uint64_t> GetUIntSequence(int n_values, uint64_t addend = 0) {
+ std::vector<uint64_t> values(n_values);
+ for (int i = 0; i < n_values; ++i) {
+ values[i] = static_cast<uint64_t>(i) + addend;
+ }
+ return values;
+}
+
+std::vector<int64_t> GetIntSequence(int n_values, uint64_t addend = 0) {
+ std::vector<int64_t> values(n_values);
+ for (int i = 0; i < n_values; ++i) {
+ values[i] = static_cast<int64_t>(i) + addend;
+ }
+ return values;
+}
+
+std::vector<uint8_t> GetValidBytes(int n_values) {
+ std::vector<uint8_t> valid_bytes(n_values);
+ for (int i = 0; i < n_values; ++i) {
+ valid_bytes[i] = (i % 3 == 0) ? 1 : 0;
+ }
+ return valid_bytes;
+}
+
+static void DetectUIntWidthNoNulls(
+ benchmark::State& state) { // NOLINT non-const reference
+ const auto values = GetUIntSequence(0x12345);
+
+ while (state.KeepRunning()) {
+ auto result = DetectUIntWidth(values.data(), static_cast<int64_t>(values.size()));
+ benchmark::DoNotOptimize(result);
+ }
+ state.SetBytesProcessed(state.iterations() * values.size() * sizeof(uint64_t));
+}
+
+static void DetectUIntWidthNulls(benchmark::State& state) { // NOLINT non-const reference
+ const auto values = GetUIntSequence(0x12345);
+ const auto valid_bytes = GetValidBytes(0x12345);
+
+ while (state.KeepRunning()) {
+ auto result = DetectUIntWidth(values.data(), valid_bytes.data(),
+ static_cast<int64_t>(values.size()));
+ benchmark::DoNotOptimize(result);
+ }
+ state.SetBytesProcessed(state.iterations() * values.size() * sizeof(uint64_t));
+}
+
+static void DetectIntWidthNoNulls(
+ benchmark::State& state) { // NOLINT non-const reference
+ const auto values = GetIntSequence(0x12345, -0x1234);
+
+ while (state.KeepRunning()) {
+ auto result = DetectIntWidth(values.data(), static_cast<int64_t>(values.size()));
+ benchmark::DoNotOptimize(result);
+ }
+ state.SetBytesProcessed(state.iterations() * values.size() * sizeof(uint64_t));
+}
+
+static void DetectIntWidthNulls(benchmark::State& state) { // NOLINT non-const reference
+ const auto values = GetIntSequence(0x12345, -0x1234);
+ const auto valid_bytes = GetValidBytes(0x12345);
+
+ while (state.KeepRunning()) {
+ auto result = DetectIntWidth(values.data(), valid_bytes.data(),
+ static_cast<int64_t>(values.size()));
+ benchmark::DoNotOptimize(result);
+ }
+ state.SetBytesProcessed(state.iterations() * values.size() * sizeof(uint64_t));
+}
+
+static void CheckIndexBoundsInt32(
+ benchmark::State& state) { // NOLINT non-const reference
+ GenericItemsArgs args(state);
+ random::RandomArrayGenerator rand(kSeed);
+ auto arr = rand.Int32(args.size, 0, 100000, args.null_proportion);
+ for (auto _ : state) {
+ ABORT_NOT_OK(CheckIndexBounds(*arr->data(), 100001));
+ }
+}
+
+static void CheckIndexBoundsUInt32(
+ benchmark::State& state) { // NOLINT non-const reference
+ GenericItemsArgs args(state);
+ random::RandomArrayGenerator rand(kSeed);
+ auto arr = rand.UInt32(args.size, 0, 100000, args.null_proportion);
+ for (auto _ : state) {
+ ABORT_NOT_OK(CheckIndexBounds(*arr->data(), 100001));
+ }
+}
+
+BENCHMARK(DetectUIntWidthNoNulls);
+BENCHMARK(DetectUIntWidthNulls);
+BENCHMARK(DetectIntWidthNoNulls);
+BENCHMARK(DetectIntWidthNulls);
+
+std::vector<int64_t> g_data_sizes = {kL1Size, kL2Size};
+
+void BoundsCheckSetArgs(benchmark::internal::Benchmark* bench) {
+ for (int64_t size : g_data_sizes) {
+ for (auto nulls : std::vector<ArgsType>({1000, 10, 2, 1, 0})) {
+ bench->Args({static_cast<ArgsType>(size), nulls});
+ }
+ }
+}
+
+BENCHMARK(CheckIndexBoundsInt32)->Apply(BoundsCheckSetArgs);
+BENCHMARK(CheckIndexBoundsUInt32)->Apply(BoundsCheckSetArgs);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/int_util_internal.h b/src/arrow/cpp/src/arrow/util/int_util_internal.h
new file mode 100644
index 000000000..413670662
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/int_util_internal.h
@@ -0,0 +1,153 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <limits>
+#include <type_traits>
+
+#include "arrow/status.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+// "safe-math.h" includes <intsafe.h> from the Windows headers.
+#include "arrow/util/windows_compatibility.h"
+#include "arrow/vendored/portable-snippets/safe-math.h"
+// clang-format off (avoid include reordering)
+#include "arrow/util/windows_fixup.h"
+// clang-format on
+
+namespace arrow {
+namespace internal {
+
+// Define functions AddWithOverflow, SubtractWithOverflow, MultiplyWithOverflow
+// with the signature `bool(T u, T v, T* out)` where T is an integer type.
+// On overflow, these functions return true. Otherwise, false is returned
+// and `out` is updated with the result of the operation.
+
+#define OP_WITH_OVERFLOW(_func_name, _psnip_op, _type, _psnip_type) \
+ static inline bool _func_name(_type u, _type v, _type* out) { \
+ return !psnip_safe_##_psnip_type##_##_psnip_op(out, u, v); \
+ }
+
+#define OPS_WITH_OVERFLOW(_func_name, _psnip_op) \
+ OP_WITH_OVERFLOW(_func_name, _psnip_op, int8_t, int8) \
+ OP_WITH_OVERFLOW(_func_name, _psnip_op, int16_t, int16) \
+ OP_WITH_OVERFLOW(_func_name, _psnip_op, int32_t, int32) \
+ OP_WITH_OVERFLOW(_func_name, _psnip_op, int64_t, int64) \
+ OP_WITH_OVERFLOW(_func_name, _psnip_op, uint8_t, uint8) \
+ OP_WITH_OVERFLOW(_func_name, _psnip_op, uint16_t, uint16) \
+ OP_WITH_OVERFLOW(_func_name, _psnip_op, uint32_t, uint32) \
+ OP_WITH_OVERFLOW(_func_name, _psnip_op, uint64_t, uint64)
+
+OPS_WITH_OVERFLOW(AddWithOverflow, add)
+OPS_WITH_OVERFLOW(SubtractWithOverflow, sub)
+OPS_WITH_OVERFLOW(MultiplyWithOverflow, mul)
+OPS_WITH_OVERFLOW(DivideWithOverflow, div)
+
+#undef OP_WITH_OVERFLOW
+#undef OPS_WITH_OVERFLOW
+
+// Define function NegateWithOverflow with the signature `bool(T u, T* out)`
+// where T is a signed integer type. On overflow, these functions return true.
+// Otherwise, false is returned and `out` is updated with the result of the
+// operation.
+
+#define UNARY_OP_WITH_OVERFLOW(_func_name, _psnip_op, _type, _psnip_type) \
+ static inline bool _func_name(_type u, _type* out) { \
+ return !psnip_safe_##_psnip_type##_##_psnip_op(out, u); \
+ }
+
+#define SIGNED_UNARY_OPS_WITH_OVERFLOW(_func_name, _psnip_op) \
+ UNARY_OP_WITH_OVERFLOW(_func_name, _psnip_op, int8_t, int8) \
+ UNARY_OP_WITH_OVERFLOW(_func_name, _psnip_op, int16_t, int16) \
+ UNARY_OP_WITH_OVERFLOW(_func_name, _psnip_op, int32_t, int32) \
+ UNARY_OP_WITH_OVERFLOW(_func_name, _psnip_op, int64_t, int64)
+
+SIGNED_UNARY_OPS_WITH_OVERFLOW(NegateWithOverflow, neg)
+
+#undef UNARY_OP_WITH_OVERFLOW
+#undef SIGNED_UNARY_OPS_WITH_OVERFLOW
+
+/// Signed addition with well-defined behaviour on overflow (as unsigned)
+template <typename SignedInt>
+SignedInt SafeSignedAdd(SignedInt u, SignedInt v) {
+ using UnsignedInt = typename std::make_unsigned<SignedInt>::type;
+ return static_cast<SignedInt>(static_cast<UnsignedInt>(u) +
+ static_cast<UnsignedInt>(v));
+}
+
+/// Signed subtraction with well-defined behaviour on overflow (as unsigned)
+template <typename SignedInt>
+SignedInt SafeSignedSubtract(SignedInt u, SignedInt v) {
+ using UnsignedInt = typename std::make_unsigned<SignedInt>::type;
+ return static_cast<SignedInt>(static_cast<UnsignedInt>(u) -
+ static_cast<UnsignedInt>(v));
+}
+
+/// Signed negation with well-defined behaviour on overflow (as unsigned)
+template <typename SignedInt>
+SignedInt SafeSignedNegate(SignedInt u) {
+ using UnsignedInt = typename std::make_unsigned<SignedInt>::type;
+ return static_cast<SignedInt>(~static_cast<UnsignedInt>(u) + 1);
+}
+
+/// Signed left shift with well-defined behaviour on negative numbers or overflow
+template <typename SignedInt, typename Shift>
+SignedInt SafeLeftShift(SignedInt u, Shift shift) {
+ using UnsignedInt = typename std::make_unsigned<SignedInt>::type;
+ return static_cast<SignedInt>(static_cast<UnsignedInt>(u) << shift);
+}
+
+/// Upcast an integer to the largest possible width (currently 64 bits)
+
+template <typename Integer>
+typename std::enable_if<
+ std::is_integral<Integer>::value && std::is_signed<Integer>::value, int64_t>::type
+UpcastInt(Integer v) {
+ return v;
+}
+
+template <typename Integer>
+typename std::enable_if<
+ std::is_integral<Integer>::value && std::is_unsigned<Integer>::value, uint64_t>::type
+UpcastInt(Integer v) {
+ return v;
+}
+
+static inline Status CheckSliceParams(int64_t object_length, int64_t slice_offset,
+ int64_t slice_length, const char* object_name) {
+ if (ARROW_PREDICT_FALSE(slice_offset < 0)) {
+ return Status::Invalid("Negative ", object_name, " slice offset");
+ }
+ if (ARROW_PREDICT_FALSE(slice_length < 0)) {
+ return Status::Invalid("Negative ", object_name, " slice length");
+ }
+ int64_t offset_plus_length;
+ if (ARROW_PREDICT_FALSE(
+ internal::AddWithOverflow(slice_offset, slice_length, &offset_plus_length))) {
+ return Status::Invalid(object_name, " slice would overflow");
+ }
+ if (ARROW_PREDICT_FALSE(slice_offset + slice_length > object_length)) {
+ return Status::Invalid(object_name, " slice would exceed ", object_name, " length");
+ }
+ return Status::OK();
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/int_util_test.cc b/src/arrow/cpp/src/arrow/util/int_util_test.cc
new file mode 100644
index 000000000..333154c5c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/int_util_test.cc
@@ -0,0 +1,597 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <random>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/datum.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+#include "arrow/util/int_util.h"
+#include "arrow/util/int_util_internal.h"
+
+namespace arrow {
+namespace internal {
+
+static std::vector<uint8_t> all_widths = {1, 2, 4, 8};
+
+template <typename T>
+void CheckUIntWidth(const std::vector<T>& values, uint8_t expected_width) {
+ for (const uint8_t min_width : all_widths) {
+ uint8_t width =
+ DetectUIntWidth(values.data(), static_cast<int64_t>(values.size()), min_width);
+ ASSERT_EQ(width, std::max(min_width, expected_width));
+ width = DetectUIntWidth(values.data(), nullptr, static_cast<int64_t>(values.size()),
+ min_width);
+ ASSERT_EQ(width, std::max(min_width, expected_width));
+ }
+}
+
+template <typename T>
+void CheckUIntWidth(const std::vector<T>& values, const std::vector<uint8_t>& valid_bytes,
+ uint8_t expected_width) {
+ for (const uint8_t min_width : all_widths) {
+ uint8_t width = DetectUIntWidth(values.data(), valid_bytes.data(),
+ static_cast<int64_t>(values.size()), min_width);
+ ASSERT_EQ(width, std::max(min_width, expected_width));
+ }
+}
+
+template <typename T>
+void CheckIntWidth(const std::vector<T>& values, uint8_t expected_width) {
+ for (const uint8_t min_width : all_widths) {
+ uint8_t width =
+ DetectIntWidth(values.data(), static_cast<int64_t>(values.size()), min_width);
+ ASSERT_EQ(width, std::max(min_width, expected_width));
+ width = DetectIntWidth(values.data(), nullptr, static_cast<int64_t>(values.size()),
+ min_width);
+ ASSERT_EQ(width, std::max(min_width, expected_width));
+ }
+}
+
+template <typename T>
+void CheckIntWidth(const std::vector<T>& values, const std::vector<uint8_t>& valid_bytes,
+ uint8_t expected_width) {
+ for (const uint8_t min_width : all_widths) {
+ uint8_t width = DetectIntWidth(values.data(), valid_bytes.data(),
+ static_cast<int64_t>(values.size()), min_width);
+ ASSERT_EQ(width, std::max(min_width, expected_width));
+ }
+}
+
+template <typename T>
+std::vector<T> MakeRandomVector(const std::vector<T>& base_values, int n_values) {
+ std::default_random_engine gen(42);
+ std::uniform_int_distribution<int> index_dist(0,
+ static_cast<int>(base_values.size() - 1));
+
+ std::vector<T> values(n_values);
+ for (int i = 0; i < n_values; ++i) {
+ values[i] = base_values[index_dist(gen)];
+ }
+ return values;
+}
+
+template <typename T>
+std::vector<std::pair<std::vector<T>, std::vector<uint8_t>>> AlmostAllNullValues(
+ int n_values, T null_value, T non_null_value) {
+ std::vector<std::pair<std::vector<T>, std::vector<uint8_t>>> vectors;
+ vectors.reserve(n_values);
+ for (int i = 0; i < n_values; ++i) {
+ std::vector<T> values(n_values, null_value);
+ std::vector<uint8_t> valid_bytes(n_values, 0);
+ values[i] = non_null_value;
+ valid_bytes[i] = 1;
+ vectors.push_back({std::move(values), std::move(valid_bytes)});
+ }
+ return vectors;
+}
+
+template <typename T>
+std::vector<std::vector<T>> AlmostAllZeros(int n_values, T nonzero_value) {
+ std::vector<std::vector<T>> vectors;
+ vectors.reserve(n_values);
+ for (int i = 0; i < n_values; ++i) {
+ std::vector<T> values(n_values, 0);
+ values[i] = nonzero_value;
+ vectors.push_back(std::move(values));
+ }
+ return vectors;
+}
+
+std::vector<uint64_t> valid_uint8 = {0, 0x7f, 0xff};
+std::vector<uint64_t> valid_uint16 = {0, 0x7f, 0xff, 0x1000, 0xffff};
+std::vector<uint64_t> valid_uint32 = {0, 0x7f, 0xff, 0x10000, 0xffffffffULL};
+std::vector<uint64_t> valid_uint64 = {0, 0x100000000ULL, 0xffffffffffffffffULL};
+
+TEST(UIntWidth, NoNulls) {
+ std::vector<uint64_t> values{0, 0x7f, 0xff};
+ CheckUIntWidth(values, 1);
+
+ values = {0, 0x100};
+ CheckUIntWidth(values, 2);
+
+ values = {0, 0xffff};
+ CheckUIntWidth(values, 2);
+
+ values = {0, 0x10000};
+ CheckUIntWidth(values, 4);
+
+ values = {0, 0xffffffffULL};
+ CheckUIntWidth(values, 4);
+
+ values = {0, 0x100000000ULL};
+ CheckUIntWidth(values, 8);
+
+ values = {0, 0xffffffffffffffffULL};
+ CheckUIntWidth(values, 8);
+}
+
+TEST(UIntWidth, Nulls) {
+ std::vector<uint8_t> valid10{true, false};
+ std::vector<uint8_t> valid01{false, true};
+
+ std::vector<uint64_t> values{0, 0xff};
+ CheckUIntWidth(values, valid01, 1);
+ CheckUIntWidth(values, valid10, 1);
+
+ values = {0, 0x100};
+ CheckUIntWidth(values, valid01, 2);
+ CheckUIntWidth(values, valid10, 1);
+
+ values = {0, 0xffff};
+ CheckUIntWidth(values, valid01, 2);
+ CheckUIntWidth(values, valid10, 1);
+
+ values = {0, 0x10000};
+ CheckUIntWidth(values, valid01, 4);
+ CheckUIntWidth(values, valid10, 1);
+
+ values = {0, 0xffffffffULL};
+ CheckUIntWidth(values, valid01, 4);
+ CheckUIntWidth(values, valid10, 1);
+
+ values = {0, 0x100000000ULL};
+ CheckUIntWidth(values, valid01, 8);
+ CheckUIntWidth(values, valid10, 1);
+
+ values = {0, 0xffffffffffffffffULL};
+ CheckUIntWidth(values, valid01, 8);
+ CheckUIntWidth(values, valid10, 1);
+}
+
+TEST(UIntWidth, NoNullsMany) {
+ constexpr int N = 40;
+ for (const auto& values : AlmostAllZeros<uint64_t>(N, 0xff)) {
+ CheckUIntWidth(values, 1);
+ }
+ for (const auto& values : AlmostAllZeros<uint64_t>(N, 0xffff)) {
+ CheckUIntWidth(values, 2);
+ }
+ for (const auto& values : AlmostAllZeros<uint64_t>(N, 0xffffffffULL)) {
+ CheckUIntWidth(values, 4);
+ }
+ for (const auto& values : AlmostAllZeros<uint64_t>(N, 0xffffffffffffffffULL)) {
+ CheckUIntWidth(values, 8);
+ }
+ auto values = MakeRandomVector(valid_uint8, N);
+ CheckUIntWidth(values, 1);
+
+ values = MakeRandomVector(valid_uint16, N);
+ CheckUIntWidth(values, 2);
+
+ values = MakeRandomVector(valid_uint32, N);
+ CheckUIntWidth(values, 4);
+
+ values = MakeRandomVector(valid_uint64, N);
+ CheckUIntWidth(values, 8);
+}
+
+TEST(UIntWidth, NullsMany) {
+ constexpr uint64_t huge = 0x123456789abcdefULL;
+ constexpr int N = 40;
+ for (const auto& p : AlmostAllNullValues<uint64_t>(N, 0, 0xff)) {
+ CheckUIntWidth(p.first, p.second, 1);
+ }
+ for (const auto& p : AlmostAllNullValues<uint64_t>(N, huge, 0xff)) {
+ CheckUIntWidth(p.first, p.second, 1);
+ }
+ for (const auto& p : AlmostAllNullValues<uint64_t>(N, 0, 0xffff)) {
+ CheckUIntWidth(p.first, p.second, 2);
+ }
+ for (const auto& p : AlmostAllNullValues<uint64_t>(N, huge, 0xffff)) {
+ CheckUIntWidth(p.first, p.second, 2);
+ }
+ for (const auto& p : AlmostAllNullValues<uint64_t>(N, 0, 0xffffffffULL)) {
+ CheckUIntWidth(p.first, p.second, 4);
+ }
+ for (const auto& p : AlmostAllNullValues<uint64_t>(N, huge, 0xffffffffULL)) {
+ CheckUIntWidth(p.first, p.second, 4);
+ }
+ for (const auto& p : AlmostAllNullValues<uint64_t>(N, 0, 0xffffffffffffffffULL)) {
+ CheckUIntWidth(p.first, p.second, 8);
+ }
+ for (const auto& p : AlmostAllNullValues<uint64_t>(N, huge, 0xffffffffffffffffULL)) {
+ CheckUIntWidth(p.first, p.second, 8);
+ }
+}
+
+TEST(IntWidth, NoNulls) {
+ std::vector<int64_t> values{0, 0x7f, -0x80};
+ CheckIntWidth(values, 1);
+
+ values = {0, 0x80};
+ CheckIntWidth(values, 2);
+
+ values = {0, -0x81};
+ CheckIntWidth(values, 2);
+
+ values = {0, 0x7fff, -0x8000};
+ CheckIntWidth(values, 2);
+
+ values = {0, 0x8000};
+ CheckIntWidth(values, 4);
+
+ values = {0, -0x8001};
+ CheckIntWidth(values, 4);
+
+ values = {0, 0x7fffffffLL, -0x80000000LL};
+ CheckIntWidth(values, 4);
+
+ values = {0, 0x80000000LL};
+ CheckIntWidth(values, 8);
+
+ values = {0, -0x80000001LL};
+ CheckIntWidth(values, 8);
+
+ values = {0, 0x7fffffffffffffffLL, -0x7fffffffffffffffLL - 1};
+ CheckIntWidth(values, 8);
+}
+
+TEST(IntWidth, Nulls) {
+ std::vector<uint8_t> valid100{true, false, false};
+ std::vector<uint8_t> valid010{false, true, false};
+ std::vector<uint8_t> valid001{false, false, true};
+
+ std::vector<int64_t> values{0, 0x7f, -0x80};
+ CheckIntWidth(values, valid100, 1);
+ CheckIntWidth(values, valid010, 1);
+ CheckIntWidth(values, valid001, 1);
+
+ values = {0, 0x80, -0x81};
+ CheckIntWidth(values, valid100, 1);
+ CheckIntWidth(values, valid010, 2);
+ CheckIntWidth(values, valid001, 2);
+
+ values = {0, 0x7fff, -0x8000};
+ CheckIntWidth(values, valid100, 1);
+ CheckIntWidth(values, valid010, 2);
+ CheckIntWidth(values, valid001, 2);
+
+ values = {0, 0x8000, -0x8001};
+ CheckIntWidth(values, valid100, 1);
+ CheckIntWidth(values, valid010, 4);
+ CheckIntWidth(values, valid001, 4);
+
+ values = {0, 0x7fffffffLL, -0x80000000LL};
+ CheckIntWidth(values, valid100, 1);
+ CheckIntWidth(values, valid010, 4);
+ CheckIntWidth(values, valid001, 4);
+
+ values = {0, 0x80000000LL, -0x80000001LL};
+ CheckIntWidth(values, valid100, 1);
+ CheckIntWidth(values, valid010, 8);
+ CheckIntWidth(values, valid001, 8);
+
+ values = {0, 0x7fffffffffffffffLL, -0x7fffffffffffffffLL - 1};
+ CheckIntWidth(values, valid100, 1);
+ CheckIntWidth(values, valid010, 8);
+ CheckIntWidth(values, valid001, 8);
+}
+
+TEST(IntWidth, NoNullsMany) {
+ constexpr int N = 40;
+ // 1 byte wide
+ for (const int64_t value : {0x7f, -0x80}) {
+ for (const auto& values : AlmostAllZeros<int64_t>(N, value)) {
+ CheckIntWidth(values, 1);
+ }
+ }
+ // 2 bytes wide
+ for (const int64_t value : {0x80, -0x81, 0x7fff, -0x8000}) {
+ for (const auto& values : AlmostAllZeros<int64_t>(N, value)) {
+ CheckIntWidth(values, 2);
+ }
+ }
+ // 4 bytes wide
+ for (const int64_t value : {0x8000LL, -0x8001LL, 0x7fffffffLL, -0x80000000LL}) {
+ for (const auto& values : AlmostAllZeros<int64_t>(N, value)) {
+ CheckIntWidth(values, 4);
+ }
+ }
+ // 8 bytes wide
+ for (const int64_t value : {0x80000000LL, -0x80000001LL, 0x7fffffffffffffffLL}) {
+ for (const auto& values : AlmostAllZeros<int64_t>(N, value)) {
+ CheckIntWidth(values, 8);
+ }
+ }
+}
+
+TEST(IntWidth, NullsMany) {
+ constexpr int64_t huge = 0x123456789abcdefLL;
+ constexpr int N = 40;
+ // 1 byte wide
+ for (const int64_t value : {0x7f, -0x80}) {
+ for (const auto& p : AlmostAllNullValues<int64_t>(N, 0, value)) {
+ CheckIntWidth(p.first, p.second, 1);
+ }
+ for (const auto& p : AlmostAllNullValues<int64_t>(N, huge, value)) {
+ CheckIntWidth(p.first, p.second, 1);
+ }
+ }
+ // 2 bytes wide
+ for (const int64_t value : {0x80, -0x81, 0x7fff, -0x8000}) {
+ for (const auto& p : AlmostAllNullValues<int64_t>(N, 0, value)) {
+ CheckIntWidth(p.first, p.second, 2);
+ }
+ for (const auto& p : AlmostAllNullValues<int64_t>(N, huge, value)) {
+ CheckIntWidth(p.first, p.second, 2);
+ }
+ }
+ // 4 bytes wide
+ for (const int64_t value : {0x8000LL, -0x8001LL, 0x7fffffffLL, -0x80000000LL}) {
+ for (const auto& p : AlmostAllNullValues<int64_t>(N, 0, value)) {
+ CheckIntWidth(p.first, p.second, 4);
+ }
+ for (const auto& p : AlmostAllNullValues<int64_t>(N, huge, value)) {
+ CheckIntWidth(p.first, p.second, 4);
+ }
+ }
+ // 8 bytes wide
+ for (const int64_t value : {0x80000000LL, -0x80000001LL, 0x7fffffffffffffffLL}) {
+ for (const auto& p : AlmostAllNullValues<int64_t>(N, 0, value)) {
+ CheckIntWidth(p.first, p.second, 8);
+ }
+ for (const auto& p : AlmostAllNullValues<int64_t>(N, huge, value)) {
+ CheckIntWidth(p.first, p.second, 8);
+ }
+ }
+}
+
+TEST(TransposeInts, Int8ToInt64) {
+ std::vector<int8_t> src = {1, 3, 5, 0, 3, 2};
+ std::vector<int32_t> transpose_map = {1111, 2222, 3333, 4444, 5555, 6666, 7777};
+ std::vector<int64_t> dest(src.size());
+
+ TransposeInts(src.data(), dest.data(), 6, transpose_map.data());
+ ASSERT_EQ(dest, std::vector<int64_t>({2222, 4444, 6666, 1111, 4444, 3333}));
+}
+
+void BoundsCheckPasses(const std::shared_ptr<DataType>& type,
+ const std::string& indices_json, uint64_t upper_limit) {
+ auto indices = ArrayFromJSON(type, indices_json);
+ ASSERT_OK(CheckIndexBounds(*indices->data(), upper_limit));
+}
+
+void BoundsCheckFails(const std::shared_ptr<DataType>& type,
+ const std::string& indices_json, uint64_t upper_limit) {
+ auto indices = ArrayFromJSON(type, indices_json);
+ ASSERT_RAISES(IndexError, CheckIndexBounds(*indices->data(), upper_limit));
+}
+
+TEST(CheckIndexBounds, Batching) {
+ auto rand = random::RandomArrayGenerator(/*seed=*/0);
+
+ const int64_t length = 200;
+
+ auto indices = rand.Int16(length, 0, 0, /*null_probability=*/0);
+ ArrayData* index_data = indices->data().get();
+ index_data->buffers[0] = *AllocateBitmap(length);
+
+ int16_t* values = index_data->GetMutableValues<int16_t>(1);
+ uint8_t* bitmap = index_data->buffers[0]->mutable_data();
+ BitUtil::SetBitsTo(bitmap, 0, length, true);
+
+ ASSERT_OK(CheckIndexBounds(*index_data, 1));
+
+ // We'll place out of bounds indices at various locations
+ values[99] = 1;
+ ASSERT_RAISES(IndexError, CheckIndexBounds(*index_data, 1));
+
+ // Make that value null
+ BitUtil::ClearBit(bitmap, 99);
+ ASSERT_OK(CheckIndexBounds(*index_data, 1));
+
+ values[199] = 1;
+ ASSERT_RAISES(IndexError, CheckIndexBounds(*index_data, 1));
+
+ // Make that value null
+ BitUtil::ClearBit(bitmap, 199);
+ ASSERT_OK(CheckIndexBounds(*index_data, 1));
+}
+
+TEST(CheckIndexBounds, SignedInts) {
+ auto CheckCommon = [&](const std::shared_ptr<DataType>& ty) {
+ BoundsCheckPasses(ty, "[0, 0, 0]", 1);
+ BoundsCheckFails(ty, "[0, 0, 0]", 0);
+ BoundsCheckFails(ty, "[-1]", 1);
+ BoundsCheckFails(ty, "[-128]", 1);
+ BoundsCheckFails(ty, "[0, 100, 127]", 127);
+ BoundsCheckPasses(ty, "[0, 100, 127]", 128);
+ };
+
+ CheckCommon(int8());
+
+ CheckCommon(int16());
+ BoundsCheckPasses(int16(), "[0, 999, 999]", 1000);
+ BoundsCheckFails(int16(), "[0, 1000, 1000]", 1000);
+ BoundsCheckPasses(int16(), "[0, 32767]", 1 << 15);
+
+ CheckCommon(int32());
+ BoundsCheckPasses(int32(), "[0, 999999, 999999]", 1000000);
+ BoundsCheckFails(int32(), "[0, 1000000, 1000000]", 1000000);
+ BoundsCheckPasses(int32(), "[0, 2147483647]", 1LL << 31);
+
+ CheckCommon(int64());
+ BoundsCheckPasses(int64(), "[0, 9999999999, 9999999999]", 10000000000LL);
+ BoundsCheckFails(int64(), "[0, 10000000000, 10000000000]", 10000000000LL);
+}
+
+TEST(CheckIndexBounds, UnsignedInts) {
+ auto CheckCommon = [&](const std::shared_ptr<DataType>& ty) {
+ BoundsCheckPasses(ty, "[0, 0, 0]", 1);
+ BoundsCheckFails(ty, "[0, 0, 0]", 0);
+ BoundsCheckFails(ty, "[0, 100, 200]", 200);
+ BoundsCheckPasses(ty, "[0, 100, 200]", 201);
+ };
+
+ CheckCommon(uint8());
+ BoundsCheckPasses(uint8(), "[255, 255, 255]", 1000);
+ BoundsCheckFails(uint8(), "[255, 255, 255]", 255);
+
+ CheckCommon(uint16());
+ BoundsCheckPasses(uint16(), "[0, 999, 999]", 1000);
+ BoundsCheckFails(uint16(), "[0, 1000, 1000]", 1000);
+ BoundsCheckPasses(uint16(), "[0, 65535]", 1 << 16);
+
+ CheckCommon(uint32());
+ BoundsCheckPasses(uint32(), "[0, 999999, 999999]", 1000000);
+ BoundsCheckFails(uint32(), "[0, 1000000, 1000000]", 1000000);
+ BoundsCheckPasses(uint32(), "[0, 4294967295]", 1LL << 32);
+
+ CheckCommon(uint64());
+ BoundsCheckPasses(uint64(), "[0, 9999999999, 9999999999]", 10000000000LL);
+ BoundsCheckFails(uint64(), "[0, 10000000000, 10000000000]", 10000000000LL);
+}
+
+void CheckInRangePasses(const std::shared_ptr<DataType>& type,
+ const std::string& values_json, const std::string& limits_json) {
+ auto values = ArrayFromJSON(type, values_json);
+ auto limits = ArrayFromJSON(type, limits_json);
+ ASSERT_OK(CheckIntegersInRange(Datum(values->data()), **limits->GetScalar(0),
+ **limits->GetScalar(1)));
+}
+
+void CheckInRangeFails(const std::shared_ptr<DataType>& type,
+ const std::string& values_json, const std::string& limits_json) {
+ auto values = ArrayFromJSON(type, values_json);
+ auto limits = ArrayFromJSON(type, limits_json);
+ ASSERT_RAISES(Invalid,
+ CheckIntegersInRange(Datum(values->data()), **limits->GetScalar(0),
+ **limits->GetScalar(1)));
+}
+
+TEST(CheckIntegersInRange, Batching) {
+ auto rand = random::RandomArrayGenerator(/*seed=*/0);
+
+ const int64_t length = 200;
+
+ auto indices = rand.Int16(length, 0, 0, /*null_probability=*/0);
+ ArrayData* index_data = indices->data().get();
+ index_data->buffers[0] = *AllocateBitmap(length);
+
+ int16_t* values = index_data->GetMutableValues<int16_t>(1);
+ uint8_t* bitmap = index_data->buffers[0]->mutable_data();
+ BitUtil::SetBitsTo(bitmap, 0, length, true);
+
+ auto zero = std::make_shared<Int16Scalar>(0);
+ auto one = std::make_shared<Int16Scalar>(1);
+
+ ASSERT_OK(CheckIntegersInRange(*index_data, *zero, *one));
+
+ // 1 is included
+ values[99] = 1;
+ ASSERT_OK(CheckIntegersInRange(*index_data, *zero, *one));
+
+ // We'll place out of bounds indices at various locations
+ values[99] = 2;
+ ASSERT_RAISES(Invalid, CheckIntegersInRange(*index_data, *zero, *one));
+
+ // Make that value null
+ BitUtil::ClearBit(bitmap, 99);
+ ASSERT_OK(CheckIntegersInRange(*index_data, *zero, *one));
+
+ values[199] = 2;
+ ASSERT_RAISES(Invalid, CheckIntegersInRange(*index_data, *zero, *one));
+
+ // Make that value null
+ BitUtil::ClearBit(bitmap, 199);
+ ASSERT_OK(CheckIntegersInRange(*index_data, *zero, *one));
+}
+
+TEST(CheckIntegersInRange, SignedInts) {
+ auto CheckCommon = [&](const std::shared_ptr<DataType>& ty) {
+ CheckInRangePasses(ty, "[0, 0, 0]", "[0, 0]");
+ CheckInRangeFails(ty, "[0, 1, 0]", "[0, 0]");
+ CheckInRangeFails(ty, "[1, 1, 1]", "[2, 4]");
+ CheckInRangeFails(ty, "[-1]", "[0, 0]");
+ CheckInRangeFails(ty, "[-128]", "[-127, 0]");
+ CheckInRangeFails(ty, "[0, 100, 127]", "[0, 126]");
+ CheckInRangePasses(ty, "[0, 100, 127]", "[0, 127]");
+ };
+
+ CheckCommon(int8());
+
+ CheckCommon(int16());
+ CheckInRangePasses(int16(), "[0, 999, 999]", "[0, 999]");
+ CheckInRangeFails(int16(), "[0, 1000, 1000]", "[0, 999]");
+
+ CheckCommon(int32());
+ CheckInRangePasses(int32(), "[0, 999999, 999999]", "[0, 999999]");
+ CheckInRangeFails(int32(), "[0, 1000000, 1000000]", "[0, 999999]");
+
+ CheckCommon(int64());
+ CheckInRangePasses(int64(), "[0, 9999999999, 9999999999]", "[0, 9999999999]");
+ CheckInRangeFails(int64(), "[0, 10000000000, 10000000000]", "[0, 9999999999]");
+}
+
+TEST(CheckIntegersInRange, UnsignedInts) {
+ auto CheckCommon = [&](const std::shared_ptr<DataType>& ty) {
+ CheckInRangePasses(ty, "[0, 0, 0]", "[0, 0]");
+ CheckInRangeFails(ty, "[0, 1, 0]", "[0, 0]");
+ CheckInRangeFails(ty, "[1, 1, 1]", "[2, 4]");
+ CheckInRangeFails(ty, "[0, 100, 200]", "[0, 199]");
+ CheckInRangePasses(ty, "[0, 100, 200]", "[0, 200]");
+ };
+
+ CheckCommon(uint8());
+ CheckInRangePasses(uint8(), "[255, 255, 255]", "[0, 255]");
+
+ CheckCommon(uint16());
+ CheckInRangePasses(uint16(), "[0, 999, 999]", "[0, 999]");
+ CheckInRangeFails(uint16(), "[0, 1000, 1000]", "[0, 999]");
+ CheckInRangePasses(uint16(), "[0, 65535]", "[0, 65535]");
+
+ CheckCommon(uint32());
+ CheckInRangePasses(uint32(), "[0, 999999, 999999]", "[0, 999999]");
+ CheckInRangeFails(uint32(), "[0, 1000000, 1000000]", "[0, 999999]");
+ CheckInRangePasses(uint32(), "[0, 4294967295]", "[0, 4294967295]");
+
+ CheckCommon(uint64());
+ CheckInRangePasses(uint64(), "[0, 9999999999, 9999999999]", "[0, 9999999999]");
+ CheckInRangeFails(uint64(), "[0, 10000000000, 10000000000]", "[0, 9999999999]");
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/io_util.cc b/src/arrow/cpp/src/arrow/util/io_util.cc
new file mode 100644
index 000000000..f6566ea7e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/io_util.cc
@@ -0,0 +1,1685 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Ensure 64-bit off_t for platforms where it matters
+#ifdef _FILE_OFFSET_BITS
+#undef _FILE_OFFSET_BITS
+#endif
+
+#define _FILE_OFFSET_BITS 64
+
+#if defined(sun) || defined(__sun)
+// According to https://bugs.python.org/issue1759169#msg82201, __EXTENSIONS__
+// is the best way to enable modern POSIX APIs, such as posix_madvise(), on Solaris.
+// (see also
+// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/mman.h)
+#undef __EXTENSIONS__
+#define __EXTENSIONS__
+#endif
+
+#include "arrow/util/windows_compatibility.h" // IWYU pragma: keep
+
+#include <algorithm>
+#include <cerrno>
+#include <cstdint>
+#include <cstring>
+#include <iostream>
+#include <random>
+#include <sstream>
+#include <string>
+#include <thread>
+#include <utility>
+#include <vector>
+
+#include <fcntl.h>
+#include <signal.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <sys/types.h> // IWYU pragma: keep
+
+// ----------------------------------------------------------------------
+// file compatibility stuff
+
+#ifdef _WIN32
+#include <io.h>
+#include <share.h>
+#else // POSIX-like platforms
+#include <dirent.h>
+#endif
+
+#ifdef _WIN32
+#include "arrow/io/mman.h"
+#undef Realloc
+#undef Free
+#else // POSIX-like platforms
+#include <sys/mman.h>
+#include <unistd.h>
+#endif
+
+// define max read/write count
+#ifdef _WIN32
+#define ARROW_MAX_IO_CHUNKSIZE INT32_MAX
+#else
+
+#ifdef __APPLE__
+// due to macOS bug, we need to set read/write max
+#define ARROW_MAX_IO_CHUNKSIZE INT32_MAX
+#else
+// see notes on Linux read/write manpage
+#define ARROW_MAX_IO_CHUNKSIZE 0x7ffff000
+#endif
+
+#endif
+
+#include "arrow/buffer.h"
+#include "arrow/result.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+
+// For filename conversion
+#if defined(_WIN32)
+#include "arrow/util/utf8.h"
+#endif
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace internal {
+
+namespace {
+
+template <typename CharT>
+std::basic_string<CharT> ReplaceChars(std::basic_string<CharT> s, CharT find, CharT rep) {
+ if (find != rep) {
+ for (size_t i = 0; i < s.length(); ++i) {
+ if (s[i] == find) {
+ s[i] = rep;
+ }
+ }
+ }
+ return s;
+}
+
+Result<NativePathString> StringToNative(const std::string& s) {
+#if _WIN32
+ return ::arrow::util::UTF8ToWideString(s);
+#else
+ return s;
+#endif
+}
+
+#if _WIN32
+Result<std::string> NativeToString(const NativePathString& ws) {
+ return ::arrow::util::WideStringToUTF8(ws);
+}
+#endif
+
+#if _WIN32
+const wchar_t kNativeSep = L'\\';
+const wchar_t kGenericSep = L'/';
+const wchar_t* kAllSeps = L"\\/";
+#else
+const char kNativeSep = '/';
+const char kGenericSep = '/';
+const char* kAllSeps = "/";
+#endif
+
+NativePathString NativeSlashes(NativePathString s) {
+ return ReplaceChars(std::move(s), kGenericSep, kNativeSep);
+}
+
+NativePathString GenericSlashes(NativePathString s) {
+ return ReplaceChars(std::move(s), kNativeSep, kGenericSep);
+}
+
+NativePathString NativeParent(const NativePathString& s) {
+ auto last_sep = s.find_last_of(kAllSeps);
+ if (last_sep == s.length() - 1) {
+ // Last separator is a trailing separator, skip all trailing separators
+ // and try again
+ auto before_last_seps = s.find_last_not_of(kAllSeps);
+ if (before_last_seps == NativePathString::npos) {
+ // Only separators in path
+ return s;
+ }
+ last_sep = s.find_last_of(kAllSeps, before_last_seps);
+ }
+ if (last_sep == NativePathString::npos) {
+ // No (other) separator in path
+ return s;
+ }
+ // There may be multiple contiguous separators, skip all of them
+ auto before_last_seps = s.find_last_not_of(kAllSeps, last_sep);
+ if (before_last_seps == NativePathString::npos) {
+ // All separators are at start of string, keep them all
+ return s.substr(0, last_sep + 1);
+ } else {
+ return s.substr(0, before_last_seps + 1);
+ }
+}
+
+Status ValidatePath(const std::string& s) {
+ if (s.find_first_of('\0') != std::string::npos) {
+ return Status::Invalid("Embedded NUL char in path: '", s, "'");
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+std::string ErrnoMessage(int errnum) { return std::strerror(errnum); }
+
+#if _WIN32
+std::string WinErrorMessage(int errnum) {
+ char buf[1024];
+ auto nchars = FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
+ NULL, errnum, 0, buf, sizeof(buf), NULL);
+ if (nchars == 0) {
+ // Fallback
+ std::stringstream ss;
+ ss << "Windows error #" << errnum;
+ return ss.str();
+ }
+ return std::string(buf, nchars);
+}
+#endif
+
+namespace {
+
+const char kErrnoDetailTypeId[] = "arrow::ErrnoDetail";
+
+class ErrnoDetail : public StatusDetail {
+ public:
+ explicit ErrnoDetail(int errnum) : errnum_(errnum) {}
+
+ const char* type_id() const override { return kErrnoDetailTypeId; }
+
+ std::string ToString() const override {
+ std::stringstream ss;
+ ss << "[errno " << errnum_ << "] " << ErrnoMessage(errnum_);
+ return ss.str();
+ }
+
+ int errnum() const { return errnum_; }
+
+ protected:
+ int errnum_;
+};
+
+#if _WIN32
+const char kWinErrorDetailTypeId[] = "arrow::WinErrorDetail";
+
+class WinErrorDetail : public StatusDetail {
+ public:
+ explicit WinErrorDetail(int errnum) : errnum_(errnum) {}
+
+ const char* type_id() const override { return kWinErrorDetailTypeId; }
+
+ std::string ToString() const override {
+ std::stringstream ss;
+ ss << "[Windows error " << errnum_ << "] " << WinErrorMessage(errnum_);
+ return ss.str();
+ }
+
+ int errnum() const { return errnum_; }
+
+ protected:
+ int errnum_;
+};
+#endif
+
+const char kSignalDetailTypeId[] = "arrow::SignalDetail";
+
+class SignalDetail : public StatusDetail {
+ public:
+ explicit SignalDetail(int signum) : signum_(signum) {}
+
+ const char* type_id() const override { return kSignalDetailTypeId; }
+
+ std::string ToString() const override {
+ std::stringstream ss;
+ ss << "received signal " << signum_;
+ return ss.str();
+ }
+
+ int signum() const { return signum_; }
+
+ protected:
+ int signum_;
+};
+
+} // namespace
+
+std::shared_ptr<StatusDetail> StatusDetailFromErrno(int errnum) {
+ return std::make_shared<ErrnoDetail>(errnum);
+}
+
+#if _WIN32
+std::shared_ptr<StatusDetail> StatusDetailFromWinError(int errnum) {
+ return std::make_shared<WinErrorDetail>(errnum);
+}
+#endif
+
+std::shared_ptr<StatusDetail> StatusDetailFromSignal(int signum) {
+ return std::make_shared<SignalDetail>(signum);
+}
+
+int ErrnoFromStatus(const Status& status) {
+ const auto detail = status.detail();
+ if (detail != nullptr && detail->type_id() == kErrnoDetailTypeId) {
+ return checked_cast<const ErrnoDetail&>(*detail).errnum();
+ }
+ return 0;
+}
+
+int WinErrorFromStatus(const Status& status) {
+#if _WIN32
+ const auto detail = status.detail();
+ if (detail != nullptr && detail->type_id() == kWinErrorDetailTypeId) {
+ return checked_cast<const WinErrorDetail&>(*detail).errnum();
+ }
+#endif
+ return 0;
+}
+
+int SignalFromStatus(const Status& status) {
+ const auto detail = status.detail();
+ if (detail != nullptr && detail->type_id() == kSignalDetailTypeId) {
+ return checked_cast<const SignalDetail&>(*detail).signum();
+ }
+ return 0;
+}
+
+//
+// PlatformFilename implementation
+//
+
+struct PlatformFilename::Impl {
+ Impl() = default;
+ explicit Impl(NativePathString p) : native_(NativeSlashes(std::move(p))) {}
+
+ NativePathString native_;
+
+ // '/'-separated
+ NativePathString generic() const { return GenericSlashes(native_); }
+};
+
+PlatformFilename::PlatformFilename() : impl_(new Impl{}) {}
+
+PlatformFilename::~PlatformFilename() {}
+
+PlatformFilename::PlatformFilename(Impl impl) : impl_(new Impl(std::move(impl))) {}
+
+PlatformFilename::PlatformFilename(const PlatformFilename& other)
+ : PlatformFilename(Impl{other.impl_->native_}) {}
+
+PlatformFilename::PlatformFilename(PlatformFilename&& other)
+ : impl_(std::move(other.impl_)) {}
+
+PlatformFilename& PlatformFilename::operator=(const PlatformFilename& other) {
+ this->impl_.reset(new Impl{other.impl_->native_});
+ return *this;
+}
+
+PlatformFilename& PlatformFilename::operator=(PlatformFilename&& other) {
+ this->impl_ = std::move(other.impl_);
+ return *this;
+}
+
+PlatformFilename::PlatformFilename(const NativePathString& path)
+ : PlatformFilename(Impl{path}) {}
+
+PlatformFilename::PlatformFilename(const NativePathString::value_type* path)
+ : PlatformFilename(NativePathString(path)) {}
+
+bool PlatformFilename::operator==(const PlatformFilename& other) const {
+ return impl_->native_ == other.impl_->native_;
+}
+
+bool PlatformFilename::operator!=(const PlatformFilename& other) const {
+ return impl_->native_ != other.impl_->native_;
+}
+
+const NativePathString& PlatformFilename::ToNative() const { return impl_->native_; }
+
+std::string PlatformFilename::ToString() const {
+#if _WIN32
+ auto result = NativeToString(impl_->generic());
+ if (!result.ok()) {
+ std::stringstream ss;
+ ss << "<Unrepresentable filename: " << result.status().ToString() << ">";
+ return ss.str();
+ }
+ return *std::move(result);
+#else
+ return impl_->generic();
+#endif
+}
+
+PlatformFilename PlatformFilename::Parent() const {
+ return PlatformFilename(NativeParent(ToNative()));
+}
+
+Result<PlatformFilename> PlatformFilename::FromString(const std::string& file_name) {
+ RETURN_NOT_OK(ValidatePath(file_name));
+ ARROW_ASSIGN_OR_RAISE(auto ns, StringToNative(file_name));
+ return PlatformFilename(std::move(ns));
+}
+
+PlatformFilename PlatformFilename::Join(const PlatformFilename& child) const {
+ if (impl_->native_.empty() || impl_->native_.back() == kNativeSep) {
+ return PlatformFilename(Impl{impl_->native_ + child.impl_->native_});
+ } else {
+ return PlatformFilename(Impl{impl_->native_ + kNativeSep + child.impl_->native_});
+ }
+}
+
+Result<PlatformFilename> PlatformFilename::Join(const std::string& child_name) const {
+ ARROW_ASSIGN_OR_RAISE(auto child, PlatformFilename::FromString(child_name));
+ return Join(child);
+}
+
+//
+// Filesystem access routines
+//
+
+namespace {
+
+Result<bool> DoCreateDir(const PlatformFilename& dir_path, bool create_parents) {
+#ifdef _WIN32
+ const auto s = dir_path.ToNative().c_str();
+ if (CreateDirectoryW(s, nullptr)) {
+ return true;
+ }
+ int errnum = GetLastError();
+ if (errnum == ERROR_ALREADY_EXISTS) {
+ const auto attrs = GetFileAttributesW(s);
+ if (attrs == INVALID_FILE_ATTRIBUTES || !(attrs & FILE_ATTRIBUTE_DIRECTORY)) {
+ // Note we propagate the original error, not the GetFileAttributesW() error
+ return IOErrorFromWinError(ERROR_ALREADY_EXISTS, "Cannot create directory '",
+ dir_path.ToString(), "': non-directory entry exists");
+ }
+ return false;
+ }
+ if (create_parents && errnum == ERROR_PATH_NOT_FOUND) {
+ auto parent_path = dir_path.Parent();
+ if (parent_path != dir_path) {
+ RETURN_NOT_OK(DoCreateDir(parent_path, create_parents));
+ return DoCreateDir(dir_path, false); // Retry
+ }
+ }
+ return IOErrorFromWinError(GetLastError(), "Cannot create directory '",
+ dir_path.ToString(), "'");
+#else
+ const auto s = dir_path.ToNative().c_str();
+ if (mkdir(s, S_IRWXU | S_IRWXG | S_IRWXO) == 0) {
+ return true;
+ }
+ if (errno == EEXIST) {
+ struct stat st;
+ if (stat(s, &st) || !S_ISDIR(st.st_mode)) {
+ // Note we propagate the original errno, not the stat() errno
+ return IOErrorFromErrno(EEXIST, "Cannot create directory '", dir_path.ToString(),
+ "': non-directory entry exists");
+ }
+ return false;
+ }
+ if (create_parents && errno == ENOENT) {
+ auto parent_path = dir_path.Parent();
+ if (parent_path != dir_path) {
+ RETURN_NOT_OK(DoCreateDir(parent_path, create_parents));
+ return DoCreateDir(dir_path, false); // Retry
+ }
+ }
+ return IOErrorFromErrno(errno, "Cannot create directory '", dir_path.ToString(), "'");
+#endif
+}
+
+} // namespace
+
+Result<bool> CreateDir(const PlatformFilename& dir_path) {
+ return DoCreateDir(dir_path, false);
+}
+
+Result<bool> CreateDirTree(const PlatformFilename& dir_path) {
+ return DoCreateDir(dir_path, true);
+}
+
+#ifdef _WIN32
+
+namespace {
+
+void FindHandleDeleter(HANDLE* handle) {
+ if (!FindClose(*handle)) {
+ ARROW_LOG(WARNING) << "Cannot close directory handle: "
+ << WinErrorMessage(GetLastError());
+ }
+}
+
+std::wstring PathWithoutTrailingSlash(const PlatformFilename& fn) {
+ std::wstring path = fn.ToNative();
+ while (!path.empty() && path.back() == kNativeSep) {
+ path.pop_back();
+ }
+ return path;
+}
+
+Result<std::vector<WIN32_FIND_DATAW>> ListDirInternal(const PlatformFilename& dir_path) {
+ WIN32_FIND_DATAW find_data;
+ std::wstring pattern = PathWithoutTrailingSlash(dir_path) + L"\\*.*";
+ HANDLE handle = FindFirstFileW(pattern.c_str(), &find_data);
+ if (handle == INVALID_HANDLE_VALUE) {
+ return IOErrorFromWinError(GetLastError(), "Cannot list directory '",
+ dir_path.ToString(), "'");
+ }
+
+ std::unique_ptr<HANDLE, decltype(&FindHandleDeleter)> handle_guard(&handle,
+ FindHandleDeleter);
+
+ std::vector<WIN32_FIND_DATAW> results;
+ do {
+ // Skip "." and ".."
+ if (find_data.cFileName[0] == L'.') {
+ if (find_data.cFileName[1] == L'\0' ||
+ (find_data.cFileName[1] == L'.' && find_data.cFileName[2] == L'\0')) {
+ continue;
+ }
+ }
+ results.push_back(find_data);
+ } while (FindNextFileW(handle, &find_data));
+
+ int errnum = GetLastError();
+ if (errnum != ERROR_NO_MORE_FILES) {
+ return IOErrorFromWinError(GetLastError(), "Cannot list directory '",
+ dir_path.ToString(), "'");
+ }
+ return results;
+}
+
+Status FindOneFile(const PlatformFilename& fn, WIN32_FIND_DATAW* find_data,
+ bool* exists = nullptr) {
+ HANDLE handle = FindFirstFileW(PathWithoutTrailingSlash(fn).c_str(), find_data);
+ if (handle == INVALID_HANDLE_VALUE) {
+ int errnum = GetLastError();
+ if (exists == nullptr ||
+ (errnum != ERROR_PATH_NOT_FOUND && errnum != ERROR_FILE_NOT_FOUND)) {
+ return IOErrorFromWinError(GetLastError(), "Cannot get information for path '",
+ fn.ToString(), "'");
+ }
+ *exists = false;
+ } else {
+ if (exists != nullptr) {
+ *exists = true;
+ }
+ FindHandleDeleter(&handle);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Result<std::vector<PlatformFilename>> ListDir(const PlatformFilename& dir_path) {
+ ARROW_ASSIGN_OR_RAISE(auto entries, ListDirInternal(dir_path));
+
+ std::vector<PlatformFilename> results;
+ results.reserve(entries.size());
+ for (const auto& entry : entries) {
+ results.emplace_back(std::wstring(entry.cFileName));
+ }
+ return results;
+}
+
+#else
+
+Result<std::vector<PlatformFilename>> ListDir(const PlatformFilename& dir_path) {
+ DIR* dir = opendir(dir_path.ToNative().c_str());
+ if (dir == nullptr) {
+ return IOErrorFromErrno(errno, "Cannot list directory '", dir_path.ToString(), "'");
+ }
+
+ auto dir_deleter = [](DIR* dir) -> void {
+ if (closedir(dir) != 0) {
+ ARROW_LOG(WARNING) << "Cannot close directory handle: " << ErrnoMessage(errno);
+ }
+ };
+ std::unique_ptr<DIR, decltype(dir_deleter)> dir_guard(dir, dir_deleter);
+
+ std::vector<PlatformFilename> results;
+ errno = 0;
+ struct dirent* entry = readdir(dir);
+ while (entry != nullptr) {
+ std::string path = entry->d_name;
+ if (path != "." && path != "..") {
+ results.emplace_back(std::move(path));
+ }
+ entry = readdir(dir);
+ }
+ if (errno != 0) {
+ return IOErrorFromErrno(errno, "Cannot list directory '", dir_path.ToString(), "'");
+ }
+ return results;
+}
+
+#endif
+
+namespace {
+
+#ifdef _WIN32
+
+Status DeleteDirTreeInternal(const PlatformFilename& dir_path);
+
+// Remove a directory entry that's always a directory
+Status DeleteDirEntryDir(const PlatformFilename& path, const WIN32_FIND_DATAW& entry,
+ bool remove_top_dir = true) {
+ if ((entry.dwFileAttributes & FILE_ATTRIBUTE_REPARSE_POINT) == 0) {
+ // It's a directory that doesn't have a reparse point => recurse
+ RETURN_NOT_OK(DeleteDirTreeInternal(path));
+ }
+ if (remove_top_dir) {
+ // Remove now empty directory or reparse point (e.g. symlink to dir)
+ if (!RemoveDirectoryW(path.ToNative().c_str())) {
+ return IOErrorFromWinError(GetLastError(), "Cannot delete directory entry '",
+ path.ToString(), "': ");
+ }
+ }
+ return Status::OK();
+}
+
+Status DeleteDirEntry(const PlatformFilename& path, const WIN32_FIND_DATAW& entry) {
+ if ((entry.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) != 0) {
+ return DeleteDirEntryDir(path, entry);
+ }
+ // It's a non-directory entry, most likely a regular file
+ if (!DeleteFileW(path.ToNative().c_str())) {
+ return IOErrorFromWinError(GetLastError(), "Cannot delete file '", path.ToString(),
+ "': ");
+ }
+ return Status::OK();
+}
+
+Status DeleteDirTreeInternal(const PlatformFilename& dir_path) {
+ ARROW_ASSIGN_OR_RAISE(auto entries, ListDirInternal(dir_path));
+ for (const auto& entry : entries) {
+ PlatformFilename path = dir_path.Join(PlatformFilename(entry.cFileName));
+ RETURN_NOT_OK(DeleteDirEntry(path, entry));
+ }
+ return Status::OK();
+}
+
+Result<bool> DeleteDirContents(const PlatformFilename& dir_path, bool allow_not_found,
+ bool remove_top_dir) {
+ bool exists = true;
+ WIN32_FIND_DATAW entry;
+ if (allow_not_found) {
+ RETURN_NOT_OK(FindOneFile(dir_path, &entry, &exists));
+ } else {
+ // Will raise if dir_path does not exist
+ RETURN_NOT_OK(FindOneFile(dir_path, &entry));
+ }
+ if (exists) {
+ RETURN_NOT_OK(DeleteDirEntryDir(dir_path, entry, remove_top_dir));
+ }
+ return exists;
+}
+
+#else // POSIX
+
+Status LinkStat(const PlatformFilename& path, struct stat* lst, bool* exists = nullptr) {
+ if (lstat(path.ToNative().c_str(), lst) != 0) {
+ if (exists == nullptr || (errno != ENOENT && errno != ENOTDIR && errno != ELOOP)) {
+ return IOErrorFromErrno(errno, "Cannot get information for path '", path.ToString(),
+ "'");
+ }
+ *exists = false;
+ } else if (exists != nullptr) {
+ *exists = true;
+ }
+ return Status::OK();
+}
+
+Status DeleteDirTreeInternal(const PlatformFilename& dir_path);
+
+Status DeleteDirEntryDir(const PlatformFilename& path, const struct stat& lst,
+ bool remove_top_dir = true) {
+ if (!S_ISLNK(lst.st_mode)) {
+ // Not a symlink => delete contents recursively
+ DCHECK(S_ISDIR(lst.st_mode));
+ RETURN_NOT_OK(DeleteDirTreeInternal(path));
+ if (remove_top_dir && rmdir(path.ToNative().c_str()) != 0) {
+ return IOErrorFromErrno(errno, "Cannot delete directory entry '", path.ToString(),
+ "'");
+ }
+ } else {
+ // Remove symlink
+ if (remove_top_dir && unlink(path.ToNative().c_str()) != 0) {
+ return IOErrorFromErrno(errno, "Cannot delete directory entry '", path.ToString(),
+ "'");
+ }
+ }
+ return Status::OK();
+}
+
+Status DeleteDirEntry(const PlatformFilename& path, const struct stat& lst) {
+ if (S_ISDIR(lst.st_mode)) {
+ return DeleteDirEntryDir(path, lst);
+ }
+ if (unlink(path.ToNative().c_str()) != 0) {
+ return IOErrorFromErrno(errno, "Cannot delete directory entry '", path.ToString(),
+ "'");
+ }
+ return Status::OK();
+}
+
+Status DeleteDirTreeInternal(const PlatformFilename& dir_path) {
+ ARROW_ASSIGN_OR_RAISE(auto children, ListDir(dir_path));
+ for (const auto& child : children) {
+ struct stat lst;
+ PlatformFilename full_path = dir_path.Join(child);
+ RETURN_NOT_OK(LinkStat(full_path, &lst));
+ RETURN_NOT_OK(DeleteDirEntry(full_path, lst));
+ }
+ return Status::OK();
+}
+
+Result<bool> DeleteDirContents(const PlatformFilename& dir_path, bool allow_not_found,
+ bool remove_top_dir) {
+ bool exists = true;
+ struct stat lst;
+ if (allow_not_found) {
+ RETURN_NOT_OK(LinkStat(dir_path, &lst, &exists));
+ } else {
+ // Will raise if dir_path does not exist
+ RETURN_NOT_OK(LinkStat(dir_path, &lst));
+ }
+ if (exists) {
+ if (!S_ISDIR(lst.st_mode) && !S_ISLNK(lst.st_mode)) {
+ return Status::IOError("Cannot delete directory '", dir_path.ToString(),
+ "': not a directory");
+ }
+ RETURN_NOT_OK(DeleteDirEntryDir(dir_path, lst, remove_top_dir));
+ }
+ return exists;
+}
+
+#endif
+
+} // namespace
+
+Result<bool> DeleteDirContents(const PlatformFilename& dir_path, bool allow_not_found) {
+ return DeleteDirContents(dir_path, allow_not_found, /*remove_top_dir=*/false);
+}
+
+Result<bool> DeleteDirTree(const PlatformFilename& dir_path, bool allow_not_found) {
+ return DeleteDirContents(dir_path, allow_not_found, /*remove_top_dir=*/true);
+}
+
+Result<bool> DeleteFile(const PlatformFilename& file_path, bool allow_not_found) {
+#ifdef _WIN32
+ if (DeleteFileW(file_path.ToNative().c_str())) {
+ return true;
+ } else {
+ int errnum = GetLastError();
+ if (!allow_not_found || errnum != ERROR_FILE_NOT_FOUND) {
+ return IOErrorFromWinError(GetLastError(), "Cannot delete file '",
+ file_path.ToString(), "'");
+ }
+ }
+#else
+ if (unlink(file_path.ToNative().c_str()) == 0) {
+ return true;
+ } else {
+ if (!allow_not_found || errno != ENOENT) {
+ return IOErrorFromErrno(errno, "Cannot delete file '", file_path.ToString(), "'");
+ }
+ }
+#endif
+ return false;
+}
+
+Result<bool> FileExists(const PlatformFilename& path) {
+#ifdef _WIN32
+ if (GetFileAttributesW(path.ToNative().c_str()) != INVALID_FILE_ATTRIBUTES) {
+ return true;
+ } else {
+ int errnum = GetLastError();
+ if (errnum != ERROR_PATH_NOT_FOUND && errnum != ERROR_FILE_NOT_FOUND) {
+ return IOErrorFromWinError(GetLastError(), "Failed getting information for path '",
+ path.ToString(), "'");
+ }
+ return false;
+ }
+#else
+ struct stat st;
+ if (stat(path.ToNative().c_str(), &st) == 0) {
+ return true;
+ } else {
+ if (errno != ENOENT && errno != ENOTDIR) {
+ return IOErrorFromErrno(errno, "Failed getting information for path '",
+ path.ToString(), "'");
+ }
+ return false;
+ }
+#endif
+}
+
+//
+// Functions for creating file descriptors
+//
+
+#define CHECK_LSEEK(retval) \
+ if ((retval) == -1) return Status::IOError("lseek failed");
+
+static inline int64_t lseek64_compat(int fd, int64_t pos, int whence) {
+#if defined(_WIN32)
+ return _lseeki64(fd, pos, whence);
+#else
+ return lseek(fd, pos, whence);
+#endif
+}
+
+static inline Result<int> CheckFileOpResult(int fd_ret, int errno_actual,
+ const PlatformFilename& file_name,
+ const char* opname) {
+ if (fd_ret == -1) {
+#ifdef _WIN32
+ int winerr = GetLastError();
+ if (winerr != ERROR_SUCCESS) {
+ return IOErrorFromWinError(GetLastError(), "Failed to ", opname, " file '",
+ file_name.ToString(), "'");
+ }
+#endif
+ return IOErrorFromErrno(errno_actual, "Failed to ", opname, " file '",
+ file_name.ToString(), "'");
+ }
+ return fd_ret;
+}
+
+Result<int> FileOpenReadable(const PlatformFilename& file_name) {
+ int fd, errno_actual;
+#if defined(_WIN32)
+ SetLastError(0);
+ HANDLE file_handle = CreateFileW(file_name.ToNative().c_str(), GENERIC_READ,
+ FILE_SHARE_READ | FILE_SHARE_WRITE, NULL,
+ OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
+
+ DWORD last_error = GetLastError();
+ if (last_error == ERROR_SUCCESS) {
+ errno_actual = 0;
+ fd = _open_osfhandle(reinterpret_cast<intptr_t>(file_handle),
+ _O_RDONLY | _O_BINARY | _O_NOINHERIT);
+ } else {
+ return IOErrorFromWinError(last_error, "Failed to open local file '",
+ file_name.ToString(), "'");
+ }
+#else
+ fd = open(file_name.ToNative().c_str(), O_RDONLY);
+ errno_actual = errno;
+
+ if (fd >= 0) {
+ // open(O_RDONLY) succeeds on directories, check for it
+ struct stat st;
+ int ret = fstat(fd, &st);
+ if (ret == -1) {
+ ARROW_UNUSED(FileClose(fd));
+ // Will propagate error below
+ } else if (S_ISDIR(st.st_mode)) {
+ ARROW_UNUSED(FileClose(fd));
+ return Status::IOError("Cannot open for reading: path '", file_name.ToString(),
+ "' is a directory");
+ }
+ }
+#endif
+
+ return CheckFileOpResult(fd, errno_actual, file_name, "open local");
+}
+
+Result<int> FileOpenWritable(const PlatformFilename& file_name, bool write_only,
+ bool truncate, bool append) {
+ int fd, errno_actual;
+
+#if defined(_WIN32)
+ SetLastError(0);
+ int oflag = _O_CREAT | _O_BINARY | _O_NOINHERIT;
+ DWORD desired_access = GENERIC_WRITE;
+ DWORD share_mode = FILE_SHARE_READ | FILE_SHARE_WRITE;
+ DWORD creation_disposition = OPEN_ALWAYS;
+
+ if (append) {
+ oflag |= _O_APPEND;
+ }
+
+ if (truncate) {
+ oflag |= _O_TRUNC;
+ creation_disposition = CREATE_ALWAYS;
+ }
+
+ if (write_only) {
+ oflag |= _O_WRONLY;
+ } else {
+ oflag |= _O_RDWR;
+ desired_access |= GENERIC_READ;
+ }
+
+ HANDLE file_handle =
+ CreateFileW(file_name.ToNative().c_str(), desired_access, share_mode, NULL,
+ creation_disposition, FILE_ATTRIBUTE_NORMAL, NULL);
+
+ DWORD last_error = GetLastError();
+ if (last_error == ERROR_SUCCESS || last_error == ERROR_ALREADY_EXISTS) {
+ errno_actual = 0;
+ fd = _open_osfhandle(reinterpret_cast<intptr_t>(file_handle), oflag);
+ } else {
+ return IOErrorFromWinError(last_error, "Failed to open local file '",
+ file_name.ToString(), "'");
+ }
+#else
+ int oflag = O_CREAT;
+
+ if (truncate) {
+ oflag |= O_TRUNC;
+ }
+ if (append) {
+ oflag |= O_APPEND;
+ }
+
+ if (write_only) {
+ oflag |= O_WRONLY;
+ } else {
+ oflag |= O_RDWR;
+ }
+
+ fd = open(file_name.ToNative().c_str(), oflag, 0666);
+ errno_actual = errno;
+#endif
+
+ RETURN_NOT_OK(CheckFileOpResult(fd, errno_actual, file_name, "open local"));
+ if (append) {
+ // Seek to end, as O_APPEND does not necessarily do it
+ auto ret = lseek64_compat(fd, 0, SEEK_END);
+ if (ret == -1) {
+ ARROW_UNUSED(FileClose(fd));
+ return Status::IOError("lseek failed");
+ }
+ }
+ return fd;
+}
+
+Result<int64_t> FileTell(int fd) {
+ int64_t current_pos;
+#if defined(_WIN32)
+ current_pos = _telli64(fd);
+ if (current_pos == -1) {
+ return Status::IOError("_telli64 failed");
+ }
+#else
+ current_pos = lseek64_compat(fd, 0, SEEK_CUR);
+ CHECK_LSEEK(current_pos);
+#endif
+ return current_pos;
+}
+
+Result<Pipe> CreatePipe() {
+ int ret;
+ int fd[2];
+#if defined(_WIN32)
+ ret = _pipe(fd, 4096, _O_BINARY);
+#else
+ ret = pipe(fd);
+#endif
+
+ if (ret == -1) {
+ return IOErrorFromErrno(errno, "Error creating pipe");
+ }
+ return Pipe{fd[0], fd[1]};
+}
+
+static Status StatusFromMmapErrno(const char* prefix) {
+#ifdef _WIN32
+ errno = __map_mman_error(GetLastError(), EPERM);
+#endif
+ return IOErrorFromErrno(errno, prefix);
+}
+
+namespace {
+
+int64_t GetPageSizeInternal() {
+#if defined(__APPLE__)
+ return getpagesize();
+#elif defined(_WIN32)
+ SYSTEM_INFO si;
+ GetSystemInfo(&si);
+ return si.dwPageSize;
+#else
+ errno = 0;
+ const auto ret = sysconf(_SC_PAGESIZE);
+ if (ret == -1) {
+ ARROW_LOG(FATAL) << "sysconf(_SC_PAGESIZE) failed: " << ErrnoMessage(errno);
+ }
+ return static_cast<int64_t>(ret);
+#endif
+}
+
+} // namespace
+
+int64_t GetPageSize() {
+ static const int64_t kPageSize = GetPageSizeInternal(); // cache it
+ return kPageSize;
+}
+
+//
+// Compatible way to remap a memory map
+//
+
+Status MemoryMapRemap(void* addr, size_t old_size, size_t new_size, int fildes,
+ void** new_addr) {
+ // should only be called with writable files
+ *new_addr = MAP_FAILED;
+#ifdef _WIN32
+ // flags are ignored on windows
+ HANDLE fm, h;
+
+ if (!UnmapViewOfFile(addr)) {
+ return StatusFromMmapErrno("UnmapViewOfFile failed");
+ }
+
+ h = reinterpret_cast<HANDLE>(_get_osfhandle(fildes));
+ if (h == INVALID_HANDLE_VALUE) {
+ return StatusFromMmapErrno("Cannot get file handle");
+ }
+
+ uint64_t new_size64 = new_size;
+ LONG new_size_low = static_cast<LONG>(new_size64 & 0xFFFFFFFFUL);
+ LONG new_size_high = static_cast<LONG>((new_size64 >> 32) & 0xFFFFFFFFUL);
+
+ SetFilePointer(h, new_size_low, &new_size_high, FILE_BEGIN);
+ SetEndOfFile(h);
+ fm = CreateFileMapping(h, NULL, PAGE_READWRITE, 0, 0, "");
+ if (fm == NULL) {
+ return StatusFromMmapErrno("CreateFileMapping failed");
+ }
+ *new_addr = MapViewOfFile(fm, FILE_MAP_WRITE, 0, 0, new_size);
+ CloseHandle(fm);
+ if (new_addr == NULL) {
+ return StatusFromMmapErrno("MapViewOfFile failed");
+ }
+ return Status::OK();
+#elif defined(__linux__)
+ if (ftruncate(fildes, new_size) == -1) {
+ return StatusFromMmapErrno("ftruncate failed");
+ }
+ *new_addr = mremap(addr, old_size, new_size, MREMAP_MAYMOVE);
+ if (*new_addr == MAP_FAILED) {
+ return StatusFromMmapErrno("mremap failed");
+ }
+ return Status::OK();
+#else
+ // we have to close the mmap first, truncate the file to the new size
+ // and recreate the mmap
+ if (munmap(addr, old_size) == -1) {
+ return StatusFromMmapErrno("munmap failed");
+ }
+ if (ftruncate(fildes, new_size) == -1) {
+ return StatusFromMmapErrno("ftruncate failed");
+ }
+ // we set READ / WRITE flags on the new map, since we could only have
+ // unlarged a RW map in the first place
+ *new_addr = mmap(NULL, new_size, PROT_READ | PROT_WRITE, MAP_SHARED, fildes, 0);
+ if (*new_addr == MAP_FAILED) {
+ return StatusFromMmapErrno("mmap failed");
+ }
+ return Status::OK();
+#endif
+}
+
+Status MemoryAdviseWillNeed(const std::vector<MemoryRegion>& regions) {
+ const auto page_size = static_cast<size_t>(GetPageSize());
+ DCHECK_GT(page_size, 0);
+ const size_t page_mask = ~(page_size - 1);
+ DCHECK_EQ(page_mask & page_size, page_size);
+
+ auto align_region = [=](const MemoryRegion& region) -> MemoryRegion {
+ const auto addr = reinterpret_cast<uintptr_t>(region.addr);
+ const auto aligned_addr = addr & page_mask;
+ DCHECK_LT(addr - aligned_addr, page_size);
+ return {reinterpret_cast<void*>(aligned_addr),
+ region.size + static_cast<size_t>(addr - aligned_addr)};
+ };
+
+#ifdef _WIN32
+ // PrefetchVirtualMemory() is available on Windows 8 or later
+ struct PrefetchEntry { // Like WIN32_MEMORY_RANGE_ENTRY
+ void* VirtualAddress;
+ size_t NumberOfBytes;
+
+ PrefetchEntry(const MemoryRegion& region) // NOLINT runtime/explicit
+ : VirtualAddress(region.addr), NumberOfBytes(region.size) {}
+ };
+ using PrefetchVirtualMemoryFunc = BOOL (*)(HANDLE, ULONG_PTR, PrefetchEntry*, ULONG);
+ static const auto prefetch_virtual_memory = reinterpret_cast<PrefetchVirtualMemoryFunc>(
+ GetProcAddress(GetModuleHandleW(L"kernel32.dll"), "PrefetchVirtualMemory"));
+ if (prefetch_virtual_memory != nullptr) {
+ std::vector<PrefetchEntry> entries;
+ entries.reserve(regions.size());
+ for (const auto& region : regions) {
+ if (region.size != 0) {
+ entries.emplace_back(align_region(region));
+ }
+ }
+ if (!entries.empty() &&
+ !prefetch_virtual_memory(GetCurrentProcess(),
+ static_cast<ULONG_PTR>(entries.size()), entries.data(),
+ 0)) {
+ return IOErrorFromWinError(GetLastError(), "PrefetchVirtualMemory failed");
+ }
+ }
+ return Status::OK();
+#elif defined(POSIX_MADV_WILLNEED)
+ for (const auto& region : regions) {
+ if (region.size != 0) {
+ const auto aligned = align_region(region);
+ int err = posix_madvise(aligned.addr, aligned.size, POSIX_MADV_WILLNEED);
+ // EBADF can be returned on Linux in the following cases:
+ // - the kernel version is older than 3.9
+ // - the kernel was compiled with CONFIG_SWAP disabled (ARROW-9577)
+ if (err != 0 && err != EBADF) {
+ return IOErrorFromErrno(err, "posix_madvise failed");
+ }
+ }
+ }
+ return Status::OK();
+#else
+ return Status::OK();
+#endif
+}
+
+//
+// Closing files
+//
+
+Status FileClose(int fd) {
+ int ret;
+
+#if defined(_WIN32)
+ ret = static_cast<int>(_close(fd));
+#else
+ ret = static_cast<int>(close(fd));
+#endif
+
+ if (ret == -1) {
+ return Status::IOError("error closing file");
+ }
+ return Status::OK();
+}
+
+//
+// Seeking and telling
+//
+
+Status FileSeek(int fd, int64_t pos, int whence) {
+ int64_t ret = lseek64_compat(fd, pos, whence);
+ CHECK_LSEEK(ret);
+ return Status::OK();
+}
+
+Status FileSeek(int fd, int64_t pos) { return FileSeek(fd, pos, SEEK_SET); }
+
+Result<int64_t> FileGetSize(int fd) {
+#if defined(_WIN32)
+ struct __stat64 st;
+#else
+ struct stat st;
+#endif
+ st.st_size = -1;
+
+#if defined(_WIN32)
+ int ret = _fstat64(fd, &st);
+#else
+ int ret = fstat(fd, &st);
+#endif
+
+ if (ret == -1) {
+ return Status::IOError("error stat()ing file");
+ }
+ if (st.st_size == 0) {
+ // Maybe the file doesn't support getting its size, double-check by
+ // trying to tell() (seekable files usually have a size, while
+ // non-seekable files don't)
+ RETURN_NOT_OK(FileTell(fd));
+ } else if (st.st_size < 0) {
+ return Status::IOError("error getting file size");
+ }
+ return st.st_size;
+}
+
+//
+// Reading data
+//
+
+static inline int64_t pread_compat(int fd, void* buf, int64_t nbytes, int64_t pos) {
+#if defined(_WIN32)
+ HANDLE handle = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
+ DWORD dwBytesRead = 0;
+ OVERLAPPED overlapped = {0};
+ overlapped.Offset = static_cast<uint32_t>(pos);
+ overlapped.OffsetHigh = static_cast<uint32_t>(pos >> 32);
+
+ // Note: ReadFile() will update the file position
+ BOOL bRet =
+ ReadFile(handle, buf, static_cast<uint32_t>(nbytes), &dwBytesRead, &overlapped);
+ if (bRet || GetLastError() == ERROR_HANDLE_EOF) {
+ return dwBytesRead;
+ } else {
+ return -1;
+ }
+#else
+ return static_cast<int64_t>(
+ pread(fd, buf, static_cast<size_t>(nbytes), static_cast<off_t>(pos)));
+#endif
+}
+
+Result<int64_t> FileRead(int fd, uint8_t* buffer, int64_t nbytes) {
+ int64_t bytes_read = 0;
+
+ while (bytes_read < nbytes) {
+ int64_t chunksize =
+ std::min(static_cast<int64_t>(ARROW_MAX_IO_CHUNKSIZE), nbytes - bytes_read);
+#if defined(_WIN32)
+ int64_t ret =
+ static_cast<int64_t>(_read(fd, buffer, static_cast<uint32_t>(chunksize)));
+#else
+ int64_t ret = static_cast<int64_t>(read(fd, buffer, static_cast<size_t>(chunksize)));
+#endif
+
+ if (ret == -1) {
+ return IOErrorFromErrno(errno, "Error reading bytes from file");
+ }
+ if (ret == 0) {
+ // EOF
+ break;
+ }
+ buffer += ret;
+ bytes_read += ret;
+ }
+ return bytes_read;
+}
+
+Result<int64_t> FileReadAt(int fd, uint8_t* buffer, int64_t position, int64_t nbytes) {
+ int64_t bytes_read = 0;
+
+ while (bytes_read < nbytes) {
+ int64_t chunksize =
+ std::min(static_cast<int64_t>(ARROW_MAX_IO_CHUNKSIZE), nbytes - bytes_read);
+ int64_t ret = pread_compat(fd, buffer, chunksize, position);
+
+ if (ret == -1) {
+ return IOErrorFromErrno(errno, "Error reading bytes from file");
+ }
+ if (ret == 0) {
+ // EOF
+ break;
+ }
+ buffer += ret;
+ position += ret;
+ bytes_read += ret;
+ }
+ return bytes_read;
+}
+
+//
+// Writing data
+//
+
+Status FileWrite(int fd, const uint8_t* buffer, const int64_t nbytes) {
+ int ret = 0;
+ int64_t bytes_written = 0;
+
+ while (ret != -1 && bytes_written < nbytes) {
+ int64_t chunksize =
+ std::min(static_cast<int64_t>(ARROW_MAX_IO_CHUNKSIZE), nbytes - bytes_written);
+#if defined(_WIN32)
+ ret = static_cast<int>(
+ _write(fd, buffer + bytes_written, static_cast<uint32_t>(chunksize)));
+#else
+ ret = static_cast<int>(
+ write(fd, buffer + bytes_written, static_cast<size_t>(chunksize)));
+#endif
+
+ if (ret != -1) {
+ bytes_written += ret;
+ }
+ }
+
+ if (ret == -1) {
+ return IOErrorFromErrno(errno, "Error writing bytes to file");
+ }
+ return Status::OK();
+}
+
+Status FileTruncate(int fd, const int64_t size) {
+ int ret, errno_actual;
+
+#ifdef _WIN32
+ errno_actual = _chsize_s(fd, static_cast<size_t>(size));
+ ret = errno_actual == 0 ? 0 : -1;
+#else
+ ret = ftruncate(fd, static_cast<size_t>(size));
+ errno_actual = errno;
+#endif
+
+ if (ret == -1) {
+ return IOErrorFromErrno(errno_actual, "Error writing bytes to file");
+ }
+ return Status::OK();
+}
+
+//
+// Environment variables
+//
+
+Result<std::string> GetEnvVar(const char* name) {
+#ifdef _WIN32
+ // On Windows, getenv() reads an early copy of the process' environment
+ // which doesn't get updated when SetEnvironmentVariable() is called.
+ constexpr int32_t bufsize = 2000;
+ char c_str[bufsize];
+ auto res = GetEnvironmentVariableA(name, c_str, bufsize);
+ if (res >= bufsize) {
+ return Status::CapacityError("environment variable value too long");
+ } else if (res == 0) {
+ return Status::KeyError("environment variable undefined");
+ }
+ return std::string(c_str);
+#else
+ char* c_str = getenv(name);
+ if (c_str == nullptr) {
+ return Status::KeyError("environment variable undefined");
+ }
+ return std::string(c_str);
+#endif
+}
+
+Result<std::string> GetEnvVar(const std::string& name) { return GetEnvVar(name.c_str()); }
+
+#ifdef _WIN32
+Result<NativePathString> GetEnvVarNative(const std::string& name) {
+ NativePathString w_name;
+ constexpr int32_t bufsize = 2000;
+ wchar_t w_str[bufsize];
+
+ ARROW_ASSIGN_OR_RAISE(w_name, StringToNative(name));
+ auto res = GetEnvironmentVariableW(w_name.c_str(), w_str, bufsize);
+ if (res >= bufsize) {
+ return Status::CapacityError("environment variable value too long");
+ } else if (res == 0) {
+ return Status::KeyError("environment variable undefined");
+ }
+ return NativePathString(w_str);
+}
+
+Result<NativePathString> GetEnvVarNative(const char* name) {
+ return GetEnvVarNative(std::string(name));
+}
+
+#else
+
+Result<NativePathString> GetEnvVarNative(const std::string& name) {
+ return GetEnvVar(name);
+}
+
+Result<NativePathString> GetEnvVarNative(const char* name) { return GetEnvVar(name); }
+#endif
+
+Status SetEnvVar(const char* name, const char* value) {
+#ifdef _WIN32
+ if (SetEnvironmentVariableA(name, value)) {
+ return Status::OK();
+ } else {
+ return Status::Invalid("failed setting environment variable");
+ }
+#else
+ if (setenv(name, value, 1) == 0) {
+ return Status::OK();
+ } else {
+ return Status::Invalid("failed setting environment variable");
+ }
+#endif
+}
+
+Status SetEnvVar(const std::string& name, const std::string& value) {
+ return SetEnvVar(name.c_str(), value.c_str());
+}
+
+Status DelEnvVar(const char* name) {
+#ifdef _WIN32
+ if (SetEnvironmentVariableA(name, nullptr)) {
+ return Status::OK();
+ } else {
+ return Status::Invalid("failed deleting environment variable");
+ }
+#else
+ if (unsetenv(name) == 0) {
+ return Status::OK();
+ } else {
+ return Status::Invalid("failed deleting environment variable");
+ }
+#endif
+}
+
+Status DelEnvVar(const std::string& name) { return DelEnvVar(name.c_str()); }
+
+//
+// Temporary directories
+//
+
+namespace {
+
+#if _WIN32
+NativePathString GetWindowsDirectoryPath() {
+ auto size = GetWindowsDirectoryW(nullptr, 0);
+ ARROW_CHECK_GT(size, 0) << "GetWindowsDirectoryW failed";
+ std::vector<wchar_t> w_str(size);
+ size = GetWindowsDirectoryW(w_str.data(), size);
+ ARROW_CHECK_GT(size, 0) << "GetWindowsDirectoryW failed";
+ return {w_str.data(), size};
+}
+#endif
+
+// Return a list of preferred locations for temporary files
+std::vector<NativePathString> GetPlatformTemporaryDirs() {
+ struct TempDirSelector {
+ std::string env_var;
+ NativePathString path_append;
+ };
+
+ std::vector<TempDirSelector> selectors;
+ NativePathString fallback_tmp;
+
+#if _WIN32
+ selectors = {
+ {"TMP", L""}, {"TEMP", L""}, {"LOCALAPPDATA", L"Temp"}, {"USERPROFILE", L"Temp"}};
+ fallback_tmp = GetWindowsDirectoryPath();
+
+#else
+ selectors = {{"TMPDIR", ""}, {"TMP", ""}, {"TEMP", ""}, {"TEMPDIR", ""}};
+#ifdef __ANDROID__
+ fallback_tmp = "/data/local/tmp";
+#else
+ fallback_tmp = "/tmp";
+#endif
+#endif
+
+ std::vector<NativePathString> temp_dirs;
+ for (const auto& sel : selectors) {
+ auto result = GetEnvVarNative(sel.env_var);
+ if (result.status().IsKeyError()) {
+ // Environment variable absent, skip
+ continue;
+ }
+ if (!result.ok()) {
+ ARROW_LOG(WARNING) << "Failed getting env var '" << sel.env_var
+ << "': " << result.status().ToString();
+ continue;
+ }
+ NativePathString p = *std::move(result);
+ if (p.empty()) {
+ // Environment variable set to empty string, skip
+ continue;
+ }
+ if (sel.path_append.empty()) {
+ temp_dirs.push_back(p);
+ } else {
+ temp_dirs.push_back(p + kNativeSep + sel.path_append);
+ }
+ }
+ temp_dirs.push_back(fallback_tmp);
+ return temp_dirs;
+}
+
+std::string MakeRandomName(int num_chars) {
+ static const std::string chars = "0123456789abcdefghijklmnopqrstuvwxyz";
+ std::default_random_engine gen(
+ static_cast<std::default_random_engine::result_type>(GetRandomSeed()));
+ std::uniform_int_distribution<int> dist(0, static_cast<int>(chars.length() - 1));
+
+ std::string s;
+ s.reserve(num_chars);
+ for (int i = 0; i < num_chars; ++i) {
+ s += chars[dist(gen)];
+ }
+ return s;
+}
+
+} // namespace
+
+Result<std::unique_ptr<TemporaryDir>> TemporaryDir::Make(const std::string& prefix) {
+ const int kNumChars = 8;
+
+ NativePathString base_name;
+
+ auto MakeBaseName = [&]() {
+ std::string suffix = MakeRandomName(kNumChars);
+ return StringToNative(prefix + suffix);
+ };
+
+ auto TryCreatingDirectory =
+ [&](const NativePathString& base_dir) -> Result<std::unique_ptr<TemporaryDir>> {
+ Status st;
+ for (int attempt = 0; attempt < 3; ++attempt) {
+ PlatformFilename fn(base_dir + kNativeSep + base_name + kNativeSep);
+ auto result = CreateDir(fn);
+ if (!result.ok()) {
+ // Probably a permissions error or a non-existing base_dir
+ return nullptr;
+ }
+ if (*result) {
+ return std::unique_ptr<TemporaryDir>(new TemporaryDir(std::move(fn)));
+ }
+ // The random name already exists in base_dir, try with another name
+ st = Status::IOError("Path already exists: '", fn.ToString(), "'");
+ ARROW_ASSIGN_OR_RAISE(base_name, MakeBaseName());
+ }
+ return st;
+ };
+
+ ARROW_ASSIGN_OR_RAISE(base_name, MakeBaseName());
+
+ auto base_dirs = GetPlatformTemporaryDirs();
+ DCHECK_NE(base_dirs.size(), 0);
+
+ for (const auto& base_dir : base_dirs) {
+ ARROW_ASSIGN_OR_RAISE(auto ptr, TryCreatingDirectory(base_dir));
+ if (ptr) {
+ return std::move(ptr);
+ }
+ // Cannot create in this directory, try the next one
+ }
+
+ return Status::IOError(
+ "Cannot create temporary subdirectory in any "
+ "of the platform temporary directories");
+}
+
+TemporaryDir::TemporaryDir(PlatformFilename&& path) : path_(std::move(path)) {}
+
+TemporaryDir::~TemporaryDir() {
+ Status st = DeleteDirTree(path_).status();
+ if (!st.ok()) {
+ ARROW_LOG(WARNING) << "When trying to delete temporary directory: " << st;
+ }
+}
+
+SignalHandler::SignalHandler() : SignalHandler(static_cast<Callback>(nullptr)) {}
+
+SignalHandler::SignalHandler(Callback cb) {
+#if ARROW_HAVE_SIGACTION
+ sa_.sa_handler = cb;
+ sa_.sa_flags = 0;
+ sigemptyset(&sa_.sa_mask);
+#else
+ cb_ = cb;
+#endif
+}
+
+#if ARROW_HAVE_SIGACTION
+SignalHandler::SignalHandler(const struct sigaction& sa) {
+ memcpy(&sa_, &sa, sizeof(sa));
+}
+#endif
+
+SignalHandler::Callback SignalHandler::callback() const {
+#if ARROW_HAVE_SIGACTION
+ return sa_.sa_handler;
+#else
+ return cb_;
+#endif
+}
+
+#if ARROW_HAVE_SIGACTION
+const struct sigaction& SignalHandler::action() const { return sa_; }
+#endif
+
+Result<SignalHandler> GetSignalHandler(int signum) {
+#if ARROW_HAVE_SIGACTION
+ struct sigaction sa;
+ int ret = sigaction(signum, nullptr, &sa);
+ if (ret != 0) {
+ // TODO more detailed message using errno
+ return Status::IOError("sigaction call failed");
+ }
+ return SignalHandler(sa);
+#else
+ // To read the old handler, set the signal handler to something else temporarily
+ SignalHandler::Callback cb = signal(signum, SIG_IGN);
+ if (cb == SIG_ERR || signal(signum, cb) == SIG_ERR) {
+ // TODO more detailed message using errno
+ return Status::IOError("signal call failed");
+ }
+ return SignalHandler(cb);
+#endif
+}
+
+Result<SignalHandler> SetSignalHandler(int signum, const SignalHandler& handler) {
+#if ARROW_HAVE_SIGACTION
+ struct sigaction old_sa;
+ int ret = sigaction(signum, &handler.action(), &old_sa);
+ if (ret != 0) {
+ // TODO more detailed message using errno
+ return Status::IOError("sigaction call failed");
+ }
+ return SignalHandler(old_sa);
+#else
+ SignalHandler::Callback cb = signal(signum, handler.callback());
+ if (cb == SIG_ERR) {
+ // TODO more detailed message using errno
+ return Status::IOError("signal call failed");
+ }
+ return SignalHandler(cb);
+#endif
+ return Status::OK();
+}
+
+void ReinstateSignalHandler(int signum, SignalHandler::Callback handler) {
+#if !ARROW_HAVE_SIGACTION
+ // Cannot report any errors from signal() (but there shouldn't be any)
+ signal(signum, handler);
+#endif
+}
+
+Status SendSignal(int signum) {
+ if (raise(signum) == 0) {
+ return Status::OK();
+ }
+ if (errno == EINVAL) {
+ return Status::Invalid("Invalid signal number ", signum);
+ }
+ return IOErrorFromErrno(errno, "Failed to raise signal");
+}
+
+Status SendSignalToThread(int signum, uint64_t thread_id) {
+#ifdef _WIN32
+ return Status::NotImplemented("Cannot send signal to specific thread on Windows");
+#else
+ // Have to use a C-style cast because pthread_t can be a pointer *or* integer type
+ int r = pthread_kill((pthread_t)thread_id, signum); // NOLINT readability-casting
+ if (r == 0) {
+ return Status::OK();
+ }
+ if (r == EINVAL) {
+ return Status::Invalid("Invalid signal number ", signum);
+ }
+ return IOErrorFromErrno(r, "Failed to raise signal");
+#endif
+}
+
+namespace {
+
+int64_t GetPid() {
+#ifdef _WIN32
+ return GetCurrentProcessId();
+#else
+ return getpid();
+#endif
+}
+
+std::mt19937_64 GetSeedGenerator() {
+ // Initialize Mersenne Twister PRNG with a true random seed.
+ // Make sure to mix in process id to minimize risks of clashes when parallel testing.
+#ifdef ARROW_VALGRIND
+ // Valgrind can crash, hang or enter an infinite loop on std::random_device,
+ // use a crude initializer instead.
+ const uint8_t dummy = 0;
+ ARROW_UNUSED(dummy);
+ std::mt19937_64 seed_gen(reinterpret_cast<uintptr_t>(&dummy) ^
+ static_cast<uintptr_t>(GetPid()));
+#else
+ std::random_device true_random;
+ std::mt19937_64 seed_gen(static_cast<uint64_t>(true_random()) ^
+ (static_cast<uint64_t>(true_random()) << 32) ^
+ static_cast<uint64_t>(GetPid()));
+#endif
+ return seed_gen;
+}
+
+} // namespace
+
+int64_t GetRandomSeed() {
+ // The process-global seed generator to aims to avoid calling std::random_device
+ // unless truly necessary (it can block on some systems, see ARROW-10287).
+ static auto seed_gen = GetSeedGenerator();
+ return static_cast<int64_t>(seed_gen());
+}
+
+uint64_t GetThreadId() {
+ uint64_t equiv{0};
+ // std::thread::id is trivially copyable as per C++ spec,
+ // so type punning as a uint64_t should work
+ static_assert(sizeof(std::thread::id) <= sizeof(uint64_t),
+ "std::thread::id can't fit into uint64_t");
+ const auto tid = std::this_thread::get_id();
+ memcpy(&equiv, reinterpret_cast<const void*>(&tid), sizeof(tid));
+ return equiv;
+}
+
+uint64_t GetOptionalThreadId() {
+ auto tid = GetThreadId();
+ return (tid == 0) ? tid - 1 : tid;
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/io_util.h b/src/arrow/cpp/src/arrow/util/io_util.h
new file mode 100644
index 000000000..4255dd371
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/io_util.h
@@ -0,0 +1,349 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#ifndef _WIN32
+#define ARROW_HAVE_SIGACTION 1
+#endif
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#if ARROW_HAVE_SIGACTION
+#include <signal.h> // Needed for struct sigaction
+#endif
+
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/windows_fixup.h"
+
+namespace arrow {
+namespace internal {
+
+// NOTE: 8-bit path strings on Windows are encoded using UTF-8.
+// Using MBCS would fail encoding some paths.
+
+#if defined(_WIN32)
+using NativePathString = std::wstring;
+#else
+using NativePathString = std::string;
+#endif
+
+class ARROW_EXPORT PlatformFilename {
+ public:
+ struct Impl;
+
+ ~PlatformFilename();
+ PlatformFilename();
+ PlatformFilename(const PlatformFilename&);
+ PlatformFilename(PlatformFilename&&);
+ PlatformFilename& operator=(const PlatformFilename&);
+ PlatformFilename& operator=(PlatformFilename&&);
+ explicit PlatformFilename(const NativePathString& path);
+ explicit PlatformFilename(const NativePathString::value_type* path);
+
+ const NativePathString& ToNative() const;
+ std::string ToString() const;
+
+ PlatformFilename Parent() const;
+
+ // These functions can fail for character encoding reasons.
+ static Result<PlatformFilename> FromString(const std::string& file_name);
+ Result<PlatformFilename> Join(const std::string& child_name) const;
+
+ PlatformFilename Join(const PlatformFilename& child_name) const;
+
+ bool operator==(const PlatformFilename& other) const;
+ bool operator!=(const PlatformFilename& other) const;
+
+ // Made public to avoid the proliferation of friend declarations.
+ const Impl* impl() const { return impl_.get(); }
+
+ private:
+ std::unique_ptr<Impl> impl_;
+
+ explicit PlatformFilename(Impl impl);
+};
+
+/// Create a directory if it doesn't exist.
+///
+/// Return whether the directory was created.
+ARROW_EXPORT
+Result<bool> CreateDir(const PlatformFilename& dir_path);
+
+/// Create a directory and its parents if it doesn't exist.
+///
+/// Return whether the directory was created.
+ARROW_EXPORT
+Result<bool> CreateDirTree(const PlatformFilename& dir_path);
+
+/// Delete a directory's contents (but not the directory itself) if it exists.
+///
+/// Return whether the directory existed.
+ARROW_EXPORT
+Result<bool> DeleteDirContents(const PlatformFilename& dir_path,
+ bool allow_not_found = true);
+
+/// Delete a directory tree if it exists.
+///
+/// Return whether the directory existed.
+ARROW_EXPORT
+Result<bool> DeleteDirTree(const PlatformFilename& dir_path, bool allow_not_found = true);
+
+// Non-recursively list the contents of the given directory.
+// The returned names are the children's base names, not including dir_path.
+ARROW_EXPORT
+Result<std::vector<PlatformFilename>> ListDir(const PlatformFilename& dir_path);
+
+/// Delete a file if it exists.
+///
+/// Return whether the file existed.
+ARROW_EXPORT
+Result<bool> DeleteFile(const PlatformFilename& file_path, bool allow_not_found = true);
+
+/// Return whether a file exists.
+ARROW_EXPORT
+Result<bool> FileExists(const PlatformFilename& path);
+
+/// Open a file for reading and return a file descriptor.
+ARROW_EXPORT
+Result<int> FileOpenReadable(const PlatformFilename& file_name);
+
+/// Open a file for writing and return a file descriptor.
+ARROW_EXPORT
+Result<int> FileOpenWritable(const PlatformFilename& file_name, bool write_only = true,
+ bool truncate = true, bool append = false);
+
+/// Read from current file position. Return number of bytes read.
+ARROW_EXPORT
+Result<int64_t> FileRead(int fd, uint8_t* buffer, int64_t nbytes);
+/// Read from given file position. Return number of bytes read.
+ARROW_EXPORT
+Result<int64_t> FileReadAt(int fd, uint8_t* buffer, int64_t position, int64_t nbytes);
+
+ARROW_EXPORT
+Status FileWrite(int fd, const uint8_t* buffer, const int64_t nbytes);
+ARROW_EXPORT
+Status FileTruncate(int fd, const int64_t size);
+
+ARROW_EXPORT
+Status FileSeek(int fd, int64_t pos);
+ARROW_EXPORT
+Status FileSeek(int fd, int64_t pos, int whence);
+ARROW_EXPORT
+Result<int64_t> FileTell(int fd);
+ARROW_EXPORT
+Result<int64_t> FileGetSize(int fd);
+
+ARROW_EXPORT
+Status FileClose(int fd);
+
+struct Pipe {
+ int rfd;
+ int wfd;
+};
+
+ARROW_EXPORT
+Result<Pipe> CreatePipe();
+
+ARROW_EXPORT
+int64_t GetPageSize();
+
+struct MemoryRegion {
+ void* addr;
+ size_t size;
+};
+
+ARROW_EXPORT
+Status MemoryMapRemap(void* addr, size_t old_size, size_t new_size, int fildes,
+ void** new_addr);
+ARROW_EXPORT
+Status MemoryAdviseWillNeed(const std::vector<MemoryRegion>& regions);
+
+ARROW_EXPORT
+Result<std::string> GetEnvVar(const char* name);
+ARROW_EXPORT
+Result<std::string> GetEnvVar(const std::string& name);
+ARROW_EXPORT
+Result<NativePathString> GetEnvVarNative(const char* name);
+ARROW_EXPORT
+Result<NativePathString> GetEnvVarNative(const std::string& name);
+
+ARROW_EXPORT
+Status SetEnvVar(const char* name, const char* value);
+ARROW_EXPORT
+Status SetEnvVar(const std::string& name, const std::string& value);
+ARROW_EXPORT
+Status DelEnvVar(const char* name);
+ARROW_EXPORT
+Status DelEnvVar(const std::string& name);
+
+ARROW_EXPORT
+std::string ErrnoMessage(int errnum);
+#if _WIN32
+ARROW_EXPORT
+std::string WinErrorMessage(int errnum);
+#endif
+
+ARROW_EXPORT
+std::shared_ptr<StatusDetail> StatusDetailFromErrno(int errnum);
+#if _WIN32
+ARROW_EXPORT
+std::shared_ptr<StatusDetail> StatusDetailFromWinError(int errnum);
+#endif
+ARROW_EXPORT
+std::shared_ptr<StatusDetail> StatusDetailFromSignal(int signum);
+
+template <typename... Args>
+Status StatusFromErrno(int errnum, StatusCode code, Args&&... args) {
+ return Status::FromDetailAndArgs(code, StatusDetailFromErrno(errnum),
+ std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+Status IOErrorFromErrno(int errnum, Args&&... args) {
+ return StatusFromErrno(errnum, StatusCode::IOError, std::forward<Args>(args)...);
+}
+
+#if _WIN32
+template <typename... Args>
+Status StatusFromWinError(int errnum, StatusCode code, Args&&... args) {
+ return Status::FromDetailAndArgs(code, StatusDetailFromWinError(errnum),
+ std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+Status IOErrorFromWinError(int errnum, Args&&... args) {
+ return StatusFromWinError(errnum, StatusCode::IOError, std::forward<Args>(args)...);
+}
+#endif
+
+template <typename... Args>
+Status StatusFromSignal(int signum, StatusCode code, Args&&... args) {
+ return Status::FromDetailAndArgs(code, StatusDetailFromSignal(signum),
+ std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+Status CancelledFromSignal(int signum, Args&&... args) {
+ return StatusFromSignal(signum, StatusCode::Cancelled, std::forward<Args>(args)...);
+}
+
+ARROW_EXPORT
+int ErrnoFromStatus(const Status&);
+
+// Always returns 0 on non-Windows platforms (for Python).
+ARROW_EXPORT
+int WinErrorFromStatus(const Status&);
+
+ARROW_EXPORT
+int SignalFromStatus(const Status&);
+
+class ARROW_EXPORT TemporaryDir {
+ public:
+ ~TemporaryDir();
+
+ /// '/'-terminated path to the temporary dir
+ const PlatformFilename& path() { return path_; }
+
+ /// Create a temporary subdirectory in the system temporary dir,
+ /// named starting with `prefix`.
+ static Result<std::unique_ptr<TemporaryDir>> Make(const std::string& prefix);
+
+ private:
+ PlatformFilename path_;
+
+ explicit TemporaryDir(PlatformFilename&&);
+};
+
+class ARROW_EXPORT SignalHandler {
+ public:
+ typedef void (*Callback)(int);
+
+ SignalHandler();
+ explicit SignalHandler(Callback cb);
+#if ARROW_HAVE_SIGACTION
+ explicit SignalHandler(const struct sigaction& sa);
+#endif
+
+ Callback callback() const;
+#if ARROW_HAVE_SIGACTION
+ const struct sigaction& action() const;
+#endif
+
+ protected:
+#if ARROW_HAVE_SIGACTION
+ // Storing the full sigaction allows to restore the entire signal handling
+ // configuration.
+ struct sigaction sa_;
+#else
+ Callback cb_;
+#endif
+};
+
+/// \brief Return the current handler for the given signal number.
+ARROW_EXPORT
+Result<SignalHandler> GetSignalHandler(int signum);
+
+/// \brief Set a new handler for the given signal number.
+///
+/// The old signal handler is returned.
+ARROW_EXPORT
+Result<SignalHandler> SetSignalHandler(int signum, const SignalHandler& handler);
+
+/// \brief Reinstate the signal handler
+///
+/// For use in signal handlers. This is needed on platforms without sigaction()
+/// such as Windows, as the default signal handler is restored there as
+/// soon as a signal is raised.
+ARROW_EXPORT
+void ReinstateSignalHandler(int signum, SignalHandler::Callback handler);
+
+/// \brief Send a signal to the current process
+///
+/// The thread which will receive the signal is unspecified.
+ARROW_EXPORT
+Status SendSignal(int signum);
+
+/// \brief Send a signal to the given thread
+///
+/// This function isn't supported on Windows.
+ARROW_EXPORT
+Status SendSignalToThread(int signum, uint64_t thread_id);
+
+/// \brief Get an unpredictable random seed
+///
+/// This function may be slightly costly, so should only be used to initialize
+/// a PRNG, not to generate a large amount of random numbers.
+/// It is better to use this function rather than std::random_device, unless
+/// absolutely necessary (e.g. to generate a cryptographic secret).
+ARROW_EXPORT
+int64_t GetRandomSeed();
+
+/// \brief Get the current thread id
+///
+/// In addition to having the same properties as std::thread, the returned value
+/// is a regular integer value, which is more convenient than an opaque type.
+ARROW_EXPORT
+uint64_t GetThreadId();
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/io_util_test.cc b/src/arrow/cpp/src/arrow/util/io_util_test.cc
new file mode 100644
index 000000000..c09e4b974
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/io_util_test.cc
@@ -0,0 +1,713 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// std::unique_ptr<TemporaryDir> temp_dir;
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <atomic>
+#include <cerrno>
+#include <limits>
+#include <sstream>
+#include <vector>
+
+#include <signal.h>
+
+#ifndef _WIN32
+#include <pthread.h>
+#endif
+
+#include <gmock/gmock-matchers.h>
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/windows_compatibility.h"
+#include "arrow/util/windows_fixup.h"
+
+namespace arrow {
+namespace internal {
+
+void AssertExists(const PlatformFilename& path) {
+ bool exists = false;
+ ASSERT_OK_AND_ASSIGN(exists, FileExists(path));
+ ASSERT_TRUE(exists) << "Path '" << path.ToString() << "' doesn't exist";
+}
+
+void AssertNotExists(const PlatformFilename& path) {
+ bool exists = true;
+ ASSERT_OK_AND_ASSIGN(exists, FileExists(path));
+ ASSERT_FALSE(exists) << "Path '" << path.ToString() << "' exists";
+}
+
+void TouchFile(const PlatformFilename& path) {
+ int fd = -1;
+ ASSERT_OK_AND_ASSIGN(fd, FileOpenWritable(path));
+ ASSERT_OK(FileClose(fd));
+}
+
+TEST(ErrnoFromStatus, Basics) {
+ Status st;
+ st = Status::OK();
+ ASSERT_EQ(ErrnoFromStatus(st), 0);
+ st = Status::KeyError("foo");
+ ASSERT_EQ(ErrnoFromStatus(st), 0);
+ st = Status::IOError("foo");
+ ASSERT_EQ(ErrnoFromStatus(st), 0);
+ st = StatusFromErrno(EINVAL, StatusCode::KeyError, "foo");
+ ASSERT_EQ(ErrnoFromStatus(st), EINVAL);
+ st = IOErrorFromErrno(EPERM, "foo");
+ ASSERT_EQ(ErrnoFromStatus(st), EPERM);
+ st = IOErrorFromErrno(6789, "foo");
+ ASSERT_EQ(ErrnoFromStatus(st), 6789);
+
+ st = CancelledFromSignal(SIGINT, "foo");
+ ASSERT_EQ(ErrnoFromStatus(st), 0);
+}
+
+TEST(SignalFromStatus, Basics) {
+ Status st;
+ st = Status::OK();
+ ASSERT_EQ(SignalFromStatus(st), 0);
+ st = Status::KeyError("foo");
+ ASSERT_EQ(SignalFromStatus(st), 0);
+ st = Status::Cancelled("foo");
+ ASSERT_EQ(SignalFromStatus(st), 0);
+ st = StatusFromSignal(SIGINT, StatusCode::KeyError, "foo");
+ ASSERT_EQ(SignalFromStatus(st), SIGINT);
+ ASSERT_EQ(st.ToString(),
+ "Key error: foo. Detail: received signal " + std::to_string(SIGINT));
+ st = CancelledFromSignal(SIGINT, "bar");
+ ASSERT_EQ(SignalFromStatus(st), SIGINT);
+ ASSERT_EQ(st.ToString(),
+ "Cancelled: bar. Detail: received signal " + std::to_string(SIGINT));
+
+ st = IOErrorFromErrno(EINVAL, "foo");
+ ASSERT_EQ(SignalFromStatus(st), 0);
+}
+
+TEST(GetPageSize, Basics) {
+ const auto page_size = GetPageSize();
+ ASSERT_GE(page_size, 4096);
+ // It's a power of 2
+ ASSERT_EQ((page_size - 1) & page_size, 0);
+}
+
+TEST(MemoryAdviseWillNeed, Basics) {
+ ASSERT_OK_AND_ASSIGN(auto buf1, AllocateBuffer(8192));
+ ASSERT_OK_AND_ASSIGN(auto buf2, AllocateBuffer(1024 * 1024));
+
+ const auto addr1 = buf1->mutable_data();
+ const auto size1 = static_cast<size_t>(buf1->size());
+ const auto addr2 = buf2->mutable_data();
+ const auto size2 = static_cast<size_t>(buf2->size());
+
+ ASSERT_OK(MemoryAdviseWillNeed({}));
+ ASSERT_OK(MemoryAdviseWillNeed({{addr1, size1}, {addr2, size2}}));
+ ASSERT_OK(MemoryAdviseWillNeed({{addr1 + 1, size1 - 1}, {addr2 + 4095, size2 - 4095}}));
+ ASSERT_OK(MemoryAdviseWillNeed({{addr1, 13}, {addr2, 1}}));
+ ASSERT_OK(MemoryAdviseWillNeed({{addr1, 0}, {addr2 + 1, 0}}));
+
+ // Should probably fail
+ // (but on Windows, MemoryAdviseWillNeed can be a no-op)
+#ifndef _WIN32
+ ASSERT_RAISES(IOError,
+ MemoryAdviseWillNeed({{nullptr, std::numeric_limits<size_t>::max()}}));
+#endif
+}
+
+#if _WIN32
+TEST(WinErrorFromStatus, Basics) {
+ Status st;
+ st = Status::OK();
+ ASSERT_EQ(WinErrorFromStatus(st), 0);
+ st = Status::KeyError("foo");
+ ASSERT_EQ(WinErrorFromStatus(st), 0);
+ st = Status::IOError("foo");
+ ASSERT_EQ(WinErrorFromStatus(st), 0);
+ st = StatusFromWinError(ERROR_FILE_NOT_FOUND, StatusCode::KeyError, "foo");
+ ASSERT_EQ(WinErrorFromStatus(st), ERROR_FILE_NOT_FOUND);
+ st = IOErrorFromWinError(ERROR_ACCESS_DENIED, "foo");
+ ASSERT_EQ(WinErrorFromStatus(st), ERROR_ACCESS_DENIED);
+ st = IOErrorFromWinError(6789, "foo");
+ ASSERT_EQ(WinErrorFromStatus(st), 6789);
+}
+#endif
+
+TEST(PlatformFilename, RoundtripAscii) {
+ PlatformFilename fn;
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("a/b"));
+ ASSERT_EQ(fn.ToString(), "a/b");
+#if _WIN32
+ ASSERT_EQ(fn.ToNative(), L"a\\b");
+#else
+ ASSERT_EQ(fn.ToNative(), "a/b");
+#endif
+}
+
+TEST(PlatformFilename, RoundtripUtf8) {
+ PlatformFilename fn;
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("h\xc3\xa9h\xc3\xa9"));
+ ASSERT_EQ(fn.ToString(), "h\xc3\xa9h\xc3\xa9");
+#if _WIN32
+ ASSERT_EQ(fn.ToNative(), L"h\u00e9h\u00e9");
+#else
+ ASSERT_EQ(fn.ToNative(), "h\xc3\xa9h\xc3\xa9");
+#endif
+}
+
+#if _WIN32
+TEST(PlatformFilename, Separators) {
+ PlatformFilename fn;
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("C:/foo/bar"));
+ ASSERT_EQ(fn.ToString(), "C:/foo/bar");
+ ASSERT_EQ(fn.ToNative(), L"C:\\foo\\bar");
+
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("C:\\foo\\bar"));
+ ASSERT_EQ(fn.ToString(), "C:/foo/bar");
+ ASSERT_EQ(fn.ToNative(), L"C:\\foo\\bar");
+}
+#endif
+
+TEST(PlatformFilename, Invalid) {
+ std::string s = "foo";
+ s += '\x00';
+ ASSERT_RAISES(Invalid, PlatformFilename::FromString(s));
+}
+
+TEST(PlatformFilename, Join) {
+ PlatformFilename fn, joined;
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("a/b"));
+ ASSERT_OK_AND_ASSIGN(joined, fn.Join("c/d"));
+ ASSERT_EQ(joined.ToString(), "a/b/c/d");
+#if _WIN32
+ ASSERT_EQ(joined.ToNative(), L"a\\b\\c\\d");
+#else
+ ASSERT_EQ(joined.ToNative(), "a/b/c/d");
+#endif
+
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("a/b/"));
+ ASSERT_OK_AND_ASSIGN(joined, fn.Join("c/d"));
+ ASSERT_EQ(joined.ToString(), "a/b/c/d");
+#if _WIN32
+ ASSERT_EQ(joined.ToNative(), L"a\\b\\c\\d");
+#else
+ ASSERT_EQ(joined.ToNative(), "a/b/c/d");
+#endif
+
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString(""));
+ ASSERT_OK_AND_ASSIGN(joined, fn.Join("c/d"));
+ ASSERT_EQ(joined.ToString(), "c/d");
+#if _WIN32
+ ASSERT_EQ(joined.ToNative(), L"c\\d");
+#else
+ ASSERT_EQ(joined.ToNative(), "c/d");
+#endif
+
+#if _WIN32
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("a\\b"));
+ ASSERT_OK_AND_ASSIGN(joined, fn.Join("c\\d"));
+ ASSERT_EQ(joined.ToString(), "a/b/c/d");
+ ASSERT_EQ(joined.ToNative(), L"a\\b\\c\\d");
+
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("a\\b\\"));
+ ASSERT_OK_AND_ASSIGN(joined, fn.Join("c\\d"));
+ ASSERT_EQ(joined.ToString(), "a/b/c/d");
+ ASSERT_EQ(joined.ToNative(), L"a\\b\\c\\d");
+#endif
+}
+
+TEST(PlatformFilename, JoinInvalid) {
+ PlatformFilename fn;
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("a/b"));
+ std::string s = "foo";
+ s += '\x00';
+ ASSERT_RAISES(Invalid, fn.Join(s));
+}
+
+TEST(PlatformFilename, Parent) {
+ PlatformFilename fn;
+
+ // Relative
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("ab/cd"));
+ ASSERT_EQ(fn.ToString(), "ab/cd");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab");
+#if _WIN32
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("ab/cd\\ef"));
+ ASSERT_EQ(fn.ToString(), "ab/cd/ef");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab/cd");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab");
+#endif
+
+ // Absolute
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("/ab/cd/ef"));
+ ASSERT_EQ(fn.ToString(), "/ab/cd/ef");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "/ab/cd");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "/ab");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "/");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "/");
+#if _WIN32
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("\\ab\\cd/ef"));
+ ASSERT_EQ(fn.ToString(), "/ab/cd/ef");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "/ab/cd");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "/ab");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "/");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "/");
+#endif
+
+ // Empty
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString(""));
+ ASSERT_EQ(fn.ToString(), "");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "");
+
+ // Multiple separators, relative
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("ab//cd///ef"));
+ ASSERT_EQ(fn.ToString(), "ab//cd///ef");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab//cd");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab");
+#if _WIN32
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("ab\\\\cd\\\\\\ef"));
+ ASSERT_EQ(fn.ToString(), "ab//cd///ef");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab//cd");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab");
+#endif
+
+ // Multiple separators, absolute
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("//ab//cd///ef"));
+ ASSERT_EQ(fn.ToString(), "//ab//cd///ef");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "//ab//cd");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "//ab");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "//");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "//");
+#if _WIN32
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("\\\\ab\\cd\\ef"));
+ ASSERT_EQ(fn.ToString(), "//ab/cd/ef");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "//ab/cd");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "//ab");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "//");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "//");
+#endif
+
+ // Trailing slashes
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("/ab/cd/ef/"));
+ ASSERT_EQ(fn.ToString(), "/ab/cd/ef/");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "/ab/cd");
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("/ab/cd/ef//"));
+ ASSERT_EQ(fn.ToString(), "/ab/cd/ef//");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "/ab/cd");
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("ab/"));
+ ASSERT_EQ(fn.ToString(), "ab/");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab/");
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("ab//"));
+ ASSERT_EQ(fn.ToString(), "ab//");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab//");
+#if _WIN32
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("\\ab\\cd\\ef\\"));
+ ASSERT_EQ(fn.ToString(), "/ab/cd/ef/");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "/ab/cd");
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("\\ab\\cd\\ef\\\\"));
+ ASSERT_EQ(fn.ToString(), "/ab/cd/ef//");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "/ab/cd");
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("ab\\"));
+ ASSERT_EQ(fn.ToString(), "ab/");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab/");
+ ASSERT_OK_AND_ASSIGN(fn, PlatformFilename::FromString("ab\\\\"));
+ ASSERT_EQ(fn.ToString(), "ab//");
+ fn = fn.Parent();
+ ASSERT_EQ(fn.ToString(), "ab//");
+#endif
+}
+
+TEST(CreateDirDeleteDir, Basics) {
+ std::unique_ptr<TemporaryDir> temp_dir;
+ ASSERT_OK_AND_ASSIGN(temp_dir, TemporaryDir::Make("deletedirtest-"));
+ const std::string BASE =
+ temp_dir->path().Join("xxx-io-util-test-dir2").ValueOrDie().ToString();
+ bool created, deleted;
+ PlatformFilename parent, child, child_file;
+
+ ASSERT_OK_AND_ASSIGN(parent, PlatformFilename::FromString(BASE));
+ ASSERT_EQ(parent.ToString(), BASE);
+
+ // Make sure the directory doesn't exist already
+ ARROW_UNUSED(DeleteDirTree(parent));
+
+ AssertNotExists(parent);
+
+ ASSERT_OK_AND_ASSIGN(created, CreateDir(parent));
+ ASSERT_TRUE(created);
+ AssertExists(parent);
+ ASSERT_OK_AND_ASSIGN(created, CreateDir(parent));
+ ASSERT_FALSE(created); // already exists
+ AssertExists(parent);
+
+ ASSERT_OK_AND_ASSIGN(child, PlatformFilename::FromString(BASE + "/some-child"));
+ ASSERT_OK_AND_ASSIGN(created, CreateDir(child));
+ ASSERT_TRUE(created);
+ AssertExists(child);
+
+ ASSERT_OK_AND_ASSIGN(child_file, PlatformFilename::FromString(BASE + "/some-file"));
+ TouchFile(child_file);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ IOError, ::testing::HasSubstr("non-directory entry exists"), CreateDir(child_file));
+
+ ASSERT_OK_AND_ASSIGN(deleted, DeleteDirTree(parent));
+ ASSERT_TRUE(deleted);
+ AssertNotExists(parent);
+ AssertNotExists(child);
+
+ // Parent is deleted, cannot create child again
+ ASSERT_RAISES(IOError, CreateDir(child));
+
+ // It's not an error to call DeleteDirTree on a nonexistent path.
+ ASSERT_OK_AND_ASSIGN(deleted, DeleteDirTree(parent));
+ ASSERT_FALSE(deleted);
+ // ... unless asked so
+ auto status = DeleteDirTree(parent, /*allow_not_found=*/false).status();
+ ASSERT_RAISES(IOError, status);
+#ifdef _WIN32
+ ASSERT_EQ(WinErrorFromStatus(status), ERROR_FILE_NOT_FOUND);
+#else
+ ASSERT_EQ(ErrnoFromStatus(status), ENOENT);
+#endif
+}
+
+TEST(DeleteDirContents, Basics) {
+ std::unique_ptr<TemporaryDir> temp_dir;
+ ASSERT_OK_AND_ASSIGN(temp_dir, TemporaryDir::Make("deletedirtest-"));
+ const std::string BASE =
+ temp_dir->path().Join("xxx-io-util-test-dir2").ValueOrDie().ToString();
+ bool created, deleted;
+ PlatformFilename parent, child1, child2;
+
+ ASSERT_OK_AND_ASSIGN(parent, PlatformFilename::FromString(BASE));
+ ASSERT_EQ(parent.ToString(), BASE);
+
+ // Make sure the directory doesn't exist already
+ ARROW_UNUSED(DeleteDirTree(parent));
+
+ AssertNotExists(parent);
+
+ // Create the parent, a child dir and a child file
+ ASSERT_OK_AND_ASSIGN(created, CreateDir(parent));
+ ASSERT_TRUE(created);
+ ASSERT_OK_AND_ASSIGN(child1, PlatformFilename::FromString(BASE + "/child-dir"));
+ ASSERT_OK_AND_ASSIGN(child2, PlatformFilename::FromString(BASE + "/child-file"));
+ ASSERT_OK_AND_ASSIGN(created, CreateDir(child1));
+ ASSERT_TRUE(created);
+ TouchFile(child2);
+ AssertExists(child1);
+ AssertExists(child2);
+
+ // Cannot call DeleteDirContents on a file
+ ASSERT_RAISES(IOError, DeleteDirContents(child2));
+ AssertExists(child2);
+
+ ASSERT_OK_AND_ASSIGN(deleted, DeleteDirContents(parent));
+ ASSERT_TRUE(deleted);
+ AssertExists(parent);
+ AssertNotExists(child1);
+ AssertNotExists(child2);
+ ASSERT_OK_AND_ASSIGN(deleted, DeleteDirContents(parent));
+ ASSERT_TRUE(deleted);
+ AssertExists(parent);
+
+ // It's not an error to call DeleteDirContents on a nonexistent path.
+ ASSERT_OK_AND_ASSIGN(deleted, DeleteDirContents(child1));
+ ASSERT_FALSE(deleted);
+ // ... unless asked so
+ auto status = DeleteDirContents(child1, /*allow_not_found=*/false).status();
+ ASSERT_RAISES(IOError, status);
+#ifdef _WIN32
+ ASSERT_EQ(WinErrorFromStatus(status), ERROR_FILE_NOT_FOUND);
+#else
+ ASSERT_EQ(ErrnoFromStatus(status), ENOENT);
+#endif
+
+ // Now actually delete the test directory
+ ASSERT_OK_AND_ASSIGN(deleted, DeleteDirTree(parent));
+ ASSERT_TRUE(deleted);
+}
+
+TEST(TemporaryDir, Basics) {
+ std::unique_ptr<TemporaryDir> temp_dir;
+ PlatformFilename fn;
+
+ ASSERT_OK_AND_ASSIGN(temp_dir, TemporaryDir::Make("some-prefix-"));
+ fn = temp_dir->path();
+ // Path has a trailing separator, for convenience
+ ASSERT_EQ(fn.ToString().back(), '/');
+#if defined(_WIN32)
+ ASSERT_EQ(fn.ToNative().back(), L'\\');
+#else
+ ASSERT_EQ(fn.ToNative().back(), '/');
+#endif
+ AssertExists(fn);
+ ASSERT_NE(fn.ToString().find("some-prefix-"), std::string::npos);
+
+ // Create child contents to check that they're cleaned up at the end
+#if defined(_WIN32)
+ PlatformFilename child(fn.ToNative() + L"some-child");
+#else
+ PlatformFilename child(fn.ToNative() + "some-child");
+#endif
+ ASSERT_OK(CreateDir(child));
+ AssertExists(child);
+
+ temp_dir.reset();
+ AssertNotExists(fn);
+ AssertNotExists(child);
+}
+
+TEST(CreateDirTree, Basics) {
+ std::unique_ptr<TemporaryDir> temp_dir;
+ PlatformFilename fn;
+ bool created;
+
+ ASSERT_OK_AND_ASSIGN(temp_dir, TemporaryDir::Make("io-util-test-"));
+
+ ASSERT_OK_AND_ASSIGN(fn, temp_dir->path().Join("AB/CD"));
+ ASSERT_OK_AND_ASSIGN(created, CreateDirTree(fn));
+ ASSERT_TRUE(created);
+ ASSERT_OK_AND_ASSIGN(created, CreateDirTree(fn));
+ ASSERT_FALSE(created);
+
+ ASSERT_OK_AND_ASSIGN(fn, temp_dir->path().Join("AB"));
+ ASSERT_OK_AND_ASSIGN(created, CreateDirTree(fn));
+ ASSERT_FALSE(created);
+
+ ASSERT_OK_AND_ASSIGN(fn, temp_dir->path().Join("EF"));
+ ASSERT_OK_AND_ASSIGN(created, CreateDirTree(fn));
+ ASSERT_TRUE(created);
+
+ ASSERT_OK_AND_ASSIGN(fn, temp_dir->path().Join("AB/file"));
+ TouchFile(fn);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ IOError, ::testing::HasSubstr("non-directory entry exists"), CreateDirTree(fn));
+
+ ASSERT_OK_AND_ASSIGN(fn, temp_dir->path().Join("AB/file/sub"));
+ ASSERT_RAISES(IOError, CreateDirTree(fn));
+}
+
+TEST(ListDir, Basics) {
+ std::unique_ptr<TemporaryDir> temp_dir;
+ PlatformFilename fn;
+ std::vector<PlatformFilename> entries;
+
+ auto check_entries = [](const std::vector<PlatformFilename>& entries,
+ std::vector<std::string> expected) -> void {
+ std::vector<std::string> actual(entries.size());
+ std::transform(entries.begin(), entries.end(), actual.begin(),
+ [](const PlatformFilename& fn) { return fn.ToString(); });
+ // Sort results for deterministic testing
+ std::sort(actual.begin(), actual.end());
+ ASSERT_EQ(actual, expected);
+ };
+
+ ASSERT_OK_AND_ASSIGN(temp_dir, TemporaryDir::Make("io-util-test-"));
+
+ ASSERT_OK_AND_ASSIGN(fn, temp_dir->path().Join("AB/CD"));
+ ASSERT_OK(CreateDirTree(fn));
+ ASSERT_OK_AND_ASSIGN(fn, temp_dir->path().Join("AB/EF/GH"));
+ ASSERT_OK(CreateDirTree(fn));
+ ASSERT_OK_AND_ASSIGN(fn, temp_dir->path().Join("AB/ghi.txt"));
+ TouchFile(fn);
+
+ ASSERT_OK_AND_ASSIGN(fn, temp_dir->path().Join("AB"));
+ ASSERT_OK_AND_ASSIGN(entries, ListDir(fn));
+ ASSERT_EQ(entries.size(), 3);
+ check_entries(entries, {"CD", "EF", "ghi.txt"});
+ ASSERT_OK_AND_ASSIGN(fn, temp_dir->path().Join("AB/EF/GH"));
+ ASSERT_OK_AND_ASSIGN(entries, ListDir(fn));
+ check_entries(entries, {});
+
+ // Errors
+ ASSERT_OK_AND_ASSIGN(fn, temp_dir->path().Join("nonexistent"));
+ ASSERT_RAISES(IOError, ListDir(fn));
+ ASSERT_OK_AND_ASSIGN(fn, temp_dir->path().Join("AB/ghi.txt"));
+ ASSERT_RAISES(IOError, ListDir(fn));
+}
+
+TEST(DeleteFile, Basics) {
+ std::unique_ptr<TemporaryDir> temp_dir;
+ PlatformFilename fn;
+ bool deleted;
+
+ ASSERT_OK_AND_ASSIGN(temp_dir, TemporaryDir::Make("io-util-test-"));
+ ASSERT_OK_AND_ASSIGN(fn, temp_dir->path().Join("test-file"));
+
+ AssertNotExists(fn);
+ TouchFile(fn);
+ AssertExists(fn);
+ ASSERT_OK_AND_ASSIGN(deleted, DeleteFile(fn));
+ ASSERT_TRUE(deleted);
+ AssertNotExists(fn);
+ ASSERT_OK_AND_ASSIGN(deleted, DeleteFile(fn));
+ ASSERT_FALSE(deleted);
+ AssertNotExists(fn);
+ auto status = DeleteFile(fn, /*allow_not_found=*/false).status();
+ ASSERT_RAISES(IOError, status);
+#ifdef _WIN32
+ ASSERT_EQ(WinErrorFromStatus(status), ERROR_FILE_NOT_FOUND);
+#else
+ ASSERT_EQ(ErrnoFromStatus(status), ENOENT);
+#endif
+
+ // Cannot call DeleteFile on directory
+ ASSERT_OK_AND_ASSIGN(fn, temp_dir->path().Join("test-temp_dir"));
+ ASSERT_OK(CreateDir(fn));
+ AssertExists(fn);
+ ASSERT_RAISES(IOError, DeleteFile(fn));
+}
+
+#ifndef __APPLE__
+TEST(FileUtils, LongPaths) {
+ // ARROW-8477: check using long file paths under Windows (> 260 characters).
+ bool created, deleted;
+#ifdef _WIN32
+ const char* kRegKeyName = R"(SYSTEM\CurrentControlSet\Control\FileSystem)";
+ const char* kRegValueName = "LongPathsEnabled";
+ DWORD value = 0;
+ DWORD size = sizeof(value);
+ LSTATUS status = RegGetValueA(HKEY_LOCAL_MACHINE, kRegKeyName, kRegValueName,
+ RRF_RT_REG_DWORD, NULL, &value, &size);
+ bool test_long_paths = (status == ERROR_SUCCESS && value == 1);
+ if (!test_long_paths) {
+ ARROW_LOG(WARNING)
+ << "Tests for accessing files with long path names have been disabled. "
+ << "To enable these tests, set the value of " << kRegValueName
+ << " in registry key \\HKEY_LOCAL_MACHINE\\" << kRegKeyName
+ << " to 1 on the test host.";
+ return;
+ }
+#endif
+
+ const std::string BASE = "xxx-io-util-test-dir-long";
+ PlatformFilename base_path, long_path, long_filename;
+ int fd = -1;
+ std::stringstream fs;
+ fs << BASE;
+ for (int i = 0; i < 64; ++i) {
+ fs << "/123456789ABCDEF";
+ }
+ ASSERT_OK_AND_ASSIGN(base_path,
+ PlatformFilename::FromString(BASE)); // long_path length > 1024
+ ASSERT_OK_AND_ASSIGN(
+ long_path, PlatformFilename::FromString(fs.str())); // long_path length > 1024
+ ASSERT_OK_AND_ASSIGN(created, CreateDirTree(long_path));
+ ASSERT_TRUE(created);
+ AssertExists(long_path);
+ ASSERT_OK_AND_ASSIGN(long_filename,
+ PlatformFilename::FromString(fs.str() + "/file.txt"));
+ TouchFile(long_filename);
+ AssertExists(long_filename);
+ fd = -1;
+ ASSERT_OK_AND_ASSIGN(fd, FileOpenReadable(long_filename));
+ ASSERT_OK(FileClose(fd));
+ ASSERT_OK_AND_ASSIGN(deleted, DeleteDirContents(long_path));
+ ASSERT_TRUE(deleted);
+ ASSERT_OK_AND_ASSIGN(deleted, DeleteDirTree(long_path));
+ ASSERT_TRUE(deleted);
+
+ // Now delete the whole test directory tree
+ ASSERT_OK_AND_ASSIGN(deleted, DeleteDirTree(base_path));
+ ASSERT_TRUE(deleted);
+}
+#endif
+
+static std::atomic<int> signal_received;
+
+static void handle_signal(int signum) {
+ ReinstateSignalHandler(signum, &handle_signal);
+ signal_received.store(signum);
+}
+
+TEST(SendSignal, Generic) {
+ signal_received.store(0);
+ SignalHandlerGuard guard(SIGINT, &handle_signal);
+
+ ASSERT_EQ(signal_received.load(), 0);
+ ASSERT_OK(SendSignal(SIGINT));
+ BusyWait(1.0, [&]() { return signal_received.load() != 0; });
+ ASSERT_EQ(signal_received.load(), SIGINT);
+
+ // Re-try (exercise ReinstateSignalHandler)
+ signal_received.store(0);
+ ASSERT_OK(SendSignal(SIGINT));
+ BusyWait(1.0, [&]() { return signal_received.load() != 0; });
+ ASSERT_EQ(signal_received.load(), SIGINT);
+}
+
+TEST(SendSignal, ToThread) {
+#ifdef _WIN32
+ uint64_t dummy_thread_id = 42;
+ ASSERT_RAISES(NotImplemented, SendSignalToThread(SIGINT, dummy_thread_id));
+#else
+ // Have to use a C-style cast because pthread_t can be a pointer *or* integer type
+ uint64_t thread_id = (uint64_t)(pthread_self()); // NOLINT readability-casting
+ signal_received.store(0);
+ SignalHandlerGuard guard(SIGINT, &handle_signal);
+
+ ASSERT_EQ(signal_received.load(), 0);
+ ASSERT_OK(SendSignalToThread(SIGINT, thread_id));
+ BusyWait(1.0, [&]() { return signal_received.load() != 0; });
+
+ ASSERT_EQ(signal_received.load(), SIGINT);
+#endif
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/io_util_test.manifest b/src/arrow/cpp/src/arrow/util/io_util_test.manifest
new file mode 100644
index 000000000..de5c0d8a8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/io_util_test.manifest
@@ -0,0 +1,39 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+<!--
+ Enable long file paths on the target application
+ See https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file#enable-long-paths-in-windows-10-version-1607-and-later
+ -->
+<assembly xmlns="urn:schemas-microsoft-com:asm.v1" manifestVersion="1.0">
+ <assemblyIdentity type="win32" name="ArrowUtilityTest" version="1.1.1.1"/>
+ <application xmlns="urn:schemas-microsoft-com:asm.v3">
+ <windowsSettings xmlns:ws2="http://schemas.microsoft.com/SMI/2016/WindowsSettings">
+ <ws2:longPathAware>true</ws2:longPathAware>
+ </windowsSettings>
+ </application>
+ <trustInfo xmlns="urn:schemas-microsoft-com:asm.v3">
+ <security>
+ <requestedPrivileges>
+ <requestedExecutionLevel level="asInvoker" uiAccess="false"/>
+ </requestedPrivileges>
+ </security>
+ </trustInfo>
+</assembly>
diff --git a/src/arrow/cpp/src/arrow/util/io_util_test.rc b/src/arrow/cpp/src/arrow/util/io_util_test.rc
new file mode 100644
index 000000000..c3236cb2b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/io_util_test.rc
@@ -0,0 +1,44 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef CREATEPROCESS_MANIFEST_RESOURCE_ID
+#define CREATEPROCESS_MANIFEST_RESOURCE_ID 1
+#endif
+#ifndef RT_MANIFEST
+#define RT_MANIFEST 24
+#endif
+
+CREATEPROCESS_MANIFEST_RESOURCE_ID RT_MANIFEST
+BEGIN
+ "<?xml version=""1.0"" encoding=""UTF-8"" standalone=""yes""?>"
+ "<assembly xmlns=""urn:schemas-microsoft-com:asm.v1"" manifestVersion=""1.0"">"
+ "<assemblyIdentity type=""win32"" name=""ArrowUtilityTest"" version=""1.1.1.1""/>"
+ "<application xmlns=""urn:schemas-microsoft-com:asm.v3"">"
+ "<windowsSettings xmlns:ws2=""http://schemas.microsoft.com/SMI/2016/WindowsSettings"">"
+ "<ws2:longPathAware>true</ws2:longPathAware>"
+ "</windowsSettings>"
+ "</application>"
+ "<trustInfo xmlns=""urn:schemas-microsoft-com:asm.v3"">"
+ "<security>"
+ "<requestedPrivileges>"
+ "<requestedExecutionLevel level=""asInvoker"" uiAccess=""false""/>"
+ "</requestedPrivileges>"
+ "</security>"
+ "</trustInfo>"
+ "</assembly>"
+END
+
diff --git a/src/arrow/cpp/src/arrow/util/iterator.h b/src/arrow/cpp/src/arrow/util/iterator.h
new file mode 100644
index 000000000..2f42803d2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/iterator.h
@@ -0,0 +1,568 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cassert>
+#include <functional>
+#include <memory>
+#include <tuple>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/compare.h"
+#include "arrow/util/functional.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+template <typename T>
+class Iterator;
+
+template <typename T>
+struct IterationTraits {
+ /// \brief a reserved value which indicates the end of iteration. By
+ /// default this is NULLPTR since most iterators yield pointer types.
+ /// Specialize IterationTraits if different end semantics are required.
+ ///
+ /// Note: This should not be used to determine if a given value is a
+ /// terminal value. Use IsIterationEnd (which uses IsEnd) instead. This
+ /// is only for returning terminal values.
+ static T End() { return T(NULLPTR); }
+
+ /// \brief Checks to see if the value is a terminal value.
+ /// A method is used here since T is not neccesarily comparable in many
+ /// cases even though it has a distinct final value
+ static bool IsEnd(const T& val) { return val == End(); }
+};
+
+template <typename T>
+T IterationEnd() {
+ return IterationTraits<T>::End();
+}
+
+template <typename T>
+bool IsIterationEnd(const T& val) {
+ return IterationTraits<T>::IsEnd(val);
+}
+
+template <typename T>
+struct IterationTraits<util::optional<T>> {
+ /// \brief by default when iterating through a sequence of optional,
+ /// nullopt indicates the end of iteration.
+ /// Specialize IterationTraits if different end semantics are required.
+ static util::optional<T> End() { return util::nullopt; }
+
+ /// \brief by default when iterating through a sequence of optional,
+ /// nullopt (!has_value()) indicates the end of iteration.
+ /// Specialize IterationTraits if different end semantics are required.
+ static bool IsEnd(const util::optional<T>& val) { return !val.has_value(); }
+
+ // TODO(bkietz) The range-for loop over Iterator<optional<T>> yields
+ // Result<optional<T>> which is unnecessary (since only the unyielded end optional
+ // is nullopt. Add IterationTraits::GetRangeElement() to handle this case
+};
+
+/// \brief A generic Iterator that can return errors
+template <typename T>
+class Iterator : public util::EqualityComparable<Iterator<T>> {
+ public:
+ /// \brief Iterator may be constructed from any type which has a member function
+ /// with signature Result<T> Next();
+ /// End of iterator is signalled by returning IteratorTraits<T>::End();
+ ///
+ /// The argument is moved or copied to the heap and kept in a unique_ptr<void>. Only
+ /// its destructor and its Next method (which are stored in function pointers) are
+ /// referenced after construction.
+ ///
+ /// This approach is used to dodge MSVC linkage hell (ARROW-6244, ARROW-6558) when using
+ /// an abstract template base class: instead of being inlined as usual for a template
+ /// function the base's virtual destructor will be exported, leading to multiple
+ /// definition errors when linking to any other TU where the base is instantiated.
+ template <typename Wrapped>
+ explicit Iterator(Wrapped has_next)
+ : ptr_(new Wrapped(std::move(has_next)), Delete<Wrapped>), next_(Next<Wrapped>) {}
+
+ Iterator() : ptr_(NULLPTR, [](void*) {}) {}
+
+ /// \brief Return the next element of the sequence, IterationTraits<T>::End() when the
+ /// iteration is completed. Calling this on a default constructed Iterator
+ /// will result in undefined behavior.
+ Result<T> Next() { return next_(ptr_.get()); }
+
+ /// Pass each element of the sequence to a visitor. Will return any error status
+ /// returned by the visitor, terminating iteration.
+ template <typename Visitor>
+ Status Visit(Visitor&& visitor) {
+ for (;;) {
+ ARROW_ASSIGN_OR_RAISE(auto value, Next());
+
+ if (IsIterationEnd(value)) break;
+
+ ARROW_RETURN_NOT_OK(visitor(std::move(value)));
+ }
+
+ return Status::OK();
+ }
+
+ /// Iterators will only compare equal if they are both null.
+ /// Equality comparability is required to make an Iterator of Iterators
+ /// (to check for the end condition).
+ bool Equals(const Iterator& other) const { return ptr_ == other.ptr_; }
+
+ explicit operator bool() const { return ptr_ != NULLPTR; }
+
+ class RangeIterator {
+ public:
+ RangeIterator() : value_(IterationTraits<T>::End()) {}
+
+ explicit RangeIterator(Iterator i)
+ : value_(IterationTraits<T>::End()),
+ iterator_(std::make_shared<Iterator>(std::move(i))) {
+ Next();
+ }
+
+ bool operator!=(const RangeIterator& other) const { return value_ != other.value_; }
+
+ RangeIterator& operator++() {
+ Next();
+ return *this;
+ }
+
+ Result<T> operator*() {
+ ARROW_RETURN_NOT_OK(value_.status());
+
+ auto value = std::move(value_);
+ value_ = IterationTraits<T>::End();
+ return value;
+ }
+
+ private:
+ void Next() {
+ if (!value_.ok()) {
+ value_ = IterationTraits<T>::End();
+ return;
+ }
+ value_ = iterator_->Next();
+ }
+
+ Result<T> value_;
+ std::shared_ptr<Iterator> iterator_;
+ };
+
+ RangeIterator begin() { return RangeIterator(std::move(*this)); }
+
+ RangeIterator end() { return RangeIterator(); }
+
+ /// \brief Move every element of this iterator into a vector.
+ Result<std::vector<T>> ToVector() {
+ std::vector<T> out;
+ for (auto maybe_element : *this) {
+ ARROW_ASSIGN_OR_RAISE(auto element, maybe_element);
+ out.push_back(std::move(element));
+ }
+ // ARROW-8193: On gcc-4.8 without the explicit move it tries to use the
+ // copy constructor, which may be deleted on the elements of type T
+ return std::move(out);
+ }
+
+ private:
+ /// Implementation of deleter for ptr_: Casts from void* to the wrapped type and
+ /// deletes that.
+ template <typename HasNext>
+ static void Delete(void* ptr) {
+ delete static_cast<HasNext*>(ptr);
+ }
+
+ /// Implementation of Next: Casts from void* to the wrapped type and invokes that
+ /// type's Next member function.
+ template <typename HasNext>
+ static Result<T> Next(void* ptr) {
+ return static_cast<HasNext*>(ptr)->Next();
+ }
+
+ /// ptr_ is a unique_ptr to void with a custom deleter: a function pointer which first
+ /// casts from void* to a pointer to the wrapped type then deletes that.
+ std::unique_ptr<void, void (*)(void*)> ptr_;
+
+ /// next_ is a function pointer which first casts from void* to a pointer to the wrapped
+ /// type then invokes its Next member function.
+ Result<T> (*next_)(void*) = NULLPTR;
+};
+
+template <typename T>
+struct TransformFlow {
+ using YieldValueType = T;
+
+ TransformFlow(YieldValueType value, bool ready_for_next)
+ : finished_(false),
+ ready_for_next_(ready_for_next),
+ yield_value_(std::move(value)) {}
+ TransformFlow(bool finished, bool ready_for_next)
+ : finished_(finished), ready_for_next_(ready_for_next), yield_value_() {}
+
+ bool HasValue() const { return yield_value_.has_value(); }
+ bool Finished() const { return finished_; }
+ bool ReadyForNext() const { return ready_for_next_; }
+ T Value() const { return *yield_value_; }
+
+ bool finished_ = false;
+ bool ready_for_next_ = false;
+ util::optional<YieldValueType> yield_value_;
+};
+
+struct TransformFinish {
+ template <typename T>
+ operator TransformFlow<T>() && { // NOLINT explicit
+ return TransformFlow<T>(true, true);
+ }
+};
+
+struct TransformSkip {
+ template <typename T>
+ operator TransformFlow<T>() && { // NOLINT explicit
+ return TransformFlow<T>(false, true);
+ }
+};
+
+template <typename T>
+TransformFlow<T> TransformYield(T value = {}, bool ready_for_next = true) {
+ return TransformFlow<T>(std::move(value), ready_for_next);
+}
+
+template <typename T, typename V>
+using Transformer = std::function<Result<TransformFlow<V>>(T)>;
+
+template <typename T, typename V>
+class TransformIterator {
+ public:
+ explicit TransformIterator(Iterator<T> it, Transformer<T, V> transformer)
+ : it_(std::move(it)),
+ transformer_(std::move(transformer)),
+ last_value_(),
+ finished_() {}
+
+ Result<V> Next() {
+ while (!finished_) {
+ ARROW_ASSIGN_OR_RAISE(util::optional<V> next, Pump());
+ if (next.has_value()) {
+ return std::move(*next);
+ }
+ ARROW_ASSIGN_OR_RAISE(last_value_, it_.Next());
+ }
+ return IterationTraits<V>::End();
+ }
+
+ private:
+ // Calls the transform function on the current value. Can return in several ways
+ // * If the next value is requested (e.g. skip) it will return an empty optional
+ // * If an invalid status is encountered that will be returned
+ // * If finished it will return IterationTraits<V>::End()
+ // * If a value is returned by the transformer that will be returned
+ Result<util::optional<V>> Pump() {
+ if (!finished_ && last_value_.has_value()) {
+ auto next_res = transformer_(*last_value_);
+ if (!next_res.ok()) {
+ finished_ = true;
+ return next_res.status();
+ }
+ auto next = *next_res;
+ if (next.ReadyForNext()) {
+ if (IsIterationEnd(*last_value_)) {
+ finished_ = true;
+ }
+ last_value_.reset();
+ }
+ if (next.Finished()) {
+ finished_ = true;
+ }
+ if (next.HasValue()) {
+ return next.Value();
+ }
+ }
+ if (finished_) {
+ return IterationTraits<V>::End();
+ }
+ return util::nullopt;
+ }
+
+ Iterator<T> it_;
+ Transformer<T, V> transformer_;
+ util::optional<T> last_value_;
+ bool finished_ = false;
+};
+
+/// \brief Transforms an iterator according to a transformer, returning a new Iterator.
+///
+/// The transformer will be called on each element of the source iterator and for each
+/// call it can yield a value, skip, or finish the iteration. When yielding a value the
+/// transformer can choose to consume the source item (the default, ready_for_next = true)
+/// or to keep it and it will be called again on the same value.
+///
+/// This is essentially a more generic form of the map operation that can return 0, 1, or
+/// many values for each of the source items.
+///
+/// The transformer will be exposed to the end of the source sequence
+/// (IterationTraits::End) in case it needs to return some penultimate item(s).
+///
+/// Any invalid status returned by the transformer will be returned immediately.
+template <typename T, typename V>
+Iterator<V> MakeTransformedIterator(Iterator<T> it, Transformer<T, V> op) {
+ return Iterator<V>(TransformIterator<T, V>(std::move(it), std::move(op)));
+}
+
+template <typename T>
+struct IterationTraits<Iterator<T>> {
+ // The end condition for an Iterator of Iterators is a default constructed (null)
+ // Iterator.
+ static Iterator<T> End() { return Iterator<T>(); }
+ static bool IsEnd(const Iterator<T>& val) { return !val; }
+};
+
+template <typename Fn, typename T>
+class FunctionIterator {
+ public:
+ explicit FunctionIterator(Fn fn) : fn_(std::move(fn)) {}
+
+ Result<T> Next() { return fn_(); }
+
+ private:
+ Fn fn_;
+};
+
+/// \brief Construct an Iterator which invokes a callable on Next()
+template <typename Fn,
+ typename Ret = typename internal::call_traits::return_type<Fn>::ValueType>
+Iterator<Ret> MakeFunctionIterator(Fn fn) {
+ return Iterator<Ret>(FunctionIterator<Fn, Ret>(std::move(fn)));
+}
+
+template <typename T>
+Iterator<T> MakeEmptyIterator() {
+ return MakeFunctionIterator([]() -> Result<T> { return IterationTraits<T>::End(); });
+}
+
+template <typename T>
+Iterator<T> MakeErrorIterator(Status s) {
+ return MakeFunctionIterator([s]() -> Result<T> {
+ ARROW_RETURN_NOT_OK(s);
+ return IterationTraits<T>::End();
+ });
+}
+
+/// \brief Simple iterator which yields the elements of a std::vector
+template <typename T>
+class VectorIterator {
+ public:
+ explicit VectorIterator(std::vector<T> v) : elements_(std::move(v)) {}
+
+ Result<T> Next() {
+ if (i_ == elements_.size()) {
+ return IterationTraits<T>::End();
+ }
+ return std::move(elements_[i_++]);
+ }
+
+ private:
+ std::vector<T> elements_;
+ size_t i_ = 0;
+};
+
+template <typename T>
+Iterator<T> MakeVectorIterator(std::vector<T> v) {
+ return Iterator<T>(VectorIterator<T>(std::move(v)));
+}
+
+/// \brief Simple iterator which yields *pointers* to the elements of a std::vector<T>.
+/// This is provided to support T where IterationTraits<T>::End is not specialized
+template <typename T>
+class VectorPointingIterator {
+ public:
+ explicit VectorPointingIterator(std::vector<T> v) : elements_(std::move(v)) {}
+
+ Result<T*> Next() {
+ if (i_ == elements_.size()) {
+ return NULLPTR;
+ }
+ return &elements_[i_++];
+ }
+
+ private:
+ std::vector<T> elements_;
+ size_t i_ = 0;
+};
+
+template <typename T>
+Iterator<T*> MakeVectorPointingIterator(std::vector<T> v) {
+ return Iterator<T*>(VectorPointingIterator<T>(std::move(v)));
+}
+
+/// \brief MapIterator takes ownership of an iterator and a function to apply
+/// on every element. The mapped function is not allowed to fail.
+template <typename Fn, typename I, typename O>
+class MapIterator {
+ public:
+ explicit MapIterator(Fn map, Iterator<I> it)
+ : map_(std::move(map)), it_(std::move(it)) {}
+
+ Result<O> Next() {
+ ARROW_ASSIGN_OR_RAISE(I i, it_.Next());
+
+ if (IsIterationEnd(i)) {
+ return IterationTraits<O>::End();
+ }
+
+ return map_(std::move(i));
+ }
+
+ private:
+ Fn map_;
+ Iterator<I> it_;
+};
+
+/// \brief MapIterator takes ownership of an iterator and a function to apply
+/// on every element. The mapped function is not allowed to fail.
+template <typename Fn, typename From = internal::call_traits::argument_type<0, Fn>,
+ typename To = internal::call_traits::return_type<Fn>>
+Iterator<To> MakeMapIterator(Fn map, Iterator<From> it) {
+ return Iterator<To>(MapIterator<Fn, From, To>(std::move(map), std::move(it)));
+}
+
+/// \brief Like MapIterator, but where the function can fail.
+template <typename Fn, typename From = internal::call_traits::argument_type<0, Fn>,
+ typename To = typename internal::call_traits::return_type<Fn>::ValueType>
+Iterator<To> MakeMaybeMapIterator(Fn map, Iterator<From> it) {
+ return Iterator<To>(MapIterator<Fn, From, To>(std::move(map), std::move(it)));
+}
+
+struct FilterIterator {
+ enum Action { ACCEPT, REJECT };
+
+ template <typename To>
+ static Result<std::pair<To, Action>> Reject() {
+ return std::make_pair(IterationTraits<To>::End(), REJECT);
+ }
+
+ template <typename To>
+ static Result<std::pair<To, Action>> Accept(To out) {
+ return std::make_pair(std::move(out), ACCEPT);
+ }
+
+ template <typename To>
+ static Result<std::pair<To, Action>> MaybeAccept(Result<To> maybe_out) {
+ return std::move(maybe_out).Map(Accept<To>);
+ }
+
+ template <typename To>
+ static Result<std::pair<To, Action>> Error(Status s) {
+ return s;
+ }
+
+ template <typename Fn, typename From, typename To>
+ class Impl {
+ public:
+ explicit Impl(Fn filter, Iterator<From> it) : filter_(filter), it_(std::move(it)) {}
+
+ Result<To> Next() {
+ To out = IterationTraits<To>::End();
+ Action action;
+
+ for (;;) {
+ ARROW_ASSIGN_OR_RAISE(From i, it_.Next());
+
+ if (IsIterationEnd(i)) {
+ return IterationTraits<To>::End();
+ }
+
+ ARROW_ASSIGN_OR_RAISE(std::tie(out, action), filter_(std::move(i)));
+
+ if (action == ACCEPT) return out;
+ }
+ }
+
+ private:
+ Fn filter_;
+ Iterator<From> it_;
+ };
+};
+
+/// \brief Like MapIterator, but where the function can fail or reject elements.
+template <
+ typename Fn, typename From = typename internal::call_traits::argument_type<0, Fn>,
+ typename Ret = typename internal::call_traits::return_type<Fn>::ValueType,
+ typename To = typename std::tuple_element<0, Ret>::type,
+ typename Enable = typename std::enable_if<std::is_same<
+ typename std::tuple_element<1, Ret>::type, FilterIterator::Action>::value>::type>
+Iterator<To> MakeFilterIterator(Fn filter, Iterator<From> it) {
+ return Iterator<To>(
+ FilterIterator::Impl<Fn, From, To>(std::move(filter), std::move(it)));
+}
+
+/// \brief FlattenIterator takes an iterator generating iterators and yields a
+/// unified iterator that flattens/concatenates in a single stream.
+template <typename T>
+class FlattenIterator {
+ public:
+ explicit FlattenIterator(Iterator<Iterator<T>> it) : parent_(std::move(it)) {}
+
+ Result<T> Next() {
+ if (IsIterationEnd(child_)) {
+ // Pop from parent's iterator.
+ ARROW_ASSIGN_OR_RAISE(child_, parent_.Next());
+
+ // Check if final iteration reached.
+ if (IsIterationEnd(child_)) {
+ return IterationTraits<T>::End();
+ }
+
+ return Next();
+ }
+
+ // Pop from child_ and check for depletion.
+ ARROW_ASSIGN_OR_RAISE(T out, child_.Next());
+ if (IsIterationEnd(out)) {
+ // Reset state such that we pop from parent on the recursive call
+ child_ = IterationTraits<Iterator<T>>::End();
+
+ return Next();
+ }
+
+ return out;
+ }
+
+ private:
+ Iterator<Iterator<T>> parent_;
+ Iterator<T> child_ = IterationTraits<Iterator<T>>::End();
+};
+
+template <typename T>
+Iterator<T> MakeFlattenIterator(Iterator<Iterator<T>> it) {
+ return Iterator<T>(FlattenIterator<T>(std::move(it)));
+}
+
+template <typename Reader>
+Iterator<typename Reader::ValueType> MakeIteratorFromReader(
+ const std::shared_ptr<Reader>& reader) {
+ return MakeFunctionIterator([reader] { return reader->Next(); });
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/iterator_test.cc b/src/arrow/cpp/src/arrow/util/iterator_test.cc
new file mode 100644
index 000000000..ab62fcb70
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/iterator_test.cc
@@ -0,0 +1,465 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <atomic>
+#include <chrono>
+#include <condition_variable>
+#include <memory>
+#include <mutex>
+#include <ostream>
+#include <thread>
+#include <unordered_set>
+#include <vector>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/test_common.h"
+#include "arrow/util/vector.h"
+
+namespace arrow {
+
+template <typename T>
+class TracingIterator {
+ public:
+ explicit TracingIterator(Iterator<T> it) : it_(std::move(it)), state_(new State) {}
+
+ Result<T> Next() {
+ auto lock = state_->Lock();
+ state_->thread_ids_.insert(std::this_thread::get_id());
+
+ RETURN_NOT_OK(state_->GetNextStatus());
+
+ ARROW_ASSIGN_OR_RAISE(auto out, it_.Next());
+ state_->values_.push_back(out);
+
+ state_->cv_.notify_one();
+ return out;
+ }
+
+ class State {
+ public:
+ const std::vector<T>& values() { return values_; }
+
+ const std::unordered_set<std::thread::id>& thread_ids() { return thread_ids_; }
+
+ void InsertFailure(Status st) {
+ auto lock = Lock();
+ next_status_ = std::move(st);
+ }
+
+ // Wait until the iterator has emitted at least `size` values
+ void WaitForValues(int size) {
+ auto lock = Lock();
+ cv_.wait(lock, [&]() { return values_.size() >= static_cast<size_t>(size); });
+ }
+
+ void AssertValuesEqual(const std::vector<T>& expected) {
+ auto lock = Lock();
+ ASSERT_EQ(values_, expected);
+ }
+
+ void AssertValuesStartwith(const std::vector<T>& expected) {
+ auto lock = Lock();
+ ASSERT_TRUE(std::equal(expected.begin(), expected.end(), values_.begin()));
+ }
+
+ std::unique_lock<std::mutex> Lock() { return std::unique_lock<std::mutex>(mutex_); }
+
+ private:
+ friend TracingIterator;
+
+ Status GetNextStatus() {
+ if (next_status_.ok()) {
+ return Status::OK();
+ }
+
+ Status st = std::move(next_status_);
+ next_status_ = Status::OK();
+ return st;
+ }
+
+ Status next_status_;
+ std::vector<T> values_;
+ std::unordered_set<std::thread::id> thread_ids_;
+
+ std::mutex mutex_;
+ std::condition_variable cv_;
+ };
+
+ const std::shared_ptr<State>& state() const { return state_; }
+
+ private:
+ Iterator<T> it_;
+
+ std::shared_ptr<State> state_;
+};
+
+template <typename T>
+inline Iterator<T> EmptyIt() {
+ return MakeEmptyIterator<T>();
+}
+
+// Non-templated version of VectorIt<T> to allow better type deduction
+inline Iterator<TestInt> VectorIt(std::vector<TestInt> v) {
+ return MakeVectorIterator<TestInt>(std::move(v));
+}
+
+template <typename Fn, typename T>
+inline Iterator<T> FilterIt(Iterator<T> it, Fn&& fn) {
+ return MakeFilterIterator(std::forward<Fn>(fn), std::move(it));
+}
+
+template <typename T>
+inline Iterator<T> FlattenIt(Iterator<Iterator<T>> its) {
+ return MakeFlattenIterator(std::move(its));
+}
+
+template <typename T>
+void AssertIteratorMatch(std::vector<T> expected, Iterator<T> actual) {
+ EXPECT_EQ(expected, IteratorToVector(std::move(actual)));
+}
+
+template <typename T>
+void AssertIteratorNoMatch(std::vector<T> expected, Iterator<T> actual) {
+ EXPECT_NE(expected, IteratorToVector(std::move(actual)));
+}
+
+template <typename T>
+void AssertIteratorNext(T expected, Iterator<T>& it) {
+ ASSERT_OK_AND_ASSIGN(T actual, it.Next());
+ ASSERT_EQ(expected, actual);
+}
+
+// --------------------------------------------------------------------
+// Synchronous iterator tests
+
+TEST(TestEmptyIterator, Basic) { AssertIteratorMatch({}, EmptyIt<TestInt>()); }
+
+TEST(TestVectorIterator, Basic) {
+ AssertIteratorMatch({}, VectorIt({}));
+ AssertIteratorMatch({1, 2, 3}, VectorIt({1, 2, 3}));
+
+ AssertIteratorNoMatch({1}, VectorIt({}));
+ AssertIteratorNoMatch({}, VectorIt({1, 2, 3}));
+ AssertIteratorNoMatch({1, 2, 2}, VectorIt({1, 2, 3}));
+ AssertIteratorNoMatch({1, 2, 3, 1}, VectorIt({1, 2, 3}));
+
+ // int does not have specialized IterationTraits
+ std::vector<int> elements = {0, 1, 2, 3, 4, 5};
+ std::vector<int*> expected;
+ for (int& element : elements) {
+ expected.push_back(&element);
+ }
+ AssertIteratorMatch(expected, MakeVectorPointingIterator(std::move(elements)));
+}
+
+TEST(TestVectorIterator, RangeForLoop) {
+ std::vector<TestInt> ints = {1, 2, 3, 4};
+
+ auto ints_it = ints.begin();
+ for (auto maybe_i : VectorIt(ints)) {
+ ASSERT_OK_AND_ASSIGN(TestInt i, maybe_i);
+ ASSERT_EQ(i, *ints_it++);
+ }
+ ASSERT_EQ(ints_it, ints.end()) << *ints_it << "@" << (ints_it - ints.begin());
+
+ std::vector<std::unique_ptr<TestInt>> intptrs;
+ for (TestInt i : ints) {
+ intptrs.emplace_back(new TestInt(i));
+ }
+
+ // also works with move only types
+ ints_it = ints.begin();
+ for (auto maybe_i_ptr : MakeVectorIterator(std::move(intptrs))) {
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr<TestInt> i_ptr, maybe_i_ptr);
+ ASSERT_EQ(*i_ptr, *ints_it++);
+ }
+ ASSERT_EQ(ints_it, ints.end());
+}
+
+Transformer<TestInt, TestStr> MakeFirstN(int n) {
+ int remaining = n;
+ return [remaining](TestInt next) mutable -> Result<TransformFlow<TestStr>> {
+ if (remaining > 0) {
+ remaining--;
+ return TransformYield(TestStr(next));
+ }
+ return TransformFinish();
+ };
+}
+
+template <typename T>
+Transformer<T, T> MakeFirstNGeneric(int n) {
+ int remaining = n;
+ return [remaining](T next) mutable -> Result<TransformFlow<T>> {
+ if (remaining > 0) {
+ remaining--;
+ return TransformYield(next);
+ }
+ return TransformFinish();
+ };
+}
+
+TEST(TestIteratorTransform, Truncating) {
+ auto original = VectorIt({1, 2, 3});
+ auto truncated = MakeTransformedIterator(std::move(original), MakeFirstN(2));
+ AssertIteratorMatch({"1", "2"}, std::move(truncated));
+}
+
+TEST(TestIteratorTransform, TestPointer) {
+ auto original = VectorIt<std::shared_ptr<int>>(
+ {std::make_shared<int>(1), std::make_shared<int>(2), std::make_shared<int>(3)});
+ auto truncated = MakeTransformedIterator(std::move(original),
+ MakeFirstNGeneric<std::shared_ptr<int>>(2));
+ ASSERT_OK_AND_ASSIGN(auto result, truncated.ToVector());
+ ASSERT_EQ(2, result.size());
+}
+
+TEST(TestIteratorTransform, TruncatingShort) {
+ // Tests the failsafe case where we never call Finish
+ auto original = VectorIt({1});
+ auto truncated =
+ MakeTransformedIterator<TestInt, TestStr>(std::move(original), MakeFirstN(2));
+ AssertIteratorMatch({"1"}, std::move(truncated));
+}
+
+TEST(TestIteratorTransform, SkipSome) {
+ // Exercises TransformSkip
+ auto original = VectorIt({1, 2, 3});
+ auto filter = MakeFilter([](TestInt& t) { return t.value != 2; });
+ auto filtered = MakeTransformedIterator(std::move(original), filter);
+ AssertIteratorMatch({"1", "3"}, std::move(filtered));
+}
+
+TEST(TestIteratorTransform, SkipAll) {
+ // Exercises TransformSkip
+ auto original = VectorIt({1, 2, 3});
+ auto filter = MakeFilter([](TestInt& t) { return false; });
+ auto filtered = MakeTransformedIterator(std::move(original), filter);
+ AssertIteratorMatch({}, std::move(filtered));
+}
+
+Transformer<TestInt, TestStr> MakeAbortOnSecond() {
+ int counter = 0;
+ return [counter](TestInt next) mutable -> Result<TransformFlow<TestStr>> {
+ if (counter++ == 1) {
+ return Status::Invalid("X");
+ }
+ return TransformYield(TestStr(next));
+ };
+}
+
+TEST(TestIteratorTransform, Abort) {
+ auto original = VectorIt({1, 2, 3});
+ auto transformed = MakeTransformedIterator(std::move(original), MakeAbortOnSecond());
+ ASSERT_OK(transformed.Next());
+ ASSERT_RAISES(Invalid, transformed.Next());
+ ASSERT_OK_AND_ASSIGN(auto third, transformed.Next());
+ ASSERT_TRUE(IsIterationEnd(third));
+}
+
+template <typename T>
+Transformer<T, T> MakeRepeatN(int repeat_count) {
+ int current_repeat = 0;
+ return [repeat_count, current_repeat](T next) mutable -> Result<TransformFlow<T>> {
+ current_repeat++;
+ bool ready_for_next = false;
+ if (current_repeat == repeat_count) {
+ current_repeat = 0;
+ ready_for_next = true;
+ }
+ return TransformYield(next, ready_for_next);
+ };
+}
+
+TEST(TestIteratorTransform, Repeating) {
+ auto original = VectorIt({1, 2, 3});
+ auto repeated = MakeTransformedIterator<TestInt, TestInt>(std::move(original),
+ MakeRepeatN<TestInt>(2));
+ AssertIteratorMatch({1, 1, 2, 2, 3, 3}, std::move(repeated));
+}
+
+TEST(TestFunctionIterator, RangeForLoop) {
+ int i = 0;
+ auto fails_at_3 = MakeFunctionIterator([&]() -> Result<TestInt> {
+ if (i >= 3) {
+ return Status::IndexError("fails at 3");
+ }
+ return i++;
+ });
+
+ int expected_i = 0;
+ for (auto maybe_i : fails_at_3) {
+ if (expected_i < 3) {
+ ASSERT_OK(maybe_i.status());
+ ASSERT_EQ(*maybe_i, expected_i);
+ } else if (expected_i == 3) {
+ ASSERT_RAISES(IndexError, maybe_i.status());
+ }
+ ASSERT_LE(expected_i, 3) << "iteration stops after an error is encountered";
+ ++expected_i;
+ }
+}
+
+TEST(FilterIterator, Basic) {
+ AssertIteratorMatch({1, 2, 3, 4}, FilterIt(VectorIt({1, 2, 3, 4}), [](TestInt i) {
+ return FilterIterator::Accept(std::move(i));
+ }));
+
+ AssertIteratorMatch({}, FilterIt(VectorIt({1, 2, 3, 4}), [](TestInt i) {
+ return FilterIterator::Reject<TestInt>();
+ }));
+
+ AssertIteratorMatch({2, 4}, FilterIt(VectorIt({1, 2, 3, 4}), [](TestInt i) {
+ return i.value % 2 == 0 ? FilterIterator::Accept(std::move(i))
+ : FilterIterator::Reject<TestInt>();
+ }));
+}
+
+TEST(FlattenVectorIterator, Basic) {
+ // Flatten expects to consume Iterator<Iterator<T>>
+ AssertIteratorMatch({}, FlattenIt(EmptyIt<Iterator<TestInt>>()));
+
+ std::vector<Iterator<TestInt>> ok;
+ ok.push_back(VectorIt({1}));
+ ok.push_back(VectorIt({2}));
+ ok.push_back(VectorIt({3}));
+ AssertIteratorMatch({1, 2, 3}, FlattenIt(VectorIt(std::move(ok))));
+
+ std::vector<Iterator<TestInt>> not_enough;
+ not_enough.push_back(VectorIt({1}));
+ not_enough.push_back(VectorIt({2}));
+ AssertIteratorNoMatch({1, 2, 3}, FlattenIt(VectorIt(std::move(not_enough))));
+
+ std::vector<Iterator<TestInt>> too_much;
+ too_much.push_back(VectorIt({1}));
+ too_much.push_back(VectorIt({2}));
+ too_much.push_back(VectorIt({3}));
+ too_much.push_back(VectorIt({2}));
+ AssertIteratorNoMatch({1, 2, 3}, FlattenIt(VectorIt(std::move(too_much))));
+}
+
+Iterator<TestInt> Join(TestInt a, TestInt b) {
+ std::vector<Iterator<TestInt>> joined{2};
+ joined[0] = VectorIt({a});
+ joined[1] = VectorIt({b});
+
+ return FlattenIt(VectorIt(std::move(joined)));
+}
+
+Iterator<TestInt> Join(TestInt a, Iterator<TestInt> b) {
+ std::vector<Iterator<TestInt>> joined{2};
+ joined[0] = VectorIt(std::vector<TestInt>{a});
+ joined[1] = std::move(b);
+
+ return FlattenIt(VectorIt(std::move(joined)));
+}
+
+TEST(FlattenVectorIterator, Pyramid) {
+ auto it = Join(1, Join(2, Join(2, Join(3, Join(3, 3)))));
+ AssertIteratorMatch({1, 2, 2, 3, 3, 3}, std::move(it));
+}
+
+TEST(ReadaheadIterator, Empty) {
+ ASSERT_OK_AND_ASSIGN(auto it, MakeReadaheadIterator(VectorIt({}), 2));
+ AssertIteratorMatch({}, std::move(it));
+}
+
+TEST(ReadaheadIterator, Basic) {
+ ASSERT_OK_AND_ASSIGN(auto it, MakeReadaheadIterator(VectorIt({1, 2, 3, 4, 5}), 2));
+ AssertIteratorMatch({1, 2, 3, 4, 5}, std::move(it));
+}
+
+TEST(ReadaheadIterator, NotExhausted) {
+ ASSERT_OK_AND_ASSIGN(auto it, MakeReadaheadIterator(VectorIt({1, 2, 3, 4, 5}), 2));
+ AssertIteratorNext({1}, it);
+ AssertIteratorNext({2}, it);
+}
+
+TEST(ReadaheadIterator, Trace) {
+ TracingIterator<TestInt> tracing_it(VectorIt({1, 2, 3, 4, 5, 6, 7, 8}));
+ auto tracing = tracing_it.state();
+ ASSERT_EQ(tracing->values().size(), 0);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto it, MakeReadaheadIterator(Iterator<TestInt>(std::move(tracing_it)), 2));
+ SleepABit(); // Background iterator won't start pumping until first request comes in
+ ASSERT_EQ(tracing->values().size(), 0);
+
+ AssertIteratorNext({1}, it); // Once we ask for one value we should get that one value
+ // as well as 2 read ahead
+
+ tracing->WaitForValues(3);
+ tracing->AssertValuesEqual({1, 2, 3});
+
+ SleepABit(); // No further values should be fetched
+ tracing->AssertValuesEqual({1, 2, 3});
+
+ AssertIteratorNext({2}, it);
+ AssertIteratorNext({3}, it);
+ AssertIteratorNext({4}, it);
+ tracing->WaitForValues(6);
+ SleepABit();
+ tracing->AssertValuesEqual({1, 2, 3, 4, 5, 6});
+
+ AssertIteratorNext({5}, it);
+ AssertIteratorNext({6}, it);
+ AssertIteratorNext({7}, it);
+ tracing->WaitForValues(9);
+ SleepABit();
+ tracing->AssertValuesEqual({1, 2, 3, 4, 5, 6, 7, 8, {}});
+
+ AssertIteratorNext({8}, it);
+ AssertIteratorExhausted(it);
+ AssertIteratorExhausted(it); // Again
+ tracing->WaitForValues(9);
+ SleepABit();
+ tracing->AssertValuesStartwith({1, 2, 3, 4, 5, 6, 7, 8, {}});
+ // A couple more EOF values may have been emitted
+ const auto& values = tracing->values();
+ ASSERT_LE(values.size(), 11);
+ for (size_t i = 9; i < values.size(); ++i) {
+ ASSERT_EQ(values[i], TestInt());
+ }
+
+ // Values were all emitted from the same thread, and it's not this thread
+ const auto& thread_ids = tracing->thread_ids();
+ ASSERT_EQ(thread_ids.size(), 1);
+ ASSERT_NE(*thread_ids.begin(), std::this_thread::get_id());
+}
+
+TEST(ReadaheadIterator, NextError) {
+ TracingIterator<TestInt> tracing_it((VectorIt({1, 2, 3})));
+ auto tracing = tracing_it.state();
+ ASSERT_EQ(tracing->values().size(), 0);
+
+ tracing->InsertFailure(Status::IOError("xxx"));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto it, MakeReadaheadIterator(Iterator<TestInt>(std::move(tracing_it)), 2));
+
+ ASSERT_RAISES(IOError, it.Next().status());
+
+ AssertIteratorExhausted(it);
+ SleepABit();
+ tracing->AssertValuesEqual({});
+ AssertIteratorExhausted(it);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/key_value_metadata.cc b/src/arrow/cpp/src/arrow/util/key_value_metadata.cc
new file mode 100644
index 000000000..bc48ae76c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/key_value_metadata.cc
@@ -0,0 +1,274 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/sort.h"
+
+using std::size_t;
+
+namespace arrow {
+
+static std::vector<std::string> UnorderedMapKeys(
+ const std::unordered_map<std::string, std::string>& map) {
+ std::vector<std::string> keys;
+ keys.reserve(map.size());
+ for (const auto& pair : map) {
+ keys.push_back(pair.first);
+ }
+ return keys;
+}
+
+static std::vector<std::string> UnorderedMapValues(
+ const std::unordered_map<std::string, std::string>& map) {
+ std::vector<std::string> values;
+ values.reserve(map.size());
+ for (const auto& pair : map) {
+ values.push_back(pair.second);
+ }
+ return values;
+}
+
+KeyValueMetadata::KeyValueMetadata() {}
+
+KeyValueMetadata::KeyValueMetadata(
+ const std::unordered_map<std::string, std::string>& map)
+ : keys_(UnorderedMapKeys(map)), values_(UnorderedMapValues(map)) {
+ ARROW_CHECK_EQ(keys_.size(), values_.size());
+}
+
+KeyValueMetadata::KeyValueMetadata(std::vector<std::string> keys,
+ std::vector<std::string> values)
+ : keys_(std::move(keys)), values_(std::move(values)) {
+ ARROW_CHECK_EQ(keys.size(), values.size());
+}
+
+std::shared_ptr<KeyValueMetadata> KeyValueMetadata::Make(
+ std::vector<std::string> keys, std::vector<std::string> values) {
+ return std::make_shared<KeyValueMetadata>(std::move(keys), std::move(values));
+}
+
+void KeyValueMetadata::ToUnorderedMap(
+ std::unordered_map<std::string, std::string>* out) const {
+ DCHECK_NE(out, nullptr);
+ const int64_t n = size();
+ out->reserve(n);
+ for (int64_t i = 0; i < n; ++i) {
+ out->insert(std::make_pair(key(i), value(i)));
+ }
+}
+
+void KeyValueMetadata::Append(std::string key, std::string value) {
+ keys_.push_back(std::move(key));
+ values_.push_back(std::move(value));
+}
+
+Result<std::string> KeyValueMetadata::Get(const std::string& key) const {
+ auto index = FindKey(key);
+ if (index < 0) {
+ return Status::KeyError(key);
+ } else {
+ return value(index);
+ }
+}
+
+Status KeyValueMetadata::Delete(int64_t index) {
+ keys_.erase(keys_.begin() + index);
+ values_.erase(values_.begin() + index);
+ return Status::OK();
+}
+
+Status KeyValueMetadata::DeleteMany(std::vector<int64_t> indices) {
+ std::sort(indices.begin(), indices.end());
+ const int64_t size = static_cast<int64_t>(keys_.size());
+ indices.push_back(size);
+
+ int64_t shift = 0;
+ for (int64_t i = 0; i < static_cast<int64_t>(indices.size() - 1); ++i) {
+ ++shift;
+ const auto start = indices[i] + 1;
+ const auto stop = indices[i + 1];
+ DCHECK_GE(start, 0);
+ DCHECK_LE(start, size);
+ DCHECK_GE(stop, 0);
+ DCHECK_LE(stop, size);
+ for (int64_t index = start; index < stop; ++index) {
+ keys_[index - shift] = std::move(keys_[index]);
+ values_[index - shift] = std::move(values_[index]);
+ }
+ }
+ keys_.resize(size - shift);
+ values_.resize(size - shift);
+ return Status::OK();
+}
+
+Status KeyValueMetadata::Delete(const std::string& key) {
+ auto index = FindKey(key);
+ if (index < 0) {
+ return Status::KeyError(key);
+ } else {
+ return Delete(index);
+ }
+}
+
+Status KeyValueMetadata::Set(const std::string& key, const std::string& value) {
+ auto index = FindKey(key);
+ if (index < 0) {
+ Append(key, value);
+ } else {
+ keys_[index] = key;
+ values_[index] = value;
+ }
+ return Status::OK();
+}
+
+bool KeyValueMetadata::Contains(const std::string& key) const {
+ return FindKey(key) >= 0;
+}
+
+void KeyValueMetadata::reserve(int64_t n) {
+ DCHECK_GE(n, 0);
+ const auto m = static_cast<size_t>(n);
+ keys_.reserve(m);
+ values_.reserve(m);
+}
+
+int64_t KeyValueMetadata::size() const {
+ DCHECK_EQ(keys_.size(), values_.size());
+ return static_cast<int64_t>(keys_.size());
+}
+
+const std::string& KeyValueMetadata::key(int64_t i) const {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(static_cast<size_t>(i), keys_.size());
+ return keys_[i];
+}
+
+const std::string& KeyValueMetadata::value(int64_t i) const {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(static_cast<size_t>(i), values_.size());
+ return values_[i];
+}
+
+std::vector<std::pair<std::string, std::string>> KeyValueMetadata::sorted_pairs() const {
+ std::vector<std::pair<std::string, std::string>> pairs;
+ pairs.reserve(size());
+
+ auto indices = internal::ArgSort(keys_);
+ for (const auto i : indices) {
+ pairs.emplace_back(keys_[i], values_[i]);
+ }
+ return pairs;
+}
+
+int KeyValueMetadata::FindKey(const std::string& key) const {
+ for (size_t i = 0; i < keys_.size(); ++i) {
+ if (keys_[i] == key) {
+ return static_cast<int>(i);
+ }
+ }
+ return -1;
+}
+
+std::shared_ptr<KeyValueMetadata> KeyValueMetadata::Copy() const {
+ return std::make_shared<KeyValueMetadata>(keys_, values_);
+}
+
+std::shared_ptr<KeyValueMetadata> KeyValueMetadata::Merge(
+ const KeyValueMetadata& other) const {
+ std::unordered_set<std::string> observed_keys;
+ std::vector<std::string> result_keys;
+ std::vector<std::string> result_values;
+
+ result_keys.reserve(keys_.size());
+ result_values.reserve(keys_.size());
+
+ for (int64_t i = 0; i < other.size(); ++i) {
+ const auto& key = other.key(i);
+ auto it = observed_keys.find(key);
+ if (it == observed_keys.end()) {
+ result_keys.push_back(key);
+ result_values.push_back(other.value(i));
+ observed_keys.insert(key);
+ }
+ }
+ for (size_t i = 0; i < keys_.size(); ++i) {
+ auto it = observed_keys.find(keys_[i]);
+ if (it == observed_keys.end()) {
+ result_keys.push_back(keys_[i]);
+ result_values.push_back(values_[i]);
+ observed_keys.insert(keys_[i]);
+ }
+ }
+
+ return std::make_shared<KeyValueMetadata>(std::move(result_keys),
+ std::move(result_values));
+}
+
+bool KeyValueMetadata::Equals(const KeyValueMetadata& other) const {
+ if (size() != other.size()) {
+ return false;
+ }
+
+ auto indices = internal::ArgSort(keys_);
+ auto other_indices = internal::ArgSort(other.keys_);
+
+ for (int64_t i = 0; i < size(); ++i) {
+ auto j = indices[i];
+ auto k = other_indices[i];
+ if (keys_[j] != other.keys_[k] || values_[j] != other.values_[k]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+std::string KeyValueMetadata::ToString() const {
+ std::stringstream buffer;
+
+ buffer << "\n-- metadata --";
+ for (int64_t i = 0; i < size(); ++i) {
+ buffer << "\n" << keys_[i] << ": " << values_[i];
+ }
+
+ return buffer.str();
+}
+
+std::shared_ptr<KeyValueMetadata> key_value_metadata(
+ const std::unordered_map<std::string, std::string>& pairs) {
+ return std::make_shared<KeyValueMetadata>(pairs);
+}
+
+std::shared_ptr<KeyValueMetadata> key_value_metadata(std::vector<std::string> keys,
+ std::vector<std::string> values) {
+ return std::make_shared<KeyValueMetadata>(std::move(keys), std::move(values));
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/key_value_metadata.h b/src/arrow/cpp/src/arrow/util/key_value_metadata.h
new file mode 100644
index 000000000..ba70ffe88
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/key_value_metadata.h
@@ -0,0 +1,98 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+/// \brief A container for key-value pair type metadata. Not thread-safe
+class ARROW_EXPORT KeyValueMetadata {
+ public:
+ KeyValueMetadata();
+ KeyValueMetadata(std::vector<std::string> keys, std::vector<std::string> values);
+ explicit KeyValueMetadata(const std::unordered_map<std::string, std::string>& map);
+
+ static std::shared_ptr<KeyValueMetadata> Make(std::vector<std::string> keys,
+ std::vector<std::string> values);
+
+ void ToUnorderedMap(std::unordered_map<std::string, std::string>* out) const;
+ void Append(std::string key, std::string value);
+
+ Result<std::string> Get(const std::string& key) const;
+ bool Contains(const std::string& key) const;
+ // Note that deleting may invalidate known indices
+ Status Delete(const std::string& key);
+ Status Delete(int64_t index);
+ Status DeleteMany(std::vector<int64_t> indices);
+ Status Set(const std::string& key, const std::string& value);
+
+ void reserve(int64_t n);
+
+ int64_t size() const;
+ const std::string& key(int64_t i) const;
+ const std::string& value(int64_t i) const;
+ const std::vector<std::string>& keys() const { return keys_; }
+ const std::vector<std::string>& values() const { return values_; }
+
+ std::vector<std::pair<std::string, std::string>> sorted_pairs() const;
+
+ /// \brief Perform linear search for key, returning -1 if not found
+ int FindKey(const std::string& key) const;
+
+ std::shared_ptr<KeyValueMetadata> Copy() const;
+
+ /// \brief Return a new KeyValueMetadata by combining the passed metadata
+ /// with this KeyValueMetadata. Colliding keys will be overridden by the
+ /// passed metadata. Assumes keys in both containers are unique
+ std::shared_ptr<KeyValueMetadata> Merge(const KeyValueMetadata& other) const;
+
+ bool Equals(const KeyValueMetadata& other) const;
+ std::string ToString() const;
+
+ private:
+ std::vector<std::string> keys_;
+ std::vector<std::string> values_;
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(KeyValueMetadata);
+};
+
+/// \brief Create a KeyValueMetadata instance
+///
+/// \param pairs key-value mapping
+std::shared_ptr<KeyValueMetadata> ARROW_EXPORT
+key_value_metadata(const std::unordered_map<std::string, std::string>& pairs);
+
+/// \brief Create a KeyValueMetadata instance
+///
+/// \param keys sequence of metadata keys
+/// \param values sequence of corresponding metadata values
+std::shared_ptr<KeyValueMetadata> ARROW_EXPORT
+key_value_metadata(std::vector<std::string> keys, std::vector<std::string> values);
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/key_value_metadata_test.cc b/src/arrow/cpp/src/arrow/util/key_value_metadata_test.cc
new file mode 100644
index 000000000..3cdcf9475
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/key_value_metadata_test.cc
@@ -0,0 +1,211 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+
+TEST(KeyValueMetadataTest, SimpleConstruction) {
+ KeyValueMetadata metadata;
+ ASSERT_EQ(0, metadata.size());
+}
+
+TEST(KeyValueMetadataTest, StringVectorConstruction) {
+ std::vector<std::string> keys = {"foo", "bar"};
+ std::vector<std::string> values = {"bizz", "buzz"};
+
+ KeyValueMetadata metadata(keys, values);
+ ASSERT_EQ("foo", metadata.key(0));
+ ASSERT_EQ("bar", metadata.key(1));
+ ASSERT_EQ("bizz", metadata.value(0));
+ ASSERT_EQ("buzz", metadata.value(1));
+ ASSERT_EQ(2, metadata.size());
+
+ std::shared_ptr<KeyValueMetadata> metadata2 =
+ key_value_metadata({"foo", "bar"}, {"bizz", "buzz"});
+ ASSERT_TRUE(metadata.Equals(*metadata2));
+}
+
+TEST(KeyValueMetadataTest, StringMapConstruction) {
+ std::unordered_map<std::string, std::string> pairs = {{"foo", "bizz"}, {"bar", "buzz"}};
+ std::unordered_map<std::string, std::string> result_map;
+ result_map.reserve(pairs.size());
+
+ KeyValueMetadata metadata(pairs);
+ metadata.ToUnorderedMap(&result_map);
+ ASSERT_EQ(pairs, result_map);
+ ASSERT_EQ(2, metadata.size());
+}
+
+TEST(KeyValueMetadataTest, StringAppend) {
+ std::vector<std::string> keys = {"foo", "bar"};
+ std::vector<std::string> values = {"bizz", "buzz"};
+
+ KeyValueMetadata metadata(keys, values);
+ ASSERT_EQ("foo", metadata.key(0));
+ ASSERT_EQ("bar", metadata.key(1));
+ ASSERT_EQ("bizz", metadata.value(0));
+ ASSERT_EQ("buzz", metadata.value(1));
+ ASSERT_EQ(2, metadata.size());
+
+ metadata.Append("purple", "orange");
+ metadata.Append("blue", "red");
+
+ ASSERT_EQ("purple", metadata.key(2));
+ ASSERT_EQ("blue", metadata.key(3));
+
+ ASSERT_EQ("orange", metadata.value(2));
+ ASSERT_EQ("red", metadata.value(3));
+}
+
+TEST(KeyValueMetadataTest, Copy) {
+ std::vector<std::string> keys = {"foo", "bar"};
+ std::vector<std::string> values = {"bizz", "buzz"};
+
+ KeyValueMetadata metadata(keys, values);
+ auto metadata2 = metadata.Copy();
+ ASSERT_TRUE(metadata.Equals(*metadata2));
+}
+
+TEST(KeyValueMetadataTest, Merge) {
+ std::vector<std::string> keys1 = {"foo", "bar"};
+ std::vector<std::string> values1 = {"bizz", "buzz"};
+ KeyValueMetadata metadata(keys1, values1);
+
+ std::vector<std::string> keys2 = {"bar", "baz"};
+ std::vector<std::string> values2 = {"bozz", "bezz"};
+ KeyValueMetadata metadata2(keys2, values2);
+
+ std::vector<std::string> keys3 = {"foo", "bar", "baz"};
+ std::vector<std::string> values3 = {"bizz", "bozz", "bezz"};
+ KeyValueMetadata expected(keys3, values3);
+
+ auto result = metadata.Merge(metadata2);
+ ASSERT_TRUE(result->Equals(expected));
+}
+
+TEST(KeyValueMetadataTest, FindKey) {
+ std::vector<std::string> keys = {"foo", "bar"};
+ std::vector<std::string> values = {"bizz", "buzz"};
+ KeyValueMetadata metadata(keys, values);
+
+ ASSERT_EQ(0, metadata.FindKey("foo"));
+ ASSERT_EQ(1, metadata.FindKey("bar"));
+ ASSERT_EQ(-1, metadata.FindKey("baz"));
+}
+
+TEST(KeyValueMetadataTest, Equals) {
+ std::vector<std::string> keys = {"foo", "bar"};
+ std::vector<std::string> values = {"bizz", "buzz"};
+
+ KeyValueMetadata metadata(keys, values);
+ KeyValueMetadata metadata2(keys, values);
+ KeyValueMetadata metadata3(keys, {"buzz", "bizz"});
+
+ ASSERT_TRUE(metadata.Equals(metadata2));
+ ASSERT_FALSE(metadata.Equals(metadata3));
+
+ // Key / value pairs are semantically unordered
+ std::reverse(keys.begin(), keys.end());
+ KeyValueMetadata metadata4(keys, values);
+ std::reverse(values.begin(), values.end());
+ KeyValueMetadata metadata5(keys, values);
+
+ ASSERT_FALSE(metadata.Equals(metadata4));
+ ASSERT_TRUE(metadata.Equals(metadata5));
+
+ KeyValueMetadata metadata6({"foo"}, {"bizz"});
+ ASSERT_FALSE(metadata.Equals(metadata6));
+}
+
+TEST(KeyValueMetadataTest, ToString) {
+ std::vector<std::string> keys = {"foo", "bar"};
+ std::vector<std::string> values = {"bizz", "buzz"};
+
+ KeyValueMetadata metadata(keys, values);
+
+ std::string result = metadata.ToString();
+ std::string expected = R"(
+-- metadata --
+foo: bizz
+bar: buzz)";
+
+ ASSERT_EQ(expected, result);
+}
+
+TEST(KeyValueMetadataTest, SortedPairs) {
+ std::vector<std::string> keys = {"foo", "bar"};
+ std::vector<std::string> values = {"bizz", "buzz"};
+
+ KeyValueMetadata metadata1(keys, values);
+ std::reverse(keys.begin(), keys.end());
+ KeyValueMetadata metadata2(keys, values);
+ std::reverse(values.begin(), values.end());
+ KeyValueMetadata metadata3(keys, values);
+
+ std::vector<std::pair<std::string, std::string>> expected = {{"bar", "buzz"},
+ {"foo", "bizz"}};
+ ASSERT_EQ(metadata1.sorted_pairs(), expected);
+ ASSERT_EQ(metadata3.sorted_pairs(), expected);
+ expected = {{"bar", "bizz"}, {"foo", "buzz"}};
+ ASSERT_EQ(metadata2.sorted_pairs(), expected);
+}
+
+TEST(KeyValueMetadataTest, Delete) {
+ std::vector<std::string> keys = {"aa", "bb", "cc", "dd", "ee", "ff", "gg"};
+ std::vector<std::string> values = {"1", "2", "3", "4", "5", "6", "7"};
+
+ {
+ KeyValueMetadata metadata(keys, values);
+ ASSERT_OK(metadata.Delete("cc"));
+ ASSERT_TRUE(metadata.Equals(KeyValueMetadata({"aa", "bb", "dd", "ee", "ff", "gg"},
+ {"1", "2", "4", "5", "6", "7"})));
+
+ ASSERT_OK(metadata.Delete(3));
+ ASSERT_TRUE(metadata.Equals(
+ KeyValueMetadata({"aa", "bb", "dd", "ff", "gg"}, {"1", "2", "4", "6", "7"})));
+ }
+ {
+ KeyValueMetadata metadata(keys, values);
+ ASSERT_OK(metadata.DeleteMany({2, 5}));
+ ASSERT_TRUE(metadata.Equals(
+ KeyValueMetadata({"aa", "bb", "dd", "ee", "gg"}, {"1", "2", "4", "5", "7"})));
+
+ ASSERT_OK(metadata.DeleteMany({}));
+ ASSERT_TRUE(metadata.Equals(
+ KeyValueMetadata({"aa", "bb", "dd", "ee", "gg"}, {"1", "2", "4", "5", "7"})));
+ }
+ {
+ KeyValueMetadata metadata(keys, values);
+ ASSERT_OK(metadata.DeleteMany({0, 6, 5, 2}));
+ ASSERT_TRUE(metadata.Equals(KeyValueMetadata({"bb", "dd", "ee"}, {"2", "4", "5"})));
+
+ ASSERT_OK(metadata.DeleteMany({}));
+ ASSERT_TRUE(metadata.Equals(KeyValueMetadata({"bb", "dd", "ee"}, {"2", "4", "5"})));
+ }
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/launder.h b/src/arrow/cpp/src/arrow/util/launder.h
new file mode 100644
index 000000000..37e2a7144
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/launder.h
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <new>
+
+namespace arrow {
+namespace internal {
+
+#if __cplusplus >= 201703L
+using std::launder;
+#else
+template <class T>
+constexpr T* launder(T* p) noexcept {
+ return p;
+}
+#endif
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/logging.cc b/src/arrow/cpp/src/arrow/util/logging.cc
new file mode 100644
index 000000000..04fcf3d3e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/logging.cc
@@ -0,0 +1,256 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/logging.h"
+
+#ifdef ARROW_WITH_BACKTRACE
+#include <execinfo.h>
+#endif
+#include <cstdlib>
+#include <iostream>
+
+#ifdef ARROW_USE_GLOG
+
+#include <signal.h>
+#include <vector>
+
+#include "glog/logging.h"
+
+// Restore our versions of DCHECK and friends, as GLog defines its own
+#undef DCHECK
+#undef DCHECK_OK
+#undef DCHECK_EQ
+#undef DCHECK_NE
+#undef DCHECK_LE
+#undef DCHECK_LT
+#undef DCHECK_GE
+#undef DCHECK_GT
+
+#define DCHECK ARROW_DCHECK
+#define DCHECK_OK ARROW_DCHECK_OK
+#define DCHECK_EQ ARROW_DCHECK_EQ
+#define DCHECK_NE ARROW_DCHECK_NE
+#define DCHECK_LE ARROW_DCHECK_LE
+#define DCHECK_LT ARROW_DCHECK_LT
+#define DCHECK_GE ARROW_DCHECK_GE
+#define DCHECK_GT ARROW_DCHECK_GT
+
+#endif
+
+namespace arrow {
+namespace util {
+
+// This code is adapted from
+// https://github.com/ray-project/ray/blob/master/src/ray/util/logging.cc.
+
+// This is the default implementation of arrow log,
+// which is independent of any libs.
+class CerrLog {
+ public:
+ explicit CerrLog(ArrowLogLevel severity) : severity_(severity), has_logged_(false) {}
+
+ virtual ~CerrLog() {
+ if (has_logged_) {
+ std::cerr << std::endl;
+ }
+ if (severity_ == ArrowLogLevel::ARROW_FATAL) {
+ PrintBackTrace();
+ std::abort();
+ }
+ }
+
+ std::ostream& Stream() {
+ has_logged_ = true;
+ return std::cerr;
+ }
+
+ template <class T>
+ CerrLog& operator<<(const T& t) {
+ if (severity_ != ArrowLogLevel::ARROW_DEBUG) {
+ has_logged_ = true;
+ std::cerr << t;
+ }
+ return *this;
+ }
+
+ protected:
+ const ArrowLogLevel severity_;
+ bool has_logged_;
+
+ void PrintBackTrace() {
+#ifdef ARROW_WITH_BACKTRACE
+ void* buffer[255];
+ const int calls = backtrace(buffer, static_cast<int>(sizeof(buffer) / sizeof(void*)));
+ backtrace_symbols_fd(buffer, calls, 1);
+#endif
+ }
+};
+
+#ifdef ARROW_USE_GLOG
+typedef google::LogMessage LoggingProvider;
+#else
+typedef CerrLog LoggingProvider;
+#endif
+
+ArrowLogLevel ArrowLog::severity_threshold_ = ArrowLogLevel::ARROW_INFO;
+// Keep the log directory.
+static std::unique_ptr<std::string> log_dir_;
+
+#ifdef ARROW_USE_GLOG
+
+// Glog's severity map.
+static int GetMappedSeverity(ArrowLogLevel severity) {
+ switch (severity) {
+ case ArrowLogLevel::ARROW_DEBUG:
+ return google::GLOG_INFO;
+ case ArrowLogLevel::ARROW_INFO:
+ return google::GLOG_INFO;
+ case ArrowLogLevel::ARROW_WARNING:
+ return google::GLOG_WARNING;
+ case ArrowLogLevel::ARROW_ERROR:
+ return google::GLOG_ERROR;
+ case ArrowLogLevel::ARROW_FATAL:
+ return google::GLOG_FATAL;
+ default:
+ ARROW_LOG(FATAL) << "Unsupported logging level: " << static_cast<int>(severity);
+ // This return won't be hit but compiler needs it.
+ return google::GLOG_FATAL;
+ }
+}
+
+#endif
+
+void ArrowLog::StartArrowLog(const std::string& app_name,
+ ArrowLogLevel severity_threshold,
+ const std::string& log_dir) {
+ severity_threshold_ = severity_threshold;
+ // In InitGoogleLogging, it simply keeps the pointer.
+ // We need to make sure the app name passed to InitGoogleLogging exist.
+ // We should avoid using static string is a dynamic lib.
+ static std::unique_ptr<std::string> app_name_;
+ app_name_.reset(new std::string(app_name));
+ log_dir_.reset(new std::string(log_dir));
+#ifdef ARROW_USE_GLOG
+ int mapped_severity_threshold = GetMappedSeverity(severity_threshold_);
+ google::SetStderrLogging(mapped_severity_threshold);
+ // Enble log file if log_dir is not empty.
+ if (!log_dir.empty()) {
+ auto dir_ends_with_slash = log_dir;
+ if (log_dir[log_dir.length() - 1] != '/') {
+ dir_ends_with_slash += "/";
+ }
+ auto app_name_without_path = app_name;
+ if (app_name.empty()) {
+ app_name_without_path = "DefaultApp";
+ } else {
+ // Find the app name without the path.
+ size_t pos = app_name.rfind('/');
+ if (pos != app_name.npos && pos + 1 < app_name.length()) {
+ app_name_without_path = app_name.substr(pos + 1);
+ }
+ }
+ // If InitGoogleLogging is called but SetLogDestination is not called,
+ // the log will be output to /tmp besides stderr. If log_dir is not
+ // provided, we'd better not call InitGoogleLogging.
+ google::InitGoogleLogging(app_name_->c_str());
+ google::SetLogFilenameExtension(app_name_without_path.c_str());
+ for (int i = static_cast<int>(severity_threshold_);
+ i <= static_cast<int>(ArrowLogLevel::ARROW_FATAL); ++i) {
+ int level = GetMappedSeverity(static_cast<ArrowLogLevel>(i));
+ google::SetLogDestination(level, dir_ends_with_slash.c_str());
+ }
+ }
+#endif
+}
+
+void ArrowLog::UninstallSignalAction() {
+#ifdef ARROW_USE_GLOG
+ ARROW_LOG(DEBUG) << "Uninstall signal handlers.";
+ // This signal list comes from glog's signalhandler.cc.
+ // https://github.com/google/glog/blob/master/src/signalhandler.cc#L58-L70
+ std::vector<int> installed_signals({SIGSEGV, SIGILL, SIGFPE, SIGABRT, SIGTERM});
+#ifdef WIN32
+ for (int signal_num : installed_signals) {
+ ARROW_CHECK(signal(signal_num, SIG_DFL) != SIG_ERR);
+ }
+#else
+ struct sigaction sig_action;
+ memset(&sig_action, 0, sizeof(sig_action));
+ sigemptyset(&sig_action.sa_mask);
+ sig_action.sa_handler = SIG_DFL;
+ for (int signal_num : installed_signals) {
+ ARROW_CHECK(sigaction(signal_num, &sig_action, NULL) == 0);
+ }
+#endif
+#endif
+}
+
+void ArrowLog::ShutDownArrowLog() {
+#ifdef ARROW_USE_GLOG
+ if (!log_dir_->empty()) {
+ google::ShutdownGoogleLogging();
+ }
+#endif
+}
+
+void ArrowLog::InstallFailureSignalHandler() {
+#ifdef ARROW_USE_GLOG
+ google::InstallFailureSignalHandler();
+#endif
+}
+
+bool ArrowLog::IsLevelEnabled(ArrowLogLevel log_level) {
+ return log_level >= severity_threshold_;
+}
+
+ArrowLog::ArrowLog(const char* file_name, int line_number, ArrowLogLevel severity)
+ // glog does not have DEBUG level, we can handle it using is_enabled_.
+ : logging_provider_(nullptr), is_enabled_(severity >= severity_threshold_) {
+#ifdef ARROW_USE_GLOG
+ if (is_enabled_) {
+ logging_provider_ =
+ new google::LogMessage(file_name, line_number, GetMappedSeverity(severity));
+ }
+#else
+ auto logging_provider = new CerrLog(severity);
+ *logging_provider << file_name << ":" << line_number << ": ";
+ logging_provider_ = logging_provider;
+#endif
+}
+
+std::ostream& ArrowLog::Stream() {
+ auto logging_provider = reinterpret_cast<LoggingProvider*>(logging_provider_);
+#ifdef ARROW_USE_GLOG
+ // Before calling this function, user should check IsEnabled.
+ // When IsEnabled == false, logging_provider_ will be empty.
+ return logging_provider->stream();
+#else
+ return logging_provider->Stream();
+#endif
+}
+
+bool ArrowLog::IsEnabled() const { return is_enabled_; }
+
+ArrowLog::~ArrowLog() {
+ if (logging_provider_ != nullptr) {
+ delete reinterpret_cast<LoggingProvider*>(logging_provider_);
+ logging_provider_ = nullptr;
+ }
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/logging.h b/src/arrow/cpp/src/arrow/util/logging.h
new file mode 100644
index 000000000..15a0188ab
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/logging.h
@@ -0,0 +1,259 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#ifdef GANDIVA_IR
+
+// The LLVM IR code doesn't have an NDEBUG mode. And, it shouldn't include references to
+// streams or stdc++. So, making the DCHECK calls void in that case.
+
+#define ARROW_IGNORE_EXPR(expr) ((void)(expr))
+
+#define DCHECK(condition) ARROW_IGNORE_EXPR(condition)
+#define DCHECK_OK(status) ARROW_IGNORE_EXPR(status)
+#define DCHECK_EQ(val1, val2) ARROW_IGNORE_EXPR(val1)
+#define DCHECK_NE(val1, val2) ARROW_IGNORE_EXPR(val1)
+#define DCHECK_LE(val1, val2) ARROW_IGNORE_EXPR(val1)
+#define DCHECK_LT(val1, val2) ARROW_IGNORE_EXPR(val1)
+#define DCHECK_GE(val1, val2) ARROW_IGNORE_EXPR(val1)
+#define DCHECK_GT(val1, val2) ARROW_IGNORE_EXPR(val1)
+
+#else // !GANDIVA_IR
+
+#include <memory>
+#include <ostream>
+#include <string>
+
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace util {
+
+enum class ArrowLogLevel : int {
+ ARROW_DEBUG = -1,
+ ARROW_INFO = 0,
+ ARROW_WARNING = 1,
+ ARROW_ERROR = 2,
+ ARROW_FATAL = 3
+};
+
+#define ARROW_LOG_INTERNAL(level) ::arrow::util::ArrowLog(__FILE__, __LINE__, level)
+#define ARROW_LOG(level) ARROW_LOG_INTERNAL(::arrow::util::ArrowLogLevel::ARROW_##level)
+
+#define ARROW_IGNORE_EXPR(expr) ((void)(expr))
+
+#define ARROW_CHECK(condition) \
+ ARROW_PREDICT_TRUE(condition) \
+ ? ARROW_IGNORE_EXPR(0) \
+ : ::arrow::util::Voidify() & \
+ ::arrow::util::ArrowLog(__FILE__, __LINE__, \
+ ::arrow::util::ArrowLogLevel::ARROW_FATAL) \
+ << " Check failed: " #condition " "
+
+// If 'to_call' returns a bad status, CHECK immediately with a logged message
+// of 'msg' followed by the status.
+#define ARROW_CHECK_OK_PREPEND(to_call, msg) \
+ do { \
+ ::arrow::Status _s = (to_call); \
+ ARROW_CHECK(_s.ok()) << "Operation failed: " << ARROW_STRINGIFY(to_call) << "\n" \
+ << (msg) << ": " << _s.ToString(); \
+ } while (false)
+
+// If the status is bad, CHECK immediately, appending the status to the
+// logged message.
+#define ARROW_CHECK_OK(s) ARROW_CHECK_OK_PREPEND(s, "Bad status")
+
+#define ARROW_CHECK_EQ(val1, val2) ARROW_CHECK((val1) == (val2))
+#define ARROW_CHECK_NE(val1, val2) ARROW_CHECK((val1) != (val2))
+#define ARROW_CHECK_LE(val1, val2) ARROW_CHECK((val1) <= (val2))
+#define ARROW_CHECK_LT(val1, val2) ARROW_CHECK((val1) < (val2))
+#define ARROW_CHECK_GE(val1, val2) ARROW_CHECK((val1) >= (val2))
+#define ARROW_CHECK_GT(val1, val2) ARROW_CHECK((val1) > (val2))
+
+#ifdef NDEBUG
+#define ARROW_DFATAL ::arrow::util::ArrowLogLevel::ARROW_WARNING
+
+// CAUTION: DCHECK_OK() always evaluates its argument, but other DCHECK*() macros
+// only do so in debug mode.
+
+#define ARROW_DCHECK(condition) \
+ while (false) ARROW_IGNORE_EXPR(condition); \
+ while (false) ::arrow::util::detail::NullLog()
+#define ARROW_DCHECK_OK(s) \
+ ARROW_IGNORE_EXPR(s); \
+ while (false) ::arrow::util::detail::NullLog()
+#define ARROW_DCHECK_EQ(val1, val2) \
+ while (false) ARROW_IGNORE_EXPR(val1); \
+ while (false) ARROW_IGNORE_EXPR(val2); \
+ while (false) ::arrow::util::detail::NullLog()
+#define ARROW_DCHECK_NE(val1, val2) \
+ while (false) ARROW_IGNORE_EXPR(val1); \
+ while (false) ARROW_IGNORE_EXPR(val2); \
+ while (false) ::arrow::util::detail::NullLog()
+#define ARROW_DCHECK_LE(val1, val2) \
+ while (false) ARROW_IGNORE_EXPR(val1); \
+ while (false) ARROW_IGNORE_EXPR(val2); \
+ while (false) ::arrow::util::detail::NullLog()
+#define ARROW_DCHECK_LT(val1, val2) \
+ while (false) ARROW_IGNORE_EXPR(val1); \
+ while (false) ARROW_IGNORE_EXPR(val2); \
+ while (false) ::arrow::util::detail::NullLog()
+#define ARROW_DCHECK_GE(val1, val2) \
+ while (false) ARROW_IGNORE_EXPR(val1); \
+ while (false) ARROW_IGNORE_EXPR(val2); \
+ while (false) ::arrow::util::detail::NullLog()
+#define ARROW_DCHECK_GT(val1, val2) \
+ while (false) ARROW_IGNORE_EXPR(val1); \
+ while (false) ARROW_IGNORE_EXPR(val2); \
+ while (false) ::arrow::util::detail::NullLog()
+
+#else
+#define ARROW_DFATAL ::arrow::util::ArrowLogLevel::ARROW_FATAL
+
+#define ARROW_DCHECK ARROW_CHECK
+#define ARROW_DCHECK_OK ARROW_CHECK_OK
+#define ARROW_DCHECK_EQ ARROW_CHECK_EQ
+#define ARROW_DCHECK_NE ARROW_CHECK_NE
+#define ARROW_DCHECK_LE ARROW_CHECK_LE
+#define ARROW_DCHECK_LT ARROW_CHECK_LT
+#define ARROW_DCHECK_GE ARROW_CHECK_GE
+#define ARROW_DCHECK_GT ARROW_CHECK_GT
+
+#endif // NDEBUG
+
+#define DCHECK ARROW_DCHECK
+#define DCHECK_OK ARROW_DCHECK_OK
+#define DCHECK_EQ ARROW_DCHECK_EQ
+#define DCHECK_NE ARROW_DCHECK_NE
+#define DCHECK_LE ARROW_DCHECK_LE
+#define DCHECK_LT ARROW_DCHECK_LT
+#define DCHECK_GE ARROW_DCHECK_GE
+#define DCHECK_GT ARROW_DCHECK_GT
+
+// This code is adapted from
+// https://github.com/ray-project/ray/blob/master/src/ray/util/logging.h.
+
+// To make the logging lib pluggable with other logging libs and make
+// the implementation unawared by the user, ArrowLog is only a declaration
+// which hide the implementation into logging.cc file.
+// In logging.cc, we can choose different log libs using different macros.
+
+// This is also a null log which does not output anything.
+class ARROW_EXPORT ArrowLogBase {
+ public:
+ virtual ~ArrowLogBase() {}
+
+ virtual bool IsEnabled() const { return false; }
+
+ template <typename T>
+ ArrowLogBase& operator<<(const T& t) {
+ if (IsEnabled()) {
+ Stream() << t;
+ }
+ return *this;
+ }
+
+ protected:
+ virtual std::ostream& Stream() = 0;
+};
+
+class ARROW_EXPORT ArrowLog : public ArrowLogBase {
+ public:
+ ArrowLog(const char* file_name, int line_number, ArrowLogLevel severity);
+ ~ArrowLog() override;
+
+ /// Return whether or not current logging instance is enabled.
+ ///
+ /// \return True if logging is enabled and false otherwise.
+ bool IsEnabled() const override;
+
+ /// The init function of arrow log for a program which should be called only once.
+ ///
+ /// \param appName The app name which starts the log.
+ /// \param severity_threshold Logging threshold for the program.
+ /// \param logDir Logging output file name. If empty, the log won't output to file.
+ static void StartArrowLog(const std::string& appName,
+ ArrowLogLevel severity_threshold = ArrowLogLevel::ARROW_INFO,
+ const std::string& logDir = "");
+
+ /// The shutdown function of arrow log, it should be used with StartArrowLog as a pair.
+ static void ShutDownArrowLog();
+
+ /// Install the failure signal handler to output call stack when crash.
+ /// If glog is not installed, this function won't do anything.
+ static void InstallFailureSignalHandler();
+
+ /// Uninstall the signal actions installed by InstallFailureSignalHandler.
+ static void UninstallSignalAction();
+
+ /// Return whether or not the log level is enabled in current setting.
+ ///
+ /// \param log_level The input log level to test.
+ /// \return True if input log level is not lower than the threshold.
+ static bool IsLevelEnabled(ArrowLogLevel log_level);
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ArrowLog);
+
+ // Hide the implementation of log provider by void *.
+ // Otherwise, lib user may define the same macro to use the correct header file.
+ void* logging_provider_;
+ /// True if log messages should be logged and false if they should be ignored.
+ bool is_enabled_;
+
+ static ArrowLogLevel severity_threshold_;
+
+ protected:
+ std::ostream& Stream() override;
+};
+
+// This class make ARROW_CHECK compilation pass to change the << operator to void.
+// This class is copied from glog.
+class ARROW_EXPORT Voidify {
+ public:
+ Voidify() {}
+ // This has to be an operator with a precedence lower than << but
+ // higher than ?:
+ void operator&(ArrowLogBase&) {}
+};
+
+namespace detail {
+
+/// @brief A helper for the nil log sink.
+///
+/// Using this helper is analogous to sending log messages to /dev/null:
+/// nothing gets logged.
+class NullLog {
+ public:
+ /// The no-op output operator.
+ ///
+ /// @param [in] t
+ /// The object to send into the nil sink.
+ /// @return Reference to the updated object.
+ template <class T>
+ NullLog& operator<<(const T& t) {
+ return *this;
+ }
+};
+
+} // namespace detail
+} // namespace util
+} // namespace arrow
+
+#endif // GANDIVA_IR
diff --git a/src/arrow/cpp/src/arrow/util/logging_test.cc b/src/arrow/cpp/src/arrow/util/logging_test.cc
new file mode 100644
index 000000000..547e0bba3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/logging_test.cc
@@ -0,0 +1,103 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <chrono>
+#include <cstdint>
+#include <iostream>
+
+#include <gtest/gtest.h>
+
+#include "arrow/util/logging.h"
+
+// This code is adapted from
+// https://github.com/ray-project/ray/blob/master/src/ray/util/logging_test.cc.
+
+namespace arrow {
+namespace util {
+
+int64_t current_time_ms() {
+ std::chrono::milliseconds ms_since_epoch =
+ std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::steady_clock::now().time_since_epoch());
+ return ms_since_epoch.count();
+}
+
+// This is not really test.
+// This file just print some information using the logging macro.
+
+void PrintLog() {
+ ARROW_LOG(DEBUG) << "This is the"
+ << " DEBUG"
+ << " message";
+ ARROW_LOG(INFO) << "This is the"
+ << " INFO message";
+ ARROW_LOG(WARNING) << "This is the"
+ << " WARNING message";
+ ARROW_LOG(ERROR) << "This is the"
+ << " ERROR message";
+ ARROW_CHECK(true) << "This is a ARROW_CHECK"
+ << " message but it won't show up";
+ // The following 2 lines should not run since it will cause program failure.
+ // ARROW_LOG(FATAL) << "This is the FATAL message";
+ // ARROW_CHECK(false) << "This is a ARROW_CHECK message but it won't show up";
+}
+
+TEST(PrintLogTest, LogTestWithoutInit) {
+ // Without ArrowLog::StartArrowLog, this should also work.
+ PrintLog();
+}
+
+TEST(PrintLogTest, LogTestWithInit) {
+ // Test empty app name.
+ ArrowLog::StartArrowLog("", ArrowLogLevel::ARROW_DEBUG);
+ PrintLog();
+ ArrowLog::ShutDownArrowLog();
+}
+
+} // namespace util
+
+TEST(DcheckMacros, DoNotEvaluateReleaseMode) {
+#ifdef NDEBUG
+ int i = 0;
+ auto f1 = [&]() {
+ ++i;
+ return true;
+ };
+ DCHECK(f1());
+ ASSERT_EQ(0, i);
+ auto f2 = [&]() {
+ ++i;
+ return i;
+ };
+ DCHECK_EQ(f2(), 0);
+ DCHECK_NE(f2(), 0);
+ DCHECK_LT(f2(), 0);
+ DCHECK_LE(f2(), 0);
+ DCHECK_GE(f2(), 0);
+ DCHECK_GT(f2(), 0);
+ ASSERT_EQ(0, i);
+ ARROW_UNUSED(f1);
+ ARROW_UNUSED(f2);
+#endif
+}
+
+} // namespace arrow
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/arrow/cpp/src/arrow/util/machine_benchmark.cc b/src/arrow/cpp/src/arrow/util/machine_benchmark.cc
new file mode 100644
index 000000000..67397444b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/machine_benchmark.cc
@@ -0,0 +1,74 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Non-Arrow system benchmarks, provided for convenience.
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <random>
+#include <string>
+#include <vector>
+
+#include "benchmark/benchmark.h"
+
+namespace arrow {
+
+#ifdef ARROW_WITH_BENCHMARKS_REFERENCE
+
+// Generate a vector of indices such as following the indices describes
+// a path over the whole vector. The path is randomized to avoid triggering
+// automatic prefetching in the CPU.
+std::vector<int32_t> RandomPath(int32_t size) {
+ std::default_random_engine gen(42);
+ std::vector<int32_t> indices(size);
+
+ for (int32_t i = 0; i < size; ++i) {
+ indices[i] = i;
+ }
+ std::shuffle(indices.begin(), indices.end(), gen);
+ std::vector<int32_t> path(size, -999999);
+ int32_t prev;
+ prev = indices[size - 1];
+ for (int32_t i = 0; i < size; ++i) {
+ int32_t next = indices[i];
+ path[prev] = next;
+ prev = next;
+ }
+ return path;
+}
+
+// Cache / main memory latency, depending on the working set size
+static void memory_latency(benchmark::State& state) {
+ const auto niters = static_cast<int32_t>(state.range(0));
+ const std::vector<int32_t> path = RandomPath(niters / 4);
+
+ int32_t total = 0;
+ int32_t index = 0;
+ for (auto _ : state) {
+ total += index;
+ index = path[index];
+ }
+ benchmark::DoNotOptimize(total);
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(memory_latency)->Repetitions(1)->RangeMultiplier(2)->Range(2 << 10, 2 << 24);
+
+#endif // ARROW_WITH_BENCHMARKS_REFERENCE
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/macros.h b/src/arrow/cpp/src/arrow/util/macros.h
new file mode 100644
index 000000000..2fb383e1d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/macros.h
@@ -0,0 +1,225 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+
+#define ARROW_EXPAND(x) x
+#define ARROW_STRINGIFY(x) #x
+#define ARROW_CONCAT(x, y) x##y
+
+// From Google gutil
+#ifndef ARROW_DISALLOW_COPY_AND_ASSIGN
+#define ARROW_DISALLOW_COPY_AND_ASSIGN(TypeName) \
+ TypeName(const TypeName&) = delete; \
+ void operator=(const TypeName&) = delete
+#endif
+
+#ifndef ARROW_DEFAULT_MOVE_AND_ASSIGN
+#define ARROW_DEFAULT_MOVE_AND_ASSIGN(TypeName) \
+ TypeName(TypeName&&) = default; \
+ TypeName& operator=(TypeName&&) = default
+#endif
+
+#define ARROW_UNUSED(x) (void)(x)
+#define ARROW_ARG_UNUSED(x)
+//
+// GCC can be told that a certain branch is not likely to be taken (for
+// instance, a CHECK failure), and use that information in static analysis.
+// Giving it this information can help it optimize for the common case in
+// the absence of better information (ie. -fprofile-arcs).
+//
+#if defined(__GNUC__)
+#define ARROW_PREDICT_FALSE(x) (__builtin_expect(!!(x), 0))
+#define ARROW_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1))
+#define ARROW_NORETURN __attribute__((noreturn))
+#define ARROW_NOINLINE __attribute__((noinline))
+#define ARROW_PREFETCH(addr) __builtin_prefetch(addr)
+#elif defined(_MSC_VER)
+#define ARROW_NORETURN __declspec(noreturn)
+#define ARROW_NOINLINE __declspec(noinline)
+#define ARROW_PREDICT_FALSE(x) (x)
+#define ARROW_PREDICT_TRUE(x) (x)
+#define ARROW_PREFETCH(addr)
+#else
+#define ARROW_NORETURN
+#define ARROW_PREDICT_FALSE(x) (x)
+#define ARROW_PREDICT_TRUE(x) (x)
+#define ARROW_PREFETCH(addr)
+#endif
+
+#if (defined(__GNUC__) || defined(__APPLE__))
+#define ARROW_MUST_USE_RESULT __attribute__((warn_unused_result))
+#elif defined(_MSC_VER)
+#define ARROW_MUST_USE_RESULT
+#else
+#define ARROW_MUST_USE_RESULT
+#endif
+
+#if defined(__clang__)
+// Only clang supports warn_unused_result as a type annotation.
+#define ARROW_MUST_USE_TYPE ARROW_MUST_USE_RESULT
+#else
+#define ARROW_MUST_USE_TYPE
+#endif
+
+#if defined(__GNUC__) || defined(__clang__) || defined(_MSC_VER)
+#define ARROW_RESTRICT __restrict
+#else
+#define ARROW_RESTRICT
+#endif
+
+// ----------------------------------------------------------------------
+// C++/CLI support macros (see ARROW-1134)
+
+#ifndef NULLPTR
+
+#ifdef __cplusplus_cli
+#define NULLPTR __nullptr
+#else
+#define NULLPTR nullptr
+#endif
+
+#endif // ifndef NULLPTR
+
+// ----------------------------------------------------------------------
+
+// clang-format off
+// [[deprecated]] is only available in C++14, use this for the time being
+// This macro takes an optional deprecation message
+#ifdef __COVERITY__
+# define ARROW_DEPRECATED(...)
+# define ARROW_DEPRECATED_USING(...)
+#elif __cplusplus > 201103L
+# define ARROW_DEPRECATED(...) [[deprecated(__VA_ARGS__)]]
+# define ARROW_DEPRECATED_USING(...) ARROW_DEPRECATED(__VA_ARGS__)
+#else
+# ifdef __GNUC__
+# define ARROW_DEPRECATED(...) __attribute__((deprecated(__VA_ARGS__)))
+# define ARROW_DEPRECATED_USING(...) ARROW_DEPRECATED(__VA_ARGS__)
+# elif defined(_MSC_VER)
+# define ARROW_DEPRECATED(...) __declspec(deprecated(__VA_ARGS__))
+# define ARROW_DEPRECATED_USING(...)
+# else
+# define ARROW_DEPRECATED(...)
+# define ARROW_DEPRECATED_USING(...)
+# endif
+#endif
+
+#ifdef __COVERITY__
+# define ARROW_DEPRECATED_ENUM_VALUE(...)
+#elif __cplusplus > 201103L
+# define ARROW_DEPRECATED_ENUM_VALUE(...) [[deprecated(__VA_ARGS__)]]
+#else
+# if defined(__GNUC__) && __GNUC__ >= 6
+# define ARROW_DEPRECATED_ENUM_VALUE(...) __attribute__((deprecated(__VA_ARGS__)))
+# else
+# define ARROW_DEPRECATED_ENUM_VALUE(...)
+# endif
+#endif
+
+// clang-format on
+
+// Macros to disable deprecation warnings
+
+#ifdef __clang__
+#define ARROW_SUPPRESS_DEPRECATION_WARNING \
+ _Pragma("clang diagnostic push"); \
+ _Pragma("clang diagnostic ignored \"-Wdeprecated-declarations\"")
+#define ARROW_UNSUPPRESS_DEPRECATION_WARNING _Pragma("clang diagnostic pop")
+#elif defined(__GNUC__)
+#define ARROW_SUPPRESS_DEPRECATION_WARNING \
+ _Pragma("GCC diagnostic push"); \
+ _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"")
+#define ARROW_UNSUPPRESS_DEPRECATION_WARNING _Pragma("GCC diagnostic pop")
+#elif defined(_MSC_VER)
+#define ARROW_SUPPRESS_DEPRECATION_WARNING \
+ __pragma(warning(push)) __pragma(warning(disable : 4996))
+#define ARROW_UNSUPPRESS_DEPRECATION_WARNING __pragma(warning(pop))
+#else
+#define ARROW_SUPPRESS_DEPRECATION_WARNING
+#define ARROW_UNSUPPRESS_DEPRECATION_WARNING
+#endif
+
+// ----------------------------------------------------------------------
+
+// macros to disable padding
+// these macros are portable across different compilers and platforms
+//[https://github.com/google/flatbuffers/blob/master/include/flatbuffers/flatbuffers.h#L1355]
+#if !defined(MANUALLY_ALIGNED_STRUCT)
+#if defined(_MSC_VER)
+#define MANUALLY_ALIGNED_STRUCT(alignment) \
+ __pragma(pack(1)); \
+ struct __declspec(align(alignment))
+#define STRUCT_END(name, size) \
+ __pragma(pack()); \
+ static_assert(sizeof(name) == size, "compiler breaks packing rules")
+#elif defined(__GNUC__) || defined(__clang__)
+#define MANUALLY_ALIGNED_STRUCT(alignment) \
+ _Pragma("pack(1)") struct __attribute__((aligned(alignment)))
+#define STRUCT_END(name, size) \
+ _Pragma("pack()") static_assert(sizeof(name) == size, "compiler breaks packing rules")
+#else
+#error Unknown compiler, please define structure alignment macros
+#endif
+#endif // !defined(MANUALLY_ALIGNED_STRUCT)
+
+// ----------------------------------------------------------------------
+// Convenience macro disabling a particular UBSan check in a function
+
+#if defined(__clang__)
+#define ARROW_DISABLE_UBSAN(feature) __attribute__((no_sanitize(feature)))
+#else
+#define ARROW_DISABLE_UBSAN(feature)
+#endif
+
+// ----------------------------------------------------------------------
+// Machine information
+
+#if INTPTR_MAX == INT64_MAX
+#define ARROW_BITNESS 64
+#elif INTPTR_MAX == INT32_MAX
+#define ARROW_BITNESS 32
+#else
+#error Unexpected INTPTR_MAX
+#endif
+
+// ----------------------------------------------------------------------
+// From googletest
+// (also in parquet-cpp)
+
+// When you need to test the private or protected members of a class,
+// use the FRIEND_TEST macro to declare your tests as friends of the
+// class. For example:
+//
+// class MyClass {
+// private:
+// void MyMethod();
+// FRIEND_TEST(MyClassTest, MyMethod);
+// };
+//
+// class MyClassTest : public testing::Test {
+// // ...
+// };
+//
+// TEST_F(MyClassTest, MyMethod) {
+// // Can call MyClass::MyMethod() here.
+// }
+
+#define FRIEND_TEST(test_case_name, test_name) \
+ friend class test_case_name##_##test_name##_Test
diff --git a/src/arrow/cpp/src/arrow/util/make_unique.h b/src/arrow/cpp/src/arrow/util/make_unique.h
new file mode 100644
index 000000000..850e20409
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/make_unique.h
@@ -0,0 +1,42 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <type_traits>
+#include <utility>
+
+namespace arrow {
+namespace internal {
+
+template <typename T, typename... A>
+typename std::enable_if<!std::is_array<T>::value, std::unique_ptr<T>>::type make_unique(
+ A&&... args) {
+ return std::unique_ptr<T>(new T(std::forward<A>(args)...));
+}
+
+template <typename T>
+typename std::enable_if<std::is_array<T>::value && std::extent<T>::value == 0,
+ std::unique_ptr<T>>::type
+make_unique(std::size_t n) {
+ using value_type = typename std::remove_extent<T>::type;
+ return std::unique_ptr<value_type[]>(new value_type[n]);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/map.h b/src/arrow/cpp/src/arrow/util/map.h
new file mode 100644
index 000000000..552390906
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/map.h
@@ -0,0 +1,63 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <utility>
+
+#include "arrow/result.h"
+
+namespace arrow {
+namespace internal {
+
+/// Helper providing single-lookup conditional insertion into std::map or
+/// std::unordered_map. If `key` exists in the container, an iterator to that pair
+/// will be returned. If `key` does not exist in the container, `gen(key)` will be
+/// invoked and its return value inserted.
+template <typename Map, typename Gen>
+auto GetOrInsertGenerated(Map* map, typename Map::key_type key, Gen&& gen)
+ -> decltype(map->begin()->second = gen(map->begin()->first), map->begin()) {
+ decltype(gen(map->begin()->first)) placeholder{};
+
+ auto it_success = map->emplace(std::move(key), std::move(placeholder));
+ if (it_success.second) {
+ // insertion of placeholder succeeded, overwrite it with gen()
+ const auto& inserted_key = it_success.first->first;
+ auto* value = &it_success.first->second;
+ *value = gen(inserted_key);
+ }
+ return it_success.first;
+}
+
+template <typename Map, typename Gen>
+auto GetOrInsertGenerated(Map* map, typename Map::key_type key, Gen&& gen)
+ -> Result<decltype(map->begin()->second = gen(map->begin()->first).ValueOrDie(),
+ map->begin())> {
+ decltype(gen(map->begin()->first).ValueOrDie()) placeholder{};
+
+ auto it_success = map->emplace(std::move(key), std::move(placeholder));
+ if (it_success.second) {
+ // insertion of placeholder succeeded, overwrite it with gen()
+ const auto& inserted_key = it_success.first->first;
+ auto* value = &it_success.first->second;
+ ARROW_ASSIGN_OR_RAISE(*value, gen(inserted_key));
+ }
+ return it_success.first;
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/math_constants.h b/src/arrow/cpp/src/arrow/util/math_constants.h
new file mode 100644
index 000000000..7ee87c5d6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/math_constants.h
@@ -0,0 +1,32 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cmath>
+
+// Not provided by default in MSVC,
+// and _USE_MATH_DEFINES is not reliable with unity builds
+#ifndef M_PI
+#define M_PI 3.14159265358979323846
+#endif
+#ifndef M_PI_2
+#define M_PI_2 1.57079632679489661923
+#endif
+#ifndef M_PI_4
+#define M_PI_4 0.785398163397448309616
+#endif
diff --git a/src/arrow/cpp/src/arrow/util/memory.cc b/src/arrow/cpp/src/arrow/util/memory.cc
new file mode 100644
index 000000000..e91009d58
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/memory.cc
@@ -0,0 +1,74 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <vector>
+
+#include "arrow/util/logging.h"
+#include "arrow/util/memory.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+namespace internal {
+
+inline uint8_t* pointer_logical_and(const uint8_t* address, uintptr_t bits) {
+ uintptr_t value = reinterpret_cast<uintptr_t>(address);
+ return reinterpret_cast<uint8_t*>(value & bits);
+}
+
+// This function is just for avoiding MinGW-w64 32bit crash.
+// See also: https://sourceforge.net/p/mingw-w64/bugs/767/
+void* wrap_memcpy(void* dst, const void* src, size_t n) { return memcpy(dst, src, n); }
+
+void parallel_memcopy(uint8_t* dst, const uint8_t* src, int64_t nbytes,
+ uintptr_t block_size, int num_threads) {
+ // XXX This function is really using `num_threads + 1` threads.
+ auto pool = GetCpuThreadPool();
+
+ uint8_t* left = pointer_logical_and(src + block_size - 1, ~(block_size - 1));
+ uint8_t* right = pointer_logical_and(src + nbytes, ~(block_size - 1));
+ int64_t num_blocks = (right - left) / block_size;
+
+ // Update right address
+ right = right - (num_blocks % num_threads) * block_size;
+
+ // Now we divide these blocks between available threads. The remainder is
+ // handled separately.
+ size_t chunk_size = (right - left) / num_threads;
+ int64_t prefix = left - src;
+ int64_t suffix = src + nbytes - right;
+ // Now the data layout is | prefix | k * num_threads * block_size | suffix |.
+ // We have chunk_size = k * block_size, therefore the data layout is
+ // | prefix | num_threads * chunk_size | suffix |.
+ // Each thread gets a "chunk" of k blocks.
+
+ // Start all parallel memcpy tasks and handle leftovers while threads run.
+ std::vector<Future<void*>> futures;
+
+ for (int i = 0; i < num_threads; i++) {
+ futures.push_back(*pool->Submit(wrap_memcpy, dst + prefix + i * chunk_size,
+ left + i * chunk_size, chunk_size));
+ }
+ memcpy(dst, src, prefix);
+ memcpy(dst + prefix + num_threads * chunk_size, right, suffix);
+
+ for (auto& fut : futures) {
+ ARROW_CHECK_OK(fut.status());
+ }
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/memory.h b/src/arrow/cpp/src/arrow/util/memory.h
new file mode 100644
index 000000000..4250d0694
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/memory.h
@@ -0,0 +1,43 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+// A helper function for doing memcpy with multiple threads. This is required
+// to saturate the memory bandwidth of modern cpus.
+void parallel_memcopy(uint8_t* dst, const uint8_t* src, int64_t nbytes,
+ uintptr_t block_size, int num_threads);
+
+// A helper function for checking if two wrapped objects implementing `Equals`
+// are equal.
+template <typename T>
+bool SharedPtrEquals(const std::shared_ptr<T>& left, const std::shared_ptr<T>& right) {
+ if (left == right) return true;
+ if (left == NULLPTR || right == NULLPTR) return false;
+ return left->Equals(*right);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/mutex.cc b/src/arrow/cpp/src/arrow/util/mutex.cc
new file mode 100644
index 000000000..7456d7889
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/mutex.cc
@@ -0,0 +1,54 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/mutex.h"
+
+#include <mutex>
+
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace util {
+
+struct Mutex::Impl {
+ std::mutex mutex_;
+};
+
+Mutex::Guard::Guard(Mutex* locked)
+ : locked_(locked, [](Mutex* locked) {
+ DCHECK(!locked->impl_->mutex_.try_lock());
+ locked->impl_->mutex_.unlock();
+ }) {}
+
+Mutex::Guard Mutex::TryLock() {
+ DCHECK_NE(impl_, nullptr);
+ if (impl_->mutex_.try_lock()) {
+ return Guard{this};
+ }
+ return Guard{};
+}
+
+Mutex::Guard Mutex::Lock() {
+ DCHECK_NE(impl_, nullptr);
+ impl_->mutex_.lock();
+ return Guard{this};
+}
+
+Mutex::Mutex() : impl_(new Impl, [](Impl* impl) { delete impl; }) {}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/mutex.h b/src/arrow/cpp/src/arrow/util/mutex.h
new file mode 100644
index 000000000..f4fc64181
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/mutex.h
@@ -0,0 +1,64 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace util {
+
+/// A wrapper around std::mutex since we can't use it directly in
+/// public headers due to C++/CLI.
+/// https://docs.microsoft.com/en-us/cpp/standard-library/mutex#remarks
+class ARROW_EXPORT Mutex {
+ public:
+ Mutex();
+ Mutex(Mutex&&) = default;
+ Mutex& operator=(Mutex&&) = default;
+
+ /// A Guard is falsy if a lock could not be acquired.
+ class ARROW_EXPORT Guard {
+ public:
+ Guard() : locked_(NULLPTR, [](Mutex* mutex) {}) {}
+ Guard(Guard&&) = default;
+ Guard& operator=(Guard&&) = default;
+
+ explicit operator bool() const { return bool(locked_); }
+
+ void Unlock() { locked_.reset(); }
+
+ private:
+ explicit Guard(Mutex* locked);
+
+ std::unique_ptr<Mutex, void (*)(Mutex*)> locked_;
+ friend Mutex;
+ };
+
+ Guard TryLock();
+ Guard Lock();
+
+ private:
+ struct Impl;
+ std::unique_ptr<Impl, void (*)(Impl*)> impl_;
+};
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/optional.h b/src/arrow/cpp/src/arrow/util/optional.h
new file mode 100644
index 000000000..e1c32e761
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/optional.h
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#define optional_CONFIG_SELECT_OPTIONAL optional_OPTIONAL_NONSTD
+
+#include "arrow/vendored/optional.hpp" // IWYU pragma: export
+
+namespace arrow {
+namespace util {
+
+template <typename T>
+using optional = nonstd::optional<T>;
+
+using nonstd::bad_optional_access;
+using nonstd::make_optional;
+using nonstd::nullopt;
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/parallel.h b/src/arrow/cpp/src/arrow/util/parallel.h
new file mode 100644
index 000000000..80f60fbdb
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/parallel.h
@@ -0,0 +1,102 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <utility>
+#include <vector>
+
+#include "arrow/status.h"
+#include "arrow/util/functional.h"
+#include "arrow/util/thread_pool.h"
+#include "arrow/util/vector.h"
+
+namespace arrow {
+namespace internal {
+
+// A parallelizer that takes a `Status(int)` function and calls it with
+// arguments between 0 and `num_tasks - 1`, on an arbitrary number of threads.
+
+template <class FUNCTION>
+Status ParallelFor(int num_tasks, FUNCTION&& func,
+ Executor* executor = internal::GetCpuThreadPool()) {
+ std::vector<Future<>> futures(num_tasks);
+
+ for (int i = 0; i < num_tasks; ++i) {
+ ARROW_ASSIGN_OR_RAISE(futures[i], executor->Submit(func, i));
+ }
+ auto st = Status::OK();
+ for (auto& fut : futures) {
+ st &= fut.status();
+ }
+ return st;
+}
+
+template <class FUNCTION, typename T,
+ typename R = typename internal::call_traits::return_type<FUNCTION>::ValueType>
+Future<std::vector<R>> ParallelForAsync(
+ std::vector<T> inputs, FUNCTION&& func,
+ Executor* executor = internal::GetCpuThreadPool()) {
+ std::vector<Future<R>> futures(inputs.size());
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(futures[i], executor->Submit(func, i, std::move(inputs[i])));
+ }
+ return All(std::move(futures))
+ .Then([](const std::vector<Result<R>>& results) -> Result<std::vector<R>> {
+ return UnwrapOrRaise(results);
+ });
+}
+
+// A parallelizer that takes a `Status(int)` function and calls it with
+// arguments between 0 and `num_tasks - 1`, in sequence or in parallel,
+// depending on the input boolean.
+
+template <class FUNCTION>
+Status OptionalParallelFor(bool use_threads, int num_tasks, FUNCTION&& func,
+ Executor* executor = internal::GetCpuThreadPool()) {
+ if (use_threads) {
+ return ParallelFor(num_tasks, std::forward<FUNCTION>(func), executor);
+ } else {
+ for (int i = 0; i < num_tasks; ++i) {
+ RETURN_NOT_OK(func(i));
+ }
+ return Status::OK();
+ }
+}
+
+// A parallelizer that takes a `Result<R>(int index, T item)` function and
+// calls it with each item from the input array, in sequence or in parallel,
+// depending on the input boolean.
+
+template <class FUNCTION, typename T,
+ typename R = typename internal::call_traits::return_type<FUNCTION>::ValueType>
+Future<std::vector<R>> OptionalParallelForAsync(
+ bool use_threads, std::vector<T> inputs, FUNCTION&& func,
+ Executor* executor = internal::GetCpuThreadPool()) {
+ if (use_threads) {
+ return ParallelForAsync(std::move(inputs), std::forward<FUNCTION>(func), executor);
+ } else {
+ std::vector<R> result(inputs.size());
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(result[i], func(i, inputs[i]));
+ }
+ return result;
+ }
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/pcg_random.h b/src/arrow/cpp/src/arrow/util/pcg_random.h
new file mode 100644
index 000000000..a53e9ec31
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/pcg_random.h
@@ -0,0 +1,31 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/vendored/pcg/pcg_random.hpp" // IWYU pragma: export
+
+namespace arrow {
+namespace random {
+
+using pcg32 = ::arrow_vendored::pcg32;
+using pcg64 = ::arrow_vendored::pcg64;
+using pcg32_fast = ::arrow_vendored::pcg32_fast;
+using pcg64_fast = ::arrow_vendored::pcg64_fast;
+
+} // namespace random
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/print.h b/src/arrow/cpp/src/arrow/util/print.h
new file mode 100644
index 000000000..d11aa443a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/print.h
@@ -0,0 +1,51 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License. template <typename T>
+
+#pragma once
+
+#include <tuple>
+
+namespace arrow {
+namespace internal {
+
+namespace detail {
+
+template <typename OStream, typename Tuple, size_t N>
+struct TuplePrinter {
+ static void Print(OStream* os, const Tuple& t) {
+ TuplePrinter<OStream, Tuple, N - 1>::Print(os, t);
+ *os << std::get<N - 1>(t);
+ }
+};
+
+template <typename OStream, typename Tuple>
+struct TuplePrinter<OStream, Tuple, 0> {
+ static void Print(OStream* os, const Tuple& t) {}
+};
+
+} // namespace detail
+
+// Print elements from a tuple to a stream, in order.
+// Typical use is to pack a bunch of existing values with std::forward_as_tuple()
+// before passing it to this function.
+template <typename OStream, typename... Args>
+void PrintTuple(OStream* os, const std::tuple<Args&...>& tup) {
+ detail::TuplePrinter<OStream, std::tuple<Args&...>, sizeof...(Args)>::Print(os, tup);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/queue.h b/src/arrow/cpp/src/arrow/util/queue.h
new file mode 100644
index 000000000..6c71fa6e1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/queue.h
@@ -0,0 +1,29 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/vendored/ProducerConsumerQueue.h"
+
+namespace arrow {
+namespace util {
+
+template <typename T>
+using SpscQueue = arrow_vendored::folly::ProducerConsumerQueue<T>;
+
+}
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/queue_benchmark.cc b/src/arrow/cpp/src/arrow/util/queue_benchmark.cc
new file mode 100644
index 000000000..675bef831
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/queue_benchmark.cc
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <iterator>
+#include <thread>
+#include <vector>
+
+#include <benchmark/benchmark.h>
+
+#include "arrow/buffer.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/queue.h"
+
+namespace arrow {
+
+namespace util {
+
+static constexpr int64_t kSize = 100000;
+
+void SpscQueueThroughput(benchmark::State& state) {
+ SpscQueue<std::shared_ptr<Buffer>> queue(16);
+
+ std::vector<std::shared_ptr<Buffer>> source;
+ std::vector<std::shared_ptr<Buffer>> sink;
+ source.reserve(kSize);
+ sink.resize(kSize);
+ const uint8_t data[1] = {0};
+ for (int64_t i = 0; i < kSize; i++) {
+ source.push_back(std::make_shared<Buffer>(data, 1));
+ }
+
+ for (auto _ : state) {
+ std::thread producer([&] {
+ auto itr = std::make_move_iterator(source.begin());
+ auto end = std::make_move_iterator(source.end());
+ while (itr != end) {
+ while (!queue.Write(*itr)) {
+ }
+ itr++;
+ }
+ });
+
+ std::thread consumer([&] {
+ auto itr = sink.begin();
+ auto end = sink.end();
+ while (itr != end) {
+ auto next = queue.FrontPtr();
+ if (next != nullptr) {
+ (*itr).swap(*next);
+ queue.PopFront();
+ itr++;
+ }
+ }
+ });
+
+ producer.join();
+ consumer.join();
+ std::swap(source, sink);
+ }
+
+ for (const auto& buf : source) {
+ ARROW_CHECK(buf && buf->size() == 1);
+ }
+ state.SetItemsProcessed(state.iterations() * kSize);
+}
+
+BENCHMARK(SpscQueueThroughput)->UseRealTime();
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/queue_test.cc b/src/arrow/cpp/src/arrow/util/queue_test.cc
new file mode 100644
index 000000000..388e4f11b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/queue_test.cc
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/queue.h"
+
+namespace arrow {
+namespace util {
+
+TEST(TestSpscQueue, TestMoveOnly) {
+ SpscQueue<MoveOnlyDataType> queue(3);
+ ASSERT_TRUE(queue.IsEmpty());
+ ASSERT_FALSE(queue.IsFull());
+ ASSERT_EQ(queue.SizeGuess(), 0);
+
+ MoveOnlyDataType in(42);
+ queue.Write(std::move(in));
+ ASSERT_FALSE(queue.IsEmpty());
+ ASSERT_FALSE(queue.IsFull());
+ ASSERT_EQ(queue.SizeGuess(), 1);
+
+ queue.Write(43);
+ ASSERT_FALSE(queue.IsEmpty());
+ ASSERT_TRUE(queue.IsFull());
+ ASSERT_EQ(queue.SizeGuess(), 2);
+
+ MoveOnlyDataType out = std::move(*queue.FrontPtr());
+ ASSERT_EQ(42, *out.data);
+ queue.PopFront();
+ ASSERT_TRUE(queue.Read(out));
+ ASSERT_EQ(43, *out.data);
+
+ ASSERT_TRUE(queue.IsEmpty());
+ ASSERT_FALSE(queue.IsFull());
+ ASSERT_EQ(queue.SizeGuess(), 0);
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/range.h b/src/arrow/cpp/src/arrow/util/range.h
new file mode 100644
index 000000000..ea0fb0eea
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/range.h
@@ -0,0 +1,155 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <numeric>
+#include <utility>
+#include <vector>
+
+namespace arrow {
+namespace internal {
+
+/// Create a vector containing the values from start up to stop
+template <typename T>
+std::vector<T> Iota(T start, T stop) {
+ if (start > stop) {
+ return {};
+ }
+ std::vector<T> result(static_cast<size_t>(stop - start));
+ std::iota(result.begin(), result.end(), start);
+ return result;
+}
+
+/// Create a vector containing the values from 0 up to length
+template <typename T>
+std::vector<T> Iota(T length) {
+ return Iota(static_cast<T>(0), length);
+}
+
+/// Create a range from a callable which takes a single index parameter
+/// and returns the value of iterator on each call and a length.
+/// Only iterators obtained from the same range should be compared, the
+/// behaviour generally similar to other STL containers.
+template <typename Generator>
+class LazyRange {
+ private:
+ // callable which generates the values
+ // has to be defined at the beginning of the class for type deduction
+ const Generator gen_;
+ // the length of the range
+ int64_t length_;
+#ifdef _MSC_VER
+ // workaround to VS2010 not supporting decltype properly
+ // see https://stackoverflow.com/questions/21782846/decltype-for-class-member-function
+ static Generator gen_static_;
+#endif
+
+ public:
+#ifdef _MSC_VER
+ using return_type = decltype(gen_static_(0));
+#else
+ using return_type = decltype(gen_(0));
+#endif
+
+ /// Construct a new range from a callable and length
+ LazyRange(Generator gen, int64_t length) : gen_(gen), length_(length) {}
+
+ // Class of the dependent iterator, created implicitly by begin and end
+ class RangeIter {
+ public:
+ using difference_type = int64_t;
+ using value_type = return_type;
+ using reference = const value_type&;
+ using pointer = const value_type*;
+ using iterator_category = std::forward_iterator_tag;
+
+#ifdef _MSC_VER
+ // msvc complains about unchecked iterators,
+ // see https://stackoverflow.com/questions/21655496/error-c4996-checked-iterators
+ using _Unchecked_type = typename LazyRange<Generator>::RangeIter;
+#endif
+
+ RangeIter() = delete;
+ RangeIter(const RangeIter& other) = default;
+ RangeIter& operator=(const RangeIter& other) = default;
+
+ RangeIter(const LazyRange<Generator>& range, int64_t index)
+ : range_(&range), index_(index) {}
+
+ const return_type operator*() const { return range_->gen_(index_); }
+
+ RangeIter operator+(difference_type length) const {
+ return RangeIter(*range_, index_ + length);
+ }
+
+ // pre-increment
+ RangeIter& operator++() {
+ ++index_;
+ return *this;
+ }
+
+ // post-increment
+ RangeIter operator++(int) {
+ auto copy = RangeIter(*this);
+ ++index_;
+ return copy;
+ }
+
+ bool operator==(const typename LazyRange<Generator>::RangeIter& other) const {
+ return this->index_ == other.index_ && this->range_ == other.range_;
+ }
+
+ bool operator!=(const typename LazyRange<Generator>::RangeIter& other) const {
+ return this->index_ != other.index_ || this->range_ != other.range_;
+ }
+
+ int64_t operator-(const typename LazyRange<Generator>::RangeIter& other) const {
+ return this->index_ - other.index_;
+ }
+
+ bool operator<(const typename LazyRange<Generator>::RangeIter& other) const {
+ return this->index_ < other.index_;
+ }
+
+ private:
+ // parent range reference
+ const LazyRange* range_;
+ // current index
+ int64_t index_;
+ };
+
+ friend class RangeIter;
+
+ // Create a new begin const iterator
+ RangeIter begin() { return RangeIter(*this, 0); }
+
+ // Create a new end const iterator
+ RangeIter end() { return RangeIter(*this, length_); }
+};
+
+/// Helper function to create a lazy range from a callable (e.g. lambda) and length
+template <typename Generator>
+LazyRange<Generator> MakeLazyRange(Generator&& gen, int64_t length) {
+ return LazyRange<Generator>(std::forward<Generator>(gen), length);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/range_benchmark.cc b/src/arrow/cpp/src/arrow/util/range_benchmark.cc
new file mode 100644
index 000000000..204fd24f7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/range_benchmark.cc
@@ -0,0 +1,128 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <iterator>
+#include <vector>
+
+#include <benchmark/benchmark.h>
+
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/range.h"
+
+namespace arrow {
+
+#ifdef ARROW_WITH_BENCHMARKS_REFERENCE
+
+static constexpr int64_t kSize = 100000000;
+
+template <typename T = int32_t>
+std::vector<T> generate_junk(int64_t size) {
+ std::vector<T> v(size);
+ randint(size, 0, 100000, &v);
+ return v;
+}
+
+// Baseline
+void for_loop(benchmark::State& state) {
+ auto source = generate_junk(kSize);
+ std::vector<int> target(kSize);
+
+ for (auto _ : state) {
+ for (int64_t index = 0; index < kSize; ++index) target[index] = source[index] + 1;
+ }
+}
+
+BENCHMARK(for_loop);
+
+// For comparison: pure copy without any changes
+void std_copy(benchmark::State& state) {
+ auto source = generate_junk(kSize);
+ std::vector<int> target(kSize);
+
+ for (auto _ : state) {
+ std::copy(source.begin(), source.end(), target.begin());
+ }
+}
+
+BENCHMARK(std_copy);
+
+// For comparison: pure copy with type conversion.
+void std_copy_converting(benchmark::State& state) {
+ auto source = generate_junk<int32_t>(kSize);
+ // bigger type to avoid warnings
+ std::vector<int64_t> target(kSize);
+
+ for (auto _ : state) {
+ std::copy(source.begin(), source.end(), target.begin());
+ }
+}
+
+BENCHMARK(std_copy_converting);
+
+// std::copy with a lazy range as a source
+void lazy_copy(benchmark::State& state) {
+ auto source = generate_junk(kSize);
+ std::vector<int> target(kSize);
+ auto lazy_range = internal::MakeLazyRange(
+ [&source](int64_t index) { return source[index]; }, source.size());
+
+ for (auto _ : state) {
+ std::copy(lazy_range.begin(), lazy_range.end(), target.begin());
+ }
+}
+
+BENCHMARK(lazy_copy);
+
+// std::copy with a lazy range which does static cast.
+// Should be the same performance as std::copy with differently typed iterators
+void lazy_copy_converting(benchmark::State& state) {
+ auto source = generate_junk<int64_t>(kSize);
+ std::vector<int32_t> target(kSize);
+ auto lazy_range = internal::MakeLazyRange(
+ [&source](int64_t index) { return static_cast<int32_t>(source[index]); },
+ source.size());
+
+ for (auto _ : state) {
+ std::copy(lazy_range.begin(), lazy_range.end(), target.begin());
+ }
+}
+
+BENCHMARK(lazy_copy_converting);
+
+// For loop with a post-increment of a lazy operator
+void lazy_postinc(benchmark::State& state) {
+ auto source = generate_junk(kSize);
+ std::vector<int> target(kSize);
+ auto lazy_range = internal::MakeLazyRange(
+ [&source](int64_t index) { return source[index]; }, source.size());
+
+ for (auto _ : state) {
+ auto lazy_iter = lazy_range.begin();
+ auto lazy_end = lazy_range.end();
+ auto target_iter = target.begin();
+
+ while (lazy_iter != lazy_end) *(target_iter++) = *(lazy_iter++);
+ }
+}
+
+BENCHMARK(lazy_postinc);
+
+#endif // ARROW_WITH_BENCHMARKS_REFERENCE
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/range_test.cc b/src/arrow/cpp/src/arrow/util/range_test.cc
new file mode 100644
index 000000000..7fedcde99
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/range_test.cc
@@ -0,0 +1,69 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/range.h"
+
+namespace arrow {
+
+class TestLazyIter : public ::testing::Test {
+ public:
+ int64_t kSize = 1000;
+ void SetUp() {
+ randint(kSize, 0, 1000000, &source_);
+ target_.resize(kSize);
+ }
+
+ protected:
+ std::vector<int> source_;
+ std::vector<int> target_;
+};
+
+TEST_F(TestLazyIter, TestIncrementCopy) {
+ auto add_one = [this](int64_t index) { return source_[index] + 1; };
+ auto lazy_range = internal::MakeLazyRange(add_one, kSize);
+ std::copy(lazy_range.begin(), lazy_range.end(), target_.begin());
+
+ for (int64_t index = 0; index < kSize; ++index) {
+ ASSERT_EQ(source_[index] + 1, target_[index]);
+ }
+}
+
+TEST_F(TestLazyIter, TestPostIncrementCopy) {
+ auto add_one = [this](int64_t index) { return source_[index] + 1; };
+ auto lazy_range = internal::MakeLazyRange(add_one, kSize);
+ auto iter = lazy_range.begin();
+ auto end = lazy_range.end();
+ auto target_iter = target_.begin();
+
+ while (iter != end) {
+ *(target_iter++) = *(iter++);
+ }
+
+ for (size_t index = 0, limit = source_.size(); index != limit; ++index) {
+ ASSERT_EQ(source_[index] + 1, target_[index]);
+ }
+}
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/reflection_internal.h b/src/arrow/cpp/src/arrow/util/reflection_internal.h
new file mode 100644
index 000000000..0440a2eb5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/reflection_internal.h
@@ -0,0 +1,133 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+#include <tuple>
+#include <utility>
+
+#include "arrow/type_traits.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace internal {
+
+template <size_t...>
+struct index_sequence {};
+
+template <size_t N, size_t Head = N, size_t... Tail>
+struct make_index_sequence_impl;
+
+template <size_t N>
+using make_index_sequence = typename make_index_sequence_impl<N>::type;
+
+template <typename... T>
+using index_sequence_for = make_index_sequence<sizeof...(T)>;
+
+template <size_t N, size_t... I>
+struct make_index_sequence_impl<N, 0, I...> {
+ using type = index_sequence<I...>;
+};
+
+template <size_t N, size_t H, size_t... I>
+struct make_index_sequence_impl : make_index_sequence_impl<N, H - 1, H - 1, I...> {};
+
+static_assert(std::is_same<index_sequence<>, make_index_sequence<0>>::value, "");
+static_assert(std::is_same<index_sequence<0, 1, 2>, make_index_sequence<3>>::value, "");
+
+template <typename...>
+struct all_same : std::true_type {};
+
+template <typename One>
+struct all_same<One> : std::true_type {};
+
+template <typename Same, typename... Rest>
+struct all_same<Same, Same, Rest...> : all_same<Same, Rest...> {};
+
+template <typename One, typename Other, typename... Rest>
+struct all_same<One, Other, Rest...> : std::false_type {};
+
+template <size_t... I, typename... T, typename Fn>
+void ForEachTupleMemberImpl(const std::tuple<T...>& tup, Fn&& fn, index_sequence<I...>) {
+ (void)std::make_tuple((fn(std::get<I>(tup), I), std::ignore)...);
+}
+
+template <typename... T, typename Fn>
+void ForEachTupleMember(const std::tuple<T...>& tup, Fn&& fn) {
+ ForEachTupleMemberImpl(tup, fn, index_sequence_for<T...>());
+}
+
+template <typename C, typename T>
+struct DataMemberProperty {
+ using Class = C;
+ using Type = T;
+
+ constexpr const Type& get(const Class& obj) const { return obj.*ptr_; }
+
+ void set(Class* obj, Type value) const { (*obj).*ptr_ = std::move(value); }
+
+ constexpr util::string_view name() const { return name_; }
+
+ util::string_view name_;
+ Type Class::*ptr_;
+};
+
+template <typename Class, typename Type>
+constexpr DataMemberProperty<Class, Type> DataMember(util::string_view name,
+ Type Class::*ptr) {
+ return {name, ptr};
+}
+
+template <typename... Properties>
+struct PropertyTuple {
+ template <typename Fn>
+ void ForEach(Fn&& fn) const {
+ ForEachTupleMember(props_, fn);
+ }
+
+ static_assert(all_same<typename Properties::Class...>::value,
+ "All properties must be properties of the same class");
+
+ size_t size() const { return sizeof...(Properties); }
+
+ std::tuple<Properties...> props_;
+};
+
+template <typename... Properties>
+PropertyTuple<Properties...> MakeProperties(Properties... props) {
+ return {std::make_tuple(props...)};
+}
+
+template <typename Enum>
+struct EnumTraits {};
+
+template <typename Enum, Enum... Values>
+struct BasicEnumTraits {
+ using CType = typename std::underlying_type<Enum>::type;
+ using Type = typename CTypeTraits<CType>::ArrowType;
+ static std::array<Enum, sizeof...(Values)> values() { return {Values...}; }
+};
+
+template <typename T, typename Enable = void>
+struct has_enum_traits : std::false_type {};
+
+template <typename T>
+struct has_enum_traits<T, void_t<typename EnumTraits<T>::Type>> : std::true_type {};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/reflection_test.cc b/src/arrow/cpp/src/arrow/util/reflection_test.cc
new file mode 100644
index 000000000..fb3d3b8fb
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/reflection_test.cc
@@ -0,0 +1,224 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+
+#include <gtest/gtest.h>
+
+#include "arrow/util/reflection_internal.h"
+#include "arrow/util/string.h"
+
+namespace arrow {
+namespace internal {
+
+// generic property-based equality comparison
+template <typename Class>
+struct EqualsImpl {
+ template <typename Properties>
+ EqualsImpl(const Class& l, const Class& r, const Properties& props)
+ : left_(l), right_(r) {
+ props.ForEach(*this);
+ }
+
+ template <typename Property>
+ void operator()(const Property& prop, size_t i) {
+ equal_ &= prop.get(left_) == prop.get(right_);
+ }
+
+ const Class& left_;
+ const Class& right_;
+ bool equal_ = true;
+};
+
+// generic property-based serialization
+template <typename Class>
+struct ToStringImpl {
+ template <typename Properties>
+ ToStringImpl(util::string_view class_name, const Class& obj, const Properties& props)
+ : class_name_(class_name), obj_(obj), members_(props.size()) {
+ props.ForEach(*this);
+ }
+
+ template <typename Property>
+ void operator()(const Property& prop, size_t i) {
+ std::stringstream ss;
+ ss << prop.name() << ":" << prop.get(obj_);
+ members_[i] = ss.str();
+ }
+
+ std::string Finish() {
+ return class_name_.to_string() + "{" + JoinStrings(members_, ",") + "}";
+ }
+
+ util::string_view class_name_;
+ const Class& obj_;
+ std::vector<std::string> members_;
+};
+
+// generic property-based deserialization
+template <typename Class>
+struct FromStringImpl {
+ template <typename Properties>
+ FromStringImpl(util::string_view class_name, util::string_view repr,
+ const Properties& props) {
+ Init(class_name, repr, props.size());
+ props.ForEach(*this);
+ }
+
+ void Fail() { obj_ = util::nullopt; }
+
+ void Init(util::string_view class_name, util::string_view repr, size_t num_properties) {
+ if (!repr.starts_with(class_name)) return Fail();
+
+ repr = repr.substr(class_name.size());
+ if (repr.empty()) return Fail();
+ if (repr.front() != '{') return Fail();
+ if (repr.back() != '}') return Fail();
+
+ repr = repr.substr(1, repr.size() - 2);
+ members_ = SplitString(repr, ',');
+ if (members_.size() != num_properties) return Fail();
+ }
+
+ template <typename Property>
+ void operator()(const Property& prop, size_t i) {
+ if (!obj_) return;
+
+ auto first_colon = members_[i].find_first_of(':');
+ if (first_colon == util::string_view::npos) return Fail();
+
+ auto name = members_[i].substr(0, first_colon);
+ if (name != prop.name()) return Fail();
+
+ auto value_repr = members_[i].substr(first_colon + 1);
+ typename Property::Type value;
+ try {
+ std::stringstream ss(value_repr.to_string());
+ ss >> value;
+ if (!ss.eof()) return Fail();
+ } catch (...) {
+ return Fail();
+ }
+ prop.set(&*obj_, std::move(value));
+ }
+
+ util::optional<Class> obj_ = Class{};
+ std::vector<util::string_view> members_;
+};
+
+// unmodified structure which we wish to reflect on:
+struct Person {
+ int age;
+ std::string name;
+};
+
+// enumeration of properties:
+// NB: no references to Person::age or Person::name after this
+// NB: ordering of properties follows this enum, regardless of
+// order of declaration in `struct Person`
+static auto kPersonProperties =
+ MakeProperties(DataMember("age", &Person::age), DataMember("name", &Person::name));
+
+// use generic facilities to define equality, serialization and deserialization
+bool operator==(const Person& l, const Person& r) {
+ return EqualsImpl<Person>{l, r, kPersonProperties}.equal_;
+}
+
+bool operator!=(const Person& l, const Person& r) { return !(l == r); }
+
+std::string ToString(const Person& obj) {
+ return ToStringImpl<Person>{"Person", obj, kPersonProperties}.Finish();
+}
+
+void PrintTo(const Person& obj, std::ostream* os) { *os << ToString(obj); }
+
+util::optional<Person> PersonFromString(util::string_view repr) {
+ return FromStringImpl<Person>("Person", repr, kPersonProperties).obj_;
+}
+
+TEST(Reflection, EqualityWithDataMembers) {
+ Person genos{19, "Genos"};
+ Person kuseno{45, "Kuseno"};
+
+ EXPECT_EQ(genos, genos);
+ EXPECT_EQ(kuseno, kuseno);
+
+ EXPECT_NE(genos, kuseno);
+ EXPECT_NE(kuseno, genos);
+}
+
+TEST(Reflection, ToStringFromDataMembers) {
+ Person genos{19, "Genos"};
+ Person kuseno{45, "Kuseno"};
+
+ EXPECT_EQ(ToString(genos), "Person{age:19,name:Genos}");
+ EXPECT_EQ(ToString(kuseno), "Person{age:45,name:Kuseno}");
+}
+
+TEST(Reflection, FromStringToDataMembers) {
+ Person genos{19, "Genos"};
+
+ EXPECT_EQ(PersonFromString(ToString(genos)), genos);
+
+ EXPECT_EQ(PersonFromString(""), util::nullopt);
+ EXPECT_EQ(PersonFromString("Per"), util::nullopt);
+ EXPECT_EQ(PersonFromString("Person{"), util::nullopt);
+ EXPECT_EQ(PersonFromString("Person{age:19,name:Genos"), util::nullopt);
+
+ EXPECT_EQ(PersonFromString("Person{name:Genos"), util::nullopt);
+ EXPECT_EQ(PersonFromString("Person{age:19,name:Genos,extra:Cyborg}"), util::nullopt);
+ EXPECT_EQ(PersonFromString("Person{name:Genos,age:19"), util::nullopt);
+
+ EXPECT_EQ(PersonFromString("Fake{age:19,name:Genos}"), util::nullopt);
+
+ EXPECT_EQ(PersonFromString("Person{age,name:Genos}"), util::nullopt);
+ EXPECT_EQ(PersonFromString("Person{age:nineteen,name:Genos}"), util::nullopt);
+ EXPECT_EQ(PersonFromString("Person{age:19 ,name:Genos}"), util::nullopt);
+ EXPECT_EQ(PersonFromString("Person{age:19,moniker:Genos}"), util::nullopt);
+
+ EXPECT_EQ(PersonFromString("Person{age: 19, name: Genos}"), util::nullopt);
+}
+
+enum class PersonType : int8_t {
+ EMPLOYEE,
+ CONTRACTOR,
+};
+
+template <>
+struct EnumTraits<PersonType>
+ : BasicEnumTraits<PersonType, PersonType::EMPLOYEE, PersonType::CONTRACTOR> {
+ static std::string name() { return "PersonType"; }
+ static std::string value_name(PersonType value) {
+ switch (value) {
+ case PersonType::EMPLOYEE:
+ return "EMPLOYEE";
+ case PersonType::CONTRACTOR:
+ return "CONTRACTOR";
+ }
+ return "<INVALID>";
+ }
+};
+
+TEST(Reflection, EnumTraits) {
+ static_assert(!has_enum_traits<Person>::value, "");
+ static_assert(has_enum_traits<PersonType>::value, "");
+ static_assert(std::is_same<EnumTraits<PersonType>::CType, int8_t>::value, "");
+ static_assert(std::is_same<EnumTraits<PersonType>::Type, Int8Type>::value, "");
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/rle_encoding.h b/src/arrow/cpp/src/arrow/util/rle_encoding.h
new file mode 100644
index 000000000..e9018fdf0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/rle_encoding.h
@@ -0,0 +1,826 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Imported from Apache Impala (incubating) on 2016-01-29 and modified for use
+// in parquet-cpp, Arrow
+
+#pragma once
+
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <vector>
+
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bit_stream_utils.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace util {
+
+/// Utility classes to do run length encoding (RLE) for fixed bit width values. If runs
+/// are sufficiently long, RLE is used, otherwise, the values are just bit-packed
+/// (literal encoding).
+/// For both types of runs, there is a byte-aligned indicator which encodes the length
+/// of the run and the type of the run.
+/// This encoding has the benefit that when there aren't any long enough runs, values
+/// are always decoded at fixed (can be precomputed) bit offsets OR both the value and
+/// the run length are byte aligned. This allows for very efficient decoding
+/// implementations.
+/// The encoding is:
+/// encoded-block := run*
+/// run := literal-run | repeated-run
+/// literal-run := literal-indicator < literal bytes >
+/// repeated-run := repeated-indicator < repeated value. padded to byte boundary >
+/// literal-indicator := varint_encode( number_of_groups << 1 | 1)
+/// repeated-indicator := varint_encode( number_of_repetitions << 1 )
+//
+/// Each run is preceded by a varint. The varint's least significant bit is
+/// used to indicate whether the run is a literal run or a repeated run. The rest
+/// of the varint is used to determine the length of the run (eg how many times the
+/// value repeats).
+//
+/// In the case of literal runs, the run length is always a multiple of 8 (i.e. encode
+/// in groups of 8), so that no matter the bit-width of the value, the sequence will end
+/// on a byte boundary without padding.
+/// Given that we know it is a multiple of 8, we store the number of 8-groups rather than
+/// the actual number of encoded ints. (This means that the total number of encoded values
+/// can not be determined from the encoded data, since the number of values in the last
+/// group may not be a multiple of 8). For the last group of literal runs, we pad
+/// the group to 8 with zeros. This allows for 8 at a time decoding on the read side
+/// without the need for additional checks.
+//
+/// There is a break-even point when it is more storage efficient to do run length
+/// encoding. For 1 bit-width values, that point is 8 values. They require 2 bytes
+/// for both the repeated encoding or the literal encoding. This value can always
+/// be computed based on the bit-width.
+/// TODO: think about how to use this for strings. The bit packing isn't quite the same.
+//
+/// Examples with bit-width 1 (eg encoding booleans):
+/// ----------------------------------------
+/// 100 1s followed by 100 0s:
+/// <varint(100 << 1)> <1, padded to 1 byte> <varint(100 << 1)> <0, padded to 1 byte>
+/// - (total 4 bytes)
+//
+/// alternating 1s and 0s (200 total):
+/// 200 ints = 25 groups of 8
+/// <varint((25 << 1) | 1)> <25 bytes of values, bitpacked>
+/// (total 26 bytes, 1 byte overhead)
+//
+
+/// Decoder class for RLE encoded data.
+class RleDecoder {
+ public:
+ /// Create a decoder object. buffer/buffer_len is the decoded data.
+ /// bit_width is the width of each value (before encoding).
+ RleDecoder(const uint8_t* buffer, int buffer_len, int bit_width)
+ : bit_reader_(buffer, buffer_len),
+ bit_width_(bit_width),
+ current_value_(0),
+ repeat_count_(0),
+ literal_count_(0) {
+ DCHECK_GE(bit_width_, 0);
+ DCHECK_LE(bit_width_, 64);
+ }
+
+ RleDecoder() : bit_width_(-1) {}
+
+ void Reset(const uint8_t* buffer, int buffer_len, int bit_width) {
+ DCHECK_GE(bit_width, 0);
+ DCHECK_LE(bit_width, 64);
+ bit_reader_.Reset(buffer, buffer_len);
+ bit_width_ = bit_width;
+ current_value_ = 0;
+ repeat_count_ = 0;
+ literal_count_ = 0;
+ }
+
+ /// Gets the next value. Returns false if there are no more.
+ template <typename T>
+ bool Get(T* val);
+
+ /// Gets a batch of values. Returns the number of decoded elements.
+ template <typename T>
+ int GetBatch(T* values, int batch_size);
+
+ /// Like GetBatch but add spacing for null entries
+ template <typename T>
+ int GetBatchSpaced(int batch_size, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset, T* out);
+
+ /// Like GetBatch but the values are then decoded using the provided dictionary
+ template <typename T>
+ int GetBatchWithDict(const T* dictionary, int32_t dictionary_length, T* values,
+ int batch_size);
+
+ /// Like GetBatchWithDict but add spacing for null entries
+ ///
+ /// Null entries will be zero-initialized in `values` to avoid leaking
+ /// private data.
+ template <typename T>
+ int GetBatchWithDictSpaced(const T* dictionary, int32_t dictionary_length, T* values,
+ int batch_size, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset);
+
+ protected:
+ BitUtil::BitReader bit_reader_;
+ /// Number of bits needed to encode the value. Must be between 0 and 64.
+ int bit_width_;
+ uint64_t current_value_;
+ int32_t repeat_count_;
+ int32_t literal_count_;
+
+ private:
+ /// Fills literal_count_ and repeat_count_ with next values. Returns false if there
+ /// are no more.
+ template <typename T>
+ bool NextCounts();
+
+ /// Utility methods for retrieving spaced values.
+ template <typename T, typename RunType, typename Converter>
+ int GetSpaced(Converter converter, int batch_size, int null_count,
+ const uint8_t* valid_bits, int64_t valid_bits_offset, T* out);
+};
+
+/// Class to incrementally build the rle data. This class does not allocate any memory.
+/// The encoding has two modes: encoding repeated runs and literal runs.
+/// If the run is sufficiently short, it is more efficient to encode as a literal run.
+/// This class does so by buffering 8 values at a time. If they are not all the same
+/// they are added to the literal run. If they are the same, they are added to the
+/// repeated run. When we switch modes, the previous run is flushed out.
+class RleEncoder {
+ public:
+ /// buffer/buffer_len: preallocated output buffer.
+ /// bit_width: max number of bits for value.
+ /// TODO: consider adding a min_repeated_run_length so the caller can control
+ /// when values should be encoded as repeated runs. Currently this is derived
+ /// based on the bit_width, which can determine a storage optimal choice.
+ /// TODO: allow 0 bit_width (and have dict encoder use it)
+ RleEncoder(uint8_t* buffer, int buffer_len, int bit_width)
+ : bit_width_(bit_width), bit_writer_(buffer, buffer_len) {
+ DCHECK_GE(bit_width_, 0);
+ DCHECK_LE(bit_width_, 64);
+ max_run_byte_size_ = MinBufferSize(bit_width);
+ DCHECK_GE(buffer_len, max_run_byte_size_) << "Input buffer not big enough.";
+ Clear();
+ }
+
+ /// Returns the minimum buffer size needed to use the encoder for 'bit_width'
+ /// This is the maximum length of a single run for 'bit_width'.
+ /// It is not valid to pass a buffer less than this length.
+ static int MinBufferSize(int bit_width) {
+ /// 1 indicator byte and MAX_VALUES_PER_LITERAL_RUN 'bit_width' values.
+ int max_literal_run_size =
+ 1 +
+ static_cast<int>(BitUtil::BytesForBits(MAX_VALUES_PER_LITERAL_RUN * bit_width));
+ /// Up to kMaxVlqByteLength indicator and a single 'bit_width' value.
+ int max_repeated_run_size = BitUtil::BitReader::kMaxVlqByteLength +
+ static_cast<int>(BitUtil::BytesForBits(bit_width));
+ return std::max(max_literal_run_size, max_repeated_run_size);
+ }
+
+ /// Returns the maximum byte size it could take to encode 'num_values'.
+ static int MaxBufferSize(int bit_width, int num_values) {
+ // For a bit_width > 1, the worst case is the repetition of "literal run of length 8
+ // and then a repeated run of length 8".
+ // 8 values per smallest run, 8 bits per byte
+ int bytes_per_run = bit_width;
+ int num_runs = static_cast<int>(BitUtil::CeilDiv(num_values, 8));
+ int literal_max_size = num_runs + num_runs * bytes_per_run;
+
+ // In the very worst case scenario, the data is a concatenation of repeated
+ // runs of 8 values. Repeated run has a 1 byte varint followed by the
+ // bit-packed repeated value
+ int min_repeated_run_size = 1 + static_cast<int>(BitUtil::BytesForBits(bit_width));
+ int repeated_max_size =
+ static_cast<int>(BitUtil::CeilDiv(num_values, 8)) * min_repeated_run_size;
+
+ return std::max(literal_max_size, repeated_max_size);
+ }
+
+ /// Encode value. Returns true if the value fits in buffer, false otherwise.
+ /// This value must be representable with bit_width_ bits.
+ bool Put(uint64_t value);
+
+ /// Flushes any pending values to the underlying buffer.
+ /// Returns the total number of bytes written
+ int Flush();
+
+ /// Resets all the state in the encoder.
+ void Clear();
+
+ /// Returns pointer to underlying buffer
+ uint8_t* buffer() { return bit_writer_.buffer(); }
+ int32_t len() { return bit_writer_.bytes_written(); }
+
+ private:
+ /// Flushes any buffered values. If this is part of a repeated run, this is largely
+ /// a no-op.
+ /// If it is part of a literal run, this will call FlushLiteralRun, which writes
+ /// out the buffered literal values.
+ /// If 'done' is true, the current run would be written even if it would normally
+ /// have been buffered more. This should only be called at the end, when the
+ /// encoder has received all values even if it would normally continue to be
+ /// buffered.
+ void FlushBufferedValues(bool done);
+
+ /// Flushes literal values to the underlying buffer. If update_indicator_byte,
+ /// then the current literal run is complete and the indicator byte is updated.
+ void FlushLiteralRun(bool update_indicator_byte);
+
+ /// Flushes a repeated run to the underlying buffer.
+ void FlushRepeatedRun();
+
+ /// Checks and sets buffer_full_. This must be called after flushing a run to
+ /// make sure there are enough bytes remaining to encode the next run.
+ void CheckBufferFull();
+
+ /// The maximum number of values in a single literal run
+ /// (number of groups encodable by a 1-byte indicator * 8)
+ static const int MAX_VALUES_PER_LITERAL_RUN = (1 << 6) * 8;
+
+ /// Number of bits needed to encode the value. Must be between 0 and 64.
+ const int bit_width_;
+
+ /// Underlying buffer.
+ BitUtil::BitWriter bit_writer_;
+
+ /// If true, the buffer is full and subsequent Put()'s will fail.
+ bool buffer_full_;
+
+ /// The maximum byte size a single run can take.
+ int max_run_byte_size_;
+
+ /// We need to buffer at most 8 values for literals. This happens when the
+ /// bit_width is 1 (so 8 values fit in one byte).
+ /// TODO: generalize this to other bit widths
+ int64_t buffered_values_[8];
+
+ /// Number of values in buffered_values_
+ int num_buffered_values_;
+
+ /// The current (also last) value that was written and the count of how
+ /// many times in a row that value has been seen. This is maintained even
+ /// if we are in a literal run. If the repeat_count_ get high enough, we switch
+ /// to encoding repeated runs.
+ uint64_t current_value_;
+ int repeat_count_;
+
+ /// Number of literals in the current run. This does not include the literals
+ /// that might be in buffered_values_. Only after we've got a group big enough
+ /// can we decide if they should part of the literal_count_ or repeat_count_
+ int literal_count_;
+
+ /// Pointer to a byte in the underlying buffer that stores the indicator byte.
+ /// This is reserved as soon as we need a literal run but the value is written
+ /// when the literal run is complete.
+ uint8_t* literal_indicator_byte_;
+};
+
+template <typename T>
+inline bool RleDecoder::Get(T* val) {
+ return GetBatch(val, 1) == 1;
+}
+
+template <typename T>
+inline int RleDecoder::GetBatch(T* values, int batch_size) {
+ DCHECK_GE(bit_width_, 0);
+ int values_read = 0;
+
+ auto* out = values;
+
+ while (values_read < batch_size) {
+ int remaining = batch_size - values_read;
+
+ if (repeat_count_ > 0) { // Repeated value case.
+ int repeat_batch = std::min(remaining, repeat_count_);
+ std::fill(out, out + repeat_batch, static_cast<T>(current_value_));
+
+ repeat_count_ -= repeat_batch;
+ values_read += repeat_batch;
+ out += repeat_batch;
+ } else if (literal_count_ > 0) {
+ int literal_batch = std::min(remaining, literal_count_);
+ int actual_read = bit_reader_.GetBatch(bit_width_, out, literal_batch);
+ if (actual_read != literal_batch) {
+ return values_read;
+ }
+
+ literal_count_ -= literal_batch;
+ values_read += literal_batch;
+ out += literal_batch;
+ } else {
+ if (!NextCounts<T>()) return values_read;
+ }
+ }
+
+ return values_read;
+}
+
+template <typename T, typename RunType, typename Converter>
+inline int RleDecoder::GetSpaced(Converter converter, int batch_size, int null_count,
+ const uint8_t* valid_bits, int64_t valid_bits_offset,
+ T* out) {
+ if (ARROW_PREDICT_FALSE(null_count == batch_size)) {
+ converter.FillZero(out, out + batch_size);
+ return batch_size;
+ }
+
+ DCHECK_GE(bit_width_, 0);
+ int values_read = 0;
+ int values_remaining = batch_size - null_count;
+
+ // Assume no bits to start.
+ arrow::internal::BitRunReader bit_reader(valid_bits, valid_bits_offset,
+ /*length=*/batch_size);
+ arrow::internal::BitRun valid_run = bit_reader.NextRun();
+ while (values_read < batch_size) {
+ if (ARROW_PREDICT_FALSE(valid_run.length == 0)) {
+ valid_run = bit_reader.NextRun();
+ }
+
+ DCHECK_GT(batch_size, 0);
+ DCHECK_GT(valid_run.length, 0);
+
+ if (valid_run.set) {
+ if ((repeat_count_ == 0) && (literal_count_ == 0)) {
+ if (!NextCounts<RunType>()) return values_read;
+ DCHECK((repeat_count_ > 0) ^ (literal_count_ > 0));
+ }
+
+ if (repeat_count_ > 0) {
+ int repeat_batch = 0;
+ // Consume the entire repeat counts incrementing repeat_batch to
+ // be the total of nulls + values consumed, we only need to
+ // get the total count because we can fill in the same value for
+ // nulls and non-nulls. This proves to be a big efficiency win.
+ while (repeat_count_ > 0 && (values_read + repeat_batch) < batch_size) {
+ DCHECK_GT(valid_run.length, 0);
+ if (valid_run.set) {
+ int update_size = std::min(static_cast<int>(valid_run.length), repeat_count_);
+ repeat_count_ -= update_size;
+ repeat_batch += update_size;
+ valid_run.length -= update_size;
+ values_remaining -= update_size;
+ } else {
+ // We can consume all nulls here because we would do so on
+ // the next loop anyways.
+ repeat_batch += static_cast<int>(valid_run.length);
+ valid_run.length = 0;
+ }
+ if (valid_run.length == 0) {
+ valid_run = bit_reader.NextRun();
+ }
+ }
+ RunType current_value = static_cast<RunType>(current_value_);
+ if (ARROW_PREDICT_FALSE(!converter.IsValid(current_value))) {
+ return values_read;
+ }
+ converter.Fill(out, out + repeat_batch, current_value);
+ out += repeat_batch;
+ values_read += repeat_batch;
+ } else if (literal_count_ > 0) {
+ int literal_batch = std::min(values_remaining, literal_count_);
+ DCHECK_GT(literal_batch, 0);
+
+ // Decode the literals
+ constexpr int kBufferSize = 1024;
+ RunType indices[kBufferSize];
+ literal_batch = std::min(literal_batch, kBufferSize);
+ int actual_read = bit_reader_.GetBatch(bit_width_, indices, literal_batch);
+ if (ARROW_PREDICT_FALSE(actual_read != literal_batch)) {
+ return values_read;
+ }
+ if (!converter.IsValid(indices, /*length=*/actual_read)) {
+ return values_read;
+ }
+ int skipped = 0;
+ int literals_read = 0;
+ while (literals_read < literal_batch) {
+ if (valid_run.set) {
+ int update_size = std::min(literal_batch - literals_read,
+ static_cast<int>(valid_run.length));
+ converter.Copy(out, indices + literals_read, update_size);
+ literals_read += update_size;
+ out += update_size;
+ valid_run.length -= update_size;
+ } else {
+ converter.FillZero(out, out + valid_run.length);
+ out += valid_run.length;
+ skipped += static_cast<int>(valid_run.length);
+ valid_run.length = 0;
+ }
+ if (valid_run.length == 0) {
+ valid_run = bit_reader.NextRun();
+ }
+ }
+ literal_count_ -= literal_batch;
+ values_remaining -= literal_batch;
+ values_read += literal_batch + skipped;
+ }
+ } else {
+ converter.FillZero(out, out + valid_run.length);
+ out += valid_run.length;
+ values_read += static_cast<int>(valid_run.length);
+ valid_run.length = 0;
+ }
+ }
+ DCHECK_EQ(valid_run.length, 0);
+ DCHECK_EQ(values_remaining, 0);
+ return values_read;
+}
+
+// Converter for GetSpaced that handles runs that get returned
+// directly as output.
+template <typename T>
+struct PlainRleConverter {
+ T kZero = {};
+ inline bool IsValid(const T& values) const { return true; }
+ inline bool IsValid(const T* values, int32_t length) const { return true; }
+ inline void Fill(T* begin, T* end, const T& run_value) const {
+ std::fill(begin, end, run_value);
+ }
+ inline void FillZero(T* begin, T* end) { std::fill(begin, end, kZero); }
+ inline void Copy(T* out, const T* values, int length) const {
+ std::memcpy(out, values, length * sizeof(T));
+ }
+};
+
+template <typename T>
+inline int RleDecoder::GetBatchSpaced(int batch_size, int null_count,
+ const uint8_t* valid_bits,
+ int64_t valid_bits_offset, T* out) {
+ if (null_count == 0) {
+ return GetBatch<T>(out, batch_size);
+ }
+
+ PlainRleConverter<T> converter;
+ arrow::internal::BitBlockCounter block_counter(valid_bits, valid_bits_offset,
+ batch_size);
+
+ int total_processed = 0;
+ int processed = 0;
+ arrow::internal::BitBlockCount block;
+
+ do {
+ block = block_counter.NextFourWords();
+ if (block.length == 0) {
+ break;
+ }
+ if (block.AllSet()) {
+ processed = GetBatch<T>(out, block.length);
+ } else if (block.NoneSet()) {
+ converter.FillZero(out, out + block.length);
+ processed = block.length;
+ } else {
+ processed = GetSpaced<T, /*RunType=*/T, PlainRleConverter<T>>(
+ converter, block.length, block.length - block.popcount, valid_bits,
+ valid_bits_offset, out);
+ }
+ total_processed += processed;
+ out += block.length;
+ valid_bits_offset += block.length;
+ } while (processed == block.length);
+ return total_processed;
+}
+
+static inline bool IndexInRange(int32_t idx, int32_t dictionary_length) {
+ return idx >= 0 && idx < dictionary_length;
+}
+
+// Converter for GetSpaced that handles runs of returned dictionary
+// indices.
+template <typename T>
+struct DictionaryConverter {
+ T kZero = {};
+ const T* dictionary;
+ int32_t dictionary_length;
+
+ inline bool IsValid(int32_t value) { return IndexInRange(value, dictionary_length); }
+
+ inline bool IsValid(const int32_t* values, int32_t length) const {
+ using IndexType = int32_t;
+ IndexType min_index = std::numeric_limits<IndexType>::max();
+ IndexType max_index = std::numeric_limits<IndexType>::min();
+ for (int x = 0; x < length; x++) {
+ min_index = std::min(values[x], min_index);
+ max_index = std::max(values[x], max_index);
+ }
+
+ return IndexInRange(min_index, dictionary_length) &&
+ IndexInRange(max_index, dictionary_length);
+ }
+ inline void Fill(T* begin, T* end, const int32_t& run_value) const {
+ std::fill(begin, end, dictionary[run_value]);
+ }
+ inline void FillZero(T* begin, T* end) { std::fill(begin, end, kZero); }
+
+ inline void Copy(T* out, const int32_t* values, int length) const {
+ for (int x = 0; x < length; x++) {
+ out[x] = dictionary[values[x]];
+ }
+ }
+};
+
+template <typename T>
+inline int RleDecoder::GetBatchWithDict(const T* dictionary, int32_t dictionary_length,
+ T* values, int batch_size) {
+ // Per https://github.com/apache/parquet-format/blob/master/Encodings.md,
+ // the maximum dictionary index width in Parquet is 32 bits.
+ using IndexType = int32_t;
+ DictionaryConverter<T> converter;
+ converter.dictionary = dictionary;
+ converter.dictionary_length = dictionary_length;
+
+ DCHECK_GE(bit_width_, 0);
+ int values_read = 0;
+
+ auto* out = values;
+
+ while (values_read < batch_size) {
+ int remaining = batch_size - values_read;
+
+ if (repeat_count_ > 0) {
+ auto idx = static_cast<IndexType>(current_value_);
+ if (ARROW_PREDICT_FALSE(!IndexInRange(idx, dictionary_length))) {
+ return values_read;
+ }
+ T val = dictionary[idx];
+
+ int repeat_batch = std::min(remaining, repeat_count_);
+ std::fill(out, out + repeat_batch, val);
+
+ /* Upkeep counters */
+ repeat_count_ -= repeat_batch;
+ values_read += repeat_batch;
+ out += repeat_batch;
+ } else if (literal_count_ > 0) {
+ constexpr int kBufferSize = 1024;
+ IndexType indices[kBufferSize];
+
+ int literal_batch = std::min(remaining, literal_count_);
+ literal_batch = std::min(literal_batch, kBufferSize);
+
+ int actual_read = bit_reader_.GetBatch(bit_width_, indices, literal_batch);
+ if (ARROW_PREDICT_FALSE(actual_read != literal_batch)) {
+ return values_read;
+ }
+ if (ARROW_PREDICT_FALSE(!converter.IsValid(indices, /*length=*/literal_batch))) {
+ return values_read;
+ }
+ converter.Copy(out, indices, literal_batch);
+
+ /* Upkeep counters */
+ literal_count_ -= literal_batch;
+ values_read += literal_batch;
+ out += literal_batch;
+ } else {
+ if (!NextCounts<IndexType>()) return values_read;
+ }
+ }
+
+ return values_read;
+}
+
+template <typename T>
+inline int RleDecoder::GetBatchWithDictSpaced(const T* dictionary,
+ int32_t dictionary_length, T* out,
+ int batch_size, int null_count,
+ const uint8_t* valid_bits,
+ int64_t valid_bits_offset) {
+ if (null_count == 0) {
+ return GetBatchWithDict<T>(dictionary, dictionary_length, out, batch_size);
+ }
+ arrow::internal::BitBlockCounter block_counter(valid_bits, valid_bits_offset,
+ batch_size);
+ using IndexType = int32_t;
+ DictionaryConverter<T> converter;
+ converter.dictionary = dictionary;
+ converter.dictionary_length = dictionary_length;
+
+ int total_processed = 0;
+ int processed = 0;
+ arrow::internal::BitBlockCount block;
+ do {
+ block = block_counter.NextFourWords();
+ if (block.length == 0) {
+ break;
+ }
+ if (block.AllSet()) {
+ processed = GetBatchWithDict<T>(dictionary, dictionary_length, out, block.length);
+ } else if (block.NoneSet()) {
+ converter.FillZero(out, out + block.length);
+ processed = block.length;
+ } else {
+ processed = GetSpaced<T, /*RunType=*/IndexType, DictionaryConverter<T>>(
+ converter, block.length, block.length - block.popcount, valid_bits,
+ valid_bits_offset, out);
+ }
+ total_processed += processed;
+ out += block.length;
+ valid_bits_offset += block.length;
+ } while (processed == block.length);
+ return total_processed;
+}
+
+template <typename T>
+bool RleDecoder::NextCounts() {
+ // Read the next run's indicator int, it could be a literal or repeated run.
+ // The int is encoded as a vlq-encoded value.
+ uint32_t indicator_value = 0;
+ if (!bit_reader_.GetVlqInt(&indicator_value)) return false;
+
+ // lsb indicates if it is a literal run or repeated run
+ bool is_literal = indicator_value & 1;
+ uint32_t count = indicator_value >> 1;
+ if (is_literal) {
+ if (ARROW_PREDICT_FALSE(count == 0 || count > static_cast<uint32_t>(INT32_MAX) / 8)) {
+ return false;
+ }
+ literal_count_ = count * 8;
+ } else {
+ if (ARROW_PREDICT_FALSE(count == 0 || count > static_cast<uint32_t>(INT32_MAX))) {
+ return false;
+ }
+ repeat_count_ = count;
+ T value = {};
+ if (!bit_reader_.GetAligned<T>(static_cast<int>(BitUtil::CeilDiv(bit_width_, 8)),
+ &value)) {
+ return false;
+ }
+ current_value_ = static_cast<uint64_t>(value);
+ }
+ return true;
+}
+
+/// This function buffers input values 8 at a time. After seeing all 8 values,
+/// it decides whether they should be encoded as a literal or repeated run.
+inline bool RleEncoder::Put(uint64_t value) {
+ DCHECK(bit_width_ == 64 || value < (1ULL << bit_width_));
+ if (ARROW_PREDICT_FALSE(buffer_full_)) return false;
+
+ if (ARROW_PREDICT_TRUE(current_value_ == value)) {
+ ++repeat_count_;
+ if (repeat_count_ > 8) {
+ // This is just a continuation of the current run, no need to buffer the
+ // values.
+ // Note that this is the fast path for long repeated runs.
+ return true;
+ }
+ } else {
+ if (repeat_count_ >= 8) {
+ // We had a run that was long enough but it has ended. Flush the
+ // current repeated run.
+ DCHECK_EQ(literal_count_, 0);
+ FlushRepeatedRun();
+ }
+ repeat_count_ = 1;
+ current_value_ = value;
+ }
+
+ buffered_values_[num_buffered_values_] = value;
+ if (++num_buffered_values_ == 8) {
+ DCHECK_EQ(literal_count_ % 8, 0);
+ FlushBufferedValues(false);
+ }
+ return true;
+}
+
+inline void RleEncoder::FlushLiteralRun(bool update_indicator_byte) {
+ if (literal_indicator_byte_ == NULL) {
+ // The literal indicator byte has not been reserved yet, get one now.
+ literal_indicator_byte_ = bit_writer_.GetNextBytePtr();
+ DCHECK(literal_indicator_byte_ != NULL);
+ }
+
+ // Write all the buffered values as bit packed literals
+ for (int i = 0; i < num_buffered_values_; ++i) {
+ bool success = bit_writer_.PutValue(buffered_values_[i], bit_width_);
+ DCHECK(success) << "There is a bug in using CheckBufferFull()";
+ }
+ num_buffered_values_ = 0;
+
+ if (update_indicator_byte) {
+ // At this point we need to write the indicator byte for the literal run.
+ // We only reserve one byte, to allow for streaming writes of literal values.
+ // The logic makes sure we flush literal runs often enough to not overrun
+ // the 1 byte.
+ DCHECK_EQ(literal_count_ % 8, 0);
+ int num_groups = literal_count_ / 8;
+ int32_t indicator_value = (num_groups << 1) | 1;
+ DCHECK_EQ(indicator_value & 0xFFFFFF00, 0);
+ *literal_indicator_byte_ = static_cast<uint8_t>(indicator_value);
+ literal_indicator_byte_ = NULL;
+ literal_count_ = 0;
+ CheckBufferFull();
+ }
+}
+
+inline void RleEncoder::FlushRepeatedRun() {
+ DCHECK_GT(repeat_count_, 0);
+ bool result = true;
+ // The lsb of 0 indicates this is a repeated run
+ int32_t indicator_value = repeat_count_ << 1 | 0;
+ result &= bit_writer_.PutVlqInt(static_cast<uint32_t>(indicator_value));
+ result &= bit_writer_.PutAligned(current_value_,
+ static_cast<int>(BitUtil::CeilDiv(bit_width_, 8)));
+ DCHECK(result);
+ num_buffered_values_ = 0;
+ repeat_count_ = 0;
+ CheckBufferFull();
+}
+
+/// Flush the values that have been buffered. At this point we decide whether
+/// we need to switch between the run types or continue the current one.
+inline void RleEncoder::FlushBufferedValues(bool done) {
+ if (repeat_count_ >= 8) {
+ // Clear the buffered values. They are part of the repeated run now and we
+ // don't want to flush them out as literals.
+ num_buffered_values_ = 0;
+ if (literal_count_ != 0) {
+ // There was a current literal run. All the values in it have been flushed
+ // but we still need to update the indicator byte.
+ DCHECK_EQ(literal_count_ % 8, 0);
+ DCHECK_EQ(repeat_count_, 8);
+ FlushLiteralRun(true);
+ }
+ DCHECK_EQ(literal_count_, 0);
+ return;
+ }
+
+ literal_count_ += num_buffered_values_;
+ DCHECK_EQ(literal_count_ % 8, 0);
+ int num_groups = literal_count_ / 8;
+ if (num_groups + 1 >= (1 << 6)) {
+ // We need to start a new literal run because the indicator byte we've reserved
+ // cannot store more values.
+ DCHECK(literal_indicator_byte_ != NULL);
+ FlushLiteralRun(true);
+ } else {
+ FlushLiteralRun(done);
+ }
+ repeat_count_ = 0;
+}
+
+inline int RleEncoder::Flush() {
+ if (literal_count_ > 0 || repeat_count_ > 0 || num_buffered_values_ > 0) {
+ bool all_repeat = literal_count_ == 0 && (repeat_count_ == num_buffered_values_ ||
+ num_buffered_values_ == 0);
+ // There is something pending, figure out if it's a repeated or literal run
+ if (repeat_count_ > 0 && all_repeat) {
+ FlushRepeatedRun();
+ } else {
+ DCHECK_EQ(literal_count_ % 8, 0);
+ // Buffer the last group of literals to 8 by padding with 0s.
+ for (; num_buffered_values_ != 0 && num_buffered_values_ < 8;
+ ++num_buffered_values_) {
+ buffered_values_[num_buffered_values_] = 0;
+ }
+ literal_count_ += num_buffered_values_;
+ FlushLiteralRun(true);
+ repeat_count_ = 0;
+ }
+ }
+ bit_writer_.Flush();
+ DCHECK_EQ(num_buffered_values_, 0);
+ DCHECK_EQ(literal_count_, 0);
+ DCHECK_EQ(repeat_count_, 0);
+
+ return bit_writer_.bytes_written();
+}
+
+inline void RleEncoder::CheckBufferFull() {
+ int bytes_written = bit_writer_.bytes_written();
+ if (bytes_written + max_run_byte_size_ > bit_writer_.buffer_len()) {
+ buffer_full_ = true;
+ }
+}
+
+inline void RleEncoder::Clear() {
+ buffer_full_ = false;
+ current_value_ = 0;
+ repeat_count_ = 0;
+ num_buffered_values_ = 0;
+ literal_count_ = 0;
+ literal_indicator_byte_ = NULL;
+ bit_writer_.Clear();
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/rle_encoding_test.cc b/src/arrow/cpp/src/arrow/util/rle_encoding_test.cc
new file mode 100644
index 000000000..362f8253c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/rle_encoding_test.cc
@@ -0,0 +1,573 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// From Apache Impala (incubating) as of 2016-01-29
+
+#include <cstdint>
+#include <cstring>
+#include <random>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+#include "arrow/util/bit_stream_utils.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/rle_encoding.h"
+
+namespace arrow {
+namespace util {
+
+const int MAX_WIDTH = 32;
+
+TEST(BitArray, TestBool) {
+ const int len = 8;
+ uint8_t buffer[len];
+
+ BitUtil::BitWriter writer(buffer, len);
+
+ // Write alternating 0's and 1's
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_TRUE(writer.PutValue(i % 2, 1));
+ }
+ writer.Flush();
+
+ EXPECT_EQ(buffer[0], 0xAA /* 0b10101010 */);
+
+ // Write 00110011
+ for (int i = 0; i < 8; ++i) {
+ bool result = false;
+ switch (i) {
+ case 0:
+ case 1:
+ case 4:
+ case 5:
+ result = writer.PutValue(false, 1);
+ break;
+ default:
+ result = writer.PutValue(true, 1);
+ break;
+ }
+ EXPECT_TRUE(result);
+ }
+ writer.Flush();
+
+ // Validate the exact bit value
+ EXPECT_EQ(buffer[0], 0xAA /* 0b10101010 */);
+ EXPECT_EQ(buffer[1], 0xCC /* 0b11001100 */);
+
+ // Use the reader and validate
+ BitUtil::BitReader reader(buffer, len);
+ for (int i = 0; i < 8; ++i) {
+ bool val = false;
+ bool result = reader.GetValue(1, &val);
+ EXPECT_TRUE(result);
+ EXPECT_EQ(val, (i % 2) != 0);
+ }
+
+ for (int i = 0; i < 8; ++i) {
+ bool val = false;
+ bool result = reader.GetValue(1, &val);
+ EXPECT_TRUE(result);
+ switch (i) {
+ case 0:
+ case 1:
+ case 4:
+ case 5:
+ EXPECT_EQ(val, false);
+ break;
+ default:
+ EXPECT_EQ(val, true);
+ break;
+ }
+ }
+}
+
+// Writes 'num_vals' values with width 'bit_width' and reads them back.
+void TestBitArrayValues(int bit_width, int num_vals) {
+ int len = static_cast<int>(BitUtil::BytesForBits(bit_width * num_vals));
+ EXPECT_GT(len, 0);
+ const uint64_t mod = bit_width == 64 ? 1 : 1LL << bit_width;
+
+ std::vector<uint8_t> buffer(len);
+ BitUtil::BitWriter writer(buffer.data(), len);
+ for (int i = 0; i < num_vals; ++i) {
+ bool result = writer.PutValue(i % mod, bit_width);
+ EXPECT_TRUE(result);
+ }
+ writer.Flush();
+ EXPECT_EQ(writer.bytes_written(), len);
+
+ BitUtil::BitReader reader(buffer.data(), len);
+ for (int i = 0; i < num_vals; ++i) {
+ int64_t val = 0;
+ bool result = reader.GetValue(bit_width, &val);
+ EXPECT_TRUE(result);
+ EXPECT_EQ(val, i % mod);
+ }
+ EXPECT_EQ(reader.bytes_left(), 0);
+}
+
+TEST(BitArray, TestValues) {
+ for (int width = 1; width <= MAX_WIDTH; ++width) {
+ TestBitArrayValues(width, 1);
+ TestBitArrayValues(width, 2);
+ // Don't write too many values
+ TestBitArrayValues(width, (width < 12) ? (1 << width) : 4096);
+ TestBitArrayValues(width, 1024);
+ }
+}
+
+// Test some mixed values
+TEST(BitArray, TestMixed) {
+ const int len = 1024;
+ uint8_t buffer[len];
+ bool parity = true;
+
+ BitUtil::BitWriter writer(buffer, len);
+ for (int i = 0; i < len; ++i) {
+ bool result;
+ if (i % 2 == 0) {
+ result = writer.PutValue(parity, 1);
+ parity = !parity;
+ } else {
+ result = writer.PutValue(i, 10);
+ }
+ EXPECT_TRUE(result);
+ }
+ writer.Flush();
+
+ parity = true;
+ BitUtil::BitReader reader(buffer, len);
+ for (int i = 0; i < len; ++i) {
+ bool result;
+ if (i % 2 == 0) {
+ bool val;
+ result = reader.GetValue(1, &val);
+ EXPECT_EQ(val, parity);
+ parity = !parity;
+ } else {
+ int val;
+ result = reader.GetValue(10, &val);
+ EXPECT_EQ(val, i);
+ }
+ EXPECT_TRUE(result);
+ }
+}
+
+// Validates encoding of values by encoding and decoding them. If
+// expected_encoding != NULL, also validates that the encoded buffer is
+// exactly 'expected_encoding'.
+// if expected_len is not -1, it will validate the encoded size is correct.
+void ValidateRle(const std::vector<int>& values, int bit_width,
+ uint8_t* expected_encoding, int expected_len) {
+ const int len = 64 * 1024;
+ uint8_t buffer[len];
+ EXPECT_LE(expected_len, len);
+
+ RleEncoder encoder(buffer, len, bit_width);
+ for (size_t i = 0; i < values.size(); ++i) {
+ bool result = encoder.Put(values[i]);
+ EXPECT_TRUE(result);
+ }
+ int encoded_len = encoder.Flush();
+
+ if (expected_len != -1) {
+ EXPECT_EQ(encoded_len, expected_len);
+ }
+ if (expected_encoding != NULL) {
+ EXPECT_EQ(memcmp(buffer, expected_encoding, encoded_len), 0);
+ }
+
+ // Verify read
+ {
+ RleDecoder decoder(buffer, len, bit_width);
+ for (size_t i = 0; i < values.size(); ++i) {
+ uint64_t val;
+ bool result = decoder.Get(&val);
+ EXPECT_TRUE(result);
+ EXPECT_EQ(values[i], val);
+ }
+ }
+
+ // Verify batch read
+ {
+ RleDecoder decoder(buffer, len, bit_width);
+ std::vector<int> values_read(values.size());
+ ASSERT_EQ(values.size(),
+ decoder.GetBatch(values_read.data(), static_cast<int>(values.size())));
+ EXPECT_EQ(values, values_read);
+ }
+}
+
+// A version of ValidateRle that round-trips the values and returns false if
+// the returned values are not all the same
+bool CheckRoundTrip(const std::vector<int>& values, int bit_width) {
+ const int len = 64 * 1024;
+ uint8_t buffer[len];
+ RleEncoder encoder(buffer, len, bit_width);
+ for (size_t i = 0; i < values.size(); ++i) {
+ bool result = encoder.Put(values[i]);
+ if (!result) {
+ return false;
+ }
+ }
+ int encoded_len = encoder.Flush();
+ int out = 0;
+
+ {
+ RleDecoder decoder(buffer, encoded_len, bit_width);
+ for (size_t i = 0; i < values.size(); ++i) {
+ EXPECT_TRUE(decoder.Get(&out));
+ if (values[i] != out) {
+ return false;
+ }
+ }
+ }
+
+ // Verify batch read
+ {
+ RleDecoder decoder(buffer, encoded_len, bit_width);
+ std::vector<int> values_read(values.size());
+ if (static_cast<int>(values.size()) !=
+ decoder.GetBatch(values_read.data(), static_cast<int>(values.size()))) {
+ return false;
+ }
+
+ if (values != values_read) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+TEST(Rle, SpecificSequences) {
+ const int len = 1024;
+ uint8_t expected_buffer[len];
+ std::vector<int> values;
+
+ // Test 50 0' followed by 50 1's
+ values.resize(100);
+ for (int i = 0; i < 50; ++i) {
+ values[i] = 0;
+ }
+ for (int i = 50; i < 100; ++i) {
+ values[i] = 1;
+ }
+
+ // expected_buffer valid for bit width <= 1 byte
+ expected_buffer[0] = (50 << 1);
+ expected_buffer[1] = 0;
+ expected_buffer[2] = (50 << 1);
+ expected_buffer[3] = 1;
+ for (int width = 1; width <= 8; ++width) {
+ ValidateRle(values, width, expected_buffer, 4);
+ }
+
+ for (int width = 9; width <= MAX_WIDTH; ++width) {
+ ValidateRle(values, width, nullptr,
+ 2 * (1 + static_cast<int>(BitUtil::CeilDiv(width, 8))));
+ }
+
+ // Test 100 0's and 1's alternating
+ for (int i = 0; i < 100; ++i) {
+ values[i] = i % 2;
+ }
+ int num_groups = static_cast<int>(BitUtil::CeilDiv(100, 8));
+ expected_buffer[0] = static_cast<uint8_t>((num_groups << 1) | 1);
+ for (int i = 1; i <= 100 / 8; ++i) {
+ expected_buffer[i] = 0xAA /* 0b10101010 */;
+ }
+ // Values for the last 4 0 and 1's. The upper 4 bits should be padded to 0.
+ expected_buffer[100 / 8 + 1] = 0x0A /* 0b00001010 */;
+
+ // num_groups and expected_buffer only valid for bit width = 1
+ ValidateRle(values, 1, expected_buffer, 1 + num_groups);
+ for (int width = 2; width <= MAX_WIDTH; ++width) {
+ int num_values = static_cast<int>(BitUtil::CeilDiv(100, 8)) * 8;
+ ValidateRle(values, width, nullptr,
+ 1 + static_cast<int>(BitUtil::CeilDiv(width * num_values, 8)));
+ }
+
+ // Test 16-bit values to confirm encoded values are stored in little endian
+ values.resize(28);
+ for (int i = 0; i < 16; ++i) {
+ values[i] = 0x55aa;
+ }
+ for (int i = 16; i < 28; ++i) {
+ values[i] = 0xaa55;
+ }
+ expected_buffer[0] = (16 << 1);
+ expected_buffer[1] = 0xaa;
+ expected_buffer[2] = 0x55;
+ expected_buffer[3] = (12 << 1);
+ expected_buffer[4] = 0x55;
+ expected_buffer[5] = 0xaa;
+
+ ValidateRle(values, 16, expected_buffer, 6);
+
+ // Test 32-bit values to confirm encoded values are stored in little endian
+ values.resize(28);
+ for (int i = 0; i < 16; ++i) {
+ values[i] = 0x555aaaa5;
+ }
+ for (int i = 16; i < 28; ++i) {
+ values[i] = 0x5aaaa555;
+ }
+ expected_buffer[0] = (16 << 1);
+ expected_buffer[1] = 0xa5;
+ expected_buffer[2] = 0xaa;
+ expected_buffer[3] = 0x5a;
+ expected_buffer[4] = 0x55;
+ expected_buffer[5] = (12 << 1);
+ expected_buffer[6] = 0x55;
+ expected_buffer[7] = 0xa5;
+ expected_buffer[8] = 0xaa;
+ expected_buffer[9] = 0x5a;
+
+ ValidateRle(values, 32, expected_buffer, 10);
+}
+
+// ValidateRle on 'num_vals' values with width 'bit_width'. If 'value' != -1, that value
+// is used, otherwise alternating values are used.
+void TestRleValues(int bit_width, int num_vals, int value = -1) {
+ const uint64_t mod = (bit_width == 64) ? 1 : 1LL << bit_width;
+ std::vector<int> values;
+ for (int v = 0; v < num_vals; ++v) {
+ values.push_back((value != -1) ? value : static_cast<int>(v % mod));
+ }
+ ValidateRle(values, bit_width, NULL, -1);
+}
+
+TEST(Rle, TestValues) {
+ for (int width = 1; width <= MAX_WIDTH; ++width) {
+ TestRleValues(width, 1);
+ TestRleValues(width, 1024);
+ TestRleValues(width, 1024, 0);
+ TestRleValues(width, 1024, 1);
+ }
+}
+
+TEST(Rle, BitWidthZeroRepeated) {
+ uint8_t buffer[1];
+ const int num_values = 15;
+ buffer[0] = num_values << 1; // repeated indicator byte
+ RleDecoder decoder(buffer, sizeof(buffer), 0);
+ uint8_t val;
+ for (int i = 0; i < num_values; ++i) {
+ bool result = decoder.Get(&val);
+ EXPECT_TRUE(result);
+ EXPECT_EQ(val, 0); // can only encode 0s with bit width 0
+ }
+ EXPECT_FALSE(decoder.Get(&val));
+}
+
+TEST(Rle, BitWidthZeroLiteral) {
+ uint8_t buffer[1];
+ const int num_groups = 4;
+ buffer[0] = num_groups << 1 | 1; // literal indicator byte
+ RleDecoder decoder = RleDecoder(buffer, sizeof(buffer), 0);
+ const int num_values = num_groups * 8;
+ uint8_t val;
+ for (int i = 0; i < num_values; ++i) {
+ bool result = decoder.Get(&val);
+ EXPECT_TRUE(result);
+ EXPECT_EQ(val, 0); // can only encode 0s with bit width 0
+ }
+ EXPECT_FALSE(decoder.Get(&val));
+}
+
+// Test that writes out a repeated group and then a literal
+// group but flush before finishing.
+TEST(BitRle, Flush) {
+ std::vector<int> values;
+ for (int i = 0; i < 16; ++i) values.push_back(1);
+ values.push_back(0);
+ ValidateRle(values, 1, NULL, -1);
+ values.push_back(1);
+ ValidateRle(values, 1, NULL, -1);
+ values.push_back(1);
+ ValidateRle(values, 1, NULL, -1);
+ values.push_back(1);
+ ValidateRle(values, 1, NULL, -1);
+}
+
+// Test some random sequences.
+TEST(BitRle, Random) {
+ int niters = 50;
+ int ngroups = 1000;
+ int max_group_size = 16;
+ std::vector<int> values(ngroups + max_group_size);
+
+ // prng setup
+ const auto seed = ::arrow::internal::GetRandomSeed();
+ std::default_random_engine gen(
+ static_cast<std::default_random_engine::result_type>(seed));
+ std::uniform_int_distribution<int> dist(1, 20);
+
+ for (int iter = 0; iter < niters; ++iter) {
+ // generate a seed with device entropy
+ bool parity = 0;
+ values.resize(0);
+
+ for (int i = 0; i < ngroups; ++i) {
+ int group_size = dist(gen);
+ if (group_size > max_group_size) {
+ group_size = 1;
+ }
+ for (int i = 0; i < group_size; ++i) {
+ values.push_back(parity);
+ }
+ parity = !parity;
+ }
+ if (!CheckRoundTrip(values, BitUtil::NumRequiredBits(values.size()))) {
+ FAIL() << "failing seed: " << seed;
+ }
+ }
+}
+
+// Test a sequence of 1 0's, 2 1's, 3 0's. etc
+// e.g. 011000111100000
+TEST(BitRle, RepeatedPattern) {
+ std::vector<int> values;
+ const int min_run = 1;
+ const int max_run = 32;
+
+ for (int i = min_run; i <= max_run; ++i) {
+ int v = i % 2;
+ for (int j = 0; j < i; ++j) {
+ values.push_back(v);
+ }
+ }
+
+ // And go back down again
+ for (int i = max_run; i >= min_run; --i) {
+ int v = i % 2;
+ for (int j = 0; j < i; ++j) {
+ values.push_back(v);
+ }
+ }
+
+ ValidateRle(values, 1, NULL, -1);
+}
+
+TEST(BitRle, Overflow) {
+ for (int bit_width = 1; bit_width < 32; bit_width += 3) {
+ int len = RleEncoder::MinBufferSize(bit_width);
+ std::vector<uint8_t> buffer(len);
+ int num_added = 0;
+ bool parity = true;
+
+ RleEncoder encoder(buffer.data(), len, bit_width);
+ // Insert alternating true/false until there is no space left
+ while (true) {
+ bool result = encoder.Put(parity);
+ parity = !parity;
+ if (!result) break;
+ ++num_added;
+ }
+
+ int bytes_written = encoder.Flush();
+ EXPECT_LE(bytes_written, len);
+ EXPECT_GT(num_added, 0);
+
+ RleDecoder decoder(buffer.data(), bytes_written, bit_width);
+ parity = true;
+ uint32_t v;
+ for (int i = 0; i < num_added; ++i) {
+ bool result = decoder.Get(&v);
+ EXPECT_TRUE(result);
+ EXPECT_EQ(v != 0, parity);
+ parity = !parity;
+ }
+ // Make sure we get false when reading past end a couple times.
+ EXPECT_FALSE(decoder.Get(&v));
+ EXPECT_FALSE(decoder.Get(&v));
+ }
+}
+
+template <typename Type>
+void CheckRoundTripSpaced(const Array& data, int bit_width) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using T = typename Type::c_type;
+
+ int num_values = static_cast<int>(data.length());
+ int buffer_size = RleEncoder::MaxBufferSize(bit_width, num_values);
+
+ const T* values = static_cast<const ArrayType&>(data).raw_values();
+
+ std::vector<uint8_t> buffer(buffer_size);
+ RleEncoder encoder(buffer.data(), buffer_size, bit_width);
+ for (int i = 0; i < num_values; ++i) {
+ if (data.IsValid(i)) {
+ if (!encoder.Put(static_cast<uint64_t>(values[i]))) {
+ FAIL() << "Encoding failed";
+ }
+ }
+ }
+ int encoded_size = encoder.Flush();
+
+ // Verify batch read
+ RleDecoder decoder(buffer.data(), encoded_size, bit_width);
+ std::vector<T> values_read(num_values);
+
+ if (num_values != decoder.GetBatchSpaced(
+ num_values, static_cast<int>(data.null_count()),
+ data.null_bitmap_data(), data.offset(), values_read.data())) {
+ FAIL();
+ }
+
+ for (int64_t i = 0; i < num_values; ++i) {
+ if (data.IsValid(i)) {
+ if (values_read[i] != values[i]) {
+ FAIL() << "Index " << i << " read " << values_read[i] << " but should be "
+ << values[i];
+ }
+ }
+ }
+}
+
+template <typename T>
+struct GetBatchSpacedTestCase {
+ T max_value;
+ int64_t size;
+ double null_probability;
+ int bit_width;
+};
+
+TEST(RleDecoder, GetBatchSpaced) {
+ uint32_t kSeed = 1337;
+ ::arrow::random::RandomArrayGenerator rand(kSeed);
+
+ std::vector<GetBatchSpacedTestCase<int32_t>> int32_cases{
+ {1, 100000, 0.01, 1}, {1, 100000, 0.1, 1}, {1, 100000, 0.5, 1},
+ {4, 100000, 0.05, 3}, {100, 100000, 0.05, 7},
+ };
+ for (auto case_ : int32_cases) {
+ auto arr = rand.Int32(case_.size, /*min=*/0, case_.max_value, case_.null_probability);
+ CheckRoundTripSpaced<Int32Type>(*arr, case_.bit_width);
+ CheckRoundTripSpaced<Int32Type>(*arr->Slice(1), case_.bit_width);
+ }
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/simd.h b/src/arrow/cpp/src/arrow/util/simd.h
new file mode 100644
index 000000000..259641dd4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/simd.h
@@ -0,0 +1,50 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#ifdef _MSC_VER
+// MSVC x86_64/arm64
+
+#if defined(_M_AMD64) || defined(_M_X64)
+#include <intrin.h>
+#elif defined(_M_ARM64)
+#include <arm64_neon.h>
+#endif
+
+#else
+// gcc/clang (possibly others)
+
+#if defined(ARROW_HAVE_BMI2)
+#include <x86intrin.h>
+#endif
+
+#if defined(ARROW_HAVE_AVX2) || defined(ARROW_HAVE_AVX512)
+#include <immintrin.h>
+#elif defined(ARROW_HAVE_SSE4_2)
+#include <nmmintrin.h>
+#endif
+
+#ifdef ARROW_HAVE_NEON
+#include <arm_neon.h>
+#endif
+
+#ifdef ARROW_HAVE_ARMV8_CRC
+#include <arm_acle.h>
+#endif
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/util/small_vector.h b/src/arrow/cpp/src/arrow/util/small_vector.h
new file mode 100644
index 000000000..071295282
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/small_vector.h
@@ -0,0 +1,519 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <initializer_list>
+#include <iterator>
+#include <limits>
+#include <new>
+#include <type_traits>
+#include <utility>
+
+#include "arrow/util/aligned_storage.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+template <typename T, size_t N, bool NonTrivialDestructor>
+struct StaticVectorStorageBase {
+ using storage_type = AlignedStorage<T>;
+
+ storage_type static_data_[N];
+ size_t size_ = 0;
+
+ void destroy() noexcept {}
+};
+
+template <typename T, size_t N>
+struct StaticVectorStorageBase<T, N, true> {
+ using storage_type = AlignedStorage<T>;
+
+ storage_type static_data_[N];
+ size_t size_ = 0;
+
+ ~StaticVectorStorageBase() noexcept { destroy(); }
+
+ void destroy() noexcept { storage_type::destroy_several(static_data_, size_); }
+};
+
+template <typename T, size_t N, bool D = !std::is_trivially_destructible<T>::value>
+struct StaticVectorStorage : public StaticVectorStorageBase<T, N, D> {
+ using Base = StaticVectorStorageBase<T, N, D>;
+ using typename Base::storage_type;
+
+ using Base::size_;
+ using Base::static_data_;
+
+ StaticVectorStorage() noexcept = default;
+
+#if __cpp_constexpr >= 201304L // non-const constexpr
+ constexpr storage_type* storage_ptr() { return static_data_; }
+#else
+ storage_type* storage_ptr() { return static_data_; }
+#endif
+
+ constexpr const storage_type* const_storage_ptr() const { return static_data_; }
+
+ // Adjust storage size, but don't initialize any objects
+ void bump_size(size_t addend) {
+ assert(size_ + addend <= N);
+ size_ += addend;
+ }
+
+ void ensure_capacity(size_t min_capacity) { assert(min_capacity <= N); }
+
+ // Adjust storage size, but don't destroy any objects
+ void reduce_size(size_t reduce_by) {
+ assert(reduce_by <= size_);
+ size_ -= reduce_by;
+ }
+
+ // Move objects from another storage, but don't destroy any objects currently
+ // stored in *this.
+ // You need to call destroy() first if necessary (e.g. in a
+ // move assignment operator).
+ void move_construct(StaticVectorStorage&& other) noexcept {
+ size_ = other.size_;
+ if (size_ != 0) {
+ // Use a compile-time memcpy size (N) for trivial types
+ storage_type::move_construct_several(other.static_data_, static_data_, size_, N);
+ }
+ }
+
+ constexpr size_t capacity() const { return N; }
+
+ constexpr size_t max_size() const { return N; }
+
+ void reserve(size_t n) {}
+
+ void clear() {
+ storage_type::destroy_several(static_data_, size_);
+ size_ = 0;
+ }
+};
+
+template <typename T, size_t N>
+struct SmallVectorStorage {
+ using storage_type = AlignedStorage<T>;
+
+ storage_type static_data_[N];
+ size_t size_ = 0;
+ storage_type* data_ = static_data_;
+ size_t dynamic_capacity_ = 0;
+
+ SmallVectorStorage() noexcept = default;
+
+ ~SmallVectorStorage() { destroy(); }
+
+#if __cpp_constexpr >= 201304L // non-const constexpr
+ constexpr storage_type* storage_ptr() { return data_; }
+#else
+ storage_type* storage_ptr() { return data_; }
+#endif
+
+ constexpr const storage_type* const_storage_ptr() const { return data_; }
+
+ void bump_size(size_t addend) {
+ const size_t new_size = size_ + addend;
+ ensure_capacity(new_size);
+ size_ = new_size;
+ }
+
+ void ensure_capacity(size_t min_capacity) {
+ if (dynamic_capacity_) {
+ // Grow dynamic storage if necessary
+ if (min_capacity > dynamic_capacity_) {
+ size_t new_capacity = std::max(dynamic_capacity_ * 2, min_capacity);
+ reallocate_dynamic(new_capacity);
+ }
+ } else if (min_capacity > N) {
+ switch_to_dynamic(min_capacity);
+ }
+ }
+
+ void reduce_size(size_t reduce_by) {
+ assert(reduce_by <= size_);
+ size_ -= reduce_by;
+ }
+
+ void destroy() noexcept {
+ storage_type::destroy_several(data_, size_);
+ if (dynamic_capacity_) {
+ delete[] data_;
+ }
+ }
+
+ void move_construct(SmallVectorStorage&& other) noexcept {
+ size_ = other.size_;
+ dynamic_capacity_ = other.dynamic_capacity_;
+ if (dynamic_capacity_) {
+ data_ = other.data_;
+ other.data_ = other.static_data_;
+ other.dynamic_capacity_ = 0;
+ other.size_ = 0;
+ } else if (size_ != 0) {
+ // Use a compile-time memcpy size (N) for trivial types
+ storage_type::move_construct_several(other.static_data_, static_data_, size_, N);
+ }
+ }
+
+ constexpr size_t capacity() const { return dynamic_capacity_ ? dynamic_capacity_ : N; }
+
+ constexpr size_t max_size() const { return std::numeric_limits<size_t>::max(); }
+
+ void reserve(size_t n) {
+ if (dynamic_capacity_) {
+ if (n > dynamic_capacity_) {
+ reallocate_dynamic(n);
+ }
+ } else if (n > N) {
+ switch_to_dynamic(n);
+ }
+ }
+
+ void clear() {
+ storage_type::destroy_several(data_, size_);
+ size_ = 0;
+ }
+
+ private:
+ void switch_to_dynamic(size_t new_capacity) {
+ dynamic_capacity_ = new_capacity;
+ data_ = new storage_type[new_capacity];
+ storage_type::move_construct_several_and_destroy_source(static_data_, data_, size_);
+ }
+
+ void reallocate_dynamic(size_t new_capacity) {
+ assert(new_capacity >= size_);
+ auto new_data = new storage_type[new_capacity];
+ storage_type::move_construct_several_and_destroy_source(data_, new_data, size_);
+ delete[] data_;
+ dynamic_capacity_ = new_capacity;
+ data_ = new_data;
+ }
+};
+
+template <typename T, size_t N, typename Storage>
+class StaticVectorImpl {
+ private:
+ Storage storage_;
+
+ T* data_ptr() { return storage_.storage_ptr()->get(); }
+
+ constexpr const T* const_data_ptr() const {
+ return storage_.const_storage_ptr()->get();
+ }
+
+ public:
+ using size_type = size_t;
+ using difference_type = ptrdiff_t;
+ using value_type = T;
+ using pointer = T*;
+ using const_pointer = const T*;
+ using reference = T&;
+ using const_reference = const T&;
+ using iterator = T*;
+ using const_iterator = const T*;
+ using reverse_iterator = std::reverse_iterator<iterator>;
+ using const_reverse_iterator = std::reverse_iterator<const_iterator>;
+
+ constexpr StaticVectorImpl() noexcept = default;
+
+ // Move and copy constructors
+ StaticVectorImpl(StaticVectorImpl&& other) noexcept {
+ storage_.move_construct(std::move(other.storage_));
+ }
+
+ StaticVectorImpl& operator=(StaticVectorImpl&& other) noexcept {
+ if (ARROW_PREDICT_TRUE(&other != this)) {
+ // TODO move_assign?
+ storage_.destroy();
+ storage_.move_construct(std::move(other.storage_));
+ }
+ return *this;
+ }
+
+ StaticVectorImpl(const StaticVectorImpl& other) {
+ init_by_copying(other.storage_.size_, other.const_data_ptr());
+ }
+
+ StaticVectorImpl& operator=(const StaticVectorImpl& other) noexcept {
+ if (ARROW_PREDICT_TRUE(&other != this)) {
+ assign_by_copying(other.storage_.size_, other.data());
+ }
+ return *this;
+ }
+
+ // Automatic conversion from std::vector<T>, for convenience
+ StaticVectorImpl(const std::vector<T>& other) { // NOLINT: explicit
+ init_by_copying(other.size(), other.data());
+ }
+
+ StaticVectorImpl(std::vector<T>&& other) noexcept { // NOLINT: explicit
+ init_by_moving(other.size(), other.data());
+ }
+
+ StaticVectorImpl& operator=(const std::vector<T>& other) {
+ assign_by_copying(other.size(), other.data());
+ return *this;
+ }
+
+ StaticVectorImpl& operator=(std::vector<T>&& other) noexcept {
+ assign_by_moving(other.size(), other.data());
+ return *this;
+ }
+
+ // Constructing from count and optional initialization value
+ explicit StaticVectorImpl(size_t count) {
+ storage_.bump_size(count);
+ auto* p = storage_.storage_ptr();
+ for (size_t i = 0; i < count; ++i) {
+ p[i].construct();
+ }
+ }
+
+ StaticVectorImpl(size_t count, const T& value) {
+ storage_.bump_size(count);
+ auto* p = storage_.storage_ptr();
+ for (size_t i = 0; i < count; ++i) {
+ p[i].construct(value);
+ }
+ }
+
+ StaticVectorImpl(std::initializer_list<T> values) {
+ storage_.bump_size(values.size());
+ auto* p = storage_.storage_ptr();
+ for (auto&& v : values) {
+ // Unfortunately, cannot move initializer values
+ p++->construct(v);
+ }
+ }
+
+ // Size inspection
+
+ constexpr bool empty() const { return storage_.size_ == 0; }
+
+ constexpr size_t size() const { return storage_.size_; }
+
+ constexpr size_t capacity() const { return storage_.capacity(); }
+
+ constexpr size_t max_size() const { return storage_.max_size(); }
+
+ // Data access
+
+ T& operator[](size_t i) { return data_ptr()[i]; }
+
+ constexpr const T& operator[](size_t i) const { return const_data_ptr()[i]; }
+
+ T& front() { return data_ptr()[0]; }
+
+ constexpr const T& front() const { return const_data_ptr()[0]; }
+
+ T& back() { return data_ptr()[storage_.size_ - 1]; }
+
+ constexpr const T& back() const { return const_data_ptr()[storage_.size_ - 1]; }
+
+ T* data() { return data_ptr(); }
+
+ constexpr const T* data() const { return const_data_ptr(); }
+
+ // Iterators
+
+ iterator begin() { return iterator(data_ptr()); }
+
+ constexpr const_iterator begin() const { return const_iterator(const_data_ptr()); }
+
+ constexpr const_iterator cbegin() const { return const_iterator(const_data_ptr()); }
+
+ iterator end() { return iterator(data_ptr() + storage_.size_); }
+
+ constexpr const_iterator end() const {
+ return const_iterator(const_data_ptr() + storage_.size_);
+ }
+
+ constexpr const_iterator cend() const {
+ return const_iterator(const_data_ptr() + storage_.size_);
+ }
+
+ reverse_iterator rbegin() { return reverse_iterator(end()); }
+
+ constexpr const_reverse_iterator rbegin() const {
+ return const_reverse_iterator(end());
+ }
+
+ constexpr const_reverse_iterator crbegin() const {
+ return const_reverse_iterator(end());
+ }
+
+ reverse_iterator rend() { return reverse_iterator(begin()); }
+
+ constexpr const_reverse_iterator rend() const {
+ return const_reverse_iterator(begin());
+ }
+
+ constexpr const_reverse_iterator crend() const {
+ return const_reverse_iterator(begin());
+ }
+
+ // Mutations
+
+ void reserve(size_t n) { storage_.reserve(n); }
+
+ void clear() { storage_.clear(); }
+
+ void push_back(const T& value) {
+ storage_.bump_size(1);
+ storage_.storage_ptr()[storage_.size_ - 1].construct(value);
+ }
+
+ void push_back(T&& value) {
+ storage_.bump_size(1);
+ storage_.storage_ptr()[storage_.size_ - 1].construct(std::move(value));
+ }
+
+ template <typename... Args>
+ void emplace_back(Args&&... args) {
+ storage_.bump_size(1);
+ storage_.storage_ptr()[storage_.size_ - 1].construct(std::forward<Args>(args)...);
+ }
+
+ template <typename InputIt>
+ iterator insert(const_iterator insert_at, InputIt first, InputIt last) {
+ const size_t n = storage_.size_;
+ const size_t it_size = static_cast<size_t>(last - first); // XXX might be O(n)?
+ const size_t pos = static_cast<size_t>(insert_at - const_data_ptr());
+ storage_.bump_size(it_size);
+ auto* p = storage_.storage_ptr();
+ if (it_size == 0) {
+ return p[pos].get();
+ }
+ const size_t end_pos = pos + it_size;
+
+ // Move [pos; n) to [end_pos; end_pos + n - pos)
+ size_t i = n;
+ size_t j = end_pos + n - pos;
+ while (j > std::max(n, end_pos)) {
+ p[--j].move_construct(&p[--i]);
+ }
+ while (j > end_pos) {
+ p[--j].move_assign(&p[--i]);
+ }
+ assert(j == end_pos);
+ // Copy [first; last) to [pos; end_pos)
+ j = pos;
+ while (j < std::min(n, end_pos)) {
+ p[j++].assign(*first++);
+ }
+ while (j < end_pos) {
+ p[j++].construct(*first++);
+ }
+ assert(first == last);
+ return p[pos].get();
+ }
+
+ void resize(size_t n) {
+ const size_t old_size = storage_.size_;
+ if (n > storage_.size_) {
+ storage_.bump_size(n - old_size);
+ auto* p = storage_.storage_ptr();
+ for (size_t i = old_size; i < n; ++i) {
+ p[i].construct(T{});
+ }
+ } else {
+ auto* p = storage_.storage_ptr();
+ for (size_t i = n; i < old_size; ++i) {
+ p[i].destroy();
+ }
+ storage_.reduce_size(old_size - n);
+ }
+ }
+
+ void resize(size_t n, const T& value) {
+ const size_t old_size = storage_.size_;
+ if (n > storage_.size_) {
+ storage_.bump_size(n - old_size);
+ auto* p = storage_.storage_ptr();
+ for (size_t i = old_size; i < n; ++i) {
+ p[i].construct(value);
+ }
+ } else {
+ auto* p = storage_.storage_ptr();
+ for (size_t i = n; i < old_size; ++i) {
+ p[i].destroy();
+ }
+ storage_.reduce_size(old_size - n);
+ }
+ }
+
+ private:
+ template <typename InputIt>
+ void init_by_copying(size_t n, InputIt src) {
+ storage_.bump_size(n);
+ auto* dest = storage_.storage_ptr();
+ for (size_t i = 0; i < n; ++i, ++src) {
+ dest[i].construct(*src);
+ }
+ }
+
+ template <typename InputIt>
+ void init_by_moving(size_t n, InputIt src) {
+ init_by_copying(n, std::make_move_iterator(src));
+ }
+
+ template <typename InputIt>
+ void assign_by_copying(size_t n, InputIt src) {
+ const size_t old_size = storage_.size_;
+ if (n > old_size) {
+ storage_.bump_size(n - old_size);
+ auto* dest = storage_.storage_ptr();
+ for (size_t i = 0; i < old_size; ++i, ++src) {
+ dest[i].assign(*src);
+ }
+ for (size_t i = old_size; i < n; ++i, ++src) {
+ dest[i].construct(*src);
+ }
+ } else {
+ auto* dest = storage_.storage_ptr();
+ for (size_t i = 0; i < n; ++i, ++src) {
+ dest[i].assign(*src);
+ }
+ for (size_t i = n; i < old_size; ++i) {
+ dest[i].destroy();
+ }
+ storage_.reduce_size(old_size - n);
+ }
+ }
+
+ template <typename InputIt>
+ void assign_by_moving(size_t n, InputIt src) {
+ assign_by_copying(n, std::make_move_iterator(src));
+ }
+};
+
+template <typename T, size_t N>
+using StaticVector = StaticVectorImpl<T, N, StaticVectorStorage<T, N>>;
+
+template <typename T, size_t N>
+using SmallVector = StaticVectorImpl<T, N, SmallVectorStorage<T, N>>;
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/small_vector_benchmark.cc b/src/arrow/cpp/src/arrow/util/small_vector_benchmark.cc
new file mode 100644
index 000000000..96f94c369
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/small_vector_benchmark.cc
@@ -0,0 +1,344 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <iterator>
+#include <memory>
+#include <numeric>
+#include <string>
+#include <vector>
+
+#include <benchmark/benchmark.h>
+
+#include "arrow/testing/util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/small_vector.h"
+
+namespace arrow {
+namespace internal {
+
+template <typename T>
+T ValueInitializer();
+template <typename T>
+T ValueInitializer(int seed);
+
+template <>
+int ValueInitializer<int>() {
+ return 42;
+}
+template <>
+int ValueInitializer<int>(int seed) {
+ return 42;
+}
+
+template <>
+std::string ValueInitializer<std::string>() {
+ return "42";
+}
+template <>
+std::string ValueInitializer<std::string>(int seed) {
+ return std::string("x", seed & 0x3f); // avoid making string too long
+}
+
+template <>
+std::shared_ptr<int> ValueInitializer<std::shared_ptr<int>>() {
+ return std::make_shared<int>(42);
+}
+template <>
+std::shared_ptr<int> ValueInitializer<std::shared_ptr<int>>(int seed) {
+ return std::make_shared<int>(seed);
+}
+
+template <typename Vector>
+ARROW_NOINLINE int64_t ConsumeVector(Vector v) {
+ return reinterpret_cast<intptr_t>(v.data());
+}
+
+template <typename Vector>
+ARROW_NOINLINE int64_t IngestVector(const Vector& v) {
+ return reinterpret_cast<intptr_t>(v.data());
+}
+
+// With ARROW_NOINLINE, try to make sure the number of items is not constant-propagated
+template <typename Vector>
+ARROW_NOINLINE void BenchmarkMoveVector(benchmark::State& state, Vector vec) {
+ constexpr int kNumIters = 1000;
+
+ for (auto _ : state) {
+ int64_t dummy = 0;
+ for (int i = 0; i < kNumIters; ++i) {
+ Vector tmp(std::move(vec));
+ dummy += IngestVector(tmp);
+ vec = std::move(tmp);
+ }
+ benchmark::DoNotOptimize(dummy);
+ }
+
+ state.SetItemsProcessed(state.iterations() * kNumIters * 2);
+}
+
+template <typename Vector>
+void MoveEmptyVector(benchmark::State& state) {
+ BenchmarkMoveVector(state, Vector{});
+}
+
+template <typename Vector>
+void MoveShortVector(benchmark::State& state) {
+ using T = typename Vector::value_type;
+ constexpr int kSize = 3;
+ const auto initializer = ValueInitializer<T>();
+
+ BenchmarkMoveVector(state, Vector(kSize, initializer));
+}
+
+template <typename Vector>
+void CopyEmptyVector(benchmark::State& state) {
+ constexpr int kNumIters = 1000;
+
+ const Vector vec{};
+
+ for (auto _ : state) {
+ int64_t dummy = 0;
+ for (int i = 0; i < kNumIters; ++i) {
+ dummy += ConsumeVector(vec);
+ }
+ benchmark::DoNotOptimize(dummy);
+ }
+
+ state.SetItemsProcessed(state.iterations() * kNumIters);
+}
+
+template <typename Vector>
+void CopyShortVector(benchmark::State& state) {
+ constexpr int kSize = 3;
+ constexpr int kNumIters = 1000;
+
+ const Vector vec(kSize);
+
+ for (auto _ : state) {
+ int64_t dummy = 0;
+ for (int i = 0; i < kNumIters; ++i) {
+ dummy += ConsumeVector(vec);
+ }
+ benchmark::DoNotOptimize(dummy);
+ }
+
+ state.SetItemsProcessed(state.iterations() * kNumIters);
+}
+
+// With ARROW_NOINLINE, try to make sure the number of items is not constant-propagated
+template <typename Vector>
+ARROW_NOINLINE void BenchmarkConstructFromStdVector(benchmark::State& state,
+ const int nitems) {
+ using T = typename Vector::value_type;
+ constexpr int kNumIters = 1000;
+ const std::vector<T> src(nitems, ValueInitializer<T>());
+
+ for (auto _ : state) {
+ int64_t dummy = 0;
+ for (int i = 0; i < kNumIters; ++i) {
+ Vector vec(src);
+ dummy += IngestVector(vec);
+ }
+ benchmark::DoNotOptimize(dummy);
+ }
+
+ state.SetItemsProcessed(state.iterations() * kNumIters);
+}
+
+template <typename Vector>
+void ConstructFromEmptyStdVector(benchmark::State& state) {
+ BenchmarkConstructFromStdVector<Vector>(state, 0);
+}
+
+template <typename Vector>
+void ConstructFromShortStdVector(benchmark::State& state) {
+ BenchmarkConstructFromStdVector<Vector>(state, 3);
+}
+
+// With ARROW_NOINLINE, try to make sure the number of items is not constant-propagated
+template <typename Vector>
+ARROW_NOINLINE void BenchmarkVectorPushBack(benchmark::State& state, const int nitems) {
+ using T = typename Vector::value_type;
+ constexpr int kNumIters = 1000;
+
+ ARROW_CHECK_LE(static_cast<size_t>(nitems), Vector{}.max_size());
+
+ for (auto _ : state) {
+ int64_t dummy = 0;
+ for (int i = 0; i < kNumIters; ++i) {
+ Vector vec;
+ vec.reserve(nitems);
+ for (int j = 0; j < nitems; ++j) {
+ vec.push_back(ValueInitializer<T>(j));
+ }
+ dummy += reinterpret_cast<intptr_t>(vec.data());
+ benchmark::ClobberMemory();
+ }
+ benchmark::DoNotOptimize(dummy);
+ }
+
+ state.SetItemsProcessed(state.iterations() * kNumIters * nitems);
+}
+
+template <typename Vector>
+void ShortVectorPushBack(benchmark::State& state) {
+ BenchmarkVectorPushBack<Vector>(state, 3);
+}
+
+template <typename Vector>
+void LongVectorPushBack(benchmark::State& state) {
+ BenchmarkVectorPushBack<Vector>(state, 100);
+}
+
+// With ARROW_NOINLINE, try to make sure the source data is not constant-propagated
+// (we could also use random data)
+template <typename Vector, typename T = typename Vector::value_type>
+ARROW_NOINLINE void BenchmarkShortVectorInsert(benchmark::State& state,
+ const std::vector<T>& src) {
+ constexpr int kNumIters = 1000;
+
+ for (auto _ : state) {
+ int64_t dummy = 0;
+ for (int i = 0; i < kNumIters; ++i) {
+ Vector vec;
+ vec.reserve(4);
+ vec.insert(vec.begin(), src.begin(), src.begin() + 2);
+ vec.insert(vec.begin(), src.begin() + 2, src.end());
+ dummy += reinterpret_cast<intptr_t>(vec.data());
+ benchmark::ClobberMemory();
+ }
+ benchmark::DoNotOptimize(dummy);
+ }
+
+ state.SetItemsProcessed(state.iterations() * kNumIters * 4);
+}
+
+template <typename Vector>
+void ShortVectorInsert(benchmark::State& state) {
+ using T = typename Vector::value_type;
+ const std::vector<T> src(4, ValueInitializer<T>());
+ BenchmarkShortVectorInsert<Vector>(state, src);
+}
+
+template <typename Vector>
+ARROW_NOINLINE void BenchmarkVectorInsertAtEnd(benchmark::State& state,
+ const int nitems) {
+ using T = typename Vector::value_type;
+ constexpr int kNumIters = 1000;
+
+ ARROW_CHECK_LE(static_cast<size_t>(nitems), Vector{}.max_size());
+ ARROW_CHECK_EQ(nitems % 2, 0);
+
+ std::vector<T> src;
+ for (int j = 0; j < nitems / 2; ++j) {
+ src.push_back(ValueInitializer<T>(j));
+ }
+
+ for (auto _ : state) {
+ int64_t dummy = 0;
+ for (int i = 0; i < kNumIters; ++i) {
+ Vector vec;
+ vec.reserve(nitems);
+ vec.insert(vec.end(), src.begin(), src.end());
+ vec.insert(vec.end(), src.begin(), src.end());
+ dummy += reinterpret_cast<intptr_t>(vec.data());
+ benchmark::ClobberMemory();
+ }
+ benchmark::DoNotOptimize(dummy);
+ }
+
+ state.SetItemsProcessed(state.iterations() * kNumIters * nitems);
+}
+
+template <typename Vector>
+void ShortVectorInsertAtEnd(benchmark::State& state) {
+ BenchmarkVectorInsertAtEnd<Vector>(state, 4);
+}
+
+template <typename Vector>
+void LongVectorInsertAtEnd(benchmark::State& state) {
+ BenchmarkVectorInsertAtEnd<Vector>(state, 100);
+}
+
+#define SHORT_VECTOR_BENCHMARKS(VEC_TYPE_FACTORY) \
+ BENCHMARK_TEMPLATE(MoveEmptyVector, VEC_TYPE_FACTORY(int)); \
+ BENCHMARK_TEMPLATE(MoveEmptyVector, VEC_TYPE_FACTORY(std::string)); \
+ BENCHMARK_TEMPLATE(MoveEmptyVector, VEC_TYPE_FACTORY(std::shared_ptr<int>)); \
+ BENCHMARK_TEMPLATE(MoveShortVector, VEC_TYPE_FACTORY(int)); \
+ BENCHMARK_TEMPLATE(MoveShortVector, VEC_TYPE_FACTORY(std::string)); \
+ BENCHMARK_TEMPLATE(MoveShortVector, VEC_TYPE_FACTORY(std::shared_ptr<int>)); \
+ BENCHMARK_TEMPLATE(CopyEmptyVector, VEC_TYPE_FACTORY(int)); \
+ BENCHMARK_TEMPLATE(CopyEmptyVector, VEC_TYPE_FACTORY(std::string)); \
+ BENCHMARK_TEMPLATE(CopyEmptyVector, VEC_TYPE_FACTORY(std::shared_ptr<int>)); \
+ BENCHMARK_TEMPLATE(CopyShortVector, VEC_TYPE_FACTORY(int)); \
+ BENCHMARK_TEMPLATE(CopyShortVector, VEC_TYPE_FACTORY(std::string)); \
+ BENCHMARK_TEMPLATE(CopyShortVector, VEC_TYPE_FACTORY(std::shared_ptr<int>)); \
+ BENCHMARK_TEMPLATE(ConstructFromEmptyStdVector, VEC_TYPE_FACTORY(int)); \
+ BENCHMARK_TEMPLATE(ConstructFromEmptyStdVector, VEC_TYPE_FACTORY(std::string)); \
+ BENCHMARK_TEMPLATE(ConstructFromEmptyStdVector, \
+ VEC_TYPE_FACTORY(std::shared_ptr<int>)); \
+ BENCHMARK_TEMPLATE(ConstructFromShortStdVector, VEC_TYPE_FACTORY(int)); \
+ BENCHMARK_TEMPLATE(ConstructFromShortStdVector, VEC_TYPE_FACTORY(std::string)); \
+ BENCHMARK_TEMPLATE(ConstructFromShortStdVector, \
+ VEC_TYPE_FACTORY(std::shared_ptr<int>)); \
+ BENCHMARK_TEMPLATE(ShortVectorPushBack, VEC_TYPE_FACTORY(int)); \
+ BENCHMARK_TEMPLATE(ShortVectorPushBack, VEC_TYPE_FACTORY(std::string)); \
+ BENCHMARK_TEMPLATE(ShortVectorPushBack, VEC_TYPE_FACTORY(std::shared_ptr<int>)); \
+ BENCHMARK_TEMPLATE(ShortVectorInsert, VEC_TYPE_FACTORY(int)); \
+ BENCHMARK_TEMPLATE(ShortVectorInsert, VEC_TYPE_FACTORY(std::string)); \
+ BENCHMARK_TEMPLATE(ShortVectorInsert, VEC_TYPE_FACTORY(std::shared_ptr<int>)); \
+ BENCHMARK_TEMPLATE(ShortVectorInsertAtEnd, VEC_TYPE_FACTORY(int)); \
+ BENCHMARK_TEMPLATE(ShortVectorInsertAtEnd, VEC_TYPE_FACTORY(std::string)); \
+ BENCHMARK_TEMPLATE(ShortVectorInsertAtEnd, VEC_TYPE_FACTORY(std::shared_ptr<int>));
+
+#define LONG_VECTOR_BENCHMARKS(VEC_TYPE_FACTORY) \
+ BENCHMARK_TEMPLATE(LongVectorPushBack, VEC_TYPE_FACTORY(int)); \
+ BENCHMARK_TEMPLATE(LongVectorPushBack, VEC_TYPE_FACTORY(std::string)); \
+ BENCHMARK_TEMPLATE(LongVectorPushBack, VEC_TYPE_FACTORY(std::shared_ptr<int>)); \
+ BENCHMARK_TEMPLATE(LongVectorInsertAtEnd, VEC_TYPE_FACTORY(int)); \
+ BENCHMARK_TEMPLATE(LongVectorInsertAtEnd, VEC_TYPE_FACTORY(std::string)); \
+ BENCHMARK_TEMPLATE(LongVectorInsertAtEnd, VEC_TYPE_FACTORY(std::shared_ptr<int>));
+
+// NOTE: the macro name below (STD_VECTOR etc.) is reflected in the
+// benchmark name, so use descriptive names.
+
+#ifdef ARROW_WITH_BENCHMARKS_REFERENCE
+
+#define STD_VECTOR(T) std::vector<T>
+SHORT_VECTOR_BENCHMARKS(STD_VECTOR);
+LONG_VECTOR_BENCHMARKS(STD_VECTOR);
+#undef STD_VECTOR
+
+#endif
+
+#define STATIC_VECTOR(T) StaticVector<T, 4>
+SHORT_VECTOR_BENCHMARKS(STATIC_VECTOR);
+#undef STATIC_VECTOR
+
+#define SMALL_VECTOR(T) SmallVector<T, 4>
+SHORT_VECTOR_BENCHMARKS(SMALL_VECTOR);
+LONG_VECTOR_BENCHMARKS(SMALL_VECTOR);
+#undef SMALL_VECTOR
+
+#undef SHORT_VECTOR_BENCHMARKS
+#undef LONG_VECTOR_BENCHMARKS
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/small_vector_test.cc b/src/arrow/cpp/src/arrow/util/small_vector_test.cc
new file mode 100644
index 000000000..f9ec5fedf
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/small_vector_test.cc
@@ -0,0 +1,786 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstddef>
+#include <iterator>
+#include <limits>
+#include <memory>
+#include <string>
+#include <type_traits>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+#include "arrow/util/small_vector.h"
+
+using testing::ElementsAre;
+using testing::ElementsAreArray;
+
+namespace arrow {
+namespace internal {
+
+struct HeapInt {
+ HeapInt() : HeapInt(0) {}
+
+ explicit HeapInt(int x) : ptr(new int(x)) {}
+
+ HeapInt& operator=(int x) {
+ ptr.reset(new int(x));
+ return *this;
+ }
+
+ HeapInt(const HeapInt& other) : HeapInt(other.ToInt()) {}
+
+ HeapInt& operator=(const HeapInt& other) {
+ *this = other.ToInt();
+ return *this;
+ }
+
+ int ToInt() const { return ptr == nullptr ? -98 : *ptr; }
+
+ bool operator==(const HeapInt& other) const {
+ return ptr != nullptr && other.ptr != nullptr && *ptr == *other.ptr;
+ }
+ bool operator<(const HeapInt& other) const {
+ return ptr == nullptr || (other.ptr != nullptr && *ptr < *other.ptr);
+ }
+
+ bool operator==(int other) const { return ptr != nullptr && *ptr == other; }
+ friend bool operator==(int left, const HeapInt& right) { return right == left; }
+
+ std::unique_ptr<int> ptr;
+};
+
+template <typename Vector>
+bool UsesStaticStorage(const Vector& v) {
+ const uint8_t* p = reinterpret_cast<const uint8_t*>(v.data());
+ if (p == nullptr) {
+ return true;
+ }
+ const uint8_t* v_start = reinterpret_cast<const uint8_t*>(&v);
+ return (p >= v_start && p < v_start + sizeof(v));
+}
+
+struct StaticVectorTraits {
+ template <typename T, size_t N>
+ using VectorType = StaticVector<T, N>;
+
+ static bool CanOverflow() { return false; }
+
+ static constexpr size_t MaxSizeFor(size_t n) { return n; }
+
+ static constexpr size_t TestSizeFor(size_t max_size) { return max_size; }
+};
+
+struct SmallVectorTraits {
+ template <typename T, size_t N>
+ using VectorType = SmallVector<T, N>;
+
+ static bool CanOverflow() { return true; }
+
+ static constexpr size_t MaxSizeFor(size_t n) {
+ return std::numeric_limits<size_t>::max();
+ }
+
+ static constexpr size_t TestSizeFor(size_t max_size) {
+ return max_size > 6 ? max_size / 3 : 2;
+ }
+};
+
+using VectorTraits = ::testing::Types<StaticVectorTraits, SmallVectorTraits>;
+
+template <typename T, typename I>
+struct VectorIntLikeParam {
+ using Traits = T;
+ using IntLike = I;
+
+ static constexpr bool IsMoveOnly() { return !std::is_copy_constructible<I>::value; }
+};
+
+using VectorIntLikeParams =
+ ::testing::Types<VectorIntLikeParam<StaticVectorTraits, int>,
+ VectorIntLikeParam<SmallVectorTraits, int>,
+ VectorIntLikeParam<StaticVectorTraits, HeapInt>,
+ VectorIntLikeParam<SmallVectorTraits, HeapInt>,
+ VectorIntLikeParam<StaticVectorTraits, MoveOnlyDataType>,
+ VectorIntLikeParam<SmallVectorTraits, MoveOnlyDataType>>;
+
+template <typename Param>
+class TestSmallStaticVector : public ::testing::Test {
+ template <bool B, typename T = void>
+ using enable_if_t = typename std::enable_if<B, T>::type;
+
+ template <typename P>
+ using enable_if_move_only = enable_if_t<P::IsMoveOnly(), int>;
+
+ template <typename P>
+ using enable_if_not_move_only = enable_if_t<!P::IsMoveOnly(), int>;
+
+ public:
+ using Traits = typename Param::Traits;
+ using IntLike = typename Param::IntLike;
+
+ template <typename T, size_t N>
+ using VectorType = typename Traits::template VectorType<T, N>;
+
+ template <size_t N>
+ using IntVectorType = VectorType<IntLike, N>;
+
+ template <size_t N>
+ IntVectorType<N> MakeVector(const std::vector<int>& init_values) {
+ IntVectorType<N> ints;
+ for (auto v : init_values) {
+ ints.emplace_back(v);
+ }
+ return ints;
+ }
+
+ template <size_t N>
+ IntVectorType<N> CheckFourValues() {
+ IntVectorType<N> ints;
+ EXPECT_EQ(ints.size(), 0);
+ EXPECT_EQ(ints.capacity(), N);
+ EXPECT_EQ(ints.max_size(), Traits::MaxSizeFor(N));
+ EXPECT_TRUE(UsesStaticStorage(ints));
+
+ ints.emplace_back(3);
+ ints.emplace_back(42);
+ EXPECT_EQ(ints.size(), 2);
+ EXPECT_EQ(ints.capacity(), N);
+ EXPECT_EQ(ints[0], 3);
+ EXPECT_EQ(ints[1], 42);
+ EXPECT_TRUE(UsesStaticStorage(ints));
+
+ ints.push_back(IntLike(5));
+ ints.emplace_back(false);
+ EXPECT_EQ(ints.size(), 4);
+ EXPECT_EQ(ints[2], 5);
+ EXPECT_EQ(ints[3], 0);
+
+ ints[3] = IntLike(8);
+ EXPECT_EQ(ints[3], 8);
+ EXPECT_EQ(ints.back(), 8);
+ ints.front() = IntLike(-1);
+ EXPECT_EQ(ints[0], -1);
+ EXPECT_EQ(ints.front(), -1);
+
+ return ints;
+ }
+
+ void TestBasics() {
+ constexpr size_t N = Traits::TestSizeFor(4);
+ const auto ints = CheckFourValues<N>();
+ EXPECT_EQ(UsesStaticStorage(ints), !Traits::CanOverflow());
+ }
+
+ void TestAlwaysStatic() {
+ const auto ints = CheckFourValues<4>();
+ EXPECT_TRUE(UsesStaticStorage(ints));
+ }
+
+ template <size_t N>
+ void CheckReserve(size_t max_size, bool expect_overflow) {
+ IntVectorType<N> ints;
+ ints.emplace_back(123);
+
+ size_t orig_capacity = ints.capacity();
+
+ ints.reserve(max_size / 3);
+ ASSERT_EQ(ints.capacity(), std::max(max_size / 3, orig_capacity));
+ ASSERT_EQ(ints.size(), 1);
+ ASSERT_EQ(ints[0], 123);
+
+ ints.reserve(4 * max_size / 5);
+ ASSERT_EQ(ints.capacity(), std::max(4 * max_size / 5, orig_capacity));
+ ASSERT_EQ(ints.size(), 1);
+ ASSERT_EQ(ints[0], 123);
+ ASSERT_EQ(UsesStaticStorage(ints), !expect_overflow);
+
+ size_t old_capacity = ints.capacity();
+ ints.reserve(max_size / 5); // no-op
+ ASSERT_EQ(ints.capacity(), old_capacity);
+ ASSERT_EQ(ints.size(), 1);
+ ASSERT_EQ(ints[0], 123);
+
+ ints.reserve(1); // no-op
+ ASSERT_EQ(ints.capacity(), old_capacity);
+ ASSERT_EQ(ints.size(), 1);
+ ASSERT_EQ(ints[0], 123);
+ }
+
+ void TestReserve() {
+ CheckReserve<Traits::TestSizeFor(12)>(12, /*expect_overflow=*/Traits::CanOverflow());
+ CheckReserve<12>(12, /*expect_overflow=*/false);
+ }
+
+ template <size_t N>
+ void CheckClear(bool expect_overflow) {
+ IntVectorType<N> ints = MakeVector<N>({5, 6, 7, 8, 9});
+ ASSERT_EQ(ints.size(), 5);
+ size_t capacity = ints.capacity();
+
+ ints.clear();
+ ASSERT_EQ(ints.size(), 0);
+ ASSERT_EQ(ints.capacity(), capacity);
+ ASSERT_EQ(UsesStaticStorage(ints), !expect_overflow);
+ }
+
+ void TestClear() {
+ CheckClear<Traits::TestSizeFor(5)>(/*expect_overflow=*/Traits::CanOverflow());
+ CheckClear<6>(/*expect_overflow=*/false);
+ }
+
+ template <bool IsMoveOnly = Param::IsMoveOnly()>
+ void TestConstructFromCount(enable_if_t<!IsMoveOnly>* = 0) {
+ constexpr size_t N = Traits::TestSizeFor(4);
+ {
+ const IntVectorType<N> ints(3);
+ ASSERT_EQ(ints.size(), 3);
+ ASSERT_EQ(ints.capacity(), std::max<size_t>(N, 3));
+ for (int i = 0; i < 3; ++i) {
+ ASSERT_EQ(ints[i], 0);
+ }
+ EXPECT_THAT(ints, ElementsAre(0, 0, 0));
+ }
+ }
+
+ template <bool IsMoveOnly = Param::IsMoveOnly()>
+ void TestConstructFromCount(enable_if_t<IsMoveOnly>* = 0) {
+ GTEST_SKIP() << "Cannot construct vector of move-only type with value count";
+ }
+
+ template <size_t N>
+ void CheckConstructFromValues() {
+ {
+ const IntVectorType<N> ints{};
+ ASSERT_EQ(ints.size(), 0);
+ ASSERT_EQ(ints.capacity(), N);
+ }
+ {
+ const IntVectorType<N> ints{IntLike(4), IntLike(5), IntLike(6)};
+ ASSERT_EQ(ints.size(), 3);
+ ASSERT_EQ(ints.capacity(), std::max<size_t>(N, 3));
+ ASSERT_EQ(ints[0], 4);
+ ASSERT_EQ(ints[1], 5);
+ ASSERT_EQ(ints[2], 6);
+ ASSERT_EQ(ints.front(), 4);
+ ASSERT_EQ(ints.back(), 6);
+ }
+ }
+
+ template <bool IsMoveOnly = Param::IsMoveOnly()>
+ void TestConstructFromValues(enable_if_t<!IsMoveOnly>* = 0) {
+ CheckConstructFromValues<Traits::TestSizeFor(4)>();
+ CheckConstructFromValues<5>();
+ }
+
+ template <bool IsMoveOnly = Param::IsMoveOnly()>
+ void TestConstructFromValues(enable_if_t<IsMoveOnly>* = 0) {
+ GTEST_SKIP() << "Cannot construct vector of move-only type with explicit values";
+ }
+
+ void CheckConstructFromMovedStdVector() {
+ constexpr size_t N = Traits::TestSizeFor(6);
+ {
+ std::vector<IntLike> src;
+ const IntVectorType<N> ints(std::move(src));
+ ASSERT_EQ(ints.size(), 0);
+ ASSERT_EQ(ints.capacity(), N);
+ }
+ {
+ std::vector<IntLike> src;
+ for (int i = 0; i < 6; ++i) {
+ src.emplace_back(i + 4);
+ }
+ const IntVectorType<N> ints(std::move(src));
+ ASSERT_EQ(ints.size(), 6);
+ ASSERT_EQ(ints.capacity(), std::max<size_t>(N, 6));
+ EXPECT_THAT(ints, ElementsAre(4, 5, 6, 7, 8, 9));
+ }
+ }
+
+ template <bool IsMoveOnly = Param::IsMoveOnly()>
+ void CheckConstructFromCopiedStdVector(enable_if_t<!IsMoveOnly>* = 0) {
+ constexpr size_t N = Traits::TestSizeFor(6);
+ {
+ const std::vector<IntLike> src;
+ const IntVectorType<N> ints(src);
+ ASSERT_EQ(ints.size(), 0);
+ ASSERT_EQ(ints.capacity(), N);
+ }
+ {
+ std::vector<IntLike> values;
+ for (int i = 0; i < 6; ++i) {
+ values.emplace_back(i + 4);
+ }
+ const auto& src = values;
+ const IntVectorType<N> ints(src);
+ ASSERT_EQ(ints.size(), 6);
+ ASSERT_EQ(ints.capacity(), std::max<size_t>(N, 6));
+ EXPECT_THAT(ints, ElementsAre(4, 5, 6, 7, 8, 9));
+ }
+ }
+
+ template <bool IsMoveOnly = Param::IsMoveOnly()>
+ void CheckConstructFromCopiedStdVector(enable_if_t<IsMoveOnly>* = 0) {}
+
+ void TestConstructFromStdVector() {
+ CheckConstructFromMovedStdVector();
+ CheckConstructFromCopiedStdVector();
+ }
+
+ void CheckAssignFromMovedStdVector() {
+ constexpr size_t N = Traits::TestSizeFor(6);
+ {
+ std::vector<IntLike> src;
+ IntVectorType<N> ints = MakeVector<N>({42});
+ ints = std::move(src);
+ ASSERT_EQ(ints.size(), 0);
+ ASSERT_EQ(ints.capacity(), N);
+ }
+ {
+ std::vector<IntLike> src;
+ for (int i = 0; i < 6; ++i) {
+ src.emplace_back(i + 4);
+ }
+ IntVectorType<N> ints = MakeVector<N>({42});
+ ints = std::move(src);
+ ASSERT_EQ(ints.size(), 6);
+ ASSERT_EQ(ints.capacity(), std::max<size_t>(N, 6));
+ EXPECT_THAT(ints, ElementsAre(4, 5, 6, 7, 8, 9));
+ }
+ }
+
+ template <bool IsMoveOnly = Param::IsMoveOnly()>
+ void CheckAssignFromCopiedStdVector(enable_if_t<!IsMoveOnly>* = 0) {
+ constexpr size_t N = Traits::TestSizeFor(6);
+ {
+ const std::vector<IntLike> src;
+ IntVectorType<N> ints = MakeVector<N>({42});
+ ints = src;
+ ASSERT_EQ(ints.size(), 0);
+ ASSERT_EQ(ints.capacity(), N);
+ }
+ {
+ std::vector<IntLike> values;
+ for (int i = 0; i < 6; ++i) {
+ values.emplace_back(i + 4);
+ }
+ const auto& src = values;
+ IntVectorType<N> ints = MakeVector<N>({42});
+ ints = src;
+ ASSERT_EQ(ints.size(), 6);
+ ASSERT_EQ(ints.capacity(), std::max<size_t>(N, 6));
+ EXPECT_THAT(ints, ElementsAre(4, 5, 6, 7, 8, 9));
+ }
+ }
+
+ template <bool IsMoveOnly = Param::IsMoveOnly()>
+ void CheckAssignFromCopiedStdVector(enable_if_t<IsMoveOnly>* = 0) {}
+
+ void TestAssignFromStdVector() {
+ CheckAssignFromMovedStdVector();
+ CheckAssignFromCopiedStdVector();
+ }
+
+ template <size_t N>
+ void CheckMove(bool expect_overflow) {
+ IntVectorType<N> ints = MakeVector<N>({4, 5, 6, 7, 8});
+
+ IntVectorType<N> moved_ints(std::move(ints));
+ ASSERT_EQ(moved_ints.size(), 5);
+ EXPECT_THAT(moved_ints, ElementsAre(4, 5, 6, 7, 8));
+ ASSERT_EQ(UsesStaticStorage(moved_ints), !expect_overflow);
+ ASSERT_TRUE(UsesStaticStorage(ints));
+
+ IntVectorType<N> moved_moved_ints = std::move(moved_ints);
+ ASSERT_EQ(moved_moved_ints.size(), 5);
+ EXPECT_THAT(moved_moved_ints, ElementsAre(4, 5, 6, 7, 8));
+
+ // Move into itself
+ moved_moved_ints = std::move(moved_moved_ints);
+ ASSERT_EQ(moved_moved_ints.size(), 5);
+ EXPECT_THAT(moved_moved_ints, ElementsAre(4, 5, 6, 7, 8));
+ }
+
+ void TestMove() {
+ CheckMove<Traits::TestSizeFor(5)>(/*expect_overflow=*/Traits::CanOverflow());
+ CheckMove<5>(/*expect_overflow=*/false);
+ }
+
+ template <size_t N>
+ void CheckCopy(bool expect_overflow) {
+ IntVectorType<N> ints = MakeVector<N>({4, 5, 6, 7, 8});
+
+ IntVectorType<N> copied_ints(ints);
+ ASSERT_EQ(copied_ints.size(), 5);
+ ASSERT_EQ(ints.size(), 5);
+ EXPECT_THAT(copied_ints, ElementsAre(4, 5, 6, 7, 8));
+ EXPECT_THAT(ints, ElementsAre(4, 5, 6, 7, 8));
+ ASSERT_EQ(UsesStaticStorage(copied_ints), !expect_overflow);
+
+ IntVectorType<N> copied_copied_ints = copied_ints;
+ ASSERT_EQ(copied_copied_ints.size(), 5);
+ ASSERT_EQ(copied_ints.size(), 5);
+ EXPECT_THAT(copied_copied_ints, ElementsAre(4, 5, 6, 7, 8));
+ EXPECT_THAT(copied_ints, ElementsAre(4, 5, 6, 7, 8));
+
+ auto copy_into = [](const IntVectorType<N>& src, IntVectorType<N>* dest) {
+ *dest = src;
+ };
+
+ // Copy into itself
+ // (avoiding the trivial form `copied_copied_ints = copied_copied_ints`
+ // that would produce a clang warning)
+ copy_into(copied_copied_ints, &copied_copied_ints);
+ ASSERT_EQ(copied_copied_ints.size(), 5);
+ EXPECT_THAT(copied_copied_ints, ElementsAre(4, 5, 6, 7, 8));
+ }
+
+ template <bool IsMoveOnly = Param::IsMoveOnly()>
+ void TestCopy(enable_if_t<!IsMoveOnly>* = 0) {
+ CheckCopy<Traits::TestSizeFor(5)>(/*expect_overflow=*/Traits::CanOverflow());
+ CheckCopy<5>(/*expect_overflow=*/false);
+ }
+
+ template <bool IsMoveOnly = Param::IsMoveOnly()>
+ void TestCopy(enable_if_t<IsMoveOnly>* = 0) {
+ GTEST_SKIP() << "Cannot copy vector of move-only type";
+ }
+
+ template <bool IsMoveOnly = Param::IsMoveOnly()>
+ void TestResize(enable_if_t<!IsMoveOnly>* = 0) {
+ constexpr size_t N = Traits::TestSizeFor(8);
+ {
+ IntVectorType<N> ints;
+ ints.resize(2);
+ ASSERT_GE(ints.capacity(), 2);
+ EXPECT_THAT(ints, ElementsAreArray(std::vector<int>(2, 0)));
+ ints.resize(3);
+ ASSERT_GE(ints.capacity(), 3);
+ EXPECT_THAT(ints, ElementsAreArray(std::vector<int>(3, 0)));
+ ints.resize(8);
+ ASSERT_GE(ints.capacity(), 8);
+ EXPECT_THAT(ints, ElementsAreArray(std::vector<int>(8, 0)));
+ ints.resize(6);
+ ints.resize(6); // no-op
+ ASSERT_GE(ints.capacity(), 8);
+ EXPECT_THAT(ints, ElementsAreArray(std::vector<int>(6, 0)));
+ ints.resize(0);
+ ASSERT_GE(ints.capacity(), 8);
+ EXPECT_THAT(ints, ElementsAreArray(std::vector<int>(0, 0)));
+ ints.resize(5);
+ ASSERT_GE(ints.capacity(), 8);
+ EXPECT_THAT(ints, ElementsAreArray(std::vector<int>(5, 0)));
+ ints.resize(7);
+ ASSERT_GE(ints.capacity(), 8);
+ EXPECT_THAT(ints, ElementsAreArray(std::vector<int>(7, 0)));
+ }
+ {
+ IntVectorType<N> ints;
+ ints.resize(2, IntLike(2));
+ ASSERT_GE(ints.capacity(), 2);
+ EXPECT_THAT(ints, ElementsAre(2, 2));
+ ints.resize(3, IntLike(3));
+ ASSERT_GE(ints.capacity(), 3);
+ EXPECT_THAT(ints, ElementsAre(2, 2, 3));
+ ints.resize(8, IntLike(8));
+ ASSERT_GE(ints.capacity(), 8);
+ EXPECT_THAT(ints, ElementsAre(2, 2, 3, 8, 8, 8, 8, 8));
+ ints.resize(6, IntLike(6));
+ ints.resize(6, IntLike(6)); // no-op
+ ASSERT_GE(ints.capacity(), 8);
+ EXPECT_THAT(ints, ElementsAre(2, 2, 3, 8, 8, 8));
+ ints.resize(0, IntLike(0));
+ ASSERT_GE(ints.capacity(), 8);
+ EXPECT_THAT(ints, ElementsAre());
+ ints.resize(5, IntLike(5));
+ ASSERT_GE(ints.capacity(), 8);
+ EXPECT_THAT(ints, ElementsAre(5, 5, 5, 5, 5));
+ ints.resize(7, IntLike(7));
+ ASSERT_GE(ints.capacity(), 8);
+ EXPECT_THAT(ints, ElementsAre(5, 5, 5, 5, 5, 7, 7));
+ }
+ }
+
+ template <bool IsMoveOnly = Param::IsMoveOnly()>
+ void TestResize(enable_if_t<IsMoveOnly>* = 0) {
+ GTEST_SKIP() << "Cannot resize vector of move-only type";
+ }
+
+ template <size_t N>
+ void CheckSort() {
+ IntVectorType<N> ints;
+ for (int v : {42, 2, 123, -5, 6, 12, 8, 13}) {
+ ints.emplace_back(v);
+ }
+ std::sort(ints.begin(), ints.end());
+ EXPECT_THAT(ints, ElementsAre(-5, 2, 6, 8, 12, 13, 42, 123));
+ }
+
+ void TestSort() {
+ CheckSort<Traits::TestSizeFor(8)>();
+ CheckSort<8>();
+ }
+
+ void TestIterators() {
+ constexpr size_t N = Traits::TestSizeFor(5);
+ {
+ // Forward iterators
+ IntVectorType<N> ints;
+ ASSERT_EQ(ints.begin(), ints.end());
+
+ for (int v : {5, 6, 7, 8, 42}) {
+ ints.emplace_back(v);
+ }
+
+ auto it = ints.begin();
+ ASSERT_NE(it, ints.end());
+ ASSERT_EQ(*it++, 5);
+ ASSERT_EQ(ints.end() - it, 4);
+
+ auto it2 = ++it;
+ ASSERT_EQ(*it, 7);
+ ASSERT_EQ(*it2, 7);
+ ASSERT_EQ(it, it2);
+
+ ASSERT_EQ(ints.end() - it, 3);
+ ASSERT_EQ(*it--, 7);
+ ASSERT_NE(it, it2);
+
+ ASSERT_EQ(ints.end() - it, 4);
+ ASSERT_NE(it, ints.end());
+ ASSERT_EQ(*--it, 5);
+ ASSERT_EQ(*it, 5);
+ ASSERT_EQ(ints.end() - it, 5);
+ it += 4;
+ ASSERT_EQ(*it, 42);
+ ASSERT_EQ(ints.end() - it, 1);
+ ASSERT_NE(it, ints.end());
+ ASSERT_EQ(*(it - 3), 6);
+ ASSERT_EQ(++it, ints.end());
+ }
+ {
+ // Reverse iterators
+ IntVectorType<N> ints;
+ ASSERT_EQ(ints.rbegin(), ints.rend());
+
+ for (int v : {42, 8, 7, 6, 5}) {
+ ints.emplace_back(v);
+ }
+
+ auto it = ints.rbegin();
+ ASSERT_NE(it, ints.rend());
+ ASSERT_EQ(*it++, 5);
+ ASSERT_EQ(ints.rend() - it, 4);
+
+ auto it2 = ++it;
+ ASSERT_EQ(*it, 7);
+ ASSERT_EQ(*it2, 7);
+ ASSERT_EQ(it, it2);
+
+ ASSERT_EQ(ints.rend() - it, 3);
+ ASSERT_EQ(*it--, 7);
+ ASSERT_NE(it, it2);
+
+ ASSERT_EQ(ints.rend() - it, 4);
+ ASSERT_NE(it, ints.rend());
+ ASSERT_EQ(*--it, 5);
+ ASSERT_EQ(*it, 5);
+ ASSERT_EQ(ints.rend() - it, 5);
+ it += 4;
+ ASSERT_EQ(*it, 42);
+ ASSERT_EQ(ints.rend() - it, 1);
+ ASSERT_NE(it, ints.rend());
+ ASSERT_EQ(*(it - 3), 6);
+ ASSERT_EQ(++it, ints.rend());
+ }
+ }
+
+ void TestConstIterators() {
+ constexpr size_t N = Traits::TestSizeFor(5);
+ {
+ const IntVectorType<N> ints{};
+ ASSERT_EQ(ints.begin(), ints.end());
+ ASSERT_EQ(ints.rbegin(), ints.rend());
+ }
+ {
+ // Forward iterators
+ IntVectorType<N> underlying_ints = MakeVector<N>({5, 6, 7, 8, 42});
+ const IntVectorType<N>& ints = underlying_ints;
+
+ auto it = ints.begin();
+ ASSERT_NE(it, ints.end());
+ ASSERT_EQ(*it++, 5);
+ auto it2 = it++;
+ ASSERT_EQ(*it2, 6);
+ ASSERT_EQ(*it, 7);
+ ASSERT_NE(it, it2);
+ ASSERT_EQ(*++it2, 7);
+ ASSERT_EQ(it, it2);
+
+ // Conversion from non-const iterator
+ it = underlying_ints.begin() + 1;
+ ASSERT_NE(it, underlying_ints.end());
+ ASSERT_EQ(*it, 6);
+ it += underlying_ints.end() - it;
+ ASSERT_EQ(it, underlying_ints.end());
+ }
+ {
+ // Reverse iterators
+ IntVectorType<N> underlying_ints = MakeVector<N>({42, 8, 7, 6, 5});
+ const IntVectorType<N>& ints = underlying_ints;
+
+ auto it = ints.rbegin();
+ ASSERT_NE(it, ints.rend());
+ ASSERT_EQ(*it++, 5);
+ auto it2 = it++;
+ ASSERT_EQ(*it2, 6);
+ ASSERT_EQ(*it, 7);
+ ASSERT_NE(it, it2);
+ ASSERT_EQ(*++it2, 7);
+ ASSERT_EQ(it, it2);
+
+ // Conversion from non-const iterator
+ it = underlying_ints.rbegin() + 1;
+ ASSERT_NE(it, underlying_ints.rend());
+ ASSERT_EQ(*it, 6);
+ it += underlying_ints.rend() - it;
+ ASSERT_EQ(it, underlying_ints.rend());
+ }
+ }
+
+ void TestInsertIteratorPair() {
+ // insert(const_iterator, InputIt first, InputIt last)
+ constexpr size_t N = Traits::TestSizeFor(10);
+ {
+ // empty source and destination
+ const std::vector<int> src{};
+ IntVectorType<N> ints;
+ ints.insert(ints.begin(), src.begin(), src.end());
+ ASSERT_EQ(ints.size(), 0);
+
+ ints.emplace_back(42);
+ ints.insert(ints.begin(), src.begin(), src.end());
+ ints.insert(ints.end(), src.begin(), src.end());
+ EXPECT_THAT(ints, ElementsAre(42));
+ }
+ const std::vector<int> src{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ {
+ // insert at start
+ IntVectorType<N> ints;
+ ints.insert(ints.begin(), src.begin() + 4, src.begin() + 7);
+ EXPECT_THAT(ints, ElementsAre(4, 5, 6));
+ ints.insert(ints.begin(), src.begin() + 1, src.begin() + 4);
+ EXPECT_THAT(ints, ElementsAre(1, 2, 3, 4, 5, 6));
+ ints.insert(ints.begin(), src.begin(), src.begin() + 1);
+ EXPECT_THAT(ints, ElementsAre(0, 1, 2, 3, 4, 5, 6));
+ ints.insert(ints.begin(), src.begin() + 7, src.begin() + 10);
+ EXPECT_THAT(ints, ElementsAre(7, 8, 9, 0, 1, 2, 3, 4, 5, 6));
+ }
+ {
+ // insert at end
+ IntVectorType<N> ints;
+ ints.insert(ints.end(), src.begin() + 4, src.begin() + 7);
+ EXPECT_THAT(ints, ElementsAre(4, 5, 6));
+ ints.insert(ints.end(), src.begin() + 1, src.begin() + 4);
+ EXPECT_THAT(ints, ElementsAre(4, 5, 6, 1, 2, 3));
+ ints.insert(ints.end(), src.begin(), src.begin() + 1);
+ EXPECT_THAT(ints, ElementsAre(4, 5, 6, 1, 2, 3, 0));
+ ints.insert(ints.end(), src.begin() + 7, src.begin() + 10);
+ EXPECT_THAT(ints, ElementsAre(4, 5, 6, 1, 2, 3, 0, 7, 8, 9));
+ }
+ {
+ // insert at some point inside
+ IntVectorType<N> ints;
+ ints.insert(ints.begin(), src.begin() + 4, src.begin() + 7);
+ EXPECT_THAT(ints, ElementsAre(4, 5, 6));
+ ints.insert(ints.begin() + 2, src.begin() + 1, src.begin() + 4);
+ EXPECT_THAT(ints, ElementsAre(4, 5, 1, 2, 3, 6));
+ ints.insert(ints.begin() + 2, src.begin(), src.begin() + 1);
+ EXPECT_THAT(ints, ElementsAre(4, 5, 0, 1, 2, 3, 6));
+ ints.insert(ints.begin() + 2, src.begin() + 7, src.begin() + 10);
+ EXPECT_THAT(ints, ElementsAre(4, 5, 7, 8, 9, 0, 1, 2, 3, 6));
+ }
+ {
+ // insert from a std::move_iterator (potentially move-only)
+ IntVectorType<N> src = MakeVector<N>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
+ IntVectorType<N> ints;
+ auto move_it = [&](size_t i) { return std::make_move_iterator(src.begin() + i); };
+ ints.insert(ints.begin(), move_it(4), move_it(7));
+ EXPECT_THAT(ints, ElementsAre(4, 5, 6));
+ ints.insert(ints.begin() + 2, move_it(1), move_it(4));
+ EXPECT_THAT(ints, ElementsAre(4, 5, 1, 2, 3, 6));
+ ints.insert(ints.begin() + 2, move_it(0), move_it(1));
+ EXPECT_THAT(ints, ElementsAre(4, 5, 0, 1, 2, 3, 6));
+ ints.insert(ints.begin() + 2, move_it(7), move_it(10));
+ EXPECT_THAT(ints, ElementsAre(4, 5, 7, 8, 9, 0, 1, 2, 3, 6));
+ }
+ }
+};
+
+TYPED_TEST_SUITE(TestSmallStaticVector, VectorIntLikeParams);
+
+TYPED_TEST(TestSmallStaticVector, Basics) { this->TestBasics(); }
+
+TYPED_TEST(TestSmallStaticVector, AlwaysStatic) { this->TestAlwaysStatic(); }
+
+TYPED_TEST(TestSmallStaticVector, Reserve) { this->TestReserve(); }
+
+TYPED_TEST(TestSmallStaticVector, Clear) { this->TestClear(); }
+
+TYPED_TEST(TestSmallStaticVector, ConstructFromCount) { this->TestConstructFromCount(); }
+
+TYPED_TEST(TestSmallStaticVector, ConstructFromValues) {
+ this->TestConstructFromValues();
+}
+
+TYPED_TEST(TestSmallStaticVector, ConstructFromStdVector) {
+ this->TestConstructFromStdVector();
+}
+
+TYPED_TEST(TestSmallStaticVector, AssignFromStdVector) {
+ this->TestAssignFromStdVector();
+}
+
+TYPED_TEST(TestSmallStaticVector, Move) { this->TestMove(); }
+
+TYPED_TEST(TestSmallStaticVector, Copy) { this->TestCopy(); }
+
+TYPED_TEST(TestSmallStaticVector, Resize) { this->TestResize(); }
+
+TYPED_TEST(TestSmallStaticVector, Sort) { this->TestSort(); }
+
+TYPED_TEST(TestSmallStaticVector, Iterators) { this->TestIterators(); }
+
+TYPED_TEST(TestSmallStaticVector, ConstIterators) { this->TestConstIterators(); }
+
+TYPED_TEST(TestSmallStaticVector, InsertIteratorPair) { this->TestInsertIteratorPair(); }
+
+TEST(StaticVector, Traits) {
+ ASSERT_TRUE((std::is_trivially_destructible<StaticVector<int, 4>>::value));
+ ASSERT_FALSE((std::is_trivially_destructible<StaticVector<std::string, 4>>::value));
+}
+
+TEST(SmallVector, Traits) {
+ ASSERT_FALSE((std::is_trivially_destructible<SmallVector<int, 4>>::value));
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/sort.h b/src/arrow/cpp/src/arrow/util/sort.h
new file mode 100644
index 000000000..cdffe0b23
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/sort.h
@@ -0,0 +1,78 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cstdint>
+#include <functional>
+#include <numeric>
+#include <utility>
+#include <vector>
+
+namespace arrow {
+namespace internal {
+
+template <typename T, typename Cmp = std::less<T>>
+std::vector<int64_t> ArgSort(const std::vector<T>& values, Cmp&& cmp = {}) {
+ std::vector<int64_t> indices(values.size());
+ std::iota(indices.begin(), indices.end(), 0);
+ std::sort(indices.begin(), indices.end(),
+ [&](int64_t i, int64_t j) -> bool { return cmp(values[i], values[j]); });
+ return indices;
+}
+
+template <typename T>
+size_t Permute(const std::vector<int64_t>& indices, std::vector<T>* values) {
+ if (indices.size() <= 1) {
+ return indices.size();
+ }
+
+ // mask indicating which of values are in the correct location
+ std::vector<bool> sorted(indices.size(), false);
+
+ size_t cycle_count = 0;
+
+ for (auto cycle_start = sorted.begin(); cycle_start != sorted.end();
+ cycle_start = std::find(cycle_start, sorted.end(), false)) {
+ ++cycle_count;
+
+ // position in which an element belongs WRT sort
+ auto sort_into = static_cast<int64_t>(cycle_start - sorted.begin());
+
+ if (indices[sort_into] == sort_into) {
+ // trivial cycle
+ sorted[sort_into] = true;
+ continue;
+ }
+
+ // resolve this cycle
+ const auto end = sort_into;
+ for (int64_t take_from = indices[sort_into]; take_from != end;
+ take_from = indices[sort_into]) {
+ std::swap(values->at(sort_into), values->at(take_from));
+ sorted[sort_into] = true;
+ sort_into = take_from;
+ }
+ sorted[sort_into] = true;
+ }
+
+ return cycle_count;
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/spaced.h b/src/arrow/cpp/src/arrow/util/spaced.h
new file mode 100644
index 000000000..8265e1d22
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/spaced.h
@@ -0,0 +1,98 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cassert>
+#include <cstdint>
+#include <cstring>
+
+#include "arrow/util/bit_run_reader.h"
+
+namespace arrow {
+namespace util {
+namespace internal {
+
+/// \brief Compress the buffer to spaced, excluding the null entries.
+///
+/// \param[in] src the source buffer
+/// \param[in] num_values the size of source buffer
+/// \param[in] valid_bits bitmap data indicating position of valid slots
+/// \param[in] valid_bits_offset offset into valid_bits
+/// \param[out] output the output buffer spaced
+/// \return The size of spaced buffer.
+template <typename T>
+inline int SpacedCompress(const T* src, int num_values, const uint8_t* valid_bits,
+ int64_t valid_bits_offset, T* output) {
+ int num_valid_values = 0;
+
+ arrow::internal::SetBitRunReader reader(valid_bits, valid_bits_offset, num_values);
+ while (true) {
+ const auto run = reader.NextRun();
+ if (run.length == 0) {
+ break;
+ }
+ std::memcpy(output + num_valid_values, src + run.position, run.length * sizeof(T));
+ num_valid_values += static_cast<int32_t>(run.length);
+ }
+
+ return num_valid_values;
+}
+
+/// \brief Relocate values in buffer into positions of non-null values as indicated by
+/// a validity bitmap.
+///
+/// \param[in, out] buffer the in-place buffer
+/// \param[in] num_values total size of buffer including null slots
+/// \param[in] null_count number of null slots
+/// \param[in] valid_bits bitmap data indicating position of valid slots
+/// \param[in] valid_bits_offset offset into valid_bits
+/// \return The number of values expanded, including nulls.
+template <typename T>
+inline int SpacedExpand(T* buffer, int num_values, int null_count,
+ const uint8_t* valid_bits, int64_t valid_bits_offset) {
+ // Point to end as we add the spacing from the back.
+ int idx_decode = num_values - null_count;
+
+ // Depending on the number of nulls, some of the value slots in buffer may
+ // be uninitialized, and this will cause valgrind warnings / potentially UB
+ std::memset(static_cast<void*>(buffer + idx_decode), 0, null_count * sizeof(T));
+ if (idx_decode == 0) {
+ // All nulls, nothing more to do
+ return num_values;
+ }
+
+ arrow::internal::ReverseSetBitRunReader reader(valid_bits, valid_bits_offset,
+ num_values);
+ while (true) {
+ const auto run = reader.NextRun();
+ if (run.length == 0) {
+ break;
+ }
+ idx_decode -= static_cast<int32_t>(run.length);
+ assert(idx_decode >= 0);
+ std::memmove(buffer + run.position, buffer + idx_decode, run.length * sizeof(T));
+ }
+
+ // Otherwise caller gave an incorrect null_count
+ assert(idx_decode == 0);
+ return num_values;
+}
+
+} // namespace internal
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/stl_util_test.cc b/src/arrow/cpp/src/arrow/util/stl_util_test.cc
new file mode 100644
index 000000000..2a8784e13
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/stl_util_test.cc
@@ -0,0 +1,172 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/sort.h"
+#include "arrow/util/string.h"
+#include "arrow/util/vector.h"
+
+namespace arrow {
+namespace internal {
+
+TEST(StlUtilTest, VectorAddRemoveTest) {
+ std::vector<int> values;
+ std::vector<int> result = AddVectorElement(values, 0, 100);
+ EXPECT_EQ(values.size(), 0);
+ EXPECT_EQ(result.size(), 1);
+ EXPECT_EQ(result[0], 100);
+
+ // Add 200 at index 0 and 300 at the end.
+ std::vector<int> result2 = AddVectorElement(result, 0, 200);
+ result2 = AddVectorElement(result2, result2.size(), 300);
+ EXPECT_EQ(result.size(), 1);
+ EXPECT_EQ(result2.size(), 3);
+ EXPECT_EQ(result2[0], 200);
+ EXPECT_EQ(result2[1], 100);
+ EXPECT_EQ(result2[2], 300);
+
+ // Remove 100, 300, 200
+ std::vector<int> result3 = DeleteVectorElement(result2, 1);
+ EXPECT_EQ(result2.size(), 3);
+ EXPECT_EQ(result3.size(), 2);
+ EXPECT_EQ(result3[0], 200);
+ EXPECT_EQ(result3[1], 300);
+
+ result3 = DeleteVectorElement(result3, 1);
+ EXPECT_EQ(result3.size(), 1);
+ EXPECT_EQ(result3[0], 200);
+
+ result3 = DeleteVectorElement(result3, 0);
+ EXPECT_TRUE(result3.empty());
+}
+
+void ExpectSortPermutation(std::vector<std::string> unsorted,
+ std::vector<int64_t> expected_indices,
+ size_t expected_cycle_count) {
+ auto actual_indices = ArgSort(unsorted);
+ EXPECT_THAT(actual_indices, ::testing::ContainerEq(expected_indices));
+
+ auto sorted = unsorted;
+ std::sort(sorted.begin(), sorted.end());
+
+ auto permuted = unsorted;
+ EXPECT_EQ(Permute(expected_indices, &permuted), expected_cycle_count);
+
+ EXPECT_THAT(permuted, ::testing::ContainerEq(sorted));
+}
+
+TEST(StlUtilTest, ArgSortPermute) {
+ std::string f = "foxtrot", a = "alpha", b = "bravo", d = "delta", c = "charlie",
+ e = "echo";
+
+ ExpectSortPermutation({a, f}, {0, 1}, 2);
+ ExpectSortPermutation({f, a}, {1, 0}, 1);
+ ExpectSortPermutation({a, b, c}, {0, 1, 2}, 3);
+ ExpectSortPermutation({a, c, b}, {0, 2, 1}, 2);
+ ExpectSortPermutation({c, a, b}, {1, 2, 0}, 1);
+ ExpectSortPermutation({a, b, c, d, e, f}, {0, 1, 2, 3, 4, 5}, 6);
+ ExpectSortPermutation({f, e, d, c, b, a}, {5, 4, 3, 2, 1, 0}, 3);
+ ExpectSortPermutation({d, f, e, c, b, a}, {5, 4, 3, 0, 2, 1}, 1);
+ ExpectSortPermutation({b, a, c, d, f, e}, {1, 0, 2, 3, 5, 4}, 4);
+ ExpectSortPermutation({c, b, a, d, e, f}, {2, 1, 0, 3, 4, 5}, 5);
+ ExpectSortPermutation({b, c, a, f, d, e}, {2, 0, 1, 4, 5, 3}, 2);
+ ExpectSortPermutation({b, c, d, e, a, f}, {4, 0, 1, 2, 3, 5}, 2);
+}
+
+TEST(StlUtilTest, VectorFlatten) {
+ std::vector<int> a{1, 2, 3};
+ std::vector<int> b{4, 5, 6};
+ std::vector<int> c{7, 8, 9};
+ std::vector<std::vector<int>> vecs{a, b, c};
+ auto actual = FlattenVectors(vecs);
+ std::vector<int> expected{1, 2, 3, 4, 5, 6, 7, 8, 9};
+ ASSERT_EQ(expected, actual);
+}
+
+static std::string int_to_str(int val) { return std::to_string(val); }
+
+TEST(StlUtilTest, VectorMap) {
+ std::vector<int> input{1, 2, 3};
+ std::vector<std::string> expected{"1", "2", "3"};
+
+ auto actual = MapVector(int_to_str, input);
+ ASSERT_EQ(expected, actual);
+
+ auto bind_fn = std::bind(int_to_str, std::placeholders::_1);
+ actual = MapVector(bind_fn, input);
+ ASSERT_EQ(expected, actual);
+
+ std::function<std::string(int)> std_fn = int_to_str;
+ actual = MapVector(std_fn, input);
+ ASSERT_EQ(expected, actual);
+
+ actual = MapVector([](int val) { return std::to_string(val); }, input);
+ ASSERT_EQ(expected, actual);
+}
+
+TEST(StlUtilTest, VectorMaybeMapFails) {
+ std::vector<int> input{1, 2, 3};
+ auto mapper = [](int item) -> Result<std::string> {
+ if (item == 1) {
+ return Status::Invalid("XYZ");
+ }
+ return std::to_string(item);
+ };
+ ASSERT_RAISES(Invalid, MaybeMapVector(mapper, input));
+}
+
+TEST(StlUtilTest, VectorMaybeMap) {
+ std::vector<int> input{1, 2, 3};
+ std::vector<std::string> expected{"1", "2", "3"};
+ EXPECT_OK_AND_ASSIGN(
+ auto actual,
+ MaybeMapVector([](int item) -> Result<std::string> { return std::to_string(item); },
+ input));
+ ASSERT_EQ(expected, actual);
+}
+
+TEST(StlUtilTest, VectorUnwrapOrRaise) {
+ // TODO(ARROW-11998) There should be an easier way to construct these vectors
+ std::vector<Result<MoveOnlyDataType>> all_good;
+ all_good.push_back(Result<MoveOnlyDataType>(MoveOnlyDataType(1)));
+ all_good.push_back(Result<MoveOnlyDataType>(MoveOnlyDataType(2)));
+ all_good.push_back(Result<MoveOnlyDataType>(MoveOnlyDataType(3)));
+
+ std::vector<Result<MoveOnlyDataType>> some_bad;
+ some_bad.push_back(Result<MoveOnlyDataType>(MoveOnlyDataType(1)));
+ some_bad.push_back(Result<MoveOnlyDataType>(Status::Invalid("XYZ")));
+ some_bad.push_back(Result<MoveOnlyDataType>(Status::IOError("XYZ")));
+
+ EXPECT_OK_AND_ASSIGN(auto unwrapped, UnwrapOrRaise(std::move(all_good)));
+ std::vector<MoveOnlyDataType> expected;
+ expected.emplace_back(1);
+ expected.emplace_back(2);
+ expected.emplace_back(3);
+
+ ASSERT_EQ(expected, unwrapped);
+
+ ASSERT_RAISES(Invalid, UnwrapOrRaise(std::move(some_bad)));
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/stopwatch.h b/src/arrow/cpp/src/arrow/util/stopwatch.h
new file mode 100644
index 000000000..db4e67f59
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/stopwatch.h
@@ -0,0 +1,48 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cassert>
+#include <chrono>
+
+namespace arrow {
+namespace internal {
+
+class StopWatch {
+ // This clock should give us wall clock time
+ using ClockType = std::chrono::steady_clock;
+
+ public:
+ StopWatch() {}
+
+ void Start() { start_ = ClockType::now(); }
+
+ // Returns time in nanoseconds.
+ uint64_t Stop() {
+ auto stop = ClockType::now();
+ std::chrono::nanoseconds d = stop - start_;
+ assert(d.count() >= 0);
+ return static_cast<uint64_t>(d.count());
+ }
+
+ private:
+ std::chrono::time_point<ClockType> start_;
+};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/string.cc b/src/arrow/cpp/src/arrow/util/string.cc
new file mode 100644
index 000000000..d922311df
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/string.cc
@@ -0,0 +1,191 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/string.h"
+
+#include <algorithm>
+#include <cctype>
+#include <memory>
+
+#include "arrow/status.h"
+
+namespace arrow {
+
+static const char* kAsciiTable = "0123456789ABCDEF";
+
+std::string HexEncode(const uint8_t* data, size_t length) {
+ std::string hex_string;
+ hex_string.reserve(length * 2);
+ for (size_t j = 0; j < length; ++j) {
+ // Convert to 2 base16 digits
+ hex_string.push_back(kAsciiTable[data[j] >> 4]);
+ hex_string.push_back(kAsciiTable[data[j] & 15]);
+ }
+ return hex_string;
+}
+
+std::string Escape(const char* data, size_t length) {
+ std::string escaped_string;
+ escaped_string.reserve(length);
+ for (size_t j = 0; j < length; ++j) {
+ switch (data[j]) {
+ case '"':
+ escaped_string += R"(\")";
+ break;
+ case '\\':
+ escaped_string += R"(\\)";
+ break;
+ case '\t':
+ escaped_string += R"(\t)";
+ break;
+ case '\r':
+ escaped_string += R"(\r)";
+ break;
+ case '\n':
+ escaped_string += R"(\n)";
+ break;
+ default:
+ escaped_string.push_back(data[j]);
+ }
+ }
+ return escaped_string;
+}
+
+std::string HexEncode(const char* data, size_t length) {
+ return HexEncode(reinterpret_cast<const uint8_t*>(data), length);
+}
+
+std::string HexEncode(util::string_view str) { return HexEncode(str.data(), str.size()); }
+
+std::string Escape(util::string_view str) { return Escape(str.data(), str.size()); }
+
+Status ParseHexValue(const char* data, uint8_t* out) {
+ char c1 = data[0];
+ char c2 = data[1];
+
+ const char* kAsciiTableEnd = kAsciiTable + 16;
+ const char* pos1 = std::lower_bound(kAsciiTable, kAsciiTableEnd, c1);
+ const char* pos2 = std::lower_bound(kAsciiTable, kAsciiTableEnd, c2);
+
+ // Error checking
+ if (pos1 == kAsciiTableEnd || pos2 == kAsciiTableEnd || *pos1 != c1 || *pos2 != c2) {
+ return Status::Invalid("Encountered non-hex digit");
+ }
+
+ *out = static_cast<uint8_t>((pos1 - kAsciiTable) << 4 | (pos2 - kAsciiTable));
+ return Status::OK();
+}
+
+namespace internal {
+
+std::vector<util::string_view> SplitString(util::string_view v, char delimiter) {
+ std::vector<util::string_view> parts;
+ size_t start = 0, end;
+ while (true) {
+ end = v.find(delimiter, start);
+ parts.push_back(v.substr(start, end - start));
+ if (end == std::string::npos) {
+ break;
+ }
+ start = end + 1;
+ }
+ return parts;
+}
+
+template <typename StringLike>
+static std::string JoinStringLikes(const std::vector<StringLike>& strings,
+ util::string_view delimiter) {
+ if (strings.size() == 0) {
+ return "";
+ }
+ std::string out = std::string(strings.front());
+ for (size_t i = 1; i < strings.size(); ++i) {
+ out.append(delimiter.begin(), delimiter.end());
+ out.append(strings[i].begin(), strings[i].end());
+ }
+ return out;
+}
+
+std::string JoinStrings(const std::vector<util::string_view>& strings,
+ util::string_view delimiter) {
+ return JoinStringLikes(strings, delimiter);
+}
+
+std::string JoinStrings(const std::vector<std::string>& strings,
+ util::string_view delimiter) {
+ return JoinStringLikes(strings, delimiter);
+}
+
+static constexpr bool IsWhitespace(char c) { return c == ' ' || c == '\t'; }
+
+std::string TrimString(std::string value) {
+ size_t ltrim_chars = 0;
+ while (ltrim_chars < value.size() && IsWhitespace(value[ltrim_chars])) {
+ ++ltrim_chars;
+ }
+ value.erase(0, ltrim_chars);
+ size_t rtrim_chars = 0;
+ while (rtrim_chars < value.size() &&
+ IsWhitespace(value[value.size() - 1 - rtrim_chars])) {
+ ++rtrim_chars;
+ }
+ value.erase(value.size() - rtrim_chars, rtrim_chars);
+ return value;
+}
+
+bool AsciiEqualsCaseInsensitive(util::string_view left, util::string_view right) {
+ // TODO: ASCII validation
+ if (left.size() != right.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < left.size(); ++i) {
+ if (std::tolower(static_cast<unsigned char>(left[i])) !=
+ std::tolower(static_cast<unsigned char>(right[i]))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+std::string AsciiToLower(util::string_view value) {
+ // TODO: ASCII validation
+ std::string result = std::string(value);
+ std::transform(result.begin(), result.end(), result.begin(),
+ [](unsigned char c) { return std::tolower(c); });
+ return result;
+}
+
+std::string AsciiToUpper(util::string_view value) {
+ // TODO: ASCII validation
+ std::string result = std::string(value);
+ std::transform(result.begin(), result.end(), result.begin(),
+ [](unsigned char c) { return std::toupper(c); });
+ return result;
+}
+
+util::optional<std::string> Replace(util::string_view s, util::string_view token,
+ util::string_view replacement) {
+ size_t token_start = s.find(token);
+ if (token_start == std::string::npos) {
+ return util::nullopt;
+ }
+ return s.substr(0, token_start).to_string() + replacement.to_string() +
+ s.substr(token_start + token.size()).to_string();
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/string.h b/src/arrow/cpp/src/arrow/util/string.h
new file mode 100644
index 000000000..68b8a54e3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/string.h
@@ -0,0 +1,79 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+#include <vector>
+
+#include "arrow/util/optional.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Status;
+
+ARROW_EXPORT std::string HexEncode(const uint8_t* data, size_t length);
+
+ARROW_EXPORT std::string Escape(const char* data, size_t length);
+
+ARROW_EXPORT std::string HexEncode(const char* data, size_t length);
+
+ARROW_EXPORT std::string HexEncode(util::string_view str);
+
+ARROW_EXPORT std::string Escape(util::string_view str);
+
+ARROW_EXPORT Status ParseHexValue(const char* data, uint8_t* out);
+
+namespace internal {
+
+/// \brief Split a string with a delimiter
+ARROW_EXPORT
+std::vector<util::string_view> SplitString(util::string_view v, char delim);
+
+/// \brief Join strings with a delimiter
+ARROW_EXPORT
+std::string JoinStrings(const std::vector<util::string_view>& strings,
+ util::string_view delimiter);
+
+/// \brief Join strings with a delimiter
+ARROW_EXPORT
+std::string JoinStrings(const std::vector<std::string>& strings,
+ util::string_view delimiter);
+
+/// \brief Trim whitespace from left and right sides of string
+ARROW_EXPORT
+std::string TrimString(std::string value);
+
+ARROW_EXPORT
+bool AsciiEqualsCaseInsensitive(util::string_view left, util::string_view right);
+
+ARROW_EXPORT
+std::string AsciiToLower(util::string_view value);
+
+ARROW_EXPORT
+std::string AsciiToUpper(util::string_view value);
+
+/// \brief Search for the first instance of a token and replace it or return nullopt if
+/// the token is not found.
+ARROW_EXPORT
+util::optional<std::string> Replace(util::string_view s, util::string_view token,
+ util::string_view replacement);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/string_builder.cc b/src/arrow/cpp/src/arrow/util/string_builder.cc
new file mode 100644
index 000000000..625ae0075
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/string_builder.cc
@@ -0,0 +1,40 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/string_builder.h"
+
+#include <sstream>
+
+#include "arrow/util/make_unique.h"
+
+namespace arrow {
+
+using internal::make_unique;
+
+namespace util {
+namespace detail {
+
+StringStreamWrapper::StringStreamWrapper()
+ : sstream_(make_unique<std::ostringstream>()), ostream_(*sstream_) {}
+
+StringStreamWrapper::~StringStreamWrapper() {}
+
+std::string StringStreamWrapper::str() { return sstream_->str(); }
+
+} // namespace detail
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/string_builder.h b/src/arrow/cpp/src/arrow/util/string_builder.h
new file mode 100644
index 000000000..7c05ccd51
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/string_builder.h
@@ -0,0 +1,84 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License. template <typename T>
+
+#pragma once
+
+#include <memory>
+#include <ostream>
+#include <string>
+#include <utility>
+
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace util {
+
+namespace detail {
+
+class ARROW_EXPORT StringStreamWrapper {
+ public:
+ StringStreamWrapper();
+ ~StringStreamWrapper();
+
+ std::ostream& stream() { return ostream_; }
+ std::string str();
+
+ protected:
+ std::unique_ptr<std::ostringstream> sstream_;
+ std::ostream& ostream_;
+};
+
+} // namespace detail
+
+template <typename Head>
+void StringBuilderRecursive(std::ostream& stream, Head&& head) {
+ stream << head;
+}
+
+template <typename Head, typename... Tail>
+void StringBuilderRecursive(std::ostream& stream, Head&& head, Tail&&... tail) {
+ StringBuilderRecursive(stream, std::forward<Head>(head));
+ StringBuilderRecursive(stream, std::forward<Tail>(tail)...);
+}
+
+template <typename... Args>
+std::string StringBuilder(Args&&... args) {
+ detail::StringStreamWrapper ss;
+ StringBuilderRecursive(ss.stream(), std::forward<Args>(args)...);
+ return ss.str();
+}
+
+/// CRTP helper for declaring string representation. Defines operator<<
+template <typename T>
+class ToStringOstreamable {
+ public:
+ ~ToStringOstreamable() {
+ static_assert(
+ std::is_same<decltype(std::declval<const T>().ToString()), std::string>::value,
+ "ToStringOstreamable depends on the method T::ToString() const");
+ }
+
+ private:
+ const T& cast() const { return static_cast<const T&>(*this); }
+
+ friend inline std::ostream& operator<<(std::ostream& os, const ToStringOstreamable& t) {
+ return os << t.cast().ToString();
+ }
+};
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/string_test.cc b/src/arrow/cpp/src/arrow/util/string_test.cc
new file mode 100644
index 000000000..057d885fc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/string_test.cc
@@ -0,0 +1,144 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/string.h"
+
+namespace arrow {
+namespace internal {
+
+TEST(Trim, Basics) {
+ std::vector<std::pair<std::string, std::string>> test_cases = {
+ {"", ""}, {" ", ""}, {" ", ""}, {"\t ", ""},
+ {" \ta\t ", "a"}, {" \ta", "a"}, {"ab \t", "ab"}};
+ for (auto case_ : test_cases) {
+ EXPECT_EQ(case_.second, TrimString(case_.first));
+ }
+}
+
+TEST(AsciiEqualsCaseInsensitive, Basics) {
+ ASSERT_TRUE(AsciiEqualsCaseInsensitive("foo", "Foo"));
+ ASSERT_TRUE(AsciiEqualsCaseInsensitive("foo!", "FOO!"));
+ ASSERT_TRUE(AsciiEqualsCaseInsensitive("", ""));
+ ASSERT_TRUE(AsciiEqualsCaseInsensitive("fooo", "fooO"));
+
+ ASSERT_FALSE(AsciiEqualsCaseInsensitive("f", "G"));
+ ASSERT_FALSE(AsciiEqualsCaseInsensitive("foo!", "FOO"));
+}
+
+TEST(AsciiToLower, Basics) {
+ ASSERT_EQ("something", AsciiToLower("Something"));
+ ASSERT_EQ("something", AsciiToLower("SOMETHING"));
+ ASSERT_EQ("", AsciiToLower(""));
+}
+
+TEST(ParseHexValue, Valid) {
+ uint8_t output;
+
+ // evaluate valid letters
+ std::string input = "AB";
+ ASSERT_OK(ParseHexValue(input.c_str(), &output));
+ EXPECT_EQ(171, output);
+
+ // evaluate valid numbers
+ input = "12";
+ ASSERT_OK(ParseHexValue(input.c_str(), &output));
+ EXPECT_EQ(18, output);
+
+ // evaluate mixed hex numbers
+ input = "B1";
+ ASSERT_OK(ParseHexValue(input.c_str(), &output));
+ EXPECT_EQ(177, output);
+}
+
+TEST(ParseHexValue, Invalid) {
+ uint8_t output;
+
+ // evaluate invalid letters
+ std::string input = "XY";
+ ASSERT_RAISES(Invalid, ParseHexValue(input.c_str(), &output));
+
+ // evaluate invalid signs
+ input = "@?";
+ ASSERT_RAISES(Invalid, ParseHexValue(input.c_str(), &output));
+
+ // evaluate lower-case letters
+ input = "ab";
+ ASSERT_RAISES(Invalid, ParseHexValue(input.c_str(), &output));
+}
+
+TEST(Replace, Basics) {
+ auto s = Replace("dat_{i}.txt", "{i}", "23");
+ EXPECT_TRUE(s);
+ EXPECT_EQ(*s, "dat_23.txt");
+
+ // only replace the first occurrence of token
+ s = Replace("dat_{i}_{i}.txt", "{i}", "23");
+ EXPECT_TRUE(s);
+ EXPECT_EQ(*s, "dat_23_{i}.txt");
+
+ s = Replace("dat_.txt", "{nope}", "23");
+ EXPECT_FALSE(s);
+}
+
+TEST(SplitString, InnerDelimiter) {
+ std::string input = "a:b:c";
+ auto parts = SplitString(input, ':');
+ ASSERT_EQ(parts.size(), 3);
+ EXPECT_EQ(parts[0], "a");
+ EXPECT_EQ(parts[1], "b");
+ EXPECT_EQ(parts[2], "c");
+}
+
+TEST(SplitString, OuterRightDelimiter) {
+ std::string input = "a:b:c:";
+ auto parts = SplitString(input, ':');
+ ASSERT_EQ(parts.size(), 4);
+ EXPECT_EQ(parts[0], "a");
+ EXPECT_EQ(parts[1], "b");
+ EXPECT_EQ(parts[2], "c");
+ EXPECT_EQ(parts[3], "");
+}
+
+TEST(SplitString, OuterLeftAndOuterRightDelimiter) {
+ std::string input = ":a:b:c:";
+ auto parts = SplitString(input, ':');
+ ASSERT_EQ(parts.size(), 5);
+ EXPECT_EQ(parts[0], "");
+ EXPECT_EQ(parts[1], "a");
+ EXPECT_EQ(parts[2], "b");
+ EXPECT_EQ(parts[3], "c");
+ EXPECT_EQ(parts[4], "");
+}
+
+TEST(SplitString, OnlyDemiliter) {
+ std::string input = ":";
+ auto parts = SplitString(input, ':');
+ ASSERT_EQ(parts.size(), 2);
+ EXPECT_EQ(parts[0], "");
+ EXPECT_EQ(parts[1], "");
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/string_view.h b/src/arrow/cpp/src/arrow/util/string_view.h
new file mode 100644
index 000000000..4a51c2ebd
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/string_view.h
@@ -0,0 +1,38 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#define nssv_CONFIG_SELECT_STRING_VIEW nssv_STRING_VIEW_NONSTD
+
+#include <cstdint>
+#include <string>
+
+#include "arrow/vendored/string_view.hpp" // IWYU pragma: export
+
+namespace arrow {
+namespace util {
+
+using nonstd::string_view;
+
+template <class Char, class Traits = std::char_traits<Char>>
+using basic_string_view = nonstd::basic_string_view<Char, Traits>;
+
+using bytes_view = basic_string_view<uint8_t>;
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/task_group.cc b/src/arrow/cpp/src/arrow/util/task_group.cc
new file mode 100644
index 000000000..7e8ab64b7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/task_group.cc
@@ -0,0 +1,224 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/task_group.h"
+
+#include <atomic>
+#include <condition_variable>
+#include <cstdint>
+#include <mutex>
+#include <utility>
+
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+namespace internal {
+
+namespace {
+
+////////////////////////////////////////////////////////////////////////
+// Serial TaskGroup implementation
+
+class SerialTaskGroup : public TaskGroup {
+ public:
+ explicit SerialTaskGroup(StopToken stop_token) : stop_token_(std::move(stop_token)) {}
+
+ void AppendReal(FnOnce<Status()> task) override {
+ DCHECK(!finished_);
+ if (stop_token_.IsStopRequested()) {
+ status_ &= stop_token_.Poll();
+ return;
+ }
+ if (status_.ok()) {
+ status_ &= std::move(task)();
+ }
+ }
+
+ Status current_status() override { return status_; }
+
+ bool ok() const override { return status_.ok(); }
+
+ Status Finish() override {
+ if (!finished_) {
+ finished_ = true;
+ }
+ return status_;
+ }
+
+ Future<> FinishAsync() override { return Future<>::MakeFinished(Finish()); }
+
+ int parallelism() override { return 1; }
+
+ StopToken stop_token_;
+ Status status_;
+ bool finished_ = false;
+};
+
+////////////////////////////////////////////////////////////////////////
+// Threaded TaskGroup implementation
+
+class ThreadedTaskGroup : public TaskGroup {
+ public:
+ ThreadedTaskGroup(Executor* executor, StopToken stop_token)
+ : executor_(executor),
+ stop_token_(std::move(stop_token)),
+ nremaining_(0),
+ ok_(true) {}
+
+ ~ThreadedTaskGroup() override {
+ // Make sure all pending tasks are finished, so that dangling references
+ // to this don't persist.
+ ARROW_UNUSED(Finish());
+ }
+
+ void AppendReal(FnOnce<Status()> task) override {
+ DCHECK(!finished_);
+ if (stop_token_.IsStopRequested()) {
+ UpdateStatus(stop_token_.Poll());
+ return;
+ }
+
+ // The hot path is unlocked thanks to atomics
+ // Only if an error occurs is the lock taken
+ if (ok_.load(std::memory_order_acquire)) {
+ nremaining_.fetch_add(1, std::memory_order_acquire);
+
+ auto self = checked_pointer_cast<ThreadedTaskGroup>(shared_from_this());
+
+ struct Callable {
+ void operator()() {
+ if (self_->ok_.load(std::memory_order_acquire)) {
+ Status st;
+ if (stop_token_.IsStopRequested()) {
+ st = stop_token_.Poll();
+ } else {
+ // XXX what about exceptions?
+ st = std::move(task_)();
+ }
+ self_->UpdateStatus(std::move(st));
+ }
+ self_->OneTaskDone();
+ }
+
+ std::shared_ptr<ThreadedTaskGroup> self_;
+ FnOnce<Status()> task_;
+ StopToken stop_token_;
+ };
+
+ Status st =
+ executor_->Spawn(Callable{std::move(self), std::move(task), stop_token_});
+ UpdateStatus(std::move(st));
+ }
+ }
+
+ Status current_status() override {
+ std::lock_guard<std::mutex> lock(mutex_);
+ return status_;
+ }
+
+ bool ok() const override { return ok_.load(); }
+
+ Status Finish() override {
+ std::unique_lock<std::mutex> lock(mutex_);
+ if (!finished_) {
+ cv_.wait(lock, [&]() { return nremaining_.load() == 0; });
+ // Current tasks may start other tasks, so only set this when done
+ finished_ = true;
+ }
+ return status_;
+ }
+
+ Future<> FinishAsync() override {
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (!completion_future_.has_value()) {
+ if (nremaining_.load() == 0) {
+ completion_future_ = Future<>::MakeFinished(status_);
+ } else {
+ completion_future_ = Future<>::Make();
+ }
+ }
+ return *completion_future_;
+ }
+
+ int parallelism() override { return executor_->GetCapacity(); }
+
+ protected:
+ void UpdateStatus(Status&& st) {
+ // Must be called unlocked, only locks on error
+ if (ARROW_PREDICT_FALSE(!st.ok())) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ ok_.store(false, std::memory_order_release);
+ status_ &= std::move(st);
+ }
+ }
+
+ void OneTaskDone() {
+ // Can be called unlocked thanks to atomics
+ auto nremaining = nremaining_.fetch_sub(1, std::memory_order_release) - 1;
+ DCHECK_GE(nremaining, 0);
+ if (nremaining == 0) {
+ // Take the lock so that ~ThreadedTaskGroup cannot destroy cv
+ // before cv.notify_one() has returned
+ std::unique_lock<std::mutex> lock(mutex_);
+ cv_.notify_one();
+ if (completion_future_.has_value()) {
+ // MarkFinished could be slow. We don't want to call it while we are holding
+ // the lock.
+ auto& future = *completion_future_;
+ const auto finished = completion_future_->is_finished();
+ const auto& status = status_;
+ // This will be redundant if the user calls Finish and not FinishAsync
+ if (!finished && !finished_) {
+ finished_ = true;
+ lock.unlock();
+ future.MarkFinished(status);
+ } else {
+ lock.unlock();
+ }
+ }
+ }
+ }
+
+ // These members are usable unlocked
+ Executor* executor_;
+ StopToken stop_token_;
+ std::atomic<int32_t> nremaining_;
+ std::atomic<bool> ok_;
+
+ // These members use locking
+ std::mutex mutex_;
+ std::condition_variable cv_;
+ Status status_;
+ bool finished_ = false;
+ util::optional<Future<>> completion_future_;
+};
+
+} // namespace
+
+std::shared_ptr<TaskGroup> TaskGroup::MakeSerial(StopToken stop_token) {
+ return std::shared_ptr<TaskGroup>(new SerialTaskGroup{stop_token});
+}
+
+std::shared_ptr<TaskGroup> TaskGroup::MakeThreaded(Executor* thread_pool,
+ StopToken stop_token) {
+ return std::shared_ptr<TaskGroup>(new ThreadedTaskGroup{thread_pool, stop_token});
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/task_group.h b/src/arrow/cpp/src/arrow/util/task_group.h
new file mode 100644
index 000000000..3bb72f0d9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/task_group.h
@@ -0,0 +1,106 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <utility>
+
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/cancel.h"
+#include "arrow/util/functional.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace internal {
+
+/// \brief A group of related tasks
+///
+/// A TaskGroup executes tasks with the signature `Status()`.
+/// Execution can be serial or parallel, depending on the TaskGroup
+/// implementation. When Finish() returns, it is guaranteed that all
+/// tasks have finished, or at least one has errored.
+///
+/// Once an error has occurred any tasks that are submitted to the task group
+/// will not run. The call to Append will simply return without scheduling the
+/// task.
+///
+/// If the task group is parallel it is possible that multiple tasks could be
+/// running at the same time and one of those tasks fails. This will put the
+/// task group in a failure state (so additional tasks cannot be run) however
+/// it will not interrupt running tasks. Finish will not complete
+/// until all running tasks have finished, even if one task fails.
+///
+/// Once a task group has finished new tasks may not be added to it. If you need to start
+/// a new batch of work then you should create a new task group.
+class ARROW_EXPORT TaskGroup : public std::enable_shared_from_this<TaskGroup> {
+ public:
+ /// Add a Status-returning function to execute. Execution order is
+ /// undefined. The function may be executed immediately or later.
+ template <typename Function>
+ void Append(Function&& func) {
+ return AppendReal(std::forward<Function>(func));
+ }
+
+ /// Wait for execution of all tasks (and subgroups) to be finished,
+ /// or for at least one task (or subgroup) to error out.
+ /// The returned Status propagates the error status of the first failing
+ /// task (or subgroup).
+ virtual Status Finish() = 0;
+
+ /// Returns a future that will complete the first time all tasks are finished.
+ /// This should be called only after all top level tasks
+ /// have been added to the task group.
+ ///
+ /// If you are using a TaskGroup asynchronously there are a few considerations to keep
+ /// in mind. The tasks should not block on I/O, etc (defeats the purpose of using
+ /// futures) and should not be doing any nested locking or you run the risk of the tasks
+ /// getting stuck in the thread pool waiting for tasks which cannot get scheduled.
+ ///
+ /// Primarily this call is intended to help migrate existing work written with TaskGroup
+ /// in mind to using futures without having to do a complete conversion on the first
+ /// pass.
+ virtual Future<> FinishAsync() = 0;
+
+ /// The current aggregate error Status. Non-blocking, useful for stopping early.
+ virtual Status current_status() = 0;
+
+ /// Whether some tasks have already failed. Non-blocking, useful for stopping early.
+ virtual bool ok() const = 0;
+
+ /// How many tasks can typically be executed in parallel.
+ /// This is only a hint, useful for testing or debugging.
+ virtual int parallelism() = 0;
+
+ static std::shared_ptr<TaskGroup> MakeSerial(StopToken = StopToken::Unstoppable());
+ static std::shared_ptr<TaskGroup> MakeThreaded(internal::Executor*,
+ StopToken = StopToken::Unstoppable());
+
+ virtual ~TaskGroup() = default;
+
+ protected:
+ TaskGroup() = default;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(TaskGroup);
+
+ virtual void AppendReal(FnOnce<Status()> task) = 0;
+};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/task_group_test.cc b/src/arrow/cpp/src/arrow/util/task_group_test.cc
new file mode 100644
index 000000000..4913fb929
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/task_group_test.cc
@@ -0,0 +1,444 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <atomic>
+#include <chrono>
+#include <condition_variable>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <random>
+#include <thread>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/status.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/task_group.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+namespace internal {
+
+// Generate random sleep durations
+static std::vector<double> RandomSleepDurations(int nsleeps, double min_seconds,
+ double max_seconds) {
+ std::vector<double> sleeps;
+ std::default_random_engine engine;
+ std::uniform_real_distribution<> sleep_dist(min_seconds, max_seconds);
+ for (int i = 0; i < nsleeps; ++i) {
+ sleeps.push_back(sleep_dist(engine));
+ }
+ return sleeps;
+}
+
+// Check TaskGroup behaviour with a bunch of all-successful tasks
+void TestTaskGroupSuccess(std::shared_ptr<TaskGroup> task_group) {
+ const int NTASKS = 10;
+ auto sleeps = RandomSleepDurations(NTASKS, 1e-3, 4e-3);
+
+ // Add NTASKS sleeps
+ std::atomic<int> count(0);
+ for (int i = 0; i < NTASKS; ++i) {
+ task_group->Append([&, i]() {
+ SleepFor(sleeps[i]);
+ count += i;
+ return Status::OK();
+ });
+ }
+ ASSERT_TRUE(task_group->ok());
+
+ ASSERT_OK(task_group->Finish());
+ ASSERT_TRUE(task_group->ok());
+ ASSERT_EQ(count.load(), NTASKS * (NTASKS - 1) / 2);
+ // Finish() is idempotent
+ ASSERT_OK(task_group->Finish());
+}
+
+// Check TaskGroup behaviour with some successful and some failing tasks
+void TestTaskGroupErrors(std::shared_ptr<TaskGroup> task_group) {
+ const int NSUCCESSES = 2;
+ const int NERRORS = 20;
+
+ std::atomic<int> count(0);
+
+ auto task_group_was_ok = false;
+ task_group->Append([&]() -> Status {
+ for (int i = 0; i < NSUCCESSES; ++i) {
+ task_group->Append([&]() {
+ count++;
+ return Status::OK();
+ });
+ }
+ task_group_was_ok = task_group->ok();
+ for (int i = 0; i < NERRORS; ++i) {
+ task_group->Append([&]() {
+ SleepFor(1e-2);
+ count++;
+ return Status::Invalid("some message");
+ });
+ }
+
+ return Status::OK();
+ });
+
+ // Task error is propagated
+ ASSERT_RAISES(Invalid, task_group->Finish());
+ ASSERT_TRUE(task_group_was_ok);
+ ASSERT_FALSE(task_group->ok());
+ if (task_group->parallelism() == 1) {
+ // Serial: exactly two successes and an error
+ ASSERT_EQ(count.load(), 3);
+ } else {
+ // Parallel: at least two successes and an error
+ ASSERT_GE(count.load(), 3);
+ ASSERT_LE(count.load(), 2 * task_group->parallelism());
+ }
+ // Finish() is idempotent
+ ASSERT_RAISES(Invalid, task_group->Finish());
+}
+
+void TestTaskGroupCancel(std::shared_ptr<TaskGroup> task_group, StopSource* stop_source) {
+ const int NSUCCESSES = 2;
+ const int NCANCELS = 20;
+
+ std::atomic<int> count(0);
+
+ auto task_group_was_ok = false;
+ task_group->Append([&]() -> Status {
+ for (int i = 0; i < NSUCCESSES; ++i) {
+ task_group->Append([&]() {
+ count++;
+ return Status::OK();
+ });
+ }
+ task_group_was_ok = task_group->ok();
+ for (int i = 0; i < NCANCELS; ++i) {
+ task_group->Append([&]() {
+ SleepFor(1e-2);
+ stop_source->RequestStop();
+ count++;
+ return Status::OK();
+ });
+ }
+
+ return Status::OK();
+ });
+
+ // Cancellation is propagated
+ ASSERT_RAISES(Cancelled, task_group->Finish());
+ ASSERT_TRUE(task_group_was_ok);
+ ASSERT_FALSE(task_group->ok());
+ if (task_group->parallelism() == 1) {
+ // Serial: exactly three successes
+ ASSERT_EQ(count.load(), NSUCCESSES + 1);
+ } else {
+ // Parallel: at least three successes
+ ASSERT_GE(count.load(), NSUCCESSES + 1);
+ ASSERT_LE(count.load(), NSUCCESSES * task_group->parallelism());
+ }
+ // Finish() is idempotent
+ ASSERT_RAISES(Cancelled, task_group->Finish());
+}
+
+class CopyCountingTask {
+ public:
+ explicit CopyCountingTask(std::shared_ptr<uint8_t> target)
+ : counter(0), target(std::move(target)) {}
+
+ CopyCountingTask(const CopyCountingTask& other)
+ : counter(other.counter + 1), target(other.target) {}
+
+ CopyCountingTask& operator=(const CopyCountingTask& other) {
+ counter = other.counter + 1;
+ target = other.target;
+ return *this;
+ }
+
+ CopyCountingTask(CopyCountingTask&& other) = default;
+ CopyCountingTask& operator=(CopyCountingTask&& other) = default;
+
+ Status operator()() {
+ *target = counter;
+ return Status::OK();
+ }
+
+ private:
+ uint8_t counter;
+ std::shared_ptr<uint8_t> target;
+};
+
+// Check TaskGroup behaviour with tasks spawning other tasks
+void TestTasksSpawnTasks(std::shared_ptr<TaskGroup> task_group) {
+ const int N = 6;
+
+ std::atomic<int> count(0);
+ // Make a task that recursively spawns itself
+ std::function<std::function<Status()>(int)> make_task = [&](int i) {
+ return [&, i]() {
+ count++;
+ if (i > 0) {
+ // Exercise parallelism by spawning two tasks at once and then sleeping
+ task_group->Append(make_task(i - 1));
+ task_group->Append(make_task(i - 1));
+ SleepFor(1e-3);
+ }
+ return Status::OK();
+ };
+ };
+
+ task_group->Append(make_task(N));
+
+ ASSERT_OK(task_group->Finish());
+ ASSERT_TRUE(task_group->ok());
+ ASSERT_EQ(count.load(), (1 << (N + 1)) - 1);
+}
+
+// A task that keeps recursing until a barrier is set.
+// Using a lambda for this doesn't play well with Thread Sanitizer.
+struct BarrierTask {
+ std::atomic<bool>* barrier_;
+ std::weak_ptr<TaskGroup> weak_group_ptr_;
+ Status final_status_;
+
+ Status operator()() {
+ if (!barrier_->load()) {
+ SleepFor(1e-5);
+ // Note the TaskGroup should be kept alive by the fact this task
+ // is still running...
+ weak_group_ptr_.lock()->Append(*this);
+ }
+ return final_status_;
+ }
+};
+
+// Try to replicate subtle lifetime issues when destroying a TaskGroup
+// where all tasks may not have finished running.
+void StressTaskGroupLifetime(std::function<std::shared_ptr<TaskGroup>()> factory) {
+ const int NTASKS = 100;
+ auto task_group = factory();
+ auto weak_group_ptr = std::weak_ptr<TaskGroup>(task_group);
+
+ std::atomic<bool> barrier(false);
+
+ BarrierTask task{&barrier, weak_group_ptr, Status::OK()};
+
+ for (int i = 0; i < NTASKS; ++i) {
+ task_group->Append(task);
+ }
+
+ // Lose strong reference
+ barrier.store(true);
+ task_group.reset();
+
+ // Wait for finish
+ while (!weak_group_ptr.expired()) {
+ SleepFor(1e-5);
+ }
+}
+
+// Same, but with also a failing task
+void StressFailingTaskGroupLifetime(std::function<std::shared_ptr<TaskGroup>()> factory) {
+ const int NTASKS = 100;
+ auto task_group = factory();
+ auto weak_group_ptr = std::weak_ptr<TaskGroup>(task_group);
+
+ std::atomic<bool> barrier(false);
+
+ BarrierTask task{&barrier, weak_group_ptr, Status::OK()};
+ BarrierTask failing_task{&barrier, weak_group_ptr, Status::Invalid("XXX")};
+
+ for (int i = 0; i < NTASKS; ++i) {
+ task_group->Append(task);
+ }
+ task_group->Append(failing_task);
+
+ // Lose strong reference
+ barrier.store(true);
+ task_group.reset();
+
+ // Wait for finish
+ while (!weak_group_ptr.expired()) {
+ SleepFor(1e-5);
+ }
+}
+
+void TestNoCopyTask(std::shared_ptr<TaskGroup> task_group) {
+ auto counter = std::make_shared<uint8_t>(0);
+ CopyCountingTask task(counter);
+ task_group->Append(std::move(task));
+ ASSERT_OK(task_group->Finish());
+ ASSERT_EQ(0, *counter);
+}
+
+void TestFinishNotSticky(std::function<std::shared_ptr<TaskGroup>()> factory) {
+ // If a task is added that runs very quickly it might decrement the task counter back
+ // down to 0 and mark the completion future as complete before all tasks are added.
+ // The "finished future" of the task group could get stuck to complete.
+ //
+ // Instead the task group should not allow the finished future to be marked complete
+ // until after FinishAsync has been called.
+ const int NTASKS = 100;
+ for (int i = 0; i < NTASKS; ++i) {
+ auto task_group = factory();
+ // Add a task and let it complete
+ task_group->Append([] { return Status::OK(); });
+ // Wait a little bit, if the task group was going to lock the finish hopefully it
+ // would do so here while we wait
+ SleepFor(1e-2);
+
+ // Add a new task that will still be running
+ std::atomic<bool> ready(false);
+ std::mutex m;
+ std::condition_variable cv;
+ task_group->Append([&m, &cv, &ready] {
+ std::unique_lock<std::mutex> lk(m);
+ cv.wait(lk, [&ready] { return ready.load(); });
+ return Status::OK();
+ });
+
+ // Ensure task group not finished already
+ auto finished = task_group->FinishAsync();
+ ASSERT_FALSE(finished.is_finished());
+
+ std::unique_lock<std::mutex> lk(m);
+ ready = true;
+ lk.unlock();
+ cv.notify_one();
+
+ ASSERT_FINISHES_OK(finished);
+ }
+}
+
+void TestFinishNeverStarted(std::shared_ptr<TaskGroup> task_group) {
+ // If we call FinishAsync we are done adding tasks so if we never added any it should be
+ // completed
+ auto finished = task_group->FinishAsync();
+ ASSERT_TRUE(finished.Wait(1));
+}
+
+void TestFinishAlreadyCompleted(std::function<std::shared_ptr<TaskGroup>()> factory) {
+ // If we call FinishAsync we are done adding tasks so even if no tasks are running we
+ // should still be completed
+ const int NTASKS = 100;
+ for (int i = 0; i < NTASKS; ++i) {
+ auto task_group = factory();
+ // Add a task and let it complete
+ task_group->Append([] { return Status::OK(); });
+ // Wait a little bit, hopefully enough time for the task to finish on one of these
+ // iterations
+ SleepFor(1e-2);
+ auto finished = task_group->FinishAsync();
+ ASSERT_FINISHES_OK(finished);
+ }
+}
+
+TEST(SerialTaskGroup, Success) { TestTaskGroupSuccess(TaskGroup::MakeSerial()); }
+
+TEST(SerialTaskGroup, Errors) { TestTaskGroupErrors(TaskGroup::MakeSerial()); }
+
+TEST(SerialTaskGroup, Cancel) {
+ StopSource stop_source;
+ TestTaskGroupCancel(TaskGroup::MakeSerial(stop_source.token()), &stop_source);
+}
+
+TEST(SerialTaskGroup, TasksSpawnTasks) { TestTasksSpawnTasks(TaskGroup::MakeSerial()); }
+
+TEST(SerialTaskGroup, NoCopyTask) { TestNoCopyTask(TaskGroup::MakeSerial()); }
+
+TEST(SerialTaskGroup, FinishNeverStarted) {
+ TestFinishNeverStarted(TaskGroup::MakeSerial());
+}
+
+TEST(SerialTaskGroup, FinishAlreadyCompleted) {
+ TestFinishAlreadyCompleted([] { return TaskGroup::MakeSerial(); });
+}
+
+TEST(ThreadedTaskGroup, Success) {
+ auto task_group = TaskGroup::MakeThreaded(GetCpuThreadPool());
+ TestTaskGroupSuccess(task_group);
+}
+
+TEST(ThreadedTaskGroup, Errors) {
+ // Limit parallelism to ensure some tasks don't get started
+ // after the first failing ones
+ std::shared_ptr<ThreadPool> thread_pool;
+ ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4));
+
+ TestTaskGroupErrors(TaskGroup::MakeThreaded(thread_pool.get()));
+}
+
+TEST(ThreadedTaskGroup, Cancel) {
+ std::shared_ptr<ThreadPool> thread_pool;
+ ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4));
+
+ StopSource stop_source;
+ TestTaskGroupCancel(TaskGroup::MakeThreaded(thread_pool.get(), stop_source.token()),
+ &stop_source);
+}
+
+TEST(ThreadedTaskGroup, TasksSpawnTasks) {
+ auto task_group = TaskGroup::MakeThreaded(GetCpuThreadPool());
+ TestTasksSpawnTasks(task_group);
+}
+
+TEST(ThreadedTaskGroup, NoCopyTask) {
+ std::shared_ptr<ThreadPool> thread_pool;
+ ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4));
+ TestNoCopyTask(TaskGroup::MakeThreaded(thread_pool.get()));
+}
+
+TEST(ThreadedTaskGroup, StressTaskGroupLifetime) {
+ std::shared_ptr<ThreadPool> thread_pool;
+ ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16));
+
+ StressTaskGroupLifetime([&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
+}
+
+TEST(ThreadedTaskGroup, StressFailingTaskGroupLifetime) {
+ std::shared_ptr<ThreadPool> thread_pool;
+ ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16));
+
+ StressFailingTaskGroupLifetime(
+ [&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
+}
+
+TEST(ThreadedTaskGroup, FinishNotSticky) {
+ std::shared_ptr<ThreadPool> thread_pool;
+ ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16));
+
+ TestFinishNotSticky([&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
+}
+
+TEST(ThreadedTaskGroup, FinishNeverStarted) {
+ std::shared_ptr<ThreadPool> thread_pool;
+ ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4));
+ TestFinishNeverStarted(TaskGroup::MakeThreaded(thread_pool.get()));
+}
+
+TEST(ThreadedTaskGroup, FinishAlreadyCompleted) {
+ std::shared_ptr<ThreadPool> thread_pool;
+ ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16));
+
+ TestFinishAlreadyCompleted([&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/tdigest.cc b/src/arrow/cpp/src/arrow/util/tdigest.cc
new file mode 100644
index 000000000..ee84a5ef6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/tdigest.cc
@@ -0,0 +1,420 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/tdigest.h"
+
+#include <algorithm>
+#include <cmath>
+#include <iostream>
+#include <limits>
+#include <queue>
+#include <tuple>
+#include <vector>
+
+#include "arrow/status.h"
+#include "arrow/util/math_constants.h"
+
+namespace arrow {
+namespace internal {
+
+namespace {
+
+// a numerically stable lerp is unbelievably complex
+// but we are *approximating* the quantile, so let's keep it simple
+double Lerp(double a, double b, double t) { return a + t * (b - a); }
+
+// histogram bin
+struct Centroid {
+ double mean;
+ double weight; // # data points in this bin
+
+ // merge with another centroid
+ void Merge(const Centroid& centroid) {
+ weight += centroid.weight;
+ mean += (centroid.mean - mean) * centroid.weight / weight;
+ }
+};
+
+// scale function K0: linear function, as baseline
+struct ScalerK0 {
+ explicit ScalerK0(uint32_t delta) : delta_norm(delta / 2.0) {}
+
+ double K(double q) const { return delta_norm * q; }
+ double Q(double k) const { return k / delta_norm; }
+
+ const double delta_norm;
+};
+
+// scale function K1
+struct ScalerK1 {
+ explicit ScalerK1(uint32_t delta) : delta_norm(delta / (2.0 * M_PI)) {}
+
+ double K(double q) const { return delta_norm * std::asin(2 * q - 1); }
+ double Q(double k) const { return (std::sin(k / delta_norm) + 1) / 2; }
+
+ const double delta_norm;
+};
+
+// implements t-digest merging algorithm
+template <class T = ScalerK1>
+class TDigestMerger : private T {
+ public:
+ explicit TDigestMerger(uint32_t delta) : T(delta) { Reset(0, nullptr); }
+
+ void Reset(double total_weight, std::vector<Centroid>* tdigest) {
+ total_weight_ = total_weight;
+ tdigest_ = tdigest;
+ if (tdigest_) {
+ tdigest_->resize(0);
+ }
+ weight_so_far_ = 0;
+ weight_limit_ = -1; // trigger first centroid merge
+ }
+
+ // merge one centroid from a sorted centroid stream
+ void Add(const Centroid& centroid) {
+ auto& td = *tdigest_;
+ const double weight = weight_so_far_ + centroid.weight;
+ if (weight <= weight_limit_) {
+ td.back().Merge(centroid);
+ } else {
+ const double quantile = weight_so_far_ / total_weight_;
+ const double next_weight_limit = total_weight_ * this->Q(this->K(quantile) + 1);
+ // weight limit should be strictly increasing, until the last centroid
+ if (next_weight_limit <= weight_limit_) {
+ weight_limit_ = total_weight_;
+ } else {
+ weight_limit_ = next_weight_limit;
+ }
+ td.push_back(centroid); // should never exceed capacity and trigger reallocation
+ }
+ weight_so_far_ = weight;
+ }
+
+ // validate k-size of a tdigest
+ Status Validate(const std::vector<Centroid>& tdigest, double total_weight) const {
+ double q_prev = 0, k_prev = this->K(0);
+ for (size_t i = 0; i < tdigest.size(); ++i) {
+ const double q = q_prev + tdigest[i].weight / total_weight;
+ const double k = this->K(q);
+ if (tdigest[i].weight != 1 && (k - k_prev) > 1.001) {
+ return Status::Invalid("oversized centroid: ", k - k_prev);
+ }
+ k_prev = k;
+ q_prev = q;
+ }
+ return Status::OK();
+ }
+
+ private:
+ double total_weight_; // total weight of this tdigest
+ double weight_so_far_; // accumulated weight till current bin
+ double weight_limit_; // max accumulated weight to move to next bin
+ std::vector<Centroid>* tdigest_;
+};
+
+} // namespace
+
+class TDigest::TDigestImpl {
+ public:
+ explicit TDigestImpl(uint32_t delta)
+ : delta_(delta > 10 ? delta : 10), merger_(delta_) {
+ tdigests_[0].reserve(delta_);
+ tdigests_[1].reserve(delta_);
+ Reset();
+ }
+
+ void Reset() {
+ tdigests_[0].resize(0);
+ tdigests_[1].resize(0);
+ current_ = 0;
+ total_weight_ = 0;
+ min_ = std::numeric_limits<double>::max();
+ max_ = std::numeric_limits<double>::lowest();
+ merger_.Reset(0, nullptr);
+ }
+
+ Status Validate() const {
+ // check weight, centroid order
+ double total_weight = 0, prev_mean = std::numeric_limits<double>::lowest();
+ for (const auto& centroid : tdigests_[current_]) {
+ if (std::isnan(centroid.mean) || std::isnan(centroid.weight)) {
+ return Status::Invalid("NAN found in tdigest");
+ }
+ if (centroid.mean < prev_mean) {
+ return Status::Invalid("centroid mean decreases");
+ }
+ if (centroid.weight < 1) {
+ return Status::Invalid("invalid centroid weight");
+ }
+ prev_mean = centroid.mean;
+ total_weight += centroid.weight;
+ }
+ if (total_weight != total_weight_) {
+ return Status::Invalid("tdigest total weight mismatch");
+ }
+ // check if buffer expanded
+ if (tdigests_[0].capacity() > delta_ || tdigests_[1].capacity() > delta_) {
+ return Status::Invalid("oversized tdigest buffer");
+ }
+ // check k-size
+ return merger_.Validate(tdigests_[current_], total_weight_);
+ }
+
+ void Dump() const {
+ const auto& td = tdigests_[current_];
+ for (size_t i = 0; i < td.size(); ++i) {
+ std::cerr << i << ": mean = " << td[i].mean << ", weight = " << td[i].weight
+ << std::endl;
+ }
+ std::cerr << "min = " << min_ << ", max = " << max_ << std::endl;
+ }
+
+ // merge with other tdigests
+ void Merge(const std::vector<const TDigestImpl*>& tdigest_impls) {
+ // current and end iterator
+ using CentroidIter = std::vector<Centroid>::const_iterator;
+ using CentroidIterPair = std::pair<CentroidIter, CentroidIter>;
+ // use a min-heap to find next minimal centroid from all tdigests
+ auto centroid_gt = [](const CentroidIterPair& lhs, const CentroidIterPair& rhs) {
+ return lhs.first->mean > rhs.first->mean;
+ };
+ using CentroidQueue =
+ std::priority_queue<CentroidIterPair, std::vector<CentroidIterPair>,
+ decltype(centroid_gt)>;
+
+ // trivial dynamic memory allocated at runtime
+ std::vector<CentroidIterPair> queue_buffer;
+ queue_buffer.reserve(tdigest_impls.size() + 1);
+ CentroidQueue queue(std::move(centroid_gt), std::move(queue_buffer));
+
+ const auto& this_tdigest = tdigests_[current_];
+ if (this_tdigest.size() > 0) {
+ queue.emplace(this_tdigest.cbegin(), this_tdigest.cend());
+ }
+ for (const TDigestImpl* td : tdigest_impls) {
+ const auto& other_tdigest = td->tdigests_[td->current_];
+ if (other_tdigest.size() > 0) {
+ queue.emplace(other_tdigest.cbegin(), other_tdigest.cend());
+ total_weight_ += td->total_weight_;
+ min_ = std::min(min_, td->min_);
+ max_ = std::max(max_, td->max_);
+ }
+ }
+
+ merger_.Reset(total_weight_, &tdigests_[1 - current_]);
+ CentroidIter current_iter, end_iter;
+ // do k-way merge till one buffer left
+ while (queue.size() > 1) {
+ std::tie(current_iter, end_iter) = queue.top();
+ merger_.Add(*current_iter);
+ queue.pop();
+ if (++current_iter != end_iter) {
+ queue.emplace(current_iter, end_iter);
+ }
+ }
+ // merge last buffer
+ if (!queue.empty()) {
+ std::tie(current_iter, end_iter) = queue.top();
+ while (current_iter != end_iter) {
+ merger_.Add(*current_iter++);
+ }
+ }
+ merger_.Reset(0, nullptr);
+
+ current_ = 1 - current_;
+ }
+
+ // merge input data with current tdigest
+ void MergeInput(std::vector<double>& input) {
+ total_weight_ += input.size();
+
+ std::sort(input.begin(), input.end());
+ min_ = std::min(min_, input.front());
+ max_ = std::max(max_, input.back());
+
+ // pick next minimal centroid from input and tdigest, feed to merger
+ merger_.Reset(total_weight_, &tdigests_[1 - current_]);
+ const auto& td = tdigests_[current_];
+ uint32_t tdigest_index = 0, input_index = 0;
+ while (tdigest_index < td.size() && input_index < input.size()) {
+ if (td[tdigest_index].mean < input[input_index]) {
+ merger_.Add(td[tdigest_index++]);
+ } else {
+ merger_.Add(Centroid{input[input_index++], 1});
+ }
+ }
+ while (tdigest_index < td.size()) {
+ merger_.Add(td[tdigest_index++]);
+ }
+ while (input_index < input.size()) {
+ merger_.Add(Centroid{input[input_index++], 1});
+ }
+ merger_.Reset(0, nullptr);
+
+ input.resize(0);
+ current_ = 1 - current_;
+ }
+
+ double Quantile(double q) const {
+ const auto& td = tdigests_[current_];
+
+ if (q < 0 || q > 1 || td.size() == 0) {
+ return NAN;
+ }
+
+ const double index = q * total_weight_;
+ if (index <= 1) {
+ return min_;
+ } else if (index >= total_weight_ - 1) {
+ return max_;
+ }
+
+ // find centroid contains the index
+ uint32_t ci = 0;
+ double weight_sum = 0;
+ for (; ci < td.size(); ++ci) {
+ weight_sum += td[ci].weight;
+ if (index <= weight_sum) {
+ break;
+ }
+ }
+ DCHECK_LT(ci, td.size());
+
+ // deviation of index from the centroid center
+ double diff = index + td[ci].weight / 2 - weight_sum;
+
+ // index happen to be in a unit weight centroid
+ if (td[ci].weight == 1 && std::abs(diff) < 0.5) {
+ return td[ci].mean;
+ }
+
+ // find adjacent centroids for interpolation
+ uint32_t ci_left = ci, ci_right = ci;
+ if (diff > 0) {
+ if (ci_right == td.size() - 1) {
+ // index larger than center of last bin
+ DCHECK_EQ(weight_sum, total_weight_);
+ const Centroid* c = &td[ci_right];
+ DCHECK_GE(c->weight, 2);
+ return Lerp(c->mean, max_, diff / (c->weight / 2));
+ }
+ ++ci_right;
+ } else {
+ if (ci_left == 0) {
+ // index smaller than center of first bin
+ const Centroid* c = &td[0];
+ DCHECK_GE(c->weight, 2);
+ return Lerp(min_, c->mean, index / (c->weight / 2));
+ }
+ --ci_left;
+ diff += td[ci_left].weight / 2 + td[ci_right].weight / 2;
+ }
+
+ // interpolate from adjacent centroids
+ diff /= (td[ci_left].weight / 2 + td[ci_right].weight / 2);
+ return Lerp(td[ci_left].mean, td[ci_right].mean, diff);
+ }
+
+ double Mean() const {
+ double sum = 0;
+ for (const auto& centroid : tdigests_[current_]) {
+ sum += centroid.mean * centroid.weight;
+ }
+ return total_weight_ == 0 ? NAN : sum / total_weight_;
+ }
+
+ double total_weight() const { return total_weight_; }
+
+ private:
+ // must be delcared before merger_, see constructor initialization list
+ const uint32_t delta_;
+
+ TDigestMerger<> merger_;
+ double total_weight_;
+ double min_, max_;
+
+ // ping-pong buffer holds two tdigests, size = 2 * delta * sizeof(Centroid)
+ std::vector<Centroid> tdigests_[2];
+ // index of active tdigest buffer, 0 or 1
+ int current_;
+};
+
+TDigest::TDigest(uint32_t delta, uint32_t buffer_size) : impl_(new TDigestImpl(delta)) {
+ input_.reserve(buffer_size);
+ Reset();
+}
+
+TDigest::~TDigest() = default;
+TDigest::TDigest(TDigest&&) = default;
+TDigest& TDigest::operator=(TDigest&&) = default;
+
+void TDigest::Reset() {
+ input_.resize(0);
+ impl_->Reset();
+}
+
+Status TDigest::Validate() const {
+ MergeInput();
+ return impl_->Validate();
+}
+
+void TDigest::Dump() const {
+ MergeInput();
+ impl_->Dump();
+}
+
+void TDigest::Merge(const std::vector<TDigest>& others) {
+ MergeInput();
+
+ std::vector<const TDigestImpl*> other_impls;
+ other_impls.reserve(others.size());
+ for (auto& other : others) {
+ other.MergeInput();
+ other_impls.push_back(other.impl_.get());
+ }
+ impl_->Merge(other_impls);
+}
+
+void TDigest::Merge(const TDigest& other) {
+ MergeInput();
+ other.MergeInput();
+ impl_->Merge({other.impl_.get()});
+}
+
+double TDigest::Quantile(double q) const {
+ MergeInput();
+ return impl_->Quantile(q);
+}
+
+double TDigest::Mean() const {
+ MergeInput();
+ return impl_->Mean();
+}
+
+bool TDigest::is_empty() const {
+ return input_.size() == 0 && impl_->total_weight() == 0;
+}
+
+void TDigest::MergeInput() const {
+ if (input_.size() > 0) {
+ impl_->MergeInput(input_); // will mutate input_
+ }
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/tdigest.h b/src/arrow/cpp/src/arrow/util/tdigest.h
new file mode 100644
index 000000000..308df4688
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/tdigest.h
@@ -0,0 +1,104 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// approximate quantiles from arbitrary length dataset with O(1) space
+// based on 'Computing Extremely Accurate Quantiles Using t-Digests' from Dunning & Ertl
+// - https://arxiv.org/abs/1902.04023
+// - https://github.com/tdunning/t-digest
+
+#pragma once
+
+#include <cmath>
+#include <memory>
+#include <vector>
+
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class Status;
+
+namespace internal {
+
+class ARROW_EXPORT TDigest {
+ public:
+ explicit TDigest(uint32_t delta = 100, uint32_t buffer_size = 500);
+ ~TDigest();
+ TDigest(TDigest&&);
+ TDigest& operator=(TDigest&&);
+
+ // reset and re-use this tdigest
+ void Reset();
+
+ // validate data integrity
+ Status Validate() const;
+
+ // dump internal data, only for debug
+ void Dump() const;
+
+ // buffer a single data point, consume internal buffer if full
+ // this function is intensively called and performance critical
+ // call it only if you are sure no NAN exists in input data
+ void Add(double value) {
+ DCHECK(!std::isnan(value)) << "cannot add NAN";
+ if (ARROW_PREDICT_FALSE(input_.size() == input_.capacity())) {
+ MergeInput();
+ }
+ input_.push_back(value);
+ }
+
+ // skip NAN on adding
+ template <typename T>
+ typename std::enable_if<std::is_floating_point<T>::value>::type NanAdd(T value) {
+ if (!std::isnan(value)) Add(value);
+ }
+
+ template <typename T>
+ typename std::enable_if<std::is_integral<T>::value>::type NanAdd(T value) {
+ Add(static_cast<double>(value));
+ }
+
+ // merge with other t-digests, called infrequently
+ void Merge(const std::vector<TDigest>& others);
+ void Merge(const TDigest& other);
+
+ // calculate quantile
+ double Quantile(double q) const;
+
+ double Min() const { return Quantile(0); }
+ double Max() const { return Quantile(1); }
+ double Mean() const;
+
+ // check if this tdigest contains no valid data points
+ bool is_empty() const;
+
+ private:
+ // merge input data with current tdigest
+ void MergeInput() const;
+
+ // input buffer, size = buffer_size * sizeof(double)
+ mutable std::vector<double> input_;
+
+ // hide other members with pimpl
+ class TDigestImpl;
+ std::unique_ptr<TDigestImpl> impl_;
+};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/tdigest_benchmark.cc b/src/arrow/cpp/src/arrow/util/tdigest_benchmark.cc
new file mode 100644
index 000000000..0b9545090
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/tdigest_benchmark.cc
@@ -0,0 +1,48 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/tdigest.h"
+
+namespace arrow {
+namespace util {
+
+static constexpr uint32_t kDelta = 100;
+static constexpr uint32_t kBufferSize = 500;
+
+static void BenchmarkTDigest(benchmark::State& state) {
+ const size_t items = state.range(0);
+ std::vector<double> values;
+ random_real(items, 0x11223344, -12345678.0, 12345678.0, &values);
+
+ for (auto _ : state) {
+ arrow::internal::TDigest td(kDelta, kBufferSize);
+ for (double value : values) {
+ td.Add(value);
+ }
+ benchmark::DoNotOptimize(td.Quantile(0));
+ }
+ state.SetItemsProcessed(state.iterations() * items);
+}
+
+BENCHMARK(BenchmarkTDigest)->Arg(1 << 12)->Arg(1 << 16)->Arg(1 << 20);
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/tdigest_test.cc b/src/arrow/cpp/src/arrow/util/tdigest_test.cc
new file mode 100644
index 000000000..532046b20
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/tdigest_test.cc
@@ -0,0 +1,290 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// XXX: There's no rigid error bound available. The accuracy is to some degree
+// *random*, which depends on input data and quantiles to be calculated. I also
+// find small gaps among linux/windows/macos.
+// In below tests, most quantiles are within 1% deviation from exact values,
+// while the worst test case is about 10% drift.
+// To make test result stable, I relaxed error bound to be *good enough*.
+// #define _TDIGEST_STRICT_TEST // enable more strict tests
+
+#include <algorithm>
+#include <cmath>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/tdigest.h"
+
+namespace arrow {
+namespace internal {
+
+TEST(TDigestTest, SingleValue) {
+ const double value = 0.12345678;
+
+ TDigest td;
+ td.Add(value);
+ ASSERT_OK(td.Validate());
+ // all quantiles equal to same single vaue
+ for (double q = 0; q <= 1; q += 0.1) {
+ EXPECT_EQ(td.Quantile(q), value);
+ }
+}
+
+TEST(TDigestTest, FewValues) {
+ // exact quantile at 0.1 interval, test sorted and unsorted input
+ std::vector<std::vector<double>> values_vector = {
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
+ {4, 1, 9, 0, 3, 2, 5, 6, 8, 7, 10},
+ };
+
+ for (const auto& values : values_vector) {
+ TDigest td;
+ for (double v : values) {
+ td.Add(v);
+ }
+ ASSERT_OK(td.Validate());
+
+ double q = 0;
+ for (size_t i = 0; i < values.size(); ++i) {
+ double expected = static_cast<double>(i);
+ EXPECT_EQ(td.Quantile(q), expected);
+ q += 0.1;
+ }
+ }
+}
+
+// Calculate exact quantile as truth
+std::vector<double> ExactQuantile(std::vector<double> values,
+ const std::vector<double>& quantiles) {
+ std::sort(values.begin(), values.end());
+
+ std::vector<double> output;
+ for (double q : quantiles) {
+ const double index = (values.size() - 1) * q;
+ const int64_t lower_index = static_cast<int64_t>(index);
+ const double fraction = index - lower_index;
+ if (fraction == 0) {
+ output.push_back(values[lower_index]);
+ } else {
+ const double lerp =
+ fraction * values[lower_index + 1] + (1 - fraction) * values[lower_index];
+ output.push_back(lerp);
+ }
+ }
+ return output;
+}
+
+void TestRandom(size_t size) {
+ const std::vector<double> fixed_quantiles = {0, 0.01, 0.1, 0.2, 0.5, 0.8, 0.9, 0.99, 1};
+
+ // append random quantiles to test
+ std::vector<double> quantiles;
+ random_real(50, 0x11223344, 0.0, 1.0, &quantiles);
+ quantiles.insert(quantiles.end(), fixed_quantiles.cbegin(), fixed_quantiles.cend());
+
+ // generate random test values
+ const double min = 1e3, max = 1e10;
+ std::vector<double> values;
+ random_real(size, 0x11223344, min, max, &values);
+
+ TDigest td(200);
+ for (double value : values) {
+ td.Add(value);
+ }
+ ASSERT_OK(td.Validate());
+
+ const std::vector<double> expected = ExactQuantile(values, quantiles);
+ std::vector<double> approximated;
+ for (auto q : quantiles) {
+ approximated.push_back(td.Quantile(q));
+ }
+
+ // r-square of expected and approximated quantiles should be greater than 0.999
+ const double expected_mean =
+ std::accumulate(expected.begin(), expected.end(), 0.0) / expected.size();
+ double rss = 0, tss = 0;
+ for (size_t i = 0; i < quantiles.size(); ++i) {
+ rss += (expected[i] - approximated[i]) * (expected[i] - approximated[i]);
+ tss += (expected[i] - expected_mean) * (expected[i] - expected_mean);
+ }
+ const double r2 = 1 - rss / tss;
+ EXPECT_GT(r2, 0.999);
+
+ // make sure no quantile drifts too much from the truth
+#ifdef _TDIGEST_STRICT_TEST
+ const double error_ratio = 0.02;
+#else
+ const double error_ratio = 0.05;
+#endif
+ for (size_t i = 0; i < quantiles.size(); ++i) {
+ const double tolerance = std::fabs(expected[i]) * error_ratio;
+ EXPECT_NEAR(approximated[i], expected[i], tolerance) << quantiles[i];
+ }
+}
+
+TEST(TDigestTest, RandomValues) { TestRandom(100000); }
+
+// too heavy to run in ci
+TEST(TDigestTest, DISABLED_HugeVolume) { TestRandom(1U << 30); }
+
+void TestMerge(const std::vector<std::vector<double>>& values_vector, uint32_t delta,
+ double error_ratio) {
+ const std::vector<double> quantiles = {0, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5,
+ 0.6, 0.7, 0.8, 0.9, 0.99, 1};
+
+ std::vector<TDigest> tds;
+ for (const auto& values : values_vector) {
+ TDigest td(delta);
+ for (double value : values) {
+ td.Add(value);
+ }
+ ASSERT_OK(td.Validate());
+ tds.push_back(std::move(td));
+ }
+
+ std::vector<double> values_combined;
+ for (const auto& values : values_vector) {
+ values_combined.insert(values_combined.end(), values.begin(), values.end());
+ }
+ const std::vector<double> expected = ExactQuantile(values_combined, quantiles);
+
+ // merge into an empty tdigest
+ {
+ TDigest td(delta);
+ td.Merge(tds);
+ ASSERT_OK(td.Validate());
+ for (size_t i = 0; i < quantiles.size(); ++i) {
+ const double tolerance = std::max(std::fabs(expected[i]) * error_ratio, 0.1);
+ EXPECT_NEAR(td.Quantile(quantiles[i]), expected[i], tolerance) << quantiles[i];
+ }
+ }
+
+ // merge into a non empty tdigest
+ {
+ TDigest td = std::move(tds[0]);
+ tds.erase(tds.begin(), tds.begin() + 1);
+ td.Merge(tds);
+ ASSERT_OK(td.Validate());
+ for (size_t i = 0; i < quantiles.size(); ++i) {
+ const double tolerance = std::max(std::fabs(expected[i]) * error_ratio, 0.1);
+ EXPECT_NEAR(td.Quantile(quantiles[i]), expected[i], tolerance) << quantiles[i];
+ }
+ }
+}
+
+// merge tdigests with same distribution
+TEST(TDigestTest, MergeUniform) {
+ const std::vector<size_t> sizes = {20000, 3000, 1500, 18000, 9999, 6666};
+ std::vector<std::vector<double>> values_vector;
+ for (auto size : sizes) {
+ std::vector<double> values;
+ random_real(size, 0x11223344, -123456789.0, 987654321.0, &values);
+ values_vector.push_back(std::move(values));
+ }
+
+#ifdef _TDIGEST_STRICT_TEST
+ TestMerge(values_vector, /*delta=*/100, /*error_ratio=*/0.01);
+#else
+ TestMerge(values_vector, /*delta=*/200, /*error_ratio=*/0.05);
+#endif
+}
+
+// merge tdigests with different distributions
+TEST(TDigestTest, MergeNonUniform) {
+ struct {
+ size_t size;
+ double min;
+ double max;
+ } configs[] = {
+ {2000, 1e8, 1e9}, {0, 0, 0}, {3000, -1, 1}, {500, -1e6, -1e5}, {800, 100, 100},
+ };
+ std::vector<std::vector<double>> values_vector;
+ for (const auto& cfg : configs) {
+ std::vector<double> values;
+ random_real(cfg.size, 0x11223344, cfg.min, cfg.max, &values);
+ values_vector.push_back(std::move(values));
+ }
+
+#ifdef _TDIGEST_STRICT_TEST
+ TestMerge(values_vector, /*delta=*/200, /*error_ratio=*/0.01);
+#else
+ TestMerge(values_vector, /*delta=*/200, /*error_ratio=*/0.05);
+#endif
+}
+
+TEST(TDigestTest, Misc) {
+ const size_t size = 100000;
+ const double min = -1000, max = 1000;
+ const std::vector<double> quantiles = {0, 0.01, 0.1, 0.4, 0.7, 0.9, 0.99, 1};
+
+ std::vector<double> values;
+ random_real(size, 0x11223344, min, max, &values);
+ const std::vector<double> expected = ExactQuantile(values, quantiles);
+
+ // test small delta and buffer
+ {
+#ifdef _TDIGEST_STRICT_TEST
+ const double error_ratio = 0.06; // low accuracy for small delta
+#else
+ const double error_ratio = 0.15;
+#endif
+
+ TDigest td(10, 50);
+ for (double value : values) {
+ td.Add(value);
+ }
+ ASSERT_OK(td.Validate());
+
+ for (size_t i = 0; i < quantiles.size(); ++i) {
+ const double tolerance = std::max(std::fabs(expected[i]) * error_ratio, 0.1);
+ EXPECT_NEAR(td.Quantile(quantiles[i]), expected[i], tolerance) << quantiles[i];
+ }
+ }
+
+ // test many duplicated values
+ {
+#ifdef _TDIGEST_STRICT_TEST
+ const double error_ratio = 0.02;
+#else
+ const double error_ratio = 0.05;
+#endif
+
+ auto values_integer = values;
+ for (double& value : values_integer) {
+ value = std::ceil(value);
+ }
+
+ TDigest td(100);
+ for (double value : values_integer) {
+ td.Add(value);
+ }
+ ASSERT_OK(td.Validate());
+
+ for (size_t i = 0; i < quantiles.size(); ++i) {
+ const double tolerance = std::max(std::fabs(expected[i]) * error_ratio, 0.1);
+ EXPECT_NEAR(td.Quantile(quantiles[i]), expected[i], tolerance) << quantiles[i];
+ }
+ }
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/test_common.cc b/src/arrow/cpp/src/arrow/util/test_common.cc
new file mode 100644
index 000000000..ac187ba0c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/test_common.cc
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/test_common.h"
+
+namespace arrow {
+
+TestInt::TestInt() : value(-999) {}
+TestInt::TestInt(int i) : value(i) {} // NOLINT runtime/explicit
+bool TestInt::operator==(const TestInt& other) const { return value == other.value; }
+
+std::ostream& operator<<(std::ostream& os, const TestInt& v) {
+ os << "{" << v.value << "}";
+ return os;
+}
+
+TestStr::TestStr() : value("") {}
+TestStr::TestStr(const std::string& s) : value(s) {} // NOLINT runtime/explicit
+TestStr::TestStr(const char* s) : value(s) {} // NOLINT runtime/explicit
+TestStr::TestStr(const TestInt& test_int) {
+ if (IsIterationEnd(test_int)) {
+ value = "";
+ } else {
+ value = std::to_string(test_int.value);
+ }
+}
+
+bool TestStr::operator==(const TestStr& other) const { return value == other.value; }
+
+std::ostream& operator<<(std::ostream& os, const TestStr& v) {
+ os << "{\"" << v.value << "\"}";
+ return os;
+}
+
+std::vector<TestInt> RangeVector(unsigned int max, unsigned int step) {
+ auto count = max / step;
+ std::vector<TestInt> range(count);
+ for (unsigned int i = 0; i < count; i++) {
+ range[i] = i * step;
+ }
+ return range;
+}
+
+Transformer<TestInt, TestStr> MakeFilter(std::function<bool(TestInt&)> filter) {
+ return [filter](TestInt next) -> Result<TransformFlow<TestStr>> {
+ if (filter(next)) {
+ return TransformYield(TestStr(next));
+ } else {
+ return TransformSkip();
+ }
+ };
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/test_common.h b/src/arrow/cpp/src/arrow/util/test_common.h
new file mode 100644
index 000000000..511daed1e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/test_common.h
@@ -0,0 +1,90 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <iosfwd>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/iterator.h"
+
+namespace arrow {
+
+struct TestInt {
+ TestInt();
+ TestInt(int i); // NOLINT runtime/explicit
+ int value;
+
+ bool operator==(const TestInt& other) const;
+
+ friend std::ostream& operator<<(std::ostream& os, const TestInt& v);
+};
+
+template <>
+struct IterationTraits<TestInt> {
+ static TestInt End() { return TestInt(); }
+ static bool IsEnd(const TestInt& val) { return val == IterationTraits<TestInt>::End(); }
+};
+
+struct TestStr {
+ TestStr();
+ TestStr(const std::string& s); // NOLINT runtime/explicit
+ TestStr(const char* s); // NOLINT runtime/explicit
+ explicit TestStr(const TestInt& test_int);
+ std::string value;
+
+ bool operator==(const TestStr& other) const;
+
+ friend std::ostream& operator<<(std::ostream& os, const TestStr& v);
+};
+
+template <>
+struct IterationTraits<TestStr> {
+ static TestStr End() { return TestStr(); }
+ static bool IsEnd(const TestStr& val) { return val == IterationTraits<TestStr>::End(); }
+};
+
+std::vector<TestInt> RangeVector(unsigned int max, unsigned int step = 1);
+
+template <typename T>
+inline Iterator<T> VectorIt(std::vector<T> v) {
+ return MakeVectorIterator<T>(std::move(v));
+}
+
+template <typename T>
+inline Iterator<T> PossiblySlowVectorIt(std::vector<T> v, bool slow = false) {
+ auto iterator = MakeVectorIterator<T>(std::move(v));
+ if (slow) {
+ return MakeTransformedIterator<T, T>(std::move(iterator),
+ [](T item) -> Result<TransformFlow<T>> {
+ SleepABit();
+ return TransformYield(item);
+ });
+ } else {
+ return iterator;
+ }
+}
+
+template <typename T>
+inline void AssertIteratorExhausted(Iterator<T>& it) {
+ ASSERT_OK_AND_ASSIGN(T next, it.Next());
+ ASSERT_TRUE(IsIterationEnd(next));
+}
+
+Transformer<TestInt, TestStr> MakeFilter(std::function<bool(TestInt&)> filter);
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/thread_pool.cc b/src/arrow/cpp/src/arrow/util/thread_pool.cc
new file mode 100644
index 000000000..37132fe1a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/thread_pool.cc
@@ -0,0 +1,450 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/thread_pool.h"
+
+#include <algorithm>
+#include <condition_variable>
+#include <deque>
+#include <list>
+#include <mutex>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace internal {
+
+Executor::~Executor() = default;
+
+namespace {
+
+struct Task {
+ FnOnce<void()> callable;
+ StopToken stop_token;
+ Executor::StopCallback stop_callback;
+};
+
+} // namespace
+
+struct SerialExecutor::State {
+ std::deque<Task> task_queue;
+ std::mutex mutex;
+ std::condition_variable wait_for_tasks;
+ bool finished{false};
+};
+
+SerialExecutor::SerialExecutor() : state_(std::make_shared<State>()) {}
+
+SerialExecutor::~SerialExecutor() = default;
+
+Status SerialExecutor::SpawnReal(TaskHints hints, FnOnce<void()> task,
+ StopToken stop_token, StopCallback&& stop_callback) {
+ // While the SerialExecutor runs tasks synchronously on its main thread,
+ // SpawnReal may be called from external threads (e.g. when transferring back
+ // from blocking I/O threads), so we need to keep the state alive *and* to
+ // lock its contents.
+ //
+ // Note that holding the lock while notifying the condition variable may
+ // not be sufficient, as some exit paths in the main thread are unlocked.
+ auto state = state_;
+ {
+ std::lock_guard<std::mutex> lk(state->mutex);
+ state->task_queue.push_back(
+ Task{std::move(task), std::move(stop_token), std::move(stop_callback)});
+ }
+ state->wait_for_tasks.notify_one();
+ return Status::OK();
+}
+
+void SerialExecutor::MarkFinished() {
+ // Same comment as SpawnReal above
+ auto state = state_;
+ {
+ std::lock_guard<std::mutex> lk(state->mutex);
+ state->finished = true;
+ }
+ state->wait_for_tasks.notify_one();
+}
+
+void SerialExecutor::RunLoop() {
+ // This is called from the SerialExecutor's main thread, so the
+ // state is guaranteed to be kept alive.
+ std::unique_lock<std::mutex> lk(state_->mutex);
+
+ while (!state_->finished) {
+ while (!state_->task_queue.empty()) {
+ Task task = std::move(state_->task_queue.front());
+ state_->task_queue.pop_front();
+ lk.unlock();
+ if (!task.stop_token.IsStopRequested()) {
+ std::move(task.callable)();
+ } else {
+ if (task.stop_callback) {
+ std::move(task.stop_callback)(task.stop_token.Poll());
+ }
+ // Can't break here because there may be cleanup tasks down the chain we still
+ // need to run.
+ }
+ lk.lock();
+ }
+ // In this case we must be waiting on work from external (e.g. I/O) executors. Wait
+ // for tasks to arrive (typically via transferred futures).
+ state_->wait_for_tasks.wait(
+ lk, [&] { return state_->finished || !state_->task_queue.empty(); });
+ }
+}
+
+struct ThreadPool::State {
+ State() = default;
+
+ // NOTE: in case locking becomes too expensive, we can investigate lock-free FIFOs
+ // such as https://github.com/cameron314/concurrentqueue
+
+ std::mutex mutex_;
+ std::condition_variable cv_;
+ std::condition_variable cv_shutdown_;
+ std::condition_variable cv_idle_;
+
+ std::list<std::thread> workers_;
+ // Trashcan for finished threads
+ std::vector<std::thread> finished_workers_;
+ std::deque<Task> pending_tasks_;
+
+ // Desired number of threads
+ int desired_capacity_ = 0;
+
+ // Total number of tasks that are either queued or running
+ int tasks_queued_or_running_ = 0;
+
+ // Are we shutting down?
+ bool please_shutdown_ = false;
+ bool quick_shutdown_ = false;
+};
+
+// The worker loop is an independent function so that it can keep running
+// after the ThreadPool is destroyed.
+static void WorkerLoop(std::shared_ptr<ThreadPool::State> state,
+ std::list<std::thread>::iterator it) {
+ std::unique_lock<std::mutex> lock(state->mutex_);
+
+ // Since we hold the lock, `it` now points to the correct thread object
+ // (LaunchWorkersUnlocked has exited)
+ DCHECK_EQ(std::this_thread::get_id(), it->get_id());
+
+ // If too many threads, we should secede from the pool
+ const auto should_secede = [&]() -> bool {
+ return state->workers_.size() > static_cast<size_t>(state->desired_capacity_);
+ };
+
+ while (true) {
+ // By the time this thread is started, some tasks may have been pushed
+ // or shutdown could even have been requested. So we only wait on the
+ // condition variable at the end of the loop.
+
+ // Execute pending tasks if any
+ while (!state->pending_tasks_.empty() && !state->quick_shutdown_) {
+ // We check this opportunistically at each loop iteration since
+ // it releases the lock below.
+ if (should_secede()) {
+ break;
+ }
+
+ DCHECK_GE(state->tasks_queued_or_running_, 0);
+ {
+ Task task = std::move(state->pending_tasks_.front());
+ state->pending_tasks_.pop_front();
+ StopToken* stop_token = &task.stop_token;
+ lock.unlock();
+ if (!stop_token->IsStopRequested()) {
+ std::move(task.callable)();
+ } else {
+ if (task.stop_callback) {
+ std::move(task.stop_callback)(stop_token->Poll());
+ }
+ }
+ ARROW_UNUSED(std::move(task)); // release resources before waiting for lock
+ lock.lock();
+ }
+ if (ARROW_PREDICT_FALSE(--state->tasks_queued_or_running_ == 0)) {
+ state->cv_idle_.notify_all();
+ }
+ }
+ // Now either the queue is empty *or* a quick shutdown was requested
+ if (state->please_shutdown_ || should_secede()) {
+ break;
+ }
+ // Wait for next wakeup
+ state->cv_.wait(lock);
+ }
+ DCHECK_GE(state->tasks_queued_or_running_, 0);
+
+ // We're done. Move our thread object to the trashcan of finished
+ // workers. This has two motivations:
+ // 1) the thread object doesn't get destroyed before this function finishes
+ // (but we could call thread::detach() instead)
+ // 2) we can explicitly join() the trashcan threads to make sure all OS threads
+ // are exited before the ThreadPool is destroyed. Otherwise subtle
+ // timing conditions can lead to false positives with Valgrind.
+ DCHECK_EQ(std::this_thread::get_id(), it->get_id());
+ state->finished_workers_.push_back(std::move(*it));
+ state->workers_.erase(it);
+ if (state->please_shutdown_) {
+ // Notify the function waiting in Shutdown().
+ state->cv_shutdown_.notify_one();
+ }
+}
+
+void ThreadPool::WaitForIdle() {
+ std::unique_lock<std::mutex> lk(state_->mutex_);
+ state_->cv_idle_.wait(lk, [this] { return state_->tasks_queued_or_running_ == 0; });
+}
+
+ThreadPool::ThreadPool()
+ : sp_state_(std::make_shared<ThreadPool::State>()),
+ state_(sp_state_.get()),
+ shutdown_on_destroy_(true) {
+#ifndef _WIN32
+ pid_ = getpid();
+#endif
+}
+
+ThreadPool::~ThreadPool() {
+ if (shutdown_on_destroy_) {
+ ARROW_UNUSED(Shutdown(false /* wait */));
+ }
+}
+
+void ThreadPool::ProtectAgainstFork() {
+#ifndef _WIN32
+ pid_t current_pid = getpid();
+ if (pid_ != current_pid) {
+ // Reinitialize internal state in child process after fork()
+ // Ideally we would use pthread_at_fork(), but that doesn't allow
+ // storing an argument, hence we'd need to maintain a list of all
+ // existing ThreadPools.
+ int capacity = state_->desired_capacity_;
+
+ auto new_state = std::make_shared<ThreadPool::State>();
+ new_state->please_shutdown_ = state_->please_shutdown_;
+ new_state->quick_shutdown_ = state_->quick_shutdown_;
+
+ pid_ = current_pid;
+ sp_state_ = new_state;
+ state_ = sp_state_.get();
+
+ // Launch worker threads anew
+ if (!state_->please_shutdown_) {
+ ARROW_UNUSED(SetCapacity(capacity));
+ }
+ }
+#endif
+}
+
+Status ThreadPool::SetCapacity(int threads) {
+ ProtectAgainstFork();
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ if (state_->please_shutdown_) {
+ return Status::Invalid("operation forbidden during or after shutdown");
+ }
+ if (threads <= 0) {
+ return Status::Invalid("ThreadPool capacity must be > 0");
+ }
+ CollectFinishedWorkersUnlocked();
+
+ state_->desired_capacity_ = threads;
+ // See if we need to increase or decrease the number of running threads
+ const int required = std::min(static_cast<int>(state_->pending_tasks_.size()),
+ threads - static_cast<int>(state_->workers_.size()));
+ if (required > 0) {
+ // Some tasks are pending, spawn the number of needed threads immediately
+ LaunchWorkersUnlocked(required);
+ } else if (required < 0) {
+ // Excess threads are running, wake them so that they stop
+ state_->cv_.notify_all();
+ }
+ return Status::OK();
+}
+
+int ThreadPool::GetCapacity() {
+ ProtectAgainstFork();
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ return state_->desired_capacity_;
+}
+
+int ThreadPool::GetNumTasks() {
+ ProtectAgainstFork();
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ return state_->tasks_queued_or_running_;
+}
+
+int ThreadPool::GetActualCapacity() {
+ ProtectAgainstFork();
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ return static_cast<int>(state_->workers_.size());
+}
+
+Status ThreadPool::Shutdown(bool wait) {
+ ProtectAgainstFork();
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+
+ if (state_->please_shutdown_) {
+ return Status::Invalid("Shutdown() already called");
+ }
+ state_->please_shutdown_ = true;
+ state_->quick_shutdown_ = !wait;
+ state_->cv_.notify_all();
+ state_->cv_shutdown_.wait(lock, [this] { return state_->workers_.empty(); });
+ if (!state_->quick_shutdown_) {
+ DCHECK_EQ(state_->pending_tasks_.size(), 0);
+ } else {
+ state_->pending_tasks_.clear();
+ }
+ CollectFinishedWorkersUnlocked();
+ return Status::OK();
+}
+
+void ThreadPool::CollectFinishedWorkersUnlocked() {
+ for (auto& thread : state_->finished_workers_) {
+ // Make sure OS thread has exited
+ thread.join();
+ }
+ state_->finished_workers_.clear();
+}
+
+thread_local ThreadPool* current_thread_pool_ = nullptr;
+
+bool ThreadPool::OwnsThisThread() { return current_thread_pool_ == this; }
+
+void ThreadPool::LaunchWorkersUnlocked(int threads) {
+ std::shared_ptr<State> state = sp_state_;
+
+ for (int i = 0; i < threads; i++) {
+ state_->workers_.emplace_back();
+ auto it = --(state_->workers_.end());
+ *it = std::thread([this, state, it] {
+ current_thread_pool_ = this;
+ WorkerLoop(state, it);
+ });
+ }
+}
+
+Status ThreadPool::SpawnReal(TaskHints hints, FnOnce<void()> task, StopToken stop_token,
+ StopCallback&& stop_callback) {
+ {
+ ProtectAgainstFork();
+ std::lock_guard<std::mutex> lock(state_->mutex_);
+ if (state_->please_shutdown_) {
+ return Status::Invalid("operation forbidden during or after shutdown");
+ }
+ CollectFinishedWorkersUnlocked();
+ state_->tasks_queued_or_running_++;
+ if (static_cast<int>(state_->workers_.size()) < state_->tasks_queued_or_running_ &&
+ state_->desired_capacity_ > static_cast<int>(state_->workers_.size())) {
+ // We can still spin up more workers so spin up a new worker
+ LaunchWorkersUnlocked(/*threads=*/1);
+ }
+ state_->pending_tasks_.push_back(
+ {std::move(task), std::move(stop_token), std::move(stop_callback)});
+ }
+ state_->cv_.notify_one();
+ return Status::OK();
+}
+
+Result<std::shared_ptr<ThreadPool>> ThreadPool::Make(int threads) {
+ auto pool = std::shared_ptr<ThreadPool>(new ThreadPool());
+ RETURN_NOT_OK(pool->SetCapacity(threads));
+ return pool;
+}
+
+Result<std::shared_ptr<ThreadPool>> ThreadPool::MakeEternal(int threads) {
+ ARROW_ASSIGN_OR_RAISE(auto pool, Make(threads));
+ // On Windows, the ThreadPool destructor may be called after non-main threads
+ // have been killed by the OS, and hang in a condition variable.
+ // On Unix, we want to avoid leak reports by Valgrind.
+#ifdef _WIN32
+ pool->shutdown_on_destroy_ = false;
+#endif
+ return pool;
+}
+
+// ----------------------------------------------------------------------
+// Global thread pool
+
+static int ParseOMPEnvVar(const char* name) {
+ // OMP_NUM_THREADS is a comma-separated list of positive integers.
+ // We are only interested in the first (top-level) number.
+ auto result = GetEnvVar(name);
+ if (!result.ok()) {
+ return 0;
+ }
+ auto str = *std::move(result);
+ auto first_comma = str.find_first_of(',');
+ if (first_comma != std::string::npos) {
+ str = str.substr(0, first_comma);
+ }
+ try {
+ return std::max(0, std::stoi(str));
+ } catch (...) {
+ return 0;
+ }
+}
+
+int ThreadPool::DefaultCapacity() {
+ int capacity, limit;
+ capacity = ParseOMPEnvVar("OMP_NUM_THREADS");
+ if (capacity == 0) {
+ capacity = std::thread::hardware_concurrency();
+ }
+ limit = ParseOMPEnvVar("OMP_THREAD_LIMIT");
+ if (limit > 0) {
+ capacity = std::min(limit, capacity);
+ }
+ if (capacity == 0) {
+ ARROW_LOG(WARNING) << "Failed to determine the number of available threads, "
+ "using a hardcoded arbitrary value";
+ capacity = 4;
+ }
+ return capacity;
+}
+
+// Helper for the singleton pattern
+std::shared_ptr<ThreadPool> ThreadPool::MakeCpuThreadPool() {
+ auto maybe_pool = ThreadPool::MakeEternal(ThreadPool::DefaultCapacity());
+ if (!maybe_pool.ok()) {
+ maybe_pool.status().Abort("Failed to create global CPU thread pool");
+ }
+ return *std::move(maybe_pool);
+}
+
+ThreadPool* GetCpuThreadPool() {
+ static std::shared_ptr<ThreadPool> singleton = ThreadPool::MakeCpuThreadPool();
+ return singleton.get();
+}
+
+} // namespace internal
+
+int GetCpuThreadPoolCapacity() { return internal::GetCpuThreadPool()->GetCapacity(); }
+
+Status SetCpuThreadPoolCapacity(int threads) {
+ return internal::GetCpuThreadPool()->SetCapacity(threads);
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/thread_pool.h b/src/arrow/cpp/src/arrow/util/thread_pool.h
new file mode 100644
index 000000000..4ed908d6f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/thread_pool.h
@@ -0,0 +1,403 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#ifndef _WIN32
+#include <unistd.h>
+#endif
+
+#include <cstdint>
+#include <memory>
+#include <queue>
+#include <type_traits>
+#include <utility>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/cancel.h"
+#include "arrow/util/functional.h"
+#include "arrow/util/future.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+
+#if defined(_MSC_VER)
+// Disable harmless warning for decorated name length limit
+#pragma warning(disable : 4503)
+#endif
+
+namespace arrow {
+
+/// \brief Get the capacity of the global thread pool
+///
+/// Return the number of worker threads in the thread pool to which
+/// Arrow dispatches various CPU-bound tasks. This is an ideal number,
+/// not necessarily the exact number of threads at a given point in time.
+///
+/// You can change this number using SetCpuThreadPoolCapacity().
+ARROW_EXPORT int GetCpuThreadPoolCapacity();
+
+/// \brief Set the capacity of the global thread pool
+///
+/// Set the number of worker threads int the thread pool to which
+/// Arrow dispatches various CPU-bound tasks.
+///
+/// The current number is returned by GetCpuThreadPoolCapacity().
+ARROW_EXPORT Status SetCpuThreadPoolCapacity(int threads);
+
+namespace internal {
+
+// Hints about a task that may be used by an Executor.
+// They are ignored by the provided ThreadPool implementation.
+struct TaskHints {
+ // The lower, the more urgent
+ int32_t priority = 0;
+ // The IO transfer size in bytes
+ int64_t io_size = -1;
+ // The approximate CPU cost in number of instructions
+ int64_t cpu_cost = -1;
+ // An application-specific ID
+ int64_t external_id = -1;
+};
+
+class ARROW_EXPORT Executor {
+ public:
+ using StopCallback = internal::FnOnce<void(const Status&)>;
+
+ virtual ~Executor();
+
+ // Spawn a fire-and-forget task.
+ template <typename Function>
+ Status Spawn(Function&& func) {
+ return SpawnReal(TaskHints{}, std::forward<Function>(func), StopToken::Unstoppable(),
+ StopCallback{});
+ }
+ template <typename Function>
+ Status Spawn(Function&& func, StopToken stop_token) {
+ return SpawnReal(TaskHints{}, std::forward<Function>(func), std::move(stop_token),
+ StopCallback{});
+ }
+ template <typename Function>
+ Status Spawn(TaskHints hints, Function&& func) {
+ return SpawnReal(hints, std::forward<Function>(func), StopToken::Unstoppable(),
+ StopCallback{});
+ }
+ template <typename Function>
+ Status Spawn(TaskHints hints, Function&& func, StopToken stop_token) {
+ return SpawnReal(hints, std::forward<Function>(func), std::move(stop_token),
+ StopCallback{});
+ }
+ template <typename Function>
+ Status Spawn(TaskHints hints, Function&& func, StopToken stop_token,
+ StopCallback stop_callback) {
+ return SpawnReal(hints, std::forward<Function>(func), std::move(stop_token),
+ std::move(stop_callback));
+ }
+
+ // Transfers a future to this executor. Any continuations added to the
+ // returned future will run in this executor. Otherwise they would run
+ // on the same thread that called MarkFinished.
+ //
+ // This is necessary when (for example) an I/O task is completing a future.
+ // The continuations of that future should run on the CPU thread pool keeping
+ // CPU heavy work off the I/O thread pool. So the I/O task should transfer
+ // the future to the CPU executor before returning.
+ //
+ // By default this method will only transfer if the future is not already completed. If
+ // the future is already completed then any callback would be run synchronously and so
+ // no transfer is typically necessary. However, in cases where you want to force a
+ // transfer (e.g. to help the scheduler break up units of work across multiple cores)
+ // then you can override this behavior with `always_transfer`.
+ template <typename T>
+ Future<T> Transfer(Future<T> future) {
+ return DoTransfer(std::move(future), false);
+ }
+
+ // Overload of Transfer which will always schedule callbacks on new threads even if the
+ // future is finished when the callback is added.
+ //
+ // This can be useful in cases where you want to ensure parallelism
+ template <typename T>
+ Future<T> TransferAlways(Future<T> future) {
+ return DoTransfer(std::move(future), true);
+ }
+
+ // Submit a callable and arguments for execution. Return a future that
+ // will return the callable's result value once.
+ // The callable's arguments are copied before execution.
+ template <typename Function, typename... Args,
+ typename FutureType = typename ::arrow::detail::ContinueFuture::ForSignature<
+ Function && (Args && ...)>>
+ Result<FutureType> Submit(TaskHints hints, StopToken stop_token, Function&& func,
+ Args&&... args) {
+ using ValueType = typename FutureType::ValueType;
+
+ auto future = FutureType::Make();
+ auto task = std::bind(::arrow::detail::ContinueFuture{}, future,
+ std::forward<Function>(func), std::forward<Args>(args)...);
+ struct {
+ WeakFuture<ValueType> weak_fut;
+
+ void operator()(const Status& st) {
+ auto fut = weak_fut.get();
+ if (fut.is_valid()) {
+ fut.MarkFinished(st);
+ }
+ }
+ } stop_callback{WeakFuture<ValueType>(future)};
+ ARROW_RETURN_NOT_OK(SpawnReal(hints, std::move(task), std::move(stop_token),
+ std::move(stop_callback)));
+
+ return future;
+ }
+
+ template <typename Function, typename... Args,
+ typename FutureType = typename ::arrow::detail::ContinueFuture::ForSignature<
+ Function && (Args && ...)>>
+ Result<FutureType> Submit(StopToken stop_token, Function&& func, Args&&... args) {
+ return Submit(TaskHints{}, stop_token, std::forward<Function>(func),
+ std::forward<Args>(args)...);
+ }
+
+ template <typename Function, typename... Args,
+ typename FutureType = typename ::arrow::detail::ContinueFuture::ForSignature<
+ Function && (Args && ...)>>
+ Result<FutureType> Submit(TaskHints hints, Function&& func, Args&&... args) {
+ return Submit(std::move(hints), StopToken::Unstoppable(),
+ std::forward<Function>(func), std::forward<Args>(args)...);
+ }
+
+ template <typename Function, typename... Args,
+ typename FutureType = typename ::arrow::detail::ContinueFuture::ForSignature<
+ Function && (Args && ...)>>
+ Result<FutureType> Submit(Function&& func, Args&&... args) {
+ return Submit(TaskHints{}, StopToken::Unstoppable(), std::forward<Function>(func),
+ std::forward<Args>(args)...);
+ }
+
+ // Return the level of parallelism (the number of tasks that may be executed
+ // concurrently). This may be an approximate number.
+ virtual int GetCapacity() = 0;
+
+ // Return true if the thread from which this function is called is owned by this
+ // Executor. Returns false if this Executor does not support this property.
+ virtual bool OwnsThisThread() { return false; }
+
+ protected:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Executor);
+
+ Executor() = default;
+
+ template <typename T, typename FT = Future<T>, typename FTSync = typename FT::SyncType>
+ Future<T> DoTransfer(Future<T> future, bool always_transfer = false) {
+ auto transferred = Future<T>::Make();
+ if (always_transfer) {
+ CallbackOptions callback_options = CallbackOptions::Defaults();
+ callback_options.should_schedule = ShouldSchedule::Always;
+ callback_options.executor = this;
+ auto sync_callback = [transferred](const FTSync& result) mutable {
+ transferred.MarkFinished(result);
+ };
+ future.AddCallback(sync_callback, callback_options);
+ return transferred;
+ }
+
+ // We could use AddCallback's ShouldSchedule::IfUnfinished but we can save a bit of
+ // work by doing the test here.
+ auto callback = [this, transferred](const FTSync& result) mutable {
+ auto spawn_status =
+ Spawn([transferred, result]() mutable { transferred.MarkFinished(result); });
+ if (!spawn_status.ok()) {
+ transferred.MarkFinished(spawn_status);
+ }
+ };
+ auto callback_factory = [&callback]() { return callback; };
+ if (future.TryAddCallback(callback_factory)) {
+ return transferred;
+ }
+ // If the future is already finished and we aren't going to force spawn a thread
+ // then we don't need to add another layer of callback and can return the original
+ // future
+ return future;
+ }
+
+ // Subclassing API
+ virtual Status SpawnReal(TaskHints hints, FnOnce<void()> task, StopToken,
+ StopCallback&&) = 0;
+};
+
+/// \brief An executor implementation that runs all tasks on a single thread using an
+/// event loop.
+///
+/// Note: Any sort of nested parallelism will deadlock this executor. Blocking waits are
+/// fine but if one task needs to wait for another task it must be expressed as an
+/// asynchronous continuation.
+class ARROW_EXPORT SerialExecutor : public Executor {
+ public:
+ template <typename T = ::arrow::internal::Empty>
+ using TopLevelTask = internal::FnOnce<Future<T>(Executor*)>;
+
+ ~SerialExecutor() override;
+
+ int GetCapacity() override { return 1; };
+ Status SpawnReal(TaskHints hints, FnOnce<void()> task, StopToken,
+ StopCallback&&) override;
+
+ /// \brief Runs the TopLevelTask and any scheduled tasks
+ ///
+ /// The TopLevelTask (or one of the tasks it schedules) must either return an invalid
+ /// status or call the finish signal. Failure to do this will result in a deadlock. For
+ /// this reason it is preferable (if possible) to use the helper methods (below)
+ /// RunSynchronously/RunSerially which delegates the responsiblity onto a Future
+ /// producer's existing responsibility to always mark a future finished (which can
+ /// someday be aided by ARROW-12207).
+ template <typename T = internal::Empty, typename FT = Future<T>,
+ typename FTSync = typename FT::SyncType>
+ static FTSync RunInSerialExecutor(TopLevelTask<T> initial_task) {
+ Future<T> fut = SerialExecutor().Run<T>(std::move(initial_task));
+ return FutureToSync(fut);
+ }
+
+ private:
+ SerialExecutor();
+
+ // State uses mutex
+ struct State;
+ std::shared_ptr<State> state_;
+
+ template <typename T, typename FTSync = typename Future<T>::SyncType>
+ Future<T> Run(TopLevelTask<T> initial_task) {
+ auto final_fut = std::move(initial_task)(this);
+ if (final_fut.is_finished()) {
+ return final_fut;
+ }
+ final_fut.AddCallback([this](const FTSync&) { MarkFinished(); });
+ RunLoop();
+ return final_fut;
+ }
+ void RunLoop();
+ void MarkFinished();
+};
+
+/// An Executor implementation spawning tasks in FIFO manner on a fixed-size
+/// pool of worker threads.
+///
+/// Note: Any sort of nested parallelism will deadlock this executor. Blocking waits are
+/// fine but if one task needs to wait for another task it must be expressed as an
+/// asynchronous continuation.
+class ARROW_EXPORT ThreadPool : public Executor {
+ public:
+ // Construct a thread pool with the given number of worker threads
+ static Result<std::shared_ptr<ThreadPool>> Make(int threads);
+
+ // Like Make(), but takes care that the returned ThreadPool is compatible
+ // with destruction late at process exit.
+ static Result<std::shared_ptr<ThreadPool>> MakeEternal(int threads);
+
+ // Destroy thread pool; the pool will first be shut down
+ ~ThreadPool() override;
+
+ // Return the desired number of worker threads.
+ // The actual number of workers may lag a bit before being adjusted to
+ // match this value.
+ int GetCapacity() override;
+
+ bool OwnsThisThread() override;
+
+ // Return the number of tasks either running or in the queue.
+ int GetNumTasks();
+
+ // Dynamically change the number of worker threads.
+ //
+ // This function always returns immediately.
+ // If fewer threads are running than this number, new threads are spawned
+ // on-demand when needed for task execution.
+ // If more threads are running than this number, excess threads are reaped
+ // as soon as possible.
+ Status SetCapacity(int threads);
+
+ // Heuristic for the default capacity of a thread pool for CPU-bound tasks.
+ // This is exposed as a static method to help with testing.
+ static int DefaultCapacity();
+
+ // Shutdown the pool. Once the pool starts shutting down, new tasks
+ // cannot be submitted anymore.
+ // If "wait" is true, shutdown waits for all pending tasks to be finished.
+ // If "wait" is false, workers are stopped as soon as currently executing
+ // tasks are finished.
+ Status Shutdown(bool wait = true);
+
+ // Wait for the thread pool to become idle
+ //
+ // This is useful for sequencing tests
+ void WaitForIdle();
+
+ struct State;
+
+ protected:
+ FRIEND_TEST(TestThreadPool, SetCapacity);
+ FRIEND_TEST(TestGlobalThreadPool, Capacity);
+ friend ARROW_EXPORT ThreadPool* GetCpuThreadPool();
+
+ ThreadPool();
+
+ Status SpawnReal(TaskHints hints, FnOnce<void()> task, StopToken,
+ StopCallback&&) override;
+
+ // Collect finished worker threads, making sure the OS threads have exited
+ void CollectFinishedWorkersUnlocked();
+ // Launch a given number of additional workers
+ void LaunchWorkersUnlocked(int threads);
+ // Get the current actual capacity
+ int GetActualCapacity();
+ // Reinitialize the thread pool if the pid changed
+ void ProtectAgainstFork();
+
+ static std::shared_ptr<ThreadPool> MakeCpuThreadPool();
+
+ std::shared_ptr<State> sp_state_;
+ State* state_;
+ bool shutdown_on_destroy_;
+#ifndef _WIN32
+ pid_t pid_;
+#endif
+};
+
+// Return the process-global thread pool for CPU-bound tasks.
+ARROW_EXPORT ThreadPool* GetCpuThreadPool();
+
+/// \brief Potentially run an async operation serially (if use_threads is false)
+/// \see RunSerially
+///
+/// If `use_threads` is true, the global CPU executor is used.
+/// If `use_threads` is false, a temporary SerialExecutor is used.
+/// `get_future` is called (from this thread) with the chosen executor and must
+/// return a future that will eventually finish. This function returns once the
+/// future has finished.
+template <typename Fut, typename ValueType = typename Fut::ValueType>
+typename Fut::SyncType RunSynchronously(FnOnce<Fut(Executor*)> get_future,
+ bool use_threads) {
+ if (use_threads) {
+ auto fut = std::move(get_future)(GetCpuThreadPool());
+ return FutureToSync(fut);
+ } else {
+ return SerialExecutor::RunInSerialExecutor<ValueType>(std::move(get_future));
+ }
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/thread_pool_benchmark.cc b/src/arrow/cpp/src/arrow/util/thread_pool_benchmark.cc
new file mode 100644
index 000000000..7c342c47f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/thread_pool_benchmark.cc
@@ -0,0 +1,248 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <random>
+#include <vector>
+
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/task_group.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+namespace internal {
+
+struct Workload {
+ explicit Workload(int32_t size) : size_(size), data_(kDataSize) {
+ std::default_random_engine gen(42);
+ std::uniform_int_distribution<uint64_t> dist(0, std::numeric_limits<uint64_t>::max());
+ std::generate(data_.begin(), data_.end(), [&]() { return dist(gen); });
+ }
+
+ void operator()();
+
+ private:
+ static constexpr int32_t kDataSize = 32;
+
+ int32_t size_;
+ std::vector<uint64_t> data_;
+};
+
+void Workload::operator()() {
+ uint64_t result = 0;
+ for (int32_t i = 0; i < size_ / kDataSize; ++i) {
+ for (const auto v : data_) {
+ result = (result << (v % 64)) - v;
+ }
+ }
+ benchmark::DoNotOptimize(result);
+}
+
+struct Task {
+ explicit Task(int32_t size) : workload_(size) {}
+
+ Status operator()() {
+ workload_();
+ return Status::OK();
+ }
+
+ private:
+ Workload workload_;
+};
+
+// Benchmark ThreadPool::Spawn
+static void ThreadPoolSpawn(benchmark::State& state) { // NOLINT non-const reference
+ const auto nthreads = static_cast<int>(state.range(0));
+ const auto workload_size = static_cast<int32_t>(state.range(1));
+
+ Workload workload(workload_size);
+
+ // Spawn enough tasks to make the pool start up overhead negligible
+ const int32_t nspawns = 200000000 / workload_size + 1;
+
+ for (auto _ : state) {
+ state.PauseTiming();
+ std::shared_ptr<ThreadPool> pool;
+ pool = *ThreadPool::Make(nthreads);
+ state.ResumeTiming();
+
+ for (int32_t i = 0; i < nspawns; ++i) {
+ // Pass the task by reference to avoid copying it around
+ ABORT_NOT_OK(pool->Spawn(std::ref(workload)));
+ }
+
+ // Wait for all tasks to finish
+ ABORT_NOT_OK(pool->Shutdown(true /* wait */));
+ state.PauseTiming();
+ pool.reset();
+ state.ResumeTiming();
+ }
+ state.SetItemsProcessed(state.iterations() * nspawns);
+}
+
+// Benchmark SerialExecutor::RunInSerialExecutor
+static void RunInSerialExecutor(benchmark::State& state) { // NOLINT non-const reference
+ const auto workload_size = static_cast<int32_t>(state.range(0));
+
+ Workload workload(workload_size);
+
+ for (auto _ : state) {
+ ABORT_NOT_OK(
+ SerialExecutor::RunInSerialExecutor<Future<>>([&](internal::Executor* executor) {
+ return DeferNotOk(executor->Submit(std::ref(workload)));
+ }));
+ }
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+// Benchmark ThreadPool::Submit
+static void ThreadPoolSubmit(benchmark::State& state) { // NOLINT non-const reference
+ const auto nthreads = static_cast<int>(state.range(0));
+ const auto workload_size = static_cast<int32_t>(state.range(1));
+
+ Workload workload(workload_size);
+
+ const int32_t nspawns = 10000000 / workload_size + 1;
+
+ for (auto _ : state) {
+ state.PauseTiming();
+ auto pool = *ThreadPool::Make(nthreads);
+ std::atomic<int32_t> n_finished{0};
+ state.ResumeTiming();
+
+ for (int32_t i = 0; i < nspawns; ++i) {
+ // Pass the task by reference to avoid copying it around
+ (void)DeferNotOk(pool->Submit(std::ref(workload))).Then([&]() {
+ n_finished.fetch_add(1);
+ });
+ }
+
+ // Wait for all tasks to finish
+ ABORT_NOT_OK(pool->Shutdown(true /* wait */));
+ ASSERT_EQ(n_finished.load(), nspawns);
+ state.PauseTiming();
+ pool.reset();
+ state.ResumeTiming();
+ }
+ state.SetItemsProcessed(state.iterations() * nspawns);
+}
+
+// Benchmark serial TaskGroup
+static void SerialTaskGroup(benchmark::State& state) { // NOLINT non-const reference
+ const auto workload_size = static_cast<int32_t>(state.range(0));
+
+ Task task(workload_size);
+
+ const int32_t nspawns = 10000000 / workload_size + 1;
+
+ for (auto _ : state) {
+ auto task_group = TaskGroup::MakeSerial();
+ for (int32_t i = 0; i < nspawns; ++i) {
+ // Pass the task by reference to avoid copying it around
+ task_group->Append(std::ref(task));
+ }
+ ABORT_NOT_OK(task_group->Finish());
+ }
+ state.SetItemsProcessed(state.iterations() * nspawns);
+}
+
+// Benchmark threaded TaskGroup
+static void ThreadedTaskGroup(benchmark::State& state) { // NOLINT non-const reference
+ const auto nthreads = static_cast<int>(state.range(0));
+ const auto workload_size = static_cast<int32_t>(state.range(1));
+
+ std::shared_ptr<ThreadPool> pool;
+ pool = *ThreadPool::Make(nthreads);
+
+ Task task(workload_size);
+
+ const int32_t nspawns = 10000000 / workload_size + 1;
+
+ for (auto _ : state) {
+ auto task_group = TaskGroup::MakeThreaded(pool.get());
+ task_group->Append([&task, nspawns, task_group] {
+ for (int32_t i = 0; i < nspawns; ++i) {
+ // Pass the task by reference to avoid copying it around
+ task_group->Append(std::ref(task));
+ }
+ return Status::OK();
+ });
+ ABORT_NOT_OK(task_group->Finish());
+ }
+ ABORT_NOT_OK(pool->Shutdown(true /* wait */));
+
+ state.SetItemsProcessed(state.iterations() * nspawns);
+}
+
+static const std::vector<int32_t> kWorkloadSizes = {1000, 10000, 100000};
+
+static void WorkloadCost_Customize(benchmark::internal::Benchmark* b) {
+ for (const int32_t w : kWorkloadSizes) {
+ b->Args({w});
+ }
+ b->ArgNames({"task_cost"});
+ b->UseRealTime();
+}
+
+static void ThreadPoolSpawn_Customize(benchmark::internal::Benchmark* b) {
+ for (const int32_t w : kWorkloadSizes) {
+ for (const int nthreads : {1, 2, 4, 8}) {
+ b->Args({nthreads, w});
+ }
+ }
+ b->ArgNames({"threads", "task_cost"});
+ b->UseRealTime();
+}
+
+#ifdef ARROW_WITH_BENCHMARKS_REFERENCE
+
+// This benchmark simply provides a baseline indicating the raw cost of our workload
+// depending on the workload size. Number of items / second in this (serial)
+// benchmark can be compared to the numbers obtained in ThreadPoolSpawn.
+static void ReferenceWorkloadCost(benchmark::State& state) {
+ const auto workload_size = static_cast<int32_t>(state.range(0));
+
+ Workload workload(workload_size);
+ for (auto _ : state) {
+ workload();
+ }
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(ReferenceWorkloadCost)->Apply(WorkloadCost_Customize);
+
+#endif
+
+BENCHMARK(SerialTaskGroup)->Apply(WorkloadCost_Customize);
+BENCHMARK(RunInSerialExecutor)->Apply(WorkloadCost_Customize);
+BENCHMARK(ThreadPoolSpawn)->Apply(ThreadPoolSpawn_Customize);
+BENCHMARK(ThreadedTaskGroup)->Apply(ThreadPoolSpawn_Customize);
+BENCHMARK(ThreadPoolSubmit)->Apply(ThreadPoolSpawn_Customize);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/thread_pool_test.cc b/src/arrow/cpp/src/arrow/util/thread_pool_test.cc
new file mode 100644
index 000000000..399c755a8
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/thread_pool_test.cc
@@ -0,0 +1,718 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef _WIN32
+#include <sys/wait.h>
+#include <unistd.h>
+#endif
+
+#include <algorithm>
+#include <cstdio>
+#include <cstdlib>
+#include <functional>
+#include <memory>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/status.h"
+#include "arrow/testing/executor_util.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/test_common.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+namespace internal {
+
+template <typename T>
+static void task_add(T x, T y, T* out) {
+ *out = x + y;
+}
+
+template <typename T>
+struct task_slow_add {
+ void operator()(T x, T y, T* out) {
+ SleepFor(seconds_);
+ *out = x + y;
+ }
+
+ const double seconds_;
+};
+
+typedef std::function<void(int, int, int*)> AddTaskFunc;
+
+template <typename T>
+static T add(T x, T y) {
+ return x + y;
+}
+
+template <typename T>
+static T slow_add(double seconds, T x, T y) {
+ SleepFor(seconds);
+ return x + y;
+}
+
+template <typename T>
+static T inplace_add(T& x, T y) {
+ return x += y;
+}
+
+// A class to spawn "add" tasks to a pool and check the results when done
+
+class AddTester {
+ public:
+ explicit AddTester(int nadds, StopToken stop_token = StopToken::Unstoppable())
+ : nadds_(nadds), stop_token_(stop_token), xs_(nadds), ys_(nadds), outs_(nadds, -1) {
+ int x = 0, y = 0;
+ std::generate(xs_.begin(), xs_.end(), [&] {
+ ++x;
+ return x;
+ });
+ std::generate(ys_.begin(), ys_.end(), [&] {
+ y += 10;
+ return y;
+ });
+ }
+
+ AddTester(AddTester&&) = default;
+
+ void SpawnTasks(ThreadPool* pool, AddTaskFunc add_func) {
+ for (int i = 0; i < nadds_; ++i) {
+ ASSERT_OK(pool->Spawn([=] { add_func(xs_[i], ys_[i], &outs_[i]); }, stop_token_));
+ }
+ }
+
+ void CheckResults() {
+ for (int i = 0; i < nadds_; ++i) {
+ ASSERT_EQ(outs_[i], (i + 1) * 11);
+ }
+ }
+
+ void CheckNotAllComputed() {
+ for (int i = 0; i < nadds_; ++i) {
+ if (outs_[i] == -1) {
+ return;
+ }
+ }
+ ASSERT_TRUE(0) << "all values were computed";
+ }
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(AddTester);
+
+ int nadds_;
+ StopToken stop_token_;
+ std::vector<int> xs_;
+ std::vector<int> ys_;
+ std::vector<int> outs_;
+};
+
+class TestRunSynchronously : public testing::TestWithParam<bool> {
+ public:
+ bool UseThreads() { return GetParam(); }
+
+ template <typename T>
+ Result<T> Run(FnOnce<Future<T>(Executor*)> top_level_task) {
+ return RunSynchronously(std::move(top_level_task), UseThreads());
+ }
+
+ Status RunVoid(FnOnce<Future<>(Executor*)> top_level_task) {
+ return RunSynchronously(std::move(top_level_task), UseThreads());
+ }
+
+ void TestContinueAfterExternal(bool transfer_to_main_thread) {
+ bool continuation_ran = false;
+ EXPECT_OK_AND_ASSIGN(auto external_pool, ThreadPool::Make(1));
+ auto top_level_task = [&](Executor* executor) {
+ struct Callback {
+ Status operator()() {
+ *continuation_ran = true;
+ return Status::OK();
+ }
+ bool* continuation_ran;
+ };
+ auto fut = DeferNotOk(external_pool->Submit([&] {
+ SleepABit();
+ return Status::OK();
+ }));
+ if (transfer_to_main_thread) {
+ fut = executor->Transfer(fut);
+ }
+ return fut.Then(Callback{&continuation_ran});
+ };
+ ASSERT_OK(RunVoid(std::move(top_level_task)));
+ EXPECT_TRUE(continuation_ran);
+ }
+};
+
+TEST_P(TestRunSynchronously, SimpleRun) {
+ bool task_ran = false;
+ auto task = [&](Executor* executor) {
+ EXPECT_NE(executor, nullptr);
+ task_ran = true;
+ return Future<>::MakeFinished();
+ };
+ ASSERT_OK(RunVoid(std::move(task)));
+ EXPECT_TRUE(task_ran);
+}
+
+TEST_P(TestRunSynchronously, SpawnNested) {
+ bool nested_ran = false;
+ auto top_level_task = [&](Executor* executor) {
+ return DeferNotOk(executor->Submit([&] {
+ nested_ran = true;
+ return Status::OK();
+ }));
+ };
+ ASSERT_OK(RunVoid(std::move(top_level_task)));
+ EXPECT_TRUE(nested_ran);
+}
+
+TEST_P(TestRunSynchronously, SpawnMoreNested) {
+ std::atomic<int> nested_ran{0};
+ auto top_level_task = [&](Executor* executor) -> Future<> {
+ auto fut_a = DeferNotOk(executor->Submit([&] { nested_ran++; }));
+ auto fut_b = DeferNotOk(executor->Submit([&] { nested_ran++; }));
+ return AllComplete({fut_a, fut_b}).Then([&]() { nested_ran++; });
+ };
+ ASSERT_OK(RunVoid(std::move(top_level_task)));
+ EXPECT_EQ(nested_ran, 3);
+}
+
+TEST_P(TestRunSynchronously, WithResult) {
+ auto top_level_task = [&](Executor* executor) {
+ return DeferNotOk(executor->Submit([] { return 42; }));
+ };
+ ASSERT_OK_AND_EQ(42, Run<int>(std::move(top_level_task)));
+}
+
+TEST_P(TestRunSynchronously, StopTokenSpawn) {
+ bool nested_ran = false;
+ StopSource stop_source;
+ auto top_level_task = [&](Executor* executor) -> Future<> {
+ stop_source.RequestStop(Status::Invalid("XYZ"));
+ RETURN_NOT_OK(executor->Spawn([&] { nested_ran = true; }, stop_source.token()));
+ return Future<>::MakeFinished();
+ };
+ ASSERT_OK(RunVoid(std::move(top_level_task)));
+ EXPECT_FALSE(nested_ran);
+}
+
+TEST_P(TestRunSynchronously, StopTokenSubmit) {
+ bool nested_ran = false;
+ StopSource stop_source;
+ auto top_level_task = [&](Executor* executor) -> Future<> {
+ stop_source.RequestStop();
+ return DeferNotOk(executor->Submit(stop_source.token(), [&] {
+ nested_ran = true;
+ return Status::OK();
+ }));
+ };
+ ASSERT_RAISES(Cancelled, RunVoid(std::move(top_level_task)));
+ EXPECT_FALSE(nested_ran);
+}
+
+TEST_P(TestRunSynchronously, ContinueAfterExternal) {
+ // The future returned by the top-level task completes on another thread.
+ // This can trigger delicate race conditions in the SerialExecutor code,
+ // especially destruction.
+ this->TestContinueAfterExternal(/*transfer_to_main_thread=*/false);
+}
+
+TEST_P(TestRunSynchronously, ContinueAfterExternalTransferred) {
+ // Like above, but the future is transferred back to the serial executor
+ // after completion on an external thread.
+ this->TestContinueAfterExternal(/*transfer_to_main_thread=*/true);
+}
+
+TEST_P(TestRunSynchronously, SchedulerAbort) {
+ auto top_level_task = [&](Executor* executor) { return Status::Invalid("XYZ"); };
+ ASSERT_RAISES(Invalid, RunVoid(std::move(top_level_task)));
+}
+
+TEST_P(TestRunSynchronously, PropagatedError) {
+ auto top_level_task = [&](Executor* executor) {
+ return DeferNotOk(executor->Submit([] { return Status::Invalid("XYZ"); }));
+ };
+ ASSERT_RAISES(Invalid, RunVoid(std::move(top_level_task)));
+}
+
+INSTANTIATE_TEST_SUITE_P(TestRunSynchronously, TestRunSynchronously,
+ ::testing::Values(false, true));
+
+class TransferTest : public testing::Test {
+ public:
+ internal::Executor* executor() { return mock_executor.get(); }
+ int spawn_count() { return mock_executor->spawn_count; }
+
+ std::function<void(const Status&)> callback = [](const Status&) {};
+ std::shared_ptr<MockExecutor> mock_executor = std::make_shared<MockExecutor>();
+};
+
+TEST_F(TransferTest, DefaultTransferIfNotFinished) {
+ {
+ Future<> fut = Future<>::Make();
+ auto transferred = executor()->Transfer(fut);
+ fut.MarkFinished();
+ ASSERT_FINISHES_OK(transferred);
+ ASSERT_EQ(1, spawn_count());
+ }
+ {
+ Future<> fut = Future<>::Make();
+ fut.MarkFinished();
+ auto transferred = executor()->Transfer(fut);
+ ASSERT_FINISHES_OK(transferred);
+ ASSERT_EQ(1, spawn_count());
+ }
+}
+
+TEST_F(TransferTest, TransferAlways) {
+ {
+ Future<> fut = Future<>::Make();
+ fut.MarkFinished();
+ auto transferred = executor()->TransferAlways(fut);
+ ASSERT_FINISHES_OK(transferred);
+ ASSERT_EQ(1, spawn_count());
+ }
+}
+
+class TestThreadPool : public ::testing::Test {
+ public:
+ void TearDown() override {
+ fflush(stdout);
+ fflush(stderr);
+ }
+
+ std::shared_ptr<ThreadPool> MakeThreadPool() { return MakeThreadPool(4); }
+
+ std::shared_ptr<ThreadPool> MakeThreadPool(int threads) {
+ return *ThreadPool::Make(threads);
+ }
+
+ void DoSpawnAdds(ThreadPool* pool, int nadds, AddTaskFunc add_func,
+ StopToken stop_token = StopToken::Unstoppable(),
+ StopSource* stop_source = nullptr) {
+ AddTester add_tester(nadds, stop_token);
+ add_tester.SpawnTasks(pool, add_func);
+ if (stop_source) {
+ stop_source->RequestStop();
+ }
+ ASSERT_OK(pool->Shutdown());
+ if (stop_source) {
+ add_tester.CheckNotAllComputed();
+ } else {
+ add_tester.CheckResults();
+ }
+ }
+
+ void SpawnAdds(ThreadPool* pool, int nadds, AddTaskFunc add_func,
+ StopToken stop_token = StopToken::Unstoppable()) {
+ DoSpawnAdds(pool, nadds, std::move(add_func), std::move(stop_token));
+ }
+
+ void SpawnAddsAndCancel(ThreadPool* pool, int nadds, AddTaskFunc add_func,
+ StopSource* stop_source) {
+ DoSpawnAdds(pool, nadds, std::move(add_func), stop_source->token(), stop_source);
+ }
+
+ void DoSpawnAddsThreaded(ThreadPool* pool, int nthreads, int nadds,
+ AddTaskFunc add_func,
+ StopToken stop_token = StopToken::Unstoppable(),
+ StopSource* stop_source = nullptr) {
+ // Same as SpawnAdds, but do the task spawning from multiple threads
+ std::vector<AddTester> add_testers;
+ std::vector<std::thread> threads;
+ for (int i = 0; i < nthreads; ++i) {
+ add_testers.emplace_back(nadds, stop_token);
+ }
+ for (auto& add_tester : add_testers) {
+ threads.emplace_back([&] { add_tester.SpawnTasks(pool, add_func); });
+ }
+ if (stop_source) {
+ stop_source->RequestStop();
+ }
+ for (auto& thread : threads) {
+ thread.join();
+ }
+ ASSERT_OK(pool->Shutdown());
+ for (auto& add_tester : add_testers) {
+ if (stop_source) {
+ add_tester.CheckNotAllComputed();
+ } else {
+ add_tester.CheckResults();
+ }
+ }
+ }
+
+ void SpawnAddsThreaded(ThreadPool* pool, int nthreads, int nadds, AddTaskFunc add_func,
+ StopToken stop_token = StopToken::Unstoppable()) {
+ DoSpawnAddsThreaded(pool, nthreads, nadds, std::move(add_func),
+ std::move(stop_token));
+ }
+
+ void SpawnAddsThreadedAndCancel(ThreadPool* pool, int nthreads, int nadds,
+ AddTaskFunc add_func, StopSource* stop_source) {
+ DoSpawnAddsThreaded(pool, nthreads, nadds, std::move(add_func), stop_source->token(),
+ stop_source);
+ }
+};
+
+TEST_F(TestThreadPool, ConstructDestruct) {
+ // Stress shutdown-at-destruction logic
+ for (int threads : {1, 2, 3, 8, 32, 70}) {
+ auto pool = this->MakeThreadPool(threads);
+ }
+}
+
+// Correctness and stress tests using Spawn() and Shutdown()
+
+TEST_F(TestThreadPool, Spawn) {
+ auto pool = this->MakeThreadPool(3);
+ SpawnAdds(pool.get(), 7, task_add<int>);
+}
+
+TEST_F(TestThreadPool, StressSpawn) {
+ auto pool = this->MakeThreadPool(30);
+ SpawnAdds(pool.get(), 1000, task_add<int>);
+}
+
+TEST_F(TestThreadPool, OwnsCurrentThread) {
+ auto pool = this->MakeThreadPool(30);
+ std::atomic<bool> one_failed{false};
+
+ for (int i = 0; i < 1000; ++i) {
+ ASSERT_OK(pool->Spawn([&] {
+ if (pool->OwnsThisThread()) return;
+
+ one_failed = true;
+ }));
+ }
+
+ ASSERT_OK(pool->Shutdown());
+ ASSERT_FALSE(pool->OwnsThisThread());
+ ASSERT_FALSE(one_failed);
+}
+
+TEST_F(TestThreadPool, StressSpawnThreaded) {
+ auto pool = this->MakeThreadPool(30);
+ SpawnAddsThreaded(pool.get(), 20, 100, task_add<int>);
+}
+
+TEST_F(TestThreadPool, SpawnSlow) {
+ // This checks that Shutdown() waits for all tasks to finish
+ auto pool = this->MakeThreadPool(2);
+ SpawnAdds(pool.get(), 7, task_slow_add<int>{/*seconds=*/0.02});
+}
+
+TEST_F(TestThreadPool, StressSpawnSlow) {
+ auto pool = this->MakeThreadPool(30);
+ SpawnAdds(pool.get(), 1000, task_slow_add<int>{/*seconds=*/0.002});
+}
+
+TEST_F(TestThreadPool, StressSpawnSlowThreaded) {
+ auto pool = this->MakeThreadPool(30);
+ SpawnAddsThreaded(pool.get(), 20, 100, task_slow_add<int>{/*seconds=*/0.002});
+}
+
+TEST_F(TestThreadPool, SpawnWithStopToken) {
+ StopSource stop_source;
+ auto pool = this->MakeThreadPool(3);
+ SpawnAdds(pool.get(), 7, task_add<int>, stop_source.token());
+}
+
+TEST_F(TestThreadPool, StressSpawnThreadedWithStopToken) {
+ StopSource stop_source;
+ auto pool = this->MakeThreadPool(30);
+ SpawnAddsThreaded(pool.get(), 20, 100, task_add<int>, stop_source.token());
+}
+
+TEST_F(TestThreadPool, SpawnWithStopTokenCancelled) {
+ StopSource stop_source;
+ auto pool = this->MakeThreadPool(3);
+ SpawnAddsAndCancel(pool.get(), 100, task_slow_add<int>{/*seconds=*/0.02}, &stop_source);
+}
+
+TEST_F(TestThreadPool, StressSpawnThreadedWithStopTokenCancelled) {
+ StopSource stop_source;
+ auto pool = this->MakeThreadPool(30);
+ SpawnAddsThreadedAndCancel(pool.get(), 20, 100, task_slow_add<int>{/*seconds=*/0.02},
+ &stop_source);
+}
+
+TEST_F(TestThreadPool, QuickShutdown) {
+ AddTester add_tester(100);
+ {
+ auto pool = this->MakeThreadPool(3);
+ add_tester.SpawnTasks(pool.get(), task_slow_add<int>{/*seconds=*/0.02});
+ ASSERT_OK(pool->Shutdown(false /* wait */));
+ add_tester.CheckNotAllComputed();
+ }
+ add_tester.CheckNotAllComputed();
+}
+
+TEST_F(TestThreadPool, SetCapacity) {
+ auto pool = this->MakeThreadPool(5);
+
+ // Thread spawning is on-demand
+ ASSERT_EQ(pool->GetCapacity(), 5);
+ ASSERT_EQ(pool->GetActualCapacity(), 0);
+
+ ASSERT_OK(pool->SetCapacity(3));
+ ASSERT_EQ(pool->GetCapacity(), 3);
+ ASSERT_EQ(pool->GetActualCapacity(), 0);
+
+ auto gating_task = GatingTask::Make();
+
+ ASSERT_OK(pool->Spawn(gating_task->Task()));
+ ASSERT_OK(gating_task->WaitForRunning(1));
+ ASSERT_EQ(pool->GetActualCapacity(), 1);
+ ASSERT_OK(gating_task->Unlock());
+
+ gating_task = GatingTask::Make();
+ // Spawn more tasks than the pool capacity
+ for (int i = 0; i < 6; ++i) {
+ ASSERT_OK(pool->Spawn(gating_task->Task()));
+ }
+ ASSERT_OK(gating_task->WaitForRunning(3));
+ SleepFor(0.001); // Sleep a bit just to make sure it isn't making any threads
+ ASSERT_EQ(pool->GetActualCapacity(), 3); // maxxed out
+
+ // The tasks have not finished yet, increasing the desired capacity
+ // should spawn threads immediately.
+ ASSERT_OK(pool->SetCapacity(5));
+ ASSERT_EQ(pool->GetCapacity(), 5);
+ ASSERT_EQ(pool->GetActualCapacity(), 5);
+
+ // Thread reaping is eager (but asynchronous)
+ ASSERT_OK(pool->SetCapacity(2));
+ ASSERT_EQ(pool->GetCapacity(), 2);
+
+ // Wait for workers to wake up and secede
+ ASSERT_OK(gating_task->Unlock());
+ BusyWait(0.5, [&] { return pool->GetActualCapacity() == 2; });
+ ASSERT_EQ(pool->GetActualCapacity(), 2);
+
+ // Downsize while tasks are pending
+ ASSERT_OK(pool->SetCapacity(5));
+ ASSERT_EQ(pool->GetCapacity(), 5);
+ gating_task = GatingTask::Make();
+ for (int i = 0; i < 10; ++i) {
+ ASSERT_OK(pool->Spawn(gating_task->Task()));
+ }
+ ASSERT_OK(gating_task->WaitForRunning(5));
+ ASSERT_EQ(pool->GetActualCapacity(), 5);
+
+ ASSERT_OK(pool->SetCapacity(2));
+ ASSERT_EQ(pool->GetCapacity(), 2);
+ ASSERT_OK(gating_task->Unlock());
+ BusyWait(0.5, [&] { return pool->GetActualCapacity() == 2; });
+ ASSERT_EQ(pool->GetActualCapacity(), 2);
+
+ // Ensure nothing got stuck
+ ASSERT_OK(pool->Shutdown());
+}
+
+// Test Submit() functionality
+
+TEST_F(TestThreadPool, Submit) {
+ auto pool = this->MakeThreadPool(3);
+ {
+ ASSERT_OK_AND_ASSIGN(Future<int> fut, pool->Submit(add<int>, 4, 5));
+ Result<int> res = fut.result();
+ ASSERT_OK_AND_EQ(9, res);
+ }
+ {
+ ASSERT_OK_AND_ASSIGN(Future<std::string> fut,
+ pool->Submit(add<std::string>, "foo", "bar"));
+ ASSERT_OK_AND_EQ("foobar", fut.result());
+ }
+ {
+ ASSERT_OK_AND_ASSIGN(auto fut, pool->Submit(slow_add<int>, /*seconds=*/0.01, 4, 5));
+ ASSERT_OK_AND_EQ(9, fut.result());
+ }
+ {
+ // Reference passing
+ std::string s = "foo";
+ ASSERT_OK_AND_ASSIGN(auto fut,
+ pool->Submit(inplace_add<std::string>, std::ref(s), "bar"));
+ ASSERT_OK_AND_EQ("foobar", fut.result());
+ ASSERT_EQ(s, "foobar");
+ }
+ {
+ // `void` return type
+ ASSERT_OK_AND_ASSIGN(auto fut, pool->Submit(SleepFor, 0.001));
+ ASSERT_OK(fut.status());
+ }
+}
+
+TEST_F(TestThreadPool, SubmitWithStopToken) {
+ auto pool = this->MakeThreadPool(3);
+ {
+ StopSource stop_source;
+ ASSERT_OK_AND_ASSIGN(Future<int> fut,
+ pool->Submit(stop_source.token(), add<int>, 4, 5));
+ Result<int> res = fut.result();
+ ASSERT_OK_AND_EQ(9, res);
+ }
+}
+
+TEST_F(TestThreadPool, SubmitWithStopTokenCancelled) {
+ auto pool = this->MakeThreadPool(3);
+ {
+ const int n_futures = 100;
+ StopSource stop_source;
+ StopToken stop_token = stop_source.token();
+ std::vector<Future<int>> futures;
+ for (int i = 0; i < n_futures; ++i) {
+ ASSERT_OK_AND_ASSIGN(
+ auto fut, pool->Submit(stop_token, slow_add<int>, 0.01 /*seconds*/, i, 1));
+ futures.push_back(std::move(fut));
+ }
+ SleepFor(0.05); // Let some work finish
+ stop_source.RequestStop();
+ int n_success = 0;
+ int n_cancelled = 0;
+ for (int i = 0; i < n_futures; ++i) {
+ Result<int> res = futures[i].result();
+ if (res.ok()) {
+ ASSERT_EQ(i + 1, *res);
+ ++n_success;
+ } else {
+ ASSERT_RAISES(Cancelled, res);
+ ++n_cancelled;
+ }
+ }
+ ASSERT_GT(n_success, 0);
+ ASSERT_GT(n_cancelled, 0);
+ }
+}
+
+// Test fork safety on Unix
+
+#if !(defined(_WIN32) || defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER) || \
+ defined(THREAD_SANITIZER))
+TEST_F(TestThreadPool, ForkSafety) {
+ pid_t child_pid;
+ int child_status;
+
+ {
+ // Fork after task submission
+ auto pool = this->MakeThreadPool(3);
+ ASSERT_OK_AND_ASSIGN(auto fut, pool->Submit(add<int>, 4, 5));
+ ASSERT_OK_AND_EQ(9, fut.result());
+
+ child_pid = fork();
+ if (child_pid == 0) {
+ // Child: thread pool should be usable
+ ASSERT_OK_AND_ASSIGN(fut, pool->Submit(add<int>, 3, 4));
+ if (*fut.result() != 7) {
+ std::exit(1);
+ }
+ // Shutting down shouldn't hang or fail
+ Status st = pool->Shutdown();
+ std::exit(st.ok() ? 0 : 2);
+ } else {
+ // Parent
+ ASSERT_GT(child_pid, 0);
+ ASSERT_GT(waitpid(child_pid, &child_status, 0), 0);
+ ASSERT_TRUE(WIFEXITED(child_status));
+ ASSERT_EQ(WEXITSTATUS(child_status), 0);
+ ASSERT_OK(pool->Shutdown());
+ }
+ }
+ {
+ // Fork after shutdown
+ auto pool = this->MakeThreadPool(3);
+ ASSERT_OK(pool->Shutdown());
+
+ child_pid = fork();
+ if (child_pid == 0) {
+ // Child
+ // Spawning a task should return with error (pool was shutdown)
+ Status st = pool->Spawn([] {});
+ if (!st.IsInvalid()) {
+ std::exit(1);
+ }
+ // Trigger destructor
+ pool.reset();
+ std::exit(0);
+ } else {
+ // Parent
+ ASSERT_GT(child_pid, 0);
+ ASSERT_GT(waitpid(child_pid, &child_status, 0), 0);
+ ASSERT_TRUE(WIFEXITED(child_status));
+ ASSERT_EQ(WEXITSTATUS(child_status), 0);
+ }
+ }
+}
+#endif
+
+TEST(TestGlobalThreadPool, Capacity) {
+ // Sanity check
+ auto pool = GetCpuThreadPool();
+ int capacity = pool->GetCapacity();
+ ASSERT_GT(capacity, 0);
+ ASSERT_EQ(GetCpuThreadPoolCapacity(), capacity);
+
+ // This value depends on whether any tasks were launched previously
+ ASSERT_GE(pool->GetActualCapacity(), 0);
+ ASSERT_LE(pool->GetActualCapacity(), capacity);
+
+ // Exercise default capacity heuristic
+ ASSERT_OK(DelEnvVar("OMP_NUM_THREADS"));
+ ASSERT_OK(DelEnvVar("OMP_THREAD_LIMIT"));
+ int hw_capacity = std::thread::hardware_concurrency();
+ ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity);
+ ASSERT_OK(SetEnvVar("OMP_NUM_THREADS", "13"));
+ ASSERT_EQ(ThreadPool::DefaultCapacity(), 13);
+ ASSERT_OK(SetEnvVar("OMP_NUM_THREADS", "7,5,13"));
+ ASSERT_EQ(ThreadPool::DefaultCapacity(), 7);
+ ASSERT_OK(DelEnvVar("OMP_NUM_THREADS"));
+
+ ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT", "1"));
+ ASSERT_EQ(ThreadPool::DefaultCapacity(), 1);
+ ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT", "999"));
+ if (hw_capacity <= 999) {
+ ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity);
+ }
+ ASSERT_OK(SetEnvVar("OMP_NUM_THREADS", "6,5,13"));
+ ASSERT_EQ(ThreadPool::DefaultCapacity(), 6);
+ ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT", "2"));
+ ASSERT_EQ(ThreadPool::DefaultCapacity(), 2);
+
+ // Invalid env values
+ ASSERT_OK(SetEnvVar("OMP_NUM_THREADS", "0"));
+ ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT", "0"));
+ ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity);
+ ASSERT_OK(SetEnvVar("OMP_NUM_THREADS", "zzz"));
+ ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT", "x"));
+ ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity);
+ ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT", "-1"));
+ ASSERT_OK(SetEnvVar("OMP_NUM_THREADS", "99999999999999999999999999"));
+ ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity);
+
+ ASSERT_OK(DelEnvVar("OMP_NUM_THREADS"));
+ ASSERT_OK(DelEnvVar("OMP_THREAD_LIMIT"));
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/time.cc b/src/arrow/cpp/src/arrow/util/time.cc
new file mode 100644
index 000000000..c285f0750
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/time.cc
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/time.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace util {
+
+// TimestampType -> TimestampType
+static const std::pair<DivideOrMultiply, int64_t> kTimestampConversionTable[4][4] = {
+ // TimestampType::SECOND
+ {{MULTIPLY, 1}, {MULTIPLY, 1000}, {MULTIPLY, 1000000}, {MULTIPLY, 1000000000}},
+ // TimestampType::MILLI
+ {{DIVIDE, 1000}, {MULTIPLY, 1}, {MULTIPLY, 1000}, {MULTIPLY, 1000000}},
+ // TimestampType::MICRO
+ {{DIVIDE, 1000000}, {DIVIDE, 1000}, {MULTIPLY, 1}, {MULTIPLY, 1000}},
+ // TimestampType::NANO
+ {{DIVIDE, 1000000000}, {DIVIDE, 1000000}, {DIVIDE, 1000}, {MULTIPLY, 1}},
+};
+
+std::pair<DivideOrMultiply, int64_t> GetTimestampConversion(TimeUnit::type in_unit,
+ TimeUnit::type out_unit) {
+ return kTimestampConversionTable[static_cast<int>(in_unit)][static_cast<int>(out_unit)];
+}
+
+Result<int64_t> ConvertTimestampValue(const std::shared_ptr<DataType>& in,
+ const std::shared_ptr<DataType>& out,
+ int64_t value) {
+ auto op_factor =
+ GetTimestampConversion(checked_cast<const TimestampType&>(*in).unit(),
+ checked_cast<const TimestampType&>(*out).unit());
+
+ auto op = op_factor.first;
+ auto factor = op_factor.second;
+ switch (op) {
+ case MULTIPLY:
+ return value * factor;
+ case DIVIDE:
+ return value / factor;
+ }
+
+ // unreachable...
+ return 0;
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/time.h b/src/arrow/cpp/src/arrow/util/time.h
new file mode 100644
index 000000000..981eab596
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/time.h
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <chrono>
+#include <memory>
+#include <utility>
+
+#include "arrow/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace util {
+
+enum DivideOrMultiply {
+ MULTIPLY,
+ DIVIDE,
+};
+
+ARROW_EXPORT
+std::pair<DivideOrMultiply, int64_t> GetTimestampConversion(TimeUnit::type in_unit,
+ TimeUnit::type out_unit);
+
+// Converts a Timestamp value into another Timestamp value.
+//
+// This function takes care of properly transforming from one unit to another.
+//
+// \param[in] in the input type. Must be TimestampType.
+// \param[in] out the output type. Must be TimestampType.
+// \param[in] value the input value.
+//
+// \return The converted value, or an error.
+ARROW_EXPORT Result<int64_t> ConvertTimestampValue(const std::shared_ptr<DataType>& in,
+ const std::shared_ptr<DataType>& out,
+ int64_t value);
+
+template <typename Visitor, typename... Args>
+decltype(std::declval<Visitor>()(std::chrono::seconds{}, std::declval<Args&&>()...))
+VisitDuration(TimeUnit::type unit, Visitor&& visitor, Args&&... args) {
+ switch (unit) {
+ default:
+ case TimeUnit::SECOND:
+ break;
+ case TimeUnit::MILLI:
+ return visitor(std::chrono::milliseconds{}, std::forward<Args>(args)...);
+ case TimeUnit::MICRO:
+ return visitor(std::chrono::microseconds{}, std::forward<Args>(args)...);
+ case TimeUnit::NANO:
+ return visitor(std::chrono::nanoseconds{}, std::forward<Args>(args)...);
+ }
+ return visitor(std::chrono::seconds{}, std::forward<Args>(args)...);
+}
+
+/// Convert a count of seconds to the corresponding count in a different TimeUnit
+struct CastSecondsToUnitImpl {
+ template <typename Duration>
+ int64_t operator()(Duration, int64_t seconds) {
+ auto duration = std::chrono::duration_cast<Duration>(std::chrono::seconds{seconds});
+ return static_cast<int64_t>(duration.count());
+ }
+};
+
+inline int64_t CastSecondsToUnit(TimeUnit::type unit, int64_t seconds) {
+ return VisitDuration(unit, CastSecondsToUnitImpl{}, seconds);
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/time_test.cc b/src/arrow/cpp/src/arrow/util/time_test.cc
new file mode 100644
index 000000000..0224cca49
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/time_test.cc
@@ -0,0 +1,63 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/time.h"
+
+namespace arrow {
+namespace util {
+
+TEST(TimeTest, ConvertTimestampValue) {
+ auto convert = [](TimeUnit::type in, TimeUnit::type out, int64_t value) {
+ return ConvertTimestampValue(timestamp(in), timestamp(out), value).ValueOrDie();
+ };
+
+ auto units = {
+ TimeUnit::SECOND,
+ TimeUnit::MILLI,
+ TimeUnit::MICRO,
+ TimeUnit::NANO,
+ };
+
+ // Test for identity
+ for (auto unit : units) {
+ EXPECT_EQ(convert(unit, unit, 0), 0);
+ EXPECT_EQ(convert(unit, unit, INT64_MAX), INT64_MAX);
+ EXPECT_EQ(convert(unit, unit, INT64_MIN), INT64_MIN);
+ }
+
+ EXPECT_EQ(convert(TimeUnit::SECOND, TimeUnit::MILLI, 2), 2000);
+ EXPECT_EQ(convert(TimeUnit::SECOND, TimeUnit::MICRO, 2), 2000000);
+ EXPECT_EQ(convert(TimeUnit::SECOND, TimeUnit::NANO, 2), 2000000000);
+
+ EXPECT_EQ(convert(TimeUnit::MILLI, TimeUnit::SECOND, 7000), 7);
+ EXPECT_EQ(convert(TimeUnit::MILLI, TimeUnit::MICRO, 7), 7000);
+ EXPECT_EQ(convert(TimeUnit::MILLI, TimeUnit::NANO, 7), 7000000);
+
+ EXPECT_EQ(convert(TimeUnit::MICRO, TimeUnit::SECOND, 4000000), 4);
+ EXPECT_EQ(convert(TimeUnit::MICRO, TimeUnit::MILLI, 4000), 4);
+ EXPECT_EQ(convert(TimeUnit::MICRO, TimeUnit::SECOND, 4000000), 4);
+
+ EXPECT_EQ(convert(TimeUnit::NANO, TimeUnit::SECOND, 6000000000), 6);
+ EXPECT_EQ(convert(TimeUnit::NANO, TimeUnit::MILLI, 6000000), 6);
+ EXPECT_EQ(convert(TimeUnit::NANO, TimeUnit::MICRO, 6000), 6);
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/trie.cc b/src/arrow/cpp/src/arrow/util/trie.cc
new file mode 100644
index 000000000..7fa7f852e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/trie.cc
@@ -0,0 +1,211 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/trie.h"
+
+#include <iostream>
+#include <utility>
+
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace internal {
+
+Status Trie::Validate() const {
+ const auto n_nodes = static_cast<fast_index_type>(nodes_.size());
+ if (size_ > n_nodes) {
+ return Status::Invalid("Number of entries larger than number of nodes");
+ }
+ for (const auto& node : nodes_) {
+ if (node.found_index_ >= size_) {
+ return Status::Invalid("Found index >= size");
+ }
+ if (node.child_lookup_ != -1 &&
+ node.child_lookup_ * 256 >
+ static_cast<fast_index_type>(lookup_table_.size() - 256)) {
+ return Status::Invalid("Child lookup base doesn't point to 256 valid indices");
+ }
+ }
+ for (const auto index : lookup_table_) {
+ if (index >= n_nodes) {
+ return Status::Invalid("Child lookup index out of bounds");
+ }
+ }
+ return Status::OK();
+}
+
+void Trie::Dump(const Node* node, const std::string& indent) const {
+ std::cerr << "[\"" << node->substring_ << "\"]";
+ if (node->found_index_ >= 0) {
+ std::cerr << " *";
+ }
+ std::cerr << "\n";
+ if (node->child_lookup_ >= 0) {
+ auto child_indent = indent + " ";
+ std::cerr << child_indent << "|\n";
+ for (fast_index_type i = 0; i < 256; ++i) {
+ auto child_index = lookup_table_[node->child_lookup_ * 256 + i];
+ if (child_index >= 0) {
+ const Node* child = &nodes_[child_index];
+ std::cerr << child_indent << "|-> '" << static_cast<char>(i) << "' (" << i
+ << ") -> ";
+ Dump(child, child_indent);
+ }
+ }
+ }
+}
+
+void Trie::Dump() const { Dump(&nodes_[0], ""); }
+
+TrieBuilder::TrieBuilder() { trie_.nodes_.push_back(Trie::Node{-1, -1, ""}); }
+
+Status TrieBuilder::AppendChildNode(Trie::Node* parent, uint8_t ch, Trie::Node&& node) {
+ if (parent->child_lookup_ == -1) {
+ RETURN_NOT_OK(ExtendLookupTable(&parent->child_lookup_));
+ }
+ auto parent_lookup = parent->child_lookup_ * 256 + ch;
+
+ DCHECK_EQ(trie_.lookup_table_[parent_lookup], -1);
+ if (trie_.nodes_.size() >= static_cast<size_t>(kMaxIndex)) {
+ auto max_capacity = kMaxIndex;
+ return Status::CapacityError("TrieBuilder cannot contain more than ", max_capacity,
+ " child nodes");
+ }
+ trie_.nodes_.push_back(std::move(node));
+ trie_.lookup_table_[parent_lookup] = static_cast<index_type>(trie_.nodes_.size() - 1);
+ return Status::OK();
+}
+
+Status TrieBuilder::CreateChildNode(Trie::Node* parent, uint8_t ch,
+ util::string_view substring) {
+ const auto kMaxSubstringLength = Trie::kMaxSubstringLength;
+
+ while (substring.length() > kMaxSubstringLength) {
+ // Substring doesn't fit in node => create intermediate node
+ auto mid_node = Trie::Node{-1, -1, substring.substr(0, kMaxSubstringLength)};
+ RETURN_NOT_OK(AppendChildNode(parent, ch, std::move(mid_node)));
+ // Recurse
+ parent = &trie_.nodes_.back();
+ ch = static_cast<uint8_t>(substring[kMaxSubstringLength]);
+ substring = substring.substr(kMaxSubstringLength + 1);
+ }
+
+ // Create final matching node
+ auto child_node = Trie::Node{trie_.size_, -1, substring};
+ RETURN_NOT_OK(AppendChildNode(parent, ch, std::move(child_node)));
+ ++trie_.size_;
+ return Status::OK();
+}
+
+Status TrieBuilder::CreateChildNode(Trie::Node* parent, char ch,
+ util::string_view substring) {
+ return CreateChildNode(parent, static_cast<uint8_t>(ch), substring);
+}
+
+Status TrieBuilder::ExtendLookupTable(index_type* out_index) {
+ auto cur_size = trie_.lookup_table_.size();
+ auto cur_index = cur_size / 256;
+ if (cur_index > static_cast<size_t>(kMaxIndex)) {
+ return Status::CapacityError("TrieBuilder cannot extend lookup table further");
+ }
+ trie_.lookup_table_.resize(cur_size + 256, -1);
+ *out_index = static_cast<index_type>(cur_index);
+ return Status::OK();
+}
+
+Status TrieBuilder::SplitNode(fast_index_type node_index, fast_index_type split_at) {
+ Trie::Node* node = &trie_.nodes_[node_index];
+
+ DCHECK_LT(split_at, node->substring_length());
+
+ // Before:
+ // {node} -> [...]
+ // After:
+ // {node} -> [c] -> {out_node} -> [...]
+ auto child_node = Trie::Node{node->found_index_, node->child_lookup_,
+ node->substring_.substr(split_at + 1)};
+ auto ch = node->substring_[split_at];
+ node->child_lookup_ = -1;
+ node->found_index_ = -1;
+ node->substring_ = node->substring_.substr(0, split_at);
+ RETURN_NOT_OK(AppendChildNode(node, ch, std::move(child_node)));
+
+ return Status::OK();
+}
+
+Status TrieBuilder::Append(util::string_view s, bool allow_duplicate) {
+ // Find or create node for string
+ fast_index_type node_index = 0;
+ fast_index_type pos = 0;
+ fast_index_type remaining = static_cast<fast_index_type>(s.length());
+
+ while (true) {
+ Trie::Node* node = &trie_.nodes_[node_index];
+ const auto substring_length = node->substring_length();
+ const auto substring_data = node->substring_data();
+
+ for (fast_index_type i = 0; i < substring_length; ++i) {
+ if (remaining == 0) {
+ // New string too short => need to split node
+ RETURN_NOT_OK(SplitNode(node_index, i));
+ // Current node matches exactly
+ node = &trie_.nodes_[node_index];
+ node->found_index_ = trie_.size_++;
+ return Status::OK();
+ }
+ if (s[pos] != substring_data[i]) {
+ // Mismatching substring => need to split node
+ RETURN_NOT_OK(SplitNode(node_index, i));
+ // Create new node for mismatching char
+ node = &trie_.nodes_[node_index];
+ return CreateChildNode(node, s[pos], s.substr(pos + 1));
+ }
+ ++pos;
+ --remaining;
+ }
+ if (remaining == 0) {
+ // Node matches exactly
+ if (node->found_index_ >= 0) {
+ if (allow_duplicate) {
+ return Status::OK();
+ } else {
+ return Status::Invalid("Duplicate entry in trie");
+ }
+ }
+ node->found_index_ = trie_.size_++;
+ return Status::OK();
+ }
+ // Lookup child using next input character
+ if (node->child_lookup_ == -1) {
+ // Need to extend lookup table for this node
+ RETURN_NOT_OK(ExtendLookupTable(&node->child_lookup_));
+ }
+ auto c = static_cast<uint8_t>(s[pos++]);
+ --remaining;
+ node_index = trie_.lookup_table_[node->child_lookup_ * 256 + c];
+ if (node_index == -1) {
+ // Child not found => need to create child node
+ return CreateChildNode(node, c, s.substr(pos));
+ }
+ node = &trie_.nodes_[node_index];
+ }
+}
+
+Trie TrieBuilder::Finish() { return std::move(trie_); }
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/trie.h b/src/arrow/cpp/src/arrow/util/trie.h
new file mode 100644
index 000000000..b250cca64
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/trie.h
@@ -0,0 +1,245 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cassert>
+#include <cstdint>
+#include <cstring>
+#include <iosfwd>
+#include <limits>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/status.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace internal {
+
+// A non-zero-terminated small string class.
+// std::string usually has a small string optimization
+// (see review at https://shaharmike.com/cpp/std-string/)
+// but this one allows tight control and optimization of memory layout.
+template <uint8_t N>
+class SmallString {
+ public:
+ SmallString() : length_(0) {}
+
+ template <typename T>
+ SmallString(const T& v) { // NOLINT implicit constructor
+ *this = util::string_view(v);
+ }
+
+ SmallString& operator=(const util::string_view s) {
+#ifndef NDEBUG
+ CheckSize(s.size());
+#endif
+ length_ = static_cast<uint8_t>(s.size());
+ std::memcpy(data_, s.data(), length_);
+ return *this;
+ }
+
+ SmallString& operator=(const std::string& s) {
+ *this = util::string_view(s);
+ return *this;
+ }
+
+ SmallString& operator=(const char* s) {
+ *this = util::string_view(s);
+ return *this;
+ }
+
+ explicit operator util::string_view() const {
+ return util::string_view(data_, length_);
+ }
+
+ const char* data() const { return data_; }
+ size_t length() const { return length_; }
+ bool empty() const { return length_ == 0; }
+ char operator[](size_t pos) const {
+#ifdef NDEBUG
+ assert(pos <= length_);
+#endif
+ return data_[pos];
+ }
+
+ SmallString substr(size_t pos) const {
+ return SmallString(util::string_view(*this).substr(pos));
+ }
+
+ SmallString substr(size_t pos, size_t count) const {
+ return SmallString(util::string_view(*this).substr(pos, count));
+ }
+
+ template <typename T>
+ bool operator==(T&& other) const {
+ return util::string_view(*this) == util::string_view(std::forward<T>(other));
+ }
+
+ template <typename T>
+ bool operator!=(T&& other) const {
+ return util::string_view(*this) != util::string_view(std::forward<T>(other));
+ }
+
+ protected:
+ uint8_t length_;
+ char data_[N];
+
+ void CheckSize(size_t n) { assert(n <= N); }
+};
+
+template <uint8_t N>
+std::ostream& operator<<(std::ostream& os, const SmallString<N>& str) {
+ return os << util::string_view(str);
+}
+
+// A trie class for byte strings, optimized for small sets of short strings.
+// This class is immutable by design, use a TrieBuilder to construct it.
+class ARROW_EXPORT Trie {
+ using index_type = int16_t;
+ using fast_index_type = int_fast16_t;
+ static constexpr auto kMaxIndex = std::numeric_limits<index_type>::max();
+
+ public:
+ Trie() : size_(0) {}
+ Trie(Trie&&) = default;
+ Trie& operator=(Trie&&) = default;
+
+ int32_t Find(util::string_view s) const {
+ const Node* node = &nodes_[0];
+ fast_index_type pos = 0;
+ if (s.length() > static_cast<size_t>(kMaxIndex)) {
+ return -1;
+ }
+ fast_index_type remaining = static_cast<fast_index_type>(s.length());
+
+ while (remaining > 0) {
+ auto substring_length = node->substring_length();
+ if (substring_length > 0) {
+ auto substring_data = node->substring_data();
+ if (remaining < substring_length) {
+ // Input too short
+ return -1;
+ }
+ for (fast_index_type i = 0; i < substring_length; ++i) {
+ if (s[pos++] != substring_data[i]) {
+ // Mismatching substring
+ return -1;
+ }
+ --remaining;
+ }
+ if (remaining == 0) {
+ // Matched node exactly
+ return node->found_index_;
+ }
+ }
+ // Lookup child using next input character
+ if (node->child_lookup_ == -1) {
+ // Input too long
+ return -1;
+ }
+ auto c = static_cast<uint8_t>(s[pos++]);
+ --remaining;
+ auto child_index = lookup_table_[node->child_lookup_ * 256 + c];
+ if (child_index == -1) {
+ // Child not found
+ return -1;
+ }
+ node = &nodes_[child_index];
+ }
+
+ // Input exhausted
+ if (node->substring_.empty()) {
+ // Matched node exactly
+ return node->found_index_;
+ } else {
+ return -1;
+ }
+ }
+
+ Status Validate() const;
+
+ void Dump() const;
+
+ protected:
+ static constexpr size_t kNodeSize = 16;
+ static constexpr auto kMaxSubstringLength =
+ kNodeSize - 2 * sizeof(index_type) - sizeof(int8_t);
+
+ struct Node {
+ // If this node is a valid end of string, index of found string, otherwise -1
+ index_type found_index_;
+ // Base index for child lookup in lookup_table_ (-1 if no child nodes)
+ index_type child_lookup_;
+ // The substring for this node.
+ SmallString<kMaxSubstringLength> substring_;
+
+ fast_index_type substring_length() const {
+ return static_cast<fast_index_type>(substring_.length());
+ }
+ const char* substring_data() const { return substring_.data(); }
+ };
+
+ static_assert(sizeof(Node) == kNodeSize, "Unexpected node size");
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(Trie);
+
+ void Dump(const Node* node, const std::string& indent) const;
+
+ // Node table: entry 0 is the root node
+ std::vector<Node> nodes_;
+
+ // Indexed lookup structure: gives index in node table, or -1 if not found
+ std::vector<index_type> lookup_table_;
+
+ // Number of entries
+ index_type size_;
+
+ friend class TrieBuilder;
+};
+
+class ARROW_EXPORT TrieBuilder {
+ using index_type = Trie::index_type;
+ using fast_index_type = Trie::fast_index_type;
+
+ public:
+ TrieBuilder();
+ Status Append(util::string_view s, bool allow_duplicate = false);
+ Trie Finish();
+
+ protected:
+ // Extend the lookup table by 256 entries, return the index of the new span
+ Status ExtendLookupTable(index_type* out_lookup_index);
+ // Split the node given by the index at the substring index `split_at`
+ Status SplitNode(fast_index_type node_index, fast_index_type split_at);
+ // Append an already constructed child node to the parent
+ Status AppendChildNode(Trie::Node* parent, uint8_t ch, Trie::Node&& node);
+ // Create a matching child node from this parent
+ Status CreateChildNode(Trie::Node* parent, uint8_t ch, util::string_view substring);
+ Status CreateChildNode(Trie::Node* parent, char ch, util::string_view substring);
+
+ Trie trie_;
+
+ static constexpr auto kMaxIndex = std::numeric_limits<index_type>::max();
+};
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/trie_benchmark.cc b/src/arrow/cpp/src/arrow/util/trie_benchmark.cc
new file mode 100644
index 000000000..868accc37
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/trie_benchmark.cc
@@ -0,0 +1,222 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/trie.h"
+
+namespace arrow {
+namespace internal {
+
+std::vector<std::string> AllNulls() {
+ return {"#N/A", "#N/A N/A", "#NA", "-1.#IND", "-1.#QNAN", "-NaN", "-nan", "1.#IND",
+ "1.#QNAN", "N/A", "NA", "NULL", "NaN", "n/a", "nan", "null"};
+}
+
+Trie MakeNullsTrie() {
+ auto nulls = AllNulls();
+
+ TrieBuilder builder;
+ for (const auto& str : AllNulls()) {
+ ABORT_NOT_OK(builder.Append(str));
+ }
+ return builder.Finish();
+}
+
+std::vector<std::string> Expand(const std::vector<std::string>& base, size_t n) {
+ std::vector<std::string> result;
+ result.reserve(n);
+
+ while (true) {
+ for (const auto& v : base) {
+ result.push_back(v);
+ if (result.size() == n) {
+ return result;
+ }
+ }
+ }
+}
+
+static void BenchmarkTrieLookups(benchmark::State& state, // NOLINT non-const reference
+ const std::vector<std::string>& strings) {
+ Trie trie = MakeNullsTrie();
+ int32_t total = 0;
+
+ auto lookups = Expand(strings, 100);
+
+ for (auto _ : state) {
+ for (const auto& s : lookups) {
+ total += trie.Find(s);
+ }
+ }
+ benchmark::DoNotOptimize(total);
+ state.SetItemsProcessed(state.iterations() * lookups.size());
+}
+
+static void TrieLookupFound(benchmark::State& state) { // NOLINT non-const reference
+ BenchmarkTrieLookups(state, {"N/A", "null", "-1.#IND", "N/A"});
+}
+
+static void TrieLookupNotFound(benchmark::State& state) { // NOLINT non-const reference
+ BenchmarkTrieLookups(state, {"None", "1.0", "", "abc"});
+}
+
+BENCHMARK(TrieLookupFound);
+BENCHMARK(TrieLookupNotFound);
+
+#ifdef ARROW_WITH_BENCHMARKS_REFERENCE
+
+static inline bool InlinedNullLookup(util::string_view s) {
+ // An inlined version of trie lookup for a specific set of strings
+ // (see AllNulls())
+ auto size = s.length();
+ auto data = s.data();
+ if (size == 0) {
+ return false;
+ }
+ if (size == 1) {
+ return false;
+ }
+
+ auto chars = reinterpret_cast<const char*>(data);
+ auto first = chars[0];
+ auto second = chars[1];
+ switch (first) {
+ case 'N': {
+ // "NA", "N/A", "NaN", "NULL"
+ if (size == 2) {
+ return second == 'A';
+ }
+ auto third = chars[2];
+ if (size == 3) {
+ return (second == '/' && third == 'A') || (second == 'a' && third == 'N');
+ }
+ if (size == 4) {
+ return (second == 'U' && third == 'L' && chars[3] == 'L');
+ }
+ return false;
+ }
+ case 'n': {
+ // "n/a", "nan", "null"
+ if (size == 2) {
+ return false;
+ }
+ auto third = chars[2];
+ if (size == 3) {
+ return (second == '/' && third == 'a') || (second == 'a' && third == 'n');
+ }
+ if (size == 4) {
+ return (second == 'u' && third == 'l' && chars[3] == 'l');
+ }
+ return false;
+ }
+ case '1': {
+ // '1.#IND', '1.#QNAN'
+ if (size == 6) {
+ // '#' is the most unlikely char here, check it first
+ return (chars[2] == '#' && chars[1] == '.' && chars[3] == 'I' &&
+ chars[4] == 'N' && chars[5] == 'D');
+ }
+ if (size == 7) {
+ return (chars[2] == '#' && chars[1] == '.' && chars[3] == 'Q' &&
+ chars[4] == 'N' && chars[5] == 'A' && chars[6] == 'N');
+ }
+ return false;
+ }
+ case '-': {
+ switch (second) {
+ case 'N':
+ // "-NaN"
+ return (size == 4 && chars[2] == 'a' && chars[3] == 'N');
+ case 'n':
+ // "-nan"
+ return (size == 4 && chars[2] == 'a' && chars[3] == 'n');
+ case '1':
+ // "-1.#IND", "-1.#QNAN"
+ if (size == 7) {
+ return (chars[3] == '#' && chars[2] == '.' && chars[4] == 'I' &&
+ chars[5] == 'N' && chars[6] == 'D');
+ }
+ if (size == 8) {
+ return (chars[3] == '#' && chars[2] == '.' && chars[4] == 'Q' &&
+ chars[5] == 'N' && chars[6] == 'A' && chars[7] == 'N');
+ }
+ return false;
+ default:
+ return false;
+ }
+ }
+ case '#': {
+ // "#N/A", "#N/A N/A", "#NA"
+ if (size < 3 || chars[1] != 'N') {
+ return false;
+ }
+ auto third = chars[2];
+ if (size == 3) {
+ return third == 'A';
+ }
+ if (size == 4) {
+ return third == '/' && chars[3] == 'A';
+ }
+ if (size == 8) {
+ return std::memcmp(data + 2, "/A N/A", 5) == 0;
+ }
+ return false;
+ }
+ default:
+ return false;
+ }
+}
+
+static void BenchmarkInlinedTrieLookups(
+ benchmark::State& state, // NOLINT non-const reference
+ const std::vector<std::string>& strings) {
+ int32_t total = 0;
+
+ auto lookups = Expand(strings, 100);
+
+ for (auto _ : state) {
+ for (const auto& s : lookups) {
+ total += InlinedNullLookup(s);
+ }
+ }
+ benchmark::DoNotOptimize(total);
+ state.SetItemsProcessed(state.iterations() * lookups.size());
+}
+static void InlinedTrieLookupFound(
+ benchmark::State& state) { // NOLINT non-const reference
+ BenchmarkInlinedTrieLookups(state, {"N/A", "null", "-1.#IND", "N/A"});
+}
+
+static void InlinedTrieLookupNotFound(
+ benchmark::State& state) { // NOLINT non-const reference
+ BenchmarkInlinedTrieLookups(state, {"None", "1.0", "", "abc"});
+}
+
+BENCHMARK(InlinedTrieLookupFound);
+BENCHMARK(InlinedTrieLookupNotFound);
+
+#endif
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/trie_test.cc b/src/arrow/cpp/src/arrow/util/trie_test.cc
new file mode 100644
index 000000000..cfe66689d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/trie_test.cc
@@ -0,0 +1,305 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/trie.h"
+
+namespace arrow {
+namespace internal {
+
+TEST(SmallString, Basics) {
+ using SS = SmallString<5>;
+ {
+ SS s;
+ ASSERT_EQ(s.length(), 0);
+ ASSERT_EQ(util::string_view(s), util::string_view(""));
+ ASSERT_EQ(s, "");
+ ASSERT_NE(s, "x");
+ ASSERT_EQ(sizeof(s), 6);
+ }
+ {
+ SS s("abc");
+ ASSERT_EQ(s.length(), 3);
+ ASSERT_EQ(util::string_view(s), util::string_view("abc"));
+ ASSERT_EQ(std::memcmp(s.data(), "abc", 3), 0);
+ ASSERT_EQ(s, "abc");
+ ASSERT_NE(s, "ab");
+ }
+}
+
+TEST(SmallString, Assign) {
+ using SS = SmallString<5>;
+ auto s = SS();
+
+ s = util::string_view("abc");
+ ASSERT_EQ(s.length(), 3);
+ ASSERT_EQ(util::string_view(s), util::string_view("abc"));
+ ASSERT_EQ(std::memcmp(s.data(), "abc", 3), 0);
+ ASSERT_EQ(s, "abc");
+ ASSERT_NE(s, "ab");
+
+ s = std::string("ghijk");
+ ASSERT_EQ(s.length(), 5);
+ ASSERT_EQ(util::string_view(s), util::string_view("ghijk"));
+ ASSERT_EQ(std::memcmp(s.data(), "ghijk", 5), 0);
+ ASSERT_EQ(s, "ghijk");
+ ASSERT_NE(s, "");
+
+ s = SS("xy");
+ ASSERT_EQ(s.length(), 2);
+ ASSERT_EQ(util::string_view(s), util::string_view("xy"));
+ ASSERT_EQ(std::memcmp(s.data(), "xy", 2), 0);
+ ASSERT_EQ(s, "xy");
+ ASSERT_NE(s, "xyz");
+}
+
+TEST(SmallString, Substr) {
+ using SS = SmallString<5>;
+ {
+ auto s = SS();
+ ASSERT_EQ(s.substr(0), "");
+ ASSERT_EQ(s.substr(0, 2), "");
+ }
+ {
+ auto s = SS("abcd");
+ ASSERT_EQ(s.substr(0), "abcd");
+ ASSERT_EQ(s.substr(1), "bcd");
+ ASSERT_EQ(s.substr(4), "");
+ ASSERT_EQ(s.substr(0, 0), "");
+ ASSERT_EQ(s.substr(0, 3), "abc");
+ ASSERT_EQ(s.substr(0, 4), "abcd");
+ ASSERT_EQ(s.substr(1, 0), "");
+ ASSERT_EQ(s.substr(1, 2), "bc");
+ ASSERT_EQ(s.substr(4, 0), "");
+ ASSERT_EQ(s.substr(4, 1), "");
+ }
+}
+
+static std::vector<std::string> AllNulls() {
+ return {"#N/A", "#N/A N/A", "#NA", "-1.#IND", "-1.#QNAN", "-NaN", "-nan", "1.#IND",
+ "1.#QNAN", "N/A", "NA", "NULL", "NaN", "n/a", "nan", "null"};
+}
+
+static void TestTrieContents(const Trie& trie, const std::vector<std::string>& entries) {
+ std::unordered_map<std::string, int32_t> control;
+ auto n_entries = static_cast<int32_t>(entries.size());
+
+ // Build control container
+ for (int32_t i = 0; i < n_entries; ++i) {
+ auto p = control.insert({entries[i], i});
+ ASSERT_TRUE(p.second);
+ }
+
+ // Check all existing entries in trie
+ for (int32_t i = 0; i < n_entries; ++i) {
+ ASSERT_EQ(i, trie.Find(entries[i])) << "for string '" << entries[i] << "'";
+ }
+
+ auto CheckNotExists = [&control, &trie](const std::string& s) {
+ auto p = control.find(s);
+ if (p == control.end()) {
+ ASSERT_EQ(-1, trie.Find(s)) << "for string '" << s << "'";
+ }
+ };
+
+ // Check potentially nonexistent strings
+ CheckNotExists("");
+ CheckNotExists("X");
+ CheckNotExists("abcdefxxxxxxxxxxxxxxx");
+
+ // Check potentially nonexistent variations of existing entries
+ for (const auto& e : entries) {
+ CheckNotExists(e + "X");
+ if (e.size() > 0) {
+ CheckNotExists(e.substr(0, 1));
+ auto prefix = e.substr(0, e.size() - 1);
+ CheckNotExists(prefix);
+ CheckNotExists(prefix + "X");
+ auto split_at = e.size() / 2;
+ CheckNotExists(e.substr(0, split_at) + 'x' + e.substr(split_at + 1));
+ }
+ }
+}
+
+static void TestTrieContents(const std::vector<std::string>& entries) {
+ TrieBuilder builder;
+ for (const auto& s : entries) {
+ ASSERT_OK(builder.Append(s));
+ }
+ const Trie trie = builder.Finish();
+ ASSERT_OK(trie.Validate());
+
+ TestTrieContents(trie, entries);
+}
+
+TEST(Trie, Empty) {
+ TrieBuilder builder;
+ const Trie trie = builder.Finish();
+ ASSERT_OK(trie.Validate());
+
+ ASSERT_EQ(-1, trie.Find(""));
+ ASSERT_EQ(-1, trie.Find("x"));
+}
+
+TEST(Trie, EmptyString) {
+ TrieBuilder builder;
+ ASSERT_OK(builder.Append(""));
+ const Trie trie = builder.Finish();
+ ASSERT_OK(trie.Validate());
+
+ ASSERT_EQ(0, trie.Find(""));
+ ASSERT_EQ(-1, trie.Find("x"));
+}
+
+TEST(Trie, LongString) {
+ auto maxlen = static_cast<size_t>(std::numeric_limits<int16_t>::max());
+ // Ensure we can insert strings with length up to maxlen
+ for (auto&& length : {maxlen, maxlen - 1, maxlen / 2}) {
+ TrieBuilder builder;
+ std::string long_string(length, 'x');
+ ASSERT_OK(builder.Append(""));
+ ASSERT_OK(builder.Append(long_string));
+ const Trie trie = builder.Finish();
+ ASSERT_EQ(1, trie.Find(long_string));
+ }
+
+ // Ensure that the trie always returns false for strings with length > maxlen
+ for (auto&& length : {maxlen, maxlen - 1, maxlen / 2, maxlen + 1, maxlen * 2}) {
+ TrieBuilder builder;
+ ASSERT_OK(builder.Append(""));
+ const Trie trie = builder.Finish();
+ std::string long_string(length, 'x');
+ ASSERT_EQ(-1, trie.Find(long_string));
+ }
+}
+
+TEST(Trie, Basics1) {
+ TestTrieContents({"abc", "de", "f"});
+ TestTrieContents({"abc", "de", "f", ""});
+}
+
+TEST(Trie, Basics2) {
+ TestTrieContents({"a", "abc", "abcd", "abcdef"});
+ TestTrieContents({"", "a", "abc", "abcd", "abcdef"});
+}
+
+TEST(Trie, Basics3) {
+ TestTrieContents({"abcd", "ab", "a"});
+ TestTrieContents({"abcd", "ab", "a", ""});
+}
+
+TEST(Trie, LongStrings) {
+ TestTrieContents({"abcdefghijklmnopqr", "abcdefghijklmnoprq", "defghijklmnopqrst"});
+ TestTrieContents({"abcdefghijklmnopqr", "abcdefghijklmnoprq", "abcde"});
+}
+
+TEST(Trie, NullChars) {
+ const std::string empty;
+ const std::string nul(1, '\x00');
+ std::string a, b, c, d;
+ a = "x" + nul + "y";
+ b = "x" + nul + "z";
+ c = nul + "y";
+ d = nul;
+ ASSERT_EQ(a.length(), 3);
+ ASSERT_EQ(d.length(), 1);
+
+ TestTrieContents({a, b, c, d});
+ TestTrieContents({a, b, c});
+ TestTrieContents({a, b, c, d, ""});
+ TestTrieContents({a, b, c, ""});
+ TestTrieContents({d, c, b, a});
+ TestTrieContents({c, b, a});
+ TestTrieContents({d, c, b, a, ""});
+ TestTrieContents({c, b, a, ""});
+}
+
+TEST(Trie, NegativeChars) {
+ // Test with characters >= 0x80 (to check the absence of sign issues)
+ TestTrieContents({"\x7f\x80\x81\xff", "\x7f\x80\x81", "\x7f\xff\x81", "\xff\x80\x81"});
+}
+
+TEST(Trie, CSVNulls) { TestTrieContents(AllNulls()); }
+
+TEST(Trie, Duplicates) {
+ {
+ TrieBuilder builder;
+ ASSERT_OK(builder.Append("ab"));
+ ASSERT_OK(builder.Append("abc"));
+ ASSERT_RAISES(Invalid, builder.Append("abc"));
+ ASSERT_OK(builder.Append("abcd"));
+ ASSERT_RAISES(Invalid, builder.Append("ab"));
+ ASSERT_OK(builder.Append("abcde"));
+ const Trie trie = builder.Finish();
+
+ TestTrieContents(trie, {"ab", "abc", "abcd", "abcde"});
+ }
+ {
+ // With allow_duplicates = true
+ TrieBuilder builder;
+ ASSERT_OK(builder.Append("ab", true));
+ ASSERT_OK(builder.Append("abc", true));
+ ASSERT_OK(builder.Append("abc", true));
+ ASSERT_OK(builder.Append("abcd", true));
+ ASSERT_OK(builder.Append("ab", true));
+ ASSERT_OK(builder.Append("abcde", true));
+ const Trie trie = builder.Finish();
+
+ TestTrieContents(trie, {"ab", "abc", "abcd", "abcde"});
+ }
+}
+
+TEST(Trie, CapacityError) {
+ // A trie uses 16-bit indices into various internal structures and
+ // therefore has limited size available.
+ TrieBuilder builder;
+ uint8_t first, second, third;
+ bool had_capacity_error = false;
+ uint8_t s[] = "\x00\x00\x00\x00";
+
+ for (first = 1; first < 125; ++first) {
+ s[0] = first;
+ for (second = 1; second < 125; ++second) {
+ s[1] = second;
+ for (third = 1; third < 125; ++third) {
+ s[2] = third;
+ auto st = builder.Append(reinterpret_cast<const char*>(s));
+ if (st.IsCapacityError()) {
+ ASSERT_GE(first, 2);
+ had_capacity_error = true;
+ break;
+ } else {
+ ASSERT_OK(st);
+ }
+ }
+ }
+ }
+ ASSERT_TRUE(had_capacity_error) << "Should have produced CapacityError";
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/type_fwd.h b/src/arrow/cpp/src/arrow/util/type_fwd.h
new file mode 100644
index 000000000..ca107c2c6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/type_fwd.h
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+namespace arrow {
+
+namespace internal {
+struct Empty;
+} // namespace internal
+
+template <typename T = internal::Empty>
+class WeakFuture;
+class FutureWaiter;
+
+class TimestampParser;
+
+namespace internal {
+
+class Executor;
+class TaskGroup;
+class ThreadPool;
+
+} // namespace internal
+
+struct Compression {
+ /// \brief Compression algorithm
+ enum type {
+ UNCOMPRESSED,
+ SNAPPY,
+ GZIP,
+ BROTLI,
+ ZSTD,
+ LZ4,
+ LZ4_FRAME,
+ LZO,
+ BZ2,
+ LZ4_HADOOP
+ };
+};
+
+namespace util {
+class Compressor;
+class Decompressor;
+class Codec;
+} // namespace util
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/type_traits.h b/src/arrow/cpp/src/arrow/util/type_traits.h
new file mode 100644
index 000000000..80cc6297e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/type_traits.h
@@ -0,0 +1,86 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <type_traits>
+
+namespace arrow {
+namespace internal {
+
+/// \brief Metafunction to allow checking if a type matches any of another set of types
+template <typename...>
+struct IsOneOf : std::false_type {}; /// Base case: nothing has matched
+
+template <typename T, typename U, typename... Args>
+struct IsOneOf<T, U, Args...> {
+ /// Recursive case: T == U or T matches any other types provided (not including U).
+ static constexpr bool value = std::is_same<T, U>::value || IsOneOf<T, Args...>::value;
+};
+
+/// \brief Shorthand for using IsOneOf + std::enable_if
+template <typename T, typename... Args>
+using EnableIfIsOneOf = typename std::enable_if<IsOneOf<T, Args...>::value, T>::type;
+
+/// \brief is_null_pointer from C++17
+template <typename T>
+struct is_null_pointer : std::is_same<std::nullptr_t, typename std::remove_cv<T>::type> {
+};
+
+#ifdef __GLIBCXX__
+
+// A aligned_union backport, because old libstdc++ versions don't include it.
+
+constexpr std::size_t max_size(std::size_t a, std::size_t b) { return (a > b) ? a : b; }
+
+template <typename...>
+struct max_size_traits;
+
+template <typename H, typename... T>
+struct max_size_traits<H, T...> {
+ static constexpr std::size_t max_sizeof() {
+ return max_size(sizeof(H), max_size_traits<T...>::max_sizeof());
+ }
+ static constexpr std::size_t max_alignof() {
+ return max_size(alignof(H), max_size_traits<T...>::max_alignof());
+ }
+};
+
+template <>
+struct max_size_traits<> {
+ static constexpr std::size_t max_sizeof() { return 0; }
+ static constexpr std::size_t max_alignof() { return 0; }
+};
+
+template <std::size_t Len, typename... T>
+struct aligned_union {
+ static constexpr std::size_t alignment_value = max_size_traits<T...>::max_alignof();
+ static constexpr std::size_t size_value =
+ max_size(Len, max_size_traits<T...>::max_sizeof());
+ using type = typename std::aligned_storage<size_value, alignment_value>::type;
+};
+
+#else
+
+template <std::size_t Len, typename... T>
+using aligned_union = std::aligned_union<Len, T...>;
+
+#endif
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/ubsan.h b/src/arrow/cpp/src/arrow/util/ubsan.h
new file mode 100644
index 000000000..77c3cb8e5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/ubsan.h
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Contains utilities for making UBSan happy.
+
+#pragma once
+
+#include <cstring>
+#include <memory>
+#include <type_traits>
+
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace util {
+
+namespace internal {
+
+constexpr uint8_t kNonNullFiller = 0;
+
+} // namespace internal
+
+/// \brief Returns maybe_null if not null or a non-null pointer to an arbitrary memory
+/// that shouldn't be dereferenced.
+///
+/// Memset/Memcpy are undefined when a nullptr is passed as an argument use this utility
+/// method to wrap locations where this could happen.
+///
+/// Note: Flatbuffers has UBSan warnings if a zero length vector is passed.
+/// https://github.com/google/flatbuffers/pull/5355 is trying to resolve
+/// them.
+template <typename T>
+inline T* MakeNonNull(T* maybe_null = NULLPTR) {
+ if (ARROW_PREDICT_TRUE(maybe_null != NULLPTR)) {
+ return maybe_null;
+ }
+
+ return const_cast<T*>(reinterpret_cast<const T*>(&internal::kNonNullFiller));
+}
+
+template <typename T>
+inline typename std::enable_if<std::is_trivial<T>::value, T>::type SafeLoadAs(
+ const uint8_t* unaligned) {
+ typename std::remove_const<T>::type ret;
+ std::memcpy(&ret, unaligned, sizeof(T));
+ return ret;
+}
+
+template <typename T>
+inline typename std::enable_if<std::is_trivial<T>::value, T>::type SafeLoad(
+ const T* unaligned) {
+ typename std::remove_const<T>::type ret;
+ std::memcpy(&ret, unaligned, sizeof(T));
+ return ret;
+}
+
+template <typename U, typename T>
+inline typename std::enable_if<std::is_trivial<T>::value && std::is_trivial<U>::value &&
+ sizeof(T) == sizeof(U),
+ U>::type
+SafeCopy(T value) {
+ typename std::remove_const<U>::type ret;
+ std::memcpy(&ret, &value, sizeof(T));
+ return ret;
+}
+
+template <typename T>
+inline typename std::enable_if<std::is_trivial<T>::value, void>::type SafeStore(
+ void* unaligned, T value) {
+ std::memcpy(unaligned, &value, sizeof(T));
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/unreachable.cc b/src/arrow/cpp/src/arrow/util/unreachable.cc
new file mode 100644
index 000000000..4ffe3a8f7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/unreachable.cc
@@ -0,0 +1,29 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/unreachable.h"
+
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+[[noreturn]] void Unreachable(const char* message) {
+ DCHECK(false) << message;
+ std::abort();
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/unreachable.h b/src/arrow/cpp/src/arrow/util/unreachable.h
new file mode 100644
index 000000000..552635981
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/unreachable.h
@@ -0,0 +1,24 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+namespace arrow {
+
+[[noreturn]] void Unreachable(const char* message = "Unreachable");
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/uri.cc b/src/arrow/cpp/src/arrow/util/uri.cc
new file mode 100644
index 000000000..35c6b8981
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/uri.cc
@@ -0,0 +1,292 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/uri.h"
+
+#include <cstring>
+#include <sstream>
+#include <vector>
+
+#include "arrow/util/string_view.h"
+#include "arrow/util/value_parsing.h"
+#include "arrow/vendored/uriparser/Uri.h"
+
+namespace arrow {
+namespace internal {
+
+namespace {
+
+util::string_view TextRangeToView(const UriTextRangeStructA& range) {
+ if (range.first == nullptr) {
+ return "";
+ } else {
+ return {range.first, static_cast<size_t>(range.afterLast - range.first)};
+ }
+}
+
+std::string TextRangeToString(const UriTextRangeStructA& range) {
+ return std::string(TextRangeToView(range));
+}
+
+// There can be a difference between an absent field and an empty field.
+// For example, in "unix:/tmp/foo", the host is absent, while in
+// "unix:///tmp/foo", the host is empty but present.
+// This function helps distinguish.
+bool IsTextRangeSet(const UriTextRangeStructA& range) { return range.first != nullptr; }
+
+#ifdef _WIN32
+bool IsDriveSpec(const util::string_view s) {
+ return (s.length() >= 2 && s[1] == ':' &&
+ ((s[0] >= 'A' && s[0] <= 'Z') || (s[0] >= 'a' && s[0] <= 'z')));
+}
+#endif
+
+} // namespace
+
+std::string UriEscape(const std::string& s) {
+ if (s.empty()) {
+ // Avoid passing null pointer to uriEscapeExA
+ return s;
+ }
+ std::string escaped;
+ escaped.resize(3 * s.length());
+
+ auto end = uriEscapeExA(s.data(), s.data() + s.length(), &escaped[0],
+ /*spaceToPlus=*/URI_FALSE, /*normalizeBreaks=*/URI_FALSE);
+ escaped.resize(end - &escaped[0]);
+ return escaped;
+}
+
+std::string UriUnescape(const util::string_view s) {
+ std::string result(s);
+ if (!result.empty()) {
+ auto end = uriUnescapeInPlaceA(&result[0]);
+ result.resize(end - &result[0]);
+ }
+ return result;
+}
+
+std::string UriEncodeHost(const std::string& host) {
+ // Fairly naive check: if it contains a ':', it's IPv6 and needs
+ // brackets, else it's OK
+ if (host.find(":") != std::string::npos) {
+ std::string result = "[";
+ result += host;
+ result += ']';
+ return result;
+ } else {
+ return host;
+ }
+}
+
+struct Uri::Impl {
+ Impl() : string_rep_(""), port_(-1) { memset(&uri_, 0, sizeof(uri_)); }
+
+ ~Impl() { uriFreeUriMembersA(&uri_); }
+
+ void Reset() {
+ uriFreeUriMembersA(&uri_);
+ memset(&uri_, 0, sizeof(uri_));
+ data_.clear();
+ string_rep_.clear();
+ path_segments_.clear();
+ port_ = -1;
+ }
+
+ const std::string& KeepString(const std::string& s) {
+ data_.push_back(s);
+ return data_.back();
+ }
+
+ UriUriA uri_;
+ // Keep alive strings that uriparser stores pointers to
+ std::vector<std::string> data_;
+ std::string string_rep_;
+ int32_t port_;
+ std::vector<util::string_view> path_segments_;
+ bool is_file_uri_;
+ bool is_absolute_path_;
+};
+
+Uri::Uri() : impl_(new Impl) {}
+
+Uri::~Uri() {}
+
+Uri::Uri(Uri&& u) : impl_(std::move(u.impl_)) {}
+
+Uri& Uri::operator=(Uri&& u) {
+ impl_ = std::move(u.impl_);
+ return *this;
+}
+
+std::string Uri::scheme() const { return TextRangeToString(impl_->uri_.scheme); }
+
+std::string Uri::host() const { return TextRangeToString(impl_->uri_.hostText); }
+
+bool Uri::has_host() const { return IsTextRangeSet(impl_->uri_.hostText); }
+
+std::string Uri::port_text() const { return TextRangeToString(impl_->uri_.portText); }
+
+int32_t Uri::port() const { return impl_->port_; }
+
+std::string Uri::username() const {
+ auto userpass = TextRangeToView(impl_->uri_.userInfo);
+ auto sep_pos = userpass.find_first_of(':');
+ if (sep_pos == util::string_view::npos) {
+ return UriUnescape(userpass);
+ } else {
+ return UriUnescape(userpass.substr(0, sep_pos));
+ }
+}
+
+std::string Uri::password() const {
+ auto userpass = TextRangeToView(impl_->uri_.userInfo);
+ auto sep_pos = userpass.find_first_of(':');
+ if (sep_pos == util::string_view::npos) {
+ return std::string();
+ } else {
+ return UriUnescape(userpass.substr(sep_pos + 1));
+ }
+}
+
+std::string Uri::path() const {
+ const auto& segments = impl_->path_segments_;
+
+ bool must_prepend_slash = impl_->is_absolute_path_;
+#ifdef _WIN32
+ // On Windows, "file:///C:/foo" should have path "C:/foo", not "/C:/foo",
+ // despite it being absolute.
+ // (see https://tools.ietf.org/html/rfc8089#page-13)
+ if (impl_->is_absolute_path_ && impl_->is_file_uri_ && segments.size() > 0 &&
+ IsDriveSpec(segments[0])) {
+ must_prepend_slash = false;
+ }
+#endif
+
+ std::stringstream ss;
+ if (must_prepend_slash) {
+ ss << "/";
+ }
+ bool first = true;
+ for (const auto& seg : segments) {
+ if (!first) {
+ ss << "/";
+ }
+ first = false;
+ ss << seg;
+ }
+ return std::move(ss).str();
+}
+
+std::string Uri::query_string() const { return TextRangeToString(impl_->uri_.query); }
+
+Result<std::vector<std::pair<std::string, std::string>>> Uri::query_items() const {
+ const auto& query = impl_->uri_.query;
+ UriQueryListA* query_list;
+ int item_count;
+ std::vector<std::pair<std::string, std::string>> items;
+
+ if (query.first == nullptr) {
+ return items;
+ }
+ if (uriDissectQueryMallocA(&query_list, &item_count, query.first, query.afterLast) !=
+ URI_SUCCESS) {
+ return Status::Invalid("Cannot parse query string: '", query_string(), "'");
+ }
+ std::unique_ptr<UriQueryListA, decltype(&uriFreeQueryListA)> query_guard(
+ query_list, uriFreeQueryListA);
+
+ items.reserve(item_count);
+ while (query_list != nullptr) {
+ if (query_list->value != nullptr) {
+ items.emplace_back(query_list->key, query_list->value);
+ } else {
+ items.emplace_back(query_list->key, "");
+ }
+ query_list = query_list->next;
+ }
+ return items;
+}
+
+const std::string& Uri::ToString() const { return impl_->string_rep_; }
+
+Status Uri::Parse(const std::string& uri_string) {
+ impl_->Reset();
+
+ const auto& s = impl_->KeepString(uri_string);
+ impl_->string_rep_ = s;
+ const char* error_pos;
+ if (uriParseSingleUriExA(&impl_->uri_, s.data(), s.data() + s.size(), &error_pos) !=
+ URI_SUCCESS) {
+ return Status::Invalid("Cannot parse URI: '", uri_string, "'");
+ }
+
+ const auto scheme = TextRangeToView(impl_->uri_.scheme);
+ if (scheme.empty()) {
+ return Status::Invalid("URI has empty scheme: '", uri_string, "'");
+ }
+ impl_->is_file_uri_ = (scheme == "file");
+
+ // Gather path segments
+ auto path_seg = impl_->uri_.pathHead;
+ while (path_seg != nullptr) {
+ impl_->path_segments_.push_back(TextRangeToView(path_seg->text));
+ path_seg = path_seg->next;
+ }
+
+ // Decide whether URI path is absolute
+ impl_->is_absolute_path_ = false;
+ if (impl_->uri_.absolutePath == URI_TRUE) {
+ impl_->is_absolute_path_ = true;
+ } else if (has_host() && impl_->path_segments_.size() > 0) {
+ // When there's a host (even empty), uriparser considers the path relative.
+ // Several URI parsers for Python all consider it absolute, though.
+ // For example, the path for "file:///tmp/foo" is "/tmp/foo", not "tmp/foo".
+ // Similarly, the path for "file://localhost/" is "/".
+ // However, the path for "file://localhost" is "".
+ impl_->is_absolute_path_ = true;
+ }
+#ifdef _WIN32
+ // There's an exception on Windows: "file:/C:foo/bar" is relative.
+ if (impl_->is_file_uri_ && impl_->path_segments_.size() > 0) {
+ const auto& first_seg = impl_->path_segments_[0];
+ if (IsDriveSpec(first_seg) && (first_seg.length() >= 3 && first_seg[2] != '/')) {
+ impl_->is_absolute_path_ = false;
+ }
+ }
+#endif
+
+ if (impl_->is_file_uri_ && !impl_->is_absolute_path_) {
+ return Status::Invalid("File URI cannot be relative: '", uri_string, "'");
+ }
+
+ // Parse port number
+ auto port_text = TextRangeToView(impl_->uri_.portText);
+ if (port_text.size()) {
+ uint16_t port_num;
+ if (!ParseValue<UInt16Type>(port_text.data(), port_text.size(), &port_num)) {
+ return Status::Invalid("Invalid port number '", port_text, "' in URI '", uri_string,
+ "'");
+ }
+ impl_->port_ = port_num;
+ }
+
+ return Status::OK();
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/uri.h b/src/arrow/cpp/src/arrow/util/uri.h
new file mode 100644
index 000000000..b4ffbb04d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/uri.h
@@ -0,0 +1,104 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/type_fwd.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace internal {
+
+/// \brief A parsed URI
+class ARROW_EXPORT Uri {
+ public:
+ Uri();
+ ~Uri();
+ Uri(Uri&&);
+ Uri& operator=(Uri&&);
+
+ // XXX Should we use util::string_view instead? These functions are
+ // not performance-critical.
+
+ /// The URI scheme, such as "http", or the empty string if the URI has no
+ /// explicit scheme.
+ std::string scheme() const;
+
+ /// Whether the URI has an explicit host name. This may return true if
+ /// the URI has an empty host (e.g. "file:///tmp/foo"), while it returns
+ /// false is the URI has not host component at all (e.g. "file:/tmp/foo").
+ bool has_host() const;
+ /// The URI host name, such as "localhost", "127.0.0.1" or "::1", or the empty
+ /// string is the URI does not have a host component.
+ std::string host() const;
+
+ /// The URI port number, as a string such as "80", or the empty string is the URI
+ /// does not have a port number component.
+ std::string port_text() const;
+ /// The URI port parsed as an integer, or -1 if the URI does not have a port
+ /// number component.
+ int32_t port() const;
+
+ /// The username specified in the URI.
+ std::string username() const;
+ /// The password specified in the URI.
+ std::string password() const;
+
+ /// The URI path component.
+ std::string path() const;
+
+ /// The URI query string
+ std::string query_string() const;
+
+ /// The URI query items
+ ///
+ /// Note this API doesn't allow differentiating between an empty value
+ /// and a missing value, such in "a&b=1" vs. "a=&b=1".
+ Result<std::vector<std::pair<std::string, std::string>>> query_items() const;
+
+ /// Get the string representation of this URI.
+ const std::string& ToString() const;
+
+ /// Factory function to parse a URI from its string representation.
+ Status Parse(const std::string& uri_string);
+
+ private:
+ struct Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+/// Percent-encode the input string, for use e.g. as a URI query parameter.
+ARROW_EXPORT
+std::string UriEscape(const std::string& s);
+
+ARROW_EXPORT
+std::string UriUnescape(const arrow::util::string_view s);
+
+/// Encode a host for use within a URI, such as "localhost",
+/// "127.0.0.1", or "[::1]".
+ARROW_EXPORT
+std::string UriEncodeHost(const std::string& host);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/uri_test.cc b/src/arrow/cpp/src/arrow/util/uri_test.cc
new file mode 100644
index 000000000..169e9c81b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/uri_test.cc
@@ -0,0 +1,312 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace internal {
+
+TEST(UriEscape, Basics) {
+ ASSERT_EQ(UriEscape(""), "");
+ ASSERT_EQ(UriEscape("foo123"), "foo123");
+ ASSERT_EQ(UriEscape("/El Niño/"), "%2FEl%20Ni%C3%B1o%2F");
+}
+
+TEST(UriEncodeHost, Basics) {
+ ASSERT_EQ(UriEncodeHost("::1"), "[::1]");
+ ASSERT_EQ(UriEscape("arrow.apache.org"), "arrow.apache.org");
+ ASSERT_EQ(UriEscape("192.168.1.1"), "192.168.1.1");
+}
+
+TEST(Uri, Empty) {
+ Uri uri;
+ ASSERT_EQ(uri.scheme(), "");
+}
+
+TEST(Uri, ParseSimple) {
+ Uri uri;
+ {
+ // An ephemeral string object shouldn't invalidate results
+ std::string s = "https://arrow.apache.org";
+ ASSERT_OK(uri.Parse(s));
+ s.replace(0, s.size(), s.size(), 'X'); // replace contents
+ }
+ ASSERT_EQ(uri.scheme(), "https");
+ ASSERT_EQ(uri.host(), "arrow.apache.org");
+ ASSERT_EQ(uri.port_text(), "");
+}
+
+TEST(Uri, ParsePath) {
+ // The various edge cases below (leading and trailing slashes) have been
+ // checked against several Python URI parsing modules: `uri`, `rfc3986`, `rfc3987`
+
+ Uri uri;
+
+ auto check_case = [&](std::string uri_string, std::string scheme, bool has_host,
+ std::string host, std::string path) -> void {
+ ASSERT_OK(uri.Parse(uri_string));
+ ASSERT_EQ(uri.scheme(), scheme);
+ ASSERT_EQ(uri.has_host(), has_host);
+ ASSERT_EQ(uri.host(), host);
+ ASSERT_EQ(uri.path(), path);
+ };
+
+ // Relative path
+ check_case("unix:tmp/flight.sock", "unix", false, "", "tmp/flight.sock");
+
+ // Absolute path
+ check_case("unix:/tmp/flight.sock", "unix", false, "", "/tmp/flight.sock");
+ check_case("unix://localhost/tmp/flight.sock", "unix", true, "localhost",
+ "/tmp/flight.sock");
+ check_case("unix:///tmp/flight.sock", "unix", true, "", "/tmp/flight.sock");
+
+ // Empty path
+ check_case("unix:", "unix", false, "", "");
+ check_case("unix://localhost", "unix", true, "localhost", "");
+
+ // With trailing slash
+ check_case("unix:/", "unix", false, "", "/");
+ check_case("unix:tmp/", "unix", false, "", "tmp/");
+ check_case("unix://localhost/", "unix", true, "localhost", "/");
+ check_case("unix:/tmp/flight/", "unix", false, "", "/tmp/flight/");
+ check_case("unix://localhost/tmp/flight/", "unix", true, "localhost", "/tmp/flight/");
+ check_case("unix:///tmp/flight/", "unix", true, "", "/tmp/flight/");
+
+ // With query string
+ check_case("unix:?", "unix", false, "", "");
+ check_case("unix:?foo", "unix", false, "", "");
+ check_case("unix:?foo=bar", "unix", false, "", "");
+ check_case("unix:/?", "unix", false, "", "/");
+ check_case("unix:/?foo", "unix", false, "", "/");
+ check_case("unix:/?foo=bar", "unix", false, "", "/");
+ check_case("unix://localhost/tmp?", "unix", true, "localhost", "/tmp");
+ check_case("unix://localhost/tmp?foo", "unix", true, "localhost", "/tmp");
+ check_case("unix://localhost/tmp?foo=bar", "unix", true, "localhost", "/tmp");
+}
+
+TEST(Uri, ParseQuery) {
+ Uri uri;
+
+ auto check_case = [&](std::string uri_string, std::string query_string,
+ std::vector<std::pair<std::string, std::string>> items) -> void {
+ ASSERT_OK(uri.Parse(uri_string));
+ ASSERT_EQ(uri.query_string(), query_string);
+ auto result = uri.query_items();
+ ASSERT_OK(result);
+ ASSERT_EQ(*result, items);
+ };
+
+ check_case("unix://localhost/tmp", "", {});
+ check_case("unix://localhost/tmp?", "", {});
+ check_case("unix://localhost/tmp?foo=bar", "foo=bar", {{"foo", "bar"}});
+ check_case("unix:?foo=bar", "foo=bar", {{"foo", "bar"}});
+ check_case("unix:?a=b&c=d", "a=b&c=d", {{"a", "b"}, {"c", "d"}});
+
+ // With escaped values
+ check_case("unix:?a=some+value&b=c", "a=some+value&b=c",
+ {{"a", "some value"}, {"b", "c"}});
+ check_case("unix:?a=some%20value%2Fanother&b=c", "a=some%20value%2Fanother&b=c",
+ {{"a", "some value/another"}, {"b", "c"}});
+}
+
+TEST(Uri, ParseHostPort) {
+ Uri uri;
+
+ ASSERT_OK(uri.Parse("http://localhost:80"));
+ ASSERT_EQ(uri.scheme(), "http");
+ ASSERT_EQ(uri.host(), "localhost");
+ ASSERT_EQ(uri.port_text(), "80");
+ ASSERT_EQ(uri.port(), 80);
+ ASSERT_EQ(uri.username(), "");
+ ASSERT_EQ(uri.password(), "");
+
+ ASSERT_OK(uri.Parse("http://1.2.3.4"));
+ ASSERT_EQ(uri.scheme(), "http");
+ ASSERT_EQ(uri.host(), "1.2.3.4");
+ ASSERT_EQ(uri.port_text(), "");
+ ASSERT_EQ(uri.port(), -1);
+ ASSERT_EQ(uri.username(), "");
+ ASSERT_EQ(uri.password(), "");
+
+ ASSERT_OK(uri.Parse("http://1.2.3.4:"));
+ ASSERT_EQ(uri.scheme(), "http");
+ ASSERT_EQ(uri.host(), "1.2.3.4");
+ ASSERT_EQ(uri.port_text(), "");
+ ASSERT_EQ(uri.port(), -1);
+ ASSERT_EQ(uri.username(), "");
+ ASSERT_EQ(uri.password(), "");
+
+ ASSERT_OK(uri.Parse("http://1.2.3.4:80"));
+ ASSERT_EQ(uri.scheme(), "http");
+ ASSERT_EQ(uri.host(), "1.2.3.4");
+ ASSERT_EQ(uri.port_text(), "80");
+ ASSERT_EQ(uri.port(), 80);
+ ASSERT_EQ(uri.username(), "");
+ ASSERT_EQ(uri.password(), "");
+
+ ASSERT_OK(uri.Parse("http://[::1]"));
+ ASSERT_EQ(uri.scheme(), "http");
+ ASSERT_EQ(uri.host(), "::1");
+ ASSERT_EQ(uri.port_text(), "");
+ ASSERT_EQ(uri.port(), -1);
+ ASSERT_EQ(uri.username(), "");
+ ASSERT_EQ(uri.password(), "");
+
+ ASSERT_OK(uri.Parse("http://[::1]:"));
+ ASSERT_EQ(uri.scheme(), "http");
+ ASSERT_EQ(uri.host(), "::1");
+ ASSERT_EQ(uri.port_text(), "");
+ ASSERT_EQ(uri.port(), -1);
+ ASSERT_EQ(uri.username(), "");
+ ASSERT_EQ(uri.password(), "");
+
+ ASSERT_OK(uri.Parse("http://[::1]:80"));
+ ASSERT_EQ(uri.scheme(), "http");
+ ASSERT_EQ(uri.host(), "::1");
+ ASSERT_EQ(uri.port_text(), "80");
+ ASSERT_EQ(uri.port(), 80);
+ ASSERT_EQ(uri.username(), "");
+ ASSERT_EQ(uri.password(), "");
+}
+
+TEST(Uri, ParseUserPass) {
+ Uri uri;
+
+ ASSERT_OK(uri.Parse("http://someuser@localhost:80"));
+ ASSERT_EQ(uri.scheme(), "http");
+ ASSERT_EQ(uri.host(), "localhost");
+ ASSERT_EQ(uri.username(), "someuser");
+ ASSERT_EQ(uri.password(), "");
+
+ ASSERT_OK(uri.Parse("http://someuser:@localhost:80"));
+ ASSERT_EQ(uri.scheme(), "http");
+ ASSERT_EQ(uri.host(), "localhost");
+ ASSERT_EQ(uri.username(), "someuser");
+ ASSERT_EQ(uri.password(), "");
+
+ ASSERT_OK(uri.Parse("http://someuser:somepass@localhost:80"));
+ ASSERT_EQ(uri.scheme(), "http");
+ ASSERT_EQ(uri.host(), "localhost");
+ ASSERT_EQ(uri.username(), "someuser");
+ ASSERT_EQ(uri.password(), "somepass");
+
+ ASSERT_OK(uri.Parse("http://someuser:somepass@localhost"));
+ ASSERT_EQ(uri.scheme(), "http");
+ ASSERT_EQ(uri.host(), "localhost");
+ ASSERT_EQ(uri.username(), "someuser");
+ ASSERT_EQ(uri.password(), "somepass");
+
+ // With %-encoding
+ ASSERT_OK(uri.Parse("http://some%20user%2Fname:somepass@localhost"));
+ ASSERT_EQ(uri.scheme(), "http");
+ ASSERT_EQ(uri.host(), "localhost");
+ ASSERT_EQ(uri.username(), "some user/name");
+ ASSERT_EQ(uri.password(), "somepass");
+
+ ASSERT_OK(uri.Parse("http://some%20user%2Fname:some%20pass%2Fword@localhost"));
+ ASSERT_EQ(uri.scheme(), "http");
+ ASSERT_EQ(uri.host(), "localhost");
+ ASSERT_EQ(uri.username(), "some user/name");
+ ASSERT_EQ(uri.password(), "some pass/word");
+}
+
+TEST(Uri, FileScheme) {
+ // "file" scheme URIs
+ // https://en.wikipedia.org/wiki/File_URI_scheme
+ // https://tools.ietf.org/html/rfc8089
+ Uri uri;
+
+ auto check_no_host = [&](std::string uri_string, std::string path) -> void {
+ ASSERT_OK(uri.Parse(uri_string));
+ ASSERT_EQ(uri.scheme(), "file");
+ ASSERT_EQ(uri.host(), "");
+ ASSERT_EQ(uri.path(), path);
+ ASSERT_EQ(uri.username(), "");
+ ASSERT_EQ(uri.password(), "");
+ };
+
+ auto check_with_host = [&](std::string uri_string, std::string host,
+ std::string path) -> void {
+ ASSERT_OK(uri.Parse(uri_string));
+ ASSERT_EQ(uri.scheme(), "file");
+ ASSERT_EQ(uri.host(), host);
+ ASSERT_EQ(uri.path(), path);
+ ASSERT_EQ(uri.username(), "");
+ ASSERT_EQ(uri.password(), "");
+ };
+
+ // Relative paths are not accepted in "file" URIs.
+ ASSERT_RAISES(Invalid, uri.Parse("file:"));
+ ASSERT_RAISES(Invalid, uri.Parse("file:foo/bar"));
+
+ // Absolute paths
+ // (no authority)
+ check_no_host("file:/", "/");
+ check_no_host("file:/foo/bar", "/foo/bar");
+ // (empty authority)
+ check_no_host("file:///", "/");
+ check_no_host("file:///foo/bar", "/foo/bar");
+ // (non-empty authority)
+ check_with_host("file://localhost/", "localhost", "/");
+ check_with_host("file://localhost/foo/bar", "localhost", "/foo/bar");
+ check_with_host("file://hostname.com/", "hostname.com", "/");
+ check_with_host("file://hostname.com/foo/bar", "hostname.com", "/foo/bar");
+
+#ifdef _WIN32
+ // Relative paths
+ ASSERT_RAISES(Invalid, uri.Parse("file:/C:foo/bar"));
+ // (NOTE: "file:/C:" is currently parsed as an absolute URI pointing to "C:/")
+
+ // Absolute paths
+ // (no authority)
+ check_no_host("file:/C:/", "C:/");
+ check_no_host("file:/C:/foo/bar", "C:/foo/bar");
+ // (empty authority)
+ check_no_host("file:///C:/", "C:/");
+ check_no_host("file:///C:/foo/bar", "C:/foo/bar");
+ // (non-empty authority)
+ check_with_host("file://server/share/", "server", "/share/");
+ check_with_host("file://server/share/foo/bar", "server", "/share/foo/bar");
+#endif
+}
+
+TEST(Uri, ParseError) {
+ Uri uri;
+
+ ASSERT_RAISES(Invalid, uri.Parse("http://a:b:c:d"));
+ ASSERT_RAISES(Invalid, uri.Parse("http://localhost:z"));
+ ASSERT_RAISES(Invalid, uri.Parse("http://localhost:-1"));
+ ASSERT_RAISES(Invalid, uri.Parse("http://localhost:99999"));
+
+ // Scheme-less URIs (forbidden by RFC 3986, and ambiguous to parse)
+ ASSERT_RAISES(Invalid, uri.Parse("localhost"));
+ ASSERT_RAISES(Invalid, uri.Parse("/foo/bar"));
+ ASSERT_RAISES(Invalid, uri.Parse("foo/bar"));
+ ASSERT_RAISES(Invalid, uri.Parse(""));
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/utf8.cc b/src/arrow/cpp/src/arrow/util/utf8.cc
new file mode 100644
index 000000000..11394d2e6
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/utf8.cc
@@ -0,0 +1,160 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <iterator>
+#include <mutex>
+#include <stdexcept>
+#include <utility>
+
+#include "arrow/result.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/utf8.h"
+#include "arrow/vendored/utfcpp/checked.h"
+
+// Can be defined by utfcpp
+#ifdef NOEXCEPT
+#undef NOEXCEPT
+#endif
+
+namespace arrow {
+namespace util {
+namespace internal {
+
+// Copyright (c) 2008-2010 Bjoern Hoehrmann <bjoern@hoehrmann.de>
+// See http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ for details.
+
+// clang-format off
+const uint8_t utf8_small_table[] = { // NOLINT
+ // The first part of the table maps bytes to character classes that
+ // to reduce the size of the transition table and create bitmasks.
+ 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // NOLINT
+ 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // NOLINT
+ 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // NOLINT
+ 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // NOLINT
+ 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, // NOLINT
+ 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, // NOLINT
+ 8,8,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, // NOLINT
+ 10,3,3,3,3,3,3,3,3,3,3,3,3,4,3,3, 11,6,6,6,5,8,8,8,8,8,8,8,8,8,8,8, // NOLINT
+
+ // The second part is a transition table that maps a combination
+ // of a state of the automaton and a character class to a state.
+ // Character classes are between 0 and 11, states are multiples of 12.
+ 0,12,24,36,60,96,84,12,12,12,48,72, 12,12,12,12,12,12,12,12,12,12,12,12, // NOLINT
+ 12, 0,12,12,12,12,12, 0,12, 0,12,12, 12,24,12,12,12,12,12,24,12,24,12,12, // NOLINT
+ 12,12,12,12,12,12,12,24,12,12,12,12, 12,24,12,12,12,12,12,12,12,24,12,12, // NOLINT
+ 12,12,12,12,12,12,12,36,12,36,12,12, 12,36,12,12,12,12,12,36,12,36,12,12, // NOLINT
+ 12,36,12,12,12,12,12,12,12,12,12,12, // NOLINT
+};
+// clang-format on
+
+uint16_t utf8_large_table[9 * 256] = {0xffff};
+
+const uint8_t utf8_byte_size_table[16] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
+
+static void InitializeLargeTable() {
+ for (uint32_t state = 0; state < 9; ++state) {
+ for (uint32_t byte = 0; byte < 256; ++byte) {
+ uint32_t byte_class = utf8_small_table[byte];
+ uint8_t next_state = utf8_small_table[256 + state * 12 + byte_class] / 12;
+ DCHECK_LT(next_state, 9);
+ utf8_large_table[state * 256 + byte] = static_cast<uint16_t>(next_state * 256);
+ }
+ }
+}
+
+ARROW_EXPORT void CheckUTF8Initialized() {
+ DCHECK_EQ(utf8_large_table[0], 0)
+ << "InitializeUTF8() must be called before calling UTF8 routines";
+}
+
+} // namespace internal
+
+static std::once_flag utf8_initialized;
+
+void InitializeUTF8() {
+ std::call_once(utf8_initialized, internal::InitializeLargeTable);
+}
+
+static const uint8_t kBOM[] = {0xEF, 0xBB, 0xBF};
+
+Result<const uint8_t*> SkipUTF8BOM(const uint8_t* data, int64_t size) {
+ int64_t i;
+ for (i = 0; i < static_cast<int64_t>(sizeof(kBOM)); ++i) {
+ if (size == 0) {
+ if (i == 0) {
+ // Empty string
+ return data;
+ } else {
+ return Status::Invalid("UTF8 string too short (truncated byte order mark?)");
+ }
+ }
+ if (data[i] != kBOM[i]) {
+ // BOM not found
+ return data;
+ }
+ --size;
+ }
+ // BOM found
+ return data + i;
+}
+
+namespace {
+
+// Some platforms (such as old MinGWs) don't have the <codecvt> header,
+// so call into a vendored utf8 implementation instead.
+
+std::wstring UTF8ToWideStringInternal(const std::string& source) {
+ std::wstring ws;
+#if WCHAR_MAX > 0xFFFF
+ ::utf8::utf8to32(source.begin(), source.end(), std::back_inserter(ws));
+#else
+ ::utf8::utf8to16(source.begin(), source.end(), std::back_inserter(ws));
+#endif
+ return ws;
+}
+
+std::string WideStringToUTF8Internal(const std::wstring& source) {
+ std::string s;
+#if WCHAR_MAX > 0xFFFF
+ ::utf8::utf32to8(source.begin(), source.end(), std::back_inserter(s));
+#else
+ ::utf8::utf16to8(source.begin(), source.end(), std::back_inserter(s));
+#endif
+ return s;
+}
+
+} // namespace
+
+Result<std::wstring> UTF8ToWideString(const std::string& source) {
+ try {
+ return UTF8ToWideStringInternal(source);
+ } catch (std::exception& e) {
+ return Status::Invalid(e.what());
+ }
+}
+
+ARROW_EXPORT Result<std::string> WideStringToUTF8(const std::wstring& source) {
+ try {
+ return WideStringToUTF8Internal(source);
+ } catch (std::exception& e) {
+ return Status::Invalid(e.what());
+ }
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/utf8.h b/src/arrow/cpp/src/arrow/util/utf8.h
new file mode 100644
index 000000000..45cdcd833
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/utf8.h
@@ -0,0 +1,566 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cassert>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+
+#if defined(ARROW_HAVE_NEON) || defined(ARROW_HAVE_SSE4_2)
+#include <xsimd/xsimd.hpp>
+#endif
+
+#include "arrow/type_fwd.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/simd.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/ubsan.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace util {
+
+// Convert a UTF8 string to a wstring (either UTF16 or UTF32, depending
+// on the wchar_t width).
+ARROW_EXPORT Result<std::wstring> UTF8ToWideString(const std::string& source);
+
+// Similarly, convert a wstring to a UTF8 string.
+ARROW_EXPORT Result<std::string> WideStringToUTF8(const std::wstring& source);
+
+namespace internal {
+
+// Copyright (c) 2008-2010 Bjoern Hoehrmann <bjoern@hoehrmann.de>
+// See http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ for details.
+
+// A compact state table allowing UTF8 decoding using two dependent
+// lookups per byte. The first lookup determines the character class
+// and the second lookup reads the next state.
+// In this table states are multiples of 12.
+ARROW_EXPORT extern const uint8_t utf8_small_table[256 + 9 * 12];
+
+// Success / reject states when looked up in the small table
+static constexpr uint8_t kUTF8DecodeAccept = 0;
+static constexpr uint8_t kUTF8DecodeReject = 12;
+
+// An expanded state table allowing transitions using a single lookup
+// at the expense of a larger memory footprint (but on non-random data,
+// not all the table will end up accessed and cached).
+// In this table states are multiples of 256.
+ARROW_EXPORT extern uint16_t utf8_large_table[9 * 256];
+
+ARROW_EXPORT extern const uint8_t utf8_byte_size_table[16];
+
+// Success / reject states when looked up in the large table
+static constexpr uint16_t kUTF8ValidateAccept = 0;
+static constexpr uint16_t kUTF8ValidateReject = 256;
+
+static inline uint8_t DecodeOneUTF8Byte(uint8_t byte, uint8_t state, uint32_t* codep) {
+ uint8_t type = utf8_small_table[byte];
+
+ *codep = (state != kUTF8DecodeAccept) ? (byte & 0x3fu) | (*codep << 6)
+ : (0xff >> type) & (byte);
+
+ state = utf8_small_table[256 + state + type];
+ return state;
+}
+
+static inline uint16_t ValidateOneUTF8Byte(uint8_t byte, uint16_t state) {
+ return utf8_large_table[state + byte];
+}
+
+ARROW_EXPORT void CheckUTF8Initialized();
+
+} // namespace internal
+
+// This function needs to be called before doing UTF8 validation.
+ARROW_EXPORT void InitializeUTF8();
+
+static inline bool ValidateUTF8(const uint8_t* data, int64_t size) {
+ static constexpr uint64_t high_bits_64 = 0x8080808080808080ULL;
+ static constexpr uint32_t high_bits_32 = 0x80808080UL;
+ static constexpr uint16_t high_bits_16 = 0x8080U;
+ static constexpr uint8_t high_bits_8 = 0x80U;
+
+#ifndef NDEBUG
+ internal::CheckUTF8Initialized();
+#endif
+
+ while (size >= 8) {
+ // XXX This is doing an unaligned access. Contemporary architectures
+ // (x86-64, AArch64, PPC64) support it natively and often have good
+ // performance nevertheless.
+ uint64_t mask64 = SafeLoadAs<uint64_t>(data);
+ if (ARROW_PREDICT_TRUE((mask64 & high_bits_64) == 0)) {
+ // 8 bytes of pure ASCII, move forward
+ size -= 8;
+ data += 8;
+ continue;
+ }
+ // Non-ASCII run detected.
+ // We process at least 4 bytes, to avoid too many spurious 64-bit reads
+ // in case the non-ASCII bytes are at the end of the tested 64-bit word.
+ // We also only check for rejection at the end since that state is stable
+ // (once in reject state, we always remain in reject state).
+ // It is guaranteed that size >= 8 when arriving here, which allows
+ // us to avoid size checks.
+ uint16_t state = internal::kUTF8ValidateAccept;
+ // Byte 0
+ state = internal::ValidateOneUTF8Byte(*data++, state);
+ --size;
+ // Byte 1
+ state = internal::ValidateOneUTF8Byte(*data++, state);
+ --size;
+ // Byte 2
+ state = internal::ValidateOneUTF8Byte(*data++, state);
+ --size;
+ // Byte 3
+ state = internal::ValidateOneUTF8Byte(*data++, state);
+ --size;
+ // Byte 4
+ state = internal::ValidateOneUTF8Byte(*data++, state);
+ --size;
+ if (state == internal::kUTF8ValidateAccept) {
+ continue; // Got full char, switch back to ASCII detection
+ }
+ // Byte 5
+ state = internal::ValidateOneUTF8Byte(*data++, state);
+ --size;
+ if (state == internal::kUTF8ValidateAccept) {
+ continue; // Got full char, switch back to ASCII detection
+ }
+ // Byte 6
+ state = internal::ValidateOneUTF8Byte(*data++, state);
+ --size;
+ if (state == internal::kUTF8ValidateAccept) {
+ continue; // Got full char, switch back to ASCII detection
+ }
+ // Byte 7
+ state = internal::ValidateOneUTF8Byte(*data++, state);
+ --size;
+ if (state == internal::kUTF8ValidateAccept) {
+ continue; // Got full char, switch back to ASCII detection
+ }
+ // kUTF8ValidateAccept not reached along 4 transitions has to mean a rejection
+ assert(state == internal::kUTF8ValidateReject);
+ return false;
+ }
+
+ // Check if string tail is full ASCII (common case, fast)
+ if (size >= 4) {
+ uint32_t tail_mask = SafeLoadAs<uint32_t>(data + size - 4);
+ uint32_t head_mask = SafeLoadAs<uint32_t>(data);
+ if (ARROW_PREDICT_TRUE(((head_mask | tail_mask) & high_bits_32) == 0)) {
+ return true;
+ }
+ } else if (size >= 2) {
+ uint16_t tail_mask = SafeLoadAs<uint16_t>(data + size - 2);
+ uint16_t head_mask = SafeLoadAs<uint16_t>(data);
+ if (ARROW_PREDICT_TRUE(((head_mask | tail_mask) & high_bits_16) == 0)) {
+ return true;
+ }
+ } else if (size == 1) {
+ if (ARROW_PREDICT_TRUE((*data & high_bits_8) == 0)) {
+ return true;
+ }
+ } else {
+ /* size == 0 */
+ return true;
+ }
+
+ // Fall back to UTF8 validation of tail string.
+ // Note the state table is designed so that, once in the reject state,
+ // we remain in that state until the end. So we needn't check for
+ // rejection at each char (we don't gain much by short-circuiting here).
+ uint16_t state = internal::kUTF8ValidateAccept;
+ switch (size) {
+ case 7:
+ state = internal::ValidateOneUTF8Byte(data[size - 7], state);
+ case 6:
+ state = internal::ValidateOneUTF8Byte(data[size - 6], state);
+ case 5:
+ state = internal::ValidateOneUTF8Byte(data[size - 5], state);
+ case 4:
+ state = internal::ValidateOneUTF8Byte(data[size - 4], state);
+ case 3:
+ state = internal::ValidateOneUTF8Byte(data[size - 3], state);
+ case 2:
+ state = internal::ValidateOneUTF8Byte(data[size - 2], state);
+ case 1:
+ state = internal::ValidateOneUTF8Byte(data[size - 1], state);
+ default:
+ break;
+ }
+ return ARROW_PREDICT_TRUE(state == internal::kUTF8ValidateAccept);
+}
+
+static inline bool ValidateUTF8(const util::string_view& str) {
+ const uint8_t* data = reinterpret_cast<const uint8_t*>(str.data());
+ const size_t length = str.size();
+
+ return ValidateUTF8(data, length);
+}
+
+static inline bool ValidateAsciiSw(const uint8_t* data, int64_t len) {
+ uint8_t orall = 0;
+
+ if (len >= 8) {
+ uint64_t or8 = 0;
+
+ do {
+ or8 |= SafeLoadAs<uint64_t>(data);
+ data += 8;
+ len -= 8;
+ } while (len >= 8);
+
+ orall = !(or8 & 0x8080808080808080ULL) - 1;
+ }
+
+ while (len--) {
+ orall |= *data++;
+ }
+
+ return orall < 0x80U;
+}
+
+#if defined(ARROW_HAVE_NEON) || defined(ARROW_HAVE_SSE4_2)
+static inline bool ValidateAsciiSimd(const uint8_t* data, int64_t len) {
+#ifdef ARROW_HAVE_NEON
+ using simd_batch = xsimd::batch<int8_t, xsimd::neon64>;
+#else
+ using simd_batch = xsimd::batch<int8_t, xsimd::sse4_2>;
+#endif
+
+ if (len >= 32) {
+ const simd_batch zero(static_cast<int8_t>(0));
+ const uint8_t* data2 = data + 16;
+ simd_batch or1 = zero, or2 = zero;
+
+ while (len >= 32) {
+ or1 |= simd_batch::load_unaligned(reinterpret_cast<const int8_t*>(data));
+ or2 |= simd_batch::load_unaligned(reinterpret_cast<const int8_t*>(data2));
+ data += 32;
+ data2 += 32;
+ len -= 32;
+ }
+
+ // To test for upper bit in all bytes, test whether any of them is negative
+ or1 |= or2;
+ if (xsimd::any(or1 < zero)) {
+ return false;
+ }
+ }
+
+ return ValidateAsciiSw(data, len);
+}
+#endif // ARROW_HAVE_SSE4_2
+
+static inline bool ValidateAscii(const uint8_t* data, int64_t len) {
+#if defined(ARROW_HAVE_NEON) || defined(ARROW_HAVE_SSE4_2)
+ return ValidateAsciiSimd(data, len);
+#else
+ return ValidateAsciiSw(data, len);
+#endif
+}
+
+static inline bool ValidateAscii(const util::string_view& str) {
+ const uint8_t* data = reinterpret_cast<const uint8_t*>(str.data());
+ const size_t length = str.size();
+
+ return ValidateAscii(data, length);
+}
+
+// Skip UTF8 byte order mark, if any.
+ARROW_EXPORT
+Result<const uint8_t*> SkipUTF8BOM(const uint8_t* data, int64_t size);
+
+static constexpr uint32_t kMaxUnicodeCodepoint = 0x110000;
+
+// size of a valid UTF8 can be determined by looking at leading 4 bits of BYTE1
+// utf8_byte_size_table[0..7] --> pure ascii chars --> 1B length
+// utf8_byte_size_table[8..11] --> internal bytes --> 1B length
+// utf8_byte_size_table[12,13] --> 2B long UTF8 chars
+// utf8_byte_size_table[14] --> 3B long UTF8 chars
+// utf8_byte_size_table[15] --> 4B long UTF8 chars
+// NOTE: Results for invalid/ malformed utf-8 sequences are undefined.
+// ex: \xFF... returns 4B
+static inline uint8_t ValidUtf8CodepointByteSize(const uint8_t* codeunit) {
+ return internal::utf8_byte_size_table[*codeunit >> 4];
+}
+
+static inline bool Utf8IsContinuation(const uint8_t codeunit) {
+ return (codeunit & 0xC0) == 0x80; // upper two bits should be 10
+}
+
+static inline bool Utf8Is2ByteStart(const uint8_t codeunit) {
+ return (codeunit & 0xE0) == 0xC0; // upper three bits should be 110
+}
+
+static inline bool Utf8Is3ByteStart(const uint8_t codeunit) {
+ return (codeunit & 0xF0) == 0xE0; // upper four bits should be 1110
+}
+
+static inline bool Utf8Is4ByteStart(const uint8_t codeunit) {
+ return (codeunit & 0xF8) == 0xF0; // upper five bits should be 11110
+}
+
+static inline uint8_t* UTF8Encode(uint8_t* str, uint32_t codepoint) {
+ if (codepoint < 0x80) {
+ *str++ = codepoint;
+ } else if (codepoint < 0x800) {
+ *str++ = 0xC0 + (codepoint >> 6);
+ *str++ = 0x80 + (codepoint & 0x3F);
+ } else if (codepoint < 0x10000) {
+ *str++ = 0xE0 + (codepoint >> 12);
+ *str++ = 0x80 + ((codepoint >> 6) & 0x3F);
+ *str++ = 0x80 + (codepoint & 0x3F);
+ } else {
+ // Assume proper codepoints are always passed
+ assert(codepoint < kMaxUnicodeCodepoint);
+ *str++ = 0xF0 + (codepoint >> 18);
+ *str++ = 0x80 + ((codepoint >> 12) & 0x3F);
+ *str++ = 0x80 + ((codepoint >> 6) & 0x3F);
+ *str++ = 0x80 + (codepoint & 0x3F);
+ }
+ return str;
+}
+
+static inline bool UTF8Decode(const uint8_t** data, uint32_t* codepoint) {
+ const uint8_t* str = *data;
+ if (*str < 0x80) { // ascii
+ *codepoint = *str++;
+ } else if (ARROW_PREDICT_FALSE(*str < 0xC0)) { // invalid non-ascii char
+ return false;
+ } else if (*str < 0xE0) {
+ uint8_t code_unit_1 = (*str++) & 0x1F; // take last 5 bits
+ if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) {
+ return false;
+ }
+ uint8_t code_unit_2 = (*str++) & 0x3F; // take last 6 bits
+ *codepoint = (code_unit_1 << 6) + code_unit_2;
+ } else if (*str < 0xF0) {
+ uint8_t code_unit_1 = (*str++) & 0x0F; // take last 4 bits
+ if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) {
+ return false;
+ }
+ uint8_t code_unit_2 = (*str++) & 0x3F; // take last 6 bits
+ if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) {
+ return false;
+ }
+ uint8_t code_unit_3 = (*str++) & 0x3F; // take last 6 bits
+ *codepoint = (code_unit_1 << 12) + (code_unit_2 << 6) + code_unit_3;
+ } else if (*str < 0xF8) {
+ uint8_t code_unit_1 = (*str++) & 0x07; // take last 3 bits
+ if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) {
+ return false;
+ }
+ uint8_t code_unit_2 = (*str++) & 0x3F; // take last 6 bits
+ if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) {
+ return false;
+ }
+ uint8_t code_unit_3 = (*str++) & 0x3F; // take last 6 bits
+ if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) {
+ return false;
+ }
+ uint8_t code_unit_4 = (*str++) & 0x3F; // take last 6 bits
+ *codepoint =
+ (code_unit_1 << 18) + (code_unit_2 << 12) + (code_unit_3 << 6) + code_unit_4;
+ } else { // invalid non-ascii char
+ return false;
+ }
+ *data = str;
+ return true;
+}
+
+static inline bool UTF8DecodeReverse(const uint8_t** data, uint32_t* codepoint) {
+ const uint8_t* str = *data;
+ if (*str < 0x80) { // ascii
+ *codepoint = *str--;
+ } else {
+ if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) {
+ return false;
+ }
+ uint8_t code_unit_N = (*str--) & 0x3F; // take last 6 bits
+ if (Utf8Is2ByteStart(*str)) {
+ uint8_t code_unit_1 = (*str--) & 0x1F; // take last 5 bits
+ *codepoint = (code_unit_1 << 6) + code_unit_N;
+ } else {
+ if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) {
+ return false;
+ }
+ uint8_t code_unit_Nmin1 = (*str--) & 0x3F; // take last 6 bits
+ if (Utf8Is3ByteStart(*str)) {
+ uint8_t code_unit_1 = (*str--) & 0x0F; // take last 4 bits
+ *codepoint = (code_unit_1 << 12) + (code_unit_Nmin1 << 6) + code_unit_N;
+ } else {
+ if (ARROW_PREDICT_FALSE(!Utf8IsContinuation(*str))) {
+ return false;
+ }
+ uint8_t code_unit_Nmin2 = (*str--) & 0x3F; // take last 6 bits
+ if (ARROW_PREDICT_TRUE(Utf8Is4ByteStart(*str))) {
+ uint8_t code_unit_1 = (*str--) & 0x07; // take last 3 bits
+ *codepoint = (code_unit_1 << 18) + (code_unit_Nmin2 << 12) +
+ (code_unit_Nmin1 << 6) + code_unit_N;
+ } else {
+ return false;
+ }
+ }
+ }
+ }
+ *data = str;
+ return true;
+}
+
+template <class UnaryOperation>
+static inline bool UTF8Transform(const uint8_t* first, const uint8_t* last,
+ uint8_t** destination, UnaryOperation&& unary_op) {
+ const uint8_t* i = first;
+ uint8_t* out = *destination;
+ while (i < last) {
+ uint32_t codepoint = 0;
+ if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) {
+ return false;
+ }
+ out = UTF8Encode(out, unary_op(codepoint));
+ }
+ *destination = out;
+ return true;
+}
+
+template <class Predicate>
+static inline bool UTF8FindIf(const uint8_t* first, const uint8_t* last,
+ Predicate&& predicate, const uint8_t** position) {
+ const uint8_t* i = first;
+ while (i < last) {
+ uint32_t codepoint = 0;
+ const uint8_t* current = i;
+ if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) {
+ return false;
+ }
+ if (predicate(codepoint)) {
+ *position = current;
+ return true;
+ }
+ }
+ *position = last;
+ return true;
+}
+
+// Same semantics as std::find_if using reverse iterators with the return value
+// having the same semantics as std::reverse_iterator<..>.base()
+// A reverse iterator physically points to the next address, e.g.:
+// &*reverse_iterator(i) == &*(i + 1)
+template <class Predicate>
+static inline bool UTF8FindIfReverse(const uint8_t* first, const uint8_t* last,
+ Predicate&& predicate, const uint8_t** position) {
+ // converts to a normal point
+ const uint8_t* i = last - 1;
+ while (i >= first) {
+ uint32_t codepoint = 0;
+ const uint8_t* current = i;
+ if (ARROW_PREDICT_FALSE(!UTF8DecodeReverse(&i, &codepoint))) {
+ return false;
+ }
+ if (predicate(codepoint)) {
+ // converts normal pointer to 'reverse iterator semantics'.
+ *position = current + 1;
+ return true;
+ }
+ }
+ // similar to how an end pointer point to 1 beyond the last, reverse iterators point
+ // to the 'first' pointer to indicate out of range.
+ *position = first;
+ return true;
+}
+
+static inline bool UTF8AdvanceCodepoints(const uint8_t* first, const uint8_t* last,
+ const uint8_t** destination, int64_t n) {
+ return UTF8FindIf(
+ first, last,
+ [&](uint32_t codepoint) {
+ bool done = n == 0;
+ n--;
+ return done;
+ },
+ destination);
+}
+
+static inline bool UTF8AdvanceCodepointsReverse(const uint8_t* first, const uint8_t* last,
+ const uint8_t** destination, int64_t n) {
+ return UTF8FindIfReverse(
+ first, last,
+ [&](uint32_t codepoint) {
+ bool done = n == 0;
+ n--;
+ return done;
+ },
+ destination);
+}
+
+template <class UnaryFunction>
+static inline bool UTF8ForEach(const uint8_t* first, const uint8_t* last,
+ UnaryFunction&& f) {
+ const uint8_t* i = first;
+ while (i < last) {
+ uint32_t codepoint = 0;
+ if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) {
+ return false;
+ }
+ f(codepoint);
+ }
+ return true;
+}
+
+template <class UnaryFunction>
+static inline bool UTF8ForEach(const std::string& s, UnaryFunction&& f) {
+ return UTF8ForEach(reinterpret_cast<const uint8_t*>(s.data()),
+ reinterpret_cast<const uint8_t*>(s.data() + s.length()),
+ std::forward<UnaryFunction>(f));
+}
+
+template <class UnaryPredicate>
+static inline bool UTF8AllOf(const uint8_t* first, const uint8_t* last, bool* result,
+ UnaryPredicate&& predicate) {
+ const uint8_t* i = first;
+ while (i < last) {
+ uint32_t codepoint = 0;
+ if (ARROW_PREDICT_FALSE(!UTF8Decode(&i, &codepoint))) {
+ return false;
+ }
+
+ if (!predicate(codepoint)) {
+ *result = false;
+ return true;
+ }
+ }
+ *result = true;
+ return true;
+}
+
+/// Count the number of codepoints in the given string (assuming it is valid UTF8).
+static inline int64_t UTF8Length(const uint8_t* first, const uint8_t* last) {
+ int64_t length = 0;
+ while (first != last) {
+ length += ((*first++ & 0xc0) != 0x80);
+ }
+ return length;
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/utf8_util_benchmark.cc b/src/arrow/cpp/src/arrow/util/utf8_util_benchmark.cc
new file mode 100644
index 000000000..2cbaa181d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/utf8_util_benchmark.cc
@@ -0,0 +1,150 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/utf8.h"
+
+// Do not benchmark inlined functions directly inside the benchmark loop
+static ARROW_NOINLINE bool ValidateUTF8NoInline(const uint8_t* data, int64_t size) {
+ return ::arrow::util::ValidateUTF8(data, size);
+}
+
+static ARROW_NOINLINE bool ValidateAsciiNoInline(const uint8_t* data, int64_t size) {
+ return ::arrow::util::ValidateAscii(data, size);
+}
+
+namespace arrow {
+namespace util {
+
+static const char* tiny_valid_ascii = "characters";
+static const char* tiny_valid_non_ascii = "caractères";
+
+static const char* valid_ascii =
+ "UTF-8 is a variable width character encoding capable of encoding all 1,112,064 "
+ "valid code points in Unicode using one to four 8-bit bytes";
+static const char* valid_almost_ascii =
+ "UTF-8 est un codage de caractères informatiques conçu pour coder l’ensemble des "
+ "caractères du « répertoire universel de caractères codés »";
+static const char* valid_non_ascii =
+ "UTF-8 はISO/IEC 10646 (UCS) "
+ "とUnicodeで使える8ビット符号単位の文字符号化形式及び文字符号化スキーム。 ";
+
+static std::string MakeLargeString(const std::string& base, int64_t nbytes) {
+ int64_t nrepeats = (nbytes + base.size() - 1) / base.size();
+ std::string s;
+ s.reserve(nrepeats * nbytes);
+ for (int64_t i = 0; i < nrepeats; ++i) {
+ s += base;
+ }
+ return s;
+}
+
+static void BenchmarkUTF8Validation(
+ benchmark::State& state, // NOLINT non-const reference
+ const std::string& s, bool expected) {
+ auto data = reinterpret_cast<const uint8_t*>(s.data());
+ auto data_size = static_cast<int64_t>(s.size());
+
+ InitializeUTF8();
+ bool b = ValidateUTF8NoInline(data, data_size);
+ if (b != expected) {
+ std::cerr << "Unexpected validation result" << std::endl;
+ std::abort();
+ }
+
+ while (state.KeepRunning()) {
+ bool b = ValidateUTF8NoInline(data, data_size);
+ benchmark::DoNotOptimize(b);
+ }
+ state.SetBytesProcessed(state.iterations() * s.size());
+}
+
+static void BenchmarkASCIIValidation(
+ benchmark::State& state, // NOLINT non-const reference
+ const std::string& s, bool expected) {
+ auto data = reinterpret_cast<const uint8_t*>(s.data());
+ auto data_size = static_cast<int64_t>(s.size());
+
+ bool b = ValidateAsciiNoInline(data, data_size);
+ if (b != expected) {
+ std::cerr << "Unexpected validation result" << std::endl;
+ std::abort();
+ }
+
+ while (state.KeepRunning()) {
+ bool b = ValidateAsciiNoInline(data, data_size);
+ benchmark::DoNotOptimize(b);
+ }
+ state.SetBytesProcessed(state.iterations() * s.size());
+}
+
+static void ValidateTinyAscii(benchmark::State& state) { // NOLINT non-const reference
+ BenchmarkASCIIValidation(state, tiny_valid_ascii, true);
+}
+
+static void ValidateTinyNonAscii(benchmark::State& state) { // NOLINT non-const reference
+ BenchmarkUTF8Validation(state, tiny_valid_non_ascii, true);
+}
+
+static void ValidateSmallAscii(benchmark::State& state) { // NOLINT non-const reference
+ BenchmarkASCIIValidation(state, valid_ascii, true);
+}
+
+static void ValidateSmallAlmostAscii(
+ benchmark::State& state) { // NOLINT non-const reference
+ BenchmarkUTF8Validation(state, valid_almost_ascii, true);
+}
+
+static void ValidateSmallNonAscii(
+ benchmark::State& state) { // NOLINT non-const reference
+ BenchmarkUTF8Validation(state, valid_non_ascii, true);
+}
+
+static void ValidateLargeAscii(benchmark::State& state) { // NOLINT non-const reference
+ auto s = MakeLargeString(valid_ascii, 100000);
+ BenchmarkASCIIValidation(state, s, true);
+}
+
+static void ValidateLargeAlmostAscii(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto s = MakeLargeString(valid_almost_ascii, 100000);
+ BenchmarkUTF8Validation(state, s, true);
+}
+
+static void ValidateLargeNonAscii(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto s = MakeLargeString(valid_non_ascii, 100000);
+ BenchmarkUTF8Validation(state, s, true);
+}
+
+BENCHMARK(ValidateTinyAscii);
+BENCHMARK(ValidateTinyNonAscii);
+BENCHMARK(ValidateSmallAscii);
+BENCHMARK(ValidateSmallAlmostAscii);
+BENCHMARK(ValidateSmallNonAscii);
+BENCHMARK(ValidateLargeAscii);
+BENCHMARK(ValidateLargeAlmostAscii);
+BENCHMARK(ValidateLargeNonAscii);
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/utf8_util_test.cc b/src/arrow/cpp/src/arrow/util/utf8_util_test.cc
new file mode 100644
index 000000000..878d924e4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/utf8_util_test.cc
@@ -0,0 +1,513 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <random>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/string.h"
+#include "arrow/util/utf8.h"
+
+namespace arrow {
+namespace util {
+
+class UTF8Test : public ::testing::Test {
+ protected:
+ static void SetUpTestCase() {
+ InitializeUTF8();
+
+ all_valid_sequences.clear();
+ for (const auto& v :
+ {valid_sequences_1, valid_sequences_2, valid_sequences_3, valid_sequences_4}) {
+ all_valid_sequences.insert(all_valid_sequences.end(), v.begin(), v.end());
+ }
+
+ all_invalid_sequences.clear();
+ for (const auto& v : {invalid_sequences_1, invalid_sequences_2, invalid_sequences_3,
+ invalid_sequences_4}) {
+ all_invalid_sequences.insert(all_invalid_sequences.end(), v.begin(), v.end());
+ }
+ }
+
+ static std::vector<std::string> valid_sequences_1;
+ static std::vector<std::string> valid_sequences_2;
+ static std::vector<std::string> valid_sequences_3;
+ static std::vector<std::string> valid_sequences_4;
+
+ static std::vector<std::string> all_valid_sequences;
+
+ static std::vector<std::string> invalid_sequences_1;
+ static std::vector<std::string> invalid_sequences_2;
+ static std::vector<std::string> invalid_sequences_3;
+ static std::vector<std::string> invalid_sequences_4;
+
+ static std::vector<std::string> all_invalid_sequences;
+
+ static std::vector<std::string> valid_sequences_ascii;
+ static std::vector<std::string> invalid_sequences_ascii;
+};
+
+std::vector<std::string> UTF8Test::valid_sequences_1 = {"a", "\x7f",
+ std::string("\0", 1)};
+std::vector<std::string> UTF8Test::valid_sequences_2 = {"\xc2\x80", "\xc3\xbf",
+ "\xdf\xbf"};
+std::vector<std::string> UTF8Test::valid_sequences_3 = {"\xe0\xa0\x80", "\xe8\x9d\xa5",
+ "\xef\xbf\xbf"};
+std::vector<std::string> UTF8Test::valid_sequences_4 = {
+ "\xf0\x90\x80\x80", "\xf0\x9f\xbf\xbf", "\xf4\x80\x80\x80", "\xf4\x8f\xbf\xbf"};
+
+std::vector<std::string> UTF8Test::all_valid_sequences;
+
+std::vector<std::string> UTF8Test::invalid_sequences_1 = {"\x80", "\xa0", "\xbf", "\xc0",
+ "\xc1"};
+std::vector<std::string> UTF8Test::invalid_sequences_2 = {
+ "\x80\x80", "\x80\xbf", "\xbf\x80", "\xbf\xbf",
+ "\xc1\x80", "\xc2\x7f", "\xc3\xff", "\xdf\xc0"};
+std::vector<std::string> UTF8Test::invalid_sequences_3 = {
+ "\xe0\x80\x80", "\xe0\x9f\x80", "\xef\xbf\xc0", "\xef\xc0\xbf", "\xef\xff\xff",
+ // Surrogates
+ "\xed\xa0\x80", "\xed\xbf\xbf"};
+std::vector<std::string> UTF8Test::invalid_sequences_4 = {
+ "\xf0\x80\x80\x80", "\xf0\x8f\x80\x80", "\xf4\x8f\xbf\xc0", "\xf4\x8f\xc0\xbf",
+ "\xf4\x90\x80\x80"};
+
+std::vector<std::string> UTF8Test::all_invalid_sequences;
+
+std::vector<std::string> UTF8Test::valid_sequences_ascii = {"a", "\x7f", "B", "&"};
+std::vector<std::string> UTF8Test::invalid_sequences_ascii = {
+ "\x80", "\xa0\x1e", "\xbf\xef\x6a", "\xc1\x9f\xc3\xd9"};
+
+class UTF8ValidationTest : public UTF8Test {};
+
+class ASCIIValidationTest : public UTF8Test {};
+
+::testing::AssertionResult IsValidUTF8(const std::string& s) {
+ if (ValidateUTF8(reinterpret_cast<const uint8_t*>(s.data()), s.size())) {
+ return ::testing::AssertionSuccess();
+ } else {
+ std::string h = HexEncode(reinterpret_cast<const uint8_t*>(s.data()),
+ static_cast<int32_t>(s.size()));
+ return ::testing::AssertionFailure()
+ << "string '" << h << "' didn't validate as UTF8";
+ }
+}
+
+::testing::AssertionResult IsInvalidUTF8(const std::string& s) {
+ if (!ValidateUTF8(reinterpret_cast<const uint8_t*>(s.data()), s.size())) {
+ return ::testing::AssertionSuccess();
+ } else {
+ std::string h = HexEncode(reinterpret_cast<const uint8_t*>(s.data()),
+ static_cast<int32_t>(s.size()));
+ return ::testing::AssertionFailure() << "string '" << h << "' validated as UTF8";
+ }
+}
+
+::testing::AssertionResult IsValidASCII(const std::string& s) {
+ if (ValidateAscii(reinterpret_cast<const uint8_t*>(s.data()), s.size())) {
+ return ::testing::AssertionSuccess();
+ } else {
+ std::string h = HexEncode(reinterpret_cast<const uint8_t*>(s.data()),
+ static_cast<int32_t>(s.size()));
+ return ::testing::AssertionFailure()
+ << "string '" << h << "' didn't validate as ASCII";
+ }
+}
+
+::testing::AssertionResult IsInvalidASCII(const std::string& s) {
+ if (!ValidateAscii(reinterpret_cast<const uint8_t*>(s.data()), s.size())) {
+ return ::testing::AssertionSuccess();
+ } else {
+ std::string h = HexEncode(reinterpret_cast<const uint8_t*>(s.data()),
+ static_cast<int32_t>(s.size()));
+ return ::testing::AssertionFailure() << "string '" << h << "' validated as ASCII";
+ }
+}
+
+template <typename ValidationFunc>
+void ValidateWithPrefixes(ValidationFunc&& validate, const std::string& s) {
+ // Exercise SIMD optimizations
+ for (int prefix_size = 1; prefix_size < 64; ++prefix_size) {
+ std::string longer(prefix_size, 'x');
+ longer.append(s);
+ ASSERT_TRUE(validate(longer));
+ longer.append(prefix_size, 'y');
+ ASSERT_TRUE(validate(longer));
+ }
+}
+
+void AssertValidUTF8(const std::string& s) { ASSERT_TRUE(IsValidUTF8(s)); }
+
+void AssertInvalidUTF8(const std::string& s) { ASSERT_TRUE(IsInvalidUTF8(s)); }
+
+void AssertValidASCII(const std::string& s) {
+ ASSERT_TRUE(IsValidASCII(s));
+ ValidateWithPrefixes(IsValidASCII, s);
+}
+
+void AssertInvalidASCII(const std::string& s) {
+ ASSERT_TRUE(IsInvalidASCII(s));
+ ValidateWithPrefixes(IsInvalidASCII, s);
+}
+
+TEST_F(ASCIIValidationTest, AsciiValid) {
+ for (const auto& s : valid_sequences_ascii) {
+ AssertValidASCII(s);
+ }
+}
+
+TEST_F(ASCIIValidationTest, AsciiInvalid) {
+ for (const auto& s : invalid_sequences_ascii) {
+ AssertInvalidASCII(s);
+ }
+}
+
+TEST_F(UTF8ValidationTest, EmptyString) { AssertValidUTF8(""); }
+
+TEST_F(UTF8ValidationTest, OneCharacterValid) {
+ for (const auto& s : all_valid_sequences) {
+ AssertValidUTF8(s);
+ }
+}
+
+TEST_F(UTF8ValidationTest, TwoCharacterValid) {
+ for (const auto& s1 : all_valid_sequences) {
+ for (const auto& s2 : all_valid_sequences) {
+ AssertValidUTF8(s1 + s2);
+ }
+ }
+}
+
+TEST_F(UTF8ValidationTest, RandomValid) {
+#ifdef ARROW_VALGRIND
+ const int niters = 50;
+#else
+ const int niters = 1000;
+#endif
+ const int nchars = 100;
+ std::default_random_engine gen(42);
+ std::uniform_int_distribution<size_t> valid_dist(0, all_valid_sequences.size() - 1);
+
+ for (int i = 0; i < niters; ++i) {
+ std::string s;
+ s.reserve(nchars * 4);
+ for (int j = 0; j < nchars; ++j) {
+ s += all_valid_sequences[valid_dist(gen)];
+ }
+ AssertValidUTF8(s);
+ }
+}
+
+TEST_F(UTF8ValidationTest, OneCharacterTruncated) {
+ for (const auto& s : all_valid_sequences) {
+ if (s.size() > 1) {
+ AssertInvalidUTF8(s.substr(0, s.size() - 1));
+ }
+ }
+}
+
+TEST_F(UTF8ValidationTest, TwoCharacterTruncated) {
+ for (const auto& s1 : all_valid_sequences) {
+ for (const auto& s2 : all_valid_sequences) {
+ if (s2.size() > 1) {
+ AssertInvalidUTF8(s1 + s2.substr(0, s2.size() - 1));
+ AssertInvalidUTF8(s2.substr(0, s2.size() - 1) + s1);
+ }
+ }
+ }
+}
+
+TEST_F(UTF8ValidationTest, OneCharacterInvalid) {
+ for (const auto& s : all_invalid_sequences) {
+ AssertInvalidUTF8(s);
+ }
+}
+
+TEST_F(UTF8ValidationTest, TwoCharacterInvalid) {
+ for (const auto& s1 : all_valid_sequences) {
+ for (const auto& s2 : all_invalid_sequences) {
+ AssertInvalidUTF8(s1 + s2);
+ AssertInvalidUTF8(s2 + s1);
+ }
+ }
+ for (const auto& s1 : all_invalid_sequences) {
+ for (const auto& s2 : all_invalid_sequences) {
+ AssertInvalidUTF8(s1 + s2);
+ }
+ }
+}
+
+TEST_F(UTF8ValidationTest, RandomInvalid) {
+#ifdef ARROW_VALGRIND
+ const int niters = 50;
+#else
+ const int niters = 1000;
+#endif
+ const int nchars = 100;
+ std::default_random_engine gen(42);
+ std::uniform_int_distribution<size_t> valid_dist(0, all_valid_sequences.size() - 1);
+ std::uniform_int_distribution<int> invalid_pos_dist(0, nchars - 1);
+ std::uniform_int_distribution<size_t> invalid_dist(0, all_invalid_sequences.size() - 1);
+
+ for (int i = 0; i < niters; ++i) {
+ std::string s;
+ s.reserve(nchars * 4);
+ // Stuff a single invalid sequence somewhere in a valid UTF8 stream
+ int invalid_pos = invalid_pos_dist(gen);
+ for (int j = 0; j < nchars; ++j) {
+ if (j == invalid_pos) {
+ s += all_invalid_sequences[invalid_dist(gen)];
+ } else {
+ s += all_valid_sequences[valid_dist(gen)];
+ }
+ }
+ AssertInvalidUTF8(s);
+ }
+}
+
+TEST_F(UTF8ValidationTest, RandomTruncated) {
+#ifdef ARROW_VALGRIND
+ const int niters = 50;
+#else
+ const int niters = 1000;
+#endif
+ const int nchars = 100;
+ std::default_random_engine gen(42);
+ std::uniform_int_distribution<size_t> valid_dist(0, all_valid_sequences.size() - 1);
+ std::uniform_int_distribution<int> invalid_pos_dist(0, nchars - 1);
+
+ for (int i = 0; i < niters; ++i) {
+ std::string s;
+ s.reserve(nchars * 4);
+ // Truncate a single sequence somewhere in a valid UTF8 stream
+ int invalid_pos = invalid_pos_dist(gen);
+ for (int j = 0; j < nchars; ++j) {
+ if (j == invalid_pos) {
+ while (true) {
+ // Ensure we truncate a 2-byte or more sequence
+ const std::string& t = all_valid_sequences[valid_dist(gen)];
+ if (t.size() > 1) {
+ s += t.substr(0, t.size() - 1);
+ break;
+ }
+ }
+ } else {
+ s += all_valid_sequences[valid_dist(gen)];
+ }
+ }
+ AssertInvalidUTF8(s);
+ }
+}
+
+TEST(SkipUTF8BOM, Basics) {
+ auto CheckOk = [](const std::string& s, size_t expected_offset) -> void {
+ const uint8_t* data = reinterpret_cast<const uint8_t*>(s.data());
+ const uint8_t* res;
+ ASSERT_OK_AND_ASSIGN(res, SkipUTF8BOM(data, static_cast<int64_t>(s.size())));
+ ASSERT_NE(res, nullptr);
+ ASSERT_EQ(res - data, expected_offset);
+ };
+
+ auto CheckTruncated = [](const std::string& s) -> void {
+ const uint8_t* data = reinterpret_cast<const uint8_t*>(s.data());
+ ASSERT_RAISES(Invalid, SkipUTF8BOM(data, static_cast<int64_t>(s.size())));
+ };
+
+ CheckOk("", 0);
+ CheckOk("a", 0);
+ CheckOk("ab", 0);
+ CheckOk("abc", 0);
+ CheckOk("abcd", 0);
+ CheckOk("\xc3\xa9", 0);
+ CheckOk("\xee", 0);
+ CheckOk("\xef\xbc", 0);
+ CheckOk("\xef\xbb\xbe", 0);
+ CheckOk("\xef\xbb\xbf", 3);
+ CheckOk("\xef\xbb\xbfx", 3);
+
+ CheckTruncated("\xef");
+ CheckTruncated("\xef\xbb");
+}
+
+TEST(UTF8ToWideString, Basics) {
+ auto CheckOk = [](const std::string& s, const std::wstring& expected) -> void {
+ ASSERT_OK_AND_ASSIGN(std::wstring ws, UTF8ToWideString(s));
+ ASSERT_EQ(ws, expected);
+ };
+
+ auto CheckInvalid = [](const std::string& s) -> void {
+ ASSERT_RAISES(Invalid, UTF8ToWideString(s));
+ };
+
+ CheckOk("", L"");
+ CheckOk("foo", L"foo");
+ CheckOk("h\xc3\xa9h\xc3\xa9", L"h\u00e9h\u00e9");
+ CheckOk("\xf0\x9f\x98\x80", L"\U0001F600");
+ CheckOk("\xf4\x8f\xbf\xbf", L"\U0010FFFF");
+ CheckOk({0, 'x'}, {0, L'x'});
+
+ CheckInvalid("\xff");
+ CheckInvalid("h\xc3");
+}
+
+TEST(WideStringToUTF8, Basics) {
+ auto CheckOk = [](const std::wstring& ws, const std::string& expected) -> void {
+ ASSERT_OK_AND_ASSIGN(std::string s, WideStringToUTF8(ws));
+ ASSERT_EQ(s, expected);
+ };
+
+ auto CheckInvalid = [](const std::wstring& ws) -> void {
+ ASSERT_RAISES(Invalid, WideStringToUTF8(ws));
+ };
+
+ CheckOk(L"", "");
+ CheckOk(L"foo", "foo");
+ CheckOk(L"h\u00e9h\u00e9", "h\xc3\xa9h\xc3\xa9");
+ CheckOk(L"\U0001F600", "\xf0\x9f\x98\x80");
+ CheckOk(L"\U0010FFFF", "\xf4\x8f\xbf\xbf");
+ CheckOk({0, L'x'}, {0, 'x'});
+
+ // Lone surrogate
+ CheckInvalid({0xD800});
+ CheckInvalid({0xDFFF});
+ // Invalid code point
+#if WCHAR_MAX > 0xFFFF
+ CheckInvalid({0x110000});
+#endif
+}
+
+TEST(UTF8DecodeReverse, Basics) {
+ auto CheckOk = [](const std::string& s) -> void {
+ const uint8_t* begin = reinterpret_cast<const uint8_t*>(s.c_str());
+ const uint8_t* end = begin + s.length();
+ const uint8_t* i = end - 1;
+ uint32_t codepoint;
+ EXPECT_TRUE(UTF8DecodeReverse(&i, &codepoint));
+ EXPECT_EQ(i, begin - 1);
+ };
+
+ // 0x80 == 0b10000000
+ // 0xC0 == 0b11000000
+ // 0xE0 == 0b11100000
+ // 0xF0 == 0b11110000
+ CheckOk("a");
+ CheckOk("\xC0\x80");
+ CheckOk("\xE0\x80\x80");
+ CheckOk("\xF0\x80\x80\x80");
+
+ auto CheckInvalid = [](const std::string& s) -> void {
+ const uint8_t* begin = reinterpret_cast<const uint8_t*>(s.c_str());
+ const uint8_t* end = begin + s.length();
+ const uint8_t* i = end - 1;
+ uint32_t codepoint;
+ EXPECT_FALSE(UTF8DecodeReverse(&i, &codepoint));
+ };
+
+ // too many continuation code units
+ CheckInvalid("a\x80");
+ CheckInvalid("\xC0\x80\x80");
+ CheckInvalid("\xE0\x80\x80\x80");
+ CheckInvalid("\xF0\x80\x80\x80\x80");
+ // not enough continuation code units
+ CheckInvalid("\xC0");
+ CheckInvalid("\xE0\x80");
+ CheckInvalid("\xF0\x80\x80");
+}
+
+TEST(UTF8FindIf, Basics) {
+ auto CheckOk = [](const std::string& s, unsigned char test, int64_t offset_left,
+ int64_t offset_right) -> void {
+ const uint8_t* begin = reinterpret_cast<const uint8_t*>(s.c_str());
+ const uint8_t* end = begin + s.length();
+ std::reverse_iterator<const uint8_t*> rbegin(end);
+ std::reverse_iterator<const uint8_t*> rend(begin);
+ const uint8_t* left = nullptr;
+ const uint8_t* right = nullptr;
+ auto predicate = [&](uint32_t c) { return c == test; };
+ EXPECT_TRUE(UTF8FindIf(begin, end, predicate, &left));
+ EXPECT_TRUE(UTF8FindIfReverse(begin, end, predicate, &right));
+ EXPECT_EQ(offset_left, left - begin);
+ EXPECT_EQ(offset_right, right - begin);
+ EXPECT_EQ(std::find_if(begin, end, predicate) - begin, left - begin);
+ EXPECT_EQ(std::find_if(rbegin, rend, predicate).base() - begin, right - begin);
+ };
+ auto CheckOkUTF8 = [](const std::string& s, uint32_t test, int64_t offset_left,
+ int64_t offset_right) -> void {
+ const uint8_t* begin = reinterpret_cast<const uint8_t*>(s.c_str());
+ const uint8_t* end = begin + s.length();
+ std::reverse_iterator<const uint8_t*> rbegin(end);
+ std::reverse_iterator<const uint8_t*> rend(begin);
+ const uint8_t* left = nullptr;
+ const uint8_t* right = nullptr;
+ auto predicate = [&](uint32_t c) { return c == test; };
+ EXPECT_TRUE(UTF8FindIf(begin, end, predicate, &left));
+ EXPECT_TRUE(UTF8FindIfReverse(begin, end, predicate, &right));
+ EXPECT_EQ(offset_left, left - begin);
+ EXPECT_EQ(offset_right, right - begin);
+ // we cannot check the unicode version with find_if semantics, because it's byte based
+ // EXPECT_EQ(std::find_if(begin, end, predicate) - begin, left - begin);
+ // EXPECT_EQ(std::find_if(rbegin, rend, predicate).base() - begin, right - begin);
+ };
+
+ CheckOk("aaaba", 'a', 0, 5);
+ CheckOkUTF8("aaaβa", 'a', 0, 6);
+
+ CheckOk("aaaba", 'b', 3, 4);
+ CheckOkUTF8("aaaβa", U'β', 3, 5);
+
+ CheckOk("aaababa", 'b', 3, 6);
+ CheckOkUTF8("aaaβaβa", U'β', 3, 8);
+
+ CheckOk("aaababa", 'c', 7, 0);
+ CheckOk("aaaβaβa", 'c', 9, 0);
+ CheckOkUTF8("aaaβaβa", U'ɑ', 9, 0);
+
+ CheckOk("a", 'a', 0, 1);
+ CheckOkUTF8("ɑ", U'ɑ', 0, 2);
+
+ CheckOk("a", 'b', 1, 0);
+ CheckOkUTF8("ɑ", 'b', 2, 0);
+
+ CheckOk("", 'b', 0, 0);
+ CheckOkUTF8("", U'β', 0, 0);
+}
+
+TEST(UTF8Length, Basics) {
+ auto length = [](const std::string& s) {
+ const auto* p = reinterpret_cast<const uint8_t*>(s.data());
+ return UTF8Length(p, p + s.length());
+ };
+ ASSERT_EQ(length("abcde"), 5);
+ // accented a encoded as a single codepoint
+ ASSERT_EQ(length("\xc3\x81"
+ "bcde"),
+ 5);
+ // accented a encoded as two codepoints via combining character
+ ASSERT_EQ(length("a\xcc\x81"
+ "bcde"),
+ 6);
+ // hiragana a (3 bytes)
+ ASSERT_EQ(length("\xe3\x81\x81"), 1);
+ // raised hands emoji (4 bytes)
+ ASSERT_EQ(length("\xf0\x9f\x99\x8c"), 1);
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/value_parsing.cc b/src/arrow/cpp/src/arrow/util/value_parsing.cc
new file mode 100644
index 000000000..adc333ecf
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/value_parsing.cc
@@ -0,0 +1,87 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/value_parsing.h"
+
+#include <string>
+#include <utility>
+
+#include "arrow/vendored/fast_float/fast_float.h"
+
+namespace arrow {
+namespace internal {
+
+bool StringToFloat(const char* s, size_t length, float* out) {
+ const auto res = ::arrow_vendored::fast_float::from_chars(s, s + length, *out);
+ return res.ec == std::errc() && res.ptr == s + length;
+}
+
+bool StringToFloat(const char* s, size_t length, double* out) {
+ const auto res = ::arrow_vendored::fast_float::from_chars(s, s + length, *out);
+ return res.ec == std::errc() && res.ptr == s + length;
+}
+
+// ----------------------------------------------------------------------
+// strptime-like parsing
+
+namespace {
+
+class StrptimeTimestampParser : public TimestampParser {
+ public:
+ explicit StrptimeTimestampParser(std::string format) : format_(std::move(format)) {}
+
+ bool operator()(const char* s, size_t length, TimeUnit::type out_unit,
+ int64_t* out) const override {
+ return ParseTimestampStrptime(s, length, format_.c_str(),
+ /*ignore_time_in_day=*/false,
+ /*allow_trailing_chars=*/false, out_unit, out);
+ }
+
+ const char* kind() const override { return "strptime"; }
+
+ const char* format() const override { return format_.c_str(); }
+
+ private:
+ std::string format_;
+};
+
+class ISO8601Parser : public TimestampParser {
+ public:
+ ISO8601Parser() {}
+
+ bool operator()(const char* s, size_t length, TimeUnit::type out_unit,
+ int64_t* out) const override {
+ return ParseTimestampISO8601(s, length, out_unit, out);
+ }
+
+ const char* kind() const override { return "iso8601"; }
+};
+
+} // namespace
+} // namespace internal
+
+const char* TimestampParser::format() const { return ""; }
+
+std::shared_ptr<TimestampParser> TimestampParser::MakeStrptime(std::string format) {
+ return std::make_shared<internal::StrptimeTimestampParser>(std::move(format));
+}
+
+std::shared_ptr<TimestampParser> TimestampParser::MakeISO8601() {
+ return std::make_shared<internal::ISO8601Parser>();
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/value_parsing.h b/src/arrow/cpp/src/arrow/util/value_parsing.h
new file mode 100644
index 000000000..d99634e12
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/value_parsing.h
@@ -0,0 +1,853 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This is a private header for string-to-number parsing utilities
+
+#pragma once
+
+#include <cassert>
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <string>
+#include <type_traits>
+
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/time.h"
+#include "arrow/util/visibility.h"
+#include "arrow/vendored/datetime.h"
+#include "arrow/vendored/strptime.h"
+
+namespace arrow {
+
+/// \brief A virtual string to timestamp parser
+class ARROW_EXPORT TimestampParser {
+ public:
+ virtual ~TimestampParser() = default;
+
+ virtual bool operator()(const char* s, size_t length, TimeUnit::type out_unit,
+ int64_t* out) const = 0;
+
+ virtual const char* kind() const = 0;
+
+ virtual const char* format() const;
+
+ /// \brief Create a TimestampParser that recognizes strptime-like format strings
+ static std::shared_ptr<TimestampParser> MakeStrptime(std::string format);
+
+ /// \brief Create a TimestampParser that recognizes (locale-agnostic) ISO8601
+ /// timestamps
+ static std::shared_ptr<TimestampParser> MakeISO8601();
+};
+
+namespace internal {
+
+/// \brief The entry point for conversion from strings.
+///
+/// Specializations of StringConverter for `ARROW_TYPE` must define:
+/// - A default constructible member type `value_type` which will be yielded on a
+/// successful parse.
+/// - The static member function `Convert`, callable with signature
+/// `(const ARROW_TYPE& t, const char* s, size_t length, value_type* out)`.
+/// `Convert` returns truthy for successful parses and assigns the parsed values to
+/// `*out`. Parameters required for parsing (for example a timestamp's TimeUnit)
+/// are acquired from the type parameter `t`.
+template <typename ARROW_TYPE, typename Enable = void>
+struct StringConverter;
+
+template <typename T>
+struct is_parseable {
+ template <typename U, typename = typename StringConverter<U>::value_type>
+ static std::true_type Test(U*);
+
+ template <typename U>
+ static std::false_type Test(...);
+
+ static constexpr bool value = decltype(Test<T>(NULLPTR))::value;
+};
+
+template <typename T, typename R = void>
+using enable_if_parseable = enable_if_t<is_parseable<T>::value, R>;
+
+template <>
+struct StringConverter<BooleanType> {
+ using value_type = bool;
+
+ static bool Convert(const BooleanType&, const char* s, size_t length, value_type* out) {
+ if (length == 1) {
+ // "0" or "1"?
+ if (s[0] == '0') {
+ *out = false;
+ return true;
+ }
+ if (s[0] == '1') {
+ *out = true;
+ return true;
+ }
+ return false;
+ }
+ if (length == 4) {
+ // "true"?
+ *out = true;
+ return ((s[0] == 't' || s[0] == 'T') && (s[1] == 'r' || s[1] == 'R') &&
+ (s[2] == 'u' || s[2] == 'U') && (s[3] == 'e' || s[3] == 'E'));
+ }
+ if (length == 5) {
+ // "false"?
+ *out = false;
+ return ((s[0] == 'f' || s[0] == 'F') && (s[1] == 'a' || s[1] == 'A') &&
+ (s[2] == 'l' || s[2] == 'L') && (s[3] == 's' || s[3] == 'S') &&
+ (s[4] == 'e' || s[4] == 'E'));
+ }
+ return false;
+ }
+};
+
+// Ideas for faster float parsing:
+// - http://rapidjson.org/md_doc_internals.html#ParsingDouble
+// - https://github.com/google/double-conversion [used here]
+// - https://github.com/achan001/dtoa-fast
+
+ARROW_EXPORT
+bool StringToFloat(const char* s, size_t length, float* out);
+
+ARROW_EXPORT
+bool StringToFloat(const char* s, size_t length, double* out);
+
+template <>
+struct StringConverter<FloatType> {
+ using value_type = float;
+
+ static bool Convert(const FloatType&, const char* s, size_t length, value_type* out) {
+ return ARROW_PREDICT_TRUE(StringToFloat(s, length, out));
+ }
+};
+
+template <>
+struct StringConverter<DoubleType> {
+ using value_type = double;
+
+ static bool Convert(const DoubleType&, const char* s, size_t length, value_type* out) {
+ return ARROW_PREDICT_TRUE(StringToFloat(s, length, out));
+ }
+};
+
+// NOTE: HalfFloatType would require a half<->float conversion library
+
+inline uint8_t ParseDecimalDigit(char c) { return static_cast<uint8_t>(c - '0'); }
+
+#define PARSE_UNSIGNED_ITERATION(C_TYPE) \
+ if (length > 0) { \
+ uint8_t digit = ParseDecimalDigit(*s++); \
+ result = static_cast<C_TYPE>(result * 10U); \
+ length--; \
+ if (ARROW_PREDICT_FALSE(digit > 9U)) { \
+ /* Non-digit */ \
+ return false; \
+ } \
+ result = static_cast<C_TYPE>(result + digit); \
+ } else { \
+ break; \
+ }
+
+#define PARSE_UNSIGNED_ITERATION_LAST(C_TYPE) \
+ if (length > 0) { \
+ if (ARROW_PREDICT_FALSE(result > std::numeric_limits<C_TYPE>::max() / 10U)) { \
+ /* Overflow */ \
+ return false; \
+ } \
+ uint8_t digit = ParseDecimalDigit(*s++); \
+ result = static_cast<C_TYPE>(result * 10U); \
+ C_TYPE new_result = static_cast<C_TYPE>(result + digit); \
+ if (ARROW_PREDICT_FALSE(--length > 0)) { \
+ /* Too many digits */ \
+ return false; \
+ } \
+ if (ARROW_PREDICT_FALSE(digit > 9U)) { \
+ /* Non-digit */ \
+ return false; \
+ } \
+ if (ARROW_PREDICT_FALSE(new_result < result)) { \
+ /* Overflow */ \
+ return false; \
+ } \
+ result = new_result; \
+ }
+
+inline bool ParseUnsigned(const char* s, size_t length, uint8_t* out) {
+ uint8_t result = 0;
+
+ do {
+ PARSE_UNSIGNED_ITERATION(uint8_t);
+ PARSE_UNSIGNED_ITERATION(uint8_t);
+ PARSE_UNSIGNED_ITERATION_LAST(uint8_t);
+ } while (false);
+ *out = result;
+ return true;
+}
+
+inline bool ParseUnsigned(const char* s, size_t length, uint16_t* out) {
+ uint16_t result = 0;
+ do {
+ PARSE_UNSIGNED_ITERATION(uint16_t);
+ PARSE_UNSIGNED_ITERATION(uint16_t);
+ PARSE_UNSIGNED_ITERATION(uint16_t);
+ PARSE_UNSIGNED_ITERATION(uint16_t);
+ PARSE_UNSIGNED_ITERATION_LAST(uint16_t);
+ } while (false);
+ *out = result;
+ return true;
+}
+
+inline bool ParseUnsigned(const char* s, size_t length, uint32_t* out) {
+ uint32_t result = 0;
+ do {
+ PARSE_UNSIGNED_ITERATION(uint32_t);
+ PARSE_UNSIGNED_ITERATION(uint32_t);
+ PARSE_UNSIGNED_ITERATION(uint32_t);
+ PARSE_UNSIGNED_ITERATION(uint32_t);
+ PARSE_UNSIGNED_ITERATION(uint32_t);
+
+ PARSE_UNSIGNED_ITERATION(uint32_t);
+ PARSE_UNSIGNED_ITERATION(uint32_t);
+ PARSE_UNSIGNED_ITERATION(uint32_t);
+ PARSE_UNSIGNED_ITERATION(uint32_t);
+
+ PARSE_UNSIGNED_ITERATION_LAST(uint32_t);
+ } while (false);
+ *out = result;
+ return true;
+}
+
+inline bool ParseUnsigned(const char* s, size_t length, uint64_t* out) {
+ uint64_t result = 0;
+ do {
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+ PARSE_UNSIGNED_ITERATION(uint64_t);
+
+ PARSE_UNSIGNED_ITERATION_LAST(uint64_t);
+ } while (false);
+ *out = result;
+ return true;
+}
+
+#undef PARSE_UNSIGNED_ITERATION
+#undef PARSE_UNSIGNED_ITERATION_LAST
+
+template <typename T>
+bool ParseHex(const char* s, size_t length, T* out) {
+ // lets make sure that the length of the string is not too big
+ if (!ARROW_PREDICT_TRUE(sizeof(T) * 2 >= length && length > 0)) {
+ return false;
+ }
+ T result = 0;
+ for (size_t i = 0; i < length; i++) {
+ result = static_cast<T>(result << 4);
+ if (s[i] >= '0' && s[i] <= '9') {
+ result = static_cast<T>(result | (s[i] - '0'));
+ } else if (s[i] >= 'A' && s[i] <= 'F') {
+ result = static_cast<T>(result | (s[i] - 'A' + 10));
+ } else if (s[i] >= 'a' && s[i] <= 'f') {
+ result = static_cast<T>(result | (s[i] - 'a' + 10));
+ } else {
+ /* Non-digit */
+ return false;
+ }
+ }
+ *out = result;
+ return true;
+}
+
+template <class ARROW_TYPE>
+struct StringToUnsignedIntConverterMixin {
+ using value_type = typename ARROW_TYPE::c_type;
+
+ static bool Convert(const ARROW_TYPE&, const char* s, size_t length, value_type* out) {
+ if (ARROW_PREDICT_FALSE(length == 0)) {
+ return false;
+ }
+ // If it starts with 0x then its hex
+ if (length > 2 && s[0] == '0' && ((s[1] == 'x') || (s[1] == 'X'))) {
+ length -= 2;
+ s += 2;
+
+ return ARROW_PREDICT_TRUE(ParseHex(s, length, out));
+ }
+ // Skip leading zeros
+ while (length > 0 && *s == '0') {
+ length--;
+ s++;
+ }
+ return ParseUnsigned(s, length, out);
+ }
+};
+
+template <>
+struct StringConverter<UInt8Type> : public StringToUnsignedIntConverterMixin<UInt8Type> {
+ using StringToUnsignedIntConverterMixin<UInt8Type>::StringToUnsignedIntConverterMixin;
+};
+
+template <>
+struct StringConverter<UInt16Type>
+ : public StringToUnsignedIntConverterMixin<UInt16Type> {
+ using StringToUnsignedIntConverterMixin<UInt16Type>::StringToUnsignedIntConverterMixin;
+};
+
+template <>
+struct StringConverter<UInt32Type>
+ : public StringToUnsignedIntConverterMixin<UInt32Type> {
+ using StringToUnsignedIntConverterMixin<UInt32Type>::StringToUnsignedIntConverterMixin;
+};
+
+template <>
+struct StringConverter<UInt64Type>
+ : public StringToUnsignedIntConverterMixin<UInt64Type> {
+ using StringToUnsignedIntConverterMixin<UInt64Type>::StringToUnsignedIntConverterMixin;
+};
+
+template <class ARROW_TYPE>
+struct StringToSignedIntConverterMixin {
+ using value_type = typename ARROW_TYPE::c_type;
+ using unsigned_type = typename std::make_unsigned<value_type>::type;
+
+ static bool Convert(const ARROW_TYPE&, const char* s, size_t length, value_type* out) {
+ static constexpr auto max_positive =
+ static_cast<unsigned_type>(std::numeric_limits<value_type>::max());
+ // Assuming two's complement
+ static constexpr unsigned_type max_negative = max_positive + 1;
+ bool negative = false;
+ unsigned_type unsigned_value = 0;
+
+ if (ARROW_PREDICT_FALSE(length == 0)) {
+ return false;
+ }
+ // If it starts with 0x then its hex
+ if (length > 2 && s[0] == '0' && ((s[1] == 'x') || (s[1] == 'X'))) {
+ length -= 2;
+ s += 2;
+
+ if (!ARROW_PREDICT_TRUE(ParseHex(s, length, &unsigned_value))) {
+ return false;
+ }
+ *out = static_cast<value_type>(unsigned_value);
+ return true;
+ }
+
+ if (*s == '-') {
+ negative = true;
+ s++;
+ if (--length == 0) {
+ return false;
+ }
+ }
+ // Skip leading zeros
+ while (length > 0 && *s == '0') {
+ length--;
+ s++;
+ }
+ if (!ARROW_PREDICT_TRUE(ParseUnsigned(s, length, &unsigned_value))) {
+ return false;
+ }
+ if (negative) {
+ if (ARROW_PREDICT_FALSE(unsigned_value > max_negative)) {
+ return false;
+ }
+ // To avoid both compiler warnings (with unsigned negation)
+ // and undefined behaviour (with signed negation overflow),
+ // use the expanded formula for 2's complement negation.
+ *out = static_cast<value_type>(~unsigned_value + 1);
+ } else {
+ if (ARROW_PREDICT_FALSE(unsigned_value > max_positive)) {
+ return false;
+ }
+ *out = static_cast<value_type>(unsigned_value);
+ }
+ return true;
+ }
+};
+
+template <>
+struct StringConverter<Int8Type> : public StringToSignedIntConverterMixin<Int8Type> {
+ using StringToSignedIntConverterMixin<Int8Type>::StringToSignedIntConverterMixin;
+};
+
+template <>
+struct StringConverter<Int16Type> : public StringToSignedIntConverterMixin<Int16Type> {
+ using StringToSignedIntConverterMixin<Int16Type>::StringToSignedIntConverterMixin;
+};
+
+template <>
+struct StringConverter<Int32Type> : public StringToSignedIntConverterMixin<Int32Type> {
+ using StringToSignedIntConverterMixin<Int32Type>::StringToSignedIntConverterMixin;
+};
+
+template <>
+struct StringConverter<Int64Type> : public StringToSignedIntConverterMixin<Int64Type> {
+ using StringToSignedIntConverterMixin<Int64Type>::StringToSignedIntConverterMixin;
+};
+
+namespace detail {
+
+// Inline-able ISO-8601 parser
+
+using ts_type = TimestampType::c_type;
+
+template <typename Duration>
+static inline bool ParseYYYY_MM_DD(const char* s, Duration* since_epoch) {
+ uint16_t year = 0;
+ uint8_t month = 0;
+ uint8_t day = 0;
+ if (ARROW_PREDICT_FALSE(s[4] != '-') || ARROW_PREDICT_FALSE(s[7] != '-')) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 0, 4, &year))) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 5, 2, &month))) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 8, 2, &day))) {
+ return false;
+ }
+ arrow_vendored::date::year_month_day ymd{arrow_vendored::date::year{year},
+ arrow_vendored::date::month{month},
+ arrow_vendored::date::day{day}};
+ if (ARROW_PREDICT_FALSE(!ymd.ok())) return false;
+
+ *since_epoch = std::chrono::duration_cast<Duration>(
+ arrow_vendored::date::sys_days{ymd}.time_since_epoch());
+ return true;
+}
+
+template <typename Duration>
+static inline bool ParseHH(const char* s, Duration* out) {
+ uint8_t hours = 0;
+ if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 0, 2, &hours))) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(hours >= 24)) {
+ return false;
+ }
+ *out = std::chrono::duration_cast<Duration>(std::chrono::hours(hours));
+ return true;
+}
+
+template <typename Duration>
+static inline bool ParseHH_MM(const char* s, Duration* out) {
+ uint8_t hours = 0;
+ uint8_t minutes = 0;
+ if (ARROW_PREDICT_FALSE(s[2] != ':')) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 0, 2, &hours))) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 3, 2, &minutes))) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(hours >= 24)) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(minutes >= 60)) {
+ return false;
+ }
+ *out = std::chrono::duration_cast<Duration>(std::chrono::hours(hours) +
+ std::chrono::minutes(minutes));
+ return true;
+}
+
+template <typename Duration>
+static inline bool ParseHH_MM_SS(const char* s, Duration* out) {
+ uint8_t hours = 0;
+ uint8_t minutes = 0;
+ uint8_t seconds = 0;
+ if (ARROW_PREDICT_FALSE(s[2] != ':') || ARROW_PREDICT_FALSE(s[5] != ':')) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 0, 2, &hours))) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 3, 2, &minutes))) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 6, 2, &seconds))) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(hours >= 24)) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(minutes >= 60)) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(seconds >= 60)) {
+ return false;
+ }
+ *out = std::chrono::duration_cast<Duration>(std::chrono::hours(hours) +
+ std::chrono::minutes(minutes) +
+ std::chrono::seconds(seconds));
+ return true;
+}
+
+static inline bool ParseSubSeconds(const char* s, size_t length, TimeUnit::type unit,
+ uint32_t* out) {
+ // The decimal point has been peeled off at this point
+
+ // Fail if number of decimal places provided exceeds what the unit can hold.
+ // Calculate how many trailing decimal places are omitted for the unit
+ // e.g. if 4 decimal places are provided and unit is MICRO, 2 are missing
+ size_t omitted = 0;
+ switch (unit) {
+ case TimeUnit::MILLI:
+ if (ARROW_PREDICT_FALSE(length > 3)) {
+ return false;
+ }
+ if (length < 3) {
+ omitted = 3 - length;
+ }
+ break;
+ case TimeUnit::MICRO:
+ if (ARROW_PREDICT_FALSE(length > 6)) {
+ return false;
+ }
+ if (length < 6) {
+ omitted = 6 - length;
+ }
+ break;
+ case TimeUnit::NANO:
+ if (ARROW_PREDICT_FALSE(length > 9)) {
+ return false;
+ }
+ if (length < 9) {
+ omitted = 9 - length;
+ }
+ break;
+ default:
+ return false;
+ }
+
+ if (ARROW_PREDICT_TRUE(omitted == 0)) {
+ return ParseUnsigned(s, length, out);
+ } else {
+ uint32_t subseconds = 0;
+ bool success = ParseUnsigned(s, length, &subseconds);
+ if (ARROW_PREDICT_TRUE(success)) {
+ switch (omitted) {
+ case 1:
+ *out = subseconds * 10;
+ break;
+ case 2:
+ *out = subseconds * 100;
+ break;
+ case 3:
+ *out = subseconds * 1000;
+ break;
+ case 4:
+ *out = subseconds * 10000;
+ break;
+ case 5:
+ *out = subseconds * 100000;
+ break;
+ case 6:
+ *out = subseconds * 1000000;
+ break;
+ case 7:
+ *out = subseconds * 10000000;
+ break;
+ case 8:
+ *out = subseconds * 100000000;
+ break;
+ default:
+ // Impossible case
+ break;
+ }
+ return true;
+ } else {
+ return false;
+ }
+ }
+}
+
+} // namespace detail
+
+static inline bool ParseTimestampISO8601(const char* s, size_t length,
+ TimeUnit::type unit,
+ TimestampType::c_type* out) {
+ using seconds_type = std::chrono::duration<TimestampType::c_type>;
+
+ // We allow the following formats for all units:
+ // - "YYYY-MM-DD"
+ // - "YYYY-MM-DD[ T]hhZ?"
+ // - "YYYY-MM-DD[ T]hh:mmZ?"
+ // - "YYYY-MM-DD[ T]hh:mm:ssZ?"
+ //
+ // We allow the following formats for unit == MILLI, MICRO, or NANO:
+ // - "YYYY-MM-DD[ T]hh:mm:ss.s{1,3}Z?"
+ //
+ // We allow the following formats for unit == MICRO, or NANO:
+ // - "YYYY-MM-DD[ T]hh:mm:ss.s{4,6}Z?"
+ //
+ // We allow the following formats for unit == NANO:
+ // - "YYYY-MM-DD[ T]hh:mm:ss.s{7,9}Z?"
+ //
+ // UTC is always assumed, and the DataType's timezone is ignored.
+ //
+
+ if (ARROW_PREDICT_FALSE(length < 10)) return false;
+
+ seconds_type seconds_since_epoch;
+ if (ARROW_PREDICT_FALSE(!detail::ParseYYYY_MM_DD(s, &seconds_since_epoch))) {
+ return false;
+ }
+
+ if (length == 10) {
+ *out = util::CastSecondsToUnit(unit, seconds_since_epoch.count());
+ return true;
+ }
+
+ if (ARROW_PREDICT_FALSE(s[10] != ' ') && ARROW_PREDICT_FALSE(s[10] != 'T')) {
+ return false;
+ }
+
+ if (s[length - 1] == 'Z') {
+ --length;
+ }
+
+ seconds_type seconds_since_midnight;
+ switch (length) {
+ case 13: // YYYY-MM-DD[ T]hh
+ if (ARROW_PREDICT_FALSE(!detail::ParseHH(s + 11, &seconds_since_midnight))) {
+ return false;
+ }
+ break;
+ case 16: // YYYY-MM-DD[ T]hh:mm
+ if (ARROW_PREDICT_FALSE(!detail::ParseHH_MM(s + 11, &seconds_since_midnight))) {
+ return false;
+ }
+ break;
+ case 19: // YYYY-MM-DD[ T]hh:mm:ss
+ case 21: // YYYY-MM-DD[ T]hh:mm:ss.s
+ case 22: // YYYY-MM-DD[ T]hh:mm:ss.ss
+ case 23: // YYYY-MM-DD[ T]hh:mm:ss.sss
+ case 24: // YYYY-MM-DD[ T]hh:mm:ss.ssss
+ case 25: // YYYY-MM-DD[ T]hh:mm:ss.sssss
+ case 26: // YYYY-MM-DD[ T]hh:mm:ss.ssssss
+ case 27: // YYYY-MM-DD[ T]hh:mm:ss.sssssss
+ case 28: // YYYY-MM-DD[ T]hh:mm:ss.ssssssss
+ case 29: // YYYY-MM-DD[ T]hh:mm:ss.sssssssss
+ if (ARROW_PREDICT_FALSE(!detail::ParseHH_MM_SS(s + 11, &seconds_since_midnight))) {
+ return false;
+ }
+ break;
+ default:
+ return false;
+ }
+
+ seconds_since_epoch += seconds_since_midnight;
+
+ if (length <= 19) {
+ *out = util::CastSecondsToUnit(unit, seconds_since_epoch.count());
+ return true;
+ }
+
+ if (ARROW_PREDICT_FALSE(s[19] != '.')) {
+ return false;
+ }
+
+ uint32_t subseconds = 0;
+ if (ARROW_PREDICT_FALSE(
+ !detail::ParseSubSeconds(s + 20, length - 20, unit, &subseconds))) {
+ return false;
+ }
+
+ *out = util::CastSecondsToUnit(unit, seconds_since_epoch.count()) + subseconds;
+ return true;
+}
+
+/// \brief Returns time since the UNIX epoch in the requested unit
+static inline bool ParseTimestampStrptime(const char* buf, size_t length,
+ const char* format, bool ignore_time_in_day,
+ bool allow_trailing_chars, TimeUnit::type unit,
+ int64_t* out) {
+ // NOTE: strptime() is more than 10x faster than arrow_vendored::date::parse().
+ // The buffer may not be nul-terminated
+ std::string clean_copy(buf, length);
+ struct tm result;
+ memset(&result, 0, sizeof(struct tm));
+#ifdef _WIN32
+ char* ret = arrow_strptime(clean_copy.c_str(), format, &result);
+#else
+ char* ret = strptime(clean_copy.c_str(), format, &result);
+#endif
+ if (ret == NULLPTR) {
+ return false;
+ }
+ if (!allow_trailing_chars && static_cast<size_t>(ret - clean_copy.c_str()) != length) {
+ return false;
+ }
+ // ignore the time part
+ arrow_vendored::date::sys_seconds secs =
+ arrow_vendored::date::sys_days(arrow_vendored::date::year(result.tm_year + 1900) /
+ (result.tm_mon + 1) / result.tm_mday);
+ if (!ignore_time_in_day) {
+ secs += (std::chrono::hours(result.tm_hour) + std::chrono::minutes(result.tm_min) +
+ std::chrono::seconds(result.tm_sec));
+ }
+ *out = util::CastSecondsToUnit(unit, secs.time_since_epoch().count());
+ return true;
+}
+
+template <>
+struct StringConverter<TimestampType> {
+ using value_type = int64_t;
+
+ static bool Convert(const TimestampType& type, const char* s, size_t length,
+ value_type* out) {
+ return ParseTimestampISO8601(s, length, type.unit(), out);
+ }
+};
+
+template <>
+struct StringConverter<DurationType>
+ : public StringToSignedIntConverterMixin<DurationType> {
+ using StringToSignedIntConverterMixin<DurationType>::StringToSignedIntConverterMixin;
+};
+
+template <typename DATE_TYPE>
+struct StringConverter<DATE_TYPE, enable_if_date<DATE_TYPE>> {
+ using value_type = typename DATE_TYPE::c_type;
+
+ using duration_type =
+ typename std::conditional<std::is_same<DATE_TYPE, Date32Type>::value,
+ arrow_vendored::date::days,
+ std::chrono::milliseconds>::type;
+
+ static bool Convert(const DATE_TYPE& type, const char* s, size_t length,
+ value_type* out) {
+ if (ARROW_PREDICT_FALSE(length != 10)) {
+ return false;
+ }
+
+ duration_type since_epoch;
+ if (ARROW_PREDICT_FALSE(!detail::ParseYYYY_MM_DD(s, &since_epoch))) {
+ return false;
+ }
+
+ *out = static_cast<value_type>(since_epoch.count());
+ return true;
+ }
+};
+
+template <typename TIME_TYPE>
+struct StringConverter<TIME_TYPE, enable_if_time<TIME_TYPE>> {
+ using value_type = typename TIME_TYPE::c_type;
+
+ // We allow the following formats for all units:
+ // - "hh:mm"
+ // - "hh:mm:ss"
+ //
+ // We allow the following formats for unit == MILLI, MICRO, or NANO:
+ // - "hh:mm:ss.s{1,3}"
+ //
+ // We allow the following formats for unit == MICRO, or NANO:
+ // - "hh:mm:ss.s{4,6}"
+ //
+ // We allow the following formats for unit == NANO:
+ // - "hh:mm:ss.s{7,9}"
+
+ static bool Convert(const TIME_TYPE& type, const char* s, size_t length,
+ value_type* out) {
+ const auto unit = type.unit();
+ std::chrono::seconds since_midnight;
+
+ if (length == 5) {
+ if (ARROW_PREDICT_FALSE(!detail::ParseHH_MM(s, &since_midnight))) {
+ return false;
+ }
+ *out =
+ static_cast<value_type>(util::CastSecondsToUnit(unit, since_midnight.count()));
+ return true;
+ }
+
+ if (ARROW_PREDICT_FALSE(length < 8)) {
+ return false;
+ }
+ if (ARROW_PREDICT_FALSE(!detail::ParseHH_MM_SS(s, &since_midnight))) {
+ return false;
+ }
+
+ *out = static_cast<value_type>(util::CastSecondsToUnit(unit, since_midnight.count()));
+
+ if (length == 8) {
+ return true;
+ }
+
+ if (ARROW_PREDICT_FALSE(s[8] != '.')) {
+ return false;
+ }
+
+ uint32_t subseconds_count = 0;
+ if (ARROW_PREDICT_FALSE(
+ !detail::ParseSubSeconds(s + 9, length - 9, unit, &subseconds_count))) {
+ return false;
+ }
+
+ *out += subseconds_count;
+ return true;
+ }
+};
+
+/// \brief Convenience wrappers around internal::StringConverter.
+template <typename T>
+bool ParseValue(const T& type, const char* s, size_t length,
+ typename StringConverter<T>::value_type* out) {
+ return StringConverter<T>::Convert(type, s, length, out);
+}
+
+template <typename T>
+enable_if_parameter_free<T, bool> ParseValue(
+ const char* s, size_t length, typename StringConverter<T>::value_type* out) {
+ static T type;
+ return StringConverter<T>::Convert(type, s, length, out);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/value_parsing_benchmark.cc b/src/arrow/cpp/src/arrow/util/value_parsing_benchmark.cc
new file mode 100644
index 000000000..40d139316
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/value_parsing_benchmark.cc
@@ -0,0 +1,303 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+#include "arrow/util/formatting.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/value_parsing.h"
+
+namespace arrow {
+namespace internal {
+
+template <typename c_int>
+static std::vector<std::string> MakeIntStrings(int32_t num_items) {
+ using c_int_limits = std::numeric_limits<c_int>;
+ std::vector<std::string> base_strings = {"0",
+ "5",
+ c_int_limits::is_signed ? "-12" : "12",
+ "34",
+ "99",
+ c_int_limits::is_signed ? "-111" : "111",
+ std::to_string(c_int_limits::min()),
+ std::to_string(c_int_limits::max())};
+ std::vector<std::string> strings;
+ for (int32_t i = 0; i < num_items; ++i) {
+ strings.push_back(base_strings[i % base_strings.size()]);
+ }
+ return strings;
+}
+
+template <typename c_int>
+static std::vector<std::string> MakeHexStrings(int32_t num_items) {
+ int32_t num_bytes = sizeof(c_int);
+ const char* kAsciiTable = "0123456789ABCDEF";
+ std::vector<char> large_hex_chars(num_bytes * 2 + 2);
+ large_hex_chars[0] = '0';
+ large_hex_chars[1] = 'x';
+ for (int32_t i = 0; i < num_bytes * 2; ++i) {
+ large_hex_chars[i + 2] = kAsciiTable[i];
+ }
+ std::string large_hex(&large_hex_chars[0], large_hex_chars.size());
+
+ std::vector<std::string> base_strings = {"0x0", "0xA5", "0x5E", large_hex};
+ std::vector<std::string> strings;
+ for (int32_t i = 0; i < num_items; ++i) {
+ strings.push_back(base_strings[i % base_strings.size()]);
+ }
+ return strings;
+}
+
+static std::vector<std::string> MakeFloatStrings(int32_t num_items) {
+ std::vector<std::string> base_strings = {"0.0", "5", "-12.3",
+ "98765430000", "3456.789", "0.0012345",
+ "2.34567e8", "-5.67e-8"};
+ std::vector<std::string> strings;
+ for (int32_t i = 0; i < num_items; ++i) {
+ strings.push_back(base_strings[i % base_strings.size()]);
+ }
+ return strings;
+}
+
+static std::vector<std::string> MakeTimestampStrings(int32_t num_items) {
+ std::vector<std::string> base_strings = {"2018-11-13 17:11:10", "2018-11-13 11:22:33",
+ "2016-02-29 11:22:33"};
+
+ std::vector<std::string> strings;
+ for (int32_t i = 0; i < num_items; ++i) {
+ strings.push_back(base_strings[i % base_strings.size()]);
+ }
+ return strings;
+}
+
+template <typename c_int, typename c_int_limits = std::numeric_limits<c_int>>
+static typename std::enable_if<c_int_limits::is_signed, std::vector<c_int>>::type
+MakeInts(int32_t num_items) {
+ std::vector<c_int> out;
+ // C++ doesn't guarantee that all integer types support std::uniform_int_distribution,
+ // so use a known type (int64_t)
+ randint<int64_t, c_int>(num_items, c_int_limits::min(), c_int_limits::max(), &out);
+ return out;
+}
+
+template <typename c_int, typename c_int_limits = std::numeric_limits<c_int>>
+static typename std::enable_if<!c_int_limits::is_signed, std::vector<c_int>>::type
+MakeInts(int32_t num_items) {
+ std::vector<c_int> out;
+ // See above.
+ randint<uint64_t, c_int>(num_items, c_int_limits::min(), c_int_limits::max(), &out);
+ return out;
+}
+
+template <typename c_float>
+static std::vector<c_float> MakeFloats(int32_t num_items) {
+ std::vector<c_float> out;
+ random_real<double, c_float>(num_items, /*seed =*/42, -1e10, 1e10, &out);
+ return out;
+}
+
+template <typename ARROW_TYPE, typename C_TYPE = typename ARROW_TYPE::c_type>
+static void IntegerParsing(benchmark::State& state) { // NOLINT non-const reference
+ auto strings = MakeIntStrings<C_TYPE>(1000);
+
+ while (state.KeepRunning()) {
+ C_TYPE total = 0;
+ for (const auto& s : strings) {
+ C_TYPE value;
+ if (!ParseValue<ARROW_TYPE>(s.data(), s.length(), &value)) {
+ std::cerr << "Conversion failed for '" << s << "'";
+ std::abort();
+ }
+ total = static_cast<C_TYPE>(total + value);
+ }
+ benchmark::DoNotOptimize(total);
+ }
+ state.SetItemsProcessed(state.iterations() * strings.size());
+}
+
+template <typename ARROW_TYPE, typename C_TYPE = typename ARROW_TYPE::c_type>
+static void HexParsing(benchmark::State& state) { // NOLINT non-const reference
+ auto strings = MakeHexStrings<C_TYPE>(1000);
+
+ while (state.KeepRunning()) {
+ C_TYPE total = 0;
+ for (const auto& s : strings) {
+ C_TYPE value;
+ if (!ParseValue<ARROW_TYPE>(s.data(), s.length(), &value)) {
+ std::cerr << "Conversion failed for '" << s << "'";
+ std::abort();
+ }
+ total = static_cast<C_TYPE>(total + value);
+ }
+ benchmark::DoNotOptimize(total);
+ }
+ state.SetItemsProcessed(state.iterations() * strings.size());
+}
+
+template <typename ARROW_TYPE, typename C_TYPE = typename ARROW_TYPE::c_type>
+static void FloatParsing(benchmark::State& state) { // NOLINT non-const reference
+ auto strings = MakeFloatStrings(1000);
+
+ while (state.KeepRunning()) {
+ C_TYPE total = 0;
+ for (const auto& s : strings) {
+ C_TYPE value;
+ if (!ParseValue<ARROW_TYPE>(s.data(), s.length(), &value)) {
+ std::cerr << "Conversion failed for '" << s << "'";
+ std::abort();
+ }
+ total += value;
+ }
+ benchmark::DoNotOptimize(total);
+ }
+ state.SetItemsProcessed(state.iterations() * strings.size());
+}
+
+static void BenchTimestampParsing(
+ benchmark::State& state, TimeUnit::type unit,
+ const TimestampParser& parser) { // NOLINT non-const reference
+ using c_type = TimestampType::c_type;
+
+ auto strings = MakeTimestampStrings(1000);
+
+ for (auto _ : state) {
+ c_type total = 0;
+ for (const auto& s : strings) {
+ c_type value;
+ if (!parser(s.data(), s.length(), unit, &value)) {
+ std::cerr << "Conversion failed for '" << s << "'";
+ std::abort();
+ }
+ total += value;
+ }
+ benchmark::DoNotOptimize(total);
+ }
+ state.SetItemsProcessed(state.iterations() * strings.size());
+}
+
+template <TimeUnit::type UNIT>
+static void TimestampParsingISO8601(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto parser = TimestampParser::MakeISO8601();
+ BenchTimestampParsing(state, UNIT, *parser);
+}
+
+template <TimeUnit::type UNIT>
+static void TimestampParsingStrptime(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto parser = TimestampParser::MakeStrptime("%Y-%m-%d %H:%M:%S");
+ BenchTimestampParsing(state, UNIT, *parser);
+}
+
+struct DummyAppender {
+ Status operator()(util::string_view v) {
+ if (pos_ >= static_cast<int32_t>(v.size())) {
+ pos_ = 0;
+ }
+ total_ += v[pos_++];
+ return Status::OK();
+ }
+
+ int64_t total_ = 0;
+ int32_t pos_ = 0;
+};
+
+template <typename ARROW_TYPE, typename C_TYPE = typename ARROW_TYPE::c_type>
+static void IntegerFormatting(benchmark::State& state) { // NOLINT non-const reference
+ std::vector<C_TYPE> values = MakeInts<C_TYPE>(1000);
+ StringFormatter<ARROW_TYPE> formatter;
+
+ while (state.KeepRunning()) {
+ DummyAppender appender;
+ for (const auto value : values) {
+ ABORT_NOT_OK(formatter(value, appender));
+ }
+ benchmark::DoNotOptimize(appender.total_);
+ }
+ state.SetItemsProcessed(state.iterations() * values.size());
+}
+
+template <typename ARROW_TYPE, typename C_TYPE = typename ARROW_TYPE::c_type>
+static void FloatFormatting(benchmark::State& state) { // NOLINT non-const reference
+ std::vector<C_TYPE> values = MakeFloats<C_TYPE>(1000);
+ StringFormatter<ARROW_TYPE> formatter;
+
+ while (state.KeepRunning()) {
+ DummyAppender appender;
+ for (const auto value : values) {
+ ABORT_NOT_OK(formatter(value, appender));
+ }
+ benchmark::DoNotOptimize(appender.total_);
+ }
+ state.SetItemsProcessed(state.iterations() * values.size());
+}
+
+BENCHMARK_TEMPLATE(IntegerParsing, Int8Type);
+BENCHMARK_TEMPLATE(IntegerParsing, Int16Type);
+BENCHMARK_TEMPLATE(IntegerParsing, Int32Type);
+BENCHMARK_TEMPLATE(IntegerParsing, Int64Type);
+BENCHMARK_TEMPLATE(IntegerParsing, UInt8Type);
+BENCHMARK_TEMPLATE(IntegerParsing, UInt16Type);
+BENCHMARK_TEMPLATE(IntegerParsing, UInt32Type);
+BENCHMARK_TEMPLATE(IntegerParsing, UInt64Type);
+
+BENCHMARK_TEMPLATE(HexParsing, Int8Type);
+BENCHMARK_TEMPLATE(HexParsing, Int16Type);
+BENCHMARK_TEMPLATE(HexParsing, Int32Type);
+BENCHMARK_TEMPLATE(HexParsing, Int64Type);
+BENCHMARK_TEMPLATE(HexParsing, UInt8Type);
+BENCHMARK_TEMPLATE(HexParsing, UInt16Type);
+BENCHMARK_TEMPLATE(HexParsing, UInt32Type);
+BENCHMARK_TEMPLATE(HexParsing, UInt64Type);
+
+BENCHMARK_TEMPLATE(FloatParsing, FloatType);
+BENCHMARK_TEMPLATE(FloatParsing, DoubleType);
+
+BENCHMARK_TEMPLATE(TimestampParsingISO8601, TimeUnit::SECOND);
+BENCHMARK_TEMPLATE(TimestampParsingISO8601, TimeUnit::MILLI);
+BENCHMARK_TEMPLATE(TimestampParsingISO8601, TimeUnit::MICRO);
+BENCHMARK_TEMPLATE(TimestampParsingISO8601, TimeUnit::NANO);
+BENCHMARK_TEMPLATE(TimestampParsingStrptime, TimeUnit::MILLI);
+
+BENCHMARK_TEMPLATE(IntegerFormatting, Int8Type);
+BENCHMARK_TEMPLATE(IntegerFormatting, Int16Type);
+BENCHMARK_TEMPLATE(IntegerFormatting, Int32Type);
+BENCHMARK_TEMPLATE(IntegerFormatting, Int64Type);
+BENCHMARK_TEMPLATE(IntegerFormatting, UInt8Type);
+BENCHMARK_TEMPLATE(IntegerFormatting, UInt16Type);
+BENCHMARK_TEMPLATE(IntegerFormatting, UInt32Type);
+BENCHMARK_TEMPLATE(IntegerFormatting, UInt64Type);
+
+BENCHMARK_TEMPLATE(FloatFormatting, FloatType);
+BENCHMARK_TEMPLATE(FloatFormatting, DoubleType);
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/value_parsing_test.cc b/src/arrow/cpp/src/arrow/util/value_parsing_test.cc
new file mode 100644
index 000000000..ebbb73339
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/value_parsing_test.cc
@@ -0,0 +1,643 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/value_parsing.h"
+
+namespace arrow {
+namespace internal {
+
+template <typename T>
+void AssertConversion(const T& type, const std::string& s, typename T::c_type expected) {
+ typename T::c_type out{};
+ ASSERT_TRUE(ParseValue(type, s.data(), s.length(), &out))
+ << "Conversion failed for '" << s << "' (expected to return " << expected << ")";
+ ASSERT_EQ(out, expected) << "Conversion failed for '" << s << "'";
+}
+
+template <typename T>
+void AssertConversion(const std::string& s, typename T::c_type expected) {
+ auto type = checked_pointer_cast<T>(TypeTraits<T>::type_singleton());
+ AssertConversion(*type, s, expected);
+}
+
+template <typename T>
+void AssertConversionFails(const T& type, const std::string& s) {
+ typename T::c_type out{};
+ ASSERT_FALSE(ParseValue(type, s.data(), s.length(), &out))
+ << "Conversion should have failed for '" << s << "' (returned " << out << ")";
+}
+
+template <typename T>
+void AssertConversionFails(const std::string& s) {
+ auto type = checked_pointer_cast<T>(TypeTraits<T>::type_singleton());
+ AssertConversionFails(*type, s);
+}
+
+TEST(StringConversion, ToBoolean) {
+ AssertConversion<BooleanType>("true", true);
+ AssertConversion<BooleanType>("tRuE", true);
+ AssertConversion<BooleanType>("FAlse", false);
+ AssertConversion<BooleanType>("false", false);
+ AssertConversion<BooleanType>("1", true);
+ AssertConversion<BooleanType>("0", false);
+
+ AssertConversionFails<BooleanType>("");
+}
+
+TEST(StringConversion, ToFloat) {
+ AssertConversion<FloatType>("1.5", 1.5f);
+ AssertConversion<FloatType>("0", 0.0f);
+ // XXX ASSERT_EQ doesn't distinguish signed zeros
+ AssertConversion<FloatType>("-0.0", -0.0f);
+ AssertConversion<FloatType>("-1e20", -1e20f);
+
+ AssertConversionFails<FloatType>("");
+ AssertConversionFails<FloatType>("e");
+}
+
+TEST(StringConversion, ToDouble) {
+ AssertConversion<DoubleType>("1.5", 1.5);
+ AssertConversion<DoubleType>("0", 0);
+ // XXX ASSERT_EQ doesn't distinguish signed zeros
+ AssertConversion<DoubleType>("-0.0", -0.0);
+ AssertConversion<DoubleType>("-1e100", -1e100);
+
+ AssertConversionFails<DoubleType>("");
+ AssertConversionFails<DoubleType>("e");
+}
+
+#if !defined(_WIN32) || defined(NDEBUG)
+
+TEST(StringConversion, ToFloatLocale) {
+ // French locale uses the comma as decimal point
+ LocaleGuard locale_guard("fr_FR.UTF-8");
+
+ AssertConversion<FloatType>("1.5", 1.5f);
+}
+
+TEST(StringConversion, ToDoubleLocale) {
+ // French locale uses the comma as decimal point
+ LocaleGuard locale_guard("fr_FR.UTF-8");
+
+ AssertConversion<DoubleType>("1.5", 1.5f);
+}
+
+#endif // _WIN32
+
+TEST(StringConversion, ToInt8) {
+ AssertConversion<Int8Type>("0", 0);
+ AssertConversion<Int8Type>("127", 127);
+ AssertConversion<Int8Type>("0127", 127);
+ AssertConversion<Int8Type>("-128", -128);
+ AssertConversion<Int8Type>("-00128", -128);
+
+ // Non-representable values
+ AssertConversionFails<Int8Type>("128");
+ AssertConversionFails<Int8Type>("-129");
+
+ AssertConversionFails<Int8Type>("");
+ AssertConversionFails<Int8Type>("-");
+ AssertConversionFails<Int8Type>("0.0");
+ AssertConversionFails<Int8Type>("e");
+
+ // Hex
+ AssertConversion<Int8Type>("0x0", 0);
+ AssertConversion<Int8Type>("0X1A", 26);
+ AssertConversion<Int8Type>("0xb", 11);
+ AssertConversion<Int8Type>("0x7F", 127);
+ AssertConversion<Int8Type>("0xFF", -1);
+ AssertConversionFails<Int8Type>("0x");
+ AssertConversionFails<Int8Type>("0x100");
+ AssertConversionFails<Int8Type>("0x1g");
+}
+
+TEST(StringConversion, ToUInt8) {
+ AssertConversion<UInt8Type>("0", 0);
+ AssertConversion<UInt8Type>("26", 26);
+ AssertConversion<UInt8Type>("255", 255);
+ AssertConversion<UInt8Type>("0255", 255);
+
+ // Non-representable values
+ AssertConversionFails<UInt8Type>("-1");
+ AssertConversionFails<UInt8Type>("256");
+ AssertConversionFails<UInt8Type>("260");
+ AssertConversionFails<UInt8Type>("1234");
+
+ AssertConversionFails<UInt8Type>("");
+ AssertConversionFails<UInt8Type>("-");
+ AssertConversionFails<UInt8Type>("0.0");
+ AssertConversionFails<UInt8Type>("e");
+
+ // Hex
+ AssertConversion<UInt8Type>("0x0", 0);
+ AssertConversion<UInt8Type>("0x1A", 26);
+ AssertConversion<UInt8Type>("0xb", 11);
+ AssertConversion<UInt8Type>("0x7F", 127);
+ AssertConversion<UInt8Type>("0xFF", 255);
+ AssertConversionFails<UInt8Type>("0x");
+ AssertConversionFails<UInt8Type>("0x100");
+ AssertConversionFails<UInt8Type>("0x1g");
+}
+
+TEST(StringConversion, ToInt16) {
+ AssertConversion<Int16Type>("0", 0);
+ AssertConversion<Int16Type>("32767", 32767);
+ AssertConversion<Int16Type>("032767", 32767);
+ AssertConversion<Int16Type>("-32768", -32768);
+ AssertConversion<Int16Type>("-0032768", -32768);
+
+ // Non-representable values
+ AssertConversionFails<Int16Type>("32768");
+ AssertConversionFails<Int16Type>("-32769");
+
+ AssertConversionFails<Int16Type>("");
+ AssertConversionFails<Int16Type>("-");
+ AssertConversionFails<Int16Type>("0.0");
+ AssertConversionFails<Int16Type>("e");
+
+ // Hex
+ AssertConversion<Int16Type>("0x0", 0);
+ AssertConversion<Int16Type>("0X1aA", 426);
+ AssertConversion<Int16Type>("0xb", 11);
+ AssertConversion<Int16Type>("0x7ffF", 32767);
+ AssertConversion<Int16Type>("0XfffF", -1);
+ AssertConversionFails<Int16Type>("0x");
+ AssertConversionFails<Int16Type>("0x10000");
+ AssertConversionFails<Int16Type>("0x1g");
+}
+
+TEST(StringConversion, ToUInt16) {
+ AssertConversion<UInt16Type>("0", 0);
+ AssertConversion<UInt16Type>("6660", 6660);
+ AssertConversion<UInt16Type>("65535", 65535);
+ AssertConversion<UInt16Type>("065535", 65535);
+
+ // Non-representable values
+ AssertConversionFails<UInt16Type>("-1");
+ AssertConversionFails<UInt16Type>("65536");
+ AssertConversionFails<UInt16Type>("123456");
+
+ AssertConversionFails<UInt16Type>("");
+ AssertConversionFails<UInt16Type>("-");
+ AssertConversionFails<UInt16Type>("0.0");
+ AssertConversionFails<UInt16Type>("e");
+
+ // Hex
+ AssertConversion<UInt16Type>("0x0", 0);
+ AssertConversion<UInt16Type>("0x1aA", 426);
+ AssertConversion<UInt16Type>("0xb", 11);
+ AssertConversion<UInt16Type>("0x7ffF", 32767);
+ AssertConversion<UInt16Type>("0xFffF", 65535);
+ AssertConversionFails<UInt16Type>("0x");
+ AssertConversionFails<UInt16Type>("0x10000");
+ AssertConversionFails<UInt16Type>("0x1g");
+}
+
+TEST(StringConversion, ToInt32) {
+ AssertConversion<Int32Type>("0", 0);
+ AssertConversion<Int32Type>("2147483647", 2147483647);
+ AssertConversion<Int32Type>("02147483647", 2147483647);
+ AssertConversion<Int32Type>("-2147483648", -2147483648LL);
+ AssertConversion<Int32Type>("-002147483648", -2147483648LL);
+
+ // Non-representable values
+ AssertConversionFails<Int32Type>("2147483648");
+ AssertConversionFails<Int32Type>("-2147483649");
+
+ AssertConversionFails<Int32Type>("");
+ AssertConversionFails<Int32Type>("-");
+ AssertConversionFails<Int32Type>("0.0");
+ AssertConversionFails<Int32Type>("e");
+
+ // Hex
+ AssertConversion<Int32Type>("0x0", 0);
+ AssertConversion<Int32Type>("0x123ABC", 1194684);
+ AssertConversion<Int32Type>("0xA4B35", 674613);
+ AssertConversion<Int32Type>("0x7FFFFFFF", 2147483647);
+ AssertConversion<Int32Type>("0x123abc", 1194684);
+ AssertConversion<Int32Type>("0xA4b35", 674613);
+ AssertConversion<Int32Type>("0x7FFFfFfF", 2147483647);
+ AssertConversion<Int32Type>("0XFFFFfFfF", -1);
+ AssertConversionFails<Int32Type>("0X");
+ AssertConversionFails<Int32Type>("0x23512ak");
+}
+
+TEST(StringConversion, ToUInt32) {
+ AssertConversion<UInt32Type>("0", 0);
+ AssertConversion<UInt32Type>("432198765", 432198765UL);
+ AssertConversion<UInt32Type>("4294967295", 4294967295UL);
+ AssertConversion<UInt32Type>("04294967295", 4294967295UL);
+
+ // Non-representable values
+ AssertConversionFails<UInt32Type>("-1");
+ AssertConversionFails<UInt32Type>("4294967296");
+ AssertConversionFails<UInt32Type>("12345678901");
+
+ AssertConversionFails<UInt32Type>("");
+ AssertConversionFails<UInt32Type>("-");
+ AssertConversionFails<UInt32Type>("0.0");
+ AssertConversionFails<UInt32Type>("e");
+
+ // Hex
+ AssertConversion<UInt32Type>("0x0", 0);
+ AssertConversion<UInt32Type>("0x123ABC", 1194684);
+ AssertConversion<UInt32Type>("0xA4B35", 674613);
+ AssertConversion<UInt32Type>("0x7FFFFFFF", 2147483647);
+ AssertConversion<UInt32Type>("0x123abc", 1194684);
+ AssertConversion<UInt32Type>("0xA4b35", 674613);
+ AssertConversion<UInt32Type>("0x7FFFfFfF", 2147483647);
+ AssertConversion<UInt32Type>("0XFFFFfFfF", 4294967295);
+ AssertConversionFails<UInt32Type>("0X");
+ AssertConversionFails<UInt32Type>("0x23512ak");
+}
+
+TEST(StringConversion, ToInt64) {
+ AssertConversion<Int64Type>("0", 0);
+ AssertConversion<Int64Type>("9223372036854775807", 9223372036854775807LL);
+ AssertConversion<Int64Type>("09223372036854775807", 9223372036854775807LL);
+ AssertConversion<Int64Type>("-9223372036854775808", -9223372036854775807LL - 1);
+ AssertConversion<Int64Type>("-009223372036854775808", -9223372036854775807LL - 1);
+
+ // Non-representable values
+ AssertConversionFails<Int64Type>("9223372036854775808");
+ AssertConversionFails<Int64Type>("-9223372036854775809");
+
+ AssertConversionFails<Int64Type>("");
+ AssertConversionFails<Int64Type>("-");
+ AssertConversionFails<Int64Type>("0.0");
+ AssertConversionFails<Int64Type>("e");
+
+ // Hex
+ AssertConversion<Int64Type>("0x0", 0);
+ AssertConversion<Int64Type>("0x5415a123ABC123cb", 6058926048274359243);
+ AssertConversion<Int64Type>("0xA4B35", 674613);
+ AssertConversion<Int64Type>("0x7FFFFFFFFFFFFFFf", 9223372036854775807);
+ AssertConversion<Int64Type>("0XF000000000000001", -1152921504606846975);
+ AssertConversion<Int64Type>("0xfFFFFFFFFFFFFFFf", -1);
+ AssertConversionFails<Int64Type>("0X");
+ AssertConversionFails<Int64Type>("0x12345678901234567");
+ AssertConversionFails<Int64Type>("0x23512ak");
+}
+
+TEST(StringConversion, ToUInt64) {
+ AssertConversion<UInt64Type>("0", 0);
+ AssertConversion<UInt64Type>("18446744073709551615", 18446744073709551615ULL);
+
+ // Non-representable values
+ AssertConversionFails<UInt64Type>("-1");
+ AssertConversionFails<UInt64Type>("18446744073709551616");
+
+ AssertConversionFails<UInt64Type>("");
+ AssertConversionFails<UInt64Type>("-");
+ AssertConversionFails<UInt64Type>("0.0");
+ AssertConversionFails<UInt64Type>("e");
+
+ // Hex
+ AssertConversion<UInt64Type>("0x0", 0);
+ AssertConversion<UInt64Type>("0x5415a123ABC123cb", 6058926048274359243);
+ AssertConversion<UInt64Type>("0xA4B35", 674613);
+ AssertConversion<UInt64Type>("0x7FFFFFFFFFFFFFFf", 9223372036854775807);
+ AssertConversion<UInt64Type>("0XF000000000000001", 17293822569102704641ULL);
+ AssertConversion<UInt64Type>("0xfFFFFFFFFFFFFFFf", 18446744073709551615ULL);
+ AssertConversionFails<UInt64Type>("0x");
+ AssertConversionFails<UInt64Type>("0x12345678901234567");
+ AssertConversionFails<UInt64Type>("0x23512ak");
+}
+
+TEST(StringConversion, ToDate32) {
+ AssertConversion<Date32Type>("1970-01-01", 0);
+ AssertConversion<Date32Type>("1970-01-02", 1);
+ AssertConversion<Date32Type>("2020-03-15", 18336);
+ AssertConversion<Date32Type>("1945-05-08", -9004);
+ AssertConversion<Date32Type>("4707-11-28", 999999);
+ AssertConversion<Date32Type>("0001-01-01", -719162);
+
+ // Invalid format
+ AssertConversionFails<Date32Type>("");
+ AssertConversionFails<Date32Type>("1970");
+ AssertConversionFails<Date32Type>("1970-01");
+ AssertConversionFails<Date32Type>("1970-01-01 00:00:00");
+ AssertConversionFails<Date32Type>("1970/01/01");
+}
+
+TEST(StringConversion, ToDate64) {
+ AssertConversion<Date64Type>("1970-01-01", 0);
+ AssertConversion<Date64Type>("1970-01-02", 86400000);
+ AssertConversion<Date64Type>("2020-03-15", 1584230400000LL);
+ AssertConversion<Date64Type>("1945-05-08", -777945600000LL);
+ AssertConversion<Date64Type>("4707-11-28", 86399913600000LL);
+ AssertConversion<Date64Type>("0001-01-01", -62135596800000LL);
+}
+
+template <typename T>
+void AssertInvalidTimes(const T& type) {
+ // Invalid time format
+ AssertConversionFails(type, "");
+ AssertConversionFails(type, "00");
+ AssertConversionFails(type, "00:");
+ AssertConversionFails(type, "00:00:");
+ AssertConversionFails(type, "00:00:00:");
+ AssertConversionFails(type, "000000");
+ AssertConversionFails(type, "000000.000");
+
+ // Invalid time value
+ AssertConversionFails(type, "24:00:00");
+ AssertConversionFails(type, "00:60:00");
+ AssertConversionFails(type, "00:00:60");
+}
+
+TEST(StringConversion, ToTime32) {
+ {
+ Time32Type type{TimeUnit::SECOND};
+
+ AssertConversion(type, "00:00", 0);
+ AssertConversion(type, "01:23", 4980);
+ AssertConversion(type, "23:59", 86340);
+
+ AssertConversion(type, "00:00:00", 0);
+ AssertConversion(type, "01:23:45", 5025);
+ AssertConversion(type, "23:45:43", 85543);
+ AssertConversion(type, "23:59:59", 86399);
+
+ AssertInvalidTimes(type);
+ // No subseconds allowed
+ AssertConversionFails(type, "00:00:00.123");
+ }
+ {
+ Time32Type type{TimeUnit::MILLI};
+
+ AssertConversion(type, "00:00", 0);
+ AssertConversion(type, "01:23", 4980000);
+ AssertConversion(type, "23:59", 86340000);
+
+ AssertConversion(type, "00:00:00", 0);
+ AssertConversion(type, "01:23:45", 5025000);
+ AssertConversion(type, "23:45:43", 85543000);
+ AssertConversion(type, "23:59:59", 86399000);
+
+ AssertConversion(type, "00:00:00.123", 123);
+ AssertConversion(type, "01:23:45.000", 5025000);
+ AssertConversion(type, "01:23:45.1", 5025100);
+ AssertConversion(type, "01:23:45.123", 5025123);
+ AssertConversion(type, "01:23:45.999", 5025999);
+
+ AssertInvalidTimes(type);
+ // Invalid subseconds
+ AssertConversionFails(type, "00:00:00.1234");
+ }
+}
+
+TEST(StringConversion, ToTime64) {
+ {
+ Time64Type type{TimeUnit::MICRO};
+
+ AssertConversion(type, "00:00:00", 0LL);
+ AssertConversion(type, "01:23:45", 5025000000LL);
+ AssertConversion(type, "23:45:43", 85543000000LL);
+ AssertConversion(type, "23:59:59", 86399000000LL);
+
+ AssertConversion(type, "00:00:00.123456", 123456LL);
+ AssertConversion(type, "01:23:45.000000", 5025000000LL);
+ AssertConversion(type, "01:23:45.1", 5025100000LL);
+ AssertConversion(type, "01:23:45.123", 5025123000LL);
+ AssertConversion(type, "01:23:45.999999", 5025999999LL);
+
+ AssertInvalidTimes(type);
+ // Invalid subseconds
+ AssertConversionFails(type, "00:00:00.1234567");
+ }
+ {
+ Time64Type type{TimeUnit::NANO};
+
+ AssertConversion(type, "00:00:00", 0LL);
+ AssertConversion(type, "01:23:45", 5025000000000LL);
+ AssertConversion(type, "23:45:43", 85543000000000LL);
+ AssertConversion(type, "23:59:59", 86399000000000LL);
+
+ AssertConversion(type, "00:00:00.123456789", 123456789LL);
+ AssertConversion(type, "01:23:45.000000000", 5025000000000LL);
+ AssertConversion(type, "01:23:45.1", 5025100000000LL);
+ AssertConversion(type, "01:23:45.1234", 5025123400000LL);
+ AssertConversion(type, "01:23:45.999999999", 5025999999999LL);
+
+ AssertInvalidTimes(type);
+ // Invalid subseconds
+ AssertConversionFails(type, "00:00:00.1234567891");
+ }
+}
+
+TEST(StringConversion, ToTimestampDate_ISO8601) {
+ {
+ TimestampType type{TimeUnit::SECOND};
+
+ AssertConversion(type, "1970-01-01", 0);
+ AssertConversion(type, "1989-07-14", 616377600);
+ AssertConversion(type, "2000-02-29", 951782400);
+ AssertConversion(type, "3989-07-14", 63730281600LL);
+ AssertConversion(type, "1900-02-28", -2203977600LL);
+
+ AssertConversionFails(type, "");
+ AssertConversionFails(type, "1970");
+ AssertConversionFails(type, "19700101");
+ AssertConversionFails(type, "1970/01/01");
+ AssertConversionFails(type, "1970-01-01 ");
+ AssertConversionFails(type, "1970-01-01Z");
+
+ // Invalid dates
+ AssertConversionFails(type, "1970-00-01");
+ AssertConversionFails(type, "1970-13-01");
+ AssertConversionFails(type, "1970-01-32");
+ AssertConversionFails(type, "1970-02-29");
+ AssertConversionFails(type, "2100-02-29");
+ }
+ {
+ TimestampType type{TimeUnit::MILLI};
+
+ AssertConversion(type, "1970-01-01", 0);
+ AssertConversion(type, "1989-07-14", 616377600000LL);
+ AssertConversion(type, "3989-07-14", 63730281600000LL);
+ AssertConversion(type, "1900-02-28", -2203977600000LL);
+ }
+ {
+ TimestampType type{TimeUnit::MICRO};
+
+ AssertConversion(type, "1970-01-01", 0);
+ AssertConversion(type, "1989-07-14", 616377600000000LL);
+ AssertConversion(type, "3989-07-14", 63730281600000000LL);
+ AssertConversion(type, "1900-02-28", -2203977600000000LL);
+ }
+ {
+ TimestampType type{TimeUnit::NANO};
+
+ AssertConversion(type, "1970-01-01", 0);
+ AssertConversion(type, "1989-07-14", 616377600000000000LL);
+ AssertConversion(type, "2018-11-13", 1542067200000000000LL);
+ AssertConversion(type, "1900-02-28", -2203977600000000000LL);
+ }
+}
+
+TEST(StringConversion, ToTimestampDateTime_ISO8601) {
+ {
+ TimestampType type{TimeUnit::SECOND};
+
+ AssertConversion(type, "1970-01-01 00:00:00", 0);
+ AssertConversion(type, "2018-11-13 17", 1542128400);
+ AssertConversion(type, "2018-11-13T17", 1542128400);
+ AssertConversion(type, "2018-11-13 17Z", 1542128400);
+ AssertConversion(type, "2018-11-13T17Z", 1542128400);
+ AssertConversion(type, "2018-11-13 17:11", 1542129060);
+ AssertConversion(type, "2018-11-13T17:11", 1542129060);
+ AssertConversion(type, "2018-11-13 17:11Z", 1542129060);
+ AssertConversion(type, "2018-11-13T17:11Z", 1542129060);
+ AssertConversion(type, "2018-11-13 17:11:10", 1542129070);
+ AssertConversion(type, "2018-11-13T17:11:10", 1542129070);
+ AssertConversion(type, "2018-11-13 17:11:10Z", 1542129070);
+ AssertConversion(type, "2018-11-13T17:11:10Z", 1542129070);
+ AssertConversion(type, "1900-02-28 12:34:56", -2203932304LL);
+
+ // No subseconds allowed
+ AssertConversionFails(type, "1900-02-28 12:34:56.001");
+ // Invalid dates
+ AssertConversionFails(type, "1970-02-29 00:00:00");
+ AssertConversionFails(type, "2100-02-29 00:00:00");
+ // Invalid times
+ AssertConversionFails(type, "1970-01-01 24");
+ AssertConversionFails(type, "1970-01-01 00:60");
+ AssertConversionFails(type, "1970-01-01 00,00");
+ AssertConversionFails(type, "1970-01-01 24:00:00");
+ AssertConversionFails(type, "1970-01-01 00:60:00");
+ AssertConversionFails(type, "1970-01-01 00:00:60");
+ AssertConversionFails(type, "1970-01-01 00:00,00");
+ AssertConversionFails(type, "1970-01-01 00,00:00");
+ }
+ {
+ TimestampType type{TimeUnit::MILLI};
+
+ AssertConversion(type, "2018-11-13 17:11:10", 1542129070000LL);
+ AssertConversion(type, "2018-11-13T17:11:10Z", 1542129070000LL);
+ AssertConversion(type, "3989-07-14T11:22:33Z", 63730322553000LL);
+ AssertConversion(type, "1900-02-28 12:34:56", -2203932304000LL);
+ AssertConversion(type, "2018-11-13T17:11:10.777Z", 1542129070777LL);
+
+ AssertConversion(type, "1900-02-28 12:34:56.1", -2203932304000LL + 100LL);
+ AssertConversion(type, "1900-02-28 12:34:56.12", -2203932304000LL + 120LL);
+ AssertConversion(type, "1900-02-28 12:34:56.123", -2203932304000LL + 123LL);
+
+ // Invalid subseconds
+ AssertConversionFails(type, "1900-02-28 12:34:56.1234");
+ AssertConversionFails(type, "1900-02-28 12:34:56.12345");
+ AssertConversionFails(type, "1900-02-28 12:34:56.123456");
+ AssertConversionFails(type, "1900-02-28 12:34:56.1234567");
+ AssertConversionFails(type, "1900-02-28 12:34:56.12345678");
+ AssertConversionFails(type, "1900-02-28 12:34:56.123456789");
+ }
+ {
+ TimestampType type{TimeUnit::MICRO};
+
+ AssertConversion(type, "2018-11-13 17:11:10", 1542129070000000LL);
+ AssertConversion(type, "2018-11-13T17:11:10Z", 1542129070000000LL);
+ AssertConversion(type, "3989-07-14T11:22:33Z", 63730322553000000LL);
+ AssertConversion(type, "1900-02-28 12:34:56", -2203932304000000LL);
+ AssertConversion(type, "2018-11-13T17:11:10.777000", 1542129070777000LL);
+ AssertConversion(type, "3989-07-14T11:22:33.000777Z", 63730322553000777LL);
+
+ AssertConversion(type, "1900-02-28 12:34:56.1", -2203932304000000LL + 100000LL);
+ AssertConversion(type, "1900-02-28 12:34:56.12", -2203932304000000LL + 120000LL);
+ AssertConversion(type, "1900-02-28 12:34:56.123", -2203932304000000LL + 123000LL);
+ AssertConversion(type, "1900-02-28 12:34:56.1234", -2203932304000000LL + 123400LL);
+ AssertConversion(type, "1900-02-28 12:34:56.12345", -2203932304000000LL + 123450LL);
+ AssertConversion(type, "1900-02-28 12:34:56.123456", -2203932304000000LL + 123456LL);
+
+ // Invalid subseconds
+ AssertConversionFails(type, "1900-02-28 12:34:56.1234567");
+ AssertConversionFails(type, "1900-02-28 12:34:56.12345678");
+ AssertConversionFails(type, "1900-02-28 12:34:56.123456789");
+ }
+ {
+ TimestampType type{TimeUnit::NANO};
+
+ AssertConversion(type, "2018-11-13 17:11:10", 1542129070000000000LL);
+ AssertConversion(type, "2018-11-13T17:11:10Z", 1542129070000000000LL);
+ AssertConversion(type, "1900-02-28 12:34:56", -2203932304000000000LL);
+ AssertConversion(type, "2018-11-13 17:11:10.777000000", 1542129070777000000LL);
+ AssertConversion(type, "2018-11-13T17:11:10.000777000Z", 1542129070000777000LL);
+ AssertConversion(type, "1969-12-31 23:59:59.999999999", -1);
+
+ AssertConversion(type, "1900-02-28 12:34:56.1", -2203932304000000000LL + 100000000LL);
+ AssertConversion(type, "1900-02-28 12:34:56.12",
+ -2203932304000000000LL + 120000000LL);
+ AssertConversion(type, "1900-02-28 12:34:56.123",
+ -2203932304000000000LL + 123000000LL);
+ AssertConversion(type, "1900-02-28 12:34:56.1234",
+ -2203932304000000000LL + 123400000LL);
+ AssertConversion(type, "1900-02-28 12:34:56.12345",
+ -2203932304000000000LL + 123450000LL);
+ AssertConversion(type, "1900-02-28 12:34:56.123456",
+ -2203932304000000000LL + 123456000LL);
+ AssertConversion(type, "1900-02-28 12:34:56.1234567",
+ -2203932304000000000LL + 123456700LL);
+ AssertConversion(type, "1900-02-28 12:34:56.12345678",
+ -2203932304000000000LL + 123456780LL);
+ AssertConversion(type, "1900-02-28 12:34:56.123456789",
+ -2203932304000000000LL + 123456789LL);
+
+ // Invalid subseconds
+ }
+}
+
+TEST(TimestampParser, StrptimeParser) {
+ std::string format = "%m/%d/%Y %H:%M:%S";
+ auto parser = TimestampParser::MakeStrptime(format);
+
+ struct Case {
+ std::string value;
+ std::string iso8601;
+ };
+
+ std::vector<Case> cases = {{"5/31/2000 12:34:56", "2000-05-31 12:34:56"},
+ {"5/31/2000 00:00:00", "2000-05-31 00:00:00"}};
+
+ std::vector<TimeUnit::type> units = {TimeUnit::SECOND, TimeUnit::MILLI, TimeUnit::MICRO,
+ TimeUnit::NANO};
+
+ for (auto unit : units) {
+ for (const auto& case_ : cases) {
+ int64_t converted, expected;
+ ASSERT_TRUE((*parser)(case_.value.c_str(), case_.value.size(), unit, &converted));
+ ASSERT_TRUE(ParseTimestampISO8601(case_.iso8601.c_str(), case_.iso8601.size(), unit,
+ &expected));
+ ASSERT_EQ(expected, converted);
+ }
+ }
+
+ // Unparseable strings
+ std::vector<std::string> unparseables = {"foo", "5/1/2000", "5/1/2000 12:34:56:6"};
+ for (auto& value : unparseables) {
+ int64_t dummy;
+ ASSERT_FALSE((*parser)(value.c_str(), value.size(), TimeUnit::SECOND, &dummy));
+ }
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/variant.h b/src/arrow/cpp/src/arrow/util/variant.h
new file mode 100644
index 000000000..8bbce5251
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/variant.h
@@ -0,0 +1,443 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstddef>
+#include <exception>
+#include <type_traits>
+#include <utility>
+
+#include "arrow/util/macros.h"
+#include "arrow/util/type_traits.h"
+
+namespace arrow {
+namespace util {
+
+/// \brief a std::variant-like discriminated union
+///
+/// Simplifications from std::variant:
+///
+/// - Strictly defaultable. The first type of T... should be nothrow default constructible
+/// and it will be used for default Variants.
+///
+/// - Never valueless_by_exception. std::variant supports a state outside those specified
+/// by T... to which it can return in the event that a constructor throws. If a Variant
+/// would become valueless_by_exception it will instead return to its default state.
+///
+/// - Strictly nothrow move constructible and assignable
+///
+/// - Less sophisticated type deduction. std::variant<bool, std::string>("hello") will
+/// intelligently construct std::string while Variant<bool, std::string>("hello") will
+/// construct bool.
+///
+/// - Either both copy constructible and assignable or neither (std::variant independently
+/// enables copy construction and copy assignment). Variant is copy constructible if
+/// each of T... is copy constructible and assignable.
+///
+/// - Slimmer interface; several members of std::variant are omitted.
+///
+/// - Throws no exceptions; if a bad_variant_access would be thrown Variant will instead
+/// segfault (nullptr dereference).
+///
+/// - Mutable visit takes a pointer instead of mutable reference or rvalue reference,
+/// which is more conformant with our code style.
+template <typename... T>
+class Variant;
+
+namespace detail {
+
+template <typename T, typename = void>
+struct is_equality_comparable : std::false_type {};
+
+template <typename T>
+struct is_equality_comparable<
+ T, typename std::enable_if<std::is_convertible<
+ decltype(std::declval<T>() == std::declval<T>()), bool>::value>::type>
+ : std::true_type {};
+
+template <bool C, typename T, typename E>
+using conditional_t = typename std::conditional<C, T, E>::type;
+
+template <typename T>
+struct type_constant {
+ using type = T;
+};
+
+template <typename...>
+struct first;
+
+template <typename H, typename... T>
+struct first<H, T...> {
+ using type = H;
+};
+
+template <typename T>
+using decay_t = typename std::decay<T>::type;
+
+template <bool...>
+struct all : std::true_type {};
+
+template <bool H, bool... T>
+struct all<H, T...> : conditional_t<H, all<T...>, std::false_type> {};
+
+struct delete_copy_constructor {
+ template <typename>
+ struct type {
+ type() = default;
+ type(const type& other) = delete;
+ type& operator=(const type& other) = delete;
+ };
+};
+
+struct explicit_copy_constructor {
+ template <typename Copyable>
+ struct type {
+ type() = default;
+ type(const type& other) { static_cast<const Copyable&>(other).copy_to(this); }
+ type& operator=(const type& other) {
+ static_cast<Copyable*>(this)->destroy();
+ static_cast<const Copyable&>(other).copy_to(this);
+ return *this;
+ }
+ };
+};
+
+template <typename... T>
+struct VariantStorage {
+ VariantStorage() = default;
+ VariantStorage(const VariantStorage&) {}
+ VariantStorage& operator=(const VariantStorage&) { return *this; }
+ VariantStorage(VariantStorage&&) noexcept {}
+ VariantStorage& operator=(VariantStorage&&) noexcept { return *this; }
+ ~VariantStorage() {
+ static_assert(offsetof(VariantStorage, data_) == 0,
+ "(void*)&VariantStorage::data_ == (void*)this");
+ }
+
+ typename arrow::internal::aligned_union<0, T...>::type data_;
+ uint8_t index_ = 0;
+};
+
+template <typename V, typename...>
+struct VariantImpl;
+
+template <typename... T>
+struct VariantImpl<Variant<T...>> : VariantStorage<T...> {
+ static void index_of() noexcept {}
+ void destroy() noexcept {}
+ void move_to(...) noexcept {}
+ void copy_to(...) const {}
+
+ template <typename R, typename Visitor>
+ [[noreturn]] R visit_const(Visitor&& visitor) const {
+ std::terminate();
+ }
+ template <typename R, typename Visitor>
+ [[noreturn]] R visit_mutable(Visitor&& visitor) {
+ std::terminate();
+ }
+};
+
+template <typename... M, typename H, typename... T>
+struct VariantImpl<Variant<M...>, H, T...> : VariantImpl<Variant<M...>, T...> {
+ using VariantType = Variant<M...>;
+ using Impl = VariantImpl<VariantType, T...>;
+
+ static constexpr uint8_t kIndex = sizeof...(M) - sizeof...(T) - 1;
+
+ VariantImpl() = default;
+
+ using VariantImpl<VariantType, T...>::VariantImpl;
+ using Impl::operator=;
+ using Impl::index_of;
+
+ explicit VariantImpl(H value) {
+ new (this) H(std::move(value));
+ this->index_ = kIndex;
+ }
+
+ VariantImpl& operator=(H value) {
+ static_cast<VariantType*>(this)->destroy();
+ new (this) H(std::move(value));
+ this->index_ = kIndex;
+ return *this;
+ }
+
+ H& cast_this() { return *reinterpret_cast<H*>(this); }
+ const H& cast_this() const { return *reinterpret_cast<const H*>(this); }
+
+ void move_to(VariantType* target) noexcept {
+ if (this->index_ == kIndex) {
+ new (target) H(std::move(cast_this()));
+ target->index_ = kIndex;
+ } else {
+ Impl::move_to(target);
+ }
+ }
+
+ // Templated to avoid instantiation in case H is not copy constructible
+ template <typename Void>
+ void copy_to(Void* generic_target) const {
+ const auto target = static_cast<VariantType*>(generic_target);
+ try {
+ if (this->index_ == kIndex) {
+ new (target) H(cast_this());
+ target->index_ = kIndex;
+ } else {
+ Impl::copy_to(target);
+ }
+ } catch (...) {
+ target->construct_default();
+ throw;
+ }
+ }
+
+ void destroy() noexcept {
+ if (this->index_ == kIndex) {
+ if (!std::is_trivially_destructible<H>::value) {
+ cast_this().~H();
+ }
+ } else {
+ Impl::destroy();
+ }
+ }
+
+ static constexpr std::integral_constant<uint8_t, kIndex> index_of(
+ const type_constant<H>&) {
+ return {};
+ }
+
+ template <typename R, typename Visitor>
+ R visit_const(Visitor&& visitor) const {
+ if (this->index_ == kIndex) {
+ return std::forward<Visitor>(visitor)(cast_this());
+ }
+ return Impl::template visit_const<R>(std::forward<Visitor>(visitor));
+ }
+
+ template <typename R, typename Visitor>
+ R visit_mutable(Visitor&& visitor) {
+ if (this->index_ == kIndex) {
+ return std::forward<Visitor>(visitor)(&cast_this());
+ }
+ return Impl::template visit_mutable<R>(std::forward<Visitor>(visitor));
+ }
+};
+
+} // namespace detail
+
+template <typename... T>
+class Variant : detail::VariantImpl<Variant<T...>, T...>,
+ detail::conditional_t<
+ detail::all<(std::is_copy_constructible<T>::value &&
+ std::is_copy_assignable<T>::value)...>::value,
+ detail::explicit_copy_constructor,
+ detail::delete_copy_constructor>::template type<Variant<T...>> {
+ template <typename U>
+ static constexpr uint8_t index_of() {
+ return Impl::index_of(detail::type_constant<U>{});
+ }
+
+ using Impl = detail::VariantImpl<Variant<T...>, T...>;
+
+ public:
+ using default_type = typename util::detail::first<T...>::type;
+
+ Variant() noexcept { construct_default(); }
+
+ Variant(const Variant& other) = default;
+ Variant& operator=(const Variant& other) = default;
+ Variant& operator=(Variant&& other) noexcept {
+ this->destroy();
+ other.move_to(this);
+ return *this;
+ }
+
+ using Impl::Impl;
+ using Impl::operator=;
+
+ Variant(Variant&& other) noexcept { other.move_to(this); }
+
+ ~Variant() {
+ static_assert(offsetof(Variant, data_) == 0, "(void*)&Variant::data_ == (void*)this");
+ this->destroy();
+ }
+
+ /// \brief Return the zero-based type index of the value held by the variant
+ uint8_t index() const noexcept { return this->index_; }
+
+ /// \brief Get a const pointer to the value held by the variant
+ ///
+ /// If the type given as template argument doesn't match, a null pointer is returned.
+ template <typename U, uint8_t I = index_of<U>()>
+ const U* get() const noexcept {
+ return index() == I ? reinterpret_cast<const U*>(this) : NULLPTR;
+ }
+
+ /// \brief Get a pointer to the value held by the variant
+ ///
+ /// If the type given as template argument doesn't match, a null pointer is returned.
+ template <typename U, uint8_t I = index_of<U>()>
+ U* get() noexcept {
+ return index() == I ? reinterpret_cast<U*>(this) : NULLPTR;
+ }
+
+ /// \brief Replace the value held by the variant
+ ///
+ /// The intended type must be given as a template argument.
+ /// The value is constructed in-place using the given function arguments.
+ template <typename U, typename... A, uint8_t I = index_of<U>()>
+ void emplace(A&&... args) {
+ try {
+ this->destroy();
+ new (this) U(std::forward<A>(args)...);
+ this->index_ = I;
+ } catch (...) {
+ construct_default();
+ throw;
+ }
+ }
+
+ template <typename U, typename E, typename... A, uint8_t I = index_of<U>()>
+ void emplace(std::initializer_list<E> il, A&&... args) {
+ try {
+ this->destroy();
+ new (this) U(il, std::forward<A>(args)...);
+ this->index_ = I;
+ } catch (...) {
+ construct_default();
+ throw;
+ }
+ }
+
+ /// \brief Swap with another variant's contents
+ void swap(Variant& other) noexcept { // NOLINT google-runtime-references
+ Variant tmp = std::move(other);
+ other = std::move(*this);
+ *this = std::move(tmp);
+ }
+
+ using Impl::visit_const;
+ using Impl::visit_mutable;
+
+ private:
+ void construct_default() noexcept {
+ new (this) default_type();
+ this->index_ = 0;
+ }
+
+ template <typename V>
+ friend struct detail::explicit_copy_constructor::type;
+
+ template <typename V, typename...>
+ friend struct detail::VariantImpl;
+};
+
+/// \brief Call polymorphic visitor on a const variant's value
+///
+/// The visitor will receive a const reference to the value held by the variant.
+/// It must define overloads for each possible variant type.
+/// The overloads should all return the same type (no attempt
+/// is made to find a generalized return type).
+template <typename Visitor, typename... T,
+ typename R = decltype(std::declval<Visitor&&>()(
+ std::declval<const typename Variant<T...>::default_type&>()))>
+R visit(Visitor&& visitor, const util::Variant<T...>& v) {
+ return v.template visit_const<R>(std::forward<Visitor>(visitor));
+}
+
+/// \brief Call polymorphic visitor on a non-const variant's value
+///
+/// The visitor will receive a pointer to the value held by the variant.
+/// It must define overloads for each possible variant type.
+/// The overloads should all return the same type (no attempt
+/// is made to find a generalized return type).
+template <typename Visitor, typename... T,
+ typename R = decltype(std::declval<Visitor&&>()(
+ std::declval<typename Variant<T...>::default_type*>()))>
+R visit(Visitor&& visitor, util::Variant<T...>* v) {
+ return v->template visit_mutable<R>(std::forward<Visitor>(visitor));
+}
+
+/// \brief Get a const reference to the value held by the variant
+///
+/// If the type given as template argument doesn't match, behavior is undefined
+/// (a null pointer will be dereferenced).
+template <typename U, typename... T>
+const U& get(const Variant<T...>& v) {
+ return *v.template get<U>();
+}
+
+/// \brief Get a reference to the value held by the variant
+///
+/// If the type given as template argument doesn't match, behavior is undefined
+/// (a null pointer will be dereferenced).
+template <typename U, typename... T>
+U& get(Variant<T...>& v) {
+ return *v.template get<U>();
+}
+
+/// \brief Get a const pointer to the value held by the variant
+///
+/// If the type given as template argument doesn't match, a nullptr is returned.
+template <typename U, typename... T>
+const U* get_if(const Variant<T...>* v) {
+ return v->template get<U>();
+}
+
+/// \brief Get a pointer to the value held by the variant
+///
+/// If the type given as template argument doesn't match, a nullptr is returned.
+template <typename U, typename... T>
+U* get_if(Variant<T...>* v) {
+ return v->template get<U>();
+}
+
+namespace detail {
+
+template <typename... T>
+struct VariantsEqual {
+ template <typename U>
+ bool operator()(const U& r) const {
+ return get<U>(l_) == r;
+ }
+ const Variant<T...>& l_;
+};
+
+} // namespace detail
+
+template <typename... T, typename = typename std::enable_if<detail::all<
+ detail::is_equality_comparable<T>::value...>::value>>
+bool operator==(const Variant<T...>& l, const Variant<T...>& r) {
+ if (l.index() != r.index()) return false;
+ return visit(detail::VariantsEqual<T...>{l}, r);
+}
+
+template <typename... T>
+auto operator!=(const Variant<T...>& l, const Variant<T...>& r) -> decltype(l == r) {
+ return !(l == r);
+}
+
+/// \brief Return whether the variant holds a value of the given type
+template <typename U, typename... T>
+bool holds_alternative(const Variant<T...>& v) {
+ return v.template get<U>();
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/variant_benchmark.cc b/src/arrow/cpp/src/arrow/util/variant_benchmark.cc
new file mode 100644
index 000000000..af3fafb8b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/variant_benchmark.cc
@@ -0,0 +1,248 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/chunked_array.h"
+#include "arrow/datum.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/variant.h"
+
+namespace arrow {
+
+using internal::checked_pointer_cast;
+
+namespace util {
+
+using TrivialVariant = arrow::util::Variant<int32_t, float>;
+
+using NonTrivialVariant = arrow::util::Variant<int32_t, std::string>;
+
+std::vector<int32_t> MakeInts(int64_t nitems) {
+ auto rng = arrow::random::RandomArrayGenerator(42);
+ auto array = checked_pointer_cast<Int32Array>(rng.Int32(nitems, 0, 1 << 30));
+ std::vector<int32_t> items(nitems);
+ for (int64_t i = 0; i < nitems; ++i) {
+ items[i] = array->Value(i);
+ }
+ return items;
+}
+
+std::vector<float> MakeFloats(int64_t nitems) {
+ auto rng = arrow::random::RandomArrayGenerator(42);
+ auto array = checked_pointer_cast<FloatArray>(rng.Float32(nitems, 0.0, 1.0));
+ std::vector<float> items(nitems);
+ for (int64_t i = 0; i < nitems; ++i) {
+ items[i] = array->Value(i);
+ }
+ return items;
+}
+
+std::vector<std::string> MakeStrings(int64_t nitems) {
+ auto rng = arrow::random::RandomArrayGenerator(42);
+ // Some std::string's will use short string optimization, but not all...
+ auto array = checked_pointer_cast<StringArray>(rng.String(nitems, 5, 40));
+ std::vector<std::string> items(nitems);
+ for (int64_t i = 0; i < nitems; ++i) {
+ items[i] = array->GetString(i);
+ }
+ return items;
+}
+
+static void ConstructTrivialVariant(benchmark::State& state) {
+ const int64_t N = 10000;
+ const auto ints = MakeInts(N);
+ const auto floats = MakeFloats(N);
+
+ for (auto _ : state) {
+ for (int64_t i = 0; i < N; ++i) {
+ // About type selection: we ensure 50% of each type, but try to avoid
+ // branch mispredictions by creating runs of the same type.
+ if (i & 0x10) {
+ TrivialVariant v{ints[i]};
+ const int32_t* val = &arrow::util::get<int32_t>(v);
+ benchmark::DoNotOptimize(val);
+ } else {
+ TrivialVariant v{floats[i]};
+ const float* val = &arrow::util::get<float>(v);
+ benchmark::DoNotOptimize(val);
+ }
+ }
+ }
+
+ state.SetItemsProcessed(state.iterations() * N);
+}
+
+static void ConstructNonTrivialVariant(benchmark::State& state) {
+ const int64_t N = 10000;
+ const auto ints = MakeInts(N);
+ const auto strings = MakeStrings(N);
+
+ for (auto _ : state) {
+ for (int64_t i = 0; i < N; ++i) {
+ if (i & 0x10) {
+ NonTrivialVariant v{ints[i]};
+ const int32_t* val = &arrow::util::get<int32_t>(v);
+ benchmark::DoNotOptimize(val);
+ } else {
+ NonTrivialVariant v{strings[i]};
+ const std::string* val = &arrow::util::get<std::string>(v);
+ benchmark::DoNotOptimize(val);
+ }
+ }
+ }
+
+ state.SetItemsProcessed(state.iterations() * N);
+}
+
+struct VariantVisitor {
+ int64_t total = 0;
+
+ void operator()(const int32_t& v) { total += v; }
+ void operator()(const float& v) {
+ // Avoid potentially costly float-to-int conversion
+ int32_t x;
+ memcpy(&x, &v, 4);
+ total += x;
+ }
+ void operator()(const std::string& v) { total += static_cast<int64_t>(v.length()); }
+};
+
+template <typename VariantType>
+static void VisitVariant(benchmark::State& state,
+ const std::vector<VariantType>& variants) {
+ for (auto _ : state) {
+ VariantVisitor visitor;
+ for (const auto& v : variants) {
+ visit(visitor, v);
+ }
+ benchmark::DoNotOptimize(visitor.total);
+ }
+
+ state.SetItemsProcessed(state.iterations() * variants.size());
+}
+
+static void VisitTrivialVariant(benchmark::State& state) {
+ const int64_t N = 10000;
+ const auto ints = MakeInts(N);
+ const auto floats = MakeFloats(N);
+
+ std::vector<TrivialVariant> variants;
+ variants.reserve(N);
+ for (int64_t i = 0; i < N; ++i) {
+ if (i & 0x10) {
+ variants.emplace_back(ints[i]);
+ } else {
+ variants.emplace_back(floats[i]);
+ }
+ }
+
+ VisitVariant(state, variants);
+}
+
+static void VisitNonTrivialVariant(benchmark::State& state) {
+ const int64_t N = 10000;
+ const auto ints = MakeInts(N);
+ const auto strings = MakeStrings(N);
+
+ std::vector<NonTrivialVariant> variants;
+ variants.reserve(N);
+ for (int64_t i = 0; i < N; ++i) {
+ if (i & 0x10) {
+ variants.emplace_back(ints[i]);
+ } else {
+ variants.emplace_back(strings[i]);
+ }
+ }
+
+ VisitVariant(state, variants);
+}
+
+static void ConstructDatum(benchmark::State& state) {
+ const int64_t N = 10000;
+ auto array = *MakeArrayOfNull(int8(), 100);
+ auto chunked_array = std::make_shared<ChunkedArray>(ArrayVector{array, array});
+
+ for (auto _ : state) {
+ for (int64_t i = 0; i < N; ++i) {
+ if (i & 0x10) {
+ Datum datum{array};
+ const ArrayData* val = datum.array().get();
+ benchmark::DoNotOptimize(val);
+ } else {
+ Datum datum{chunked_array};
+ const ChunkedArray* val = datum.chunked_array().get();
+ benchmark::DoNotOptimize(val);
+ }
+ }
+ }
+
+ state.SetItemsProcessed(state.iterations() * N);
+}
+
+static void VisitDatum(benchmark::State& state) {
+ const int64_t N = 10000;
+ auto array = *MakeArrayOfNull(int8(), 100);
+ auto chunked_array = std::make_shared<ChunkedArray>(ArrayVector{array, array});
+
+ std::vector<Datum> datums;
+ datums.reserve(N);
+ for (int64_t i = 0; i < N; ++i) {
+ if (i & 0x10) {
+ datums.emplace_back(array);
+ } else {
+ datums.emplace_back(chunked_array);
+ }
+ }
+
+ for (auto _ : state) {
+ int64_t total = 0;
+ for (const auto& datum : datums) {
+ // The .is_XXX() methods are the usual idiom when visiting a Datum,
+ // rather than the visit() function.
+ if (datum.is_array()) {
+ total += datum.array()->length;
+ } else {
+ total += datum.chunked_array()->length();
+ }
+ }
+ benchmark::DoNotOptimize(total);
+ }
+
+ state.SetItemsProcessed(state.iterations() * datums.size());
+}
+
+BENCHMARK(ConstructTrivialVariant);
+BENCHMARK(ConstructNonTrivialVariant);
+BENCHMARK(VisitTrivialVariant);
+BENCHMARK(VisitNonTrivialVariant);
+BENCHMARK(ConstructDatum);
+BENCHMARK(VisitDatum);
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/variant_test.cc b/src/arrow/cpp/src/arrow/util/variant_test.cc
new file mode 100644
index 000000000..f94d1b6cc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/variant_test.cc
@@ -0,0 +1,345 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/variant.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_compat.h"
+
+namespace arrow {
+
+namespace util {
+namespace {
+
+using ::testing::Eq;
+
+template <typename H, typename... T>
+void AssertDefaultConstruction() {
+ using variant_type = Variant<H, T...>;
+
+ static_assert(std::is_nothrow_default_constructible<variant_type>::value, "");
+
+ variant_type v;
+ EXPECT_EQ(v.index(), 0);
+ EXPECT_EQ(get<H>(v), H{});
+}
+
+TEST(Variant, DefaultConstruction) {
+ AssertDefaultConstruction<int>();
+ AssertDefaultConstruction<int, std::string>();
+ AssertDefaultConstruction<std::string, int>();
+ AssertDefaultConstruction<std::unique_ptr<int>>();
+ AssertDefaultConstruction<std::vector<int>, int>();
+ AssertDefaultConstruction<bool, std::string, std::unique_ptr<int>, void*,
+ std::true_type>();
+ AssertDefaultConstruction<std::nullptr_t, std::unique_ptr<int>, void*, bool,
+ std::string, std::true_type>();
+}
+
+template <typename V, typename T>
+struct AssertCopyConstructionOne {
+ void operator()(uint8_t index) {
+ V v{member_};
+ EXPECT_EQ(v.index(), index);
+ EXPECT_EQ(get<T>(v), member_);
+
+ V copy{v};
+ EXPECT_EQ(copy.index(), v.index());
+ EXPECT_EQ(get<T>(copy), get<T>(v));
+ EXPECT_EQ(copy, v);
+
+ V assigned;
+ assigned = member_;
+ EXPECT_EQ(assigned.index(), index);
+ EXPECT_EQ(get<T>(assigned), member_);
+
+ assigned = v;
+ EXPECT_EQ(assigned.index(), v.index());
+ EXPECT_EQ(get<T>(assigned), get<T>(v));
+ EXPECT_EQ(assigned, v);
+ }
+
+ const T& member_;
+};
+
+template <typename... T>
+void AssertCopyConstruction(T... member) {
+ uint8_t index = 0;
+ for (auto Assert : {std::function<void(uint8_t)>(
+ AssertCopyConstructionOne<Variant<T...>, T>{member})...}) {
+ Assert(index++);
+ }
+}
+
+template <typename... T>
+void AssertCopyConstructionDisabled() {
+ static_assert(!std::is_copy_constructible<Variant<T...>>::value,
+ "copy construction was not disabled");
+}
+
+TEST(Variant, CopyConstruction) {
+ // if any member is not copy constructible then Variant is not copy constructible
+ AssertCopyConstructionDisabled<std::unique_ptr<int>>();
+ AssertCopyConstructionDisabled<std::unique_ptr<int>, std::string>();
+ AssertCopyConstructionDisabled<std::string, int, bool, std::unique_ptr<int>>();
+
+ AssertCopyConstruction(32, std::string("hello"), true);
+ AssertCopyConstruction(std::string("world"), false, 53);
+ AssertCopyConstruction(nullptr, std::true_type{}, std::string("!"));
+ AssertCopyConstruction(std::vector<int>{1, 3, 3, 7}, "C string");
+
+ // copy assignment operator is not used
+ struct CopyAssignThrows {
+ CopyAssignThrows() = default;
+ CopyAssignThrows(const CopyAssignThrows&) = default;
+
+ CopyAssignThrows& operator=(const CopyAssignThrows&) { throw 42; }
+
+ CopyAssignThrows(CopyAssignThrows&&) = default;
+ CopyAssignThrows& operator=(CopyAssignThrows&&) = default;
+
+ bool operator==(const CopyAssignThrows&) const { return true; }
+ };
+ EXPECT_NO_THROW(AssertCopyConstruction(CopyAssignThrows{}));
+}
+
+TEST(Variant, Emplace) {
+ using variant_type = Variant<std::string, std::vector<int>, int>;
+ variant_type v;
+
+ v.emplace<int>();
+ EXPECT_EQ(v, variant_type{int{}});
+
+ v.emplace<std::string>("hello");
+ EXPECT_EQ(v, variant_type{std::string("hello")});
+
+ v.emplace<std::vector<int>>({1, 3, 3, 7});
+ EXPECT_EQ(v, variant_type{std::vector<int>({1, 3, 3, 7})});
+}
+
+TEST(Variant, MoveConstruction) {
+ struct noop_delete {
+ void operator()(...) const {}
+ };
+ using ptr = std::unique_ptr<int, noop_delete>;
+ static_assert(!std::is_copy_constructible<ptr>::value, "");
+
+ using variant_type = Variant<int, ptr>;
+
+ int tag = 42;
+ auto ExpectIsTag = [&](const variant_type& v) {
+ EXPECT_EQ(v.index(), 1);
+ EXPECT_EQ(get<ptr>(v).get(), &tag);
+ };
+
+ ptr p;
+
+ // move construction from member
+ p.reset(&tag);
+ variant_type v0{std::move(p)};
+ ExpectIsTag(v0);
+
+ // move assignment from member
+ p.reset(&tag);
+ v0 = std::move(p);
+ ExpectIsTag(v0);
+
+ // move construction from other variant
+ variant_type v1{std::move(v0)};
+ ExpectIsTag(v1);
+
+ // move assignment from other variant
+ p.reset(&tag);
+ variant_type v2{std::move(p)};
+ v1 = std::move(v2);
+ ExpectIsTag(v1);
+
+ // type changing move assignment from member
+ variant_type v3;
+ EXPECT_NE(v3.index(), 1);
+ p.reset(&tag);
+ v3 = std::move(p);
+ ExpectIsTag(v3);
+
+ // type changing move assignment from other variant
+ variant_type v4;
+ EXPECT_NE(v4.index(), 1);
+ v4 = std::move(v3);
+ ExpectIsTag(v4);
+}
+
+TEST(Variant, ExceptionSafety) {
+ struct {
+ } actually_throw;
+
+ struct {
+ } dont_throw;
+
+ struct ConstructorThrows {
+ explicit ConstructorThrows(decltype(actually_throw)) { throw 42; }
+ explicit ConstructorThrows(decltype(dont_throw)) {}
+
+ ConstructorThrows(const ConstructorThrows&) { throw 42; }
+
+ ConstructorThrows& operator=(const ConstructorThrows&) = default;
+ ConstructorThrows(ConstructorThrows&&) = default;
+ ConstructorThrows& operator=(ConstructorThrows&&) = default;
+ };
+
+ Variant<int, ConstructorThrows> v;
+
+ // constructor throws during emplacement
+ EXPECT_THROW(v.emplace<ConstructorThrows>(actually_throw), int);
+ // safely returned to the default state
+ EXPECT_EQ(v.index(), 0);
+
+ // constructor throws during copy assignment from member
+ EXPECT_THROW(
+ {
+ const ConstructorThrows throws(dont_throw);
+ v = throws;
+ },
+ int);
+ // safely returned to the default state
+ EXPECT_EQ(v.index(), 0);
+}
+
+// XXX GTest 1.11 exposes a `using std::visit` in its headers which
+// somehow gets preferred to `arrow::util::visit`, even if there is
+// a using clause (perhaps because of macros such as EXPECT_EQ).
+template <typename... Args>
+void DoVisit(Args&&... args) {
+ return ::arrow::util::visit(std::forward<Args>(args)...);
+}
+
+template <typename T, typename... Args>
+void AssertVisitedEquals(const T& expected, Args&&... args) {
+ const auto actual = ::arrow::util::visit(std::forward<Args>(args)...);
+ EXPECT_EQ(expected, actual);
+}
+
+template <typename V, typename T>
+struct AssertVisitOne {
+ void operator()(const T& actual) { EXPECT_EQ(&actual, expected_); }
+
+ void operator()(T* actual) { EXPECT_EQ(actual, expected_); }
+
+ template <typename U>
+ void operator()(const U&) {
+ FAIL() << "the expected type was not visited.";
+ }
+
+ template <typename U>
+ void operator()(U*) {
+ FAIL() << "the expected type was not visited.";
+ }
+
+ explicit AssertVisitOne(T member) : member_(std::move(member)) {}
+
+ void operator()() {
+ V v{member_};
+ expected_ = &get<T>(v);
+ DoVisit(*this, v);
+ DoVisit(*this, &v);
+ }
+
+ T member_;
+ const T* expected_;
+};
+
+// Try visiting all alternatives on a Variant<T...>
+template <typename... T>
+void AssertVisitAll(T... member) {
+ for (auto Assert :
+ {std::function<void()>(AssertVisitOne<Variant<T...>, T>{member})...}) {
+ Assert();
+ }
+}
+
+TEST(VariantTest, Visit) {
+ AssertVisitAll(32, std::string("hello"), true);
+ AssertVisitAll(std::string("world"), false, 53);
+ AssertVisitAll(nullptr, std::true_type{}, std::string("!"));
+ AssertVisitAll(std::vector<int>{1, 3, 3, 7}, "C string");
+
+ using int_or_string = Variant<int, std::string>;
+ int_or_string v;
+
+ // value returning visit:
+ struct {
+ int_or_string operator()(int i) { return int_or_string{i * 2}; }
+ int_or_string operator()(const std::string& s) { return int_or_string{s + s}; }
+ } Double;
+
+ v = 7;
+ AssertVisitedEquals(int_or_string{14}, Double, v);
+
+ v = "lolol";
+ AssertVisitedEquals(int_or_string{"lolollolol"}, Double, v);
+
+ // mutating visit:
+ struct {
+ void operator()(int* i) { *i *= 2; }
+ void operator()(std::string* s) { *s += *s; }
+ } DoubleInplace;
+
+ v = 7;
+ DoVisit(DoubleInplace, &v);
+ EXPECT_EQ(v, int_or_string{14});
+
+ v = "lolol";
+ DoVisit(DoubleInplace, &v);
+ EXPECT_EQ(v, int_or_string{"lolollolol"});
+}
+
+TEST(VariantTest, Equality) {
+ using int_or_double = Variant<int, double>;
+
+ auto eq = [](const int_or_double& a, const int_or_double& b) {
+ EXPECT_TRUE(a == b);
+ EXPECT_FALSE(a != b);
+ };
+ auto ne = [](const int_or_double& a, const int_or_double& b) {
+ EXPECT_TRUE(a != b);
+ EXPECT_FALSE(a == b);
+ };
+
+ int_or_double u, v;
+ u.emplace<int>(1);
+ v.emplace<int>(1);
+ eq(u, v);
+ v.emplace<int>(2);
+ ne(u, v);
+ v.emplace<double>(1.0);
+ ne(u, v);
+ u.emplace<double>(1.0);
+ eq(u, v);
+ u.emplace<double>(2.0);
+ ne(u, v);
+}
+
+} // namespace
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/vector.h b/src/arrow/cpp/src/arrow/util/vector.h
new file mode 100644
index 000000000..041bdb424
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/vector.h
@@ -0,0 +1,172 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <utility>
+#include <vector>
+
+#include "arrow/result.h"
+#include "arrow/util/algorithm.h"
+#include "arrow/util/functional.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace internal {
+
+template <typename T>
+std::vector<T> DeleteVectorElement(const std::vector<T>& values, size_t index) {
+ DCHECK(!values.empty());
+ DCHECK_LT(index, values.size());
+ std::vector<T> out;
+ out.reserve(values.size() - 1);
+ for (size_t i = 0; i < index; ++i) {
+ out.push_back(values[i]);
+ }
+ for (size_t i = index + 1; i < values.size(); ++i) {
+ out.push_back(values[i]);
+ }
+ return out;
+}
+
+template <typename T>
+std::vector<T> AddVectorElement(const std::vector<T>& values, size_t index,
+ T new_element) {
+ DCHECK_LE(index, values.size());
+ std::vector<T> out;
+ out.reserve(values.size() + 1);
+ for (size_t i = 0; i < index; ++i) {
+ out.push_back(values[i]);
+ }
+ out.emplace_back(std::move(new_element));
+ for (size_t i = index; i < values.size(); ++i) {
+ out.push_back(values[i]);
+ }
+ return out;
+}
+
+template <typename T>
+std::vector<T> ReplaceVectorElement(const std::vector<T>& values, size_t index,
+ T new_element) {
+ DCHECK_LE(index, values.size());
+ std::vector<T> out;
+ out.reserve(values.size());
+ for (size_t i = 0; i < index; ++i) {
+ out.push_back(values[i]);
+ }
+ out.emplace_back(std::move(new_element));
+ for (size_t i = index + 1; i < values.size(); ++i) {
+ out.push_back(values[i]);
+ }
+ return out;
+}
+
+template <typename T, typename Predicate>
+std::vector<T> FilterVector(std::vector<T> values, Predicate&& predicate) {
+ auto new_end =
+ std::remove_if(values.begin(), values.end(), std::forward<Predicate>(predicate));
+ values.erase(new_end, values.end());
+ return values;
+}
+
+template <typename Fn, typename From,
+ typename To = decltype(std::declval<Fn>()(std::declval<From>()))>
+std::vector<To> MapVector(Fn&& map, const std::vector<From>& source) {
+ std::vector<To> out;
+ out.reserve(source.size());
+ std::transform(source.begin(), source.end(), std::back_inserter(out),
+ std::forward<Fn>(map));
+ return out;
+}
+
+template <typename Fn, typename From,
+ typename To = decltype(std::declval<Fn>()(std::declval<From>()))>
+std::vector<To> MapVector(Fn&& map, std::vector<From>&& source) {
+ std::vector<To> out;
+ out.reserve(source.size());
+ std::transform(std::make_move_iterator(source.begin()),
+ std::make_move_iterator(source.end()), std::back_inserter(out),
+ std::forward<Fn>(map));
+ return out;
+}
+
+/// \brief Like MapVector, but where the function can fail.
+template <typename Fn, typename From = internal::call_traits::argument_type<0, Fn>,
+ typename To = typename internal::call_traits::return_type<Fn>::ValueType>
+Result<std::vector<To>> MaybeMapVector(Fn&& map, const std::vector<From>& source) {
+ std::vector<To> out;
+ out.reserve(source.size());
+ ARROW_RETURN_NOT_OK(MaybeTransform(source.begin(), source.end(),
+ std::back_inserter(out), std::forward<Fn>(map)));
+ return std::move(out);
+}
+
+template <typename Fn, typename From = internal::call_traits::argument_type<0, Fn>,
+ typename To = typename internal::call_traits::return_type<Fn>::ValueType>
+Result<std::vector<To>> MaybeMapVector(Fn&& map, std::vector<From>&& source) {
+ std::vector<To> out;
+ out.reserve(source.size());
+ ARROW_RETURN_NOT_OK(MaybeTransform(std::make_move_iterator(source.begin()),
+ std::make_move_iterator(source.end()),
+ std::back_inserter(out), std::forward<Fn>(map)));
+ return std::move(out);
+}
+
+template <typename T>
+std::vector<T> FlattenVectors(const std::vector<std::vector<T>>& vecs) {
+ std::size_t sum = 0;
+ for (const auto& vec : vecs) {
+ sum += vec.size();
+ }
+ std::vector<T> out;
+ out.reserve(sum);
+ for (const auto& vec : vecs) {
+ out.insert(out.end(), vec.begin(), vec.end());
+ }
+ return out;
+}
+
+template <typename T>
+Result<std::vector<T>> UnwrapOrRaise(std::vector<Result<T>>&& results) {
+ std::vector<T> out;
+ out.reserve(results.size());
+ auto end = std::make_move_iterator(results.end());
+ for (auto it = std::make_move_iterator(results.begin()); it != end; it++) {
+ if (!it->ok()) {
+ return it->status();
+ }
+ out.push_back(it->MoveValueUnsafe());
+ }
+ return std::move(out);
+}
+
+template <typename T>
+Result<std::vector<T>> UnwrapOrRaise(const std::vector<Result<T>>& results) {
+ std::vector<T> out;
+ out.reserve(results.size());
+ for (const auto& result : results) {
+ if (!result.ok()) {
+ return result.status();
+ }
+ out.push_back(result.ValueUnsafe());
+ }
+ return std::move(out);
+}
+
+} // namespace internal
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/util/visibility.h b/src/arrow/cpp/src/arrow/util/visibility.h
new file mode 100644
index 000000000..dd9ac45e9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/visibility.h
@@ -0,0 +1,45 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#if defined(_WIN32) || defined(__CYGWIN__)
+#if defined(_MSC_VER)
+#pragma warning(disable : 4251)
+#else
+#pragma GCC diagnostic ignored "-Wattributes"
+#endif
+
+#ifdef ARROW_STATIC
+#define ARROW_EXPORT
+#elif defined(ARROW_EXPORTING)
+#define ARROW_EXPORT __declspec(dllexport)
+#else
+#define ARROW_EXPORT __declspec(dllimport)
+#endif
+
+#define ARROW_NO_EXPORT
+#define ARROW_FORCE_INLINE __forceinline
+#else // Not Windows
+#ifndef ARROW_EXPORT
+#define ARROW_EXPORT __attribute__((visibility("default")))
+#endif
+#ifndef ARROW_NO_EXPORT
+#define ARROW_NO_EXPORT __attribute__((visibility("hidden")))
+#define ARROW_FORCE_INLINE
+#endif
+#endif // Non-Windows
diff --git a/src/arrow/cpp/src/arrow/util/windows_compatibility.h b/src/arrow/cpp/src/arrow/util/windows_compatibility.h
new file mode 100644
index 000000000..64a2772c4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/windows_compatibility.h
@@ -0,0 +1,42 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#ifdef _WIN32
+
+// Windows defines min and max macros that mess up std::min/max
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+
+#define WIN32_LEAN_AND_MEAN
+
+// Set Windows 7 as a conservative minimum for Apache Arrow
+#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x601
+#undef _WIN32_WINNT
+#endif
+#ifndef _WIN32_WINNT
+#define _WIN32_WINNT 0x601
+#endif
+
+#include <winsock2.h>
+#include <windows.h>
+
+#include "arrow/util/windows_fixup.h"
+
+#endif // _WIN32
diff --git a/src/arrow/cpp/src/arrow/util/windows_fixup.h b/src/arrow/cpp/src/arrow/util/windows_fixup.h
new file mode 100644
index 000000000..2949ac4ab
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/util/windows_fixup.h
@@ -0,0 +1,52 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This header needs to be included multiple times.
+
+#ifdef _WIN32
+
+#ifdef max
+#undef max
+#endif
+#ifdef min
+#undef min
+#endif
+
+// The Windows API defines macros from *File resolving to either
+// *FileA or *FileW. Need to undo them.
+#ifdef CopyFile
+#undef CopyFile
+#endif
+#ifdef CreateFile
+#undef CreateFile
+#endif
+#ifdef DeleteFile
+#undef DeleteFile
+#endif
+
+// Other annoying Windows macro definitions...
+#ifdef IN
+#undef IN
+#endif
+#ifdef OUT
+#undef OUT
+#endif
+
+// Note that we can't undefine OPTIONAL, because it can be used in other
+// Windows headers...
+
+#endif // _WIN32
diff --git a/src/arrow/cpp/src/arrow/vendored/CMakeLists.txt b/src/arrow/cpp/src/arrow/vendored/CMakeLists.txt
new file mode 100644
index 000000000..8d4c323d2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/CMakeLists.txt
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+arrow_install_all_headers("arrow/vendored")
+
+add_subdirectory(datetime)
+add_subdirectory(double-conversion)
diff --git a/src/arrow/cpp/src/arrow/vendored/ProducerConsumerQueue.h b/src/arrow/cpp/src/arrow/vendored/ProducerConsumerQueue.h
new file mode 100644
index 000000000..0b7cfa1cb
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/ProducerConsumerQueue.h
@@ -0,0 +1,217 @@
+// Vendored from git tag v2021.02.15.00
+
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// @author Bo Hu (bhu@fb.com)
+// @author Jordan DeLong (delong.j@fb.com)
+
+// This file has been modified as part of Apache Arrow to conform to
+// Apache Arrow's coding conventions
+
+#pragma once
+
+#include <atomic>
+#include <cassert>
+#include <cstdlib>
+#include <memory>
+#include <stdexcept>
+#include <type_traits>
+#include <utility>
+
+namespace arrow_vendored {
+namespace folly {
+
+// Vendored from folly/Portability.h
+namespace {
+#if defined(__arm__)
+#define FOLLY_ARM 1
+#else
+#define FOLLY_ARM 0
+#endif
+
+#if defined(__s390x__)
+#define FOLLY_S390X 1
+#else
+#define FOLLY_S390X 0
+#endif
+
+constexpr bool kIsArchArm = FOLLY_ARM == 1;
+constexpr bool kIsArchS390X = FOLLY_S390X == 1;
+} // namespace
+
+// Vendored from folly/lang/Align.h
+namespace {
+
+constexpr std::size_t hardware_destructive_interference_size =
+ (kIsArchArm || kIsArchS390X) ? 64 : 128;
+
+} // namespace
+
+/*
+ * ProducerConsumerQueue is a one producer and one consumer queue
+ * without locks.
+ */
+template <class T>
+struct ProducerConsumerQueue {
+ typedef T value_type;
+
+ ProducerConsumerQueue(const ProducerConsumerQueue&) = delete;
+ ProducerConsumerQueue& operator=(const ProducerConsumerQueue&) = delete;
+
+ // size must be >= 2.
+ //
+ // Also, note that the number of usable slots in the queue at any
+ // given time is actually (size-1), so if you start with an empty queue,
+ // IsFull() will return true after size-1 insertions.
+ explicit ProducerConsumerQueue(uint32_t size)
+ : size_(size),
+ records_(static_cast<T*>(std::malloc(sizeof(T) * size))),
+ readIndex_(0),
+ writeIndex_(0) {
+ assert(size >= 2);
+ if (!records_) {
+ throw std::bad_alloc();
+ }
+ }
+
+ ~ProducerConsumerQueue() {
+ // We need to destruct anything that may still exist in our queue.
+ // (No real synchronization needed at destructor time: only one
+ // thread can be doing this.)
+ if (!std::is_trivially_destructible<T>::value) {
+ size_t readIndex = readIndex_;
+ size_t endIndex = writeIndex_;
+ while (readIndex != endIndex) {
+ records_[readIndex].~T();
+ if (++readIndex == size_) {
+ readIndex = 0;
+ }
+ }
+ }
+
+ std::free(records_);
+ }
+
+ template <class... Args>
+ bool Write(Args&&... recordArgs) {
+ auto const currentWrite = writeIndex_.load(std::memory_order_relaxed);
+ auto nextRecord = currentWrite + 1;
+ if (nextRecord == size_) {
+ nextRecord = 0;
+ }
+ if (nextRecord != readIndex_.load(std::memory_order_acquire)) {
+ new (&records_[currentWrite]) T(std::forward<Args>(recordArgs)...);
+ writeIndex_.store(nextRecord, std::memory_order_release);
+ return true;
+ }
+
+ // queue is full
+ return false;
+ }
+
+ // move the value at the front of the queue to given variable
+ bool Read(T& record) {
+ auto const currentRead = readIndex_.load(std::memory_order_relaxed);
+ if (currentRead == writeIndex_.load(std::memory_order_acquire)) {
+ // queue is empty
+ return false;
+ }
+
+ auto nextRecord = currentRead + 1;
+ if (nextRecord == size_) {
+ nextRecord = 0;
+ }
+ record = std::move(records_[currentRead]);
+ records_[currentRead].~T();
+ readIndex_.store(nextRecord, std::memory_order_release);
+ return true;
+ }
+
+ // pointer to the value at the front of the queue (for use in-place) or
+ // nullptr if empty.
+ T* FrontPtr() {
+ auto const currentRead = readIndex_.load(std::memory_order_relaxed);
+ if (currentRead == writeIndex_.load(std::memory_order_acquire)) {
+ // queue is empty
+ return nullptr;
+ }
+ return &records_[currentRead];
+ }
+
+ // queue must not be empty
+ void PopFront() {
+ auto const currentRead = readIndex_.load(std::memory_order_relaxed);
+ assert(currentRead != writeIndex_.load(std::memory_order_acquire));
+
+ auto nextRecord = currentRead + 1;
+ if (nextRecord == size_) {
+ nextRecord = 0;
+ }
+ records_[currentRead].~T();
+ readIndex_.store(nextRecord, std::memory_order_release);
+ }
+
+ bool IsEmpty() const {
+ return readIndex_.load(std::memory_order_acquire) ==
+ writeIndex_.load(std::memory_order_acquire);
+ }
+
+ bool IsFull() const {
+ auto nextRecord = writeIndex_.load(std::memory_order_acquire) + 1;
+ if (nextRecord == size_) {
+ nextRecord = 0;
+ }
+ if (nextRecord != readIndex_.load(std::memory_order_acquire)) {
+ return false;
+ }
+ // queue is full
+ return true;
+ }
+
+ // * If called by consumer, then true size may be more (because producer may
+ // be adding items concurrently).
+ // * If called by producer, then true size may be less (because consumer may
+ // be removing items concurrently).
+ // * It is undefined to call this from any other thread.
+ size_t SizeGuess() const {
+ int ret = writeIndex_.load(std::memory_order_acquire) -
+ readIndex_.load(std::memory_order_acquire);
+ if (ret < 0) {
+ ret += size_;
+ }
+ return ret;
+ }
+
+ // maximum number of items in the queue.
+ size_t capacity() const { return size_ - 1; }
+
+ private:
+ using AtomicIndex = std::atomic<unsigned int>;
+
+ char pad0_[hardware_destructive_interference_size];
+ const uint32_t size_;
+ T* const records_;
+
+ AtomicIndex readIndex_;
+ char pad1_[hardware_destructive_interference_size - sizeof(AtomicIndex)];
+ AtomicIndex writeIndex_;
+
+ char pad2_[hardware_destructive_interference_size - sizeof(AtomicIndex)];
+};
+
+} // namespace folly
+} // namespace arrow_vendored
diff --git a/src/arrow/cpp/src/arrow/vendored/base64.cpp b/src/arrow/cpp/src/arrow/vendored/base64.cpp
new file mode 100644
index 000000000..0de11955b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/base64.cpp
@@ -0,0 +1,134 @@
+/*
+ base64.cpp and base64.h
+
+ base64 encoding and decoding with C++.
+
+ Version: 1.01.00
+
+ Copyright (C) 2004-2017 René Nyffenegger
+
+ This source code is provided 'as-is', without any express or implied
+ warranty. In no event will the author be held liable for any damages
+ arising from the use of this software.
+
+ Permission is granted to anyone to use this software for any purpose,
+ including commercial applications, and to alter it and redistribute it
+ freely, subject to the following restrictions:
+
+ 1. The origin of this source code must not be misrepresented; you must not
+ claim that you wrote the original source code. If you use this source code
+ in a product, an acknowledgment in the product documentation would be
+ appreciated but is not required.
+
+ 2. Altered source versions must be plainly marked as such, and must not be
+ misrepresented as being the original source code.
+
+ 3. This notice may not be removed or altered from any source distribution.
+
+ René Nyffenegger rene.nyffenegger@adp-gmbh.ch
+
+*/
+
+#include "arrow/util/base64.h"
+#include <iostream>
+
+namespace arrow {
+namespace util {
+
+static const std::string base64_chars =
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "abcdefghijklmnopqrstuvwxyz"
+ "0123456789+/";
+
+
+static inline bool is_base64(unsigned char c) {
+ return (isalnum(c) || (c == '+') || (c == '/'));
+}
+
+static std::string base64_encode(unsigned char const* bytes_to_encode, unsigned int in_len) {
+ std::string ret;
+ int i = 0;
+ int j = 0;
+ unsigned char char_array_3[3];
+ unsigned char char_array_4[4];
+
+ while (in_len--) {
+ char_array_3[i++] = *(bytes_to_encode++);
+ if (i == 3) {
+ char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
+ char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
+ char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
+ char_array_4[3] = char_array_3[2] & 0x3f;
+
+ for(i = 0; (i <4) ; i++)
+ ret += base64_chars[char_array_4[i]];
+ i = 0;
+ }
+ }
+
+ if (i)
+ {
+ for(j = i; j < 3; j++)
+ char_array_3[j] = '\0';
+
+ char_array_4[0] = ( char_array_3[0] & 0xfc) >> 2;
+ char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
+ char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
+
+ for (j = 0; (j < i + 1); j++)
+ ret += base64_chars[char_array_4[j]];
+
+ while((i++ < 3))
+ ret += '=';
+
+ }
+
+ return ret;
+
+}
+
+std::string base64_encode(string_view string_to_encode) {
+ auto bytes_to_encode = reinterpret_cast<const unsigned char*>(string_to_encode.data());
+ auto in_len = static_cast<unsigned int>(string_to_encode.size());
+ return base64_encode(bytes_to_encode, in_len);
+}
+
+std::string base64_decode(string_view encoded_string) {
+ size_t in_len = encoded_string.size();
+ int i = 0;
+ int j = 0;
+ int in_ = 0;
+ unsigned char char_array_4[4], char_array_3[3];
+ std::string ret;
+
+ while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
+ char_array_4[i++] = encoded_string[in_]; in_++;
+ if (i ==4) {
+ for (i = 0; i <4; i++)
+ char_array_4[i] = base64_chars.find(char_array_4[i]) & 0xff;
+
+ char_array_3[0] = ( char_array_4[0] << 2 ) + ((char_array_4[1] & 0x30) >> 4);
+ char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
+ char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
+
+ for (i = 0; (i < 3); i++)
+ ret += char_array_3[i];
+ i = 0;
+ }
+ }
+
+ if (i) {
+ for (j = 0; j < i; j++)
+ char_array_4[j] = base64_chars.find(char_array_4[j]) & 0xff;
+
+ char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
+ char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
+
+ for (j = 0; (j < i - 1); j++) ret += char_array_3[j];
+ }
+
+ return ret;
+}
+
+} // namespace util
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/vendored/datetime.h b/src/arrow/cpp/src/arrow/vendored/datetime.h
new file mode 100644
index 000000000..e437cdcbc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/datetime.h
@@ -0,0 +1,26 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/vendored/datetime/date.h" // IWYU pragma: export
+#include "arrow/vendored/datetime/tz.h" // IWYU pragma: export
+
+// Can be defined by date.h.
+#ifdef NOEXCEPT
+#undef NOEXCEPT
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/datetime/CMakeLists.txt b/src/arrow/cpp/src/arrow/vendored/datetime/CMakeLists.txt
new file mode 100644
index 000000000..00366d0ee
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/datetime/CMakeLists.txt
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+arrow_install_all_headers("arrow/vendored/datetime")
diff --git a/src/arrow/cpp/src/arrow/vendored/datetime/README.md b/src/arrow/cpp/src/arrow/vendored/datetime/README.md
new file mode 100644
index 000000000..cff53e7e3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/datetime/README.md
@@ -0,0 +1,28 @@
+<!--
+The MIT License (MIT)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+-->
+
+# Utilities for supporting date time functions
+
+Sources for datetime are adapted from Howard Hinnant's date library
+(https://github.com/HowardHinnant/date).
+
+Sources are taken from changeset 2e19c006e2218447ee31f864191859517603f59f
+of the above project.
+
+The following changes are made:
+- fix internal inclusion paths (from "date/xxx.h" to simply "xxx.h")
+- enclose the `date` namespace inside the `arrow_vendored` namespace
+- include a custom "visibility.h" header from "tz.cpp" for proper DLL
+ exports on Windows
+- disable curl-based database downloading in "tz.h"
diff --git a/src/arrow/cpp/src/arrow/vendored/datetime/date.h b/src/arrow/cpp/src/arrow/vendored/datetime/date.h
new file mode 100644
index 000000000..3b38b263a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/datetime/date.h
@@ -0,0 +1,8237 @@
+#ifndef DATE_H
+#define DATE_H
+
+// The MIT License (MIT)
+//
+// Copyright (c) 2015, 2016, 2017 Howard Hinnant
+// Copyright (c) 2016 Adrian Colomitchi
+// Copyright (c) 2017 Florian Dang
+// Copyright (c) 2017 Paul Thompson
+// Copyright (c) 2018, 2019 Tomasz Kamiński
+// Copyright (c) 2019 Jiangang Zhuang
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+//
+// Our apologies. When the previous paragraph was written, lowercase had not yet
+// been invented (that would involve another several millennia of evolution).
+// We did not mean to shout.
+
+#ifndef HAS_STRING_VIEW
+# if __cplusplus >= 201703 || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
+# define HAS_STRING_VIEW 1
+# else
+# define HAS_STRING_VIEW 0
+# endif
+#endif // HAS_STRING_VIEW
+
+#include <cassert>
+#include <algorithm>
+#include <cctype>
+#include <chrono>
+#include <climits>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+#include <ctime>
+#include <ios>
+#include <istream>
+#include <iterator>
+#include <limits>
+#include <locale>
+#include <memory>
+#include <ostream>
+#include <ratio>
+#include <sstream>
+#include <stdexcept>
+#include <string>
+#if HAS_STRING_VIEW
+# include <string_view>
+#endif
+#include <utility>
+#include <type_traits>
+
+#ifdef __GNUC__
+# pragma GCC diagnostic push
+# if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 7)
+# pragma GCC diagnostic ignored "-Wpedantic"
+# endif
+# if __GNUC__ < 5
+ // GCC 4.9 Bug 61489 Wrong warning with -Wmissing-field-initializers
+# pragma GCC diagnostic ignored "-Wmissing-field-initializers"
+# endif
+#endif
+
+#ifdef _MSC_VER
+# pragma warning(push)
+// warning C4127: conditional expression is constant
+# pragma warning(disable : 4127)
+#endif
+
+namespace arrow_vendored
+{
+namespace date
+{
+
+//---------------+
+// Configuration |
+//---------------+
+
+#ifndef ONLY_C_LOCALE
+# define ONLY_C_LOCALE 0
+#endif
+
+#if defined(_MSC_VER) && (!defined(__clang__) || (_MSC_VER < 1910))
+// MSVC
+# ifndef _SILENCE_CXX17_UNCAUGHT_EXCEPTION_DEPRECATION_WARNING
+# define _SILENCE_CXX17_UNCAUGHT_EXCEPTION_DEPRECATION_WARNING
+# endif
+# if _MSC_VER < 1910
+// before VS2017
+# define CONSTDATA const
+# define CONSTCD11
+# define CONSTCD14
+# define NOEXCEPT _NOEXCEPT
+# else
+// VS2017 and later
+# define CONSTDATA constexpr const
+# define CONSTCD11 constexpr
+# define CONSTCD14 constexpr
+# define NOEXCEPT noexcept
+# endif
+
+#elif defined(__SUNPRO_CC) && __SUNPRO_CC <= 0x5150
+// Oracle Developer Studio 12.6 and earlier
+# define CONSTDATA constexpr const
+# define CONSTCD11 constexpr
+# define CONSTCD14
+# define NOEXCEPT noexcept
+
+#elif __cplusplus >= 201402
+// C++14
+# define CONSTDATA constexpr const
+# define CONSTCD11 constexpr
+# define CONSTCD14 constexpr
+# define NOEXCEPT noexcept
+#else
+// C++11
+# define CONSTDATA constexpr const
+# define CONSTCD11 constexpr
+# define CONSTCD14
+# define NOEXCEPT noexcept
+#endif
+
+#ifndef HAS_UNCAUGHT_EXCEPTIONS
+# if __cplusplus >= 201703 || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
+# define HAS_UNCAUGHT_EXCEPTIONS 1
+# else
+# define HAS_UNCAUGHT_EXCEPTIONS 0
+# endif
+#endif // HAS_UNCAUGHT_EXCEPTIONS
+
+#ifndef HAS_VOID_T
+# if __cplusplus >= 201703 || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
+# define HAS_VOID_T 1
+# else
+# define HAS_VOID_T 0
+# endif
+#endif // HAS_VOID_T
+
+// Protect from Oracle sun macro
+#ifdef sun
+# undef sun
+#endif
+
+// Work around for a NVCC compiler bug which causes it to fail
+// to compile std::ratio_{multiply,divide} when used directly
+// in the std::chrono::duration template instantiations below
+namespace detail {
+template <typename R1, typename R2>
+using ratio_multiply = decltype(std::ratio_multiply<R1, R2>{});
+
+template <typename R1, typename R2>
+using ratio_divide = decltype(std::ratio_divide<R1, R2>{});
+} // namespace detail
+
+//-----------+
+// Interface |
+//-----------+
+
+// durations
+
+using days = std::chrono::duration
+ <int, detail::ratio_multiply<std::ratio<24>, std::chrono::hours::period>>;
+
+using weeks = std::chrono::duration
+ <int, detail::ratio_multiply<std::ratio<7>, days::period>>;
+
+using years = std::chrono::duration
+ <int, detail::ratio_multiply<std::ratio<146097, 400>, days::period>>;
+
+using months = std::chrono::duration
+ <int, detail::ratio_divide<years::period, std::ratio<12>>>;
+
+// time_point
+
+template <class Duration>
+ using sys_time = std::chrono::time_point<std::chrono::system_clock, Duration>;
+
+using sys_days = sys_time<days>;
+using sys_seconds = sys_time<std::chrono::seconds>;
+
+struct local_t {};
+
+template <class Duration>
+ using local_time = std::chrono::time_point<local_t, Duration>;
+
+using local_seconds = local_time<std::chrono::seconds>;
+using local_days = local_time<days>;
+
+// types
+
+struct last_spec
+{
+ explicit last_spec() = default;
+};
+
+class day;
+class month;
+class year;
+
+class weekday;
+class weekday_indexed;
+class weekday_last;
+
+class month_day;
+class month_day_last;
+class month_weekday;
+class month_weekday_last;
+
+class year_month;
+
+class year_month_day;
+class year_month_day_last;
+class year_month_weekday;
+class year_month_weekday_last;
+
+// date composition operators
+
+CONSTCD11 year_month operator/(const year& y, const month& m) NOEXCEPT;
+CONSTCD11 year_month operator/(const year& y, int m) NOEXCEPT;
+
+CONSTCD11 month_day operator/(const day& d, const month& m) NOEXCEPT;
+CONSTCD11 month_day operator/(const day& d, int m) NOEXCEPT;
+CONSTCD11 month_day operator/(const month& m, const day& d) NOEXCEPT;
+CONSTCD11 month_day operator/(const month& m, int d) NOEXCEPT;
+CONSTCD11 month_day operator/(int m, const day& d) NOEXCEPT;
+
+CONSTCD11 month_day_last operator/(const month& m, last_spec) NOEXCEPT;
+CONSTCD11 month_day_last operator/(int m, last_spec) NOEXCEPT;
+CONSTCD11 month_day_last operator/(last_spec, const month& m) NOEXCEPT;
+CONSTCD11 month_day_last operator/(last_spec, int m) NOEXCEPT;
+
+CONSTCD11 month_weekday operator/(const month& m, const weekday_indexed& wdi) NOEXCEPT;
+CONSTCD11 month_weekday operator/(int m, const weekday_indexed& wdi) NOEXCEPT;
+CONSTCD11 month_weekday operator/(const weekday_indexed& wdi, const month& m) NOEXCEPT;
+CONSTCD11 month_weekday operator/(const weekday_indexed& wdi, int m) NOEXCEPT;
+
+CONSTCD11 month_weekday_last operator/(const month& m, const weekday_last& wdl) NOEXCEPT;
+CONSTCD11 month_weekday_last operator/(int m, const weekday_last& wdl) NOEXCEPT;
+CONSTCD11 month_weekday_last operator/(const weekday_last& wdl, const month& m) NOEXCEPT;
+CONSTCD11 month_weekday_last operator/(const weekday_last& wdl, int m) NOEXCEPT;
+
+CONSTCD11 year_month_day operator/(const year_month& ym, const day& d) NOEXCEPT;
+CONSTCD11 year_month_day operator/(const year_month& ym, int d) NOEXCEPT;
+CONSTCD11 year_month_day operator/(const year& y, const month_day& md) NOEXCEPT;
+CONSTCD11 year_month_day operator/(int y, const month_day& md) NOEXCEPT;
+CONSTCD11 year_month_day operator/(const month_day& md, const year& y) NOEXCEPT;
+CONSTCD11 year_month_day operator/(const month_day& md, int y) NOEXCEPT;
+
+CONSTCD11
+ year_month_day_last operator/(const year_month& ym, last_spec) NOEXCEPT;
+CONSTCD11
+ year_month_day_last operator/(const year& y, const month_day_last& mdl) NOEXCEPT;
+CONSTCD11
+ year_month_day_last operator/(int y, const month_day_last& mdl) NOEXCEPT;
+CONSTCD11
+ year_month_day_last operator/(const month_day_last& mdl, const year& y) NOEXCEPT;
+CONSTCD11
+ year_month_day_last operator/(const month_day_last& mdl, int y) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday
+operator/(const year_month& ym, const weekday_indexed& wdi) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday
+operator/(const year& y, const month_weekday& mwd) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday
+operator/(int y, const month_weekday& mwd) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday
+operator/(const month_weekday& mwd, const year& y) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday
+operator/(const month_weekday& mwd, int y) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday_last
+operator/(const year_month& ym, const weekday_last& wdl) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday_last
+operator/(const year& y, const month_weekday_last& mwdl) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday_last
+operator/(int y, const month_weekday_last& mwdl) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday_last
+operator/(const month_weekday_last& mwdl, const year& y) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday_last
+operator/(const month_weekday_last& mwdl, int y) NOEXCEPT;
+
+// Detailed interface
+
+// day
+
+class day
+{
+ unsigned char d_;
+
+public:
+ day() = default;
+ explicit CONSTCD11 day(unsigned d) NOEXCEPT;
+
+ CONSTCD14 day& operator++() NOEXCEPT;
+ CONSTCD14 day operator++(int) NOEXCEPT;
+ CONSTCD14 day& operator--() NOEXCEPT;
+ CONSTCD14 day operator--(int) NOEXCEPT;
+
+ CONSTCD14 day& operator+=(const days& d) NOEXCEPT;
+ CONSTCD14 day& operator-=(const days& d) NOEXCEPT;
+
+ CONSTCD11 explicit operator unsigned() const NOEXCEPT;
+ CONSTCD11 bool ok() const NOEXCEPT;
+};
+
+CONSTCD11 bool operator==(const day& x, const day& y) NOEXCEPT;
+CONSTCD11 bool operator!=(const day& x, const day& y) NOEXCEPT;
+CONSTCD11 bool operator< (const day& x, const day& y) NOEXCEPT;
+CONSTCD11 bool operator> (const day& x, const day& y) NOEXCEPT;
+CONSTCD11 bool operator<=(const day& x, const day& y) NOEXCEPT;
+CONSTCD11 bool operator>=(const day& x, const day& y) NOEXCEPT;
+
+CONSTCD11 day operator+(const day& x, const days& y) NOEXCEPT;
+CONSTCD11 day operator+(const days& x, const day& y) NOEXCEPT;
+CONSTCD11 day operator-(const day& x, const days& y) NOEXCEPT;
+CONSTCD11 days operator-(const day& x, const day& y) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const day& d);
+
+// month
+
+class month
+{
+ unsigned char m_;
+
+public:
+ month() = default;
+ explicit CONSTCD11 month(unsigned m) NOEXCEPT;
+
+ CONSTCD14 month& operator++() NOEXCEPT;
+ CONSTCD14 month operator++(int) NOEXCEPT;
+ CONSTCD14 month& operator--() NOEXCEPT;
+ CONSTCD14 month operator--(int) NOEXCEPT;
+
+ CONSTCD14 month& operator+=(const months& m) NOEXCEPT;
+ CONSTCD14 month& operator-=(const months& m) NOEXCEPT;
+
+ CONSTCD11 explicit operator unsigned() const NOEXCEPT;
+ CONSTCD11 bool ok() const NOEXCEPT;
+};
+
+CONSTCD11 bool operator==(const month& x, const month& y) NOEXCEPT;
+CONSTCD11 bool operator!=(const month& x, const month& y) NOEXCEPT;
+CONSTCD11 bool operator< (const month& x, const month& y) NOEXCEPT;
+CONSTCD11 bool operator> (const month& x, const month& y) NOEXCEPT;
+CONSTCD11 bool operator<=(const month& x, const month& y) NOEXCEPT;
+CONSTCD11 bool operator>=(const month& x, const month& y) NOEXCEPT;
+
+CONSTCD14 month operator+(const month& x, const months& y) NOEXCEPT;
+CONSTCD14 month operator+(const months& x, const month& y) NOEXCEPT;
+CONSTCD14 month operator-(const month& x, const months& y) NOEXCEPT;
+CONSTCD14 months operator-(const month& x, const month& y) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const month& m);
+
+// year
+
+class year
+{
+ short y_;
+
+public:
+ year() = default;
+ explicit CONSTCD11 year(int y) NOEXCEPT;
+
+ CONSTCD14 year& operator++() NOEXCEPT;
+ CONSTCD14 year operator++(int) NOEXCEPT;
+ CONSTCD14 year& operator--() NOEXCEPT;
+ CONSTCD14 year operator--(int) NOEXCEPT;
+
+ CONSTCD14 year& operator+=(const years& y) NOEXCEPT;
+ CONSTCD14 year& operator-=(const years& y) NOEXCEPT;
+
+ CONSTCD11 year operator-() const NOEXCEPT;
+ CONSTCD11 year operator+() const NOEXCEPT;
+
+ CONSTCD11 bool is_leap() const NOEXCEPT;
+
+ CONSTCD11 explicit operator int() const NOEXCEPT;
+ CONSTCD11 bool ok() const NOEXCEPT;
+
+ static CONSTCD11 year min() NOEXCEPT { return year{-32767}; }
+ static CONSTCD11 year max() NOEXCEPT { return year{32767}; }
+};
+
+CONSTCD11 bool operator==(const year& x, const year& y) NOEXCEPT;
+CONSTCD11 bool operator!=(const year& x, const year& y) NOEXCEPT;
+CONSTCD11 bool operator< (const year& x, const year& y) NOEXCEPT;
+CONSTCD11 bool operator> (const year& x, const year& y) NOEXCEPT;
+CONSTCD11 bool operator<=(const year& x, const year& y) NOEXCEPT;
+CONSTCD11 bool operator>=(const year& x, const year& y) NOEXCEPT;
+
+CONSTCD11 year operator+(const year& x, const years& y) NOEXCEPT;
+CONSTCD11 year operator+(const years& x, const year& y) NOEXCEPT;
+CONSTCD11 year operator-(const year& x, const years& y) NOEXCEPT;
+CONSTCD11 years operator-(const year& x, const year& y) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const year& y);
+
+// weekday
+
+class weekday
+{
+ unsigned char wd_;
+public:
+ weekday() = default;
+ explicit CONSTCD11 weekday(unsigned wd) NOEXCEPT;
+ CONSTCD14 weekday(const sys_days& dp) NOEXCEPT;
+ CONSTCD14 explicit weekday(const local_days& dp) NOEXCEPT;
+
+ CONSTCD14 weekday& operator++() NOEXCEPT;
+ CONSTCD14 weekday operator++(int) NOEXCEPT;
+ CONSTCD14 weekday& operator--() NOEXCEPT;
+ CONSTCD14 weekday operator--(int) NOEXCEPT;
+
+ CONSTCD14 weekday& operator+=(const days& d) NOEXCEPT;
+ CONSTCD14 weekday& operator-=(const days& d) NOEXCEPT;
+
+ CONSTCD11 bool ok() const NOEXCEPT;
+
+ CONSTCD11 unsigned c_encoding() const NOEXCEPT;
+ CONSTCD11 unsigned iso_encoding() const NOEXCEPT;
+
+ CONSTCD11 weekday_indexed operator[](unsigned index) const NOEXCEPT;
+ CONSTCD11 weekday_last operator[](last_spec) const NOEXCEPT;
+
+private:
+ static CONSTCD14 unsigned char weekday_from_days(int z) NOEXCEPT;
+
+ friend CONSTCD11 bool operator==(const weekday& x, const weekday& y) NOEXCEPT;
+ friend CONSTCD14 days operator-(const weekday& x, const weekday& y) NOEXCEPT;
+ friend CONSTCD14 weekday operator+(const weekday& x, const days& y) NOEXCEPT;
+ template<class CharT, class Traits>
+ friend std::basic_ostream<CharT, Traits>&
+ operator<<(std::basic_ostream<CharT, Traits>& os, const weekday& wd);
+ friend class weekday_indexed;
+};
+
+CONSTCD11 bool operator==(const weekday& x, const weekday& y) NOEXCEPT;
+CONSTCD11 bool operator!=(const weekday& x, const weekday& y) NOEXCEPT;
+
+CONSTCD14 weekday operator+(const weekday& x, const days& y) NOEXCEPT;
+CONSTCD14 weekday operator+(const days& x, const weekday& y) NOEXCEPT;
+CONSTCD14 weekday operator-(const weekday& x, const days& y) NOEXCEPT;
+CONSTCD14 days operator-(const weekday& x, const weekday& y) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const weekday& wd);
+
+// weekday_indexed
+
+class weekday_indexed
+{
+ unsigned char wd_ : 4;
+ unsigned char index_ : 4;
+
+public:
+ weekday_indexed() = default;
+ CONSTCD11 weekday_indexed(const date::weekday& wd, unsigned index) NOEXCEPT;
+
+ CONSTCD11 date::weekday weekday() const NOEXCEPT;
+ CONSTCD11 unsigned index() const NOEXCEPT;
+ CONSTCD11 bool ok() const NOEXCEPT;
+};
+
+CONSTCD11 bool operator==(const weekday_indexed& x, const weekday_indexed& y) NOEXCEPT;
+CONSTCD11 bool operator!=(const weekday_indexed& x, const weekday_indexed& y) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const weekday_indexed& wdi);
+
+// weekday_last
+
+class weekday_last
+{
+ date::weekday wd_;
+
+public:
+ explicit CONSTCD11 weekday_last(const date::weekday& wd) NOEXCEPT;
+
+ CONSTCD11 date::weekday weekday() const NOEXCEPT;
+ CONSTCD11 bool ok() const NOEXCEPT;
+};
+
+CONSTCD11 bool operator==(const weekday_last& x, const weekday_last& y) NOEXCEPT;
+CONSTCD11 bool operator!=(const weekday_last& x, const weekday_last& y) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const weekday_last& wdl);
+
+namespace detail
+{
+
+struct unspecified_month_disambiguator {};
+
+} // namespace detail
+
+// year_month
+
+class year_month
+{
+ date::year y_;
+ date::month m_;
+
+public:
+ year_month() = default;
+ CONSTCD11 year_month(const date::year& y, const date::month& m) NOEXCEPT;
+
+ CONSTCD11 date::year year() const NOEXCEPT;
+ CONSTCD11 date::month month() const NOEXCEPT;
+
+ template<class = detail::unspecified_month_disambiguator>
+ CONSTCD14 year_month& operator+=(const months& dm) NOEXCEPT;
+ template<class = detail::unspecified_month_disambiguator>
+ CONSTCD14 year_month& operator-=(const months& dm) NOEXCEPT;
+ CONSTCD14 year_month& operator+=(const years& dy) NOEXCEPT;
+ CONSTCD14 year_month& operator-=(const years& dy) NOEXCEPT;
+
+ CONSTCD11 bool ok() const NOEXCEPT;
+};
+
+CONSTCD11 bool operator==(const year_month& x, const year_month& y) NOEXCEPT;
+CONSTCD11 bool operator!=(const year_month& x, const year_month& y) NOEXCEPT;
+CONSTCD11 bool operator< (const year_month& x, const year_month& y) NOEXCEPT;
+CONSTCD11 bool operator> (const year_month& x, const year_month& y) NOEXCEPT;
+CONSTCD11 bool operator<=(const year_month& x, const year_month& y) NOEXCEPT;
+CONSTCD11 bool operator>=(const year_month& x, const year_month& y) NOEXCEPT;
+
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14 year_month operator+(const year_month& ym, const months& dm) NOEXCEPT;
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14 year_month operator+(const months& dm, const year_month& ym) NOEXCEPT;
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14 year_month operator-(const year_month& ym, const months& dm) NOEXCEPT;
+
+CONSTCD11 months operator-(const year_month& x, const year_month& y) NOEXCEPT;
+CONSTCD11 year_month operator+(const year_month& ym, const years& dy) NOEXCEPT;
+CONSTCD11 year_month operator+(const years& dy, const year_month& ym) NOEXCEPT;
+CONSTCD11 year_month operator-(const year_month& ym, const years& dy) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const year_month& ym);
+
+// month_day
+
+class month_day
+{
+ date::month m_;
+ date::day d_;
+
+public:
+ month_day() = default;
+ CONSTCD11 month_day(const date::month& m, const date::day& d) NOEXCEPT;
+
+ CONSTCD11 date::month month() const NOEXCEPT;
+ CONSTCD11 date::day day() const NOEXCEPT;
+
+ CONSTCD14 bool ok() const NOEXCEPT;
+};
+
+CONSTCD11 bool operator==(const month_day& x, const month_day& y) NOEXCEPT;
+CONSTCD11 bool operator!=(const month_day& x, const month_day& y) NOEXCEPT;
+CONSTCD11 bool operator< (const month_day& x, const month_day& y) NOEXCEPT;
+CONSTCD11 bool operator> (const month_day& x, const month_day& y) NOEXCEPT;
+CONSTCD11 bool operator<=(const month_day& x, const month_day& y) NOEXCEPT;
+CONSTCD11 bool operator>=(const month_day& x, const month_day& y) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const month_day& md);
+
+// month_day_last
+
+class month_day_last
+{
+ date::month m_;
+
+public:
+ CONSTCD11 explicit month_day_last(const date::month& m) NOEXCEPT;
+
+ CONSTCD11 date::month month() const NOEXCEPT;
+ CONSTCD11 bool ok() const NOEXCEPT;
+};
+
+CONSTCD11 bool operator==(const month_day_last& x, const month_day_last& y) NOEXCEPT;
+CONSTCD11 bool operator!=(const month_day_last& x, const month_day_last& y) NOEXCEPT;
+CONSTCD11 bool operator< (const month_day_last& x, const month_day_last& y) NOEXCEPT;
+CONSTCD11 bool operator> (const month_day_last& x, const month_day_last& y) NOEXCEPT;
+CONSTCD11 bool operator<=(const month_day_last& x, const month_day_last& y) NOEXCEPT;
+CONSTCD11 bool operator>=(const month_day_last& x, const month_day_last& y) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const month_day_last& mdl);
+
+// month_weekday
+
+class month_weekday
+{
+ date::month m_;
+ date::weekday_indexed wdi_;
+public:
+ CONSTCD11 month_weekday(const date::month& m,
+ const date::weekday_indexed& wdi) NOEXCEPT;
+
+ CONSTCD11 date::month month() const NOEXCEPT;
+ CONSTCD11 date::weekday_indexed weekday_indexed() const NOEXCEPT;
+
+ CONSTCD11 bool ok() const NOEXCEPT;
+};
+
+CONSTCD11 bool operator==(const month_weekday& x, const month_weekday& y) NOEXCEPT;
+CONSTCD11 bool operator!=(const month_weekday& x, const month_weekday& y) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const month_weekday& mwd);
+
+// month_weekday_last
+
+class month_weekday_last
+{
+ date::month m_;
+ date::weekday_last wdl_;
+
+public:
+ CONSTCD11 month_weekday_last(const date::month& m,
+ const date::weekday_last& wd) NOEXCEPT;
+
+ CONSTCD11 date::month month() const NOEXCEPT;
+ CONSTCD11 date::weekday_last weekday_last() const NOEXCEPT;
+
+ CONSTCD11 bool ok() const NOEXCEPT;
+};
+
+CONSTCD11
+ bool operator==(const month_weekday_last& x, const month_weekday_last& y) NOEXCEPT;
+CONSTCD11
+ bool operator!=(const month_weekday_last& x, const month_weekday_last& y) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const month_weekday_last& mwdl);
+
+// class year_month_day
+
+class year_month_day
+{
+ date::year y_;
+ date::month m_;
+ date::day d_;
+
+public:
+ year_month_day() = default;
+ CONSTCD11 year_month_day(const date::year& y, const date::month& m,
+ const date::day& d) NOEXCEPT;
+ CONSTCD14 year_month_day(const year_month_day_last& ymdl) NOEXCEPT;
+
+ CONSTCD14 year_month_day(sys_days dp) NOEXCEPT;
+ CONSTCD14 explicit year_month_day(local_days dp) NOEXCEPT;
+
+ template<class = detail::unspecified_month_disambiguator>
+ CONSTCD14 year_month_day& operator+=(const months& m) NOEXCEPT;
+ template<class = detail::unspecified_month_disambiguator>
+ CONSTCD14 year_month_day& operator-=(const months& m) NOEXCEPT;
+ CONSTCD14 year_month_day& operator+=(const years& y) NOEXCEPT;
+ CONSTCD14 year_month_day& operator-=(const years& y) NOEXCEPT;
+
+ CONSTCD11 date::year year() const NOEXCEPT;
+ CONSTCD11 date::month month() const NOEXCEPT;
+ CONSTCD11 date::day day() const NOEXCEPT;
+
+ CONSTCD14 operator sys_days() const NOEXCEPT;
+ CONSTCD14 explicit operator local_days() const NOEXCEPT;
+ CONSTCD14 bool ok() const NOEXCEPT;
+
+private:
+ static CONSTCD14 year_month_day from_days(days dp) NOEXCEPT;
+ CONSTCD14 days to_days() const NOEXCEPT;
+};
+
+CONSTCD11 bool operator==(const year_month_day& x, const year_month_day& y) NOEXCEPT;
+CONSTCD11 bool operator!=(const year_month_day& x, const year_month_day& y) NOEXCEPT;
+CONSTCD11 bool operator< (const year_month_day& x, const year_month_day& y) NOEXCEPT;
+CONSTCD11 bool operator> (const year_month_day& x, const year_month_day& y) NOEXCEPT;
+CONSTCD11 bool operator<=(const year_month_day& x, const year_month_day& y) NOEXCEPT;
+CONSTCD11 bool operator>=(const year_month_day& x, const year_month_day& y) NOEXCEPT;
+
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14 year_month_day operator+(const year_month_day& ymd, const months& dm) NOEXCEPT;
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14 year_month_day operator+(const months& dm, const year_month_day& ymd) NOEXCEPT;
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14 year_month_day operator-(const year_month_day& ymd, const months& dm) NOEXCEPT;
+CONSTCD11 year_month_day operator+(const year_month_day& ymd, const years& dy) NOEXCEPT;
+CONSTCD11 year_month_day operator+(const years& dy, const year_month_day& ymd) NOEXCEPT;
+CONSTCD11 year_month_day operator-(const year_month_day& ymd, const years& dy) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const year_month_day& ymd);
+
+// year_month_day_last
+
+class year_month_day_last
+{
+ date::year y_;
+ date::month_day_last mdl_;
+
+public:
+ CONSTCD11 year_month_day_last(const date::year& y,
+ const date::month_day_last& mdl) NOEXCEPT;
+
+ template<class = detail::unspecified_month_disambiguator>
+ CONSTCD14 year_month_day_last& operator+=(const months& m) NOEXCEPT;
+ template<class = detail::unspecified_month_disambiguator>
+ CONSTCD14 year_month_day_last& operator-=(const months& m) NOEXCEPT;
+ CONSTCD14 year_month_day_last& operator+=(const years& y) NOEXCEPT;
+ CONSTCD14 year_month_day_last& operator-=(const years& y) NOEXCEPT;
+
+ CONSTCD11 date::year year() const NOEXCEPT;
+ CONSTCD11 date::month month() const NOEXCEPT;
+ CONSTCD11 date::month_day_last month_day_last() const NOEXCEPT;
+ CONSTCD14 date::day day() const NOEXCEPT;
+
+ CONSTCD14 operator sys_days() const NOEXCEPT;
+ CONSTCD14 explicit operator local_days() const NOEXCEPT;
+ CONSTCD11 bool ok() const NOEXCEPT;
+};
+
+CONSTCD11
+ bool operator==(const year_month_day_last& x, const year_month_day_last& y) NOEXCEPT;
+CONSTCD11
+ bool operator!=(const year_month_day_last& x, const year_month_day_last& y) NOEXCEPT;
+CONSTCD11
+ bool operator< (const year_month_day_last& x, const year_month_day_last& y) NOEXCEPT;
+CONSTCD11
+ bool operator> (const year_month_day_last& x, const year_month_day_last& y) NOEXCEPT;
+CONSTCD11
+ bool operator<=(const year_month_day_last& x, const year_month_day_last& y) NOEXCEPT;
+CONSTCD11
+ bool operator>=(const year_month_day_last& x, const year_month_day_last& y) NOEXCEPT;
+
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14
+year_month_day_last
+operator+(const year_month_day_last& ymdl, const months& dm) NOEXCEPT;
+
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14
+year_month_day_last
+operator+(const months& dm, const year_month_day_last& ymdl) NOEXCEPT;
+
+CONSTCD11
+year_month_day_last
+operator+(const year_month_day_last& ymdl, const years& dy) NOEXCEPT;
+
+CONSTCD11
+year_month_day_last
+operator+(const years& dy, const year_month_day_last& ymdl) NOEXCEPT;
+
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14
+year_month_day_last
+operator-(const year_month_day_last& ymdl, const months& dm) NOEXCEPT;
+
+CONSTCD11
+year_month_day_last
+operator-(const year_month_day_last& ymdl, const years& dy) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const year_month_day_last& ymdl);
+
+// year_month_weekday
+
+class year_month_weekday
+{
+ date::year y_;
+ date::month m_;
+ date::weekday_indexed wdi_;
+
+public:
+ year_month_weekday() = default;
+ CONSTCD11 year_month_weekday(const date::year& y, const date::month& m,
+ const date::weekday_indexed& wdi) NOEXCEPT;
+ CONSTCD14 year_month_weekday(const sys_days& dp) NOEXCEPT;
+ CONSTCD14 explicit year_month_weekday(const local_days& dp) NOEXCEPT;
+
+ template<class = detail::unspecified_month_disambiguator>
+ CONSTCD14 year_month_weekday& operator+=(const months& m) NOEXCEPT;
+ template<class = detail::unspecified_month_disambiguator>
+ CONSTCD14 year_month_weekday& operator-=(const months& m) NOEXCEPT;
+ CONSTCD14 year_month_weekday& operator+=(const years& y) NOEXCEPT;
+ CONSTCD14 year_month_weekday& operator-=(const years& y) NOEXCEPT;
+
+ CONSTCD11 date::year year() const NOEXCEPT;
+ CONSTCD11 date::month month() const NOEXCEPT;
+ CONSTCD11 date::weekday weekday() const NOEXCEPT;
+ CONSTCD11 unsigned index() const NOEXCEPT;
+ CONSTCD11 date::weekday_indexed weekday_indexed() const NOEXCEPT;
+
+ CONSTCD14 operator sys_days() const NOEXCEPT;
+ CONSTCD14 explicit operator local_days() const NOEXCEPT;
+ CONSTCD14 bool ok() const NOEXCEPT;
+
+private:
+ static CONSTCD14 year_month_weekday from_days(days dp) NOEXCEPT;
+ CONSTCD14 days to_days() const NOEXCEPT;
+};
+
+CONSTCD11
+ bool operator==(const year_month_weekday& x, const year_month_weekday& y) NOEXCEPT;
+CONSTCD11
+ bool operator!=(const year_month_weekday& x, const year_month_weekday& y) NOEXCEPT;
+
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14
+year_month_weekday
+operator+(const year_month_weekday& ymwd, const months& dm) NOEXCEPT;
+
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14
+year_month_weekday
+operator+(const months& dm, const year_month_weekday& ymwd) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday
+operator+(const year_month_weekday& ymwd, const years& dy) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday
+operator+(const years& dy, const year_month_weekday& ymwd) NOEXCEPT;
+
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14
+year_month_weekday
+operator-(const year_month_weekday& ymwd, const months& dm) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday
+operator-(const year_month_weekday& ymwd, const years& dy) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const year_month_weekday& ymwdi);
+
+// year_month_weekday_last
+
+class year_month_weekday_last
+{
+ date::year y_;
+ date::month m_;
+ date::weekday_last wdl_;
+
+public:
+ CONSTCD11 year_month_weekday_last(const date::year& y, const date::month& m,
+ const date::weekday_last& wdl) NOEXCEPT;
+
+ template<class = detail::unspecified_month_disambiguator>
+ CONSTCD14 year_month_weekday_last& operator+=(const months& m) NOEXCEPT;
+ template<class = detail::unspecified_month_disambiguator>
+ CONSTCD14 year_month_weekday_last& operator-=(const months& m) NOEXCEPT;
+ CONSTCD14 year_month_weekday_last& operator+=(const years& y) NOEXCEPT;
+ CONSTCD14 year_month_weekday_last& operator-=(const years& y) NOEXCEPT;
+
+ CONSTCD11 date::year year() const NOEXCEPT;
+ CONSTCD11 date::month month() const NOEXCEPT;
+ CONSTCD11 date::weekday weekday() const NOEXCEPT;
+ CONSTCD11 date::weekday_last weekday_last() const NOEXCEPT;
+
+ CONSTCD14 operator sys_days() const NOEXCEPT;
+ CONSTCD14 explicit operator local_days() const NOEXCEPT;
+ CONSTCD11 bool ok() const NOEXCEPT;
+
+private:
+ CONSTCD14 days to_days() const NOEXCEPT;
+};
+
+CONSTCD11
+bool
+operator==(const year_month_weekday_last& x, const year_month_weekday_last& y) NOEXCEPT;
+
+CONSTCD11
+bool
+operator!=(const year_month_weekday_last& x, const year_month_weekday_last& y) NOEXCEPT;
+
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14
+year_month_weekday_last
+operator+(const year_month_weekday_last& ymwdl, const months& dm) NOEXCEPT;
+
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14
+year_month_weekday_last
+operator+(const months& dm, const year_month_weekday_last& ymwdl) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday_last
+operator+(const year_month_weekday_last& ymwdl, const years& dy) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday_last
+operator+(const years& dy, const year_month_weekday_last& ymwdl) NOEXCEPT;
+
+template<class = detail::unspecified_month_disambiguator>
+CONSTCD14
+year_month_weekday_last
+operator-(const year_month_weekday_last& ymwdl, const months& dm) NOEXCEPT;
+
+CONSTCD11
+year_month_weekday_last
+operator-(const year_month_weekday_last& ymwdl, const years& dy) NOEXCEPT;
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const year_month_weekday_last& ymwdl);
+
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+inline namespace literals
+{
+
+CONSTCD11 date::day operator "" _d(unsigned long long d) NOEXCEPT;
+CONSTCD11 date::year operator "" _y(unsigned long long y) NOEXCEPT;
+
+} // inline namespace literals
+#endif // !defined(_MSC_VER) || (_MSC_VER >= 1900)
+
+// CONSTDATA date::month January{1};
+// CONSTDATA date::month February{2};
+// CONSTDATA date::month March{3};
+// CONSTDATA date::month April{4};
+// CONSTDATA date::month May{5};
+// CONSTDATA date::month June{6};
+// CONSTDATA date::month July{7};
+// CONSTDATA date::month August{8};
+// CONSTDATA date::month September{9};
+// CONSTDATA date::month October{10};
+// CONSTDATA date::month November{11};
+// CONSTDATA date::month December{12};
+//
+// CONSTDATA date::weekday Sunday{0u};
+// CONSTDATA date::weekday Monday{1u};
+// CONSTDATA date::weekday Tuesday{2u};
+// CONSTDATA date::weekday Wednesday{3u};
+// CONSTDATA date::weekday Thursday{4u};
+// CONSTDATA date::weekday Friday{5u};
+// CONSTDATA date::weekday Saturday{6u};
+
+#if HAS_VOID_T
+
+template <class T, class = std::void_t<>>
+struct is_clock
+ : std::false_type
+{};
+
+template <class T>
+struct is_clock<T, std::void_t<decltype(T::now()), typename T::rep, typename T::period,
+ typename T::duration, typename T::time_point,
+ decltype(T::is_steady)>>
+ : std::true_type
+{};
+
+template<class T> inline constexpr bool is_clock_v = is_clock<T>::value;
+
+#endif // HAS_VOID_T
+
+//----------------+
+// Implementation |
+//----------------+
+
+// utilities
+namespace detail {
+
+template<class CharT, class Traits = std::char_traits<CharT>>
+class save_istream
+{
+protected:
+ std::basic_ios<CharT, Traits>& is_;
+ CharT fill_;
+ std::ios::fmtflags flags_;
+ std::streamsize precision_;
+ std::streamsize width_;
+ std::basic_ostream<CharT, Traits>* tie_;
+ std::locale loc_;
+
+public:
+ ~save_istream()
+ {
+ is_.fill(fill_);
+ is_.flags(flags_);
+ is_.precision(precision_);
+ is_.width(width_);
+ is_.imbue(loc_);
+ is_.tie(tie_);
+ }
+
+ save_istream(const save_istream&) = delete;
+ save_istream& operator=(const save_istream&) = delete;
+
+ explicit save_istream(std::basic_ios<CharT, Traits>& is)
+ : is_(is)
+ , fill_(is.fill())
+ , flags_(is.flags())
+ , precision_(is.precision())
+ , width_(is.width(0))
+ , tie_(is.tie(nullptr))
+ , loc_(is.getloc())
+ {
+ if (tie_ != nullptr)
+ tie_->flush();
+ }
+};
+
+template<class CharT, class Traits = std::char_traits<CharT>>
+class save_ostream
+ : private save_istream<CharT, Traits>
+{
+public:
+ ~save_ostream()
+ {
+ if ((this->flags_ & std::ios::unitbuf) &&
+#if HAS_UNCAUGHT_EXCEPTIONS
+ std::uncaught_exceptions() == 0 &&
+#else
+ !std::uncaught_exception() &&
+#endif
+ this->is_.good())
+ this->is_.rdbuf()->pubsync();
+ }
+
+ save_ostream(const save_ostream&) = delete;
+ save_ostream& operator=(const save_ostream&) = delete;
+
+ explicit save_ostream(std::basic_ios<CharT, Traits>& os)
+ : save_istream<CharT, Traits>(os)
+ {
+ }
+};
+
+template <class T>
+struct choose_trunc_type
+{
+ static const int digits = std::numeric_limits<T>::digits;
+ using type = typename std::conditional
+ <
+ digits < 32,
+ std::int32_t,
+ typename std::conditional
+ <
+ digits < 64,
+ std::int64_t,
+#ifdef __SIZEOF_INT128__
+ __int128
+#else
+ std::int64_t
+#endif
+ >::type
+ >::type;
+};
+
+template <class T>
+CONSTCD11
+inline
+typename std::enable_if
+<
+ !std::chrono::treat_as_floating_point<T>::value,
+ T
+>::type
+trunc(T t) NOEXCEPT
+{
+ return t;
+}
+
+template <class T>
+CONSTCD14
+inline
+typename std::enable_if
+<
+ std::chrono::treat_as_floating_point<T>::value,
+ T
+>::type
+trunc(T t) NOEXCEPT
+{
+ using std::numeric_limits;
+ using I = typename choose_trunc_type<T>::type;
+ CONSTDATA auto digits = numeric_limits<T>::digits;
+ static_assert(digits < numeric_limits<I>::digits, "");
+ CONSTDATA auto max = I{1} << (digits-1);
+ CONSTDATA auto min = -max;
+ const auto negative = t < T{0};
+ if (min <= t && t <= max && t != 0 && t == t)
+ {
+ t = static_cast<T>(static_cast<I>(t));
+ if (t == 0 && negative)
+ t = -t;
+ }
+ return t;
+}
+
+template <std::intmax_t Xp, std::intmax_t Yp>
+struct static_gcd
+{
+ static const std::intmax_t value = static_gcd<Yp, Xp % Yp>::value;
+};
+
+template <std::intmax_t Xp>
+struct static_gcd<Xp, 0>
+{
+ static const std::intmax_t value = Xp;
+};
+
+template <>
+struct static_gcd<0, 0>
+{
+ static const std::intmax_t value = 1;
+};
+
+template <class R1, class R2>
+struct no_overflow
+{
+private:
+ static const std::intmax_t gcd_n1_n2 = static_gcd<R1::num, R2::num>::value;
+ static const std::intmax_t gcd_d1_d2 = static_gcd<R1::den, R2::den>::value;
+ static const std::intmax_t n1 = R1::num / gcd_n1_n2;
+ static const std::intmax_t d1 = R1::den / gcd_d1_d2;
+ static const std::intmax_t n2 = R2::num / gcd_n1_n2;
+ static const std::intmax_t d2 = R2::den / gcd_d1_d2;
+#ifdef __cpp_constexpr
+ static const std::intmax_t max = std::numeric_limits<std::intmax_t>::max();
+#else
+ static const std::intmax_t max = LLONG_MAX;
+#endif
+
+ template <std::intmax_t Xp, std::intmax_t Yp, bool overflow>
+ struct mul // overflow == false
+ {
+ static const std::intmax_t value = Xp * Yp;
+ };
+
+ template <std::intmax_t Xp, std::intmax_t Yp>
+ struct mul<Xp, Yp, true>
+ {
+ static const std::intmax_t value = 1;
+ };
+
+public:
+ static const bool value = (n1 <= max / d2) && (n2 <= max / d1);
+ typedef std::ratio<mul<n1, d2, !value>::value,
+ mul<n2, d1, !value>::value> type;
+};
+
+} // detail
+
+// trunc towards zero
+template <class To, class Rep, class Period>
+CONSTCD11
+inline
+typename std::enable_if
+<
+ detail::no_overflow<Period, typename To::period>::value,
+ To
+>::type
+trunc(const std::chrono::duration<Rep, Period>& d)
+{
+ return To{detail::trunc(std::chrono::duration_cast<To>(d).count())};
+}
+
+template <class To, class Rep, class Period>
+CONSTCD11
+inline
+typename std::enable_if
+<
+ !detail::no_overflow<Period, typename To::period>::value,
+ To
+>::type
+trunc(const std::chrono::duration<Rep, Period>& d)
+{
+ using std::chrono::duration_cast;
+ using std::chrono::duration;
+ using rep = typename std::common_type<Rep, typename To::rep>::type;
+ return To{detail::trunc(duration_cast<To>(duration_cast<duration<rep>>(d)).count())};
+}
+
+#ifndef HAS_CHRONO_ROUNDING
+# if defined(_MSC_FULL_VER) && (_MSC_FULL_VER >= 190023918 || (_MSC_FULL_VER >= 190000000 && defined (__clang__)))
+# define HAS_CHRONO_ROUNDING 1
+# elif defined(__cpp_lib_chrono) && __cplusplus > 201402 && __cpp_lib_chrono >= 201510
+# define HAS_CHRONO_ROUNDING 1
+# elif defined(_LIBCPP_VERSION) && __cplusplus > 201402 && _LIBCPP_VERSION >= 3800
+# define HAS_CHRONO_ROUNDING 1
+# else
+# define HAS_CHRONO_ROUNDING 0
+# endif
+#endif // HAS_CHRONO_ROUNDING
+
+#if HAS_CHRONO_ROUNDING == 0
+
+// round down
+template <class To, class Rep, class Period>
+CONSTCD14
+inline
+typename std::enable_if
+<
+ detail::no_overflow<Period, typename To::period>::value,
+ To
+>::type
+floor(const std::chrono::duration<Rep, Period>& d)
+{
+ auto t = trunc<To>(d);
+ if (t > d)
+ return t - To{1};
+ return t;
+}
+
+template <class To, class Rep, class Period>
+CONSTCD14
+inline
+typename std::enable_if
+<
+ !detail::no_overflow<Period, typename To::period>::value,
+ To
+>::type
+floor(const std::chrono::duration<Rep, Period>& d)
+{
+ using rep = typename std::common_type<Rep, typename To::rep>::type;
+ return floor<To>(floor<std::chrono::duration<rep>>(d));
+}
+
+// round to nearest, to even on tie
+template <class To, class Rep, class Period>
+CONSTCD14
+inline
+To
+round(const std::chrono::duration<Rep, Period>& d)
+{
+ auto t0 = floor<To>(d);
+ auto t1 = t0 + To{1};
+ if (t1 == To{0} && t0 < To{0})
+ t1 = -t1;
+ auto diff0 = d - t0;
+ auto diff1 = t1 - d;
+ if (diff0 == diff1)
+ {
+ if (t0 - trunc<To>(t0/2)*2 == To{0})
+ return t0;
+ return t1;
+ }
+ if (diff0 < diff1)
+ return t0;
+ return t1;
+}
+
+// round up
+template <class To, class Rep, class Period>
+CONSTCD14
+inline
+To
+ceil(const std::chrono::duration<Rep, Period>& d)
+{
+ auto t = trunc<To>(d);
+ if (t < d)
+ return t + To{1};
+ return t;
+}
+
+template <class Rep, class Period,
+ class = typename std::enable_if
+ <
+ std::numeric_limits<Rep>::is_signed
+ >::type>
+CONSTCD11
+std::chrono::duration<Rep, Period>
+abs(std::chrono::duration<Rep, Period> d)
+{
+ return d >= d.zero() ? d : -d;
+}
+
+// round down
+template <class To, class Clock, class FromDuration>
+CONSTCD11
+inline
+std::chrono::time_point<Clock, To>
+floor(const std::chrono::time_point<Clock, FromDuration>& tp)
+{
+ using std::chrono::time_point;
+ return time_point<Clock, To>{date::floor<To>(tp.time_since_epoch())};
+}
+
+// round to nearest, to even on tie
+template <class To, class Clock, class FromDuration>
+CONSTCD11
+inline
+std::chrono::time_point<Clock, To>
+round(const std::chrono::time_point<Clock, FromDuration>& tp)
+{
+ using std::chrono::time_point;
+ return time_point<Clock, To>{round<To>(tp.time_since_epoch())};
+}
+
+// round up
+template <class To, class Clock, class FromDuration>
+CONSTCD11
+inline
+std::chrono::time_point<Clock, To>
+ceil(const std::chrono::time_point<Clock, FromDuration>& tp)
+{
+ using std::chrono::time_point;
+ return time_point<Clock, To>{ceil<To>(tp.time_since_epoch())};
+}
+
+#else // HAS_CHRONO_ROUNDING == 1
+
+using std::chrono::floor;
+using std::chrono::ceil;
+using std::chrono::round;
+using std::chrono::abs;
+
+#endif // HAS_CHRONO_ROUNDING
+
+namespace detail
+{
+
+template <class To, class Rep, class Period>
+CONSTCD14
+inline
+typename std::enable_if
+<
+ !std::chrono::treat_as_floating_point<typename To::rep>::value,
+ To
+>::type
+round_i(const std::chrono::duration<Rep, Period>& d)
+{
+ return round<To>(d);
+}
+
+template <class To, class Rep, class Period>
+CONSTCD14
+inline
+typename std::enable_if
+<
+ std::chrono::treat_as_floating_point<typename To::rep>::value,
+ To
+>::type
+round_i(const std::chrono::duration<Rep, Period>& d)
+{
+ return d;
+}
+
+template <class To, class Clock, class FromDuration>
+CONSTCD11
+inline
+std::chrono::time_point<Clock, To>
+round_i(const std::chrono::time_point<Clock, FromDuration>& tp)
+{
+ using std::chrono::time_point;
+ return time_point<Clock, To>{round_i<To>(tp.time_since_epoch())};
+}
+
+} // detail
+
+// trunc towards zero
+template <class To, class Clock, class FromDuration>
+CONSTCD11
+inline
+std::chrono::time_point<Clock, To>
+trunc(const std::chrono::time_point<Clock, FromDuration>& tp)
+{
+ using std::chrono::time_point;
+ return time_point<Clock, To>{trunc<To>(tp.time_since_epoch())};
+}
+
+// day
+
+CONSTCD11 inline day::day(unsigned d) NOEXCEPT : d_(static_cast<decltype(d_)>(d)) {}
+CONSTCD14 inline day& day::operator++() NOEXCEPT {++d_; return *this;}
+CONSTCD14 inline day day::operator++(int) NOEXCEPT {auto tmp(*this); ++(*this); return tmp;}
+CONSTCD14 inline day& day::operator--() NOEXCEPT {--d_; return *this;}
+CONSTCD14 inline day day::operator--(int) NOEXCEPT {auto tmp(*this); --(*this); return tmp;}
+CONSTCD14 inline day& day::operator+=(const days& d) NOEXCEPT {*this = *this + d; return *this;}
+CONSTCD14 inline day& day::operator-=(const days& d) NOEXCEPT {*this = *this - d; return *this;}
+CONSTCD11 inline day::operator unsigned() const NOEXCEPT {return d_;}
+CONSTCD11 inline bool day::ok() const NOEXCEPT {return 1 <= d_ && d_ <= 31;}
+
+CONSTCD11
+inline
+bool
+operator==(const day& x, const day& y) NOEXCEPT
+{
+ return static_cast<unsigned>(x) == static_cast<unsigned>(y);
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const day& x, const day& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+CONSTCD11
+inline
+bool
+operator<(const day& x, const day& y) NOEXCEPT
+{
+ return static_cast<unsigned>(x) < static_cast<unsigned>(y);
+}
+
+CONSTCD11
+inline
+bool
+operator>(const day& x, const day& y) NOEXCEPT
+{
+ return y < x;
+}
+
+CONSTCD11
+inline
+bool
+operator<=(const day& x, const day& y) NOEXCEPT
+{
+ return !(y < x);
+}
+
+CONSTCD11
+inline
+bool
+operator>=(const day& x, const day& y) NOEXCEPT
+{
+ return !(x < y);
+}
+
+CONSTCD11
+inline
+days
+operator-(const day& x, const day& y) NOEXCEPT
+{
+ return days{static_cast<days::rep>(static_cast<unsigned>(x)
+ - static_cast<unsigned>(y))};
+}
+
+CONSTCD11
+inline
+day
+operator+(const day& x, const days& y) NOEXCEPT
+{
+ return day{static_cast<unsigned>(x) + static_cast<unsigned>(y.count())};
+}
+
+CONSTCD11
+inline
+day
+operator+(const days& x, const day& y) NOEXCEPT
+{
+ return y + x;
+}
+
+CONSTCD11
+inline
+day
+operator-(const day& x, const days& y) NOEXCEPT
+{
+ return x + -y;
+}
+
+namespace detail
+{
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+low_level_fmt(std::basic_ostream<CharT, Traits>& os, const day& d)
+{
+ detail::save_ostream<CharT, Traits> _(os);
+ os.fill('0');
+ os.flags(std::ios::dec | std::ios::right);
+ os.width(2);
+ os << static_cast<unsigned>(d);
+ return os;
+}
+
+} // namespace detail
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const day& d)
+{
+ detail::low_level_fmt(os, d);
+ if (!d.ok())
+ os << " is not a valid day";
+ return os;
+}
+
+// month
+
+CONSTCD11 inline month::month(unsigned m) NOEXCEPT : m_(static_cast<decltype(m_)>(m)) {}
+CONSTCD14 inline month& month::operator++() NOEXCEPT {*this += months{1}; return *this;}
+CONSTCD14 inline month month::operator++(int) NOEXCEPT {auto tmp(*this); ++(*this); return tmp;}
+CONSTCD14 inline month& month::operator--() NOEXCEPT {*this -= months{1}; return *this;}
+CONSTCD14 inline month month::operator--(int) NOEXCEPT {auto tmp(*this); --(*this); return tmp;}
+
+CONSTCD14
+inline
+month&
+month::operator+=(const months& m) NOEXCEPT
+{
+ *this = *this + m;
+ return *this;
+}
+
+CONSTCD14
+inline
+month&
+month::operator-=(const months& m) NOEXCEPT
+{
+ *this = *this - m;
+ return *this;
+}
+
+CONSTCD11 inline month::operator unsigned() const NOEXCEPT {return m_;}
+CONSTCD11 inline bool month::ok() const NOEXCEPT {return 1 <= m_ && m_ <= 12;}
+
+CONSTCD11
+inline
+bool
+operator==(const month& x, const month& y) NOEXCEPT
+{
+ return static_cast<unsigned>(x) == static_cast<unsigned>(y);
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const month& x, const month& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+CONSTCD11
+inline
+bool
+operator<(const month& x, const month& y) NOEXCEPT
+{
+ return static_cast<unsigned>(x) < static_cast<unsigned>(y);
+}
+
+CONSTCD11
+inline
+bool
+operator>(const month& x, const month& y) NOEXCEPT
+{
+ return y < x;
+}
+
+CONSTCD11
+inline
+bool
+operator<=(const month& x, const month& y) NOEXCEPT
+{
+ return !(y < x);
+}
+
+CONSTCD11
+inline
+bool
+operator>=(const month& x, const month& y) NOEXCEPT
+{
+ return !(x < y);
+}
+
+CONSTCD14
+inline
+months
+operator-(const month& x, const month& y) NOEXCEPT
+{
+ auto const d = static_cast<unsigned>(x) - static_cast<unsigned>(y);
+ return months(d <= 11 ? d : d + 12);
+}
+
+CONSTCD14
+inline
+month
+operator+(const month& x, const months& y) NOEXCEPT
+{
+ auto const mu = static_cast<long long>(static_cast<unsigned>(x)) + y.count() - 1;
+ auto const yr = (mu >= 0 ? mu : mu-11) / 12;
+ return month{static_cast<unsigned>(mu - yr * 12 + 1)};
+}
+
+CONSTCD14
+inline
+month
+operator+(const months& x, const month& y) NOEXCEPT
+{
+ return y + x;
+}
+
+CONSTCD14
+inline
+month
+operator-(const month& x, const months& y) NOEXCEPT
+{
+ return x + -y;
+}
+
+namespace detail
+{
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+low_level_fmt(std::basic_ostream<CharT, Traits>& os, const month& m)
+{
+ if (m.ok())
+ {
+ CharT fmt[] = {'%', 'b', 0};
+ os << format(os.getloc(), fmt, m);
+ }
+ else
+ os << static_cast<unsigned>(m);
+ return os;
+}
+
+} // namespace detail
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const month& m)
+{
+ detail::low_level_fmt(os, m);
+ if (!m.ok())
+ os << " is not a valid month";
+ return os;
+}
+
+// year
+
+CONSTCD11 inline year::year(int y) NOEXCEPT : y_(static_cast<decltype(y_)>(y)) {}
+CONSTCD14 inline year& year::operator++() NOEXCEPT {++y_; return *this;}
+CONSTCD14 inline year year::operator++(int) NOEXCEPT {auto tmp(*this); ++(*this); return tmp;}
+CONSTCD14 inline year& year::operator--() NOEXCEPT {--y_; return *this;}
+CONSTCD14 inline year year::operator--(int) NOEXCEPT {auto tmp(*this); --(*this); return tmp;}
+CONSTCD14 inline year& year::operator+=(const years& y) NOEXCEPT {*this = *this + y; return *this;}
+CONSTCD14 inline year& year::operator-=(const years& y) NOEXCEPT {*this = *this - y; return *this;}
+CONSTCD11 inline year year::operator-() const NOEXCEPT {return year{-y_};}
+CONSTCD11 inline year year::operator+() const NOEXCEPT {return *this;}
+
+CONSTCD11
+inline
+bool
+year::is_leap() const NOEXCEPT
+{
+ return y_ % 4 == 0 && (y_ % 100 != 0 || y_ % 400 == 0);
+}
+
+CONSTCD11 inline year::operator int() const NOEXCEPT {return y_;}
+
+CONSTCD11
+inline
+bool
+year::ok() const NOEXCEPT
+{
+ return y_ != std::numeric_limits<short>::min();
+}
+
+CONSTCD11
+inline
+bool
+operator==(const year& x, const year& y) NOEXCEPT
+{
+ return static_cast<int>(x) == static_cast<int>(y);
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const year& x, const year& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+CONSTCD11
+inline
+bool
+operator<(const year& x, const year& y) NOEXCEPT
+{
+ return static_cast<int>(x) < static_cast<int>(y);
+}
+
+CONSTCD11
+inline
+bool
+operator>(const year& x, const year& y) NOEXCEPT
+{
+ return y < x;
+}
+
+CONSTCD11
+inline
+bool
+operator<=(const year& x, const year& y) NOEXCEPT
+{
+ return !(y < x);
+}
+
+CONSTCD11
+inline
+bool
+operator>=(const year& x, const year& y) NOEXCEPT
+{
+ return !(x < y);
+}
+
+CONSTCD11
+inline
+years
+operator-(const year& x, const year& y) NOEXCEPT
+{
+ return years{static_cast<int>(x) - static_cast<int>(y)};
+}
+
+CONSTCD11
+inline
+year
+operator+(const year& x, const years& y) NOEXCEPT
+{
+ return year{static_cast<int>(x) + y.count()};
+}
+
+CONSTCD11
+inline
+year
+operator+(const years& x, const year& y) NOEXCEPT
+{
+ return y + x;
+}
+
+CONSTCD11
+inline
+year
+operator-(const year& x, const years& y) NOEXCEPT
+{
+ return year{static_cast<int>(x) - y.count()};
+}
+
+namespace detail
+{
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+low_level_fmt(std::basic_ostream<CharT, Traits>& os, const year& y)
+{
+ detail::save_ostream<CharT, Traits> _(os);
+ os.fill('0');
+ os.flags(std::ios::dec | std::ios::internal);
+ os.width(4 + (y < year{0}));
+ os.imbue(std::locale::classic());
+ os << static_cast<int>(y);
+ return os;
+}
+
+} // namespace detail
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const year& y)
+{
+ detail::low_level_fmt(os, y);
+ if (!y.ok())
+ os << " is not a valid year";
+ return os;
+}
+
+// weekday
+
+CONSTCD14
+inline
+unsigned char
+weekday::weekday_from_days(int z) NOEXCEPT
+{
+ auto u = static_cast<unsigned>(z);
+ return static_cast<unsigned char>(z >= -4 ? (u+4) % 7 : u % 7);
+}
+
+CONSTCD11
+inline
+weekday::weekday(unsigned wd) NOEXCEPT
+ : wd_(static_cast<decltype(wd_)>(wd != 7 ? wd : 0))
+ {}
+
+CONSTCD14
+inline
+weekday::weekday(const sys_days& dp) NOEXCEPT
+ : wd_(weekday_from_days(dp.time_since_epoch().count()))
+ {}
+
+CONSTCD14
+inline
+weekday::weekday(const local_days& dp) NOEXCEPT
+ : wd_(weekday_from_days(dp.time_since_epoch().count()))
+ {}
+
+CONSTCD14 inline weekday& weekday::operator++() NOEXCEPT {*this += days{1}; return *this;}
+CONSTCD14 inline weekday weekday::operator++(int) NOEXCEPT {auto tmp(*this); ++(*this); return tmp;}
+CONSTCD14 inline weekday& weekday::operator--() NOEXCEPT {*this -= days{1}; return *this;}
+CONSTCD14 inline weekday weekday::operator--(int) NOEXCEPT {auto tmp(*this); --(*this); return tmp;}
+
+CONSTCD14
+inline
+weekday&
+weekday::operator+=(const days& d) NOEXCEPT
+{
+ *this = *this + d;
+ return *this;
+}
+
+CONSTCD14
+inline
+weekday&
+weekday::operator-=(const days& d) NOEXCEPT
+{
+ *this = *this - d;
+ return *this;
+}
+
+CONSTCD11 inline bool weekday::ok() const NOEXCEPT {return wd_ <= 6;}
+
+CONSTCD11
+inline
+unsigned weekday::c_encoding() const NOEXCEPT
+{
+ return unsigned{wd_};
+}
+
+CONSTCD11
+inline
+unsigned weekday::iso_encoding() const NOEXCEPT
+{
+ return unsigned{((wd_ == 0u) ? 7u : wd_)};
+}
+
+CONSTCD11
+inline
+bool
+operator==(const weekday& x, const weekday& y) NOEXCEPT
+{
+ return x.wd_ == y.wd_;
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const weekday& x, const weekday& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+CONSTCD14
+inline
+days
+operator-(const weekday& x, const weekday& y) NOEXCEPT
+{
+ auto const wdu = x.wd_ - y.wd_;
+ auto const wk = (wdu >= 0 ? wdu : wdu-6) / 7;
+ return days{wdu - wk * 7};
+}
+
+CONSTCD14
+inline
+weekday
+operator+(const weekday& x, const days& y) NOEXCEPT
+{
+ auto const wdu = static_cast<long long>(static_cast<unsigned>(x.wd_)) + y.count();
+ auto const wk = (wdu >= 0 ? wdu : wdu-6) / 7;
+ return weekday{static_cast<unsigned>(wdu - wk * 7)};
+}
+
+CONSTCD14
+inline
+weekday
+operator+(const days& x, const weekday& y) NOEXCEPT
+{
+ return y + x;
+}
+
+CONSTCD14
+inline
+weekday
+operator-(const weekday& x, const days& y) NOEXCEPT
+{
+ return x + -y;
+}
+
+namespace detail
+{
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+low_level_fmt(std::basic_ostream<CharT, Traits>& os, const weekday& wd)
+{
+ if (wd.ok())
+ {
+ CharT fmt[] = {'%', 'a', 0};
+ os << format(fmt, wd);
+ }
+ else
+ os << wd.c_encoding();
+ return os;
+}
+
+} // namespace detail
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const weekday& wd)
+{
+ detail::low_level_fmt(os, wd);
+ if (!wd.ok())
+ os << " is not a valid weekday";
+ return os;
+}
+
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+inline namespace literals
+{
+
+CONSTCD11
+inline
+date::day
+operator "" _d(unsigned long long d) NOEXCEPT
+{
+ return date::day{static_cast<unsigned>(d)};
+}
+
+CONSTCD11
+inline
+date::year
+operator "" _y(unsigned long long y) NOEXCEPT
+{
+ return date::year(static_cast<int>(y));
+}
+#endif // !defined(_MSC_VER) || (_MSC_VER >= 1900)
+
+CONSTDATA date::last_spec last{};
+
+CONSTDATA date::month jan{1};
+CONSTDATA date::month feb{2};
+CONSTDATA date::month mar{3};
+CONSTDATA date::month apr{4};
+CONSTDATA date::month may{5};
+CONSTDATA date::month jun{6};
+CONSTDATA date::month jul{7};
+CONSTDATA date::month aug{8};
+CONSTDATA date::month sep{9};
+CONSTDATA date::month oct{10};
+CONSTDATA date::month nov{11};
+CONSTDATA date::month dec{12};
+
+CONSTDATA date::weekday sun{0u};
+CONSTDATA date::weekday mon{1u};
+CONSTDATA date::weekday tue{2u};
+CONSTDATA date::weekday wed{3u};
+CONSTDATA date::weekday thu{4u};
+CONSTDATA date::weekday fri{5u};
+CONSTDATA date::weekday sat{6u};
+
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+} // inline namespace literals
+#endif
+
+CONSTDATA date::month January{1};
+CONSTDATA date::month February{2};
+CONSTDATA date::month March{3};
+CONSTDATA date::month April{4};
+CONSTDATA date::month May{5};
+CONSTDATA date::month June{6};
+CONSTDATA date::month July{7};
+CONSTDATA date::month August{8};
+CONSTDATA date::month September{9};
+CONSTDATA date::month October{10};
+CONSTDATA date::month November{11};
+CONSTDATA date::month December{12};
+
+CONSTDATA date::weekday Monday{1};
+CONSTDATA date::weekday Tuesday{2};
+CONSTDATA date::weekday Wednesday{3};
+CONSTDATA date::weekday Thursday{4};
+CONSTDATA date::weekday Friday{5};
+CONSTDATA date::weekday Saturday{6};
+CONSTDATA date::weekday Sunday{7};
+
+// weekday_indexed
+
+CONSTCD11
+inline
+weekday
+weekday_indexed::weekday() const NOEXCEPT
+{
+ return date::weekday{static_cast<unsigned>(wd_)};
+}
+
+CONSTCD11 inline unsigned weekday_indexed::index() const NOEXCEPT {return index_;}
+
+CONSTCD11
+inline
+bool
+weekday_indexed::ok() const NOEXCEPT
+{
+ return weekday().ok() && 1 <= index_ && index_ <= 5;
+}
+
+#ifdef __GNUC__
+# pragma GCC diagnostic push
+# pragma GCC diagnostic ignored "-Wconversion"
+#endif // __GNUC__
+
+CONSTCD11
+inline
+weekday_indexed::weekday_indexed(const date::weekday& wd, unsigned index) NOEXCEPT
+ : wd_(static_cast<decltype(wd_)>(static_cast<unsigned>(wd.wd_)))
+ , index_(static_cast<decltype(index_)>(index))
+ {}
+
+#ifdef __GNUC__
+# pragma GCC diagnostic pop
+#endif // __GNUC__
+
+namespace detail
+{
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+low_level_fmt(std::basic_ostream<CharT, Traits>& os, const weekday_indexed& wdi)
+{
+ return low_level_fmt(os, wdi.weekday()) << '[' << wdi.index() << ']';
+}
+
+} // namespace detail
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const weekday_indexed& wdi)
+{
+ detail::low_level_fmt(os, wdi);
+ if (!wdi.ok())
+ os << " is not a valid weekday_indexed";
+ return os;
+}
+
+CONSTCD11
+inline
+weekday_indexed
+weekday::operator[](unsigned index) const NOEXCEPT
+{
+ return {*this, index};
+}
+
+CONSTCD11
+inline
+bool
+operator==(const weekday_indexed& x, const weekday_indexed& y) NOEXCEPT
+{
+ return x.weekday() == y.weekday() && x.index() == y.index();
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const weekday_indexed& x, const weekday_indexed& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+// weekday_last
+
+CONSTCD11 inline date::weekday weekday_last::weekday() const NOEXCEPT {return wd_;}
+CONSTCD11 inline bool weekday_last::ok() const NOEXCEPT {return wd_.ok();}
+CONSTCD11 inline weekday_last::weekday_last(const date::weekday& wd) NOEXCEPT : wd_(wd) {}
+
+CONSTCD11
+inline
+bool
+operator==(const weekday_last& x, const weekday_last& y) NOEXCEPT
+{
+ return x.weekday() == y.weekday();
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const weekday_last& x, const weekday_last& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+namespace detail
+{
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+low_level_fmt(std::basic_ostream<CharT, Traits>& os, const weekday_last& wdl)
+{
+ return low_level_fmt(os, wdl.weekday()) << "[last]";
+}
+
+} // namespace detail
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const weekday_last& wdl)
+{
+ detail::low_level_fmt(os, wdl);
+ if (!wdl.ok())
+ os << " is not a valid weekday_last";
+ return os;
+}
+
+CONSTCD11
+inline
+weekday_last
+weekday::operator[](last_spec) const NOEXCEPT
+{
+ return weekday_last{*this};
+}
+
+// year_month
+
+CONSTCD11
+inline
+year_month::year_month(const date::year& y, const date::month& m) NOEXCEPT
+ : y_(y)
+ , m_(m)
+ {}
+
+CONSTCD11 inline year year_month::year() const NOEXCEPT {return y_;}
+CONSTCD11 inline month year_month::month() const NOEXCEPT {return m_;}
+CONSTCD11 inline bool year_month::ok() const NOEXCEPT {return y_.ok() && m_.ok();}
+
+template<class>
+CONSTCD14
+inline
+year_month&
+year_month::operator+=(const months& dm) NOEXCEPT
+{
+ *this = *this + dm;
+ return *this;
+}
+
+template<class>
+CONSTCD14
+inline
+year_month&
+year_month::operator-=(const months& dm) NOEXCEPT
+{
+ *this = *this - dm;
+ return *this;
+}
+
+CONSTCD14
+inline
+year_month&
+year_month::operator+=(const years& dy) NOEXCEPT
+{
+ *this = *this + dy;
+ return *this;
+}
+
+CONSTCD14
+inline
+year_month&
+year_month::operator-=(const years& dy) NOEXCEPT
+{
+ *this = *this - dy;
+ return *this;
+}
+
+CONSTCD11
+inline
+bool
+operator==(const year_month& x, const year_month& y) NOEXCEPT
+{
+ return x.year() == y.year() && x.month() == y.month();
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const year_month& x, const year_month& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+CONSTCD11
+inline
+bool
+operator<(const year_month& x, const year_month& y) NOEXCEPT
+{
+ return x.year() < y.year() ? true
+ : (x.year() > y.year() ? false
+ : (x.month() < y.month()));
+}
+
+CONSTCD11
+inline
+bool
+operator>(const year_month& x, const year_month& y) NOEXCEPT
+{
+ return y < x;
+}
+
+CONSTCD11
+inline
+bool
+operator<=(const year_month& x, const year_month& y) NOEXCEPT
+{
+ return !(y < x);
+}
+
+CONSTCD11
+inline
+bool
+operator>=(const year_month& x, const year_month& y) NOEXCEPT
+{
+ return !(x < y);
+}
+
+template<class>
+CONSTCD14
+inline
+year_month
+operator+(const year_month& ym, const months& dm) NOEXCEPT
+{
+ auto dmi = static_cast<int>(static_cast<unsigned>(ym.month())) - 1 + dm.count();
+ auto dy = (dmi >= 0 ? dmi : dmi-11) / 12;
+ dmi = dmi - dy * 12 + 1;
+ return (ym.year() + years(dy)) / month(static_cast<unsigned>(dmi));
+}
+
+template<class>
+CONSTCD14
+inline
+year_month
+operator+(const months& dm, const year_month& ym) NOEXCEPT
+{
+ return ym + dm;
+}
+
+template<class>
+CONSTCD14
+inline
+year_month
+operator-(const year_month& ym, const months& dm) NOEXCEPT
+{
+ return ym + -dm;
+}
+
+CONSTCD11
+inline
+months
+operator-(const year_month& x, const year_month& y) NOEXCEPT
+{
+ return (x.year() - y.year()) +
+ months(static_cast<unsigned>(x.month()) - static_cast<unsigned>(y.month()));
+}
+
+CONSTCD11
+inline
+year_month
+operator+(const year_month& ym, const years& dy) NOEXCEPT
+{
+ return (ym.year() + dy) / ym.month();
+}
+
+CONSTCD11
+inline
+year_month
+operator+(const years& dy, const year_month& ym) NOEXCEPT
+{
+ return ym + dy;
+}
+
+CONSTCD11
+inline
+year_month
+operator-(const year_month& ym, const years& dy) NOEXCEPT
+{
+ return ym + -dy;
+}
+
+namespace detail
+{
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+low_level_fmt(std::basic_ostream<CharT, Traits>& os, const year_month& ym)
+{
+ low_level_fmt(os, ym.year()) << '/';
+ return low_level_fmt(os, ym.month());
+}
+
+} // namespace detail
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const year_month& ym)
+{
+ detail::low_level_fmt(os, ym);
+ if (!ym.ok())
+ os << " is not a valid year_month";
+ return os;
+}
+
+// month_day
+
+CONSTCD11
+inline
+month_day::month_day(const date::month& m, const date::day& d) NOEXCEPT
+ : m_(m)
+ , d_(d)
+ {}
+
+CONSTCD11 inline date::month month_day::month() const NOEXCEPT {return m_;}
+CONSTCD11 inline date::day month_day::day() const NOEXCEPT {return d_;}
+
+CONSTCD14
+inline
+bool
+month_day::ok() const NOEXCEPT
+{
+ CONSTDATA date::day d[] =
+ {
+ date::day(31), date::day(29), date::day(31),
+ date::day(30), date::day(31), date::day(30),
+ date::day(31), date::day(31), date::day(30),
+ date::day(31), date::day(30), date::day(31)
+ };
+ return m_.ok() && date::day{1} <= d_ && d_ <= d[static_cast<unsigned>(m_)-1];
+}
+
+CONSTCD11
+inline
+bool
+operator==(const month_day& x, const month_day& y) NOEXCEPT
+{
+ return x.month() == y.month() && x.day() == y.day();
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const month_day& x, const month_day& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+CONSTCD11
+inline
+bool
+operator<(const month_day& x, const month_day& y) NOEXCEPT
+{
+ return x.month() < y.month() ? true
+ : (x.month() > y.month() ? false
+ : (x.day() < y.day()));
+}
+
+CONSTCD11
+inline
+bool
+operator>(const month_day& x, const month_day& y) NOEXCEPT
+{
+ return y < x;
+}
+
+CONSTCD11
+inline
+bool
+operator<=(const month_day& x, const month_day& y) NOEXCEPT
+{
+ return !(y < x);
+}
+
+CONSTCD11
+inline
+bool
+operator>=(const month_day& x, const month_day& y) NOEXCEPT
+{
+ return !(x < y);
+}
+
+namespace detail
+{
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+low_level_fmt(std::basic_ostream<CharT, Traits>& os, const month_day& md)
+{
+ low_level_fmt(os, md.month()) << '/';
+ return low_level_fmt(os, md.day());
+}
+
+} // namespace detail
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const month_day& md)
+{
+ detail::low_level_fmt(os, md);
+ if (!md.ok())
+ os << " is not a valid month_day";
+ return os;
+}
+
+// month_day_last
+
+CONSTCD11 inline month month_day_last::month() const NOEXCEPT {return m_;}
+CONSTCD11 inline bool month_day_last::ok() const NOEXCEPT {return m_.ok();}
+CONSTCD11 inline month_day_last::month_day_last(const date::month& m) NOEXCEPT : m_(m) {}
+
+CONSTCD11
+inline
+bool
+operator==(const month_day_last& x, const month_day_last& y) NOEXCEPT
+{
+ return x.month() == y.month();
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const month_day_last& x, const month_day_last& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+CONSTCD11
+inline
+bool
+operator<(const month_day_last& x, const month_day_last& y) NOEXCEPT
+{
+ return x.month() < y.month();
+}
+
+CONSTCD11
+inline
+bool
+operator>(const month_day_last& x, const month_day_last& y) NOEXCEPT
+{
+ return y < x;
+}
+
+CONSTCD11
+inline
+bool
+operator<=(const month_day_last& x, const month_day_last& y) NOEXCEPT
+{
+ return !(y < x);
+}
+
+CONSTCD11
+inline
+bool
+operator>=(const month_day_last& x, const month_day_last& y) NOEXCEPT
+{
+ return !(x < y);
+}
+
+namespace detail
+{
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+low_level_fmt(std::basic_ostream<CharT, Traits>& os, const month_day_last& mdl)
+{
+ return low_level_fmt(os, mdl.month()) << "/last";
+}
+
+} // namespace detail
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const month_day_last& mdl)
+{
+ detail::low_level_fmt(os, mdl);
+ if (!mdl.ok())
+ os << " is not a valid month_day_last";
+ return os;
+}
+
+// month_weekday
+
+CONSTCD11
+inline
+month_weekday::month_weekday(const date::month& m,
+ const date::weekday_indexed& wdi) NOEXCEPT
+ : m_(m)
+ , wdi_(wdi)
+ {}
+
+CONSTCD11 inline month month_weekday::month() const NOEXCEPT {return m_;}
+
+CONSTCD11
+inline
+weekday_indexed
+month_weekday::weekday_indexed() const NOEXCEPT
+{
+ return wdi_;
+}
+
+CONSTCD11
+inline
+bool
+month_weekday::ok() const NOEXCEPT
+{
+ return m_.ok() && wdi_.ok();
+}
+
+CONSTCD11
+inline
+bool
+operator==(const month_weekday& x, const month_weekday& y) NOEXCEPT
+{
+ return x.month() == y.month() && x.weekday_indexed() == y.weekday_indexed();
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const month_weekday& x, const month_weekday& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+namespace detail
+{
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+low_level_fmt(std::basic_ostream<CharT, Traits>& os, const month_weekday& mwd)
+{
+ low_level_fmt(os, mwd.month()) << '/';
+ return low_level_fmt(os, mwd.weekday_indexed());
+}
+
+} // namespace detail
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const month_weekday& mwd)
+{
+ detail::low_level_fmt(os, mwd);
+ if (!mwd.ok())
+ os << " is not a valid month_weekday";
+ return os;
+}
+
+// month_weekday_last
+
+CONSTCD11
+inline
+month_weekday_last::month_weekday_last(const date::month& m,
+ const date::weekday_last& wdl) NOEXCEPT
+ : m_(m)
+ , wdl_(wdl)
+ {}
+
+CONSTCD11 inline month month_weekday_last::month() const NOEXCEPT {return m_;}
+
+CONSTCD11
+inline
+weekday_last
+month_weekday_last::weekday_last() const NOEXCEPT
+{
+ return wdl_;
+}
+
+CONSTCD11
+inline
+bool
+month_weekday_last::ok() const NOEXCEPT
+{
+ return m_.ok() && wdl_.ok();
+}
+
+CONSTCD11
+inline
+bool
+operator==(const month_weekday_last& x, const month_weekday_last& y) NOEXCEPT
+{
+ return x.month() == y.month() && x.weekday_last() == y.weekday_last();
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const month_weekday_last& x, const month_weekday_last& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+namespace detail
+{
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+low_level_fmt(std::basic_ostream<CharT, Traits>& os, const month_weekday_last& mwdl)
+{
+ low_level_fmt(os, mwdl.month()) << '/';
+ return low_level_fmt(os, mwdl.weekday_last());
+}
+
+} // namespace detail
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const month_weekday_last& mwdl)
+{
+ detail::low_level_fmt(os, mwdl);
+ if (!mwdl.ok())
+ os << " is not a valid month_weekday_last";
+ return os;
+}
+
+// year_month_day_last
+
+CONSTCD11
+inline
+year_month_day_last::year_month_day_last(const date::year& y,
+ const date::month_day_last& mdl) NOEXCEPT
+ : y_(y)
+ , mdl_(mdl)
+ {}
+
+template<class>
+CONSTCD14
+inline
+year_month_day_last&
+year_month_day_last::operator+=(const months& m) NOEXCEPT
+{
+ *this = *this + m;
+ return *this;
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_day_last&
+year_month_day_last::operator-=(const months& m) NOEXCEPT
+{
+ *this = *this - m;
+ return *this;
+}
+
+CONSTCD14
+inline
+year_month_day_last&
+year_month_day_last::operator+=(const years& y) NOEXCEPT
+{
+ *this = *this + y;
+ return *this;
+}
+
+CONSTCD14
+inline
+year_month_day_last&
+year_month_day_last::operator-=(const years& y) NOEXCEPT
+{
+ *this = *this - y;
+ return *this;
+}
+
+CONSTCD11 inline year year_month_day_last::year() const NOEXCEPT {return y_;}
+CONSTCD11 inline month year_month_day_last::month() const NOEXCEPT {return mdl_.month();}
+
+CONSTCD11
+inline
+month_day_last
+year_month_day_last::month_day_last() const NOEXCEPT
+{
+ return mdl_;
+}
+
+CONSTCD14
+inline
+day
+year_month_day_last::day() const NOEXCEPT
+{
+ CONSTDATA date::day d[] =
+ {
+ date::day(31), date::day(28), date::day(31),
+ date::day(30), date::day(31), date::day(30),
+ date::day(31), date::day(31), date::day(30),
+ date::day(31), date::day(30), date::day(31)
+ };
+ return (month() != February || !y_.is_leap()) && mdl_.ok() ?
+ d[static_cast<unsigned>(month()) - 1] : date::day{29};
+}
+
+CONSTCD14
+inline
+year_month_day_last::operator sys_days() const NOEXCEPT
+{
+ return sys_days(year()/month()/day());
+}
+
+CONSTCD14
+inline
+year_month_day_last::operator local_days() const NOEXCEPT
+{
+ return local_days(year()/month()/day());
+}
+
+CONSTCD11
+inline
+bool
+year_month_day_last::ok() const NOEXCEPT
+{
+ return y_.ok() && mdl_.ok();
+}
+
+CONSTCD11
+inline
+bool
+operator==(const year_month_day_last& x, const year_month_day_last& y) NOEXCEPT
+{
+ return x.year() == y.year() && x.month_day_last() == y.month_day_last();
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const year_month_day_last& x, const year_month_day_last& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+CONSTCD11
+inline
+bool
+operator<(const year_month_day_last& x, const year_month_day_last& y) NOEXCEPT
+{
+ return x.year() < y.year() ? true
+ : (x.year() > y.year() ? false
+ : (x.month_day_last() < y.month_day_last()));
+}
+
+CONSTCD11
+inline
+bool
+operator>(const year_month_day_last& x, const year_month_day_last& y) NOEXCEPT
+{
+ return y < x;
+}
+
+CONSTCD11
+inline
+bool
+operator<=(const year_month_day_last& x, const year_month_day_last& y) NOEXCEPT
+{
+ return !(y < x);
+}
+
+CONSTCD11
+inline
+bool
+operator>=(const year_month_day_last& x, const year_month_day_last& y) NOEXCEPT
+{
+ return !(x < y);
+}
+
+namespace detail
+{
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+low_level_fmt(std::basic_ostream<CharT, Traits>& os, const year_month_day_last& ymdl)
+{
+ low_level_fmt(os, ymdl.year()) << '/';
+ return low_level_fmt(os, ymdl.month_day_last());
+}
+
+} // namespace detail
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const year_month_day_last& ymdl)
+{
+ detail::low_level_fmt(os, ymdl);
+ if (!ymdl.ok())
+ os << " is not a valid year_month_day_last";
+ return os;
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_day_last
+operator+(const year_month_day_last& ymdl, const months& dm) NOEXCEPT
+{
+ return (ymdl.year() / ymdl.month() + dm) / last;
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_day_last
+operator+(const months& dm, const year_month_day_last& ymdl) NOEXCEPT
+{
+ return ymdl + dm;
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_day_last
+operator-(const year_month_day_last& ymdl, const months& dm) NOEXCEPT
+{
+ return ymdl + (-dm);
+}
+
+CONSTCD11
+inline
+year_month_day_last
+operator+(const year_month_day_last& ymdl, const years& dy) NOEXCEPT
+{
+ return {ymdl.year()+dy, ymdl.month_day_last()};
+}
+
+CONSTCD11
+inline
+year_month_day_last
+operator+(const years& dy, const year_month_day_last& ymdl) NOEXCEPT
+{
+ return ymdl + dy;
+}
+
+CONSTCD11
+inline
+year_month_day_last
+operator-(const year_month_day_last& ymdl, const years& dy) NOEXCEPT
+{
+ return ymdl + (-dy);
+}
+
+// year_month_day
+
+CONSTCD11
+inline
+year_month_day::year_month_day(const date::year& y, const date::month& m,
+ const date::day& d) NOEXCEPT
+ : y_(y)
+ , m_(m)
+ , d_(d)
+ {}
+
+CONSTCD14
+inline
+year_month_day::year_month_day(const year_month_day_last& ymdl) NOEXCEPT
+ : y_(ymdl.year())
+ , m_(ymdl.month())
+ , d_(ymdl.day())
+ {}
+
+CONSTCD14
+inline
+year_month_day::year_month_day(sys_days dp) NOEXCEPT
+ : year_month_day(from_days(dp.time_since_epoch()))
+ {}
+
+CONSTCD14
+inline
+year_month_day::year_month_day(local_days dp) NOEXCEPT
+ : year_month_day(from_days(dp.time_since_epoch()))
+ {}
+
+CONSTCD11 inline year year_month_day::year() const NOEXCEPT {return y_;}
+CONSTCD11 inline month year_month_day::month() const NOEXCEPT {return m_;}
+CONSTCD11 inline day year_month_day::day() const NOEXCEPT {return d_;}
+
+template<class>
+CONSTCD14
+inline
+year_month_day&
+year_month_day::operator+=(const months& m) NOEXCEPT
+{
+ *this = *this + m;
+ return *this;
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_day&
+year_month_day::operator-=(const months& m) NOEXCEPT
+{
+ *this = *this - m;
+ return *this;
+}
+
+CONSTCD14
+inline
+year_month_day&
+year_month_day::operator+=(const years& y) NOEXCEPT
+{
+ *this = *this + y;
+ return *this;
+}
+
+CONSTCD14
+inline
+year_month_day&
+year_month_day::operator-=(const years& y) NOEXCEPT
+{
+ *this = *this - y;
+ return *this;
+}
+
+CONSTCD14
+inline
+days
+year_month_day::to_days() const NOEXCEPT
+{
+ static_assert(std::numeric_limits<unsigned>::digits >= 18,
+ "This algorithm has not been ported to a 16 bit unsigned integer");
+ static_assert(std::numeric_limits<int>::digits >= 20,
+ "This algorithm has not been ported to a 16 bit signed integer");
+ auto const y = static_cast<int>(y_) - (m_ <= February);
+ auto const m = static_cast<unsigned>(m_);
+ auto const d = static_cast<unsigned>(d_);
+ auto const era = (y >= 0 ? y : y-399) / 400;
+ auto const yoe = static_cast<unsigned>(y - era * 400); // [0, 399]
+ auto const doy = (153*(m > 2 ? m-3 : m+9) + 2)/5 + d-1; // [0, 365]
+ auto const doe = yoe * 365 + yoe/4 - yoe/100 + doy; // [0, 146096]
+ return days{era * 146097 + static_cast<int>(doe) - 719468};
+}
+
+CONSTCD14
+inline
+year_month_day::operator sys_days() const NOEXCEPT
+{
+ return sys_days{to_days()};
+}
+
+CONSTCD14
+inline
+year_month_day::operator local_days() const NOEXCEPT
+{
+ return local_days{to_days()};
+}
+
+CONSTCD14
+inline
+bool
+year_month_day::ok() const NOEXCEPT
+{
+ if (!(y_.ok() && m_.ok()))
+ return false;
+ return date::day{1} <= d_ && d_ <= (y_ / m_ / last).day();
+}
+
+CONSTCD11
+inline
+bool
+operator==(const year_month_day& x, const year_month_day& y) NOEXCEPT
+{
+ return x.year() == y.year() && x.month() == y.month() && x.day() == y.day();
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const year_month_day& x, const year_month_day& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+CONSTCD11
+inline
+bool
+operator<(const year_month_day& x, const year_month_day& y) NOEXCEPT
+{
+ return x.year() < y.year() ? true
+ : (x.year() > y.year() ? false
+ : (x.month() < y.month() ? true
+ : (x.month() > y.month() ? false
+ : (x.day() < y.day()))));
+}
+
+CONSTCD11
+inline
+bool
+operator>(const year_month_day& x, const year_month_day& y) NOEXCEPT
+{
+ return y < x;
+}
+
+CONSTCD11
+inline
+bool
+operator<=(const year_month_day& x, const year_month_day& y) NOEXCEPT
+{
+ return !(y < x);
+}
+
+CONSTCD11
+inline
+bool
+operator>=(const year_month_day& x, const year_month_day& y) NOEXCEPT
+{
+ return !(x < y);
+}
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const year_month_day& ymd)
+{
+ detail::save_ostream<CharT, Traits> _(os);
+ os.fill('0');
+ os.flags(std::ios::dec | std::ios::right);
+ os.imbue(std::locale::classic());
+ os << static_cast<int>(ymd.year()) << '-';
+ os.width(2);
+ os << static_cast<unsigned>(ymd.month()) << '-';
+ os.width(2);
+ os << static_cast<unsigned>(ymd.day());
+ if (!ymd.ok())
+ os << " is not a valid year_month_day";
+ return os;
+}
+
+CONSTCD14
+inline
+year_month_day
+year_month_day::from_days(days dp) NOEXCEPT
+{
+ static_assert(std::numeric_limits<unsigned>::digits >= 18,
+ "This algorithm has not been ported to a 16 bit unsigned integer");
+ static_assert(std::numeric_limits<int>::digits >= 20,
+ "This algorithm has not been ported to a 16 bit signed integer");
+ auto const z = dp.count() + 719468;
+ auto const era = (z >= 0 ? z : z - 146096) / 146097;
+ auto const doe = static_cast<unsigned>(z - era * 146097); // [0, 146096]
+ auto const yoe = (doe - doe/1460 + doe/36524 - doe/146096) / 365; // [0, 399]
+ auto const y = static_cast<days::rep>(yoe) + era * 400;
+ auto const doy = doe - (365*yoe + yoe/4 - yoe/100); // [0, 365]
+ auto const mp = (5*doy + 2)/153; // [0, 11]
+ auto const d = doy - (153*mp+2)/5 + 1; // [1, 31]
+ auto const m = mp < 10 ? mp+3 : mp-9; // [1, 12]
+ return year_month_day{date::year{y + (m <= 2)}, date::month(m), date::day(d)};
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_day
+operator+(const year_month_day& ymd, const months& dm) NOEXCEPT
+{
+ return (ymd.year() / ymd.month() + dm) / ymd.day();
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_day
+operator+(const months& dm, const year_month_day& ymd) NOEXCEPT
+{
+ return ymd + dm;
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_day
+operator-(const year_month_day& ymd, const months& dm) NOEXCEPT
+{
+ return ymd + (-dm);
+}
+
+CONSTCD11
+inline
+year_month_day
+operator+(const year_month_day& ymd, const years& dy) NOEXCEPT
+{
+ return (ymd.year() + dy) / ymd.month() / ymd.day();
+}
+
+CONSTCD11
+inline
+year_month_day
+operator+(const years& dy, const year_month_day& ymd) NOEXCEPT
+{
+ return ymd + dy;
+}
+
+CONSTCD11
+inline
+year_month_day
+operator-(const year_month_day& ymd, const years& dy) NOEXCEPT
+{
+ return ymd + (-dy);
+}
+
+// year_month_weekday
+
+CONSTCD11
+inline
+year_month_weekday::year_month_weekday(const date::year& y, const date::month& m,
+ const date::weekday_indexed& wdi)
+ NOEXCEPT
+ : y_(y)
+ , m_(m)
+ , wdi_(wdi)
+ {}
+
+CONSTCD14
+inline
+year_month_weekday::year_month_weekday(const sys_days& dp) NOEXCEPT
+ : year_month_weekday(from_days(dp.time_since_epoch()))
+ {}
+
+CONSTCD14
+inline
+year_month_weekday::year_month_weekday(const local_days& dp) NOEXCEPT
+ : year_month_weekday(from_days(dp.time_since_epoch()))
+ {}
+
+template<class>
+CONSTCD14
+inline
+year_month_weekday&
+year_month_weekday::operator+=(const months& m) NOEXCEPT
+{
+ *this = *this + m;
+ return *this;
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_weekday&
+year_month_weekday::operator-=(const months& m) NOEXCEPT
+{
+ *this = *this - m;
+ return *this;
+}
+
+CONSTCD14
+inline
+year_month_weekday&
+year_month_weekday::operator+=(const years& y) NOEXCEPT
+{
+ *this = *this + y;
+ return *this;
+}
+
+CONSTCD14
+inline
+year_month_weekday&
+year_month_weekday::operator-=(const years& y) NOEXCEPT
+{
+ *this = *this - y;
+ return *this;
+}
+
+CONSTCD11 inline year year_month_weekday::year() const NOEXCEPT {return y_;}
+CONSTCD11 inline month year_month_weekday::month() const NOEXCEPT {return m_;}
+
+CONSTCD11
+inline
+weekday
+year_month_weekday::weekday() const NOEXCEPT
+{
+ return wdi_.weekday();
+}
+
+CONSTCD11
+inline
+unsigned
+year_month_weekday::index() const NOEXCEPT
+{
+ return wdi_.index();
+}
+
+CONSTCD11
+inline
+weekday_indexed
+year_month_weekday::weekday_indexed() const NOEXCEPT
+{
+ return wdi_;
+}
+
+CONSTCD14
+inline
+year_month_weekday::operator sys_days() const NOEXCEPT
+{
+ return sys_days{to_days()};
+}
+
+CONSTCD14
+inline
+year_month_weekday::operator local_days() const NOEXCEPT
+{
+ return local_days{to_days()};
+}
+
+CONSTCD14
+inline
+bool
+year_month_weekday::ok() const NOEXCEPT
+{
+ if (!y_.ok() || !m_.ok() || !wdi_.weekday().ok() || wdi_.index() < 1)
+ return false;
+ if (wdi_.index() <= 4)
+ return true;
+ auto d2 = wdi_.weekday() - date::weekday(static_cast<sys_days>(y_/m_/1)) +
+ days((wdi_.index()-1)*7 + 1);
+ return static_cast<unsigned>(d2.count()) <= static_cast<unsigned>((y_/m_/last).day());
+}
+
+CONSTCD14
+inline
+year_month_weekday
+year_month_weekday::from_days(days d) NOEXCEPT
+{
+ sys_days dp{d};
+ auto const wd = date::weekday(dp);
+ auto const ymd = year_month_day(dp);
+ return {ymd.year(), ymd.month(), wd[(static_cast<unsigned>(ymd.day())-1)/7+1]};
+}
+
+CONSTCD14
+inline
+days
+year_month_weekday::to_days() const NOEXCEPT
+{
+ auto d = sys_days(y_/m_/1);
+ return (d + (wdi_.weekday() - date::weekday(d) + days{(wdi_.index()-1)*7})
+ ).time_since_epoch();
+}
+
+CONSTCD11
+inline
+bool
+operator==(const year_month_weekday& x, const year_month_weekday& y) NOEXCEPT
+{
+ return x.year() == y.year() && x.month() == y.month() &&
+ x.weekday_indexed() == y.weekday_indexed();
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const year_month_weekday& x, const year_month_weekday& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const year_month_weekday& ymwdi)
+{
+ detail::low_level_fmt(os, ymwdi.year()) << '/';
+ detail::low_level_fmt(os, ymwdi.month()) << '/';
+ detail::low_level_fmt(os, ymwdi.weekday_indexed());
+ if (!ymwdi.ok())
+ os << " is not a valid year_month_weekday";
+ return os;
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_weekday
+operator+(const year_month_weekday& ymwd, const months& dm) NOEXCEPT
+{
+ return (ymwd.year() / ymwd.month() + dm) / ymwd.weekday_indexed();
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_weekday
+operator+(const months& dm, const year_month_weekday& ymwd) NOEXCEPT
+{
+ return ymwd + dm;
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_weekday
+operator-(const year_month_weekday& ymwd, const months& dm) NOEXCEPT
+{
+ return ymwd + (-dm);
+}
+
+CONSTCD11
+inline
+year_month_weekday
+operator+(const year_month_weekday& ymwd, const years& dy) NOEXCEPT
+{
+ return {ymwd.year()+dy, ymwd.month(), ymwd.weekday_indexed()};
+}
+
+CONSTCD11
+inline
+year_month_weekday
+operator+(const years& dy, const year_month_weekday& ymwd) NOEXCEPT
+{
+ return ymwd + dy;
+}
+
+CONSTCD11
+inline
+year_month_weekday
+operator-(const year_month_weekday& ymwd, const years& dy) NOEXCEPT
+{
+ return ymwd + (-dy);
+}
+
+// year_month_weekday_last
+
+CONSTCD11
+inline
+year_month_weekday_last::year_month_weekday_last(const date::year& y,
+ const date::month& m,
+ const date::weekday_last& wdl) NOEXCEPT
+ : y_(y)
+ , m_(m)
+ , wdl_(wdl)
+ {}
+
+template<class>
+CONSTCD14
+inline
+year_month_weekday_last&
+year_month_weekday_last::operator+=(const months& m) NOEXCEPT
+{
+ *this = *this + m;
+ return *this;
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_weekday_last&
+year_month_weekday_last::operator-=(const months& m) NOEXCEPT
+{
+ *this = *this - m;
+ return *this;
+}
+
+CONSTCD14
+inline
+year_month_weekday_last&
+year_month_weekday_last::operator+=(const years& y) NOEXCEPT
+{
+ *this = *this + y;
+ return *this;
+}
+
+CONSTCD14
+inline
+year_month_weekday_last&
+year_month_weekday_last::operator-=(const years& y) NOEXCEPT
+{
+ *this = *this - y;
+ return *this;
+}
+
+CONSTCD11 inline year year_month_weekday_last::year() const NOEXCEPT {return y_;}
+CONSTCD11 inline month year_month_weekday_last::month() const NOEXCEPT {return m_;}
+
+CONSTCD11
+inline
+weekday
+year_month_weekday_last::weekday() const NOEXCEPT
+{
+ return wdl_.weekday();
+}
+
+CONSTCD11
+inline
+weekday_last
+year_month_weekday_last::weekday_last() const NOEXCEPT
+{
+ return wdl_;
+}
+
+CONSTCD14
+inline
+year_month_weekday_last::operator sys_days() const NOEXCEPT
+{
+ return sys_days{to_days()};
+}
+
+CONSTCD14
+inline
+year_month_weekday_last::operator local_days() const NOEXCEPT
+{
+ return local_days{to_days()};
+}
+
+CONSTCD11
+inline
+bool
+year_month_weekday_last::ok() const NOEXCEPT
+{
+ return y_.ok() && m_.ok() && wdl_.ok();
+}
+
+CONSTCD14
+inline
+days
+year_month_weekday_last::to_days() const NOEXCEPT
+{
+ auto const d = sys_days(y_/m_/last);
+ return (d - (date::weekday{d} - wdl_.weekday())).time_since_epoch();
+}
+
+CONSTCD11
+inline
+bool
+operator==(const year_month_weekday_last& x, const year_month_weekday_last& y) NOEXCEPT
+{
+ return x.year() == y.year() && x.month() == y.month() &&
+ x.weekday_last() == y.weekday_last();
+}
+
+CONSTCD11
+inline
+bool
+operator!=(const year_month_weekday_last& x, const year_month_weekday_last& y) NOEXCEPT
+{
+ return !(x == y);
+}
+
+template<class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const year_month_weekday_last& ymwdl)
+{
+ detail::low_level_fmt(os, ymwdl.year()) << '/';
+ detail::low_level_fmt(os, ymwdl.month()) << '/';
+ detail::low_level_fmt(os, ymwdl.weekday_last());
+ if (!ymwdl.ok())
+ os << " is not a valid year_month_weekday_last";
+ return os;
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_weekday_last
+operator+(const year_month_weekday_last& ymwdl, const months& dm) NOEXCEPT
+{
+ return (ymwdl.year() / ymwdl.month() + dm) / ymwdl.weekday_last();
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_weekday_last
+operator+(const months& dm, const year_month_weekday_last& ymwdl) NOEXCEPT
+{
+ return ymwdl + dm;
+}
+
+template<class>
+CONSTCD14
+inline
+year_month_weekday_last
+operator-(const year_month_weekday_last& ymwdl, const months& dm) NOEXCEPT
+{
+ return ymwdl + (-dm);
+}
+
+CONSTCD11
+inline
+year_month_weekday_last
+operator+(const year_month_weekday_last& ymwdl, const years& dy) NOEXCEPT
+{
+ return {ymwdl.year()+dy, ymwdl.month(), ymwdl.weekday_last()};
+}
+
+CONSTCD11
+inline
+year_month_weekday_last
+operator+(const years& dy, const year_month_weekday_last& ymwdl) NOEXCEPT
+{
+ return ymwdl + dy;
+}
+
+CONSTCD11
+inline
+year_month_weekday_last
+operator-(const year_month_weekday_last& ymwdl, const years& dy) NOEXCEPT
+{
+ return ymwdl + (-dy);
+}
+
+// year_month from operator/()
+
+CONSTCD11
+inline
+year_month
+operator/(const year& y, const month& m) NOEXCEPT
+{
+ return {y, m};
+}
+
+CONSTCD11
+inline
+year_month
+operator/(const year& y, int m) NOEXCEPT
+{
+ return y / month(static_cast<unsigned>(m));
+}
+
+// month_day from operator/()
+
+CONSTCD11
+inline
+month_day
+operator/(const month& m, const day& d) NOEXCEPT
+{
+ return {m, d};
+}
+
+CONSTCD11
+inline
+month_day
+operator/(const day& d, const month& m) NOEXCEPT
+{
+ return m / d;
+}
+
+CONSTCD11
+inline
+month_day
+operator/(const month& m, int d) NOEXCEPT
+{
+ return m / day(static_cast<unsigned>(d));
+}
+
+CONSTCD11
+inline
+month_day
+operator/(int m, const day& d) NOEXCEPT
+{
+ return month(static_cast<unsigned>(m)) / d;
+}
+
+CONSTCD11 inline month_day operator/(const day& d, int m) NOEXCEPT {return m / d;}
+
+// month_day_last from operator/()
+
+CONSTCD11
+inline
+month_day_last
+operator/(const month& m, last_spec) NOEXCEPT
+{
+ return month_day_last{m};
+}
+
+CONSTCD11
+inline
+month_day_last
+operator/(last_spec, const month& m) NOEXCEPT
+{
+ return m/last;
+}
+
+CONSTCD11
+inline
+month_day_last
+operator/(int m, last_spec) NOEXCEPT
+{
+ return month(static_cast<unsigned>(m))/last;
+}
+
+CONSTCD11
+inline
+month_day_last
+operator/(last_spec, int m) NOEXCEPT
+{
+ return m/last;
+}
+
+// month_weekday from operator/()
+
+CONSTCD11
+inline
+month_weekday
+operator/(const month& m, const weekday_indexed& wdi) NOEXCEPT
+{
+ return {m, wdi};
+}
+
+CONSTCD11
+inline
+month_weekday
+operator/(const weekday_indexed& wdi, const month& m) NOEXCEPT
+{
+ return m / wdi;
+}
+
+CONSTCD11
+inline
+month_weekday
+operator/(int m, const weekday_indexed& wdi) NOEXCEPT
+{
+ return month(static_cast<unsigned>(m)) / wdi;
+}
+
+CONSTCD11
+inline
+month_weekday
+operator/(const weekday_indexed& wdi, int m) NOEXCEPT
+{
+ return m / wdi;
+}
+
+// month_weekday_last from operator/()
+
+CONSTCD11
+inline
+month_weekday_last
+operator/(const month& m, const weekday_last& wdl) NOEXCEPT
+{
+ return {m, wdl};
+}
+
+CONSTCD11
+inline
+month_weekday_last
+operator/(const weekday_last& wdl, const month& m) NOEXCEPT
+{
+ return m / wdl;
+}
+
+CONSTCD11
+inline
+month_weekday_last
+operator/(int m, const weekday_last& wdl) NOEXCEPT
+{
+ return month(static_cast<unsigned>(m)) / wdl;
+}
+
+CONSTCD11
+inline
+month_weekday_last
+operator/(const weekday_last& wdl, int m) NOEXCEPT
+{
+ return m / wdl;
+}
+
+// year_month_day from operator/()
+
+CONSTCD11
+inline
+year_month_day
+operator/(const year_month& ym, const day& d) NOEXCEPT
+{
+ return {ym.year(), ym.month(), d};
+}
+
+CONSTCD11
+inline
+year_month_day
+operator/(const year_month& ym, int d) NOEXCEPT
+{
+ return ym / day(static_cast<unsigned>(d));
+}
+
+CONSTCD11
+inline
+year_month_day
+operator/(const year& y, const month_day& md) NOEXCEPT
+{
+ return y / md.month() / md.day();
+}
+
+CONSTCD11
+inline
+year_month_day
+operator/(int y, const month_day& md) NOEXCEPT
+{
+ return year(y) / md;
+}
+
+CONSTCD11
+inline
+year_month_day
+operator/(const month_day& md, const year& y) NOEXCEPT
+{
+ return y / md;
+}
+
+CONSTCD11
+inline
+year_month_day
+operator/(const month_day& md, int y) NOEXCEPT
+{
+ return year(y) / md;
+}
+
+// year_month_day_last from operator/()
+
+CONSTCD11
+inline
+year_month_day_last
+operator/(const year_month& ym, last_spec) NOEXCEPT
+{
+ return {ym.year(), month_day_last{ym.month()}};
+}
+
+CONSTCD11
+inline
+year_month_day_last
+operator/(const year& y, const month_day_last& mdl) NOEXCEPT
+{
+ return {y, mdl};
+}
+
+CONSTCD11
+inline
+year_month_day_last
+operator/(int y, const month_day_last& mdl) NOEXCEPT
+{
+ return year(y) / mdl;
+}
+
+CONSTCD11
+inline
+year_month_day_last
+operator/(const month_day_last& mdl, const year& y) NOEXCEPT
+{
+ return y / mdl;
+}
+
+CONSTCD11
+inline
+year_month_day_last
+operator/(const month_day_last& mdl, int y) NOEXCEPT
+{
+ return year(y) / mdl;
+}
+
+// year_month_weekday from operator/()
+
+CONSTCD11
+inline
+year_month_weekday
+operator/(const year_month& ym, const weekday_indexed& wdi) NOEXCEPT
+{
+ return {ym.year(), ym.month(), wdi};
+}
+
+CONSTCD11
+inline
+year_month_weekday
+operator/(const year& y, const month_weekday& mwd) NOEXCEPT
+{
+ return {y, mwd.month(), mwd.weekday_indexed()};
+}
+
+CONSTCD11
+inline
+year_month_weekday
+operator/(int y, const month_weekday& mwd) NOEXCEPT
+{
+ return year(y) / mwd;
+}
+
+CONSTCD11
+inline
+year_month_weekday
+operator/(const month_weekday& mwd, const year& y) NOEXCEPT
+{
+ return y / mwd;
+}
+
+CONSTCD11
+inline
+year_month_weekday
+operator/(const month_weekday& mwd, int y) NOEXCEPT
+{
+ return year(y) / mwd;
+}
+
+// year_month_weekday_last from operator/()
+
+CONSTCD11
+inline
+year_month_weekday_last
+operator/(const year_month& ym, const weekday_last& wdl) NOEXCEPT
+{
+ return {ym.year(), ym.month(), wdl};
+}
+
+CONSTCD11
+inline
+year_month_weekday_last
+operator/(const year& y, const month_weekday_last& mwdl) NOEXCEPT
+{
+ return {y, mwdl.month(), mwdl.weekday_last()};
+}
+
+CONSTCD11
+inline
+year_month_weekday_last
+operator/(int y, const month_weekday_last& mwdl) NOEXCEPT
+{
+ return year(y) / mwdl;
+}
+
+CONSTCD11
+inline
+year_month_weekday_last
+operator/(const month_weekday_last& mwdl, const year& y) NOEXCEPT
+{
+ return y / mwdl;
+}
+
+CONSTCD11
+inline
+year_month_weekday_last
+operator/(const month_weekday_last& mwdl, int y) NOEXCEPT
+{
+ return year(y) / mwdl;
+}
+
+template <class Duration>
+struct fields;
+
+template <class CharT, class Traits, class Duration>
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt,
+ const fields<Duration>& fds, const std::string* abbrev = nullptr,
+ const std::chrono::seconds* offset_sec = nullptr);
+
+template <class CharT, class Traits, class Duration, class Alloc>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt,
+ fields<Duration>& fds, std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr);
+
+// hh_mm_ss
+
+namespace detail
+{
+
+struct undocumented {explicit undocumented() = default;};
+
+// width<n>::value is the number of fractional decimal digits in 1/n
+// width<0>::value and width<1>::value are defined to be 0
+// If 1/n takes more than 18 fractional decimal digits,
+// the result is truncated to 19.
+// Example: width<2>::value == 1
+// Example: width<3>::value == 19
+// Example: width<4>::value == 2
+// Example: width<10>::value == 1
+// Example: width<1000>::value == 3
+template <std::uint64_t n, std::uint64_t d, unsigned w = 0,
+ bool should_continue = n%d != 0 && (w < 19)>
+struct width
+{
+ static_assert(d > 0, "width called with zero denominator");
+ static CONSTDATA unsigned value = 1 + width<n%d*10, d, w+1>::value;
+};
+
+template <std::uint64_t n, std::uint64_t d, unsigned w>
+struct width<n, d, w, false>
+{
+ static CONSTDATA unsigned value = 0;
+};
+
+template <unsigned exp>
+struct static_pow10
+{
+private:
+ static CONSTDATA std::uint64_t h = static_pow10<exp/2>::value;
+public:
+ static CONSTDATA std::uint64_t value = h * h * (exp % 2 ? 10 : 1);
+};
+
+template <>
+struct static_pow10<0>
+{
+ static CONSTDATA std::uint64_t value = 1;
+};
+
+template <class Duration>
+class decimal_format_seconds
+{
+ using CT = typename std::common_type<Duration, std::chrono::seconds>::type;
+ using rep = typename CT::rep;
+ static unsigned CONSTDATA trial_width =
+ detail::width<CT::period::num, CT::period::den>::value;
+public:
+ static unsigned CONSTDATA width = trial_width < 19 ? trial_width : 6u;
+ using precision = std::chrono::duration<rep,
+ std::ratio<1, static_pow10<width>::value>>;
+
+private:
+ std::chrono::seconds s_;
+ precision sub_s_;
+
+public:
+ CONSTCD11 decimal_format_seconds()
+ : s_()
+ , sub_s_()
+ {}
+
+ CONSTCD11 explicit decimal_format_seconds(const Duration& d) NOEXCEPT
+ : s_(std::chrono::duration_cast<std::chrono::seconds>(d))
+ , sub_s_(std::chrono::duration_cast<precision>(d - s_))
+ {}
+
+ CONSTCD14 std::chrono::seconds& seconds() NOEXCEPT {return s_;}
+ CONSTCD11 std::chrono::seconds seconds() const NOEXCEPT {return s_;}
+ CONSTCD11 precision subseconds() const NOEXCEPT {return sub_s_;}
+
+ CONSTCD14 precision to_duration() const NOEXCEPT
+ {
+ return s_ + sub_s_;
+ }
+
+ CONSTCD11 bool in_conventional_range() const NOEXCEPT
+ {
+ return sub_s_ < std::chrono::seconds{1} && s_ < std::chrono::minutes{1};
+ }
+
+ template <class CharT, class Traits>
+ friend
+ std::basic_ostream<CharT, Traits>&
+ operator<<(std::basic_ostream<CharT, Traits>& os, const decimal_format_seconds& x)
+ {
+ return x.print(os, std::chrono::treat_as_floating_point<rep>{});
+ }
+
+ template <class CharT, class Traits>
+ std::basic_ostream<CharT, Traits>&
+ print(std::basic_ostream<CharT, Traits>& os, std::true_type) const
+ {
+ date::detail::save_ostream<CharT, Traits> _(os);
+ std::chrono::duration<rep> d = s_ + sub_s_;
+ if (d < std::chrono::seconds{10})
+ os << '0';
+ os.precision(width+6);
+ os << std::fixed << d.count();
+ return os;
+ }
+
+ template <class CharT, class Traits>
+ std::basic_ostream<CharT, Traits>&
+ print(std::basic_ostream<CharT, Traits>& os, std::false_type) const
+ {
+ date::detail::save_ostream<CharT, Traits> _(os);
+ os.fill('0');
+ os.flags(std::ios::dec | std::ios::right);
+ os.width(2);
+ os << s_.count();
+ if (width > 0)
+ {
+#if !ONLY_C_LOCALE
+ os << std::use_facet<std::numpunct<CharT>>(os.getloc()).decimal_point();
+#else
+ os << '.';
+#endif
+ date::detail::save_ostream<CharT, Traits> _s(os);
+ os.imbue(std::locale::classic());
+ os.width(width);
+ os << sub_s_.count();
+ }
+ return os;
+ }
+};
+
+template <class Rep, class Period>
+inline
+CONSTCD11
+typename std::enable_if
+ <
+ std::numeric_limits<Rep>::is_signed,
+ std::chrono::duration<Rep, Period>
+ >::type
+abs(std::chrono::duration<Rep, Period> d)
+{
+ return d >= d.zero() ? +d : -d;
+}
+
+template <class Rep, class Period>
+inline
+CONSTCD11
+typename std::enable_if
+ <
+ !std::numeric_limits<Rep>::is_signed,
+ std::chrono::duration<Rep, Period>
+ >::type
+abs(std::chrono::duration<Rep, Period> d)
+{
+ return d;
+}
+
+} // namespace detail
+
+template <class Duration>
+class hh_mm_ss
+{
+ using dfs = detail::decimal_format_seconds<typename std::common_type<Duration,
+ std::chrono::seconds>::type>;
+
+ std::chrono::hours h_;
+ std::chrono::minutes m_;
+ dfs s_;
+ bool neg_;
+
+public:
+ static unsigned CONSTDATA fractional_width = dfs::width;
+ using precision = typename dfs::precision;
+
+ CONSTCD11 hh_mm_ss() NOEXCEPT
+ : hh_mm_ss(Duration::zero())
+ {}
+
+ CONSTCD11 explicit hh_mm_ss(Duration d) NOEXCEPT
+ : h_(std::chrono::duration_cast<std::chrono::hours>(detail::abs(d)))
+ , m_(std::chrono::duration_cast<std::chrono::minutes>(detail::abs(d)) - h_)
+ , s_(detail::abs(d) - h_ - m_)
+ , neg_(d < Duration::zero())
+ {}
+
+ CONSTCD11 std::chrono::hours hours() const NOEXCEPT {return h_;}
+ CONSTCD11 std::chrono::minutes minutes() const NOEXCEPT {return m_;}
+ CONSTCD11 std::chrono::seconds seconds() const NOEXCEPT {return s_.seconds();}
+ CONSTCD14 std::chrono::seconds&
+ seconds(detail::undocumented) NOEXCEPT {return s_.seconds();}
+ CONSTCD11 precision subseconds() const NOEXCEPT {return s_.subseconds();}
+ CONSTCD11 bool is_negative() const NOEXCEPT {return neg_;}
+
+ CONSTCD11 explicit operator precision() const NOEXCEPT {return to_duration();}
+ CONSTCD11 precision to_duration() const NOEXCEPT
+ {return (s_.to_duration() + m_ + h_) * (1-2*neg_);}
+
+ CONSTCD11 bool in_conventional_range() const NOEXCEPT
+ {
+ return !neg_ && h_ < days{1} && m_ < std::chrono::hours{1} &&
+ s_.in_conventional_range();
+ }
+
+private:
+
+ template <class charT, class traits>
+ friend
+ std::basic_ostream<charT, traits>&
+ operator<<(std::basic_ostream<charT, traits>& os, hh_mm_ss const& tod)
+ {
+ if (tod.is_negative())
+ os << '-';
+ if (tod.h_ < std::chrono::hours{10})
+ os << '0';
+ os << tod.h_.count() << ':';
+ if (tod.m_ < std::chrono::minutes{10})
+ os << '0';
+ os << tod.m_.count() << ':' << tod.s_;
+ return os;
+ }
+
+ template <class CharT, class Traits, class Duration2>
+ friend
+ std::basic_ostream<CharT, Traits>&
+ date::to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt,
+ const fields<Duration2>& fds, const std::string* abbrev,
+ const std::chrono::seconds* offset_sec);
+
+ template <class CharT, class Traits, class Duration2, class Alloc>
+ friend
+ std::basic_istream<CharT, Traits>&
+ date::from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt,
+ fields<Duration2>& fds,
+ std::basic_string<CharT, Traits, Alloc>* abbrev, std::chrono::minutes* offset);
+};
+
+inline
+CONSTCD14
+bool
+is_am(std::chrono::hours const& h) NOEXCEPT
+{
+ using std::chrono::hours;
+ return hours{0} <= h && h < hours{12};
+}
+
+inline
+CONSTCD14
+bool
+is_pm(std::chrono::hours const& h) NOEXCEPT
+{
+ using std::chrono::hours;
+ return hours{12} <= h && h < hours{24};
+}
+
+inline
+CONSTCD14
+std::chrono::hours
+make12(std::chrono::hours h) NOEXCEPT
+{
+ using std::chrono::hours;
+ if (h < hours{12})
+ {
+ if (h == hours{0})
+ h = hours{12};
+ }
+ else
+ {
+ if (h != hours{12})
+ h = h - hours{12};
+ }
+ return h;
+}
+
+inline
+CONSTCD14
+std::chrono::hours
+make24(std::chrono::hours h, bool is_pm) NOEXCEPT
+{
+ using std::chrono::hours;
+ if (is_pm)
+ {
+ if (h != hours{12})
+ h = h + hours{12};
+ }
+ else if (h == hours{12})
+ h = hours{0};
+ return h;
+}
+
+template <class Duration>
+using time_of_day = hh_mm_ss<Duration>;
+
+template <class Rep, class Period>
+CONSTCD11
+inline
+hh_mm_ss<std::chrono::duration<Rep, Period>>
+make_time(const std::chrono::duration<Rep, Period>& d)
+{
+ return hh_mm_ss<std::chrono::duration<Rep, Period>>(d);
+}
+
+template <class CharT, class Traits, class Duration>
+inline
+typename std::enable_if
+<
+ std::ratio_less<typename Duration::period, days::period>::value
+ , std::basic_ostream<CharT, Traits>&
+>::type
+operator<<(std::basic_ostream<CharT, Traits>& os, const sys_time<Duration>& tp)
+{
+ auto const dp = date::floor<days>(tp);
+ return os << year_month_day(dp) << ' ' << make_time(tp-dp);
+}
+
+template <class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const sys_days& dp)
+{
+ return os << year_month_day(dp);
+}
+
+template <class CharT, class Traits, class Duration>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const local_time<Duration>& ut)
+{
+ return (os << sys_time<Duration>{ut.time_since_epoch()});
+}
+
+namespace detail
+{
+
+template <class CharT, std::size_t N>
+class string_literal;
+
+template <class CharT1, class CharT2, std::size_t N1, std::size_t N2>
+inline
+CONSTCD14
+string_literal<typename std::conditional<sizeof(CharT2) <= sizeof(CharT1), CharT1, CharT2>::type,
+ N1 + N2 - 1>
+operator+(const string_literal<CharT1, N1>& x, const string_literal<CharT2, N2>& y) NOEXCEPT;
+
+template <class CharT, std::size_t N>
+class string_literal
+{
+ CharT p_[N];
+
+ CONSTCD11 string_literal() NOEXCEPT
+ : p_{}
+ {}
+
+public:
+ using const_iterator = const CharT*;
+
+ string_literal(string_literal const&) = default;
+ string_literal& operator=(string_literal const&) = delete;
+
+ template <std::size_t N1 = 2,
+ class = typename std::enable_if<N1 == N>::type>
+ CONSTCD11 string_literal(CharT c) NOEXCEPT
+ : p_{c}
+ {
+ }
+
+ template <std::size_t N1 = 3,
+ class = typename std::enable_if<N1 == N>::type>
+ CONSTCD11 string_literal(CharT c1, CharT c2) NOEXCEPT
+ : p_{c1, c2}
+ {
+ }
+
+ template <std::size_t N1 = 4,
+ class = typename std::enable_if<N1 == N>::type>
+ CONSTCD11 string_literal(CharT c1, CharT c2, CharT c3) NOEXCEPT
+ : p_{c1, c2, c3}
+ {
+ }
+
+ CONSTCD14 string_literal(const CharT(&a)[N]) NOEXCEPT
+ : p_{}
+ {
+ for (std::size_t i = 0; i < N; ++i)
+ p_[i] = a[i];
+ }
+
+ template <class U = CharT,
+ class = typename std::enable_if<(1 < sizeof(U))>::type>
+ CONSTCD14 string_literal(const char(&a)[N]) NOEXCEPT
+ : p_{}
+ {
+ for (std::size_t i = 0; i < N; ++i)
+ p_[i] = a[i];
+ }
+
+ template <class CharT2,
+ class = typename std::enable_if<!std::is_same<CharT2, CharT>::value>::type>
+ CONSTCD14 string_literal(string_literal<CharT2, N> const& a) NOEXCEPT
+ : p_{}
+ {
+ for (std::size_t i = 0; i < N; ++i)
+ p_[i] = a[i];
+ }
+
+ CONSTCD11 const CharT* data() const NOEXCEPT {return p_;}
+ CONSTCD11 std::size_t size() const NOEXCEPT {return N-1;}
+
+ CONSTCD11 const_iterator begin() const NOEXCEPT {return p_;}
+ CONSTCD11 const_iterator end() const NOEXCEPT {return p_ + N-1;}
+
+ CONSTCD11 CharT const& operator[](std::size_t n) const NOEXCEPT
+ {
+ return p_[n];
+ }
+
+ template <class Traits>
+ friend
+ std::basic_ostream<CharT, Traits>&
+ operator<<(std::basic_ostream<CharT, Traits>& os, const string_literal& s)
+ {
+ return os << s.p_;
+ }
+
+ template <class CharT1, class CharT2, std::size_t N1, std::size_t N2>
+ friend
+ CONSTCD14
+ string_literal<typename std::conditional<sizeof(CharT2) <= sizeof(CharT1), CharT1, CharT2>::type,
+ N1 + N2 - 1>
+ operator+(const string_literal<CharT1, N1>& x, const string_literal<CharT2, N2>& y) NOEXCEPT;
+};
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 3>
+operator+(const string_literal<CharT, 2>& x, const string_literal<CharT, 2>& y) NOEXCEPT
+{
+ return string_literal<CharT, 3>(x[0], y[0]);
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 4>
+operator+(const string_literal<CharT, 3>& x, const string_literal<CharT, 2>& y) NOEXCEPT
+{
+ return string_literal<CharT, 4>(x[0], x[1], y[0]);
+}
+
+template <class CharT1, class CharT2, std::size_t N1, std::size_t N2>
+CONSTCD14
+inline
+string_literal<typename std::conditional<sizeof(CharT2) <= sizeof(CharT1), CharT1, CharT2>::type,
+ N1 + N2 - 1>
+operator+(const string_literal<CharT1, N1>& x, const string_literal<CharT2, N2>& y) NOEXCEPT
+{
+ using CT = typename std::conditional<sizeof(CharT2) <= sizeof(CharT1), CharT1, CharT2>::type;
+
+ string_literal<CT, N1 + N2 - 1> r;
+ std::size_t i = 0;
+ for (; i < N1-1; ++i)
+ r.p_[i] = CT(x.p_[i]);
+ for (std::size_t j = 0; j < N2; ++j, ++i)
+ r.p_[i] = CT(y.p_[j]);
+
+ return r;
+}
+
+
+template <class CharT, class Traits, class Alloc, std::size_t N>
+inline
+std::basic_string<CharT, Traits, Alloc>
+operator+(std::basic_string<CharT, Traits, Alloc> x, const string_literal<CharT, N>& y)
+{
+ x.append(y.data(), y.size());
+ return x;
+}
+
+#if __cplusplus >= 201402 && (!defined(__EDG_VERSION__) || __EDG_VERSION__ > 411) \
+ && (!defined(__SUNPRO_CC) || __SUNPRO_CC > 0x5150)
+
+template <class CharT,
+ class = std::enable_if_t<std::is_same<CharT, char>::value ||
+ std::is_same<CharT, wchar_t>::value ||
+ std::is_same<CharT, char16_t>::value ||
+ std::is_same<CharT, char32_t>::value>>
+CONSTCD14
+inline
+string_literal<CharT, 2>
+msl(CharT c) NOEXCEPT
+{
+ return string_literal<CharT, 2>{c};
+}
+
+CONSTCD14
+inline
+std::size_t
+to_string_len(std::intmax_t i)
+{
+ std::size_t r = 0;
+ do
+ {
+ i /= 10;
+ ++r;
+ } while (i > 0);
+ return r;
+}
+
+template <std::intmax_t N>
+CONSTCD14
+inline
+std::enable_if_t
+<
+ N < 10,
+ string_literal<char, to_string_len(N)+1>
+>
+msl() NOEXCEPT
+{
+ return msl(char(N % 10 + '0'));
+}
+
+template <std::intmax_t N>
+CONSTCD14
+inline
+std::enable_if_t
+<
+ 10 <= N,
+ string_literal<char, to_string_len(N)+1>
+>
+msl() NOEXCEPT
+{
+ return msl<N/10>() + msl(char(N % 10 + '0'));
+}
+
+template <class CharT, std::intmax_t N, std::intmax_t D>
+CONSTCD14
+inline
+std::enable_if_t
+<
+ std::ratio<N, D>::type::den != 1,
+ string_literal<CharT, to_string_len(std::ratio<N, D>::type::num) +
+ to_string_len(std::ratio<N, D>::type::den) + 4>
+>
+msl(std::ratio<N, D>) NOEXCEPT
+{
+ using R = typename std::ratio<N, D>::type;
+ return msl(CharT{'['}) + msl<R::num>() + msl(CharT{'/'}) +
+ msl<R::den>() + msl(CharT{']'});
+}
+
+template <class CharT, std::intmax_t N, std::intmax_t D>
+CONSTCD14
+inline
+std::enable_if_t
+<
+ std::ratio<N, D>::type::den == 1,
+ string_literal<CharT, to_string_len(std::ratio<N, D>::type::num) + 3>
+>
+msl(std::ratio<N, D>) NOEXCEPT
+{
+ using R = typename std::ratio<N, D>::type;
+ return msl(CharT{'['}) + msl<R::num>() + msl(CharT{']'});
+}
+
+
+#else // __cplusplus < 201402 || (defined(__EDG_VERSION__) && __EDG_VERSION__ <= 411)
+
+inline
+std::string
+to_string(std::uint64_t x)
+{
+ return std::to_string(x);
+}
+
+template <class CharT>
+inline
+std::basic_string<CharT>
+to_string(std::uint64_t x)
+{
+ auto y = std::to_string(x);
+ return std::basic_string<CharT>(y.begin(), y.end());
+}
+
+template <class CharT, std::intmax_t N, std::intmax_t D>
+inline
+typename std::enable_if
+<
+ std::ratio<N, D>::type::den != 1,
+ std::basic_string<CharT>
+>::type
+msl(std::ratio<N, D>)
+{
+ using R = typename std::ratio<N, D>::type;
+ return std::basic_string<CharT>(1, '[') + to_string<CharT>(R::num) + CharT{'/'} +
+ to_string<CharT>(R::den) + CharT{']'};
+}
+
+template <class CharT, std::intmax_t N, std::intmax_t D>
+inline
+typename std::enable_if
+<
+ std::ratio<N, D>::type::den == 1,
+ std::basic_string<CharT>
+>::type
+msl(std::ratio<N, D>)
+{
+ using R = typename std::ratio<N, D>::type;
+ return std::basic_string<CharT>(1, '[') + to_string<CharT>(R::num) + CharT{']'};
+}
+
+#endif // __cplusplus < 201402 || (defined(__EDG_VERSION__) && __EDG_VERSION__ <= 411)
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+msl(std::atto) NOEXCEPT
+{
+ return string_literal<CharT, 2>{'a'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+msl(std::femto) NOEXCEPT
+{
+ return string_literal<CharT, 2>{'f'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+msl(std::pico) NOEXCEPT
+{
+ return string_literal<CharT, 2>{'p'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+msl(std::nano) NOEXCEPT
+{
+ return string_literal<CharT, 2>{'n'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+typename std::enable_if
+<
+ std::is_same<CharT, char>::value,
+ string_literal<char, 3>
+>::type
+msl(std::micro) NOEXCEPT
+{
+ return string_literal<char, 3>{'\xC2', '\xB5'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+typename std::enable_if
+<
+ !std::is_same<CharT, char>::value,
+ string_literal<CharT, 2>
+>::type
+msl(std::micro) NOEXCEPT
+{
+ return string_literal<CharT, 2>{CharT{static_cast<unsigned char>('\xB5')}};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+msl(std::milli) NOEXCEPT
+{
+ return string_literal<CharT, 2>{'m'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+msl(std::centi) NOEXCEPT
+{
+ return string_literal<CharT, 2>{'c'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 3>
+msl(std::deca) NOEXCEPT
+{
+ return string_literal<CharT, 3>{'d', 'a'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+msl(std::deci) NOEXCEPT
+{
+ return string_literal<CharT, 2>{'d'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+msl(std::hecto) NOEXCEPT
+{
+ return string_literal<CharT, 2>{'h'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+msl(std::kilo) NOEXCEPT
+{
+ return string_literal<CharT, 2>{'k'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+msl(std::mega) NOEXCEPT
+{
+ return string_literal<CharT, 2>{'M'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+msl(std::giga) NOEXCEPT
+{
+ return string_literal<CharT, 2>{'G'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+msl(std::tera) NOEXCEPT
+{
+ return string_literal<CharT, 2>{'T'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+msl(std::peta) NOEXCEPT
+{
+ return string_literal<CharT, 2>{'P'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+msl(std::exa) NOEXCEPT
+{
+ return string_literal<CharT, 2>{'E'};
+}
+
+template <class CharT, class Period>
+CONSTCD11
+inline
+auto
+get_units(Period p)
+ -> decltype(msl<CharT>(p) + string_literal<CharT, 2>{'s'})
+{
+ return msl<CharT>(p) + string_literal<CharT, 2>{'s'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+get_units(std::ratio<1>)
+{
+ return string_literal<CharT, 2>{'s'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+get_units(std::ratio<3600>)
+{
+ return string_literal<CharT, 2>{'h'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 4>
+get_units(std::ratio<60>)
+{
+ return string_literal<CharT, 4>{'m', 'i', 'n'};
+}
+
+template <class CharT>
+CONSTCD11
+inline
+string_literal<CharT, 2>
+get_units(std::ratio<86400>)
+{
+ return string_literal<CharT, 2>{'d'};
+}
+
+template <class CharT, class Traits = std::char_traits<CharT>>
+struct make_string;
+
+template <>
+struct make_string<char>
+{
+ template <class Rep>
+ static
+ std::string
+ from(Rep n)
+ {
+ return std::to_string(n);
+ }
+};
+
+template <class Traits>
+struct make_string<char, Traits>
+{
+ template <class Rep>
+ static
+ std::basic_string<char, Traits>
+ from(Rep n)
+ {
+ auto s = std::to_string(n);
+ return std::basic_string<char, Traits>(s.begin(), s.end());
+ }
+};
+
+template <>
+struct make_string<wchar_t>
+{
+ template <class Rep>
+ static
+ std::wstring
+ from(Rep n)
+ {
+ return std::to_wstring(n);
+ }
+};
+
+template <class Traits>
+struct make_string<wchar_t, Traits>
+{
+ template <class Rep>
+ static
+ std::basic_string<wchar_t, Traits>
+ from(Rep n)
+ {
+ auto s = std::to_wstring(n);
+ return std::basic_string<wchar_t, Traits>(s.begin(), s.end());
+ }
+};
+
+} // namespace detail
+
+// to_stream
+
+CONSTDATA year nanyear{-32768};
+
+template <class Duration>
+struct fields
+{
+ year_month_day ymd{nanyear/0/0};
+ weekday wd{8u};
+ hh_mm_ss<Duration> tod{};
+ bool has_tod = false;
+
+#if !defined(__clang__) && defined(__GNUC__) && (__GNUC__ * 100 + __GNUC_MINOR__ <= 409)
+ fields() : ymd{nanyear/0/0}, wd{8u}, tod{}, has_tod{false} {}
+#else
+ fields() = default;
+#endif
+
+ fields(year_month_day ymd_) : ymd(ymd_) {}
+ fields(weekday wd_) : wd(wd_) {}
+ fields(hh_mm_ss<Duration> tod_) : tod(tod_), has_tod(true) {}
+
+ fields(year_month_day ymd_, weekday wd_) : ymd(ymd_), wd(wd_) {}
+ fields(year_month_day ymd_, hh_mm_ss<Duration> tod_) : ymd(ymd_), tod(tod_),
+ has_tod(true) {}
+
+ fields(weekday wd_, hh_mm_ss<Duration> tod_) : wd(wd_), tod(tod_), has_tod(true) {}
+
+ fields(year_month_day ymd_, weekday wd_, hh_mm_ss<Duration> tod_)
+ : ymd(ymd_)
+ , wd(wd_)
+ , tod(tod_)
+ , has_tod(true)
+ {}
+};
+
+namespace detail
+{
+
+template <class CharT, class Traits, class Duration>
+unsigned
+extract_weekday(std::basic_ostream<CharT, Traits>& os, const fields<Duration>& fds)
+{
+ if (!fds.ymd.ok() && !fds.wd.ok())
+ {
+ // fds does not contain a valid weekday
+ os.setstate(std::ios::failbit);
+ return 8;
+ }
+ weekday wd;
+ if (fds.ymd.ok())
+ {
+ wd = weekday{sys_days(fds.ymd)};
+ if (fds.wd.ok() && wd != fds.wd)
+ {
+ // fds.ymd and fds.wd are inconsistent
+ os.setstate(std::ios::failbit);
+ return 8;
+ }
+ }
+ else
+ wd = fds.wd;
+ return static_cast<unsigned>((wd - Sunday).count());
+}
+
+template <class CharT, class Traits, class Duration>
+unsigned
+extract_month(std::basic_ostream<CharT, Traits>& os, const fields<Duration>& fds)
+{
+ if (!fds.ymd.month().ok())
+ {
+ // fds does not contain a valid month
+ os.setstate(std::ios::failbit);
+ return 0;
+ }
+ return static_cast<unsigned>(fds.ymd.month());
+}
+
+} // namespace detail
+
+#if ONLY_C_LOCALE
+
+namespace detail
+{
+
+inline
+std::pair<const std::string*, const std::string*>
+weekday_names()
+{
+ static const std::string nm[] =
+ {
+ "Sunday",
+ "Monday",
+ "Tuesday",
+ "Wednesday",
+ "Thursday",
+ "Friday",
+ "Saturday",
+ "Sun",
+ "Mon",
+ "Tue",
+ "Wed",
+ "Thu",
+ "Fri",
+ "Sat"
+ };
+ return std::make_pair(nm, nm+sizeof(nm)/sizeof(nm[0]));
+}
+
+inline
+std::pair<const std::string*, const std::string*>
+month_names()
+{
+ static const std::string nm[] =
+ {
+ "January",
+ "February",
+ "March",
+ "April",
+ "May",
+ "June",
+ "July",
+ "August",
+ "September",
+ "October",
+ "November",
+ "December",
+ "Jan",
+ "Feb",
+ "Mar",
+ "Apr",
+ "May",
+ "Jun",
+ "Jul",
+ "Aug",
+ "Sep",
+ "Oct",
+ "Nov",
+ "Dec"
+ };
+ return std::make_pair(nm, nm+sizeof(nm)/sizeof(nm[0]));
+}
+
+inline
+std::pair<const std::string*, const std::string*>
+ampm_names()
+{
+ static const std::string nm[] =
+ {
+ "AM",
+ "PM"
+ };
+ return std::make_pair(nm, nm+sizeof(nm)/sizeof(nm[0]));
+}
+
+template <class CharT, class Traits, class FwdIter>
+FwdIter
+scan_keyword(std::basic_istream<CharT, Traits>& is, FwdIter kb, FwdIter ke)
+{
+ size_t nkw = static_cast<size_t>(std::distance(kb, ke));
+ const unsigned char doesnt_match = '\0';
+ const unsigned char might_match = '\1';
+ const unsigned char does_match = '\2';
+ unsigned char statbuf[100];
+ unsigned char* status = statbuf;
+ std::unique_ptr<unsigned char, void(*)(void*)> stat_hold(0, free);
+ if (nkw > sizeof(statbuf))
+ {
+ status = (unsigned char*)std::malloc(nkw);
+ if (status == nullptr)
+ throw std::bad_alloc();
+ stat_hold.reset(status);
+ }
+ size_t n_might_match = nkw; // At this point, any keyword might match
+ size_t n_does_match = 0; // but none of them definitely do
+ // Initialize all statuses to might_match, except for "" keywords are does_match
+ unsigned char* st = status;
+ for (auto ky = kb; ky != ke; ++ky, ++st)
+ {
+ if (!ky->empty())
+ *st = might_match;
+ else
+ {
+ *st = does_match;
+ --n_might_match;
+ ++n_does_match;
+ }
+ }
+ // While there might be a match, test keywords against the next CharT
+ for (size_t indx = 0; is && n_might_match > 0; ++indx)
+ {
+ // Peek at the next CharT but don't consume it
+ auto ic = is.peek();
+ if (ic == EOF)
+ {
+ is.setstate(std::ios::eofbit);
+ break;
+ }
+ auto c = static_cast<char>(toupper(static_cast<unsigned char>(ic)));
+ bool consume = false;
+ // For each keyword which might match, see if the indx character is c
+ // If a match if found, consume c
+ // If a match is found, and that is the last character in the keyword,
+ // then that keyword matches.
+ // If the keyword doesn't match this character, then change the keyword
+ // to doesn't match
+ st = status;
+ for (auto ky = kb; ky != ke; ++ky, ++st)
+ {
+ if (*st == might_match)
+ {
+ if (c == static_cast<char>(toupper(static_cast<unsigned char>((*ky)[indx]))))
+ {
+ consume = true;
+ if (ky->size() == indx+1)
+ {
+ *st = does_match;
+ --n_might_match;
+ ++n_does_match;
+ }
+ }
+ else
+ {
+ *st = doesnt_match;
+ --n_might_match;
+ }
+ }
+ }
+ // consume if we matched a character
+ if (consume)
+ {
+ (void)is.get();
+ // If we consumed a character and there might be a matched keyword that
+ // was marked matched on a previous iteration, then such keywords
+ // are now marked as not matching.
+ if (n_might_match + n_does_match > 1)
+ {
+ st = status;
+ for (auto ky = kb; ky != ke; ++ky, ++st)
+ {
+ if (*st == does_match && ky->size() != indx+1)
+ {
+ *st = doesnt_match;
+ --n_does_match;
+ }
+ }
+ }
+ }
+ }
+ // We've exited the loop because we hit eof and/or we have no more "might matches".
+ // Return the first matching result
+ for (st = status; kb != ke; ++kb, ++st)
+ if (*st == does_match)
+ break;
+ if (kb == ke)
+ is.setstate(std::ios::failbit);
+ return kb;
+}
+
+} // namespace detail
+
+#endif // ONLY_C_LOCALE
+
+template <class CharT, class Traits, class Duration>
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt,
+ const fields<Duration>& fds, const std::string* abbrev,
+ const std::chrono::seconds* offset_sec)
+{
+#if ONLY_C_LOCALE
+ using detail::weekday_names;
+ using detail::month_names;
+ using detail::ampm_names;
+#endif
+ using detail::save_ostream;
+ using detail::get_units;
+ using detail::extract_weekday;
+ using detail::extract_month;
+ using std::ios;
+ using std::chrono::duration_cast;
+ using std::chrono::seconds;
+ using std::chrono::minutes;
+ using std::chrono::hours;
+ date::detail::save_ostream<CharT, Traits> ss(os);
+ os.fill(' ');
+ os.flags(std::ios::skipws | std::ios::dec);
+ os.width(0);
+ tm tm{};
+ bool insert_negative = fds.has_tod && fds.tod.to_duration() < Duration::zero();
+#if !ONLY_C_LOCALE
+ auto& facet = std::use_facet<std::time_put<CharT>>(os.getloc());
+#endif
+ const CharT* command = nullptr;
+ CharT modified = CharT{};
+ for (; *fmt; ++fmt)
+ {
+ switch (*fmt)
+ {
+ case 'a':
+ case 'A':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ tm.tm_wday = static_cast<int>(extract_weekday(os, fds));
+ if (os.fail())
+ return os;
+#if !ONLY_C_LOCALE
+ const CharT f[] = {'%', *fmt};
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+#else // ONLY_C_LOCALE
+ os << weekday_names().first[tm.tm_wday+7*(*fmt == 'a')];
+#endif // ONLY_C_LOCALE
+ }
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ modified = CharT{};
+ }
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'b':
+ case 'B':
+ case 'h':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ tm.tm_mon = static_cast<int>(extract_month(os, fds)) - 1;
+#if !ONLY_C_LOCALE
+ const CharT f[] = {'%', *fmt};
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+#else // ONLY_C_LOCALE
+ os << month_names().first[tm.tm_mon+12*(*fmt != 'B')];
+#endif // ONLY_C_LOCALE
+ }
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ modified = CharT{};
+ }
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'c':
+ case 'x':
+ if (command)
+ {
+ if (modified == CharT{'O'})
+ os << CharT{'%'} << modified << *fmt;
+ else
+ {
+ if (!fds.ymd.ok())
+ os.setstate(std::ios::failbit);
+ if (*fmt == 'c' && !fds.has_tod)
+ os.setstate(std::ios::failbit);
+#if !ONLY_C_LOCALE
+ tm = std::tm{};
+ auto const& ymd = fds.ymd;
+ auto ld = local_days(ymd);
+ if (*fmt == 'c')
+ {
+ tm.tm_sec = static_cast<int>(fds.tod.seconds().count());
+ tm.tm_min = static_cast<int>(fds.tod.minutes().count());
+ tm.tm_hour = static_cast<int>(fds.tod.hours().count());
+ }
+ tm.tm_mday = static_cast<int>(static_cast<unsigned>(ymd.day()));
+ tm.tm_mon = static_cast<int>(extract_month(os, fds) - 1);
+ tm.tm_year = static_cast<int>(ymd.year()) - 1900;
+ tm.tm_wday = static_cast<int>(extract_weekday(os, fds));
+ if (os.fail())
+ return os;
+ tm.tm_yday = static_cast<int>((ld - local_days(ymd.year()/1/1)).count());
+ CharT f[3] = {'%'};
+ auto fe = std::begin(f) + 1;
+ if (modified == CharT{'E'})
+ *fe++ = modified;
+ *fe++ = *fmt;
+ facet.put(os, os, os.fill(), &tm, std::begin(f), fe);
+#else // ONLY_C_LOCALE
+ if (*fmt == 'c')
+ {
+ auto wd = static_cast<int>(extract_weekday(os, fds));
+ os << weekday_names().first[static_cast<unsigned>(wd)+7]
+ << ' ';
+ os << month_names().first[extract_month(os, fds)-1+12] << ' ';
+ auto d = static_cast<int>(static_cast<unsigned>(fds.ymd.day()));
+ if (d < 10)
+ os << ' ';
+ os << d << ' '
+ << make_time(duration_cast<seconds>(fds.tod.to_duration()))
+ << ' ' << fds.ymd.year();
+
+ }
+ else // *fmt == 'x'
+ {
+ auto const& ymd = fds.ymd;
+ save_ostream<CharT, Traits> _(os);
+ os.fill('0');
+ os.flags(std::ios::dec | std::ios::right);
+ os.width(2);
+ os << static_cast<unsigned>(ymd.month()) << CharT{'/'};
+ os.width(2);
+ os << static_cast<unsigned>(ymd.day()) << CharT{'/'};
+ os.width(2);
+ os << static_cast<int>(ymd.year()) % 100;
+ }
+#endif // ONLY_C_LOCALE
+ }
+ command = nullptr;
+ modified = CharT{};
+ }
+ else
+ os << *fmt;
+ break;
+ case 'C':
+ if (command)
+ {
+ if (modified == CharT{'O'})
+ os << CharT{'%'} << modified << *fmt;
+ else
+ {
+ if (!fds.ymd.year().ok())
+ os.setstate(std::ios::failbit);
+ auto y = static_cast<int>(fds.ymd.year());
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#endif
+ {
+ save_ostream<CharT, Traits> _(os);
+ os.fill('0');
+ os.flags(std::ios::dec | std::ios::right);
+ if (y >= 0)
+ {
+ os.width(2);
+ os << y/100;
+ }
+ else
+ {
+ os << CharT{'-'};
+ os.width(2);
+ os << -(y-99)/100;
+ }
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'E'})
+ {
+ tm.tm_year = y - 1900;
+ CharT f[3] = {'%', 'E', 'C'};
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+ }
+#endif
+ }
+ command = nullptr;
+ modified = CharT{};
+ }
+ else
+ os << *fmt;
+ break;
+ case 'd':
+ case 'e':
+ if (command)
+ {
+ if (modified == CharT{'E'})
+ os << CharT{'%'} << modified << *fmt;
+ else
+ {
+ if (!fds.ymd.day().ok())
+ os.setstate(std::ios::failbit);
+ auto d = static_cast<int>(static_cast<unsigned>(fds.ymd.day()));
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#endif
+ {
+ save_ostream<CharT, Traits> _(os);
+ if (*fmt == CharT{'d'})
+ os.fill('0');
+ else
+ os.fill(' ');
+ os.flags(std::ios::dec | std::ios::right);
+ os.width(2);
+ os << d;
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ tm.tm_mday = d;
+ CharT f[3] = {'%', 'O', *fmt};
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+ }
+#endif
+ }
+ command = nullptr;
+ modified = CharT{};
+ }
+ else
+ os << *fmt;
+ break;
+ case 'D':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ if (!fds.ymd.ok())
+ os.setstate(std::ios::failbit);
+ auto const& ymd = fds.ymd;
+ save_ostream<CharT, Traits> _(os);
+ os.fill('0');
+ os.flags(std::ios::dec | std::ios::right);
+ os.width(2);
+ os << static_cast<unsigned>(ymd.month()) << CharT{'/'};
+ os.width(2);
+ os << static_cast<unsigned>(ymd.day()) << CharT{'/'};
+ os.width(2);
+ os << static_cast<int>(ymd.year()) % 100;
+ }
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ modified = CharT{};
+ }
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'F':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ if (!fds.ymd.ok())
+ os.setstate(std::ios::failbit);
+ auto const& ymd = fds.ymd;
+ save_ostream<CharT, Traits> _(os);
+ os.imbue(std::locale::classic());
+ os.fill('0');
+ os.flags(std::ios::dec | std::ios::right);
+ os.width(4);
+ os << static_cast<int>(ymd.year()) << CharT{'-'};
+ os.width(2);
+ os << static_cast<unsigned>(ymd.month()) << CharT{'-'};
+ os.width(2);
+ os << static_cast<unsigned>(ymd.day());
+ }
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ modified = CharT{};
+ }
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'g':
+ case 'G':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ if (!fds.ymd.ok())
+ os.setstate(std::ios::failbit);
+ auto ld = local_days(fds.ymd);
+ auto y = year_month_day{ld + days{3}}.year();
+ auto start = local_days((y-years{1})/December/Thursday[last]) +
+ (Monday-Thursday);
+ if (ld < start)
+ --y;
+ if (*fmt == CharT{'G'})
+ os << y;
+ else
+ {
+ save_ostream<CharT, Traits> _(os);
+ os.fill('0');
+ os.flags(std::ios::dec | std::ios::right);
+ os.width(2);
+ os << std::abs(static_cast<int>(y)) % 100;
+ }
+ }
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ modified = CharT{};
+ }
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'H':
+ case 'I':
+ if (command)
+ {
+ if (modified == CharT{'E'})
+ os << CharT{'%'} << modified << *fmt;
+ else
+ {
+ if (!fds.has_tod)
+ os.setstate(std::ios::failbit);
+ if (insert_negative)
+ {
+ os << '-';
+ insert_negative = false;
+ }
+ auto hms = fds.tod;
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#endif
+ {
+ auto h = *fmt == CharT{'I'} ? date::make12(hms.hours()) : hms.hours();
+ if (h < hours{10})
+ os << CharT{'0'};
+ os << h.count();
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ const CharT f[] = {'%', modified, *fmt};
+ tm.tm_hour = static_cast<int>(hms.hours().count());
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+ }
+#endif
+ }
+ modified = CharT{};
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'j':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ if (fds.ymd.ok() || fds.has_tod)
+ {
+ days doy;
+ if (fds.ymd.ok())
+ {
+ auto ld = local_days(fds.ymd);
+ auto y = fds.ymd.year();
+ doy = ld - local_days(y/January/1) + days{1};
+ }
+ else
+ {
+ doy = duration_cast<days>(fds.tod.to_duration());
+ }
+ save_ostream<CharT, Traits> _(os);
+ os.fill('0');
+ os.flags(std::ios::dec | std::ios::right);
+ os.width(3);
+ os << doy.count();
+ }
+ else
+ {
+ os.setstate(std::ios::failbit);
+ }
+ }
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ modified = CharT{};
+ }
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'm':
+ if (command)
+ {
+ if (modified == CharT{'E'})
+ os << CharT{'%'} << modified << *fmt;
+ else
+ {
+ if (!fds.ymd.month().ok())
+ os.setstate(std::ios::failbit);
+ auto m = static_cast<unsigned>(fds.ymd.month());
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#endif
+ {
+ if (m < 10)
+ os << CharT{'0'};
+ os << m;
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ const CharT f[] = {'%', modified, *fmt};
+ tm.tm_mon = static_cast<int>(m-1);
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+ }
+#endif
+ }
+ modified = CharT{};
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'M':
+ if (command)
+ {
+ if (modified == CharT{'E'})
+ os << CharT{'%'} << modified << *fmt;
+ else
+ {
+ if (!fds.has_tod)
+ os.setstate(std::ios::failbit);
+ if (insert_negative)
+ {
+ os << '-';
+ insert_negative = false;
+ }
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#endif
+ {
+ if (fds.tod.minutes() < minutes{10})
+ os << CharT{'0'};
+ os << fds.tod.minutes().count();
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ const CharT f[] = {'%', modified, *fmt};
+ tm.tm_min = static_cast<int>(fds.tod.minutes().count());
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+ }
+#endif
+ }
+ modified = CharT{};
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'n':
+ if (command)
+ {
+ if (modified == CharT{})
+ os << CharT{'\n'};
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ modified = CharT{};
+ }
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'p':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ if (!fds.has_tod)
+ os.setstate(std::ios::failbit);
+#if !ONLY_C_LOCALE
+ const CharT f[] = {'%', *fmt};
+ tm.tm_hour = static_cast<int>(fds.tod.hours().count());
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+#else
+ if (date::is_am(fds.tod.hours()))
+ os << ampm_names().first[0];
+ else
+ os << ampm_names().first[1];
+#endif
+ }
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ }
+ modified = CharT{};
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'Q':
+ case 'q':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ if (!fds.has_tod)
+ os.setstate(std::ios::failbit);
+ auto d = fds.tod.to_duration();
+ if (*fmt == 'q')
+ os << get_units<CharT>(typename decltype(d)::period::type{});
+ else
+ os << d.count();
+ }
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ }
+ modified = CharT{};
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'r':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ if (!fds.has_tod)
+ os.setstate(std::ios::failbit);
+#if !ONLY_C_LOCALE
+ const CharT f[] = {'%', *fmt};
+ tm.tm_hour = static_cast<int>(fds.tod.hours().count());
+ tm.tm_min = static_cast<int>(fds.tod.minutes().count());
+ tm.tm_sec = static_cast<int>(fds.tod.seconds().count());
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+#else
+ hh_mm_ss<seconds> tod(duration_cast<seconds>(fds.tod.to_duration()));
+ save_ostream<CharT, Traits> _(os);
+ os.fill('0');
+ os.width(2);
+ os << date::make12(tod.hours()).count() << CharT{':'};
+ os.width(2);
+ os << tod.minutes().count() << CharT{':'};
+ os.width(2);
+ os << tod.seconds().count() << CharT{' '};
+ if (date::is_am(tod.hours()))
+ os << ampm_names().first[0];
+ else
+ os << ampm_names().first[1];
+#endif
+ }
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ }
+ modified = CharT{};
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'R':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ if (!fds.has_tod)
+ os.setstate(std::ios::failbit);
+ if (fds.tod.hours() < hours{10})
+ os << CharT{'0'};
+ os << fds.tod.hours().count() << CharT{':'};
+ if (fds.tod.minutes() < minutes{10})
+ os << CharT{'0'};
+ os << fds.tod.minutes().count();
+ }
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ modified = CharT{};
+ }
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'S':
+ if (command)
+ {
+ if (modified == CharT{'E'})
+ os << CharT{'%'} << modified << *fmt;
+ else
+ {
+ if (!fds.has_tod)
+ os.setstate(std::ios::failbit);
+ if (insert_negative)
+ {
+ os << '-';
+ insert_negative = false;
+ }
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#endif
+ {
+ os << fds.tod.s_;
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ const CharT f[] = {'%', modified, *fmt};
+ tm.tm_sec = static_cast<int>(fds.tod.s_.seconds().count());
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+ }
+#endif
+ }
+ modified = CharT{};
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 't':
+ if (command)
+ {
+ if (modified == CharT{})
+ os << CharT{'\t'};
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ modified = CharT{};
+ }
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'T':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ if (!fds.has_tod)
+ os.setstate(std::ios::failbit);
+ os << fds.tod;
+ }
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ modified = CharT{};
+ }
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'u':
+ if (command)
+ {
+ if (modified == CharT{'E'})
+ os << CharT{'%'} << modified << *fmt;
+ else
+ {
+ auto wd = extract_weekday(os, fds);
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#endif
+ {
+ os << (wd != 0 ? wd : 7u);
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ const CharT f[] = {'%', modified, *fmt};
+ tm.tm_wday = static_cast<int>(wd);
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+ }
+#endif
+ }
+ modified = CharT{};
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'U':
+ if (command)
+ {
+ if (modified == CharT{'E'})
+ os << CharT{'%'} << modified << *fmt;
+ else
+ {
+ auto const& ymd = fds.ymd;
+ if (!ymd.ok())
+ os.setstate(std::ios::failbit);
+ auto ld = local_days(ymd);
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#endif
+ {
+ auto st = local_days(Sunday[1]/January/ymd.year());
+ if (ld < st)
+ os << CharT{'0'} << CharT{'0'};
+ else
+ {
+ auto wn = duration_cast<weeks>(ld - st).count() + 1;
+ if (wn < 10)
+ os << CharT{'0'};
+ os << wn;
+ }
+ }
+ #if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ const CharT f[] = {'%', modified, *fmt};
+ tm.tm_year = static_cast<int>(ymd.year()) - 1900;
+ tm.tm_wday = static_cast<int>(extract_weekday(os, fds));
+ if (os.fail())
+ return os;
+ tm.tm_yday = static_cast<int>((ld - local_days(ymd.year()/1/1)).count());
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+ }
+#endif
+ }
+ modified = CharT{};
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'V':
+ if (command)
+ {
+ if (modified == CharT{'E'})
+ os << CharT{'%'} << modified << *fmt;
+ else
+ {
+ if (!fds.ymd.ok())
+ os.setstate(std::ios::failbit);
+ auto ld = local_days(fds.ymd);
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#endif
+ {
+ auto y = year_month_day{ld + days{3}}.year();
+ auto st = local_days((y-years{1})/12/Thursday[last]) +
+ (Monday-Thursday);
+ if (ld < st)
+ {
+ --y;
+ st = local_days((y - years{1})/12/Thursday[last]) +
+ (Monday-Thursday);
+ }
+ auto wn = duration_cast<weeks>(ld - st).count() + 1;
+ if (wn < 10)
+ os << CharT{'0'};
+ os << wn;
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ const CharT f[] = {'%', modified, *fmt};
+ auto const& ymd = fds.ymd;
+ tm.tm_year = static_cast<int>(ymd.year()) - 1900;
+ tm.tm_wday = static_cast<int>(extract_weekday(os, fds));
+ if (os.fail())
+ return os;
+ tm.tm_yday = static_cast<int>((ld - local_days(ymd.year()/1/1)).count());
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+ }
+#endif
+ }
+ modified = CharT{};
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'w':
+ if (command)
+ {
+ auto wd = extract_weekday(os, fds);
+ if (os.fail())
+ return os;
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#else
+ if (modified != CharT{'E'})
+#endif
+ {
+ os << wd;
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ const CharT f[] = {'%', modified, *fmt};
+ tm.tm_wday = static_cast<int>(wd);
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+ }
+#endif
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ }
+ modified = CharT{};
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'W':
+ if (command)
+ {
+ if (modified == CharT{'E'})
+ os << CharT{'%'} << modified << *fmt;
+ else
+ {
+ auto const& ymd = fds.ymd;
+ if (!ymd.ok())
+ os.setstate(std::ios::failbit);
+ auto ld = local_days(ymd);
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#endif
+ {
+ auto st = local_days(Monday[1]/January/ymd.year());
+ if (ld < st)
+ os << CharT{'0'} << CharT{'0'};
+ else
+ {
+ auto wn = duration_cast<weeks>(ld - st).count() + 1;
+ if (wn < 10)
+ os << CharT{'0'};
+ os << wn;
+ }
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ const CharT f[] = {'%', modified, *fmt};
+ tm.tm_year = static_cast<int>(ymd.year()) - 1900;
+ tm.tm_wday = static_cast<int>(extract_weekday(os, fds));
+ if (os.fail())
+ return os;
+ tm.tm_yday = static_cast<int>((ld - local_days(ymd.year()/1/1)).count());
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+ }
+#endif
+ }
+ modified = CharT{};
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'X':
+ if (command)
+ {
+ if (modified == CharT{'O'})
+ os << CharT{'%'} << modified << *fmt;
+ else
+ {
+ if (!fds.has_tod)
+ os.setstate(std::ios::failbit);
+#if !ONLY_C_LOCALE
+ tm = std::tm{};
+ tm.tm_sec = static_cast<int>(fds.tod.seconds().count());
+ tm.tm_min = static_cast<int>(fds.tod.minutes().count());
+ tm.tm_hour = static_cast<int>(fds.tod.hours().count());
+ CharT f[3] = {'%'};
+ auto fe = std::begin(f) + 1;
+ if (modified == CharT{'E'})
+ *fe++ = modified;
+ *fe++ = *fmt;
+ facet.put(os, os, os.fill(), &tm, std::begin(f), fe);
+#else
+ os << fds.tod;
+#endif
+ }
+ command = nullptr;
+ modified = CharT{};
+ }
+ else
+ os << *fmt;
+ break;
+ case 'y':
+ if (command)
+ {
+ if (!fds.ymd.year().ok())
+ os.setstate(std::ios::failbit);
+ auto y = static_cast<int>(fds.ymd.year());
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+ {
+#endif
+ y = std::abs(y) % 100;
+ if (y < 10)
+ os << CharT{'0'};
+ os << y;
+#if !ONLY_C_LOCALE
+ }
+ else
+ {
+ const CharT f[] = {'%', modified, *fmt};
+ tm.tm_year = y - 1900;
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+ }
+#endif
+ modified = CharT{};
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'Y':
+ if (command)
+ {
+ if (modified == CharT{'O'})
+ os << CharT{'%'} << modified << *fmt;
+ else
+ {
+ if (!fds.ymd.year().ok())
+ os.setstate(std::ios::failbit);
+ auto y = fds.ymd.year();
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#endif
+ {
+ save_ostream<CharT, Traits> _(os);
+ os.imbue(std::locale::classic());
+ os << y;
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'E'})
+ {
+ const CharT f[] = {'%', modified, *fmt};
+ tm.tm_year = static_cast<int>(y) - 1900;
+ facet.put(os, os, os.fill(), &tm, std::begin(f), std::end(f));
+ }
+#endif
+ }
+ modified = CharT{};
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'z':
+ if (command)
+ {
+ if (offset_sec == nullptr)
+ {
+ // Can not format %z with unknown offset
+ os.setstate(ios::failbit);
+ return os;
+ }
+ auto m = duration_cast<minutes>(*offset_sec);
+ auto neg = m < minutes{0};
+ m = date::abs(m);
+ auto h = duration_cast<hours>(m);
+ m -= h;
+ if (neg)
+ os << CharT{'-'};
+ else
+ os << CharT{'+'};
+ if (h < hours{10})
+ os << CharT{'0'};
+ os << h.count();
+ if (modified != CharT{})
+ os << CharT{':'};
+ if (m < minutes{10})
+ os << CharT{'0'};
+ os << m.count();
+ command = nullptr;
+ modified = CharT{};
+ }
+ else
+ os << *fmt;
+ break;
+ case 'Z':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ if (abbrev == nullptr)
+ {
+ // Can not format %Z with unknown time_zone
+ os.setstate(ios::failbit);
+ return os;
+ }
+ for (auto c : *abbrev)
+ os << CharT(c);
+ }
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ modified = CharT{};
+ }
+ command = nullptr;
+ }
+ else
+ os << *fmt;
+ break;
+ case 'E':
+ case 'O':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ modified = *fmt;
+ }
+ else
+ {
+ os << CharT{'%'} << modified << *fmt;
+ command = nullptr;
+ modified = CharT{};
+ }
+ }
+ else
+ os << *fmt;
+ break;
+ case '%':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ os << CharT{'%'};
+ command = nullptr;
+ }
+ else
+ {
+ os << CharT{'%'} << modified << CharT{'%'};
+ command = nullptr;
+ modified = CharT{};
+ }
+ }
+ else
+ command = fmt;
+ break;
+ default:
+ if (command)
+ {
+ os << CharT{'%'};
+ command = nullptr;
+ }
+ if (modified != CharT{})
+ {
+ os << modified;
+ modified = CharT{};
+ }
+ os << *fmt;
+ break;
+ }
+ }
+ if (command)
+ os << CharT{'%'};
+ if (modified != CharT{})
+ os << modified;
+ return os;
+}
+
+template <class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt, const year& y)
+{
+ using CT = std::chrono::seconds;
+ fields<CT> fds{y/0/0};
+ return to_stream(os, fmt, fds);
+}
+
+template <class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt, const month& m)
+{
+ using CT = std::chrono::seconds;
+ fields<CT> fds{m/0/nanyear};
+ return to_stream(os, fmt, fds);
+}
+
+template <class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt, const day& d)
+{
+ using CT = std::chrono::seconds;
+ fields<CT> fds{d/0/nanyear};
+ return to_stream(os, fmt, fds);
+}
+
+template <class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt, const weekday& wd)
+{
+ using CT = std::chrono::seconds;
+ fields<CT> fds{wd};
+ return to_stream(os, fmt, fds);
+}
+
+template <class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt, const year_month& ym)
+{
+ using CT = std::chrono::seconds;
+ fields<CT> fds{ym/0};
+ return to_stream(os, fmt, fds);
+}
+
+template <class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt, const month_day& md)
+{
+ using CT = std::chrono::seconds;
+ fields<CT> fds{md/nanyear};
+ return to_stream(os, fmt, fds);
+}
+
+template <class CharT, class Traits>
+inline
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt,
+ const year_month_day& ymd)
+{
+ using CT = std::chrono::seconds;
+ fields<CT> fds{ymd};
+ return to_stream(os, fmt, fds);
+}
+
+template <class CharT, class Traits, class Rep, class Period>
+inline
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt,
+ const std::chrono::duration<Rep, Period>& d)
+{
+ using Duration = std::chrono::duration<Rep, Period>;
+ using CT = typename std::common_type<Duration, std::chrono::seconds>::type;
+ fields<CT> fds{hh_mm_ss<CT>{d}};
+ return to_stream(os, fmt, fds);
+}
+
+template <class CharT, class Traits, class Duration>
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt,
+ const local_time<Duration>& tp, const std::string* abbrev = nullptr,
+ const std::chrono::seconds* offset_sec = nullptr)
+{
+ using CT = typename std::common_type<Duration, std::chrono::seconds>::type;
+ auto ld = std::chrono::time_point_cast<days>(tp);
+ fields<CT> fds;
+ if (ld <= tp)
+ fds = fields<CT>{year_month_day{ld}, hh_mm_ss<CT>{tp-local_seconds{ld}}};
+ else
+ fds = fields<CT>{year_month_day{ld - days{1}},
+ hh_mm_ss<CT>{days{1} - (local_seconds{ld} - tp)}};
+ return to_stream(os, fmt, fds, abbrev, offset_sec);
+}
+
+template <class CharT, class Traits, class Duration>
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt,
+ const sys_time<Duration>& tp)
+{
+ using std::chrono::seconds;
+ using CT = typename std::common_type<Duration, seconds>::type;
+ const std::string abbrev("UTC");
+ CONSTDATA seconds offset{0};
+ auto sd = std::chrono::time_point_cast<days>(tp);
+ fields<CT> fds;
+ if (sd <= tp)
+ fds = fields<CT>{year_month_day{sd}, hh_mm_ss<CT>{tp-sys_seconds{sd}}};
+ else
+ fds = fields<CT>{year_month_day{sd - days{1}},
+ hh_mm_ss<CT>{days{1} - (sys_seconds{sd} - tp)}};
+ return to_stream(os, fmt, fds, &abbrev, &offset);
+}
+
+// format
+
+template <class CharT, class Streamable>
+auto
+format(const std::locale& loc, const CharT* fmt, const Streamable& tp)
+ -> decltype(to_stream(std::declval<std::basic_ostream<CharT>&>(), fmt, tp),
+ std::basic_string<CharT>{})
+{
+ std::basic_ostringstream<CharT> os;
+ os.exceptions(std::ios::failbit | std::ios::badbit);
+ os.imbue(loc);
+ to_stream(os, fmt, tp);
+ return os.str();
+}
+
+template <class CharT, class Streamable>
+auto
+format(const CharT* fmt, const Streamable& tp)
+ -> decltype(to_stream(std::declval<std::basic_ostream<CharT>&>(), fmt, tp),
+ std::basic_string<CharT>{})
+{
+ std::basic_ostringstream<CharT> os;
+ os.exceptions(std::ios::failbit | std::ios::badbit);
+ to_stream(os, fmt, tp);
+ return os.str();
+}
+
+template <class CharT, class Traits, class Alloc, class Streamable>
+auto
+format(const std::locale& loc, const std::basic_string<CharT, Traits, Alloc>& fmt,
+ const Streamable& tp)
+ -> decltype(to_stream(std::declval<std::basic_ostream<CharT, Traits>&>(), fmt.c_str(), tp),
+ std::basic_string<CharT, Traits, Alloc>{})
+{
+ std::basic_ostringstream<CharT, Traits, Alloc> os;
+ os.exceptions(std::ios::failbit | std::ios::badbit);
+ os.imbue(loc);
+ to_stream(os, fmt.c_str(), tp);
+ return os.str();
+}
+
+template <class CharT, class Traits, class Alloc, class Streamable>
+auto
+format(const std::basic_string<CharT, Traits, Alloc>& fmt, const Streamable& tp)
+ -> decltype(to_stream(std::declval<std::basic_ostream<CharT, Traits>&>(), fmt.c_str(), tp),
+ std::basic_string<CharT, Traits, Alloc>{})
+{
+ std::basic_ostringstream<CharT, Traits, Alloc> os;
+ os.exceptions(std::ios::failbit | std::ios::badbit);
+ to_stream(os, fmt.c_str(), tp);
+ return os.str();
+}
+
+// parse
+
+namespace detail
+{
+
+template <class CharT, class Traits>
+bool
+read_char(std::basic_istream<CharT, Traits>& is, CharT fmt, std::ios::iostate& err)
+{
+ auto ic = is.get();
+ if (Traits::eq_int_type(ic, Traits::eof()) ||
+ !Traits::eq(Traits::to_char_type(ic), fmt))
+ {
+ err |= std::ios::failbit;
+ is.setstate(std::ios::failbit);
+ return false;
+ }
+ return true;
+}
+
+template <class CharT, class Traits>
+unsigned
+read_unsigned(std::basic_istream<CharT, Traits>& is, unsigned m = 1, unsigned M = 10)
+{
+ unsigned x = 0;
+ unsigned count = 0;
+ while (true)
+ {
+ auto ic = is.peek();
+ if (Traits::eq_int_type(ic, Traits::eof()))
+ break;
+ auto c = static_cast<char>(Traits::to_char_type(ic));
+ if (!('0' <= c && c <= '9'))
+ break;
+ (void)is.get();
+ ++count;
+ x = 10*x + static_cast<unsigned>(c - '0');
+ if (count == M)
+ break;
+ }
+ if (count < m)
+ is.setstate(std::ios::failbit);
+ return x;
+}
+
+template <class CharT, class Traits>
+int
+read_signed(std::basic_istream<CharT, Traits>& is, unsigned m = 1, unsigned M = 10)
+{
+ auto ic = is.peek();
+ if (!Traits::eq_int_type(ic, Traits::eof()))
+ {
+ auto c = static_cast<char>(Traits::to_char_type(ic));
+ if (('0' <= c && c <= '9') || c == '-' || c == '+')
+ {
+ if (c == '-' || c == '+')
+ (void)is.get();
+ auto x = static_cast<int>(read_unsigned(is, std::max(m, 1u), M));
+ if (!is.fail())
+ {
+ if (c == '-')
+ x = -x;
+ return x;
+ }
+ }
+ }
+ if (m > 0)
+ is.setstate(std::ios::failbit);
+ return 0;
+}
+
+template <class CharT, class Traits>
+long double
+read_long_double(std::basic_istream<CharT, Traits>& is, unsigned m = 1, unsigned M = 10)
+{
+ unsigned count = 0;
+ unsigned fcount = 0;
+ unsigned long long i = 0;
+ unsigned long long f = 0;
+ bool parsing_fraction = false;
+#if ONLY_C_LOCALE
+ typename Traits::int_type decimal_point = '.';
+#else
+ auto decimal_point = Traits::to_int_type(
+ std::use_facet<std::numpunct<CharT>>(is.getloc()).decimal_point());
+#endif
+ while (true)
+ {
+ auto ic = is.peek();
+ if (Traits::eq_int_type(ic, Traits::eof()))
+ break;
+ if (Traits::eq_int_type(ic, decimal_point))
+ {
+ decimal_point = Traits::eof();
+ parsing_fraction = true;
+ }
+ else
+ {
+ auto c = static_cast<char>(Traits::to_char_type(ic));
+ if (!('0' <= c && c <= '9'))
+ break;
+ if (!parsing_fraction)
+ {
+ i = 10*i + static_cast<unsigned>(c - '0');
+ }
+ else
+ {
+ f = 10*f + static_cast<unsigned>(c - '0');
+ ++fcount;
+ }
+ }
+ (void)is.get();
+ if (++count == M)
+ break;
+ }
+ if (count < m)
+ {
+ is.setstate(std::ios::failbit);
+ return 0;
+ }
+ return static_cast<long double>(i) + static_cast<long double>(f)/std::pow(10.L, fcount);
+}
+
+struct rs
+{
+ int& i;
+ unsigned m;
+ unsigned M;
+};
+
+struct ru
+{
+ int& i;
+ unsigned m;
+ unsigned M;
+};
+
+struct rld
+{
+ long double& i;
+ unsigned m;
+ unsigned M;
+};
+
+template <class CharT, class Traits>
+void
+read(std::basic_istream<CharT, Traits>&)
+{
+}
+
+template <class CharT, class Traits, class ...Args>
+void
+read(std::basic_istream<CharT, Traits>& is, CharT a0, Args&& ...args);
+
+template <class CharT, class Traits, class ...Args>
+void
+read(std::basic_istream<CharT, Traits>& is, rs a0, Args&& ...args);
+
+template <class CharT, class Traits, class ...Args>
+void
+read(std::basic_istream<CharT, Traits>& is, ru a0, Args&& ...args);
+
+template <class CharT, class Traits, class ...Args>
+void
+read(std::basic_istream<CharT, Traits>& is, int a0, Args&& ...args);
+
+template <class CharT, class Traits, class ...Args>
+void
+read(std::basic_istream<CharT, Traits>& is, rld a0, Args&& ...args);
+
+template <class CharT, class Traits, class ...Args>
+void
+read(std::basic_istream<CharT, Traits>& is, CharT a0, Args&& ...args)
+{
+ // No-op if a0 == CharT{}
+ if (a0 != CharT{})
+ {
+ auto ic = is.peek();
+ if (Traits::eq_int_type(ic, Traits::eof()))
+ {
+ is.setstate(std::ios::failbit | std::ios::eofbit);
+ return;
+ }
+ if (!Traits::eq(Traits::to_char_type(ic), a0))
+ {
+ is.setstate(std::ios::failbit);
+ return;
+ }
+ (void)is.get();
+ }
+ read(is, std::forward<Args>(args)...);
+}
+
+template <class CharT, class Traits, class ...Args>
+void
+read(std::basic_istream<CharT, Traits>& is, rs a0, Args&& ...args)
+{
+ auto x = read_signed(is, a0.m, a0.M);
+ if (is.fail())
+ return;
+ a0.i = x;
+ read(is, std::forward<Args>(args)...);
+}
+
+template <class CharT, class Traits, class ...Args>
+void
+read(std::basic_istream<CharT, Traits>& is, ru a0, Args&& ...args)
+{
+ auto x = read_unsigned(is, a0.m, a0.M);
+ if (is.fail())
+ return;
+ a0.i = static_cast<int>(x);
+ read(is, std::forward<Args>(args)...);
+}
+
+template <class CharT, class Traits, class ...Args>
+void
+read(std::basic_istream<CharT, Traits>& is, int a0, Args&& ...args)
+{
+ if (a0 != -1)
+ {
+ auto u = static_cast<unsigned>(a0);
+ CharT buf[std::numeric_limits<unsigned>::digits10+2u] = {};
+ auto e = buf;
+ do
+ {
+ *e++ = static_cast<CharT>(CharT(u % 10) + CharT{'0'});
+ u /= 10;
+ } while (u > 0);
+ std::reverse(buf, e);
+ for (auto p = buf; p != e && is.rdstate() == std::ios::goodbit; ++p)
+ read(is, *p);
+ }
+ if (is.rdstate() == std::ios::goodbit)
+ read(is, std::forward<Args>(args)...);
+}
+
+template <class CharT, class Traits, class ...Args>
+void
+read(std::basic_istream<CharT, Traits>& is, rld a0, Args&& ...args)
+{
+ auto x = read_long_double(is, a0.m, a0.M);
+ if (is.fail())
+ return;
+ a0.i = x;
+ read(is, std::forward<Args>(args)...);
+}
+
+template <class T, class CharT, class Traits>
+inline
+void
+checked_set(T& value, T from, T not_a_value, std::basic_ios<CharT, Traits>& is)
+{
+ if (!is.fail())
+ {
+ if (value == not_a_value)
+ value = std::move(from);
+ else if (value != from)
+ is.setstate(std::ios::failbit);
+ }
+}
+
+} // namespace detail;
+
+template <class CharT, class Traits, class Duration, class Alloc = std::allocator<CharT>>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt,
+ fields<Duration>& fds, std::basic_string<CharT, Traits, Alloc>* abbrev,
+ std::chrono::minutes* offset)
+{
+ using std::numeric_limits;
+ using std::ios;
+ using std::chrono::duration;
+ using std::chrono::duration_cast;
+ using std::chrono::seconds;
+ using std::chrono::minutes;
+ using std::chrono::hours;
+ using detail::round_i;
+ typename std::basic_istream<CharT, Traits>::sentry ok{is, true};
+ if (ok)
+ {
+ date::detail::save_istream<CharT, Traits> ss(is);
+ is.fill(' ');
+ is.flags(std::ios::skipws | std::ios::dec);
+ is.width(0);
+#if !ONLY_C_LOCALE
+ auto& f = std::use_facet<std::time_get<CharT>>(is.getloc());
+ std::tm tm{};
+#endif
+ const CharT* command = nullptr;
+ auto modified = CharT{};
+ auto width = -1;
+
+ CONSTDATA int not_a_year = numeric_limits<short>::min();
+ CONSTDATA int not_a_2digit_year = 100;
+ CONSTDATA int not_a_century = not_a_year / 100;
+ CONSTDATA int not_a_month = 0;
+ CONSTDATA int not_a_day = 0;
+ CONSTDATA int not_a_hour = numeric_limits<int>::min();
+ CONSTDATA int not_a_hour_12_value = 0;
+ CONSTDATA int not_a_minute = not_a_hour;
+ CONSTDATA Duration not_a_second = Duration::min();
+ CONSTDATA int not_a_doy = -1;
+ CONSTDATA int not_a_weekday = 8;
+ CONSTDATA int not_a_week_num = 100;
+ CONSTDATA int not_a_ampm = -1;
+ CONSTDATA minutes not_a_offset = minutes::min();
+
+ int Y = not_a_year; // c, F, Y *
+ int y = not_a_2digit_year; // D, x, y *
+ int g = not_a_2digit_year; // g *
+ int G = not_a_year; // G *
+ int C = not_a_century; // C *
+ int m = not_a_month; // b, B, h, m, c, D, F, x *
+ int d = not_a_day; // c, d, D, e, F, x *
+ int j = not_a_doy; // j *
+ int wd = not_a_weekday; // a, A, u, w *
+ int H = not_a_hour; // c, H, R, T, X *
+ int I = not_a_hour_12_value; // I, r *
+ int p = not_a_ampm; // p, r *
+ int M = not_a_minute; // c, M, r, R, T, X *
+ Duration s = not_a_second; // c, r, S, T, X *
+ int U = not_a_week_num; // U *
+ int V = not_a_week_num; // V *
+ int W = not_a_week_num; // W *
+ std::basic_string<CharT, Traits, Alloc> temp_abbrev; // Z *
+ minutes temp_offset = not_a_offset; // z *
+
+ using detail::read;
+ using detail::rs;
+ using detail::ru;
+ using detail::rld;
+ using detail::checked_set;
+ for (; *fmt != CharT{} && !is.fail(); ++fmt)
+ {
+ switch (*fmt)
+ {
+ case 'a':
+ case 'A':
+ case 'u':
+ case 'w': // wd: a, A, u, w
+ if (command)
+ {
+ int trial_wd = not_a_weekday;
+ if (*fmt == 'a' || *fmt == 'A')
+ {
+ if (modified == CharT{})
+ {
+#if !ONLY_C_LOCALE
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ is.setstate(err);
+ if (!is.fail())
+ trial_wd = tm.tm_wday;
+#else
+ auto nm = detail::weekday_names();
+ auto i = detail::scan_keyword(is, nm.first, nm.second) - nm.first;
+ if (!is.fail())
+ trial_wd = i % 7;
+#endif
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ }
+ else // *fmt == 'u' || *fmt == 'w'
+ {
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#else
+ if (modified != CharT{'E'})
+#endif
+ {
+ read(is, ru{trial_wd, 1, width == -1 ?
+ 1u : static_cast<unsigned>(width)});
+ if (!is.fail())
+ {
+ if (*fmt == 'u')
+ {
+ if (!(1 <= trial_wd && trial_wd <= 7))
+ {
+ trial_wd = not_a_weekday;
+ is.setstate(ios::failbit);
+ }
+ else if (trial_wd == 7)
+ trial_wd = 0;
+ }
+ else // *fmt == 'w'
+ {
+ if (!(0 <= trial_wd && trial_wd <= 6))
+ {
+ trial_wd = not_a_weekday;
+ is.setstate(ios::failbit);
+ }
+ }
+ }
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ is.setstate(err);
+ if (!is.fail())
+ trial_wd = tm.tm_wday;
+ }
+#endif
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ }
+ if (trial_wd != not_a_weekday)
+ checked_set(wd, trial_wd, not_a_weekday, is);
+ }
+ else // !command
+ read(is, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ break;
+ case 'b':
+ case 'B':
+ case 'h':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ int ttm = not_a_month;
+#if !ONLY_C_LOCALE
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ if ((err & ios::failbit) == 0)
+ ttm = tm.tm_mon + 1;
+ is.setstate(err);
+#else
+ auto nm = detail::month_names();
+ auto i = detail::scan_keyword(is, nm.first, nm.second) - nm.first;
+ if (!is.fail())
+ ttm = i % 12 + 1;
+#endif
+ checked_set(m, ttm, not_a_month, is);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'c':
+ if (command)
+ {
+ if (modified != CharT{'O'})
+ {
+#if !ONLY_C_LOCALE
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ if ((err & ios::failbit) == 0)
+ {
+ checked_set(Y, tm.tm_year + 1900, not_a_year, is);
+ checked_set(m, tm.tm_mon + 1, not_a_month, is);
+ checked_set(d, tm.tm_mday, not_a_day, is);
+ checked_set(H, tm.tm_hour, not_a_hour, is);
+ checked_set(M, tm.tm_min, not_a_minute, is);
+ checked_set(s, duration_cast<Duration>(seconds{tm.tm_sec}),
+ not_a_second, is);
+ }
+ is.setstate(err);
+#else
+ // "%a %b %e %T %Y"
+ auto nm = detail::weekday_names();
+ auto i = detail::scan_keyword(is, nm.first, nm.second) - nm.first;
+ checked_set(wd, static_cast<int>(i % 7), not_a_weekday, is);
+ ws(is);
+ nm = detail::month_names();
+ i = detail::scan_keyword(is, nm.first, nm.second) - nm.first;
+ checked_set(m, static_cast<int>(i % 12 + 1), not_a_month, is);
+ ws(is);
+ int td = not_a_day;
+ read(is, rs{td, 1, 2});
+ checked_set(d, td, not_a_day, is);
+ ws(is);
+ using dfs = detail::decimal_format_seconds<Duration>;
+ CONSTDATA auto w = Duration::period::den == 1 ? 2 : 3 + dfs::width;
+ int tH;
+ int tM;
+ long double S{};
+ read(is, ru{tH, 1, 2}, CharT{':'}, ru{tM, 1, 2},
+ CharT{':'}, rld{S, 1, w});
+ checked_set(H, tH, not_a_hour, is);
+ checked_set(M, tM, not_a_minute, is);
+ checked_set(s, round_i<Duration>(duration<long double>{S}),
+ not_a_second, is);
+ ws(is);
+ int tY = not_a_year;
+ read(is, rs{tY, 1, 4u});
+ checked_set(Y, tY, not_a_year, is);
+#endif
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'x':
+ if (command)
+ {
+ if (modified != CharT{'O'})
+ {
+#if !ONLY_C_LOCALE
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ if ((err & ios::failbit) == 0)
+ {
+ checked_set(Y, tm.tm_year + 1900, not_a_year, is);
+ checked_set(m, tm.tm_mon + 1, not_a_month, is);
+ checked_set(d, tm.tm_mday, not_a_day, is);
+ }
+ is.setstate(err);
+#else
+ // "%m/%d/%y"
+ int ty = not_a_2digit_year;
+ int tm = not_a_month;
+ int td = not_a_day;
+ read(is, ru{tm, 1, 2}, CharT{'/'}, ru{td, 1, 2}, CharT{'/'},
+ rs{ty, 1, 2});
+ checked_set(y, ty, not_a_2digit_year, is);
+ checked_set(m, tm, not_a_month, is);
+ checked_set(d, td, not_a_day, is);
+#endif
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'X':
+ if (command)
+ {
+ if (modified != CharT{'O'})
+ {
+#if !ONLY_C_LOCALE
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ if ((err & ios::failbit) == 0)
+ {
+ checked_set(H, tm.tm_hour, not_a_hour, is);
+ checked_set(M, tm.tm_min, not_a_minute, is);
+ checked_set(s, duration_cast<Duration>(seconds{tm.tm_sec}),
+ not_a_second, is);
+ }
+ is.setstate(err);
+#else
+ // "%T"
+ using dfs = detail::decimal_format_seconds<Duration>;
+ CONSTDATA auto w = Duration::period::den == 1 ? 2 : 3 + dfs::width;
+ int tH = not_a_hour;
+ int tM = not_a_minute;
+ long double S{};
+ read(is, ru{tH, 1, 2}, CharT{':'}, ru{tM, 1, 2},
+ CharT{':'}, rld{S, 1, w});
+ checked_set(H, tH, not_a_hour, is);
+ checked_set(M, tM, not_a_minute, is);
+ checked_set(s, round_i<Duration>(duration<long double>{S}),
+ not_a_second, is);
+#endif
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'C':
+ if (command)
+ {
+ int tC = not_a_century;
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+ {
+#endif
+ read(is, rs{tC, 1, width == -1 ? 2u : static_cast<unsigned>(width)});
+#if !ONLY_C_LOCALE
+ }
+ else
+ {
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ if ((err & ios::failbit) == 0)
+ {
+ auto tY = tm.tm_year + 1900;
+ tC = (tY >= 0 ? tY : tY-99) / 100;
+ }
+ is.setstate(err);
+ }
+#endif
+ checked_set(C, tC, not_a_century, is);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'D':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ int tn = not_a_month;
+ int td = not_a_day;
+ int ty = not_a_2digit_year;
+ read(is, ru{tn, 1, 2}, CharT{'\0'}, CharT{'/'}, CharT{'\0'},
+ ru{td, 1, 2}, CharT{'\0'}, CharT{'/'}, CharT{'\0'},
+ rs{ty, 1, 2});
+ checked_set(y, ty, not_a_2digit_year, is);
+ checked_set(m, tn, not_a_month, is);
+ checked_set(d, td, not_a_day, is);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'F':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ int tY = not_a_year;
+ int tn = not_a_month;
+ int td = not_a_day;
+ read(is, rs{tY, 1, width == -1 ? 4u : static_cast<unsigned>(width)},
+ CharT{'-'}, ru{tn, 1, 2}, CharT{'-'}, ru{td, 1, 2});
+ checked_set(Y, tY, not_a_year, is);
+ checked_set(m, tn, not_a_month, is);
+ checked_set(d, td, not_a_day, is);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'd':
+ case 'e':
+ if (command)
+ {
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#else
+ if (modified != CharT{'E'})
+#endif
+ {
+ int td = not_a_day;
+ read(is, rs{td, 1, width == -1 ? 2u : static_cast<unsigned>(width)});
+ checked_set(d, td, not_a_day, is);
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ if ((err & ios::failbit) == 0)
+ checked_set(d, tm.tm_mday, not_a_day, is);
+ is.setstate(err);
+ }
+#endif
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'H':
+ if (command)
+ {
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#else
+ if (modified != CharT{'E'})
+#endif
+ {
+ int tH = not_a_hour;
+ read(is, ru{tH, 1, width == -1 ? 2u : static_cast<unsigned>(width)});
+ checked_set(H, tH, not_a_hour, is);
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ if ((err & ios::failbit) == 0)
+ checked_set(H, tm.tm_hour, not_a_hour, is);
+ is.setstate(err);
+ }
+#endif
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'I':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ int tI = not_a_hour_12_value;
+ // reads in an hour into I, but most be in [1, 12]
+ read(is, rs{tI, 1, width == -1 ? 2u : static_cast<unsigned>(width)});
+ if (!(1 <= tI && tI <= 12))
+ is.setstate(ios::failbit);
+ checked_set(I, tI, not_a_hour_12_value, is);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'j':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ int tj = not_a_doy;
+ read(is, ru{tj, 1, width == -1 ? 3u : static_cast<unsigned>(width)});
+ checked_set(j, tj, not_a_doy, is);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'M':
+ if (command)
+ {
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#else
+ if (modified != CharT{'E'})
+#endif
+ {
+ int tM = not_a_minute;
+ read(is, ru{tM, 1, width == -1 ? 2u : static_cast<unsigned>(width)});
+ checked_set(M, tM, not_a_minute, is);
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ if ((err & ios::failbit) == 0)
+ checked_set(M, tm.tm_min, not_a_minute, is);
+ is.setstate(err);
+ }
+#endif
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'm':
+ if (command)
+ {
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#else
+ if (modified != CharT{'E'})
+#endif
+ {
+ int tn = not_a_month;
+ read(is, rs{tn, 1, width == -1 ? 2u : static_cast<unsigned>(width)});
+ checked_set(m, tn, not_a_month, is);
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ if ((err & ios::failbit) == 0)
+ checked_set(m, tm.tm_mon + 1, not_a_month, is);
+ is.setstate(err);
+ }
+#endif
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'n':
+ case 't':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ // %n matches a single white space character
+ // %t matches 0 or 1 white space characters
+ auto ic = is.peek();
+ if (Traits::eq_int_type(ic, Traits::eof()))
+ {
+ ios::iostate err = ios::eofbit;
+ if (*fmt == 'n')
+ err |= ios::failbit;
+ is.setstate(err);
+ break;
+ }
+ if (isspace(ic))
+ {
+ (void)is.get();
+ }
+ else if (*fmt == 'n')
+ is.setstate(ios::failbit);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'p':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ int tp = not_a_ampm;
+#if !ONLY_C_LOCALE
+ tm = std::tm{};
+ tm.tm_hour = 1;
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ is.setstate(err);
+ if (tm.tm_hour == 1)
+ tp = 0;
+ else if (tm.tm_hour == 13)
+ tp = 1;
+ else
+ is.setstate(err);
+#else
+ auto nm = detail::ampm_names();
+ auto i = detail::scan_keyword(is, nm.first, nm.second) - nm.first;
+ tp = static_cast<decltype(tp)>(i);
+#endif
+ checked_set(p, tp, not_a_ampm, is);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+
+ break;
+ case 'r':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+#if !ONLY_C_LOCALE
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ if ((err & ios::failbit) == 0)
+ {
+ checked_set(H, tm.tm_hour, not_a_hour, is);
+ checked_set(M, tm.tm_min, not_a_hour, is);
+ checked_set(s, duration_cast<Duration>(seconds{tm.tm_sec}),
+ not_a_second, is);
+ }
+ is.setstate(err);
+#else
+ // "%I:%M:%S %p"
+ using dfs = detail::decimal_format_seconds<Duration>;
+ CONSTDATA auto w = Duration::period::den == 1 ? 2 : 3 + dfs::width;
+ long double S{};
+ int tI = not_a_hour_12_value;
+ int tM = not_a_minute;
+ read(is, ru{tI, 1, 2}, CharT{':'}, ru{tM, 1, 2},
+ CharT{':'}, rld{S, 1, w});
+ checked_set(I, tI, not_a_hour_12_value, is);
+ checked_set(M, tM, not_a_minute, is);
+ checked_set(s, round_i<Duration>(duration<long double>{S}),
+ not_a_second, is);
+ ws(is);
+ auto nm = detail::ampm_names();
+ auto i = detail::scan_keyword(is, nm.first, nm.second) - nm.first;
+ checked_set(p, static_cast<int>(i), not_a_ampm, is);
+#endif
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'R':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ int tH = not_a_hour;
+ int tM = not_a_minute;
+ read(is, ru{tH, 1, 2}, CharT{'\0'}, CharT{':'}, CharT{'\0'},
+ ru{tM, 1, 2}, CharT{'\0'});
+ checked_set(H, tH, not_a_hour, is);
+ checked_set(M, tM, not_a_minute, is);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'S':
+ if (command)
+ {
+ #if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#else
+ if (modified != CharT{'E'})
+#endif
+ {
+ using dfs = detail::decimal_format_seconds<Duration>;
+ CONSTDATA auto w = Duration::period::den == 1 ? 2 : 3 + dfs::width;
+ long double S{};
+ read(is, rld{S, 1, width == -1 ? w : static_cast<unsigned>(width)});
+ checked_set(s, round_i<Duration>(duration<long double>{S}),
+ not_a_second, is);
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'O'})
+ {
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ if ((err & ios::failbit) == 0)
+ checked_set(s, duration_cast<Duration>(seconds{tm.tm_sec}),
+ not_a_second, is);
+ is.setstate(err);
+ }
+#endif
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'T':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ using dfs = detail::decimal_format_seconds<Duration>;
+ CONSTDATA auto w = Duration::period::den == 1 ? 2 : 3 + dfs::width;
+ int tH = not_a_hour;
+ int tM = not_a_minute;
+ long double S{};
+ read(is, ru{tH, 1, 2}, CharT{':'}, ru{tM, 1, 2},
+ CharT{':'}, rld{S, 1, w});
+ checked_set(H, tH, not_a_hour, is);
+ checked_set(M, tM, not_a_minute, is);
+ checked_set(s, round_i<Duration>(duration<long double>{S}),
+ not_a_second, is);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'Y':
+ if (command)
+ {
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#else
+ if (modified != CharT{'O'})
+#endif
+ {
+ int tY = not_a_year;
+ read(is, rs{tY, 1, width == -1 ? 4u : static_cast<unsigned>(width)});
+ checked_set(Y, tY, not_a_year, is);
+ }
+#if !ONLY_C_LOCALE
+ else if (modified == CharT{'E'})
+ {
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ if ((err & ios::failbit) == 0)
+ checked_set(Y, tm.tm_year + 1900, not_a_year, is);
+ is.setstate(err);
+ }
+#endif
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'y':
+ if (command)
+ {
+#if !ONLY_C_LOCALE
+ if (modified == CharT{})
+#endif
+ {
+ int ty = not_a_2digit_year;
+ read(is, ru{ty, 1, width == -1 ? 2u : static_cast<unsigned>(width)});
+ checked_set(y, ty, not_a_2digit_year, is);
+ }
+#if !ONLY_C_LOCALE
+ else
+ {
+ ios::iostate err = ios::goodbit;
+ f.get(is, nullptr, is, err, &tm, command, fmt+1);
+ if ((err & ios::failbit) == 0)
+ checked_set(Y, tm.tm_year + 1900, not_a_year, is);
+ is.setstate(err);
+ }
+#endif
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'g':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ int tg = not_a_2digit_year;
+ read(is, ru{tg, 1, width == -1 ? 2u : static_cast<unsigned>(width)});
+ checked_set(g, tg, not_a_2digit_year, is);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'G':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ int tG = not_a_year;
+ read(is, rs{tG, 1, width == -1 ? 4u : static_cast<unsigned>(width)});
+ checked_set(G, tG, not_a_year, is);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'U':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ int tU = not_a_week_num;
+ read(is, ru{tU, 1, width == -1 ? 2u : static_cast<unsigned>(width)});
+ checked_set(U, tU, not_a_week_num, is);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'V':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ int tV = not_a_week_num;
+ read(is, ru{tV, 1, width == -1 ? 2u : static_cast<unsigned>(width)});
+ checked_set(V, tV, not_a_week_num, is);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'W':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ int tW = not_a_week_num;
+ read(is, ru{tW, 1, width == -1 ? 2u : static_cast<unsigned>(width)});
+ checked_set(W, tW, not_a_week_num, is);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'E':
+ case 'O':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ modified = *fmt;
+ }
+ else
+ {
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ }
+ else
+ read(is, *fmt);
+ break;
+ case '%':
+ if (command)
+ {
+ if (modified == CharT{})
+ read(is, *fmt);
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ command = fmt;
+ break;
+ case 'z':
+ if (command)
+ {
+ int tH, tM;
+ minutes toff = not_a_offset;
+ bool neg = false;
+ auto ic = is.peek();
+ if (!Traits::eq_int_type(ic, Traits::eof()))
+ {
+ auto c = static_cast<char>(Traits::to_char_type(ic));
+ if (c == '-')
+ neg = true;
+ }
+ if (modified == CharT{})
+ {
+ read(is, rs{tH, 2, 2});
+ if (!is.fail())
+ toff = hours{std::abs(tH)};
+ if (is.good())
+ {
+ ic = is.peek();
+ if (!Traits::eq_int_type(ic, Traits::eof()))
+ {
+ auto c = static_cast<char>(Traits::to_char_type(ic));
+ if ('0' <= c && c <= '9')
+ {
+ read(is, ru{tM, 2, 2});
+ if (!is.fail())
+ toff += minutes{tM};
+ }
+ }
+ }
+ }
+ else
+ {
+ read(is, rs{tH, 1, 2});
+ if (!is.fail())
+ toff = hours{std::abs(tH)};
+ if (is.good())
+ {
+ ic = is.peek();
+ if (!Traits::eq_int_type(ic, Traits::eof()))
+ {
+ auto c = static_cast<char>(Traits::to_char_type(ic));
+ if (c == ':')
+ {
+ (void)is.get();
+ read(is, ru{tM, 2, 2});
+ if (!is.fail())
+ toff += minutes{tM};
+ }
+ }
+ }
+ }
+ if (neg)
+ toff = -toff;
+ checked_set(temp_offset, toff, not_a_offset, is);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ case 'Z':
+ if (command)
+ {
+ if (modified == CharT{})
+ {
+ std::basic_string<CharT, Traits, Alloc> buf;
+ while (is.rdstate() == std::ios::goodbit)
+ {
+ auto i = is.rdbuf()->sgetc();
+ if (Traits::eq_int_type(i, Traits::eof()))
+ {
+ is.setstate(ios::eofbit);
+ break;
+ }
+ auto wc = Traits::to_char_type(i);
+ auto c = static_cast<char>(wc);
+ // is c a valid time zone name or abbreviation character?
+ if (!(CharT{1} < wc && wc < CharT{127}) || !(isalnum(c) ||
+ c == '_' || c == '/' || c == '-' || c == '+'))
+ break;
+ buf.push_back(c);
+ is.rdbuf()->sbumpc();
+ }
+ if (buf.empty())
+ is.setstate(ios::failbit);
+ checked_set(temp_abbrev, buf, {}, is);
+ }
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ else
+ read(is, *fmt);
+ break;
+ default:
+ if (command)
+ {
+ if (width == -1 && modified == CharT{} && '0' <= *fmt && *fmt <= '9')
+ {
+ width = static_cast<char>(*fmt) - '0';
+ while ('0' <= fmt[1] && fmt[1] <= '9')
+ width = 10*width + static_cast<char>(*++fmt) - '0';
+ }
+ else
+ {
+ if (modified == CharT{})
+ read(is, CharT{'%'}, width, *fmt);
+ else
+ read(is, CharT{'%'}, width, modified, *fmt);
+ command = nullptr;
+ width = -1;
+ modified = CharT{};
+ }
+ }
+ else // !command
+ {
+ if (isspace(static_cast<unsigned char>(*fmt)))
+ {
+ // space matches 0 or more white space characters
+ if (is.good())
+ ws(is);
+ }
+ else
+ read(is, *fmt);
+ }
+ break;
+ }
+ }
+ // is.fail() || *fmt == CharT{}
+ if (is.rdstate() == ios::goodbit && command)
+ {
+ if (modified == CharT{})
+ read(is, CharT{'%'}, width);
+ else
+ read(is, CharT{'%'}, width, modified);
+ }
+ if (!is.fail())
+ {
+ if (y != not_a_2digit_year)
+ {
+ // Convert y and an optional C to Y
+ if (!(0 <= y && y <= 99))
+ goto broken;
+ if (C == not_a_century)
+ {
+ if (Y == not_a_year)
+ {
+ if (y >= 69)
+ C = 19;
+ else
+ C = 20;
+ }
+ else
+ {
+ C = (Y >= 0 ? Y : Y-100) / 100;
+ }
+ }
+ int tY;
+ if (C >= 0)
+ tY = 100*C + y;
+ else
+ tY = 100*(C+1) - (y == 0 ? 100 : y);
+ if (Y != not_a_year && Y != tY)
+ goto broken;
+ Y = tY;
+ }
+ if (g != not_a_2digit_year)
+ {
+ // Convert g and an optional C to G
+ if (!(0 <= g && g <= 99))
+ goto broken;
+ if (C == not_a_century)
+ {
+ if (G == not_a_year)
+ {
+ if (g >= 69)
+ C = 19;
+ else
+ C = 20;
+ }
+ else
+ {
+ C = (G >= 0 ? G : G-100) / 100;
+ }
+ }
+ int tG;
+ if (C >= 0)
+ tG = 100*C + g;
+ else
+ tG = 100*(C+1) - (g == 0 ? 100 : g);
+ if (G != not_a_year && G != tG)
+ goto broken;
+ G = tG;
+ }
+ if (Y < static_cast<int>(year::min()) || Y > static_cast<int>(year::max()))
+ Y = not_a_year;
+ bool computed = false;
+ if (G != not_a_year && V != not_a_week_num && wd != not_a_weekday)
+ {
+ year_month_day ymd_trial = sys_days(year{G-1}/December/Thursday[last]) +
+ (Monday-Thursday) + weeks{V-1} +
+ (weekday{static_cast<unsigned>(wd)}-Monday);
+ if (Y == not_a_year)
+ Y = static_cast<int>(ymd_trial.year());
+ else if (year{Y} != ymd_trial.year())
+ goto broken;
+ if (m == not_a_month)
+ m = static_cast<int>(static_cast<unsigned>(ymd_trial.month()));
+ else if (month(static_cast<unsigned>(m)) != ymd_trial.month())
+ goto broken;
+ if (d == not_a_day)
+ d = static_cast<int>(static_cast<unsigned>(ymd_trial.day()));
+ else if (day(static_cast<unsigned>(d)) != ymd_trial.day())
+ goto broken;
+ computed = true;
+ }
+ if (Y != not_a_year && U != not_a_week_num && wd != not_a_weekday)
+ {
+ year_month_day ymd_trial = sys_days(year{Y}/January/Sunday[1]) +
+ weeks{U-1} +
+ (weekday{static_cast<unsigned>(wd)} - Sunday);
+ if (Y == not_a_year)
+ Y = static_cast<int>(ymd_trial.year());
+ else if (year{Y} != ymd_trial.year())
+ goto broken;
+ if (m == not_a_month)
+ m = static_cast<int>(static_cast<unsigned>(ymd_trial.month()));
+ else if (month(static_cast<unsigned>(m)) != ymd_trial.month())
+ goto broken;
+ if (d == not_a_day)
+ d = static_cast<int>(static_cast<unsigned>(ymd_trial.day()));
+ else if (day(static_cast<unsigned>(d)) != ymd_trial.day())
+ goto broken;
+ computed = true;
+ }
+ if (Y != not_a_year && W != not_a_week_num && wd != not_a_weekday)
+ {
+ year_month_day ymd_trial = sys_days(year{Y}/January/Monday[1]) +
+ weeks{W-1} +
+ (weekday{static_cast<unsigned>(wd)} - Monday);
+ if (Y == not_a_year)
+ Y = static_cast<int>(ymd_trial.year());
+ else if (year{Y} != ymd_trial.year())
+ goto broken;
+ if (m == not_a_month)
+ m = static_cast<int>(static_cast<unsigned>(ymd_trial.month()));
+ else if (month(static_cast<unsigned>(m)) != ymd_trial.month())
+ goto broken;
+ if (d == not_a_day)
+ d = static_cast<int>(static_cast<unsigned>(ymd_trial.day()));
+ else if (day(static_cast<unsigned>(d)) != ymd_trial.day())
+ goto broken;
+ computed = true;
+ }
+ if (j != not_a_doy && Y != not_a_year)
+ {
+ auto ymd_trial = year_month_day{local_days(year{Y}/1/1) + days{j-1}};
+ if (m == not_a_month)
+ m = static_cast<int>(static_cast<unsigned>(ymd_trial.month()));
+ else if (month(static_cast<unsigned>(m)) != ymd_trial.month())
+ goto broken;
+ if (d == not_a_day)
+ d = static_cast<int>(static_cast<unsigned>(ymd_trial.day()));
+ else if (day(static_cast<unsigned>(d)) != ymd_trial.day())
+ goto broken;
+ j = not_a_doy;
+ }
+ auto ymd = year{Y}/m/d;
+ if (ymd.ok())
+ {
+ if (wd == not_a_weekday)
+ wd = static_cast<int>((weekday(sys_days(ymd)) - Sunday).count());
+ else if (wd != static_cast<int>((weekday(sys_days(ymd)) - Sunday).count()))
+ goto broken;
+ if (!computed)
+ {
+ if (G != not_a_year || V != not_a_week_num)
+ {
+ sys_days sd = ymd;
+ auto G_trial = year_month_day{sd + days{3}}.year();
+ auto start = sys_days((G_trial - years{1})/December/Thursday[last]) +
+ (Monday - Thursday);
+ if (sd < start)
+ {
+ --G_trial;
+ if (V != not_a_week_num)
+ start = sys_days((G_trial - years{1})/December/Thursday[last])
+ + (Monday - Thursday);
+ }
+ if (G != not_a_year && G != static_cast<int>(G_trial))
+ goto broken;
+ if (V != not_a_week_num)
+ {
+ auto V_trial = duration_cast<weeks>(sd - start).count() + 1;
+ if (V != V_trial)
+ goto broken;
+ }
+ }
+ if (U != not_a_week_num)
+ {
+ auto start = sys_days(Sunday[1]/January/ymd.year());
+ auto U_trial = floor<weeks>(sys_days(ymd) - start).count() + 1;
+ if (U != U_trial)
+ goto broken;
+ }
+ if (W != not_a_week_num)
+ {
+ auto start = sys_days(Monday[1]/January/ymd.year());
+ auto W_trial = floor<weeks>(sys_days(ymd) - start).count() + 1;
+ if (W != W_trial)
+ goto broken;
+ }
+ }
+ }
+ fds.ymd = ymd;
+ if (I != not_a_hour_12_value)
+ {
+ if (!(1 <= I && I <= 12))
+ goto broken;
+ if (p != not_a_ampm)
+ {
+ // p is in [0, 1] == [AM, PM]
+ // Store trial H in I
+ if (I == 12)
+ --p;
+ I += p*12;
+ // Either set H from I or make sure H and I are consistent
+ if (H == not_a_hour)
+ H = I;
+ else if (I != H)
+ goto broken;
+ }
+ else // p == not_a_ampm
+ {
+ // if H, make sure H and I could be consistent
+ if (H != not_a_hour)
+ {
+ if (I == 12)
+ {
+ if (H != 0 && H != 12)
+ goto broken;
+ }
+ else if (!(I == H || I == H+12))
+ {
+ goto broken;
+ }
+ }
+ else // I is ambiguous, AM or PM?
+ goto broken;
+ }
+ }
+ if (H != not_a_hour)
+ {
+ fds.has_tod = true;
+ fds.tod = hh_mm_ss<Duration>{hours{H}};
+ }
+ if (M != not_a_minute)
+ {
+ fds.has_tod = true;
+ fds.tod.m_ = minutes{M};
+ }
+ if (s != not_a_second)
+ {
+ fds.has_tod = true;
+ fds.tod.s_ = detail::decimal_format_seconds<Duration>{s};
+ }
+ if (j != not_a_doy)
+ {
+ fds.has_tod = true;
+ fds.tod.h_ += hours{days{j}};
+ }
+ if (wd != not_a_weekday)
+ fds.wd = weekday{static_cast<unsigned>(wd)};
+ if (abbrev != nullptr)
+ *abbrev = std::move(temp_abbrev);
+ if (offset != nullptr && temp_offset != not_a_offset)
+ *offset = temp_offset;
+ }
+ return is;
+ }
+broken:
+ is.setstate(ios::failbit);
+ return is;
+}
+
+template <class CharT, class Traits, class Alloc = std::allocator<CharT>>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt, year& y,
+ std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+{
+ using CT = std::chrono::seconds;
+ fields<CT> fds{};
+ date::from_stream(is, fmt, fds, abbrev, offset);
+ if (!fds.ymd.year().ok())
+ is.setstate(std::ios::failbit);
+ if (!is.fail())
+ y = fds.ymd.year();
+ return is;
+}
+
+template <class CharT, class Traits, class Alloc = std::allocator<CharT>>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt, month& m,
+ std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+{
+ using CT = std::chrono::seconds;
+ fields<CT> fds{};
+ date::from_stream(is, fmt, fds, abbrev, offset);
+ if (!fds.ymd.month().ok())
+ is.setstate(std::ios::failbit);
+ if (!is.fail())
+ m = fds.ymd.month();
+ return is;
+}
+
+template <class CharT, class Traits, class Alloc = std::allocator<CharT>>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt, day& d,
+ std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+{
+ using CT = std::chrono::seconds;
+ fields<CT> fds{};
+ date::from_stream(is, fmt, fds, abbrev, offset);
+ if (!fds.ymd.day().ok())
+ is.setstate(std::ios::failbit);
+ if (!is.fail())
+ d = fds.ymd.day();
+ return is;
+}
+
+template <class CharT, class Traits, class Alloc = std::allocator<CharT>>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt, weekday& wd,
+ std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+{
+ using CT = std::chrono::seconds;
+ fields<CT> fds{};
+ date::from_stream(is, fmt, fds, abbrev, offset);
+ if (!fds.wd.ok())
+ is.setstate(std::ios::failbit);
+ if (!is.fail())
+ wd = fds.wd;
+ return is;
+}
+
+template <class CharT, class Traits, class Alloc = std::allocator<CharT>>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt, year_month& ym,
+ std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+{
+ using CT = std::chrono::seconds;
+ fields<CT> fds{};
+ date::from_stream(is, fmt, fds, abbrev, offset);
+ if (!fds.ymd.month().ok())
+ is.setstate(std::ios::failbit);
+ if (!is.fail())
+ ym = fds.ymd.year()/fds.ymd.month();
+ return is;
+}
+
+template <class CharT, class Traits, class Alloc = std::allocator<CharT>>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt, month_day& md,
+ std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+{
+ using CT = std::chrono::seconds;
+ fields<CT> fds{};
+ date::from_stream(is, fmt, fds, abbrev, offset);
+ if (!fds.ymd.month().ok() || !fds.ymd.day().ok())
+ is.setstate(std::ios::failbit);
+ if (!is.fail())
+ md = fds.ymd.month()/fds.ymd.day();
+ return is;
+}
+
+template <class CharT, class Traits, class Alloc = std::allocator<CharT>>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt,
+ year_month_day& ymd, std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+{
+ using CT = std::chrono::seconds;
+ fields<CT> fds{};
+ date::from_stream(is, fmt, fds, abbrev, offset);
+ if (!fds.ymd.ok())
+ is.setstate(std::ios::failbit);
+ if (!is.fail())
+ ymd = fds.ymd;
+ return is;
+}
+
+template <class Duration, class CharT, class Traits, class Alloc = std::allocator<CharT>>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt,
+ sys_time<Duration>& tp, std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+{
+ using CT = typename std::common_type<Duration, std::chrono::seconds>::type;
+ using detail::round_i;
+ std::chrono::minutes offset_local{};
+ auto offptr = offset ? offset : &offset_local;
+ fields<CT> fds{};
+ fds.has_tod = true;
+ date::from_stream(is, fmt, fds, abbrev, offptr);
+ if (!fds.ymd.ok() || !fds.tod.in_conventional_range())
+ is.setstate(std::ios::failbit);
+ if (!is.fail())
+ tp = round_i<Duration>(sys_days(fds.ymd) - *offptr + fds.tod.to_duration());
+ return is;
+}
+
+template <class Duration, class CharT, class Traits, class Alloc = std::allocator<CharT>>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt,
+ local_time<Duration>& tp, std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+{
+ using CT = typename std::common_type<Duration, std::chrono::seconds>::type;
+ using detail::round_i;
+ fields<CT> fds{};
+ fds.has_tod = true;
+ date::from_stream(is, fmt, fds, abbrev, offset);
+ if (!fds.ymd.ok() || !fds.tod.in_conventional_range())
+ is.setstate(std::ios::failbit);
+ if (!is.fail())
+ tp = round_i<Duration>(local_seconds{local_days(fds.ymd)} + fds.tod.to_duration());
+ return is;
+}
+
+template <class Rep, class Period, class CharT, class Traits, class Alloc = std::allocator<CharT>>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt,
+ std::chrono::duration<Rep, Period>& d,
+ std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+{
+ using Duration = std::chrono::duration<Rep, Period>;
+ using CT = typename std::common_type<Duration, std::chrono::seconds>::type;
+ using detail::round_i;
+ fields<CT> fds{};
+ date::from_stream(is, fmt, fds, abbrev, offset);
+ if (!fds.has_tod)
+ is.setstate(std::ios::failbit);
+ if (!is.fail())
+ d = round_i<Duration>(fds.tod.to_duration());
+ return is;
+}
+
+template <class Parsable, class CharT, class Traits = std::char_traits<CharT>,
+ class Alloc = std::allocator<CharT>>
+struct parse_manip
+{
+ const std::basic_string<CharT, Traits, Alloc> format_;
+ Parsable& tp_;
+ std::basic_string<CharT, Traits, Alloc>* abbrev_;
+ std::chrono::minutes* offset_;
+
+public:
+ parse_manip(std::basic_string<CharT, Traits, Alloc> format, Parsable& tp,
+ std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+ : format_(std::move(format))
+ , tp_(tp)
+ , abbrev_(abbrev)
+ , offset_(offset)
+ {}
+
+#if HAS_STRING_VIEW
+ parse_manip(const CharT* format, Parsable& tp,
+ std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+ : format_(format)
+ , tp_(tp)
+ , abbrev_(abbrev)
+ , offset_(offset)
+ {}
+
+ parse_manip(std::basic_string_view<CharT, Traits> format, Parsable& tp,
+ std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+ : format_(format)
+ , tp_(tp)
+ , abbrev_(abbrev)
+ , offset_(offset)
+ {}
+#endif // HAS_STRING_VIEW
+};
+
+template <class Parsable, class CharT, class Traits, class Alloc>
+std::basic_istream<CharT, Traits>&
+operator>>(std::basic_istream<CharT, Traits>& is,
+ const parse_manip<Parsable, CharT, Traits, Alloc>& x)
+{
+ return date::from_stream(is, x.format_.c_str(), x.tp_, x.abbrev_, x.offset_);
+}
+
+template <class Parsable, class CharT, class Traits, class Alloc>
+inline
+auto
+parse(const std::basic_string<CharT, Traits, Alloc>& format, Parsable& tp)
+ -> decltype(date::from_stream(std::declval<std::basic_istream<CharT, Traits>&>(),
+ format.c_str(), tp),
+ parse_manip<Parsable, CharT, Traits, Alloc>{format, tp})
+{
+ return {format, tp};
+}
+
+template <class Parsable, class CharT, class Traits, class Alloc>
+inline
+auto
+parse(const std::basic_string<CharT, Traits, Alloc>& format, Parsable& tp,
+ std::basic_string<CharT, Traits, Alloc>& abbrev)
+ -> decltype(date::from_stream(std::declval<std::basic_istream<CharT, Traits>&>(),
+ format.c_str(), tp, &abbrev),
+ parse_manip<Parsable, CharT, Traits, Alloc>{format, tp, &abbrev})
+{
+ return {format, tp, &abbrev};
+}
+
+template <class Parsable, class CharT, class Traits, class Alloc>
+inline
+auto
+parse(const std::basic_string<CharT, Traits, Alloc>& format, Parsable& tp,
+ std::chrono::minutes& offset)
+ -> decltype(date::from_stream(std::declval<std::basic_istream<CharT, Traits>&>(),
+ format.c_str(), tp,
+ std::declval<std::basic_string<CharT, Traits, Alloc>*>(),
+ &offset),
+ parse_manip<Parsable, CharT, Traits, Alloc>{format, tp, nullptr, &offset})
+{
+ return {format, tp, nullptr, &offset};
+}
+
+template <class Parsable, class CharT, class Traits, class Alloc>
+inline
+auto
+parse(const std::basic_string<CharT, Traits, Alloc>& format, Parsable& tp,
+ std::basic_string<CharT, Traits, Alloc>& abbrev, std::chrono::minutes& offset)
+ -> decltype(date::from_stream(std::declval<std::basic_istream<CharT, Traits>&>(),
+ format.c_str(), tp, &abbrev, &offset),
+ parse_manip<Parsable, CharT, Traits, Alloc>{format, tp, &abbrev, &offset})
+{
+ return {format, tp, &abbrev, &offset};
+}
+
+// const CharT* formats
+
+template <class Parsable, class CharT>
+inline
+auto
+parse(const CharT* format, Parsable& tp)
+ -> decltype(date::from_stream(std::declval<std::basic_istream<CharT>&>(), format, tp),
+ parse_manip<Parsable, CharT>{format, tp})
+{
+ return {format, tp};
+}
+
+template <class Parsable, class CharT, class Traits, class Alloc>
+inline
+auto
+parse(const CharT* format, Parsable& tp, std::basic_string<CharT, Traits, Alloc>& abbrev)
+ -> decltype(date::from_stream(std::declval<std::basic_istream<CharT, Traits>&>(), format,
+ tp, &abbrev),
+ parse_manip<Parsable, CharT, Traits, Alloc>{format, tp, &abbrev})
+{
+ return {format, tp, &abbrev};
+}
+
+template <class Parsable, class CharT>
+inline
+auto
+parse(const CharT* format, Parsable& tp, std::chrono::minutes& offset)
+ -> decltype(date::from_stream(std::declval<std::basic_istream<CharT>&>(), format,
+ tp, std::declval<std::basic_string<CharT>*>(), &offset),
+ parse_manip<Parsable, CharT>{format, tp, nullptr, &offset})
+{
+ return {format, tp, nullptr, &offset};
+}
+
+template <class Parsable, class CharT, class Traits, class Alloc>
+inline
+auto
+parse(const CharT* format, Parsable& tp,
+ std::basic_string<CharT, Traits, Alloc>& abbrev, std::chrono::minutes& offset)
+ -> decltype(date::from_stream(std::declval<std::basic_istream<CharT, Traits>&>(), format,
+ tp, &abbrev, &offset),
+ parse_manip<Parsable, CharT, Traits, Alloc>{format, tp, &abbrev, &offset})
+{
+ return {format, tp, &abbrev, &offset};
+}
+
+// duration streaming
+
+template <class CharT, class Traits, class Rep, class Period>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os,
+ const std::chrono::duration<Rep, Period>& d)
+{
+ return os << detail::make_string<CharT, Traits>::from(d.count()) +
+ detail::get_units<CharT>(typename Period::type{});
+}
+
+} // namespace date
+} // namespace arrow_vendored
+
+#ifdef _MSC_VER
+# pragma warning(pop)
+#endif
+
+#ifdef __GNUC__
+# pragma GCC diagnostic pop
+#endif
+
+#endif // DATE_H
diff --git a/src/arrow/cpp/src/arrow/vendored/datetime/ios.h b/src/arrow/cpp/src/arrow/vendored/datetime/ios.h
new file mode 100644
index 000000000..acad28d13
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/datetime/ios.h
@@ -0,0 +1,53 @@
+//
+// ios.h
+// DateTimeLib
+//
+// The MIT License (MIT)
+//
+// Copyright (c) 2016 Alexander Kormanovsky
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+
+#ifndef ios_hpp
+#define ios_hpp
+
+#if __APPLE__
+# include <TargetConditionals.h>
+# if TARGET_OS_IPHONE
+# include <string>
+
+ namespace arrow_vendored
+ {
+ namespace date
+ {
+ namespace iOSUtils
+ {
+
+ std::string get_tzdata_path();
+ std::string get_current_timezone();
+
+ } // namespace iOSUtils
+ } // namespace date
+ } // namespace arrow_vendored
+
+# endif // TARGET_OS_IPHONE
+#else // !__APPLE__
+# define TARGET_OS_IPHONE 0
+#endif // !__APPLE__
+#endif // ios_hpp
diff --git a/src/arrow/cpp/src/arrow/vendored/datetime/ios.mm b/src/arrow/cpp/src/arrow/vendored/datetime/ios.mm
new file mode 100644
index 000000000..22b7ce6c3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/datetime/ios.mm
@@ -0,0 +1,340 @@
+//
+// The MIT License (MIT)
+//
+// Copyright (c) 2016 Alexander Kormanovsky
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+//
+
+#include "ios.h"
+
+#if TARGET_OS_IPHONE
+
+#include <Foundation/Foundation.h>
+
+#include <fstream>
+#include <zlib.h>
+#include <sys/stat.h>
+
+#ifndef TAR_DEBUG
+# define TAR_DEBUG 0
+#endif
+
+#define INTERNAL_DIR "Library"
+#define TZDATA_DIR "tzdata"
+#define TARGZ_EXTENSION "tar.gz"
+
+#define TAR_BLOCK_SIZE 512
+#define TAR_TYPE_POSITION 156
+#define TAR_NAME_POSITION 0
+#define TAR_NAME_SIZE 100
+#define TAR_SIZE_POSITION 124
+#define TAR_SIZE_SIZE 12
+
+namespace arrow_vendored
+{
+namespace date
+{
+ namespace iOSUtils
+ {
+
+ struct TarInfo
+ {
+ char objType;
+ std::string objName;
+ size_t realContentSize; // writable size without padding zeroes
+ size_t blocksContentSize; // adjusted size to 512 bytes blocks
+ bool success;
+ };
+
+ std::string convertCFStringRefPathToCStringPath(CFStringRef ref);
+ bool extractTzdata(CFURLRef homeUrl, CFURLRef archiveUrl, std::string destPath);
+ TarInfo getTarObjectInfo(std::ifstream &readStream);
+ std::string getTarObject(std::ifstream &readStream, int64_t size);
+ bool writeFile(const std::string &tzdataPath, const std::string &fileName,
+ const std::string &data, size_t realContentSize);
+
+ std::string
+ get_current_timezone()
+ {
+ CFTimeZoneRef tzRef = CFTimeZoneCopySystem();
+ CFStringRef tzNameRef = CFTimeZoneGetName(tzRef);
+ CFIndex bufferSize = CFStringGetLength(tzNameRef) + 1;
+ char buffer[bufferSize];
+
+ if (CFStringGetCString(tzNameRef, buffer, bufferSize, kCFStringEncodingUTF8))
+ {
+ CFRelease(tzRef);
+ return std::string(buffer);
+ }
+
+ CFRelease(tzRef);
+
+ return "";
+ }
+
+ std::string
+ get_tzdata_path()
+ {
+ CFURLRef homeUrlRef = CFCopyHomeDirectoryURL();
+ CFStringRef homePath = CFURLCopyPath(homeUrlRef);
+ std::string path(std::string(convertCFStringRefPathToCStringPath(homePath)) +
+ INTERNAL_DIR + "/" + TZDATA_DIR);
+ std::string result_path(std::string(convertCFStringRefPathToCStringPath(homePath)) +
+ INTERNAL_DIR);
+
+ if (access(path.c_str(), F_OK) == 0)
+ {
+#if TAR_DEBUG
+ printf("tzdata dir exists\n");
+#endif
+ CFRelease(homeUrlRef);
+ CFRelease(homePath);
+
+ return result_path;
+ }
+
+ CFBundleRef mainBundle = CFBundleGetMainBundle();
+ CFArrayRef paths = CFBundleCopyResourceURLsOfType(mainBundle, CFSTR(TARGZ_EXTENSION),
+ NULL);
+
+ if (CFArrayGetCount(paths) != 0)
+ {
+ // get archive path, assume there is no other tar.gz in bundle
+ CFURLRef archiveUrl = static_cast<CFURLRef>(CFArrayGetValueAtIndex(paths, 0));
+ CFStringRef archiveName = CFURLCopyPath(archiveUrl);
+ archiveUrl = CFBundleCopyResourceURL(mainBundle, archiveName, NULL, NULL);
+
+ extractTzdata(homeUrlRef, archiveUrl, path);
+
+ CFRelease(archiveUrl);
+ CFRelease(archiveName);
+ }
+
+ CFRelease(homeUrlRef);
+ CFRelease(homePath);
+ CFRelease(paths);
+
+ return result_path;
+ }
+
+ std::string
+ convertCFStringRefPathToCStringPath(CFStringRef ref)
+ {
+ CFIndex bufferSize = CFStringGetMaximumSizeOfFileSystemRepresentation(ref);
+ char *buffer = new char[bufferSize];
+ CFStringGetFileSystemRepresentation(ref, buffer, bufferSize);
+ auto result = std::string(buffer);
+ delete[] buffer;
+ return result;
+ }
+
+ bool
+ extractTzdata(CFURLRef homeUrl, CFURLRef archiveUrl, std::string destPath)
+ {
+ std::string TAR_TMP_PATH = "/tmp.tar";
+
+ CFStringRef homeStringRef = CFURLCopyPath(homeUrl);
+ auto homePath = convertCFStringRefPathToCStringPath(homeStringRef);
+ CFRelease(homeStringRef);
+
+ CFStringRef archiveStringRef = CFURLCopyPath(archiveUrl);
+ auto archivePath = convertCFStringRefPathToCStringPath(archiveStringRef);
+ CFRelease(archiveStringRef);
+
+ // create Library path
+ auto libraryPath = homePath + INTERNAL_DIR;
+
+ // create tzdata path
+ auto tzdataPath = libraryPath + "/" + TZDATA_DIR;
+
+ // -- replace %20 with " "
+ const std::string search = "%20";
+ const std::string replacement = " ";
+ size_t pos = 0;
+
+ while ((pos = archivePath.find(search, pos)) != std::string::npos) {
+ archivePath.replace(pos, search.length(), replacement);
+ pos += replacement.length();
+ }
+
+ gzFile tarFile = gzopen(archivePath.c_str(), "rb");
+
+ // create tar unpacking path
+ auto tarPath = libraryPath + TAR_TMP_PATH;
+
+ // create tzdata directory
+ mkdir(destPath.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH);
+
+ // ======= extract tar ========
+
+ std::ofstream os(tarPath.c_str(), std::ofstream::out | std::ofstream::app);
+ unsigned int bufferLength = 1024 * 256; // 256Kb
+ unsigned char *buffer = (unsigned char *)malloc(bufferLength);
+ bool success = true;
+
+ while (true)
+ {
+ int readBytes = gzread(tarFile, buffer, bufferLength);
+
+ if (readBytes > 0)
+ {
+ os.write((char *) &buffer[0], readBytes);
+ }
+ else
+ if (readBytes == 0)
+ {
+ break;
+ }
+ else
+ if (readBytes == -1)
+ {
+ printf("decompression failed\n");
+ success = false;
+ break;
+ }
+ else
+ {
+ printf("unexpected zlib state\n");
+ success = false;
+ break;
+ }
+ }
+
+ os.close();
+ free(buffer);
+ gzclose(tarFile);
+
+ if (!success)
+ {
+ remove(tarPath.c_str());
+ return false;
+ }
+
+ // ======== extract files =========
+
+ uint64_t location = 0; // Position in the file
+
+ // get file size
+ struct stat stat_buf;
+ int res = stat(tarPath.c_str(), &stat_buf);
+ if (res != 0)
+ {
+ printf("error file size\n");
+ remove(tarPath.c_str());
+ return false;
+ }
+ int64_t tarSize = stat_buf.st_size;
+
+ // create read stream
+ std::ifstream is(tarPath.c_str(), std::ifstream::in | std::ifstream::binary);
+
+ // process files
+ while (location < tarSize)
+ {
+ TarInfo info = getTarObjectInfo(is);
+
+ if (!info.success || info.realContentSize == 0)
+ {
+ break; // something wrong or all files are read
+ }
+
+ switch (info.objType)
+ {
+ case '0': // file
+ case '\0': //
+ {
+ std::string obj = getTarObject(is, info.blocksContentSize);
+#if TAR_DEBUG
+ size += info.realContentSize;
+ printf("#%i %s file size %lld written total %ld from %lld\n", ++count,
+ info.objName.c_str(), info.realContentSize, size, tarSize);
+#endif
+ writeFile(tzdataPath, info.objName, obj, info.realContentSize);
+ location += info.blocksContentSize;
+
+ break;
+ }
+ }
+ }
+
+ remove(tarPath.c_str());
+
+ return true;
+ }
+
+ TarInfo
+ getTarObjectInfo(std::ifstream &readStream)
+ {
+ int64_t length = TAR_BLOCK_SIZE;
+ char buffer[length];
+ char type;
+ char name[TAR_NAME_SIZE + 1];
+ char sizeBuf[TAR_SIZE_SIZE + 1];
+
+ readStream.read(buffer, length);
+
+ memcpy(&type, &buffer[TAR_TYPE_POSITION], 1);
+
+ memset(&name, '\0', TAR_NAME_SIZE + 1);
+ memcpy(&name, &buffer[TAR_NAME_POSITION], TAR_NAME_SIZE);
+
+ memset(&sizeBuf, '\0', TAR_SIZE_SIZE + 1);
+ memcpy(&sizeBuf, &buffer[TAR_SIZE_POSITION], TAR_SIZE_SIZE);
+ size_t realSize = strtol(sizeBuf, NULL, 8);
+ size_t blocksSize = realSize + (TAR_BLOCK_SIZE - (realSize % TAR_BLOCK_SIZE));
+
+ return {type, std::string(name), realSize, blocksSize, true};
+ }
+
+ std::string
+ getTarObject(std::ifstream &readStream, int64_t size)
+ {
+ char buffer[size];
+ readStream.read(buffer, size);
+ return std::string(buffer);
+ }
+
+ bool
+ writeFile(const std::string &tzdataPath, const std::string &fileName, const std::string &data,
+ size_t realContentSize)
+ {
+ std::ofstream os(tzdataPath + "/" + fileName, std::ofstream::out | std::ofstream::binary);
+
+ if (!os) {
+ return false;
+ }
+
+ // trim empty space
+ char trimmedData[realContentSize + 1];
+ memset(&trimmedData, '\0', realContentSize);
+ memcpy(&trimmedData, data.c_str(), realContentSize);
+
+ // write
+ os.write(trimmedData, realContentSize);
+ os.close();
+
+ return true;
+ }
+
+ } // namespace iOSUtils
+} // namespace date
+} // namespace arrow_vendored
+
+#endif // TARGET_OS_IPHONE
diff --git a/src/arrow/cpp/src/arrow/vendored/datetime/tz.cpp b/src/arrow/cpp/src/arrow/vendored/datetime/tz.cpp
new file mode 100644
index 000000000..9047a31c7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/datetime/tz.cpp
@@ -0,0 +1,3951 @@
+// The MIT License (MIT)
+//
+// Copyright (c) 2015, 2016, 2017 Howard Hinnant
+// Copyright (c) 2015 Ville Voutilainen
+// Copyright (c) 2016 Alexander Kormanovsky
+// Copyright (c) 2016, 2017 Jiangang Zhuang
+// Copyright (c) 2017 Nicolas Veloz Savino
+// Copyright (c) 2017 Florian Dang
+// Copyright (c) 2017 Aaron Bishop
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+//
+// Our apologies. When the previous paragraph was written, lowercase had not yet
+// been invented (that would involve another several millennia of evolution).
+// We did not mean to shout.
+
+// NOTE(ARROW): This is required so that symbols are properly exported from the DLL
+#include "visibility.h"
+
+
+#ifdef _WIN32
+ // windows.h will be included directly and indirectly (e.g. by curl).
+ // We need to define these macros to prevent windows.h bringing in
+ // more than we need and do it early so windows.h doesn't get included
+ // without these macros having been defined.
+ // min/max macros interfere with the C++ versions.
+# ifndef NOMINMAX
+# define NOMINMAX
+# endif
+ // We don't need all that Windows has to offer.
+# ifndef WIN32_LEAN_AND_MEAN
+# define WIN32_LEAN_AND_MEAN
+# endif
+
+ // for wcstombs
+# ifndef _CRT_SECURE_NO_WARNINGS
+# define _CRT_SECURE_NO_WARNINGS
+# endif
+
+ // None of this happens with the MS SDK (at least VS14 which I tested), but:
+ // Compiling with mingw, we get "error: 'KF_FLAG_DEFAULT' was not declared in this scope."
+ // and error: 'SHGetKnownFolderPath' was not declared in this scope.".
+ // It seems when using mingw NTDDI_VERSION is undefined and that
+ // causes KNOWN_FOLDER_FLAG and the KF_ flags to not get defined.
+ // So we must define NTDDI_VERSION to get those flags on mingw.
+ // The docs say though here:
+ // https://msdn.microsoft.com/en-nz/library/windows/desktop/aa383745(v=vs.85).aspx
+ // that "If you define NTDDI_VERSION, you must also define _WIN32_WINNT."
+ // So we declare we require Vista or greater.
+# ifdef __MINGW32__
+
+# ifndef NTDDI_VERSION
+# define NTDDI_VERSION 0x06000000
+# define _WIN32_WINNT _WIN32_WINNT_VISTA
+# elif NTDDI_VERSION < 0x06000000
+# warning "If this fails to compile NTDDI_VERSION may be to low. See comments above."
+# endif
+ // But once we define the values above we then get this linker error:
+ // "tz.cpp:(.rdata$.refptr.FOLDERID_Downloads[.refptr.FOLDERID_Downloads]+0x0): "
+ // "undefined reference to `FOLDERID_Downloads'"
+ // which #include <initguid.h> cures see:
+ // https://support.microsoft.com/en-us/kb/130869
+# include <initguid.h>
+ // But with <initguid.h> included, the error moves on to:
+ // error: 'FOLDERID_Downloads' was not declared in this scope
+ // Which #include <knownfolders.h> cures.
+# include <knownfolders.h>
+
+# endif // __MINGW32__
+
+# include <windows.h>
+#endif // _WIN32
+
+#include "tz_private.h"
+
+#ifdef __APPLE__
+# include "ios.h"
+#else
+# define TARGET_OS_IPHONE 0
+# define TARGET_OS_SIMULATOR 0
+#endif
+
+#if USE_OS_TZDB
+# include <dirent.h>
+#endif
+#include <algorithm>
+#include <cctype>
+#include <cstdlib>
+#include <cstring>
+#include <cwchar>
+#include <exception>
+#include <fstream>
+#include <iostream>
+#include <iterator>
+#include <memory>
+#if USE_OS_TZDB
+# include <queue>
+#endif
+#include <sstream>
+#include <string>
+#include <tuple>
+#include <vector>
+#include <sys/stat.h>
+
+// unistd.h is used on some platforms as part of the the means to get
+// the current time zone. On Win32 windows.h provides a means to do it.
+// gcc/mingw supports unistd.h on Win32 but MSVC does not.
+
+#ifdef _WIN32
+# ifdef WINAPI_FAMILY
+# include <winapifamily.h>
+# if WINAPI_FAMILY != WINAPI_FAMILY_DESKTOP_APP
+# define WINRT
+# define INSTALL .
+# endif
+# endif
+
+# include <io.h> // _unlink etc.
+
+# if defined(__clang__)
+ struct IUnknown; // fix for issue with static_cast<> in objbase.h
+ // (see https://github.com/philsquared/Catch/issues/690)
+# endif
+
+# include <shlobj.h> // CoTaskFree, ShGetKnownFolderPath etc.
+# if HAS_REMOTE_API
+# include <direct.h> // _mkdir
+# include <shellapi.h> // ShFileOperation etc.
+# endif // HAS_REMOTE_API
+#else // !_WIN32
+# include <unistd.h>
+# if !USE_OS_TZDB && !defined(INSTALL)
+# include <wordexp.h>
+# endif
+# include <limits.h>
+# include <string.h>
+# if !USE_SHELL_API
+# include <sys/stat.h>
+# include <sys/fcntl.h>
+# include <dirent.h>
+# include <cstring>
+# include <sys/wait.h>
+# include <sys/types.h>
+# endif //!USE_SHELL_API
+#endif // !_WIN32
+
+
+#if HAS_REMOTE_API
+ // Note curl includes windows.h so we must include curl AFTER definitions of things
+ // that affect windows.h such as NOMINMAX.
+#if defined(_MSC_VER) && defined(SHORTENED_CURL_INCLUDE)
+ // For rmt_curl nuget package
+# include <curl.h>
+#else
+# include <curl/curl.h>
+#endif
+#endif
+
+#ifdef _WIN32
+static CONSTDATA char folder_delimiter = '\\';
+#else // !_WIN32
+static CONSTDATA char folder_delimiter = '/';
+#endif // !_WIN32
+
+#if defined(__GNUC__) && __GNUC__ < 5
+ // GCC 4.9 Bug 61489 Wrong warning with -Wmissing-field-initializers
+# pragma GCC diagnostic push
+# pragma GCC diagnostic ignored "-Wmissing-field-initializers"
+#endif // defined(__GNUC__) && __GNUC__ < 5
+
+#if !USE_OS_TZDB
+
+# ifdef _WIN32
+# ifndef WINRT
+
+namespace
+{
+ struct task_mem_deleter
+ {
+ void operator()(wchar_t buf[])
+ {
+ if (buf != nullptr)
+ CoTaskMemFree(buf);
+ }
+ };
+ using co_task_mem_ptr = std::unique_ptr<wchar_t[], task_mem_deleter>;
+}
+
+// We might need to know certain locations even if not using the remote API,
+// so keep these routines out of that block for now.
+static
+std::string
+get_known_folder(const GUID& folderid)
+{
+ std::string folder;
+ PWSTR pfolder = nullptr;
+ HRESULT hr = SHGetKnownFolderPath(folderid, KF_FLAG_DEFAULT, nullptr, &pfolder);
+ if (SUCCEEDED(hr))
+ {
+ co_task_mem_ptr folder_ptr(pfolder);
+ const wchar_t* fptr = folder_ptr.get();
+ auto state = std::mbstate_t();
+ const auto required = std::wcsrtombs(nullptr, &fptr, 0, &state);
+ if (required != 0 && required != std::size_t(-1))
+ {
+ folder.resize(required);
+ std::wcsrtombs(&folder[0], &fptr, folder.size(), &state);
+ }
+ }
+ return folder;
+}
+
+# ifndef INSTALL
+
+// Usually something like "c:\Users\username\Downloads".
+static
+std::string
+get_download_folder()
+{
+ return get_known_folder(FOLDERID_Downloads);
+}
+
+# endif // !INSTALL
+
+# endif // WINRT
+# else // !_WIN32
+
+# if !defined(INSTALL)
+
+static
+std::string
+expand_path(std::string path)
+{
+# if TARGET_OS_IPHONE
+ return date::iOSUtils::get_tzdata_path();
+# else // !TARGET_OS_IPHONE
+ ::wordexp_t w{};
+ std::unique_ptr<::wordexp_t, void(*)(::wordexp_t*)> hold{&w, ::wordfree};
+ ::wordexp(path.c_str(), &w, 0);
+ if (w.we_wordc != 1)
+ throw std::runtime_error("Cannot expand path: " + path);
+ path = w.we_wordv[0];
+ return path;
+# endif // !TARGET_OS_IPHONE
+}
+
+static
+std::string
+get_download_folder()
+{
+ return expand_path("~/Downloads");
+}
+
+# endif // !defined(INSTALL)
+
+# endif // !_WIN32
+
+#endif // !USE_OS_TZDB
+
+namespace arrow_vendored
+{
+namespace date
+{
+// +---------------------+
+// | Begin Configuration |
+// +---------------------+
+
+using namespace detail;
+
+#if !USE_OS_TZDB
+
+static
+std::string&
+access_install()
+{
+ static std::string install
+#ifndef INSTALL
+
+ = get_download_folder() + folder_delimiter + "tzdata";
+
+#else // !INSTALL
+
+# define STRINGIZEIMP(x) #x
+# define STRINGIZE(x) STRINGIZEIMP(x)
+
+ = STRINGIZE(INSTALL) + std::string(1, folder_delimiter) + "tzdata";
+
+ #undef STRINGIZEIMP
+ #undef STRINGIZE
+#endif // !INSTALL
+
+ return install;
+}
+
+void
+set_install(const std::string& s)
+{
+ access_install() = s;
+}
+
+static
+const std::string&
+get_install()
+{
+ static const std::string& ref = access_install();
+ return ref;
+}
+
+#if HAS_REMOTE_API
+static
+std::string
+get_download_gz_file(const std::string& version)
+{
+ auto file = get_install() + version + ".tar.gz";
+ return file;
+}
+#endif // HAS_REMOTE_API
+
+#endif // !USE_OS_TZDB
+
+// These can be used to reduce the range of the database to save memory
+CONSTDATA auto min_year = date::year::min();
+CONSTDATA auto max_year = date::year::max();
+
+CONSTDATA auto min_day = date::January/1;
+CONSTDATA auto max_day = date::December/31;
+
+#if USE_OS_TZDB
+
+CONSTCD14 const sys_seconds min_seconds = sys_days(min_year/min_day);
+
+#endif // USE_OS_TZDB
+
+#ifndef _WIN32
+
+static
+std::string
+discover_tz_dir()
+{
+ struct stat sb;
+ using namespace std;
+# ifndef __APPLE__
+ CONSTDATA auto tz_dir_default = "/usr/share/zoneinfo";
+ CONSTDATA auto tz_dir_buildroot = "/usr/share/zoneinfo/uclibc";
+
+ // Check special path which is valid for buildroot with uclibc builds
+ if(stat(tz_dir_buildroot, &sb) == 0 && S_ISDIR(sb.st_mode))
+ return tz_dir_buildroot;
+ else if(stat(tz_dir_default, &sb) == 0 && S_ISDIR(sb.st_mode))
+ return tz_dir_default;
+ else
+ throw runtime_error("discover_tz_dir failed to find zoneinfo\n");
+# else // __APPLE__
+# if TARGET_OS_IPHONE
+# if TARGET_OS_SIMULATOR
+ return "/usr/share/zoneinfo";
+# else
+ return "/var/db/timezone/zoneinfo";
+# endif
+# else
+ CONSTDATA auto timezone = "/etc/localtime";
+ if (!(lstat(timezone, &sb) == 0 && S_ISLNK(sb.st_mode) && sb.st_size > 0))
+ throw runtime_error("discover_tz_dir failed\n");
+ string result;
+ char rp[PATH_MAX+1] = {};
+ if (readlink(timezone, rp, sizeof(rp)-1) > 0)
+ result = string(rp);
+ else
+ throw system_error(errno, system_category(), "readlink() failed");
+ auto i = result.find("zoneinfo");
+ if (i == string::npos)
+ throw runtime_error("discover_tz_dir failed to find zoneinfo\n");
+ i = result.find('/', i);
+ if (i == string::npos)
+ throw runtime_error("discover_tz_dir failed to find '/'\n");
+ return result.substr(0, i);
+# endif
+# endif // __APPLE__
+}
+
+static
+const std::string&
+get_tz_dir()
+{
+ static const std::string tz_dir = discover_tz_dir();
+ return tz_dir;
+}
+
+#endif
+
+// +-------------------+
+// | End Configuration |
+// +-------------------+
+
+#ifndef _MSC_VER
+static_assert(min_year <= max_year, "Configuration error");
+#endif
+
+static std::unique_ptr<tzdb> init_tzdb();
+
+tzdb_list::~tzdb_list()
+{
+ const tzdb* ptr = head_;
+ head_ = nullptr;
+ while (ptr != nullptr)
+ {
+ auto next = ptr->next;
+ delete ptr;
+ ptr = next;
+ }
+}
+
+tzdb_list::tzdb_list(tzdb_list&& x) NOEXCEPT
+ : head_{x.head_.exchange(nullptr)}
+{
+}
+
+void
+tzdb_list::push_front(tzdb* tzdb) NOEXCEPT
+{
+ tzdb->next = head_;
+ head_ = tzdb;
+}
+
+tzdb_list::const_iterator
+tzdb_list::erase_after(const_iterator p) NOEXCEPT
+{
+ auto t = p.p_->next;
+ p.p_->next = p.p_->next->next;
+ delete t;
+ return ++p;
+}
+
+struct tzdb_list::undocumented_helper
+{
+ static void push_front(tzdb_list& db_list, tzdb* tzdb) NOEXCEPT
+ {
+ db_list.push_front(tzdb);
+ }
+};
+
+static
+tzdb_list
+create_tzdb()
+{
+ tzdb_list tz_db;
+ tzdb_list::undocumented_helper::push_front(tz_db, init_tzdb().release());
+ return tz_db;
+}
+
+tzdb_list&
+get_tzdb_list()
+{
+ static tzdb_list tz_db = create_tzdb();
+ return tz_db;
+}
+
+static
+std::string
+parse3(std::istream& in)
+{
+ std::string r(3, ' ');
+ ws(in);
+ r[0] = static_cast<char>(in.get());
+ r[1] = static_cast<char>(in.get());
+ r[2] = static_cast<char>(in.get());
+ return r;
+}
+
+static
+unsigned
+parse_month(std::istream& in)
+{
+ CONSTDATA char*const month_names[] =
+ {"Jan", "Feb", "Mar", "Apr", "May", "Jun",
+ "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"};
+ auto s = parse3(in);
+ auto m = std::find(std::begin(month_names), std::end(month_names), s) - month_names;
+ if (m >= std::end(month_names) - std::begin(month_names))
+ throw std::runtime_error("oops: bad month name: " + s);
+ return static_cast<unsigned>(++m);
+}
+
+#if !USE_OS_TZDB
+
+#ifdef _WIN32
+
+static
+void
+sort_zone_mappings(std::vector<date::detail::timezone_mapping>& mappings)
+{
+ std::sort(mappings.begin(), mappings.end(),
+ [](const date::detail::timezone_mapping& lhs,
+ const date::detail::timezone_mapping& rhs)->bool
+ {
+ auto other_result = lhs.other.compare(rhs.other);
+ if (other_result < 0)
+ return true;
+ else if (other_result == 0)
+ {
+ auto territory_result = lhs.territory.compare(rhs.territory);
+ if (territory_result < 0)
+ return true;
+ else if (territory_result == 0)
+ {
+ if (lhs.type < rhs.type)
+ return true;
+ }
+ }
+ return false;
+ });
+}
+
+static
+bool
+native_to_standard_timezone_name(const std::string& native_tz_name,
+ std::string& standard_tz_name)
+{
+ // TOOD! Need be a case insensitive compare?
+ if (native_tz_name == "UTC")
+ {
+ standard_tz_name = "Etc/UTC";
+ return true;
+ }
+ standard_tz_name.clear();
+ // TODO! we can improve on linear search.
+ const auto& mappings = date::get_tzdb().mappings;
+ for (const auto& tzm : mappings)
+ {
+ if (tzm.other == native_tz_name)
+ {
+ standard_tz_name = tzm.type;
+ return true;
+ }
+ }
+ return false;
+}
+
+// Parse this XML file:
+// https://raw.githubusercontent.com/unicode-org/cldr/master/common/supplemental/windowsZones.xml
+// The parsing method is designed to be simple and quick. It is not overly
+// forgiving of change but it should diagnose basic format issues.
+// See timezone_mapping structure for more info.
+static
+std::vector<detail::timezone_mapping>
+load_timezone_mappings_from_xml_file(const std::string& input_path)
+{
+ std::size_t line_num = 0;
+ std::vector<detail::timezone_mapping> mappings;
+ std::string line;
+
+ std::ifstream is(input_path);
+ if (!is.is_open())
+ {
+ // We don't emit file exceptions because that's an implementation detail.
+ std::string msg = "Error opening time zone mapping file \"";
+ msg += input_path;
+ msg += "\".";
+ throw std::runtime_error(msg);
+ }
+
+ auto error = [&input_path, &line_num](const char* info)
+ {
+ std::string msg = "Error loading time zone mapping file \"";
+ msg += input_path;
+ msg += "\" at line ";
+ msg += std::to_string(line_num);
+ msg += ": ";
+ msg += info;
+ throw std::runtime_error(msg);
+ };
+ // [optional space]a="b"
+ auto read_attribute = [&line, &error]
+ (const char* name, std::string& value, std::size_t startPos)
+ ->std::size_t
+ {
+ value.clear();
+ // Skip leading space before attribute name.
+ std::size_t spos = line.find_first_not_of(' ', startPos);
+ if (spos == std::string::npos)
+ spos = startPos;
+ // Assume everything up to next = is the attribute name
+ // and that an = will always delimit that.
+ std::size_t epos = line.find('=', spos);
+ if (epos == std::string::npos)
+ error("Expected \'=\' right after attribute name.");
+ std::size_t name_len = epos - spos;
+ // Expect the name we find matches the name we expect.
+ if (line.compare(spos, name_len, name) != 0)
+ {
+ std::string msg;
+ msg = "Expected attribute name \'";
+ msg += name;
+ msg += "\' around position ";
+ msg += std::to_string(spos);
+ msg += " but found something else.";
+ error(msg.c_str());
+ }
+ ++epos; // Skip the '=' that is after the attribute name.
+ spos = epos;
+ if (spos < line.length() && line[spos] == '\"')
+ ++spos; // Skip the quote that is before the attribute value.
+ else
+ {
+ std::string msg = "Expected '\"' to begin value of attribute \'";
+ msg += name;
+ msg += "\'.";
+ error(msg.c_str());
+ }
+ epos = line.find('\"', spos);
+ if (epos == std::string::npos)
+ {
+ std::string msg = "Expected '\"' to end value of attribute \'";
+ msg += name;
+ msg += "\'.";
+ error(msg.c_str());
+ }
+ // Extract everything in between the quotes. Note no escaping is done.
+ std::size_t value_len = epos - spos;
+ value.assign(line, spos, value_len);
+ ++epos; // Skip the quote that is after the attribute value;
+ return epos;
+ };
+
+ // Quick but not overly forgiving XML mapping file processing.
+ bool mapTimezonesOpenTagFound = false;
+ bool mapTimezonesCloseTagFound = false;
+ std::size_t mapZonePos = std::string::npos;
+ std::size_t mapTimezonesPos = std::string::npos;
+ CONSTDATA char mapTimeZonesOpeningTag[] = { "<mapTimezones " };
+ CONSTDATA char mapZoneOpeningTag[] = { "<mapZone " };
+ CONSTDATA std::size_t mapZoneOpeningTagLen = sizeof(mapZoneOpeningTag) /
+ sizeof(mapZoneOpeningTag[0]) - 1;
+ while (!mapTimezonesOpenTagFound)
+ {
+ std::getline(is, line);
+ ++line_num;
+ if (is.eof())
+ {
+ // If there is no mapTimezones tag is it an error?
+ // Perhaps if there are no mapZone mappings it might be ok for
+ // its parent mapTimezones element to be missing?
+ // We treat this as an error though on the assumption that if there
+ // really are no mappings we should still get a mapTimezones parent
+ // element but no mapZone elements inside. Assuming we must
+ // find something will hopefully at least catch more drastic formatting
+ // changes or errors than if we don't do this and assume nothing found.
+ error("Expected a mapTimezones opening tag.");
+ }
+ mapTimezonesPos = line.find(mapTimeZonesOpeningTag);
+ mapTimezonesOpenTagFound = (mapTimezonesPos != std::string::npos);
+ }
+
+ // NOTE: We could extract the version info that follows the opening
+ // mapTimezones tag and compare that to the version of other data we have.
+ // I would have expected them to be kept in synch but testing has shown
+ // it typically does not match anyway. So what's the point?
+ while (!mapTimezonesCloseTagFound)
+ {
+ std::ws(is);
+ std::getline(is, line);
+ ++line_num;
+ if (is.eof())
+ error("Expected a mapTimezones closing tag.");
+ if (line.empty())
+ continue;
+ mapZonePos = line.find(mapZoneOpeningTag);
+ if (mapZonePos != std::string::npos)
+ {
+ mapZonePos += mapZoneOpeningTagLen;
+ detail::timezone_mapping zm{};
+ std::size_t pos = read_attribute("other", zm.other, mapZonePos);
+ pos = read_attribute("territory", zm.territory, pos);
+ read_attribute("type", zm.type, pos);
+ mappings.push_back(std::move(zm));
+
+ continue;
+ }
+ mapTimezonesPos = line.find("</mapTimezones>");
+ mapTimezonesCloseTagFound = (mapTimezonesPos != std::string::npos);
+ if (!mapTimezonesCloseTagFound)
+ {
+ std::size_t commentPos = line.find("<!--");
+ if (commentPos == std::string::npos)
+ error("Unexpected mapping record found. A xml mapZone or comment "
+ "attribute or mapTimezones closing tag was expected.");
+ }
+ }
+
+ is.close();
+ return mappings;
+}
+
+#endif // _WIN32
+
+// Parsing helpers
+
+static
+unsigned
+parse_dow(std::istream& in)
+{
+ CONSTDATA char*const dow_names[] =
+ {"Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"};
+ auto s = parse3(in);
+ auto dow = std::find(std::begin(dow_names), std::end(dow_names), s) - dow_names;
+ if (dow >= std::end(dow_names) - std::begin(dow_names))
+ throw std::runtime_error("oops: bad dow name: " + s);
+ return static_cast<unsigned>(dow);
+}
+
+static
+std::chrono::seconds
+parse_unsigned_time(std::istream& in)
+{
+ using namespace std::chrono;
+ int x;
+ in >> x;
+ auto r = seconds{hours{x}};
+ if (!in.eof() && in.peek() == ':')
+ {
+ in.get();
+ in >> x;
+ r += minutes{x};
+ if (!in.eof() && in.peek() == ':')
+ {
+ in.get();
+ in >> x;
+ r += seconds{x};
+ }
+ }
+ return r;
+}
+
+static
+std::chrono::seconds
+parse_signed_time(std::istream& in)
+{
+ ws(in);
+ auto sign = 1;
+ if (in.peek() == '-')
+ {
+ sign = -1;
+ in.get();
+ }
+ else if (in.peek() == '+')
+ in.get();
+ return sign * parse_unsigned_time(in);
+}
+
+// MonthDayTime
+
+detail::MonthDayTime::MonthDayTime(local_seconds tp, tz timezone)
+ : zone_(timezone)
+{
+ using namespace date;
+ const auto dp = date::floor<days>(tp);
+ const auto hms = make_time(tp - dp);
+ const auto ymd = year_month_day(dp);
+ u = ymd.month() / ymd.day();
+ h_ = hms.hours();
+ m_ = hms.minutes();
+ s_ = hms.seconds();
+}
+
+detail::MonthDayTime::MonthDayTime(const date::month_day& md, tz timezone)
+ : zone_(timezone)
+{
+ u = md;
+}
+
+date::day
+detail::MonthDayTime::day() const
+{
+ switch (type_)
+ {
+ case month_day:
+ return u.month_day_.day();
+ case month_last_dow:
+ return date::day{31};
+ case lteq:
+ case gteq:
+ break;
+ }
+ return u.month_day_weekday_.month_day_.day();
+}
+
+date::month
+detail::MonthDayTime::month() const
+{
+ switch (type_)
+ {
+ case month_day:
+ return u.month_day_.month();
+ case month_last_dow:
+ return u.month_weekday_last_.month();
+ case lteq:
+ case gteq:
+ break;
+ }
+ return u.month_day_weekday_.month_day_.month();
+}
+
+int
+detail::MonthDayTime::compare(date::year y, const MonthDayTime& x, date::year yx,
+ std::chrono::seconds offset, std::chrono::minutes prev_save) const
+{
+ if (zone_ != x.zone_)
+ {
+ auto dp0 = to_sys_days(y);
+ auto dp1 = x.to_sys_days(yx);
+ if (std::abs((dp0-dp1).count()) > 1)
+ return dp0 < dp1 ? -1 : 1;
+ if (zone_ == tz::local)
+ {
+ auto tp0 = to_time_point(y) - prev_save;
+ if (x.zone_ == tz::utc)
+ tp0 -= offset;
+ auto tp1 = x.to_time_point(yx);
+ return tp0 < tp1 ? -1 : tp0 == tp1 ? 0 : 1;
+ }
+ else if (zone_ == tz::standard)
+ {
+ auto tp0 = to_time_point(y);
+ auto tp1 = x.to_time_point(yx);
+ if (x.zone_ == tz::local)
+ tp1 -= prev_save;
+ else
+ tp0 -= offset;
+ return tp0 < tp1 ? -1 : tp0 == tp1 ? 0 : 1;
+ }
+ // zone_ == tz::utc
+ auto tp0 = to_time_point(y);
+ auto tp1 = x.to_time_point(yx);
+ if (x.zone_ == tz::local)
+ tp1 -= offset + prev_save;
+ else
+ tp1 -= offset;
+ return tp0 < tp1 ? -1 : tp0 == tp1 ? 0 : 1;
+ }
+ auto const t0 = to_time_point(y);
+ auto const t1 = x.to_time_point(yx);
+ return t0 < t1 ? -1 : t0 == t1 ? 0 : 1;
+}
+
+sys_seconds
+detail::MonthDayTime::to_sys(date::year y, std::chrono::seconds offset,
+ std::chrono::seconds save) const
+{
+ using namespace date;
+ using namespace std::chrono;
+ auto until_utc = to_time_point(y);
+ if (zone_ == tz::standard)
+ until_utc -= offset;
+ else if (zone_ == tz::local)
+ until_utc -= offset + save;
+ return until_utc;
+}
+
+detail::MonthDayTime::U&
+detail::MonthDayTime::U::operator=(const date::month_day& x)
+{
+ month_day_ = x;
+ return *this;
+}
+
+detail::MonthDayTime::U&
+detail::MonthDayTime::U::operator=(const date::month_weekday_last& x)
+{
+ month_weekday_last_ = x;
+ return *this;
+}
+
+detail::MonthDayTime::U&
+detail::MonthDayTime::U::operator=(const pair& x)
+{
+ month_day_weekday_ = x;
+ return *this;
+}
+
+date::sys_days
+detail::MonthDayTime::to_sys_days(date::year y) const
+{
+ using namespace std::chrono;
+ using namespace date;
+ switch (type_)
+ {
+ case month_day:
+ return sys_days(y/u.month_day_);
+ case month_last_dow:
+ return sys_days(y/u.month_weekday_last_);
+ case lteq:
+ {
+ auto const x = y/u.month_day_weekday_.month_day_;
+ auto const wd1 = weekday(static_cast<sys_days>(x));
+ auto const wd0 = u.month_day_weekday_.weekday_;
+ return sys_days(x) - (wd1-wd0);
+ }
+ case gteq:
+ break;
+ }
+ auto const x = y/u.month_day_weekday_.month_day_;
+ auto const wd1 = u.month_day_weekday_.weekday_;
+ auto const wd0 = weekday(static_cast<sys_days>(x));
+ return sys_days(x) + (wd1-wd0);
+}
+
+sys_seconds
+detail::MonthDayTime::to_time_point(date::year y) const
+{
+ // Add seconds first to promote to largest rep early to prevent overflow
+ return to_sys_days(y) + s_ + h_ + m_;
+}
+
+void
+detail::MonthDayTime::canonicalize(date::year y)
+{
+ using namespace std::chrono;
+ using namespace date;
+ switch (type_)
+ {
+ case month_day:
+ return;
+ case month_last_dow:
+ {
+ auto const ymd = year_month_day(sys_days(y/u.month_weekday_last_));
+ u.month_day_ = ymd.month()/ymd.day();
+ type_ = month_day;
+ return;
+ }
+ case lteq:
+ {
+ auto const x = y/u.month_day_weekday_.month_day_;
+ auto const wd1 = weekday(static_cast<sys_days>(x));
+ auto const wd0 = u.month_day_weekday_.weekday_;
+ auto const ymd = year_month_day(sys_days(x) - (wd1-wd0));
+ u.month_day_ = ymd.month()/ymd.day();
+ type_ = month_day;
+ return;
+ }
+ case gteq:
+ {
+ auto const x = y/u.month_day_weekday_.month_day_;
+ auto const wd1 = u.month_day_weekday_.weekday_;
+ auto const wd0 = weekday(static_cast<sys_days>(x));
+ auto const ymd = year_month_day(sys_days(x) + (wd1-wd0));
+ u.month_day_ = ymd.month()/ymd.day();
+ type_ = month_day;
+ return;
+ }
+ }
+}
+
+std::istream&
+detail::operator>>(std::istream& is, MonthDayTime& x)
+{
+ using namespace date;
+ using namespace std::chrono;
+ assert(((std::ios::failbit | std::ios::badbit) & is.exceptions()) ==
+ (std::ios::failbit | std::ios::badbit));
+ x = MonthDayTime{};
+ if (!is.eof() && ws(is) && !is.eof() && is.peek() != '#')
+ {
+ auto m = parse_month(is);
+ if (!is.eof() && ws(is) && !is.eof() && is.peek() != '#')
+ {
+ if (is.peek() == 'l')
+ {
+ for (int i = 0; i < 4; ++i)
+ is.get();
+ auto dow = parse_dow(is);
+ x.type_ = MonthDayTime::month_last_dow;
+ x.u = date::month(m)/weekday(dow)[last];
+ }
+ else if (std::isalpha(is.peek()))
+ {
+ auto dow = parse_dow(is);
+ char c{};
+ is >> c;
+ if (c == '<' || c == '>')
+ {
+ char c2{};
+ is >> c2;
+ if (c2 != '=')
+ throw std::runtime_error(std::string("bad operator: ") + c + c2);
+ int d;
+ is >> d;
+ if (d < 1 || d > 31)
+ throw std::runtime_error(std::string("bad operator: ") + c + c2
+ + std::to_string(d));
+ x.type_ = c == '<' ? MonthDayTime::lteq : MonthDayTime::gteq;
+ x.u = MonthDayTime::pair{ date::month(m) / d, date::weekday(dow) };
+ }
+ else
+ throw std::runtime_error(std::string("bad operator: ") + c);
+ }
+ else // if (std::isdigit(is.peek())
+ {
+ int d;
+ is >> d;
+ if (d < 1 || d > 31)
+ throw std::runtime_error(std::string("day of month: ")
+ + std::to_string(d));
+ x.type_ = MonthDayTime::month_day;
+ x.u = date::month(m)/d;
+ }
+ if (!is.eof() && ws(is) && !is.eof() && is.peek() != '#')
+ {
+ int t;
+ is >> t;
+ x.h_ = hours{t};
+ if (!is.eof() && is.peek() == ':')
+ {
+ is.get();
+ is >> t;
+ x.m_ = minutes{t};
+ if (!is.eof() && is.peek() == ':')
+ {
+ is.get();
+ is >> t;
+ x.s_ = seconds{t};
+ }
+ }
+ if (!is.eof() && std::isalpha(is.peek()))
+ {
+ char c;
+ is >> c;
+ switch (c)
+ {
+ case 's':
+ x.zone_ = tz::standard;
+ break;
+ case 'u':
+ x.zone_ = tz::utc;
+ break;
+ }
+ }
+ }
+ }
+ else
+ {
+ x.u = month{m}/1;
+ }
+ }
+ return is;
+}
+
+std::ostream&
+detail::operator<<(std::ostream& os, const MonthDayTime& x)
+{
+ switch (x.type_)
+ {
+ case MonthDayTime::month_day:
+ os << x.u.month_day_ << " ";
+ break;
+ case MonthDayTime::month_last_dow:
+ os << x.u.month_weekday_last_ << " ";
+ break;
+ case MonthDayTime::lteq:
+ os << x.u.month_day_weekday_.weekday_ << " on or before "
+ << x.u.month_day_weekday_.month_day_ << " ";
+ break;
+ case MonthDayTime::gteq:
+ if ((static_cast<unsigned>(x.day()) - 1) % 7 == 0)
+ {
+ os << (x.u.month_day_weekday_.month_day_.month() /
+ x.u.month_day_weekday_.weekday_[
+ (static_cast<unsigned>(x.day()) - 1)/7+1]) << " ";
+ }
+ else
+ {
+ os << x.u.month_day_weekday_.weekday_ << " on or after "
+ << x.u.month_day_weekday_.month_day_ << " ";
+ }
+ break;
+ }
+ os << date::make_time(x.s_ + x.h_ + x.m_);
+ if (x.zone_ == tz::utc)
+ os << "UTC ";
+ else if (x.zone_ == tz::standard)
+ os << "STD ";
+ else
+ os << " ";
+ return os;
+}
+
+// Rule
+
+detail::Rule::Rule(const std::string& s)
+{
+ try
+ {
+ using namespace date;
+ using namespace std::chrono;
+ std::istringstream in(s);
+ in.exceptions(std::ios::failbit | std::ios::badbit);
+ std::string word;
+ in >> word >> name_;
+ int x;
+ ws(in);
+ if (std::isalpha(in.peek()))
+ {
+ in >> word;
+ if (word == "min")
+ {
+ starting_year_ = year::min();
+ }
+ else
+ throw std::runtime_error("Didn't find expected word: " + word);
+ }
+ else
+ {
+ in >> x;
+ starting_year_ = year{x};
+ }
+ std::ws(in);
+ if (std::isalpha(in.peek()))
+ {
+ in >> word;
+ if (word == "only")
+ {
+ ending_year_ = starting_year_;
+ }
+ else if (word == "max")
+ {
+ ending_year_ = year::max();
+ }
+ else
+ throw std::runtime_error("Didn't find expected word: " + word);
+ }
+ else
+ {
+ in >> x;
+ ending_year_ = year{x};
+ }
+ in >> word; // TYPE (always "-")
+ assert(word == "-");
+ in >> starting_at_;
+ save_ = duration_cast<minutes>(parse_signed_time(in));
+ in >> abbrev_;
+ if (abbrev_ == "-")
+ abbrev_.clear();
+ assert(hours{-1} <= save_ && save_ <= hours{2});
+ }
+ catch (...)
+ {
+ std::cerr << s << '\n';
+ std::cerr << *this << '\n';
+ throw;
+ }
+}
+
+detail::Rule::Rule(const Rule& r, date::year starting_year, date::year ending_year)
+ : name_(r.name_)
+ , starting_year_(starting_year)
+ , ending_year_(ending_year)
+ , starting_at_(r.starting_at_)
+ , save_(r.save_)
+ , abbrev_(r.abbrev_)
+{
+}
+
+bool
+detail::operator==(const Rule& x, const Rule& y)
+{
+ if (std::tie(x.name_, x.save_, x.starting_year_, x.ending_year_) ==
+ std::tie(y.name_, y.save_, y.starting_year_, y.ending_year_))
+ return x.month() == y.month() && x.day() == y.day();
+ return false;
+}
+
+bool
+detail::operator<(const Rule& x, const Rule& y)
+{
+ using namespace std::chrono;
+ auto const xm = x.month();
+ auto const ym = y.month();
+ if (std::tie(x.name_, x.starting_year_, xm, x.ending_year_) <
+ std::tie(y.name_, y.starting_year_, ym, y.ending_year_))
+ return true;
+ if (std::tie(x.name_, x.starting_year_, xm, x.ending_year_) >
+ std::tie(y.name_, y.starting_year_, ym, y.ending_year_))
+ return false;
+ return x.day() < y.day();
+}
+
+bool
+detail::operator==(const Rule& x, const date::year& y)
+{
+ return x.starting_year_ <= y && y <= x.ending_year_;
+}
+
+bool
+detail::operator<(const Rule& x, const date::year& y)
+{
+ return x.ending_year_ < y;
+}
+
+bool
+detail::operator==(const date::year& x, const Rule& y)
+{
+ return y.starting_year_ <= x && x <= y.ending_year_;
+}
+
+bool
+detail::operator<(const date::year& x, const Rule& y)
+{
+ return x < y.starting_year_;
+}
+
+bool
+detail::operator==(const Rule& x, const std::string& y)
+{
+ return x.name() == y;
+}
+
+bool
+detail::operator<(const Rule& x, const std::string& y)
+{
+ return x.name() < y;
+}
+
+bool
+detail::operator==(const std::string& x, const Rule& y)
+{
+ return y.name() == x;
+}
+
+bool
+detail::operator<(const std::string& x, const Rule& y)
+{
+ return x < y.name();
+}
+
+std::ostream&
+detail::operator<<(std::ostream& os, const Rule& r)
+{
+ using namespace date;
+ using namespace std::chrono;
+ detail::save_ostream<char> _(os);
+ os.fill(' ');
+ os.flags(std::ios::dec | std::ios::left);
+ os.width(15);
+ os << r.name_;
+ os << r.starting_year_ << " " << r.ending_year_ << " ";
+ os << r.starting_at_;
+ if (r.save_ >= minutes{0})
+ os << ' ';
+ os << date::make_time(r.save_) << " ";
+ os << r.abbrev_;
+ return os;
+}
+
+date::day
+detail::Rule::day() const
+{
+ return starting_at_.day();
+}
+
+date::month
+detail::Rule::month() const
+{
+ return starting_at_.month();
+}
+
+struct find_rule_by_name
+{
+ bool operator()(const Rule& x, const std::string& nm) const
+ {
+ return x.name() < nm;
+ }
+
+ bool operator()(const std::string& nm, const Rule& x) const
+ {
+ return nm < x.name();
+ }
+};
+
+bool
+detail::Rule::overlaps(const Rule& x, const Rule& y)
+{
+ // assume x.starting_year_ <= y.starting_year_;
+ if (!(x.starting_year_ <= y.starting_year_))
+ {
+ std::cerr << x << '\n';
+ std::cerr << y << '\n';
+ assert(x.starting_year_ <= y.starting_year_);
+ }
+ if (y.starting_year_ > x.ending_year_)
+ return false;
+ return !(x.starting_year_ == y.starting_year_ && x.ending_year_ == y.ending_year_);
+}
+
+void
+detail::Rule::split(std::vector<Rule>& rules, std::size_t i, std::size_t k, std::size_t& e)
+{
+ using namespace date;
+ using difference_type = std::vector<Rule>::iterator::difference_type;
+ // rules[i].starting_year_ <= rules[k].starting_year_ &&
+ // rules[i].ending_year_ >= rules[k].starting_year_ &&
+ // (rules[i].starting_year_ != rules[k].starting_year_ ||
+ // rules[i].ending_year_ != rules[k].ending_year_)
+ assert(rules[i].starting_year_ <= rules[k].starting_year_ &&
+ rules[i].ending_year_ >= rules[k].starting_year_ &&
+ (rules[i].starting_year_ != rules[k].starting_year_ ||
+ rules[i].ending_year_ != rules[k].ending_year_));
+ if (rules[i].starting_year_ == rules[k].starting_year_)
+ {
+ if (rules[k].ending_year_ < rules[i].ending_year_)
+ {
+ rules.insert(rules.begin() + static_cast<difference_type>(k+1),
+ Rule(rules[i], rules[k].ending_year_ + years{1},
+ std::move(rules[i].ending_year_)));
+ ++e;
+ rules[i].ending_year_ = rules[k].ending_year_;
+ }
+ else // rules[k].ending_year_ > rules[i].ending_year_
+ {
+ rules.insert(rules.begin() + static_cast<difference_type>(k+1),
+ Rule(rules[k], rules[i].ending_year_ + years{1},
+ std::move(rules[k].ending_year_)));
+ ++e;
+ rules[k].ending_year_ = rules[i].ending_year_;
+ }
+ }
+ else // rules[i].starting_year_ < rules[k].starting_year_
+ {
+ if (rules[k].ending_year_ < rules[i].ending_year_)
+ {
+ rules.insert(rules.begin() + static_cast<difference_type>(k),
+ Rule(rules[i], rules[k].starting_year_, rules[k].ending_year_));
+ ++k;
+ rules.insert(rules.begin() + static_cast<difference_type>(k+1),
+ Rule(rules[i], rules[k].ending_year_ + years{1},
+ std::move(rules[i].ending_year_)));
+ rules[i].ending_year_ = rules[k].starting_year_ - years{1};
+ e += 2;
+ }
+ else if (rules[k].ending_year_ > rules[i].ending_year_)
+ {
+ rules.insert(rules.begin() + static_cast<difference_type>(k),
+ Rule(rules[i], rules[k].starting_year_, rules[i].ending_year_));
+ ++k;
+ rules.insert(rules.begin() + static_cast<difference_type>(k+1),
+ Rule(rules[k], rules[i].ending_year_ + years{1},
+ std::move(rules[k].ending_year_)));
+ e += 2;
+ rules[k].ending_year_ = std::move(rules[i].ending_year_);
+ rules[i].ending_year_ = rules[k].starting_year_ - years{1};
+ }
+ else // rules[k].ending_year_ == rules[i].ending_year_
+ {
+ rules.insert(rules.begin() + static_cast<difference_type>(k),
+ Rule(rules[i], rules[k].starting_year_,
+ std::move(rules[i].ending_year_)));
+ ++k;
+ ++e;
+ rules[i].ending_year_ = rules[k].starting_year_ - years{1};
+ }
+ }
+}
+
+void
+detail::Rule::split_overlaps(std::vector<Rule>& rules, std::size_t i, std::size_t& e)
+{
+ using difference_type = std::vector<Rule>::iterator::difference_type;
+ auto j = i;
+ for (; i + 1 < e; ++i)
+ {
+ for (auto k = i + 1; k < e; ++k)
+ {
+ if (overlaps(rules[i], rules[k]))
+ {
+ split(rules, i, k, e);
+ std::sort(rules.begin() + static_cast<difference_type>(i),
+ rules.begin() + static_cast<difference_type>(e));
+ }
+ }
+ }
+ for (; j < e; ++j)
+ {
+ if (rules[j].starting_year() == rules[j].ending_year())
+ rules[j].starting_at_.canonicalize(rules[j].starting_year());
+ }
+}
+
+void
+detail::Rule::split_overlaps(std::vector<Rule>& rules)
+{
+ using difference_type = std::vector<Rule>::iterator::difference_type;
+ for (std::size_t i = 0; i < rules.size();)
+ {
+ auto e = static_cast<std::size_t>(std::upper_bound(
+ rules.cbegin()+static_cast<difference_type>(i), rules.cend(), rules[i].name(),
+ [](const std::string& nm, const Rule& x)
+ {
+ return nm < x.name();
+ }) - rules.cbegin());
+ split_overlaps(rules, i, e);
+ auto first_rule = rules.begin() + static_cast<difference_type>(i);
+ auto last_rule = rules.begin() + static_cast<difference_type>(e);
+ auto t = std::lower_bound(first_rule, last_rule, min_year);
+ if (t > first_rule+1)
+ {
+ if (t == last_rule || t->starting_year() >= min_year)
+ --t;
+ auto d = static_cast<std::size_t>(t - first_rule);
+ rules.erase(first_rule, t);
+ e -= d;
+ }
+ first_rule = rules.begin() + static_cast<difference_type>(i);
+ last_rule = rules.begin() + static_cast<difference_type>(e);
+ t = std::upper_bound(first_rule, last_rule, max_year);
+ if (t != last_rule)
+ {
+ auto d = static_cast<std::size_t>(last_rule - t);
+ rules.erase(t, last_rule);
+ e -= d;
+ }
+ i = e;
+ }
+ rules.shrink_to_fit();
+}
+
+// Find the rule that comes chronologically before Rule r. For multi-year rules,
+// y specifies which rules in r. For single year rules, y is assumed to be equal
+// to the year specified by r.
+// Returns a pointer to the chronologically previous rule, and the year within
+// that rule. If there is no previous rule, returns nullptr and year::min().
+// Preconditions:
+// r->starting_year() <= y && y <= r->ending_year()
+static
+std::pair<const Rule*, date::year>
+find_previous_rule(const Rule* r, date::year y)
+{
+ using namespace date;
+ auto const& rules = get_tzdb().rules;
+ if (y == r->starting_year())
+ {
+ if (r == &rules.front() || r->name() != r[-1].name())
+ std::terminate(); // never called with first rule
+ --r;
+ if (y == r->starting_year())
+ return {r, y};
+ return {r, r->ending_year()};
+ }
+ if (r == &rules.front() || r->name() != r[-1].name() ||
+ r[-1].starting_year() < r->starting_year())
+ {
+ while (r < &rules.back() && r->name() == r[1].name() &&
+ r->starting_year() == r[1].starting_year())
+ ++r;
+ return {r, --y};
+ }
+ --r;
+ return {r, y};
+}
+
+// Find the rule that comes chronologically after Rule r. For multi-year rules,
+// y specifies which rules in r. For single year rules, y is assumed to be equal
+// to the year specified by r.
+// Returns a pointer to the chronologically next rule, and the year within
+// that rule. If there is no next rule, return a pointer to a defaulted rule
+// and y+1.
+// Preconditions:
+// first <= r && r < last && r->starting_year() <= y && y <= r->ending_year()
+// [first, last) all have the same name
+static
+std::pair<const Rule*, date::year>
+find_next_rule(const Rule* first_rule, const Rule* last_rule, const Rule* r, date::year y)
+{
+ using namespace date;
+ if (y == r->ending_year())
+ {
+ if (r == last_rule-1)
+ return {nullptr, year::max()};
+ ++r;
+ if (y == r->ending_year())
+ return {r, y};
+ return {r, r->starting_year()};
+ }
+ if (r == last_rule-1 || r->ending_year() < r[1].ending_year())
+ {
+ while (r > first_rule && r->starting_year() == r[-1].starting_year())
+ --r;
+ return {r, ++y};
+ }
+ ++r;
+ return {r, y};
+}
+
+// Find the rule that comes chronologically after Rule r. For multi-year rules,
+// y specifies which rules in r. For single year rules, y is assumed to be equal
+// to the year specified by r.
+// Returns a pointer to the chronologically next rule, and the year within
+// that rule. If there is no next rule, return nullptr and year::max().
+// Preconditions:
+// r->starting_year() <= y && y <= r->ending_year()
+static
+std::pair<const Rule*, date::year>
+find_next_rule(const Rule* r, date::year y)
+{
+ using namespace date;
+ auto const& rules = get_tzdb().rules;
+ if (y == r->ending_year())
+ {
+ if (r == &rules.back() || r->name() != r[1].name())
+ return {nullptr, year::max()};
+ ++r;
+ if (y == r->ending_year())
+ return {r, y};
+ return {r, r->starting_year()};
+ }
+ if (r == &rules.back() || r->name() != r[1].name() ||
+ r->ending_year() < r[1].ending_year())
+ {
+ while (r > &rules.front() && r->name() == r[-1].name() &&
+ r->starting_year() == r[-1].starting_year())
+ --r;
+ return {r, ++y};
+ }
+ ++r;
+ return {r, y};
+}
+
+static
+const Rule*
+find_first_std_rule(const std::pair<const Rule*, const Rule*>& eqr)
+{
+ auto r = eqr.first;
+ auto ry = r->starting_year();
+ while (r->save() != std::chrono::minutes{0})
+ {
+ std::tie(r, ry) = find_next_rule(eqr.first, eqr.second, r, ry);
+ if (r == nullptr)
+ throw std::runtime_error("Could not find standard offset in rule "
+ + eqr.first->name());
+ }
+ return r;
+}
+
+static
+std::pair<const Rule*, date::year>
+find_rule_for_zone(const std::pair<const Rule*, const Rule*>& eqr,
+ const date::year& y, const std::chrono::seconds& offset,
+ const MonthDayTime& mdt)
+{
+ assert(eqr.first != nullptr);
+ assert(eqr.second != nullptr);
+
+ using namespace std::chrono;
+ using namespace date;
+ auto r = eqr.first;
+ auto ry = r->starting_year();
+ auto prev_save = minutes{0};
+ auto prev_year = year::min();
+ const Rule* prev_rule = nullptr;
+ while (r != nullptr)
+ {
+ if (mdt.compare(y, r->mdt(), ry, offset, prev_save) <= 0)
+ break;
+ prev_rule = r;
+ prev_year = ry;
+ prev_save = prev_rule->save();
+ std::tie(r, ry) = find_next_rule(eqr.first, eqr.second, r, ry);
+ }
+ return {prev_rule, prev_year};
+}
+
+static
+std::pair<const Rule*, date::year>
+find_rule_for_zone(const std::pair<const Rule*, const Rule*>& eqr,
+ const sys_seconds& tp_utc,
+ const local_seconds& tp_std,
+ const local_seconds& tp_loc)
+{
+ using namespace std::chrono;
+ using namespace date;
+ auto r = eqr.first;
+ auto ry = r->starting_year();
+ auto prev_save = minutes{0};
+ auto prev_year = year::min();
+ const Rule* prev_rule = nullptr;
+ while (r != nullptr)
+ {
+ bool found = false;
+ switch (r->mdt().zone())
+ {
+ case tz::utc:
+ found = tp_utc < r->mdt().to_time_point(ry);
+ break;
+ case tz::standard:
+ found = sys_seconds{tp_std.time_since_epoch()} < r->mdt().to_time_point(ry);
+ break;
+ case tz::local:
+ found = sys_seconds{tp_loc.time_since_epoch()} < r->mdt().to_time_point(ry);
+ break;
+ }
+ if (found)
+ break;
+ prev_rule = r;
+ prev_year = ry;
+ prev_save = prev_rule->save();
+ std::tie(r, ry) = find_next_rule(eqr.first, eqr.second, r, ry);
+ }
+ return {prev_rule, prev_year};
+}
+
+static
+sys_info
+find_rule(const std::pair<const Rule*, date::year>& first_rule,
+ const std::pair<const Rule*, date::year>& last_rule,
+ const date::year& y, const std::chrono::seconds& offset,
+ const MonthDayTime& mdt, const std::chrono::minutes& initial_save,
+ const std::string& initial_abbrev)
+{
+ using namespace std::chrono;
+ using namespace date;
+ auto r = first_rule.first;
+ auto ry = first_rule.second;
+ sys_info x{sys_days(year::min()/min_day), sys_days(year::max()/max_day),
+ seconds{0}, initial_save, initial_abbrev};
+ while (r != nullptr)
+ {
+ auto tr = r->mdt().to_sys(ry, offset, x.save);
+ auto tx = mdt.to_sys(y, offset, x.save);
+ // Find last rule where tx >= tr
+ if (tx <= tr || (r == last_rule.first && ry == last_rule.second))
+ {
+ if (tx < tr && r == first_rule.first && ry == first_rule.second)
+ {
+ x.end = r->mdt().to_sys(ry, offset, x.save);
+ break;
+ }
+ if (tx < tr)
+ {
+ std::tie(r, ry) = find_previous_rule(r, ry); // can't return nullptr for r
+ assert(r != nullptr);
+ }
+ // r != nullptr && tx >= tr (if tr were to be recomputed)
+ auto prev_save = initial_save;
+ if (!(r == first_rule.first && ry == first_rule.second))
+ prev_save = find_previous_rule(r, ry).first->save();
+ x.begin = r->mdt().to_sys(ry, offset, prev_save);
+ x.save = r->save();
+ x.abbrev = r->abbrev();
+ if (!(r == last_rule.first && ry == last_rule.second))
+ {
+ std::tie(r, ry) = find_next_rule(r, ry); // can't return nullptr for r
+ assert(r != nullptr);
+ x.end = r->mdt().to_sys(ry, offset, x.save);
+ }
+ else
+ x.end = sys_days(year::max()/max_day);
+ break;
+ }
+ x.save = r->save();
+ std::tie(r, ry) = find_next_rule(r, ry); // Can't return nullptr for r
+ assert(r != nullptr);
+ }
+ return x;
+}
+
+// zonelet
+
+detail::zonelet::~zonelet()
+{
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+ using minutes = std::chrono::minutes;
+ using string = std::string;
+ if (tag_ == has_save)
+ u.save_.~minutes();
+ else
+ u.rule_.~string();
+#endif
+}
+
+detail::zonelet::zonelet()
+{
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+ ::new(&u.rule_) std::string();
+#endif
+}
+
+detail::zonelet::zonelet(const zonelet& i)
+ : gmtoff_(i.gmtoff_)
+ , tag_(i.tag_)
+ , format_(i.format_)
+ , until_year_(i.until_year_)
+ , until_date_(i.until_date_)
+ , until_utc_(i.until_utc_)
+ , until_std_(i.until_std_)
+ , until_loc_(i.until_loc_)
+ , initial_save_(i.initial_save_)
+ , initial_abbrev_(i.initial_abbrev_)
+ , first_rule_(i.first_rule_)
+ , last_rule_(i.last_rule_)
+{
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+ if (tag_ == has_save)
+ ::new(&u.save_) std::chrono::minutes(i.u.save_);
+ else
+ ::new(&u.rule_) std::string(i.u.rule_);
+#else
+ if (tag_ == has_save)
+ u.save_ = i.u.save_;
+ else
+ u.rule_ = i.u.rule_;
+#endif
+}
+
+#endif // !USE_OS_TZDB
+
+// time_zone
+
+#if USE_OS_TZDB
+
+time_zone::time_zone(const std::string& s, detail::undocumented)
+ : name_(s)
+ , adjusted_(new std::once_flag{})
+{
+}
+
+enum class endian
+{
+ native = __BYTE_ORDER__,
+ little = __ORDER_LITTLE_ENDIAN__,
+ big = __ORDER_BIG_ENDIAN__
+};
+
+static
+inline
+std::uint32_t
+reverse_bytes(std::uint32_t i)
+{
+ return
+ (i & 0xff000000u) >> 24 |
+ (i & 0x00ff0000u) >> 8 |
+ (i & 0x0000ff00u) << 8 |
+ (i & 0x000000ffu) << 24;
+}
+
+static
+inline
+std::uint64_t
+reverse_bytes(std::uint64_t i)
+{
+ return
+ (i & 0xff00000000000000ull) >> 56 |
+ (i & 0x00ff000000000000ull) >> 40 |
+ (i & 0x0000ff0000000000ull) >> 24 |
+ (i & 0x000000ff00000000ull) >> 8 |
+ (i & 0x00000000ff000000ull) << 8 |
+ (i & 0x0000000000ff0000ull) << 24 |
+ (i & 0x000000000000ff00ull) << 40 |
+ (i & 0x00000000000000ffull) << 56;
+}
+
+template <class T>
+static
+inline
+void
+maybe_reverse_bytes(T&, std::false_type)
+{
+}
+
+static
+inline
+void
+maybe_reverse_bytes(std::int32_t& t, std::true_type)
+{
+ t = static_cast<std::int32_t>(reverse_bytes(static_cast<std::uint32_t>(t)));
+}
+
+static
+inline
+void
+maybe_reverse_bytes(std::int64_t& t, std::true_type)
+{
+ t = static_cast<std::int64_t>(reverse_bytes(static_cast<std::uint64_t>(t)));
+}
+
+template <class T>
+static
+inline
+void
+maybe_reverse_bytes(T& t)
+{
+ maybe_reverse_bytes(t, std::integral_constant<bool,
+ endian::native == endian::little>{});
+}
+
+static
+void
+load_header(std::istream& inf)
+{
+ // Read TZif
+ auto t = inf.get();
+ auto z = inf.get();
+ auto i = inf.get();
+ auto f = inf.get();
+#ifndef NDEBUG
+ assert(t == 'T');
+ assert(z == 'Z');
+ assert(i == 'i');
+ assert(f == 'f');
+#else
+ (void)t;
+ (void)z;
+ (void)i;
+ (void)f;
+#endif
+}
+
+static
+unsigned char
+load_version(std::istream& inf)
+{
+ // Read version
+ auto v = inf.get();
+ assert(v != EOF);
+ return static_cast<unsigned char>(v);
+}
+
+static
+void
+skip_reserve(std::istream& inf)
+{
+ inf.ignore(15);
+}
+
+static
+void
+load_counts(std::istream& inf,
+ std::int32_t& tzh_ttisgmtcnt, std::int32_t& tzh_ttisstdcnt,
+ std::int32_t& tzh_leapcnt, std::int32_t& tzh_timecnt,
+ std::int32_t& tzh_typecnt, std::int32_t& tzh_charcnt)
+{
+ // Read counts;
+ inf.read(reinterpret_cast<char*>(&tzh_ttisgmtcnt), 4);
+ maybe_reverse_bytes(tzh_ttisgmtcnt);
+ inf.read(reinterpret_cast<char*>(&tzh_ttisstdcnt), 4);
+ maybe_reverse_bytes(tzh_ttisstdcnt);
+ inf.read(reinterpret_cast<char*>(&tzh_leapcnt), 4);
+ maybe_reverse_bytes(tzh_leapcnt);
+ inf.read(reinterpret_cast<char*>(&tzh_timecnt), 4);
+ maybe_reverse_bytes(tzh_timecnt);
+ inf.read(reinterpret_cast<char*>(&tzh_typecnt), 4);
+ maybe_reverse_bytes(tzh_typecnt);
+ inf.read(reinterpret_cast<char*>(&tzh_charcnt), 4);
+ maybe_reverse_bytes(tzh_charcnt);
+}
+
+template <class TimeType>
+static
+std::vector<detail::transition>
+load_transitions(std::istream& inf, std::int32_t tzh_timecnt)
+{
+ // Read transitions
+ using namespace std::chrono;
+ std::vector<detail::transition> transitions;
+ transitions.reserve(static_cast<unsigned>(tzh_timecnt));
+ for (std::int32_t i = 0; i < tzh_timecnt; ++i)
+ {
+ TimeType t;
+ inf.read(reinterpret_cast<char*>(&t), sizeof(t));
+ maybe_reverse_bytes(t);
+ transitions.emplace_back(sys_seconds{seconds{t}});
+ if (transitions.back().timepoint < min_seconds)
+ transitions.back().timepoint = min_seconds;
+ }
+ return transitions;
+}
+
+static
+std::vector<std::uint8_t>
+load_indices(std::istream& inf, std::int32_t tzh_timecnt)
+{
+ // Read indices
+ std::vector<std::uint8_t> indices;
+ indices.reserve(static_cast<unsigned>(tzh_timecnt));
+ for (std::int32_t i = 0; i < tzh_timecnt; ++i)
+ {
+ std::uint8_t t;
+ inf.read(reinterpret_cast<char*>(&t), sizeof(t));
+ indices.emplace_back(t);
+ }
+ return indices;
+}
+
+static
+std::vector<ttinfo>
+load_ttinfo(std::istream& inf, std::int32_t tzh_typecnt)
+{
+ // Read ttinfo
+ std::vector<ttinfo> ttinfos;
+ ttinfos.reserve(static_cast<unsigned>(tzh_typecnt));
+ for (std::int32_t i = 0; i < tzh_typecnt; ++i)
+ {
+ ttinfo t;
+ inf.read(reinterpret_cast<char*>(&t), 6);
+ maybe_reverse_bytes(t.tt_gmtoff);
+ ttinfos.emplace_back(t);
+ }
+ return ttinfos;
+}
+
+static
+std::string
+load_abbreviations(std::istream& inf, std::int32_t tzh_charcnt)
+{
+ // Read abbreviations
+ std::string abbrev;
+ abbrev.resize(static_cast<unsigned>(tzh_charcnt), '\0');
+ inf.read(&abbrev[0], tzh_charcnt);
+ return abbrev;
+}
+
+#if !MISSING_LEAP_SECONDS
+
+template <class TimeType>
+static
+std::vector<leap_second>
+load_leaps(std::istream& inf, std::int32_t tzh_leapcnt)
+{
+ // Read tzh_leapcnt pairs
+ using namespace std::chrono;
+ std::vector<leap_second> leap_seconds;
+ leap_seconds.reserve(static_cast<std::size_t>(tzh_leapcnt));
+ for (std::int32_t i = 0; i < tzh_leapcnt; ++i)
+ {
+ TimeType t0;
+ std::int32_t t1;
+ inf.read(reinterpret_cast<char*>(&t0), sizeof(t0));
+ inf.read(reinterpret_cast<char*>(&t1), sizeof(t1));
+ maybe_reverse_bytes(t0);
+ maybe_reverse_bytes(t1);
+ leap_seconds.emplace_back(sys_seconds{seconds{t0 - (t1-1)}},
+ detail::undocumented{});
+ }
+ return leap_seconds;
+}
+
+template <class TimeType>
+static
+std::vector<leap_second>
+load_leap_data(std::istream& inf,
+ std::int32_t tzh_leapcnt, std::int32_t tzh_timecnt,
+ std::int32_t tzh_typecnt, std::int32_t tzh_charcnt)
+{
+ inf.ignore(tzh_timecnt*static_cast<std::int32_t>(sizeof(TimeType)) + tzh_timecnt +
+ tzh_typecnt*6 + tzh_charcnt);
+ return load_leaps<TimeType>(inf, tzh_leapcnt);
+}
+
+static
+std::vector<leap_second>
+load_just_leaps(std::istream& inf)
+{
+ // Read tzh_leapcnt pairs
+ using namespace std::chrono;
+ load_header(inf);
+ auto v = load_version(inf);
+ std::int32_t tzh_ttisgmtcnt, tzh_ttisstdcnt, tzh_leapcnt,
+ tzh_timecnt, tzh_typecnt, tzh_charcnt;
+ skip_reserve(inf);
+ load_counts(inf, tzh_ttisgmtcnt, tzh_ttisstdcnt, tzh_leapcnt,
+ tzh_timecnt, tzh_typecnt, tzh_charcnt);
+ if (v == 0)
+ return load_leap_data<int32_t>(inf, tzh_leapcnt, tzh_timecnt, tzh_typecnt,
+ tzh_charcnt);
+#if !defined(NDEBUG)
+ inf.ignore((4+1)*tzh_timecnt + 6*tzh_typecnt + tzh_charcnt + 8*tzh_leapcnt +
+ tzh_ttisstdcnt + tzh_ttisgmtcnt);
+ load_header(inf);
+ auto v2 = load_version(inf);
+ assert(v == v2);
+ skip_reserve(inf);
+#else // defined(NDEBUG)
+ inf.ignore((4+1)*tzh_timecnt + 6*tzh_typecnt + tzh_charcnt + 8*tzh_leapcnt +
+ tzh_ttisstdcnt + tzh_ttisgmtcnt + (4+1+15));
+#endif // defined(NDEBUG)
+ load_counts(inf, tzh_ttisgmtcnt, tzh_ttisstdcnt, tzh_leapcnt,
+ tzh_timecnt, tzh_typecnt, tzh_charcnt);
+ return load_leap_data<int64_t>(inf, tzh_leapcnt, tzh_timecnt, tzh_typecnt,
+ tzh_charcnt);
+}
+
+#endif // !MISSING_LEAP_SECONDS
+
+template <class TimeType>
+void
+time_zone::load_data(std::istream& inf,
+ std::int32_t tzh_leapcnt, std::int32_t tzh_timecnt,
+ std::int32_t tzh_typecnt, std::int32_t tzh_charcnt)
+{
+ using namespace std::chrono;
+ transitions_ = load_transitions<TimeType>(inf, tzh_timecnt);
+ auto indices = load_indices(inf, tzh_timecnt);
+ auto infos = load_ttinfo(inf, tzh_typecnt);
+ auto abbrev = load_abbreviations(inf, tzh_charcnt);
+#if !MISSING_LEAP_SECONDS
+ auto& leap_seconds = get_tzdb_list().front().leap_seconds;
+ if (leap_seconds.empty() && tzh_leapcnt > 0)
+ leap_seconds = load_leaps<TimeType>(inf, tzh_leapcnt);
+#endif
+ ttinfos_.reserve(infos.size());
+ for (auto& info : infos)
+ {
+ ttinfos_.push_back({seconds{info.tt_gmtoff},
+ abbrev.c_str() + info.tt_abbrind,
+ info.tt_isdst != 0});
+ }
+ auto i = 0u;
+ if (transitions_.empty() || transitions_.front().timepoint != min_seconds)
+ {
+ transitions_.emplace(transitions_.begin(), min_seconds);
+ auto tf = std::find_if(ttinfos_.begin(), ttinfos_.end(),
+ [](const expanded_ttinfo& ti)
+ {return ti.is_dst == 0;});
+ if (tf == ttinfos_.end())
+ tf = ttinfos_.begin();
+ transitions_[i].info = &*tf;
+ ++i;
+ }
+ for (auto j = 0u; i < transitions_.size(); ++i, ++j)
+ transitions_[i].info = ttinfos_.data() + indices[j];
+}
+
+void
+time_zone::init_impl()
+{
+ using namespace std;
+ using namespace std::chrono;
+ auto name = get_tz_dir() + ('/' + name_);
+ std::ifstream inf(name);
+ if (!inf.is_open())
+ throw std::runtime_error{"Unable to open " + name};
+ inf.exceptions(std::ios::failbit | std::ios::badbit);
+ load_header(inf);
+ auto v = load_version(inf);
+ std::int32_t tzh_ttisgmtcnt, tzh_ttisstdcnt, tzh_leapcnt,
+ tzh_timecnt, tzh_typecnt, tzh_charcnt;
+ skip_reserve(inf);
+ load_counts(inf, tzh_ttisgmtcnt, tzh_ttisstdcnt, tzh_leapcnt,
+ tzh_timecnt, tzh_typecnt, tzh_charcnt);
+ if (v == 0)
+ {
+ load_data<int32_t>(inf, tzh_leapcnt, tzh_timecnt, tzh_typecnt, tzh_charcnt);
+ }
+ else
+ {
+#if !defined(NDEBUG)
+ inf.ignore((4+1)*tzh_timecnt + 6*tzh_typecnt + tzh_charcnt + 8*tzh_leapcnt +
+ tzh_ttisstdcnt + tzh_ttisgmtcnt);
+ load_header(inf);
+ auto v2 = load_version(inf);
+ assert(v == v2);
+ skip_reserve(inf);
+#else // defined(NDEBUG)
+ inf.ignore((4+1)*tzh_timecnt + 6*tzh_typecnt + tzh_charcnt + 8*tzh_leapcnt +
+ tzh_ttisstdcnt + tzh_ttisgmtcnt + (4+1+15));
+#endif // defined(NDEBUG)
+ load_counts(inf, tzh_ttisgmtcnt, tzh_ttisstdcnt, tzh_leapcnt,
+ tzh_timecnt, tzh_typecnt, tzh_charcnt);
+ load_data<int64_t>(inf, tzh_leapcnt, tzh_timecnt, tzh_typecnt, tzh_charcnt);
+ }
+#if !MISSING_LEAP_SECONDS
+ if (tzh_leapcnt > 0)
+ {
+ auto& leap_seconds = get_tzdb_list().front().leap_seconds;
+ auto itr = leap_seconds.begin();
+ auto l = itr->date();
+ seconds leap_count{0};
+ for (auto t = std::upper_bound(transitions_.begin(), transitions_.end(), l,
+ [](const sys_seconds& x, const transition& ct)
+ {
+ return x < ct.timepoint;
+ });
+ t != transitions_.end(); ++t)
+ {
+ while (t->timepoint >= l)
+ {
+ ++leap_count;
+ if (++itr == leap_seconds.end())
+ l = sys_days(max_year/max_day);
+ else
+ l = itr->date() + leap_count;
+ }
+ t->timepoint -= leap_count;
+ }
+ }
+#endif // !MISSING_LEAP_SECONDS
+ auto b = transitions_.begin();
+ auto i = transitions_.end();
+ if (i != b)
+ {
+ for (--i; i != b; --i)
+ {
+ if (i->info->offset == i[-1].info->offset &&
+ i->info->abbrev == i[-1].info->abbrev &&
+ i->info->is_dst == i[-1].info->is_dst)
+ i = transitions_.erase(i);
+ }
+ }
+}
+
+void
+time_zone::init() const
+{
+ std::call_once(*adjusted_, [this]() {const_cast<time_zone*>(this)->init_impl();});
+}
+
+sys_info
+time_zone::load_sys_info(std::vector<detail::transition>::const_iterator i) const
+{
+ using namespace std::chrono;
+ assert(!transitions_.empty());
+ sys_info r;
+ if (i != transitions_.begin())
+ {
+ r.begin = i[-1].timepoint;
+ r.end = i != transitions_.end() ? i->timepoint :
+ sys_seconds(sys_days(year::max()/max_day));
+ r.offset = i[-1].info->offset;
+ r.save = i[-1].info->is_dst ? minutes{1} : minutes{0};
+ r.abbrev = i[-1].info->abbrev;
+ }
+ else
+ {
+ r.begin = sys_days(year::min()/min_day);
+ r.end = i+1 != transitions_.end() ? i[1].timepoint :
+ sys_seconds(sys_days(year::max()/max_day));
+ r.offset = i[0].info->offset;
+ r.save = i[0].info->is_dst ? minutes{1} : minutes{0};
+ r.abbrev = i[0].info->abbrev;
+ }
+ return r;
+}
+
+sys_info
+time_zone::get_info_impl(sys_seconds tp) const
+{
+ using namespace std;
+ init();
+ return load_sys_info(upper_bound(transitions_.begin(), transitions_.end(), tp,
+ [](const sys_seconds& x, const transition& t)
+ {
+ return x < t.timepoint;
+ }));
+}
+
+local_info
+time_zone::get_info_impl(local_seconds tp) const
+{
+ using namespace std::chrono;
+ init();
+ local_info i{};
+ i.result = local_info::unique;
+ auto tr = upper_bound(transitions_.begin(), transitions_.end(), tp,
+ [](const local_seconds& x, const transition& t)
+ {
+ return sys_seconds{x.time_since_epoch()} -
+ t.info->offset < t.timepoint;
+ });
+ i.first = load_sys_info(tr);
+ auto tps = sys_seconds{(tp - i.first.offset).time_since_epoch()};
+ if (tps < i.first.begin + days{1} && tr != transitions_.begin())
+ {
+ i.second = load_sys_info(--tr);
+ tps = sys_seconds{(tp - i.second.offset).time_since_epoch()};
+ if (tps < i.second.end && i.first.end != i.second.end)
+ {
+ i.result = local_info::ambiguous;
+ std::swap(i.first, i.second);
+ }
+ else
+ {
+ i.second = {};
+ }
+ }
+ else if (tps >= i.first.end && tr != transitions_.end())
+ {
+ i.second = load_sys_info(++tr);
+ tps = sys_seconds{(tp - i.second.offset).time_since_epoch()};
+ if (tps < i.second.begin)
+ i.result = local_info::nonexistent;
+ else
+ i.second = {};
+ }
+ return i;
+}
+
+std::ostream&
+operator<<(std::ostream& os, const time_zone& z)
+{
+ using namespace std::chrono;
+ z.init();
+ os << z.name_ << '\n';
+ os << "Initially: ";
+ auto const& t = z.transitions_.front();
+ if (t.info->offset >= seconds{0})
+ os << '+';
+ os << make_time(t.info->offset);
+ if (t.info->is_dst > 0)
+ os << " daylight ";
+ else
+ os << " standard ";
+ os << t.info->abbrev << '\n';
+ for (auto i = std::next(z.transitions_.cbegin()); i < z.transitions_.cend(); ++i)
+ os << *i << '\n';
+ return os;
+}
+
+leap_second::leap_second(const sys_seconds& s, detail::undocumented)
+ : date_(s)
+{
+}
+
+#else // !USE_OS_TZDB
+
+time_zone::time_zone(const std::string& s, detail::undocumented)
+ : adjusted_(new std::once_flag{})
+{
+ try
+ {
+ using namespace date;
+ std::istringstream in(s);
+ in.exceptions(std::ios::failbit | std::ios::badbit);
+ std::string word;
+ in >> word >> name_;
+ parse_info(in);
+ }
+ catch (...)
+ {
+ std::cerr << s << '\n';
+ std::cerr << *this << '\n';
+ zonelets_.pop_back();
+ throw;
+ }
+}
+
+sys_info
+time_zone::get_info_impl(sys_seconds tp) const
+{
+ return get_info_impl(tp, static_cast<int>(tz::utc));
+}
+
+local_info
+time_zone::get_info_impl(local_seconds tp) const
+{
+ using namespace std::chrono;
+ local_info i{};
+ i.first = get_info_impl(sys_seconds{tp.time_since_epoch()}, static_cast<int>(tz::local));
+ auto tps = sys_seconds{(tp - i.first.offset).time_since_epoch()};
+ if (tps < i.first.begin)
+ {
+ i.second = std::move(i.first);
+ i.first = get_info_impl(i.second.begin - seconds{1}, static_cast<int>(tz::utc));
+ i.result = local_info::nonexistent;
+ }
+ else if (i.first.end - tps <= days{1})
+ {
+ i.second = get_info_impl(i.first.end, static_cast<int>(tz::utc));
+ tps = sys_seconds{(tp - i.second.offset).time_since_epoch()};
+ if (tps >= i.second.begin)
+ i.result = local_info::ambiguous;
+ else
+ i.second = {};
+ }
+ return i;
+}
+
+void
+time_zone::add(const std::string& s)
+{
+ try
+ {
+ std::istringstream in(s);
+ in.exceptions(std::ios::failbit | std::ios::badbit);
+ ws(in);
+ if (!in.eof() && in.peek() != '#')
+ parse_info(in);
+ }
+ catch (...)
+ {
+ std::cerr << s << '\n';
+ std::cerr << *this << '\n';
+ zonelets_.pop_back();
+ throw;
+ }
+}
+
+void
+time_zone::parse_info(std::istream& in)
+{
+ using namespace date;
+ using namespace std::chrono;
+ zonelets_.emplace_back();
+ auto& zonelet = zonelets_.back();
+ zonelet.gmtoff_ = parse_signed_time(in);
+ in >> zonelet.u.rule_;
+ if (zonelet.u.rule_ == "-")
+ zonelet.u.rule_.clear();
+ in >> zonelet.format_;
+ if (!in.eof())
+ ws(in);
+ if (in.eof() || in.peek() == '#')
+ {
+ zonelet.until_year_ = year::max();
+ zonelet.until_date_ = MonthDayTime(max_day, tz::utc);
+ }
+ else
+ {
+ int y;
+ in >> y;
+ zonelet.until_year_ = year{y};
+ in >> zonelet.until_date_;
+ zonelet.until_date_.canonicalize(zonelet.until_year_);
+ }
+ if ((zonelet.until_year_ < min_year) ||
+ (zonelets_.size() > 1 && zonelets_.end()[-2].until_year_ > max_year))
+ zonelets_.pop_back();
+}
+
+void
+time_zone::adjust_infos(const std::vector<Rule>& rules)
+{
+ using namespace std::chrono;
+ using namespace date;
+ const zonelet* prev_zonelet = nullptr;
+ for (auto& z : zonelets_)
+ {
+ std::pair<const Rule*, const Rule*> eqr{};
+ std::istringstream in;
+ in.exceptions(std::ios::failbit | std::ios::badbit);
+ // Classify info as rule-based, has save, or neither
+ if (!z.u.rule_.empty())
+ {
+ // Find out if this zonelet has a rule or a save
+ eqr = std::equal_range(rules.data(), rules.data() + rules.size(), z.u.rule_);
+ if (eqr.first == eqr.second)
+ {
+ // The rule doesn't exist. Assume this is a save
+ try
+ {
+ using namespace std::chrono;
+ using string = std::string;
+ in.str(z.u.rule_);
+ auto tmp = duration_cast<minutes>(parse_signed_time(in));
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+ z.u.rule_.~string();
+ z.tag_ = zonelet::has_save;
+ ::new(&z.u.save_) minutes(tmp);
+#else
+ z.u.rule_.clear();
+ z.tag_ = zonelet::has_save;
+ z.u.save_ = tmp;
+#endif
+ }
+ catch (...)
+ {
+ std::cerr << name_ << " : " << z.u.rule_ << '\n';
+ throw;
+ }
+ }
+ }
+ else
+ {
+ // This zone::zonelet has no rule and no save
+ z.tag_ = zonelet::is_empty;
+ }
+
+ minutes final_save{0};
+ if (z.tag_ == zonelet::has_save)
+ {
+ final_save = z.u.save_;
+ }
+ else if (z.tag_ == zonelet::has_rule)
+ {
+ z.last_rule_ = find_rule_for_zone(eqr, z.until_year_, z.gmtoff_,
+ z.until_date_);
+ if (z.last_rule_.first != nullptr)
+ final_save = z.last_rule_.first->save();
+ }
+ z.until_utc_ = z.until_date_.to_sys(z.until_year_, z.gmtoff_, final_save);
+ z.until_std_ = local_seconds{z.until_utc_.time_since_epoch()} + z.gmtoff_;
+ z.until_loc_ = z.until_std_ + final_save;
+
+ if (z.tag_ == zonelet::has_rule)
+ {
+ if (prev_zonelet != nullptr)
+ {
+ z.first_rule_ = find_rule_for_zone(eqr, prev_zonelet->until_utc_,
+ prev_zonelet->until_std_,
+ prev_zonelet->until_loc_);
+ if (z.first_rule_.first != nullptr)
+ {
+ z.initial_save_ = z.first_rule_.first->save();
+ z.initial_abbrev_ = z.first_rule_.first->abbrev();
+ if (z.first_rule_ != z.last_rule_)
+ {
+ z.first_rule_ = find_next_rule(eqr.first, eqr.second,
+ z.first_rule_.first,
+ z.first_rule_.second);
+ }
+ else
+ {
+ z.first_rule_ = std::make_pair(nullptr, year::min());
+ z.last_rule_ = std::make_pair(nullptr, year::max());
+ }
+ }
+ }
+ if (z.first_rule_.first == nullptr && z.last_rule_.first != nullptr)
+ {
+ z.first_rule_ = std::make_pair(eqr.first, eqr.first->starting_year());
+ z.initial_abbrev_ = find_first_std_rule(eqr)->abbrev();
+ }
+ }
+
+#ifndef NDEBUG
+ if (z.first_rule_.first == nullptr)
+ {
+ assert(z.first_rule_.second == year::min());
+ assert(z.last_rule_.first == nullptr);
+ assert(z.last_rule_.second == year::max());
+ }
+ else
+ {
+ assert(z.last_rule_.first != nullptr);
+ }
+#endif
+ prev_zonelet = &z;
+ }
+}
+
+static
+std::string
+format_abbrev(std::string format, const std::string& variable, std::chrono::seconds off,
+ std::chrono::minutes save)
+{
+ using namespace std::chrono;
+ auto k = format.find("%s");
+ if (k != std::string::npos)
+ {
+ format.replace(k, 2, variable);
+ }
+ else
+ {
+ k = format.find('/');
+ if (k != std::string::npos)
+ {
+ if (save == minutes{0})
+ format.erase(k);
+ else
+ format.erase(0, k+1);
+ }
+ else
+ {
+ k = format.find("%z");
+ if (k != std::string::npos)
+ {
+ std::string temp;
+ if (off < seconds{0})
+ {
+ temp = '-';
+ off = -off;
+ }
+ else
+ temp = '+';
+ auto h = date::floor<hours>(off);
+ off -= h;
+ if (h < hours{10})
+ temp += '0';
+ temp += std::to_string(h.count());
+ if (off > seconds{0})
+ {
+ auto m = date::floor<minutes>(off);
+ off -= m;
+ if (m < minutes{10})
+ temp += '0';
+ temp += std::to_string(m.count());
+ if (off > seconds{0})
+ {
+ if (off < seconds{10})
+ temp += '0';
+ temp += std::to_string(off.count());
+ }
+ }
+ format.replace(k, 2, temp);
+ }
+ }
+ }
+ return format;
+}
+
+sys_info
+time_zone::get_info_impl(sys_seconds tp, int tz_int) const
+{
+ using namespace std::chrono;
+ using namespace date;
+ tz timezone = static_cast<tz>(tz_int);
+ assert(timezone != tz::standard);
+ auto y = year_month_day(floor<days>(tp)).year();
+ if (y < min_year || y > max_year)
+ throw std::runtime_error("The year " + std::to_string(static_cast<int>(y)) +
+ " is out of range:[" + std::to_string(static_cast<int>(min_year)) + ", "
+ + std::to_string(static_cast<int>(max_year)) + "]");
+ std::call_once(*adjusted_,
+ [this]()
+ {
+ const_cast<time_zone*>(this)->adjust_infos(get_tzdb().rules);
+ });
+ auto i = std::upper_bound(zonelets_.begin(), zonelets_.end(), tp,
+ [timezone](sys_seconds t, const zonelet& zl)
+ {
+ return timezone == tz::utc ? t < zl.until_utc_ :
+ t < sys_seconds{zl.until_loc_.time_since_epoch()};
+ });
+
+ sys_info r{};
+ if (i != zonelets_.end())
+ {
+ if (i->tag_ == zonelet::has_save)
+ {
+ if (i != zonelets_.begin())
+ r.begin = i[-1].until_utc_;
+ else
+ r.begin = sys_days(year::min()/min_day);
+ r.end = i->until_utc_;
+ r.offset = i->gmtoff_ + i->u.save_;
+ r.save = i->u.save_;
+ }
+ else if (i->u.rule_.empty())
+ {
+ if (i != zonelets_.begin())
+ r.begin = i[-1].until_utc_;
+ else
+ r.begin = sys_days(year::min()/min_day);
+ r.end = i->until_utc_;
+ r.offset = i->gmtoff_;
+ }
+ else
+ {
+ r = find_rule(i->first_rule_, i->last_rule_, y, i->gmtoff_,
+ MonthDayTime(local_seconds{tp.time_since_epoch()}, timezone),
+ i->initial_save_, i->initial_abbrev_);
+ r.offset = i->gmtoff_ + r.save;
+ if (i != zonelets_.begin() && r.begin < i[-1].until_utc_)
+ r.begin = i[-1].until_utc_;
+ if (r.end > i->until_utc_)
+ r.end = i->until_utc_;
+ }
+ r.abbrev = format_abbrev(i->format_, r.abbrev, r.offset, r.save);
+ assert(r.begin < r.end);
+ }
+ return r;
+}
+
+std::ostream&
+operator<<(std::ostream& os, const time_zone& z)
+{
+ using namespace date;
+ using namespace std::chrono;
+ detail::save_ostream<char> _(os);
+ os.fill(' ');
+ os.flags(std::ios::dec | std::ios::left);
+ std::call_once(*z.adjusted_,
+ [&z]()
+ {
+ const_cast<time_zone&>(z).adjust_infos(get_tzdb().rules);
+ });
+ os.width(35);
+ os << z.name_;
+ std::string indent;
+ for (auto const& s : z.zonelets_)
+ {
+ os << indent;
+ if (s.gmtoff_ >= seconds{0})
+ os << ' ';
+ os << make_time(s.gmtoff_) << " ";
+ os.width(15);
+ if (s.tag_ != zonelet::has_save)
+ os << s.u.rule_;
+ else
+ {
+ std::ostringstream tmp;
+ tmp << make_time(s.u.save_);
+ os << tmp.str();
+ }
+ os.width(8);
+ os << s.format_ << " ";
+ os << s.until_year_ << ' ' << s.until_date_;
+ os << " " << s.until_utc_ << " UTC";
+ os << " " << s.until_std_ << " STD";
+ os << " " << s.until_loc_;
+ os << " " << make_time(s.initial_save_);
+ os << " " << s.initial_abbrev_;
+ if (s.first_rule_.first != nullptr)
+ os << " {" << *s.first_rule_.first << ", " << s.first_rule_.second << '}';
+ else
+ os << " {" << "nullptr" << ", " << s.first_rule_.second << '}';
+ if (s.last_rule_.first != nullptr)
+ os << " {" << *s.last_rule_.first << ", " << s.last_rule_.second << '}';
+ else
+ os << " {" << "nullptr" << ", " << s.last_rule_.second << '}';
+ os << '\n';
+ if (indent.empty())
+ indent = std::string(35, ' ');
+ }
+ return os;
+}
+
+#endif // !USE_OS_TZDB
+
+std::ostream&
+operator<<(std::ostream& os, const leap_second& x)
+{
+ using namespace date;
+ return os << x.date_ << " +";
+}
+
+#if USE_OS_TZDB
+
+static
+std::string
+get_version()
+{
+ using namespace std;
+ auto path = get_tz_dir() + string("/+VERSION");
+ ifstream in{path};
+ string version;
+ if (in)
+ {
+ in >> version;
+ return version;
+ }
+ in.clear();
+ in.open(get_tz_dir() + std::string(1, folder_delimiter) + "version");
+ if (in)
+ {
+ in >> version;
+ return version;
+ }
+ return "unknown";
+}
+
+static
+std::vector<leap_second>
+find_read_and_leap_seconds()
+{
+ std::ifstream in(get_tz_dir() + std::string(1, folder_delimiter) + "leapseconds",
+ std::ios_base::binary);
+ if (in)
+ {
+ std::vector<leap_second> leap_seconds;
+ std::string line;
+ while (in)
+ {
+ std::getline(in, line);
+ if (!line.empty() && line[0] != '#')
+ {
+ std::istringstream in(line);
+ in.exceptions(std::ios::failbit | std::ios::badbit);
+ std::string word;
+ in >> word;
+ if (word == "Leap")
+ {
+ int y, m, d;
+ in >> y;
+ m = static_cast<int>(parse_month(in));
+ in >> d;
+ leap_seconds.push_back(leap_second(sys_days{year{y}/m/d} + days{1},
+ detail::undocumented{}));
+ }
+ else
+ {
+ std::cerr << line << '\n';
+ }
+ }
+ }
+ return leap_seconds;
+ }
+ in.clear();
+ in.open(get_tz_dir() + std::string(1, folder_delimiter) + "leap-seconds.list",
+ std::ios_base::binary);
+ if (in)
+ {
+ std::vector<leap_second> leap_seconds;
+ std::string line;
+ const auto offset = sys_days{1970_y/1/1}-sys_days{1900_y/1/1};
+ while (in)
+ {
+ std::getline(in, line);
+ if (!line.empty() && line[0] != '#')
+ {
+ std::istringstream in(line);
+ in.exceptions(std::ios::failbit | std::ios::badbit);
+ using seconds = std::chrono::seconds;
+ seconds::rep s;
+ in >> s;
+ if (s == 2272060800)
+ continue;
+ leap_seconds.push_back(leap_second(sys_seconds{seconds{s}} - offset,
+ detail::undocumented{}));
+ }
+ }
+ return leap_seconds;
+ }
+ in.clear();
+ in.open(get_tz_dir() + std::string(1, folder_delimiter) + "right/UTC",
+ std::ios_base::binary);
+ if (in)
+ {
+ return load_just_leaps(in);
+ }
+ in.clear();
+ in.open(get_tz_dir() + std::string(1, folder_delimiter) + "UTC",
+ std::ios_base::binary);
+ if (in)
+ {
+ return load_just_leaps(in);
+ }
+ return {};
+}
+
+static
+std::unique_ptr<tzdb>
+init_tzdb()
+{
+ std::unique_ptr<tzdb> db(new tzdb);
+
+ //Iterate through folders
+ std::queue<std::string> subfolders;
+ subfolders.emplace(get_tz_dir());
+ struct dirent* d;
+ struct stat s;
+ while (!subfolders.empty())
+ {
+ auto dirname = std::move(subfolders.front());
+ subfolders.pop();
+ auto dir = opendir(dirname.c_str());
+ if (!dir)
+ continue;
+ while ((d = readdir(dir)) != nullptr)
+ {
+ // Ignore these files:
+ if (d->d_name[0] == '.' || // curdir, prevdir, hidden
+ memcmp(d->d_name, "posix", 5) == 0 || // starts with posix
+ strcmp(d->d_name, "Factory") == 0 ||
+ strcmp(d->d_name, "iso3166.tab") == 0 ||
+ strcmp(d->d_name, "right") == 0 ||
+ strcmp(d->d_name, "+VERSION") == 0 ||
+ strcmp(d->d_name, "version") == 0 ||
+ strcmp(d->d_name, "zone.tab") == 0 ||
+ strcmp(d->d_name, "zone1970.tab") == 0 ||
+ strcmp(d->d_name, "tzdata.zi") == 0 ||
+ strcmp(d->d_name, "leapseconds") == 0 ||
+ strcmp(d->d_name, "leap-seconds.list") == 0 )
+ continue;
+ auto subname = dirname + folder_delimiter + d->d_name;
+ if(stat(subname.c_str(), &s) == 0)
+ {
+ if(S_ISDIR(s.st_mode))
+ {
+ if(!S_ISLNK(s.st_mode))
+ {
+ subfolders.push(subname);
+ }
+ }
+ else
+ {
+ db->zones.emplace_back(subname.substr(get_tz_dir().size()+1),
+ detail::undocumented{});
+ }
+ }
+ }
+ closedir(dir);
+ }
+ db->zones.shrink_to_fit();
+ std::sort(db->zones.begin(), db->zones.end());
+ db->leap_seconds = find_read_and_leap_seconds();
+ db->version = get_version();
+ return db;
+}
+
+#else // !USE_OS_TZDB
+
+// time_zone_link
+
+time_zone_link::time_zone_link(const std::string& s)
+{
+ using namespace date;
+ std::istringstream in(s);
+ in.exceptions(std::ios::failbit | std::ios::badbit);
+ std::string word;
+ in >> word >> target_ >> name_;
+}
+
+std::ostream&
+operator<<(std::ostream& os, const time_zone_link& x)
+{
+ using namespace date;
+ detail::save_ostream<char> _(os);
+ os.fill(' ');
+ os.flags(std::ios::dec | std::ios::left);
+ os.width(35);
+ return os << x.name_ << " --> " << x.target_;
+}
+
+// leap_second
+
+leap_second::leap_second(const std::string& s, detail::undocumented)
+{
+ using namespace date;
+ std::istringstream in(s);
+ in.exceptions(std::ios::failbit | std::ios::badbit);
+ std::string word;
+ int y;
+ MonthDayTime date;
+ in >> word >> y >> date;
+ date_ = date.to_time_point(year(y));
+}
+
+static
+bool
+file_exists(const std::string& filename)
+{
+#ifdef _WIN32
+ return ::_access(filename.c_str(), 0) == 0;
+#else
+ return ::access(filename.c_str(), F_OK) == 0;
+#endif
+}
+
+#if HAS_REMOTE_API
+
+// CURL tools
+
+namespace
+{
+
+struct curl_global_init_and_cleanup
+{
+ ~curl_global_init_and_cleanup()
+ {
+ ::curl_global_cleanup();
+ }
+ curl_global_init_and_cleanup()
+ {
+ if (::curl_global_init(CURL_GLOBAL_DEFAULT) != 0)
+ throw std::runtime_error("CURL global initialization failed");
+ }
+ curl_global_init_and_cleanup(curl_global_init_and_cleanup const&) = delete;
+ curl_global_init_and_cleanup& operator=(curl_global_init_and_cleanup const&) = delete;
+};
+
+struct curl_deleter
+{
+ void operator()(CURL* p) const
+ {
+ ::curl_easy_cleanup(p);
+ }
+};
+
+} // unnamed namespace
+
+static
+std::unique_ptr<CURL, curl_deleter>
+curl_init()
+{
+ static const curl_global_init_and_cleanup _{};
+ return std::unique_ptr<CURL, curl_deleter>{::curl_easy_init()};
+}
+
+static
+bool
+download_to_string(const std::string& url, std::string& str)
+{
+ str.clear();
+ auto curl = curl_init();
+ if (!curl)
+ return false;
+ std::string version;
+ curl_easy_setopt(curl.get(), CURLOPT_USERAGENT, "curl");
+ curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
+ curl_write_callback write_cb = [](char* contents, std::size_t size, std::size_t nmemb,
+ void* userp) -> std::size_t
+ {
+ auto& userstr = *static_cast<std::string*>(userp);
+ auto realsize = size * nmemb;
+ userstr.append(contents, realsize);
+ return realsize;
+ };
+ curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, write_cb);
+ curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &str);
+ curl_easy_setopt(curl.get(), CURLOPT_SSL_VERIFYPEER, false);
+ auto res = curl_easy_perform(curl.get());
+ return (res == CURLE_OK);
+}
+
+namespace
+{
+ enum class download_file_options { binary, text };
+}
+
+static
+bool
+download_to_file(const std::string& url, const std::string& local_filename,
+ download_file_options opts, char* error_buffer)
+{
+ auto curl = curl_init();
+ if (!curl)
+ return false;
+ curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
+ curl_easy_setopt(curl.get(), CURLOPT_SSL_VERIFYPEER, false);
+ if (error_buffer)
+ curl_easy_setopt(curl.get(), CURLOPT_ERRORBUFFER, error_buffer);
+ curl_write_callback write_cb = [](char* contents, std::size_t size, std::size_t nmemb,
+ void* userp) -> std::size_t
+ {
+ auto& of = *static_cast<std::ofstream*>(userp);
+ auto realsize = size * nmemb;
+ of.write(contents, static_cast<std::streamsize>(realsize));
+ return realsize;
+ };
+ curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, write_cb);
+ decltype(curl_easy_perform(curl.get())) res;
+ {
+ std::ofstream of(local_filename,
+ opts == download_file_options::binary ?
+ std::ofstream::out | std::ofstream::binary :
+ std::ofstream::out);
+ of.exceptions(std::ios::badbit);
+ curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &of);
+ res = curl_easy_perform(curl.get());
+ }
+ return res == CURLE_OK;
+}
+
+std::string
+remote_version()
+{
+ std::string version;
+ std::string str;
+ if (download_to_string("https://www.iana.org/time-zones", str))
+ {
+ CONSTDATA char db[] = "/time-zones/releases/tzdata";
+ CONSTDATA auto db_size = sizeof(db) - 1;
+ auto p = str.find(db, 0, db_size);
+ const int ver_str_len = 5;
+ if (p != std::string::npos && p + (db_size + ver_str_len) <= str.size())
+ version = str.substr(p + db_size, ver_str_len);
+ }
+ return version;
+}
+
+
+// TODO! Using system() create a process and a console window.
+// This is useful to see what errors may occur but is slow and distracting.
+// Consider implementing this functionality more directly, such as
+// using _mkdir and CreateProcess etc.
+// But use the current means now as matches Unix implementations and while
+// in proof of concept / testing phase.
+// TODO! Use <filesystem> eventually.
+static
+bool
+remove_folder_and_subfolders(const std::string& folder)
+{
+# ifdef _WIN32
+# if USE_SHELL_API
+ // Delete the folder contents by deleting the folder.
+ std::string cmd = "rd /s /q \"";
+ cmd += folder;
+ cmd += '\"';
+ return std::system(cmd.c_str()) == EXIT_SUCCESS;
+# else // !USE_SHELL_API
+ // Create a buffer containing the path to delete. It must be terminated
+ // by two nuls. Who designs these API's...
+ std::vector<char> from;
+ from.assign(folder.begin(), folder.end());
+ from.push_back('\0');
+ from.push_back('\0');
+ SHFILEOPSTRUCT fo{}; // Zero initialize.
+ fo.wFunc = FO_DELETE;
+ fo.pFrom = from.data();
+ fo.fFlags = FOF_NO_UI;
+ int ret = SHFileOperation(&fo);
+ if (ret == 0 && !fo.fAnyOperationsAborted)
+ return true;
+ return false;
+# endif // !USE_SHELL_API
+# else // !_WIN32
+# if USE_SHELL_API
+ return std::system(("rm -R " + folder).c_str()) == EXIT_SUCCESS;
+# else // !USE_SHELL_API
+ struct dir_deleter {
+ dir_deleter() {}
+ void operator()(DIR* d) const
+ {
+ if (d != nullptr)
+ {
+ int result = closedir(d);
+ assert(result == 0);
+ }
+ }
+ };
+ using closedir_ptr = std::unique_ptr<DIR, dir_deleter>;
+
+ std::string filename;
+ struct stat statbuf;
+ std::size_t folder_len = folder.length();
+ struct dirent* p = nullptr;
+
+ closedir_ptr d(opendir(folder.c_str()));
+ bool r = d.get() != nullptr;
+ while (r && (p=readdir(d.get())) != nullptr)
+ {
+ if (strcmp(p->d_name, ".") == 0 || strcmp(p->d_name, "..") == 0)
+ continue;
+
+ // + 2 for path delimiter and nul terminator.
+ std::size_t buf_len = folder_len + strlen(p->d_name) + 2;
+ filename.resize(buf_len);
+ std::size_t path_len = static_cast<std::size_t>(
+ snprintf(&filename[0], buf_len, "%s/%s", folder.c_str(), p->d_name));
+ assert(path_len == buf_len - 1);
+ filename.resize(path_len);
+
+ if (stat(filename.c_str(), &statbuf) == 0)
+ r = S_ISDIR(statbuf.st_mode)
+ ? remove_folder_and_subfolders(filename)
+ : unlink(filename.c_str()) == 0;
+ }
+ d.reset();
+
+ if (r)
+ r = rmdir(folder.c_str()) == 0;
+
+ return r;
+# endif // !USE_SHELL_API
+# endif // !_WIN32
+}
+
+static
+bool
+make_directory(const std::string& folder)
+{
+# ifdef _WIN32
+# if USE_SHELL_API
+ // Re-create the folder.
+ std::string cmd = "mkdir \"";
+ cmd += folder;
+ cmd += '\"';
+ return std::system(cmd.c_str()) == EXIT_SUCCESS;
+# else // !USE_SHELL_API
+ return _mkdir(folder.c_str()) == 0;
+# endif // !USE_SHELL_API
+# else // !_WIN32
+# if USE_SHELL_API
+ return std::system(("mkdir -p " + folder).c_str()) == EXIT_SUCCESS;
+# else // !USE_SHELL_API
+ return mkdir(folder.c_str(), 0777) == 0;
+# endif // !USE_SHELL_API
+# endif // !_WIN32
+}
+
+static
+bool
+delete_file(const std::string& file)
+{
+# ifdef _WIN32
+# if USE_SHELL_API
+ std::string cmd = "del \"";
+ cmd += file;
+ cmd += '\"';
+ return std::system(cmd.c_str()) == 0;
+# else // !USE_SHELL_API
+ return _unlink(file.c_str()) == 0;
+# endif // !USE_SHELL_API
+# else // !_WIN32
+# if USE_SHELL_API
+ return std::system(("rm " + file).c_str()) == EXIT_SUCCESS;
+# else // !USE_SHELL_API
+ return unlink(file.c_str()) == 0;
+# endif // !USE_SHELL_API
+# endif // !_WIN32
+}
+
+# ifdef _WIN32
+
+static
+bool
+move_file(const std::string& from, const std::string& to)
+{
+# if USE_SHELL_API
+ std::string cmd = "move \"";
+ cmd += from;
+ cmd += "\" \"";
+ cmd += to;
+ cmd += '\"';
+ return std::system(cmd.c_str()) == EXIT_SUCCESS;
+# else // !USE_SHELL_API
+ return !!::MoveFile(from.c_str(), to.c_str());
+# endif // !USE_SHELL_API
+}
+
+// Usually something like "c:\Program Files".
+static
+std::string
+get_program_folder()
+{
+ return get_known_folder(FOLDERID_ProgramFiles);
+}
+
+// Note folder can and usually does contain spaces.
+static
+std::string
+get_unzip_program()
+{
+ std::string path;
+
+ // 7-Zip appears to note its location in the registry.
+ // If that doesn't work, fall through and take a guess, but it will likely be wrong.
+ HKEY hKey = nullptr;
+ if (RegOpenKeyExA(HKEY_LOCAL_MACHINE, "SOFTWARE\\7-Zip", 0, KEY_READ, &hKey) == ERROR_SUCCESS)
+ {
+ char value_buffer[MAX_PATH + 1]; // fyi 260 at time of writing.
+ // in/out parameter. Documentation say that size is a count of bytes not chars.
+ DWORD size = sizeof(value_buffer) - sizeof(value_buffer[0]);
+ DWORD tzi_type = REG_SZ;
+ // Testing shows Path key value is "C:\Program Files\7-Zip\" i.e. always with trailing \.
+ bool got_value = (RegQueryValueExA(hKey, "Path", nullptr, &tzi_type,
+ reinterpret_cast<LPBYTE>(value_buffer), &size) == ERROR_SUCCESS);
+ RegCloseKey(hKey); // Close now incase of throw later.
+ if (got_value)
+ {
+ // Function does not guarantee to null terminate.
+ value_buffer[size / sizeof(value_buffer[0])] = '\0';
+ path = value_buffer;
+ if (!path.empty())
+ {
+ path += "7z.exe";
+ return path;
+ }
+ }
+ }
+ path += get_program_folder();
+ path += folder_delimiter;
+ path += "7-Zip\\7z.exe";
+ return path;
+}
+
+# if !USE_SHELL_API
+static
+int
+run_program(const std::string& command)
+{
+ STARTUPINFO si{};
+ si.cb = sizeof(si);
+ PROCESS_INFORMATION pi{};
+
+ // Allegedly CreateProcess overwrites the command line. Ugh.
+ std::string mutable_command(command);
+ if (CreateProcess(nullptr, &mutable_command[0],
+ nullptr, nullptr, FALSE, CREATE_NO_WINDOW, nullptr, nullptr, &si, &pi))
+ {
+ WaitForSingleObject(pi.hProcess, INFINITE);
+ DWORD exit_code;
+ bool got_exit_code = !!GetExitCodeProcess(pi.hProcess, &exit_code);
+ CloseHandle(pi.hProcess);
+ CloseHandle(pi.hThread);
+ // Not 100% sure about this still active thing is correct,
+ // but I'm going with it because I *think* WaitForSingleObject might
+ // return in some cases without INFINITE-ly waiting.
+ // But why/wouldn't GetExitCodeProcess return false in that case?
+ if (got_exit_code && exit_code != STILL_ACTIVE)
+ return static_cast<int>(exit_code);
+ }
+ return EXIT_FAILURE;
+}
+# endif // !USE_SHELL_API
+
+static
+std::string
+get_download_tar_file(const std::string& version)
+{
+ auto file = get_install();
+ file += folder_delimiter;
+ file += "tzdata";
+ file += version;
+ file += ".tar";
+ return file;
+}
+
+static
+bool
+extract_gz_file(const std::string& version, const std::string& gz_file,
+ const std::string& dest_folder)
+{
+ auto unzip_prog = get_unzip_program();
+ bool unzip_result = false;
+ // Use the unzip program to extract the tar file from the archive.
+
+ // Aim to create a string like:
+ // "C:\Program Files\7-Zip\7z.exe" x "C:\Users\SomeUser\Downloads\tzdata2016d.tar.gz"
+ // -o"C:\Users\SomeUser\Downloads\tzdata"
+ std::string cmd;
+ cmd = '\"';
+ cmd += unzip_prog;
+ cmd += "\" x \"";
+ cmd += gz_file;
+ cmd += "\" -o\"";
+ cmd += dest_folder;
+ cmd += '\"';
+
+# if USE_SHELL_API
+ // When using shelling out with std::system() extra quotes are required around the
+ // whole command. It's weird but necessary it seems, see:
+ // http://stackoverflow.com/q/27975969/576911
+
+ cmd = "\"" + cmd + "\"";
+ if (std::system(cmd.c_str()) == EXIT_SUCCESS)
+ unzip_result = true;
+# else // !USE_SHELL_API
+ if (run_program(cmd) == EXIT_SUCCESS)
+ unzip_result = true;
+# endif // !USE_SHELL_API
+ if (unzip_result)
+ delete_file(gz_file);
+
+ // Use the unzip program extract the data from the tar file that was
+ // just extracted from the archive.
+ auto tar_file = get_download_tar_file(version);
+ cmd = '\"';
+ cmd += unzip_prog;
+ cmd += "\" x \"";
+ cmd += tar_file;
+ cmd += "\" -o\"";
+ cmd += get_install();
+ cmd += '\"';
+# if USE_SHELL_API
+ cmd = "\"" + cmd + "\"";
+ if (std::system(cmd.c_str()) == EXIT_SUCCESS)
+ unzip_result = true;
+# else // !USE_SHELL_API
+ if (run_program(cmd) == EXIT_SUCCESS)
+ unzip_result = true;
+# endif // !USE_SHELL_API
+
+ if (unzip_result)
+ delete_file(tar_file);
+
+ return unzip_result;
+}
+
+static
+std::string
+get_download_mapping_file(const std::string& version)
+{
+ auto file = get_install() + version + "windowsZones.xml";
+ return file;
+}
+
+# else // !_WIN32
+
+# if !USE_SHELL_API
+static
+int
+run_program(const char* prog, const char*const args[])
+{
+ pid_t pid = fork();
+ if (pid == -1) // Child failed to start.
+ return EXIT_FAILURE;
+
+ if (pid != 0)
+ {
+ // We are in the parent. Child started. Wait for it.
+ pid_t ret;
+ int status;
+ while ((ret = waitpid(pid, &status, 0)) == -1)
+ {
+ if (errno != EINTR)
+ break;
+ }
+ if (ret != -1)
+ {
+ if (WIFEXITED(status))
+ return WEXITSTATUS(status);
+ }
+ printf("Child issues!\n");
+
+ return EXIT_FAILURE; // Not sure what status of child is.
+ }
+ else // We are in the child process. Start the program the parent wants to run.
+ {
+
+ if (execv(prog, const_cast<char**>(args)) == -1) // Does not return.
+ {
+ perror("unreachable 0\n");
+ _Exit(127);
+ }
+ printf("unreachable 2\n");
+ }
+ printf("unreachable 2\n");
+ // Unreachable.
+ assert(false);
+ exit(EXIT_FAILURE);
+ return EXIT_FAILURE;
+}
+# endif // !USE_SHELL_API
+
+static
+bool
+extract_gz_file(const std::string&, const std::string& gz_file, const std::string&)
+{
+# if USE_SHELL_API
+ bool unzipped = std::system(("tar -xzf " + gz_file + " -C " + get_install()).c_str()) == EXIT_SUCCESS;
+# else // !USE_SHELL_API
+ const char prog[] = {"/usr/bin/tar"};
+ const char*const args[] =
+ {
+ prog, "-xzf", gz_file.c_str(), "-C", get_install().c_str(), nullptr
+ };
+ bool unzipped = (run_program(prog, args) == EXIT_SUCCESS);
+# endif // !USE_SHELL_API
+ if (unzipped)
+ {
+ delete_file(gz_file);
+ return true;
+ }
+ return false;
+}
+
+# endif // !_WIN32
+
+bool
+remote_download(const std::string& version, char* error_buffer)
+{
+ assert(!version.empty());
+
+# ifdef _WIN32
+ // Download folder should be always available for Windows
+# else // !_WIN32
+ // Create download folder if it does not exist on UNIX system
+ auto download_folder = get_install();
+ if (!file_exists(download_folder))
+ {
+ if (!make_directory(download_folder))
+ return false;
+ }
+# endif // _WIN32
+
+ auto url = "https://data.iana.org/time-zones/releases/tzdata" + version +
+ ".tar.gz";
+ bool result = download_to_file(url, get_download_gz_file(version),
+ download_file_options::binary, error_buffer);
+# ifdef _WIN32
+ if (result)
+ {
+ auto mapping_file = get_download_mapping_file(version);
+ result = download_to_file(
+ "https://raw.githubusercontent.com/unicode-org/cldr/master/"
+ "common/supplemental/windowsZones.xml",
+ mapping_file, download_file_options::text, error_buffer);
+ }
+# endif // _WIN32
+ return result;
+}
+
+bool
+remote_install(const std::string& version)
+{
+ auto success = false;
+ assert(!version.empty());
+
+ std::string install = get_install();
+ auto gz_file = get_download_gz_file(version);
+ if (file_exists(gz_file))
+ {
+ if (file_exists(install))
+ remove_folder_and_subfolders(install);
+ if (make_directory(install))
+ {
+ if (extract_gz_file(version, gz_file, install))
+ success = true;
+# ifdef _WIN32
+ auto mapping_file_source = get_download_mapping_file(version);
+ auto mapping_file_dest = get_install();
+ mapping_file_dest += folder_delimiter;
+ mapping_file_dest += "windowsZones.xml";
+ if (!move_file(mapping_file_source, mapping_file_dest))
+ success = false;
+# endif // _WIN32
+ }
+ }
+ return success;
+}
+
+#endif // HAS_REMOTE_API
+
+static
+std::string
+get_version(const std::string& path)
+{
+ std::string version;
+ std::ifstream infile(path + "version");
+ if (infile.is_open())
+ {
+ infile >> version;
+ if (!infile.fail())
+ return version;
+ }
+ else
+ {
+ infile.open(path + "NEWS");
+ while (infile)
+ {
+ infile >> version;
+ if (version == "Release")
+ {
+ infile >> version;
+ return version;
+ }
+ }
+ }
+ throw std::runtime_error("Unable to get Timezone database version from " + path);
+}
+
+static
+std::unique_ptr<tzdb>
+init_tzdb()
+{
+ using namespace date;
+ const std::string install = get_install();
+ const std::string path = install + folder_delimiter;
+ std::string line;
+ bool continue_zone = false;
+ std::unique_ptr<tzdb> db(new tzdb);
+
+#if AUTO_DOWNLOAD
+ if (!file_exists(install))
+ {
+ auto rv = remote_version();
+ if (!rv.empty() && remote_download(rv))
+ {
+ if (!remote_install(rv))
+ {
+ std::string msg = "Timezone database version \"";
+ msg += rv;
+ msg += "\" did not install correctly to \"";
+ msg += install;
+ msg += "\"";
+ throw std::runtime_error(msg);
+ }
+ }
+ if (!file_exists(install))
+ {
+ std::string msg = "Timezone database not found at \"";
+ msg += install;
+ msg += "\"";
+ throw std::runtime_error(msg);
+ }
+ db->version = get_version(path);
+ }
+ else
+ {
+ db->version = get_version(path);
+ auto rv = remote_version();
+ if (!rv.empty() && db->version != rv)
+ {
+ if (remote_download(rv))
+ {
+ remote_install(rv);
+ db->version = get_version(path);
+ }
+ }
+ }
+#else // !AUTO_DOWNLOAD
+ if (!file_exists(install))
+ {
+ std::string msg = "Timezone database not found at \"";
+ msg += install;
+ msg += "\"";
+ throw std::runtime_error(msg);
+ }
+ db->version = get_version(path);
+#endif // !AUTO_DOWNLOAD
+
+ CONSTDATA char*const files[] =
+ {
+ "africa", "antarctica", "asia", "australasia", "backward", "etcetera", "europe",
+ "pacificnew", "northamerica", "southamerica", "systemv", "leapseconds"
+ };
+
+ for (const auto& filename : files)
+ {
+ std::ifstream infile(path + filename);
+ while (infile)
+ {
+ std::getline(infile, line);
+ if (!line.empty() && line[0] != '#')
+ {
+ std::istringstream in(line);
+ std::string word;
+ in >> word;
+ if (word == "Rule")
+ {
+ db->rules.push_back(Rule(line));
+ continue_zone = false;
+ }
+ else if (word == "Link")
+ {
+ db->links.push_back(time_zone_link(line));
+ continue_zone = false;
+ }
+ else if (word == "Leap")
+ {
+ db->leap_seconds.push_back(leap_second(line, detail::undocumented{}));
+ continue_zone = false;
+ }
+ else if (word == "Zone")
+ {
+ db->zones.push_back(time_zone(line, detail::undocumented{}));
+ continue_zone = true;
+ }
+ else if (line[0] == '\t' && continue_zone)
+ {
+ db->zones.back().add(line);
+ }
+ else
+ {
+ std::cerr << line << '\n';
+ }
+ }
+ }
+ }
+ std::sort(db->rules.begin(), db->rules.end());
+ Rule::split_overlaps(db->rules);
+ std::sort(db->zones.begin(), db->zones.end());
+ db->zones.shrink_to_fit();
+ std::sort(db->links.begin(), db->links.end());
+ db->links.shrink_to_fit();
+ std::sort(db->leap_seconds.begin(), db->leap_seconds.end());
+ db->leap_seconds.shrink_to_fit();
+
+#ifdef _WIN32
+ std::string mapping_file = get_install() + folder_delimiter + "windowsZones.xml";
+ db->mappings = load_timezone_mappings_from_xml_file(mapping_file);
+ sort_zone_mappings(db->mappings);
+#endif // _WIN32
+
+ return db;
+}
+
+const tzdb&
+reload_tzdb()
+{
+#if AUTO_DOWNLOAD
+ auto const& v = get_tzdb_list().front().version;
+ if (!v.empty() && v == remote_version())
+ return get_tzdb_list().front();
+#endif // AUTO_DOWNLOAD
+ tzdb_list::undocumented_helper::push_front(get_tzdb_list(), init_tzdb().release());
+ return get_tzdb_list().front();
+}
+
+#endif // !USE_OS_TZDB
+
+const tzdb&
+get_tzdb()
+{
+ return get_tzdb_list().front();
+}
+
+const time_zone*
+#if HAS_STRING_VIEW
+tzdb::locate_zone(std::string_view tz_name) const
+#else
+tzdb::locate_zone(const std::string& tz_name) const
+#endif
+{
+ auto zi = std::lower_bound(zones.begin(), zones.end(), tz_name,
+#if HAS_STRING_VIEW
+ [](const time_zone& z, const std::string_view& nm)
+#else
+ [](const time_zone& z, const std::string& nm)
+#endif
+ {
+ return z.name() < nm;
+ });
+ if (zi == zones.end() || zi->name() != tz_name)
+ {
+#if !USE_OS_TZDB
+ auto li = std::lower_bound(links.begin(), links.end(), tz_name,
+#if HAS_STRING_VIEW
+ [](const time_zone_link& z, const std::string_view& nm)
+#else
+ [](const time_zone_link& z, const std::string& nm)
+#endif
+ {
+ return z.name() < nm;
+ });
+ if (li != links.end() && li->name() == tz_name)
+ {
+ zi = std::lower_bound(zones.begin(), zones.end(), li->target(),
+ [](const time_zone& z, const std::string& nm)
+ {
+ return z.name() < nm;
+ });
+ if (zi != zones.end() && zi->name() == li->target())
+ return &*zi;
+ }
+#endif // !USE_OS_TZDB
+ throw std::runtime_error(std::string(tz_name) + " not found in timezone database");
+ }
+ return &*zi;
+}
+
+const time_zone*
+#if HAS_STRING_VIEW
+locate_zone(std::string_view tz_name)
+#else
+locate_zone(const std::string& tz_name)
+#endif
+{
+ return get_tzdb().locate_zone(tz_name);
+}
+
+#if USE_OS_TZDB
+
+std::ostream&
+operator<<(std::ostream& os, const tzdb& db)
+{
+ os << "Version: " << db.version << "\n\n";
+ for (const auto& x : db.zones)
+ os << x << '\n';
+ os << '\n';
+ for (const auto& x : db.leap_seconds)
+ os << x << '\n';
+ return os;
+}
+
+#else // !USE_OS_TZDB
+
+std::ostream&
+operator<<(std::ostream& os, const tzdb& db)
+{
+ os << "Version: " << db.version << '\n';
+ std::string title("--------------------------------------------"
+ "--------------------------------------------\n"
+ "Name ""Start Y ""End Y "
+ "Beginning ""Offset "
+ "Designator\n"
+ "--------------------------------------------"
+ "--------------------------------------------\n");
+ int count = 0;
+ for (const auto& x : db.rules)
+ {
+ if (count++ % 50 == 0)
+ os << title;
+ os << x << '\n';
+ }
+ os << '\n';
+ title = std::string("---------------------------------------------------------"
+ "--------------------------------------------------------\n"
+ "Name ""Offset "
+ "Rule ""Abrev ""Until\n"
+ "---------------------------------------------------------"
+ "--------------------------------------------------------\n");
+ count = 0;
+ for (const auto& x : db.zones)
+ {
+ if (count++ % 10 == 0)
+ os << title;
+ os << x << '\n';
+ }
+ os << '\n';
+ title = std::string("---------------------------------------------------------"
+ "--------------------------------------------------------\n"
+ "Alias ""To\n"
+ "---------------------------------------------------------"
+ "--------------------------------------------------------\n");
+ count = 0;
+ for (const auto& x : db.links)
+ {
+ if (count++ % 45 == 0)
+ os << title;
+ os << x << '\n';
+ }
+ os << '\n';
+ title = std::string("---------------------------------------------------------"
+ "--------------------------------------------------------\n"
+ "Leap second on\n"
+ "---------------------------------------------------------"
+ "--------------------------------------------------------\n");
+ os << title;
+ for (const auto& x : db.leap_seconds)
+ os << x << '\n';
+ return os;
+}
+
+#endif // !USE_OS_TZDB
+
+// -----------------------
+
+#ifdef _WIN32
+
+static
+std::string
+getTimeZoneKeyName()
+{
+ DYNAMIC_TIME_ZONE_INFORMATION dtzi{};
+ auto result = GetDynamicTimeZoneInformation(&dtzi);
+ if (result == TIME_ZONE_ID_INVALID)
+ throw std::runtime_error("current_zone(): GetDynamicTimeZoneInformation()"
+ " reported TIME_ZONE_ID_INVALID.");
+ auto wlen = wcslen(dtzi.TimeZoneKeyName);
+ char buf[128] = {};
+ assert(sizeof(buf) >= wlen+1);
+ wcstombs(buf, dtzi.TimeZoneKeyName, wlen);
+ if (strcmp(buf, "Coordinated Universal Time") == 0)
+ return "UTC";
+ return buf;
+}
+
+const time_zone*
+tzdb::current_zone() const
+{
+ std::string win_tzid = getTimeZoneKeyName();
+ std::string standard_tzid;
+ if (!native_to_standard_timezone_name(win_tzid, standard_tzid))
+ {
+ std::string msg;
+ msg = "current_zone() failed: A mapping from the Windows Time Zone id \"";
+ msg += win_tzid;
+ msg += "\" was not found in the time zone mapping database.";
+ throw std::runtime_error(msg);
+ }
+ return locate_zone(standard_tzid);
+}
+
+#else // !_WIN32
+
+#if HAS_STRING_VIEW
+
+static
+std::string_view
+extract_tz_name(char const* rp)
+{
+ using namespace std;
+ string_view result = rp;
+ CONSTDATA string_view zoneinfo = "zoneinfo";
+ size_t pos = result.rfind(zoneinfo);
+ if (pos == result.npos)
+ throw runtime_error(
+ "current_zone() failed to find \"zoneinfo\" in " + string(result));
+ pos = result.find('/', pos);
+ result.remove_prefix(pos + 1);
+ return result;
+}
+
+#else // !HAS_STRING_VIEW
+
+static
+std::string
+extract_tz_name(char const* rp)
+{
+ using namespace std;
+ string result = rp;
+ CONSTDATA char zoneinfo[] = "zoneinfo";
+ size_t pos = result.rfind(zoneinfo);
+ if (pos == result.npos)
+ throw runtime_error(
+ "current_zone() failed to find \"zoneinfo\" in " + result);
+ pos = result.find('/', pos);
+ result.erase(0, pos + 1);
+ return result;
+}
+
+#endif // HAS_STRING_VIEW
+
+static
+bool
+sniff_realpath(const char* timezone)
+{
+ using namespace std;
+ char rp[PATH_MAX+1] = {};
+ if (realpath(timezone, rp) == nullptr)
+ throw system_error(errno, system_category(), "realpath() failed");
+ auto result = extract_tz_name(rp);
+ return result != "posixrules";
+}
+
+const time_zone*
+tzdb::current_zone() const
+{
+ // On some OS's a file called /etc/localtime may
+ // exist and it may be either a real file
+ // containing time zone details or a symlink to such a file.
+ // On MacOS and BSD Unix if this file is a symlink it
+ // might resolve to a path like this:
+ // "/usr/share/zoneinfo/America/Los_Angeles"
+ // If it does, we try to determine the current
+ // timezone from the remainder of the path by removing the prefix
+ // and hoping the rest resolves to a valid timezone.
+ // It may not always work though. If it doesn't then an
+ // exception will be thrown by local_timezone.
+ // The path may also take a relative form:
+ // "../usr/share/zoneinfo/America/Los_Angeles".
+ {
+ struct stat sb;
+ CONSTDATA auto timezone = "/etc/localtime";
+ if (lstat(timezone, &sb) == 0 && S_ISLNK(sb.st_mode) && sb.st_size > 0)
+ {
+ using namespace std;
+ static const bool use_realpath = sniff_realpath(timezone);
+ char rp[PATH_MAX+1] = {};
+ if (use_realpath)
+ {
+ if (realpath(timezone, rp) == nullptr)
+ throw system_error(errno, system_category(), "realpath() failed");
+ }
+ else
+ {
+ if (readlink(timezone, rp, sizeof(rp)-1) <= 0)
+ throw system_error(errno, system_category(), "readlink() failed");
+ }
+ return locate_zone(extract_tz_name(rp));
+ }
+ }
+ // On embedded systems e.g. buildroot with uclibc the timezone is linked
+ // into /etc/TZ which is a symlink to path like this:
+ // "/usr/share/zoneinfo/uclibc/America/Los_Angeles"
+ // If it does, we try to determine the current
+ // timezone from the remainder of the path by removing the prefix
+ // and hoping the rest resolves to valid timezone.
+ // It may not always work though. If it doesn't then an
+ // exception will be thrown by local_timezone.
+ // The path may also take a relative form:
+ // "../usr/share/zoneinfo/uclibc/America/Los_Angeles".
+ {
+ struct stat sb;
+ CONSTDATA auto timezone = "/etc/TZ";
+ if (lstat(timezone, &sb) == 0 && S_ISLNK(sb.st_mode) && sb.st_size > 0) {
+ using namespace std;
+ string result;
+ char rp[PATH_MAX+1] = {};
+ if (readlink(timezone, rp, sizeof(rp)-1) > 0)
+ result = string(rp);
+ else
+ throw system_error(errno, system_category(), "readlink() failed");
+
+ const size_t pos = result.find(get_tz_dir());
+ if (pos != result.npos)
+ result.erase(0, get_tz_dir().size() + 1 + pos);
+ return locate_zone(result);
+ }
+ }
+ {
+ // On some versions of some linux distro's (e.g. Ubuntu),
+ // the current timezone might be in the first line of
+ // the /etc/timezone file.
+ std::ifstream timezone_file("/etc/timezone");
+ if (timezone_file.is_open())
+ {
+ std::string result;
+ std::getline(timezone_file, result);
+ if (!result.empty())
+ return locate_zone(result);
+ }
+ // Fall through to try other means.
+ }
+ {
+ // On some versions of some bsd distro's (e.g. FreeBSD),
+ // the current timezone might be in the first line of
+ // the /var/db/zoneinfo file.
+ std::ifstream timezone_file("/var/db/zoneinfo");
+ if (timezone_file.is_open())
+ {
+ std::string result;
+ std::getline(timezone_file, result);
+ if (!result.empty())
+ return locate_zone(result);
+ }
+ // Fall through to try other means.
+ }
+ {
+ // On some versions of some bsd distro's (e.g. iOS),
+ // it is not possible to use file based approach,
+ // we switch to system API, calling functions in
+ // CoreFoundation framework.
+#if TARGET_OS_IPHONE
+ std::string result = date::iOSUtils::get_current_timezone();
+ if (!result.empty())
+ return locate_zone(result);
+#endif
+ // Fall through to try other means.
+ }
+ {
+ // On some versions of some linux distro's (e.g. Red Hat),
+ // the current timezone might be in the first line of
+ // the /etc/sysconfig/clock file as:
+ // ZONE="US/Eastern"
+ std::ifstream timezone_file("/etc/sysconfig/clock");
+ std::string result;
+ while (timezone_file)
+ {
+ std::getline(timezone_file, result);
+ auto p = result.find("ZONE=\"");
+ if (p != std::string::npos)
+ {
+ result.erase(p, p+6);
+ result.erase(result.rfind('"'));
+ return locate_zone(result);
+ }
+ }
+ // Fall through to try other means.
+ }
+ throw std::runtime_error("Could not get current timezone");
+}
+
+#endif // !_WIN32
+
+const time_zone*
+current_zone()
+{
+ return get_tzdb().current_zone();
+}
+
+} // namespace date
+} // namespace arrow_vendored
+
+#if defined(__GNUC__) && __GNUC__ < 5
+# pragma GCC diagnostic pop
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/datetime/tz.h b/src/arrow/cpp/src/arrow/vendored/datetime/tz.h
new file mode 100644
index 000000000..6d54e49ea
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/datetime/tz.h
@@ -0,0 +1,2801 @@
+#ifndef TZ_H
+#define TZ_H
+
+// The MIT License (MIT)
+//
+// Copyright (c) 2015, 2016, 2017 Howard Hinnant
+// Copyright (c) 2017 Jiangang Zhuang
+// Copyright (c) 2017 Aaron Bishop
+// Copyright (c) 2017 Tomasz Kamiński
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+//
+// Our apologies. When the previous paragraph was written, lowercase had not yet
+// been invented (that would involve another several millennia of evolution).
+// We did not mean to shout.
+
+// Get more recent database at http://www.iana.org/time-zones
+
+// The notion of "current timezone" is something the operating system is expected to "just
+// know". How it knows this is system specific. It's often a value set by the user at OS
+// installation time and recorded by the OS somewhere. On Linux and Mac systems the current
+// timezone name is obtained by looking at the name or contents of a particular file on
+// disk. On Windows the current timezone name comes from the registry. In either method,
+// there is no guarantee that the "native" current timezone name obtained will match any
+// of the "Standard" names in this library's "database". On Linux, the names usually do
+// seem to match so mapping functions to map from native to "Standard" are typically not
+// required. On Windows, the names are never "Standard" so mapping is always required.
+// Technically any OS may use the mapping process but currently only Windows does use it.
+
+// NOTE(ARROW): If this is not set, then the library will attempt to
+// use libcurl to obtain a timezone database, and we probably do not want this.
+#ifndef _WIN32
+#define USE_OS_TZDB 1
+#endif
+
+#ifndef USE_OS_TZDB
+# define USE_OS_TZDB 0
+#endif
+
+#ifndef HAS_REMOTE_API
+# if USE_OS_TZDB == 0
+# ifdef _WIN32
+# define HAS_REMOTE_API 0
+# else
+# define HAS_REMOTE_API 1
+# endif
+# else // HAS_REMOTE_API makes no since when using the OS timezone database
+# define HAS_REMOTE_API 0
+# endif
+#endif
+
+#ifdef __clang__
+# pragma clang diagnostic push
+# pragma clang diagnostic ignored "-Wconstant-logical-operand"
+#endif
+
+static_assert(!(USE_OS_TZDB && HAS_REMOTE_API),
+ "USE_OS_TZDB and HAS_REMOTE_API can not be used together");
+
+#ifdef __clang__
+# pragma clang diagnostic pop
+#endif
+
+#ifndef AUTO_DOWNLOAD
+# define AUTO_DOWNLOAD HAS_REMOTE_API
+#endif
+
+static_assert(HAS_REMOTE_API == 0 ? AUTO_DOWNLOAD == 0 : true,
+ "AUTO_DOWNLOAD can not be turned on without HAS_REMOTE_API");
+
+#ifndef USE_SHELL_API
+# define USE_SHELL_API 1
+#endif
+
+#if USE_OS_TZDB
+# ifdef _WIN32
+# error "USE_OS_TZDB can not be used on Windows"
+# endif
+#endif
+
+#ifndef HAS_DEDUCTION_GUIDES
+# if __cplusplus >= 201703
+# define HAS_DEDUCTION_GUIDES 1
+# else
+# define HAS_DEDUCTION_GUIDES 0
+# endif
+#endif // HAS_DEDUCTION_GUIDES
+
+#include "date.h"
+
+#if defined(_MSC_VER) && (_MSC_VER < 1900)
+#include "tz_private.h"
+#endif
+
+#include <algorithm>
+#include <atomic>
+#include <cassert>
+#include <chrono>
+#include <istream>
+#include <locale>
+#include <memory>
+#include <mutex>
+#include <ostream>
+#include <sstream>
+#include <stdexcept>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#ifdef _WIN32
+# ifdef DATE_BUILD_DLL
+# define DATE_API __declspec(dllexport)
+# elif defined(DATE_USE_DLL)
+# define DATE_API __declspec(dllimport)
+# else
+# define DATE_API
+# endif
+#else
+# ifdef DATE_BUILD_DLL
+# define DATE_API __attribute__ ((visibility ("default")))
+# else
+# define DATE_API
+# endif
+#endif
+
+namespace arrow_vendored
+{
+namespace date
+{
+
+enum class choose {earliest, latest};
+
+namespace detail
+{
+ struct undocumented;
+
+ template<typename T>
+ struct nodeduct
+ {
+ using type = T;
+ };
+
+ template<typename T>
+ using nodeduct_t = typename nodeduct<T>::type;
+}
+
+struct sys_info
+{
+ sys_seconds begin;
+ sys_seconds end;
+ std::chrono::seconds offset;
+ std::chrono::minutes save;
+ std::string abbrev;
+};
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const sys_info& r)
+{
+ os << r.begin << '\n';
+ os << r.end << '\n';
+ os << make_time(r.offset) << "\n";
+ os << make_time(r.save) << "\n";
+ os << r.abbrev << '\n';
+ return os;
+}
+
+struct local_info
+{
+ enum {unique, nonexistent, ambiguous} result;
+ sys_info first;
+ sys_info second;
+};
+
+template<class CharT, class Traits>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const local_info& r)
+{
+ if (r.result == local_info::nonexistent)
+ os << "nonexistent between\n";
+ else if (r.result == local_info::ambiguous)
+ os << "ambiguous between\n";
+ os << r.first;
+ if (r.result != local_info::unique)
+ {
+ os << "and\n";
+ os << r.second;
+ }
+ return os;
+}
+
+class nonexistent_local_time
+ : public std::runtime_error
+{
+public:
+ template <class Duration>
+ nonexistent_local_time(local_time<Duration> tp, const local_info& i);
+
+private:
+ template <class Duration>
+ static
+ std::string
+ make_msg(local_time<Duration> tp, const local_info& i);
+};
+
+template <class Duration>
+inline
+nonexistent_local_time::nonexistent_local_time(local_time<Duration> tp,
+ const local_info& i)
+ : std::runtime_error(make_msg(tp, i))
+{
+}
+
+template <class Duration>
+std::string
+nonexistent_local_time::make_msg(local_time<Duration> tp, const local_info& i)
+{
+ assert(i.result == local_info::nonexistent);
+ std::ostringstream os;
+ os << tp << " is in a gap between\n"
+ << local_seconds{i.first.end.time_since_epoch()} + i.first.offset << ' '
+ << i.first.abbrev << " and\n"
+ << local_seconds{i.second.begin.time_since_epoch()} + i.second.offset << ' '
+ << i.second.abbrev
+ << " which are both equivalent to\n"
+ << i.first.end << " UTC";
+ return os.str();
+}
+
+class ambiguous_local_time
+ : public std::runtime_error
+{
+public:
+ template <class Duration>
+ ambiguous_local_time(local_time<Duration> tp, const local_info& i);
+
+private:
+ template <class Duration>
+ static
+ std::string
+ make_msg(local_time<Duration> tp, const local_info& i);
+};
+
+template <class Duration>
+inline
+ambiguous_local_time::ambiguous_local_time(local_time<Duration> tp, const local_info& i)
+ : std::runtime_error(make_msg(tp, i))
+{
+}
+
+template <class Duration>
+std::string
+ambiguous_local_time::make_msg(local_time<Duration> tp, const local_info& i)
+{
+ assert(i.result == local_info::ambiguous);
+ std::ostringstream os;
+ os << tp << " is ambiguous. It could be\n"
+ << tp << ' ' << i.first.abbrev << " == "
+ << tp - i.first.offset << " UTC or\n"
+ << tp << ' ' << i.second.abbrev << " == "
+ << tp - i.second.offset << " UTC";
+ return os.str();
+}
+
+class time_zone;
+
+#if HAS_STRING_VIEW
+DATE_API const time_zone* locate_zone(std::string_view tz_name);
+#else
+DATE_API const time_zone* locate_zone(const std::string& tz_name);
+#endif
+
+DATE_API const time_zone* current_zone();
+
+template <class T>
+struct zoned_traits
+{
+};
+
+template <>
+struct zoned_traits<const time_zone*>
+{
+ static
+ const time_zone*
+ default_zone()
+ {
+ return date::locate_zone("Etc/UTC");
+ }
+
+#if HAS_STRING_VIEW
+
+ static
+ const time_zone*
+ locate_zone(std::string_view name)
+ {
+ return date::locate_zone(name);
+ }
+
+#else // !HAS_STRING_VIEW
+
+ static
+ const time_zone*
+ locate_zone(const std::string& name)
+ {
+ return date::locate_zone(name);
+ }
+
+ static
+ const time_zone*
+ locate_zone(const char* name)
+ {
+ return date::locate_zone(name);
+ }
+
+#endif // !HAS_STRING_VIEW
+};
+
+template <class Duration, class TimeZonePtr>
+class zoned_time;
+
+template <class Duration1, class Duration2, class TimeZonePtr>
+bool
+operator==(const zoned_time<Duration1, TimeZonePtr>& x,
+ const zoned_time<Duration2, TimeZonePtr>& y);
+
+template <class Duration, class TimeZonePtr = const time_zone*>
+class zoned_time
+{
+public:
+ using duration = typename std::common_type<Duration, std::chrono::seconds>::type;
+
+private:
+ TimeZonePtr zone_;
+ sys_time<duration> tp_;
+
+public:
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class T = TimeZonePtr,
+ class = decltype(zoned_traits<T>::default_zone())>
+#endif
+ zoned_time();
+
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class T = TimeZonePtr,
+ class = decltype(zoned_traits<T>::default_zone())>
+#endif
+ zoned_time(const sys_time<Duration>& st);
+ explicit zoned_time(TimeZonePtr z);
+
+#if HAS_STRING_VIEW
+ template <class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string_view()))
+ >::value
+ >::type>
+ explicit zoned_time(std::string_view name);
+#else
+# if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string()))
+ >::value
+ >::type>
+# endif
+ explicit zoned_time(const std::string& name);
+#endif
+
+ template <class Duration2,
+ class = typename std::enable_if
+ <
+ std::is_convertible<sys_time<Duration2>,
+ sys_time<Duration>>::value
+ >::type>
+ zoned_time(const zoned_time<Duration2, TimeZonePtr>& zt) NOEXCEPT;
+
+ zoned_time(TimeZonePtr z, const sys_time<Duration>& st);
+
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_convertible
+ <
+ decltype(std::declval<T&>()->to_sys(local_time<Duration>{})),
+ sys_time<duration>
+ >::value
+ >::type>
+#endif
+ zoned_time(TimeZonePtr z, const local_time<Duration>& tp);
+
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_convertible
+ <
+ decltype(std::declval<T&>()->to_sys(local_time<Duration>{},
+ choose::earliest)),
+ sys_time<duration>
+ >::value
+ >::type>
+#endif
+ zoned_time(TimeZonePtr z, const local_time<Duration>& tp, choose c);
+
+ template <class Duration2, class TimeZonePtr2,
+ class = typename std::enable_if
+ <
+ std::is_convertible<sys_time<Duration2>,
+ sys_time<Duration>>::value
+ >::type>
+ zoned_time(TimeZonePtr z, const zoned_time<Duration2, TimeZonePtr2>& zt);
+
+ template <class Duration2, class TimeZonePtr2,
+ class = typename std::enable_if
+ <
+ std::is_convertible<sys_time<Duration2>,
+ sys_time<Duration>>::value
+ >::type>
+ zoned_time(TimeZonePtr z, const zoned_time<Duration2, TimeZonePtr2>& zt, choose);
+
+#if HAS_STRING_VIEW
+
+ template <class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string_view())),
+ sys_time<Duration>
+ >::value
+ >::type>
+ zoned_time(std::string_view name, detail::nodeduct_t<const sys_time<Duration>&> st);
+
+ template <class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string_view())),
+ local_time<Duration>
+ >::value
+ >::type>
+ zoned_time(std::string_view name, detail::nodeduct_t<const local_time<Duration>&> tp);
+
+ template <class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string_view())),
+ local_time<Duration>,
+ choose
+ >::value
+ >::type>
+ zoned_time(std::string_view name, detail::nodeduct_t<const local_time<Duration>&> tp, choose c);
+
+ template <class Duration2, class TimeZonePtr2, class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_convertible<sys_time<Duration2>,
+ sys_time<Duration>>::value &&
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string_view())),
+ zoned_time
+ >::value
+ >::type>
+ zoned_time(std::string_view name, const zoned_time<Duration2, TimeZonePtr2>& zt);
+
+ template <class Duration2, class TimeZonePtr2, class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_convertible<sys_time<Duration2>,
+ sys_time<Duration>>::value &&
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string_view())),
+ zoned_time,
+ choose
+ >::value
+ >::type>
+ zoned_time(std::string_view name, const zoned_time<Duration2, TimeZonePtr2>& zt, choose);
+
+#else // !HAS_STRING_VIEW
+
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string())),
+ sys_time<Duration>
+ >::value
+ >::type>
+#endif
+ zoned_time(const std::string& name, const sys_time<Duration>& st);
+
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string())),
+ sys_time<Duration>
+ >::value
+ >::type>
+#endif
+ zoned_time(const char* name, const sys_time<Duration>& st);
+
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string())),
+ local_time<Duration>
+ >::value
+ >::type>
+#endif
+ zoned_time(const std::string& name, const local_time<Duration>& tp);
+
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string())),
+ local_time<Duration>
+ >::value
+ >::type>
+#endif
+ zoned_time(const char* name, const local_time<Duration>& tp);
+
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string())),
+ local_time<Duration>,
+ choose
+ >::value
+ >::type>
+#endif
+ zoned_time(const std::string& name, const local_time<Duration>& tp, choose c);
+
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string())),
+ local_time<Duration>,
+ choose
+ >::value
+ >::type>
+#endif
+ zoned_time(const char* name, const local_time<Duration>& tp, choose c);
+
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class Duration2, class TimeZonePtr2, class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_convertible<sys_time<Duration2>,
+ sys_time<Duration>>::value &&
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string())),
+ zoned_time
+ >::value
+ >::type>
+#else
+ template <class Duration2, class TimeZonePtr2>
+#endif
+ zoned_time(const std::string& name, const zoned_time<Duration2, TimeZonePtr2>& zt);
+
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class Duration2, class TimeZonePtr2, class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_convertible<sys_time<Duration2>,
+ sys_time<Duration>>::value &&
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string())),
+ zoned_time
+ >::value
+ >::type>
+#else
+ template <class Duration2, class TimeZonePtr2>
+#endif
+ zoned_time(const char* name, const zoned_time<Duration2, TimeZonePtr2>& zt);
+
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class Duration2, class TimeZonePtr2, class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_convertible<sys_time<Duration2>,
+ sys_time<Duration>>::value &&
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string())),
+ zoned_time,
+ choose
+ >::value
+ >::type>
+#else
+ template <class Duration2, class TimeZonePtr2>
+#endif
+ zoned_time(const std::string& name, const zoned_time<Duration2, TimeZonePtr2>& zt,
+ choose);
+
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+ template <class Duration2, class TimeZonePtr2, class T = TimeZonePtr,
+ class = typename std::enable_if
+ <
+ std::is_convertible<sys_time<Duration2>,
+ sys_time<Duration>>::value &&
+ std::is_constructible
+ <
+ zoned_time,
+ decltype(zoned_traits<T>::locate_zone(std::string())),
+ zoned_time,
+ choose
+ >::value
+ >::type>
+#else
+ template <class Duration2, class TimeZonePtr2>
+#endif
+ zoned_time(const char* name, const zoned_time<Duration2, TimeZonePtr2>& zt,
+ choose);
+
+#endif // !HAS_STRING_VIEW
+
+ zoned_time& operator=(const sys_time<Duration>& st);
+ zoned_time& operator=(const local_time<Duration>& ut);
+
+ explicit operator sys_time<duration>() const;
+ explicit operator local_time<duration>() const;
+
+ TimeZonePtr get_time_zone() const;
+ local_time<duration> get_local_time() const;
+ sys_time<duration> get_sys_time() const;
+ sys_info get_info() const;
+
+ template <class Duration1, class Duration2, class TimeZonePtr1>
+ friend
+ bool
+ operator==(const zoned_time<Duration1, TimeZonePtr1>& x,
+ const zoned_time<Duration2, TimeZonePtr1>& y);
+
+ template <class CharT, class Traits, class Duration1, class TimeZonePtr1>
+ friend
+ std::basic_ostream<CharT, Traits>&
+ operator<<(std::basic_ostream<CharT, Traits>& os,
+ const zoned_time<Duration1, TimeZonePtr1>& t);
+
+private:
+ template <class D, class T> friend class zoned_time;
+
+ template <class TimeZonePtr2>
+ static
+ TimeZonePtr2&&
+ check(TimeZonePtr2&& p);
+};
+
+using zoned_seconds = zoned_time<std::chrono::seconds>;
+
+#if HAS_DEDUCTION_GUIDES
+
+namespace detail
+{
+ template<typename TimeZonePtrOrName>
+ using time_zone_representation =
+ std::conditional_t
+ <
+ std::is_convertible<TimeZonePtrOrName, std::string_view>::value,
+ time_zone const*,
+ std::remove_cv_t<std::remove_reference_t<TimeZonePtrOrName>>
+ >;
+}
+
+zoned_time()
+ -> zoned_time<std::chrono::seconds>;
+
+template <class Duration>
+zoned_time(sys_time<Duration>)
+ -> zoned_time<std::common_type_t<Duration, std::chrono::seconds>>;
+
+template <class TimeZonePtrOrName>
+zoned_time(TimeZonePtrOrName&&)
+ -> zoned_time<std::chrono::seconds, detail::time_zone_representation<TimeZonePtrOrName>>;
+
+template <class TimeZonePtrOrName, class Duration>
+zoned_time(TimeZonePtrOrName&&, sys_time<Duration>)
+ -> zoned_time<std::common_type_t<Duration, std::chrono::seconds>, detail::time_zone_representation<TimeZonePtrOrName>>;
+
+template <class TimeZonePtrOrName, class Duration>
+zoned_time(TimeZonePtrOrName&&, local_time<Duration>, choose = choose::earliest)
+ -> zoned_time<std::common_type_t<Duration, std::chrono::seconds>, detail::time_zone_representation<TimeZonePtrOrName>>;
+
+template <class Duration, class TimeZonePtrOrName, class TimeZonePtr2>
+zoned_time(TimeZonePtrOrName&&, zoned_time<Duration, TimeZonePtr2>, choose = choose::earliest)
+ -> zoned_time<std::common_type_t<Duration, std::chrono::seconds>, detail::time_zone_representation<TimeZonePtrOrName>>;
+
+#endif // HAS_DEDUCTION_GUIDES
+
+template <class Duration1, class Duration2, class TimeZonePtr>
+inline
+bool
+operator==(const zoned_time<Duration1, TimeZonePtr>& x,
+ const zoned_time<Duration2, TimeZonePtr>& y)
+{
+ return x.zone_ == y.zone_ && x.tp_ == y.tp_;
+}
+
+template <class Duration1, class Duration2, class TimeZonePtr>
+inline
+bool
+operator!=(const zoned_time<Duration1, TimeZonePtr>& x,
+ const zoned_time<Duration2, TimeZonePtr>& y)
+{
+ return !(x == y);
+}
+
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+
+namespace detail
+{
+# if USE_OS_TZDB
+ struct transition;
+ struct expanded_ttinfo;
+# else // !USE_OS_TZDB
+ struct zonelet;
+ class Rule;
+# endif // !USE_OS_TZDB
+}
+
+#endif // !defined(_MSC_VER) || (_MSC_VER >= 1900)
+
+class time_zone
+{
+private:
+ std::string name_;
+#if USE_OS_TZDB
+ std::vector<detail::transition> transitions_;
+ std::vector<detail::expanded_ttinfo> ttinfos_;
+#else // !USE_OS_TZDB
+ std::vector<detail::zonelet> zonelets_;
+#endif // !USE_OS_TZDB
+ std::unique_ptr<std::once_flag> adjusted_;
+
+public:
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+ time_zone(time_zone&&) = default;
+ time_zone& operator=(time_zone&&) = default;
+#else // defined(_MSC_VER) && (_MSC_VER < 1900)
+ time_zone(time_zone&& src);
+ time_zone& operator=(time_zone&& src);
+#endif // defined(_MSC_VER) && (_MSC_VER < 1900)
+
+ DATE_API explicit time_zone(const std::string& s, detail::undocumented);
+
+ const std::string& name() const NOEXCEPT;
+
+ template <class Duration> sys_info get_info(sys_time<Duration> st) const;
+ template <class Duration> local_info get_info(local_time<Duration> tp) const;
+
+ template <class Duration>
+ sys_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+ to_sys(local_time<Duration> tp) const;
+
+ template <class Duration>
+ sys_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+ to_sys(local_time<Duration> tp, choose z) const;
+
+ template <class Duration>
+ local_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+ to_local(sys_time<Duration> tp) const;
+
+ friend bool operator==(const time_zone& x, const time_zone& y) NOEXCEPT;
+ friend bool operator< (const time_zone& x, const time_zone& y) NOEXCEPT;
+ friend DATE_API std::ostream& operator<<(std::ostream& os, const time_zone& z);
+
+#if !USE_OS_TZDB
+ DATE_API void add(const std::string& s);
+#endif // !USE_OS_TZDB
+
+private:
+ DATE_API sys_info get_info_impl(sys_seconds tp) const;
+ DATE_API local_info get_info_impl(local_seconds tp) const;
+
+ template <class Duration>
+ sys_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+ to_sys_impl(local_time<Duration> tp, choose z, std::false_type) const;
+ template <class Duration>
+ sys_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+ to_sys_impl(local_time<Duration> tp, choose, std::true_type) const;
+
+#if USE_OS_TZDB
+ DATE_API void init() const;
+ DATE_API void init_impl();
+ DATE_API sys_info
+ load_sys_info(std::vector<detail::transition>::const_iterator i) const;
+
+ template <class TimeType>
+ DATE_API void
+ load_data(std::istream& inf, std::int32_t tzh_leapcnt, std::int32_t tzh_timecnt,
+ std::int32_t tzh_typecnt, std::int32_t tzh_charcnt);
+#else // !USE_OS_TZDB
+ DATE_API sys_info get_info_impl(sys_seconds tp, int timezone) const;
+ DATE_API void adjust_infos(const std::vector<detail::Rule>& rules);
+ DATE_API void parse_info(std::istream& in);
+#endif // !USE_OS_TZDB
+};
+
+#if defined(_MSC_VER) && (_MSC_VER < 1900)
+
+inline
+time_zone::time_zone(time_zone&& src)
+ : name_(std::move(src.name_))
+ , zonelets_(std::move(src.zonelets_))
+ , adjusted_(std::move(src.adjusted_))
+ {}
+
+inline
+time_zone&
+time_zone::operator=(time_zone&& src)
+{
+ name_ = std::move(src.name_);
+ zonelets_ = std::move(src.zonelets_);
+ adjusted_ = std::move(src.adjusted_);
+ return *this;
+}
+
+#endif // defined(_MSC_VER) && (_MSC_VER < 1900)
+
+inline
+const std::string&
+time_zone::name() const NOEXCEPT
+{
+ return name_;
+}
+
+template <class Duration>
+inline
+sys_info
+time_zone::get_info(sys_time<Duration> st) const
+{
+ return get_info_impl(date::floor<std::chrono::seconds>(st));
+}
+
+template <class Duration>
+inline
+local_info
+time_zone::get_info(local_time<Duration> tp) const
+{
+ return get_info_impl(date::floor<std::chrono::seconds>(tp));
+}
+
+template <class Duration>
+inline
+sys_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+time_zone::to_sys(local_time<Duration> tp) const
+{
+ return to_sys_impl(tp, choose{}, std::true_type{});
+}
+
+template <class Duration>
+inline
+sys_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+time_zone::to_sys(local_time<Duration> tp, choose z) const
+{
+ return to_sys_impl(tp, z, std::false_type{});
+}
+
+template <class Duration>
+inline
+local_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+time_zone::to_local(sys_time<Duration> tp) const
+{
+ using LT = local_time<typename std::common_type<Duration, std::chrono::seconds>::type>;
+ auto i = get_info(tp);
+ return LT{(tp + i.offset).time_since_epoch()};
+}
+
+inline bool operator==(const time_zone& x, const time_zone& y) NOEXCEPT {return x.name_ == y.name_;}
+inline bool operator< (const time_zone& x, const time_zone& y) NOEXCEPT {return x.name_ < y.name_;}
+
+inline bool operator!=(const time_zone& x, const time_zone& y) NOEXCEPT {return !(x == y);}
+inline bool operator> (const time_zone& x, const time_zone& y) NOEXCEPT {return y < x;}
+inline bool operator<=(const time_zone& x, const time_zone& y) NOEXCEPT {return !(y < x);}
+inline bool operator>=(const time_zone& x, const time_zone& y) NOEXCEPT {return !(x < y);}
+
+template <class Duration>
+sys_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+time_zone::to_sys_impl(local_time<Duration> tp, choose z, std::false_type) const
+{
+ auto i = get_info(tp);
+ if (i.result == local_info::nonexistent)
+ {
+ return i.first.end;
+ }
+ else if (i.result == local_info::ambiguous)
+ {
+ if (z == choose::latest)
+ return sys_time<Duration>{tp.time_since_epoch()} - i.second.offset;
+ }
+ return sys_time<Duration>{tp.time_since_epoch()} - i.first.offset;
+}
+
+template <class Duration>
+sys_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+time_zone::to_sys_impl(local_time<Duration> tp, choose, std::true_type) const
+{
+ auto i = get_info(tp);
+ if (i.result == local_info::nonexistent)
+ throw nonexistent_local_time(tp, i);
+ else if (i.result == local_info::ambiguous)
+ throw ambiguous_local_time(tp, i);
+ return sys_time<Duration>{tp.time_since_epoch()} - i.first.offset;
+}
+
+#if !USE_OS_TZDB
+
+class time_zone_link
+{
+private:
+ std::string name_;
+ std::string target_;
+public:
+ DATE_API explicit time_zone_link(const std::string& s);
+
+ const std::string& name() const {return name_;}
+ const std::string& target() const {return target_;}
+
+ friend bool operator==(const time_zone_link& x, const time_zone_link& y) {return x.name_ == y.name_;}
+ friend bool operator< (const time_zone_link& x, const time_zone_link& y) {return x.name_ < y.name_;}
+
+ friend DATE_API std::ostream& operator<<(std::ostream& os, const time_zone_link& x);
+};
+
+using link = time_zone_link;
+
+inline bool operator!=(const time_zone_link& x, const time_zone_link& y) {return !(x == y);}
+inline bool operator> (const time_zone_link& x, const time_zone_link& y) {return y < x;}
+inline bool operator<=(const time_zone_link& x, const time_zone_link& y) {return !(y < x);}
+inline bool operator>=(const time_zone_link& x, const time_zone_link& y) {return !(x < y);}
+
+#endif // !USE_OS_TZDB
+
+class leap_second
+{
+private:
+ sys_seconds date_;
+
+public:
+#if USE_OS_TZDB
+ DATE_API explicit leap_second(const sys_seconds& s, detail::undocumented);
+#else
+ DATE_API explicit leap_second(const std::string& s, detail::undocumented);
+#endif
+
+ sys_seconds date() const {return date_;}
+
+ friend bool operator==(const leap_second& x, const leap_second& y) {return x.date_ == y.date_;}
+ friend bool operator< (const leap_second& x, const leap_second& y) {return x.date_ < y.date_;}
+
+ template <class Duration>
+ friend
+ bool
+ operator==(const leap_second& x, const sys_time<Duration>& y)
+ {
+ return x.date_ == y;
+ }
+
+ template <class Duration>
+ friend
+ bool
+ operator< (const leap_second& x, const sys_time<Duration>& y)
+ {
+ return x.date_ < y;
+ }
+
+ template <class Duration>
+ friend
+ bool
+ operator< (const sys_time<Duration>& x, const leap_second& y)
+ {
+ return x < y.date_;
+ }
+
+ friend DATE_API std::ostream& operator<<(std::ostream& os, const leap_second& x);
+};
+
+inline bool operator!=(const leap_second& x, const leap_second& y) {return !(x == y);}
+inline bool operator> (const leap_second& x, const leap_second& y) {return y < x;}
+inline bool operator<=(const leap_second& x, const leap_second& y) {return !(y < x);}
+inline bool operator>=(const leap_second& x, const leap_second& y) {return !(x < y);}
+
+template <class Duration>
+inline
+bool
+operator==(const sys_time<Duration>& x, const leap_second& y)
+{
+ return y == x;
+}
+
+template <class Duration>
+inline
+bool
+operator!=(const leap_second& x, const sys_time<Duration>& y)
+{
+ return !(x == y);
+}
+
+template <class Duration>
+inline
+bool
+operator!=(const sys_time<Duration>& x, const leap_second& y)
+{
+ return !(x == y);
+}
+
+template <class Duration>
+inline
+bool
+operator> (const leap_second& x, const sys_time<Duration>& y)
+{
+ return y < x;
+}
+
+template <class Duration>
+inline
+bool
+operator> (const sys_time<Duration>& x, const leap_second& y)
+{
+ return y < x;
+}
+
+template <class Duration>
+inline
+bool
+operator<=(const leap_second& x, const sys_time<Duration>& y)
+{
+ return !(y < x);
+}
+
+template <class Duration>
+inline
+bool
+operator<=(const sys_time<Duration>& x, const leap_second& y)
+{
+ return !(y < x);
+}
+
+template <class Duration>
+inline
+bool
+operator>=(const leap_second& x, const sys_time<Duration>& y)
+{
+ return !(x < y);
+}
+
+template <class Duration>
+inline
+bool
+operator>=(const sys_time<Duration>& x, const leap_second& y)
+{
+ return !(x < y);
+}
+
+using leap = leap_second;
+
+#ifdef _WIN32
+
+namespace detail
+{
+
+// The time zone mapping is modelled after this data file:
+// http://unicode.org/repos/cldr/trunk/common/supplemental/windowsZones.xml
+// and the field names match the element names from the mapZone element
+// of windowsZones.xml.
+// The website displays this file here:
+// http://www.unicode.org/cldr/charts/latest/supplemental/zone_tzid.html
+// The html view is sorted before being displayed but is otherwise the same
+// There is a mapping between the os centric view (in this case windows)
+// the html displays uses and the generic view the xml file.
+// That mapping is this:
+// display column "windows" -> xml field "other".
+// display column "region" -> xml field "territory".
+// display column "tzid" -> xml field "type".
+// This structure uses the generic terminology because it could be
+// used to to support other os/native name conversions, not just windows,
+// and using the same generic names helps retain the connection to the
+// origin of the data that we are using.
+struct timezone_mapping
+{
+ timezone_mapping(const char* other, const char* territory, const char* type)
+ : other(other), territory(territory), type(type)
+ {
+ }
+ timezone_mapping() = default;
+ std::string other;
+ std::string territory;
+ std::string type;
+};
+
+} // detail
+
+#endif // _WIN32
+
+struct tzdb
+{
+ std::string version = "unknown";
+ std::vector<time_zone> zones;
+#if !USE_OS_TZDB
+ std::vector<time_zone_link> links;
+#endif
+ std::vector<leap_second> leap_seconds;
+#if !USE_OS_TZDB
+ std::vector<detail::Rule> rules;
+#endif
+#ifdef _WIN32
+ std::vector<detail::timezone_mapping> mappings;
+#endif
+ tzdb* next = nullptr;
+
+ tzdb() = default;
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+ tzdb(tzdb&&) = default;
+ tzdb& operator=(tzdb&&) = default;
+#else // defined(_MSC_VER) && (_MSC_VER < 1900)
+ tzdb(tzdb&& src)
+ : version(std::move(src.version))
+ , zones(std::move(src.zones))
+ , links(std::move(src.links))
+ , leap_seconds(std::move(src.leap_seconds))
+ , rules(std::move(src.rules))
+ , mappings(std::move(src.mappings))
+ {}
+
+ tzdb& operator=(tzdb&& src)
+ {
+ version = std::move(src.version);
+ zones = std::move(src.zones);
+ links = std::move(src.links);
+ leap_seconds = std::move(src.leap_seconds);
+ rules = std::move(src.rules);
+ mappings = std::move(src.mappings);
+ return *this;
+ }
+#endif // defined(_MSC_VER) && (_MSC_VER < 1900)
+
+#if HAS_STRING_VIEW
+ const time_zone* locate_zone(std::string_view tz_name) const;
+#else
+ const time_zone* locate_zone(const std::string& tz_name) const;
+#endif
+ const time_zone* current_zone() const;
+};
+
+using TZ_DB = tzdb;
+
+DATE_API std::ostream&
+operator<<(std::ostream& os, const tzdb& db);
+
+DATE_API const tzdb& get_tzdb();
+
+class tzdb_list
+{
+ std::atomic<tzdb*> head_{nullptr};
+
+public:
+ ~tzdb_list();
+ tzdb_list() = default;
+ tzdb_list(tzdb_list&& x) NOEXCEPT;
+
+ const tzdb& front() const NOEXCEPT {return *head_;}
+ tzdb& front() NOEXCEPT {return *head_;}
+
+ class const_iterator;
+
+ const_iterator begin() const NOEXCEPT;
+ const_iterator end() const NOEXCEPT;
+
+ const_iterator cbegin() const NOEXCEPT;
+ const_iterator cend() const NOEXCEPT;
+
+ const_iterator erase_after(const_iterator p) NOEXCEPT;
+
+ struct undocumented_helper;
+private:
+ void push_front(tzdb* tzdb) NOEXCEPT;
+};
+
+class tzdb_list::const_iterator
+{
+ tzdb* p_ = nullptr;
+
+ explicit const_iterator(tzdb* p) NOEXCEPT : p_{p} {}
+public:
+ const_iterator() = default;
+
+ using iterator_category = std::forward_iterator_tag;
+ using value_type = tzdb;
+ using reference = const value_type&;
+ using pointer = const value_type*;
+ using difference_type = std::ptrdiff_t;
+
+ reference operator*() const NOEXCEPT {return *p_;}
+ pointer operator->() const NOEXCEPT {return p_;}
+
+ const_iterator& operator++() NOEXCEPT {p_ = p_->next; return *this;}
+ const_iterator operator++(int) NOEXCEPT {auto t = *this; ++(*this); return t;}
+
+ friend
+ bool
+ operator==(const const_iterator& x, const const_iterator& y) NOEXCEPT
+ {return x.p_ == y.p_;}
+
+ friend
+ bool
+ operator!=(const const_iterator& x, const const_iterator& y) NOEXCEPT
+ {return !(x == y);}
+
+ friend class tzdb_list;
+};
+
+inline
+tzdb_list::const_iterator
+tzdb_list::begin() const NOEXCEPT
+{
+ return const_iterator{head_};
+}
+
+inline
+tzdb_list::const_iterator
+tzdb_list::end() const NOEXCEPT
+{
+ return const_iterator{nullptr};
+}
+
+inline
+tzdb_list::const_iterator
+tzdb_list::cbegin() const NOEXCEPT
+{
+ return begin();
+}
+
+inline
+tzdb_list::const_iterator
+tzdb_list::cend() const NOEXCEPT
+{
+ return end();
+}
+
+DATE_API tzdb_list& get_tzdb_list();
+
+#if !USE_OS_TZDB
+
+DATE_API const tzdb& reload_tzdb();
+DATE_API void set_install(const std::string& install);
+
+#endif // !USE_OS_TZDB
+
+#if HAS_REMOTE_API
+
+DATE_API std::string remote_version();
+// if provided error_buffer size should be at least CURL_ERROR_SIZE
+DATE_API bool remote_download(const std::string& version, char* error_buffer = nullptr);
+DATE_API bool remote_install(const std::string& version);
+
+#endif
+
+// zoned_time
+
+namespace detail
+{
+
+template <class T>
+inline
+T*
+to_raw_pointer(T* p) NOEXCEPT
+{
+ return p;
+}
+
+template <class Pointer>
+inline
+auto
+to_raw_pointer(Pointer p) NOEXCEPT
+ -> decltype(detail::to_raw_pointer(p.operator->()))
+{
+ return detail::to_raw_pointer(p.operator->());
+}
+
+} // namespace detail
+
+template <class Duration, class TimeZonePtr>
+template <class TimeZonePtr2>
+inline
+TimeZonePtr2&&
+zoned_time<Duration, TimeZonePtr>::check(TimeZonePtr2&& p)
+{
+ if (detail::to_raw_pointer(p) == nullptr)
+ throw std::runtime_error(
+ "zoned_time constructed with a time zone pointer == nullptr");
+ return std::forward<TimeZonePtr2>(p);
+}
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class T, class>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time()
+ : zone_(check(zoned_traits<TimeZonePtr>::default_zone()))
+ {}
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class T, class>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(const sys_time<Duration>& st)
+ : zone_(check(zoned_traits<TimeZonePtr>::default_zone()))
+ , tp_(st)
+ {}
+
+template <class Duration, class TimeZonePtr>
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(TimeZonePtr z)
+ : zone_(check(std::move(z)))
+ {}
+
+#if HAS_STRING_VIEW
+
+template <class Duration, class TimeZonePtr>
+template <class T, class>
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(std::string_view name)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name))
+ {}
+
+#else // !HAS_STRING_VIEW
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class T, class>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(const std::string& name)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name))
+ {}
+
+#endif // !HAS_STRING_VIEW
+
+template <class Duration, class TimeZonePtr>
+template <class Duration2, class>
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(const zoned_time<Duration2, TimeZonePtr>& zt) NOEXCEPT
+ : zone_(zt.zone_)
+ , tp_(zt.tp_)
+ {}
+
+template <class Duration, class TimeZonePtr>
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(TimeZonePtr z, const sys_time<Duration>& st)
+ : zone_(check(std::move(z)))
+ , tp_(st)
+ {}
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class T, class>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(TimeZonePtr z, const local_time<Duration>& t)
+ : zone_(check(std::move(z)))
+ , tp_(zone_->to_sys(t))
+ {}
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class T, class>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(TimeZonePtr z, const local_time<Duration>& t,
+ choose c)
+ : zone_(check(std::move(z)))
+ , tp_(zone_->to_sys(t, c))
+ {}
+
+template <class Duration, class TimeZonePtr>
+template <class Duration2, class TimeZonePtr2, class>
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(TimeZonePtr z,
+ const zoned_time<Duration2, TimeZonePtr2>& zt)
+ : zone_(check(std::move(z)))
+ , tp_(zt.tp_)
+ {}
+
+template <class Duration, class TimeZonePtr>
+template <class Duration2, class TimeZonePtr2, class>
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(TimeZonePtr z,
+ const zoned_time<Duration2, TimeZonePtr2>& zt, choose)
+ : zoned_time(std::move(z), zt)
+ {}
+
+#if HAS_STRING_VIEW
+
+template <class Duration, class TimeZonePtr>
+template <class T, class>
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(std::string_view name,
+ detail::nodeduct_t<const sys_time<Duration>&> st)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), st)
+ {}
+
+template <class Duration, class TimeZonePtr>
+template <class T, class>
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(std::string_view name,
+ detail::nodeduct_t<const local_time<Duration>&> t)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), t)
+ {}
+
+template <class Duration, class TimeZonePtr>
+template <class T, class>
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(std::string_view name,
+ detail::nodeduct_t<const local_time<Duration>&> t, choose c)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), t, c)
+ {}
+
+template <class Duration, class TimeZonePtr>
+template <class Duration2, class TimeZonePtr2, class, class>
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(std::string_view name,
+ const zoned_time<Duration2, TimeZonePtr2>& zt)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), zt)
+ {}
+
+template <class Duration, class TimeZonePtr>
+template <class Duration2, class TimeZonePtr2, class, class>
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(std::string_view name,
+ const zoned_time<Duration2, TimeZonePtr2>& zt,
+ choose c)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), zt, c)
+ {}
+
+#else // !HAS_STRING_VIEW
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class T, class>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(const std::string& name,
+ const sys_time<Duration>& st)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), st)
+ {}
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class T, class>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(const char* name,
+ const sys_time<Duration>& st)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), st)
+ {}
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class T, class>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(const std::string& name,
+ const local_time<Duration>& t)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), t)
+ {}
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class T, class>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(const char* name,
+ const local_time<Duration>& t)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), t)
+ {}
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class T, class>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(const std::string& name,
+ const local_time<Duration>& t, choose c)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), t, c)
+ {}
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class T, class>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(const char* name,
+ const local_time<Duration>& t, choose c)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), t, c)
+ {}
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class Duration2, class TimeZonePtr2, class, class>
+#else
+template <class Duration2, class TimeZonePtr2>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(const std::string& name,
+ const zoned_time<Duration2, TimeZonePtr2>& zt)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), zt)
+ {}
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class Duration2, class TimeZonePtr2, class, class>
+#else
+template <class Duration2, class TimeZonePtr2>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(const char* name,
+ const zoned_time<Duration2, TimeZonePtr2>& zt)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), zt)
+ {}
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class Duration2, class TimeZonePtr2, class, class>
+#else
+template <class Duration2, class TimeZonePtr2>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(const std::string& name,
+ const zoned_time<Duration2, TimeZonePtr2>& zt,
+ choose c)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), zt, c)
+ {}
+
+template <class Duration, class TimeZonePtr>
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+template <class Duration2, class TimeZonePtr2, class, class>
+#else
+template <class Duration2, class TimeZonePtr2>
+#endif
+inline
+zoned_time<Duration, TimeZonePtr>::zoned_time(const char* name,
+ const zoned_time<Duration2, TimeZonePtr2>& zt,
+ choose c)
+ : zoned_time(zoned_traits<TimeZonePtr>::locate_zone(name), zt, c)
+ {}
+
+#endif // HAS_STRING_VIEW
+
+template <class Duration, class TimeZonePtr>
+inline
+zoned_time<Duration, TimeZonePtr>&
+zoned_time<Duration, TimeZonePtr>::operator=(const sys_time<Duration>& st)
+{
+ tp_ = st;
+ return *this;
+}
+
+template <class Duration, class TimeZonePtr>
+inline
+zoned_time<Duration, TimeZonePtr>&
+zoned_time<Duration, TimeZonePtr>::operator=(const local_time<Duration>& ut)
+{
+ tp_ = zone_->to_sys(ut);
+ return *this;
+}
+
+template <class Duration, class TimeZonePtr>
+inline
+zoned_time<Duration, TimeZonePtr>::operator local_time<typename zoned_time<Duration, TimeZonePtr>::duration>() const
+{
+ return get_local_time();
+}
+
+template <class Duration, class TimeZonePtr>
+inline
+zoned_time<Duration, TimeZonePtr>::operator sys_time<typename zoned_time<Duration, TimeZonePtr>::duration>() const
+{
+ return get_sys_time();
+}
+
+template <class Duration, class TimeZonePtr>
+inline
+TimeZonePtr
+zoned_time<Duration, TimeZonePtr>::get_time_zone() const
+{
+ return zone_;
+}
+
+template <class Duration, class TimeZonePtr>
+inline
+local_time<typename zoned_time<Duration, TimeZonePtr>::duration>
+zoned_time<Duration, TimeZonePtr>::get_local_time() const
+{
+ return zone_->to_local(tp_);
+}
+
+template <class Duration, class TimeZonePtr>
+inline
+sys_time<typename zoned_time<Duration, TimeZonePtr>::duration>
+zoned_time<Duration, TimeZonePtr>::get_sys_time() const
+{
+ return tp_;
+}
+
+template <class Duration, class TimeZonePtr>
+inline
+sys_info
+zoned_time<Duration, TimeZonePtr>::get_info() const
+{
+ return zone_->get_info(tp_);
+}
+
+// make_zoned_time
+
+inline
+zoned_time<std::chrono::seconds>
+make_zoned()
+{
+ return zoned_time<std::chrono::seconds>();
+}
+
+template <class Duration>
+inline
+zoned_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+make_zoned(const sys_time<Duration>& tp)
+{
+ return zoned_time<typename std::common_type<Duration, std::chrono::seconds>::type>(tp);
+}
+
+template <class TimeZonePtr
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+#if !defined(__INTEL_COMPILER) || (__INTEL_COMPILER > 1600)
+ , class = typename std::enable_if
+ <
+ std::is_class
+ <
+ typename std::decay
+ <
+ decltype(*detail::to_raw_pointer(std::declval<TimeZonePtr&>()))
+ >::type
+ >{}
+ >::type
+#endif
+#endif
+ >
+inline
+zoned_time<std::chrono::seconds, TimeZonePtr>
+make_zoned(TimeZonePtr z)
+{
+ return zoned_time<std::chrono::seconds, TimeZonePtr>(std::move(z));
+}
+
+inline
+zoned_seconds
+make_zoned(const std::string& name)
+{
+ return zoned_seconds(name);
+}
+
+template <class Duration, class TimeZonePtr
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+#if !defined(__INTEL_COMPILER) || (__INTEL_COMPILER > 1600)
+ , class = typename std::enable_if
+ <
+ std::is_class<typename std::decay<decltype(*std::declval<TimeZonePtr&>())>::type>{}
+ >::type
+#endif
+#endif
+ >
+inline
+zoned_time<typename std::common_type<Duration, std::chrono::seconds>::type, TimeZonePtr>
+make_zoned(TimeZonePtr zone, const local_time<Duration>& tp)
+{
+ return zoned_time<typename std::common_type<Duration, std::chrono::seconds>::type,
+ TimeZonePtr>(std::move(zone), tp);
+}
+
+template <class Duration, class TimeZonePtr
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+#if !defined(__INTEL_COMPILER) || (__INTEL_COMPILER > 1600)
+ , class = typename std::enable_if
+ <
+ std::is_class<typename std::decay<decltype(*std::declval<TimeZonePtr&>())>::type>{}
+ >::type
+#endif
+#endif
+ >
+inline
+zoned_time<typename std::common_type<Duration, std::chrono::seconds>::type, TimeZonePtr>
+make_zoned(TimeZonePtr zone, const local_time<Duration>& tp, choose c)
+{
+ return zoned_time<typename std::common_type<Duration, std::chrono::seconds>::type,
+ TimeZonePtr>(std::move(zone), tp, c);
+}
+
+template <class Duration>
+inline
+zoned_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+make_zoned(const std::string& name, const local_time<Duration>& tp)
+{
+ return zoned_time<typename std::common_type<Duration,
+ std::chrono::seconds>::type>(name, tp);
+}
+
+template <class Duration>
+inline
+zoned_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+make_zoned(const std::string& name, const local_time<Duration>& tp, choose c)
+{
+ return zoned_time<typename std::common_type<Duration,
+ std::chrono::seconds>::type>(name, tp, c);
+}
+
+template <class Duration, class TimeZonePtr>
+inline
+zoned_time<Duration, TimeZonePtr>
+make_zoned(TimeZonePtr zone, const zoned_time<Duration, TimeZonePtr>& zt)
+{
+ return zoned_time<Duration, TimeZonePtr>(std::move(zone), zt);
+}
+
+template <class Duration, class TimeZonePtr>
+inline
+zoned_time<Duration, TimeZonePtr>
+make_zoned(const std::string& name, const zoned_time<Duration, TimeZonePtr>& zt)
+{
+ return zoned_time<Duration, TimeZonePtr>(name, zt);
+}
+
+template <class Duration, class TimeZonePtr>
+inline
+zoned_time<Duration, TimeZonePtr>
+make_zoned(TimeZonePtr zone, const zoned_time<Duration, TimeZonePtr>& zt, choose c)
+{
+ return zoned_time<Duration, TimeZonePtr>(std::move(zone), zt, c);
+}
+
+template <class Duration, class TimeZonePtr>
+inline
+zoned_time<Duration, TimeZonePtr>
+make_zoned(const std::string& name, const zoned_time<Duration, TimeZonePtr>& zt, choose c)
+{
+ return zoned_time<Duration, TimeZonePtr>(name, zt, c);
+}
+
+template <class Duration, class TimeZonePtr
+#if !defined(_MSC_VER) || (_MSC_VER > 1916)
+#if !defined(__INTEL_COMPILER) || (__INTEL_COMPILER > 1600)
+ , class = typename std::enable_if
+ <
+ std::is_class<typename std::decay<decltype(*std::declval<TimeZonePtr&>())>::type>{}
+ >::type
+#endif
+#endif
+ >
+inline
+zoned_time<typename std::common_type<Duration, std::chrono::seconds>::type, TimeZonePtr>
+make_zoned(TimeZonePtr zone, const sys_time<Duration>& st)
+{
+ return zoned_time<typename std::common_type<Duration, std::chrono::seconds>::type,
+ TimeZonePtr>(std::move(zone), st);
+}
+
+template <class Duration>
+inline
+zoned_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+make_zoned(const std::string& name, const sys_time<Duration>& st)
+{
+ return zoned_time<typename std::common_type<Duration,
+ std::chrono::seconds>::type>(name, st);
+}
+
+template <class CharT, class Traits, class Duration, class TimeZonePtr>
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt,
+ const zoned_time<Duration, TimeZonePtr>& tp)
+{
+ using duration = typename zoned_time<Duration, TimeZonePtr>::duration;
+ using LT = local_time<duration>;
+ auto const st = tp.get_sys_time();
+ auto const info = tp.get_time_zone()->get_info(st);
+ return to_stream(os, fmt, LT{(st+info.offset).time_since_epoch()},
+ &info.abbrev, &info.offset);
+}
+
+template <class CharT, class Traits, class Duration, class TimeZonePtr>
+inline
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const zoned_time<Duration, TimeZonePtr>& t)
+{
+ const CharT fmt[] = {'%', 'F', ' ', '%', 'T', ' ', '%', 'Z', CharT{}};
+ return to_stream(os, fmt, t);
+}
+
+class utc_clock
+{
+public:
+ using duration = std::chrono::system_clock::duration;
+ using rep = duration::rep;
+ using period = duration::period;
+ using time_point = std::chrono::time_point<utc_clock>;
+ static CONSTDATA bool is_steady = false;
+
+ static time_point now();
+
+ template<typename Duration>
+ static
+ std::chrono::time_point<std::chrono::system_clock, typename std::common_type<Duration, std::chrono::seconds>::type>
+ to_sys(const std::chrono::time_point<utc_clock, Duration>&);
+
+ template<typename Duration>
+ static
+ std::chrono::time_point<utc_clock, typename std::common_type<Duration, std::chrono::seconds>::type>
+ from_sys(const std::chrono::time_point<std::chrono::system_clock, Duration>&);
+
+ template<typename Duration>
+ static
+ std::chrono::time_point<local_t, typename std::common_type<Duration, std::chrono::seconds>::type>
+ to_local(const std::chrono::time_point<utc_clock, Duration>&);
+
+ template<typename Duration>
+ static
+ std::chrono::time_point<utc_clock, typename std::common_type<Duration, std::chrono::seconds>::type>
+ from_local(const std::chrono::time_point<local_t, Duration>&);
+};
+
+template <class Duration>
+ using utc_time = std::chrono::time_point<utc_clock, Duration>;
+
+using utc_seconds = utc_time<std::chrono::seconds>;
+
+template <class Duration>
+utc_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+utc_clock::from_sys(const sys_time<Duration>& st)
+{
+ using std::chrono::seconds;
+ using CD = typename std::common_type<Duration, seconds>::type;
+ auto const& leaps = get_tzdb().leap_seconds;
+ auto const lt = std::upper_bound(leaps.begin(), leaps.end(), st);
+ return utc_time<CD>{st.time_since_epoch() + seconds{lt-leaps.begin()}};
+}
+
+// Return pair<is_leap_second, seconds{number_of_leap_seconds_since_1970}>
+// first is true if ut is during a leap second insertion, otherwise false.
+// If ut is during a leap second insertion, that leap second is included in the count
+template <class Duration>
+std::pair<bool, std::chrono::seconds>
+is_leap_second(date::utc_time<Duration> const& ut)
+{
+ using std::chrono::seconds;
+ using duration = typename std::common_type<Duration, seconds>::type;
+ auto const& leaps = get_tzdb().leap_seconds;
+ auto tp = sys_time<duration>{ut.time_since_epoch()};
+ auto const lt = std::upper_bound(leaps.begin(), leaps.end(), tp);
+ auto ds = seconds{lt-leaps.begin()};
+ tp -= ds;
+ auto ls = false;
+ if (lt > leaps.begin())
+ {
+ if (tp < lt[-1])
+ {
+ if (tp >= lt[-1].date() - seconds{1})
+ ls = true;
+ else
+ --ds;
+ }
+ }
+ return {ls, ds};
+}
+
+struct leap_second_info
+{
+ bool is_leap_second;
+ std::chrono::seconds elapsed;
+};
+
+template <class Duration>
+leap_second_info
+get_leap_second_info(date::utc_time<Duration> const& ut)
+{
+ auto p = is_leap_second(ut);
+ return {p.first, p.second};
+}
+
+template <class Duration>
+sys_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+utc_clock::to_sys(const utc_time<Duration>& ut)
+{
+ using std::chrono::seconds;
+ using CD = typename std::common_type<Duration, seconds>::type;
+ auto ls = is_leap_second(ut);
+ auto tp = sys_time<CD>{ut.time_since_epoch() - ls.second};
+ if (ls.first)
+ tp = floor<seconds>(tp) + seconds{1} - CD{1};
+ return tp;
+}
+
+inline
+utc_clock::time_point
+utc_clock::now()
+{
+ return from_sys(std::chrono::system_clock::now());
+}
+
+template <class Duration>
+utc_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+utc_clock::from_local(const local_time<Duration>& st)
+{
+ return from_sys(sys_time<Duration>{st.time_since_epoch()});
+}
+
+template <class Duration>
+local_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+utc_clock::to_local(const utc_time<Duration>& ut)
+{
+ using CD = typename std::common_type<Duration, std::chrono::seconds>::type;
+ return local_time<CD>{to_sys(ut).time_since_epoch()};
+}
+
+template <class CharT, class Traits, class Duration>
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt,
+ const utc_time<Duration>& t)
+{
+ using std::chrono::seconds;
+ using CT = typename std::common_type<Duration, seconds>::type;
+ const std::string abbrev("UTC");
+ CONSTDATA seconds offset{0};
+ auto ls = is_leap_second(t);
+ auto tp = sys_time<CT>{t.time_since_epoch() - ls.second};
+ auto const sd = floor<days>(tp);
+ year_month_day ymd = sd;
+ auto time = make_time(tp - sys_seconds{sd});
+ time.seconds(detail::undocumented{}) += seconds{ls.first};
+ fields<CT> fds{ymd, time};
+ return to_stream(os, fmt, fds, &abbrev, &offset);
+}
+
+template <class CharT, class Traits, class Duration>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const utc_time<Duration>& t)
+{
+ const CharT fmt[] = {'%', 'F', ' ', '%', 'T', CharT{}};
+ return to_stream(os, fmt, t);
+}
+
+template <class Duration, class CharT, class Traits, class Alloc = std::allocator<CharT>>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt,
+ utc_time<Duration>& tp, std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+{
+ using std::chrono::seconds;
+ using std::chrono::minutes;
+ using CT = typename std::common_type<Duration, seconds>::type;
+ minutes offset_local{};
+ auto offptr = offset ? offset : &offset_local;
+ fields<CT> fds{};
+ fds.has_tod = true;
+ from_stream(is, fmt, fds, abbrev, offptr);
+ if (!fds.ymd.ok())
+ is.setstate(std::ios::failbit);
+ if (!is.fail())
+ {
+ bool is_60_sec = fds.tod.seconds() == seconds{60};
+ if (is_60_sec)
+ fds.tod.seconds(detail::undocumented{}) -= seconds{1};
+ auto tmp = utc_clock::from_sys(sys_days(fds.ymd) - *offptr + fds.tod.to_duration());
+ if (is_60_sec)
+ tmp += seconds{1};
+ if (is_60_sec != is_leap_second(tmp).first || !fds.tod.in_conventional_range())
+ {
+ is.setstate(std::ios::failbit);
+ return is;
+ }
+ tp = std::chrono::time_point_cast<Duration>(tmp);
+ }
+ return is;
+}
+
+// tai_clock
+
+class tai_clock
+{
+public:
+ using duration = std::chrono::system_clock::duration;
+ using rep = duration::rep;
+ using period = duration::period;
+ using time_point = std::chrono::time_point<tai_clock>;
+ static const bool is_steady = false;
+
+ static time_point now();
+
+ template<typename Duration>
+ static
+ std::chrono::time_point<utc_clock, typename std::common_type<Duration, std::chrono::seconds>::type>
+ to_utc(const std::chrono::time_point<tai_clock, Duration>&) NOEXCEPT;
+
+ template<typename Duration>
+ static
+ std::chrono::time_point<tai_clock, typename std::common_type<Duration, std::chrono::seconds>::type>
+ from_utc(const std::chrono::time_point<utc_clock, Duration>&) NOEXCEPT;
+
+ template<typename Duration>
+ static
+ std::chrono::time_point<local_t, typename std::common_type<Duration, date::days>::type>
+ to_local(const std::chrono::time_point<tai_clock, Duration>&) NOEXCEPT;
+
+ template<typename Duration>
+ static
+ std::chrono::time_point<tai_clock, typename std::common_type<Duration, date::days>::type>
+ from_local(const std::chrono::time_point<local_t, Duration>&) NOEXCEPT;
+};
+
+template <class Duration>
+ using tai_time = std::chrono::time_point<tai_clock, Duration>;
+
+using tai_seconds = tai_time<std::chrono::seconds>;
+
+template <class Duration>
+inline
+utc_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+tai_clock::to_utc(const tai_time<Duration>& t) NOEXCEPT
+{
+ using std::chrono::seconds;
+ using CD = typename std::common_type<Duration, seconds>::type;
+ return utc_time<CD>{t.time_since_epoch()} -
+ (sys_days(year{1970}/January/1) - sys_days(year{1958}/January/1) + seconds{10});
+}
+
+template <class Duration>
+inline
+tai_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+tai_clock::from_utc(const utc_time<Duration>& t) NOEXCEPT
+{
+ using std::chrono::seconds;
+ using CD = typename std::common_type<Duration, seconds>::type;
+ return tai_time<CD>{t.time_since_epoch()} +
+ (sys_days(year{1970}/January/1) - sys_days(year{1958}/January/1) + seconds{10});
+}
+
+inline
+tai_clock::time_point
+tai_clock::now()
+{
+ return from_utc(utc_clock::now());
+}
+
+template <class Duration>
+inline
+local_time<typename std::common_type<Duration, date::days>::type>
+tai_clock::to_local(const tai_time<Duration>& t) NOEXCEPT
+{
+ using CD = typename std::common_type<Duration, date::days>::type;
+ return local_time<CD>{t.time_since_epoch()} -
+ (local_days(year{1970}/January/1) - local_days(year{1958}/January/1));
+}
+
+template <class Duration>
+inline
+tai_time<typename std::common_type<Duration, date::days>::type>
+tai_clock::from_local(const local_time<Duration>& t) NOEXCEPT
+{
+ using CD = typename std::common_type<Duration, date::days>::type;
+ return tai_time<CD>{t.time_since_epoch()} +
+ (local_days(year{1970}/January/1) - local_days(year{1958}/January/1));
+}
+
+template <class CharT, class Traits, class Duration>
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt,
+ const tai_time<Duration>& t)
+{
+ const std::string abbrev("TAI");
+ CONSTDATA std::chrono::seconds offset{0};
+ return to_stream(os, fmt, tai_clock::to_local(t), &abbrev, &offset);
+}
+
+template <class CharT, class Traits, class Duration>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const tai_time<Duration>& t)
+{
+ const CharT fmt[] = {'%', 'F', ' ', '%', 'T', CharT{}};
+ return to_stream(os, fmt, t);
+}
+
+template <class Duration, class CharT, class Traits, class Alloc = std::allocator<CharT>>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt,
+ tai_time<Duration>& tp,
+ std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+{
+ local_time<Duration> lp;
+ from_stream(is, fmt, lp, abbrev, offset);
+ if (!is.fail())
+ tp = tai_clock::from_local(lp);
+ return is;
+}
+
+// gps_clock
+
+class gps_clock
+{
+public:
+ using duration = std::chrono::system_clock::duration;
+ using rep = duration::rep;
+ using period = duration::period;
+ using time_point = std::chrono::time_point<gps_clock>;
+ static const bool is_steady = false;
+
+ static time_point now();
+
+ template<typename Duration>
+ static
+ std::chrono::time_point<utc_clock, typename std::common_type<Duration, std::chrono::seconds>::type>
+ to_utc(const std::chrono::time_point<gps_clock, Duration>&) NOEXCEPT;
+
+ template<typename Duration>
+ static
+ std::chrono::time_point<gps_clock, typename std::common_type<Duration, std::chrono::seconds>::type>
+ from_utc(const std::chrono::time_point<utc_clock, Duration>&) NOEXCEPT;
+
+ template<typename Duration>
+ static
+ std::chrono::time_point<local_t, typename std::common_type<Duration, date::days>::type>
+ to_local(const std::chrono::time_point<gps_clock, Duration>&) NOEXCEPT;
+
+ template<typename Duration>
+ static
+ std::chrono::time_point<gps_clock, typename std::common_type<Duration, date::days>::type>
+ from_local(const std::chrono::time_point<local_t, Duration>&) NOEXCEPT;
+};
+
+template <class Duration>
+ using gps_time = std::chrono::time_point<gps_clock, Duration>;
+
+using gps_seconds = gps_time<std::chrono::seconds>;
+
+template <class Duration>
+inline
+utc_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+gps_clock::to_utc(const gps_time<Duration>& t) NOEXCEPT
+{
+ using std::chrono::seconds;
+ using CD = typename std::common_type<Duration, seconds>::type;
+ return utc_time<CD>{t.time_since_epoch()} +
+ (sys_days(year{1980}/January/Sunday[1]) - sys_days(year{1970}/January/1) +
+ seconds{9});
+}
+
+template <class Duration>
+inline
+gps_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+gps_clock::from_utc(const utc_time<Duration>& t) NOEXCEPT
+{
+ using std::chrono::seconds;
+ using CD = typename std::common_type<Duration, seconds>::type;
+ return gps_time<CD>{t.time_since_epoch()} -
+ (sys_days(year{1980}/January/Sunday[1]) - sys_days(year{1970}/January/1) +
+ seconds{9});
+}
+
+inline
+gps_clock::time_point
+gps_clock::now()
+{
+ return from_utc(utc_clock::now());
+}
+
+template <class Duration>
+inline
+local_time<typename std::common_type<Duration, date::days>::type>
+gps_clock::to_local(const gps_time<Duration>& t) NOEXCEPT
+{
+ using CD = typename std::common_type<Duration, date::days>::type;
+ return local_time<CD>{t.time_since_epoch()} +
+ (local_days(year{1980}/January/Sunday[1]) - local_days(year{1970}/January/1));
+}
+
+template <class Duration>
+inline
+gps_time<typename std::common_type<Duration, date::days>::type>
+gps_clock::from_local(const local_time<Duration>& t) NOEXCEPT
+{
+ using CD = typename std::common_type<Duration, date::days>::type;
+ return gps_time<CD>{t.time_since_epoch()} -
+ (local_days(year{1980}/January/Sunday[1]) - local_days(year{1970}/January/1));
+}
+
+
+template <class CharT, class Traits, class Duration>
+std::basic_ostream<CharT, Traits>&
+to_stream(std::basic_ostream<CharT, Traits>& os, const CharT* fmt,
+ const gps_time<Duration>& t)
+{
+ const std::string abbrev("GPS");
+ CONSTDATA std::chrono::seconds offset{0};
+ return to_stream(os, fmt, gps_clock::to_local(t), &abbrev, &offset);
+}
+
+template <class CharT, class Traits, class Duration>
+std::basic_ostream<CharT, Traits>&
+operator<<(std::basic_ostream<CharT, Traits>& os, const gps_time<Duration>& t)
+{
+ const CharT fmt[] = {'%', 'F', ' ', '%', 'T', CharT{}};
+ return to_stream(os, fmt, t);
+}
+
+template <class Duration, class CharT, class Traits, class Alloc = std::allocator<CharT>>
+std::basic_istream<CharT, Traits>&
+from_stream(std::basic_istream<CharT, Traits>& is, const CharT* fmt,
+ gps_time<Duration>& tp,
+ std::basic_string<CharT, Traits, Alloc>* abbrev = nullptr,
+ std::chrono::minutes* offset = nullptr)
+{
+ local_time<Duration> lp;
+ from_stream(is, fmt, lp, abbrev, offset);
+ if (!is.fail())
+ tp = gps_clock::from_local(lp);
+ return is;
+}
+
+// clock_time_conversion
+
+template <class DstClock, class SrcClock>
+struct clock_time_conversion
+{};
+
+template <>
+struct clock_time_conversion<std::chrono::system_clock, std::chrono::system_clock>
+{
+ template <class Duration>
+ CONSTCD14
+ sys_time<Duration>
+ operator()(const sys_time<Duration>& st) const
+ {
+ return st;
+ }
+};
+
+template <>
+struct clock_time_conversion<utc_clock, utc_clock>
+{
+ template <class Duration>
+ CONSTCD14
+ utc_time<Duration>
+ operator()(const utc_time<Duration>& ut) const
+ {
+ return ut;
+ }
+};
+
+template<>
+struct clock_time_conversion<local_t, local_t>
+{
+ template <class Duration>
+ CONSTCD14
+ local_time<Duration>
+ operator()(const local_time<Duration>& lt) const
+ {
+ return lt;
+ }
+};
+
+template <>
+struct clock_time_conversion<utc_clock, std::chrono::system_clock>
+{
+ template <class Duration>
+ utc_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+ operator()(const sys_time<Duration>& st) const
+ {
+ return utc_clock::from_sys(st);
+ }
+};
+
+template <>
+struct clock_time_conversion<std::chrono::system_clock, utc_clock>
+{
+ template <class Duration>
+ sys_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+ operator()(const utc_time<Duration>& ut) const
+ {
+ return utc_clock::to_sys(ut);
+ }
+};
+
+template<>
+struct clock_time_conversion<local_t, std::chrono::system_clock>
+{
+ template <class Duration>
+ CONSTCD14
+ local_time<Duration>
+ operator()(const sys_time<Duration>& st) const
+ {
+ return local_time<Duration>{st.time_since_epoch()};
+ }
+};
+
+template<>
+struct clock_time_conversion<std::chrono::system_clock, local_t>
+{
+ template <class Duration>
+ CONSTCD14
+ sys_time<Duration>
+ operator()(const local_time<Duration>& lt) const
+ {
+ return sys_time<Duration>{lt.time_since_epoch()};
+ }
+};
+
+template<>
+struct clock_time_conversion<utc_clock, local_t>
+{
+ template <class Duration>
+ utc_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+ operator()(const local_time<Duration>& lt) const
+ {
+ return utc_clock::from_local(lt);
+ }
+};
+
+template<>
+struct clock_time_conversion<local_t, utc_clock>
+{
+ template <class Duration>
+ local_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+ operator()(const utc_time<Duration>& ut) const
+ {
+ return utc_clock::to_local(ut);
+ }
+};
+
+template<typename Clock>
+struct clock_time_conversion<Clock, Clock>
+{
+ template <class Duration>
+ CONSTCD14
+ std::chrono::time_point<Clock, Duration>
+ operator()(const std::chrono::time_point<Clock, Duration>& tp) const
+ {
+ return tp;
+ }
+};
+
+namespace ctc_detail
+{
+
+template <class Clock, class Duration>
+ using time_point = std::chrono::time_point<Clock, Duration>;
+
+using std::declval;
+using std::chrono::system_clock;
+
+//Check if TimePoint is time for given clock,
+//if not emits hard error
+template <class Clock, class TimePoint>
+struct return_clock_time
+{
+ using clock_time_point = time_point<Clock, typename TimePoint::duration>;
+ using type = TimePoint;
+
+ static_assert(std::is_same<TimePoint, clock_time_point>::value,
+ "time point with appropariate clock shall be returned");
+};
+
+// Check if Clock has to_sys method accepting TimePoint with given duration const& and
+// returning sys_time. If so has nested type member equal to return type to_sys.
+template <class Clock, class Duration, class = void>
+struct return_to_sys
+{};
+
+template <class Clock, class Duration>
+struct return_to_sys
+ <
+ Clock, Duration,
+ decltype(Clock::to_sys(declval<time_point<Clock, Duration> const&>()), void())
+ >
+ : return_clock_time
+ <
+ system_clock,
+ decltype(Clock::to_sys(declval<time_point<Clock, Duration> const&>()))
+ >
+{};
+
+// Similiar to above
+template <class Clock, class Duration, class = void>
+struct return_from_sys
+{};
+
+template <class Clock, class Duration>
+struct return_from_sys
+ <
+ Clock, Duration,
+ decltype(Clock::from_sys(declval<time_point<system_clock, Duration> const&>()),
+ void())
+ >
+ : return_clock_time
+ <
+ Clock,
+ decltype(Clock::from_sys(declval<time_point<system_clock, Duration> const&>()))
+ >
+{};
+
+// Similiar to above
+template <class Clock, class Duration, class = void>
+struct return_to_utc
+{};
+
+template <class Clock, class Duration>
+struct return_to_utc
+ <
+ Clock, Duration,
+ decltype(Clock::to_utc(declval<time_point<Clock, Duration> const&>()), void())
+ >
+ : return_clock_time
+ <
+ utc_clock,
+ decltype(Clock::to_utc(declval<time_point<Clock, Duration> const&>()))>
+{};
+
+// Similiar to above
+template <class Clock, class Duration, class = void>
+struct return_from_utc
+{};
+
+template <class Clock, class Duration>
+struct return_from_utc
+ <
+ Clock, Duration,
+ decltype(Clock::from_utc(declval<time_point<utc_clock, Duration> const&>()),
+ void())
+ >
+ : return_clock_time
+ <
+ Clock,
+ decltype(Clock::from_utc(declval<time_point<utc_clock, Duration> const&>()))
+ >
+{};
+
+// Similiar to above
+template<typename Clock, typename Duration, typename = void>
+struct return_to_local
+{};
+
+template<typename Clock, typename Duration>
+struct return_to_local
+ <
+ Clock, Duration,
+ decltype(Clock::to_local(declval<time_point<Clock, Duration> const&>()),
+ void())
+ >
+ : return_clock_time
+ <
+ local_t,
+ decltype(Clock::to_local(declval<time_point<Clock, Duration> const&>()))
+ >
+{};
+
+// Similiar to above
+template<typename Clock, typename Duration, typename = void>
+struct return_from_local
+{};
+
+template<typename Clock, typename Duration>
+struct return_from_local
+ <
+ Clock, Duration,
+ decltype(Clock::from_local(declval<time_point<local_t, Duration> const&>()),
+ void())
+ >
+ : return_clock_time
+ <
+ Clock,
+ decltype(Clock::from_local(declval<time_point<local_t, Duration> const&>()))
+ >
+{};
+
+} // namespace ctc_detail
+
+template <class SrcClock>
+struct clock_time_conversion<std::chrono::system_clock, SrcClock>
+{
+ template <class Duration>
+ CONSTCD14
+ typename ctc_detail::return_to_sys<SrcClock, Duration>::type
+ operator()(const std::chrono::time_point<SrcClock, Duration>& tp) const
+ {
+ return SrcClock::to_sys(tp);
+ }
+};
+
+template <class DstClock>
+struct clock_time_conversion<DstClock, std::chrono::system_clock>
+{
+ template <class Duration>
+ CONSTCD14
+ typename ctc_detail::return_from_sys<DstClock, Duration>::type
+ operator()(const sys_time<Duration>& st) const
+ {
+ return DstClock::from_sys(st);
+ }
+};
+
+template <class SrcClock>
+struct clock_time_conversion<utc_clock, SrcClock>
+{
+ template <class Duration>
+ CONSTCD14
+ typename ctc_detail::return_to_utc<SrcClock, Duration>::type
+ operator()(const std::chrono::time_point<SrcClock, Duration>& tp) const
+ {
+ return SrcClock::to_utc(tp);
+ }
+};
+
+template <class DstClock>
+struct clock_time_conversion<DstClock, utc_clock>
+{
+ template <class Duration>
+ CONSTCD14
+ typename ctc_detail::return_from_utc<DstClock, Duration>::type
+ operator()(const utc_time<Duration>& ut) const
+ {
+ return DstClock::from_utc(ut);
+ }
+};
+
+template<typename SrcClock>
+struct clock_time_conversion<local_t, SrcClock>
+{
+ template <class Duration>
+ CONSTCD14
+ typename ctc_detail::return_to_local<SrcClock, Duration>::type
+ operator()(const std::chrono::time_point<SrcClock, Duration>& tp) const
+ {
+ return SrcClock::to_local(tp);
+ }
+};
+
+template<typename DstClock>
+struct clock_time_conversion<DstClock, local_t>
+{
+ template <class Duration>
+ CONSTCD14
+ typename ctc_detail::return_from_local<DstClock, Duration>::type
+ operator()(const local_time<Duration>& lt) const
+ {
+ return DstClock::from_local(lt);
+ }
+};
+
+namespace clock_cast_detail
+{
+
+template <class Clock, class Duration>
+ using time_point = std::chrono::time_point<Clock, Duration>;
+using std::chrono::system_clock;
+
+template <class DstClock, class SrcClock, class Duration>
+CONSTCD14
+auto
+conv_clock(const time_point<SrcClock, Duration>& t)
+ -> decltype(std::declval<clock_time_conversion<DstClock, SrcClock>>()(t))
+{
+ return clock_time_conversion<DstClock, SrcClock>{}(t);
+}
+
+//direct trait conversion, 1st candidate
+template <class DstClock, class SrcClock, class Duration>
+CONSTCD14
+auto
+cc_impl(const time_point<SrcClock, Duration>& t, const time_point<SrcClock, Duration>*)
+ -> decltype(conv_clock<DstClock>(t))
+{
+ return conv_clock<DstClock>(t);
+}
+
+//conversion through sys, 2nd candidate
+template <class DstClock, class SrcClock, class Duration>
+CONSTCD14
+auto
+cc_impl(const time_point<SrcClock, Duration>& t, const void*)
+ -> decltype(conv_clock<DstClock>(conv_clock<system_clock>(t)))
+{
+ return conv_clock<DstClock>(conv_clock<system_clock>(t));
+}
+
+//conversion through utc, 2nd candidate
+template <class DstClock, class SrcClock, class Duration>
+CONSTCD14
+auto
+cc_impl(const time_point<SrcClock, Duration>& t, const void*)
+ -> decltype(0, // MSVC_WORKAROUND
+ conv_clock<DstClock>(conv_clock<utc_clock>(t)))
+{
+ return conv_clock<DstClock>(conv_clock<utc_clock>(t));
+}
+
+//conversion through sys and utc, 3rd candidate
+template <class DstClock, class SrcClock, class Duration>
+CONSTCD14
+auto
+cc_impl(const time_point<SrcClock, Duration>& t, ...)
+ -> decltype(conv_clock<DstClock>(conv_clock<utc_clock>(conv_clock<system_clock>(t))))
+{
+ return conv_clock<DstClock>(conv_clock<utc_clock>(conv_clock<system_clock>(t)));
+}
+
+//conversion through utc and sys, 3rd candidate
+template <class DstClock, class SrcClock, class Duration>
+CONSTCD14
+auto
+cc_impl(const time_point<SrcClock, Duration>& t, ...)
+ -> decltype(0, // MSVC_WORKAROUND
+ conv_clock<DstClock>(conv_clock<system_clock>(conv_clock<utc_clock>(t))))
+{
+ return conv_clock<DstClock>(conv_clock<system_clock>(conv_clock<utc_clock>(t)));
+}
+
+} // namespace clock_cast_detail
+
+template <class DstClock, class SrcClock, class Duration>
+CONSTCD14
+auto
+clock_cast(const std::chrono::time_point<SrcClock, Duration>& tp)
+ -> decltype(clock_cast_detail::cc_impl<DstClock>(tp, &tp))
+{
+ return clock_cast_detail::cc_impl<DstClock>(tp, &tp);
+}
+
+// Deprecated API
+
+template <class Duration>
+inline
+sys_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+to_sys_time(const utc_time<Duration>& t)
+{
+ return utc_clock::to_sys(t);
+}
+
+template <class Duration>
+inline
+sys_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+to_sys_time(const tai_time<Duration>& t)
+{
+ return utc_clock::to_sys(tai_clock::to_utc(t));
+}
+
+template <class Duration>
+inline
+sys_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+to_sys_time(const gps_time<Duration>& t)
+{
+ return utc_clock::to_sys(gps_clock::to_utc(t));
+}
+
+
+template <class Duration>
+inline
+utc_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+to_utc_time(const sys_time<Duration>& t)
+{
+ return utc_clock::from_sys(t);
+}
+
+template <class Duration>
+inline
+utc_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+to_utc_time(const tai_time<Duration>& t)
+{
+ return tai_clock::to_utc(t);
+}
+
+template <class Duration>
+inline
+utc_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+to_utc_time(const gps_time<Duration>& t)
+{
+ return gps_clock::to_utc(t);
+}
+
+
+template <class Duration>
+inline
+tai_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+to_tai_time(const sys_time<Duration>& t)
+{
+ return tai_clock::from_utc(utc_clock::from_sys(t));
+}
+
+template <class Duration>
+inline
+tai_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+to_tai_time(const utc_time<Duration>& t)
+{
+ return tai_clock::from_utc(t);
+}
+
+template <class Duration>
+inline
+tai_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+to_tai_time(const gps_time<Duration>& t)
+{
+ return tai_clock::from_utc(gps_clock::to_utc(t));
+}
+
+
+template <class Duration>
+inline
+gps_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+to_gps_time(const sys_time<Duration>& t)
+{
+ return gps_clock::from_utc(utc_clock::from_sys(t));
+}
+
+template <class Duration>
+inline
+gps_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+to_gps_time(const utc_time<Duration>& t)
+{
+ return gps_clock::from_utc(t);
+}
+
+template <class Duration>
+inline
+gps_time<typename std::common_type<Duration, std::chrono::seconds>::type>
+to_gps_time(const tai_time<Duration>& t)
+{
+ return gps_clock::from_utc(tai_clock::to_utc(t));
+}
+
+} // namespace date
+} // namespace arrow_vendored
+
+#endif // TZ_H
diff --git a/src/arrow/cpp/src/arrow/vendored/datetime/tz_private.h b/src/arrow/cpp/src/arrow/vendored/datetime/tz_private.h
new file mode 100644
index 000000000..6b7a91493
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/datetime/tz_private.h
@@ -0,0 +1,319 @@
+#ifndef TZ_PRIVATE_H
+#define TZ_PRIVATE_H
+
+// The MIT License (MIT)
+//
+// Copyright (c) 2015, 2016 Howard Hinnant
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+//
+// Our apologies. When the previous paragraph was written, lowercase had not yet
+// been invented (that would involve another several millennia of evolution).
+// We did not mean to shout.
+
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+#include "tz.h"
+#else
+#include "date.h"
+#include <vector>
+#endif
+
+namespace arrow_vendored
+{
+namespace date
+{
+
+namespace detail
+{
+
+#if !USE_OS_TZDB
+
+enum class tz {utc, local, standard};
+
+//forward declare to avoid warnings in gcc 6.2
+class MonthDayTime;
+std::istream& operator>>(std::istream& is, MonthDayTime& x);
+std::ostream& operator<<(std::ostream& os, const MonthDayTime& x);
+
+
+class MonthDayTime
+{
+private:
+ struct pair
+ {
+#if defined(_MSC_VER) && (_MSC_VER < 1900)
+ pair() : month_day_(date::jan / 1), weekday_(0U) {}
+
+ pair(const date::month_day& month_day, const date::weekday& weekday)
+ : month_day_(month_day), weekday_(weekday) {}
+#endif
+
+ date::month_day month_day_;
+ date::weekday weekday_;
+ };
+
+ enum Type {month_day, month_last_dow, lteq, gteq};
+
+ Type type_{month_day};
+
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+ union U
+#else
+ struct U
+#endif
+ {
+ date::month_day month_day_;
+ date::month_weekday_last month_weekday_last_;
+ pair month_day_weekday_;
+
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+ U() : month_day_{date::jan/1} {}
+#else
+ U() :
+ month_day_(date::jan/1),
+ month_weekday_last_(date::month(0U), date::weekday_last(date::weekday(0U)))
+ {}
+
+#endif // !defined(_MSC_VER) || (_MSC_VER >= 1900)
+
+ U& operator=(const date::month_day& x);
+ U& operator=(const date::month_weekday_last& x);
+ U& operator=(const pair& x);
+ } u;
+
+ std::chrono::hours h_{0};
+ std::chrono::minutes m_{0};
+ std::chrono::seconds s_{0};
+ tz zone_{tz::local};
+
+public:
+ MonthDayTime() = default;
+ MonthDayTime(local_seconds tp, tz timezone);
+ MonthDayTime(const date::month_day& md, tz timezone);
+
+ date::day day() const;
+ date::month month() const;
+ tz zone() const {return zone_;}
+
+ void canonicalize(date::year y);
+
+ sys_seconds
+ to_sys(date::year y, std::chrono::seconds offset, std::chrono::seconds save) const;
+ sys_days to_sys_days(date::year y) const;
+
+ sys_seconds to_time_point(date::year y) const;
+ int compare(date::year y, const MonthDayTime& x, date::year yx,
+ std::chrono::seconds offset, std::chrono::minutes prev_save) const;
+
+ friend std::istream& operator>>(std::istream& is, MonthDayTime& x);
+ friend std::ostream& operator<<(std::ostream& os, const MonthDayTime& x);
+};
+
+// A Rule specifies one or more set of datetimes without using an offset.
+// Multiple dates are specified with multiple years. The years in effect
+// go from starting_year_ to ending_year_, inclusive. starting_year_ <=
+// ending_year_. save_ is in effect for times from the specified time
+// onward, including the specified time. When the specified time is
+// local, it uses the save_ from the chronologically previous Rule, or if
+// there is none, 0.
+
+//forward declare to avoid warnings in gcc 6.2
+class Rule;
+bool operator==(const Rule& x, const Rule& y);
+bool operator<(const Rule& x, const Rule& y);
+bool operator==(const Rule& x, const date::year& y);
+bool operator<(const Rule& x, const date::year& y);
+bool operator==(const date::year& x, const Rule& y);
+bool operator<(const date::year& x, const Rule& y);
+bool operator==(const Rule& x, const std::string& y);
+bool operator<(const Rule& x, const std::string& y);
+bool operator==(const std::string& x, const Rule& y);
+bool operator<(const std::string& x, const Rule& y);
+std::ostream& operator<<(std::ostream& os, const Rule& r);
+
+class Rule
+{
+private:
+ std::string name_;
+ date::year starting_year_{0};
+ date::year ending_year_{0};
+ MonthDayTime starting_at_;
+ std::chrono::minutes save_{0};
+ std::string abbrev_;
+
+public:
+ Rule() = default;
+ explicit Rule(const std::string& s);
+ Rule(const Rule& r, date::year starting_year, date::year ending_year);
+
+ const std::string& name() const {return name_;}
+ const std::string& abbrev() const {return abbrev_;}
+
+ const MonthDayTime& mdt() const {return starting_at_;}
+ const date::year& starting_year() const {return starting_year_;}
+ const date::year& ending_year() const {return ending_year_;}
+ const std::chrono::minutes& save() const {return save_;}
+
+ static void split_overlaps(std::vector<Rule>& rules);
+
+ friend bool operator==(const Rule& x, const Rule& y);
+ friend bool operator<(const Rule& x, const Rule& y);
+ friend bool operator==(const Rule& x, const date::year& y);
+ friend bool operator<(const Rule& x, const date::year& y);
+ friend bool operator==(const date::year& x, const Rule& y);
+ friend bool operator<(const date::year& x, const Rule& y);
+ friend bool operator==(const Rule& x, const std::string& y);
+ friend bool operator<(const Rule& x, const std::string& y);
+ friend bool operator==(const std::string& x, const Rule& y);
+ friend bool operator<(const std::string& x, const Rule& y);
+
+ friend std::ostream& operator<<(std::ostream& os, const Rule& r);
+
+private:
+ date::day day() const;
+ date::month month() const;
+ static void split_overlaps(std::vector<Rule>& rules, std::size_t i, std::size_t& e);
+ static bool overlaps(const Rule& x, const Rule& y);
+ static void split(std::vector<Rule>& rules, std::size_t i, std::size_t k,
+ std::size_t& e);
+};
+
+inline bool operator!=(const Rule& x, const Rule& y) {return !(x == y);}
+inline bool operator> (const Rule& x, const Rule& y) {return y < x;}
+inline bool operator<=(const Rule& x, const Rule& y) {return !(y < x);}
+inline bool operator>=(const Rule& x, const Rule& y) {return !(x < y);}
+
+inline bool operator!=(const Rule& x, const date::year& y) {return !(x == y);}
+inline bool operator> (const Rule& x, const date::year& y) {return y < x;}
+inline bool operator<=(const Rule& x, const date::year& y) {return !(y < x);}
+inline bool operator>=(const Rule& x, const date::year& y) {return !(x < y);}
+
+inline bool operator!=(const date::year& x, const Rule& y) {return !(x == y);}
+inline bool operator> (const date::year& x, const Rule& y) {return y < x;}
+inline bool operator<=(const date::year& x, const Rule& y) {return !(y < x);}
+inline bool operator>=(const date::year& x, const Rule& y) {return !(x < y);}
+
+inline bool operator!=(const Rule& x, const std::string& y) {return !(x == y);}
+inline bool operator> (const Rule& x, const std::string& y) {return y < x;}
+inline bool operator<=(const Rule& x, const std::string& y) {return !(y < x);}
+inline bool operator>=(const Rule& x, const std::string& y) {return !(x < y);}
+
+inline bool operator!=(const std::string& x, const Rule& y) {return !(x == y);}
+inline bool operator> (const std::string& x, const Rule& y) {return y < x;}
+inline bool operator<=(const std::string& x, const Rule& y) {return !(y < x);}
+inline bool operator>=(const std::string& x, const Rule& y) {return !(x < y);}
+
+struct zonelet
+{
+ enum tag {has_rule, has_save, is_empty};
+
+ std::chrono::seconds gmtoff_;
+ tag tag_ = has_rule;
+
+#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
+ union U
+#else
+ struct U
+#endif
+ {
+ std::string rule_;
+ std::chrono::minutes save_;
+
+ ~U() {}
+ U() {}
+ U(const U&) {}
+ U& operator=(const U&) = delete;
+ } u;
+
+ std::string format_;
+ date::year until_year_{0};
+ MonthDayTime until_date_;
+ sys_seconds until_utc_;
+ local_seconds until_std_;
+ local_seconds until_loc_;
+ std::chrono::minutes initial_save_{0};
+ std::string initial_abbrev_;
+ std::pair<const Rule*, date::year> first_rule_{nullptr, date::year::min()};
+ std::pair<const Rule*, date::year> last_rule_{nullptr, date::year::max()};
+
+ ~zonelet();
+ zonelet();
+ zonelet(const zonelet& i);
+ zonelet& operator=(const zonelet&) = delete;
+};
+
+#else // USE_OS_TZDB
+
+struct ttinfo
+{
+ std::int32_t tt_gmtoff;
+ unsigned char tt_isdst;
+ unsigned char tt_abbrind;
+ unsigned char pad[2];
+};
+
+static_assert(sizeof(ttinfo) == 8, "");
+
+struct expanded_ttinfo
+{
+ std::chrono::seconds offset;
+ std::string abbrev;
+ bool is_dst;
+};
+
+struct transition
+{
+ sys_seconds timepoint;
+ const expanded_ttinfo* info;
+
+ transition(sys_seconds tp, const expanded_ttinfo* i = nullptr)
+ : timepoint(tp)
+ , info(i)
+ {}
+
+ friend
+ std::ostream&
+ operator<<(std::ostream& os, const transition& t)
+ {
+ using date::operator<<;
+ os << t.timepoint << "Z ";
+ if (t.info->offset >= std::chrono::seconds{0})
+ os << '+';
+ os << make_time(t.info->offset);
+ if (t.info->is_dst > 0)
+ os << " daylight ";
+ else
+ os << " standard ";
+ os << t.info->abbrev;
+ return os;
+ }
+};
+
+#endif // USE_OS_TZDB
+
+} // namespace detail
+
+} // namespace date
+} // namespace arrow_vendored
+
+#if defined(_MSC_VER) && (_MSC_VER < 1900)
+#include "tz.h"
+#endif
+
+#endif // TZ_PRIVATE_H
diff --git a/src/arrow/cpp/src/arrow/vendored/datetime/visibility.h b/src/arrow/cpp/src/arrow/vendored/datetime/visibility.h
new file mode 100644
index 000000000..ae031238d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/datetime/visibility.h
@@ -0,0 +1,26 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#if defined(ARROW_STATIC)
+// intentially empty
+#elif defined(ARROW_EXPORTING)
+#define DATE_BUILD_DLL
+#else
+#define DATE_USE_DLL
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/.gitignore b/src/arrow/cpp/src/arrow/vendored/double-conversion/.gitignore
new file mode 100644
index 000000000..1edeb79fd
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/.gitignore
@@ -0,0 +1 @@
+*.os
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/CMakeLists.txt b/src/arrow/cpp/src/arrow/vendored/double-conversion/CMakeLists.txt
new file mode 100644
index 000000000..6de8801ee
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/CMakeLists.txt
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+arrow_install_all_headers("arrow/vendored/double-conversion")
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/README.md b/src/arrow/cpp/src/arrow/vendored/double-conversion/README.md
new file mode 100644
index 000000000..af03f7322
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/README.md
@@ -0,0 +1,20 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+The files in this directory are vendored from double-conversion git tag v3.1.5.
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/bignum-dtoa.cc b/src/arrow/cpp/src/arrow/vendored/double-conversion/bignum-dtoa.cc
new file mode 100644
index 000000000..d99ac2aaf
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/bignum-dtoa.cc
@@ -0,0 +1,641 @@
+// Copyright 2010 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <cmath>
+
+#include "bignum-dtoa.h"
+
+#include "bignum.h"
+#include "ieee.h"
+
+namespace double_conversion {
+
+static int NormalizedExponent(uint64_t significand, int exponent) {
+ ASSERT(significand != 0);
+ while ((significand & Double::kHiddenBit) == 0) {
+ significand = significand << 1;
+ exponent = exponent - 1;
+ }
+ return exponent;
+}
+
+
+// Forward declarations:
+// Returns an estimation of k such that 10^(k-1) <= v < 10^k.
+static int EstimatePower(int exponent);
+// Computes v / 10^estimated_power exactly, as a ratio of two bignums, numerator
+// and denominator.
+static void InitialScaledStartValues(uint64_t significand,
+ int exponent,
+ bool lower_boundary_is_closer,
+ int estimated_power,
+ bool need_boundary_deltas,
+ Bignum* numerator,
+ Bignum* denominator,
+ Bignum* delta_minus,
+ Bignum* delta_plus);
+// Multiplies numerator/denominator so that its values lies in the range 1-10.
+// Returns decimal_point s.t.
+// v = numerator'/denominator' * 10^(decimal_point-1)
+// where numerator' and denominator' are the values of numerator and
+// denominator after the call to this function.
+static void FixupMultiply10(int estimated_power, bool is_even,
+ int* decimal_point,
+ Bignum* numerator, Bignum* denominator,
+ Bignum* delta_minus, Bignum* delta_plus);
+// Generates digits from the left to the right and stops when the generated
+// digits yield the shortest decimal representation of v.
+static void GenerateShortestDigits(Bignum* numerator, Bignum* denominator,
+ Bignum* delta_minus, Bignum* delta_plus,
+ bool is_even,
+ Vector<char> buffer, int* length);
+// Generates 'requested_digits' after the decimal point.
+static void BignumToFixed(int requested_digits, int* decimal_point,
+ Bignum* numerator, Bignum* denominator,
+ Vector<char>(buffer), int* length);
+// Generates 'count' digits of numerator/denominator.
+// Once 'count' digits have been produced rounds the result depending on the
+// remainder (remainders of exactly .5 round upwards). Might update the
+// decimal_point when rounding up (for example for 0.9999).
+static void GenerateCountedDigits(int count, int* decimal_point,
+ Bignum* numerator, Bignum* denominator,
+ Vector<char>(buffer), int* length);
+
+
+void BignumDtoa(double v, BignumDtoaMode mode, int requested_digits,
+ Vector<char> buffer, int* length, int* decimal_point) {
+ ASSERT(v > 0);
+ ASSERT(!Double(v).IsSpecial());
+ uint64_t significand;
+ int exponent;
+ bool lower_boundary_is_closer;
+ if (mode == BIGNUM_DTOA_SHORTEST_SINGLE) {
+ float f = static_cast<float>(v);
+ ASSERT(f == v);
+ significand = Single(f).Significand();
+ exponent = Single(f).Exponent();
+ lower_boundary_is_closer = Single(f).LowerBoundaryIsCloser();
+ } else {
+ significand = Double(v).Significand();
+ exponent = Double(v).Exponent();
+ lower_boundary_is_closer = Double(v).LowerBoundaryIsCloser();
+ }
+ bool need_boundary_deltas =
+ (mode == BIGNUM_DTOA_SHORTEST || mode == BIGNUM_DTOA_SHORTEST_SINGLE);
+
+ bool is_even = (significand & 1) == 0;
+ int normalized_exponent = NormalizedExponent(significand, exponent);
+ // estimated_power might be too low by 1.
+ int estimated_power = EstimatePower(normalized_exponent);
+
+ // Shortcut for Fixed.
+ // The requested digits correspond to the digits after the point. If the
+ // number is much too small, then there is no need in trying to get any
+ // digits.
+ if (mode == BIGNUM_DTOA_FIXED && -estimated_power - 1 > requested_digits) {
+ buffer[0] = '\0';
+ *length = 0;
+ // Set decimal-point to -requested_digits. This is what Gay does.
+ // Note that it should not have any effect anyways since the string is
+ // empty.
+ *decimal_point = -requested_digits;
+ return;
+ }
+
+ Bignum numerator;
+ Bignum denominator;
+ Bignum delta_minus;
+ Bignum delta_plus;
+ // Make sure the bignum can grow large enough. The smallest double equals
+ // 4e-324. In this case the denominator needs fewer than 324*4 binary digits.
+ // The maximum double is 1.7976931348623157e308 which needs fewer than
+ // 308*4 binary digits.
+ ASSERT(Bignum::kMaxSignificantBits >= 324*4);
+ InitialScaledStartValues(significand, exponent, lower_boundary_is_closer,
+ estimated_power, need_boundary_deltas,
+ &numerator, &denominator,
+ &delta_minus, &delta_plus);
+ // We now have v = (numerator / denominator) * 10^estimated_power.
+ FixupMultiply10(estimated_power, is_even, decimal_point,
+ &numerator, &denominator,
+ &delta_minus, &delta_plus);
+ // We now have v = (numerator / denominator) * 10^(decimal_point-1), and
+ // 1 <= (numerator + delta_plus) / denominator < 10
+ switch (mode) {
+ case BIGNUM_DTOA_SHORTEST:
+ case BIGNUM_DTOA_SHORTEST_SINGLE:
+ GenerateShortestDigits(&numerator, &denominator,
+ &delta_minus, &delta_plus,
+ is_even, buffer, length);
+ break;
+ case BIGNUM_DTOA_FIXED:
+ BignumToFixed(requested_digits, decimal_point,
+ &numerator, &denominator,
+ buffer, length);
+ break;
+ case BIGNUM_DTOA_PRECISION:
+ GenerateCountedDigits(requested_digits, decimal_point,
+ &numerator, &denominator,
+ buffer, length);
+ break;
+ default:
+ UNREACHABLE();
+ }
+ buffer[*length] = '\0';
+}
+
+
+// The procedure starts generating digits from the left to the right and stops
+// when the generated digits yield the shortest decimal representation of v. A
+// decimal representation of v is a number lying closer to v than to any other
+// double, so it converts to v when read.
+//
+// This is true if d, the decimal representation, is between m- and m+, the
+// upper and lower boundaries. d must be strictly between them if !is_even.
+// m- := (numerator - delta_minus) / denominator
+// m+ := (numerator + delta_plus) / denominator
+//
+// Precondition: 0 <= (numerator+delta_plus) / denominator < 10.
+// If 1 <= (numerator+delta_plus) / denominator < 10 then no leading 0 digit
+// will be produced. This should be the standard precondition.
+static void GenerateShortestDigits(Bignum* numerator, Bignum* denominator,
+ Bignum* delta_minus, Bignum* delta_plus,
+ bool is_even,
+ Vector<char> buffer, int* length) {
+ // Small optimization: if delta_minus and delta_plus are the same just reuse
+ // one of the two bignums.
+ if (Bignum::Equal(*delta_minus, *delta_plus)) {
+ delta_plus = delta_minus;
+ }
+ *length = 0;
+ for (;;) {
+ uint16_t digit;
+ digit = numerator->DivideModuloIntBignum(*denominator);
+ ASSERT(digit <= 9); // digit is a uint16_t and therefore always positive.
+ // digit = numerator / denominator (integer division).
+ // numerator = numerator % denominator.
+ buffer[(*length)++] = static_cast<char>(digit + '0');
+
+ // Can we stop already?
+ // If the remainder of the division is less than the distance to the lower
+ // boundary we can stop. In this case we simply round down (discarding the
+ // remainder).
+ // Similarly we test if we can round up (using the upper boundary).
+ bool in_delta_room_minus;
+ bool in_delta_room_plus;
+ if (is_even) {
+ in_delta_room_minus = Bignum::LessEqual(*numerator, *delta_minus);
+ } else {
+ in_delta_room_minus = Bignum::Less(*numerator, *delta_minus);
+ }
+ if (is_even) {
+ in_delta_room_plus =
+ Bignum::PlusCompare(*numerator, *delta_plus, *denominator) >= 0;
+ } else {
+ in_delta_room_plus =
+ Bignum::PlusCompare(*numerator, *delta_plus, *denominator) > 0;
+ }
+ if (!in_delta_room_minus && !in_delta_room_plus) {
+ // Prepare for next iteration.
+ numerator->Times10();
+ delta_minus->Times10();
+ // We optimized delta_plus to be equal to delta_minus (if they share the
+ // same value). So don't multiply delta_plus if they point to the same
+ // object.
+ if (delta_minus != delta_plus) {
+ delta_plus->Times10();
+ }
+ } else if (in_delta_room_minus && in_delta_room_plus) {
+ // Let's see if 2*numerator < denominator.
+ // If yes, then the next digit would be < 5 and we can round down.
+ int compare = Bignum::PlusCompare(*numerator, *numerator, *denominator);
+ if (compare < 0) {
+ // Remaining digits are less than .5. -> Round down (== do nothing).
+ } else if (compare > 0) {
+ // Remaining digits are more than .5 of denominator. -> Round up.
+ // Note that the last digit could not be a '9' as otherwise the whole
+ // loop would have stopped earlier.
+ // We still have an assert here in case the preconditions were not
+ // satisfied.
+ ASSERT(buffer[(*length) - 1] != '9');
+ buffer[(*length) - 1]++;
+ } else {
+ // Halfway case.
+ // TODO(floitsch): need a way to solve half-way cases.
+ // For now let's round towards even (since this is what Gay seems to
+ // do).
+
+ if ((buffer[(*length) - 1] - '0') % 2 == 0) {
+ // Round down => Do nothing.
+ } else {
+ ASSERT(buffer[(*length) - 1] != '9');
+ buffer[(*length) - 1]++;
+ }
+ }
+ return;
+ } else if (in_delta_room_minus) {
+ // Round down (== do nothing).
+ return;
+ } else { // in_delta_room_plus
+ // Round up.
+ // Note again that the last digit could not be '9' since this would have
+ // stopped the loop earlier.
+ // We still have an ASSERT here, in case the preconditions were not
+ // satisfied.
+ ASSERT(buffer[(*length) -1] != '9');
+ buffer[(*length) - 1]++;
+ return;
+ }
+ }
+}
+
+
+// Let v = numerator / denominator < 10.
+// Then we generate 'count' digits of d = x.xxxxx... (without the decimal point)
+// from left to right. Once 'count' digits have been produced we decide wether
+// to round up or down. Remainders of exactly .5 round upwards. Numbers such
+// as 9.999999 propagate a carry all the way, and change the
+// exponent (decimal_point), when rounding upwards.
+static void GenerateCountedDigits(int count, int* decimal_point,
+ Bignum* numerator, Bignum* denominator,
+ Vector<char> buffer, int* length) {
+ ASSERT(count >= 0);
+ for (int i = 0; i < count - 1; ++i) {
+ uint16_t digit;
+ digit = numerator->DivideModuloIntBignum(*denominator);
+ ASSERT(digit <= 9); // digit is a uint16_t and therefore always positive.
+ // digit = numerator / denominator (integer division).
+ // numerator = numerator % denominator.
+ buffer[i] = static_cast<char>(digit + '0');
+ // Prepare for next iteration.
+ numerator->Times10();
+ }
+ // Generate the last digit.
+ uint16_t digit;
+ digit = numerator->DivideModuloIntBignum(*denominator);
+ if (Bignum::PlusCompare(*numerator, *numerator, *denominator) >= 0) {
+ digit++;
+ }
+ ASSERT(digit <= 10);
+ buffer[count - 1] = static_cast<char>(digit + '0');
+ // Correct bad digits (in case we had a sequence of '9's). Propagate the
+ // carry until we hat a non-'9' or til we reach the first digit.
+ for (int i = count - 1; i > 0; --i) {
+ if (buffer[i] != '0' + 10) break;
+ buffer[i] = '0';
+ buffer[i - 1]++;
+ }
+ if (buffer[0] == '0' + 10) {
+ // Propagate a carry past the top place.
+ buffer[0] = '1';
+ (*decimal_point)++;
+ }
+ *length = count;
+}
+
+
+// Generates 'requested_digits' after the decimal point. It might omit
+// trailing '0's. If the input number is too small then no digits at all are
+// generated (ex.: 2 fixed digits for 0.00001).
+//
+// Input verifies: 1 <= (numerator + delta) / denominator < 10.
+static void BignumToFixed(int requested_digits, int* decimal_point,
+ Bignum* numerator, Bignum* denominator,
+ Vector<char>(buffer), int* length) {
+ // Note that we have to look at more than just the requested_digits, since
+ // a number could be rounded up. Example: v=0.5 with requested_digits=0.
+ // Even though the power of v equals 0 we can't just stop here.
+ if (-(*decimal_point) > requested_digits) {
+ // The number is definitively too small.
+ // Ex: 0.001 with requested_digits == 1.
+ // Set decimal-point to -requested_digits. This is what Gay does.
+ // Note that it should not have any effect anyways since the string is
+ // empty.
+ *decimal_point = -requested_digits;
+ *length = 0;
+ return;
+ } else if (-(*decimal_point) == requested_digits) {
+ // We only need to verify if the number rounds down or up.
+ // Ex: 0.04 and 0.06 with requested_digits == 1.
+ ASSERT(*decimal_point == -requested_digits);
+ // Initially the fraction lies in range (1, 10]. Multiply the denominator
+ // by 10 so that we can compare more easily.
+ denominator->Times10();
+ if (Bignum::PlusCompare(*numerator, *numerator, *denominator) >= 0) {
+ // If the fraction is >= 0.5 then we have to include the rounded
+ // digit.
+ buffer[0] = '1';
+ *length = 1;
+ (*decimal_point)++;
+ } else {
+ // Note that we caught most of similar cases earlier.
+ *length = 0;
+ }
+ return;
+ } else {
+ // The requested digits correspond to the digits after the point.
+ // The variable 'needed_digits' includes the digits before the point.
+ int needed_digits = (*decimal_point) + requested_digits;
+ GenerateCountedDigits(needed_digits, decimal_point,
+ numerator, denominator,
+ buffer, length);
+ }
+}
+
+
+// Returns an estimation of k such that 10^(k-1) <= v < 10^k where
+// v = f * 2^exponent and 2^52 <= f < 2^53.
+// v is hence a normalized double with the given exponent. The output is an
+// approximation for the exponent of the decimal approimation .digits * 10^k.
+//
+// The result might undershoot by 1 in which case 10^k <= v < 10^k+1.
+// Note: this property holds for v's upper boundary m+ too.
+// 10^k <= m+ < 10^k+1.
+// (see explanation below).
+//
+// Examples:
+// EstimatePower(0) => 16
+// EstimatePower(-52) => 0
+//
+// Note: e >= 0 => EstimatedPower(e) > 0. No similar claim can be made for e<0.
+static int EstimatePower(int exponent) {
+ // This function estimates log10 of v where v = f*2^e (with e == exponent).
+ // Note that 10^floor(log10(v)) <= v, but v <= 10^ceil(log10(v)).
+ // Note that f is bounded by its container size. Let p = 53 (the double's
+ // significand size). Then 2^(p-1) <= f < 2^p.
+ //
+ // Given that log10(v) == log2(v)/log2(10) and e+(len(f)-1) is quite close
+ // to log2(v) the function is simplified to (e+(len(f)-1)/log2(10)).
+ // The computed number undershoots by less than 0.631 (when we compute log3
+ // and not log10).
+ //
+ // Optimization: since we only need an approximated result this computation
+ // can be performed on 64 bit integers. On x86/x64 architecture the speedup is
+ // not really measurable, though.
+ //
+ // Since we want to avoid overshooting we decrement by 1e10 so that
+ // floating-point imprecisions don't affect us.
+ //
+ // Explanation for v's boundary m+: the computation takes advantage of
+ // the fact that 2^(p-1) <= f < 2^p. Boundaries still satisfy this requirement
+ // (even for denormals where the delta can be much more important).
+
+ const double k1Log10 = 0.30102999566398114; // 1/lg(10)
+
+ // For doubles len(f) == 53 (don't forget the hidden bit).
+ const int kSignificandSize = Double::kSignificandSize;
+ double estimate = ceil((exponent + kSignificandSize - 1) * k1Log10 - 1e-10);
+ return static_cast<int>(estimate);
+}
+
+
+// See comments for InitialScaledStartValues.
+static void InitialScaledStartValuesPositiveExponent(
+ uint64_t significand, int exponent,
+ int estimated_power, bool need_boundary_deltas,
+ Bignum* numerator, Bignum* denominator,
+ Bignum* delta_minus, Bignum* delta_plus) {
+ // A positive exponent implies a positive power.
+ ASSERT(estimated_power >= 0);
+ // Since the estimated_power is positive we simply multiply the denominator
+ // by 10^estimated_power.
+
+ // numerator = v.
+ numerator->AssignUInt64(significand);
+ numerator->ShiftLeft(exponent);
+ // denominator = 10^estimated_power.
+ denominator->AssignPowerUInt16(10, estimated_power);
+
+ if (need_boundary_deltas) {
+ // Introduce a common denominator so that the deltas to the boundaries are
+ // integers.
+ denominator->ShiftLeft(1);
+ numerator->ShiftLeft(1);
+ // Let v = f * 2^e, then m+ - v = 1/2 * 2^e; With the common
+ // denominator (of 2) delta_plus equals 2^e.
+ delta_plus->AssignUInt16(1);
+ delta_plus->ShiftLeft(exponent);
+ // Same for delta_minus. The adjustments if f == 2^p-1 are done later.
+ delta_minus->AssignUInt16(1);
+ delta_minus->ShiftLeft(exponent);
+ }
+}
+
+
+// See comments for InitialScaledStartValues
+static void InitialScaledStartValuesNegativeExponentPositivePower(
+ uint64_t significand, int exponent,
+ int estimated_power, bool need_boundary_deltas,
+ Bignum* numerator, Bignum* denominator,
+ Bignum* delta_minus, Bignum* delta_plus) {
+ // v = f * 2^e with e < 0, and with estimated_power >= 0.
+ // This means that e is close to 0 (have a look at how estimated_power is
+ // computed).
+
+ // numerator = significand
+ // since v = significand * 2^exponent this is equivalent to
+ // numerator = v * / 2^-exponent
+ numerator->AssignUInt64(significand);
+ // denominator = 10^estimated_power * 2^-exponent (with exponent < 0)
+ denominator->AssignPowerUInt16(10, estimated_power);
+ denominator->ShiftLeft(-exponent);
+
+ if (need_boundary_deltas) {
+ // Introduce a common denominator so that the deltas to the boundaries are
+ // integers.
+ denominator->ShiftLeft(1);
+ numerator->ShiftLeft(1);
+ // Let v = f * 2^e, then m+ - v = 1/2 * 2^e; With the common
+ // denominator (of 2) delta_plus equals 2^e.
+ // Given that the denominator already includes v's exponent the distance
+ // to the boundaries is simply 1.
+ delta_plus->AssignUInt16(1);
+ // Same for delta_minus. The adjustments if f == 2^p-1 are done later.
+ delta_minus->AssignUInt16(1);
+ }
+}
+
+
+// See comments for InitialScaledStartValues
+static void InitialScaledStartValuesNegativeExponentNegativePower(
+ uint64_t significand, int exponent,
+ int estimated_power, bool need_boundary_deltas,
+ Bignum* numerator, Bignum* denominator,
+ Bignum* delta_minus, Bignum* delta_plus) {
+ // Instead of multiplying the denominator with 10^estimated_power we
+ // multiply all values (numerator and deltas) by 10^-estimated_power.
+
+ // Use numerator as temporary container for power_ten.
+ Bignum* power_ten = numerator;
+ power_ten->AssignPowerUInt16(10, -estimated_power);
+
+ if (need_boundary_deltas) {
+ // Since power_ten == numerator we must make a copy of 10^estimated_power
+ // before we complete the computation of the numerator.
+ // delta_plus = delta_minus = 10^estimated_power
+ delta_plus->AssignBignum(*power_ten);
+ delta_minus->AssignBignum(*power_ten);
+ }
+
+ // numerator = significand * 2 * 10^-estimated_power
+ // since v = significand * 2^exponent this is equivalent to
+ // numerator = v * 10^-estimated_power * 2 * 2^-exponent.
+ // Remember: numerator has been abused as power_ten. So no need to assign it
+ // to itself.
+ ASSERT(numerator == power_ten);
+ numerator->MultiplyByUInt64(significand);
+
+ // denominator = 2 * 2^-exponent with exponent < 0.
+ denominator->AssignUInt16(1);
+ denominator->ShiftLeft(-exponent);
+
+ if (need_boundary_deltas) {
+ // Introduce a common denominator so that the deltas to the boundaries are
+ // integers.
+ numerator->ShiftLeft(1);
+ denominator->ShiftLeft(1);
+ // With this shift the boundaries have their correct value, since
+ // delta_plus = 10^-estimated_power, and
+ // delta_minus = 10^-estimated_power.
+ // These assignments have been done earlier.
+ // The adjustments if f == 2^p-1 (lower boundary is closer) are done later.
+ }
+}
+
+
+// Let v = significand * 2^exponent.
+// Computes v / 10^estimated_power exactly, as a ratio of two bignums, numerator
+// and denominator. The functions GenerateShortestDigits and
+// GenerateCountedDigits will then convert this ratio to its decimal
+// representation d, with the required accuracy.
+// Then d * 10^estimated_power is the representation of v.
+// (Note: the fraction and the estimated_power might get adjusted before
+// generating the decimal representation.)
+//
+// The initial start values consist of:
+// - a scaled numerator: s.t. numerator/denominator == v / 10^estimated_power.
+// - a scaled (common) denominator.
+// optionally (used by GenerateShortestDigits to decide if it has the shortest
+// decimal converting back to v):
+// - v - m-: the distance to the lower boundary.
+// - m+ - v: the distance to the upper boundary.
+//
+// v, m+, m-, and therefore v - m- and m+ - v all share the same denominator.
+//
+// Let ep == estimated_power, then the returned values will satisfy:
+// v / 10^ep = numerator / denominator.
+// v's boundarys m- and m+:
+// m- / 10^ep == v / 10^ep - delta_minus / denominator
+// m+ / 10^ep == v / 10^ep + delta_plus / denominator
+// Or in other words:
+// m- == v - delta_minus * 10^ep / denominator;
+// m+ == v + delta_plus * 10^ep / denominator;
+//
+// Since 10^(k-1) <= v < 10^k (with k == estimated_power)
+// or 10^k <= v < 10^(k+1)
+// we then have 0.1 <= numerator/denominator < 1
+// or 1 <= numerator/denominator < 10
+//
+// It is then easy to kickstart the digit-generation routine.
+//
+// The boundary-deltas are only filled if the mode equals BIGNUM_DTOA_SHORTEST
+// or BIGNUM_DTOA_SHORTEST_SINGLE.
+
+static void InitialScaledStartValues(uint64_t significand,
+ int exponent,
+ bool lower_boundary_is_closer,
+ int estimated_power,
+ bool need_boundary_deltas,
+ Bignum* numerator,
+ Bignum* denominator,
+ Bignum* delta_minus,
+ Bignum* delta_plus) {
+ if (exponent >= 0) {
+ InitialScaledStartValuesPositiveExponent(
+ significand, exponent, estimated_power, need_boundary_deltas,
+ numerator, denominator, delta_minus, delta_plus);
+ } else if (estimated_power >= 0) {
+ InitialScaledStartValuesNegativeExponentPositivePower(
+ significand, exponent, estimated_power, need_boundary_deltas,
+ numerator, denominator, delta_minus, delta_plus);
+ } else {
+ InitialScaledStartValuesNegativeExponentNegativePower(
+ significand, exponent, estimated_power, need_boundary_deltas,
+ numerator, denominator, delta_minus, delta_plus);
+ }
+
+ if (need_boundary_deltas && lower_boundary_is_closer) {
+ // The lower boundary is closer at half the distance of "normal" numbers.
+ // Increase the common denominator and adapt all but the delta_minus.
+ denominator->ShiftLeft(1); // *2
+ numerator->ShiftLeft(1); // *2
+ delta_plus->ShiftLeft(1); // *2
+ }
+}
+
+
+// This routine multiplies numerator/denominator so that its values lies in the
+// range 1-10. That is after a call to this function we have:
+// 1 <= (numerator + delta_plus) /denominator < 10.
+// Let numerator the input before modification and numerator' the argument
+// after modification, then the output-parameter decimal_point is such that
+// numerator / denominator * 10^estimated_power ==
+// numerator' / denominator' * 10^(decimal_point - 1)
+// In some cases estimated_power was too low, and this is already the case. We
+// then simply adjust the power so that 10^(k-1) <= v < 10^k (with k ==
+// estimated_power) but do not touch the numerator or denominator.
+// Otherwise the routine multiplies the numerator and the deltas by 10.
+static void FixupMultiply10(int estimated_power, bool is_even,
+ int* decimal_point,
+ Bignum* numerator, Bignum* denominator,
+ Bignum* delta_minus, Bignum* delta_plus) {
+ bool in_range;
+ if (is_even) {
+ // For IEEE doubles half-way cases (in decimal system numbers ending with 5)
+ // are rounded to the closest floating-point number with even significand.
+ in_range = Bignum::PlusCompare(*numerator, *delta_plus, *denominator) >= 0;
+ } else {
+ in_range = Bignum::PlusCompare(*numerator, *delta_plus, *denominator) > 0;
+ }
+ if (in_range) {
+ // Since numerator + delta_plus >= denominator we already have
+ // 1 <= numerator/denominator < 10. Simply update the estimated_power.
+ *decimal_point = estimated_power + 1;
+ } else {
+ *decimal_point = estimated_power;
+ numerator->Times10();
+ if (Bignum::Equal(*delta_minus, *delta_plus)) {
+ delta_minus->Times10();
+ delta_plus->AssignBignum(*delta_minus);
+ } else {
+ delta_minus->Times10();
+ delta_plus->Times10();
+ }
+ }
+}
+
+} // namespace double_conversion
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/bignum-dtoa.h b/src/arrow/cpp/src/arrow/vendored/double-conversion/bignum-dtoa.h
new file mode 100644
index 000000000..34b961992
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/bignum-dtoa.h
@@ -0,0 +1,84 @@
+// Copyright 2010 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef DOUBLE_CONVERSION_BIGNUM_DTOA_H_
+#define DOUBLE_CONVERSION_BIGNUM_DTOA_H_
+
+#include "utils.h"
+
+namespace double_conversion {
+
+enum BignumDtoaMode {
+ // Return the shortest correct representation.
+ // For example the output of 0.299999999999999988897 is (the less accurate but
+ // correct) 0.3.
+ BIGNUM_DTOA_SHORTEST,
+ // Same as BIGNUM_DTOA_SHORTEST but for single-precision floats.
+ BIGNUM_DTOA_SHORTEST_SINGLE,
+ // Return a fixed number of digits after the decimal point.
+ // For instance fixed(0.1, 4) becomes 0.1000
+ // If the input number is big, the output will be big.
+ BIGNUM_DTOA_FIXED,
+ // Return a fixed number of digits, no matter what the exponent is.
+ BIGNUM_DTOA_PRECISION
+};
+
+// Converts the given double 'v' to ascii.
+// The result should be interpreted as buffer * 10^(point-length).
+// The buffer will be null-terminated.
+//
+// The input v must be > 0 and different from NaN, and Infinity.
+//
+// The output depends on the given mode:
+// - SHORTEST: produce the least amount of digits for which the internal
+// identity requirement is still satisfied. If the digits are printed
+// (together with the correct exponent) then reading this number will give
+// 'v' again. The buffer will choose the representation that is closest to
+// 'v'. If there are two at the same distance, than the number is round up.
+// In this mode the 'requested_digits' parameter is ignored.
+// - FIXED: produces digits necessary to print a given number with
+// 'requested_digits' digits after the decimal point. The produced digits
+// might be too short in which case the caller has to fill the gaps with '0's.
+// Example: toFixed(0.001, 5) is allowed to return buffer="1", point=-2.
+// Halfway cases are rounded up. The call toFixed(0.15, 2) thus returns
+// buffer="2", point=0.
+// Note: the length of the returned buffer has no meaning wrt the significance
+// of its digits. That is, just because it contains '0's does not mean that
+// any other digit would not satisfy the internal identity requirement.
+// - PRECISION: produces 'requested_digits' where the first digit is not '0'.
+// Even though the length of produced digits usually equals
+// 'requested_digits', the function is allowed to return fewer digits, in
+// which case the caller has to fill the missing digits with '0's.
+// Halfway cases are again rounded up.
+// 'BignumDtoa' expects the given buffer to be big enough to hold all digits
+// and a terminating null-character.
+void BignumDtoa(double v, BignumDtoaMode mode, int requested_digits,
+ Vector<char> buffer, int* length, int* point);
+
+} // namespace double_conversion
+
+#endif // DOUBLE_CONVERSION_BIGNUM_DTOA_H_
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/bignum.cc b/src/arrow/cpp/src/arrow/vendored/double-conversion/bignum.cc
new file mode 100644
index 000000000..d077eef3f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/bignum.cc
@@ -0,0 +1,767 @@
+// Copyright 2010 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "bignum.h"
+#include "utils.h"
+
+namespace double_conversion {
+
+Bignum::Bignum()
+ : bigits_buffer_(), bigits_(bigits_buffer_, kBigitCapacity), used_digits_(0), exponent_(0) {
+ for (int i = 0; i < kBigitCapacity; ++i) {
+ bigits_[i] = 0;
+ }
+}
+
+
+template<typename S>
+static int BitSize(S value) {
+ (void) value; // Mark variable as used.
+ return 8 * sizeof(value);
+}
+
+// Guaranteed to lie in one Bigit.
+void Bignum::AssignUInt16(uint16_t value) {
+ ASSERT(kBigitSize >= BitSize(value));
+ Zero();
+ if (value == 0) return;
+
+ EnsureCapacity(1);
+ bigits_[0] = value;
+ used_digits_ = 1;
+}
+
+
+void Bignum::AssignUInt64(uint64_t value) {
+ const int kUInt64Size = 64;
+
+ Zero();
+ if (value == 0) return;
+
+ int needed_bigits = kUInt64Size / kBigitSize + 1;
+ EnsureCapacity(needed_bigits);
+ for (int i = 0; i < needed_bigits; ++i) {
+ bigits_[i] = value & kBigitMask;
+ value = value >> kBigitSize;
+ }
+ used_digits_ = needed_bigits;
+ Clamp();
+}
+
+
+void Bignum::AssignBignum(const Bignum& other) {
+ exponent_ = other.exponent_;
+ for (int i = 0; i < other.used_digits_; ++i) {
+ bigits_[i] = other.bigits_[i];
+ }
+ // Clear the excess digits (if there were any).
+ for (int i = other.used_digits_; i < used_digits_; ++i) {
+ bigits_[i] = 0;
+ }
+ used_digits_ = other.used_digits_;
+}
+
+
+static uint64_t ReadUInt64(Vector<const char> buffer,
+ int from,
+ int digits_to_read) {
+ uint64_t result = 0;
+ for (int i = from; i < from + digits_to_read; ++i) {
+ int digit = buffer[i] - '0';
+ ASSERT(0 <= digit && digit <= 9);
+ result = result * 10 + digit;
+ }
+ return result;
+}
+
+
+void Bignum::AssignDecimalString(Vector<const char> value) {
+ // 2^64 = 18446744073709551616 > 10^19
+ const int kMaxUint64DecimalDigits = 19;
+ Zero();
+ int length = value.length();
+ unsigned int pos = 0;
+ // Let's just say that each digit needs 4 bits.
+ while (length >= kMaxUint64DecimalDigits) {
+ uint64_t digits = ReadUInt64(value, pos, kMaxUint64DecimalDigits);
+ pos += kMaxUint64DecimalDigits;
+ length -= kMaxUint64DecimalDigits;
+ MultiplyByPowerOfTen(kMaxUint64DecimalDigits);
+ AddUInt64(digits);
+ }
+ uint64_t digits = ReadUInt64(value, pos, length);
+ MultiplyByPowerOfTen(length);
+ AddUInt64(digits);
+ Clamp();
+}
+
+
+static int HexCharValue(char c) {
+ if ('0' <= c && c <= '9') return c - '0';
+ if ('a' <= c && c <= 'f') return 10 + c - 'a';
+ ASSERT('A' <= c && c <= 'F');
+ return 10 + c - 'A';
+}
+
+
+void Bignum::AssignHexString(Vector<const char> value) {
+ Zero();
+ int length = value.length();
+
+ int needed_bigits = length * 4 / kBigitSize + 1;
+ EnsureCapacity(needed_bigits);
+ int string_index = length - 1;
+ for (int i = 0; i < needed_bigits - 1; ++i) {
+ // These bigits are guaranteed to be "full".
+ Chunk current_bigit = 0;
+ for (int j = 0; j < kBigitSize / 4; j++) {
+ current_bigit += HexCharValue(value[string_index--]) << (j * 4);
+ }
+ bigits_[i] = current_bigit;
+ }
+ used_digits_ = needed_bigits - 1;
+
+ Chunk most_significant_bigit = 0; // Could be = 0;
+ for (int j = 0; j <= string_index; ++j) {
+ most_significant_bigit <<= 4;
+ most_significant_bigit += HexCharValue(value[j]);
+ }
+ if (most_significant_bigit != 0) {
+ bigits_[used_digits_] = most_significant_bigit;
+ used_digits_++;
+ }
+ Clamp();
+}
+
+
+void Bignum::AddUInt64(uint64_t operand) {
+ if (operand == 0) return;
+ Bignum other;
+ other.AssignUInt64(operand);
+ AddBignum(other);
+}
+
+
+void Bignum::AddBignum(const Bignum& other) {
+ ASSERT(IsClamped());
+ ASSERT(other.IsClamped());
+
+ // If this has a greater exponent than other append zero-bigits to this.
+ // After this call exponent_ <= other.exponent_.
+ Align(other);
+
+ // There are two possibilities:
+ // aaaaaaaaaaa 0000 (where the 0s represent a's exponent)
+ // bbbbb 00000000
+ // ----------------
+ // ccccccccccc 0000
+ // or
+ // aaaaaaaaaa 0000
+ // bbbbbbbbb 0000000
+ // -----------------
+ // cccccccccccc 0000
+ // In both cases we might need a carry bigit.
+
+ EnsureCapacity(1 + Max(BigitLength(), other.BigitLength()) - exponent_);
+ Chunk carry = 0;
+ int bigit_pos = other.exponent_ - exponent_;
+ ASSERT(bigit_pos >= 0);
+ for (int i = 0; i < other.used_digits_; ++i) {
+ Chunk sum = bigits_[bigit_pos] + other.bigits_[i] + carry;
+ bigits_[bigit_pos] = sum & kBigitMask;
+ carry = sum >> kBigitSize;
+ bigit_pos++;
+ }
+
+ while (carry != 0) {
+ Chunk sum = bigits_[bigit_pos] + carry;
+ bigits_[bigit_pos] = sum & kBigitMask;
+ carry = sum >> kBigitSize;
+ bigit_pos++;
+ }
+ used_digits_ = Max(bigit_pos, used_digits_);
+ ASSERT(IsClamped());
+}
+
+
+void Bignum::SubtractBignum(const Bignum& other) {
+ ASSERT(IsClamped());
+ ASSERT(other.IsClamped());
+ // We require this to be bigger than other.
+ ASSERT(LessEqual(other, *this));
+
+ Align(other);
+
+ int offset = other.exponent_ - exponent_;
+ Chunk borrow = 0;
+ int i;
+ for (i = 0; i < other.used_digits_; ++i) {
+ ASSERT((borrow == 0) || (borrow == 1));
+ Chunk difference = bigits_[i + offset] - other.bigits_[i] - borrow;
+ bigits_[i + offset] = difference & kBigitMask;
+ borrow = difference >> (kChunkSize - 1);
+ }
+ while (borrow != 0) {
+ Chunk difference = bigits_[i + offset] - borrow;
+ bigits_[i + offset] = difference & kBigitMask;
+ borrow = difference >> (kChunkSize - 1);
+ ++i;
+ }
+ Clamp();
+}
+
+
+void Bignum::ShiftLeft(int shift_amount) {
+ if (used_digits_ == 0) return;
+ exponent_ += shift_amount / kBigitSize;
+ int local_shift = shift_amount % kBigitSize;
+ EnsureCapacity(used_digits_ + 1);
+ BigitsShiftLeft(local_shift);
+}
+
+
+void Bignum::MultiplyByUInt32(uint32_t factor) {
+ if (factor == 1) return;
+ if (factor == 0) {
+ Zero();
+ return;
+ }
+ if (used_digits_ == 0) return;
+
+ // The product of a bigit with the factor is of size kBigitSize + 32.
+ // Assert that this number + 1 (for the carry) fits into double chunk.
+ ASSERT(kDoubleChunkSize >= kBigitSize + 32 + 1);
+ DoubleChunk carry = 0;
+ for (int i = 0; i < used_digits_; ++i) {
+ DoubleChunk product = static_cast<DoubleChunk>(factor) * bigits_[i] + carry;
+ bigits_[i] = static_cast<Chunk>(product & kBigitMask);
+ carry = (product >> kBigitSize);
+ }
+ while (carry != 0) {
+ EnsureCapacity(used_digits_ + 1);
+ bigits_[used_digits_] = carry & kBigitMask;
+ used_digits_++;
+ carry >>= kBigitSize;
+ }
+}
+
+
+void Bignum::MultiplyByUInt64(uint64_t factor) {
+ if (factor == 1) return;
+ if (factor == 0) {
+ Zero();
+ return;
+ }
+ ASSERT(kBigitSize < 32);
+ uint64_t carry = 0;
+ uint64_t low = factor & 0xFFFFFFFF;
+ uint64_t high = factor >> 32;
+ for (int i = 0; i < used_digits_; ++i) {
+ uint64_t product_low = low * bigits_[i];
+ uint64_t product_high = high * bigits_[i];
+ uint64_t tmp = (carry & kBigitMask) + product_low;
+ bigits_[i] = tmp & kBigitMask;
+ carry = (carry >> kBigitSize) + (tmp >> kBigitSize) +
+ (product_high << (32 - kBigitSize));
+ }
+ while (carry != 0) {
+ EnsureCapacity(used_digits_ + 1);
+ bigits_[used_digits_] = carry & kBigitMask;
+ used_digits_++;
+ carry >>= kBigitSize;
+ }
+}
+
+
+void Bignum::MultiplyByPowerOfTen(int exponent) {
+ const uint64_t kFive27 = UINT64_2PART_C(0x6765c793, fa10079d);
+ const uint16_t kFive1 = 5;
+ const uint16_t kFive2 = kFive1 * 5;
+ const uint16_t kFive3 = kFive2 * 5;
+ const uint16_t kFive4 = kFive3 * 5;
+ const uint16_t kFive5 = kFive4 * 5;
+ const uint16_t kFive6 = kFive5 * 5;
+ const uint32_t kFive7 = kFive6 * 5;
+ const uint32_t kFive8 = kFive7 * 5;
+ const uint32_t kFive9 = kFive8 * 5;
+ const uint32_t kFive10 = kFive9 * 5;
+ const uint32_t kFive11 = kFive10 * 5;
+ const uint32_t kFive12 = kFive11 * 5;
+ const uint32_t kFive13 = kFive12 * 5;
+ const uint32_t kFive1_to_12[] =
+ { kFive1, kFive2, kFive3, kFive4, kFive5, kFive6,
+ kFive7, kFive8, kFive9, kFive10, kFive11, kFive12 };
+
+ ASSERT(exponent >= 0);
+ if (exponent == 0) return;
+ if (used_digits_ == 0) return;
+
+ // We shift by exponent at the end just before returning.
+ int remaining_exponent = exponent;
+ while (remaining_exponent >= 27) {
+ MultiplyByUInt64(kFive27);
+ remaining_exponent -= 27;
+ }
+ while (remaining_exponent >= 13) {
+ MultiplyByUInt32(kFive13);
+ remaining_exponent -= 13;
+ }
+ if (remaining_exponent > 0) {
+ MultiplyByUInt32(kFive1_to_12[remaining_exponent - 1]);
+ }
+ ShiftLeft(exponent);
+}
+
+
+void Bignum::Square() {
+ ASSERT(IsClamped());
+ int product_length = 2 * used_digits_;
+ EnsureCapacity(product_length);
+
+ // Comba multiplication: compute each column separately.
+ // Example: r = a2a1a0 * b2b1b0.
+ // r = 1 * a0b0 +
+ // 10 * (a1b0 + a0b1) +
+ // 100 * (a2b0 + a1b1 + a0b2) +
+ // 1000 * (a2b1 + a1b2) +
+ // 10000 * a2b2
+ //
+ // In the worst case we have to accumulate nb-digits products of digit*digit.
+ //
+ // Assert that the additional number of bits in a DoubleChunk are enough to
+ // sum up used_digits of Bigit*Bigit.
+ if ((1 << (2 * (kChunkSize - kBigitSize))) <= used_digits_) {
+ UNIMPLEMENTED();
+ }
+ DoubleChunk accumulator = 0;
+ // First shift the digits so we don't overwrite them.
+ int copy_offset = used_digits_;
+ for (int i = 0; i < used_digits_; ++i) {
+ bigits_[copy_offset + i] = bigits_[i];
+ }
+ // We have two loops to avoid some 'if's in the loop.
+ for (int i = 0; i < used_digits_; ++i) {
+ // Process temporary digit i with power i.
+ // The sum of the two indices must be equal to i.
+ int bigit_index1 = i;
+ int bigit_index2 = 0;
+ // Sum all of the sub-products.
+ while (bigit_index1 >= 0) {
+ Chunk chunk1 = bigits_[copy_offset + bigit_index1];
+ Chunk chunk2 = bigits_[copy_offset + bigit_index2];
+ accumulator += static_cast<DoubleChunk>(chunk1) * chunk2;
+ bigit_index1--;
+ bigit_index2++;
+ }
+ bigits_[i] = static_cast<Chunk>(accumulator) & kBigitMask;
+ accumulator >>= kBigitSize;
+ }
+ for (int i = used_digits_; i < product_length; ++i) {
+ int bigit_index1 = used_digits_ - 1;
+ int bigit_index2 = i - bigit_index1;
+ // Invariant: sum of both indices is again equal to i.
+ // Inner loop runs 0 times on last iteration, emptying accumulator.
+ while (bigit_index2 < used_digits_) {
+ Chunk chunk1 = bigits_[copy_offset + bigit_index1];
+ Chunk chunk2 = bigits_[copy_offset + bigit_index2];
+ accumulator += static_cast<DoubleChunk>(chunk1) * chunk2;
+ bigit_index1--;
+ bigit_index2++;
+ }
+ // The overwritten bigits_[i] will never be read in further loop iterations,
+ // because bigit_index1 and bigit_index2 are always greater
+ // than i - used_digits_.
+ bigits_[i] = static_cast<Chunk>(accumulator) & kBigitMask;
+ accumulator >>= kBigitSize;
+ }
+ // Since the result was guaranteed to lie inside the number the
+ // accumulator must be 0 now.
+ ASSERT(accumulator == 0);
+
+ // Don't forget to update the used_digits and the exponent.
+ used_digits_ = product_length;
+ exponent_ *= 2;
+ Clamp();
+}
+
+
+void Bignum::AssignPowerUInt16(uint16_t base, int power_exponent) {
+ ASSERT(base != 0);
+ ASSERT(power_exponent >= 0);
+ if (power_exponent == 0) {
+ AssignUInt16(1);
+ return;
+ }
+ Zero();
+ int shifts = 0;
+ // We expect base to be in range 2-32, and most often to be 10.
+ // It does not make much sense to implement different algorithms for counting
+ // the bits.
+ while ((base & 1) == 0) {
+ base >>= 1;
+ shifts++;
+ }
+ int bit_size = 0;
+ int tmp_base = base;
+ while (tmp_base != 0) {
+ tmp_base >>= 1;
+ bit_size++;
+ }
+ int final_size = bit_size * power_exponent;
+ // 1 extra bigit for the shifting, and one for rounded final_size.
+ EnsureCapacity(final_size / kBigitSize + 2);
+
+ // Left to Right exponentiation.
+ int mask = 1;
+ while (power_exponent >= mask) mask <<= 1;
+
+ // The mask is now pointing to the bit above the most significant 1-bit of
+ // power_exponent.
+ // Get rid of first 1-bit;
+ mask >>= 2;
+ uint64_t this_value = base;
+
+ bool delayed_multiplication = false;
+ const uint64_t max_32bits = 0xFFFFFFFF;
+ while (mask != 0 && this_value <= max_32bits) {
+ this_value = this_value * this_value;
+ // Verify that there is enough space in this_value to perform the
+ // multiplication. The first bit_size bits must be 0.
+ if ((power_exponent & mask) != 0) {
+ ASSERT(bit_size > 0);
+ uint64_t base_bits_mask =
+ ~((static_cast<uint64_t>(1) << (64 - bit_size)) - 1);
+ bool high_bits_zero = (this_value & base_bits_mask) == 0;
+ if (high_bits_zero) {
+ this_value *= base;
+ } else {
+ delayed_multiplication = true;
+ }
+ }
+ mask >>= 1;
+ }
+ AssignUInt64(this_value);
+ if (delayed_multiplication) {
+ MultiplyByUInt32(base);
+ }
+
+ // Now do the same thing as a bignum.
+ while (mask != 0) {
+ Square();
+ if ((power_exponent & mask) != 0) {
+ MultiplyByUInt32(base);
+ }
+ mask >>= 1;
+ }
+
+ // And finally add the saved shifts.
+ ShiftLeft(shifts * power_exponent);
+}
+
+
+// Precondition: this/other < 16bit.
+uint16_t Bignum::DivideModuloIntBignum(const Bignum& other) {
+ ASSERT(IsClamped());
+ ASSERT(other.IsClamped());
+ ASSERT(other.used_digits_ > 0);
+
+ // Easy case: if we have less digits than the divisor than the result is 0.
+ // Note: this handles the case where this == 0, too.
+ if (BigitLength() < other.BigitLength()) {
+ return 0;
+ }
+
+ Align(other);
+
+ uint16_t result = 0;
+
+ // Start by removing multiples of 'other' until both numbers have the same
+ // number of digits.
+ while (BigitLength() > other.BigitLength()) {
+ // This naive approach is extremely inefficient if `this` divided by other
+ // is big. This function is implemented for doubleToString where
+ // the result should be small (less than 10).
+ ASSERT(other.bigits_[other.used_digits_ - 1] >= ((1 << kBigitSize) / 16));
+ ASSERT(bigits_[used_digits_ - 1] < 0x10000);
+ // Remove the multiples of the first digit.
+ // Example this = 23 and other equals 9. -> Remove 2 multiples.
+ result += static_cast<uint16_t>(bigits_[used_digits_ - 1]);
+ SubtractTimes(other, bigits_[used_digits_ - 1]);
+ }
+
+ ASSERT(BigitLength() == other.BigitLength());
+
+ // Both bignums are at the same length now.
+ // Since other has more than 0 digits we know that the access to
+ // bigits_[used_digits_ - 1] is safe.
+ Chunk this_bigit = bigits_[used_digits_ - 1];
+ Chunk other_bigit = other.bigits_[other.used_digits_ - 1];
+
+ if (other.used_digits_ == 1) {
+ // Shortcut for easy (and common) case.
+ int quotient = this_bigit / other_bigit;
+ bigits_[used_digits_ - 1] = this_bigit - other_bigit * quotient;
+ ASSERT(quotient < 0x10000);
+ result += static_cast<uint16_t>(quotient);
+ Clamp();
+ return result;
+ }
+
+ int division_estimate = this_bigit / (other_bigit + 1);
+ ASSERT(division_estimate < 0x10000);
+ result += static_cast<uint16_t>(division_estimate);
+ SubtractTimes(other, division_estimate);
+
+ if (other_bigit * (division_estimate + 1) > this_bigit) {
+ // No need to even try to subtract. Even if other's remaining digits were 0
+ // another subtraction would be too much.
+ return result;
+ }
+
+ while (LessEqual(other, *this)) {
+ SubtractBignum(other);
+ result++;
+ }
+ return result;
+}
+
+
+template<typename S>
+static int SizeInHexChars(S number) {
+ ASSERT(number > 0);
+ int result = 0;
+ while (number != 0) {
+ number >>= 4;
+ result++;
+ }
+ return result;
+}
+
+
+static char HexCharOfValue(int value) {
+ ASSERT(0 <= value && value <= 16);
+ if (value < 10) return static_cast<char>(value + '0');
+ return static_cast<char>(value - 10 + 'A');
+}
+
+
+bool Bignum::ToHexString(char* buffer, int buffer_size) const {
+ ASSERT(IsClamped());
+ // Each bigit must be printable as separate hex-character.
+ ASSERT(kBigitSize % 4 == 0);
+ const int kHexCharsPerBigit = kBigitSize / 4;
+
+ if (used_digits_ == 0) {
+ if (buffer_size < 2) return false;
+ buffer[0] = '0';
+ buffer[1] = '\0';
+ return true;
+ }
+ // We add 1 for the terminating '\0' character.
+ int needed_chars = (BigitLength() - 1) * kHexCharsPerBigit +
+ SizeInHexChars(bigits_[used_digits_ - 1]) + 1;
+ if (needed_chars > buffer_size) return false;
+ int string_index = needed_chars - 1;
+ buffer[string_index--] = '\0';
+ for (int i = 0; i < exponent_; ++i) {
+ for (int j = 0; j < kHexCharsPerBigit; ++j) {
+ buffer[string_index--] = '0';
+ }
+ }
+ for (int i = 0; i < used_digits_ - 1; ++i) {
+ Chunk current_bigit = bigits_[i];
+ for (int j = 0; j < kHexCharsPerBigit; ++j) {
+ buffer[string_index--] = HexCharOfValue(current_bigit & 0xF);
+ current_bigit >>= 4;
+ }
+ }
+ // And finally the last bigit.
+ Chunk most_significant_bigit = bigits_[used_digits_ - 1];
+ while (most_significant_bigit != 0) {
+ buffer[string_index--] = HexCharOfValue(most_significant_bigit & 0xF);
+ most_significant_bigit >>= 4;
+ }
+ return true;
+}
+
+
+Bignum::Chunk Bignum::BigitAt(int index) const {
+ if (index >= BigitLength()) return 0;
+ if (index < exponent_) return 0;
+ return bigits_[index - exponent_];
+}
+
+
+int Bignum::Compare(const Bignum& a, const Bignum& b) {
+ ASSERT(a.IsClamped());
+ ASSERT(b.IsClamped());
+ int bigit_length_a = a.BigitLength();
+ int bigit_length_b = b.BigitLength();
+ if (bigit_length_a < bigit_length_b) return -1;
+ if (bigit_length_a > bigit_length_b) return +1;
+ for (int i = bigit_length_a - 1; i >= Min(a.exponent_, b.exponent_); --i) {
+ Chunk bigit_a = a.BigitAt(i);
+ Chunk bigit_b = b.BigitAt(i);
+ if (bigit_a < bigit_b) return -1;
+ if (bigit_a > bigit_b) return +1;
+ // Otherwise they are equal up to this digit. Try the next digit.
+ }
+ return 0;
+}
+
+
+int Bignum::PlusCompare(const Bignum& a, const Bignum& b, const Bignum& c) {
+ ASSERT(a.IsClamped());
+ ASSERT(b.IsClamped());
+ ASSERT(c.IsClamped());
+ if (a.BigitLength() < b.BigitLength()) {
+ return PlusCompare(b, a, c);
+ }
+ if (a.BigitLength() + 1 < c.BigitLength()) return -1;
+ if (a.BigitLength() > c.BigitLength()) return +1;
+ // The exponent encodes 0-bigits. So if there are more 0-digits in 'a' than
+ // 'b' has digits, then the bigit-length of 'a'+'b' must be equal to the one
+ // of 'a'.
+ if (a.exponent_ >= b.BigitLength() && a.BigitLength() < c.BigitLength()) {
+ return -1;
+ }
+
+ Chunk borrow = 0;
+ // Starting at min_exponent all digits are == 0. So no need to compare them.
+ int min_exponent = Min(Min(a.exponent_, b.exponent_), c.exponent_);
+ for (int i = c.BigitLength() - 1; i >= min_exponent; --i) {
+ Chunk chunk_a = a.BigitAt(i);
+ Chunk chunk_b = b.BigitAt(i);
+ Chunk chunk_c = c.BigitAt(i);
+ Chunk sum = chunk_a + chunk_b;
+ if (sum > chunk_c + borrow) {
+ return +1;
+ } else {
+ borrow = chunk_c + borrow - sum;
+ if (borrow > 1) return -1;
+ borrow <<= kBigitSize;
+ }
+ }
+ if (borrow == 0) return 0;
+ return -1;
+}
+
+
+void Bignum::Clamp() {
+ while (used_digits_ > 0 && bigits_[used_digits_ - 1] == 0) {
+ used_digits_--;
+ }
+ if (used_digits_ == 0) {
+ // Zero.
+ exponent_ = 0;
+ }
+}
+
+
+bool Bignum::IsClamped() const {
+ return used_digits_ == 0 || bigits_[used_digits_ - 1] != 0;
+}
+
+
+void Bignum::Zero() {
+ for (int i = 0; i < used_digits_; ++i) {
+ bigits_[i] = 0;
+ }
+ used_digits_ = 0;
+ exponent_ = 0;
+}
+
+
+void Bignum::Align(const Bignum& other) {
+ if (exponent_ > other.exponent_) {
+ // If "X" represents a "hidden" digit (by the exponent) then we are in the
+ // following case (a == this, b == other):
+ // a: aaaaaaXXXX or a: aaaaaXXX
+ // b: bbbbbbX b: bbbbbbbbXX
+ // We replace some of the hidden digits (X) of a with 0 digits.
+ // a: aaaaaa000X or a: aaaaa0XX
+ int zero_digits = exponent_ - other.exponent_;
+ EnsureCapacity(used_digits_ + zero_digits);
+ for (int i = used_digits_ - 1; i >= 0; --i) {
+ bigits_[i + zero_digits] = bigits_[i];
+ }
+ for (int i = 0; i < zero_digits; ++i) {
+ bigits_[i] = 0;
+ }
+ used_digits_ += zero_digits;
+ exponent_ -= zero_digits;
+ ASSERT(used_digits_ >= 0);
+ ASSERT(exponent_ >= 0);
+ }
+}
+
+
+void Bignum::BigitsShiftLeft(int shift_amount) {
+ ASSERT(shift_amount < kBigitSize);
+ ASSERT(shift_amount >= 0);
+ Chunk carry = 0;
+ for (int i = 0; i < used_digits_; ++i) {
+ Chunk new_carry = bigits_[i] >> (kBigitSize - shift_amount);
+ bigits_[i] = ((bigits_[i] << shift_amount) + carry) & kBigitMask;
+ carry = new_carry;
+ }
+ if (carry != 0) {
+ bigits_[used_digits_] = carry;
+ used_digits_++;
+ }
+}
+
+
+void Bignum::SubtractTimes(const Bignum& other, int factor) {
+ ASSERT(exponent_ <= other.exponent_);
+ if (factor < 3) {
+ for (int i = 0; i < factor; ++i) {
+ SubtractBignum(other);
+ }
+ return;
+ }
+ Chunk borrow = 0;
+ int exponent_diff = other.exponent_ - exponent_;
+ for (int i = 0; i < other.used_digits_; ++i) {
+ DoubleChunk product = static_cast<DoubleChunk>(factor) * other.bigits_[i];
+ DoubleChunk remove = borrow + product;
+ Chunk difference = bigits_[i + exponent_diff] - (remove & kBigitMask);
+ bigits_[i + exponent_diff] = difference & kBigitMask;
+ borrow = static_cast<Chunk>((difference >> (kChunkSize - 1)) +
+ (remove >> kBigitSize));
+ }
+ for (int i = other.used_digits_ + exponent_diff; i < used_digits_; ++i) {
+ if (borrow == 0) return;
+ Chunk difference = bigits_[i] - borrow;
+ bigits_[i] = difference & kBigitMask;
+ borrow = difference >> (kChunkSize - 1);
+ }
+ Clamp();
+}
+
+
+} // namespace double_conversion
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/bignum.h b/src/arrow/cpp/src/arrow/vendored/double-conversion/bignum.h
new file mode 100644
index 000000000..7c289fa2f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/bignum.h
@@ -0,0 +1,144 @@
+// Copyright 2010 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef DOUBLE_CONVERSION_BIGNUM_H_
+#define DOUBLE_CONVERSION_BIGNUM_H_
+
+#include "utils.h"
+
+namespace double_conversion {
+
+class Bignum {
+ public:
+ // 3584 = 128 * 28. We can represent 2^3584 > 10^1000 accurately.
+ // This bignum can encode much bigger numbers, since it contains an
+ // exponent.
+ static const int kMaxSignificantBits = 3584;
+
+ Bignum();
+ void AssignUInt16(uint16_t value);
+ void AssignUInt64(uint64_t value);
+ void AssignBignum(const Bignum& other);
+
+ void AssignDecimalString(Vector<const char> value);
+ void AssignHexString(Vector<const char> value);
+
+ void AssignPowerUInt16(uint16_t base, int exponent);
+
+ void AddUInt64(uint64_t operand);
+ void AddBignum(const Bignum& other);
+ // Precondition: this >= other.
+ void SubtractBignum(const Bignum& other);
+
+ void Square();
+ void ShiftLeft(int shift_amount);
+ void MultiplyByUInt32(uint32_t factor);
+ void MultiplyByUInt64(uint64_t factor);
+ void MultiplyByPowerOfTen(int exponent);
+ void Times10() { return MultiplyByUInt32(10); }
+ // Pseudocode:
+ // int result = this / other;
+ // this = this % other;
+ // In the worst case this function is in O(this/other).
+ uint16_t DivideModuloIntBignum(const Bignum& other);
+
+ bool ToHexString(char* buffer, int buffer_size) const;
+
+ // Returns
+ // -1 if a < b,
+ // 0 if a == b, and
+ // +1 if a > b.
+ static int Compare(const Bignum& a, const Bignum& b);
+ static bool Equal(const Bignum& a, const Bignum& b) {
+ return Compare(a, b) == 0;
+ }
+ static bool LessEqual(const Bignum& a, const Bignum& b) {
+ return Compare(a, b) <= 0;
+ }
+ static bool Less(const Bignum& a, const Bignum& b) {
+ return Compare(a, b) < 0;
+ }
+ // Returns Compare(a + b, c);
+ static int PlusCompare(const Bignum& a, const Bignum& b, const Bignum& c);
+ // Returns a + b == c
+ static bool PlusEqual(const Bignum& a, const Bignum& b, const Bignum& c) {
+ return PlusCompare(a, b, c) == 0;
+ }
+ // Returns a + b <= c
+ static bool PlusLessEqual(const Bignum& a, const Bignum& b, const Bignum& c) {
+ return PlusCompare(a, b, c) <= 0;
+ }
+ // Returns a + b < c
+ static bool PlusLess(const Bignum& a, const Bignum& b, const Bignum& c) {
+ return PlusCompare(a, b, c) < 0;
+ }
+ private:
+ typedef uint32_t Chunk;
+ typedef uint64_t DoubleChunk;
+
+ static const int kChunkSize = sizeof(Chunk) * 8;
+ static const int kDoubleChunkSize = sizeof(DoubleChunk) * 8;
+ // With bigit size of 28 we loose some bits, but a double still fits easily
+ // into two chunks, and more importantly we can use the Comba multiplication.
+ static const int kBigitSize = 28;
+ static const Chunk kBigitMask = (1 << kBigitSize) - 1;
+ // Every instance allocates kBigitLength chunks on the stack. Bignums cannot
+ // grow. There are no checks if the stack-allocated space is sufficient.
+ static const int kBigitCapacity = kMaxSignificantBits / kBigitSize;
+
+ void EnsureCapacity(int size) {
+ if (size > kBigitCapacity) {
+ UNREACHABLE();
+ }
+ }
+ void Align(const Bignum& other);
+ void Clamp();
+ bool IsClamped() const;
+ void Zero();
+ // Requires this to have enough capacity (no tests done).
+ // Updates used_digits_ if necessary.
+ // shift_amount must be < kBigitSize.
+ void BigitsShiftLeft(int shift_amount);
+ // BigitLength includes the "hidden" digits encoded in the exponent.
+ int BigitLength() const { return used_digits_ + exponent_; }
+ Chunk BigitAt(int index) const;
+ void SubtractTimes(const Bignum& other, int factor);
+
+ Chunk bigits_buffer_[kBigitCapacity];
+ // A vector backed by bigits_buffer_. This way accesses to the array are
+ // checked for out-of-bounds errors.
+ Vector<Chunk> bigits_;
+ int used_digits_;
+ // The Bignum's value equals value(bigits_) * 2^(exponent_ * kBigitSize).
+ int exponent_;
+
+ DC_DISALLOW_COPY_AND_ASSIGN(Bignum);
+};
+
+} // namespace double_conversion
+
+#endif // DOUBLE_CONVERSION_BIGNUM_H_
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/cached-powers.cc b/src/arrow/cpp/src/arrow/vendored/double-conversion/cached-powers.cc
new file mode 100644
index 000000000..8ab281a1b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/cached-powers.cc
@@ -0,0 +1,175 @@
+// Copyright 2006-2008 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <climits>
+#include <cmath>
+#include <cstdarg>
+
+#include "utils.h"
+
+#include "cached-powers.h"
+
+namespace double_conversion {
+
+struct CachedPower {
+ uint64_t significand;
+ int16_t binary_exponent;
+ int16_t decimal_exponent;
+};
+
+static const CachedPower kCachedPowers[] = {
+ {UINT64_2PART_C(0xfa8fd5a0, 081c0288), -1220, -348},
+ {UINT64_2PART_C(0xbaaee17f, a23ebf76), -1193, -340},
+ {UINT64_2PART_C(0x8b16fb20, 3055ac76), -1166, -332},
+ {UINT64_2PART_C(0xcf42894a, 5dce35ea), -1140, -324},
+ {UINT64_2PART_C(0x9a6bb0aa, 55653b2d), -1113, -316},
+ {UINT64_2PART_C(0xe61acf03, 3d1a45df), -1087, -308},
+ {UINT64_2PART_C(0xab70fe17, c79ac6ca), -1060, -300},
+ {UINT64_2PART_C(0xff77b1fc, bebcdc4f), -1034, -292},
+ {UINT64_2PART_C(0xbe5691ef, 416bd60c), -1007, -284},
+ {UINT64_2PART_C(0x8dd01fad, 907ffc3c), -980, -276},
+ {UINT64_2PART_C(0xd3515c28, 31559a83), -954, -268},
+ {UINT64_2PART_C(0x9d71ac8f, ada6c9b5), -927, -260},
+ {UINT64_2PART_C(0xea9c2277, 23ee8bcb), -901, -252},
+ {UINT64_2PART_C(0xaecc4991, 4078536d), -874, -244},
+ {UINT64_2PART_C(0x823c1279, 5db6ce57), -847, -236},
+ {UINT64_2PART_C(0xc2109436, 4dfb5637), -821, -228},
+ {UINT64_2PART_C(0x9096ea6f, 3848984f), -794, -220},
+ {UINT64_2PART_C(0xd77485cb, 25823ac7), -768, -212},
+ {UINT64_2PART_C(0xa086cfcd, 97bf97f4), -741, -204},
+ {UINT64_2PART_C(0xef340a98, 172aace5), -715, -196},
+ {UINT64_2PART_C(0xb23867fb, 2a35b28e), -688, -188},
+ {UINT64_2PART_C(0x84c8d4df, d2c63f3b), -661, -180},
+ {UINT64_2PART_C(0xc5dd4427, 1ad3cdba), -635, -172},
+ {UINT64_2PART_C(0x936b9fce, bb25c996), -608, -164},
+ {UINT64_2PART_C(0xdbac6c24, 7d62a584), -582, -156},
+ {UINT64_2PART_C(0xa3ab6658, 0d5fdaf6), -555, -148},
+ {UINT64_2PART_C(0xf3e2f893, dec3f126), -529, -140},
+ {UINT64_2PART_C(0xb5b5ada8, aaff80b8), -502, -132},
+ {UINT64_2PART_C(0x87625f05, 6c7c4a8b), -475, -124},
+ {UINT64_2PART_C(0xc9bcff60, 34c13053), -449, -116},
+ {UINT64_2PART_C(0x964e858c, 91ba2655), -422, -108},
+ {UINT64_2PART_C(0xdff97724, 70297ebd), -396, -100},
+ {UINT64_2PART_C(0xa6dfbd9f, b8e5b88f), -369, -92},
+ {UINT64_2PART_C(0xf8a95fcf, 88747d94), -343, -84},
+ {UINT64_2PART_C(0xb9447093, 8fa89bcf), -316, -76},
+ {UINT64_2PART_C(0x8a08f0f8, bf0f156b), -289, -68},
+ {UINT64_2PART_C(0xcdb02555, 653131b6), -263, -60},
+ {UINT64_2PART_C(0x993fe2c6, d07b7fac), -236, -52},
+ {UINT64_2PART_C(0xe45c10c4, 2a2b3b06), -210, -44},
+ {UINT64_2PART_C(0xaa242499, 697392d3), -183, -36},
+ {UINT64_2PART_C(0xfd87b5f2, 8300ca0e), -157, -28},
+ {UINT64_2PART_C(0xbce50864, 92111aeb), -130, -20},
+ {UINT64_2PART_C(0x8cbccc09, 6f5088cc), -103, -12},
+ {UINT64_2PART_C(0xd1b71758, e219652c), -77, -4},
+ {UINT64_2PART_C(0x9c400000, 00000000), -50, 4},
+ {UINT64_2PART_C(0xe8d4a510, 00000000), -24, 12},
+ {UINT64_2PART_C(0xad78ebc5, ac620000), 3, 20},
+ {UINT64_2PART_C(0x813f3978, f8940984), 30, 28},
+ {UINT64_2PART_C(0xc097ce7b, c90715b3), 56, 36},
+ {UINT64_2PART_C(0x8f7e32ce, 7bea5c70), 83, 44},
+ {UINT64_2PART_C(0xd5d238a4, abe98068), 109, 52},
+ {UINT64_2PART_C(0x9f4f2726, 179a2245), 136, 60},
+ {UINT64_2PART_C(0xed63a231, d4c4fb27), 162, 68},
+ {UINT64_2PART_C(0xb0de6538, 8cc8ada8), 189, 76},
+ {UINT64_2PART_C(0x83c7088e, 1aab65db), 216, 84},
+ {UINT64_2PART_C(0xc45d1df9, 42711d9a), 242, 92},
+ {UINT64_2PART_C(0x924d692c, a61be758), 269, 100},
+ {UINT64_2PART_C(0xda01ee64, 1a708dea), 295, 108},
+ {UINT64_2PART_C(0xa26da399, 9aef774a), 322, 116},
+ {UINT64_2PART_C(0xf209787b, b47d6b85), 348, 124},
+ {UINT64_2PART_C(0xb454e4a1, 79dd1877), 375, 132},
+ {UINT64_2PART_C(0x865b8692, 5b9bc5c2), 402, 140},
+ {UINT64_2PART_C(0xc83553c5, c8965d3d), 428, 148},
+ {UINT64_2PART_C(0x952ab45c, fa97a0b3), 455, 156},
+ {UINT64_2PART_C(0xde469fbd, 99a05fe3), 481, 164},
+ {UINT64_2PART_C(0xa59bc234, db398c25), 508, 172},
+ {UINT64_2PART_C(0xf6c69a72, a3989f5c), 534, 180},
+ {UINT64_2PART_C(0xb7dcbf53, 54e9bece), 561, 188},
+ {UINT64_2PART_C(0x88fcf317, f22241e2), 588, 196},
+ {UINT64_2PART_C(0xcc20ce9b, d35c78a5), 614, 204},
+ {UINT64_2PART_C(0x98165af3, 7b2153df), 641, 212},
+ {UINT64_2PART_C(0xe2a0b5dc, 971f303a), 667, 220},
+ {UINT64_2PART_C(0xa8d9d153, 5ce3b396), 694, 228},
+ {UINT64_2PART_C(0xfb9b7cd9, a4a7443c), 720, 236},
+ {UINT64_2PART_C(0xbb764c4c, a7a44410), 747, 244},
+ {UINT64_2PART_C(0x8bab8eef, b6409c1a), 774, 252},
+ {UINT64_2PART_C(0xd01fef10, a657842c), 800, 260},
+ {UINT64_2PART_C(0x9b10a4e5, e9913129), 827, 268},
+ {UINT64_2PART_C(0xe7109bfb, a19c0c9d), 853, 276},
+ {UINT64_2PART_C(0xac2820d9, 623bf429), 880, 284},
+ {UINT64_2PART_C(0x80444b5e, 7aa7cf85), 907, 292},
+ {UINT64_2PART_C(0xbf21e440, 03acdd2d), 933, 300},
+ {UINT64_2PART_C(0x8e679c2f, 5e44ff8f), 960, 308},
+ {UINT64_2PART_C(0xd433179d, 9c8cb841), 986, 316},
+ {UINT64_2PART_C(0x9e19db92, b4e31ba9), 1013, 324},
+ {UINT64_2PART_C(0xeb96bf6e, badf77d9), 1039, 332},
+ {UINT64_2PART_C(0xaf87023b, 9bf0ee6b), 1066, 340},
+};
+
+static const int kCachedPowersOffset = 348; // -1 * the first decimal_exponent.
+static const double kD_1_LOG2_10 = 0.30102999566398114; // 1 / lg(10)
+// Difference between the decimal exponents in the table above.
+const int PowersOfTenCache::kDecimalExponentDistance = 8;
+const int PowersOfTenCache::kMinDecimalExponent = -348;
+const int PowersOfTenCache::kMaxDecimalExponent = 340;
+
+void PowersOfTenCache::GetCachedPowerForBinaryExponentRange(
+ int min_exponent,
+ int max_exponent,
+ DiyFp* power,
+ int* decimal_exponent) {
+ int kQ = DiyFp::kSignificandSize;
+ double k = ceil((min_exponent + kQ - 1) * kD_1_LOG2_10);
+ int foo = kCachedPowersOffset;
+ int index =
+ (foo + static_cast<int>(k) - 1) / kDecimalExponentDistance + 1;
+ ASSERT(0 <= index && index < static_cast<int>(ARRAY_SIZE(kCachedPowers)));
+ CachedPower cached_power = kCachedPowers[index];
+ ASSERT(min_exponent <= cached_power.binary_exponent);
+ (void) max_exponent; // Mark variable as used.
+ ASSERT(cached_power.binary_exponent <= max_exponent);
+ *decimal_exponent = cached_power.decimal_exponent;
+ *power = DiyFp(cached_power.significand, cached_power.binary_exponent);
+}
+
+
+void PowersOfTenCache::GetCachedPowerForDecimalExponent(int requested_exponent,
+ DiyFp* power,
+ int* found_exponent) {
+ ASSERT(kMinDecimalExponent <= requested_exponent);
+ ASSERT(requested_exponent < kMaxDecimalExponent + kDecimalExponentDistance);
+ int index =
+ (requested_exponent + kCachedPowersOffset) / kDecimalExponentDistance;
+ CachedPower cached_power = kCachedPowers[index];
+ *power = DiyFp(cached_power.significand, cached_power.binary_exponent);
+ *found_exponent = cached_power.decimal_exponent;
+ ASSERT(*found_exponent <= requested_exponent);
+ ASSERT(requested_exponent < *found_exponent + kDecimalExponentDistance);
+}
+
+} // namespace double_conversion
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/cached-powers.h b/src/arrow/cpp/src/arrow/vendored/double-conversion/cached-powers.h
new file mode 100644
index 000000000..61a50614c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/cached-powers.h
@@ -0,0 +1,64 @@
+// Copyright 2010 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef DOUBLE_CONVERSION_CACHED_POWERS_H_
+#define DOUBLE_CONVERSION_CACHED_POWERS_H_
+
+#include "diy-fp.h"
+
+namespace double_conversion {
+
+class PowersOfTenCache {
+ public:
+
+ // Not all powers of ten are cached. The decimal exponent of two neighboring
+ // cached numbers will differ by kDecimalExponentDistance.
+ static const int kDecimalExponentDistance;
+
+ static const int kMinDecimalExponent;
+ static const int kMaxDecimalExponent;
+
+ // Returns a cached power-of-ten with a binary exponent in the range
+ // [min_exponent; max_exponent] (boundaries included).
+ static void GetCachedPowerForBinaryExponentRange(int min_exponent,
+ int max_exponent,
+ DiyFp* power,
+ int* decimal_exponent);
+
+ // Returns a cached power of ten x ~= 10^k such that
+ // k <= decimal_exponent < k + kCachedPowersDecimalDistance.
+ // The given decimal_exponent must satisfy
+ // kMinDecimalExponent <= requested_exponent, and
+ // requested_exponent < kMaxDecimalExponent + kDecimalExponentDistance.
+ static void GetCachedPowerForDecimalExponent(int requested_exponent,
+ DiyFp* power,
+ int* found_exponent);
+};
+
+} // namespace double_conversion
+
+#endif // DOUBLE_CONVERSION_CACHED_POWERS_H_
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/diy-fp.cc b/src/arrow/cpp/src/arrow/vendored/double-conversion/diy-fp.cc
new file mode 100644
index 000000000..ddd1891b1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/diy-fp.cc
@@ -0,0 +1,57 @@
+// Copyright 2010 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+#include "diy-fp.h"
+#include "utils.h"
+
+namespace double_conversion {
+
+void DiyFp::Multiply(const DiyFp& other) {
+ // Simply "emulates" a 128 bit multiplication.
+ // However: the resulting number only contains 64 bits. The least
+ // significant 64 bits are only used for rounding the most significant 64
+ // bits.
+ const uint64_t kM32 = 0xFFFFFFFFU;
+ uint64_t a = f_ >> 32;
+ uint64_t b = f_ & kM32;
+ uint64_t c = other.f_ >> 32;
+ uint64_t d = other.f_ & kM32;
+ uint64_t ac = a * c;
+ uint64_t bc = b * c;
+ uint64_t ad = a * d;
+ uint64_t bd = b * d;
+ uint64_t tmp = (bd >> 32) + (ad & kM32) + (bc & kM32);
+ // By adding 1U << 31 to tmp we round the final result.
+ // Halfway cases will be round up.
+ tmp += 1U << 31;
+ uint64_t result_f = ac + (ad >> 32) + (bc >> 32) + (tmp >> 32);
+ e_ += other.e_ + 64;
+ f_ = result_f;
+}
+
+} // namespace double_conversion
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/diy-fp.h b/src/arrow/cpp/src/arrow/vendored/double-conversion/diy-fp.h
new file mode 100644
index 000000000..2edf34674
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/diy-fp.h
@@ -0,0 +1,118 @@
+// Copyright 2010 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef DOUBLE_CONVERSION_DIY_FP_H_
+#define DOUBLE_CONVERSION_DIY_FP_H_
+
+#include "utils.h"
+
+namespace double_conversion {
+
+// This "Do It Yourself Floating Point" class implements a floating-point number
+// with a uint64 significand and an int exponent. Normalized DiyFp numbers will
+// have the most significant bit of the significand set.
+// Multiplication and Subtraction do not normalize their results.
+// DiyFp are not designed to contain special doubles (NaN and Infinity).
+class DiyFp {
+ public:
+ static const int kSignificandSize = 64;
+
+ DiyFp() : f_(0), e_(0) {}
+ DiyFp(uint64_t significand, int exponent) : f_(significand), e_(exponent) {}
+
+ // this = this - other.
+ // The exponents of both numbers must be the same and the significand of this
+ // must be bigger than the significand of other.
+ // The result will not be normalized.
+ void Subtract(const DiyFp& other) {
+ ASSERT(e_ == other.e_);
+ ASSERT(f_ >= other.f_);
+ f_ -= other.f_;
+ }
+
+ // Returns a - b.
+ // The exponents of both numbers must be the same and this must be bigger
+ // than other. The result will not be normalized.
+ static DiyFp Minus(const DiyFp& a, const DiyFp& b) {
+ DiyFp result = a;
+ result.Subtract(b);
+ return result;
+ }
+
+
+ // this = this * other.
+ void Multiply(const DiyFp& other);
+
+ // returns a * b;
+ static DiyFp Times(const DiyFp& a, const DiyFp& b) {
+ DiyFp result = a;
+ result.Multiply(b);
+ return result;
+ }
+
+ void Normalize() {
+ ASSERT(f_ != 0);
+ uint64_t significand = f_;
+ int exponent = e_;
+
+ // This method is mainly called for normalizing boundaries. In general
+ // boundaries need to be shifted by 10 bits. We thus optimize for this case.
+ const uint64_t k10MSBits = UINT64_2PART_C(0xFFC00000, 00000000);
+ while ((significand & k10MSBits) == 0) {
+ significand <<= 10;
+ exponent -= 10;
+ }
+ while ((significand & kUint64MSB) == 0) {
+ significand <<= 1;
+ exponent--;
+ }
+ f_ = significand;
+ e_ = exponent;
+ }
+
+ static DiyFp Normalize(const DiyFp& a) {
+ DiyFp result = a;
+ result.Normalize();
+ return result;
+ }
+
+ uint64_t f() const { return f_; }
+ int e() const { return e_; }
+
+ void set_f(uint64_t new_value) { f_ = new_value; }
+ void set_e(int new_value) { e_ = new_value; }
+
+ private:
+ static const uint64_t kUint64MSB = UINT64_2PART_C(0x80000000, 00000000);
+
+ uint64_t f_;
+ int e_;
+};
+
+} // namespace double_conversion
+
+#endif // DOUBLE_CONVERSION_DIY_FP_H_
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/double-conversion.cc b/src/arrow/cpp/src/arrow/vendored/double-conversion/double-conversion.cc
new file mode 100644
index 000000000..27e70b4c9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/double-conversion.cc
@@ -0,0 +1,1171 @@
+// Copyright 2010 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <climits>
+#include <locale>
+#include <cmath>
+
+#include "double-conversion.h"
+
+#include "bignum-dtoa.h"
+#include "fast-dtoa.h"
+#include "fixed-dtoa.h"
+#include "ieee.h"
+#include "strtod.h"
+#include "utils.h"
+
+#if defined(_MSC_VER)
+#pragma warning(disable : 4244)
+#endif
+
+namespace double_conversion {
+
+const DoubleToStringConverter& DoubleToStringConverter::EcmaScriptConverter() {
+ int flags = UNIQUE_ZERO | EMIT_POSITIVE_EXPONENT_SIGN;
+ static DoubleToStringConverter converter(flags,
+ "Infinity",
+ "NaN",
+ 'e',
+ -6, 21,
+ 6, 0);
+ return converter;
+}
+
+
+bool DoubleToStringConverter::HandleSpecialValues(
+ double value,
+ StringBuilder* result_builder) const {
+ Double double_inspect(value);
+ if (double_inspect.IsInfinite()) {
+ if (infinity_symbol_ == NULL) return false;
+ if (value < 0) {
+ result_builder->AddCharacter('-');
+ }
+ result_builder->AddString(infinity_symbol_);
+ return true;
+ }
+ if (double_inspect.IsNan()) {
+ if (nan_symbol_ == NULL) return false;
+ result_builder->AddString(nan_symbol_);
+ return true;
+ }
+ return false;
+}
+
+
+void DoubleToStringConverter::CreateExponentialRepresentation(
+ const char* decimal_digits,
+ int length,
+ int exponent,
+ StringBuilder* result_builder) const {
+ ASSERT(length != 0);
+ result_builder->AddCharacter(decimal_digits[0]);
+
+ /* If the mantissa of the scientific notation representation is an integer number,
+ * the EMIT_TRAILING_DECIMAL_POINT flag will add a '.' character at the end of the
+ * representation:
+ * - With EMIT_TRAILING_DECIMAL_POINT enabled -> 0.0009 => 9.E-4
+ * - With EMIT_TRAILING_DECIMAL_POINT disabled -> 0.0009 => 9E-4
+ *
+ * If the mantissa is an integer and the EMIT_TRAILING_ZERO_AFTER_POINT flag is enabled
+ * it will add a '0' character at the end of the mantissa representation. Note that that
+ * flag depends on EMIT_TRAILING_DECIMAL_POINT flag be enabled.*/
+ if(length == 1){
+ if ((flags_ & EMIT_TRAILING_DECIMAL_POINT) != 0) {
+ result_builder->AddCharacter('.');
+
+ if ((flags_ & EMIT_TRAILING_ZERO_AFTER_POINT) != 0) {
+ result_builder->AddCharacter('0');
+ }
+ }
+ } else {
+ result_builder->AddCharacter('.');
+ result_builder->AddSubstring(&decimal_digits[1], length-1);
+ }
+ result_builder->AddCharacter(exponent_character_);
+ if (exponent < 0) {
+ result_builder->AddCharacter('-');
+ exponent = -exponent;
+ } else {
+ if ((flags_ & EMIT_POSITIVE_EXPONENT_SIGN) != 0) {
+ result_builder->AddCharacter('+');
+ }
+ }
+ if (exponent == 0) {
+ result_builder->AddCharacter('0');
+ return;
+ }
+ ASSERT(exponent < 1e4);
+ const int kMaxExponentLength = 5;
+ char buffer[kMaxExponentLength + 1];
+ buffer[kMaxExponentLength] = '\0';
+ int first_char_pos = kMaxExponentLength;
+ while (exponent > 0) {
+ buffer[--first_char_pos] = '0' + (exponent % 10);
+ exponent /= 10;
+ }
+ result_builder->AddSubstring(&buffer[first_char_pos],
+ kMaxExponentLength - first_char_pos);
+}
+
+
+void DoubleToStringConverter::CreateDecimalRepresentation(
+ const char* decimal_digits,
+ int length,
+ int decimal_point,
+ int digits_after_point,
+ StringBuilder* result_builder) const {
+ // Create a representation that is padded with zeros if needed.
+ if (decimal_point <= 0) {
+ // "0.00000decimal_rep" or "0.000decimal_rep00".
+ result_builder->AddCharacter('0');
+ if (digits_after_point > 0) {
+ result_builder->AddCharacter('.');
+ result_builder->AddPadding('0', -decimal_point);
+ ASSERT(length <= digits_after_point - (-decimal_point));
+ result_builder->AddSubstring(decimal_digits, length);
+ int remaining_digits = digits_after_point - (-decimal_point) - length;
+ result_builder->AddPadding('0', remaining_digits);
+ }
+ } else if (decimal_point >= length) {
+ // "decimal_rep0000.00000" or "decimal_rep.0000".
+ result_builder->AddSubstring(decimal_digits, length);
+ result_builder->AddPadding('0', decimal_point - length);
+ if (digits_after_point > 0) {
+ result_builder->AddCharacter('.');
+ result_builder->AddPadding('0', digits_after_point);
+ }
+ } else {
+ // "decima.l_rep000".
+ ASSERT(digits_after_point > 0);
+ result_builder->AddSubstring(decimal_digits, decimal_point);
+ result_builder->AddCharacter('.');
+ ASSERT(length - decimal_point <= digits_after_point);
+ result_builder->AddSubstring(&decimal_digits[decimal_point],
+ length - decimal_point);
+ int remaining_digits = digits_after_point - (length - decimal_point);
+ result_builder->AddPadding('0', remaining_digits);
+ }
+ if (digits_after_point == 0) {
+ if ((flags_ & EMIT_TRAILING_DECIMAL_POINT) != 0) {
+ result_builder->AddCharacter('.');
+ }
+ if ((flags_ & EMIT_TRAILING_ZERO_AFTER_POINT) != 0) {
+ result_builder->AddCharacter('0');
+ }
+ }
+}
+
+
+bool DoubleToStringConverter::ToShortestIeeeNumber(
+ double value,
+ StringBuilder* result_builder,
+ DoubleToStringConverter::DtoaMode mode) const {
+ ASSERT(mode == SHORTEST || mode == SHORTEST_SINGLE);
+ if (Double(value).IsSpecial()) {
+ return HandleSpecialValues(value, result_builder);
+ }
+
+ int decimal_point;
+ bool sign;
+ const int kDecimalRepCapacity = kBase10MaximalLength + 1;
+ char decimal_rep[kDecimalRepCapacity];
+ int decimal_rep_length;
+
+ DoubleToAscii(value, mode, 0, decimal_rep, kDecimalRepCapacity,
+ &sign, &decimal_rep_length, &decimal_point);
+
+ bool unique_zero = (flags_ & UNIQUE_ZERO) != 0;
+ if (sign && (value != 0.0 || !unique_zero)) {
+ result_builder->AddCharacter('-');
+ }
+
+ int exponent = decimal_point - 1;
+ if ((decimal_in_shortest_low_ <= exponent) &&
+ (exponent < decimal_in_shortest_high_)) {
+ CreateDecimalRepresentation(decimal_rep, decimal_rep_length,
+ decimal_point,
+ Max(0, decimal_rep_length - decimal_point),
+ result_builder);
+ } else {
+ CreateExponentialRepresentation(decimal_rep, decimal_rep_length, exponent,
+ result_builder);
+ }
+ return true;
+}
+
+
+bool DoubleToStringConverter::ToFixed(double value,
+ int requested_digits,
+ StringBuilder* result_builder) const {
+ ASSERT(kMaxFixedDigitsBeforePoint == 60);
+ const double kFirstNonFixed = 1e60;
+
+ if (Double(value).IsSpecial()) {
+ return HandleSpecialValues(value, result_builder);
+ }
+
+ if (requested_digits > kMaxFixedDigitsAfterPoint) return false;
+ if (value >= kFirstNonFixed || value <= -kFirstNonFixed) return false;
+
+ // Find a sufficiently precise decimal representation of n.
+ int decimal_point;
+ bool sign;
+ // Add space for the '\0' byte.
+ const int kDecimalRepCapacity =
+ kMaxFixedDigitsBeforePoint + kMaxFixedDigitsAfterPoint + 1;
+ char decimal_rep[kDecimalRepCapacity];
+ int decimal_rep_length;
+ DoubleToAscii(value, FIXED, requested_digits,
+ decimal_rep, kDecimalRepCapacity,
+ &sign, &decimal_rep_length, &decimal_point);
+
+ bool unique_zero = ((flags_ & UNIQUE_ZERO) != 0);
+ if (sign && (value != 0.0 || !unique_zero)) {
+ result_builder->AddCharacter('-');
+ }
+
+ CreateDecimalRepresentation(decimal_rep, decimal_rep_length, decimal_point,
+ requested_digits, result_builder);
+ return true;
+}
+
+
+bool DoubleToStringConverter::ToExponential(
+ double value,
+ int requested_digits,
+ StringBuilder* result_builder) const {
+ if (Double(value).IsSpecial()) {
+ return HandleSpecialValues(value, result_builder);
+ }
+
+ if (requested_digits < -1) return false;
+ if (requested_digits > kMaxExponentialDigits) return false;
+
+ int decimal_point;
+ bool sign;
+ // Add space for digit before the decimal point and the '\0' character.
+ const int kDecimalRepCapacity = kMaxExponentialDigits + 2;
+ ASSERT(kDecimalRepCapacity > kBase10MaximalLength);
+ char decimal_rep[kDecimalRepCapacity];
+#ifndef NDEBUG
+ // Problem: there is an assert in StringBuilder::AddSubstring() that
+ // will pass this buffer to strlen(), and this buffer is not generally
+ // null-terminated.
+ memset(decimal_rep, 0, sizeof(decimal_rep));
+#endif
+ int decimal_rep_length;
+
+ if (requested_digits == -1) {
+ DoubleToAscii(value, SHORTEST, 0,
+ decimal_rep, kDecimalRepCapacity,
+ &sign, &decimal_rep_length, &decimal_point);
+ } else {
+ DoubleToAscii(value, PRECISION, requested_digits + 1,
+ decimal_rep, kDecimalRepCapacity,
+ &sign, &decimal_rep_length, &decimal_point);
+ ASSERT(decimal_rep_length <= requested_digits + 1);
+
+ for (int i = decimal_rep_length; i < requested_digits + 1; ++i) {
+ decimal_rep[i] = '0';
+ }
+ decimal_rep_length = requested_digits + 1;
+ }
+
+ bool unique_zero = ((flags_ & UNIQUE_ZERO) != 0);
+ if (sign && (value != 0.0 || !unique_zero)) {
+ result_builder->AddCharacter('-');
+ }
+
+ int exponent = decimal_point - 1;
+ CreateExponentialRepresentation(decimal_rep,
+ decimal_rep_length,
+ exponent,
+ result_builder);
+ return true;
+}
+
+
+bool DoubleToStringConverter::ToPrecision(double value,
+ int precision,
+ StringBuilder* result_builder) const {
+ if (Double(value).IsSpecial()) {
+ return HandleSpecialValues(value, result_builder);
+ }
+
+ if (precision < kMinPrecisionDigits || precision > kMaxPrecisionDigits) {
+ return false;
+ }
+
+ // Find a sufficiently precise decimal representation of n.
+ int decimal_point;
+ bool sign;
+ // Add one for the terminating null character.
+ const int kDecimalRepCapacity = kMaxPrecisionDigits + 1;
+ char decimal_rep[kDecimalRepCapacity];
+ int decimal_rep_length;
+
+ DoubleToAscii(value, PRECISION, precision,
+ decimal_rep, kDecimalRepCapacity,
+ &sign, &decimal_rep_length, &decimal_point);
+ ASSERT(decimal_rep_length <= precision);
+
+ bool unique_zero = ((flags_ & UNIQUE_ZERO) != 0);
+ if (sign && (value != 0.0 || !unique_zero)) {
+ result_builder->AddCharacter('-');
+ }
+
+ // The exponent if we print the number as x.xxeyyy. That is with the
+ // decimal point after the first digit.
+ int exponent = decimal_point - 1;
+
+ int extra_zero = ((flags_ & EMIT_TRAILING_ZERO_AFTER_POINT) != 0) ? 1 : 0;
+ if ((-decimal_point + 1 > max_leading_padding_zeroes_in_precision_mode_) ||
+ (decimal_point - precision + extra_zero >
+ max_trailing_padding_zeroes_in_precision_mode_)) {
+ // Fill buffer to contain 'precision' digits.
+ // Usually the buffer is already at the correct length, but 'DoubleToAscii'
+ // is allowed to return less characters.
+ for (int i = decimal_rep_length; i < precision; ++i) {
+ decimal_rep[i] = '0';
+ }
+
+ CreateExponentialRepresentation(decimal_rep,
+ precision,
+ exponent,
+ result_builder);
+ } else {
+ CreateDecimalRepresentation(decimal_rep, decimal_rep_length, decimal_point,
+ Max(0, precision - decimal_point),
+ result_builder);
+ }
+ return true;
+}
+
+
+static BignumDtoaMode DtoaToBignumDtoaMode(
+ DoubleToStringConverter::DtoaMode dtoa_mode) {
+ switch (dtoa_mode) {
+ case DoubleToStringConverter::SHORTEST: return BIGNUM_DTOA_SHORTEST;
+ case DoubleToStringConverter::SHORTEST_SINGLE:
+ return BIGNUM_DTOA_SHORTEST_SINGLE;
+ case DoubleToStringConverter::FIXED: return BIGNUM_DTOA_FIXED;
+ case DoubleToStringConverter::PRECISION: return BIGNUM_DTOA_PRECISION;
+ default:
+ UNREACHABLE();
+ }
+}
+
+
+void DoubleToStringConverter::DoubleToAscii(double v,
+ DtoaMode mode,
+ int requested_digits,
+ char* buffer,
+ int buffer_length,
+ bool* sign,
+ int* length,
+ int* point) {
+ Vector<char> vector(buffer, buffer_length);
+ ASSERT(!Double(v).IsSpecial());
+ ASSERT(mode == SHORTEST || mode == SHORTEST_SINGLE || requested_digits >= 0);
+
+ if (Double(v).Sign() < 0) {
+ *sign = true;
+ v = -v;
+ } else {
+ *sign = false;
+ }
+
+ if (mode == PRECISION && requested_digits == 0) {
+ vector[0] = '\0';
+ *length = 0;
+ return;
+ }
+
+ if (v == 0) {
+ vector[0] = '0';
+ vector[1] = '\0';
+ *length = 1;
+ *point = 1;
+ return;
+ }
+
+ bool fast_worked;
+ switch (mode) {
+ case SHORTEST:
+ fast_worked = FastDtoa(v, FAST_DTOA_SHORTEST, 0, vector, length, point);
+ break;
+ case SHORTEST_SINGLE:
+ fast_worked = FastDtoa(v, FAST_DTOA_SHORTEST_SINGLE, 0,
+ vector, length, point);
+ break;
+ case FIXED:
+ fast_worked = FastFixedDtoa(v, requested_digits, vector, length, point);
+ break;
+ case PRECISION:
+ fast_worked = FastDtoa(v, FAST_DTOA_PRECISION, requested_digits,
+ vector, length, point);
+ break;
+ default:
+ fast_worked = false;
+ UNREACHABLE();
+ }
+ if (fast_worked) return;
+
+ // If the fast dtoa didn't succeed use the slower bignum version.
+ BignumDtoaMode bignum_mode = DtoaToBignumDtoaMode(mode);
+ BignumDtoa(v, bignum_mode, requested_digits, vector, length, point);
+ vector[*length] = '\0';
+}
+
+
+namespace {
+
+inline char ToLower(char ch) {
+ static const std::ctype<char>& cType =
+ std::use_facet<std::ctype<char> >(std::locale::classic());
+ return cType.tolower(ch);
+}
+
+inline char Pass(char ch) {
+ return ch;
+}
+
+template <class Iterator, class Converter>
+static inline bool ConsumeSubStringImpl(Iterator* current,
+ Iterator end,
+ const char* substring,
+ Converter converter) {
+ ASSERT(converter(**current) == *substring);
+ for (substring++; *substring != '\0'; substring++) {
+ ++*current;
+ if (*current == end || converter(**current) != *substring) {
+ return false;
+ }
+ }
+ ++*current;
+ return true;
+}
+
+// Consumes the given substring from the iterator.
+// Returns false, if the substring does not match.
+template <class Iterator>
+static bool ConsumeSubString(Iterator* current,
+ Iterator end,
+ const char* substring,
+ bool allow_case_insensibility) {
+ if (allow_case_insensibility) {
+ return ConsumeSubStringImpl(current, end, substring, ToLower);
+ } else {
+ return ConsumeSubStringImpl(current, end, substring, Pass);
+ }
+}
+
+// Consumes first character of the str is equal to ch
+inline bool ConsumeFirstCharacter(char ch,
+ const char* str,
+ bool case_insensibility) {
+ return case_insensibility ? ToLower(ch) == str[0] : ch == str[0];
+}
+} // namespace
+
+// Maximum number of significant digits in decimal representation.
+// The longest possible double in decimal representation is
+// (2^53 - 1) * 2 ^ -1074 that is (2 ^ 53 - 1) * 5 ^ 1074 / 10 ^ 1074
+// (768 digits). If we parse a number whose first digits are equal to a
+// mean of 2 adjacent doubles (that could have up to 769 digits) the result
+// must be rounded to the bigger one unless the tail consists of zeros, so
+// we don't need to preserve all the digits.
+const int kMaxSignificantDigits = 772;
+
+
+static const char kWhitespaceTable7[] = { 32, 13, 10, 9, 11, 12 };
+static const int kWhitespaceTable7Length = ARRAY_SIZE(kWhitespaceTable7);
+
+
+static const uc16 kWhitespaceTable16[] = {
+ 160, 8232, 8233, 5760, 6158, 8192, 8193, 8194, 8195,
+ 8196, 8197, 8198, 8199, 8200, 8201, 8202, 8239, 8287, 12288, 65279
+};
+static const int kWhitespaceTable16Length = ARRAY_SIZE(kWhitespaceTable16);
+
+
+static bool isWhitespace(int x) {
+ if (x < 128) {
+ for (int i = 0; i < kWhitespaceTable7Length; i++) {
+ if (kWhitespaceTable7[i] == x) return true;
+ }
+ } else {
+ for (int i = 0; i < kWhitespaceTable16Length; i++) {
+ if (kWhitespaceTable16[i] == x) return true;
+ }
+ }
+ return false;
+}
+
+
+// Returns true if a nonspace found and false if the end has reached.
+template <class Iterator>
+static inline bool AdvanceToNonspace(Iterator* current, Iterator end) {
+ while (*current != end) {
+ if (!isWhitespace(**current)) return true;
+ ++*current;
+ }
+ return false;
+}
+
+
+static bool isDigit(int x, int radix) {
+ return (x >= '0' && x <= '9' && x < '0' + radix)
+ || (radix > 10 && x >= 'a' && x < 'a' + radix - 10)
+ || (radix > 10 && x >= 'A' && x < 'A' + radix - 10);
+}
+
+
+static double SignedZero(bool sign) {
+ return sign ? -0.0 : 0.0;
+}
+
+
+// Returns true if 'c' is a decimal digit that is valid for the given radix.
+//
+// The function is small and could be inlined, but VS2012 emitted a warning
+// because it constant-propagated the radix and concluded that the last
+// condition was always true. By moving it into a separate function the
+// compiler wouldn't warn anymore.
+#ifdef _MSC_VER
+#pragma optimize("",off)
+static bool IsDecimalDigitForRadix(int c, int radix) {
+ return '0' <= c && c <= '9' && (c - '0') < radix;
+}
+#pragma optimize("",on)
+#else
+static bool inline IsDecimalDigitForRadix(int c, int radix) {
+ return '0' <= c && c <= '9' && (c - '0') < radix;
+}
+#endif
+// Returns true if 'c' is a character digit that is valid for the given radix.
+// The 'a_character' should be 'a' or 'A'.
+//
+// The function is small and could be inlined, but VS2012 emitted a warning
+// because it constant-propagated the radix and concluded that the first
+// condition was always false. By moving it into a separate function the
+// compiler wouldn't warn anymore.
+static bool IsCharacterDigitForRadix(int c, int radix, char a_character) {
+ return radix > 10 && c >= a_character && c < a_character + radix - 10;
+}
+
+// Returns true, when the iterator is equal to end.
+template<class Iterator>
+static bool Advance (Iterator* it, uc16 separator, int base, Iterator& end) {
+ if (separator == StringToDoubleConverter::kNoSeparator) {
+ ++(*it);
+ return *it == end;
+ }
+ if (!isDigit(**it, base)) {
+ ++(*it);
+ return *it == end;
+ }
+ ++(*it);
+ if (*it == end) return true;
+ if (*it + 1 == end) return false;
+ if (**it == separator && isDigit(*(*it + 1), base)) {
+ ++(*it);
+ }
+ return *it == end;
+}
+
+// Checks whether the string in the range start-end is a hex-float string.
+// This function assumes that the leading '0x'/'0X' is already consumed.
+//
+// Hex float strings are of one of the following forms:
+// - hex_digits+ 'p' ('+'|'-')? exponent_digits+
+// - hex_digits* '.' hex_digits+ 'p' ('+'|'-')? exponent_digits+
+// - hex_digits+ '.' 'p' ('+'|'-')? exponent_digits+
+template<class Iterator>
+static bool IsHexFloatString(Iterator start,
+ Iterator end,
+ uc16 separator,
+ bool allow_trailing_junk) {
+ ASSERT(start != end);
+
+ Iterator current = start;
+
+ bool saw_digit = false;
+ while (isDigit(*current, 16)) {
+ saw_digit = true;
+ if (Advance(&current, separator, 16, end)) return false;
+ }
+ if (*current == '.') {
+ if (Advance(&current, separator, 16, end)) return false;
+ while (isDigit(*current, 16)) {
+ saw_digit = true;
+ if (Advance(&current, separator, 16, end)) return false;
+ }
+ }
+ if (!saw_digit) return false;
+ if (*current != 'p' && *current != 'P') return false;
+ if (Advance(&current, separator, 16, end)) return false;
+ if (*current == '+' || *current == '-') {
+ if (Advance(&current, separator, 16, end)) return false;
+ }
+ if (!isDigit(*current, 10)) return false;
+ if (Advance(&current, separator, 16, end)) return true;
+ while (isDigit(*current, 10)) {
+ if (Advance(&current, separator, 16, end)) return true;
+ }
+ return allow_trailing_junk || !AdvanceToNonspace(&current, end);
+}
+
+
+// Parsing integers with radix 2, 4, 8, 16, 32. Assumes current != end.
+//
+// If parse_as_hex_float is true, then the string must be a valid
+// hex-float.
+template <int radix_log_2, class Iterator>
+static double RadixStringToIeee(Iterator* current,
+ Iterator end,
+ bool sign,
+ uc16 separator,
+ bool parse_as_hex_float,
+ bool allow_trailing_junk,
+ double junk_string_value,
+ bool read_as_double,
+ bool* result_is_junk) {
+ ASSERT(*current != end);
+ ASSERT(!parse_as_hex_float ||
+ IsHexFloatString(*current, end, separator, allow_trailing_junk));
+
+ const int kDoubleSize = Double::kSignificandSize;
+ const int kSingleSize = Single::kSignificandSize;
+ const int kSignificandSize = read_as_double? kDoubleSize: kSingleSize;
+
+ *result_is_junk = true;
+
+ int64_t number = 0;
+ int exponent = 0;
+ const int radix = (1 << radix_log_2);
+ // Whether we have encountered a '.' and are parsing the decimal digits.
+ // Only relevant if parse_as_hex_float is true.
+ bool post_decimal = false;
+
+ // Skip leading 0s.
+ while (**current == '0') {
+ if (Advance(current, separator, radix, end)) {
+ *result_is_junk = false;
+ return SignedZero(sign);
+ }
+ }
+
+ while (true) {
+ int digit;
+ if (IsDecimalDigitForRadix(**current, radix)) {
+ digit = static_cast<char>(**current) - '0';
+ if (post_decimal) exponent -= radix_log_2;
+ } else if (IsCharacterDigitForRadix(**current, radix, 'a')) {
+ digit = static_cast<char>(**current) - 'a' + 10;
+ if (post_decimal) exponent -= radix_log_2;
+ } else if (IsCharacterDigitForRadix(**current, radix, 'A')) {
+ digit = static_cast<char>(**current) - 'A' + 10;
+ if (post_decimal) exponent -= radix_log_2;
+ } else if (parse_as_hex_float && **current == '.') {
+ post_decimal = true;
+ Advance(current, separator, radix, end);
+ ASSERT(*current != end);
+ continue;
+ } else if (parse_as_hex_float && (**current == 'p' || **current == 'P')) {
+ break;
+ } else {
+ if (allow_trailing_junk || !AdvanceToNonspace(current, end)) {
+ break;
+ } else {
+ return junk_string_value;
+ }
+ }
+
+ number = number * radix + digit;
+ int overflow = static_cast<int>(number >> kSignificandSize);
+ if (overflow != 0) {
+ // Overflow occurred. Need to determine which direction to round the
+ // result.
+ int overflow_bits_count = 1;
+ while (overflow > 1) {
+ overflow_bits_count++;
+ overflow >>= 1;
+ }
+
+ int dropped_bits_mask = ((1 << overflow_bits_count) - 1);
+ int dropped_bits = static_cast<int>(number) & dropped_bits_mask;
+ number >>= overflow_bits_count;
+ exponent += overflow_bits_count;
+
+ bool zero_tail = true;
+ for (;;) {
+ if (Advance(current, separator, radix, end)) break;
+ if (parse_as_hex_float && **current == '.') {
+ // Just run over the '.'. We are just trying to see whether there is
+ // a non-zero digit somewhere.
+ Advance(current, separator, radix, end);
+ ASSERT(*current != end);
+ post_decimal = true;
+ }
+ if (!isDigit(**current, radix)) break;
+ zero_tail = zero_tail && **current == '0';
+ if (!post_decimal) exponent += radix_log_2;
+ }
+
+ if (!parse_as_hex_float &&
+ !allow_trailing_junk &&
+ AdvanceToNonspace(current, end)) {
+ return junk_string_value;
+ }
+
+ int middle_value = (1 << (overflow_bits_count - 1));
+ if (dropped_bits > middle_value) {
+ number++; // Rounding up.
+ } else if (dropped_bits == middle_value) {
+ // Rounding to even to consistency with decimals: half-way case rounds
+ // up if significant part is odd and down otherwise.
+ if ((number & 1) != 0 || !zero_tail) {
+ number++; // Rounding up.
+ }
+ }
+
+ // Rounding up may cause overflow.
+ if ((number & ((int64_t)1 << kSignificandSize)) != 0) {
+ exponent++;
+ number >>= 1;
+ }
+ break;
+ }
+ if (Advance(current, separator, radix, end)) break;
+ }
+
+ ASSERT(number < ((int64_t)1 << kSignificandSize));
+ ASSERT(static_cast<int64_t>(static_cast<double>(number)) == number);
+
+ *result_is_junk = false;
+
+ if (parse_as_hex_float) {
+ ASSERT(**current == 'p' || **current == 'P');
+ Advance(current, separator, radix, end);
+ ASSERT(*current != end);
+ bool is_negative = false;
+ if (**current == '+') {
+ Advance(current, separator, radix, end);
+ ASSERT(*current != end);
+ } else if (**current == '-') {
+ is_negative = true;
+ Advance(current, separator, radix, end);
+ ASSERT(*current != end);
+ }
+ int written_exponent = 0;
+ while (IsDecimalDigitForRadix(**current, 10)) {
+ // No need to read exponents if they are too big. That could potentially overflow
+ // the `written_exponent` variable.
+ if (abs(written_exponent) <= 100 * Double::kMaxExponent) {
+ written_exponent = 10 * written_exponent + **current - '0';
+ }
+ if (Advance(current, separator, radix, end)) break;
+ }
+ if (is_negative) written_exponent = -written_exponent;
+ exponent += written_exponent;
+ }
+
+ if (exponent == 0 || number == 0) {
+ if (sign) {
+ if (number == 0) return -0.0;
+ number = -number;
+ }
+ return static_cast<double>(number);
+ }
+
+ ASSERT(number != 0);
+ double result = Double(DiyFp(number, exponent)).value();
+ return sign ? -result : result;
+}
+
+template <class Iterator>
+double StringToDoubleConverter::StringToIeee(
+ Iterator input,
+ int length,
+ bool read_as_double,
+ int* processed_characters_count) const {
+ Iterator current = input;
+ Iterator end = input + length;
+
+ *processed_characters_count = 0;
+
+ const bool allow_trailing_junk = (flags_ & ALLOW_TRAILING_JUNK) != 0;
+ const bool allow_leading_spaces = (flags_ & ALLOW_LEADING_SPACES) != 0;
+ const bool allow_trailing_spaces = (flags_ & ALLOW_TRAILING_SPACES) != 0;
+ const bool allow_spaces_after_sign = (flags_ & ALLOW_SPACES_AFTER_SIGN) != 0;
+ const bool allow_case_insensibility = (flags_ & ALLOW_CASE_INSENSIBILITY) != 0;
+
+ // To make sure that iterator dereferencing is valid the following
+ // convention is used:
+ // 1. Each '++current' statement is followed by check for equality to 'end'.
+ // 2. If AdvanceToNonspace returned false then current == end.
+ // 3. If 'current' becomes equal to 'end' the function returns or goes to
+ // 'parsing_done'.
+ // 4. 'current' is not dereferenced after the 'parsing_done' label.
+ // 5. Code before 'parsing_done' may rely on 'current != end'.
+ if (current == end) return empty_string_value_;
+
+ if (allow_leading_spaces || allow_trailing_spaces) {
+ if (!AdvanceToNonspace(&current, end)) {
+ *processed_characters_count = static_cast<int>(current - input);
+ return empty_string_value_;
+ }
+ if (!allow_leading_spaces && (input != current)) {
+ // No leading spaces allowed, but AdvanceToNonspace moved forward.
+ return junk_string_value_;
+ }
+ }
+
+ // The longest form of simplified number is: "-<significant digits>.1eXXX\0".
+ const int kBufferSize = kMaxSignificantDigits + 10;
+ char buffer[kBufferSize]; // NOLINT: size is known at compile time.
+ int buffer_pos = 0;
+
+ // Exponent will be adjusted if insignificant digits of the integer part
+ // or insignificant leading zeros of the fractional part are dropped.
+ int exponent = 0;
+ int significant_digits = 0;
+ int insignificant_digits = 0;
+ bool nonzero_digit_dropped = false;
+
+ bool sign = false;
+
+ if (*current == '+' || *current == '-') {
+ sign = (*current == '-');
+ ++current;
+ Iterator next_non_space = current;
+ // Skip following spaces (if allowed).
+ if (!AdvanceToNonspace(&next_non_space, end)) return junk_string_value_;
+ if (!allow_spaces_after_sign && (current != next_non_space)) {
+ return junk_string_value_;
+ }
+ current = next_non_space;
+ }
+
+ if (infinity_symbol_ != NULL) {
+ if (ConsumeFirstCharacter(*current, infinity_symbol_, allow_case_insensibility)) {
+ if (!ConsumeSubString(&current, end, infinity_symbol_, allow_case_insensibility)) {
+ return junk_string_value_;
+ }
+
+ if (!(allow_trailing_spaces || allow_trailing_junk) && (current != end)) {
+ return junk_string_value_;
+ }
+ if (!allow_trailing_junk && AdvanceToNonspace(&current, end)) {
+ return junk_string_value_;
+ }
+
+ ASSERT(buffer_pos == 0);
+ *processed_characters_count = static_cast<int>(current - input);
+ return sign ? -Double::Infinity() : Double::Infinity();
+ }
+ }
+
+ if (nan_symbol_ != NULL) {
+ if (ConsumeFirstCharacter(*current, nan_symbol_, allow_case_insensibility)) {
+ if (!ConsumeSubString(&current, end, nan_symbol_, allow_case_insensibility)) {
+ return junk_string_value_;
+ }
+
+ if (!(allow_trailing_spaces || allow_trailing_junk) && (current != end)) {
+ return junk_string_value_;
+ }
+ if (!allow_trailing_junk && AdvanceToNonspace(&current, end)) {
+ return junk_string_value_;
+ }
+
+ ASSERT(buffer_pos == 0);
+ *processed_characters_count = static_cast<int>(current - input);
+ return sign ? -Double::NaN() : Double::NaN();
+ }
+ }
+
+ bool leading_zero = false;
+ if (*current == '0') {
+ if (Advance(&current, separator_, 10, end)) {
+ *processed_characters_count = static_cast<int>(current - input);
+ return SignedZero(sign);
+ }
+
+ leading_zero = true;
+
+ // It could be hexadecimal value.
+ if (((flags_ & ALLOW_HEX) || (flags_ & ALLOW_HEX_FLOATS)) &&
+ (*current == 'x' || *current == 'X')) {
+ ++current;
+
+ if (current == end) return junk_string_value_; // "0x"
+
+ bool parse_as_hex_float = (flags_ & ALLOW_HEX_FLOATS) &&
+ IsHexFloatString(current, end, separator_, allow_trailing_junk);
+
+ if (!parse_as_hex_float && !isDigit(*current, 16)) {
+ return junk_string_value_;
+ }
+
+ bool result_is_junk;
+ double result = RadixStringToIeee<4>(&current,
+ end,
+ sign,
+ separator_,
+ parse_as_hex_float,
+ allow_trailing_junk,
+ junk_string_value_,
+ read_as_double,
+ &result_is_junk);
+ if (!result_is_junk) {
+ if (allow_trailing_spaces) AdvanceToNonspace(&current, end);
+ *processed_characters_count = static_cast<int>(current - input);
+ }
+ return result;
+ }
+
+ // Ignore leading zeros in the integer part.
+ while (*current == '0') {
+ if (Advance(&current, separator_, 10, end)) {
+ *processed_characters_count = static_cast<int>(current - input);
+ return SignedZero(sign);
+ }
+ }
+ }
+
+ bool octal = leading_zero && (flags_ & ALLOW_OCTALS) != 0;
+
+ // Copy significant digits of the integer part (if any) to the buffer.
+ while (*current >= '0' && *current <= '9') {
+ if (significant_digits < kMaxSignificantDigits) {
+ ASSERT(buffer_pos < kBufferSize);
+ buffer[buffer_pos++] = static_cast<char>(*current);
+ significant_digits++;
+ // Will later check if it's an octal in the buffer.
+ } else {
+ insignificant_digits++; // Move the digit into the exponential part.
+ nonzero_digit_dropped = nonzero_digit_dropped || *current != '0';
+ }
+ octal = octal && *current < '8';
+ if (Advance(&current, separator_, 10, end)) goto parsing_done;
+ }
+
+ if (significant_digits == 0) {
+ octal = false;
+ }
+
+ if (*current == '.') {
+ if (octal && !allow_trailing_junk) return junk_string_value_;
+ if (octal) goto parsing_done;
+
+ if (Advance(&current, separator_, 10, end)) {
+ if (significant_digits == 0 && !leading_zero) {
+ return junk_string_value_;
+ } else {
+ goto parsing_done;
+ }
+ }
+
+ if (significant_digits == 0) {
+ // octal = false;
+ // Integer part consists of 0 or is absent. Significant digits start after
+ // leading zeros (if any).
+ while (*current == '0') {
+ if (Advance(&current, separator_, 10, end)) {
+ *processed_characters_count = static_cast<int>(current - input);
+ return SignedZero(sign);
+ }
+ exponent--; // Move this 0 into the exponent.
+ }
+ }
+
+ // There is a fractional part.
+ // We don't emit a '.', but adjust the exponent instead.
+ while (*current >= '0' && *current <= '9') {
+ if (significant_digits < kMaxSignificantDigits) {
+ ASSERT(buffer_pos < kBufferSize);
+ buffer[buffer_pos++] = static_cast<char>(*current);
+ significant_digits++;
+ exponent--;
+ } else {
+ // Ignore insignificant digits in the fractional part.
+ nonzero_digit_dropped = nonzero_digit_dropped || *current != '0';
+ }
+ if (Advance(&current, separator_, 10, end)) goto parsing_done;
+ }
+ }
+
+ if (!leading_zero && exponent == 0 && significant_digits == 0) {
+ // If leading_zeros is true then the string contains zeros.
+ // If exponent < 0 then string was [+-]\.0*...
+ // If significant_digits != 0 the string is not equal to 0.
+ // Otherwise there are no digits in the string.
+ return junk_string_value_;
+ }
+
+ // Parse exponential part.
+ if (*current == 'e' || *current == 'E') {
+ if (octal && !allow_trailing_junk) return junk_string_value_;
+ if (octal) goto parsing_done;
+ Iterator junk_begin = current;
+ ++current;
+ if (current == end) {
+ if (allow_trailing_junk) {
+ current = junk_begin;
+ goto parsing_done;
+ } else {
+ return junk_string_value_;
+ }
+ }
+ char exponen_sign = '+';
+ if (*current == '+' || *current == '-') {
+ exponen_sign = static_cast<char>(*current);
+ ++current;
+ if (current == end) {
+ if (allow_trailing_junk) {
+ current = junk_begin;
+ goto parsing_done;
+ } else {
+ return junk_string_value_;
+ }
+ }
+ }
+
+ if (current == end || *current < '0' || *current > '9') {
+ if (allow_trailing_junk) {
+ current = junk_begin;
+ goto parsing_done;
+ } else {
+ return junk_string_value_;
+ }
+ }
+
+ const int max_exponent = INT_MAX / 2;
+ ASSERT(-max_exponent / 2 <= exponent && exponent <= max_exponent / 2);
+ int num = 0;
+ do {
+ // Check overflow.
+ int digit = *current - '0';
+ if (num >= max_exponent / 10
+ && !(num == max_exponent / 10 && digit <= max_exponent % 10)) {
+ num = max_exponent;
+ } else {
+ num = num * 10 + digit;
+ }
+ ++current;
+ } while (current != end && *current >= '0' && *current <= '9');
+
+ exponent += (exponen_sign == '-' ? -num : num);
+ }
+
+ if (!(allow_trailing_spaces || allow_trailing_junk) && (current != end)) {
+ return junk_string_value_;
+ }
+ if (!allow_trailing_junk && AdvanceToNonspace(&current, end)) {
+ return junk_string_value_;
+ }
+ if (allow_trailing_spaces) {
+ AdvanceToNonspace(&current, end);
+ }
+
+ parsing_done:
+ exponent += insignificant_digits;
+
+ if (octal) {
+ double result;
+ bool result_is_junk;
+ char* start = buffer;
+ result = RadixStringToIeee<3>(&start,
+ buffer + buffer_pos,
+ sign,
+ separator_,
+ false, // Don't parse as hex_float.
+ allow_trailing_junk,
+ junk_string_value_,
+ read_as_double,
+ &result_is_junk);
+ ASSERT(!result_is_junk);
+ *processed_characters_count = static_cast<int>(current - input);
+ return result;
+ }
+
+ if (nonzero_digit_dropped) {
+ buffer[buffer_pos++] = '1';
+ exponent--;
+ }
+
+ ASSERT(buffer_pos < kBufferSize);
+ buffer[buffer_pos] = '\0';
+
+ double converted;
+ if (read_as_double) {
+ converted = Strtod(Vector<const char>(buffer, buffer_pos), exponent);
+ } else {
+ converted = Strtof(Vector<const char>(buffer, buffer_pos), exponent);
+ }
+ *processed_characters_count = static_cast<int>(current - input);
+ return sign? -converted: converted;
+}
+
+
+double StringToDoubleConverter::StringToDouble(
+ const char* buffer,
+ int length,
+ int* processed_characters_count) const {
+ return StringToIeee(buffer, length, true, processed_characters_count);
+}
+
+
+double StringToDoubleConverter::StringToDouble(
+ const uc16* buffer,
+ int length,
+ int* processed_characters_count) const {
+ return StringToIeee(buffer, length, true, processed_characters_count);
+}
+
+
+float StringToDoubleConverter::StringToFloat(
+ const char* buffer,
+ int length,
+ int* processed_characters_count) const {
+ return static_cast<float>(StringToIeee(buffer, length, false,
+ processed_characters_count));
+}
+
+
+float StringToDoubleConverter::StringToFloat(
+ const uc16* buffer,
+ int length,
+ int* processed_characters_count) const {
+ return static_cast<float>(StringToIeee(buffer, length, false,
+ processed_characters_count));
+}
+
+} // namespace double_conversion
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/double-conversion.h b/src/arrow/cpp/src/arrow/vendored/double-conversion/double-conversion.h
new file mode 100644
index 000000000..9dc3ebd8d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/double-conversion.h
@@ -0,0 +1,587 @@
+// Copyright 2012 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef DOUBLE_CONVERSION_DOUBLE_CONVERSION_H_
+#define DOUBLE_CONVERSION_DOUBLE_CONVERSION_H_
+
+#include "utils.h"
+
+namespace double_conversion {
+
+class DoubleToStringConverter {
+ public:
+ // When calling ToFixed with a double > 10^kMaxFixedDigitsBeforePoint
+ // or a requested_digits parameter > kMaxFixedDigitsAfterPoint then the
+ // function returns false.
+ static const int kMaxFixedDigitsBeforePoint = 60;
+ static const int kMaxFixedDigitsAfterPoint = 60;
+
+ // When calling ToExponential with a requested_digits
+ // parameter > kMaxExponentialDigits then the function returns false.
+ static const int kMaxExponentialDigits = 120;
+
+ // When calling ToPrecision with a requested_digits
+ // parameter < kMinPrecisionDigits or requested_digits > kMaxPrecisionDigits
+ // then the function returns false.
+ static const int kMinPrecisionDigits = 1;
+ static const int kMaxPrecisionDigits = 120;
+
+ enum Flags {
+ NO_FLAGS = 0,
+ EMIT_POSITIVE_EXPONENT_SIGN = 1,
+ EMIT_TRAILING_DECIMAL_POINT = 2,
+ EMIT_TRAILING_ZERO_AFTER_POINT = 4,
+ UNIQUE_ZERO = 8
+ };
+
+ // Flags should be a bit-or combination of the possible Flags-enum.
+ // - NO_FLAGS: no special flags.
+ // - EMIT_POSITIVE_EXPONENT_SIGN: when the number is converted into exponent
+ // form, emits a '+' for positive exponents. Example: 1.2e+2.
+ // - EMIT_TRAILING_DECIMAL_POINT: when the input number is an integer and is
+ // converted into decimal format then a trailing decimal point is appended.
+ // Example: 2345.0 is converted to "2345.".
+ // - EMIT_TRAILING_ZERO_AFTER_POINT: in addition to a trailing decimal point
+ // emits a trailing '0'-character. This flag requires the
+ // EXMIT_TRAILING_DECIMAL_POINT flag.
+ // Example: 2345.0 is converted to "2345.0".
+ // - UNIQUE_ZERO: "-0.0" is converted to "0.0".
+ //
+ // Infinity symbol and nan_symbol provide the string representation for these
+ // special values. If the string is NULL and the special value is encountered
+ // then the conversion functions return false.
+ //
+ // The exponent_character is used in exponential representations. It is
+ // usually 'e' or 'E'.
+ //
+ // When converting to the shortest representation the converter will
+ // represent input numbers in decimal format if they are in the interval
+ // [10^decimal_in_shortest_low; 10^decimal_in_shortest_high[
+ // (lower boundary included, greater boundary excluded).
+ // Example: with decimal_in_shortest_low = -6 and
+ // decimal_in_shortest_high = 21:
+ // ToShortest(0.000001) -> "0.000001"
+ // ToShortest(0.0000001) -> "1e-7"
+ // ToShortest(111111111111111111111.0) -> "111111111111111110000"
+ // ToShortest(100000000000000000000.0) -> "100000000000000000000"
+ // ToShortest(1111111111111111111111.0) -> "1.1111111111111111e+21"
+ //
+ // When converting to precision mode the converter may add
+ // max_leading_padding_zeroes before returning the number in exponential
+ // format.
+ // Example with max_leading_padding_zeroes_in_precision_mode = 6.
+ // ToPrecision(0.0000012345, 2) -> "0.0000012"
+ // ToPrecision(0.00000012345, 2) -> "1.2e-7"
+ // Similarily the converter may add up to
+ // max_trailing_padding_zeroes_in_precision_mode in precision mode to avoid
+ // returning an exponential representation. A zero added by the
+ // EMIT_TRAILING_ZERO_AFTER_POINT flag is counted for this limit.
+ // Examples for max_trailing_padding_zeroes_in_precision_mode = 1:
+ // ToPrecision(230.0, 2) -> "230"
+ // ToPrecision(230.0, 2) -> "230." with EMIT_TRAILING_DECIMAL_POINT.
+ // ToPrecision(230.0, 2) -> "2.3e2" with EMIT_TRAILING_ZERO_AFTER_POINT.
+ //
+ // When converting numbers to scientific notation representation, if the mantissa of
+ // the representation is an integer number, the EMIT_TRAILING_DECIMAL_POINT flag will
+ // add a '.' character at the end of the representation:
+ // - With EMIT_TRAILING_DECIMAL_POINT enabled -> 0.0009 => 9.E-4
+ // - With EMIT_TRAILING_DECIMAL_POINT disabled -> 0.0009 => 9E-4
+ //
+ // If the mantissa is an integer and the EMIT_TRAILING_ZERO_AFTER_POINT flag is enabled
+ // it will add a '0' character at the end of the mantissa representation. Note that that
+ // flag depends on EMIT_TRAILING_DECIMAL_POINT flag be enabled.
+ // - With EMIT_TRAILING_ZERO_AFTER_POINT enabled -> 0.0009 => 9.0E-4
+ DoubleToStringConverter(int flags,
+ const char* infinity_symbol,
+ const char* nan_symbol,
+ char exponent_character,
+ int decimal_in_shortest_low,
+ int decimal_in_shortest_high,
+ int max_leading_padding_zeroes_in_precision_mode,
+ int max_trailing_padding_zeroes_in_precision_mode)
+ : flags_(flags),
+ infinity_symbol_(infinity_symbol),
+ nan_symbol_(nan_symbol),
+ exponent_character_(exponent_character),
+ decimal_in_shortest_low_(decimal_in_shortest_low),
+ decimal_in_shortest_high_(decimal_in_shortest_high),
+ max_leading_padding_zeroes_in_precision_mode_(
+ max_leading_padding_zeroes_in_precision_mode),
+ max_trailing_padding_zeroes_in_precision_mode_(
+ max_trailing_padding_zeroes_in_precision_mode) {
+ // When 'trailing zero after the point' is set, then 'trailing point'
+ // must be set too.
+ ASSERT(((flags & EMIT_TRAILING_DECIMAL_POINT) != 0) ||
+ !((flags & EMIT_TRAILING_ZERO_AFTER_POINT) != 0));
+ }
+
+ // Returns a converter following the EcmaScript specification.
+ static const DoubleToStringConverter& EcmaScriptConverter();
+
+ // Computes the shortest string of digits that correctly represent the input
+ // number. Depending on decimal_in_shortest_low and decimal_in_shortest_high
+ // (see constructor) it then either returns a decimal representation, or an
+ // exponential representation.
+ // Example with decimal_in_shortest_low = -6,
+ // decimal_in_shortest_high = 21,
+ // EMIT_POSITIVE_EXPONENT_SIGN activated, and
+ // EMIT_TRAILING_DECIMAL_POINT deactived:
+ // ToShortest(0.000001) -> "0.000001"
+ // ToShortest(0.0000001) -> "1e-7"
+ // ToShortest(111111111111111111111.0) -> "111111111111111110000"
+ // ToShortest(100000000000000000000.0) -> "100000000000000000000"
+ // ToShortest(1111111111111111111111.0) -> "1.1111111111111111e+21"
+ //
+ // Note: the conversion may round the output if the returned string
+ // is accurate enough to uniquely identify the input-number.
+ // For example the most precise representation of the double 9e59 equals
+ // "899999999999999918767229449717619953810131273674690656206848", but
+ // the converter will return the shorter (but still correct) "9e59".
+ //
+ // Returns true if the conversion succeeds. The conversion always succeeds
+ // except when the input value is special and no infinity_symbol or
+ // nan_symbol has been given to the constructor.
+ bool ToShortest(double value, StringBuilder* result_builder) const {
+ return ToShortestIeeeNumber(value, result_builder, SHORTEST);
+ }
+
+ // Same as ToShortest, but for single-precision floats.
+ bool ToShortestSingle(float value, StringBuilder* result_builder) const {
+ return ToShortestIeeeNumber(value, result_builder, SHORTEST_SINGLE);
+ }
+
+
+ // Computes a decimal representation with a fixed number of digits after the
+ // decimal point. The last emitted digit is rounded.
+ //
+ // Examples:
+ // ToFixed(3.12, 1) -> "3.1"
+ // ToFixed(3.1415, 3) -> "3.142"
+ // ToFixed(1234.56789, 4) -> "1234.5679"
+ // ToFixed(1.23, 5) -> "1.23000"
+ // ToFixed(0.1, 4) -> "0.1000"
+ // ToFixed(1e30, 2) -> "1000000000000000019884624838656.00"
+ // ToFixed(0.1, 30) -> "0.100000000000000005551115123126"
+ // ToFixed(0.1, 17) -> "0.10000000000000001"
+ //
+ // If requested_digits equals 0, then the tail of the result depends on
+ // the EMIT_TRAILING_DECIMAL_POINT and EMIT_TRAILING_ZERO_AFTER_POINT.
+ // Examples, for requested_digits == 0,
+ // let EMIT_TRAILING_DECIMAL_POINT and EMIT_TRAILING_ZERO_AFTER_POINT be
+ // - false and false: then 123.45 -> 123
+ // 0.678 -> 1
+ // - true and false: then 123.45 -> 123.
+ // 0.678 -> 1.
+ // - true and true: then 123.45 -> 123.0
+ // 0.678 -> 1.0
+ //
+ // Returns true if the conversion succeeds. The conversion always succeeds
+ // except for the following cases:
+ // - the input value is special and no infinity_symbol or nan_symbol has
+ // been provided to the constructor,
+ // - 'value' > 10^kMaxFixedDigitsBeforePoint, or
+ // - 'requested_digits' > kMaxFixedDigitsAfterPoint.
+ // The last two conditions imply that the result will never contain more than
+ // 1 + kMaxFixedDigitsBeforePoint + 1 + kMaxFixedDigitsAfterPoint characters
+ // (one additional character for the sign, and one for the decimal point).
+ bool ToFixed(double value,
+ int requested_digits,
+ StringBuilder* result_builder) const;
+
+ // Computes a representation in exponential format with requested_digits
+ // after the decimal point. The last emitted digit is rounded.
+ // If requested_digits equals -1, then the shortest exponential representation
+ // is computed.
+ //
+ // Examples with EMIT_POSITIVE_EXPONENT_SIGN deactivated, and
+ // exponent_character set to 'e'.
+ // ToExponential(3.12, 1) -> "3.1e0"
+ // ToExponential(5.0, 3) -> "5.000e0"
+ // ToExponential(0.001, 2) -> "1.00e-3"
+ // ToExponential(3.1415, -1) -> "3.1415e0"
+ // ToExponential(3.1415, 4) -> "3.1415e0"
+ // ToExponential(3.1415, 3) -> "3.142e0"
+ // ToExponential(123456789000000, 3) -> "1.235e14"
+ // ToExponential(1000000000000000019884624838656.0, -1) -> "1e30"
+ // ToExponential(1000000000000000019884624838656.0, 32) ->
+ // "1.00000000000000001988462483865600e30"
+ // ToExponential(1234, 0) -> "1e3"
+ //
+ // Returns true if the conversion succeeds. The conversion always succeeds
+ // except for the following cases:
+ // - the input value is special and no infinity_symbol or nan_symbol has
+ // been provided to the constructor,
+ // - 'requested_digits' > kMaxExponentialDigits.
+ // The last condition implies that the result will never contain more than
+ // kMaxExponentialDigits + 8 characters (the sign, the digit before the
+ // decimal point, the decimal point, the exponent character, the
+ // exponent's sign, and at most 3 exponent digits).
+ bool ToExponential(double value,
+ int requested_digits,
+ StringBuilder* result_builder) const;
+
+ // Computes 'precision' leading digits of the given 'value' and returns them
+ // either in exponential or decimal format, depending on
+ // max_{leading|trailing}_padding_zeroes_in_precision_mode (given to the
+ // constructor).
+ // The last computed digit is rounded.
+ //
+ // Example with max_leading_padding_zeroes_in_precision_mode = 6.
+ // ToPrecision(0.0000012345, 2) -> "0.0000012"
+ // ToPrecision(0.00000012345, 2) -> "1.2e-7"
+ // Similarily the converter may add up to
+ // max_trailing_padding_zeroes_in_precision_mode in precision mode to avoid
+ // returning an exponential representation. A zero added by the
+ // EMIT_TRAILING_ZERO_AFTER_POINT flag is counted for this limit.
+ // Examples for max_trailing_padding_zeroes_in_precision_mode = 1:
+ // ToPrecision(230.0, 2) -> "230"
+ // ToPrecision(230.0, 2) -> "230." with EMIT_TRAILING_DECIMAL_POINT.
+ // ToPrecision(230.0, 2) -> "2.3e2" with EMIT_TRAILING_ZERO_AFTER_POINT.
+ // Examples for max_trailing_padding_zeroes_in_precision_mode = 3, and no
+ // EMIT_TRAILING_ZERO_AFTER_POINT:
+ // ToPrecision(123450.0, 6) -> "123450"
+ // ToPrecision(123450.0, 5) -> "123450"
+ // ToPrecision(123450.0, 4) -> "123500"
+ // ToPrecision(123450.0, 3) -> "123000"
+ // ToPrecision(123450.0, 2) -> "1.2e5"
+ //
+ // Returns true if the conversion succeeds. The conversion always succeeds
+ // except for the following cases:
+ // - the input value is special and no infinity_symbol or nan_symbol has
+ // been provided to the constructor,
+ // - precision < kMinPericisionDigits
+ // - precision > kMaxPrecisionDigits
+ // The last condition implies that the result will never contain more than
+ // kMaxPrecisionDigits + 7 characters (the sign, the decimal point, the
+ // exponent character, the exponent's sign, and at most 3 exponent digits).
+ bool ToPrecision(double value,
+ int precision,
+ StringBuilder* result_builder) const;
+
+ enum DtoaMode {
+ // Produce the shortest correct representation.
+ // For example the output of 0.299999999999999988897 is (the less accurate
+ // but correct) 0.3.
+ SHORTEST,
+ // Same as SHORTEST, but for single-precision floats.
+ SHORTEST_SINGLE,
+ // Produce a fixed number of digits after the decimal point.
+ // For instance fixed(0.1, 4) becomes 0.1000
+ // If the input number is big, the output will be big.
+ FIXED,
+ // Fixed number of digits (independent of the decimal point).
+ PRECISION
+ };
+
+ // The maximal number of digits that are needed to emit a double in base 10.
+ // A higher precision can be achieved by using more digits, but the shortest
+ // accurate representation of any double will never use more digits than
+ // kBase10MaximalLength.
+ // Note that DoubleToAscii null-terminates its input. So the given buffer
+ // should be at least kBase10MaximalLength + 1 characters long.
+ static const int kBase10MaximalLength = 17;
+
+ // Converts the given double 'v' to digit characters. 'v' must not be NaN,
+ // +Infinity, or -Infinity. In SHORTEST_SINGLE-mode this restriction also
+ // applies to 'v' after it has been casted to a single-precision float. That
+ // is, in this mode static_cast<float>(v) must not be NaN, +Infinity or
+ // -Infinity.
+ //
+ // The result should be interpreted as buffer * 10^(point-length).
+ //
+ // The digits are written to the buffer in the platform's charset, which is
+ // often UTF-8 (with ASCII-range digits) but may be another charset, such
+ // as EBCDIC.
+ //
+ // The output depends on the given mode:
+ // - SHORTEST: produce the least amount of digits for which the internal
+ // identity requirement is still satisfied. If the digits are printed
+ // (together with the correct exponent) then reading this number will give
+ // 'v' again. The buffer will choose the representation that is closest to
+ // 'v'. If there are two at the same distance, than the one farther away
+ // from 0 is chosen (halfway cases - ending with 5 - are rounded up).
+ // In this mode the 'requested_digits' parameter is ignored.
+ // - SHORTEST_SINGLE: same as SHORTEST but with single-precision.
+ // - FIXED: produces digits necessary to print a given number with
+ // 'requested_digits' digits after the decimal point. The produced digits
+ // might be too short in which case the caller has to fill the remainder
+ // with '0's.
+ // Example: toFixed(0.001, 5) is allowed to return buffer="1", point=-2.
+ // Halfway cases are rounded towards +/-Infinity (away from 0). The call
+ // toFixed(0.15, 2) thus returns buffer="2", point=0.
+ // The returned buffer may contain digits that would be truncated from the
+ // shortest representation of the input.
+ // - PRECISION: produces 'requested_digits' where the first digit is not '0'.
+ // Even though the length of produced digits usually equals
+ // 'requested_digits', the function is allowed to return fewer digits, in
+ // which case the caller has to fill the missing digits with '0's.
+ // Halfway cases are again rounded away from 0.
+ // DoubleToAscii expects the given buffer to be big enough to hold all
+ // digits and a terminating null-character. In SHORTEST-mode it expects a
+ // buffer of at least kBase10MaximalLength + 1. In all other modes the
+ // requested_digits parameter and the padding-zeroes limit the size of the
+ // output. Don't forget the decimal point, the exponent character and the
+ // terminating null-character when computing the maximal output size.
+ // The given length is only used in debug mode to ensure the buffer is big
+ // enough.
+ static void DoubleToAscii(double v,
+ DtoaMode mode,
+ int requested_digits,
+ char* buffer,
+ int buffer_length,
+ bool* sign,
+ int* length,
+ int* point);
+
+ private:
+ // Implementation for ToShortest and ToShortestSingle.
+ bool ToShortestIeeeNumber(double value,
+ StringBuilder* result_builder,
+ DtoaMode mode) const;
+
+ // If the value is a special value (NaN or Infinity) constructs the
+ // corresponding string using the configured infinity/nan-symbol.
+ // If either of them is NULL or the value is not special then the
+ // function returns false.
+ bool HandleSpecialValues(double value, StringBuilder* result_builder) const;
+ // Constructs an exponential representation (i.e. 1.234e56).
+ // The given exponent assumes a decimal point after the first decimal digit.
+ void CreateExponentialRepresentation(const char* decimal_digits,
+ int length,
+ int exponent,
+ StringBuilder* result_builder) const;
+ // Creates a decimal representation (i.e 1234.5678).
+ void CreateDecimalRepresentation(const char* decimal_digits,
+ int length,
+ int decimal_point,
+ int digits_after_point,
+ StringBuilder* result_builder) const;
+
+ const int flags_;
+ const char* const infinity_symbol_;
+ const char* const nan_symbol_;
+ const char exponent_character_;
+ const int decimal_in_shortest_low_;
+ const int decimal_in_shortest_high_;
+ const int max_leading_padding_zeroes_in_precision_mode_;
+ const int max_trailing_padding_zeroes_in_precision_mode_;
+
+ DC_DISALLOW_IMPLICIT_CONSTRUCTORS(DoubleToStringConverter);
+};
+
+
+class StringToDoubleConverter {
+ public:
+ // Enumeration for allowing octals and ignoring junk when converting
+ // strings to numbers.
+ enum Flags {
+ NO_FLAGS = 0,
+ ALLOW_HEX = 1,
+ ALLOW_OCTALS = 2,
+ ALLOW_TRAILING_JUNK = 4,
+ ALLOW_LEADING_SPACES = 8,
+ ALLOW_TRAILING_SPACES = 16,
+ ALLOW_SPACES_AFTER_SIGN = 32,
+ ALLOW_CASE_INSENSIBILITY = 64,
+ ALLOW_HEX_FLOATS = 128,
+ };
+
+ static const uc16 kNoSeparator = '\0';
+
+ // Flags should be a bit-or combination of the possible Flags-enum.
+ // - NO_FLAGS: no special flags.
+ // - ALLOW_HEX: recognizes the prefix "0x". Hex numbers may only be integers.
+ // Ex: StringToDouble("0x1234") -> 4660.0
+ // In StringToDouble("0x1234.56") the characters ".56" are trailing
+ // junk. The result of the call is hence dependent on
+ // the ALLOW_TRAILING_JUNK flag and/or the junk value.
+ // With this flag "0x" is a junk-string. Even with ALLOW_TRAILING_JUNK,
+ // the string will not be parsed as "0" followed by junk.
+ //
+ // - ALLOW_OCTALS: recognizes the prefix "0" for octals:
+ // If a sequence of octal digits starts with '0', then the number is
+ // read as octal integer. Octal numbers may only be integers.
+ // Ex: StringToDouble("01234") -> 668.0
+ // StringToDouble("012349") -> 12349.0 // Not a sequence of octal
+ // // digits.
+ // In StringToDouble("01234.56") the characters ".56" are trailing
+ // junk. The result of the call is hence dependent on
+ // the ALLOW_TRAILING_JUNK flag and/or the junk value.
+ // In StringToDouble("01234e56") the characters "e56" are trailing
+ // junk, too.
+ // - ALLOW_TRAILING_JUNK: ignore trailing characters that are not part of
+ // a double literal.
+ // - ALLOW_LEADING_SPACES: skip over leading whitespace, including spaces,
+ // new-lines, and tabs.
+ // - ALLOW_TRAILING_SPACES: ignore trailing whitespace.
+ // - ALLOW_SPACES_AFTER_SIGN: ignore whitespace after the sign.
+ // Ex: StringToDouble("- 123.2") -> -123.2.
+ // StringToDouble("+ 123.2") -> 123.2
+ // - ALLOW_CASE_INSENSIBILITY: ignore case of characters for special values:
+ // infinity and nan.
+ // - ALLOW_HEX_FLOATS: allows hexadecimal float literals.
+ // This *must* start with "0x" and separate the exponent with "p".
+ // Examples: 0x1.2p3 == 9.0
+ // 0x10.1p0 == 16.0625
+ // ALLOW_HEX and ALLOW_HEX_FLOATS are indendent.
+ //
+ // empty_string_value is returned when an empty string is given as input.
+ // If ALLOW_LEADING_SPACES or ALLOW_TRAILING_SPACES are set, then a string
+ // containing only spaces is converted to the 'empty_string_value', too.
+ //
+ // junk_string_value is returned when
+ // a) ALLOW_TRAILING_JUNK is not set, and a junk character (a character not
+ // part of a double-literal) is found.
+ // b) ALLOW_TRAILING_JUNK is set, but the string does not start with a
+ // double literal.
+ //
+ // infinity_symbol and nan_symbol are strings that are used to detect
+ // inputs that represent infinity and NaN. They can be null, in which case
+ // they are ignored.
+ // The conversion routine first reads any possible signs. Then it compares the
+ // following character of the input-string with the first character of
+ // the infinity, and nan-symbol. If either matches, the function assumes, that
+ // a match has been found, and expects the following input characters to match
+ // the remaining characters of the special-value symbol.
+ // This means that the following restrictions apply to special-value symbols:
+ // - they must not start with signs ('+', or '-'),
+ // - they must not have the same first character.
+ // - they must not start with digits.
+ //
+ // If the separator character is not kNoSeparator, then that specific
+ // character is ignored when in between two valid digits of the significant.
+ // It is not allowed to appear in the exponent.
+ // It is not allowed to lead or trail the number.
+ // It is not allowed to appear twice next to each other.
+ //
+ // Examples:
+ // flags = ALLOW_HEX | ALLOW_TRAILING_JUNK,
+ // empty_string_value = 0.0,
+ // junk_string_value = NaN,
+ // infinity_symbol = "infinity",
+ // nan_symbol = "nan":
+ // StringToDouble("0x1234") -> 4660.0.
+ // StringToDouble("0x1234K") -> 4660.0.
+ // StringToDouble("") -> 0.0 // empty_string_value.
+ // StringToDouble(" ") -> NaN // junk_string_value.
+ // StringToDouble(" 1") -> NaN // junk_string_value.
+ // StringToDouble("0x") -> NaN // junk_string_value.
+ // StringToDouble("-123.45") -> -123.45.
+ // StringToDouble("--123.45") -> NaN // junk_string_value.
+ // StringToDouble("123e45") -> 123e45.
+ // StringToDouble("123E45") -> 123e45.
+ // StringToDouble("123e+45") -> 123e45.
+ // StringToDouble("123E-45") -> 123e-45.
+ // StringToDouble("123e") -> 123.0 // trailing junk ignored.
+ // StringToDouble("123e-") -> 123.0 // trailing junk ignored.
+ // StringToDouble("+NaN") -> NaN // NaN string literal.
+ // StringToDouble("-infinity") -> -inf. // infinity literal.
+ // StringToDouble("Infinity") -> NaN // junk_string_value.
+ //
+ // flags = ALLOW_OCTAL | ALLOW_LEADING_SPACES,
+ // empty_string_value = 0.0,
+ // junk_string_value = NaN,
+ // infinity_symbol = NULL,
+ // nan_symbol = NULL:
+ // StringToDouble("0x1234") -> NaN // junk_string_value.
+ // StringToDouble("01234") -> 668.0.
+ // StringToDouble("") -> 0.0 // empty_string_value.
+ // StringToDouble(" ") -> 0.0 // empty_string_value.
+ // StringToDouble(" 1") -> 1.0
+ // StringToDouble("0x") -> NaN // junk_string_value.
+ // StringToDouble("0123e45") -> NaN // junk_string_value.
+ // StringToDouble("01239E45") -> 1239e45.
+ // StringToDouble("-infinity") -> NaN // junk_string_value.
+ // StringToDouble("NaN") -> NaN // junk_string_value.
+ //
+ // flags = NO_FLAGS,
+ // separator = ' ':
+ // StringToDouble("1 2 3 4") -> 1234.0
+ // StringToDouble("1 2") -> NaN // junk_string_value
+ // StringToDouble("1 000 000.0") -> 1000000.0
+ // StringToDouble("1.000 000") -> 1.0
+ // StringToDouble("1.0e1 000") -> NaN // junk_string_value
+ StringToDoubleConverter(int flags,
+ double empty_string_value,
+ double junk_string_value,
+ const char* infinity_symbol,
+ const char* nan_symbol,
+ uc16 separator = kNoSeparator)
+ : flags_(flags),
+ empty_string_value_(empty_string_value),
+ junk_string_value_(junk_string_value),
+ infinity_symbol_(infinity_symbol),
+ nan_symbol_(nan_symbol),
+ separator_(separator) {
+ }
+
+ // Performs the conversion.
+ // The output parameter 'processed_characters_count' is set to the number
+ // of characters that have been processed to read the number.
+ // Spaces than are processed with ALLOW_{LEADING|TRAILING}_SPACES are included
+ // in the 'processed_characters_count'. Trailing junk is never included.
+ double StringToDouble(const char* buffer,
+ int length,
+ int* processed_characters_count) const;
+
+ // Same as StringToDouble above but for 16 bit characters.
+ double StringToDouble(const uc16* buffer,
+ int length,
+ int* processed_characters_count) const;
+
+ // Same as StringToDouble but reads a float.
+ // Note that this is not equivalent to static_cast<float>(StringToDouble(...))
+ // due to potential double-rounding.
+ float StringToFloat(const char* buffer,
+ int length,
+ int* processed_characters_count) const;
+
+ // Same as StringToFloat above but for 16 bit characters.
+ float StringToFloat(const uc16* buffer,
+ int length,
+ int* processed_characters_count) const;
+
+ private:
+ const int flags_;
+ const double empty_string_value_;
+ const double junk_string_value_;
+ const char* const infinity_symbol_;
+ const char* const nan_symbol_;
+ const uc16 separator_;
+
+ template <class Iterator>
+ double StringToIeee(Iterator start_pointer,
+ int length,
+ bool read_as_double,
+ int* processed_characters_count) const;
+
+ DC_DISALLOW_IMPLICIT_CONSTRUCTORS(StringToDoubleConverter);
+};
+
+} // namespace double_conversion
+
+#endif // DOUBLE_CONVERSION_DOUBLE_CONVERSION_H_
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/fast-dtoa.cc b/src/arrow/cpp/src/arrow/vendored/double-conversion/fast-dtoa.cc
new file mode 100644
index 000000000..61350383a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/fast-dtoa.cc
@@ -0,0 +1,665 @@
+// Copyright 2012 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "fast-dtoa.h"
+
+#include "cached-powers.h"
+#include "diy-fp.h"
+#include "ieee.h"
+
+namespace double_conversion {
+
+// The minimal and maximal target exponent define the range of w's binary
+// exponent, where 'w' is the result of multiplying the input by a cached power
+// of ten.
+//
+// A different range might be chosen on a different platform, to optimize digit
+// generation, but a smaller range requires more powers of ten to be cached.
+static const int kMinimalTargetExponent = -60;
+static const int kMaximalTargetExponent = -32;
+
+
+// Adjusts the last digit of the generated number, and screens out generated
+// solutions that may be inaccurate. A solution may be inaccurate if it is
+// outside the safe interval, or if we cannot prove that it is closer to the
+// input than a neighboring representation of the same length.
+//
+// Input: * buffer containing the digits of too_high / 10^kappa
+// * the buffer's length
+// * distance_too_high_w == (too_high - w).f() * unit
+// * unsafe_interval == (too_high - too_low).f() * unit
+// * rest = (too_high - buffer * 10^kappa).f() * unit
+// * ten_kappa = 10^kappa * unit
+// * unit = the common multiplier
+// Output: returns true if the buffer is guaranteed to contain the closest
+// representable number to the input.
+// Modifies the generated digits in the buffer to approach (round towards) w.
+static bool RoundWeed(Vector<char> buffer,
+ int length,
+ uint64_t distance_too_high_w,
+ uint64_t unsafe_interval,
+ uint64_t rest,
+ uint64_t ten_kappa,
+ uint64_t unit) {
+ uint64_t small_distance = distance_too_high_w - unit;
+ uint64_t big_distance = distance_too_high_w + unit;
+ // Let w_low = too_high - big_distance, and
+ // w_high = too_high - small_distance.
+ // Note: w_low < w < w_high
+ //
+ // The real w (* unit) must lie somewhere inside the interval
+ // ]w_low; w_high[ (often written as "(w_low; w_high)")
+
+ // Basically the buffer currently contains a number in the unsafe interval
+ // ]too_low; too_high[ with too_low < w < too_high
+ //
+ // too_high - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+ // ^v 1 unit ^ ^ ^ ^
+ // boundary_high --------------------- . . . .
+ // ^v 1 unit . . . .
+ // - - - - - - - - - - - - - - - - - - - + - - + - - - - - - . .
+ // . . ^ . .
+ // . big_distance . . .
+ // . . . . rest
+ // small_distance . . . .
+ // v . . . .
+ // w_high - - - - - - - - - - - - - - - - - - . . . .
+ // ^v 1 unit . . . .
+ // w ---------------------------------------- . . . .
+ // ^v 1 unit v . . .
+ // w_low - - - - - - - - - - - - - - - - - - - - - . . .
+ // . . v
+ // buffer --------------------------------------------------+-------+--------
+ // . .
+ // safe_interval .
+ // v .
+ // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - .
+ // ^v 1 unit .
+ // boundary_low ------------------------- unsafe_interval
+ // ^v 1 unit v
+ // too_low - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+ //
+ //
+ // Note that the value of buffer could lie anywhere inside the range too_low
+ // to too_high.
+ //
+ // boundary_low, boundary_high and w are approximations of the real boundaries
+ // and v (the input number). They are guaranteed to be precise up to one unit.
+ // In fact the error is guaranteed to be strictly less than one unit.
+ //
+ // Anything that lies outside the unsafe interval is guaranteed not to round
+ // to v when read again.
+ // Anything that lies inside the safe interval is guaranteed to round to v
+ // when read again.
+ // If the number inside the buffer lies inside the unsafe interval but not
+ // inside the safe interval then we simply do not know and bail out (returning
+ // false).
+ //
+ // Similarly we have to take into account the imprecision of 'w' when finding
+ // the closest representation of 'w'. If we have two potential
+ // representations, and one is closer to both w_low and w_high, then we know
+ // it is closer to the actual value v.
+ //
+ // By generating the digits of too_high we got the largest (closest to
+ // too_high) buffer that is still in the unsafe interval. In the case where
+ // w_high < buffer < too_high we try to decrement the buffer.
+ // This way the buffer approaches (rounds towards) w.
+ // There are 3 conditions that stop the decrementation process:
+ // 1) the buffer is already below w_high
+ // 2) decrementing the buffer would make it leave the unsafe interval
+ // 3) decrementing the buffer would yield a number below w_high and farther
+ // away than the current number. In other words:
+ // (buffer{-1} < w_high) && w_high - buffer{-1} > buffer - w_high
+ // Instead of using the buffer directly we use its distance to too_high.
+ // Conceptually rest ~= too_high - buffer
+ // We need to do the following tests in this order to avoid over- and
+ // underflows.
+ ASSERT(rest <= unsafe_interval);
+ while (rest < small_distance && // Negated condition 1
+ unsafe_interval - rest >= ten_kappa && // Negated condition 2
+ (rest + ten_kappa < small_distance || // buffer{-1} > w_high
+ small_distance - rest >= rest + ten_kappa - small_distance)) {
+ buffer[length - 1]--;
+ rest += ten_kappa;
+ }
+
+ // We have approached w+ as much as possible. We now test if approaching w-
+ // would require changing the buffer. If yes, then we have two possible
+ // representations close to w, but we cannot decide which one is closer.
+ if (rest < big_distance &&
+ unsafe_interval - rest >= ten_kappa &&
+ (rest + ten_kappa < big_distance ||
+ big_distance - rest > rest + ten_kappa - big_distance)) {
+ return false;
+ }
+
+ // Weeding test.
+ // The safe interval is [too_low + 2 ulp; too_high - 2 ulp]
+ // Since too_low = too_high - unsafe_interval this is equivalent to
+ // [too_high - unsafe_interval + 4 ulp; too_high - 2 ulp]
+ // Conceptually we have: rest ~= too_high - buffer
+ return (2 * unit <= rest) && (rest <= unsafe_interval - 4 * unit);
+}
+
+
+// Rounds the buffer upwards if the result is closer to v by possibly adding
+// 1 to the buffer. If the precision of the calculation is not sufficient to
+// round correctly, return false.
+// The rounding might shift the whole buffer in which case the kappa is
+// adjusted. For example "99", kappa = 3 might become "10", kappa = 4.
+//
+// If 2*rest > ten_kappa then the buffer needs to be round up.
+// rest can have an error of +/- 1 unit. This function accounts for the
+// imprecision and returns false, if the rounding direction cannot be
+// unambiguously determined.
+//
+// Precondition: rest < ten_kappa.
+static bool RoundWeedCounted(Vector<char> buffer,
+ int length,
+ uint64_t rest,
+ uint64_t ten_kappa,
+ uint64_t unit,
+ int* kappa) {
+ ASSERT(rest < ten_kappa);
+ // The following tests are done in a specific order to avoid overflows. They
+ // will work correctly with any uint64 values of rest < ten_kappa and unit.
+ //
+ // If the unit is too big, then we don't know which way to round. For example
+ // a unit of 50 means that the real number lies within rest +/- 50. If
+ // 10^kappa == 40 then there is no way to tell which way to round.
+ if (unit >= ten_kappa) return false;
+ // Even if unit is just half the size of 10^kappa we are already completely
+ // lost. (And after the previous test we know that the expression will not
+ // over/underflow.)
+ if (ten_kappa - unit <= unit) return false;
+ // If 2 * (rest + unit) <= 10^kappa we can safely round down.
+ if ((ten_kappa - rest > rest) && (ten_kappa - 2 * rest >= 2 * unit)) {
+ return true;
+ }
+ // If 2 * (rest - unit) >= 10^kappa, then we can safely round up.
+ if ((rest > unit) && (ten_kappa - (rest - unit) <= (rest - unit))) {
+ // Increment the last digit recursively until we find a non '9' digit.
+ buffer[length - 1]++;
+ for (int i = length - 1; i > 0; --i) {
+ if (buffer[i] != '0' + 10) break;
+ buffer[i] = '0';
+ buffer[i - 1]++;
+ }
+ // If the first digit is now '0'+ 10 we had a buffer with all '9's. With the
+ // exception of the first digit all digits are now '0'. Simply switch the
+ // first digit to '1' and adjust the kappa. Example: "99" becomes "10" and
+ // the power (the kappa) is increased.
+ if (buffer[0] == '0' + 10) {
+ buffer[0] = '1';
+ (*kappa) += 1;
+ }
+ return true;
+ }
+ return false;
+}
+
+// Returns the biggest power of ten that is less than or equal to the given
+// number. We furthermore receive the maximum number of bits 'number' has.
+//
+// Returns power == 10^(exponent_plus_one-1) such that
+// power <= number < power * 10.
+// If number_bits == 0 then 0^(0-1) is returned.
+// The number of bits must be <= 32.
+// Precondition: number < (1 << (number_bits + 1)).
+
+// Inspired by the method for finding an integer log base 10 from here:
+// http://graphics.stanford.edu/~seander/bithacks.html#IntegerLog10
+static unsigned int const kSmallPowersOfTen[] =
+ {0, 1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000,
+ 1000000000};
+
+static void BiggestPowerTen(uint32_t number,
+ int number_bits,
+ uint32_t* power,
+ int* exponent_plus_one) {
+ ASSERT(number < (1u << (number_bits + 1)));
+ // 1233/4096 is approximately 1/lg(10).
+ int exponent_plus_one_guess = ((number_bits + 1) * 1233 >> 12);
+ // We increment to skip over the first entry in the kPowersOf10 table.
+ // Note: kPowersOf10[i] == 10^(i-1).
+ exponent_plus_one_guess++;
+ // We don't have any guarantees that 2^number_bits <= number.
+ if (number < kSmallPowersOfTen[exponent_plus_one_guess]) {
+ exponent_plus_one_guess--;
+ }
+ *power = kSmallPowersOfTen[exponent_plus_one_guess];
+ *exponent_plus_one = exponent_plus_one_guess;
+}
+
+// Generates the digits of input number w.
+// w is a floating-point number (DiyFp), consisting of a significand and an
+// exponent. Its exponent is bounded by kMinimalTargetExponent and
+// kMaximalTargetExponent.
+// Hence -60 <= w.e() <= -32.
+//
+// Returns false if it fails, in which case the generated digits in the buffer
+// should not be used.
+// Preconditions:
+// * low, w and high are correct up to 1 ulp (unit in the last place). That
+// is, their error must be less than a unit of their last digits.
+// * low.e() == w.e() == high.e()
+// * low < w < high, and taking into account their error: low~ <= high~
+// * kMinimalTargetExponent <= w.e() <= kMaximalTargetExponent
+// Postconditions: returns false if procedure fails.
+// otherwise:
+// * buffer is not null-terminated, but len contains the number of digits.
+// * buffer contains the shortest possible decimal digit-sequence
+// such that LOW < buffer * 10^kappa < HIGH, where LOW and HIGH are the
+// correct values of low and high (without their error).
+// * if more than one decimal representation gives the minimal number of
+// decimal digits then the one closest to W (where W is the correct value
+// of w) is chosen.
+// Remark: this procedure takes into account the imprecision of its input
+// numbers. If the precision is not enough to guarantee all the postconditions
+// then false is returned. This usually happens rarely (~0.5%).
+//
+// Say, for the sake of example, that
+// w.e() == -48, and w.f() == 0x1234567890abcdef
+// w's value can be computed by w.f() * 2^w.e()
+// We can obtain w's integral digits by simply shifting w.f() by -w.e().
+// -> w's integral part is 0x1234
+// w's fractional part is therefore 0x567890abcdef.
+// Printing w's integral part is easy (simply print 0x1234 in decimal).
+// In order to print its fraction we repeatedly multiply the fraction by 10 and
+// get each digit. Example the first digit after the point would be computed by
+// (0x567890abcdef * 10) >> 48. -> 3
+// The whole thing becomes slightly more complicated because we want to stop
+// once we have enough digits. That is, once the digits inside the buffer
+// represent 'w' we can stop. Everything inside the interval low - high
+// represents w. However we have to pay attention to low, high and w's
+// imprecision.
+static bool DigitGen(DiyFp low,
+ DiyFp w,
+ DiyFp high,
+ Vector<char> buffer,
+ int* length,
+ int* kappa) {
+ ASSERT(low.e() == w.e() && w.e() == high.e());
+ ASSERT(low.f() + 1 <= high.f() - 1);
+ ASSERT(kMinimalTargetExponent <= w.e() && w.e() <= kMaximalTargetExponent);
+ // low, w and high are imprecise, but by less than one ulp (unit in the last
+ // place).
+ // If we remove (resp. add) 1 ulp from low (resp. high) we are certain that
+ // the new numbers are outside of the interval we want the final
+ // representation to lie in.
+ // Inversely adding (resp. removing) 1 ulp from low (resp. high) would yield
+ // numbers that are certain to lie in the interval. We will use this fact
+ // later on.
+ // We will now start by generating the digits within the uncertain
+ // interval. Later we will weed out representations that lie outside the safe
+ // interval and thus _might_ lie outside the correct interval.
+ uint64_t unit = 1;
+ DiyFp too_low = DiyFp(low.f() - unit, low.e());
+ DiyFp too_high = DiyFp(high.f() + unit, high.e());
+ // too_low and too_high are guaranteed to lie outside the interval we want the
+ // generated number in.
+ DiyFp unsafe_interval = DiyFp::Minus(too_high, too_low);
+ // We now cut the input number into two parts: the integral digits and the
+ // fractionals. We will not write any decimal separator though, but adapt
+ // kappa instead.
+ // Reminder: we are currently computing the digits (stored inside the buffer)
+ // such that: too_low < buffer * 10^kappa < too_high
+ // We use too_high for the digit_generation and stop as soon as possible.
+ // If we stop early we effectively round down.
+ DiyFp one = DiyFp(static_cast<uint64_t>(1) << -w.e(), w.e());
+ // Division by one is a shift.
+ uint32_t integrals = static_cast<uint32_t>(too_high.f() >> -one.e());
+ // Modulo by one is an and.
+ uint64_t fractionals = too_high.f() & (one.f() - 1);
+ uint32_t divisor;
+ int divisor_exponent_plus_one;
+ BiggestPowerTen(integrals, DiyFp::kSignificandSize - (-one.e()),
+ &divisor, &divisor_exponent_plus_one);
+ *kappa = divisor_exponent_plus_one;
+ *length = 0;
+ // Loop invariant: buffer = too_high / 10^kappa (integer division)
+ // The invariant holds for the first iteration: kappa has been initialized
+ // with the divisor exponent + 1. And the divisor is the biggest power of ten
+ // that is smaller than integrals.
+ while (*kappa > 0) {
+ int digit = integrals / divisor;
+ ASSERT(digit <= 9);
+ buffer[*length] = static_cast<char>('0' + digit);
+ (*length)++;
+ integrals %= divisor;
+ (*kappa)--;
+ // Note that kappa now equals the exponent of the divisor and that the
+ // invariant thus holds again.
+ uint64_t rest =
+ (static_cast<uint64_t>(integrals) << -one.e()) + fractionals;
+ // Invariant: too_high = buffer * 10^kappa + DiyFp(rest, one.e())
+ // Reminder: unsafe_interval.e() == one.e()
+ if (rest < unsafe_interval.f()) {
+ // Rounding down (by not emitting the remaining digits) yields a number
+ // that lies within the unsafe interval.
+ return RoundWeed(buffer, *length, DiyFp::Minus(too_high, w).f(),
+ unsafe_interval.f(), rest,
+ static_cast<uint64_t>(divisor) << -one.e(), unit);
+ }
+ divisor /= 10;
+ }
+
+ // The integrals have been generated. We are at the point of the decimal
+ // separator. In the following loop we simply multiply the remaining digits by
+ // 10 and divide by one. We just need to pay attention to multiply associated
+ // data (like the interval or 'unit'), too.
+ // Note that the multiplication by 10 does not overflow, because w.e >= -60
+ // and thus one.e >= -60.
+ ASSERT(one.e() >= -60);
+ ASSERT(fractionals < one.f());
+ ASSERT(UINT64_2PART_C(0xFFFFFFFF, FFFFFFFF) / 10 >= one.f());
+ for (;;) {
+ fractionals *= 10;
+ unit *= 10;
+ unsafe_interval.set_f(unsafe_interval.f() * 10);
+ // Integer division by one.
+ int digit = static_cast<int>(fractionals >> -one.e());
+ ASSERT(digit <= 9);
+ buffer[*length] = static_cast<char>('0' + digit);
+ (*length)++;
+ fractionals &= one.f() - 1; // Modulo by one.
+ (*kappa)--;
+ if (fractionals < unsafe_interval.f()) {
+ return RoundWeed(buffer, *length, DiyFp::Minus(too_high, w).f() * unit,
+ unsafe_interval.f(), fractionals, one.f(), unit);
+ }
+ }
+}
+
+
+
+// Generates (at most) requested_digits digits of input number w.
+// w is a floating-point number (DiyFp), consisting of a significand and an
+// exponent. Its exponent is bounded by kMinimalTargetExponent and
+// kMaximalTargetExponent.
+// Hence -60 <= w.e() <= -32.
+//
+// Returns false if it fails, in which case the generated digits in the buffer
+// should not be used.
+// Preconditions:
+// * w is correct up to 1 ulp (unit in the last place). That
+// is, its error must be strictly less than a unit of its last digit.
+// * kMinimalTargetExponent <= w.e() <= kMaximalTargetExponent
+//
+// Postconditions: returns false if procedure fails.
+// otherwise:
+// * buffer is not null-terminated, but length contains the number of
+// digits.
+// * the representation in buffer is the most precise representation of
+// requested_digits digits.
+// * buffer contains at most requested_digits digits of w. If there are less
+// than requested_digits digits then some trailing '0's have been removed.
+// * kappa is such that
+// w = buffer * 10^kappa + eps with |eps| < 10^kappa / 2.
+//
+// Remark: This procedure takes into account the imprecision of its input
+// numbers. If the precision is not enough to guarantee all the postconditions
+// then false is returned. This usually happens rarely, but the failure-rate
+// increases with higher requested_digits.
+static bool DigitGenCounted(DiyFp w,
+ int requested_digits,
+ Vector<char> buffer,
+ int* length,
+ int* kappa) {
+ ASSERT(kMinimalTargetExponent <= w.e() && w.e() <= kMaximalTargetExponent);
+ ASSERT(kMinimalTargetExponent >= -60);
+ ASSERT(kMaximalTargetExponent <= -32);
+ // w is assumed to have an error less than 1 unit. Whenever w is scaled we
+ // also scale its error.
+ uint64_t w_error = 1;
+ // We cut the input number into two parts: the integral digits and the
+ // fractional digits. We don't emit any decimal separator, but adapt kappa
+ // instead. Example: instead of writing "1.2" we put "12" into the buffer and
+ // increase kappa by 1.
+ DiyFp one = DiyFp(static_cast<uint64_t>(1) << -w.e(), w.e());
+ // Division by one is a shift.
+ uint32_t integrals = static_cast<uint32_t>(w.f() >> -one.e());
+ // Modulo by one is an and.
+ uint64_t fractionals = w.f() & (one.f() - 1);
+ uint32_t divisor;
+ int divisor_exponent_plus_one;
+ BiggestPowerTen(integrals, DiyFp::kSignificandSize - (-one.e()),
+ &divisor, &divisor_exponent_plus_one);
+ *kappa = divisor_exponent_plus_one;
+ *length = 0;
+
+ // Loop invariant: buffer = w / 10^kappa (integer division)
+ // The invariant holds for the first iteration: kappa has been initialized
+ // with the divisor exponent + 1. And the divisor is the biggest power of ten
+ // that is smaller than 'integrals'.
+ while (*kappa > 0) {
+ int digit = integrals / divisor;
+ ASSERT(digit <= 9);
+ buffer[*length] = static_cast<char>('0' + digit);
+ (*length)++;
+ requested_digits--;
+ integrals %= divisor;
+ (*kappa)--;
+ // Note that kappa now equals the exponent of the divisor and that the
+ // invariant thus holds again.
+ if (requested_digits == 0) break;
+ divisor /= 10;
+ }
+
+ if (requested_digits == 0) {
+ uint64_t rest =
+ (static_cast<uint64_t>(integrals) << -one.e()) + fractionals;
+ return RoundWeedCounted(buffer, *length, rest,
+ static_cast<uint64_t>(divisor) << -one.e(), w_error,
+ kappa);
+ }
+
+ // The integrals have been generated. We are at the point of the decimal
+ // separator. In the following loop we simply multiply the remaining digits by
+ // 10 and divide by one. We just need to pay attention to multiply associated
+ // data (the 'unit'), too.
+ // Note that the multiplication by 10 does not overflow, because w.e >= -60
+ // and thus one.e >= -60.
+ ASSERT(one.e() >= -60);
+ ASSERT(fractionals < one.f());
+ ASSERT(UINT64_2PART_C(0xFFFFFFFF, FFFFFFFF) / 10 >= one.f());
+ while (requested_digits > 0 && fractionals > w_error) {
+ fractionals *= 10;
+ w_error *= 10;
+ // Integer division by one.
+ int digit = static_cast<int>(fractionals >> -one.e());
+ ASSERT(digit <= 9);
+ buffer[*length] = static_cast<char>('0' + digit);
+ (*length)++;
+ requested_digits--;
+ fractionals &= one.f() - 1; // Modulo by one.
+ (*kappa)--;
+ }
+ if (requested_digits != 0) return false;
+ return RoundWeedCounted(buffer, *length, fractionals, one.f(), w_error,
+ kappa);
+}
+
+
+// Provides a decimal representation of v.
+// Returns true if it succeeds, otherwise the result cannot be trusted.
+// There will be *length digits inside the buffer (not null-terminated).
+// If the function returns true then
+// v == (double) (buffer * 10^decimal_exponent).
+// The digits in the buffer are the shortest representation possible: no
+// 0.09999999999999999 instead of 0.1. The shorter representation will even be
+// chosen even if the longer one would be closer to v.
+// The last digit will be closest to the actual v. That is, even if several
+// digits might correctly yield 'v' when read again, the closest will be
+// computed.
+static bool Grisu3(double v,
+ FastDtoaMode mode,
+ Vector<char> buffer,
+ int* length,
+ int* decimal_exponent) {
+ DiyFp w = Double(v).AsNormalizedDiyFp();
+ // boundary_minus and boundary_plus are the boundaries between v and its
+ // closest floating-point neighbors. Any number strictly between
+ // boundary_minus and boundary_plus will round to v when convert to a double.
+ // Grisu3 will never output representations that lie exactly on a boundary.
+ DiyFp boundary_minus, boundary_plus;
+ if (mode == FAST_DTOA_SHORTEST) {
+ Double(v).NormalizedBoundaries(&boundary_minus, &boundary_plus);
+ } else {
+ ASSERT(mode == FAST_DTOA_SHORTEST_SINGLE);
+ float single_v = static_cast<float>(v);
+ Single(single_v).NormalizedBoundaries(&boundary_minus, &boundary_plus);
+ }
+ ASSERT(boundary_plus.e() == w.e());
+ DiyFp ten_mk; // Cached power of ten: 10^-k
+ int mk; // -k
+ int ten_mk_minimal_binary_exponent =
+ kMinimalTargetExponent - (w.e() + DiyFp::kSignificandSize);
+ int ten_mk_maximal_binary_exponent =
+ kMaximalTargetExponent - (w.e() + DiyFp::kSignificandSize);
+ PowersOfTenCache::GetCachedPowerForBinaryExponentRange(
+ ten_mk_minimal_binary_exponent,
+ ten_mk_maximal_binary_exponent,
+ &ten_mk, &mk);
+ ASSERT((kMinimalTargetExponent <= w.e() + ten_mk.e() +
+ DiyFp::kSignificandSize) &&
+ (kMaximalTargetExponent >= w.e() + ten_mk.e() +
+ DiyFp::kSignificandSize));
+ // Note that ten_mk is only an approximation of 10^-k. A DiyFp only contains a
+ // 64 bit significand and ten_mk is thus only precise up to 64 bits.
+
+ // The DiyFp::Times procedure rounds its result, and ten_mk is approximated
+ // too. The variable scaled_w (as well as scaled_boundary_minus/plus) are now
+ // off by a small amount.
+ // In fact: scaled_w - w*10^k < 1ulp (unit in the last place) of scaled_w.
+ // In other words: let f = scaled_w.f() and e = scaled_w.e(), then
+ // (f-1) * 2^e < w*10^k < (f+1) * 2^e
+ DiyFp scaled_w = DiyFp::Times(w, ten_mk);
+ ASSERT(scaled_w.e() ==
+ boundary_plus.e() + ten_mk.e() + DiyFp::kSignificandSize);
+ // In theory it would be possible to avoid some recomputations by computing
+ // the difference between w and boundary_minus/plus (a power of 2) and to
+ // compute scaled_boundary_minus/plus by subtracting/adding from
+ // scaled_w. However the code becomes much less readable and the speed
+ // enhancements are not terriffic.
+ DiyFp scaled_boundary_minus = DiyFp::Times(boundary_minus, ten_mk);
+ DiyFp scaled_boundary_plus = DiyFp::Times(boundary_plus, ten_mk);
+
+ // DigitGen will generate the digits of scaled_w. Therefore we have
+ // v == (double) (scaled_w * 10^-mk).
+ // Set decimal_exponent == -mk and pass it to DigitGen. If scaled_w is not an
+ // integer than it will be updated. For instance if scaled_w == 1.23 then
+ // the buffer will be filled with "123" und the decimal_exponent will be
+ // decreased by 2.
+ int kappa;
+ bool result = DigitGen(scaled_boundary_minus, scaled_w, scaled_boundary_plus,
+ buffer, length, &kappa);
+ *decimal_exponent = -mk + kappa;
+ return result;
+}
+
+
+// The "counted" version of grisu3 (see above) only generates requested_digits
+// number of digits. This version does not generate the shortest representation,
+// and with enough requested digits 0.1 will at some point print as 0.9999999...
+// Grisu3 is too imprecise for real halfway cases (1.5 will not work) and
+// therefore the rounding strategy for halfway cases is irrelevant.
+static bool Grisu3Counted(double v,
+ int requested_digits,
+ Vector<char> buffer,
+ int* length,
+ int* decimal_exponent) {
+ DiyFp w = Double(v).AsNormalizedDiyFp();
+ DiyFp ten_mk; // Cached power of ten: 10^-k
+ int mk; // -k
+ int ten_mk_minimal_binary_exponent =
+ kMinimalTargetExponent - (w.e() + DiyFp::kSignificandSize);
+ int ten_mk_maximal_binary_exponent =
+ kMaximalTargetExponent - (w.e() + DiyFp::kSignificandSize);
+ PowersOfTenCache::GetCachedPowerForBinaryExponentRange(
+ ten_mk_minimal_binary_exponent,
+ ten_mk_maximal_binary_exponent,
+ &ten_mk, &mk);
+ ASSERT((kMinimalTargetExponent <= w.e() + ten_mk.e() +
+ DiyFp::kSignificandSize) &&
+ (kMaximalTargetExponent >= w.e() + ten_mk.e() +
+ DiyFp::kSignificandSize));
+ // Note that ten_mk is only an approximation of 10^-k. A DiyFp only contains a
+ // 64 bit significand and ten_mk is thus only precise up to 64 bits.
+
+ // The DiyFp::Times procedure rounds its result, and ten_mk is approximated
+ // too. The variable scaled_w (as well as scaled_boundary_minus/plus) are now
+ // off by a small amount.
+ // In fact: scaled_w - w*10^k < 1ulp (unit in the last place) of scaled_w.
+ // In other words: let f = scaled_w.f() and e = scaled_w.e(), then
+ // (f-1) * 2^e < w*10^k < (f+1) * 2^e
+ DiyFp scaled_w = DiyFp::Times(w, ten_mk);
+
+ // We now have (double) (scaled_w * 10^-mk).
+ // DigitGen will generate the first requested_digits digits of scaled_w and
+ // return together with a kappa such that scaled_w ~= buffer * 10^kappa. (It
+ // will not always be exactly the same since DigitGenCounted only produces a
+ // limited number of digits.)
+ int kappa;
+ bool result = DigitGenCounted(scaled_w, requested_digits,
+ buffer, length, &kappa);
+ *decimal_exponent = -mk + kappa;
+ return result;
+}
+
+
+bool FastDtoa(double v,
+ FastDtoaMode mode,
+ int requested_digits,
+ Vector<char> buffer,
+ int* length,
+ int* decimal_point) {
+ ASSERT(v > 0);
+ ASSERT(!Double(v).IsSpecial());
+
+ bool result = false;
+ int decimal_exponent = 0;
+ switch (mode) {
+ case FAST_DTOA_SHORTEST:
+ case FAST_DTOA_SHORTEST_SINGLE:
+ result = Grisu3(v, mode, buffer, length, &decimal_exponent);
+ break;
+ case FAST_DTOA_PRECISION:
+ result = Grisu3Counted(v, requested_digits,
+ buffer, length, &decimal_exponent);
+ break;
+ default:
+ UNREACHABLE();
+ }
+ if (result) {
+ *decimal_point = *length + decimal_exponent;
+ buffer[*length] = '\0';
+ }
+ return result;
+}
+
+} // namespace double_conversion
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/fast-dtoa.h b/src/arrow/cpp/src/arrow/vendored/double-conversion/fast-dtoa.h
new file mode 100644
index 000000000..5f1e8eee5
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/fast-dtoa.h
@@ -0,0 +1,88 @@
+// Copyright 2010 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef DOUBLE_CONVERSION_FAST_DTOA_H_
+#define DOUBLE_CONVERSION_FAST_DTOA_H_
+
+#include "utils.h"
+
+namespace double_conversion {
+
+enum FastDtoaMode {
+ // Computes the shortest representation of the given input. The returned
+ // result will be the most accurate number of this length. Longer
+ // representations might be more accurate.
+ FAST_DTOA_SHORTEST,
+ // Same as FAST_DTOA_SHORTEST but for single-precision floats.
+ FAST_DTOA_SHORTEST_SINGLE,
+ // Computes a representation where the precision (number of digits) is
+ // given as input. The precision is independent of the decimal point.
+ FAST_DTOA_PRECISION
+};
+
+// FastDtoa will produce at most kFastDtoaMaximalLength digits. This does not
+// include the terminating '\0' character.
+static const int kFastDtoaMaximalLength = 17;
+// Same for single-precision numbers.
+static const int kFastDtoaMaximalSingleLength = 9;
+
+// Provides a decimal representation of v.
+// The result should be interpreted as buffer * 10^(point - length).
+//
+// Precondition:
+// * v must be a strictly positive finite double.
+//
+// Returns true if it succeeds, otherwise the result can not be trusted.
+// There will be *length digits inside the buffer followed by a null terminator.
+// If the function returns true and mode equals
+// - FAST_DTOA_SHORTEST, then
+// the parameter requested_digits is ignored.
+// The result satisfies
+// v == (double) (buffer * 10^(point - length)).
+// The digits in the buffer are the shortest representation possible. E.g.
+// if 0.099999999999 and 0.1 represent the same double then "1" is returned
+// with point = 0.
+// The last digit will be closest to the actual v. That is, even if several
+// digits might correctly yield 'v' when read again, the buffer will contain
+// the one closest to v.
+// - FAST_DTOA_PRECISION, then
+// the buffer contains requested_digits digits.
+// the difference v - (buffer * 10^(point-length)) is closest to zero for
+// all possible representations of requested_digits digits.
+// If there are two values that are equally close, then FastDtoa returns
+// false.
+// For both modes the buffer must be large enough to hold the result.
+bool FastDtoa(double d,
+ FastDtoaMode mode,
+ int requested_digits,
+ Vector<char> buffer,
+ int* length,
+ int* decimal_point);
+
+} // namespace double_conversion
+
+#endif // DOUBLE_CONVERSION_FAST_DTOA_H_
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/fixed-dtoa.cc b/src/arrow/cpp/src/arrow/vendored/double-conversion/fixed-dtoa.cc
new file mode 100644
index 000000000..0f989bcea
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/fixed-dtoa.cc
@@ -0,0 +1,405 @@
+// Copyright 2010 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <cmath>
+
+#include "fixed-dtoa.h"
+#include "ieee.h"
+
+namespace double_conversion {
+
+// Represents a 128bit type. This class should be replaced by a native type on
+// platforms that support 128bit integers.
+class UInt128 {
+ public:
+ UInt128() : high_bits_(0), low_bits_(0) { }
+ UInt128(uint64_t high, uint64_t low) : high_bits_(high), low_bits_(low) { }
+
+ void Multiply(uint32_t multiplicand) {
+ uint64_t accumulator;
+
+ accumulator = (low_bits_ & kMask32) * multiplicand;
+ uint32_t part = static_cast<uint32_t>(accumulator & kMask32);
+ accumulator >>= 32;
+ accumulator = accumulator + (low_bits_ >> 32) * multiplicand;
+ low_bits_ = (accumulator << 32) + part;
+ accumulator >>= 32;
+ accumulator = accumulator + (high_bits_ & kMask32) * multiplicand;
+ part = static_cast<uint32_t>(accumulator & kMask32);
+ accumulator >>= 32;
+ accumulator = accumulator + (high_bits_ >> 32) * multiplicand;
+ high_bits_ = (accumulator << 32) + part;
+ ASSERT((accumulator >> 32) == 0);
+ }
+
+ void Shift(int shift_amount) {
+ ASSERT(-64 <= shift_amount && shift_amount <= 64);
+ if (shift_amount == 0) {
+ return;
+ } else if (shift_amount == -64) {
+ high_bits_ = low_bits_;
+ low_bits_ = 0;
+ } else if (shift_amount == 64) {
+ low_bits_ = high_bits_;
+ high_bits_ = 0;
+ } else if (shift_amount <= 0) {
+ high_bits_ <<= -shift_amount;
+ high_bits_ += low_bits_ >> (64 + shift_amount);
+ low_bits_ <<= -shift_amount;
+ } else {
+ low_bits_ >>= shift_amount;
+ low_bits_ += high_bits_ << (64 - shift_amount);
+ high_bits_ >>= shift_amount;
+ }
+ }
+
+ // Modifies *this to *this MOD (2^power).
+ // Returns *this DIV (2^power).
+ int DivModPowerOf2(int power) {
+ if (power >= 64) {
+ int result = static_cast<int>(high_bits_ >> (power - 64));
+ high_bits_ -= static_cast<uint64_t>(result) << (power - 64);
+ return result;
+ } else {
+ uint64_t part_low = low_bits_ >> power;
+ uint64_t part_high = high_bits_ << (64 - power);
+ int result = static_cast<int>(part_low + part_high);
+ high_bits_ = 0;
+ low_bits_ -= part_low << power;
+ return result;
+ }
+ }
+
+ bool IsZero() const {
+ return high_bits_ == 0 && low_bits_ == 0;
+ }
+
+ int BitAt(int position) const {
+ if (position >= 64) {
+ return static_cast<int>(high_bits_ >> (position - 64)) & 1;
+ } else {
+ return static_cast<int>(low_bits_ >> position) & 1;
+ }
+ }
+
+ private:
+ static const uint64_t kMask32 = 0xFFFFFFFF;
+ // Value == (high_bits_ << 64) + low_bits_
+ uint64_t high_bits_;
+ uint64_t low_bits_;
+};
+
+
+static const int kDoubleSignificandSize = 53; // Includes the hidden bit.
+
+
+static void FillDigits32FixedLength(uint32_t number, int requested_length,
+ Vector<char> buffer, int* length) {
+ for (int i = requested_length - 1; i >= 0; --i) {
+ buffer[(*length) + i] = '0' + number % 10;
+ number /= 10;
+ }
+ *length += requested_length;
+}
+
+
+static void FillDigits32(uint32_t number, Vector<char> buffer, int* length) {
+ int number_length = 0;
+ // We fill the digits in reverse order and exchange them afterwards.
+ while (number != 0) {
+ int digit = number % 10;
+ number /= 10;
+ buffer[(*length) + number_length] = static_cast<char>('0' + digit);
+ number_length++;
+ }
+ // Exchange the digits.
+ int i = *length;
+ int j = *length + number_length - 1;
+ while (i < j) {
+ char tmp = buffer[i];
+ buffer[i] = buffer[j];
+ buffer[j] = tmp;
+ i++;
+ j--;
+ }
+ *length += number_length;
+}
+
+
+static void FillDigits64FixedLength(uint64_t number,
+ Vector<char> buffer, int* length) {
+ const uint32_t kTen7 = 10000000;
+ // For efficiency cut the number into 3 uint32_t parts, and print those.
+ uint32_t part2 = static_cast<uint32_t>(number % kTen7);
+ number /= kTen7;
+ uint32_t part1 = static_cast<uint32_t>(number % kTen7);
+ uint32_t part0 = static_cast<uint32_t>(number / kTen7);
+
+ FillDigits32FixedLength(part0, 3, buffer, length);
+ FillDigits32FixedLength(part1, 7, buffer, length);
+ FillDigits32FixedLength(part2, 7, buffer, length);
+}
+
+
+static void FillDigits64(uint64_t number, Vector<char> buffer, int* length) {
+ const uint32_t kTen7 = 10000000;
+ // For efficiency cut the number into 3 uint32_t parts, and print those.
+ uint32_t part2 = static_cast<uint32_t>(number % kTen7);
+ number /= kTen7;
+ uint32_t part1 = static_cast<uint32_t>(number % kTen7);
+ uint32_t part0 = static_cast<uint32_t>(number / kTen7);
+
+ if (part0 != 0) {
+ FillDigits32(part0, buffer, length);
+ FillDigits32FixedLength(part1, 7, buffer, length);
+ FillDigits32FixedLength(part2, 7, buffer, length);
+ } else if (part1 != 0) {
+ FillDigits32(part1, buffer, length);
+ FillDigits32FixedLength(part2, 7, buffer, length);
+ } else {
+ FillDigits32(part2, buffer, length);
+ }
+}
+
+
+static void RoundUp(Vector<char> buffer, int* length, int* decimal_point) {
+ // An empty buffer represents 0.
+ if (*length == 0) {
+ buffer[0] = '1';
+ *decimal_point = 1;
+ *length = 1;
+ return;
+ }
+ // Round the last digit until we either have a digit that was not '9' or until
+ // we reached the first digit.
+ buffer[(*length) - 1]++;
+ for (int i = (*length) - 1; i > 0; --i) {
+ if (buffer[i] != '0' + 10) {
+ return;
+ }
+ buffer[i] = '0';
+ buffer[i - 1]++;
+ }
+ // If the first digit is now '0' + 10, we would need to set it to '0' and add
+ // a '1' in front. However we reach the first digit only if all following
+ // digits had been '9' before rounding up. Now all trailing digits are '0' and
+ // we simply switch the first digit to '1' and update the decimal-point
+ // (indicating that the point is now one digit to the right).
+ if (buffer[0] == '0' + 10) {
+ buffer[0] = '1';
+ (*decimal_point)++;
+ }
+}
+
+
+// The given fractionals number represents a fixed-point number with binary
+// point at bit (-exponent).
+// Preconditions:
+// -128 <= exponent <= 0.
+// 0 <= fractionals * 2^exponent < 1
+// The buffer holds the result.
+// The function will round its result. During the rounding-process digits not
+// generated by this function might be updated, and the decimal-point variable
+// might be updated. If this function generates the digits 99 and the buffer
+// already contained "199" (thus yielding a buffer of "19999") then a
+// rounding-up will change the contents of the buffer to "20000".
+static void FillFractionals(uint64_t fractionals, int exponent,
+ int fractional_count, Vector<char> buffer,
+ int* length, int* decimal_point) {
+ ASSERT(-128 <= exponent && exponent <= 0);
+ // 'fractionals' is a fixed-point number, with binary point at bit
+ // (-exponent). Inside the function the non-converted remainder of fractionals
+ // is a fixed-point number, with binary point at bit 'point'.
+ if (-exponent <= 64) {
+ // One 64 bit number is sufficient.
+ ASSERT(fractionals >> 56 == 0);
+ int point = -exponent;
+ for (int i = 0; i < fractional_count; ++i) {
+ if (fractionals == 0) break;
+ // Instead of multiplying by 10 we multiply by 5 and adjust the point
+ // location. This way the fractionals variable will not overflow.
+ // Invariant at the beginning of the loop: fractionals < 2^point.
+ // Initially we have: point <= 64 and fractionals < 2^56
+ // After each iteration the point is decremented by one.
+ // Note that 5^3 = 125 < 128 = 2^7.
+ // Therefore three iterations of this loop will not overflow fractionals
+ // (even without the subtraction at the end of the loop body). At this
+ // time point will satisfy point <= 61 and therefore fractionals < 2^point
+ // and any further multiplication of fractionals by 5 will not overflow.
+ fractionals *= 5;
+ point--;
+ int digit = static_cast<int>(fractionals >> point);
+ ASSERT(digit <= 9);
+ buffer[*length] = static_cast<char>('0' + digit);
+ (*length)++;
+ fractionals -= static_cast<uint64_t>(digit) << point;
+ }
+ // If the first bit after the point is set we have to round up.
+ ASSERT(fractionals == 0 || point - 1 >= 0);
+ if ((fractionals != 0) && ((fractionals >> (point - 1)) & 1) == 1) {
+ RoundUp(buffer, length, decimal_point);
+ }
+ } else { // We need 128 bits.
+ ASSERT(64 < -exponent && -exponent <= 128);
+ UInt128 fractionals128 = UInt128(fractionals, 0);
+ fractionals128.Shift(-exponent - 64);
+ int point = 128;
+ for (int i = 0; i < fractional_count; ++i) {
+ if (fractionals128.IsZero()) break;
+ // As before: instead of multiplying by 10 we multiply by 5 and adjust the
+ // point location.
+ // This multiplication will not overflow for the same reasons as before.
+ fractionals128.Multiply(5);
+ point--;
+ int digit = fractionals128.DivModPowerOf2(point);
+ ASSERT(digit <= 9);
+ buffer[*length] = static_cast<char>('0' + digit);
+ (*length)++;
+ }
+ if (fractionals128.BitAt(point - 1) == 1) {
+ RoundUp(buffer, length, decimal_point);
+ }
+ }
+}
+
+
+// Removes leading and trailing zeros.
+// If leading zeros are removed then the decimal point position is adjusted.
+static void TrimZeros(Vector<char> buffer, int* length, int* decimal_point) {
+ while (*length > 0 && buffer[(*length) - 1] == '0') {
+ (*length)--;
+ }
+ int first_non_zero = 0;
+ while (first_non_zero < *length && buffer[first_non_zero] == '0') {
+ first_non_zero++;
+ }
+ if (first_non_zero != 0) {
+ for (int i = first_non_zero; i < *length; ++i) {
+ buffer[i - first_non_zero] = buffer[i];
+ }
+ *length -= first_non_zero;
+ *decimal_point -= first_non_zero;
+ }
+}
+
+
+bool FastFixedDtoa(double v,
+ int fractional_count,
+ Vector<char> buffer,
+ int* length,
+ int* decimal_point) {
+ const uint32_t kMaxUInt32 = 0xFFFFFFFF;
+ uint64_t significand = Double(v).Significand();
+ int exponent = Double(v).Exponent();
+ // v = significand * 2^exponent (with significand a 53bit integer).
+ // If the exponent is larger than 20 (i.e. we may have a 73bit number) then we
+ // don't know how to compute the representation. 2^73 ~= 9.5*10^21.
+ // If necessary this limit could probably be increased, but we don't need
+ // more.
+ if (exponent > 20) return false;
+ if (fractional_count > 20) return false;
+ *length = 0;
+ // At most kDoubleSignificandSize bits of the significand are non-zero.
+ // Given a 64 bit integer we have 11 0s followed by 53 potentially non-zero
+ // bits: 0..11*..0xxx..53*..xx
+ if (exponent + kDoubleSignificandSize > 64) {
+ // The exponent must be > 11.
+ //
+ // We know that v = significand * 2^exponent.
+ // And the exponent > 11.
+ // We simplify the task by dividing v by 10^17.
+ // The quotient delivers the first digits, and the remainder fits into a 64
+ // bit number.
+ // Dividing by 10^17 is equivalent to dividing by 5^17*2^17.
+ const uint64_t kFive17 = UINT64_2PART_C(0xB1, A2BC2EC5); // 5^17
+ uint64_t divisor = kFive17;
+ int divisor_power = 17;
+ uint64_t dividend = significand;
+ uint32_t quotient;
+ uint64_t remainder;
+ // Let v = f * 2^e with f == significand and e == exponent.
+ // Then need q (quotient) and r (remainder) as follows:
+ // v = q * 10^17 + r
+ // f * 2^e = q * 10^17 + r
+ // f * 2^e = q * 5^17 * 2^17 + r
+ // If e > 17 then
+ // f * 2^(e-17) = q * 5^17 + r/2^17
+ // else
+ // f = q * 5^17 * 2^(17-e) + r/2^e
+ if (exponent > divisor_power) {
+ // We only allow exponents of up to 20 and therefore (17 - e) <= 3
+ dividend <<= exponent - divisor_power;
+ quotient = static_cast<uint32_t>(dividend / divisor);
+ remainder = (dividend % divisor) << divisor_power;
+ } else {
+ divisor <<= divisor_power - exponent;
+ quotient = static_cast<uint32_t>(dividend / divisor);
+ remainder = (dividend % divisor) << exponent;
+ }
+ FillDigits32(quotient, buffer, length);
+ FillDigits64FixedLength(remainder, buffer, length);
+ *decimal_point = *length;
+ } else if (exponent >= 0) {
+ // 0 <= exponent <= 11
+ significand <<= exponent;
+ FillDigits64(significand, buffer, length);
+ *decimal_point = *length;
+ } else if (exponent > -kDoubleSignificandSize) {
+ // We have to cut the number.
+ uint64_t integrals = significand >> -exponent;
+ uint64_t fractionals = significand - (integrals << -exponent);
+ if (integrals > kMaxUInt32) {
+ FillDigits64(integrals, buffer, length);
+ } else {
+ FillDigits32(static_cast<uint32_t>(integrals), buffer, length);
+ }
+ *decimal_point = *length;
+ FillFractionals(fractionals, exponent, fractional_count,
+ buffer, length, decimal_point);
+ } else if (exponent < -128) {
+ // This configuration (with at most 20 digits) means that all digits must be
+ // 0.
+ ASSERT(fractional_count <= 20);
+ buffer[0] = '\0';
+ *length = 0;
+ *decimal_point = -fractional_count;
+ } else {
+ *decimal_point = 0;
+ FillFractionals(significand, exponent, fractional_count,
+ buffer, length, decimal_point);
+ }
+ TrimZeros(buffer, length, decimal_point);
+ buffer[*length] = '\0';
+ if ((*length) == 0) {
+ // The string is empty and the decimal_point thus has no importance. Mimick
+ // Gay's dtoa and and set it to -fractional_count.
+ *decimal_point = -fractional_count;
+ }
+ return true;
+}
+
+} // namespace double_conversion
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/fixed-dtoa.h b/src/arrow/cpp/src/arrow/vendored/double-conversion/fixed-dtoa.h
new file mode 100644
index 000000000..3bdd08e21
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/fixed-dtoa.h
@@ -0,0 +1,56 @@
+// Copyright 2010 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef DOUBLE_CONVERSION_FIXED_DTOA_H_
+#define DOUBLE_CONVERSION_FIXED_DTOA_H_
+
+#include "utils.h"
+
+namespace double_conversion {
+
+// Produces digits necessary to print a given number with
+// 'fractional_count' digits after the decimal point.
+// The buffer must be big enough to hold the result plus one terminating null
+// character.
+//
+// The produced digits might be too short in which case the caller has to fill
+// the gaps with '0's.
+// Example: FastFixedDtoa(0.001, 5, ...) is allowed to return buffer = "1", and
+// decimal_point = -2.
+// Halfway cases are rounded towards +/-Infinity (away from 0). The call
+// FastFixedDtoa(0.15, 2, ...) thus returns buffer = "2", decimal_point = 0.
+// The returned buffer may contain digits that would be truncated from the
+// shortest representation of the input.
+//
+// This method only works for some parameters. If it can't handle the input it
+// returns false. The output is null-terminated when the function succeeds.
+bool FastFixedDtoa(double v, int fractional_count,
+ Vector<char> buffer, int* length, int* decimal_point);
+
+} // namespace double_conversion
+
+#endif // DOUBLE_CONVERSION_FIXED_DTOA_H_
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/ieee.h b/src/arrow/cpp/src/arrow/vendored/double-conversion/ieee.h
new file mode 100644
index 000000000..83274849b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/ieee.h
@@ -0,0 +1,402 @@
+// Copyright 2012 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef DOUBLE_CONVERSION_DOUBLE_H_
+#define DOUBLE_CONVERSION_DOUBLE_H_
+
+#include "diy-fp.h"
+
+namespace double_conversion {
+
+// We assume that doubles and uint64_t have the same endianness.
+static uint64_t double_to_uint64(double d) { return BitCast<uint64_t>(d); }
+static double uint64_to_double(uint64_t d64) { return BitCast<double>(d64); }
+static uint32_t float_to_uint32(float f) { return BitCast<uint32_t>(f); }
+static float uint32_to_float(uint32_t d32) { return BitCast<float>(d32); }
+
+// Helper functions for doubles.
+class Double {
+ public:
+ static const uint64_t kSignMask = UINT64_2PART_C(0x80000000, 00000000);
+ static const uint64_t kExponentMask = UINT64_2PART_C(0x7FF00000, 00000000);
+ static const uint64_t kSignificandMask = UINT64_2PART_C(0x000FFFFF, FFFFFFFF);
+ static const uint64_t kHiddenBit = UINT64_2PART_C(0x00100000, 00000000);
+ static const int kPhysicalSignificandSize = 52; // Excludes the hidden bit.
+ static const int kSignificandSize = 53;
+ static const int kExponentBias = 0x3FF + kPhysicalSignificandSize;
+ static const int kMaxExponent = 0x7FF - kExponentBias;
+
+ Double() : d64_(0) {}
+ explicit Double(double d) : d64_(double_to_uint64(d)) {}
+ explicit Double(uint64_t d64) : d64_(d64) {}
+ explicit Double(DiyFp diy_fp)
+ : d64_(DiyFpToUint64(diy_fp)) {}
+
+ // The value encoded by this Double must be greater or equal to +0.0.
+ // It must not be special (infinity, or NaN).
+ DiyFp AsDiyFp() const {
+ ASSERT(Sign() > 0);
+ ASSERT(!IsSpecial());
+ return DiyFp(Significand(), Exponent());
+ }
+
+ // The value encoded by this Double must be strictly greater than 0.
+ DiyFp AsNormalizedDiyFp() const {
+ ASSERT(value() > 0.0);
+ uint64_t f = Significand();
+ int e = Exponent();
+
+ // The current double could be a denormal.
+ while ((f & kHiddenBit) == 0) {
+ f <<= 1;
+ e--;
+ }
+ // Do the final shifts in one go.
+ f <<= DiyFp::kSignificandSize - kSignificandSize;
+ e -= DiyFp::kSignificandSize - kSignificandSize;
+ return DiyFp(f, e);
+ }
+
+ // Returns the double's bit as uint64.
+ uint64_t AsUint64() const {
+ return d64_;
+ }
+
+ // Returns the next greater double. Returns +infinity on input +infinity.
+ double NextDouble() const {
+ if (d64_ == kInfinity) return Double(kInfinity).value();
+ if (Sign() < 0 && Significand() == 0) {
+ // -0.0
+ return 0.0;
+ }
+ if (Sign() < 0) {
+ return Double(d64_ - 1).value();
+ } else {
+ return Double(d64_ + 1).value();
+ }
+ }
+
+ double PreviousDouble() const {
+ if (d64_ == (kInfinity | kSignMask)) return -Infinity();
+ if (Sign() < 0) {
+ return Double(d64_ + 1).value();
+ } else {
+ if (Significand() == 0) return -0.0;
+ return Double(d64_ - 1).value();
+ }
+ }
+
+ int Exponent() const {
+ if (IsDenormal()) return kDenormalExponent;
+
+ uint64_t d64 = AsUint64();
+ int biased_e =
+ static_cast<int>((d64 & kExponentMask) >> kPhysicalSignificandSize);
+ return biased_e - kExponentBias;
+ }
+
+ uint64_t Significand() const {
+ uint64_t d64 = AsUint64();
+ uint64_t significand = d64 & kSignificandMask;
+ if (!IsDenormal()) {
+ return significand + kHiddenBit;
+ } else {
+ return significand;
+ }
+ }
+
+ // Returns true if the double is a denormal.
+ bool IsDenormal() const {
+ uint64_t d64 = AsUint64();
+ return (d64 & kExponentMask) == 0;
+ }
+
+ // We consider denormals not to be special.
+ // Hence only Infinity and NaN are special.
+ bool IsSpecial() const {
+ uint64_t d64 = AsUint64();
+ return (d64 & kExponentMask) == kExponentMask;
+ }
+
+ bool IsNan() const {
+ uint64_t d64 = AsUint64();
+ return ((d64 & kExponentMask) == kExponentMask) &&
+ ((d64 & kSignificandMask) != 0);
+ }
+
+ bool IsInfinite() const {
+ uint64_t d64 = AsUint64();
+ return ((d64 & kExponentMask) == kExponentMask) &&
+ ((d64 & kSignificandMask) == 0);
+ }
+
+ int Sign() const {
+ uint64_t d64 = AsUint64();
+ return (d64 & kSignMask) == 0? 1: -1;
+ }
+
+ // Precondition: the value encoded by this Double must be greater or equal
+ // than +0.0.
+ DiyFp UpperBoundary() const {
+ ASSERT(Sign() > 0);
+ return DiyFp(Significand() * 2 + 1, Exponent() - 1);
+ }
+
+ // Computes the two boundaries of this.
+ // The bigger boundary (m_plus) is normalized. The lower boundary has the same
+ // exponent as m_plus.
+ // Precondition: the value encoded by this Double must be greater than 0.
+ void NormalizedBoundaries(DiyFp* out_m_minus, DiyFp* out_m_plus) const {
+ ASSERT(value() > 0.0);
+ DiyFp v = this->AsDiyFp();
+ DiyFp m_plus = DiyFp::Normalize(DiyFp((v.f() << 1) + 1, v.e() - 1));
+ DiyFp m_minus;
+ if (LowerBoundaryIsCloser()) {
+ m_minus = DiyFp((v.f() << 2) - 1, v.e() - 2);
+ } else {
+ m_minus = DiyFp((v.f() << 1) - 1, v.e() - 1);
+ }
+ m_minus.set_f(m_minus.f() << (m_minus.e() - m_plus.e()));
+ m_minus.set_e(m_plus.e());
+ *out_m_plus = m_plus;
+ *out_m_minus = m_minus;
+ }
+
+ bool LowerBoundaryIsCloser() const {
+ // The boundary is closer if the significand is of the form f == 2^p-1 then
+ // the lower boundary is closer.
+ // Think of v = 1000e10 and v- = 9999e9.
+ // Then the boundary (== (v - v-)/2) is not just at a distance of 1e9 but
+ // at a distance of 1e8.
+ // The only exception is for the smallest normal: the largest denormal is
+ // at the same distance as its successor.
+ // Note: denormals have the same exponent as the smallest normals.
+ bool physical_significand_is_zero = ((AsUint64() & kSignificandMask) == 0);
+ return physical_significand_is_zero && (Exponent() != kDenormalExponent);
+ }
+
+ double value() const { return uint64_to_double(d64_); }
+
+ // Returns the significand size for a given order of magnitude.
+ // If v = f*2^e with 2^p-1 <= f <= 2^p then p+e is v's order of magnitude.
+ // This function returns the number of significant binary digits v will have
+ // once it's encoded into a double. In almost all cases this is equal to
+ // kSignificandSize. The only exceptions are denormals. They start with
+ // leading zeroes and their effective significand-size is hence smaller.
+ static int SignificandSizeForOrderOfMagnitude(int order) {
+ if (order >= (kDenormalExponent + kSignificandSize)) {
+ return kSignificandSize;
+ }
+ if (order <= kDenormalExponent) return 0;
+ return order - kDenormalExponent;
+ }
+
+ static double Infinity() {
+ return Double(kInfinity).value();
+ }
+
+ static double NaN() {
+ return Double(kNaN).value();
+ }
+
+ private:
+ static const int kDenormalExponent = -kExponentBias + 1;
+ static const uint64_t kInfinity = UINT64_2PART_C(0x7FF00000, 00000000);
+ static const uint64_t kNaN = UINT64_2PART_C(0x7FF80000, 00000000);
+
+ const uint64_t d64_;
+
+ static uint64_t DiyFpToUint64(DiyFp diy_fp) {
+ uint64_t significand = diy_fp.f();
+ int exponent = diy_fp.e();
+ while (significand > kHiddenBit + kSignificandMask) {
+ significand >>= 1;
+ exponent++;
+ }
+ if (exponent >= kMaxExponent) {
+ return kInfinity;
+ }
+ if (exponent < kDenormalExponent) {
+ return 0;
+ }
+ while (exponent > kDenormalExponent && (significand & kHiddenBit) == 0) {
+ significand <<= 1;
+ exponent--;
+ }
+ uint64_t biased_exponent;
+ if (exponent == kDenormalExponent && (significand & kHiddenBit) == 0) {
+ biased_exponent = 0;
+ } else {
+ biased_exponent = static_cast<uint64_t>(exponent + kExponentBias);
+ }
+ return (significand & kSignificandMask) |
+ (biased_exponent << kPhysicalSignificandSize);
+ }
+
+ DC_DISALLOW_COPY_AND_ASSIGN(Double);
+};
+
+class Single {
+ public:
+ static const uint32_t kSignMask = 0x80000000;
+ static const uint32_t kExponentMask = 0x7F800000;
+ static const uint32_t kSignificandMask = 0x007FFFFF;
+ static const uint32_t kHiddenBit = 0x00800000;
+ static const int kPhysicalSignificandSize = 23; // Excludes the hidden bit.
+ static const int kSignificandSize = 24;
+
+ Single() : d32_(0) {}
+ explicit Single(float f) : d32_(float_to_uint32(f)) {}
+ explicit Single(uint32_t d32) : d32_(d32) {}
+
+ // The value encoded by this Single must be greater or equal to +0.0.
+ // It must not be special (infinity, or NaN).
+ DiyFp AsDiyFp() const {
+ ASSERT(Sign() > 0);
+ ASSERT(!IsSpecial());
+ return DiyFp(Significand(), Exponent());
+ }
+
+ // Returns the single's bit as uint64.
+ uint32_t AsUint32() const {
+ return d32_;
+ }
+
+ int Exponent() const {
+ if (IsDenormal()) return kDenormalExponent;
+
+ uint32_t d32 = AsUint32();
+ int biased_e =
+ static_cast<int>((d32 & kExponentMask) >> kPhysicalSignificandSize);
+ return biased_e - kExponentBias;
+ }
+
+ uint32_t Significand() const {
+ uint32_t d32 = AsUint32();
+ uint32_t significand = d32 & kSignificandMask;
+ if (!IsDenormal()) {
+ return significand + kHiddenBit;
+ } else {
+ return significand;
+ }
+ }
+
+ // Returns true if the single is a denormal.
+ bool IsDenormal() const {
+ uint32_t d32 = AsUint32();
+ return (d32 & kExponentMask) == 0;
+ }
+
+ // We consider denormals not to be special.
+ // Hence only Infinity and NaN are special.
+ bool IsSpecial() const {
+ uint32_t d32 = AsUint32();
+ return (d32 & kExponentMask) == kExponentMask;
+ }
+
+ bool IsNan() const {
+ uint32_t d32 = AsUint32();
+ return ((d32 & kExponentMask) == kExponentMask) &&
+ ((d32 & kSignificandMask) != 0);
+ }
+
+ bool IsInfinite() const {
+ uint32_t d32 = AsUint32();
+ return ((d32 & kExponentMask) == kExponentMask) &&
+ ((d32 & kSignificandMask) == 0);
+ }
+
+ int Sign() const {
+ uint32_t d32 = AsUint32();
+ return (d32 & kSignMask) == 0? 1: -1;
+ }
+
+ // Computes the two boundaries of this.
+ // The bigger boundary (m_plus) is normalized. The lower boundary has the same
+ // exponent as m_plus.
+ // Precondition: the value encoded by this Single must be greater than 0.
+ void NormalizedBoundaries(DiyFp* out_m_minus, DiyFp* out_m_plus) const {
+ ASSERT(value() > 0.0);
+ DiyFp v = this->AsDiyFp();
+ DiyFp m_plus = DiyFp::Normalize(DiyFp((v.f() << 1) + 1, v.e() - 1));
+ DiyFp m_minus;
+ if (LowerBoundaryIsCloser()) {
+ m_minus = DiyFp((v.f() << 2) - 1, v.e() - 2);
+ } else {
+ m_minus = DiyFp((v.f() << 1) - 1, v.e() - 1);
+ }
+ m_minus.set_f(m_minus.f() << (m_minus.e() - m_plus.e()));
+ m_minus.set_e(m_plus.e());
+ *out_m_plus = m_plus;
+ *out_m_minus = m_minus;
+ }
+
+ // Precondition: the value encoded by this Single must be greater or equal
+ // than +0.0.
+ DiyFp UpperBoundary() const {
+ ASSERT(Sign() > 0);
+ return DiyFp(Significand() * 2 + 1, Exponent() - 1);
+ }
+
+ bool LowerBoundaryIsCloser() const {
+ // The boundary is closer if the significand is of the form f == 2^p-1 then
+ // the lower boundary is closer.
+ // Think of v = 1000e10 and v- = 9999e9.
+ // Then the boundary (== (v - v-)/2) is not just at a distance of 1e9 but
+ // at a distance of 1e8.
+ // The only exception is for the smallest normal: the largest denormal is
+ // at the same distance as its successor.
+ // Note: denormals have the same exponent as the smallest normals.
+ bool physical_significand_is_zero = ((AsUint32() & kSignificandMask) == 0);
+ return physical_significand_is_zero && (Exponent() != kDenormalExponent);
+ }
+
+ float value() const { return uint32_to_float(d32_); }
+
+ static float Infinity() {
+ return Single(kInfinity).value();
+ }
+
+ static float NaN() {
+ return Single(kNaN).value();
+ }
+
+ private:
+ static const int kExponentBias = 0x7F + kPhysicalSignificandSize;
+ static const int kDenormalExponent = -kExponentBias + 1;
+ static const int kMaxExponent = 0xFF - kExponentBias;
+ static const uint32_t kInfinity = 0x7F800000;
+ static const uint32_t kNaN = 0x7FC00000;
+
+ const uint32_t d32_;
+
+ DC_DISALLOW_COPY_AND_ASSIGN(Single);
+};
+
+} // namespace double_conversion
+
+#endif // DOUBLE_CONVERSION_DOUBLE_H_
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/strtod.cc b/src/arrow/cpp/src/arrow/vendored/double-conversion/strtod.cc
new file mode 100644
index 000000000..a75cf5d9f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/strtod.cc
@@ -0,0 +1,580 @@
+// Copyright 2010 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <climits>
+#include <cstdarg>
+
+#include "bignum.h"
+#include "cached-powers.h"
+#include "ieee.h"
+#include "strtod.h"
+
+namespace double_conversion {
+
+// 2^53 = 9007199254740992.
+// Any integer with at most 15 decimal digits will hence fit into a double
+// (which has a 53bit significand) without loss of precision.
+static const int kMaxExactDoubleIntegerDecimalDigits = 15;
+// 2^64 = 18446744073709551616 > 10^19
+static const int kMaxUint64DecimalDigits = 19;
+
+// Max double: 1.7976931348623157 x 10^308
+// Min non-zero double: 4.9406564584124654 x 10^-324
+// Any x >= 10^309 is interpreted as +infinity.
+// Any x <= 10^-324 is interpreted as 0.
+// Note that 2.5e-324 (despite being smaller than the min double) will be read
+// as non-zero (equal to the min non-zero double).
+static const int kMaxDecimalPower = 309;
+static const int kMinDecimalPower = -324;
+
+// 2^64 = 18446744073709551616
+static const uint64_t kMaxUint64 = UINT64_2PART_C(0xFFFFFFFF, FFFFFFFF);
+
+
+static const double exact_powers_of_ten[] = {
+ 1.0, // 10^0
+ 10.0,
+ 100.0,
+ 1000.0,
+ 10000.0,
+ 100000.0,
+ 1000000.0,
+ 10000000.0,
+ 100000000.0,
+ 1000000000.0,
+ 10000000000.0, // 10^10
+ 100000000000.0,
+ 1000000000000.0,
+ 10000000000000.0,
+ 100000000000000.0,
+ 1000000000000000.0,
+ 10000000000000000.0,
+ 100000000000000000.0,
+ 1000000000000000000.0,
+ 10000000000000000000.0,
+ 100000000000000000000.0, // 10^20
+ 1000000000000000000000.0,
+ // 10^22 = 0x21e19e0c9bab2400000 = 0x878678326eac9 * 2^22
+ 10000000000000000000000.0
+};
+static const int kExactPowersOfTenSize = ARRAY_SIZE(exact_powers_of_ten);
+
+// Maximum number of significant digits in the decimal representation.
+// In fact the value is 772 (see conversions.cc), but to give us some margin
+// we round up to 780.
+static const int kMaxSignificantDecimalDigits = 780;
+
+static Vector<const char> TrimLeadingZeros(Vector<const char> buffer) {
+ for (int i = 0; i < buffer.length(); i++) {
+ if (buffer[i] != '0') {
+ return buffer.SubVector(i, buffer.length());
+ }
+ }
+ return Vector<const char>(buffer.start(), 0);
+}
+
+
+static Vector<const char> TrimTrailingZeros(Vector<const char> buffer) {
+ for (int i = buffer.length() - 1; i >= 0; --i) {
+ if (buffer[i] != '0') {
+ return buffer.SubVector(0, i + 1);
+ }
+ }
+ return Vector<const char>(buffer.start(), 0);
+}
+
+
+static void CutToMaxSignificantDigits(Vector<const char> buffer,
+ int exponent,
+ char* significant_buffer,
+ int* significant_exponent) {
+ for (int i = 0; i < kMaxSignificantDecimalDigits - 1; ++i) {
+ significant_buffer[i] = buffer[i];
+ }
+ // The input buffer has been trimmed. Therefore the last digit must be
+ // different from '0'.
+ ASSERT(buffer[buffer.length() - 1] != '0');
+ // Set the last digit to be non-zero. This is sufficient to guarantee
+ // correct rounding.
+ significant_buffer[kMaxSignificantDecimalDigits - 1] = '1';
+ *significant_exponent =
+ exponent + (buffer.length() - kMaxSignificantDecimalDigits);
+}
+
+
+// Trims the buffer and cuts it to at most kMaxSignificantDecimalDigits.
+// If possible the input-buffer is reused, but if the buffer needs to be
+// modified (due to cutting), then the input needs to be copied into the
+// buffer_copy_space.
+static void TrimAndCut(Vector<const char> buffer, int exponent,
+ char* buffer_copy_space, int space_size,
+ Vector<const char>* trimmed, int* updated_exponent) {
+ Vector<const char> left_trimmed = TrimLeadingZeros(buffer);
+ Vector<const char> right_trimmed = TrimTrailingZeros(left_trimmed);
+ exponent += left_trimmed.length() - right_trimmed.length();
+ if (right_trimmed.length() > kMaxSignificantDecimalDigits) {
+ (void) space_size; // Mark variable as used.
+ ASSERT(space_size >= kMaxSignificantDecimalDigits);
+ CutToMaxSignificantDigits(right_trimmed, exponent,
+ buffer_copy_space, updated_exponent);
+ *trimmed = Vector<const char>(buffer_copy_space,
+ kMaxSignificantDecimalDigits);
+ } else {
+ *trimmed = right_trimmed;
+ *updated_exponent = exponent;
+ }
+}
+
+
+// Reads digits from the buffer and converts them to a uint64.
+// Reads in as many digits as fit into a uint64.
+// When the string starts with "1844674407370955161" no further digit is read.
+// Since 2^64 = 18446744073709551616 it would still be possible read another
+// digit if it was less or equal than 6, but this would complicate the code.
+static uint64_t ReadUint64(Vector<const char> buffer,
+ int* number_of_read_digits) {
+ uint64_t result = 0;
+ int i = 0;
+ while (i < buffer.length() && result <= (kMaxUint64 / 10 - 1)) {
+ int digit = buffer[i++] - '0';
+ ASSERT(0 <= digit && digit <= 9);
+ result = 10 * result + digit;
+ }
+ *number_of_read_digits = i;
+ return result;
+}
+
+
+// Reads a DiyFp from the buffer.
+// The returned DiyFp is not necessarily normalized.
+// If remaining_decimals is zero then the returned DiyFp is accurate.
+// Otherwise it has been rounded and has error of at most 1/2 ulp.
+static void ReadDiyFp(Vector<const char> buffer,
+ DiyFp* result,
+ int* remaining_decimals) {
+ int read_digits;
+ uint64_t significand = ReadUint64(buffer, &read_digits);
+ if (buffer.length() == read_digits) {
+ *result = DiyFp(significand, 0);
+ *remaining_decimals = 0;
+ } else {
+ // Round the significand.
+ if (buffer[read_digits] >= '5') {
+ significand++;
+ }
+ // Compute the binary exponent.
+ int exponent = 0;
+ *result = DiyFp(significand, exponent);
+ *remaining_decimals = buffer.length() - read_digits;
+ }
+}
+
+
+static bool DoubleStrtod(Vector<const char> trimmed,
+ int exponent,
+ double* result) {
+#if !defined(DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS)
+ // On x86 the floating-point stack can be 64 or 80 bits wide. If it is
+ // 80 bits wide (as is the case on Linux) then double-rounding occurs and the
+ // result is not accurate.
+ // We know that Windows32 uses 64 bits and is therefore accurate.
+ // Note that the ARM simulator is compiled for 32bits. It therefore exhibits
+ // the same problem.
+ return false;
+#else
+ if (trimmed.length() <= kMaxExactDoubleIntegerDecimalDigits) {
+ int read_digits;
+ // The trimmed input fits into a double.
+ // If the 10^exponent (resp. 10^-exponent) fits into a double too then we
+ // can compute the result-double simply by multiplying (resp. dividing) the
+ // two numbers.
+ // This is possible because IEEE guarantees that floating-point operations
+ // return the best possible approximation.
+ if (exponent < 0 && -exponent < kExactPowersOfTenSize) {
+ // 10^-exponent fits into a double.
+ *result = static_cast<double>(ReadUint64(trimmed, &read_digits));
+ ASSERT(read_digits == trimmed.length());
+ *result /= exact_powers_of_ten[-exponent];
+ return true;
+ }
+ if (0 <= exponent && exponent < kExactPowersOfTenSize) {
+ // 10^exponent fits into a double.
+ *result = static_cast<double>(ReadUint64(trimmed, &read_digits));
+ ASSERT(read_digits == trimmed.length());
+ *result *= exact_powers_of_ten[exponent];
+ return true;
+ }
+ int remaining_digits =
+ kMaxExactDoubleIntegerDecimalDigits - trimmed.length();
+ if ((0 <= exponent) &&
+ (exponent - remaining_digits < kExactPowersOfTenSize)) {
+ // The trimmed string was short and we can multiply it with
+ // 10^remaining_digits. As a result the remaining exponent now fits
+ // into a double too.
+ *result = static_cast<double>(ReadUint64(trimmed, &read_digits));
+ ASSERT(read_digits == trimmed.length());
+ *result *= exact_powers_of_ten[remaining_digits];
+ *result *= exact_powers_of_ten[exponent - remaining_digits];
+ return true;
+ }
+ }
+ return false;
+#endif
+}
+
+
+// Returns 10^exponent as an exact DiyFp.
+// The given exponent must be in the range [1; kDecimalExponentDistance[.
+static DiyFp AdjustmentPowerOfTen(int exponent) {
+ ASSERT(0 < exponent);
+ ASSERT(exponent < PowersOfTenCache::kDecimalExponentDistance);
+ // Simply hardcode the remaining powers for the given decimal exponent
+ // distance.
+ ASSERT(PowersOfTenCache::kDecimalExponentDistance == 8);
+ switch (exponent) {
+ case 1: return DiyFp(UINT64_2PART_C(0xa0000000, 00000000), -60);
+ case 2: return DiyFp(UINT64_2PART_C(0xc8000000, 00000000), -57);
+ case 3: return DiyFp(UINT64_2PART_C(0xfa000000, 00000000), -54);
+ case 4: return DiyFp(UINT64_2PART_C(0x9c400000, 00000000), -50);
+ case 5: return DiyFp(UINT64_2PART_C(0xc3500000, 00000000), -47);
+ case 6: return DiyFp(UINT64_2PART_C(0xf4240000, 00000000), -44);
+ case 7: return DiyFp(UINT64_2PART_C(0x98968000, 00000000), -40);
+ default:
+ UNREACHABLE();
+ }
+}
+
+
+// If the function returns true then the result is the correct double.
+// Otherwise it is either the correct double or the double that is just below
+// the correct double.
+static bool DiyFpStrtod(Vector<const char> buffer,
+ int exponent,
+ double* result) {
+ DiyFp input;
+ int remaining_decimals;
+ ReadDiyFp(buffer, &input, &remaining_decimals);
+ // Since we may have dropped some digits the input is not accurate.
+ // If remaining_decimals is different than 0 than the error is at most
+ // .5 ulp (unit in the last place).
+ // We don't want to deal with fractions and therefore keep a common
+ // denominator.
+ const int kDenominatorLog = 3;
+ const int kDenominator = 1 << kDenominatorLog;
+ // Move the remaining decimals into the exponent.
+ exponent += remaining_decimals;
+ uint64_t error = (remaining_decimals == 0 ? 0 : kDenominator / 2);
+
+ int old_e = input.e();
+ input.Normalize();
+ error <<= old_e - input.e();
+
+ ASSERT(exponent <= PowersOfTenCache::kMaxDecimalExponent);
+ if (exponent < PowersOfTenCache::kMinDecimalExponent) {
+ *result = 0.0;
+ return true;
+ }
+ DiyFp cached_power;
+ int cached_decimal_exponent;
+ PowersOfTenCache::GetCachedPowerForDecimalExponent(exponent,
+ &cached_power,
+ &cached_decimal_exponent);
+
+ if (cached_decimal_exponent != exponent) {
+ int adjustment_exponent = exponent - cached_decimal_exponent;
+ DiyFp adjustment_power = AdjustmentPowerOfTen(adjustment_exponent);
+ input.Multiply(adjustment_power);
+ if (kMaxUint64DecimalDigits - buffer.length() >= adjustment_exponent) {
+ // The product of input with the adjustment power fits into a 64 bit
+ // integer.
+ ASSERT(DiyFp::kSignificandSize == 64);
+ } else {
+ // The adjustment power is exact. There is hence only an error of 0.5.
+ error += kDenominator / 2;
+ }
+ }
+
+ input.Multiply(cached_power);
+ // The error introduced by a multiplication of a*b equals
+ // error_a + error_b + error_a*error_b/2^64 + 0.5
+ // Substituting a with 'input' and b with 'cached_power' we have
+ // error_b = 0.5 (all cached powers have an error of less than 0.5 ulp),
+ // error_ab = 0 or 1 / kDenominator > error_a*error_b/ 2^64
+ int error_b = kDenominator / 2;
+ int error_ab = (error == 0 ? 0 : 1); // We round up to 1.
+ int fixed_error = kDenominator / 2;
+ error += error_b + error_ab + fixed_error;
+
+ old_e = input.e();
+ input.Normalize();
+ error <<= old_e - input.e();
+
+ // See if the double's significand changes if we add/subtract the error.
+ int order_of_magnitude = DiyFp::kSignificandSize + input.e();
+ int effective_significand_size =
+ Double::SignificandSizeForOrderOfMagnitude(order_of_magnitude);
+ int precision_digits_count =
+ DiyFp::kSignificandSize - effective_significand_size;
+ if (precision_digits_count + kDenominatorLog >= DiyFp::kSignificandSize) {
+ // This can only happen for very small denormals. In this case the
+ // half-way multiplied by the denominator exceeds the range of an uint64.
+ // Simply shift everything to the right.
+ int shift_amount = (precision_digits_count + kDenominatorLog) -
+ DiyFp::kSignificandSize + 1;
+ input.set_f(input.f() >> shift_amount);
+ input.set_e(input.e() + shift_amount);
+ // We add 1 for the lost precision of error, and kDenominator for
+ // the lost precision of input.f().
+ error = (error >> shift_amount) + 1 + kDenominator;
+ precision_digits_count -= shift_amount;
+ }
+ // We use uint64_ts now. This only works if the DiyFp uses uint64_ts too.
+ ASSERT(DiyFp::kSignificandSize == 64);
+ ASSERT(precision_digits_count < 64);
+ uint64_t one64 = 1;
+ uint64_t precision_bits_mask = (one64 << precision_digits_count) - 1;
+ uint64_t precision_bits = input.f() & precision_bits_mask;
+ uint64_t half_way = one64 << (precision_digits_count - 1);
+ precision_bits *= kDenominator;
+ half_way *= kDenominator;
+ DiyFp rounded_input(input.f() >> precision_digits_count,
+ input.e() + precision_digits_count);
+ if (precision_bits >= half_way + error) {
+ rounded_input.set_f(rounded_input.f() + 1);
+ }
+ // If the last_bits are too close to the half-way case than we are too
+ // inaccurate and round down. In this case we return false so that we can
+ // fall back to a more precise algorithm.
+
+ *result = Double(rounded_input).value();
+ if (half_way - error < precision_bits && precision_bits < half_way + error) {
+ // Too imprecise. The caller will have to fall back to a slower version.
+ // However the returned number is guaranteed to be either the correct
+ // double, or the next-lower double.
+ return false;
+ } else {
+ return true;
+ }
+}
+
+
+// Returns
+// - -1 if buffer*10^exponent < diy_fp.
+// - 0 if buffer*10^exponent == diy_fp.
+// - +1 if buffer*10^exponent > diy_fp.
+// Preconditions:
+// buffer.length() + exponent <= kMaxDecimalPower + 1
+// buffer.length() + exponent > kMinDecimalPower
+// buffer.length() <= kMaxDecimalSignificantDigits
+static int CompareBufferWithDiyFp(Vector<const char> buffer,
+ int exponent,
+ DiyFp diy_fp) {
+ ASSERT(buffer.length() + exponent <= kMaxDecimalPower + 1);
+ ASSERT(buffer.length() + exponent > kMinDecimalPower);
+ ASSERT(buffer.length() <= kMaxSignificantDecimalDigits);
+ // Make sure that the Bignum will be able to hold all our numbers.
+ // Our Bignum implementation has a separate field for exponents. Shifts will
+ // consume at most one bigit (< 64 bits).
+ // ln(10) == 3.3219...
+ ASSERT(((kMaxDecimalPower + 1) * 333 / 100) < Bignum::kMaxSignificantBits);
+ Bignum buffer_bignum;
+ Bignum diy_fp_bignum;
+ buffer_bignum.AssignDecimalString(buffer);
+ diy_fp_bignum.AssignUInt64(diy_fp.f());
+ if (exponent >= 0) {
+ buffer_bignum.MultiplyByPowerOfTen(exponent);
+ } else {
+ diy_fp_bignum.MultiplyByPowerOfTen(-exponent);
+ }
+ if (diy_fp.e() > 0) {
+ diy_fp_bignum.ShiftLeft(diy_fp.e());
+ } else {
+ buffer_bignum.ShiftLeft(-diy_fp.e());
+ }
+ return Bignum::Compare(buffer_bignum, diy_fp_bignum);
+}
+
+
+// Returns true if the guess is the correct double.
+// Returns false, when guess is either correct or the next-lower double.
+static bool ComputeGuess(Vector<const char> trimmed, int exponent,
+ double* guess) {
+ if (trimmed.length() == 0) {
+ *guess = 0.0;
+ return true;
+ }
+ if (exponent + trimmed.length() - 1 >= kMaxDecimalPower) {
+ *guess = Double::Infinity();
+ return true;
+ }
+ if (exponent + trimmed.length() <= kMinDecimalPower) {
+ *guess = 0.0;
+ return true;
+ }
+
+ if (DoubleStrtod(trimmed, exponent, guess) ||
+ DiyFpStrtod(trimmed, exponent, guess)) {
+ return true;
+ }
+ if (*guess == Double::Infinity()) {
+ return true;
+ }
+ return false;
+}
+
+double Strtod(Vector<const char> buffer, int exponent) {
+ char copy_buffer[kMaxSignificantDecimalDigits];
+ Vector<const char> trimmed;
+ int updated_exponent;
+ TrimAndCut(buffer, exponent, copy_buffer, kMaxSignificantDecimalDigits,
+ &trimmed, &updated_exponent);
+ exponent = updated_exponent;
+
+ double guess;
+ bool is_correct = ComputeGuess(trimmed, exponent, &guess);
+ if (is_correct) return guess;
+
+ DiyFp upper_boundary = Double(guess).UpperBoundary();
+ int comparison = CompareBufferWithDiyFp(trimmed, exponent, upper_boundary);
+ if (comparison < 0) {
+ return guess;
+ } else if (comparison > 0) {
+ return Double(guess).NextDouble();
+ } else if ((Double(guess).Significand() & 1) == 0) {
+ // Round towards even.
+ return guess;
+ } else {
+ return Double(guess).NextDouble();
+ }
+}
+
+static float SanitizedDoubletof(double d) {
+ ASSERT(d >= 0.0);
+ // ASAN has a sanitize check that disallows casting doubles to floats if
+ // they are too big.
+ // https://clang.llvm.org/docs/UndefinedBehaviorSanitizer.html#available-checks
+ // The behavior should be covered by IEEE 754, but some projects use this
+ // flag, so work around it.
+ float max_finite = 3.4028234663852885981170418348451692544e+38;
+ // The half-way point between the max-finite and infinity value.
+ // Since infinity has an even significand everything equal or greater than
+ // this value should become infinity.
+ double half_max_finite_infinity =
+ 3.40282356779733661637539395458142568448e+38;
+ if (d >= max_finite) {
+ if (d >= half_max_finite_infinity) {
+ return Single::Infinity();
+ } else {
+ return max_finite;
+ }
+ } else {
+ return static_cast<float>(d);
+ }
+}
+
+float Strtof(Vector<const char> buffer, int exponent) {
+ char copy_buffer[kMaxSignificantDecimalDigits];
+ Vector<const char> trimmed;
+ int updated_exponent;
+ TrimAndCut(buffer, exponent, copy_buffer, kMaxSignificantDecimalDigits,
+ &trimmed, &updated_exponent);
+ exponent = updated_exponent;
+
+ double double_guess;
+ bool is_correct = ComputeGuess(trimmed, exponent, &double_guess);
+
+ float float_guess = SanitizedDoubletof(double_guess);
+ if (float_guess == double_guess) {
+ // This shortcut triggers for integer values.
+ return float_guess;
+ }
+
+ // We must catch double-rounding. Say the double has been rounded up, and is
+ // now a boundary of a float, and rounds up again. This is why we have to
+ // look at previous too.
+ // Example (in decimal numbers):
+ // input: 12349
+ // high-precision (4 digits): 1235
+ // low-precision (3 digits):
+ // when read from input: 123
+ // when rounded from high precision: 124.
+ // To do this we simply look at the neigbors of the correct result and see
+ // if they would round to the same float. If the guess is not correct we have
+ // to look at four values (since two different doubles could be the correct
+ // double).
+
+ double double_next = Double(double_guess).NextDouble();
+ double double_previous = Double(double_guess).PreviousDouble();
+
+ float f1 = SanitizedDoubletof(double_previous);
+ float f2 = float_guess;
+ float f3 = SanitizedDoubletof(double_next);
+ float f4;
+ if (is_correct) {
+ f4 = f3;
+ } else {
+ double double_next2 = Double(double_next).NextDouble();
+ f4 = SanitizedDoubletof(double_next2);
+ }
+ (void) f2; // Mark variable as used.
+ ASSERT(f1 <= f2 && f2 <= f3 && f3 <= f4);
+
+ // If the guess doesn't lie near a single-precision boundary we can simply
+ // return its float-value.
+ if (f1 == f4) {
+ return float_guess;
+ }
+
+ ASSERT((f1 != f2 && f2 == f3 && f3 == f4) ||
+ (f1 == f2 && f2 != f3 && f3 == f4) ||
+ (f1 == f2 && f2 == f3 && f3 != f4));
+
+ // guess and next are the two possible candidates (in the same way that
+ // double_guess was the lower candidate for a double-precision guess).
+ float guess = f1;
+ float next = f4;
+ DiyFp upper_boundary;
+ if (guess == 0.0f) {
+ float min_float = 1e-45f;
+ upper_boundary = Double(static_cast<double>(min_float) / 2).AsDiyFp();
+ } else {
+ upper_boundary = Single(guess).UpperBoundary();
+ }
+ int comparison = CompareBufferWithDiyFp(trimmed, exponent, upper_boundary);
+ if (comparison < 0) {
+ return guess;
+ } else if (comparison > 0) {
+ return next;
+ } else if ((Single(guess).Significand() & 1) == 0) {
+ // Round towards even.
+ return guess;
+ } else {
+ return next;
+ }
+}
+
+} // namespace double_conversion
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/strtod.h b/src/arrow/cpp/src/arrow/vendored/double-conversion/strtod.h
new file mode 100644
index 000000000..ed0293b8f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/strtod.h
@@ -0,0 +1,45 @@
+// Copyright 2010 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef DOUBLE_CONVERSION_STRTOD_H_
+#define DOUBLE_CONVERSION_STRTOD_H_
+
+#include "utils.h"
+
+namespace double_conversion {
+
+// The buffer must only contain digits in the range [0-9]. It must not
+// contain a dot or a sign. It must not start with '0', and must not be empty.
+double Strtod(Vector<const char> buffer, int exponent);
+
+// The buffer must only contain digits in the range [0-9]. It must not
+// contain a dot or a sign. It must not start with '0', and must not be empty.
+float Strtof(Vector<const char> buffer, int exponent);
+
+} // namespace double_conversion
+
+#endif // DOUBLE_CONVERSION_STRTOD_H_
diff --git a/src/arrow/cpp/src/arrow/vendored/double-conversion/utils.h b/src/arrow/cpp/src/arrow/vendored/double-conversion/utils.h
new file mode 100644
index 000000000..4328344d7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/double-conversion/utils.h
@@ -0,0 +1,367 @@
+// Copyright 2010 the V8 project authors. All rights reserved.
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following
+// disclaimer in the documentation and/or other materials provided
+// with the distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived
+// from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef DOUBLE_CONVERSION_UTILS_H_
+#define DOUBLE_CONVERSION_UTILS_H_
+
+#include <cstdlib>
+#include <cstring>
+
+#include <cassert>
+#ifndef ASSERT
+#define ASSERT(condition) \
+ assert(condition);
+#endif
+#ifndef UNIMPLEMENTED
+#define UNIMPLEMENTED() (abort())
+#endif
+#ifndef DOUBLE_CONVERSION_NO_RETURN
+#ifdef _MSC_VER
+#define DOUBLE_CONVERSION_NO_RETURN __declspec(noreturn)
+#else
+#define DOUBLE_CONVERSION_NO_RETURN __attribute__((noreturn))
+#endif
+#endif
+#ifndef UNREACHABLE
+#ifdef _MSC_VER
+void DOUBLE_CONVERSION_NO_RETURN abort_noreturn();
+inline void abort_noreturn() { abort(); }
+#define UNREACHABLE() (abort_noreturn())
+#else
+#define UNREACHABLE() (abort())
+#endif
+#endif
+
+#ifndef DOUBLE_CONVERSION_UNUSED
+#ifdef __GNUC__
+#define DOUBLE_CONVERSION_UNUSED __attribute__((unused))
+#else
+#define DOUBLE_CONVERSION_UNUSED
+#endif
+#endif
+
+// Double operations detection based on target architecture.
+// Linux uses a 80bit wide floating point stack on x86. This induces double
+// rounding, which in turn leads to wrong results.
+// An easy way to test if the floating-point operations are correct is to
+// evaluate: 89255.0/1e22. If the floating-point stack is 64 bits wide then
+// the result is equal to 89255e-22.
+// The best way to test this, is to create a division-function and to compare
+// the output of the division with the expected result. (Inlining must be
+// disabled.)
+// On Linux,x86 89255e-22 != Div_double(89255.0/1e22)
+//
+// For example:
+/*
+// -- in div.c
+double Div_double(double x, double y) { return x / y; }
+
+// -- in main.c
+double Div_double(double x, double y); // Forward declaration.
+
+int main(int argc, char** argv) {
+ return Div_double(89255.0, 1e22) == 89255e-22;
+}
+*/
+// Run as follows ./main || echo "correct"
+//
+// If it prints "correct" then the architecture should be here, in the "correct" section.
+#if defined(_M_X64) || defined(__x86_64__) || \
+ defined(__ARMEL__) || defined(__avr32__) || defined(_M_ARM) || defined(_M_ARM64) || \
+ defined(__hppa__) || defined(__ia64__) || \
+ defined(__mips__) || \
+ defined(__powerpc__) || defined(__ppc__) || defined(__ppc64__) || \
+ defined(_POWER) || defined(_ARCH_PPC) || defined(_ARCH_PPC64) || \
+ defined(__sparc__) || defined(__sparc) || defined(__s390__) || \
+ defined(__SH4__) || defined(__alpha__) || \
+ defined(_MIPS_ARCH_MIPS32R2) || defined(__ARMEB__) ||\
+ defined(__AARCH64EL__) || defined(__aarch64__) || defined(__AARCH64EB__) || \
+ defined(__riscv) || \
+ defined(__or1k__) || defined(__arc__) || \
+ defined(__EMSCRIPTEN__)
+#define DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS 1
+#elif defined(__mc68000__) || \
+ defined(__pnacl__) || defined(__native_client__)
+#undef DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS
+#elif defined(_M_IX86) || defined(__i386__) || defined(__i386)
+#if defined(_WIN32)
+// Windows uses a 64bit wide floating point stack.
+#define DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS 1
+#else
+#undef DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS
+#endif // _WIN32
+#else
+#error Target architecture was not detected as supported by Double-Conversion.
+#endif
+
+#if defined(_WIN32) && !defined(__MINGW32__)
+
+typedef signed char int8_t;
+typedef unsigned char uint8_t;
+typedef short int16_t; // NOLINT
+typedef unsigned short uint16_t; // NOLINT
+typedef int int32_t;
+typedef unsigned int uint32_t;
+typedef __int64 int64_t;
+typedef unsigned __int64 uint64_t;
+// intptr_t and friends are defined in crtdefs.h through stdio.h.
+
+#else
+
+#include <stdint.h>
+
+#endif
+
+typedef uint16_t uc16;
+
+// The following macro works on both 32 and 64-bit platforms.
+// Usage: instead of writing 0x1234567890123456
+// write UINT64_2PART_C(0x12345678,90123456);
+#define UINT64_2PART_C(a, b) (((static_cast<uint64_t>(a) << 32) + 0x##b##u))
+
+
+// The expression ARRAY_SIZE(a) is a compile-time constant of type
+// size_t which represents the number of elements of the given
+// array. You should only use ARRAY_SIZE on statically allocated
+// arrays.
+#ifndef ARRAY_SIZE
+#define ARRAY_SIZE(a) \
+ ((sizeof(a) / sizeof(*(a))) / \
+ static_cast<size_t>(!(sizeof(a) % sizeof(*(a)))))
+#endif
+
+// A macro to disallow the evil copy constructor and operator= functions
+// This should be used in the private: declarations for a class
+#ifndef DC_DISALLOW_COPY_AND_ASSIGN
+#define DC_DISALLOW_COPY_AND_ASSIGN(TypeName) \
+ TypeName(const TypeName&); \
+ void operator=(const TypeName&)
+#endif
+
+// A macro to disallow all the implicit constructors, namely the
+// default constructor, copy constructor and operator= functions.
+//
+// This should be used in the private: declarations for a class
+// that wants to prevent anyone from instantiating it. This is
+// especially useful for classes containing only static methods.
+#ifndef DC_DISALLOW_IMPLICIT_CONSTRUCTORS
+#define DC_DISALLOW_IMPLICIT_CONSTRUCTORS(TypeName) \
+ TypeName(); \
+ DC_DISALLOW_COPY_AND_ASSIGN(TypeName)
+#endif
+
+namespace double_conversion {
+
+static const int kCharSize = sizeof(char);
+
+// Returns the maximum of the two parameters.
+template <typename T>
+static T Max(T a, T b) {
+ return a < b ? b : a;
+}
+
+
+// Returns the minimum of the two parameters.
+template <typename T>
+static T Min(T a, T b) {
+ return a < b ? a : b;
+}
+
+
+inline int StrLength(const char* string) {
+ size_t length = strlen(string);
+ ASSERT(length == static_cast<size_t>(static_cast<int>(length)));
+ return static_cast<int>(length);
+}
+
+// This is a simplified version of V8's Vector class.
+template <typename T>
+class Vector {
+ public:
+ Vector() : start_(NULL), length_(0) {}
+ Vector(T* data, int len) : start_(data), length_(len) {
+ ASSERT(len == 0 || (len > 0 && data != NULL));
+ }
+
+ // Returns a vector using the same backing storage as this one,
+ // spanning from and including 'from', to but not including 'to'.
+ Vector<T> SubVector(int from, int to) {
+ ASSERT(to <= length_);
+ ASSERT(from < to);
+ ASSERT(0 <= from);
+ return Vector<T>(start() + from, to - from);
+ }
+
+ // Returns the length of the vector.
+ int length() const { return length_; }
+
+ // Returns whether or not the vector is empty.
+ bool is_empty() const { return length_ == 0; }
+
+ // Returns the pointer to the start of the data in the vector.
+ T* start() const { return start_; }
+
+ // Access individual vector elements - checks bounds in debug mode.
+ T& operator[](int index) const {
+ ASSERT(0 <= index && index < length_);
+ return start_[index];
+ }
+
+ T& first() { return start_[0]; }
+
+ T& last() { return start_[length_ - 1]; }
+
+ private:
+ T* start_;
+ int length_;
+};
+
+
+// Helper class for building result strings in a character buffer. The
+// purpose of the class is to use safe operations that checks the
+// buffer bounds on all operations in debug mode.
+class StringBuilder {
+ public:
+ StringBuilder(char* buffer, int buffer_size)
+ : buffer_(buffer, buffer_size), position_(0) { }
+
+ ~StringBuilder() { if (!is_finalized()) Finalize(); }
+
+ int size() const { return buffer_.length(); }
+
+ // Get the current position in the builder.
+ int position() const {
+ ASSERT(!is_finalized());
+ return position_;
+ }
+
+ // Reset the position.
+ void Reset() { position_ = 0; }
+
+ // Add a single character to the builder. It is not allowed to add
+ // 0-characters; use the Finalize() method to terminate the string
+ // instead.
+ void AddCharacter(char c) {
+ ASSERT(c != '\0');
+ ASSERT(!is_finalized() && position_ < buffer_.length());
+ buffer_[position_++] = c;
+ }
+
+ // Add an entire string to the builder. Uses strlen() internally to
+ // compute the length of the input string.
+ void AddString(const char* s) {
+ AddSubstring(s, StrLength(s));
+ }
+
+ // Add the first 'n' characters of the given string 's' to the
+ // builder. The input string must have enough characters.
+ void AddSubstring(const char* s, int n) {
+ ASSERT(!is_finalized() && position_ + n < buffer_.length());
+ ASSERT(static_cast<size_t>(n) <= strlen(s));
+ memmove(&buffer_[position_], s, n * kCharSize);
+ position_ += n;
+ }
+
+
+ // Add character padding to the builder. If count is non-positive,
+ // nothing is added to the builder.
+ void AddPadding(char c, int count) {
+ for (int i = 0; i < count; i++) {
+ AddCharacter(c);
+ }
+ }
+
+ // Finalize the string by 0-terminating it and returning the buffer.
+ char* Finalize() {
+ ASSERT(!is_finalized() && position_ < buffer_.length());
+ buffer_[position_] = '\0';
+ // Make sure nobody managed to add a 0-character to the
+ // buffer while building the string.
+ ASSERT(strlen(buffer_.start()) == static_cast<size_t>(position_));
+ position_ = -1;
+ ASSERT(is_finalized());
+ return buffer_.start();
+ }
+
+ private:
+ Vector<char> buffer_;
+ int position_;
+
+ bool is_finalized() const { return position_ < 0; }
+
+ DC_DISALLOW_IMPLICIT_CONSTRUCTORS(StringBuilder);
+};
+
+// The type-based aliasing rule allows the compiler to assume that pointers of
+// different types (for some definition of different) never alias each other.
+// Thus the following code does not work:
+//
+// float f = foo();
+// int fbits = *(int*)(&f);
+//
+// The compiler 'knows' that the int pointer can't refer to f since the types
+// don't match, so the compiler may cache f in a register, leaving random data
+// in fbits. Using C++ style casts makes no difference, however a pointer to
+// char data is assumed to alias any other pointer. This is the 'memcpy
+// exception'.
+//
+// Bit_cast uses the memcpy exception to move the bits from a variable of one
+// type of a variable of another type. Of course the end result is likely to
+// be implementation dependent. Most compilers (gcc-4.2 and MSVC 2005)
+// will completely optimize BitCast away.
+//
+// There is an additional use for BitCast.
+// Recent gccs will warn when they see casts that may result in breakage due to
+// the type-based aliasing rule. If you have checked that there is no breakage
+// you can use BitCast to cast one pointer type to another. This confuses gcc
+// enough that it can no longer see that you have cast one pointer type to
+// another thus avoiding the warning.
+template <class Dest, class Source>
+inline Dest BitCast(const Source& source) {
+ // Compile time assertion: sizeof(Dest) == sizeof(Source)
+ // A compile error here means your Dest and Source have different sizes.
+#if __cplusplus >= 201103L
+ static_assert(sizeof(Dest) == sizeof(Source),
+ "source and destination size mismatch");
+#else
+ DOUBLE_CONVERSION_UNUSED
+ typedef char VerifySizesAreEqual[sizeof(Dest) == sizeof(Source) ? 1 : -1];
+#endif
+
+ Dest dest;
+ memmove(&dest, &source, sizeof(dest));
+ return dest;
+}
+
+template <class Dest, class Source>
+inline Dest BitCast(Source* source) {
+ return BitCast<Dest>(reinterpret_cast<uintptr_t>(source));
+}
+
+} // namespace double_conversion
+
+#endif // DOUBLE_CONVERSION_UTILS_H_
diff --git a/src/arrow/cpp/src/arrow/vendored/fast_float/README.md b/src/arrow/cpp/src/arrow/vendored/fast_float/README.md
new file mode 100644
index 000000000..7d2e70541
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/fast_float/README.md
@@ -0,0 +1,7 @@
+The files in this directory are vendored from fast_float
+git changeset `d4bc0f28a276ac05f8663826eadef324de3a3399`.
+
+See https://github.com/fastfloat/fast_float
+
+Changes:
+- enclosed in `arrow_vendored` namespace.
diff --git a/src/arrow/cpp/src/arrow/vendored/fast_float/ascii_number.h b/src/arrow/cpp/src/arrow/vendored/fast_float/ascii_number.h
new file mode 100644
index 000000000..8a5afdee9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/fast_float/ascii_number.h
@@ -0,0 +1,301 @@
+#ifndef FASTFLOAT_ASCII_NUMBER_H
+#define FASTFLOAT_ASCII_NUMBER_H
+
+#include <cstdio>
+#include <cctype>
+#include <cstdint>
+#include <cstring>
+
+#include "float_common.h"
+
+namespace arrow_vendored {
+namespace fast_float {
+
+// Next function can be micro-optimized, but compilers are entirely
+// able to optimize it well.
+fastfloat_really_inline bool is_integer(char c) noexcept { return c >= '0' && c <= '9'; }
+
+
+// credit @aqrit
+fastfloat_really_inline uint32_t parse_eight_digits_unrolled(uint64_t val) {
+ const uint64_t mask = 0x000000FF000000FF;
+ const uint64_t mul1 = 0x000F424000000064; // 100 + (1000000ULL << 32)
+ const uint64_t mul2 = 0x0000271000000001; // 1 + (10000ULL << 32)
+ val -= 0x3030303030303030;
+ val = (val * 10) + (val >> 8); // val = (val * 2561) >> 8;
+ val = (((val & mask) * mul1) + (((val >> 16) & mask) * mul2)) >> 32;
+ return uint32_t(val);
+}
+
+fastfloat_really_inline uint32_t parse_eight_digits_unrolled(const char *chars) noexcept {
+ uint64_t val;
+ ::memcpy(&val, chars, sizeof(uint64_t));
+ return parse_eight_digits_unrolled(val);
+}
+
+// credit @aqrit
+fastfloat_really_inline bool is_made_of_eight_digits_fast(uint64_t val) noexcept {
+ return !((((val + 0x4646464646464646) | (val - 0x3030303030303030)) &
+ 0x8080808080808080));
+}
+
+fastfloat_really_inline bool is_made_of_eight_digits_fast(const char *chars) noexcept {
+ uint64_t val;
+ ::memcpy(&val, chars, 8);
+ return is_made_of_eight_digits_fast(val);
+}
+
+struct parsed_number_string {
+ int64_t exponent;
+ uint64_t mantissa;
+ const char *lastmatch;
+ bool negative;
+ bool valid;
+ bool too_many_digits;
+};
+
+
+// Assuming that you use no more than 19 digits, this will
+// parse an ASCII string.
+fastfloat_really_inline
+parsed_number_string parse_number_string(const char *p, const char *pend, chars_format fmt) noexcept {
+ parsed_number_string answer;
+ answer.valid = false;
+ answer.too_many_digits = false;
+ answer.negative = (*p == '-');
+ if ((*p == '-') || (*p == '+')) {
+ ++p;
+ if (p == pend) {
+ return answer;
+ }
+ if (!is_integer(*p) && (*p != '.')) { // a sign must be followed by an integer or the dot
+ return answer;
+ }
+ }
+ const char *const start_digits = p;
+
+ uint64_t i = 0; // an unsigned int avoids signed overflows (which are bad)
+
+ while ((p != pend) && is_integer(*p)) {
+ // a multiplication by 10 is cheaper than an arbitrary integer
+ // multiplication
+ i = 10 * i +
+ uint64_t(*p - '0'); // might overflow, we will handle the overflow later
+ ++p;
+ }
+ const char *const end_of_integer_part = p;
+ int64_t digit_count = int64_t(end_of_integer_part - start_digits);
+ int64_t exponent = 0;
+ if ((p != pend) && (*p == '.')) {
+ ++p;
+#if FASTFLOAT_IS_BIG_ENDIAN == 0
+ // Fast approach only tested under little endian systems
+ if ((p + 8 <= pend) && is_made_of_eight_digits_fast(p)) {
+ i = i * 100000000 + parse_eight_digits_unrolled(p); // in rare cases, this will overflow, but that's ok
+ p += 8;
+ if ((p + 8 <= pend) && is_made_of_eight_digits_fast(p)) {
+ i = i * 100000000 + parse_eight_digits_unrolled(p); // in rare cases, this will overflow, but that's ok
+ p += 8;
+ }
+ }
+#endif
+ while ((p != pend) && is_integer(*p)) {
+ uint8_t digit = uint8_t(*p - '0');
+ ++p;
+ i = i * 10 + digit; // in rare cases, this will overflow, but that's ok
+ }
+ exponent = end_of_integer_part + 1 - p;
+ digit_count -= exponent;
+ }
+ // we must have encountered at least one integer!
+ if (digit_count == 0) {
+ return answer;
+ }
+ int64_t exp_number = 0; // explicit exponential part
+ if ((fmt & chars_format::scientific) && (p != pend) && (('e' == *p) || ('E' == *p))) {
+ const char * location_of_e = p;
+ ++p;
+ bool neg_exp = false;
+ if ((p != pend) && ('-' == *p)) {
+ neg_exp = true;
+ ++p;
+ } else if ((p != pend) && ('+' == *p)) {
+ ++p;
+ }
+ if ((p == pend) || !is_integer(*p)) {
+ if(!(fmt & chars_format::fixed)) {
+ // We are in error.
+ return answer;
+ }
+ // Otherwise, we will be ignoring the 'e'.
+ p = location_of_e;
+ } else {
+ while ((p != pend) && is_integer(*p)) {
+ uint8_t digit = uint8_t(*p - '0');
+ if (exp_number < 0x10000) {
+ exp_number = 10 * exp_number + digit;
+ }
+ ++p;
+ }
+ if(neg_exp) { exp_number = - exp_number; }
+ exponent += exp_number;
+ }
+ } else {
+ // If it scientific and not fixed, we have to bail out.
+ if((fmt & chars_format::scientific) && !(fmt & chars_format::fixed)) { return answer; }
+ }
+ answer.lastmatch = p;
+ answer.valid = true;
+
+ // If we frequently had to deal with long strings of digits,
+ // we could extend our code by using a 128-bit integer instead
+ // of a 64-bit integer. However, this is uncommon.
+ //
+ // We can deal with up to 19 digits.
+ if (digit_count > 19) { // this is uncommon
+ // It is possible that the integer had an overflow.
+ // We have to handle the case where we have 0.0000somenumber.
+ // We need to be mindful of the case where we only have zeroes...
+ // E.g., 0.000000000...000.
+ const char *start = start_digits;
+ while ((start != pend) && (*start == '0' || *start == '.')) {
+ if(*start == '0') { digit_count --; }
+ start++;
+ }
+ if (digit_count > 19) {
+ answer.too_many_digits = true;
+ // Let us start again, this time, avoiding overflows.
+ i = 0;
+ p = start_digits;
+ const uint64_t minimal_nineteen_digit_integer{1000000000000000000};
+ while((i < minimal_nineteen_digit_integer) && (p != pend) && is_integer(*p)) {
+ i = i * 10 + uint64_t(*p - '0');
+ ++p;
+ }
+ if (i >= minimal_nineteen_digit_integer) { // We have a big integers
+ exponent = end_of_integer_part - p + exp_number;
+ } else { // We have a value with a fractional component.
+ p++; // skip the '.'
+ const char *first_after_period = p;
+ while((i < minimal_nineteen_digit_integer) && (p != pend) && is_integer(*p)) {
+ i = i * 10 + uint64_t(*p - '0');
+ ++p;
+ }
+ exponent = first_after_period - p + exp_number;
+ }
+ // We have now corrected both exponent and i, to a truncated value
+ }
+ }
+ answer.exponent = exponent;
+ answer.mantissa = i;
+ return answer;
+}
+
+
+// This should always succeed since it follows a call to parse_number_string
+// This function could be optimized. In particular, we could stop after 19 digits
+// and try to bail out. Furthermore, we should be able to recover the computed
+// exponent from the pass in parse_number_string.
+fastfloat_really_inline decimal parse_decimal(const char *p, const char *pend) noexcept {
+ decimal answer;
+ answer.num_digits = 0;
+ answer.decimal_point = 0;
+ answer.truncated = false;
+ // any whitespace has been skipped.
+ answer.negative = (*p == '-');
+ if ((*p == '-') || (*p == '+')) {
+ ++p;
+ }
+ // skip leading zeroes
+ while ((p != pend) && (*p == '0')) {
+ ++p;
+ }
+ while ((p != pend) && is_integer(*p)) {
+ if (answer.num_digits < max_digits) {
+ answer.digits[answer.num_digits] = uint8_t(*p - '0');
+ }
+ answer.num_digits++;
+ ++p;
+ }
+ if ((p != pend) && (*p == '.')) {
+ ++p;
+ const char *first_after_period = p;
+ // if we have not yet encountered a zero, we have to skip it as well
+ if(answer.num_digits == 0) {
+ // skip zeros
+ while ((p != pend) && (*p == '0')) {
+ ++p;
+ }
+ }
+#if FASTFLOAT_IS_BIG_ENDIAN == 0
+ // We expect that this loop will often take the bulk of the running time
+ // because when a value has lots of digits, these digits often
+ while ((p + 8 <= pend) && (answer.num_digits + 8 < max_digits)) {
+ uint64_t val;
+ ::memcpy(&val, p, sizeof(uint64_t));
+ if(! is_made_of_eight_digits_fast(val)) { break; }
+ // We have eight digits, process them in one go!
+ val -= 0x3030303030303030;
+ ::memcpy(answer.digits + answer.num_digits, &val, sizeof(uint64_t));
+ answer.num_digits += 8;
+ p += 8;
+ }
+#endif
+ while ((p != pend) && is_integer(*p)) {
+ if (answer.num_digits < max_digits) {
+ answer.digits[answer.num_digits] = uint8_t(*p - '0');
+ }
+ answer.num_digits++;
+ ++p;
+ }
+ answer.decimal_point = int32_t(first_after_period - p);
+ }
+ // We want num_digits to be the number of significant digits, excluding
+ // leading *and* trailing zeros! Otherwise the truncated flag later is
+ // going to be misleading.
+ if(answer.num_digits > 0) {
+ // We potentially need the answer.num_digits > 0 guard because we
+ // prune leading zeros. So with answer.num_digits > 0, we know that
+ // we have at least one non-zero digit.
+ const char *preverse = p - 1;
+ int32_t trailing_zeros = 0;
+ while ((*preverse == '0') || (*preverse == '.')) {
+ if(*preverse == '0') { trailing_zeros++; };
+ --preverse;
+ }
+ answer.decimal_point += int32_t(answer.num_digits);
+ answer.num_digits -= uint32_t(trailing_zeros);
+ }
+ if(answer.num_digits > max_digits) {
+ answer.truncated = true;
+ answer.num_digits = max_digits;
+ }
+ if ((p != pend) && (('e' == *p) || ('E' == *p))) {
+ ++p;
+ bool neg_exp = false;
+ if ((p != pend) && ('-' == *p)) {
+ neg_exp = true;
+ ++p;
+ } else if ((p != pend) && ('+' == *p)) {
+ ++p;
+ }
+ int32_t exp_number = 0; // exponential part
+ while ((p != pend) && is_integer(*p)) {
+ uint8_t digit = uint8_t(*p - '0');
+ if (exp_number < 0x10000) {
+ exp_number = 10 * exp_number + digit;
+ }
+ ++p;
+ }
+ answer.decimal_point += (neg_exp ? -exp_number : exp_number);
+ }
+ // In very rare cases, we may have fewer than 19 digits, we want to be able to reliably
+ // assume that all digits up to max_digit_without_overflow have been initialized.
+ for(uint32_t i = answer.num_digits; i < max_digit_without_overflow; i++) { answer.digits[i] = 0; }
+
+ return answer;
+}
+} // namespace fast_float
+} // namespace arrow_vendored
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/fast_float/decimal_to_binary.h b/src/arrow/cpp/src/arrow/vendored/fast_float/decimal_to_binary.h
new file mode 100644
index 000000000..2419de7c0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/fast_float/decimal_to_binary.h
@@ -0,0 +1,176 @@
+#ifndef FASTFLOAT_DECIMAL_TO_BINARY_H
+#define FASTFLOAT_DECIMAL_TO_BINARY_H
+
+#include "float_common.h"
+#include "fast_table.h"
+#include <cfloat>
+#include <cinttypes>
+#include <cmath>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+
+namespace arrow_vendored {
+namespace fast_float {
+
+// This will compute or rather approximate w * 5**q and return a pair of 64-bit words approximating
+// the result, with the "high" part corresponding to the most significant bits and the
+// low part corresponding to the least significant bits.
+//
+template <int bit_precision>
+fastfloat_really_inline
+value128 compute_product_approximation(int64_t q, uint64_t w) {
+ const int index = 2 * int(q - smallest_power_of_five);
+ // For small values of q, e.g., q in [0,27], the answer is always exact because
+ // The line value128 firstproduct = full_multiplication(w, power_of_five_128[index]);
+ // gives the exact answer.
+ value128 firstproduct = full_multiplication(w, power_of_five_128[index]);
+ static_assert((bit_precision >= 0) && (bit_precision <= 64), " precision should be in (0,64]");
+ constexpr uint64_t precision_mask = (bit_precision < 64) ?
+ (uint64_t(0xFFFFFFFFFFFFFFFF) >> bit_precision)
+ : uint64_t(0xFFFFFFFFFFFFFFFF);
+ if((firstproduct.high & precision_mask) == precision_mask) { // could further guard with (lower + w < lower)
+ // regarding the second product, we only need secondproduct.high, but our expectation is that the compiler will optimize this extra work away if needed.
+ value128 secondproduct = full_multiplication(w, power_of_five_128[index + 1]);
+ firstproduct.low += secondproduct.high;
+ if(secondproduct.high > firstproduct.low) {
+ firstproduct.high++;
+ }
+ }
+ return firstproduct;
+}
+
+namespace {
+/**
+ * For q in (0,350), we have that
+ * f = (((152170 + 65536) * q ) >> 16);
+ * is equal to
+ * floor(p) + q
+ * where
+ * p = log(5**q)/log(2) = q * log(5)/log(2)
+ *
+ * For negative values of q in (-400,0), we have that
+ * f = (((152170 + 65536) * q ) >> 16);
+ * is equal to
+ * -ceil(p) + q
+ * where
+ * p = log(5**-q)/log(2) = -q * log(5)/log(2)
+ */
+ fastfloat_really_inline int power(int q) noexcept {
+ return (((152170 + 65536) * q) >> 16) + 63;
+ }
+} // namespace
+
+
+// w * 10 ** q
+// The returned value should be a valid ieee64 number that simply need to be packed.
+// However, in some very rare cases, the computation will fail. In such cases, we
+// return an adjusted_mantissa with a negative power of 2: the caller should recompute
+// in such cases.
+template <typename binary>
+fastfloat_really_inline
+adjusted_mantissa compute_float(int64_t q, uint64_t w) noexcept {
+ adjusted_mantissa answer;
+ if ((w == 0) || (q < binary::smallest_power_of_ten())) {
+ answer.power2 = 0;
+ answer.mantissa = 0;
+ // result should be zero
+ return answer;
+ }
+ if (q > binary::largest_power_of_ten()) {
+ // we want to get infinity:
+ answer.power2 = binary::infinite_power();
+ answer.mantissa = 0;
+ return answer;
+ }
+ // At this point in time q is in [smallest_power_of_five, largest_power_of_five].
+
+ // We want the most significant bit of i to be 1. Shift if needed.
+ int lz = leading_zeroes(w);
+ w <<= lz;
+
+ // The required precision is binary::mantissa_explicit_bits() + 3 because
+ // 1. We need the implicit bit
+ // 2. We need an extra bit for rounding purposes
+ // 3. We might lose a bit due to the "upperbit" routine (result too small, requiring a shift)
+
+ value128 product = compute_product_approximation<binary::mantissa_explicit_bits() + 3>(q, w);
+ if(product.low == 0xFFFFFFFFFFFFFFFF) { // could guard it further
+ // In some very rare cases, this could happen, in which case we might need a more accurate
+ // computation that what we can provide cheaply. This is very, very unlikely.
+ //
+ const bool inside_safe_exponent = (q >= -27) && (q <= 55); // always good because 5**q <2**128 when q>=0,
+ // and otherwise, for q<0, we have 5**-q<2**64 and the 128-bit reciprocal allows for exact computation.
+ if(!inside_safe_exponent) {
+ answer.power2 = -1; // This (a negative value) indicates an error condition.
+ return answer;
+ }
+ }
+ // The "compute_product_approximation" function can be slightly slower than a branchless approach:
+ // value128 product = compute_product(q, w);
+ // but in practice, we can win big with the compute_product_approximation if its additional branch
+ // is easily predicted. Which is best is data specific.
+ int upperbit = int(product.high >> 63);
+
+ answer.mantissa = product.high >> (upperbit + 64 - binary::mantissa_explicit_bits() - 3);
+
+ answer.power2 = int(power(int(q)) + upperbit - lz - binary::minimum_exponent());
+ if (answer.power2 <= 0) { // we have a subnormal?
+ // Here have that answer.power2 <= 0 so -answer.power2 >= 0
+ if(-answer.power2 + 1 >= 64) { // if we have more than 64 bits below the minimum exponent, you have a zero for sure.
+ answer.power2 = 0;
+ answer.mantissa = 0;
+ // result should be zero
+ return answer;
+ }
+ // next line is safe because -answer.power2 + 1 < 64
+ answer.mantissa >>= -answer.power2 + 1;
+ // Thankfully, we can't have both "round-to-even" and subnormals because
+ // "round-to-even" only occurs for powers close to 0.
+ answer.mantissa += (answer.mantissa & 1); // round up
+ answer.mantissa >>= 1;
+ // There is a weird scenario where we don't have a subnormal but just.
+ // Suppose we start with 2.2250738585072013e-308, we end up
+ // with 0x3fffffffffffff x 2^-1023-53 which is technically subnormal
+ // whereas 0x40000000000000 x 2^-1023-53 is normal. Now, we need to round
+ // up 0x3fffffffffffff x 2^-1023-53 and once we do, we are no longer
+ // subnormal, but we can only know this after rounding.
+ // So we only declare a subnormal if we are smaller than the threshold.
+ answer.power2 = (answer.mantissa < (uint64_t(1) << binary::mantissa_explicit_bits())) ? 0 : 1;
+ return answer;
+ }
+
+ // usually, we round *up*, but if we fall right in between and and we have an
+ // even basis, we need to round down
+ // We are only concerned with the cases where 5**q fits in single 64-bit word.
+ if ((product.low <= 1) && (q >= binary::min_exponent_round_to_even()) && (q <= binary::max_exponent_round_to_even()) &&
+ ((answer.mantissa & 3) == 1) ) { // we may fall between two floats!
+ // To be in-between two floats we need that in doing
+ // answer.mantissa = product.high >> (upperbit + 64 - binary::mantissa_explicit_bits() - 3);
+ // ... we dropped out only zeroes. But if this happened, then we can go back!!!
+ if((answer.mantissa << (upperbit + 64 - binary::mantissa_explicit_bits() - 3)) == product.high) {
+ answer.mantissa &= ~uint64_t(1); // flip it so that we do not round up
+ }
+ }
+
+ answer.mantissa += (answer.mantissa & 1); // round up
+ answer.mantissa >>= 1;
+ if (answer.mantissa >= (uint64_t(2) << binary::mantissa_explicit_bits())) {
+ answer.mantissa = (uint64_t(1) << binary::mantissa_explicit_bits());
+ answer.power2++; // undo previous addition
+ }
+
+ answer.mantissa &= ~(uint64_t(1) << binary::mantissa_explicit_bits());
+ if (answer.power2 >= binary::infinite_power()) { // infinity
+ answer.power2 = binary::infinite_power();
+ answer.mantissa = 0;
+ }
+ return answer;
+}
+
+
+} // namespace fast_float
+} // namespace arrow_vendored
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/fast_float/fast_float.h b/src/arrow/cpp/src/arrow/vendored/fast_float/fast_float.h
new file mode 100644
index 000000000..3e39cac90
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/fast_float/fast_float.h
@@ -0,0 +1,48 @@
+#ifndef FASTFLOAT_FAST_FLOAT_H
+#define FASTFLOAT_FAST_FLOAT_H
+
+#include <system_error>
+
+namespace arrow_vendored {
+namespace fast_float {
+enum chars_format {
+ scientific = 1<<0,
+ fixed = 1<<2,
+ hex = 1<<3,
+ general = fixed | scientific
+};
+
+
+struct from_chars_result {
+ const char *ptr;
+ std::errc ec;
+};
+
+/**
+ * This function parses the character sequence [first,last) for a number. It parses floating-point numbers expecting
+ * a locale-indepent format equivalent to what is used by std::strtod in the default ("C") locale.
+ * The resulting floating-point value is the closest floating-point values (using either float or double),
+ * using the "round to even" convention for values that would otherwise fall right in-between two values.
+ * That is, we provide exact parsing according to the IEEE standard.
+ *
+ * Given a successful parse, the pointer (`ptr`) in the returned value is set to point right after the
+ * parsed number, and the `value` referenced is set to the parsed value. In case of error, the returned
+ * `ec` contains a representative error, otherwise the default (`std::errc()`) value is stored.
+ *
+ * The implementation does not throw and does not allocate memory (e.g., with `new` or `malloc`).
+ *
+ * Like the C++17 standard, the `fast_float::from_chars` functions take an optional last argument of
+ * the type `fast_float::chars_format`. It is a bitset value: we check whether
+ * `fmt & fast_float::chars_format::fixed` and `fmt & fast_float::chars_format::scientific` are set
+ * to determine whether we allowe the fixed point and scientific notation respectively.
+ * The default is `fast_float::chars_format::general` which allows both `fixed` and `scientific`.
+ */
+template<typename T>
+from_chars_result from_chars(const char *first, const char *last,
+ T &value, chars_format fmt = chars_format::general) noexcept;
+
+}
+} // namespace arrow_vendored
+
+#include "parse_number.h"
+#endif // FASTFLOAT_FAST_FLOAT_H
diff --git a/src/arrow/cpp/src/arrow/vendored/fast_float/fast_table.h b/src/arrow/cpp/src/arrow/vendored/fast_float/fast_table.h
new file mode 100644
index 000000000..c1ca1755a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/fast_float/fast_table.h
@@ -0,0 +1,691 @@
+#ifndef FASTFLOAT_FAST_TABLE_H
+#define FASTFLOAT_FAST_TABLE_H
+#include <cstdint>
+
+namespace arrow_vendored {
+namespace fast_float {
+
+/**
+ * When mapping numbers from decimal to binary,
+ * we go from w * 10^q to m * 2^p but we have
+ * 10^q = 5^q * 2^q, so effectively
+ * we are trying to match
+ * w * 2^q * 5^q to m * 2^p. Thus the powers of two
+ * are not a concern since they can be represented
+ * exactly using the binary notation, only the powers of five
+ * affect the binary significand.
+ */
+
+/**
+ * The smallest non-zero float (binary64) is 2^−1074.
+ * We take as input numbers of the form w x 10^q where w < 2^64.
+ * We have that w * 10^-343 < 2^(64-344) 5^-343 < 2^-1076.
+ * However, we have that
+ * (2^64-1) * 10^-342 = (2^64-1) * 2^-342 * 5^-342 > 2^−1074.
+ * Thus it is possible for a number of the form w * 10^-342 where
+ * w is a 64-bit value to be a non-zero floating-point number.
+ *********
+ * Any number of form w * 10^309 where w>= 1 is going to be
+ * infinite in binary64 so we never need to worry about powers
+ * of 5 greater than 308.
+ */
+constexpr int smallest_power_of_five = -342;
+constexpr int largest_power_of_five = 308;
+// Powers of five from 5^-342 all the way to 5^308 rounded toward one.
+const uint64_t power_of_five_128[]= {
+ 0xeef453d6923bd65a,0x113faa2906a13b3f,
+ 0x9558b4661b6565f8,0x4ac7ca59a424c507,
+ 0xbaaee17fa23ebf76,0x5d79bcf00d2df649,
+ 0xe95a99df8ace6f53,0xf4d82c2c107973dc,
+ 0x91d8a02bb6c10594,0x79071b9b8a4be869,
+ 0xb64ec836a47146f9,0x9748e2826cdee284,
+ 0xe3e27a444d8d98b7,0xfd1b1b2308169b25,
+ 0x8e6d8c6ab0787f72,0xfe30f0f5e50e20f7,
+ 0xb208ef855c969f4f,0xbdbd2d335e51a935,
+ 0xde8b2b66b3bc4723,0xad2c788035e61382,
+ 0x8b16fb203055ac76,0x4c3bcb5021afcc31,
+ 0xaddcb9e83c6b1793,0xdf4abe242a1bbf3d,
+ 0xd953e8624b85dd78,0xd71d6dad34a2af0d,
+ 0x87d4713d6f33aa6b,0x8672648c40e5ad68,
+ 0xa9c98d8ccb009506,0x680efdaf511f18c2,
+ 0xd43bf0effdc0ba48,0x212bd1b2566def2,
+ 0x84a57695fe98746d,0x14bb630f7604b57,
+ 0xa5ced43b7e3e9188,0x419ea3bd35385e2d,
+ 0xcf42894a5dce35ea,0x52064cac828675b9,
+ 0x818995ce7aa0e1b2,0x7343efebd1940993,
+ 0xa1ebfb4219491a1f,0x1014ebe6c5f90bf8,
+ 0xca66fa129f9b60a6,0xd41a26e077774ef6,
+ 0xfd00b897478238d0,0x8920b098955522b4,
+ 0x9e20735e8cb16382,0x55b46e5f5d5535b0,
+ 0xc5a890362fddbc62,0xeb2189f734aa831d,
+ 0xf712b443bbd52b7b,0xa5e9ec7501d523e4,
+ 0x9a6bb0aa55653b2d,0x47b233c92125366e,
+ 0xc1069cd4eabe89f8,0x999ec0bb696e840a,
+ 0xf148440a256e2c76,0xc00670ea43ca250d,
+ 0x96cd2a865764dbca,0x380406926a5e5728,
+ 0xbc807527ed3e12bc,0xc605083704f5ecf2,
+ 0xeba09271e88d976b,0xf7864a44c633682e,
+ 0x93445b8731587ea3,0x7ab3ee6afbe0211d,
+ 0xb8157268fdae9e4c,0x5960ea05bad82964,
+ 0xe61acf033d1a45df,0x6fb92487298e33bd,
+ 0x8fd0c16206306bab,0xa5d3b6d479f8e056,
+ 0xb3c4f1ba87bc8696,0x8f48a4899877186c,
+ 0xe0b62e2929aba83c,0x331acdabfe94de87,
+ 0x8c71dcd9ba0b4925,0x9ff0c08b7f1d0b14,
+ 0xaf8e5410288e1b6f,0x7ecf0ae5ee44dd9,
+ 0xdb71e91432b1a24a,0xc9e82cd9f69d6150,
+ 0x892731ac9faf056e,0xbe311c083a225cd2,
+ 0xab70fe17c79ac6ca,0x6dbd630a48aaf406,
+ 0xd64d3d9db981787d,0x92cbbccdad5b108,
+ 0x85f0468293f0eb4e,0x25bbf56008c58ea5,
+ 0xa76c582338ed2621,0xaf2af2b80af6f24e,
+ 0xd1476e2c07286faa,0x1af5af660db4aee1,
+ 0x82cca4db847945ca,0x50d98d9fc890ed4d,
+ 0xa37fce126597973c,0xe50ff107bab528a0,
+ 0xcc5fc196fefd7d0c,0x1e53ed49a96272c8,
+ 0xff77b1fcbebcdc4f,0x25e8e89c13bb0f7a,
+ 0x9faacf3df73609b1,0x77b191618c54e9ac,
+ 0xc795830d75038c1d,0xd59df5b9ef6a2417,
+ 0xf97ae3d0d2446f25,0x4b0573286b44ad1d,
+ 0x9becce62836ac577,0x4ee367f9430aec32,
+ 0xc2e801fb244576d5,0x229c41f793cda73f,
+ 0xf3a20279ed56d48a,0x6b43527578c1110f,
+ 0x9845418c345644d6,0x830a13896b78aaa9,
+ 0xbe5691ef416bd60c,0x23cc986bc656d553,
+ 0xedec366b11c6cb8f,0x2cbfbe86b7ec8aa8,
+ 0x94b3a202eb1c3f39,0x7bf7d71432f3d6a9,
+ 0xb9e08a83a5e34f07,0xdaf5ccd93fb0cc53,
+ 0xe858ad248f5c22c9,0xd1b3400f8f9cff68,
+ 0x91376c36d99995be,0x23100809b9c21fa1,
+ 0xb58547448ffffb2d,0xabd40a0c2832a78a,
+ 0xe2e69915b3fff9f9,0x16c90c8f323f516c,
+ 0x8dd01fad907ffc3b,0xae3da7d97f6792e3,
+ 0xb1442798f49ffb4a,0x99cd11cfdf41779c,
+ 0xdd95317f31c7fa1d,0x40405643d711d583,
+ 0x8a7d3eef7f1cfc52,0x482835ea666b2572,
+ 0xad1c8eab5ee43b66,0xda3243650005eecf,
+ 0xd863b256369d4a40,0x90bed43e40076a82,
+ 0x873e4f75e2224e68,0x5a7744a6e804a291,
+ 0xa90de3535aaae202,0x711515d0a205cb36,
+ 0xd3515c2831559a83,0xd5a5b44ca873e03,
+ 0x8412d9991ed58091,0xe858790afe9486c2,
+ 0xa5178fff668ae0b6,0x626e974dbe39a872,
+ 0xce5d73ff402d98e3,0xfb0a3d212dc8128f,
+ 0x80fa687f881c7f8e,0x7ce66634bc9d0b99,
+ 0xa139029f6a239f72,0x1c1fffc1ebc44e80,
+ 0xc987434744ac874e,0xa327ffb266b56220,
+ 0xfbe9141915d7a922,0x4bf1ff9f0062baa8,
+ 0x9d71ac8fada6c9b5,0x6f773fc3603db4a9,
+ 0xc4ce17b399107c22,0xcb550fb4384d21d3,
+ 0xf6019da07f549b2b,0x7e2a53a146606a48,
+ 0x99c102844f94e0fb,0x2eda7444cbfc426d,
+ 0xc0314325637a1939,0xfa911155fefb5308,
+ 0xf03d93eebc589f88,0x793555ab7eba27ca,
+ 0x96267c7535b763b5,0x4bc1558b2f3458de,
+ 0xbbb01b9283253ca2,0x9eb1aaedfb016f16,
+ 0xea9c227723ee8bcb,0x465e15a979c1cadc,
+ 0x92a1958a7675175f,0xbfacd89ec191ec9,
+ 0xb749faed14125d36,0xcef980ec671f667b,
+ 0xe51c79a85916f484,0x82b7e12780e7401a,
+ 0x8f31cc0937ae58d2,0xd1b2ecb8b0908810,
+ 0xb2fe3f0b8599ef07,0x861fa7e6dcb4aa15,
+ 0xdfbdcece67006ac9,0x67a791e093e1d49a,
+ 0x8bd6a141006042bd,0xe0c8bb2c5c6d24e0,
+ 0xaecc49914078536d,0x58fae9f773886e18,
+ 0xda7f5bf590966848,0xaf39a475506a899e,
+ 0x888f99797a5e012d,0x6d8406c952429603,
+ 0xaab37fd7d8f58178,0xc8e5087ba6d33b83,
+ 0xd5605fcdcf32e1d6,0xfb1e4a9a90880a64,
+ 0x855c3be0a17fcd26,0x5cf2eea09a55067f,
+ 0xa6b34ad8c9dfc06f,0xf42faa48c0ea481e,
+ 0xd0601d8efc57b08b,0xf13b94daf124da26,
+ 0x823c12795db6ce57,0x76c53d08d6b70858,
+ 0xa2cb1717b52481ed,0x54768c4b0c64ca6e,
+ 0xcb7ddcdda26da268,0xa9942f5dcf7dfd09,
+ 0xfe5d54150b090b02,0xd3f93b35435d7c4c,
+ 0x9efa548d26e5a6e1,0xc47bc5014a1a6daf,
+ 0xc6b8e9b0709f109a,0x359ab6419ca1091b,
+ 0xf867241c8cc6d4c0,0xc30163d203c94b62,
+ 0x9b407691d7fc44f8,0x79e0de63425dcf1d,
+ 0xc21094364dfb5636,0x985915fc12f542e4,
+ 0xf294b943e17a2bc4,0x3e6f5b7b17b2939d,
+ 0x979cf3ca6cec5b5a,0xa705992ceecf9c42,
+ 0xbd8430bd08277231,0x50c6ff782a838353,
+ 0xece53cec4a314ebd,0xa4f8bf5635246428,
+ 0x940f4613ae5ed136,0x871b7795e136be99,
+ 0xb913179899f68584,0x28e2557b59846e3f,
+ 0xe757dd7ec07426e5,0x331aeada2fe589cf,
+ 0x9096ea6f3848984f,0x3ff0d2c85def7621,
+ 0xb4bca50b065abe63,0xfed077a756b53a9,
+ 0xe1ebce4dc7f16dfb,0xd3e8495912c62894,
+ 0x8d3360f09cf6e4bd,0x64712dd7abbbd95c,
+ 0xb080392cc4349dec,0xbd8d794d96aacfb3,
+ 0xdca04777f541c567,0xecf0d7a0fc5583a0,
+ 0x89e42caaf9491b60,0xf41686c49db57244,
+ 0xac5d37d5b79b6239,0x311c2875c522ced5,
+ 0xd77485cb25823ac7,0x7d633293366b828b,
+ 0x86a8d39ef77164bc,0xae5dff9c02033197,
+ 0xa8530886b54dbdeb,0xd9f57f830283fdfc,
+ 0xd267caa862a12d66,0xd072df63c324fd7b,
+ 0x8380dea93da4bc60,0x4247cb9e59f71e6d,
+ 0xa46116538d0deb78,0x52d9be85f074e608,
+ 0xcd795be870516656,0x67902e276c921f8b,
+ 0x806bd9714632dff6,0xba1cd8a3db53b6,
+ 0xa086cfcd97bf97f3,0x80e8a40eccd228a4,
+ 0xc8a883c0fdaf7df0,0x6122cd128006b2cd,
+ 0xfad2a4b13d1b5d6c,0x796b805720085f81,
+ 0x9cc3a6eec6311a63,0xcbe3303674053bb0,
+ 0xc3f490aa77bd60fc,0xbedbfc4411068a9c,
+ 0xf4f1b4d515acb93b,0xee92fb5515482d44,
+ 0x991711052d8bf3c5,0x751bdd152d4d1c4a,
+ 0xbf5cd54678eef0b6,0xd262d45a78a0635d,
+ 0xef340a98172aace4,0x86fb897116c87c34,
+ 0x9580869f0e7aac0e,0xd45d35e6ae3d4da0,
+ 0xbae0a846d2195712,0x8974836059cca109,
+ 0xe998d258869facd7,0x2bd1a438703fc94b,
+ 0x91ff83775423cc06,0x7b6306a34627ddcf,
+ 0xb67f6455292cbf08,0x1a3bc84c17b1d542,
+ 0xe41f3d6a7377eeca,0x20caba5f1d9e4a93,
+ 0x8e938662882af53e,0x547eb47b7282ee9c,
+ 0xb23867fb2a35b28d,0xe99e619a4f23aa43,
+ 0xdec681f9f4c31f31,0x6405fa00e2ec94d4,
+ 0x8b3c113c38f9f37e,0xde83bc408dd3dd04,
+ 0xae0b158b4738705e,0x9624ab50b148d445,
+ 0xd98ddaee19068c76,0x3badd624dd9b0957,
+ 0x87f8a8d4cfa417c9,0xe54ca5d70a80e5d6,
+ 0xa9f6d30a038d1dbc,0x5e9fcf4ccd211f4c,
+ 0xd47487cc8470652b,0x7647c3200069671f,
+ 0x84c8d4dfd2c63f3b,0x29ecd9f40041e073,
+ 0xa5fb0a17c777cf09,0xf468107100525890,
+ 0xcf79cc9db955c2cc,0x7182148d4066eeb4,
+ 0x81ac1fe293d599bf,0xc6f14cd848405530,
+ 0xa21727db38cb002f,0xb8ada00e5a506a7c,
+ 0xca9cf1d206fdc03b,0xa6d90811f0e4851c,
+ 0xfd442e4688bd304a,0x908f4a166d1da663,
+ 0x9e4a9cec15763e2e,0x9a598e4e043287fe,
+ 0xc5dd44271ad3cdba,0x40eff1e1853f29fd,
+ 0xf7549530e188c128,0xd12bee59e68ef47c,
+ 0x9a94dd3e8cf578b9,0x82bb74f8301958ce,
+ 0xc13a148e3032d6e7,0xe36a52363c1faf01,
+ 0xf18899b1bc3f8ca1,0xdc44e6c3cb279ac1,
+ 0x96f5600f15a7b7e5,0x29ab103a5ef8c0b9,
+ 0xbcb2b812db11a5de,0x7415d448f6b6f0e7,
+ 0xebdf661791d60f56,0x111b495b3464ad21,
+ 0x936b9fcebb25c995,0xcab10dd900beec34,
+ 0xb84687c269ef3bfb,0x3d5d514f40eea742,
+ 0xe65829b3046b0afa,0xcb4a5a3112a5112,
+ 0x8ff71a0fe2c2e6dc,0x47f0e785eaba72ab,
+ 0xb3f4e093db73a093,0x59ed216765690f56,
+ 0xe0f218b8d25088b8,0x306869c13ec3532c,
+ 0x8c974f7383725573,0x1e414218c73a13fb,
+ 0xafbd2350644eeacf,0xe5d1929ef90898fa,
+ 0xdbac6c247d62a583,0xdf45f746b74abf39,
+ 0x894bc396ce5da772,0x6b8bba8c328eb783,
+ 0xab9eb47c81f5114f,0x66ea92f3f326564,
+ 0xd686619ba27255a2,0xc80a537b0efefebd,
+ 0x8613fd0145877585,0xbd06742ce95f5f36,
+ 0xa798fc4196e952e7,0x2c48113823b73704,
+ 0xd17f3b51fca3a7a0,0xf75a15862ca504c5,
+ 0x82ef85133de648c4,0x9a984d73dbe722fb,
+ 0xa3ab66580d5fdaf5,0xc13e60d0d2e0ebba,
+ 0xcc963fee10b7d1b3,0x318df905079926a8,
+ 0xffbbcfe994e5c61f,0xfdf17746497f7052,
+ 0x9fd561f1fd0f9bd3,0xfeb6ea8bedefa633,
+ 0xc7caba6e7c5382c8,0xfe64a52ee96b8fc0,
+ 0xf9bd690a1b68637b,0x3dfdce7aa3c673b0,
+ 0x9c1661a651213e2d,0x6bea10ca65c084e,
+ 0xc31bfa0fe5698db8,0x486e494fcff30a62,
+ 0xf3e2f893dec3f126,0x5a89dba3c3efccfa,
+ 0x986ddb5c6b3a76b7,0xf89629465a75e01c,
+ 0xbe89523386091465,0xf6bbb397f1135823,
+ 0xee2ba6c0678b597f,0x746aa07ded582e2c,
+ 0x94db483840b717ef,0xa8c2a44eb4571cdc,
+ 0xba121a4650e4ddeb,0x92f34d62616ce413,
+ 0xe896a0d7e51e1566,0x77b020baf9c81d17,
+ 0x915e2486ef32cd60,0xace1474dc1d122e,
+ 0xb5b5ada8aaff80b8,0xd819992132456ba,
+ 0xe3231912d5bf60e6,0x10e1fff697ed6c69,
+ 0x8df5efabc5979c8f,0xca8d3ffa1ef463c1,
+ 0xb1736b96b6fd83b3,0xbd308ff8a6b17cb2,
+ 0xddd0467c64bce4a0,0xac7cb3f6d05ddbde,
+ 0x8aa22c0dbef60ee4,0x6bcdf07a423aa96b,
+ 0xad4ab7112eb3929d,0x86c16c98d2c953c6,
+ 0xd89d64d57a607744,0xe871c7bf077ba8b7,
+ 0x87625f056c7c4a8b,0x11471cd764ad4972,
+ 0xa93af6c6c79b5d2d,0xd598e40d3dd89bcf,
+ 0xd389b47879823479,0x4aff1d108d4ec2c3,
+ 0x843610cb4bf160cb,0xcedf722a585139ba,
+ 0xa54394fe1eedb8fe,0xc2974eb4ee658828,
+ 0xce947a3da6a9273e,0x733d226229feea32,
+ 0x811ccc668829b887,0x806357d5a3f525f,
+ 0xa163ff802a3426a8,0xca07c2dcb0cf26f7,
+ 0xc9bcff6034c13052,0xfc89b393dd02f0b5,
+ 0xfc2c3f3841f17c67,0xbbac2078d443ace2,
+ 0x9d9ba7832936edc0,0xd54b944b84aa4c0d,
+ 0xc5029163f384a931,0xa9e795e65d4df11,
+ 0xf64335bcf065d37d,0x4d4617b5ff4a16d5,
+ 0x99ea0196163fa42e,0x504bced1bf8e4e45,
+ 0xc06481fb9bcf8d39,0xe45ec2862f71e1d6,
+ 0xf07da27a82c37088,0x5d767327bb4e5a4c,
+ 0x964e858c91ba2655,0x3a6a07f8d510f86f,
+ 0xbbe226efb628afea,0x890489f70a55368b,
+ 0xeadab0aba3b2dbe5,0x2b45ac74ccea842e,
+ 0x92c8ae6b464fc96f,0x3b0b8bc90012929d,
+ 0xb77ada0617e3bbcb,0x9ce6ebb40173744,
+ 0xe55990879ddcaabd,0xcc420a6a101d0515,
+ 0x8f57fa54c2a9eab6,0x9fa946824a12232d,
+ 0xb32df8e9f3546564,0x47939822dc96abf9,
+ 0xdff9772470297ebd,0x59787e2b93bc56f7,
+ 0x8bfbea76c619ef36,0x57eb4edb3c55b65a,
+ 0xaefae51477a06b03,0xede622920b6b23f1,
+ 0xdab99e59958885c4,0xe95fab368e45eced,
+ 0x88b402f7fd75539b,0x11dbcb0218ebb414,
+ 0xaae103b5fcd2a881,0xd652bdc29f26a119,
+ 0xd59944a37c0752a2,0x4be76d3346f0495f,
+ 0x857fcae62d8493a5,0x6f70a4400c562ddb,
+ 0xa6dfbd9fb8e5b88e,0xcb4ccd500f6bb952,
+ 0xd097ad07a71f26b2,0x7e2000a41346a7a7,
+ 0x825ecc24c873782f,0x8ed400668c0c28c8,
+ 0xa2f67f2dfa90563b,0x728900802f0f32fa,
+ 0xcbb41ef979346bca,0x4f2b40a03ad2ffb9,
+ 0xfea126b7d78186bc,0xe2f610c84987bfa8,
+ 0x9f24b832e6b0f436,0xdd9ca7d2df4d7c9,
+ 0xc6ede63fa05d3143,0x91503d1c79720dbb,
+ 0xf8a95fcf88747d94,0x75a44c6397ce912a,
+ 0x9b69dbe1b548ce7c,0xc986afbe3ee11aba,
+ 0xc24452da229b021b,0xfbe85badce996168,
+ 0xf2d56790ab41c2a2,0xfae27299423fb9c3,
+ 0x97c560ba6b0919a5,0xdccd879fc967d41a,
+ 0xbdb6b8e905cb600f,0x5400e987bbc1c920,
+ 0xed246723473e3813,0x290123e9aab23b68,
+ 0x9436c0760c86e30b,0xf9a0b6720aaf6521,
+ 0xb94470938fa89bce,0xf808e40e8d5b3e69,
+ 0xe7958cb87392c2c2,0xb60b1d1230b20e04,
+ 0x90bd77f3483bb9b9,0xb1c6f22b5e6f48c2,
+ 0xb4ecd5f01a4aa828,0x1e38aeb6360b1af3,
+ 0xe2280b6c20dd5232,0x25c6da63c38de1b0,
+ 0x8d590723948a535f,0x579c487e5a38ad0e,
+ 0xb0af48ec79ace837,0x2d835a9df0c6d851,
+ 0xdcdb1b2798182244,0xf8e431456cf88e65,
+ 0x8a08f0f8bf0f156b,0x1b8e9ecb641b58ff,
+ 0xac8b2d36eed2dac5,0xe272467e3d222f3f,
+ 0xd7adf884aa879177,0x5b0ed81dcc6abb0f,
+ 0x86ccbb52ea94baea,0x98e947129fc2b4e9,
+ 0xa87fea27a539e9a5,0x3f2398d747b36224,
+ 0xd29fe4b18e88640e,0x8eec7f0d19a03aad,
+ 0x83a3eeeef9153e89,0x1953cf68300424ac,
+ 0xa48ceaaab75a8e2b,0x5fa8c3423c052dd7,
+ 0xcdb02555653131b6,0x3792f412cb06794d,
+ 0x808e17555f3ebf11,0xe2bbd88bbee40bd0,
+ 0xa0b19d2ab70e6ed6,0x5b6aceaeae9d0ec4,
+ 0xc8de047564d20a8b,0xf245825a5a445275,
+ 0xfb158592be068d2e,0xeed6e2f0f0d56712,
+ 0x9ced737bb6c4183d,0x55464dd69685606b,
+ 0xc428d05aa4751e4c,0xaa97e14c3c26b886,
+ 0xf53304714d9265df,0xd53dd99f4b3066a8,
+ 0x993fe2c6d07b7fab,0xe546a8038efe4029,
+ 0xbf8fdb78849a5f96,0xde98520472bdd033,
+ 0xef73d256a5c0f77c,0x963e66858f6d4440,
+ 0x95a8637627989aad,0xdde7001379a44aa8,
+ 0xbb127c53b17ec159,0x5560c018580d5d52,
+ 0xe9d71b689dde71af,0xaab8f01e6e10b4a6,
+ 0x9226712162ab070d,0xcab3961304ca70e8,
+ 0xb6b00d69bb55c8d1,0x3d607b97c5fd0d22,
+ 0xe45c10c42a2b3b05,0x8cb89a7db77c506a,
+ 0x8eb98a7a9a5b04e3,0x77f3608e92adb242,
+ 0xb267ed1940f1c61c,0x55f038b237591ed3,
+ 0xdf01e85f912e37a3,0x6b6c46dec52f6688,
+ 0x8b61313bbabce2c6,0x2323ac4b3b3da015,
+ 0xae397d8aa96c1b77,0xabec975e0a0d081a,
+ 0xd9c7dced53c72255,0x96e7bd358c904a21,
+ 0x881cea14545c7575,0x7e50d64177da2e54,
+ 0xaa242499697392d2,0xdde50bd1d5d0b9e9,
+ 0xd4ad2dbfc3d07787,0x955e4ec64b44e864,
+ 0x84ec3c97da624ab4,0xbd5af13bef0b113e,
+ 0xa6274bbdd0fadd61,0xecb1ad8aeacdd58e,
+ 0xcfb11ead453994ba,0x67de18eda5814af2,
+ 0x81ceb32c4b43fcf4,0x80eacf948770ced7,
+ 0xa2425ff75e14fc31,0xa1258379a94d028d,
+ 0xcad2f7f5359a3b3e,0x96ee45813a04330,
+ 0xfd87b5f28300ca0d,0x8bca9d6e188853fc,
+ 0x9e74d1b791e07e48,0x775ea264cf55347e,
+ 0xc612062576589dda,0x95364afe032a819e,
+ 0xf79687aed3eec551,0x3a83ddbd83f52205,
+ 0x9abe14cd44753b52,0xc4926a9672793543,
+ 0xc16d9a0095928a27,0x75b7053c0f178294,
+ 0xf1c90080baf72cb1,0x5324c68b12dd6339,
+ 0x971da05074da7bee,0xd3f6fc16ebca5e04,
+ 0xbce5086492111aea,0x88f4bb1ca6bcf585,
+ 0xec1e4a7db69561a5,0x2b31e9e3d06c32e6,
+ 0x9392ee8e921d5d07,0x3aff322e62439fd0,
+ 0xb877aa3236a4b449,0x9befeb9fad487c3,
+ 0xe69594bec44de15b,0x4c2ebe687989a9b4,
+ 0x901d7cf73ab0acd9,0xf9d37014bf60a11,
+ 0xb424dc35095cd80f,0x538484c19ef38c95,
+ 0xe12e13424bb40e13,0x2865a5f206b06fba,
+ 0x8cbccc096f5088cb,0xf93f87b7442e45d4,
+ 0xafebff0bcb24aafe,0xf78f69a51539d749,
+ 0xdbe6fecebdedd5be,0xb573440e5a884d1c,
+ 0x89705f4136b4a597,0x31680a88f8953031,
+ 0xabcc77118461cefc,0xfdc20d2b36ba7c3e,
+ 0xd6bf94d5e57a42bc,0x3d32907604691b4d,
+ 0x8637bd05af6c69b5,0xa63f9a49c2c1b110,
+ 0xa7c5ac471b478423,0xfcf80dc33721d54,
+ 0xd1b71758e219652b,0xd3c36113404ea4a9,
+ 0x83126e978d4fdf3b,0x645a1cac083126ea,
+ 0xa3d70a3d70a3d70a,0x3d70a3d70a3d70a4,
+ 0xcccccccccccccccc,0xcccccccccccccccd,
+ 0x8000000000000000,0x0,
+ 0xa000000000000000,0x0,
+ 0xc800000000000000,0x0,
+ 0xfa00000000000000,0x0,
+ 0x9c40000000000000,0x0,
+ 0xc350000000000000,0x0,
+ 0xf424000000000000,0x0,
+ 0x9896800000000000,0x0,
+ 0xbebc200000000000,0x0,
+ 0xee6b280000000000,0x0,
+ 0x9502f90000000000,0x0,
+ 0xba43b74000000000,0x0,
+ 0xe8d4a51000000000,0x0,
+ 0x9184e72a00000000,0x0,
+ 0xb5e620f480000000,0x0,
+ 0xe35fa931a0000000,0x0,
+ 0x8e1bc9bf04000000,0x0,
+ 0xb1a2bc2ec5000000,0x0,
+ 0xde0b6b3a76400000,0x0,
+ 0x8ac7230489e80000,0x0,
+ 0xad78ebc5ac620000,0x0,
+ 0xd8d726b7177a8000,0x0,
+ 0x878678326eac9000,0x0,
+ 0xa968163f0a57b400,0x0,
+ 0xd3c21bcecceda100,0x0,
+ 0x84595161401484a0,0x0,
+ 0xa56fa5b99019a5c8,0x0,
+ 0xcecb8f27f4200f3a,0x0,
+ 0x813f3978f8940984,0x4000000000000000,
+ 0xa18f07d736b90be5,0x5000000000000000,
+ 0xc9f2c9cd04674ede,0xa400000000000000,
+ 0xfc6f7c4045812296,0x4d00000000000000,
+ 0x9dc5ada82b70b59d,0xf020000000000000,
+ 0xc5371912364ce305,0x6c28000000000000,
+ 0xf684df56c3e01bc6,0xc732000000000000,
+ 0x9a130b963a6c115c,0x3c7f400000000000,
+ 0xc097ce7bc90715b3,0x4b9f100000000000,
+ 0xf0bdc21abb48db20,0x1e86d40000000000,
+ 0x96769950b50d88f4,0x1314448000000000,
+ 0xbc143fa4e250eb31,0x17d955a000000000,
+ 0xeb194f8e1ae525fd,0x5dcfab0800000000,
+ 0x92efd1b8d0cf37be,0x5aa1cae500000000,
+ 0xb7abc627050305ad,0xf14a3d9e40000000,
+ 0xe596b7b0c643c719,0x6d9ccd05d0000000,
+ 0x8f7e32ce7bea5c6f,0xe4820023a2000000,
+ 0xb35dbf821ae4f38b,0xdda2802c8a800000,
+ 0xe0352f62a19e306e,0xd50b2037ad200000,
+ 0x8c213d9da502de45,0x4526f422cc340000,
+ 0xaf298d050e4395d6,0x9670b12b7f410000,
+ 0xdaf3f04651d47b4c,0x3c0cdd765f114000,
+ 0x88d8762bf324cd0f,0xa5880a69fb6ac800,
+ 0xab0e93b6efee0053,0x8eea0d047a457a00,
+ 0xd5d238a4abe98068,0x72a4904598d6d880,
+ 0x85a36366eb71f041,0x47a6da2b7f864750,
+ 0xa70c3c40a64e6c51,0x999090b65f67d924,
+ 0xd0cf4b50cfe20765,0xfff4b4e3f741cf6d,
+ 0x82818f1281ed449f,0xbff8f10e7a8921a4,
+ 0xa321f2d7226895c7,0xaff72d52192b6a0d,
+ 0xcbea6f8ceb02bb39,0x9bf4f8a69f764490,
+ 0xfee50b7025c36a08,0x2f236d04753d5b4,
+ 0x9f4f2726179a2245,0x1d762422c946590,
+ 0xc722f0ef9d80aad6,0x424d3ad2b7b97ef5,
+ 0xf8ebad2b84e0d58b,0xd2e0898765a7deb2,
+ 0x9b934c3b330c8577,0x63cc55f49f88eb2f,
+ 0xc2781f49ffcfa6d5,0x3cbf6b71c76b25fb,
+ 0xf316271c7fc3908a,0x8bef464e3945ef7a,
+ 0x97edd871cfda3a56,0x97758bf0e3cbb5ac,
+ 0xbde94e8e43d0c8ec,0x3d52eeed1cbea317,
+ 0xed63a231d4c4fb27,0x4ca7aaa863ee4bdd,
+ 0x945e455f24fb1cf8,0x8fe8caa93e74ef6a,
+ 0xb975d6b6ee39e436,0xb3e2fd538e122b44,
+ 0xe7d34c64a9c85d44,0x60dbbca87196b616,
+ 0x90e40fbeea1d3a4a,0xbc8955e946fe31cd,
+ 0xb51d13aea4a488dd,0x6babab6398bdbe41,
+ 0xe264589a4dcdab14,0xc696963c7eed2dd1,
+ 0x8d7eb76070a08aec,0xfc1e1de5cf543ca2,
+ 0xb0de65388cc8ada8,0x3b25a55f43294bcb,
+ 0xdd15fe86affad912,0x49ef0eb713f39ebe,
+ 0x8a2dbf142dfcc7ab,0x6e3569326c784337,
+ 0xacb92ed9397bf996,0x49c2c37f07965404,
+ 0xd7e77a8f87daf7fb,0xdc33745ec97be906,
+ 0x86f0ac99b4e8dafd,0x69a028bb3ded71a3,
+ 0xa8acd7c0222311bc,0xc40832ea0d68ce0c,
+ 0xd2d80db02aabd62b,0xf50a3fa490c30190,
+ 0x83c7088e1aab65db,0x792667c6da79e0fa,
+ 0xa4b8cab1a1563f52,0x577001b891185938,
+ 0xcde6fd5e09abcf26,0xed4c0226b55e6f86,
+ 0x80b05e5ac60b6178,0x544f8158315b05b4,
+ 0xa0dc75f1778e39d6,0x696361ae3db1c721,
+ 0xc913936dd571c84c,0x3bc3a19cd1e38e9,
+ 0xfb5878494ace3a5f,0x4ab48a04065c723,
+ 0x9d174b2dcec0e47b,0x62eb0d64283f9c76,
+ 0xc45d1df942711d9a,0x3ba5d0bd324f8394,
+ 0xf5746577930d6500,0xca8f44ec7ee36479,
+ 0x9968bf6abbe85f20,0x7e998b13cf4e1ecb,
+ 0xbfc2ef456ae276e8,0x9e3fedd8c321a67e,
+ 0xefb3ab16c59b14a2,0xc5cfe94ef3ea101e,
+ 0x95d04aee3b80ece5,0xbba1f1d158724a12,
+ 0xbb445da9ca61281f,0x2a8a6e45ae8edc97,
+ 0xea1575143cf97226,0xf52d09d71a3293bd,
+ 0x924d692ca61be758,0x593c2626705f9c56,
+ 0xb6e0c377cfa2e12e,0x6f8b2fb00c77836c,
+ 0xe498f455c38b997a,0xb6dfb9c0f956447,
+ 0x8edf98b59a373fec,0x4724bd4189bd5eac,
+ 0xb2977ee300c50fe7,0x58edec91ec2cb657,
+ 0xdf3d5e9bc0f653e1,0x2f2967b66737e3ed,
+ 0x8b865b215899f46c,0xbd79e0d20082ee74,
+ 0xae67f1e9aec07187,0xecd8590680a3aa11,
+ 0xda01ee641a708de9,0xe80e6f4820cc9495,
+ 0x884134fe908658b2,0x3109058d147fdcdd,
+ 0xaa51823e34a7eede,0xbd4b46f0599fd415,
+ 0xd4e5e2cdc1d1ea96,0x6c9e18ac7007c91a,
+ 0x850fadc09923329e,0x3e2cf6bc604ddb0,
+ 0xa6539930bf6bff45,0x84db8346b786151c,
+ 0xcfe87f7cef46ff16,0xe612641865679a63,
+ 0x81f14fae158c5f6e,0x4fcb7e8f3f60c07e,
+ 0xa26da3999aef7749,0xe3be5e330f38f09d,
+ 0xcb090c8001ab551c,0x5cadf5bfd3072cc5,
+ 0xfdcb4fa002162a63,0x73d9732fc7c8f7f6,
+ 0x9e9f11c4014dda7e,0x2867e7fddcdd9afa,
+ 0xc646d63501a1511d,0xb281e1fd541501b8,
+ 0xf7d88bc24209a565,0x1f225a7ca91a4226,
+ 0x9ae757596946075f,0x3375788de9b06958,
+ 0xc1a12d2fc3978937,0x52d6b1641c83ae,
+ 0xf209787bb47d6b84,0xc0678c5dbd23a49a,
+ 0x9745eb4d50ce6332,0xf840b7ba963646e0,
+ 0xbd176620a501fbff,0xb650e5a93bc3d898,
+ 0xec5d3fa8ce427aff,0xa3e51f138ab4cebe,
+ 0x93ba47c980e98cdf,0xc66f336c36b10137,
+ 0xb8a8d9bbe123f017,0xb80b0047445d4184,
+ 0xe6d3102ad96cec1d,0xa60dc059157491e5,
+ 0x9043ea1ac7e41392,0x87c89837ad68db2f,
+ 0xb454e4a179dd1877,0x29babe4598c311fb,
+ 0xe16a1dc9d8545e94,0xf4296dd6fef3d67a,
+ 0x8ce2529e2734bb1d,0x1899e4a65f58660c,
+ 0xb01ae745b101e9e4,0x5ec05dcff72e7f8f,
+ 0xdc21a1171d42645d,0x76707543f4fa1f73,
+ 0x899504ae72497eba,0x6a06494a791c53a8,
+ 0xabfa45da0edbde69,0x487db9d17636892,
+ 0xd6f8d7509292d603,0x45a9d2845d3c42b6,
+ 0x865b86925b9bc5c2,0xb8a2392ba45a9b2,
+ 0xa7f26836f282b732,0x8e6cac7768d7141e,
+ 0xd1ef0244af2364ff,0x3207d795430cd926,
+ 0x8335616aed761f1f,0x7f44e6bd49e807b8,
+ 0xa402b9c5a8d3a6e7,0x5f16206c9c6209a6,
+ 0xcd036837130890a1,0x36dba887c37a8c0f,
+ 0x802221226be55a64,0xc2494954da2c9789,
+ 0xa02aa96b06deb0fd,0xf2db9baa10b7bd6c,
+ 0xc83553c5c8965d3d,0x6f92829494e5acc7,
+ 0xfa42a8b73abbf48c,0xcb772339ba1f17f9,
+ 0x9c69a97284b578d7,0xff2a760414536efb,
+ 0xc38413cf25e2d70d,0xfef5138519684aba,
+ 0xf46518c2ef5b8cd1,0x7eb258665fc25d69,
+ 0x98bf2f79d5993802,0xef2f773ffbd97a61,
+ 0xbeeefb584aff8603,0xaafb550ffacfd8fa,
+ 0xeeaaba2e5dbf6784,0x95ba2a53f983cf38,
+ 0x952ab45cfa97a0b2,0xdd945a747bf26183,
+ 0xba756174393d88df,0x94f971119aeef9e4,
+ 0xe912b9d1478ceb17,0x7a37cd5601aab85d,
+ 0x91abb422ccb812ee,0xac62e055c10ab33a,
+ 0xb616a12b7fe617aa,0x577b986b314d6009,
+ 0xe39c49765fdf9d94,0xed5a7e85fda0b80b,
+ 0x8e41ade9fbebc27d,0x14588f13be847307,
+ 0xb1d219647ae6b31c,0x596eb2d8ae258fc8,
+ 0xde469fbd99a05fe3,0x6fca5f8ed9aef3bb,
+ 0x8aec23d680043bee,0x25de7bb9480d5854,
+ 0xada72ccc20054ae9,0xaf561aa79a10ae6a,
+ 0xd910f7ff28069da4,0x1b2ba1518094da04,
+ 0x87aa9aff79042286,0x90fb44d2f05d0842,
+ 0xa99541bf57452b28,0x353a1607ac744a53,
+ 0xd3fa922f2d1675f2,0x42889b8997915ce8,
+ 0x847c9b5d7c2e09b7,0x69956135febada11,
+ 0xa59bc234db398c25,0x43fab9837e699095,
+ 0xcf02b2c21207ef2e,0x94f967e45e03f4bb,
+ 0x8161afb94b44f57d,0x1d1be0eebac278f5,
+ 0xa1ba1ba79e1632dc,0x6462d92a69731732,
+ 0xca28a291859bbf93,0x7d7b8f7503cfdcfe,
+ 0xfcb2cb35e702af78,0x5cda735244c3d43e,
+ 0x9defbf01b061adab,0x3a0888136afa64a7,
+ 0xc56baec21c7a1916,0x88aaa1845b8fdd0,
+ 0xf6c69a72a3989f5b,0x8aad549e57273d45,
+ 0x9a3c2087a63f6399,0x36ac54e2f678864b,
+ 0xc0cb28a98fcf3c7f,0x84576a1bb416a7dd,
+ 0xf0fdf2d3f3c30b9f,0x656d44a2a11c51d5,
+ 0x969eb7c47859e743,0x9f644ae5a4b1b325,
+ 0xbc4665b596706114,0x873d5d9f0dde1fee,
+ 0xeb57ff22fc0c7959,0xa90cb506d155a7ea,
+ 0x9316ff75dd87cbd8,0x9a7f12442d588f2,
+ 0xb7dcbf5354e9bece,0xc11ed6d538aeb2f,
+ 0xe5d3ef282a242e81,0x8f1668c8a86da5fa,
+ 0x8fa475791a569d10,0xf96e017d694487bc,
+ 0xb38d92d760ec4455,0x37c981dcc395a9ac,
+ 0xe070f78d3927556a,0x85bbe253f47b1417,
+ 0x8c469ab843b89562,0x93956d7478ccec8e,
+ 0xaf58416654a6babb,0x387ac8d1970027b2,
+ 0xdb2e51bfe9d0696a,0x6997b05fcc0319e,
+ 0x88fcf317f22241e2,0x441fece3bdf81f03,
+ 0xab3c2fddeeaad25a,0xd527e81cad7626c3,
+ 0xd60b3bd56a5586f1,0x8a71e223d8d3b074,
+ 0x85c7056562757456,0xf6872d5667844e49,
+ 0xa738c6bebb12d16c,0xb428f8ac016561db,
+ 0xd106f86e69d785c7,0xe13336d701beba52,
+ 0x82a45b450226b39c,0xecc0024661173473,
+ 0xa34d721642b06084,0x27f002d7f95d0190,
+ 0xcc20ce9bd35c78a5,0x31ec038df7b441f4,
+ 0xff290242c83396ce,0x7e67047175a15271,
+ 0x9f79a169bd203e41,0xf0062c6e984d386,
+ 0xc75809c42c684dd1,0x52c07b78a3e60868,
+ 0xf92e0c3537826145,0xa7709a56ccdf8a82,
+ 0x9bbcc7a142b17ccb,0x88a66076400bb691,
+ 0xc2abf989935ddbfe,0x6acff893d00ea435,
+ 0xf356f7ebf83552fe,0x583f6b8c4124d43,
+ 0x98165af37b2153de,0xc3727a337a8b704a,
+ 0xbe1bf1b059e9a8d6,0x744f18c0592e4c5c,
+ 0xeda2ee1c7064130c,0x1162def06f79df73,
+ 0x9485d4d1c63e8be7,0x8addcb5645ac2ba8,
+ 0xb9a74a0637ce2ee1,0x6d953e2bd7173692,
+ 0xe8111c87c5c1ba99,0xc8fa8db6ccdd0437,
+ 0x910ab1d4db9914a0,0x1d9c9892400a22a2,
+ 0xb54d5e4a127f59c8,0x2503beb6d00cab4b,
+ 0xe2a0b5dc971f303a,0x2e44ae64840fd61d,
+ 0x8da471a9de737e24,0x5ceaecfed289e5d2,
+ 0xb10d8e1456105dad,0x7425a83e872c5f47,
+ 0xdd50f1996b947518,0xd12f124e28f77719,
+ 0x8a5296ffe33cc92f,0x82bd6b70d99aaa6f,
+ 0xace73cbfdc0bfb7b,0x636cc64d1001550b,
+ 0xd8210befd30efa5a,0x3c47f7e05401aa4e,
+ 0x8714a775e3e95c78,0x65acfaec34810a71,
+ 0xa8d9d1535ce3b396,0x7f1839a741a14d0d,
+ 0xd31045a8341ca07c,0x1ede48111209a050,
+ 0x83ea2b892091e44d,0x934aed0aab460432,
+ 0xa4e4b66b68b65d60,0xf81da84d5617853f,
+ 0xce1de40642e3f4b9,0x36251260ab9d668e,
+ 0x80d2ae83e9ce78f3,0xc1d72b7c6b426019,
+ 0xa1075a24e4421730,0xb24cf65b8612f81f,
+ 0xc94930ae1d529cfc,0xdee033f26797b627,
+ 0xfb9b7cd9a4a7443c,0x169840ef017da3b1,
+ 0x9d412e0806e88aa5,0x8e1f289560ee864e,
+ 0xc491798a08a2ad4e,0xf1a6f2bab92a27e2,
+ 0xf5b5d7ec8acb58a2,0xae10af696774b1db,
+ 0x9991a6f3d6bf1765,0xacca6da1e0a8ef29,
+ 0xbff610b0cc6edd3f,0x17fd090a58d32af3,
+ 0xeff394dcff8a948e,0xddfc4b4cef07f5b0,
+ 0x95f83d0a1fb69cd9,0x4abdaf101564f98e,
+ 0xbb764c4ca7a4440f,0x9d6d1ad41abe37f1,
+ 0xea53df5fd18d5513,0x84c86189216dc5ed,
+ 0x92746b9be2f8552c,0x32fd3cf5b4e49bb4,
+ 0xb7118682dbb66a77,0x3fbc8c33221dc2a1,
+ 0xe4d5e82392a40515,0xfabaf3feaa5334a,
+ 0x8f05b1163ba6832d,0x29cb4d87f2a7400e,
+ 0xb2c71d5bca9023f8,0x743e20e9ef511012,
+ 0xdf78e4b2bd342cf6,0x914da9246b255416,
+ 0x8bab8eefb6409c1a,0x1ad089b6c2f7548e,
+ 0xae9672aba3d0c320,0xa184ac2473b529b1,
+ 0xda3c0f568cc4f3e8,0xc9e5d72d90a2741e,
+ 0x8865899617fb1871,0x7e2fa67c7a658892,
+ 0xaa7eebfb9df9de8d,0xddbb901b98feeab7,
+ 0xd51ea6fa85785631,0x552a74227f3ea565,
+ 0x8533285c936b35de,0xd53a88958f87275f,
+ 0xa67ff273b8460356,0x8a892abaf368f137,
+ 0xd01fef10a657842c,0x2d2b7569b0432d85,
+ 0x8213f56a67f6b29b,0x9c3b29620e29fc73,
+ 0xa298f2c501f45f42,0x8349f3ba91b47b8f,
+ 0xcb3f2f7642717713,0x241c70a936219a73,
+ 0xfe0efb53d30dd4d7,0xed238cd383aa0110,
+ 0x9ec95d1463e8a506,0xf4363804324a40aa,
+ 0xc67bb4597ce2ce48,0xb143c6053edcd0d5,
+ 0xf81aa16fdc1b81da,0xdd94b7868e94050a,
+ 0x9b10a4e5e9913128,0xca7cf2b4191c8326,
+ 0xc1d4ce1f63f57d72,0xfd1c2f611f63a3f0,
+ 0xf24a01a73cf2dccf,0xbc633b39673c8cec,
+ 0x976e41088617ca01,0xd5be0503e085d813,
+ 0xbd49d14aa79dbc82,0x4b2d8644d8a74e18,
+ 0xec9c459d51852ba2,0xddf8e7d60ed1219e,
+ 0x93e1ab8252f33b45,0xcabb90e5c942b503,
+ 0xb8da1662e7b00a17,0x3d6a751f3b936243,
+ 0xe7109bfba19c0c9d,0xcc512670a783ad4,
+ 0x906a617d450187e2,0x27fb2b80668b24c5,
+ 0xb484f9dc9641e9da,0xb1f9f660802dedf6,
+ 0xe1a63853bbd26451,0x5e7873f8a0396973,
+ 0x8d07e33455637eb2,0xdb0b487b6423e1e8,
+ 0xb049dc016abc5e5f,0x91ce1a9a3d2cda62,
+ 0xdc5c5301c56b75f7,0x7641a140cc7810fb,
+ 0x89b9b3e11b6329ba,0xa9e904c87fcb0a9d,
+ 0xac2820d9623bf429,0x546345fa9fbdcd44,
+ 0xd732290fbacaf133,0xa97c177947ad4095,
+ 0x867f59a9d4bed6c0,0x49ed8eabcccc485d,
+ 0xa81f301449ee8c70,0x5c68f256bfff5a74,
+ 0xd226fc195c6a2f8c,0x73832eec6fff3111,
+ 0x83585d8fd9c25db7,0xc831fd53c5ff7eab,
+ 0xa42e74f3d032f525,0xba3e7ca8b77f5e55,
+ 0xcd3a1230c43fb26f,0x28ce1bd2e55f35eb,
+ 0x80444b5e7aa7cf85,0x7980d163cf5b81b3,
+ 0xa0555e361951c366,0xd7e105bcc332621f,
+ 0xc86ab5c39fa63440,0x8dd9472bf3fefaa7,
+ 0xfa856334878fc150,0xb14f98f6f0feb951,
+ 0x9c935e00d4b9d8d2,0x6ed1bf9a569f33d3,
+ 0xc3b8358109e84f07,0xa862f80ec4700c8,
+ 0xf4a642e14c6262c8,0xcd27bb612758c0fa,
+ 0x98e7e9cccfbd7dbd,0x8038d51cb897789c,
+ 0xbf21e44003acdd2c,0xe0470a63e6bd56c3,
+ 0xeeea5d5004981478,0x1858ccfce06cac74,
+ 0x95527a5202df0ccb,0xf37801e0c43ebc8,
+ 0xbaa718e68396cffd,0xd30560258f54e6ba,
+ 0xe950df20247c83fd,0x47c6b82ef32a2069,
+ 0x91d28b7416cdd27e,0x4cdc331d57fa5441,
+ 0xb6472e511c81471d,0xe0133fe4adf8e952,
+ 0xe3d8f9e563a198e5,0x58180fddd97723a6,
+ 0x8e679c2f5e44ff8f,0x570f09eaa7ea7648,};
+
+}
+} // namespace arrow_vendored
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/fast_float/float_common.h b/src/arrow/cpp/src/arrow/vendored/fast_float/float_common.h
new file mode 100644
index 000000000..f7b7662b9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/fast_float/float_common.h
@@ -0,0 +1,345 @@
+#ifndef FASTFLOAT_FLOAT_COMMON_H
+#define FASTFLOAT_FLOAT_COMMON_H
+
+#include <cfloat>
+#include <cstdint>
+#include <cassert>
+
+#if (defined(__x86_64) || defined(__x86_64__) || defined(_M_X64) \
+ || defined(__amd64) || defined(__aarch64__) || defined(_M_ARM64) \
+ || defined(__MINGW64__) \
+ || defined(__s390x__) \
+ || (defined(__ppc64__) || defined(__PPC64__) || defined(__ppc64le__) || defined(__PPC64LE__)) \
+ || defined(__EMSCRIPTEN__))
+#define FASTFLOAT_64BIT
+#elif (defined(__i386) || defined(__i386__) || defined(_M_IX86) \
+ || defined(__arm__) \
+ || defined(__MINGW32__))
+#define FASTFLOAT_32BIT
+#else
+#error Unknown platform (not 32-bit, not 64-bit?)
+#endif
+
+#if ((defined(_WIN32) || defined(_WIN64)) && !defined(__clang__))
+#include <intrin.h>
+#endif
+
+#if defined(_MSC_VER) && !defined(__clang__)
+#define FASTFLOAT_VISUAL_STUDIO 1
+#endif
+
+#ifdef _WIN32
+#define FASTFLOAT_IS_BIG_ENDIAN 0
+#else
+#if defined(__APPLE__) || defined(__FreeBSD__)
+#include <machine/endian.h>
+#elif defined(sun) || defined(__sun)
+#include <sys/byteorder.h>
+#else
+#include <endian.h>
+#endif
+#
+#ifndef __BYTE_ORDER__
+// safe choice
+#define FASTFLOAT_IS_BIG_ENDIAN 0
+#endif
+#
+#ifndef __ORDER_LITTLE_ENDIAN__
+// safe choice
+#define FASTFLOAT_IS_BIG_ENDIAN 0
+#endif
+#
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+#define FASTFLOAT_IS_BIG_ENDIAN 0
+#else
+#define FASTFLOAT_IS_BIG_ENDIAN 1
+#endif
+#endif
+
+#ifdef FASTFLOAT_VISUAL_STUDIO
+#define fastfloat_really_inline __forceinline
+#else
+#define fastfloat_really_inline inline __attribute__((always_inline))
+#endif
+
+namespace arrow_vendored {
+namespace fast_float {
+
+// Compares two ASCII strings in a case insensitive manner.
+inline bool fastfloat_strncasecmp(const char *input1, const char *input2,
+ size_t length) {
+ char running_diff{0};
+ for (size_t i = 0; i < length; i++) {
+ running_diff |= (input1[i] ^ input2[i]);
+ }
+ return (running_diff == 0) || (running_diff == 32);
+}
+
+#ifndef FLT_EVAL_METHOD
+#error "FLT_EVAL_METHOD should be defined, please include cfloat."
+#endif
+
+inline bool is_space(uint8_t c) {
+ static const bool table[] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+ return table[c];
+}
+
+namespace {
+constexpr uint32_t max_digits = 768;
+constexpr uint32_t max_digit_without_overflow = 19;
+constexpr int32_t decimal_point_range = 2047;
+} // namespace
+
+struct value128 {
+ uint64_t low;
+ uint64_t high;
+ value128(uint64_t _low, uint64_t _high) : low(_low), high(_high) {}
+ value128() : low(0), high(0) {}
+};
+
+/* result might be undefined when input_num is zero */
+fastfloat_really_inline int leading_zeroes(uint64_t input_num) {
+ assert(input_num > 0);
+#ifdef FASTFLOAT_VISUAL_STUDIO
+ #if defined(_M_X64) || defined(_M_ARM64)
+ unsigned long leading_zero = 0;
+ // Search the mask data from most significant bit (MSB)
+ // to least significant bit (LSB) for a set bit (1).
+ _BitScanReverse64(&leading_zero, input_num);
+ return (int)(63 - leading_zero);
+ #else
+ int last_bit = 0;
+ if(input_num & uint64_t(0xffffffff00000000)) input_num >>= 32, last_bit |= 32;
+ if(input_num & uint64_t( 0xffff0000)) input_num >>= 16, last_bit |= 16;
+ if(input_num & uint64_t( 0xff00)) input_num >>= 8, last_bit |= 8;
+ if(input_num & uint64_t( 0xf0)) input_num >>= 4, last_bit |= 4;
+ if(input_num & uint64_t( 0xc)) input_num >>= 2, last_bit |= 2;
+ if(input_num & uint64_t( 0x2)) input_num >>= 1, last_bit |= 1;
+ return 63 - last_bit;
+ #endif
+#else
+ return __builtin_clzll(input_num);
+#endif
+}
+
+#ifdef FASTFLOAT_32BIT
+
+#if (!defined(_WIN32)) || defined(__MINGW32__)
+// slow emulation routine for 32-bit
+fastfloat_really_inline uint64_t __emulu(uint32_t x, uint32_t y) {
+ return x * (uint64_t)y;
+}
+#endif
+
+// slow emulation routine for 32-bit
+#if !defined(__MINGW64__)
+fastfloat_really_inline uint64_t _umul128(uint64_t ab, uint64_t cd,
+ uint64_t *hi) {
+ uint64_t ad = __emulu((uint32_t)(ab >> 32), (uint32_t)cd);
+ uint64_t bd = __emulu((uint32_t)ab, (uint32_t)cd);
+ uint64_t adbc = ad + __emulu((uint32_t)ab, (uint32_t)(cd >> 32));
+ uint64_t adbc_carry = !!(adbc < ad);
+ uint64_t lo = bd + (adbc << 32);
+ *hi = __emulu((uint32_t)(ab >> 32), (uint32_t)(cd >> 32)) + (adbc >> 32) +
+ (adbc_carry << 32) + !!(lo < bd);
+ return lo;
+}
+#endif // !__MINGW64__
+
+#endif // FASTFLOAT_32BIT
+
+
+// compute 64-bit a*b
+fastfloat_really_inline value128 full_multiplication(uint64_t a,
+ uint64_t b) {
+ value128 answer;
+#ifdef _M_ARM64
+ // ARM64 has native support for 64-bit multiplications, no need to emulate
+ answer.high = __umulh(a, b);
+ answer.low = a * b;
+#elif defined(FASTFLOAT_32BIT) || (defined(_WIN64) && !defined(__clang__))
+ answer.low = _umul128(a, b, &answer.high); // _umul128 not available on ARM64
+#elif defined(FASTFLOAT_64BIT)
+ __uint128_t r = ((__uint128_t)a) * b;
+ answer.low = uint64_t(r);
+ answer.high = uint64_t(r >> 64);
+#else
+ #error Not implemented
+#endif
+ return answer;
+}
+
+
+struct adjusted_mantissa {
+ uint64_t mantissa{0};
+ int power2{0}; // a negative value indicates an invalid result
+ adjusted_mantissa() = default;
+ bool operator==(const adjusted_mantissa &o) const {
+ return mantissa == o.mantissa && power2 == o.power2;
+ }
+ bool operator!=(const adjusted_mantissa &o) const {
+ return mantissa != o.mantissa || power2 != o.power2;
+ }
+};
+
+struct decimal {
+ uint32_t num_digits{0};
+ int32_t decimal_point{0};
+ bool negative{false};
+ bool truncated{false};
+ uint8_t digits[max_digits];
+ decimal() = default;
+ // Copies are not allowed since this is a fat object.
+ decimal(const decimal &) = delete;
+ // Copies are not allowed since this is a fat object.
+ decimal &operator=(const decimal &) = delete;
+ // Moves are allowed:
+ decimal(decimal &&) = default;
+ decimal &operator=(decimal &&other) = default;
+};
+
+constexpr static double powers_of_ten_double[] = {
+ 1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11,
+ 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19, 1e20, 1e21, 1e22};
+constexpr static float powers_of_ten_float[] = {1e0, 1e1, 1e2, 1e3, 1e4, 1e5,
+ 1e6, 1e7, 1e8, 1e9, 1e10};
+
+template <typename T> struct binary_format {
+ static constexpr int mantissa_explicit_bits();
+ static constexpr int minimum_exponent();
+ static constexpr int infinite_power();
+ static constexpr int sign_index();
+ static constexpr int min_exponent_fast_path();
+ static constexpr int max_exponent_fast_path();
+ static constexpr int max_exponent_round_to_even();
+ static constexpr int min_exponent_round_to_even();
+ static constexpr uint64_t max_mantissa_fast_path();
+ static constexpr int largest_power_of_ten();
+ static constexpr int smallest_power_of_ten();
+ static constexpr T exact_power_of_ten(int64_t power);
+};
+
+template <> constexpr int binary_format<double>::mantissa_explicit_bits() {
+ return 52;
+}
+template <> constexpr int binary_format<float>::mantissa_explicit_bits() {
+ return 23;
+}
+
+template <> constexpr int binary_format<double>::max_exponent_round_to_even() {
+ return 23;
+}
+
+template <> constexpr int binary_format<float>::max_exponent_round_to_even() {
+ return 10;
+}
+
+template <> constexpr int binary_format<double>::min_exponent_round_to_even() {
+ return -4;
+}
+
+template <> constexpr int binary_format<float>::min_exponent_round_to_even() {
+ return -17;
+}
+
+template <> constexpr int binary_format<double>::minimum_exponent() {
+ return -1023;
+}
+template <> constexpr int binary_format<float>::minimum_exponent() {
+ return -127;
+}
+
+template <> constexpr int binary_format<double>::infinite_power() {
+ return 0x7FF;
+}
+template <> constexpr int binary_format<float>::infinite_power() {
+ return 0xFF;
+}
+
+template <> constexpr int binary_format<double>::sign_index() { return 63; }
+template <> constexpr int binary_format<float>::sign_index() { return 31; }
+
+template <> constexpr int binary_format<double>::min_exponent_fast_path() {
+#if (FLT_EVAL_METHOD != 1) && (FLT_EVAL_METHOD != 0)
+ return 0;
+#else
+ return -22;
+#endif
+}
+template <> constexpr int binary_format<float>::min_exponent_fast_path() {
+#if (FLT_EVAL_METHOD != 1) && (FLT_EVAL_METHOD != 0)
+ return 0;
+#else
+ return -10;
+#endif
+}
+
+template <> constexpr int binary_format<double>::max_exponent_fast_path() {
+ return 22;
+}
+template <> constexpr int binary_format<float>::max_exponent_fast_path() {
+ return 10;
+}
+
+template <> constexpr uint64_t binary_format<double>::max_mantissa_fast_path() {
+ return uint64_t(2) << mantissa_explicit_bits();
+}
+template <> constexpr uint64_t binary_format<float>::max_mantissa_fast_path() {
+ return uint64_t(2) << mantissa_explicit_bits();
+}
+
+template <>
+constexpr double binary_format<double>::exact_power_of_ten(int64_t power) {
+ return powers_of_ten_double[power];
+}
+template <>
+constexpr float binary_format<float>::exact_power_of_ten(int64_t power) {
+
+ return powers_of_ten_float[power];
+}
+
+
+template <>
+constexpr int binary_format<double>::largest_power_of_ten() {
+ return 308;
+}
+template <>
+constexpr int binary_format<float>::largest_power_of_ten() {
+ return 38;
+}
+
+template <>
+constexpr int binary_format<double>::smallest_power_of_ten() {
+ return -342;
+}
+template <>
+constexpr int binary_format<float>::smallest_power_of_ten() {
+ return -65;
+}
+
+} // namespace fast_float
+} // namespace arrow_vendored
+
+// for convenience:
+template<class OStream>
+inline OStream& operator<<(OStream &out, const arrow_vendored::fast_float::decimal &d) {
+ out << "0.";
+ for (size_t i = 0; i < d.num_digits; i++) {
+ out << int32_t(d.digits[i]);
+ }
+ out << " * 10 ** " << d.decimal_point;
+ return out;
+}
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/fast_float/parse_number.h b/src/arrow/cpp/src/arrow/vendored/fast_float/parse_number.h
new file mode 100644
index 000000000..d530f765c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/fast_float/parse_number.h
@@ -0,0 +1,133 @@
+#ifndef FASTFLOAT_PARSE_NUMBER_H
+#define FASTFLOAT_PARSE_NUMBER_H
+#include "ascii_number.h"
+#include "decimal_to_binary.h"
+#include "simple_decimal_conversion.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstring>
+#include <limits>
+#include <system_error>
+
+namespace arrow_vendored {
+namespace fast_float {
+
+
+namespace {
+/**
+ * Special case +inf, -inf, nan, infinity, -infinity.
+ * The case comparisons could be made much faster given that we know that the
+ * strings a null-free and fixed.
+ **/
+template <typename T>
+from_chars_result parse_infnan(const char *first, const char *last, T &value) noexcept {
+ from_chars_result answer;
+ answer.ptr = first;
+ answer.ec = std::errc(); // be optimistic
+ bool minusSign = false;
+ if (*first == '-') { // assume first < last, so dereference without checks
+ minusSign = true;
+ ++first;
+ } else if( *first == '+' ) { // C++17 20.19.3.7 explicitly forbids '+' here, but anyway
+ ++first;
+ }
+ if (last - first >= 3) {
+ if (fastfloat_strncasecmp(first, "nan", 3)) {
+ answer.ptr = (first += 3);
+ value = minusSign ? -std::numeric_limits<T>::quiet_NaN() : std::numeric_limits<T>::quiet_NaN();
+ // Check for possible nan(n-char-seq-opt), C++17 20.19.3.7, C11 7.20.1.3.3. At least MSVC produces nan(ind) and nan(snan).
+ if(first != last && *first == '(') {
+ for(const char* ptr = first + 1; ptr != last; ++ptr) {
+ if (*ptr == ')') {
+ answer.ptr = ptr + 1; // valid nan(n-char-seq-opt)
+ break;
+ }
+ else if(!(('a' <= *ptr && *ptr <= 'z') || ('A' <= *ptr && *ptr <= 'Z') || ('0' <= *ptr && *ptr <= '9') || *ptr == '_'))
+ break; // forbidden char, not nan(n-char-seq-opt)
+ }
+ }
+ return answer;
+ }
+ if (fastfloat_strncasecmp(first, "inf", 3)) {
+ if ((last - first >= 8) && fastfloat_strncasecmp(first + 3, "inity", 5)) {
+ answer.ptr = first + 8;
+ } else {
+ answer.ptr = first + 3;
+ }
+ value = minusSign ? -std::numeric_limits<T>::infinity() : std::numeric_limits<T>::infinity();
+ return answer;
+ }
+ }
+ answer.ec = std::errc::invalid_argument;
+ return answer;
+}
+
+template<typename T>
+fastfloat_really_inline void to_float(bool negative, adjusted_mantissa am, T &value) {
+ uint64_t word = am.mantissa;
+ word |= uint64_t(am.power2) << binary_format<T>::mantissa_explicit_bits();
+ word = negative
+ ? word | (uint64_t(1) << binary_format<T>::sign_index()) : word;
+#if FASTFLOAT_IS_BIG_ENDIAN == 1
+ if (std::is_same<T, float>::value) {
+ ::memcpy(&value, (char *)&word + 4, sizeof(T)); // extract value at offset 4-7 if float on big-endian
+ } else {
+ ::memcpy(&value, &word, sizeof(T));
+ }
+#else
+ // For little-endian systems:
+ ::memcpy(&value, &word, sizeof(T));
+#endif
+}
+
+} // namespace
+
+
+
+template<typename T>
+from_chars_result from_chars(const char *first, const char *last,
+ T &value, chars_format fmt /*= chars_format::general*/) noexcept {
+ static_assert (std::is_same<T, double>::value || std::is_same<T, float>::value, "only float and double are supported");
+
+
+ from_chars_result answer;
+ while ((first != last) && fast_float::is_space(uint8_t(*first))) {
+ first++;
+ }
+ if (first == last) {
+ answer.ec = std::errc::invalid_argument;
+ answer.ptr = first;
+ return answer;
+ }
+ parsed_number_string pns = parse_number_string(first, last, fmt);
+ if (!pns.valid) {
+ return parse_infnan(first, last, value);
+ }
+ answer.ec = std::errc(); // be optimistic
+ answer.ptr = pns.lastmatch;
+ // Next is Clinger's fast path.
+ if (binary_format<T>::min_exponent_fast_path() <= pns.exponent && pns.exponent <= binary_format<T>::max_exponent_fast_path() && pns.mantissa <=binary_format<T>::max_mantissa_fast_path() && !pns.too_many_digits) {
+ value = T(pns.mantissa);
+ if (pns.exponent < 0) { value = value / binary_format<T>::exact_power_of_ten(-pns.exponent); }
+ else { value = value * binary_format<T>::exact_power_of_ten(pns.exponent); }
+ if (pns.negative) { value = -value; }
+ return answer;
+ }
+ adjusted_mantissa am = compute_float<binary_format<T>>(pns.exponent, pns.mantissa);
+ if(pns.too_many_digits) {
+ if(am != compute_float<binary_format<T>>(pns.exponent, pns.mantissa + 1)) {
+ am.power2 = -1; // value is invalid.
+ }
+ }
+ // If we called compute_float<binary_format<T>>(pns.exponent, pns.mantissa) and we have an invalid power (am.power2 < 0),
+ // then we need to go the long way around again. This is very uncommon.
+ if(am.power2 < 0) { am = parse_long_mantissa<binary_format<T>>(first,last); }
+ to_float(pns.negative, am, value);
+ return answer;
+}
+
+} // namespace fast_float
+} // namespace arrow_vendored
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/fast_float/simple_decimal_conversion.h b/src/arrow/cpp/src/arrow/vendored/fast_float/simple_decimal_conversion.h
new file mode 100644
index 000000000..486724cd9
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/fast_float/simple_decimal_conversion.h
@@ -0,0 +1,362 @@
+#ifndef FASTFLOAT_GENERIC_DECIMAL_TO_BINARY_H
+#define FASTFLOAT_GENERIC_DECIMAL_TO_BINARY_H
+
+/**
+ * This code is meant to handle the case where we have more than 19 digits.
+ *
+ * It is based on work by Nigel Tao (at https://github.com/google/wuffs/)
+ * who credits Ken Thompson for the design (via a reference to the Go source
+ * code).
+ *
+ * Rob Pike suggested that this algorithm be called "Simple Decimal Conversion".
+ *
+ * It is probably not very fast but it is a fallback that should almost never
+ * be used in real life. Though it is not fast, it is "easily" understood and debugged.
+ **/
+#include "ascii_number.h"
+#include "decimal_to_binary.h"
+#include <cstdint>
+
+namespace arrow_vendored {
+namespace fast_float {
+
+namespace {
+
+// remove all final zeroes
+inline void trim(decimal &h) {
+ while ((h.num_digits > 0) && (h.digits[h.num_digits - 1] == 0)) {
+ h.num_digits--;
+ }
+}
+
+
+
+uint32_t number_of_digits_decimal_left_shift(const decimal &h, uint32_t shift) {
+ shift &= 63;
+ const static uint16_t number_of_digits_decimal_left_shift_table[65] = {
+ 0x0000, 0x0800, 0x0801, 0x0803, 0x1006, 0x1009, 0x100D, 0x1812, 0x1817,
+ 0x181D, 0x2024, 0x202B, 0x2033, 0x203C, 0x2846, 0x2850, 0x285B, 0x3067,
+ 0x3073, 0x3080, 0x388E, 0x389C, 0x38AB, 0x38BB, 0x40CC, 0x40DD, 0x40EF,
+ 0x4902, 0x4915, 0x4929, 0x513E, 0x5153, 0x5169, 0x5180, 0x5998, 0x59B0,
+ 0x59C9, 0x61E3, 0x61FD, 0x6218, 0x6A34, 0x6A50, 0x6A6D, 0x6A8B, 0x72AA,
+ 0x72C9, 0x72E9, 0x7B0A, 0x7B2B, 0x7B4D, 0x8370, 0x8393, 0x83B7, 0x83DC,
+ 0x8C02, 0x8C28, 0x8C4F, 0x9477, 0x949F, 0x94C8, 0x9CF2, 0x051C, 0x051C,
+ 0x051C, 0x051C,
+ };
+ uint32_t x_a = number_of_digits_decimal_left_shift_table[shift];
+ uint32_t x_b = number_of_digits_decimal_left_shift_table[shift + 1];
+ uint32_t num_new_digits = x_a >> 11;
+ uint32_t pow5_a = 0x7FF & x_a;
+ uint32_t pow5_b = 0x7FF & x_b;
+ const static uint8_t
+ number_of_digits_decimal_left_shift_table_powers_of_5[0x051C] = {
+ 5, 2, 5, 1, 2, 5, 6, 2, 5, 3, 1, 2, 5, 1, 5, 6, 2, 5, 7, 8, 1, 2, 5, 3,
+ 9, 0, 6, 2, 5, 1, 9, 5, 3, 1, 2, 5, 9, 7, 6, 5, 6, 2, 5, 4, 8, 8, 2, 8,
+ 1, 2, 5, 2, 4, 4, 1, 4, 0, 6, 2, 5, 1, 2, 2, 0, 7, 0, 3, 1, 2, 5, 6, 1,
+ 0, 3, 5, 1, 5, 6, 2, 5, 3, 0, 5, 1, 7, 5, 7, 8, 1, 2, 5, 1, 5, 2, 5, 8,
+ 7, 8, 9, 0, 6, 2, 5, 7, 6, 2, 9, 3, 9, 4, 5, 3, 1, 2, 5, 3, 8, 1, 4, 6,
+ 9, 7, 2, 6, 5, 6, 2, 5, 1, 9, 0, 7, 3, 4, 8, 6, 3, 2, 8, 1, 2, 5, 9, 5,
+ 3, 6, 7, 4, 3, 1, 6, 4, 0, 6, 2, 5, 4, 7, 6, 8, 3, 7, 1, 5, 8, 2, 0, 3,
+ 1, 2, 5, 2, 3, 8, 4, 1, 8, 5, 7, 9, 1, 0, 1, 5, 6, 2, 5, 1, 1, 9, 2, 0,
+ 9, 2, 8, 9, 5, 5, 0, 7, 8, 1, 2, 5, 5, 9, 6, 0, 4, 6, 4, 4, 7, 7, 5, 3,
+ 9, 0, 6, 2, 5, 2, 9, 8, 0, 2, 3, 2, 2, 3, 8, 7, 6, 9, 5, 3, 1, 2, 5, 1,
+ 4, 9, 0, 1, 1, 6, 1, 1, 9, 3, 8, 4, 7, 6, 5, 6, 2, 5, 7, 4, 5, 0, 5, 8,
+ 0, 5, 9, 6, 9, 2, 3, 8, 2, 8, 1, 2, 5, 3, 7, 2, 5, 2, 9, 0, 2, 9, 8, 4,
+ 6, 1, 9, 1, 4, 0, 6, 2, 5, 1, 8, 6, 2, 6, 4, 5, 1, 4, 9, 2, 3, 0, 9, 5,
+ 7, 0, 3, 1, 2, 5, 9, 3, 1, 3, 2, 2, 5, 7, 4, 6, 1, 5, 4, 7, 8, 5, 1, 5,
+ 6, 2, 5, 4, 6, 5, 6, 6, 1, 2, 8, 7, 3, 0, 7, 7, 3, 9, 2, 5, 7, 8, 1, 2,
+ 5, 2, 3, 2, 8, 3, 0, 6, 4, 3, 6, 5, 3, 8, 6, 9, 6, 2, 8, 9, 0, 6, 2, 5,
+ 1, 1, 6, 4, 1, 5, 3, 2, 1, 8, 2, 6, 9, 3, 4, 8, 1, 4, 4, 5, 3, 1, 2, 5,
+ 5, 8, 2, 0, 7, 6, 6, 0, 9, 1, 3, 4, 6, 7, 4, 0, 7, 2, 2, 6, 5, 6, 2, 5,
+ 2, 9, 1, 0, 3, 8, 3, 0, 4, 5, 6, 7, 3, 3, 7, 0, 3, 6, 1, 3, 2, 8, 1, 2,
+ 5, 1, 4, 5, 5, 1, 9, 1, 5, 2, 2, 8, 3, 6, 6, 8, 5, 1, 8, 0, 6, 6, 4, 0,
+ 6, 2, 5, 7, 2, 7, 5, 9, 5, 7, 6, 1, 4, 1, 8, 3, 4, 2, 5, 9, 0, 3, 3, 2,
+ 0, 3, 1, 2, 5, 3, 6, 3, 7, 9, 7, 8, 8, 0, 7, 0, 9, 1, 7, 1, 2, 9, 5, 1,
+ 6, 6, 0, 1, 5, 6, 2, 5, 1, 8, 1, 8, 9, 8, 9, 4, 0, 3, 5, 4, 5, 8, 5, 6,
+ 4, 7, 5, 8, 3, 0, 0, 7, 8, 1, 2, 5, 9, 0, 9, 4, 9, 4, 7, 0, 1, 7, 7, 2,
+ 9, 2, 8, 2, 3, 7, 9, 1, 5, 0, 3, 9, 0, 6, 2, 5, 4, 5, 4, 7, 4, 7, 3, 5,
+ 0, 8, 8, 6, 4, 6, 4, 1, 1, 8, 9, 5, 7, 5, 1, 9, 5, 3, 1, 2, 5, 2, 2, 7,
+ 3, 7, 3, 6, 7, 5, 4, 4, 3, 2, 3, 2, 0, 5, 9, 4, 7, 8, 7, 5, 9, 7, 6, 5,
+ 6, 2, 5, 1, 1, 3, 6, 8, 6, 8, 3, 7, 7, 2, 1, 6, 1, 6, 0, 2, 9, 7, 3, 9,
+ 3, 7, 9, 8, 8, 2, 8, 1, 2, 5, 5, 6, 8, 4, 3, 4, 1, 8, 8, 6, 0, 8, 0, 8,
+ 0, 1, 4, 8, 6, 9, 6, 8, 9, 9, 4, 1, 4, 0, 6, 2, 5, 2, 8, 4, 2, 1, 7, 0,
+ 9, 4, 3, 0, 4, 0, 4, 0, 0, 7, 4, 3, 4, 8, 4, 4, 9, 7, 0, 7, 0, 3, 1, 2,
+ 5, 1, 4, 2, 1, 0, 8, 5, 4, 7, 1, 5, 2, 0, 2, 0, 0, 3, 7, 1, 7, 4, 2, 2,
+ 4, 8, 5, 3, 5, 1, 5, 6, 2, 5, 7, 1, 0, 5, 4, 2, 7, 3, 5, 7, 6, 0, 1, 0,
+ 0, 1, 8, 5, 8, 7, 1, 1, 2, 4, 2, 6, 7, 5, 7, 8, 1, 2, 5, 3, 5, 5, 2, 7,
+ 1, 3, 6, 7, 8, 8, 0, 0, 5, 0, 0, 9, 2, 9, 3, 5, 5, 6, 2, 1, 3, 3, 7, 8,
+ 9, 0, 6, 2, 5, 1, 7, 7, 6, 3, 5, 6, 8, 3, 9, 4, 0, 0, 2, 5, 0, 4, 6, 4,
+ 6, 7, 7, 8, 1, 0, 6, 6, 8, 9, 4, 5, 3, 1, 2, 5, 8, 8, 8, 1, 7, 8, 4, 1,
+ 9, 7, 0, 0, 1, 2, 5, 2, 3, 2, 3, 3, 8, 9, 0, 5, 3, 3, 4, 4, 7, 2, 6, 5,
+ 6, 2, 5, 4, 4, 4, 0, 8, 9, 2, 0, 9, 8, 5, 0, 0, 6, 2, 6, 1, 6, 1, 6, 9,
+ 4, 5, 2, 6, 6, 7, 2, 3, 6, 3, 2, 8, 1, 2, 5, 2, 2, 2, 0, 4, 4, 6, 0, 4,
+ 9, 2, 5, 0, 3, 1, 3, 0, 8, 0, 8, 4, 7, 2, 6, 3, 3, 3, 6, 1, 8, 1, 6, 4,
+ 0, 6, 2, 5, 1, 1, 1, 0, 2, 2, 3, 0, 2, 4, 6, 2, 5, 1, 5, 6, 5, 4, 0, 4,
+ 2, 3, 6, 3, 1, 6, 6, 8, 0, 9, 0, 8, 2, 0, 3, 1, 2, 5, 5, 5, 5, 1, 1, 1,
+ 5, 1, 2, 3, 1, 2, 5, 7, 8, 2, 7, 0, 2, 1, 1, 8, 1, 5, 8, 3, 4, 0, 4, 5,
+ 4, 1, 0, 1, 5, 6, 2, 5, 2, 7, 7, 5, 5, 5, 7, 5, 6, 1, 5, 6, 2, 8, 9, 1,
+ 3, 5, 1, 0, 5, 9, 0, 7, 9, 1, 7, 0, 2, 2, 7, 0, 5, 0, 7, 8, 1, 2, 5, 1,
+ 3, 8, 7, 7, 7, 8, 7, 8, 0, 7, 8, 1, 4, 4, 5, 6, 7, 5, 5, 2, 9, 5, 3, 9,
+ 5, 8, 5, 1, 1, 3, 5, 2, 5, 3, 9, 0, 6, 2, 5, 6, 9, 3, 8, 8, 9, 3, 9, 0,
+ 3, 9, 0, 7, 2, 2, 8, 3, 7, 7, 6, 4, 7, 6, 9, 7, 9, 2, 5, 5, 6, 7, 6, 2,
+ 6, 9, 5, 3, 1, 2, 5, 3, 4, 6, 9, 4, 4, 6, 9, 5, 1, 9, 5, 3, 6, 1, 4, 1,
+ 8, 8, 8, 2, 3, 8, 4, 8, 9, 6, 2, 7, 8, 3, 8, 1, 3, 4, 7, 6, 5, 6, 2, 5,
+ 1, 7, 3, 4, 7, 2, 3, 4, 7, 5, 9, 7, 6, 8, 0, 7, 0, 9, 4, 4, 1, 1, 9, 2,
+ 4, 4, 8, 1, 3, 9, 1, 9, 0, 6, 7, 3, 8, 2, 8, 1, 2, 5, 8, 6, 7, 3, 6, 1,
+ 7, 3, 7, 9, 8, 8, 4, 0, 3, 5, 4, 7, 2, 0, 5, 9, 6, 2, 2, 4, 0, 6, 9, 5,
+ 9, 5, 3, 3, 6, 9, 1, 4, 0, 6, 2, 5,
+ };
+ const uint8_t *pow5 =
+ &number_of_digits_decimal_left_shift_table_powers_of_5[pow5_a];
+ uint32_t i = 0;
+ uint32_t n = pow5_b - pow5_a;
+ for (; i < n; i++) {
+ if (i >= h.num_digits) {
+ return num_new_digits - 1;
+ } else if (h.digits[i] == pow5[i]) {
+ continue;
+ } else if (h.digits[i] < pow5[i]) {
+ return num_new_digits - 1;
+ } else {
+ return num_new_digits;
+ }
+ }
+ return num_new_digits;
+}
+
+uint64_t round(decimal &h) {
+ if ((h.num_digits == 0) || (h.decimal_point < 0)) {
+ return 0;
+ } else if (h.decimal_point > 18) {
+ return UINT64_MAX;
+ }
+ // at this point, we know that h.decimal_point >= 0
+ uint32_t dp = uint32_t(h.decimal_point);
+ uint64_t n = 0;
+ for (uint32_t i = 0; i < dp; i++) {
+ n = (10 * n) + ((i < h.num_digits) ? h.digits[i] : 0);
+ }
+ bool round_up = false;
+ if (dp < h.num_digits) {
+ round_up = h.digits[dp] >= 5; // normally, we round up
+ // but we may need to round to even!
+ if ((h.digits[dp] == 5) && (dp + 1 == h.num_digits)) {
+ round_up = h.truncated || ((dp > 0) && (1 & h.digits[dp - 1]));
+ }
+ }
+ if (round_up) {
+ n++;
+ }
+ return n;
+}
+
+// computes h * 2^-shift
+void decimal_left_shift(decimal &h, uint32_t shift) {
+ if (h.num_digits == 0) {
+ return;
+ }
+ uint32_t num_new_digits = number_of_digits_decimal_left_shift(h, shift);
+ int32_t read_index = int32_t(h.num_digits - 1);
+ uint32_t write_index = h.num_digits - 1 + num_new_digits;
+ uint64_t n = 0;
+
+ while (read_index >= 0) {
+ n += uint64_t(h.digits[read_index]) << shift;
+ uint64_t quotient = n / 10;
+ uint64_t remainder = n - (10 * quotient);
+ if (write_index < max_digits) {
+ h.digits[write_index] = uint8_t(remainder);
+ } else if (remainder > 0) {
+ h.truncated = true;
+ }
+ n = quotient;
+ write_index--;
+ read_index--;
+ }
+ while (n > 0) {
+ uint64_t quotient = n / 10;
+ uint64_t remainder = n - (10 * quotient);
+ if (write_index < max_digits) {
+ h.digits[write_index] = uint8_t(remainder);
+ } else if (remainder > 0) {
+ h.truncated = true;
+ }
+ n = quotient;
+ write_index--;
+ }
+ h.num_digits += num_new_digits;
+ if (h.num_digits > max_digits) {
+ h.num_digits = max_digits;
+ }
+ h.decimal_point += int32_t(num_new_digits);
+ trim(h);
+}
+
+// computes h * 2^shift
+void decimal_right_shift(decimal &h, uint32_t shift) {
+ uint32_t read_index = 0;
+ uint32_t write_index = 0;
+
+ uint64_t n = 0;
+
+ while ((n >> shift) == 0) {
+ if (read_index < h.num_digits) {
+ n = (10 * n) + h.digits[read_index++];
+ } else if (n == 0) {
+ return;
+ } else {
+ while ((n >> shift) == 0) {
+ n = 10 * n;
+ read_index++;
+ }
+ break;
+ }
+ }
+ h.decimal_point -= int32_t(read_index - 1);
+ if (h.decimal_point < -decimal_point_range) { // it is zero
+ h.num_digits = 0;
+ h.decimal_point = 0;
+ h.negative = false;
+ h.truncated = false;
+ return;
+ }
+ uint64_t mask = (uint64_t(1) << shift) - 1;
+ while (read_index < h.num_digits) {
+ uint8_t new_digit = uint8_t(n >> shift);
+ n = (10 * (n & mask)) + h.digits[read_index++];
+ h.digits[write_index++] = new_digit;
+ }
+ while (n > 0) {
+ uint8_t new_digit = uint8_t(n >> shift);
+ n = 10 * (n & mask);
+ if (write_index < max_digits) {
+ h.digits[write_index++] = new_digit;
+ } else if (new_digit > 0) {
+ h.truncated = true;
+ }
+ }
+ h.num_digits = write_index;
+ trim(h);
+}
+
+} // end of anonymous namespace
+
+template <typename binary>
+adjusted_mantissa compute_float(decimal &d) {
+ adjusted_mantissa answer;
+ if (d.num_digits == 0) {
+ // should be zero
+ answer.power2 = 0;
+ answer.mantissa = 0;
+ return answer;
+ }
+ // At this point, going further, we can assume that d.num_digits > 0.
+ //
+ // We want to guard against excessive decimal point values because
+ // they can result in long running times. Indeed, we do
+ // shifts by at most 60 bits. We have that log(10**400)/log(2**60) ~= 22
+ // which is fine, but log(10**299995)/log(2**60) ~= 16609 which is not
+ // fine (runs for a long time).
+ //
+ if(d.decimal_point < -324) {
+ // We have something smaller than 1e-324 which is always zero
+ // in binary64 and binary32.
+ // It should be zero.
+ answer.power2 = 0;
+ answer.mantissa = 0;
+ return answer;
+ } else if(d.decimal_point >= 310) {
+ // We have something at least as large as 0.1e310 which is
+ // always infinite.
+ answer.power2 = binary::infinite_power();
+ answer.mantissa = 0;
+ return answer;
+ }
+ static const uint32_t max_shift = 60;
+ static const uint32_t num_powers = 19;
+ static const uint8_t powers[19] = {
+ 0, 3, 6, 9, 13, 16, 19, 23, 26, 29, //
+ 33, 36, 39, 43, 46, 49, 53, 56, 59, //
+ };
+ int32_t exp2 = 0;
+ while (d.decimal_point > 0) {
+ uint32_t n = uint32_t(d.decimal_point);
+ uint32_t shift = (n < num_powers) ? powers[n] : max_shift;
+ decimal_right_shift(d, shift);
+ if (d.decimal_point < -decimal_point_range) {
+ // should be zero
+ answer.power2 = 0;
+ answer.mantissa = 0;
+ return answer;
+ }
+ exp2 += int32_t(shift);
+ }
+ // We shift left toward [1/2 ... 1].
+ while (d.decimal_point <= 0) {
+ uint32_t shift;
+ if (d.decimal_point == 0) {
+ if (d.digits[0] >= 5) {
+ break;
+ }
+ shift = (d.digits[0] < 2) ? 2 : 1;
+ } else {
+ uint32_t n = uint32_t(-d.decimal_point);
+ shift = (n < num_powers) ? powers[n] : max_shift;
+ }
+ decimal_left_shift(d, shift);
+ if (d.decimal_point > decimal_point_range) {
+ // we want to get infinity:
+ answer.power2 = binary::infinite_power();
+ answer.mantissa = 0;
+ return answer;
+ }
+ exp2 -= int32_t(shift);
+ }
+ // We are now in the range [1/2 ... 1] but the binary format uses [1 ... 2].
+ exp2--;
+ constexpr int32_t minimum_exponent = binary::minimum_exponent();
+ while ((minimum_exponent + 1) > exp2) {
+ uint32_t n = uint32_t((minimum_exponent + 1) - exp2);
+ if (n > max_shift) {
+ n = max_shift;
+ }
+ decimal_right_shift(d, n);
+ exp2 += int32_t(n);
+ }
+ if ((exp2 - minimum_exponent) >= binary::infinite_power()) {
+ answer.power2 = binary::infinite_power();
+ answer.mantissa = 0;
+ return answer;
+ }
+
+ const int mantissa_size_in_bits = binary::mantissa_explicit_bits() + 1;
+ decimal_left_shift(d, mantissa_size_in_bits);
+
+ uint64_t mantissa = round(d);
+ // It is possible that we have an overflow, in which case we need
+ // to shift back.
+ if(mantissa >= (uint64_t(1) << mantissa_size_in_bits)) {
+ decimal_right_shift(d, 1);
+ exp2 += 1;
+ mantissa = round(d);
+ if ((exp2 - minimum_exponent) >= binary::infinite_power()) {
+ answer.power2 = binary::infinite_power();
+ answer.mantissa = 0;
+ return answer;
+ }
+ }
+ answer.power2 = exp2 - binary::minimum_exponent();
+ if(mantissa < (uint64_t(1) << binary::mantissa_explicit_bits())) { answer.power2--; }
+ answer.mantissa = mantissa & ((uint64_t(1) << binary::mantissa_explicit_bits()) - 1);
+ return answer;
+}
+
+template <typename binary>
+adjusted_mantissa parse_long_mantissa(const char *first, const char* last) {
+ decimal d = parse_decimal(first, last);
+ return compute_float<binary>(d);
+}
+
+} // namespace fast_float
+} // namespace arrow_vendored
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/musl/README.md b/src/arrow/cpp/src/arrow/vendored/musl/README.md
new file mode 100644
index 000000000..40962a14c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/musl/README.md
@@ -0,0 +1,25 @@
+<!--
+Copyright © 2005-2020 Rich Felker, et al.
+
+Permission is hereby granted, free of charge, to any person obtaining
+a copy of this software and associated documentation files (the
+"Software"), to deal in the Software without restriction, including
+without limitation the rights to use, copy, modify, merge, publish,
+distribute, sublicense, and/or sell copies of the Software, and to
+permit persons to whom the Software is furnished to do so, subject to
+the following conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+-->
+
+Assorted utility functions are adapted from the musl libc project
+(https://musl.libc.org/).
diff --git a/src/arrow/cpp/src/arrow/vendored/musl/strptime.c b/src/arrow/cpp/src/arrow/vendored/musl/strptime.c
new file mode 100644
index 000000000..e8111f576
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/musl/strptime.c
@@ -0,0 +1,237 @@
+// Vendored from musl git commit 593caa456309714402ca4cb77c3770f4c24da9da
+// + adaptations
+
+#include "arrow/vendored/strptime.h"
+
+#include <ctype.h>
+#include <stddef.h>
+#include <stdlib.h>
+#include <string.h>
+
+#ifdef _WIN32
+#define strncasecmp _strnicmp
+#define strcasecmp _stricmp
+#else
+#include <strings.h>
+#endif
+
+#undef HAVE_LANGINFO
+
+#ifndef _WIN32
+#define HAVE_LANGINFO 1
+#endif
+
+#ifdef HAVE_LANGINFO
+#include <langinfo.h>
+#endif
+
+#define strptime arrow_strptime
+
+char *strptime(const char *__restrict s, const char *__restrict f, struct tm *__restrict tm)
+{
+ int i, w, neg, adj, min, range, *dest, dummy;
+#ifdef HAVE_LANGINFO
+ const char *ex;
+ size_t len;
+#endif
+ int want_century = 0, century = 0, relyear = 0;
+ while (*f) {
+ if (*f != '%') {
+ if (isspace(*f)) for (; *s && isspace(*s); s++);
+ else if (*s != *f) return 0;
+ else s++;
+ f++;
+ continue;
+ }
+ f++;
+ if (*f == '+') f++;
+ if (isdigit(*f)) {
+ char *new_f;
+ w=strtoul(f, &new_f, 10);
+ f = new_f;
+ } else {
+ w=-1;
+ }
+ adj=0;
+ switch (*f++) {
+#ifdef HAVE_LANGINFO
+ case 'a': case 'A':
+ dest = &tm->tm_wday;
+ min = ABDAY_1;
+ range = 7;
+ goto symbolic_range;
+ case 'b': case 'B': case 'h':
+ dest = &tm->tm_mon;
+ min = ABMON_1;
+ range = 12;
+ goto symbolic_range;
+ case 'c':
+ s = strptime(s, nl_langinfo(D_T_FMT), tm);
+ if (!s) return 0;
+ break;
+#endif
+ case 'C':
+ dest = &century;
+ if (w<0) w=2;
+ want_century |= 2;
+ goto numeric_digits;
+ case 'd': case 'e':
+ dest = &tm->tm_mday;
+ min = 1;
+ range = 31;
+ goto numeric_range;
+ case 'D':
+ s = strptime(s, "%m/%d/%y", tm);
+ if (!s) return 0;
+ break;
+ case 'H':
+ dest = &tm->tm_hour;
+ min = 0;
+ range = 24;
+ goto numeric_range;
+ case 'I':
+ dest = &tm->tm_hour;
+ min = 1;
+ range = 12;
+ goto numeric_range;
+ case 'j':
+ dest = &tm->tm_yday;
+ min = 1;
+ range = 366;
+ adj = 1;
+ goto numeric_range;
+ case 'm':
+ dest = &tm->tm_mon;
+ min = 1;
+ range = 12;
+ adj = 1;
+ goto numeric_range;
+ case 'M':
+ dest = &tm->tm_min;
+ min = 0;
+ range = 60;
+ goto numeric_range;
+ case 'n': case 't':
+ for (; *s && isspace(*s); s++);
+ break;
+#ifdef HAVE_LANGINFO
+ case 'p':
+ ex = nl_langinfo(AM_STR);
+ len = strlen(ex);
+ if (!strncasecmp(s, ex, len)) {
+ tm->tm_hour %= 12;
+ s += len;
+ break;
+ }
+ ex = nl_langinfo(PM_STR);
+ len = strlen(ex);
+ if (!strncasecmp(s, ex, len)) {
+ tm->tm_hour %= 12;
+ tm->tm_hour += 12;
+ s += len;
+ break;
+ }
+ return 0;
+ case 'r':
+ s = strptime(s, nl_langinfo(T_FMT_AMPM), tm);
+ if (!s) return 0;
+ break;
+#endif
+ case 'R':
+ s = strptime(s, "%H:%M", tm);
+ if (!s) return 0;
+ break;
+ case 'S':
+ dest = &tm->tm_sec;
+ min = 0;
+ range = 61;
+ goto numeric_range;
+ case 'T':
+ s = strptime(s, "%H:%M:%S", tm);
+ if (!s) return 0;
+ break;
+ case 'U':
+ case 'W':
+ /* Throw away result, for now. (FIXME?) */
+ dest = &dummy;
+ min = 0;
+ range = 54;
+ goto numeric_range;
+ case 'w':
+ dest = &tm->tm_wday;
+ min = 0;
+ range = 7;
+ goto numeric_range;
+#ifdef HAVE_LANGINFO
+ case 'x':
+ s = strptime(s, nl_langinfo(D_FMT), tm);
+ if (!s) return 0;
+ break;
+ case 'X':
+ s = strptime(s, nl_langinfo(T_FMT), tm);
+ if (!s) return 0;
+ break;
+#endif
+ case 'y':
+ dest = &relyear;
+ w = 2;
+ want_century |= 1;
+ goto numeric_digits;
+ case 'Y':
+ dest = &tm->tm_year;
+ if (w<0) w=4;
+ adj = 1900;
+ want_century = 0;
+ goto numeric_digits;
+ case '%':
+ if (*s++ != '%') return 0;
+ break;
+ default:
+ return 0;
+ numeric_range:
+ if (!isdigit(*s)) return 0;
+ *dest = 0;
+ for (i=1; i<=min+range && isdigit(*s); i*=10)
+ *dest = *dest * 10 + *s++ - '0';
+ if (*dest - min >= range) return 0;
+ *dest -= adj;
+ switch((char *)dest - (char *)tm) {
+ case offsetof(struct tm, tm_yday):
+ ;
+ }
+ goto update;
+ numeric_digits:
+ neg = 0;
+ if (*s == '+') s++;
+ else if (*s == '-') neg=1, s++;
+ if (!isdigit(*s)) return 0;
+ for (*dest=i=0; i<w && isdigit(*s); i++)
+ *dest = *dest * 10 + *s++ - '0';
+ if (neg) *dest = -*dest;
+ *dest -= adj;
+ goto update;
+#ifdef HAVE_LANGINFO
+ symbolic_range:
+ for (i=2*range-1; i>=0; i--) {
+ ex = nl_langinfo(min+i);
+ len = strlen(ex);
+ if (strncasecmp(s, ex, len)) continue;
+ s += len;
+ *dest = i % range;
+ break;
+ }
+ if (i<0) return 0;
+ goto update;
+#endif
+ update:
+ //FIXME
+ ;
+ }
+ }
+ if (want_century) {
+ tm->tm_year = relyear;
+ if (want_century & 2) tm->tm_year += century * 100 - 1900;
+ else if (tm->tm_year <= 68) tm->tm_year += 100;
+ }
+ return (char *)s;
+}
diff --git a/src/arrow/cpp/src/arrow/vendored/optional.hpp b/src/arrow/cpp/src/arrow/vendored/optional.hpp
new file mode 100644
index 000000000..e266bb20b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/optional.hpp
@@ -0,0 +1,1553 @@
+// Vendored from git tag v3.2.0
+
+// Copyright (c) 2014-2018 Martin Moene
+//
+// https://github.com/martinmoene/optional-lite
+//
+// Distributed under the Boost Software License, Version 1.0.
+// (See accompanying file LICENSE.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
+
+#pragma once
+
+#ifndef NONSTD_OPTIONAL_LITE_HPP
+#define NONSTD_OPTIONAL_LITE_HPP
+
+#define optional_lite_MAJOR 3
+#define optional_lite_MINOR 2
+#define optional_lite_PATCH 0
+
+#define optional_lite_VERSION optional_STRINGIFY(optional_lite_MAJOR) "." optional_STRINGIFY(optional_lite_MINOR) "." optional_STRINGIFY(optional_lite_PATCH)
+
+#define optional_STRINGIFY( x ) optional_STRINGIFY_( x )
+#define optional_STRINGIFY_( x ) #x
+
+// optional-lite configuration:
+
+#define optional_OPTIONAL_DEFAULT 0
+#define optional_OPTIONAL_NONSTD 1
+#define optional_OPTIONAL_STD 2
+
+#if !defined( optional_CONFIG_SELECT_OPTIONAL )
+# define optional_CONFIG_SELECT_OPTIONAL ( optional_HAVE_STD_OPTIONAL ? optional_OPTIONAL_STD : optional_OPTIONAL_NONSTD )
+#endif
+
+// Control presence of exception handling (try and auto discover):
+
+#ifndef optional_CONFIG_NO_EXCEPTIONS
+# if defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)
+# define optional_CONFIG_NO_EXCEPTIONS 0
+# else
+# define optional_CONFIG_NO_EXCEPTIONS 1
+# endif
+#endif
+
+// C++ language version detection (C++20 is speculative):
+// Note: VC14.0/1900 (VS2015) lacks too much from C++14.
+
+#ifndef optional_CPLUSPLUS
+# if defined(_MSVC_LANG ) && !defined(__clang__)
+# define optional_CPLUSPLUS (_MSC_VER == 1900 ? 201103L : _MSVC_LANG )
+# else
+# define optional_CPLUSPLUS __cplusplus
+# endif
+#endif
+
+#define optional_CPP98_OR_GREATER ( optional_CPLUSPLUS >= 199711L )
+#define optional_CPP11_OR_GREATER ( optional_CPLUSPLUS >= 201103L )
+#define optional_CPP11_OR_GREATER_ ( optional_CPLUSPLUS >= 201103L )
+#define optional_CPP14_OR_GREATER ( optional_CPLUSPLUS >= 201402L )
+#define optional_CPP17_OR_GREATER ( optional_CPLUSPLUS >= 201703L )
+#define optional_CPP20_OR_GREATER ( optional_CPLUSPLUS >= 202000L )
+
+// C++ language version (represent 98 as 3):
+
+#define optional_CPLUSPLUS_V ( optional_CPLUSPLUS / 100 - (optional_CPLUSPLUS > 200000 ? 2000 : 1994) )
+
+// Use C++17 std::optional if available and requested:
+
+#if optional_CPP17_OR_GREATER && defined(__has_include )
+# if __has_include( <optional> )
+# define optional_HAVE_STD_OPTIONAL 1
+# else
+# define optional_HAVE_STD_OPTIONAL 0
+# endif
+#else
+# define optional_HAVE_STD_OPTIONAL 0
+#endif
+
+#define optional_USES_STD_OPTIONAL ( (optional_CONFIG_SELECT_OPTIONAL == optional_OPTIONAL_STD) || ((optional_CONFIG_SELECT_OPTIONAL == optional_OPTIONAL_DEFAULT) && optional_HAVE_STD_OPTIONAL) )
+
+//
+// in_place: code duplicated in any-lite, expected-lite, optional-lite, value-ptr-lite, variant-lite:
+//
+
+#ifndef nonstd_lite_HAVE_IN_PLACE_TYPES
+#define nonstd_lite_HAVE_IN_PLACE_TYPES 1
+
+// C++17 std::in_place in <utility>:
+
+#if optional_CPP17_OR_GREATER
+
+#include <utility>
+
+namespace nonstd {
+
+using std::in_place;
+using std::in_place_type;
+using std::in_place_index;
+using std::in_place_t;
+using std::in_place_type_t;
+using std::in_place_index_t;
+
+#define nonstd_lite_in_place_t( T) std::in_place_t
+#define nonstd_lite_in_place_type_t( T) std::in_place_type_t<T>
+#define nonstd_lite_in_place_index_t(K) std::in_place_index_t<K>
+
+#define nonstd_lite_in_place( T) std::in_place_t{}
+#define nonstd_lite_in_place_type( T) std::in_place_type_t<T>{}
+#define nonstd_lite_in_place_index(K) std::in_place_index_t<K>{}
+
+} // namespace nonstd
+
+#else // optional_CPP17_OR_GREATER
+
+#include <cstddef>
+
+namespace nonstd {
+namespace detail {
+
+template< class T >
+struct in_place_type_tag {};
+
+template< std::size_t K >
+struct in_place_index_tag {};
+
+} // namespace detail
+
+struct in_place_t {};
+
+template< class T >
+inline in_place_t in_place( detail::in_place_type_tag<T> /*unused*/ = detail::in_place_type_tag<T>() )
+{
+ return in_place_t();
+}
+
+template< std::size_t K >
+inline in_place_t in_place( detail::in_place_index_tag<K> /*unused*/ = detail::in_place_index_tag<K>() )
+{
+ return in_place_t();
+}
+
+template< class T >
+inline in_place_t in_place_type( detail::in_place_type_tag<T> /*unused*/ = detail::in_place_type_tag<T>() )
+{
+ return in_place_t();
+}
+
+template< std::size_t K >
+inline in_place_t in_place_index( detail::in_place_index_tag<K> /*unused*/ = detail::in_place_index_tag<K>() )
+{
+ return in_place_t();
+}
+
+// mimic templated typedef:
+
+#define nonstd_lite_in_place_t( T) nonstd::in_place_t(&)( nonstd::detail::in_place_type_tag<T> )
+#define nonstd_lite_in_place_type_t( T) nonstd::in_place_t(&)( nonstd::detail::in_place_type_tag<T> )
+#define nonstd_lite_in_place_index_t(K) nonstd::in_place_t(&)( nonstd::detail::in_place_index_tag<K> )
+
+#define nonstd_lite_in_place( T) nonstd::in_place_type<T>
+#define nonstd_lite_in_place_type( T) nonstd::in_place_type<T>
+#define nonstd_lite_in_place_index(K) nonstd::in_place_index<K>
+
+} // namespace nonstd
+
+#endif // optional_CPP17_OR_GREATER
+#endif // nonstd_lite_HAVE_IN_PLACE_TYPES
+
+//
+// Using std::optional:
+//
+
+#if optional_USES_STD_OPTIONAL
+
+#include <optional>
+
+namespace nonstd {
+
+ using std::optional;
+ using std::bad_optional_access;
+ using std::hash;
+
+ using std::nullopt;
+ using std::nullopt_t;
+
+ using std::operator==;
+ using std::operator!=;
+ using std::operator<;
+ using std::operator<=;
+ using std::operator>;
+ using std::operator>=;
+ using std::make_optional;
+ using std::swap;
+}
+
+#else // optional_USES_STD_OPTIONAL
+
+#include <cassert>
+#include <utility>
+
+// optional-lite alignment configuration:
+
+#ifndef optional_CONFIG_MAX_ALIGN_HACK
+# define optional_CONFIG_MAX_ALIGN_HACK 0
+#endif
+
+#ifndef optional_CONFIG_ALIGN_AS
+// no default, used in #if defined()
+#endif
+
+#ifndef optional_CONFIG_ALIGN_AS_FALLBACK
+# define optional_CONFIG_ALIGN_AS_FALLBACK double
+#endif
+
+// Compiler warning suppression:
+
+#if defined(__clang__)
+# pragma clang diagnostic push
+# pragma clang diagnostic ignored "-Wundef"
+#elif defined(__GNUC__)
+# pragma GCC diagnostic push
+# pragma GCC diagnostic ignored "-Wundef"
+#elif defined(_MSC_VER )
+# pragma warning( push )
+#endif
+
+// half-open range [lo..hi):
+#define optional_BETWEEN( v, lo, hi ) ( (lo) <= (v) && (v) < (hi) )
+
+// Compiler versions:
+//
+// MSVC++ 6.0 _MSC_VER == 1200 (Visual Studio 6.0)
+// MSVC++ 7.0 _MSC_VER == 1300 (Visual Studio .NET 2002)
+// MSVC++ 7.1 _MSC_VER == 1310 (Visual Studio .NET 2003)
+// MSVC++ 8.0 _MSC_VER == 1400 (Visual Studio 2005)
+// MSVC++ 9.0 _MSC_VER == 1500 (Visual Studio 2008)
+// MSVC++ 10.0 _MSC_VER == 1600 (Visual Studio 2010)
+// MSVC++ 11.0 _MSC_VER == 1700 (Visual Studio 2012)
+// MSVC++ 12.0 _MSC_VER == 1800 (Visual Studio 2013)
+// MSVC++ 14.0 _MSC_VER == 1900 (Visual Studio 2015)
+// MSVC++ 14.1 _MSC_VER >= 1910 (Visual Studio 2017)
+
+#if defined(_MSC_VER ) && !defined(__clang__)
+# define optional_COMPILER_MSVC_VER (_MSC_VER )
+# define optional_COMPILER_MSVC_VERSION (_MSC_VER / 10 - 10 * ( 5 + (_MSC_VER < 1900 ) ) )
+#else
+# define optional_COMPILER_MSVC_VER 0
+# define optional_COMPILER_MSVC_VERSION 0
+#endif
+
+#define optional_COMPILER_VERSION( major, minor, patch ) ( 10 * (10 * (major) + (minor) ) + (patch) )
+
+#if defined(__GNUC__) && !defined(__clang__)
+# define optional_COMPILER_GNUC_VERSION optional_COMPILER_VERSION(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__)
+#else
+# define optional_COMPILER_GNUC_VERSION 0
+#endif
+
+#if defined(__clang__)
+# define optional_COMPILER_CLANG_VERSION optional_COMPILER_VERSION(__clang_major__, __clang_minor__, __clang_patchlevel__)
+#else
+# define optional_COMPILER_CLANG_VERSION 0
+#endif
+
+#if optional_BETWEEN(optional_COMPILER_MSVC_VERSION, 70, 140 )
+# pragma warning( disable: 4345 ) // initialization behavior changed
+#endif
+
+#if optional_BETWEEN(optional_COMPILER_MSVC_VERSION, 70, 150 )
+# pragma warning( disable: 4814 ) // in C++14 'constexpr' will not imply 'const'
+#endif
+
+// Presence of language and library features:
+
+#define optional_HAVE(FEATURE) ( optional_HAVE_##FEATURE )
+
+#ifdef _HAS_CPP0X
+# define optional_HAS_CPP0X _HAS_CPP0X
+#else
+# define optional_HAS_CPP0X 0
+#endif
+
+// Unless defined otherwise below, consider VC14 as C++11 for optional-lite:
+
+#if optional_COMPILER_MSVC_VER >= 1900
+# undef optional_CPP11_OR_GREATER
+# define optional_CPP11_OR_GREATER 1
+#endif
+
+#define optional_CPP11_90 (optional_CPP11_OR_GREATER_ || optional_COMPILER_MSVC_VER >= 1500)
+#define optional_CPP11_100 (optional_CPP11_OR_GREATER_ || optional_COMPILER_MSVC_VER >= 1600)
+#define optional_CPP11_110 (optional_CPP11_OR_GREATER_ || optional_COMPILER_MSVC_VER >= 1700)
+#define optional_CPP11_120 (optional_CPP11_OR_GREATER_ || optional_COMPILER_MSVC_VER >= 1800)
+#define optional_CPP11_140 (optional_CPP11_OR_GREATER_ || optional_COMPILER_MSVC_VER >= 1900)
+#define optional_CPP11_141 (optional_CPP11_OR_GREATER_ || optional_COMPILER_MSVC_VER >= 1910)
+
+#define optional_CPP14_000 (optional_CPP14_OR_GREATER)
+#define optional_CPP17_000 (optional_CPP17_OR_GREATER)
+
+// Presence of C++11 language features:
+
+#define optional_HAVE_CONSTEXPR_11 optional_CPP11_140
+#define optional_HAVE_IS_DEFAULT optional_CPP11_140
+#define optional_HAVE_NOEXCEPT optional_CPP11_140
+#define optional_HAVE_NULLPTR optional_CPP11_100
+#define optional_HAVE_REF_QUALIFIER optional_CPP11_140
+
+// Presence of C++14 language features:
+
+#define optional_HAVE_CONSTEXPR_14 optional_CPP14_000
+
+// Presence of C++17 language features:
+
+#define optional_HAVE_NODISCARD optional_CPP17_000
+
+// Presence of C++ library features:
+
+#define optional_HAVE_CONDITIONAL optional_CPP11_120
+#define optional_HAVE_REMOVE_CV optional_CPP11_120
+#define optional_HAVE_TYPE_TRAITS optional_CPP11_90
+
+#define optional_HAVE_TR1_TYPE_TRAITS (!! optional_COMPILER_GNUC_VERSION )
+#define optional_HAVE_TR1_ADD_POINTER (!! optional_COMPILER_GNUC_VERSION )
+
+// C++ feature usage:
+
+#if optional_HAVE( CONSTEXPR_11 )
+# define optional_constexpr constexpr
+#else
+# define optional_constexpr /*constexpr*/
+#endif
+
+#if optional_HAVE( IS_DEFAULT )
+# define optional_is_default = default;
+#else
+# define optional_is_default {}
+#endif
+
+#if optional_HAVE( CONSTEXPR_14 )
+# define optional_constexpr14 constexpr
+#else
+# define optional_constexpr14 /*constexpr*/
+#endif
+
+#if optional_HAVE( NODISCARD )
+# define optional_nodiscard [[nodiscard]]
+#else
+# define optional_nodiscard /*[[nodiscard]]*/
+#endif
+
+#if optional_HAVE( NOEXCEPT )
+# define optional_noexcept noexcept
+#else
+# define optional_noexcept /*noexcept*/
+#endif
+
+#if optional_HAVE( NULLPTR )
+# define optional_nullptr nullptr
+#else
+# define optional_nullptr NULL
+#endif
+
+#if optional_HAVE( REF_QUALIFIER )
+// NOLINTNEXTLINE( bugprone-macro-parentheses )
+# define optional_ref_qual &
+# define optional_refref_qual &&
+#else
+# define optional_ref_qual /*&*/
+# define optional_refref_qual /*&&*/
+#endif
+
+// additional includes:
+
+#if optional_CONFIG_NO_EXCEPTIONS
+// already included: <cassert>
+#else
+# include <stdexcept>
+#endif
+
+#if optional_CPP11_OR_GREATER
+# include <functional>
+#endif
+
+#if optional_HAVE( INITIALIZER_LIST )
+# include <initializer_list>
+#endif
+
+#if optional_HAVE( TYPE_TRAITS )
+# include <type_traits>
+#elif optional_HAVE( TR1_TYPE_TRAITS )
+# include <tr1/type_traits>
+#endif
+
+// Method enabling
+
+#if optional_CPP11_OR_GREATER
+
+#define optional_REQUIRES_0(...) \
+ template< bool B = (__VA_ARGS__), typename std::enable_if<B, int>::type = 0 >
+
+#define optional_REQUIRES_T(...) \
+ , typename = typename std::enable_if< (__VA_ARGS__), nonstd::optional_lite::detail::enabler >::type
+
+#define optional_REQUIRES_R(R, ...) \
+ typename std::enable_if< (__VA_ARGS__), R>::type
+
+#define optional_REQUIRES_A(...) \
+ , typename std::enable_if< (__VA_ARGS__), void*>::type = nullptr
+
+#endif
+
+//
+// optional:
+//
+
+namespace nonstd { namespace optional_lite {
+
+namespace std11 {
+
+#if optional_CPP11_OR_GREATER
+ using std::move;
+#else
+ template< typename T > T & move( T & t ) { return t; }
+#endif
+
+#if optional_HAVE( CONDITIONAL )
+ using std::conditional;
+#else
+ template< bool B, typename T, typename F > struct conditional { typedef T type; };
+ template< typename T, typename F > struct conditional<false, T, F> { typedef F type; };
+#endif // optional_HAVE_CONDITIONAL
+
+} // namespace std11
+
+#if optional_CPP11_OR_GREATER
+
+/// type traits C++17:
+
+namespace std17 {
+
+#if optional_CPP17_OR_GREATER
+
+using std::is_swappable;
+using std::is_nothrow_swappable;
+
+#elif optional_CPP11_OR_GREATER
+
+namespace detail {
+
+using std::swap;
+
+struct is_swappable
+{
+ template< typename T, typename = decltype( swap( std::declval<T&>(), std::declval<T&>() ) ) >
+ static std::true_type test( int /*unused*/ );
+
+ template< typename >
+ static std::false_type test(...);
+};
+
+struct is_nothrow_swappable
+{
+ // wrap noexcept(expr) in separate function as work-around for VC140 (VS2015):
+
+ template< typename T >
+ static constexpr bool satisfies()
+ {
+ return noexcept( swap( std::declval<T&>(), std::declval<T&>() ) );
+ }
+
+ template< typename T >
+ static auto test( int /*unused*/ ) -> std::integral_constant<bool, satisfies<T>()>{}
+
+ template< typename >
+ static auto test(...) -> std::false_type;
+};
+
+} // namespace detail
+
+// is [nothow] swappable:
+
+template< typename T >
+struct is_swappable : decltype( detail::is_swappable::test<T>(0) ){};
+
+template< typename T >
+struct is_nothrow_swappable : decltype( detail::is_nothrow_swappable::test<T>(0) ){};
+
+#endif // optional_CPP17_OR_GREATER
+
+} // namespace std17
+
+/// type traits C++20:
+
+namespace std20 {
+
+template< typename T >
+struct remove_cvref
+{
+ typedef typename std::remove_cv< typename std::remove_reference<T>::type >::type type;
+};
+
+} // namespace std20
+
+#endif // optional_CPP11_OR_GREATER
+
+/// class optional
+
+template< typename T >
+class optional;
+
+namespace detail {
+
+// for optional_REQUIRES_T
+
+#if optional_CPP11_OR_GREATER
+enum class enabler{};
+#endif
+
+// C++11 emulation:
+
+struct nulltype{};
+
+template< typename Head, typename Tail >
+struct typelist
+{
+ typedef Head head;
+ typedef Tail tail;
+};
+
+#if optional_CONFIG_MAX_ALIGN_HACK
+
+// Max align, use most restricted type for alignment:
+
+#define optional_UNIQUE( name ) optional_UNIQUE2( name, __LINE__ )
+#define optional_UNIQUE2( name, line ) optional_UNIQUE3( name, line )
+#define optional_UNIQUE3( name, line ) name ## line
+
+#define optional_ALIGN_TYPE( type ) \
+ type optional_UNIQUE( _t ); struct_t< type > optional_UNIQUE( _st )
+
+template< typename T >
+struct struct_t { T _; };
+
+union max_align_t
+{
+ optional_ALIGN_TYPE( char );
+ optional_ALIGN_TYPE( short int );
+ optional_ALIGN_TYPE( int );
+ optional_ALIGN_TYPE( long int );
+ optional_ALIGN_TYPE( float );
+ optional_ALIGN_TYPE( double );
+ optional_ALIGN_TYPE( long double );
+ optional_ALIGN_TYPE( char * );
+ optional_ALIGN_TYPE( short int * );
+ optional_ALIGN_TYPE( int * );
+ optional_ALIGN_TYPE( long int * );
+ optional_ALIGN_TYPE( float * );
+ optional_ALIGN_TYPE( double * );
+ optional_ALIGN_TYPE( long double * );
+ optional_ALIGN_TYPE( void * );
+
+#ifdef HAVE_LONG_LONG
+ optional_ALIGN_TYPE( long long );
+#endif
+
+ struct Unknown;
+
+ Unknown ( * optional_UNIQUE(_) )( Unknown );
+ Unknown * Unknown::* optional_UNIQUE(_);
+ Unknown ( Unknown::* optional_UNIQUE(_) )( Unknown );
+
+ struct_t< Unknown ( * )( Unknown) > optional_UNIQUE(_);
+ struct_t< Unknown * Unknown::* > optional_UNIQUE(_);
+ struct_t< Unknown ( Unknown::* )(Unknown) > optional_UNIQUE(_);
+};
+
+#undef optional_UNIQUE
+#undef optional_UNIQUE2
+#undef optional_UNIQUE3
+
+#undef optional_ALIGN_TYPE
+
+#elif defined( optional_CONFIG_ALIGN_AS ) // optional_CONFIG_MAX_ALIGN_HACK
+
+// Use user-specified type for alignment:
+
+#define optional_ALIGN_AS( unused ) \
+ optional_CONFIG_ALIGN_AS
+
+#else // optional_CONFIG_MAX_ALIGN_HACK
+
+// Determine POD type to use for alignment:
+
+#define optional_ALIGN_AS( to_align ) \
+ typename type_of_size< alignment_types, alignment_of< to_align >::value >::type
+
+template< typename T >
+struct alignment_of;
+
+template< typename T >
+struct alignment_of_hack
+{
+ char c;
+ T t;
+ alignment_of_hack();
+};
+
+template< size_t A, size_t S >
+struct alignment_logic
+{
+ enum { value = A < S ? A : S };
+};
+
+template< typename T >
+struct alignment_of
+{
+ enum { value = alignment_logic<
+ sizeof( alignment_of_hack<T> ) - sizeof(T), sizeof(T) >::value };
+};
+
+template< typename List, size_t N >
+struct type_of_size
+{
+ typedef typename std11::conditional<
+ N == sizeof( typename List::head ),
+ typename List::head,
+ typename type_of_size<typename List::tail, N >::type >::type type;
+};
+
+template< size_t N >
+struct type_of_size< nulltype, N >
+{
+ typedef optional_CONFIG_ALIGN_AS_FALLBACK type;
+};
+
+template< typename T>
+struct struct_t { T _; };
+
+#define optional_ALIGN_TYPE( type ) \
+ typelist< type , typelist< struct_t< type >
+
+struct Unknown;
+
+typedef
+ optional_ALIGN_TYPE( char ),
+ optional_ALIGN_TYPE( short ),
+ optional_ALIGN_TYPE( int ),
+ optional_ALIGN_TYPE( long), optional_ALIGN_TYPE(float), optional_ALIGN_TYPE(double),
+ optional_ALIGN_TYPE(long double),
+
+ optional_ALIGN_TYPE(char*), optional_ALIGN_TYPE(short*), optional_ALIGN_TYPE(int*),
+ optional_ALIGN_TYPE(long*), optional_ALIGN_TYPE(float*), optional_ALIGN_TYPE(double*),
+ optional_ALIGN_TYPE(long double*),
+
+ optional_ALIGN_TYPE(Unknown (*)(Unknown)), optional_ALIGN_TYPE(Unknown* Unknown::*),
+ optional_ALIGN_TYPE(Unknown (Unknown::*)(Unknown)),
+
+ nulltype >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> alignment_types;
+
+#undef optional_ALIGN_TYPE
+
+#endif // optional_CONFIG_MAX_ALIGN_HACK
+
+/// C++03 constructed union to hold value.
+
+template <typename T>
+union storage_t {
+ // private:
+ // template< typename > friend class optional;
+
+ typedef T value_type;
+
+ storage_t() optional_is_default
+
+ explicit storage_t(value_type const& v) {
+ construct_value(v);
+ }
+
+ void construct_value(value_type const& v) { ::new (value_ptr()) value_type(v); }
+
+#if optional_CPP11_OR_GREATER
+
+ explicit storage_t(value_type&& v) { construct_value(std::move(v)); }
+
+ void construct_value(value_type&& v) { ::new (value_ptr()) value_type(std::move(v)); }
+
+ template <class... Args>
+ void emplace(Args&&... args) {
+ ::new (value_ptr()) value_type(std::forward<Args>(args)...);
+ }
+
+ template <class U, class... Args>
+ void emplace(std::initializer_list<U> il, Args&&... args) {
+ ::new (value_ptr()) value_type(il, std::forward<Args>(args)...);
+ }
+
+#endif
+
+ void destruct_value() { value_ptr()->~T(); }
+
+ optional_nodiscard value_type const* value_ptr() const { return as<value_type>(); }
+
+ value_type* value_ptr() { return as<value_type>(); }
+
+ optional_nodiscard value_type const& value() const optional_ref_qual {
+ return *value_ptr();
+ }
+
+ value_type& value() optional_ref_qual { return *value_ptr(); }
+
+#if optional_CPP11_OR_GREATER
+
+ optional_nodiscard value_type const&& value() const optional_refref_qual {
+ return std::move(value());
+ }
+
+ value_type&& value() optional_refref_qual { return std::move(value()); }
+
+#endif
+
+#if optional_CPP11_OR_GREATER
+
+ using aligned_storage_t =
+ typename std::aligned_storage<sizeof(value_type), alignof(value_type)>::type;
+ aligned_storage_t data;
+
+#elif optional_CONFIG_MAX_ALIGN_HACK
+
+ typedef struct {
+ unsigned char data[sizeof(value_type)];
+ } aligned_storage_t;
+
+ max_align_t hack;
+ aligned_storage_t data;
+
+#else
+ typedef optional_ALIGN_AS(value_type) align_as_type;
+
+ typedef struct {
+ align_as_type data[1 + (sizeof(value_type) - 1) / sizeof(align_as_type)];
+ } aligned_storage_t;
+ aligned_storage_t data;
+
+#undef optional_ALIGN_AS
+
+#endif // optional_CONFIG_MAX_ALIGN_HACK
+
+ optional_nodiscard void* ptr() optional_noexcept { return &data; }
+
+ optional_nodiscard void const* ptr() const optional_noexcept { return &data; }
+
+ template <typename U>
+ optional_nodiscard U* as() {
+ return reinterpret_cast<U*>(ptr());
+ }
+
+ template <typename U>
+ optional_nodiscard U const* as() const {
+ return reinterpret_cast<U const*>(ptr());
+ }
+};
+
+} // namespace detail
+
+/// disengaged state tag
+
+struct nullopt_t {
+ struct init {};
+ explicit optional_constexpr nullopt_t(init /*unused*/) optional_noexcept {}
+};
+
+#if optional_HAVE(CONSTEXPR_11)
+constexpr nullopt_t nullopt{nullopt_t::init{}};
+#else
+// extra parenthesis to prevent the most vexing parse:
+const nullopt_t nullopt((nullopt_t::init()));
+#endif
+
+/// optional access error
+
+#if !optional_CONFIG_NO_EXCEPTIONS
+
+class bad_optional_access : public std::logic_error {
+ public:
+ explicit bad_optional_access() : logic_error("bad optional access") {}
+};
+
+#endif // optional_CONFIG_NO_EXCEPTIONS
+
+/// optional
+
+template <typename T>
+class optional {
+ private:
+ template <typename>
+ friend class optional;
+
+ typedef void (optional::*safe_bool)() const;
+
+ public:
+ typedef T value_type;
+
+ // x.x.3.1, constructors
+
+ // 1a - default construct
+ optional_constexpr optional() optional_noexcept : has_value_(false), contained() {}
+
+ // 1b - construct explicitly empty
+ // NOLINTNEXTLINE( google-explicit-constructor, hicpp-explicit-conversions )
+ optional_constexpr optional(nullopt_t /*unused*/) optional_noexcept : has_value_(false),
+ contained() {}
+
+ // 2 - copy-construct
+ optional_constexpr14 optional(
+ optional const& other
+#if optional_CPP11_OR_GREATER
+ optional_REQUIRES_A(true || std::is_copy_constructible<T>::value)
+#endif
+ )
+ : has_value_(other.has_value()) {
+ if (other.has_value()) {
+ contained.construct_value(other.contained.value());
+ }
+ }
+
+#if optional_CPP11_OR_GREATER
+
+ // 3 (C++11) - move-construct from optional
+ optional_constexpr14 optional(
+ optional&& other optional_REQUIRES_A(true || std::is_move_constructible<T>::value)
+ // NOLINTNEXTLINE( performance-noexcept-move-constructor )
+ ) noexcept(std::is_nothrow_move_constructible<T>::value)
+ : has_value_(other.has_value()) {
+ if (other.has_value()) {
+ contained.construct_value(std::move(other.contained.value()));
+ }
+ }
+
+ // 4a (C++11) - explicit converting copy-construct from optional
+ template <typename U>
+ explicit optional(optional<U> const& other optional_REQUIRES_A(
+ std::is_constructible<T, U const&>::value &&
+ !std::is_constructible<T, optional<U>&>::value &&
+ !std::is_constructible<T, optional<U>&&>::value &&
+ !std::is_constructible<T, optional<U> const&>::value &&
+ !std::is_constructible<T, optional<U> const&&>::value &&
+ !std::is_convertible<optional<U>&, T>::value &&
+ !std::is_convertible<optional<U>&&, T>::value &&
+ !std::is_convertible<optional<U> const&, T>::value &&
+ !std::is_convertible<optional<U> const&&, T>::value &&
+ !std::is_convertible<U const&, T>::value /*=> explicit */
+ ))
+ : has_value_(other.has_value()) {
+ if (other.has_value()) {
+ contained.construct_value(T{other.contained.value()});
+ }
+ }
+#endif // optional_CPP11_OR_GREATER
+
+ // 4b (C++98 and later) - non-explicit converting copy-construct from optional
+ template <typename U>
+ // NOLINTNEXTLINE( google-explicit-constructor, hicpp-explicit-conversions )
+ optional(
+ optional<U> const& other
+#if optional_CPP11_OR_GREATER
+ optional_REQUIRES_A(std::is_constructible<T, U const&>::value &&
+ !std::is_constructible<T, optional<U>&>::value &&
+ !std::is_constructible<T, optional<U>&&>::value &&
+ !std::is_constructible<T, optional<U> const&>::value &&
+ !std::is_constructible<T, optional<U> const&&>::value &&
+ !std::is_convertible<optional<U>&, T>::value &&
+ !std::is_convertible<optional<U>&&, T>::value &&
+ !std::is_convertible<optional<U> const&, T>::value &&
+ !std::is_convertible<optional<U> const&&, T>::value &&
+ std::is_convertible<U const&, T>::value /*=> non-explicit */
+ )
+#endif // optional_CPP11_OR_GREATER
+ )
+ : has_value_(other.has_value()) {
+ if (other.has_value()) {
+ contained.construct_value(other.contained.value());
+ }
+ }
+
+#if optional_CPP11_OR_GREATER
+
+ // 5a (C++11) - explicit converting move-construct from optional
+ template <typename U>
+ explicit optional(optional<U>&& other optional_REQUIRES_A(
+ std::is_constructible<T, U&&>::value &&
+ !std::is_constructible<T, optional<U>&>::value &&
+ !std::is_constructible<T, optional<U>&&>::value &&
+ !std::is_constructible<T, optional<U> const&>::value &&
+ !std::is_constructible<T, optional<U> const&&>::value &&
+ !std::is_convertible<optional<U>&, T>::value &&
+ !std::is_convertible<optional<U>&&, T>::value &&
+ !std::is_convertible<optional<U> const&, T>::value &&
+ !std::is_convertible<optional<U> const&&, T>::value &&
+ !std::is_convertible<U&&, T>::value /*=> explicit */
+ ))
+ : has_value_(other.has_value()) {
+ if (other.has_value()) {
+ contained.construct_value(T{std::move(other.contained.value())});
+ }
+ }
+
+ // 5a (C++11) - non-explicit converting move-construct from optional
+ template <typename U>
+ // NOLINTNEXTLINE( google-explicit-constructor, hicpp-explicit-conversions )
+ optional(optional<U>&& other optional_REQUIRES_A(
+ std::is_constructible<T, U&&>::value &&
+ !std::is_constructible<T, optional<U>&>::value &&
+ !std::is_constructible<T, optional<U>&&>::value &&
+ !std::is_constructible<T, optional<U> const&>::value &&
+ !std::is_constructible<T, optional<U> const&&>::value &&
+ !std::is_convertible<optional<U>&, T>::value &&
+ !std::is_convertible<optional<U>&&, T>::value &&
+ !std::is_convertible<optional<U> const&, T>::value &&
+ !std::is_convertible<optional<U> const&&, T>::value &&
+ std::is_convertible<U&&, T>::value /*=> non-explicit */
+ ))
+ : has_value_(other.has_value()) {
+ if (other.has_value()) {
+ contained.construct_value(std::move(other.contained.value()));
+ }
+ }
+
+ // 6 (C++11) - in-place construct
+ template <
+ typename... Args optional_REQUIRES_T(std::is_constructible<T, Args&&...>::value)>
+ optional_constexpr explicit optional(nonstd_lite_in_place_t(T), Args&&... args)
+ : has_value_(true), contained(T(std::forward<Args>(args)...)) {}
+
+ // 7 (C++11) - in-place construct, initializer-list
+ template <typename U,
+ typename... Args optional_REQUIRES_T(
+ std::is_constructible<T, std::initializer_list<U>&, Args&&...>::value)>
+ optional_constexpr explicit optional(nonstd_lite_in_place_t(T),
+ std::initializer_list<U> il, Args&&... args)
+ : has_value_(true), contained(T(il, std::forward<Args>(args)...)) {}
+
+ // 8a (C++11) - explicit move construct from value
+ template <typename U = value_type>
+ optional_constexpr explicit optional(U&& value optional_REQUIRES_A(
+ std::is_constructible<T, U&&>::value &&
+ !std::is_same<typename std20::remove_cvref<U>::type,
+ nonstd_lite_in_place_t(U)>::value &&
+ !std::is_same<typename std20::remove_cvref<U>::type, optional<T> >::value &&
+ !std::is_convertible<U&&, T>::value /*=> explicit */
+ ))
+ : has_value_(true), contained(T{std::forward<U>(value)}) {}
+
+ // 8b (C++11) - non-explicit move construct from value
+ template <typename U = value_type>
+ // NOLINTNEXTLINE( google-explicit-constructor, hicpp-explicit-conversions )
+ optional_constexpr optional(U&& value optional_REQUIRES_A(
+ std::is_constructible<T, U&&>::value &&
+ !std::is_same<typename std20::remove_cvref<U>::type,
+ nonstd_lite_in_place_t(U)>::value &&
+ !std::is_same<typename std20::remove_cvref<U>::type, optional<T> >::value &&
+ std::is_convertible<U&&, T>::value /*=> non-explicit */
+ ))
+ : has_value_(true), contained(std::forward<U>(value)) {}
+
+#else // optional_CPP11_OR_GREATER
+
+ // 8 (C++98)
+ optional(value_type const& value) : has_value_(true), contained(value) {}
+
+#endif // optional_CPP11_OR_GREATER
+
+ // x.x.3.2, destructor
+
+ ~optional() {
+ if (has_value()) {
+ contained.destruct_value();
+ }
+ }
+
+ // x.x.3.3, assignment
+
+ // 1 (C++98and later) - assign explicitly empty
+ optional& operator=(nullopt_t /*unused*/) optional_noexcept {
+ reset();
+ return *this;
+ }
+
+ // 2 (C++98and later) - copy-assign from optional
+#if optional_CPP11_OR_GREATER
+ // NOLINTNEXTLINE( cppcoreguidelines-c-copy-assignment-signature,
+ // misc-unconventional-assign-operator )
+ optional_REQUIRES_R(optional&, true
+ // std::is_copy_constructible<T>::value
+ // && std::is_copy_assignable<T>::value
+ )
+ operator=(optional const& other) noexcept(
+ std::is_nothrow_move_assignable<T>::value&&
+ std::is_nothrow_move_constructible<T>::value)
+#else
+ optional& operator=(optional const& other)
+#endif
+ {
+ if ((has_value() == true) && (other.has_value() == false)) {
+ reset();
+ } else if ((has_value() == false) && (other.has_value() == true)) {
+ initialize(*other);
+ } else if ((has_value() == true) && (other.has_value() == true)) {
+ contained.value() = *other;
+ }
+ return *this;
+ }
+
+#if optional_CPP11_OR_GREATER
+
+ // 3 (C++11) - move-assign from optional
+ // NOLINTNEXTLINE( cppcoreguidelines-c-copy-assignment-signature,
+ // misc-unconventional-assign-operator )
+ optional_REQUIRES_R(optional&, true
+ // std::is_move_constructible<T>::value
+ // && std::is_move_assignable<T>::value
+ )
+ operator=(optional&& other) noexcept {
+ if ((has_value() == true) && (other.has_value() == false)) {
+ reset();
+ } else if ((has_value() == false) && (other.has_value() == true)) {
+ initialize(std::move(*other));
+ } else if ((has_value() == true) && (other.has_value() == true)) {
+ contained.value() = std::move(*other);
+ }
+ return *this;
+ }
+
+ // 4 (C++11) - move-assign from value
+ template <typename U = T>
+ // NOLINTNEXTLINE( cppcoreguidelines-c-copy-assignment-signature,
+ // misc-unconventional-assign-operator )
+ optional_REQUIRES_R(
+ optional&,
+ std::is_constructible<T, U>::value&& std::is_assignable<T&, U>::value &&
+ !std::is_same<typename std20::remove_cvref<U>::type,
+ nonstd_lite_in_place_t(U)>::value &&
+ !std::is_same<typename std20::remove_cvref<U>::type, optional<T> >::value &&
+ !(std::is_scalar<T>::value &&
+ std::is_same<T, typename std::decay<U>::type>::value))
+ operator=(U&& value) {
+ if (has_value()) {
+ contained.value() = std::forward<U>(value);
+ } else {
+ initialize(T(std::forward<U>(value)));
+ }
+ return *this;
+ }
+
+#else // optional_CPP11_OR_GREATER
+
+ // 4 (C++98) - copy-assign from value
+ template <typename U /*= T*/>
+ optional& operator=(U const& value) {
+ if (has_value())
+ contained.value() = value;
+ else
+ initialize(T(value));
+ return *this;
+ }
+
+#endif // optional_CPP11_OR_GREATER
+
+ // 5 (C++98 and later) - converting copy-assign from optional
+ template <typename U>
+#if optional_CPP11_OR_GREATER
+ // NOLINTNEXTLINE( cppcoreguidelines-c-copy-assignment-signature,
+ // misc-unconventional-assign-operator )
+ optional_REQUIRES_R(optional&,
+ std::is_constructible<T, U const&>::value&&
+ std::is_assignable<T&, U const&>::value &&
+ !std::is_constructible<T, optional<U>&>::value &&
+ !std::is_constructible<T, optional<U>&&>::value &&
+ !std::is_constructible<T, optional<U> const&>::value &&
+ !std::is_constructible<T, optional<U> const&&>::value &&
+ !std::is_convertible<optional<U>&, T>::value &&
+ !std::is_convertible<optional<U>&&, T>::value &&
+ !std::is_convertible<optional<U> const&, T>::value &&
+ !std::is_convertible<optional<U> const&&, T>::value &&
+ !std::is_assignable<T&, optional<U>&>::value &&
+ !std::is_assignable<T&, optional<U>&&>::value &&
+ !std::is_assignable<T&, optional<U> const&>::value &&
+ !std::is_assignable<T&, optional<U> const&&>::value)
+#else
+ optional&
+#endif // optional_CPP11_OR_GREATER
+ operator=(optional<U> const& other) {
+ return *this = optional(other);
+ }
+
+#if optional_CPP11_OR_GREATER
+
+ // 6 (C++11) - converting move-assign from optional
+ template <typename U>
+ // NOLINTNEXTLINE( cppcoreguidelines-c-copy-assignment-signature,
+ // misc-unconventional-assign-operator )
+ optional_REQUIRES_R(
+ optional&, std::is_constructible<T, U>::value&& std::is_assignable<T&, U>::value &&
+ !std::is_constructible<T, optional<U>&>::value &&
+ !std::is_constructible<T, optional<U>&&>::value &&
+ !std::is_constructible<T, optional<U> const&>::value &&
+ !std::is_constructible<T, optional<U> const&&>::value &&
+ !std::is_convertible<optional<U>&, T>::value &&
+ !std::is_convertible<optional<U>&&, T>::value &&
+ !std::is_convertible<optional<U> const&, T>::value &&
+ !std::is_convertible<optional<U> const&&, T>::value &&
+ !std::is_assignable<T&, optional<U>&>::value &&
+ !std::is_assignable<T&, optional<U>&&>::value &&
+ !std::is_assignable<T&, optional<U> const&>::value &&
+ !std::is_assignable<T&, optional<U> const&&>::value)
+ operator=(optional<U>&& other) {
+ return *this = optional(std::move(other));
+ }
+
+ // 7 (C++11) - emplace
+ template <
+ typename... Args optional_REQUIRES_T(std::is_constructible<T, Args&&...>::value)>
+ T& emplace(Args&&... args) {
+ *this = nullopt;
+ contained.emplace(std::forward<Args>(args)...);
+ has_value_ = true;
+ return contained.value();
+ }
+
+ // 8 (C++11) - emplace, initializer-list
+ template <typename U,
+ typename... Args optional_REQUIRES_T(
+ std::is_constructible<T, std::initializer_list<U>&, Args&&...>::value)>
+ T& emplace(std::initializer_list<U> il, Args&&... args) {
+ *this = nullopt;
+ contained.emplace(il, std::forward<Args>(args)...);
+ has_value_ = true;
+ return contained.value();
+ }
+
+#endif // optional_CPP11_OR_GREATER
+
+ // x.x.3.4, swap
+
+ void swap(optional& other)
+#if optional_CPP11_OR_GREATER
+ noexcept(std::is_nothrow_move_constructible<T>::value&&
+ std17::is_nothrow_swappable<T>::value)
+#endif
+ {
+ using std::swap;
+ if ((has_value() == true) && (other.has_value() == true)) {
+ swap(**this, *other);
+ } else if ((has_value() == false) && (other.has_value() == true)) {
+ initialize(std11::move(*other));
+ other.reset();
+ } else if ((has_value() == true) && (other.has_value() == false)) {
+ other.initialize(std11::move(**this));
+ reset();
+ }
+ }
+
+ // x.x.3.5, observers
+
+ optional_constexpr value_type const* operator->() const {
+ return assert(has_value()), contained.value_ptr();
+ }
+
+ optional_constexpr14 value_type* operator->() {
+ return assert(has_value()), contained.value_ptr();
+ }
+
+ optional_constexpr value_type const& operator*() const optional_ref_qual {
+ return assert(has_value()), contained.value();
+ }
+
+ optional_constexpr14 value_type& operator*() optional_ref_qual {
+ return assert(has_value()), contained.value();
+ }
+
+#if optional_HAVE(REF_QUALIFIER) && \
+ (!optional_COMPILER_GNUC_VERSION || optional_COMPILER_GNUC_VERSION >= 490)
+
+ optional_constexpr value_type const&& operator*() const optional_refref_qual {
+ return std::move(**this);
+ }
+
+ optional_constexpr14 value_type&& operator*() optional_refref_qual {
+ return std::move(**this);
+ }
+
+#endif
+
+#if optional_CPP11_OR_GREATER
+ optional_constexpr explicit operator bool() const optional_noexcept {
+ return has_value();
+ }
+#else
+ optional_constexpr operator safe_bool() const optional_noexcept {
+ return has_value() ? &optional::this_type_does_not_support_comparisons : 0;
+ }
+#endif
+
+ // NOLINTNEXTLINE( modernize-use-nodiscard )
+ /*optional_nodiscard*/ optional_constexpr bool has_value() const optional_noexcept {
+ return has_value_;
+ }
+
+ // NOLINTNEXTLINE( modernize-use-nodiscard )
+ /*optional_nodiscard*/ optional_constexpr14 value_type const& value() const
+ optional_ref_qual {
+#if optional_CONFIG_NO_EXCEPTIONS
+ assert(has_value());
+#else
+ if (!has_value()) {
+ throw bad_optional_access();
+ }
+#endif
+ return contained.value();
+ }
+
+ optional_constexpr14 value_type& value() optional_ref_qual {
+#if optional_CONFIG_NO_EXCEPTIONS
+ assert(has_value());
+#else
+ if (!has_value()) {
+ throw bad_optional_access();
+ }
+#endif
+ return contained.value();
+ }
+
+#if optional_HAVE(REF_QUALIFIER) && \
+ (!optional_COMPILER_GNUC_VERSION || optional_COMPILER_GNUC_VERSION >= 490)
+
+ // NOLINTNEXTLINE( modernize-use-nodiscard )
+ /*optional_nodiscard*/ optional_constexpr value_type const&& value() const
+ optional_refref_qual {
+ return std::move(value());
+ }
+
+ optional_constexpr14 value_type&& value() optional_refref_qual {
+ return std::move(value());
+ }
+
+#endif
+
+#if optional_CPP11_OR_GREATER
+
+ template <typename U>
+ optional_constexpr value_type value_or(U&& v) const optional_ref_qual {
+ return has_value() ? contained.value() : static_cast<T>(std::forward<U>(v));
+ }
+
+ template <typename U>
+ optional_constexpr14 value_type value_or(U&& v) optional_refref_qual {
+ return has_value() ? std::move(contained.value())
+ : static_cast<T>(std::forward<U>(v));
+ }
+
+#else
+
+ template <typename U>
+ optional_constexpr value_type value_or(U const& v) const {
+ return has_value() ? contained.value() : static_cast<value_type>(v);
+ }
+
+#endif // optional_CPP11_OR_GREATER
+
+ // x.x.3.6, modifiers
+
+ void reset() optional_noexcept {
+ if (has_value()) {
+ contained.destruct_value();
+ }
+
+ has_value_ = false;
+ }
+
+ private:
+ void this_type_does_not_support_comparisons() const {}
+
+ template <typename V>
+ void initialize(V const& value) {
+ assert(!has_value());
+ contained.construct_value(value);
+ has_value_ = true;
+ }
+
+#if optional_CPP11_OR_GREATER
+ template <typename V>
+ void initialize(V&& value) {
+ assert(!has_value());
+ contained.construct_value(std::move(value));
+ has_value_ = true;
+ }
+
+#endif
+
+ private:
+ bool has_value_;
+ detail::storage_t<value_type> contained;
+};
+
+// Relational operators
+
+template <typename T, typename U>
+inline optional_constexpr bool operator==(optional<T> const& x, optional<U> const& y) {
+ return bool(x) != bool(y) ? false : !bool(x) ? true : *x == *y;
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator!=(optional<T> const& x, optional<U> const& y) {
+ return !(x == y);
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator<(optional<T> const& x, optional<U> const& y) {
+ return (!y) ? false : (!x) ? true : *x < *y;
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator>(optional<T> const& x, optional<U> const& y) {
+ return (y < x);
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator<=(optional<T> const& x, optional<U> const& y) {
+ return !(y < x);
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator>=(optional<T> const& x, optional<U> const& y) {
+ return !(x < y);
+}
+
+// Comparison with nullopt
+
+template <typename T>
+inline optional_constexpr bool operator==(optional<T> const& x,
+ nullopt_t /*unused*/) optional_noexcept {
+ return (!x);
+}
+
+template <typename T>
+inline optional_constexpr bool operator==(nullopt_t /*unused*/,
+ optional<T> const& x) optional_noexcept {
+ return (!x);
+}
+
+template <typename T>
+inline optional_constexpr bool operator!=(optional<T> const& x,
+ nullopt_t /*unused*/) optional_noexcept {
+ return bool(x);
+}
+
+template <typename T>
+inline optional_constexpr bool operator!=(nullopt_t /*unused*/,
+ optional<T> const& x) optional_noexcept {
+ return bool(x);
+}
+
+template <typename T>
+inline optional_constexpr bool operator<(optional<T> const& /*unused*/,
+ nullopt_t /*unused*/) optional_noexcept {
+ return false;
+}
+
+template <typename T>
+inline optional_constexpr bool operator<(nullopt_t /*unused*/,
+ optional<T> const& x) optional_noexcept {
+ return bool(x);
+}
+
+template <typename T>
+inline optional_constexpr bool operator<=(optional<T> const& x,
+ nullopt_t /*unused*/) optional_noexcept {
+ return (!x);
+}
+
+template <typename T>
+inline optional_constexpr bool operator<=(
+ nullopt_t /*unused*/, optional<T> const& /*unused*/) optional_noexcept {
+ return true;
+}
+
+template <typename T>
+inline optional_constexpr bool operator>(optional<T> const& x,
+ nullopt_t /*unused*/) optional_noexcept {
+ return bool(x);
+}
+
+template <typename T>
+inline optional_constexpr bool operator>(
+ nullopt_t /*unused*/, optional<T> const& /*unused*/) optional_noexcept {
+ return false;
+}
+
+template <typename T>
+inline optional_constexpr bool operator>=(optional<T> const& /*unused*/,
+ nullopt_t /*unused*/) optional_noexcept {
+ return true;
+}
+
+template <typename T>
+inline optional_constexpr bool operator>=(nullopt_t /*unused*/,
+ optional<T> const& x) optional_noexcept {
+ return (!x);
+}
+
+// Comparison with T
+
+template <typename T, typename U>
+inline optional_constexpr bool operator==(optional<T> const& x, U const& v) {
+ return bool(x) ? *x == v : false;
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator==(U const& v, optional<T> const& x) {
+ return bool(x) ? v == *x : false;
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator!=(optional<T> const& x, U const& v) {
+ return bool(x) ? *x != v : true;
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator!=(U const& v, optional<T> const& x) {
+ return bool(x) ? v != *x : true;
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator<(optional<T> const& x, U const& v) {
+ return bool(x) ? *x < v : true;
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator<(U const& v, optional<T> const& x) {
+ return bool(x) ? v < *x : false;
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator<=(optional<T> const& x, U const& v) {
+ return bool(x) ? *x <= v : true;
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator<=(U const& v, optional<T> const& x) {
+ return bool(x) ? v <= *x : false;
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator>(optional<T> const& x, U const& v) {
+ return bool(x) ? *x > v : false;
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator>(U const& v, optional<T> const& x) {
+ return bool(x) ? v > *x : true;
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator>=(optional<T> const& x, U const& v) {
+ return bool(x) ? *x >= v : false;
+}
+
+template <typename T, typename U>
+inline optional_constexpr bool operator>=(U const& v, optional<T> const& x) {
+ return bool(x) ? v >= *x : true;
+}
+
+// Specialized algorithms
+
+template <typename T
+#if optional_CPP11_OR_GREATER
+ optional_REQUIRES_T(
+ std::is_move_constructible<T>::value&& std17::is_swappable<T>::value)
+#endif
+ >
+void swap(optional<T>& x, optional<T>& y)
+#if optional_CPP11_OR_GREATER
+ noexcept(noexcept(x.swap(y)))
+#endif
+{
+ x.swap(y);
+}
+
+#if optional_CPP11_OR_GREATER
+
+template <typename T>
+optional_constexpr optional<typename std::decay<T>::type> make_optional(T&& value) {
+ return optional<typename std::decay<T>::type>(std::forward<T>(value));
+}
+
+template <typename T, typename... Args>
+optional_constexpr optional<T> make_optional(Args&&... args) {
+ return optional<T>(nonstd_lite_in_place(T), std::forward<Args>(args)...);
+}
+
+template <typename T, typename U, typename... Args>
+optional_constexpr optional<T> make_optional(std::initializer_list<U> il,
+ Args&&... args) {
+ return optional<T>(nonstd_lite_in_place(T), il, std::forward<Args>(args)...);
+}
+
+#else
+
+template <typename T>
+optional<T> make_optional(T const& value) {
+ return optional<T>(value);
+}
+
+#endif // optional_CPP11_OR_GREATER
+
+} // namespace optional_lite
+
+using optional_lite::bad_optional_access;
+using optional_lite::nullopt;
+using optional_lite::nullopt_t;
+using optional_lite::optional;
+
+using optional_lite::make_optional;
+
+} // namespace nonstd
+
+#if optional_CPP11_OR_GREATER
+
+// specialize the std::hash algorithm:
+
+namespace std {
+
+template <class T>
+struct hash<nonstd::optional<T> > {
+ public:
+ std::size_t operator()(nonstd::optional<T> const& v) const optional_noexcept {
+ return bool(v) ? std::hash<T>{}(*v) : 0;
+ }
+};
+
+} // namespace std
+
+#endif // optional_CPP11_OR_GREATER
+
+#if defined(__clang__)
+#pragma clang diagnostic pop
+#elif defined(__GNUC__)
+#pragma GCC diagnostic pop
+#elif defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+#endif // optional_USES_STD_OPTIONAL
+
+#endif // NONSTD_OPTIONAL_LITE_HPP
diff --git a/src/arrow/cpp/src/arrow/vendored/pcg/README.md b/src/arrow/cpp/src/arrow/vendored/pcg/README.md
new file mode 100644
index 000000000..bf72ea897
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/pcg/README.md
@@ -0,0 +1,26 @@
+<!--
+PCG Random Number Generation for C++
+
+Copyright 2014-2019 Melissa O'Neill <oneill@pcg-random.org>,
+ and the PCG Project contributors.
+
+SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+Licensed under the Apache License, Version 2.0 (provided in
+LICENSE-APACHE.txt and at http://www.apache.org/licenses/LICENSE-2.0)
+or under the MIT license (provided in LICENSE-MIT.txt and at
+http://opensource.org/licenses/MIT), at your option. This file may not
+be copied, modified, or distributed except according to those terms.
+
+Distributed on an "AS IS" BASIS, WITHOUT WARRANTY OF ANY KIND, either
+express or implied. See your chosen license for details.
+
+For additional information about the PCG random number generation scheme,
+visit http://www.pcg-random.org/.
+-->
+
+Sources are taken from git changeset ffd522e7188bef30a00c74dc7eb9de5faff90092
+(https://github.com/imneme/pcg-cpp).
+
+Changes:
+- enclosed in `arrow_vendored` namespace
diff --git a/src/arrow/cpp/src/arrow/vendored/pcg/pcg_extras.hpp b/src/arrow/cpp/src/arrow/vendored/pcg/pcg_extras.hpp
new file mode 100644
index 000000000..760867e1e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/pcg/pcg_extras.hpp
@@ -0,0 +1,670 @@
+/*
+ * PCG Random Number Generation for C++
+ *
+ * Copyright 2014-2017 Melissa O'Neill <oneill@pcg-random.org>,
+ * and the PCG Project contributors.
+ *
+ * SPDX-License-Identifier: (Apache-2.0 OR MIT)
+ *
+ * Licensed under the Apache License, Version 2.0 (provided in
+ * LICENSE-APACHE.txt and at http://www.apache.org/licenses/LICENSE-2.0)
+ * or under the MIT license (provided in LICENSE-MIT.txt and at
+ * http://opensource.org/licenses/MIT), at your option. This file may not
+ * be copied, modified, or distributed except according to those terms.
+ *
+ * Distributed on an "AS IS" BASIS, WITHOUT WARRANTY OF ANY KIND, either
+ * express or implied. See your chosen license for details.
+ *
+ * For additional information about the PCG random number generation scheme,
+ * visit http://www.pcg-random.org/.
+ */
+
+/*
+ * This file provides support code that is useful for random-number generation
+ * but not specific to the PCG generation scheme, including:
+ * - 128-bit int support for platforms where it isn't available natively
+ * - bit twiddling operations
+ * - I/O of 128-bit and 8-bit integers
+ * - Handling the evilness of SeedSeq
+ * - Support for efficiently producing random numbers less than a given
+ * bound
+ */
+
+#ifndef PCG_EXTRAS_HPP_INCLUDED
+#define PCG_EXTRAS_HPP_INCLUDED 1
+
+#include <cinttypes>
+#include <cstddef>
+#include <cstdlib>
+#include <cstring>
+#include <cassert>
+#include <limits>
+#include <iostream>
+#include <type_traits>
+#include <utility>
+#include <locale>
+#include <iterator>
+
+#ifdef __GNUC__
+ #include <cxxabi.h>
+#endif
+
+/*
+ * Abstractions for compiler-specific directives
+ */
+
+#ifdef __GNUC__
+ #define PCG_NOINLINE __attribute__((noinline))
+#else
+ #define PCG_NOINLINE
+#endif
+
+/*
+ * Some members of the PCG library use 128-bit math. When compiling on 64-bit
+ * platforms, both GCC and Clang provide 128-bit integer types that are ideal
+ * for the job.
+ *
+ * On 32-bit platforms (or with other compilers), we fall back to a C++
+ * class that provides 128-bit unsigned integers instead. It may seem
+ * like we're reinventing the wheel here, because libraries already exist
+ * that support large integers, but most existing libraries provide a very
+ * generic multiprecision code, but here we're operating at a fixed size.
+ * Also, most other libraries are fairly heavyweight. So we use a direct
+ * implementation. Sadly, it's much slower than hand-coded assembly or
+ * direct CPU support.
+ *
+ */
+#if __SIZEOF_INT128__ && !PCG_FORCE_EMULATED_128BIT_MATH
+ namespace arrow_vendored {
+ namespace pcg_extras {
+ typedef __uint128_t pcg128_t;
+ }
+ }
+ #define PCG_128BIT_CONSTANT(high,low) \
+ ((pcg_extras::pcg128_t(high) << 64) + low)
+#else
+ #include "pcg_uint128.hpp"
+ namespace arrow_vendored {
+ namespace pcg_extras {
+ typedef pcg_extras::uint_x4<uint32_t,uint64_t> pcg128_t;
+ }
+ }
+ #define PCG_128BIT_CONSTANT(high,low) \
+ pcg_extras::pcg128_t(high,low)
+ #define PCG_EMULATED_128BIT_MATH 1
+#endif
+
+
+namespace arrow_vendored {
+namespace pcg_extras {
+
+/*
+ * We often need to represent a "number of bits". When used normally, these
+ * numbers are never greater than 128, so an unsigned char is plenty.
+ * If you're using a nonstandard generator of a larger size, you can set
+ * PCG_BITCOUNT_T to have it define it as a larger size. (Some compilers
+ * might produce faster code if you set it to an unsigned int.)
+ */
+
+#ifndef PCG_BITCOUNT_T
+ typedef uint8_t bitcount_t;
+#else
+ typedef PCG_BITCOUNT_T bitcount_t;
+#endif
+
+/*
+ * C++ requires us to be able to serialize RNG state by printing or reading
+ * it from a stream. Because we use 128-bit ints, we also need to be able
+ * ot print them, so here is code to do so.
+ *
+ * This code provides enough functionality to print 128-bit ints in decimal
+ * and zero-padded in hex. It's not a full-featured implementation.
+ */
+
+template <typename CharT, typename Traits>
+std::basic_ostream<CharT,Traits>&
+operator<<(std::basic_ostream<CharT,Traits>& out, pcg128_t value)
+{
+ auto desired_base = out.flags() & out.basefield;
+ bool want_hex = desired_base == out.hex;
+
+ if (want_hex) {
+ uint64_t highpart = uint64_t(value >> 64);
+ uint64_t lowpart = uint64_t(value);
+ auto desired_width = out.width();
+ if (desired_width > 16) {
+ out.width(desired_width - 16);
+ }
+ if (highpart != 0 || desired_width > 16)
+ out << highpart;
+ CharT oldfill = '\0';
+ if (highpart != 0) {
+ out.width(16);
+ oldfill = out.fill('0');
+ }
+ auto oldflags = out.setf(decltype(desired_base){}, out.showbase);
+ out << lowpart;
+ out.setf(oldflags);
+ if (highpart != 0) {
+ out.fill(oldfill);
+ }
+ return out;
+ }
+ constexpr size_t MAX_CHARS_128BIT = 40;
+
+ char buffer[MAX_CHARS_128BIT];
+ char* pos = buffer+sizeof(buffer);
+ *(--pos) = '\0';
+ constexpr auto BASE = pcg128_t(10ULL);
+ do {
+ auto div = value / BASE;
+ auto mod = uint32_t(value - (div * BASE));
+ *(--pos) = '0' + char(mod);
+ value = div;
+ } while(value != pcg128_t(0ULL));
+ return out << pos;
+}
+
+template <typename CharT, typename Traits>
+std::basic_istream<CharT,Traits>&
+operator>>(std::basic_istream<CharT,Traits>& in, pcg128_t& value)
+{
+ typename std::basic_istream<CharT,Traits>::sentry s(in);
+
+ if (!s)
+ return in;
+
+ constexpr auto BASE = pcg128_t(10ULL);
+ pcg128_t current(0ULL);
+ bool did_nothing = true;
+ bool overflow = false;
+ for(;;) {
+ CharT wide_ch = in.get();
+ if (!in.good())
+ break;
+ auto ch = in.narrow(wide_ch, '\0');
+ if (ch < '0' || ch > '9') {
+ in.unget();
+ break;
+ }
+ did_nothing = false;
+ pcg128_t digit(uint32_t(ch - '0'));
+ pcg128_t timesbase = current*BASE;
+ overflow = overflow || timesbase < current;
+ current = timesbase + digit;
+ overflow = overflow || current < digit;
+ }
+
+ if (did_nothing || overflow) {
+ in.setstate(std::ios::failbit);
+ if (overflow)
+ current = ~pcg128_t(0ULL);
+ }
+
+ value = current;
+
+ return in;
+}
+
+/*
+ * Likewise, if people use tiny rngs, we'll be serializing uint8_t.
+ * If we just used the provided IO operators, they'd read/write chars,
+ * not ints, so we need to define our own. We *can* redefine this operator
+ * here because we're in our own namespace.
+ */
+
+template <typename CharT, typename Traits>
+std::basic_ostream<CharT,Traits>&
+operator<<(std::basic_ostream<CharT,Traits>&out, uint8_t value)
+{
+ return out << uint32_t(value);
+}
+
+template <typename CharT, typename Traits>
+std::basic_istream<CharT,Traits>&
+operator>>(std::basic_istream<CharT,Traits>& in, uint8_t& target)
+{
+ uint32_t value = 0xdecea5edU;
+ in >> value;
+ if (!in && value == 0xdecea5edU)
+ return in;
+ if (value > uint8_t(~0)) {
+ in.setstate(std::ios::failbit);
+ value = ~0U;
+ }
+ target = uint8_t(value);
+ return in;
+}
+
+/* Unfortunately, the above functions don't get found in preference to the
+ * built in ones, so we create some more specific overloads that will.
+ * Ugh.
+ */
+
+inline std::ostream& operator<<(std::ostream& out, uint8_t value)
+{
+ return pcg_extras::operator<< <char>(out, value);
+}
+
+inline std::istream& operator>>(std::istream& in, uint8_t& value)
+{
+ return pcg_extras::operator>> <char>(in, value);
+}
+
+
+
+/*
+ * Useful bitwise operations.
+ */
+
+/*
+ * XorShifts are invertable, but they are someting of a pain to invert.
+ * This function backs them out. It's used by the whacky "inside out"
+ * generator defined later.
+ */
+
+template <typename itype>
+inline itype unxorshift(itype x, bitcount_t bits, bitcount_t shift)
+{
+ if (2*shift >= bits) {
+ return x ^ (x >> shift);
+ }
+ itype lowmask1 = (itype(1U) << (bits - shift*2)) - 1;
+ itype highmask1 = ~lowmask1;
+ itype top1 = x;
+ itype bottom1 = x & lowmask1;
+ top1 ^= top1 >> shift;
+ top1 &= highmask1;
+ x = top1 | bottom1;
+ itype lowmask2 = (itype(1U) << (bits - shift)) - 1;
+ itype bottom2 = x & lowmask2;
+ bottom2 = unxorshift(bottom2, bits - shift, shift);
+ bottom2 &= lowmask1;
+ return top1 | bottom2;
+}
+
+/*
+ * Rotate left and right.
+ *
+ * In ideal world, compilers would spot idiomatic rotate code and convert it
+ * to a rotate instruction. Of course, opinions vary on what the correct
+ * idiom is and how to spot it. For clang, sometimes it generates better
+ * (but still crappy) code if you define PCG_USE_ZEROCHECK_ROTATE_IDIOM.
+ */
+
+template <typename itype>
+inline itype rotl(itype value, bitcount_t rot)
+{
+ constexpr bitcount_t bits = sizeof(itype) * 8;
+ constexpr bitcount_t mask = bits - 1;
+#if PCG_USE_ZEROCHECK_ROTATE_IDIOM
+ return rot ? (value << rot) | (value >> (bits - rot)) : value;
+#else
+ return (value << rot) | (value >> ((- rot) & mask));
+#endif
+}
+
+template <typename itype>
+inline itype rotr(itype value, bitcount_t rot)
+{
+ constexpr bitcount_t bits = sizeof(itype) * 8;
+ constexpr bitcount_t mask = bits - 1;
+#if PCG_USE_ZEROCHECK_ROTATE_IDIOM
+ return rot ? (value >> rot) | (value << (bits - rot)) : value;
+#else
+ return (value >> rot) | (value << ((- rot) & mask));
+#endif
+}
+
+/* Unfortunately, both Clang and GCC sometimes perform poorly when it comes
+ * to properly recognizing idiomatic rotate code, so for we also provide
+ * assembler directives (enabled with PCG_USE_INLINE_ASM). Boo, hiss.
+ * (I hope that these compilers get better so that this code can die.)
+ *
+ * These overloads will be preferred over the general template code above.
+ */
+#if PCG_USE_INLINE_ASM && __GNUC__ && (__x86_64__ || __i386__)
+
+inline uint8_t rotr(uint8_t value, bitcount_t rot)
+{
+ asm ("rorb %%cl, %0" : "=r" (value) : "0" (value), "c" (rot));
+ return value;
+}
+
+inline uint16_t rotr(uint16_t value, bitcount_t rot)
+{
+ asm ("rorw %%cl, %0" : "=r" (value) : "0" (value), "c" (rot));
+ return value;
+}
+
+inline uint32_t rotr(uint32_t value, bitcount_t rot)
+{
+ asm ("rorl %%cl, %0" : "=r" (value) : "0" (value), "c" (rot));
+ return value;
+}
+
+#if __x86_64__
+inline uint64_t rotr(uint64_t value, bitcount_t rot)
+{
+ asm ("rorq %%cl, %0" : "=r" (value) : "0" (value), "c" (rot));
+ return value;
+}
+#endif // __x86_64__
+
+#elif defined(_MSC_VER)
+ // Use MSVC++ bit rotation intrinsics
+
+#pragma intrinsic(_rotr, _rotr64, _rotr8, _rotr16)
+
+inline uint8_t rotr(uint8_t value, bitcount_t rot)
+{
+ return _rotr8(value, rot);
+}
+
+inline uint16_t rotr(uint16_t value, bitcount_t rot)
+{
+ return _rotr16(value, rot);
+}
+
+inline uint32_t rotr(uint32_t value, bitcount_t rot)
+{
+ return _rotr(value, rot);
+}
+
+inline uint64_t rotr(uint64_t value, bitcount_t rot)
+{
+ return _rotr64(value, rot);
+}
+
+#endif // PCG_USE_INLINE_ASM
+
+
+/*
+ * The C++ SeedSeq concept (modelled by seed_seq) can fill an array of
+ * 32-bit integers with seed data, but sometimes we want to produce
+ * larger or smaller integers.
+ *
+ * The following code handles this annoyance.
+ *
+ * uneven_copy will copy an array of 32-bit ints to an array of larger or
+ * smaller ints (actually, the code is general it only needing forward
+ * iterators). The copy is identical to the one that would be performed if
+ * we just did memcpy on a standard little-endian machine, but works
+ * regardless of the endian of the machine (or the weirdness of the ints
+ * involved).
+ *
+ * generate_to initializes an array of integers using a SeedSeq
+ * object. It is given the size as a static constant at compile time and
+ * tries to avoid memory allocation. If we're filling in 32-bit constants
+ * we just do it directly. If we need a separate buffer and it's small,
+ * we allocate it on the stack. Otherwise, we fall back to heap allocation.
+ * Ugh.
+ *
+ * generate_one produces a single value of some integral type using a
+ * SeedSeq object.
+ */
+
+ /* uneven_copy helper, case where destination ints are less than 32 bit. */
+
+template<class SrcIter, class DestIter>
+SrcIter uneven_copy_impl(
+ SrcIter src_first, DestIter dest_first, DestIter dest_last,
+ std::true_type)
+{
+ typedef typename std::iterator_traits<SrcIter>::value_type src_t;
+ typedef typename std::iterator_traits<DestIter>::value_type dest_t;
+
+ constexpr bitcount_t SRC_SIZE = sizeof(src_t);
+ constexpr bitcount_t DEST_SIZE = sizeof(dest_t);
+ constexpr bitcount_t DEST_BITS = DEST_SIZE * 8;
+ constexpr bitcount_t SCALE = SRC_SIZE / DEST_SIZE;
+
+ size_t count = 0;
+ src_t value = 0;
+
+ while (dest_first != dest_last) {
+ if ((count++ % SCALE) == 0)
+ value = *src_first++; // Get more bits
+ else
+ value >>= DEST_BITS; // Move down bits
+
+ *dest_first++ = dest_t(value); // Truncates, ignores high bits.
+ }
+ return src_first;
+}
+
+ /* uneven_copy helper, case where destination ints are more than 32 bit. */
+
+template<class SrcIter, class DestIter>
+SrcIter uneven_copy_impl(
+ SrcIter src_first, DestIter dest_first, DestIter dest_last,
+ std::false_type)
+{
+ typedef typename std::iterator_traits<SrcIter>::value_type src_t;
+ typedef typename std::iterator_traits<DestIter>::value_type dest_t;
+
+ constexpr auto SRC_SIZE = sizeof(src_t);
+ constexpr auto SRC_BITS = SRC_SIZE * 8;
+ constexpr auto DEST_SIZE = sizeof(dest_t);
+ constexpr auto SCALE = (DEST_SIZE+SRC_SIZE-1) / SRC_SIZE;
+
+ while (dest_first != dest_last) {
+ dest_t value(0UL);
+ unsigned int shift = 0;
+
+ for (size_t i = 0; i < SCALE; ++i) {
+ value |= dest_t(*src_first++) << shift;
+ shift += SRC_BITS;
+ }
+
+ *dest_first++ = value;
+ }
+ return src_first;
+}
+
+/* uneven_copy, call the right code for larger vs. smaller */
+
+template<class SrcIter, class DestIter>
+inline SrcIter uneven_copy(SrcIter src_first,
+ DestIter dest_first, DestIter dest_last)
+{
+ typedef typename std::iterator_traits<SrcIter>::value_type src_t;
+ typedef typename std::iterator_traits<DestIter>::value_type dest_t;
+
+ constexpr bool DEST_IS_SMALLER = sizeof(dest_t) < sizeof(src_t);
+
+ return uneven_copy_impl(src_first, dest_first, dest_last,
+ std::integral_constant<bool, DEST_IS_SMALLER>{});
+}
+
+/* generate_to, fill in a fixed-size array of integral type using a SeedSeq
+ * (actually works for any random-access iterator)
+ */
+
+template <size_t size, typename SeedSeq, typename DestIter>
+inline void generate_to_impl(SeedSeq&& generator, DestIter dest,
+ std::true_type)
+{
+ generator.generate(dest, dest+size);
+}
+
+template <size_t size, typename SeedSeq, typename DestIter>
+void generate_to_impl(SeedSeq&& generator, DestIter dest,
+ std::false_type)
+{
+ typedef typename std::iterator_traits<DestIter>::value_type dest_t;
+ constexpr auto DEST_SIZE = sizeof(dest_t);
+ constexpr auto GEN_SIZE = sizeof(uint32_t);
+
+ constexpr bool GEN_IS_SMALLER = GEN_SIZE < DEST_SIZE;
+ constexpr size_t FROM_ELEMS =
+ GEN_IS_SMALLER
+ ? size * ((DEST_SIZE+GEN_SIZE-1) / GEN_SIZE)
+ : (size + (GEN_SIZE / DEST_SIZE) - 1)
+ / ((GEN_SIZE / DEST_SIZE) + GEN_IS_SMALLER);
+ // this odd code ^^^^^^^^^^^^^^^^^ is work-around for
+ // a bug: http://llvm.org/bugs/show_bug.cgi?id=21287
+
+ if (FROM_ELEMS <= 1024) {
+ uint32_t buffer[FROM_ELEMS];
+ generator.generate(buffer, buffer+FROM_ELEMS);
+ uneven_copy(buffer, dest, dest+size);
+ } else {
+ uint32_t* buffer = static_cast<uint32_t*>(malloc(GEN_SIZE * FROM_ELEMS));
+ generator.generate(buffer, buffer+FROM_ELEMS);
+ uneven_copy(buffer, dest, dest+size);
+ free(static_cast<void*>(buffer));
+ }
+}
+
+template <size_t size, typename SeedSeq, typename DestIter>
+inline void generate_to(SeedSeq&& generator, DestIter dest)
+{
+ typedef typename std::iterator_traits<DestIter>::value_type dest_t;
+ constexpr bool IS_32BIT = sizeof(dest_t) == sizeof(uint32_t);
+
+ generate_to_impl<size>(std::forward<SeedSeq>(generator), dest,
+ std::integral_constant<bool, IS_32BIT>{});
+}
+
+/* generate_one, produce a value of integral type using a SeedSeq
+ * (optionally, we can have it produce more than one and pick which one
+ * we want)
+ */
+
+template <typename UInt, size_t i = 0UL, size_t N = i+1UL, typename SeedSeq>
+inline UInt generate_one(SeedSeq&& generator)
+{
+ UInt result[N];
+ generate_to<N>(std::forward<SeedSeq>(generator), result);
+ return result[i];
+}
+
+template <typename RngType>
+auto bounded_rand(RngType& rng, typename RngType::result_type upper_bound)
+ -> typename RngType::result_type
+{
+ typedef typename RngType::result_type rtype;
+ rtype threshold = (RngType::max() - RngType::min() + rtype(1) - upper_bound)
+ % upper_bound;
+ for (;;) {
+ rtype r = rng() - RngType::min();
+ if (r >= threshold)
+ return r % upper_bound;
+ }
+}
+
+template <typename Iter, typename RandType>
+void shuffle(Iter from, Iter to, RandType&& rng)
+{
+ typedef typename std::iterator_traits<Iter>::difference_type delta_t;
+ typedef typename std::remove_reference<RandType>::type::result_type result_t;
+ auto count = to - from;
+ while (count > 1) {
+ delta_t chosen = delta_t(bounded_rand(rng, result_t(count)));
+ --count;
+ --to;
+ using std::swap;
+ swap(*(from + chosen), *to);
+ }
+}
+
+/*
+ * Although std::seed_seq is useful, it isn't everything. Often we want to
+ * initialize a random-number generator some other way, such as from a random
+ * device.
+ *
+ * Technically, it does not meet the requirements of a SeedSequence because
+ * it lacks some of the rarely-used member functions (some of which would
+ * be impossible to provide). However the C++ standard is quite specific
+ * that actual engines only called the generate method, so it ought not to be
+ * a problem in practice.
+ */
+
+template <typename RngType>
+class seed_seq_from {
+private:
+ RngType rng_;
+
+ typedef uint_least32_t result_type;
+
+public:
+ template<typename... Args>
+ seed_seq_from(Args&&... args) :
+ rng_(std::forward<Args>(args)...)
+ {
+ // Nothing (else) to do...
+ }
+
+ template<typename Iter>
+ void generate(Iter start, Iter finish)
+ {
+ for (auto i = start; i != finish; ++i)
+ *i = result_type(rng_());
+ }
+
+ constexpr size_t size() const
+ {
+ return (sizeof(typename RngType::result_type) > sizeof(result_type)
+ && RngType::max() > ~size_t(0UL))
+ ? ~size_t(0UL)
+ : size_t(RngType::max());
+ }
+};
+
+/*
+ * Sometimes you might want a distinct seed based on when the program
+ * was compiled. That way, a particular instance of the program will
+ * behave the same way, but when recompiled it'll produce a different
+ * value.
+ */
+
+template <typename IntType>
+struct static_arbitrary_seed {
+private:
+ static constexpr IntType fnv(IntType hash, const char* pos) {
+ return *pos == '\0'
+ ? hash
+ : fnv((hash * IntType(16777619U)) ^ *pos, (pos+1));
+ }
+
+public:
+ static constexpr IntType value = fnv(IntType(2166136261U ^ sizeof(IntType)),
+ __DATE__ __TIME__ __FILE__);
+};
+
+// Sometimes, when debugging or testing, it's handy to be able print the name
+// of a (in human-readable form). This code allows the idiom:
+//
+// cout << printable_typename<my_foo_type_t>()
+//
+// to print out my_foo_type_t (or its concrete type if it is a synonym)
+
+#if __cpp_rtti || __GXX_RTTI
+
+template <typename T>
+struct printable_typename {};
+
+template <typename T>
+std::ostream& operator<<(std::ostream& out, printable_typename<T>) {
+ const char *implementation_typename = typeid(T).name();
+#ifdef __GNUC__
+ int status;
+ char* pretty_name =
+ abi::__cxa_demangle(implementation_typename, nullptr, nullptr, &status);
+ if (status == 0)
+ out << pretty_name;
+ free(static_cast<void*>(pretty_name));
+ if (status == 0)
+ return out;
+#endif
+ out << implementation_typename;
+ return out;
+}
+
+#endif // __cpp_rtti || __GXX_RTTI
+
+} // namespace pcg_extras
+} // namespace arrow_vendored
+
+#endif // PCG_EXTRAS_HPP_INCLUDED
diff --git a/src/arrow/cpp/src/arrow/vendored/pcg/pcg_random.hpp b/src/arrow/cpp/src/arrow/vendored/pcg/pcg_random.hpp
new file mode 100644
index 000000000..a864ba0a2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/pcg/pcg_random.hpp
@@ -0,0 +1,1954 @@
+/*
+ * PCG Random Number Generation for C++
+ *
+ * Copyright 2014-2019 Melissa O'Neill <oneill@pcg-random.org>,
+ * and the PCG Project contributors.
+ *
+ * SPDX-License-Identifier: (Apache-2.0 OR MIT)
+ *
+ * Licensed under the Apache License, Version 2.0 (provided in
+ * LICENSE-APACHE.txt and at http://www.apache.org/licenses/LICENSE-2.0)
+ * or under the MIT license (provided in LICENSE-MIT.txt and at
+ * http://opensource.org/licenses/MIT), at your option. This file may not
+ * be copied, modified, or distributed except according to those terms.
+ *
+ * Distributed on an "AS IS" BASIS, WITHOUT WARRANTY OF ANY KIND, either
+ * express or implied. See your chosen license for details.
+ *
+ * For additional information about the PCG random number generation scheme,
+ * visit http://www.pcg-random.org/.
+ */
+
+/*
+ * This code provides the reference implementation of the PCG family of
+ * random number generators. The code is complex because it implements
+ *
+ * - several members of the PCG family, specifically members corresponding
+ * to the output functions:
+ * - XSH RR (good for 64-bit state, 32-bit output)
+ * - XSH RS (good for 64-bit state, 32-bit output)
+ * - XSL RR (good for 128-bit state, 64-bit output)
+ * - RXS M XS (statistically most powerful generator)
+ * - XSL RR RR (good for 128-bit state, 128-bit output)
+ * - and RXS, RXS M, XSH, XSL (mostly for testing)
+ * - at potentially *arbitrary* bit sizes
+ * - with four different techniques for random streams (MCG, one-stream
+ * LCG, settable-stream LCG, unique-stream LCG)
+ * - and the extended generation schemes allowing arbitrary periods
+ * - with all features of C++11 random number generation (and more),
+ * some of which are somewhat painful, including
+ * - initializing with a SeedSequence which writes 32-bit values
+ * to memory, even though the state of the generator may not
+ * use 32-bit values (it might use smaller or larger integers)
+ * - I/O for RNGs and a prescribed format, which needs to handle
+ * the issue that 8-bit and 128-bit integers don't have working
+ * I/O routines (e.g., normally 8-bit = char, not integer)
+ * - equality and inequality for RNGs
+ * - and a number of convenience typedefs to mask all the complexity
+ *
+ * The code employes a fairly heavy level of abstraction, and has to deal
+ * with various C++ minutia. If you're looking to learn about how the PCG
+ * scheme works, you're probably best of starting with one of the other
+ * codebases (see www.pcg-random.org). But if you're curious about the
+ * constants for the various output functions used in those other, simpler,
+ * codebases, this code shows how they are calculated.
+ *
+ * On the positive side, at least there are convenience typedefs so that you
+ * can say
+ *
+ * pcg32 myRNG;
+ *
+ * rather than:
+ *
+ * pcg_detail::engine<
+ * uint32_t, // Output Type
+ * uint64_t, // State Type
+ * pcg_detail::xsh_rr_mixin<uint32_t, uint64_t>, true, // Output Func
+ * pcg_detail::specific_stream<uint64_t>, // Stream Kind
+ * pcg_detail::default_multiplier<uint64_t> // LCG Mult
+ * > myRNG;
+ *
+ */
+
+#ifndef PCG_RAND_HPP_INCLUDED
+#define PCG_RAND_HPP_INCLUDED 1
+
+#include <algorithm>
+#include <cinttypes>
+#include <cstddef>
+#include <cstdlib>
+#include <cstring>
+#include <cassert>
+#include <limits>
+#include <iostream>
+#include <iterator>
+#include <type_traits>
+#include <utility>
+#include <locale>
+#include <new>
+#include <stdexcept>
+
+#ifdef _MSC_VER
+ #pragma warning(disable:4146)
+#endif
+
+#ifdef _MSC_VER
+ #define PCG_ALWAYS_INLINE __forceinline
+#elif __GNUC__
+ #define PCG_ALWAYS_INLINE __attribute__((always_inline))
+#else
+ #define PCG_ALWAYS_INLINE inline
+#endif
+
+/*
+ * The pcg_extras namespace contains some support code that is likley to
+ * be useful for a variety of RNGs, including:
+ * - 128-bit int support for platforms where it isn't available natively
+ * - bit twiddling operations
+ * - I/O of 128-bit and 8-bit integers
+ * - Handling the evilness of SeedSeq
+ * - Support for efficiently producing random numbers less than a given
+ * bound
+ */
+
+#include "pcg_extras.hpp"
+
+namespace arrow_vendored {
+namespace pcg_detail {
+
+using namespace pcg_extras;
+
+/*
+ * The LCG generators need some constants to function. This code lets you
+ * look up the constant by *type*. For example
+ *
+ * default_multiplier<uint32_t>::multiplier()
+ *
+ * gives you the default multipler for 32-bit integers. We use the name
+ * of the constant and not a generic word like value to allow these classes
+ * to be used as mixins.
+ */
+
+template <typename T>
+struct default_multiplier {
+ // Not defined for an arbitrary type
+};
+
+template <typename T>
+struct default_increment {
+ // Not defined for an arbitrary type
+};
+
+#define PCG_DEFINE_CONSTANT(type, what, kind, constant) \
+ template <> \
+ struct what ## _ ## kind<type> { \
+ static constexpr type kind() { \
+ return constant; \
+ } \
+ };
+
+PCG_DEFINE_CONSTANT(uint8_t, default, multiplier, 141U)
+PCG_DEFINE_CONSTANT(uint8_t, default, increment, 77U)
+
+PCG_DEFINE_CONSTANT(uint16_t, default, multiplier, 12829U)
+PCG_DEFINE_CONSTANT(uint16_t, default, increment, 47989U)
+
+PCG_DEFINE_CONSTANT(uint32_t, default, multiplier, 747796405U)
+PCG_DEFINE_CONSTANT(uint32_t, default, increment, 2891336453U)
+
+PCG_DEFINE_CONSTANT(uint64_t, default, multiplier, 6364136223846793005ULL)
+PCG_DEFINE_CONSTANT(uint64_t, default, increment, 1442695040888963407ULL)
+
+PCG_DEFINE_CONSTANT(pcg128_t, default, multiplier,
+ PCG_128BIT_CONSTANT(2549297995355413924ULL,4865540595714422341ULL))
+PCG_DEFINE_CONSTANT(pcg128_t, default, increment,
+ PCG_128BIT_CONSTANT(6364136223846793005ULL,1442695040888963407ULL))
+
+/* Alternative (cheaper) multipliers for 128-bit */
+
+template <typename T>
+struct cheap_multiplier : public default_multiplier<T> {
+ // For most types just use the default.
+};
+
+template <>
+struct cheap_multiplier<pcg128_t> {
+ static constexpr uint64_t multiplier() {
+ return 0xda942042e4dd58b5ULL;
+ }
+};
+
+
+/*
+ * Each PCG generator is available in four variants, based on how it applies
+ * the additive constant for its underlying LCG; the variations are:
+ *
+ * single stream - all instances use the same fixed constant, thus
+ * the RNG always somewhere in same sequence
+ * mcg - adds zero, resulting in a single stream and reduced
+ * period
+ * specific stream - the constant can be changed at any time, selecting
+ * a different random sequence
+ * unique stream - the constant is based on the memory address of the
+ * object, thus every RNG has its own unique sequence
+ *
+ * This variation is provided though mixin classes which define a function
+ * value called increment() that returns the nesessary additive constant.
+ */
+
+
+
+/*
+ * unique stream
+ */
+
+
+template <typename itype>
+class unique_stream {
+protected:
+ static constexpr bool is_mcg = false;
+
+ // Is never called, but is provided for symmetry with specific_stream
+ void set_stream(...)
+ {
+ abort();
+ }
+
+public:
+ typedef itype state_type;
+
+ constexpr itype increment() const {
+ return itype(reinterpret_cast<uintptr_t>(this) | 1);
+ }
+
+ constexpr itype stream() const
+ {
+ return increment() >> 1;
+ }
+
+ static constexpr bool can_specify_stream = false;
+
+ static constexpr size_t streams_pow2()
+ {
+ return (sizeof(itype) < sizeof(size_t) ? sizeof(itype)
+ : sizeof(size_t))*8 - 1u;
+ }
+
+protected:
+ constexpr unique_stream() = default;
+};
+
+
+/*
+ * no stream (mcg)
+ */
+
+template <typename itype>
+class no_stream {
+protected:
+ static constexpr bool is_mcg = true;
+
+ // Is never called, but is provided for symmetry with specific_stream
+ void set_stream(...)
+ {
+ abort();
+ }
+
+public:
+ typedef itype state_type;
+
+ static constexpr itype increment() {
+ return 0;
+ }
+
+ static constexpr bool can_specify_stream = false;
+
+ static constexpr size_t streams_pow2()
+ {
+ return 0u;
+ }
+
+protected:
+ constexpr no_stream() = default;
+};
+
+
+/*
+ * single stream/sequence (oneseq)
+ */
+
+template <typename itype>
+class oneseq_stream : public default_increment<itype> {
+protected:
+ static constexpr bool is_mcg = false;
+
+ // Is never called, but is provided for symmetry with specific_stream
+ void set_stream(...)
+ {
+ abort();
+ }
+
+public:
+ typedef itype state_type;
+
+ static constexpr itype stream()
+ {
+ return default_increment<itype>::increment() >> 1;
+ }
+
+ static constexpr bool can_specify_stream = false;
+
+ static constexpr size_t streams_pow2()
+ {
+ return 0u;
+ }
+
+protected:
+ constexpr oneseq_stream() = default;
+};
+
+
+/*
+ * specific stream
+ */
+
+template <typename itype>
+class specific_stream {
+protected:
+ static constexpr bool is_mcg = false;
+
+ itype inc_ = default_increment<itype>::increment();
+
+public:
+ typedef itype state_type;
+ typedef itype stream_state;
+
+ constexpr itype increment() const {
+ return inc_;
+ }
+
+ itype stream()
+ {
+ return inc_ >> 1;
+ }
+
+ void set_stream(itype specific_seq)
+ {
+ inc_ = (specific_seq << 1) | 1;
+ }
+
+ static constexpr bool can_specify_stream = true;
+
+ static constexpr size_t streams_pow2()
+ {
+ return (sizeof(itype)*8) - 1u;
+ }
+
+protected:
+ specific_stream() = default;
+
+ specific_stream(itype specific_seq)
+ : inc_(itype(specific_seq << 1) | itype(1U))
+ {
+ // Nothing (else) to do.
+ }
+};
+
+
+/*
+ * This is where it all comes together. This function joins together three
+ * mixin classes which define
+ * - the LCG additive constant (the stream)
+ * - the LCG multiplier
+ * - the output function
+ * in addition, we specify the type of the LCG state, and the result type,
+ * and whether to use the pre-advance version of the state for the output
+ * (increasing instruction-level parallelism) or the post-advance version
+ * (reducing register pressure).
+ *
+ * Given the high level of parameterization, the code has to use some
+ * template-metaprogramming tricks to handle some of the suble variations
+ * involved.
+ */
+
+template <typename xtype, typename itype,
+ typename output_mixin,
+ bool output_previous = true,
+ typename stream_mixin = oneseq_stream<itype>,
+ typename multiplier_mixin = default_multiplier<itype> >
+class engine : protected output_mixin,
+ public stream_mixin,
+ protected multiplier_mixin {
+protected:
+ itype state_;
+
+ struct can_specify_stream_tag {};
+ struct no_specifiable_stream_tag {};
+
+ using stream_mixin::increment;
+ using multiplier_mixin::multiplier;
+
+public:
+ typedef xtype result_type;
+ typedef itype state_type;
+
+ static constexpr size_t period_pow2()
+ {
+ return sizeof(state_type)*8 - 2*stream_mixin::is_mcg;
+ }
+
+ // It would be nice to use std::numeric_limits for these, but
+ // we can't be sure that it'd be defined for the 128-bit types.
+
+ static constexpr result_type min()
+ {
+ return result_type(0UL);
+ }
+
+ static constexpr result_type max()
+ {
+ return result_type(~result_type(0UL));
+ }
+
+protected:
+ itype bump(itype state)
+ {
+ return state * multiplier() + increment();
+ }
+
+ itype base_generate()
+ {
+ return state_ = bump(state_);
+ }
+
+ itype base_generate0()
+ {
+ itype old_state = state_;
+ state_ = bump(state_);
+ return old_state;
+ }
+
+public:
+ result_type operator()()
+ {
+ if (output_previous)
+ return this->output(base_generate0());
+ else
+ return this->output(base_generate());
+ }
+
+ result_type operator()(result_type upper_bound)
+ {
+ return bounded_rand(*this, upper_bound);
+ }
+
+protected:
+ static itype advance(itype state, itype delta,
+ itype cur_mult, itype cur_plus);
+
+ static itype distance(itype cur_state, itype newstate, itype cur_mult,
+ itype cur_plus, itype mask = ~itype(0U));
+
+ itype distance(itype newstate, itype mask = itype(~itype(0U))) const
+ {
+ return distance(state_, newstate, multiplier(), increment(), mask);
+ }
+
+public:
+ void advance(itype delta)
+ {
+ state_ = advance(state_, delta, this->multiplier(), this->increment());
+ }
+
+ void backstep(itype delta)
+ {
+ advance(-delta);
+ }
+
+ void discard(itype delta)
+ {
+ advance(delta);
+ }
+
+ bool wrapped()
+ {
+ if (stream_mixin::is_mcg) {
+ // For MCGs, the low order two bits never change. In this
+ // implementation, we keep them fixed at 3 to make this test
+ // easier.
+ return state_ == 3;
+ } else {
+ return state_ == 0;
+ }
+ }
+
+ engine(itype state = itype(0xcafef00dd15ea5e5ULL))
+ : state_(this->is_mcg ? state|state_type(3U)
+ : bump(state + this->increment()))
+ {
+ // Nothing else to do.
+ }
+
+ // This function may or may not exist. It thus has to be a template
+ // to use SFINAE; users don't have to worry about its template-ness.
+
+ template <typename sm = stream_mixin>
+ engine(itype state, typename sm::stream_state stream_seed)
+ : stream_mixin(stream_seed),
+ state_(this->is_mcg ? state|state_type(3U)
+ : bump(state + this->increment()))
+ {
+ // Nothing else to do.
+ }
+
+ template<typename SeedSeq>
+ engine(SeedSeq&& seedSeq, typename std::enable_if<
+ !stream_mixin::can_specify_stream
+ && !std::is_convertible<SeedSeq, itype>::value
+ && !std::is_convertible<SeedSeq, engine>::value,
+ no_specifiable_stream_tag>::type = {})
+ : engine(generate_one<itype>(std::forward<SeedSeq>(seedSeq)))
+ {
+ // Nothing else to do.
+ }
+
+ template<typename SeedSeq>
+ engine(SeedSeq&& seedSeq, typename std::enable_if<
+ stream_mixin::can_specify_stream
+ && !std::is_convertible<SeedSeq, itype>::value
+ && !std::is_convertible<SeedSeq, engine>::value,
+ can_specify_stream_tag>::type = {})
+ {
+ itype seeddata[2];
+ generate_to<2>(std::forward<SeedSeq>(seedSeq), seeddata);
+ seed(seeddata[1], seeddata[0]);
+ }
+
+
+ template<typename... Args>
+ void seed(Args&&... args)
+ {
+ new (this) engine(std::forward<Args>(args)...);
+ }
+
+ template <typename xtype1, typename itype1,
+ typename output_mixin1, bool output_previous1,
+ typename stream_mixin_lhs, typename multiplier_mixin_lhs,
+ typename stream_mixin_rhs, typename multiplier_mixin_rhs>
+ friend bool operator==(const engine<xtype1,itype1,
+ output_mixin1,output_previous1,
+ stream_mixin_lhs, multiplier_mixin_lhs>&,
+ const engine<xtype1,itype1,
+ output_mixin1,output_previous1,
+ stream_mixin_rhs, multiplier_mixin_rhs>&);
+
+ template <typename xtype1, typename itype1,
+ typename output_mixin1, bool output_previous1,
+ typename stream_mixin_lhs, typename multiplier_mixin_lhs,
+ typename stream_mixin_rhs, typename multiplier_mixin_rhs>
+ friend itype1 operator-(const engine<xtype1,itype1,
+ output_mixin1,output_previous1,
+ stream_mixin_lhs, multiplier_mixin_lhs>&,
+ const engine<xtype1,itype1,
+ output_mixin1,output_previous1,
+ stream_mixin_rhs, multiplier_mixin_rhs>&);
+
+ template <typename CharT, typename Traits,
+ typename xtype1, typename itype1,
+ typename output_mixin1, bool output_previous1,
+ typename stream_mixin1, typename multiplier_mixin1>
+ friend std::basic_ostream<CharT,Traits>&
+ operator<<(std::basic_ostream<CharT,Traits>& out,
+ const engine<xtype1,itype1,
+ output_mixin1,output_previous1,
+ stream_mixin1, multiplier_mixin1>&);
+
+ template <typename CharT, typename Traits,
+ typename xtype1, typename itype1,
+ typename output_mixin1, bool output_previous1,
+ typename stream_mixin1, typename multiplier_mixin1>
+ friend std::basic_istream<CharT,Traits>&
+ operator>>(std::basic_istream<CharT,Traits>& in,
+ engine<xtype1, itype1,
+ output_mixin1, output_previous1,
+ stream_mixin1, multiplier_mixin1>& rng);
+};
+
+template <typename CharT, typename Traits,
+ typename xtype, typename itype,
+ typename output_mixin, bool output_previous,
+ typename stream_mixin, typename multiplier_mixin>
+std::basic_ostream<CharT,Traits>&
+operator<<(std::basic_ostream<CharT,Traits>& out,
+ const engine<xtype,itype,
+ output_mixin,output_previous,
+ stream_mixin, multiplier_mixin>& rng)
+{
+ using pcg_extras::operator<<;
+
+ auto orig_flags = out.flags(std::ios_base::dec | std::ios_base::left);
+ auto space = out.widen(' ');
+ auto orig_fill = out.fill();
+
+ out << rng.multiplier() << space
+ << rng.increment() << space
+ << rng.state_;
+
+ out.flags(orig_flags);
+ out.fill(orig_fill);
+ return out;
+}
+
+
+template <typename CharT, typename Traits,
+ typename xtype, typename itype,
+ typename output_mixin, bool output_previous,
+ typename stream_mixin, typename multiplier_mixin>
+std::basic_istream<CharT,Traits>&
+operator>>(std::basic_istream<CharT,Traits>& in,
+ engine<xtype,itype,
+ output_mixin,output_previous,
+ stream_mixin, multiplier_mixin>& rng)
+{
+ using pcg_extras::operator>>;
+
+ auto orig_flags = in.flags(std::ios_base::dec | std::ios_base::skipws);
+
+ itype multiplier, increment, state;
+ in >> multiplier >> increment >> state;
+
+ if (!in.fail()) {
+ bool good = true;
+ if (multiplier != rng.multiplier()) {
+ good = false;
+ } else if (rng.can_specify_stream) {
+ rng.set_stream(increment >> 1);
+ } else if (increment != rng.increment()) {
+ good = false;
+ }
+ if (good) {
+ rng.state_ = state;
+ } else {
+ in.clear(std::ios::failbit);
+ }
+ }
+
+ in.flags(orig_flags);
+ return in;
+}
+
+
+template <typename xtype, typename itype,
+ typename output_mixin, bool output_previous,
+ typename stream_mixin, typename multiplier_mixin>
+itype engine<xtype,itype,output_mixin,output_previous,stream_mixin,
+ multiplier_mixin>::advance(
+ itype state, itype delta, itype cur_mult, itype cur_plus)
+{
+ // The method used here is based on Brown, "Random Number Generation
+ // with Arbitrary Stride,", Transactions of the American Nuclear
+ // Society (Nov. 1994). The algorithm is very similar to fast
+ // exponentiation.
+ //
+ // Even though delta is an unsigned integer, we can pass a
+ // signed integer to go backwards, it just goes "the long way round".
+
+ constexpr itype ZERO = 0u; // itype may be a non-trivial types, so
+ constexpr itype ONE = 1u; // we define some ugly constants.
+ itype acc_mult = 1;
+ itype acc_plus = 0;
+ while (delta > ZERO) {
+ if (delta & ONE) {
+ acc_mult *= cur_mult;
+ acc_plus = acc_plus*cur_mult + cur_plus;
+ }
+ cur_plus = (cur_mult+ONE)*cur_plus;
+ cur_mult *= cur_mult;
+ delta >>= 1;
+ }
+ return acc_mult * state + acc_plus;
+}
+
+template <typename xtype, typename itype,
+ typename output_mixin, bool output_previous,
+ typename stream_mixin, typename multiplier_mixin>
+itype engine<xtype,itype,output_mixin,output_previous,stream_mixin,
+ multiplier_mixin>::distance(
+ itype cur_state, itype newstate, itype cur_mult, itype cur_plus, itype mask)
+{
+ constexpr itype ONE = 1u; // itype could be weird, so use constant
+ bool is_mcg = cur_plus == itype(0);
+ itype the_bit = is_mcg ? itype(4u) : itype(1u);
+ itype distance = 0u;
+ while ((cur_state & mask) != (newstate & mask)) {
+ if ((cur_state & the_bit) != (newstate & the_bit)) {
+ cur_state = cur_state * cur_mult + cur_plus;
+ distance |= the_bit;
+ }
+ assert((cur_state & the_bit) == (newstate & the_bit));
+ the_bit <<= 1;
+ cur_plus = (cur_mult+ONE)*cur_plus;
+ cur_mult *= cur_mult;
+ }
+ return is_mcg ? distance >> 2 : distance;
+}
+
+template <typename xtype, typename itype,
+ typename output_mixin, bool output_previous,
+ typename stream_mixin_lhs, typename multiplier_mixin_lhs,
+ typename stream_mixin_rhs, typename multiplier_mixin_rhs>
+itype operator-(const engine<xtype,itype,
+ output_mixin,output_previous,
+ stream_mixin_lhs, multiplier_mixin_lhs>& lhs,
+ const engine<xtype,itype,
+ output_mixin,output_previous,
+ stream_mixin_rhs, multiplier_mixin_rhs>& rhs)
+{
+ static_assert(
+ std::is_same<stream_mixin_lhs, stream_mixin_rhs>::value &&
+ std::is_same<multiplier_mixin_lhs, multiplier_mixin_rhs>::value,
+ "Incomparable generators");
+ if (lhs.increment() == rhs.increment()) {
+ return rhs.distance(lhs.state_);
+ } else {
+ constexpr itype ONE = 1u;
+ itype lhs_diff = lhs.increment() + (lhs.multiplier()-ONE) * lhs.state_;
+ itype rhs_diff = rhs.increment() + (rhs.multiplier()-ONE) * rhs.state_;
+ if ((lhs_diff & itype(3u)) != (rhs_diff & itype(3u))) {
+ rhs_diff = -rhs_diff;
+ }
+ return rhs.distance(rhs_diff, lhs_diff, rhs.multiplier(), itype(0u));
+ }
+}
+
+
+template <typename xtype, typename itype,
+ typename output_mixin, bool output_previous,
+ typename stream_mixin_lhs, typename multiplier_mixin_lhs,
+ typename stream_mixin_rhs, typename multiplier_mixin_rhs>
+bool operator==(const engine<xtype,itype,
+ output_mixin,output_previous,
+ stream_mixin_lhs, multiplier_mixin_lhs>& lhs,
+ const engine<xtype,itype,
+ output_mixin,output_previous,
+ stream_mixin_rhs, multiplier_mixin_rhs>& rhs)
+{
+ return (lhs.multiplier() == rhs.multiplier())
+ && (lhs.increment() == rhs.increment())
+ && (lhs.state_ == rhs.state_);
+}
+
+template <typename xtype, typename itype,
+ typename output_mixin, bool output_previous,
+ typename stream_mixin_lhs, typename multiplier_mixin_lhs,
+ typename stream_mixin_rhs, typename multiplier_mixin_rhs>
+inline bool operator!=(const engine<xtype,itype,
+ output_mixin,output_previous,
+ stream_mixin_lhs, multiplier_mixin_lhs>& lhs,
+ const engine<xtype,itype,
+ output_mixin,output_previous,
+ stream_mixin_rhs, multiplier_mixin_rhs>& rhs)
+{
+ return !operator==(lhs,rhs);
+}
+
+
+template <typename xtype, typename itype,
+ template<typename XT,typename IT> class output_mixin,
+ bool output_previous = (sizeof(itype) <= 8),
+ template<typename IT> class multiplier_mixin = default_multiplier>
+using oneseq_base = engine<xtype, itype,
+ output_mixin<xtype, itype>, output_previous,
+ oneseq_stream<itype>,
+ multiplier_mixin<itype> >;
+
+template <typename xtype, typename itype,
+ template<typename XT,typename IT> class output_mixin,
+ bool output_previous = (sizeof(itype) <= 8),
+ template<typename IT> class multiplier_mixin = default_multiplier>
+using unique_base = engine<xtype, itype,
+ output_mixin<xtype, itype>, output_previous,
+ unique_stream<itype>,
+ multiplier_mixin<itype> >;
+
+template <typename xtype, typename itype,
+ template<typename XT,typename IT> class output_mixin,
+ bool output_previous = (sizeof(itype) <= 8),
+ template<typename IT> class multiplier_mixin = default_multiplier>
+using setseq_base = engine<xtype, itype,
+ output_mixin<xtype, itype>, output_previous,
+ specific_stream<itype>,
+ multiplier_mixin<itype> >;
+
+template <typename xtype, typename itype,
+ template<typename XT,typename IT> class output_mixin,
+ bool output_previous = (sizeof(itype) <= 8),
+ template<typename IT> class multiplier_mixin = default_multiplier>
+using mcg_base = engine<xtype, itype,
+ output_mixin<xtype, itype>, output_previous,
+ no_stream<itype>,
+ multiplier_mixin<itype> >;
+
+/*
+ * OUTPUT FUNCTIONS.
+ *
+ * These are the core of the PCG generation scheme. They specify how to
+ * turn the base LCG's internal state into the output value of the final
+ * generator.
+ *
+ * They're implemented as mixin classes.
+ *
+ * All of the classes have code that is written to allow it to be applied
+ * at *arbitrary* bit sizes, although in practice they'll only be used at
+ * standard sizes supported by C++.
+ */
+
+/*
+ * XSH RS -- high xorshift, followed by a random shift
+ *
+ * Fast. A good performer.
+ */
+
+template <typename xtype, typename itype>
+struct xsh_rs_mixin {
+ static xtype output(itype internal)
+ {
+ constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8);
+ constexpr bitcount_t xtypebits = bitcount_t(sizeof(xtype) * 8);
+ constexpr bitcount_t sparebits = bits - xtypebits;
+ constexpr bitcount_t opbits =
+ sparebits-5 >= 64 ? 5
+ : sparebits-4 >= 32 ? 4
+ : sparebits-3 >= 16 ? 3
+ : sparebits-2 >= 4 ? 2
+ : sparebits-1 >= 1 ? 1
+ : 0;
+ constexpr bitcount_t mask = (1 << opbits) - 1;
+ constexpr bitcount_t maxrandshift = mask;
+ constexpr bitcount_t topspare = opbits;
+ constexpr bitcount_t bottomspare = sparebits - topspare;
+ constexpr bitcount_t xshift = topspare + (xtypebits+maxrandshift)/2;
+ bitcount_t rshift =
+ opbits ? bitcount_t(internal >> (bits - opbits)) & mask : 0;
+ internal ^= internal >> xshift;
+ xtype result = xtype(internal >> (bottomspare - maxrandshift + rshift));
+ return result;
+ }
+};
+
+/*
+ * XSH RR -- high xorshift, followed by a random rotate
+ *
+ * Fast. A good performer. Slightly better statistically than XSH RS.
+ */
+
+template <typename xtype, typename itype>
+struct xsh_rr_mixin {
+ static xtype output(itype internal)
+ {
+ constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8);
+ constexpr bitcount_t xtypebits = bitcount_t(sizeof(xtype)*8);
+ constexpr bitcount_t sparebits = bits - xtypebits;
+ constexpr bitcount_t wantedopbits =
+ xtypebits >= 128 ? 7
+ : xtypebits >= 64 ? 6
+ : xtypebits >= 32 ? 5
+ : xtypebits >= 16 ? 4
+ : 3;
+ constexpr bitcount_t opbits =
+ sparebits >= wantedopbits ? wantedopbits
+ : sparebits;
+ constexpr bitcount_t amplifier = wantedopbits - opbits;
+ constexpr bitcount_t mask = (1 << opbits) - 1;
+ constexpr bitcount_t topspare = opbits;
+ constexpr bitcount_t bottomspare = sparebits - topspare;
+ constexpr bitcount_t xshift = (topspare + xtypebits)/2;
+ bitcount_t rot = opbits ? bitcount_t(internal >> (bits - opbits)) & mask
+ : 0;
+ bitcount_t amprot = (rot << amplifier) & mask;
+ internal ^= internal >> xshift;
+ xtype result = xtype(internal >> bottomspare);
+ result = rotr(result, amprot);
+ return result;
+ }
+};
+
+/*
+ * RXS -- random xorshift
+ */
+
+template <typename xtype, typename itype>
+struct rxs_mixin {
+static xtype output_rxs(itype internal)
+ {
+ constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8);
+ constexpr bitcount_t xtypebits = bitcount_t(sizeof(xtype)*8);
+ constexpr bitcount_t shift = bits - xtypebits;
+ constexpr bitcount_t extrashift = (xtypebits - shift)/2;
+ bitcount_t rshift = shift > 64+8 ? (internal >> (bits - 6)) & 63
+ : shift > 32+4 ? (internal >> (bits - 5)) & 31
+ : shift > 16+2 ? (internal >> (bits - 4)) & 15
+ : shift > 8+1 ? (internal >> (bits - 3)) & 7
+ : shift > 4+1 ? (internal >> (bits - 2)) & 3
+ : shift > 2+1 ? (internal >> (bits - 1)) & 1
+ : 0;
+ internal ^= internal >> (shift + extrashift - rshift);
+ xtype result = internal >> rshift;
+ return result;
+ }
+};
+
+/*
+ * RXS M XS -- random xorshift, mcg multiply, fixed xorshift
+ *
+ * The most statistically powerful generator, but all those steps
+ * make it slower than some of the others. We give it the rottenest jobs.
+ *
+ * Because it's usually used in contexts where the state type and the
+ * result type are the same, it is a permutation and is thus invertable.
+ * We thus provide a function to invert it. This function is used to
+ * for the "inside out" generator used by the extended generator.
+ */
+
+/* Defined type-based concepts for the multiplication step. They're actually
+ * all derived by truncating the 128-bit, which was computed to be a good
+ * "universal" constant.
+ */
+
+template <typename T>
+struct mcg_multiplier {
+ // Not defined for an arbitrary type
+};
+
+template <typename T>
+struct mcg_unmultiplier {
+ // Not defined for an arbitrary type
+};
+
+PCG_DEFINE_CONSTANT(uint8_t, mcg, multiplier, 217U)
+PCG_DEFINE_CONSTANT(uint8_t, mcg, unmultiplier, 105U)
+
+PCG_DEFINE_CONSTANT(uint16_t, mcg, multiplier, 62169U)
+PCG_DEFINE_CONSTANT(uint16_t, mcg, unmultiplier, 28009U)
+
+PCG_DEFINE_CONSTANT(uint32_t, mcg, multiplier, 277803737U)
+PCG_DEFINE_CONSTANT(uint32_t, mcg, unmultiplier, 2897767785U)
+
+PCG_DEFINE_CONSTANT(uint64_t, mcg, multiplier, 12605985483714917081ULL)
+PCG_DEFINE_CONSTANT(uint64_t, mcg, unmultiplier, 15009553638781119849ULL)
+
+PCG_DEFINE_CONSTANT(pcg128_t, mcg, multiplier,
+ PCG_128BIT_CONSTANT(17766728186571221404ULL, 12605985483714917081ULL))
+PCG_DEFINE_CONSTANT(pcg128_t, mcg, unmultiplier,
+ PCG_128BIT_CONSTANT(14422606686972528997ULL, 15009553638781119849ULL))
+
+
+template <typename xtype, typename itype>
+struct rxs_m_xs_mixin {
+ static xtype output(itype internal)
+ {
+ constexpr bitcount_t xtypebits = bitcount_t(sizeof(xtype) * 8);
+ constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8);
+ constexpr bitcount_t opbits = xtypebits >= 128 ? 6
+ : xtypebits >= 64 ? 5
+ : xtypebits >= 32 ? 4
+ : xtypebits >= 16 ? 3
+ : 2;
+ constexpr bitcount_t shift = bits - xtypebits;
+ constexpr bitcount_t mask = (1 << opbits) - 1;
+ bitcount_t rshift =
+ opbits ? bitcount_t(internal >> (bits - opbits)) & mask : 0;
+ internal ^= internal >> (opbits + rshift);
+ internal *= mcg_multiplier<itype>::multiplier();
+ xtype result = internal >> shift;
+ result ^= result >> ((2U*xtypebits+2U)/3U);
+ return result;
+ }
+
+ static itype unoutput(itype internal)
+ {
+ constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8);
+ constexpr bitcount_t opbits = bits >= 128 ? 6
+ : bits >= 64 ? 5
+ : bits >= 32 ? 4
+ : bits >= 16 ? 3
+ : 2;
+ constexpr bitcount_t mask = (1 << opbits) - 1;
+
+ internal = unxorshift(internal, bits, (2U*bits+2U)/3U);
+
+ internal *= mcg_unmultiplier<itype>::unmultiplier();
+
+ bitcount_t rshift = opbits ? (internal >> (bits - opbits)) & mask : 0;
+ internal = unxorshift(internal, bits, opbits + rshift);
+
+ return internal;
+ }
+};
+
+
+/*
+ * RXS M -- random xorshift, mcg multiply
+ */
+
+template <typename xtype, typename itype>
+struct rxs_m_mixin {
+ static xtype output(itype internal)
+ {
+ constexpr bitcount_t xtypebits = bitcount_t(sizeof(xtype) * 8);
+ constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8);
+ constexpr bitcount_t opbits = xtypebits >= 128 ? 6
+ : xtypebits >= 64 ? 5
+ : xtypebits >= 32 ? 4
+ : xtypebits >= 16 ? 3
+ : 2;
+ constexpr bitcount_t shift = bits - xtypebits;
+ constexpr bitcount_t mask = (1 << opbits) - 1;
+ bitcount_t rshift = opbits ? (internal >> (bits - opbits)) & mask : 0;
+ internal ^= internal >> (opbits + rshift);
+ internal *= mcg_multiplier<itype>::multiplier();
+ xtype result = internal >> shift;
+ return result;
+ }
+};
+
+
+/*
+ * DXSM -- double xorshift multiply
+ *
+ * This is a new, more powerful output permutation (added in 2019). It's
+ * a more comprehensive scrambling than RXS M, but runs faster on 128-bit
+ * types. Although primarily intended for use at large sizes, also works
+ * at smaller sizes as well.
+ *
+ * This permutation is similar to xorshift multiply hash functions, except
+ * that one of the multipliers is the LCG multiplier (to avoid needing to
+ * have a second constant) and the other is based on the low-order bits.
+ * This latter aspect means that the scrambling applied to the high bits
+ * depends on the low bits, and makes it (to my eye) impractical to back
+ * out the permutation without having the low-order bits.
+ */
+
+template <typename xtype, typename itype>
+struct dxsm_mixin {
+ inline xtype output(itype internal)
+ {
+ constexpr bitcount_t xtypebits = bitcount_t(sizeof(xtype) * 8);
+ constexpr bitcount_t itypebits = bitcount_t(sizeof(itype) * 8);
+ static_assert(xtypebits <= itypebits/2,
+ "Output type must be half the size of the state type.");
+
+ xtype hi = xtype(internal >> (itypebits - xtypebits));
+ xtype lo = xtype(internal);
+
+ lo |= 1;
+ hi ^= hi >> (xtypebits/2);
+ hi *= xtype(cheap_multiplier<itype>::multiplier());
+ hi ^= hi >> (3*(xtypebits/4));
+ hi *= lo;
+ return hi;
+ }
+};
+
+
+/*
+ * XSL RR -- fixed xorshift (to low bits), random rotate
+ *
+ * Useful for 128-bit types that are split across two CPU registers.
+ */
+
+template <typename xtype, typename itype>
+struct xsl_rr_mixin {
+ static xtype output(itype internal)
+ {
+ constexpr bitcount_t xtypebits = bitcount_t(sizeof(xtype) * 8);
+ constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8);
+ constexpr bitcount_t sparebits = bits - xtypebits;
+ constexpr bitcount_t wantedopbits = xtypebits >= 128 ? 7
+ : xtypebits >= 64 ? 6
+ : xtypebits >= 32 ? 5
+ : xtypebits >= 16 ? 4
+ : 3;
+ constexpr bitcount_t opbits = sparebits >= wantedopbits ? wantedopbits
+ : sparebits;
+ constexpr bitcount_t amplifier = wantedopbits - opbits;
+ constexpr bitcount_t mask = (1 << opbits) - 1;
+ constexpr bitcount_t topspare = sparebits;
+ constexpr bitcount_t bottomspare = sparebits - topspare;
+ constexpr bitcount_t xshift = (topspare + xtypebits) / 2;
+
+ bitcount_t rot =
+ opbits ? bitcount_t(internal >> (bits - opbits)) & mask : 0;
+ bitcount_t amprot = (rot << amplifier) & mask;
+ internal ^= internal >> xshift;
+ xtype result = xtype(internal >> bottomspare);
+ result = rotr(result, amprot);
+ return result;
+ }
+};
+
+
+/*
+ * XSL RR RR -- fixed xorshift (to low bits), random rotate (both parts)
+ *
+ * Useful for 128-bit types that are split across two CPU registers.
+ * If you really want an invertable 128-bit RNG, I guess this is the one.
+ */
+
+template <typename T> struct halfsize_trait {};
+template <> struct halfsize_trait<pcg128_t> { typedef uint64_t type; };
+template <> struct halfsize_trait<uint64_t> { typedef uint32_t type; };
+template <> struct halfsize_trait<uint32_t> { typedef uint16_t type; };
+template <> struct halfsize_trait<uint16_t> { typedef uint8_t type; };
+
+template <typename xtype, typename itype>
+struct xsl_rr_rr_mixin {
+ typedef typename halfsize_trait<itype>::type htype;
+
+ static itype output(itype internal)
+ {
+ constexpr bitcount_t htypebits = bitcount_t(sizeof(htype) * 8);
+ constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8);
+ constexpr bitcount_t sparebits = bits - htypebits;
+ constexpr bitcount_t wantedopbits = htypebits >= 128 ? 7
+ : htypebits >= 64 ? 6
+ : htypebits >= 32 ? 5
+ : htypebits >= 16 ? 4
+ : 3;
+ constexpr bitcount_t opbits = sparebits >= wantedopbits ? wantedopbits
+ : sparebits;
+ constexpr bitcount_t amplifier = wantedopbits - opbits;
+ constexpr bitcount_t mask = (1 << opbits) - 1;
+ constexpr bitcount_t topspare = sparebits;
+ constexpr bitcount_t xshift = (topspare + htypebits) / 2;
+
+ bitcount_t rot =
+ opbits ? bitcount_t(internal >> (bits - opbits)) & mask : 0;
+ bitcount_t amprot = (rot << amplifier) & mask;
+ internal ^= internal >> xshift;
+ htype lowbits = htype(internal);
+ lowbits = rotr(lowbits, amprot);
+ htype highbits = htype(internal >> topspare);
+ bitcount_t rot2 = lowbits & mask;
+ bitcount_t amprot2 = (rot2 << amplifier) & mask;
+ highbits = rotr(highbits, amprot2);
+ return (itype(highbits) << topspare) ^ itype(lowbits);
+ }
+};
+
+
+/*
+ * XSH -- fixed xorshift (to high bits)
+ *
+ * You shouldn't use this at 64-bits or less.
+ */
+
+template <typename xtype, typename itype>
+struct xsh_mixin {
+ static xtype output(itype internal)
+ {
+ constexpr bitcount_t xtypebits = bitcount_t(sizeof(xtype) * 8);
+ constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8);
+ constexpr bitcount_t sparebits = bits - xtypebits;
+ constexpr bitcount_t topspare = 0;
+ constexpr bitcount_t bottomspare = sparebits - topspare;
+ constexpr bitcount_t xshift = (topspare + xtypebits) / 2;
+
+ internal ^= internal >> xshift;
+ xtype result = internal >> bottomspare;
+ return result;
+ }
+};
+
+/*
+ * XSL -- fixed xorshift (to low bits)
+ *
+ * You shouldn't use this at 64-bits or less.
+ */
+
+template <typename xtype, typename itype>
+struct xsl_mixin {
+ inline xtype output(itype internal)
+ {
+ constexpr bitcount_t xtypebits = bitcount_t(sizeof(xtype) * 8);
+ constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8);
+ constexpr bitcount_t sparebits = bits - xtypebits;
+ constexpr bitcount_t topspare = sparebits;
+ constexpr bitcount_t bottomspare = sparebits - topspare;
+ constexpr bitcount_t xshift = (topspare + xtypebits) / 2;
+
+ internal ^= internal >> xshift;
+ xtype result = internal >> bottomspare;
+ return result;
+ }
+};
+
+
+/* ---- End of Output Functions ---- */
+
+
+template <typename baseclass>
+struct inside_out : private baseclass {
+ inside_out() = delete;
+
+ typedef typename baseclass::result_type result_type;
+ typedef typename baseclass::state_type state_type;
+ static_assert(sizeof(result_type) == sizeof(state_type),
+ "Require a RNG whose output function is a permutation");
+
+ static bool external_step(result_type& randval, size_t i)
+ {
+ state_type state = baseclass::unoutput(randval);
+ state = state * baseclass::multiplier() + baseclass::increment()
+ + state_type(i*2);
+ result_type result = baseclass::output(state);
+ randval = result;
+ state_type zero =
+ baseclass::is_mcg ? state & state_type(3U) : state_type(0U);
+ return result == zero;
+ }
+
+ static bool external_advance(result_type& randval, size_t i,
+ result_type delta, bool forwards = true)
+ {
+ state_type state = baseclass::unoutput(randval);
+ state_type mult = baseclass::multiplier();
+ state_type inc = baseclass::increment() + state_type(i*2);
+ state_type zero =
+ baseclass::is_mcg ? state & state_type(3U) : state_type(0U);
+ state_type dist_to_zero = baseclass::distance(state, zero, mult, inc);
+ bool crosses_zero =
+ forwards ? dist_to_zero <= delta
+ : (-dist_to_zero) <= delta;
+ if (!forwards)
+ delta = -delta;
+ state = baseclass::advance(state, delta, mult, inc);
+ randval = baseclass::output(state);
+ return crosses_zero;
+ }
+};
+
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2, typename baseclass, typename extvalclass, bool kdd = true>
+class extended : public baseclass {
+public:
+ typedef typename baseclass::state_type state_type;
+ typedef typename baseclass::result_type result_type;
+ typedef inside_out<extvalclass> insideout;
+
+private:
+ static constexpr bitcount_t rtypebits = sizeof(result_type)*8;
+ static constexpr bitcount_t stypebits = sizeof(state_type)*8;
+
+ static constexpr bitcount_t tick_limit_pow2 = 64U;
+
+ static constexpr size_t table_size = 1UL << table_pow2;
+ static constexpr size_t table_shift = stypebits - table_pow2;
+ static constexpr state_type table_mask =
+ (state_type(1U) << table_pow2) - state_type(1U);
+
+ static constexpr bool may_tick =
+ (advance_pow2 < stypebits) && (advance_pow2 < tick_limit_pow2);
+ static constexpr size_t tick_shift = stypebits - advance_pow2;
+ static constexpr state_type tick_mask =
+ may_tick ? state_type(
+ (uint64_t(1) << (advance_pow2*may_tick)) - 1)
+ // ^-- stupidity to appease GCC warnings
+ : ~state_type(0U);
+
+ static constexpr bool may_tock = stypebits < tick_limit_pow2;
+
+ result_type data_[table_size];
+
+ PCG_NOINLINE void advance_table();
+
+ PCG_NOINLINE void advance_table(state_type delta, bool isForwards = true);
+
+ result_type& get_extended_value()
+ {
+ state_type state = this->state_;
+ if (kdd && baseclass::is_mcg) {
+ // The low order bits of an MCG are constant, so drop them.
+ state >>= 2;
+ }
+ size_t index = kdd ? state & table_mask
+ : state >> table_shift;
+
+ if (may_tick) {
+ bool tick = kdd ? (state & tick_mask) == state_type(0u)
+ : (state >> tick_shift) == state_type(0u);
+ if (tick)
+ advance_table();
+ }
+ if (may_tock) {
+ bool tock = state == state_type(0u);
+ if (tock)
+ advance_table();
+ }
+ return data_[index];
+ }
+
+public:
+ static constexpr size_t period_pow2()
+ {
+ return baseclass::period_pow2() + table_size*extvalclass::period_pow2();
+ }
+
+ PCG_ALWAYS_INLINE result_type operator()()
+ {
+ result_type rhs = get_extended_value();
+ result_type lhs = this->baseclass::operator()();
+ return lhs ^ rhs;
+ }
+
+ result_type operator()(result_type upper_bound)
+ {
+ return bounded_rand(*this, upper_bound);
+ }
+
+ void set(result_type wanted)
+ {
+ result_type& rhs = get_extended_value();
+ result_type lhs = this->baseclass::operator()();
+ rhs = lhs ^ wanted;
+ }
+
+ void advance(state_type distance, bool forwards = true);
+
+ void backstep(state_type distance)
+ {
+ advance(distance, false);
+ }
+
+ extended(const result_type* data)
+ : baseclass()
+ {
+ datainit(data);
+ }
+
+ extended(const result_type* data, state_type seed)
+ : baseclass(seed)
+ {
+ datainit(data);
+ }
+
+ // This function may or may not exist. It thus has to be a template
+ // to use SFINAE; users don't have to worry about its template-ness.
+
+ template <typename bc = baseclass>
+ extended(const result_type* data, state_type seed,
+ typename bc::stream_state stream_seed)
+ : baseclass(seed, stream_seed)
+ {
+ datainit(data);
+ }
+
+ extended()
+ : baseclass()
+ {
+ selfinit();
+ }
+
+ extended(state_type seed)
+ : baseclass(seed)
+ {
+ selfinit();
+ }
+
+ // This function may or may not exist. It thus has to be a template
+ // to use SFINAE; users don't have to worry about its template-ness.
+
+ template <typename bc = baseclass>
+ extended(state_type seed, typename bc::stream_state stream_seed)
+ : baseclass(seed, stream_seed)
+ {
+ selfinit();
+ }
+
+private:
+ void selfinit();
+ void datainit(const result_type* data);
+
+public:
+
+ template<typename SeedSeq, typename = typename std::enable_if<
+ !std::is_convertible<SeedSeq, result_type>::value
+ && !std::is_convertible<SeedSeq, extended>::value>::type>
+ extended(SeedSeq&& seedSeq)
+ : baseclass(seedSeq)
+ {
+ generate_to<table_size>(seedSeq, data_);
+ }
+
+ template<typename... Args>
+ void seed(Args&&... args)
+ {
+ new (this) extended(std::forward<Args>(args)...);
+ }
+
+ template <bitcount_t table_pow2_, bitcount_t advance_pow2_,
+ typename baseclass_, typename extvalclass_, bool kdd_>
+ friend bool operator==(const extended<table_pow2_, advance_pow2_,
+ baseclass_, extvalclass_, kdd_>&,
+ const extended<table_pow2_, advance_pow2_,
+ baseclass_, extvalclass_, kdd_>&);
+
+ template <typename CharT, typename Traits,
+ bitcount_t table_pow2_, bitcount_t advance_pow2_,
+ typename baseclass_, typename extvalclass_, bool kdd_>
+ friend std::basic_ostream<CharT,Traits>&
+ operator<<(std::basic_ostream<CharT,Traits>& out,
+ const extended<table_pow2_, advance_pow2_,
+ baseclass_, extvalclass_, kdd_>&);
+
+ template <typename CharT, typename Traits,
+ bitcount_t table_pow2_, bitcount_t advance_pow2_,
+ typename baseclass_, typename extvalclass_, bool kdd_>
+ friend std::basic_istream<CharT,Traits>&
+ operator>>(std::basic_istream<CharT,Traits>& in,
+ extended<table_pow2_, advance_pow2_,
+ baseclass_, extvalclass_, kdd_>&);
+
+};
+
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2,
+ typename baseclass, typename extvalclass, bool kdd>
+void extended<table_pow2,advance_pow2,baseclass,extvalclass,kdd>::datainit(
+ const result_type* data)
+{
+ for (size_t i = 0; i < table_size; ++i)
+ data_[i] = data[i];
+}
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2,
+ typename baseclass, typename extvalclass, bool kdd>
+void extended<table_pow2,advance_pow2,baseclass,extvalclass,kdd>::selfinit()
+{
+ // We need to fill the extended table with something, and we have
+ // very little provided data, so we use the base generator to
+ // produce values. Although not ideal (use a seed sequence, folks!),
+ // unexpected correlations are mitigated by
+ // - using XOR differences rather than the number directly
+ // - the way the table is accessed, its values *won't* be accessed
+ // in the same order the were written.
+ // - any strange correlations would only be apparent if we
+ // were to backstep the generator so that the base generator
+ // was generating the same values again
+ result_type lhs = baseclass::operator()();
+ result_type rhs = baseclass::operator()();
+ result_type xdiff = lhs - rhs;
+ for (size_t i = 0; i < table_size; ++i) {
+ data_[i] = baseclass::operator()() ^ xdiff;
+ }
+}
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2,
+ typename baseclass, typename extvalclass, bool kdd>
+bool operator==(const extended<table_pow2, advance_pow2,
+ baseclass, extvalclass, kdd>& lhs,
+ const extended<table_pow2, advance_pow2,
+ baseclass, extvalclass, kdd>& rhs)
+{
+ auto& base_lhs = static_cast<const baseclass&>(lhs);
+ auto& base_rhs = static_cast<const baseclass&>(rhs);
+ return base_lhs == base_rhs
+ && std::equal(
+ std::begin(lhs.data_), std::end(lhs.data_),
+ std::begin(rhs.data_)
+ );
+}
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2,
+ typename baseclass, typename extvalclass, bool kdd>
+inline bool operator!=(const extended<table_pow2, advance_pow2,
+ baseclass, extvalclass, kdd>& lhs,
+ const extended<table_pow2, advance_pow2,
+ baseclass, extvalclass, kdd>& rhs)
+{
+ return !operator==(lhs, rhs);
+}
+
+template <typename CharT, typename Traits,
+ bitcount_t table_pow2, bitcount_t advance_pow2,
+ typename baseclass, typename extvalclass, bool kdd>
+std::basic_ostream<CharT,Traits>&
+operator<<(std::basic_ostream<CharT,Traits>& out,
+ const extended<table_pow2, advance_pow2,
+ baseclass, extvalclass, kdd>& rng)
+{
+ using pcg_extras::operator<<;
+
+ auto orig_flags = out.flags(std::ios_base::dec | std::ios_base::left);
+ auto space = out.widen(' ');
+ auto orig_fill = out.fill();
+
+ out << rng.multiplier() << space
+ << rng.increment() << space
+ << rng.state_;
+
+ for (const auto& datum : rng.data_)
+ out << space << datum;
+
+ out.flags(orig_flags);
+ out.fill(orig_fill);
+ return out;
+}
+
+template <typename CharT, typename Traits,
+ bitcount_t table_pow2, bitcount_t advance_pow2,
+ typename baseclass, typename extvalclass, bool kdd>
+std::basic_istream<CharT,Traits>&
+operator>>(std::basic_istream<CharT,Traits>& in,
+ extended<table_pow2, advance_pow2,
+ baseclass, extvalclass, kdd>& rng)
+{
+ extended<table_pow2, advance_pow2, baseclass, extvalclass> new_rng;
+ auto& base_rng = static_cast<baseclass&>(new_rng);
+ in >> base_rng;
+
+ if (in.fail())
+ return in;
+
+ using pcg_extras::operator>>;
+
+ auto orig_flags = in.flags(std::ios_base::dec | std::ios_base::skipws);
+
+ for (auto& datum : new_rng.data_) {
+ in >> datum;
+ if (in.fail())
+ goto bail;
+ }
+
+ rng = new_rng;
+
+bail:
+ in.flags(orig_flags);
+ return in;
+}
+
+
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2,
+ typename baseclass, typename extvalclass, bool kdd>
+void
+extended<table_pow2,advance_pow2,baseclass,extvalclass,kdd>::advance_table()
+{
+ bool carry = false;
+ for (size_t i = 0; i < table_size; ++i) {
+ if (carry) {
+ carry = insideout::external_step(data_[i],i+1);
+ }
+ bool carry2 = insideout::external_step(data_[i],i+1);
+ carry = carry || carry2;
+ }
+}
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2,
+ typename baseclass, typename extvalclass, bool kdd>
+void
+extended<table_pow2,advance_pow2,baseclass,extvalclass,kdd>::advance_table(
+ state_type delta, bool isForwards)
+{
+ typedef typename baseclass::state_type base_state_t;
+ typedef typename extvalclass::state_type ext_state_t;
+ constexpr bitcount_t basebits = sizeof(base_state_t)*8;
+ constexpr bitcount_t extbits = sizeof(ext_state_t)*8;
+ static_assert(basebits <= extbits || advance_pow2 > 0,
+ "Current implementation might overflow its carry");
+
+ base_state_t carry = 0;
+ for (size_t i = 0; i < table_size; ++i) {
+ base_state_t total_delta = carry + delta;
+ ext_state_t trunc_delta = ext_state_t(total_delta);
+ if (basebits > extbits) {
+ carry = total_delta >> extbits;
+ } else {
+ carry = 0;
+ }
+ carry +=
+ insideout::external_advance(data_[i],i+1, trunc_delta, isForwards);
+ }
+}
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2,
+ typename baseclass, typename extvalclass, bool kdd>
+void extended<table_pow2,advance_pow2,baseclass,extvalclass,kdd>::advance(
+ state_type distance, bool forwards)
+{
+ static_assert(kdd,
+ "Efficient advance is too hard for non-kdd extension. "
+ "For a weak advance, cast to base class");
+ state_type zero =
+ baseclass::is_mcg ? this->state_ & state_type(3U) : state_type(0U);
+ if (may_tick) {
+ state_type ticks = distance >> (advance_pow2*may_tick);
+ // ^-- stupidity to appease GCC
+ // warnings
+ state_type adv_mask =
+ baseclass::is_mcg ? tick_mask << 2 : tick_mask;
+ state_type next_advance_distance = this->distance(zero, adv_mask);
+ if (!forwards)
+ next_advance_distance = (-next_advance_distance) & tick_mask;
+ if (next_advance_distance < (distance & tick_mask)) {
+ ++ticks;
+ }
+ if (ticks)
+ advance_table(ticks, forwards);
+ }
+ if (forwards) {
+ if (may_tock && this->distance(zero) <= distance)
+ advance_table();
+ baseclass::advance(distance);
+ } else {
+ if (may_tock && -(this->distance(zero)) <= distance)
+ advance_table(state_type(1U), false);
+ baseclass::advance(-distance);
+ }
+}
+
+} // namespace pcg_detail
+
+namespace pcg_engines {
+
+using namespace pcg_detail;
+
+/* Predefined types for XSH RS */
+
+typedef oneseq_base<uint8_t, uint16_t, xsh_rs_mixin> oneseq_xsh_rs_16_8;
+typedef oneseq_base<uint16_t, uint32_t, xsh_rs_mixin> oneseq_xsh_rs_32_16;
+typedef oneseq_base<uint32_t, uint64_t, xsh_rs_mixin> oneseq_xsh_rs_64_32;
+typedef oneseq_base<uint64_t, pcg128_t, xsh_rs_mixin> oneseq_xsh_rs_128_64;
+typedef oneseq_base<uint64_t, pcg128_t, xsh_rs_mixin, true, cheap_multiplier>
+ cm_oneseq_xsh_rs_128_64;
+
+typedef unique_base<uint8_t, uint16_t, xsh_rs_mixin> unique_xsh_rs_16_8;
+typedef unique_base<uint16_t, uint32_t, xsh_rs_mixin> unique_xsh_rs_32_16;
+typedef unique_base<uint32_t, uint64_t, xsh_rs_mixin> unique_xsh_rs_64_32;
+typedef unique_base<uint64_t, pcg128_t, xsh_rs_mixin> unique_xsh_rs_128_64;
+typedef unique_base<uint64_t, pcg128_t, xsh_rs_mixin, true, cheap_multiplier>
+ cm_unique_xsh_rs_128_64;
+
+typedef setseq_base<uint8_t, uint16_t, xsh_rs_mixin> setseq_xsh_rs_16_8;
+typedef setseq_base<uint16_t, uint32_t, xsh_rs_mixin> setseq_xsh_rs_32_16;
+typedef setseq_base<uint32_t, uint64_t, xsh_rs_mixin> setseq_xsh_rs_64_32;
+typedef setseq_base<uint64_t, pcg128_t, xsh_rs_mixin> setseq_xsh_rs_128_64;
+typedef setseq_base<uint64_t, pcg128_t, xsh_rs_mixin, true, cheap_multiplier>
+ cm_setseq_xsh_rs_128_64;
+
+typedef mcg_base<uint8_t, uint16_t, xsh_rs_mixin> mcg_xsh_rs_16_8;
+typedef mcg_base<uint16_t, uint32_t, xsh_rs_mixin> mcg_xsh_rs_32_16;
+typedef mcg_base<uint32_t, uint64_t, xsh_rs_mixin> mcg_xsh_rs_64_32;
+typedef mcg_base<uint64_t, pcg128_t, xsh_rs_mixin> mcg_xsh_rs_128_64;
+typedef mcg_base<uint64_t, pcg128_t, xsh_rs_mixin, true, cheap_multiplier>
+ cm_mcg_xsh_rs_128_64;
+
+/* Predefined types for XSH RR */
+
+typedef oneseq_base<uint8_t, uint16_t, xsh_rr_mixin> oneseq_xsh_rr_16_8;
+typedef oneseq_base<uint16_t, uint32_t, xsh_rr_mixin> oneseq_xsh_rr_32_16;
+typedef oneseq_base<uint32_t, uint64_t, xsh_rr_mixin> oneseq_xsh_rr_64_32;
+typedef oneseq_base<uint64_t, pcg128_t, xsh_rr_mixin> oneseq_xsh_rr_128_64;
+typedef oneseq_base<uint64_t, pcg128_t, xsh_rr_mixin, true, cheap_multiplier>
+ cm_oneseq_xsh_rr_128_64;
+
+typedef unique_base<uint8_t, uint16_t, xsh_rr_mixin> unique_xsh_rr_16_8;
+typedef unique_base<uint16_t, uint32_t, xsh_rr_mixin> unique_xsh_rr_32_16;
+typedef unique_base<uint32_t, uint64_t, xsh_rr_mixin> unique_xsh_rr_64_32;
+typedef unique_base<uint64_t, pcg128_t, xsh_rr_mixin> unique_xsh_rr_128_64;
+typedef unique_base<uint64_t, pcg128_t, xsh_rr_mixin, true, cheap_multiplier>
+ cm_unique_xsh_rr_128_64;
+
+typedef setseq_base<uint8_t, uint16_t, xsh_rr_mixin> setseq_xsh_rr_16_8;
+typedef setseq_base<uint16_t, uint32_t, xsh_rr_mixin> setseq_xsh_rr_32_16;
+typedef setseq_base<uint32_t, uint64_t, xsh_rr_mixin> setseq_xsh_rr_64_32;
+typedef setseq_base<uint64_t, pcg128_t, xsh_rr_mixin> setseq_xsh_rr_128_64;
+typedef setseq_base<uint64_t, pcg128_t, xsh_rr_mixin, true, cheap_multiplier>
+ cm_setseq_xsh_rr_128_64;
+
+typedef mcg_base<uint8_t, uint16_t, xsh_rr_mixin> mcg_xsh_rr_16_8;
+typedef mcg_base<uint16_t, uint32_t, xsh_rr_mixin> mcg_xsh_rr_32_16;
+typedef mcg_base<uint32_t, uint64_t, xsh_rr_mixin> mcg_xsh_rr_64_32;
+typedef mcg_base<uint64_t, pcg128_t, xsh_rr_mixin> mcg_xsh_rr_128_64;
+typedef mcg_base<uint64_t, pcg128_t, xsh_rr_mixin, true, cheap_multiplier>
+ cm_mcg_xsh_rr_128_64;
+
+
+/* Predefined types for RXS M XS */
+
+typedef oneseq_base<uint8_t, uint8_t, rxs_m_xs_mixin> oneseq_rxs_m_xs_8_8;
+typedef oneseq_base<uint16_t, uint16_t, rxs_m_xs_mixin> oneseq_rxs_m_xs_16_16;
+typedef oneseq_base<uint32_t, uint32_t, rxs_m_xs_mixin> oneseq_rxs_m_xs_32_32;
+typedef oneseq_base<uint64_t, uint64_t, rxs_m_xs_mixin> oneseq_rxs_m_xs_64_64;
+typedef oneseq_base<pcg128_t, pcg128_t, rxs_m_xs_mixin>
+ oneseq_rxs_m_xs_128_128;
+typedef oneseq_base<pcg128_t, pcg128_t, rxs_m_xs_mixin, true, cheap_multiplier>
+ cm_oneseq_rxs_m_xs_128_128;
+
+typedef unique_base<uint8_t, uint8_t, rxs_m_xs_mixin> unique_rxs_m_xs_8_8;
+typedef unique_base<uint16_t, uint16_t, rxs_m_xs_mixin> unique_rxs_m_xs_16_16;
+typedef unique_base<uint32_t, uint32_t, rxs_m_xs_mixin> unique_rxs_m_xs_32_32;
+typedef unique_base<uint64_t, uint64_t, rxs_m_xs_mixin> unique_rxs_m_xs_64_64;
+typedef unique_base<pcg128_t, pcg128_t, rxs_m_xs_mixin> unique_rxs_m_xs_128_128;
+typedef unique_base<pcg128_t, pcg128_t, rxs_m_xs_mixin, true, cheap_multiplier>
+ cm_unique_rxs_m_xs_128_128;
+
+typedef setseq_base<uint8_t, uint8_t, rxs_m_xs_mixin> setseq_rxs_m_xs_8_8;
+typedef setseq_base<uint16_t, uint16_t, rxs_m_xs_mixin> setseq_rxs_m_xs_16_16;
+typedef setseq_base<uint32_t, uint32_t, rxs_m_xs_mixin> setseq_rxs_m_xs_32_32;
+typedef setseq_base<uint64_t, uint64_t, rxs_m_xs_mixin> setseq_rxs_m_xs_64_64;
+typedef setseq_base<pcg128_t, pcg128_t, rxs_m_xs_mixin> setseq_rxs_m_xs_128_128;
+typedef setseq_base<pcg128_t, pcg128_t, rxs_m_xs_mixin, true, cheap_multiplier>
+ cm_setseq_rxs_m_xs_128_128;
+
+ // MCG versions don't make sense here, so aren't defined.
+
+/* Predefined types for RXS M */
+
+typedef oneseq_base<uint8_t, uint16_t, rxs_m_mixin> oneseq_rxs_m_16_8;
+typedef oneseq_base<uint16_t, uint32_t, rxs_m_mixin> oneseq_rxs_m_32_16;
+typedef oneseq_base<uint32_t, uint64_t, rxs_m_mixin> oneseq_rxs_m_64_32;
+typedef oneseq_base<uint64_t, pcg128_t, rxs_m_mixin> oneseq_rxs_m_128_64;
+typedef oneseq_base<uint64_t, pcg128_t, rxs_m_mixin, true, cheap_multiplier>
+ cm_oneseq_rxs_m_128_64;
+
+typedef unique_base<uint8_t, uint16_t, rxs_m_mixin> unique_rxs_m_16_8;
+typedef unique_base<uint16_t, uint32_t, rxs_m_mixin> unique_rxs_m_32_16;
+typedef unique_base<uint32_t, uint64_t, rxs_m_mixin> unique_rxs_m_64_32;
+typedef unique_base<uint64_t, pcg128_t, rxs_m_mixin> unique_rxs_m_128_64;
+typedef unique_base<uint64_t, pcg128_t, rxs_m_mixin, true, cheap_multiplier>
+ cm_unique_rxs_m_128_64;
+
+typedef setseq_base<uint8_t, uint16_t, rxs_m_mixin> setseq_rxs_m_16_8;
+typedef setseq_base<uint16_t, uint32_t, rxs_m_mixin> setseq_rxs_m_32_16;
+typedef setseq_base<uint32_t, uint64_t, rxs_m_mixin> setseq_rxs_m_64_32;
+typedef setseq_base<uint64_t, pcg128_t, rxs_m_mixin> setseq_rxs_m_128_64;
+typedef setseq_base<uint64_t, pcg128_t, rxs_m_mixin, true, cheap_multiplier>
+ cm_setseq_rxs_m_128_64;
+
+typedef mcg_base<uint8_t, uint16_t, rxs_m_mixin> mcg_rxs_m_16_8;
+typedef mcg_base<uint16_t, uint32_t, rxs_m_mixin> mcg_rxs_m_32_16;
+typedef mcg_base<uint32_t, uint64_t, rxs_m_mixin> mcg_rxs_m_64_32;
+typedef mcg_base<uint64_t, pcg128_t, rxs_m_mixin> mcg_rxs_m_128_64;
+typedef mcg_base<uint64_t, pcg128_t, rxs_m_mixin, true, cheap_multiplier>
+ cm_mcg_rxs_m_128_64;
+
+/* Predefined types for DXSM */
+
+typedef oneseq_base<uint8_t, uint16_t, dxsm_mixin> oneseq_dxsm_16_8;
+typedef oneseq_base<uint16_t, uint32_t, dxsm_mixin> oneseq_dxsm_32_16;
+typedef oneseq_base<uint32_t, uint64_t, dxsm_mixin> oneseq_dxsm_64_32;
+typedef oneseq_base<uint64_t, pcg128_t, dxsm_mixin> oneseq_dxsm_128_64;
+typedef oneseq_base<uint64_t, pcg128_t, dxsm_mixin, true, cheap_multiplier>
+ cm_oneseq_dxsm_128_64;
+
+typedef unique_base<uint8_t, uint16_t, dxsm_mixin> unique_dxsm_16_8;
+typedef unique_base<uint16_t, uint32_t, dxsm_mixin> unique_dxsm_32_16;
+typedef unique_base<uint32_t, uint64_t, dxsm_mixin> unique_dxsm_64_32;
+typedef unique_base<uint64_t, pcg128_t, dxsm_mixin> unique_dxsm_128_64;
+typedef unique_base<uint64_t, pcg128_t, dxsm_mixin, true, cheap_multiplier>
+ cm_unique_dxsm_128_64;
+
+typedef setseq_base<uint8_t, uint16_t, dxsm_mixin> setseq_dxsm_16_8;
+typedef setseq_base<uint16_t, uint32_t, dxsm_mixin> setseq_dxsm_32_16;
+typedef setseq_base<uint32_t, uint64_t, dxsm_mixin> setseq_dxsm_64_32;
+typedef setseq_base<uint64_t, pcg128_t, dxsm_mixin> setseq_dxsm_128_64;
+typedef setseq_base<uint64_t, pcg128_t, dxsm_mixin, true, cheap_multiplier>
+ cm_setseq_dxsm_128_64;
+
+typedef mcg_base<uint8_t, uint16_t, dxsm_mixin> mcg_dxsm_16_8;
+typedef mcg_base<uint16_t, uint32_t, dxsm_mixin> mcg_dxsm_32_16;
+typedef mcg_base<uint32_t, uint64_t, dxsm_mixin> mcg_dxsm_64_32;
+typedef mcg_base<uint64_t, pcg128_t, dxsm_mixin> mcg_dxsm_128_64;
+typedef mcg_base<uint64_t, pcg128_t, dxsm_mixin, true, cheap_multiplier>
+ cm_mcg_dxsm_128_64;
+
+/* Predefined types for XSL RR (only defined for "large" types) */
+
+typedef oneseq_base<uint32_t, uint64_t, xsl_rr_mixin> oneseq_xsl_rr_64_32;
+typedef oneseq_base<uint64_t, pcg128_t, xsl_rr_mixin> oneseq_xsl_rr_128_64;
+typedef oneseq_base<uint64_t, pcg128_t, xsl_rr_mixin, true, cheap_multiplier>
+ cm_oneseq_xsl_rr_128_64;
+
+typedef unique_base<uint32_t, uint64_t, xsl_rr_mixin> unique_xsl_rr_64_32;
+typedef unique_base<uint64_t, pcg128_t, xsl_rr_mixin> unique_xsl_rr_128_64;
+typedef unique_base<uint64_t, pcg128_t, xsl_rr_mixin, true, cheap_multiplier>
+ cm_unique_xsl_rr_128_64;
+
+typedef setseq_base<uint32_t, uint64_t, xsl_rr_mixin> setseq_xsl_rr_64_32;
+typedef setseq_base<uint64_t, pcg128_t, xsl_rr_mixin> setseq_xsl_rr_128_64;
+typedef setseq_base<uint64_t, pcg128_t, xsl_rr_mixin, true, cheap_multiplier>
+ cm_setseq_xsl_rr_128_64;
+
+typedef mcg_base<uint32_t, uint64_t, xsl_rr_mixin> mcg_xsl_rr_64_32;
+typedef mcg_base<uint64_t, pcg128_t, xsl_rr_mixin> mcg_xsl_rr_128_64;
+typedef mcg_base<uint64_t, pcg128_t, xsl_rr_mixin, true, cheap_multiplier>
+ cm_mcg_xsl_rr_128_64;
+
+
+/* Predefined types for XSL RR RR (only defined for "large" types) */
+
+typedef oneseq_base<uint64_t, uint64_t, xsl_rr_rr_mixin>
+ oneseq_xsl_rr_rr_64_64;
+typedef oneseq_base<pcg128_t, pcg128_t, xsl_rr_rr_mixin>
+ oneseq_xsl_rr_rr_128_128;
+typedef oneseq_base<pcg128_t, pcg128_t, xsl_rr_rr_mixin, true, cheap_multiplier>
+ cm_oneseq_xsl_rr_rr_128_128;
+
+typedef unique_base<uint64_t, uint64_t, xsl_rr_rr_mixin>
+ unique_xsl_rr_rr_64_64;
+typedef unique_base<pcg128_t, pcg128_t, xsl_rr_rr_mixin>
+ unique_xsl_rr_rr_128_128;
+typedef unique_base<pcg128_t, pcg128_t, xsl_rr_rr_mixin, true, cheap_multiplier>
+ cm_unique_xsl_rr_rr_128_128;
+
+typedef setseq_base<uint64_t, uint64_t, xsl_rr_rr_mixin>
+ setseq_xsl_rr_rr_64_64;
+typedef setseq_base<pcg128_t, pcg128_t, xsl_rr_rr_mixin>
+ setseq_xsl_rr_rr_128_128;
+typedef setseq_base<pcg128_t, pcg128_t, xsl_rr_rr_mixin, true, cheap_multiplier>
+ cm_setseq_xsl_rr_rr_128_128;
+
+ // MCG versions don't make sense here, so aren't defined.
+
+/* Extended generators */
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2,
+ typename BaseRNG, bool kdd = true>
+using ext_std8 = extended<table_pow2, advance_pow2, BaseRNG,
+ oneseq_rxs_m_xs_8_8, kdd>;
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2,
+ typename BaseRNG, bool kdd = true>
+using ext_std16 = extended<table_pow2, advance_pow2, BaseRNG,
+ oneseq_rxs_m_xs_16_16, kdd>;
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2,
+ typename BaseRNG, bool kdd = true>
+using ext_std32 = extended<table_pow2, advance_pow2, BaseRNG,
+ oneseq_rxs_m_xs_32_32, kdd>;
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2,
+ typename BaseRNG, bool kdd = true>
+using ext_std64 = extended<table_pow2, advance_pow2, BaseRNG,
+ oneseq_rxs_m_xs_64_64, kdd>;
+
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2, bool kdd = true>
+using ext_oneseq_rxs_m_xs_32_32 =
+ ext_std32<table_pow2, advance_pow2, oneseq_rxs_m_xs_32_32, kdd>;
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2, bool kdd = true>
+using ext_mcg_xsh_rs_64_32 =
+ ext_std32<table_pow2, advance_pow2, mcg_xsh_rs_64_32, kdd>;
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2, bool kdd = true>
+using ext_oneseq_xsh_rs_64_32 =
+ ext_std32<table_pow2, advance_pow2, oneseq_xsh_rs_64_32, kdd>;
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2, bool kdd = true>
+using ext_setseq_xsh_rr_64_32 =
+ ext_std32<table_pow2, advance_pow2, setseq_xsh_rr_64_32, kdd>;
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2, bool kdd = true>
+using ext_mcg_xsl_rr_128_64 =
+ ext_std64<table_pow2, advance_pow2, mcg_xsl_rr_128_64, kdd>;
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2, bool kdd = true>
+using ext_oneseq_xsl_rr_128_64 =
+ ext_std64<table_pow2, advance_pow2, oneseq_xsl_rr_128_64, kdd>;
+
+template <bitcount_t table_pow2, bitcount_t advance_pow2, bool kdd = true>
+using ext_setseq_xsl_rr_128_64 =
+ ext_std64<table_pow2, advance_pow2, setseq_xsl_rr_128_64, kdd>;
+
+} // namespace pcg_engines
+
+typedef pcg_engines::setseq_xsh_rr_64_32 pcg32;
+typedef pcg_engines::oneseq_xsh_rr_64_32 pcg32_oneseq;
+typedef pcg_engines::unique_xsh_rr_64_32 pcg32_unique;
+typedef pcg_engines::mcg_xsh_rs_64_32 pcg32_fast;
+
+typedef pcg_engines::setseq_xsl_rr_128_64 pcg64;
+typedef pcg_engines::oneseq_xsl_rr_128_64 pcg64_oneseq;
+typedef pcg_engines::unique_xsl_rr_128_64 pcg64_unique;
+typedef pcg_engines::mcg_xsl_rr_128_64 pcg64_fast;
+
+typedef pcg_engines::setseq_rxs_m_xs_8_8 pcg8_once_insecure;
+typedef pcg_engines::setseq_rxs_m_xs_16_16 pcg16_once_insecure;
+typedef pcg_engines::setseq_rxs_m_xs_32_32 pcg32_once_insecure;
+typedef pcg_engines::setseq_rxs_m_xs_64_64 pcg64_once_insecure;
+typedef pcg_engines::setseq_xsl_rr_rr_128_128 pcg128_once_insecure;
+
+typedef pcg_engines::oneseq_rxs_m_xs_8_8 pcg8_oneseq_once_insecure;
+typedef pcg_engines::oneseq_rxs_m_xs_16_16 pcg16_oneseq_once_insecure;
+typedef pcg_engines::oneseq_rxs_m_xs_32_32 pcg32_oneseq_once_insecure;
+typedef pcg_engines::oneseq_rxs_m_xs_64_64 pcg64_oneseq_once_insecure;
+typedef pcg_engines::oneseq_xsl_rr_rr_128_128 pcg128_oneseq_once_insecure;
+
+
+// These two extended RNGs provide two-dimensionally equidistributed
+// 32-bit generators. pcg32_k2_fast occupies the same space as pcg64,
+// and can be called twice to generate 64 bits, but does not required
+// 128-bit math; on 32-bit systems, it's faster than pcg64 as well.
+
+typedef pcg_engines::ext_setseq_xsh_rr_64_32<1,16,true> pcg32_k2;
+typedef pcg_engines::ext_oneseq_xsh_rs_64_32<1,32,true> pcg32_k2_fast;
+
+// These eight extended RNGs have about as much state as arc4random
+//
+// - the k variants are k-dimensionally equidistributed
+// - the c variants offer better crypographic security
+//
+// (just how good the cryptographic security is is an open question)
+
+typedef pcg_engines::ext_setseq_xsh_rr_64_32<6,16,true> pcg32_k64;
+typedef pcg_engines::ext_mcg_xsh_rs_64_32<6,32,true> pcg32_k64_oneseq;
+typedef pcg_engines::ext_oneseq_xsh_rs_64_32<6,32,true> pcg32_k64_fast;
+
+typedef pcg_engines::ext_setseq_xsh_rr_64_32<6,16,false> pcg32_c64;
+typedef pcg_engines::ext_oneseq_xsh_rs_64_32<6,32,false> pcg32_c64_oneseq;
+typedef pcg_engines::ext_mcg_xsh_rs_64_32<6,32,false> pcg32_c64_fast;
+
+typedef pcg_engines::ext_setseq_xsl_rr_128_64<5,16,true> pcg64_k32;
+typedef pcg_engines::ext_oneseq_xsl_rr_128_64<5,128,true> pcg64_k32_oneseq;
+typedef pcg_engines::ext_mcg_xsl_rr_128_64<5,128,true> pcg64_k32_fast;
+
+typedef pcg_engines::ext_setseq_xsl_rr_128_64<5,16,false> pcg64_c32;
+typedef pcg_engines::ext_oneseq_xsl_rr_128_64<5,128,false> pcg64_c32_oneseq;
+typedef pcg_engines::ext_mcg_xsl_rr_128_64<5,128,false> pcg64_c32_fast;
+
+// These eight extended RNGs have more state than the Mersenne twister
+//
+// - the k variants are k-dimensionally equidistributed
+// - the c variants offer better crypographic security
+//
+// (just how good the cryptographic security is is an open question)
+
+typedef pcg_engines::ext_setseq_xsh_rr_64_32<10,16,true> pcg32_k1024;
+typedef pcg_engines::ext_oneseq_xsh_rs_64_32<10,32,true> pcg32_k1024_fast;
+
+typedef pcg_engines::ext_setseq_xsh_rr_64_32<10,16,false> pcg32_c1024;
+typedef pcg_engines::ext_oneseq_xsh_rs_64_32<10,32,false> pcg32_c1024_fast;
+
+typedef pcg_engines::ext_setseq_xsl_rr_128_64<10,16,true> pcg64_k1024;
+typedef pcg_engines::ext_oneseq_xsl_rr_128_64<10,128,true> pcg64_k1024_fast;
+
+typedef pcg_engines::ext_setseq_xsl_rr_128_64<10,16,false> pcg64_c1024;
+typedef pcg_engines::ext_oneseq_xsl_rr_128_64<10,128,false> pcg64_c1024_fast;
+
+// These generators have an insanely huge period (2^524352), and is suitable
+// for silly party tricks, such as dumping out 64 KB ZIP files at an arbitrary
+// point in the future. [Actually, over the full period of the generator, it
+// will produce every 64 KB ZIP file 2^64 times!]
+
+typedef pcg_engines::ext_setseq_xsh_rr_64_32<14,16,true> pcg32_k16384;
+typedef pcg_engines::ext_oneseq_xsh_rs_64_32<14,32,true> pcg32_k16384_fast;
+
+} // namespace arrow_vendored
+
+#ifdef _MSC_VER
+ #pragma warning(default:4146)
+#endif
+
+#endif // PCG_RAND_HPP_INCLUDED
diff --git a/src/arrow/cpp/src/arrow/vendored/pcg/pcg_uint128.hpp b/src/arrow/cpp/src/arrow/vendored/pcg/pcg_uint128.hpp
new file mode 100644
index 000000000..0181e69e4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/pcg/pcg_uint128.hpp
@@ -0,0 +1,1008 @@
+/*
+ * PCG Random Number Generation for C++
+ *
+ * Copyright 2014-2021 Melissa O'Neill <oneill@pcg-random.org>,
+ * and the PCG Project contributors.
+ *
+ * SPDX-License-Identifier: (Apache-2.0 OR MIT)
+ *
+ * Licensed under the Apache License, Version 2.0 (provided in
+ * LICENSE-APACHE.txt and at http://www.apache.org/licenses/LICENSE-2.0)
+ * or under the MIT license (provided in LICENSE-MIT.txt and at
+ * http://opensource.org/licenses/MIT), at your option. This file may not
+ * be copied, modified, or distributed except according to those terms.
+ *
+ * Distributed on an "AS IS" BASIS, WITHOUT WARRANTY OF ANY KIND, either
+ * express or implied. See your chosen license for details.
+ *
+ * For additional information about the PCG random number generation scheme,
+ * visit http://www.pcg-random.org/.
+ */
+
+/*
+ * This code provides a a C++ class that can provide 128-bit (or higher)
+ * integers. To produce 2K-bit integers, it uses two K-bit integers,
+ * placed in a union that allowes the code to also see them as four K/2 bit
+ * integers (and access them either directly name, or by index).
+ *
+ * It may seem like we're reinventing the wheel here, because several
+ * libraries already exist that support large integers, but most existing
+ * libraries provide a very generic multiprecision code, but here we're
+ * operating at a fixed size. Also, most other libraries are fairly
+ * heavyweight. So we use a direct implementation. Sadly, it's much slower
+ * than hand-coded assembly or direct CPU support.
+ */
+
+#ifndef PCG_UINT128_HPP_INCLUDED
+#define PCG_UINT128_HPP_INCLUDED 1
+
+#include <cstdint>
+#include <cstdio>
+#include <cassert>
+#include <climits>
+#include <utility>
+#include <initializer_list>
+#include <type_traits>
+
+#if defined(_MSC_VER) // Use MSVC++ intrinsics
+#include <intrin.h>
+#endif
+
+/*
+ * We want to lay the type out the same way that a native type would be laid
+ * out, which means we must know the machine's endian, at compile time.
+ * This ugliness attempts to do so.
+ */
+
+#ifndef PCG_LITTLE_ENDIAN
+ #if defined(__BYTE_ORDER__)
+ #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+ #define PCG_LITTLE_ENDIAN 1
+ #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+ #define PCG_LITTLE_ENDIAN 0
+ #else
+ #error __BYTE_ORDER__ does not match a standard endian, pick a side
+ #endif
+ #elif __LITTLE_ENDIAN__ || _LITTLE_ENDIAN
+ #define PCG_LITTLE_ENDIAN 1
+ #elif __BIG_ENDIAN__ || _BIG_ENDIAN
+ #define PCG_LITTLE_ENDIAN 0
+ #elif __x86_64 || __x86_64__ || _M_X64 || __i386 || __i386__ || _M_IX86
+ #define PCG_LITTLE_ENDIAN 1
+ #elif __powerpc__ || __POWERPC__ || __ppc__ || __PPC__ \
+ || __m68k__ || __mc68000__
+ #define PCG_LITTLE_ENDIAN 0
+ #else
+ #error Unable to determine target endianness
+ #endif
+#endif
+
+#if INTPTR_MAX == INT64_MAX && !defined(PCG_64BIT_SPECIALIZATIONS)
+ #define PCG_64BIT_SPECIALIZATIONS 1
+#endif
+
+namespace arrow_vendored {
+namespace pcg_extras {
+
+// Recent versions of GCC have intrinsics we can use to quickly calculate
+// the number of leading and trailing zeros in a number. If possible, we
+// use them, otherwise we fall back to old-fashioned bit twiddling to figure
+// them out.
+
+#ifndef PCG_BITCOUNT_T
+ typedef uint8_t bitcount_t;
+#else
+ typedef PCG_BITCOUNT_T bitcount_t;
+#endif
+
+/*
+ * Provide some useful helper functions
+ * * flog2 floor(log2(x))
+ * * trailingzeros number of trailing zero bits
+ */
+
+#if defined(__GNUC__) // Any GNU-compatible compiler supporting C++11 has
+ // some useful intrinsics we can use.
+
+inline bitcount_t flog2(uint32_t v)
+{
+ return 31 - __builtin_clz(v);
+}
+
+inline bitcount_t trailingzeros(uint32_t v)
+{
+ return __builtin_ctz(v);
+}
+
+inline bitcount_t flog2(uint64_t v)
+{
+#if UINT64_MAX == ULONG_MAX
+ return 63 - __builtin_clzl(v);
+#elif UINT64_MAX == ULLONG_MAX
+ return 63 - __builtin_clzll(v);
+#else
+ #error Cannot find a function for uint64_t
+#endif
+}
+
+inline bitcount_t trailingzeros(uint64_t v)
+{
+#if UINT64_MAX == ULONG_MAX
+ return __builtin_ctzl(v);
+#elif UINT64_MAX == ULLONG_MAX
+ return __builtin_ctzll(v);
+#else
+ #error Cannot find a function for uint64_t
+#endif
+}
+
+#elif defined(_MSC_VER) // Use MSVC++ intrinsics
+
+#pragma intrinsic(_BitScanReverse, _BitScanForward)
+#if defined(_M_X64) || defined(_M_ARM) || defined(_M_ARM64)
+#pragma intrinsic(_BitScanReverse64, _BitScanForward64)
+#endif
+
+inline bitcount_t flog2(uint32_t v)
+{
+ unsigned long i;
+ _BitScanReverse(&i, v);
+ return bitcount_t(i);
+}
+
+inline bitcount_t trailingzeros(uint32_t v)
+{
+ unsigned long i;
+ _BitScanForward(&i, v);
+ return bitcount_t(i);
+}
+
+inline bitcount_t flog2(uint64_t v)
+{
+#if defined(_M_X64) || defined(_M_ARM) || defined(_M_ARM64)
+ unsigned long i;
+ _BitScanReverse64(&i, v);
+ return bitcount_t(i);
+#else
+ // 32-bit x86
+ uint32_t high = v >> 32;
+ uint32_t low = uint32_t(v);
+ return high ? 32+flog2(high) : flog2(low);
+#endif
+}
+
+inline bitcount_t trailingzeros(uint64_t v)
+{
+#if defined(_M_X64) || defined(_M_ARM) || defined(_M_ARM64)
+ unsigned long i;
+ _BitScanForward64(&i, v);
+ return bitcount_t(i);
+#else
+ // 32-bit x86
+ uint32_t high = v >> 32;
+ uint32_t low = uint32_t(v);
+ return low ? trailingzeros(low) : trailingzeros(high)+32;
+#endif
+}
+
+#else // Otherwise, we fall back to bit twiddling
+ // implementations
+
+inline bitcount_t flog2(uint32_t v)
+{
+ // Based on code by Eric Cole and Mark Dickinson, which appears at
+ // https://graphics.stanford.edu/~seander/bithacks.html#IntegerLogDeBruijn
+
+ static const uint8_t multiplyDeBruijnBitPos[32] = {
+ 0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30,
+ 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31
+ };
+
+ v |= v >> 1; // first round down to one less than a power of 2
+ v |= v >> 2;
+ v |= v >> 4;
+ v |= v >> 8;
+ v |= v >> 16;
+
+ return multiplyDeBruijnBitPos[(uint32_t)(v * 0x07C4ACDDU) >> 27];
+}
+
+inline bitcount_t trailingzeros(uint32_t v)
+{
+ static const uint8_t multiplyDeBruijnBitPos[32] = {
+ 0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
+ 31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9
+ };
+
+ return multiplyDeBruijnBitPos[((uint32_t)((v & -v) * 0x077CB531U)) >> 27];
+}
+
+inline bitcount_t flog2(uint64_t v)
+{
+ uint32_t high = v >> 32;
+ uint32_t low = uint32_t(v);
+
+ return high ? 32+flog2(high) : flog2(low);
+}
+
+inline bitcount_t trailingzeros(uint64_t v)
+{
+ uint32_t high = v >> 32;
+ uint32_t low = uint32_t(v);
+
+ return low ? trailingzeros(low) : trailingzeros(high)+32;
+}
+
+#endif
+
+inline bitcount_t flog2(uint8_t v)
+{
+ return flog2(uint32_t(v));
+}
+
+inline bitcount_t flog2(uint16_t v)
+{
+ return flog2(uint32_t(v));
+}
+
+#if __SIZEOF_INT128__
+inline bitcount_t flog2(__uint128_t v)
+{
+ uint64_t high = uint64_t(v >> 64);
+ uint64_t low = uint64_t(v);
+
+ return high ? 64+flog2(high) : flog2(low);
+}
+#endif
+
+inline bitcount_t trailingzeros(uint8_t v)
+{
+ return trailingzeros(uint32_t(v));
+}
+
+inline bitcount_t trailingzeros(uint16_t v)
+{
+ return trailingzeros(uint32_t(v));
+}
+
+#if __SIZEOF_INT128__
+inline bitcount_t trailingzeros(__uint128_t v)
+{
+ uint64_t high = uint64_t(v >> 64);
+ uint64_t low = uint64_t(v);
+ return low ? trailingzeros(low) : trailingzeros(high)+64;
+}
+#endif
+
+template <typename UInt>
+inline bitcount_t clog2(UInt v)
+{
+ return flog2(v) + ((v & (-v)) != v);
+}
+
+template <typename UInt>
+inline UInt addwithcarry(UInt x, UInt y, bool carryin, bool* carryout)
+{
+ UInt half_result = y + carryin;
+ UInt result = x + half_result;
+ *carryout = (half_result < y) || (result < x);
+ return result;
+}
+
+template <typename UInt>
+inline UInt subwithcarry(UInt x, UInt y, bool carryin, bool* carryout)
+{
+ UInt half_result = y + carryin;
+ UInt result = x - half_result;
+ *carryout = (half_result < y) || (result > x);
+ return result;
+}
+
+
+template <typename UInt, typename UIntX2>
+class uint_x4 {
+// private:
+ static constexpr unsigned int UINT_BITS = sizeof(UInt) * CHAR_BIT;
+public:
+ union {
+#if PCG_LITTLE_ENDIAN
+ struct {
+ UInt v0, v1, v2, v3;
+ } w;
+ struct {
+ UIntX2 v01, v23;
+ } d;
+#else
+ struct {
+ UInt v3, v2, v1, v0;
+ } w;
+ struct {
+ UIntX2 v23, v01;
+ } d;
+#endif
+ // For the array access versions, the code that uses the array
+ // must handle endian itself. Yuck.
+ UInt wa[4];
+ };
+
+public:
+ uint_x4() = default;
+
+ constexpr uint_x4(UInt v3, UInt v2, UInt v1, UInt v0)
+#if PCG_LITTLE_ENDIAN
+ : w{v0, v1, v2, v3}
+#else
+ : w{v3, v2, v1, v0}
+#endif
+ {
+ // Nothing (else) to do
+ }
+
+ constexpr uint_x4(UIntX2 v23, UIntX2 v01)
+#if PCG_LITTLE_ENDIAN
+ : d{v01,v23}
+#else
+ : d{v23,v01}
+#endif
+ {
+ // Nothing (else) to do
+ }
+
+ constexpr uint_x4(UIntX2 v01)
+#if PCG_LITTLE_ENDIAN
+ : d{v01, UIntX2(0)}
+#else
+ : d{UIntX2(0),v01}
+#endif
+ {
+ // Nothing (else) to do
+ }
+
+ template<class Integral,
+ typename std::enable_if<(std::is_integral<Integral>::value
+ && sizeof(Integral) <= sizeof(UIntX2))
+ >::type* = nullptr>
+ constexpr uint_x4(Integral v01)
+#if PCG_LITTLE_ENDIAN
+ : d{UIntX2(v01), UIntX2(0)}
+#else
+ : d{UIntX2(0), UIntX2(v01)}
+#endif
+ {
+ // Nothing (else) to do
+ }
+
+ explicit constexpr operator UIntX2() const
+ {
+ return d.v01;
+ }
+
+ template<class Integral,
+ typename std::enable_if<(std::is_integral<Integral>::value
+ && sizeof(Integral) <= sizeof(UIntX2))
+ >::type* = nullptr>
+ explicit constexpr operator Integral() const
+ {
+ return Integral(d.v01);
+ }
+
+ explicit constexpr operator bool() const
+ {
+ return d.v01 || d.v23;
+ }
+
+ template<typename U, typename V>
+ friend uint_x4<U,V> operator*(const uint_x4<U,V>&, const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend uint_x4<U,V> operator*(const uint_x4<U,V>&, V);
+
+ template<typename U, typename V>
+ friend std::pair< uint_x4<U,V>,uint_x4<U,V> >
+ divmod(const uint_x4<U,V>&, const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend uint_x4<U,V> operator+(const uint_x4<U,V>&, const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend uint_x4<U,V> operator-(const uint_x4<U,V>&, const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend uint_x4<U,V> operator<<(const uint_x4<U,V>&, const bitcount_t shift);
+
+ template<typename U, typename V>
+ friend uint_x4<U,V> operator>>(const uint_x4<U,V>&, const bitcount_t shift);
+
+#if PCG_64BIT_SPECIALIZATIONS
+ template<typename U>
+ friend uint_x4<U,uint64_t> operator<<(const uint_x4<U,uint64_t>&, const bitcount_t shift);
+
+ template<typename U>
+ friend uint_x4<U,uint64_t> operator>>(const uint_x4<U,uint64_t>&, const bitcount_t shift);
+#endif
+
+ template<typename U, typename V>
+ friend uint_x4<U,V> operator&(const uint_x4<U,V>&, const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend uint_x4<U,V> operator|(const uint_x4<U,V>&, const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend uint_x4<U,V> operator^(const uint_x4<U,V>&, const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend bool operator==(const uint_x4<U,V>&, const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend bool operator!=(const uint_x4<U,V>&, const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend bool operator<(const uint_x4<U,V>&, const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend bool operator<=(const uint_x4<U,V>&, const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend bool operator>(const uint_x4<U,V>&, const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend bool operator>=(const uint_x4<U,V>&, const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend uint_x4<U,V> operator~(const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend uint_x4<U,V> operator-(const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend bitcount_t flog2(const uint_x4<U,V>&);
+
+ template<typename U, typename V>
+ friend bitcount_t trailingzeros(const uint_x4<U,V>&);
+
+#if PCG_64BIT_SPECIALIZATIONS
+ template<typename U>
+ friend bitcount_t flog2(const uint_x4<U,uint64_t>&);
+
+ template<typename U>
+ friend bitcount_t trailingzeros(const uint_x4<U,uint64_t>&);
+#endif
+
+ uint_x4& operator*=(const uint_x4& rhs)
+ {
+ uint_x4 result = *this * rhs;
+ return *this = result;
+ }
+
+ uint_x4& operator*=(UIntX2 rhs)
+ {
+ uint_x4 result = *this * rhs;
+ return *this = result;
+ }
+
+ uint_x4& operator/=(const uint_x4& rhs)
+ {
+ uint_x4 result = *this / rhs;
+ return *this = result;
+ }
+
+ uint_x4& operator%=(const uint_x4& rhs)
+ {
+ uint_x4 result = *this % rhs;
+ return *this = result;
+ }
+
+ uint_x4& operator+=(const uint_x4& rhs)
+ {
+ uint_x4 result = *this + rhs;
+ return *this = result;
+ }
+
+ uint_x4& operator-=(const uint_x4& rhs)
+ {
+ uint_x4 result = *this - rhs;
+ return *this = result;
+ }
+
+ uint_x4& operator&=(const uint_x4& rhs)
+ {
+ uint_x4 result = *this & rhs;
+ return *this = result;
+ }
+
+ uint_x4& operator|=(const uint_x4& rhs)
+ {
+ uint_x4 result = *this | rhs;
+ return *this = result;
+ }
+
+ uint_x4& operator^=(const uint_x4& rhs)
+ {
+ uint_x4 result = *this ^ rhs;
+ return *this = result;
+ }
+
+ uint_x4& operator>>=(bitcount_t shift)
+ {
+ uint_x4 result = *this >> shift;
+ return *this = result;
+ }
+
+ uint_x4& operator<<=(bitcount_t shift)
+ {
+ uint_x4 result = *this << shift;
+ return *this = result;
+ }
+
+};
+
+template<typename U, typename V>
+bitcount_t flog2(const uint_x4<U,V>& v)
+{
+#if PCG_LITTLE_ENDIAN
+ for (uint8_t i = 4; i !=0; /* dec in loop */) {
+ --i;
+#else
+ for (uint8_t i = 0; i < 4; ++i) {
+#endif
+ if (v.wa[i] == 0)
+ continue;
+ return flog2(v.wa[i]) + uint_x4<U,V>::UINT_BITS*i;
+ }
+ abort();
+}
+
+template<typename U, typename V>
+bitcount_t trailingzeros(const uint_x4<U,V>& v)
+{
+#if PCG_LITTLE_ENDIAN
+ for (uint8_t i = 0; i < 4; ++i) {
+#else
+ for (uint8_t i = 4; i !=0; /* dec in loop */) {
+ --i;
+#endif
+ if (v.wa[i] != 0)
+ return trailingzeros(v.wa[i]) + uint_x4<U,V>::UINT_BITS*i;
+ }
+ return uint_x4<U,V>::UINT_BITS*4;
+}
+
+#if PCG_64BIT_SPECIALIZATIONS
+template<typename UInt32>
+bitcount_t flog2(const uint_x4<UInt32,uint64_t>& v)
+{
+ return v.d.v23 > 0 ? flog2(v.d.v23) + uint_x4<UInt32,uint64_t>::UINT_BITS*2
+ : flog2(v.d.v01);
+}
+
+template<typename UInt32>
+bitcount_t trailingzeros(const uint_x4<UInt32,uint64_t>& v)
+{
+ return v.d.v01 == 0 ? trailingzeros(v.d.v23) + uint_x4<UInt32,uint64_t>::UINT_BITS*2
+ : trailingzeros(v.d.v01);
+}
+#endif
+
+template <typename UInt, typename UIntX2>
+std::pair< uint_x4<UInt,UIntX2>, uint_x4<UInt,UIntX2> >
+ divmod(const uint_x4<UInt,UIntX2>& orig_dividend,
+ const uint_x4<UInt,UIntX2>& divisor)
+{
+ // If the dividend is less than the divisor, the answer is always zero.
+ // This takes care of boundary cases like 0/x (which would otherwise be
+ // problematic because we can't take the log of zero. (The boundary case
+ // of division by zero is undefined.)
+ if (orig_dividend < divisor)
+ return { uint_x4<UInt,UIntX2>(UIntX2(0)), orig_dividend };
+
+ auto dividend = orig_dividend;
+
+ auto log2_divisor = flog2(divisor);
+ auto log2_dividend = flog2(dividend);
+ // assert(log2_dividend >= log2_divisor);
+ bitcount_t logdiff = log2_dividend - log2_divisor;
+
+ constexpr uint_x4<UInt,UIntX2> ONE(UIntX2(1));
+ if (logdiff == 0)
+ return { ONE, dividend - divisor };
+
+ // Now we change the log difference to
+ // floor(log2(divisor)) - ceil(log2(dividend))
+ // to ensure that we *underestimate* the result.
+ logdiff -= 1;
+
+ uint_x4<UInt,UIntX2> quotient(UIntX2(0));
+
+ auto qfactor = ONE << logdiff;
+ auto factor = divisor << logdiff;
+
+ do {
+ dividend -= factor;
+ quotient += qfactor;
+ while (dividend < factor) {
+ factor >>= 1;
+ qfactor >>= 1;
+ }
+ } while (dividend >= divisor);
+
+ return { quotient, dividend };
+}
+
+template <typename UInt, typename UIntX2>
+uint_x4<UInt,UIntX2> operator/(const uint_x4<UInt,UIntX2>& dividend,
+ const uint_x4<UInt,UIntX2>& divisor)
+{
+ return divmod(dividend, divisor).first;
+}
+
+template <typename UInt, typename UIntX2>
+uint_x4<UInt,UIntX2> operator%(const uint_x4<UInt,UIntX2>& dividend,
+ const uint_x4<UInt,UIntX2>& divisor)
+{
+ return divmod(dividend, divisor).second;
+}
+
+
+template <typename UInt, typename UIntX2>
+uint_x4<UInt,UIntX2> operator*(const uint_x4<UInt,UIntX2>& a,
+ const uint_x4<UInt,UIntX2>& b)
+{
+ constexpr auto UINT_BITS = uint_x4<UInt,UIntX2>::UINT_BITS;
+ uint_x4<UInt,UIntX2> r = {0U, 0U, 0U, 0U};
+ bool carryin = false;
+ bool carryout;
+ UIntX2 a0b0 = UIntX2(a.w.v0) * UIntX2(b.w.v0);
+ r.w.v0 = UInt(a0b0);
+ r.w.v1 = UInt(a0b0 >> UINT_BITS);
+
+ UIntX2 a1b0 = UIntX2(a.w.v1) * UIntX2(b.w.v0);
+ r.w.v2 = UInt(a1b0 >> UINT_BITS);
+ r.w.v1 = addwithcarry(r.w.v1, UInt(a1b0), carryin, &carryout);
+ carryin = carryout;
+ r.w.v2 = addwithcarry(r.w.v2, UInt(0U), carryin, &carryout);
+ carryin = carryout;
+ r.w.v3 = addwithcarry(r.w.v3, UInt(0U), carryin, &carryout);
+
+ UIntX2 a0b1 = UIntX2(a.w.v0) * UIntX2(b.w.v1);
+ carryin = false;
+ r.w.v2 = addwithcarry(r.w.v2, UInt(a0b1 >> UINT_BITS), carryin, &carryout);
+ carryin = carryout;
+ r.w.v3 = addwithcarry(r.w.v3, UInt(0U), carryin, &carryout);
+
+ carryin = false;
+ r.w.v1 = addwithcarry(r.w.v1, UInt(a0b1), carryin, &carryout);
+ carryin = carryout;
+ r.w.v2 = addwithcarry(r.w.v2, UInt(0U), carryin, &carryout);
+ carryin = carryout;
+ r.w.v3 = addwithcarry(r.w.v3, UInt(0U), carryin, &carryout);
+
+ UIntX2 a1b1 = UIntX2(a.w.v1) * UIntX2(b.w.v1);
+ carryin = false;
+ r.w.v2 = addwithcarry(r.w.v2, UInt(a1b1), carryin, &carryout);
+ carryin = carryout;
+ r.w.v3 = addwithcarry(r.w.v3, UInt(a1b1 >> UINT_BITS), carryin, &carryout);
+
+ r.d.v23 += a.d.v01 * b.d.v23 + a.d.v23 * b.d.v01;
+
+ return r;
+}
+
+
+template <typename UInt, typename UIntX2>
+uint_x4<UInt,UIntX2> operator*(const uint_x4<UInt,UIntX2>& a,
+ UIntX2 b01)
+{
+ constexpr auto UINT_BITS = uint_x4<UInt,UIntX2>::UINT_BITS;
+ uint_x4<UInt,UIntX2> r = {0U, 0U, 0U, 0U};
+ bool carryin = false;
+ bool carryout;
+ UIntX2 a0b0 = UIntX2(a.w.v0) * UIntX2(UInt(b01));
+ r.w.v0 = UInt(a0b0);
+ r.w.v1 = UInt(a0b0 >> UINT_BITS);
+
+ UIntX2 a1b0 = UIntX2(a.w.v1) * UIntX2(UInt(b01));
+ r.w.v2 = UInt(a1b0 >> UINT_BITS);
+ r.w.v1 = addwithcarry(r.w.v1, UInt(a1b0), carryin, &carryout);
+ carryin = carryout;
+ r.w.v2 = addwithcarry(r.w.v2, UInt(0U), carryin, &carryout);
+ carryin = carryout;
+ r.w.v3 = addwithcarry(r.w.v3, UInt(0U), carryin, &carryout);
+
+ UIntX2 a0b1 = UIntX2(a.w.v0) * UIntX2(b01 >> UINT_BITS);
+ carryin = false;
+ r.w.v2 = addwithcarry(r.w.v2, UInt(a0b1 >> UINT_BITS), carryin, &carryout);
+ carryin = carryout;
+ r.w.v3 = addwithcarry(r.w.v3, UInt(0U), carryin, &carryout);
+
+ carryin = false;
+ r.w.v1 = addwithcarry(r.w.v1, UInt(a0b1), carryin, &carryout);
+ carryin = carryout;
+ r.w.v2 = addwithcarry(r.w.v2, UInt(0U), carryin, &carryout);
+ carryin = carryout;
+ r.w.v3 = addwithcarry(r.w.v3, UInt(0U), carryin, &carryout);
+
+ UIntX2 a1b1 = UIntX2(a.w.v1) * UIntX2(b01 >> UINT_BITS);
+ carryin = false;
+ r.w.v2 = addwithcarry(r.w.v2, UInt(a1b1), carryin, &carryout);
+ carryin = carryout;
+ r.w.v3 = addwithcarry(r.w.v3, UInt(a1b1 >> UINT_BITS), carryin, &carryout);
+
+ r.d.v23 += a.d.v23 * b01;
+
+ return r;
+}
+
+#if PCG_64BIT_SPECIALIZATIONS
+#if defined(_MSC_VER)
+#pragma intrinsic(_umul128)
+#endif
+
+#if defined(_MSC_VER) || __SIZEOF_INT128__
+template <typename UInt32>
+uint_x4<UInt32,uint64_t> operator*(const uint_x4<UInt32,uint64_t>& a,
+ const uint_x4<UInt32,uint64_t>& b)
+{
+#if defined(_MSC_VER)
+ uint64_t hi;
+ uint64_t lo = _umul128(a.d.v01, b.d.v01, &hi);
+#else
+ __uint128_t r = __uint128_t(a.d.v01) * __uint128_t(b.d.v01);
+ uint64_t lo = uint64_t(r);
+ uint64_t hi = r >> 64;
+#endif
+ hi += a.d.v23 * b.d.v01 + a.d.v01 * b.d.v23;
+ return {hi, lo};
+}
+#endif
+#endif
+
+
+template <typename UInt, typename UIntX2>
+uint_x4<UInt,UIntX2> operator+(const uint_x4<UInt,UIntX2>& a,
+ const uint_x4<UInt,UIntX2>& b)
+{
+ uint_x4<UInt,UIntX2> r = {0U, 0U, 0U, 0U};
+
+ bool carryin = false;
+ bool carryout;
+ r.w.v0 = addwithcarry(a.w.v0, b.w.v0, carryin, &carryout);
+ carryin = carryout;
+ r.w.v1 = addwithcarry(a.w.v1, b.w.v1, carryin, &carryout);
+ carryin = carryout;
+ r.w.v2 = addwithcarry(a.w.v2, b.w.v2, carryin, &carryout);
+ carryin = carryout;
+ r.w.v3 = addwithcarry(a.w.v3, b.w.v3, carryin, &carryout);
+
+ return r;
+}
+
+template <typename UInt, typename UIntX2>
+uint_x4<UInt,UIntX2> operator-(const uint_x4<UInt,UIntX2>& a,
+ const uint_x4<UInt,UIntX2>& b)
+{
+ uint_x4<UInt,UIntX2> r = {0U, 0U, 0U, 0U};
+
+ bool carryin = false;
+ bool carryout;
+ r.w.v0 = subwithcarry(a.w.v0, b.w.v0, carryin, &carryout);
+ carryin = carryout;
+ r.w.v1 = subwithcarry(a.w.v1, b.w.v1, carryin, &carryout);
+ carryin = carryout;
+ r.w.v2 = subwithcarry(a.w.v2, b.w.v2, carryin, &carryout);
+ carryin = carryout;
+ r.w.v3 = subwithcarry(a.w.v3, b.w.v3, carryin, &carryout);
+
+ return r;
+}
+
+#if PCG_64BIT_SPECIALIZATIONS
+template <typename UInt32>
+uint_x4<UInt32,uint64_t> operator+(const uint_x4<UInt32,uint64_t>& a,
+ const uint_x4<UInt32,uint64_t>& b)
+{
+ uint_x4<UInt32,uint64_t> r = {uint64_t(0u), uint64_t(0u)};
+
+ bool carryin = false;
+ bool carryout;
+ r.d.v01 = addwithcarry(a.d.v01, b.d.v01, carryin, &carryout);
+ carryin = carryout;
+ r.d.v23 = addwithcarry(a.d.v23, b.d.v23, carryin, &carryout);
+
+ return r;
+}
+
+template <typename UInt32>
+uint_x4<UInt32,uint64_t> operator-(const uint_x4<UInt32,uint64_t>& a,
+ const uint_x4<UInt32,uint64_t>& b)
+{
+ uint_x4<UInt32,uint64_t> r = {uint64_t(0u), uint64_t(0u)};
+
+ bool carryin = false;
+ bool carryout;
+ r.d.v01 = subwithcarry(a.d.v01, b.d.v01, carryin, &carryout);
+ carryin = carryout;
+ r.d.v23 = subwithcarry(a.d.v23, b.d.v23, carryin, &carryout);
+
+ return r;
+}
+#endif
+
+template <typename UInt, typename UIntX2>
+uint_x4<UInt,UIntX2> operator&(const uint_x4<UInt,UIntX2>& a,
+ const uint_x4<UInt,UIntX2>& b)
+{
+ return uint_x4<UInt,UIntX2>(a.d.v23 & b.d.v23, a.d.v01 & b.d.v01);
+}
+
+template <typename UInt, typename UIntX2>
+uint_x4<UInt,UIntX2> operator|(const uint_x4<UInt,UIntX2>& a,
+ const uint_x4<UInt,UIntX2>& b)
+{
+ return uint_x4<UInt,UIntX2>(a.d.v23 | b.d.v23, a.d.v01 | b.d.v01);
+}
+
+template <typename UInt, typename UIntX2>
+uint_x4<UInt,UIntX2> operator^(const uint_x4<UInt,UIntX2>& a,
+ const uint_x4<UInt,UIntX2>& b)
+{
+ return uint_x4<UInt,UIntX2>(a.d.v23 ^ b.d.v23, a.d.v01 ^ b.d.v01);
+}
+
+template <typename UInt, typename UIntX2>
+uint_x4<UInt,UIntX2> operator~(const uint_x4<UInt,UIntX2>& v)
+{
+ return uint_x4<UInt,UIntX2>(~v.d.v23, ~v.d.v01);
+}
+
+template <typename UInt, typename UIntX2>
+uint_x4<UInt,UIntX2> operator-(const uint_x4<UInt,UIntX2>& v)
+{
+ return uint_x4<UInt,UIntX2>(0UL,0UL) - v;
+}
+
+template <typename UInt, typename UIntX2>
+bool operator==(const uint_x4<UInt,UIntX2>& a, const uint_x4<UInt,UIntX2>& b)
+{
+ return (a.d.v01 == b.d.v01) && (a.d.v23 == b.d.v23);
+}
+
+template <typename UInt, typename UIntX2>
+bool operator!=(const uint_x4<UInt,UIntX2>& a, const uint_x4<UInt,UIntX2>& b)
+{
+ return !operator==(a,b);
+}
+
+
+template <typename UInt, typename UIntX2>
+bool operator<(const uint_x4<UInt,UIntX2>& a, const uint_x4<UInt,UIntX2>& b)
+{
+ return (a.d.v23 < b.d.v23)
+ || ((a.d.v23 == b.d.v23) && (a.d.v01 < b.d.v01));
+}
+
+template <typename UInt, typename UIntX2>
+bool operator>(const uint_x4<UInt,UIntX2>& a, const uint_x4<UInt,UIntX2>& b)
+{
+ return operator<(b,a);
+}
+
+template <typename UInt, typename UIntX2>
+bool operator<=(const uint_x4<UInt,UIntX2>& a, const uint_x4<UInt,UIntX2>& b)
+{
+ return !(operator<(b,a));
+}
+
+template <typename UInt, typename UIntX2>
+bool operator>=(const uint_x4<UInt,UIntX2>& a, const uint_x4<UInt,UIntX2>& b)
+{
+ return !(operator<(a,b));
+}
+
+
+
+template <typename UInt, typename UIntX2>
+uint_x4<UInt,UIntX2> operator<<(const uint_x4<UInt,UIntX2>& v,
+ const bitcount_t shift)
+{
+ uint_x4<UInt,UIntX2> r = {0U, 0U, 0U, 0U};
+ const bitcount_t bits = uint_x4<UInt,UIntX2>::UINT_BITS;
+ const bitcount_t bitmask = bits - 1;
+ const bitcount_t shiftdiv = shift / bits;
+ const bitcount_t shiftmod = shift & bitmask;
+
+ if (shiftmod) {
+ UInt carryover = 0;
+#if PCG_LITTLE_ENDIAN
+ for (uint8_t out = shiftdiv, in = 0; out < 4; ++out, ++in) {
+#else
+ for (uint8_t out = 4-shiftdiv, in = 4; out != 0; /* dec in loop */) {
+ --out, --in;
+#endif
+ r.wa[out] = (v.wa[in] << shiftmod) | carryover;
+ carryover = (v.wa[in] >> (bits - shiftmod));
+ }
+ } else {
+#if PCG_LITTLE_ENDIAN
+ for (uint8_t out = shiftdiv, in = 0; out < 4; ++out, ++in) {
+#else
+ for (uint8_t out = 4-shiftdiv, in = 4; out != 0; /* dec in loop */) {
+ --out, --in;
+#endif
+ r.wa[out] = v.wa[in];
+ }
+ }
+
+ return r;
+}
+
+template <typename UInt, typename UIntX2>
+uint_x4<UInt,UIntX2> operator>>(const uint_x4<UInt,UIntX2>& v,
+ const bitcount_t shift)
+{
+ uint_x4<UInt,UIntX2> r = {0U, 0U, 0U, 0U};
+ const bitcount_t bits = uint_x4<UInt,UIntX2>::UINT_BITS;
+ const bitcount_t bitmask = bits - 1;
+ const bitcount_t shiftdiv = shift / bits;
+ const bitcount_t shiftmod = shift & bitmask;
+
+ if (shiftmod) {
+ UInt carryover = 0;
+#if PCG_LITTLE_ENDIAN
+ for (uint8_t out = 4-shiftdiv, in = 4; out != 0; /* dec in loop */) {
+ --out, --in;
+#else
+ for (uint8_t out = shiftdiv, in = 0; out < 4; ++out, ++in) {
+#endif
+ r.wa[out] = (v.wa[in] >> shiftmod) | carryover;
+ carryover = (v.wa[in] << (bits - shiftmod));
+ }
+ } else {
+#if PCG_LITTLE_ENDIAN
+ for (uint8_t out = 4-shiftdiv, in = 4; out != 0; /* dec in loop */) {
+ --out, --in;
+#else
+ for (uint8_t out = shiftdiv, in = 0; out < 4; ++out, ++in) {
+#endif
+ r.wa[out] = v.wa[in];
+ }
+ }
+
+ return r;
+}
+
+#if PCG_64BIT_SPECIALIZATIONS
+template <typename UInt32>
+uint_x4<UInt32,uint64_t> operator<<(const uint_x4<UInt32,uint64_t>& v,
+ const bitcount_t shift)
+{
+ constexpr bitcount_t bits2 = uint_x4<UInt32,uint64_t>::UINT_BITS * 2;
+
+ if (shift >= bits2) {
+ return {v.d.v01 << (shift-bits2), uint64_t(0u)};
+ } else {
+ return {shift ? (v.d.v23 << shift) | (v.d.v01 >> (bits2-shift))
+ : v.d.v23,
+ v.d.v01 << shift};
+ }
+}
+
+template <typename UInt32>
+uint_x4<UInt32,uint64_t> operator>>(const uint_x4<UInt32,uint64_t>& v,
+ const bitcount_t shift)
+{
+ constexpr bitcount_t bits2 = uint_x4<UInt32,uint64_t>::UINT_BITS * 2;
+
+ if (shift >= bits2) {
+ return {uint64_t(0u), v.d.v23 >> (shift-bits2)};
+ } else {
+ return {v.d.v23 >> shift,
+ shift ? (v.d.v01 >> shift) | (v.d.v23 << (bits2-shift))
+ : v.d.v01};
+ }
+}
+#endif
+
+} // namespace pcg_extras
+} // namespace arrow_vendored
+
+#endif // PCG_UINT128_HPP_INCLUDED
diff --git a/src/arrow/cpp/src/arrow/vendored/portable-snippets/README.md b/src/arrow/cpp/src/arrow/vendored/portable-snippets/README.md
new file mode 100644
index 000000000..9c67b7baa
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/portable-snippets/README.md
@@ -0,0 +1,10 @@
+<!---
+Each source file contains a preamble explaining the license situation
+for that file, which takes priority over this file. With the
+exception of some code pulled in from other repositories (such as
+µnit, an MIT-licensed project which is used for testing), the code is
+public domain, released using the CC0 1.0 Universal dedication.
+-->
+
+The files in this directory are vendored from portable-snippets
+git changeset f596f8b0a4b8a6ea1166c2361a5cb7e6f802c5ea.
diff --git a/src/arrow/cpp/src/arrow/vendored/portable-snippets/safe-math.h b/src/arrow/cpp/src/arrow/vendored/portable-snippets/safe-math.h
new file mode 100644
index 000000000..7f6426ac7
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/portable-snippets/safe-math.h
@@ -0,0 +1,1072 @@
+/* Overflow-safe math functions
+ * Portable Snippets - https://github.com/nemequ/portable-snippets
+ * Created by Evan Nemerson <evan@nemerson.com>
+ *
+ * To the extent possible under law, the authors have waived all
+ * copyright and related or neighboring rights to this code. For
+ * details, see the Creative Commons Zero 1.0 Universal license at
+ * https://creativecommons.org/publicdomain/zero/1.0/
+ */
+
+#if !defined(PSNIP_SAFE_H)
+#define PSNIP_SAFE_H
+
+#if !defined(PSNIP_SAFE_FORCE_PORTABLE)
+# if defined(__has_builtin)
+# if __has_builtin(__builtin_add_overflow) && !defined(__ibmxl__)
+# define PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW
+# endif
+# elif defined(__GNUC__) && (__GNUC__ >= 5) && !defined(__INTEL_COMPILER)
+# define PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW
+# endif
+# if defined(__has_include)
+# if __has_include(<intsafe.h>)
+# define PSNIP_SAFE_HAVE_INTSAFE_H
+# endif
+# elif defined(_WIN32)
+# define PSNIP_SAFE_HAVE_INTSAFE_H
+# endif
+#endif /* !defined(PSNIP_SAFE_FORCE_PORTABLE) */
+
+#if defined(__GNUC__)
+# define PSNIP_SAFE_LIKELY(expr) __builtin_expect(!!(expr), 1)
+# define PSNIP_SAFE_UNLIKELY(expr) __builtin_expect(!!(expr), 0)
+#else
+# define PSNIP_SAFE_LIKELY(expr) !!(expr)
+# define PSNIP_SAFE_UNLIKELY(expr) !!(expr)
+#endif /* defined(__GNUC__) */
+
+#if !defined(PSNIP_SAFE_STATIC_INLINE)
+# if defined(__GNUC__)
+# define PSNIP_SAFE__COMPILER_ATTRIBUTES __attribute__((__unused__))
+# else
+# define PSNIP_SAFE__COMPILER_ATTRIBUTES
+# endif
+
+# if defined(HEDLEY_INLINE)
+# define PSNIP_SAFE__INLINE HEDLEY_INLINE
+# elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L
+# define PSNIP_SAFE__INLINE inline
+# elif defined(__GNUC_STDC_INLINE__)
+# define PSNIP_SAFE__INLINE __inline__
+# elif defined(_MSC_VER) && _MSC_VER >= 1200
+# define PSNIP_SAFE__INLINE __inline
+# else
+# define PSNIP_SAFE__INLINE
+# endif
+
+# define PSNIP_SAFE__FUNCTION PSNIP_SAFE__COMPILER_ATTRIBUTES static PSNIP_SAFE__INLINE
+#endif
+
+// !defined(__cplusplus) added for Solaris support
+#if !defined(__cplusplus) && defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L
+# define psnip_safe_bool _Bool
+#else
+# define psnip_safe_bool int
+#endif
+
+#if !defined(PSNIP_SAFE_NO_FIXED)
+/* For maximum portability include the exact-int module from
+ portable snippets. */
+# if \
+ !defined(psnip_int64_t) || !defined(psnip_uint64_t) || \
+ !defined(psnip_int32_t) || !defined(psnip_uint32_t) || \
+ !defined(psnip_int16_t) || !defined(psnip_uint16_t) || \
+ !defined(psnip_int8_t) || !defined(psnip_uint8_t)
+# include <stdint.h>
+# if !defined(psnip_int64_t)
+# define psnip_int64_t int64_t
+# endif
+# if !defined(psnip_uint64_t)
+# define psnip_uint64_t uint64_t
+# endif
+# if !defined(psnip_int32_t)
+# define psnip_int32_t int32_t
+# endif
+# if !defined(psnip_uint32_t)
+# define psnip_uint32_t uint32_t
+# endif
+# if !defined(psnip_int16_t)
+# define psnip_int16_t int16_t
+# endif
+# if !defined(psnip_uint16_t)
+# define psnip_uint16_t uint16_t
+# endif
+# if !defined(psnip_int8_t)
+# define psnip_int8_t int8_t
+# endif
+# if !defined(psnip_uint8_t)
+# define psnip_uint8_t uint8_t
+# endif
+# endif
+#endif /* !defined(PSNIP_SAFE_NO_FIXED) */
+#include <limits.h>
+#include <stdlib.h>
+
+#if !defined(PSNIP_SAFE_SIZE_MAX)
+# if defined(__SIZE_MAX__)
+# define PSNIP_SAFE_SIZE_MAX __SIZE_MAX__
+# elif defined(PSNIP_EXACT_INT_HAVE_STDINT)
+# include <stdint.h>
+# endif
+#endif
+
+#if defined(PSNIP_SAFE_SIZE_MAX)
+# define PSNIP_SAFE__SIZE_MAX_RT PSNIP_SAFE_SIZE_MAX
+#else
+# define PSNIP_SAFE__SIZE_MAX_RT (~((size_t) 0))
+#endif
+
+#if defined(PSNIP_SAFE_HAVE_INTSAFE_H)
+/* In VS 10, stdint.h and intsafe.h both define (U)INTN_MIN/MAX, which
+ triggers warning C4005 (level 1). */
+# if defined(_MSC_VER) && (_MSC_VER == 1600)
+# pragma warning(push)
+# pragma warning(disable:4005)
+# endif
+# include <intsafe.h>
+# if defined(_MSC_VER) && (_MSC_VER == 1600)
+# pragma warning(pop)
+# endif
+#endif /* defined(PSNIP_SAFE_HAVE_INTSAFE_H) */
+
+/* If there is a type larger than the one we're concerned with it's
+ * likely much faster to simply promote the operands, perform the
+ * requested operation, verify that the result falls within the
+ * original type, then cast the result back to the original type. */
+
+#if !defined(PSNIP_SAFE_NO_PROMOTIONS)
+
+#define PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, op_name, op) \
+ PSNIP_SAFE__FUNCTION psnip_safe_##name##_larger \
+ psnip_safe_larger_##name##_##op_name (T a, T b) { \
+ return ((psnip_safe_##name##_larger) a) op ((psnip_safe_##name##_larger) b); \
+ }
+
+#define PSNIP_SAFE_DEFINE_LARGER_UNARY_OP(T, name, op_name, op) \
+ PSNIP_SAFE__FUNCTION psnip_safe_##name##_larger \
+ psnip_safe_larger_##name##_##op_name (T value) { \
+ return (op ((psnip_safe_##name##_larger) value)); \
+ }
+
+#define PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(T, name) \
+ PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, add, +) \
+ PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, sub, -) \
+ PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, mul, *) \
+ PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, div, /) \
+ PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, mod, %) \
+ PSNIP_SAFE_DEFINE_LARGER_UNARY_OP (T, name, neg, -)
+
+#define PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(T, name) \
+ PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, add, +) \
+ PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, sub, -) \
+ PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, mul, *) \
+ PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, div, /) \
+ PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, mod, %)
+
+#define PSNIP_SAFE_IS_LARGER(ORIG_MAX, DEST_MAX) ((DEST_MAX / ORIG_MAX) >= ORIG_MAX)
+
+#if defined(__GNUC__) && ((__GNUC__ >= 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && defined(__SIZEOF_INT128__) && !defined(__ibmxl__)
+#define PSNIP_SAFE_HAVE_128
+typedef __int128 psnip_safe_int128_t;
+typedef unsigned __int128 psnip_safe_uint128_t;
+#endif /* defined(__GNUC__) */
+
+#if !defined(PSNIP_SAFE_NO_FIXED)
+#define PSNIP_SAFE_HAVE_INT8_LARGER
+#define PSNIP_SAFE_HAVE_UINT8_LARGER
+typedef psnip_int16_t psnip_safe_int8_larger;
+typedef psnip_uint16_t psnip_safe_uint8_larger;
+
+#define PSNIP_SAFE_HAVE_INT16_LARGER
+typedef psnip_int32_t psnip_safe_int16_larger;
+typedef psnip_uint32_t psnip_safe_uint16_larger;
+
+#define PSNIP_SAFE_HAVE_INT32_LARGER
+typedef psnip_int64_t psnip_safe_int32_larger;
+typedef psnip_uint64_t psnip_safe_uint32_larger;
+
+#if defined(PSNIP_SAFE_HAVE_128)
+#define PSNIP_SAFE_HAVE_INT64_LARGER
+typedef psnip_safe_int128_t psnip_safe_int64_larger;
+typedef psnip_safe_uint128_t psnip_safe_uint64_larger;
+#endif /* defined(PSNIP_SAFE_HAVE_128) */
+#endif /* !defined(PSNIP_SAFE_NO_FIXED) */
+
+#define PSNIP_SAFE_HAVE_LARGER_SCHAR
+#if PSNIP_SAFE_IS_LARGER(SCHAR_MAX, SHRT_MAX)
+typedef short psnip_safe_schar_larger;
+#elif PSNIP_SAFE_IS_LARGER(SCHAR_MAX, INT_MAX)
+typedef int psnip_safe_schar_larger;
+#elif PSNIP_SAFE_IS_LARGER(SCHAR_MAX, LONG_MAX)
+typedef long psnip_safe_schar_larger;
+#elif PSNIP_SAFE_IS_LARGER(SCHAR_MAX, LLONG_MAX)
+typedef long long psnip_safe_schar_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(SCHAR_MAX, 0x7fff)
+typedef psnip_int16_t psnip_safe_schar_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(SCHAR_MAX, 0x7fffffffLL)
+typedef psnip_int32_t psnip_safe_schar_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(SCHAR_MAX, 0x7fffffffffffffffLL)
+typedef psnip_int64_t psnip_safe_schar_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (SCHAR_MAX <= 0x7fffffffffffffffLL)
+typedef psnip_safe_int128_t psnip_safe_schar_larger;
+#else
+#undef PSNIP_SAFE_HAVE_LARGER_SCHAR
+#endif
+
+#define PSNIP_SAFE_HAVE_LARGER_UCHAR
+#if PSNIP_SAFE_IS_LARGER(UCHAR_MAX, USHRT_MAX)
+typedef unsigned short psnip_safe_uchar_larger;
+#elif PSNIP_SAFE_IS_LARGER(UCHAR_MAX, UINT_MAX)
+typedef unsigned int psnip_safe_uchar_larger;
+#elif PSNIP_SAFE_IS_LARGER(UCHAR_MAX, ULONG_MAX)
+typedef unsigned long psnip_safe_uchar_larger;
+#elif PSNIP_SAFE_IS_LARGER(UCHAR_MAX, ULLONG_MAX)
+typedef unsigned long long psnip_safe_uchar_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(UCHAR_MAX, 0xffffU)
+typedef psnip_uint16_t psnip_safe_uchar_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(UCHAR_MAX, 0xffffffffUL)
+typedef psnip_uint32_t psnip_safe_uchar_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(UCHAR_MAX, 0xffffffffffffffffULL)
+typedef psnip_uint64_t psnip_safe_uchar_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (UCHAR_MAX <= 0xffffffffffffffffULL)
+typedef psnip_safe_uint128_t psnip_safe_uchar_larger;
+#else
+#undef PSNIP_SAFE_HAVE_LARGER_UCHAR
+#endif
+
+#if CHAR_MIN == 0 && defined(PSNIP_SAFE_HAVE_LARGER_UCHAR)
+#define PSNIP_SAFE_HAVE_LARGER_CHAR
+typedef psnip_safe_uchar_larger psnip_safe_char_larger;
+#elif CHAR_MIN < 0 && defined(PSNIP_SAFE_HAVE_LARGER_SCHAR)
+#define PSNIP_SAFE_HAVE_LARGER_CHAR
+typedef psnip_safe_schar_larger psnip_safe_char_larger;
+#endif
+
+#define PSNIP_SAFE_HAVE_LARGER_SHRT
+#if PSNIP_SAFE_IS_LARGER(SHRT_MAX, INT_MAX)
+typedef int psnip_safe_short_larger;
+#elif PSNIP_SAFE_IS_LARGER(SHRT_MAX, LONG_MAX)
+typedef long psnip_safe_short_larger;
+#elif PSNIP_SAFE_IS_LARGER(SHRT_MAX, LLONG_MAX)
+typedef long long psnip_safe_short_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(SHRT_MAX, 0x7fff)
+typedef psnip_int16_t psnip_safe_short_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(SHRT_MAX, 0x7fffffffLL)
+typedef psnip_int32_t psnip_safe_short_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(SHRT_MAX, 0x7fffffffffffffffLL)
+typedef psnip_int64_t psnip_safe_short_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (SHRT_MAX <= 0x7fffffffffffffffLL)
+typedef psnip_safe_int128_t psnip_safe_short_larger;
+#else
+#undef PSNIP_SAFE_HAVE_LARGER_SHRT
+#endif
+
+#define PSNIP_SAFE_HAVE_LARGER_USHRT
+#if PSNIP_SAFE_IS_LARGER(USHRT_MAX, UINT_MAX)
+typedef unsigned int psnip_safe_ushort_larger;
+#elif PSNIP_SAFE_IS_LARGER(USHRT_MAX, ULONG_MAX)
+typedef unsigned long psnip_safe_ushort_larger;
+#elif PSNIP_SAFE_IS_LARGER(USHRT_MAX, ULLONG_MAX)
+typedef unsigned long long psnip_safe_ushort_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(USHRT_MAX, 0xffff)
+typedef psnip_uint16_t psnip_safe_ushort_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(USHRT_MAX, 0xffffffffUL)
+typedef psnip_uint32_t psnip_safe_ushort_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(USHRT_MAX, 0xffffffffffffffffULL)
+typedef psnip_uint64_t psnip_safe_ushort_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (USHRT_MAX <= 0xffffffffffffffffULL)
+typedef psnip_safe_uint128_t psnip_safe_ushort_larger;
+#else
+#undef PSNIP_SAFE_HAVE_LARGER_USHRT
+#endif
+
+#define PSNIP_SAFE_HAVE_LARGER_INT
+#if PSNIP_SAFE_IS_LARGER(INT_MAX, LONG_MAX)
+typedef long psnip_safe_int_larger;
+#elif PSNIP_SAFE_IS_LARGER(INT_MAX, LLONG_MAX)
+typedef long long psnip_safe_int_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(INT_MAX, 0x7fff)
+typedef psnip_int16_t psnip_safe_int_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(INT_MAX, 0x7fffffffLL)
+typedef psnip_int32_t psnip_safe_int_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(INT_MAX, 0x7fffffffffffffffLL)
+typedef psnip_int64_t psnip_safe_int_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (INT_MAX <= 0x7fffffffffffffffLL)
+typedef psnip_safe_int128_t psnip_safe_int_larger;
+#else
+#undef PSNIP_SAFE_HAVE_LARGER_INT
+#endif
+
+#define PSNIP_SAFE_HAVE_LARGER_UINT
+#if PSNIP_SAFE_IS_LARGER(UINT_MAX, ULONG_MAX)
+typedef unsigned long psnip_safe_uint_larger;
+#elif PSNIP_SAFE_IS_LARGER(UINT_MAX, ULLONG_MAX)
+typedef unsigned long long psnip_safe_uint_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(UINT_MAX, 0xffff)
+typedef psnip_uint16_t psnip_safe_uint_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(UINT_MAX, 0xffffffffUL)
+typedef psnip_uint32_t psnip_safe_uint_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(UINT_MAX, 0xffffffffffffffffULL)
+typedef psnip_uint64_t psnip_safe_uint_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (UINT_MAX <= 0xffffffffffffffffULL)
+typedef psnip_safe_uint128_t psnip_safe_uint_larger;
+#else
+#undef PSNIP_SAFE_HAVE_LARGER_UINT
+#endif
+
+#define PSNIP_SAFE_HAVE_LARGER_LONG
+#if PSNIP_SAFE_IS_LARGER(LONG_MAX, LLONG_MAX)
+typedef long long psnip_safe_long_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(LONG_MAX, 0x7fff)
+typedef psnip_int16_t psnip_safe_long_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(LONG_MAX, 0x7fffffffLL)
+typedef psnip_int32_t psnip_safe_long_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(LONG_MAX, 0x7fffffffffffffffLL)
+typedef psnip_int64_t psnip_safe_long_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (LONG_MAX <= 0x7fffffffffffffffLL)
+typedef psnip_safe_int128_t psnip_safe_long_larger;
+#else
+#undef PSNIP_SAFE_HAVE_LARGER_LONG
+#endif
+
+#define PSNIP_SAFE_HAVE_LARGER_ULONG
+#if PSNIP_SAFE_IS_LARGER(ULONG_MAX, ULLONG_MAX)
+typedef unsigned long long psnip_safe_ulong_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(ULONG_MAX, 0xffff)
+typedef psnip_uint16_t psnip_safe_ulong_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(ULONG_MAX, 0xffffffffUL)
+typedef psnip_uint32_t psnip_safe_ulong_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(ULONG_MAX, 0xffffffffffffffffULL)
+typedef psnip_uint64_t psnip_safe_ulong_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (ULONG_MAX <= 0xffffffffffffffffULL)
+typedef psnip_safe_uint128_t psnip_safe_ulong_larger;
+#else
+#undef PSNIP_SAFE_HAVE_LARGER_ULONG
+#endif
+
+#define PSNIP_SAFE_HAVE_LARGER_LLONG
+#if !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(LLONG_MAX, 0x7fff)
+typedef psnip_int16_t psnip_safe_llong_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(LLONG_MAX, 0x7fffffffLL)
+typedef psnip_int32_t psnip_safe_llong_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(LLONG_MAX, 0x7fffffffffffffffLL)
+typedef psnip_int64_t psnip_safe_llong_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (LLONG_MAX <= 0x7fffffffffffffffLL)
+typedef psnip_safe_int128_t psnip_safe_llong_larger;
+#else
+#undef PSNIP_SAFE_HAVE_LARGER_LLONG
+#endif
+
+#define PSNIP_SAFE_HAVE_LARGER_ULLONG
+#if !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(ULLONG_MAX, 0xffff)
+typedef psnip_uint16_t psnip_safe_ullong_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(ULLONG_MAX, 0xffffffffUL)
+typedef psnip_uint32_t psnip_safe_ullong_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(ULLONG_MAX, 0xffffffffffffffffULL)
+typedef psnip_uint64_t psnip_safe_ullong_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (ULLONG_MAX <= 0xffffffffffffffffULL)
+typedef psnip_safe_uint128_t psnip_safe_ullong_larger;
+#else
+#undef PSNIP_SAFE_HAVE_LARGER_ULLONG
+#endif
+
+#if defined(PSNIP_SAFE_SIZE_MAX)
+#define PSNIP_SAFE_HAVE_LARGER_SIZE
+#if PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, USHRT_MAX)
+typedef unsigned short psnip_safe_size_larger;
+#elif PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, UINT_MAX)
+typedef unsigned int psnip_safe_size_larger;
+#elif PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, ULONG_MAX)
+typedef unsigned long psnip_safe_size_larger;
+#elif PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, ULLONG_MAX)
+typedef unsigned long long psnip_safe_size_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, 0xffff)
+typedef psnip_uint16_t psnip_safe_size_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, 0xffffffffUL)
+typedef psnip_uint32_t psnip_safe_size_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, 0xffffffffffffffffULL)
+typedef psnip_uint64_t psnip_safe_size_larger;
+#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (PSNIP_SAFE_SIZE_MAX <= 0xffffffffffffffffULL)
+typedef psnip_safe_uint128_t psnip_safe_size_larger;
+#else
+#undef PSNIP_SAFE_HAVE_LARGER_SIZE
+#endif
+#endif
+
+#if defined(PSNIP_SAFE_HAVE_LARGER_SCHAR)
+PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(signed char, schar)
+#endif
+
+#if defined(PSNIP_SAFE_HAVE_LARGER_UCHAR)
+PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(unsigned char, uchar)
+#endif
+
+#if defined(PSNIP_SAFE_HAVE_LARGER_CHAR)
+#if CHAR_MIN == 0
+PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(char, char)
+#else
+PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(char, char)
+#endif
+#endif
+
+#if defined(PSNIP_SAFE_HAVE_LARGER_SHORT)
+PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(short, short)
+#endif
+
+#if defined(PSNIP_SAFE_HAVE_LARGER_USHORT)
+PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(unsigned short, ushort)
+#endif
+
+#if defined(PSNIP_SAFE_HAVE_LARGER_INT)
+PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(int, int)
+#endif
+
+#if defined(PSNIP_SAFE_HAVE_LARGER_UINT)
+PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(unsigned int, uint)
+#endif
+
+#if defined(PSNIP_SAFE_HAVE_LARGER_LONG)
+PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(long, long)
+#endif
+
+#if defined(PSNIP_SAFE_HAVE_LARGER_ULONG)
+PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(unsigned long, ulong)
+#endif
+
+#if defined(PSNIP_SAFE_HAVE_LARGER_LLONG)
+PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(long long, llong)
+#endif
+
+#if defined(PSNIP_SAFE_HAVE_LARGER_ULLONG)
+PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(unsigned long long, ullong)
+#endif
+
+#if defined(PSNIP_SAFE_HAVE_LARGER_SIZE)
+PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(size_t, size)
+#endif
+
+#if !defined(PSNIP_SAFE_NO_FIXED)
+PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(psnip_int8_t, int8)
+PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(psnip_uint8_t, uint8)
+PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(psnip_int16_t, int16)
+PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(psnip_uint16_t, uint16)
+PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(psnip_int32_t, int32)
+PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(psnip_uint32_t, uint32)
+#if defined(PSNIP_SAFE_HAVE_128)
+PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(psnip_int64_t, int64)
+PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(psnip_uint64_t, uint64)
+#endif
+#endif
+
+#endif /* !defined(PSNIP_SAFE_NO_PROMOTIONS) */
+
+#define PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(T, name, op_name) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_##op_name(T* res, T a, T b) { \
+ return !__builtin_##op_name##_overflow(a, b, res); \
+ }
+
+#define PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(T, name, op_name, min, max) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_##op_name(T* res, T a, T b) { \
+ const psnip_safe_##name##_larger r = psnip_safe_larger_##name##_##op_name(a, b); \
+ *res = (T) r; \
+ return (r >= min) && (r <= max); \
+ }
+
+#define PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(T, name, op_name, max) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_##op_name(T* res, T a, T b) { \
+ const psnip_safe_##name##_larger r = psnip_safe_larger_##name##_##op_name(a, b); \
+ *res = (T) r; \
+ return (r <= max); \
+ }
+
+#define PSNIP_SAFE_DEFINE_SIGNED_ADD(T, name, min, max) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_add (T* res, T a, T b) { \
+ psnip_safe_bool r = !( ((b > 0) && (a > (max - b))) || \
+ ((b < 0) && (a < (min - b))) ); \
+ if(PSNIP_SAFE_LIKELY(r)) \
+ *res = a + b; \
+ return r; \
+ }
+
+#define PSNIP_SAFE_DEFINE_UNSIGNED_ADD(T, name, max) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_add (T* res, T a, T b) { \
+ *res = (T) (a + b); \
+ return !PSNIP_SAFE_UNLIKELY((b > 0) && (a > (max - b))); \
+ }
+
+#define PSNIP_SAFE_DEFINE_SIGNED_SUB(T, name, min, max) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_sub (T* res, T a, T b) { \
+ psnip_safe_bool r = !((b > 0 && a < (min + b)) || \
+ (b < 0 && a > (max + b))); \
+ if(PSNIP_SAFE_LIKELY(r)) \
+ *res = a - b; \
+ return r; \
+ }
+
+#define PSNIP_SAFE_DEFINE_UNSIGNED_SUB(T, name, max) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_sub (T* res, T a, T b) { \
+ *res = a - b; \
+ return !PSNIP_SAFE_UNLIKELY(b > a); \
+ }
+
+#define PSNIP_SAFE_DEFINE_SIGNED_MUL(T, name, min, max) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_mul (T* res, T a, T b) { \
+ psnip_safe_bool r = 1; \
+ if (a > 0) { \
+ if (b > 0) { \
+ if (a > (max / b)) { \
+ r = 0; \
+ } \
+ } else { \
+ if (b < (min / a)) { \
+ r = 0; \
+ } \
+ } \
+ } else { \
+ if (b > 0) { \
+ if (a < (min / b)) { \
+ r = 0; \
+ } \
+ } else { \
+ if ( (a != 0) && (b < (max / a))) { \
+ r = 0; \
+ } \
+ } \
+ } \
+ if(PSNIP_SAFE_LIKELY(r)) \
+ *res = a * b; \
+ return r; \
+ }
+
+#define PSNIP_SAFE_DEFINE_UNSIGNED_MUL(T, name, max) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_mul (T* res, T a, T b) { \
+ *res = (T) (a * b); \
+ return !PSNIP_SAFE_UNLIKELY((a > 0) && (b > 0) && (a > (max / b))); \
+ }
+
+#define PSNIP_SAFE_DEFINE_SIGNED_DIV(T, name, min, max) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_div (T* res, T a, T b) { \
+ if (PSNIP_SAFE_UNLIKELY(b == 0)) { \
+ *res = 0; \
+ return 0; \
+ } else if (PSNIP_SAFE_UNLIKELY(a == min && b == -1)) { \
+ *res = min; \
+ return 0; \
+ } else { \
+ *res = (T) (a / b); \
+ return 1; \
+ } \
+ }
+
+#define PSNIP_SAFE_DEFINE_UNSIGNED_DIV(T, name, max) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_div (T* res, T a, T b) { \
+ if (PSNIP_SAFE_UNLIKELY(b == 0)) { \
+ *res = 0; \
+ return 0; \
+ } else { \
+ *res = a / b; \
+ return 1; \
+ } \
+ }
+
+#define PSNIP_SAFE_DEFINE_SIGNED_MOD(T, name, min, max) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_mod (T* res, T a, T b) { \
+ if (PSNIP_SAFE_UNLIKELY(b == 0)) { \
+ *res = 0; \
+ return 0; \
+ } else if (PSNIP_SAFE_UNLIKELY(a == min && b == -1)) { \
+ *res = min; \
+ return 0; \
+ } else { \
+ *res = (T) (a % b); \
+ return 1; \
+ } \
+ }
+
+#define PSNIP_SAFE_DEFINE_UNSIGNED_MOD(T, name, max) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_mod (T* res, T a, T b) { \
+ if (PSNIP_SAFE_UNLIKELY(b == 0)) { \
+ *res = 0; \
+ return 0; \
+ } else { \
+ *res = a % b; \
+ return 1; \
+ } \
+ }
+
+#define PSNIP_SAFE_DEFINE_SIGNED_NEG(T, name, min, max) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_neg (T* res, T value) { \
+ psnip_safe_bool r = value != min; \
+ *res = PSNIP_SAFE_LIKELY(r) ? -value : max; \
+ return r; \
+ }
+
+#define PSNIP_SAFE_DEFINE_INTSAFE(T, name, op, isf) \
+ PSNIP_SAFE__FUNCTION psnip_safe_bool \
+ psnip_safe_##name##_##op (T* res, T a, T b) { \
+ return isf(a, b, res) == S_OK; \
+ }
+
+#if CHAR_MIN == 0
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(char, char, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(char, char, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(char, char, mul)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_CHAR)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(char, char, add, CHAR_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(char, char, sub, CHAR_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(char, char, mul, CHAR_MAX)
+#else
+PSNIP_SAFE_DEFINE_UNSIGNED_ADD(char, char, CHAR_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_SUB(char, char, CHAR_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_MUL(char, char, CHAR_MAX)
+#endif
+PSNIP_SAFE_DEFINE_UNSIGNED_DIV(char, char, CHAR_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_MOD(char, char, CHAR_MAX)
+#else /* CHAR_MIN != 0 */
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(char, char, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(char, char, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(char, char, mul)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_CHAR)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(char, char, add, CHAR_MIN, CHAR_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(char, char, sub, CHAR_MIN, CHAR_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(char, char, mul, CHAR_MIN, CHAR_MAX)
+#else
+PSNIP_SAFE_DEFINE_SIGNED_ADD(char, char, CHAR_MIN, CHAR_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_SUB(char, char, CHAR_MIN, CHAR_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_MUL(char, char, CHAR_MIN, CHAR_MAX)
+#endif
+PSNIP_SAFE_DEFINE_SIGNED_DIV(char, char, CHAR_MIN, CHAR_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_MOD(char, char, CHAR_MIN, CHAR_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_NEG(char, char, CHAR_MIN, CHAR_MAX)
+#endif
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(signed char, schar, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(signed char, schar, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(signed char, schar, mul)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_SCHAR)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(signed char, schar, add, SCHAR_MIN, SCHAR_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(signed char, schar, sub, SCHAR_MIN, SCHAR_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(signed char, schar, mul, SCHAR_MIN, SCHAR_MAX)
+#else
+PSNIP_SAFE_DEFINE_SIGNED_ADD(signed char, schar, SCHAR_MIN, SCHAR_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_SUB(signed char, schar, SCHAR_MIN, SCHAR_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_MUL(signed char, schar, SCHAR_MIN, SCHAR_MAX)
+#endif
+PSNIP_SAFE_DEFINE_SIGNED_DIV(signed char, schar, SCHAR_MIN, SCHAR_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_MOD(signed char, schar, SCHAR_MIN, SCHAR_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_NEG(signed char, schar, SCHAR_MIN, SCHAR_MAX)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned char, uchar, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned char, uchar, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned char, uchar, mul)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_UCHAR)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned char, uchar, add, UCHAR_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned char, uchar, sub, UCHAR_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned char, uchar, mul, UCHAR_MAX)
+#else
+PSNIP_SAFE_DEFINE_UNSIGNED_ADD(unsigned char, uchar, UCHAR_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_SUB(unsigned char, uchar, UCHAR_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_MUL(unsigned char, uchar, UCHAR_MAX)
+#endif
+PSNIP_SAFE_DEFINE_UNSIGNED_DIV(unsigned char, uchar, UCHAR_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_MOD(unsigned char, uchar, UCHAR_MAX)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(short, short, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(short, short, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(short, short, mul)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_SHORT)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(short, short, add, SHRT_MIN, SHRT_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(short, short, sub, SHRT_MIN, SHRT_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(short, short, mul, SHRT_MIN, SHRT_MAX)
+#else
+PSNIP_SAFE_DEFINE_SIGNED_ADD(short, short, SHRT_MIN, SHRT_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_SUB(short, short, SHRT_MIN, SHRT_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_MUL(short, short, SHRT_MIN, SHRT_MAX)
+#endif
+PSNIP_SAFE_DEFINE_SIGNED_DIV(short, short, SHRT_MIN, SHRT_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_MOD(short, short, SHRT_MIN, SHRT_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_NEG(short, short, SHRT_MIN, SHRT_MAX)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned short, ushort, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned short, ushort, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned short, ushort, mul)
+#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H)
+PSNIP_SAFE_DEFINE_INTSAFE(unsigned short, ushort, add, UShortAdd)
+PSNIP_SAFE_DEFINE_INTSAFE(unsigned short, ushort, sub, UShortSub)
+PSNIP_SAFE_DEFINE_INTSAFE(unsigned short, ushort, mul, UShortMult)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_USHORT)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned short, ushort, add, USHRT_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned short, ushort, sub, USHRT_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned short, ushort, mul, USHRT_MAX)
+#else
+PSNIP_SAFE_DEFINE_UNSIGNED_ADD(unsigned short, ushort, USHRT_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_SUB(unsigned short, ushort, USHRT_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_MUL(unsigned short, ushort, USHRT_MAX)
+#endif
+PSNIP_SAFE_DEFINE_UNSIGNED_DIV(unsigned short, ushort, USHRT_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_MOD(unsigned short, ushort, USHRT_MAX)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(int, int, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(int, int, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(int, int, mul)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_INT)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(int, int, add, INT_MIN, INT_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(int, int, sub, INT_MIN, INT_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(int, int, mul, INT_MIN, INT_MAX)
+#else
+PSNIP_SAFE_DEFINE_SIGNED_ADD(int, int, INT_MIN, INT_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_SUB(int, int, INT_MIN, INT_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_MUL(int, int, INT_MIN, INT_MAX)
+#endif
+PSNIP_SAFE_DEFINE_SIGNED_DIV(int, int, INT_MIN, INT_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_MOD(int, int, INT_MIN, INT_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_NEG(int, int, INT_MIN, INT_MAX)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned int, uint, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned int, uint, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned int, uint, mul)
+#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H)
+PSNIP_SAFE_DEFINE_INTSAFE(unsigned int, uint, add, UIntAdd)
+PSNIP_SAFE_DEFINE_INTSAFE(unsigned int, uint, sub, UIntSub)
+PSNIP_SAFE_DEFINE_INTSAFE(unsigned int, uint, mul, UIntMult)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_UINT)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned int, uint, add, UINT_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned int, uint, sub, UINT_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned int, uint, mul, UINT_MAX)
+#else
+PSNIP_SAFE_DEFINE_UNSIGNED_ADD(unsigned int, uint, UINT_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_SUB(unsigned int, uint, UINT_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_MUL(unsigned int, uint, UINT_MAX)
+#endif
+PSNIP_SAFE_DEFINE_UNSIGNED_DIV(unsigned int, uint, UINT_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_MOD(unsigned int, uint, UINT_MAX)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(long, long, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(long, long, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(long, long, mul)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_LONG)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(long, long, add, LONG_MIN, LONG_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(long, long, sub, LONG_MIN, LONG_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(long, long, mul, LONG_MIN, LONG_MAX)
+#else
+PSNIP_SAFE_DEFINE_SIGNED_ADD(long, long, LONG_MIN, LONG_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_SUB(long, long, LONG_MIN, LONG_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_MUL(long, long, LONG_MIN, LONG_MAX)
+#endif
+PSNIP_SAFE_DEFINE_SIGNED_DIV(long, long, LONG_MIN, LONG_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_MOD(long, long, LONG_MIN, LONG_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_NEG(long, long, LONG_MIN, LONG_MAX)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned long, ulong, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned long, ulong, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned long, ulong, mul)
+#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H)
+PSNIP_SAFE_DEFINE_INTSAFE(unsigned long, ulong, add, ULongAdd)
+PSNIP_SAFE_DEFINE_INTSAFE(unsigned long, ulong, sub, ULongSub)
+PSNIP_SAFE_DEFINE_INTSAFE(unsigned long, ulong, mul, ULongMult)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_ULONG)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned long, ulong, add, ULONG_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned long, ulong, sub, ULONG_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned long, ulong, mul, ULONG_MAX)
+#else
+PSNIP_SAFE_DEFINE_UNSIGNED_ADD(unsigned long, ulong, ULONG_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_SUB(unsigned long, ulong, ULONG_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_MUL(unsigned long, ulong, ULONG_MAX)
+#endif
+PSNIP_SAFE_DEFINE_UNSIGNED_DIV(unsigned long, ulong, ULONG_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_MOD(unsigned long, ulong, ULONG_MAX)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(long long, llong, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(long long, llong, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(long long, llong, mul)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_LLONG)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(long long, llong, add, LLONG_MIN, LLONG_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(long long, llong, sub, LLONG_MIN, LLONG_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(long long, llong, mul, LLONG_MIN, LLONG_MAX)
+#else
+PSNIP_SAFE_DEFINE_SIGNED_ADD(long long, llong, LLONG_MIN, LLONG_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_SUB(long long, llong, LLONG_MIN, LLONG_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_MUL(long long, llong, LLONG_MIN, LLONG_MAX)
+#endif
+PSNIP_SAFE_DEFINE_SIGNED_DIV(long long, llong, LLONG_MIN, LLONG_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_MOD(long long, llong, LLONG_MIN, LLONG_MAX)
+PSNIP_SAFE_DEFINE_SIGNED_NEG(long long, llong, LLONG_MIN, LLONG_MAX)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned long long, ullong, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned long long, ullong, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned long long, ullong, mul)
+#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H)
+PSNIP_SAFE_DEFINE_INTSAFE(unsigned long long, ullong, add, ULongLongAdd)
+PSNIP_SAFE_DEFINE_INTSAFE(unsigned long long, ullong, sub, ULongLongSub)
+PSNIP_SAFE_DEFINE_INTSAFE(unsigned long long, ullong, mul, ULongLongMult)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_ULLONG)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned long long, ullong, add, ULLONG_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned long long, ullong, sub, ULLONG_MAX)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned long long, ullong, mul, ULLONG_MAX)
+#else
+PSNIP_SAFE_DEFINE_UNSIGNED_ADD(unsigned long long, ullong, ULLONG_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_SUB(unsigned long long, ullong, ULLONG_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_MUL(unsigned long long, ullong, ULLONG_MAX)
+#endif
+PSNIP_SAFE_DEFINE_UNSIGNED_DIV(unsigned long long, ullong, ULLONG_MAX)
+PSNIP_SAFE_DEFINE_UNSIGNED_MOD(unsigned long long, ullong, ULLONG_MAX)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(size_t, size, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(size_t, size, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(size_t, size, mul)
+#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H)
+PSNIP_SAFE_DEFINE_INTSAFE(size_t, size, add, SizeTAdd)
+PSNIP_SAFE_DEFINE_INTSAFE(size_t, size, sub, SizeTSub)
+PSNIP_SAFE_DEFINE_INTSAFE(size_t, size, mul, SizeTMult)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_SIZE)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(size_t, size, add, PSNIP_SAFE__SIZE_MAX_RT)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(size_t, size, sub, PSNIP_SAFE__SIZE_MAX_RT)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(size_t, size, mul, PSNIP_SAFE__SIZE_MAX_RT)
+#else
+PSNIP_SAFE_DEFINE_UNSIGNED_ADD(size_t, size, PSNIP_SAFE__SIZE_MAX_RT)
+PSNIP_SAFE_DEFINE_UNSIGNED_SUB(size_t, size, PSNIP_SAFE__SIZE_MAX_RT)
+PSNIP_SAFE_DEFINE_UNSIGNED_MUL(size_t, size, PSNIP_SAFE__SIZE_MAX_RT)
+#endif
+PSNIP_SAFE_DEFINE_UNSIGNED_DIV(size_t, size, PSNIP_SAFE__SIZE_MAX_RT)
+PSNIP_SAFE_DEFINE_UNSIGNED_MOD(size_t, size, PSNIP_SAFE__SIZE_MAX_RT)
+
+#if !defined(PSNIP_SAFE_NO_FIXED)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int8_t, int8, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int8_t, int8, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int8_t, int8, mul)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_INT8)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int8_t, int8, add, (-0x7fLL-1), 0x7f)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int8_t, int8, sub, (-0x7fLL-1), 0x7f)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int8_t, int8, mul, (-0x7fLL-1), 0x7f)
+#else
+PSNIP_SAFE_DEFINE_SIGNED_ADD(psnip_int8_t, int8, (-0x7fLL-1), 0x7f)
+PSNIP_SAFE_DEFINE_SIGNED_SUB(psnip_int8_t, int8, (-0x7fLL-1), 0x7f)
+PSNIP_SAFE_DEFINE_SIGNED_MUL(psnip_int8_t, int8, (-0x7fLL-1), 0x7f)
+#endif
+PSNIP_SAFE_DEFINE_SIGNED_DIV(psnip_int8_t, int8, (-0x7fLL-1), 0x7f)
+PSNIP_SAFE_DEFINE_SIGNED_MOD(psnip_int8_t, int8, (-0x7fLL-1), 0x7f)
+PSNIP_SAFE_DEFINE_SIGNED_NEG(psnip_int8_t, int8, (-0x7fLL-1), 0x7f)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint8_t, uint8, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint8_t, uint8, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint8_t, uint8, mul)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_UINT8)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint8_t, uint8, add, 0xff)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint8_t, uint8, sub, 0xff)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint8_t, uint8, mul, 0xff)
+#else
+PSNIP_SAFE_DEFINE_UNSIGNED_ADD(psnip_uint8_t, uint8, 0xff)
+PSNIP_SAFE_DEFINE_UNSIGNED_SUB(psnip_uint8_t, uint8, 0xff)
+PSNIP_SAFE_DEFINE_UNSIGNED_MUL(psnip_uint8_t, uint8, 0xff)
+#endif
+PSNIP_SAFE_DEFINE_UNSIGNED_DIV(psnip_uint8_t, uint8, 0xff)
+PSNIP_SAFE_DEFINE_UNSIGNED_MOD(psnip_uint8_t, uint8, 0xff)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int16_t, int16, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int16_t, int16, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int16_t, int16, mul)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_INT16)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int16_t, int16, add, (-32767-1), 0x7fff)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int16_t, int16, sub, (-32767-1), 0x7fff)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int16_t, int16, mul, (-32767-1), 0x7fff)
+#else
+PSNIP_SAFE_DEFINE_SIGNED_ADD(psnip_int16_t, int16, (-32767-1), 0x7fff)
+PSNIP_SAFE_DEFINE_SIGNED_SUB(psnip_int16_t, int16, (-32767-1), 0x7fff)
+PSNIP_SAFE_DEFINE_SIGNED_MUL(psnip_int16_t, int16, (-32767-1), 0x7fff)
+#endif
+PSNIP_SAFE_DEFINE_SIGNED_DIV(psnip_int16_t, int16, (-32767-1), 0x7fff)
+PSNIP_SAFE_DEFINE_SIGNED_MOD(psnip_int16_t, int16, (-32767-1), 0x7fff)
+PSNIP_SAFE_DEFINE_SIGNED_NEG(psnip_int16_t, int16, (-32767-1), 0x7fff)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint16_t, uint16, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint16_t, uint16, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint16_t, uint16, mul)
+#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H) && defined(_WIN32)
+PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint16_t, uint16, add, UShortAdd)
+PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint16_t, uint16, sub, UShortSub)
+PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint16_t, uint16, mul, UShortMult)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_UINT16)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint16_t, uint16, add, 0xffff)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint16_t, uint16, sub, 0xffff)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint16_t, uint16, mul, 0xffff)
+#else
+PSNIP_SAFE_DEFINE_UNSIGNED_ADD(psnip_uint16_t, uint16, 0xffff)
+PSNIP_SAFE_DEFINE_UNSIGNED_SUB(psnip_uint16_t, uint16, 0xffff)
+PSNIP_SAFE_DEFINE_UNSIGNED_MUL(psnip_uint16_t, uint16, 0xffff)
+#endif
+PSNIP_SAFE_DEFINE_UNSIGNED_DIV(psnip_uint16_t, uint16, 0xffff)
+PSNIP_SAFE_DEFINE_UNSIGNED_MOD(psnip_uint16_t, uint16, 0xffff)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int32_t, int32, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int32_t, int32, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int32_t, int32, mul)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_INT32)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int32_t, int32, add, (-0x7fffffffLL-1), 0x7fffffffLL)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int32_t, int32, sub, (-0x7fffffffLL-1), 0x7fffffffLL)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int32_t, int32, mul, (-0x7fffffffLL-1), 0x7fffffffLL)
+#else
+PSNIP_SAFE_DEFINE_SIGNED_ADD(psnip_int32_t, int32, (-0x7fffffffLL-1), 0x7fffffffLL)
+PSNIP_SAFE_DEFINE_SIGNED_SUB(psnip_int32_t, int32, (-0x7fffffffLL-1), 0x7fffffffLL)
+PSNIP_SAFE_DEFINE_SIGNED_MUL(psnip_int32_t, int32, (-0x7fffffffLL-1), 0x7fffffffLL)
+#endif
+PSNIP_SAFE_DEFINE_SIGNED_DIV(psnip_int32_t, int32, (-0x7fffffffLL-1), 0x7fffffffLL)
+PSNIP_SAFE_DEFINE_SIGNED_MOD(psnip_int32_t, int32, (-0x7fffffffLL-1), 0x7fffffffLL)
+PSNIP_SAFE_DEFINE_SIGNED_NEG(psnip_int32_t, int32, (-0x7fffffffLL-1), 0x7fffffffLL)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint32_t, uint32, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint32_t, uint32, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint32_t, uint32, mul)
+#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H) && defined(_WIN32)
+PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint32_t, uint32, add, UIntAdd)
+PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint32_t, uint32, sub, UIntSub)
+PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint32_t, uint32, mul, UIntMult)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_UINT32)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint32_t, uint32, add, 0xffffffffUL)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint32_t, uint32, sub, 0xffffffffUL)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint32_t, uint32, mul, 0xffffffffUL)
+#else
+PSNIP_SAFE_DEFINE_UNSIGNED_ADD(psnip_uint32_t, uint32, 0xffffffffUL)
+PSNIP_SAFE_DEFINE_UNSIGNED_SUB(psnip_uint32_t, uint32, 0xffffffffUL)
+PSNIP_SAFE_DEFINE_UNSIGNED_MUL(psnip_uint32_t, uint32, 0xffffffffUL)
+#endif
+PSNIP_SAFE_DEFINE_UNSIGNED_DIV(psnip_uint32_t, uint32, 0xffffffffUL)
+PSNIP_SAFE_DEFINE_UNSIGNED_MOD(psnip_uint32_t, uint32, 0xffffffffUL)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int64_t, int64, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int64_t, int64, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int64_t, int64, mul)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_INT64)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int64_t, int64, add, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int64_t, int64, sub, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL)
+PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int64_t, int64, mul, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL)
+#else
+PSNIP_SAFE_DEFINE_SIGNED_ADD(psnip_int64_t, int64, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL)
+PSNIP_SAFE_DEFINE_SIGNED_SUB(psnip_int64_t, int64, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL)
+PSNIP_SAFE_DEFINE_SIGNED_MUL(psnip_int64_t, int64, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL)
+#endif
+PSNIP_SAFE_DEFINE_SIGNED_DIV(psnip_int64_t, int64, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL)
+PSNIP_SAFE_DEFINE_SIGNED_MOD(psnip_int64_t, int64, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL)
+PSNIP_SAFE_DEFINE_SIGNED_NEG(psnip_int64_t, int64, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint64_t, uint64, add)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint64_t, uint64, sub)
+PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint64_t, uint64, mul)
+#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H) && defined(_WIN32)
+PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint64_t, uint64, add, ULongLongAdd)
+PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint64_t, uint64, sub, ULongLongSub)
+PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint64_t, uint64, mul, ULongLongMult)
+#elif defined(PSNIP_SAFE_HAVE_LARGER_UINT64)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint64_t, uint64, add, 0xffffffffffffffffULL)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint64_t, uint64, sub, 0xffffffffffffffffULL)
+PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint64_t, uint64, mul, 0xffffffffffffffffULL)
+#else
+PSNIP_SAFE_DEFINE_UNSIGNED_ADD(psnip_uint64_t, uint64, 0xffffffffffffffffULL)
+PSNIP_SAFE_DEFINE_UNSIGNED_SUB(psnip_uint64_t, uint64, 0xffffffffffffffffULL)
+PSNIP_SAFE_DEFINE_UNSIGNED_MUL(psnip_uint64_t, uint64, 0xffffffffffffffffULL)
+#endif
+PSNIP_SAFE_DEFINE_UNSIGNED_DIV(psnip_uint64_t, uint64, 0xffffffffffffffffULL)
+PSNIP_SAFE_DEFINE_UNSIGNED_MOD(psnip_uint64_t, uint64, 0xffffffffffffffffULL)
+
+#endif /* !defined(PSNIP_SAFE_NO_FIXED) */
+
+#define PSNIP_SAFE_C11_GENERIC_SELECTION(res, op) \
+ _Generic((*res), \
+ char: psnip_safe_char_##op, \
+ unsigned char: psnip_safe_uchar_##op, \
+ short: psnip_safe_short_##op, \
+ unsigned short: psnip_safe_ushort_##op, \
+ int: psnip_safe_int_##op, \
+ unsigned int: psnip_safe_uint_##op, \
+ long: psnip_safe_long_##op, \
+ unsigned long: psnip_safe_ulong_##op, \
+ long long: psnip_safe_llong_##op, \
+ unsigned long long: psnip_safe_ullong_##op)
+
+#define PSNIP_SAFE_C11_GENERIC_BINARY_OP(op, res, a, b) \
+ PSNIP_SAFE_C11_GENERIC_SELECTION(res, op)(res, a, b)
+#define PSNIP_SAFE_C11_GENERIC_UNARY_OP(op, res, v) \
+ PSNIP_SAFE_C11_GENERIC_SELECTION(res, op)(res, v)
+
+#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW)
+#define psnip_safe_add(res, a, b) !__builtin_add_overflow(a, b, res)
+#define psnip_safe_sub(res, a, b) !__builtin_sub_overflow(a, b, res)
+#define psnip_safe_mul(res, a, b) !__builtin_mul_overflow(a, b, res)
+#define psnip_safe_div(res, a, b) !__builtin_div_overflow(a, b, res)
+#define psnip_safe_mod(res, a, b) !__builtin_mod_overflow(a, b, res)
+#define psnip_safe_neg(res, v) PSNIP_SAFE_C11_GENERIC_UNARY_OP (neg, res, v)
+
+#elif defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)
+/* The are no fixed-length or size selections because they cause an
+ * error about _Generic specifying two compatible types. Hopefully
+ * this doesn't cause problems on exotic platforms, but if it does
+ * please let me know and I'll try to figure something out. */
+
+#define psnip_safe_add(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(add, res, a, b)
+#define psnip_safe_sub(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(sub, res, a, b)
+#define psnip_safe_mul(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(mul, res, a, b)
+#define psnip_safe_div(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(div, res, a, b)
+#define psnip_safe_mod(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(mod, res, a, b)
+#define psnip_safe_neg(res, v) PSNIP_SAFE_C11_GENERIC_UNARY_OP (neg, res, v)
+#endif
+
+#if !defined(PSNIP_SAFE_HAVE_BUILTINS) && (defined(PSNIP_SAFE_EMULATE_NATIVE) || defined(PSNIP_BUILTIN_EMULATE_NATIVE))
+# define __builtin_sadd_overflow(a, b, res) (!psnip_safe_int_add(res, a, b))
+# define __builtin_saddl_overflow(a, b, res) (!psnip_safe_long_add(res, a, b))
+# define __builtin_saddll_overflow(a, b, res) (!psnip_safe_llong_add(res, a, b))
+# define __builtin_uadd_overflow(a, b, res) (!psnip_safe_uint_add(res, a, b))
+# define __builtin_uaddl_overflow(a, b, res) (!psnip_safe_ulong_add(res, a, b))
+# define __builtin_uaddll_overflow(a, b, res) (!psnip_safe_ullong_add(res, a, b))
+
+# define __builtin_ssub_overflow(a, b, res) (!psnip_safe_int_sub(res, a, b))
+# define __builtin_ssubl_overflow(a, b, res) (!psnip_safe_long_sub(res, a, b))
+# define __builtin_ssubll_overflow(a, b, res) (!psnip_safe_llong_sub(res, a, b))
+# define __builtin_usub_overflow(a, b, res) (!psnip_safe_uint_sub(res, a, b))
+# define __builtin_usubl_overflow(a, b, res) (!psnip_safe_ulong_sub(res, a, b))
+# define __builtin_usubll_overflow(a, b, res) (!psnip_safe_ullong_sub(res, a, b))
+
+# define __builtin_smul_overflow(a, b, res) (!psnip_safe_int_mul(res, a, b))
+# define __builtin_smull_overflow(a, b, res) (!psnip_safe_long_mul(res, a, b))
+# define __builtin_smulll_overflow(a, b, res) (!psnip_safe_llong_mul(res, a, b))
+# define __builtin_umul_overflow(a, b, res) (!psnip_safe_uint_mul(res, a, b))
+# define __builtin_umull_overflow(a, b, res) (!psnip_safe_ulong_mul(res, a, b))
+# define __builtin_umulll_overflow(a, b, res) (!psnip_safe_ullong_mul(res, a, b))
+#endif
+
+#endif /* !defined(PSNIP_SAFE_H) */
diff --git a/src/arrow/cpp/src/arrow/vendored/string_view.hpp b/src/arrow/cpp/src/arrow/vendored/string_view.hpp
new file mode 100644
index 000000000..a2d556785
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/string_view.hpp
@@ -0,0 +1,1531 @@
+// Vendored from git changeset v1.4.0
+
+// Copyright 2017-2020 by Martin Moene
+//
+// string-view lite, a C++17-like string_view for C++98 and later.
+// For more information see https://github.com/martinmoene/string-view-lite
+//
+// Distributed under the Boost Software License, Version 1.0.
+// (See accompanying file LICENSE.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
+
+#pragma once
+
+#ifndef NONSTD_SV_LITE_H_INCLUDED
+#define NONSTD_SV_LITE_H_INCLUDED
+
+#define string_view_lite_MAJOR 1
+#define string_view_lite_MINOR 4
+#define string_view_lite_PATCH 0
+
+#define string_view_lite_VERSION nssv_STRINGIFY(string_view_lite_MAJOR) "." nssv_STRINGIFY(string_view_lite_MINOR) "." nssv_STRINGIFY(string_view_lite_PATCH)
+
+#define nssv_STRINGIFY( x ) nssv_STRINGIFY_( x )
+#define nssv_STRINGIFY_( x ) #x
+
+// string-view lite configuration:
+
+#define nssv_STRING_VIEW_DEFAULT 0
+#define nssv_STRING_VIEW_NONSTD 1
+#define nssv_STRING_VIEW_STD 2
+
+#if !defined( nssv_CONFIG_SELECT_STRING_VIEW )
+# define nssv_CONFIG_SELECT_STRING_VIEW ( nssv_HAVE_STD_STRING_VIEW ? nssv_STRING_VIEW_STD : nssv_STRING_VIEW_NONSTD )
+#endif
+
+#if defined( nssv_CONFIG_SELECT_STD_STRING_VIEW ) || defined( nssv_CONFIG_SELECT_NONSTD_STRING_VIEW )
+# error nssv_CONFIG_SELECT_STD_STRING_VIEW and nssv_CONFIG_SELECT_NONSTD_STRING_VIEW are deprecated and removed, please use nssv_CONFIG_SELECT_STRING_VIEW=nssv_STRING_VIEW_...
+#endif
+
+#ifndef nssv_CONFIG_STD_SV_OPERATOR
+# define nssv_CONFIG_STD_SV_OPERATOR 0
+#endif
+
+#ifndef nssv_CONFIG_USR_SV_OPERATOR
+# define nssv_CONFIG_USR_SV_OPERATOR 1
+#endif
+
+#ifdef nssv_CONFIG_CONVERSION_STD_STRING
+# define nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS nssv_CONFIG_CONVERSION_STD_STRING
+# define nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS nssv_CONFIG_CONVERSION_STD_STRING
+#endif
+
+#ifndef nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS
+# define nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS 1
+#endif
+
+#ifndef nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS
+# define nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS 1
+#endif
+
+// Control presence of exception handling (try and auto discover):
+
+#ifndef nssv_CONFIG_NO_EXCEPTIONS
+# if defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)
+# define nssv_CONFIG_NO_EXCEPTIONS 0
+# else
+# define nssv_CONFIG_NO_EXCEPTIONS 1
+# endif
+#endif
+
+// C++ language version detection (C++20 is speculative):
+// Note: VC14.0/1900 (VS2015) lacks too much from C++14.
+
+#ifndef nssv_CPLUSPLUS
+# if defined(_MSVC_LANG ) && !defined(__clang__)
+# define nssv_CPLUSPLUS (_MSC_VER == 1900 ? 201103L : _MSVC_LANG )
+# else
+# define nssv_CPLUSPLUS __cplusplus
+# endif
+#endif
+
+#define nssv_CPP98_OR_GREATER ( nssv_CPLUSPLUS >= 199711L )
+#define nssv_CPP11_OR_GREATER ( nssv_CPLUSPLUS >= 201103L )
+#define nssv_CPP11_OR_GREATER_ ( nssv_CPLUSPLUS >= 201103L )
+#define nssv_CPP14_OR_GREATER ( nssv_CPLUSPLUS >= 201402L )
+#define nssv_CPP17_OR_GREATER ( nssv_CPLUSPLUS >= 201703L )
+#define nssv_CPP20_OR_GREATER ( nssv_CPLUSPLUS >= 202000L )
+
+// use C++17 std::string_view if available and requested:
+
+#if nssv_CPP17_OR_GREATER && defined(__has_include )
+# if __has_include( <string_view> )
+# define nssv_HAVE_STD_STRING_VIEW 1
+# else
+# define nssv_HAVE_STD_STRING_VIEW 0
+# endif
+#else
+# define nssv_HAVE_STD_STRING_VIEW 0
+#endif
+
+#define nssv_USES_STD_STRING_VIEW ( (nssv_CONFIG_SELECT_STRING_VIEW == nssv_STRING_VIEW_STD) || ((nssv_CONFIG_SELECT_STRING_VIEW == nssv_STRING_VIEW_DEFAULT) && nssv_HAVE_STD_STRING_VIEW) )
+
+#define nssv_HAVE_STARTS_WITH ( nssv_CPP20_OR_GREATER || !nssv_USES_STD_STRING_VIEW )
+#define nssv_HAVE_ENDS_WITH nssv_HAVE_STARTS_WITH
+
+//
+// Use C++17 std::string_view:
+//
+
+#if nssv_USES_STD_STRING_VIEW
+
+#include <string_view>
+
+// Extensions for std::string:
+
+#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS
+
+namespace nonstd {
+
+template< class CharT, class Traits, class Allocator = std::allocator<CharT> >
+std::basic_string<CharT, Traits, Allocator>
+to_string( std::basic_string_view<CharT, Traits> v, Allocator const & a = Allocator() )
+{
+ return std::basic_string<CharT,Traits, Allocator>( v.begin(), v.end(), a );
+}
+
+template< class CharT, class Traits, class Allocator >
+std::basic_string_view<CharT, Traits>
+to_string_view( std::basic_string<CharT, Traits, Allocator> const & s )
+{
+ return std::basic_string_view<CharT, Traits>( s.data(), s.size() );
+}
+
+// Literal operators sv and _sv:
+
+#if nssv_CONFIG_STD_SV_OPERATOR
+
+using namespace std::literals::string_view_literals;
+
+#endif
+
+#if nssv_CONFIG_USR_SV_OPERATOR
+
+inline namespace literals {
+inline namespace string_view_literals {
+
+
+constexpr std::string_view operator "" _sv( const char* str, size_t len ) noexcept // (1)
+{
+ return std::string_view{ str, len };
+}
+
+constexpr std::u16string_view operator "" _sv( const char16_t* str, size_t len ) noexcept // (2)
+{
+ return std::u16string_view{ str, len };
+}
+
+constexpr std::u32string_view operator "" _sv( const char32_t* str, size_t len ) noexcept // (3)
+{
+ return std::u32string_view{ str, len };
+}
+
+constexpr std::wstring_view operator "" _sv( const wchar_t* str, size_t len ) noexcept // (4)
+{
+ return std::wstring_view{ str, len };
+}
+
+}} // namespace literals::string_view_literals
+
+#endif // nssv_CONFIG_USR_SV_OPERATOR
+
+} // namespace nonstd
+
+#endif // nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS
+
+namespace nonstd {
+
+using std::string_view;
+using std::wstring_view;
+using std::u16string_view;
+using std::u32string_view;
+using std::basic_string_view;
+
+// literal "sv" and "_sv", see above
+
+using std::operator==;
+using std::operator!=;
+using std::operator<;
+using std::operator<=;
+using std::operator>;
+using std::operator>=;
+
+using std::operator<<;
+
+} // namespace nonstd
+
+#else // nssv_HAVE_STD_STRING_VIEW
+
+//
+// Before C++17: use string_view lite:
+//
+
+// Compiler versions:
+//
+// MSVC++ 6.0 _MSC_VER == 1200 nssv_COMPILER_MSVC_VERSION == 60 (Visual Studio 6.0)
+// MSVC++ 7.0 _MSC_VER == 1300 nssv_COMPILER_MSVC_VERSION == 70 (Visual Studio .NET 2002)
+// MSVC++ 7.1 _MSC_VER == 1310 nssv_COMPILER_MSVC_VERSION == 71 (Visual Studio .NET 2003)
+// MSVC++ 8.0 _MSC_VER == 1400 nssv_COMPILER_MSVC_VERSION == 80 (Visual Studio 2005)
+// MSVC++ 9.0 _MSC_VER == 1500 nssv_COMPILER_MSVC_VERSION == 90 (Visual Studio 2008)
+// MSVC++ 10.0 _MSC_VER == 1600 nssv_COMPILER_MSVC_VERSION == 100 (Visual Studio 2010)
+// MSVC++ 11.0 _MSC_VER == 1700 nssv_COMPILER_MSVC_VERSION == 110 (Visual Studio 2012)
+// MSVC++ 12.0 _MSC_VER == 1800 nssv_COMPILER_MSVC_VERSION == 120 (Visual Studio 2013)
+// MSVC++ 14.0 _MSC_VER == 1900 nssv_COMPILER_MSVC_VERSION == 140 (Visual Studio 2015)
+// MSVC++ 14.1 _MSC_VER >= 1910 nssv_COMPILER_MSVC_VERSION == 141 (Visual Studio 2017)
+// MSVC++ 14.2 _MSC_VER >= 1920 nssv_COMPILER_MSVC_VERSION == 142 (Visual Studio 2019)
+
+#if defined(_MSC_VER ) && !defined(__clang__)
+# define nssv_COMPILER_MSVC_VER (_MSC_VER )
+# define nssv_COMPILER_MSVC_VERSION (_MSC_VER / 10 - 10 * ( 5 + (_MSC_VER < 1900 ) ) )
+#else
+# define nssv_COMPILER_MSVC_VER 0
+# define nssv_COMPILER_MSVC_VERSION 0
+#endif
+
+#define nssv_COMPILER_VERSION( major, minor, patch ) ( 10 * ( 10 * (major) + (minor) ) + (patch) )
+
+#if defined(__clang__)
+# define nssv_COMPILER_CLANG_VERSION nssv_COMPILER_VERSION(__clang_major__, __clang_minor__, __clang_patchlevel__)
+#else
+# define nssv_COMPILER_CLANG_VERSION 0
+#endif
+
+#if defined(__GNUC__) && !defined(__clang__)
+# define nssv_COMPILER_GNUC_VERSION nssv_COMPILER_VERSION(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__)
+#else
+# define nssv_COMPILER_GNUC_VERSION 0
+#endif
+
+// half-open range [lo..hi):
+#define nssv_BETWEEN( v, lo, hi ) ( (lo) <= (v) && (v) < (hi) )
+
+// Presence of language and library features:
+
+#ifdef _HAS_CPP0X
+# define nssv_HAS_CPP0X _HAS_CPP0X
+#else
+# define nssv_HAS_CPP0X 0
+#endif
+
+// Unless defined otherwise below, consider VC14 as C++11 for variant-lite:
+
+#if nssv_COMPILER_MSVC_VER >= 1900
+# undef nssv_CPP11_OR_GREATER
+# define nssv_CPP11_OR_GREATER 1
+#endif
+
+#define nssv_CPP11_90 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1500)
+#define nssv_CPP11_100 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1600)
+#define nssv_CPP11_110 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1700)
+#define nssv_CPP11_120 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1800)
+#define nssv_CPP11_140 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1900)
+#define nssv_CPP11_141 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1910)
+
+#define nssv_CPP14_000 (nssv_CPP14_OR_GREATER)
+#define nssv_CPP17_000 (nssv_CPP17_OR_GREATER)
+
+// Presence of C++11 language features:
+
+#define nssv_HAVE_CONSTEXPR_11 nssv_CPP11_140
+#define nssv_HAVE_EXPLICIT_CONVERSION nssv_CPP11_140
+#define nssv_HAVE_INLINE_NAMESPACE nssv_CPP11_140
+#define nssv_HAVE_NOEXCEPT nssv_CPP11_140
+#define nssv_HAVE_NULLPTR nssv_CPP11_100
+#define nssv_HAVE_REF_QUALIFIER nssv_CPP11_140
+#define nssv_HAVE_UNICODE_LITERALS nssv_CPP11_140
+#define nssv_HAVE_USER_DEFINED_LITERALS nssv_CPP11_140
+#define nssv_HAVE_WCHAR16_T nssv_CPP11_100
+#define nssv_HAVE_WCHAR32_T nssv_CPP11_100
+
+#if ! ( ( nssv_CPP11_OR_GREATER && nssv_COMPILER_CLANG_VERSION ) || nssv_BETWEEN( nssv_COMPILER_CLANG_VERSION, 300, 400 ) )
+# define nssv_HAVE_STD_DEFINED_LITERALS nssv_CPP11_140
+#else
+# define nssv_HAVE_STD_DEFINED_LITERALS 0
+#endif
+
+// Presence of C++14 language features:
+
+#define nssv_HAVE_CONSTEXPR_14 nssv_CPP14_000
+
+// Presence of C++17 language features:
+
+#define nssv_HAVE_NODISCARD nssv_CPP17_000
+
+// Presence of C++ library features:
+
+#define nssv_HAVE_STD_HASH nssv_CPP11_120
+
+// C++ feature usage:
+
+#if nssv_HAVE_CONSTEXPR_11
+# define nssv_constexpr constexpr
+#else
+# define nssv_constexpr /*constexpr*/
+#endif
+
+#if nssv_HAVE_CONSTEXPR_14
+# define nssv_constexpr14 constexpr
+#else
+# define nssv_constexpr14 /*constexpr*/
+#endif
+
+#if nssv_HAVE_EXPLICIT_CONVERSION
+# define nssv_explicit explicit
+#else
+# define nssv_explicit /*explicit*/
+#endif
+
+#if nssv_HAVE_INLINE_NAMESPACE
+# define nssv_inline_ns inline
+#else
+# define nssv_inline_ns /*inline*/
+#endif
+
+#if nssv_HAVE_NOEXCEPT
+# define nssv_noexcept noexcept
+#else
+# define nssv_noexcept /*noexcept*/
+#endif
+
+//#if nssv_HAVE_REF_QUALIFIER
+//# define nssv_ref_qual &
+//# define nssv_refref_qual &&
+//#else
+//# define nssv_ref_qual /*&*/
+//# define nssv_refref_qual /*&&*/
+//#endif
+
+#if nssv_HAVE_NULLPTR
+# define nssv_nullptr nullptr
+#else
+# define nssv_nullptr NULL
+#endif
+
+#if nssv_HAVE_NODISCARD
+# define nssv_nodiscard [[nodiscard]]
+#else
+# define nssv_nodiscard /*[[nodiscard]]*/
+#endif
+
+// Additional includes:
+
+#include <algorithm>
+#include <cassert>
+#include <iterator>
+#include <limits>
+#include <ostream>
+#include <string> // std::char_traits<>
+
+#if ! nssv_CONFIG_NO_EXCEPTIONS
+# include <stdexcept>
+#endif
+
+#if nssv_CPP11_OR_GREATER
+# include <type_traits>
+#endif
+
+// Clang, GNUC, MSVC warning suppression macros:
+
+#if defined(__clang__)
+# pragma clang diagnostic ignored "-Wreserved-user-defined-literal"
+# pragma clang diagnostic push
+# pragma clang diagnostic ignored "-Wuser-defined-literals"
+#elif defined(__GNUC__)
+# pragma GCC diagnostic push
+# pragma GCC diagnostic ignored "-Wliteral-suffix"
+#endif // __clang__
+
+#if nssv_COMPILER_MSVC_VERSION >= 140
+# define nssv_SUPPRESS_MSGSL_WARNING(expr) [[gsl::suppress(expr)]]
+# define nssv_SUPPRESS_MSVC_WARNING(code, descr) __pragma(warning(suppress: code) )
+# define nssv_DISABLE_MSVC_WARNINGS(codes) __pragma(warning(push)) __pragma(warning(disable: codes))
+#else
+# define nssv_SUPPRESS_MSGSL_WARNING(expr)
+# define nssv_SUPPRESS_MSVC_WARNING(code, descr)
+# define nssv_DISABLE_MSVC_WARNINGS(codes)
+#endif
+
+#if defined(__clang__)
+# define nssv_RESTORE_WARNINGS() _Pragma("clang diagnostic pop")
+#elif defined(__GNUC__)
+# define nssv_RESTORE_WARNINGS() _Pragma("GCC diagnostic pop")
+#elif nssv_COMPILER_MSVC_VERSION >= 140
+# define nssv_RESTORE_WARNINGS() __pragma(warning(pop ))
+#else
+# define nssv_RESTORE_WARNINGS()
+#endif
+
+// Suppress the following MSVC (GSL) warnings:
+// - C4455, non-gsl : 'operator ""sv': literal suffix identifiers that do not
+// start with an underscore are reserved
+// - C26472, gsl::t.1 : don't use a static_cast for arithmetic conversions;
+// use brace initialization, gsl::narrow_cast or gsl::narow
+// - C26481: gsl::b.1 : don't use pointer arithmetic. Use span instead
+
+nssv_DISABLE_MSVC_WARNINGS( 4455 26481 26472 )
+//nssv_DISABLE_CLANG_WARNINGS( "-Wuser-defined-literals" )
+//nssv_DISABLE_GNUC_WARNINGS( -Wliteral-suffix )
+
+namespace nonstd { namespace sv_lite {
+
+#if nssv_CPP11_OR_GREATER
+
+namespace detail {
+
+#if nssv_CPP14_OR_GREATER
+
+template< typename CharT >
+inline constexpr std::size_t length( CharT * s, std::size_t result = 0 )
+{
+ CharT * v = s;
+ std::size_t r = result;
+ while ( *v != '\0' ) {
+ ++v;
+ ++r;
+ }
+ return r;
+}
+
+#else // nssv_CPP14_OR_GREATER
+
+// Expect tail call optimization to make length() non-recursive:
+
+template< typename CharT >
+inline constexpr std::size_t length( CharT * s, std::size_t result = 0 )
+{
+ return *s == '\0' ? result : length( s + 1, result + 1 );
+}
+
+#endif // nssv_CPP14_OR_GREATER
+
+} // namespace detail
+
+#endif // nssv_CPP11_OR_GREATER
+
+template
+<
+ class CharT,
+ class Traits = std::char_traits<CharT>
+>
+class basic_string_view;
+
+//
+// basic_string_view:
+//
+
+template
+<
+ class CharT,
+ class Traits /* = std::char_traits<CharT> */
+>
+class basic_string_view
+{
+public:
+ // Member types:
+
+ typedef Traits traits_type;
+ typedef CharT value_type;
+
+ typedef CharT * pointer;
+ typedef CharT const * const_pointer;
+ typedef CharT & reference;
+ typedef CharT const & const_reference;
+
+ typedef const_pointer iterator;
+ typedef const_pointer const_iterator;
+ typedef std::reverse_iterator< const_iterator > reverse_iterator;
+ typedef std::reverse_iterator< const_iterator > const_reverse_iterator;
+
+ typedef std::size_t size_type;
+ typedef std::ptrdiff_t difference_type;
+
+ // 24.4.2.1 Construction and assignment:
+
+ nssv_constexpr basic_string_view() nssv_noexcept
+ : data_( nssv_nullptr )
+ , size_( 0 )
+ {}
+
+#if nssv_CPP11_OR_GREATER
+ nssv_constexpr basic_string_view( basic_string_view const & other ) nssv_noexcept = default;
+#else
+ nssv_constexpr basic_string_view( basic_string_view const & other ) nssv_noexcept
+ : data_( other.data_)
+ , size_( other.size_)
+ {}
+#endif
+
+ nssv_constexpr basic_string_view( CharT const * s, size_type count ) nssv_noexcept // non-standard noexcept
+ : data_( s )
+ , size_( count )
+ {}
+
+ nssv_constexpr basic_string_view( CharT const * s) nssv_noexcept // non-standard noexcept
+ : data_( s )
+#if nssv_CPP17_OR_GREATER
+ , size_( Traits::length(s) )
+#elif nssv_CPP11_OR_GREATER
+ , size_( detail::length(s) )
+#else
+ , size_( Traits::length(s) )
+#endif
+ {}
+
+ // Assignment:
+
+#if nssv_CPP11_OR_GREATER
+ nssv_constexpr14 basic_string_view & operator=( basic_string_view const & other ) nssv_noexcept = default;
+#else
+ nssv_constexpr14 basic_string_view & operator=( basic_string_view const & other ) nssv_noexcept
+ {
+ data_ = other.data_;
+ size_ = other.size_;
+ return *this;
+ }
+#endif
+
+ // 24.4.2.2 Iterator support:
+
+ nssv_constexpr const_iterator begin() const nssv_noexcept { return data_; }
+ nssv_constexpr const_iterator end() const nssv_noexcept { return data_ + size_; }
+
+ nssv_constexpr const_iterator cbegin() const nssv_noexcept { return begin(); }
+ nssv_constexpr const_iterator cend() const nssv_noexcept { return end(); }
+
+ nssv_constexpr const_reverse_iterator rbegin() const nssv_noexcept { return const_reverse_iterator( end() ); }
+ nssv_constexpr const_reverse_iterator rend() const nssv_noexcept { return const_reverse_iterator( begin() ); }
+
+ nssv_constexpr const_reverse_iterator crbegin() const nssv_noexcept { return rbegin(); }
+ nssv_constexpr const_reverse_iterator crend() const nssv_noexcept { return rend(); }
+
+ // 24.4.2.3 Capacity:
+
+ nssv_constexpr size_type size() const nssv_noexcept { return size_; }
+ nssv_constexpr size_type length() const nssv_noexcept { return size_; }
+ nssv_constexpr size_type max_size() const nssv_noexcept { return (std::numeric_limits< size_type >::max)(); }
+
+ // since C++20
+ nssv_nodiscard nssv_constexpr bool empty() const nssv_noexcept
+ {
+ return 0 == size_;
+ }
+
+ // 24.4.2.4 Element access:
+
+ nssv_constexpr const_reference operator[]( size_type pos ) const
+ {
+ return data_at( pos );
+ }
+
+ nssv_constexpr14 const_reference at( size_type pos ) const
+ {
+#if nssv_CONFIG_NO_EXCEPTIONS
+ assert( pos < size() );
+#else
+ if ( pos >= size() )
+ {
+ throw std::out_of_range("nonstd::string_view::at()");
+ }
+#endif
+ return data_at( pos );
+ }
+
+ nssv_constexpr const_reference front() const { return data_at( 0 ); }
+ nssv_constexpr const_reference back() const { return data_at( size() - 1 ); }
+
+ nssv_constexpr const_pointer data() const nssv_noexcept { return data_; }
+
+ // 24.4.2.5 Modifiers:
+
+ nssv_constexpr14 void remove_prefix( size_type n )
+ {
+ assert( n <= size() );
+ data_ += n;
+ size_ -= n;
+ }
+
+ nssv_constexpr14 void remove_suffix( size_type n )
+ {
+ assert( n <= size() );
+ size_ -= n;
+ }
+
+ nssv_constexpr14 void swap( basic_string_view & other ) nssv_noexcept
+ {
+ using std::swap;
+ swap( data_, other.data_ );
+ swap( size_, other.size_ );
+ }
+
+ // 24.4.2.6 String operations:
+
+ size_type copy( CharT * dest, size_type n, size_type pos = 0 ) const
+ {
+#if nssv_CONFIG_NO_EXCEPTIONS
+ assert( pos <= size() );
+#else
+ if ( pos > size() )
+ {
+ throw std::out_of_range("nonstd::string_view::copy()");
+ }
+#endif
+ const size_type rlen = (std::min)( n, size() - pos );
+
+ (void) Traits::copy( dest, data() + pos, rlen );
+
+ return rlen;
+ }
+
+ nssv_constexpr14 basic_string_view substr( size_type pos = 0, size_type n = npos ) const
+ {
+#if nssv_CONFIG_NO_EXCEPTIONS
+ assert( pos <= size() );
+#else
+ if ( pos > size() )
+ {
+ throw std::out_of_range("nonstd::string_view::substr()");
+ }
+#endif
+ return basic_string_view( data() + pos, (std::min)( n, size() - pos ) );
+ }
+
+ // compare(), 6x:
+
+ nssv_constexpr14 int compare( basic_string_view other ) const nssv_noexcept // (1)
+ {
+ if ( const int result = Traits::compare( data(), other.data(), (std::min)( size(), other.size() ) ) )
+ {
+ return result;
+ }
+
+ return size() == other.size() ? 0 : size() < other.size() ? -1 : 1;
+ }
+
+ nssv_constexpr int compare( size_type pos1, size_type n1, basic_string_view other ) const // (2)
+ {
+ return substr( pos1, n1 ).compare( other );
+ }
+
+ nssv_constexpr int compare( size_type pos1, size_type n1, basic_string_view other, size_type pos2, size_type n2 ) const // (3)
+ {
+ return substr( pos1, n1 ).compare( other.substr( pos2, n2 ) );
+ }
+
+ nssv_constexpr int compare( CharT const * s ) const // (4)
+ {
+ return compare( basic_string_view( s ) );
+ }
+
+ nssv_constexpr int compare( size_type pos1, size_type n1, CharT const * s ) const // (5)
+ {
+ return substr( pos1, n1 ).compare( basic_string_view( s ) );
+ }
+
+ nssv_constexpr int compare( size_type pos1, size_type n1, CharT const * s, size_type n2 ) const // (6)
+ {
+ return substr( pos1, n1 ).compare( basic_string_view( s, n2 ) );
+ }
+
+ // 24.4.2.7 Searching:
+
+ // starts_with(), 3x, since C++20:
+
+ nssv_constexpr bool starts_with( basic_string_view v ) const nssv_noexcept // (1)
+ {
+ return size() >= v.size() && compare( 0, v.size(), v ) == 0;
+ }
+
+ nssv_constexpr bool starts_with( CharT c ) const nssv_noexcept // (2)
+ {
+ return starts_with( basic_string_view( &c, 1 ) );
+ }
+
+ nssv_constexpr bool starts_with( CharT const * s ) const // (3)
+ {
+ return starts_with( basic_string_view( s ) );
+ }
+
+ // ends_with(), 3x, since C++20:
+
+ nssv_constexpr bool ends_with( basic_string_view v ) const nssv_noexcept // (1)
+ {
+ return size() >= v.size() && compare( size() - v.size(), npos, v ) == 0;
+ }
+
+ nssv_constexpr bool ends_with( CharT c ) const nssv_noexcept // (2)
+ {
+ return ends_with( basic_string_view( &c, 1 ) );
+ }
+
+ nssv_constexpr bool ends_with( CharT const * s ) const // (3)
+ {
+ return ends_with( basic_string_view( s ) );
+ }
+
+ // find(), 4x:
+
+ nssv_constexpr14 size_type find( basic_string_view v, size_type pos = 0 ) const nssv_noexcept // (1)
+ {
+ return assert( v.size() == 0 || v.data() != nssv_nullptr )
+ , pos >= size()
+ ? npos
+ : to_pos( std::search( cbegin() + pos, cend(), v.cbegin(), v.cend(), Traits::eq ) );
+ }
+
+ nssv_constexpr14 size_type find( CharT c, size_type pos = 0 ) const nssv_noexcept // (2)
+ {
+ return find( basic_string_view( &c, 1 ), pos );
+ }
+
+ nssv_constexpr14 size_type find( CharT const * s, size_type pos, size_type n ) const // (3)
+ {
+ return find( basic_string_view( s, n ), pos );
+ }
+
+ nssv_constexpr14 size_type find( CharT const * s, size_type pos = 0 ) const // (4)
+ {
+ return find( basic_string_view( s ), pos );
+ }
+
+ // rfind(), 4x:
+
+ nssv_constexpr14 size_type rfind( basic_string_view v, size_type pos = npos ) const nssv_noexcept // (1)
+ {
+ if ( size() < v.size() )
+ {
+ return npos;
+ }
+
+ if ( v.empty() )
+ {
+ return (std::min)( size(), pos );
+ }
+
+ const_iterator last = cbegin() + (std::min)( size() - v.size(), pos ) + v.size();
+ const_iterator result = std::find_end( cbegin(), last, v.cbegin(), v.cend(), Traits::eq );
+
+ return result != last ? size_type( result - cbegin() ) : npos;
+ }
+
+ nssv_constexpr14 size_type rfind( CharT c, size_type pos = npos ) const nssv_noexcept // (2)
+ {
+ return rfind( basic_string_view( &c, 1 ), pos );
+ }
+
+ nssv_constexpr14 size_type rfind( CharT const * s, size_type pos, size_type n ) const // (3)
+ {
+ return rfind( basic_string_view( s, n ), pos );
+ }
+
+ nssv_constexpr14 size_type rfind( CharT const * s, size_type pos = npos ) const // (4)
+ {
+ return rfind( basic_string_view( s ), pos );
+ }
+
+ // find_first_of(), 4x:
+
+ nssv_constexpr size_type find_first_of( basic_string_view v, size_type pos = 0 ) const nssv_noexcept // (1)
+ {
+ return pos >= size()
+ ? npos
+ : to_pos( std::find_first_of( cbegin() + pos, cend(), v.cbegin(), v.cend(), Traits::eq ) );
+ }
+
+ nssv_constexpr size_type find_first_of( CharT c, size_type pos = 0 ) const nssv_noexcept // (2)
+ {
+ return find_first_of( basic_string_view( &c, 1 ), pos );
+ }
+
+ nssv_constexpr size_type find_first_of( CharT const * s, size_type pos, size_type n ) const // (3)
+ {
+ return find_first_of( basic_string_view( s, n ), pos );
+ }
+
+ nssv_constexpr size_type find_first_of( CharT const * s, size_type pos = 0 ) const // (4)
+ {
+ return find_first_of( basic_string_view( s ), pos );
+ }
+
+ // find_last_of(), 4x:
+
+ nssv_constexpr size_type find_last_of( basic_string_view v, size_type pos = npos ) const nssv_noexcept // (1)
+ {
+ return empty()
+ ? npos
+ : pos >= size()
+ ? find_last_of( v, size() - 1 )
+ : to_pos( std::find_first_of( const_reverse_iterator( cbegin() + pos + 1 ), crend(), v.cbegin(), v.cend(), Traits::eq ) );
+ }
+
+ nssv_constexpr size_type find_last_of( CharT c, size_type pos = npos ) const nssv_noexcept // (2)
+ {
+ return find_last_of( basic_string_view( &c, 1 ), pos );
+ }
+
+ nssv_constexpr size_type find_last_of( CharT const * s, size_type pos, size_type count ) const // (3)
+ {
+ return find_last_of( basic_string_view( s, count ), pos );
+ }
+
+ nssv_constexpr size_type find_last_of( CharT const * s, size_type pos = npos ) const // (4)
+ {
+ return find_last_of( basic_string_view( s ), pos );
+ }
+
+ // find_first_not_of(), 4x:
+
+ nssv_constexpr size_type find_first_not_of( basic_string_view v, size_type pos = 0 ) const nssv_noexcept // (1)
+ {
+ return pos >= size()
+ ? npos
+ : to_pos( std::find_if( cbegin() + pos, cend(), not_in_view( v ) ) );
+ }
+
+ nssv_constexpr size_type find_first_not_of( CharT c, size_type pos = 0 ) const nssv_noexcept // (2)
+ {
+ return find_first_not_of( basic_string_view( &c, 1 ), pos );
+ }
+
+ nssv_constexpr size_type find_first_not_of( CharT const * s, size_type pos, size_type count ) const // (3)
+ {
+ return find_first_not_of( basic_string_view( s, count ), pos );
+ }
+
+ nssv_constexpr size_type find_first_not_of( CharT const * s, size_type pos = 0 ) const // (4)
+ {
+ return find_first_not_of( basic_string_view( s ), pos );
+ }
+
+ // find_last_not_of(), 4x:
+
+ nssv_constexpr size_type find_last_not_of( basic_string_view v, size_type pos = npos ) const nssv_noexcept // (1)
+ {
+ return empty()
+ ? npos
+ : pos >= size()
+ ? find_last_not_of( v, size() - 1 )
+ : to_pos( std::find_if( const_reverse_iterator( cbegin() + pos + 1 ), crend(), not_in_view( v ) ) );
+ }
+
+ nssv_constexpr size_type find_last_not_of( CharT c, size_type pos = npos ) const nssv_noexcept // (2)
+ {
+ return find_last_not_of( basic_string_view( &c, 1 ), pos );
+ }
+
+ nssv_constexpr size_type find_last_not_of( CharT const * s, size_type pos, size_type count ) const // (3)
+ {
+ return find_last_not_of( basic_string_view( s, count ), pos );
+ }
+
+ nssv_constexpr size_type find_last_not_of( CharT const * s, size_type pos = npos ) const // (4)
+ {
+ return find_last_not_of( basic_string_view( s ), pos );
+ }
+
+ // Constants:
+
+#if nssv_CPP17_OR_GREATER
+ static nssv_constexpr size_type npos = size_type(-1);
+#elif nssv_CPP11_OR_GREATER
+ enum : size_type { npos = size_type(-1) };
+#else
+ enum { npos = size_type(-1) };
+#endif
+
+private:
+ struct not_in_view
+ {
+ const basic_string_view v;
+
+ nssv_constexpr explicit not_in_view( basic_string_view v_ ) : v( v_ ) {}
+
+ nssv_constexpr bool operator()( CharT c ) const
+ {
+ return npos == v.find_first_of( c );
+ }
+ };
+
+ nssv_constexpr size_type to_pos( const_iterator it ) const
+ {
+ return it == cend() ? npos : size_type( it - cbegin() );
+ }
+
+ nssv_constexpr size_type to_pos( const_reverse_iterator it ) const
+ {
+ return it == crend() ? npos : size_type( crend() - it - 1 );
+ }
+
+ nssv_constexpr const_reference data_at( size_type pos ) const
+ {
+#if nssv_BETWEEN( nssv_COMPILER_GNUC_VERSION, 1, 500 )
+ return data_[pos];
+#else
+ return assert( pos < size() ), data_[pos];
+#endif
+ }
+
+private:
+ const_pointer data_;
+ size_type size_;
+
+public:
+#if nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS
+
+ template< class Allocator >
+ basic_string_view( std::basic_string<CharT, Traits, Allocator> const & s ) nssv_noexcept
+ : data_( s.data() )
+ , size_( s.size() )
+ {}
+
+#if nssv_HAVE_EXPLICIT_CONVERSION
+
+ template< class Allocator >
+ explicit operator std::basic_string<CharT, Traits, Allocator>() const
+ {
+ return to_string( Allocator() );
+ }
+
+#endif // nssv_HAVE_EXPLICIT_CONVERSION
+
+#if nssv_CPP11_OR_GREATER
+
+ template< class Allocator = std::allocator<CharT> >
+ std::basic_string<CharT, Traits, Allocator>
+ to_string( Allocator const & a = Allocator() ) const
+ {
+ return std::basic_string<CharT, Traits, Allocator>( begin(), end(), a );
+ }
+
+#else
+
+ std::basic_string<CharT, Traits>
+ to_string() const
+ {
+ return std::basic_string<CharT, Traits>( begin(), end() );
+ }
+
+ template< class Allocator >
+ std::basic_string<CharT, Traits, Allocator>
+ to_string( Allocator const & a ) const
+ {
+ return std::basic_string<CharT, Traits, Allocator>( begin(), end(), a );
+ }
+
+#endif // nssv_CPP11_OR_GREATER
+
+#endif // nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS
+};
+
+//
+// Non-member functions:
+//
+
+// 24.4.3 Non-member comparison functions:
+// lexicographically compare two string views (function template):
+
+template< class CharT, class Traits >
+nssv_constexpr bool operator== (
+ basic_string_view <CharT, Traits> lhs,
+ basic_string_view <CharT, Traits> rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) == 0 ; }
+
+template< class CharT, class Traits >
+nssv_constexpr bool operator!= (
+ basic_string_view <CharT, Traits> lhs,
+ basic_string_view <CharT, Traits> rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) != 0 ; }
+
+template< class CharT, class Traits >
+nssv_constexpr bool operator< (
+ basic_string_view <CharT, Traits> lhs,
+ basic_string_view <CharT, Traits> rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) < 0 ; }
+
+template< class CharT, class Traits >
+nssv_constexpr bool operator<= (
+ basic_string_view <CharT, Traits> lhs,
+ basic_string_view <CharT, Traits> rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) <= 0 ; }
+
+template< class CharT, class Traits >
+nssv_constexpr bool operator> (
+ basic_string_view <CharT, Traits> lhs,
+ basic_string_view <CharT, Traits> rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) > 0 ; }
+
+template< class CharT, class Traits >
+nssv_constexpr bool operator>= (
+ basic_string_view <CharT, Traits> lhs,
+ basic_string_view <CharT, Traits> rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) >= 0 ; }
+
+// Let S be basic_string_view<CharT, Traits>, and sv be an instance of S.
+// Implementations shall provide sufficient additional overloads marked
+// constexpr and noexcept so that an object t with an implicit conversion
+// to S can be compared according to Table 67.
+
+#if ! nssv_CPP11_OR_GREATER || nssv_BETWEEN( nssv_COMPILER_MSVC_VERSION, 100, 141 )
+
+// accomodate for older compilers:
+
+// ==
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator==(
+ basic_string_view<CharT, Traits> lhs,
+ CharT const * rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) == 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator==(
+ CharT const * lhs,
+ basic_string_view<CharT, Traits> rhs ) nssv_noexcept
+{ return rhs.compare( lhs ) == 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator==(
+ basic_string_view<CharT, Traits> lhs,
+ std::basic_string<CharT, Traits> rhs ) nssv_noexcept
+{ return lhs.size() == rhs.size() && lhs.compare( rhs ) == 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator==(
+ std::basic_string<CharT, Traits> rhs,
+ basic_string_view<CharT, Traits> lhs ) nssv_noexcept
+{ return lhs.size() == rhs.size() && lhs.compare( rhs ) == 0; }
+
+// !=
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator!=(
+ basic_string_view<CharT, Traits> lhs,
+ char const * rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) != 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator!=(
+ char const * lhs,
+ basic_string_view<CharT, Traits> rhs ) nssv_noexcept
+{ return rhs.compare( lhs ) != 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator!=(
+ basic_string_view<CharT, Traits> lhs,
+ std::basic_string<CharT, Traits> rhs ) nssv_noexcept
+{ return lhs.size() != rhs.size() && lhs.compare( rhs ) != 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator!=(
+ std::basic_string<CharT, Traits> rhs,
+ basic_string_view<CharT, Traits> lhs ) nssv_noexcept
+{ return lhs.size() != rhs.size() || rhs.compare( lhs ) != 0; }
+
+// <
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator<(
+ basic_string_view<CharT, Traits> lhs,
+ char const * rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) < 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator<(
+ char const * lhs,
+ basic_string_view<CharT, Traits> rhs ) nssv_noexcept
+{ return rhs.compare( lhs ) > 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator<(
+ basic_string_view<CharT, Traits> lhs,
+ std::basic_string<CharT, Traits> rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) < 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator<(
+ std::basic_string<CharT, Traits> rhs,
+ basic_string_view<CharT, Traits> lhs ) nssv_noexcept
+{ return rhs.compare( lhs ) > 0; }
+
+// <=
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator<=(
+ basic_string_view<CharT, Traits> lhs,
+ char const * rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) <= 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator<=(
+ char const * lhs,
+ basic_string_view<CharT, Traits> rhs ) nssv_noexcept
+{ return rhs.compare( lhs ) >= 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator<=(
+ basic_string_view<CharT, Traits> lhs,
+ std::basic_string<CharT, Traits> rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) <= 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator<=(
+ std::basic_string<CharT, Traits> rhs,
+ basic_string_view<CharT, Traits> lhs ) nssv_noexcept
+{ return rhs.compare( lhs ) >= 0; }
+
+// >
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator>(
+ basic_string_view<CharT, Traits> lhs,
+ char const * rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) > 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator>(
+ char const * lhs,
+ basic_string_view<CharT, Traits> rhs ) nssv_noexcept
+{ return rhs.compare( lhs ) < 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator>(
+ basic_string_view<CharT, Traits> lhs,
+ std::basic_string<CharT, Traits> rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) > 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator>(
+ std::basic_string<CharT, Traits> rhs,
+ basic_string_view<CharT, Traits> lhs ) nssv_noexcept
+{ return rhs.compare( lhs ) < 0; }
+
+// >=
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator>=(
+ basic_string_view<CharT, Traits> lhs,
+ char const * rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) >= 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator>=(
+ char const * lhs,
+ basic_string_view<CharT, Traits> rhs ) nssv_noexcept
+{ return rhs.compare( lhs ) <= 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator>=(
+ basic_string_view<CharT, Traits> lhs,
+ std::basic_string<CharT, Traits> rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) >= 0; }
+
+template< class CharT, class Traits>
+nssv_constexpr bool operator>=(
+ std::basic_string<CharT, Traits> rhs,
+ basic_string_view<CharT, Traits> lhs ) nssv_noexcept
+{ return rhs.compare( lhs ) <= 0; }
+
+#else // newer compilers:
+
+#define nssv_BASIC_STRING_VIEW_I(T,U) typename std::decay< basic_string_view<T,U> >::type
+
+#if nssv_BETWEEN( nssv_COMPILER_MSVC_VERSION, 140, 150 )
+# define nssv_MSVC_ORDER(x) , int=x
+#else
+# define nssv_MSVC_ORDER(x) /*, int=x*/
+#endif
+
+// ==
+
+template< class CharT, class Traits nssv_MSVC_ORDER(1) >
+nssv_constexpr bool operator==(
+ basic_string_view <CharT, Traits> lhs,
+ nssv_BASIC_STRING_VIEW_I(CharT, Traits) rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) == 0; }
+
+template< class CharT, class Traits nssv_MSVC_ORDER(2) >
+nssv_constexpr bool operator==(
+ nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs,
+ basic_string_view <CharT, Traits> rhs ) nssv_noexcept
+{ return lhs.size() == rhs.size() && lhs.compare( rhs ) == 0; }
+
+// !=
+
+template< class CharT, class Traits nssv_MSVC_ORDER(1) >
+nssv_constexpr bool operator!= (
+ basic_string_view < CharT, Traits > lhs,
+ nssv_BASIC_STRING_VIEW_I( CharT, Traits ) rhs ) nssv_noexcept
+{ return lhs.size() != rhs.size() || lhs.compare( rhs ) != 0 ; }
+
+template< class CharT, class Traits nssv_MSVC_ORDER(2) >
+nssv_constexpr bool operator!= (
+ nssv_BASIC_STRING_VIEW_I( CharT, Traits ) lhs,
+ basic_string_view < CharT, Traits > rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) != 0 ; }
+
+// <
+
+template< class CharT, class Traits nssv_MSVC_ORDER(1) >
+nssv_constexpr bool operator< (
+ basic_string_view < CharT, Traits > lhs,
+ nssv_BASIC_STRING_VIEW_I( CharT, Traits ) rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) < 0 ; }
+
+template< class CharT, class Traits nssv_MSVC_ORDER(2) >
+nssv_constexpr bool operator< (
+ nssv_BASIC_STRING_VIEW_I( CharT, Traits ) lhs,
+ basic_string_view < CharT, Traits > rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) < 0 ; }
+
+// <=
+
+template< class CharT, class Traits nssv_MSVC_ORDER(1) >
+nssv_constexpr bool operator<= (
+ basic_string_view < CharT, Traits > lhs,
+ nssv_BASIC_STRING_VIEW_I( CharT, Traits ) rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) <= 0 ; }
+
+template< class CharT, class Traits nssv_MSVC_ORDER(2) >
+nssv_constexpr bool operator<= (
+ nssv_BASIC_STRING_VIEW_I( CharT, Traits ) lhs,
+ basic_string_view < CharT, Traits > rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) <= 0 ; }
+
+// >
+
+template< class CharT, class Traits nssv_MSVC_ORDER(1) >
+nssv_constexpr bool operator> (
+ basic_string_view < CharT, Traits > lhs,
+ nssv_BASIC_STRING_VIEW_I( CharT, Traits ) rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) > 0 ; }
+
+template< class CharT, class Traits nssv_MSVC_ORDER(2) >
+nssv_constexpr bool operator> (
+ nssv_BASIC_STRING_VIEW_I( CharT, Traits ) lhs,
+ basic_string_view < CharT, Traits > rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) > 0 ; }
+
+// >=
+
+template< class CharT, class Traits nssv_MSVC_ORDER(1) >
+nssv_constexpr bool operator>= (
+ basic_string_view < CharT, Traits > lhs,
+ nssv_BASIC_STRING_VIEW_I( CharT, Traits ) rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) >= 0 ; }
+
+template< class CharT, class Traits nssv_MSVC_ORDER(2) >
+nssv_constexpr bool operator>= (
+ nssv_BASIC_STRING_VIEW_I( CharT, Traits ) lhs,
+ basic_string_view < CharT, Traits > rhs ) nssv_noexcept
+{ return lhs.compare( rhs ) >= 0 ; }
+
+#undef nssv_MSVC_ORDER
+#undef nssv_BASIC_STRING_VIEW_I
+
+#endif // compiler-dependent approach to comparisons
+
+// 24.4.4 Inserters and extractors:
+
+namespace detail {
+
+template< class Stream >
+void write_padding( Stream & os, std::streamsize n )
+{
+ for ( std::streamsize i = 0; i < n; ++i )
+ os.rdbuf()->sputc( os.fill() );
+}
+
+template< class Stream, class View >
+Stream & write_to_stream( Stream & os, View const & sv )
+{
+ typename Stream::sentry sentry( os );
+
+ if ( !os )
+ return os;
+
+ const std::streamsize length = static_cast<std::streamsize>( sv.length() );
+
+ // Whether, and how, to pad:
+ const bool pad = ( length < os.width() );
+ const bool left_pad = pad && ( os.flags() & std::ios_base::adjustfield ) == std::ios_base::right;
+
+ if ( left_pad )
+ write_padding( os, os.width() - length );
+
+ // Write span characters:
+ os.rdbuf()->sputn( sv.begin(), length );
+
+ if ( pad && !left_pad )
+ write_padding( os, os.width() - length );
+
+ // Reset output stream width:
+ os.width( 0 );
+
+ return os;
+}
+
+} // namespace detail
+
+template< class CharT, class Traits >
+std::basic_ostream<CharT, Traits> &
+operator<<(
+ std::basic_ostream<CharT, Traits>& os,
+ basic_string_view <CharT, Traits> sv )
+{
+ return detail::write_to_stream( os, sv );
+}
+
+// Several typedefs for common character types are provided:
+
+typedef basic_string_view<char> string_view;
+typedef basic_string_view<wchar_t> wstring_view;
+#if nssv_HAVE_WCHAR16_T
+typedef basic_string_view<char16_t> u16string_view;
+typedef basic_string_view<char32_t> u32string_view;
+#endif
+
+}} // namespace nonstd::sv_lite
+
+//
+// 24.4.6 Suffix for basic_string_view literals:
+//
+
+#if nssv_HAVE_USER_DEFINED_LITERALS
+
+namespace nonstd {
+nssv_inline_ns namespace literals {
+nssv_inline_ns namespace string_view_literals {
+
+#if nssv_CONFIG_STD_SV_OPERATOR && nssv_HAVE_STD_DEFINED_LITERALS
+
+nssv_constexpr nonstd::sv_lite::string_view operator "" sv( const char* str, size_t len ) nssv_noexcept // (1)
+{
+ return nonstd::sv_lite::string_view{ str, len };
+}
+
+nssv_constexpr nonstd::sv_lite::u16string_view operator "" sv( const char16_t* str, size_t len ) nssv_noexcept // (2)
+{
+ return nonstd::sv_lite::u16string_view{ str, len };
+}
+
+nssv_constexpr nonstd::sv_lite::u32string_view operator "" sv( const char32_t* str, size_t len ) nssv_noexcept // (3)
+{
+ return nonstd::sv_lite::u32string_view{ str, len };
+}
+
+nssv_constexpr nonstd::sv_lite::wstring_view operator "" sv( const wchar_t* str, size_t len ) nssv_noexcept // (4)
+{
+ return nonstd::sv_lite::wstring_view{ str, len };
+}
+
+#endif // nssv_CONFIG_STD_SV_OPERATOR && nssv_HAVE_STD_DEFINED_LITERALS
+
+#if nssv_CONFIG_USR_SV_OPERATOR
+
+nssv_constexpr nonstd::sv_lite::string_view operator "" _sv( const char* str, size_t len ) nssv_noexcept // (1)
+{
+ return nonstd::sv_lite::string_view{ str, len };
+}
+
+nssv_constexpr nonstd::sv_lite::u16string_view operator "" _sv( const char16_t* str, size_t len ) nssv_noexcept // (2)
+{
+ return nonstd::sv_lite::u16string_view{ str, len };
+}
+
+nssv_constexpr nonstd::sv_lite::u32string_view operator "" _sv( const char32_t* str, size_t len ) nssv_noexcept // (3)
+{
+ return nonstd::sv_lite::u32string_view{ str, len };
+}
+
+nssv_constexpr nonstd::sv_lite::wstring_view operator "" _sv( const wchar_t* str, size_t len ) nssv_noexcept // (4)
+{
+ return nonstd::sv_lite::wstring_view{ str, len };
+}
+
+#endif // nssv_CONFIG_USR_SV_OPERATOR
+
+}}} // namespace nonstd::literals::string_view_literals
+
+#endif
+
+//
+// Extensions for std::string:
+//
+
+#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS
+
+namespace nonstd {
+namespace sv_lite {
+
+// Exclude MSVC 14 (19.00): it yields ambiguous to_string():
+
+#if nssv_CPP11_OR_GREATER && nssv_COMPILER_MSVC_VERSION != 140
+
+template< class CharT, class Traits, class Allocator = std::allocator<CharT> >
+std::basic_string<CharT, Traits, Allocator>
+to_string( basic_string_view<CharT, Traits> v, Allocator const & a = Allocator() )
+{
+ return std::basic_string<CharT,Traits, Allocator>( v.begin(), v.end(), a );
+}
+
+#else
+
+template< class CharT, class Traits >
+std::basic_string<CharT, Traits>
+to_string( basic_string_view<CharT, Traits> v )
+{
+ return std::basic_string<CharT, Traits>( v.begin(), v.end() );
+}
+
+template< class CharT, class Traits, class Allocator >
+std::basic_string<CharT, Traits, Allocator>
+to_string( basic_string_view<CharT, Traits> v, Allocator const & a )
+{
+ return std::basic_string<CharT, Traits, Allocator>( v.begin(), v.end(), a );
+}
+
+#endif // nssv_CPP11_OR_GREATER
+
+template< class CharT, class Traits, class Allocator >
+basic_string_view<CharT, Traits>
+to_string_view( std::basic_string<CharT, Traits, Allocator> const & s )
+{
+ return basic_string_view<CharT, Traits>( s.data(), s.size() );
+}
+
+}} // namespace nonstd::sv_lite
+
+#endif // nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS
+
+//
+// make types and algorithms available in namespace nonstd:
+//
+
+namespace nonstd {
+
+using sv_lite::basic_string_view;
+using sv_lite::string_view;
+using sv_lite::wstring_view;
+
+#if nssv_HAVE_WCHAR16_T
+using sv_lite::u16string_view;
+#endif
+#if nssv_HAVE_WCHAR32_T
+using sv_lite::u32string_view;
+#endif
+
+// literal "sv"
+
+using sv_lite::operator==;
+using sv_lite::operator!=;
+using sv_lite::operator<;
+using sv_lite::operator<=;
+using sv_lite::operator>;
+using sv_lite::operator>=;
+
+using sv_lite::operator<<;
+
+#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS
+using sv_lite::to_string;
+using sv_lite::to_string_view;
+#endif
+
+} // namespace nonstd
+
+// 24.4.5 Hash support (C++11):
+
+// Note: The hash value of a string view object is equal to the hash value of
+// the corresponding string object.
+
+#if nssv_HAVE_STD_HASH
+
+#include <functional>
+
+namespace std {
+
+template<>
+struct hash< nonstd::string_view >
+{
+public:
+ std::size_t operator()( nonstd::string_view v ) const nssv_noexcept
+ {
+ return std::hash<std::string>()( std::string( v.data(), v.size() ) );
+ }
+};
+
+template<>
+struct hash< nonstd::wstring_view >
+{
+public:
+ std::size_t operator()( nonstd::wstring_view v ) const nssv_noexcept
+ {
+ return std::hash<std::wstring>()( std::wstring( v.data(), v.size() ) );
+ }
+};
+
+template<>
+struct hash< nonstd::u16string_view >
+{
+public:
+ std::size_t operator()( nonstd::u16string_view v ) const nssv_noexcept
+ {
+ return std::hash<std::u16string>()( std::u16string( v.data(), v.size() ) );
+ }
+};
+
+template<>
+struct hash< nonstd::u32string_view >
+{
+public:
+ std::size_t operator()( nonstd::u32string_view v ) const nssv_noexcept
+ {
+ return std::hash<std::u32string>()( std::u32string( v.data(), v.size() ) );
+ }
+};
+
+} // namespace std
+
+#endif // nssv_HAVE_STD_HASH
+
+nssv_RESTORE_WARNINGS()
+
+#endif // nssv_HAVE_STD_STRING_VIEW
+#endif // NONSTD_SV_LITE_H_INCLUDED
diff --git a/src/arrow/cpp/src/arrow/vendored/strptime.h b/src/arrow/cpp/src/arrow/vendored/strptime.h
new file mode 100644
index 000000000..764a4440e
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/strptime.h
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <time.h>
+
+#include "arrow/util/visibility.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// A less featureful implementation of strptime() for platforms lacking
+// a standard implementation (e.g. Windows).
+ARROW_EXPORT char* arrow_strptime(const char* __restrict, const char* __restrict,
+ struct tm* __restrict);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/README.md b/src/arrow/cpp/src/arrow/vendored/uriparser/README.md
new file mode 100644
index 000000000..b3a219c7f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/README.md
@@ -0,0 +1,25 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+The files in this directory are vendored from uriparser git tag "uriparser-0.9.3".
+
+Include paths fixed using:
+```
+sed -E -i 's:include <uriparser/(.*).h>:include "\1.h":g' src/arrow/vendored/uriparser/*
+```
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/Uri.h b/src/arrow/cpp/src/arrow/vendored/uriparser/Uri.h
new file mode 100644
index 000000000..9315a12c3
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/Uri.h
@@ -0,0 +1,1090 @@
+/* 8cd64b75589a7efa22989ae85f71e620a99e3c44776338f7d3114eacf73d27a3 (0.9.3+)
+ *
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/**
+ * @file Uri.h
+ * Holds the RFC 3986 %URI parser interface.
+ * NOTE: This header includes itself twice.
+ */
+
+#if (defined(URI_PASS_ANSI) && !defined(URI_H_ANSI)) \
+ || (defined(URI_PASS_UNICODE) && !defined(URI_H_UNICODE)) \
+ || (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* What encodings are enabled? */
+#include "UriDefsConfig.h"
+#if (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* Include SELF twice */
+# ifdef URI_ENABLE_ANSI
+# define URI_PASS_ANSI 1
+# include "Uri.h"
+# undef URI_PASS_ANSI
+# endif
+# ifdef URI_ENABLE_UNICODE
+# define URI_PASS_UNICODE 1
+# include "Uri.h"
+# undef URI_PASS_UNICODE
+# endif
+/* Only one pass for each encoding */
+#elif (defined(URI_PASS_ANSI) && !defined(URI_H_ANSI) \
+ && defined(URI_ENABLE_ANSI)) || (defined(URI_PASS_UNICODE) \
+ && !defined(URI_H_UNICODE) && defined(URI_ENABLE_UNICODE))
+# ifdef URI_PASS_ANSI
+# define URI_H_ANSI 1
+# include "UriDefsAnsi.h"
+# else
+# define URI_H_UNICODE 1
+# include "UriDefsUnicode.h"
+# endif
+
+
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+
+#ifndef URI_DOXYGEN
+# include "UriBase.h"
+#endif
+
+
+
+/**
+ * Specifies a range of characters within a string.
+ * The range includes all characters from <c>first</c>
+ * to one before <c>afterLast</c>. So if both are
+ * non-NULL the difference is the length of the text range.
+ *
+ * @see UriUriA
+ * @see UriPathSegmentA
+ * @see UriHostDataA
+ * @since 0.3.0
+ */
+typedef struct URI_TYPE(TextRangeStruct) {
+ const URI_CHAR * first; /**< Pointer to first character */
+ const URI_CHAR * afterLast; /**< Pointer to character after the last one still in */
+} URI_TYPE(TextRange); /**< @copydoc UriTextRangeStructA */
+
+
+
+/**
+ * Represents a path segment within a %URI path.
+ * More precisely it is a node in a linked
+ * list of path segments.
+ *
+ * @see UriUriA
+ * @since 0.3.0
+ */
+typedef struct URI_TYPE(PathSegmentStruct) {
+ URI_TYPE(TextRange) text; /**< Path segment name */
+ struct URI_TYPE(PathSegmentStruct) * next; /**< Pointer to the next path segment in the list, can be NULL if last already */
+
+ void * reserved; /**< Reserved to the parser */
+} URI_TYPE(PathSegment); /**< @copydoc UriPathSegmentStructA */
+
+
+
+/**
+ * Holds structured host information.
+ * This is either a IPv4, IPv6, plain
+ * text for IPvFuture or all zero for
+ * a registered name.
+ *
+ * @see UriUriA
+ * @since 0.3.0
+ */
+typedef struct URI_TYPE(HostDataStruct) {
+ UriIp4 * ip4; /**< IPv4 address */
+ UriIp6 * ip6; /**< IPv6 address */
+ URI_TYPE(TextRange) ipFuture; /**< IPvFuture address */
+} URI_TYPE(HostData); /**< @copydoc UriHostDataStructA */
+
+
+
+/**
+ * Represents an RFC 3986 %URI.
+ * Missing components can be {NULL, NULL} ranges.
+ *
+ * @see uriFreeUriMembersA
+ * @see uriFreeUriMembersMmA
+ * @see UriParserStateA
+ * @since 0.3.0
+ */
+typedef struct URI_TYPE(UriStruct) {
+ URI_TYPE(TextRange) scheme; /**< Scheme (e.g. "http") */
+ URI_TYPE(TextRange) userInfo; /**< User info (e.g. "user:pass") */
+ URI_TYPE(TextRange) hostText; /**< Host text (set for all hosts, excluding square brackets) */
+ URI_TYPE(HostData) hostData; /**< Structured host type specific data */
+ URI_TYPE(TextRange) portText; /**< Port (e.g. "80") */
+ URI_TYPE(PathSegment) * pathHead; /**< Head of a linked list of path segments */
+ URI_TYPE(PathSegment) * pathTail; /**< Tail of the list behind pathHead */
+ URI_TYPE(TextRange) query; /**< Query without leading "?" */
+ URI_TYPE(TextRange) fragment; /**< Query without leading "#" */
+ UriBool absolutePath; /**< Absolute path flag, distincting "a" and "/a";
+ always <c>URI_FALSE</c> for URIs with host */
+ UriBool owner; /**< Memory owner flag */
+
+ void * reserved; /**< Reserved to the parser */
+} URI_TYPE(Uri); /**< @copydoc UriUriStructA */
+
+
+
+/**
+ * Represents a state of the %URI parser.
+ * Missing components can be NULL to reflect
+ * a components absence.
+ *
+ * @see uriFreeUriMembersA
+ * @see uriFreeUriMembersMmA
+ * @since 0.3.0
+ */
+typedef struct URI_TYPE(ParserStateStruct) {
+ URI_TYPE(Uri) * uri; /**< Plug in the %URI structure to be filled while parsing here */
+ int errorCode; /**< Code identifying the error which occurred */
+ const URI_CHAR * errorPos; /**< Pointer to position in case of a syntax error */
+
+ void * reserved; /**< Reserved to the parser */
+} URI_TYPE(ParserState); /**< @copydoc UriParserStateStructA */
+
+
+
+/**
+ * Represents a query element.
+ * More precisely it is a node in a linked
+ * list of query elements.
+ *
+ * @since 0.7.0
+ */
+typedef struct URI_TYPE(QueryListStruct) {
+ const URI_CHAR * key; /**< Key of the query element */
+ const URI_CHAR * value; /**< Value of the query element, can be NULL */
+
+ struct URI_TYPE(QueryListStruct) * next; /**< Pointer to the next key/value pair in the list, can be NULL if last already */
+} URI_TYPE(QueryList); /**< @copydoc UriQueryListStructA */
+
+
+
+/**
+ * Parses a RFC 3986 %URI.
+ * Uses default libc-based memory manager.
+ *
+ * @param state <b>INOUT</b>: Parser state with set output %URI, must not be NULL
+ * @param first <b>IN</b>: Pointer to the first character to parse, must not be NULL
+ * @param afterLast <b>IN</b>: Pointer to the character after the last to parse, must not be NULL
+ * @return 0 on success, error code otherwise
+ *
+ * @see uriParseUriA
+ * @see uriParseSingleUriA
+ * @see uriParseSingleUriExA
+ * @see uriToStringA
+ * @since 0.3.0
+ * @deprecated Deprecated since 0.9.0, please migrate to uriParseSingleUriExA (with "Single").
+ */
+URI_PUBLIC int URI_FUNC(ParseUriEx)(URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast);
+
+
+
+/**
+ * Parses a RFC 3986 %URI.
+ * Uses default libc-based memory manager.
+ *
+ * @param state <b>INOUT</b>: Parser state with set output %URI, must not be NULL
+ * @param text <b>IN</b>: Text to parse, must not be NULL
+ * @return 0 on success, error code otherwise
+ *
+ * @see uriParseUriExA
+ * @see uriParseSingleUriA
+ * @see uriParseSingleUriExA
+ * @see uriToStringA
+ * @since 0.3.0
+ * @deprecated Deprecated since 0.9.0, please migrate to uriParseSingleUriA (with "Single").
+ */
+URI_PUBLIC int URI_FUNC(ParseUri)(URI_TYPE(ParserState) * state,
+ const URI_CHAR * text);
+
+
+
+/**
+ * Parses a single RFC 3986 %URI.
+ * Uses default libc-based memory manager.
+ *
+ * @param uri <b>OUT</b>: Output %URI, must not be NULL
+ * @param text <b>IN</b>: Pointer to the first character to parse,
+ * must not be NULL
+ * @param errorPos <b>OUT</b>: Pointer to a pointer to the first character
+ * causing a syntax error, can be NULL;
+ * only set when URI_ERROR_SYNTAX was returned
+ * @return 0 on success, error code otherwise
+ *
+ * @see uriParseSingleUriExA
+ * @see uriParseSingleUriExMmA
+ * @see uriToStringA
+ * @since 0.9.0
+ */
+URI_PUBLIC int URI_FUNC(ParseSingleUri)(URI_TYPE(Uri) * uri,
+ const URI_CHAR * text, const URI_CHAR ** errorPos);
+
+
+
+/**
+ * Parses a single RFC 3986 %URI.
+ * Uses default libc-based memory manager.
+ *
+ * @param uri <b>OUT</b>: Output %URI, must not be NULL
+ * @param first <b>IN</b>: Pointer to the first character to parse,
+ * must not be NULL
+ * @param afterLast <b>IN</b>: Pointer to the character after the last to
+ * parse, can be NULL
+ * (to use first + strlen(first))
+ * @param errorPos <b>OUT</b>: Pointer to a pointer to the first character
+ * causing a syntax error, can be NULL;
+ * only set when URI_ERROR_SYNTAX was returned
+ * @return 0 on success, error code otherwise
+ *
+ * @see uriParseSingleUriA
+ * @see uriParseSingleUriExMmA
+ * @see uriToStringA
+ * @since 0.9.0
+ */
+URI_PUBLIC int URI_FUNC(ParseSingleUriEx)(URI_TYPE(Uri) * uri,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ const URI_CHAR ** errorPos);
+
+
+
+/**
+ * Parses a single RFC 3986 %URI.
+ *
+ * @param uri <b>OUT</b>: Output %URI, must not be NULL
+ * @param first <b>IN</b>: Pointer to the first character to parse,
+ * must not be NULL
+ * @param afterLast <b>IN</b>: Pointer to the character after the last to
+ * parse, can be NULL
+ * (to use first + strlen(first))
+ * @param errorPos <b>OUT</b>: Pointer to a pointer to the first character
+ * causing a syntax error, can be NULL;
+ * only set when URI_ERROR_SYNTAX was returned
+ * @param memory <b>IN</b>: Memory manager to use, NULL for default libc
+ * @return 0 on success, error code otherwise
+ *
+ * @see uriParseSingleUriA
+ * @see uriParseSingleUriExA
+ * @see uriToStringA
+ * @since 0.9.0
+ */
+URI_PUBLIC int URI_FUNC(ParseSingleUriExMm)(URI_TYPE(Uri) * uri,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ const URI_CHAR ** errorPos, UriMemoryManager * memory);
+
+
+
+/**
+ * Frees all memory associated with the members
+ * of the %URI structure. Note that the structure
+ * itself is not freed, only its members.
+ * Uses default libc-based memory manager.
+ *
+ * @param uri <b>INOUT</b>: %URI structure whose members should be freed
+ *
+ * @see uriFreeUriMembersMmA
+ * @since 0.3.0
+ */
+URI_PUBLIC void URI_FUNC(FreeUriMembers)(URI_TYPE(Uri) * uri);
+
+
+
+/**
+ * Frees all memory associated with the members
+ * of the %URI structure. Note that the structure
+ * itself is not freed, only its members.
+ *
+ * @param uri <b>INOUT</b>: %URI structure whose members should be freed
+ * @param memory <b>IN</b>: Memory manager to use, NULL for default libc
+ * @return 0 on success, error code otherwise
+ *
+ * @see uriFreeUriMembersA
+ * @since 0.9.0
+ */
+URI_PUBLIC int URI_FUNC(FreeUriMembersMm)(URI_TYPE(Uri) * uri,
+ UriMemoryManager * memory);
+
+
+
+/**
+ * Percent-encodes all unreserved characters from the input string and
+ * writes the encoded version to the output string.
+ * Be sure to allocate <b>3 times</b> the space of the input buffer for
+ * the output buffer for <c>normalizeBreaks == URI_FALSE</c> and <b>6 times</b>
+ * the space for <c>normalizeBreaks == URI_TRUE</c>
+ * (since e.g. "\x0d" becomes "%0D%0A" in that case)
+ *
+ * @param inFirst <b>IN</b>: Pointer to first character of the input text
+ * @param inAfterLast <b>IN</b>: Pointer after the last character of the input text
+ * @param out <b>OUT</b>: Encoded text destination
+ * @param spaceToPlus <b>IN</b>: Whether to convert ' ' to '+' or not
+ * @param normalizeBreaks <b>IN</b>: Whether to convert CR and LF to CR-LF or not.
+ * @return Position of terminator in output string
+ *
+ * @see uriEscapeA
+ * @see uriUnescapeInPlaceExA
+ * @since 0.5.2
+ */
+URI_PUBLIC URI_CHAR * URI_FUNC(EscapeEx)(const URI_CHAR * inFirst,
+ const URI_CHAR * inAfterLast, URI_CHAR * out,
+ UriBool spaceToPlus, UriBool normalizeBreaks);
+
+
+
+/**
+ * Percent-encodes all unreserved characters from the input string and
+ * writes the encoded version to the output string.
+ * Be sure to allocate <b>3 times</b> the space of the input buffer for
+ * the output buffer for <c>normalizeBreaks == URI_FALSE</c> and <b>6 times</b>
+ * the space for <c>normalizeBreaks == URI_TRUE</c>
+ * (since e.g. "\x0d" becomes "%0D%0A" in that case)
+ *
+ * @param in <b>IN</b>: Text source
+ * @param out <b>OUT</b>: Encoded text destination
+ * @param spaceToPlus <b>IN</b>: Whether to convert ' ' to '+' or not
+ * @param normalizeBreaks <b>IN</b>: Whether to convert CR and LF to CR-LF or not.
+ * @return Position of terminator in output string
+ *
+ * @see uriEscapeExA
+ * @see uriUnescapeInPlaceA
+ * @since 0.5.0
+ */
+URI_PUBLIC URI_CHAR * URI_FUNC(Escape)(const URI_CHAR * in, URI_CHAR * out,
+ UriBool spaceToPlus, UriBool normalizeBreaks);
+
+
+
+/**
+ * Unescapes percent-encoded groups in a given string.
+ * E.g. "%20" will become " ". Unescaping is done in place.
+ * The return value will be point to the new position
+ * of the terminating zero. Use this value to get the new
+ * length of the string. NULL is only returned if <c>inout</c>
+ * is NULL.
+ *
+ * @param inout <b>INOUT</b>: Text to unescape/decode
+ * @param plusToSpace <b>IN</b>: Whether to convert '+' to ' ' or not
+ * @param breakConversion <b>IN</b>: Line break conversion mode
+ * @return Pointer to new position of the terminating zero
+ *
+ * @see uriUnescapeInPlaceA
+ * @see uriEscapeExA
+ * @since 0.5.0
+ */
+URI_PUBLIC const URI_CHAR * URI_FUNC(UnescapeInPlaceEx)(URI_CHAR * inout,
+ UriBool plusToSpace, UriBreakConversion breakConversion);
+
+
+
+/**
+ * Unescapes percent-encoded groups in a given string.
+ * E.g. "%20" will become " ". Unescaping is done in place.
+ * The return value will be point to the new position
+ * of the terminating zero. Use this value to get the new
+ * length of the string. NULL is only returned if <c>inout</c>
+ * is NULL.
+ *
+ * NOTE: '+' is not decoded to ' ' and line breaks are not converted.
+ * Use the more advanced UnescapeInPlaceEx for that features instead.
+ *
+ * @param inout <b>INOUT</b>: Text to unescape/decode
+ * @return Pointer to new position of the terminating zero
+ *
+ * @see uriUnescapeInPlaceExA
+ * @see uriEscapeA
+ * @since 0.3.0
+ */
+URI_PUBLIC const URI_CHAR * URI_FUNC(UnescapeInPlace)(URI_CHAR * inout);
+
+
+
+/**
+ * Performs reference resolution as described in
+ * <a href="http://tools.ietf.org/html/rfc3986#section-5.2.2">section 5.2.2 of RFC 3986</a>.
+ * Uses default libc-based memory manager.
+ * NOTE: On success you have to call uriFreeUriMembersA on \p absoluteDest manually later.
+ *
+ * @param absoluteDest <b>OUT</b>: Result %URI
+ * @param relativeSource <b>IN</b>: Reference to resolve
+ * @param absoluteBase <b>IN</b>: Base %URI to apply
+ * @return Error code or 0 on success
+ *
+ * @see uriRemoveBaseUriA
+ * @see uriRemoveBaseUriMmA
+ * @see uriAddBaseUriExA
+ * @see uriAddBaseUriExMmA
+ * @since 0.4.0
+ */
+URI_PUBLIC int URI_FUNC(AddBaseUri)(URI_TYPE(Uri) * absoluteDest,
+ const URI_TYPE(Uri) * relativeSource,
+ const URI_TYPE(Uri) * absoluteBase);
+
+
+
+/**
+ * Performs reference resolution as described in
+ * <a href="http://tools.ietf.org/html/rfc3986#section-5.2.2">section 5.2.2 of RFC 3986</a>.
+ * Uses default libc-based memory manager.
+ * NOTE: On success you have to call uriFreeUriMembersA on \p absoluteDest manually later.
+ *
+ * @param absoluteDest <b>OUT</b>: Result %URI
+ * @param relativeSource <b>IN</b>: Reference to resolve
+ * @param absoluteBase <b>IN</b>: Base %URI to apply
+ * @param options <b>IN</b>: Configuration to apply
+ * @return Error code or 0 on success
+ *
+ * @see uriRemoveBaseUriA
+ * @see uriAddBaseUriA
+ * @see uriAddBaseUriExMmA
+ * @since 0.8.1
+ */
+URI_PUBLIC int URI_FUNC(AddBaseUriEx)(URI_TYPE(Uri) * absoluteDest,
+ const URI_TYPE(Uri) * relativeSource,
+ const URI_TYPE(Uri) * absoluteBase,
+ UriResolutionOptions options);
+
+
+
+/**
+ * Performs reference resolution as described in
+ * <a href="http://tools.ietf.org/html/rfc3986#section-5.2.2">section 5.2.2 of RFC 3986</a>.
+ * NOTE: On success you have to call uriFreeUriMembersMmA on \p absoluteDest manually later.
+ *
+ * @param absoluteDest <b>OUT</b>: Result %URI
+ * @param relativeSource <b>IN</b>: Reference to resolve
+ * @param absoluteBase <b>IN</b>: Base %URI to apply
+ * @param options <b>IN</b>: Configuration to apply
+ * @param memory <b>IN</b>: Memory manager to use, NULL for default libc
+ * @return Error code or 0 on success
+ *
+ * @see uriRemoveBaseUriA
+ * @see uriRemoveBaseUriMmA
+ * @see uriAddBaseUriA
+ * @see uriAddBaseUriExA
+ * @since 0.9.0
+ */
+URI_PUBLIC int URI_FUNC(AddBaseUriExMm)(URI_TYPE(Uri) * absoluteDest,
+ const URI_TYPE(Uri) * relativeSource,
+ const URI_TYPE(Uri) * absoluteBase,
+ UriResolutionOptions options, UriMemoryManager * memory);
+
+
+
+/**
+ * Tries to make a relative %URI (a reference) from an
+ * absolute %URI and a given base %URI. The resulting %URI is going to be
+ * relative if the absolute %URI and base %UI share both scheme and authority.
+ * If that is not the case, the result will still be
+ * an absolute URI (with scheme part if necessary).
+ * Uses default libc-based memory manager.
+ * NOTE: On success you have to call uriFreeUriMembersA on
+ * \p dest manually later.
+ *
+ * @param dest <b>OUT</b>: Result %URI
+ * @param absoluteSource <b>IN</b>: Absolute %URI to make relative
+ * @param absoluteBase <b>IN</b>: Base %URI
+ * @param domainRootMode <b>IN</b>: Create %URI with path relative to domain root
+ * @return Error code or 0 on success
+ *
+ * @see uriRemoveBaseUriMmA
+ * @see uriAddBaseUriA
+ * @see uriAddBaseUriExA
+ * @see uriAddBaseUriExMmA
+ * @since 0.5.2
+ */
+URI_PUBLIC int URI_FUNC(RemoveBaseUri)(URI_TYPE(Uri) * dest,
+ const URI_TYPE(Uri) * absoluteSource,
+ const URI_TYPE(Uri) * absoluteBase,
+ UriBool domainRootMode);
+
+
+
+/**
+ * Tries to make a relative %URI (a reference) from an
+ * absolute %URI and a given base %URI. The resulting %URI is going to be
+ * relative if the absolute %URI and base %UI share both scheme and authority.
+ * If that is not the case, the result will still be
+ * an absolute URI (with scheme part if necessary).
+ * NOTE: On success you have to call uriFreeUriMembersMmA on
+ * \p dest manually later.
+ *
+ * @param dest <b>OUT</b>: Result %URI
+ * @param absoluteSource <b>IN</b>: Absolute %URI to make relative
+ * @param absoluteBase <b>IN</b>: Base %URI
+ * @param domainRootMode <b>IN</b>: Create %URI with path relative to domain root
+ * @param memory <b>IN</b>: Memory manager to use, NULL for default libc
+ * @return Error code or 0 on success
+ *
+ * @see uriRemoveBaseUriA
+ * @see uriAddBaseUriA
+ * @see uriAddBaseUriExA
+ * @see uriAddBaseUriExMmA
+ * @since 0.9.0
+ */
+URI_PUBLIC int URI_FUNC(RemoveBaseUriMm)(URI_TYPE(Uri) * dest,
+ const URI_TYPE(Uri) * absoluteSource,
+ const URI_TYPE(Uri) * absoluteBase,
+ UriBool domainRootMode, UriMemoryManager * memory);
+
+
+
+/**
+ * Checks two URIs for equivalence. Comparison is done
+ * the naive way, without prior normalization.
+ * NOTE: Two <c>NULL</c> URIs are equal as well.
+ *
+ * @param a <b>IN</b>: First %URI
+ * @param b <b>IN</b>: Second %URI
+ * @return <c>URI_TRUE</c> when equal, <c>URI_FAlSE</c> else
+ *
+ * @since 0.4.0
+ */
+URI_PUBLIC UriBool URI_FUNC(EqualsUri)(const URI_TYPE(Uri) * a,
+ const URI_TYPE(Uri) * b);
+
+
+
+/**
+ * Calculates the number of characters needed to store the
+ * string representation of the given %URI excluding the
+ * terminator.
+ *
+ * @param uri <b>IN</b>: %URI to measure
+ * @param charsRequired <b>OUT</b>: Length of the string representation in characters <b>excluding</b> terminator
+ * @return Error code or 0 on success
+ *
+ * @see uriToStringA
+ * @since 0.5.0
+ */
+URI_PUBLIC int URI_FUNC(ToStringCharsRequired)(const URI_TYPE(Uri) * uri,
+ int * charsRequired);
+
+
+
+/**
+ * Converts a %URI structure back to text as described in
+ * <a href="http://tools.ietf.org/html/rfc3986#section-5.3">section 5.3 of RFC 3986</a>.
+ *
+ * @param dest <b>OUT</b>: Output destination
+ * @param uri <b>IN</b>: %URI to convert
+ * @param maxChars <b>IN</b>: Maximum number of characters to copy <b>including</b> terminator
+ * @param charsWritten <b>OUT</b>: Number of characters written, can be lower than maxChars even if the %URI is too long!
+ * @return Error code or 0 on success
+ *
+ * @see uriToStringCharsRequiredA
+ * @since 0.4.0
+ */
+URI_PUBLIC int URI_FUNC(ToString)(URI_CHAR * dest, const URI_TYPE(Uri) * uri,
+ int maxChars, int * charsWritten);
+
+
+
+/**
+ * Determines the components of a %URI that are not normalized.
+ *
+ * @param uri <b>IN</b>: %URI to check
+ * @return Normalization job mask
+ *
+ * @see uriNormalizeSyntaxA
+ * @see uriNormalizeSyntaxExA
+ * @see uriNormalizeSyntaxExMmA
+ * @see uriNormalizeSyntaxMaskRequiredExA
+ * @since 0.5.0
+ * @deprecated Deprecated since 0.9.0, please migrate to uriNormalizeSyntaxMaskRequiredExA (with "Ex").
+ */
+URI_PUBLIC unsigned int URI_FUNC(NormalizeSyntaxMaskRequired)(
+ const URI_TYPE(Uri) * uri);
+
+
+
+/**
+ * Determines the components of a %URI that are not normalized.
+ *
+ * @param uri <b>IN</b>: %URI to check
+ * @param outMask <b>OUT</b>: Normalization job mask
+ * @return Error code or 0 on success
+ *
+ * @see uriNormalizeSyntaxA
+ * @see uriNormalizeSyntaxExA
+ * @see uriNormalizeSyntaxExMmA
+ * @see uriNormalizeSyntaxMaskRequiredA
+ * @since 0.9.0
+ */
+URI_PUBLIC int URI_FUNC(NormalizeSyntaxMaskRequiredEx)(
+ const URI_TYPE(Uri) * uri, unsigned int * outMask);
+
+
+
+/**
+ * Normalizes a %URI using a normalization mask.
+ * The normalization mask decides what components are normalized.
+ *
+ * NOTE: If necessary the %URI becomes owner of all memory
+ * behind the text pointed to. Text is duplicated in that case.
+ * Uses default libc-based memory manager.
+ *
+ * @param uri <b>INOUT</b>: %URI to normalize
+ * @param mask <b>IN</b>: Normalization mask
+ * @return Error code or 0 on success
+ *
+ * @see uriNormalizeSyntaxA
+ * @see uriNormalizeSyntaxExMmA
+ * @see uriNormalizeSyntaxMaskRequiredA
+ * @since 0.5.0
+ */
+URI_PUBLIC int URI_FUNC(NormalizeSyntaxEx)(URI_TYPE(Uri) * uri,
+ unsigned int mask);
+
+
+
+/**
+ * Normalizes a %URI using a normalization mask.
+ * The normalization mask decides what components are normalized.
+ *
+ * NOTE: If necessary the %URI becomes owner of all memory
+ * behind the text pointed to. Text is duplicated in that case.
+ *
+ * @param uri <b>INOUT</b>: %URI to normalize
+ * @param mask <b>IN</b>: Normalization mask
+ * @param memory <b>IN</b>: Memory manager to use, NULL for default libc
+ * @return Error code or 0 on success
+ *
+ * @see uriNormalizeSyntaxA
+ * @see uriNormalizeSyntaxExA
+ * @see uriNormalizeSyntaxMaskRequiredA
+ * @since 0.9.0
+ */
+URI_PUBLIC int URI_FUNC(NormalizeSyntaxExMm)(URI_TYPE(Uri) * uri,
+ unsigned int mask, UriMemoryManager * memory);
+
+
+
+/**
+ * Normalizes all components of a %URI.
+ *
+ * NOTE: If necessary the %URI becomes owner of all memory
+ * behind the text pointed to. Text is duplicated in that case.
+ * Uses default libc-based memory manager.
+ *
+ * @param uri <b>INOUT</b>: %URI to normalize
+ * @return Error code or 0 on success
+ *
+ * @see uriNormalizeSyntaxExA
+ * @see uriNormalizeSyntaxExMmA
+ * @see uriNormalizeSyntaxMaskRequiredA
+ * @since 0.5.0
+ */
+URI_PUBLIC int URI_FUNC(NormalizeSyntax)(URI_TYPE(Uri) * uri);
+
+
+
+/**
+ * Converts a Unix filename to a %URI string.
+ * The destination buffer must be large enough to hold 7 + 3 * len(filename) + 1
+ * characters in case of an absolute filename or 3 * len(filename) + 1 in case
+ * of a relative filename.
+ *
+ * EXAMPLE
+ * Input: "/bin/bash"
+ * Output: "file:///bin/bash"
+ *
+ * @param filename <b>IN</b>: Unix filename to convert
+ * @param uriString <b>OUT</b>: Destination to write %URI string to
+ * @return Error code or 0 on success
+ *
+ * @see uriUriStringToUnixFilenameA
+ * @see uriWindowsFilenameToUriStringA
+ * @since 0.5.2
+ */
+URI_PUBLIC int URI_FUNC(UnixFilenameToUriString)(const URI_CHAR * filename,
+ URI_CHAR * uriString);
+
+
+
+/**
+ * Converts a Windows filename to a %URI string.
+ * The destination buffer must be large enough to hold 8 + 3 * len(filename) + 1
+ * characters in case of an absolute filename or 3 * len(filename) + 1 in case
+ * of a relative filename.
+ *
+ * EXAMPLE
+ * Input: "E:\\Documents and Settings"
+ * Output: "file:///E:/Documents%20and%20Settings"
+ *
+ * @param filename <b>IN</b>: Windows filename to convert
+ * @param uriString <b>OUT</b>: Destination to write %URI string to
+ * @return Error code or 0 on success
+ *
+ * @see uriUriStringToWindowsFilenameA
+ * @see uriUnixFilenameToUriStringA
+ * @since 0.5.2
+ */
+URI_PUBLIC int URI_FUNC(WindowsFilenameToUriString)(const URI_CHAR * filename,
+ URI_CHAR * uriString);
+
+
+
+/**
+ * Extracts a Unix filename from a %URI string.
+ * The destination buffer must be large enough to hold len(uriString) + 1 - 7
+ * characters in case of an absolute %URI or len(uriString) + 1 in case
+ * of a relative %URI.
+ *
+ * @param uriString <b>IN</b>: %URI string to convert
+ * @param filename <b>OUT</b>: Destination to write filename to
+ * @return Error code or 0 on success
+ *
+ * @see uriUnixFilenameToUriStringA
+ * @see uriUriStringToWindowsFilenameA
+ * @since 0.5.2
+ */
+URI_PUBLIC int URI_FUNC(UriStringToUnixFilename)(const URI_CHAR * uriString,
+ URI_CHAR * filename);
+
+
+
+/**
+ * Extracts a Windows filename from a %URI string.
+ * The destination buffer must be large enough to hold len(uriString) + 1 - 5
+ * characters in case of an absolute %URI or len(uriString) + 1 in case
+ * of a relative %URI.
+ *
+ * @param uriString <b>IN</b>: %URI string to convert
+ * @param filename <b>OUT</b>: Destination to write filename to
+ * @return Error code or 0 on success
+ *
+ * @see uriWindowsFilenameToUriStringA
+ * @see uriUriStringToUnixFilenameA
+ * @since 0.5.2
+ */
+URI_PUBLIC int URI_FUNC(UriStringToWindowsFilename)(const URI_CHAR * uriString,
+ URI_CHAR * filename);
+
+
+
+/**
+ * Calculates the number of characters needed to store the
+ * string representation of the given query list excluding the
+ * terminator. It is assumed that line breaks are will be
+ * normalized to "%0D%0A".
+ *
+ * @param queryList <b>IN</b>: Query list to measure
+ * @param charsRequired <b>OUT</b>: Length of the string representation in characters <b>excluding</b> terminator
+ * @return Error code or 0 on success
+ *
+ * @see uriComposeQueryCharsRequiredExA
+ * @see uriComposeQueryA
+ * @since 0.7.0
+ */
+URI_PUBLIC int URI_FUNC(ComposeQueryCharsRequired)(
+ const URI_TYPE(QueryList) * queryList, int * charsRequired);
+
+
+
+/**
+ * Calculates the number of characters needed to store the
+ * string representation of the given query list excluding the
+ * terminator.
+ *
+ * @param queryList <b>IN</b>: Query list to measure
+ * @param charsRequired <b>OUT</b>: Length of the string representation in characters <b>excluding</b> terminator
+ * @param spaceToPlus <b>IN</b>: Whether to convert ' ' to '+' or not
+ * @param normalizeBreaks <b>IN</b>: Whether to convert CR and LF to CR-LF or not.
+ * @return Error code or 0 on success
+ *
+ * @see uriComposeQueryCharsRequiredA
+ * @see uriComposeQueryExA
+ * @since 0.7.0
+ */
+URI_PUBLIC int URI_FUNC(ComposeQueryCharsRequiredEx)(
+ const URI_TYPE(QueryList) * queryList,
+ int * charsRequired, UriBool spaceToPlus, UriBool normalizeBreaks);
+
+
+
+/**
+ * Converts a query list structure back to a query string.
+ * The composed string does not start with '?',
+ * on the way ' ' is converted to '+' and line breaks are
+ * normalized to "%0D%0A".
+ *
+ * @param dest <b>OUT</b>: Output destination
+ * @param queryList <b>IN</b>: Query list to convert
+ * @param maxChars <b>IN</b>: Maximum number of characters to copy <b>including</b> terminator
+ * @param charsWritten <b>OUT</b>: Number of characters written, can be lower than maxChars even if the query list is too long!
+ * @return Error code or 0 on success
+ *
+ * @see uriComposeQueryExA
+ * @see uriComposeQueryMallocA
+ * @see uriComposeQueryMallocExA
+ * @see uriComposeQueryMallocExMmA
+ * @see uriComposeQueryCharsRequiredA
+ * @see uriDissectQueryMallocA
+ * @see uriDissectQueryMallocExA
+ * @see uriDissectQueryMallocExMmA
+ * @since 0.7.0
+ */
+URI_PUBLIC int URI_FUNC(ComposeQuery)(URI_CHAR * dest,
+ const URI_TYPE(QueryList) * queryList, int maxChars, int * charsWritten);
+
+
+
+/**
+ * Converts a query list structure back to a query string.
+ * The composed string does not start with '?'.
+ *
+ * @param dest <b>OUT</b>: Output destination
+ * @param queryList <b>IN</b>: Query list to convert
+ * @param maxChars <b>IN</b>: Maximum number of characters to copy <b>including</b> terminator
+ * @param charsWritten <b>OUT</b>: Number of characters written, can be lower than maxChars even if the query list is too long!
+ * @param spaceToPlus <b>IN</b>: Whether to convert ' ' to '+' or not
+ * @param normalizeBreaks <b>IN</b>: Whether to convert CR and LF to CR-LF or not.
+ * @return Error code or 0 on success
+ *
+ * @see uriComposeQueryA
+ * @see uriComposeQueryMallocA
+ * @see uriComposeQueryMallocExA
+ * @see uriComposeQueryMallocExMmA
+ * @see uriComposeQueryCharsRequiredExA
+ * @see uriDissectQueryMallocA
+ * @see uriDissectQueryMallocExA
+ * @see uriDissectQueryMallocExMmA
+ * @since 0.7.0
+ */
+URI_PUBLIC int URI_FUNC(ComposeQueryEx)(URI_CHAR * dest,
+ const URI_TYPE(QueryList) * queryList, int maxChars, int * charsWritten,
+ UriBool spaceToPlus, UriBool normalizeBreaks);
+
+
+
+/**
+ * Converts a query list structure back to a query string.
+ * Memory for this string is allocated internally.
+ * The composed string does not start with '?',
+ * on the way ' ' is converted to '+' and line breaks are
+ * normalized to "%0D%0A".
+ * Uses default libc-based memory manager.
+ *
+ * @param dest <b>OUT</b>: Output destination
+ * @param queryList <b>IN</b>: Query list to convert
+ * @return Error code or 0 on success
+ *
+ * @see uriComposeQueryMallocExA
+ * @see uriComposeQueryMallocExMmA
+ * @see uriComposeQueryA
+ * @see uriDissectQueryMallocA
+ * @see uriDissectQueryMallocExA
+ * @see uriDissectQueryMallocExMmA
+ * @since 0.7.0
+ */
+URI_PUBLIC int URI_FUNC(ComposeQueryMalloc)(URI_CHAR ** dest,
+ const URI_TYPE(QueryList) * queryList);
+
+
+
+/**
+ * Converts a query list structure back to a query string.
+ * Memory for this string is allocated internally.
+ * The composed string does not start with '?'.
+ * Uses default libc-based memory manager.
+ *
+ * @param dest <b>OUT</b>: Output destination
+ * @param queryList <b>IN</b>: Query list to convert
+ * @param spaceToPlus <b>IN</b>: Whether to convert ' ' to '+' or not
+ * @param normalizeBreaks <b>IN</b>: Whether to convert CR and LF to CR-LF or not.
+ * @return Error code or 0 on success
+ *
+ * @see uriComposeQueryMallocA
+ * @see uriComposeQueryMallocExMmA
+ * @see uriComposeQueryExA
+ * @see uriDissectQueryMallocA
+ * @see uriDissectQueryMallocExA
+ * @see uriDissectQueryMallocExMmA
+ * @since 0.7.0
+ */
+URI_PUBLIC int URI_FUNC(ComposeQueryMallocEx)(URI_CHAR ** dest,
+ const URI_TYPE(QueryList) * queryList,
+ UriBool spaceToPlus, UriBool normalizeBreaks);
+
+
+
+/**
+ * Converts a query list structure back to a query string.
+ * Memory for this string is allocated internally.
+ * The composed string does not start with '?'.
+ *
+ * @param dest <b>OUT</b>: Output destination
+ * @param queryList <b>IN</b>: Query list to convert
+ * @param spaceToPlus <b>IN</b>: Whether to convert ' ' to '+' or not
+ * @param normalizeBreaks <b>IN</b>: Whether to convert CR and LF to CR-LF or not.
+ * @param memory <b>IN</b>: Memory manager to use, NULL for default libc
+ * @return Error code or 0 on success
+ *
+ * @see uriComposeQueryMallocA
+ * @see uriComposeQueryMallocExA
+ * @see uriComposeQueryExA
+ * @see uriDissectQueryMallocA
+ * @see uriDissectQueryMallocExA
+ * @see uriDissectQueryMallocExMmA
+ * @since 0.9.0
+ */
+URI_PUBLIC int URI_FUNC(ComposeQueryMallocExMm)(URI_CHAR ** dest,
+ const URI_TYPE(QueryList) * queryList,
+ UriBool spaceToPlus, UriBool normalizeBreaks,
+ UriMemoryManager * memory);
+
+
+
+/**
+ * Constructs a query list from the raw query string of a given URI.
+ * On the way '+' is converted back to ' ', line breaks are not modified.
+ * Uses default libc-based memory manager.
+ *
+ * @param dest <b>OUT</b>: Output destination
+ * @param itemCount <b>OUT</b>: Number of items found, can be NULL
+ * @param first <b>IN</b>: Pointer to first character <b>after</b> '?'
+ * @param afterLast <b>IN</b>: Pointer to character after the last one still in
+ * @return Error code or 0 on success
+ *
+ * @see uriDissectQueryMallocExA
+ * @see uriDissectQueryMallocExMmA
+ * @see uriComposeQueryA
+ * @see uriFreeQueryListA
+ * @see uriFreeQueryListMmA
+ * @since 0.7.0
+ */
+URI_PUBLIC int URI_FUNC(DissectQueryMalloc)(URI_TYPE(QueryList) ** dest,
+ int * itemCount, const URI_CHAR * first, const URI_CHAR * afterLast);
+
+
+
+/**
+ * Constructs a query list from the raw query string of a given URI.
+ * Uses default libc-based memory manager.
+ *
+ * @param dest <b>OUT</b>: Output destination
+ * @param itemCount <b>OUT</b>: Number of items found, can be NULL
+ * @param first <b>IN</b>: Pointer to first character <b>after</b> '?'
+ * @param afterLast <b>IN</b>: Pointer to character after the last one still in
+ * @param plusToSpace <b>IN</b>: Whether to convert '+' to ' ' or not
+ * @param breakConversion <b>IN</b>: Line break conversion mode
+ * @return Error code or 0 on success
+ *
+ * @see uriDissectQueryMallocA
+ * @see uriDissectQueryMallocExMmA
+ * @see uriComposeQueryExA
+ * @see uriFreeQueryListA
+ * @since 0.7.0
+ */
+URI_PUBLIC int URI_FUNC(DissectQueryMallocEx)(URI_TYPE(QueryList) ** dest,
+ int * itemCount, const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriBool plusToSpace, UriBreakConversion breakConversion);
+
+
+
+/**
+ * Constructs a query list from the raw query string of a given URI.
+ *
+ * @param dest <b>OUT</b>: Output destination
+ * @param itemCount <b>OUT</b>: Number of items found, can be NULL
+ * @param first <b>IN</b>: Pointer to first character <b>after</b> '?'
+ * @param afterLast <b>IN</b>: Pointer to character after the last one still in
+ * @param plusToSpace <b>IN</b>: Whether to convert '+' to ' ' or not
+ * @param breakConversion <b>IN</b>: Line break conversion mode
+ * @param memory <b>IN</b>: Memory manager to use, NULL for default libc
+ * @return Error code or 0 on success
+ *
+ * @see uriDissectQueryMallocA
+ * @see uriDissectQueryMallocExA
+ * @see uriComposeQueryExA
+ * @see uriFreeQueryListA
+ * @see uriFreeQueryListMmA
+ * @since 0.9.0
+ */
+URI_PUBLIC int URI_FUNC(DissectQueryMallocExMm)(URI_TYPE(QueryList) ** dest,
+ int * itemCount, const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriBool plusToSpace, UriBreakConversion breakConversion,
+ UriMemoryManager * memory);
+
+
+
+/**
+ * Frees all memory associated with the given query list.
+ * The structure itself is freed as well.
+ *
+ * @param queryList <b>INOUT</b>: Query list to free
+ *
+ * @see uriFreeQueryListMmA
+ * @since 0.7.0
+ */
+URI_PUBLIC void URI_FUNC(FreeQueryList)(URI_TYPE(QueryList) * queryList);
+
+
+
+/**
+ * Frees all memory associated with the given query list.
+ * The structure itself is freed as well.
+ *
+ * @param queryList <b>INOUT</b>: Query list to free
+ * @param memory <b>IN</b>: Memory manager to use, NULL for default libc
+ * @return Error code or 0 on success
+ *
+ * @see uriFreeQueryListA
+ * @since 0.9.0
+ */
+URI_PUBLIC int URI_FUNC(FreeQueryListMm)(URI_TYPE(QueryList) * queryList,
+ UriMemoryManager * memory);
+
+
+
+#ifdef __cplusplus
+}
+#endif
+
+
+
+#endif
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriBase.h b/src/arrow/cpp/src/arrow/vendored/uriparser/UriBase.h
new file mode 100644
index 000000000..166915950
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriBase.h
@@ -0,0 +1,377 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/**
+ * @file UriBase.h
+ * Holds definitions independent of the encoding pass.
+ */
+
+#ifndef URI_BASE_H
+#define URI_BASE_H 1
+
+
+
+/* Version helper macro */
+#define URI_ANSI_TO_UNICODE(x) L##x
+
+
+
+/* Version */
+#define URI_VER_MAJOR 0
+#define URI_VER_MINOR 9
+#define URI_VER_RELEASE 3
+#define URI_VER_SUFFIX_ANSI ""
+#define URI_VER_SUFFIX_UNICODE URI_ANSI_TO_UNICODE(URI_VER_SUFFIX_ANSI)
+
+
+
+/* More version helper macros */
+#define URI_INT_TO_ANSI_HELPER(x) #x
+#define URI_INT_TO_ANSI(x) URI_INT_TO_ANSI_HELPER(x)
+
+#define URI_INT_TO_UNICODE_HELPER(x) URI_ANSI_TO_UNICODE(#x)
+#define URI_INT_TO_UNICODE(x) URI_INT_TO_UNICODE_HELPER(x)
+
+#define URI_VER_ANSI_HELPER(ma, mi, r, s) \
+ URI_INT_TO_ANSI(ma) "." \
+ URI_INT_TO_ANSI(mi) "." \
+ URI_INT_TO_ANSI(r) \
+ s
+
+#define URI_VER_UNICODE_HELPER(ma, mi, r, s) \
+ URI_INT_TO_UNICODE(ma) L"." \
+ URI_INT_TO_UNICODE(mi) L"." \
+ URI_INT_TO_UNICODE(r) \
+ s
+
+
+
+/* Full version strings */
+#define URI_VER_ANSI URI_VER_ANSI_HELPER(URI_VER_MAJOR, URI_VER_MINOR, URI_VER_RELEASE, URI_VER_SUFFIX_ANSI)
+#define URI_VER_UNICODE URI_VER_UNICODE_HELPER(URI_VER_MAJOR, URI_VER_MINOR, URI_VER_RELEASE, URI_VER_SUFFIX_UNICODE)
+
+
+
+/* Unused parameter macro */
+#ifdef __GNUC__
+# define URI_UNUSED(x) unused_##x __attribute__((unused))
+#else
+# define URI_UNUSED(x) x
+#endif
+
+
+
+/* Import/export decorator */
+#if defined(URI_STATIC_BUILD)
+# define URI_PUBLIC
+#else
+# if defined(URI_LIBRARY_BUILD)
+# if defined(_MSC_VER)
+# define URI_PUBLIC __declspec(dllexport)
+# elif defined(URI_VISIBILITY)
+# define URI_PUBLIC __attribute__ ((visibility("default")))
+# else
+# define URI_PUBLIC
+# endif
+# else
+# if defined(_MSC_VER)
+# define URI_PUBLIC __declspec(dllimport)
+# else
+# define URI_PUBLIC
+# endif
+# endif
+#endif
+
+
+
+typedef int UriBool; /**< Boolean type */
+
+#define URI_TRUE 1
+#define URI_FALSE 0
+
+
+
+/* Shared errors */
+#define URI_SUCCESS 0
+#define URI_ERROR_SYNTAX 1 /* Parsed text violates expected format */
+#define URI_ERROR_NULL 2 /* One of the params passed was NULL
+ although it mustn't be */
+#define URI_ERROR_MALLOC 3 /* Requested memory could not be allocated */
+#define URI_ERROR_OUTPUT_TOO_LARGE 4 /* Some output is to large for the receiving buffer */
+#define URI_ERROR_NOT_IMPLEMENTED 8 /* The called function is not implemented yet */
+#define URI_ERROR_RANGE_INVALID 9 /* The parameters passed contained invalid ranges */
+#define URI_ERROR_MEMORY_MANAGER_INCOMPLETE 10 /* [>=0.9.0] The UriMemoryManager passed does not implement all needed functions */
+
+
+/* Errors specific to ToString */
+#define URI_ERROR_TOSTRING_TOO_LONG URI_ERROR_OUTPUT_TOO_LARGE /* Deprecated, test for URI_ERROR_OUTPUT_TOO_LARGE instead */
+
+/* Errors specific to AddBaseUri */
+#define URI_ERROR_ADDBASE_REL_BASE 5 /* Given base is not absolute */
+
+/* Errors specific to RemoveBaseUri */
+#define URI_ERROR_REMOVEBASE_REL_BASE 6 /* Given base is not absolute */
+#define URI_ERROR_REMOVEBASE_REL_SOURCE 7 /* Given base is not absolute */
+
+/* Error specific to uriTestMemoryManager */
+#define URI_ERROR_MEMORY_MANAGER_FAULTY 11 /* [>=0.9.0] The UriMemoryManager given did not pass the test suite */
+
+
+#ifndef URI_DOXYGEN
+# include <stdio.h> /* For NULL, snprintf */
+# include <ctype.h> /* For wchar_t */
+# include <string.h> /* For strlen, memset, memcpy */
+# include <stdlib.h> /* For malloc */
+#endif /* URI_DOXYGEN */
+
+
+
+/**
+ * Holds an IPv4 address.
+ */
+typedef struct UriIp4Struct {
+ unsigned char data[4]; /**< Each octet in one byte */
+} UriIp4; /**< @copydoc UriIp4Struct */
+
+
+
+/**
+ * Holds an IPv6 address.
+ */
+typedef struct UriIp6Struct {
+ unsigned char data[16]; /**< Each quad in two bytes */
+} UriIp6; /**< @copydoc UriIp6Struct */
+
+
+struct UriMemoryManagerStruct; /* foward declaration to break loop */
+
+
+/**
+ * Function signature that custom malloc(3) functions must conform to
+ *
+ * @since 0.9.0
+ */
+typedef void * (*UriFuncMalloc)(struct UriMemoryManagerStruct *, size_t);
+
+/**
+ * Function signature that custom calloc(3) functions must conform to
+ *
+ * @since 0.9.0
+ */
+typedef void * (*UriFuncCalloc)(struct UriMemoryManagerStruct *, size_t, size_t);
+
+/**
+ * Function signature that custom realloc(3) functions must conform to
+ *
+ * @since 0.9.0
+ */
+typedef void * (*UriFuncRealloc)(struct UriMemoryManagerStruct *, void *, size_t);
+
+/**
+ * Function signature that custom reallocarray(3) functions must conform to
+ *
+ * @since 0.9.0
+ */
+typedef void * (*UriFuncReallocarray)(struct UriMemoryManagerStruct *, void *, size_t, size_t);
+
+/**
+ * Function signature that custom free(3) functions must conform to
+ *
+ * @since 0.9.0
+ */
+typedef void (*UriFuncFree)(struct UriMemoryManagerStruct *, void *);
+
+
+/**
+ * Class-like interface of custom memory managers
+ *
+ * @see uriCompleteMemoryManager
+ * @see uriEmulateCalloc
+ * @see uriEmulateReallocarray
+ * @see uriTestMemoryManager
+ * @since 0.9.0
+ */
+typedef struct UriMemoryManagerStruct {
+ UriFuncMalloc malloc; /**< Pointer to custom malloc(3) */
+ UriFuncCalloc calloc; /**< Pointer to custom calloc(3); to emulate using malloc and memset see uriEmulateCalloc */
+ UriFuncRealloc realloc; /**< Pointer to custom realloc(3) */
+ UriFuncReallocarray reallocarray; /**< Pointer to custom reallocarray(3); to emulate using realloc see uriEmulateReallocarray */
+ UriFuncFree free; /**< Pointer to custom free(3) */
+ void * userData; /**< Pointer to data that the other function members need access to */
+} UriMemoryManager; /**< @copydoc UriMemoryManagerStruct */
+
+
+/**
+ * Specifies a line break conversion mode.
+ */
+typedef enum UriBreakConversionEnum {
+ URI_BR_TO_LF, /**< Convert to Unix line breaks ("\\x0a") */
+ URI_BR_TO_CRLF, /**< Convert to Windows line breaks ("\\x0d\\x0a") */
+ URI_BR_TO_CR, /**< Convert to Macintosh line breaks ("\\x0d") */
+ URI_BR_TO_UNIX = URI_BR_TO_LF, /**< @copydoc UriBreakConversionEnum::URI_BR_TO_LF */
+ URI_BR_TO_WINDOWS = URI_BR_TO_CRLF, /**< @copydoc UriBreakConversionEnum::URI_BR_TO_CRLF */
+ URI_BR_TO_MAC = URI_BR_TO_CR, /**< @copydoc UriBreakConversionEnum::URI_BR_TO_CR */
+ URI_BR_DONT_TOUCH /**< Copy line breaks unmodified */
+} UriBreakConversion; /**< @copydoc UriBreakConversionEnum */
+
+
+
+/**
+ * Specifies which component of a %URI has to be normalized.
+ */
+typedef enum UriNormalizationMaskEnum {
+ URI_NORMALIZED = 0, /**< Do not normalize anything */
+ URI_NORMALIZE_SCHEME = 1 << 0, /**< Normalize scheme (fix uppercase letters) */
+ URI_NORMALIZE_USER_INFO = 1 << 1, /**< Normalize user info (fix uppercase percent-encodings) */
+ URI_NORMALIZE_HOST = 1 << 2, /**< Normalize host (fix uppercase letters) */
+ URI_NORMALIZE_PATH = 1 << 3, /**< Normalize path (fix uppercase percent-encodings and redundant dot segments) */
+ URI_NORMALIZE_QUERY = 1 << 4, /**< Normalize query (fix uppercase percent-encodings) */
+ URI_NORMALIZE_FRAGMENT = 1 << 5 /**< Normalize fragment (fix uppercase percent-encodings) */
+} UriNormalizationMask; /**< @copydoc UriNormalizationMaskEnum */
+
+
+
+/**
+ * Specifies how to resolve %URI references.
+ */
+typedef enum UriResolutionOptionsEnum {
+ URI_RESOLVE_STRICTLY = 0, /**< Full RFC conformance */
+ URI_RESOLVE_IDENTICAL_SCHEME_COMPAT = 1 << 0 /**< Treat %URI to resolve with identical scheme as having no scheme */
+} UriResolutionOptions; /**< @copydoc UriResolutionOptionsEnum */
+
+
+
+/**
+ * Wraps a memory manager backend that only provides malloc and free
+ * to make a complete memory manager ready to be used.
+ *
+ * The core feature of this wrapper is that you don't need to implement
+ * realloc if you don't want to. The wrapped memory manager uses
+ * backend->malloc, memcpy, and backend->free and soieof(size_t) extra
+ * bytes per allocation to emulate fallback realloc for you.
+ *
+ * memory->calloc is uriEmulateCalloc.
+ * memory->free uses backend->free and handles the size header.
+ * memory->malloc uses backend->malloc and adds a size header.
+ * memory->realloc uses memory->malloc, memcpy, and memory->free and reads
+ * the size header.
+ * memory->reallocarray is uriEmulateReallocarray.
+ *
+ * The internal workings behind memory->free, memory->malloc, and
+ * memory->realloc may change so the functions exposed by these function
+ * pointer sshould be consided internal and not public API.
+ *
+ * @param memory <b>OUT</b>: Where to write the wrapped memory manager to
+ * @param backend <b>IN</b>: Memory manager to use as a backend
+ * @return Error code or 0 on success
+ *
+ * @see uriEmulateCalloc
+ * @see uriEmulateReallocarray
+ * @see UriMemoryManager
+ * @since 0.9.0
+ */
+URI_PUBLIC int uriCompleteMemoryManager(UriMemoryManager * memory,
+ UriMemoryManager * backend);
+
+
+
+/**
+ * Offers emulation of calloc(3) based on memory->malloc and memset.
+ * See "man 3 calloc" as well.
+ *
+ * @param memory <b>IN</b>: Memory manager to use, should not be NULL
+ * @param nmemb <b>IN</b>: Number of elements to allocate
+ * @param size <b>IN</b>: Size in bytes per element
+ * @return Pointer to allocated memory or NULL
+ *
+ * @see uriCompleteMemoryManager
+ * @see uriEmulateReallocarray
+ * @see UriMemoryManager
+ * @since 0.9.0
+ */
+URI_PUBLIC void * uriEmulateCalloc(UriMemoryManager * memory,
+ size_t nmemb, size_t size);
+
+
+
+/**
+ * Offers emulation of reallocarray(3) based on memory->realloc.
+ * See "man 3 reallocarray" as well.
+ *
+ * @param memory <b>IN</b>: Memory manager to use, should not be NULL
+ * @param ptr <b>IN</b>: Pointer allocated using memory->malloc/... or NULL
+ * @param nmemb <b>IN</b>: Number of elements to allocate
+ * @param size <b>IN</b>: Size in bytes per element
+ * @return Pointer to allocated memory or NULL
+ *
+ * @see uriCompleteMemoryManager
+ * @see uriEmulateCalloc
+ * @see UriMemoryManager
+ * @since 0.9.0
+ */
+URI_PUBLIC void * uriEmulateReallocarray(UriMemoryManager * memory,
+ void * ptr, size_t nmemb, size_t size);
+
+
+
+/**
+ * Run multiple tests against a given memory manager.
+ * For example, one test
+ * 1. allocates a small amount of memory,
+ * 2. writes some magic bytes to it,
+ * 3. reallocates it,
+ * 4. checks that previous values are still present,
+ * 5. and frees that memory.
+ *
+ * It is recommended to compile with AddressSanitizer enabled
+ * to take full advantage of uriTestMemoryManager.
+ *
+ * @param memory <b>IN</b>: Memory manager to use, should not be NULL
+ * @return Error code or 0 on success
+ *
+ * @see uriEmulateCalloc
+ * @see uriEmulateReallocarray
+ * @see UriMemoryManager
+ * @since 0.9.0
+ */
+URI_PUBLIC int uriTestMemoryManager(UriMemoryManager * memory);
+
+
+
+#endif /* URI_BASE_H */
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriCommon.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriCommon.c
new file mode 100644
index 000000000..824893d7c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriCommon.c
@@ -0,0 +1,572 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/* What encodings are enabled? */
+#include "UriDefsConfig.h"
+#if (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* Include SELF twice */
+# ifdef URI_ENABLE_ANSI
+# define URI_PASS_ANSI 1
+# include "UriCommon.c"
+# undef URI_PASS_ANSI
+# endif
+# ifdef URI_ENABLE_UNICODE
+# define URI_PASS_UNICODE 1
+# include "UriCommon.c"
+# undef URI_PASS_UNICODE
+# endif
+#else
+# ifdef URI_PASS_ANSI
+# include "UriDefsAnsi.h"
+# else
+# include "UriDefsUnicode.h"
+# include <wchar.h>
+# endif
+
+
+
+#ifndef URI_DOXYGEN
+# include "Uri.h"
+# include "UriCommon.h"
+#endif
+
+
+
+/*extern*/ const URI_CHAR * const URI_FUNC(SafeToPointTo) = _UT("X");
+/*extern*/ const URI_CHAR * const URI_FUNC(ConstPwd) = _UT(".");
+/*extern*/ const URI_CHAR * const URI_FUNC(ConstParent) = _UT("..");
+
+
+
+void URI_FUNC(ResetUri)(URI_TYPE(Uri) * uri) {
+ if (uri == NULL) {
+ return;
+ }
+ memset(uri, 0, sizeof(URI_TYPE(Uri)));
+}
+
+
+
+/* Compares two text ranges for equal text content */
+int URI_FUNC(CompareRange)(
+ const URI_TYPE(TextRange) * a,
+ const URI_TYPE(TextRange) * b) {
+ int diff;
+
+ /* NOTE: Both NULL means equal! */
+ if ((a == NULL) || (b == NULL)) {
+ return ((a == NULL) ? 0 : 1) - ((b == NULL) ? 0 : 1);
+ }
+
+ /* NOTE: Both NULL means equal! */
+ if ((a->first == NULL) || (b->first == NULL)) {
+ return ((a->first == NULL) ? 0 : 1) - ((b->first == NULL) ? 0 : 1);
+ }
+
+ diff = ((int)(a->afterLast - a->first) - (int)(b->afterLast - b->first));
+ if (diff > 0) {
+ return 1;
+ } else if (diff < 0) {
+ return -1;
+ }
+
+ diff = URI_STRNCMP(a->first, b->first, (a->afterLast - a->first));
+
+ if (diff > 0) {
+ return 1;
+ } else if (diff < 0) {
+ return -1;
+ }
+
+ return diff;
+}
+
+
+
+/* Properly removes "." and ".." path segments */
+UriBool URI_FUNC(RemoveDotSegments)(URI_TYPE(Uri) * uri,
+ UriBool relative, UriMemoryManager * memory) {
+ if (uri == NULL) {
+ return URI_TRUE;
+ }
+ return URI_FUNC(RemoveDotSegmentsEx)(uri, relative, uri->owner, memory);
+}
+
+
+
+UriBool URI_FUNC(RemoveDotSegmentsEx)(URI_TYPE(Uri) * uri,
+ UriBool relative, UriBool pathOwned, UriMemoryManager * memory) {
+ URI_TYPE(PathSegment) * walker;
+ if ((uri == NULL) || (uri->pathHead == NULL)) {
+ return URI_TRUE;
+ }
+
+ walker = uri->pathHead;
+ walker->reserved = NULL; /* Prev pointer */
+ do {
+ UriBool removeSegment = URI_FALSE;
+ int len = (int)(walker->text.afterLast - walker->text.first);
+ switch (len) {
+ case 1:
+ if ((walker->text.first)[0] == _UT('.')) {
+ /* "." segment -> remove if not essential */
+ URI_TYPE(PathSegment) * const prev = walker->reserved;
+ URI_TYPE(PathSegment) * const nextBackup = walker->next;
+
+ /* Is this dot segment essential? */
+ removeSegment = URI_TRUE;
+ if (relative && (walker == uri->pathHead) && (walker->next != NULL)) {
+ const URI_CHAR * ch = walker->next->text.first;
+ for (; ch < walker->next->text.afterLast; ch++) {
+ if (*ch == _UT(':')) {
+ removeSegment = URI_FALSE;
+ break;
+ }
+ }
+ }
+
+ if (removeSegment) {
+ /* Last segment? */
+ if (walker->next != NULL) {
+ /* Not last segment */
+ walker->next->reserved = prev;
+
+ if (prev == NULL) {
+ /* First but not last segment */
+ uri->pathHead = walker->next;
+ } else {
+ /* Middle segment */
+ prev->next = walker->next;
+ }
+
+ if (pathOwned && (walker->text.first != walker->text.afterLast)) {
+ memory->free(memory, (URI_CHAR *)walker->text.first);
+ }
+ memory->free(memory, walker);
+ } else {
+ /* Last segment */
+ if (pathOwned && (walker->text.first != walker->text.afterLast)) {
+ memory->free(memory, (URI_CHAR *)walker->text.first);
+ }
+
+ if (prev == NULL) {
+ /* Last and first */
+ if (URI_FUNC(IsHostSet)(uri)) {
+ /* Replace "." with empty segment to represent trailing slash */
+ walker->text.first = URI_FUNC(SafeToPointTo);
+ walker->text.afterLast = URI_FUNC(SafeToPointTo);
+ } else {
+ memory->free(memory, walker);
+
+ uri->pathHead = NULL;
+ uri->pathTail = NULL;
+ }
+ } else {
+ /* Last but not first, replace "." with empty segment to represent trailing slash */
+ walker->text.first = URI_FUNC(SafeToPointTo);
+ walker->text.afterLast = URI_FUNC(SafeToPointTo);
+ }
+ }
+
+ walker = nextBackup;
+ }
+ }
+ break;
+
+ case 2:
+ if (((walker->text.first)[0] == _UT('.'))
+ && ((walker->text.first)[1] == _UT('.'))) {
+ /* Path ".." -> remove this and the previous segment */
+ URI_TYPE(PathSegment) * const prev = walker->reserved;
+ URI_TYPE(PathSegment) * prevPrev;
+ URI_TYPE(PathSegment) * const nextBackup = walker->next;
+
+ removeSegment = URI_TRUE;
+ if (relative) {
+ if (prev == NULL) {
+ removeSegment = URI_FALSE;
+ } else if ((prev != NULL)
+ && ((prev->text.afterLast - prev->text.first) == 2)
+ && ((prev->text.first)[0] == _UT('.'))
+ && ((prev->text.first)[1] == _UT('.'))) {
+ removeSegment = URI_FALSE;
+ }
+ }
+
+ if (removeSegment) {
+ if (prev != NULL) {
+ /* Not first segment */
+ prevPrev = prev->reserved;
+ if (prevPrev != NULL) {
+ /* Not even prev is the first one */
+ prevPrev->next = walker->next;
+ if (walker->next != NULL) {
+ walker->next->reserved = prevPrev;
+ } else {
+ /* Last segment -> insert "" segment to represent trailing slash, update tail */
+ URI_TYPE(PathSegment) * const segment = memory->calloc(memory, 1, sizeof(URI_TYPE(PathSegment)));
+ if (segment == NULL) {
+ if (pathOwned && (walker->text.first != walker->text.afterLast)) {
+ memory->free(memory, (URI_CHAR *)walker->text.first);
+ }
+ memory->free(memory, walker);
+
+ if (pathOwned && (prev->text.first != prev->text.afterLast)) {
+ memory->free(memory, (URI_CHAR *)prev->text.first);
+ }
+ memory->free(memory, prev);
+
+ return URI_FALSE; /* Raises malloc error */
+ }
+ segment->text.first = URI_FUNC(SafeToPointTo);
+ segment->text.afterLast = URI_FUNC(SafeToPointTo);
+ prevPrev->next = segment;
+ uri->pathTail = segment;
+ }
+
+ if (pathOwned && (walker->text.first != walker->text.afterLast)) {
+ memory->free(memory, (URI_CHAR *)walker->text.first);
+ }
+ memory->free(memory, walker);
+
+ if (pathOwned && (prev->text.first != prev->text.afterLast)) {
+ memory->free(memory, (URI_CHAR *)prev->text.first);
+ }
+ memory->free(memory, prev);
+
+ walker = nextBackup;
+ } else {
+ /* Prev is the first segment */
+ if (walker->next != NULL) {
+ uri->pathHead = walker->next;
+ walker->next->reserved = NULL;
+
+ if (pathOwned && (walker->text.first != walker->text.afterLast)) {
+ memory->free(memory, (URI_CHAR *)walker->text.first);
+ }
+ memory->free(memory, walker);
+ } else {
+ /* Re-use segment for "" path segment to represent trailing slash, update tail */
+ URI_TYPE(PathSegment) * const segment = walker;
+ if (pathOwned && (segment->text.first != segment->text.afterLast)) {
+ memory->free(memory, (URI_CHAR *)segment->text.first);
+ }
+ segment->text.first = URI_FUNC(SafeToPointTo);
+ segment->text.afterLast = URI_FUNC(SafeToPointTo);
+ uri->pathHead = segment;
+ uri->pathTail = segment;
+ }
+
+ if (pathOwned && (prev->text.first != prev->text.afterLast)) {
+ memory->free(memory, (URI_CHAR *)prev->text.first);
+ }
+ memory->free(memory, prev);
+
+ walker = nextBackup;
+ }
+ } else {
+ URI_TYPE(PathSegment) * const anotherNextBackup = walker->next;
+ /* First segment -> update head pointer */
+ uri->pathHead = walker->next;
+ if (walker->next != NULL) {
+ walker->next->reserved = NULL;
+ } else {
+ /* Last segment -> update tail */
+ uri->pathTail = NULL;
+ }
+
+ if (pathOwned && (walker->text.first != walker->text.afterLast)) {
+ memory->free(memory, (URI_CHAR *)walker->text.first);
+ }
+ memory->free(memory, walker);
+
+ walker = anotherNextBackup;
+ }
+ }
+ }
+ break;
+
+ }
+
+ if (!removeSegment) {
+ if (walker->next != NULL) {
+ walker->next->reserved = walker;
+ } else {
+ /* Last segment -> update tail */
+ uri->pathTail = walker;
+ }
+ walker = walker->next;
+ }
+ } while (walker != NULL);
+
+ return URI_TRUE;
+}
+
+
+
+/* Properly removes "." and ".." path segments */
+UriBool URI_FUNC(RemoveDotSegmentsAbsolute)(URI_TYPE(Uri) * uri,
+ UriMemoryManager * memory) {
+ const UriBool ABSOLUTE = URI_FALSE;
+ return URI_FUNC(RemoveDotSegments)(uri, ABSOLUTE, memory);
+}
+
+
+
+unsigned char URI_FUNC(HexdigToInt)(URI_CHAR hexdig) {
+ switch (hexdig) {
+ case _UT('0'):
+ case _UT('1'):
+ case _UT('2'):
+ case _UT('3'):
+ case _UT('4'):
+ case _UT('5'):
+ case _UT('6'):
+ case _UT('7'):
+ case _UT('8'):
+ case _UT('9'):
+ return (unsigned char)(9 + hexdig - _UT('9'));
+
+ case _UT('a'):
+ case _UT('b'):
+ case _UT('c'):
+ case _UT('d'):
+ case _UT('e'):
+ case _UT('f'):
+ return (unsigned char)(15 + hexdig - _UT('f'));
+
+ case _UT('A'):
+ case _UT('B'):
+ case _UT('C'):
+ case _UT('D'):
+ case _UT('E'):
+ case _UT('F'):
+ return (unsigned char)(15 + hexdig - _UT('F'));
+
+ default:
+ return 0;
+ }
+}
+
+
+
+URI_CHAR URI_FUNC(HexToLetter)(unsigned int value) {
+ /* Uppercase recommended in section 2.1. of RFC 3986 *
+ * http://tools.ietf.org/html/rfc3986#section-2.1 */
+ return URI_FUNC(HexToLetterEx)(value, URI_TRUE);
+}
+
+
+
+URI_CHAR URI_FUNC(HexToLetterEx)(unsigned int value, UriBool uppercase) {
+ switch (value) {
+ case 0: return _UT('0');
+ case 1: return _UT('1');
+ case 2: return _UT('2');
+ case 3: return _UT('3');
+ case 4: return _UT('4');
+ case 5: return _UT('5');
+ case 6: return _UT('6');
+ case 7: return _UT('7');
+ case 8: return _UT('8');
+ case 9: return _UT('9');
+
+ case 10: return (uppercase == URI_TRUE) ? _UT('A') : _UT('a');
+ case 11: return (uppercase == URI_TRUE) ? _UT('B') : _UT('b');
+ case 12: return (uppercase == URI_TRUE) ? _UT('C') : _UT('c');
+ case 13: return (uppercase == URI_TRUE) ? _UT('D') : _UT('d');
+ case 14: return (uppercase == URI_TRUE) ? _UT('E') : _UT('e');
+ default: return (uppercase == URI_TRUE) ? _UT('F') : _UT('f');
+ }
+}
+
+
+
+/* Checks if a URI has the host component set. */
+UriBool URI_FUNC(IsHostSet)(const URI_TYPE(Uri) * uri) {
+ return (uri != NULL)
+ && ((uri->hostText.first != NULL)
+ || (uri->hostData.ip4 != NULL)
+ || (uri->hostData.ip6 != NULL)
+ || (uri->hostData.ipFuture.first != NULL)
+ );
+}
+
+
+
+/* Copies the path segment list from one URI to another. */
+UriBool URI_FUNC(CopyPath)(URI_TYPE(Uri) * dest,
+ const URI_TYPE(Uri) * source, UriMemoryManager * memory) {
+ if (source->pathHead == NULL) {
+ /* No path component */
+ dest->pathHead = NULL;
+ dest->pathTail = NULL;
+ } else {
+ /* Copy list but not the text contained */
+ URI_TYPE(PathSegment) * sourceWalker = source->pathHead;
+ URI_TYPE(PathSegment) * destPrev = NULL;
+ do {
+ URI_TYPE(PathSegment) * cur = memory->malloc(memory, sizeof(URI_TYPE(PathSegment)));
+ if (cur == NULL) {
+ /* Fix broken list */
+ if (destPrev != NULL) {
+ destPrev->next = NULL;
+ }
+ return URI_FALSE; /* Raises malloc error */
+ }
+
+ /* From this functions usage we know that *
+ * the dest URI cannot be uri->owner */
+ cur->text = sourceWalker->text;
+ if (destPrev == NULL) {
+ /* First segment ever */
+ dest->pathHead = cur;
+ } else {
+ destPrev->next = cur;
+ }
+ destPrev = cur;
+ sourceWalker = sourceWalker->next;
+ } while (sourceWalker != NULL);
+ dest->pathTail = destPrev;
+ dest->pathTail->next = NULL;
+ }
+
+ dest->absolutePath = source->absolutePath;
+ return URI_TRUE;
+}
+
+
+
+/* Copies the authority part of an URI over to another. */
+UriBool URI_FUNC(CopyAuthority)(URI_TYPE(Uri) * dest,
+ const URI_TYPE(Uri) * source, UriMemoryManager * memory) {
+ /* From this functions usage we know that *
+ * the dest URI cannot be uri->owner */
+
+ /* Copy userInfo */
+ dest->userInfo = source->userInfo;
+
+ /* Copy hostText */
+ dest->hostText = source->hostText;
+
+ /* Copy hostData */
+ if (source->hostData.ip4 != NULL) {
+ dest->hostData.ip4 = memory->malloc(memory, sizeof(UriIp4));
+ if (dest->hostData.ip4 == NULL) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+ *(dest->hostData.ip4) = *(source->hostData.ip4);
+ dest->hostData.ip6 = NULL;
+ dest->hostData.ipFuture.first = NULL;
+ dest->hostData.ipFuture.afterLast = NULL;
+ } else if (source->hostData.ip6 != NULL) {
+ dest->hostData.ip4 = NULL;
+ dest->hostData.ip6 = memory->malloc(memory, sizeof(UriIp6));
+ if (dest->hostData.ip6 == NULL) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+ *(dest->hostData.ip6) = *(source->hostData.ip6);
+ dest->hostData.ipFuture.first = NULL;
+ dest->hostData.ipFuture.afterLast = NULL;
+ } else {
+ dest->hostData.ip4 = NULL;
+ dest->hostData.ip6 = NULL;
+ dest->hostData.ipFuture = source->hostData.ipFuture;
+ }
+
+ /* Copy portText */
+ dest->portText = source->portText;
+
+ return URI_TRUE;
+}
+
+
+
+UriBool URI_FUNC(FixAmbiguity)(URI_TYPE(Uri) * uri,
+ UriMemoryManager * memory) {
+ URI_TYPE(PathSegment) * segment;
+
+ if ( /* Case 1: absolute path, empty first segment */
+ (uri->absolutePath
+ && (uri->pathHead != NULL)
+ && (uri->pathHead->text.afterLast == uri->pathHead->text.first))
+
+ /* Case 2: relative path, empty first and second segment */
+ || (!uri->absolutePath
+ && (uri->pathHead != NULL)
+ && (uri->pathHead->next != NULL)
+ && (uri->pathHead->text.afterLast == uri->pathHead->text.first)
+ && (uri->pathHead->next->text.afterLast == uri->pathHead->next->text.first))) {
+ /* NOOP */
+ } else {
+ return URI_TRUE;
+ }
+
+ segment = memory->malloc(memory, 1 * sizeof(URI_TYPE(PathSegment)));
+ if (segment == NULL) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+
+ /* Insert "." segment in front */
+ segment->next = uri->pathHead;
+ segment->text.first = URI_FUNC(ConstPwd);
+ segment->text.afterLast = URI_FUNC(ConstPwd) + 1;
+ uri->pathHead = segment;
+ return URI_TRUE;
+}
+
+
+
+void URI_FUNC(FixEmptyTrailSegment)(URI_TYPE(Uri) * uri,
+ UriMemoryManager * memory) {
+ /* Fix path if only one empty segment */
+ if (!uri->absolutePath
+ && !URI_FUNC(IsHostSet)(uri)
+ && (uri->pathHead != NULL)
+ && (uri->pathHead->next == NULL)
+ && (uri->pathHead->text.first == uri->pathHead->text.afterLast)) {
+ memory->free(memory, uri->pathHead);
+ uri->pathHead = NULL;
+ uri->pathTail = NULL;
+ }
+}
+
+
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriCommon.h b/src/arrow/cpp/src/arrow/vendored/uriparser/UriCommon.h
new file mode 100644
index 000000000..97b997a9a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriCommon.h
@@ -0,0 +1,109 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#if (defined(URI_PASS_ANSI) && !defined(URI_COMMON_H_ANSI)) \
+ || (defined(URI_PASS_UNICODE) && !defined(URI_COMMON_H_UNICODE)) \
+ || (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* What encodings are enabled? */
+#include "UriDefsConfig.h"
+#if (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* Include SELF twice */
+# ifdef URI_ENABLE_ANSI
+# define URI_PASS_ANSI 1
+# include "UriCommon.h"
+# undef URI_PASS_ANSI
+# endif
+# ifdef URI_ENABLE_UNICODE
+# define URI_PASS_UNICODE 1
+# include "UriCommon.h"
+# undef URI_PASS_UNICODE
+# endif
+/* Only one pass for each encoding */
+#elif (defined(URI_PASS_ANSI) && !defined(URI_COMMON_H_ANSI) \
+ && defined(URI_ENABLE_ANSI)) || (defined(URI_PASS_UNICODE) \
+ && !defined(URI_COMMON_H_UNICODE) && defined(URI_ENABLE_UNICODE))
+# ifdef URI_PASS_ANSI
+# define URI_COMMON_H_ANSI 1
+# include "UriDefsAnsi.h"
+# else
+# define URI_COMMON_H_UNICODE 1
+# include "UriDefsUnicode.h"
+# endif
+
+
+
+/* Used to point to from empty path segments.
+ * X.first and X.afterLast must be the same non-NULL value then. */
+extern const URI_CHAR * const URI_FUNC(SafeToPointTo);
+extern const URI_CHAR * const URI_FUNC(ConstPwd);
+extern const URI_CHAR * const URI_FUNC(ConstParent);
+
+
+
+void URI_FUNC(ResetUri)(URI_TYPE(Uri) * uri);
+
+int URI_FUNC(CompareRange)(
+ const URI_TYPE(TextRange) * a,
+ const URI_TYPE(TextRange) * b);
+
+UriBool URI_FUNC(RemoveDotSegmentsAbsolute)(URI_TYPE(Uri) * uri,
+ UriMemoryManager * memory);
+UriBool URI_FUNC(RemoveDotSegments)(URI_TYPE(Uri) * uri, UriBool relative,
+ UriMemoryManager * memory);
+UriBool URI_FUNC(RemoveDotSegmentsEx)(URI_TYPE(Uri) * uri,
+ UriBool relative, UriBool pathOwned, UriMemoryManager * memory);
+
+unsigned char URI_FUNC(HexdigToInt)(URI_CHAR hexdig);
+URI_CHAR URI_FUNC(HexToLetter)(unsigned int value);
+URI_CHAR URI_FUNC(HexToLetterEx)(unsigned int value, UriBool uppercase);
+
+UriBool URI_FUNC(IsHostSet)(const URI_TYPE(Uri) * uri);
+
+UriBool URI_FUNC(CopyPath)(URI_TYPE(Uri) * dest, const URI_TYPE(Uri) * source,
+ UriMemoryManager * memory);
+UriBool URI_FUNC(CopyAuthority)(URI_TYPE(Uri) * dest,
+ const URI_TYPE(Uri) * source, UriMemoryManager * memory);
+
+UriBool URI_FUNC(FixAmbiguity)(URI_TYPE(Uri) * uri, UriMemoryManager * memory);
+void URI_FUNC(FixEmptyTrailSegment)(URI_TYPE(Uri) * uri,
+ UriMemoryManager * memory);
+
+
+#endif
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriCompare.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriCompare.c
new file mode 100644
index 000000000..e63573a1f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriCompare.c
@@ -0,0 +1,168 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/* What encodings are enabled? */
+#include "UriDefsConfig.h"
+#if (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* Include SELF twice */
+# ifdef URI_ENABLE_ANSI
+# define URI_PASS_ANSI 1
+# include "UriCompare.c"
+# undef URI_PASS_ANSI
+# endif
+# ifdef URI_ENABLE_UNICODE
+# define URI_PASS_UNICODE 1
+# include "UriCompare.c"
+# undef URI_PASS_UNICODE
+# endif
+#else
+# ifdef URI_PASS_ANSI
+# include "UriDefsAnsi.h"
+# else
+# include "UriDefsUnicode.h"
+# include <wchar.h>
+# endif
+
+
+
+#ifndef URI_DOXYGEN
+# include "Uri.h"
+# include "UriIp4.h"
+# include "UriCommon.h"
+#endif
+
+
+
+UriBool URI_FUNC(EqualsUri)(const URI_TYPE(Uri) * a,
+ const URI_TYPE(Uri) * b) {
+ /* NOTE: Both NULL means equal! */
+ if ((a == NULL) || (b == NULL)) {
+ return ((a == NULL) && (b == NULL)) ? URI_TRUE : URI_FALSE;
+ }
+
+ /* scheme */
+ if (URI_FUNC(CompareRange)(&(a->scheme), &(b->scheme))) {
+ return URI_FALSE;
+ }
+
+ /* absolutePath */
+ if ((a->scheme.first == NULL)&& (a->absolutePath != b->absolutePath)) {
+ return URI_FALSE;
+ }
+
+ /* userInfo */
+ if (URI_FUNC(CompareRange)(&(a->userInfo), &(b->userInfo))) {
+ return URI_FALSE;
+ }
+
+ /* Host */
+ if (((a->hostData.ip4 == NULL) != (b->hostData.ip4 == NULL))
+ || ((a->hostData.ip6 == NULL) != (b->hostData.ip6 == NULL))
+ || ((a->hostData.ipFuture.first == NULL)
+ != (b->hostData.ipFuture.first == NULL))) {
+ return URI_FALSE;
+ }
+
+ if (a->hostData.ip4 != NULL) {
+ if (memcmp(a->hostData.ip4->data, b->hostData.ip4->data, 4)) {
+ return URI_FALSE;
+ }
+ }
+
+ if (a->hostData.ip6 != NULL) {
+ if (memcmp(a->hostData.ip6->data, b->hostData.ip6->data, 16)) {
+ return URI_FALSE;
+ }
+ }
+
+ if (a->hostData.ipFuture.first != NULL) {
+ if (URI_FUNC(CompareRange)(&(a->hostData.ipFuture), &(b->hostData.ipFuture))) {
+ return URI_FALSE;
+ }
+ }
+
+ if ((a->hostData.ip4 == NULL)
+ && (a->hostData.ip6 == NULL)
+ && (a->hostData.ipFuture.first == NULL)) {
+ if (URI_FUNC(CompareRange)(&(a->hostText), &(b->hostText))) {
+ return URI_FALSE;
+ }
+ }
+
+ /* portText */
+ if (URI_FUNC(CompareRange)(&(a->portText), &(b->portText))) {
+ return URI_FALSE;
+ }
+
+ /* Path */
+ if ((a->pathHead == NULL) != (b->pathHead == NULL)) {
+ return URI_FALSE;
+ }
+
+ if (a->pathHead != NULL) {
+ URI_TYPE(PathSegment) * walkA = a->pathHead;
+ URI_TYPE(PathSegment) * walkB = b->pathHead;
+ do {
+ if (URI_FUNC(CompareRange)(&(walkA->text), &(walkB->text))) {
+ return URI_FALSE;
+ }
+ if ((walkA->next == NULL) != (walkB->next == NULL)) {
+ return URI_FALSE;
+ }
+ walkA = walkA->next;
+ walkB = walkB->next;
+ } while (walkA != NULL);
+ }
+
+ /* query */
+ if (URI_FUNC(CompareRange)(&(a->query), &(b->query))) {
+ return URI_FALSE;
+ }
+
+ /* fragment */
+ if (URI_FUNC(CompareRange)(&(a->fragment), &(b->fragment))) {
+ return URI_FALSE;
+ }
+
+ return URI_TRUE; /* Equal*/
+}
+
+
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriDefsAnsi.h b/src/arrow/cpp/src/arrow/vendored/uriparser/UriDefsAnsi.h
new file mode 100644
index 000000000..d6dbcad1b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriDefsAnsi.h
@@ -0,0 +1,82 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/**
+ * @file UriDefsAnsi.h
+ * Holds definitions for the ANSI pass.
+ * NOTE: This header is included N times, not once.
+ */
+
+/* Allow multi inclusion */
+#include "UriDefsConfig.h"
+
+
+
+#undef URI_CHAR
+#define URI_CHAR char
+
+#undef _UT
+#define _UT(x) x
+
+
+
+#undef URI_FUNC
+#define URI_FUNC(x) uri##x##A
+
+#undef URI_TYPE
+#define URI_TYPE(x) Uri##x##A
+
+
+
+#undef URI_STRLEN
+#define URI_STRLEN strlen
+#undef URI_STRCPY
+#define URI_STRCPY strcpy
+#undef URI_STRCMP
+#define URI_STRCMP strcmp
+#undef URI_STRNCMP
+#define URI_STRNCMP strncmp
+
+/* TODO Remove on next source-compatibility break */
+#undef URI_SNPRINTF
+#if (defined(__WIN32__) || defined(_WIN32) || defined(WIN32))
+# define URI_SNPRINTF _snprintf
+#else
+# define URI_SNPRINTF snprintf
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriDefsConfig.h b/src/arrow/cpp/src/arrow/vendored/uriparser/UriDefsConfig.h
new file mode 100644
index 000000000..d604494b0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriDefsConfig.h
@@ -0,0 +1,102 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/**
+ * @file UriDefsConfig.h
+ * Adjusts the internal configuration after processing external definitions.
+ */
+
+#ifndef URI_DEFS_CONFIG_H
+#define URI_DEFS_CONFIG_H 1
+
+
+
+/* Deny external overriding */
+#undef URI_ENABLE_ANSI /* Internal for !URI_NO_ANSI */
+#undef URI_ENABLE_UNICODE /* Internal for !URI_NO_UNICODE */
+
+
+
+/* Encoding */
+#ifdef URI_NO_ANSI
+# ifdef URI_NO_UNICODE
+/* No encoding at all */
+# error URI_NO_ANSI and URI_NO_UNICODE cannot go together.
+# else
+/* Unicode only */
+# define URI_ENABLE_UNICODE 1
+# endif
+#else
+# ifdef URI_NO_UNICODE
+/* ANSI only */
+# define URI_ENABLE_ANSI 1
+# else
+/* Both ANSI and Unicode */
+# define URI_ENABLE_ANSI 1
+# define URI_ENABLE_UNICODE 1
+# endif
+#endif
+
+
+
+/* Function inlining, not ANSI/ISO C! */
+#if defined(URI_DOXYGEN)
+# define URI_INLINE
+#elif defined(__INTEL_COMPILER)
+/* Intel C/C++ */
+/* http://predef.sourceforge.net/precomp.html#sec20 */
+/* http://www.intel.com/support/performancetools/c/windows/sb/CS-007751.htm#2 */
+/* EDIT 11/5/20. Intel changed __force_inline to __forceinline */
+# define URI_INLINE __forceinline
+#elif defined(_MSC_VER)
+/* Microsoft Visual C++ */
+/* http://predef.sourceforge.net/precomp.html#sec32 */
+/* http://msdn2.microsoft.com/en-us/library/ms882281.aspx */
+# define URI_INLINE __forceinline
+#elif (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L))
+/* C99, "inline" is a keyword */
+# define URI_INLINE inline
+#else
+/* No inlining */
+# define URI_INLINE
+#endif
+
+
+
+#endif /* URI_DEFS_CONFIG_H */
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriDefsUnicode.h b/src/arrow/cpp/src/arrow/vendored/uriparser/UriDefsUnicode.h
new file mode 100644
index 000000000..8bb8bc2be
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriDefsUnicode.h
@@ -0,0 +1,82 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/**
+ * @file UriDefsUnicode.h
+ * Holds definitions for the Unicode pass.
+ * NOTE: This header is included N times, not once.
+ */
+
+/* Allow multi inclusion */
+#include "UriDefsConfig.h"
+
+
+
+#undef URI_CHAR
+#define URI_CHAR wchar_t
+
+#undef _UT
+#define _UT(x) L##x
+
+
+
+#undef URI_FUNC
+#define URI_FUNC(x) uri##x##W
+
+#undef URI_TYPE
+#define URI_TYPE(x) Uri##x##W
+
+
+
+#undef URI_STRLEN
+#define URI_STRLEN wcslen
+#undef URI_STRCPY
+#define URI_STRCPY wcscpy
+#undef URI_STRCMP
+#define URI_STRCMP wcscmp
+#undef URI_STRNCMP
+#define URI_STRNCMP wcsncmp
+
+/* TODO Remove on next source-compatibility break */
+#undef URI_SNPRINTF
+#if (defined(__WIN32__) || defined(_WIN32) || defined(WIN32))
+# define URI_SNPRINTF _snwprintf
+#else
+# define URI_SNPRINTF swprintf
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriEscape.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriEscape.c
new file mode 100644
index 000000000..46f3fd4ac
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriEscape.c
@@ -0,0 +1,453 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/* What encodings are enabled? */
+#include "UriDefsConfig.h"
+#if (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* Include SELF twice */
+# ifdef URI_ENABLE_ANSI
+# define URI_PASS_ANSI 1
+# include "UriEscape.c"
+# undef URI_PASS_ANSI
+# endif
+# ifdef URI_ENABLE_UNICODE
+# define URI_PASS_UNICODE 1
+# include "UriEscape.c"
+# undef URI_PASS_UNICODE
+# endif
+#else
+# ifdef URI_PASS_ANSI
+# include "UriDefsAnsi.h"
+# else
+# include "UriDefsUnicode.h"
+# include <wchar.h>
+# endif
+
+
+
+#ifndef URI_DOXYGEN
+# include "Uri.h"
+# include "UriCommon.h"
+#endif
+
+
+
+URI_CHAR * URI_FUNC(Escape)(const URI_CHAR * in, URI_CHAR * out,
+ UriBool spaceToPlus, UriBool normalizeBreaks) {
+ return URI_FUNC(EscapeEx)(in, NULL, out, spaceToPlus, normalizeBreaks);
+}
+
+
+
+URI_CHAR * URI_FUNC(EscapeEx)(const URI_CHAR * inFirst,
+ const URI_CHAR * inAfterLast, URI_CHAR * out,
+ UriBool spaceToPlus, UriBool normalizeBreaks) {
+ const URI_CHAR * read = inFirst;
+ URI_CHAR * write = out;
+ UriBool prevWasCr = URI_FALSE;
+ if ((out == NULL) || (inFirst == out)) {
+ return NULL;
+ } else if (inFirst == NULL) {
+ if (out != NULL) {
+ out[0] = _UT('\0');
+ }
+ return out;
+ }
+
+ for (;;) {
+ if ((inAfterLast != NULL) && (read >= inAfterLast)) {
+ write[0] = _UT('\0');
+ return write;
+ }
+
+ switch (read[0]) {
+ case _UT('\0'):
+ write[0] = _UT('\0');
+ return write;
+
+ case _UT(' '):
+ if (spaceToPlus) {
+ write[0] = _UT('+');
+ write++;
+ } else {
+ write[0] = _UT('%');
+ write[1] = _UT('2');
+ write[2] = _UT('0');
+ write += 3;
+ }
+ prevWasCr = URI_FALSE;
+ break;
+
+ case _UT('a'): /* ALPHA */
+ case _UT('A'):
+ case _UT('b'):
+ case _UT('B'):
+ case _UT('c'):
+ case _UT('C'):
+ case _UT('d'):
+ case _UT('D'):
+ case _UT('e'):
+ case _UT('E'):
+ case _UT('f'):
+ case _UT('F'):
+ case _UT('g'):
+ case _UT('G'):
+ case _UT('h'):
+ case _UT('H'):
+ case _UT('i'):
+ case _UT('I'):
+ case _UT('j'):
+ case _UT('J'):
+ case _UT('k'):
+ case _UT('K'):
+ case _UT('l'):
+ case _UT('L'):
+ case _UT('m'):
+ case _UT('M'):
+ case _UT('n'):
+ case _UT('N'):
+ case _UT('o'):
+ case _UT('O'):
+ case _UT('p'):
+ case _UT('P'):
+ case _UT('q'):
+ case _UT('Q'):
+ case _UT('r'):
+ case _UT('R'):
+ case _UT('s'):
+ case _UT('S'):
+ case _UT('t'):
+ case _UT('T'):
+ case _UT('u'):
+ case _UT('U'):
+ case _UT('v'):
+ case _UT('V'):
+ case _UT('w'):
+ case _UT('W'):
+ case _UT('x'):
+ case _UT('X'):
+ case _UT('y'):
+ case _UT('Y'):
+ case _UT('z'):
+ case _UT('Z'):
+ case _UT('0'): /* DIGIT */
+ case _UT('1'):
+ case _UT('2'):
+ case _UT('3'):
+ case _UT('4'):
+ case _UT('5'):
+ case _UT('6'):
+ case _UT('7'):
+ case _UT('8'):
+ case _UT('9'):
+ case _UT('-'): /* "-" / "." / "_" / "~" */
+ case _UT('.'):
+ case _UT('_'):
+ case _UT('~'):
+ /* Copy unmodified */
+ write[0] = read[0];
+ write++;
+
+ prevWasCr = URI_FALSE;
+ break;
+
+ case _UT('\x0a'):
+ if (normalizeBreaks) {
+ if (!prevWasCr) {
+ write[0] = _UT('%');
+ write[1] = _UT('0');
+ write[2] = _UT('D');
+ write[3] = _UT('%');
+ write[4] = _UT('0');
+ write[5] = _UT('A');
+ write += 6;
+ }
+ } else {
+ write[0] = _UT('%');
+ write[1] = _UT('0');
+ write[2] = _UT('A');
+ write += 3;
+ }
+ prevWasCr = URI_FALSE;
+ break;
+
+ case _UT('\x0d'):
+ if (normalizeBreaks) {
+ write[0] = _UT('%');
+ write[1] = _UT('0');
+ write[2] = _UT('D');
+ write[3] = _UT('%');
+ write[4] = _UT('0');
+ write[5] = _UT('A');
+ write += 6;
+ } else {
+ write[0] = _UT('%');
+ write[1] = _UT('0');
+ write[2] = _UT('D');
+ write += 3;
+ }
+ prevWasCr = URI_TRUE;
+ break;
+
+ default:
+ /* Percent encode */
+ {
+ const unsigned char code = (unsigned char)read[0];
+ write[0] = _UT('%');
+ write[1] = URI_FUNC(HexToLetter)(code >> 4);
+ write[2] = URI_FUNC(HexToLetter)(code & 0x0f);
+ write += 3;
+ }
+ prevWasCr = URI_FALSE;
+ break;
+ }
+
+ read++;
+ }
+}
+
+
+
+const URI_CHAR * URI_FUNC(UnescapeInPlace)(URI_CHAR * inout) {
+ return URI_FUNC(UnescapeInPlaceEx)(inout, URI_FALSE, URI_BR_DONT_TOUCH);
+}
+
+
+
+const URI_CHAR * URI_FUNC(UnescapeInPlaceEx)(URI_CHAR * inout,
+ UriBool plusToSpace, UriBreakConversion breakConversion) {
+ URI_CHAR * read = inout;
+ URI_CHAR * write = inout;
+ UriBool prevWasCr = URI_FALSE;
+
+ if (inout == NULL) {
+ return NULL;
+ }
+
+ for (;;) {
+ switch (read[0]) {
+ case _UT('\0'):
+ if (read > write) {
+ write[0] = _UT('\0');
+ }
+ return write;
+
+ case _UT('%'):
+ switch (read[1]) {
+ case _UT('0'):
+ case _UT('1'):
+ case _UT('2'):
+ case _UT('3'):
+ case _UT('4'):
+ case _UT('5'):
+ case _UT('6'):
+ case _UT('7'):
+ case _UT('8'):
+ case _UT('9'):
+ case _UT('a'):
+ case _UT('b'):
+ case _UT('c'):
+ case _UT('d'):
+ case _UT('e'):
+ case _UT('f'):
+ case _UT('A'):
+ case _UT('B'):
+ case _UT('C'):
+ case _UT('D'):
+ case _UT('E'):
+ case _UT('F'):
+ switch (read[2]) {
+ case _UT('0'):
+ case _UT('1'):
+ case _UT('2'):
+ case _UT('3'):
+ case _UT('4'):
+ case _UT('5'):
+ case _UT('6'):
+ case _UT('7'):
+ case _UT('8'):
+ case _UT('9'):
+ case _UT('a'):
+ case _UT('b'):
+ case _UT('c'):
+ case _UT('d'):
+ case _UT('e'):
+ case _UT('f'):
+ case _UT('A'):
+ case _UT('B'):
+ case _UT('C'):
+ case _UT('D'):
+ case _UT('E'):
+ case _UT('F'):
+ {
+ /* Percent group found */
+ const unsigned char left = URI_FUNC(HexdigToInt)(read[1]);
+ const unsigned char right = URI_FUNC(HexdigToInt)(read[2]);
+ const int code = 16 * left + right;
+ switch (code) {
+ case 10:
+ switch (breakConversion) {
+ case URI_BR_TO_LF:
+ if (!prevWasCr) {
+ write[0] = (URI_CHAR)10;
+ write++;
+ }
+ break;
+
+ case URI_BR_TO_CRLF:
+ if (!prevWasCr) {
+ write[0] = (URI_CHAR)13;
+ write[1] = (URI_CHAR)10;
+ write += 2;
+ }
+ break;
+
+ case URI_BR_TO_CR:
+ if (!prevWasCr) {
+ write[0] = (URI_CHAR)13;
+ write++;
+ }
+ break;
+
+ case URI_BR_DONT_TOUCH:
+ default:
+ write[0] = (URI_CHAR)10;
+ write++;
+
+ }
+ prevWasCr = URI_FALSE;
+ break;
+
+ case 13:
+ switch (breakConversion) {
+ case URI_BR_TO_LF:
+ write[0] = (URI_CHAR)10;
+ write++;
+ break;
+
+ case URI_BR_TO_CRLF:
+ write[0] = (URI_CHAR)13;
+ write[1] = (URI_CHAR)10;
+ write += 2;
+ break;
+
+ case URI_BR_TO_CR:
+ write[0] = (URI_CHAR)13;
+ write++;
+ break;
+
+ case URI_BR_DONT_TOUCH:
+ default:
+ write[0] = (URI_CHAR)13;
+ write++;
+
+ }
+ prevWasCr = URI_TRUE;
+ break;
+
+ default:
+ write[0] = (URI_CHAR)(code);
+ write++;
+
+ prevWasCr = URI_FALSE;
+
+ }
+ read += 3;
+ }
+ break;
+
+ default:
+ /* Copy two chars unmodified and */
+ /* look at this char again */
+ if (read > write) {
+ write[0] = read[0];
+ write[1] = read[1];
+ }
+ read += 2;
+ write += 2;
+
+ prevWasCr = URI_FALSE;
+ }
+ break;
+
+ default:
+ /* Copy one char unmodified and */
+ /* look at this char again */
+ if (read > write) {
+ write[0] = read[0];
+ }
+ read++;
+ write++;
+
+ prevWasCr = URI_FALSE;
+ }
+ break;
+
+ case _UT('+'):
+ if (plusToSpace) {
+ /* Convert '+' to ' ' */
+ write[0] = _UT(' ');
+ } else {
+ /* Copy one char unmodified */
+ if (read > write) {
+ write[0] = read[0];
+ }
+ }
+ read++;
+ write++;
+
+ prevWasCr = URI_FALSE;
+ break;
+
+ default:
+ /* Copy one char unmodified */
+ if (read > write) {
+ write[0] = read[0];
+ }
+ read++;
+ write++;
+
+ prevWasCr = URI_FALSE;
+ }
+ }
+}
+
+
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriFile.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriFile.c
new file mode 100644
index 000000000..abaa84f9a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriFile.c
@@ -0,0 +1,242 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/* What encodings are enabled? */
+#include "UriDefsConfig.h"
+#if (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* Include SELF twice */
+# ifdef URI_ENABLE_ANSI
+# define URI_PASS_ANSI 1
+# include "UriFile.c"
+# undef URI_PASS_ANSI
+# endif
+# ifdef URI_ENABLE_UNICODE
+# define URI_PASS_UNICODE 1
+# include "UriFile.c"
+# undef URI_PASS_UNICODE
+# endif
+#else
+# ifdef URI_PASS_ANSI
+# include "UriDefsAnsi.h"
+# else
+# include "UriDefsUnicode.h"
+# include <wchar.h>
+# endif
+
+
+
+#ifndef URI_DOXYGEN
+# include "Uri.h"
+#endif
+
+
+
+#include <stdlib.h> /* for size_t, avoiding stddef.h for older MSVCs */
+
+
+
+static URI_INLINE int URI_FUNC(FilenameToUriString)(const URI_CHAR * filename,
+ URI_CHAR * uriString, UriBool fromUnix) {
+ const URI_CHAR * input = filename;
+ const URI_CHAR * lastSep = input - 1;
+ UriBool firstSegment = URI_TRUE;
+ URI_CHAR * output = uriString;
+ UriBool absolute;
+ UriBool is_windows_network;
+
+ if ((filename == NULL) || (uriString == NULL)) {
+ return URI_ERROR_NULL;
+ }
+
+ is_windows_network = (filename[0] == _UT('\\')) && (filename[1] == _UT('\\'));
+ absolute = fromUnix
+ ? (filename[0] == _UT('/'))
+ : (((filename[0] != _UT('\0')) && (filename[1] == _UT(':')))
+ || is_windows_network);
+
+ if (absolute) {
+ const URI_CHAR * const prefix = fromUnix
+ ? _UT("file://")
+ : is_windows_network
+ ? _UT("file:")
+ : _UT("file:///");
+ const size_t prefixLen = URI_STRLEN(prefix);
+
+ /* Copy prefix */
+ memcpy(uriString, prefix, prefixLen * sizeof(URI_CHAR));
+ output += prefixLen;
+ }
+
+ /* Copy and escape on the fly */
+ for (;;) {
+ if ((input[0] == _UT('\0'))
+ || (fromUnix && input[0] == _UT('/'))
+ || (!fromUnix && input[0] == _UT('\\'))) {
+ /* Copy text after last separator */
+ if (lastSep + 1 < input) {
+ if (!fromUnix && absolute && (firstSegment == URI_TRUE)) {
+ /* Quick hack to not convert "C:" to "C%3A" */
+ const int charsToCopy = (int)(input - (lastSep + 1));
+ memcpy(output, lastSep + 1, charsToCopy * sizeof(URI_CHAR));
+ output += charsToCopy;
+ } else {
+ output = URI_FUNC(EscapeEx)(lastSep + 1, input, output,
+ URI_FALSE, URI_FALSE);
+ }
+ }
+ firstSegment = URI_FALSE;
+ }
+
+ if (input[0] == _UT('\0')) {
+ output[0] = _UT('\0');
+ break;
+ } else if (fromUnix && (input[0] == _UT('/'))) {
+ /* Copy separators unmodified */
+ output[0] = _UT('/');
+ output++;
+ lastSep = input;
+ } else if (!fromUnix && (input[0] == _UT('\\'))) {
+ /* Convert backslashes to forward slashes */
+ output[0] = _UT('/');
+ output++;
+ lastSep = input;
+ }
+ input++;
+ }
+
+ return URI_SUCCESS;
+}
+
+
+
+static URI_INLINE int URI_FUNC(UriStringToFilename)(const URI_CHAR * uriString,
+ URI_CHAR * filename, UriBool toUnix) {
+ if ((uriString == NULL) || (filename == NULL)) {
+ return URI_ERROR_NULL;
+ }
+
+ {
+ const UriBool file_unknown_slashes =
+ URI_STRNCMP(uriString, _UT("file:"), URI_STRLEN(_UT("file:"))) == 0;
+ const UriBool file_one_or_more_slashes = file_unknown_slashes
+ && (URI_STRNCMP(uriString, _UT("file:/"), URI_STRLEN(_UT("file:/"))) == 0);
+ const UriBool file_two_or_more_slashes = file_one_or_more_slashes
+ && (URI_STRNCMP(uriString, _UT("file://"), URI_STRLEN(_UT("file://"))) == 0);
+ const UriBool file_three_or_more_slashes = file_two_or_more_slashes
+ && (URI_STRNCMP(uriString, _UT("file:///"), URI_STRLEN(_UT("file:///"))) == 0);
+
+ const size_t charsToSkip = file_two_or_more_slashes
+ ? file_three_or_more_slashes
+ ? toUnix
+ /* file:///bin/bash */
+ ? URI_STRLEN(_UT("file://"))
+ /* file:///E:/Documents%20and%20Settings */
+ : URI_STRLEN(_UT("file:///"))
+ /* file://Server01/Letter.txt */
+ : URI_STRLEN(_UT("file://"))
+ : ((file_one_or_more_slashes && toUnix)
+ /* file:/bin/bash */
+ /* https://tools.ietf.org/html/rfc8089#appendix-B */
+ ? URI_STRLEN(_UT("file:"))
+ : ((! toUnix && file_unknown_slashes && ! file_one_or_more_slashes)
+ /* file:c:/path/to/file */
+ /* https://tools.ietf.org/html/rfc8089#appendix-E.2 */
+ ? URI_STRLEN(_UT("file:"))
+ : 0));
+ const size_t charsToCopy = URI_STRLEN(uriString + charsToSkip) + 1;
+
+ const UriBool is_windows_network_with_authority =
+ (toUnix == URI_FALSE)
+ && file_two_or_more_slashes
+ && ! file_three_or_more_slashes;
+
+ URI_CHAR * const unescape_target = is_windows_network_with_authority
+ ? (filename + 2)
+ : filename;
+
+ if (is_windows_network_with_authority) {
+ filename[0] = '\\';
+ filename[1] = '\\';
+ }
+
+ memcpy(unescape_target, uriString + charsToSkip, charsToCopy * sizeof(URI_CHAR));
+ URI_FUNC(UnescapeInPlaceEx)(filename, URI_FALSE, URI_BR_DONT_TOUCH);
+ }
+
+ /* Convert forward slashes to backslashes */
+ if (!toUnix) {
+ URI_CHAR * walker = filename;
+ while (walker[0] != _UT('\0')) {
+ if (walker[0] == _UT('/')) {
+ walker[0] = _UT('\\');
+ }
+ walker++;
+ }
+ }
+
+ return URI_SUCCESS;
+}
+
+
+
+int URI_FUNC(UnixFilenameToUriString)(const URI_CHAR * filename, URI_CHAR * uriString) {
+ return URI_FUNC(FilenameToUriString)(filename, uriString, URI_TRUE);
+}
+
+
+
+int URI_FUNC(WindowsFilenameToUriString)(const URI_CHAR * filename, URI_CHAR * uriString) {
+ return URI_FUNC(FilenameToUriString)(filename, uriString, URI_FALSE);
+}
+
+
+
+int URI_FUNC(UriStringToUnixFilename)(const URI_CHAR * uriString, URI_CHAR * filename) {
+ return URI_FUNC(UriStringToFilename)(uriString, filename, URI_TRUE);
+}
+
+
+
+int URI_FUNC(UriStringToWindowsFilename)(const URI_CHAR * uriString, URI_CHAR * filename) {
+ return URI_FUNC(UriStringToFilename)(uriString, filename, URI_FALSE);
+}
+
+
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4.c
new file mode 100644
index 000000000..1d4ab39bc
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4.c
@@ -0,0 +1,329 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/**
+ * @file UriIp4.c
+ * Holds the IPv4 parser implementation.
+ * NOTE: This source file includes itself twice.
+ */
+
+/* What encodings are enabled? */
+#include "UriDefsConfig.h"
+#if (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* Include SELF twice */
+# ifdef URI_ENABLE_ANSI
+# define URI_PASS_ANSI 1
+# include "UriIp4.c"
+# undef URI_PASS_ANSI
+# endif
+# ifdef URI_ENABLE_UNICODE
+# define URI_PASS_UNICODE 1
+# include "UriIp4.c"
+# undef URI_PASS_UNICODE
+# endif
+#else
+# ifdef URI_PASS_ANSI
+# include "UriDefsAnsi.h"
+# else
+# include "UriDefsUnicode.h"
+# endif
+
+
+
+#ifndef URI_DOXYGEN
+# include "UriIp4.h"
+# include "UriIp4Base.h"
+# include "UriBase.h"
+#endif
+
+
+
+/* Prototypes */
+static const URI_CHAR * URI_FUNC(ParseDecOctet)(UriIp4Parser * parser,
+ const URI_CHAR * first, const URI_CHAR * afterLast);
+static const URI_CHAR * URI_FUNC(ParseDecOctetOne)(UriIp4Parser * parser,
+ const URI_CHAR * first, const URI_CHAR * afterLast);
+static const URI_CHAR * URI_FUNC(ParseDecOctetTwo)(UriIp4Parser * parser,
+ const URI_CHAR * first, const URI_CHAR * afterLast);
+static const URI_CHAR * URI_FUNC(ParseDecOctetThree)(UriIp4Parser * parser,
+ const URI_CHAR * first, const URI_CHAR * afterLast);
+static const URI_CHAR * URI_FUNC(ParseDecOctetFour)(UriIp4Parser * parser,
+ const URI_CHAR * first, const URI_CHAR * afterLast);
+
+
+
+/*
+ * [ipFourAddress]->[decOctet]<.>[decOctet]<.>[decOctet]<.>[decOctet]
+ */
+int URI_FUNC(ParseIpFourAddress)(unsigned char * octetOutput,
+ const URI_CHAR * first, const URI_CHAR * afterLast) {
+ const URI_CHAR * after;
+ UriIp4Parser parser;
+
+ /* Essential checks */
+ if ((octetOutput == NULL) || (first == NULL)
+ || (afterLast <= first)) {
+ return URI_ERROR_SYNTAX;
+ }
+
+ /* Reset parser */
+ parser.stackCount = 0;
+
+ /* Octet #1 */
+ after = URI_FUNC(ParseDecOctet)(&parser, first, afterLast);
+ if ((after == NULL) || (after >= afterLast) || (*after != _UT('.'))) {
+ return URI_ERROR_SYNTAX;
+ }
+ uriStackToOctet(&parser, octetOutput);
+
+ /* Octet #2 */
+ after = URI_FUNC(ParseDecOctet)(&parser, after + 1, afterLast);
+ if ((after == NULL) || (after >= afterLast) || (*after != _UT('.'))) {
+ return URI_ERROR_SYNTAX;
+ }
+ uriStackToOctet(&parser, octetOutput + 1);
+
+ /* Octet #3 */
+ after = URI_FUNC(ParseDecOctet)(&parser, after + 1, afterLast);
+ if ((after == NULL) || (after >= afterLast) || (*after != _UT('.'))) {
+ return URI_ERROR_SYNTAX;
+ }
+ uriStackToOctet(&parser, octetOutput + 2);
+
+ /* Octet #4 */
+ after = URI_FUNC(ParseDecOctet)(&parser, after + 1, afterLast);
+ if (after != afterLast) {
+ return URI_ERROR_SYNTAX;
+ }
+ uriStackToOctet(&parser, octetOutput + 3);
+
+ return URI_SUCCESS;
+}
+
+
+
+/*
+ * [decOctet]-><0>
+ * [decOctet]-><1>[decOctetOne]
+ * [decOctet]-><2>[decOctetTwo]
+ * [decOctet]-><3>[decOctetThree]
+ * [decOctet]-><4>[decOctetThree]
+ * [decOctet]-><5>[decOctetThree]
+ * [decOctet]-><6>[decOctetThree]
+ * [decOctet]-><7>[decOctetThree]
+ * [decOctet]-><8>[decOctetThree]
+ * [decOctet]-><9>[decOctetThree]
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParseDecOctet)(UriIp4Parser * parser,
+ const URI_CHAR * first, const URI_CHAR * afterLast) {
+ if (first >= afterLast) {
+ return NULL;
+ }
+
+ switch (*first) {
+ case _UT('0'):
+ uriPushToStack(parser, 0);
+ return first + 1;
+
+ case _UT('1'):
+ uriPushToStack(parser, 1);
+ return (const URI_CHAR *)URI_FUNC(ParseDecOctetOne)(parser, first + 1, afterLast);
+
+ case _UT('2'):
+ uriPushToStack(parser, 2);
+ return (const URI_CHAR *)URI_FUNC(ParseDecOctetTwo)(parser, first + 1, afterLast);
+
+ case _UT('3'):
+ case _UT('4'):
+ case _UT('5'):
+ case _UT('6'):
+ case _UT('7'):
+ case _UT('8'):
+ case _UT('9'):
+ uriPushToStack(parser, (unsigned char)(9 + *first - _UT('9')));
+ return (const URI_CHAR *)URI_FUNC(ParseDecOctetThree)(parser, first + 1, afterLast);
+
+ default:
+ return NULL;
+ }
+}
+
+
+
+/*
+ * [decOctetOne]-><NULL>
+ * [decOctetOne]->[DIGIT][decOctetThree]
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParseDecOctetOne)(UriIp4Parser * parser,
+ const URI_CHAR * first, const URI_CHAR * afterLast) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('0'):
+ case _UT('1'):
+ case _UT('2'):
+ case _UT('3'):
+ case _UT('4'):
+ case _UT('5'):
+ case _UT('6'):
+ case _UT('7'):
+ case _UT('8'):
+ case _UT('9'):
+ uriPushToStack(parser, (unsigned char)(9 + *first - _UT('9')));
+ return (const URI_CHAR *)URI_FUNC(ParseDecOctetThree)(parser, first + 1, afterLast);
+
+ default:
+ return first;
+ }
+}
+
+
+
+/*
+ * [decOctetTwo]-><NULL>
+ * [decOctetTwo]-><0>[decOctetThree]
+ * [decOctetTwo]-><1>[decOctetThree]
+ * [decOctetTwo]-><2>[decOctetThree]
+ * [decOctetTwo]-><3>[decOctetThree]
+ * [decOctetTwo]-><4>[decOctetThree]
+ * [decOctetTwo]-><5>[decOctetFour]
+ * [decOctetTwo]-><6>
+ * [decOctetTwo]-><7>
+ * [decOctetTwo]-><8>
+ * [decOctetTwo]-><9>
+*/
+static URI_INLINE const URI_CHAR * URI_FUNC(ParseDecOctetTwo)(UriIp4Parser * parser,
+ const URI_CHAR * first, const URI_CHAR * afterLast) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('0'):
+ case _UT('1'):
+ case _UT('2'):
+ case _UT('3'):
+ case _UT('4'):
+ uriPushToStack(parser, (unsigned char)(9 + *first - _UT('9')));
+ return (const URI_CHAR *)URI_FUNC(ParseDecOctetThree)(parser, first + 1, afterLast);
+
+ case _UT('5'):
+ uriPushToStack(parser, 5);
+ return (const URI_CHAR *)URI_FUNC(ParseDecOctetFour)(parser, first + 1, afterLast);
+
+ case _UT('6'):
+ case _UT('7'):
+ case _UT('8'):
+ case _UT('9'):
+ uriPushToStack(parser, (unsigned char)(9 + *first - _UT('9')));
+ return first + 1;
+
+ default:
+ return first;
+ }
+}
+
+
+
+/*
+ * [decOctetThree]-><NULL>
+ * [decOctetThree]->[DIGIT]
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParseDecOctetThree)(UriIp4Parser * parser,
+ const URI_CHAR * first, const URI_CHAR * afterLast) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('0'):
+ case _UT('1'):
+ case _UT('2'):
+ case _UT('3'):
+ case _UT('4'):
+ case _UT('5'):
+ case _UT('6'):
+ case _UT('7'):
+ case _UT('8'):
+ case _UT('9'):
+ uriPushToStack(parser, (unsigned char)(9 + *first - _UT('9')));
+ return first + 1;
+
+ default:
+ return first;
+ }
+}
+
+
+
+/*
+ * [decOctetFour]-><NULL>
+ * [decOctetFour]-><0>
+ * [decOctetFour]-><1>
+ * [decOctetFour]-><2>
+ * [decOctetFour]-><3>
+ * [decOctetFour]-><4>
+ * [decOctetFour]-><5>
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParseDecOctetFour)(UriIp4Parser * parser,
+ const URI_CHAR * first, const URI_CHAR * afterLast) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('0'):
+ case _UT('1'):
+ case _UT('2'):
+ case _UT('3'):
+ case _UT('4'):
+ case _UT('5'):
+ uriPushToStack(parser, (unsigned char)(9 + *first - _UT('9')));
+ return first + 1;
+
+ default:
+ return first;
+ }
+}
+
+
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4.h b/src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4.h
new file mode 100644
index 000000000..e1d7f1e4a
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4.h
@@ -0,0 +1,110 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/**
+ * @file UriIp4.h
+ * Holds the IPv4 parser interface.
+ * NOTE: This header includes itself twice.
+ */
+
+#if (defined(URI_PASS_ANSI) && !defined(URI_IP4_TWICE_H_ANSI)) \
+ || (defined(URI_PASS_UNICODE) && !defined(URI_IP4_TWICE_H_UNICODE)) \
+ || (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* What encodings are enabled? */
+#include "UriDefsConfig.h"
+#if (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* Include SELF twice */
+# ifdef URI_ENABLE_ANSI
+# define URI_PASS_ANSI 1
+# include "UriIp4.h"
+# undef URI_PASS_ANSI
+# endif
+# ifdef URI_ENABLE_UNICODE
+# define URI_PASS_UNICODE 1
+# include "UriIp4.h"
+# undef URI_PASS_UNICODE
+# endif
+/* Only one pass for each encoding */
+#elif (defined(URI_PASS_ANSI) && !defined(URI_IP4_TWICE_H_ANSI) \
+ && defined(URI_ENABLE_ANSI)) || (defined(URI_PASS_UNICODE) \
+ && !defined(URI_IP4_TWICE_H_UNICODE) && defined(URI_ENABLE_UNICODE))
+# ifdef URI_PASS_ANSI
+# define URI_IP4_TWICE_H_ANSI 1
+# include "UriDefsAnsi.h"
+# else
+# define URI_IP4_TWICE_H_UNICODE 1
+# include "UriDefsUnicode.h"
+# include <wchar.h>
+# endif
+
+
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+
+#ifndef URI_DOXYGEN
+# include "UriBase.h"
+#endif
+
+
+
+/**
+ * Converts a IPv4 text representation into four bytes.
+ *
+ * @param octetOutput Output destination
+ * @param first First character of IPv4 text to parse
+ * @param afterLast Position to stop parsing at
+ * @return Error code or 0 on success
+ */
+URI_PUBLIC int URI_FUNC(ParseIpFourAddress)(unsigned char * octetOutput,
+ const URI_CHAR * first, const URI_CHAR * afterLast);
+
+
+
+#ifdef __cplusplus
+}
+#endif
+
+
+
+#endif
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4Base.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4Base.c
new file mode 100644
index 000000000..f0662465d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4Base.c
@@ -0,0 +1,96 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/**
+ * @file UriIp4Base.c
+ * Holds code independent of the encoding pass.
+ */
+
+#ifndef URI_DOXYGEN
+# include "UriIp4Base.h"
+#endif
+
+
+
+void uriStackToOctet(UriIp4Parser * parser, unsigned char * octet) {
+ switch (parser->stackCount) {
+ case 1:
+ *octet = parser->stackOne;
+ break;
+
+ case 2:
+ *octet = parser->stackOne * 10
+ + parser->stackTwo;
+ break;
+
+ case 3:
+ *octet = parser->stackOne * 100
+ + parser->stackTwo * 10
+ + parser->stackThree;
+ break;
+
+ default:
+ ;
+ }
+ parser->stackCount = 0;
+}
+
+
+
+void uriPushToStack(UriIp4Parser * parser, unsigned char digit) {
+ switch (parser->stackCount) {
+ case 0:
+ parser->stackOne = digit;
+ parser->stackCount = 1;
+ break;
+
+ case 1:
+ parser->stackTwo = digit;
+ parser->stackCount = 2;
+ break;
+
+ case 2:
+ parser->stackThree = digit;
+ parser->stackCount = 3;
+ break;
+
+ default:
+ ;
+ }
+}
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4Base.h b/src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4Base.h
new file mode 100644
index 000000000..bef028f7c
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriIp4Base.h
@@ -0,0 +1,59 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#ifndef URI_IP4_BASE_H
+#define URI_IP4_BASE_H 1
+
+
+
+typedef struct UriIp4ParserStruct {
+ unsigned char stackCount;
+ unsigned char stackOne;
+ unsigned char stackTwo;
+ unsigned char stackThree;
+} UriIp4Parser;
+
+
+
+void uriPushToStack(UriIp4Parser * parser, unsigned char digit);
+void uriStackToOctet(UriIp4Parser * parser, unsigned char * octet);
+
+
+
+#endif /* URI_IP4_BASE_H */
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriMemory.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriMemory.c
new file mode 100644
index 000000000..2d23f1b6f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriMemory.c
@@ -0,0 +1,468 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2018, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2018, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/**
+ * @file UriMemory.c
+ * Holds memory manager implementation.
+ */
+
+#include "config.h"
+
+#ifdef HAVE_REALLOCARRAY
+# ifndef _GNU_SOURCE
+# define _GNU_SOURCE 1
+# endif
+#endif
+
+#include <errno.h>
+#include <stdlib.h>
+
+
+
+#ifndef URI_DOXYGEN
+# include "UriMemory.h"
+#endif
+
+
+
+#define URI_CHECK_ALLOC_OVERFLOW(total_size, nmemb, size) \
+ do { \
+ /* check for unsigned overflow */ \
+ if ((nmemb != 0) && (total_size / nmemb != size)) { \
+ errno = ENOMEM; \
+ return NULL; \
+ } \
+ } while (0)
+
+
+
+static void * uriDefaultMalloc(UriMemoryManager * URI_UNUSED(memory),
+ size_t size) {
+ return malloc(size);
+}
+
+
+
+static void * uriDefaultCalloc(UriMemoryManager * URI_UNUSED(memory),
+ size_t nmemb, size_t size) {
+ return calloc(nmemb, size);
+}
+
+
+
+static void * uriDefaultRealloc(UriMemoryManager * URI_UNUSED(memory),
+ void * ptr, size_t size) {
+ return realloc(ptr, size);
+}
+
+
+
+static void * uriDefaultReallocarray(UriMemoryManager * URI_UNUSED(memory),
+ void * ptr, size_t nmemb, size_t size) {
+#ifdef HAVE_REALLOCARRAY
+ return reallocarray(ptr, nmemb, size);
+#else
+ const size_t total_size = nmemb * size;
+
+ URI_CHECK_ALLOC_OVERFLOW(total_size, nmemb, size); /* may return */
+
+ return realloc(ptr, total_size);
+#endif
+}
+
+
+
+static void uriDefaultFree(UriMemoryManager * URI_UNUSED(memory),
+ void * ptr) {
+ free(ptr);
+}
+
+
+
+UriBool uriMemoryManagerIsComplete(const UriMemoryManager * memory) {
+ return (memory
+ && memory->malloc
+ && memory->calloc
+ && memory->realloc
+ && memory->reallocarray
+ && memory->free) ? URI_TRUE : URI_FALSE;
+}
+
+
+
+void * uriEmulateCalloc(UriMemoryManager * memory, size_t nmemb, size_t size) {
+ void * buffer;
+ const size_t total_size = nmemb * size;
+
+ if (memory == NULL) {
+ errno = EINVAL;
+ return NULL;
+ }
+
+ URI_CHECK_ALLOC_OVERFLOW(total_size, nmemb, size); /* may return */
+
+ buffer = memory->malloc(memory, total_size);
+ if (buffer == NULL) {
+ /* errno set by malloc */
+ return NULL;
+ }
+ memset(buffer, 0, total_size);
+ return buffer;
+}
+
+
+
+void * uriEmulateReallocarray(UriMemoryManager * memory,
+ void * ptr, size_t nmemb, size_t size) {
+ const size_t total_size = nmemb * size;
+
+ if (memory == NULL) {
+ errno = EINVAL;
+ return NULL;
+ }
+
+ URI_CHECK_ALLOC_OVERFLOW(total_size, nmemb, size); /* may return */
+
+ return memory->realloc(memory, ptr, total_size);
+}
+
+
+
+static void * uriDecorateMalloc(UriMemoryManager * memory,
+ size_t size) {
+ UriMemoryManager * backend;
+ const size_t extraBytes = sizeof(size_t);
+ void * buffer;
+
+ if (memory == NULL) {
+ errno = EINVAL;
+ return NULL;
+ }
+
+ /* check for unsigned overflow */
+ if (size > ((size_t)-1) - extraBytes) {
+ errno = ENOMEM;
+ return NULL;
+ }
+
+ backend = (UriMemoryManager *)memory->userData;
+ if (backend == NULL) {
+ errno = EINVAL;
+ return NULL;
+ }
+
+ buffer = backend->malloc(backend, extraBytes + size);
+ if (buffer == NULL) {
+ return NULL;
+ }
+
+ *(size_t *)buffer = size;
+
+ return (char *)buffer + extraBytes;
+}
+
+
+
+static void * uriDecorateRealloc(UriMemoryManager * memory,
+ void * ptr, size_t size) {
+ void * newBuffer;
+ size_t prevSize;
+
+ if (memory == NULL) {
+ errno = EINVAL;
+ return NULL;
+ }
+
+ /* man realloc: "If ptr is NULL, then the call is equivalent to
+ * malloc(size), for *all* values of size" */
+ if (ptr == NULL) {
+ return memory->malloc(memory, size);
+ }
+
+ /* man realloc: "If size is equal to zero, and ptr is *not* NULL,
+ * then the call is equivalent to free(ptr)." */
+ if (size == 0) {
+ memory->free(memory, ptr);
+ return NULL;
+ }
+
+ prevSize = *((size_t *)((char *)ptr - sizeof(size_t)));
+
+ /* Anything to do? */
+ if (size <= prevSize) {
+ return ptr;
+ }
+
+ newBuffer = memory->malloc(memory, size);
+ if (newBuffer == NULL) {
+ /* errno set by malloc */
+ return NULL;
+ }
+
+ memcpy(newBuffer, ptr, prevSize);
+
+ memory->free(memory, ptr);
+
+ return newBuffer;
+}
+
+
+
+static void uriDecorateFree(UriMemoryManager * memory, void * ptr) {
+ UriMemoryManager * backend;
+
+ if ((ptr == NULL) || (memory == NULL)) {
+ return;
+ }
+
+ backend = (UriMemoryManager *)memory->userData;
+ if (backend == NULL) {
+ return;
+ }
+
+ backend->free(backend, (char *)ptr - sizeof(size_t));
+}
+
+
+
+int uriCompleteMemoryManager(UriMemoryManager * memory,
+ UriMemoryManager * backend) {
+ if ((memory == NULL) || (backend == NULL)) {
+ return URI_ERROR_NULL;
+ }
+
+ if ((backend->malloc == NULL) || (backend->free == NULL)) {
+ return URI_ERROR_MEMORY_MANAGER_INCOMPLETE;
+ }
+
+ memory->calloc = uriEmulateCalloc;
+ memory->reallocarray = uriEmulateReallocarray;
+
+ memory->malloc = uriDecorateMalloc;
+ memory->realloc = uriDecorateRealloc;
+ memory->free = uriDecorateFree;
+
+ memory->userData = backend;
+
+ return URI_SUCCESS;
+}
+
+
+
+int uriTestMemoryManager(UriMemoryManager * memory) {
+ const size_t mallocSize = 7;
+ const size_t callocNmemb = 3;
+ const size_t callocSize = 5;
+ const size_t callocTotalSize = callocNmemb * callocSize;
+ const size_t reallocSize = 11;
+ const size_t reallocarrayNmemb = 5;
+ const size_t reallocarraySize = 7;
+ const size_t reallocarrayTotal = reallocarrayNmemb * reallocarraySize;
+ size_t index;
+ char * buffer;
+
+ if (memory == NULL) {
+ return URI_ERROR_NULL;
+ }
+
+ if (uriMemoryManagerIsComplete(memory) != URI_TRUE) {
+ return URI_ERROR_MEMORY_MANAGER_INCOMPLETE;
+ }
+
+ /* malloc + free*/
+ buffer = memory->malloc(memory, mallocSize);
+ if (buffer == NULL) {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ buffer[mallocSize - 1] = '\xF1';
+ memory->free(memory, buffer);
+ buffer = NULL;
+
+ /* calloc + free */
+ buffer = memory->calloc(memory, callocNmemb, callocSize);
+ if (buffer == NULL) {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ for (index = 0; index < callocTotalSize; index++) { /* all zeros? */
+ if (buffer[index] != '\0') {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ }
+ buffer[callocTotalSize - 1] = '\xF2';
+ memory->free(memory, buffer);
+ buffer = NULL;
+
+ /* malloc + realloc + free */
+ buffer = memory->malloc(memory, mallocSize);
+ if (buffer == NULL) {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ for (index = 0; index < mallocSize; index++) {
+ buffer[index] = '\xF3';
+ }
+ buffer = memory->realloc(memory, buffer, reallocSize);
+ if (buffer == NULL) {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ for (index = 0; index < mallocSize; index++) { /* previous content? */
+ if (buffer[index] != '\xF3') {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ }
+ buffer[reallocSize - 1] = '\xF4';
+ memory->free(memory, buffer);
+ buffer = NULL;
+
+ /* malloc + realloc ptr!=NULL size==0 (equals free) */
+ buffer = memory->malloc(memory, mallocSize);
+ if (buffer == NULL) {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ buffer[mallocSize - 1] = '\xF5';
+ memory->realloc(memory, buffer, 0);
+ buffer = NULL;
+
+ /* realloc ptr==NULL size!=0 (equals malloc) + free */
+ buffer = memory->realloc(memory, NULL, mallocSize);
+ if (buffer == NULL) {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ buffer[mallocSize - 1] = '\xF6';
+ memory->free(memory, buffer);
+ buffer = NULL;
+
+ /* realloc ptr==NULL size==0 (equals malloc) + free */
+ buffer = memory->realloc(memory, NULL, 0);
+ if (buffer != NULL) {
+ memory->free(memory, buffer);
+ buffer = NULL;
+ }
+
+ /* malloc + reallocarray + free */
+ buffer = memory->malloc(memory, mallocSize);
+ if (buffer == NULL) {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ for (index = 0; index < mallocSize; index++) {
+ buffer[index] = '\xF7';
+ }
+ buffer = memory->reallocarray(memory, buffer, reallocarrayNmemb,
+ reallocarraySize);
+ if (buffer == NULL) {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ for (index = 0; index < mallocSize; index++) { /* previous content? */
+ if (buffer[index] != '\xF7') {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ }
+ buffer[reallocarrayTotal - 1] = '\xF8';
+ memory->free(memory, buffer);
+ buffer = NULL;
+
+ /* malloc + reallocarray ptr!=NULL nmemb==0 size!=0 (equals free) */
+ buffer = memory->malloc(memory, mallocSize);
+ if (buffer == NULL) {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ buffer[mallocSize - 1] = '\xF9';
+ memory->reallocarray(memory, buffer, 0, reallocarraySize);
+ buffer = NULL;
+
+ /* malloc + reallocarray ptr!=NULL nmemb!=0 size==0 (equals free) */
+ buffer = memory->malloc(memory, mallocSize);
+ if (buffer == NULL) {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ buffer[mallocSize - 1] = '\xFA';
+ memory->reallocarray(memory, buffer, reallocarrayNmemb, 0);
+ buffer = NULL;
+
+ /* malloc + reallocarray ptr!=NULL nmemb==0 size==0 (equals free) */
+ buffer = memory->malloc(memory, mallocSize);
+ if (buffer == NULL) {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ buffer[mallocSize - 1] = '\xFB';
+ memory->reallocarray(memory, buffer, 0, 0);
+ buffer = NULL;
+
+ /* reallocarray ptr==NULL nmemb!=0 size!=0 (equals malloc) + free */
+ buffer = memory->reallocarray(memory, NULL, callocNmemb, callocSize);
+ if (buffer == NULL) {
+ return URI_ERROR_MEMORY_MANAGER_FAULTY;
+ }
+ buffer[callocTotalSize - 1] = '\xFC';
+ memory->free(memory, buffer);
+ buffer = NULL;
+
+ /* reallocarray ptr==NULL nmemb==0 size!=0 (equals malloc) + free */
+ buffer = memory->reallocarray(memory, NULL, 0, callocSize);
+ if (buffer != NULL) {
+ memory->free(memory, buffer);
+ buffer = NULL;
+ }
+
+ /* reallocarray ptr==NULL nmemb!=0 size==0 (equals malloc) + free */
+ buffer = memory->reallocarray(memory, NULL, callocNmemb, 0);
+ if (buffer != NULL) {
+ memory->free(memory, buffer);
+ buffer = NULL;
+ }
+
+ /* reallocarray ptr==NULL nmemb==0 size==0 (equals malloc) + free */
+ buffer = memory->reallocarray(memory, NULL, 0, 0);
+ if (buffer != NULL) {
+ memory->free(memory, buffer);
+ buffer = NULL;
+ }
+
+ return URI_SUCCESS;
+}
+
+
+
+/*extern*/ UriMemoryManager defaultMemoryManager = {
+ uriDefaultMalloc,
+ uriDefaultCalloc,
+ uriDefaultRealloc,
+ uriDefaultReallocarray,
+ uriDefaultFree,
+ NULL /* userData */
+};
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriMemory.h b/src/arrow/cpp/src/arrow/vendored/uriparser/UriMemory.h
new file mode 100644
index 000000000..5d6bf6784
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriMemory.h
@@ -0,0 +1,78 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2018, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2018, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#ifndef URI_MEMORY_H
+#define URI_MEMORY_H 1
+
+
+
+#ifndef URI_DOXYGEN
+# include "Uri.h"
+#endif
+
+
+
+#define URI_CHECK_MEMORY_MANAGER(memory) \
+ do { \
+ if (memory == NULL) { \
+ memory = &defaultMemoryManager; \
+ } else if (uriMemoryManagerIsComplete(memory) != URI_TRUE) { \
+ return URI_ERROR_MEMORY_MANAGER_INCOMPLETE; \
+ } \
+ } while (0)
+
+
+
+#ifdef __cplusplus
+# define URIPARSER_EXTERN extern "C"
+#else
+# define URIPARSER_EXTERN extern
+#endif
+
+URIPARSER_EXTERN UriMemoryManager defaultMemoryManager;
+
+#undef URIPARSER_EXTERN
+
+
+
+UriBool uriMemoryManagerIsComplete(const UriMemoryManager * memory);
+
+
+
+#endif /* URI_MEMORY_H */
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriNormalize.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriNormalize.c
new file mode 100644
index 000000000..3fc749203
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriNormalize.c
@@ -0,0 +1,771 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/**
+ * @file UriNormalize.c
+ * Holds the RFC 3986 %URI normalization implementation.
+ * NOTE: This source file includes itself twice.
+ */
+
+/* What encodings are enabled? */
+#include "UriDefsConfig.h"
+#if (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* Include SELF twice */
+# ifdef URI_ENABLE_ANSI
+# define URI_PASS_ANSI 1
+# include "UriNormalize.c"
+# undef URI_PASS_ANSI
+# endif
+# ifdef URI_ENABLE_UNICODE
+# define URI_PASS_UNICODE 1
+# include "UriNormalize.c"
+# undef URI_PASS_UNICODE
+# endif
+#else
+# ifdef URI_PASS_ANSI
+# include "UriDefsAnsi.h"
+# else
+# include "UriDefsUnicode.h"
+# include <wchar.h>
+# endif
+
+
+
+#ifndef URI_DOXYGEN
+# include "Uri.h"
+# include "UriNormalizeBase.h"
+# include "UriCommon.h"
+# include "UriMemory.h"
+#endif
+
+
+
+#include <assert.h>
+
+
+
+static int URI_FUNC(NormalizeSyntaxEngine)(URI_TYPE(Uri) * uri, unsigned int inMask,
+ unsigned int * outMask, UriMemoryManager * memory);
+
+static UriBool URI_FUNC(MakeRangeOwner)(unsigned int * doneMask,
+ unsigned int maskTest, URI_TYPE(TextRange) * range,
+ UriMemoryManager * memory);
+static UriBool URI_FUNC(MakeOwner)(URI_TYPE(Uri) * uri,
+ unsigned int * doneMask, UriMemoryManager * memory);
+
+static void URI_FUNC(FixPercentEncodingInplace)(const URI_CHAR * first,
+ const URI_CHAR ** afterLast);
+static UriBool URI_FUNC(FixPercentEncodingMalloc)(const URI_CHAR ** first,
+ const URI_CHAR ** afterLast, UriMemoryManager * memory);
+static void URI_FUNC(FixPercentEncodingEngine)(
+ const URI_CHAR * inFirst, const URI_CHAR * inAfterLast,
+ const URI_CHAR * outFirst, const URI_CHAR ** outAfterLast);
+
+static UriBool URI_FUNC(ContainsUppercaseLetters)(const URI_CHAR * first,
+ const URI_CHAR * afterLast);
+static UriBool URI_FUNC(ContainsUglyPercentEncoding)(const URI_CHAR * first,
+ const URI_CHAR * afterLast);
+
+static void URI_FUNC(LowercaseInplace)(const URI_CHAR * first,
+ const URI_CHAR * afterLast);
+static UriBool URI_FUNC(LowercaseMalloc)(const URI_CHAR ** first,
+ const URI_CHAR ** afterLast, UriMemoryManager * memory);
+
+static void URI_FUNC(PreventLeakage)(URI_TYPE(Uri) * uri,
+ unsigned int revertMask, UriMemoryManager * memory);
+
+
+
+static URI_INLINE void URI_FUNC(PreventLeakage)(URI_TYPE(Uri) * uri,
+ unsigned int revertMask, UriMemoryManager * memory) {
+ if (revertMask & URI_NORMALIZE_SCHEME) {
+ memory->free(memory, (URI_CHAR *)uri->scheme.first);
+ uri->scheme.first = NULL;
+ uri->scheme.afterLast = NULL;
+ }
+
+ if (revertMask & URI_NORMALIZE_USER_INFO) {
+ memory->free(memory, (URI_CHAR *)uri->userInfo.first);
+ uri->userInfo.first = NULL;
+ uri->userInfo.afterLast = NULL;
+ }
+
+ if (revertMask & URI_NORMALIZE_HOST) {
+ if (uri->hostData.ipFuture.first != NULL) {
+ /* IPvFuture */
+ memory->free(memory, (URI_CHAR *)uri->hostData.ipFuture.first);
+ uri->hostData.ipFuture.first = NULL;
+ uri->hostData.ipFuture.afterLast = NULL;
+ uri->hostText.first = NULL;
+ uri->hostText.afterLast = NULL;
+ } else if ((uri->hostText.first != NULL)
+ && (uri->hostData.ip4 == NULL)
+ && (uri->hostData.ip6 == NULL)) {
+ /* Regname */
+ memory->free(memory, (URI_CHAR *)uri->hostText.first);
+ uri->hostText.first = NULL;
+ uri->hostText.afterLast = NULL;
+ }
+ }
+
+ /* NOTE: Port cannot happen! */
+
+ if (revertMask & URI_NORMALIZE_PATH) {
+ URI_TYPE(PathSegment) * walker = uri->pathHead;
+ while (walker != NULL) {
+ URI_TYPE(PathSegment) * const next = walker->next;
+ if (walker->text.afterLast > walker->text.first) {
+ memory->free(memory, (URI_CHAR *)walker->text.first);
+ }
+ memory->free(memory, walker);
+ walker = next;
+ }
+ uri->pathHead = NULL;
+ uri->pathTail = NULL;
+ }
+
+ if (revertMask & URI_NORMALIZE_QUERY) {
+ memory->free(memory, (URI_CHAR *)uri->query.first);
+ uri->query.first = NULL;
+ uri->query.afterLast = NULL;
+ }
+
+ if (revertMask & URI_NORMALIZE_FRAGMENT) {
+ memory->free(memory, (URI_CHAR *)uri->fragment.first);
+ uri->fragment.first = NULL;
+ uri->fragment.afterLast = NULL;
+ }
+}
+
+
+
+static URI_INLINE UriBool URI_FUNC(ContainsUppercaseLetters)(const URI_CHAR * first,
+ const URI_CHAR * afterLast) {
+ if ((first != NULL) && (afterLast != NULL) && (afterLast > first)) {
+ const URI_CHAR * i = first;
+ for (; i < afterLast; i++) {
+ /* 6.2.2.1 Case Normalization: uppercase letters in scheme or host */
+ if ((*i >= _UT('A')) && (*i <= _UT('Z'))) {
+ return URI_TRUE;
+ }
+ }
+ }
+ return URI_FALSE;
+}
+
+
+
+static URI_INLINE UriBool URI_FUNC(ContainsUglyPercentEncoding)(const URI_CHAR * first,
+ const URI_CHAR * afterLast) {
+ if ((first != NULL) && (afterLast != NULL) && (afterLast > first)) {
+ const URI_CHAR * i = first;
+ for (; i + 2 < afterLast; i++) {
+ if (i[0] == _UT('%')) {
+ /* 6.2.2.1 Case Normalization: *
+ * lowercase percent-encodings */
+ if (((i[1] >= _UT('a')) && (i[1] <= _UT('f')))
+ || ((i[2] >= _UT('a')) && (i[2] <= _UT('f')))) {
+ return URI_TRUE;
+ } else {
+ /* 6.2.2.2 Percent-Encoding Normalization: *
+ * percent-encoded unreserved characters */
+ const unsigned char left = URI_FUNC(HexdigToInt)(i[1]);
+ const unsigned char right = URI_FUNC(HexdigToInt)(i[2]);
+ const int code = 16 * left + right;
+ if (uriIsUnreserved(code)) {
+ return URI_TRUE;
+ }
+ }
+ }
+ }
+ }
+ return URI_FALSE;
+}
+
+
+
+static URI_INLINE void URI_FUNC(LowercaseInplace)(const URI_CHAR * first,
+ const URI_CHAR * afterLast) {
+ if ((first != NULL) && (afterLast != NULL) && (afterLast > first)) {
+ URI_CHAR * i = (URI_CHAR *)first;
+ const int lowerUpperDiff = (_UT('a') - _UT('A'));
+ for (; i < afterLast; i++) {
+ if ((*i >= _UT('A')) && (*i <=_UT('Z'))) {
+ *i = (URI_CHAR)(*i + lowerUpperDiff);
+ }
+ }
+ }
+}
+
+
+
+static URI_INLINE UriBool URI_FUNC(LowercaseMalloc)(const URI_CHAR ** first,
+ const URI_CHAR ** afterLast, UriMemoryManager * memory) {
+ int lenInChars;
+ const int lowerUpperDiff = (_UT('a') - _UT('A'));
+ URI_CHAR * buffer;
+ int i = 0;
+
+ if ((first == NULL) || (afterLast == NULL) || (*first == NULL)
+ || (*afterLast == NULL)) {
+ return URI_FALSE;
+ }
+
+ lenInChars = (int)(*afterLast - *first);
+ if (lenInChars == 0) {
+ return URI_TRUE;
+ } else if (lenInChars < 0) {
+ return URI_FALSE;
+ }
+
+ buffer = memory->malloc(memory, lenInChars * sizeof(URI_CHAR));
+ if (buffer == NULL) {
+ return URI_FALSE;
+ }
+
+ for (; i < lenInChars; i++) {
+ if (((*first)[i] >= _UT('A')) && ((*first)[i] <=_UT('Z'))) {
+ buffer[i] = (URI_CHAR)((*first)[i] + lowerUpperDiff);
+ } else {
+ buffer[i] = (*first)[i];
+ }
+ }
+
+ *first = buffer;
+ *afterLast = buffer + lenInChars;
+ return URI_TRUE;
+}
+
+
+
+/* NOTE: Implementation must stay inplace-compatible */
+static URI_INLINE void URI_FUNC(FixPercentEncodingEngine)(
+ const URI_CHAR * inFirst, const URI_CHAR * inAfterLast,
+ const URI_CHAR * outFirst, const URI_CHAR ** outAfterLast) {
+ URI_CHAR * write = (URI_CHAR *)outFirst;
+ const int lenInChars = (int)(inAfterLast - inFirst);
+ int i = 0;
+
+ /* All but last two */
+ for (; i + 2 < lenInChars; i++) {
+ if (inFirst[i] != _UT('%')) {
+ write[0] = inFirst[i];
+ write++;
+ } else {
+ /* 6.2.2.2 Percent-Encoding Normalization: *
+ * percent-encoded unreserved characters */
+ const URI_CHAR one = inFirst[i + 1];
+ const URI_CHAR two = inFirst[i + 2];
+ const unsigned char left = URI_FUNC(HexdigToInt)(one);
+ const unsigned char right = URI_FUNC(HexdigToInt)(two);
+ const int code = 16 * left + right;
+ if (uriIsUnreserved(code)) {
+ write[0] = (URI_CHAR)(code);
+ write++;
+ } else {
+ /* 6.2.2.1 Case Normalization: *
+ * lowercase percent-encodings */
+ write[0] = _UT('%');
+ write[1] = URI_FUNC(HexToLetter)(left);
+ write[2] = URI_FUNC(HexToLetter)(right);
+ write += 3;
+ }
+
+ i += 2; /* For the two chars of the percent group we just ate */
+ }
+ }
+
+ /* Last two */
+ for (; i < lenInChars; i++) {
+ write[0] = inFirst[i];
+ write++;
+ }
+
+ *outAfterLast = write;
+}
+
+
+
+static URI_INLINE void URI_FUNC(FixPercentEncodingInplace)(const URI_CHAR * first,
+ const URI_CHAR ** afterLast) {
+ /* Death checks */
+ if ((first == NULL) || (afterLast == NULL) || (*afterLast == NULL)) {
+ return;
+ }
+
+ /* Fix inplace */
+ URI_FUNC(FixPercentEncodingEngine)(first, *afterLast, first, afterLast);
+}
+
+
+
+static URI_INLINE UriBool URI_FUNC(FixPercentEncodingMalloc)(const URI_CHAR ** first,
+ const URI_CHAR ** afterLast, UriMemoryManager * memory) {
+ int lenInChars;
+ URI_CHAR * buffer;
+
+ /* Death checks */
+ if ((first == NULL) || (afterLast == NULL)
+ || (*first == NULL) || (*afterLast == NULL)) {
+ return URI_FALSE;
+ }
+
+ /* Old text length */
+ lenInChars = (int)(*afterLast - *first);
+ if (lenInChars == 0) {
+ return URI_TRUE;
+ } else if (lenInChars < 0) {
+ return URI_FALSE;
+ }
+
+ /* New buffer */
+ buffer = memory->malloc(memory, lenInChars * sizeof(URI_CHAR));
+ if (buffer == NULL) {
+ return URI_FALSE;
+ }
+
+ /* Fix on copy */
+ URI_FUNC(FixPercentEncodingEngine)(*first, *afterLast, buffer, afterLast);
+ *first = buffer;
+ return URI_TRUE;
+}
+
+
+
+static URI_INLINE UriBool URI_FUNC(MakeRangeOwner)(unsigned int * doneMask,
+ unsigned int maskTest, URI_TYPE(TextRange) * range,
+ UriMemoryManager * memory) {
+ if (((*doneMask & maskTest) == 0)
+ && (range->first != NULL)
+ && (range->afterLast != NULL)
+ && (range->afterLast > range->first)) {
+ const int lenInChars = (int)(range->afterLast - range->first);
+ const int lenInBytes = lenInChars * sizeof(URI_CHAR);
+ URI_CHAR * dup = memory->malloc(memory, lenInBytes);
+ if (dup == NULL) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+ memcpy(dup, range->first, lenInBytes);
+ range->first = dup;
+ range->afterLast = dup + lenInChars;
+ *doneMask |= maskTest;
+ }
+ return URI_TRUE;
+}
+
+
+
+static URI_INLINE UriBool URI_FUNC(MakeOwner)(URI_TYPE(Uri) * uri,
+ unsigned int * doneMask, UriMemoryManager * memory) {
+ URI_TYPE(PathSegment) * walker = uri->pathHead;
+ if (!URI_FUNC(MakeRangeOwner)(doneMask, URI_NORMALIZE_SCHEME,
+ &(uri->scheme), memory)
+ || !URI_FUNC(MakeRangeOwner)(doneMask, URI_NORMALIZE_USER_INFO,
+ &(uri->userInfo), memory)
+ || !URI_FUNC(MakeRangeOwner)(doneMask, URI_NORMALIZE_QUERY,
+ &(uri->query), memory)
+ || !URI_FUNC(MakeRangeOwner)(doneMask, URI_NORMALIZE_FRAGMENT,
+ &(uri->fragment), memory)) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+
+ /* Host */
+ if ((*doneMask & URI_NORMALIZE_HOST) == 0) {
+ if ((uri->hostData.ip4 == NULL)
+ && (uri->hostData.ip6 == NULL)) {
+ if (uri->hostData.ipFuture.first != NULL) {
+ /* IPvFuture */
+ if (!URI_FUNC(MakeRangeOwner)(doneMask, URI_NORMALIZE_HOST,
+ &(uri->hostData.ipFuture), memory)) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+ uri->hostText.first = uri->hostData.ipFuture.first;
+ uri->hostText.afterLast = uri->hostData.ipFuture.afterLast;
+ } else if (uri->hostText.first != NULL) {
+ /* Regname */
+ if (!URI_FUNC(MakeRangeOwner)(doneMask, URI_NORMALIZE_HOST,
+ &(uri->hostText), memory)) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+ }
+ }
+ }
+
+ /* Path */
+ if ((*doneMask & URI_NORMALIZE_PATH) == 0) {
+ while (walker != NULL) {
+ if (!URI_FUNC(MakeRangeOwner)(doneMask, 0, &(walker->text), memory)) {
+ /* Free allocations done so far and kill path */
+
+ /* Kill path to one before walker (if any) */
+ URI_TYPE(PathSegment) * ranger = uri->pathHead;
+ while (ranger != walker) {
+ URI_TYPE(PathSegment) * const next = ranger->next;
+ if ((ranger->text.first != NULL)
+ && (ranger->text.afterLast != NULL)
+ && (ranger->text.afterLast > ranger->text.first)) {
+ memory->free(memory, (URI_CHAR *)ranger->text.first);
+ }
+ memory->free(memory, ranger);
+ ranger = next;
+ }
+
+ /* Kill path from walker */
+ while (walker != NULL) {
+ URI_TYPE(PathSegment) * const next = walker->next;
+ memory->free(memory, walker);
+ walker = next;
+ }
+
+ uri->pathHead = NULL;
+ uri->pathTail = NULL;
+ return URI_FALSE; /* Raises malloc error */
+ }
+ walker = walker->next;
+ }
+ *doneMask |= URI_NORMALIZE_PATH;
+ }
+
+ /* Port text, must come last so we don't have to undo that one if it fails. *
+ * Otherwise we would need and extra enum flag for it although the port *
+ * cannot go unnormalized... */
+ if (!URI_FUNC(MakeRangeOwner)(doneMask, 0, &(uri->portText), memory)) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+
+ return URI_TRUE;
+}
+
+
+
+unsigned int URI_FUNC(NormalizeSyntaxMaskRequired)(const URI_TYPE(Uri) * uri) {
+ unsigned int outMask = URI_NORMALIZED; /* for NULL uri */
+ URI_FUNC(NormalizeSyntaxMaskRequiredEx)(uri, &outMask);
+ return outMask;
+}
+
+
+
+int URI_FUNC(NormalizeSyntaxMaskRequiredEx)(const URI_TYPE(Uri) * uri,
+ unsigned int * outMask) {
+ UriMemoryManager * const memory = NULL; /* no use of memory manager */
+
+#if defined(__GNUC__) && ((__GNUC__ > 4) \
+ || ((__GNUC__ == 4) && defined(__GNUC_MINOR__) && (__GNUC_MINOR__ >= 2)))
+ /* Slower code that fixes a warning, not sure if this is a smart idea */
+ URI_TYPE(Uri) writeableClone;
+#endif
+
+ if ((uri == NULL) || (outMask == NULL)) {
+ return URI_ERROR_NULL;
+ }
+
+#if defined(__GNUC__) && ((__GNUC__ > 4) \
+ || ((__GNUC__ == 4) && defined(__GNUC_MINOR__) && (__GNUC_MINOR__ >= 2)))
+ /* Slower code that fixes a warning, not sure if this is a smart idea */
+ memcpy(&writeableClone, uri, 1 * sizeof(URI_TYPE(Uri)));
+ URI_FUNC(NormalizeSyntaxEngine)(&writeableClone, 0, outMask, memory);
+#else
+ URI_FUNC(NormalizeSyntaxEngine)((URI_TYPE(Uri) *)uri, 0, outMask, memory);
+#endif
+ return URI_SUCCESS;
+}
+
+
+
+int URI_FUNC(NormalizeSyntaxEx)(URI_TYPE(Uri) * uri, unsigned int mask) {
+ return URI_FUNC(NormalizeSyntaxExMm)(uri, mask, NULL);
+}
+
+
+
+int URI_FUNC(NormalizeSyntaxExMm)(URI_TYPE(Uri) * uri, unsigned int mask,
+ UriMemoryManager * memory) {
+ URI_CHECK_MEMORY_MANAGER(memory); /* may return */
+ return URI_FUNC(NormalizeSyntaxEngine)(uri, mask, NULL, memory);
+}
+
+
+
+int URI_FUNC(NormalizeSyntax)(URI_TYPE(Uri) * uri) {
+ return URI_FUNC(NormalizeSyntaxEx)(uri, (unsigned int)-1);
+}
+
+
+
+static URI_INLINE int URI_FUNC(NormalizeSyntaxEngine)(URI_TYPE(Uri) * uri,
+ unsigned int inMask, unsigned int * outMask,
+ UriMemoryManager * memory) {
+ unsigned int doneMask = URI_NORMALIZED;
+
+ /* Not just doing inspection? -> memory manager required! */
+ if (outMask == NULL) {
+ assert(memory != NULL);
+ }
+
+ if (uri == NULL) {
+ if (outMask != NULL) {
+ *outMask = URI_NORMALIZED;
+ return URI_SUCCESS;
+ } else {
+ return URI_ERROR_NULL;
+ }
+ }
+
+ if (outMask != NULL) {
+ /* Reset mask */
+ *outMask = URI_NORMALIZED;
+ } else if (inMask == URI_NORMALIZED) {
+ /* Nothing to do */
+ return URI_SUCCESS;
+ }
+
+ /* Scheme, host */
+ if (outMask != NULL) {
+ const UriBool normalizeScheme = URI_FUNC(ContainsUppercaseLetters)(
+ uri->scheme.first, uri->scheme.afterLast);
+ const UriBool normalizeHostCase = URI_FUNC(ContainsUppercaseLetters)(
+ uri->hostText.first, uri->hostText.afterLast);
+ if (normalizeScheme) {
+ *outMask |= URI_NORMALIZE_SCHEME;
+ }
+
+ if (normalizeHostCase) {
+ *outMask |= URI_NORMALIZE_HOST;
+ } else {
+ const UriBool normalizeHostPrecent = URI_FUNC(ContainsUglyPercentEncoding)(
+ uri->hostText.first, uri->hostText.afterLast);
+ if (normalizeHostPrecent) {
+ *outMask |= URI_NORMALIZE_HOST;
+ }
+ }
+ } else {
+ /* Scheme */
+ if ((inMask & URI_NORMALIZE_SCHEME) && (uri->scheme.first != NULL)) {
+ if (uri->owner) {
+ URI_FUNC(LowercaseInplace)(uri->scheme.first, uri->scheme.afterLast);
+ } else {
+ if (!URI_FUNC(LowercaseMalloc)(&(uri->scheme.first), &(uri->scheme.afterLast), memory)) {
+ URI_FUNC(PreventLeakage)(uri, doneMask, memory);
+ return URI_ERROR_MALLOC;
+ }
+ doneMask |= URI_NORMALIZE_SCHEME;
+ }
+ }
+
+ /* Host */
+ if (inMask & URI_NORMALIZE_HOST) {
+ if (uri->hostData.ipFuture.first != NULL) {
+ /* IPvFuture */
+ if (uri->owner) {
+ URI_FUNC(LowercaseInplace)(uri->hostData.ipFuture.first,
+ uri->hostData.ipFuture.afterLast);
+ } else {
+ if (!URI_FUNC(LowercaseMalloc)(&(uri->hostData.ipFuture.first),
+ &(uri->hostData.ipFuture.afterLast), memory)) {
+ URI_FUNC(PreventLeakage)(uri, doneMask, memory);
+ return URI_ERROR_MALLOC;
+ }
+ doneMask |= URI_NORMALIZE_HOST;
+ }
+ uri->hostText.first = uri->hostData.ipFuture.first;
+ uri->hostText.afterLast = uri->hostData.ipFuture.afterLast;
+ } else if ((uri->hostText.first != NULL)
+ && (uri->hostData.ip4 == NULL)
+ && (uri->hostData.ip6 == NULL)) {
+ /* Regname */
+ if (uri->owner) {
+ URI_FUNC(FixPercentEncodingInplace)(uri->hostText.first,
+ &(uri->hostText.afterLast));
+ } else {
+ if (!URI_FUNC(FixPercentEncodingMalloc)(
+ &(uri->hostText.first),
+ &(uri->hostText.afterLast),
+ memory)) {
+ URI_FUNC(PreventLeakage)(uri, doneMask, memory);
+ return URI_ERROR_MALLOC;
+ }
+ doneMask |= URI_NORMALIZE_HOST;
+ }
+
+ URI_FUNC(LowercaseInplace)(uri->hostText.first,
+ uri->hostText.afterLast);
+ }
+ }
+ }
+
+ /* User info */
+ if (outMask != NULL) {
+ const UriBool normalizeUserInfo = URI_FUNC(ContainsUglyPercentEncoding)(
+ uri->userInfo.first, uri->userInfo.afterLast);
+ if (normalizeUserInfo) {
+ *outMask |= URI_NORMALIZE_USER_INFO;
+ }
+ } else {
+ if ((inMask & URI_NORMALIZE_USER_INFO) && (uri->userInfo.first != NULL)) {
+ if (uri->owner) {
+ URI_FUNC(FixPercentEncodingInplace)(uri->userInfo.first, &(uri->userInfo.afterLast));
+ } else {
+ if (!URI_FUNC(FixPercentEncodingMalloc)(&(uri->userInfo.first),
+ &(uri->userInfo.afterLast), memory)) {
+ URI_FUNC(PreventLeakage)(uri, doneMask, memory);
+ return URI_ERROR_MALLOC;
+ }
+ doneMask |= URI_NORMALIZE_USER_INFO;
+ }
+ }
+ }
+
+ /* Path */
+ if (outMask != NULL) {
+ const URI_TYPE(PathSegment) * walker = uri->pathHead;
+ while (walker != NULL) {
+ const URI_CHAR * const first = walker->text.first;
+ const URI_CHAR * const afterLast = walker->text.afterLast;
+ if ((first != NULL)
+ && (afterLast != NULL)
+ && (afterLast > first)
+ && (
+ (((afterLast - first) == 1)
+ && (first[0] == _UT('.')))
+ ||
+ (((afterLast - first) == 2)
+ && (first[0] == _UT('.'))
+ && (first[1] == _UT('.')))
+ ||
+ URI_FUNC(ContainsUglyPercentEncoding)(first, afterLast)
+ )) {
+ *outMask |= URI_NORMALIZE_PATH;
+ break;
+ }
+ walker = walker->next;
+ }
+ } else if (inMask & URI_NORMALIZE_PATH) {
+ URI_TYPE(PathSegment) * walker;
+ const UriBool relative = ((uri->scheme.first == NULL)
+ && !uri->absolutePath) ? URI_TRUE : URI_FALSE;
+
+ /* Fix percent-encoding for each segment */
+ walker = uri->pathHead;
+ if (uri->owner) {
+ while (walker != NULL) {
+ URI_FUNC(FixPercentEncodingInplace)(walker->text.first, &(walker->text.afterLast));
+ walker = walker->next;
+ }
+ } else {
+ while (walker != NULL) {
+ if (!URI_FUNC(FixPercentEncodingMalloc)(&(walker->text.first),
+ &(walker->text.afterLast), memory)) {
+ URI_FUNC(PreventLeakage)(uri, doneMask, memory);
+ return URI_ERROR_MALLOC;
+ }
+ walker = walker->next;
+ }
+ doneMask |= URI_NORMALIZE_PATH;
+ }
+
+ /* 6.2.2.3 Path Segment Normalization */
+ if (!URI_FUNC(RemoveDotSegmentsEx)(uri, relative,
+ (uri->owner == URI_TRUE)
+ || ((doneMask & URI_NORMALIZE_PATH) != 0),
+ memory)) {
+ URI_FUNC(PreventLeakage)(uri, doneMask, memory);
+ return URI_ERROR_MALLOC;
+ }
+ URI_FUNC(FixEmptyTrailSegment)(uri, memory);
+ }
+
+ /* Query, fragment */
+ if (outMask != NULL) {
+ const UriBool normalizeQuery = URI_FUNC(ContainsUglyPercentEncoding)(
+ uri->query.first, uri->query.afterLast);
+ const UriBool normalizeFragment = URI_FUNC(ContainsUglyPercentEncoding)(
+ uri->fragment.first, uri->fragment.afterLast);
+ if (normalizeQuery) {
+ *outMask |= URI_NORMALIZE_QUERY;
+ }
+
+ if (normalizeFragment) {
+ *outMask |= URI_NORMALIZE_FRAGMENT;
+ }
+ } else {
+ /* Query */
+ if ((inMask & URI_NORMALIZE_QUERY) && (uri->query.first != NULL)) {
+ if (uri->owner) {
+ URI_FUNC(FixPercentEncodingInplace)(uri->query.first, &(uri->query.afterLast));
+ } else {
+ if (!URI_FUNC(FixPercentEncodingMalloc)(&(uri->query.first),
+ &(uri->query.afterLast), memory)) {
+ URI_FUNC(PreventLeakage)(uri, doneMask, memory);
+ return URI_ERROR_MALLOC;
+ }
+ doneMask |= URI_NORMALIZE_QUERY;
+ }
+ }
+
+ /* Fragment */
+ if ((inMask & URI_NORMALIZE_FRAGMENT) && (uri->fragment.first != NULL)) {
+ if (uri->owner) {
+ URI_FUNC(FixPercentEncodingInplace)(uri->fragment.first, &(uri->fragment.afterLast));
+ } else {
+ if (!URI_FUNC(FixPercentEncodingMalloc)(&(uri->fragment.first),
+ &(uri->fragment.afterLast), memory)) {
+ URI_FUNC(PreventLeakage)(uri, doneMask, memory);
+ return URI_ERROR_MALLOC;
+ }
+ doneMask |= URI_NORMALIZE_FRAGMENT;
+ }
+ }
+ }
+
+ /* Dup all not duped yet */
+ if ((outMask == NULL) && !uri->owner) {
+ if (!URI_FUNC(MakeOwner)(uri, &doneMask, memory)) {
+ URI_FUNC(PreventLeakage)(uri, doneMask, memory);
+ return URI_ERROR_MALLOC;
+ }
+ uri->owner = URI_TRUE;
+ }
+
+ return URI_SUCCESS;
+}
+
+
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriNormalizeBase.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriNormalizeBase.c
new file mode 100644
index 000000000..8dbe7fd19
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriNormalizeBase.c
@@ -0,0 +1,119 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#ifndef URI_DOXYGEN
+# include "UriNormalizeBase.h"
+#endif
+
+
+
+UriBool uriIsUnreserved(int code) {
+ switch (code) {
+ case L'a': /* ALPHA */
+ case L'A':
+ case L'b':
+ case L'B':
+ case L'c':
+ case L'C':
+ case L'd':
+ case L'D':
+ case L'e':
+ case L'E':
+ case L'f':
+ case L'F':
+ case L'g':
+ case L'G':
+ case L'h':
+ case L'H':
+ case L'i':
+ case L'I':
+ case L'j':
+ case L'J':
+ case L'k':
+ case L'K':
+ case L'l':
+ case L'L':
+ case L'm':
+ case L'M':
+ case L'n':
+ case L'N':
+ case L'o':
+ case L'O':
+ case L'p':
+ case L'P':
+ case L'q':
+ case L'Q':
+ case L'r':
+ case L'R':
+ case L's':
+ case L'S':
+ case L't':
+ case L'T':
+ case L'u':
+ case L'U':
+ case L'v':
+ case L'V':
+ case L'w':
+ case L'W':
+ case L'x':
+ case L'X':
+ case L'y':
+ case L'Y':
+ case L'z':
+ case L'Z':
+ case L'0': /* DIGIT */
+ case L'1':
+ case L'2':
+ case L'3':
+ case L'4':
+ case L'5':
+ case L'6':
+ case L'7':
+ case L'8':
+ case L'9':
+ case L'-': /* "-" / "." / "_" / "~" */
+ case L'.':
+ case L'_':
+ case L'~':
+ return URI_TRUE;
+
+ default:
+ return URI_FALSE;
+ }
+}
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriNormalizeBase.h b/src/arrow/cpp/src/arrow/vendored/uriparser/UriNormalizeBase.h
new file mode 100644
index 000000000..a12907564
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriNormalizeBase.h
@@ -0,0 +1,53 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#ifndef URI_NORMALIZE_BASE_H
+#define URI_NORMALIZE_BASE_H 1
+
+
+
+#include "UriBase.h"
+
+
+
+UriBool uriIsUnreserved(int code);
+
+
+
+#endif /* URI_NORMALIZE_BASE_H */
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriParse.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriParse.c
new file mode 100644
index 000000000..f5972d8f4
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriParse.c
@@ -0,0 +1,2410 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/**
+ * @file UriParse.c
+ * Holds the RFC 3986 %URI parsing implementation.
+ * NOTE: This source file includes itself twice.
+ */
+
+/* What encodings are enabled? */
+#include "UriDefsConfig.h"
+#if (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* Include SELF twice */
+# ifdef URI_ENABLE_ANSI
+# define URI_PASS_ANSI 1
+# include "UriParse.c"
+# undef URI_PASS_ANSI
+# endif
+# ifdef URI_ENABLE_UNICODE
+# define URI_PASS_UNICODE 1
+# include "UriParse.c"
+# undef URI_PASS_UNICODE
+# endif
+#else
+# ifdef URI_PASS_ANSI
+# include "UriDefsAnsi.h"
+# else
+# include "UriDefsUnicode.h"
+# include <wchar.h>
+# endif
+
+
+
+#ifndef URI_DOXYGEN
+# include "Uri.h"
+# include "UriIp4.h"
+# include "UriCommon.h"
+# include "UriMemory.h"
+# include "UriParseBase.h"
+#endif
+
+
+
+#define URI_SET_DIGIT \
+ _UT('0'): \
+ case _UT('1'): \
+ case _UT('2'): \
+ case _UT('3'): \
+ case _UT('4'): \
+ case _UT('5'): \
+ case _UT('6'): \
+ case _UT('7'): \
+ case _UT('8'): \
+ case _UT('9')
+
+#define URI_SET_HEX_LETTER_UPPER \
+ _UT('A'): \
+ case _UT('B'): \
+ case _UT('C'): \
+ case _UT('D'): \
+ case _UT('E'): \
+ case _UT('F')
+
+#define URI_SET_HEX_LETTER_LOWER \
+ _UT('a'): \
+ case _UT('b'): \
+ case _UT('c'): \
+ case _UT('d'): \
+ case _UT('e'): \
+ case _UT('f')
+
+#define URI_SET_HEXDIG \
+ URI_SET_DIGIT: \
+ case URI_SET_HEX_LETTER_UPPER: \
+ case URI_SET_HEX_LETTER_LOWER
+
+#define URI_SET_ALPHA \
+ URI_SET_HEX_LETTER_UPPER: \
+ case URI_SET_HEX_LETTER_LOWER: \
+ case _UT('g'): \
+ case _UT('G'): \
+ case _UT('h'): \
+ case _UT('H'): \
+ case _UT('i'): \
+ case _UT('I'): \
+ case _UT('j'): \
+ case _UT('J'): \
+ case _UT('k'): \
+ case _UT('K'): \
+ case _UT('l'): \
+ case _UT('L'): \
+ case _UT('m'): \
+ case _UT('M'): \
+ case _UT('n'): \
+ case _UT('N'): \
+ case _UT('o'): \
+ case _UT('O'): \
+ case _UT('p'): \
+ case _UT('P'): \
+ case _UT('q'): \
+ case _UT('Q'): \
+ case _UT('r'): \
+ case _UT('R'): \
+ case _UT('s'): \
+ case _UT('S'): \
+ case _UT('t'): \
+ case _UT('T'): \
+ case _UT('u'): \
+ case _UT('U'): \
+ case _UT('v'): \
+ case _UT('V'): \
+ case _UT('w'): \
+ case _UT('W'): \
+ case _UT('x'): \
+ case _UT('X'): \
+ case _UT('y'): \
+ case _UT('Y'): \
+ case _UT('z'): \
+ case _UT('Z')
+
+
+
+static const URI_CHAR * URI_FUNC(ParseAuthority)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseAuthorityTwo)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast);
+static const URI_CHAR * URI_FUNC(ParseHexZero)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast);
+static const URI_CHAR * URI_FUNC(ParseHierPart)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseIpFutLoop)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseIpFutStopGo)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseIpLit2)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseIPv6address2)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseMustBeSegmentNzNc)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseOwnHost)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseOwnHost2)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseOwnHostUserInfo)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseOwnHostUserInfoNz)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseOwnPortUserInfo)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseOwnUserInfo)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParsePartHelperTwo)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParsePathAbsEmpty)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParsePathAbsNoLeadSlash)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParsePathRootless)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParsePchar)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParsePctEncoded)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParsePctSubUnres)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParsePort)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast);
+static const URI_CHAR * URI_FUNC(ParseQueryFrag)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseSegment)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseSegmentNz)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseSegmentNzNcOrScheme2)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseUriReference)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseUriTail)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseUriTailTwo)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+static const URI_CHAR * URI_FUNC(ParseZeroMoreSlashSegs)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast, UriMemoryManager * memory);
+
+static UriBool URI_FUNC(OnExitOwnHost2)(URI_TYPE(ParserState) * state, const URI_CHAR * first, UriMemoryManager * memory);
+static UriBool URI_FUNC(OnExitOwnHostUserInfo)(URI_TYPE(ParserState) * state, const URI_CHAR * first, UriMemoryManager * memory);
+static UriBool URI_FUNC(OnExitOwnPortUserInfo)(URI_TYPE(ParserState) * state, const URI_CHAR * first, UriMemoryManager * memory);
+static UriBool URI_FUNC(OnExitSegmentNzNcOrScheme2)(URI_TYPE(ParserState) * state, const URI_CHAR * first, UriMemoryManager * memory);
+static void URI_FUNC(OnExitPartHelperTwo)(URI_TYPE(ParserState) * state);
+
+static void URI_FUNC(ResetParserStateExceptUri)(URI_TYPE(ParserState) * state);
+
+static UriBool URI_FUNC(PushPathSegment)(URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory);
+
+static void URI_FUNC(StopSyntax)(URI_TYPE(ParserState) * state, const URI_CHAR * errorPos, UriMemoryManager * memory);
+static void URI_FUNC(StopMalloc)(URI_TYPE(ParserState) * state, UriMemoryManager * memory);
+
+static int URI_FUNC(ParseUriExMm)(URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory);
+
+
+
+static URI_INLINE void URI_FUNC(StopSyntax)(URI_TYPE(ParserState) * state,
+ const URI_CHAR * errorPos, UriMemoryManager * memory) {
+ URI_FUNC(FreeUriMembersMm)(state->uri, memory);
+ state->errorPos = errorPos;
+ state->errorCode = URI_ERROR_SYNTAX;
+}
+
+
+
+static URI_INLINE void URI_FUNC(StopMalloc)(URI_TYPE(ParserState) * state, UriMemoryManager * memory) {
+ URI_FUNC(FreeUriMembersMm)(state->uri, memory);
+ state->errorPos = NULL;
+ state->errorCode = URI_ERROR_MALLOC;
+}
+
+
+
+/*
+ * [authority]-><[>[ipLit2][authorityTwo]
+ * [authority]->[ownHostUserInfoNz]
+ * [authority]-><NULL>
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParseAuthority)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ /* "" regname host */
+ state->uri->hostText.first = URI_FUNC(SafeToPointTo);
+ state->uri->hostText.afterLast = URI_FUNC(SafeToPointTo);
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('['):
+ {
+ const URI_CHAR * const afterIpLit2
+ = URI_FUNC(ParseIpLit2)(state, first + 1, afterLast, memory);
+ if (afterIpLit2 == NULL) {
+ return NULL;
+ }
+ state->uri->hostText.first = first + 1; /* HOST BEGIN */
+ return URI_FUNC(ParseAuthorityTwo)(state, afterIpLit2, afterLast);
+ }
+
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('%'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('-'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT('.'):
+ case _UT(':'):
+ case _UT(';'):
+ case _UT('@'):
+ case _UT('\''):
+ case _UT('_'):
+ case _UT('~'):
+ case _UT('+'):
+ case _UT('='):
+ case URI_SET_DIGIT:
+ case URI_SET_ALPHA:
+ state->uri->userInfo.first = first; /* USERINFO BEGIN */
+ return URI_FUNC(ParseOwnHostUserInfoNz)(state, first, afterLast, memory);
+
+ default:
+ /* "" regname host */
+ state->uri->hostText.first = URI_FUNC(SafeToPointTo);
+ state->uri->hostText.afterLast = URI_FUNC(SafeToPointTo);
+ return first;
+ }
+}
+
+
+
+/*
+ * [authorityTwo]-><:>[port]
+ * [authorityTwo]-><NULL>
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParseAuthorityTwo)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT(':'):
+ {
+ const URI_CHAR * const afterPort = URI_FUNC(ParsePort)(state, first + 1, afterLast);
+ if (afterPort == NULL) {
+ return NULL;
+ }
+ state->uri->portText.first = first + 1; /* PORT BEGIN */
+ state->uri->portText.afterLast = afterPort; /* PORT END */
+ return afterPort;
+ }
+
+ default:
+ return first;
+ }
+}
+
+
+
+/*
+ * [hexZero]->[HEXDIG][hexZero]
+ * [hexZero]-><NULL>
+ */
+static const URI_CHAR * URI_FUNC(ParseHexZero)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case URI_SET_HEXDIG:
+ return URI_FUNC(ParseHexZero)(state, first + 1, afterLast);
+
+ default:
+ return first;
+ }
+}
+
+
+
+/*
+ * [hierPart]->[pathRootless]
+ * [hierPart]-></>[partHelperTwo]
+ * [hierPart]-><NULL>
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParseHierPart)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('%'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('-'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT('.'):
+ case _UT(':'):
+ case _UT(';'):
+ case _UT('@'):
+ case _UT('\''):
+ case _UT('_'):
+ case _UT('~'):
+ case _UT('+'):
+ case _UT('='):
+ case URI_SET_DIGIT:
+ case URI_SET_ALPHA:
+ return URI_FUNC(ParsePathRootless)(state, first, afterLast, memory);
+
+ case _UT('/'):
+ return URI_FUNC(ParsePartHelperTwo)(state, first + 1, afterLast, memory);
+
+ default:
+ return first;
+ }
+}
+
+
+
+/*
+ * [ipFutLoop]->[subDelims][ipFutStopGo]
+ * [ipFutLoop]->[unreserved][ipFutStopGo]
+ * [ipFutLoop]-><:>[ipFutStopGo]
+ */
+static const URI_CHAR * URI_FUNC(ParseIpFutLoop)(URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+
+ switch (*first) {
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('-'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT('.'):
+ case _UT(':'):
+ case _UT(';'):
+ case _UT('\''):
+ case _UT('_'):
+ case _UT('~'):
+ case _UT('+'):
+ case _UT('='):
+ case URI_SET_DIGIT:
+ case URI_SET_ALPHA:
+ return URI_FUNC(ParseIpFutStopGo)(state, first + 1, afterLast, memory);
+
+ default:
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+}
+
+
+
+/*
+ * [ipFutStopGo]->[ipFutLoop]
+ * [ipFutStopGo]-><NULL>
+ */
+static const URI_CHAR * URI_FUNC(ParseIpFutStopGo)(
+ URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('-'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT('.'):
+ case _UT(':'):
+ case _UT(';'):
+ case _UT('\''):
+ case _UT('_'):
+ case _UT('~'):
+ case _UT('+'):
+ case _UT('='):
+ case URI_SET_DIGIT:
+ case URI_SET_ALPHA:
+ return URI_FUNC(ParseIpFutLoop)(state, first, afterLast, memory);
+
+ default:
+ return first;
+ }
+}
+
+
+
+/*
+ * [ipFuture]-><v>[HEXDIG][hexZero]<.>[ipFutLoop]
+ */
+static const URI_CHAR * URI_FUNC(ParseIpFuture)(URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+
+ /*
+ First character has already been
+ checked before entering this rule.
+
+ switch (*first) {
+ case _UT('v'):
+ */
+ if (first + 1 >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+
+ switch (first[1]) {
+ case URI_SET_HEXDIG:
+ {
+ const URI_CHAR * afterIpFutLoop;
+ const URI_CHAR * const afterHexZero
+ = URI_FUNC(ParseHexZero)(state, first + 2, afterLast);
+ if (afterHexZero == NULL) {
+ return NULL;
+ }
+ if (afterHexZero >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+ if (*afterHexZero != _UT('.')) {
+ URI_FUNC(StopSyntax)(state, afterHexZero, memory);
+ return NULL;
+ }
+ state->uri->hostText.first = first; /* HOST BEGIN */
+ state->uri->hostData.ipFuture.first = first; /* IPFUTURE BEGIN */
+ afterIpFutLoop = URI_FUNC(ParseIpFutLoop)(state, afterHexZero + 1, afterLast, memory);
+ if (afterIpFutLoop == NULL) {
+ return NULL;
+ }
+ state->uri->hostText.afterLast = afterIpFutLoop; /* HOST END */
+ state->uri->hostData.ipFuture.afterLast = afterIpFutLoop; /* IPFUTURE END */
+ return afterIpFutLoop;
+ }
+
+ default:
+ URI_FUNC(StopSyntax)(state, first + 1, memory);
+ return NULL;
+ }
+
+ /*
+ default:
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+ */
+}
+
+
+
+/*
+ * [ipLit2]->[ipFuture]<]>
+ * [ipLit2]->[IPv6address2]
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParseIpLit2)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+
+ switch (*first) {
+ case _UT('v'):
+ {
+ const URI_CHAR * const afterIpFuture
+ = URI_FUNC(ParseIpFuture)(state, first, afterLast, memory);
+ if (afterIpFuture == NULL) {
+ return NULL;
+ }
+ if (afterIpFuture >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+ if (*afterIpFuture != _UT(']')) {
+ URI_FUNC(StopSyntax)(state, afterIpFuture, memory);
+ return NULL;
+ }
+ return afterIpFuture + 1;
+ }
+
+ case _UT(':'):
+ case _UT(']'):
+ case URI_SET_HEXDIG:
+ state->uri->hostData.ip6 = memory->malloc(memory, 1 * sizeof(UriIp6)); /* Freed when stopping on parse error */
+ if (state->uri->hostData.ip6 == NULL) {
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ return URI_FUNC(ParseIPv6address2)(state, first, afterLast, memory);
+
+ default:
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+}
+
+
+
+/*
+ * [IPv6address2]->..<]>
+ */
+static const URI_CHAR * URI_FUNC(ParseIPv6address2)(
+ URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory) {
+ int zipperEver = 0;
+ int quadsDone = 0;
+ int digitCount = 0;
+ unsigned char digitHistory[4];
+ int ip4OctetsDone = 0;
+
+ unsigned char quadsAfterZipper[14];
+ int quadsAfterZipperCount = 0;
+
+
+ for (;;) {
+ if (first >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+
+ /* Inside IPv4 part? */
+ if (ip4OctetsDone > 0) {
+ /* Eat rest of IPv4 address */
+ for (;;) {
+ switch (*first) {
+ case URI_SET_DIGIT:
+ if (digitCount == 4) {
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+ digitHistory[digitCount++] = (unsigned char)(9 + *first - _UT('9'));
+ break;
+
+ case _UT('.'):
+ if ((ip4OctetsDone == 4) /* NOTE! */
+ || (digitCount == 0)
+ || (digitCount == 4)) {
+ /* Invalid digit or octet count */
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ } else if ((digitCount > 1)
+ && (digitHistory[0] == 0)) {
+ /* Leading zero */
+ URI_FUNC(StopSyntax)(state, first - digitCount, memory);
+ return NULL;
+ } else if ((digitCount > 2)
+ && (digitHistory[1] == 0)) {
+ /* Leading zero */
+ URI_FUNC(StopSyntax)(state, first - digitCount + 1, memory);
+ return NULL;
+ } else if ((digitCount == 3)
+ && (100 * digitHistory[0]
+ + 10 * digitHistory[1]
+ + digitHistory[2] > 255)) {
+ /* Octet value too large */
+ if (digitHistory[0] > 2) {
+ URI_FUNC(StopSyntax)(state, first - 3, memory);
+ } else if (digitHistory[1] > 5) {
+ URI_FUNC(StopSyntax)(state, first - 2, memory);
+ } else {
+ URI_FUNC(StopSyntax)(state, first - 1, memory);
+ }
+ return NULL;
+ }
+
+ /* Copy IPv4 octet */
+ state->uri->hostData.ip6->data[16 - 4 + ip4OctetsDone] = uriGetOctetValue(digitHistory, digitCount);
+ digitCount = 0;
+ ip4OctetsDone++;
+ break;
+
+ case _UT(']'):
+ if ((ip4OctetsDone != 3) /* NOTE! */
+ || (digitCount == 0)
+ || (digitCount == 4)) {
+ /* Invalid digit or octet count */
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ } else if ((digitCount > 1)
+ && (digitHistory[0] == 0)) {
+ /* Leading zero */
+ URI_FUNC(StopSyntax)(state, first - digitCount, memory);
+ return NULL;
+ } else if ((digitCount > 2)
+ && (digitHistory[1] == 0)) {
+ /* Leading zero */
+ URI_FUNC(StopSyntax)(state, first - digitCount + 1, memory);
+ return NULL;
+ } else if ((digitCount == 3)
+ && (100 * digitHistory[0]
+ + 10 * digitHistory[1]
+ + digitHistory[2] > 255)) {
+ /* Octet value too large */
+ if (digitHistory[0] > 2) {
+ URI_FUNC(StopSyntax)(state, first - 3, memory);
+ } else if (digitHistory[1] > 5) {
+ URI_FUNC(StopSyntax)(state, first - 2, memory);
+ } else {
+ URI_FUNC(StopSyntax)(state, first - 1, memory);
+ }
+ return NULL;
+ }
+
+ state->uri->hostText.afterLast = first; /* HOST END */
+
+ /* Copy missing quads right before IPv4 */
+ memcpy(state->uri->hostData.ip6->data + 16 - 4 - 2 * quadsAfterZipperCount,
+ quadsAfterZipper, 2 * quadsAfterZipperCount);
+
+ /* Copy last IPv4 octet */
+ state->uri->hostData.ip6->data[16 - 4 + 3] = uriGetOctetValue(digitHistory, digitCount);
+
+ return first + 1;
+
+ default:
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+ first++;
+
+ if (first >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+ }
+ } else {
+ /* Eat while no dot in sight */
+ int letterAmong = 0;
+ int walking = 1;
+ do {
+ switch (*first) {
+ case URI_SET_HEX_LETTER_LOWER:
+ letterAmong = 1;
+ if (digitCount == 4) {
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+ digitHistory[digitCount] = (unsigned char)(15 + *first - _UT('f'));
+ digitCount++;
+ break;
+
+ case URI_SET_HEX_LETTER_UPPER:
+ letterAmong = 1;
+ if (digitCount == 4) {
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+ digitHistory[digitCount] = (unsigned char)(15 + *first - _UT('F'));
+ digitCount++;
+ break;
+
+ case URI_SET_DIGIT:
+ if (digitCount == 4) {
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+ digitHistory[digitCount] = (unsigned char)(9 + *first - _UT('9'));
+ digitCount++;
+ break;
+
+ case _UT(':'):
+ {
+ int setZipper = 0;
+
+ if (digitCount > 0) {
+ if (zipperEver) {
+ uriWriteQuadToDoubleByte(digitHistory, digitCount, quadsAfterZipper + 2 * quadsAfterZipperCount);
+ quadsAfterZipperCount++;
+ } else {
+ uriWriteQuadToDoubleByte(digitHistory, digitCount, state->uri->hostData.ip6->data + 2 * quadsDone);
+ }
+ quadsDone++;
+ digitCount = 0;
+ }
+ letterAmong = 0;
+
+ /* Too many quads? */
+ if (quadsDone >= 8 - zipperEver) {
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+
+ /* "::"? */
+ if (first + 1 >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+ if (first[1] == _UT(':')) {
+ const int resetOffset = 2 * (quadsDone + (digitCount > 0));
+
+ first++;
+ if (zipperEver) {
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL; /* "::.+::" */
+ }
+
+ /* Zero everything after zipper */
+ memset(state->uri->hostData.ip6->data + resetOffset, 0, 16 - resetOffset);
+ setZipper = 1;
+
+ /* ":::+"? */
+ if (first + 1 >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL; /* No ']' yet */
+ }
+ if (first[1] == _UT(':')) {
+ URI_FUNC(StopSyntax)(state, first + 1, memory);
+ return NULL; /* ":::+ "*/
+ }
+ }
+
+ if (setZipper) {
+ zipperEver = 1;
+ }
+ }
+ break;
+
+ case _UT('.'):
+ if ((quadsDone > 6) /* NOTE */
+ || (!zipperEver && (quadsDone < 6))
+ || letterAmong
+ || (digitCount == 0)
+ || (digitCount == 4)) {
+ /* Invalid octet before */
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ } else if ((digitCount > 1)
+ && (digitHistory[0] == 0)) {
+ /* Leading zero */
+ URI_FUNC(StopSyntax)(state, first - digitCount, memory);
+ return NULL;
+ } else if ((digitCount > 2)
+ && (digitHistory[1] == 0)) {
+ /* Leading zero */
+ URI_FUNC(StopSyntax)(state, first - digitCount + 1, memory);
+ return NULL;
+ } else if ((digitCount == 3)
+ && (100 * digitHistory[0]
+ + 10 * digitHistory[1]
+ + digitHistory[2] > 255)) {
+ /* Octet value too large */
+ if (digitHistory[0] > 2) {
+ URI_FUNC(StopSyntax)(state, first - 3, memory);
+ } else if (digitHistory[1] > 5) {
+ URI_FUNC(StopSyntax)(state, first - 2, memory);
+ } else {
+ URI_FUNC(StopSyntax)(state, first - 1, memory);
+ }
+ return NULL;
+ }
+
+ /* Copy first IPv4 octet */
+ state->uri->hostData.ip6->data[16 - 4] = uriGetOctetValue(digitHistory, digitCount);
+ digitCount = 0;
+
+ /* Switch over to IPv4 loop */
+ ip4OctetsDone = 1;
+ walking = 0;
+ break;
+
+ case _UT(']'):
+ /* Too little quads? */
+ if (!zipperEver && !((quadsDone == 7) && (digitCount > 0))) {
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+
+ if (digitCount > 0) {
+ if (zipperEver) {
+ uriWriteQuadToDoubleByte(digitHistory, digitCount, quadsAfterZipper + 2 * quadsAfterZipperCount);
+ quadsAfterZipperCount++;
+ } else {
+ uriWriteQuadToDoubleByte(digitHistory, digitCount, state->uri->hostData.ip6->data + 2 * quadsDone);
+ }
+ /*
+ quadsDone++;
+ digitCount = 0;
+ */
+ }
+
+ /* Copy missing quads to the end */
+ memcpy(state->uri->hostData.ip6->data + 16 - 2 * quadsAfterZipperCount,
+ quadsAfterZipper, 2 * quadsAfterZipperCount);
+
+ state->uri->hostText.afterLast = first; /* HOST END */
+ return first + 1; /* Fine */
+
+ default:
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+ first++;
+
+ if (first >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL; /* No ']' yet */
+ }
+ } while (walking);
+ }
+ }
+}
+
+
+
+/*
+ * [mustBeSegmentNzNc]->[pctEncoded][mustBeSegmentNzNc]
+ * [mustBeSegmentNzNc]->[subDelims][mustBeSegmentNzNc]
+ * [mustBeSegmentNzNc]->[unreserved][mustBeSegmentNzNc]
+ * [mustBeSegmentNzNc]->[uriTail] // can take <NULL>
+ * [mustBeSegmentNzNc]-></>[segment][zeroMoreSlashSegs][uriTail]
+ * [mustBeSegmentNzNc]-><@>[mustBeSegmentNzNc]
+ */
+static const URI_CHAR * URI_FUNC(ParseMustBeSegmentNzNc)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ if (!URI_FUNC(PushPathSegment)(state, state->uri->scheme.first, first, memory)) { /* SEGMENT BOTH */
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ state->uri->scheme.first = NULL; /* Not a scheme, reset */
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('%'):
+ {
+ const URI_CHAR * const afterPctEncoded
+ = URI_FUNC(ParsePctEncoded)(state, first, afterLast, memory);
+ if (afterPctEncoded == NULL) {
+ return NULL;
+ }
+ return URI_FUNC(ParseMustBeSegmentNzNc)(state, afterPctEncoded, afterLast, memory);
+ }
+
+ case _UT('@'):
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT(';'):
+ case _UT('\''):
+ case _UT('+'):
+ case _UT('='):
+ case _UT('-'):
+ case _UT('.'):
+ case _UT('_'):
+ case _UT('~'):
+ case URI_SET_DIGIT:
+ case URI_SET_ALPHA:
+ return URI_FUNC(ParseMustBeSegmentNzNc)(state, first + 1, afterLast, memory);
+
+ case _UT('/'):
+ {
+ const URI_CHAR * afterZeroMoreSlashSegs;
+ const URI_CHAR * afterSegment;
+ if (!URI_FUNC(PushPathSegment)(state, state->uri->scheme.first, first, memory)) { /* SEGMENT BOTH */
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ state->uri->scheme.first = NULL; /* Not a scheme, reset */
+ afterSegment = URI_FUNC(ParseSegment)(state, first + 1, afterLast, memory);
+ if (afterSegment == NULL) {
+ return NULL;
+ }
+ if (!URI_FUNC(PushPathSegment)(state, first + 1, afterSegment, memory)) { /* SEGMENT BOTH */
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ afterZeroMoreSlashSegs
+ = URI_FUNC(ParseZeroMoreSlashSegs)(state, afterSegment, afterLast, memory);
+ if (afterZeroMoreSlashSegs == NULL) {
+ return NULL;
+ }
+ return URI_FUNC(ParseUriTail)(state, afterZeroMoreSlashSegs, afterLast, memory);
+ }
+
+ default:
+ if (!URI_FUNC(PushPathSegment)(state, state->uri->scheme.first, first, memory)) { /* SEGMENT BOTH */
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ state->uri->scheme.first = NULL; /* Not a scheme, reset */
+ return URI_FUNC(ParseUriTail)(state, first, afterLast, memory);
+ }
+}
+
+
+
+/*
+ * [ownHost]-><[>[ipLit2][authorityTwo]
+ * [ownHost]->[ownHost2] // can take <NULL>
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParseOwnHost)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ state->uri->hostText.afterLast = afterLast; /* HOST END */
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('['):
+ {
+ const URI_CHAR * const afterIpLit2
+ = URI_FUNC(ParseIpLit2)(state, first + 1, afterLast, memory);
+ if (afterIpLit2 == NULL) {
+ return NULL;
+ }
+ state->uri->hostText.first = first + 1; /* HOST BEGIN */
+ return URI_FUNC(ParseAuthorityTwo)(state, afterIpLit2, afterLast);
+ }
+
+ default:
+ return URI_FUNC(ParseOwnHost2)(state, first, afterLast, memory);
+ }
+}
+
+
+
+static URI_INLINE UriBool URI_FUNC(OnExitOwnHost2)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ UriMemoryManager * memory) {
+ state->uri->hostText.afterLast = first; /* HOST END */
+
+ /* Valid IPv4 or just a regname? */
+ state->uri->hostData.ip4 = memory->malloc(memory, 1 * sizeof(UriIp4)); /* Freed when stopping on parse error */
+ if (state->uri->hostData.ip4 == NULL) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+ if (URI_FUNC(ParseIpFourAddress)(state->uri->hostData.ip4->data,
+ state->uri->hostText.first, state->uri->hostText.afterLast)) {
+ /* Not IPv4 */
+ memory->free(memory, state->uri->hostData.ip4);
+ state->uri->hostData.ip4 = NULL;
+ }
+ return URI_TRUE; /* Success */
+}
+
+
+
+/*
+ * [ownHost2]->[authorityTwo] // can take <NULL>
+ * [ownHost2]->[pctSubUnres][ownHost2]
+ */
+static const URI_CHAR * URI_FUNC(ParseOwnHost2)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ if (!URI_FUNC(OnExitOwnHost2)(state, first, memory)) {
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('%'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('-'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT('.'):
+ case _UT(';'):
+ case _UT('\''):
+ case _UT('_'):
+ case _UT('~'):
+ case _UT('+'):
+ case _UT('='):
+ case URI_SET_DIGIT:
+ case URI_SET_ALPHA:
+ {
+ const URI_CHAR * const afterPctSubUnres
+ = URI_FUNC(ParsePctSubUnres)(state, first, afterLast, memory);
+ if (afterPctSubUnres == NULL) {
+ return NULL;
+ }
+ return URI_FUNC(ParseOwnHost2)(state, afterPctSubUnres, afterLast, memory);
+ }
+
+ default:
+ if (!URI_FUNC(OnExitOwnHost2)(state, first, memory)) {
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ return URI_FUNC(ParseAuthorityTwo)(state, first, afterLast);
+ }
+}
+
+
+
+static URI_INLINE UriBool URI_FUNC(OnExitOwnHostUserInfo)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ UriMemoryManager * memory) {
+ state->uri->hostText.first = state->uri->userInfo.first; /* Host instead of userInfo, update */
+ state->uri->userInfo.first = NULL; /* Not a userInfo, reset */
+ state->uri->hostText.afterLast = first; /* HOST END */
+
+ /* Valid IPv4 or just a regname? */
+ state->uri->hostData.ip4 = memory->malloc(memory, 1 * sizeof(UriIp4)); /* Freed when stopping on parse error */
+ if (state->uri->hostData.ip4 == NULL) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+ if (URI_FUNC(ParseIpFourAddress)(state->uri->hostData.ip4->data,
+ state->uri->hostText.first, state->uri->hostText.afterLast)) {
+ /* Not IPv4 */
+ memory->free(memory, state->uri->hostData.ip4);
+ state->uri->hostData.ip4 = NULL;
+ }
+ return URI_TRUE; /* Success */
+}
+
+
+
+/*
+ * [ownHostUserInfo]->[ownHostUserInfoNz]
+ * [ownHostUserInfo]-><NULL>
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParseOwnHostUserInfo)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ if (!URI_FUNC(OnExitOwnHostUserInfo)(state, first, memory)) {
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('%'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('-'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT('.'):
+ case _UT(':'):
+ case _UT(';'):
+ case _UT('@'):
+ case _UT('\''):
+ case _UT('_'):
+ case _UT('~'):
+ case _UT('+'):
+ case _UT('='):
+ case URI_SET_DIGIT:
+ case URI_SET_ALPHA:
+ return URI_FUNC(ParseOwnHostUserInfoNz)(state, first, afterLast, memory);
+
+ default:
+ if (!URI_FUNC(OnExitOwnHostUserInfo)(state, first, memory)) {
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ return first;
+ }
+}
+
+
+
+/*
+ * [ownHostUserInfoNz]->[pctSubUnres][ownHostUserInfo]
+ * [ownHostUserInfoNz]-><:>[ownPortUserInfo]
+ * [ownHostUserInfoNz]-><@>[ownHost]
+ */
+static const URI_CHAR * URI_FUNC(ParseOwnHostUserInfoNz)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+
+ switch (*first) {
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('%'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('-'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT('.'):
+ case _UT(';'):
+ case _UT('\''):
+ case _UT('_'):
+ case _UT('~'):
+ case _UT('+'):
+ case _UT('='):
+ case URI_SET_DIGIT:
+ case URI_SET_ALPHA:
+ {
+ const URI_CHAR * const afterPctSubUnres
+ = URI_FUNC(ParsePctSubUnres)(state, first, afterLast, memory);
+ if (afterPctSubUnres == NULL) {
+ return NULL;
+ }
+ return URI_FUNC(ParseOwnHostUserInfo)(state, afterPctSubUnres, afterLast, memory);
+ }
+
+ case _UT(':'):
+ state->uri->hostText.afterLast = first; /* HOST END */
+ state->uri->portText.first = first + 1; /* PORT BEGIN */
+ return URI_FUNC(ParseOwnPortUserInfo)(state, first + 1, afterLast, memory);
+
+ case _UT('@'):
+ state->uri->userInfo.afterLast = first; /* USERINFO END */
+ state->uri->hostText.first = first + 1; /* HOST BEGIN */
+ return URI_FUNC(ParseOwnHost)(state, first + 1, afterLast, memory);
+
+ default:
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+}
+
+
+
+static URI_INLINE UriBool URI_FUNC(OnExitOwnPortUserInfo)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ UriMemoryManager * memory) {
+ state->uri->hostText.first = state->uri->userInfo.first; /* Host instead of userInfo, update */
+ state->uri->userInfo.first = NULL; /* Not a userInfo, reset */
+ state->uri->portText.afterLast = first; /* PORT END */
+
+ /* Valid IPv4 or just a regname? */
+ state->uri->hostData.ip4 = memory->malloc(memory, 1 * sizeof(UriIp4)); /* Freed when stopping on parse error */
+ if (state->uri->hostData.ip4 == NULL) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+ if (URI_FUNC(ParseIpFourAddress)(state->uri->hostData.ip4->data,
+ state->uri->hostText.first, state->uri->hostText.afterLast)) {
+ /* Not IPv4 */
+ memory->free(memory, state->uri->hostData.ip4);
+ state->uri->hostData.ip4 = NULL;
+ }
+ return URI_TRUE; /* Success */
+}
+
+
+
+/*
+ * [ownPortUserInfo]->[ALPHA][ownUserInfo]
+ * [ownPortUserInfo]->[DIGIT][ownPortUserInfo]
+ * [ownPortUserInfo]-><.>[ownUserInfo]
+ * [ownPortUserInfo]-><_>[ownUserInfo]
+ * [ownPortUserInfo]-><~>[ownUserInfo]
+ * [ownPortUserInfo]-><->[ownUserInfo]
+ * [ownPortUserInfo]->[subDelims][ownUserInfo]
+ * [ownPortUserInfo]->[pctEncoded][ownUserInfo]
+ * [ownPortUserInfo]-><:>[ownUserInfo]
+ * [ownPortUserInfo]-><@>[ownHost]
+ * [ownPortUserInfo]-><NULL>
+ */
+static const URI_CHAR * URI_FUNC(ParseOwnPortUserInfo)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ if (!URI_FUNC(OnExitOwnPortUserInfo)(state, first, memory)) {
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ return afterLast;
+ }
+
+ switch (*first) {
+ /* begin sub-delims */
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('&'):
+ case _UT('\''):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('*'):
+ case _UT('+'):
+ case _UT(','):
+ case _UT(';'):
+ case _UT('='):
+ /* end sub-delims */
+ /* begin unreserved (except alpha and digit) */
+ case _UT('-'):
+ case _UT('.'):
+ case _UT('_'):
+ case _UT('~'):
+ /* end unreserved (except alpha and digit) */
+ case _UT(':'):
+ case URI_SET_ALPHA:
+ state->uri->hostText.afterLast = NULL; /* Not a host, reset */
+ state->uri->portText.first = NULL; /* Not a port, reset */
+ return URI_FUNC(ParseOwnUserInfo)(state, first + 1, afterLast, memory);
+
+ case URI_SET_DIGIT:
+ return URI_FUNC(ParseOwnPortUserInfo)(state, first + 1, afterLast, memory);
+
+ case _UT('%'):
+ state->uri->portText.first = NULL; /* Not a port, reset */
+ {
+ const URI_CHAR * const afterPct
+ = URI_FUNC(ParsePctEncoded)(state, first, afterLast, memory);
+ if (afterPct == NULL) {
+ return NULL;
+ }
+ return URI_FUNC(ParseOwnUserInfo)(state, afterPct, afterLast, memory);
+ }
+
+ case _UT('@'):
+ state->uri->hostText.afterLast = NULL; /* Not a host, reset */
+ state->uri->portText.first = NULL; /* Not a port, reset */
+ state->uri->userInfo.afterLast = first; /* USERINFO END */
+ state->uri->hostText.first = first + 1; /* HOST BEGIN */
+ return URI_FUNC(ParseOwnHost)(state, first + 1, afterLast, memory);
+
+ default:
+ if (!URI_FUNC(OnExitOwnPortUserInfo)(state, first, memory)) {
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ return first;
+ }
+}
+
+
+
+/*
+ * [ownUserInfo]->[pctSubUnres][ownUserInfo]
+ * [ownUserInfo]-><:>[ownUserInfo]
+ * [ownUserInfo]-><@>[ownHost]
+ */
+static const URI_CHAR * URI_FUNC(ParseOwnUserInfo)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+
+ switch (*first) {
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('%'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('-'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT('.'):
+ case _UT(';'):
+ case _UT('\''):
+ case _UT('_'):
+ case _UT('~'):
+ case _UT('+'):
+ case _UT('='):
+ case URI_SET_DIGIT:
+ case URI_SET_ALPHA:
+ {
+ const URI_CHAR * const afterPctSubUnres
+ = URI_FUNC(ParsePctSubUnres)(state, first, afterLast, memory);
+ if (afterPctSubUnres == NULL) {
+ return NULL;
+ }
+ return URI_FUNC(ParseOwnUserInfo)(state, afterPctSubUnres, afterLast, memory);
+ }
+
+ case _UT(':'):
+ return URI_FUNC(ParseOwnUserInfo)(state, first + 1, afterLast, memory);
+
+ case _UT('@'):
+ /* SURE */
+ state->uri->userInfo.afterLast = first; /* USERINFO END */
+ state->uri->hostText.first = first + 1; /* HOST BEGIN */
+ return URI_FUNC(ParseOwnHost)(state, first + 1, afterLast, memory);
+
+ default:
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+}
+
+
+
+static URI_INLINE void URI_FUNC(OnExitPartHelperTwo)(URI_TYPE(ParserState) * state) {
+ state->uri->absolutePath = URI_TRUE;
+}
+
+
+
+/*
+ * [partHelperTwo]->[pathAbsNoLeadSlash] // can take <NULL>
+ * [partHelperTwo]-></>[authority][pathAbsEmpty]
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParsePartHelperTwo)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ URI_FUNC(OnExitPartHelperTwo)(state);
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('/'):
+ {
+ const URI_CHAR * const afterAuthority
+ = URI_FUNC(ParseAuthority)(state, first + 1, afterLast, memory);
+ const URI_CHAR * afterPathAbsEmpty;
+ if (afterAuthority == NULL) {
+ return NULL;
+ }
+ afterPathAbsEmpty = URI_FUNC(ParsePathAbsEmpty)(state, afterAuthority, afterLast, memory);
+
+ URI_FUNC(FixEmptyTrailSegment)(state->uri, memory);
+
+ return afterPathAbsEmpty;
+ }
+
+ default:
+ URI_FUNC(OnExitPartHelperTwo)(state);
+ return URI_FUNC(ParsePathAbsNoLeadSlash)(state, first, afterLast, memory);
+ }
+}
+
+
+
+/*
+ * [pathAbsEmpty]-></>[segment][pathAbsEmpty]
+ * [pathAbsEmpty]-><NULL>
+ */
+static const URI_CHAR * URI_FUNC(ParsePathAbsEmpty)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('/'):
+ {
+ const URI_CHAR * const afterSegment
+ = URI_FUNC(ParseSegment)(state, first + 1, afterLast, memory);
+ if (afterSegment == NULL) {
+ return NULL;
+ }
+ if (!URI_FUNC(PushPathSegment)(state, first + 1, afterSegment, memory)) { /* SEGMENT BOTH */
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ return URI_FUNC(ParsePathAbsEmpty)(state, afterSegment, afterLast, memory);
+ }
+
+ default:
+ return first;
+ }
+}
+
+
+
+/*
+ * [pathAbsNoLeadSlash]->[segmentNz][zeroMoreSlashSegs]
+ * [pathAbsNoLeadSlash]-><NULL>
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParsePathAbsNoLeadSlash)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('%'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('-'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT('.'):
+ case _UT(':'):
+ case _UT(';'):
+ case _UT('@'):
+ case _UT('\''):
+ case _UT('_'):
+ case _UT('~'):
+ case _UT('+'):
+ case _UT('='):
+ case URI_SET_DIGIT:
+ case URI_SET_ALPHA:
+ {
+ const URI_CHAR * const afterSegmentNz
+ = URI_FUNC(ParseSegmentNz)(state, first, afterLast, memory);
+ if (afterSegmentNz == NULL) {
+ return NULL;
+ }
+ if (!URI_FUNC(PushPathSegment)(state, first, afterSegmentNz, memory)) { /* SEGMENT BOTH */
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ return URI_FUNC(ParseZeroMoreSlashSegs)(state, afterSegmentNz, afterLast, memory);
+ }
+
+ default:
+ return first;
+ }
+}
+
+
+
+/*
+ * [pathRootless]->[segmentNz][zeroMoreSlashSegs]
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParsePathRootless)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ const URI_CHAR * const afterSegmentNz
+ = URI_FUNC(ParseSegmentNz)(state, first, afterLast, memory);
+ if (afterSegmentNz == NULL) {
+ return NULL;
+ } else {
+ if (!URI_FUNC(PushPathSegment)(state, first, afterSegmentNz, memory)) { /* SEGMENT BOTH */
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ }
+ return URI_FUNC(ParseZeroMoreSlashSegs)(state, afterSegmentNz, afterLast, memory);
+}
+
+
+
+/*
+ * [pchar]->[pctEncoded]
+ * [pchar]->[subDelims]
+ * [pchar]->[unreserved]
+ * [pchar]-><:>
+ * [pchar]-><@>
+ */
+static const URI_CHAR * URI_FUNC(ParsePchar)(URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+
+ switch (*first) {
+ case _UT('%'):
+ return URI_FUNC(ParsePctEncoded)(state, first, afterLast, memory);
+
+ case _UT(':'):
+ case _UT('@'):
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT(';'):
+ case _UT('\''):
+ case _UT('+'):
+ case _UT('='):
+ case _UT('-'):
+ case _UT('.'):
+ case _UT('_'):
+ case _UT('~'):
+ case URI_SET_DIGIT:
+ case URI_SET_ALPHA:
+ return first + 1;
+
+ default:
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+}
+
+
+
+/*
+ * [pctEncoded]-><%>[HEXDIG][HEXDIG]
+ */
+static const URI_CHAR * URI_FUNC(ParsePctEncoded)(
+ URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+
+ /*
+ First character has already been
+ checked before entering this rule.
+
+ switch (*first) {
+ case _UT('%'):
+ */
+ if (first + 1 >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+
+ switch (first[1]) {
+ case URI_SET_HEXDIG:
+ if (first + 2 >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+
+ switch (first[2]) {
+ case URI_SET_HEXDIG:
+ return first + 3;
+
+ default:
+ URI_FUNC(StopSyntax)(state, first + 2, memory);
+ return NULL;
+ }
+
+ default:
+ URI_FUNC(StopSyntax)(state, first + 1, memory);
+ return NULL;
+ }
+
+ /*
+ default:
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+ */
+}
+
+
+
+/*
+ * [pctSubUnres]->[pctEncoded]
+ * [pctSubUnres]->[subDelims]
+ * [pctSubUnres]->[unreserved]
+ */
+static const URI_CHAR * URI_FUNC(ParsePctSubUnres)(
+ URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ return NULL;
+ }
+
+ switch (*first) {
+ case _UT('%'):
+ return URI_FUNC(ParsePctEncoded)(state, first, afterLast, memory);
+
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT(';'):
+ case _UT('\''):
+ case _UT('+'):
+ case _UT('='):
+ case _UT('-'):
+ case _UT('.'):
+ case _UT('_'):
+ case _UT('~'):
+ case URI_SET_DIGIT:
+ case URI_SET_ALPHA:
+ return first + 1;
+
+ default:
+ URI_FUNC(StopSyntax)(state, first, memory);
+ return NULL;
+ }
+}
+
+
+
+/*
+ * [port]->[DIGIT][port]
+ * [port]-><NULL>
+ */
+static const URI_CHAR * URI_FUNC(ParsePort)(URI_TYPE(ParserState) * state, const URI_CHAR * first, const URI_CHAR * afterLast) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case URI_SET_DIGIT:
+ return URI_FUNC(ParsePort)(state, first + 1, afterLast);
+
+ default:
+ return first;
+ }
+}
+
+
+
+/*
+ * [queryFrag]->[pchar][queryFrag]
+ * [queryFrag]-></>[queryFrag]
+ * [queryFrag]-><?>[queryFrag]
+ * [queryFrag]-><NULL>
+ */
+static const URI_CHAR * URI_FUNC(ParseQueryFrag)(URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('%'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('-'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT('.'):
+ case _UT(':'):
+ case _UT(';'):
+ case _UT('@'):
+ case _UT('\''):
+ case _UT('_'):
+ case _UT('~'):
+ case _UT('+'):
+ case _UT('='):
+ case URI_SET_DIGIT:
+ case URI_SET_ALPHA:
+ {
+ const URI_CHAR * const afterPchar
+ = URI_FUNC(ParsePchar)(state, first, afterLast, memory);
+ if (afterPchar == NULL) {
+ return NULL;
+ }
+ return URI_FUNC(ParseQueryFrag)(state, afterPchar, afterLast, memory);
+ }
+
+ case _UT('/'):
+ case _UT('?'):
+ return URI_FUNC(ParseQueryFrag)(state, first + 1, afterLast, memory);
+
+ default:
+ return first;
+ }
+}
+
+
+
+/*
+ * [segment]->[pchar][segment]
+ * [segment]-><NULL>
+ */
+static const URI_CHAR * URI_FUNC(ParseSegment)(URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('%'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('-'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT('.'):
+ case _UT(':'):
+ case _UT(';'):
+ case _UT('@'):
+ case _UT('\''):
+ case _UT('_'):
+ case _UT('~'):
+ case _UT('+'):
+ case _UT('='):
+ case URI_SET_DIGIT:
+ case URI_SET_ALPHA:
+ {
+ const URI_CHAR * const afterPchar
+ = URI_FUNC(ParsePchar)(state, first, afterLast, memory);
+ if (afterPchar == NULL) {
+ return NULL;
+ }
+ return URI_FUNC(ParseSegment)(state, afterPchar, afterLast, memory);
+ }
+
+ default:
+ return first;
+ }
+}
+
+
+
+/*
+ * [segmentNz]->[pchar][segment]
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParseSegmentNz)(
+ URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory) {
+ const URI_CHAR * const afterPchar
+ = URI_FUNC(ParsePchar)(state, first, afterLast, memory);
+ if (afterPchar == NULL) {
+ return NULL;
+ }
+ return URI_FUNC(ParseSegment)(state, afterPchar, afterLast, memory);
+}
+
+
+
+static URI_INLINE UriBool URI_FUNC(OnExitSegmentNzNcOrScheme2)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ UriMemoryManager * memory) {
+ if (!URI_FUNC(PushPathSegment)(state, state->uri->scheme.first, first, memory)) { /* SEGMENT BOTH */
+ return URI_FALSE; /* Raises malloc error*/
+ }
+ state->uri->scheme.first = NULL; /* Not a scheme, reset */
+ return URI_TRUE; /* Success */
+}
+
+
+
+/*
+ * [segmentNzNcOrScheme2]->[ALPHA][segmentNzNcOrScheme2]
+ * [segmentNzNcOrScheme2]->[DIGIT][segmentNzNcOrScheme2]
+ * [segmentNzNcOrScheme2]->[pctEncoded][mustBeSegmentNzNc]
+ * [segmentNzNcOrScheme2]->[uriTail] // can take <NULL>
+ * [segmentNzNcOrScheme2]-><!>[mustBeSegmentNzNc]
+ * [segmentNzNcOrScheme2]-><$>[mustBeSegmentNzNc]
+ * [segmentNzNcOrScheme2]-><&>[mustBeSegmentNzNc]
+ * [segmentNzNcOrScheme2]-><(>[mustBeSegmentNzNc]
+ * [segmentNzNcOrScheme2]-><)>[mustBeSegmentNzNc]
+ * [segmentNzNcOrScheme2]-><*>[mustBeSegmentNzNc]
+ * [segmentNzNcOrScheme2]-><,>[mustBeSegmentNzNc]
+ * [segmentNzNcOrScheme2]-><.>[segmentNzNcOrScheme2]
+ * [segmentNzNcOrScheme2]-></>[segment][zeroMoreSlashSegs][uriTail]
+ * [segmentNzNcOrScheme2]-><:>[hierPart][uriTail]
+ * [segmentNzNcOrScheme2]-><;>[mustBeSegmentNzNc]
+ * [segmentNzNcOrScheme2]-><@>[mustBeSegmentNzNc]
+ * [segmentNzNcOrScheme2]-><_>[mustBeSegmentNzNc]
+ * [segmentNzNcOrScheme2]-><~>[mustBeSegmentNzNc]
+ * [segmentNzNcOrScheme2]-><+>[segmentNzNcOrScheme2]
+ * [segmentNzNcOrScheme2]-><=>[mustBeSegmentNzNc]
+ * [segmentNzNcOrScheme2]-><'>[mustBeSegmentNzNc]
+ * [segmentNzNcOrScheme2]-><->[segmentNzNcOrScheme2]
+ */
+static const URI_CHAR * URI_FUNC(ParseSegmentNzNcOrScheme2)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ if (!URI_FUNC(OnExitSegmentNzNcOrScheme2)(state, first, memory)) {
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('.'):
+ case _UT('+'):
+ case _UT('-'):
+ case URI_SET_ALPHA:
+ case URI_SET_DIGIT:
+ return URI_FUNC(ParseSegmentNzNcOrScheme2)(state, first + 1, afterLast, memory);
+
+ case _UT('%'):
+ {
+ const URI_CHAR * const afterPctEncoded
+ = URI_FUNC(ParsePctEncoded)(state, first, afterLast, memory);
+ if (afterPctEncoded == NULL) {
+ return NULL;
+ }
+ return URI_FUNC(ParseMustBeSegmentNzNc)(state, afterPctEncoded, afterLast, memory);
+ }
+
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT(';'):
+ case _UT('@'):
+ case _UT('_'):
+ case _UT('~'):
+ case _UT('='):
+ case _UT('\''):
+ return URI_FUNC(ParseMustBeSegmentNzNc)(state, first + 1, afterLast, memory);
+
+ case _UT('/'):
+ {
+ const URI_CHAR * afterZeroMoreSlashSegs;
+ const URI_CHAR * const afterSegment
+ = URI_FUNC(ParseSegment)(state, first + 1, afterLast, memory);
+ if (afterSegment == NULL) {
+ return NULL;
+ }
+ if (!URI_FUNC(PushPathSegment)(state, state->uri->scheme.first, first, memory)) { /* SEGMENT BOTH */
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ state->uri->scheme.first = NULL; /* Not a scheme, reset */
+ if (!URI_FUNC(PushPathSegment)(state, first + 1, afterSegment, memory)) { /* SEGMENT BOTH */
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ afterZeroMoreSlashSegs
+ = URI_FUNC(ParseZeroMoreSlashSegs)(state, afterSegment, afterLast, memory);
+ if (afterZeroMoreSlashSegs == NULL) {
+ return NULL;
+ }
+ return URI_FUNC(ParseUriTail)(state, afterZeroMoreSlashSegs, afterLast, memory);
+ }
+
+ case _UT(':'):
+ {
+ const URI_CHAR * const afterHierPart
+ = URI_FUNC(ParseHierPart)(state, first + 1, afterLast, memory);
+ state->uri->scheme.afterLast = first; /* SCHEME END */
+ if (afterHierPart == NULL) {
+ return NULL;
+ }
+ return URI_FUNC(ParseUriTail)(state, afterHierPart, afterLast, memory);
+ }
+
+ default:
+ if (!URI_FUNC(OnExitSegmentNzNcOrScheme2)(state, first, memory)) {
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ return URI_FUNC(ParseUriTail)(state, first, afterLast, memory);
+ }
+}
+
+
+
+/*
+ * [uriReference]->[ALPHA][segmentNzNcOrScheme2]
+ * [uriReference]->[DIGIT][mustBeSegmentNzNc]
+ * [uriReference]->[pctEncoded][mustBeSegmentNzNc]
+ * [uriReference]->[subDelims][mustBeSegmentNzNc]
+ * [uriReference]->[uriTail] // can take <NULL>
+ * [uriReference]-><.>[mustBeSegmentNzNc]
+ * [uriReference]-></>[partHelperTwo][uriTail]
+ * [uriReference]-><@>[mustBeSegmentNzNc]
+ * [uriReference]-><_>[mustBeSegmentNzNc]
+ * [uriReference]-><~>[mustBeSegmentNzNc]
+ * [uriReference]-><->[mustBeSegmentNzNc]
+ */
+static const URI_CHAR * URI_FUNC(ParseUriReference)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case URI_SET_ALPHA:
+ state->uri->scheme.first = first; /* SCHEME BEGIN */
+ return URI_FUNC(ParseSegmentNzNcOrScheme2)(state, first + 1, afterLast, memory);
+
+ case URI_SET_DIGIT:
+ case _UT('!'):
+ case _UT('$'):
+ case _UT('&'):
+ case _UT('('):
+ case _UT(')'):
+ case _UT('*'):
+ case _UT(','):
+ case _UT(';'):
+ case _UT('\''):
+ case _UT('+'):
+ case _UT('='):
+ case _UT('.'):
+ case _UT('_'):
+ case _UT('~'):
+ case _UT('-'):
+ case _UT('@'):
+ state->uri->scheme.first = first; /* SEGMENT BEGIN, ABUSE SCHEME POINTER */
+ return URI_FUNC(ParseMustBeSegmentNzNc)(state, first + 1, afterLast, memory);
+
+ case _UT('%'):
+ {
+ const URI_CHAR * const afterPctEncoded
+ = URI_FUNC(ParsePctEncoded)(state, first, afterLast, memory);
+ if (afterPctEncoded == NULL) {
+ return NULL;
+ }
+ state->uri->scheme.first = first; /* SEGMENT BEGIN, ABUSE SCHEME POINTER */
+ return URI_FUNC(ParseMustBeSegmentNzNc)(state, afterPctEncoded, afterLast, memory);
+ }
+
+ case _UT('/'):
+ {
+ const URI_CHAR * const afterPartHelperTwo
+ = URI_FUNC(ParsePartHelperTwo)(state, first + 1, afterLast, memory);
+ if (afterPartHelperTwo == NULL) {
+ return NULL;
+ }
+ return URI_FUNC(ParseUriTail)(state, afterPartHelperTwo, afterLast, memory);
+ }
+
+ default:
+ return URI_FUNC(ParseUriTail)(state, first, afterLast, memory);
+ }
+}
+
+
+
+/*
+ * [uriTail]-><#>[queryFrag]
+ * [uriTail]-><?>[queryFrag][uriTailTwo]
+ * [uriTail]-><NULL>
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParseUriTail)(
+ URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('#'):
+ {
+ const URI_CHAR * const afterQueryFrag = URI_FUNC(ParseQueryFrag)(state, first + 1, afterLast, memory);
+ if (afterQueryFrag == NULL) {
+ return NULL;
+ }
+ state->uri->fragment.first = first + 1; /* FRAGMENT BEGIN */
+ state->uri->fragment.afterLast = afterQueryFrag; /* FRAGMENT END */
+ return afterQueryFrag;
+ }
+
+ case _UT('?'):
+ {
+ const URI_CHAR * const afterQueryFrag
+ = URI_FUNC(ParseQueryFrag)(state, first + 1, afterLast, memory);
+ if (afterQueryFrag == NULL) {
+ return NULL;
+ }
+ state->uri->query.first = first + 1; /* QUERY BEGIN */
+ state->uri->query.afterLast = afterQueryFrag; /* QUERY END */
+ return URI_FUNC(ParseUriTailTwo)(state, afterQueryFrag, afterLast, memory);
+ }
+
+ default:
+ return first;
+ }
+}
+
+
+
+/*
+ * [uriTailTwo]-><#>[queryFrag]
+ * [uriTailTwo]-><NULL>
+ */
+static URI_INLINE const URI_CHAR * URI_FUNC(ParseUriTailTwo)(
+ URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('#'):
+ {
+ const URI_CHAR * const afterQueryFrag = URI_FUNC(ParseQueryFrag)(state, first + 1, afterLast, memory);
+ if (afterQueryFrag == NULL) {
+ return NULL;
+ }
+ state->uri->fragment.first = first + 1; /* FRAGMENT BEGIN */
+ state->uri->fragment.afterLast = afterQueryFrag; /* FRAGMENT END */
+ return afterQueryFrag;
+ }
+
+ default:
+ return first;
+ }
+}
+
+
+
+/*
+ * [zeroMoreSlashSegs]-></>[segment][zeroMoreSlashSegs]
+ * [zeroMoreSlashSegs]-><NULL>
+ */
+static const URI_CHAR * URI_FUNC(ParseZeroMoreSlashSegs)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ if (first >= afterLast) {
+ return afterLast;
+ }
+
+ switch (*first) {
+ case _UT('/'):
+ {
+ const URI_CHAR * const afterSegment
+ = URI_FUNC(ParseSegment)(state, first + 1, afterLast, memory);
+ if (afterSegment == NULL) {
+ return NULL;
+ }
+ if (!URI_FUNC(PushPathSegment)(state, first + 1, afterSegment, memory)) { /* SEGMENT BOTH */
+ URI_FUNC(StopMalloc)(state, memory);
+ return NULL;
+ }
+ return URI_FUNC(ParseZeroMoreSlashSegs)(state, afterSegment, afterLast, memory);
+ }
+
+ default:
+ return first;
+ }
+}
+
+
+
+static URI_INLINE void URI_FUNC(ResetParserStateExceptUri)(URI_TYPE(ParserState) * state) {
+ URI_TYPE(Uri) * const uriBackup = state->uri;
+ memset(state, 0, sizeof(URI_TYPE(ParserState)));
+ state->uri = uriBackup;
+}
+
+
+
+static URI_INLINE UriBool URI_FUNC(PushPathSegment)(
+ URI_TYPE(ParserState) * state, const URI_CHAR * first,
+ const URI_CHAR * afterLast, UriMemoryManager * memory) {
+ URI_TYPE(PathSegment) * segment = memory->calloc(memory, 1, sizeof(URI_TYPE(PathSegment)));
+ if (segment == NULL) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+ if (first == afterLast) {
+ segment->text.first = URI_FUNC(SafeToPointTo);
+ segment->text.afterLast = URI_FUNC(SafeToPointTo);
+ } else {
+ segment->text.first = first;
+ segment->text.afterLast = afterLast;
+ }
+
+ /* First segment ever? */
+ if (state->uri->pathHead == NULL) {
+ /* First segment ever, set head and tail */
+ state->uri->pathHead = segment;
+ state->uri->pathTail = segment;
+ } else {
+ /* Append, update tail */
+ state->uri->pathTail->next = segment;
+ state->uri->pathTail = segment;
+ }
+
+ return URI_TRUE; /* Success */
+}
+
+
+
+int URI_FUNC(ParseUriEx)(URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast) {
+ return URI_FUNC(ParseUriExMm)(state, first, afterLast, NULL);
+}
+
+
+
+static int URI_FUNC(ParseUriExMm)(URI_TYPE(ParserState) * state,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory) {
+ const URI_CHAR * afterUriReference;
+ URI_TYPE(Uri) * uri;
+
+ /* Check params */
+ if ((state == NULL) || (first == NULL) || (afterLast == NULL)) {
+ return URI_ERROR_NULL;
+ }
+ URI_CHECK_MEMORY_MANAGER(memory); /* may return */
+
+ uri = state->uri;
+
+ /* Init parser */
+ URI_FUNC(ResetParserStateExceptUri)(state);
+ URI_FUNC(ResetUri)(uri);
+
+ /* Parse */
+ afterUriReference = URI_FUNC(ParseUriReference)(state, first, afterLast, memory);
+ if (afterUriReference == NULL) {
+ /* Waterproof errorPos <= afterLast */
+ if (state->errorPos && (state->errorPos > afterLast)) {
+ state->errorPos = afterLast;
+ }
+ return state->errorCode;
+ }
+ if (afterUriReference != afterLast) {
+ if (afterUriReference < afterLast) {
+ URI_FUNC(StopSyntax)(state, afterUriReference, memory);
+ } else {
+ URI_FUNC(StopSyntax)(state, afterLast, memory);
+ }
+ return state->errorCode;
+ }
+ return URI_SUCCESS;
+}
+
+
+
+int URI_FUNC(ParseUri)(URI_TYPE(ParserState) * state, const URI_CHAR * text) {
+ if ((state == NULL) || (text == NULL)) {
+ return URI_ERROR_NULL;
+ }
+ return URI_FUNC(ParseUriEx)(state, text, text + URI_STRLEN(text));
+}
+
+
+
+int URI_FUNC(ParseSingleUri)(URI_TYPE(Uri) * uri, const URI_CHAR * text,
+ const URI_CHAR ** errorPos) {
+ return URI_FUNC(ParseSingleUriEx)(uri, text, NULL, errorPos);
+}
+
+
+
+int URI_FUNC(ParseSingleUriEx)(URI_TYPE(Uri) * uri,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ const URI_CHAR ** errorPos) {
+ if ((afterLast == NULL) && (first != NULL)) {
+ afterLast = first + URI_STRLEN(first);
+ }
+ return URI_FUNC(ParseSingleUriExMm)(uri, first, afterLast, errorPos, NULL);
+}
+
+
+
+int URI_FUNC(ParseSingleUriExMm)(URI_TYPE(Uri) * uri,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ const URI_CHAR ** errorPos, UriMemoryManager * memory) {
+ URI_TYPE(ParserState) state;
+ int res;
+
+ /* Check params */
+ if ((uri == NULL) || (first == NULL) || (afterLast == NULL)) {
+ return URI_ERROR_NULL;
+ }
+ URI_CHECK_MEMORY_MANAGER(memory); /* may return */
+
+ state.uri = uri;
+
+ res = URI_FUNC(ParseUriExMm)(&state, first, afterLast, memory);
+
+ if (res != URI_SUCCESS) {
+ if (errorPos != NULL) {
+ *errorPos = state.errorPos;
+ }
+ URI_FUNC(FreeUriMembersMm)(uri, memory);
+ }
+
+ return res;
+}
+
+
+
+void URI_FUNC(FreeUriMembers)(URI_TYPE(Uri) * uri) {
+ URI_FUNC(FreeUriMembersMm)(uri, NULL);
+}
+
+
+
+int URI_FUNC(FreeUriMembersMm)(URI_TYPE(Uri) * uri, UriMemoryManager * memory) {
+ if (uri == NULL) {
+ return URI_ERROR_NULL;
+ }
+
+ URI_CHECK_MEMORY_MANAGER(memory); /* may return */
+
+ if (uri->owner) {
+ /* Scheme */
+ if (uri->scheme.first != NULL) {
+ if (uri->scheme.first != uri->scheme.afterLast) {
+ memory->free(memory, (URI_CHAR *)uri->scheme.first);
+ }
+ uri->scheme.first = NULL;
+ uri->scheme.afterLast = NULL;
+ }
+
+ /* User info */
+ if (uri->userInfo.first != NULL) {
+ if (uri->userInfo.first != uri->userInfo.afterLast) {
+ memory->free(memory, (URI_CHAR *)uri->userInfo.first);
+ }
+ uri->userInfo.first = NULL;
+ uri->userInfo.afterLast = NULL;
+ }
+
+ /* Host data - IPvFuture */
+ if (uri->hostData.ipFuture.first != NULL) {
+ if (uri->hostData.ipFuture.first != uri->hostData.ipFuture.afterLast) {
+ memory->free(memory, (URI_CHAR *)uri->hostData.ipFuture.first);
+ }
+ uri->hostData.ipFuture.first = NULL;
+ uri->hostData.ipFuture.afterLast = NULL;
+ uri->hostText.first = NULL;
+ uri->hostText.afterLast = NULL;
+ }
+
+ /* Host text (if regname, after IPvFuture!) */
+ if ((uri->hostText.first != NULL)
+ && (uri->hostData.ip4 == NULL)
+ && (uri->hostData.ip6 == NULL)) {
+ /* Real regname */
+ if (uri->hostText.first != uri->hostText.afterLast) {
+ memory->free(memory, (URI_CHAR *)uri->hostText.first);
+ }
+ uri->hostText.first = NULL;
+ uri->hostText.afterLast = NULL;
+ }
+ }
+
+ /* Host data - IPv4 */
+ if (uri->hostData.ip4 != NULL) {
+ memory->free(memory, uri->hostData.ip4);
+ uri->hostData.ip4 = NULL;
+ }
+
+ /* Host data - IPv6 */
+ if (uri->hostData.ip6 != NULL) {
+ memory->free(memory, uri->hostData.ip6);
+ uri->hostData.ip6 = NULL;
+ }
+
+ /* Port text */
+ if (uri->owner && (uri->portText.first != NULL)) {
+ if (uri->portText.first != uri->portText.afterLast) {
+ memory->free(memory, (URI_CHAR *)uri->portText.first);
+ }
+ uri->portText.first = NULL;
+ uri->portText.afterLast = NULL;
+ }
+
+ /* Path */
+ if (uri->pathHead != NULL) {
+ URI_TYPE(PathSegment) * segWalk = uri->pathHead;
+ while (segWalk != NULL) {
+ URI_TYPE(PathSegment) * const next = segWalk->next;
+ if (uri->owner && (segWalk->text.first != NULL)
+ && (segWalk->text.first < segWalk->text.afterLast)) {
+ memory->free(memory, (URI_CHAR *)segWalk->text.first);
+ }
+ memory->free(memory, segWalk);
+ segWalk = next;
+ }
+ uri->pathHead = NULL;
+ uri->pathTail = NULL;
+ }
+
+ if (uri->owner) {
+ /* Query */
+ if (uri->query.first != NULL) {
+ if (uri->query.first != uri->query.afterLast) {
+ memory->free(memory, (URI_CHAR *)uri->query.first);
+ }
+ uri->query.first = NULL;
+ uri->query.afterLast = NULL;
+ }
+
+ /* Fragment */
+ if (uri->fragment.first != NULL) {
+ if (uri->fragment.first != uri->fragment.afterLast) {
+ memory->free(memory, (URI_CHAR *)uri->fragment.first);
+ }
+ uri->fragment.first = NULL;
+ uri->fragment.afterLast = NULL;
+ }
+ }
+
+ return URI_SUCCESS;
+}
+
+
+
+UriBool URI_FUNC(_TESTING_ONLY_ParseIpSix)(const URI_CHAR * text) {
+ UriMemoryManager * const memory = &defaultMemoryManager;
+ URI_TYPE(Uri) uri;
+ URI_TYPE(ParserState) parser;
+ const URI_CHAR * const afterIpSix = text + URI_STRLEN(text);
+ const URI_CHAR * res;
+
+ URI_FUNC(ResetUri)(&uri);
+ parser.uri = &uri;
+ URI_FUNC(ResetParserStateExceptUri)(&parser);
+ parser.uri->hostData.ip6 = memory->malloc(memory, 1 * sizeof(UriIp6));
+ res = URI_FUNC(ParseIPv6address2)(&parser, text, afterIpSix, memory);
+ URI_FUNC(FreeUriMembersMm)(&uri, memory);
+ return res == afterIpSix ? URI_TRUE : URI_FALSE;
+}
+
+
+
+UriBool URI_FUNC(_TESTING_ONLY_ParseIpFour)(const URI_CHAR * text) {
+ unsigned char octets[4];
+ int res = URI_FUNC(ParseIpFourAddress)(octets, text, text + URI_STRLEN(text));
+ return (res == URI_SUCCESS) ? URI_TRUE : URI_FALSE;
+}
+
+
+
+#undef URI_SET_DIGIT
+#undef URI_SET_HEX_LETTER_UPPER
+#undef URI_SET_HEX_LETTER_LOWER
+#undef URI_SET_HEXDIG
+#undef URI_SET_ALPHA
+
+
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriParseBase.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriParseBase.c
new file mode 100644
index 000000000..1d4ef6e2b
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriParseBase.c
@@ -0,0 +1,90 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#ifndef URI_DOXYGEN
+# include "UriParseBase.h"
+#endif
+
+
+
+void uriWriteQuadToDoubleByte(const unsigned char * hexDigits, int digitCount, unsigned char * output) {
+ switch (digitCount) {
+ case 1:
+ /* 0x___? -> \x00 \x0? */
+ output[0] = 0;
+ output[1] = hexDigits[0];
+ break;
+
+ case 2:
+ /* 0x__?? -> \0xx \x?? */
+ output[0] = 0;
+ output[1] = 16 * hexDigits[0] + hexDigits[1];
+ break;
+
+ case 3:
+ /* 0x_??? -> \0x? \x?? */
+ output[0] = hexDigits[0];
+ output[1] = 16 * hexDigits[1] + hexDigits[2];
+ break;
+
+ case 4:
+ /* 0x???? -> \0?? \x?? */
+ output[0] = 16 * hexDigits[0] + hexDigits[1];
+ output[1] = 16 * hexDigits[2] + hexDigits[3];
+ break;
+
+ }
+}
+
+
+
+unsigned char uriGetOctetValue(const unsigned char * digits, int digitCount) {
+ switch (digitCount) {
+ case 1:
+ return digits[0];
+
+ case 2:
+ return 10 * digits[0] + digits[1];
+
+ case 3:
+ default:
+ return 100 * digits[0] + 10 * digits[1] + digits[2];
+
+ }
+}
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriParseBase.h b/src/arrow/cpp/src/arrow/vendored/uriparser/UriParseBase.h
new file mode 100644
index 000000000..eea9ffa29
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriParseBase.h
@@ -0,0 +1,55 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#ifndef URI_PARSE_BASE_H
+#define URI_PARSE_BASE_H 1
+
+
+
+#include "UriBase.h"
+
+
+
+void uriWriteQuadToDoubleByte(const unsigned char * hexDigits, int digitCount,
+ unsigned char * output);
+unsigned char uriGetOctetValue(const unsigned char * digits, int digitCount);
+
+
+
+#endif /* URI_PARSE_BASE_H */
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriQuery.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriQuery.c
new file mode 100644
index 000000000..42a74aeec
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriQuery.c
@@ -0,0 +1,501 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/* What encodings are enabled? */
+#include "UriDefsConfig.h"
+#if (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* Include SELF twice */
+# ifdef URI_ENABLE_ANSI
+# define URI_PASS_ANSI 1
+# include "UriQuery.c"
+# undef URI_PASS_ANSI
+# endif
+# ifdef URI_ENABLE_UNICODE
+# define URI_PASS_UNICODE 1
+# include "UriQuery.c"
+# undef URI_PASS_UNICODE
+# endif
+#else
+# ifdef URI_PASS_ANSI
+# include "UriDefsAnsi.h"
+# else
+# include "UriDefsUnicode.h"
+# include <wchar.h>
+# endif
+
+
+
+#ifndef URI_DOXYGEN
+# include "Uri.h"
+# include "UriCommon.h"
+# include "UriMemory.h"
+#endif
+
+
+
+#include <limits.h>
+
+
+
+static int URI_FUNC(ComposeQueryEngine)(URI_CHAR * dest,
+ const URI_TYPE(QueryList) * queryList,
+ int maxChars, int * charsWritten, int * charsRequired,
+ UriBool spaceToPlus, UriBool normalizeBreaks);
+
+static UriBool URI_FUNC(AppendQueryItem)(URI_TYPE(QueryList) ** prevNext,
+ int * itemCount, const URI_CHAR * keyFirst, const URI_CHAR * keyAfter,
+ const URI_CHAR * valueFirst, const URI_CHAR * valueAfter,
+ UriBool plusToSpace, UriBreakConversion breakConversion,
+ UriMemoryManager * memory);
+
+
+
+int URI_FUNC(ComposeQueryCharsRequired)(const URI_TYPE(QueryList) * queryList,
+ int * charsRequired) {
+ const UriBool spaceToPlus = URI_TRUE;
+ const UriBool normalizeBreaks = URI_TRUE;
+
+ return URI_FUNC(ComposeQueryCharsRequiredEx)(queryList, charsRequired,
+ spaceToPlus, normalizeBreaks);
+}
+
+
+
+int URI_FUNC(ComposeQueryCharsRequiredEx)(const URI_TYPE(QueryList) * queryList,
+ int * charsRequired, UriBool spaceToPlus, UriBool normalizeBreaks) {
+ if ((queryList == NULL) || (charsRequired == NULL)) {
+ return URI_ERROR_NULL;
+ }
+
+ return URI_FUNC(ComposeQueryEngine)(NULL, queryList, 0, NULL,
+ charsRequired, spaceToPlus, normalizeBreaks);
+}
+
+
+
+int URI_FUNC(ComposeQuery)(URI_CHAR * dest,
+ const URI_TYPE(QueryList) * queryList, int maxChars, int * charsWritten) {
+ const UriBool spaceToPlus = URI_TRUE;
+ const UriBool normalizeBreaks = URI_TRUE;
+
+ return URI_FUNC(ComposeQueryEx)(dest, queryList, maxChars, charsWritten,
+ spaceToPlus, normalizeBreaks);
+}
+
+
+
+int URI_FUNC(ComposeQueryEx)(URI_CHAR * dest,
+ const URI_TYPE(QueryList) * queryList, int maxChars, int * charsWritten,
+ UriBool spaceToPlus, UriBool normalizeBreaks) {
+ if ((dest == NULL) || (queryList == NULL)) {
+ return URI_ERROR_NULL;
+ }
+
+ if (maxChars < 1) {
+ return URI_ERROR_OUTPUT_TOO_LARGE;
+ }
+
+ return URI_FUNC(ComposeQueryEngine)(dest, queryList, maxChars,
+ charsWritten, NULL, spaceToPlus, normalizeBreaks);
+}
+
+
+
+int URI_FUNC(ComposeQueryMalloc)(URI_CHAR ** dest,
+ const URI_TYPE(QueryList) * queryList) {
+ const UriBool spaceToPlus = URI_TRUE;
+ const UriBool normalizeBreaks = URI_TRUE;
+
+ return URI_FUNC(ComposeQueryMallocEx)(dest, queryList,
+ spaceToPlus, normalizeBreaks);
+}
+
+
+
+int URI_FUNC(ComposeQueryMallocEx)(URI_CHAR ** dest,
+ const URI_TYPE(QueryList) * queryList,
+ UriBool spaceToPlus, UriBool normalizeBreaks) {
+ return URI_FUNC(ComposeQueryMallocExMm)(dest, queryList, spaceToPlus,
+ normalizeBreaks, NULL);
+}
+
+
+
+int URI_FUNC(ComposeQueryMallocExMm)(URI_CHAR ** dest,
+ const URI_TYPE(QueryList) * queryList,
+ UriBool spaceToPlus, UriBool normalizeBreaks,
+ UriMemoryManager * memory) {
+ int charsRequired;
+ int res;
+ URI_CHAR * queryString;
+
+ if (dest == NULL) {
+ return URI_ERROR_NULL;
+ }
+
+ URI_CHECK_MEMORY_MANAGER(memory); /* may return */
+
+ /* Calculate space */
+ res = URI_FUNC(ComposeQueryCharsRequiredEx)(queryList, &charsRequired,
+ spaceToPlus, normalizeBreaks);
+ if (res != URI_SUCCESS) {
+ return res;
+ }
+ charsRequired++;
+
+ /* Allocate space */
+ queryString = memory->malloc(memory, charsRequired * sizeof(URI_CHAR));
+ if (queryString == NULL) {
+ return URI_ERROR_MALLOC;
+ }
+
+ /* Put query in */
+ res = URI_FUNC(ComposeQueryEx)(queryString, queryList, charsRequired,
+ NULL, spaceToPlus, normalizeBreaks);
+ if (res != URI_SUCCESS) {
+ memory->free(memory, queryString);
+ return res;
+ }
+
+ *dest = queryString;
+ return URI_SUCCESS;
+}
+
+
+
+int URI_FUNC(ComposeQueryEngine)(URI_CHAR * dest,
+ const URI_TYPE(QueryList) * queryList,
+ int maxChars, int * charsWritten, int * charsRequired,
+ UriBool spaceToPlus, UriBool normalizeBreaks) {
+ UriBool firstItem = URI_TRUE;
+ int ampersandLen = 0; /* increased to 1 from second item on */
+ URI_CHAR * write = dest;
+
+ /* Subtract terminator */
+ if (dest == NULL) {
+ *charsRequired = 0;
+ } else {
+ maxChars--;
+ }
+
+ while (queryList != NULL) {
+ const URI_CHAR * const key = queryList->key;
+ const URI_CHAR * const value = queryList->value;
+ const int worstCase = (normalizeBreaks == URI_TRUE ? 6 : 3);
+ const int keyLen = (key == NULL) ? 0 : (int)URI_STRLEN(key);
+ int keyRequiredChars;
+ const int valueLen = (value == NULL) ? 0 : (int)URI_STRLEN(value);
+ int valueRequiredChars;
+
+ if ((keyLen >= INT_MAX / worstCase) || (valueLen >= INT_MAX / worstCase)) {
+ return URI_ERROR_OUTPUT_TOO_LARGE;
+ }
+ keyRequiredChars = worstCase * keyLen;
+ valueRequiredChars = worstCase * valueLen;
+
+ if (dest == NULL) {
+ (*charsRequired) += ampersandLen + keyRequiredChars + ((value == NULL)
+ ? 0
+ : 1 + valueRequiredChars);
+
+ if (firstItem == URI_TRUE) {
+ ampersandLen = 1;
+ firstItem = URI_FALSE;
+ }
+ } else {
+ if ((write - dest) + ampersandLen + keyRequiredChars > maxChars) {
+ return URI_ERROR_OUTPUT_TOO_LARGE;
+ }
+
+ /* Copy key */
+ if (firstItem == URI_TRUE) {
+ ampersandLen = 1;
+ firstItem = URI_FALSE;
+ } else {
+ write[0] = _UT('&');
+ write++;
+ }
+ write = URI_FUNC(EscapeEx)(key, key + keyLen,
+ write, spaceToPlus, normalizeBreaks);
+
+ if (value != NULL) {
+ if ((write - dest) + 1 + valueRequiredChars > maxChars) {
+ return URI_ERROR_OUTPUT_TOO_LARGE;
+ }
+
+ /* Copy value */
+ write[0] = _UT('=');
+ write++;
+ write = URI_FUNC(EscapeEx)(value, value + valueLen,
+ write, spaceToPlus, normalizeBreaks);
+ }
+ }
+
+ queryList = queryList->next;
+ }
+
+ if (dest != NULL) {
+ write[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = (int)(write - dest) + 1; /* .. for terminator */
+ }
+ }
+
+ return URI_SUCCESS;
+}
+
+
+
+UriBool URI_FUNC(AppendQueryItem)(URI_TYPE(QueryList) ** prevNext,
+ int * itemCount, const URI_CHAR * keyFirst, const URI_CHAR * keyAfter,
+ const URI_CHAR * valueFirst, const URI_CHAR * valueAfter,
+ UriBool plusToSpace, UriBreakConversion breakConversion,
+ UriMemoryManager * memory) {
+ const int keyLen = (int)(keyAfter - keyFirst);
+ const int valueLen = (int)(valueAfter - valueFirst);
+ URI_CHAR * key;
+ URI_CHAR * value;
+
+ if ((prevNext == NULL) || (itemCount == NULL)
+ || (keyFirst == NULL) || (keyAfter == NULL)
+ || (keyFirst > keyAfter) || (valueFirst > valueAfter)
+ || ((keyFirst == keyAfter)
+ && (valueFirst == NULL) && (valueAfter == NULL))) {
+ return URI_TRUE;
+ }
+
+ /* Append new empty item */
+ *prevNext = memory->malloc(memory, 1 * sizeof(URI_TYPE(QueryList)));
+ if (*prevNext == NULL) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+ (*prevNext)->next = NULL;
+
+
+ /* Fill key */
+ key = memory->malloc(memory, (keyLen + 1) * sizeof(URI_CHAR));
+ if (key == NULL) {
+ memory->free(memory, *prevNext);
+ *prevNext = NULL;
+ return URI_FALSE; /* Raises malloc error */
+ }
+
+ key[keyLen] = _UT('\0');
+ if (keyLen > 0) {
+ /* Copy 1:1 */
+ memcpy(key, keyFirst, keyLen * sizeof(URI_CHAR));
+
+ /* Unescape */
+ URI_FUNC(UnescapeInPlaceEx)(key, plusToSpace, breakConversion);
+ }
+ (*prevNext)->key = key;
+
+
+ /* Fill value */
+ if (valueFirst != NULL) {
+ value = memory->malloc(memory, (valueLen + 1) * sizeof(URI_CHAR));
+ if (value == NULL) {
+ memory->free(memory, key);
+ memory->free(memory, *prevNext);
+ *prevNext = NULL;
+ return URI_FALSE; /* Raises malloc error */
+ }
+
+ value[valueLen] = _UT('\0');
+ if (valueLen > 0) {
+ /* Copy 1:1 */
+ memcpy(value, valueFirst, valueLen * sizeof(URI_CHAR));
+
+ /* Unescape */
+ URI_FUNC(UnescapeInPlaceEx)(value, plusToSpace, breakConversion);
+ }
+ (*prevNext)->value = value;
+ } else {
+ value = NULL;
+ }
+ (*prevNext)->value = value;
+
+ (*itemCount)++;
+ return URI_TRUE;
+}
+
+
+
+void URI_FUNC(FreeQueryList)(URI_TYPE(QueryList) * queryList) {
+ URI_FUNC(FreeQueryListMm)(queryList, NULL);
+}
+
+
+
+int URI_FUNC(FreeQueryListMm)(URI_TYPE(QueryList) * queryList,
+ UriMemoryManager * memory) {
+ URI_CHECK_MEMORY_MANAGER(memory); /* may return */
+ while (queryList != NULL) {
+ URI_TYPE(QueryList) * nextBackup = queryList->next;
+ memory->free(memory, (URI_CHAR *)queryList->key); /* const cast */
+ memory->free(memory, (URI_CHAR *)queryList->value); /* const cast */
+ memory->free(memory, queryList);
+ queryList = nextBackup;
+ }
+ return URI_SUCCESS;
+}
+
+
+
+int URI_FUNC(DissectQueryMalloc)(URI_TYPE(QueryList) ** dest, int * itemCount,
+ const URI_CHAR * first, const URI_CHAR * afterLast) {
+ const UriBool plusToSpace = URI_TRUE;
+ const UriBreakConversion breakConversion = URI_BR_DONT_TOUCH;
+
+ return URI_FUNC(DissectQueryMallocEx)(dest, itemCount, first, afterLast,
+ plusToSpace, breakConversion);
+}
+
+
+
+int URI_FUNC(DissectQueryMallocEx)(URI_TYPE(QueryList) ** dest, int * itemCount,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriBool plusToSpace, UriBreakConversion breakConversion) {
+ return URI_FUNC(DissectQueryMallocExMm)(dest, itemCount, first, afterLast,
+ plusToSpace, breakConversion, NULL);
+}
+
+
+
+int URI_FUNC(DissectQueryMallocExMm)(URI_TYPE(QueryList) ** dest, int * itemCount,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriBool plusToSpace, UriBreakConversion breakConversion,
+ UriMemoryManager * memory) {
+ const URI_CHAR * walk = first;
+ const URI_CHAR * keyFirst = first;
+ const URI_CHAR * keyAfter = NULL;
+ const URI_CHAR * valueFirst = NULL;
+ const URI_CHAR * valueAfter = NULL;
+ URI_TYPE(QueryList) ** prevNext = dest;
+ int nullCounter;
+ int * itemsAppended = (itemCount == NULL) ? &nullCounter : itemCount;
+
+ if ((dest == NULL) || (first == NULL) || (afterLast == NULL)) {
+ return URI_ERROR_NULL;
+ }
+
+ if (first > afterLast) {
+ return URI_ERROR_RANGE_INVALID;
+ }
+
+ URI_CHECK_MEMORY_MANAGER(memory); /* may return */
+
+ *dest = NULL;
+ *itemsAppended = 0;
+
+ /* Parse query string */
+ for (; walk < afterLast; walk++) {
+ switch (*walk) {
+ case _UT('&'):
+ if (valueFirst != NULL) {
+ valueAfter = walk;
+ } else {
+ keyAfter = walk;
+ }
+
+ if (URI_FUNC(AppendQueryItem)(prevNext, itemsAppended,
+ keyFirst, keyAfter, valueFirst, valueAfter,
+ plusToSpace, breakConversion, memory)
+ == URI_FALSE) {
+ /* Free list we built */
+ *itemsAppended = 0;
+ URI_FUNC(FreeQueryListMm)(*dest, memory);
+ return URI_ERROR_MALLOC;
+ }
+
+ /* Make future items children of the current */
+ if ((prevNext != NULL) && (*prevNext != NULL)) {
+ prevNext = &((*prevNext)->next);
+ }
+
+ if (walk + 1 < afterLast) {
+ keyFirst = walk + 1;
+ } else {
+ keyFirst = NULL;
+ }
+ keyAfter = NULL;
+ valueFirst = NULL;
+ valueAfter = NULL;
+ break;
+
+ case _UT('='):
+ /* NOTE: WE treat the first '=' as a separator, */
+ /* all following go into the value part */
+ if (keyAfter == NULL) {
+ keyAfter = walk;
+ if (walk + 1 <= afterLast) {
+ valueFirst = walk + 1;
+ valueAfter = walk + 1;
+ }
+ }
+ break;
+
+ default:
+ break;
+ }
+ }
+
+ if (valueFirst != NULL) {
+ /* Must be key/value pair */
+ valueAfter = walk;
+ } else {
+ /* Must be key only */
+ keyAfter = walk;
+ }
+
+ if (URI_FUNC(AppendQueryItem)(prevNext, itemsAppended, keyFirst, keyAfter,
+ valueFirst, valueAfter, plusToSpace, breakConversion, memory)
+ == URI_FALSE) {
+ /* Free list we built */
+ *itemsAppended = 0;
+ URI_FUNC(FreeQueryListMm)(*dest, memory);
+ return URI_ERROR_MALLOC;
+ }
+
+ return URI_SUCCESS;
+}
+
+
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriRecompose.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriRecompose.c
new file mode 100644
index 000000000..2f987215f
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriRecompose.c
@@ -0,0 +1,577 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/* What encodings are enabled? */
+#include "UriDefsConfig.h"
+#if (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* Include SELF twice */
+# ifdef URI_ENABLE_ANSI
+# define URI_PASS_ANSI 1
+# include "UriRecompose.c"
+# undef URI_PASS_ANSI
+# endif
+# ifdef URI_ENABLE_UNICODE
+# define URI_PASS_UNICODE 1
+# include "UriRecompose.c"
+# undef URI_PASS_UNICODE
+# endif
+#else
+# ifdef URI_PASS_ANSI
+# include "UriDefsAnsi.h"
+# else
+# include "UriDefsUnicode.h"
+# include <wchar.h>
+# endif
+
+
+
+#ifndef URI_DOXYGEN
+# include "Uri.h"
+# include "UriCommon.h"
+#endif
+
+
+
+static int URI_FUNC(ToStringEngine)(URI_CHAR * dest, const URI_TYPE(Uri) * uri,
+ int maxChars, int * charsWritten, int * charsRequired);
+
+
+
+int URI_FUNC(ToStringCharsRequired)(const URI_TYPE(Uri) * uri,
+ int * charsRequired) {
+ const int MAX_CHARS = ((unsigned int)-1) >> 1;
+ return URI_FUNC(ToStringEngine)(NULL, uri, MAX_CHARS, NULL, charsRequired);
+}
+
+
+
+int URI_FUNC(ToString)(URI_CHAR * dest, const URI_TYPE(Uri) * uri,
+ int maxChars, int * charsWritten) {
+ return URI_FUNC(ToStringEngine)(dest, uri, maxChars, charsWritten, NULL);
+}
+
+
+
+static URI_INLINE int URI_FUNC(ToStringEngine)(URI_CHAR * dest,
+ const URI_TYPE(Uri) * uri, int maxChars, int * charsWritten,
+ int * charsRequired) {
+ int written = 0;
+ if ((uri == NULL) || ((dest == NULL) && (charsRequired == NULL))) {
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_NULL;
+ }
+
+ if (maxChars < 1) {
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ maxChars--; /* So we don't have to subtract 1 for '\0' all the time */
+
+ /* [01/19] result = "" */
+ if (dest != NULL) {
+ dest[0] = _UT('\0');
+ } else {
+ (*charsRequired) = 0;
+ }
+ /* [02/19] if defined(scheme) then */
+ if (uri->scheme.first != NULL) {
+ /* [03/19] append scheme to result; */
+ const int charsToWrite
+ = (int)(uri->scheme.afterLast - uri->scheme.first);
+ if (dest != NULL) {
+ if (written + charsToWrite <= maxChars) {
+ memcpy(dest + written, uri->scheme.first,
+ charsToWrite * sizeof(URI_CHAR));
+ written += charsToWrite;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += charsToWrite;
+ }
+ /* [04/19] append ":" to result; */
+ if (dest != NULL) {
+ if (written + 1 <= maxChars) {
+ memcpy(dest + written, _UT(":"),
+ 1 * sizeof(URI_CHAR));
+ written += 1;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += 1;
+ }
+ /* [05/19] endif; */
+ }
+ /* [06/19] if defined(authority) then */
+ if (URI_FUNC(IsHostSet)(uri)) {
+ /* [07/19] append "//" to result; */
+ if (dest != NULL) {
+ if (written + 2 <= maxChars) {
+ memcpy(dest + written, _UT("//"),
+ 2 * sizeof(URI_CHAR));
+ written += 2;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += 2;
+ }
+ /* [08/19] append authority to result; */
+ /* UserInfo */
+ if (uri->userInfo.first != NULL) {
+ const int charsToWrite = (int)(uri->userInfo.afterLast - uri->userInfo.first);
+ if (dest != NULL) {
+ if (written + charsToWrite <= maxChars) {
+ memcpy(dest + written, uri->userInfo.first,
+ charsToWrite * sizeof(URI_CHAR));
+ written += charsToWrite;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+
+ if (written + 1 <= maxChars) {
+ memcpy(dest + written, _UT("@"),
+ 1 * sizeof(URI_CHAR));
+ written += 1;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += charsToWrite + 1;
+ }
+ }
+
+ /* Host */
+ if (uri->hostData.ip4 != NULL) {
+ /* IPv4 */
+ int i = 0;
+ for (; i < 4; i++) {
+ const unsigned char value = uri->hostData.ip4->data[i];
+ const int charsToWrite = (value > 99) ? 3 : ((value > 9) ? 2 : 1);
+ if (dest != NULL) {
+ if (written + charsToWrite <= maxChars) {
+ URI_CHAR text[4];
+ if (value > 99) {
+ text[0] = _UT('0') + (value / 100);
+ text[1] = _UT('0') + ((value % 100) / 10);
+ text[2] = _UT('0') + (value % 10);
+ } else if (value > 9) {
+ text[0] = _UT('0') + (value / 10);
+ text[1] = _UT('0') + (value % 10);
+ } else {
+ text[0] = _UT('0') + value;
+ }
+ text[charsToWrite] = _UT('\0');
+ memcpy(dest + written, text, charsToWrite * sizeof(URI_CHAR));
+ written += charsToWrite;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ if (i < 3) {
+ if (written + 1 <= maxChars) {
+ memcpy(dest + written, _UT("."),
+ 1 * sizeof(URI_CHAR));
+ written += 1;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ }
+ } else {
+ (*charsRequired) += charsToWrite + ((i == 3) ? 0 : 1);
+ }
+ }
+ } else if (uri->hostData.ip6 != NULL) {
+ /* IPv6 */
+ int i = 0;
+ if (dest != NULL) {
+ if (written + 1 <= maxChars) {
+ memcpy(dest + written, _UT("["),
+ 1 * sizeof(URI_CHAR));
+ written += 1;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += 1;
+ }
+
+ for (; i < 16; i++) {
+ const unsigned char value = uri->hostData.ip6->data[i];
+ if (dest != NULL) {
+ if (written + 2 <= maxChars) {
+ URI_CHAR text[3];
+ text[0] = URI_FUNC(HexToLetterEx)(value / 16, URI_FALSE);
+ text[1] = URI_FUNC(HexToLetterEx)(value % 16, URI_FALSE);
+ text[2] = _UT('\0');
+ memcpy(dest + written, text, 2 * sizeof(URI_CHAR));
+ written += 2;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += 2;
+ }
+ if (((i & 1) == 1) && (i < 15)) {
+ if (dest != NULL) {
+ if (written + 1 <= maxChars) {
+ memcpy(dest + written, _UT(":"),
+ 1 * sizeof(URI_CHAR));
+ written += 1;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += 1;
+ }
+ }
+ }
+
+ if (dest != NULL) {
+ if (written + 1 <= maxChars) {
+ memcpy(dest + written, _UT("]"),
+ 1 * sizeof(URI_CHAR));
+ written += 1;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += 1;
+ }
+ } else if (uri->hostData.ipFuture.first != NULL) {
+ /* IPvFuture */
+ const int charsToWrite = (int)(uri->hostData.ipFuture.afterLast
+ - uri->hostData.ipFuture.first);
+ if (dest != NULL) {
+ if (written + 1 <= maxChars) {
+ memcpy(dest + written, _UT("["),
+ 1 * sizeof(URI_CHAR));
+ written += 1;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+
+ if (written + charsToWrite <= maxChars) {
+ memcpy(dest + written, uri->hostData.ipFuture.first, charsToWrite * sizeof(URI_CHAR));
+ written += charsToWrite;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+
+ if (written + 1 <= maxChars) {
+ memcpy(dest + written, _UT("]"),
+ 1 * sizeof(URI_CHAR));
+ written += 1;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += 1 + charsToWrite + 1;
+ }
+ } else if (uri->hostText.first != NULL) {
+ /* Regname */
+ const int charsToWrite = (int)(uri->hostText.afterLast - uri->hostText.first);
+ if (dest != NULL) {
+ if (written + charsToWrite <= maxChars) {
+ memcpy(dest + written, uri->hostText.first,
+ charsToWrite * sizeof(URI_CHAR));
+ written += charsToWrite;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += charsToWrite;
+ }
+ }
+
+ /* Port */
+ if (uri->portText.first != NULL) {
+ const int charsToWrite = (int)(uri->portText.afterLast - uri->portText.first);
+ if (dest != NULL) {
+ /* Leading ':' */
+ if (written + 1 <= maxChars) {
+ memcpy(dest + written, _UT(":"),
+ 1 * sizeof(URI_CHAR));
+ written += 1;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+
+ /* Port number */
+ if (written + charsToWrite <= maxChars) {
+ memcpy(dest + written, uri->portText.first,
+ charsToWrite * sizeof(URI_CHAR));
+ written += charsToWrite;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += 1 + charsToWrite;
+ }
+ }
+ /* [09/19] endif; */
+ }
+ /* [10/19] append path to result; */
+ /* Slash needed here? */
+ if (uri->absolutePath || ((uri->pathHead != NULL)
+ && URI_FUNC(IsHostSet)(uri))) {
+ if (dest != NULL) {
+ if (written + 1 <= maxChars) {
+ memcpy(dest + written, _UT("/"),
+ 1 * sizeof(URI_CHAR));
+ written += 1;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += 1;
+ }
+ }
+
+ if (uri->pathHead != NULL) {
+ URI_TYPE(PathSegment) * walker = uri->pathHead;
+ do {
+ const int charsToWrite = (int)(walker->text.afterLast - walker->text.first);
+ if (dest != NULL) {
+ if (written + charsToWrite <= maxChars) {
+ memcpy(dest + written, walker->text.first,
+ charsToWrite * sizeof(URI_CHAR));
+ written += charsToWrite;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += charsToWrite;
+ }
+
+ /* Not last segment -> append slash */
+ if (walker->next != NULL) {
+ if (dest != NULL) {
+ if (written + 1 <= maxChars) {
+ memcpy(dest + written, _UT("/"),
+ 1 * sizeof(URI_CHAR));
+ written += 1;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += 1;
+ }
+ }
+
+ walker = walker->next;
+ } while (walker != NULL);
+ }
+ /* [11/19] if defined(query) then */
+ if (uri->query.first != NULL) {
+ /* [12/19] append "?" to result; */
+ if (dest != NULL) {
+ if (written + 1 <= maxChars) {
+ memcpy(dest + written, _UT("?"),
+ 1 * sizeof(URI_CHAR));
+ written += 1;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += 1;
+ }
+ /* [13/19] append query to result; */
+ {
+ const int charsToWrite
+ = (int)(uri->query.afterLast - uri->query.first);
+ if (dest != NULL) {
+ if (written + charsToWrite <= maxChars) {
+ memcpy(dest + written, uri->query.first,
+ charsToWrite * sizeof(URI_CHAR));
+ written += charsToWrite;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += charsToWrite;
+ }
+ }
+ /* [14/19] endif; */
+ }
+ /* [15/19] if defined(fragment) then */
+ if (uri->fragment.first != NULL) {
+ /* [16/19] append "#" to result; */
+ if (dest != NULL) {
+ if (written + 1 <= maxChars) {
+ memcpy(dest + written, _UT("#"),
+ 1 * sizeof(URI_CHAR));
+ written += 1;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += 1;
+ }
+ /* [17/19] append fragment to result; */
+ {
+ const int charsToWrite
+ = (int)(uri->fragment.afterLast - uri->fragment.first);
+ if (dest != NULL) {
+ if (written + charsToWrite <= maxChars) {
+ memcpy(dest + written, uri->fragment.first,
+ charsToWrite * sizeof(URI_CHAR));
+ written += charsToWrite;
+ } else {
+ dest[0] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = 0;
+ }
+ return URI_ERROR_TOSTRING_TOO_LONG;
+ }
+ } else {
+ (*charsRequired) += charsToWrite;
+ }
+ }
+ /* [18/19] endif; */
+ }
+ /* [19/19] return result; */
+ if (dest != NULL) {
+ dest[written++] = _UT('\0');
+ if (charsWritten != NULL) {
+ *charsWritten = written;
+ }
+ }
+ return URI_SUCCESS;
+}
+
+
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriResolve.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriResolve.c
new file mode 100644
index 000000000..ee4cba818
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriResolve.c
@@ -0,0 +1,329 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/* What encodings are enabled? */
+#include "UriDefsConfig.h"
+#if (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* Include SELF twice */
+# ifdef URI_ENABLE_ANSI
+# define URI_PASS_ANSI 1
+# include "UriResolve.c"
+# undef URI_PASS_ANSI
+# endif
+# ifdef URI_ENABLE_UNICODE
+# define URI_PASS_UNICODE 1
+# include "UriResolve.c"
+# undef URI_PASS_UNICODE
+# endif
+#else
+# ifdef URI_PASS_ANSI
+# include "UriDefsAnsi.h"
+# else
+# include "UriDefsUnicode.h"
+# include <wchar.h>
+# endif
+
+
+
+#ifndef URI_DOXYGEN
+# include "Uri.h"
+# include "UriCommon.h"
+# include "UriMemory.h"
+#endif
+
+
+
+/* Appends a relative URI to an absolute. The last path segment of
+ * the absolute URI is replaced. */
+static URI_INLINE UriBool URI_FUNC(MergePath)(URI_TYPE(Uri) * absWork,
+ const URI_TYPE(Uri) * relAppend, UriMemoryManager * memory) {
+ URI_TYPE(PathSegment) * sourceWalker;
+ URI_TYPE(PathSegment) * destPrev;
+ if (relAppend->pathHead == NULL) {
+ return URI_TRUE;
+ }
+
+ /* Replace last segment ("" if trailing slash) with first of append chain */
+ if (absWork->pathHead == NULL) {
+ URI_TYPE(PathSegment) * const dup = memory->malloc(memory, sizeof(URI_TYPE(PathSegment)));
+ if (dup == NULL) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+ dup->next = NULL;
+ absWork->pathHead = dup;
+ absWork->pathTail = dup;
+ }
+ absWork->pathTail->text.first = relAppend->pathHead->text.first;
+ absWork->pathTail->text.afterLast = relAppend->pathHead->text.afterLast;
+
+ /* Append all the others */
+ sourceWalker = relAppend->pathHead->next;
+ if (sourceWalker == NULL) {
+ return URI_TRUE;
+ }
+ destPrev = absWork->pathTail;
+
+ for (;;) {
+ URI_TYPE(PathSegment) * const dup = memory->malloc(memory, sizeof(URI_TYPE(PathSegment)));
+ if (dup == NULL) {
+ destPrev->next = NULL;
+ absWork->pathTail = destPrev;
+ return URI_FALSE; /* Raises malloc error */
+ }
+ dup->text = sourceWalker->text;
+ destPrev->next = dup;
+
+ if (sourceWalker->next == NULL) {
+ absWork->pathTail = dup;
+ absWork->pathTail->next = NULL;
+ break;
+ }
+ destPrev = dup;
+ sourceWalker = sourceWalker->next;
+ }
+
+ return URI_TRUE;
+}
+
+
+static int URI_FUNC(ResolveAbsolutePathFlag)(URI_TYPE(Uri) * absWork,
+ UriMemoryManager * memory) {
+ if (absWork == NULL) {
+ return URI_ERROR_NULL;
+ }
+
+ if (URI_FUNC(IsHostSet)(absWork) && absWork->absolutePath) {
+ /* Empty segment needed, instead? */
+ if (absWork->pathHead == NULL) {
+ URI_TYPE(PathSegment) * const segment = memory->malloc(memory, sizeof(URI_TYPE(PathSegment)));
+ if (segment == NULL) {
+ return URI_ERROR_MALLOC;
+ }
+ segment->text.first = URI_FUNC(SafeToPointTo);
+ segment->text.afterLast = URI_FUNC(SafeToPointTo);
+ segment->next = NULL;
+
+ absWork->pathHead = segment;
+ absWork->pathTail = segment;
+ }
+
+ absWork->absolutePath = URI_FALSE;
+ }
+
+ return URI_SUCCESS;
+}
+
+
+static int URI_FUNC(AddBaseUriImpl)(URI_TYPE(Uri) * absDest,
+ const URI_TYPE(Uri) * relSource,
+ const URI_TYPE(Uri) * absBase,
+ UriResolutionOptions options, UriMemoryManager * memory) {
+ UriBool relSourceHasScheme;
+
+ if (absDest == NULL) {
+ return URI_ERROR_NULL;
+ }
+ URI_FUNC(ResetUri)(absDest);
+
+ if ((relSource == NULL) || (absBase == NULL)) {
+ return URI_ERROR_NULL;
+ }
+
+ /* absBase absolute? */
+ if (absBase->scheme.first == NULL) {
+ return URI_ERROR_ADDBASE_REL_BASE;
+ }
+
+ /* [00/32] -- A non-strict parser may ignore a scheme in the reference */
+ /* [00/32] -- if it is identical to the base URI's scheme. */
+ /* [00/32] if ((not strict) and (R.scheme == Base.scheme)) then */
+ relSourceHasScheme = (relSource->scheme.first != NULL) ? URI_TRUE : URI_FALSE;
+ if ((options & URI_RESOLVE_IDENTICAL_SCHEME_COMPAT)
+ && (absBase->scheme.first != NULL)
+ && (relSource->scheme.first != NULL)
+ && (0 == URI_FUNC(CompareRange)(&(absBase->scheme), &(relSource->scheme)))) {
+ /* [00/32] undefine(R.scheme); */
+ relSourceHasScheme = URI_FALSE;
+ /* [00/32] endif; */
+ }
+
+ /* [01/32] if defined(R.scheme) then */
+ if (relSourceHasScheme) {
+ /* [02/32] T.scheme = R.scheme; */
+ absDest->scheme = relSource->scheme;
+ /* [03/32] T.authority = R.authority; */
+ if (!URI_FUNC(CopyAuthority)(absDest, relSource, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [04/32] T.path = remove_dot_segments(R.path); */
+ if (!URI_FUNC(CopyPath)(absDest, relSource, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ if (!URI_FUNC(RemoveDotSegmentsAbsolute)(absDest, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [05/32] T.query = R.query; */
+ absDest->query = relSource->query;
+ /* [06/32] else */
+ } else {
+ /* [07/32] if defined(R.authority) then */
+ if (URI_FUNC(IsHostSet)(relSource)) {
+ /* [08/32] T.authority = R.authority; */
+ if (!URI_FUNC(CopyAuthority)(absDest, relSource, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [09/32] T.path = remove_dot_segments(R.path); */
+ if (!URI_FUNC(CopyPath)(absDest, relSource, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ if (!URI_FUNC(RemoveDotSegmentsAbsolute)(absDest, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [10/32] T.query = R.query; */
+ absDest->query = relSource->query;
+ /* [11/32] else */
+ } else {
+ /* [28/32] T.authority = Base.authority; */
+ if (!URI_FUNC(CopyAuthority)(absDest, absBase, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [12/32] if (R.path == "") then */
+ if (relSource->pathHead == NULL && !relSource->absolutePath) {
+ /* [13/32] T.path = Base.path; */
+ if (!URI_FUNC(CopyPath)(absDest, absBase, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [14/32] if defined(R.query) then */
+ if (relSource->query.first != NULL) {
+ /* [15/32] T.query = R.query; */
+ absDest->query = relSource->query;
+ /* [16/32] else */
+ } else {
+ /* [17/32] T.query = Base.query; */
+ absDest->query = absBase->query;
+ /* [18/32] endif; */
+ }
+ /* [19/32] else */
+ } else {
+ /* [20/32] if (R.path starts-with "/") then */
+ if (relSource->absolutePath) {
+ int res;
+ /* [21/32] T.path = remove_dot_segments(R.path); */
+ if (!URI_FUNC(CopyPath)(absDest, relSource, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ res = URI_FUNC(ResolveAbsolutePathFlag)(absDest, memory);
+ if (res != URI_SUCCESS) {
+ return res;
+ }
+ if (!URI_FUNC(RemoveDotSegmentsAbsolute)(absDest, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [22/32] else */
+ } else {
+ /* [23/32] T.path = merge(Base.path, R.path); */
+ if (!URI_FUNC(CopyPath)(absDest, absBase, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ if (!URI_FUNC(MergePath)(absDest, relSource, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [24/32] T.path = remove_dot_segments(T.path); */
+ if (!URI_FUNC(RemoveDotSegmentsAbsolute)(absDest, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+
+ if (!URI_FUNC(FixAmbiguity)(absDest, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [25/32] endif; */
+ }
+ /* [26/32] T.query = R.query; */
+ absDest->query = relSource->query;
+ /* [27/32] endif; */
+ }
+ URI_FUNC(FixEmptyTrailSegment)(absDest, memory);
+ /* [29/32] endif; */
+ }
+ /* [30/32] T.scheme = Base.scheme; */
+ absDest->scheme = absBase->scheme;
+ /* [31/32] endif; */
+ }
+ /* [32/32] T.fragment = R.fragment; */
+ absDest->fragment = relSource->fragment;
+
+ return URI_SUCCESS;
+
+}
+
+
+
+int URI_FUNC(AddBaseUri)(URI_TYPE(Uri) * absDest,
+ const URI_TYPE(Uri) * relSource, const URI_TYPE(Uri) * absBase) {
+ const UriResolutionOptions options = URI_RESOLVE_STRICTLY;
+ return URI_FUNC(AddBaseUriEx)(absDest, relSource, absBase, options);
+}
+
+
+
+int URI_FUNC(AddBaseUriEx)(URI_TYPE(Uri) * absDest,
+ const URI_TYPE(Uri) * relSource, const URI_TYPE(Uri) * absBase,
+ UriResolutionOptions options) {
+ return URI_FUNC(AddBaseUriExMm)(absDest, relSource, absBase, options, NULL);
+}
+
+
+
+int URI_FUNC(AddBaseUriExMm)(URI_TYPE(Uri) * absDest,
+ const URI_TYPE(Uri) * relSource, const URI_TYPE(Uri) * absBase,
+ UriResolutionOptions options, UriMemoryManager * memory) {
+ int res;
+
+ URI_CHECK_MEMORY_MANAGER(memory); /* may return */
+
+ res = URI_FUNC(AddBaseUriImpl)(absDest, relSource, absBase, options, memory);
+ if ((res != URI_SUCCESS) && (absDest != NULL)) {
+ URI_FUNC(FreeUriMembersMm)(absDest, memory);
+ }
+ return res;
+}
+
+
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/UriShorten.c b/src/arrow/cpp/src/arrow/vendored/uriparser/UriShorten.c
new file mode 100644
index 000000000..f00b05f59
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/UriShorten.c
@@ -0,0 +1,324 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2007, Weijia Song <songweijia@gmail.com>
+ * Copyright (C) 2007, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/* What encodings are enabled? */
+#include "UriDefsConfig.h"
+#if (!defined(URI_PASS_ANSI) && !defined(URI_PASS_UNICODE))
+/* Include SELF twice */
+# ifdef URI_ENABLE_ANSI
+# define URI_PASS_ANSI 1
+# include "UriShorten.c"
+# undef URI_PASS_ANSI
+# endif
+# ifdef URI_ENABLE_UNICODE
+# define URI_PASS_UNICODE 1
+# include "UriShorten.c"
+# undef URI_PASS_UNICODE
+# endif
+#else
+# ifdef URI_PASS_ANSI
+# include "UriDefsAnsi.h"
+# else
+# include "UriDefsUnicode.h"
+# include <wchar.h>
+# endif
+
+
+
+#ifndef URI_DOXYGEN
+# include "Uri.h"
+# include "UriCommon.h"
+# include "UriMemory.h"
+#endif
+
+
+
+static URI_INLINE UriBool URI_FUNC(AppendSegment)(URI_TYPE(Uri) * uri,
+ const URI_CHAR * first, const URI_CHAR * afterLast,
+ UriMemoryManager * memory) {
+ /* Create segment */
+ URI_TYPE(PathSegment) * segment = memory->malloc(memory, 1 * sizeof(URI_TYPE(PathSegment)));
+ if (segment == NULL) {
+ return URI_FALSE; /* Raises malloc error */
+ }
+ segment->next = NULL;
+ segment->text.first = first;
+ segment->text.afterLast = afterLast;
+
+ /* Put into chain */
+ if (uri->pathTail == NULL) {
+ uri->pathHead = segment;
+ } else {
+ uri->pathTail->next = segment;
+ }
+ uri->pathTail = segment;
+
+ return URI_TRUE;
+}
+
+
+
+static URI_INLINE UriBool URI_FUNC(EqualsAuthority)(const URI_TYPE(Uri) * first,
+ const URI_TYPE(Uri) * second) {
+ /* IPv4 */
+ if (first->hostData.ip4 != NULL) {
+ return ((second->hostData.ip4 != NULL)
+ && !memcmp(first->hostData.ip4->data,
+ second->hostData.ip4->data, 4)) ? URI_TRUE : URI_FALSE;
+ }
+
+ /* IPv6 */
+ if (first->hostData.ip6 != NULL) {
+ return ((second->hostData.ip6 != NULL)
+ && !memcmp(first->hostData.ip6->data,
+ second->hostData.ip6->data, 16)) ? URI_TRUE : URI_FALSE;
+ }
+
+ /* IPvFuture */
+ if (first->hostData.ipFuture.first != NULL) {
+ return ((second->hostData.ipFuture.first != NULL)
+ && !URI_FUNC(CompareRange)(&first->hostData.ipFuture,
+ &second->hostData.ipFuture)) ? URI_TRUE : URI_FALSE;
+ }
+
+ return !URI_FUNC(CompareRange)(&first->hostText, &second->hostText)
+ ? URI_TRUE : URI_FALSE;
+}
+
+
+
+static int URI_FUNC(RemoveBaseUriImpl)(URI_TYPE(Uri) * dest,
+ const URI_TYPE(Uri) * absSource,
+ const URI_TYPE(Uri) * absBase,
+ UriBool domainRootMode, UriMemoryManager * memory) {
+ if (dest == NULL) {
+ return URI_ERROR_NULL;
+ }
+ URI_FUNC(ResetUri)(dest);
+
+ if ((absSource == NULL) || (absBase == NULL)) {
+ return URI_ERROR_NULL;
+ }
+
+ /* absBase absolute? */
+ if (absBase->scheme.first == NULL) {
+ return URI_ERROR_REMOVEBASE_REL_BASE;
+ }
+
+ /* absSource absolute? */
+ if (absSource->scheme.first == NULL) {
+ return URI_ERROR_REMOVEBASE_REL_SOURCE;
+ }
+
+ /* [01/50] if (A.scheme != Base.scheme) then */
+ if (URI_FUNC(CompareRange)(&absSource->scheme, &absBase->scheme)) {
+ /* [02/50] T.scheme = A.scheme; */
+ dest->scheme = absSource->scheme;
+ /* [03/50] T.authority = A.authority; */
+ if (!URI_FUNC(CopyAuthority)(dest, absSource, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [04/50] T.path = A.path; */
+ if (!URI_FUNC(CopyPath)(dest, absSource, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [05/50] else */
+ } else {
+ /* [06/50] undef(T.scheme); */
+ /* NOOP */
+ /* [07/50] if (A.authority != Base.authority) then */
+ if (!URI_FUNC(EqualsAuthority)(absSource, absBase)) {
+ /* [08/50] T.authority = A.authority; */
+ if (!URI_FUNC(CopyAuthority)(dest, absSource, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [09/50] T.path = A.path; */
+ if (!URI_FUNC(CopyPath)(dest, absSource, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [10/50] else */
+ } else {
+ /* [11/50] if domainRootMode then */
+ if (domainRootMode == URI_TRUE) {
+ /* [12/50] undef(T.authority); */
+ /* NOOP */
+ /* [13/50] if (first(A.path) == "") then */
+ /* GROUPED */
+ /* [14/50] T.path = "/." + A.path; */
+ /* GROUPED */
+ /* [15/50] else */
+ /* GROUPED */
+ /* [16/50] T.path = A.path; */
+ /* GROUPED */
+ /* [17/50] endif; */
+ if (!URI_FUNC(CopyPath)(dest, absSource, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ dest->absolutePath = URI_TRUE;
+
+ if (!URI_FUNC(FixAmbiguity)(dest, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [18/50] else */
+ } else {
+ const URI_TYPE(PathSegment) * sourceSeg = absSource->pathHead;
+ const URI_TYPE(PathSegment) * baseSeg = absBase->pathHead;
+ /* [19/50] bool pathNaked = true; */
+ UriBool pathNaked = URI_TRUE;
+ /* [20/50] undef(last(Base.path)); */
+ /* NOOP */
+ /* [21/50] T.path = ""; */
+ dest->absolutePath = URI_FALSE;
+ /* [22/50] while (first(A.path) == first(Base.path)) do */
+ while ((sourceSeg != NULL) && (baseSeg != NULL)
+ && !URI_FUNC(CompareRange)(&sourceSeg->text, &baseSeg->text)
+ && !((sourceSeg->text.first == sourceSeg->text.afterLast)
+ && ((sourceSeg->next == NULL) != (baseSeg->next == NULL)))) {
+ /* [23/50] A.path++; */
+ sourceSeg = sourceSeg->next;
+ /* [24/50] Base.path++; */
+ baseSeg = baseSeg->next;
+ /* [25/50] endwhile; */
+ }
+ /* [26/50] while defined(first(Base.path)) do */
+ while ((baseSeg != NULL) && (baseSeg->next != NULL)) {
+ /* [27/50] Base.path++; */
+ baseSeg = baseSeg->next;
+ /* [28/50] T.path += "../"; */
+ if (!URI_FUNC(AppendSegment)(dest, URI_FUNC(ConstParent),
+ URI_FUNC(ConstParent) + 2, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [29/50] pathNaked = false; */
+ pathNaked = URI_FALSE;
+ /* [30/50] endwhile; */
+ }
+ /* [31/50] while defined(first(A.path)) do */
+ while (sourceSeg != NULL) {
+ /* [32/50] if pathNaked then */
+ if (pathNaked == URI_TRUE) {
+ /* [33/50] if (first(A.path) contains ":") then */
+ UriBool containsColon = URI_FALSE;
+ const URI_CHAR * ch = sourceSeg->text.first;
+ for (; ch < sourceSeg->text.afterLast; ch++) {
+ if (*ch == _UT(':')) {
+ containsColon = URI_TRUE;
+ break;
+ }
+ }
+
+ if (containsColon) {
+ /* [34/50] T.path += "./"; */
+ if (!URI_FUNC(AppendSegment)(dest, URI_FUNC(ConstPwd),
+ URI_FUNC(ConstPwd) + 1, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [35/50] elseif (first(A.path) == "") then */
+ } else if (sourceSeg->text.first == sourceSeg->text.afterLast) {
+ /* [36/50] T.path += "/."; */
+ if (!URI_FUNC(AppendSegment)(dest, URI_FUNC(ConstPwd),
+ URI_FUNC(ConstPwd) + 1, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [37/50] endif; */
+ }
+ /* [38/50] endif; */
+ }
+ /* [39/50] T.path += first(A.path); */
+ if (!URI_FUNC(AppendSegment)(dest, sourceSeg->text.first,
+ sourceSeg->text.afterLast, memory)) {
+ return URI_ERROR_MALLOC;
+ }
+ /* [40/50] pathNaked = false; */
+ pathNaked = URI_FALSE;
+ /* [41/50] A.path++; */
+ sourceSeg = sourceSeg->next;
+ /* [42/50] if defined(first(A.path)) then */
+ /* NOOP */
+ /* [43/50] T.path += + "/"; */
+ /* NOOP */
+ /* [44/50] endif; */
+ /* NOOP */
+ /* [45/50] endwhile; */
+ }
+ /* [46/50] endif; */
+ }
+ /* [47/50] endif; */
+ }
+ /* [48/50] endif; */
+ }
+ /* [49/50] T.query = A.query; */
+ dest->query = absSource->query;
+ /* [50/50] T.fragment = A.fragment; */
+ dest->fragment = absSource->fragment;
+
+ return URI_SUCCESS;
+}
+
+
+
+int URI_FUNC(RemoveBaseUri)(URI_TYPE(Uri) * dest,
+ const URI_TYPE(Uri) * absSource,
+ const URI_TYPE(Uri) * absBase,
+ UriBool domainRootMode) {
+ return URI_FUNC(RemoveBaseUriMm)(dest, absSource, absBase,
+ domainRootMode, NULL);
+}
+
+
+
+int URI_FUNC(RemoveBaseUriMm)(URI_TYPE(Uri) * dest,
+ const URI_TYPE(Uri) * absSource,
+ const URI_TYPE(Uri) * absBase,
+ UriBool domainRootMode, UriMemoryManager * memory) {
+ int res;
+
+ URI_CHECK_MEMORY_MANAGER(memory); /* may return */
+
+ res = URI_FUNC(RemoveBaseUriImpl)(dest, absSource,
+ absBase, domainRootMode, memory);
+ if ((res != URI_SUCCESS) && (dest != NULL)) {
+ URI_FUNC(FreeUriMembersMm)(dest, memory);
+ }
+ return res;
+}
+
+
+
+#endif
diff --git a/src/arrow/cpp/src/arrow/vendored/uriparser/config.h b/src/arrow/cpp/src/arrow/vendored/uriparser/config.h
new file mode 100644
index 000000000..97f0605da
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/uriparser/config.h
@@ -0,0 +1,47 @@
+/*
+ * uriparser - RFC 3986 URI parsing library
+ *
+ * Copyright (C) 2018, Sebastian Pipping <sebastian@pipping.org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ * * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * * Neither the name of the <ORGANIZATION> nor the names of its
+ * contributors may be used to endorse or promote products
+ * derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
+ * OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#if !defined(URI_CONFIG_H)
+# define URI_CONFIG_H 1
+
+#define PACKAGE_VERSION "0.9.3"
+
+// #define HAVE_WPRINTF
+// #define HAVE_REALLOCARRAY
+
+#endif /* !defined(URI_CONFIG_H) */
diff --git a/src/arrow/cpp/src/arrow/vendored/utfcpp/README.md b/src/arrow/cpp/src/arrow/vendored/utfcpp/README.md
new file mode 100644
index 000000000..c0abfd7d1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/utfcpp/README.md
@@ -0,0 +1,28 @@
+<!---
+ Boost Software License - Version 1.0 - August 17th, 2003
+
+ Permission is hereby granted, free of charge, to any person or organization
+ obtaining a copy of the software and accompanying documentation covered by
+ this license (the "Software") to use, reproduce, display, distribute,
+ execute, and transmit the Software, and to prepare derivative works of the
+ Software, and to permit third-parties to whom the Software is furnished to
+ do so, all subject to the following:
+
+ The copyright notices in the Software and this entire statement, including
+ the above license grant, this restriction and the following disclaimer,
+ must be included in all copies of the Software, in whole or in part, and
+ all derivative works of the Software, unless such copies or derivative
+ works are solely in the form of machine-executable object code generated by
+ a source language processor.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
+ SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
+ FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
+ ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ DEALINGS IN THE SOFTWARE.
+-->
+
+The files in this directory are vendored from utfcpp git tag v3.1.1
+(https://github.com/nemtrif/utfcpp).
diff --git a/src/arrow/cpp/src/arrow/vendored/utfcpp/checked.h b/src/arrow/cpp/src/arrow/vendored/utfcpp/checked.h
new file mode 100644
index 000000000..648636e46
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/utfcpp/checked.h
@@ -0,0 +1,333 @@
+// Copyright 2006-2016 Nemanja Trifunovic
+
+/*
+Permission is hereby granted, free of charge, to any person or organization
+obtaining a copy of the software and accompanying documentation covered by
+this license (the "Software") to use, reproduce, display, distribute,
+execute, and transmit the Software, and to prepare derivative works of the
+Software, and to permit third-parties to whom the Software is furnished to
+do so, all subject to the following:
+
+The copyright notices in the Software and this entire statement, including
+the above license grant, this restriction and the following disclaimer,
+must be included in all copies of the Software, in whole or in part, and
+all derivative works of the Software, unless such copies or derivative
+works are solely in the form of machine-executable object code generated by
+a source language processor.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
+SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
+FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
+ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+*/
+
+
+#ifndef UTF8_FOR_CPP_CHECKED_H_2675DCD0_9480_4c0c_B92A_CC14C027B731
+#define UTF8_FOR_CPP_CHECKED_H_2675DCD0_9480_4c0c_B92A_CC14C027B731
+
+#include "core.h"
+#include <stdexcept>
+
+namespace utf8
+{
+ // Base for the exceptions that may be thrown from the library
+ class exception : public ::std::exception {
+ };
+
+ // Exceptions that may be thrown from the library functions.
+ class invalid_code_point : public exception {
+ uint32_t cp;
+ public:
+ invalid_code_point(uint32_t codepoint) : cp(codepoint) {}
+ virtual const char* what() const NOEXCEPT OVERRIDE { return "Invalid code point"; }
+ uint32_t code_point() const {return cp;}
+ };
+
+ class invalid_utf8 : public exception {
+ uint8_t u8;
+ public:
+ invalid_utf8 (uint8_t u) : u8(u) {}
+ virtual const char* what() const NOEXCEPT OVERRIDE { return "Invalid UTF-8"; }
+ uint8_t utf8_octet() const {return u8;}
+ };
+
+ class invalid_utf16 : public exception {
+ uint16_t u16;
+ public:
+ invalid_utf16 (uint16_t u) : u16(u) {}
+ virtual const char* what() const NOEXCEPT OVERRIDE { return "Invalid UTF-16"; }
+ uint16_t utf16_word() const {return u16;}
+ };
+
+ class not_enough_room : public exception {
+ public:
+ virtual const char* what() const NOEXCEPT OVERRIDE { return "Not enough space"; }
+ };
+
+ /// The library API - functions intended to be called by the users
+
+ template <typename octet_iterator>
+ octet_iterator append(uint32_t cp, octet_iterator result)
+ {
+ if (!utf8::internal::is_code_point_valid(cp))
+ throw invalid_code_point(cp);
+
+ if (cp < 0x80) // one octet
+ *(result++) = static_cast<uint8_t>(cp);
+ else if (cp < 0x800) { // two octets
+ *(result++) = static_cast<uint8_t>((cp >> 6) | 0xc0);
+ *(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
+ }
+ else if (cp < 0x10000) { // three octets
+ *(result++) = static_cast<uint8_t>((cp >> 12) | 0xe0);
+ *(result++) = static_cast<uint8_t>(((cp >> 6) & 0x3f) | 0x80);
+ *(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
+ }
+ else { // four octets
+ *(result++) = static_cast<uint8_t>((cp >> 18) | 0xf0);
+ *(result++) = static_cast<uint8_t>(((cp >> 12) & 0x3f) | 0x80);
+ *(result++) = static_cast<uint8_t>(((cp >> 6) & 0x3f) | 0x80);
+ *(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
+ }
+ return result;
+ }
+
+ template <typename octet_iterator, typename output_iterator>
+ output_iterator replace_invalid(octet_iterator start, octet_iterator end, output_iterator out, uint32_t replacement)
+ {
+ while (start != end) {
+ octet_iterator sequence_start = start;
+ internal::utf_error err_code = utf8::internal::validate_next(start, end);
+ switch (err_code) {
+ case internal::UTF8_OK :
+ for (octet_iterator it = sequence_start; it != start; ++it)
+ *out++ = *it;
+ break;
+ case internal::NOT_ENOUGH_ROOM:
+ out = utf8::append (replacement, out);
+ start = end;
+ break;
+ case internal::INVALID_LEAD:
+ out = utf8::append (replacement, out);
+ ++start;
+ break;
+ case internal::INCOMPLETE_SEQUENCE:
+ case internal::OVERLONG_SEQUENCE:
+ case internal::INVALID_CODE_POINT:
+ out = utf8::append (replacement, out);
+ ++start;
+ // just one replacement mark for the sequence
+ while (start != end && utf8::internal::is_trail(*start))
+ ++start;
+ break;
+ }
+ }
+ return out;
+ }
+
+ template <typename octet_iterator, typename output_iterator>
+ inline output_iterator replace_invalid(octet_iterator start, octet_iterator end, output_iterator out)
+ {
+ static const uint32_t replacement_marker = utf8::internal::mask16(0xfffd);
+ return utf8::replace_invalid(start, end, out, replacement_marker);
+ }
+
+ template <typename octet_iterator>
+ uint32_t next(octet_iterator& it, octet_iterator end)
+ {
+ uint32_t cp = 0;
+ internal::utf_error err_code = utf8::internal::validate_next(it, end, cp);
+ switch (err_code) {
+ case internal::UTF8_OK :
+ break;
+ case internal::NOT_ENOUGH_ROOM :
+ throw not_enough_room();
+ case internal::INVALID_LEAD :
+ case internal::INCOMPLETE_SEQUENCE :
+ case internal::OVERLONG_SEQUENCE :
+ throw invalid_utf8(*it);
+ case internal::INVALID_CODE_POINT :
+ throw invalid_code_point(cp);
+ }
+ return cp;
+ }
+
+ template <typename octet_iterator>
+ uint32_t peek_next(octet_iterator it, octet_iterator end)
+ {
+ return utf8::next(it, end);
+ }
+
+ template <typename octet_iterator>
+ uint32_t prior(octet_iterator& it, octet_iterator start)
+ {
+ // can't do much if it == start
+ if (it == start)
+ throw not_enough_room();
+
+ octet_iterator end = it;
+ // Go back until we hit either a lead octet or start
+ while (utf8::internal::is_trail(*(--it)))
+ if (it == start)
+ throw invalid_utf8(*it); // error - no lead byte in the sequence
+ return utf8::peek_next(it, end);
+ }
+
+ template <typename octet_iterator, typename distance_type>
+ void advance (octet_iterator& it, distance_type n, octet_iterator end)
+ {
+ const distance_type zero(0);
+ if (n < zero) {
+ // backward
+ for (distance_type i = n; i < zero; ++i)
+ utf8::prior(it, end);
+ } else {
+ // forward
+ for (distance_type i = zero; i < n; ++i)
+ utf8::next(it, end);
+ }
+ }
+
+ template <typename octet_iterator>
+ typename std::iterator_traits<octet_iterator>::difference_type
+ distance (octet_iterator first, octet_iterator last)
+ {
+ typename std::iterator_traits<octet_iterator>::difference_type dist;
+ for (dist = 0; first < last; ++dist)
+ utf8::next(first, last);
+ return dist;
+ }
+
+ template <typename u16bit_iterator, typename octet_iterator>
+ octet_iterator utf16to8 (u16bit_iterator start, u16bit_iterator end, octet_iterator result)
+ {
+ while (start != end) {
+ uint32_t cp = utf8::internal::mask16(*start++);
+ // Take care of surrogate pairs first
+ if (utf8::internal::is_lead_surrogate(cp)) {
+ if (start != end) {
+ uint32_t trail_surrogate = utf8::internal::mask16(*start++);
+ if (utf8::internal::is_trail_surrogate(trail_surrogate))
+ cp = (cp << 10) + trail_surrogate + internal::SURROGATE_OFFSET;
+ else
+ throw invalid_utf16(static_cast<uint16_t>(trail_surrogate));
+ }
+ else
+ throw invalid_utf16(static_cast<uint16_t>(cp));
+
+ }
+ // Lone trail surrogate
+ else if (utf8::internal::is_trail_surrogate(cp))
+ throw invalid_utf16(static_cast<uint16_t>(cp));
+
+ result = utf8::append(cp, result);
+ }
+ return result;
+ }
+
+ template <typename u16bit_iterator, typename octet_iterator>
+ u16bit_iterator utf8to16 (octet_iterator start, octet_iterator end, u16bit_iterator result)
+ {
+ while (start < end) {
+ uint32_t cp = utf8::next(start, end);
+ if (cp > 0xffff) { //make a surrogate pair
+ *result++ = static_cast<uint16_t>((cp >> 10) + internal::LEAD_OFFSET);
+ *result++ = static_cast<uint16_t>((cp & 0x3ff) + internal::TRAIL_SURROGATE_MIN);
+ }
+ else
+ *result++ = static_cast<uint16_t>(cp);
+ }
+ return result;
+ }
+
+ template <typename octet_iterator, typename u32bit_iterator>
+ octet_iterator utf32to8 (u32bit_iterator start, u32bit_iterator end, octet_iterator result)
+ {
+ while (start != end)
+ result = utf8::append(*(start++), result);
+
+ return result;
+ }
+
+ template <typename octet_iterator, typename u32bit_iterator>
+ u32bit_iterator utf8to32 (octet_iterator start, octet_iterator end, u32bit_iterator result)
+ {
+ while (start < end)
+ (*result++) = utf8::next(start, end);
+
+ return result;
+ }
+
+ // The iterator class
+ template <typename octet_iterator>
+ class iterator {
+ octet_iterator it;
+ octet_iterator range_start;
+ octet_iterator range_end;
+ public:
+ typedef uint32_t value_type;
+ typedef uint32_t* pointer;
+ typedef uint32_t& reference;
+ typedef std::ptrdiff_t difference_type;
+ typedef std::bidirectional_iterator_tag iterator_category;
+ iterator () {}
+ explicit iterator (const octet_iterator& octet_it,
+ const octet_iterator& rangestart,
+ const octet_iterator& rangeend) :
+ it(octet_it), range_start(rangestart), range_end(rangeend)
+ {
+ if (it < range_start || it > range_end)
+ throw std::out_of_range("Invalid utf-8 iterator position");
+ }
+ // the default "big three" are OK
+ octet_iterator base () const { return it; }
+ uint32_t operator * () const
+ {
+ octet_iterator temp = it;
+ return utf8::next(temp, range_end);
+ }
+ bool operator == (const iterator& rhs) const
+ {
+ if (range_start != rhs.range_start || range_end != rhs.range_end)
+ throw std::logic_error("Comparing utf-8 iterators defined with different ranges");
+ return (it == rhs.it);
+ }
+ bool operator != (const iterator& rhs) const
+ {
+ return !(operator == (rhs));
+ }
+ iterator& operator ++ ()
+ {
+ utf8::next(it, range_end);
+ return *this;
+ }
+ iterator operator ++ (int)
+ {
+ iterator temp = *this;
+ utf8::next(it, range_end);
+ return temp;
+ }
+ iterator& operator -- ()
+ {
+ utf8::prior(it, range_start);
+ return *this;
+ }
+ iterator operator -- (int)
+ {
+ iterator temp = *this;
+ utf8::prior(it, range_start);
+ return temp;
+ }
+ }; // class iterator
+
+} // namespace utf8
+
+#if UTF_CPP_CPLUSPLUS >= 201103L // C++ 11 or later
+#include "cpp11.h"
+#endif // C++ 11 or later
+
+#endif //header guard
+
diff --git a/src/arrow/cpp/src/arrow/vendored/utfcpp/core.h b/src/arrow/cpp/src/arrow/vendored/utfcpp/core.h
new file mode 100644
index 000000000..244e89231
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/utfcpp/core.h
@@ -0,0 +1,338 @@
+// Copyright 2006 Nemanja Trifunovic
+
+/*
+Permission is hereby granted, free of charge, to any person or organization
+obtaining a copy of the software and accompanying documentation covered by
+this license (the "Software") to use, reproduce, display, distribute,
+execute, and transmit the Software, and to prepare derivative works of the
+Software, and to permit third-parties to whom the Software is furnished to
+do so, all subject to the following:
+
+The copyright notices in the Software and this entire statement, including
+the above license grant, this restriction and the following disclaimer,
+must be included in all copies of the Software, in whole or in part, and
+all derivative works of the Software, unless such copies or derivative
+works are solely in the form of machine-executable object code generated by
+a source language processor.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
+SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
+FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
+ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+*/
+
+
+#ifndef UTF8_FOR_CPP_CORE_H_2675DCD0_9480_4c0c_B92A_CC14C027B731
+#define UTF8_FOR_CPP_CORE_H_2675DCD0_9480_4c0c_B92A_CC14C027B731
+
+#include <iterator>
+
+// Determine the C++ standard version.
+// If the user defines UTF_CPP_CPLUSPLUS, use that.
+// Otherwise, trust the unreliable predefined macro __cplusplus
+
+#if !defined UTF_CPP_CPLUSPLUS
+ #define UTF_CPP_CPLUSPLUS __cplusplus
+#endif
+
+#if UTF_CPP_CPLUSPLUS >= 201103L // C++ 11 or later
+ #define OVERRIDE override
+ #define NOEXCEPT noexcept
+#else // C++ 98/03
+ #define OVERRIDE
+ #define NOEXCEPT throw()
+#endif // C++ 11 or later
+
+
+namespace utf8
+{
+ // The typedefs for 8-bit, 16-bit and 32-bit unsigned integers
+ // You may need to change them to match your system.
+ // These typedefs have the same names as ones from cstdint, or boost/cstdint
+ typedef unsigned char uint8_t;
+ typedef unsigned short uint16_t;
+ typedef unsigned int uint32_t;
+
+// Helper code - not intended to be directly called by the library users. May be changed at any time
+namespace internal
+{
+ // Unicode constants
+ // Leading (high) surrogates: 0xd800 - 0xdbff
+ // Trailing (low) surrogates: 0xdc00 - 0xdfff
+ const uint16_t LEAD_SURROGATE_MIN = 0xd800u;
+ const uint16_t LEAD_SURROGATE_MAX = 0xdbffu;
+ const uint16_t TRAIL_SURROGATE_MIN = 0xdc00u;
+ const uint16_t TRAIL_SURROGATE_MAX = 0xdfffu;
+ const uint16_t LEAD_OFFSET = 0xd7c0u; // LEAD_SURROGATE_MIN - (0x10000 >> 10)
+ const uint32_t SURROGATE_OFFSET = 0xfca02400u; // 0x10000u - (LEAD_SURROGATE_MIN << 10) - TRAIL_SURROGATE_MIN
+
+ // Maximum valid value for a Unicode code point
+ const uint32_t CODE_POINT_MAX = 0x0010ffffu;
+
+ template<typename octet_type>
+ inline uint8_t mask8(octet_type oc)
+ {
+ return static_cast<uint8_t>(0xff & oc);
+ }
+ template<typename u16_type>
+ inline uint16_t mask16(u16_type oc)
+ {
+ return static_cast<uint16_t>(0xffff & oc);
+ }
+ template<typename octet_type>
+ inline bool is_trail(octet_type oc)
+ {
+ return ((utf8::internal::mask8(oc) >> 6) == 0x2);
+ }
+
+ template <typename u16>
+ inline bool is_lead_surrogate(u16 cp)
+ {
+ return (cp >= LEAD_SURROGATE_MIN && cp <= LEAD_SURROGATE_MAX);
+ }
+
+ template <typename u16>
+ inline bool is_trail_surrogate(u16 cp)
+ {
+ return (cp >= TRAIL_SURROGATE_MIN && cp <= TRAIL_SURROGATE_MAX);
+ }
+
+ template <typename u16>
+ inline bool is_surrogate(u16 cp)
+ {
+ return (cp >= LEAD_SURROGATE_MIN && cp <= TRAIL_SURROGATE_MAX);
+ }
+
+ template <typename u32>
+ inline bool is_code_point_valid(u32 cp)
+ {
+ return (cp <= CODE_POINT_MAX && !utf8::internal::is_surrogate(cp));
+ }
+
+ template <typename octet_iterator>
+ inline typename std::iterator_traits<octet_iterator>::difference_type
+ sequence_length(octet_iterator lead_it)
+ {
+ uint8_t lead = utf8::internal::mask8(*lead_it);
+ if (lead < 0x80)
+ return 1;
+ else if ((lead >> 5) == 0x6)
+ return 2;
+ else if ((lead >> 4) == 0xe)
+ return 3;
+ else if ((lead >> 3) == 0x1e)
+ return 4;
+ else
+ return 0;
+ }
+
+ template <typename octet_difference_type>
+ inline bool is_overlong_sequence(uint32_t cp, octet_difference_type length)
+ {
+ if (cp < 0x80) {
+ if (length != 1)
+ return true;
+ }
+ else if (cp < 0x800) {
+ if (length != 2)
+ return true;
+ }
+ else if (cp < 0x10000) {
+ if (length != 3)
+ return true;
+ }
+
+ return false;
+ }
+
+ enum utf_error {UTF8_OK, NOT_ENOUGH_ROOM, INVALID_LEAD, INCOMPLETE_SEQUENCE, OVERLONG_SEQUENCE, INVALID_CODE_POINT};
+
+ /// Helper for get_sequence_x
+ template <typename octet_iterator>
+ utf_error increase_safely(octet_iterator& it, octet_iterator end)
+ {
+ if (++it == end)
+ return NOT_ENOUGH_ROOM;
+
+ if (!utf8::internal::is_trail(*it))
+ return INCOMPLETE_SEQUENCE;
+
+ return UTF8_OK;
+ }
+
+ #define UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR(IT, END) {utf_error ret = increase_safely(IT, END); if (ret != UTF8_OK) return ret;}
+
+ /// get_sequence_x functions decode utf-8 sequences of the length x
+ template <typename octet_iterator>
+ utf_error get_sequence_1(octet_iterator& it, octet_iterator end, uint32_t& code_point)
+ {
+ if (it == end)
+ return NOT_ENOUGH_ROOM;
+
+ code_point = utf8::internal::mask8(*it);
+
+ return UTF8_OK;
+ }
+
+ template <typename octet_iterator>
+ utf_error get_sequence_2(octet_iterator& it, octet_iterator end, uint32_t& code_point)
+ {
+ if (it == end)
+ return NOT_ENOUGH_ROOM;
+
+ code_point = utf8::internal::mask8(*it);
+
+ UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR(it, end)
+
+ code_point = ((code_point << 6) & 0x7ff) + ((*it) & 0x3f);
+
+ return UTF8_OK;
+ }
+
+ template <typename octet_iterator>
+ utf_error get_sequence_3(octet_iterator& it, octet_iterator end, uint32_t& code_point)
+ {
+ if (it == end)
+ return NOT_ENOUGH_ROOM;
+
+ code_point = utf8::internal::mask8(*it);
+
+ UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR(it, end)
+
+ code_point = ((code_point << 12) & 0xffff) + ((utf8::internal::mask8(*it) << 6) & 0xfff);
+
+ UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR(it, end)
+
+ code_point += (*it) & 0x3f;
+
+ return UTF8_OK;
+ }
+
+ template <typename octet_iterator>
+ utf_error get_sequence_4(octet_iterator& it, octet_iterator end, uint32_t& code_point)
+ {
+ if (it == end)
+ return NOT_ENOUGH_ROOM;
+
+ code_point = utf8::internal::mask8(*it);
+
+ UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR(it, end)
+
+ code_point = ((code_point << 18) & 0x1fffff) + ((utf8::internal::mask8(*it) << 12) & 0x3ffff);
+
+ UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR(it, end)
+
+ code_point += (utf8::internal::mask8(*it) << 6) & 0xfff;
+
+ UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR(it, end)
+
+ code_point += (*it) & 0x3f;
+
+ return UTF8_OK;
+ }
+
+ #undef UTF8_CPP_INCREASE_AND_RETURN_ON_ERROR
+
+ template <typename octet_iterator>
+ utf_error validate_next(octet_iterator& it, octet_iterator end, uint32_t& code_point)
+ {
+ if (it == end)
+ return NOT_ENOUGH_ROOM;
+
+ // Save the original value of it so we can go back in case of failure
+ // Of course, it does not make much sense with i.e. stream iterators
+ octet_iterator original_it = it;
+
+ uint32_t cp = 0;
+ // Determine the sequence length based on the lead octet
+ typedef typename std::iterator_traits<octet_iterator>::difference_type octet_difference_type;
+ const octet_difference_type length = utf8::internal::sequence_length(it);
+
+ // Get trail octets and calculate the code point
+ utf_error err = UTF8_OK;
+ switch (length) {
+ case 0:
+ return INVALID_LEAD;
+ case 1:
+ err = utf8::internal::get_sequence_1(it, end, cp);
+ break;
+ case 2:
+ err = utf8::internal::get_sequence_2(it, end, cp);
+ break;
+ case 3:
+ err = utf8::internal::get_sequence_3(it, end, cp);
+ break;
+ case 4:
+ err = utf8::internal::get_sequence_4(it, end, cp);
+ break;
+ }
+
+ if (err == UTF8_OK) {
+ // Decoding succeeded. Now, security checks...
+ if (utf8::internal::is_code_point_valid(cp)) {
+ if (!utf8::internal::is_overlong_sequence(cp, length)){
+ // Passed! Return here.
+ code_point = cp;
+ ++it;
+ return UTF8_OK;
+ }
+ else
+ err = OVERLONG_SEQUENCE;
+ }
+ else
+ err = INVALID_CODE_POINT;
+ }
+
+ // Failure branch - restore the original value of the iterator
+ it = original_it;
+ return err;
+ }
+
+ template <typename octet_iterator>
+ inline utf_error validate_next(octet_iterator& it, octet_iterator end) {
+ uint32_t ignored;
+ return utf8::internal::validate_next(it, end, ignored);
+ }
+
+} // namespace internal
+
+ /// The library API - functions intended to be called by the users
+
+ // Byte order mark
+ const uint8_t bom[] = {0xef, 0xbb, 0xbf};
+
+ template <typename octet_iterator>
+ octet_iterator find_invalid(octet_iterator start, octet_iterator end)
+ {
+ octet_iterator result = start;
+ while (result != end) {
+ utf8::internal::utf_error err_code = utf8::internal::validate_next(result, end);
+ if (err_code != internal::UTF8_OK)
+ return result;
+ }
+ return result;
+ }
+
+ template <typename octet_iterator>
+ inline bool is_valid(octet_iterator start, octet_iterator end)
+ {
+ return (utf8::find_invalid(start, end) == end);
+ }
+
+ template <typename octet_iterator>
+ inline bool starts_with_bom (octet_iterator it, octet_iterator end)
+ {
+ return (
+ ((it != end) && (utf8::internal::mask8(*it++)) == bom[0]) &&
+ ((it != end) && (utf8::internal::mask8(*it++)) == bom[1]) &&
+ ((it != end) && (utf8::internal::mask8(*it)) == bom[2])
+ );
+ }
+} // namespace utf8
+
+#endif // header guard
+
+
diff --git a/src/arrow/cpp/src/arrow/vendored/utfcpp/cpp11.h b/src/arrow/cpp/src/arrow/vendored/utfcpp/cpp11.h
new file mode 100644
index 000000000..d93961b04
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/utfcpp/cpp11.h
@@ -0,0 +1,103 @@
+// Copyright 2018 Nemanja Trifunovic
+
+/*
+Permission is hereby granted, free of charge, to any person or organization
+obtaining a copy of the software and accompanying documentation covered by
+this license (the "Software") to use, reproduce, display, distribute,
+execute, and transmit the Software, and to prepare derivative works of the
+Software, and to permit third-parties to whom the Software is furnished to
+do so, all subject to the following:
+
+The copyright notices in the Software and this entire statement, including
+the above license grant, this restriction and the following disclaimer,
+must be included in all copies of the Software, in whole or in part, and
+all derivative works of the Software, unless such copies or derivative
+works are solely in the form of machine-executable object code generated by
+a source language processor.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
+SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
+FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
+ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+*/
+
+
+#ifndef UTF8_FOR_CPP_a184c22c_d012_11e8_a8d5_f2801f1b9fd1
+#define UTF8_FOR_CPP_a184c22c_d012_11e8_a8d5_f2801f1b9fd1
+
+#include "checked.h"
+#include <string>
+
+namespace utf8
+{
+
+ inline void append(char32_t cp, std::string& s)
+ {
+ append(uint32_t(cp), std::back_inserter(s));
+ }
+
+ inline std::string utf16to8(const std::u16string& s)
+ {
+ std::string result;
+ utf16to8(s.begin(), s.end(), std::back_inserter(result));
+ return result;
+ }
+
+ inline std::u16string utf8to16(const std::string& s)
+ {
+ std::u16string result;
+ utf8to16(s.begin(), s.end(), std::back_inserter(result));
+ return result;
+ }
+
+ inline std::string utf32to8(const std::u32string& s)
+ {
+ std::string result;
+ utf32to8(s.begin(), s.end(), std::back_inserter(result));
+ return result;
+ }
+
+ inline std::u32string utf8to32(const std::string& s)
+ {
+ std::u32string result;
+ utf8to32(s.begin(), s.end(), std::back_inserter(result));
+ return result;
+ }
+
+ inline std::size_t find_invalid(const std::string& s)
+ {
+ std::string::const_iterator invalid = find_invalid(s.begin(), s.end());
+ return (invalid == s.end()) ? std::string::npos : (invalid - s.begin());
+ }
+
+ inline bool is_valid(const std::string& s)
+ {
+ return is_valid(s.begin(), s.end());
+ }
+
+ inline std::string replace_invalid(const std::string& s, char32_t replacement)
+ {
+ std::string result;
+ replace_invalid(s.begin(), s.end(), std::back_inserter(result), replacement);
+ return result;
+ }
+
+ inline std::string replace_invalid(const std::string& s)
+ {
+ std::string result;
+ replace_invalid(s.begin(), s.end(), std::back_inserter(result));
+ return result;
+ }
+
+ inline bool starts_with_bom(const std::string& s)
+ {
+ return starts_with_bom(s.begin(), s.end());
+ }
+
+} // namespace utf8
+
+#endif // header guard
+
diff --git a/src/arrow/cpp/src/arrow/vendored/xxhash.h b/src/arrow/cpp/src/arrow/vendored/xxhash.h
new file mode 100644
index 000000000..a33cdf861
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/xxhash.h
@@ -0,0 +1,18 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/vendored/xxhash/xxhash.h"
diff --git a/src/arrow/cpp/src/arrow/vendored/xxhash/README.md b/src/arrow/cpp/src/arrow/vendored/xxhash/README.md
new file mode 100644
index 000000000..6f942ede1
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/xxhash/README.md
@@ -0,0 +1,22 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+The files in this directory are vendored from xxHash git tag v0.8.0
+(https://github.com/Cyan4973/xxHash).
+Includes https://github.com/Cyan4973/xxHash/pull/502 for Solaris compatibility \ No newline at end of file
diff --git a/src/arrow/cpp/src/arrow/vendored/xxhash/xxhash.c b/src/arrow/cpp/src/arrow/vendored/xxhash/xxhash.c
new file mode 100644
index 000000000..0fae88c5d
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/xxhash/xxhash.c
@@ -0,0 +1,43 @@
+/*
+ * xxHash - Extremely Fast Hash algorithm
+ * Copyright (C) 2012-2020 Yann Collet
+ *
+ * BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php)
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * You can contact the author at:
+ * - xxHash homepage: https://www.xxhash.com
+ * - xxHash source repository: https://github.com/Cyan4973/xxHash
+ */
+
+
+/*
+ * xxhash.c instantiates functions defined in xxhash.h
+ */
+
+#define XXH_STATIC_LINKING_ONLY /* access advanced declarations */
+#define XXH_IMPLEMENTATION /* access definitions */
+
+#include "xxhash.h"
diff --git a/src/arrow/cpp/src/arrow/vendored/xxhash/xxhash.h b/src/arrow/cpp/src/arrow/vendored/xxhash/xxhash.h
new file mode 100644
index 000000000..99b2b4b38
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/vendored/xxhash/xxhash.h
@@ -0,0 +1,4769 @@
+/*
+ * xxHash - Extremely Fast Hash algorithm
+ * Header File
+ * Copyright (C) 2012-2020 Yann Collet
+ *
+ * BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php)
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * You can contact the author at:
+ * - xxHash homepage: https://www.xxhash.com
+ * - xxHash source repository: https://github.com/Cyan4973/xxHash
+ */
+
+/* TODO: update */
+/* Notice extracted from xxHash homepage:
+
+xxHash is an extremely fast hash algorithm, running at RAM speed limits.
+It also successfully passes all tests from the SMHasher suite.
+
+Comparison (single thread, Windows Seven 32 bits, using SMHasher on a Core 2 Duo @3GHz)
+
+Name Speed Q.Score Author
+xxHash 5.4 GB/s 10
+CrapWow 3.2 GB/s 2 Andrew
+MumurHash 3a 2.7 GB/s 10 Austin Appleby
+SpookyHash 2.0 GB/s 10 Bob Jenkins
+SBox 1.4 GB/s 9 Bret Mulvey
+Lookup3 1.2 GB/s 9 Bob Jenkins
+SuperFastHash 1.2 GB/s 1 Paul Hsieh
+CityHash64 1.05 GB/s 10 Pike & Alakuijala
+FNV 0.55 GB/s 5 Fowler, Noll, Vo
+CRC32 0.43 GB/s 9
+MD5-32 0.33 GB/s 10 Ronald L. Rivest
+SHA1-32 0.28 GB/s 10
+
+Q.Score is a measure of quality of the hash function.
+It depends on successfully passing SMHasher test set.
+10 is a perfect score.
+
+Note: SMHasher's CRC32 implementation is not the fastest one.
+Other speed-oriented implementations can be faster,
+especially in combination with PCLMUL instruction:
+https://fastcompression.blogspot.com/2019/03/presenting-xxh3.html?showComment=1552696407071#c3490092340461170735
+
+A 64-bit version, named XXH64, is available since r35.
+It offers much better speed, but for 64-bit applications only.
+Name Speed on 64 bits Speed on 32 bits
+XXH64 13.8 GB/s 1.9 GB/s
+XXH32 6.8 GB/s 6.0 GB/s
+*/
+
+#if defined (__cplusplus)
+extern "C" {
+#endif
+
+/* ****************************
+ * INLINE mode
+ ******************************/
+/*!
+ * XXH_INLINE_ALL (and XXH_PRIVATE_API)
+ * Use these build macros to inline xxhash into the target unit.
+ * Inlining improves performance on small inputs, especially when the length is
+ * expressed as a compile-time constant:
+ *
+ * https://fastcompression.blogspot.com/2018/03/xxhash-for-small-keys-impressive-power.html
+ *
+ * It also keeps xxHash symbols private to the unit, so they are not exported.
+ *
+ * Usage:
+ * #define XXH_INLINE_ALL
+ * #include "xxhash.h"
+ *
+ * Do not compile and link xxhash.o as a separate object, as it is not useful.
+ */
+#if (defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API)) \
+ && !defined(XXH_INLINE_ALL_31684351384)
+ /* this section should be traversed only once */
+# define XXH_INLINE_ALL_31684351384
+ /* give access to the advanced API, required to compile implementations */
+# undef XXH_STATIC_LINKING_ONLY /* avoid macro redef */
+# define XXH_STATIC_LINKING_ONLY
+ /* make all functions private */
+# undef XXH_PUBLIC_API
+# if defined(__GNUC__)
+# define XXH_PUBLIC_API static __inline __attribute__((unused))
+# elif defined (__cplusplus) || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */)
+# define XXH_PUBLIC_API static inline
+# elif defined(_MSC_VER)
+# define XXH_PUBLIC_API static __inline
+# else
+ /* note: this version may generate warnings for unused static functions */
+# define XXH_PUBLIC_API static
+# endif
+
+ /*
+ * This part deals with the special case where a unit wants to inline xxHash,
+ * but "xxhash.h" has previously been included without XXH_INLINE_ALL, such
+ * as part of some previously included *.h header file.
+ * Without further action, the new include would just be ignored,
+ * and functions would effectively _not_ be inlined (silent failure).
+ * The following macros solve this situation by prefixing all inlined names,
+ * avoiding naming collision with previous inclusions.
+ */
+# ifdef XXH_NAMESPACE
+# error "XXH_INLINE_ALL with XXH_NAMESPACE is not supported"
+ /*
+ * Note: Alternative: #undef all symbols (it's a pretty large list).
+ * Without #error: it compiles, but functions are actually not inlined.
+ */
+# endif
+# define XXH_NAMESPACE XXH_INLINE_
+ /*
+ * Some identifiers (enums, type names) are not symbols, but they must
+ * still be renamed to avoid redeclaration.
+ * Alternative solution: do not redeclare them.
+ * However, this requires some #ifdefs, and is a more dispersed action.
+ * Meanwhile, renaming can be achieved in a single block
+ */
+# define XXH_IPREF(Id) XXH_INLINE_ ## Id
+# define XXH_OK XXH_IPREF(XXH_OK)
+# define XXH_ERROR XXH_IPREF(XXH_ERROR)
+# define XXH_errorcode XXH_IPREF(XXH_errorcode)
+# define XXH32_canonical_t XXH_IPREF(XXH32_canonical_t)
+# define XXH64_canonical_t XXH_IPREF(XXH64_canonical_t)
+# define XXH128_canonical_t XXH_IPREF(XXH128_canonical_t)
+# define XXH32_state_s XXH_IPREF(XXH32_state_s)
+# define XXH32_state_t XXH_IPREF(XXH32_state_t)
+# define XXH64_state_s XXH_IPREF(XXH64_state_s)
+# define XXH64_state_t XXH_IPREF(XXH64_state_t)
+# define XXH3_state_s XXH_IPREF(XXH3_state_s)
+# define XXH3_state_t XXH_IPREF(XXH3_state_t)
+# define XXH128_hash_t XXH_IPREF(XXH128_hash_t)
+ /* Ensure the header is parsed again, even if it was previously included */
+# undef XXHASH_H_5627135585666179
+# undef XXHASH_H_STATIC_13879238742
+#endif /* XXH_INLINE_ALL || XXH_PRIVATE_API */
+
+
+
+/* ****************************************************************
+ * Stable API
+ *****************************************************************/
+#ifndef XXHASH_H_5627135585666179
+#define XXHASH_H_5627135585666179 1
+
+/* specific declaration modes for Windows */
+#if !defined(XXH_INLINE_ALL) && !defined(XXH_PRIVATE_API)
+# if defined(WIN32) && defined(_MSC_VER) && (defined(XXH_IMPORT) || defined(XXH_EXPORT))
+# ifdef XXH_EXPORT
+# define XXH_PUBLIC_API __declspec(dllexport)
+# elif XXH_IMPORT
+# define XXH_PUBLIC_API __declspec(dllimport)
+# endif
+# else
+# define XXH_PUBLIC_API /* do nothing */
+# endif
+#endif
+
+/*!
+ * XXH_NAMESPACE, aka Namespace Emulation:
+ *
+ * If you want to include _and expose_ xxHash functions from within your own
+ * library, but also want to avoid symbol collisions with other libraries which
+ * may also include xxHash, you can use XXH_NAMESPACE to automatically prefix
+ * any public symbol from xxhash library with the value of XXH_NAMESPACE
+ * (therefore, avoid empty or numeric values).
+ *
+ * Note that no change is required within the calling program as long as it
+ * includes `xxhash.h`: Regular symbol names will be automatically translated
+ * by this header.
+ */
+#ifdef XXH_NAMESPACE
+# define XXH_CAT(A,B) A##B
+# define XXH_NAME2(A,B) XXH_CAT(A,B)
+# define XXH_versionNumber XXH_NAME2(XXH_NAMESPACE, XXH_versionNumber)
+/* XXH32 */
+# define XXH32 XXH_NAME2(XXH_NAMESPACE, XXH32)
+# define XXH32_createState XXH_NAME2(XXH_NAMESPACE, XXH32_createState)
+# define XXH32_freeState XXH_NAME2(XXH_NAMESPACE, XXH32_freeState)
+# define XXH32_reset XXH_NAME2(XXH_NAMESPACE, XXH32_reset)
+# define XXH32_update XXH_NAME2(XXH_NAMESPACE, XXH32_update)
+# define XXH32_digest XXH_NAME2(XXH_NAMESPACE, XXH32_digest)
+# define XXH32_copyState XXH_NAME2(XXH_NAMESPACE, XXH32_copyState)
+# define XXH32_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH32_canonicalFromHash)
+# define XXH32_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH32_hashFromCanonical)
+/* XXH64 */
+# define XXH64 XXH_NAME2(XXH_NAMESPACE, XXH64)
+# define XXH64_createState XXH_NAME2(XXH_NAMESPACE, XXH64_createState)
+# define XXH64_freeState XXH_NAME2(XXH_NAMESPACE, XXH64_freeState)
+# define XXH64_reset XXH_NAME2(XXH_NAMESPACE, XXH64_reset)
+# define XXH64_update XXH_NAME2(XXH_NAMESPACE, XXH64_update)
+# define XXH64_digest XXH_NAME2(XXH_NAMESPACE, XXH64_digest)
+# define XXH64_copyState XXH_NAME2(XXH_NAMESPACE, XXH64_copyState)
+# define XXH64_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH64_canonicalFromHash)
+# define XXH64_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH64_hashFromCanonical)
+/* XXH3_64bits */
+# define XXH3_64bits XXH_NAME2(XXH_NAMESPACE, XXH3_64bits)
+# define XXH3_64bits_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_withSecret)
+# define XXH3_64bits_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_withSeed)
+# define XXH3_createState XXH_NAME2(XXH_NAMESPACE, XXH3_createState)
+# define XXH3_freeState XXH_NAME2(XXH_NAMESPACE, XXH3_freeState)
+# define XXH3_copyState XXH_NAME2(XXH_NAMESPACE, XXH3_copyState)
+# define XXH3_64bits_reset XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset)
+# define XXH3_64bits_reset_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset_withSeed)
+# define XXH3_64bits_reset_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset_withSecret)
+# define XXH3_64bits_update XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_update)
+# define XXH3_64bits_digest XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_digest)
+# define XXH3_generateSecret XXH_NAME2(XXH_NAMESPACE, XXH3_generateSecret)
+/* XXH3_128bits */
+# define XXH128 XXH_NAME2(XXH_NAMESPACE, XXH128)
+# define XXH3_128bits XXH_NAME2(XXH_NAMESPACE, XXH3_128bits)
+# define XXH3_128bits_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_withSeed)
+# define XXH3_128bits_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_withSecret)
+# define XXH3_128bits_reset XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset)
+# define XXH3_128bits_reset_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset_withSeed)
+# define XXH3_128bits_reset_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset_withSecret)
+# define XXH3_128bits_update XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_update)
+# define XXH3_128bits_digest XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_digest)
+# define XXH128_isEqual XXH_NAME2(XXH_NAMESPACE, XXH128_isEqual)
+# define XXH128_cmp XXH_NAME2(XXH_NAMESPACE, XXH128_cmp)
+# define XXH128_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH128_canonicalFromHash)
+# define XXH128_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH128_hashFromCanonical)
+#endif
+
+
+/* *************************************
+* Version
+***************************************/
+#define XXH_VERSION_MAJOR 0
+#define XXH_VERSION_MINOR 8
+#define XXH_VERSION_RELEASE 0
+#define XXH_VERSION_NUMBER (XXH_VERSION_MAJOR *100*100 + XXH_VERSION_MINOR *100 + XXH_VERSION_RELEASE)
+XXH_PUBLIC_API unsigned XXH_versionNumber (void);
+
+
+/* ****************************
+* Definitions
+******************************/
+#include <stddef.h> /* size_t */
+typedef enum { XXH_OK=0, XXH_ERROR } XXH_errorcode;
+
+
+/*-**********************************************************************
+* 32-bit hash
+************************************************************************/
+#if !defined (__VMS) \
+ && (defined (__cplusplus) \
+ || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) )
+# include <stdint.h>
+ typedef uint32_t XXH32_hash_t;
+#else
+# include <limits.h>
+# if UINT_MAX == 0xFFFFFFFFUL
+ typedef unsigned int XXH32_hash_t;
+# else
+# if ULONG_MAX == 0xFFFFFFFFUL
+ typedef unsigned long XXH32_hash_t;
+# else
+# error "unsupported platform: need a 32-bit type"
+# endif
+# endif
+#endif
+
+/*!
+ * XXH32():
+ * Calculate the 32-bit hash of sequence "length" bytes stored at memory address "input".
+ * The memory between input & input+length must be valid (allocated and read-accessible).
+ * "seed" can be used to alter the result predictably.
+ * Speed on Core 2 Duo @ 3 GHz (single thread, SMHasher benchmark): 5.4 GB/s
+ *
+ * Note: XXH3 provides competitive speed for both 32-bit and 64-bit systems,
+ * and offers true 64/128 bit hash results. It provides a superior level of
+ * dispersion, and greatly reduces the risks of collisions.
+ */
+XXH_PUBLIC_API XXH32_hash_t XXH32 (const void* input, size_t length, XXH32_hash_t seed);
+
+/******* Streaming *******/
+
+/*
+ * Streaming functions generate the xxHash value from an incrememtal input.
+ * This method is slower than single-call functions, due to state management.
+ * For small inputs, prefer `XXH32()` and `XXH64()`, which are better optimized.
+ *
+ * An XXH state must first be allocated using `XXH*_createState()`.
+ *
+ * Start a new hash by initializing the state with a seed using `XXH*_reset()`.
+ *
+ * Then, feed the hash state by calling `XXH*_update()` as many times as necessary.
+ *
+ * The function returns an error code, with 0 meaning OK, and any other value
+ * meaning there is an error.
+ *
+ * Finally, a hash value can be produced anytime, by using `XXH*_digest()`.
+ * This function returns the nn-bits hash as an int or long long.
+ *
+ * It's still possible to continue inserting input into the hash state after a
+ * digest, and generate new hash values later on by invoking `XXH*_digest()`.
+ *
+ * When done, release the state using `XXH*_freeState()`.
+ */
+
+typedef struct XXH32_state_s XXH32_state_t; /* incomplete type */
+XXH_PUBLIC_API XXH32_state_t* XXH32_createState(void);
+XXH_PUBLIC_API XXH_errorcode XXH32_freeState(XXH32_state_t* statePtr);
+XXH_PUBLIC_API void XXH32_copyState(XXH32_state_t* dst_state, const XXH32_state_t* src_state);
+
+XXH_PUBLIC_API XXH_errorcode XXH32_reset (XXH32_state_t* statePtr, XXH32_hash_t seed);
+XXH_PUBLIC_API XXH_errorcode XXH32_update (XXH32_state_t* statePtr, const void* input, size_t length);
+XXH_PUBLIC_API XXH32_hash_t XXH32_digest (const XXH32_state_t* statePtr);
+
+/******* Canonical representation *******/
+
+/*
+ * The default return values from XXH functions are unsigned 32 and 64 bit
+ * integers.
+ * This the simplest and fastest format for further post-processing.
+ *
+ * However, this leaves open the question of what is the order on the byte level,
+ * since little and big endian conventions will store the same number differently.
+ *
+ * The canonical representation settles this issue by mandating big-endian
+ * convention, the same convention as human-readable numbers (large digits first).
+ *
+ * When writing hash values to storage, sending them over a network, or printing
+ * them, it's highly recommended to use the canonical representation to ensure
+ * portability across a wider range of systems, present and future.
+ *
+ * The following functions allow transformation of hash values to and from
+ * canonical format.
+ */
+
+typedef struct { unsigned char digest[4]; } XXH32_canonical_t;
+XXH_PUBLIC_API void XXH32_canonicalFromHash(XXH32_canonical_t* dst, XXH32_hash_t hash);
+XXH_PUBLIC_API XXH32_hash_t XXH32_hashFromCanonical(const XXH32_canonical_t* src);
+
+
+#ifndef XXH_NO_LONG_LONG
+/*-**********************************************************************
+* 64-bit hash
+************************************************************************/
+#if !defined (__VMS) \
+ && (defined (__cplusplus) \
+ || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) )
+# include <stdint.h>
+ typedef uint64_t XXH64_hash_t;
+#else
+ /* the following type must have a width of 64-bit */
+ typedef unsigned long long XXH64_hash_t;
+#endif
+
+/*!
+ * XXH64():
+ * Returns the 64-bit hash of sequence of length @length stored at memory
+ * address @input.
+ * @seed can be used to alter the result predictably.
+ *
+ * This function usually runs faster on 64-bit systems, but slower on 32-bit
+ * systems (see benchmark).
+ *
+ * Note: XXH3 provides competitive speed for both 32-bit and 64-bit systems,
+ * and offers true 64/128 bit hash results. It provides a superior level of
+ * dispersion, and greatly reduces the risks of collisions.
+ */
+XXH_PUBLIC_API XXH64_hash_t XXH64 (const void* input, size_t length, XXH64_hash_t seed);
+
+/******* Streaming *******/
+typedef struct XXH64_state_s XXH64_state_t; /* incomplete type */
+XXH_PUBLIC_API XXH64_state_t* XXH64_createState(void);
+XXH_PUBLIC_API XXH_errorcode XXH64_freeState(XXH64_state_t* statePtr);
+XXH_PUBLIC_API void XXH64_copyState(XXH64_state_t* dst_state, const XXH64_state_t* src_state);
+
+XXH_PUBLIC_API XXH_errorcode XXH64_reset (XXH64_state_t* statePtr, XXH64_hash_t seed);
+XXH_PUBLIC_API XXH_errorcode XXH64_update (XXH64_state_t* statePtr, const void* input, size_t length);
+XXH_PUBLIC_API XXH64_hash_t XXH64_digest (const XXH64_state_t* statePtr);
+
+/******* Canonical representation *******/
+typedef struct { unsigned char digest[sizeof(XXH64_hash_t)]; } XXH64_canonical_t;
+XXH_PUBLIC_API void XXH64_canonicalFromHash(XXH64_canonical_t* dst, XXH64_hash_t hash);
+XXH_PUBLIC_API XXH64_hash_t XXH64_hashFromCanonical(const XXH64_canonical_t* src);
+
+
+/*-**********************************************************************
+* XXH3 64-bit variant
+************************************************************************/
+
+/* ************************************************************************
+ * XXH3 is a new hash algorithm featuring:
+ * - Improved speed for both small and large inputs
+ * - True 64-bit and 128-bit outputs
+ * - SIMD acceleration
+ * - Improved 32-bit viability
+ *
+ * Speed analysis methodology is explained here:
+ *
+ * https://fastcompression.blogspot.com/2019/03/presenting-xxh3.html
+ *
+ * In general, expect XXH3 to run about ~2x faster on large inputs and >3x
+ * faster on small ones compared to XXH64, though exact differences depend on
+ * the platform.
+ *
+ * The algorithm is portable: Like XXH32 and XXH64, it generates the same hash
+ * on all platforms.
+ *
+ * It benefits greatly from SIMD and 64-bit arithmetic, but does not require it.
+ *
+ * Almost all 32-bit and 64-bit targets that can run XXH32 smoothly can run
+ * XXH3 at competitive speeds, even if XXH64 runs slowly. Further details are
+ * explained in the implementation.
+ *
+ * Optimized implementations are provided for AVX512, AVX2, SSE2, NEON, POWER8,
+ * ZVector and scalar targets. This can be controlled with the XXH_VECTOR macro.
+ *
+ * XXH3 offers 2 variants, _64bits and _128bits.
+ * When only 64 bits are needed, prefer calling the _64bits variant, as it
+ * reduces the amount of mixing, resulting in faster speed on small inputs.
+ *
+ * It's also generally simpler to manipulate a scalar return type than a struct.
+ *
+ * The 128-bit version adds additional strength, but it is slightly slower.
+ *
+ * The XXH3 algorithm is still in development.
+ * The results it produces may still change in future versions.
+ *
+ * Results produced by v0.7.x are not comparable with results from v0.7.y.
+ * However, the API is completely stable, and it can safely be used for
+ * ephemeral data (local sessions).
+ *
+ * Avoid storing values in long-term storage until the algorithm is finalized.
+ * XXH3's return values will be officially finalized upon reaching v0.8.0.
+ *
+ * After which, return values of XXH3 and XXH128 will no longer change in
+ * future versions.
+ *
+ * The API supports one-shot hashing, streaming mode, and custom secrets.
+ */
+
+/* XXH3_64bits():
+ * default 64-bit variant, using default secret and default seed of 0.
+ * It's the fastest variant. */
+XXH_PUBLIC_API XXH64_hash_t XXH3_64bits(const void* data, size_t len);
+
+/*
+ * XXH3_64bits_withSeed():
+ * This variant generates a custom secret on the fly
+ * based on default secret altered using the `seed` value.
+ * While this operation is decently fast, note that it's not completely free.
+ * Note: seed==0 produces the same results as XXH3_64bits().
+ */
+XXH_PUBLIC_API XXH64_hash_t XXH3_64bits_withSeed(const void* data, size_t len, XXH64_hash_t seed);
+
+/*
+ * XXH3_64bits_withSecret():
+ * It's possible to provide any blob of bytes as a "secret" to generate the hash.
+ * This makes it more difficult for an external actor to prepare an intentional collision.
+ * The main condition is that secretSize *must* be large enough (>= XXH3_SECRET_SIZE_MIN).
+ * However, the quality of produced hash values depends on secret's entropy.
+ * Technically, the secret must look like a bunch of random bytes.
+ * Avoid "trivial" or structured data such as repeated sequences or a text document.
+ * Whenever unsure about the "randomness" of the blob of bytes,
+ * consider relabelling it as a "custom seed" instead,
+ * and employ "XXH3_generateSecret()" (see below)
+ * to generate a high entropy secret derived from the custom seed.
+ */
+#define XXH3_SECRET_SIZE_MIN 136
+XXH_PUBLIC_API XXH64_hash_t XXH3_64bits_withSecret(const void* data, size_t len, const void* secret, size_t secretSize);
+
+
+/******* Streaming *******/
+/*
+ * Streaming requires state maintenance.
+ * This operation costs memory and CPU.
+ * As a consequence, streaming is slower than one-shot hashing.
+ * For better performance, prefer one-shot functions whenever applicable.
+ */
+typedef struct XXH3_state_s XXH3_state_t;
+XXH_PUBLIC_API XXH3_state_t* XXH3_createState(void);
+XXH_PUBLIC_API XXH_errorcode XXH3_freeState(XXH3_state_t* statePtr);
+XXH_PUBLIC_API void XXH3_copyState(XXH3_state_t* dst_state, const XXH3_state_t* src_state);
+
+/*
+ * XXH3_64bits_reset():
+ * Initialize with default parameters.
+ * digest will be equivalent to `XXH3_64bits()`.
+ */
+XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset(XXH3_state_t* statePtr);
+/*
+ * XXH3_64bits_reset_withSeed():
+ * Generate a custom secret from `seed`, and store it into `statePtr`.
+ * digest will be equivalent to `XXH3_64bits_withSeed()`.
+ */
+XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset_withSeed(XXH3_state_t* statePtr, XXH64_hash_t seed);
+/*
+ * XXH3_64bits_reset_withSecret():
+ * `secret` is referenced, it _must outlive_ the hash streaming session.
+ * Similar to one-shot API, `secretSize` must be >= `XXH3_SECRET_SIZE_MIN`,
+ * and the quality of produced hash values depends on secret's entropy
+ * (secret's content should look like a bunch of random bytes).
+ * When in doubt about the randomness of a candidate `secret`,
+ * consider employing `XXH3_generateSecret()` instead (see below).
+ */
+XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset_withSecret(XXH3_state_t* statePtr, const void* secret, size_t secretSize);
+
+XXH_PUBLIC_API XXH_errorcode XXH3_64bits_update (XXH3_state_t* statePtr, const void* input, size_t length);
+XXH_PUBLIC_API XXH64_hash_t XXH3_64bits_digest (const XXH3_state_t* statePtr);
+
+/* note : canonical representation of XXH3 is the same as XXH64
+ * since they both produce XXH64_hash_t values */
+
+
+/*-**********************************************************************
+* XXH3 128-bit variant
+************************************************************************/
+
+typedef struct {
+ XXH64_hash_t low64;
+ XXH64_hash_t high64;
+} XXH128_hash_t;
+
+XXH_PUBLIC_API XXH128_hash_t XXH3_128bits(const void* data, size_t len);
+XXH_PUBLIC_API XXH128_hash_t XXH3_128bits_withSeed(const void* data, size_t len, XXH64_hash_t seed);
+XXH_PUBLIC_API XXH128_hash_t XXH3_128bits_withSecret(const void* data, size_t len, const void* secret, size_t secretSize);
+
+/******* Streaming *******/
+/*
+ * Streaming requires state maintenance.
+ * This operation costs memory and CPU.
+ * As a consequence, streaming is slower than one-shot hashing.
+ * For better performance, prefer one-shot functions whenever applicable.
+ *
+ * XXH3_128bits uses the same XXH3_state_t as XXH3_64bits().
+ * Use already declared XXH3_createState() and XXH3_freeState().
+ *
+ * All reset and streaming functions have same meaning as their 64-bit counterpart.
+ */
+
+XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset(XXH3_state_t* statePtr);
+XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset_withSeed(XXH3_state_t* statePtr, XXH64_hash_t seed);
+XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset_withSecret(XXH3_state_t* statePtr, const void* secret, size_t secretSize);
+
+XXH_PUBLIC_API XXH_errorcode XXH3_128bits_update (XXH3_state_t* statePtr, const void* input, size_t length);
+XXH_PUBLIC_API XXH128_hash_t XXH3_128bits_digest (const XXH3_state_t* statePtr);
+
+/* Following helper functions make it possible to compare XXH128_hast_t values.
+ * Since XXH128_hash_t is a structure, this capability is not offered by the language.
+ * Note: For better performance, these functions can be inlined using XXH_INLINE_ALL */
+
+/*!
+ * XXH128_isEqual():
+ * Return: 1 if `h1` and `h2` are equal, 0 if they are not.
+ */
+XXH_PUBLIC_API int XXH128_isEqual(XXH128_hash_t h1, XXH128_hash_t h2);
+
+/*!
+ * XXH128_cmp():
+ *
+ * This comparator is compatible with stdlib's `qsort()`/`bsearch()`.
+ *
+ * return: >0 if *h128_1 > *h128_2
+ * =0 if *h128_1 == *h128_2
+ * <0 if *h128_1 < *h128_2
+ */
+XXH_PUBLIC_API int XXH128_cmp(const void* h128_1, const void* h128_2);
+
+
+/******* Canonical representation *******/
+typedef struct { unsigned char digest[sizeof(XXH128_hash_t)]; } XXH128_canonical_t;
+XXH_PUBLIC_API void XXH128_canonicalFromHash(XXH128_canonical_t* dst, XXH128_hash_t hash);
+XXH_PUBLIC_API XXH128_hash_t XXH128_hashFromCanonical(const XXH128_canonical_t* src);
+
+
+#endif /* XXH_NO_LONG_LONG */
+
+#endif /* XXHASH_H_5627135585666179 */
+
+
+
+#if defined(XXH_STATIC_LINKING_ONLY) && !defined(XXHASH_H_STATIC_13879238742)
+#define XXHASH_H_STATIC_13879238742
+/* ****************************************************************************
+ * This section contains declarations which are not guaranteed to remain stable.
+ * They may change in future versions, becoming incompatible with a different
+ * version of the library.
+ * These declarations should only be used with static linking.
+ * Never use them in association with dynamic linking!
+ ***************************************************************************** */
+
+/*
+ * These definitions are only present to allow static allocation
+ * of XXH states, on stack or in a struct, for example.
+ * Never **ever** access their members directly.
+ */
+
+struct XXH32_state_s {
+ XXH32_hash_t total_len_32;
+ XXH32_hash_t large_len;
+ XXH32_hash_t v1;
+ XXH32_hash_t v2;
+ XXH32_hash_t v3;
+ XXH32_hash_t v4;
+ XXH32_hash_t mem32[4];
+ XXH32_hash_t memsize;
+ XXH32_hash_t reserved; /* never read nor write, might be removed in a future version */
+}; /* typedef'd to XXH32_state_t */
+
+
+#ifndef XXH_NO_LONG_LONG /* defined when there is no 64-bit support */
+
+struct XXH64_state_s {
+ XXH64_hash_t total_len;
+ XXH64_hash_t v1;
+ XXH64_hash_t v2;
+ XXH64_hash_t v3;
+ XXH64_hash_t v4;
+ XXH64_hash_t mem64[4];
+ XXH32_hash_t memsize;
+ XXH32_hash_t reserved32; /* required for padding anyway */
+ XXH64_hash_t reserved64; /* never read nor write, might be removed in a future version */
+}; /* typedef'd to XXH64_state_t */
+
+#if defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) /* C11+ */
+# include <stdalign.h>
+# define XXH_ALIGN(n) alignas(n)
+#elif defined(__GNUC__)
+# define XXH_ALIGN(n) __attribute__ ((aligned(n)))
+#elif defined(_MSC_VER)
+# define XXH_ALIGN(n) __declspec(align(n))
+#else
+# define XXH_ALIGN(n) /* disabled */
+#endif
+
+/* Old GCC versions only accept the attribute after the type in structures. */
+#if !(defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)) /* C11+ */ \
+ && defined(__GNUC__)
+# define XXH_ALIGN_MEMBER(align, type) type XXH_ALIGN(align)
+#else
+# define XXH_ALIGN_MEMBER(align, type) XXH_ALIGN(align) type
+#endif
+
+#define XXH3_INTERNALBUFFER_SIZE 256
+#define XXH3_SECRET_DEFAULT_SIZE 192
+struct XXH3_state_s {
+ XXH_ALIGN_MEMBER(64, XXH64_hash_t acc[8]);
+ /* used to store a custom secret generated from a seed */
+ XXH_ALIGN_MEMBER(64, unsigned char customSecret[XXH3_SECRET_DEFAULT_SIZE]);
+ XXH_ALIGN_MEMBER(64, unsigned char buffer[XXH3_INTERNALBUFFER_SIZE]);
+ XXH32_hash_t bufferedSize;
+ XXH32_hash_t reserved32;
+ size_t nbStripesSoFar;
+ XXH64_hash_t totalLen;
+ size_t nbStripesPerBlock;
+ size_t secretLimit;
+ XXH64_hash_t seed;
+ XXH64_hash_t reserved64;
+ const unsigned char* extSecret; /* reference to external secret;
+ * if == NULL, use .customSecret instead */
+ /* note: there may be some padding at the end due to alignment on 64 bytes */
+}; /* typedef'd to XXH3_state_t */
+
+#undef XXH_ALIGN_MEMBER
+
+/* When the XXH3_state_t structure is merely emplaced on stack,
+ * it should be initialized with XXH3_INITSTATE() or a memset()
+ * in case its first reset uses XXH3_NNbits_reset_withSeed().
+ * This init can be omitted if the first reset uses default or _withSecret mode.
+ * This operation isn't necessary when the state is created with XXH3_createState().
+ * Note that this doesn't prepare the state for a streaming operation,
+ * it's still necessary to use XXH3_NNbits_reset*() afterwards.
+ */
+#define XXH3_INITSTATE(XXH3_state_ptr) { (XXH3_state_ptr)->seed = 0; }
+
+
+/* === Experimental API === */
+/* Symbols defined below must be considered tied to a specific library version. */
+
+/*
+ * XXH3_generateSecret():
+ *
+ * Derive a high-entropy secret from any user-defined content, named customSeed.
+ * The generated secret can be used in combination with `*_withSecret()` functions.
+ * The `_withSecret()` variants are useful to provide a higher level of protection than 64-bit seed,
+ * as it becomes much more difficult for an external actor to guess how to impact the calculation logic.
+ *
+ * The function accepts as input a custom seed of any length and any content,
+ * and derives from it a high-entropy secret of length XXH3_SECRET_DEFAULT_SIZE
+ * into an already allocated buffer secretBuffer.
+ * The generated secret is _always_ XXH_SECRET_DEFAULT_SIZE bytes long.
+ *
+ * The generated secret can then be used with any `*_withSecret()` variant.
+ * Functions `XXH3_128bits_withSecret()`, `XXH3_64bits_withSecret()`,
+ * `XXH3_128bits_reset_withSecret()` and `XXH3_64bits_reset_withSecret()`
+ * are part of this list. They all accept a `secret` parameter
+ * which must be very long for implementation reasons (>= XXH3_SECRET_SIZE_MIN)
+ * _and_ feature very high entropy (consist of random-looking bytes).
+ * These conditions can be a high bar to meet, so
+ * this function can be used to generate a secret of proper quality.
+ *
+ * customSeed can be anything. It can have any size, even small ones,
+ * and its content can be anything, even stupidly "low entropy" source such as a bunch of zeroes.
+ * The resulting `secret` will nonetheless provide all expected qualities.
+ *
+ * Supplying NULL as the customSeed copies the default secret into `secretBuffer`.
+ * When customSeedSize > 0, supplying NULL as customSeed is undefined behavior.
+ */
+XXH_PUBLIC_API void XXH3_generateSecret(void* secretBuffer, const void* customSeed, size_t customSeedSize);
+
+
+/* simple short-cut to pre-selected XXH3_128bits variant */
+XXH_PUBLIC_API XXH128_hash_t XXH128(const void* data, size_t len, XXH64_hash_t seed);
+
+
+#endif /* XXH_NO_LONG_LONG */
+
+
+#if defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API)
+# define XXH_IMPLEMENTATION
+#endif
+
+#endif /* defined(XXH_STATIC_LINKING_ONLY) && !defined(XXHASH_H_STATIC_13879238742) */
+
+
+/* ======================================================================== */
+/* ======================================================================== */
+/* ======================================================================== */
+
+
+/*-**********************************************************************
+ * xxHash implementation
+ *-**********************************************************************
+ * xxHash's implementation used to be hosted inside xxhash.c.
+ *
+ * However, inlining requires implementation to be visible to the compiler,
+ * hence be included alongside the header.
+ * Previously, implementation was hosted inside xxhash.c,
+ * which was then #included when inlining was activated.
+ * This construction created issues with a few build and install systems,
+ * as it required xxhash.c to be stored in /include directory.
+ *
+ * xxHash implementation is now directly integrated within xxhash.h.
+ * As a consequence, xxhash.c is no longer needed in /include.
+ *
+ * xxhash.c is still available and is still useful.
+ * In a "normal" setup, when xxhash is not inlined,
+ * xxhash.h only exposes the prototypes and public symbols,
+ * while xxhash.c can be built into an object file xxhash.o
+ * which can then be linked into the final binary.
+ ************************************************************************/
+
+#if ( defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API) \
+ || defined(XXH_IMPLEMENTATION) ) && !defined(XXH_IMPLEM_13a8737387)
+# define XXH_IMPLEM_13a8737387
+
+/* *************************************
+* Tuning parameters
+***************************************/
+/*!
+ * XXH_FORCE_MEMORY_ACCESS:
+ * By default, access to unaligned memory is controlled by `memcpy()`, which is
+ * safe and portable.
+ *
+ * Unfortunately, on some target/compiler combinations, the generated assembly
+ * is sub-optimal.
+ *
+ * The below switch allow selection of a different access method
+ * in the search for improved performance.
+ * Method 0 (default):
+ * Use `memcpy()`. Safe and portable. Default.
+ * Method 1:
+ * `__attribute__((packed))` statement. It depends on compiler extensions
+ * and is therefore not portable.
+ * This method is safe if your compiler supports it, and *generally* as
+ * fast or faster than `memcpy`.
+ * Method 2:
+ * Direct access via cast. This method doesn't depend on the compiler but
+ * violates the C standard.
+ * It can generate buggy code on targets which do not support unaligned
+ * memory accesses.
+ * But in some circumstances, it's the only known way to get the most
+ * performance (example: GCC + ARMv6)
+ * Method 3:
+ * Byteshift. This can generate the best code on old compilers which don't
+ * inline small `memcpy()` calls, and it might also be faster on big-endian
+ * systems which lack a native byteswap instruction.
+ * See https://stackoverflow.com/a/32095106/646947 for details.
+ * Prefer these methods in priority order (0 > 1 > 2 > 3)
+ */
+#ifndef XXH_FORCE_MEMORY_ACCESS /* can be defined externally, on command line for example */
+# if !defined(__clang__) && defined(__GNUC__) && defined(__ARM_FEATURE_UNALIGNED) && defined(__ARM_ARCH) && (__ARM_ARCH == 6)
+# define XXH_FORCE_MEMORY_ACCESS 2
+# elif !defined(__clang__) && ((defined(__INTEL_COMPILER) && !defined(_WIN32)) || \
+ (defined(__GNUC__) && (defined(__ARM_ARCH) && __ARM_ARCH >= 7)))
+# define XXH_FORCE_MEMORY_ACCESS 1
+# endif
+#endif
+
+/*!
+ * XXH_ACCEPT_NULL_INPUT_POINTER:
+ * If the input pointer is NULL, xxHash's default behavior is to dereference it,
+ * triggering a segfault.
+ * When this macro is enabled, xxHash actively checks the input for a null pointer.
+ * If it is, the result for null input pointers is the same as a zero-length input.
+ */
+#ifndef XXH_ACCEPT_NULL_INPUT_POINTER /* can be defined externally */
+# define XXH_ACCEPT_NULL_INPUT_POINTER 0
+#endif
+
+/*!
+ * XXH_FORCE_ALIGN_CHECK:
+ * This is an important performance trick
+ * for architectures without decent unaligned memory access performance.
+ * It checks for input alignment, and when conditions are met,
+ * uses a "fast path" employing direct 32-bit/64-bit read,
+ * resulting in _dramatically faster_ read speed.
+ *
+ * The check costs one initial branch per hash, which is generally negligible, but not zero.
+ * Moreover, it's not useful to generate binary for an additional code path
+ * if memory access uses same instruction for both aligned and unaligned adresses.
+ *
+ * In these cases, the alignment check can be removed by setting this macro to 0.
+ * Then the code will always use unaligned memory access.
+ * Align check is automatically disabled on x86, x64 & arm64,
+ * which are platforms known to offer good unaligned memory accesses performance.
+ *
+ * This option does not affect XXH3 (only XXH32 and XXH64).
+ */
+#ifndef XXH_FORCE_ALIGN_CHECK /* can be defined externally */
+# if defined(__i386) || defined(__x86_64__) || defined(__aarch64__) \
+ || defined(_M_IX86) || defined(_M_X64) || defined(_M_ARM64) /* visual */
+# define XXH_FORCE_ALIGN_CHECK 0
+# else
+# define XXH_FORCE_ALIGN_CHECK 1
+# endif
+#endif
+
+/*!
+ * XXH_NO_INLINE_HINTS:
+ *
+ * By default, xxHash tries to force the compiler to inline almost all internal
+ * functions.
+ *
+ * This can usually improve performance due to reduced jumping and improved
+ * constant folding, but significantly increases the size of the binary which
+ * might not be favorable.
+ *
+ * Additionally, sometimes the forced inlining can be detrimental to performance,
+ * depending on the architecture.
+ *
+ * XXH_NO_INLINE_HINTS marks all internal functions as static, giving the
+ * compiler full control on whether to inline or not.
+ *
+ * When not optimizing (-O0), optimizing for size (-Os, -Oz), or using
+ * -fno-inline with GCC or Clang, this will automatically be defined.
+ */
+#ifndef XXH_NO_INLINE_HINTS
+# if defined(__OPTIMIZE_SIZE__) /* -Os, -Oz */ \
+ || defined(__NO_INLINE__) /* -O0, -fno-inline */
+# define XXH_NO_INLINE_HINTS 1
+# else
+# define XXH_NO_INLINE_HINTS 0
+# endif
+#endif
+
+/*!
+ * XXH_REROLL:
+ * Whether to reroll XXH32_finalize, and XXH64_finalize,
+ * instead of using an unrolled jump table/if statement loop.
+ *
+ * This is automatically defined on -Os/-Oz on GCC and Clang.
+ */
+#ifndef XXH_REROLL
+# if defined(__OPTIMIZE_SIZE__)
+# define XXH_REROLL 1
+# else
+# define XXH_REROLL 0
+# endif
+#endif
+
+
+/* *************************************
+* Includes & Memory related functions
+***************************************/
+/*!
+ * Modify the local functions below should you wish to use
+ * different memory routines for malloc() and free()
+ */
+#include <stdlib.h>
+
+static void* XXH_malloc(size_t s) { return malloc(s); }
+static void XXH_free(void* p) { free(p); }
+
+/*! and for memcpy() */
+#include <string.h>
+static void* XXH_memcpy(void* dest, const void* src, size_t size)
+{
+ return memcpy(dest,src,size);
+}
+
+#include <limits.h> /* ULLONG_MAX */
+
+
+/* *************************************
+* Compiler Specific Options
+***************************************/
+#ifdef _MSC_VER /* Visual Studio warning fix */
+# pragma warning(disable : 4127) /* disable: C4127: conditional expression is constant */
+#endif
+
+#if XXH_NO_INLINE_HINTS /* disable inlining hints */
+# if defined(__GNUC__)
+# define XXH_FORCE_INLINE static __attribute__((unused))
+# else
+# define XXH_FORCE_INLINE static
+# endif
+# define XXH_NO_INLINE static
+/* enable inlining hints */
+#elif defined(_MSC_VER) /* Visual Studio */
+# define XXH_FORCE_INLINE static __forceinline
+# define XXH_NO_INLINE static __declspec(noinline)
+#elif defined(__GNUC__)
+# define XXH_FORCE_INLINE static __inline__ __attribute__((always_inline, unused))
+# define XXH_NO_INLINE static __attribute__((noinline))
+#elif defined (__cplusplus) \
+ || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) /* C99 */
+# define XXH_FORCE_INLINE static inline
+# define XXH_NO_INLINE static
+#else
+# define XXH_FORCE_INLINE static
+# define XXH_NO_INLINE static
+#endif
+
+
+
+/* *************************************
+* Debug
+***************************************/
+/*
+ * XXH_DEBUGLEVEL is expected to be defined externally, typically via the
+ * compiler's command line options. The value must be a number.
+ */
+#ifndef XXH_DEBUGLEVEL
+# ifdef DEBUGLEVEL /* backwards compat */
+# define XXH_DEBUGLEVEL DEBUGLEVEL
+# else
+# define XXH_DEBUGLEVEL 0
+# endif
+#endif
+
+#if (XXH_DEBUGLEVEL>=1)
+# include <assert.h> /* note: can still be disabled with NDEBUG */
+# define XXH_ASSERT(c) assert(c)
+#else
+# define XXH_ASSERT(c) ((void)0)
+#endif
+
+/* note: use after variable declarations */
+#define XXH_STATIC_ASSERT(c) do { enum { XXH_sa = 1/(int)(!!(c)) }; } while (0)
+
+
+/* *************************************
+* Basic Types
+***************************************/
+#if !defined (__VMS) \
+ && (defined (__cplusplus) \
+ || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) )
+# include <stdint.h>
+ typedef uint8_t xxh_u8;
+#else
+ typedef unsigned char xxh_u8;
+#endif
+typedef XXH32_hash_t xxh_u32;
+
+#ifdef XXH_OLD_NAMES
+# define BYTE xxh_u8
+# define U8 xxh_u8
+# define U32 xxh_u32
+#endif
+
+/* *** Memory access *** */
+
+#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3))
+/*
+ * Manual byteshift. Best for old compilers which don't inline memcpy.
+ * We actually directly use XXH_readLE32 and XXH_readBE32.
+ */
+#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==2))
+
+/*
+ * Force direct memory access. Only works on CPU which support unaligned memory
+ * access in hardware.
+ */
+static xxh_u32 XXH_read32(const void* memPtr) { return *(const xxh_u32*) memPtr; }
+
+#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==1))
+
+/*
+ * __pack instructions are safer but compiler specific, hence potentially
+ * problematic for some compilers.
+ *
+ * Currently only defined for GCC and ICC.
+ */
+#ifdef XXH_OLD_NAMES
+typedef union { xxh_u32 u32; } __attribute__((packed)) unalign;
+#endif
+static xxh_u32 XXH_read32(const void* ptr)
+{
+ typedef union { xxh_u32 u32; } __attribute__((packed)) xxh_unalign;
+ return ((const xxh_unalign*)ptr)->u32;
+}
+
+#else
+
+/*
+ * Portable and safe solution. Generally efficient.
+ * see: https://stackoverflow.com/a/32095106/646947
+ */
+static xxh_u32 XXH_read32(const void* memPtr)
+{
+ xxh_u32 val;
+ memcpy(&val, memPtr, sizeof(val));
+ return val;
+}
+
+#endif /* XXH_FORCE_DIRECT_MEMORY_ACCESS */
+
+
+/* *** Endianess *** */
+typedef enum { XXH_bigEndian=0, XXH_littleEndian=1 } XXH_endianess;
+
+/*!
+ * XXH_CPU_LITTLE_ENDIAN:
+ * Defined to 1 if the target is little endian, or 0 if it is big endian.
+ * It can be defined externally, for example on the compiler command line.
+ *
+ * If it is not defined, a runtime check (which is usually constant folded)
+ * is used instead.
+ */
+#ifndef XXH_CPU_LITTLE_ENDIAN
+/*
+ * Try to detect endianness automatically, to avoid the nonstandard behavior
+ * in `XXH_isLittleEndian()`
+ */
+# if defined(_WIN32) /* Windows is always little endian */ \
+ || defined(__LITTLE_ENDIAN__) \
+ || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
+# define XXH_CPU_LITTLE_ENDIAN 1
+# elif defined(__BIG_ENDIAN__) \
+ || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+# define XXH_CPU_LITTLE_ENDIAN 0
+# else
+/*
+ * runtime test, presumed to simplify to a constant by compiler
+ */
+static int XXH_isLittleEndian(void)
+{
+ /*
+ * Portable and well-defined behavior.
+ * Don't use static: it is detrimental to performance.
+ */
+ const union { xxh_u32 u; xxh_u8 c[4]; } one = { 1 };
+ return one.c[0];
+}
+# define XXH_CPU_LITTLE_ENDIAN XXH_isLittleEndian()
+# endif
+#endif
+
+
+
+
+/* ****************************************
+* Compiler-specific Functions and Macros
+******************************************/
+#define XXH_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__)
+
+#ifdef __has_builtin
+# define XXH_HAS_BUILTIN(x) __has_builtin(x)
+#else
+# define XXH_HAS_BUILTIN(x) 0
+#endif
+
+#if !defined(NO_CLANG_BUILTIN) && XXH_HAS_BUILTIN(__builtin_rotateleft32) \
+ && XXH_HAS_BUILTIN(__builtin_rotateleft64)
+# define XXH_rotl32 __builtin_rotateleft32
+# define XXH_rotl64 __builtin_rotateleft64
+/* Note: although _rotl exists for minGW (GCC under windows), performance seems poor */
+#elif defined(_MSC_VER)
+# define XXH_rotl32(x,r) _rotl(x,r)
+# define XXH_rotl64(x,r) _rotl64(x,r)
+#else
+# define XXH_rotl32(x,r) (((x) << (r)) | ((x) >> (32 - (r))))
+# define XXH_rotl64(x,r) (((x) << (r)) | ((x) >> (64 - (r))))
+#endif
+
+#if defined(_MSC_VER) /* Visual Studio */
+# define XXH_swap32 _byteswap_ulong
+#elif XXH_GCC_VERSION >= 403
+# define XXH_swap32 __builtin_bswap32
+#else
+static xxh_u32 XXH_swap32 (xxh_u32 x)
+{
+ return ((x << 24) & 0xff000000 ) |
+ ((x << 8) & 0x00ff0000 ) |
+ ((x >> 8) & 0x0000ff00 ) |
+ ((x >> 24) & 0x000000ff );
+}
+#endif
+
+
+/* ***************************
+* Memory reads
+*****************************/
+typedef enum { XXH_aligned, XXH_unaligned } XXH_alignment;
+
+/*
+ * XXH_FORCE_MEMORY_ACCESS==3 is an endian-independent byteshift load.
+ *
+ * This is ideal for older compilers which don't inline memcpy.
+ */
+#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3))
+
+XXH_FORCE_INLINE xxh_u32 XXH_readLE32(const void* memPtr)
+{
+ const xxh_u8* bytePtr = (const xxh_u8 *)memPtr;
+ return bytePtr[0]
+ | ((xxh_u32)bytePtr[1] << 8)
+ | ((xxh_u32)bytePtr[2] << 16)
+ | ((xxh_u32)bytePtr[3] << 24);
+}
+
+XXH_FORCE_INLINE xxh_u32 XXH_readBE32(const void* memPtr)
+{
+ const xxh_u8* bytePtr = (const xxh_u8 *)memPtr;
+ return bytePtr[3]
+ | ((xxh_u32)bytePtr[2] << 8)
+ | ((xxh_u32)bytePtr[1] << 16)
+ | ((xxh_u32)bytePtr[0] << 24);
+}
+
+#else
+XXH_FORCE_INLINE xxh_u32 XXH_readLE32(const void* ptr)
+{
+ return XXH_CPU_LITTLE_ENDIAN ? XXH_read32(ptr) : XXH_swap32(XXH_read32(ptr));
+}
+
+static xxh_u32 XXH_readBE32(const void* ptr)
+{
+ return XXH_CPU_LITTLE_ENDIAN ? XXH_swap32(XXH_read32(ptr)) : XXH_read32(ptr);
+}
+#endif
+
+XXH_FORCE_INLINE xxh_u32
+XXH_readLE32_align(const void* ptr, XXH_alignment align)
+{
+ if (align==XXH_unaligned) {
+ return XXH_readLE32(ptr);
+ } else {
+ return XXH_CPU_LITTLE_ENDIAN ? *(const xxh_u32*)ptr : XXH_swap32(*(const xxh_u32*)ptr);
+ }
+}
+
+
+/* *************************************
+* Misc
+***************************************/
+XXH_PUBLIC_API unsigned XXH_versionNumber (void) { return XXH_VERSION_NUMBER; }
+
+
+/* *******************************************************************
+* 32-bit hash functions
+*********************************************************************/
+static const xxh_u32 XXH_PRIME32_1 = 0x9E3779B1U; /* 0b10011110001101110111100110110001 */
+static const xxh_u32 XXH_PRIME32_2 = 0x85EBCA77U; /* 0b10000101111010111100101001110111 */
+static const xxh_u32 XXH_PRIME32_3 = 0xC2B2AE3DU; /* 0b11000010101100101010111000111101 */
+static const xxh_u32 XXH_PRIME32_4 = 0x27D4EB2FU; /* 0b00100111110101001110101100101111 */
+static const xxh_u32 XXH_PRIME32_5 = 0x165667B1U; /* 0b00010110010101100110011110110001 */
+
+#ifdef XXH_OLD_NAMES
+# define PRIME32_1 XXH_PRIME32_1
+# define PRIME32_2 XXH_PRIME32_2
+# define PRIME32_3 XXH_PRIME32_3
+# define PRIME32_4 XXH_PRIME32_4
+# define PRIME32_5 XXH_PRIME32_5
+#endif
+
+static xxh_u32 XXH32_round(xxh_u32 acc, xxh_u32 input)
+{
+ acc += input * XXH_PRIME32_2;
+ acc = XXH_rotl32(acc, 13);
+ acc *= XXH_PRIME32_1;
+#if defined(__GNUC__) && defined(__SSE4_1__) && !defined(XXH_ENABLE_AUTOVECTORIZE)
+ /*
+ * UGLY HACK:
+ * This inline assembly hack forces acc into a normal register. This is the
+ * only thing that prevents GCC and Clang from autovectorizing the XXH32
+ * loop (pragmas and attributes don't work for some resason) without globally
+ * disabling SSE4.1.
+ *
+ * The reason we want to avoid vectorization is because despite working on
+ * 4 integers at a time, there are multiple factors slowing XXH32 down on
+ * SSE4:
+ * - There's a ridiculous amount of lag from pmulld (10 cycles of latency on
+ * newer chips!) making it slightly slower to multiply four integers at
+ * once compared to four integers independently. Even when pmulld was
+ * fastest, Sandy/Ivy Bridge, it is still not worth it to go into SSE
+ * just to multiply unless doing a long operation.
+ *
+ * - Four instructions are required to rotate,
+ * movqda tmp, v // not required with VEX encoding
+ * pslld tmp, 13 // tmp <<= 13
+ * psrld v, 19 // x >>= 19
+ * por v, tmp // x |= tmp
+ * compared to one for scalar:
+ * roll v, 13 // reliably fast across the board
+ * shldl v, v, 13 // Sandy Bridge and later prefer this for some reason
+ *
+ * - Instruction level parallelism is actually more beneficial here because
+ * the SIMD actually serializes this operation: While v1 is rotating, v2
+ * can load data, while v3 can multiply. SSE forces them to operate
+ * together.
+ *
+ * How this hack works:
+ * __asm__("" // Declare an assembly block but don't declare any instructions
+ * : // However, as an Input/Output Operand,
+ * "+r" // constrain a read/write operand (+) as a general purpose register (r).
+ * (acc) // and set acc as the operand
+ * );
+ *
+ * Because of the 'r', the compiler has promised that seed will be in a
+ * general purpose register and the '+' says that it will be 'read/write',
+ * so it has to assume it has changed. It is like volatile without all the
+ * loads and stores.
+ *
+ * Since the argument has to be in a normal register (not an SSE register),
+ * each time XXH32_round is called, it is impossible to vectorize.
+ */
+ __asm__("" : "+r" (acc));
+#endif
+ return acc;
+}
+
+/* mix all bits */
+static xxh_u32 XXH32_avalanche(xxh_u32 h32)
+{
+ h32 ^= h32 >> 15;
+ h32 *= XXH_PRIME32_2;
+ h32 ^= h32 >> 13;
+ h32 *= XXH_PRIME32_3;
+ h32 ^= h32 >> 16;
+ return(h32);
+}
+
+#define XXH_get32bits(p) XXH_readLE32_align(p, align)
+
+static xxh_u32
+XXH32_finalize(xxh_u32 h32, const xxh_u8* ptr, size_t len, XXH_alignment align)
+{
+#define XXH_PROCESS1 do { \
+ h32 += (*ptr++) * XXH_PRIME32_5; \
+ h32 = XXH_rotl32(h32, 11) * XXH_PRIME32_1; \
+} while (0)
+
+#define XXH_PROCESS4 do { \
+ h32 += XXH_get32bits(ptr) * XXH_PRIME32_3; \
+ ptr += 4; \
+ h32 = XXH_rotl32(h32, 17) * XXH_PRIME32_4; \
+} while (0)
+
+ /* Compact rerolled version */
+ if (XXH_REROLL) {
+ len &= 15;
+ while (len >= 4) {
+ XXH_PROCESS4;
+ len -= 4;
+ }
+ while (len > 0) {
+ XXH_PROCESS1;
+ --len;
+ }
+ return XXH32_avalanche(h32);
+ } else {
+ switch(len&15) /* or switch(bEnd - p) */ {
+ case 12: XXH_PROCESS4;
+ /* fallthrough */
+ case 8: XXH_PROCESS4;
+ /* fallthrough */
+ case 4: XXH_PROCESS4;
+ return XXH32_avalanche(h32);
+
+ case 13: XXH_PROCESS4;
+ /* fallthrough */
+ case 9: XXH_PROCESS4;
+ /* fallthrough */
+ case 5: XXH_PROCESS4;
+ XXH_PROCESS1;
+ return XXH32_avalanche(h32);
+
+ case 14: XXH_PROCESS4;
+ /* fallthrough */
+ case 10: XXH_PROCESS4;
+ /* fallthrough */
+ case 6: XXH_PROCESS4;
+ XXH_PROCESS1;
+ XXH_PROCESS1;
+ return XXH32_avalanche(h32);
+
+ case 15: XXH_PROCESS4;
+ /* fallthrough */
+ case 11: XXH_PROCESS4;
+ /* fallthrough */
+ case 7: XXH_PROCESS4;
+ /* fallthrough */
+ case 3: XXH_PROCESS1;
+ /* fallthrough */
+ case 2: XXH_PROCESS1;
+ /* fallthrough */
+ case 1: XXH_PROCESS1;
+ /* fallthrough */
+ case 0: return XXH32_avalanche(h32);
+ }
+ XXH_ASSERT(0);
+ return h32; /* reaching this point is deemed impossible */
+ }
+}
+
+#ifdef XXH_OLD_NAMES
+# define PROCESS1 XXH_PROCESS1
+# define PROCESS4 XXH_PROCESS4
+#else
+# undef XXH_PROCESS1
+# undef XXH_PROCESS4
+#endif
+
+XXH_FORCE_INLINE xxh_u32
+XXH32_endian_align(const xxh_u8* input, size_t len, xxh_u32 seed, XXH_alignment align)
+{
+ const xxh_u8* bEnd = input + len;
+ xxh_u32 h32;
+
+#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1)
+ if (input==NULL) {
+ len=0;
+ bEnd=input=(const xxh_u8*)(size_t)16;
+ }
+#endif
+
+ if (len>=16) {
+ const xxh_u8* const limit = bEnd - 15;
+ xxh_u32 v1 = seed + XXH_PRIME32_1 + XXH_PRIME32_2;
+ xxh_u32 v2 = seed + XXH_PRIME32_2;
+ xxh_u32 v3 = seed + 0;
+ xxh_u32 v4 = seed - XXH_PRIME32_1;
+
+ do {
+ v1 = XXH32_round(v1, XXH_get32bits(input)); input += 4;
+ v2 = XXH32_round(v2, XXH_get32bits(input)); input += 4;
+ v3 = XXH32_round(v3, XXH_get32bits(input)); input += 4;
+ v4 = XXH32_round(v4, XXH_get32bits(input)); input += 4;
+ } while (input < limit);
+
+ h32 = XXH_rotl32(v1, 1) + XXH_rotl32(v2, 7)
+ + XXH_rotl32(v3, 12) + XXH_rotl32(v4, 18);
+ } else {
+ h32 = seed + XXH_PRIME32_5;
+ }
+
+ h32 += (xxh_u32)len;
+
+ return XXH32_finalize(h32, input, len&15, align);
+}
+
+
+XXH_PUBLIC_API XXH32_hash_t XXH32 (const void* input, size_t len, XXH32_hash_t seed)
+{
+#if 0
+ /* Simple version, good for code maintenance, but unfortunately slow for small inputs */
+ XXH32_state_t state;
+ XXH32_reset(&state, seed);
+ XXH32_update(&state, (const xxh_u8*)input, len);
+ return XXH32_digest(&state);
+
+#else
+
+ if (XXH_FORCE_ALIGN_CHECK) {
+ if ((((size_t)input) & 3) == 0) { /* Input is 4-bytes aligned, leverage the speed benefit */
+ return XXH32_endian_align((const xxh_u8*)input, len, seed, XXH_aligned);
+ } }
+
+ return XXH32_endian_align((const xxh_u8*)input, len, seed, XXH_unaligned);
+#endif
+}
+
+
+
+/******* Hash streaming *******/
+
+XXH_PUBLIC_API XXH32_state_t* XXH32_createState(void)
+{
+ return (XXH32_state_t*)XXH_malloc(sizeof(XXH32_state_t));
+}
+XXH_PUBLIC_API XXH_errorcode XXH32_freeState(XXH32_state_t* statePtr)
+{
+ XXH_free(statePtr);
+ return XXH_OK;
+}
+
+XXH_PUBLIC_API void XXH32_copyState(XXH32_state_t* dstState, const XXH32_state_t* srcState)
+{
+ memcpy(dstState, srcState, sizeof(*dstState));
+}
+
+XXH_PUBLIC_API XXH_errorcode XXH32_reset(XXH32_state_t* statePtr, XXH32_hash_t seed)
+{
+ XXH32_state_t state; /* using a local state to memcpy() in order to avoid strict-aliasing warnings */
+ memset(&state, 0, sizeof(state));
+ state.v1 = seed + XXH_PRIME32_1 + XXH_PRIME32_2;
+ state.v2 = seed + XXH_PRIME32_2;
+ state.v3 = seed + 0;
+ state.v4 = seed - XXH_PRIME32_1;
+ /* do not write into reserved, planned to be removed in a future version */
+ memcpy(statePtr, &state, sizeof(state) - sizeof(state.reserved));
+ return XXH_OK;
+}
+
+
+XXH_PUBLIC_API XXH_errorcode
+XXH32_update(XXH32_state_t* state, const void* input, size_t len)
+{
+ if (input==NULL)
+#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1)
+ return XXH_OK;
+#else
+ return XXH_ERROR;
+#endif
+
+ { const xxh_u8* p = (const xxh_u8*)input;
+ const xxh_u8* const bEnd = p + len;
+
+ state->total_len_32 += (XXH32_hash_t)len;
+ state->large_len |= (XXH32_hash_t)((len>=16) | (state->total_len_32>=16));
+
+ if (state->memsize + len < 16) { /* fill in tmp buffer */
+ XXH_memcpy((xxh_u8*)(state->mem32) + state->memsize, input, len);
+ state->memsize += (XXH32_hash_t)len;
+ return XXH_OK;
+ }
+
+ if (state->memsize) { /* some data left from previous update */
+ XXH_memcpy((xxh_u8*)(state->mem32) + state->memsize, input, 16-state->memsize);
+ { const xxh_u32* p32 = state->mem32;
+ state->v1 = XXH32_round(state->v1, XXH_readLE32(p32)); p32++;
+ state->v2 = XXH32_round(state->v2, XXH_readLE32(p32)); p32++;
+ state->v3 = XXH32_round(state->v3, XXH_readLE32(p32)); p32++;
+ state->v4 = XXH32_round(state->v4, XXH_readLE32(p32));
+ }
+ p += 16-state->memsize;
+ state->memsize = 0;
+ }
+
+ if (p <= bEnd-16) {
+ const xxh_u8* const limit = bEnd - 16;
+ xxh_u32 v1 = state->v1;
+ xxh_u32 v2 = state->v2;
+ xxh_u32 v3 = state->v3;
+ xxh_u32 v4 = state->v4;
+
+ do {
+ v1 = XXH32_round(v1, XXH_readLE32(p)); p+=4;
+ v2 = XXH32_round(v2, XXH_readLE32(p)); p+=4;
+ v3 = XXH32_round(v3, XXH_readLE32(p)); p+=4;
+ v4 = XXH32_round(v4, XXH_readLE32(p)); p+=4;
+ } while (p<=limit);
+
+ state->v1 = v1;
+ state->v2 = v2;
+ state->v3 = v3;
+ state->v4 = v4;
+ }
+
+ if (p < bEnd) {
+ XXH_memcpy(state->mem32, p, (size_t)(bEnd-p));
+ state->memsize = (unsigned)(bEnd-p);
+ }
+ }
+
+ return XXH_OK;
+}
+
+
+XXH_PUBLIC_API XXH32_hash_t XXH32_digest (const XXH32_state_t* state)
+{
+ xxh_u32 h32;
+
+ if (state->large_len) {
+ h32 = XXH_rotl32(state->v1, 1)
+ + XXH_rotl32(state->v2, 7)
+ + XXH_rotl32(state->v3, 12)
+ + XXH_rotl32(state->v4, 18);
+ } else {
+ h32 = state->v3 /* == seed */ + XXH_PRIME32_5;
+ }
+
+ h32 += state->total_len_32;
+
+ return XXH32_finalize(h32, (const xxh_u8*)state->mem32, state->memsize, XXH_aligned);
+}
+
+
+/******* Canonical representation *******/
+
+/*
+ * The default return values from XXH functions are unsigned 32 and 64 bit
+ * integers.
+ *
+ * The canonical representation uses big endian convention, the same convention
+ * as human-readable numbers (large digits first).
+ *
+ * This way, hash values can be written into a file or buffer, remaining
+ * comparable across different systems.
+ *
+ * The following functions allow transformation of hash values to and from their
+ * canonical format.
+ */
+XXH_PUBLIC_API void XXH32_canonicalFromHash(XXH32_canonical_t* dst, XXH32_hash_t hash)
+{
+ XXH_STATIC_ASSERT(sizeof(XXH32_canonical_t) == sizeof(XXH32_hash_t));
+ if (XXH_CPU_LITTLE_ENDIAN) hash = XXH_swap32(hash);
+ memcpy(dst, &hash, sizeof(*dst));
+}
+
+XXH_PUBLIC_API XXH32_hash_t XXH32_hashFromCanonical(const XXH32_canonical_t* src)
+{
+ return XXH_readBE32(src);
+}
+
+
+#ifndef XXH_NO_LONG_LONG
+
+/* *******************************************************************
+* 64-bit hash functions
+*********************************************************************/
+
+/******* Memory access *******/
+
+typedef XXH64_hash_t xxh_u64;
+
+#ifdef XXH_OLD_NAMES
+# define U64 xxh_u64
+#endif
+
+/*!
+ * XXH_REROLL_XXH64:
+ * Whether to reroll the XXH64_finalize() loop.
+ *
+ * Just like XXH32, we can unroll the XXH64_finalize() loop. This can be a
+ * performance gain on 64-bit hosts, as only one jump is required.
+ *
+ * However, on 32-bit hosts, because arithmetic needs to be done with two 32-bit
+ * registers, and 64-bit arithmetic needs to be simulated, it isn't beneficial
+ * to unroll. The code becomes ridiculously large (the largest function in the
+ * binary on i386!), and rerolling it saves anywhere from 3kB to 20kB. It is
+ * also slightly faster because it fits into cache better and is more likely
+ * to be inlined by the compiler.
+ *
+ * If XXH_REROLL is defined, this is ignored and the loop is always rerolled.
+ */
+#ifndef XXH_REROLL_XXH64
+# if (defined(__ILP32__) || defined(_ILP32)) /* ILP32 is often defined on 32-bit GCC family */ \
+ || !(defined(__x86_64__) || defined(_M_X64) || defined(_M_AMD64) /* x86-64 */ \
+ || defined(_M_ARM64) || defined(__aarch64__) || defined(__arm64__) /* aarch64 */ \
+ || defined(__PPC64__) || defined(__PPC64LE__) || defined(__ppc64__) || defined(__powerpc64__) /* ppc64 */ \
+ || defined(__mips64__) || defined(__mips64)) /* mips64 */ \
+ || (!defined(SIZE_MAX) || SIZE_MAX < ULLONG_MAX) /* check limits */
+# define XXH_REROLL_XXH64 1
+# else
+# define XXH_REROLL_XXH64 0
+# endif
+#endif /* !defined(XXH_REROLL_XXH64) */
+
+#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3))
+/*
+ * Manual byteshift. Best for old compilers which don't inline memcpy.
+ * We actually directly use XXH_readLE64 and XXH_readBE64.
+ */
+#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==2))
+
+/* Force direct memory access. Only works on CPU which support unaligned memory access in hardware */
+static xxh_u64 XXH_read64(const void* memPtr) { return *(const xxh_u64*) memPtr; }
+
+#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==1))
+
+/*
+ * __pack instructions are safer, but compiler specific, hence potentially
+ * problematic for some compilers.
+ *
+ * Currently only defined for GCC and ICC.
+ */
+#ifdef XXH_OLD_NAMES
+typedef union { xxh_u32 u32; xxh_u64 u64; } __attribute__((packed)) unalign64;
+#endif
+static xxh_u64 XXH_read64(const void* ptr)
+{
+ typedef union { xxh_u32 u32; xxh_u64 u64; } __attribute__((packed)) xxh_unalign64;
+ return ((const xxh_unalign64*)ptr)->u64;
+}
+
+#else
+
+/*
+ * Portable and safe solution. Generally efficient.
+ * see: https://stackoverflow.com/a/32095106/646947
+ */
+static xxh_u64 XXH_read64(const void* memPtr)
+{
+ xxh_u64 val;
+ memcpy(&val, memPtr, sizeof(val));
+ return val;
+}
+
+#endif /* XXH_FORCE_DIRECT_MEMORY_ACCESS */
+
+#if defined(_MSC_VER) /* Visual Studio */
+# define XXH_swap64 _byteswap_uint64
+#elif XXH_GCC_VERSION >= 403
+# define XXH_swap64 __builtin_bswap64
+#else
+static xxh_u64 XXH_swap64 (xxh_u64 x)
+{
+ return ((x << 56) & 0xff00000000000000ULL) |
+ ((x << 40) & 0x00ff000000000000ULL) |
+ ((x << 24) & 0x0000ff0000000000ULL) |
+ ((x << 8) & 0x000000ff00000000ULL) |
+ ((x >> 8) & 0x00000000ff000000ULL) |
+ ((x >> 24) & 0x0000000000ff0000ULL) |
+ ((x >> 40) & 0x000000000000ff00ULL) |
+ ((x >> 56) & 0x00000000000000ffULL);
+}
+#endif
+
+
+/* XXH_FORCE_MEMORY_ACCESS==3 is an endian-independent byteshift load. */
+#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3))
+
+XXH_FORCE_INLINE xxh_u64 XXH_readLE64(const void* memPtr)
+{
+ const xxh_u8* bytePtr = (const xxh_u8 *)memPtr;
+ return bytePtr[0]
+ | ((xxh_u64)bytePtr[1] << 8)
+ | ((xxh_u64)bytePtr[2] << 16)
+ | ((xxh_u64)bytePtr[3] << 24)
+ | ((xxh_u64)bytePtr[4] << 32)
+ | ((xxh_u64)bytePtr[5] << 40)
+ | ((xxh_u64)bytePtr[6] << 48)
+ | ((xxh_u64)bytePtr[7] << 56);
+}
+
+XXH_FORCE_INLINE xxh_u64 XXH_readBE64(const void* memPtr)
+{
+ const xxh_u8* bytePtr = (const xxh_u8 *)memPtr;
+ return bytePtr[7]
+ | ((xxh_u64)bytePtr[6] << 8)
+ | ((xxh_u64)bytePtr[5] << 16)
+ | ((xxh_u64)bytePtr[4] << 24)
+ | ((xxh_u64)bytePtr[3] << 32)
+ | ((xxh_u64)bytePtr[2] << 40)
+ | ((xxh_u64)bytePtr[1] << 48)
+ | ((xxh_u64)bytePtr[0] << 56);
+}
+
+#else
+XXH_FORCE_INLINE xxh_u64 XXH_readLE64(const void* ptr)
+{
+ return XXH_CPU_LITTLE_ENDIAN ? XXH_read64(ptr) : XXH_swap64(XXH_read64(ptr));
+}
+
+static xxh_u64 XXH_readBE64(const void* ptr)
+{
+ return XXH_CPU_LITTLE_ENDIAN ? XXH_swap64(XXH_read64(ptr)) : XXH_read64(ptr);
+}
+#endif
+
+XXH_FORCE_INLINE xxh_u64
+XXH_readLE64_align(const void* ptr, XXH_alignment align)
+{
+ if (align==XXH_unaligned)
+ return XXH_readLE64(ptr);
+ else
+ return XXH_CPU_LITTLE_ENDIAN ? *(const xxh_u64*)ptr : XXH_swap64(*(const xxh_u64*)ptr);
+}
+
+
+/******* xxh64 *******/
+
+static const xxh_u64 XXH_PRIME64_1 = 0x9E3779B185EBCA87ULL; /* 0b1001111000110111011110011011000110000101111010111100101010000111 */
+static const xxh_u64 XXH_PRIME64_2 = 0xC2B2AE3D27D4EB4FULL; /* 0b1100001010110010101011100011110100100111110101001110101101001111 */
+static const xxh_u64 XXH_PRIME64_3 = 0x165667B19E3779F9ULL; /* 0b0001011001010110011001111011000110011110001101110111100111111001 */
+static const xxh_u64 XXH_PRIME64_4 = 0x85EBCA77C2B2AE63ULL; /* 0b1000010111101011110010100111011111000010101100101010111001100011 */
+static const xxh_u64 XXH_PRIME64_5 = 0x27D4EB2F165667C5ULL; /* 0b0010011111010100111010110010111100010110010101100110011111000101 */
+
+#ifdef XXH_OLD_NAMES
+# define PRIME64_1 XXH_PRIME64_1
+# define PRIME64_2 XXH_PRIME64_2
+# define PRIME64_3 XXH_PRIME64_3
+# define PRIME64_4 XXH_PRIME64_4
+# define PRIME64_5 XXH_PRIME64_5
+#endif
+
+static xxh_u64 XXH64_round(xxh_u64 acc, xxh_u64 input)
+{
+ acc += input * XXH_PRIME64_2;
+ acc = XXH_rotl64(acc, 31);
+ acc *= XXH_PRIME64_1;
+ return acc;
+}
+
+static xxh_u64 XXH64_mergeRound(xxh_u64 acc, xxh_u64 val)
+{
+ val = XXH64_round(0, val);
+ acc ^= val;
+ acc = acc * XXH_PRIME64_1 + XXH_PRIME64_4;
+ return acc;
+}
+
+static xxh_u64 XXH64_avalanche(xxh_u64 h64)
+{
+ h64 ^= h64 >> 33;
+ h64 *= XXH_PRIME64_2;
+ h64 ^= h64 >> 29;
+ h64 *= XXH_PRIME64_3;
+ h64 ^= h64 >> 32;
+ return h64;
+}
+
+
+#define XXH_get64bits(p) XXH_readLE64_align(p, align)
+
+static xxh_u64
+XXH64_finalize(xxh_u64 h64, const xxh_u8* ptr, size_t len, XXH_alignment align)
+{
+#define XXH_PROCESS1_64 do { \
+ h64 ^= (*ptr++) * XXH_PRIME64_5; \
+ h64 = XXH_rotl64(h64, 11) * XXH_PRIME64_1; \
+} while (0)
+
+#define XXH_PROCESS4_64 do { \
+ h64 ^= (xxh_u64)(XXH_get32bits(ptr)) * XXH_PRIME64_1; \
+ ptr += 4; \
+ h64 = XXH_rotl64(h64, 23) * XXH_PRIME64_2 + XXH_PRIME64_3; \
+} while (0)
+
+#define XXH_PROCESS8_64 do { \
+ xxh_u64 const k1 = XXH64_round(0, XXH_get64bits(ptr)); \
+ ptr += 8; \
+ h64 ^= k1; \
+ h64 = XXH_rotl64(h64,27) * XXH_PRIME64_1 + XXH_PRIME64_4; \
+} while (0)
+
+ /* Rerolled version for 32-bit targets is faster and much smaller. */
+ if (XXH_REROLL || XXH_REROLL_XXH64) {
+ len &= 31;
+ while (len >= 8) {
+ XXH_PROCESS8_64;
+ len -= 8;
+ }
+ if (len >= 4) {
+ XXH_PROCESS4_64;
+ len -= 4;
+ }
+ while (len > 0) {
+ XXH_PROCESS1_64;
+ --len;
+ }
+ return XXH64_avalanche(h64);
+ } else {
+ switch(len & 31) {
+ case 24: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 16: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 8: XXH_PROCESS8_64;
+ return XXH64_avalanche(h64);
+
+ case 28: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 20: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 12: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 4: XXH_PROCESS4_64;
+ return XXH64_avalanche(h64);
+
+ case 25: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 17: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 9: XXH_PROCESS8_64;
+ XXH_PROCESS1_64;
+ return XXH64_avalanche(h64);
+
+ case 29: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 21: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 13: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 5: XXH_PROCESS4_64;
+ XXH_PROCESS1_64;
+ return XXH64_avalanche(h64);
+
+ case 26: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 18: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 10: XXH_PROCESS8_64;
+ XXH_PROCESS1_64;
+ XXH_PROCESS1_64;
+ return XXH64_avalanche(h64);
+
+ case 30: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 22: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 14: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 6: XXH_PROCESS4_64;
+ XXH_PROCESS1_64;
+ XXH_PROCESS1_64;
+ return XXH64_avalanche(h64);
+
+ case 27: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 19: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 11: XXH_PROCESS8_64;
+ XXH_PROCESS1_64;
+ XXH_PROCESS1_64;
+ XXH_PROCESS1_64;
+ return XXH64_avalanche(h64);
+
+ case 31: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 23: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 15: XXH_PROCESS8_64;
+ /* fallthrough */
+ case 7: XXH_PROCESS4_64;
+ /* fallthrough */
+ case 3: XXH_PROCESS1_64;
+ /* fallthrough */
+ case 2: XXH_PROCESS1_64;
+ /* fallthrough */
+ case 1: XXH_PROCESS1_64;
+ /* fallthrough */
+ case 0: return XXH64_avalanche(h64);
+ }
+ }
+ /* impossible to reach */
+ XXH_ASSERT(0);
+ return 0; /* unreachable, but some compilers complain without it */
+}
+
+#ifdef XXH_OLD_NAMES
+# define PROCESS1_64 XXH_PROCESS1_64
+# define PROCESS4_64 XXH_PROCESS4_64
+# define PROCESS8_64 XXH_PROCESS8_64
+#else
+# undef XXH_PROCESS1_64
+# undef XXH_PROCESS4_64
+# undef XXH_PROCESS8_64
+#endif
+
+XXH_FORCE_INLINE xxh_u64
+XXH64_endian_align(const xxh_u8* input, size_t len, xxh_u64 seed, XXH_alignment align)
+{
+ const xxh_u8* bEnd = input + len;
+ xxh_u64 h64;
+
+#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1)
+ if (input==NULL) {
+ len=0;
+ bEnd=input=(const xxh_u8*)(size_t)32;
+ }
+#endif
+
+ if (len>=32) {
+ const xxh_u8* const limit = bEnd - 32;
+ xxh_u64 v1 = seed + XXH_PRIME64_1 + XXH_PRIME64_2;
+ xxh_u64 v2 = seed + XXH_PRIME64_2;
+ xxh_u64 v3 = seed + 0;
+ xxh_u64 v4 = seed - XXH_PRIME64_1;
+
+ do {
+ v1 = XXH64_round(v1, XXH_get64bits(input)); input+=8;
+ v2 = XXH64_round(v2, XXH_get64bits(input)); input+=8;
+ v3 = XXH64_round(v3, XXH_get64bits(input)); input+=8;
+ v4 = XXH64_round(v4, XXH_get64bits(input)); input+=8;
+ } while (input<=limit);
+
+ h64 = XXH_rotl64(v1, 1) + XXH_rotl64(v2, 7) + XXH_rotl64(v3, 12) + XXH_rotl64(v4, 18);
+ h64 = XXH64_mergeRound(h64, v1);
+ h64 = XXH64_mergeRound(h64, v2);
+ h64 = XXH64_mergeRound(h64, v3);
+ h64 = XXH64_mergeRound(h64, v4);
+
+ } else {
+ h64 = seed + XXH_PRIME64_5;
+ }
+
+ h64 += (xxh_u64) len;
+
+ return XXH64_finalize(h64, input, len, align);
+}
+
+
+XXH_PUBLIC_API XXH64_hash_t XXH64 (const void* input, size_t len, XXH64_hash_t seed)
+{
+#if 0
+ /* Simple version, good for code maintenance, but unfortunately slow for small inputs */
+ XXH64_state_t state;
+ XXH64_reset(&state, seed);
+ XXH64_update(&state, (const xxh_u8*)input, len);
+ return XXH64_digest(&state);
+
+#else
+
+ if (XXH_FORCE_ALIGN_CHECK) {
+ if ((((size_t)input) & 7)==0) { /* Input is aligned, let's leverage the speed advantage */
+ return XXH64_endian_align((const xxh_u8*)input, len, seed, XXH_aligned);
+ } }
+
+ return XXH64_endian_align((const xxh_u8*)input, len, seed, XXH_unaligned);
+
+#endif
+}
+
+/******* Hash Streaming *******/
+
+XXH_PUBLIC_API XXH64_state_t* XXH64_createState(void)
+{
+ return (XXH64_state_t*)XXH_malloc(sizeof(XXH64_state_t));
+}
+XXH_PUBLIC_API XXH_errorcode XXH64_freeState(XXH64_state_t* statePtr)
+{
+ XXH_free(statePtr);
+ return XXH_OK;
+}
+
+XXH_PUBLIC_API void XXH64_copyState(XXH64_state_t* dstState, const XXH64_state_t* srcState)
+{
+ memcpy(dstState, srcState, sizeof(*dstState));
+}
+
+XXH_PUBLIC_API XXH_errorcode XXH64_reset(XXH64_state_t* statePtr, XXH64_hash_t seed)
+{
+ XXH64_state_t state; /* use a local state to memcpy() in order to avoid strict-aliasing warnings */
+ memset(&state, 0, sizeof(state));
+ state.v1 = seed + XXH_PRIME64_1 + XXH_PRIME64_2;
+ state.v2 = seed + XXH_PRIME64_2;
+ state.v3 = seed + 0;
+ state.v4 = seed - XXH_PRIME64_1;
+ /* do not write into reserved64, might be removed in a future version */
+ memcpy(statePtr, &state, sizeof(state) - sizeof(state.reserved64));
+ return XXH_OK;
+}
+
+XXH_PUBLIC_API XXH_errorcode
+XXH64_update (XXH64_state_t* state, const void* input, size_t len)
+{
+ if (input==NULL)
+#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1)
+ return XXH_OK;
+#else
+ return XXH_ERROR;
+#endif
+
+ { const xxh_u8* p = (const xxh_u8*)input;
+ const xxh_u8* const bEnd = p + len;
+
+ state->total_len += len;
+
+ if (state->memsize + len < 32) { /* fill in tmp buffer */
+ XXH_memcpy(((xxh_u8*)state->mem64) + state->memsize, input, len);
+ state->memsize += (xxh_u32)len;
+ return XXH_OK;
+ }
+
+ if (state->memsize) { /* tmp buffer is full */
+ XXH_memcpy(((xxh_u8*)state->mem64) + state->memsize, input, 32-state->memsize);
+ state->v1 = XXH64_round(state->v1, XXH_readLE64(state->mem64+0));
+ state->v2 = XXH64_round(state->v2, XXH_readLE64(state->mem64+1));
+ state->v3 = XXH64_round(state->v3, XXH_readLE64(state->mem64+2));
+ state->v4 = XXH64_round(state->v4, XXH_readLE64(state->mem64+3));
+ p += 32-state->memsize;
+ state->memsize = 0;
+ }
+
+ if (p+32 <= bEnd) {
+ const xxh_u8* const limit = bEnd - 32;
+ xxh_u64 v1 = state->v1;
+ xxh_u64 v2 = state->v2;
+ xxh_u64 v3 = state->v3;
+ xxh_u64 v4 = state->v4;
+
+ do {
+ v1 = XXH64_round(v1, XXH_readLE64(p)); p+=8;
+ v2 = XXH64_round(v2, XXH_readLE64(p)); p+=8;
+ v3 = XXH64_round(v3, XXH_readLE64(p)); p+=8;
+ v4 = XXH64_round(v4, XXH_readLE64(p)); p+=8;
+ } while (p<=limit);
+
+ state->v1 = v1;
+ state->v2 = v2;
+ state->v3 = v3;
+ state->v4 = v4;
+ }
+
+ if (p < bEnd) {
+ XXH_memcpy(state->mem64, p, (size_t)(bEnd-p));
+ state->memsize = (unsigned)(bEnd-p);
+ }
+ }
+
+ return XXH_OK;
+}
+
+
+XXH_PUBLIC_API XXH64_hash_t XXH64_digest (const XXH64_state_t* state)
+{
+ xxh_u64 h64;
+
+ if (state->total_len >= 32) {
+ xxh_u64 const v1 = state->v1;
+ xxh_u64 const v2 = state->v2;
+ xxh_u64 const v3 = state->v3;
+ xxh_u64 const v4 = state->v4;
+
+ h64 = XXH_rotl64(v1, 1) + XXH_rotl64(v2, 7) + XXH_rotl64(v3, 12) + XXH_rotl64(v4, 18);
+ h64 = XXH64_mergeRound(h64, v1);
+ h64 = XXH64_mergeRound(h64, v2);
+ h64 = XXH64_mergeRound(h64, v3);
+ h64 = XXH64_mergeRound(h64, v4);
+ } else {
+ h64 = state->v3 /*seed*/ + XXH_PRIME64_5;
+ }
+
+ h64 += (xxh_u64) state->total_len;
+
+ return XXH64_finalize(h64, (const xxh_u8*)state->mem64, (size_t)state->total_len, XXH_aligned);
+}
+
+
+/******* Canonical representation *******/
+
+XXH_PUBLIC_API void XXH64_canonicalFromHash(XXH64_canonical_t* dst, XXH64_hash_t hash)
+{
+ XXH_STATIC_ASSERT(sizeof(XXH64_canonical_t) == sizeof(XXH64_hash_t));
+ if (XXH_CPU_LITTLE_ENDIAN) hash = XXH_swap64(hash);
+ memcpy(dst, &hash, sizeof(*dst));
+}
+
+XXH_PUBLIC_API XXH64_hash_t XXH64_hashFromCanonical(const XXH64_canonical_t* src)
+{
+ return XXH_readBE64(src);
+}
+
+
+
+/* *********************************************************************
+* XXH3
+* New generation hash designed for speed on small keys and vectorization
+************************************************************************ */
+
+/* === Compiler specifics === */
+
+/* Patch from https://github.com/Cyan4973/xxHash/pull/498 */
+#if ((defined(sun) || defined(__sun)) && __cplusplus) /* Solaris includes __STDC_VERSION__ with C++. Tested with GCC 5.5 */
+# define XXH_RESTRICT /* disable */
+#elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L /* >= C99 */
+# define XXH_RESTRICT restrict
+#else
+/* Note: it might be useful to define __restrict or __restrict__ for some C++ compilers */
+# define XXH_RESTRICT /* disable */
+#endif
+
+#if (defined(__GNUC__) && (__GNUC__ >= 3)) \
+ || (defined(__INTEL_COMPILER) && (__INTEL_COMPILER >= 800)) \
+ || defined(__clang__)
+# define XXH_likely(x) __builtin_expect(x, 1)
+# define XXH_unlikely(x) __builtin_expect(x, 0)
+#else
+# define XXH_likely(x) (x)
+# define XXH_unlikely(x) (x)
+#endif
+
+#if defined(__GNUC__)
+# if defined(__AVX2__)
+# include <immintrin.h>
+# elif defined(__SSE2__)
+# include <emmintrin.h>
+# elif defined(__ARM_NEON__) || defined(__ARM_NEON)
+# define inline __inline__ /* circumvent a clang bug */
+# include <arm_neon.h>
+# undef inline
+# endif
+#elif defined(_MSC_VER)
+# include <intrin.h>
+#endif
+
+/*
+ * One goal of XXH3 is to make it fast on both 32-bit and 64-bit, while
+ * remaining a true 64-bit/128-bit hash function.
+ *
+ * This is done by prioritizing a subset of 64-bit operations that can be
+ * emulated without too many steps on the average 32-bit machine.
+ *
+ * For example, these two lines seem similar, and run equally fast on 64-bit:
+ *
+ * xxh_u64 x;
+ * x ^= (x >> 47); // good
+ * x ^= (x >> 13); // bad
+ *
+ * However, to a 32-bit machine, there is a major difference.
+ *
+ * x ^= (x >> 47) looks like this:
+ *
+ * x.lo ^= (x.hi >> (47 - 32));
+ *
+ * while x ^= (x >> 13) looks like this:
+ *
+ * // note: funnel shifts are not usually cheap.
+ * x.lo ^= (x.lo >> 13) | (x.hi << (32 - 13));
+ * x.hi ^= (x.hi >> 13);
+ *
+ * The first one is significantly faster than the second, simply because the
+ * shift is larger than 32. This means:
+ * - All the bits we need are in the upper 32 bits, so we can ignore the lower
+ * 32 bits in the shift.
+ * - The shift result will always fit in the lower 32 bits, and therefore,
+ * we can ignore the upper 32 bits in the xor.
+ *
+ * Thanks to this optimization, XXH3 only requires these features to be efficient:
+ *
+ * - Usable unaligned access
+ * - A 32-bit or 64-bit ALU
+ * - If 32-bit, a decent ADC instruction
+ * - A 32 or 64-bit multiply with a 64-bit result
+ * - For the 128-bit variant, a decent byteswap helps short inputs.
+ *
+ * The first two are already required by XXH32, and almost all 32-bit and 64-bit
+ * platforms which can run XXH32 can run XXH3 efficiently.
+ *
+ * Thumb-1, the classic 16-bit only subset of ARM's instruction set, is one
+ * notable exception.
+ *
+ * First of all, Thumb-1 lacks support for the UMULL instruction which
+ * performs the important long multiply. This means numerous __aeabi_lmul
+ * calls.
+ *
+ * Second of all, the 8 functional registers are just not enough.
+ * Setup for __aeabi_lmul, byteshift loads, pointers, and all arithmetic need
+ * Lo registers, and this shuffling results in thousands more MOVs than A32.
+ *
+ * A32 and T32 don't have this limitation. They can access all 14 registers,
+ * do a 32->64 multiply with UMULL, and the flexible operand allowing free
+ * shifts is helpful, too.
+ *
+ * Therefore, we do a quick sanity check.
+ *
+ * If compiling Thumb-1 for a target which supports ARM instructions, we will
+ * emit a warning, as it is not a "sane" platform to compile for.
+ *
+ * Usually, if this happens, it is because of an accident and you probably need
+ * to specify -march, as you likely meant to compile for a newer architecture.
+ *
+ * Credit: large sections of the vectorial and asm source code paths
+ * have been contributed by @easyaspi314
+ */
+#if defined(__thumb__) && !defined(__thumb2__) && defined(__ARM_ARCH_ISA_ARM)
+# warning "XXH3 is highly inefficient without ARM or Thumb-2."
+#endif
+
+/* ==========================================
+ * Vectorization detection
+ * ========================================== */
+#define XXH_SCALAR 0 /* Portable scalar version */
+#define XXH_SSE2 1 /* SSE2 for Pentium 4 and all x86_64 */
+#define XXH_AVX2 2 /* AVX2 for Haswell and Bulldozer */
+#define XXH_AVX512 3 /* AVX512 for Skylake and Icelake */
+#define XXH_NEON 4 /* NEON for most ARMv7-A and all AArch64 */
+#define XXH_VSX 5 /* VSX and ZVector for POWER8/z13 */
+
+#ifndef XXH_VECTOR /* can be defined on command line */
+# if defined(__AVX512F__)
+# define XXH_VECTOR XXH_AVX512
+# elif defined(__AVX2__)
+# define XXH_VECTOR XXH_AVX2
+# elif defined(__SSE2__) || defined(_M_AMD64) || defined(_M_X64) || (defined(_M_IX86_FP) && (_M_IX86_FP == 2))
+# define XXH_VECTOR XXH_SSE2
+# elif defined(__GNUC__) /* msvc support maybe later */ \
+ && (defined(__ARM_NEON__) || defined(__ARM_NEON)) \
+ && (defined(__LITTLE_ENDIAN__) /* We only support little endian NEON */ \
+ || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__))
+# define XXH_VECTOR XXH_NEON
+# elif (defined(__PPC64__) && defined(__POWER8_VECTOR__)) \
+ || (defined(__s390x__) && defined(__VEC__)) \
+ && defined(__GNUC__) /* TODO: IBM XL */
+# define XXH_VECTOR XXH_VSX
+# else
+# define XXH_VECTOR XXH_SCALAR
+# endif
+#endif
+
+/*
+ * Controls the alignment of the accumulator,
+ * for compatibility with aligned vector loads, which are usually faster.
+ */
+#ifndef XXH_ACC_ALIGN
+# if defined(XXH_X86DISPATCH)
+# define XXH_ACC_ALIGN 64 /* for compatibility with avx512 */
+# elif XXH_VECTOR == XXH_SCALAR /* scalar */
+# define XXH_ACC_ALIGN 8
+# elif XXH_VECTOR == XXH_SSE2 /* sse2 */
+# define XXH_ACC_ALIGN 16
+# elif XXH_VECTOR == XXH_AVX2 /* avx2 */
+# define XXH_ACC_ALIGN 32
+# elif XXH_VECTOR == XXH_NEON /* neon */
+# define XXH_ACC_ALIGN 16
+# elif XXH_VECTOR == XXH_VSX /* vsx */
+# define XXH_ACC_ALIGN 16
+# elif XXH_VECTOR == XXH_AVX512 /* avx512 */
+# define XXH_ACC_ALIGN 64
+# endif
+#endif
+
+#if defined(XXH_X86DISPATCH) || XXH_VECTOR == XXH_SSE2 \
+ || XXH_VECTOR == XXH_AVX2 || XXH_VECTOR == XXH_AVX512
+# define XXH_SEC_ALIGN XXH_ACC_ALIGN
+#else
+# define XXH_SEC_ALIGN 8
+#endif
+
+/*
+ * UGLY HACK:
+ * GCC usually generates the best code with -O3 for xxHash.
+ *
+ * However, when targeting AVX2, it is overzealous in its unrolling resulting
+ * in code roughly 3/4 the speed of Clang.
+ *
+ * There are other issues, such as GCC splitting _mm256_loadu_si256 into
+ * _mm_loadu_si128 + _mm256_inserti128_si256. This is an optimization which
+ * only applies to Sandy and Ivy Bridge... which don't even support AVX2.
+ *
+ * That is why when compiling the AVX2 version, it is recommended to use either
+ * -O2 -mavx2 -march=haswell
+ * or
+ * -O2 -mavx2 -mno-avx256-split-unaligned-load
+ * for decent performance, or to use Clang instead.
+ *
+ * Fortunately, we can control the first one with a pragma that forces GCC into
+ * -O2, but the other one we can't control without "failed to inline always
+ * inline function due to target mismatch" warnings.
+ */
+#if XXH_VECTOR == XXH_AVX2 /* AVX2 */ \
+ && defined(__GNUC__) && !defined(__clang__) /* GCC, not Clang */ \
+ && defined(__OPTIMIZE__) && !defined(__OPTIMIZE_SIZE__) /* respect -O0 and -Os */
+# pragma GCC push_options
+# pragma GCC optimize("-O2")
+#endif
+
+
+#if XXH_VECTOR == XXH_NEON
+/*
+ * NEON's setup for vmlal_u32 is a little more complicated than it is on
+ * SSE2, AVX2, and VSX.
+ *
+ * While PMULUDQ and VMULEUW both perform a mask, VMLAL.U32 performs an upcast.
+ *
+ * To do the same operation, the 128-bit 'Q' register needs to be split into
+ * two 64-bit 'D' registers, performing this operation::
+ *
+ * [ a | b ]
+ * | '---------. .--------' |
+ * | x |
+ * | .---------' '--------. |
+ * [ a & 0xFFFFFFFF | b & 0xFFFFFFFF ],[ a >> 32 | b >> 32 ]
+ *
+ * Due to significant changes in aarch64, the fastest method for aarch64 is
+ * completely different than the fastest method for ARMv7-A.
+ *
+ * ARMv7-A treats D registers as unions overlaying Q registers, so modifying
+ * D11 will modify the high half of Q5. This is similar to how modifying AH
+ * will only affect bits 8-15 of AX on x86.
+ *
+ * VZIP takes two registers, and puts even lanes in one register and odd lanes
+ * in the other.
+ *
+ * On ARMv7-A, this strangely modifies both parameters in place instead of
+ * taking the usual 3-operand form.
+ *
+ * Therefore, if we want to do this, we can simply use a D-form VZIP.32 on the
+ * lower and upper halves of the Q register to end up with the high and low
+ * halves where we want - all in one instruction.
+ *
+ * vzip.32 d10, d11 @ d10 = { d10[0], d11[0] }; d11 = { d10[1], d11[1] }
+ *
+ * Unfortunately we need inline assembly for this: Instructions modifying two
+ * registers at once is not possible in GCC or Clang's IR, and they have to
+ * create a copy.
+ *
+ * aarch64 requires a different approach.
+ *
+ * In order to make it easier to write a decent compiler for aarch64, many
+ * quirks were removed, such as conditional execution.
+ *
+ * NEON was also affected by this.
+ *
+ * aarch64 cannot access the high bits of a Q-form register, and writes to a
+ * D-form register zero the high bits, similar to how writes to W-form scalar
+ * registers (or DWORD registers on x86_64) work.
+ *
+ * The formerly free vget_high intrinsics now require a vext (with a few
+ * exceptions)
+ *
+ * Additionally, VZIP was replaced by ZIP1 and ZIP2, which are the equivalent
+ * of PUNPCKL* and PUNPCKH* in SSE, respectively, in order to only modify one
+ * operand.
+ *
+ * The equivalent of the VZIP.32 on the lower and upper halves would be this
+ * mess:
+ *
+ * ext v2.4s, v0.4s, v0.4s, #2 // v2 = { v0[2], v0[3], v0[0], v0[1] }
+ * zip1 v1.2s, v0.2s, v2.2s // v1 = { v0[0], v2[0] }
+ * zip2 v0.2s, v0.2s, v1.2s // v0 = { v0[1], v2[1] }
+ *
+ * Instead, we use a literal downcast, vmovn_u64 (XTN), and vshrn_n_u64 (SHRN):
+ *
+ * shrn v1.2s, v0.2d, #32 // v1 = (uint32x2_t)(v0 >> 32);
+ * xtn v0.2s, v0.2d // v0 = (uint32x2_t)(v0 & 0xFFFFFFFF);
+ *
+ * This is available on ARMv7-A, but is less efficient than a single VZIP.32.
+ */
+
+/*
+ * Function-like macro:
+ * void XXH_SPLIT_IN_PLACE(uint64x2_t &in, uint32x2_t &outLo, uint32x2_t &outHi)
+ * {
+ * outLo = (uint32x2_t)(in & 0xFFFFFFFF);
+ * outHi = (uint32x2_t)(in >> 32);
+ * in = UNDEFINED;
+ * }
+ */
+# if !defined(XXH_NO_VZIP_HACK) /* define to disable */ \
+ && defined(__GNUC__) \
+ && !defined(__aarch64__) && !defined(__arm64__)
+# define XXH_SPLIT_IN_PLACE(in, outLo, outHi) \
+ do { \
+ /* Undocumented GCC/Clang operand modifier: %e0 = lower D half, %f0 = upper D half */ \
+ /* https://github.com/gcc-mirror/gcc/blob/38cf91e5/gcc/config/arm/arm.c#L22486 */ \
+ /* https://github.com/llvm-mirror/llvm/blob/2c4ca683/lib/Target/ARM/ARMAsmPrinter.cpp#L399 */ \
+ __asm__("vzip.32 %e0, %f0" : "+w" (in)); \
+ (outLo) = vget_low_u32 (vreinterpretq_u32_u64(in)); \
+ (outHi) = vget_high_u32(vreinterpretq_u32_u64(in)); \
+ } while (0)
+# else
+# define XXH_SPLIT_IN_PLACE(in, outLo, outHi) \
+ do { \
+ (outLo) = vmovn_u64 (in); \
+ (outHi) = vshrn_n_u64 ((in), 32); \
+ } while (0)
+# endif
+#endif /* XXH_VECTOR == XXH_NEON */
+
+/*
+ * VSX and Z Vector helpers.
+ *
+ * This is very messy, and any pull requests to clean this up are welcome.
+ *
+ * There are a lot of problems with supporting VSX and s390x, due to
+ * inconsistent intrinsics, spotty coverage, and multiple endiannesses.
+ */
+#if XXH_VECTOR == XXH_VSX
+# if defined(__s390x__)
+# include <s390intrin.h>
+# else
+/* gcc's altivec.h can have the unwanted consequence to unconditionally
+ * #define bool, vector, and pixel keywords,
+ * with bad consequences for programs already using these keywords for other purposes.
+ * The paragraph defining these macros is skipped when __APPLE_ALTIVEC__ is defined.
+ * __APPLE_ALTIVEC__ is _generally_ defined automatically by the compiler,
+ * but it seems that, in some cases, it isn't.
+ * Force the build macro to be defined, so that keywords are not altered.
+ */
+# if defined(__GNUC__) && !defined(__APPLE_ALTIVEC__)
+# define __APPLE_ALTIVEC__
+# endif
+# include <altivec.h>
+# endif
+
+typedef __vector unsigned long long xxh_u64x2;
+typedef __vector unsigned char xxh_u8x16;
+typedef __vector unsigned xxh_u32x4;
+
+# ifndef XXH_VSX_BE
+# if defined(__BIG_ENDIAN__) \
+ || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+# define XXH_VSX_BE 1
+# elif defined(__VEC_ELEMENT_REG_ORDER__) && __VEC_ELEMENT_REG_ORDER__ == __ORDER_BIG_ENDIAN__
+# warning "-maltivec=be is not recommended. Please use native endianness."
+# define XXH_VSX_BE 1
+# else
+# define XXH_VSX_BE 0
+# endif
+# endif /* !defined(XXH_VSX_BE) */
+
+# if XXH_VSX_BE
+/* A wrapper for POWER9's vec_revb. */
+# if defined(__POWER9_VECTOR__) || (defined(__clang__) && defined(__s390x__))
+# define XXH_vec_revb vec_revb
+# else
+XXH_FORCE_INLINE xxh_u64x2 XXH_vec_revb(xxh_u64x2 val)
+{
+ xxh_u8x16 const vByteSwap = { 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, 0x00,
+ 0x0F, 0x0E, 0x0D, 0x0C, 0x0B, 0x0A, 0x09, 0x08 };
+ return vec_perm(val, val, vByteSwap);
+}
+# endif
+# endif /* XXH_VSX_BE */
+
+/*
+ * Performs an unaligned load and byte swaps it on big endian.
+ */
+XXH_FORCE_INLINE xxh_u64x2 XXH_vec_loadu(const void *ptr)
+{
+ xxh_u64x2 ret;
+ memcpy(&ret, ptr, sizeof(xxh_u64x2));
+# if XXH_VSX_BE
+ ret = XXH_vec_revb(ret);
+# endif
+ return ret;
+}
+
+/*
+ * vec_mulo and vec_mule are very problematic intrinsics on PowerPC
+ *
+ * These intrinsics weren't added until GCC 8, despite existing for a while,
+ * and they are endian dependent. Also, their meaning swap depending on version.
+ * */
+# if defined(__s390x__)
+ /* s390x is always big endian, no issue on this platform */
+# define XXH_vec_mulo vec_mulo
+# define XXH_vec_mule vec_mule
+# elif defined(__clang__) && XXH_HAS_BUILTIN(__builtin_altivec_vmuleuw)
+/* Clang has a better way to control this, we can just use the builtin which doesn't swap. */
+# define XXH_vec_mulo __builtin_altivec_vmulouw
+# define XXH_vec_mule __builtin_altivec_vmuleuw
+# else
+/* gcc needs inline assembly */
+/* Adapted from https://github.com/google/highwayhash/blob/master/highwayhash/hh_vsx.h. */
+XXH_FORCE_INLINE xxh_u64x2 XXH_vec_mulo(xxh_u32x4 a, xxh_u32x4 b)
+{
+ xxh_u64x2 result;
+ __asm__("vmulouw %0, %1, %2" : "=v" (result) : "v" (a), "v" (b));
+ return result;
+}
+XXH_FORCE_INLINE xxh_u64x2 XXH_vec_mule(xxh_u32x4 a, xxh_u32x4 b)
+{
+ xxh_u64x2 result;
+ __asm__("vmuleuw %0, %1, %2" : "=v" (result) : "v" (a), "v" (b));
+ return result;
+}
+# endif /* XXH_vec_mulo, XXH_vec_mule */
+#endif /* XXH_VECTOR == XXH_VSX */
+
+
+/* prefetch
+ * can be disabled, by declaring XXH_NO_PREFETCH build macro */
+#if defined(XXH_NO_PREFETCH)
+# define XXH_PREFETCH(ptr) (void)(ptr) /* disabled */
+#else
+# if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_I86)) /* _mm_prefetch() is not defined outside of x86/x64 */
+# include <mmintrin.h> /* https://msdn.microsoft.com/fr-fr/library/84szxsww(v=vs.90).aspx */
+# define XXH_PREFETCH(ptr) _mm_prefetch((const char*)(ptr), _MM_HINT_T0)
+# elif defined(__GNUC__) && ( (__GNUC__ >= 4) || ( (__GNUC__ == 3) && (__GNUC_MINOR__ >= 1) ) )
+# define XXH_PREFETCH(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */)
+# else
+# define XXH_PREFETCH(ptr) (void)(ptr) /* disabled */
+# endif
+#endif /* XXH_NO_PREFETCH */
+
+
+/* ==========================================
+ * XXH3 default settings
+ * ========================================== */
+
+#define XXH_SECRET_DEFAULT_SIZE 192 /* minimum XXH3_SECRET_SIZE_MIN */
+
+#if (XXH_SECRET_DEFAULT_SIZE < XXH3_SECRET_SIZE_MIN)
+# error "default keyset is not large enough"
+#endif
+
+/* Pseudorandom secret taken directly from FARSH */
+XXH_ALIGN(64) static const xxh_u8 XXH3_kSecret[XXH_SECRET_DEFAULT_SIZE] = {
+ 0xb8, 0xfe, 0x6c, 0x39, 0x23, 0xa4, 0x4b, 0xbe, 0x7c, 0x01, 0x81, 0x2c, 0xf7, 0x21, 0xad, 0x1c,
+ 0xde, 0xd4, 0x6d, 0xe9, 0x83, 0x90, 0x97, 0xdb, 0x72, 0x40, 0xa4, 0xa4, 0xb7, 0xb3, 0x67, 0x1f,
+ 0xcb, 0x79, 0xe6, 0x4e, 0xcc, 0xc0, 0xe5, 0x78, 0x82, 0x5a, 0xd0, 0x7d, 0xcc, 0xff, 0x72, 0x21,
+ 0xb8, 0x08, 0x46, 0x74, 0xf7, 0x43, 0x24, 0x8e, 0xe0, 0x35, 0x90, 0xe6, 0x81, 0x3a, 0x26, 0x4c,
+ 0x3c, 0x28, 0x52, 0xbb, 0x91, 0xc3, 0x00, 0xcb, 0x88, 0xd0, 0x65, 0x8b, 0x1b, 0x53, 0x2e, 0xa3,
+ 0x71, 0x64, 0x48, 0x97, 0xa2, 0x0d, 0xf9, 0x4e, 0x38, 0x19, 0xef, 0x46, 0xa9, 0xde, 0xac, 0xd8,
+ 0xa8, 0xfa, 0x76, 0x3f, 0xe3, 0x9c, 0x34, 0x3f, 0xf9, 0xdc, 0xbb, 0xc7, 0xc7, 0x0b, 0x4f, 0x1d,
+ 0x8a, 0x51, 0xe0, 0x4b, 0xcd, 0xb4, 0x59, 0x31, 0xc8, 0x9f, 0x7e, 0xc9, 0xd9, 0x78, 0x73, 0x64,
+ 0xea, 0xc5, 0xac, 0x83, 0x34, 0xd3, 0xeb, 0xc3, 0xc5, 0x81, 0xa0, 0xff, 0xfa, 0x13, 0x63, 0xeb,
+ 0x17, 0x0d, 0xdd, 0x51, 0xb7, 0xf0, 0xda, 0x49, 0xd3, 0x16, 0x55, 0x26, 0x29, 0xd4, 0x68, 0x9e,
+ 0x2b, 0x16, 0xbe, 0x58, 0x7d, 0x47, 0xa1, 0xfc, 0x8f, 0xf8, 0xb8, 0xd1, 0x7a, 0xd0, 0x31, 0xce,
+ 0x45, 0xcb, 0x3a, 0x8f, 0x95, 0x16, 0x04, 0x28, 0xaf, 0xd7, 0xfb, 0xca, 0xbb, 0x4b, 0x40, 0x7e,
+};
+
+
+#ifdef XXH_OLD_NAMES
+# define kSecret XXH3_kSecret
+#endif
+
+/*
+ * Calculates a 32-bit to 64-bit long multiply.
+ *
+ * Wraps __emulu on MSVC x86 because it tends to call __allmul when it doesn't
+ * need to (but it shouldn't need to anyways, it is about 7 instructions to do
+ * a 64x64 multiply...). Since we know that this will _always_ emit MULL, we
+ * use that instead of the normal method.
+ *
+ * If you are compiling for platforms like Thumb-1 and don't have a better option,
+ * you may also want to write your own long multiply routine here.
+ *
+ * XXH_FORCE_INLINE xxh_u64 XXH_mult32to64(xxh_u64 x, xxh_u64 y)
+ * {
+ * return (x & 0xFFFFFFFF) * (y & 0xFFFFFFFF);
+ * }
+ */
+#if defined(_MSC_VER) && defined(_M_IX86)
+# include <intrin.h>
+# define XXH_mult32to64(x, y) __emulu((unsigned)(x), (unsigned)(y))
+#else
+/*
+ * Downcast + upcast is usually better than masking on older compilers like
+ * GCC 4.2 (especially 32-bit ones), all without affecting newer compilers.
+ *
+ * The other method, (x & 0xFFFFFFFF) * (y & 0xFFFFFFFF), will AND both operands
+ * and perform a full 64x64 multiply -- entirely redundant on 32-bit.
+ */
+# define XXH_mult32to64(x, y) ((xxh_u64)(xxh_u32)(x) * (xxh_u64)(xxh_u32)(y))
+#endif
+
+/*
+ * Calculates a 64->128-bit long multiply.
+ *
+ * Uses __uint128_t and _umul128 if available, otherwise uses a scalar version.
+ */
+static XXH128_hash_t
+XXH_mult64to128(xxh_u64 lhs, xxh_u64 rhs)
+{
+ /*
+ * GCC/Clang __uint128_t method.
+ *
+ * On most 64-bit targets, GCC and Clang define a __uint128_t type.
+ * This is usually the best way as it usually uses a native long 64-bit
+ * multiply, such as MULQ on x86_64 or MUL + UMULH on aarch64.
+ *
+ * Usually.
+ *
+ * Despite being a 32-bit platform, Clang (and emscripten) define this type
+ * despite not having the arithmetic for it. This results in a laggy
+ * compiler builtin call which calculates a full 128-bit multiply.
+ * In that case it is best to use the portable one.
+ * https://github.com/Cyan4973/xxHash/issues/211#issuecomment-515575677
+ */
+#if defined(__GNUC__) && !defined(__wasm__) \
+ && defined(__SIZEOF_INT128__) \
+ || (defined(_INTEGRAL_MAX_BITS) && _INTEGRAL_MAX_BITS >= 128)
+
+ __uint128_t const product = (__uint128_t)lhs * (__uint128_t)rhs;
+ XXH128_hash_t r128;
+ r128.low64 = (xxh_u64)(product);
+ r128.high64 = (xxh_u64)(product >> 64);
+ return r128;
+
+ /*
+ * MSVC for x64's _umul128 method.
+ *
+ * xxh_u64 _umul128(xxh_u64 Multiplier, xxh_u64 Multiplicand, xxh_u64 *HighProduct);
+ *
+ * This compiles to single operand MUL on x64.
+ */
+#elif defined(_M_X64) || defined(_M_IA64)
+
+#ifndef _MSC_VER
+# pragma intrinsic(_umul128)
+#endif
+ xxh_u64 product_high;
+ xxh_u64 const product_low = _umul128(lhs, rhs, &product_high);
+ XXH128_hash_t r128;
+ r128.low64 = product_low;
+ r128.high64 = product_high;
+ return r128;
+
+#else
+ /*
+ * Portable scalar method. Optimized for 32-bit and 64-bit ALUs.
+ *
+ * This is a fast and simple grade school multiply, which is shown below
+ * with base 10 arithmetic instead of base 0x100000000.
+ *
+ * 9 3 // D2 lhs = 93
+ * x 7 5 // D2 rhs = 75
+ * ----------
+ * 1 5 // D2 lo_lo = (93 % 10) * (75 % 10) = 15
+ * 4 5 | // D2 hi_lo = (93 / 10) * (75 % 10) = 45
+ * 2 1 | // D2 lo_hi = (93 % 10) * (75 / 10) = 21
+ * + 6 3 | | // D2 hi_hi = (93 / 10) * (75 / 10) = 63
+ * ---------
+ * 2 7 | // D2 cross = (15 / 10) + (45 % 10) + 21 = 27
+ * + 6 7 | | // D2 upper = (27 / 10) + (45 / 10) + 63 = 67
+ * ---------
+ * 6 9 7 5 // D4 res = (27 * 10) + (15 % 10) + (67 * 100) = 6975
+ *
+ * The reasons for adding the products like this are:
+ * 1. It avoids manual carry tracking. Just like how
+ * (9 * 9) + 9 + 9 = 99, the same applies with this for UINT64_MAX.
+ * This avoids a lot of complexity.
+ *
+ * 2. It hints for, and on Clang, compiles to, the powerful UMAAL
+ * instruction available in ARM's Digital Signal Processing extension
+ * in 32-bit ARMv6 and later, which is shown below:
+ *
+ * void UMAAL(xxh_u32 *RdLo, xxh_u32 *RdHi, xxh_u32 Rn, xxh_u32 Rm)
+ * {
+ * xxh_u64 product = (xxh_u64)*RdLo * (xxh_u64)*RdHi + Rn + Rm;
+ * *RdLo = (xxh_u32)(product & 0xFFFFFFFF);
+ * *RdHi = (xxh_u32)(product >> 32);
+ * }
+ *
+ * This instruction was designed for efficient long multiplication, and
+ * allows this to be calculated in only 4 instructions at speeds
+ * comparable to some 64-bit ALUs.
+ *
+ * 3. It isn't terrible on other platforms. Usually this will be a couple
+ * of 32-bit ADD/ADCs.
+ */
+
+ /* First calculate all of the cross products. */
+ xxh_u64 const lo_lo = XXH_mult32to64(lhs & 0xFFFFFFFF, rhs & 0xFFFFFFFF);
+ xxh_u64 const hi_lo = XXH_mult32to64(lhs >> 32, rhs & 0xFFFFFFFF);
+ xxh_u64 const lo_hi = XXH_mult32to64(lhs & 0xFFFFFFFF, rhs >> 32);
+ xxh_u64 const hi_hi = XXH_mult32to64(lhs >> 32, rhs >> 32);
+
+ /* Now add the products together. These will never overflow. */
+ xxh_u64 const cross = (lo_lo >> 32) + (hi_lo & 0xFFFFFFFF) + lo_hi;
+ xxh_u64 const upper = (hi_lo >> 32) + (cross >> 32) + hi_hi;
+ xxh_u64 const lower = (cross << 32) | (lo_lo & 0xFFFFFFFF);
+
+ XXH128_hash_t r128;
+ r128.low64 = lower;
+ r128.high64 = upper;
+ return r128;
+#endif
+}
+
+/*
+ * Does a 64-bit to 128-bit multiply, then XOR folds it.
+ *
+ * The reason for the separate function is to prevent passing too many structs
+ * around by value. This will hopefully inline the multiply, but we don't force it.
+ */
+static xxh_u64
+XXH3_mul128_fold64(xxh_u64 lhs, xxh_u64 rhs)
+{
+ XXH128_hash_t product = XXH_mult64to128(lhs, rhs);
+ return product.low64 ^ product.high64;
+}
+
+/* Seems to produce slightly better code on GCC for some reason. */
+XXH_FORCE_INLINE xxh_u64 XXH_xorshift64(xxh_u64 v64, int shift)
+{
+ XXH_ASSERT(0 <= shift && shift < 64);
+ return v64 ^ (v64 >> shift);
+}
+
+/*
+ * This is a fast avalanche stage,
+ * suitable when input bits are already partially mixed
+ */
+static XXH64_hash_t XXH3_avalanche(xxh_u64 h64)
+{
+ h64 = XXH_xorshift64(h64, 37);
+ h64 *= 0x165667919E3779F9ULL;
+ h64 = XXH_xorshift64(h64, 32);
+ return h64;
+}
+
+/*
+ * This is a stronger avalanche,
+ * inspired by Pelle Evensen's rrmxmx
+ * preferable when input has not been previously mixed
+ */
+static XXH64_hash_t XXH3_rrmxmx(xxh_u64 h64, xxh_u64 len)
+{
+ /* this mix is inspired by Pelle Evensen's rrmxmx */
+ h64 ^= XXH_rotl64(h64, 49) ^ XXH_rotl64(h64, 24);
+ h64 *= 0x9FB21C651E98DF25ULL;
+ h64 ^= (h64 >> 35) + len ;
+ h64 *= 0x9FB21C651E98DF25ULL;
+ return XXH_xorshift64(h64, 28);
+}
+
+
+/* ==========================================
+ * Short keys
+ * ==========================================
+ * One of the shortcomings of XXH32 and XXH64 was that their performance was
+ * sub-optimal on short lengths. It used an iterative algorithm which strongly
+ * favored lengths that were a multiple of 4 or 8.
+ *
+ * Instead of iterating over individual inputs, we use a set of single shot
+ * functions which piece together a range of lengths and operate in constant time.
+ *
+ * Additionally, the number of multiplies has been significantly reduced. This
+ * reduces latency, especially when emulating 64-bit multiplies on 32-bit.
+ *
+ * Depending on the platform, this may or may not be faster than XXH32, but it
+ * is almost guaranteed to be faster than XXH64.
+ */
+
+/*
+ * At very short lengths, there isn't enough input to fully hide secrets, or use
+ * the entire secret.
+ *
+ * There is also only a limited amount of mixing we can do before significantly
+ * impacting performance.
+ *
+ * Therefore, we use different sections of the secret and always mix two secret
+ * samples with an XOR. This should have no effect on performance on the
+ * seedless or withSeed variants because everything _should_ be constant folded
+ * by modern compilers.
+ *
+ * The XOR mixing hides individual parts of the secret and increases entropy.
+ *
+ * This adds an extra layer of strength for custom secrets.
+ */
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_len_1to3_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ XXH_ASSERT(input != NULL);
+ XXH_ASSERT(1 <= len && len <= 3);
+ XXH_ASSERT(secret != NULL);
+ /*
+ * len = 1: combined = { input[0], 0x01, input[0], input[0] }
+ * len = 2: combined = { input[1], 0x02, input[0], input[1] }
+ * len = 3: combined = { input[2], 0x03, input[0], input[1] }
+ */
+ { xxh_u8 const c1 = input[0];
+ xxh_u8 const c2 = input[len >> 1];
+ xxh_u8 const c3 = input[len - 1];
+ xxh_u32 const combined = ((xxh_u32)c1 << 16) | ((xxh_u32)c2 << 24)
+ | ((xxh_u32)c3 << 0) | ((xxh_u32)len << 8);
+ xxh_u64 const bitflip = (XXH_readLE32(secret) ^ XXH_readLE32(secret+4)) + seed;
+ xxh_u64 const keyed = (xxh_u64)combined ^ bitflip;
+ return XXH64_avalanche(keyed);
+ }
+}
+
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_len_4to8_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ XXH_ASSERT(input != NULL);
+ XXH_ASSERT(secret != NULL);
+ XXH_ASSERT(4 <= len && len < 8);
+ seed ^= (xxh_u64)XXH_swap32((xxh_u32)seed) << 32;
+ { xxh_u32 const input1 = XXH_readLE32(input);
+ xxh_u32 const input2 = XXH_readLE32(input + len - 4);
+ xxh_u64 const bitflip = (XXH_readLE64(secret+8) ^ XXH_readLE64(secret+16)) - seed;
+ xxh_u64 const input64 = input2 + (((xxh_u64)input1) << 32);
+ xxh_u64 const keyed = input64 ^ bitflip;
+ return XXH3_rrmxmx(keyed, len);
+ }
+}
+
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_len_9to16_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ XXH_ASSERT(input != NULL);
+ XXH_ASSERT(secret != NULL);
+ XXH_ASSERT(8 <= len && len <= 16);
+ { xxh_u64 const bitflip1 = (XXH_readLE64(secret+24) ^ XXH_readLE64(secret+32)) + seed;
+ xxh_u64 const bitflip2 = (XXH_readLE64(secret+40) ^ XXH_readLE64(secret+48)) - seed;
+ xxh_u64 const input_lo = XXH_readLE64(input) ^ bitflip1;
+ xxh_u64 const input_hi = XXH_readLE64(input + len - 8) ^ bitflip2;
+ xxh_u64 const acc = len
+ + XXH_swap64(input_lo) + input_hi
+ + XXH3_mul128_fold64(input_lo, input_hi);
+ return XXH3_avalanche(acc);
+ }
+}
+
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_len_0to16_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ XXH_ASSERT(len <= 16);
+ { if (XXH_likely(len > 8)) return XXH3_len_9to16_64b(input, len, secret, seed);
+ if (XXH_likely(len >= 4)) return XXH3_len_4to8_64b(input, len, secret, seed);
+ if (len) return XXH3_len_1to3_64b(input, len, secret, seed);
+ return XXH64_avalanche(seed ^ (XXH_readLE64(secret+56) ^ XXH_readLE64(secret+64)));
+ }
+}
+
+/*
+ * DISCLAIMER: There are known *seed-dependent* multicollisions here due to
+ * multiplication by zero, affecting hashes of lengths 17 to 240.
+ *
+ * However, they are very unlikely.
+ *
+ * Keep this in mind when using the unseeded XXH3_64bits() variant: As with all
+ * unseeded non-cryptographic hashes, it does not attempt to defend itself
+ * against specially crafted inputs, only random inputs.
+ *
+ * Compared to classic UMAC where a 1 in 2^31 chance of 4 consecutive bytes
+ * cancelling out the secret is taken an arbitrary number of times (addressed
+ * in XXH3_accumulate_512), this collision is very unlikely with random inputs
+ * and/or proper seeding:
+ *
+ * This only has a 1 in 2^63 chance of 8 consecutive bytes cancelling out, in a
+ * function that is only called up to 16 times per hash with up to 240 bytes of
+ * input.
+ *
+ * This is not too bad for a non-cryptographic hash function, especially with
+ * only 64 bit outputs.
+ *
+ * The 128-bit variant (which trades some speed for strength) is NOT affected
+ * by this, although it is always a good idea to use a proper seed if you care
+ * about strength.
+ */
+XXH_FORCE_INLINE xxh_u64 XXH3_mix16B(const xxh_u8* XXH_RESTRICT input,
+ const xxh_u8* XXH_RESTRICT secret, xxh_u64 seed64)
+{
+#if defined(__GNUC__) && !defined(__clang__) /* GCC, not Clang */ \
+ && defined(__i386__) && defined(__SSE2__) /* x86 + SSE2 */ \
+ && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable like XXH32 hack */
+ /*
+ * UGLY HACK:
+ * GCC for x86 tends to autovectorize the 128-bit multiply, resulting in
+ * slower code.
+ *
+ * By forcing seed64 into a register, we disrupt the cost model and
+ * cause it to scalarize. See `XXH32_round()`
+ *
+ * FIXME: Clang's output is still _much_ faster -- On an AMD Ryzen 3600,
+ * XXH3_64bits @ len=240 runs at 4.6 GB/s with Clang 9, but 3.3 GB/s on
+ * GCC 9.2, despite both emitting scalar code.
+ *
+ * GCC generates much better scalar code than Clang for the rest of XXH3,
+ * which is why finding a more optimal codepath is an interest.
+ */
+ __asm__ ("" : "+r" (seed64));
+#endif
+ { xxh_u64 const input_lo = XXH_readLE64(input);
+ xxh_u64 const input_hi = XXH_readLE64(input+8);
+ return XXH3_mul128_fold64(
+ input_lo ^ (XXH_readLE64(secret) + seed64),
+ input_hi ^ (XXH_readLE64(secret+8) - seed64)
+ );
+ }
+}
+
+/* For mid range keys, XXH3 uses a Mum-hash variant. */
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_len_17to128_64b(const xxh_u8* XXH_RESTRICT input, size_t len,
+ const xxh_u8* XXH_RESTRICT secret, size_t secretSize,
+ XXH64_hash_t seed)
+{
+ XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize;
+ XXH_ASSERT(16 < len && len <= 128);
+
+ { xxh_u64 acc = len * XXH_PRIME64_1;
+ if (len > 32) {
+ if (len > 64) {
+ if (len > 96) {
+ acc += XXH3_mix16B(input+48, secret+96, seed);
+ acc += XXH3_mix16B(input+len-64, secret+112, seed);
+ }
+ acc += XXH3_mix16B(input+32, secret+64, seed);
+ acc += XXH3_mix16B(input+len-48, secret+80, seed);
+ }
+ acc += XXH3_mix16B(input+16, secret+32, seed);
+ acc += XXH3_mix16B(input+len-32, secret+48, seed);
+ }
+ acc += XXH3_mix16B(input+0, secret+0, seed);
+ acc += XXH3_mix16B(input+len-16, secret+16, seed);
+
+ return XXH3_avalanche(acc);
+ }
+}
+
+#define XXH3_MIDSIZE_MAX 240
+
+XXH_NO_INLINE XXH64_hash_t
+XXH3_len_129to240_64b(const xxh_u8* XXH_RESTRICT input, size_t len,
+ const xxh_u8* XXH_RESTRICT secret, size_t secretSize,
+ XXH64_hash_t seed)
+{
+ XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize;
+ XXH_ASSERT(128 < len && len <= XXH3_MIDSIZE_MAX);
+
+ #define XXH3_MIDSIZE_STARTOFFSET 3
+ #define XXH3_MIDSIZE_LASTOFFSET 17
+
+ { xxh_u64 acc = len * XXH_PRIME64_1;
+ int const nbRounds = (int)len / 16;
+ int i;
+ for (i=0; i<8; i++) {
+ acc += XXH3_mix16B(input+(16*i), secret+(16*i), seed);
+ }
+ acc = XXH3_avalanche(acc);
+ XXH_ASSERT(nbRounds >= 8);
+#if defined(__clang__) /* Clang */ \
+ && (defined(__ARM_NEON) || defined(__ARM_NEON__)) /* NEON */ \
+ && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable */
+ /*
+ * UGLY HACK:
+ * Clang for ARMv7-A tries to vectorize this loop, similar to GCC x86.
+ * In everywhere else, it uses scalar code.
+ *
+ * For 64->128-bit multiplies, even if the NEON was 100% optimal, it
+ * would still be slower than UMAAL (see XXH_mult64to128).
+ *
+ * Unfortunately, Clang doesn't handle the long multiplies properly and
+ * converts them to the nonexistent "vmulq_u64" intrinsic, which is then
+ * scalarized into an ugly mess of VMOV.32 instructions.
+ *
+ * This mess is difficult to avoid without turning autovectorization
+ * off completely, but they are usually relatively minor and/or not
+ * worth it to fix.
+ *
+ * This loop is the easiest to fix, as unlike XXH32, this pragma
+ * _actually works_ because it is a loop vectorization instead of an
+ * SLP vectorization.
+ */
+ #pragma clang loop vectorize(disable)
+#endif
+ for (i=8 ; i < nbRounds; i++) {
+ acc += XXH3_mix16B(input+(16*i), secret+(16*(i-8)) + XXH3_MIDSIZE_STARTOFFSET, seed);
+ }
+ /* last bytes */
+ acc += XXH3_mix16B(input + len - 16, secret + XXH3_SECRET_SIZE_MIN - XXH3_MIDSIZE_LASTOFFSET, seed);
+ return XXH3_avalanche(acc);
+ }
+}
+
+
+/* ======= Long Keys ======= */
+
+#define XXH_STRIPE_LEN 64
+#define XXH_SECRET_CONSUME_RATE 8 /* nb of secret bytes consumed at each accumulation */
+#define XXH_ACC_NB (XXH_STRIPE_LEN / sizeof(xxh_u64))
+
+#ifdef XXH_OLD_NAMES
+# define STRIPE_LEN XXH_STRIPE_LEN
+# define ACC_NB XXH_ACC_NB
+#endif
+
+XXH_FORCE_INLINE void XXH_writeLE64(void* dst, xxh_u64 v64)
+{
+ if (!XXH_CPU_LITTLE_ENDIAN) v64 = XXH_swap64(v64);
+ memcpy(dst, &v64, sizeof(v64));
+}
+
+/* Several intrinsic functions below are supposed to accept __int64 as argument,
+ * as documented in https://software.intel.com/sites/landingpage/IntrinsicsGuide/ .
+ * However, several environments do not define __int64 type,
+ * requiring a workaround.
+ */
+#if !defined (__VMS) \
+ && (defined (__cplusplus) \
+ || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) )
+ typedef int64_t xxh_i64;
+#else
+ /* the following type must have a width of 64-bit */
+ typedef long long xxh_i64;
+#endif
+
+/*
+ * XXH3_accumulate_512 is the tightest loop for long inputs, and it is the most optimized.
+ *
+ * It is a hardened version of UMAC, based off of FARSH's implementation.
+ *
+ * This was chosen because it adapts quite well to 32-bit, 64-bit, and SIMD
+ * implementations, and it is ridiculously fast.
+ *
+ * We harden it by mixing the original input to the accumulators as well as the product.
+ *
+ * This means that in the (relatively likely) case of a multiply by zero, the
+ * original input is preserved.
+ *
+ * On 128-bit inputs, we swap 64-bit pairs when we add the input to improve
+ * cross-pollination, as otherwise the upper and lower halves would be
+ * essentially independent.
+ *
+ * This doesn't matter on 64-bit hashes since they all get merged together in
+ * the end, so we skip the extra step.
+ *
+ * Both XXH3_64bits and XXH3_128bits use this subroutine.
+ */
+
+#if (XXH_VECTOR == XXH_AVX512) || defined(XXH_X86DISPATCH)
+
+#ifndef XXH_TARGET_AVX512
+# define XXH_TARGET_AVX512 /* disable attribute target */
+#endif
+
+XXH_FORCE_INLINE XXH_TARGET_AVX512 void
+XXH3_accumulate_512_avx512(void* XXH_RESTRICT acc,
+ const void* XXH_RESTRICT input,
+ const void* XXH_RESTRICT secret)
+{
+ XXH_ALIGN(64) __m512i* const xacc = (__m512i *) acc;
+ XXH_ASSERT((((size_t)acc) & 63) == 0);
+ XXH_STATIC_ASSERT(XXH_STRIPE_LEN == sizeof(__m512i));
+
+ {
+ /* data_vec = input[0]; */
+ __m512i const data_vec = _mm512_loadu_si512 (input);
+ /* key_vec = secret[0]; */
+ __m512i const key_vec = _mm512_loadu_si512 (secret);
+ /* data_key = data_vec ^ key_vec; */
+ __m512i const data_key = _mm512_xor_si512 (data_vec, key_vec);
+ /* data_key_lo = data_key >> 32; */
+ __m512i const data_key_lo = _mm512_shuffle_epi32 (data_key, (_MM_PERM_ENUM)_MM_SHUFFLE(0, 3, 0, 1));
+ /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */
+ __m512i const product = _mm512_mul_epu32 (data_key, data_key_lo);
+ /* xacc[0] += swap(data_vec); */
+ __m512i const data_swap = _mm512_shuffle_epi32(data_vec, (_MM_PERM_ENUM)_MM_SHUFFLE(1, 0, 3, 2));
+ __m512i const sum = _mm512_add_epi64(*xacc, data_swap);
+ /* xacc[0] += product; */
+ *xacc = _mm512_add_epi64(product, sum);
+ }
+}
+
+/*
+ * XXH3_scrambleAcc: Scrambles the accumulators to improve mixing.
+ *
+ * Multiplication isn't perfect, as explained by Google in HighwayHash:
+ *
+ * // Multiplication mixes/scrambles bytes 0-7 of the 64-bit result to
+ * // varying degrees. In descending order of goodness, bytes
+ * // 3 4 2 5 1 6 0 7 have quality 228 224 164 160 100 96 36 32.
+ * // As expected, the upper and lower bytes are much worse.
+ *
+ * Source: https://github.com/google/highwayhash/blob/0aaf66b/highwayhash/hh_avx2.h#L291
+ *
+ * Since our algorithm uses a pseudorandom secret to add some variance into the
+ * mix, we don't need to (or want to) mix as often or as much as HighwayHash does.
+ *
+ * This isn't as tight as XXH3_accumulate, but still written in SIMD to avoid
+ * extraction.
+ *
+ * Both XXH3_64bits and XXH3_128bits use this subroutine.
+ */
+
+XXH_FORCE_INLINE XXH_TARGET_AVX512 void
+XXH3_scrambleAcc_avx512(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret)
+{
+ XXH_ASSERT((((size_t)acc) & 63) == 0);
+ XXH_STATIC_ASSERT(XXH_STRIPE_LEN == sizeof(__m512i));
+ { XXH_ALIGN(64) __m512i* const xacc = (__m512i*) acc;
+ const __m512i prime32 = _mm512_set1_epi32((int)XXH_PRIME32_1);
+
+ /* xacc[0] ^= (xacc[0] >> 47) */
+ __m512i const acc_vec = *xacc;
+ __m512i const shifted = _mm512_srli_epi64 (acc_vec, 47);
+ __m512i const data_vec = _mm512_xor_si512 (acc_vec, shifted);
+ /* xacc[0] ^= secret; */
+ __m512i const key_vec = _mm512_loadu_si512 (secret);
+ __m512i const data_key = _mm512_xor_si512 (data_vec, key_vec);
+
+ /* xacc[0] *= XXH_PRIME32_1; */
+ __m512i const data_key_hi = _mm512_shuffle_epi32 (data_key, (_MM_PERM_ENUM)_MM_SHUFFLE(0, 3, 0, 1));
+ __m512i const prod_lo = _mm512_mul_epu32 (data_key, prime32);
+ __m512i const prod_hi = _mm512_mul_epu32 (data_key_hi, prime32);
+ *xacc = _mm512_add_epi64(prod_lo, _mm512_slli_epi64(prod_hi, 32));
+ }
+}
+
+XXH_FORCE_INLINE XXH_TARGET_AVX512 void
+XXH3_initCustomSecret_avx512(void* XXH_RESTRICT customSecret, xxh_u64 seed64)
+{
+ XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 63) == 0);
+ XXH_STATIC_ASSERT(XXH_SEC_ALIGN == 64);
+ XXH_ASSERT(((size_t)customSecret & 63) == 0);
+ (void)(&XXH_writeLE64);
+ { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / sizeof(__m512i);
+ __m512i const seed = _mm512_mask_set1_epi64(_mm512_set1_epi64((xxh_i64)seed64), 0xAA, -(xxh_i64)seed64);
+
+ XXH_ALIGN(64) const __m512i* const src = (const __m512i*) XXH3_kSecret;
+ XXH_ALIGN(64) __m512i* const dest = ( __m512i*) customSecret;
+ int i;
+ for (i=0; i < nbRounds; ++i) {
+ /* GCC has a bug, _mm512_stream_load_si512 accepts 'void*', not 'void const*',
+ * this will warn "discards ‘const’ qualifier". */
+ union {
+ XXH_ALIGN(64) const __m512i* cp;
+ XXH_ALIGN(64) void* p;
+ } remote_const_void;
+ remote_const_void.cp = src + i;
+ dest[i] = _mm512_add_epi64(_mm512_stream_load_si512(remote_const_void.p), seed);
+ } }
+}
+
+#endif
+
+#if (XXH_VECTOR == XXH_AVX2) || defined(XXH_X86DISPATCH)
+
+#ifndef XXH_TARGET_AVX2
+# define XXH_TARGET_AVX2 /* disable attribute target */
+#endif
+
+XXH_FORCE_INLINE XXH_TARGET_AVX2 void
+XXH3_accumulate_512_avx2( void* XXH_RESTRICT acc,
+ const void* XXH_RESTRICT input,
+ const void* XXH_RESTRICT secret)
+{
+ XXH_ASSERT((((size_t)acc) & 31) == 0);
+ { XXH_ALIGN(32) __m256i* const xacc = (__m256i *) acc;
+ /* Unaligned. This is mainly for pointer arithmetic, and because
+ * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */
+ const __m256i* const xinput = (const __m256i *) input;
+ /* Unaligned. This is mainly for pointer arithmetic, and because
+ * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */
+ const __m256i* const xsecret = (const __m256i *) secret;
+
+ size_t i;
+ for (i=0; i < XXH_STRIPE_LEN/sizeof(__m256i); i++) {
+ /* data_vec = xinput[i]; */
+ __m256i const data_vec = _mm256_loadu_si256 (xinput+i);
+ /* key_vec = xsecret[i]; */
+ __m256i const key_vec = _mm256_loadu_si256 (xsecret+i);
+ /* data_key = data_vec ^ key_vec; */
+ __m256i const data_key = _mm256_xor_si256 (data_vec, key_vec);
+ /* data_key_lo = data_key >> 32; */
+ __m256i const data_key_lo = _mm256_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1));
+ /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */
+ __m256i const product = _mm256_mul_epu32 (data_key, data_key_lo);
+ /* xacc[i] += swap(data_vec); */
+ __m256i const data_swap = _mm256_shuffle_epi32(data_vec, _MM_SHUFFLE(1, 0, 3, 2));
+ __m256i const sum = _mm256_add_epi64(xacc[i], data_swap);
+ /* xacc[i] += product; */
+ xacc[i] = _mm256_add_epi64(product, sum);
+ } }
+}
+
+XXH_FORCE_INLINE XXH_TARGET_AVX2 void
+XXH3_scrambleAcc_avx2(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret)
+{
+ XXH_ASSERT((((size_t)acc) & 31) == 0);
+ { XXH_ALIGN(32) __m256i* const xacc = (__m256i*) acc;
+ /* Unaligned. This is mainly for pointer arithmetic, and because
+ * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */
+ const __m256i* const xsecret = (const __m256i *) secret;
+ const __m256i prime32 = _mm256_set1_epi32((int)XXH_PRIME32_1);
+
+ size_t i;
+ for (i=0; i < XXH_STRIPE_LEN/sizeof(__m256i); i++) {
+ /* xacc[i] ^= (xacc[i] >> 47) */
+ __m256i const acc_vec = xacc[i];
+ __m256i const shifted = _mm256_srli_epi64 (acc_vec, 47);
+ __m256i const data_vec = _mm256_xor_si256 (acc_vec, shifted);
+ /* xacc[i] ^= xsecret; */
+ __m256i const key_vec = _mm256_loadu_si256 (xsecret+i);
+ __m256i const data_key = _mm256_xor_si256 (data_vec, key_vec);
+
+ /* xacc[i] *= XXH_PRIME32_1; */
+ __m256i const data_key_hi = _mm256_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1));
+ __m256i const prod_lo = _mm256_mul_epu32 (data_key, prime32);
+ __m256i const prod_hi = _mm256_mul_epu32 (data_key_hi, prime32);
+ xacc[i] = _mm256_add_epi64(prod_lo, _mm256_slli_epi64(prod_hi, 32));
+ }
+ }
+}
+
+XXH_FORCE_INLINE XXH_TARGET_AVX2 void XXH3_initCustomSecret_avx2(void* XXH_RESTRICT customSecret, xxh_u64 seed64)
+{
+ XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 31) == 0);
+ XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE / sizeof(__m256i)) == 6);
+ XXH_STATIC_ASSERT(XXH_SEC_ALIGN <= 64);
+ (void)(&XXH_writeLE64);
+ XXH_PREFETCH(customSecret);
+ { __m256i const seed = _mm256_set_epi64x(-(xxh_i64)seed64, (xxh_i64)seed64, -(xxh_i64)seed64, (xxh_i64)seed64);
+
+ XXH_ALIGN(64) const __m256i* const src = (const __m256i*) XXH3_kSecret;
+ XXH_ALIGN(64) __m256i* dest = ( __m256i*) customSecret;
+
+# if defined(__GNUC__) || defined(__clang__)
+ /*
+ * On GCC & Clang, marking 'dest' as modified will cause the compiler:
+ * - do not extract the secret from sse registers in the internal loop
+ * - use less common registers, and avoid pushing these reg into stack
+ * The asm hack causes Clang to assume that XXH3_kSecretPtr aliases with
+ * customSecret, and on aarch64, this prevented LDP from merging two
+ * loads together for free. Putting the loads together before the stores
+ * properly generates LDP.
+ */
+ __asm__("" : "+r" (dest));
+# endif
+
+ /* GCC -O2 need unroll loop manually */
+ dest[0] = _mm256_add_epi64(_mm256_stream_load_si256(src+0), seed);
+ dest[1] = _mm256_add_epi64(_mm256_stream_load_si256(src+1), seed);
+ dest[2] = _mm256_add_epi64(_mm256_stream_load_si256(src+2), seed);
+ dest[3] = _mm256_add_epi64(_mm256_stream_load_si256(src+3), seed);
+ dest[4] = _mm256_add_epi64(_mm256_stream_load_si256(src+4), seed);
+ dest[5] = _mm256_add_epi64(_mm256_stream_load_si256(src+5), seed);
+ }
+}
+
+#endif
+
+#if (XXH_VECTOR == XXH_SSE2) || defined(XXH_X86DISPATCH)
+
+#ifndef XXH_TARGET_SSE2
+# define XXH_TARGET_SSE2 /* disable attribute target */
+#endif
+
+XXH_FORCE_INLINE XXH_TARGET_SSE2 void
+XXH3_accumulate_512_sse2( void* XXH_RESTRICT acc,
+ const void* XXH_RESTRICT input,
+ const void* XXH_RESTRICT secret)
+{
+ /* SSE2 is just a half-scale version of the AVX2 version. */
+ XXH_ASSERT((((size_t)acc) & 15) == 0);
+ { XXH_ALIGN(16) __m128i* const xacc = (__m128i *) acc;
+ /* Unaligned. This is mainly for pointer arithmetic, and because
+ * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */
+ const __m128i* const xinput = (const __m128i *) input;
+ /* Unaligned. This is mainly for pointer arithmetic, and because
+ * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */
+ const __m128i* const xsecret = (const __m128i *) secret;
+
+ size_t i;
+ for (i=0; i < XXH_STRIPE_LEN/sizeof(__m128i); i++) {
+ /* data_vec = xinput[i]; */
+ __m128i const data_vec = _mm_loadu_si128 (xinput+i);
+ /* key_vec = xsecret[i]; */
+ __m128i const key_vec = _mm_loadu_si128 (xsecret+i);
+ /* data_key = data_vec ^ key_vec; */
+ __m128i const data_key = _mm_xor_si128 (data_vec, key_vec);
+ /* data_key_lo = data_key >> 32; */
+ __m128i const data_key_lo = _mm_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1));
+ /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */
+ __m128i const product = _mm_mul_epu32 (data_key, data_key_lo);
+ /* xacc[i] += swap(data_vec); */
+ __m128i const data_swap = _mm_shuffle_epi32(data_vec, _MM_SHUFFLE(1,0,3,2));
+ __m128i const sum = _mm_add_epi64(xacc[i], data_swap);
+ /* xacc[i] += product; */
+ xacc[i] = _mm_add_epi64(product, sum);
+ } }
+}
+
+XXH_FORCE_INLINE XXH_TARGET_SSE2 void
+XXH3_scrambleAcc_sse2(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret)
+{
+ XXH_ASSERT((((size_t)acc) & 15) == 0);
+ { XXH_ALIGN(16) __m128i* const xacc = (__m128i*) acc;
+ /* Unaligned. This is mainly for pointer arithmetic, and because
+ * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */
+ const __m128i* const xsecret = (const __m128i *) secret;
+ const __m128i prime32 = _mm_set1_epi32((int)XXH_PRIME32_1);
+
+ size_t i;
+ for (i=0; i < XXH_STRIPE_LEN/sizeof(__m128i); i++) {
+ /* xacc[i] ^= (xacc[i] >> 47) */
+ __m128i const acc_vec = xacc[i];
+ __m128i const shifted = _mm_srli_epi64 (acc_vec, 47);
+ __m128i const data_vec = _mm_xor_si128 (acc_vec, shifted);
+ /* xacc[i] ^= xsecret[i]; */
+ __m128i const key_vec = _mm_loadu_si128 (xsecret+i);
+ __m128i const data_key = _mm_xor_si128 (data_vec, key_vec);
+
+ /* xacc[i] *= XXH_PRIME32_1; */
+ __m128i const data_key_hi = _mm_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1));
+ __m128i const prod_lo = _mm_mul_epu32 (data_key, prime32);
+ __m128i const prod_hi = _mm_mul_epu32 (data_key_hi, prime32);
+ xacc[i] = _mm_add_epi64(prod_lo, _mm_slli_epi64(prod_hi, 32));
+ }
+ }
+}
+
+XXH_FORCE_INLINE XXH_TARGET_SSE2 void XXH3_initCustomSecret_sse2(void* XXH_RESTRICT customSecret, xxh_u64 seed64)
+{
+ XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 15) == 0);
+ (void)(&XXH_writeLE64);
+ { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / sizeof(__m128i);
+
+# if defined(_MSC_VER) && defined(_M_IX86) && _MSC_VER < 1900
+ // MSVC 32bit mode does not support _mm_set_epi64x before 2015
+ XXH_ALIGN(16) const xxh_i64 seed64x2[2] = { (xxh_i64)seed64, -(xxh_i64)seed64 };
+ __m128i const seed = _mm_load_si128((__m128i const*)seed64x2);
+# else
+ __m128i const seed = _mm_set_epi64x(-(xxh_i64)seed64, (xxh_i64)seed64);
+# endif
+ int i;
+
+ XXH_ALIGN(64) const float* const src = (float const*) XXH3_kSecret;
+ XXH_ALIGN(XXH_SEC_ALIGN) __m128i* dest = (__m128i*) customSecret;
+# if defined(__GNUC__) || defined(__clang__)
+ /*
+ * On GCC & Clang, marking 'dest' as modified will cause the compiler:
+ * - do not extract the secret from sse registers in the internal loop
+ * - use less common registers, and avoid pushing these reg into stack
+ */
+ __asm__("" : "+r" (dest));
+# endif
+
+ for (i=0; i < nbRounds; ++i) {
+ dest[i] = _mm_add_epi64(_mm_castps_si128(_mm_load_ps(src+i*4)), seed);
+ } }
+}
+
+#endif
+
+#if (XXH_VECTOR == XXH_NEON)
+
+XXH_FORCE_INLINE void
+XXH3_accumulate_512_neon( void* XXH_RESTRICT acc,
+ const void* XXH_RESTRICT input,
+ const void* XXH_RESTRICT secret)
+{
+ XXH_ASSERT((((size_t)acc) & 15) == 0);
+ {
+ XXH_ALIGN(16) uint64x2_t* const xacc = (uint64x2_t *) acc;
+ /* We don't use a uint32x4_t pointer because it causes bus errors on ARMv7. */
+ uint8_t const* const xinput = (const uint8_t *) input;
+ uint8_t const* const xsecret = (const uint8_t *) secret;
+
+ size_t i;
+ for (i=0; i < XXH_STRIPE_LEN / sizeof(uint64x2_t); i++) {
+ /* data_vec = xinput[i]; */
+ uint8x16_t data_vec = vld1q_u8(xinput + (i * 16));
+ /* key_vec = xsecret[i]; */
+ uint8x16_t key_vec = vld1q_u8(xsecret + (i * 16));
+ uint64x2_t data_key;
+ uint32x2_t data_key_lo, data_key_hi;
+ /* xacc[i] += swap(data_vec); */
+ uint64x2_t const data64 = vreinterpretq_u64_u8(data_vec);
+ uint64x2_t const swapped = vextq_u64(data64, data64, 1);
+ xacc[i] = vaddq_u64 (xacc[i], swapped);
+ /* data_key = data_vec ^ key_vec; */
+ data_key = vreinterpretq_u64_u8(veorq_u8(data_vec, key_vec));
+ /* data_key_lo = (uint32x2_t) (data_key & 0xFFFFFFFF);
+ * data_key_hi = (uint32x2_t) (data_key >> 32);
+ * data_key = UNDEFINED; */
+ XXH_SPLIT_IN_PLACE(data_key, data_key_lo, data_key_hi);
+ /* xacc[i] += (uint64x2_t) data_key_lo * (uint64x2_t) data_key_hi; */
+ xacc[i] = vmlal_u32 (xacc[i], data_key_lo, data_key_hi);
+
+ }
+ }
+}
+
+XXH_FORCE_INLINE void
+XXH3_scrambleAcc_neon(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret)
+{
+ XXH_ASSERT((((size_t)acc) & 15) == 0);
+
+ { uint64x2_t* xacc = (uint64x2_t*) acc;
+ uint8_t const* xsecret = (uint8_t const*) secret;
+ uint32x2_t prime = vdup_n_u32 (XXH_PRIME32_1);
+
+ size_t i;
+ for (i=0; i < XXH_STRIPE_LEN/sizeof(uint64x2_t); i++) {
+ /* xacc[i] ^= (xacc[i] >> 47); */
+ uint64x2_t acc_vec = xacc[i];
+ uint64x2_t shifted = vshrq_n_u64 (acc_vec, 47);
+ uint64x2_t data_vec = veorq_u64 (acc_vec, shifted);
+
+ /* xacc[i] ^= xsecret[i]; */
+ uint8x16_t key_vec = vld1q_u8(xsecret + (i * 16));
+ uint64x2_t data_key = veorq_u64(data_vec, vreinterpretq_u64_u8(key_vec));
+
+ /* xacc[i] *= XXH_PRIME32_1 */
+ uint32x2_t data_key_lo, data_key_hi;
+ /* data_key_lo = (uint32x2_t) (xacc[i] & 0xFFFFFFFF);
+ * data_key_hi = (uint32x2_t) (xacc[i] >> 32);
+ * xacc[i] = UNDEFINED; */
+ XXH_SPLIT_IN_PLACE(data_key, data_key_lo, data_key_hi);
+ { /*
+ * prod_hi = (data_key >> 32) * XXH_PRIME32_1;
+ *
+ * Avoid vmul_u32 + vshll_n_u32 since Clang 6 and 7 will
+ * incorrectly "optimize" this:
+ * tmp = vmul_u32(vmovn_u64(a), vmovn_u64(b));
+ * shifted = vshll_n_u32(tmp, 32);
+ * to this:
+ * tmp = "vmulq_u64"(a, b); // no such thing!
+ * shifted = vshlq_n_u64(tmp, 32);
+ *
+ * However, unlike SSE, Clang lacks a 64-bit multiply routine
+ * for NEON, and it scalarizes two 64-bit multiplies instead.
+ *
+ * vmull_u32 has the same timing as vmul_u32, and it avoids
+ * this bug completely.
+ * See https://bugs.llvm.org/show_bug.cgi?id=39967
+ */
+ uint64x2_t prod_hi = vmull_u32 (data_key_hi, prime);
+ /* xacc[i] = prod_hi << 32; */
+ xacc[i] = vshlq_n_u64(prod_hi, 32);
+ /* xacc[i] += (prod_hi & 0xFFFFFFFF) * XXH_PRIME32_1; */
+ xacc[i] = vmlal_u32(xacc[i], data_key_lo, prime);
+ }
+ } }
+}
+
+#endif
+
+#if (XXH_VECTOR == XXH_VSX)
+
+XXH_FORCE_INLINE void
+XXH3_accumulate_512_vsx( void* XXH_RESTRICT acc,
+ const void* XXH_RESTRICT input,
+ const void* XXH_RESTRICT secret)
+{
+ xxh_u64x2* const xacc = (xxh_u64x2*) acc; /* presumed aligned */
+ xxh_u64x2 const* const xinput = (xxh_u64x2 const*) input; /* no alignment restriction */
+ xxh_u64x2 const* const xsecret = (xxh_u64x2 const*) secret; /* no alignment restriction */
+ xxh_u64x2 const v32 = { 32, 32 };
+ size_t i;
+ for (i = 0; i < XXH_STRIPE_LEN / sizeof(xxh_u64x2); i++) {
+ /* data_vec = xinput[i]; */
+ xxh_u64x2 const data_vec = XXH_vec_loadu(xinput + i);
+ /* key_vec = xsecret[i]; */
+ xxh_u64x2 const key_vec = XXH_vec_loadu(xsecret + i);
+ xxh_u64x2 const data_key = data_vec ^ key_vec;
+ /* shuffled = (data_key << 32) | (data_key >> 32); */
+ xxh_u32x4 const shuffled = (xxh_u32x4)vec_rl(data_key, v32);
+ /* product = ((xxh_u64x2)data_key & 0xFFFFFFFF) * ((xxh_u64x2)shuffled & 0xFFFFFFFF); */
+ xxh_u64x2 const product = XXH_vec_mulo((xxh_u32x4)data_key, shuffled);
+ xacc[i] += product;
+
+ /* swap high and low halves */
+#ifdef __s390x__
+ xacc[i] += vec_permi(data_vec, data_vec, 2);
+#else
+ xacc[i] += vec_xxpermdi(data_vec, data_vec, 2);
+#endif
+ }
+}
+
+XXH_FORCE_INLINE void
+XXH3_scrambleAcc_vsx(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret)
+{
+ XXH_ASSERT((((size_t)acc) & 15) == 0);
+
+ { xxh_u64x2* const xacc = (xxh_u64x2*) acc;
+ const xxh_u64x2* const xsecret = (const xxh_u64x2*) secret;
+ /* constants */
+ xxh_u64x2 const v32 = { 32, 32 };
+ xxh_u64x2 const v47 = { 47, 47 };
+ xxh_u32x4 const prime = { XXH_PRIME32_1, XXH_PRIME32_1, XXH_PRIME32_1, XXH_PRIME32_1 };
+ size_t i;
+ for (i = 0; i < XXH_STRIPE_LEN / sizeof(xxh_u64x2); i++) {
+ /* xacc[i] ^= (xacc[i] >> 47); */
+ xxh_u64x2 const acc_vec = xacc[i];
+ xxh_u64x2 const data_vec = acc_vec ^ (acc_vec >> v47);
+
+ /* xacc[i] ^= xsecret[i]; */
+ xxh_u64x2 const key_vec = XXH_vec_loadu(xsecret + i);
+ xxh_u64x2 const data_key = data_vec ^ key_vec;
+
+ /* xacc[i] *= XXH_PRIME32_1 */
+ /* prod_lo = ((xxh_u64x2)data_key & 0xFFFFFFFF) * ((xxh_u64x2)prime & 0xFFFFFFFF); */
+ xxh_u64x2 const prod_even = XXH_vec_mule((xxh_u32x4)data_key, prime);
+ /* prod_hi = ((xxh_u64x2)data_key >> 32) * ((xxh_u64x2)prime >> 32); */
+ xxh_u64x2 const prod_odd = XXH_vec_mulo((xxh_u32x4)data_key, prime);
+ xacc[i] = prod_odd + (prod_even << v32);
+ } }
+}
+
+#endif
+
+/* scalar variants - universal */
+
+XXH_FORCE_INLINE void
+XXH3_accumulate_512_scalar(void* XXH_RESTRICT acc,
+ const void* XXH_RESTRICT input,
+ const void* XXH_RESTRICT secret)
+{
+ XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64* const xacc = (xxh_u64*) acc; /* presumed aligned */
+ const xxh_u8* const xinput = (const xxh_u8*) input; /* no alignment restriction */
+ const xxh_u8* const xsecret = (const xxh_u8*) secret; /* no alignment restriction */
+ size_t i;
+ XXH_ASSERT(((size_t)acc & (XXH_ACC_ALIGN-1)) == 0);
+ for (i=0; i < XXH_ACC_NB; i++) {
+ xxh_u64 const data_val = XXH_readLE64(xinput + 8*i);
+ xxh_u64 const data_key = data_val ^ XXH_readLE64(xsecret + i*8);
+ xacc[i ^ 1] += data_val; /* swap adjacent lanes */
+ xacc[i] += XXH_mult32to64(data_key & 0xFFFFFFFF, data_key >> 32);
+ }
+}
+
+XXH_FORCE_INLINE void
+XXH3_scrambleAcc_scalar(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret)
+{
+ XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64* const xacc = (xxh_u64*) acc; /* presumed aligned */
+ const xxh_u8* const xsecret = (const xxh_u8*) secret; /* no alignment restriction */
+ size_t i;
+ XXH_ASSERT((((size_t)acc) & (XXH_ACC_ALIGN-1)) == 0);
+ for (i=0; i < XXH_ACC_NB; i++) {
+ xxh_u64 const key64 = XXH_readLE64(xsecret + 8*i);
+ xxh_u64 acc64 = xacc[i];
+ acc64 = XXH_xorshift64(acc64, 47);
+ acc64 ^= key64;
+ acc64 *= XXH_PRIME32_1;
+ xacc[i] = acc64;
+ }
+}
+
+XXH_FORCE_INLINE void
+XXH3_initCustomSecret_scalar(void* XXH_RESTRICT customSecret, xxh_u64 seed64)
+{
+ /*
+ * We need a separate pointer for the hack below,
+ * which requires a non-const pointer.
+ * Any decent compiler will optimize this out otherwise.
+ */
+ const xxh_u8* kSecretPtr = XXH3_kSecret;
+ XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 15) == 0);
+
+#if defined(__clang__) && defined(__aarch64__)
+ /*
+ * UGLY HACK:
+ * Clang generates a bunch of MOV/MOVK pairs for aarch64, and they are
+ * placed sequentially, in order, at the top of the unrolled loop.
+ *
+ * While MOVK is great for generating constants (2 cycles for a 64-bit
+ * constant compared to 4 cycles for LDR), long MOVK chains stall the
+ * integer pipelines:
+ * I L S
+ * MOVK
+ * MOVK
+ * MOVK
+ * MOVK
+ * ADD
+ * SUB STR
+ * STR
+ * By forcing loads from memory (as the asm line causes Clang to assume
+ * that XXH3_kSecretPtr has been changed), the pipelines are used more
+ * efficiently:
+ * I L S
+ * LDR
+ * ADD LDR
+ * SUB STR
+ * STR
+ * XXH3_64bits_withSeed, len == 256, Snapdragon 835
+ * without hack: 2654.4 MB/s
+ * with hack: 3202.9 MB/s
+ */
+ __asm__("" : "+r" (kSecretPtr));
+#endif
+ /*
+ * Note: in debug mode, this overrides the asm optimization
+ * and Clang will emit MOVK chains again.
+ */
+ XXH_ASSERT(kSecretPtr == XXH3_kSecret);
+
+ { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / 16;
+ int i;
+ for (i=0; i < nbRounds; i++) {
+ /*
+ * The asm hack causes Clang to assume that kSecretPtr aliases with
+ * customSecret, and on aarch64, this prevented LDP from merging two
+ * loads together for free. Putting the loads together before the stores
+ * properly generates LDP.
+ */
+ xxh_u64 lo = XXH_readLE64(kSecretPtr + 16*i) + seed64;
+ xxh_u64 hi = XXH_readLE64(kSecretPtr + 16*i + 8) - seed64;
+ XXH_writeLE64((xxh_u8*)customSecret + 16*i, lo);
+ XXH_writeLE64((xxh_u8*)customSecret + 16*i + 8, hi);
+ } }
+}
+
+
+typedef void (*XXH3_f_accumulate_512)(void* XXH_RESTRICT, const void*, const void*);
+typedef void (*XXH3_f_scrambleAcc)(void* XXH_RESTRICT, const void*);
+typedef void (*XXH3_f_initCustomSecret)(void* XXH_RESTRICT, xxh_u64);
+
+
+#if (XXH_VECTOR == XXH_AVX512)
+
+#define XXH3_accumulate_512 XXH3_accumulate_512_avx512
+#define XXH3_scrambleAcc XXH3_scrambleAcc_avx512
+#define XXH3_initCustomSecret XXH3_initCustomSecret_avx512
+
+#elif (XXH_VECTOR == XXH_AVX2)
+
+#define XXH3_accumulate_512 XXH3_accumulate_512_avx2
+#define XXH3_scrambleAcc XXH3_scrambleAcc_avx2
+#define XXH3_initCustomSecret XXH3_initCustomSecret_avx2
+
+#elif (XXH_VECTOR == XXH_SSE2)
+
+#define XXH3_accumulate_512 XXH3_accumulate_512_sse2
+#define XXH3_scrambleAcc XXH3_scrambleAcc_sse2
+#define XXH3_initCustomSecret XXH3_initCustomSecret_sse2
+
+#elif (XXH_VECTOR == XXH_NEON)
+
+#define XXH3_accumulate_512 XXH3_accumulate_512_neon
+#define XXH3_scrambleAcc XXH3_scrambleAcc_neon
+#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar
+
+#elif (XXH_VECTOR == XXH_VSX)
+
+#define XXH3_accumulate_512 XXH3_accumulate_512_vsx
+#define XXH3_scrambleAcc XXH3_scrambleAcc_vsx
+#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar
+
+#else /* scalar */
+
+#define XXH3_accumulate_512 XXH3_accumulate_512_scalar
+#define XXH3_scrambleAcc XXH3_scrambleAcc_scalar
+#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar
+
+#endif
+
+
+
+#ifndef XXH_PREFETCH_DIST
+# ifdef __clang__
+# define XXH_PREFETCH_DIST 320
+# else
+# if (XXH_VECTOR == XXH_AVX512)
+# define XXH_PREFETCH_DIST 512
+# else
+# define XXH_PREFETCH_DIST 384
+# endif
+# endif /* __clang__ */
+#endif /* XXH_PREFETCH_DIST */
+
+/*
+ * XXH3_accumulate()
+ * Loops over XXH3_accumulate_512().
+ * Assumption: nbStripes will not overflow the secret size
+ */
+XXH_FORCE_INLINE void
+XXH3_accumulate( xxh_u64* XXH_RESTRICT acc,
+ const xxh_u8* XXH_RESTRICT input,
+ const xxh_u8* XXH_RESTRICT secret,
+ size_t nbStripes,
+ XXH3_f_accumulate_512 f_acc512)
+{
+ size_t n;
+ for (n = 0; n < nbStripes; n++ ) {
+ const xxh_u8* const in = input + n*XXH_STRIPE_LEN;
+ XXH_PREFETCH(in + XXH_PREFETCH_DIST);
+ f_acc512(acc,
+ in,
+ secret + n*XXH_SECRET_CONSUME_RATE);
+ }
+}
+
+XXH_FORCE_INLINE void
+XXH3_hashLong_internal_loop(xxh_u64* XXH_RESTRICT acc,
+ const xxh_u8* XXH_RESTRICT input, size_t len,
+ const xxh_u8* XXH_RESTRICT secret, size_t secretSize,
+ XXH3_f_accumulate_512 f_acc512,
+ XXH3_f_scrambleAcc f_scramble)
+{
+ size_t const nbStripesPerBlock = (secretSize - XXH_STRIPE_LEN) / XXH_SECRET_CONSUME_RATE;
+ size_t const block_len = XXH_STRIPE_LEN * nbStripesPerBlock;
+ size_t const nb_blocks = (len - 1) / block_len;
+
+ size_t n;
+
+ XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN);
+
+ for (n = 0; n < nb_blocks; n++) {
+ XXH3_accumulate(acc, input + n*block_len, secret, nbStripesPerBlock, f_acc512);
+ f_scramble(acc, secret + secretSize - XXH_STRIPE_LEN);
+ }
+
+ /* last partial block */
+ XXH_ASSERT(len > XXH_STRIPE_LEN);
+ { size_t const nbStripes = ((len - 1) - (block_len * nb_blocks)) / XXH_STRIPE_LEN;
+ XXH_ASSERT(nbStripes <= (secretSize / XXH_SECRET_CONSUME_RATE));
+ XXH3_accumulate(acc, input + nb_blocks*block_len, secret, nbStripes, f_acc512);
+
+ /* last stripe */
+ { const xxh_u8* const p = input + len - XXH_STRIPE_LEN;
+#define XXH_SECRET_LASTACC_START 7 /* not aligned on 8, last secret is different from acc & scrambler */
+ f_acc512(acc, p, secret + secretSize - XXH_STRIPE_LEN - XXH_SECRET_LASTACC_START);
+ } }
+}
+
+XXH_FORCE_INLINE xxh_u64
+XXH3_mix2Accs(const xxh_u64* XXH_RESTRICT acc, const xxh_u8* XXH_RESTRICT secret)
+{
+ return XXH3_mul128_fold64(
+ acc[0] ^ XXH_readLE64(secret),
+ acc[1] ^ XXH_readLE64(secret+8) );
+}
+
+static XXH64_hash_t
+XXH3_mergeAccs(const xxh_u64* XXH_RESTRICT acc, const xxh_u8* XXH_RESTRICT secret, xxh_u64 start)
+{
+ xxh_u64 result64 = start;
+ size_t i = 0;
+
+ for (i = 0; i < 4; i++) {
+ result64 += XXH3_mix2Accs(acc+2*i, secret + 16*i);
+#if defined(__clang__) /* Clang */ \
+ && (defined(__arm__) || defined(__thumb__)) /* ARMv7 */ \
+ && (defined(__ARM_NEON) || defined(__ARM_NEON__)) /* NEON */ \
+ && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable */
+ /*
+ * UGLY HACK:
+ * Prevent autovectorization on Clang ARMv7-a. Exact same problem as
+ * the one in XXH3_len_129to240_64b. Speeds up shorter keys > 240b.
+ * XXH3_64bits, len == 256, Snapdragon 835:
+ * without hack: 2063.7 MB/s
+ * with hack: 2560.7 MB/s
+ */
+ __asm__("" : "+r" (result64));
+#endif
+ }
+
+ return XXH3_avalanche(result64);
+}
+
+#define XXH3_INIT_ACC { XXH_PRIME32_3, XXH_PRIME64_1, XXH_PRIME64_2, XXH_PRIME64_3, \
+ XXH_PRIME64_4, XXH_PRIME32_2, XXH_PRIME64_5, XXH_PRIME32_1 }
+
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_hashLong_64b_internal(const void* XXH_RESTRICT input, size_t len,
+ const void* XXH_RESTRICT secret, size_t secretSize,
+ XXH3_f_accumulate_512 f_acc512,
+ XXH3_f_scrambleAcc f_scramble)
+{
+ XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64 acc[XXH_ACC_NB] = XXH3_INIT_ACC;
+
+ XXH3_hashLong_internal_loop(acc, (const xxh_u8*)input, len, (const xxh_u8*)secret, secretSize, f_acc512, f_scramble);
+
+ /* converge into final hash */
+ XXH_STATIC_ASSERT(sizeof(acc) == 64);
+ /* do not align on 8, so that the secret is different from the accumulator */
+#define XXH_SECRET_MERGEACCS_START 11
+ XXH_ASSERT(secretSize >= sizeof(acc) + XXH_SECRET_MERGEACCS_START);
+ return XXH3_mergeAccs(acc, (const xxh_u8*)secret + XXH_SECRET_MERGEACCS_START, (xxh_u64)len * XXH_PRIME64_1);
+}
+
+/*
+ * It's important for performance that XXH3_hashLong is not inlined.
+ */
+XXH_NO_INLINE XXH64_hash_t
+XXH3_hashLong_64b_withSecret(const void* XXH_RESTRICT input, size_t len,
+ XXH64_hash_t seed64, const xxh_u8* XXH_RESTRICT secret, size_t secretLen)
+{
+ (void)seed64;
+ return XXH3_hashLong_64b_internal(input, len, secret, secretLen, XXH3_accumulate_512, XXH3_scrambleAcc);
+}
+
+/*
+ * It's important for performance that XXH3_hashLong is not inlined.
+ * Since the function is not inlined, the compiler may not be able to understand that,
+ * in some scenarios, its `secret` argument is actually a compile time constant.
+ * This variant enforces that the compiler can detect that,
+ * and uses this opportunity to streamline the generated code for better performance.
+ */
+XXH_NO_INLINE XXH64_hash_t
+XXH3_hashLong_64b_default(const void* XXH_RESTRICT input, size_t len,
+ XXH64_hash_t seed64, const xxh_u8* XXH_RESTRICT secret, size_t secretLen)
+{
+ (void)seed64; (void)secret; (void)secretLen;
+ return XXH3_hashLong_64b_internal(input, len, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_accumulate_512, XXH3_scrambleAcc);
+}
+
+/*
+ * XXH3_hashLong_64b_withSeed():
+ * Generate a custom key based on alteration of default XXH3_kSecret with the seed,
+ * and then use this key for long mode hashing.
+ *
+ * This operation is decently fast but nonetheless costs a little bit of time.
+ * Try to avoid it whenever possible (typically when seed==0).
+ *
+ * It's important for performance that XXH3_hashLong is not inlined. Not sure
+ * why (uop cache maybe?), but the difference is large and easily measurable.
+ */
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_hashLong_64b_withSeed_internal(const void* input, size_t len,
+ XXH64_hash_t seed,
+ XXH3_f_accumulate_512 f_acc512,
+ XXH3_f_scrambleAcc f_scramble,
+ XXH3_f_initCustomSecret f_initSec)
+{
+ if (seed == 0)
+ return XXH3_hashLong_64b_internal(input, len,
+ XXH3_kSecret, sizeof(XXH3_kSecret),
+ f_acc512, f_scramble);
+ { XXH_ALIGN(XXH_SEC_ALIGN) xxh_u8 secret[XXH_SECRET_DEFAULT_SIZE];
+ f_initSec(secret, seed);
+ return XXH3_hashLong_64b_internal(input, len, secret, sizeof(secret),
+ f_acc512, f_scramble);
+ }
+}
+
+/*
+ * It's important for performance that XXH3_hashLong is not inlined.
+ */
+XXH_NO_INLINE XXH64_hash_t
+XXH3_hashLong_64b_withSeed(const void* input, size_t len,
+ XXH64_hash_t seed, const xxh_u8* secret, size_t secretLen)
+{
+ (void)secret; (void)secretLen;
+ return XXH3_hashLong_64b_withSeed_internal(input, len, seed,
+ XXH3_accumulate_512, XXH3_scrambleAcc, XXH3_initCustomSecret);
+}
+
+
+typedef XXH64_hash_t (*XXH3_hashLong64_f)(const void* XXH_RESTRICT, size_t,
+ XXH64_hash_t, const xxh_u8* XXH_RESTRICT, size_t);
+
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_64bits_internal(const void* XXH_RESTRICT input, size_t len,
+ XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen,
+ XXH3_hashLong64_f f_hashLong)
+{
+ XXH_ASSERT(secretLen >= XXH3_SECRET_SIZE_MIN);
+ /*
+ * If an action is to be taken if `secretLen` condition is not respected,
+ * it should be done here.
+ * For now, it's a contract pre-condition.
+ * Adding a check and a branch here would cost performance at every hash.
+ * Also, note that function signature doesn't offer room to return an error.
+ */
+ if (len <= 16)
+ return XXH3_len_0to16_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, seed64);
+ if (len <= 128)
+ return XXH3_len_17to128_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64);
+ if (len <= XXH3_MIDSIZE_MAX)
+ return XXH3_len_129to240_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64);
+ return f_hashLong(input, len, seed64, (const xxh_u8*)secret, secretLen);
+}
+
+
+/* === Public entry point === */
+
+XXH_PUBLIC_API XXH64_hash_t XXH3_64bits(const void* input, size_t len)
+{
+ return XXH3_64bits_internal(input, len, 0, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_hashLong_64b_default);
+}
+
+XXH_PUBLIC_API XXH64_hash_t
+XXH3_64bits_withSecret(const void* input, size_t len, const void* secret, size_t secretSize)
+{
+ return XXH3_64bits_internal(input, len, 0, secret, secretSize, XXH3_hashLong_64b_withSecret);
+}
+
+XXH_PUBLIC_API XXH64_hash_t
+XXH3_64bits_withSeed(const void* input, size_t len, XXH64_hash_t seed)
+{
+ return XXH3_64bits_internal(input, len, seed, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_hashLong_64b_withSeed);
+}
+
+
+/* === XXH3 streaming === */
+
+/*
+ * Malloc's a pointer that is always aligned to align.
+ *
+ * This must be freed with `XXH_alignedFree()`.
+ *
+ * malloc typically guarantees 16 byte alignment on 64-bit systems and 8 byte
+ * alignment on 32-bit. This isn't enough for the 32 byte aligned loads in AVX2
+ * or on 32-bit, the 16 byte aligned loads in SSE2 and NEON.
+ *
+ * This underalignment previously caused a rather obvious crash which went
+ * completely unnoticed due to XXH3_createState() not actually being tested.
+ * Credit to RedSpah for noticing this bug.
+ *
+ * The alignment is done manually: Functions like posix_memalign or _mm_malloc
+ * are avoided: To maintain portability, we would have to write a fallback
+ * like this anyways, and besides, testing for the existence of library
+ * functions without relying on external build tools is impossible.
+ *
+ * The method is simple: Overallocate, manually align, and store the offset
+ * to the original behind the returned pointer.
+ *
+ * Align must be a power of 2 and 8 <= align <= 128.
+ */
+static void* XXH_alignedMalloc(size_t s, size_t align)
+{
+ XXH_ASSERT(align <= 128 && align >= 8); /* range check */
+ XXH_ASSERT((align & (align-1)) == 0); /* power of 2 */
+ XXH_ASSERT(s != 0 && s < (s + align)); /* empty/overflow */
+ { /* Overallocate to make room for manual realignment and an offset byte */
+ xxh_u8* base = (xxh_u8*)XXH_malloc(s + align);
+ if (base != NULL) {
+ /*
+ * Get the offset needed to align this pointer.
+ *
+ * Even if the returned pointer is aligned, there will always be
+ * at least one byte to store the offset to the original pointer.
+ */
+ size_t offset = align - ((size_t)base & (align - 1)); /* base % align */
+ /* Add the offset for the now-aligned pointer */
+ xxh_u8* ptr = base + offset;
+
+ XXH_ASSERT((size_t)ptr % align == 0);
+
+ /* Store the offset immediately before the returned pointer. */
+ ptr[-1] = (xxh_u8)offset;
+ return ptr;
+ }
+ return NULL;
+ }
+}
+/*
+ * Frees an aligned pointer allocated by XXH_alignedMalloc(). Don't pass
+ * normal malloc'd pointers, XXH_alignedMalloc has a specific data layout.
+ */
+static void XXH_alignedFree(void* p)
+{
+ if (p != NULL) {
+ xxh_u8* ptr = (xxh_u8*)p;
+ /* Get the offset byte we added in XXH_malloc. */
+ xxh_u8 offset = ptr[-1];
+ /* Free the original malloc'd pointer */
+ xxh_u8* base = ptr - offset;
+ XXH_free(base);
+ }
+}
+XXH_PUBLIC_API XXH3_state_t* XXH3_createState(void)
+{
+ XXH3_state_t* const state = (XXH3_state_t*)XXH_alignedMalloc(sizeof(XXH3_state_t), 64);
+ if (state==NULL) return NULL;
+ XXH3_INITSTATE(state);
+ return state;
+}
+
+XXH_PUBLIC_API XXH_errorcode XXH3_freeState(XXH3_state_t* statePtr)
+{
+ XXH_alignedFree(statePtr);
+ return XXH_OK;
+}
+
+XXH_PUBLIC_API void
+XXH3_copyState(XXH3_state_t* dst_state, const XXH3_state_t* src_state)
+{
+ memcpy(dst_state, src_state, sizeof(*dst_state));
+}
+
+static void
+XXH3_64bits_reset_internal(XXH3_state_t* statePtr,
+ XXH64_hash_t seed,
+ const void* secret, size_t secretSize)
+{
+ size_t const initStart = offsetof(XXH3_state_t, bufferedSize);
+ size_t const initLength = offsetof(XXH3_state_t, nbStripesPerBlock) - initStart;
+ XXH_ASSERT(offsetof(XXH3_state_t, nbStripesPerBlock) > initStart);
+ XXH_ASSERT(statePtr != NULL);
+ /* set members from bufferedSize to nbStripesPerBlock (excluded) to 0 */
+ memset((char*)statePtr + initStart, 0, initLength);
+ statePtr->acc[0] = XXH_PRIME32_3;
+ statePtr->acc[1] = XXH_PRIME64_1;
+ statePtr->acc[2] = XXH_PRIME64_2;
+ statePtr->acc[3] = XXH_PRIME64_3;
+ statePtr->acc[4] = XXH_PRIME64_4;
+ statePtr->acc[5] = XXH_PRIME32_2;
+ statePtr->acc[6] = XXH_PRIME64_5;
+ statePtr->acc[7] = XXH_PRIME32_1;
+ statePtr->seed = seed;
+ statePtr->extSecret = (const unsigned char*)secret;
+ XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN);
+ statePtr->secretLimit = secretSize - XXH_STRIPE_LEN;
+ statePtr->nbStripesPerBlock = statePtr->secretLimit / XXH_SECRET_CONSUME_RATE;
+}
+
+XXH_PUBLIC_API XXH_errorcode
+XXH3_64bits_reset(XXH3_state_t* statePtr)
+{
+ if (statePtr == NULL) return XXH_ERROR;
+ XXH3_64bits_reset_internal(statePtr, 0, XXH3_kSecret, XXH_SECRET_DEFAULT_SIZE);
+ return XXH_OK;
+}
+
+XXH_PUBLIC_API XXH_errorcode
+XXH3_64bits_reset_withSecret(XXH3_state_t* statePtr, const void* secret, size_t secretSize)
+{
+ if (statePtr == NULL) return XXH_ERROR;
+ XXH3_64bits_reset_internal(statePtr, 0, secret, secretSize);
+ if (secret == NULL) return XXH_ERROR;
+ if (secretSize < XXH3_SECRET_SIZE_MIN) return XXH_ERROR;
+ return XXH_OK;
+}
+
+XXH_PUBLIC_API XXH_errorcode
+XXH3_64bits_reset_withSeed(XXH3_state_t* statePtr, XXH64_hash_t seed)
+{
+ if (statePtr == NULL) return XXH_ERROR;
+ if (seed==0) return XXH3_64bits_reset(statePtr);
+ if (seed != statePtr->seed) XXH3_initCustomSecret(statePtr->customSecret, seed);
+ XXH3_64bits_reset_internal(statePtr, seed, NULL, XXH_SECRET_DEFAULT_SIZE);
+ return XXH_OK;
+}
+
+/* Note : when XXH3_consumeStripes() is invoked,
+ * there must be a guarantee that at least one more byte must be consumed from input
+ * so that the function can blindly consume all stripes using the "normal" secret segment */
+XXH_FORCE_INLINE void
+XXH3_consumeStripes(xxh_u64* XXH_RESTRICT acc,
+ size_t* XXH_RESTRICT nbStripesSoFarPtr, size_t nbStripesPerBlock,
+ const xxh_u8* XXH_RESTRICT input, size_t nbStripes,
+ const xxh_u8* XXH_RESTRICT secret, size_t secretLimit,
+ XXH3_f_accumulate_512 f_acc512,
+ XXH3_f_scrambleAcc f_scramble)
+{
+ XXH_ASSERT(nbStripes <= nbStripesPerBlock); /* can handle max 1 scramble per invocation */
+ XXH_ASSERT(*nbStripesSoFarPtr < nbStripesPerBlock);
+ if (nbStripesPerBlock - *nbStripesSoFarPtr <= nbStripes) {
+ /* need a scrambling operation */
+ size_t const nbStripesToEndofBlock = nbStripesPerBlock - *nbStripesSoFarPtr;
+ size_t const nbStripesAfterBlock = nbStripes - nbStripesToEndofBlock;
+ XXH3_accumulate(acc, input, secret + nbStripesSoFarPtr[0] * XXH_SECRET_CONSUME_RATE, nbStripesToEndofBlock, f_acc512);
+ f_scramble(acc, secret + secretLimit);
+ XXH3_accumulate(acc, input + nbStripesToEndofBlock * XXH_STRIPE_LEN, secret, nbStripesAfterBlock, f_acc512);
+ *nbStripesSoFarPtr = nbStripesAfterBlock;
+ } else {
+ XXH3_accumulate(acc, input, secret + nbStripesSoFarPtr[0] * XXH_SECRET_CONSUME_RATE, nbStripes, f_acc512);
+ *nbStripesSoFarPtr += nbStripes;
+ }
+}
+
+/*
+ * Both XXH3_64bits_update and XXH3_128bits_update use this routine.
+ */
+XXH_FORCE_INLINE XXH_errorcode
+XXH3_update(XXH3_state_t* state,
+ const xxh_u8* input, size_t len,
+ XXH3_f_accumulate_512 f_acc512,
+ XXH3_f_scrambleAcc f_scramble)
+{
+ if (input==NULL)
+#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1)
+ return XXH_OK;
+#else
+ return XXH_ERROR;
+#endif
+
+ { const xxh_u8* const bEnd = input + len;
+ const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret;
+
+ state->totalLen += len;
+
+ if (state->bufferedSize + len <= XXH3_INTERNALBUFFER_SIZE) { /* fill in tmp buffer */
+ XXH_memcpy(state->buffer + state->bufferedSize, input, len);
+ state->bufferedSize += (XXH32_hash_t)len;
+ return XXH_OK;
+ }
+ /* total input is now > XXH3_INTERNALBUFFER_SIZE */
+
+ #define XXH3_INTERNALBUFFER_STRIPES (XXH3_INTERNALBUFFER_SIZE / XXH_STRIPE_LEN)
+ XXH_STATIC_ASSERT(XXH3_INTERNALBUFFER_SIZE % XXH_STRIPE_LEN == 0); /* clean multiple */
+
+ /*
+ * Internal buffer is partially filled (always, except at beginning)
+ * Complete it, then consume it.
+ */
+ if (state->bufferedSize) {
+ size_t const loadSize = XXH3_INTERNALBUFFER_SIZE - state->bufferedSize;
+ XXH_memcpy(state->buffer + state->bufferedSize, input, loadSize);
+ input += loadSize;
+ XXH3_consumeStripes(state->acc,
+ &state->nbStripesSoFar, state->nbStripesPerBlock,
+ state->buffer, XXH3_INTERNALBUFFER_STRIPES,
+ secret, state->secretLimit,
+ f_acc512, f_scramble);
+ state->bufferedSize = 0;
+ }
+ XXH_ASSERT(input < bEnd);
+
+ /* Consume input by a multiple of internal buffer size */
+ if (input+XXH3_INTERNALBUFFER_SIZE < bEnd) {
+ const xxh_u8* const limit = bEnd - XXH3_INTERNALBUFFER_SIZE;
+ do {
+ XXH3_consumeStripes(state->acc,
+ &state->nbStripesSoFar, state->nbStripesPerBlock,
+ input, XXH3_INTERNALBUFFER_STRIPES,
+ secret, state->secretLimit,
+ f_acc512, f_scramble);
+ input += XXH3_INTERNALBUFFER_SIZE;
+ } while (input<limit);
+ /* for last partial stripe */
+ memcpy(state->buffer + sizeof(state->buffer) - XXH_STRIPE_LEN, input - XXH_STRIPE_LEN, XXH_STRIPE_LEN);
+ }
+ XXH_ASSERT(input < bEnd);
+
+ /* Some remaining input (always) : buffer it */
+ XXH_memcpy(state->buffer, input, (size_t)(bEnd-input));
+ state->bufferedSize = (XXH32_hash_t)(bEnd-input);
+ }
+
+ return XXH_OK;
+}
+
+XXH_PUBLIC_API XXH_errorcode
+XXH3_64bits_update(XXH3_state_t* state, const void* input, size_t len)
+{
+ return XXH3_update(state, (const xxh_u8*)input, len,
+ XXH3_accumulate_512, XXH3_scrambleAcc);
+}
+
+
+XXH_FORCE_INLINE void
+XXH3_digest_long (XXH64_hash_t* acc,
+ const XXH3_state_t* state,
+ const unsigned char* secret)
+{
+ /*
+ * Digest on a local copy. This way, the state remains unaltered, and it can
+ * continue ingesting more input afterwards.
+ */
+ memcpy(acc, state->acc, sizeof(state->acc));
+ if (state->bufferedSize >= XXH_STRIPE_LEN) {
+ size_t const nbStripes = (state->bufferedSize - 1) / XXH_STRIPE_LEN;
+ size_t nbStripesSoFar = state->nbStripesSoFar;
+ XXH3_consumeStripes(acc,
+ &nbStripesSoFar, state->nbStripesPerBlock,
+ state->buffer, nbStripes,
+ secret, state->secretLimit,
+ XXH3_accumulate_512, XXH3_scrambleAcc);
+ /* last stripe */
+ XXH3_accumulate_512(acc,
+ state->buffer + state->bufferedSize - XXH_STRIPE_LEN,
+ secret + state->secretLimit - XXH_SECRET_LASTACC_START);
+ } else { /* bufferedSize < XXH_STRIPE_LEN */
+ xxh_u8 lastStripe[XXH_STRIPE_LEN];
+ size_t const catchupSize = XXH_STRIPE_LEN - state->bufferedSize;
+ XXH_ASSERT(state->bufferedSize > 0); /* there is always some input buffered */
+ memcpy(lastStripe, state->buffer + sizeof(state->buffer) - catchupSize, catchupSize);
+ memcpy(lastStripe + catchupSize, state->buffer, state->bufferedSize);
+ XXH3_accumulate_512(acc,
+ lastStripe,
+ secret + state->secretLimit - XXH_SECRET_LASTACC_START);
+ }
+}
+
+XXH_PUBLIC_API XXH64_hash_t XXH3_64bits_digest (const XXH3_state_t* state)
+{
+ const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret;
+ if (state->totalLen > XXH3_MIDSIZE_MAX) {
+ XXH_ALIGN(XXH_ACC_ALIGN) XXH64_hash_t acc[XXH_ACC_NB];
+ XXH3_digest_long(acc, state, secret);
+ return XXH3_mergeAccs(acc,
+ secret + XXH_SECRET_MERGEACCS_START,
+ (xxh_u64)state->totalLen * XXH_PRIME64_1);
+ }
+ /* totalLen <= XXH3_MIDSIZE_MAX: digesting a short input */
+ if (state->seed)
+ return XXH3_64bits_withSeed(state->buffer, (size_t)state->totalLen, state->seed);
+ return XXH3_64bits_withSecret(state->buffer, (size_t)(state->totalLen),
+ secret, state->secretLimit + XXH_STRIPE_LEN);
+}
+
+
+#define XXH_MIN(x, y) (((x) > (y)) ? (y) : (x))
+
+XXH_PUBLIC_API void
+XXH3_generateSecret(void* secretBuffer, const void* customSeed, size_t customSeedSize)
+{
+ XXH_ASSERT(secretBuffer != NULL);
+ if (customSeedSize == 0) {
+ memcpy(secretBuffer, XXH3_kSecret, XXH_SECRET_DEFAULT_SIZE);
+ return;
+ }
+ XXH_ASSERT(customSeed != NULL);
+
+ { size_t const segmentSize = sizeof(XXH128_hash_t);
+ size_t const nbSegments = XXH_SECRET_DEFAULT_SIZE / segmentSize;
+ XXH128_canonical_t scrambler;
+ XXH64_hash_t seeds[12];
+ size_t segnb;
+ XXH_ASSERT(nbSegments == 12);
+ XXH_ASSERT(segmentSize * nbSegments == XXH_SECRET_DEFAULT_SIZE); /* exact multiple */
+ XXH128_canonicalFromHash(&scrambler, XXH128(customSeed, customSeedSize, 0));
+
+ /*
+ * Copy customSeed to seeds[], truncating or repeating as necessary.
+ */
+ { size_t toFill = XXH_MIN(customSeedSize, sizeof(seeds));
+ size_t filled = toFill;
+ memcpy(seeds, customSeed, toFill);
+ while (filled < sizeof(seeds)) {
+ toFill = XXH_MIN(filled, sizeof(seeds) - filled);
+ memcpy((char*)seeds + filled, seeds, toFill);
+ filled += toFill;
+ } }
+
+ /* generate secret */
+ memcpy(secretBuffer, &scrambler, sizeof(scrambler));
+ for (segnb=1; segnb < nbSegments; segnb++) {
+ size_t const segmentStart = segnb * segmentSize;
+ XXH128_canonical_t segment;
+ XXH128_canonicalFromHash(&segment,
+ XXH128(&scrambler, sizeof(scrambler), XXH_readLE64(seeds + segnb) + segnb) );
+ memcpy((char*)secretBuffer + segmentStart, &segment, sizeof(segment));
+ } }
+}
+
+
+/* ==========================================
+ * XXH3 128 bits (a.k.a XXH128)
+ * ==========================================
+ * XXH3's 128-bit variant has better mixing and strength than the 64-bit variant,
+ * even without counting the significantly larger output size.
+ *
+ * For example, extra steps are taken to avoid the seed-dependent collisions
+ * in 17-240 byte inputs (See XXH3_mix16B and XXH128_mix32B).
+ *
+ * This strength naturally comes at the cost of some speed, especially on short
+ * lengths. Note that longer hashes are about as fast as the 64-bit version
+ * due to it using only a slight modification of the 64-bit loop.
+ *
+ * XXH128 is also more oriented towards 64-bit machines. It is still extremely
+ * fast for a _128-bit_ hash on 32-bit (it usually clears XXH64).
+ */
+
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_len_1to3_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ /* A doubled version of 1to3_64b with different constants. */
+ XXH_ASSERT(input != NULL);
+ XXH_ASSERT(1 <= len && len <= 3);
+ XXH_ASSERT(secret != NULL);
+ /*
+ * len = 1: combinedl = { input[0], 0x01, input[0], input[0] }
+ * len = 2: combinedl = { input[1], 0x02, input[0], input[1] }
+ * len = 3: combinedl = { input[2], 0x03, input[0], input[1] }
+ */
+ { xxh_u8 const c1 = input[0];
+ xxh_u8 const c2 = input[len >> 1];
+ xxh_u8 const c3 = input[len - 1];
+ xxh_u32 const combinedl = ((xxh_u32)c1 <<16) | ((xxh_u32)c2 << 24)
+ | ((xxh_u32)c3 << 0) | ((xxh_u32)len << 8);
+ xxh_u32 const combinedh = XXH_rotl32(XXH_swap32(combinedl), 13);
+ xxh_u64 const bitflipl = (XXH_readLE32(secret) ^ XXH_readLE32(secret+4)) + seed;
+ xxh_u64 const bitfliph = (XXH_readLE32(secret+8) ^ XXH_readLE32(secret+12)) - seed;
+ xxh_u64 const keyed_lo = (xxh_u64)combinedl ^ bitflipl;
+ xxh_u64 const keyed_hi = (xxh_u64)combinedh ^ bitfliph;
+ XXH128_hash_t h128;
+ h128.low64 = XXH64_avalanche(keyed_lo);
+ h128.high64 = XXH64_avalanche(keyed_hi);
+ return h128;
+ }
+}
+
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_len_4to8_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ XXH_ASSERT(input != NULL);
+ XXH_ASSERT(secret != NULL);
+ XXH_ASSERT(4 <= len && len <= 8);
+ seed ^= (xxh_u64)XXH_swap32((xxh_u32)seed) << 32;
+ { xxh_u32 const input_lo = XXH_readLE32(input);
+ xxh_u32 const input_hi = XXH_readLE32(input + len - 4);
+ xxh_u64 const input_64 = input_lo + ((xxh_u64)input_hi << 32);
+ xxh_u64 const bitflip = (XXH_readLE64(secret+16) ^ XXH_readLE64(secret+24)) + seed;
+ xxh_u64 const keyed = input_64 ^ bitflip;
+
+ /* Shift len to the left to ensure it is even, this avoids even multiplies. */
+ XXH128_hash_t m128 = XXH_mult64to128(keyed, XXH_PRIME64_1 + (len << 2));
+
+ m128.high64 += (m128.low64 << 1);
+ m128.low64 ^= (m128.high64 >> 3);
+
+ m128.low64 = XXH_xorshift64(m128.low64, 35);
+ m128.low64 *= 0x9FB21C651E98DF25ULL;
+ m128.low64 = XXH_xorshift64(m128.low64, 28);
+ m128.high64 = XXH3_avalanche(m128.high64);
+ return m128;
+ }
+}
+
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_len_9to16_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ XXH_ASSERT(input != NULL);
+ XXH_ASSERT(secret != NULL);
+ XXH_ASSERT(9 <= len && len <= 16);
+ { xxh_u64 const bitflipl = (XXH_readLE64(secret+32) ^ XXH_readLE64(secret+40)) - seed;
+ xxh_u64 const bitfliph = (XXH_readLE64(secret+48) ^ XXH_readLE64(secret+56)) + seed;
+ xxh_u64 const input_lo = XXH_readLE64(input);
+ xxh_u64 input_hi = XXH_readLE64(input + len - 8);
+ XXH128_hash_t m128 = XXH_mult64to128(input_lo ^ input_hi ^ bitflipl, XXH_PRIME64_1);
+ /*
+ * Put len in the middle of m128 to ensure that the length gets mixed to
+ * both the low and high bits in the 128x64 multiply below.
+ */
+ m128.low64 += (xxh_u64)(len - 1) << 54;
+ input_hi ^= bitfliph;
+ /*
+ * Add the high 32 bits of input_hi to the high 32 bits of m128, then
+ * add the long product of the low 32 bits of input_hi and XXH_PRIME32_2 to
+ * the high 64 bits of m128.
+ *
+ * The best approach to this operation is different on 32-bit and 64-bit.
+ */
+ if (sizeof(void *) < sizeof(xxh_u64)) { /* 32-bit */
+ /*
+ * 32-bit optimized version, which is more readable.
+ *
+ * On 32-bit, it removes an ADC and delays a dependency between the two
+ * halves of m128.high64, but it generates an extra mask on 64-bit.
+ */
+ m128.high64 += (input_hi & 0xFFFFFFFF00000000ULL) + XXH_mult32to64((xxh_u32)input_hi, XXH_PRIME32_2);
+ } else {
+ /*
+ * 64-bit optimized (albeit more confusing) version.
+ *
+ * Uses some properties of addition and multiplication to remove the mask:
+ *
+ * Let:
+ * a = input_hi.lo = (input_hi & 0x00000000FFFFFFFF)
+ * b = input_hi.hi = (input_hi & 0xFFFFFFFF00000000)
+ * c = XXH_PRIME32_2
+ *
+ * a + (b * c)
+ * Inverse Property: x + y - x == y
+ * a + (b * (1 + c - 1))
+ * Distributive Property: x * (y + z) == (x * y) + (x * z)
+ * a + (b * 1) + (b * (c - 1))
+ * Identity Property: x * 1 == x
+ * a + b + (b * (c - 1))
+ *
+ * Substitute a, b, and c:
+ * input_hi.hi + input_hi.lo + ((xxh_u64)input_hi.lo * (XXH_PRIME32_2 - 1))
+ *
+ * Since input_hi.hi + input_hi.lo == input_hi, we get this:
+ * input_hi + ((xxh_u64)input_hi.lo * (XXH_PRIME32_2 - 1))
+ */
+ m128.high64 += input_hi + XXH_mult32to64((xxh_u32)input_hi, XXH_PRIME32_2 - 1);
+ }
+ /* m128 ^= XXH_swap64(m128 >> 64); */
+ m128.low64 ^= XXH_swap64(m128.high64);
+
+ { /* 128x64 multiply: h128 = m128 * XXH_PRIME64_2; */
+ XXH128_hash_t h128 = XXH_mult64to128(m128.low64, XXH_PRIME64_2);
+ h128.high64 += m128.high64 * XXH_PRIME64_2;
+
+ h128.low64 = XXH3_avalanche(h128.low64);
+ h128.high64 = XXH3_avalanche(h128.high64);
+ return h128;
+ } }
+}
+
+/*
+ * Assumption: `secret` size is >= XXH3_SECRET_SIZE_MIN
+ */
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_len_0to16_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ XXH_ASSERT(len <= 16);
+ { if (len > 8) return XXH3_len_9to16_128b(input, len, secret, seed);
+ if (len >= 4) return XXH3_len_4to8_128b(input, len, secret, seed);
+ if (len) return XXH3_len_1to3_128b(input, len, secret, seed);
+ { XXH128_hash_t h128;
+ xxh_u64 const bitflipl = XXH_readLE64(secret+64) ^ XXH_readLE64(secret+72);
+ xxh_u64 const bitfliph = XXH_readLE64(secret+80) ^ XXH_readLE64(secret+88);
+ h128.low64 = XXH64_avalanche(seed ^ bitflipl);
+ h128.high64 = XXH64_avalanche( seed ^ bitfliph);
+ return h128;
+ } }
+}
+
+/*
+ * A bit slower than XXH3_mix16B, but handles multiply by zero better.
+ */
+XXH_FORCE_INLINE XXH128_hash_t
+XXH128_mix32B(XXH128_hash_t acc, const xxh_u8* input_1, const xxh_u8* input_2,
+ const xxh_u8* secret, XXH64_hash_t seed)
+{
+ acc.low64 += XXH3_mix16B (input_1, secret+0, seed);
+ acc.low64 ^= XXH_readLE64(input_2) + XXH_readLE64(input_2 + 8);
+ acc.high64 += XXH3_mix16B (input_2, secret+16, seed);
+ acc.high64 ^= XXH_readLE64(input_1) + XXH_readLE64(input_1 + 8);
+ return acc;
+}
+
+
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_len_17to128_128b(const xxh_u8* XXH_RESTRICT input, size_t len,
+ const xxh_u8* XXH_RESTRICT secret, size_t secretSize,
+ XXH64_hash_t seed)
+{
+ XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize;
+ XXH_ASSERT(16 < len && len <= 128);
+
+ { XXH128_hash_t acc;
+ acc.low64 = len * XXH_PRIME64_1;
+ acc.high64 = 0;
+ if (len > 32) {
+ if (len > 64) {
+ if (len > 96) {
+ acc = XXH128_mix32B(acc, input+48, input+len-64, secret+96, seed);
+ }
+ acc = XXH128_mix32B(acc, input+32, input+len-48, secret+64, seed);
+ }
+ acc = XXH128_mix32B(acc, input+16, input+len-32, secret+32, seed);
+ }
+ acc = XXH128_mix32B(acc, input, input+len-16, secret, seed);
+ { XXH128_hash_t h128;
+ h128.low64 = acc.low64 + acc.high64;
+ h128.high64 = (acc.low64 * XXH_PRIME64_1)
+ + (acc.high64 * XXH_PRIME64_4)
+ + ((len - seed) * XXH_PRIME64_2);
+ h128.low64 = XXH3_avalanche(h128.low64);
+ h128.high64 = (XXH64_hash_t)0 - XXH3_avalanche(h128.high64);
+ return h128;
+ }
+ }
+}
+
+XXH_NO_INLINE XXH128_hash_t
+XXH3_len_129to240_128b(const xxh_u8* XXH_RESTRICT input, size_t len,
+ const xxh_u8* XXH_RESTRICT secret, size_t secretSize,
+ XXH64_hash_t seed)
+{
+ XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize;
+ XXH_ASSERT(128 < len && len <= XXH3_MIDSIZE_MAX);
+
+ { XXH128_hash_t acc;
+ int const nbRounds = (int)len / 32;
+ int i;
+ acc.low64 = len * XXH_PRIME64_1;
+ acc.high64 = 0;
+ for (i=0; i<4; i++) {
+ acc = XXH128_mix32B(acc,
+ input + (32 * i),
+ input + (32 * i) + 16,
+ secret + (32 * i),
+ seed);
+ }
+ acc.low64 = XXH3_avalanche(acc.low64);
+ acc.high64 = XXH3_avalanche(acc.high64);
+ XXH_ASSERT(nbRounds >= 4);
+ for (i=4 ; i < nbRounds; i++) {
+ acc = XXH128_mix32B(acc,
+ input + (32 * i),
+ input + (32 * i) + 16,
+ secret + XXH3_MIDSIZE_STARTOFFSET + (32 * (i - 4)),
+ seed);
+ }
+ /* last bytes */
+ acc = XXH128_mix32B(acc,
+ input + len - 16,
+ input + len - 32,
+ secret + XXH3_SECRET_SIZE_MIN - XXH3_MIDSIZE_LASTOFFSET - 16,
+ 0ULL - seed);
+
+ { XXH128_hash_t h128;
+ h128.low64 = acc.low64 + acc.high64;
+ h128.high64 = (acc.low64 * XXH_PRIME64_1)
+ + (acc.high64 * XXH_PRIME64_4)
+ + ((len - seed) * XXH_PRIME64_2);
+ h128.low64 = XXH3_avalanche(h128.low64);
+ h128.high64 = (XXH64_hash_t)0 - XXH3_avalanche(h128.high64);
+ return h128;
+ }
+ }
+}
+
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_hashLong_128b_internal(const void* XXH_RESTRICT input, size_t len,
+ const xxh_u8* XXH_RESTRICT secret, size_t secretSize,
+ XXH3_f_accumulate_512 f_acc512,
+ XXH3_f_scrambleAcc f_scramble)
+{
+ XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64 acc[XXH_ACC_NB] = XXH3_INIT_ACC;
+
+ XXH3_hashLong_internal_loop(acc, (const xxh_u8*)input, len, secret, secretSize, f_acc512, f_scramble);
+
+ /* converge into final hash */
+ XXH_STATIC_ASSERT(sizeof(acc) == 64);
+ XXH_ASSERT(secretSize >= sizeof(acc) + XXH_SECRET_MERGEACCS_START);
+ { XXH128_hash_t h128;
+ h128.low64 = XXH3_mergeAccs(acc,
+ secret + XXH_SECRET_MERGEACCS_START,
+ (xxh_u64)len * XXH_PRIME64_1);
+ h128.high64 = XXH3_mergeAccs(acc,
+ secret + secretSize
+ - sizeof(acc) - XXH_SECRET_MERGEACCS_START,
+ ~((xxh_u64)len * XXH_PRIME64_2));
+ return h128;
+ }
+}
+
+/*
+ * It's important for performance that XXH3_hashLong is not inlined.
+ */
+XXH_NO_INLINE XXH128_hash_t
+XXH3_hashLong_128b_default(const void* XXH_RESTRICT input, size_t len,
+ XXH64_hash_t seed64,
+ const void* XXH_RESTRICT secret, size_t secretLen)
+{
+ (void)seed64; (void)secret; (void)secretLen;
+ return XXH3_hashLong_128b_internal(input, len, XXH3_kSecret, sizeof(XXH3_kSecret),
+ XXH3_accumulate_512, XXH3_scrambleAcc);
+}
+
+/*
+ * It's important for performance that XXH3_hashLong is not inlined.
+ */
+XXH_NO_INLINE XXH128_hash_t
+XXH3_hashLong_128b_withSecret(const void* XXH_RESTRICT input, size_t len,
+ XXH64_hash_t seed64,
+ const void* XXH_RESTRICT secret, size_t secretLen)
+{
+ (void)seed64;
+ return XXH3_hashLong_128b_internal(input, len, (const xxh_u8*)secret, secretLen,
+ XXH3_accumulate_512, XXH3_scrambleAcc);
+}
+
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_hashLong_128b_withSeed_internal(const void* XXH_RESTRICT input, size_t len,
+ XXH64_hash_t seed64,
+ XXH3_f_accumulate_512 f_acc512,
+ XXH3_f_scrambleAcc f_scramble,
+ XXH3_f_initCustomSecret f_initSec)
+{
+ if (seed64 == 0)
+ return XXH3_hashLong_128b_internal(input, len,
+ XXH3_kSecret, sizeof(XXH3_kSecret),
+ f_acc512, f_scramble);
+ { XXH_ALIGN(XXH_SEC_ALIGN) xxh_u8 secret[XXH_SECRET_DEFAULT_SIZE];
+ f_initSec(secret, seed64);
+ return XXH3_hashLong_128b_internal(input, len, (const xxh_u8*)secret, sizeof(secret),
+ f_acc512, f_scramble);
+ }
+}
+
+/*
+ * It's important for performance that XXH3_hashLong is not inlined.
+ */
+XXH_NO_INLINE XXH128_hash_t
+XXH3_hashLong_128b_withSeed(const void* input, size_t len,
+ XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen)
+{
+ (void)secret; (void)secretLen;
+ return XXH3_hashLong_128b_withSeed_internal(input, len, seed64,
+ XXH3_accumulate_512, XXH3_scrambleAcc, XXH3_initCustomSecret);
+}
+
+typedef XXH128_hash_t (*XXH3_hashLong128_f)(const void* XXH_RESTRICT, size_t,
+ XXH64_hash_t, const void* XXH_RESTRICT, size_t);
+
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_128bits_internal(const void* input, size_t len,
+ XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen,
+ XXH3_hashLong128_f f_hl128)
+{
+ XXH_ASSERT(secretLen >= XXH3_SECRET_SIZE_MIN);
+ /*
+ * If an action is to be taken if `secret` conditions are not respected,
+ * it should be done here.
+ * For now, it's a contract pre-condition.
+ * Adding a check and a branch here would cost performance at every hash.
+ */
+ if (len <= 16)
+ return XXH3_len_0to16_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, seed64);
+ if (len <= 128)
+ return XXH3_len_17to128_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64);
+ if (len <= XXH3_MIDSIZE_MAX)
+ return XXH3_len_129to240_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64);
+ return f_hl128(input, len, seed64, secret, secretLen);
+}
+
+
+/* === Public XXH128 API === */
+
+XXH_PUBLIC_API XXH128_hash_t XXH3_128bits(const void* input, size_t len)
+{
+ return XXH3_128bits_internal(input, len, 0,
+ XXH3_kSecret, sizeof(XXH3_kSecret),
+ XXH3_hashLong_128b_default);
+}
+
+XXH_PUBLIC_API XXH128_hash_t
+XXH3_128bits_withSecret(const void* input, size_t len, const void* secret, size_t secretSize)
+{
+ return XXH3_128bits_internal(input, len, 0,
+ (const xxh_u8*)secret, secretSize,
+ XXH3_hashLong_128b_withSecret);
+}
+
+XXH_PUBLIC_API XXH128_hash_t
+XXH3_128bits_withSeed(const void* input, size_t len, XXH64_hash_t seed)
+{
+ return XXH3_128bits_internal(input, len, seed,
+ XXH3_kSecret, sizeof(XXH3_kSecret),
+ XXH3_hashLong_128b_withSeed);
+}
+
+XXH_PUBLIC_API XXH128_hash_t
+XXH128(const void* input, size_t len, XXH64_hash_t seed)
+{
+ return XXH3_128bits_withSeed(input, len, seed);
+}
+
+
+/* === XXH3 128-bit streaming === */
+
+/*
+ * All the functions are actually the same as for 64-bit streaming variant.
+ * The only difference is the finalizatiom routine.
+ */
+
+static void
+XXH3_128bits_reset_internal(XXH3_state_t* statePtr,
+ XXH64_hash_t seed,
+ const void* secret, size_t secretSize)
+{
+ XXH3_64bits_reset_internal(statePtr, seed, secret, secretSize);
+}
+
+XXH_PUBLIC_API XXH_errorcode
+XXH3_128bits_reset(XXH3_state_t* statePtr)
+{
+ if (statePtr == NULL) return XXH_ERROR;
+ XXH3_128bits_reset_internal(statePtr, 0, XXH3_kSecret, XXH_SECRET_DEFAULT_SIZE);
+ return XXH_OK;
+}
+
+XXH_PUBLIC_API XXH_errorcode
+XXH3_128bits_reset_withSecret(XXH3_state_t* statePtr, const void* secret, size_t secretSize)
+{
+ if (statePtr == NULL) return XXH_ERROR;
+ XXH3_128bits_reset_internal(statePtr, 0, secret, secretSize);
+ if (secret == NULL) return XXH_ERROR;
+ if (secretSize < XXH3_SECRET_SIZE_MIN) return XXH_ERROR;
+ return XXH_OK;
+}
+
+XXH_PUBLIC_API XXH_errorcode
+XXH3_128bits_reset_withSeed(XXH3_state_t* statePtr, XXH64_hash_t seed)
+{
+ if (statePtr == NULL) return XXH_ERROR;
+ if (seed==0) return XXH3_128bits_reset(statePtr);
+ if (seed != statePtr->seed) XXH3_initCustomSecret(statePtr->customSecret, seed);
+ XXH3_128bits_reset_internal(statePtr, seed, NULL, XXH_SECRET_DEFAULT_SIZE);
+ return XXH_OK;
+}
+
+XXH_PUBLIC_API XXH_errorcode
+XXH3_128bits_update(XXH3_state_t* state, const void* input, size_t len)
+{
+ return XXH3_update(state, (const xxh_u8*)input, len,
+ XXH3_accumulate_512, XXH3_scrambleAcc);
+}
+
+XXH_PUBLIC_API XXH128_hash_t XXH3_128bits_digest (const XXH3_state_t* state)
+{
+ const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret;
+ if (state->totalLen > XXH3_MIDSIZE_MAX) {
+ XXH_ALIGN(XXH_ACC_ALIGN) XXH64_hash_t acc[XXH_ACC_NB];
+ XXH3_digest_long(acc, state, secret);
+ XXH_ASSERT(state->secretLimit + XXH_STRIPE_LEN >= sizeof(acc) + XXH_SECRET_MERGEACCS_START);
+ { XXH128_hash_t h128;
+ h128.low64 = XXH3_mergeAccs(acc,
+ secret + XXH_SECRET_MERGEACCS_START,
+ (xxh_u64)state->totalLen * XXH_PRIME64_1);
+ h128.high64 = XXH3_mergeAccs(acc,
+ secret + state->secretLimit + XXH_STRIPE_LEN
+ - sizeof(acc) - XXH_SECRET_MERGEACCS_START,
+ ~((xxh_u64)state->totalLen * XXH_PRIME64_2));
+ return h128;
+ }
+ }
+ /* len <= XXH3_MIDSIZE_MAX : short code */
+ if (state->seed)
+ return XXH3_128bits_withSeed(state->buffer, (size_t)state->totalLen, state->seed);
+ return XXH3_128bits_withSecret(state->buffer, (size_t)(state->totalLen),
+ secret, state->secretLimit + XXH_STRIPE_LEN);
+}
+
+/* 128-bit utility functions */
+
+#include <string.h> /* memcmp, memcpy */
+
+/* return : 1 is equal, 0 if different */
+XXH_PUBLIC_API int XXH128_isEqual(XXH128_hash_t h1, XXH128_hash_t h2)
+{
+ /* note : XXH128_hash_t is compact, it has no padding byte */
+ return !(memcmp(&h1, &h2, sizeof(h1)));
+}
+
+/* This prototype is compatible with stdlib's qsort().
+ * return : >0 if *h128_1 > *h128_2
+ * <0 if *h128_1 < *h128_2
+ * =0 if *h128_1 == *h128_2 */
+XXH_PUBLIC_API int XXH128_cmp(const void* h128_1, const void* h128_2)
+{
+ XXH128_hash_t const h1 = *(const XXH128_hash_t*)h128_1;
+ XXH128_hash_t const h2 = *(const XXH128_hash_t*)h128_2;
+ int const hcmp = (h1.high64 > h2.high64) - (h2.high64 > h1.high64);
+ /* note : bets that, in most cases, hash values are different */
+ if (hcmp) return hcmp;
+ return (h1.low64 > h2.low64) - (h2.low64 > h1.low64);
+}
+
+
+/*====== Canonical representation ======*/
+XXH_PUBLIC_API void
+XXH128_canonicalFromHash(XXH128_canonical_t* dst, XXH128_hash_t hash)
+{
+ XXH_STATIC_ASSERT(sizeof(XXH128_canonical_t) == sizeof(XXH128_hash_t));
+ if (XXH_CPU_LITTLE_ENDIAN) {
+ hash.high64 = XXH_swap64(hash.high64);
+ hash.low64 = XXH_swap64(hash.low64);
+ }
+ memcpy(dst, &hash.high64, sizeof(hash.high64));
+ memcpy((char*)dst + sizeof(hash.high64), &hash.low64, sizeof(hash.low64));
+}
+
+XXH_PUBLIC_API XXH128_hash_t
+XXH128_hashFromCanonical(const XXH128_canonical_t* src)
+{
+ XXH128_hash_t h;
+ h.high64 = XXH_readBE64(src);
+ h.low64 = XXH_readBE64(src->digest + 8);
+ return h;
+}
+
+/* Pop our optimization override from above */
+#if XXH_VECTOR == XXH_AVX2 /* AVX2 */ \
+ && defined(__GNUC__) && !defined(__clang__) /* GCC, not Clang */ \
+ && defined(__OPTIMIZE__) && !defined(__OPTIMIZE_SIZE__) /* respect -O0 and -Os */
+# pragma GCC pop_options
+#endif
+
+#endif /* XXH_NO_LONG_LONG */
+
+
+#endif /* XXH_IMPLEMENTATION */
+
+
+#if defined (__cplusplus)
+}
+#endif
diff --git a/src/arrow/cpp/src/arrow/visitor.cc b/src/arrow/cpp/src/arrow/visitor.cc
new file mode 100644
index 000000000..1f2771bc2
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/visitor.cc
@@ -0,0 +1,172 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/visitor.h"
+
+#include <memory>
+
+#include "arrow/array.h" // IWYU pragma: keep
+#include "arrow/extension_type.h"
+#include "arrow/scalar.h" // IWYU pragma: keep
+#include "arrow/status.h"
+#include "arrow/type.h"
+
+namespace arrow {
+
+#define ARRAY_VISITOR_DEFAULT(ARRAY_CLASS) \
+ Status ArrayVisitor::Visit(const ARRAY_CLASS& array) { \
+ return Status::NotImplemented(array.type()->ToString()); \
+ }
+
+ARRAY_VISITOR_DEFAULT(NullArray)
+ARRAY_VISITOR_DEFAULT(BooleanArray)
+ARRAY_VISITOR_DEFAULT(Int8Array)
+ARRAY_VISITOR_DEFAULT(Int16Array)
+ARRAY_VISITOR_DEFAULT(Int32Array)
+ARRAY_VISITOR_DEFAULT(Int64Array)
+ARRAY_VISITOR_DEFAULT(UInt8Array)
+ARRAY_VISITOR_DEFAULT(UInt16Array)
+ARRAY_VISITOR_DEFAULT(UInt32Array)
+ARRAY_VISITOR_DEFAULT(UInt64Array)
+ARRAY_VISITOR_DEFAULT(HalfFloatArray)
+ARRAY_VISITOR_DEFAULT(FloatArray)
+ARRAY_VISITOR_DEFAULT(DoubleArray)
+ARRAY_VISITOR_DEFAULT(BinaryArray)
+ARRAY_VISITOR_DEFAULT(StringArray)
+ARRAY_VISITOR_DEFAULT(LargeBinaryArray)
+ARRAY_VISITOR_DEFAULT(LargeStringArray)
+ARRAY_VISITOR_DEFAULT(FixedSizeBinaryArray)
+ARRAY_VISITOR_DEFAULT(Date32Array)
+ARRAY_VISITOR_DEFAULT(Date64Array)
+ARRAY_VISITOR_DEFAULT(Time32Array)
+ARRAY_VISITOR_DEFAULT(Time64Array)
+ARRAY_VISITOR_DEFAULT(TimestampArray)
+ARRAY_VISITOR_DEFAULT(DayTimeIntervalArray)
+ARRAY_VISITOR_DEFAULT(MonthDayNanoIntervalArray)
+ARRAY_VISITOR_DEFAULT(MonthIntervalArray)
+ARRAY_VISITOR_DEFAULT(DurationArray)
+ARRAY_VISITOR_DEFAULT(ListArray)
+ARRAY_VISITOR_DEFAULT(LargeListArray)
+ARRAY_VISITOR_DEFAULT(MapArray)
+ARRAY_VISITOR_DEFAULT(FixedSizeListArray)
+ARRAY_VISITOR_DEFAULT(StructArray)
+ARRAY_VISITOR_DEFAULT(SparseUnionArray)
+ARRAY_VISITOR_DEFAULT(DenseUnionArray)
+ARRAY_VISITOR_DEFAULT(DictionaryArray)
+ARRAY_VISITOR_DEFAULT(Decimal128Array)
+ARRAY_VISITOR_DEFAULT(Decimal256Array)
+ARRAY_VISITOR_DEFAULT(ExtensionArray)
+
+#undef ARRAY_VISITOR_DEFAULT
+
+// ----------------------------------------------------------------------
+// Default implementations of TypeVisitor methods
+
+#define TYPE_VISITOR_DEFAULT(TYPE_CLASS) \
+ Status TypeVisitor::Visit(const TYPE_CLASS& type) { \
+ return Status::NotImplemented(type.ToString()); \
+ }
+
+TYPE_VISITOR_DEFAULT(NullType)
+TYPE_VISITOR_DEFAULT(BooleanType)
+TYPE_VISITOR_DEFAULT(Int8Type)
+TYPE_VISITOR_DEFAULT(Int16Type)
+TYPE_VISITOR_DEFAULT(Int32Type)
+TYPE_VISITOR_DEFAULT(Int64Type)
+TYPE_VISITOR_DEFAULT(UInt8Type)
+TYPE_VISITOR_DEFAULT(UInt16Type)
+TYPE_VISITOR_DEFAULT(UInt32Type)
+TYPE_VISITOR_DEFAULT(UInt64Type)
+TYPE_VISITOR_DEFAULT(HalfFloatType)
+TYPE_VISITOR_DEFAULT(FloatType)
+TYPE_VISITOR_DEFAULT(DoubleType)
+TYPE_VISITOR_DEFAULT(StringType)
+TYPE_VISITOR_DEFAULT(BinaryType)
+TYPE_VISITOR_DEFAULT(LargeStringType)
+TYPE_VISITOR_DEFAULT(LargeBinaryType)
+TYPE_VISITOR_DEFAULT(FixedSizeBinaryType)
+TYPE_VISITOR_DEFAULT(Date64Type)
+TYPE_VISITOR_DEFAULT(Date32Type)
+TYPE_VISITOR_DEFAULT(Time32Type)
+TYPE_VISITOR_DEFAULT(Time64Type)
+TYPE_VISITOR_DEFAULT(TimestampType)
+TYPE_VISITOR_DEFAULT(DayTimeIntervalType)
+TYPE_VISITOR_DEFAULT(MonthDayNanoIntervalType)
+TYPE_VISITOR_DEFAULT(MonthIntervalType)
+TYPE_VISITOR_DEFAULT(DurationType)
+TYPE_VISITOR_DEFAULT(Decimal128Type)
+TYPE_VISITOR_DEFAULT(Decimal256Type)
+TYPE_VISITOR_DEFAULT(ListType)
+TYPE_VISITOR_DEFAULT(LargeListType)
+TYPE_VISITOR_DEFAULT(MapType)
+TYPE_VISITOR_DEFAULT(FixedSizeListType)
+TYPE_VISITOR_DEFAULT(StructType)
+TYPE_VISITOR_DEFAULT(SparseUnionType)
+TYPE_VISITOR_DEFAULT(DenseUnionType)
+TYPE_VISITOR_DEFAULT(DictionaryType)
+TYPE_VISITOR_DEFAULT(ExtensionType)
+
+#undef TYPE_VISITOR_DEFAULT
+
+// ----------------------------------------------------------------------
+// Default implementations of ScalarVisitor methods
+
+#define SCALAR_VISITOR_DEFAULT(SCALAR_CLASS) \
+ Status ScalarVisitor::Visit(const SCALAR_CLASS& scalar) { \
+ return Status::NotImplemented( \
+ "ScalarVisitor not implemented for " ARROW_STRINGIFY(SCALAR_CLASS)); \
+ }
+
+SCALAR_VISITOR_DEFAULT(NullScalar)
+SCALAR_VISITOR_DEFAULT(BooleanScalar)
+SCALAR_VISITOR_DEFAULT(Int8Scalar)
+SCALAR_VISITOR_DEFAULT(Int16Scalar)
+SCALAR_VISITOR_DEFAULT(Int32Scalar)
+SCALAR_VISITOR_DEFAULT(Int64Scalar)
+SCALAR_VISITOR_DEFAULT(UInt8Scalar)
+SCALAR_VISITOR_DEFAULT(UInt16Scalar)
+SCALAR_VISITOR_DEFAULT(UInt32Scalar)
+SCALAR_VISITOR_DEFAULT(UInt64Scalar)
+SCALAR_VISITOR_DEFAULT(HalfFloatScalar)
+SCALAR_VISITOR_DEFAULT(FloatScalar)
+SCALAR_VISITOR_DEFAULT(DoubleScalar)
+SCALAR_VISITOR_DEFAULT(StringScalar)
+SCALAR_VISITOR_DEFAULT(BinaryScalar)
+SCALAR_VISITOR_DEFAULT(LargeStringScalar)
+SCALAR_VISITOR_DEFAULT(LargeBinaryScalar)
+SCALAR_VISITOR_DEFAULT(FixedSizeBinaryScalar)
+SCALAR_VISITOR_DEFAULT(Date64Scalar)
+SCALAR_VISITOR_DEFAULT(Date32Scalar)
+SCALAR_VISITOR_DEFAULT(Time32Scalar)
+SCALAR_VISITOR_DEFAULT(Time64Scalar)
+SCALAR_VISITOR_DEFAULT(TimestampScalar)
+SCALAR_VISITOR_DEFAULT(DayTimeIntervalScalar)
+SCALAR_VISITOR_DEFAULT(MonthDayNanoIntervalScalar)
+SCALAR_VISITOR_DEFAULT(MonthIntervalScalar)
+SCALAR_VISITOR_DEFAULT(DurationScalar)
+SCALAR_VISITOR_DEFAULT(Decimal128Scalar)
+SCALAR_VISITOR_DEFAULT(Decimal256Scalar)
+SCALAR_VISITOR_DEFAULT(ListScalar)
+SCALAR_VISITOR_DEFAULT(LargeListScalar)
+SCALAR_VISITOR_DEFAULT(MapScalar)
+SCALAR_VISITOR_DEFAULT(FixedSizeListScalar)
+SCALAR_VISITOR_DEFAULT(StructScalar)
+SCALAR_VISITOR_DEFAULT(DictionaryScalar)
+
+#undef SCALAR_VISITOR_DEFAULT
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/visitor.h b/src/arrow/cpp/src/arrow/visitor.h
new file mode 100644
index 000000000..18a5c7db0
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/visitor.h
@@ -0,0 +1,155 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+class ARROW_EXPORT ArrayVisitor {
+ public:
+ virtual ~ArrayVisitor() = default;
+
+ virtual Status Visit(const NullArray& array);
+ virtual Status Visit(const BooleanArray& array);
+ virtual Status Visit(const Int8Array& array);
+ virtual Status Visit(const Int16Array& array);
+ virtual Status Visit(const Int32Array& array);
+ virtual Status Visit(const Int64Array& array);
+ virtual Status Visit(const UInt8Array& array);
+ virtual Status Visit(const UInt16Array& array);
+ virtual Status Visit(const UInt32Array& array);
+ virtual Status Visit(const UInt64Array& array);
+ virtual Status Visit(const HalfFloatArray& array);
+ virtual Status Visit(const FloatArray& array);
+ virtual Status Visit(const DoubleArray& array);
+ virtual Status Visit(const StringArray& array);
+ virtual Status Visit(const BinaryArray& array);
+ virtual Status Visit(const LargeStringArray& array);
+ virtual Status Visit(const LargeBinaryArray& array);
+ virtual Status Visit(const FixedSizeBinaryArray& array);
+ virtual Status Visit(const Date32Array& array);
+ virtual Status Visit(const Date64Array& array);
+ virtual Status Visit(const Time32Array& array);
+ virtual Status Visit(const Time64Array& array);
+ virtual Status Visit(const TimestampArray& array);
+ virtual Status Visit(const DayTimeIntervalArray& array);
+ virtual Status Visit(const MonthDayNanoIntervalArray& array);
+ virtual Status Visit(const MonthIntervalArray& array);
+ virtual Status Visit(const DurationArray& array);
+ virtual Status Visit(const Decimal128Array& array);
+ virtual Status Visit(const Decimal256Array& array);
+ virtual Status Visit(const ListArray& array);
+ virtual Status Visit(const LargeListArray& array);
+ virtual Status Visit(const MapArray& array);
+ virtual Status Visit(const FixedSizeListArray& array);
+ virtual Status Visit(const StructArray& array);
+ virtual Status Visit(const SparseUnionArray& array);
+ virtual Status Visit(const DenseUnionArray& array);
+ virtual Status Visit(const DictionaryArray& array);
+ virtual Status Visit(const ExtensionArray& array);
+};
+
+class ARROW_EXPORT TypeVisitor {
+ public:
+ virtual ~TypeVisitor() = default;
+
+ virtual Status Visit(const NullType& type);
+ virtual Status Visit(const BooleanType& type);
+ virtual Status Visit(const Int8Type& type);
+ virtual Status Visit(const Int16Type& type);
+ virtual Status Visit(const Int32Type& type);
+ virtual Status Visit(const Int64Type& type);
+ virtual Status Visit(const UInt8Type& type);
+ virtual Status Visit(const UInt16Type& type);
+ virtual Status Visit(const UInt32Type& type);
+ virtual Status Visit(const UInt64Type& type);
+ virtual Status Visit(const HalfFloatType& type);
+ virtual Status Visit(const FloatType& type);
+ virtual Status Visit(const DoubleType& type);
+ virtual Status Visit(const StringType& type);
+ virtual Status Visit(const BinaryType& type);
+ virtual Status Visit(const LargeStringType& type);
+ virtual Status Visit(const LargeBinaryType& type);
+ virtual Status Visit(const FixedSizeBinaryType& type);
+ virtual Status Visit(const Date64Type& type);
+ virtual Status Visit(const Date32Type& type);
+ virtual Status Visit(const Time32Type& type);
+ virtual Status Visit(const Time64Type& type);
+ virtual Status Visit(const TimestampType& type);
+ virtual Status Visit(const MonthDayNanoIntervalType& type);
+ virtual Status Visit(const MonthIntervalType& type);
+ virtual Status Visit(const DayTimeIntervalType& type);
+ virtual Status Visit(const DurationType& type);
+ virtual Status Visit(const Decimal128Type& type);
+ virtual Status Visit(const Decimal256Type& type);
+ virtual Status Visit(const ListType& type);
+ virtual Status Visit(const LargeListType& type);
+ virtual Status Visit(const MapType& type);
+ virtual Status Visit(const FixedSizeListType& type);
+ virtual Status Visit(const StructType& type);
+ virtual Status Visit(const SparseUnionType& type);
+ virtual Status Visit(const DenseUnionType& type);
+ virtual Status Visit(const DictionaryType& type);
+ virtual Status Visit(const ExtensionType& type);
+};
+
+class ARROW_EXPORT ScalarVisitor {
+ public:
+ virtual ~ScalarVisitor() = default;
+
+ virtual Status Visit(const NullScalar& scalar);
+ virtual Status Visit(const BooleanScalar& scalar);
+ virtual Status Visit(const Int8Scalar& scalar);
+ virtual Status Visit(const Int16Scalar& scalar);
+ virtual Status Visit(const Int32Scalar& scalar);
+ virtual Status Visit(const Int64Scalar& scalar);
+ virtual Status Visit(const UInt8Scalar& scalar);
+ virtual Status Visit(const UInt16Scalar& scalar);
+ virtual Status Visit(const UInt32Scalar& scalar);
+ virtual Status Visit(const UInt64Scalar& scalar);
+ virtual Status Visit(const HalfFloatScalar& scalar);
+ virtual Status Visit(const FloatScalar& scalar);
+ virtual Status Visit(const DoubleScalar& scalar);
+ virtual Status Visit(const StringScalar& scalar);
+ virtual Status Visit(const BinaryScalar& scalar);
+ virtual Status Visit(const LargeStringScalar& scalar);
+ virtual Status Visit(const LargeBinaryScalar& scalar);
+ virtual Status Visit(const FixedSizeBinaryScalar& scalar);
+ virtual Status Visit(const Date64Scalar& scalar);
+ virtual Status Visit(const Date32Scalar& scalar);
+ virtual Status Visit(const Time32Scalar& scalar);
+ virtual Status Visit(const Time64Scalar& scalar);
+ virtual Status Visit(const TimestampScalar& scalar);
+ virtual Status Visit(const DayTimeIntervalScalar& scalar);
+ virtual Status Visit(const MonthDayNanoIntervalScalar& type);
+ virtual Status Visit(const MonthIntervalScalar& scalar);
+ virtual Status Visit(const DurationScalar& scalar);
+ virtual Status Visit(const Decimal128Scalar& scalar);
+ virtual Status Visit(const Decimal256Scalar& scalar);
+ virtual Status Visit(const ListScalar& scalar);
+ virtual Status Visit(const LargeListScalar& scalar);
+ virtual Status Visit(const MapScalar& scalar);
+ virtual Status Visit(const FixedSizeListScalar& scalar);
+ virtual Status Visit(const StructScalar& scalar);
+ virtual Status Visit(const DictionaryScalar& scalar);
+};
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/arrow/visitor_inline.h b/src/arrow/cpp/src/arrow/visitor_inline.h
new file mode 100644
index 000000000..3321605ae
--- /dev/null
+++ b/src/arrow/cpp/src/arrow/visitor_inline.h
@@ -0,0 +1,450 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Private header, not to be exported
+
+#pragma once
+
+#include <utility>
+
+#include "arrow/array.h"
+#include "arrow/extension_type.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/functional.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+
+#define ARROW_GENERATE_FOR_ALL_INTEGER_TYPES(ACTION) \
+ ACTION(Int8); \
+ ACTION(UInt8); \
+ ACTION(Int16); \
+ ACTION(UInt16); \
+ ACTION(Int32); \
+ ACTION(UInt32); \
+ ACTION(Int64); \
+ ACTION(UInt64)
+
+#define ARROW_GENERATE_FOR_ALL_NUMERIC_TYPES(ACTION) \
+ ARROW_GENERATE_FOR_ALL_INTEGER_TYPES(ACTION); \
+ ACTION(HalfFloat); \
+ ACTION(Float); \
+ ACTION(Double)
+
+#define ARROW_GENERATE_FOR_ALL_TYPES(ACTION) \
+ ACTION(Null); \
+ ACTION(Boolean); \
+ ARROW_GENERATE_FOR_ALL_NUMERIC_TYPES(ACTION); \
+ ACTION(String); \
+ ACTION(Binary); \
+ ACTION(LargeString); \
+ ACTION(LargeBinary); \
+ ACTION(FixedSizeBinary); \
+ ACTION(Duration); \
+ ACTION(Date32); \
+ ACTION(Date64); \
+ ACTION(Timestamp); \
+ ACTION(Time32); \
+ ACTION(Time64); \
+ ACTION(MonthDayNanoInterval); \
+ ACTION(MonthInterval); \
+ ACTION(DayTimeInterval); \
+ ACTION(Decimal128); \
+ ACTION(Decimal256); \
+ ACTION(List); \
+ ACTION(LargeList); \
+ ACTION(Map); \
+ ACTION(FixedSizeList); \
+ ACTION(Struct); \
+ ACTION(SparseUnion); \
+ ACTION(DenseUnion); \
+ ACTION(Dictionary); \
+ ACTION(Extension)
+
+#define TYPE_VISIT_INLINE(TYPE_CLASS) \
+ case TYPE_CLASS##Type::type_id: \
+ return visitor->Visit(internal::checked_cast<const TYPE_CLASS##Type&>(type));
+
+template <typename VISITOR>
+inline Status VisitTypeInline(const DataType& type, VISITOR* visitor) {
+ switch (type.id()) {
+ ARROW_GENERATE_FOR_ALL_TYPES(TYPE_VISIT_INLINE);
+ default:
+ break;
+ }
+ return Status::NotImplemented("Type not implemented");
+}
+
+#undef TYPE_VISIT_INLINE
+
+#define TYPE_ID_VISIT_INLINE(TYPE_CLASS) \
+ case TYPE_CLASS##Type::type_id: { \
+ const TYPE_CLASS##Type* concrete_ptr = nullptr; \
+ return visitor->Visit(concrete_ptr); \
+ }
+
+// Calls `visitor` with a nullptr of the corresponding concrete type class
+template <typename VISITOR>
+inline Status VisitTypeIdInline(Type::type id, VISITOR* visitor) {
+ switch (id) {
+ ARROW_GENERATE_FOR_ALL_TYPES(TYPE_ID_VISIT_INLINE);
+ default:
+ break;
+ }
+ return Status::NotImplemented("Type not implemented");
+}
+
+#undef TYPE_ID_VISIT_INLINE
+
+#define ARRAY_VISIT_INLINE(TYPE_CLASS) \
+ case TYPE_CLASS##Type::type_id: \
+ return visitor->Visit( \
+ internal::checked_cast<const typename TypeTraits<TYPE_CLASS##Type>::ArrayType&>( \
+ array));
+
+template <typename VISITOR>
+inline Status VisitArrayInline(const Array& array, VISITOR* visitor) {
+ switch (array.type_id()) {
+ ARROW_GENERATE_FOR_ALL_TYPES(ARRAY_VISIT_INLINE);
+ default:
+ break;
+ }
+ return Status::NotImplemented("Type not implemented");
+}
+
+namespace internal {
+
+template <typename T, typename Enable = void>
+struct ArrayDataInlineVisitor {};
+
+// Numeric and primitive C-compatible types
+template <typename T>
+struct ArrayDataInlineVisitor<T, enable_if_has_c_type<T>> {
+ using c_type = typename T::c_type;
+
+ template <typename ValidFunc, typename NullFunc>
+ static Status VisitStatus(const ArrayData& arr, ValidFunc&& valid_func,
+ NullFunc&& null_func) {
+ const c_type* data = arr.GetValues<c_type>(1);
+ auto visit_valid = [&](int64_t i) { return valid_func(data[i]); };
+ return VisitBitBlocks(arr.buffers[0], arr.offset, arr.length, std::move(visit_valid),
+ std::forward<NullFunc>(null_func));
+ }
+
+ template <typename ValidFunc, typename NullFunc>
+ static void VisitVoid(const ArrayData& arr, ValidFunc&& valid_func,
+ NullFunc&& null_func) {
+ using c_type = typename T::c_type;
+ const c_type* data = arr.GetValues<c_type>(1);
+ auto visit_valid = [&](int64_t i) { valid_func(data[i]); };
+ VisitBitBlocksVoid(arr.buffers[0], arr.offset, arr.length, std::move(visit_valid),
+ std::forward<NullFunc>(null_func));
+ }
+};
+
+// Boolean
+template <>
+struct ArrayDataInlineVisitor<BooleanType> {
+ using c_type = bool;
+
+ template <typename ValidFunc, typename NullFunc>
+ static Status VisitStatus(const ArrayData& arr, ValidFunc&& valid_func,
+ NullFunc&& null_func) {
+ int64_t offset = arr.offset;
+ const uint8_t* data = arr.buffers[1]->data();
+ return VisitBitBlocks(
+ arr.buffers[0], offset, arr.length,
+ [&](int64_t i) { return valid_func(BitUtil::GetBit(data, offset + i)); },
+ std::forward<NullFunc>(null_func));
+ }
+
+ template <typename ValidFunc, typename NullFunc>
+ static void VisitVoid(const ArrayData& arr, ValidFunc&& valid_func,
+ NullFunc&& null_func) {
+ int64_t offset = arr.offset;
+ const uint8_t* data = arr.buffers[1]->data();
+ VisitBitBlocksVoid(
+ arr.buffers[0], offset, arr.length,
+ [&](int64_t i) { valid_func(BitUtil::GetBit(data, offset + i)); },
+ std::forward<NullFunc>(null_func));
+ }
+};
+
+// Binary, String...
+template <typename T>
+struct ArrayDataInlineVisitor<T, enable_if_base_binary<T>> {
+ using c_type = util::string_view;
+
+ template <typename ValidFunc, typename NullFunc>
+ static Status VisitStatus(const ArrayData& arr, ValidFunc&& valid_func,
+ NullFunc&& null_func) {
+ using offset_type = typename T::offset_type;
+ constexpr char empty_value = 0;
+
+ if (arr.length == 0) {
+ return Status::OK();
+ }
+ const offset_type* offsets = arr.GetValues<offset_type>(1);
+ const char* data;
+ if (!arr.buffers[2]) {
+ data = &empty_value;
+ } else {
+ // Do not apply the array offset to the values array; the value_offsets
+ // index the non-sliced values array.
+ data = arr.GetValues<char>(2, /*absolute_offset=*/0);
+ }
+ offset_type cur_offset = *offsets++;
+ return VisitBitBlocks(
+ arr.buffers[0], arr.offset, arr.length,
+ [&](int64_t i) {
+ ARROW_UNUSED(i);
+ auto value = util::string_view(data + cur_offset, *offsets - cur_offset);
+ cur_offset = *offsets++;
+ return valid_func(value);
+ },
+ [&]() {
+ cur_offset = *offsets++;
+ return null_func();
+ });
+ }
+
+ template <typename ValidFunc, typename NullFunc>
+ static void VisitVoid(const ArrayData& arr, ValidFunc&& valid_func,
+ NullFunc&& null_func) {
+ using offset_type = typename T::offset_type;
+ constexpr uint8_t empty_value = 0;
+
+ if (arr.length == 0) {
+ return;
+ }
+ const offset_type* offsets = arr.GetValues<offset_type>(1);
+ const uint8_t* data;
+ if (!arr.buffers[2]) {
+ data = &empty_value;
+ } else {
+ // Do not apply the array offset to the values array; the value_offsets
+ // index the non-sliced values array.
+ data = arr.GetValues<uint8_t>(2, /*absolute_offset=*/0);
+ }
+
+ VisitBitBlocksVoid(
+ arr.buffers[0], arr.offset, arr.length,
+ [&](int64_t i) {
+ auto value = util::string_view(reinterpret_cast<const char*>(data + offsets[i]),
+ offsets[i + 1] - offsets[i]);
+ valid_func(value);
+ },
+ std::forward<NullFunc>(null_func));
+ }
+};
+
+// FixedSizeBinary, Decimal128
+template <typename T>
+struct ArrayDataInlineVisitor<T, enable_if_fixed_size_binary<T>> {
+ using c_type = util::string_view;
+
+ template <typename ValidFunc, typename NullFunc>
+ static Status VisitStatus(const ArrayData& arr, ValidFunc&& valid_func,
+ NullFunc&& null_func) {
+ const auto& fw_type = internal::checked_cast<const FixedSizeBinaryType&>(*arr.type);
+
+ const int32_t byte_width = fw_type.byte_width();
+ const char* data = arr.GetValues<char>(1,
+ /*absolute_offset=*/arr.offset * byte_width);
+
+ return VisitBitBlocks(
+ arr.buffers[0], arr.offset, arr.length,
+ [&](int64_t i) {
+ auto value = util::string_view(data, byte_width);
+ data += byte_width;
+ return valid_func(value);
+ },
+ [&]() {
+ data += byte_width;
+ return null_func();
+ });
+ }
+
+ template <typename ValidFunc, typename NullFunc>
+ static void VisitVoid(const ArrayData& arr, ValidFunc&& valid_func,
+ NullFunc&& null_func) {
+ const auto& fw_type = internal::checked_cast<const FixedSizeBinaryType&>(*arr.type);
+
+ const int32_t byte_width = fw_type.byte_width();
+ const char* data = arr.GetValues<char>(1,
+ /*absolute_offset=*/arr.offset * byte_width);
+
+ VisitBitBlocksVoid(
+ arr.buffers[0], arr.offset, arr.length,
+ [&](int64_t i) {
+ valid_func(util::string_view(data, byte_width));
+ data += byte_width;
+ },
+ [&]() {
+ data += byte_width;
+ null_func();
+ });
+ }
+};
+
+} // namespace internal
+
+// Visit an array's data values, in order, without overhead.
+//
+// The given `ValidFunc` should be a callable with either of these signatures:
+// - void(scalar_type)
+// - Status(scalar_type)
+//
+// The `NullFunc` should have the same return type as `ValidFunc`.
+//
+// ... where `scalar_type` depends on the array data type:
+// - the type's `c_type`, if any
+// - for boolean arrays, a `bool`
+// - for binary, string and fixed-size binary arrays, a `util::string_view`
+
+template <typename T, typename ValidFunc, typename NullFunc>
+typename internal::call_traits::enable_if_return<ValidFunc, Status>::type
+VisitArrayDataInline(const ArrayData& arr, ValidFunc&& valid_func, NullFunc&& null_func) {
+ return internal::ArrayDataInlineVisitor<T>::VisitStatus(
+ arr, std::forward<ValidFunc>(valid_func), std::forward<NullFunc>(null_func));
+}
+
+template <typename T, typename ValidFunc, typename NullFunc>
+typename internal::call_traits::enable_if_return<ValidFunc, void>::type
+VisitArrayDataInline(const ArrayData& arr, ValidFunc&& valid_func, NullFunc&& null_func) {
+ return internal::ArrayDataInlineVisitor<T>::VisitVoid(
+ arr, std::forward<ValidFunc>(valid_func), std::forward<NullFunc>(null_func));
+}
+
+// Visit an array's data values, in order, without overhead.
+//
+// The Visit method's `visitor` argument should be an object with two public methods:
+// - Status VisitNull()
+// - Status VisitValue(<scalar>)
+//
+// The scalar value's type depends on the array data type:
+// - the type's `c_type`, if any
+// - for boolean arrays, a `bool`
+// - for binary, string and fixed-size binary arrays, a `util::string_view`
+
+template <typename T>
+struct ArrayDataVisitor {
+ using InlineVisitorType = internal::ArrayDataInlineVisitor<T>;
+ using c_type = typename InlineVisitorType::c_type;
+
+ template <typename Visitor>
+ static Status Visit(const ArrayData& arr, Visitor* visitor) {
+ return InlineVisitorType::VisitStatus(
+ arr, [visitor](c_type v) { return visitor->VisitValue(v); },
+ [visitor]() { return visitor->VisitNull(); });
+ }
+};
+
+#define SCALAR_VISIT_INLINE(TYPE_CLASS) \
+ case TYPE_CLASS##Type::type_id: \
+ return visitor->Visit(internal::checked_cast<const TYPE_CLASS##Scalar&>(scalar));
+
+template <typename VISITOR>
+inline Status VisitScalarInline(const Scalar& scalar, VISITOR* visitor) {
+ switch (scalar.type->id()) {
+ ARROW_GENERATE_FOR_ALL_TYPES(SCALAR_VISIT_INLINE);
+ default:
+ break;
+ }
+ return Status::NotImplemented("Scalar visitor for type not implemented ",
+ scalar.type->ToString());
+}
+
+#undef TYPE_VISIT_INLINE
+
+// Visit a null bitmap, in order, without overhead.
+//
+// The given `ValidFunc` should be a callable with either of these signatures:
+// - void()
+// - Status()
+//
+// The `NullFunc` should have the same return type as `ValidFunc`.
+
+template <typename ValidFunc, typename NullFunc>
+typename internal::call_traits::enable_if_return<ValidFunc, Status>::type
+VisitNullBitmapInline(const uint8_t* valid_bits, int64_t valid_bits_offset,
+ int64_t num_values, int64_t null_count, ValidFunc&& valid_func,
+ NullFunc&& null_func) {
+ ARROW_UNUSED(null_count);
+ internal::OptionalBitBlockCounter bit_counter(valid_bits, valid_bits_offset,
+ num_values);
+ int64_t position = 0;
+ int64_t offset_position = valid_bits_offset;
+ while (position < num_values) {
+ internal::BitBlockCount block = bit_counter.NextBlock();
+ if (block.AllSet()) {
+ for (int64_t i = 0; i < block.length; ++i) {
+ ARROW_RETURN_NOT_OK(valid_func());
+ }
+ } else if (block.NoneSet()) {
+ for (int64_t i = 0; i < block.length; ++i) {
+ ARROW_RETURN_NOT_OK(null_func());
+ }
+ } else {
+ for (int64_t i = 0; i < block.length; ++i) {
+ ARROW_RETURN_NOT_OK(BitUtil::GetBit(valid_bits, offset_position + i)
+ ? valid_func()
+ : null_func());
+ }
+ }
+ position += block.length;
+ offset_position += block.length;
+ }
+ return Status::OK();
+}
+
+template <typename ValidFunc, typename NullFunc>
+typename internal::call_traits::enable_if_return<ValidFunc, void>::type
+VisitNullBitmapInline(const uint8_t* valid_bits, int64_t valid_bits_offset,
+ int64_t num_values, int64_t null_count, ValidFunc&& valid_func,
+ NullFunc&& null_func) {
+ ARROW_UNUSED(null_count);
+ internal::OptionalBitBlockCounter bit_counter(valid_bits, valid_bits_offset,
+ num_values);
+ int64_t position = 0;
+ int64_t offset_position = valid_bits_offset;
+ while (position < num_values) {
+ internal::BitBlockCount block = bit_counter.NextBlock();
+ if (block.AllSet()) {
+ for (int64_t i = 0; i < block.length; ++i) {
+ valid_func();
+ }
+ } else if (block.NoneSet()) {
+ for (int64_t i = 0; i < block.length; ++i) {
+ null_func();
+ }
+ } else {
+ for (int64_t i = 0; i < block.length; ++i) {
+ BitUtil::GetBit(valid_bits, offset_position + i) ? valid_func() : null_func();
+ }
+ }
+ position += block.length;
+ offset_position += block.length;
+ }
+}
+
+} // namespace arrow
diff --git a/src/arrow/cpp/src/gandiva/CMakeLists.txt b/src/arrow/cpp/src/gandiva/CMakeLists.txt
new file mode 100644
index 000000000..654a4a40b
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/CMakeLists.txt
@@ -0,0 +1,253 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set(GANDIVA_VERSION "${ARROW_VERSION}")
+
+# For "make gandiva" to build everything Gandiva-related
+add_custom_target(gandiva-all)
+add_custom_target(gandiva)
+add_custom_target(gandiva-tests)
+add_custom_target(gandiva-benchmarks)
+
+add_dependencies(gandiva-all gandiva gandiva-tests gandiva-benchmarks)
+
+find_package(LLVMAlt REQUIRED)
+
+if(LLVM_VERSION_MAJOR LESS "10")
+ set(GANDIVA_CXX_STANDARD ${CMAKE_CXX_STANDARD})
+else()
+ # LLVM 10 or later requires C++ 14
+ if(CMAKE_CXX_STANDARD LESS 14)
+ set(GANDIVA_CXX_STANDARD 14)
+ else()
+ set(GANDIVA_CXX_STANDARD ${CMAKE_CXX_STANDARD})
+ endif()
+endif()
+
+add_definitions(-DGANDIVA_LLVM_VERSION=${LLVM_VERSION_MAJOR})
+
+find_package(OpenSSLAlt REQUIRED)
+
+# Set the path where the bitcode file generated, see precompiled/CMakeLists.txt
+set(GANDIVA_PRECOMPILED_BC_PATH "${CMAKE_CURRENT_BINARY_DIR}/irhelpers.bc")
+set(GANDIVA_PRECOMPILED_CC_PATH "${CMAKE_CURRENT_BINARY_DIR}/precompiled_bitcode.cc")
+set(GANDIVA_PRECOMPILED_CC_IN_PATH
+ "${CMAKE_CURRENT_SOURCE_DIR}/precompiled_bitcode.cc.in")
+
+# add_arrow_lib will look for this not yet existing file, so flag as generated
+set_source_files_properties(${GANDIVA_PRECOMPILED_CC_PATH} PROPERTIES GENERATED TRUE)
+
+set(SRC_FILES
+ annotator.cc
+ bitmap_accumulator.cc
+ cache.cc
+ cast_time.cc
+ configuration.cc
+ context_helper.cc
+ decimal_ir.cc
+ decimal_type_util.cc
+ decimal_xlarge.cc
+ engine.cc
+ date_utils.cc
+ expr_decomposer.cc
+ expr_validator.cc
+ expression.cc
+ expression_registry.cc
+ exported_funcs_registry.cc
+ filter.cc
+ function_ir_builder.cc
+ function_registry.cc
+ function_registry_arithmetic.cc
+ function_registry_datetime.cc
+ function_registry_hash.cc
+ function_registry_math_ops.cc
+ function_registry_string.cc
+ function_registry_timestamp_arithmetic.cc
+ function_signature.cc
+ gdv_function_stubs.cc
+ hash_utils.cc
+ llvm_generator.cc
+ llvm_types.cc
+ like_holder.cc
+ literal_holder.cc
+ projector.cc
+ regex_util.cc
+ replace_holder.cc
+ selection_vector.cc
+ tree_expr_builder.cc
+ to_date_holder.cc
+ random_generator_holder.cc
+ ${GANDIVA_PRECOMPILED_CC_PATH})
+
+set(GANDIVA_SHARED_PRIVATE_LINK_LIBS arrow_shared LLVM::LLVM_INTERFACE
+ ${GANDIVA_OPENSSL_LIBS})
+
+set(GANDIVA_STATIC_LINK_LIBS arrow_static LLVM::LLVM_INTERFACE ${GANDIVA_OPENSSL_LIBS})
+
+if(ARROW_GANDIVA_STATIC_LIBSTDCPP AND (CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX
+ ))
+ set(GANDIVA_STATIC_LINK_LIBS ${GANDIVA_STATIC_LINK_LIBS} -static-libstdc++
+ -static-libgcc)
+endif()
+
+# if (MSVC)
+# # Symbols that need to be made public in gandiva.dll for LLVM IR
+# # compilation
+# set(MSVC_SYMBOL_EXPORTS _Init_thread_header)
+# foreach(SYMBOL ${MSVC_SYMBOL_EXPORTS})
+# set(GANDIVA_SHARED_LINK_FLAGS "${GANDIVA_SHARED_LINK_FLAGS} /EXPORT:${SYMBOL}")
+# endforeach()
+# endif()
+if(CXX_LINKER_SUPPORTS_VERSION_SCRIPT)
+ set(GANDIVA_VERSION_SCRIPT_FLAGS
+ "-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/symbols.map")
+ set(GANDIVA_SHARED_LINK_FLAGS
+ "${GANDIVA_SHARED_LINK_FLAGS} ${GANDIVA_VERSION_SCRIPT_FLAGS}")
+endif()
+
+add_arrow_lib(gandiva
+ CMAKE_PACKAGE_NAME
+ Gandiva
+ PKG_CONFIG_NAME
+ gandiva
+ SOURCES
+ ${SRC_FILES}
+ PRECOMPILED_HEADERS
+ "$<$<COMPILE_LANGUAGE:CXX>:gandiva/pch.h>"
+ OUTPUTS
+ GANDIVA_LIBRARIES
+ DEPENDENCIES
+ arrow_dependencies
+ precompiled
+ EXTRA_INCLUDES
+ $<TARGET_PROPERTY:LLVM::LLVM_INTERFACE,INTERFACE_INCLUDE_DIRECTORIES>
+ ${GANDIVA_OPENSSL_INCLUDE_DIR}
+ ${UTF8PROC_INCLUDE_DIR}
+ SHARED_LINK_FLAGS
+ ${GANDIVA_SHARED_LINK_FLAGS}
+ SHARED_LINK_LIBS
+ arrow_shared
+ SHARED_PRIVATE_LINK_LIBS
+ ${GANDIVA_SHARED_PRIVATE_LINK_LIBS}
+ STATIC_LINK_LIBS
+ ${GANDIVA_STATIC_LINK_LIBS})
+
+foreach(LIB_TARGET ${GANDIVA_LIBRARIES})
+ target_compile_definitions(${LIB_TARGET} PRIVATE GANDIVA_EXPORTING)
+ set_target_properties(${LIB_TARGET} PROPERTIES CXX_STANDARD ${GANDIVA_CXX_STANDARD})
+endforeach()
+
+if(ARROW_BUILD_STATIC AND WIN32)
+ target_compile_definitions(gandiva_static PUBLIC GANDIVA_STATIC)
+endif()
+
+add_dependencies(gandiva ${GANDIVA_LIBRARIES})
+
+arrow_install_all_headers("gandiva")
+
+set(GANDIVA_STATIC_TEST_LINK_LIBS gandiva_static ${ARROW_TEST_LINK_LIBS})
+
+set(GANDIVA_SHARED_TEST_LINK_LIBS gandiva_shared ${ARROW_TEST_LINK_LIBS})
+
+function(ADD_GANDIVA_TEST REL_TEST_NAME)
+ set(options USE_STATIC_LINKING)
+ set(one_value_args)
+ set(multi_value_args)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+
+ if(NO_TESTS)
+ return()
+ endif()
+
+ set(TEST_ARGUMENTS
+ ENABLED
+ PREFIX
+ "gandiva"
+ LABELS
+ "gandiva-tests"
+ ${ARG_UNPARSED_ARGUMENTS})
+
+ # and uses less disk space, but in some cases we need to force static
+ # linking (see rationale below).
+ if(ARG_USE_STATIC_LINKING OR ARROW_TEST_LINKAGE STREQUAL "static")
+ add_test_case(${REL_TEST_NAME}
+ ${TEST_ARGUMENTS}
+ STATIC_LINK_LIBS
+ ${GANDIVA_STATIC_TEST_LINK_LIBS}
+ ${ARG_UNPARSED_ARGUMENTS})
+ else()
+ add_test_case(${REL_TEST_NAME}
+ ${TEST_ARGUMENTS}
+ STATIC_LINK_LIBS
+ ${GANDIVA_SHARED_TEST_LINK_LIBS}
+ ${ARG_UNPARSED_ARGUMENTS})
+ endif()
+
+ set(TEST_NAME gandiva-${REL_TEST_NAME})
+ string(REPLACE "_" "-" TEST_NAME ${TEST_NAME})
+ set_target_properties(${TEST_NAME} PROPERTIES CXX_STANDARD ${GANDIVA_CXX_STANDARD})
+endfunction()
+
+set(GANDIVA_INTERNALS_TEST_ARGUMENTS)
+if(WIN32)
+ list(APPEND
+ GANDIVA_INTERNALS_TEST_ARGUMENTS
+ EXTRA_LINK_LIBS
+ LLVM::LLVM_INTERFACE
+ ${GANDIVA_OPENSSL_LIBS})
+endif()
+add_gandiva_test(internals-test
+ SOURCES
+ bitmap_accumulator_test.cc
+ engine_llvm_test.cc
+ function_registry_test.cc
+ function_signature_test.cc
+ llvm_types_test.cc
+ llvm_generator_test.cc
+ annotator_test.cc
+ tree_expr_test.cc
+ expr_decomposer_test.cc
+ expression_registry_test.cc
+ selection_vector_test.cc
+ greedy_dual_size_cache_test.cc
+ to_date_holder_test.cc
+ simple_arena_test.cc
+ like_holder_test.cc
+ replace_holder_test.cc
+ decimal_type_util_test.cc
+ random_generator_holder_test.cc
+ hash_utils_test.cc
+ gdv_function_stubs_test.cc
+ EXTRA_DEPENDENCIES
+ LLVM::LLVM_INTERFACE
+ ${GANDIVA_OPENSSL_LIBS}
+ EXTRA_INCLUDES
+ $<TARGET_PROPERTY:LLVM::LLVM_INTERFACE,INTERFACE_INCLUDE_DIRECTORIES>
+ ${GANDIVA_INTERNALS_TEST_ARGUMENTS}
+ ${GANDIVA_OPENSSL_INCLUDE_DIR}
+ ${UTF8PROC_INCLUDE_DIR})
+
+if(ARROW_GANDIVA_JAVA)
+ add_subdirectory(jni)
+endif()
+
+add_subdirectory(precompiled)
+add_subdirectory(tests)
diff --git a/src/arrow/cpp/src/gandiva/GandivaConfig.cmake.in b/src/arrow/cpp/src/gandiva/GandivaConfig.cmake.in
new file mode 100644
index 000000000..09bc33901
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/GandivaConfig.cmake.in
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This config sets the following variables in your project::
+#
+# Gandiva_FOUND - true if Gandiva found on the system
+#
+# This config sets the following targets in your project::
+#
+# gandiva_shared - for linked as shared library if shared library is built
+# gandiva_static - for linked as static library if static library is built
+
+@PACKAGE_INIT@
+
+include(CMakeFindDependencyMacro)
+find_dependency(Arrow)
+
+# Load targets only once. If we load targets multiple times, CMake reports
+# already existent target error.
+if(NOT (TARGET gandiva_shared OR TARGET gandiva_static))
+ include("${CMAKE_CURRENT_LIST_DIR}/GandivaTargets.cmake")
+endif()
diff --git a/src/arrow/cpp/src/gandiva/annotator.cc b/src/arrow/cpp/src/gandiva/annotator.cc
new file mode 100644
index 000000000..f6acaff18
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/annotator.cc
@@ -0,0 +1,118 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/annotator.h"
+
+#include <memory>
+#include <string>
+
+#include "gandiva/field_descriptor.h"
+
+namespace gandiva {
+
+FieldDescriptorPtr Annotator::CheckAndAddInputFieldDescriptor(FieldPtr field) {
+ // If the field is already in the map, return the entry.
+ auto found = in_name_to_desc_.find(field->name());
+ if (found != in_name_to_desc_.end()) {
+ return found->second;
+ }
+
+ auto desc = MakeDesc(field, false /*is_output*/);
+ in_name_to_desc_[field->name()] = desc;
+ return desc;
+}
+
+FieldDescriptorPtr Annotator::AddOutputFieldDescriptor(FieldPtr field) {
+ auto desc = MakeDesc(field, true /*is_output*/);
+ out_descs_.push_back(desc);
+ return desc;
+}
+
+FieldDescriptorPtr Annotator::MakeDesc(FieldPtr field, bool is_output) {
+ int data_idx = buffer_count_++;
+ int validity_idx = buffer_count_++;
+ int offsets_idx = FieldDescriptor::kInvalidIdx;
+ if (arrow::is_binary_like(field->type()->id())) {
+ offsets_idx = buffer_count_++;
+ }
+ int data_buffer_ptr_idx = FieldDescriptor::kInvalidIdx;
+ if (is_output) {
+ data_buffer_ptr_idx = buffer_count_++;
+ }
+ return std::make_shared<FieldDescriptor>(field, data_idx, validity_idx, offsets_idx,
+ data_buffer_ptr_idx);
+}
+
+void Annotator::PrepareBuffersForField(const FieldDescriptor& desc,
+ const arrow::ArrayData& array_data,
+ EvalBatch* eval_batch, bool is_output) {
+ int buffer_idx = 0;
+
+ // The validity buffer is optional. Use nullptr if it does not have one.
+ if (array_data.buffers[buffer_idx]) {
+ uint8_t* validity_buf = const_cast<uint8_t*>(array_data.buffers[buffer_idx]->data());
+ eval_batch->SetBuffer(desc.validity_idx(), validity_buf, array_data.offset);
+ } else {
+ eval_batch->SetBuffer(desc.validity_idx(), nullptr, array_data.offset);
+ }
+ ++buffer_idx;
+
+ if (desc.HasOffsetsIdx()) {
+ uint8_t* offsets_buf = const_cast<uint8_t*>(array_data.buffers[buffer_idx]->data());
+ eval_batch->SetBuffer(desc.offsets_idx(), offsets_buf, array_data.offset);
+ ++buffer_idx;
+ }
+
+ uint8_t* data_buf = const_cast<uint8_t*>(array_data.buffers[buffer_idx]->data());
+ eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.offset);
+ if (is_output) {
+ // pass in the Buffer object for output data buffers. Can be used for resizing.
+ uint8_t* data_buf_ptr =
+ reinterpret_cast<uint8_t*>(array_data.buffers[buffer_idx].get());
+ eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, array_data.offset);
+ }
+}
+
+EvalBatchPtr Annotator::PrepareEvalBatch(const arrow::RecordBatch& record_batch,
+ const ArrayDataVector& out_vector) {
+ EvalBatchPtr eval_batch = std::make_shared<EvalBatch>(
+ record_batch.num_rows(), buffer_count_, local_bitmap_count_);
+
+ // Fill in the entries for the input fields.
+ for (int i = 0; i < record_batch.num_columns(); ++i) {
+ const std::string& name = record_batch.column_name(i);
+ auto found = in_name_to_desc_.find(name);
+ if (found == in_name_to_desc_.end()) {
+ // skip columns not involved in the expression.
+ continue;
+ }
+
+ PrepareBuffersForField(*(found->second), *(record_batch.column(i))->data(),
+ eval_batch.get(), false /*is_output*/);
+ }
+
+ // Fill in the entries for the output fields.
+ int idx = 0;
+ for (auto& arraydata : out_vector) {
+ const FieldDescriptorPtr& desc = out_descs_.at(idx);
+ PrepareBuffersForField(*desc, *arraydata, eval_batch.get(), true /*is_output*/);
+ ++idx;
+ }
+ return eval_batch;
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/annotator.h b/src/arrow/cpp/src/gandiva/annotator.h
new file mode 100644
index 000000000..5f185d183
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/annotator.h
@@ -0,0 +1,81 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <list>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "arrow/util/logging.h"
+#include "gandiva/arrow.h"
+#include "gandiva/eval_batch.h"
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// \brief annotate the arrow fields in an expression, and use that
+/// to convert the incoming arrow-format row batch to an EvalBatch.
+class GANDIVA_EXPORT Annotator {
+ public:
+ Annotator() : buffer_count_(0), local_bitmap_count_(0) {}
+
+ /// Add an annotated field descriptor for a field in an input schema.
+ /// If the field is already annotated, returns that instead.
+ FieldDescriptorPtr CheckAndAddInputFieldDescriptor(FieldPtr field);
+
+ /// Add an annotated field descriptor for an output field.
+ FieldDescriptorPtr AddOutputFieldDescriptor(FieldPtr field);
+
+ /// Add a local bitmap (for saving validity bits of an intermediate node).
+ /// Returns the index of the bitmap in the list of local bitmaps.
+ int AddLocalBitMap() { return local_bitmap_count_++; }
+
+ /// Prepare an eval batch for the incoming record batch.
+ EvalBatchPtr PrepareEvalBatch(const arrow::RecordBatch& record_batch,
+ const ArrayDataVector& out_vector);
+
+ int buffer_count() { return buffer_count_; }
+
+ private:
+ /// Annotate a field and return the descriptor.
+ FieldDescriptorPtr MakeDesc(FieldPtr field, bool is_output);
+
+ /// Populate eval_batch by extracting the raw buffers from the arrow array, whose
+ /// contents are represent by the annotated descriptor 'desc'.
+ void PrepareBuffersForField(const FieldDescriptor& desc,
+ const arrow::ArrayData& array_data, EvalBatch* eval_batch,
+ bool is_output);
+
+ /// The list of input/output buffers (includes bitmap buffers, value buffers and
+ /// offset buffers).
+ int buffer_count_;
+
+ /// The number of local bitmaps. These are used to save the validity bits for
+ /// intermediate nodes in the expression tree.
+ int local_bitmap_count_;
+
+ /// map between field name and annotated input field descriptor.
+ std::unordered_map<std::string, FieldDescriptorPtr> in_name_to_desc_;
+
+ /// vector of annotated output field descriptors.
+ std::vector<FieldDescriptorPtr> out_descs_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/annotator_test.cc b/src/arrow/cpp/src/gandiva/annotator_test.cc
new file mode 100644
index 000000000..e537943d9
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/annotator_test.cc
@@ -0,0 +1,102 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/annotator.h"
+
+#include <memory>
+#include <utility>
+
+#include <arrow/memory_pool.h>
+#include <gtest/gtest.h>
+#include "gandiva/field_descriptor.h"
+
+namespace gandiva {
+
+class TestAnnotator : public ::testing::Test {
+ protected:
+ ArrayPtr MakeInt32Array(int length);
+};
+
+ArrayPtr TestAnnotator::MakeInt32Array(int length) {
+ arrow::Status status;
+
+ auto validity = *arrow::AllocateBuffer((length + 63) / 8);
+
+ auto values = *arrow::AllocateBuffer(length * sizeof(int32_t));
+
+ auto array_data = arrow::ArrayData::Make(arrow::int32(), length,
+ {std::move(validity), std::move(values)});
+ return arrow::MakeArray(array_data);
+}
+
+TEST_F(TestAnnotator, TestAdd) {
+ Annotator annotator;
+
+ auto field_a = arrow::field("a", arrow::int32());
+ auto field_b = arrow::field("b", arrow::int32());
+ auto in_schema = arrow::schema({field_a, field_b});
+ auto field_sum = arrow::field("sum", arrow::int32());
+
+ FieldDescriptorPtr desc_a = annotator.CheckAndAddInputFieldDescriptor(field_a);
+ EXPECT_EQ(desc_a->field(), field_a);
+ EXPECT_EQ(desc_a->data_idx(), 0);
+ EXPECT_EQ(desc_a->validity_idx(), 1);
+
+ // duplicate add shouldn't cause a new descriptor.
+ FieldDescriptorPtr dup = annotator.CheckAndAddInputFieldDescriptor(field_a);
+ EXPECT_EQ(dup, desc_a);
+ EXPECT_EQ(dup->validity_idx(), desc_a->validity_idx());
+
+ FieldDescriptorPtr desc_b = annotator.CheckAndAddInputFieldDescriptor(field_b);
+ EXPECT_EQ(desc_b->field(), field_b);
+ EXPECT_EQ(desc_b->data_idx(), 2);
+ EXPECT_EQ(desc_b->validity_idx(), 3);
+
+ FieldDescriptorPtr desc_sum = annotator.AddOutputFieldDescriptor(field_sum);
+ EXPECT_EQ(desc_sum->field(), field_sum);
+ EXPECT_EQ(desc_sum->data_idx(), 4);
+ EXPECT_EQ(desc_sum->validity_idx(), 5);
+ EXPECT_EQ(desc_sum->data_buffer_ptr_idx(), 6);
+
+ // prepare record batch
+ int num_records = 100;
+ auto arrow_v0 = MakeInt32Array(num_records);
+ auto arrow_v1 = MakeInt32Array(num_records);
+
+ // prepare input record batch
+ auto record_batch =
+ arrow::RecordBatch::Make(in_schema, num_records, {arrow_v0, arrow_v1});
+
+ auto arrow_sum = MakeInt32Array(num_records);
+ EvalBatchPtr batch = annotator.PrepareEvalBatch(*record_batch, {arrow_sum->data()});
+ EXPECT_EQ(batch->GetNumBuffers(), 7);
+
+ auto buffers = batch->GetBufferArray();
+ EXPECT_EQ(buffers[desc_a->validity_idx()], arrow_v0->data()->buffers.at(0)->data());
+ EXPECT_EQ(buffers[desc_a->data_idx()], arrow_v0->data()->buffers.at(1)->data());
+ EXPECT_EQ(buffers[desc_b->validity_idx()], arrow_v1->data()->buffers.at(0)->data());
+ EXPECT_EQ(buffers[desc_b->data_idx()], arrow_v1->data()->buffers.at(1)->data());
+ EXPECT_EQ(buffers[desc_sum->validity_idx()], arrow_sum->data()->buffers.at(0)->data());
+ EXPECT_EQ(buffers[desc_sum->data_idx()], arrow_sum->data()->buffers.at(1)->data());
+ EXPECT_EQ(buffers[desc_sum->data_buffer_ptr_idx()],
+ reinterpret_cast<uint8_t*>(arrow_sum->data()->buffers.at(1).get()));
+
+ auto bitmaps = batch->GetLocalBitMapArray();
+ EXPECT_EQ(bitmaps, nullptr);
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/arrow.h b/src/arrow/cpp/src/gandiva/arrow.h
new file mode 100644
index 000000000..e6d40cb18
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/arrow.h
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <vector>
+
+#include "arrow/array.h" // IWYU pragma: export
+#include "arrow/builder.h" // IWYU pragma: export
+#include "arrow/pretty_print.h" // IWYU pragma: export
+#include "arrow/record_batch.h" // IWYU pragma: export
+#include "arrow/status.h" // IWYU pragma: export
+#include "arrow/type.h" // IWYU pragma: export
+
+namespace gandiva {
+
+using arrow::ArrayDataVector;
+using arrow::DataTypeVector;
+using arrow::FieldVector;
+using arrow::Result;
+using arrow::Status;
+using arrow::StatusCode;
+
+using ArrayPtr = std::shared_ptr<arrow::Array>;
+using ArrayDataPtr = std::shared_ptr<arrow::ArrayData>;
+using DataTypePtr = std::shared_ptr<arrow::DataType>;
+using FieldPtr = std::shared_ptr<arrow::Field>;
+using RecordBatchPtr = std::shared_ptr<arrow::RecordBatch>;
+using SchemaPtr = std::shared_ptr<arrow::Schema>;
+
+using Decimal128TypePtr = std::shared_ptr<arrow::Decimal128Type>;
+using Decimal128TypeVector = std::vector<Decimal128TypePtr>;
+
+static inline bool is_decimal_128(DataTypePtr type) {
+ if (type->id() == arrow::Type::DECIMAL) {
+ auto decimal_type = arrow::internal::checked_cast<arrow::DecimalType*>(type.get());
+ return decimal_type->byte_width() == 16;
+ } else {
+ return false;
+ }
+}
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/basic_decimal_scalar.h b/src/arrow/cpp/src/gandiva/basic_decimal_scalar.h
new file mode 100644
index 000000000..b2f0da506
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/basic_decimal_scalar.h
@@ -0,0 +1,65 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+
+#include "arrow/util/basic_decimal.h"
+#include "arrow/util/decimal.h"
+
+namespace gandiva {
+
+using arrow::BasicDecimal128;
+
+/// Represents a 128-bit decimal value along with its precision and scale.
+class BasicDecimalScalar128 {
+ public:
+ constexpr BasicDecimalScalar128(int64_t high_bits, uint64_t low_bits, int32_t precision,
+ int32_t scale)
+ : value_(high_bits, low_bits), precision_(precision), scale_(scale) {}
+
+ constexpr BasicDecimalScalar128(const BasicDecimal128& value, int32_t precision,
+ int32_t scale)
+ : value_(value), precision_(precision), scale_(scale) {}
+
+ constexpr BasicDecimalScalar128(int32_t precision, int32_t scale)
+ : precision_(precision), scale_(scale) {}
+
+ int32_t scale() const { return scale_; }
+
+ int32_t precision() const { return precision_; }
+
+ const BasicDecimal128& value() const { return value_; }
+
+ private:
+ BasicDecimal128 value_;
+ int32_t precision_;
+ int32_t scale_;
+};
+
+inline bool operator==(const BasicDecimalScalar128& left,
+ const BasicDecimalScalar128& right) {
+ return left.value() == right.value() && left.precision() == right.precision() &&
+ left.scale() == right.scale();
+}
+
+inline BasicDecimalScalar128 operator-(const BasicDecimalScalar128& operand) {
+ return BasicDecimalScalar128{-operand.value(), operand.precision(), operand.scale()};
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/bitmap_accumulator.cc b/src/arrow/cpp/src/gandiva/bitmap_accumulator.cc
new file mode 100644
index 000000000..8fc66b389
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/bitmap_accumulator.cc
@@ -0,0 +1,75 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/bitmap_accumulator.h"
+
+#include <vector>
+
+#include "arrow/util/bitmap_ops.h"
+
+namespace gandiva {
+
+void BitMapAccumulator::ComputeResult(uint8_t* dst_bitmap) {
+ int64_t num_records = eval_batch_.num_records();
+
+ if (all_invalid_) {
+ // set all bits to 0.
+ memset(dst_bitmap, 0, arrow::BitUtil::BytesForBits(num_records));
+ } else {
+ IntersectBitMaps(dst_bitmap, src_maps_, src_map_offsets_, num_records);
+ }
+}
+
+/// Compute the intersection of multiple bitmaps.
+void BitMapAccumulator::IntersectBitMaps(uint8_t* dst_map,
+ const std::vector<uint8_t*>& src_maps,
+ const std::vector<int64_t>& src_map_offsets,
+ int64_t num_records) {
+ int64_t num_words = (num_records + 63) / 64; // aligned to 8-byte.
+ int64_t num_bytes = num_words * 8;
+ int64_t nmaps = src_maps.size();
+
+ switch (nmaps) {
+ case 0: {
+ // no src_maps_ bitmap. simply set all bits
+ memset(dst_map, 0xff, num_bytes);
+ break;
+ }
+
+ case 1: {
+ // one src_maps_ bitmap. copy to dst_map
+ arrow::internal::CopyBitmap(src_maps[0], src_map_offsets[0], num_records, dst_map,
+ 0);
+ break;
+ }
+
+ default: {
+ // src_maps bitmaps ANDs
+ arrow::internal::BitmapAnd(src_maps[0], src_map_offsets[0], src_maps[1],
+ src_map_offsets[1], num_records, /*offset=*/0, dst_map);
+ for (int64_t m = 2; m < nmaps; ++m) {
+ arrow::internal::BitmapAnd(dst_map, 0, src_maps[m], src_map_offsets[m],
+ num_records,
+ /*offset=*/0, dst_map);
+ }
+
+ break;
+ }
+ }
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/bitmap_accumulator.h b/src/arrow/cpp/src/gandiva/bitmap_accumulator.h
new file mode 100644
index 000000000..0b297a98f
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/bitmap_accumulator.h
@@ -0,0 +1,79 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <vector>
+
+#include "arrow/util/macros.h"
+#include "gandiva/dex.h"
+#include "gandiva/dex_visitor.h"
+#include "gandiva/eval_batch.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// \brief Extract bitmap buffer from either the input/buffer vectors or the
+/// local validity bitmap, and accumulates them to do the final computation.
+class GANDIVA_EXPORT BitMapAccumulator : public DexDefaultVisitor {
+ public:
+ explicit BitMapAccumulator(const EvalBatch& eval_batch)
+ : eval_batch_(eval_batch), all_invalid_(false) {}
+
+ void Visit(const VectorReadValidityDex& dex) {
+ int idx = dex.ValidityIdx();
+ auto bitmap = eval_batch_.GetBuffer(idx);
+ // The bitmap could be null. Ignore it in this case.
+ if (bitmap != NULLPTR) {
+ src_maps_.push_back(bitmap);
+ src_map_offsets_.push_back(eval_batch_.GetBufferOffset(idx));
+ }
+ }
+
+ void Visit(const LocalBitMapValidityDex& dex) {
+ int idx = dex.local_bitmap_idx();
+ auto bitmap = eval_batch_.GetLocalBitMap(idx);
+ src_maps_.push_back(bitmap);
+ src_map_offsets_.push_back(0); // local bitmap has offset 0
+ }
+
+ void Visit(const TrueDex& dex) {
+ // bitwise-and with 1 is always 1. so, ignore.
+ }
+
+ void Visit(const FalseDex& dex) {
+ // The final result is "all 0s".
+ all_invalid_ = true;
+ }
+
+ /// Compute the dst_bmap based on the contents and type of the accumulated bitmap dex.
+ void ComputeResult(uint8_t* dst_bitmap);
+
+ /// Compute the intersection of the accumulated bitmaps (with offsets) and save the
+ /// result in dst_bmap.
+ static void IntersectBitMaps(uint8_t* dst_map, const std::vector<uint8_t*>& src_maps,
+ const std::vector<int64_t>& src_maps_offsets,
+ int64_t num_records);
+
+ private:
+ const EvalBatch& eval_batch_;
+ std::vector<uint8_t*> src_maps_;
+ std::vector<int64_t> src_map_offsets_;
+ bool all_invalid_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/bitmap_accumulator_test.cc b/src/arrow/cpp/src/gandiva/bitmap_accumulator_test.cc
new file mode 100644
index 000000000..ccffab3e9
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/bitmap_accumulator_test.cc
@@ -0,0 +1,112 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/bitmap_accumulator.h"
+
+#include <memory>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/util.h"
+#include "arrow/util/bitmap_ops.h"
+
+#include "gandiva/dex.h"
+
+namespace gandiva {
+
+class TestBitMapAccumulator : public ::testing::Test {
+ protected:
+ void FillBitMap(uint8_t* bmap, uint32_t seed, int nrecords);
+ void ByteWiseIntersectBitMaps(uint8_t* dst, const std::vector<uint8_t*>& srcs,
+ const std::vector<int64_t>& srcOffsets, int nrecords);
+};
+
+void TestBitMapAccumulator::FillBitMap(uint8_t* bmap, uint32_t seed, int nbytes) {
+ ::arrow::random_bytes(nbytes, seed, bmap);
+}
+
+void TestBitMapAccumulator::ByteWiseIntersectBitMaps(
+ uint8_t* dst, const std::vector<uint8_t*>& srcs,
+ const std::vector<int64_t>& srcOffsets, int nrecords) {
+ if (srcs.empty()) {
+ arrow::BitUtil::SetBitsTo(dst, 0, nrecords, true);
+ return;
+ }
+
+ arrow::internal::CopyBitmap(srcs[0], srcOffsets[0], nrecords, dst, 0);
+ for (uint32_t j = 1; j < srcs.size(); ++j) {
+ arrow::internal::BitmapAnd(dst, 0, srcs[j], srcOffsets[j], nrecords, 0, dst);
+ }
+}
+
+TEST_F(TestBitMapAccumulator, TestIntersectBitMaps) {
+ const int length = 128;
+ const int nrecords = length * 8;
+ uint8_t src_bitmaps[4][length];
+ uint8_t dst_bitmap[length];
+ uint8_t expected_bitmap[length];
+
+ for (int i = 0; i < 4; i++) {
+ FillBitMap(src_bitmaps[i], i, length);
+ }
+
+ for (int i = 0; i < 4; i++) {
+ std::vector<uint8_t*> src_bitmap_ptrs;
+ std::vector<int64_t> src_bitmap_offsets(i, 0);
+ for (int j = 0; j < i; ++j) {
+ src_bitmap_ptrs.push_back(src_bitmaps[j]);
+ }
+
+ BitMapAccumulator::IntersectBitMaps(dst_bitmap, src_bitmap_ptrs, src_bitmap_offsets,
+ nrecords);
+ ByteWiseIntersectBitMaps(expected_bitmap, src_bitmap_ptrs, src_bitmap_offsets,
+ nrecords);
+ EXPECT_EQ(memcmp(dst_bitmap, expected_bitmap, length), 0);
+ }
+}
+
+TEST_F(TestBitMapAccumulator, TestIntersectBitMapsWithOffset) {
+ const int length = 128;
+ uint8_t src_bitmaps[4][length];
+ uint8_t dst_bitmap[length];
+ uint8_t expected_bitmap[length];
+
+ for (int i = 0; i < 4; i++) {
+ FillBitMap(src_bitmaps[i], i, length);
+ }
+
+ for (int i = 0; i < 4; i++) {
+ std::vector<uint8_t*> src_bitmap_ptrs;
+ std::vector<int64_t> src_bitmap_offsets;
+ for (int j = 0; j < i; ++j) {
+ src_bitmap_ptrs.push_back(src_bitmaps[j]);
+ src_bitmap_offsets.push_back(j); // offset j
+ }
+ const int nrecords = (i == 0) ? length * 8 : length * 8 - i + 1;
+
+ BitMapAccumulator::IntersectBitMaps(dst_bitmap, src_bitmap_ptrs, src_bitmap_offsets,
+ nrecords);
+ ByteWiseIntersectBitMaps(expected_bitmap, src_bitmap_ptrs, src_bitmap_offsets,
+ nrecords);
+ EXPECT_TRUE(
+ arrow::internal::BitmapEquals(dst_bitmap, 0, expected_bitmap, 0, nrecords));
+ }
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/cache.cc b/src/arrow/cpp/src/gandiva/cache.cc
new file mode 100644
index 000000000..d823a676b
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/cache.cc
@@ -0,0 +1,45 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/cache.h"
+#include "arrow/util/logging.h"
+
+namespace gandiva {
+
+static const int DEFAULT_CACHE_SIZE = 500;
+
+int GetCapacity() {
+ int capacity;
+ const char* env_cache_size = std::getenv("GANDIVA_CACHE_SIZE");
+ if (env_cache_size != nullptr) {
+ capacity = std::atoi(env_cache_size);
+ if (capacity <= 0) {
+ ARROW_LOG(WARNING) << "Invalid cache size provided. Using default cache size: "
+ << DEFAULT_CACHE_SIZE;
+ capacity = DEFAULT_CACHE_SIZE;
+ }
+ } else {
+ capacity = DEFAULT_CACHE_SIZE;
+ }
+ return capacity;
+}
+
+void LogCacheSize(size_t capacity) {
+ ARROW_LOG(INFO) << "Creating gandiva cache with capacity: " << capacity;
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/cache.h b/src/arrow/cpp/src/gandiva/cache.h
new file mode 100644
index 000000000..8d0f75ce3
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/cache.h
@@ -0,0 +1,60 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdlib>
+#include <memory>
+#include <mutex>
+
+#include "gandiva/greedy_dual_size_cache.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+GANDIVA_EXPORT
+int GetCapacity();
+
+GANDIVA_EXPORT
+void LogCacheSize(size_t capacity);
+
+template <class KeyType, typename ValueType>
+class Cache {
+ public:
+ explicit Cache(size_t capacity) : cache_(capacity) { LogCacheSize(capacity); }
+
+ Cache() : Cache(GetCapacity()) {}
+
+ ValueType GetModule(KeyType cache_key) {
+ arrow::util::optional<ValueCacheObject<ValueType>> result;
+ mtx_.lock();
+ result = cache_.get(cache_key);
+ mtx_.unlock();
+ return result != arrow::util::nullopt ? (*result).module : nullptr;
+ }
+
+ void PutModule(KeyType cache_key, ValueCacheObject<ValueType> valueCacheObject) {
+ mtx_.lock();
+ cache_.insert(cache_key, valueCacheObject);
+ mtx_.unlock();
+ }
+
+ private:
+ GreedyDualSizeCache<KeyType, ValueType> cache_;
+ std::mutex mtx_;
+};
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/cast_time.cc b/src/arrow/cpp/src/gandiva/cast_time.cc
new file mode 100644
index 000000000..843ce01f8
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/cast_time.cc
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+
+#include "arrow/vendored/datetime.h"
+
+#include "gandiva/precompiled/time_fields.h"
+
+#ifndef GANDIVA_UNIT_TEST
+#include "gandiva/exported_funcs.h"
+#include "gandiva/gdv_function_stubs.h"
+
+#include "gandiva/engine.h"
+
+namespace gandiva {
+
+void ExportedTimeFunctions::AddMappings(Engine* engine) const {
+ std::vector<llvm::Type*> args;
+ auto types = engine->types();
+
+ // gdv_fn_time_with_zone
+ args = {types->ptr_type(types->i32_type()), // time fields
+ types->i8_ptr_type(), // const char* zone
+ types->i32_type(), // int data_len
+ types->i64_type()}; // timestamp *ret_time
+
+ engine->AddGlobalMappingForFunc("gdv_fn_time_with_zone",
+ types->i32_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_time_with_zone));
+}
+
+} // namespace gandiva
+#endif // !GANDIVA_UNIT_TEST
+
+extern "C" {
+
+// TODO : Do input validation or make sure the callers do that ?
+int gdv_fn_time_with_zone(int* time_fields, const char* zone, int zone_len,
+ int64_t* ret_time) {
+ using arrow_vendored::date::day;
+ using arrow_vendored::date::local_days;
+ using arrow_vendored::date::locate_zone;
+ using arrow_vendored::date::month;
+ using arrow_vendored::date::time_zone;
+ using arrow_vendored::date::year;
+ using std::chrono::hours;
+ using std::chrono::milliseconds;
+ using std::chrono::minutes;
+ using std::chrono::seconds;
+
+ using gandiva::TimeFields;
+ try {
+ const time_zone* tz = locate_zone(std::string(zone, zone_len));
+ *ret_time = tz->to_sys(local_days(year(time_fields[TimeFields::kYear]) /
+ month(time_fields[TimeFields::kMonth]) /
+ day(time_fields[TimeFields::kDay])) +
+ hours(time_fields[TimeFields::kHours]) +
+ minutes(time_fields[TimeFields::kMinutes]) +
+ seconds(time_fields[TimeFields::kSeconds]) +
+ milliseconds(time_fields[TimeFields::kSubSeconds]))
+ .time_since_epoch()
+ .count();
+ } catch (...) {
+ return EINVAL;
+ }
+
+ return 0;
+}
+
+} // extern "C"
diff --git a/src/arrow/cpp/src/gandiva/compiled_expr.h b/src/arrow/cpp/src/gandiva/compiled_expr.h
new file mode 100644
index 000000000..ba0ca3437
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/compiled_expr.h
@@ -0,0 +1,71 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <vector>
+#include "gandiva/llvm_includes.h"
+#include "gandiva/selection_vector.h"
+#include "gandiva/value_validity_pair.h"
+
+namespace gandiva {
+
+using EvalFunc = int (*)(uint8_t** buffers, int64_t* offsets, uint8_t** local_bitmaps,
+ const uint8_t* selection_buffer, int64_t execution_ctx_ptr,
+ int64_t record_count);
+
+/// \brief Tracks the compiled state for one expression.
+class CompiledExpr {
+ public:
+ CompiledExpr(ValueValidityPairPtr value_validity, FieldDescriptorPtr output)
+ : value_validity_(value_validity), output_(output) {}
+
+ ValueValidityPairPtr value_validity() const { return value_validity_; }
+
+ FieldDescriptorPtr output() const { return output_; }
+
+ void SetIRFunction(SelectionVector::Mode mode, llvm::Function* ir_function) {
+ ir_functions_[static_cast<int>(mode)] = ir_function;
+ }
+
+ llvm::Function* GetIRFunction(SelectionVector::Mode mode) const {
+ return ir_functions_[static_cast<int>(mode)];
+ }
+
+ void SetJITFunction(SelectionVector::Mode mode, EvalFunc jit_function) {
+ jit_functions_[static_cast<int>(mode)] = jit_function;
+ }
+
+ EvalFunc GetJITFunction(SelectionVector::Mode mode) const {
+ return jit_functions_[static_cast<int>(mode)];
+ }
+
+ private:
+ // value & validities for the expression tree (root)
+ ValueValidityPairPtr value_validity_;
+
+ // output field
+ FieldDescriptorPtr output_;
+
+ // IR functions for various modes in the generated code
+ std::array<llvm::Function*, SelectionVector::kNumModes> ir_functions_;
+
+ // JIT functions in the generated code (set after the module is optimised and finalized)
+ std::array<EvalFunc, SelectionVector::kNumModes> jit_functions_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/condition.h b/src/arrow/cpp/src/gandiva/condition.h
new file mode 100644
index 000000000..a3e8f9d1f
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/condition.h
@@ -0,0 +1,37 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "gandiva/arrow.h"
+#include "gandiva/expression.h"
+#include "gandiva/gandiva_aliases.h"
+
+namespace gandiva {
+
+/// \brief A condition expression.
+class Condition : public Expression {
+ public:
+ explicit Condition(const NodePtr root)
+ : Expression(root, std::make_shared<arrow::Field>("cond", arrow::boolean())) {}
+
+ virtual ~Condition() = default;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/configuration.cc b/src/arrow/cpp/src/gandiva/configuration.cc
new file mode 100644
index 000000000..1e26c5c70
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/configuration.cc
@@ -0,0 +1,43 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/configuration.h"
+
+#include "arrow/util/hash_util.h"
+
+namespace gandiva {
+
+const std::shared_ptr<Configuration> ConfigurationBuilder::default_configuration_ =
+ InitDefaultConfig();
+
+std::size_t Configuration::Hash() const {
+ static constexpr size_t kHashSeed = 0;
+ size_t result = kHashSeed;
+ arrow::internal::hash_combine(result, static_cast<size_t>(optimize_));
+ arrow::internal::hash_combine(result, static_cast<size_t>(target_host_cpu_));
+ return result;
+}
+
+bool Configuration::operator==(const Configuration& other) const {
+ return optimize_ == other.optimize_ && target_host_cpu_ == other.target_host_cpu_;
+}
+
+bool Configuration::operator!=(const Configuration& other) const {
+ return !(*this == other);
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/configuration.h b/src/arrow/cpp/src/gandiva/configuration.h
new file mode 100644
index 000000000..9cd301524
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/configuration.h
@@ -0,0 +1,84 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/status.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+class ConfigurationBuilder;
+/// \brief runtime config for gandiva
+///
+/// It contains elements to customize gandiva execution
+/// at run time.
+class GANDIVA_EXPORT Configuration {
+ public:
+ friend class ConfigurationBuilder;
+
+ Configuration() : optimize_(true), target_host_cpu_(true) {}
+ explicit Configuration(bool optimize) : optimize_(optimize), target_host_cpu_(true) {}
+
+ std::size_t Hash() const;
+ bool operator==(const Configuration& other) const;
+ bool operator!=(const Configuration& other) const;
+
+ bool optimize() const { return optimize_; }
+ bool target_host_cpu() const { return target_host_cpu_; }
+
+ void set_optimize(bool optimize) { optimize_ = optimize; }
+ void target_host_cpu(bool target_host_cpu) { target_host_cpu_ = target_host_cpu; }
+
+ private:
+ bool optimize_; /* optimise the generated llvm IR */
+ bool target_host_cpu_; /* set the mcpu flag to host cpu while compiling llvm ir */
+};
+
+/// \brief configuration builder for gandiva
+///
+/// Provides a default configuration and convenience methods
+/// to override specific values and build a custom instance
+class GANDIVA_EXPORT ConfigurationBuilder {
+ public:
+ std::shared_ptr<Configuration> build() {
+ std::shared_ptr<Configuration> configuration(new Configuration());
+ return configuration;
+ }
+
+ std::shared_ptr<Configuration> build(bool optimize) {
+ std::shared_ptr<Configuration> configuration(new Configuration(optimize));
+ return configuration;
+ }
+
+ static std::shared_ptr<Configuration> DefaultConfiguration() {
+ return default_configuration_;
+ }
+
+ private:
+ static std::shared_ptr<Configuration> InitDefaultConfig() {
+ std::shared_ptr<Configuration> configuration(new Configuration());
+ return configuration;
+ }
+
+ static const std::shared_ptr<Configuration> default_configuration_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/context_helper.cc b/src/arrow/cpp/src/gandiva/context_helper.cc
new file mode 100644
index 000000000..224bfd8f5
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/context_helper.cc
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This file is also used in the pre-compiled unit tests, which do include
+// llvm/engine/..
+#ifndef GANDIVA_UNIT_TEST
+#include "gandiva/exported_funcs.h"
+#include "gandiva/gdv_function_stubs.h"
+
+#include "gandiva/engine.h"
+
+namespace gandiva {
+
+void ExportedContextFunctions::AddMappings(Engine* engine) const {
+ std::vector<llvm::Type*> args;
+ auto types = engine->types();
+
+ // gdv_fn_context_set_error_msg
+ args = {types->i64_type(), // int64_t context_ptr
+ types->i8_ptr_type()}; // char const* err_msg
+
+ engine->AddGlobalMappingForFunc("gdv_fn_context_set_error_msg", types->void_type(),
+ args,
+ reinterpret_cast<void*>(gdv_fn_context_set_error_msg));
+
+ // gdv_fn_context_arena_malloc
+ args = {types->i64_type(), // int64_t context_ptr
+ types->i32_type()}; // int32_t size
+
+ engine->AddGlobalMappingForFunc("gdv_fn_context_arena_malloc", types->i8_ptr_type(),
+ args,
+ reinterpret_cast<void*>(gdv_fn_context_arena_malloc));
+
+ // gdv_fn_context_arena_reset
+ args = {types->i64_type()}; // int64_t context_ptr
+
+ engine->AddGlobalMappingForFunc("gdv_fn_context_arena_reset", types->void_type(), args,
+ reinterpret_cast<void*>(gdv_fn_context_arena_reset));
+}
+
+} // namespace gandiva
+#endif // !GANDIVA_UNIT_TEST
+
+#include "gandiva/execution_context.h"
+
+extern "C" {
+
+void gdv_fn_context_set_error_msg(int64_t context_ptr, char const* err_msg) {
+ auto context = reinterpret_cast<gandiva::ExecutionContext*>(context_ptr);
+ context->set_error_msg(err_msg);
+}
+
+uint8_t* gdv_fn_context_arena_malloc(int64_t context_ptr, int32_t size) {
+ auto context = reinterpret_cast<gandiva::ExecutionContext*>(context_ptr);
+ return context->arena()->Allocate(size);
+}
+
+void gdv_fn_context_arena_reset(int64_t context_ptr) {
+ auto context = reinterpret_cast<gandiva::ExecutionContext*>(context_ptr);
+ return context->arena()->Reset();
+}
+}
diff --git a/src/arrow/cpp/src/gandiva/date_utils.cc b/src/arrow/cpp/src/gandiva/date_utils.cc
new file mode 100644
index 000000000..f0a80d3c9
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/date_utils.cc
@@ -0,0 +1,232 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <sstream>
+#include <vector>
+
+#include "gandiva/date_utils.h"
+
+namespace gandiva {
+
+std::vector<std::string> DateUtils::GetMatches(std::string pattern, bool exactMatch) {
+ // we are case insensitive
+ std::transform(pattern.begin(), pattern.end(), pattern.begin(), ::tolower);
+ std::vector<std::string> matches;
+
+ for (const auto& it : sql_date_format_to_boost_map_) {
+ if (it.first.find(pattern) != std::string::npos &&
+ (!exactMatch || (it.first.length() == pattern.length()))) {
+ matches.push_back(it.first);
+ }
+ }
+
+ return matches;
+}
+
+std::vector<std::string> DateUtils::GetPotentialMatches(const std::string& pattern) {
+ return GetMatches(pattern, false);
+}
+
+std::vector<std::string> DateUtils::GetExactMatches(const std::string& pattern) {
+ return GetMatches(pattern, true);
+}
+
+/**
+ * Validates and converts format to the strptime equivalent
+ *
+ */
+Status DateUtils::ToInternalFormat(const std::string& format,
+ std::shared_ptr<std::string>* internal_format) {
+ std::stringstream builder;
+ std::stringstream buffer;
+ bool is_in_quoted_text = false;
+
+ for (size_t i = 0; i < format.size(); i++) {
+ char currentChar = format[i];
+
+ // logic before we append to the buffer
+ if (currentChar == '"') {
+ if (is_in_quoted_text) {
+ // we are done with a quoted block
+ is_in_quoted_text = false;
+
+ // use ' for quoting
+ builder << '\'';
+ builder << buffer.str();
+ builder << '\'';
+
+ // clear buffer
+ buffer.str("");
+ continue;
+ } else {
+ ARROW_RETURN_IF(buffer.str().length() > 0,
+ Status::Invalid("Invalid date format string '", format, "'"));
+
+ is_in_quoted_text = true;
+ continue;
+ }
+ }
+
+ // handle special characters we want to simply pass through, but only if not in quoted
+ // and the buffer is empty
+ std::string special_characters = "*-/,.;: ";
+ if (!is_in_quoted_text && buffer.str().length() == 0 &&
+ (special_characters.find_first_of(currentChar) != std::string::npos)) {
+ builder << currentChar;
+ continue;
+ }
+
+ // append to the buffer
+ buffer << currentChar;
+
+ // nothing else to do if we are in quoted text
+ if (is_in_quoted_text) {
+ continue;
+ }
+
+ // check how many matches we have for our buffer
+ std::vector<std::string> potentialList = GetPotentialMatches(buffer.str());
+ int64_t potentialCount = potentialList.size();
+
+ if (potentialCount >= 1) {
+ // one potential and the length match
+ if (potentialCount == 1 && potentialList[0].length() == buffer.str().length()) {
+ // we have a match!
+ builder << sql_date_format_to_boost_map_[potentialList[0]];
+ buffer.str("");
+ } else {
+ // Some patterns (like MON, MONTH) can cause ambiguity, such as "MON:". "MON"
+ // will have two potential matches, but "MON:" will match nothing, so we want to
+ // look ahead when we match "MON" and check if adding the next char leads to 0
+ // potentials. If it does, we go ahead and treat the buffer as matched (if a
+ // potential match exists that matches the buffer)
+ if (format.length() - 1 > i) {
+ std::string lookAheadPattern = (buffer.str() + format.at(i + 1));
+ std::transform(lookAheadPattern.begin(), lookAheadPattern.end(),
+ lookAheadPattern.begin(), ::tolower);
+ bool lookAheadMatched = false;
+
+ // we can query potentialList to see if it has anything that matches the
+ // lookahead pattern
+ for (std::string potential : potentialList) {
+ if (potential.find(lookAheadPattern) != std::string::npos) {
+ lookAheadMatched = true;
+ break;
+ }
+ }
+
+ if (!lookAheadMatched) {
+ // check if any of the potential matches are the same length as our buffer, we
+ // do not want to match "MO:"
+ bool matched = false;
+ for (std::string potential : potentialList) {
+ if (potential.length() == buffer.str().length()) {
+ matched = true;
+ break;
+ }
+ }
+
+ if (matched) {
+ std::string match = buffer.str();
+ std::transform(match.begin(), match.end(), match.begin(), ::tolower);
+ builder << sql_date_format_to_boost_map_[match];
+ buffer.str("");
+ continue;
+ }
+ }
+ }
+ }
+ } else {
+ return Status::Invalid("Invalid date format string '", format, "'");
+ }
+ }
+
+ if (buffer.str().length() > 0) {
+ // Some patterns (like MON, MONTH) can cause us to reach this point with a valid
+ // buffer value as MON has 2 valid potential matches, so double check here
+ std::vector<std::string> exactMatches = GetExactMatches(buffer.str());
+ if (exactMatches.size() == 1 && exactMatches[0].length() == buffer.str().length()) {
+ builder << sql_date_format_to_boost_map_[exactMatches[0]];
+ } else {
+ // Format partially parsed
+ int64_t pos = format.length() - buffer.str().length();
+ return Status::Invalid("Invalid date format string '", format, "' at position ",
+ pos);
+ }
+ }
+ std::string final_pattern = builder.str();
+ internal_format->reset(new std::string(final_pattern));
+ return Status::OK();
+}
+
+DateUtils::date_format_converter DateUtils::sql_date_format_to_boost_map_ = InitMap();
+
+DateUtils::date_format_converter DateUtils::InitMap() {
+ date_format_converter map;
+
+ // Era
+ map["ad"] = "%EC";
+ map["bc"] = "%EC";
+ // Meridian
+ map["am"] = "%p";
+ map["pm"] = "%p";
+ // Century
+ map["cc"] = "%C";
+ // Week of year
+ map["ww"] = "%W";
+ // Day of week
+ map["d"] = "%u";
+ // Day name of week
+ map["dy"] = "%a";
+ map["day"] = "%a";
+ // Year
+ map["yyyy"] = "%Y";
+ map["yy"] = "%y";
+ // Day of year
+ map["ddd"] = "%j";
+ // Month
+ map["mm"] = "%m";
+ map["mon"] = "%b";
+ map["month"] = "%b";
+ // Day of month
+ map["dd"] = "%d";
+ // Hour of day
+ map["hh"] = "%I";
+ map["hh12"] = "%I";
+ map["hh24"] = "%H";
+ // Minutes
+ map["mi"] = "%M";
+ // Seconds
+ map["ss"] = "%S";
+ // Milliseconds
+ map["f"] = "S";
+ map["ff"] = "SS";
+ map["fff"] = "SSS";
+ /*
+ // Timezone not tested/supported yet fully.
+ map["tzd"] = "%Z";
+ map["tzo"] = "%z";
+ map["tzh:tzm"] = "%z";
+ */
+
+ return map;
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/date_utils.h b/src/arrow/cpp/src/gandiva/date_utils.h
new file mode 100644
index 000000000..0d39a5f29
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/date_utils.h
@@ -0,0 +1,52 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "arrow/util/macros.h"
+
+#include "gandiva/arrow.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// \brief Utility class for converting sql date patterns to internal date patterns.
+class GANDIVA_EXPORT DateUtils {
+ public:
+ static Status ToInternalFormat(const std::string& format,
+ std::shared_ptr<std::string>* internal_format);
+
+ private:
+ using date_format_converter = std::unordered_map<std::string, std::string>;
+
+ static date_format_converter sql_date_format_to_boost_map_;
+
+ static date_format_converter InitMap();
+
+ static std::vector<std::string> GetMatches(std::string pattern, bool exactMatch);
+
+ static std::vector<std::string> GetPotentialMatches(const std::string& pattern);
+
+ static std::vector<std::string> GetExactMatches(const std::string& pattern);
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/decimal_ir.cc b/src/arrow/cpp/src/gandiva/decimal_ir.cc
new file mode 100644
index 000000000..5d5d30b4a
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/decimal_ir.cc
@@ -0,0 +1,559 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+#include <unordered_set>
+#include <utility>
+
+#include "arrow/status.h"
+#include "gandiva/decimal_ir.h"
+#include "gandiva/decimal_type_util.h"
+
+// Algorithms adapted from Apache Impala
+
+namespace gandiva {
+
+#define ADD_TRACE_32(msg, value) \
+ if (enable_ir_traces_) { \
+ AddTrace32(msg, value); \
+ }
+#define ADD_TRACE_128(msg, value) \
+ if (enable_ir_traces_) { \
+ AddTrace128(msg, value); \
+ }
+
+// These are the functions defined in this file. The rest are in precompiled folder,
+// and the i128 needs to be dis-assembled for those.
+static const char* kAddFunction = "add_decimal128_decimal128";
+static const char* kSubtractFunction = "subtract_decimal128_decimal128";
+static const char* kEQFunction = "equal_decimal128_decimal128";
+static const char* kNEFunction = "not_equal_decimal128_decimal128";
+static const char* kLTFunction = "less_than_decimal128_decimal128";
+static const char* kLEFunction = "less_than_or_equal_to_decimal128_decimal128";
+static const char* kGTFunction = "greater_than_decimal128_decimal128";
+static const char* kGEFunction = "greater_than_or_equal_to_decimal128_decimal128";
+
+static const std::unordered_set<std::string> kDecimalIRBuilderFunctions{
+ kAddFunction, kSubtractFunction, kEQFunction, kNEFunction,
+ kLTFunction, kLEFunction, kGTFunction, kGEFunction};
+
+const char* DecimalIR::kScaleMultipliersName = "gandivaScaleMultipliers";
+
+/// Populate globals required by decimal IR.
+/// TODO: can this be done just once ?
+void DecimalIR::AddGlobals(Engine* engine) {
+ auto types = engine->types();
+
+ // populate vector : [ 1, 10, 100, 1000, ..]
+ std::string value = "1";
+ std::vector<llvm::Constant*> scale_multipliers;
+ for (int i = 0; i < DecimalTypeUtil::kMaxPrecision + 1; ++i) {
+ auto multiplier =
+ llvm::ConstantInt::get(llvm::Type::getInt128Ty(*engine->context()), value, 10);
+ scale_multipliers.push_back(multiplier);
+ value.append("0");
+ }
+
+ auto array_type =
+ llvm::ArrayType::get(types->i128_type(), DecimalTypeUtil::kMaxPrecision + 1);
+ auto initializer = llvm::ConstantArray::get(
+ array_type, llvm::ArrayRef<llvm::Constant*>(scale_multipliers));
+
+ auto globalScaleMultipliers = new llvm::GlobalVariable(
+ *engine->module(), array_type, true /*constant*/,
+ llvm::GlobalValue::LinkOnceAnyLinkage, initializer, kScaleMultipliersName);
+ globalScaleMultipliers->setAlignment(LLVM_ALIGN(16));
+}
+
+// Lookup intrinsic functions
+void DecimalIR::InitializeIntrinsics() {
+ sadd_with_overflow_fn_ = llvm::Intrinsic::getDeclaration(
+ module(), llvm::Intrinsic::sadd_with_overflow, types()->i128_type());
+ DCHECK_NE(sadd_with_overflow_fn_, nullptr);
+
+ smul_with_overflow_fn_ = llvm::Intrinsic::getDeclaration(
+ module(), llvm::Intrinsic::smul_with_overflow, types()->i128_type());
+ DCHECK_NE(smul_with_overflow_fn_, nullptr);
+
+ i128_with_overflow_struct_type_ =
+ sadd_with_overflow_fn_->getFunctionType()->getReturnType();
+}
+
+// CPP: return kScaleMultipliers[scale]
+llvm::Value* DecimalIR::GetScaleMultiplier(llvm::Value* scale) {
+ auto const_array = module()->getGlobalVariable(kScaleMultipliersName);
+ auto ptr = CreateGEP(ir_builder(), const_array, {types()->i32_constant(0), scale});
+ return CreateLoad(ir_builder(), ptr);
+}
+
+// CPP: x <= y ? y : x
+llvm::Value* DecimalIR::GetHigherScale(llvm::Value* x_scale, llvm::Value* y_scale) {
+ llvm::Value* le = ir_builder()->CreateICmpSLE(x_scale, y_scale);
+ return ir_builder()->CreateSelect(le, y_scale, x_scale);
+}
+
+// CPP: return (increase_scale_by <= 0) ?
+// in_value : in_value * GetScaleMultiplier(increase_scale_by)
+llvm::Value* DecimalIR::IncreaseScale(llvm::Value* in_value,
+ llvm::Value* increase_scale_by) {
+ llvm::Value* le_zero =
+ ir_builder()->CreateICmpSLE(increase_scale_by, types()->i32_constant(0));
+ // then block
+ auto then_lambda = [&] { return in_value; };
+
+ // else block
+ auto else_lambda = [&] {
+ llvm::Value* multiplier = GetScaleMultiplier(increase_scale_by);
+ return ir_builder()->CreateMul(in_value, multiplier);
+ };
+
+ return BuildIfElse(le_zero, types()->i128_type(), then_lambda, else_lambda);
+}
+
+// CPP: return (increase_scale_by <= 0) ?
+// {in_value,false} : {in_value * GetScaleMultiplier(increase_scale_by),true}
+//
+// The return value also indicates if there was an overflow while increasing the scale.
+DecimalIR::ValueWithOverflow DecimalIR::IncreaseScaleWithOverflowCheck(
+ llvm::Value* in_value, llvm::Value* increase_scale_by) {
+ llvm::Value* le_zero =
+ ir_builder()->CreateICmpSLE(increase_scale_by, types()->i32_constant(0));
+
+ // then block
+ auto then_lambda = [&] {
+ ValueWithOverflow ret{in_value, types()->false_constant()};
+ return ret.AsStruct(this);
+ };
+
+ // else block
+ auto else_lambda = [&] {
+ llvm::Value* multiplier = GetScaleMultiplier(increase_scale_by);
+ return ir_builder()->CreateCall(smul_with_overflow_fn_, {in_value, multiplier});
+ };
+
+ auto ir_struct =
+ BuildIfElse(le_zero, i128_with_overflow_struct_type_, then_lambda, else_lambda);
+ return ValueWithOverflow::MakeFromStruct(this, ir_struct);
+}
+
+// CPP: return (reduce_scale_by <= 0) ?
+// in_value : in_value / GetScaleMultiplier(reduce_scale_by)
+//
+// ReduceScale cannot cause an overflow.
+llvm::Value* DecimalIR::ReduceScale(llvm::Value* in_value, llvm::Value* reduce_scale_by) {
+ auto le_zero = ir_builder()->CreateICmpSLE(reduce_scale_by, types()->i32_constant(0));
+ // then block
+ auto then_lambda = [&] { return in_value; };
+
+ // else block
+ auto else_lambda = [&] {
+ // TODO : handle rounding.
+ llvm::Value* multiplier = GetScaleMultiplier(reduce_scale_by);
+ return ir_builder()->CreateSDiv(in_value, multiplier);
+ };
+
+ return BuildIfElse(le_zero, types()->i128_type(), then_lambda, else_lambda);
+}
+
+/// @brief Fast-path for add
+/// Adjust x and y to the same scale, and add them.
+llvm::Value* DecimalIR::AddFastPath(const ValueFull& x, const ValueFull& y) {
+ auto higher_scale = GetHigherScale(x.scale(), y.scale());
+ ADD_TRACE_32("AddFastPath : higher_scale", higher_scale);
+
+ // CPP : x_scaled = IncreaseScale(x_value, higher_scale - x_scale)
+ auto x_delta = ir_builder()->CreateSub(higher_scale, x.scale());
+ auto x_scaled = IncreaseScale(x.value(), x_delta);
+ ADD_TRACE_128("AddFastPath : x_scaled", x_scaled);
+
+ // CPP : y_scaled = IncreaseScale(y_value, higher_scale - y_scale)
+ auto y_delta = ir_builder()->CreateSub(higher_scale, y.scale());
+ auto y_scaled = IncreaseScale(y.value(), y_delta);
+ ADD_TRACE_128("AddFastPath : y_scaled", y_scaled);
+
+ auto sum = ir_builder()->CreateAdd(x_scaled, y_scaled);
+ ADD_TRACE_128("AddFastPath : sum", sum);
+ return sum;
+}
+
+// @brief Add with overflow check.
+/// Adjust x and y to the same scale, add them, and reduce sum to output scale.
+/// If there is an overflow, the sum is set to 0.
+DecimalIR::ValueWithOverflow DecimalIR::AddWithOverflowCheck(const ValueFull& x,
+ const ValueFull& y,
+ const ValueFull& out) {
+ auto higher_scale = GetHigherScale(x.scale(), y.scale());
+ ADD_TRACE_32("AddWithOverflowCheck : higher_scale", higher_scale);
+
+ // CPP : x_scaled = IncreaseScale(x_value, higher_scale - x.scale())
+ auto x_delta = ir_builder()->CreateSub(higher_scale, x.scale());
+ auto x_scaled = IncreaseScaleWithOverflowCheck(x.value(), x_delta);
+ ADD_TRACE_128("AddWithOverflowCheck : x_scaled", x_scaled.value());
+
+ // CPP : y_scaled = IncreaseScale(y_value, higher_scale - y_scale)
+ auto y_delta = ir_builder()->CreateSub(higher_scale, y.scale());
+ auto y_scaled = IncreaseScaleWithOverflowCheck(y.value(), y_delta);
+ ADD_TRACE_128("AddWithOverflowCheck : y_scaled", y_scaled.value());
+
+ // CPP : sum = x_scaled + y_scaled
+ auto sum_ir_struct = ir_builder()->CreateCall(sadd_with_overflow_fn_,
+ {x_scaled.value(), y_scaled.value()});
+ auto sum = ValueWithOverflow::MakeFromStruct(this, sum_ir_struct);
+ ADD_TRACE_128("AddWithOverflowCheck : sum", sum.value());
+
+ // CPP : overflow ? 0 : sum / GetScaleMultiplier(max_scale - out_scale)
+ auto overflow = GetCombinedOverflow({x_scaled, y_scaled, sum});
+ ADD_TRACE_32("AddWithOverflowCheck : overflow", overflow);
+ auto then_lambda = [&] {
+ // if there is an overflow, the value returned won't be used. so, save the division.
+ return types()->i128_constant(0);
+ };
+ auto else_lambda = [&] {
+ auto reduce_scale_by = ir_builder()->CreateSub(higher_scale, out.scale());
+ return ReduceScale(sum.value(), reduce_scale_by);
+ };
+ auto sum_descaled =
+ BuildIfElse(overflow, types()->i128_type(), then_lambda, else_lambda);
+ return ValueWithOverflow(sum_descaled, overflow);
+}
+
+// This is pretty complex, so use CPP fns.
+llvm::Value* DecimalIR::AddLarge(const ValueFull& x, const ValueFull& y,
+ const ValueFull& out) {
+ auto block = ir_builder()->GetInsertBlock();
+ auto out_high_ptr = new llvm::AllocaInst(types()->i64_type(), 0, "out_hi", block);
+ auto out_low_ptr = new llvm::AllocaInst(types()->i64_type(), 0, "out_low", block);
+ auto x_split = ValueSplit::MakeFromInt128(this, x.value());
+ auto y_split = ValueSplit::MakeFromInt128(this, y.value());
+
+ std::vector<llvm::Value*> args = {
+ x_split.high(), x_split.low(), x.precision(), x.scale(),
+ y_split.high(), y_split.low(), y.precision(), y.scale(),
+ out.precision(), out.scale(), out_high_ptr, out_low_ptr,
+ };
+ ir_builder()->CreateCall(module()->getFunction("add_large_decimal128_decimal128"),
+ args);
+
+ auto out_high = CreateLoad(ir_builder(), out_high_ptr);
+ auto out_low = CreateLoad(ir_builder(), out_low_ptr);
+ auto sum = ValueSplit(out_high, out_low).AsInt128(this);
+ ADD_TRACE_128("AddLarge : sum", sum);
+ return sum;
+}
+
+/// The output scale/precision cannot be arbitrary values. The algo here depends on them
+/// to be the same as computed in DecimalTypeSql.
+/// TODO: enforce this.
+Status DecimalIR::BuildAdd() {
+ // Create fn prototype :
+ // int128_t
+ // add_decimal128_decimal128(int128_t x_value, int32_t x_precision, int32_t x_scale,
+ // int128_t y_value, int32_t y_precision, int32_t y_scale
+ // int32_t out_precision, int32_t out_scale)
+ auto i32 = types()->i32_type();
+ auto i128 = types()->i128_type();
+ auto function = BuildFunction(kAddFunction, i128,
+ {
+ {"x_value", i128},
+ {"x_precision", i32},
+ {"x_scale", i32},
+ {"y_value", i128},
+ {"y_precision", i32},
+ {"y_scale", i32},
+ {"out_precision", i32},
+ {"out_scale", i32},
+ });
+
+ auto arg_iter = function->arg_begin();
+ ValueFull x(&arg_iter[0], &arg_iter[1], &arg_iter[2]);
+ ValueFull y(&arg_iter[3], &arg_iter[4], &arg_iter[5]);
+ ValueFull out(nullptr, &arg_iter[6], &arg_iter[7]);
+
+ auto entry = llvm::BasicBlock::Create(*context(), "entry", function);
+ ir_builder()->SetInsertPoint(entry);
+
+ // CPP :
+ // if (out_precision < 38) {
+ // return AddFastPath(x, y)
+ // } else {
+ // ret = AddWithOverflowCheck(x, y)
+ // if (ret.overflow)
+ // return AddLarge(x, y)
+ // else
+ // return ret.value;
+ // }
+ llvm::Value* lt_max_precision = ir_builder()->CreateICmpSLT(
+ out.precision(), types()->i32_constant(DecimalTypeUtil::kMaxPrecision));
+ auto then_lambda = [&] {
+ // fast-path add
+ return AddFastPath(x, y);
+ };
+ auto else_lambda = [&] {
+ if (kUseOverflowIntrinsics) {
+ // do the add and check if there was overflow
+ auto ret = AddWithOverflowCheck(x, y, out);
+
+ // if there is an overflow, switch to the AddLarge codepath.
+ return BuildIfElse(
+ ret.overflow(), types()->i128_type(), [&] { return AddLarge(x, y, out); },
+ [&] { return ret.value(); });
+ } else {
+ return AddLarge(x, y, out);
+ }
+ };
+ auto value =
+ BuildIfElse(lt_max_precision, types()->i128_type(), then_lambda, else_lambda);
+
+ // store result to out
+ ir_builder()->CreateRet(value);
+ return Status::OK();
+}
+
+Status DecimalIR::BuildSubtract() {
+ // Create fn prototype :
+ // int128_t
+ // subtract_decimal128_decimal128(int128_t x_value, int32_t x_precision, int32_t
+ // x_scale,
+ // int128_t y_value, int32_t y_precision, int32_t y_scale
+ // int32_t out_precision, int32_t out_scale)
+ auto i32 = types()->i32_type();
+ auto i128 = types()->i128_type();
+ auto function = BuildFunction(kSubtractFunction, i128,
+ {
+ {"x_value", i128},
+ {"x_precision", i32},
+ {"x_scale", i32},
+ {"y_value", i128},
+ {"y_precision", i32},
+ {"y_scale", i32},
+ {"out_precision", i32},
+ {"out_scale", i32},
+ });
+
+ auto entry = llvm::BasicBlock::Create(*context(), "entry", function);
+ ir_builder()->SetInsertPoint(entry);
+
+ // reuse add function after negating y_value. i.e
+ // add(x_value, x_precision, x_scale, -y_value, y_precision, y_scale,
+ // out_precision, out_scale)
+ std::vector<llvm::Value*> args;
+ int i = 0;
+ for (auto& in_arg : function->args()) {
+ if (i == 3) {
+ auto y_neg_value = ir_builder()->CreateNeg(&in_arg);
+ args.push_back(y_neg_value);
+ } else {
+ args.push_back(&in_arg);
+ }
+ ++i;
+ }
+ auto value = ir_builder()->CreateCall(module()->getFunction(kAddFunction), args);
+
+ // store result to out
+ ir_builder()->CreateRet(value);
+ return Status::OK();
+}
+
+Status DecimalIR::BuildCompare(const std::string& function_name,
+ llvm::ICmpInst::Predicate cmp_instruction) {
+ // Create fn prototype :
+ // bool
+ // function_name(int128_t x_value, int32_t x_precision, int32_t x_scale,
+ // int128_t y_value, int32_t y_precision, int32_t y_scale)
+
+ auto i32 = types()->i32_type();
+ auto i128 = types()->i128_type();
+ auto function = BuildFunction(function_name, types()->i1_type(),
+ {
+ {"x_value", i128},
+ {"x_precision", i32},
+ {"x_scale", i32},
+ {"y_value", i128},
+ {"y_precision", i32},
+ {"y_scale", i32},
+ });
+
+ auto arg_iter = function->arg_begin();
+ ValueFull x(&arg_iter[0], &arg_iter[1], &arg_iter[2]);
+ ValueFull y(&arg_iter[3], &arg_iter[4], &arg_iter[5]);
+
+ auto entry = llvm::BasicBlock::Create(*context(), "entry", function);
+ ir_builder()->SetInsertPoint(entry);
+
+ // Make call to pre-compiled IR function.
+ auto x_split = ValueSplit::MakeFromInt128(this, x.value());
+ auto y_split = ValueSplit::MakeFromInt128(this, y.value());
+
+ std::vector<llvm::Value*> args = {
+ x_split.high(), x_split.low(), x.precision(), x.scale(),
+ y_split.high(), y_split.low(), y.precision(), y.scale(),
+ };
+ auto cmp_value = ir_builder()->CreateCall(
+ module()->getFunction("compare_decimal128_decimal128_internal"), args);
+ auto result =
+ ir_builder()->CreateICmp(cmp_instruction, cmp_value, types()->i32_constant(0));
+ ir_builder()->CreateRet(result);
+ return Status::OK();
+}
+
+llvm::Value* DecimalIR::CallDecimalFunction(const std::string& function_name,
+ llvm::Type* return_type,
+ const std::vector<llvm::Value*>& params) {
+ if (kDecimalIRBuilderFunctions.count(function_name) != 0) {
+ // this is fn built with the irbuilder.
+ return ir_builder()->CreateCall(module()->getFunction(function_name), params);
+ }
+
+ // ppre-compiler fn : disassemble i128 to two i64s and re-assemble.
+ auto i128 = types()->i128_type();
+ auto i64 = types()->i64_type();
+ std::vector<llvm::Value*> dis_assembled_args;
+ for (auto& arg : params) {
+ if (arg->getType() == i128) {
+ // split i128 arg into two int64s.
+ auto split = ValueSplit::MakeFromInt128(this, arg);
+ dis_assembled_args.push_back(split.high());
+ dis_assembled_args.push_back(split.low());
+ } else {
+ dis_assembled_args.push_back(arg);
+ }
+ }
+
+ llvm::Value* result = nullptr;
+ if (return_type == i128) {
+ // for i128 ret, replace with two int64* args, and join them.
+ auto block = ir_builder()->GetInsertBlock();
+ auto out_high_ptr = new llvm::AllocaInst(i64, 0, "out_hi", block);
+ auto out_low_ptr = new llvm::AllocaInst(i64, 0, "out_low", block);
+ dis_assembled_args.push_back(out_high_ptr);
+ dis_assembled_args.push_back(out_low_ptr);
+
+ // Make call to pre-compiled IR function.
+ ir_builder()->CreateCall(module()->getFunction(function_name), dis_assembled_args);
+
+ auto out_high = CreateLoad(ir_builder(), out_high_ptr);
+ auto out_low = CreateLoad(ir_builder(), out_low_ptr);
+ result = ValueSplit(out_high, out_low).AsInt128(this);
+ } else {
+ DCHECK_NE(return_type, types()->void_type());
+
+ // Make call to pre-compiled IR function.
+ result = ir_builder()->CreateCall(module()->getFunction(function_name),
+ dis_assembled_args);
+ }
+ return result;
+}
+
+Status DecimalIR::AddFunctions(Engine* engine) {
+ auto decimal_ir = std::make_shared<DecimalIR>(engine);
+
+ // Populate global variables used by decimal operations.
+ decimal_ir->AddGlobals(engine);
+
+ // Lookup intrinsic functions
+ decimal_ir->InitializeIntrinsics();
+
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildAdd());
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildSubtract());
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(kEQFunction, llvm::ICmpInst::ICMP_EQ));
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(kNEFunction, llvm::ICmpInst::ICMP_NE));
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(kLTFunction, llvm::ICmpInst::ICMP_SLT));
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(kLEFunction, llvm::ICmpInst::ICMP_SLE));
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(kGTFunction, llvm::ICmpInst::ICMP_SGT));
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(kGEFunction, llvm::ICmpInst::ICMP_SGE));
+ return Status::OK();
+}
+
+// Do an bitwise-or of all the overflow bits.
+llvm::Value* DecimalIR::GetCombinedOverflow(
+ std::vector<DecimalIR::ValueWithOverflow> vec) {
+ llvm::Value* res = types()->false_constant();
+ for (auto& val : vec) {
+ res = ir_builder()->CreateOr(res, val.overflow());
+ }
+ return res;
+}
+
+DecimalIR::ValueSplit DecimalIR::ValueSplit::MakeFromInt128(DecimalIR* decimal_ir,
+ llvm::Value* in) {
+ auto builder = decimal_ir->ir_builder();
+ auto types = decimal_ir->types();
+
+ auto high = builder->CreateLShr(in, types->i128_constant(64));
+ high = builder->CreateTrunc(high, types->i64_type());
+ auto low = builder->CreateTrunc(in, types->i64_type());
+ return ValueSplit(high, low);
+}
+
+/// Convert IR struct {%i64, %i64} to cpp class ValueSplit
+DecimalIR::ValueSplit DecimalIR::ValueSplit::MakeFromStruct(DecimalIR* decimal_ir,
+ llvm::Value* dstruct) {
+ auto builder = decimal_ir->ir_builder();
+ auto high = builder->CreateExtractValue(dstruct, 0);
+ auto low = builder->CreateExtractValue(dstruct, 1);
+ return DecimalIR::ValueSplit(high, low);
+}
+
+llvm::Value* DecimalIR::ValueSplit::AsInt128(DecimalIR* decimal_ir) const {
+ auto builder = decimal_ir->ir_builder();
+ auto types = decimal_ir->types();
+
+ auto value = builder->CreateSExt(high_, types->i128_type());
+ value = builder->CreateShl(value, types->i128_constant(64));
+ value = builder->CreateAdd(value, builder->CreateZExt(low_, types->i128_type()));
+ return value;
+}
+
+/// Convert IR struct {%i128, %i1} to cpp class ValueWithOverflow
+DecimalIR::ValueWithOverflow DecimalIR::ValueWithOverflow::MakeFromStruct(
+ DecimalIR* decimal_ir, llvm::Value* dstruct) {
+ auto builder = decimal_ir->ir_builder();
+ auto value = builder->CreateExtractValue(dstruct, 0);
+ auto overflow = builder->CreateExtractValue(dstruct, 1);
+ return DecimalIR::ValueWithOverflow(value, overflow);
+}
+
+/// Convert to IR struct {%i128, %i1}
+llvm::Value* DecimalIR::ValueWithOverflow::AsStruct(DecimalIR* decimal_ir) const {
+ auto builder = decimal_ir->ir_builder();
+
+ auto undef = llvm::UndefValue::get(decimal_ir->i128_with_overflow_struct_type_);
+ auto struct_val = builder->CreateInsertValue(undef, value(), 0);
+ return builder->CreateInsertValue(struct_val, overflow(), 1);
+}
+
+/// debug traces
+void DecimalIR::AddTrace(const std::string& fmt, std::vector<llvm::Value*> args) {
+ DCHECK(enable_ir_traces_);
+
+ auto ir_str = ir_builder()->CreateGlobalStringPtr(fmt);
+ args.insert(args.begin(), ir_str);
+ ir_builder()->CreateCall(module()->getFunction("printf"), args, "trace");
+}
+
+void DecimalIR::AddTrace32(const std::string& msg, llvm::Value* value) {
+ AddTrace("DECIMAL_IR_TRACE:: " + msg + " %d\n", {value});
+}
+
+void DecimalIR::AddTrace128(const std::string& msg, llvm::Value* value) {
+ // convert i128 into two i64s for printing
+ auto split = ValueSplit::MakeFromInt128(this, value);
+ AddTrace("DECIMAL_IR_TRACE:: " + msg + " %llx:%llx (%lld:%llu)\n",
+ {split.high(), split.low(), split.high(), split.low()});
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/decimal_ir.h b/src/arrow/cpp/src/gandiva/decimal_ir.h
new file mode 100644
index 000000000..b11730f1e
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/decimal_ir.h
@@ -0,0 +1,188 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gandiva/function_ir_builder.h"
+
+namespace gandiva {
+
+/// @brief Decimal IR functions
+class DecimalIR : public FunctionIRBuilder {
+ public:
+ explicit DecimalIR(Engine* engine)
+ : FunctionIRBuilder(engine), enable_ir_traces_(false) {}
+
+ /// Build decimal IR functions and add them to the engine.
+ static Status AddFunctions(Engine* engine);
+
+ void EnableTraces() { enable_ir_traces_ = true; }
+
+ llvm::Value* CallDecimalFunction(const std::string& function_name,
+ llvm::Type* return_type,
+ const std::vector<llvm::Value*>& args);
+
+ private:
+ /// The intrinsic fn for divide with small divisors is about 10x slower, so not
+ /// using these.
+ static const bool kUseOverflowIntrinsics = false;
+
+ // Holder for an i128 value, along with its with scale and precision.
+ class ValueFull {
+ public:
+ ValueFull(llvm::Value* value, llvm::Value* precision, llvm::Value* scale)
+ : value_(value), precision_(precision), scale_(scale) {}
+
+ llvm::Value* value() const { return value_; }
+ llvm::Value* precision() const { return precision_; }
+ llvm::Value* scale() const { return scale_; }
+
+ private:
+ llvm::Value* value_;
+ llvm::Value* precision_;
+ llvm::Value* scale_;
+ };
+
+ // Holder for an i128 value, and a boolean indicating overflow.
+ class ValueWithOverflow {
+ public:
+ ValueWithOverflow(llvm::Value* value, llvm::Value* overflow)
+ : value_(value), overflow_(overflow) {}
+
+ // Make from IR struct
+ static ValueWithOverflow MakeFromStruct(DecimalIR* decimal_ir, llvm::Value* dstruct);
+
+ // Build a corresponding IR struct
+ llvm::Value* AsStruct(DecimalIR* decimal_ir) const;
+
+ llvm::Value* value() const { return value_; }
+ llvm::Value* overflow() const { return overflow_; }
+
+ private:
+ llvm::Value* value_;
+ llvm::Value* overflow_;
+ };
+
+ // Holder for an i128 value that is split into two i64s
+ class ValueSplit {
+ public:
+ ValueSplit(llvm::Value* high, llvm::Value* low) : high_(high), low_(low) {}
+
+ // Make from i128 value
+ static ValueSplit MakeFromInt128(DecimalIR* decimal_ir, llvm::Value* in);
+
+ // Make from IR struct
+ static ValueSplit MakeFromStruct(DecimalIR* decimal_ir, llvm::Value* dstruct);
+
+ // Combine the two parts into an i128
+ llvm::Value* AsInt128(DecimalIR* decimal_ir) const;
+
+ llvm::Value* high() const { return high_; }
+ llvm::Value* low() const { return low_; }
+
+ private:
+ llvm::Value* high_;
+ llvm::Value* low_;
+ };
+
+ // Add global variables to the module.
+ static void AddGlobals(Engine* engine);
+
+ // Initialize intrinsic functions that are used by decimal operations.
+ void InitializeIntrinsics();
+
+ // Create IR builder for decimal add function.
+ static Status MakeAdd(Engine* engine, std::shared_ptr<FunctionIRBuilder>* out);
+
+ // Get the multiplier for specified scale (i.e 10^scale)
+ llvm::Value* GetScaleMultiplier(llvm::Value* scale);
+
+ // Get the higher of the two scales
+ llvm::Value* GetHigherScale(llvm::Value* x_scale, llvm::Value* y_scale);
+
+ // Increase scale of 'in_value' by 'increase_scale_by'.
+ // - If 'increase_scale_by' is <= 0, does nothing.
+ llvm::Value* IncreaseScale(llvm::Value* in_value, llvm::Value* increase_scale_by);
+
+ // Similar to IncreaseScale. but, also check if there is overflow.
+ ValueWithOverflow IncreaseScaleWithOverflowCheck(llvm::Value* in_value,
+ llvm::Value* increase_scale_by);
+
+ // Reduce scale of 'in_value' by 'reduce_scale_by'.
+ // - If 'reduce_scale_by' is <= 0, does nothing.
+ llvm::Value* ReduceScale(llvm::Value* in_value, llvm::Value* reduce_scale_by);
+
+ // Fast path of add: guaranteed no overflow
+ llvm::Value* AddFastPath(const ValueFull& x, const ValueFull& y);
+
+ // Similar to AddFastPath, but check if there's an overflow.
+ ValueWithOverflow AddWithOverflowCheck(const ValueFull& x, const ValueFull& y,
+ const ValueFull& out);
+
+ // Do addition of large integers (both positive and negative).
+ llvm::Value* AddLarge(const ValueFull& x, const ValueFull& y, const ValueFull& out);
+
+ // Get the combined overflow (logical or).
+ llvm::Value* GetCombinedOverflow(std::vector<ValueWithOverflow> values);
+
+ // Build the function for adding decimals.
+ Status BuildAdd();
+
+ // Build the function for decimal subtraction.
+ Status BuildSubtract();
+
+ // Build the function for decimal multiplication.
+ Status BuildMultiply();
+
+ // Build the function for decimal division/mod.
+ Status BuildDivideOrMod(const std::string& function_name,
+ const std::string& internal_name);
+
+ Status BuildCompare(const std::string& function_name,
+ llvm::ICmpInst::Predicate cmp_instruction);
+
+ Status BuildDecimalFunction(const std::string& function_name, llvm::Type* return_type,
+ std::vector<NamedArg> in_types);
+
+ // Add a trace in IR code.
+ void AddTrace(const std::string& fmt, std::vector<llvm::Value*> args);
+
+ // Add a trace msg along with a 32-bit integer.
+ void AddTrace32(const std::string& msg, llvm::Value* value);
+
+ // Add a trace msg along with a 128-bit integer.
+ void AddTrace128(const std::string& msg, llvm::Value* value);
+
+ // name of the global variable having the array of scale multipliers.
+ static const char* kScaleMultipliersName;
+
+ // Intrinsic functions
+ llvm::Function* sadd_with_overflow_fn_;
+ llvm::Function* smul_with_overflow_fn_;
+
+ // struct { i128: value, i1: overflow}
+ llvm::Type* i128_with_overflow_struct_type_;
+
+ // if set to true, ir traces are enabled. Useful for debugging.
+ bool enable_ir_traces_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/decimal_scalar.h b/src/arrow/cpp/src/gandiva/decimal_scalar.h
new file mode 100644
index 000000000..a03807b35
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/decimal_scalar.h
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License
+
+#pragma once
+
+#include <cstdint>
+#include <string>
+#include "arrow/util/decimal.h"
+#include "arrow/util/hash_util.h"
+#include "gandiva/basic_decimal_scalar.h"
+
+namespace gandiva {
+
+using Decimal128 = arrow::Decimal128;
+
+/// Represents a 128-bit decimal value along with its precision and scale.
+///
+/// BasicDecimalScalar128 can be safely compiled to IR without references to libstdc++.
+/// This class has additional functionality on top of BasicDecimalScalar128 to deal with
+/// strings and streams.
+class DecimalScalar128 : public BasicDecimalScalar128 {
+ public:
+ using BasicDecimalScalar128::BasicDecimalScalar128;
+
+ DecimalScalar128(const std::string& value, int32_t precision, int32_t scale)
+ : BasicDecimalScalar128(Decimal128(value), precision, scale) {}
+
+ /// \brief constructor creates a DecimalScalar128 from a BasicDecimalScalar128.
+ constexpr DecimalScalar128(const BasicDecimalScalar128& scalar) noexcept
+ : BasicDecimalScalar128(scalar) {}
+
+ inline std::string ToString() const {
+ Decimal128 dvalue(value());
+ return dvalue.ToString(0) + "," + std::to_string(precision()) + "," +
+ std::to_string(scale());
+ }
+
+ friend std::ostream& operator<<(std::ostream& os, const DecimalScalar128& dec) {
+ os << dec.ToString();
+ return os;
+ }
+};
+
+} // namespace gandiva
+
+namespace std {
+template <>
+struct hash<gandiva::DecimalScalar128> {
+ std::size_t operator()(gandiva::DecimalScalar128 const& s) const noexcept {
+ arrow::BasicDecimal128 dvalue(s.value());
+
+ static const int kSeedValue = 4;
+ size_t result = kSeedValue;
+
+ arrow::internal::hash_combine(result, dvalue.high_bits());
+ arrow::internal::hash_combine(result, dvalue.low_bits());
+ arrow::internal::hash_combine(result, s.precision());
+ arrow::internal::hash_combine(result, s.scale());
+ return result;
+ }
+};
+} // namespace std
diff --git a/src/arrow/cpp/src/gandiva/decimal_type_util.cc b/src/arrow/cpp/src/gandiva/decimal_type_util.cc
new file mode 100644
index 000000000..2abc5a21e
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/decimal_type_util.cc
@@ -0,0 +1,75 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/decimal_type_util.h"
+#include "arrow/util/logging.h"
+
+namespace gandiva {
+
+constexpr int32_t DecimalTypeUtil::kMinAdjustedScale;
+
+#define DCHECK_TYPE(type) \
+ { \
+ DCHECK_GE(type->scale(), 0); \
+ DCHECK_LE(type->precision(), kMaxPrecision); \
+ }
+
+// Implementation of decimal rules.
+Status DecimalTypeUtil::GetResultType(Op op, const Decimal128TypeVector& in_types,
+ Decimal128TypePtr* out_type) {
+ DCHECK_EQ(in_types.size(), 2);
+
+ *out_type = nullptr;
+ auto t1 = in_types[0];
+ auto t2 = in_types[1];
+ DCHECK_TYPE(t1);
+ DCHECK_TYPE(t2);
+
+ int32_t s1 = t1->scale();
+ int32_t s2 = t2->scale();
+ int32_t p1 = t1->precision();
+ int32_t p2 = t2->precision();
+ int32_t result_scale = 0;
+ int32_t result_precision = 0;
+
+ switch (op) {
+ case kOpAdd:
+ case kOpSubtract:
+ result_scale = std::max(s1, s2);
+ result_precision = std::max(p1 - s1, p2 - s2) + result_scale + 1;
+ break;
+
+ case kOpMultiply:
+ result_scale = s1 + s2;
+ result_precision = p1 + p2 + 1;
+ break;
+
+ case kOpDivide:
+ result_scale = std::max(kMinAdjustedScale, s1 + p2 + 1);
+ result_precision = p1 - s1 + s2 + result_scale;
+ break;
+
+ case kOpMod:
+ result_scale = std::max(s1, s2);
+ result_precision = std::min(p1 - s1, p2 - s2) + result_scale;
+ break;
+ }
+ *out_type = MakeAdjustedType(result_precision, result_scale);
+ return Status::OK();
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/decimal_type_util.h b/src/arrow/cpp/src/gandiva/decimal_type_util.h
new file mode 100644
index 000000000..2b496f6cb
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/decimal_type_util.h
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Adapted from Apache Impala
+
+#pragma once
+
+#include <algorithm>
+#include <memory>
+
+#include "gandiva/arrow.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// @brief Handles conversion of scale/precision for operations on decimal types.
+/// TODO : do validations for all of these.
+class GANDIVA_EXPORT DecimalTypeUtil {
+ public:
+ enum Op {
+ kOpAdd,
+ kOpSubtract,
+ kOpMultiply,
+ kOpDivide,
+ kOpMod,
+ };
+
+ /// The maximum precision representable by a 4-byte decimal
+ static constexpr int32_t kMaxDecimal32Precision = 9;
+
+ /// The maximum precision representable by a 8-byte decimal
+ static constexpr int32_t kMaxDecimal64Precision = 18;
+
+ /// The maximum precision representable by a 16-byte decimal
+ static constexpr int32_t kMaxPrecision = 38;
+
+ // The maximum scale representable.
+ static constexpr int32_t kMaxScale = kMaxPrecision;
+
+ // When operating on decimal inputs, the integer part of the output can exceed the
+ // max precision. In such cases, the scale can be reduced, up to a minimum of
+ // kMinAdjustedScale.
+ // * There is no strong reason for 6, but both SQLServer and Impala use 6 too.
+ static constexpr int32_t kMinAdjustedScale = 6;
+
+ // For specified operation and input scale/precision, determine the output
+ // scale/precision.
+ static Status GetResultType(Op op, const Decimal128TypeVector& in_types,
+ Decimal128TypePtr* out_type);
+
+ static Decimal128TypePtr MakeType(int32_t precision, int32_t scale) {
+ return std::dynamic_pointer_cast<arrow::Decimal128Type>(
+ arrow::decimal(precision, scale));
+ }
+
+ private:
+ // Reduce the scale if possible so that precision stays <= kMaxPrecision
+ static Decimal128TypePtr MakeAdjustedType(int32_t precision, int32_t scale) {
+ if (precision > kMaxPrecision) {
+ int32_t min_scale = std::min(scale, kMinAdjustedScale);
+ int32_t delta = precision - kMaxPrecision;
+ precision = kMaxPrecision;
+ scale = std::max(scale - delta, min_scale);
+ }
+ return MakeType(precision, scale);
+ }
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/decimal_type_util_test.cc b/src/arrow/cpp/src/gandiva/decimal_type_util_test.cc
new file mode 100644
index 000000000..98ea0bb16
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/decimal_type_util_test.cc
@@ -0,0 +1,58 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Adapted from Apache Impala
+
+#include <gtest/gtest.h>
+
+#include "gandiva/decimal_type_util.h"
+#include "tests/test_util.h"
+
+namespace gandiva {
+
+#define DECIMAL_TYPE(p, s) DecimalTypeUtil::MakeType(p, s)
+
+Decimal128TypePtr DoOp(DecimalTypeUtil::Op op, Decimal128TypePtr d1,
+ Decimal128TypePtr d2) {
+ Decimal128TypePtr ret_type;
+ ARROW_EXPECT_OK(DecimalTypeUtil::GetResultType(op, {d1, d2}, &ret_type));
+ return ret_type;
+}
+
+TEST(DecimalResultTypes, Basic) {
+ EXPECT_ARROW_TYPE_EQUALS(
+ DECIMAL_TYPE(31, 10),
+ DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(30, 10), DECIMAL_TYPE(30, 10)));
+
+ EXPECT_ARROW_TYPE_EQUALS(
+ DECIMAL_TYPE(32, 6),
+ DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(30, 6), DECIMAL_TYPE(30, 5)));
+
+ EXPECT_ARROW_TYPE_EQUALS(
+ DECIMAL_TYPE(38, 9),
+ DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(30, 10), DECIMAL_TYPE(38, 10)));
+
+ EXPECT_ARROW_TYPE_EQUALS(
+ DECIMAL_TYPE(38, 9),
+ DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(38, 10), DECIMAL_TYPE(38, 38)));
+
+ EXPECT_ARROW_TYPE_EQUALS(
+ DECIMAL_TYPE(38, 6),
+ DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(38, 10), DECIMAL_TYPE(38, 2)));
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/decimal_xlarge.cc b/src/arrow/cpp/src/gandiva/decimal_xlarge.cc
new file mode 100644
index 000000000..caebd8b09
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/decimal_xlarge.cc
@@ -0,0 +1,284 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Operations that can deal with very large values (256-bit).
+//
+// The intermediate results with decimal can be larger than what can fit into 128-bit,
+// but the final results can fit in 128-bit after scaling down. These functions deal
+// with operations on the intermediate values.
+//
+
+#include "gandiva/decimal_xlarge.h"
+
+#include <boost/multiprecision/cpp_int.hpp>
+#include <limits>
+#include <vector>
+
+#include "arrow/util/basic_decimal.h"
+#include "arrow/util/logging.h"
+#include "gandiva/decimal_type_util.h"
+
+#ifndef GANDIVA_UNIT_TEST
+#include "gandiva/engine.h"
+#include "gandiva/exported_funcs.h"
+
+namespace gandiva {
+
+void ExportedDecimalFunctions::AddMappings(Engine* engine) const {
+ std::vector<llvm::Type*> args;
+ auto types = engine->types();
+
+ // gdv_multiply_and_scale_down
+ args = {types->i64_type(), // int64_t x_high
+ types->i64_type(), // uint64_t x_low
+ types->i64_type(), // int64_t y_high
+ types->i64_type(), // uint64_t x_low
+ types->i32_type(), // int32_t reduce_scale_by
+ types->i64_ptr_type(), // int64_t* out_high
+ types->i64_ptr_type(), // uint64_t* out_low
+ types->i8_ptr_type()}; // bool* overflow
+
+ engine->AddGlobalMappingForFunc(
+ "gdv_xlarge_multiply_and_scale_down", types->void_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_xlarge_multiply_and_scale_down));
+
+ // gdv_xlarge_scale_up_and_divide
+ args = {types->i64_type(), // int64_t x_high
+ types->i64_type(), // uint64_t x_low
+ types->i64_type(), // int64_t y_high
+ types->i64_type(), // uint64_t y_low
+ types->i32_type(), // int32_t increase_scale_by
+ types->i64_ptr_type(), // int64_t* out_high
+ types->i64_ptr_type(), // uint64_t* out_low
+ types->i8_ptr_type()}; // bool* overflow
+
+ engine->AddGlobalMappingForFunc(
+ "gdv_xlarge_scale_up_and_divide", types->void_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_xlarge_scale_up_and_divide));
+
+ // gdv_xlarge_mod
+ args = {types->i64_type(), // int64_t x_high
+ types->i64_type(), // uint64_t x_low
+ types->i32_type(), // int32_t x_scale
+ types->i64_type(), // int64_t y_high
+ types->i64_type(), // uint64_t y_low
+ types->i32_type(), // int32_t y_scale
+ types->i64_ptr_type(), // int64_t* out_high
+ types->i64_ptr_type()}; // uint64_t* out_low
+
+ engine->AddGlobalMappingForFunc("gdv_xlarge_mod", types->void_type() /*return_type*/,
+ args, reinterpret_cast<void*>(gdv_xlarge_mod));
+
+ // gdv_xlarge_compare
+ args = {types->i64_type(), // int64_t x_high
+ types->i64_type(), // uint64_t x_low
+ types->i32_type(), // int32_t x_scale
+ types->i64_type(), // int64_t y_high
+ types->i64_type(), // uint64_t y_low
+ types->i32_type()}; // int32_t y_scale
+
+ engine->AddGlobalMappingForFunc("gdv_xlarge_compare", types->i32_type() /*return_type*/,
+ args, reinterpret_cast<void*>(gdv_xlarge_compare));
+}
+
+} // namespace gandiva
+
+#endif // !GANDIVA_UNIT_TEST
+
+using arrow::BasicDecimal128;
+using boost::multiprecision::int256_t;
+
+namespace gandiva {
+namespace internal {
+
+// Convert to 256-bit integer from 128-bit decimal.
+static int256_t ConvertToInt256(BasicDecimal128 in) {
+ int256_t v = in.high_bits();
+ v <<= 64;
+ v |= in.low_bits();
+ return v;
+}
+
+// Convert to 128-bit decimal from 256-bit integer.
+// If there is an overflow, the output is undefined.
+static BasicDecimal128 ConvertToDecimal128(int256_t in, bool* overflow) {
+ BasicDecimal128 result;
+ constexpr int256_t UINT64_MASK = std::numeric_limits<uint64_t>::max();
+
+ int256_t in_abs = abs(in);
+ bool is_negative = in < 0;
+
+ uint64_t low = (in_abs & UINT64_MASK).convert_to<uint64_t>();
+ in_abs >>= 64;
+ uint64_t high = (in_abs & UINT64_MASK).convert_to<uint64_t>();
+ in_abs >>= 64;
+
+ if (in_abs > 0) {
+ // we've shifted in by 128-bit, so nothing should be left.
+ *overflow = true;
+ } else if (high > INT64_MAX) {
+ // the high-bit must not be set (signed 128-bit).
+ *overflow = true;
+ } else {
+ result = BasicDecimal128(static_cast<int64_t>(high), low);
+ if (result > BasicDecimal128::GetMaxValue()) {
+ *overflow = true;
+ }
+ }
+ return is_negative ? -result : result;
+}
+
+static constexpr int32_t kMaxLargeScale = 2 * DecimalTypeUtil::kMaxPrecision;
+
+// Compute the scale multipliers once.
+static std::array<int256_t, kMaxLargeScale + 1> kLargeScaleMultipliers =
+ ([]() -> std::array<int256_t, kMaxLargeScale + 1> {
+ std::array<int256_t, kMaxLargeScale + 1> values;
+ values[0] = 1;
+ for (int32_t idx = 1; idx <= kMaxLargeScale; idx++) {
+ values[idx] = values[idx - 1] * 10;
+ }
+ return values;
+ })();
+
+static int256_t GetScaleMultiplier(int scale) {
+ DCHECK_GE(scale, 0);
+ DCHECK_LE(scale, kMaxLargeScale);
+
+ return kLargeScaleMultipliers[scale];
+}
+
+// divide input by 10^reduce_by, and round up the fractional part.
+static int256_t ReduceScaleBy(int256_t in, int32_t reduce_by) {
+ if (reduce_by == 0) {
+ // nothing to do.
+ return in;
+ }
+
+ int256_t divisor = GetScaleMultiplier(reduce_by);
+ DCHECK_GT(divisor, 0);
+ DCHECK_EQ(divisor % 2, 0); // multiple of 10.
+ auto result = in / divisor;
+ auto remainder = in % divisor;
+ // round up (same as BasicDecimal128::ReduceScaleBy)
+ if (abs(remainder) >= (divisor >> 1)) {
+ result += (in > 0 ? 1 : -1);
+ }
+ return result;
+}
+
+// multiply input by 10^increase_by.
+static int256_t IncreaseScaleBy(int256_t in, int32_t increase_by) {
+ DCHECK_GE(increase_by, 0);
+ DCHECK_LE(increase_by, 2 * DecimalTypeUtil::kMaxPrecision);
+
+ return in * GetScaleMultiplier(increase_by);
+}
+
+} // namespace internal
+} // namespace gandiva
+
+extern "C" {
+
+void gdv_xlarge_multiply_and_scale_down(int64_t x_high, uint64_t x_low, int64_t y_high,
+ uint64_t y_low, int32_t reduce_scale_by,
+ int64_t* out_high, uint64_t* out_low,
+ bool* overflow) {
+ BasicDecimal128 x{x_high, x_low};
+ BasicDecimal128 y{y_high, y_low};
+ auto intermediate_result =
+ gandiva::internal::ConvertToInt256(x) * gandiva::internal::ConvertToInt256(y);
+ intermediate_result =
+ gandiva::internal::ReduceScaleBy(intermediate_result, reduce_scale_by);
+ auto result = gandiva::internal::ConvertToDecimal128(intermediate_result, overflow);
+ *out_high = result.high_bits();
+ *out_low = result.low_bits();
+}
+
+void gdv_xlarge_scale_up_and_divide(int64_t x_high, uint64_t x_low, int64_t y_high,
+ uint64_t y_low, int32_t increase_scale_by,
+ int64_t* out_high, uint64_t* out_low,
+ bool* overflow) {
+ BasicDecimal128 x{x_high, x_low};
+ BasicDecimal128 y{y_high, y_low};
+
+ int256_t x_large = gandiva::internal::ConvertToInt256(x);
+ int256_t x_large_scaled_up =
+ gandiva::internal::IncreaseScaleBy(x_large, increase_scale_by);
+ int256_t y_large = gandiva::internal::ConvertToInt256(y);
+ int256_t result_large = x_large_scaled_up / y_large;
+ int256_t remainder_large = x_large_scaled_up % y_large;
+
+ // Since we are scaling up and then, scaling down, round-up the result (+1 for +ve,
+ // -1 for -ve), if the remainder is >= 2 * divisor.
+ if (abs(2 * remainder_large) >= abs(y_large)) {
+ // x +ve and y +ve, result is +ve => (1 ^ 1) + 1 = 0 + 1 = +1
+ // x +ve and y -ve, result is -ve => (-1 ^ 1) + 1 = -2 + 1 = -1
+ // x +ve and y -ve, result is -ve => (1 ^ -1) + 1 = -2 + 1 = -1
+ // x -ve and y -ve, result is +ve => (-1 ^ -1) + 1 = 0 + 1 = +1
+ result_large += (x.Sign() ^ y.Sign()) + 1;
+ }
+ auto result = gandiva::internal::ConvertToDecimal128(result_large, overflow);
+ *out_high = result.high_bits();
+ *out_low = result.low_bits();
+}
+
+void gdv_xlarge_mod(int64_t x_high, uint64_t x_low, int32_t x_scale, int64_t y_high,
+ uint64_t y_low, int32_t y_scale, int64_t* out_high,
+ uint64_t* out_low) {
+ BasicDecimal128 x{x_high, x_low};
+ BasicDecimal128 y{y_high, y_low};
+
+ int256_t x_large = gandiva::internal::ConvertToInt256(x);
+ int256_t y_large = gandiva::internal::ConvertToInt256(y);
+ if (x_scale < y_scale) {
+ x_large = gandiva::internal::IncreaseScaleBy(x_large, y_scale - x_scale);
+ } else {
+ y_large = gandiva::internal::IncreaseScaleBy(y_large, x_scale - y_scale);
+ }
+ auto intermediate_result = x_large % y_large;
+ bool overflow = false;
+ auto result = gandiva::internal::ConvertToDecimal128(intermediate_result, &overflow);
+ DCHECK_EQ(overflow, false);
+
+ *out_high = result.high_bits();
+ *out_low = result.low_bits();
+}
+
+int32_t gdv_xlarge_compare(int64_t x_high, uint64_t x_low, int32_t x_scale,
+ int64_t y_high, uint64_t y_low, int32_t y_scale) {
+ BasicDecimal128 x{x_high, x_low};
+ BasicDecimal128 y{y_high, y_low};
+
+ int256_t x_large = gandiva::internal::ConvertToInt256(x);
+ int256_t y_large = gandiva::internal::ConvertToInt256(y);
+ if (x_scale < y_scale) {
+ x_large = gandiva::internal::IncreaseScaleBy(x_large, y_scale - x_scale);
+ } else {
+ y_large = gandiva::internal::IncreaseScaleBy(y_large, x_scale - y_scale);
+ }
+
+ if (x_large == y_large) {
+ return 0;
+ } else if (x_large < y_large) {
+ return -1;
+ } else {
+ return 1;
+ }
+}
+
+} // extern "C"
diff --git a/src/arrow/cpp/src/gandiva/decimal_xlarge.h b/src/arrow/cpp/src/gandiva/decimal_xlarge.h
new file mode 100644
index 000000000..264329775
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/decimal_xlarge.h
@@ -0,0 +1,41 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+
+/// Stub functions to deal with extra large decimals that can be accessed from LLVM-IR
+/// code.
+extern "C" {
+
+void gdv_xlarge_multiply_and_scale_down(int64_t x_high, uint64_t x_low, int64_t y_high,
+ uint64_t y_low, int32_t reduce_scale_by,
+ int64_t* out_high, uint64_t* out_low,
+ bool* overflow);
+
+void gdv_xlarge_scale_up_and_divide(int64_t x_high, uint64_t x_low, int64_t y_high,
+ uint64_t y_low, int32_t increase_scale_by,
+ int64_t* out_high, uint64_t* out_low, bool* overflow);
+
+void gdv_xlarge_mod(int64_t x_high, uint64_t x_low, int32_t x_scale, int64_t y_high,
+ uint64_t y_low, int32_t y_scale, int64_t* out_high,
+ uint64_t* out_low);
+
+int32_t gdv_xlarge_compare(int64_t x_high, uint64_t x_low, int32_t x_scale,
+ int64_t y_high, uint64_t y_low, int32_t y_scale);
+}
diff --git a/src/arrow/cpp/src/gandiva/dex.h b/src/arrow/cpp/src/gandiva/dex.h
new file mode 100644
index 000000000..d1115c051
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/dex.h
@@ -0,0 +1,396 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "gandiva/dex_visitor.h"
+#include "gandiva/field_descriptor.h"
+#include "gandiva/func_descriptor.h"
+#include "gandiva/function_holder.h"
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/in_holder.h"
+#include "gandiva/literal_holder.h"
+#include "gandiva/native_function.h"
+#include "gandiva/value_validity_pair.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// \brief Decomposed expression : the validity and value are separated.
+class GANDIVA_EXPORT Dex {
+ public:
+ /// Derived classes should simply invoke the Visit api of the visitor.
+ virtual void Accept(DexVisitor& visitor) = 0;
+ virtual ~Dex() = default;
+};
+
+/// Base class for other Vector related Dex.
+class GANDIVA_EXPORT VectorReadBaseDex : public Dex {
+ public:
+ explicit VectorReadBaseDex(FieldDescriptorPtr field_desc) : field_desc_(field_desc) {}
+
+ const std::string& FieldName() const { return field_desc_->Name(); }
+
+ DataTypePtr FieldType() const { return field_desc_->Type(); }
+
+ FieldPtr Field() const { return field_desc_->field(); }
+
+ protected:
+ FieldDescriptorPtr field_desc_;
+};
+
+/// validity component of a ValueVector
+class GANDIVA_EXPORT VectorReadValidityDex : public VectorReadBaseDex {
+ public:
+ explicit VectorReadValidityDex(FieldDescriptorPtr field_desc)
+ : VectorReadBaseDex(field_desc) {}
+
+ int ValidityIdx() const { return field_desc_->validity_idx(); }
+
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+};
+
+/// value component of a fixed-len ValueVector
+class GANDIVA_EXPORT VectorReadFixedLenValueDex : public VectorReadBaseDex {
+ public:
+ explicit VectorReadFixedLenValueDex(FieldDescriptorPtr field_desc)
+ : VectorReadBaseDex(field_desc) {}
+
+ int DataIdx() const { return field_desc_->data_idx(); }
+
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+};
+
+/// value component of a variable-len ValueVector
+class GANDIVA_EXPORT VectorReadVarLenValueDex : public VectorReadBaseDex {
+ public:
+ explicit VectorReadVarLenValueDex(FieldDescriptorPtr field_desc)
+ : VectorReadBaseDex(field_desc) {}
+
+ int DataIdx() const { return field_desc_->data_idx(); }
+
+ int OffsetsIdx() const { return field_desc_->offsets_idx(); }
+
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+};
+
+/// validity based on a local bitmap.
+class GANDIVA_EXPORT LocalBitMapValidityDex : public Dex {
+ public:
+ explicit LocalBitMapValidityDex(int local_bitmap_idx)
+ : local_bitmap_idx_(local_bitmap_idx) {}
+
+ int local_bitmap_idx() const { return local_bitmap_idx_; }
+
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+
+ private:
+ int local_bitmap_idx_;
+};
+
+/// base function expression
+class GANDIVA_EXPORT FuncDex : public Dex {
+ public:
+ FuncDex(FuncDescriptorPtr func_descriptor, const NativeFunction* native_function,
+ FunctionHolderPtr function_holder, const ValueValidityPairVector& args)
+ : func_descriptor_(func_descriptor),
+ native_function_(native_function),
+ function_holder_(function_holder),
+ args_(args) {}
+
+ FuncDescriptorPtr func_descriptor() const { return func_descriptor_; }
+
+ const NativeFunction* native_function() const { return native_function_; }
+
+ FunctionHolderPtr function_holder() const { return function_holder_; }
+
+ const ValueValidityPairVector& args() const { return args_; }
+
+ private:
+ FuncDescriptorPtr func_descriptor_;
+ const NativeFunction* native_function_;
+ FunctionHolderPtr function_holder_;
+ ValueValidityPairVector args_;
+};
+
+/// A function expression that only deals with non-null inputs, and generates non-null
+/// outputs.
+class GANDIVA_EXPORT NonNullableFuncDex : public FuncDex {
+ public:
+ NonNullableFuncDex(FuncDescriptorPtr func_descriptor,
+ const NativeFunction* native_function,
+ FunctionHolderPtr function_holder,
+ const ValueValidityPairVector& args)
+ : FuncDex(func_descriptor, native_function, function_holder, args) {}
+
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+};
+
+/// A function expression that deals with nullable inputs, but generates non-null
+/// outputs.
+class GANDIVA_EXPORT NullableNeverFuncDex : public FuncDex {
+ public:
+ NullableNeverFuncDex(FuncDescriptorPtr func_descriptor,
+ const NativeFunction* native_function,
+ FunctionHolderPtr function_holder,
+ const ValueValidityPairVector& args)
+ : FuncDex(func_descriptor, native_function, function_holder, args) {}
+
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+};
+
+/// A function expression that deals with nullable inputs, and
+/// nullable outputs.
+class GANDIVA_EXPORT NullableInternalFuncDex : public FuncDex {
+ public:
+ NullableInternalFuncDex(FuncDescriptorPtr func_descriptor,
+ const NativeFunction* native_function,
+ FunctionHolderPtr function_holder,
+ const ValueValidityPairVector& args, int local_bitmap_idx)
+ : FuncDex(func_descriptor, native_function, function_holder, args),
+ local_bitmap_idx_(local_bitmap_idx) {}
+
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+
+ /// The validity of the function result is saved in this bitmap.
+ int local_bitmap_idx() const { return local_bitmap_idx_; }
+
+ private:
+ int local_bitmap_idx_;
+};
+
+/// special validity type that always returns true.
+class GANDIVA_EXPORT TrueDex : public Dex {
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+};
+
+/// special validity type that always returns false.
+class GANDIVA_EXPORT FalseDex : public Dex {
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+};
+
+/// decomposed expression for a literal.
+class GANDIVA_EXPORT LiteralDex : public Dex {
+ public:
+ LiteralDex(DataTypePtr type, const LiteralHolder& holder)
+ : type_(type), holder_(holder) {}
+
+ const DataTypePtr& type() const { return type_; }
+
+ const LiteralHolder& holder() const { return holder_; }
+
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+
+ private:
+ DataTypePtr type_;
+ LiteralHolder holder_;
+};
+
+/// decomposed if-else expression.
+class GANDIVA_EXPORT IfDex : public Dex {
+ public:
+ IfDex(ValueValidityPairPtr condition_vv, ValueValidityPairPtr then_vv,
+ ValueValidityPairPtr else_vv, DataTypePtr result_type, int local_bitmap_idx,
+ bool is_terminal_else)
+ : condition_vv_(condition_vv),
+ then_vv_(then_vv),
+ else_vv_(else_vv),
+ result_type_(result_type),
+ local_bitmap_idx_(local_bitmap_idx),
+ is_terminal_else_(is_terminal_else) {}
+
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+
+ const ValueValidityPair& condition_vv() const { return *condition_vv_; }
+ const ValueValidityPair& then_vv() const { return *then_vv_; }
+ const ValueValidityPair& else_vv() const { return *else_vv_; }
+
+ /// The validity of the result is saved in this bitmap.
+ int local_bitmap_idx() const { return local_bitmap_idx_; }
+
+ /// is this a terminal else ? i.e no nested if-else underneath.
+ bool is_terminal_else() const { return is_terminal_else_; }
+
+ const DataTypePtr& result_type() const { return result_type_; }
+
+ private:
+ ValueValidityPairPtr condition_vv_;
+ ValueValidityPairPtr then_vv_;
+ ValueValidityPairPtr else_vv_;
+ DataTypePtr result_type_;
+ int local_bitmap_idx_;
+ bool is_terminal_else_;
+};
+
+// decomposed boolean expression.
+class GANDIVA_EXPORT BooleanDex : public Dex {
+ public:
+ BooleanDex(const ValueValidityPairVector& args, int local_bitmap_idx)
+ : args_(args), local_bitmap_idx_(local_bitmap_idx) {}
+
+ const ValueValidityPairVector& args() const { return args_; }
+
+ /// The validity of the result is saved in this bitmap.
+ int local_bitmap_idx() const { return local_bitmap_idx_; }
+
+ private:
+ ValueValidityPairVector args_;
+ int local_bitmap_idx_;
+};
+
+/// Boolean-AND expression
+class GANDIVA_EXPORT BooleanAndDex : public BooleanDex {
+ public:
+ BooleanAndDex(const ValueValidityPairVector& args, int local_bitmap_idx)
+ : BooleanDex(args, local_bitmap_idx) {}
+
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+};
+
+/// Boolean-OR expression
+class GANDIVA_EXPORT BooleanOrDex : public BooleanDex {
+ public:
+ BooleanOrDex(const ValueValidityPairVector& args, int local_bitmap_idx)
+ : BooleanDex(args, local_bitmap_idx) {}
+
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+};
+
+// decomposed in expression.
+template <typename Type>
+class InExprDex;
+
+template <typename Type>
+class InExprDexBase : public Dex {
+ public:
+ InExprDexBase(const ValueValidityPairVector& args,
+ const std::unordered_set<Type>& values)
+ : args_(args) {
+ in_holder_.reset(new InHolder<Type>(values));
+ }
+
+ const ValueValidityPairVector& args() const { return args_; }
+
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+
+ const std::string& runtime_function() const { return runtime_function_; }
+
+ const std::shared_ptr<InHolder<Type>>& in_holder() const { return in_holder_; }
+
+ protected:
+ ValueValidityPairVector args_;
+ std::string runtime_function_;
+ std::shared_ptr<InHolder<Type>> in_holder_;
+};
+
+template <>
+class InExprDexBase<gandiva::DecimalScalar128> : public Dex {
+ public:
+ InExprDexBase(const ValueValidityPairVector& args,
+ const std::unordered_set<gandiva::DecimalScalar128>& values,
+ int32_t precision, int32_t scale)
+ : args_(args), precision_(precision), scale_(scale) {
+ in_holder_.reset(new InHolder<gandiva::DecimalScalar128>(values));
+ }
+
+ int32_t get_precision() const { return precision_; }
+
+ int32_t get_scale() const { return scale_; }
+
+ const ValueValidityPairVector& args() const { return args_; }
+
+ void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
+
+ const std::string& runtime_function() const { return runtime_function_; }
+
+ const std::shared_ptr<InHolder<gandiva::DecimalScalar128>>& in_holder() const {
+ return in_holder_;
+ }
+
+ protected:
+ ValueValidityPairVector args_;
+ std::string runtime_function_;
+ std::shared_ptr<InHolder<gandiva::DecimalScalar128>> in_holder_;
+ int32_t precision_, scale_;
+};
+
+template <>
+class InExprDex<int32_t> : public InExprDexBase<int32_t> {
+ public:
+ InExprDex(const ValueValidityPairVector& args,
+ const std::unordered_set<int32_t>& values)
+ : InExprDexBase(args, values) {
+ runtime_function_ = "gdv_fn_in_expr_lookup_int32";
+ }
+};
+
+template <>
+class InExprDex<int64_t> : public InExprDexBase<int64_t> {
+ public:
+ InExprDex(const ValueValidityPairVector& args,
+ const std::unordered_set<int64_t>& values)
+ : InExprDexBase(args, values) {
+ runtime_function_ = "gdv_fn_in_expr_lookup_int64";
+ }
+};
+
+template <>
+class InExprDex<float> : public InExprDexBase<float> {
+ public:
+ InExprDex(const ValueValidityPairVector& args, const std::unordered_set<float>& values)
+ : InExprDexBase(args, values) {
+ runtime_function_ = "gdv_fn_in_expr_lookup_float";
+ }
+};
+
+template <>
+class InExprDex<double> : public InExprDexBase<double> {
+ public:
+ InExprDex(const ValueValidityPairVector& args, const std::unordered_set<double>& values)
+ : InExprDexBase(args, values) {
+ runtime_function_ = "gdv_fn_in_expr_lookup_double";
+ }
+};
+
+template <>
+class InExprDex<gandiva::DecimalScalar128>
+ : public InExprDexBase<gandiva::DecimalScalar128> {
+ public:
+ InExprDex(const ValueValidityPairVector& args,
+ const std::unordered_set<gandiva::DecimalScalar128>& values,
+ int32_t precision, int32_t scale)
+ : InExprDexBase<gandiva::DecimalScalar128>(args, values, precision, scale) {
+ runtime_function_ = "gdv_fn_in_expr_lookup_decimal";
+ }
+};
+
+template <>
+class InExprDex<std::string> : public InExprDexBase<std::string> {
+ public:
+ InExprDex(const ValueValidityPairVector& args,
+ const std::unordered_set<std::string>& values)
+ : InExprDexBase(args, values) {
+ runtime_function_ = "gdv_fn_in_expr_lookup_utf8";
+ }
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/dex_visitor.h b/src/arrow/cpp/src/gandiva/dex_visitor.h
new file mode 100644
index 000000000..5d160bb22
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/dex_visitor.h
@@ -0,0 +1,97 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cmath>
+#include <string>
+
+#include "arrow/util/logging.h"
+#include "gandiva/decimal_scalar.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+class VectorReadValidityDex;
+class VectorReadFixedLenValueDex;
+class VectorReadVarLenValueDex;
+class LocalBitMapValidityDex;
+class LiteralDex;
+class TrueDex;
+class FalseDex;
+class NonNullableFuncDex;
+class NullableNeverFuncDex;
+class NullableInternalFuncDex;
+class IfDex;
+class BooleanAndDex;
+class BooleanOrDex;
+template <typename Type>
+class InExprDexBase;
+
+/// \brief Visitor for decomposed expression.
+class GANDIVA_EXPORT DexVisitor {
+ public:
+ virtual ~DexVisitor() = default;
+
+ virtual void Visit(const VectorReadValidityDex& dex) = 0;
+ virtual void Visit(const VectorReadFixedLenValueDex& dex) = 0;
+ virtual void Visit(const VectorReadVarLenValueDex& dex) = 0;
+ virtual void Visit(const LocalBitMapValidityDex& dex) = 0;
+ virtual void Visit(const TrueDex& dex) = 0;
+ virtual void Visit(const FalseDex& dex) = 0;
+ virtual void Visit(const LiteralDex& dex) = 0;
+ virtual void Visit(const NonNullableFuncDex& dex) = 0;
+ virtual void Visit(const NullableNeverFuncDex& dex) = 0;
+ virtual void Visit(const NullableInternalFuncDex& dex) = 0;
+ virtual void Visit(const IfDex& dex) = 0;
+ virtual void Visit(const BooleanAndDex& dex) = 0;
+ virtual void Visit(const BooleanOrDex& dex) = 0;
+ virtual void Visit(const InExprDexBase<int32_t>& dex) = 0;
+ virtual void Visit(const InExprDexBase<int64_t>& dex) = 0;
+ virtual void Visit(const InExprDexBase<float>& dex) = 0;
+ virtual void Visit(const InExprDexBase<double>& dex) = 0;
+ virtual void Visit(const InExprDexBase<gandiva::DecimalScalar128>& dex) = 0;
+ virtual void Visit(const InExprDexBase<std::string>& dex) = 0;
+};
+
+/// Default implementation with only DCHECK().
+#define VISIT_DCHECK(DEX_CLASS) \
+ void Visit(const DEX_CLASS& dex) override { DCHECK(0); }
+
+class GANDIVA_EXPORT DexDefaultVisitor : public DexVisitor {
+ VISIT_DCHECK(VectorReadValidityDex)
+ VISIT_DCHECK(VectorReadFixedLenValueDex)
+ VISIT_DCHECK(VectorReadVarLenValueDex)
+ VISIT_DCHECK(LocalBitMapValidityDex)
+ VISIT_DCHECK(TrueDex)
+ VISIT_DCHECK(FalseDex)
+ VISIT_DCHECK(LiteralDex)
+ VISIT_DCHECK(NonNullableFuncDex)
+ VISIT_DCHECK(NullableNeverFuncDex)
+ VISIT_DCHECK(NullableInternalFuncDex)
+ VISIT_DCHECK(IfDex)
+ VISIT_DCHECK(BooleanAndDex)
+ VISIT_DCHECK(BooleanOrDex)
+ VISIT_DCHECK(InExprDexBase<int32_t>)
+ VISIT_DCHECK(InExprDexBase<int64_t>)
+ VISIT_DCHECK(InExprDexBase<float>)
+ VISIT_DCHECK(InExprDexBase<double>)
+ VISIT_DCHECK(InExprDexBase<gandiva::DecimalScalar128>)
+ VISIT_DCHECK(InExprDexBase<std::string>)
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/engine.cc b/src/arrow/cpp/src/gandiva/engine.cc
new file mode 100644
index 000000000..f0b768f5f
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/engine.cc
@@ -0,0 +1,338 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// TODO(wesm): LLVM 7 produces pesky C4244 that disable pragmas around the LLVM
+// includes seem to not fix as with LLVM 6
+#if defined(_MSC_VER)
+#pragma warning(disable : 4244)
+#endif
+
+#include "gandiva/engine.h"
+
+#include <iostream>
+#include <memory>
+#include <mutex>
+#include <sstream>
+#include <string>
+#include <unordered_set>
+#include <utility>
+
+#include "arrow/util/logging.h"
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4141)
+#pragma warning(disable : 4146)
+#pragma warning(disable : 4244)
+#pragma warning(disable : 4267)
+#pragma warning(disable : 4624)
+#endif
+
+#include <llvm/Analysis/Passes.h>
+#include <llvm/Analysis/TargetTransformInfo.h>
+#include <llvm/Bitcode/BitcodeReader.h>
+#include <llvm/ExecutionEngine/ExecutionEngine.h>
+#include <llvm/ExecutionEngine/MCJIT.h>
+#include <llvm/IR/DataLayout.h>
+#include <llvm/IR/IRBuilder.h>
+#include <llvm/IR/LLVMContext.h>
+#include <llvm/IR/LegacyPassManager.h>
+#include <llvm/IR/Verifier.h>
+#include <llvm/Linker/Linker.h>
+#include <llvm/MC/SubtargetFeature.h>
+#include <llvm/Support/DynamicLibrary.h>
+#include <llvm/Support/Host.h>
+#include <llvm/Support/TargetRegistry.h>
+#include <llvm/Support/TargetSelect.h>
+#include <llvm/Support/raw_ostream.h>
+#include <llvm/Transforms/IPO.h>
+#include <llvm/Transforms/IPO/PassManagerBuilder.h>
+#include <llvm/Transforms/InstCombine/InstCombine.h>
+#include <llvm/Transforms/Scalar.h>
+#include <llvm/Transforms/Scalar/GVN.h>
+#include <llvm/Transforms/Utils.h>
+#include <llvm/Transforms/Vectorize.h>
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+#include "arrow/util/make_unique.h"
+#include "gandiva/configuration.h"
+#include "gandiva/decimal_ir.h"
+#include "gandiva/exported_funcs_registry.h"
+
+namespace gandiva {
+
+extern const unsigned char kPrecompiledBitcode[];
+extern const size_t kPrecompiledBitcodeSize;
+
+std::once_flag llvm_init_once_flag;
+static bool llvm_init = false;
+static llvm::StringRef cpu_name;
+static llvm::SmallVector<std::string, 10> cpu_attrs;
+
+void Engine::InitOnce() {
+ DCHECK_EQ(llvm_init, false);
+
+ llvm::InitializeNativeTarget();
+ llvm::InitializeNativeTargetAsmPrinter();
+ llvm::InitializeNativeTargetAsmParser();
+ llvm::InitializeNativeTargetDisassembler();
+ llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
+
+ cpu_name = llvm::sys::getHostCPUName();
+ llvm::StringMap<bool> host_features;
+ std::string cpu_attrs_str;
+ if (llvm::sys::getHostCPUFeatures(host_features)) {
+ for (auto& f : host_features) {
+ std::string attr = f.second ? std::string("+") + f.first().str()
+ : std::string("-") + f.first().str();
+ cpu_attrs.push_back(attr);
+ cpu_attrs_str += " " + attr;
+ }
+ }
+ ARROW_LOG(INFO) << "Detected CPU Name : " << cpu_name.str();
+ ARROW_LOG(INFO) << "Detected CPU Features:" << cpu_attrs_str;
+ llvm_init = true;
+}
+
+Engine::Engine(const std::shared_ptr<Configuration>& conf,
+ std::unique_ptr<llvm::LLVMContext> ctx,
+ std::unique_ptr<llvm::ExecutionEngine> engine, llvm::Module* module)
+ : context_(std::move(ctx)),
+ execution_engine_(std::move(engine)),
+ ir_builder_(arrow::internal::make_unique<llvm::IRBuilder<>>(*context_)),
+ module_(module),
+ types_(*context_),
+ optimize_(conf->optimize()) {}
+
+Status Engine::Init() {
+ // Add mappings for functions that can be accessed from LLVM/IR module.
+ AddGlobalMappings();
+
+ ARROW_RETURN_NOT_OK(LoadPreCompiledIR());
+ ARROW_RETURN_NOT_OK(DecimalIR::AddFunctions(this));
+
+ return Status::OK();
+}
+
+/// factory method to construct the engine.
+Status Engine::Make(const std::shared_ptr<Configuration>& conf,
+ std::unique_ptr<Engine>* out) {
+ std::call_once(llvm_init_once_flag, InitOnce);
+
+ auto ctx = arrow::internal::make_unique<llvm::LLVMContext>();
+ auto module = arrow::internal::make_unique<llvm::Module>("codegen", *ctx);
+
+ // Capture before moving, ExecutionEngine does not allow retrieving the
+ // original Module.
+ auto module_ptr = module.get();
+
+ auto opt_level =
+ conf->optimize() ? llvm::CodeGenOpt::Aggressive : llvm::CodeGenOpt::None;
+
+ // Note that the lifetime of the error string is not captured by the
+ // ExecutionEngine but only for the lifetime of the builder. Found by
+ // inspecting LLVM sources.
+ std::string builder_error;
+
+ llvm::EngineBuilder engine_builder(std::move(module));
+
+ engine_builder.setEngineKind(llvm::EngineKind::JIT)
+ .setOptLevel(opt_level)
+ .setErrorStr(&builder_error);
+
+ if (conf->target_host_cpu()) {
+ engine_builder.setMCPU(cpu_name);
+ engine_builder.setMAttrs(cpu_attrs);
+ }
+ std::unique_ptr<llvm::ExecutionEngine> exec_engine{engine_builder.create()};
+
+ if (exec_engine == nullptr) {
+ return Status::CodeGenError("Could not instantiate llvm::ExecutionEngine: ",
+ builder_error);
+ }
+
+ std::unique_ptr<Engine> engine{
+ new Engine(conf, std::move(ctx), std::move(exec_engine), module_ptr)};
+ ARROW_RETURN_NOT_OK(engine->Init());
+ *out = std::move(engine);
+ return Status::OK();
+}
+
+// This method was modified from its original version for a part of MLIR
+// Original source from
+// https://github.com/llvm/llvm-project/blob/9f2ce5b915a505a5488a5cf91bb0a8efa9ddfff7/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
+// The original copyright notice follows.
+
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+static void SetDataLayout(llvm::Module* module) {
+ auto target_triple = llvm::sys::getDefaultTargetTriple();
+ std::string error_message;
+ auto target = llvm::TargetRegistry::lookupTarget(target_triple, error_message);
+ if (!target) {
+ return;
+ }
+
+ std::string cpu(llvm::sys::getHostCPUName());
+ llvm::SubtargetFeatures features;
+ llvm::StringMap<bool> host_features;
+
+ if (llvm::sys::getHostCPUFeatures(host_features)) {
+ for (auto& f : host_features) {
+ features.AddFeature(f.first(), f.second);
+ }
+ }
+
+ std::unique_ptr<llvm::TargetMachine> machine(
+ target->createTargetMachine(target_triple, cpu, features.getString(), {}, {}));
+
+ module->setDataLayout(machine->createDataLayout());
+}
+// end of the mofified method from MLIR
+
+// Handling for pre-compiled IR libraries.
+Status Engine::LoadPreCompiledIR() {
+ auto bitcode = llvm::StringRef(reinterpret_cast<const char*>(kPrecompiledBitcode),
+ kPrecompiledBitcodeSize);
+
+ /// Read from file into memory buffer.
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> buffer_or_error =
+ llvm::MemoryBuffer::getMemBuffer(bitcode, "precompiled", false);
+
+ ARROW_RETURN_IF(!buffer_or_error,
+ Status::CodeGenError("Could not load module from IR: ",
+ buffer_or_error.getError().message()));
+
+ std::unique_ptr<llvm::MemoryBuffer> buffer = move(buffer_or_error.get());
+
+ /// Parse the IR module.
+ llvm::Expected<std::unique_ptr<llvm::Module>> module_or_error =
+ llvm::getOwningLazyBitcodeModule(move(buffer), *context());
+ if (!module_or_error) {
+ // NOTE: llvm::handleAllErrors() fails linking with RTTI-disabled LLVM builds
+ // (ARROW-5148)
+ std::string str;
+ llvm::raw_string_ostream stream(str);
+ stream << module_or_error.takeError();
+ return Status::CodeGenError(stream.str());
+ }
+ std::unique_ptr<llvm::Module> ir_module = move(module_or_error.get());
+
+ // set dataLayout
+ SetDataLayout(ir_module.get());
+
+ ARROW_RETURN_IF(llvm::verifyModule(*ir_module, &llvm::errs()),
+ Status::CodeGenError("verify of IR Module failed"));
+ ARROW_RETURN_IF(llvm::Linker::linkModules(*module_, move(ir_module)),
+ Status::CodeGenError("failed to link IR Modules"));
+
+ return Status::OK();
+}
+
+// Get rid of all functions that don't need to be compiled.
+// This helps in reducing the overall compilation time. This pass is trivial,
+// and is always done since the number of functions in gandiva is very high.
+// (Adapted from Apache Impala)
+//
+// Done by marking all the unused functions as internal, and then, running
+// a pass for dead code elimination.
+Status Engine::RemoveUnusedFunctions() {
+ // Setup an optimiser pipeline
+ std::unique_ptr<llvm::legacy::PassManager> pass_manager(
+ new llvm::legacy::PassManager());
+
+ std::unordered_set<std::string> used_functions;
+ used_functions.insert(functions_to_compile_.begin(), functions_to_compile_.end());
+
+ pass_manager->add(
+ llvm::createInternalizePass([&used_functions](const llvm::GlobalValue& func) {
+ return (used_functions.find(func.getName().str()) != used_functions.end());
+ }));
+ pass_manager->add(llvm::createGlobalDCEPass());
+ pass_manager->run(*module_);
+ return Status::OK();
+}
+
+// Optimise and compile the module.
+Status Engine::FinalizeModule() {
+ ARROW_RETURN_NOT_OK(RemoveUnusedFunctions());
+
+ if (optimize_) {
+ // misc passes to allow for inlining, vectorization, ..
+ std::unique_ptr<llvm::legacy::PassManager> pass_manager(
+ new llvm::legacy::PassManager());
+
+ llvm::TargetIRAnalysis target_analysis =
+ execution_engine_->getTargetMachine()->getTargetIRAnalysis();
+ pass_manager->add(llvm::createTargetTransformInfoWrapperPass(target_analysis));
+ pass_manager->add(llvm::createFunctionInliningPass());
+ pass_manager->add(llvm::createInstructionCombiningPass());
+ pass_manager->add(llvm::createPromoteMemoryToRegisterPass());
+ pass_manager->add(llvm::createGVNPass());
+ pass_manager->add(llvm::createNewGVNPass());
+ pass_manager->add(llvm::createCFGSimplificationPass());
+ pass_manager->add(llvm::createLoopVectorizePass());
+ pass_manager->add(llvm::createSLPVectorizerPass());
+ pass_manager->add(llvm::createGlobalOptimizerPass());
+
+ // run the optimiser
+ llvm::PassManagerBuilder pass_builder;
+ pass_builder.OptLevel = 3;
+ pass_builder.populateModulePassManager(*pass_manager);
+ pass_manager->run(*module_);
+ }
+
+ ARROW_RETURN_IF(llvm::verifyModule(*module_, &llvm::errs()),
+ Status::CodeGenError("Module verification failed after optimizer"));
+
+ // do the compilation
+ execution_engine_->finalizeObject();
+ module_finalized_ = true;
+
+ return Status::OK();
+}
+
+void* Engine::CompiledFunction(llvm::Function* irFunction) {
+ DCHECK(module_finalized_);
+ return execution_engine_->getPointerToFunction(irFunction);
+}
+
+void Engine::AddGlobalMappingForFunc(const std::string& name, llvm::Type* ret_type,
+ const std::vector<llvm::Type*>& args,
+ void* function_ptr) {
+ constexpr bool is_var_arg = false;
+ auto prototype = llvm::FunctionType::get(ret_type, args, is_var_arg);
+ constexpr auto linkage = llvm::GlobalValue::ExternalLinkage;
+ auto fn = llvm::Function::Create(prototype, linkage, name, module());
+ execution_engine_->addGlobalMapping(fn, function_ptr);
+}
+
+void Engine::AddGlobalMappings() { ExportedFuncsRegistry::AddMappings(this); }
+
+std::string Engine::DumpIR() {
+ std::string ir;
+ llvm::raw_string_ostream stream(ir);
+ module_->print(stream, nullptr);
+ return ir;
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/engine.h b/src/arrow/cpp/src/gandiva/engine.h
new file mode 100644
index 000000000..d26b8aa0e
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/engine.h
@@ -0,0 +1,104 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "arrow/util/macros.h"
+
+#include "arrow/util/logging.h"
+#include "gandiva/configuration.h"
+#include "gandiva/llvm_includes.h"
+#include "gandiva/llvm_types.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// \brief LLVM Execution engine wrapper.
+class GANDIVA_EXPORT Engine {
+ public:
+ llvm::LLVMContext* context() { return context_.get(); }
+ llvm::IRBuilder<>* ir_builder() { return ir_builder_.get(); }
+ LLVMTypes* types() { return &types_; }
+ llvm::Module* module() { return module_; }
+
+ /// Factory method to create and initialize the engine object.
+ ///
+ /// \param[in] config the engine configuration
+ /// \param[out] engine the created engine
+ static Status Make(const std::shared_ptr<Configuration>& config,
+ std::unique_ptr<Engine>* engine);
+
+ /// Add the function to the list of IR functions that need to be compiled.
+ /// Compiling only the functions that are used by the module saves time.
+ void AddFunctionToCompile(const std::string& fname) {
+ DCHECK(!module_finalized_);
+ functions_to_compile_.push_back(fname);
+ }
+
+ /// Optimise and compile the module.
+ Status FinalizeModule();
+
+ /// Get the compiled function corresponding to the irfunction.
+ void* CompiledFunction(llvm::Function* irFunction);
+
+ // Create and add a mapping for the cpp function to make it accessible from LLVM.
+ void AddGlobalMappingForFunc(const std::string& name, llvm::Type* ret_type,
+ const std::vector<llvm::Type*>& args, void* func);
+
+ /// Return the generated IR for the module.
+ std::string DumpIR();
+
+ private:
+ Engine(const std::shared_ptr<Configuration>& conf,
+ std::unique_ptr<llvm::LLVMContext> ctx,
+ std::unique_ptr<llvm::ExecutionEngine> engine, llvm::Module* module);
+
+ // Post construction init. This _must_ be called after the constructor.
+ Status Init();
+
+ static void InitOnce();
+
+ llvm::ExecutionEngine& execution_engine() { return *execution_engine_; }
+
+ /// load pre-compiled IR modules from precompiled_bitcode.cc and merge them into
+ /// the main module.
+ Status LoadPreCompiledIR();
+
+ // Create and add mappings for cpp functions that can be accessed from LLVM.
+ void AddGlobalMappings();
+
+ // Remove unused functions to reduce compile time.
+ Status RemoveUnusedFunctions();
+
+ std::unique_ptr<llvm::LLVMContext> context_;
+ std::unique_ptr<llvm::ExecutionEngine> execution_engine_;
+ std::unique_ptr<llvm::IRBuilder<>> ir_builder_;
+ llvm::Module* module_;
+ LLVMTypes types_;
+
+ std::vector<std::string> functions_to_compile_;
+
+ bool optimize_ = true;
+ bool module_finalized_ = false;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/engine_llvm_test.cc b/src/arrow/cpp/src/gandiva/engine_llvm_test.cc
new file mode 100644
index 000000000..ef2275b34
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/engine_llvm_test.cc
@@ -0,0 +1,131 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/engine.h"
+
+#include <gtest/gtest.h>
+#include <functional>
+#include "gandiva/llvm_types.h"
+#include "gandiva/tests/test_util.h"
+
+namespace gandiva {
+
+typedef int64_t (*add_vector_func_t)(int64_t* data, int n);
+
+class TestEngine : public ::testing::Test {
+ protected:
+ llvm::Function* BuildVecAdd(Engine* engine) {
+ auto types = engine->types();
+ llvm::IRBuilder<>* builder = engine->ir_builder();
+ llvm::LLVMContext* context = engine->context();
+
+ // Create fn prototype :
+ // int64_t add_longs(int64_t *elements, int32_t nelements)
+ std::vector<llvm::Type*> arguments;
+ arguments.push_back(types->i64_ptr_type());
+ arguments.push_back(types->i32_type());
+ llvm::FunctionType* prototype =
+ llvm::FunctionType::get(types->i64_type(), arguments, false /*isVarArg*/);
+
+ // Create fn
+ std::string func_name = "add_longs";
+ engine->AddFunctionToCompile(func_name);
+ llvm::Function* fn = llvm::Function::Create(
+ prototype, llvm::GlobalValue::ExternalLinkage, func_name, engine->module());
+ assert(fn != nullptr);
+
+ // Name the arguments
+ llvm::Function::arg_iterator args = fn->arg_begin();
+ llvm::Value* arg_elements = &*args;
+ arg_elements->setName("elements");
+ ++args;
+ llvm::Value* arg_nelements = &*args;
+ arg_nelements->setName("nelements");
+ ++args;
+
+ llvm::BasicBlock* loop_entry = llvm::BasicBlock::Create(*context, "entry", fn);
+ llvm::BasicBlock* loop_body = llvm::BasicBlock::Create(*context, "loop", fn);
+ llvm::BasicBlock* loop_exit = llvm::BasicBlock::Create(*context, "exit", fn);
+
+ // Loop entry
+ builder->SetInsertPoint(loop_entry);
+ builder->CreateBr(loop_body);
+
+ // Loop body
+ builder->SetInsertPoint(loop_body);
+
+ llvm::PHINode* loop_var = builder->CreatePHI(types->i32_type(), 2, "loop_var");
+ llvm::PHINode* sum = builder->CreatePHI(types->i64_type(), 2, "sum");
+
+ loop_var->addIncoming(types->i32_constant(0), loop_entry);
+ sum->addIncoming(types->i64_constant(0), loop_entry);
+
+ // setup loop PHI
+ llvm::Value* loop_update =
+ builder->CreateAdd(loop_var, types->i32_constant(1), "loop_var+1");
+ loop_var->addIncoming(loop_update, loop_body);
+
+ // get the current value
+ llvm::Value* offset = CreateGEP(builder, arg_elements, loop_var, "offset");
+ llvm::Value* current_value = CreateLoad(builder, offset, "value");
+
+ // setup sum PHI
+ llvm::Value* sum_update = builder->CreateAdd(sum, current_value, "sum+ith");
+ sum->addIncoming(sum_update, loop_body);
+
+ // check loop_var
+ llvm::Value* loop_var_check =
+ builder->CreateICmpSLT(loop_update, arg_nelements, "loop_var < nrec");
+ builder->CreateCondBr(loop_var_check, loop_body, loop_exit);
+
+ // Loop exit
+ builder->SetInsertPoint(loop_exit);
+ builder->CreateRet(sum_update);
+ return fn;
+ }
+
+ void BuildEngine() { ASSERT_OK(Engine::Make(TestConfiguration(), &engine)); }
+
+ std::unique_ptr<Engine> engine;
+ std::shared_ptr<Configuration> configuration = TestConfiguration();
+};
+
+TEST_F(TestEngine, TestAddUnoptimised) {
+ configuration->set_optimize(false);
+ BuildEngine();
+
+ llvm::Function* ir_func = BuildVecAdd(engine.get());
+ ASSERT_OK(engine->FinalizeModule());
+ auto add_func = reinterpret_cast<add_vector_func_t>(engine->CompiledFunction(ir_func));
+
+ int64_t my_array[] = {1, 3, -5, 8, 10};
+ EXPECT_EQ(add_func(my_array, 5), 17);
+}
+
+TEST_F(TestEngine, TestAddOptimised) {
+ configuration->set_optimize(true);
+ BuildEngine();
+
+ llvm::Function* ir_func = BuildVecAdd(engine.get());
+ ASSERT_OK(engine->FinalizeModule());
+ auto add_func = reinterpret_cast<add_vector_func_t>(engine->CompiledFunction(ir_func));
+
+ int64_t my_array[] = {1, 3, -5, 8, 10};
+ EXPECT_EQ(add_func(my_array, 5), 17);
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/eval_batch.h b/src/arrow/cpp/src/gandiva/eval_batch.h
new file mode 100644
index 000000000..25d9ab1d9
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/eval_batch.h
@@ -0,0 +1,107 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/util/logging.h"
+
+#include "gandiva/arrow.h"
+#include "gandiva/execution_context.h"
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/local_bitmaps_holder.h"
+
+namespace gandiva {
+
+/// \brief The buffers corresponding to one batch of records, used for
+/// expression evaluation.
+class EvalBatch {
+ public:
+ explicit EvalBatch(int64_t num_records, int num_buffers, int num_local_bitmaps)
+ : num_records_(num_records), num_buffers_(num_buffers) {
+ if (num_buffers > 0) {
+ buffers_array_.reset(new uint8_t*[num_buffers]);
+ buffer_offsets_array_.reset(new int64_t[num_buffers]);
+ }
+ local_bitmaps_holder_.reset(new LocalBitMapsHolder(num_records, num_local_bitmaps));
+ execution_context_.reset(new ExecutionContext());
+ }
+
+ int64_t num_records() const { return num_records_; }
+
+ uint8_t** GetBufferArray() const { return buffers_array_.get(); }
+
+ int64_t* GetBufferOffsetArray() const { return buffer_offsets_array_.get(); }
+
+ int GetNumBuffers() const { return num_buffers_; }
+
+ uint8_t* GetBuffer(int idx) const {
+ DCHECK(idx <= num_buffers_);
+ return (buffers_array_.get())[idx];
+ }
+
+ int64_t GetBufferOffset(int idx) const {
+ DCHECK(idx <= num_buffers_);
+ return (buffer_offsets_array_.get())[idx];
+ }
+
+ void SetBuffer(int idx, uint8_t* buffer, int64_t offset) {
+ DCHECK(idx <= num_buffers_);
+ (buffers_array_.get())[idx] = buffer;
+ (buffer_offsets_array_.get())[idx] = offset;
+ }
+
+ int GetNumLocalBitMaps() const { return local_bitmaps_holder_->GetNumLocalBitMaps(); }
+
+ int64_t GetLocalBitmapSize() const {
+ return local_bitmaps_holder_->GetLocalBitMapSize();
+ }
+
+ uint8_t* GetLocalBitMap(int idx) const {
+ DCHECK(idx <= GetNumLocalBitMaps());
+ return local_bitmaps_holder_->GetLocalBitMap(idx);
+ }
+
+ uint8_t** GetLocalBitMapArray() const {
+ return local_bitmaps_holder_->GetLocalBitMapArray();
+ }
+
+ ExecutionContext* GetExecutionContext() const { return execution_context_.get(); }
+
+ private:
+ /// number of records in the current batch.
+ int64_t num_records_;
+
+ // number of buffers.
+ int num_buffers_;
+
+ /// An array of 'num_buffers_', each containing a buffer. The buffer
+ /// sizes depends on the data type, but all of them have the same
+ /// number of slots (equal to num_records_).
+ std::unique_ptr<uint8_t*[]> buffers_array_;
+
+ /// An array of 'num_buffers_', each containing the offset for
+ /// corresponding buffer.
+ std::unique_ptr<int64_t[]> buffer_offsets_array_;
+
+ std::unique_ptr<LocalBitMapsHolder> local_bitmaps_holder_;
+
+ std::unique_ptr<ExecutionContext> execution_context_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/execution_context.h b/src/arrow/cpp/src/gandiva/execution_context.h
new file mode 100644
index 000000000..efa546874
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/execution_context.h
@@ -0,0 +1,54 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include "gandiva/simple_arena.h"
+
+namespace gandiva {
+
+/// Execution context during llvm evaluation
+class ExecutionContext {
+ public:
+ explicit ExecutionContext(arrow::MemoryPool* pool = arrow::default_memory_pool())
+ : arena_(pool) {}
+ std::string get_error() const { return error_msg_; }
+
+ void set_error_msg(const char* error_msg) {
+ // Remember the first error only.
+ if (error_msg_.empty()) {
+ error_msg_ = std::string(error_msg);
+ }
+ }
+
+ bool has_error() const { return !error_msg_.empty(); }
+
+ SimpleArena* arena() { return &arena_; }
+
+ void Reset() {
+ error_msg_.clear();
+ arena_.Reset();
+ }
+
+ private:
+ std::string error_msg_;
+ SimpleArena arena_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/exported_funcs.h b/src/arrow/cpp/src/gandiva/exported_funcs.h
new file mode 100644
index 000000000..582052660
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/exported_funcs.h
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <gandiva/exported_funcs_registry.h>
+#include <vector>
+
+namespace gandiva {
+
+class Engine;
+
+// Base-class type for exporting functions that can be accessed from LLVM/IR.
+class ExportedFuncsBase {
+ public:
+ virtual ~ExportedFuncsBase() = default;
+
+ virtual void AddMappings(Engine* engine) const = 0;
+};
+
+// Class for exporting Stub functions
+class ExportedStubFunctions : public ExportedFuncsBase {
+ void AddMappings(Engine* engine) const override;
+};
+REGISTER_EXPORTED_FUNCS(ExportedStubFunctions);
+
+// Class for exporting Context functions
+class ExportedContextFunctions : public ExportedFuncsBase {
+ void AddMappings(Engine* engine) const override;
+};
+REGISTER_EXPORTED_FUNCS(ExportedContextFunctions);
+
+// Class for exporting Time functions
+class ExportedTimeFunctions : public ExportedFuncsBase {
+ void AddMappings(Engine* engine) const override;
+};
+REGISTER_EXPORTED_FUNCS(ExportedTimeFunctions);
+
+// Class for exporting Decimal functions
+class ExportedDecimalFunctions : public ExportedFuncsBase {
+ void AddMappings(Engine* engine) const override;
+};
+REGISTER_EXPORTED_FUNCS(ExportedDecimalFunctions);
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/exported_funcs_registry.cc b/src/arrow/cpp/src/gandiva/exported_funcs_registry.cc
new file mode 100644
index 000000000..4c87c4d40
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/exported_funcs_registry.cc
@@ -0,0 +1,30 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/exported_funcs_registry.h"
+
+#include "gandiva/exported_funcs.h"
+
+namespace gandiva {
+
+void ExportedFuncsRegistry::AddMappings(Engine* engine) {
+ for (auto entry : registered()) {
+ entry->AddMappings(engine);
+ }
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/exported_funcs_registry.h b/src/arrow/cpp/src/gandiva/exported_funcs_registry.h
new file mode 100644
index 000000000..1504f2130
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/exported_funcs_registry.h
@@ -0,0 +1,54 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <vector>
+
+#include <gandiva/engine.h>
+
+namespace gandiva {
+
+class ExportedFuncsBase;
+
+/// Registry for classes that export functions which can be accessed by
+/// LLVM/IR code.
+class ExportedFuncsRegistry {
+ public:
+ using list_type = std::vector<std::shared_ptr<ExportedFuncsBase>>;
+
+ // Add functions from all the registered classes to the engine.
+ static void AddMappings(Engine* engine);
+
+ static bool Register(std::shared_ptr<ExportedFuncsBase> entry) {
+ registered().push_back(entry);
+ return true;
+ }
+
+ private:
+ static list_type& registered() {
+ static list_type registered_list;
+ return registered_list;
+ }
+};
+
+#define REGISTER_EXPORTED_FUNCS(classname) \
+ static bool _registered_##classname = \
+ ExportedFuncsRegistry::Register(std::make_shared<classname>())
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/expr_decomposer.cc b/src/arrow/cpp/src/gandiva/expr_decomposer.cc
new file mode 100644
index 000000000..1c09d28f5
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/expr_decomposer.cc
@@ -0,0 +1,310 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/expr_decomposer.h"
+
+#include <memory>
+#include <stack>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "gandiva/annotator.h"
+#include "gandiva/dex.h"
+#include "gandiva/function_holder_registry.h"
+#include "gandiva/function_registry.h"
+#include "gandiva/function_signature.h"
+#include "gandiva/in_holder.h"
+#include "gandiva/node.h"
+
+namespace gandiva {
+
+// Decompose a field node - simply separate out validity & value arrays.
+Status ExprDecomposer::Visit(const FieldNode& node) {
+ auto desc = annotator_.CheckAndAddInputFieldDescriptor(node.field());
+
+ DexPtr validity_dex = std::make_shared<VectorReadValidityDex>(desc);
+ DexPtr value_dex;
+ if (desc->HasOffsetsIdx()) {
+ value_dex = std::make_shared<VectorReadVarLenValueDex>(desc);
+ } else {
+ value_dex = std::make_shared<VectorReadFixedLenValueDex>(desc);
+ }
+ result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex);
+ return Status::OK();
+}
+
+// Try and optimize a function node, by substituting with cheaper alternatives.
+// eg. replacing 'like' with 'starts_with' can save function calls at evaluation
+// time.
+const FunctionNode ExprDecomposer::TryOptimize(const FunctionNode& node) {
+ if (node.descriptor()->name() == "like") {
+ return LikeHolder::TryOptimize(node);
+ } else {
+ return node;
+ }
+}
+
+// Decompose a field node - wherever possible, merge the validity vectors of the
+// child nodes.
+Status ExprDecomposer::Visit(const FunctionNode& in_node) {
+ auto node = TryOptimize(in_node);
+ auto desc = node.descriptor();
+ FunctionSignature signature(desc->name(), desc->params(), desc->return_type());
+ const NativeFunction* native_function = registry_.LookupSignature(signature);
+ DCHECK(native_function) << "Missing Signature " << signature.ToString();
+
+ // decompose the children.
+ std::vector<ValueValidityPairPtr> args;
+ for (auto& child : node.children()) {
+ auto status = child->Accept(*this);
+ ARROW_RETURN_NOT_OK(status);
+
+ args.push_back(result());
+ }
+
+ // Make a function holder, if required.
+ std::shared_ptr<FunctionHolder> holder;
+ if (native_function->NeedsFunctionHolder()) {
+ auto status = FunctionHolderRegistry::Make(desc->name(), node, &holder);
+ ARROW_RETURN_NOT_OK(status);
+ }
+
+ if (native_function->result_nullable_type() == kResultNullIfNull) {
+ // These functions are decomposable, merge the validity bits of the children.
+
+ std::vector<DexPtr> merged_validity;
+ for (auto& decomposed : args) {
+ // Merge the validity_expressions of the children to build a combined validity
+ // expression.
+ merged_validity.insert(merged_validity.end(), decomposed->validity_exprs().begin(),
+ decomposed->validity_exprs().end());
+ }
+
+ auto value_dex =
+ std::make_shared<NonNullableFuncDex>(desc, native_function, holder, args);
+ result_ = std::make_shared<ValueValidityPair>(merged_validity, value_dex);
+ } else if (native_function->result_nullable_type() == kResultNullNever) {
+ // These functions always output valid results. So, no validity dex.
+ auto value_dex =
+ std::make_shared<NullableNeverFuncDex>(desc, native_function, holder, args);
+ result_ = std::make_shared<ValueValidityPair>(value_dex);
+ } else {
+ DCHECK(native_function->result_nullable_type() == kResultNullInternal);
+
+ // Add a local bitmap to track the output validity.
+ int local_bitmap_idx = annotator_.AddLocalBitMap();
+ auto validity_dex = std::make_shared<LocalBitMapValidityDex>(local_bitmap_idx);
+
+ auto value_dex = std::make_shared<NullableInternalFuncDex>(
+ desc, native_function, holder, args, local_bitmap_idx);
+ result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex);
+ }
+ return Status::OK();
+}
+
+// Decompose an IfNode
+Status ExprDecomposer::Visit(const IfNode& node) {
+ // nested_if_else_ might get overwritten when visiting the condition-node, so
+ // saving the value to a local variable and resetting nested_if_else_ to false
+ bool svd_nested_if_else = nested_if_else_;
+ nested_if_else_ = false;
+
+ PushConditionEntry(node);
+ auto status = node.condition()->Accept(*this);
+ ARROW_RETURN_NOT_OK(status);
+ auto condition_vv = result();
+ PopConditionEntry(node);
+
+ // Add a local bitmap to track the output validity.
+ int local_bitmap_idx = PushThenEntry(node, svd_nested_if_else);
+ status = node.then_node()->Accept(*this);
+ ARROW_RETURN_NOT_OK(status);
+ auto then_vv = result();
+ PopThenEntry(node);
+
+ PushElseEntry(node, local_bitmap_idx);
+ nested_if_else_ = (dynamic_cast<IfNode*>(node.else_node().get()) != nullptr);
+
+ status = node.else_node()->Accept(*this);
+ ARROW_RETURN_NOT_OK(status);
+ auto else_vv = result();
+ bool is_terminal_else = PopElseEntry(node);
+
+ auto validity_dex = std::make_shared<LocalBitMapValidityDex>(local_bitmap_idx);
+ auto value_dex =
+ std::make_shared<IfDex>(condition_vv, then_vv, else_vv, node.return_type(),
+ local_bitmap_idx, is_terminal_else);
+
+ result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex);
+ return Status::OK();
+}
+
+// Decompose a BooleanNode
+Status ExprDecomposer::Visit(const BooleanNode& node) {
+ // decompose the children.
+ std::vector<ValueValidityPairPtr> args;
+ for (auto& child : node.children()) {
+ auto status = child->Accept(*this);
+ ARROW_RETURN_NOT_OK(status);
+
+ args.push_back(result());
+ }
+
+ // Add a local bitmap to track the output validity.
+ int local_bitmap_idx = annotator_.AddLocalBitMap();
+ auto validity_dex = std::make_shared<LocalBitMapValidityDex>(local_bitmap_idx);
+
+ std::shared_ptr<BooleanDex> value_dex;
+ switch (node.expr_type()) {
+ case BooleanNode::AND:
+ value_dex = std::make_shared<BooleanAndDex>(args, local_bitmap_idx);
+ break;
+ case BooleanNode::OR:
+ value_dex = std::make_shared<BooleanOrDex>(args, local_bitmap_idx);
+ break;
+ }
+ result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex);
+ return Status::OK();
+}
+Status ExprDecomposer::Visit(const InExpressionNode<gandiva::DecimalScalar128>& node) {
+ /* decompose the children. */
+ std::vector<ValueValidityPairPtr> args;
+ auto status = node.eval_expr()->Accept(*this);
+ ARROW_RETURN_NOT_OK(status);
+ args.push_back(result());
+ /* In always outputs valid results, so no validity dex */
+ auto value_dex = std::make_shared<InExprDex<gandiva::DecimalScalar128>>(
+ args, node.values(), node.get_precision(), node.get_scale());
+ result_ = std::make_shared<ValueValidityPair>(value_dex);
+ return Status::OK();
+}
+
+#define MAKE_VISIT_IN(ctype) \
+ Status ExprDecomposer::Visit(const InExpressionNode<ctype>& node) { \
+ /* decompose the children. */ \
+ std::vector<ValueValidityPairPtr> args; \
+ auto status = node.eval_expr()->Accept(*this); \
+ ARROW_RETURN_NOT_OK(status); \
+ args.push_back(result()); \
+ /* In always outputs valid results, so no validity dex */ \
+ auto value_dex = std::make_shared<InExprDex<ctype>>(args, node.values()); \
+ result_ = std::make_shared<ValueValidityPair>(value_dex); \
+ return Status::OK(); \
+ }
+
+MAKE_VISIT_IN(int32_t);
+MAKE_VISIT_IN(int64_t);
+MAKE_VISIT_IN(float);
+MAKE_VISIT_IN(double);
+MAKE_VISIT_IN(std::string);
+
+Status ExprDecomposer::Visit(const LiteralNode& node) {
+ auto value_dex = std::make_shared<LiteralDex>(node.return_type(), node.holder());
+ DexPtr validity_dex;
+ if (node.is_null()) {
+ validity_dex = std::make_shared<FalseDex>();
+ } else {
+ validity_dex = std::make_shared<TrueDex>();
+ }
+ result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex);
+ return Status::OK();
+}
+
+// The bolow functions use a stack to detect :
+// a. nested if-else expressions.
+// In such cases, the local bitmap can be re-used.
+// b. detect terminal else expressions
+// The non-terminal else expressions do not need to track validity (the if statement
+// that has a match will do it).
+// Both of the above optimisations save CPU cycles during expression evaluation.
+
+int ExprDecomposer::PushThenEntry(const IfNode& node, bool reuse_bitmap) {
+ int local_bitmap_idx;
+
+ if (reuse_bitmap) {
+ // we also need stack in addition to reuse_bitmap flag since we
+ // can also enter other if-else nodes when we visit the condition-node
+ // (which themselves might be nested) before we visit then-node
+ DCHECK_EQ(if_entries_stack_.empty(), false) << "PushThenEntry: stack is empty";
+ DCHECK_EQ(if_entries_stack_.top()->entry_type_, kStackEntryElse)
+ << "PushThenEntry: top of stack is not of type entry_else";
+ auto top = if_entries_stack_.top().get();
+
+ // inside a nested else statement (i.e if-else-if). use the parent's bitmap.
+ local_bitmap_idx = top->local_bitmap_idx_;
+
+ // clear the is_terminal bit in the current top entry (else).
+ top->is_terminal_else_ = false;
+ } else {
+ // alloc a new bitmap.
+ local_bitmap_idx = annotator_.AddLocalBitMap();
+ }
+
+ // push new entry to the stack.
+ std::unique_ptr<IfStackEntry> entry(new IfStackEntry(
+ node, kStackEntryThen, false /*is_terminal_else*/, local_bitmap_idx));
+ if_entries_stack_.emplace(std::move(entry));
+ return local_bitmap_idx;
+}
+
+void ExprDecomposer::PopThenEntry(const IfNode& node) {
+ DCHECK_EQ(if_entries_stack_.empty(), false) << "PopThenEntry: found empty stack";
+
+ auto top = if_entries_stack_.top().get();
+ DCHECK_EQ(top->entry_type_, kStackEntryThen)
+ << "PopThenEntry: found " << top->entry_type_ << " expected then";
+ DCHECK_EQ(&top->if_node_, &node) << "PopThenEntry: found mismatched node";
+
+ if_entries_stack_.pop();
+}
+
+void ExprDecomposer::PushElseEntry(const IfNode& node, int local_bitmap_idx) {
+ std::unique_ptr<IfStackEntry> entry(new IfStackEntry(
+ node, kStackEntryElse, true /*is_terminal_else*/, local_bitmap_idx));
+ if_entries_stack_.emplace(std::move(entry));
+}
+
+bool ExprDecomposer::PopElseEntry(const IfNode& node) {
+ DCHECK_EQ(if_entries_stack_.empty(), false) << "PopElseEntry: found empty stack";
+
+ auto top = if_entries_stack_.top().get();
+ DCHECK_EQ(top->entry_type_, kStackEntryElse)
+ << "PopElseEntry: found " << top->entry_type_ << " expected else";
+ DCHECK_EQ(&top->if_node_, &node) << "PopElseEntry: found mismatched node";
+ bool is_terminal_else = top->is_terminal_else_;
+
+ if_entries_stack_.pop();
+ return is_terminal_else;
+}
+
+void ExprDecomposer::PushConditionEntry(const IfNode& node) {
+ std::unique_ptr<IfStackEntry> entry(new IfStackEntry(node, kStackEntryCondition));
+ if_entries_stack_.emplace(std::move(entry));
+}
+
+void ExprDecomposer::PopConditionEntry(const IfNode& node) {
+ DCHECK_EQ(if_entries_stack_.empty(), false) << "PopConditionEntry: found empty stack";
+
+ auto top = if_entries_stack_.top().get();
+ DCHECK_EQ(top->entry_type_, kStackEntryCondition)
+ << "PopConditionEntry: found " << top->entry_type_ << " expected condition";
+ DCHECK_EQ(&top->if_node_, &node) << "PopConditionEntry: found mismatched node";
+ if_entries_stack_.pop();
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/expr_decomposer.h b/src/arrow/cpp/src/gandiva/expr_decomposer.h
new file mode 100644
index 000000000..f68b8a8fc
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/expr_decomposer.h
@@ -0,0 +1,128 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cmath>
+#include <memory>
+#include <stack>
+#include <string>
+#include <utility>
+
+#include "gandiva/arrow.h"
+#include "gandiva/expression.h"
+#include "gandiva/node.h"
+#include "gandiva/node_visitor.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+class FunctionRegistry;
+class Annotator;
+
+/// \brief Decomposes an expression tree to separate out the validity and
+/// value expressions.
+class GANDIVA_EXPORT ExprDecomposer : public NodeVisitor {
+ public:
+ explicit ExprDecomposer(const FunctionRegistry& registry, Annotator& annotator)
+ : registry_(registry), annotator_(annotator), nested_if_else_(false) {}
+
+ Status Decompose(const Node& root, ValueValidityPairPtr* out) {
+ auto status = root.Accept(*this);
+ if (status.ok()) {
+ *out = std::move(result_);
+ }
+ return status;
+ }
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ExprDecomposer);
+
+ FRIEND_TEST(TestExprDecomposer, TestStackSimple);
+ FRIEND_TEST(TestExprDecomposer, TestNested);
+ FRIEND_TEST(TestExprDecomposer, TestInternalIf);
+ FRIEND_TEST(TestExprDecomposer, TestParallelIf);
+ FRIEND_TEST(TestExprDecomposer, TestIfInCondition);
+ FRIEND_TEST(TestExprDecomposer, TestFunctionBetweenNestedIf);
+ FRIEND_TEST(TestExprDecomposer, TestComplexIfCondition);
+
+ Status Visit(const FieldNode& node) override;
+ Status Visit(const FunctionNode& node) override;
+ Status Visit(const IfNode& node) override;
+ Status Visit(const LiteralNode& node) override;
+ Status Visit(const BooleanNode& node) override;
+ Status Visit(const InExpressionNode<int32_t>& node) override;
+ Status Visit(const InExpressionNode<int64_t>& node) override;
+ Status Visit(const InExpressionNode<float>& node) override;
+ Status Visit(const InExpressionNode<double>& node) override;
+ Status Visit(const InExpressionNode<gandiva::DecimalScalar128>& node) override;
+ Status Visit(const InExpressionNode<std::string>& node) override;
+
+ // Optimize a function node, if possible.
+ const FunctionNode TryOptimize(const FunctionNode& node);
+
+ enum StackEntryType { kStackEntryCondition, kStackEntryThen, kStackEntryElse };
+
+ // stack of if nodes.
+ class IfStackEntry {
+ public:
+ IfStackEntry(const IfNode& if_node, StackEntryType entry_type,
+ bool is_terminal_else = false, int local_bitmap_idx = 0)
+ : if_node_(if_node),
+ entry_type_(entry_type),
+ is_terminal_else_(is_terminal_else),
+ local_bitmap_idx_(local_bitmap_idx) {}
+
+ const IfNode& if_node_;
+ StackEntryType entry_type_;
+ bool is_terminal_else_;
+ int local_bitmap_idx_;
+
+ private:
+ ARROW_DISALLOW_COPY_AND_ASSIGN(IfStackEntry);
+ };
+
+ // pop 'condition entry' into stack.
+ void PushConditionEntry(const IfNode& node);
+
+ // pop 'condition entry' from stack.
+ void PopConditionEntry(const IfNode& node);
+
+ // push 'then entry' to stack. returns either a new local bitmap or the parent's
+ // bitmap (in case of nested if-else).
+ int PushThenEntry(const IfNode& node, bool reuse_bitmap);
+
+ // pop 'then entry' from stack.
+ void PopThenEntry(const IfNode& node);
+
+ // push 'else entry' into stack.
+ void PushElseEntry(const IfNode& node, int local_bitmap_idx);
+
+ // pop 'else entry' from stack. returns 'true' if this is a terminal else condition
+ // i.e no nested if condition below this node.
+ bool PopElseEntry(const IfNode& node);
+
+ ValueValidityPairPtr result() { return std::move(result_); }
+
+ const FunctionRegistry& registry_;
+ Annotator& annotator_;
+ std::stack<std::unique_ptr<IfStackEntry>> if_entries_stack_;
+ ValueValidityPairPtr result_;
+ bool nested_if_else_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/expr_decomposer_test.cc b/src/arrow/cpp/src/gandiva/expr_decomposer_test.cc
new file mode 100644
index 000000000..638ceebcb
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/expr_decomposer_test.cc
@@ -0,0 +1,409 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/expr_decomposer.h"
+
+#include <gtest/gtest.h>
+
+#include "gandiva/annotator.h"
+#include "gandiva/dex.h"
+#include "gandiva/function_registry.h"
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/node.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::int32;
+
+class TestExprDecomposer : public ::testing::Test {
+ protected:
+ FunctionRegistry registry_;
+};
+
+TEST_F(TestExprDecomposer, TestStackSimple) {
+ Annotator annotator;
+ ExprDecomposer decomposer(registry_, annotator);
+
+ // if (a) _
+ // else _
+ IfNode node_a(nullptr, nullptr, nullptr, int32());
+
+ decomposer.PushConditionEntry(node_a);
+ decomposer.PopConditionEntry(node_a);
+
+ int idx_a = decomposer.PushThenEntry(node_a, false);
+ EXPECT_EQ(idx_a, 0);
+ decomposer.PopThenEntry(node_a);
+
+ decomposer.PushElseEntry(node_a, idx_a);
+ bool is_terminal_a = decomposer.PopElseEntry(node_a);
+ EXPECT_EQ(is_terminal_a, true);
+ EXPECT_EQ(decomposer.if_entries_stack_.empty(), true);
+}
+
+TEST_F(TestExprDecomposer, TestNested) {
+ Annotator annotator;
+ ExprDecomposer decomposer(registry_, annotator);
+
+ // if (a) _
+ // else _
+ // if (b) _
+ // else _
+ IfNode node_a(nullptr, nullptr, nullptr, int32());
+ IfNode node_b(nullptr, nullptr, nullptr, int32());
+
+ decomposer.PushConditionEntry(node_a);
+ decomposer.PopConditionEntry(node_a);
+
+ int idx_a = decomposer.PushThenEntry(node_a, false);
+ EXPECT_EQ(idx_a, 0);
+ decomposer.PopThenEntry(node_a);
+
+ decomposer.PushElseEntry(node_a, idx_a);
+
+ { // start b
+ decomposer.PushConditionEntry(node_b);
+ decomposer.PopConditionEntry(node_b);
+
+ int idx_b = decomposer.PushThenEntry(node_b, true);
+ EXPECT_EQ(idx_b, 0); // must reuse bitmap.
+ decomposer.PopThenEntry(node_b);
+
+ decomposer.PushElseEntry(node_b, idx_b);
+ bool is_terminal_b = decomposer.PopElseEntry(node_b);
+ EXPECT_EQ(is_terminal_b, true);
+ } // end b
+
+ bool is_terminal_a = decomposer.PopElseEntry(node_a);
+ EXPECT_EQ(is_terminal_a, false); // there was a nested if.
+
+ EXPECT_EQ(decomposer.if_entries_stack_.empty(), true);
+}
+
+TEST_F(TestExprDecomposer, TestInternalIf) {
+ Annotator annotator;
+ ExprDecomposer decomposer(registry_, annotator);
+
+ // if (a) _
+ // if (b) _
+ // else _
+ // else _
+ IfNode node_a(nullptr, nullptr, nullptr, int32());
+ IfNode node_b(nullptr, nullptr, nullptr, int32());
+
+ decomposer.PushConditionEntry(node_a);
+ decomposer.PopConditionEntry(node_a);
+
+ int idx_a = decomposer.PushThenEntry(node_a, false);
+ EXPECT_EQ(idx_a, 0);
+
+ { // start b
+ decomposer.PushConditionEntry(node_b);
+ decomposer.PopConditionEntry(node_b);
+
+ int idx_b = decomposer.PushThenEntry(node_b, false);
+ EXPECT_EQ(idx_b, 1); // must not reuse bitmap.
+ decomposer.PopThenEntry(node_b);
+
+ decomposer.PushElseEntry(node_b, idx_b);
+ bool is_terminal_b = decomposer.PopElseEntry(node_b);
+ EXPECT_EQ(is_terminal_b, true);
+ } // end b
+
+ decomposer.PopThenEntry(node_a);
+ decomposer.PushElseEntry(node_a, idx_a);
+
+ bool is_terminal_a = decomposer.PopElseEntry(node_a);
+ EXPECT_EQ(is_terminal_a, true); // there was no nested if.
+
+ EXPECT_EQ(decomposer.if_entries_stack_.empty(), true);
+}
+
+TEST_F(TestExprDecomposer, TestParallelIf) {
+ Annotator annotator;
+ ExprDecomposer decomposer(registry_, annotator);
+
+ // if (a) _
+ // else _
+ // if (b) _
+ // else _
+ IfNode node_a(nullptr, nullptr, nullptr, int32());
+ IfNode node_b(nullptr, nullptr, nullptr, int32());
+
+ decomposer.PushConditionEntry(node_a);
+ decomposer.PopConditionEntry(node_a);
+
+ int idx_a = decomposer.PushThenEntry(node_a, false);
+ EXPECT_EQ(idx_a, 0);
+
+ decomposer.PopThenEntry(node_a);
+ decomposer.PushElseEntry(node_a, idx_a);
+
+ bool is_terminal_a = decomposer.PopElseEntry(node_a);
+ EXPECT_EQ(is_terminal_a, true); // there was no nested if.
+
+ // start b
+ decomposer.PushConditionEntry(node_b);
+ decomposer.PopConditionEntry(node_b);
+
+ int idx_b = decomposer.PushThenEntry(node_b, false);
+ EXPECT_EQ(idx_b, 1); // must not reuse bitmap.
+ decomposer.PopThenEntry(node_b);
+
+ decomposer.PushElseEntry(node_b, idx_b);
+ bool is_terminal_b = decomposer.PopElseEntry(node_b);
+ EXPECT_EQ(is_terminal_b, true);
+
+ EXPECT_EQ(decomposer.if_entries_stack_.empty(), true);
+}
+
+TEST_F(TestExprDecomposer, TestIfInCondition) {
+ Annotator annotator;
+ ExprDecomposer decomposer(registry_, annotator);
+
+ // if (if _ else _) : a
+ // -
+ // else
+ // if (if _ else _) : b
+ // -
+ // else
+ // -
+ IfNode node_a(nullptr, nullptr, nullptr, int32());
+ IfNode node_b(nullptr, nullptr, nullptr, int32());
+ IfNode cond_node_a(nullptr, nullptr, nullptr, int32());
+ IfNode cond_node_b(nullptr, nullptr, nullptr, int32());
+
+ // start a
+ decomposer.PushConditionEntry(node_a);
+ {
+ // start cond_node_a
+ decomposer.PushConditionEntry(cond_node_a);
+ decomposer.PopConditionEntry(cond_node_a);
+
+ int idx_cond_a = decomposer.PushThenEntry(cond_node_a, false);
+ EXPECT_EQ(idx_cond_a, 0);
+ decomposer.PopThenEntry(cond_node_a);
+
+ decomposer.PushElseEntry(cond_node_a, idx_cond_a);
+ bool is_terminal = decomposer.PopElseEntry(cond_node_a);
+ EXPECT_EQ(is_terminal, true); // there was no nested if.
+ }
+ decomposer.PopConditionEntry(node_a);
+
+ int idx_a = decomposer.PushThenEntry(node_a, false);
+ EXPECT_EQ(idx_a, 1); // no re-use
+ decomposer.PopThenEntry(node_a);
+
+ decomposer.PushElseEntry(node_a, idx_a);
+
+ { // start b
+ decomposer.PushConditionEntry(node_b);
+ {
+ // start cond_node_b
+ decomposer.PushConditionEntry(cond_node_b);
+ decomposer.PopConditionEntry(cond_node_b);
+
+ int idx_cond_b = decomposer.PushThenEntry(cond_node_b, false);
+ EXPECT_EQ(idx_cond_b, 2); // no re-use
+ decomposer.PopThenEntry(cond_node_b);
+
+ decomposer.PushElseEntry(cond_node_b, idx_cond_b);
+ bool is_terminal = decomposer.PopElseEntry(cond_node_b);
+ EXPECT_EQ(is_terminal, true); // there was no nested if.
+ }
+ decomposer.PopConditionEntry(node_b);
+
+ int idx_b = decomposer.PushThenEntry(node_b, true);
+ EXPECT_EQ(idx_b, 1); // must reuse bitmap.
+ decomposer.PopThenEntry(node_b);
+
+ decomposer.PushElseEntry(node_b, idx_b);
+ bool is_terminal = decomposer.PopElseEntry(node_b);
+ EXPECT_EQ(is_terminal, true);
+ } // end b
+
+ bool is_terminal_a = decomposer.PopElseEntry(node_a);
+ EXPECT_EQ(is_terminal_a, false); // there was a nested if.
+
+ EXPECT_EQ(decomposer.if_entries_stack_.empty(), true);
+}
+
+TEST_F(TestExprDecomposer, TestFunctionBetweenNestedIf) {
+ Annotator annotator;
+ ExprDecomposer decomposer(registry_, annotator);
+
+ // if (a) _
+ // else
+ // function(
+ // if (b) _
+ // else _
+ // )
+
+ IfNode node_a(nullptr, nullptr, nullptr, int32());
+ IfNode node_b(nullptr, nullptr, nullptr, int32());
+
+ // start outer if
+ decomposer.PushConditionEntry(node_a);
+ decomposer.PopConditionEntry(node_a);
+
+ int idx_a = decomposer.PushThenEntry(node_a, false);
+ EXPECT_EQ(idx_a, 0);
+ decomposer.PopThenEntry(node_a);
+
+ decomposer.PushElseEntry(node_a, idx_a);
+ { // start b
+ decomposer.PushConditionEntry(node_b);
+ decomposer.PopConditionEntry(node_b);
+
+ int idx_b = decomposer.PushThenEntry(node_b, false); // not else node of parent if
+ EXPECT_EQ(idx_b, 1); // can't reuse bitmap.
+ decomposer.PopThenEntry(node_b);
+
+ decomposer.PushElseEntry(node_b, idx_b);
+ bool is_terminal_b = decomposer.PopElseEntry(node_b);
+ EXPECT_EQ(is_terminal_b, true);
+ }
+ bool is_terminal_a = decomposer.PopElseEntry(node_a);
+ EXPECT_EQ(is_terminal_a, true); // a else is also terminal
+
+ EXPECT_TRUE(decomposer.if_entries_stack_.empty());
+}
+
+TEST_F(TestExprDecomposer, TestComplexIfCondition) {
+ Annotator annotator;
+ ExprDecomposer decomposer(registry_, annotator);
+
+ // if (if _
+ // else
+ // if _
+ // else _
+ // )
+ // then
+ // if _
+ // else
+ // if _
+ // else _
+ //
+ // else
+ // if _
+ // else
+ // if _
+ // else _
+
+ IfNode node_a(nullptr, nullptr, nullptr, int32());
+
+ IfNode cond_node_a(nullptr, nullptr, nullptr, int32());
+ IfNode cond_node_a_inner_if(nullptr, nullptr, nullptr, int32());
+
+ IfNode then_node_a(nullptr, nullptr, nullptr, int32());
+ IfNode then_node_a_inner_if(nullptr, nullptr, nullptr, int32());
+
+ IfNode else_node_a(nullptr, nullptr, nullptr, int32());
+ IfNode else_node_a_inner_if(nullptr, nullptr, nullptr, int32());
+
+ // start outer if
+ decomposer.PushConditionEntry(node_a);
+ {
+ // start the nested if inside the condition of a
+ decomposer.PushConditionEntry(cond_node_a);
+ decomposer.PopConditionEntry(cond_node_a);
+
+ int idx_cond_a = decomposer.PushThenEntry(cond_node_a, false);
+ EXPECT_EQ(idx_cond_a, 0);
+ decomposer.PopThenEntry(cond_node_a);
+
+ decomposer.PushElseEntry(cond_node_a, idx_cond_a);
+ {
+ decomposer.PushConditionEntry(cond_node_a_inner_if);
+ decomposer.PopConditionEntry(cond_node_a_inner_if);
+
+ int idx_cond_a_inner_if = decomposer.PushThenEntry(cond_node_a_inner_if, true);
+ EXPECT_EQ(idx_cond_a_inner_if,
+ 0); // expect bitmap to be resused since nested if else
+ decomposer.PopThenEntry(cond_node_a_inner_if);
+
+ decomposer.PushElseEntry(cond_node_a_inner_if, idx_cond_a_inner_if);
+ bool is_terminal = decomposer.PopElseEntry(cond_node_a_inner_if);
+ EXPECT_TRUE(is_terminal);
+ }
+ EXPECT_FALSE(decomposer.PopElseEntry(cond_node_a));
+ }
+ decomposer.PopConditionEntry(node_a);
+
+ int idx_a = decomposer.PushThenEntry(node_a, false);
+ EXPECT_EQ(idx_a, 1);
+
+ {
+ // start the nested if inside the then node of a
+ decomposer.PushConditionEntry(then_node_a);
+ decomposer.PopConditionEntry(then_node_a);
+
+ int idx_then_a = decomposer.PushThenEntry(then_node_a, false);
+ EXPECT_EQ(idx_then_a, 2);
+ decomposer.PopThenEntry(then_node_a);
+
+ decomposer.PushElseEntry(then_node_a, idx_then_a);
+ {
+ decomposer.PushConditionEntry(then_node_a_inner_if);
+ decomposer.PopConditionEntry(then_node_a_inner_if);
+
+ int idx_then_a_inner_if = decomposer.PushThenEntry(then_node_a_inner_if, true);
+ EXPECT_EQ(idx_then_a_inner_if,
+ 2); // expect bitmap to be resused since nested if else
+ decomposer.PopThenEntry(then_node_a_inner_if);
+
+ decomposer.PushElseEntry(then_node_a_inner_if, idx_then_a_inner_if);
+ bool is_terminal = decomposer.PopElseEntry(then_node_a_inner_if);
+ EXPECT_TRUE(is_terminal);
+ }
+ EXPECT_FALSE(decomposer.PopElseEntry(then_node_a));
+ }
+ decomposer.PopThenEntry(node_a);
+
+ decomposer.PushElseEntry(node_a, idx_a);
+ {
+ // start the nested if inside the else node of a
+ decomposer.PushConditionEntry(else_node_a);
+ decomposer.PopConditionEntry(else_node_a);
+
+ int idx_else_a =
+ decomposer.PushThenEntry(else_node_a, true); // else node is another if-node
+ EXPECT_EQ(idx_else_a, 1); // reuse the outer if node bitmap since nested if-else
+ decomposer.PopThenEntry(else_node_a);
+
+ decomposer.PushElseEntry(else_node_a, idx_else_a);
+ {
+ decomposer.PushConditionEntry(else_node_a_inner_if);
+ decomposer.PopConditionEntry(else_node_a_inner_if);
+
+ int idx_else_a_inner_if = decomposer.PushThenEntry(else_node_a_inner_if, true);
+ EXPECT_EQ(idx_else_a_inner_if,
+ 1); // expect bitmap to be resused since nested if else
+ decomposer.PopThenEntry(else_node_a_inner_if);
+
+ decomposer.PushElseEntry(else_node_a_inner_if, idx_else_a_inner_if);
+ bool is_terminal = decomposer.PopElseEntry(else_node_a_inner_if);
+ EXPECT_TRUE(is_terminal);
+ }
+ EXPECT_FALSE(decomposer.PopElseEntry(else_node_a));
+ }
+ EXPECT_FALSE(decomposer.PopElseEntry(node_a));
+ EXPECT_TRUE(decomposer.if_entries_stack_.empty());
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/expr_validator.cc b/src/arrow/cpp/src/gandiva/expr_validator.cc
new file mode 100644
index 000000000..c3c784c95
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/expr_validator.cc
@@ -0,0 +1,193 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "gandiva/expr_validator.h"
+
+namespace gandiva {
+
+Status ExprValidator::Validate(const ExpressionPtr& expr) {
+ ARROW_RETURN_IF(expr == nullptr,
+ Status::ExpressionValidationError("Expression cannot be null"));
+
+ Node& root = *expr->root();
+ ARROW_RETURN_NOT_OK(root.Accept(*this));
+
+ // Ensure root's return type match the expression return type. Type
+ // support validation is not required because root type is already supported.
+ ARROW_RETURN_IF(!root.return_type()->Equals(*expr->result()->type()),
+ Status::ExpressionValidationError("Return type of root node ",
+ root.return_type()->ToString(),
+ " does not match that of expression ",
+ expr->result()->type()->ToString()));
+
+ return Status::OK();
+}
+
+Status ExprValidator::Visit(const FieldNode& node) {
+ auto llvm_type = types_->IRType(node.return_type()->id());
+ ARROW_RETURN_IF(llvm_type == nullptr,
+ Status::ExpressionValidationError("Field ", node.field()->name(),
+ " has unsupported data type ",
+ node.return_type()->name()));
+
+ // Ensure that field is found in schema
+ auto field_in_schema_entry = field_map_.find(node.field()->name());
+ ARROW_RETURN_IF(field_in_schema_entry == field_map_.end(),
+ Status::ExpressionValidationError("Field ", node.field()->name(),
+ " not in schema."));
+
+ // Ensure that that the found field match.
+ FieldPtr field_in_schema = field_in_schema_entry->second;
+ ARROW_RETURN_IF(!field_in_schema->Equals(node.field()),
+ Status::ExpressionValidationError(
+ "Field definition in schema ", field_in_schema->ToString(),
+ " different from field in expression ", node.field()->ToString()));
+
+ return Status::OK();
+}
+
+Status ExprValidator::Visit(const FunctionNode& node) {
+ auto desc = node.descriptor();
+ FunctionSignature signature(desc->name(), desc->params(), desc->return_type());
+
+ const NativeFunction* native_function = registry_.LookupSignature(signature);
+ ARROW_RETURN_IF(native_function == nullptr,
+ Status::ExpressionValidationError("Function ", signature.ToString(),
+ " not supported yet. "));
+
+ for (auto& child : node.children()) {
+ ARROW_RETURN_NOT_OK(child->Accept(*this));
+ }
+
+ return Status::OK();
+}
+
+Status ExprValidator::Visit(const IfNode& node) {
+ ARROW_RETURN_NOT_OK(node.condition()->Accept(*this));
+ ARROW_RETURN_NOT_OK(node.then_node()->Accept(*this));
+ ARROW_RETURN_NOT_OK(node.else_node()->Accept(*this));
+
+ auto if_node_ret_type = node.return_type();
+ auto then_node_ret_type = node.then_node()->return_type();
+ auto else_node_ret_type = node.else_node()->return_type();
+
+ // condition must be of boolean type.
+ ARROW_RETURN_IF(
+ !node.condition()->return_type()->Equals(arrow::boolean()),
+ Status::ExpressionValidationError("condition must be of boolean type, found type ",
+ node.condition()->return_type()->ToString()));
+
+ // Then-branch return type must match.
+ ARROW_RETURN_IF(!if_node_ret_type->Equals(*then_node_ret_type),
+ Status::ExpressionValidationError(
+ "Return type of if ", if_node_ret_type->ToString(), " and then ",
+ then_node_ret_type->ToString(), " not matching."));
+
+ // Else-branch return type must match.
+ ARROW_RETURN_IF(!if_node_ret_type->Equals(*else_node_ret_type),
+ Status::ExpressionValidationError(
+ "Return type of if ", if_node_ret_type->ToString(), " and else ",
+ else_node_ret_type->ToString(), " not matching."));
+
+ return Status::OK();
+}
+
+Status ExprValidator::Visit(const LiteralNode& node) {
+ auto llvm_type = types_->IRType(node.return_type()->id());
+ ARROW_RETURN_IF(llvm_type == nullptr,
+ Status::ExpressionValidationError("Value ", ToString(node.holder()),
+ " has unsupported data type ",
+ node.return_type()->name()));
+
+ return Status::OK();
+}
+
+Status ExprValidator::Visit(const BooleanNode& node) {
+ ARROW_RETURN_IF(
+ node.children().size() < 2,
+ Status::ExpressionValidationError("Boolean expression has ", node.children().size(),
+ " children, expected at least two"));
+
+ for (auto& child : node.children()) {
+ const auto bool_type = arrow::boolean();
+ const auto ret_type = child->return_type();
+
+ ARROW_RETURN_IF(!ret_type->Equals(bool_type),
+ Status::ExpressionValidationError(
+ "Boolean expression has a child with return type ",
+ ret_type->ToString(), ", expected return type boolean"));
+
+ ARROW_RETURN_NOT_OK(child->Accept(*this));
+ }
+
+ return Status::OK();
+}
+
+/*
+ * Validate the following
+ *
+ * 1. Non empty list of constants to search in.
+ * 2. Expression returns of the same type as the constants.
+ */
+Status ExprValidator::Visit(const InExpressionNode<int32_t>& node) {
+ return ValidateInExpression(node.values().size(), node.eval_expr()->return_type(),
+ arrow::int32());
+}
+
+Status ExprValidator::Visit(const InExpressionNode<int64_t>& node) {
+ return ValidateInExpression(node.values().size(), node.eval_expr()->return_type(),
+ arrow::int64());
+}
+Status ExprValidator::Visit(const InExpressionNode<float>& node) {
+ return ValidateInExpression(node.values().size(), node.eval_expr()->return_type(),
+ arrow::float32());
+}
+Status ExprValidator::Visit(const InExpressionNode<double>& node) {
+ return ValidateInExpression(node.values().size(), node.eval_expr()->return_type(),
+ arrow::float64());
+}
+
+Status ExprValidator::Visit(const InExpressionNode<gandiva::DecimalScalar128>& node) {
+ return ValidateInExpression(node.values().size(), node.eval_expr()->return_type(),
+ arrow::decimal(node.get_precision(), node.get_scale()));
+}
+
+Status ExprValidator::Visit(const InExpressionNode<std::string>& node) {
+ return ValidateInExpression(node.values().size(), node.eval_expr()->return_type(),
+ arrow::utf8());
+}
+
+Status ExprValidator::ValidateInExpression(size_t number_of_values,
+ DataTypePtr in_expr_return_type,
+ DataTypePtr type_of_values) {
+ ARROW_RETURN_IF(number_of_values == 0,
+ Status::ExpressionValidationError(
+ "IN Expression needs a non-empty constant list to match."));
+ ARROW_RETURN_IF(
+ !in_expr_return_type->Equals(type_of_values),
+ Status::ExpressionValidationError(
+ "Evaluation expression for IN clause returns ", in_expr_return_type->ToString(),
+ " values are of type", type_of_values->ToString()));
+
+ return Status::OK();
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/expr_validator.h b/src/arrow/cpp/src/gandiva/expr_validator.h
new file mode 100644
index 000000000..daaf50897
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/expr_validator.h
@@ -0,0 +1,80 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+#include <unordered_map>
+
+#include "arrow/status.h"
+
+#include "gandiva/arrow.h"
+#include "gandiva/expression.h"
+#include "gandiva/function_registry.h"
+#include "gandiva/llvm_types.h"
+#include "gandiva/node.h"
+#include "gandiva/node_visitor.h"
+
+namespace gandiva {
+
+class FunctionRegistry;
+
+/// \brief Validates the entire expression tree including
+/// data types, signatures and return types
+class ExprValidator : public NodeVisitor {
+ public:
+ explicit ExprValidator(LLVMTypes* types, SchemaPtr schema)
+ : types_(types), schema_(schema) {
+ for (auto& field : schema_->fields()) {
+ field_map_[field->name()] = field;
+ }
+ }
+
+ /// \brief Validates the root node
+ /// of an expression.
+ /// 1. Data type of fields and literals.
+ /// 2. Function signature is supported.
+ /// 3. For if nodes that return types match
+ /// for if, then and else nodes.
+ Status Validate(const ExpressionPtr& expr);
+
+ private:
+ Status Visit(const FieldNode& node) override;
+ Status Visit(const FunctionNode& node) override;
+ Status Visit(const IfNode& node) override;
+ Status Visit(const LiteralNode& node) override;
+ Status Visit(const BooleanNode& node) override;
+ Status Visit(const InExpressionNode<int32_t>& node) override;
+ Status Visit(const InExpressionNode<int64_t>& node) override;
+ Status Visit(const InExpressionNode<float>& node) override;
+ Status Visit(const InExpressionNode<double>& node) override;
+ Status Visit(const InExpressionNode<gandiva::DecimalScalar128>& node) override;
+ Status Visit(const InExpressionNode<std::string>& node) override;
+ Status ValidateInExpression(size_t number_of_values, DataTypePtr in_expr_return_type,
+ DataTypePtr type_of_values);
+
+ FunctionRegistry registry_;
+
+ LLVMTypes* types_;
+
+ SchemaPtr schema_;
+
+ using FieldMap = std::unordered_map<std::string, FieldPtr>;
+ FieldMap field_map_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/expression.cc b/src/arrow/cpp/src/gandiva/expression.cc
new file mode 100644
index 000000000..06aada27b
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/expression.cc
@@ -0,0 +1,25 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/expression.h"
+#include "gandiva/node.h"
+
+namespace gandiva {
+
+std::string Expression::ToString() { return root()->ToString(); }
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/expression.h b/src/arrow/cpp/src/gandiva/expression.h
new file mode 100644
index 000000000..cdda2512b
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/expression.h
@@ -0,0 +1,46 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+
+#include "gandiva/arrow.h"
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// \brief An expression tree with a root node, and a result field.
+class GANDIVA_EXPORT Expression {
+ public:
+ Expression(const NodePtr root, const FieldPtr result) : root_(root), result_(result) {}
+
+ virtual ~Expression() = default;
+
+ const NodePtr& root() const { return root_; }
+
+ const FieldPtr& result() const { return result_; }
+
+ std::string ToString();
+
+ private:
+ const NodePtr root_;
+ const FieldPtr result_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/expression_registry.cc b/src/arrow/cpp/src/gandiva/expression_registry.cc
new file mode 100644
index 000000000..c3a08fd3a
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/expression_registry.cc
@@ -0,0 +1,187 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/expression_registry.h"
+
+#include "gandiva/function_registry.h"
+#include "gandiva/llvm_types.h"
+
+namespace gandiva {
+
+ExpressionRegistry::ExpressionRegistry() {
+ function_registry_.reset(new FunctionRegistry());
+}
+
+ExpressionRegistry::~ExpressionRegistry() {}
+
+// to be used only to create function_signature_start
+ExpressionRegistry::FunctionSignatureIterator::FunctionSignatureIterator(
+ native_func_iterator_type nf_it, native_func_iterator_type nf_it_end)
+ : native_func_it_{nf_it},
+ native_func_it_end_{nf_it_end},
+ func_sig_it_{&(nf_it->signatures().front())} {}
+
+// to be used only to create function_signature_end
+ExpressionRegistry::FunctionSignatureIterator::FunctionSignatureIterator(
+ func_sig_iterator_type fs_it)
+ : native_func_it_{nullptr}, native_func_it_end_{nullptr}, func_sig_it_{fs_it} {}
+
+const ExpressionRegistry::FunctionSignatureIterator
+ExpressionRegistry::function_signature_begin() {
+ return FunctionSignatureIterator(function_registry_->begin(),
+ function_registry_->end());
+}
+
+const ExpressionRegistry::FunctionSignatureIterator
+ExpressionRegistry::function_signature_end() const {
+ return FunctionSignatureIterator(&(*(function_registry_->back()->signatures().end())));
+}
+
+bool ExpressionRegistry::FunctionSignatureIterator::operator!=(
+ const FunctionSignatureIterator& func_sign_it) {
+ return func_sign_it.func_sig_it_ != this->func_sig_it_;
+}
+
+FunctionSignature ExpressionRegistry::FunctionSignatureIterator::operator*() {
+ return *func_sig_it_;
+}
+
+ExpressionRegistry::func_sig_iterator_type ExpressionRegistry::FunctionSignatureIterator::
+operator++(int increment) {
+ ++func_sig_it_;
+ // point func_sig_it_ to first signature of next nativefunction if func_sig_it_ is
+ // pointing to end
+ if (func_sig_it_ == &(*native_func_it_->signatures().end())) {
+ ++native_func_it_;
+ if (native_func_it_ == native_func_it_end_) { // last native function
+ return func_sig_it_;
+ }
+ func_sig_it_ = &(native_func_it_->signatures().front());
+ }
+ return func_sig_it_;
+}
+
+static void AddArrowTypesToVector(arrow::Type::type type, DataTypeVector& vector);
+
+static DataTypeVector InitSupportedTypes() {
+ DataTypeVector data_type_vector;
+ llvm::LLVMContext llvm_context;
+ LLVMTypes llvm_types(llvm_context);
+ auto supported_arrow_types = llvm_types.GetSupportedArrowTypes();
+ for (auto& type_id : supported_arrow_types) {
+ AddArrowTypesToVector(type_id, data_type_vector);
+ }
+ return data_type_vector;
+}
+
+DataTypeVector ExpressionRegistry::supported_types_ = InitSupportedTypes();
+
+static void AddArrowTypesToVector(arrow::Type::type type, DataTypeVector& vector) {
+ switch (type) {
+ case arrow::Type::type::BOOL:
+ vector.push_back(arrow::boolean());
+ break;
+ case arrow::Type::type::UINT8:
+ vector.push_back(arrow::uint8());
+ break;
+ case arrow::Type::type::INT8:
+ vector.push_back(arrow::int8());
+ break;
+ case arrow::Type::type::UINT16:
+ vector.push_back(arrow::uint16());
+ break;
+ case arrow::Type::type::INT16:
+ vector.push_back(arrow::int16());
+ break;
+ case arrow::Type::type::UINT32:
+ vector.push_back(arrow::uint32());
+ break;
+ case arrow::Type::type::INT32:
+ vector.push_back(arrow::int32());
+ break;
+ case arrow::Type::type::UINT64:
+ vector.push_back(arrow::uint64());
+ break;
+ case arrow::Type::type::INT64:
+ vector.push_back(arrow::int64());
+ break;
+ case arrow::Type::type::HALF_FLOAT:
+ vector.push_back(arrow::float16());
+ break;
+ case arrow::Type::type::FLOAT:
+ vector.push_back(arrow::float32());
+ break;
+ case arrow::Type::type::DOUBLE:
+ vector.push_back(arrow::float64());
+ break;
+ case arrow::Type::type::STRING:
+ vector.push_back(arrow::utf8());
+ break;
+ case arrow::Type::type::BINARY:
+ vector.push_back(arrow::binary());
+ break;
+ case arrow::Type::type::DATE32:
+ vector.push_back(arrow::date32());
+ break;
+ case arrow::Type::type::DATE64:
+ vector.push_back(arrow::date64());
+ break;
+ case arrow::Type::type::TIMESTAMP:
+ vector.push_back(arrow::timestamp(arrow::TimeUnit::SECOND));
+ vector.push_back(arrow::timestamp(arrow::TimeUnit::MILLI));
+ vector.push_back(arrow::timestamp(arrow::TimeUnit::NANO));
+ vector.push_back(arrow::timestamp(arrow::TimeUnit::MICRO));
+ break;
+ case arrow::Type::type::TIME32:
+ vector.push_back(arrow::time32(arrow::TimeUnit::SECOND));
+ vector.push_back(arrow::time32(arrow::TimeUnit::MILLI));
+ break;
+ case arrow::Type::type::TIME64:
+ vector.push_back(arrow::time64(arrow::TimeUnit::MICRO));
+ vector.push_back(arrow::time64(arrow::TimeUnit::NANO));
+ break;
+ case arrow::Type::type::NA:
+ vector.push_back(arrow::null());
+ break;
+ case arrow::Type::type::DECIMAL:
+ vector.push_back(arrow::decimal(38, 0));
+ break;
+ case arrow::Type::type::INTERVAL_MONTHS:
+ vector.push_back(arrow::month_interval());
+ break;
+ case arrow::Type::type::INTERVAL_DAY_TIME:
+ vector.push_back(arrow::day_time_interval());
+ break;
+ default:
+ // Unsupported types. test ensures that
+ // when one of these are added build breaks.
+ DCHECK(false);
+ }
+}
+
+std::vector<std::shared_ptr<FunctionSignature>> GetRegisteredFunctionSignatures() {
+ ExpressionRegistry registry;
+ std::vector<std::shared_ptr<FunctionSignature>> signatures;
+ for (auto iter = registry.function_signature_begin();
+ iter != registry.function_signature_end(); iter++) {
+ signatures.push_back(std::make_shared<FunctionSignature>(
+ (*iter).base_name(), (*iter).param_types(), (*iter).ret_type()));
+ }
+ return signatures;
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/expression_registry.h b/src/arrow/cpp/src/gandiva/expression_registry.h
new file mode 100644
index 000000000..fb4f177ba
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/expression_registry.h
@@ -0,0 +1,71 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <vector>
+
+#include "gandiva/arrow.h"
+#include "gandiva/function_signature.h"
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+class NativeFunction;
+class FunctionRegistry;
+/// \brief Exports types supported by Gandiva for processing.
+///
+/// Has helper methods for clients to programmatically discover
+/// data types and functions supported by Gandiva.
+class GANDIVA_EXPORT ExpressionRegistry {
+ public:
+ using native_func_iterator_type = const NativeFunction*;
+ using func_sig_iterator_type = const FunctionSignature*;
+ ExpressionRegistry();
+ ~ExpressionRegistry();
+ static DataTypeVector supported_types() { return supported_types_; }
+ class GANDIVA_EXPORT FunctionSignatureIterator {
+ public:
+ explicit FunctionSignatureIterator(native_func_iterator_type nf_it,
+ native_func_iterator_type nf_it_end_);
+ explicit FunctionSignatureIterator(func_sig_iterator_type fs_it);
+
+ bool operator!=(const FunctionSignatureIterator& func_sign_it);
+
+ FunctionSignature operator*();
+
+ func_sig_iterator_type operator++(int);
+
+ private:
+ native_func_iterator_type native_func_it_;
+ const native_func_iterator_type native_func_it_end_;
+ func_sig_iterator_type func_sig_it_;
+ };
+ const FunctionSignatureIterator function_signature_begin();
+ const FunctionSignatureIterator function_signature_end() const;
+
+ private:
+ static DataTypeVector supported_types_;
+ std::unique_ptr<FunctionRegistry> function_registry_;
+};
+
+GANDIVA_EXPORT
+std::vector<std::shared_ptr<FunctionSignature>> GetRegisteredFunctionSignatures();
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/expression_registry_test.cc b/src/arrow/cpp/src/gandiva/expression_registry_test.cc
new file mode 100644
index 000000000..c254ff4f3
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/expression_registry_test.cc
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/expression_registry.h"
+
+#include <algorithm>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "gandiva/function_registry.h"
+#include "gandiva/function_signature.h"
+#include "gandiva/llvm_types.h"
+
+namespace gandiva {
+
+typedef int64_t (*add_vector_func_t)(int64_t* elements, int nelements);
+
+class TestExpressionRegistry : public ::testing::Test {
+ protected:
+ FunctionRegistry registry_;
+};
+
+// Verify all functions in registry are exported.
+TEST_F(TestExpressionRegistry, VerifySupportedFunctions) {
+ std::vector<FunctionSignature> functions;
+ ExpressionRegistry expr_registry;
+ for (auto iter = expr_registry.function_signature_begin();
+ iter != expr_registry.function_signature_end(); iter++) {
+ functions.push_back((*iter));
+ }
+ for (auto& iter : registry_) {
+ for (auto& func_iter : iter.signatures()) {
+ auto element = std::find(functions.begin(), functions.end(), func_iter);
+ EXPECT_NE(element, functions.end()) << "function signature " << func_iter.ToString()
+ << " missing in supported functions.\n";
+ }
+ }
+}
+
+// Verify all types are supported.
+TEST_F(TestExpressionRegistry, VerifyDataTypes) {
+ DataTypeVector data_types = ExpressionRegistry::supported_types();
+ llvm::LLVMContext llvm_context;
+ LLVMTypes llvm_types(llvm_context);
+ auto supported_arrow_types = llvm_types.GetSupportedArrowTypes();
+ for (auto& type_id : supported_arrow_types) {
+ auto element =
+ std::find(supported_arrow_types.begin(), supported_arrow_types.end(), type_id);
+ EXPECT_NE(element, supported_arrow_types.end())
+ << "data type " << type_id << " missing in supported data types.\n";
+ }
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/field_descriptor.h b/src/arrow/cpp/src/gandiva/field_descriptor.h
new file mode 100644
index 000000000..0fe6fe37f
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/field_descriptor.h
@@ -0,0 +1,69 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+
+#include "gandiva/arrow.h"
+
+namespace gandiva {
+
+/// \brief Descriptor for an arrow field. Holds indexes into the flattened array of
+/// buffers that is passed to LLVM generated functions.
+class FieldDescriptor {
+ public:
+ static const int kInvalidIdx = -1;
+
+ FieldDescriptor(FieldPtr field, int data_idx, int validity_idx = kInvalidIdx,
+ int offsets_idx = kInvalidIdx, int data_buffer_ptr_idx = kInvalidIdx)
+ : field_(field),
+ data_idx_(data_idx),
+ validity_idx_(validity_idx),
+ offsets_idx_(offsets_idx),
+ data_buffer_ptr_idx_(data_buffer_ptr_idx) {}
+
+ /// Index of validity array in the array-of-buffers
+ int validity_idx() const { return validity_idx_; }
+
+ /// Index of data array in the array-of-buffers
+ int data_idx() const { return data_idx_; }
+
+ /// Index of offsets array in the array-of-buffers
+ int offsets_idx() const { return offsets_idx_; }
+
+ /// Index of data buffer pointer in the array-of-buffers
+ int data_buffer_ptr_idx() const { return data_buffer_ptr_idx_; }
+
+ FieldPtr field() const { return field_; }
+
+ const std::string& Name() const { return field_->name(); }
+ DataTypePtr Type() const { return field_->type(); }
+
+ bool HasOffsetsIdx() const { return offsets_idx_ != kInvalidIdx; }
+
+ bool HasDataBufferPtrIdx() const { return data_buffer_ptr_idx_ != kInvalidIdx; }
+
+ private:
+ FieldPtr field_;
+ int data_idx_;
+ int validity_idx_;
+ int offsets_idx_;
+ int data_buffer_ptr_idx_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/filter.cc b/src/arrow/cpp/src/gandiva/filter.cc
new file mode 100644
index 000000000..875cc5447
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/filter.cc
@@ -0,0 +1,171 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/filter.h"
+
+#include <memory>
+#include <thread>
+#include <utility>
+#include <vector>
+
+#include "arrow/util/hash_util.h"
+
+#include "gandiva/bitmap_accumulator.h"
+#include "gandiva/cache.h"
+#include "gandiva/condition.h"
+#include "gandiva/expr_validator.h"
+#include "gandiva/llvm_generator.h"
+#include "gandiva/selection_vector_impl.h"
+
+namespace gandiva {
+
+FilterCacheKey::FilterCacheKey(SchemaPtr schema,
+ std::shared_ptr<Configuration> configuration,
+ Expression& expression)
+ : schema_(schema), configuration_(configuration), uniqifier_(0) {
+ static const int kSeedValue = 4;
+ size_t result = kSeedValue;
+ expression_as_string_ = expression.ToString();
+ UpdateUniqifier(expression_as_string_);
+ arrow::internal::hash_combine(result, expression_as_string_);
+ arrow::internal::hash_combine(result, configuration);
+ arrow::internal::hash_combine(result, schema_->ToString());
+ arrow::internal::hash_combine(result, uniqifier_);
+ hash_code_ = result;
+}
+
+bool FilterCacheKey::operator==(const FilterCacheKey& other) const {
+ // arrow schema does not overload equality operators.
+ if (!(schema_->Equals(*other.schema().get(), true))) {
+ return false;
+ }
+
+ if (configuration_ != other.configuration_) {
+ return false;
+ }
+
+ if (expression_as_string_ != other.expression_as_string_) {
+ return false;
+ }
+
+ if (uniqifier_ != other.uniqifier_) {
+ return false;
+ }
+ return true;
+}
+
+std::string FilterCacheKey::ToString() const {
+ std::stringstream ss;
+ // indent, window, indent_size, null_rep and skip new lines.
+ arrow::PrettyPrintOptions options{0, 10, 2, "null", true};
+ DCHECK_OK(PrettyPrint(*schema_.get(), options, &ss));
+
+ ss << "Condition: [" << expression_as_string_ << "]";
+ return ss.str();
+}
+
+void FilterCacheKey::UpdateUniqifier(const std::string& expr) {
+ // caching of expressions with re2 patterns causes lock contention. So, use
+ // multiple instances to reduce contention.
+ if (expr.find(" like(") != std::string::npos) {
+ uniqifier_ = std::hash<std::thread::id>()(std::this_thread::get_id()) % 16;
+ }
+}
+
+Filter::Filter(std::unique_ptr<LLVMGenerator> llvm_generator, SchemaPtr schema,
+ std::shared_ptr<Configuration> configuration)
+ : llvm_generator_(std::move(llvm_generator)),
+ schema_(schema),
+ configuration_(configuration) {}
+
+Filter::~Filter() {}
+
+Status Filter::Make(SchemaPtr schema, ConditionPtr condition,
+ std::shared_ptr<Configuration> configuration,
+ std::shared_ptr<Filter>* filter) {
+ ARROW_RETURN_IF(schema == nullptr, Status::Invalid("Schema cannot be null"));
+ ARROW_RETURN_IF(condition == nullptr, Status::Invalid("Condition cannot be null"));
+ ARROW_RETURN_IF(configuration == nullptr,
+ Status::Invalid("Configuration cannot be null"));
+
+ static Cache<FilterCacheKey, std::shared_ptr<Filter>> cache;
+ FilterCacheKey cache_key(schema, configuration, *(condition.get()));
+ auto cachedFilter = cache.GetModule(cache_key);
+ if (cachedFilter != nullptr) {
+ *filter = cachedFilter;
+ return Status::OK();
+ }
+
+ // Build LLVM generator, and generate code for the specified expression
+ std::unique_ptr<LLVMGenerator> llvm_gen;
+ ARROW_RETURN_NOT_OK(LLVMGenerator::Make(configuration, &llvm_gen));
+
+ // Run the validation on the expression.
+ // Return if the expression is invalid since we will not be able to process further.
+ ExprValidator expr_validator(llvm_gen->types(), schema);
+ ARROW_RETURN_NOT_OK(expr_validator.Validate(condition));
+
+ // Start measuring build time
+ auto begin = std::chrono::high_resolution_clock::now();
+ ARROW_RETURN_NOT_OK(llvm_gen->Build({condition}, SelectionVector::Mode::MODE_NONE));
+ // Stop measuring time and calculate the elapsed time
+ auto end = std::chrono::high_resolution_clock::now();
+ auto elapsed =
+ std::chrono::duration_cast<std::chrono::milliseconds>(end - begin).count();
+
+ // Instantiate the filter with the completely built llvm generator
+ *filter = std::make_shared<Filter>(std::move(llvm_gen), schema, configuration);
+ ValueCacheObject<std::shared_ptr<Filter>> value_cache(*filter, elapsed);
+ cache.PutModule(cache_key, value_cache);
+
+ return Status::OK();
+}
+
+Status Filter::Evaluate(const arrow::RecordBatch& batch,
+ std::shared_ptr<SelectionVector> out_selection) {
+ const auto num_rows = batch.num_rows();
+ ARROW_RETURN_IF(!batch.schema()->Equals(*schema_),
+ Status::Invalid("RecordBatch schema must expected filter schema"));
+ ARROW_RETURN_IF(num_rows == 0, Status::Invalid("RecordBatch must be non-empty."));
+ ARROW_RETURN_IF(out_selection == nullptr,
+ Status::Invalid("out_selection must be non-null."));
+ ARROW_RETURN_IF(out_selection->GetMaxSlots() < num_rows,
+ Status::Invalid("Output selection vector capacity too small"));
+
+ // Allocate three local_bitmaps (one for output, one for validity, one to compute the
+ // intersection).
+ LocalBitMapsHolder bitmaps(num_rows, 3 /*local_bitmaps*/);
+ int64_t bitmap_size = bitmaps.GetLocalBitMapSize();
+
+ auto validity = std::make_shared<arrow::Buffer>(bitmaps.GetLocalBitMap(0), bitmap_size);
+ auto value = std::make_shared<arrow::Buffer>(bitmaps.GetLocalBitMap(1), bitmap_size);
+ auto array_data = arrow::ArrayData::Make(arrow::boolean(), num_rows, {validity, value});
+
+ // Execute the expression(s).
+ ARROW_RETURN_NOT_OK(llvm_generator_->Execute(batch, {array_data}));
+
+ // Compute the intersection of the value and validity.
+ auto result = bitmaps.GetLocalBitMap(2);
+ BitMapAccumulator::IntersectBitMaps(
+ result, {bitmaps.GetLocalBitMap(0), bitmaps.GetLocalBitMap((1))}, {0, 0}, num_rows);
+
+ return out_selection->PopulateFromBitMap(result, bitmap_size, num_rows - 1);
+}
+
+std::string Filter::DumpIR() { return llvm_generator_->DumpIR(); }
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/filter.h b/src/arrow/cpp/src/gandiva/filter.h
new file mode 100644
index 000000000..70ccd7cf0
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/filter.h
@@ -0,0 +1,112 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/status.h"
+
+#include "gandiva/arrow.h"
+#include "gandiva/condition.h"
+#include "gandiva/configuration.h"
+#include "gandiva/selection_vector.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+class LLVMGenerator;
+
+class FilterCacheKey {
+ public:
+ FilterCacheKey(SchemaPtr schema, std::shared_ptr<Configuration> configuration,
+ Expression& expression);
+
+ std::size_t Hash() const { return hash_code_; }
+
+ bool operator==(const FilterCacheKey& other) const;
+
+ bool operator!=(const FilterCacheKey& other) const { return !(*this == other); }
+
+ SchemaPtr schema() const { return schema_; }
+
+ std::string ToString() const;
+
+ private:
+ void UpdateUniqifier(const std::string& expr);
+
+ const SchemaPtr schema_;
+ const std::shared_ptr<Configuration> configuration_;
+ std::string expression_as_string_;
+ size_t hash_code_;
+ uint32_t uniqifier_;
+};
+
+/// \brief filter records based on a condition.
+///
+/// A filter is built for a specific schema and condition. Once the filter is built, it
+/// can be used to evaluate many row batches.
+class GANDIVA_EXPORT Filter {
+ public:
+ Filter(std::unique_ptr<LLVMGenerator> llvm_generator, SchemaPtr schema,
+ std::shared_ptr<Configuration> config);
+
+ // Inline dtor will attempt to resolve the destructor for
+ // LLVMGenerator on MSVC, so we compile the dtor in the object code
+ ~Filter();
+
+ /// Build a filter for the given schema and condition, with the default configuration.
+ ///
+ /// \param[in] schema schema for the record batches, and the condition.
+ /// \param[in] condition filter condition.
+ /// \param[out] filter the returned filter object
+ static Status Make(SchemaPtr schema, ConditionPtr condition,
+ std::shared_ptr<Filter>* filter) {
+ return Make(schema, condition, ConfigurationBuilder::DefaultConfiguration(), filter);
+ }
+
+ /// \brief Build a filter for the given schema and condition.
+ /// Customize the filter with runtime configuration.
+ ///
+ /// \param[in] schema schema for the record batches, and the condition.
+ /// \param[in] condition filter conditions.
+ /// \param[in] config run time configuration.
+ /// \param[out] filter the returned filter object
+ static Status Make(SchemaPtr schema, ConditionPtr condition,
+ std::shared_ptr<Configuration> config,
+ std::shared_ptr<Filter>* filter);
+
+ /// Evaluate the specified record batch, and populate output selection vector.
+ ///
+ /// \param[in] batch the record batch. schema should be the same as the one in 'Make'
+ /// \param[in,out] out_selection the selection array with indices of rows that match
+ /// the condition.
+ Status Evaluate(const arrow::RecordBatch& batch,
+ std::shared_ptr<SelectionVector> out_selection);
+
+ std::string DumpIR();
+
+ private:
+ std::unique_ptr<LLVMGenerator> llvm_generator_;
+ SchemaPtr schema_;
+ std::shared_ptr<Configuration> configuration_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/formatting_utils.h b/src/arrow/cpp/src/gandiva/formatting_utils.h
new file mode 100644
index 000000000..7bc6a4969
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/formatting_utils.h
@@ -0,0 +1,69 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/type.h"
+#include "arrow/util/formatting.h"
+#include "arrow/vendored/double-conversion/double-conversion.h"
+
+namespace gandiva {
+
+/// \brief The entry point for conversion to strings.
+template <typename ARROW_TYPE, typename Enable = void>
+class GdvStringFormatter;
+
+using double_conversion::DoubleToStringConverter;
+
+template <typename ARROW_TYPE>
+class FloatToStringGdvMixin
+ : public arrow::internal::FloatToStringFormatterMixin<ARROW_TYPE> {
+ public:
+ using arrow::internal::FloatToStringFormatterMixin<
+ ARROW_TYPE>::FloatToStringFormatterMixin;
+
+ // The mixin is a modified version of the existent FloatToStringFormatterMixin, but
+ // it defines some specific parameters in the FloatToStringFormatterMixin to cast
+ // the float numbers to string using the same patterns like Java.
+ //
+ // The Java real numbers are represented in two ways following these rules:
+ //- If the number is greater or equals than 10^7 and less than 10^(-3)
+ // it will be represented using scientific notation, e.g:
+ // - 0.000012 -> 1.2E-5
+ // - 10000002.3 -> 1.00000023E7
+ //- If the numbers are between that interval above, they are showed as is.
+ explicit FloatToStringGdvMixin(const std::shared_ptr<arrow::DataType>& = NULLPTR)
+ : arrow::internal::FloatToStringFormatterMixin<ARROW_TYPE>(
+ DoubleToStringConverter::EMIT_TRAILING_ZERO_AFTER_POINT |
+ DoubleToStringConverter::EMIT_TRAILING_DECIMAL_POINT,
+ "Infinity", "NaN", 'E', -3, 7, 3, 1) {}
+};
+
+template <>
+class GdvStringFormatter<arrow::FloatType>
+ : public FloatToStringGdvMixin<arrow::FloatType> {
+ public:
+ using FloatToStringGdvMixin::FloatToStringGdvMixin;
+};
+
+template <>
+class GdvStringFormatter<arrow::DoubleType>
+ : public FloatToStringGdvMixin<arrow::DoubleType> {
+ public:
+ using FloatToStringGdvMixin::FloatToStringGdvMixin;
+};
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/func_descriptor.h b/src/arrow/cpp/src/gandiva/func_descriptor.h
new file mode 100644
index 000000000..a2bf3a16b
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/func_descriptor.h
@@ -0,0 +1,50 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+#include <vector>
+
+#include "gandiva/arrow.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// Descriptor for a function in the expression.
+class GANDIVA_EXPORT FuncDescriptor {
+ public:
+ FuncDescriptor(const std::string& name, const DataTypeVector& params,
+ DataTypePtr return_type)
+ : name_(name), params_(params), return_type_(return_type) {}
+
+ /// base function name.
+ const std::string& name() const { return name_; }
+
+ /// Data types of the input params.
+ const DataTypeVector& params() const { return params_; }
+
+ /// Data type of the return parameter.
+ DataTypePtr return_type() const { return return_type_; }
+
+ private:
+ std::string name_;
+ DataTypeVector params_;
+ DataTypePtr return_type_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_holder.h b/src/arrow/cpp/src/gandiva/function_holder.h
new file mode 100644
index 000000000..e3576f09c
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_holder.h
@@ -0,0 +1,34 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// Holder for a function that can be invoked from LLVM.
+class GANDIVA_EXPORT FunctionHolder {
+ public:
+ virtual ~FunctionHolder() = default;
+};
+
+using FunctionHolderPtr = std::shared_ptr<FunctionHolder>;
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_holder_registry.h b/src/arrow/cpp/src/gandiva/function_holder_registry.h
new file mode 100644
index 000000000..ced153891
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_holder_registry.h
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include "arrow/status.h"
+
+#include "gandiva/function_holder.h"
+#include "gandiva/like_holder.h"
+#include "gandiva/node.h"
+#include "gandiva/random_generator_holder.h"
+#include "gandiva/replace_holder.h"
+#include "gandiva/to_date_holder.h"
+
+namespace gandiva {
+
+#define LAMBDA_MAKER(derived) \
+ [](const FunctionNode& node, FunctionHolderPtr* holder) { \
+ std::shared_ptr<derived> derived_instance; \
+ auto status = derived::Make(node, &derived_instance); \
+ if (status.ok()) { \
+ *holder = derived_instance; \
+ } \
+ return status; \
+ }
+
+/// Static registry of function holders.
+class FunctionHolderRegistry {
+ public:
+ using maker_type = std::function<Status(const FunctionNode&, FunctionHolderPtr*)>;
+ using map_type = std::unordered_map<std::string, maker_type>;
+
+ static Status Make(const std::string& name, const FunctionNode& node,
+ FunctionHolderPtr* holder) {
+ auto found = makers().find(name);
+ if (found == makers().end()) {
+ return Status::Invalid("function holder not registered for function " + name);
+ }
+
+ return found->second(node, holder);
+ }
+
+ private:
+ static map_type& makers() {
+ static map_type maker_map = {
+ {"like", LAMBDA_MAKER(LikeHolder)},
+ {"ilike", LAMBDA_MAKER(LikeHolder)},
+ {"to_date", LAMBDA_MAKER(ToDateHolder)},
+ {"random", LAMBDA_MAKER(RandomGeneratorHolder)},
+ {"rand", LAMBDA_MAKER(RandomGeneratorHolder)},
+ {"regexp_replace", LAMBDA_MAKER(ReplaceHolder)},
+ };
+ return maker_map;
+ }
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_ir_builder.cc b/src/arrow/cpp/src/gandiva/function_ir_builder.cc
new file mode 100644
index 000000000..194273933
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_ir_builder.cc
@@ -0,0 +1,81 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/function_ir_builder.h"
+
+namespace gandiva {
+
+llvm::Value* FunctionIRBuilder::BuildIfElse(llvm::Value* condition,
+ llvm::Type* return_type,
+ std::function<llvm::Value*()> then_func,
+ std::function<llvm::Value*()> else_func) {
+ llvm::IRBuilder<>* builder = ir_builder();
+ llvm::Function* function = builder->GetInsertBlock()->getParent();
+ DCHECK_NE(function, nullptr);
+
+ // Create blocks for the then, else and merge cases.
+ llvm::BasicBlock* then_bb = llvm::BasicBlock::Create(*context(), "then", function);
+ llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context(), "else", function);
+ llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context(), "merge", function);
+
+ builder->CreateCondBr(condition, then_bb, else_bb);
+
+ // Emit the then block.
+ builder->SetInsertPoint(then_bb);
+ auto then_value = then_func();
+ builder->CreateBr(merge_bb);
+
+ // refresh then_bb for phi (could have changed due to code generation of then_value).
+ then_bb = builder->GetInsertBlock();
+
+ // Emit the else block.
+ builder->SetInsertPoint(else_bb);
+ auto else_value = else_func();
+ builder->CreateBr(merge_bb);
+
+ // refresh else_bb for phi (could have changed due to code generation of else_value).
+ else_bb = builder->GetInsertBlock();
+
+ // Emit the merge block.
+ builder->SetInsertPoint(merge_bb);
+ llvm::PHINode* result_value = builder->CreatePHI(return_type, 2, "res_value");
+ result_value->addIncoming(then_value, then_bb);
+ result_value->addIncoming(else_value, else_bb);
+ return result_value;
+}
+
+llvm::Function* FunctionIRBuilder::BuildFunction(const std::string& function_name,
+ llvm::Type* return_type,
+ std::vector<NamedArg> in_args) {
+ std::vector<llvm::Type*> arg_types;
+ for (auto& arg : in_args) {
+ arg_types.push_back(arg.type);
+ }
+ auto prototype = llvm::FunctionType::get(return_type, arg_types, false /*isVarArg*/);
+ auto function = llvm::Function::Create(prototype, llvm::GlobalValue::ExternalLinkage,
+ function_name, module());
+
+ uint32_t i = 0;
+ for (auto& fn_arg : function->args()) {
+ DCHECK_LT(i, in_args.size());
+ fn_arg.setName(in_args[i].name);
+ ++i;
+ }
+ return function;
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_ir_builder.h b/src/arrow/cpp/src/gandiva/function_ir_builder.h
new file mode 100644
index 000000000..388f55840
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_ir_builder.h
@@ -0,0 +1,61 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gandiva/engine.h"
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/llvm_types.h"
+
+namespace gandiva {
+
+/// @brief Base class for building IR functions.
+class FunctionIRBuilder {
+ public:
+ explicit FunctionIRBuilder(Engine* engine) : engine_(engine) {}
+ virtual ~FunctionIRBuilder() = default;
+
+ protected:
+ LLVMTypes* types() { return engine_->types(); }
+ llvm::Module* module() { return engine_->module(); }
+ llvm::LLVMContext* context() { return engine_->context(); }
+ llvm::IRBuilder<>* ir_builder() { return engine_->ir_builder(); }
+
+ /// Build an if-else block.
+ llvm::Value* BuildIfElse(llvm::Value* condition, llvm::Type* return_type,
+ std::function<llvm::Value*()> then_func,
+ std::function<llvm::Value*()> else_func);
+
+ struct NamedArg {
+ std::string name;
+ llvm::Type* type;
+ };
+
+ /// Build llvm fn.
+ llvm::Function* BuildFunction(const std::string& function_name, llvm::Type* return_type,
+ std::vector<NamedArg> in_args);
+
+ private:
+ Engine* engine_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry.cc b/src/arrow/cpp/src/gandiva/function_registry.cc
new file mode 100644
index 000000000..d5d015c10
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry.cc
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/function_registry.h"
+#include "gandiva/function_registry_arithmetic.h"
+#include "gandiva/function_registry_datetime.h"
+#include "gandiva/function_registry_hash.h"
+#include "gandiva/function_registry_math_ops.h"
+#include "gandiva/function_registry_string.h"
+#include "gandiva/function_registry_timestamp_arithmetic.h"
+
+#include <iterator>
+#include <utility>
+#include <vector>
+
+namespace gandiva {
+
+FunctionRegistry::iterator FunctionRegistry::begin() const {
+ return &(*pc_registry_.begin());
+}
+
+FunctionRegistry::iterator FunctionRegistry::end() const {
+ return &(*pc_registry_.end());
+}
+
+FunctionRegistry::iterator FunctionRegistry::back() const {
+ return &(pc_registry_.back());
+}
+
+std::vector<NativeFunction> FunctionRegistry::pc_registry_;
+
+SignatureMap FunctionRegistry::pc_registry_map_ = InitPCMap();
+
+SignatureMap FunctionRegistry::InitPCMap() {
+ SignatureMap map;
+
+ auto v1 = GetArithmeticFunctionRegistry();
+ pc_registry_.insert(std::end(pc_registry_), v1.begin(), v1.end());
+ auto v2 = GetDateTimeFunctionRegistry();
+ pc_registry_.insert(std::end(pc_registry_), v2.begin(), v2.end());
+
+ auto v3 = GetHashFunctionRegistry();
+ pc_registry_.insert(std::end(pc_registry_), v3.begin(), v3.end());
+
+ auto v4 = GetMathOpsFunctionRegistry();
+ pc_registry_.insert(std::end(pc_registry_), v4.begin(), v4.end());
+
+ auto v5 = GetStringFunctionRegistry();
+ pc_registry_.insert(std::end(pc_registry_), v5.begin(), v5.end());
+
+ auto v6 = GetDateTimeArithmeticFunctionRegistry();
+ pc_registry_.insert(std::end(pc_registry_), v6.begin(), v6.end());
+
+ for (auto& elem : pc_registry_) {
+ for (auto& func_signature : elem.signatures()) {
+ map.insert(std::make_pair(&(func_signature), &elem));
+ }
+ }
+
+ return map;
+}
+
+const NativeFunction* FunctionRegistry::LookupSignature(
+ const FunctionSignature& signature) const {
+ auto got = pc_registry_map_.find(&signature);
+ return got == pc_registry_map_.end() ? nullptr : got->second;
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry.h b/src/arrow/cpp/src/gandiva/function_registry.h
new file mode 100644
index 000000000..d92563260
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry.h
@@ -0,0 +1,47 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <vector>
+#include "gandiva/function_registry_common.h"
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/native_function.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+///\brief Registry of pre-compiled IR functions.
+class GANDIVA_EXPORT FunctionRegistry {
+ public:
+ using iterator = const NativeFunction*;
+
+ /// Lookup a pre-compiled function by its signature.
+ const NativeFunction* LookupSignature(const FunctionSignature& signature) const;
+
+ iterator begin() const;
+ iterator end() const;
+ iterator back() const;
+
+ private:
+ static SignatureMap InitPCMap();
+
+ static std::vector<NativeFunction> pc_registry_;
+ static SignatureMap pc_registry_map_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry_arithmetic.cc b/src/arrow/cpp/src/gandiva/function_registry_arithmetic.cc
new file mode 100644
index 000000000..f34289f37
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry_arithmetic.cc
@@ -0,0 +1,125 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/function_registry_arithmetic.h"
+#include "gandiva/function_registry_common.h"
+
+namespace gandiva {
+
+#define BINARY_SYMMETRIC_FN(name, ALIASES) \
+ NUMERIC_TYPES(BINARY_SYMMETRIC_SAFE_NULL_IF_NULL, name, ALIASES)
+
+#define BINARY_RELATIONAL_BOOL_FN(name, ALIASES) \
+ NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, name, ALIASES)
+
+#define BINARY_RELATIONAL_BOOL_DATE_FN(name, ALIASES) \
+ NUMERIC_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, name, ALIASES)
+
+#define UNARY_CAST_TO_FLOAT64(type) UNARY_SAFE_NULL_IF_NULL(castFLOAT8, {}, type, float64)
+
+#define UNARY_CAST_TO_FLOAT32(type) UNARY_SAFE_NULL_IF_NULL(castFLOAT4, {}, type, float32)
+
+#define UNARY_CAST_TO_INT32(type) UNARY_SAFE_NULL_IF_NULL(castINT, {}, type, int32)
+
+#define UNARY_CAST_TO_INT64(type) UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, type, int64)
+
+std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
+ static std::vector<NativeFunction> arithmetic_fn_registry_ = {
+ UNARY_SAFE_NULL_IF_NULL(not, {}, boolean, boolean),
+ UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, int32, int64),
+ UNARY_SAFE_NULL_IF_NULL(castINT, {}, int64, int32),
+ UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, decimal128, int64),
+
+ // cast to float32
+ UNARY_CAST_TO_FLOAT32(int32), UNARY_CAST_TO_FLOAT32(int64),
+ UNARY_CAST_TO_FLOAT32(float64),
+
+ // cast to int32
+ UNARY_CAST_TO_INT32(float32), UNARY_CAST_TO_INT32(float64),
+
+ // cast to int64
+ UNARY_CAST_TO_INT64(float32), UNARY_CAST_TO_INT64(float64),
+
+ // cast to float64
+ UNARY_CAST_TO_FLOAT64(int32), UNARY_CAST_TO_FLOAT64(int64),
+ UNARY_CAST_TO_FLOAT64(float32), UNARY_CAST_TO_FLOAT64(decimal128),
+
+ // cast to decimal
+ UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, int32, decimal128),
+ UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, int64, decimal128),
+ UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, float32, decimal128),
+ UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, float64, decimal128),
+ UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, decimal128, decimal128),
+ UNARY_UNSAFE_NULL_IF_NULL(castDECIMAL, {}, utf8, decimal128),
+
+ NativeFunction("castDECIMALNullOnOverflow", {}, DataTypeVector{decimal128()},
+ decimal128(), kResultNullInternal,
+ "castDECIMALNullOnOverflow_decimal128"),
+
+ UNARY_SAFE_NULL_IF_NULL(castDATE, {}, int64, date64),
+ UNARY_SAFE_NULL_IF_NULL(castDATE, {}, int32, date32),
+ UNARY_SAFE_NULL_IF_NULL(castDATE, {}, date32, date64),
+
+ // add/sub/multiply/divide/mod
+ BINARY_SYMMETRIC_FN(add, {}), BINARY_SYMMETRIC_FN(subtract, {}),
+ BINARY_SYMMETRIC_FN(multiply, {}),
+ NUMERIC_TYPES(BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL, divide, {}),
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int64, int32, int32),
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int64, int64, int64),
+ BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, {"modulo"}, decimal128),
+ BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, {"modulo"}, float64),
+ BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, int32),
+ BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, int64),
+ BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, float32),
+ BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, float64),
+
+ // bitwise operators
+ BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_and, {}, int32),
+ BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_and, {}, int64),
+ BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_or, {}, int32),
+ BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_or, {}, int64),
+ BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_xor, {}, int32),
+ BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_xor, {}, int64),
+ UNARY_SAFE_NULL_IF_NULL(bitwise_not, {}, int32, int32),
+ UNARY_SAFE_NULL_IF_NULL(bitwise_not, {}, int64, int64),
+
+ // round functions
+ UNARY_SAFE_NULL_IF_NULL(round, {}, float32, float32),
+ UNARY_SAFE_NULL_IF_NULL(round, {}, float64, float64),
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(round, {}, float32, int32, float32),
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(round, {}, float64, int32, float64),
+ UNARY_SAFE_NULL_IF_NULL(round, {}, int32, int32),
+ UNARY_SAFE_NULL_IF_NULL(round, {}, int64, int64),
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(round, {}, int32, int32, int32),
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(round, {}, int64, int32, int64),
+
+ // compare functions
+ BINARY_RELATIONAL_BOOL_FN(equal, ({"eq", "same"})),
+ BINARY_RELATIONAL_BOOL_FN(not_equal, {}),
+ BINARY_RELATIONAL_BOOL_DATE_FN(less_than, {}),
+ BINARY_RELATIONAL_BOOL_DATE_FN(less_than_or_equal_to, {}),
+ BINARY_RELATIONAL_BOOL_DATE_FN(greater_than, {}),
+ BINARY_RELATIONAL_BOOL_DATE_FN(greater_than_or_equal_to, {}),
+
+ // binary representation of integer values
+ UNARY_UNSAFE_NULL_IF_NULL(bin, {}, int32, utf8),
+ UNARY_UNSAFE_NULL_IF_NULL(bin, {}, int64, utf8)};
+
+ return arithmetic_fn_registry_;
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry_arithmetic.h b/src/arrow/cpp/src/gandiva/function_registry_arithmetic.h
new file mode 100644
index 000000000..693d3b95e
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry_arithmetic.h
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <vector>
+#include "gandiva/native_function.h"
+
+namespace gandiva {
+
+std::vector<NativeFunction> GetArithmeticFunctionRegistry();
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry_common.h b/src/arrow/cpp/src/gandiva/function_registry_common.h
new file mode 100644
index 000000000..66f945150
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry_common.h
@@ -0,0 +1,268 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "gandiva/arrow.h"
+#include "gandiva/function_signature.h"
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/native_function.h"
+
+/* This is a private file, intended for internal use by gandiva & must not be included
+ * directly.
+ */
+namespace gandiva {
+
+using arrow::binary;
+using arrow::boolean;
+using arrow::date32;
+using arrow::date64;
+using arrow::day_time_interval;
+using arrow::float32;
+using arrow::float64;
+using arrow::int16;
+using arrow::int32;
+using arrow::int64;
+using arrow::int8;
+using arrow::month_interval;
+using arrow::uint16;
+using arrow::uint32;
+using arrow::uint64;
+using arrow::uint8;
+using arrow::utf8;
+
+inline DataTypePtr time32() { return arrow::time32(arrow::TimeUnit::MILLI); }
+
+inline DataTypePtr time64() { return arrow::time64(arrow::TimeUnit::MICRO); }
+
+inline DataTypePtr timestamp() { return arrow::timestamp(arrow::TimeUnit::MILLI); }
+inline DataTypePtr decimal128() { return arrow::decimal(38, 0); }
+
+struct KeyHash {
+ std::size_t operator()(const FunctionSignature* k) const { return k->Hash(); }
+};
+
+struct KeyEquals {
+ bool operator()(const FunctionSignature* s1, const FunctionSignature* s2) const {
+ return *s1 == *s2;
+ }
+};
+
+typedef std::unordered_map<const FunctionSignature*, const NativeFunction*, KeyHash,
+ KeyEquals>
+ SignatureMap;
+
+// Binary functions that :
+// - have the same input type for both params
+// - output type is same as the input type
+// - NULL handling is of type NULL_IF_NULL
+//
+// The pre-compiled fn name includes the base name & input type names. eg. add_int32_int32
+#define BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, \
+ DataTypeVector{TYPE(), TYPE()}, TYPE(), kResultNullIfNull, \
+ ARROW_STRINGIFY(NAME##_##TYPE##_##TYPE))
+
+// Binary functions that :
+// - have the same input type for both params
+// - NULL handling is of type NULL_IINTERNAL
+// - can return error.
+//
+// The pre-compiled fn name includes the base name & input type names. eg. add_int32_int32
+#define BINARY_UNSAFE_NULL_IF_NULL(NAME, ALIASES, IN_TYPE, OUT_TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, \
+ DataTypeVector{IN_TYPE(), IN_TYPE()}, OUT_TYPE(), kResultNullIfNull, \
+ ARROW_STRINGIFY(NAME##_##IN_TYPE##_##IN_TYPE), \
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)
+
+#define BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \
+ BINARY_UNSAFE_NULL_IF_NULL(NAME, ALIASES, TYPE, TYPE)
+
+// Binary functions that :
+// - have different input types, or output type
+// - NULL handling is of type NULL_IF_NULL
+//
+// The pre-compiled fn name includes the base name & input type names. eg. mod_int64_int32
+#define BINARY_GENERIC_SAFE_NULL_IF_NULL(NAME, ALIASES, IN_TYPE1, IN_TYPE2, OUT_TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, \
+ DataTypeVector{IN_TYPE1(), IN_TYPE2()}, OUT_TYPE(), kResultNullIfNull, \
+ ARROW_STRINGIFY(NAME##_##IN_TYPE1##_##IN_TYPE2))
+
+// Binary functions that :
+// - have the same input type
+// - output type is boolean
+// - NULL handling is of type NULL_IF_NULL
+//
+// The pre-compiled fn name includes the base name & input type names.
+// eg. equal_int32_int32
+#define BINARY_RELATIONAL_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, \
+ DataTypeVector{TYPE(), TYPE()}, boolean(), kResultNullIfNull, \
+ ARROW_STRINGIFY(NAME##_##TYPE##_##TYPE))
+
+// Unary functions that :
+// - NULL handling is of type NULL_IF_NULL
+//
+// The pre-compiled fn name includes the base name & input type name. eg. castFloat_int32
+#define UNARY_SAFE_NULL_IF_NULL(NAME, ALIASES, IN_TYPE, OUT_TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{IN_TYPE()}, \
+ OUT_TYPE(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_##IN_TYPE))
+
+// Unary functions that :
+// - NULL handling is of type NULL_NEVER
+//
+// The pre-compiled fn name includes the base name & input type name. eg. isnull_int32
+#define UNARY_SAFE_NULL_NEVER_BOOL(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \
+ boolean(), kResultNullNever, ARROW_STRINGIFY(NAME##_##TYPE))
+
+// Unary functions that :
+// - NULL handling is of type NULL_INTERNAL
+//
+// The pre-compiled fn name includes the base name & input type name. eg. castFloat_int32
+#define UNARY_UNSAFE_NULL_IF_NULL(NAME, ALIASES, IN_TYPE, OUT_TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{IN_TYPE()}, \
+ OUT_TYPE(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_##IN_TYPE), \
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)
+
+// Binary functions that :
+// - NULL handling is of type NULL_NEVER
+//
+// The pre-compiled fn name includes the base name & input type names,
+// eg. is_distinct_from_int32_int32
+#define BINARY_SAFE_NULL_NEVER_BOOL(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, \
+ DataTypeVector{TYPE(), TYPE()}, boolean(), kResultNullNever, \
+ ARROW_STRINGIFY(NAME##_##TYPE##_##TYPE))
+
+// Extract functions (used with data/time types) that :
+// - NULL handling is of type NULL_IF_NULL
+//
+// The pre-compiled fn name includes the base name & input type name. eg. extractYear_date
+#define EXTRACT_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \
+ int64(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_##TYPE))
+
+#define TRUNCATE_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \
+ TYPE(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_##TYPE))
+
+// Last day functions (used with data/time types) that :
+// - NULL handling is of type NULL_IF_NULL
+//
+// The pre-compiled fn name includes the base name & input type name. eg:
+// - last_day_from_date64
+#define LAST_DAY_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \
+ date64(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_from_##TYPE))
+
+// Hash32 functions that :
+// - NULL handling is of type NULL_NEVER
+//
+// The pre-compiled fn name includes the base name & input type name. hash32_int8
+#define HASH32_SAFE_NULL_NEVER(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \
+ int32(), kResultNullNever, ARROW_STRINGIFY(NAME##_##TYPE))
+
+// Hash32 functions that :
+// - NULL handling is of type NULL_NEVER
+//
+// The pre-compiled fn name includes the base name & input type name. hash32_int8
+#define HASH64_SAFE_NULL_NEVER(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \
+ int64(), kResultNullNever, ARROW_STRINGIFY(NAME##_##TYPE))
+
+// Hash32 functions with seed that :
+// - NULL handling is of type NULL_NEVER
+//
+// The pre-compiled fn name includes the base name & input type name. hash32WithSeed_int8
+#define HASH32_SEED_SAFE_NULL_NEVER(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, \
+ DataTypeVector{TYPE(), int32()}, int32(), kResultNullNever, \
+ ARROW_STRINGIFY(NAME##WithSeed_##TYPE))
+
+// Hash64 functions with seed that :
+// - NULL handling is of type NULL_NEVER
+//
+// The pre-compiled fn name includes the base name & input type name. hash32WithSeed_int8
+#define HASH64_SEED_SAFE_NULL_NEVER(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, \
+ DataTypeVector{TYPE(), int64()}, int64(), kResultNullNever, \
+ ARROW_STRINGIFY(NAME##WithSeed_##TYPE))
+
+// HashSHA1 functions that :
+// - NULL handling is of type NULL_NEVER
+// - can return errors
+//
+// The function name includes the base name & input type name. gdv_fn_sha1_float64
+#define HASH_SHA1_NULL_NEVER(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, {"sha", "sha1"}, DataTypeVector{TYPE()}, utf8(), \
+ kResultNullNever, ARROW_STRINGIFY(gdv_fn_sha1_##TYPE), \
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)
+
+// HashSHA256 functions that :
+// - NULL handling is of type NULL_NEVER
+// - can return errors
+//
+// The function name includes the base name & input type name. gdv_fn_sha256_float64
+#define HASH_SHA256_NULL_NEVER(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, {"sha256"}, DataTypeVector{TYPE()}, utf8(), kResultNullNever, \
+ ARROW_STRINGIFY(gdv_fn_sha256_##TYPE), \
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)
+
+// Iterate the inner macro over all numeric types
+#define BASE_NUMERIC_TYPES(INNER, NAME, ALIASES) \
+ INNER(NAME, ALIASES, int8), INNER(NAME, ALIASES, int16), INNER(NAME, ALIASES, int32), \
+ INNER(NAME, ALIASES, int64), INNER(NAME, ALIASES, uint8), \
+ INNER(NAME, ALIASES, uint16), INNER(NAME, ALIASES, uint32), \
+ INNER(NAME, ALIASES, uint64), INNER(NAME, ALIASES, float32), \
+ INNER(NAME, ALIASES, float64)
+
+// Iterate the inner macro over all base numeric types
+#define NUMERIC_TYPES(INNER, NAME, ALIASES) \
+ BASE_NUMERIC_TYPES(INNER, NAME, ALIASES), INNER(NAME, ALIASES, decimal128)
+
+// Iterate the inner macro over numeric and date/time types
+#define NUMERIC_DATE_TYPES(INNER, NAME, ALIASES) \
+ NUMERIC_TYPES(INNER, NAME, ALIASES), DATE_TYPES(INNER, NAME, ALIASES), \
+ TIME_TYPES(INNER, NAME, ALIASES), INNER(NAME, ALIASES, date32)
+
+// Iterate the inner macro over all date types
+#define DATE_TYPES(INNER, NAME, ALIASES) \
+ INNER(NAME, ALIASES, date64), INNER(NAME, ALIASES, timestamp)
+
+// Iterate the inner macro over all time types
+#define TIME_TYPES(INNER, NAME, ALIASES) INNER(NAME, ALIASES, time32)
+
+// Iterate the inner macro over all data types
+#define VAR_LEN_TYPES(INNER, NAME, ALIASES) \
+ INNER(NAME, ALIASES, utf8), INNER(NAME, ALIASES, binary)
+
+// Iterate the inner macro over all numeric types, date types and bool type
+#define NUMERIC_BOOL_DATE_TYPES(INNER, NAME, ALIASES) \
+ NUMERIC_DATE_TYPES(INNER, NAME, ALIASES), INNER(NAME, ALIASES, boolean)
+
+// Iterate the inner macro over all numeric types, date types, bool and varlen types
+#define NUMERIC_BOOL_DATE_VAR_LEN_TYPES(INNER, NAME, ALIASES) \
+ NUMERIC_BOOL_DATE_TYPES(INNER, NAME, ALIASES), VAR_LEN_TYPES(INNER, NAME, ALIASES)
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry_datetime.cc b/src/arrow/cpp/src/gandiva/function_registry_datetime.cc
new file mode 100644
index 000000000..b8d2e7b6c
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry_datetime.cc
@@ -0,0 +1,132 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/function_registry_datetime.h"
+
+#include "gandiva/function_registry_common.h"
+
+namespace gandiva {
+
+#define DATE_EXTRACTION_TRUNCATION_FNS(INNER, name) \
+ DATE_TYPES(INNER, name##Millennium, {}), DATE_TYPES(INNER, name##Century, {}), \
+ DATE_TYPES(INNER, name##Decade, {}), DATE_TYPES(INNER, name##Year, {"year"}), \
+ DATE_TYPES(INNER, name##Quarter, {}), DATE_TYPES(INNER, name##Month, {"month"}), \
+ DATE_TYPES(INNER, name##Week, ({"weekofyear", "yearweek"})), \
+ DATE_TYPES(INNER, name##Day, ({"day", "dayofmonth"})), \
+ DATE_TYPES(INNER, name##Hour, {"hour"}), \
+ DATE_TYPES(INNER, name##Minute, {"minute"}), \
+ DATE_TYPES(INNER, name##Second, {"second"})
+
+#define TO_TIMESTAMP_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \
+ timestamp(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_##TYPE))
+
+#define TO_TIME_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \
+ NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \
+ time32(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_##TYPE))
+
+#define TIME_EXTRACTION_FNS(name) \
+ TIME_TYPES(EXTRACT_SAFE_NULL_IF_NULL, name##Hour, {"hour"}), \
+ TIME_TYPES(EXTRACT_SAFE_NULL_IF_NULL, name##Minute, {"minute"}), \
+ TIME_TYPES(EXTRACT_SAFE_NULL_IF_NULL, name##Second, {"second"})
+
+std::vector<NativeFunction> GetDateTimeFunctionRegistry() {
+ static std::vector<NativeFunction> date_time_fn_registry_ = {
+ DATE_EXTRACTION_TRUNCATION_FNS(EXTRACT_SAFE_NULL_IF_NULL, extract),
+ DATE_EXTRACTION_TRUNCATION_FNS(TRUNCATE_SAFE_NULL_IF_NULL, date_trunc_),
+
+ DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractDoy, {}),
+ DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractDow, {}),
+ DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractEpoch, {}),
+
+ TIME_EXTRACTION_FNS(extract),
+
+ NativeFunction("castDATE", {}, DataTypeVector{utf8()}, date64(), kResultNullIfNull,
+ "castDATE_utf8",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castTIMESTAMP", {}, DataTypeVector{utf8()}, timestamp(),
+ kResultNullIfNull, "castTIMESTAMP_utf8",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castVARCHAR", {}, DataTypeVector{timestamp(), int64()}, utf8(),
+ kResultNullIfNull, "castVARCHAR_timestamp_int64",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("to_date", {}, DataTypeVector{utf8(), utf8()}, date64(),
+ kResultNullInternal, "gdv_fn_to_date_utf8_utf8",
+ NativeFunction::kNeedsContext |
+ NativeFunction::kNeedsFunctionHolder |
+ NativeFunction::kCanReturnErrors),
+
+ NativeFunction("to_date", {}, DataTypeVector{utf8(), utf8(), int32()}, date64(),
+ kResultNullInternal, "gdv_fn_to_date_utf8_utf8_int32",
+ NativeFunction::kNeedsContext |
+ NativeFunction::kNeedsFunctionHolder |
+ NativeFunction::kCanReturnErrors),
+ NativeFunction("castTIMESTAMP", {}, DataTypeVector{date64()}, timestamp(),
+ kResultNullIfNull, "castTIMESTAMP_date64"),
+
+ NativeFunction("castTIMESTAMP", {}, DataTypeVector{int64()}, timestamp(),
+ kResultNullIfNull, "castTIMESTAMP_int64"),
+
+ NativeFunction("castDATE", {"to_date"}, DataTypeVector{timestamp()}, date64(),
+ kResultNullIfNull, "castDATE_timestamp"),
+
+ NativeFunction("castTIME", {}, DataTypeVector{timestamp()}, time32(),
+ kResultNullIfNull, "castTIME_timestamp"),
+
+ NativeFunction("castBIGINT", {}, DataTypeVector{day_time_interval()}, int64(),
+ kResultNullIfNull, "castBIGINT_daytimeinterval"),
+
+ NativeFunction("castINT", {"castNULLABLEINT"}, DataTypeVector{month_interval()},
+ int32(), kResultNullIfNull, "castINT_year_interval",
+ NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castBIGINT", {"castNULLABLEBIGINT"},
+ DataTypeVector{month_interval()}, int64(), kResultNullIfNull,
+ "castBIGINT_year_interval", NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castNULLABLEINTERVALYEAR", {"castINTERVALYEAR"},
+ DataTypeVector{int32()}, month_interval(), kResultNullIfNull,
+ "castNULLABLEINTERVALYEAR_int32",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castNULLABLEINTERVALYEAR", {"castINTERVALYEAR"},
+ DataTypeVector{int64()}, month_interval(), kResultNullIfNull,
+ "castNULLABLEINTERVALYEAR_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castNULLABLEINTERVALDAY", {"castINTERVALDAY"},
+ DataTypeVector{int32()}, day_time_interval(), kResultNullIfNull,
+ "castNULLABLEINTERVALDAY_int32"),
+
+ NativeFunction("castNULLABLEINTERVALDAY", {"castINTERVALDAY"},
+ DataTypeVector{int64()}, day_time_interval(), kResultNullIfNull,
+ "castNULLABLEINTERVALDAY_int64"),
+
+ NativeFunction("extractDay", {}, DataTypeVector{day_time_interval()}, int64(),
+ kResultNullIfNull, "extractDay_daytimeinterval"),
+
+ DATE_TYPES(LAST_DAY_SAFE_NULL_IF_NULL, last_day, {}),
+ BASE_NUMERIC_TYPES(TO_TIME_SAFE_NULL_IF_NULL, to_time, {}),
+ BASE_NUMERIC_TYPES(TO_TIMESTAMP_SAFE_NULL_IF_NULL, to_timestamp, {})};
+
+ return date_time_fn_registry_;
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry_datetime.h b/src/arrow/cpp/src/gandiva/function_registry_datetime.h
new file mode 100644
index 000000000..46172ec62
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry_datetime.h
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <vector>
+#include "gandiva/native_function.h"
+
+namespace gandiva {
+
+std::vector<NativeFunction> GetDateTimeFunctionRegistry();
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry_hash.cc b/src/arrow/cpp/src/gandiva/function_registry_hash.cc
new file mode 100644
index 000000000..7fad9321e
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry_hash.cc
@@ -0,0 +1,63 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/function_registry_hash.h"
+#include "gandiva/function_registry_common.h"
+
+namespace gandiva {
+
+#define HASH32_SAFE_NULL_NEVER_FN(name, ALIASES) \
+ NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH32_SAFE_NULL_NEVER, name, ALIASES)
+
+#define HASH32_SEED_SAFE_NULL_NEVER_FN(name, ALIASES) \
+ NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH32_SEED_SAFE_NULL_NEVER, name, ALIASES)
+
+#define HASH64_SAFE_NULL_NEVER_FN(name, ALIASES) \
+ NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH64_SAFE_NULL_NEVER, name, ALIASES)
+
+#define HASH64_SEED_SAFE_NULL_NEVER_FN(name, ALIASES) \
+ NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH64_SEED_SAFE_NULL_NEVER, name, ALIASES)
+
+#define HASH_SHA1_NULL_NEVER_FN(name, ALIASES) \
+ NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH_SHA1_NULL_NEVER, name, ALIASES)
+
+#define HASH_SHA256_NULL_NEVER_FN(name, ALIASES) \
+ NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH_SHA256_NULL_NEVER, name, ALIASES)
+
+std::vector<NativeFunction> GetHashFunctionRegistry() {
+ static std::vector<NativeFunction> hash_fn_registry_ = {
+ HASH32_SAFE_NULL_NEVER_FN(hash, {}),
+ HASH32_SAFE_NULL_NEVER_FN(hash32, {}),
+ HASH32_SAFE_NULL_NEVER_FN(hash32AsDouble, {}),
+
+ HASH32_SEED_SAFE_NULL_NEVER_FN(hash32, {}),
+ HASH32_SEED_SAFE_NULL_NEVER_FN(hash32AsDouble, {}),
+
+ HASH64_SAFE_NULL_NEVER_FN(hash64, {}),
+ HASH64_SAFE_NULL_NEVER_FN(hash64AsDouble, {}),
+
+ HASH64_SEED_SAFE_NULL_NEVER_FN(hash64, {}),
+ HASH64_SEED_SAFE_NULL_NEVER_FN(hash64AsDouble, {}),
+
+ HASH_SHA1_NULL_NEVER_FN(hashSHA1, {}),
+
+ HASH_SHA256_NULL_NEVER_FN(hashSHA256, {})};
+
+ return hash_fn_registry_;
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry_hash.h b/src/arrow/cpp/src/gandiva/function_registry_hash.h
new file mode 100644
index 000000000..4f96d30cf
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry_hash.h
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <vector>
+#include "gandiva/native_function.h"
+
+namespace gandiva {
+
+std::vector<NativeFunction> GetHashFunctionRegistry();
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry_math_ops.cc b/src/arrow/cpp/src/gandiva/function_registry_math_ops.cc
new file mode 100644
index 000000000..49afd4003
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry_math_ops.cc
@@ -0,0 +1,106 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/function_registry_math_ops.h"
+#include "gandiva/function_registry_common.h"
+
+namespace gandiva {
+
+#define MATH_UNARY_OPS(name, ALIASES) \
+ UNARY_SAFE_NULL_IF_NULL(name, ALIASES, int32, float64), \
+ UNARY_SAFE_NULL_IF_NULL(name, ALIASES, int64, float64), \
+ UNARY_SAFE_NULL_IF_NULL(name, ALIASES, uint32, float64), \
+ UNARY_SAFE_NULL_IF_NULL(name, ALIASES, uint64, float64), \
+ UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float32, float64), \
+ UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float64, float64)
+
+#define MATH_BINARY_UNSAFE(name, ALIASES) \
+ BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, int32, float64), \
+ BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, int64, float64), \
+ BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, uint32, float64), \
+ BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, uint64, float64), \
+ BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, float32, float64), \
+ BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, float64, float64)
+
+#define MATH_BINARY_SAFE(name, ALIASES) \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int32, int32, float64), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int64, int64, float64), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, uint32, uint32, float64), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, uint64, uint64, float64), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, float32, float32, float64), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, float64, float64, float64)
+
+#define UNARY_SAFE_NULL_NEVER_BOOL_FN(name, ALIASES) \
+ NUMERIC_BOOL_DATE_TYPES(UNARY_SAFE_NULL_NEVER_BOOL, name, ALIASES)
+
+#define BINARY_SAFE_NULL_NEVER_BOOL_FN(name, ALIASES) \
+ NUMERIC_BOOL_DATE_TYPES(BINARY_SAFE_NULL_NEVER_BOOL, name, ALIASES)
+
+std::vector<NativeFunction> GetMathOpsFunctionRegistry() {
+ static std::vector<NativeFunction> math_fn_registry_ = {
+ MATH_UNARY_OPS(cbrt, {}), MATH_UNARY_OPS(exp, {}), MATH_UNARY_OPS(log, {}),
+ MATH_UNARY_OPS(log10, {}),
+
+ MATH_BINARY_UNSAFE(log, {}),
+
+ BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(power, {"pow"}, float64),
+
+ UNARY_SAFE_NULL_NEVER_BOOL_FN(isnull, {}),
+ UNARY_SAFE_NULL_NEVER_BOOL_FN(isnotnull, {}),
+
+ NUMERIC_TYPES(UNARY_SAFE_NULL_NEVER_BOOL, isnumeric, {}),
+
+ BINARY_SAFE_NULL_NEVER_BOOL_FN(is_distinct_from, {}),
+ BINARY_SAFE_NULL_NEVER_BOOL_FN(is_not_distinct_from, {}),
+
+ // trigonometry functions
+ MATH_UNARY_OPS(sin, {}), MATH_UNARY_OPS(cos, {}), MATH_UNARY_OPS(asin, {}),
+ MATH_UNARY_OPS(acos, {}), MATH_UNARY_OPS(tan, {}), MATH_UNARY_OPS(atan, {}),
+ MATH_UNARY_OPS(sinh, {}), MATH_UNARY_OPS(cosh, {}), MATH_UNARY_OPS(tanh, {}),
+ MATH_UNARY_OPS(cot, {}), MATH_UNARY_OPS(radians, {}), MATH_UNARY_OPS(degrees, {}),
+ MATH_BINARY_SAFE(atan2, {}),
+
+ // decimal functions
+ UNARY_SAFE_NULL_IF_NULL(abs, {}, decimal128, decimal128),
+ UNARY_SAFE_NULL_IF_NULL(ceil, {}, decimal128, decimal128),
+ UNARY_SAFE_NULL_IF_NULL(floor, {}, decimal128, decimal128),
+ UNARY_SAFE_NULL_IF_NULL(round, {}, decimal128, decimal128),
+ UNARY_SAFE_NULL_IF_NULL(truncate, {"trunc"}, decimal128, decimal128),
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(round, {}, decimal128, int32, decimal128),
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(truncate, {"trunc"}, decimal128, int32,
+ decimal128),
+
+ NativeFunction("truncate", {"trunc"}, DataTypeVector{int64(), int32()}, int64(),
+ kResultNullIfNull, "truncate_int64_int32"),
+ NativeFunction("random", {"rand"}, DataTypeVector{}, float64(), kResultNullNever,
+ "gdv_fn_random", NativeFunction::kNeedsFunctionHolder),
+ NativeFunction("random", {"rand"}, DataTypeVector{int32()}, float64(),
+ kResultNullNever, "gdv_fn_random_with_seed",
+ NativeFunction::kNeedsFunctionHolder)};
+
+ return math_fn_registry_;
+}
+
+#undef MATH_UNARY_OPS
+
+#undef MATH_BINARY_UNSAFE
+
+#undef UNARY_SAFE_NULL_NEVER_BOOL_FN
+
+#undef BINARY_SAFE_NULL_NEVER_BOOL_FN
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry_math_ops.h b/src/arrow/cpp/src/gandiva/function_registry_math_ops.h
new file mode 100644
index 000000000..2c8a40d53
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry_math_ops.h
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <vector>
+#include "gandiva/native_function.h"
+
+namespace gandiva {
+
+std::vector<NativeFunction> GetMathOpsFunctionRegistry();
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry_string.cc b/src/arrow/cpp/src/gandiva/function_registry_string.cc
new file mode 100644
index 000000000..3ea426c85
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry_string.cc
@@ -0,0 +1,422 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/function_registry_string.h"
+
+#include "gandiva/function_registry_common.h"
+
+namespace gandiva {
+
+#define BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN(name, ALIASES) \
+ VAR_LEN_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, name, ALIASES)
+
+#define BINARY_RELATIONAL_SAFE_NULL_IF_NULL_UTF8_FN(name, ALIASES) \
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL(name, ALIASES, utf8)
+
+#define UNARY_OCTET_LEN_FN(name, ALIASES) \
+ UNARY_SAFE_NULL_IF_NULL(name, ALIASES, utf8, int32), \
+ UNARY_SAFE_NULL_IF_NULL(name, ALIASES, binary, int32)
+
+#define UNARY_SAFE_NULL_NEVER_BOOL_FN(name, ALIASES) \
+ VAR_LEN_TYPES(UNARY_SAFE_NULL_NEVER_BOOL, name, ALIASES)
+
+std::vector<NativeFunction> GetStringFunctionRegistry() {
+ static std::vector<NativeFunction> string_fn_registry_ = {
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN(equal, {}),
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN(not_equal, {}),
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN(less_than, {}),
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN(less_than_or_equal_to, {}),
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN(greater_than, {}),
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN(greater_than_or_equal_to, {}),
+
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL_UTF8_FN(starts_with, {}),
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL_UTF8_FN(ends_with, {}),
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL_UTF8_FN(is_substr, {}),
+
+ BINARY_UNSAFE_NULL_IF_NULL(locate, {"position"}, utf8, int32),
+ BINARY_UNSAFE_NULL_IF_NULL(strpos, {}, utf8, int32),
+
+ UNARY_OCTET_LEN_FN(octet_length, {}), UNARY_OCTET_LEN_FN(bit_length, {}),
+
+ UNARY_UNSAFE_NULL_IF_NULL(char_length, {}, utf8, int32),
+ UNARY_UNSAFE_NULL_IF_NULL(length, {}, utf8, int32),
+ UNARY_UNSAFE_NULL_IF_NULL(lengthUtf8, {}, binary, int32),
+ UNARY_UNSAFE_NULL_IF_NULL(reverse, {}, utf8, utf8),
+ UNARY_UNSAFE_NULL_IF_NULL(ltrim, {}, utf8, utf8),
+ UNARY_UNSAFE_NULL_IF_NULL(rtrim, {}, utf8, utf8),
+ UNARY_UNSAFE_NULL_IF_NULL(btrim, {}, utf8, utf8),
+ UNARY_UNSAFE_NULL_IF_NULL(space, {}, int32, utf8),
+ UNARY_UNSAFE_NULL_IF_NULL(space, {}, int64, utf8),
+
+ UNARY_SAFE_NULL_NEVER_BOOL_FN(isnull, {}),
+ UNARY_SAFE_NULL_NEVER_BOOL_FN(isnotnull, {}),
+
+ NativeFunction("ascii", {}, DataTypeVector{utf8()}, int32(), kResultNullIfNull,
+ "ascii_utf8"),
+
+ NativeFunction("base64", {}, DataTypeVector{binary()}, utf8(), kResultNullIfNull,
+ "gdv_fn_base64_encode_binary", NativeFunction::kNeedsContext),
+
+ NativeFunction("unbase64", {}, DataTypeVector{utf8()}, binary(), kResultNullIfNull,
+ "gdv_fn_base64_decode_utf8", NativeFunction::kNeedsContext),
+
+ NativeFunction("repeat", {}, DataTypeVector{utf8(), int32()}, utf8(),
+ kResultNullIfNull, "repeat_utf8_int32",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("upper", {}, DataTypeVector{utf8()}, utf8(), kResultNullIfNull,
+ "gdv_fn_upper_utf8", NativeFunction::kNeedsContext),
+
+ NativeFunction("lower", {}, DataTypeVector{utf8()}, utf8(), kResultNullIfNull,
+ "gdv_fn_lower_utf8", NativeFunction::kNeedsContext),
+
+ NativeFunction("initcap", {}, DataTypeVector{utf8()}, utf8(), kResultNullIfNull,
+ "gdv_fn_initcap_utf8",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castBIT", {"castBOOLEAN"}, DataTypeVector{utf8()}, boolean(),
+ kResultNullIfNull, "castBIT_utf8", NativeFunction::kNeedsContext),
+
+ NativeFunction("castINT", {}, DataTypeVector{utf8()}, int32(), kResultNullIfNull,
+ "gdv_fn_castINT_utf8",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castBIGINT", {}, DataTypeVector{utf8()}, int64(), kResultNullIfNull,
+ "gdv_fn_castBIGINT_utf8",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castFLOAT4", {}, DataTypeVector{utf8()}, float32(),
+ kResultNullIfNull, "gdv_fn_castFLOAT4_utf8",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castFLOAT8", {}, DataTypeVector{utf8()}, float64(),
+ kResultNullIfNull, "gdv_fn_castFLOAT8_utf8",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castINT", {}, DataTypeVector{binary()}, int32(), kResultNullIfNull,
+ "gdv_fn_castINT_varbinary",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castBIGINT", {}, DataTypeVector{binary()}, int64(),
+ kResultNullIfNull, "gdv_fn_castBIGINT_varbinary",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castFLOAT4", {}, DataTypeVector{binary()}, float32(),
+ kResultNullIfNull, "gdv_fn_castFLOAT4_varbinary",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castFLOAT8", {}, DataTypeVector{binary()}, float64(),
+ kResultNullIfNull, "gdv_fn_castFLOAT8_varbinary",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castVARCHAR", {}, DataTypeVector{boolean(), int64()}, utf8(),
+ kResultNullIfNull, "castVARCHAR_bool_int64",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("castVARCHAR", {}, DataTypeVector{utf8(), int64()}, utf8(),
+ kResultNullIfNull, "castVARCHAR_utf8_int64",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("castVARCHAR", {}, DataTypeVector{binary(), int64()}, utf8(),
+ kResultNullIfNull, "castVARCHAR_binary_int64",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("castVARCHAR", {}, DataTypeVector{int32(), int64()}, utf8(),
+ kResultNullIfNull, "gdv_fn_castVARCHAR_int32_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castVARCHAR", {}, DataTypeVector{int64(), int64()}, utf8(),
+ kResultNullIfNull, "gdv_fn_castVARCHAR_int64_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castVARCHAR", {}, DataTypeVector{float32(), int64()}, utf8(),
+ kResultNullIfNull, "gdv_fn_castVARCHAR_float32_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castVARCHAR", {}, DataTypeVector{float64(), int64()}, utf8(),
+ kResultNullIfNull, "gdv_fn_castVARCHAR_float64_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castVARCHAR", {}, DataTypeVector{decimal128(), int64()}, utf8(),
+ kResultNullIfNull, "castVARCHAR_decimal128_int64",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("like", {}, DataTypeVector{utf8(), utf8()}, boolean(),
+ kResultNullIfNull, "gdv_fn_like_utf8_utf8",
+ NativeFunction::kNeedsFunctionHolder),
+
+ NativeFunction("like", {}, DataTypeVector{utf8(), utf8(), utf8()}, boolean(),
+ kResultNullIfNull, "gdv_fn_like_utf8_utf8_utf8",
+ NativeFunction::kNeedsFunctionHolder),
+
+ NativeFunction("ilike", {}, DataTypeVector{utf8(), utf8()}, boolean(),
+ kResultNullIfNull, "gdv_fn_ilike_utf8_utf8",
+ NativeFunction::kNeedsFunctionHolder),
+
+ NativeFunction("ltrim", {}, DataTypeVector{utf8(), utf8()}, utf8(),
+ kResultNullIfNull, "ltrim_utf8_utf8", NativeFunction::kNeedsContext),
+
+ NativeFunction("rtrim", {}, DataTypeVector{utf8(), utf8()}, utf8(),
+ kResultNullIfNull, "rtrim_utf8_utf8", NativeFunction::kNeedsContext),
+
+ NativeFunction("btrim", {}, DataTypeVector{utf8(), utf8()}, utf8(),
+ kResultNullIfNull, "btrim_utf8_utf8", NativeFunction::kNeedsContext),
+
+ NativeFunction("substr", {"substring"},
+ DataTypeVector{utf8(), int64() /*offset*/, int64() /*length*/},
+ utf8(), kResultNullIfNull, "substr_utf8_int64_int64",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("substr", {"substring"}, DataTypeVector{utf8(), int64() /*offset*/},
+ utf8(), kResultNullIfNull, "substr_utf8_int64",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("lpad", {}, DataTypeVector{utf8(), int32(), utf8()}, utf8(),
+ kResultNullIfNull, "lpad_utf8_int32_utf8",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("lpad", {}, DataTypeVector{utf8(), int32()}, utf8(),
+ kResultNullIfNull, "lpad_utf8_int32", NativeFunction::kNeedsContext),
+
+ NativeFunction("rpad", {}, DataTypeVector{utf8(), int32(), utf8()}, utf8(),
+ kResultNullIfNull, "rpad_utf8_int32_utf8",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("rpad", {}, DataTypeVector{utf8(), int32()}, utf8(),
+ kResultNullIfNull, "rpad_utf8_int32", NativeFunction::kNeedsContext),
+
+ NativeFunction("regexp_replace", {}, DataTypeVector{utf8(), utf8(), utf8()}, utf8(),
+ kResultNullIfNull, "gdv_fn_regexp_replace_utf8_utf8",
+ NativeFunction::kNeedsContext |
+ NativeFunction::kNeedsFunctionHolder |
+ NativeFunction::kCanReturnErrors),
+
+ NativeFunction("concatOperator", {}, DataTypeVector{utf8(), utf8()}, utf8(),
+ kResultNullIfNull, "concatOperator_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction("concatOperator", {}, DataTypeVector{utf8(), utf8(), utf8()}, utf8(),
+ kResultNullIfNull, "concatOperator_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction("concatOperator", {}, DataTypeVector{utf8(), utf8(), utf8(), utf8()},
+ utf8(), kResultNullIfNull, "concatOperator_utf8_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction("concatOperator", {},
+ DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8()}, utf8(),
+ kResultNullIfNull, "concatOperator_utf8_utf8_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction("concatOperator", {},
+ DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8()},
+ utf8(), kResultNullIfNull,
+ "concatOperator_utf8_utf8_utf8_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction(
+ "concatOperator", {},
+ DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), utf8()}, utf8(),
+ kResultNullIfNull, "concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction(
+ "concatOperator", {},
+ DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), utf8()},
+ utf8(), kResultNullIfNull,
+ "concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction("concatOperator", {},
+ DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(),
+ utf8(), utf8(), utf8()},
+ utf8(), kResultNullIfNull,
+ "concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction("concatOperator", {},
+ DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(),
+ utf8(), utf8(), utf8(), utf8()},
+ utf8(), kResultNullIfNull,
+ "concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+
+ // concat treats null inputs as empty strings whereas concatOperator returns null if
+ // one of the inputs is null
+ NativeFunction("concat", {}, DataTypeVector{utf8(), utf8()}, utf8(),
+ kResultNullNever, "concat_utf8_utf8", NativeFunction::kNeedsContext),
+ NativeFunction("concat", {}, DataTypeVector{utf8(), utf8(), utf8()}, utf8(),
+ kResultNullNever, "concat_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction("concat", {}, DataTypeVector{utf8(), utf8(), utf8(), utf8()}, utf8(),
+ kResultNullNever, "concat_utf8_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction("concat", {}, DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8()},
+ utf8(), kResultNullNever, "concat_utf8_utf8_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction("concat", {},
+ DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8()},
+ utf8(), kResultNullNever, "concat_utf8_utf8_utf8_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction(
+ "concat", {},
+ DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), utf8()}, utf8(),
+ kResultNullNever, "concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction(
+ "concat", {},
+ DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), utf8()},
+ utf8(), kResultNullNever, "concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction("concat", {},
+ DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(),
+ utf8(), utf8(), utf8()},
+ utf8(), kResultNullNever,
+ "concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+ NativeFunction("concat", {},
+ DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(),
+ utf8(), utf8(), utf8(), utf8()},
+ utf8(), kResultNullNever,
+ "concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("byte_substr", {"bytesubstring"},
+ DataTypeVector{binary(), int32(), int32()}, binary(),
+ kResultNullIfNull, "byte_substr_binary_int32_int32",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_fromUTF8", {"convert_fromutf8"}, DataTypeVector{binary()},
+ utf8(), kResultNullIfNull, "convert_fromUTF8_binary",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_replaceUTF8", {"convert_replaceutf8"},
+ DataTypeVector{binary(), utf8()}, utf8(), kResultNullIfNull,
+ "convert_replace_invalid_fromUTF8_binary",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toDOUBLE", {}, DataTypeVector{float64()}, binary(),
+ kResultNullIfNull, "convert_toDOUBLE",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toDOUBLE_be", {}, DataTypeVector{float64()}, binary(),
+ kResultNullIfNull, "convert_toDOUBLE_be",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toFLOAT", {}, DataTypeVector{float32()}, binary(),
+ kResultNullIfNull, "convert_toFLOAT", NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toFLOAT_be", {}, DataTypeVector{float32()}, binary(),
+ kResultNullIfNull, "convert_toFLOAT_be",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toINT", {}, DataTypeVector{int32()}, binary(),
+ kResultNullIfNull, "convert_toINT", NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toINT_be", {}, DataTypeVector{int32()}, binary(),
+ kResultNullIfNull, "convert_toINT_be",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toBIGINT", {}, DataTypeVector{int64()}, binary(),
+ kResultNullIfNull, "convert_toBIGINT",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toBIGINT_be", {}, DataTypeVector{int64()}, binary(),
+ kResultNullIfNull, "convert_toBIGINT_be",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toBOOLEAN_BYTE", {}, DataTypeVector{boolean()}, binary(),
+ kResultNullIfNull, "convert_toBOOLEAN",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toTIME_EPOCH", {}, DataTypeVector{time32()}, binary(),
+ kResultNullIfNull, "convert_toTIME_EPOCH",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toTIME_EPOCH_be", {}, DataTypeVector{time32()}, binary(),
+ kResultNullIfNull, "convert_toTIME_EPOCH_be",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toTIMESTAMP_EPOCH", {}, DataTypeVector{timestamp()},
+ binary(), kResultNullIfNull, "convert_toTIMESTAMP_EPOCH",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toTIMESTAMP_EPOCH_be", {}, DataTypeVector{timestamp()},
+ binary(), kResultNullIfNull, "convert_toTIMESTAMP_EPOCH_be",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toDATE_EPOCH", {}, DataTypeVector{date64()}, binary(),
+ kResultNullIfNull, "convert_toDATE_EPOCH",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toDATE_EPOCH_be", {}, DataTypeVector{date64()}, binary(),
+ kResultNullIfNull, "convert_toDATE_EPOCH_be",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("convert_toUTF8", {}, DataTypeVector{utf8()}, binary(),
+ kResultNullIfNull, "convert_toUTF8", NativeFunction::kNeedsContext),
+
+ NativeFunction("locate", {"position"}, DataTypeVector{utf8(), utf8(), int32()},
+ int32(), kResultNullIfNull, "locate_utf8_utf8_int32",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("replace", {}, DataTypeVector{utf8(), utf8(), utf8()}, utf8(),
+ kResultNullIfNull, "replace_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("binary_string", {}, DataTypeVector{utf8()}, binary(),
+ kResultNullIfNull, "binary_string", NativeFunction::kNeedsContext),
+
+ NativeFunction("left", {}, DataTypeVector{utf8(), int32()}, utf8(),
+ kResultNullIfNull, "left_utf8_int32", NativeFunction::kNeedsContext),
+
+ NativeFunction("right", {}, DataTypeVector{utf8(), int32()}, utf8(),
+ kResultNullIfNull, "right_utf8_int32",
+ NativeFunction::kNeedsContext),
+
+ NativeFunction("castVARBINARY", {}, DataTypeVector{binary(), int64()}, binary(),
+ kResultNullIfNull, "castVARBINARY_binary_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castVARBINARY", {}, DataTypeVector{utf8(), int64()}, binary(),
+ kResultNullIfNull, "castVARBINARY_utf8_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castVARBINARY", {}, DataTypeVector{int32(), int64()}, binary(),
+ kResultNullIfNull, "gdv_fn_castVARBINARY_int32_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castVARBINARY", {}, DataTypeVector{int64(), int64()}, binary(),
+ kResultNullIfNull, "gdv_fn_castVARBINARY_int64_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castVARBINARY", {}, DataTypeVector{float32(), int64()}, binary(),
+ kResultNullIfNull, "gdv_fn_castVARBINARY_float32_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castVARBINARY", {}, DataTypeVector{float64(), int64()}, binary(),
+ kResultNullIfNull, "gdv_fn_castVARBINARY_float64_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("split_part", {}, DataTypeVector{utf8(), utf8(), int32()}, utf8(),
+ kResultNullIfNull, "split_part",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)};
+
+ return string_fn_registry_;
+}
+
+#undef BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN
+
+#undef BINARY_RELATIONAL_SAFE_NULL_IF_NULL_UTF8_FN
+
+#undef UNARY_OCTET_LEN_FN
+
+#undef UNARY_SAFE_NULL_NEVER_BOOL_FN
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry_string.h b/src/arrow/cpp/src/gandiva/function_registry_string.h
new file mode 100644
index 000000000..f14c95a81
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry_string.h
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <vector>
+#include "gandiva/native_function.h"
+
+namespace gandiva {
+
+std::vector<NativeFunction> GetStringFunctionRegistry();
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry_test.cc b/src/arrow/cpp/src/gandiva/function_registry_test.cc
new file mode 100644
index 000000000..e3c1e85f7
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry_test.cc
@@ -0,0 +1,96 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/function_registry.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <algorithm>
+#include <string>
+#include <unordered_set>
+
+namespace gandiva {
+
+class TestFunctionRegistry : public ::testing::Test {
+ protected:
+ FunctionRegistry registry_;
+};
+
+TEST_F(TestFunctionRegistry, TestFound) {
+ FunctionSignature add_i32_i32("add", {arrow::int32(), arrow::int32()}, arrow::int32());
+
+ const NativeFunction* function = registry_.LookupSignature(add_i32_i32);
+ EXPECT_NE(function, nullptr);
+ EXPECT_THAT(function->signatures(), testing::Contains(add_i32_i32));
+ EXPECT_EQ(function->pc_name(), "add_int32_int32");
+}
+
+TEST_F(TestFunctionRegistry, TestNotFound) {
+ FunctionSignature addX_i32_i32("addX", {arrow::int32(), arrow::int32()},
+ arrow::int32());
+ EXPECT_EQ(registry_.LookupSignature(addX_i32_i32), nullptr);
+
+ FunctionSignature add_i32_i32_ret64("add", {arrow::int32(), arrow::int32()},
+ arrow::int64());
+ EXPECT_EQ(registry_.LookupSignature(add_i32_i32_ret64), nullptr);
+}
+
+// one nativefunction object per precompiled function
+TEST_F(TestFunctionRegistry, TestNoDuplicates) {
+ std::unordered_set<std::string> pc_func_sigs;
+ std::unordered_set<std::string> native_func_duplicates;
+ std::unordered_set<std::string> func_sigs;
+ std::unordered_set<std::string> func_sig_duplicates;
+ for (auto native_func_it = registry_.begin(); native_func_it != registry_.end();
+ ++native_func_it) {
+ auto& first_sig = native_func_it->signatures().front();
+ auto pc_func_sig = FunctionSignature(native_func_it->pc_name(),
+ first_sig.param_types(), first_sig.ret_type())
+ .ToString();
+ if (pc_func_sigs.count(pc_func_sig) == 0) {
+ pc_func_sigs.insert(pc_func_sig);
+ } else {
+ native_func_duplicates.insert(pc_func_sig);
+ }
+
+ for (auto& sig : native_func_it->signatures()) {
+ auto sig_str = sig.ToString();
+ if (func_sigs.count(sig_str) == 0) {
+ func_sigs.insert(sig_str);
+ } else {
+ func_sig_duplicates.insert(sig_str);
+ }
+ }
+ }
+ std::ostringstream stream;
+ std::copy(native_func_duplicates.begin(), native_func_duplicates.end(),
+ std::ostream_iterator<std::string>(stream, "\n"));
+ std::string result = stream.str();
+ EXPECT_TRUE(native_func_duplicates.empty())
+ << "Registry has duplicates.\nMultiple NativeFunction objects refer to the "
+ "following precompiled functions:\n"
+ << result;
+
+ stream.clear();
+ std::copy(func_sig_duplicates.begin(), func_sig_duplicates.end(),
+ std::ostream_iterator<std::string>(stream, "\n"));
+ EXPECT_TRUE(func_sig_duplicates.empty())
+ << "The following signatures are defined more than once possibly pointing to "
+ "different precompiled functions:\n"
+ << stream.str();
+}
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry_timestamp_arithmetic.cc b/src/arrow/cpp/src/gandiva/function_registry_timestamp_arithmetic.cc
new file mode 100644
index 000000000..c277dab72
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry_timestamp_arithmetic.cc
@@ -0,0 +1,89 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/function_registry_timestamp_arithmetic.h"
+
+#include "gandiva/function_registry_common.h"
+
+namespace gandiva {
+
+#define TIMESTAMP_ADD_FNS(name, ALIASES) \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int32, timestamp, timestamp), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int32, date64, date64), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int64, timestamp, timestamp), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int64, date64, date64), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, timestamp, int32, timestamp), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, date64, int32, date64), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, timestamp, int64, timestamp), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, date64, int64, date64)
+
+#define TIMESTAMP_DIFF_FN(name, ALIASES) \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, timestamp, timestamp, int32)
+
+#define DATE_ADD_FNS(name, ALIASES) \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, date64, int32, date64), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, timestamp, int32, timestamp), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, date64, int64, date64), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, timestamp, int64, timestamp), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int32, date64, date64), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int32, timestamp, timestamp), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int64, date64, date64), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int64, timestamp, timestamp)
+
+#define DATE_DIFF_FNS(name, ALIASES) \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, date64, int32, date64), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, timestamp, int32, timestamp), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, date64, int64, date64), \
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, timestamp, int64, timestamp)
+
+std::vector<NativeFunction> GetDateTimeArithmeticFunctionRegistry() {
+ static std::vector<NativeFunction> datetime_fn_registry_ = {
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(months_between, {}, date64, date64, float64),
+ BINARY_GENERIC_SAFE_NULL_IF_NULL(months_between, {}, timestamp, timestamp, float64),
+
+ TIMESTAMP_DIFF_FN(timestampdiffSecond, {}),
+ TIMESTAMP_DIFF_FN(timestampdiffMinute, {}),
+ TIMESTAMP_DIFF_FN(timestampdiffHour, {}),
+ TIMESTAMP_DIFF_FN(timestampdiffDay, {"datediff"}),
+ TIMESTAMP_DIFF_FN(timestampdiffWeek, {}),
+ TIMESTAMP_DIFF_FN(timestampdiffMonth, {}),
+ TIMESTAMP_DIFF_FN(timestampdiffQuarter, {}),
+ TIMESTAMP_DIFF_FN(timestampdiffYear, {}),
+
+ TIMESTAMP_ADD_FNS(timestampaddSecond, {}),
+ TIMESTAMP_ADD_FNS(timestampaddMinute, {}),
+ TIMESTAMP_ADD_FNS(timestampaddHour, {}),
+ TIMESTAMP_ADD_FNS(timestampaddDay, {}),
+ TIMESTAMP_ADD_FNS(timestampaddWeek, {}),
+ TIMESTAMP_ADD_FNS(timestampaddMonth, {"add_months"}),
+ TIMESTAMP_ADD_FNS(timestampaddQuarter, {}),
+ TIMESTAMP_ADD_FNS(timestampaddYear, {}),
+
+ DATE_ADD_FNS(date_add, {}),
+ DATE_ADD_FNS(add, {}),
+
+ NativeFunction("add", {}, DataTypeVector{date64(), int64()}, timestamp(),
+ kResultNullIfNull, "add_date64_int64"),
+
+ DATE_DIFF_FNS(date_sub, {}),
+ DATE_DIFF_FNS(subtract, {}),
+ DATE_DIFF_FNS(date_diff, {})};
+
+ return datetime_fn_registry_;
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_registry_timestamp_arithmetic.h b/src/arrow/cpp/src/gandiva/function_registry_timestamp_arithmetic.h
new file mode 100644
index 000000000..9ac3ab2ec
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_registry_timestamp_arithmetic.h
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <vector>
+#include "gandiva/native_function.h"
+
+namespace gandiva {
+
+std::vector<NativeFunction> GetDateTimeArithmeticFunctionRegistry();
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_signature.cc b/src/arrow/cpp/src/gandiva/function_signature.cc
new file mode 100644
index 000000000..6dc641617
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_signature.cc
@@ -0,0 +1,113 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/function_signature.h"
+
+#include <cstddef>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/hash_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string.h"
+
+using arrow::internal::AsciiEqualsCaseInsensitive;
+using arrow::internal::AsciiToLower;
+using arrow::internal::checked_cast;
+using arrow::internal::hash_combine;
+
+namespace gandiva {
+
+bool DataTypeEquals(const DataTypePtr& left, const DataTypePtr& right) {
+ if (left->id() == right->id()) {
+ switch (left->id()) {
+ case arrow::Type::DECIMAL: {
+ // For decimal types, the precision/scale isn't part of the signature.
+ auto dleft = checked_cast<arrow::DecimalType*>(left.get());
+ auto dright = checked_cast<arrow::DecimalType*>(right.get());
+ return (dleft != NULL) && (dright != NULL) &&
+ (dleft->byte_width() == dright->byte_width());
+ }
+ default:
+ return left->Equals(right);
+ }
+ } else {
+ return false;
+ }
+}
+
+FunctionSignature::FunctionSignature(std::string base_name, DataTypeVector param_types,
+ DataTypePtr ret_type)
+ : base_name_(std::move(base_name)),
+ param_types_(std::move(param_types)),
+ ret_type_(std::move(ret_type)) {
+ DCHECK_GT(base_name_.length(), 0);
+ for (auto it = param_types_.begin(); it != param_types_.end(); it++) {
+ DCHECK(*it);
+ }
+ DCHECK(ret_type_);
+}
+
+bool FunctionSignature::operator==(const FunctionSignature& other) const {
+ if (param_types_.size() != other.param_types_.size() ||
+ !DataTypeEquals(ret_type_, other.ret_type_) ||
+ !AsciiEqualsCaseInsensitive(base_name_, other.base_name_)) {
+ return false;
+ }
+
+ for (size_t idx = 0; idx < param_types_.size(); idx++) {
+ if (!DataTypeEquals(param_types_[idx], other.param_types_[idx])) {
+ return false;
+ }
+ }
+ return true;
+}
+
+/// calculated based on name, datatype id of parameters and datatype id
+/// of return type.
+std::size_t FunctionSignature::Hash() const {
+ static const size_t kSeedValue = 17;
+ size_t result = kSeedValue;
+ hash_combine(result, AsciiToLower(base_name_));
+ hash_combine(result, static_cast<size_t>(ret_type_->id()));
+ // not using hash_range since we only want to include the id from the data type
+ for (auto& param_type : param_types_) {
+ hash_combine(result, static_cast<size_t>(param_type->id()));
+ }
+ return result;
+}
+
+std::string FunctionSignature::ToString() const {
+ std::stringstream s;
+
+ s << ret_type_->ToString() << " " << base_name_ << "(";
+ for (uint32_t i = 0; i < param_types_.size(); i++) {
+ if (i > 0) {
+ s << ", ";
+ }
+
+ s << param_types_[i]->ToString();
+ }
+
+ s << ")";
+ return s.str();
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_signature.h b/src/arrow/cpp/src/gandiva/function_signature.h
new file mode 100644
index 000000000..c3e363949
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_signature.h
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+#include <vector>
+
+#include "gandiva/arrow.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// \brief Signature for a function : includes the base name, input param types and
+/// output types.
+class GANDIVA_EXPORT FunctionSignature {
+ public:
+ FunctionSignature(std::string base_name, DataTypeVector param_types,
+ DataTypePtr ret_type);
+
+ bool operator==(const FunctionSignature& other) const;
+
+ /// calculated based on name, datatype id of parameters and datatype id
+ /// of return type.
+ std::size_t Hash() const;
+
+ DataTypePtr ret_type() const { return ret_type_; }
+
+ const std::string& base_name() const { return base_name_; }
+
+ DataTypeVector param_types() const { return param_types_; }
+
+ std::string ToString() const;
+
+ private:
+ std::string base_name_;
+ DataTypeVector param_types_;
+ DataTypePtr ret_type_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/function_signature_test.cc b/src/arrow/cpp/src/gandiva/function_signature_test.cc
new file mode 100644
index 000000000..0eb62d4e7
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/function_signature_test.cc
@@ -0,0 +1,113 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/function_signature.h"
+
+#include <memory>
+
+#include <gtest/gtest.h>
+
+namespace gandiva {
+
+class TestFunctionSignature : public ::testing::Test {
+ protected:
+ virtual void SetUp() {
+ // Use make_shared so these are distinct from the static instances returned
+ // by e.g. arrow::int32()
+ local_i32_type_ = std::make_shared<arrow::Int32Type>();
+ local_i64_type_ = std::make_shared<arrow::Int64Type>();
+ local_date32_type_ = std::make_shared<arrow::Date32Type>();
+ }
+
+ virtual void TearDown() {
+ local_i32_type_.reset();
+ local_i64_type_.reset();
+ local_date32_type_.reset();
+ }
+
+ // virtual void TearDown() {}
+ DataTypePtr local_i32_type_;
+ DataTypePtr local_i64_type_;
+ DataTypePtr local_date32_type_;
+};
+
+TEST_F(TestFunctionSignature, TestToString) {
+ EXPECT_EQ(
+ FunctionSignature("myfunc", {arrow::int32(), arrow::float32()}, arrow::float64())
+ .ToString(),
+ "double myfunc(int32, float)");
+}
+
+TEST_F(TestFunctionSignature, TestEqualsName) {
+ EXPECT_EQ(FunctionSignature("add", {arrow::int32()}, arrow::int32()),
+ FunctionSignature("add", {arrow::int32()}, arrow::int32()));
+
+ EXPECT_EQ(FunctionSignature("add", {arrow::int32()}, arrow::int64()),
+ FunctionSignature("add", {local_i32_type_}, local_i64_type_));
+
+ EXPECT_FALSE(FunctionSignature("add", {arrow::int32()}, arrow::int32()) ==
+ FunctionSignature("sub", {arrow::int32()}, arrow::int32()));
+
+ EXPECT_EQ(FunctionSignature("extractDay", {arrow::int64()}, arrow::int64()),
+ FunctionSignature("extractday", {arrow::int64()}, arrow::int64()));
+
+ EXPECT_EQ(
+ FunctionSignature("castVARCHAR", {arrow::utf8(), arrow::int64()}, arrow::utf8()),
+ FunctionSignature("castvarchar", {arrow::utf8(), arrow::int64()}, arrow::utf8()));
+}
+
+TEST_F(TestFunctionSignature, TestEqualsParamCount) {
+ EXPECT_FALSE(
+ FunctionSignature("add", {arrow::int32(), arrow::int32()}, arrow::int32()) ==
+ FunctionSignature("add", {arrow::int32()}, arrow::int32()));
+}
+
+TEST_F(TestFunctionSignature, TestEqualsParamValue) {
+ EXPECT_FALSE(FunctionSignature("add", {arrow::int32()}, arrow::int32()) ==
+ FunctionSignature("add", {arrow::int64()}, arrow::int32()));
+
+ EXPECT_FALSE(
+ FunctionSignature("add", {arrow::int32()}, arrow::int32()) ==
+ FunctionSignature("add", {arrow::float32(), arrow::float32()}, arrow::int32()));
+
+ EXPECT_FALSE(
+ FunctionSignature("add", {arrow::int32(), arrow::int64()}, arrow::int32()) ==
+ FunctionSignature("add", {arrow::int64(), arrow::int32()}, arrow::int32()));
+
+ EXPECT_EQ(FunctionSignature("extract_month", {arrow::date32()}, arrow::int64()),
+ FunctionSignature("extract_month", {local_date32_type_}, local_i64_type_));
+
+ EXPECT_FALSE(FunctionSignature("extract_month", {arrow::date32()}, arrow::int64()) ==
+ FunctionSignature("extract_month", {arrow::date64()}, arrow::date32()));
+}
+
+TEST_F(TestFunctionSignature, TestEqualsReturn) {
+ EXPECT_FALSE(FunctionSignature("add", {arrow::int32()}, arrow::int64()) ==
+ FunctionSignature("add", {arrow::int32()}, arrow::int32()));
+}
+
+TEST_F(TestFunctionSignature, TestHash) {
+ FunctionSignature f1("add", {arrow::int32(), arrow::int32()}, arrow::int64());
+ FunctionSignature f2("add", {local_i32_type_, local_i32_type_}, local_i64_type_);
+ EXPECT_EQ(f1.Hash(), f2.Hash());
+
+ FunctionSignature f3("extractDay", {arrow::int64()}, arrow::int64());
+ FunctionSignature f4("extractday", {arrow::int64()}, arrow::int64());
+ EXPECT_EQ(f3.Hash(), f4.Hash());
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/gandiva.pc.in b/src/arrow/cpp/src/gandiva/gandiva.pc.in
new file mode 100644
index 000000000..22ff11a4f
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/gandiva.pc.in
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+prefix=@CMAKE_INSTALL_PREFIX@
+libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@
+includedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@
+
+Name: Gandiva
+Description: Gandiva is a toolset for compiling and evaluating expressions on Arrow data.
+Version: @GANDIVA_VERSION@
+Requires: arrow
+Libs: -L${libdir} -lgandiva
+Cflags: -I${includedir}
diff --git a/src/arrow/cpp/src/gandiva/gandiva_aliases.h b/src/arrow/cpp/src/gandiva/gandiva_aliases.h
new file mode 100644
index 000000000..6cbb671ff
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/gandiva_aliases.h
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+namespace gandiva {
+
+class Dex;
+using DexPtr = std::shared_ptr<Dex>;
+using DexVector = std::vector<std::shared_ptr<Dex>>;
+
+class ValueValidityPair;
+using ValueValidityPairPtr = std::shared_ptr<ValueValidityPair>;
+using ValueValidityPairVector = std::vector<ValueValidityPairPtr>;
+
+class FieldDescriptor;
+using FieldDescriptorPtr = std::shared_ptr<FieldDescriptor>;
+
+class FuncDescriptor;
+using FuncDescriptorPtr = std::shared_ptr<FuncDescriptor>;
+
+class LValue;
+using LValuePtr = std::shared_ptr<LValue>;
+
+class Expression;
+using ExpressionPtr = std::shared_ptr<Expression>;
+using ExpressionVector = std::vector<ExpressionPtr>;
+
+class Condition;
+using ConditionPtr = std::shared_ptr<Condition>;
+
+class Node;
+using NodePtr = std::shared_ptr<Node>;
+using NodeVector = std::vector<std::shared_ptr<Node>>;
+
+class EvalBatch;
+using EvalBatchPtr = std::shared_ptr<EvalBatch>;
+
+class FunctionSignature;
+using FuncSignaturePtr = std::shared_ptr<FunctionSignature>;
+using FuncSignatureVector = std::vector<FuncSignaturePtr>;
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/gdv_function_stubs.cc b/src/arrow/cpp/src/gandiva/gdv_function_stubs.cc
new file mode 100644
index 000000000..ed34eef4a
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/gdv_function_stubs.cc
@@ -0,0 +1,1603 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/gdv_function_stubs.h"
+
+#include <utf8proc.h>
+
+#include <string>
+#include <vector>
+
+#include "arrow/util/base64.h"
+#include "arrow/util/double_conversion.h"
+#include "arrow/util/formatting.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/utf8.h"
+#include "arrow/util/value_parsing.h"
+#include "gandiva/engine.h"
+#include "gandiva/exported_funcs.h"
+#include "gandiva/formatting_utils.h"
+#include "gandiva/hash_utils.h"
+#include "gandiva/in_holder.h"
+#include "gandiva/like_holder.h"
+#include "gandiva/precompiled/types.h"
+#include "gandiva/random_generator_holder.h"
+#include "gandiva/replace_holder.h"
+#include "gandiva/to_date_holder.h"
+
+/// Stub functions that can be accessed from LLVM or the pre-compiled library.
+
+extern "C" {
+
+bool gdv_fn_like_utf8_utf8(int64_t ptr, const char* data, int data_len,
+ const char* pattern, int pattern_len) {
+ gandiva::LikeHolder* holder = reinterpret_cast<gandiva::LikeHolder*>(ptr);
+ return (*holder)(std::string(data, data_len));
+}
+
+bool gdv_fn_like_utf8_utf8_utf8(int64_t ptr, const char* data, int data_len,
+ const char* pattern, int pattern_len,
+ const char* escape_char, int escape_char_len) {
+ gandiva::LikeHolder* holder = reinterpret_cast<gandiva::LikeHolder*>(ptr);
+ return (*holder)(std::string(data, data_len));
+}
+
+bool gdv_fn_ilike_utf8_utf8(int64_t ptr, const char* data, int data_len,
+ const char* pattern, int pattern_len) {
+ gandiva::LikeHolder* holder = reinterpret_cast<gandiva::LikeHolder*>(ptr);
+ return (*holder)(std::string(data, data_len));
+}
+
+const char* gdv_fn_regexp_replace_utf8_utf8(
+ int64_t ptr, int64_t holder_ptr, const char* data, int32_t data_len,
+ const char* /*pattern*/, int32_t /*pattern_len*/, const char* replace_string,
+ int32_t replace_string_len, int32_t* out_length) {
+ gandiva::ExecutionContext* context = reinterpret_cast<gandiva::ExecutionContext*>(ptr);
+
+ gandiva::ReplaceHolder* holder = reinterpret_cast<gandiva::ReplaceHolder*>(holder_ptr);
+
+ return (*holder)(context, data, data_len, replace_string, replace_string_len,
+ out_length);
+}
+
+double gdv_fn_random(int64_t ptr) {
+ gandiva::RandomGeneratorHolder* holder =
+ reinterpret_cast<gandiva::RandomGeneratorHolder*>(ptr);
+ return (*holder)();
+}
+
+double gdv_fn_random_with_seed(int64_t ptr, int32_t seed, bool seed_validity) {
+ gandiva::RandomGeneratorHolder* holder =
+ reinterpret_cast<gandiva::RandomGeneratorHolder*>(ptr);
+ return (*holder)();
+}
+
+int64_t gdv_fn_to_date_utf8_utf8(int64_t context_ptr, int64_t holder_ptr,
+ const char* data, int data_len, bool in1_validity,
+ const char* pattern, int pattern_len, bool in2_validity,
+ bool* out_valid) {
+ gandiva::ExecutionContext* context =
+ reinterpret_cast<gandiva::ExecutionContext*>(context_ptr);
+ gandiva::ToDateHolder* holder = reinterpret_cast<gandiva::ToDateHolder*>(holder_ptr);
+ return (*holder)(context, data, data_len, in1_validity, out_valid);
+}
+
+int64_t gdv_fn_to_date_utf8_utf8_int32(int64_t context_ptr, int64_t holder_ptr,
+ const char* data, int data_len, bool in1_validity,
+ const char* pattern, int pattern_len,
+ bool in2_validity, int32_t suppress_errors,
+ bool in3_validity, bool* out_valid) {
+ gandiva::ExecutionContext* context =
+ reinterpret_cast<gandiva::ExecutionContext*>(context_ptr);
+ gandiva::ToDateHolder* holder = reinterpret_cast<gandiva::ToDateHolder*>(holder_ptr);
+ return (*holder)(context, data, data_len, in1_validity, out_valid);
+}
+
+bool gdv_fn_in_expr_lookup_int32(int64_t ptr, int32_t value, bool in_validity) {
+ if (!in_validity) {
+ return false;
+ }
+ gandiva::InHolder<int32_t>* holder = reinterpret_cast<gandiva::InHolder<int32_t>*>(ptr);
+ return holder->HasValue(value);
+}
+
+bool gdv_fn_in_expr_lookup_int64(int64_t ptr, int64_t value, bool in_validity) {
+ if (!in_validity) {
+ return false;
+ }
+ gandiva::InHolder<int64_t>* holder = reinterpret_cast<gandiva::InHolder<int64_t>*>(ptr);
+ return holder->HasValue(value);
+}
+
+bool gdv_fn_in_expr_lookup_decimal(int64_t ptr, int64_t value_high, int64_t value_low,
+ int32_t precision, int32_t scale, bool in_validity) {
+ if (!in_validity) {
+ return false;
+ }
+ gandiva::DecimalScalar128 value(value_high, value_low, precision, scale);
+ gandiva::InHolder<gandiva::DecimalScalar128>* holder =
+ reinterpret_cast<gandiva::InHolder<gandiva::DecimalScalar128>*>(ptr);
+ return holder->HasValue(value);
+}
+
+bool gdv_fn_in_expr_lookup_float(int64_t ptr, float value, bool in_validity) {
+ if (!in_validity) {
+ return false;
+ }
+ gandiva::InHolder<float>* holder = reinterpret_cast<gandiva::InHolder<float>*>(ptr);
+ return holder->HasValue(value);
+}
+
+bool gdv_fn_in_expr_lookup_double(int64_t ptr, double value, bool in_validity) {
+ if (!in_validity) {
+ return false;
+ }
+ gandiva::InHolder<double>* holder = reinterpret_cast<gandiva::InHolder<double>*>(ptr);
+ return holder->HasValue(value);
+}
+
+bool gdv_fn_in_expr_lookup_utf8(int64_t ptr, const char* data, int data_len,
+ bool in_validity) {
+ if (!in_validity) {
+ return false;
+ }
+ gandiva::InHolder<std::string>* holder =
+ reinterpret_cast<gandiva::InHolder<std::string>*>(ptr);
+ return holder->HasValue(arrow::util::string_view(data, data_len));
+}
+
+int32_t gdv_fn_populate_varlen_vector(int64_t context_ptr, int8_t* data_ptr,
+ int32_t* offsets, int64_t slot,
+ const char* entry_buf, int32_t entry_len) {
+ auto buffer = reinterpret_cast<arrow::ResizableBuffer*>(data_ptr);
+ int32_t offset = static_cast<int32_t>(buffer->size());
+
+ // This also sets the size in the buffer.
+ auto status = buffer->Resize(offset + entry_len, false /*shrink*/);
+ if (!status.ok()) {
+ gandiva::ExecutionContext* context =
+ reinterpret_cast<gandiva::ExecutionContext*>(context_ptr);
+
+ context->set_error_msg(status.message().c_str());
+ return -1;
+ }
+
+ // append the new entry.
+ memcpy(buffer->mutable_data() + offset, entry_buf, entry_len);
+
+ // update offsets buffer.
+ offsets[slot] = offset;
+ offsets[slot + 1] = offset + entry_len;
+ return 0;
+}
+
+#define SHA1_HASH_FUNCTION(TYPE) \
+ GANDIVA_EXPORT \
+ const char* gdv_fn_sha1_##TYPE(int64_t context, gdv_##TYPE value, bool validity, \
+ int32_t* out_length) { \
+ if (!validity) { \
+ return gandiva::gdv_hash_using_sha1(context, NULLPTR, 0, out_length); \
+ } \
+ auto value_as_long = gandiva::gdv_double_to_long((double)value); \
+ const char* result = gandiva::gdv_hash_using_sha1( \
+ context, &value_as_long, sizeof(value_as_long), out_length); \
+ \
+ return result; \
+ }
+
+#define SHA1_HASH_FUNCTION_BUF(TYPE) \
+ GANDIVA_EXPORT \
+ const char* gdv_fn_sha1_##TYPE(int64_t context, gdv_##TYPE value, \
+ int32_t value_length, bool value_validity, \
+ int32_t* out_length) { \
+ if (!value_validity) { \
+ return gandiva::gdv_hash_using_sha1(context, NULLPTR, 0, out_length); \
+ } \
+ return gandiva::gdv_hash_using_sha1(context, value, value_length, out_length); \
+ }
+
+#define SHA256_HASH_FUNCTION(TYPE) \
+ GANDIVA_EXPORT \
+ const char* gdv_fn_sha256_##TYPE(int64_t context, gdv_##TYPE value, bool validity, \
+ int32_t* out_length) { \
+ if (!validity) { \
+ return gandiva::gdv_hash_using_sha256(context, NULLPTR, 0, out_length); \
+ } \
+ auto value_as_long = gandiva::gdv_double_to_long((double)value); \
+ const char* result = gandiva::gdv_hash_using_sha256( \
+ context, &value_as_long, sizeof(value_as_long), out_length); \
+ return result; \
+ }
+
+#define SHA256_HASH_FUNCTION_BUF(TYPE) \
+ GANDIVA_EXPORT \
+ const char* gdv_fn_sha256_##TYPE(int64_t context, gdv_##TYPE value, \
+ int32_t value_length, bool value_validity, \
+ int32_t* out_length) { \
+ if (!value_validity) { \
+ return gandiva::gdv_hash_using_sha256(context, NULLPTR, 0, out_length); \
+ } \
+ \
+ return gandiva::gdv_hash_using_sha256(context, value, value_length, out_length); \
+ }
+
+// Expand inner macro for all numeric types.
+#define SHA_NUMERIC_BOOL_DATE_PARAMS(INNER) \
+ INNER(int8) \
+ INNER(int16) \
+ INNER(int32) \
+ INNER(int64) \
+ INNER(uint8) \
+ INNER(uint16) \
+ INNER(uint32) \
+ INNER(uint64) \
+ INNER(float32) \
+ INNER(float64) \
+ INNER(boolean) \
+ INNER(date64) \
+ INNER(date32) \
+ INNER(time32) \
+ INNER(timestamp)
+
+// Expand inner macro for all numeric types.
+#define SHA_VAR_LEN_PARAMS(INNER) \
+ INNER(utf8) \
+ INNER(binary)
+
+SHA_NUMERIC_BOOL_DATE_PARAMS(SHA256_HASH_FUNCTION)
+SHA_VAR_LEN_PARAMS(SHA256_HASH_FUNCTION_BUF)
+
+SHA_NUMERIC_BOOL_DATE_PARAMS(SHA1_HASH_FUNCTION)
+SHA_VAR_LEN_PARAMS(SHA1_HASH_FUNCTION_BUF)
+
+#undef SHA_NUMERIC_BOOL_DATE_PARAMS
+#undef SHA_VAR_LEN_PARAMS
+
+// Add functions for decimal128
+GANDIVA_EXPORT
+const char* gdv_fn_sha256_decimal128(int64_t context, int64_t x_high, uint64_t x_low,
+ int32_t /*x_precision*/, int32_t /*x_scale*/,
+ gdv_boolean x_isvalid, int32_t* out_length) {
+ if (!x_isvalid) {
+ return gandiva::gdv_hash_using_sha256(context, NULLPTR, 0, out_length);
+ }
+
+ const gandiva::BasicDecimal128 decimal_128(x_high, x_low);
+ return gandiva::gdv_hash_using_sha256(context, decimal_128.ToBytes().data(), 16,
+ out_length);
+}
+
+GANDIVA_EXPORT
+const char* gdv_fn_sha1_decimal128(int64_t context, int64_t x_high, uint64_t x_low,
+ int32_t /*x_precision*/, int32_t /*x_scale*/,
+ gdv_boolean x_isvalid, int32_t* out_length) {
+ if (!x_isvalid) {
+ return gandiva::gdv_hash_using_sha1(context, NULLPTR, 0, out_length);
+ }
+
+ const gandiva::BasicDecimal128 decimal_128(x_high, x_low);
+ return gandiva::gdv_hash_using_sha1(context, decimal_128.ToBytes().data(), 16,
+ out_length);
+}
+
+int32_t gdv_fn_dec_from_string(int64_t context, const char* in, int32_t in_length,
+ int32_t* precision_from_str, int32_t* scale_from_str,
+ int64_t* dec_high_from_str, uint64_t* dec_low_from_str) {
+ arrow::Decimal128 dec;
+ auto status = arrow::Decimal128::FromString(std::string(in, in_length), &dec,
+ precision_from_str, scale_from_str);
+ if (!status.ok()) {
+ gdv_fn_context_set_error_msg(context, status.message().data());
+ return -1;
+ }
+ *dec_high_from_str = dec.high_bits();
+ *dec_low_from_str = dec.low_bits();
+ return 0;
+}
+
+char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low,
+ int32_t x_scale, int32_t* dec_str_len) {
+ arrow::Decimal128 dec(arrow::BasicDecimal128(x_high, x_low));
+ std::string dec_str = dec.ToString(x_scale);
+ *dec_str_len = static_cast<int32_t>(dec_str.length());
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *dec_str_len));
+ if (ret == nullptr) {
+ std::string err_msg = "Could not allocate memory for string: " + dec_str;
+ gdv_fn_context_set_error_msg(context, err_msg.data());
+ return nullptr;
+ }
+ memcpy(ret, dec_str.data(), *dec_str_len);
+ return ret;
+}
+
+GANDIVA_EXPORT
+const char* gdv_fn_base64_encode_binary(int64_t context, const char* in, int32_t in_len,
+ int32_t* out_len) {
+ if (in_len < 0) {
+ gdv_fn_context_set_error_msg(context, "Buffer length can not be negative");
+ *out_len = 0;
+ return "";
+ }
+ if (in_len == 0) {
+ *out_len = 0;
+ return "";
+ }
+ // use arrow method to encode base64 string
+ std::string encoded_str =
+ arrow::util::base64_encode(arrow::util::string_view(in, in_len));
+ *out_len = static_cast<int32_t>(encoded_str.length());
+ // allocate memory for response
+ char* ret = reinterpret_cast<char*>(
+ gdv_fn_context_arena_malloc(context, static_cast<int32_t>(*out_len)));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(ret, encoded_str.data(), *out_len);
+ return ret;
+}
+
+GANDIVA_EXPORT
+const char* gdv_fn_base64_decode_utf8(int64_t context, const char* in, int32_t in_len,
+ int32_t* out_len) {
+ if (in_len < 0) {
+ gdv_fn_context_set_error_msg(context, "Buffer length can not be negative");
+ *out_len = 0;
+ return "";
+ }
+ if (in_len == 0) {
+ *out_len = 0;
+ return "";
+ }
+ // use arrow method to decode base64 string
+ std::string decoded_str =
+ arrow::util::base64_decode(arrow::util::string_view(in, in_len));
+ *out_len = static_cast<int32_t>(decoded_str.length());
+ // allocate memory for response
+ char* ret = reinterpret_cast<char*>(
+ gdv_fn_context_arena_malloc(context, static_cast<int32_t>(*out_len)));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(ret, decoded_str.data(), *out_len);
+ return ret;
+}
+
+#define CAST_NUMERIC_FROM_VARLEN_TYPES(OUT_TYPE, ARROW_TYPE, TYPE_NAME, INNER_TYPE) \
+ GANDIVA_EXPORT \
+ OUT_TYPE gdv_fn_cast##TYPE_NAME##_##INNER_TYPE(int64_t context, const char* data, \
+ int32_t len) { \
+ OUT_TYPE val = 0; \
+ /* trim leading and trailing spaces */ \
+ int32_t trimmed_len; \
+ int32_t start = 0, end = len - 1; \
+ while (start <= end && data[start] == ' ') { \
+ ++start; \
+ } \
+ while (end >= start && data[end] == ' ') { \
+ --end; \
+ } \
+ trimmed_len = end - start + 1; \
+ const char* trimmed_data = data + start; \
+ if (!arrow::internal::ParseValue<ARROW_TYPE>(trimmed_data, trimmed_len, &val)) { \
+ std::string err = \
+ "Failed to cast the string " + std::string(data, len) + " to " #OUT_TYPE; \
+ gdv_fn_context_set_error_msg(context, err.c_str()); \
+ } \
+ return val; \
+ }
+
+#define CAST_NUMERIC_FROM_STRING(OUT_TYPE, ARROW_TYPE, TYPE_NAME) \
+ CAST_NUMERIC_FROM_VARLEN_TYPES(OUT_TYPE, ARROW_TYPE, TYPE_NAME, utf8)
+
+CAST_NUMERIC_FROM_STRING(int32_t, arrow::Int32Type, INT)
+CAST_NUMERIC_FROM_STRING(int64_t, arrow::Int64Type, BIGINT)
+CAST_NUMERIC_FROM_STRING(float, arrow::FloatType, FLOAT4)
+CAST_NUMERIC_FROM_STRING(double, arrow::DoubleType, FLOAT8)
+
+#undef CAST_NUMERIC_FROM_STRING
+
+#define CAST_NUMERIC_FROM_VARBINARY(OUT_TYPE, ARROW_TYPE, TYPE_NAME) \
+ CAST_NUMERIC_FROM_VARLEN_TYPES(OUT_TYPE, ARROW_TYPE, TYPE_NAME, varbinary)
+
+CAST_NUMERIC_FROM_VARBINARY(int32_t, arrow::Int32Type, INT)
+CAST_NUMERIC_FROM_VARBINARY(int64_t, arrow::Int64Type, BIGINT)
+CAST_NUMERIC_FROM_VARBINARY(float, arrow::FloatType, FLOAT4)
+CAST_NUMERIC_FROM_VARBINARY(double, arrow::DoubleType, FLOAT8)
+
+#undef CAST_NUMERIC_STRING
+
+#define GDV_FN_CAST_VARLEN_TYPE_FROM_INTEGER(IN_TYPE, CAST_NAME, ARROW_TYPE) \
+ GANDIVA_EXPORT \
+ const char* gdv_fn_cast##CAST_NAME##_##IN_TYPE##_int64( \
+ int64_t context, gdv_##IN_TYPE value, int64_t len, int32_t * out_len) { \
+ if (len < 0) { \
+ gdv_fn_context_set_error_msg(context, "Buffer length can not be negative"); \
+ *out_len = 0; \
+ return ""; \
+ } \
+ if (len == 0) { \
+ *out_len = 0; \
+ return ""; \
+ } \
+ arrow::internal::StringFormatter<arrow::ARROW_TYPE> formatter; \
+ char* ret = reinterpret_cast<char*>( \
+ gdv_fn_context_arena_malloc(context, static_cast<int32_t>(len))); \
+ if (ret == nullptr) { \
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory"); \
+ *out_len = 0; \
+ return ""; \
+ } \
+ arrow::Status status = formatter(value, [&](arrow::util::string_view v) { \
+ int64_t size = static_cast<int64_t>(v.size()); \
+ *out_len = static_cast<int32_t>(len < size ? len : size); \
+ memcpy(ret, v.data(), *out_len); \
+ return arrow::Status::OK(); \
+ }); \
+ if (!status.ok()) { \
+ std::string err = "Could not cast " + std::to_string(value) + " to string"; \
+ gdv_fn_context_set_error_msg(context, err.c_str()); \
+ *out_len = 0; \
+ return ""; \
+ } \
+ return ret; \
+ }
+
+#define GDV_FN_CAST_VARLEN_TYPE_FROM_REAL(IN_TYPE, CAST_NAME, ARROW_TYPE) \
+ GANDIVA_EXPORT \
+ const char* gdv_fn_cast##CAST_NAME##_##IN_TYPE##_int64( \
+ int64_t context, gdv_##IN_TYPE value, int64_t len, int32_t * out_len) { \
+ if (len < 0) { \
+ gdv_fn_context_set_error_msg(context, "Buffer length can not be negative"); \
+ *out_len = 0; \
+ return ""; \
+ } \
+ if (len == 0) { \
+ *out_len = 0; \
+ return ""; \
+ } \
+ gandiva::GdvStringFormatter<arrow::ARROW_TYPE> formatter; \
+ char* ret = reinterpret_cast<char*>( \
+ gdv_fn_context_arena_malloc(context, static_cast<int32_t>(len))); \
+ if (ret == nullptr) { \
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory"); \
+ *out_len = 0; \
+ return ""; \
+ } \
+ arrow::Status status = formatter(value, [&](arrow::util::string_view v) { \
+ int64_t size = static_cast<int64_t>(v.size()); \
+ *out_len = static_cast<int32_t>(len < size ? len : size); \
+ memcpy(ret, v.data(), *out_len); \
+ return arrow::Status::OK(); \
+ }); \
+ if (!status.ok()) { \
+ std::string err = "Could not cast " + std::to_string(value) + " to string"; \
+ gdv_fn_context_set_error_msg(context, err.c_str()); \
+ *out_len = 0; \
+ return ""; \
+ } \
+ return ret; \
+ }
+
+#define CAST_VARLEN_TYPE_FROM_NUMERIC(VARLEN_TYPE) \
+ GDV_FN_CAST_VARLEN_TYPE_FROM_INTEGER(int32, VARLEN_TYPE, Int32Type) \
+ GDV_FN_CAST_VARLEN_TYPE_FROM_INTEGER(int64, VARLEN_TYPE, Int64Type) \
+ GDV_FN_CAST_VARLEN_TYPE_FROM_REAL(float32, VARLEN_TYPE, FloatType) \
+ GDV_FN_CAST_VARLEN_TYPE_FROM_REAL(float64, VARLEN_TYPE, DoubleType)
+
+CAST_VARLEN_TYPE_FROM_NUMERIC(VARCHAR)
+CAST_VARLEN_TYPE_FROM_NUMERIC(VARBINARY)
+
+#undef CAST_VARLEN_TYPE_FROM_NUMERIC
+#undef GDV_FN_CAST_VARLEN_TYPE_FROM_INTEGER
+#undef GDV_FN_CAST_VARLEN_TYPE_FROM_REAL
+#undef GDV_FN_CAST_VARCHAR_INTEGER
+#undef GDV_FN_CAST_VARCHAR_REAL
+
+GDV_FORCE_INLINE
+int32_t gdv_fn_utf8_char_length(char c) {
+ if ((signed char)c >= 0) { // 1-byte char (0x00 ~ 0x7F)
+ return 1;
+ } else if ((c & 0xE0) == 0xC0) { // 2-byte char
+ return 2;
+ } else if ((c & 0xF0) == 0xE0) { // 3-byte char
+ return 3;
+ } else if ((c & 0xF8) == 0xF0) { // 4-byte char
+ return 4;
+ }
+ // invalid char
+ return 0;
+}
+
+GDV_FORCE_INLINE
+void gdv_fn_set_error_for_invalid_utf8(int64_t execution_context, char val) {
+ char const* fmt = "unexpected byte \\%02hhx encountered while decoding utf8 string";
+ int size = static_cast<int>(strlen(fmt)) + 64;
+ char* error = reinterpret_cast<char*>(malloc(size));
+ snprintf(error, size, fmt, (unsigned char)val);
+ gdv_fn_context_set_error_msg(execution_context, error);
+ free(error);
+}
+
+// Convert an utf8 string to its corresponding uppercase string
+GANDIVA_EXPORT
+const char* gdv_fn_upper_utf8(int64_t context, const char* data, int32_t data_len,
+ int32_t* out_len) {
+ if (data_len == 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ // If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte
+ // long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of
+ // the output can be at most twice the length of the input
+ char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
+ if (out == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+
+ int32_t char_len, out_char_len, out_idx = 0;
+ uint32_t char_codepoint;
+
+ for (int32_t i = 0; i < data_len; i += char_len) {
+ char_len = gdv_fn_utf8_char_length(data[i]);
+ // For single byte characters:
+ // If it is a lowercase ASCII character, set the output to its corresponding uppercase
+ // character; else, set the output to the read character
+ if (char_len == 1) {
+ char cur = data[i];
+ // 'A' - 'Z' : 0x41 - 0x5a
+ // 'a' - 'z' : 0x61 - 0x7a
+ if (cur >= 0x61 && cur <= 0x7a) {
+ out[out_idx++] = static_cast<char>(cur - 0x20);
+ } else {
+ out[out_idx++] = cur;
+ }
+ continue;
+ }
+
+ // Control reaches here when we encounter a multibyte character
+ const auto* in_char = (const uint8_t*)(data + i);
+
+ // Decode the multibyte character
+ bool is_valid_utf8_char =
+ arrow::util::UTF8Decode((const uint8_t**)&in_char, &char_codepoint);
+
+ // If it is an invalid utf8 character, UTF8Decode evaluates to false
+ if (!is_valid_utf8_char) {
+ gdv_fn_set_error_for_invalid_utf8(context, data[i]);
+ *out_len = 0;
+ return "";
+ }
+
+ // Convert the encoded codepoint to its uppercase codepoint
+ int32_t upper_codepoint = utf8proc_toupper(char_codepoint);
+
+ // UTF8Encode advances the pointer by the number of bytes present in the uppercase
+ // character
+ auto* out_char = (uint8_t*)(out + out_idx);
+ uint8_t* out_char_start = out_char;
+
+ // Encode the uppercase character
+ out_char = arrow::util::UTF8Encode(out_char, upper_codepoint);
+
+ out_char_len = static_cast<int32_t>(out_char - out_char_start);
+ out_idx += out_char_len;
+ }
+
+ *out_len = out_idx;
+ return out;
+}
+
+// Convert an utf8 string to its corresponding lowercase string
+GANDIVA_EXPORT
+const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_len,
+ int32_t* out_len) {
+ if (data_len == 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ // If it is a single-byte character (ASCII), corresponding lowercase is always 1-byte
+ // long; if it is >= 2 bytes long, lowercase can be at most 4 bytes long, so length of
+ // the output can be at most twice the length of the input
+ char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
+ if (out == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+
+ int32_t char_len, out_char_len, out_idx = 0;
+ uint32_t char_codepoint;
+
+ for (int32_t i = 0; i < data_len; i += char_len) {
+ char_len = gdv_fn_utf8_char_length(data[i]);
+ // For single byte characters:
+ // If it is an uppercase ASCII character, set the output to its corresponding
+ // lowercase character; else, set the output to the read character
+ if (char_len == 1) {
+ char cur = data[i];
+ // 'A' - 'Z' : 0x41 - 0x5a
+ // 'a' - 'z' : 0x61 - 0x7a
+ if (cur >= 0x41 && cur <= 0x5a) {
+ out[out_idx++] = static_cast<char>(cur + 0x20);
+ } else {
+ out[out_idx++] = cur;
+ }
+ continue;
+ }
+
+ // Control reaches here when we encounter a multibyte character
+ const auto* in_char = (const uint8_t*)(data + i);
+
+ // Decode the multibyte character
+ bool is_valid_utf8_char =
+ arrow::util::UTF8Decode((const uint8_t**)&in_char, &char_codepoint);
+
+ // If it is an invalid utf8 character, UTF8Decode evaluates to false
+ if (!is_valid_utf8_char) {
+ gdv_fn_set_error_for_invalid_utf8(context, data[i]);
+ *out_len = 0;
+ return "";
+ }
+
+ // Convert the encoded codepoint to its lowercase codepoint
+ int32_t lower_codepoint = utf8proc_tolower(char_codepoint);
+
+ // UTF8Encode advances the pointer by the number of bytes present in the lowercase
+ // character
+ auto* out_char = (uint8_t*)(out + out_idx);
+ uint8_t* out_char_start = out_char;
+
+ // Encode the lowercase character
+ out_char = arrow::util::UTF8Encode(out_char, lower_codepoint);
+
+ out_char_len = static_cast<int32_t>(out_char - out_char_start);
+ out_idx += out_char_len;
+ }
+
+ *out_len = out_idx;
+ return out;
+}
+
+// Any codepoint, except the ones for lowercase letters, uppercase letters,
+// titlecase letters, decimal digits and letter numbers categories will be
+// considered as word separators.
+//
+// The Unicode characters also are divided between categories. This link
+// https://www.compart.com/en/unicode/category shows
+// more information about characters categories.
+GDV_FORCE_INLINE
+bool gdv_fn_is_codepoint_for_space(uint32_t val) {
+ auto category = utf8proc_category(val);
+
+ return category != utf8proc_category_t::UTF8PROC_CATEGORY_LU &&
+ category != utf8proc_category_t::UTF8PROC_CATEGORY_LL &&
+ category != utf8proc_category_t::UTF8PROC_CATEGORY_LT &&
+ category != utf8proc_category_t::UTF8PROC_CATEGORY_NL &&
+ category != utf8proc_category_t ::UTF8PROC_CATEGORY_ND;
+}
+
+// For a given text, initialize the first letter after a word-separator and lowercase
+// the others e.g:
+// - "IT is a tEXt str" -> "It Is A Text Str"
+GANDIVA_EXPORT
+const char* gdv_fn_initcap_utf8(int64_t context, const char* data, int32_t data_len,
+ int32_t* out_len) {
+ if (data_len == 0) {
+ *out_len = data_len;
+ return "";
+ }
+
+ // If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte
+ // long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of
+ // the output can be at most twice the length of the input
+ char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
+ if (out == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+
+ int32_t char_len = 0;
+ int32_t out_char_len = 0;
+ int32_t out_idx = 0;
+ uint32_t char_codepoint;
+
+ // Any character is considered as space, except if it is alphanumeric
+ bool last_char_was_space = true;
+
+ for (int32_t i = 0; i < data_len; i += char_len) {
+ // An optimization for single byte characters:
+ if (static_cast<signed char>(data[i]) >= 0) { // 1-byte char (0x00 ~ 0x7F)
+ char_len = 1;
+ char cur = data[i];
+
+ if (cur >= 0x61 && cur <= 0x7a && last_char_was_space) {
+ // Check if the character is the first one of the word and it is
+ // lowercase -> 'a' - 'z' : 0x61 - 0x7a.
+ // Then turn it into uppercase -> 'A' - 'Z' : 0x41 - 0x5a
+ out[out_idx++] = static_cast<char>(cur - 0x20);
+ last_char_was_space = false;
+ } else if (cur >= 0x41 && cur <= 0x5a && !last_char_was_space) {
+ out[out_idx++] = static_cast<char>(cur + 0x20);
+ } else {
+ // Check if the ASCII character is not an alphanumeric character:
+ // '0' - '9': 0x30 - 0x39
+ // 'a' - 'z' : 0x61 - 0x7a
+ // 'A' - 'Z' : 0x41 - 0x5a
+ last_char_was_space = (cur < 0x30) || (cur > 0x39 && cur < 0x41) ||
+ (cur > 0x5a && cur < 0x61) || (cur > 0x7a);
+ out[out_idx++] = cur;
+ }
+ continue;
+ }
+
+ char_len = gdv_fn_utf8_char_length(data[i]);
+
+ // Control reaches here when we encounter a multibyte character
+ const auto* in_char = (const uint8_t*)(data + i);
+
+ // Decode the multibyte character
+ bool is_valid_utf8_char =
+ arrow::util::UTF8Decode((const uint8_t**)&in_char, &char_codepoint);
+
+ // If it is an invalid utf8 character, UTF8Decode evaluates to false
+ if (!is_valid_utf8_char) {
+ gdv_fn_set_error_for_invalid_utf8(context, data[i]);
+ *out_len = 0;
+ return "";
+ }
+
+ bool is_char_space = gdv_fn_is_codepoint_for_space(char_codepoint);
+
+ int32_t formatted_codepoint;
+ if (last_char_was_space && !is_char_space) {
+ formatted_codepoint = utf8proc_toupper(char_codepoint);
+ } else {
+ formatted_codepoint = utf8proc_tolower(char_codepoint);
+ }
+
+ // UTF8Encode advances the pointer by the number of bytes present in the character
+ auto* out_char = (uint8_t*)(out + out_idx);
+ uint8_t* out_char_start = out_char;
+
+ // Encode the character
+ out_char = arrow::util::UTF8Encode(out_char, formatted_codepoint);
+
+ out_char_len = static_cast<int32_t>(out_char - out_char_start);
+ out_idx += out_char_len;
+
+ last_char_was_space = is_char_space;
+ }
+
+ *out_len = out_idx;
+ return out;
+}
+}
+
+namespace gandiva {
+
+void ExportedStubFunctions::AddMappings(Engine* engine) const {
+ std::vector<llvm::Type*> args;
+ auto types = engine->types();
+
+ // gdv_fn_castVARBINARY_int32
+ args = {
+ types->i64_type(), // context
+ types->i32_type(), // int32_t value
+ types->i64_type(), // int64_t out value length
+ types->i32_ptr_type() // int32_t out_length
+ };
+
+ engine->AddGlobalMappingForFunc(
+ "gdv_fn_castVARBINARY_int32_int64", types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_castVARBINARY_int32_int64));
+
+ // gdv_fn_castVARBINARY_int64
+ args = {
+ types->i64_type(), // context
+ types->i64_type(), // int64_t value
+ types->i64_type(), // int64_t out value length
+ types->i32_ptr_type() // int32_t out_length
+ };
+
+ engine->AddGlobalMappingForFunc(
+ "gdv_fn_castVARBINARY_int64_int64", types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_castVARBINARY_int64_int64));
+
+ // gdv_fn_castVARBINARY_float32
+ args = {
+ types->i64_type(), // context
+ types->float_type(), // float value
+ types->i64_type(), // int64_t out value length
+ types->i64_ptr_type() // int32_t out_length
+ };
+
+ engine->AddGlobalMappingForFunc(
+ "gdv_fn_castVARBINARY_float32_int64", types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_castVARBINARY_float32_int64));
+
+ // gdv_fn_castVARBINARY_float64
+ args = {
+ types->i64_type(), // context
+ types->i64_type(), // double value
+ types->i64_type(), // int64_t out value length
+ types->i32_ptr_type() // int32_t out_length
+ };
+
+ engine->AddGlobalMappingForFunc(
+ "gdv_fn_castVARBINARY_float64_int64", types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_castVARBINARY_float64_int64));
+
+ // gdv_fn_dec_from_string
+ args = {
+ types->i64_type(), // context
+ types->i8_ptr_type(), // const char* in
+ types->i32_type(), // int32_t in_length
+ types->i32_ptr_type(), // int32_t* precision_from_str
+ types->i32_ptr_type(), // int32_t* scale_from_str
+ types->i64_ptr_type(), // int64_t* dec_high_from_str
+ types->i64_ptr_type(), // int64_t* dec_low_from_str
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_dec_from_string",
+ types->i32_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_dec_from_string));
+
+ // gdv_fn_dec_to_string
+ args = {
+ types->i64_type(), // context
+ types->i64_type(), // int64_t x_high
+ types->i64_type(), // int64_t x_low
+ types->i32_type(), // int32_t x_scale
+ types->i64_ptr_type(), // int64_t* dec_str_len
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_dec_to_string",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_dec_to_string));
+
+ // gdv_fn_like_utf8_utf8
+ args = {types->i64_type(), // int64_t ptr
+ types->i8_ptr_type(), // const char* data
+ types->i32_type(), // int data_len
+ types->i8_ptr_type(), // const char* pattern
+ types->i32_type()}; // int pattern_len
+
+ engine->AddGlobalMappingForFunc("gdv_fn_like_utf8_utf8",
+ types->i1_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_like_utf8_utf8));
+
+ // gdv_fn_like_utf8_utf8_utf8
+ args = {types->i64_type(), // int64_t ptr
+ types->i8_ptr_type(), // const char* data
+ types->i32_type(), // int data_len
+ types->i8_ptr_type(), // const char* pattern
+ types->i32_type(), // int pattern_len
+ types->i8_ptr_type(), // const char* escape_char
+ types->i32_type()}; // int escape_char_len
+
+ engine->AddGlobalMappingForFunc("gdv_fn_like_utf8_utf8_utf8",
+ types->i1_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_like_utf8_utf8_utf8));
+
+ // gdv_fn_ilike_utf8_utf8
+ args = {types->i64_type(), // int64_t ptr
+ types->i8_ptr_type(), // const char* data
+ types->i32_type(), // int data_len
+ types->i8_ptr_type(), // const char* pattern
+ types->i32_type()}; // int pattern_len
+
+ engine->AddGlobalMappingForFunc("gdv_fn_ilike_utf8_utf8",
+ types->i1_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_ilike_utf8_utf8));
+
+ // gdv_fn_regexp_replace_utf8_utf8
+ args = {types->i64_type(), // int64_t ptr
+ types->i64_type(), // int64_t holder_ptr
+ types->i8_ptr_type(), // const char* data
+ types->i32_type(), // int data_len
+ types->i8_ptr_type(), // const char* pattern
+ types->i32_type(), // int pattern_len
+ types->i8_ptr_type(), // const char* replace_string
+ types->i32_type(), // int32_t replace_string_len
+ types->i32_ptr_type()}; // int32_t* out_length
+
+ engine->AddGlobalMappingForFunc(
+ "gdv_fn_regexp_replace_utf8_utf8", types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_regexp_replace_utf8_utf8));
+
+ // gdv_fn_to_date_utf8_utf8
+ args = {types->i64_type(), // int64_t execution_context
+ types->i64_type(), // int64_t holder_ptr
+ types->i8_ptr_type(), // const char* data
+ types->i32_type(), // int data_len
+ types->i1_type(), // bool in1_validity
+ types->i8_ptr_type(), // const char* pattern
+ types->i32_type(), // int pattern_len
+ types->i1_type(), // bool in2_validity
+ types->ptr_type(types->i8_type())}; // bool* out_valid
+
+ engine->AddGlobalMappingForFunc("gdv_fn_to_date_utf8_utf8",
+ types->i64_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_to_date_utf8_utf8));
+
+ // gdv_fn_to_date_utf8_utf8_int32
+ args = {types->i64_type(), // int64_t execution_context
+ types->i64_type(), // int64_t holder_ptr
+ types->i8_ptr_type(), // const char* data
+ types->i32_type(), // int data_len
+ types->i1_type(), // bool in1_validity
+ types->i8_ptr_type(), // const char* pattern
+ types->i32_type(), // int pattern_len
+ types->i1_type(), // bool in2_validity
+ types->i32_type(), // int32_t suppress_errors
+ types->i1_type(), // bool in3_validity
+ types->ptr_type(types->i8_type())}; // bool* out_valid
+
+ engine->AddGlobalMappingForFunc(
+ "gdv_fn_to_date_utf8_utf8_int32", types->i64_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_to_date_utf8_utf8_int32));
+
+ // gdv_fn_in_expr_lookup_int32
+ args = {types->i64_type(), // int64_t in holder ptr
+ types->i32_type(), // int32 value
+ types->i1_type()}; // bool in_validity
+
+ engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_int32",
+ types->i1_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_in_expr_lookup_int32));
+
+ // gdv_fn_in_expr_lookup_int64
+ args = {types->i64_type(), // int64_t in holder ptr
+ types->i64_type(), // int64 value
+ types->i1_type()}; // bool in_validity
+
+ engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_int64",
+ types->i1_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_in_expr_lookup_int64));
+
+ // gdv_fn_in_expr_lookup_decimal
+ args = {types->i64_type(), // int64_t in holder ptr
+ types->i64_type(), // high decimal value
+ types->i64_type(), // low decimal value
+ types->i32_type(), // decimal precision value
+ types->i32_type(), // decimal scale value
+ types->i1_type()}; // bool in_validity
+
+ engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_decimal",
+ types->i1_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_in_expr_lookup_decimal));
+
+ // gdv_fn_in_expr_lookup_utf8
+ args = {types->i64_type(), // int64_t in holder ptr
+ types->i8_ptr_type(), // const char* value
+ types->i32_type(), // int value_len
+ types->i1_type()}; // bool in_validity
+
+ engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_utf8",
+ types->i1_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_in_expr_lookup_utf8));
+ // gdv_fn_in_expr_lookup_float
+ args = {types->i64_type(), // int64_t in holder ptr
+ types->float_type(), // float value
+ types->i1_type()}; // bool in_validity
+
+ engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_float",
+ types->i1_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_in_expr_lookup_float));
+ // gdv_fn_in_expr_lookup_double
+ args = {types->i64_type(), // int64_t in holder ptr
+ types->double_type(), // double value
+ types->i1_type()}; // bool in_validity
+
+ engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_double",
+ types->i1_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_in_expr_lookup_double));
+ // gdv_fn_populate_varlen_vector
+ args = {types->i64_type(), // int64_t execution_context
+ types->i8_ptr_type(), // int8_t* data ptr
+ types->i32_ptr_type(), // int32_t* offsets ptr
+ types->i64_type(), // int64_t slot
+ types->i8_ptr_type(), // const char* entry_buf
+ types->i32_type()}; // int32_t entry__len
+
+ engine->AddGlobalMappingForFunc("gdv_fn_populate_varlen_vector",
+ types->i32_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_populate_varlen_vector));
+
+ // gdv_fn_random
+ args = {types->i64_type()};
+ engine->AddGlobalMappingForFunc("gdv_fn_random", types->double_type(), args,
+ reinterpret_cast<void*>(gdv_fn_random));
+
+ args = {types->i64_type(), types->i32_type(), types->i1_type()};
+ engine->AddGlobalMappingForFunc("gdv_fn_random_with_seed", types->double_type(), args,
+ reinterpret_cast<void*>(gdv_fn_random_with_seed));
+
+ args = {types->i64_type(), // int64_t context_ptr
+ types->i8_ptr_type(), // const char* data
+ types->i32_type()}; // int32_t lenr
+
+ engine->AddGlobalMappingForFunc("gdv_fn_castINT_utf8", types->i32_type(), args,
+ reinterpret_cast<void*>(gdv_fn_castINT_utf8));
+
+ args = {types->i64_type(), // int64_t context_ptr
+ types->i8_ptr_type(), // const char* data
+ types->i32_type()}; // int32_t lenr
+
+ engine->AddGlobalMappingForFunc("gdv_fn_castBIGINT_utf8", types->i64_type(), args,
+ reinterpret_cast<void*>(gdv_fn_castBIGINT_utf8));
+
+ args = {types->i64_type(), // int64_t context_ptr
+ types->i8_ptr_type(), // const char* data
+ types->i32_type()}; // int32_t lenr
+
+ engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT4_utf8", types->float_type(), args,
+ reinterpret_cast<void*>(gdv_fn_castFLOAT4_utf8));
+
+ args = {types->i64_type(), // int64_t context_ptr
+ types->i8_ptr_type(), // const char* data
+ types->i32_type()}; // int32_t lenr
+
+ engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_utf8", types->double_type(), args,
+ reinterpret_cast<void*>(gdv_fn_castFLOAT8_utf8));
+
+ // gdv_fn_castVARCHAR_int32_int64
+ args = {types->i64_type(), // int64_t execution_context
+ types->i32_type(), // int32_t value
+ types->i64_type(), // int64_t len
+ types->i32_ptr_type()}; // int32_t* out_len
+ engine->AddGlobalMappingForFunc(
+ "gdv_fn_castVARCHAR_int32_int64", types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_castVARCHAR_int32_int64));
+
+ // gdv_fn_castVARCHAR_int64_int64
+ args = {types->i64_type(), // int64_t execution_context
+ types->i64_type(), // int64_t value
+ types->i64_type(), // int64_t len
+ types->i32_ptr_type()}; // int32_t* out_len
+ engine->AddGlobalMappingForFunc(
+ "gdv_fn_castVARCHAR_int64_int64", types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_castVARCHAR_int64_int64));
+
+ // gdv_fn_castVARCHAR_float32_int64
+ args = {types->i64_type(), // int64_t execution_context
+ types->float_type(), // float value
+ types->i64_type(), // int64_t len
+ types->i32_ptr_type()}; // int32_t* out_len
+ engine->AddGlobalMappingForFunc(
+ "gdv_fn_castVARCHAR_float32_int64", types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_castVARCHAR_float32_int64));
+
+ // gdv_fn_castVARCHAR_float64_int64
+ args = {types->i64_type(), // int64_t execution_context
+ types->double_type(), // double value
+ types->i64_type(), // int64_t len
+ types->i32_ptr_type()}; // int32_t* out_len
+ engine->AddGlobalMappingForFunc(
+ "gdv_fn_castVARCHAR_float64_int64", types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_castVARCHAR_float64_int64));
+
+ args = {types->i64_type(), // int64_t context_ptr
+ types->i8_ptr_type(), // const char* data
+ types->i32_type()}; // int32_t lenr
+
+ engine->AddGlobalMappingForFunc("gdv_fn_castINT_varbinary", types->i32_type(), args,
+ reinterpret_cast<void*>(gdv_fn_castINT_varbinary));
+
+ args = {types->i64_type(), // int64_t context_ptr
+ types->i8_ptr_type(), // const char* data
+ types->i32_type()}; // int32_t lenr
+
+ engine->AddGlobalMappingForFunc("gdv_fn_castBIGINT_varbinary", types->i64_type(), args,
+ reinterpret_cast<void*>(gdv_fn_castBIGINT_varbinary));
+
+ args = {types->i64_type(), // int64_t context_ptr
+ types->i8_ptr_type(), // const char* data
+ types->i32_type()}; // int32_t lenr
+
+ engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT4_varbinary", types->float_type(),
+ args,
+ reinterpret_cast<void*>(gdv_fn_castFLOAT4_varbinary));
+
+ args = {types->i64_type(), // int64_t context_ptr
+ types->i8_ptr_type(), // const char* data
+ types->i32_type()}; // int32_t lenr
+
+ engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_varbinary", types->double_type(),
+ args,
+ reinterpret_cast<void*>(gdv_fn_castFLOAT8_varbinary));
+
+ // gdv_fn_sha1_int8
+ args = {
+ types->i64_type(), // context
+ types->i8_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_int8",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_int8));
+
+ // gdv_fn_sha1_int16
+ args = {
+ types->i64_type(), // context
+ types->i16_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_int16",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_int16));
+
+ // gdv_fn_sha1_int32
+ args = {
+ types->i64_type(), // context
+ types->i32_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_int32",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_int32));
+
+ // gdv_fn_sha1_int32
+ args = {
+ types->i64_type(), // context
+ types->i64_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_int64",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_int64));
+
+ // gdv_fn_sha1_uint8
+ args = {
+ types->i64_type(), // context
+ types->i8_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_uint8",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_uint8));
+
+ // gdv_fn_sha1_uint16
+ args = {
+ types->i64_type(), // context
+ types->i16_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_uint16",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_uint16));
+
+ // gdv_fn_sha1_uint32
+ args = {
+ types->i64_type(), // context
+ types->i32_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_uint32",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_uint32));
+
+ // gdv_fn_sha1_uint64
+ args = {
+ types->i64_type(), // context
+ types->i64_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_uint64",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_uint64));
+
+ // gdv_fn_sha1_float32
+ args = {
+ types->i64_type(), // context
+ types->float_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_float32",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_float32));
+
+ // gdv_fn_sha1_float64
+ args = {
+ types->i64_type(), // context
+ types->double_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_float64",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_float64));
+
+ // gdv_fn_sha1_boolean
+ args = {
+ types->i64_type(), // context
+ types->i1_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_boolean",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_boolean));
+
+ // gdv_fn_sha1_date64
+ args = {
+ types->i64_type(), // context
+ types->i64_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_date64",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_date64));
+
+ // gdv_fn_sha1_date32
+ args = {
+ types->i64_type(), // context
+ types->i32_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_date32",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_date32));
+
+ // gdv_fn_sha1_time32
+ args = {
+ types->i64_type(), // context
+ types->i32_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_time32",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_time32));
+
+ // gdv_fn_sha1_timestamp
+ args = {
+ types->i64_type(), // context
+ types->i64_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_timestamp",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_timestamp));
+
+ // gdv_fn_sha1_from_utf8
+ args = {
+ types->i64_type(), // context
+ types->i8_ptr_type(), // const char*
+ types->i32_type(), // value_length
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_utf8",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_utf8));
+
+ // gdv_fn_sha1_from_binary
+ args = {
+ types->i64_type(), // context
+ types->i8_ptr_type(), // const char*
+ types->i32_type(), // value_length
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_binary",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_binary));
+
+ // gdv_fn_sha256_int8
+ args = {
+ types->i64_type(), // context
+ types->i8_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_int8",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_int8));
+
+ // gdv_fn_sha256_int16
+ args = {
+ types->i64_type(), // context
+ types->i16_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_int16",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_int16));
+
+ // gdv_fn_sha256_int32
+ args = {
+ types->i64_type(), // context
+ types->i32_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_int32",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_int32));
+
+ // gdv_fn_sha256_int32
+ args = {
+ types->i64_type(), // context
+ types->i64_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_int64",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_int64));
+
+ // gdv_fn_sha256_uint8
+ args = {
+ types->i64_type(), // context
+ types->i8_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_uint8",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_uint8));
+
+ // gdv_fn_sha256_uint16
+ args = {
+ types->i64_type(), // context
+ types->i16_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_uint16",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_uint16));
+
+ // gdv_fn_sha256_uint32
+ args = {
+ types->i64_type(), // context
+ types->i32_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_uint32",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_uint32));
+
+ // gdv_fn_sha256_uint64
+ args = {
+ types->i64_type(), // context
+ types->i64_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_uint64",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_uint64));
+
+ // gdv_fn_sha256_float32
+ args = {
+ types->i64_type(), // context
+ types->float_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_float32",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_float32));
+
+ // gdv_fn_sha256_float64
+ args = {
+ types->i64_type(), // context
+ types->double_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_float64",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_float64));
+
+ // gdv_fn_sha256_boolean
+ args = {
+ types->i64_type(), // context
+ types->i1_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_boolean",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_boolean));
+
+ // gdv_fn_sha256_date64
+ args = {
+ types->i64_type(), // context
+ types->i64_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_date64",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_date64));
+
+ // gdv_fn_sha256_date32
+ args = {
+ types->i64_type(), // context
+ types->i32_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_date32",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_date32));
+
+ // gdv_fn_sha256_time32
+ args = {
+ types->i64_type(), // context
+ types->i32_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_time32",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_time32));
+
+ // gdv_fn_sha256_timestamp
+ args = {
+ types->i64_type(), // context
+ types->i64_type(), // value
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out_length
+ };
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_timestamp",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_timestamp));
+
+ // gdv_fn_hash_sha256_from_utf8
+ args = {
+ types->i64_type(), // context
+ types->i8_ptr_type(), // const char*
+ types->i32_type(), // value_length
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_utf8",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_utf8));
+
+ // gdv_fn_hash_sha256_from_binary
+ args = {
+ types->i64_type(), // context
+ types->i8_ptr_type(), // const char*
+ types->i32_type(), // value_length
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_binary",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_binary));
+
+ // gdv_fn_sha1_decimal128
+ args = {
+ types->i64_type(), // context
+ types->i64_type(), // high_bits
+ types->i64_type(), // low_bits
+ types->i32_type(), // precision
+ types->i32_type(), // scale
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out length
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_sha1_decimal128",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha1_decimal128));
+ // gdv_fn_sha256_decimal128
+ args = {
+ types->i64_type(), // context
+ types->i64_type(), // high_bits
+ types->i64_type(), // low_bits
+ types->i32_type(), // precision
+ types->i32_type(), // scale
+ types->i1_type(), // validity
+ types->i32_ptr_type() // out length
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_sha256_decimal128",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_sha256_decimal128));
+
+ // gdv_fn_base64_encode_utf8
+ args = {
+ types->i64_type(), // context
+ types->i8_ptr_type(), // in
+ types->i32_type(), // in_len
+ types->i32_ptr_type(), // out_len
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_base64_encode_binary",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_base64_encode_binary));
+
+ // gdv_fn_base64_decode_utf8
+ args = {
+ types->i64_type(), // context
+ types->i8_ptr_type(), // in
+ types->i32_type(), // in_len
+ types->i32_ptr_type(), // out_len
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_base64_decode_utf8",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_base64_decode_utf8));
+
+ // gdv_fn_upper_utf8
+ args = {
+ types->i64_type(), // context
+ types->i8_ptr_type(), // data
+ types->i32_type(), // data_len
+ types->i32_ptr_type(), // out_len
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_upper_utf8",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_upper_utf8));
+ // gdv_fn_lower_utf8
+ args = {
+ types->i64_type(), // context
+ types->i8_ptr_type(), // data
+ types->i32_type(), // data_len
+ types->i32_ptr_type(), // out_len
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_lower_utf8",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_lower_utf8));
+
+ // gdv_fn_initcap_utf8
+ args = {
+ types->i64_type(), // context
+ types->i8_ptr_type(), // const char*
+ types->i32_type(), // value_length
+ types->i32_ptr_type() // out_length
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_initcap_utf8",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_initcap_utf8));
+}
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/gdv_function_stubs.h b/src/arrow/cpp/src/gandiva/gdv_function_stubs.h
new file mode 100644
index 000000000..670ac94df
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/gdv_function_stubs.h
@@ -0,0 +1,173 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+
+#include "gandiva/visibility.h"
+
+/// Stub functions that can be accessed from LLVM.
+extern "C" {
+
+using gdv_boolean = bool;
+using gdv_int8 = int8_t;
+using gdv_int16 = int16_t;
+using gdv_int32 = int32_t;
+using gdv_int64 = int64_t;
+using gdv_uint8 = uint8_t;
+using gdv_uint16 = uint16_t;
+using gdv_uint32 = uint32_t;
+using gdv_uint64 = uint64_t;
+using gdv_float32 = float;
+using gdv_float64 = double;
+using gdv_date64 = int64_t;
+using gdv_date32 = int32_t;
+using gdv_time32 = int32_t;
+using gdv_timestamp = int64_t;
+using gdv_utf8 = char*;
+using gdv_binary = char*;
+using gdv_day_time_interval = int64_t;
+using gdv_month_interval = int32_t;
+
+#ifdef GANDIVA_UNIT_TEST
+// unit tests may be compiled without O2, so inlining may not happen.
+#define GDV_FORCE_INLINE
+#else
+#ifdef _MSC_VER
+#define GDV_FORCE_INLINE __forceinline
+#else
+#define GDV_FORCE_INLINE inline __attribute__((always_inline))
+#endif
+#endif
+
+bool gdv_fn_like_utf8_utf8(int64_t ptr, const char* data, int data_len,
+ const char* pattern, int pattern_len);
+
+bool gdv_fn_like_utf8_utf8_utf8(int64_t ptr, const char* data, int data_len,
+ const char* pattern, int pattern_len,
+ const char* escape_char, int escape_char_len);
+
+bool gdv_fn_ilike_utf8_utf8(int64_t ptr, const char* data, int data_len,
+ const char* pattern, int pattern_len);
+
+int64_t gdv_fn_to_date_utf8_utf8_int32(int64_t context, int64_t ptr, const char* data,
+ int data_len, bool in1_validity,
+ const char* pattern, int pattern_len,
+ bool in2_validity, int32_t suppress_errors,
+ bool in3_validity, bool* out_valid);
+
+void gdv_fn_context_set_error_msg(int64_t context_ptr, const char* err_msg);
+
+uint8_t* gdv_fn_context_arena_malloc(int64_t context_ptr, int32_t data_len);
+
+void gdv_fn_context_arena_reset(int64_t context_ptr);
+
+bool in_expr_lookup_int32(int64_t ptr, int32_t value, bool in_validity);
+
+bool in_expr_lookup_int64(int64_t ptr, int64_t value, bool in_validity);
+
+bool in_expr_lookup_utf8(int64_t ptr, const char* data, int data_len, bool in_validity);
+
+int gdv_fn_time_with_zone(int* time_fields, const char* zone, int zone_len,
+ int64_t* ret_time);
+
+GANDIVA_EXPORT
+const char* gdv_fn_base64_encode_binary(int64_t context, const char* in, int32_t in_len,
+ int32_t* out_len);
+
+GANDIVA_EXPORT
+const char* gdv_fn_base64_decode_utf8(int64_t context, const char* in, int32_t in_len,
+ int32_t* out_len);
+
+GANDIVA_EXPORT
+const char* gdv_fn_castVARBINARY_int32_int64(int64_t context, gdv_int32 value,
+ int64_t out_len, int32_t* out_length);
+
+GANDIVA_EXPORT
+const char* gdv_fn_castVARBINARY_int64_int64(int64_t context, gdv_int64 value,
+ int64_t out_len, int32_t* out_length);
+
+GANDIVA_EXPORT
+const char* gdv_fn_sha256_decimal128(int64_t context, int64_t x_high, uint64_t x_low,
+ int32_t x_precision, int32_t x_scale,
+ gdv_boolean x_isvalid, int32_t* out_length);
+
+GANDIVA_EXPORT
+const char* gdv_fn_sha1_decimal128(int64_t context, int64_t x_high, uint64_t x_low,
+ int32_t x_precision, int32_t x_scale,
+ gdv_boolean x_isvalid, int32_t* out_length);
+
+int32_t gdv_fn_dec_from_string(int64_t context, const char* in, int32_t in_length,
+ int32_t* precision_from_str, int32_t* scale_from_str,
+ int64_t* dec_high_from_str, uint64_t* dec_low_from_str);
+
+char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low,
+ int32_t x_scale, int32_t* dec_str_len);
+
+GANDIVA_EXPORT
+int32_t gdv_fn_castINT_utf8(int64_t context, const char* data, int32_t data_len);
+
+GANDIVA_EXPORT
+int64_t gdv_fn_castBIGINT_utf8(int64_t context, const char* data, int32_t data_len);
+
+GANDIVA_EXPORT
+float gdv_fn_castFLOAT4_utf8(int64_t context, const char* data, int32_t data_len);
+
+GANDIVA_EXPORT
+double gdv_fn_castFLOAT8_utf8(int64_t context, const char* data, int32_t data_len);
+
+GANDIVA_EXPORT
+const char* gdv_fn_castVARCHAR_int32_int64(int64_t context, int32_t value, int64_t len,
+ int32_t* out_len);
+GANDIVA_EXPORT
+const char* gdv_fn_castVARCHAR_int64_int64(int64_t context, int64_t value, int64_t len,
+ int32_t* out_len);
+GANDIVA_EXPORT
+const char* gdv_fn_castVARCHAR_float32_int64(int64_t context, float value, int64_t len,
+ int32_t* out_len);
+GANDIVA_EXPORT
+const char* gdv_fn_castVARCHAR_float64_int64(int64_t context, double value, int64_t len,
+ int32_t* out_len);
+
+GANDIVA_EXPORT
+int32_t gdv_fn_utf8_char_length(char c);
+
+GANDIVA_EXPORT
+const char* gdv_fn_upper_utf8(int64_t context, const char* data, int32_t data_len,
+ int32_t* out_len);
+
+GANDIVA_EXPORT
+const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_len,
+ int32_t* out_len);
+
+GANDIVA_EXPORT
+const char* gdv_fn_initcap_utf8(int64_t context, const char* data, int32_t data_len,
+ int32_t* out_len);
+
+GANDIVA_EXPORT
+int32_t gdv_fn_castINT_varbinary(gdv_int64 context, const char* in, int32_t in_len);
+
+GANDIVA_EXPORT
+int64_t gdv_fn_castBIGINT_varbinary(gdv_int64 context, const char* in, int32_t in_len);
+
+GANDIVA_EXPORT
+float gdv_fn_castFLOAT4_varbinary(gdv_int64 context, const char* in, int32_t in_len);
+
+GANDIVA_EXPORT
+double gdv_fn_castFLOAT8_varbinary(gdv_int64 context, const char* in, int32_t in_len);
+}
diff --git a/src/arrow/cpp/src/gandiva/gdv_function_stubs_test.cc b/src/arrow/cpp/src/gandiva/gdv_function_stubs_test.cc
new file mode 100644
index 000000000..f7c21981c
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/gdv_function_stubs_test.cc
@@ -0,0 +1,769 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/gdv_function_stubs.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "gandiva/execution_context.h"
+
+namespace gandiva {
+
+TEST(TestGdvFnStubs, TestCastVarbinaryNumeric) {
+ gandiva::ExecutionContext ctx;
+
+ int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+ int32_t out_len = 0;
+
+ // tests for integer values as input
+ const char* out_str = gdv_fn_castVARBINARY_int32_int64(ctx_ptr, -46, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "-46");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARBINARY_int32_int64(ctx_ptr, 2147483647, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "2147483647");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARBINARY_int32_int64(ctx_ptr, -2147483647 - 1, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "-2147483648");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARBINARY_int32_int64(ctx_ptr, 0, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "0");
+ EXPECT_FALSE(ctx.has_error());
+
+ // test with required length less than actual buffer length
+ out_str = gdv_fn_castVARBINARY_int32_int64(ctx_ptr, 34567, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "345");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARBINARY_int32_int64(ctx_ptr, 347, 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ gdv_fn_castVARBINARY_int32_int64(ctx_ptr, 347, -1, &out_len);
+ EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer length can not be negative"));
+ ctx.Reset();
+
+ // tests for big integer values as input
+ out_str =
+ gdv_fn_castVARBINARY_int64_int64(ctx_ptr, 9223372036854775807LL, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "9223372036854775807");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARBINARY_int64_int64(ctx_ptr, -9223372036854775807LL - 1, 100,
+ &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "-9223372036854775808");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARBINARY_int64_int64(ctx_ptr, 0, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "0");
+ EXPECT_FALSE(ctx.has_error());
+
+ // test with required length less than actual buffer length
+ out_str = gdv_fn_castVARBINARY_int64_int64(ctx_ptr, 12345, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "123");
+ EXPECT_FALSE(ctx.has_error());
+}
+
+TEST(TestGdvFnStubs, TestBase64Encode) {
+ gandiva::ExecutionContext ctx;
+
+ auto ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+ int32_t out_len = 0;
+
+ auto value = gdv_fn_base64_encode_binary(ctx_ptr, "hello", 5, &out_len);
+ std::string out_value = std::string(value, out_len);
+ EXPECT_EQ(out_value, "aGVsbG8=");
+
+ value = gdv_fn_base64_encode_binary(ctx_ptr, "test", 4, &out_len);
+ out_value = std::string(value, out_len);
+ EXPECT_EQ(out_value, "dGVzdA==");
+
+ value = gdv_fn_base64_encode_binary(ctx_ptr, "hive", 4, &out_len);
+ out_value = std::string(value, out_len);
+ EXPECT_EQ(out_value, "aGl2ZQ==");
+
+ value = gdv_fn_base64_encode_binary(ctx_ptr, "", 0, &out_len);
+ out_value = std::string(value, out_len);
+ EXPECT_EQ(out_value, "");
+
+ value = gdv_fn_base64_encode_binary(ctx_ptr, "test", -5, &out_len);
+ out_value = std::string(value, out_len);
+ EXPECT_EQ(out_value, "");
+ EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer length can not be negative"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestBase64Decode) {
+ gandiva::ExecutionContext ctx;
+
+ auto ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+ int32_t out_len = 0;
+
+ auto value = gdv_fn_base64_decode_utf8(ctx_ptr, "aGVsbG8=", 8, &out_len);
+ std::string out_value = std::string(value, out_len);
+ EXPECT_EQ(out_value, "hello");
+
+ value = gdv_fn_base64_decode_utf8(ctx_ptr, "dGVzdA==", 8, &out_len);
+ out_value = std::string(value, out_len);
+ EXPECT_EQ(out_value, "test");
+
+ value = gdv_fn_base64_decode_utf8(ctx_ptr, "aGl2ZQ==", 8, &out_len);
+ out_value = std::string(value, out_len);
+ EXPECT_EQ(out_value, "hive");
+
+ value = gdv_fn_base64_decode_utf8(ctx_ptr, "", 0, &out_len);
+ out_value = std::string(value, out_len);
+ EXPECT_EQ(out_value, "");
+
+ value = gdv_fn_base64_decode_utf8(ctx_ptr, "test", -5, &out_len);
+ out_value = std::string(value, out_len);
+ EXPECT_EQ(out_value, "");
+ EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer length can not be negative"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestCastINT) {
+ gandiva::ExecutionContext ctx;
+
+ int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+
+ EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "-45", 3), -45);
+ EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "0", 1), 0);
+ EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "2147483647", 10), 2147483647);
+ EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "02147483647", 11), 2147483647);
+ EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "-2147483648", 11), -2147483648LL);
+ EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "-02147483648", 12), -2147483648LL);
+ EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, " 12 ", 4), 12);
+
+ gdv_fn_castINT_utf8(ctx_ptr, "2147483648", 10);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string 2147483648 to int32"));
+ ctx.Reset();
+
+ gdv_fn_castINT_utf8(ctx_ptr, "-2147483649", 11);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string -2147483649 to int32"));
+ ctx.Reset();
+
+ gdv_fn_castINT_utf8(ctx_ptr, "12.34", 5);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string 12.34 to int32"));
+ ctx.Reset();
+
+ gdv_fn_castINT_utf8(ctx_ptr, "abc", 3);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string abc to int32"));
+ ctx.Reset();
+
+ gdv_fn_castINT_utf8(ctx_ptr, "", 0);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string to int32"));
+ ctx.Reset();
+
+ gdv_fn_castINT_utf8(ctx_ptr, "-", 1);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string - to int32"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestCastBIGINT) {
+ gandiva::ExecutionContext ctx;
+
+ int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+
+ EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "-45", 3), -45);
+ EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "0", 1), 0);
+ EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "9223372036854775807", 19),
+ 9223372036854775807LL);
+ EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "09223372036854775807", 20),
+ 9223372036854775807LL);
+ EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "-9223372036854775808", 20),
+ -9223372036854775807LL - 1);
+ EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "-009223372036854775808", 22),
+ -9223372036854775807LL - 1);
+ EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, " 12 ", 4), 12);
+
+ gdv_fn_castBIGINT_utf8(ctx_ptr, "9223372036854775808", 19);
+ EXPECT_THAT(
+ ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string 9223372036854775808 to int64"));
+ ctx.Reset();
+
+ gdv_fn_castBIGINT_utf8(ctx_ptr, "-9223372036854775809", 20);
+ EXPECT_THAT(
+ ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string -9223372036854775809 to int64"));
+ ctx.Reset();
+
+ gdv_fn_castBIGINT_utf8(ctx_ptr, "12.34", 5);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string 12.34 to int64"));
+ ctx.Reset();
+
+ gdv_fn_castBIGINT_utf8(ctx_ptr, "abc", 3);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string abc to int64"));
+ ctx.Reset();
+
+ gdv_fn_castBIGINT_utf8(ctx_ptr, "", 0);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string to int64"));
+ ctx.Reset();
+
+ gdv_fn_castBIGINT_utf8(ctx_ptr, "-", 1);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string - to int64"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestCastFloat4) {
+ gandiva::ExecutionContext ctx;
+
+ int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+
+ EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, "-45.34", 6), -45.34f);
+ EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, "0", 1), 0.0f);
+ EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, "5", 1), 5.0f);
+ EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, " 3.4 ", 5), 3.4f);
+
+ gdv_fn_castFLOAT4_utf8(ctx_ptr, "", 0);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string to float"));
+ ctx.Reset();
+
+ gdv_fn_castFLOAT4_utf8(ctx_ptr, "e", 1);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string e to float"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestCastFloat8) {
+ gandiva::ExecutionContext ctx;
+
+ int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+
+ EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, "-45.34", 6), -45.34);
+ EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, "0", 1), 0.0);
+ EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, "5", 1), 5.0);
+ EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, " 3.4 ", 5), 3.4);
+
+ gdv_fn_castFLOAT8_utf8(ctx_ptr, "", 0);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string to double"));
+ ctx.Reset();
+
+ gdv_fn_castFLOAT8_utf8(ctx_ptr, "e", 1);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string e to double"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestCastVARCHARFromInt32) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+ int32_t out_len = 0;
+
+ const char* out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, -46, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "-46");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 2147483647, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "2147483647");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, -2147483647 - 1, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "-2147483648");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 0, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "0");
+ EXPECT_FALSE(ctx.has_error());
+
+ // test with required length less than actual buffer length
+ out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 34567, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "345");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 347, 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 347, -1, &out_len);
+ EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer length can not be negative"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestCastVARCHARFromInt64) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+ int32_t out_len = 0;
+
+ const char* out_str =
+ gdv_fn_castVARCHAR_int64_int64(ctx_ptr, 9223372036854775807LL, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "9223372036854775807");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str =
+ gdv_fn_castVARCHAR_int64_int64(ctx_ptr, -9223372036854775807LL - 1, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "-9223372036854775808");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_int64_int64(ctx_ptr, 0, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "0");
+ EXPECT_FALSE(ctx.has_error());
+
+ // test with required length less than actual buffer length
+ out_str = gdv_fn_castVARCHAR_int64_int64(ctx_ptr, 12345, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "123");
+ EXPECT_FALSE(ctx.has_error());
+}
+
+TEST(TestGdvFnStubs, TestCastVARCHARFromFloat) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+ int32_t out_len = 0;
+
+ const char* out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 4.567f, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "4.567");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, -3.4567f, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "-3.4567");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 0.00001f, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1.0E-5");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 0.00099999f, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "9.9999E-4");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 0.0f, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "0.0");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 10.00000f, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "10.0");
+ EXPECT_FALSE(ctx.has_error());
+
+ // test with required length less than actual buffer length
+ out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 1.2345f, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1.2");
+ EXPECT_FALSE(ctx.has_error());
+}
+
+TEST(TestGdvFnStubs, TestCastVARCHARFromDouble) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+ int32_t out_len = 0;
+
+ const char* out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 4.567, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "4.567");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, -3.4567, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "-3.4567");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 0.00001, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1.0E-5");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 0.00099999f, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "9.9999E-4");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 0.0, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "0.0");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 10.0000000000, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "10.0");
+ EXPECT_FALSE(ctx.has_error());
+
+ // test with required length less than actual buffer length
+ out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 1.2345, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1.2");
+ EXPECT_FALSE(ctx.has_error());
+}
+
+TEST(TestGdvFnStubs, TestUpper) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+
+ const char* out_str = gdv_fn_upper_utf8(ctx_ptr, "AbcDEfGh", 8, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "ABCDEFGH");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_upper_utf8(ctx_ptr, "asdfj", 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "ASDFJ");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_upper_utf8(ctx_ptr, "s;dcGS,jO!l", 11, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "S;DCGS,JO!L");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_upper_utf8(ctx_ptr, "münchen", 8, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "MÜNCHEN");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_upper_utf8(ctx_ptr, "CITROËN", 8, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "CITROËN");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_upper_utf8(ctx_ptr, "âBćDëFGH", 11, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "ÂBĆDËFGH");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_upper_utf8(ctx_ptr, "øhpqRšvñ", 11, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "ØHPQRŠVÑ");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_upper_utf8(ctx_ptr, "Möbelträgerfüße", 19, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "MÖBELTRÄGERFÜẞE");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_upper_utf8(ctx_ptr, "{õhp,PQŚv}ń+", 15, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "{ÕHP,PQŚV}Ń+");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_upper_utf8(ctx_ptr, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ std::string d("AbOJjÜoß\xc3");
+ out_str = gdv_fn_upper_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr(
+ "unexpected byte \\c3 encountered while decoding utf8 string"));
+ ctx.Reset();
+
+ std::string e(
+ "åbÑg\xe0\xa0"
+ "åBUå");
+ out_str = gdv_fn_upper_utf8(ctx_ptr, e.data(), static_cast<int>(e.length()), &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr(
+ "unexpected byte \\e0 encountered while decoding utf8 string"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestLower) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+
+ const char* out_str = gdv_fn_lower_utf8(ctx_ptr, "AbcDEfGh", 8, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "abcdefgh");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_lower_utf8(ctx_ptr, "asdfj", 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asdfj");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_lower_utf8(ctx_ptr, "S;DCgs,Jo!L", 11, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "s;dcgs,jo!l");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_lower_utf8(ctx_ptr, "MÜNCHEN", 8, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "münchen");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_lower_utf8(ctx_ptr, "citroën", 8, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "citroën");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_lower_utf8(ctx_ptr, "ÂbĆDËFgh", 11, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "âbćdëfgh");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_lower_utf8(ctx_ptr, "ØHPQrŠvÑ", 11, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "øhpqršvñ");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_lower_utf8(ctx_ptr, "MÖBELTRÄGERFÜẞE", 20, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "möbelträgerfüße");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_lower_utf8(ctx_ptr, "{ÕHP,pqśv}Ń+", 15, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "{õhp,pqśv}ń+");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_lower_utf8(ctx_ptr, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ std::string d("AbOJjÜoß\xc3");
+ out_str = gdv_fn_lower_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr(
+ "unexpected byte \\c3 encountered while decoding utf8 string"));
+ ctx.Reset();
+
+ std::string e(
+ "åbÑg\xe0\xa0"
+ "åBUå");
+ out_str = gdv_fn_lower_utf8(ctx_ptr, e.data(), static_cast<int>(e.length()), &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr(
+ "unexpected byte \\e0 encountered while decoding utf8 string"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestInitCap) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+
+ const char* out_str = gdv_fn_initcap_utf8(ctx_ptr, "test string", 11, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Test String");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_initcap_utf8(ctx_ptr, "asdfj\nhlqf", 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Asdfj\nHlqf");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_initcap_utf8(ctx_ptr, "s;DCgs,Jo!l", 11, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "S;Dcgs,Jo!L");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_initcap_utf8(ctx_ptr, " mÜNCHEN", 9, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), " München");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_initcap_utf8(ctx_ptr, "citroën CaR", 12, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Citroën Car");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_initcap_utf8(ctx_ptr, "ÂbĆDËFgh\néll", 16, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Âbćdëfgh\nÉll");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_initcap_utf8(ctx_ptr, " øhpqršvñ \n\n", 17, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), " Øhpqršvñ \n\n");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str =
+ gdv_fn_initcap_utf8(ctx_ptr, "möbelträgerfüße \nmöbelträgerfüße", 42, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Möbelträgerfüße \nMöbelträgerfüße");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_initcap_utf8(ctx_ptr, "{ÕHP,pqśv}Ń+", 15, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "{Õhp,Pqśv}Ń+");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_initcap_utf8(ctx_ptr, "sɦasasdsɦsd\"sdsdɦ", 19, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Sɦasasdsɦsd\"Sdsdɦ");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_initcap_utf8(ctx_ptr, "mysuperscipt@number²isfine", 27, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Mysuperscipt@Number²Isfine");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_initcap_utf8(ctx_ptr, "Ő<tŵas̓老ƕɱ¢vIYwށ", 25, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Ő<Tŵas̓老Ƕɱ¢Viywށ");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_initcap_utf8(ctx_ptr, "ↆcheckↆnumberisspace", 24, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "ↆcheckↆnumberisspace");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_initcap_utf8(ctx_ptr, "testing ᾌTitleᾌcase", 23, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Testing ᾌtitleᾄcase");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_initcap_utf8(ctx_ptr, "ʳTesting mʳodified", 20, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "ʳTesting MʳOdified");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_initcap_utf8(ctx_ptr, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ std::string d("AbOJjÜoß\xc3");
+ out_str =
+ gdv_fn_initcap_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr(
+ "unexpected byte \\c3 encountered while decoding utf8 string"));
+ ctx.Reset();
+
+ std::string e(
+ "åbÑg\xe0\xa0"
+ "åBUå");
+ out_str =
+ gdv_fn_initcap_utf8(ctx_ptr, e.data(), static_cast<int>(e.length()), &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr(
+ "unexpected byte \\e0 encountered while decoding utf8 string"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestCastVarbinaryINT) {
+ gandiva::ExecutionContext ctx;
+
+ int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+
+ EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "-45", 3), -45);
+ EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "0", 1), 0);
+ EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "2147483647", 10), 2147483647);
+ EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "\x32\x33", 2), 23);
+ EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "02147483647", 11), 2147483647);
+ EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "-2147483648", 11), -2147483648LL);
+ EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "-02147483648", 12), -2147483648LL);
+ EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, " 12 ", 4), 12);
+
+ gdv_fn_castINT_varbinary(ctx_ptr, "2147483648", 10);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string 2147483648 to int32"));
+ ctx.Reset();
+
+ gdv_fn_castINT_varbinary(ctx_ptr, "-2147483649", 11);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string -2147483649 to int32"));
+ ctx.Reset();
+
+ gdv_fn_castINT_varbinary(ctx_ptr, "12.34", 5);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string 12.34 to int32"));
+ ctx.Reset();
+
+ gdv_fn_castINT_varbinary(ctx_ptr, "abc", 3);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string abc to int32"));
+ ctx.Reset();
+
+ gdv_fn_castINT_varbinary(ctx_ptr, "", 0);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string to int32"));
+ ctx.Reset();
+
+ gdv_fn_castINT_varbinary(ctx_ptr, "-", 1);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string - to int32"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestCastVarbinaryBIGINT) {
+ gandiva::ExecutionContext ctx;
+
+ int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+
+ EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "-45", 3), -45);
+ EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "0", 1), 0);
+ EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "9223372036854775807", 19),
+ 9223372036854775807LL);
+ EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "09223372036854775807", 20),
+ 9223372036854775807LL);
+ EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "-9223372036854775808", 20),
+ -9223372036854775807LL - 1);
+ EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "-009223372036854775808", 22),
+ -9223372036854775807LL - 1);
+ EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, " 12 ", 4), 12);
+
+ EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr,
+ "\x39\x39\x39\x39\x39\x39\x39\x39\x39\x39", 10),
+ 9999999999LL);
+
+ gdv_fn_castBIGINT_varbinary(ctx_ptr, "9223372036854775808", 19);
+ EXPECT_THAT(
+ ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string 9223372036854775808 to int64"));
+ ctx.Reset();
+
+ gdv_fn_castBIGINT_varbinary(ctx_ptr, "-9223372036854775809", 20);
+ EXPECT_THAT(
+ ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string -9223372036854775809 to int64"));
+ ctx.Reset();
+
+ gdv_fn_castBIGINT_varbinary(ctx_ptr, "12.34", 5);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string 12.34 to int64"));
+ ctx.Reset();
+
+ gdv_fn_castBIGINT_varbinary(ctx_ptr, "abc", 3);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string abc to int64"));
+ ctx.Reset();
+
+ gdv_fn_castBIGINT_varbinary(ctx_ptr, "", 0);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string to int64"));
+ ctx.Reset();
+
+ gdv_fn_castBIGINT_varbinary(ctx_ptr, "-", 1);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string - to int64"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestCastVarbinaryFloat4) {
+ gandiva::ExecutionContext ctx;
+
+ int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+
+ EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "-45.34", 6), -45.34f);
+ EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "0", 1), 0.0f);
+ EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "5", 1), 5.0f);
+ EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, " 3.4 ", 5), 3.4f);
+ EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, " \x33\x2E\x34 ", 5), 3.4f);
+
+ gdv_fn_castFLOAT4_varbinary(ctx_ptr, "", 0);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string to float"));
+ ctx.Reset();
+
+ gdv_fn_castFLOAT4_varbinary(ctx_ptr, "e", 1);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string e to float"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestCastVarbinaryFloat8) {
+ gandiva::ExecutionContext ctx;
+
+ int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+
+ EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, "-45.34", 6), -45.34);
+ EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, "0", 1), 0.0);
+ EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, "5", 1), 5.0);
+ EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, " \x33\x2E\x34 ", 5), 3.4);
+
+ gdv_fn_castFLOAT8_varbinary(ctx_ptr, "", 0);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string to double"));
+ ctx.Reset();
+
+ gdv_fn_castFLOAT8_varbinary(ctx_ptr, "e", 1);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Failed to cast the string e to double"));
+ ctx.Reset();
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/greedy_dual_size_cache.h b/src/arrow/cpp/src/gandiva/greedy_dual_size_cache.h
new file mode 100644
index 000000000..cb5c38e07
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/greedy_dual_size_cache.h
@@ -0,0 +1,154 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <list>
+#include <queue>
+#include <set>
+#include <unordered_map>
+#include <utility>
+
+#include "arrow/util/optional.h"
+
+// modified cache to support evict policy using the GreedyDual-Size algorithm.
+namespace gandiva {
+// Defines a base value object supported on the cache that may contain properties
+template <typename ValueType>
+class ValueCacheObject {
+ public:
+ ValueCacheObject(ValueType module, uint64_t cost) : module(module), cost(cost) {}
+ ValueType module;
+ uint64_t cost;
+ bool operator<(const ValueCacheObject& other) const { return cost < other.cost; }
+};
+
+// A particular cache based on the GreedyDual-Size cache which is a generalization of LRU
+// which defines costs for each cache values.
+// The algorithm associates a cost, C, with each cache value. Initially, when the value
+// is brought into cache, C is set to be the cost related to the value (the cost is
+// always non-negative). When a replacement needs to be made, the value with the lowest C
+// cost is replaced, and then all values reduce their C costs by the minimum value of C
+// over all the values already in the cache.
+// If a value is accessed, its C value is restored to its initial cost. Thus, the C costs
+// of recently accessed values retain a larger portion of the original cost than those of
+// values that have not been accessed for a long time. The C costs are reduced as time
+// goes and are restored when accessed.
+
+template <class Key, class Value>
+class GreedyDualSizeCache {
+ // inner class to define the priority item
+ class PriorityItem {
+ public:
+ PriorityItem(uint64_t actual_priority, uint64_t original_priority, Key key)
+ : actual_priority(actual_priority),
+ original_priority(original_priority),
+ cache_key(key) {}
+ // this ensure that the items with low priority stays in the beginning of the queue,
+ // so it can be the one removed by evict operation
+ bool operator<(const PriorityItem& other) const {
+ return actual_priority < other.actual_priority;
+ }
+ uint64_t actual_priority;
+ uint64_t original_priority;
+ Key cache_key;
+ };
+
+ public:
+ struct hasher {
+ template <typename I>
+ std::size_t operator()(const I& i) const {
+ return i.Hash();
+ }
+ };
+ // a map from 'key' to a pair of Value and a pointer to the priority value
+ using map_type = std::unordered_map<
+ Key, std::pair<ValueCacheObject<Value>, typename std::set<PriorityItem>::iterator>,
+ hasher>;
+
+ explicit GreedyDualSizeCache(size_t capacity) : inflation_(0), capacity_(capacity) {}
+
+ ~GreedyDualSizeCache() = default;
+
+ size_t size() const { return map_.size(); }
+
+ size_t capacity() const { return capacity_; }
+
+ bool empty() const { return map_.empty(); }
+
+ bool contains(const Key& key) { return map_.find(key) != map_.end(); }
+
+ void insert(const Key& key, const ValueCacheObject<Value>& value) {
+ typename map_type::iterator i = map_.find(key);
+ // check if element is not in the cache to add it
+ if (i == map_.end()) {
+ // insert item into the cache, but first check if it is full, to evict an item
+ // if it is necessary
+ if (size() >= capacity_) {
+ evict();
+ }
+
+ // insert the new item
+ auto item =
+ priority_set_.insert(PriorityItem(value.cost + inflation_, value.cost, key));
+ // save on map the value and the priority item iterator position
+ map_.emplace(key, std::make_pair(value, item.first));
+ }
+ }
+
+ arrow::util::optional<ValueCacheObject<Value>> get(const Key& key) {
+ // lookup value in the cache
+ typename map_type::iterator value_for_key = map_.find(key);
+ if (value_for_key == map_.end()) {
+ // value not in cache
+ return arrow::util::nullopt;
+ }
+ PriorityItem item = *value_for_key->second.second;
+ // if the value was found on the cache, update its cost (original + inflation)
+ if (item.actual_priority != item.original_priority + inflation_) {
+ priority_set_.erase(value_for_key->second.second);
+ auto iter = priority_set_.insert(PriorityItem(
+ item.original_priority + inflation_, item.original_priority, item.cache_key));
+ value_for_key->second.second = iter.first;
+ }
+ return value_for_key->second.first;
+ }
+
+ void clear() {
+ map_.clear();
+ priority_set_.clear();
+ }
+
+ private:
+ void evict() {
+ // TODO: inflation overflow is unlikely to happen but needs to be handled
+ // for correctness.
+ // evict item from the beginning of the set. This set is ordered from the
+ // lower priority value to the higher priority value.
+ typename std::set<PriorityItem>::iterator i = priority_set_.begin();
+ // update the inflation cost related to the evicted item
+ inflation_ = (*i).actual_priority;
+ map_.erase((*i).cache_key);
+ priority_set_.erase(i);
+ }
+
+ map_type map_;
+ std::set<PriorityItem> priority_set_;
+ uint64_t inflation_;
+ size_t capacity_;
+};
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/greedy_dual_size_cache_test.cc b/src/arrow/cpp/src/gandiva/greedy_dual_size_cache_test.cc
new file mode 100644
index 000000000..3c72eef70
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/greedy_dual_size_cache_test.cc
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/greedy_dual_size_cache.h"
+
+#include <string>
+#include <typeinfo>
+
+#include <gtest/gtest.h>
+
+namespace gandiva {
+
+class GreedyDualSizeCacheKey {
+ public:
+ explicit GreedyDualSizeCacheKey(int tmp) : tmp_(tmp) {}
+ std::size_t Hash() const { return tmp_; }
+ bool operator==(const GreedyDualSizeCacheKey& other) const {
+ return tmp_ == other.tmp_;
+ }
+
+ private:
+ int tmp_;
+};
+
+class TestGreedyDualSizeCache : public ::testing::Test {
+ public:
+ TestGreedyDualSizeCache() : cache_(2) {}
+
+ protected:
+ GreedyDualSizeCache<GreedyDualSizeCacheKey, std::string> cache_;
+};
+
+TEST_F(TestGreedyDualSizeCache, TestEvict) {
+ // check if the cache is evicting the items with low priority on cache
+ cache_.insert(GreedyDualSizeCacheKey(1), ValueCacheObject<std::string>("1", 1));
+ cache_.insert(GreedyDualSizeCacheKey(2), ValueCacheObject<std::string>("2", 10));
+ cache_.insert(GreedyDualSizeCacheKey(3), ValueCacheObject<std::string>("3", 20));
+ cache_.insert(GreedyDualSizeCacheKey(4), ValueCacheObject<std::string>("4", 15));
+ cache_.insert(GreedyDualSizeCacheKey(1), ValueCacheObject<std::string>("5", 1));
+ ASSERT_EQ(2, cache_.size());
+ // we check initially the values that won't be on the cache, since the get operation
+ // may affect the entity costs, which is not the purpose of this test
+ ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(2)), arrow::util::nullopt);
+ ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(3)), arrow::util::nullopt);
+ ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(1))->module, "5");
+ ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(4))->module, "4");
+}
+
+TEST_F(TestGreedyDualSizeCache, TestGreedyDualSizeBehavior) {
+ // insert 1 and 3 evicting 2 (this eviction will increase the inflation cost by 20)
+ cache_.insert(GreedyDualSizeCacheKey(1), ValueCacheObject<std::string>("1", 40));
+ cache_.insert(GreedyDualSizeCacheKey(2), ValueCacheObject<std::string>("2", 20));
+ cache_.insert(GreedyDualSizeCacheKey(3), ValueCacheObject<std::string>("3", 30));
+
+ // when accessing key 3, its actual cost will be increased by the inflation, so in the
+ // next eviction, the key 1 will be evicted, since the key 1 actual cost (original(40))
+ // is smaller than key 3 actual increased cost (original(30) + inflation(20))
+ ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(3))->module, "3");
+
+ // try to insert key 2 and expect the eviction of key 1
+ cache_.insert(GreedyDualSizeCacheKey(2), ValueCacheObject<std::string>("2", 20));
+ ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(1)), arrow::util::nullopt);
+
+ // when accessing key 2, its original cost should be increased by inflation, so when
+ // inserting the key 1 again, now the key 3 should be evicted
+ ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(2))->module, "2");
+ cache_.insert(GreedyDualSizeCacheKey(1), ValueCacheObject<std::string>("1", 20));
+
+ ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(1))->module, "1");
+ ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(2))->module, "2");
+ ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(3)), arrow::util::nullopt);
+ ASSERT_EQ(2, cache_.size());
+}
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/hash_utils.cc b/src/arrow/cpp/src/gandiva/hash_utils.cc
new file mode 100644
index 000000000..8ebf60a9b
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/hash_utils.cc
@@ -0,0 +1,134 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/hash_utils.h"
+#include <cstring>
+#include "arrow/util/logging.h"
+#include "gandiva/gdv_function_stubs.h"
+#include "openssl/evp.h"
+
+namespace gandiva {
+/// Hashes a generic message using the SHA256 algorithm
+GANDIVA_EXPORT
+const char* gdv_hash_using_sha256(int64_t context, const void* message,
+ size_t message_length, int32_t* out_length) {
+ constexpr int sha256_result_length = 64;
+ return gdv_hash_using_sha(context, message, message_length, EVP_sha256(),
+ sha256_result_length, out_length);
+}
+
+/// Hashes a generic message using the SHA1 algorithm
+GANDIVA_EXPORT
+const char* gdv_hash_using_sha1(int64_t context, const void* message,
+ size_t message_length, int32_t* out_length) {
+ constexpr int sha1_result_length = 40;
+ return gdv_hash_using_sha(context, message, message_length, EVP_sha1(),
+ sha1_result_length, out_length);
+}
+
+/// \brief Hashes a generic message using SHA algorithm.
+///
+/// It uses the EVP API in the OpenSSL library to generate
+/// the hash. The type of the hash is defined by the
+/// \b hash_type \b parameter.
+GANDIVA_EXPORT
+const char* gdv_hash_using_sha(int64_t context, const void* message,
+ size_t message_length, const EVP_MD* hash_type,
+ uint32_t result_buf_size, int32_t* out_length) {
+ EVP_MD_CTX* md_ctx = EVP_MD_CTX_new();
+
+ if (md_ctx == nullptr) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not create the context for SHA processing.");
+ *out_length = 0;
+ return "";
+ }
+
+ int evp_success_status = 1;
+
+ if (EVP_DigestInit_ex(md_ctx, hash_type, nullptr) != evp_success_status ||
+ EVP_DigestUpdate(md_ctx, message, message_length) != evp_success_status) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not obtain the hash for the defined value.");
+ EVP_MD_CTX_free(md_ctx);
+
+ *out_length = 0;
+ return "";
+ }
+
+ // Create the temporary buffer used by the EVP to generate the hash
+ unsigned int hash_digest_size = EVP_MD_size(hash_type);
+ auto* result = static_cast<unsigned char*>(OPENSSL_malloc(hash_digest_size));
+
+ if (result == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for SHA processing");
+ EVP_MD_CTX_free(md_ctx);
+ *out_length = 0;
+ return "";
+ }
+
+ unsigned int result_length;
+ EVP_DigestFinal_ex(md_ctx, result, &result_length);
+
+ if (result_length != hash_digest_size && result_buf_size != (2 * hash_digest_size)) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not obtain the hash for the defined value");
+ EVP_MD_CTX_free(md_ctx);
+ OPENSSL_free(result);
+
+ *out_length = 0;
+ return "";
+ }
+
+ auto result_buffer =
+ reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, result_buf_size));
+
+ if (result_buffer == nullptr) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not allocate memory for the result buffer");
+ // Free the resources used by the EVP
+ EVP_MD_CTX_free(md_ctx);
+ OPENSSL_free(result);
+
+ *out_length = 0;
+ return "";
+ }
+
+ unsigned int result_buff_index = 0;
+ for (unsigned int j = 0; j < result_length; j++) {
+ DCHECK(result_buff_index >= 0 && result_buff_index < result_buf_size);
+
+ unsigned char hex_number = result[j];
+ result_buff_index +=
+ snprintf(result_buffer + result_buff_index, result_buf_size, "%02x", hex_number);
+ }
+
+ // Free the resources used by the EVP to avoid memory leaks
+ EVP_MD_CTX_free(md_ctx);
+ OPENSSL_free(result);
+
+ *out_length = result_buf_size;
+ return result_buffer;
+}
+
+GANDIVA_EXPORT
+uint64_t gdv_double_to_long(double value) {
+ uint64_t result;
+ memcpy(&result, &value, sizeof(result));
+ return result;
+}
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/hash_utils.h b/src/arrow/cpp/src/gandiva/hash_utils.h
new file mode 100644
index 000000000..483993f30
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/hash_utils.h
@@ -0,0 +1,44 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef ARROW_SRC_HASH_UTILS_H_
+#define ARROW_SRC_HASH_UTILS_H_
+
+#include <cstdint>
+#include <cstdlib>
+#include "gandiva/visibility.h"
+#include "openssl/evp.h"
+
+namespace gandiva {
+GANDIVA_EXPORT
+const char* gdv_hash_using_sha256(int64_t context, const void* message,
+ size_t message_length, int32_t* out_length);
+
+GANDIVA_EXPORT
+const char* gdv_hash_using_sha1(int64_t context, const void* message,
+ size_t message_length, int32_t* out_length);
+
+GANDIVA_EXPORT
+const char* gdv_hash_using_sha(int64_t context, const void* message,
+ size_t message_length, const EVP_MD* hash_type,
+ uint32_t result_buf_size, int32_t* out_length);
+
+GANDIVA_EXPORT
+uint64_t gdv_double_to_long(double value);
+} // namespace gandiva
+
+#endif // ARROW_SRC_HASH_UTILS_H_
diff --git a/src/arrow/cpp/src/gandiva/hash_utils_test.cc b/src/arrow/cpp/src/gandiva/hash_utils_test.cc
new file mode 100644
index 000000000..a8f55e1ed
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/hash_utils_test.cc
@@ -0,0 +1,164 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include <unordered_set>
+
+#include "gandiva/execution_context.h"
+#include "gandiva/hash_utils.h"
+
+TEST(TestShaHashUtils, TestSha1Numeric) {
+ gandiva::ExecutionContext ctx;
+
+ auto ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+
+ std::vector<uint64_t> values_to_be_hashed;
+
+ // Generate a list of values to obtains the SHA1 hash
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.0));
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.1));
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.2));
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(-0.10000001));
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(-0.0000001));
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(1.000000));
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(-0.0000002));
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.999999));
+
+ // Checks if the hash value is different for each one of the values
+ std::unordered_set<std::string> sha_values;
+
+ int sha1_size = 40;
+
+ for (auto value : values_to_be_hashed) {
+ int out_length;
+ const char* sha_1 =
+ gandiva::gdv_hash_using_sha1(ctx_ptr, &value, sizeof(value), &out_length);
+ std::string sha1_as_str(sha_1, out_length);
+ EXPECT_EQ(sha1_as_str.size(), sha1_size);
+
+ // The value can not exists inside the set with the hash results
+ EXPECT_EQ(sha_values.find(sha1_as_str), sha_values.end());
+ sha_values.insert(sha1_as_str);
+ }
+}
+
+TEST(TestShaHashUtils, TestSha256Numeric) {
+ gandiva::ExecutionContext ctx;
+
+ auto ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+
+ std::vector<uint64_t> values_to_be_hashed;
+
+ // Generate a list of values to obtains the SHA1 hash
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.0));
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.1));
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.2));
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(-0.10000001));
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(-0.0000001));
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(1.000000));
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(-0.0000002));
+ values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.999999));
+
+ // Checks if the hash value is different for each one of the values
+ std::unordered_set<std::string> sha_values;
+
+ int sha256_size = 64;
+
+ for (auto value : values_to_be_hashed) {
+ int out_length;
+ const char* sha_256 =
+ gandiva::gdv_hash_using_sha256(ctx_ptr, &value, sizeof(value), &out_length);
+ std::string sha256_as_str(sha_256, out_length);
+ EXPECT_EQ(sha256_as_str.size(), sha256_size);
+
+ // The value can not exists inside the set with the hash results
+ EXPECT_EQ(sha_values.find(sha256_as_str), sha_values.end());
+ sha_values.insert(sha256_as_str);
+ }
+}
+
+TEST(TestShaHashUtils, TestSha1Varlen) {
+ gandiva::ExecutionContext ctx;
+
+ auto ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+
+ std::string first_string =
+ "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeıʃn\nY [ˈʏpsilɔn], "
+ "Yen [jɛn], Yoga [ˈjoːgɑ]";
+
+ std::string second_string =
+ "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeın\nY [ˈʏpsilɔn], "
+ "Yen [jɛn], Yoga [ˈjoːgɑ] コンニチハ";
+
+ // The strings expected hashes are obtained from shell executing the following command:
+ // echo -n <output-string> | openssl dgst sha1
+ std::string expected_first_result = "160fcdbc2fa694d884868f5fae7a4bae82706185";
+ std::string expected_second_result = "a456b3e0f88669d2482170a42fade226a815bee1";
+
+ // Generate the hashes and compare with expected outputs
+ const int sha1_size = 40;
+ int out_length;
+
+ const char* sha_1 = gandiva::gdv_hash_using_sha1(ctx_ptr, first_string.c_str(),
+ first_string.size(), &out_length);
+ std::string sha1_as_str(sha_1, out_length);
+ EXPECT_EQ(sha1_as_str.size(), sha1_size);
+ EXPECT_EQ(sha1_as_str, expected_first_result);
+
+ const char* sha_2 = gandiva::gdv_hash_using_sha1(ctx_ptr, second_string.c_str(),
+ second_string.size(), &out_length);
+ std::string sha2_as_str(sha_2, out_length);
+ EXPECT_EQ(sha2_as_str.size(), sha1_size);
+ EXPECT_EQ(sha2_as_str, expected_second_result);
+}
+
+TEST(TestShaHashUtils, TestSha256Varlen) {
+ gandiva::ExecutionContext ctx;
+
+ auto ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+
+ std::string first_string =
+ "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeıʃn\nY [ˈʏpsilɔn], "
+ "Yen [jɛn], Yoga [ˈjoːgɑ]";
+
+ std::string second_string =
+ "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeın\nY [ˈʏpsilɔn], "
+ "Yen [jɛn], Yoga [ˈjoːgɑ] コンニチハ";
+
+ // The strings expected hashes are obtained from shell executing the following command:
+ // echo -n <output-string> | openssl dgst sha1
+ std::string expected_first_result =
+ "55aeb2e789871dbd289edae94d4c1c82a1c25ca0bcd5a873924da2fefdd57acb";
+ std::string expected_second_result =
+ "86b29c13d0d0e26ea8f85bfa649dc9b8622ae59a4da2409d7d9b463e86e796f2";
+
+ // Generate the hashes and compare with expected outputs
+ const int sha256_size = 64;
+ int out_length;
+
+ const char* sha_1 = gandiva::gdv_hash_using_sha256(ctx_ptr, first_string.c_str(),
+ first_string.size(), &out_length);
+ std::string sha1_as_str(sha_1, out_length);
+ EXPECT_EQ(sha1_as_str.size(), sha256_size);
+ EXPECT_EQ(sha1_as_str, expected_first_result);
+
+ const char* sha_2 = gandiva::gdv_hash_using_sha256(ctx_ptr, second_string.c_str(),
+ second_string.size(), &out_length);
+ std::string sha2_as_str(sha_2, out_length);
+ EXPECT_EQ(sha2_as_str.size(), sha256_size);
+ EXPECT_EQ(sha2_as_str, expected_second_result);
+}
diff --git a/src/arrow/cpp/src/gandiva/in_holder.h b/src/arrow/cpp/src/gandiva/in_holder.h
new file mode 100644
index 000000000..d55ab5ec5
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/in_holder.h
@@ -0,0 +1,91 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+#include <unordered_set>
+
+#include "arrow/util/hashing.h"
+#include "gandiva/arrow.h"
+#include "gandiva/decimal_scalar.h"
+#include "gandiva/gandiva_aliases.h"
+
+namespace gandiva {
+
+/// Function Holder for IN Expressions
+template <typename Type>
+class InHolder {
+ public:
+ explicit InHolder(const std::unordered_set<Type>& values) {
+ values_.max_load_factor(0.25f);
+ for (auto& value : values) {
+ values_.insert(value);
+ }
+ }
+
+ bool HasValue(Type value) const { return values_.count(value) == 1; }
+
+ private:
+ std::unordered_set<Type> values_;
+};
+
+template <>
+class InHolder<gandiva::DecimalScalar128> {
+ public:
+ explicit InHolder(const std::unordered_set<gandiva::DecimalScalar128>& values) {
+ values_.max_load_factor(0.25f);
+ for (auto& value : values) {
+ values_.insert(value);
+ }
+ }
+
+ bool HasValue(gandiva::DecimalScalar128 value) const {
+ return values_.count(value) == 1;
+ }
+
+ private:
+ std::unordered_set<gandiva::DecimalScalar128> values_;
+};
+
+template <>
+class InHolder<std::string> {
+ public:
+ explicit InHolder(std::unordered_set<std::string> values) : values_(std::move(values)) {
+ values_lookup_.max_load_factor(0.25f);
+ for (const std::string& value : values_) {
+ values_lookup_.emplace(value);
+ }
+ }
+
+ bool HasValue(arrow::util::string_view value) const {
+ return values_lookup_.count(value) == 1;
+ }
+
+ private:
+ struct string_view_hash {
+ public:
+ std::size_t operator()(arrow::util::string_view v) const {
+ return arrow::internal::ComputeStringHash<0>(v.data(), v.length());
+ }
+ };
+
+ std::unordered_set<arrow::util::string_view, string_view_hash> values_lookup_;
+ const std::unordered_set<std::string> values_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/jni/CMakeLists.txt b/src/arrow/cpp/src/gandiva/jni/CMakeLists.txt
new file mode 100644
index 000000000..046934141
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/jni/CMakeLists.txt
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(CMAKE_VERSION VERSION_LESS 3.11)
+ message(FATAL_ERROR "Building the Gandiva JNI bindings requires CMake version >= 3.11")
+endif()
+
+if(MSVC)
+ add_definitions(-DPROTOBUF_USE_DLLS)
+endif()
+
+# Find JNI
+find_package(JNI REQUIRED)
+
+set(PROTO_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR})
+set(PROTO_OUTPUT_FILES "${PROTO_OUTPUT_DIR}/Types.pb.cc")
+set(PROTO_OUTPUT_FILES ${PROTO_OUTPUT_FILES} "${PROTO_OUTPUT_DIR}/Types.pb.h")
+
+set_source_files_properties(${PROTO_OUTPUT_FILES} PROPERTIES GENERATED TRUE)
+
+get_filename_component(ABS_GANDIVA_PROTO
+ ${CMAKE_SOURCE_DIR}/src/gandiva/proto/Types.proto ABSOLUTE)
+
+add_custom_command(OUTPUT ${PROTO_OUTPUT_FILES}
+ COMMAND ${ARROW_PROTOBUF_PROTOC} --proto_path
+ ${CMAKE_SOURCE_DIR}/src/gandiva/proto --cpp_out
+ ${PROTO_OUTPUT_DIR}
+ ${CMAKE_SOURCE_DIR}/src/gandiva/proto/Types.proto
+ DEPENDS ${ABS_GANDIVA_PROTO} ${ARROW_PROTOBUF_LIBPROTOBUF}
+ COMMENT "Running PROTO compiler on Types.proto"
+ VERBATIM)
+
+add_custom_target(gandiva_jni_proto ALL DEPENDS ${PROTO_OUTPUT_FILES})
+set(PROTO_SRCS "${PROTO_OUTPUT_DIR}/Types.pb.cc")
+set(PROTO_HDRS "${PROTO_OUTPUT_DIR}/Types.pb.h")
+
+# Create the jni header file (from the java class).
+set(JNI_HEADERS_DIR "${CMAKE_CURRENT_BINARY_DIR}/java")
+add_subdirectory(../../../../java/gandiva ./java/gandiva)
+
+set(GANDIVA_LINK_LIBS ${ARROW_PROTOBUF_LIBPROTOBUF})
+if(ARROW_BUILD_STATIC)
+ list(APPEND GANDIVA_LINK_LIBS gandiva_static)
+else()
+ list(APPEND GANDIVA_LINK_LIBS gandiva_shared)
+endif()
+
+set(GANDIVA_JNI_SOURCES
+ config_builder.cc
+ config_holder.cc
+ expression_registry_helper.cc
+ jni_common.cc
+ ${PROTO_SRCS})
+
+# For users of gandiva_jni library (including integ tests), include-dir is :
+# /usr/**/include dir after install,
+# cpp/include during build
+# For building gandiva_jni library itself, include-dir (in addition to above) is :
+# cpp/src
+add_arrow_lib(gandiva_jni
+ SOURCES
+ ${GANDIVA_JNI_SOURCES}
+ OUTPUTS
+ GANDIVA_JNI_LIBRARIES
+ SHARED_PRIVATE_LINK_LIBS
+ ${GANDIVA_LINK_LIBS}
+ STATIC_LINK_LIBS
+ ${GANDIVA_LINK_LIBS}
+ DEPENDENCIES
+ ${GANDIVA_LINK_LIBS}
+ gandiva_java
+ gandiva_jni_headers
+ gandiva_jni_proto
+ EXTRA_INCLUDES
+ $<INSTALL_INTERFACE:include>
+ $<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/include>
+ $<BUILD_INTERFACE:${JNI_HEADERS_DIR}>
+ PRIVATE_INCLUDES
+ ${JNI_INCLUDE_DIRS}
+ ${CMAKE_CURRENT_BINARY_DIR})
+
+add_dependencies(gandiva ${GANDIVA_JNI_LIBRARIES})
+
+if(ARROW_BUILD_SHARED)
+ # filter out everything that is not needed for the jni bridge
+ # statically linked stdc++ has conflicts with stdc++ loaded by other libraries.
+ if(CXX_LINKER_SUPPORTS_VERSION_SCRIPT)
+ set_target_properties(gandiva_jni_shared
+ PROPERTIES LINK_FLAGS
+ "-Wl,--version-script=${CMAKE_SOURCE_DIR}/src/gandiva/jni/symbols.map"
+ )
+ endif()
+endif()
diff --git a/src/arrow/cpp/src/gandiva/jni/config_builder.cc b/src/arrow/cpp/src/gandiva/jni/config_builder.cc
new file mode 100644
index 000000000..b115210ce
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/jni/config_builder.cc
@@ -0,0 +1,53 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <string>
+
+#include "gandiva/configuration.h"
+#include "gandiva/jni/config_holder.h"
+#include "gandiva/jni/env_helper.h"
+#include "jni/org_apache_arrow_gandiva_evaluator_ConfigurationBuilder.h"
+
+using gandiva::ConfigHolder;
+using gandiva::Configuration;
+using gandiva::ConfigurationBuilder;
+
+/*
+ * Class: org_apache_arrow_gandiva_evaluator_ConfigBuilder
+ * Method: buildConfigInstance
+ * Signature: (ZZ)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_apache_arrow_gandiva_evaluator_ConfigurationBuilder_buildConfigInstance(
+ JNIEnv* env, jobject configuration, jboolean optimize, jboolean target_host_cpu) {
+ ConfigurationBuilder configuration_builder;
+ std::shared_ptr<Configuration> config = configuration_builder.build();
+ config->set_optimize(optimize);
+ config->target_host_cpu(target_host_cpu);
+ return ConfigHolder::MapInsert(config);
+}
+
+/*
+ * Class: org_apache_arrow_gandiva_evaluator_ConfigBuilder
+ * Method: releaseConfigInstance
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL
+Java_org_apache_arrow_gandiva_evaluator_ConfigurationBuilder_releaseConfigInstance(
+ JNIEnv* env, jobject configuration, jlong config_id) {
+ ConfigHolder::MapErase(config_id);
+}
diff --git a/src/arrow/cpp/src/gandiva/jni/config_holder.cc b/src/arrow/cpp/src/gandiva/jni/config_holder.cc
new file mode 100644
index 000000000..11d305c81
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/jni/config_holder.cc
@@ -0,0 +1,30 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/jni/config_holder.h"
+
+#include <cstdint>
+
+namespace gandiva {
+int64_t ConfigHolder::config_id_ = 1;
+
+// map of configuration objects created so far
+std::unordered_map<int64_t, std::shared_ptr<Configuration>>
+ ConfigHolder::configuration_map_;
+
+std::mutex ConfigHolder::g_mtx_;
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/jni/config_holder.h b/src/arrow/cpp/src/gandiva/jni/config_holder.h
new file mode 100644
index 000000000..3fdb7a01d
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/jni/config_holder.h
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+
+#include "gandiva/configuration.h"
+
+namespace gandiva {
+
+class ConfigHolder {
+ public:
+ static int64_t MapInsert(std::shared_ptr<Configuration> config) {
+ g_mtx_.lock();
+
+ int64_t result = config_id_++;
+ configuration_map_.insert(
+ std::pair<int64_t, std::shared_ptr<Configuration>>(result, config));
+
+ g_mtx_.unlock();
+ return result;
+ }
+
+ static void MapErase(int64_t config_id_) {
+ g_mtx_.lock();
+ configuration_map_.erase(config_id_);
+ g_mtx_.unlock();
+ }
+
+ static std::shared_ptr<Configuration> MapLookup(int64_t config_id_) {
+ std::shared_ptr<Configuration> result = nullptr;
+
+ try {
+ result = configuration_map_.at(config_id_);
+ } catch (const std::out_of_range&) {
+ }
+
+ return result;
+ }
+
+ private:
+ // map of configuration objects created so far
+ static std::unordered_map<int64_t, std::shared_ptr<Configuration>> configuration_map_;
+
+ static std::mutex g_mtx_;
+
+ // atomic counter for projector module ids
+ static int64_t config_id_;
+};
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/jni/env_helper.h b/src/arrow/cpp/src/gandiva/jni/env_helper.h
new file mode 100644
index 000000000..5ae13c807
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/jni/env_helper.h
@@ -0,0 +1,23 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <jni.h>
+
+// class references
+extern jclass configuration_builder_class_;
diff --git a/src/arrow/cpp/src/gandiva/jni/expression_registry_helper.cc b/src/arrow/cpp/src/gandiva/jni/expression_registry_helper.cc
new file mode 100644
index 000000000..0d1f74ba6
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/jni/expression_registry_helper.cc
@@ -0,0 +1,190 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "jni/org_apache_arrow_gandiva_evaluator_ExpressionRegistryJniHelper.h"
+
+#include <memory>
+
+#include "Types.pb.h"
+#include "arrow/util/logging.h"
+#include "gandiva/arrow.h"
+#include "gandiva/expression_registry.h"
+
+using gandiva::DataTypePtr;
+using gandiva::ExpressionRegistry;
+
+types::TimeUnit MapTimeUnit(arrow::TimeUnit::type& unit) {
+ switch (unit) {
+ case arrow::TimeUnit::MILLI:
+ return types::TimeUnit::MILLISEC;
+ case arrow::TimeUnit::SECOND:
+ return types::TimeUnit::SEC;
+ case arrow::TimeUnit::MICRO:
+ return types::TimeUnit::MICROSEC;
+ case arrow::TimeUnit::NANO:
+ return types::TimeUnit::NANOSEC;
+ }
+ // satisfy gcc. should be unreachable.
+ return types::TimeUnit::SEC;
+}
+
+void ArrowToProtobuf(DataTypePtr type, types::ExtGandivaType* gandiva_data_type) {
+ switch (type->id()) {
+ case arrow::Type::BOOL:
+ gandiva_data_type->set_type(types::GandivaType::BOOL);
+ break;
+ case arrow::Type::UINT8:
+ gandiva_data_type->set_type(types::GandivaType::UINT8);
+ break;
+ case arrow::Type::INT8:
+ gandiva_data_type->set_type(types::GandivaType::INT8);
+ break;
+ case arrow::Type::UINT16:
+ gandiva_data_type->set_type(types::GandivaType::UINT16);
+ break;
+ case arrow::Type::INT16:
+ gandiva_data_type->set_type(types::GandivaType::INT16);
+ break;
+ case arrow::Type::UINT32:
+ gandiva_data_type->set_type(types::GandivaType::UINT32);
+ break;
+ case arrow::Type::INT32:
+ gandiva_data_type->set_type(types::GandivaType::INT32);
+ break;
+ case arrow::Type::UINT64:
+ gandiva_data_type->set_type(types::GandivaType::UINT64);
+ break;
+ case arrow::Type::INT64:
+ gandiva_data_type->set_type(types::GandivaType::INT64);
+ break;
+ case arrow::Type::HALF_FLOAT:
+ gandiva_data_type->set_type(types::GandivaType::HALF_FLOAT);
+ break;
+ case arrow::Type::FLOAT:
+ gandiva_data_type->set_type(types::GandivaType::FLOAT);
+ break;
+ case arrow::Type::DOUBLE:
+ gandiva_data_type->set_type(types::GandivaType::DOUBLE);
+ break;
+ case arrow::Type::STRING:
+ gandiva_data_type->set_type(types::GandivaType::UTF8);
+ break;
+ case arrow::Type::BINARY:
+ gandiva_data_type->set_type(types::GandivaType::BINARY);
+ break;
+ case arrow::Type::DATE32:
+ gandiva_data_type->set_type(types::GandivaType::DATE32);
+ break;
+ case arrow::Type::DATE64:
+ gandiva_data_type->set_type(types::GandivaType::DATE64);
+ break;
+ case arrow::Type::TIMESTAMP: {
+ gandiva_data_type->set_type(types::GandivaType::TIMESTAMP);
+ std::shared_ptr<arrow::TimestampType> cast_time_stamp_type =
+ std::dynamic_pointer_cast<arrow::TimestampType>(type);
+ arrow::TimeUnit::type unit = cast_time_stamp_type->unit();
+ types::TimeUnit time_unit = MapTimeUnit(unit);
+ gandiva_data_type->set_timeunit(time_unit);
+ break;
+ }
+ case arrow::Type::TIME32: {
+ gandiva_data_type->set_type(types::GandivaType::TIME32);
+ std::shared_ptr<arrow::Time32Type> cast_time_32_type =
+ std::dynamic_pointer_cast<arrow::Time32Type>(type);
+ arrow::TimeUnit::type unit = cast_time_32_type->unit();
+ types::TimeUnit time_unit = MapTimeUnit(unit);
+ gandiva_data_type->set_timeunit(time_unit);
+ break;
+ }
+ case arrow::Type::TIME64: {
+ gandiva_data_type->set_type(types::GandivaType::TIME32);
+ std::shared_ptr<arrow::Time64Type> cast_time_64_type =
+ std::dynamic_pointer_cast<arrow::Time64Type>(type);
+ arrow::TimeUnit::type unit = cast_time_64_type->unit();
+ types::TimeUnit time_unit = MapTimeUnit(unit);
+ gandiva_data_type->set_timeunit(time_unit);
+ break;
+ }
+ case arrow::Type::NA:
+ gandiva_data_type->set_type(types::GandivaType::NONE);
+ break;
+ case arrow::Type::DECIMAL: {
+ gandiva_data_type->set_type(types::GandivaType::DECIMAL);
+ gandiva_data_type->set_precision(0);
+ gandiva_data_type->set_scale(0);
+ break;
+ }
+ case arrow::Type::INTERVAL_MONTHS:
+ gandiva_data_type->set_type(types::GandivaType::INTERVAL);
+ gandiva_data_type->set_intervaltype(types::IntervalType::YEAR_MONTH);
+ break;
+ case arrow::Type::INTERVAL_DAY_TIME:
+ gandiva_data_type->set_type(types::GandivaType::INTERVAL);
+ gandiva_data_type->set_intervaltype(types::IntervalType::DAY_TIME);
+ break;
+ default:
+ // un-supported types. test ensures that
+ // when one of these are added build breaks.
+ DCHECK(false);
+ }
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_org_apache_arrow_gandiva_evaluator_ExpressionRegistryJniHelper_getGandivaSupportedDataTypes( // NOLINT
+ JNIEnv* env, jobject types_helper) {
+ types::GandivaDataTypes gandiva_data_types;
+ auto supported_types = ExpressionRegistry::supported_types();
+ for (auto const& type : supported_types) {
+ types::ExtGandivaType* gandiva_data_type = gandiva_data_types.add_datatype();
+ ArrowToProtobuf(type, gandiva_data_type);
+ }
+ auto size = gandiva_data_types.ByteSizeLong();
+ std::unique_ptr<jbyte[]> buffer{new jbyte[size]};
+ gandiva_data_types.SerializeToArray(reinterpret_cast<void*>(buffer.get()), size);
+ jbyteArray ret = env->NewByteArray(size);
+ env->SetByteArrayRegion(ret, 0, size, buffer.get());
+ return ret;
+}
+
+/*
+ * Class: org_apache_arrow_gandiva_types_ExpressionRegistryJniHelper
+ * Method: getGandivaSupportedFunctions
+ * Signature: ()[B
+ */
+JNIEXPORT jbyteArray JNICALL
+Java_org_apache_arrow_gandiva_evaluator_ExpressionRegistryJniHelper_getGandivaSupportedFunctions( // NOLINT
+ JNIEnv* env, jobject types_helper) {
+ ExpressionRegistry expr_registry;
+ types::GandivaFunctions gandiva_functions;
+ for (auto function = expr_registry.function_signature_begin();
+ function != expr_registry.function_signature_end(); function++) {
+ types::FunctionSignature* function_signature = gandiva_functions.add_function();
+ function_signature->set_name((*function).base_name());
+ types::ExtGandivaType* return_type = function_signature->mutable_returntype();
+ ArrowToProtobuf((*function).ret_type(), return_type);
+ for (auto& param_type : (*function).param_types()) {
+ types::ExtGandivaType* proto_param_type = function_signature->add_paramtypes();
+ ArrowToProtobuf(param_type, proto_param_type);
+ }
+ }
+ auto size = gandiva_functions.ByteSizeLong();
+ std::unique_ptr<jbyte[]> buffer{new jbyte[size]};
+ gandiva_functions.SerializeToArray(reinterpret_cast<void*>(buffer.get()), size);
+ jbyteArray ret = env->NewByteArray(size);
+ env->SetByteArrayRegion(ret, 0, size, buffer.get());
+ return ret;
+}
diff --git a/src/arrow/cpp/src/gandiva/jni/id_to_module_map.h b/src/arrow/cpp/src/gandiva/jni/id_to_module_map.h
new file mode 100644
index 000000000..98100955b
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/jni/id_to_module_map.h
@@ -0,0 +1,66 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <unordered_map>
+#include <utility>
+
+namespace gandiva {
+
+template <typename HOLDER>
+class IdToModuleMap {
+ public:
+ IdToModuleMap() : module_id_(kInitModuleId) {}
+
+ jlong Insert(HOLDER holder) {
+ mtx_.lock();
+ jlong result = module_id_++;
+ map_.insert(std::pair<jlong, HOLDER>(result, holder));
+ mtx_.unlock();
+ return result;
+ }
+
+ void Erase(jlong module_id) {
+ mtx_.lock();
+ map_.erase(module_id);
+ mtx_.unlock();
+ }
+
+ HOLDER Lookup(jlong module_id) {
+ HOLDER result = nullptr;
+ mtx_.lock();
+ try {
+ result = map_.at(module_id);
+ } catch (const std::out_of_range&) {
+ }
+ mtx_.unlock();
+ return result;
+ }
+
+ private:
+ static const int kInitModuleId = 4;
+
+ int64_t module_id_;
+ std::mutex mtx_;
+ // map from module ids returned to Java and module pointers
+ std::unordered_map<jlong, HOLDER> map_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/jni/jni_common.cc b/src/arrow/cpp/src/gandiva/jni/jni_common.cc
new file mode 100644
index 000000000..5a4cbb031
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/jni/jni_common.cc
@@ -0,0 +1,1055 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <google/protobuf/io/coded_stream.h>
+
+#include <map>
+#include <memory>
+#include <mutex>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <arrow/builder.h>
+#include <arrow/record_batch.h>
+#include <arrow/type.h>
+
+#include "Types.pb.h"
+#include "gandiva/configuration.h"
+#include "gandiva/decimal_scalar.h"
+#include "gandiva/filter.h"
+#include "gandiva/jni/config_holder.h"
+#include "gandiva/jni/env_helper.h"
+#include "gandiva/jni/id_to_module_map.h"
+#include "gandiva/jni/module_holder.h"
+#include "gandiva/projector.h"
+#include "gandiva/selection_vector.h"
+#include "gandiva/tree_expr_builder.h"
+#include "jni/org_apache_arrow_gandiva_evaluator_JniWrapper.h"
+
+using gandiva::ConditionPtr;
+using gandiva::DataTypePtr;
+using gandiva::ExpressionPtr;
+using gandiva::ExpressionVector;
+using gandiva::FieldPtr;
+using gandiva::FieldVector;
+using gandiva::Filter;
+using gandiva::NodePtr;
+using gandiva::NodeVector;
+using gandiva::Projector;
+using gandiva::SchemaPtr;
+using gandiva::Status;
+using gandiva::TreeExprBuilder;
+
+using gandiva::ArrayDataVector;
+using gandiva::ConfigHolder;
+using gandiva::Configuration;
+using gandiva::ConfigurationBuilder;
+using gandiva::FilterHolder;
+using gandiva::ProjectorHolder;
+
+// forward declarations
+NodePtr ProtoTypeToNode(const types::TreeNode& node);
+
+static jint JNI_VERSION = JNI_VERSION_1_6;
+
+// extern refs - initialized for other modules.
+jclass configuration_builder_class_;
+
+// refs for self.
+static jclass gandiva_exception_;
+static jclass vector_expander_class_;
+static jclass vector_expander_ret_class_;
+static jmethodID vector_expander_method_;
+static jfieldID vector_expander_ret_address_;
+static jfieldID vector_expander_ret_capacity_;
+
+// module maps
+gandiva::IdToModuleMap<std::shared_ptr<ProjectorHolder>> projector_modules_;
+gandiva::IdToModuleMap<std::shared_ptr<FilterHolder>> filter_modules_;
+
+jint JNI_OnLoad(JavaVM* vm, void* reserved) {
+ JNIEnv* env;
+ if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
+ return JNI_ERR;
+ }
+ jclass local_configuration_builder_class_ =
+ env->FindClass("org/apache/arrow/gandiva/evaluator/ConfigurationBuilder");
+ configuration_builder_class_ =
+ (jclass)env->NewGlobalRef(local_configuration_builder_class_);
+ env->DeleteLocalRef(local_configuration_builder_class_);
+
+ jclass localExceptionClass =
+ env->FindClass("org/apache/arrow/gandiva/exceptions/GandivaException");
+ gandiva_exception_ = (jclass)env->NewGlobalRef(localExceptionClass);
+ env->ExceptionDescribe();
+ env->DeleteLocalRef(localExceptionClass);
+
+ jclass local_expander_class =
+ env->FindClass("org/apache/arrow/gandiva/evaluator/VectorExpander");
+ vector_expander_class_ = (jclass)env->NewGlobalRef(local_expander_class);
+ env->DeleteLocalRef(local_expander_class);
+
+ vector_expander_method_ = env->GetMethodID(
+ vector_expander_class_, "expandOutputVectorAtIndex",
+ "(IJ)Lorg/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult;");
+
+ jclass local_expander_ret_class =
+ env->FindClass("org/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult");
+ vector_expander_ret_class_ = (jclass)env->NewGlobalRef(local_expander_ret_class);
+ env->DeleteLocalRef(local_expander_ret_class);
+
+ vector_expander_ret_address_ =
+ env->GetFieldID(vector_expander_ret_class_, "address", "J");
+ vector_expander_ret_capacity_ =
+ env->GetFieldID(vector_expander_ret_class_, "capacity", "J");
+ return JNI_VERSION;
+}
+
+void JNI_OnUnload(JavaVM* vm, void* reserved) {
+ JNIEnv* env;
+ vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION);
+ env->DeleteGlobalRef(configuration_builder_class_);
+ env->DeleteGlobalRef(gandiva_exception_);
+ env->DeleteGlobalRef(vector_expander_class_);
+ env->DeleteGlobalRef(vector_expander_ret_class_);
+}
+
+DataTypePtr ProtoTypeToTime32(const types::ExtGandivaType& ext_type) {
+ switch (ext_type.timeunit()) {
+ case types::SEC:
+ return arrow::time32(arrow::TimeUnit::SECOND);
+ case types::MILLISEC:
+ return arrow::time32(arrow::TimeUnit::MILLI);
+ default:
+ std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for time32\n";
+ return nullptr;
+ }
+}
+
+DataTypePtr ProtoTypeToTime64(const types::ExtGandivaType& ext_type) {
+ switch (ext_type.timeunit()) {
+ case types::MICROSEC:
+ return arrow::time64(arrow::TimeUnit::MICRO);
+ case types::NANOSEC:
+ return arrow::time64(arrow::TimeUnit::NANO);
+ default:
+ std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for time64\n";
+ return nullptr;
+ }
+}
+
+DataTypePtr ProtoTypeToTimestamp(const types::ExtGandivaType& ext_type) {
+ switch (ext_type.timeunit()) {
+ case types::SEC:
+ return arrow::timestamp(arrow::TimeUnit::SECOND);
+ case types::MILLISEC:
+ return arrow::timestamp(arrow::TimeUnit::MILLI);
+ case types::MICROSEC:
+ return arrow::timestamp(arrow::TimeUnit::MICRO);
+ case types::NANOSEC:
+ return arrow::timestamp(arrow::TimeUnit::NANO);
+ default:
+ std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for timestamp\n";
+ return nullptr;
+ }
+}
+
+DataTypePtr ProtoTypeToInterval(const types::ExtGandivaType& ext_type) {
+ switch (ext_type.intervaltype()) {
+ case types::YEAR_MONTH:
+ return arrow::month_interval();
+ case types::DAY_TIME:
+ return arrow::day_time_interval();
+ default:
+ std::cerr << "Unknown interval type: " << ext_type.intervaltype() << "\n";
+ return nullptr;
+ }
+}
+
+DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) {
+ switch (ext_type.type()) {
+ case types::NONE:
+ return arrow::null();
+ case types::BOOL:
+ return arrow::boolean();
+ case types::UINT8:
+ return arrow::uint8();
+ case types::INT8:
+ return arrow::int8();
+ case types::UINT16:
+ return arrow::uint16();
+ case types::INT16:
+ return arrow::int16();
+ case types::UINT32:
+ return arrow::uint32();
+ case types::INT32:
+ return arrow::int32();
+ case types::UINT64:
+ return arrow::uint64();
+ case types::INT64:
+ return arrow::int64();
+ case types::HALF_FLOAT:
+ return arrow::float16();
+ case types::FLOAT:
+ return arrow::float32();
+ case types::DOUBLE:
+ return arrow::float64();
+ case types::UTF8:
+ return arrow::utf8();
+ case types::BINARY:
+ return arrow::binary();
+ case types::DATE32:
+ return arrow::date32();
+ case types::DATE64:
+ return arrow::date64();
+ case types::DECIMAL:
+ // TODO: error handling
+ return arrow::decimal(ext_type.precision(), ext_type.scale());
+ case types::TIME32:
+ return ProtoTypeToTime32(ext_type);
+ case types::TIME64:
+ return ProtoTypeToTime64(ext_type);
+ case types::TIMESTAMP:
+ return ProtoTypeToTimestamp(ext_type);
+ case types::INTERVAL:
+ return ProtoTypeToInterval(ext_type);
+ case types::FIXED_SIZE_BINARY:
+ case types::LIST:
+ case types::STRUCT:
+ case types::UNION:
+ case types::DICTIONARY:
+ case types::MAP:
+ std::cerr << "Unhandled data type: " << ext_type.type() << "\n";
+ return nullptr;
+
+ default:
+ std::cerr << "Unknown data type: " << ext_type.type() << "\n";
+ return nullptr;
+ }
+}
+
+FieldPtr ProtoTypeToField(const types::Field& f) {
+ const std::string& name = f.name();
+ DataTypePtr type = ProtoTypeToDataType(f.type());
+ bool nullable = true;
+ if (f.has_nullable()) {
+ nullable = f.nullable();
+ }
+
+ return field(name, type, nullable);
+}
+
+NodePtr ProtoTypeToFieldNode(const types::FieldNode& node) {
+ FieldPtr field_ptr = ProtoTypeToField(node.field());
+ if (field_ptr == nullptr) {
+ std::cerr << "Unable to create field node from protobuf\n";
+ return nullptr;
+ }
+
+ return TreeExprBuilder::MakeField(field_ptr);
+}
+
+NodePtr ProtoTypeToFnNode(const types::FunctionNode& node) {
+ const std::string& name = node.functionname();
+ NodeVector children;
+
+ for (int i = 0; i < node.inargs_size(); i++) {
+ const types::TreeNode& arg = node.inargs(i);
+
+ NodePtr n = ProtoTypeToNode(arg);
+ if (n == nullptr) {
+ std::cerr << "Unable to create argument for function: " << name << "\n";
+ return nullptr;
+ }
+
+ children.push_back(n);
+ }
+
+ DataTypePtr return_type = ProtoTypeToDataType(node.returntype());
+ if (return_type == nullptr) {
+ std::cerr << "Unknown return type for function: " << name << "\n";
+ return nullptr;
+ }
+
+ return TreeExprBuilder::MakeFunction(name, children, return_type);
+}
+
+NodePtr ProtoTypeToIfNode(const types::IfNode& node) {
+ NodePtr cond = ProtoTypeToNode(node.cond());
+ if (cond == nullptr) {
+ std::cerr << "Unable to create cond node for if node\n";
+ return nullptr;
+ }
+
+ NodePtr then_node = ProtoTypeToNode(node.thennode());
+ if (then_node == nullptr) {
+ std::cerr << "Unable to create then node for if node\n";
+ return nullptr;
+ }
+
+ NodePtr else_node = ProtoTypeToNode(node.elsenode());
+ if (else_node == nullptr) {
+ std::cerr << "Unable to create else node for if node\n";
+ return nullptr;
+ }
+
+ DataTypePtr return_type = ProtoTypeToDataType(node.returntype());
+ if (return_type == nullptr) {
+ std::cerr << "Unknown return type for if node\n";
+ return nullptr;
+ }
+
+ return TreeExprBuilder::MakeIf(cond, then_node, else_node, return_type);
+}
+
+NodePtr ProtoTypeToAndNode(const types::AndNode& node) {
+ NodeVector children;
+
+ for (int i = 0; i < node.args_size(); i++) {
+ const types::TreeNode& arg = node.args(i);
+
+ NodePtr n = ProtoTypeToNode(arg);
+ if (n == nullptr) {
+ std::cerr << "Unable to create argument for boolean and\n";
+ return nullptr;
+ }
+ children.push_back(n);
+ }
+ return TreeExprBuilder::MakeAnd(children);
+}
+
+NodePtr ProtoTypeToOrNode(const types::OrNode& node) {
+ NodeVector children;
+
+ for (int i = 0; i < node.args_size(); i++) {
+ const types::TreeNode& arg = node.args(i);
+
+ NodePtr n = ProtoTypeToNode(arg);
+ if (n == nullptr) {
+ std::cerr << "Unable to create argument for boolean or\n";
+ return nullptr;
+ }
+ children.push_back(n);
+ }
+ return TreeExprBuilder::MakeOr(children);
+}
+
+NodePtr ProtoTypeToInNode(const types::InNode& node) {
+ NodePtr field = ProtoTypeToNode(node.node());
+
+ if (node.has_intvalues()) {
+ std::unordered_set<int32_t> int_values;
+ for (int i = 0; i < node.intvalues().intvalues_size(); i++) {
+ int_values.insert(node.intvalues().intvalues(i).value());
+ }
+ return TreeExprBuilder::MakeInExpressionInt32(field, int_values);
+ }
+
+ if (node.has_longvalues()) {
+ std::unordered_set<int64_t> long_values;
+ for (int i = 0; i < node.longvalues().longvalues_size(); i++) {
+ long_values.insert(node.longvalues().longvalues(i).value());
+ }
+ return TreeExprBuilder::MakeInExpressionInt64(field, long_values);
+ }
+
+ if (node.has_decimalvalues()) {
+ std::unordered_set<gandiva::DecimalScalar128> decimal_values;
+ for (int i = 0; i < node.decimalvalues().decimalvalues_size(); i++) {
+ decimal_values.insert(
+ gandiva::DecimalScalar128(node.decimalvalues().decimalvalues(i).value(),
+ node.decimalvalues().decimalvalues(i).precision(),
+ node.decimalvalues().decimalvalues(i).scale()));
+ }
+ return TreeExprBuilder::MakeInExpressionDecimal(field, decimal_values);
+ }
+
+ if (node.has_floatvalues()) {
+ std::unordered_set<float> float_values;
+ for (int i = 0; i < node.floatvalues().floatvalues_size(); i++) {
+ float_values.insert(node.floatvalues().floatvalues(i).value());
+ }
+ return TreeExprBuilder::MakeInExpressionFloat(field, float_values);
+ }
+
+ if (node.has_doublevalues()) {
+ std::unordered_set<double> double_values;
+ for (int i = 0; i < node.doublevalues().doublevalues_size(); i++) {
+ double_values.insert(node.doublevalues().doublevalues(i).value());
+ }
+ return TreeExprBuilder::MakeInExpressionDouble(field, double_values);
+ }
+
+ if (node.has_stringvalues()) {
+ std::unordered_set<std::string> stringvalues;
+ for (int i = 0; i < node.stringvalues().stringvalues_size(); i++) {
+ stringvalues.insert(node.stringvalues().stringvalues(i).value());
+ }
+ return TreeExprBuilder::MakeInExpressionString(field, stringvalues);
+ }
+
+ if (node.has_binaryvalues()) {
+ std::unordered_set<std::string> stringvalues;
+ for (int i = 0; i < node.binaryvalues().binaryvalues_size(); i++) {
+ stringvalues.insert(node.binaryvalues().binaryvalues(i).value());
+ }
+ return TreeExprBuilder::MakeInExpressionBinary(field, stringvalues);
+ }
+ // not supported yet.
+ std::cerr << "Unknown constant type for in expression.\n";
+ return nullptr;
+}
+
+NodePtr ProtoTypeToNullNode(const types::NullNode& node) {
+ DataTypePtr data_type = ProtoTypeToDataType(node.type());
+ if (data_type == nullptr) {
+ std::cerr << "Unknown type " << data_type->ToString() << " for null node\n";
+ return nullptr;
+ }
+
+ return TreeExprBuilder::MakeNull(data_type);
+}
+
+NodePtr ProtoTypeToNode(const types::TreeNode& node) {
+ if (node.has_fieldnode()) {
+ return ProtoTypeToFieldNode(node.fieldnode());
+ }
+
+ if (node.has_fnnode()) {
+ return ProtoTypeToFnNode(node.fnnode());
+ }
+
+ if (node.has_ifnode()) {
+ return ProtoTypeToIfNode(node.ifnode());
+ }
+
+ if (node.has_andnode()) {
+ return ProtoTypeToAndNode(node.andnode());
+ }
+
+ if (node.has_ornode()) {
+ return ProtoTypeToOrNode(node.ornode());
+ }
+
+ if (node.has_innode()) {
+ return ProtoTypeToInNode(node.innode());
+ }
+
+ if (node.has_nullnode()) {
+ return ProtoTypeToNullNode(node.nullnode());
+ }
+
+ if (node.has_intnode()) {
+ return TreeExprBuilder::MakeLiteral(node.intnode().value());
+ }
+
+ if (node.has_floatnode()) {
+ return TreeExprBuilder::MakeLiteral(node.floatnode().value());
+ }
+
+ if (node.has_longnode()) {
+ return TreeExprBuilder::MakeLiteral(node.longnode().value());
+ }
+
+ if (node.has_booleannode()) {
+ return TreeExprBuilder::MakeLiteral(node.booleannode().value());
+ }
+
+ if (node.has_doublenode()) {
+ return TreeExprBuilder::MakeLiteral(node.doublenode().value());
+ }
+
+ if (node.has_stringnode()) {
+ return TreeExprBuilder::MakeStringLiteral(node.stringnode().value());
+ }
+
+ if (node.has_binarynode()) {
+ return TreeExprBuilder::MakeBinaryLiteral(node.binarynode().value());
+ }
+
+ if (node.has_decimalnode()) {
+ std::string value = node.decimalnode().value();
+ gandiva::DecimalScalar128 literal(value, node.decimalnode().precision(),
+ node.decimalnode().scale());
+ return TreeExprBuilder::MakeDecimalLiteral(literal);
+ }
+ std::cerr << "Unknown node type in protobuf\n";
+ return nullptr;
+}
+
+ExpressionPtr ProtoTypeToExpression(const types::ExpressionRoot& root) {
+ NodePtr root_node = ProtoTypeToNode(root.root());
+ if (root_node == nullptr) {
+ std::cerr << "Unable to create expression node from expression protobuf\n";
+ return nullptr;
+ }
+
+ FieldPtr field = ProtoTypeToField(root.resulttype());
+ if (field == nullptr) {
+ std::cerr << "Unable to extra return field from expression protobuf\n";
+ return nullptr;
+ }
+
+ return TreeExprBuilder::MakeExpression(root_node, field);
+}
+
+ConditionPtr ProtoTypeToCondition(const types::Condition& condition) {
+ NodePtr root_node = ProtoTypeToNode(condition.root());
+ if (root_node == nullptr) {
+ return nullptr;
+ }
+
+ return TreeExprBuilder::MakeCondition(root_node);
+}
+
+SchemaPtr ProtoTypeToSchema(const types::Schema& schema) {
+ std::vector<FieldPtr> fields;
+
+ for (int i = 0; i < schema.columns_size(); i++) {
+ FieldPtr field = ProtoTypeToField(schema.columns(i));
+ if (field == nullptr) {
+ std::cerr << "Unable to extract arrow field from schema\n";
+ return nullptr;
+ }
+
+ fields.push_back(field);
+ }
+
+ return arrow::schema(fields);
+}
+
+// Common for both projector and filters.
+
+bool ParseProtobuf(uint8_t* buf, int bufLen, google::protobuf::Message* msg) {
+ google::protobuf::io::CodedInputStream cis(buf, bufLen);
+ cis.SetRecursionLimit(1000);
+ return msg->ParseFromCodedStream(&cis);
+}
+
+Status make_record_batch_with_buf_addrs(SchemaPtr schema, int num_rows,
+ jlong* in_buf_addrs, jlong* in_buf_sizes,
+ int in_bufs_len,
+ std::shared_ptr<arrow::RecordBatch>* batch) {
+ std::vector<std::shared_ptr<arrow::ArrayData>> columns;
+ auto num_fields = schema->num_fields();
+ int buf_idx = 0;
+ int sz_idx = 0;
+
+ for (int i = 0; i < num_fields; i++) {
+ auto field = schema->field(i);
+ std::vector<std::shared_ptr<arrow::Buffer>> buffers;
+
+ if (buf_idx >= in_bufs_len) {
+ return Status::Invalid("insufficient number of in_buf_addrs");
+ }
+ jlong validity_addr = in_buf_addrs[buf_idx++];
+ jlong validity_size = in_buf_sizes[sz_idx++];
+ auto validity = std::shared_ptr<arrow::Buffer>(
+ new arrow::Buffer(reinterpret_cast<uint8_t*>(validity_addr), validity_size));
+ buffers.push_back(validity);
+
+ if (buf_idx >= in_bufs_len) {
+ return Status::Invalid("insufficient number of in_buf_addrs");
+ }
+ jlong value_addr = in_buf_addrs[buf_idx++];
+ jlong value_size = in_buf_sizes[sz_idx++];
+ auto data = std::shared_ptr<arrow::Buffer>(
+ new arrow::Buffer(reinterpret_cast<uint8_t*>(value_addr), value_size));
+ buffers.push_back(data);
+
+ if (arrow::is_binary_like(field->type()->id())) {
+ if (buf_idx >= in_bufs_len) {
+ return Status::Invalid("insufficient number of in_buf_addrs");
+ }
+
+ // add offsets buffer for variable-len fields.
+ jlong offsets_addr = in_buf_addrs[buf_idx++];
+ jlong offsets_size = in_buf_sizes[sz_idx++];
+ auto offsets = std::shared_ptr<arrow::Buffer>(
+ new arrow::Buffer(reinterpret_cast<uint8_t*>(offsets_addr), offsets_size));
+ buffers.push_back(offsets);
+ }
+
+ auto array_data = arrow::ArrayData::Make(field->type(), num_rows, std::move(buffers));
+ columns.push_back(array_data);
+ }
+ *batch = arrow::RecordBatch::Make(schema, num_rows, columns);
+ return Status::OK();
+}
+
+// projector related functions.
+void releaseProjectorInput(jbyteArray schema_arr, jbyte* schema_bytes,
+ jbyteArray exprs_arr, jbyte* exprs_bytes, JNIEnv* env) {
+ env->ReleaseByteArrayElements(schema_arr, schema_bytes, JNI_ABORT);
+ env->ReleaseByteArrayElements(exprs_arr, exprs_bytes, JNI_ABORT);
+}
+
+JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildProjector(
+ JNIEnv* env, jobject obj, jbyteArray schema_arr, jbyteArray exprs_arr,
+ jint selection_vector_type, jlong configuration_id) {
+ jlong module_id = 0LL;
+ std::shared_ptr<Projector> projector;
+ std::shared_ptr<ProjectorHolder> holder;
+
+ types::Schema schema;
+ jsize schema_len = env->GetArrayLength(schema_arr);
+ jbyte* schema_bytes = env->GetByteArrayElements(schema_arr, 0);
+
+ types::ExpressionList exprs;
+ jsize exprs_len = env->GetArrayLength(exprs_arr);
+ jbyte* exprs_bytes = env->GetByteArrayElements(exprs_arr, 0);
+
+ ExpressionVector expr_vector;
+ SchemaPtr schema_ptr;
+ FieldVector ret_types;
+ gandiva::Status status;
+ auto mode = gandiva::SelectionVector::MODE_NONE;
+
+ std::shared_ptr<Configuration> config = ConfigHolder::MapLookup(configuration_id);
+ std::stringstream ss;
+
+ if (config == nullptr) {
+ ss << "configuration is mandatory.";
+ releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
+ goto err_out;
+ }
+
+ if (!ParseProtobuf(reinterpret_cast<uint8_t*>(schema_bytes), schema_len, &schema)) {
+ ss << "Unable to parse schema protobuf\n";
+ releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
+ goto err_out;
+ }
+
+ if (!ParseProtobuf(reinterpret_cast<uint8_t*>(exprs_bytes), exprs_len, &exprs)) {
+ releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
+ ss << "Unable to parse expressions protobuf\n";
+ goto err_out;
+ }
+
+ // convert types::Schema to arrow::Schema
+ schema_ptr = ProtoTypeToSchema(schema);
+ if (schema_ptr == nullptr) {
+ ss << "Unable to construct arrow schema object from schema protobuf\n";
+ releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
+ goto err_out;
+ }
+
+ // create Expression out of the list of exprs
+ for (int i = 0; i < exprs.exprs_size(); i++) {
+ ExpressionPtr root = ProtoTypeToExpression(exprs.exprs(i));
+
+ if (root == nullptr) {
+ ss << "Unable to construct expression object from expression protobuf\n";
+ releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
+ goto err_out;
+ }
+
+ expr_vector.push_back(root);
+ ret_types.push_back(root->result());
+ }
+
+ switch (selection_vector_type) {
+ case types::SV_NONE:
+ mode = gandiva::SelectionVector::MODE_NONE;
+ break;
+ case types::SV_INT16:
+ mode = gandiva::SelectionVector::MODE_UINT16;
+ break;
+ case types::SV_INT32:
+ mode = gandiva::SelectionVector::MODE_UINT32;
+ break;
+ }
+ // good to invoke the evaluator now
+ status = Projector::Make(schema_ptr, expr_vector, mode, config, &projector);
+
+ if (!status.ok()) {
+ ss << "Failed to make LLVM module due to " << status.message() << "\n";
+ releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
+ goto err_out;
+ }
+
+ // store the result in a map
+ holder = std::shared_ptr<ProjectorHolder>(
+ new ProjectorHolder(schema_ptr, ret_types, std::move(projector)));
+ module_id = projector_modules_.Insert(holder);
+ releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
+ return module_id;
+
+err_out:
+ env->ThrowNew(gandiva_exception_, ss.str().c_str());
+ return module_id;
+}
+
+///
+/// \brief Resizable buffer which resizes by doing a callback into java.
+///
+class JavaResizableBuffer : public arrow::ResizableBuffer {
+ public:
+ JavaResizableBuffer(JNIEnv* env, jobject jexpander, int32_t vector_idx, uint8_t* buffer,
+ int32_t len)
+ : ResizableBuffer(buffer, len),
+ env_(env),
+ jexpander_(jexpander),
+ vector_idx_(vector_idx) {
+ size_ = 0;
+ }
+
+ Status Resize(const int64_t new_size, bool shrink_to_fit) override;
+
+ Status Reserve(const int64_t new_capacity) override {
+ return Status::NotImplemented("reserve not implemented");
+ }
+
+ private:
+ JNIEnv* env_;
+ jobject jexpander_;
+ int32_t vector_idx_;
+};
+
+Status JavaResizableBuffer::Resize(const int64_t new_size, bool shrink_to_fit) {
+ if (shrink_to_fit == true) {
+ return Status::NotImplemented("shrink not implemented");
+ }
+
+ if (ARROW_PREDICT_TRUE(new_size < capacity())) {
+ // no need to expand.
+ size_ = new_size;
+ return Status::OK();
+ }
+
+ // callback into java to expand the buffer
+ jobject ret =
+ env_->CallObjectMethod(jexpander_, vector_expander_method_, vector_idx_, new_size);
+ if (env_->ExceptionCheck()) {
+ env_->ExceptionDescribe();
+ env_->ExceptionClear();
+ return Status::OutOfMemory("buffer expand failed in java");
+ }
+
+ jlong ret_address = env_->GetLongField(ret, vector_expander_ret_address_);
+ jlong ret_capacity = env_->GetLongField(ret, vector_expander_ret_capacity_);
+ DCHECK_GE(ret_capacity, new_size);
+
+ data_ = reinterpret_cast<uint8_t*>(ret_address);
+ size_ = new_size;
+ capacity_ = ret_capacity;
+ return Status::OK();
+}
+
+#define CHECK_OUT_BUFFER_IDX_AND_BREAK(idx, len) \
+ if (idx >= len) { \
+ status = gandiva::Status::Invalid("insufficient number of out_buf_addrs"); \
+ break; \
+ }
+
+JNIEXPORT void JNICALL
+Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector(
+ JNIEnv* env, jobject object, jobject jexpander, jlong module_id, jint num_rows,
+ jlongArray buf_addrs, jlongArray buf_sizes, jint sel_vec_type, jint sel_vec_rows,
+ jlong sel_vec_addr, jlong sel_vec_size, jlongArray out_buf_addrs,
+ jlongArray out_buf_sizes) {
+ Status status;
+ std::shared_ptr<ProjectorHolder> holder = projector_modules_.Lookup(module_id);
+ if (holder == nullptr) {
+ std::stringstream ss;
+ ss << "Unknown module id " << module_id;
+ env->ThrowNew(gandiva_exception_, ss.str().c_str());
+ return;
+ }
+
+ int in_bufs_len = env->GetArrayLength(buf_addrs);
+ if (in_bufs_len != env->GetArrayLength(buf_sizes)) {
+ env->ThrowNew(gandiva_exception_, "mismatch in arraylen of buf_addrs and buf_sizes");
+ return;
+ }
+
+ int out_bufs_len = env->GetArrayLength(out_buf_addrs);
+ if (out_bufs_len != env->GetArrayLength(out_buf_sizes)) {
+ env->ThrowNew(gandiva_exception_,
+ "mismatch in arraylen of out_buf_addrs and out_buf_sizes");
+ return;
+ }
+
+ jlong* in_buf_addrs = env->GetLongArrayElements(buf_addrs, 0);
+ jlong* in_buf_sizes = env->GetLongArrayElements(buf_sizes, 0);
+
+ jlong* out_bufs = env->GetLongArrayElements(out_buf_addrs, 0);
+ jlong* out_sizes = env->GetLongArrayElements(out_buf_sizes, 0);
+
+ do {
+ std::shared_ptr<arrow::RecordBatch> in_batch;
+ status = make_record_batch_with_buf_addrs(holder->schema(), num_rows, in_buf_addrs,
+ in_buf_sizes, in_bufs_len, &in_batch);
+ if (!status.ok()) {
+ break;
+ }
+
+ std::shared_ptr<gandiva::SelectionVector> selection_vector;
+ auto selection_buffer = std::make_shared<arrow::Buffer>(
+ reinterpret_cast<uint8_t*>(sel_vec_addr), sel_vec_size);
+ int output_row_count = 0;
+ switch (sel_vec_type) {
+ case types::SV_NONE: {
+ output_row_count = num_rows;
+ break;
+ }
+ case types::SV_INT16: {
+ status = gandiva::SelectionVector::MakeImmutableInt16(
+ sel_vec_rows, selection_buffer, &selection_vector);
+ output_row_count = sel_vec_rows;
+ break;
+ }
+ case types::SV_INT32: {
+ status = gandiva::SelectionVector::MakeImmutableInt32(
+ sel_vec_rows, selection_buffer, &selection_vector);
+ output_row_count = sel_vec_rows;
+ break;
+ }
+ }
+ if (!status.ok()) {
+ break;
+ }
+
+ auto ret_types = holder->rettypes();
+ ArrayDataVector output;
+ int buf_idx = 0;
+ int sz_idx = 0;
+ int output_vector_idx = 0;
+ for (FieldPtr field : ret_types) {
+ std::vector<std::shared_ptr<arrow::Buffer>> buffers;
+
+ CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len);
+ uint8_t* validity_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]);
+ jlong bitmap_sz = out_sizes[sz_idx++];
+ buffers.push_back(std::make_shared<arrow::MutableBuffer>(validity_buf, bitmap_sz));
+
+ if (arrow::is_binary_like(field->type()->id())) {
+ CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len);
+ uint8_t* offsets_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]);
+ jlong offsets_sz = out_sizes[sz_idx++];
+ buffers.push_back(
+ std::make_shared<arrow::MutableBuffer>(offsets_buf, offsets_sz));
+ }
+
+ CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len);
+ uint8_t* value_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]);
+ jlong data_sz = out_sizes[sz_idx++];
+ if (arrow::is_binary_like(field->type()->id())) {
+ if (jexpander == nullptr) {
+ status = Status::Invalid(
+ "expression has variable len output columns, but the expander object is "
+ "null");
+ break;
+ }
+ buffers.push_back(std::make_shared<JavaResizableBuffer>(
+ env, jexpander, output_vector_idx, value_buf, data_sz));
+ } else {
+ buffers.push_back(std::make_shared<arrow::MutableBuffer>(value_buf, data_sz));
+ }
+
+ auto array_data = arrow::ArrayData::Make(field->type(), output_row_count, buffers);
+ output.push_back(array_data);
+ ++output_vector_idx;
+ }
+ if (!status.ok()) {
+ break;
+ }
+ status = holder->projector()->Evaluate(*in_batch, selection_vector.get(), output);
+ } while (0);
+
+ env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT);
+ env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT);
+ env->ReleaseLongArrayElements(out_buf_addrs, out_bufs, JNI_ABORT);
+ env->ReleaseLongArrayElements(out_buf_sizes, out_sizes, JNI_ABORT);
+
+ if (!status.ok()) {
+ std::stringstream ss;
+ ss << "Evaluate returned " << status.message() << "\n";
+ env->ThrowNew(gandiva_exception_, status.message().c_str());
+ return;
+ }
+}
+
+JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeProjector(
+ JNIEnv* env, jobject cls, jlong module_id) {
+ projector_modules_.Erase(module_id);
+}
+
+// filter related functions.
+void releaseFilterInput(jbyteArray schema_arr, jbyte* schema_bytes,
+ jbyteArray condition_arr, jbyte* condition_bytes, JNIEnv* env) {
+ env->ReleaseByteArrayElements(schema_arr, schema_bytes, JNI_ABORT);
+ env->ReleaseByteArrayElements(condition_arr, condition_bytes, JNI_ABORT);
+}
+
+JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildFilter(
+ JNIEnv* env, jobject obj, jbyteArray schema_arr, jbyteArray condition_arr,
+ jlong configuration_id) {
+ jlong module_id = 0LL;
+ std::shared_ptr<Filter> filter;
+ std::shared_ptr<FilterHolder> holder;
+
+ types::Schema schema;
+ jsize schema_len = env->GetArrayLength(schema_arr);
+ jbyte* schema_bytes = env->GetByteArrayElements(schema_arr, 0);
+
+ types::Condition condition;
+ jsize condition_len = env->GetArrayLength(condition_arr);
+ jbyte* condition_bytes = env->GetByteArrayElements(condition_arr, 0);
+
+ ConditionPtr condition_ptr;
+ SchemaPtr schema_ptr;
+ gandiva::Status status;
+
+ std::shared_ptr<Configuration> config = ConfigHolder::MapLookup(configuration_id);
+ std::stringstream ss;
+
+ if (config == nullptr) {
+ ss << "configuration is mandatory.";
+ releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
+ goto err_out;
+ }
+
+ if (!ParseProtobuf(reinterpret_cast<uint8_t*>(schema_bytes), schema_len, &schema)) {
+ ss << "Unable to parse schema protobuf\n";
+ releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
+ goto err_out;
+ }
+
+ if (!ParseProtobuf(reinterpret_cast<uint8_t*>(condition_bytes), condition_len,
+ &condition)) {
+ ss << "Unable to parse condition protobuf\n";
+ releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
+ goto err_out;
+ }
+
+ // convert types::Schema to arrow::Schema
+ schema_ptr = ProtoTypeToSchema(schema);
+ if (schema_ptr == nullptr) {
+ ss << "Unable to construct arrow schema object from schema protobuf\n";
+ releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
+ goto err_out;
+ }
+
+ condition_ptr = ProtoTypeToCondition(condition);
+ if (condition_ptr == nullptr) {
+ ss << "Unable to construct condition object from condition protobuf\n";
+ releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
+ goto err_out;
+ }
+
+ // good to invoke the filter builder now
+ status = Filter::Make(schema_ptr, condition_ptr, config, &filter);
+ if (!status.ok()) {
+ ss << "Failed to make LLVM module due to " << status.message() << "\n";
+ releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
+ goto err_out;
+ }
+
+ // store the result in a map
+ holder = std::shared_ptr<FilterHolder>(new FilterHolder(schema_ptr, std::move(filter)));
+ module_id = filter_modules_.Insert(holder);
+ releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
+ return module_id;
+
+err_out:
+ env->ThrowNew(gandiva_exception_, ss.str().c_str());
+ return module_id;
+}
+
+JNIEXPORT jint JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateFilter(
+ JNIEnv* env, jobject cls, jlong module_id, jint num_rows, jlongArray buf_addrs,
+ jlongArray buf_sizes, jint jselection_vector_type, jlong out_buf_addr,
+ jlong out_buf_size) {
+ gandiva::Status status;
+ std::shared_ptr<FilterHolder> holder = filter_modules_.Lookup(module_id);
+ if (holder == nullptr) {
+ env->ThrowNew(gandiva_exception_, "Unknown module id\n");
+ return -1;
+ }
+
+ int in_bufs_len = env->GetArrayLength(buf_addrs);
+ if (in_bufs_len != env->GetArrayLength(buf_sizes)) {
+ env->ThrowNew(gandiva_exception_, "mismatch in arraylen of buf_addrs and buf_sizes");
+ return -1;
+ }
+
+ jlong* in_buf_addrs = env->GetLongArrayElements(buf_addrs, 0);
+ jlong* in_buf_sizes = env->GetLongArrayElements(buf_sizes, 0);
+ std::shared_ptr<gandiva::SelectionVector> selection_vector;
+
+ do {
+ std::shared_ptr<arrow::RecordBatch> in_batch;
+
+ status = make_record_batch_with_buf_addrs(holder->schema(), num_rows, in_buf_addrs,
+ in_buf_sizes, in_bufs_len, &in_batch);
+ if (!status.ok()) {
+ break;
+ }
+
+ auto selection_vector_type =
+ static_cast<types::SelectionVectorType>(jselection_vector_type);
+ auto out_buffer = std::make_shared<arrow::MutableBuffer>(
+ reinterpret_cast<uint8_t*>(out_buf_addr), out_buf_size);
+ switch (selection_vector_type) {
+ case types::SV_INT16:
+ status =
+ gandiva::SelectionVector::MakeInt16(num_rows, out_buffer, &selection_vector);
+ break;
+ case types::SV_INT32:
+ status =
+ gandiva::SelectionVector::MakeInt32(num_rows, out_buffer, &selection_vector);
+ break;
+ default:
+ status = gandiva::Status::Invalid("unknown selection vector type");
+ }
+ if (!status.ok()) {
+ break;
+ }
+
+ status = holder->filter()->Evaluate(*in_batch, selection_vector);
+ } while (0);
+
+ env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT);
+ env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT);
+
+ if (!status.ok()) {
+ std::stringstream ss;
+ ss << "Evaluate returned " << status.message() << "\n";
+ env->ThrowNew(gandiva_exception_, status.message().c_str());
+ return -1;
+ } else {
+ int64_t num_slots = selection_vector->GetNumSlots();
+ // Check integer overflow
+ if (num_slots > INT_MAX) {
+ std::stringstream ss;
+ ss << "The selection vector has " << num_slots
+ << " slots, which is larger than the " << INT_MAX << " limit.\n";
+ const std::string message = ss.str();
+ env->ThrowNew(gandiva_exception_, message.c_str());
+ return -1;
+ }
+ return static_cast<int>(num_slots);
+ }
+}
+
+JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeFilter(
+ JNIEnv* env, jobject cls, jlong module_id) {
+ filter_modules_.Erase(module_id);
+}
diff --git a/src/arrow/cpp/src/gandiva/jni/module_holder.h b/src/arrow/cpp/src/gandiva/jni/module_holder.h
new file mode 100644
index 000000000..929c64231
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/jni/module_holder.h
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <utility>
+
+#include "gandiva/arrow.h"
+
+namespace gandiva {
+
+class Projector;
+class Filter;
+
+class ProjectorHolder {
+ public:
+ ProjectorHolder(SchemaPtr schema, FieldVector ret_types,
+ std::shared_ptr<Projector> projector)
+ : schema_(schema), ret_types_(ret_types), projector_(std::move(projector)) {}
+
+ SchemaPtr schema() { return schema_; }
+ FieldVector rettypes() { return ret_types_; }
+ std::shared_ptr<Projector> projector() { return projector_; }
+
+ private:
+ SchemaPtr schema_;
+ FieldVector ret_types_;
+ std::shared_ptr<Projector> projector_;
+};
+
+class FilterHolder {
+ public:
+ FilterHolder(SchemaPtr schema, std::shared_ptr<Filter> filter)
+ : schema_(schema), filter_(std::move(filter)) {}
+
+ SchemaPtr schema() { return schema_; }
+ std::shared_ptr<Filter> filter() { return filter_; }
+
+ private:
+ SchemaPtr schema_;
+ std::shared_ptr<Filter> filter_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/jni/symbols.map b/src/arrow/cpp/src/gandiva/jni/symbols.map
new file mode 100644
index 000000000..e0f5def41
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/jni/symbols.map
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+{
+ global: extern "C++" { gandiva*; }; Java*; JNI*;
+ local: *;
+};
diff --git a/src/arrow/cpp/src/gandiva/like_holder.cc b/src/arrow/cpp/src/gandiva/like_holder.cc
new file mode 100644
index 000000000..af9ac67d6
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/like_holder.cc
@@ -0,0 +1,156 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/like_holder.h"
+
+#include <regex>
+#include "gandiva/node.h"
+#include "gandiva/regex_util.h"
+
+namespace gandiva {
+
+RE2 LikeHolder::starts_with_regex_(R"((\w|\s)*\.\*)");
+RE2 LikeHolder::ends_with_regex_(R"(\.\*(\w|\s)*)");
+RE2 LikeHolder::is_substr_regex_(R"(\.\*(\w|\s)*\.\*)");
+
+// Short-circuit pattern matches for the following common sub cases :
+// - starts_with, ends_with and is_substr
+const FunctionNode LikeHolder::TryOptimize(const FunctionNode& node) {
+ std::shared_ptr<LikeHolder> holder;
+ auto status = Make(node, &holder);
+ if (status.ok()) {
+ std::string& pattern = holder->pattern_;
+ auto literal_type = node.children().at(1)->return_type();
+
+ if (RE2::FullMatch(pattern, starts_with_regex_)) {
+ auto prefix = pattern.substr(0, pattern.length() - 2); // trim .*
+ auto prefix_node =
+ std::make_shared<LiteralNode>(literal_type, LiteralHolder(prefix), false);
+ return FunctionNode("starts_with", {node.children().at(0), prefix_node},
+ node.return_type());
+ } else if (RE2::FullMatch(pattern, ends_with_regex_)) {
+ auto suffix = pattern.substr(2); // skip .*
+ auto suffix_node =
+ std::make_shared<LiteralNode>(literal_type, LiteralHolder(suffix), false);
+ return FunctionNode("ends_with", {node.children().at(0), suffix_node},
+ node.return_type());
+ } else if (RE2::FullMatch(pattern, is_substr_regex_)) {
+ auto substr =
+ pattern.substr(2, pattern.length() - 4); // trim starting and ending .*
+ auto substr_node =
+ std::make_shared<LiteralNode>(literal_type, LiteralHolder(substr), false);
+ return FunctionNode("is_substr", {node.children().at(0), substr_node},
+ node.return_type());
+ }
+ }
+
+ // Could not optimize, return original node.
+ return node;
+}
+
+static bool IsArrowStringLiteral(arrow::Type::type type) {
+ return type == arrow::Type::STRING || type == arrow::Type::BINARY;
+}
+
+Status LikeHolder::Make(const FunctionNode& node, std::shared_ptr<LikeHolder>* holder) {
+ ARROW_RETURN_IF(node.children().size() != 2 && node.children().size() != 3,
+ Status::Invalid("'like' function requires two or three parameters"));
+
+ auto literal = dynamic_cast<LiteralNode*>(node.children().at(1).get());
+ ARROW_RETURN_IF(
+ literal == nullptr,
+ Status::Invalid("'like' function requires a literal as the second parameter"));
+
+ auto literal_type = literal->return_type()->id();
+ ARROW_RETURN_IF(
+ !IsArrowStringLiteral(literal_type),
+ Status::Invalid(
+ "'like' function requires a string literal as the second parameter"));
+
+ RE2::Options regex_op;
+ if (node.descriptor()->name() == "ilike") {
+ regex_op.set_case_sensitive(false); // set case-insensitive for ilike function.
+
+ return Make(arrow::util::get<std::string>(literal->holder()), holder, regex_op);
+ }
+ if (node.children().size() == 2) {
+ return Make(arrow::util::get<std::string>(literal->holder()), holder);
+ } else {
+ auto escape_char = dynamic_cast<LiteralNode*>(node.children().at(2).get());
+ ARROW_RETURN_IF(
+ escape_char == nullptr,
+ Status::Invalid("'like' function requires a literal as the third parameter"));
+
+ auto escape_char_type = escape_char->return_type()->id();
+ ARROW_RETURN_IF(
+ !IsArrowStringLiteral(escape_char_type),
+ Status::Invalid(
+ "'like' function requires a string literal as the third parameter"));
+ return Make(arrow::util::get<std::string>(literal->holder()),
+ arrow::util::get<std::string>(escape_char->holder()), holder);
+ }
+}
+
+Status LikeHolder::Make(const std::string& sql_pattern,
+ std::shared_ptr<LikeHolder>* holder) {
+ std::string pcre_pattern;
+ ARROW_RETURN_NOT_OK(RegexUtil::SqlLikePatternToPcre(sql_pattern, pcre_pattern));
+
+ auto lholder = std::shared_ptr<LikeHolder>(new LikeHolder(pcre_pattern));
+ ARROW_RETURN_IF(!lholder->regex_.ok(),
+ Status::Invalid("Building RE2 pattern '", pcre_pattern, "' failed"));
+
+ *holder = lholder;
+ return Status::OK();
+}
+
+Status LikeHolder::Make(const std::string& sql_pattern, const std::string& escape_char,
+ std::shared_ptr<LikeHolder>* holder) {
+ ARROW_RETURN_IF(escape_char.length() > 1,
+ Status::Invalid("The length of escape char ", escape_char,
+ " in 'like' function is greater than 1"));
+ std::string pcre_pattern;
+ if (escape_char.length() == 1) {
+ ARROW_RETURN_NOT_OK(
+ RegexUtil::SqlLikePatternToPcre(sql_pattern, escape_char.at(0), pcre_pattern));
+ } else {
+ ARROW_RETURN_NOT_OK(RegexUtil::SqlLikePatternToPcre(sql_pattern, pcre_pattern));
+ }
+
+ auto lholder = std::shared_ptr<LikeHolder>(new LikeHolder(pcre_pattern));
+ ARROW_RETURN_IF(!lholder->regex_.ok(),
+ Status::Invalid("Building RE2 pattern '", pcre_pattern, "' failed"));
+
+ *holder = lholder;
+ return Status::OK();
+}
+
+Status LikeHolder::Make(const std::string& sql_pattern,
+ std::shared_ptr<LikeHolder>* holder, RE2::Options regex_op) {
+ std::string pcre_pattern;
+ ARROW_RETURN_NOT_OK(RegexUtil::SqlLikePatternToPcre(sql_pattern, pcre_pattern));
+
+ std::shared_ptr<LikeHolder> lholder;
+ lholder = std::shared_ptr<LikeHolder>(new LikeHolder(pcre_pattern, regex_op));
+
+ ARROW_RETURN_IF(!lholder->regex_.ok(),
+ Status::Invalid("Building RE2 pattern '", pcre_pattern, "' failed"));
+
+ *holder = lholder;
+ return Status::OK();
+}
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/like_holder.h b/src/arrow/cpp/src/gandiva/like_holder.h
new file mode 100644
index 000000000..73e58017d
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/like_holder.h
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include <re2/re2.h>
+
+#include "arrow/status.h"
+
+#include "gandiva/function_holder.h"
+#include "gandiva/node.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// Function Holder for SQL 'like'
+class GANDIVA_EXPORT LikeHolder : public FunctionHolder {
+ public:
+ ~LikeHolder() override = default;
+
+ static Status Make(const FunctionNode& node, std::shared_ptr<LikeHolder>* holder);
+
+ static Status Make(const std::string& sql_pattern, std::shared_ptr<LikeHolder>* holder);
+
+ static Status Make(const std::string& sql_pattern, const std::string& escape_char,
+ std::shared_ptr<LikeHolder>* holder);
+
+ static Status Make(const std::string& sql_pattern, std::shared_ptr<LikeHolder>* holder,
+ RE2::Options regex_op);
+
+ // Try and optimise a function node with a "like" pattern.
+ static const FunctionNode TryOptimize(const FunctionNode& node);
+
+ /// Return true if the data matches the pattern.
+ bool operator()(const std::string& data) { return RE2::FullMatch(data, regex_); }
+
+ private:
+ explicit LikeHolder(const std::string& pattern) : pattern_(pattern), regex_(pattern) {}
+
+ LikeHolder(const std::string& pattern, RE2::Options regex_op)
+ : pattern_(pattern), regex_(pattern, regex_op) {}
+
+ std::string pattern_; // posix pattern string, to help debugging
+ RE2 regex_; // compiled regex for the pattern
+
+ static RE2 starts_with_regex_; // pre-compiled pattern for matching starts_with
+ static RE2 ends_with_regex_; // pre-compiled pattern for matching ends_with
+ static RE2 is_substr_regex_; // pre-compiled pattern for matching is_substr
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/like_holder_test.cc b/src/arrow/cpp/src/gandiva/like_holder_test.cc
new file mode 100644
index 000000000..a52533a11
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/like_holder_test.cc
@@ -0,0 +1,281 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/like_holder.h"
+#include "gandiva/regex_util.h"
+
+#include <memory>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+namespace gandiva {
+
+class TestLikeHolder : public ::testing::Test {
+ public:
+ RE2::Options regex_op;
+ FunctionNode BuildLike(std::string pattern) {
+ auto field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8()));
+ auto pattern_node =
+ std::make_shared<LiteralNode>(arrow::utf8(), LiteralHolder(pattern), false);
+ return FunctionNode("like", {field, pattern_node}, arrow::boolean());
+ }
+
+ FunctionNode BuildLike(std::string pattern, char escape_char) {
+ auto field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8()));
+ auto pattern_node =
+ std::make_shared<LiteralNode>(arrow::utf8(), LiteralHolder(pattern), false);
+ auto escape_char_node = std::make_shared<LiteralNode>(
+ arrow::int8(), LiteralHolder((int8_t)escape_char), false);
+ return FunctionNode("like", {field, pattern_node, escape_char_node},
+ arrow::boolean());
+ }
+};
+
+TEST_F(TestLikeHolder, TestMatchAny) {
+ std::shared_ptr<LikeHolder> like_holder;
+
+ auto status = LikeHolder::Make("ab%", &like_holder, regex_op);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& like = *like_holder;
+ EXPECT_TRUE(like("ab"));
+ EXPECT_TRUE(like("abc"));
+ EXPECT_TRUE(like("abcd"));
+
+ EXPECT_FALSE(like("a"));
+ EXPECT_FALSE(like("cab"));
+}
+
+TEST_F(TestLikeHolder, TestMatchOne) {
+ std::shared_ptr<LikeHolder> like_holder;
+
+ auto status = LikeHolder::Make("ab_", &like_holder, regex_op);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& like = *like_holder;
+ EXPECT_TRUE(like("abc"));
+ EXPECT_TRUE(like("abd"));
+
+ EXPECT_FALSE(like("a"));
+ EXPECT_FALSE(like("abcd"));
+ EXPECT_FALSE(like("dabc"));
+}
+
+TEST_F(TestLikeHolder, TestPcreSpecial) {
+ std::shared_ptr<LikeHolder> like_holder;
+
+ auto status = LikeHolder::Make(".*ab_", &like_holder, regex_op);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& like = *like_holder;
+ EXPECT_TRUE(like(".*abc")); // . and * aren't special in sql regex
+ EXPECT_FALSE(like("xxabc"));
+}
+
+TEST_F(TestLikeHolder, TestRegexEscape) {
+ std::string res;
+ auto status = RegexUtil::SqlLikePatternToPcre("#%hello#_abc_def##", '#', res);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ EXPECT_EQ(res, "%hello_abc.def#");
+}
+
+TEST_F(TestLikeHolder, TestDot) {
+ std::shared_ptr<LikeHolder> like_holder;
+
+ auto status = LikeHolder::Make("abc.", &like_holder, regex_op);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& like = *like_holder;
+ EXPECT_FALSE(like("abcd"));
+}
+
+TEST_F(TestLikeHolder, TestOptimise) {
+ // optimise for 'starts_with'
+ auto fnode = LikeHolder::TryOptimize(BuildLike("xy 123z%"));
+ EXPECT_EQ(fnode.descriptor()->name(), "starts_with");
+ EXPECT_EQ(fnode.ToString(), "bool starts_with((string) in, (const string) xy 123z)");
+
+ // optimise for 'ends_with'
+ fnode = LikeHolder::TryOptimize(BuildLike("%xyz"));
+ EXPECT_EQ(fnode.descriptor()->name(), "ends_with");
+ EXPECT_EQ(fnode.ToString(), "bool ends_with((string) in, (const string) xyz)");
+
+ // optimise for 'is_substr'
+ fnode = LikeHolder::TryOptimize(BuildLike("%abc%"));
+ EXPECT_EQ(fnode.descriptor()->name(), "is_substr");
+ EXPECT_EQ(fnode.ToString(), "bool is_substr((string) in, (const string) abc)");
+
+ // no optimisation for others.
+ fnode = LikeHolder::TryOptimize(BuildLike("xyz_"));
+ EXPECT_EQ(fnode.descriptor()->name(), "like");
+
+ fnode = LikeHolder::TryOptimize(BuildLike("_xyz"));
+ EXPECT_EQ(fnode.descriptor()->name(), "like");
+
+ fnode = LikeHolder::TryOptimize(BuildLike("_xyz_"));
+ EXPECT_EQ(fnode.descriptor()->name(), "like");
+
+ fnode = LikeHolder::TryOptimize(BuildLike("%xyz_"));
+ EXPECT_EQ(fnode.descriptor()->name(), "like");
+
+ fnode = LikeHolder::TryOptimize(BuildLike("x_yz%"));
+ EXPECT_EQ(fnode.descriptor()->name(), "like");
+
+ // no optimisation for escaped pattern.
+ fnode = LikeHolder::TryOptimize(BuildLike("\\%xyz", '\\'));
+ EXPECT_EQ(fnode.descriptor()->name(), "like");
+ EXPECT_EQ(fnode.ToString(),
+ "bool like((string) in, (const string) \\%xyz, (const int8) \\)");
+}
+
+TEST_F(TestLikeHolder, TestMatchOneEscape) {
+ std::shared_ptr<LikeHolder> like_holder;
+
+ auto status = LikeHolder::Make("ab\\_", "\\", &like_holder);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& like = *like_holder;
+
+ EXPECT_TRUE(like("ab_"));
+
+ EXPECT_FALSE(like("abc"));
+ EXPECT_FALSE(like("abd"));
+ EXPECT_FALSE(like("a"));
+ EXPECT_FALSE(like("abcd"));
+ EXPECT_FALSE(like("dabc"));
+}
+
+TEST_F(TestLikeHolder, TestMatchManyEscape) {
+ std::shared_ptr<LikeHolder> like_holder;
+
+ auto status = LikeHolder::Make("ab\\%", "\\", &like_holder);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& like = *like_holder;
+
+ EXPECT_TRUE(like("ab%"));
+
+ EXPECT_FALSE(like("abc"));
+ EXPECT_FALSE(like("abd"));
+ EXPECT_FALSE(like("a"));
+ EXPECT_FALSE(like("abcd"));
+ EXPECT_FALSE(like("dabc"));
+}
+
+TEST_F(TestLikeHolder, TestMatchEscape) {
+ std::shared_ptr<LikeHolder> like_holder;
+
+ auto status = LikeHolder::Make("ab\\\\", "\\", &like_holder);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& like = *like_holder;
+
+ EXPECT_TRUE(like("ab\\"));
+
+ EXPECT_FALSE(like("abc"));
+}
+
+TEST_F(TestLikeHolder, TestEmptyEscapeChar) {
+ std::shared_ptr<LikeHolder> like_holder;
+
+ auto status = LikeHolder::Make("ab\\_", "", &like_holder);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& like = *like_holder;
+
+ EXPECT_TRUE(like("ab\\c"));
+ EXPECT_TRUE(like("ab\\_"));
+
+ EXPECT_FALSE(like("ab\\_d"));
+ EXPECT_FALSE(like("ab__"));
+}
+
+TEST_F(TestLikeHolder, TestMultipleEscapeChar) {
+ std::shared_ptr<LikeHolder> like_holder;
+
+ auto status = LikeHolder::Make("ab\\_", "\\\\", &like_holder);
+ EXPECT_EQ(status.ok(), false) << status.message();
+}
+class TestILikeHolder : public ::testing::Test {
+ public:
+ RE2::Options regex_op;
+ FunctionNode BuildILike(std::string pattern) {
+ auto field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8()));
+ auto pattern_node =
+ std::make_shared<LiteralNode>(arrow::utf8(), LiteralHolder(pattern), false);
+ return FunctionNode("ilike", {field, pattern_node}, arrow::boolean());
+ }
+};
+
+TEST_F(TestILikeHolder, TestMatchAny) {
+ std::shared_ptr<LikeHolder> like_holder;
+
+ regex_op.set_case_sensitive(false);
+ auto status = LikeHolder::Make("ab%", &like_holder, regex_op);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& like = *like_holder;
+ EXPECT_TRUE(like("ab"));
+ EXPECT_TRUE(like("aBc"));
+ EXPECT_TRUE(like("ABCD"));
+
+ EXPECT_FALSE(like("a"));
+ EXPECT_FALSE(like("cab"));
+}
+
+TEST_F(TestILikeHolder, TestMatchOne) {
+ std::shared_ptr<LikeHolder> like_holder;
+
+ regex_op.set_case_sensitive(false);
+ auto status = LikeHolder::Make("Ab_", &like_holder, regex_op);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& like = *like_holder;
+ EXPECT_TRUE(like("abc"));
+ EXPECT_TRUE(like("aBd"));
+
+ EXPECT_FALSE(like("A"));
+ EXPECT_FALSE(like("Abcd"));
+ EXPECT_FALSE(like("DaBc"));
+}
+
+TEST_F(TestILikeHolder, TestPcreSpecial) {
+ std::shared_ptr<LikeHolder> like_holder;
+
+ regex_op.set_case_sensitive(false);
+ auto status = LikeHolder::Make(".*aB_", &like_holder, regex_op);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& like = *like_holder;
+ EXPECT_TRUE(like(".*Abc")); // . and * aren't special in sql regex
+ EXPECT_FALSE(like("xxAbc"));
+}
+
+TEST_F(TestILikeHolder, TestDot) {
+ std::shared_ptr<LikeHolder> like_holder;
+
+ regex_op.set_case_sensitive(false);
+ auto status = LikeHolder::Make("aBc.", &like_holder, regex_op);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& like = *like_holder;
+ EXPECT_FALSE(like("abcd"));
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/literal_holder.cc b/src/arrow/cpp/src/gandiva/literal_holder.cc
new file mode 100644
index 000000000..beed8119c
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/literal_holder.cc
@@ -0,0 +1,45 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+
+#include "gandiva/literal_holder.h"
+
+namespace gandiva {
+
+namespace {
+
+template <typename OStream>
+struct LiteralToStream {
+ OStream& ostream_;
+
+ template <typename Value>
+ void operator()(const Value& v) {
+ ostream_ << v;
+ }
+};
+
+} // namespace
+
+std::string ToString(const LiteralHolder& holder) {
+ std::stringstream ss;
+ LiteralToStream<std::stringstream> visitor{ss};
+ ::arrow::util::visit(visitor, holder);
+ return ss.str();
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/literal_holder.h b/src/arrow/cpp/src/gandiva/literal_holder.h
new file mode 100644
index 000000000..c4712aafc
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/literal_holder.h
@@ -0,0 +1,36 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+
+#include <arrow/util/variant.h>
+
+#include <arrow/type.h>
+#include "gandiva/decimal_scalar.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+using LiteralHolder =
+ arrow::util::Variant<bool, float, double, int8_t, int16_t, int32_t, int64_t, uint8_t,
+ uint16_t, uint32_t, uint64_t, std::string, DecimalScalar128>;
+
+GANDIVA_EXPORT std::string ToString(const LiteralHolder& holder);
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/llvm_generator.cc b/src/arrow/cpp/src/gandiva/llvm_generator.cc
new file mode 100644
index 000000000..0129e5278
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/llvm_generator.cc
@@ -0,0 +1,1400 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/llvm_generator.h"
+
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gandiva/bitmap_accumulator.h"
+#include "gandiva/decimal_ir.h"
+#include "gandiva/dex.h"
+#include "gandiva/expr_decomposer.h"
+#include "gandiva/expression.h"
+#include "gandiva/lvalue.h"
+
+namespace gandiva {
+
+#define ADD_TRACE(...) \
+ if (enable_ir_traces_) { \
+ AddTrace(__VA_ARGS__); \
+ }
+
+LLVMGenerator::LLVMGenerator() : enable_ir_traces_(false) {}
+
+Status LLVMGenerator::Make(std::shared_ptr<Configuration> config,
+ std::unique_ptr<LLVMGenerator>* llvm_generator) {
+ std::unique_ptr<LLVMGenerator> llvmgen_obj(new LLVMGenerator());
+
+ ARROW_RETURN_NOT_OK(Engine::Make(config, &(llvmgen_obj->engine_)));
+ *llvm_generator = std::move(llvmgen_obj);
+
+ return Status::OK();
+}
+
+Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr output) {
+ int idx = static_cast<int>(compiled_exprs_.size());
+ // decompose the expression to separate out value and validities.
+ ExprDecomposer decomposer(function_registry_, annotator_);
+ ValueValidityPairPtr value_validity;
+ ARROW_RETURN_NOT_OK(decomposer.Decompose(*expr->root(), &value_validity));
+ // Generate the IR function for the decomposed expression.
+ std::unique_ptr<CompiledExpr> compiled_expr(new CompiledExpr(value_validity, output));
+ llvm::Function* ir_function = nullptr;
+ ARROW_RETURN_NOT_OK(CodeGenExprValue(value_validity->value_expr(),
+ annotator_.buffer_count(), output, idx,
+ &ir_function, selection_vector_mode_));
+ compiled_expr->SetIRFunction(selection_vector_mode_, ir_function);
+
+ compiled_exprs_.push_back(std::move(compiled_expr));
+ return Status::OK();
+}
+
+/// Build and optimise module for projection expression.
+Status LLVMGenerator::Build(const ExpressionVector& exprs, SelectionVector::Mode mode) {
+ selection_vector_mode_ = mode;
+ for (auto& expr : exprs) {
+ auto output = annotator_.AddOutputFieldDescriptor(expr->result());
+ ARROW_RETURN_NOT_OK(Add(expr, output));
+ }
+
+ // Compile and inject into the process' memory the generated function.
+ ARROW_RETURN_NOT_OK(engine_->FinalizeModule());
+
+ // setup the jit functions for each expression.
+ for (auto& compiled_expr : compiled_exprs_) {
+ auto ir_fn = compiled_expr->GetIRFunction(mode);
+ auto jit_fn = reinterpret_cast<EvalFunc>(engine_->CompiledFunction(ir_fn));
+ compiled_expr->SetJITFunction(selection_vector_mode_, jit_fn);
+ }
+
+ return Status::OK();
+}
+
+/// Execute the compiled module against the provided vectors.
+Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch,
+ const ArrayDataVector& output_vector) {
+ return Execute(record_batch, nullptr, output_vector);
+}
+
+/// Execute the compiled module against the provided vectors based on the type of
+/// selection vector.
+Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch,
+ const SelectionVector* selection_vector,
+ const ArrayDataVector& output_vector) {
+ DCHECK_GT(record_batch.num_rows(), 0);
+
+ auto eval_batch = annotator_.PrepareEvalBatch(record_batch, output_vector);
+ DCHECK_GT(eval_batch->GetNumBuffers(), 0);
+
+ auto mode = SelectionVector::MODE_NONE;
+ if (selection_vector != nullptr) {
+ mode = selection_vector->GetMode();
+ }
+ if (mode != selection_vector_mode_) {
+ return Status::Invalid("llvm expression built for selection vector mode ",
+ selection_vector_mode_, " received vector with mode ", mode);
+ }
+
+ for (auto& compiled_expr : compiled_exprs_) {
+ // generate data/offset vectors.
+ const uint8_t* selection_buffer = nullptr;
+ auto num_output_rows = record_batch.num_rows();
+ if (selection_vector != nullptr) {
+ selection_buffer = selection_vector->GetBuffer().data();
+ num_output_rows = selection_vector->GetNumSlots();
+ }
+
+ EvalFunc jit_function = compiled_expr->GetJITFunction(mode);
+ jit_function(eval_batch->GetBufferArray(), eval_batch->GetBufferOffsetArray(),
+ eval_batch->GetLocalBitMapArray(), selection_buffer,
+ (int64_t)eval_batch->GetExecutionContext(), num_output_rows);
+
+ // check for execution errors
+ ARROW_RETURN_IF(
+ eval_batch->GetExecutionContext()->has_error(),
+ Status::ExecutionError(eval_batch->GetExecutionContext()->get_error()));
+
+ // generate validity vectors.
+ ComputeBitMapsForExpr(*compiled_expr, *eval_batch, selection_vector);
+ }
+
+ return Status::OK();
+}
+
+llvm::Value* LLVMGenerator::LoadVectorAtIndex(llvm::Value* arg_addrs, int idx,
+ const std::string& name) {
+ auto* idx_val = types()->i32_constant(idx);
+ auto* offset = CreateGEP(ir_builder(), arg_addrs, idx_val, name + "_mem_addr");
+ return CreateLoad(ir_builder(), offset, name + "_mem");
+}
+
+/// Get reference to validity array at specified index in the args list.
+llvm::Value* LLVMGenerator::GetValidityReference(llvm::Value* arg_addrs, int idx,
+ FieldPtr field) {
+ const std::string& name = field->name();
+ llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
+ return ir_builder()->CreateIntToPtr(load, types()->i64_ptr_type(), name + "_varray");
+}
+
+/// Get reference to data array at specified index in the args list.
+llvm::Value* LLVMGenerator::GetDataBufferPtrReference(llvm::Value* arg_addrs, int idx,
+ FieldPtr field) {
+ const std::string& name = field->name();
+ llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
+ return ir_builder()->CreateIntToPtr(load, types()->i8_ptr_type(), name + "_buf_ptr");
+}
+
+/// Get reference to data array at specified index in the args list.
+llvm::Value* LLVMGenerator::GetDataReference(llvm::Value* arg_addrs, int idx,
+ FieldPtr field) {
+ const std::string& name = field->name();
+ llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
+ llvm::Type* base_type = types()->DataVecType(field->type());
+ llvm::Value* ret;
+ if (base_type->isPointerTy()) {
+ ret = ir_builder()->CreateIntToPtr(load, base_type, name + "_darray");
+ } else {
+ llvm::Type* pointer_type = types()->ptr_type(base_type);
+ ret = ir_builder()->CreateIntToPtr(load, pointer_type, name + "_darray");
+ }
+ return ret;
+}
+
+/// Get reference to offsets array at specified index in the args list.
+llvm::Value* LLVMGenerator::GetOffsetsReference(llvm::Value* arg_addrs, int idx,
+ FieldPtr field) {
+ const std::string& name = field->name();
+ llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);
+ return ir_builder()->CreateIntToPtr(load, types()->i32_ptr_type(), name + "_oarray");
+}
+
+/// Get reference to local bitmap array at specified index in the args list.
+llvm::Value* LLVMGenerator::GetLocalBitMapReference(llvm::Value* arg_bitmaps, int idx) {
+ llvm::Value* load = LoadVectorAtIndex(arg_bitmaps, idx, "");
+ return ir_builder()->CreateIntToPtr(load, types()->i64_ptr_type(),
+ std::to_string(idx) + "_lbmap");
+}
+
+/// \brief Generate code for one expression.
+
+// Sample IR code for "c1:int + c2:int"
+//
+// The C-code equivalent is :
+// ------------------------------
+// int expr_0(int64_t *addrs, int64_t *local_bitmaps,
+// int64_t execution_context_ptr, int64_t nrecords) {
+// int *outVec = (int *) addrs[5];
+// int *c0Vec = (int *) addrs[1];
+// int *c1Vec = (int *) addrs[3];
+// for (int loop_var = 0; loop_var < nrecords; ++loop_var) {
+// int c0 = c0Vec[loop_var];
+// int c1 = c1Vec[loop_var];
+// int out = c0 + c1;
+// outVec[loop_var] = out;
+// }
+// }
+//
+// IR Code
+// --------
+//
+// define i32 @expr_0(i64* %args, i64* %local_bitmaps, i64 %execution_context_ptr, , i64
+// %nrecords) { entry:
+// %outmemAddr = getelementptr i64, i64* %args, i32 5
+// %outmem = load i64, i64* %outmemAddr
+// %outVec = inttoptr i64 %outmem to i32*
+// %c0memAddr = getelementptr i64, i64* %args, i32 1
+// %c0mem = load i64, i64* %c0memAddr
+// %c0Vec = inttoptr i64 %c0mem to i32*
+// %c1memAddr = getelementptr i64, i64* %args, i32 3
+// %c1mem = load i64, i64* %c1memAddr
+// %c1Vec = inttoptr i64 %c1mem to i32*
+// br label %loop
+// loop: ; preds = %loop, %entry
+// %loop_var = phi i64 [ 0, %entry ], [ %"loop_var+1", %loop ]
+// %"loop_var+1" = add i64 %loop_var, 1
+// %0 = getelementptr i32, i32* %c0Vec, i32 %loop_var
+// %c0 = load i32, i32* %0
+// %1 = getelementptr i32, i32* %c1Vec, i32 %loop_var
+// %c1 = load i32, i32* %1
+// %add_int_int = call i32 @add_int_int(i32 %c0, i32 %c1)
+// %2 = getelementptr i32, i32* %outVec, i32 %loop_var
+// store i32 %add_int_int, i32* %2
+// %"loop_var < nrec" = icmp slt i64 %"loop_var+1", %nrecords
+// br i1 %"loop_var < nrec", label %loop, label %exit
+// exit: ; preds = %loop
+// ret i32 0
+// }
+Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count,
+ FieldDescriptorPtr output, int suffix_idx,
+ llvm::Function** fn,
+ SelectionVector::Mode selection_vector_mode) {
+ llvm::IRBuilder<>* builder = ir_builder();
+ // Create fn prototype :
+ // int expr_1 (long **addrs, long *offsets, long **bitmaps,
+ // long *context_ptr, long nrec)
+ std::vector<llvm::Type*> arguments;
+ arguments.push_back(types()->i64_ptr_type()); // addrs
+ arguments.push_back(types()->i64_ptr_type()); // offsets
+ arguments.push_back(types()->i64_ptr_type()); // bitmaps
+ switch (selection_vector_mode) {
+ case SelectionVector::MODE_NONE:
+ case SelectionVector::MODE_UINT16:
+ arguments.push_back(types()->ptr_type(types()->i16_type()));
+ break;
+ case SelectionVector::MODE_UINT32:
+ arguments.push_back(types()->i32_ptr_type());
+ break;
+ case SelectionVector::MODE_UINT64:
+ arguments.push_back(types()->i64_ptr_type());
+ }
+ arguments.push_back(types()->i64_type()); // ctx_ptr
+ arguments.push_back(types()->i64_type()); // nrec
+ llvm::FunctionType* prototype =
+ llvm::FunctionType::get(types()->i32_type(), arguments, false /*isVarArg*/);
+
+ // Create fn
+ std::string func_name = "expr_" + std::to_string(suffix_idx) + "_" +
+ std::to_string(static_cast<int>(selection_vector_mode));
+ engine_->AddFunctionToCompile(func_name);
+ *fn = llvm::Function::Create(prototype, llvm::GlobalValue::ExternalLinkage, func_name,
+ module());
+ ARROW_RETURN_IF((*fn == nullptr), Status::CodeGenError("Error creating function."));
+
+ // Name the arguments
+ llvm::Function::arg_iterator args = (*fn)->arg_begin();
+ llvm::Value* arg_addrs = &*args;
+ arg_addrs->setName("inputs_addr");
+ ++args;
+ llvm::Value* arg_addr_offsets = &*args;
+ arg_addr_offsets->setName("inputs_addr_offsets");
+ ++args;
+ llvm::Value* arg_local_bitmaps = &*args;
+ arg_local_bitmaps->setName("local_bitmaps");
+ ++args;
+ llvm::Value* arg_selection_vector = &*args;
+ arg_selection_vector->setName("selection_vector");
+ ++args;
+ llvm::Value* arg_context_ptr = &*args;
+ arg_context_ptr->setName("context_ptr");
+ ++args;
+ llvm::Value* arg_nrecords = &*args;
+ arg_nrecords->setName("nrecords");
+
+ llvm::BasicBlock* loop_entry = llvm::BasicBlock::Create(*context(), "entry", *fn);
+ llvm::BasicBlock* loop_body = llvm::BasicBlock::Create(*context(), "loop", *fn);
+ llvm::BasicBlock* loop_exit = llvm::BasicBlock::Create(*context(), "exit", *fn);
+
+ // Add reference to output vector (in entry block)
+ builder->SetInsertPoint(loop_entry);
+ llvm::Value* output_ref =
+ GetDataReference(arg_addrs, output->data_idx(), output->field());
+ llvm::Value* output_buffer_ptr_ref = GetDataBufferPtrReference(
+ arg_addrs, output->data_buffer_ptr_idx(), output->field());
+ llvm::Value* output_offset_ref =
+ GetOffsetsReference(arg_addrs, output->offsets_idx(), output->field());
+
+ std::vector<llvm::Value*> slice_offsets;
+ for (int idx = 0; idx < buffer_count; idx++) {
+ auto offsetAddr = CreateGEP(builder, arg_addr_offsets, types()->i32_constant(idx));
+ auto offset = CreateLoad(builder, offsetAddr);
+ slice_offsets.push_back(offset);
+ }
+
+ // Loop body
+ builder->SetInsertPoint(loop_body);
+
+ // define loop_var : start with 0, +1 after each iter
+ llvm::PHINode* loop_var = builder->CreatePHI(types()->i64_type(), 2, "loop_var");
+
+ llvm::Value* position_var = loop_var;
+ if (selection_vector_mode != SelectionVector::MODE_NONE) {
+ position_var = builder->CreateIntCast(
+ CreateLoad(builder, CreateGEP(builder, arg_selection_vector, loop_var),
+ "uncasted_position_var"),
+ types()->i64_type(), true, "position_var");
+ }
+
+ // The visitor can add code to both the entry/loop blocks.
+ Visitor visitor(this, *fn, loop_entry, arg_addrs, arg_local_bitmaps, slice_offsets,
+ arg_context_ptr, position_var);
+ value_expr->Accept(visitor);
+ LValuePtr output_value = visitor.result();
+
+ // The "current" block may have changed due to code generation in the visitor.
+ llvm::BasicBlock* loop_body_tail = builder->GetInsertBlock();
+
+ // add jump to "loop block" at the end of the "setup block".
+ builder->SetInsertPoint(loop_entry);
+ builder->CreateBr(loop_body);
+
+ // save the value in the output vector.
+ builder->SetInsertPoint(loop_body_tail);
+
+ auto output_type_id = output->Type()->id();
+ if (output_type_id == arrow::Type::BOOL) {
+ SetPackedBitValue(output_ref, loop_var, output_value->data());
+ } else if (arrow::is_primitive(output_type_id) ||
+ output_type_id == arrow::Type::DECIMAL) {
+ llvm::Value* slot_offset = CreateGEP(builder, output_ref, loop_var);
+ builder->CreateStore(output_value->data(), slot_offset);
+ } else if (arrow::is_binary_like(output_type_id)) {
+ // Var-len output. Make a function call to populate the data.
+ // if there is an error, the fn sets it in the context. And, will be returned at the
+ // end of this row batch.
+ AddFunctionCall("gdv_fn_populate_varlen_vector", types()->i32_type(),
+ {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, loop_var,
+ output_value->data(), output_value->length()});
+ } else {
+ return Status::NotImplemented("output type ", output->Type()->ToString(),
+ " not supported");
+ }
+ ADD_TRACE("saving result " + output->Name() + " value %T", output_value->data());
+
+ if (visitor.has_arena_allocs()) {
+ // Reset allocations to avoid excessive memory usage. Once the result is copied to
+ // the output vector (store instruction above), any memory allocations in this
+ // iteration of the loop are no longer needed.
+ std::vector<llvm::Value*> reset_args;
+ reset_args.push_back(arg_context_ptr);
+ AddFunctionCall("gdv_fn_context_arena_reset", types()->void_type(), reset_args);
+ }
+
+ // check loop_var
+ loop_var->addIncoming(types()->i64_constant(0), loop_entry);
+ llvm::Value* loop_update =
+ builder->CreateAdd(loop_var, types()->i64_constant(1), "loop_var+1");
+ loop_var->addIncoming(loop_update, loop_body_tail);
+
+ llvm::Value* loop_var_check =
+ builder->CreateICmpSLT(loop_update, arg_nrecords, "loop_var < nrec");
+ builder->CreateCondBr(loop_var_check, loop_body, loop_exit);
+
+ // Loop exit
+ builder->SetInsertPoint(loop_exit);
+ builder->CreateRet(types()->i32_constant(0));
+ return Status::OK();
+}
+
+/// Return value of a bit in bitMap.
+llvm::Value* LLVMGenerator::GetPackedBitValue(llvm::Value* bitmap,
+ llvm::Value* position) {
+ ADD_TRACE("fetch bit at position %T", position);
+
+ llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
+ bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
+ return AddFunctionCall("bitMapGetBit", types()->i1_type(), {bitmap8, position});
+}
+
+/// Set the value of a bit in bitMap.
+void LLVMGenerator::SetPackedBitValue(llvm::Value* bitmap, llvm::Value* position,
+ llvm::Value* value) {
+ ADD_TRACE("set bit at position %T", position);
+ ADD_TRACE(" to value %T ", value);
+
+ llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
+ bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
+ AddFunctionCall("bitMapSetBit", types()->void_type(), {bitmap8, position, value});
+}
+
+/// Return value of a bit in validity bitMap (handles null bitmaps too).
+llvm::Value* LLVMGenerator::GetPackedValidityBitValue(llvm::Value* bitmap,
+ llvm::Value* position) {
+ ADD_TRACE("fetch validity bit at position %T", position);
+
+ llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
+ bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
+ return AddFunctionCall("bitMapValidityGetBit", types()->i1_type(), {bitmap8, position});
+}
+
+/// Clear the bit in bitMap if value = false.
+void LLVMGenerator::ClearPackedBitValueIfFalse(llvm::Value* bitmap, llvm::Value* position,
+ llvm::Value* value) {
+ ADD_TRACE("ClearIfFalse bit at position %T", position);
+ ADD_TRACE(" value %T ", value);
+
+ llvm::Value* bitmap8 = ir_builder()->CreateBitCast(
+ bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast");
+ AddFunctionCall("bitMapClearBitIfFalse", types()->void_type(),
+ {bitmap8, position, value});
+}
+
+/// Extract the bitmap addresses, and do an intersection.
+void LLVMGenerator::ComputeBitMapsForExpr(const CompiledExpr& compiled_expr,
+ const EvalBatch& eval_batch,
+ const SelectionVector* selection_vector) {
+ auto validities = compiled_expr.value_validity()->validity_exprs();
+
+ // Extract all the source bitmap addresses.
+ BitMapAccumulator accumulator(eval_batch);
+ for (auto& validity_dex : validities) {
+ validity_dex->Accept(accumulator);
+ }
+
+ // Extract the destination bitmap address.
+ int out_idx = compiled_expr.output()->validity_idx();
+ uint8_t* dst_bitmap = eval_batch.GetBuffer(out_idx);
+ // Compute the destination bitmap.
+ if (selection_vector == nullptr) {
+ accumulator.ComputeResult(dst_bitmap);
+ } else {
+ /// The output bitmap is an intersection of some input/local bitmaps. However, with a
+ /// selection vector, only the bits corresponding to the indices in the selection
+ /// vector need to set in the output bitmap. This is done in two steps :
+ ///
+ /// 1. Do the intersection of input/local bitmaps to generate a temporary bitmap.
+ /// 2. copy just the relevant bits from the temporary bitmap to the output bitmap.
+ LocalBitMapsHolder bit_map_holder(eval_batch.num_records(), 1);
+ uint8_t* temp_bitmap = bit_map_holder.GetLocalBitMap(0);
+ accumulator.ComputeResult(temp_bitmap);
+
+ auto num_out_records = selection_vector->GetNumSlots();
+ // the memset isn't required, doing it just for valgrind.
+ memset(dst_bitmap, 0, arrow::BitUtil::BytesForBits(num_out_records));
+ for (auto i = 0; i < num_out_records; ++i) {
+ auto bit = arrow::BitUtil::GetBit(temp_bitmap, selection_vector->GetIndex(i));
+ arrow::BitUtil::SetBitTo(dst_bitmap, i, bit);
+ }
+ }
+}
+
+llvm::Value* LLVMGenerator::AddFunctionCall(const std::string& full_name,
+ llvm::Type* ret_type,
+ const std::vector<llvm::Value*>& args) {
+ // find the llvm function.
+ llvm::Function* fn = module()->getFunction(full_name);
+ DCHECK_NE(fn, nullptr) << "missing function " << full_name;
+
+ if (enable_ir_traces_ && !full_name.compare("printf") &&
+ !full_name.compare("printff")) {
+ // Trace for debugging
+ ADD_TRACE("invoke native fn " + full_name);
+ }
+
+ // build a call to the llvm function.
+ llvm::Value* value;
+ if (ret_type->isVoidTy()) {
+ // void functions can't have a name for the call.
+ value = ir_builder()->CreateCall(fn, args);
+ } else {
+ value = ir_builder()->CreateCall(fn, args, full_name);
+ DCHECK(value->getType() == ret_type);
+ }
+
+ return value;
+}
+
+std::shared_ptr<DecimalLValue> LLVMGenerator::BuildDecimalLValue(llvm::Value* value,
+ DataTypePtr arrow_type) {
+ // only decimals of size 128-bit supported.
+ DCHECK(is_decimal_128(arrow_type));
+ auto decimal_type =
+ arrow::internal::checked_cast<arrow::DecimalType*>(arrow_type.get());
+ return std::make_shared<DecimalLValue>(value, nullptr,
+ types()->i32_constant(decimal_type->precision()),
+ types()->i32_constant(decimal_type->scale()));
+}
+
+#define ADD_VISITOR_TRACE(...) \
+ if (generator_->enable_ir_traces_) { \
+ generator_->AddTrace(__VA_ARGS__); \
+ }
+
+// Visitor for generating the code for a decomposed expression.
+LLVMGenerator::Visitor::Visitor(LLVMGenerator* generator, llvm::Function* function,
+ llvm::BasicBlock* entry_block, llvm::Value* arg_addrs,
+ llvm::Value* arg_local_bitmaps,
+ std::vector<llvm::Value*> slice_offsets,
+ llvm::Value* arg_context_ptr, llvm::Value* loop_var)
+ : generator_(generator),
+ function_(function),
+ entry_block_(entry_block),
+ arg_addrs_(arg_addrs),
+ arg_local_bitmaps_(arg_local_bitmaps),
+ slice_offsets_(slice_offsets),
+ arg_context_ptr_(arg_context_ptr),
+ loop_var_(loop_var),
+ has_arena_allocs_(false) {
+ ADD_VISITOR_TRACE("Iteration %T", loop_var);
+}
+
+void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) {
+ llvm::IRBuilder<>* builder = ir_builder();
+ llvm::Value* slot_ref = GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field());
+ llvm::Value* slot_index = builder->CreateAdd(loop_var_, GetSliceOffset(dex.DataIdx()));
+ llvm::Value* slot_value;
+ std::shared_ptr<LValue> lvalue;
+
+ switch (dex.FieldType()->id()) {
+ case arrow::Type::BOOL:
+ slot_value = generator_->GetPackedBitValue(slot_ref, slot_index);
+ lvalue = std::make_shared<LValue>(slot_value);
+ break;
+
+ case arrow::Type::DECIMAL: {
+ auto slot_offset = CreateGEP(builder, slot_ref, slot_index);
+ slot_value = CreateLoad(builder, slot_offset, dex.FieldName());
+ lvalue = generator_->BuildDecimalLValue(slot_value, dex.FieldType());
+ break;
+ }
+
+ default: {
+ auto slot_offset = CreateGEP(builder, slot_ref, slot_index);
+ slot_value = CreateLoad(builder, slot_offset, dex.FieldName());
+ lvalue = std::make_shared<LValue>(slot_value);
+ break;
+ }
+ }
+ ADD_VISITOR_TRACE("visit fixed-len data vector " + dex.FieldName() + " value %T",
+ slot_value);
+ result_ = lvalue;
+}
+
+void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) {
+ llvm::IRBuilder<>* builder = ir_builder();
+ llvm::Value* slot;
+
+ // compute len from the offsets array.
+ llvm::Value* offsets_slot_ref =
+ GetBufferReference(dex.OffsetsIdx(), kBufferTypeOffsets, dex.Field());
+ llvm::Value* offsets_slot_index =
+ builder->CreateAdd(loop_var_, GetSliceOffset(dex.OffsetsIdx()));
+
+ // => offset_start = offsets[loop_var]
+ slot = CreateGEP(builder, offsets_slot_ref, offsets_slot_index);
+ llvm::Value* offset_start = CreateLoad(builder, slot, "offset_start");
+
+ // => offset_end = offsets[loop_var + 1]
+ llvm::Value* offsets_slot_index_next = builder->CreateAdd(
+ offsets_slot_index, generator_->types()->i64_constant(1), "loop_var+1");
+ slot = CreateGEP(builder, offsets_slot_ref, offsets_slot_index_next);
+ llvm::Value* offset_end = CreateLoad(builder, slot, "offset_end");
+
+ // => len_value = offset_end - offset_start
+ llvm::Value* len_value =
+ builder->CreateSub(offset_end, offset_start, dex.FieldName() + "Len");
+
+ // get the data from the data array, at offset 'offset_start'.
+ llvm::Value* data_slot_ref =
+ GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field());
+ llvm::Value* data_value = CreateGEP(builder, data_slot_ref, offset_start);
+ ADD_VISITOR_TRACE("visit var-len data vector " + dex.FieldName() + " len %T",
+ len_value);
+ result_.reset(new LValue(data_value, len_value));
+}
+
+void LLVMGenerator::Visitor::Visit(const VectorReadValidityDex& dex) {
+ llvm::IRBuilder<>* builder = ir_builder();
+ llvm::Value* slot_ref =
+ GetBufferReference(dex.ValidityIdx(), kBufferTypeValidity, dex.Field());
+ llvm::Value* slot_index =
+ builder->CreateAdd(loop_var_, GetSliceOffset(dex.ValidityIdx()));
+ llvm::Value* validity = generator_->GetPackedValidityBitValue(slot_ref, slot_index);
+
+ ADD_VISITOR_TRACE("visit validity vector " + dex.FieldName() + " value %T", validity);
+ result_.reset(new LValue(validity));
+}
+
+void LLVMGenerator::Visitor::Visit(const LocalBitMapValidityDex& dex) {
+ llvm::Value* slot_ref = GetLocalBitMapReference(dex.local_bitmap_idx());
+ llvm::Value* validity = generator_->GetPackedBitValue(slot_ref, loop_var_);
+
+ ADD_VISITOR_TRACE(
+ "visit local bitmap " + std::to_string(dex.local_bitmap_idx()) + " value %T",
+ validity);
+ result_.reset(new LValue(validity));
+}
+
+void LLVMGenerator::Visitor::Visit(const TrueDex& dex) {
+ result_.reset(new LValue(generator_->types()->true_constant()));
+}
+
+void LLVMGenerator::Visitor::Visit(const FalseDex& dex) {
+ result_.reset(new LValue(generator_->types()->false_constant()));
+}
+
+void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) {
+ LLVMTypes* types = generator_->types();
+ llvm::Value* value = nullptr;
+ llvm::Value* len = nullptr;
+
+ switch (dex.type()->id()) {
+ case arrow::Type::BOOL:
+ value = types->i1_constant(arrow::util::get<bool>(dex.holder()));
+ break;
+
+ case arrow::Type::UINT8:
+ value = types->i8_constant(arrow::util::get<uint8_t>(dex.holder()));
+ break;
+
+ case arrow::Type::UINT16:
+ value = types->i16_constant(arrow::util::get<uint16_t>(dex.holder()));
+ break;
+
+ case arrow::Type::UINT32:
+ value = types->i32_constant(arrow::util::get<uint32_t>(dex.holder()));
+ break;
+
+ case arrow::Type::UINT64:
+ value = types->i64_constant(arrow::util::get<uint64_t>(dex.holder()));
+ break;
+
+ case arrow::Type::INT8:
+ value = types->i8_constant(arrow::util::get<int8_t>(dex.holder()));
+ break;
+
+ case arrow::Type::INT16:
+ value = types->i16_constant(arrow::util::get<int16_t>(dex.holder()));
+ break;
+
+ case arrow::Type::FLOAT:
+ value = types->float_constant(arrow::util::get<float>(dex.holder()));
+ break;
+
+ case arrow::Type::DOUBLE:
+ value = types->double_constant(arrow::util::get<double>(dex.holder()));
+ break;
+
+ case arrow::Type::STRING:
+ case arrow::Type::BINARY: {
+ const std::string& str = arrow::util::get<std::string>(dex.holder());
+
+ llvm::Constant* str_int_cast = types->i64_constant((int64_t)str.c_str());
+ value = llvm::ConstantExpr::getIntToPtr(str_int_cast, types->i8_ptr_type());
+ len = types->i32_constant(static_cast<int32_t>(str.length()));
+ break;
+ }
+
+ case arrow::Type::INT32:
+ case arrow::Type::DATE32:
+ case arrow::Type::TIME32:
+ case arrow::Type::INTERVAL_MONTHS:
+ value = types->i32_constant(arrow::util::get<int32_t>(dex.holder()));
+ break;
+
+ case arrow::Type::INT64:
+ case arrow::Type::DATE64:
+ case arrow::Type::TIME64:
+ case arrow::Type::TIMESTAMP:
+ case arrow::Type::INTERVAL_DAY_TIME:
+ value = types->i64_constant(arrow::util::get<int64_t>(dex.holder()));
+ break;
+
+ case arrow::Type::DECIMAL: {
+ // build code for struct
+ auto scalar = arrow::util::get<DecimalScalar128>(dex.holder());
+ // ConstantInt doesn't have a get method that takes int128 or a pair of int64. so,
+ // passing the string representation instead.
+ auto int128_value =
+ llvm::ConstantInt::get(llvm::Type::getInt128Ty(*generator_->context()),
+ Decimal128(scalar.value()).ToIntegerString(), 10);
+ auto type = arrow::decimal(scalar.precision(), scalar.scale());
+ auto lvalue = generator_->BuildDecimalLValue(int128_value, type);
+ // set it as the l-value and return.
+ result_ = lvalue;
+ return;
+ }
+
+ default:
+ DCHECK(0);
+ }
+ ADD_VISITOR_TRACE("visit Literal %T", value);
+ result_.reset(new LValue(value, len));
+}
+
+void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) {
+ const std::string& function_name = dex.func_descriptor()->name();
+ ADD_VISITOR_TRACE("visit NonNullableFunc base function " + function_name);
+
+ const NativeFunction* native_function = dex.native_function();
+
+ // build the function params (ignore validity).
+ auto params = BuildParams(dex.function_holder().get(), dex.args(), false,
+ native_function->NeedsContext());
+
+ auto arrow_return_type = dex.func_descriptor()->return_type();
+ if (native_function->CanReturnErrors()) {
+ // slow path : if a function can return errors, skip invoking the function
+ // unless all of the input args are valid. Otherwise, it can cause spurious errors.
+
+ llvm::IRBuilder<>* builder = ir_builder();
+ LLVMTypes* types = generator_->types();
+ auto arrow_type_id = arrow_return_type->id();
+ auto result_type = types->IRType(arrow_type_id);
+
+ // Build combined validity of the args.
+ llvm::Value* is_valid = types->true_constant();
+ for (auto& pair : dex.args()) {
+ auto arg_validity = BuildCombinedValidity(pair->validity_exprs());
+ is_valid = builder->CreateAnd(is_valid, arg_validity, "validityBitAnd");
+ }
+
+ // then block
+ auto then_lambda = [&] {
+ ADD_VISITOR_TRACE("fn " + function_name +
+ " can return errors : all args valid, invoke fn");
+ return BuildFunctionCall(native_function, arrow_return_type, &params);
+ };
+
+ // else block
+ auto else_lambda = [&] {
+ ADD_VISITOR_TRACE("fn " + function_name +
+ " can return errors : not all args valid, return dummy value");
+ llvm::Value* else_value = types->NullConstant(result_type);
+ llvm::Value* else_value_len = nullptr;
+ if (arrow::is_binary_like(arrow_type_id)) {
+ else_value_len = types->i32_constant(0);
+ }
+ return std::make_shared<LValue>(else_value, else_value_len);
+ };
+
+ result_ = BuildIfElse(is_valid, then_lambda, else_lambda, arrow_return_type);
+ } else {
+ // fast path : invoke function without computing validities.
+ result_ = BuildFunctionCall(native_function, arrow_return_type, &params);
+ }
+}
+
+void LLVMGenerator::Visitor::Visit(const NullableNeverFuncDex& dex) {
+ ADD_VISITOR_TRACE("visit NullableNever base function " + dex.func_descriptor()->name());
+ const NativeFunction* native_function = dex.native_function();
+
+ // build function params along with validity.
+ auto params = BuildParams(dex.function_holder().get(), dex.args(), true,
+ native_function->NeedsContext());
+
+ auto arrow_return_type = dex.func_descriptor()->return_type();
+ result_ = BuildFunctionCall(native_function, arrow_return_type, &params);
+}
+
+void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex& dex) {
+ ADD_VISITOR_TRACE("visit NullableInternal base function " +
+ dex.func_descriptor()->name());
+ llvm::IRBuilder<>* builder = ir_builder();
+ LLVMTypes* types = generator_->types();
+
+ const NativeFunction* native_function = dex.native_function();
+
+ // build function params along with validity.
+ auto params = BuildParams(dex.function_holder().get(), dex.args(), true,
+ native_function->NeedsContext());
+
+ // add an extra arg for validity (allocated on stack).
+ llvm::AllocaInst* result_valid_ptr =
+ new llvm::AllocaInst(types->i8_type(), 0, "result_valid", entry_block_);
+ params.push_back(result_valid_ptr);
+
+ auto arrow_return_type = dex.func_descriptor()->return_type();
+ result_ = BuildFunctionCall(native_function, arrow_return_type, &params);
+
+ // load the result validity and truncate to i1.
+ llvm::Value* result_valid_i8 = CreateLoad(builder, result_valid_ptr);
+ llvm::Value* result_valid = builder->CreateTrunc(result_valid_i8, types->i1_type());
+
+ // set validity bit in the local bitmap.
+ ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), result_valid);
+}
+
+void LLVMGenerator::Visitor::Visit(const IfDex& dex) {
+ ADD_VISITOR_TRACE("visit IfExpression");
+ llvm::IRBuilder<>* builder = ir_builder();
+
+ // Evaluate condition.
+ LValuePtr if_condition = BuildValueAndValidity(dex.condition_vv());
+
+ // Check if the result is valid, and there is match.
+ llvm::Value* validAndMatched =
+ builder->CreateAnd(if_condition->data(), if_condition->validity(), "validAndMatch");
+
+ // then block
+ auto then_lambda = [&] {
+ ADD_VISITOR_TRACE("branch to then block");
+ LValuePtr then_lvalue = BuildValueAndValidity(dex.then_vv());
+ ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), then_lvalue->validity());
+ ADD_VISITOR_TRACE("IfExpression result validity %T in matching then",
+ then_lvalue->validity());
+ return then_lvalue;
+ };
+
+ // else block
+ auto else_lambda = [&] {
+ LValuePtr else_lvalue;
+ if (dex.is_terminal_else()) {
+ ADD_VISITOR_TRACE("branch to terminal else block");
+
+ else_lvalue = BuildValueAndValidity(dex.else_vv());
+ // update the local bitmap with the validity.
+ ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), else_lvalue->validity());
+ ADD_VISITOR_TRACE("IfExpression result validity %T in terminal else",
+ else_lvalue->validity());
+ } else {
+ ADD_VISITOR_TRACE("branch to non-terminal else block");
+
+ // this is a non-terminal else. let the child (nested if/else) handle validity.
+ auto value_expr = dex.else_vv().value_expr();
+ value_expr->Accept(*this);
+ else_lvalue = result();
+ }
+ return else_lvalue;
+ };
+
+ // build the if-else condition.
+ result_ = BuildIfElse(validAndMatched, then_lambda, else_lambda, dex.result_type());
+ if (arrow::is_binary_like(dex.result_type()->id())) {
+ ADD_VISITOR_TRACE("IfElse result length %T", result_->length());
+ }
+ ADD_VISITOR_TRACE("IfElse result value %T", result_->data());
+}
+
+// Boolean AND
+// if any arg is valid and false,
+// short-circuit and return FALSE (value=false, valid=true)
+// else if all args are valid and true
+// return TRUE (value=true, valid=true)
+// else
+// return NULL (value=true, valid=false)
+
+void LLVMGenerator::Visitor::Visit(const BooleanAndDex& dex) {
+ ADD_VISITOR_TRACE("visit BooleanAndExpression");
+ llvm::IRBuilder<>* builder = ir_builder();
+ LLVMTypes* types = generator_->types();
+ llvm::LLVMContext* context = generator_->context();
+
+ // Create blocks for short-circuit.
+ llvm::BasicBlock* short_circuit_bb =
+ llvm::BasicBlock::Create(*context, "short_circuit", function_);
+ llvm::BasicBlock* non_short_circuit_bb =
+ llvm::BasicBlock::Create(*context, "non_short_circuit", function_);
+ llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_);
+
+ llvm::Value* all_exprs_valid = types->true_constant();
+ for (auto& pair : dex.args()) {
+ LValuePtr current = BuildValueAndValidity(*pair);
+
+ ADD_VISITOR_TRACE("BooleanAndExpression arg value %T", current->data());
+ ADD_VISITOR_TRACE("BooleanAndExpression arg validity %T", current->validity());
+
+ // short-circuit if valid and false
+ llvm::Value* is_false = builder->CreateNot(current->data());
+ llvm::Value* valid_and_false =
+ builder->CreateAnd(is_false, current->validity(), "valid_and_false");
+
+ llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_);
+ builder->CreateCondBr(valid_and_false, short_circuit_bb, else_bb);
+
+ // Emit the else block.
+ builder->SetInsertPoint(else_bb);
+ // remember if any nulls were encountered.
+ all_exprs_valid =
+ builder->CreateAnd(all_exprs_valid, current->validity(), "validityBitAnd");
+ // continue to evaluate the next pair in list.
+ }
+ builder->CreateBr(non_short_circuit_bb);
+
+ // Short-circuit case (at least one of the expressions is valid and false).
+ // No need to set validity bit (valid by default).
+ builder->SetInsertPoint(short_circuit_bb);
+ ADD_VISITOR_TRACE("BooleanAndExpression result value false");
+ ADD_VISITOR_TRACE("BooleanAndExpression result validity true");
+ builder->CreateBr(merge_bb);
+
+ // non short-circuit case (All expressions are either true or null).
+ // result valid if all of the exprs are non-null.
+ builder->SetInsertPoint(non_short_circuit_bb);
+ ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), all_exprs_valid);
+ ADD_VISITOR_TRACE("BooleanAndExpression result value true");
+ ADD_VISITOR_TRACE("BooleanAndExpression result validity %T", all_exprs_valid);
+ builder->CreateBr(merge_bb);
+
+ builder->SetInsertPoint(merge_bb);
+ llvm::PHINode* result_value = builder->CreatePHI(types->i1_type(), 2, "res_value");
+ result_value->addIncoming(types->false_constant(), short_circuit_bb);
+ result_value->addIncoming(types->true_constant(), non_short_circuit_bb);
+ result_.reset(new LValue(result_value));
+}
+
+// Boolean OR
+// if any arg is valid and true,
+// short-circuit and return TRUE (value=true, valid=true)
+// else if all args are valid and false
+// return FALSE (value=false, valid=true)
+// else
+// return NULL (value=false, valid=false)
+
+void LLVMGenerator::Visitor::Visit(const BooleanOrDex& dex) {
+ ADD_VISITOR_TRACE("visit BooleanOrExpression");
+ llvm::IRBuilder<>* builder = ir_builder();
+ LLVMTypes* types = generator_->types();
+ llvm::LLVMContext* context = generator_->context();
+
+ // Create blocks for short-circuit.
+ llvm::BasicBlock* short_circuit_bb =
+ llvm::BasicBlock::Create(*context, "short_circuit", function_);
+ llvm::BasicBlock* non_short_circuit_bb =
+ llvm::BasicBlock::Create(*context, "non_short_circuit", function_);
+ llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_);
+
+ llvm::Value* all_exprs_valid = types->true_constant();
+ for (auto& pair : dex.args()) {
+ LValuePtr current = BuildValueAndValidity(*pair);
+
+ ADD_VISITOR_TRACE("BooleanOrExpression arg value %T", current->data());
+ ADD_VISITOR_TRACE("BooleanOrExpression arg validity %T", current->validity());
+
+ // short-circuit if valid and true.
+ llvm::Value* valid_and_true =
+ builder->CreateAnd(current->data(), current->validity(), "valid_and_true");
+
+ llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_);
+ builder->CreateCondBr(valid_and_true, short_circuit_bb, else_bb);
+
+ // Emit the else block.
+ builder->SetInsertPoint(else_bb);
+ // remember if any nulls were encountered.
+ all_exprs_valid =
+ builder->CreateAnd(all_exprs_valid, current->validity(), "validityBitAnd");
+ // continue to evaluate the next pair in list.
+ }
+ builder->CreateBr(non_short_circuit_bb);
+
+ // Short-circuit case (at least one of the expressions is valid and true).
+ // No need to set validity bit (valid by default).
+ builder->SetInsertPoint(short_circuit_bb);
+ ADD_VISITOR_TRACE("BooleanOrExpression result value true");
+ ADD_VISITOR_TRACE("BooleanOrExpression result validity true");
+ builder->CreateBr(merge_bb);
+
+ // non short-circuit case (All expressions are either false or null).
+ // result valid if all of the exprs are non-null.
+ builder->SetInsertPoint(non_short_circuit_bb);
+ ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), all_exprs_valid);
+ ADD_VISITOR_TRACE("BooleanOrExpression result value false");
+ ADD_VISITOR_TRACE("BooleanOrExpression result validity %T", all_exprs_valid);
+ builder->CreateBr(merge_bb);
+
+ builder->SetInsertPoint(merge_bb);
+ llvm::PHINode* result_value = builder->CreatePHI(types->i1_type(), 2, "res_value");
+ result_value->addIncoming(types->true_constant(), short_circuit_bb);
+ result_value->addIncoming(types->false_constant(), non_short_circuit_bb);
+ result_.reset(new LValue(result_value));
+}
+
+template <typename Type>
+void LLVMGenerator::Visitor::VisitInExpression(const InExprDexBase<Type>& dex) {
+ ADD_VISITOR_TRACE("visit In Expression");
+ LLVMTypes* types = generator_->types();
+ std::vector<llvm::Value*> params;
+
+ const InExprDex<Type>& dex_instance = dynamic_cast<const InExprDex<Type>&>(dex);
+ /* add the holder at the beginning */
+ llvm::Constant* ptr_int_cast =
+ types->i64_constant((int64_t)(dex_instance.in_holder().get()));
+ params.push_back(ptr_int_cast);
+
+ /* eval expr result */
+ for (auto& pair : dex.args()) {
+ DexPtr value_expr = pair->value_expr();
+ value_expr->Accept(*this);
+ LValue& result_ref = *result();
+ params.push_back(result_ref.data());
+
+ /* length if the result is a string */
+ if (result_ref.length() != nullptr) {
+ params.push_back(result_ref.length());
+ }
+
+ /* push the validity of eval expr result */
+ llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs());
+ params.push_back(validity_expr);
+ }
+
+ llvm::Type* ret_type = types->IRType(arrow::Type::type::BOOL);
+
+ llvm::Value* value;
+
+ value = generator_->AddFunctionCall(dex.runtime_function(), ret_type, params);
+
+ result_.reset(new LValue(value));
+}
+
+template <>
+void LLVMGenerator::Visitor::VisitInExpression<gandiva::DecimalScalar128>(
+ const InExprDexBase<gandiva::DecimalScalar128>& dex) {
+ ADD_VISITOR_TRACE("visit In Expression");
+ LLVMTypes* types = generator_->types();
+ std::vector<llvm::Value*> params;
+ DecimalIR decimalIR(generator_->engine_.get());
+
+ const InExprDex<gandiva::DecimalScalar128>& dex_instance =
+ dynamic_cast<const InExprDex<gandiva::DecimalScalar128>&>(dex);
+ /* add the holder at the beginning */
+ llvm::Constant* ptr_int_cast =
+ types->i64_constant((int64_t)(dex_instance.in_holder().get()));
+ params.push_back(ptr_int_cast);
+
+ /* eval expr result */
+ for (auto& pair : dex.args()) {
+ DexPtr value_expr = pair->value_expr();
+ value_expr->Accept(*this);
+ LValue& result_ref = *result();
+ params.push_back(result_ref.data());
+
+ llvm::Constant* precision = types->i32_constant(dex.get_precision());
+ llvm::Constant* scale = types->i32_constant(dex.get_scale());
+ params.push_back(precision);
+ params.push_back(scale);
+
+ /* push the validity of eval expr result */
+ llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs());
+ params.push_back(validity_expr);
+ }
+
+ llvm::Type* ret_type = types->IRType(arrow::Type::type::BOOL);
+
+ llvm::Value* value;
+
+ value = decimalIR.CallDecimalFunction(dex.runtime_function(), ret_type, params);
+
+ result_.reset(new LValue(value));
+}
+
+void LLVMGenerator::Visitor::Visit(const InExprDexBase<int32_t>& dex) {
+ VisitInExpression<int32_t>(dex);
+}
+
+void LLVMGenerator::Visitor::Visit(const InExprDexBase<int64_t>& dex) {
+ VisitInExpression<int64_t>(dex);
+}
+
+void LLVMGenerator::Visitor::Visit(const InExprDexBase<float>& dex) {
+ VisitInExpression<float>(dex);
+}
+void LLVMGenerator::Visitor::Visit(const InExprDexBase<double>& dex) {
+ VisitInExpression<double>(dex);
+}
+
+void LLVMGenerator::Visitor::Visit(const InExprDexBase<gandiva::DecimalScalar128>& dex) {
+ VisitInExpression<gandiva::DecimalScalar128>(dex);
+}
+
+void LLVMGenerator::Visitor::Visit(const InExprDexBase<std::string>& dex) {
+ VisitInExpression<std::string>(dex);
+}
+
+LValuePtr LLVMGenerator::Visitor::BuildIfElse(llvm::Value* condition,
+ std::function<LValuePtr()> then_func,
+ std::function<LValuePtr()> else_func,
+ DataTypePtr result_type) {
+ llvm::IRBuilder<>* builder = ir_builder();
+ llvm::LLVMContext* context = generator_->context();
+ LLVMTypes* types = generator_->types();
+
+ // Create blocks for the then, else and merge cases.
+ llvm::BasicBlock* then_bb = llvm::BasicBlock::Create(*context, "then", function_);
+ llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_);
+ llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_);
+
+ builder->CreateCondBr(condition, then_bb, else_bb);
+
+ // Emit the then block.
+ builder->SetInsertPoint(then_bb);
+ LValuePtr then_lvalue = then_func();
+ builder->CreateBr(merge_bb);
+
+ // refresh then_bb for phi (could have changed due to code generation of then_vv).
+ then_bb = builder->GetInsertBlock();
+
+ // Emit the else block.
+ builder->SetInsertPoint(else_bb);
+ LValuePtr else_lvalue = else_func();
+ builder->CreateBr(merge_bb);
+
+ // refresh else_bb for phi (could have changed due to code generation of else_vv).
+ else_bb = builder->GetInsertBlock();
+
+ // Emit the merge block.
+ builder->SetInsertPoint(merge_bb);
+ auto llvm_type = types->IRType(result_type->id());
+ llvm::PHINode* result_value = builder->CreatePHI(llvm_type, 2, "res_value");
+ result_value->addIncoming(then_lvalue->data(), then_bb);
+ result_value->addIncoming(else_lvalue->data(), else_bb);
+
+ LValuePtr ret;
+ switch (result_type->id()) {
+ case arrow::Type::STRING:
+ case arrow::Type::BINARY: {
+ llvm::PHINode* result_length;
+ result_length = builder->CreatePHI(types->i32_type(), 2, "res_length");
+ result_length->addIncoming(then_lvalue->length(), then_bb);
+ result_length->addIncoming(else_lvalue->length(), else_bb);
+ ret = std::make_shared<LValue>(result_value, result_length);
+ break;
+ }
+
+ case arrow::Type::DECIMAL:
+ ret = generator_->BuildDecimalLValue(result_value, result_type);
+ break;
+
+ default:
+ ret = std::make_shared<LValue>(result_value);
+ break;
+ }
+ return ret;
+}
+
+LValuePtr LLVMGenerator::Visitor::BuildValueAndValidity(const ValueValidityPair& pair) {
+ // generate code for value
+ auto value_expr = pair.value_expr();
+ value_expr->Accept(*this);
+ auto value = result()->data();
+ auto length = result()->length();
+
+ // generate code for validity
+ auto validity = BuildCombinedValidity(pair.validity_exprs());
+
+ return std::make_shared<LValue>(value, length, validity);
+}
+
+LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func,
+ DataTypePtr arrow_return_type,
+ std::vector<llvm::Value*>* params) {
+ auto types = generator_->types();
+ auto arrow_return_type_id = arrow_return_type->id();
+ auto llvm_return_type = types->IRType(arrow_return_type_id);
+ DecimalIR decimalIR(generator_->engine_.get());
+
+ if (arrow_return_type_id == arrow::Type::DECIMAL) {
+ // For decimal fns, the output precision/scale are passed along as parameters.
+ //
+ // convert from this :
+ // out = add_decimal(v1, p1, s1, v2, p2, s2)
+ // to:
+ // out = add_decimal(v1, p1, s1, v2, p2, s2, out_p, out_s)
+
+ // Append the out_precision and out_scale
+ auto ret_lvalue = generator_->BuildDecimalLValue(nullptr, arrow_return_type);
+ params->push_back(ret_lvalue->precision());
+ params->push_back(ret_lvalue->scale());
+
+ // Make the function call
+ auto out = decimalIR.CallDecimalFunction(func->pc_name(), llvm_return_type, *params);
+ ret_lvalue->set_data(out);
+ return std::move(ret_lvalue);
+ } else {
+ bool isDecimalFunction = false;
+ for (auto& arg : *params) {
+ if (arg->getType() == types->i128_type()) {
+ isDecimalFunction = true;
+ }
+ }
+ // add extra arg for return length for variable len return types (allocated on stack).
+ llvm::AllocaInst* result_len_ptr = nullptr;
+ if (arrow::is_binary_like(arrow_return_type_id)) {
+ result_len_ptr = new llvm::AllocaInst(generator_->types()->i32_type(), 0,
+ "result_len", entry_block_);
+ params->push_back(result_len_ptr);
+ has_arena_allocs_ = true;
+ }
+
+ // Make the function call
+ llvm::IRBuilder<>* builder = ir_builder();
+ auto value =
+ isDecimalFunction
+ ? decimalIR.CallDecimalFunction(func->pc_name(), llvm_return_type, *params)
+ : generator_->AddFunctionCall(func->pc_name(), llvm_return_type, *params);
+ auto value_len =
+ (result_len_ptr == nullptr) ? nullptr : CreateLoad(builder, result_len_ptr);
+ return std::make_shared<LValue>(value, value_len);
+ }
+}
+
+std::vector<llvm::Value*> LLVMGenerator::Visitor::BuildParams(
+ FunctionHolder* holder, const ValueValidityPairVector& args, bool with_validity,
+ bool with_context) {
+ LLVMTypes* types = generator_->types();
+ std::vector<llvm::Value*> params;
+
+ // add context if required.
+ if (with_context) {
+ params.push_back(arg_context_ptr_);
+ }
+
+ // if the function has holder, add the holder pointer.
+ if (holder != nullptr) {
+ auto ptr = types->i64_constant((int64_t)holder);
+ params.push_back(ptr);
+ }
+
+ // build the function params, along with the validities.
+ for (auto& pair : args) {
+ // build value.
+ DexPtr value_expr = pair->value_expr();
+ value_expr->Accept(*this);
+ LValue& result_ref = *result();
+
+ // append all the parameters corresponding to this LValue.
+ result_ref.AppendFunctionParams(&params);
+
+ // build validity.
+ if (with_validity) {
+ llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs());
+ params.push_back(validity_expr);
+ }
+ }
+
+ return params;
+}
+
+// Bitwise-AND of a vector of bits to get the combined validity.
+llvm::Value* LLVMGenerator::Visitor::BuildCombinedValidity(const DexVector& validities) {
+ llvm::IRBuilder<>* builder = ir_builder();
+ LLVMTypes* types = generator_->types();
+
+ llvm::Value* isValid = types->true_constant();
+ for (auto& dex : validities) {
+ dex->Accept(*this);
+ isValid = builder->CreateAnd(isValid, result()->data(), "validityBitAnd");
+ }
+ ADD_VISITOR_TRACE("combined validity is %T", isValid);
+ return isValid;
+}
+
+llvm::Value* LLVMGenerator::Visitor::GetBufferReference(int idx, BufferType buffer_type,
+ FieldPtr field) {
+ llvm::IRBuilder<>* builder = ir_builder();
+
+ // Switch to the entry block to create a reference.
+ llvm::BasicBlock* saved_block = builder->GetInsertBlock();
+ builder->SetInsertPoint(entry_block_);
+
+ llvm::Value* slot_ref = nullptr;
+ switch (buffer_type) {
+ case kBufferTypeValidity:
+ slot_ref = generator_->GetValidityReference(arg_addrs_, idx, field);
+ break;
+
+ case kBufferTypeData:
+ slot_ref = generator_->GetDataReference(arg_addrs_, idx, field);
+ break;
+
+ case kBufferTypeOffsets:
+ slot_ref = generator_->GetOffsetsReference(arg_addrs_, idx, field);
+ break;
+ }
+
+ // Revert to the saved block.
+ builder->SetInsertPoint(saved_block);
+ return slot_ref;
+}
+
+llvm::Value* LLVMGenerator::Visitor::GetSliceOffset(int idx) {
+ return slice_offsets_[idx];
+}
+
+llvm::Value* LLVMGenerator::Visitor::GetLocalBitMapReference(int idx) {
+ llvm::IRBuilder<>* builder = ir_builder();
+
+ // Switch to the entry block to create a reference.
+ llvm::BasicBlock* saved_block = builder->GetInsertBlock();
+ builder->SetInsertPoint(entry_block_);
+
+ llvm::Value* slot_ref = generator_->GetLocalBitMapReference(arg_local_bitmaps_, idx);
+
+ // Revert to the saved block.
+ builder->SetInsertPoint(saved_block);
+ return slot_ref;
+}
+
+/// The local bitmap is pre-filled with 1s. Clear only if invalid.
+void LLVMGenerator::Visitor::ClearLocalBitMapIfNotValid(int local_bitmap_idx,
+ llvm::Value* is_valid) {
+ llvm::Value* slot_ref = GetLocalBitMapReference(local_bitmap_idx);
+ generator_->ClearPackedBitValueIfFalse(slot_ref, loop_var_, is_valid);
+}
+
+// Hooks for tracing/printfs.
+//
+// replace %T with the type-specific format specifier.
+// For some reason, float/double literals are getting lost when printing with the generic
+// printf. so, use a wrapper instead.
+std::string LLVMGenerator::ReplaceFormatInTrace(const std::string& in_msg,
+ llvm::Value* value,
+ std::string* print_fn) {
+ std::string msg = in_msg;
+ std::size_t pos = msg.find("%T");
+ if (pos == std::string::npos) {
+ DCHECK(0);
+ return msg;
+ }
+
+ llvm::Type* type = value->getType();
+ const char* fmt = "";
+ if (type->isIntegerTy(1) || type->isIntegerTy(8) || type->isIntegerTy(16) ||
+ type->isIntegerTy(32)) {
+ fmt = "%d";
+ } else if (type->isIntegerTy(64)) {
+ // bigint
+ fmt = "%lld";
+ } else if (type->isFloatTy()) {
+ // float
+ fmt = "%f";
+ *print_fn = "print_float";
+ } else if (type->isDoubleTy()) {
+ // float
+ fmt = "%lf";
+ *print_fn = "print_double";
+ } else if (type->isPointerTy()) {
+ // string
+ fmt = "%s";
+ } else {
+ DCHECK(0);
+ }
+ msg.replace(pos, 2, fmt);
+ return msg;
+}
+
+void LLVMGenerator::AddTrace(const std::string& msg, llvm::Value* value) {
+ if (!enable_ir_traces_) {
+ return;
+ }
+
+ std::string dmsg = "IR_TRACE:: " + msg + "\n";
+ std::string print_fn_name = "printf";
+ if (value != nullptr) {
+ dmsg = ReplaceFormatInTrace(dmsg, value, &print_fn_name);
+ }
+ trace_strings_.push_back(dmsg);
+
+ // cast this to an llvm pointer.
+ const char* str = trace_strings_.back().c_str();
+ llvm::Constant* str_int_cast = types()->i64_constant((int64_t)str);
+ llvm::Constant* str_ptr_cast =
+ llvm::ConstantExpr::getIntToPtr(str_int_cast, types()->i8_ptr_type());
+
+ std::vector<llvm::Value*> args;
+ args.push_back(str_ptr_cast);
+ if (value != nullptr) {
+ args.push_back(value);
+ }
+ AddFunctionCall(print_fn_name, types()->i32_type(), args);
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/llvm_generator.h b/src/arrow/cpp/src/gandiva/llvm_generator.h
new file mode 100644
index 000000000..ff6d84602
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/llvm_generator.h
@@ -0,0 +1,253 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/util/macros.h"
+
+#include "gandiva/annotator.h"
+#include "gandiva/compiled_expr.h"
+#include "gandiva/configuration.h"
+#include "gandiva/dex_visitor.h"
+#include "gandiva/engine.h"
+#include "gandiva/execution_context.h"
+#include "gandiva/function_registry.h"
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/llvm_types.h"
+#include "gandiva/lvalue.h"
+#include "gandiva/selection_vector.h"
+#include "gandiva/value_validity_pair.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+class FunctionHolder;
+
+/// Builds an LLVM module and generates code for the specified set of expressions.
+class GANDIVA_EXPORT LLVMGenerator {
+ public:
+ /// \brief Factory method to initialize the generator.
+ static Status Make(std::shared_ptr<Configuration> config,
+ std::unique_ptr<LLVMGenerator>* llvm_generator);
+
+ /// \brief Build the code for the expression trees for default mode. Each
+ /// element in the vector represents an expression tree
+ Status Build(const ExpressionVector& exprs, SelectionVector::Mode mode);
+
+ /// \brief Build the code for the expression trees for default mode. Each
+ /// element in the vector represents an expression tree
+ Status Build(const ExpressionVector& exprs) {
+ return Build(exprs, SelectionVector::Mode::MODE_NONE);
+ }
+
+ /// \brief Execute the built expression against the provided arguments for
+ /// default mode.
+ Status Execute(const arrow::RecordBatch& record_batch,
+ const ArrayDataVector& output_vector);
+
+ /// \brief Execute the built expression against the provided arguments for
+ /// all modes. Only works on the records specified in the selection_vector.
+ Status Execute(const arrow::RecordBatch& record_batch,
+ const SelectionVector* selection_vector,
+ const ArrayDataVector& output_vector);
+
+ SelectionVector::Mode selection_vector_mode() { return selection_vector_mode_; }
+ LLVMTypes* types() { return engine_->types(); }
+ llvm::Module* module() { return engine_->module(); }
+ std::string DumpIR() { return engine_->DumpIR(); }
+
+ private:
+ LLVMGenerator();
+
+ FRIEND_TEST(TestLLVMGenerator, VerifyPCFunctions);
+ FRIEND_TEST(TestLLVMGenerator, TestAdd);
+ FRIEND_TEST(TestLLVMGenerator, TestNullInternal);
+
+ llvm::LLVMContext* context() { return engine_->context(); }
+ llvm::IRBuilder<>* ir_builder() { return engine_->ir_builder(); }
+
+ /// Visitor to generate the code for a decomposed expression.
+ class Visitor : public DexVisitor {
+ public:
+ Visitor(LLVMGenerator* generator, llvm::Function* function,
+ llvm::BasicBlock* entry_block, llvm::Value* arg_addrs,
+ llvm::Value* arg_local_bitmaps, std::vector<llvm::Value*> slice_offsets,
+ llvm::Value* arg_context_ptr, llvm::Value* loop_var);
+
+ void Visit(const VectorReadValidityDex& dex) override;
+ void Visit(const VectorReadFixedLenValueDex& dex) override;
+ void Visit(const VectorReadVarLenValueDex& dex) override;
+ void Visit(const LocalBitMapValidityDex& dex) override;
+ void Visit(const TrueDex& dex) override;
+ void Visit(const FalseDex& dex) override;
+ void Visit(const LiteralDex& dex) override;
+ void Visit(const NonNullableFuncDex& dex) override;
+ void Visit(const NullableNeverFuncDex& dex) override;
+ void Visit(const NullableInternalFuncDex& dex) override;
+ void Visit(const IfDex& dex) override;
+ void Visit(const BooleanAndDex& dex) override;
+ void Visit(const BooleanOrDex& dex) override;
+ void Visit(const InExprDexBase<int32_t>& dex) override;
+ void Visit(const InExprDexBase<int64_t>& dex) override;
+ void Visit(const InExprDexBase<float>& dex) override;
+ void Visit(const InExprDexBase<double>& dex) override;
+ void Visit(const InExprDexBase<gandiva::DecimalScalar128>& dex) override;
+ void Visit(const InExprDexBase<std::string>& dex) override;
+ template <typename Type>
+ void VisitInExpression(const InExprDexBase<Type>& dex);
+
+ LValuePtr result() { return result_; }
+
+ bool has_arena_allocs() { return has_arena_allocs_; }
+
+ private:
+ enum BufferType { kBufferTypeValidity = 0, kBufferTypeData, kBufferTypeOffsets };
+
+ llvm::IRBuilder<>* ir_builder() { return generator_->ir_builder(); }
+ llvm::Module* module() { return generator_->module(); }
+
+ // Generate the code to build the combined validity (bitwise and) from the
+ // vector of validities.
+ llvm::Value* BuildCombinedValidity(const DexVector& validities);
+
+ // Generate the code to build the validity and the value for the given pair.
+ LValuePtr BuildValueAndValidity(const ValueValidityPair& pair);
+
+ // Generate code to build the params.
+ std::vector<llvm::Value*> BuildParams(FunctionHolder* holder,
+ const ValueValidityPairVector& args,
+ bool with_validity, bool with_context);
+
+ // Generate code to onvoke a function call.
+ LValuePtr BuildFunctionCall(const NativeFunction* func, DataTypePtr arrow_return_type,
+ std::vector<llvm::Value*>* params);
+
+ // Generate code for an if-else condition.
+ LValuePtr BuildIfElse(llvm::Value* condition, std::function<LValuePtr()> then_func,
+ std::function<LValuePtr()> else_func,
+ DataTypePtr arrow_return_type);
+
+ // Switch to the entry_block and get reference of the validity/value/offsets buffer
+ llvm::Value* GetBufferReference(int idx, BufferType buffer_type, FieldPtr field);
+
+ // Get the slice offset of the validity/value/offsets buffer
+ llvm::Value* GetSliceOffset(int idx);
+
+ // Switch to the entry_block and get reference to the local bitmap.
+ llvm::Value* GetLocalBitMapReference(int idx);
+
+ // Clear the bit in the local bitmap, if is_valid is 'false'
+ void ClearLocalBitMapIfNotValid(int local_bitmap_idx, llvm::Value* is_valid);
+
+ LLVMGenerator* generator_;
+ LValuePtr result_;
+ llvm::Function* function_;
+ llvm::BasicBlock* entry_block_;
+ llvm::Value* arg_addrs_;
+ llvm::Value* arg_local_bitmaps_;
+ std::vector<llvm::Value*> slice_offsets_;
+ llvm::Value* arg_context_ptr_;
+ llvm::Value* loop_var_;
+ bool has_arena_allocs_;
+ };
+
+ // Generate the code for one expression for default mode, with the output of
+ // the expression going to 'output'.
+ Status Add(const ExpressionPtr expr, const FieldDescriptorPtr output);
+
+ /// Generate code to load the vector at specified index in the 'arg_addrs' array.
+ llvm::Value* LoadVectorAtIndex(llvm::Value* arg_addrs, int idx,
+ const std::string& name);
+
+ /// Generate code to load the vector at specified index and cast it as bitmap.
+ llvm::Value* GetValidityReference(llvm::Value* arg_addrs, int idx, FieldPtr field);
+
+ /// Generate code to load the vector at specified index and cast it as data array.
+ llvm::Value* GetDataReference(llvm::Value* arg_addrs, int idx, FieldPtr field);
+
+ /// Generate code to load the vector at specified index and cast it as offsets array.
+ llvm::Value* GetOffsetsReference(llvm::Value* arg_addrs, int idx, FieldPtr field);
+
+ /// Generate code to load the vector at specified index and cast it as buffer pointer.
+ llvm::Value* GetDataBufferPtrReference(llvm::Value* arg_addrs, int idx, FieldPtr field);
+
+ /// Generate code for the value array of one expression.
+ Status CodeGenExprValue(DexPtr value_expr, int num_buffers, FieldDescriptorPtr output,
+ int suffix_idx, llvm::Function** fn,
+ SelectionVector::Mode selection_vector_mode);
+
+ /// Generate code to load the local bitmap specified index and cast it as bitmap.
+ llvm::Value* GetLocalBitMapReference(llvm::Value* arg_bitmaps, int idx);
+
+ /// Generate code to get the bit value at 'position' in the bitmap.
+ llvm::Value* GetPackedBitValue(llvm::Value* bitmap, llvm::Value* position);
+
+ /// Generate code to get the bit value at 'position' in the validity bitmap.
+ llvm::Value* GetPackedValidityBitValue(llvm::Value* bitmap, llvm::Value* position);
+
+ /// Generate code to set the bit value at 'position' in the bitmap to 'value'.
+ void SetPackedBitValue(llvm::Value* bitmap, llvm::Value* position, llvm::Value* value);
+
+ /// Generate code to clear the bit value at 'position' in the bitmap if 'value'
+ /// is false.
+ void ClearPackedBitValueIfFalse(llvm::Value* bitmap, llvm::Value* position,
+ llvm::Value* value);
+
+ // Generate code to build a DecimalLValue with specified value/precision/scale.
+ std::shared_ptr<DecimalLValue> BuildDecimalLValue(llvm::Value* value,
+ DataTypePtr arrow_type);
+
+ /// Generate code to make a function call (to a pre-compiled IR function) which takes
+ /// 'args' and has a return type 'ret_type'.
+ llvm::Value* AddFunctionCall(const std::string& full_name, llvm::Type* ret_type,
+ const std::vector<llvm::Value*>& args);
+
+ /// Compute the result bitmap for the expression.
+ ///
+ /// \param[in] compiled_expr the compiled expression (includes the bitmap indices to be
+ /// used for computing the validity bitmap of the result).
+ /// \param[in] eval_batch (includes input/output buffer addresses)
+ /// \param[in] selection_vector the list of selected positions
+ void ComputeBitMapsForExpr(const CompiledExpr& compiled_expr,
+ const EvalBatch& eval_batch,
+ const SelectionVector* selection_vector);
+
+ /// Replace the %T in the trace msg with the correct type corresponding to 'type'
+ /// eg. %d for int32, %ld for int64, ..
+ std::string ReplaceFormatInTrace(const std::string& msg, llvm::Value* value,
+ std::string* print_fn);
+
+ /// Generate the code to print a trace msg with one optional argument (%T)
+ void AddTrace(const std::string& msg, llvm::Value* value = NULLPTR);
+
+ std::unique_ptr<Engine> engine_;
+ std::vector<std::unique_ptr<CompiledExpr>> compiled_exprs_;
+ FunctionRegistry function_registry_;
+ Annotator annotator_;
+ SelectionVector::Mode selection_vector_mode_;
+
+ // used for debug
+ bool enable_ir_traces_;
+ std::vector<std::string> trace_strings_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/llvm_generator_test.cc b/src/arrow/cpp/src/gandiva/llvm_generator_test.cc
new file mode 100644
index 000000000..bdc3b0051
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/llvm_generator_test.cc
@@ -0,0 +1,116 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/llvm_generator.h"
+
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "gandiva/configuration.h"
+#include "gandiva/dex.h"
+#include "gandiva/expression.h"
+#include "gandiva/func_descriptor.h"
+#include "gandiva/function_registry.h"
+#include "gandiva/tests/test_util.h"
+
+namespace gandiva {
+
+typedef int64_t (*add_vector_func_t)(int64_t* elements, int nelements);
+
+class TestLLVMGenerator : public ::testing::Test {
+ protected:
+ FunctionRegistry registry_;
+};
+
+// Verify that a valid pc function exists for every function in the registry.
+TEST_F(TestLLVMGenerator, VerifyPCFunctions) {
+ std::unique_ptr<LLVMGenerator> generator;
+ ASSERT_OK(LLVMGenerator::Make(TestConfiguration(), &generator));
+
+ llvm::Module* module = generator->module();
+ for (auto& iter : registry_) {
+ EXPECT_NE(module->getFunction(iter.pc_name()), nullptr);
+ }
+}
+
+TEST_F(TestLLVMGenerator, TestAdd) {
+ // Setup LLVM generator to do an arithmetic add of two vectors
+ std::unique_ptr<LLVMGenerator> generator;
+ ASSERT_OK(LLVMGenerator::Make(TestConfiguration(), &generator));
+ Annotator annotator;
+
+ auto field0 = std::make_shared<arrow::Field>("f0", arrow::int32());
+ auto desc0 = annotator.CheckAndAddInputFieldDescriptor(field0);
+ auto validity_dex0 = std::make_shared<VectorReadValidityDex>(desc0);
+ auto value_dex0 = std::make_shared<VectorReadFixedLenValueDex>(desc0);
+ auto pair0 = std::make_shared<ValueValidityPair>(validity_dex0, value_dex0);
+
+ auto field1 = std::make_shared<arrow::Field>("f1", arrow::int32());
+ auto desc1 = annotator.CheckAndAddInputFieldDescriptor(field1);
+ auto validity_dex1 = std::make_shared<VectorReadValidityDex>(desc1);
+ auto value_dex1 = std::make_shared<VectorReadFixedLenValueDex>(desc1);
+ auto pair1 = std::make_shared<ValueValidityPair>(validity_dex1, value_dex1);
+
+ DataTypeVector params{arrow::int32(), arrow::int32()};
+ auto func_desc = std::make_shared<FuncDescriptor>("add", params, arrow::int32());
+ FunctionSignature signature(func_desc->name(), func_desc->params(),
+ func_desc->return_type());
+ const NativeFunction* native_func =
+ generator->function_registry_.LookupSignature(signature);
+
+ std::vector<ValueValidityPairPtr> pairs{pair0, pair1};
+ auto func_dex = std::make_shared<NonNullableFuncDex>(func_desc, native_func,
+ FunctionHolderPtr(nullptr), pairs);
+
+ auto field_sum = std::make_shared<arrow::Field>("out", arrow::int32());
+ auto desc_sum = annotator.CheckAndAddInputFieldDescriptor(field_sum);
+
+ llvm::Function* ir_func = nullptr;
+
+ ASSERT_OK(generator->CodeGenExprValue(func_dex, 4, desc_sum, 0, &ir_func,
+ SelectionVector::MODE_NONE));
+
+ ASSERT_OK(generator->engine_->FinalizeModule());
+ auto ir = generator->engine_->DumpIR();
+ EXPECT_THAT(ir, testing::HasSubstr("vector.body"));
+
+ EvalFunc eval_func = (EvalFunc)generator->engine_->CompiledFunction(ir_func);
+
+ constexpr size_t kNumRecords = 4;
+ std::array<uint32_t, kNumRecords> a0{1, 2, 3, 4};
+ std::array<uint32_t, kNumRecords> a1{5, 6, 7, 8};
+ uint64_t in_bitmap = 0xffffffffffffffffull;
+
+ std::array<uint32_t, kNumRecords> out{0, 0, 0, 0};
+ uint64_t out_bitmap = 0;
+
+ std::array<uint8_t*, 6> addrs{
+ reinterpret_cast<uint8_t*>(a0.data()), reinterpret_cast<uint8_t*>(&in_bitmap),
+ reinterpret_cast<uint8_t*>(a1.data()), reinterpret_cast<uint8_t*>(&in_bitmap),
+ reinterpret_cast<uint8_t*>(out.data()), reinterpret_cast<uint8_t*>(&out_bitmap),
+ };
+ std::array<int64_t, 6> addr_offsets{0, 0, 0, 0, 0, 0};
+ eval_func(addrs.data(), addr_offsets.data(), nullptr, nullptr,
+ 0 /* dummy context ptr */, kNumRecords);
+
+ EXPECT_THAT(out, testing::ElementsAre(6, 8, 10, 12));
+ EXPECT_EQ(out_bitmap, 0ULL);
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/llvm_includes.h b/src/arrow/cpp/src/gandiva/llvm_includes.h
new file mode 100644
index 000000000..37f915eb5
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/llvm_includes.h
@@ -0,0 +1,56 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4141)
+#pragma warning(disable : 4146)
+#pragma warning(disable : 4244)
+#pragma warning(disable : 4267)
+#pragma warning(disable : 4291)
+#pragma warning(disable : 4624)
+#endif
+
+#include <llvm/ExecutionEngine/ExecutionEngine.h>
+#include <llvm/IR/IRBuilder.h>
+#include <llvm/IR/LLVMContext.h>
+#include <llvm/IR/Module.h>
+
+#if LLVM_VERSION_MAJOR >= 10
+#define LLVM_ALIGN(alignment) (llvm::Align((alignment)))
+#else
+#define LLVM_ALIGN(alignment) (alignment)
+#endif
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+// Workaround for deprecated builder methods as of LLVM 13: ARROW-14363
+inline llvm::Value* CreateGEP(llvm::IRBuilder<>* builder, llvm::Value* Ptr,
+ llvm::ArrayRef<llvm::Value*> IdxList,
+ const llvm::Twine& Name = "") {
+ return builder->CreateGEP(Ptr->getType()->getScalarType()->getPointerElementType(), Ptr,
+ IdxList, Name);
+}
+
+inline llvm::LoadInst* CreateLoad(llvm::IRBuilder<>* builder, llvm::Value* Ptr,
+ const llvm::Twine& Name = "") {
+ return builder->CreateLoad(Ptr->getType()->getPointerElementType(), Ptr, Name);
+}
diff --git a/src/arrow/cpp/src/gandiva/llvm_types.cc b/src/arrow/cpp/src/gandiva/llvm_types.cc
new file mode 100644
index 000000000..de322a8c0
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/llvm_types.cc
@@ -0,0 +1,48 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/llvm_types.h"
+
+namespace gandiva {
+
+// LLVM doesn't distinguish between signed and unsigned types.
+
+LLVMTypes::LLVMTypes(llvm::LLVMContext& context) : context_(context) {
+ arrow_id_to_llvm_type_map_ = {{arrow::Type::type::BOOL, i1_type()},
+ {arrow::Type::type::INT8, i8_type()},
+ {arrow::Type::type::INT16, i16_type()},
+ {arrow::Type::type::INT32, i32_type()},
+ {arrow::Type::type::INT64, i64_type()},
+ {arrow::Type::type::UINT8, i8_type()},
+ {arrow::Type::type::UINT16, i16_type()},
+ {arrow::Type::type::UINT32, i32_type()},
+ {arrow::Type::type::UINT64, i64_type()},
+ {arrow::Type::type::FLOAT, float_type()},
+ {arrow::Type::type::DOUBLE, double_type()},
+ {arrow::Type::type::DATE32, i32_type()},
+ {arrow::Type::type::DATE64, i64_type()},
+ {arrow::Type::type::TIME32, i32_type()},
+ {arrow::Type::type::TIME64, i64_type()},
+ {arrow::Type::type::TIMESTAMP, i64_type()},
+ {arrow::Type::type::STRING, i8_ptr_type()},
+ {arrow::Type::type::BINARY, i8_ptr_type()},
+ {arrow::Type::type::DECIMAL, i128_type()},
+ {arrow::Type::type::INTERVAL_MONTHS, i32_type()},
+ {arrow::Type::type::INTERVAL_DAY_TIME, i64_type()}};
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/llvm_types.h b/src/arrow/cpp/src/gandiva/llvm_types.h
new file mode 100644
index 000000000..d6f095271
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/llvm_types.h
@@ -0,0 +1,130 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <map>
+#include <vector>
+
+#include "arrow/util/logging.h"
+#include "gandiva/arrow.h"
+#include "gandiva/llvm_includes.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// \brief Holder for llvm types, and mappings between arrow types and llvm types.
+class GANDIVA_EXPORT LLVMTypes {
+ public:
+ explicit LLVMTypes(llvm::LLVMContext& context);
+
+ llvm::Type* void_type() { return llvm::Type::getVoidTy(context_); }
+
+ llvm::Type* i1_type() { return llvm::Type::getInt1Ty(context_); }
+
+ llvm::Type* i8_type() { return llvm::Type::getInt8Ty(context_); }
+
+ llvm::Type* i16_type() { return llvm::Type::getInt16Ty(context_); }
+
+ llvm::Type* i32_type() { return llvm::Type::getInt32Ty(context_); }
+
+ llvm::Type* i64_type() { return llvm::Type::getInt64Ty(context_); }
+
+ llvm::Type* i128_type() { return llvm::Type::getInt128Ty(context_); }
+
+ llvm::StructType* i128_split_type() {
+ // struct with high/low bits (see decimal_ops.cc:DecimalSplit)
+ return llvm::StructType::get(context_, {i64_type(), i64_type()}, false);
+ }
+
+ llvm::Type* float_type() { return llvm::Type::getFloatTy(context_); }
+
+ llvm::Type* double_type() { return llvm::Type::getDoubleTy(context_); }
+
+ llvm::PointerType* ptr_type(llvm::Type* type) { return type->getPointerTo(); }
+
+ llvm::PointerType* i8_ptr_type() { return ptr_type(i8_type()); }
+
+ llvm::PointerType* i32_ptr_type() { return ptr_type(i32_type()); }
+
+ llvm::PointerType* i64_ptr_type() { return ptr_type(i64_type()); }
+
+ llvm::PointerType* i128_ptr_type() { return ptr_type(i128_type()); }
+
+ template <typename ctype, size_t N = (sizeof(ctype) * CHAR_BIT)>
+ llvm::Constant* int_constant(ctype val) {
+ return llvm::ConstantInt::get(context_, llvm::APInt(N, val));
+ }
+
+ llvm::Constant* i1_constant(bool val) { return int_constant<bool, 1>(val); }
+ llvm::Constant* i8_constant(int8_t val) { return int_constant(val); }
+ llvm::Constant* i16_constant(int16_t val) { return int_constant(val); }
+ llvm::Constant* i32_constant(int32_t val) { return int_constant(val); }
+ llvm::Constant* i64_constant(int64_t val) { return int_constant(val); }
+ llvm::Constant* i128_constant(int64_t val) { return int_constant<int64_t, 128>(val); }
+
+ llvm::Constant* true_constant() { return i1_constant(true); }
+ llvm::Constant* false_constant() { return i1_constant(false); }
+
+ llvm::Constant* i128_zero() { return i128_constant(0); }
+ llvm::Constant* i128_one() { return i128_constant(1); }
+
+ llvm::Constant* float_constant(float val) {
+ return llvm::ConstantFP::get(float_type(), val);
+ }
+
+ llvm::Constant* double_constant(double val) {
+ return llvm::ConstantFP::get(double_type(), val);
+ }
+
+ llvm::Constant* NullConstant(llvm::Type* type) {
+ if (type->isIntegerTy()) {
+ return llvm::ConstantInt::get(type, 0);
+ } else if (type->isFloatingPointTy()) {
+ return llvm::ConstantFP::get(type, 0);
+ } else {
+ DCHECK(type->isPointerTy());
+ return llvm::ConstantPointerNull::getNullValue(type);
+ }
+ }
+
+ /// For a given data type, find the ir type used for the data vector slot.
+ llvm::Type* DataVecType(const DataTypePtr& data_type) {
+ return IRType(data_type->id());
+ }
+
+ /// For a given minor type, find the corresponding ir type.
+ llvm::Type* IRType(arrow::Type::type arrow_type) {
+ auto found = arrow_id_to_llvm_type_map_.find(arrow_type);
+ return (found == arrow_id_to_llvm_type_map_.end()) ? NULL : found->second;
+ }
+
+ std::vector<arrow::Type::type> GetSupportedArrowTypes() {
+ std::vector<arrow::Type::type> retval;
+ for (auto const& element : arrow_id_to_llvm_type_map_) {
+ retval.push_back(element.first);
+ }
+ return retval;
+ }
+
+ private:
+ std::map<arrow::Type::type, llvm::Type*> arrow_id_to_llvm_type_map_;
+
+ llvm::LLVMContext& context_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/llvm_types_test.cc b/src/arrow/cpp/src/gandiva/llvm_types_test.cc
new file mode 100644
index 000000000..666968306
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/llvm_types_test.cc
@@ -0,0 +1,61 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/llvm_types.h"
+
+#include <gtest/gtest.h>
+
+namespace gandiva {
+
+class TestLLVMTypes : public ::testing::Test {
+ protected:
+ virtual void SetUp() { types_ = new LLVMTypes(context_); }
+ virtual void TearDown() { delete types_; }
+
+ llvm::LLVMContext context_;
+ LLVMTypes* types_;
+};
+
+TEST_F(TestLLVMTypes, TestFound) {
+ EXPECT_EQ(types_->IRType(arrow::Type::BOOL), types_->i1_type());
+ EXPECT_EQ(types_->IRType(arrow::Type::INT32), types_->i32_type());
+ EXPECT_EQ(types_->IRType(arrow::Type::INT64), types_->i64_type());
+ EXPECT_EQ(types_->IRType(arrow::Type::FLOAT), types_->float_type());
+ EXPECT_EQ(types_->IRType(arrow::Type::DOUBLE), types_->double_type());
+ EXPECT_EQ(types_->IRType(arrow::Type::DATE64), types_->i64_type());
+ EXPECT_EQ(types_->IRType(arrow::Type::TIME64), types_->i64_type());
+ EXPECT_EQ(types_->IRType(arrow::Type::TIMESTAMP), types_->i64_type());
+
+ EXPECT_EQ(types_->DataVecType(arrow::boolean()), types_->i1_type());
+ EXPECT_EQ(types_->DataVecType(arrow::int32()), types_->i32_type());
+ EXPECT_EQ(types_->DataVecType(arrow::int64()), types_->i64_type());
+ EXPECT_EQ(types_->DataVecType(arrow::float32()), types_->float_type());
+ EXPECT_EQ(types_->DataVecType(arrow::float64()), types_->double_type());
+ EXPECT_EQ(types_->DataVecType(arrow::date64()), types_->i64_type());
+ EXPECT_EQ(types_->DataVecType(arrow::time64(arrow::TimeUnit::MICRO)),
+ types_->i64_type());
+ EXPECT_EQ(types_->DataVecType(arrow::timestamp(arrow::TimeUnit::MILLI)),
+ types_->i64_type());
+}
+
+TEST_F(TestLLVMTypes, TestNotFound) {
+ EXPECT_EQ(types_->IRType(arrow::Type::SPARSE_UNION), nullptr);
+ EXPECT_EQ(types_->IRType(arrow::Type::DENSE_UNION), nullptr);
+ EXPECT_EQ(types_->DataVecType(arrow::null()), nullptr);
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/local_bitmaps_holder.h b/src/arrow/cpp/src/gandiva/local_bitmaps_holder.h
new file mode 100644
index 000000000..a172fb973
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/local_bitmaps_holder.h
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include <arrow/util/logging.h>
+#include "gandiva/arrow.h"
+#include "gandiva/gandiva_aliases.h"
+
+namespace gandiva {
+
+/// \brief The buffers corresponding to one batch of records, used for
+/// expression evaluation.
+class LocalBitMapsHolder {
+ public:
+ LocalBitMapsHolder(int64_t num_records, int num_local_bitmaps);
+
+ int GetNumLocalBitMaps() const { return static_cast<int>(local_bitmaps_vec_.size()); }
+
+ int64_t GetLocalBitMapSize() const { return local_bitmap_size_; }
+
+ uint8_t** GetLocalBitMapArray() const { return local_bitmaps_array_.get(); }
+
+ uint8_t* GetLocalBitMap(int idx) const {
+ DCHECK(idx <= GetNumLocalBitMaps());
+ return local_bitmaps_array_.get()[idx];
+ }
+
+ private:
+ /// number of records in the current batch.
+ int64_t num_records_;
+
+ /// A container of 'local_bitmaps_', each sized to accommodate 'num_records'.
+ std::vector<std::unique_ptr<uint8_t[]>> local_bitmaps_vec_;
+
+ /// An array of the local bitmaps.
+ std::unique_ptr<uint8_t*[]> local_bitmaps_array_;
+
+ int64_t local_bitmap_size_;
+};
+
+inline LocalBitMapsHolder::LocalBitMapsHolder(int64_t num_records, int num_local_bitmaps)
+ : num_records_(num_records) {
+ // alloc an array for the pointers to the bitmaps.
+ if (num_local_bitmaps > 0) {
+ local_bitmaps_array_.reset(new uint8_t*[num_local_bitmaps]);
+ }
+
+ // 64-bit aligned bitmaps.
+ int64_t roundUp64Multiple = (num_records_ + 63) >> 6;
+ local_bitmap_size_ = roundUp64Multiple * 8;
+
+ // Alloc 'num_local_bitmaps_' number of bitmaps, each of capacity 'num_records_'.
+ for (int i = 0; i < num_local_bitmaps; ++i) {
+ // TODO : round-up to a slab friendly multiple.
+ std::unique_ptr<uint8_t[]> bitmap(new uint8_t[local_bitmap_size_]);
+
+ // keep pointer to the bitmap in the array.
+ (local_bitmaps_array_.get())[i] = bitmap.get();
+
+ // pre-fill with 1s (assuming that the probability of is_valid is higher).
+ memset(bitmap.get(), 0xff, local_bitmap_size_);
+ local_bitmaps_vec_.push_back(std::move(bitmap));
+ }
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/lvalue.h b/src/arrow/cpp/src/gandiva/lvalue.h
new file mode 100644
index 000000000..df292855b
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/lvalue.h
@@ -0,0 +1,77 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <vector>
+
+#include "arrow/util/macros.h"
+
+#include "arrow/util/logging.h"
+#include "gandiva/llvm_includes.h"
+
+namespace gandiva {
+
+/// \brief Tracks validity/value builders in LLVM.
+class GANDIVA_EXPORT LValue {
+ public:
+ explicit LValue(llvm::Value* data, llvm::Value* length = NULLPTR,
+ llvm::Value* validity = NULLPTR)
+ : data_(data), length_(length), validity_(validity) {}
+ virtual ~LValue() = default;
+
+ llvm::Value* data() { return data_; }
+ llvm::Value* length() { return length_; }
+ llvm::Value* validity() { return validity_; }
+
+ void set_data(llvm::Value* data) { data_ = data; }
+
+ // Append the params required when passing this as a function parameter.
+ virtual void AppendFunctionParams(std::vector<llvm::Value*>* params) {
+ params->push_back(data_);
+ if (length_ != NULLPTR) {
+ params->push_back(length_);
+ }
+ }
+
+ private:
+ llvm::Value* data_;
+ llvm::Value* length_;
+ llvm::Value* validity_;
+};
+
+class GANDIVA_EXPORT DecimalLValue : public LValue {
+ public:
+ DecimalLValue(llvm::Value* data, llvm::Value* validity, llvm::Value* precision,
+ llvm::Value* scale)
+ : LValue(data, NULLPTR, validity), precision_(precision), scale_(scale) {}
+
+ llvm::Value* precision() { return precision_; }
+ llvm::Value* scale() { return scale_; }
+
+ void AppendFunctionParams(std::vector<llvm::Value*>* params) override {
+ LValue::AppendFunctionParams(params);
+ params->push_back(precision_);
+ params->push_back(scale_);
+ }
+
+ private:
+ llvm::Value* precision_;
+ llvm::Value* scale_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/make_precompiled_bitcode.py b/src/arrow/cpp/src/gandiva/make_precompiled_bitcode.py
new file mode 100644
index 000000000..97d96f8a8
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/make_precompiled_bitcode.py
@@ -0,0 +1,49 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+
+marker = b"<DATA_CHARS>"
+
+def expand(data):
+ """
+ Expand *data* as a initializer list of hexadecimal char escapes.
+ """
+ expanded_data = ", ".join([hex(c) for c in bytearray(data)])
+ return expanded_data.encode('ascii')
+
+
+def apply_template(template, data):
+ if template.count(marker) != 1:
+ raise ValueError("Invalid template")
+ return template.replace(marker, expand(data))
+
+
+if __name__ == "__main__":
+ if len(sys.argv) != 4:
+ raise ValueError("Usage: {0} <template file> <data file> "
+ "<output file>".format(sys.argv[0]))
+ with open(sys.argv[1], "rb") as f:
+ template = f.read()
+ with open(sys.argv[2], "rb") as f:
+ data = f.read()
+
+ expanded_data = apply_template(template, data)
+ with open(sys.argv[3], "wb") as f:
+ f.write(expanded_data)
diff --git a/src/arrow/cpp/src/gandiva/native_function.h b/src/arrow/cpp/src/gandiva/native_function.h
new file mode 100644
index 000000000..1268a2567
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/native_function.h
@@ -0,0 +1,81 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gandiva/arrow.h"
+#include "gandiva/function_signature.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+enum ResultNullableType {
+ /// result validity is an intersection of the validity of the children.
+ kResultNullIfNull,
+ /// result is always valid.
+ kResultNullNever,
+ /// result validity depends on some internal logic.
+ kResultNullInternal,
+};
+
+/// \brief Holder for the mapping from a function in an expression to a
+/// precompiled function.
+class GANDIVA_EXPORT NativeFunction {
+ public:
+ // function attributes.
+ static constexpr int32_t kNeedsContext = (1 << 1);
+ static constexpr int32_t kNeedsFunctionHolder = (1 << 2);
+ static constexpr int32_t kCanReturnErrors = (1 << 3);
+
+ const std::vector<FunctionSignature>& signatures() const { return signatures_; }
+ std::string pc_name() const { return pc_name_; }
+ ResultNullableType result_nullable_type() const { return result_nullable_type_; }
+
+ bool NeedsContext() const { return (flags_ & kNeedsContext) != 0; }
+ bool NeedsFunctionHolder() const { return (flags_ & kNeedsFunctionHolder) != 0; }
+ bool CanReturnErrors() const { return (flags_ & kCanReturnErrors) != 0; }
+
+ NativeFunction(const std::string& base_name, const std::vector<std::string>& aliases,
+ const DataTypeVector& param_types, DataTypePtr ret_type,
+ const ResultNullableType& result_nullable_type,
+ const std::string& pc_name, int32_t flags = 0)
+ : signatures_(),
+ flags_(flags),
+ result_nullable_type_(result_nullable_type),
+ pc_name_(pc_name) {
+ signatures_.push_back(FunctionSignature(base_name, param_types, ret_type));
+ for (auto& func_name : aliases) {
+ signatures_.push_back(FunctionSignature(func_name, param_types, ret_type));
+ }
+ }
+
+ private:
+ std::vector<FunctionSignature> signatures_;
+
+ /// attributes
+ int32_t flags_;
+ ResultNullableType result_nullable_type_;
+
+ /// pre-compiled function name.
+ std::string pc_name_;
+};
+
+} // end namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/node.h b/src/arrow/cpp/src/gandiva/node.h
new file mode 100644
index 000000000..20807d4a0
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/node.h
@@ -0,0 +1,299 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <sstream>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "arrow/status.h"
+
+#include "gandiva/arrow.h"
+#include "gandiva/func_descriptor.h"
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/literal_holder.h"
+#include "gandiva/node_visitor.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// \brief Represents a node in the expression tree. Validity and value are
+/// in a joined state.
+class GANDIVA_EXPORT Node {
+ public:
+ explicit Node(DataTypePtr return_type) : return_type_(return_type) {}
+
+ virtual ~Node() = default;
+
+ const DataTypePtr& return_type() const { return return_type_; }
+
+ /// Derived classes should simply invoke the Visit api of the visitor.
+ virtual Status Accept(NodeVisitor& visitor) const = 0;
+
+ virtual std::string ToString() const = 0;
+
+ protected:
+ DataTypePtr return_type_;
+};
+
+/// \brief Node in the expression tree, representing a literal.
+class GANDIVA_EXPORT LiteralNode : public Node {
+ public:
+ LiteralNode(DataTypePtr type, const LiteralHolder& holder, bool is_null)
+ : Node(type), holder_(holder), is_null_(is_null) {}
+
+ Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); }
+
+ const LiteralHolder& holder() const { return holder_; }
+
+ bool is_null() const { return is_null_; }
+
+ std::string ToString() const override {
+ std::stringstream ss;
+ ss << "(const " << return_type()->ToString() << ") ";
+ if (is_null()) {
+ ss << std::string("null");
+ return ss.str();
+ }
+
+ ss << gandiva::ToString(holder_);
+ // The default formatter prints in decimal can cause a loss in precision. so,
+ // print in hex. Can't use hexfloat since gcc 4.9 doesn't support it.
+ if (return_type()->id() == arrow::Type::DOUBLE) {
+ double dvalue = arrow::util::get<double>(holder_);
+ uint64_t bits;
+ memcpy(&bits, &dvalue, sizeof(bits));
+ ss << " raw(" << std::hex << bits << ")";
+ } else if (return_type()->id() == arrow::Type::FLOAT) {
+ float fvalue = arrow::util::get<float>(holder_);
+ uint32_t bits;
+ memcpy(&bits, &fvalue, sizeof(bits));
+ ss << " raw(" << std::hex << bits << ")";
+ }
+ return ss.str();
+ }
+
+ private:
+ LiteralHolder holder_;
+ bool is_null_;
+};
+
+/// \brief Node in the expression tree, representing an arrow field.
+class GANDIVA_EXPORT FieldNode : public Node {
+ public:
+ explicit FieldNode(FieldPtr field) : Node(field->type()), field_(field) {}
+
+ Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); }
+
+ const FieldPtr& field() const { return field_; }
+
+ std::string ToString() const override {
+ return "(" + field()->type()->ToString() + ") " + field()->name();
+ }
+
+ private:
+ FieldPtr field_;
+};
+
+/// \brief Node in the expression tree, representing a function.
+class GANDIVA_EXPORT FunctionNode : public Node {
+ public:
+ FunctionNode(const std::string& name, const NodeVector& children, DataTypePtr retType);
+
+ Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); }
+
+ const FuncDescriptorPtr& descriptor() const { return descriptor_; }
+ const NodeVector& children() const { return children_; }
+
+ std::string ToString() const override {
+ std::stringstream ss;
+ ss << descriptor()->return_type()->ToString() << " " << descriptor()->name() << "(";
+ bool skip_comma = true;
+ for (auto& child : children()) {
+ if (skip_comma) {
+ ss << child->ToString();
+ skip_comma = false;
+ } else {
+ ss << ", " << child->ToString();
+ }
+ }
+ ss << ")";
+ return ss.str();
+ }
+
+ private:
+ FuncDescriptorPtr descriptor_;
+ NodeVector children_;
+};
+
+inline FunctionNode::FunctionNode(const std::string& name, const NodeVector& children,
+ DataTypePtr return_type)
+ : Node(return_type), children_(children) {
+ DataTypeVector param_types;
+ for (auto& child : children) {
+ param_types.push_back(child->return_type());
+ }
+
+ descriptor_ = FuncDescriptorPtr(new FuncDescriptor(name, param_types, return_type));
+}
+
+/// \brief Node in the expression tree, representing an if-else expression.
+class GANDIVA_EXPORT IfNode : public Node {
+ public:
+ IfNode(NodePtr condition, NodePtr then_node, NodePtr else_node, DataTypePtr result_type)
+ : Node(result_type),
+ condition_(condition),
+ then_node_(then_node),
+ else_node_(else_node) {}
+
+ Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); }
+
+ const NodePtr& condition() const { return condition_; }
+ const NodePtr& then_node() const { return then_node_; }
+ const NodePtr& else_node() const { return else_node_; }
+
+ std::string ToString() const override {
+ std::stringstream ss;
+ ss << "if (" << condition()->ToString() << ") { ";
+ ss << then_node()->ToString() << " } else { ";
+ ss << else_node()->ToString() << " }";
+ return ss.str();
+ }
+
+ private:
+ NodePtr condition_;
+ NodePtr then_node_;
+ NodePtr else_node_;
+};
+
+/// \brief Node in the expression tree, representing an and/or boolean expression.
+class GANDIVA_EXPORT BooleanNode : public Node {
+ public:
+ enum ExprType : char { AND, OR };
+
+ BooleanNode(ExprType expr_type, const NodeVector& children)
+ : Node(arrow::boolean()), expr_type_(expr_type), children_(children) {}
+
+ Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); }
+
+ ExprType expr_type() const { return expr_type_; }
+
+ const NodeVector& children() const { return children_; }
+
+ std::string ToString() const override {
+ std::stringstream ss;
+ bool first = true;
+ for (auto& child : children_) {
+ if (!first) {
+ if (expr_type() == BooleanNode::AND) {
+ ss << " && ";
+ } else {
+ ss << " || ";
+ }
+ }
+ ss << child->ToString();
+ first = false;
+ }
+ return ss.str();
+ }
+
+ private:
+ ExprType expr_type_;
+ NodeVector children_;
+};
+
+/// \brief Node in expression tree, representing an in expression.
+template <typename Type>
+class InExpressionNode : public Node {
+ public:
+ InExpressionNode(NodePtr eval_expr, const std::unordered_set<Type>& values)
+ : Node(arrow::boolean()), eval_expr_(eval_expr), values_(values) {}
+
+ const NodePtr& eval_expr() const { return eval_expr_; }
+
+ const std::unordered_set<Type>& values() const { return values_; }
+
+ Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); }
+
+ std::string ToString() const override {
+ std::stringstream ss;
+ ss << eval_expr_->ToString() << " IN (";
+ bool add_comma = false;
+ for (auto& value : values_) {
+ if (add_comma) {
+ ss << ", ";
+ }
+ // add type in the front to differentiate
+ ss << value;
+ add_comma = true;
+ }
+ ss << ")";
+ return ss.str();
+ }
+
+ private:
+ NodePtr eval_expr_;
+ std::unordered_set<Type> values_;
+};
+
+template <>
+class InExpressionNode<gandiva::DecimalScalar128> : public Node {
+ public:
+ InExpressionNode(NodePtr eval_expr,
+ std::unordered_set<gandiva::DecimalScalar128>& values,
+ int32_t precision, int32_t scale)
+ : Node(arrow::boolean()),
+ eval_expr_(std::move(eval_expr)),
+ values_(std::move(values)),
+ precision_(precision),
+ scale_(scale) {}
+
+ int32_t get_precision() const { return precision_; }
+
+ int32_t get_scale() const { return scale_; }
+
+ const NodePtr& eval_expr() const { return eval_expr_; }
+
+ const std::unordered_set<gandiva::DecimalScalar128>& values() const { return values_; }
+
+ Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); }
+
+ std::string ToString() const override {
+ std::stringstream ss;
+ ss << eval_expr_->ToString() << " IN (";
+ bool add_comma = false;
+ for (auto& value : values_) {
+ if (add_comma) {
+ ss << ", ";
+ }
+ // add type in the front to differentiate
+ ss << value;
+ add_comma = true;
+ }
+ ss << ")";
+ return ss.str();
+ }
+
+ private:
+ NodePtr eval_expr_;
+ std::unordered_set<gandiva::DecimalScalar128> values_;
+ int32_t precision_, scale_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/node_visitor.h b/src/arrow/cpp/src/gandiva/node_visitor.h
new file mode 100644
index 000000000..8f233f5b7
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/node_visitor.h
@@ -0,0 +1,56 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cmath>
+#include <string>
+
+#include "arrow/status.h"
+
+#include "arrow/util/logging.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+class FieldNode;
+class FunctionNode;
+class IfNode;
+class LiteralNode;
+class BooleanNode;
+template <typename Type>
+class InExpressionNode;
+
+/// \brief Visitor for nodes in the expression tree.
+class GANDIVA_EXPORT NodeVisitor {
+ public:
+ virtual ~NodeVisitor() = default;
+
+ virtual Status Visit(const FieldNode& node) = 0;
+ virtual Status Visit(const FunctionNode& node) = 0;
+ virtual Status Visit(const IfNode& node) = 0;
+ virtual Status Visit(const LiteralNode& node) = 0;
+ virtual Status Visit(const BooleanNode& node) = 0;
+ virtual Status Visit(const InExpressionNode<int32_t>& node) = 0;
+ virtual Status Visit(const InExpressionNode<int64_t>& node) = 0;
+ virtual Status Visit(const InExpressionNode<float>& node) = 0;
+ virtual Status Visit(const InExpressionNode<double>& node) = 0;
+ virtual Status Visit(const InExpressionNode<gandiva::DecimalScalar128>& node) = 0;
+ virtual Status Visit(const InExpressionNode<std::string>& node) = 0;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/pch.h b/src/arrow/cpp/src/gandiva/pch.h
new file mode 100644
index 000000000..f3d9b2fad
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/pch.h
@@ -0,0 +1,24 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Often-used headers, for precompiling.
+// If updating this header, please make sure you check compilation speed
+// before checking in. Adding headers which are not used extremely often
+// may incur a slowdown, since it makes the precompiled header heavier to load.
+
+#include "arrow/pch.h"
+#include "gandiva/llvm_types.h"
diff --git a/src/arrow/cpp/src/gandiva/precompiled/CMakeLists.txt b/src/arrow/cpp/src/gandiva/precompiled/CMakeLists.txt
new file mode 100644
index 000000000..650b80f6b
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/CMakeLists.txt
@@ -0,0 +1,142 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+project(gandiva)
+
+set(PRECOMPILED_SRCS
+ arithmetic_ops.cc
+ bitmap.cc
+ decimal_ops.cc
+ decimal_wrapper.cc
+ extended_math_ops.cc
+ hash.cc
+ print.cc
+ string_ops.cc
+ time.cc
+ timestamp_arithmetic.cc
+ ../../arrow/util/basic_decimal.cc)
+
+if(MSVC)
+ # clang pretends to be a particular version of MSVC. 191[0-9] is
+ # Visual Studio 2017, and the standard library uses C++14 features,
+ # so we have to use that -std version to get the IR compilation to work
+ if(MSVC_VERSION MATCHES "^191[0-9]$")
+ set(FMS_COMPATIBILITY 19.10)
+ else()
+ message(FATAL_ERROR "Unsupported MSVC_VERSION=${MSVC_VERSION}")
+ endif()
+ set(PLATFORM_CLANG_OPTIONS -std=c++14 -fms-compatibility
+ -fms-compatibility-version=${FMS_COMPATIBILITY})
+else()
+ set(PLATFORM_CLANG_OPTIONS -std=c++11)
+endif()
+
+# Create bitcode for each of the source files.
+foreach(SRC_FILE ${PRECOMPILED_SRCS})
+ get_filename_component(SRC_BASE ${SRC_FILE} NAME_WE)
+ get_filename_component(ABSOLUTE_SRC ${SRC_FILE} ABSOLUTE)
+ set(BC_FILE ${CMAKE_CURRENT_BINARY_DIR}/${SRC_BASE}.bc)
+ set(PRECOMPILE_COMMAND)
+ if(CMAKE_OSX_SYSROOT)
+ list(APPEND
+ PRECOMPILE_COMMAND
+ ${CMAKE_COMMAND}
+ -E
+ env
+ SDKROOT=${CMAKE_OSX_SYSROOT})
+ endif()
+ list(APPEND
+ PRECOMPILE_COMMAND
+ ${CLANG_EXECUTABLE}
+ ${PLATFORM_CLANG_OPTIONS}
+ -DGANDIVA_IR
+ -DNDEBUG # DCHECK macros not implemented in precompiled code
+ -DARROW_STATIC # Do not set __declspec(dllimport) on MSVC on Arrow symbols
+ -DGANDIVA_STATIC # Do not set __declspec(dllimport) on MSVC on Gandiva symbols
+ -fno-use-cxa-atexit # Workaround for unresolved __dso_handle
+ -emit-llvm
+ -O3
+ -c
+ ${ABSOLUTE_SRC}
+ -o
+ ${BC_FILE}
+ ${ARROW_GANDIVA_PC_CXX_FLAGS}
+ -I${CMAKE_SOURCE_DIR}/src
+ -I${ARROW_BINARY_DIR}/src)
+
+ if(NOT ARROW_USE_NATIVE_INT128)
+ list(APPEND PRECOMPILE_COMMAND -I${Boost_INCLUDE_DIR})
+ endif()
+ add_custom_command(OUTPUT ${BC_FILE}
+ COMMAND ${PRECOMPILE_COMMAND}
+ DEPENDS ${SRC_FILE})
+ list(APPEND BC_FILES ${BC_FILE})
+endforeach()
+
+# link all of the bitcode files into a single bitcode file.
+add_custom_command(OUTPUT ${GANDIVA_PRECOMPILED_BC_PATH}
+ COMMAND ${LLVM_LINK_EXECUTABLE} -o ${GANDIVA_PRECOMPILED_BC_PATH}
+ ${BC_FILES}
+ DEPENDS ${BC_FILES})
+
+# turn the bitcode file into a C++ static data variable.
+add_custom_command(OUTPUT ${GANDIVA_PRECOMPILED_CC_PATH}
+ COMMAND ${PYTHON_EXECUTABLE}
+ "${CMAKE_CURRENT_SOURCE_DIR}/../make_precompiled_bitcode.py"
+ ${GANDIVA_PRECOMPILED_CC_IN_PATH}
+ ${GANDIVA_PRECOMPILED_BC_PATH} ${GANDIVA_PRECOMPILED_CC_PATH}
+ DEPENDS ${GANDIVA_PRECOMPILED_CC_IN_PATH}
+ ${GANDIVA_PRECOMPILED_BC_PATH})
+
+add_custom_target(precompiled ALL DEPENDS ${GANDIVA_PRECOMPILED_BC_PATH}
+ ${GANDIVA_PRECOMPILED_CC_PATH})
+
+# testing
+if(ARROW_BUILD_TESTS)
+ add_executable(gandiva-precompiled-test
+ ../context_helper.cc
+ bitmap_test.cc
+ bitmap.cc
+ epoch_time_point_test.cc
+ time_test.cc
+ time.cc
+ timestamp_arithmetic.cc
+ ../cast_time.cc
+ ../../arrow/vendored/datetime/tz.cpp
+ hash_test.cc
+ hash.cc
+ string_ops_test.cc
+ string_ops.cc
+ arithmetic_ops_test.cc
+ arithmetic_ops.cc
+ extended_math_ops_test.cc
+ extended_math_ops.cc
+ decimal_ops_test.cc
+ decimal_ops.cc
+ ../decimal_type_util.cc
+ ../decimal_xlarge.cc)
+ target_include_directories(gandiva-precompiled-test PRIVATE ${CMAKE_SOURCE_DIR}/src)
+ target_link_libraries(gandiva-precompiled-test PRIVATE ${ARROW_TEST_LINK_LIBS})
+ target_compile_definitions(gandiva-precompiled-test PRIVATE GANDIVA_UNIT_TEST=1
+ ARROW_STATIC GANDIVA_STATIC)
+ set(TEST_PATH "${EXECUTABLE_OUTPUT_PATH}/gandiva-precompiled-test")
+ add_test(gandiva-precompiled-test ${TEST_PATH})
+ set_property(TEST gandiva-precompiled-test
+ APPEND
+ PROPERTY LABELS "unittest;gandiva-tests")
+ add_dependencies(gandiva-tests gandiva-precompiled-test)
+endif()
diff --git a/src/arrow/cpp/src/gandiva/precompiled/arithmetic_ops.cc b/src/arrow/cpp/src/gandiva/precompiled/arithmetic_ops.cc
new file mode 100644
index 000000000..c736c38d3
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/arithmetic_ops.cc
@@ -0,0 +1,274 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+extern "C" {
+
+#include <math.h>
+#include "./types.h"
+
+// Expand inner macro for all numeric types.
+#define NUMERIC_TYPES(INNER, NAME, OP) \
+ INNER(NAME, int8, OP) \
+ INNER(NAME, int16, OP) \
+ INNER(NAME, int32, OP) \
+ INNER(NAME, int64, OP) \
+ INNER(NAME, uint8, OP) \
+ INNER(NAME, uint16, OP) \
+ INNER(NAME, uint32, OP) \
+ INNER(NAME, uint64, OP) \
+ INNER(NAME, float32, OP) \
+ INNER(NAME, float64, OP)
+
+// Expand inner macros for all date/time types.
+#define DATE_TYPES(INNER, NAME, OP) \
+ INNER(NAME, date64, OP) \
+ INNER(NAME, date32, OP) \
+ INNER(NAME, timestamp, OP) \
+ INNER(NAME, time32, OP)
+
+#define NUMERIC_DATE_TYPES(INNER, NAME, OP) \
+ NUMERIC_TYPES(INNER, NAME, OP) \
+ DATE_TYPES(INNER, NAME, OP)
+
+#define NUMERIC_BOOL_DATE_TYPES(INNER, NAME, OP) \
+ NUMERIC_TYPES(INNER, NAME, OP) \
+ DATE_TYPES(INNER, NAME, OP) \
+ INNER(NAME, boolean, OP)
+
+#define MOD_OP(NAME, IN_TYPE1, IN_TYPE2, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE NAME##_##IN_TYPE1##_##IN_TYPE2(gdv_##IN_TYPE1 left, \
+ gdv_##IN_TYPE2 right) { \
+ return (right == 0 ? static_cast<gdv_##OUT_TYPE>(left) \
+ : static_cast<gdv_##OUT_TYPE>(left % right)); \
+ }
+
+// Symmetric binary fns : left, right params and return type are same.
+#define BINARY_SYMMETRIC(NAME, TYPE, OP) \
+ FORCE_INLINE \
+ gdv_##TYPE NAME##_##TYPE##_##TYPE(gdv_##TYPE left, gdv_##TYPE right) { \
+ return static_cast<gdv_##TYPE>(left OP right); \
+ }
+
+NUMERIC_TYPES(BINARY_SYMMETRIC, add, +)
+NUMERIC_TYPES(BINARY_SYMMETRIC, subtract, -)
+NUMERIC_TYPES(BINARY_SYMMETRIC, multiply, *)
+BINARY_SYMMETRIC(bitwise_and, int32, &)
+BINARY_SYMMETRIC(bitwise_and, int64, &)
+BINARY_SYMMETRIC(bitwise_or, int32, |)
+BINARY_SYMMETRIC(bitwise_or, int64, |)
+BINARY_SYMMETRIC(bitwise_xor, int32, ^)
+BINARY_SYMMETRIC(bitwise_xor, int64, ^)
+
+#undef BINARY_SYMMETRIC
+
+MOD_OP(mod, int64, int32, int32)
+MOD_OP(mod, int64, int64, int64)
+
+#undef MOD_OP
+
+gdv_float64 mod_float64_float64(int64_t context, gdv_float64 x, gdv_float64 y) {
+ if (y == 0.0) {
+ char const* err_msg = "divide by zero error";
+ gdv_fn_context_set_error_msg(context, err_msg);
+ return 0.0;
+ }
+ return fmod(x, y);
+}
+
+// Relational binary fns : left, right params are same, return is bool.
+#define BINARY_RELATIONAL(NAME, TYPE, OP) \
+ FORCE_INLINE \
+ bool NAME##_##TYPE##_##TYPE(gdv_##TYPE left, gdv_##TYPE right) { return left OP right; }
+
+NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL, equal, ==)
+NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL, not_equal, !=)
+NUMERIC_DATE_TYPES(BINARY_RELATIONAL, less_than, <)
+NUMERIC_DATE_TYPES(BINARY_RELATIONAL, less_than_or_equal_to, <=)
+NUMERIC_DATE_TYPES(BINARY_RELATIONAL, greater_than, >)
+NUMERIC_DATE_TYPES(BINARY_RELATIONAL, greater_than_or_equal_to, >=)
+
+#undef BINARY_RELATIONAL
+
+// cast fns : takes one param type, returns another type.
+#define CAST_UNARY(NAME, IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE NAME##_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_##OUT_TYPE>(in); \
+ }
+
+CAST_UNARY(castBIGINT, int32, int64)
+CAST_UNARY(castINT, int64, int32)
+CAST_UNARY(castFLOAT4, int32, float32)
+CAST_UNARY(castFLOAT4, int64, float32)
+CAST_UNARY(castFLOAT8, int32, float64)
+CAST_UNARY(castFLOAT8, int64, float64)
+CAST_UNARY(castFLOAT8, float32, float64)
+CAST_UNARY(castFLOAT4, float64, float32)
+
+#undef CAST_UNARY
+
+// cast float types to int types.
+#define CAST_INT_FLOAT(NAME, IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE NAME##_##IN_TYPE(gdv_##IN_TYPE in) { \
+ gdv_##OUT_TYPE out = static_cast<gdv_##OUT_TYPE>(round(in)); \
+ return out; \
+ }
+
+CAST_INT_FLOAT(castBIGINT, float32, int64)
+CAST_INT_FLOAT(castBIGINT, float64, int64)
+CAST_INT_FLOAT(castINT, float32, int32)
+CAST_INT_FLOAT(castINT, float64, int32)
+
+#undef CAST_INT_FLOAT
+
+// simple nullable functions, result value = fn(input validity)
+#define VALIDITY_OP(NAME, TYPE, OP) \
+ FORCE_INLINE \
+ bool NAME##_##TYPE(gdv_##TYPE in, gdv_boolean is_valid) { return OP is_valid; }
+
+NUMERIC_BOOL_DATE_TYPES(VALIDITY_OP, isnull, !)
+NUMERIC_BOOL_DATE_TYPES(VALIDITY_OP, isnotnull, +)
+NUMERIC_TYPES(VALIDITY_OP, isnumeric, +)
+
+#undef VALIDITY_OP
+
+#define NUMERIC_FUNCTION(INNER) \
+ INNER(int8) \
+ INNER(int16) \
+ INNER(int32) \
+ INNER(int64) \
+ INNER(uint8) \
+ INNER(uint16) \
+ INNER(uint32) \
+ INNER(uint64) \
+ INNER(float32) \
+ INNER(float64)
+
+#define DATE_FUNCTION(INNER) \
+ INNER(date32) \
+ INNER(date64) \
+ INNER(timestamp) \
+ INNER(time32)
+
+#define NUMERIC_BOOL_DATE_FUNCTION(INNER) \
+ NUMERIC_FUNCTION(INNER) \
+ DATE_FUNCTION(INNER) \
+ INNER(boolean)
+
+FORCE_INLINE
+gdv_boolean not_boolean(gdv_boolean in) { return !in; }
+
+// is_distinct_from
+#define IS_DISTINCT_FROM(TYPE) \
+ FORCE_INLINE \
+ bool is_distinct_from_##TYPE##_##TYPE(gdv_##TYPE in1, gdv_boolean is_valid1, \
+ gdv_##TYPE in2, gdv_boolean is_valid2) { \
+ if (is_valid1 != is_valid2) { \
+ return true; \
+ } \
+ if (!is_valid1) { \
+ return false; \
+ } \
+ return in1 != in2; \
+ }
+
+// is_not_distinct_from
+#define IS_NOT_DISTINCT_FROM(TYPE) \
+ FORCE_INLINE \
+ bool is_not_distinct_from_##TYPE##_##TYPE(gdv_##TYPE in1, gdv_boolean is_valid1, \
+ gdv_##TYPE in2, gdv_boolean is_valid2) { \
+ if (is_valid1 != is_valid2) { \
+ return false; \
+ } \
+ if (!is_valid1) { \
+ return true; \
+ } \
+ return in1 == in2; \
+ }
+
+NUMERIC_BOOL_DATE_FUNCTION(IS_DISTINCT_FROM)
+NUMERIC_BOOL_DATE_FUNCTION(IS_NOT_DISTINCT_FROM)
+
+#undef IS_DISTINCT_FROM
+#undef IS_NOT_DISTINCT_FROM
+
+#define DIVIDE(TYPE) \
+ FORCE_INLINE \
+ gdv_##TYPE divide_##TYPE##_##TYPE(gdv_int64 context, gdv_##TYPE in1, gdv_##TYPE in2) { \
+ if (in2 == 0) { \
+ char const* err_msg = "divide by zero error"; \
+ gdv_fn_context_set_error_msg(context, err_msg); \
+ return 0; \
+ } \
+ return static_cast<gdv_##TYPE>(in1 / in2); \
+ }
+
+NUMERIC_FUNCTION(DIVIDE)
+
+#undef DIVIDE
+
+#define DIV(TYPE) \
+ FORCE_INLINE \
+ gdv_##TYPE div_##TYPE##_##TYPE(gdv_int64 context, gdv_##TYPE in1, gdv_##TYPE in2) { \
+ if (in2 == 0) { \
+ char const* err_msg = "divide by zero error"; \
+ gdv_fn_context_set_error_msg(context, err_msg); \
+ return 0; \
+ } \
+ return static_cast<gdv_##TYPE>(in1 / in2); \
+ }
+
+DIV(int32)
+DIV(int64)
+
+#undef DIV
+
+#define DIV_FLOAT(TYPE) \
+ FORCE_INLINE \
+ gdv_##TYPE div_##TYPE##_##TYPE(gdv_int64 context, gdv_##TYPE in1, gdv_##TYPE in2) { \
+ if (in2 == 0) { \
+ char const* err_msg = "divide by zero error"; \
+ gdv_fn_context_set_error_msg(context, err_msg); \
+ return 0; \
+ } \
+ return static_cast<gdv_##TYPE>(::trunc(in1 / in2)); \
+ }
+
+DIV_FLOAT(float32)
+DIV_FLOAT(float64)
+
+#undef DIV_FLOAT
+
+#define BITWISE_NOT(TYPE) \
+ FORCE_INLINE \
+ gdv_##TYPE bitwise_not_##TYPE(gdv_##TYPE in) { return static_cast<gdv_##TYPE>(~in); }
+
+BITWISE_NOT(int32)
+BITWISE_NOT(int64)
+
+#undef BITWISE_NOT
+
+#undef DATE_FUNCTION
+#undef DATE_TYPES
+#undef NUMERIC_BOOL_DATE_TYPES
+#undef NUMERIC_DATE_TYPES
+#undef NUMERIC_FUNCTION
+#undef NUMERIC_TYPES
+
+} // extern "C"
diff --git a/src/arrow/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc b/src/arrow/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc
new file mode 100644
index 000000000..36b50bcfd
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc
@@ -0,0 +1,180 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "../execution_context.h"
+#include "gandiva/precompiled/types.h"
+
+namespace gandiva {
+
+TEST(TestArithmeticOps, TestIsDistinctFrom) {
+ EXPECT_EQ(is_distinct_from_timestamp_timestamp(1000, true, 1000, false), true);
+ EXPECT_EQ(is_distinct_from_timestamp_timestamp(1000, false, 1000, true), true);
+ EXPECT_EQ(is_distinct_from_timestamp_timestamp(1000, false, 1000, false), false);
+ EXPECT_EQ(is_distinct_from_timestamp_timestamp(1000, true, 1000, true), false);
+
+ EXPECT_EQ(is_not_distinct_from_int32_int32(1000, true, 1000, false), false);
+ EXPECT_EQ(is_not_distinct_from_int32_int32(1000, false, 1000, true), false);
+ EXPECT_EQ(is_not_distinct_from_int32_int32(1000, false, 1000, false), true);
+ EXPECT_EQ(is_not_distinct_from_int32_int32(1000, true, 1000, true), true);
+}
+
+TEST(TestArithmeticOps, TestMod) {
+ gandiva::ExecutionContext context;
+ EXPECT_EQ(mod_int64_int32(10, 0), 10);
+
+ const double acceptable_abs_error = 0.00000000001; // 1e-10
+
+ EXPECT_DOUBLE_EQ(mod_float64_float64(reinterpret_cast<gdv_int64>(&context), 2.5, 0.0),
+ 0.0);
+ EXPECT_TRUE(context.has_error());
+ EXPECT_EQ(context.get_error(), "divide by zero error");
+
+ context.Reset();
+ EXPECT_NEAR(mod_float64_float64(reinterpret_cast<gdv_int64>(&context), 2.5, 1.2), 0.1,
+ acceptable_abs_error);
+ EXPECT_FALSE(context.has_error());
+
+ context.Reset();
+ EXPECT_DOUBLE_EQ(mod_float64_float64(reinterpret_cast<gdv_int64>(&context), 2.5, 2.5),
+ 0.0);
+ EXPECT_FALSE(context.has_error());
+
+ context.Reset();
+ EXPECT_NEAR(mod_float64_float64(reinterpret_cast<gdv_int64>(&context), 9.2, 3.7), 1.8,
+ acceptable_abs_error);
+ EXPECT_FALSE(context.has_error());
+}
+
+TEST(TestArithmeticOps, TestDivide) {
+ gandiva::ExecutionContext context;
+ EXPECT_EQ(divide_int64_int64(reinterpret_cast<gdv_int64>(&context), 10, 0), 0);
+ EXPECT_EQ(context.has_error(), true);
+ EXPECT_EQ(context.get_error(), "divide by zero error");
+
+ context.Reset();
+ EXPECT_EQ(divide_int64_int64(reinterpret_cast<gdv_int64>(&context), 10, 2), 5);
+ EXPECT_EQ(context.has_error(), false);
+}
+
+TEST(TestArithmeticOps, TestDiv) {
+ gandiva::ExecutionContext context;
+ EXPECT_EQ(div_int64_int64(reinterpret_cast<gdv_int64>(&context), 101, 0), 0);
+ EXPECT_EQ(context.has_error(), true);
+ EXPECT_EQ(context.get_error(), "divide by zero error");
+ context.Reset();
+
+ EXPECT_EQ(div_int64_int64(reinterpret_cast<gdv_int64>(&context), 101, 111), 0);
+ EXPECT_EQ(context.has_error(), false);
+ context.Reset();
+
+ EXPECT_EQ(div_float64_float64(reinterpret_cast<gdv_int64>(&context), 1010.1010, 2.1),
+ 481.0);
+ EXPECT_EQ(context.has_error(), false);
+ context.Reset();
+
+ EXPECT_EQ(
+ div_float64_float64(reinterpret_cast<gdv_int64>(&context), 1010.1010, 0.00000),
+ 0.0);
+ EXPECT_EQ(context.has_error(), true);
+ EXPECT_EQ(context.get_error(), "divide by zero error");
+ context.Reset();
+
+ EXPECT_EQ(div_float32_float32(reinterpret_cast<gdv_int64>(&context), 1010.1010f, 2.1f),
+ 481.0f);
+ EXPECT_EQ(context.has_error(), false);
+ context.Reset();
+}
+
+TEST(TestArithmeticOps, TestBitwiseOps) {
+ // bitwise AND
+ EXPECT_EQ(bitwise_and_int32_int32(0x0147D, 0x17159), 0x01059);
+ EXPECT_EQ(bitwise_and_int32_int32(0xFFFFFFCC, 0x00000297), 0x00000284);
+ EXPECT_EQ(bitwise_and_int32_int32(0x000, 0x285), 0x000);
+ EXPECT_EQ(bitwise_and_int64_int64(0x563672F83, 0x0D9FCF85B), 0x041642803);
+ EXPECT_EQ(bitwise_and_int64_int64(0xFFFFFFFFFFDA8F6A, 0xFFFFFFFFFFFF791C),
+ 0xFFFFFFFFFFDA0908);
+ EXPECT_EQ(bitwise_and_int64_int64(0x6A5B1, 0x00000), 0x00000);
+
+ // bitwise OR
+ EXPECT_EQ(bitwise_or_int32_int32(0x0147D, 0x17159), 0x1757D);
+ EXPECT_EQ(bitwise_or_int32_int32(0xFFFFFFCC, 0x00000297), 0xFFFFFFDF);
+ EXPECT_EQ(bitwise_or_int32_int32(0x000, 0x285), 0x285);
+ EXPECT_EQ(bitwise_or_int64_int64(0x563672F83, 0x0D9FCF85B), 0x5FBFFFFDB);
+ EXPECT_EQ(bitwise_or_int64_int64(0xFFFFFFFFFFDA8F6A, 0xFFFFFFFFFFFF791C),
+ 0xFFFFFFFFFFFFFF7E);
+ EXPECT_EQ(bitwise_or_int64_int64(0x6A5B1, 0x00000), 0x6A5B1);
+
+ // bitwise XOR
+ EXPECT_EQ(bitwise_xor_int32_int32(0x0147D, 0x17159), 0x16524);
+ EXPECT_EQ(bitwise_xor_int32_int32(0xFFFFFFCC, 0x00000297), 0XFFFFFD5B);
+ EXPECT_EQ(bitwise_xor_int32_int32(0x000, 0x285), 0x285);
+ EXPECT_EQ(bitwise_xor_int64_int64(0x563672F83, 0x0D9FCF85B), 0x5BA9BD7D8);
+ EXPECT_EQ(bitwise_xor_int64_int64(0xFFFFFFFFFFDA8F6A, 0xFFFFFFFFFFFF791C), 0X25F676);
+ EXPECT_EQ(bitwise_xor_int64_int64(0x6A5B1, 0x00000), 0x6A5B1);
+ EXPECT_EQ(bitwise_xor_int64_int64(0x6A5B1, 0x6A5B1), 0x00000);
+
+ // bitwise NOT
+ EXPECT_EQ(bitwise_not_int32(0x00017159), 0xFFFE8EA6);
+ EXPECT_EQ(bitwise_not_int32(0xFFFFF226), 0x00000DD9);
+ EXPECT_EQ(bitwise_not_int64(0x000000008BCAE9B4), 0xFFFFFFFF7435164B);
+ EXPECT_EQ(bitwise_not_int64(0xFFFFFF966C8D7997), 0x0000006993728668);
+ EXPECT_EQ(bitwise_not_int64(0x0000000000000000), 0xFFFFFFFFFFFFFFFF);
+}
+
+TEST(TestArithmeticOps, TestIntCastFloatDouble) {
+ // castINT from floats
+ EXPECT_EQ(castINT_float32(6.6f), 7);
+ EXPECT_EQ(castINT_float32(-6.6f), -7);
+ EXPECT_EQ(castINT_float32(-6.3f), -6);
+ EXPECT_EQ(castINT_float32(0.0f), 0);
+ EXPECT_EQ(castINT_float32(-0), 0);
+
+ // castINT from doubles
+ EXPECT_EQ(castINT_float64(6.6), 7);
+ EXPECT_EQ(castINT_float64(-6.6), -7);
+ EXPECT_EQ(castINT_float64(-6.3), -6);
+ EXPECT_EQ(castINT_float64(0.0), 0);
+ EXPECT_EQ(castINT_float64(-0), 0);
+ EXPECT_EQ(castINT_float64(999999.99999999999999999999999), 1000000);
+ EXPECT_EQ(castINT_float64(-999999.99999999999999999999999), -1000000);
+ EXPECT_EQ(castINT_float64(INT32_MAX), 2147483647);
+ EXPECT_EQ(castINT_float64(-2147483647), -2147483647);
+}
+
+TEST(TestArithmeticOps, TestBigIntCastFloatDouble) {
+ // castINT from floats
+ EXPECT_EQ(castBIGINT_float32(6.6f), 7);
+ EXPECT_EQ(castBIGINT_float32(-6.6f), -7);
+ EXPECT_EQ(castBIGINT_float32(-6.3f), -6);
+ EXPECT_EQ(castBIGINT_float32(0.0f), 0);
+ EXPECT_EQ(castBIGINT_float32(-0), 0);
+
+ // castINT from doubles
+ EXPECT_EQ(castBIGINT_float64(6.6), 7);
+ EXPECT_EQ(castBIGINT_float64(-6.6), -7);
+ EXPECT_EQ(castBIGINT_float64(-6.3), -6);
+ EXPECT_EQ(castBIGINT_float64(0.0), 0);
+ EXPECT_EQ(castBIGINT_float64(-0), 0);
+ EXPECT_EQ(castBIGINT_float64(999999.99999999999999999999999), 1000000);
+ EXPECT_EQ(castBIGINT_float64(-999999.99999999999999999999999), -1000000);
+ EXPECT_EQ(castBIGINT_float64(INT32_MAX), 2147483647);
+ EXPECT_EQ(castBIGINT_float64(-2147483647), -2147483647);
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/precompiled/bitmap.cc b/src/arrow/cpp/src/gandiva/precompiled/bitmap.cc
new file mode 100644
index 000000000..332f08dbe
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/bitmap.cc
@@ -0,0 +1,60 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// BitMap functions
+
+#include "arrow/util/bit_util.h"
+
+extern "C" {
+
+#include "./types.h"
+
+#define BITS_TO_BYTES(x) ((x + 7) / 8)
+#define BITS_TO_WORDS(x) ((x + 63) / 64)
+
+#define POS_TO_BYTE_INDEX(p) (p / 8)
+#define POS_TO_BIT_INDEX(p) (p % 8)
+
+FORCE_INLINE
+bool bitMapGetBit(const uint8_t* bmap, int64_t position) {
+ return arrow::BitUtil::GetBit(bmap, position);
+}
+
+FORCE_INLINE
+bool bitMapValidityGetBit(const uint8_t* bmap, int64_t position) {
+ if (bmap == nullptr) {
+ // if validity bitmap is null, all entries are valid.
+ return true;
+ } else {
+ return bitMapGetBit(bmap, position);
+ }
+}
+
+FORCE_INLINE
+void bitMapSetBit(uint8_t* bmap, int64_t position, bool value) {
+ arrow::BitUtil::SetBitTo(bmap, position, value);
+}
+
+// Clear the bit if value = false. Does nothing if value = true.
+FORCE_INLINE
+void bitMapClearBitIfFalse(uint8_t* bmap, int64_t position, bool value) {
+ if (!value) {
+ arrow::BitUtil::ClearBit(bmap, position);
+ }
+}
+
+} // extern "C"
diff --git a/src/arrow/cpp/src/gandiva/precompiled/bitmap_test.cc b/src/arrow/cpp/src/gandiva/precompiled/bitmap_test.cc
new file mode 100644
index 000000000..ac3084ade
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/bitmap_test.cc
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include "gandiva/precompiled/types.h"
+
+namespace gandiva {
+
+TEST(TestBitMap, TestSimple) {
+ static const int kNumBytes = 16;
+ uint8_t bit_map[kNumBytes];
+ memset(bit_map, 0, kNumBytes);
+
+ EXPECT_EQ(bitMapGetBit(bit_map, 100), false);
+
+ // set 100th bit and verify
+ bitMapSetBit(bit_map, 100, true);
+ EXPECT_EQ(bitMapGetBit(bit_map, 100), true);
+
+ // clear 100th bit and verify
+ bitMapSetBit(bit_map, 100, false);
+ EXPECT_EQ(bitMapGetBit(bit_map, 100), false);
+}
+
+TEST(TestBitMap, TestClearIfFalse) {
+ static const int kNumBytes = 32;
+ uint8_t bit_map[kNumBytes];
+ memset(bit_map, 0, kNumBytes);
+
+ bitMapSetBit(bit_map, 24, true);
+
+ // bit should remain unchanged.
+ bitMapClearBitIfFalse(bit_map, 24, true);
+ EXPECT_EQ(bitMapGetBit(bit_map, 24), true);
+
+ // bit should be cleared.
+ bitMapClearBitIfFalse(bit_map, 24, false);
+ EXPECT_EQ(bitMapGetBit(bit_map, 24), false);
+
+ // this function should have no impact if the bit is already clear.
+ bitMapClearBitIfFalse(bit_map, 24, true);
+ EXPECT_EQ(bitMapGetBit(bit_map, 24), false);
+
+ bitMapClearBitIfFalse(bit_map, 24, false);
+ EXPECT_EQ(bitMapGetBit(bit_map, 24), false);
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/precompiled/decimal_ops.cc b/src/arrow/cpp/src/gandiva/precompiled/decimal_ops.cc
new file mode 100644
index 000000000..61cac6062
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/decimal_ops.cc
@@ -0,0 +1,723 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Algorithms adapted from Apache Impala
+
+#include "gandiva/precompiled/decimal_ops.h"
+
+#include <algorithm>
+#include <cmath>
+#include <limits>
+
+#include "arrow/util/logging.h"
+#include "gandiva/decimal_type_util.h"
+#include "gandiva/decimal_xlarge.h"
+#include "gandiva/gdv_function_stubs.h"
+
+// Several operations (multiply, divide, mod, ..) require converting to 256-bit, and we
+// use the boost library for doing 256-bit operations. To avoid references to boost from
+// the precompiled-to-ir code (this causes issues with symbol resolution at runtime), we
+// use a wrapper exported from the CPP code. The wrapper functions are named gdv_xlarge_xx
+
+namespace gandiva {
+namespace decimalops {
+
+using arrow::BasicDecimal128;
+
+static BasicDecimal128 CheckAndIncreaseScale(const BasicDecimal128& in, int32_t delta) {
+ return (delta <= 0) ? in : in.IncreaseScaleBy(delta);
+}
+
+static BasicDecimal128 CheckAndReduceScale(const BasicDecimal128& in, int32_t delta) {
+ return (delta <= 0) ? in : in.ReduceScaleBy(delta);
+}
+
+/// Adjust x and y to the same scale, and add them.
+static BasicDecimal128 AddFastPath(const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y, int32_t out_scale) {
+ auto higher_scale = std::max(x.scale(), y.scale());
+
+ auto x_scaled = CheckAndIncreaseScale(x.value(), higher_scale - x.scale());
+ auto y_scaled = CheckAndIncreaseScale(y.value(), higher_scale - y.scale());
+ return x_scaled + y_scaled;
+}
+
+/// Add x and y, caller has ensured there can be no overflow.
+static BasicDecimal128 AddNoOverflow(const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y, int32_t out_scale) {
+ auto higher_scale = std::max(x.scale(), y.scale());
+ auto sum = AddFastPath(x, y, out_scale);
+ return CheckAndReduceScale(sum, higher_scale - out_scale);
+}
+
+/// Both x_value and y_value must be >= 0
+static BasicDecimal128 AddLargePositive(const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y,
+ int32_t out_scale) {
+ DCHECK_GE(x.value(), 0);
+ DCHECK_GE(y.value(), 0);
+
+ // separate out whole/fractions.
+ BasicDecimal128 x_left, x_right, y_left, y_right;
+ x.value().GetWholeAndFraction(x.scale(), &x_left, &x_right);
+ y.value().GetWholeAndFraction(y.scale(), &y_left, &y_right);
+
+ // Adjust fractional parts to higher scale.
+ auto higher_scale = std::max(x.scale(), y.scale());
+ auto x_right_scaled = CheckAndIncreaseScale(x_right, higher_scale - x.scale());
+ auto y_right_scaled = CheckAndIncreaseScale(y_right, higher_scale - y.scale());
+
+ BasicDecimal128 right;
+ BasicDecimal128 carry_to_left;
+ auto multiplier = BasicDecimal128::GetScaleMultiplier(higher_scale);
+ if (x_right_scaled >= multiplier - y_right_scaled) {
+ right = x_right_scaled - (multiplier - y_right_scaled);
+ carry_to_left = 1;
+ } else {
+ right = x_right_scaled + y_right_scaled;
+ carry_to_left = 0;
+ }
+ right = CheckAndReduceScale(right, higher_scale - out_scale);
+
+ auto left = x_left + y_left + carry_to_left;
+ return (left * BasicDecimal128::GetScaleMultiplier(out_scale)) + right;
+}
+
+/// x_value and y_value cannot be 0, and one must be positive and the other negative.
+static BasicDecimal128 AddLargeNegative(const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y,
+ int32_t out_scale) {
+ DCHECK_NE(x.value(), 0);
+ DCHECK_NE(y.value(), 0);
+ DCHECK((x.value() < 0 && y.value() > 0) || (x.value() > 0 && y.value() < 0));
+
+ // separate out whole/fractions.
+ BasicDecimal128 x_left, x_right, y_left, y_right;
+ x.value().GetWholeAndFraction(x.scale(), &x_left, &x_right);
+ y.value().GetWholeAndFraction(y.scale(), &y_left, &y_right);
+
+ // Adjust fractional parts to higher scale.
+ auto higher_scale = std::max(x.scale(), y.scale());
+ x_right = CheckAndIncreaseScale(x_right, higher_scale - x.scale());
+ y_right = CheckAndIncreaseScale(y_right, higher_scale - y.scale());
+
+ // Overflow not possible because one is +ve and the other is -ve.
+ auto left = x_left + y_left;
+ auto right = x_right + y_right;
+
+ // If the whole and fractional parts have different signs, then we need to make the
+ // fractional part have the same sign as the whole part. If either left or right is
+ // zero, then nothing needs to be done.
+ if (left < 0 && right > 0) {
+ left += 1;
+ right -= BasicDecimal128::GetScaleMultiplier(higher_scale);
+ } else if (left > 0 && right < 0) {
+ left -= 1;
+ right += BasicDecimal128::GetScaleMultiplier(higher_scale);
+ }
+ right = CheckAndReduceScale(right, higher_scale - out_scale);
+ return (left * BasicDecimal128::GetScaleMultiplier(out_scale)) + right;
+}
+
+static BasicDecimal128 AddLarge(const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y, int32_t out_scale) {
+ if (x.value() >= 0 && y.value() >= 0) {
+ // both positive or 0
+ return AddLargePositive(x, y, out_scale);
+ } else if (x.value() <= 0 && y.value() <= 0) {
+ // both negative or 0
+ BasicDecimalScalar128 x_neg(-x.value(), x.precision(), x.scale());
+ BasicDecimalScalar128 y_neg(-y.value(), y.precision(), y.scale());
+ return -AddLargePositive(x_neg, y_neg, out_scale);
+ } else {
+ // one positive and the other negative
+ return AddLargeNegative(x, y, out_scale);
+ }
+}
+
+// Suppose we have a number that requires x bits to be represented and we scale it up by
+// 10^scale_by. Let's say now y bits are required to represent it. This function returns
+// the maximum possible y - x for a given 'scale_by'.
+inline int32_t MaxBitsRequiredIncreaseAfterScaling(int32_t scale_by) {
+ // We rely on the following formula:
+ // bits_required(x * 10^y) <= bits_required(x) + floor(log2(10^y)) + 1
+ // We precompute floor(log2(10^x)) + 1 for x = 0, 1, 2...75, 76
+ DCHECK_GE(scale_by, 0);
+ DCHECK_LE(scale_by, 76);
+ static const int32_t floor_log2_plus_one[] = {
+ 0, 4, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40, 44, 47, 50,
+ 54, 57, 60, 64, 67, 70, 74, 77, 80, 84, 87, 90, 94, 97, 100, 103,
+ 107, 110, 113, 117, 120, 123, 127, 130, 133, 137, 140, 143, 147, 150, 153, 157,
+ 160, 163, 167, 170, 173, 177, 180, 183, 187, 190, 193, 196, 200, 203, 206, 210,
+ 213, 216, 220, 223, 226, 230, 233, 236, 240, 243, 246, 250, 253};
+ return floor_log2_plus_one[scale_by];
+}
+
+// If we have a number with 'num_lz' leading zeros, and we scale it up by 10^scale_by,
+// this function returns the minimum number of leading zeros the result can have.
+inline int32_t MinLeadingZerosAfterScaling(int32_t num_lz, int32_t scale_by) {
+ DCHECK_GE(scale_by, 0);
+ DCHECK_LE(scale_by, 76);
+ int32_t result = num_lz - MaxBitsRequiredIncreaseAfterScaling(scale_by);
+ return result;
+}
+
+// Returns the maximum possible number of bits required to represent num * 10^scale_by.
+inline int32_t MaxBitsRequiredAfterScaling(const BasicDecimalScalar128& num,
+ int32_t scale_by) {
+ auto value = num.value();
+ auto value_abs = value.Abs();
+
+ int32_t num_occupied = 128 - value_abs.CountLeadingBinaryZeros();
+ DCHECK_GE(scale_by, 0);
+ DCHECK_LE(scale_by, 76);
+ return num_occupied + MaxBitsRequiredIncreaseAfterScaling(scale_by);
+}
+
+// Returns the minimum number of leading zero x or y would have after one of them gets
+// scaled up to match the scale of the other one.
+inline int32_t MinLeadingZeros(const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y) {
+ auto x_value = x.value();
+ auto x_value_abs = x_value.Abs();
+
+ auto y_value = y.value();
+ auto y_value_abs = y_value.Abs();
+
+ int32_t x_lz = x_value_abs.CountLeadingBinaryZeros();
+ int32_t y_lz = y_value_abs.CountLeadingBinaryZeros();
+ if (x.scale() < y.scale()) {
+ x_lz = MinLeadingZerosAfterScaling(x_lz, y.scale() - x.scale());
+ } else if (x.scale() > y.scale()) {
+ y_lz = MinLeadingZerosAfterScaling(y_lz, x.scale() - y.scale());
+ }
+ return std::min(x_lz, y_lz);
+}
+
+BasicDecimal128 Add(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y,
+ int32_t out_precision, int32_t out_scale) {
+ if (out_precision < DecimalTypeUtil::kMaxPrecision) {
+ // fast-path add
+ return AddFastPath(x, y, out_scale);
+ } else {
+ int32_t min_lz = MinLeadingZeros(x, y);
+ if (min_lz >= 3) {
+ // If both numbers have at least MIN_LZ leading zeros, we can add them directly
+ // without the risk of overflow.
+ // We want the result to have at least 2 leading zeros, which ensures that it fits
+ // into the maximum decimal because 2^126 - 1 < 10^38 - 1. If both x and y have at
+ // least 3 leading zeros, then we are guaranteed that the result will have at lest 2
+ // leading zeros.
+ return AddNoOverflow(x, y, out_scale);
+ } else {
+ // slower-version : add whole/fraction parts separately, and then, combine.
+ return AddLarge(x, y, out_scale);
+ }
+ }
+}
+
+BasicDecimal128 Subtract(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y,
+ int32_t out_precision, int32_t out_scale) {
+ return Add(x, {-y.value(), y.precision(), y.scale()}, out_precision, out_scale);
+}
+
+// Multiply when the out_precision is 38, and there is no trimming of the scale i.e
+// the intermediate value is the same as the final value.
+static BasicDecimal128 MultiplyMaxPrecisionNoScaleDown(const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y,
+ int32_t out_scale,
+ bool* overflow) {
+ DCHECK_EQ(x.scale() + y.scale(), out_scale);
+
+ BasicDecimal128 result;
+ auto x_abs = BasicDecimal128::Abs(x.value());
+ auto y_abs = BasicDecimal128::Abs(y.value());
+
+ if (x_abs > BasicDecimal128::GetMaxValue() / y_abs) {
+ *overflow = true;
+ } else {
+ // We've verified that the result will fit into 128 bits.
+ *overflow = false;
+ result = x.value() * y.value();
+ }
+ return result;
+}
+
+// Multiply when the out_precision is 38, and there is trimming of the scale i.e
+// the intermediate value could be larger than the final value.
+static BasicDecimal128 MultiplyMaxPrecisionAndScaleDown(const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y,
+ int32_t out_scale,
+ bool* overflow) {
+ auto delta_scale = x.scale() + y.scale() - out_scale;
+ DCHECK_GT(delta_scale, 0);
+
+ *overflow = false;
+ BasicDecimal128 result;
+ auto x_abs = BasicDecimal128::Abs(x.value());
+ auto y_abs = BasicDecimal128::Abs(y.value());
+
+ // It's possible that the intermediate value does not fit in 128-bits, but the
+ // final value will (after scaling down).
+ bool needs_int256 = false;
+ int32_t total_leading_zeros =
+ x_abs.CountLeadingBinaryZeros() + y_abs.CountLeadingBinaryZeros();
+ // This check is quick, but conservative. In some cases it will indicate that
+ // converting to 256 bits is necessary, when it's not actually the case.
+ needs_int256 = total_leading_zeros <= 128;
+ if (ARROW_PREDICT_FALSE(needs_int256)) {
+ int64_t result_high;
+ uint64_t result_low;
+
+ // This requires converting to 256-bit, and we use the boost library for that. To
+ // avoid references to boost from the precompiled-to-ir code (this causes issues
+ // with symbol resolution at runtime), we use a wrapper exported from the CPP code.
+ gdv_xlarge_multiply_and_scale_down(x.value().high_bits(), x.value().low_bits(),
+ y.value().high_bits(), y.value().low_bits(),
+ delta_scale, &result_high, &result_low, overflow);
+ result = BasicDecimal128(result_high, result_low);
+ } else {
+ if (ARROW_PREDICT_TRUE(delta_scale <= 38)) {
+ // The largest value that result can have here is (2^64 - 1) * (2^63 - 1), which is
+ // greater than BasicDecimal128::kMaxValue.
+ result = x.value() * y.value();
+ // Since delta_scale is greater than zero, result can now be at most
+ // ((2^64 - 1) * (2^63 - 1)) / 10, which is less than BasicDecimal128::kMaxValue, so
+ // there cannot be any overflow.
+ result = result.ReduceScaleBy(delta_scale);
+ } else {
+ // We are multiplying decimal(38, 38) by decimal(38, 38). The result should be a
+ // decimal(38, 37), so delta scale = 38 + 38 - 37 = 39. Since we are not in the
+ // 256 bit intermediate value case and we are scaling down by 39, then we are
+ // guaranteed that the result is 0 (even if we try to round). The largest possible
+ // intermediate result is 38 "9"s. If we scale down by 39, the leftmost 9 is now
+ // two digits to the right of the rightmost "visible" one. The reason why we have
+ // to handle this case separately is because a scale multiplier with a delta_scale
+ // 39 does not fit into 128 bit.
+ DCHECK_EQ(delta_scale, 39);
+ result = 0;
+ }
+ }
+ return result;
+}
+
+// Multiply when the out_precision is 38.
+static BasicDecimal128 MultiplyMaxPrecision(const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y,
+ int32_t out_scale, bool* overflow) {
+ auto delta_scale = x.scale() + y.scale() - out_scale;
+ DCHECK_GE(delta_scale, 0);
+ if (delta_scale == 0) {
+ return MultiplyMaxPrecisionNoScaleDown(x, y, out_scale, overflow);
+ } else {
+ return MultiplyMaxPrecisionAndScaleDown(x, y, out_scale, overflow);
+ }
+}
+
+BasicDecimal128 Multiply(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y,
+ int32_t out_precision, int32_t out_scale, bool* overflow) {
+ BasicDecimal128 result;
+ *overflow = false;
+ if (out_precision < DecimalTypeUtil::kMaxPrecision) {
+ // fast-path multiply
+ result = x.value() * y.value();
+ DCHECK_EQ(x.scale() + y.scale(), out_scale);
+ DCHECK_LE(BasicDecimal128::Abs(result), BasicDecimal128::GetMaxValue());
+ } else if (x.value() == 0 || y.value() == 0) {
+ // Handle this separately to avoid divide-by-zero errors.
+ result = BasicDecimal128(0, 0);
+ } else {
+ result = MultiplyMaxPrecision(x, y, out_scale, overflow);
+ }
+ DCHECK(*overflow || BasicDecimal128::Abs(result) <= BasicDecimal128::GetMaxValue());
+ return result;
+}
+
+BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y, int32_t out_precision,
+ int32_t out_scale, bool* overflow) {
+ if (y.value() == 0) {
+ char const* err_msg = "divide by zero error";
+ gdv_fn_context_set_error_msg(context, err_msg);
+ return 0;
+ }
+
+ // scale up to the output scale, and do an integer division.
+ int32_t delta_scale = out_scale + y.scale() - x.scale();
+ DCHECK_GE(delta_scale, 0);
+
+ BasicDecimal128 result;
+ auto num_bits_required_after_scaling = MaxBitsRequiredAfterScaling(x, delta_scale);
+ if (num_bits_required_after_scaling <= 127) {
+ // fast-path. The dividend fits in 128-bit after scaling too.
+ *overflow = false;
+
+ // do the division.
+ auto x_scaled = CheckAndIncreaseScale(x.value(), delta_scale);
+ BasicDecimal128 remainder;
+ auto status = x_scaled.Divide(y.value(), &result, &remainder);
+ DCHECK_EQ(status, arrow::DecimalStatus::kSuccess);
+
+ // round-up
+ if (BasicDecimal128::Abs(2 * remainder) >= BasicDecimal128::Abs(y.value())) {
+ result += (x.value().Sign() ^ y.value().Sign()) + 1;
+ }
+ } else {
+ // convert to 256-bit and do the divide.
+ *overflow = delta_scale > 38 && num_bits_required_after_scaling > 255;
+ if (!*overflow) {
+ int64_t result_high;
+ uint64_t result_low;
+
+ gdv_xlarge_scale_up_and_divide(x.value().high_bits(), x.value().low_bits(),
+ y.value().high_bits(), y.value().low_bits(),
+ delta_scale, &result_high, &result_low, overflow);
+ result = BasicDecimal128(result_high, result_low);
+ }
+ }
+ return result;
+}
+
+BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y, int32_t out_precision,
+ int32_t out_scale, bool* overflow) {
+ if (y.value() == 0) {
+ char const* err_msg = "divide by zero error";
+ gdv_fn_context_set_error_msg(context, err_msg);
+ return 0;
+ }
+
+ // Adsjust x and y to the same scale (higher one), and then, do a integer mod.
+ *overflow = false;
+ BasicDecimal128 result;
+ int32_t min_lz = MinLeadingZeros(x, y);
+ if (min_lz >= 2) {
+ auto higher_scale = std::max(x.scale(), y.scale());
+ auto x_scaled = CheckAndIncreaseScale(x.value(), higher_scale - x.scale());
+ auto y_scaled = CheckAndIncreaseScale(y.value(), higher_scale - y.scale());
+ result = x_scaled % y_scaled;
+ DCHECK_LE(BasicDecimal128::Abs(result), BasicDecimal128::GetMaxValue());
+ } else {
+ int64_t result_high;
+ uint64_t result_low;
+
+ gdv_xlarge_mod(x.value().high_bits(), x.value().low_bits(), x.scale(),
+ y.value().high_bits(), y.value().low_bits(), y.scale(), &result_high,
+ &result_low);
+ result = BasicDecimal128(result_high, result_low);
+ }
+ DCHECK(BasicDecimal128::Abs(result) <= BasicDecimal128::Abs(x.value()) ||
+ BasicDecimal128::Abs(result) <= BasicDecimal128::Abs(y.value()));
+ return result;
+}
+
+int32_t CompareSameScale(const BasicDecimal128& x, const BasicDecimal128& y) {
+ if (x == y) {
+ return 0;
+ } else if (x < y) {
+ return -1;
+ } else {
+ return 1;
+ }
+}
+
+int32_t Compare(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y) {
+ int32_t delta_scale = x.scale() - y.scale();
+
+ // fast-path : both are of the same scale.
+ if (delta_scale == 0) {
+ return CompareSameScale(x.value(), y.value());
+ }
+
+ // Check if we'll need more than 256-bits after adjusting the scale.
+ bool need256 =
+ (delta_scale < 0 && x.precision() - delta_scale > DecimalTypeUtil::kMaxPrecision) ||
+ (y.precision() + delta_scale > DecimalTypeUtil::kMaxPrecision);
+ if (need256) {
+ return gdv_xlarge_compare(x.value().high_bits(), x.value().low_bits(), x.scale(),
+ y.value().high_bits(), y.value().low_bits(), y.scale());
+ } else {
+ BasicDecimal128 x_scaled;
+ BasicDecimal128 y_scaled;
+
+ if (delta_scale < 0) {
+ x_scaled = x.value().IncreaseScaleBy(-delta_scale);
+ y_scaled = y.value();
+ } else {
+ x_scaled = x.value();
+ y_scaled = y.value().IncreaseScaleBy(delta_scale);
+ }
+ return CompareSameScale(x_scaled, y_scaled);
+ }
+}
+
+#define DECIMAL_OVERFLOW_IF(condition, overflow) \
+ do { \
+ if (*overflow || (condition)) { \
+ *overflow = true; \
+ return 0; \
+ } \
+ } while (0)
+
+static BasicDecimal128 GetMaxValue(int32_t precision) {
+ return BasicDecimal128::GetScaleMultiplier(precision) - 1;
+}
+
+// Compute the double scale multipliers once.
+static std::array<double, DecimalTypeUtil::kMaxPrecision + 1> kDoubleScaleMultipliers =
+ ([]() -> std::array<double, DecimalTypeUtil::kMaxPrecision + 1> {
+ std::array<double, DecimalTypeUtil::kMaxPrecision + 1> values;
+ values[0] = 1.0;
+ for (int32_t idx = 1; idx <= DecimalTypeUtil::kMaxPrecision; idx++) {
+ values[idx] = values[idx - 1] * 10;
+ }
+ return values;
+ })();
+
+BasicDecimal128 FromDouble(double in, int32_t precision, int32_t scale, bool* overflow) {
+ // Multiply decimal with the scale
+ auto unscaled = in * kDoubleScaleMultipliers[scale];
+ DECIMAL_OVERFLOW_IF(std::isnan(unscaled), overflow);
+
+ unscaled = std::round(unscaled);
+
+ // convert scaled double to int128
+ int32_t sign = unscaled < 0 ? -1 : 1;
+ auto unscaled_abs = std::abs(unscaled);
+
+ // overflow if > 2^127 - 1
+ DECIMAL_OVERFLOW_IF(unscaled_abs > std::ldexp(static_cast<double>(1), 127) - 1,
+ overflow);
+
+ uint64_t high_bits = static_cast<uint64_t>(std::ldexp(unscaled_abs, -64));
+ uint64_t low_bits = static_cast<uint64_t>(
+ unscaled_abs - std::ldexp(static_cast<double>(high_bits), 64));
+
+ auto result = BasicDecimal128(static_cast<int64_t>(high_bits), low_bits);
+
+ // overflow if > max value based on precision
+ DECIMAL_OVERFLOW_IF(result > GetMaxValue(precision), overflow);
+ return result * sign;
+}
+
+double ToDouble(const BasicDecimalScalar128& in, bool* overflow) {
+ // convert int128 to double
+ int64_t sign = in.value().Sign();
+ auto value_abs = BasicDecimal128::Abs(in.value());
+ double unscaled = static_cast<double>(value_abs.low_bits()) +
+ std::ldexp(static_cast<double>(value_abs.high_bits()), 64);
+
+ // scale double.
+ return (unscaled * sign) / kDoubleScaleMultipliers[in.scale()];
+}
+
+BasicDecimal128 FromInt64(int64_t in, int32_t precision, int32_t scale, bool* overflow) {
+ // check if multiplying by scale will cause an overflow.
+ DECIMAL_OVERFLOW_IF(std::abs(in) > GetMaxValue(precision - scale), overflow);
+ return in * BasicDecimal128::GetScaleMultiplier(scale);
+}
+
+// Helper function to modify the scale and/or precision of a decimal value.
+static BasicDecimal128 ModifyScaleAndPrecision(const BasicDecimalScalar128& x,
+ int32_t out_precision, int32_t out_scale,
+ bool* overflow) {
+ int32_t delta_scale = out_scale - x.scale();
+ if (delta_scale >= 0) {
+ // check if multiplying by delta_scale will cause an overflow.
+ DECIMAL_OVERFLOW_IF(
+ BasicDecimal128::Abs(x.value()) > GetMaxValue(out_precision - delta_scale),
+ overflow);
+ return x.value().IncreaseScaleBy(delta_scale);
+ } else {
+ // Do not do any rounding, that is handled by the caller.
+ auto result = x.value().ReduceScaleBy(-delta_scale, false);
+ DECIMAL_OVERFLOW_IF(BasicDecimal128::Abs(result) > GetMaxValue(out_precision),
+ overflow);
+ return result;
+ }
+}
+
+enum RoundType {
+ kRoundTypeCeil, // +1 if +ve and trailing value is > 0, else no rounding.
+ kRoundTypeFloor, // -1 if -ve and trailing value is < 0, else no rounding.
+ kRoundTypeTrunc, // no rounding, truncate the trailing digits.
+ kRoundTypeHalfRoundUp, // if +ve and trailing value is >= half of base, +1.
+ // else if -ve and trailing value is >= half of base, -1.
+};
+
+// Compute the rounding delta for the givven rounding type.
+static int32_t ComputeRoundingDelta(const BasicDecimal128& x, int32_t x_scale,
+ int32_t out_scale, RoundType type) {
+ if (type == kRoundTypeTrunc || // no rounding for this type.
+ out_scale >= x_scale) { // no digits dropped, so no rounding.
+ return 0;
+ }
+
+ int32_t result = 0;
+ switch (type) {
+ case kRoundTypeHalfRoundUp: {
+ auto base = BasicDecimal128::GetScaleMultiplier(x_scale - out_scale);
+ auto trailing = x % base;
+ if (trailing == 0) {
+ result = 0;
+ } else if (trailing.Abs() < base / 2) {
+ result = 0;
+ } else {
+ result = (x < 0) ? -1 : 1;
+ }
+ break;
+ }
+
+ case kRoundTypeCeil:
+ if (x < 0) {
+ // no rounding for -ve
+ result = 0;
+ } else {
+ auto base = BasicDecimal128::GetScaleMultiplier(x_scale - out_scale);
+ auto trailing = x % base;
+ result = (trailing == 0) ? 0 : 1;
+ }
+ break;
+
+ case kRoundTypeFloor:
+ if (x > 0) {
+ // no rounding for +ve
+ result = 0;
+ } else {
+ auto base = BasicDecimal128::GetScaleMultiplier(x_scale - out_scale);
+ auto trailing = x % base;
+ result = (trailing == 0) ? 0 : -1;
+ }
+ break;
+
+ case kRoundTypeTrunc:
+ break;
+ }
+ return result;
+}
+
+// Modify the scale and round.
+static BasicDecimal128 RoundWithPositiveScale(const BasicDecimalScalar128& x,
+ int32_t out_precision, int32_t out_scale,
+ RoundType round_type, bool* overflow) {
+ DCHECK_GE(out_scale, 0);
+
+ auto scaled = ModifyScaleAndPrecision(x, out_precision, out_scale, overflow);
+ if (*overflow) {
+ return 0;
+ }
+
+ auto delta = ComputeRoundingDelta(x.value(), x.scale(), out_scale, round_type);
+ if (delta == 0) {
+ return scaled;
+ }
+
+ // If there is a rounding delta, the output scale must be less than the input scale.
+ // That means at least one digit is dropped after the decimal. The delta add can add
+ // utmost one digit before the decimal. So, overflow will occur only if the output
+ // precision has changed.
+ DCHECK_GT(x.scale(), out_scale);
+ auto result = scaled + delta;
+ DECIMAL_OVERFLOW_IF(out_precision < x.precision() &&
+ BasicDecimal128::Abs(result) > GetMaxValue(out_precision),
+ overflow);
+ return result;
+}
+
+// Modify scale to drop all digits to the right of the decimal and round.
+// Then, zero out 'rounding_scale' number of digits to the left of the decimal point.
+static BasicDecimal128 RoundWithNegativeScale(const BasicDecimalScalar128& x,
+ int32_t out_precision,
+ int32_t rounding_scale,
+ RoundType round_type, bool* overflow) {
+ DCHECK_LT(rounding_scale, 0);
+
+ // get rid of the fractional part.
+ auto scaled = ModifyScaleAndPrecision(x, out_precision, 0, overflow);
+ auto rounding_delta = ComputeRoundingDelta(scaled, 0, -rounding_scale, round_type);
+
+ auto base = BasicDecimal128::GetScaleMultiplier(-rounding_scale);
+ auto delta = rounding_delta * base - (scaled % base);
+ DECIMAL_OVERFLOW_IF(BasicDecimal128::Abs(scaled) >
+ GetMaxValue(out_precision) - BasicDecimal128::Abs(delta),
+ overflow);
+ return scaled + delta;
+}
+
+BasicDecimal128 Round(const BasicDecimalScalar128& x, int32_t out_precision,
+ int32_t out_scale, int32_t rounding_scale, bool* overflow) {
+ // no-op if target scale is same as arg scale
+ if (x.scale() == out_scale && rounding_scale >= 0) {
+ return x.value();
+ }
+
+ if (rounding_scale < 0) {
+ return RoundWithNegativeScale(x, out_precision, rounding_scale,
+ RoundType::kRoundTypeHalfRoundUp, overflow);
+ } else {
+ return RoundWithPositiveScale(x, out_precision, rounding_scale,
+ RoundType::kRoundTypeHalfRoundUp, overflow);
+ }
+}
+
+BasicDecimal128 Truncate(const BasicDecimalScalar128& x, int32_t out_precision,
+ int32_t out_scale, int32_t rounding_scale, bool* overflow) {
+ // no-op if target scale is same as arg scale
+ if (x.scale() == out_scale && rounding_scale >= 0) {
+ return x.value();
+ }
+
+ if (rounding_scale < 0) {
+ return RoundWithNegativeScale(x, out_precision, rounding_scale,
+ RoundType::kRoundTypeTrunc, overflow);
+ } else {
+ return RoundWithPositiveScale(x, out_precision, rounding_scale,
+ RoundType::kRoundTypeTrunc, overflow);
+ }
+}
+
+BasicDecimal128 Ceil(const BasicDecimalScalar128& x, bool* overflow) {
+ return RoundWithPositiveScale(x, x.precision(), 0, RoundType::kRoundTypeCeil, overflow);
+}
+
+BasicDecimal128 Floor(const BasicDecimalScalar128& x, bool* overflow) {
+ return RoundWithPositiveScale(x, x.precision(), 0, RoundType::kRoundTypeFloor,
+ overflow);
+}
+
+BasicDecimal128 Convert(const BasicDecimalScalar128& x, int32_t out_precision,
+ int32_t out_scale, bool* overflow) {
+ DCHECK_GE(out_scale, 0);
+ DCHECK_LE(out_scale, DecimalTypeUtil::kMaxScale);
+ DCHECK_GT(out_precision, 0);
+ DCHECK_LE(out_precision, DecimalTypeUtil::kMaxScale);
+
+ return RoundWithPositiveScale(x, out_precision, out_scale,
+ RoundType::kRoundTypeHalfRoundUp, overflow);
+}
+
+int64_t ToInt64(const BasicDecimalScalar128& in, bool* overflow) {
+ auto rounded = RoundWithPositiveScale(in, in.precision(), 0 /*scale*/,
+ RoundType::kRoundTypeHalfRoundUp, overflow);
+ DECIMAL_OVERFLOW_IF((rounded > std::numeric_limits<int64_t>::max()) ||
+ (rounded < std::numeric_limits<int64_t>::min()),
+ overflow);
+ return static_cast<int64_t>(rounded.low_bits());
+}
+
+} // namespace decimalops
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/precompiled/decimal_ops.h b/src/arrow/cpp/src/gandiva/precompiled/decimal_ops.h
new file mode 100644
index 000000000..292dce220
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/decimal_ops.h
@@ -0,0 +1,90 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <string>
+#include "gandiva/basic_decimal_scalar.h"
+
+namespace gandiva {
+namespace decimalops {
+
+/// Return the sum of 'x' and 'y'.
+/// out_precision and out_scale are passed along for efficiency, they must match
+/// the rules in DecimalTypeSql::GetResultType.
+arrow::BasicDecimal128 Add(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y,
+ int32_t out_precision, int32_t out_scale);
+
+/// Subtract 'y' from 'x', and return the result.
+arrow::BasicDecimal128 Subtract(const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y, int32_t out_precision,
+ int32_t out_scale);
+
+/// Multiply 'x' from 'y', and return the result.
+arrow::BasicDecimal128 Multiply(const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y, int32_t out_precision,
+ int32_t out_scale, bool* overflow);
+
+/// Divide 'x' by 'y', and return the result.
+arrow::BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y, int32_t out_precision,
+ int32_t out_scale, bool* overflow);
+
+/// Divide 'x' by 'y', and return the remainder.
+arrow::BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y, int32_t out_precision,
+ int32_t out_scale, bool* overflow);
+
+/// Compare two decimals. Returns :
+/// 0 if x == y
+/// 1 if x > y
+/// -1 if x < y
+int32_t Compare(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y);
+
+/// Convert to decimal from double.
+BasicDecimal128 FromDouble(double in, int32_t precision, int32_t scale, bool* overflow);
+
+/// Convert from decimal to double.
+double ToDouble(const BasicDecimalScalar128& in, bool* overflow);
+
+/// Convert to decimal from gdv_int64.
+BasicDecimal128 FromInt64(int64_t in, int32_t precision, int32_t scale, bool* overflow);
+
+/// Convert from decimal to gdv_int64
+int64_t ToInt64(const BasicDecimalScalar128& in, bool* overflow);
+
+/// Convert from one decimal scale/precision to another.
+BasicDecimal128 Convert(const BasicDecimalScalar128& x, int32_t out_precision,
+ int32_t out_scale, bool* overflow);
+
+/// round decimal.
+BasicDecimal128 Round(const BasicDecimalScalar128& x, int32_t out_precision,
+ int32_t out_scale, int32_t rounding_scale, bool* overflow);
+
+/// truncate decimal.
+BasicDecimal128 Truncate(const BasicDecimalScalar128& x, int32_t out_precision,
+ int32_t out_scale, int32_t rounding_scale, bool* overflow);
+
+/// ceil decimal
+BasicDecimal128 Ceil(const BasicDecimalScalar128& x, bool* overflow);
+
+/// floor decimal
+BasicDecimal128 Floor(const BasicDecimalScalar128& x, bool* overflow);
+
+} // namespace decimalops
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/precompiled/decimal_ops_test.cc b/src/arrow/cpp/src/gandiva/precompiled/decimal_ops_test.cc
new file mode 100644
index 000000000..be8a1fe8a
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/decimal_ops_test.cc
@@ -0,0 +1,1095 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include <algorithm>
+#include <limits>
+#include <memory>
+#include <tuple>
+#include <vector>
+
+#include "arrow/testing/gtest_util.h"
+#include "gandiva/decimal_scalar.h"
+#include "gandiva/decimal_type_util.h"
+#include "gandiva/execution_context.h"
+#include "gandiva/precompiled/decimal_ops.h"
+#include "gandiva/precompiled/types.h"
+
+namespace gandiva {
+
+const arrow::Decimal128 kThirtyFive9s(std::string(35, '9'));
+const arrow::Decimal128 kThirtySix9s(std::string(36, '9'));
+const arrow::Decimal128 kThirtyEight9s(std::string(38, '9'));
+
+class TestDecimalSql : public ::testing::Test {
+ protected:
+ static void Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x,
+ const DecimalScalar128& y, const DecimalScalar128& expected_result,
+ bool expected_overflow);
+
+ static void VerifyAllSign(DecimalTypeUtil::Op op, const DecimalScalar128& left,
+ const DecimalScalar128& right,
+ const DecimalScalar128& expected_output,
+ bool expected_overflow);
+
+ void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected_result) {
+ // TODO: overflow checks
+ return Verify(DecimalTypeUtil::kOpAdd, x, y, expected_result, false);
+ }
+
+ void SubtractAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected_result) {
+ // TODO: overflow checks
+ return Verify(DecimalTypeUtil::kOpSubtract, x, y, expected_result, false);
+ }
+
+ void MultiplyAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected_result,
+ bool expected_overflow) {
+ return Verify(DecimalTypeUtil::kOpMultiply, x, y, expected_result, expected_overflow);
+ }
+
+ void MultiplyAndVerifyAllSign(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected_result,
+ bool expected_overflow) {
+ return VerifyAllSign(DecimalTypeUtil::kOpMultiply, x, y, expected_result,
+ expected_overflow);
+ }
+
+ void DivideAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected_result, bool expected_overflow) {
+ return Verify(DecimalTypeUtil::kOpDivide, x, y, expected_result, expected_overflow);
+ }
+
+ void DivideAndVerifyAllSign(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected_result,
+ bool expected_overflow) {
+ return VerifyAllSign(DecimalTypeUtil::kOpDivide, x, y, expected_result,
+ expected_overflow);
+ }
+
+ void ModAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected_result, bool expected_overflow) {
+ return Verify(DecimalTypeUtil::kOpMod, x, y, expected_result, expected_overflow);
+ }
+
+ void ModAndVerifyAllSign(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected_result,
+ bool expected_overflow) {
+ return VerifyAllSign(DecimalTypeUtil::kOpMod, x, y, expected_result,
+ expected_overflow);
+ }
+};
+
+#define EXPECT_DECIMAL_EQ(op, x, y, expected_result, expected_overflow, actual_result, \
+ actual_overflow) \
+ { \
+ EXPECT_TRUE(expected_overflow == actual_overflow) \
+ << op << "(" << (x).ToString() << " and " << (y).ToString() << ")" \
+ << " expected overflow : " << expected_overflow \
+ << " actual overflow : " << actual_overflow; \
+ if (!expected_overflow) { \
+ EXPECT_TRUE(expected_result == actual_result) \
+ << op << "(" << (x).ToString() << " and " << (y).ToString() << ")" \
+ << " expected : " << expected_result.ToString() \
+ << " actual : " << actual_result.ToString(); \
+ } \
+ }
+
+void TestDecimalSql::Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x,
+ const DecimalScalar128& y,
+ const DecimalScalar128& expected_result,
+ bool expected_overflow) {
+ auto t1 = std::make_shared<arrow::Decimal128Type>(x.precision(), x.scale());
+ auto t2 = std::make_shared<arrow::Decimal128Type>(y.precision(), y.scale());
+ bool overflow = false;
+ int64_t context = 0;
+
+ Decimal128TypePtr out_type;
+ ARROW_EXPECT_OK(DecimalTypeUtil::GetResultType(op, {t1, t2}, &out_type));
+
+ arrow::BasicDecimal128 out_value;
+ std::string op_name;
+ switch (op) {
+ case DecimalTypeUtil::kOpAdd:
+ op_name = "add";
+ out_value = decimalops::Add(x, y, out_type->precision(), out_type->scale());
+ break;
+
+ case DecimalTypeUtil::kOpSubtract:
+ op_name = "subtract";
+ out_value = decimalops::Subtract(x, y, out_type->precision(), out_type->scale());
+ break;
+
+ case DecimalTypeUtil::kOpMultiply:
+ op_name = "multiply";
+ out_value =
+ decimalops::Multiply(x, y, out_type->precision(), out_type->scale(), &overflow);
+ break;
+
+ case DecimalTypeUtil::kOpDivide:
+ op_name = "divide";
+ out_value = decimalops::Divide(context, x, y, out_type->precision(),
+ out_type->scale(), &overflow);
+ break;
+
+ case DecimalTypeUtil::kOpMod:
+ op_name = "mod";
+ out_value = decimalops::Mod(context, x, y, out_type->precision(), out_type->scale(),
+ &overflow);
+ break;
+
+ default:
+ // not implemented.
+ ASSERT_FALSE(true);
+ }
+ EXPECT_DECIMAL_EQ(op_name, x, y, expected_result, expected_overflow,
+ DecimalScalar128(out_value, out_type->precision(), out_type->scale()),
+ overflow);
+}
+
+void TestDecimalSql::VerifyAllSign(DecimalTypeUtil::Op op, const DecimalScalar128& left,
+ const DecimalScalar128& right,
+ const DecimalScalar128& expected_output,
+ bool expected_overflow) {
+ // both +ve
+ Verify(op, left, right, expected_output, expected_overflow);
+
+ // left -ve
+ Verify(op, -left, right, -expected_output, expected_overflow);
+
+ if (op == DecimalTypeUtil::kOpMod) {
+ // right -ve
+ Verify(op, left, -right, expected_output, expected_overflow);
+
+ // both -ve
+ Verify(op, -left, -right, -expected_output, expected_overflow);
+ } else {
+ ASSERT_TRUE(op == DecimalTypeUtil::kOpMultiply || op == DecimalTypeUtil::kOpDivide);
+
+ // right -ve
+ Verify(op, left, -right, -expected_output, expected_overflow);
+
+ // both -ve
+ Verify(op, -left, -right, expected_output, expected_overflow);
+ }
+}
+
+TEST_F(TestDecimalSql, Add) {
+ // fast-path
+ AddAndVerify(DecimalScalar128{"201", 30, 3}, // x
+ DecimalScalar128{"301", 30, 3}, // y
+ DecimalScalar128{"502", 31, 3}); // expected
+
+ // max precision
+ AddAndVerify(DecimalScalar128{"09999999999999999999999999999999000000", 38, 5}, // x
+ DecimalScalar128{"100", 38, 7}, // y
+ DecimalScalar128{"99999999999999999999999999999990000010", 38, 6});
+
+ // Both -ve
+ AddAndVerify(DecimalScalar128{"-201", 30, 3}, // x
+ DecimalScalar128{"-301", 30, 2}, // y
+ DecimalScalar128{"-3211", 32, 3}); // expected
+
+ // -ve and max precision
+ AddAndVerify(DecimalScalar128{"-09999999999999999999999999999999000000", 38, 5}, // x
+ DecimalScalar128{"-100", 38, 7}, // y
+ DecimalScalar128{"-99999999999999999999999999999990000010", 38, 6});
+}
+
+TEST_F(TestDecimalSql, Subtract) {
+ // fast-path
+ SubtractAndVerify(DecimalScalar128{"201", 30, 3}, // x
+ DecimalScalar128{"301", 30, 3}, // y
+ DecimalScalar128{"-100", 31, 3}); // expected
+
+ // max precision
+ SubtractAndVerify(
+ DecimalScalar128{"09999999999999999999999999999999000000", 38, 5}, // x
+ DecimalScalar128{"100", 38, 7}, // y
+ DecimalScalar128{"99999999999999999999999999999989999990", 38, 6});
+
+ // Both -ve
+ SubtractAndVerify(DecimalScalar128{"-201", 30, 3}, // x
+ DecimalScalar128{"-301", 30, 2}, // y
+ DecimalScalar128{"2809", 32, 3}); // expected
+
+ // -ve and max precision
+ SubtractAndVerify(
+ DecimalScalar128{"-09999999999999999999999999999999000000", 38, 5}, // x
+ DecimalScalar128{"-100", 38, 7}, // y
+ DecimalScalar128{"-99999999999999999999999999999989999990", 38, 6});
+}
+
+TEST_F(TestDecimalSql, Multiply) {
+ // fast-path : out_precision < 38
+ MultiplyAndVerifyAllSign(DecimalScalar128{"201", 10, 3}, // x
+ DecimalScalar128{"301", 10, 2}, // y
+ DecimalScalar128{"60501", 21, 5}, // expected
+ false); // overflow
+
+ // right 0
+ MultiplyAndVerify(DecimalScalar128{"201", 20, 3}, // x
+ DecimalScalar128{"0", 20, 2}, // y
+ DecimalScalar128{"0", 38, 5}, // expected
+ false); // overflow
+
+ // left 0
+ MultiplyAndVerify(DecimalScalar128{"0", 20, 3}, // x
+ DecimalScalar128{"301", 20, 2}, // y
+ DecimalScalar128{"0", 38, 5}, // expected
+ false); // overflow
+
+ // out_precision == 38, small input values, no trimming of scale (scale <= 6 doesn't
+ // get trimmed).
+ MultiplyAndVerify(DecimalScalar128{"201", 20, 3}, // x
+ DecimalScalar128{"301", 20, 2}, // y
+ DecimalScalar128{"60501", 38, 5}, // expected
+ false); // overflow
+
+ // out_precision == 38, large values, no trimming of scale (scale <= 6 doesn't
+ // get trimmed).
+ MultiplyAndVerifyAllSign(
+ DecimalScalar128{"201", 20, 3}, // x
+ DecimalScalar128{kThirtyFive9s, 35, 2}, // y
+ DecimalScalar128{"20099999999999999999999999999999999799", 38, 5}, // expected
+ false); // overflow
+
+ // out_precision == 38, very large values, no trimming of scale (scale <= 6 doesn't
+ // get trimmed). overflow expected.
+ MultiplyAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x
+ DecimalScalar128{kThirtySix9s, 35, 2}, // y
+ DecimalScalar128{"0", 38, 5}, // expected
+ true); // overflow
+
+ MultiplyAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x
+ DecimalScalar128{kThirtyEight9s, 35, 2}, // y
+ DecimalScalar128{"0", 38, 5}, // expected
+ true); // overflow
+
+ // out_precision == 38, small input values, trimming of scale.
+ MultiplyAndVerifyAllSign(DecimalScalar128{"201", 20, 5}, // x
+ DecimalScalar128{"301", 20, 5}, // y
+ DecimalScalar128{"61", 38, 7}, // expected
+ false); // overflow
+
+ // out_precision == 38, large values, trimming of scale.
+ MultiplyAndVerifyAllSign(
+ DecimalScalar128{"201", 20, 5}, // x
+ DecimalScalar128{kThirtyFive9s, 35, 5}, // y
+ DecimalScalar128{"2010000000000000000000000000000000", 38, 6}, // expected
+ false); // overflow
+
+ // out_precision == 38, very large values, trimming of scale (requires convert to 256).
+ MultiplyAndVerifyAllSign(
+ DecimalScalar128{kThirtyFive9s, 38, 20}, // x
+ DecimalScalar128{kThirtySix9s, 38, 20}, // y
+ DecimalScalar128{"9999999999999999999999999999999999890", 38, 6}, // expected
+ false); // overflow
+
+ // out_precision == 38, very large values, trimming of scale (requires convert to 256).
+ // should cause overflow.
+ MultiplyAndVerifyAllSign(DecimalScalar128{kThirtyFive9s, 38, 4}, // x
+ DecimalScalar128{kThirtySix9s, 38, 4}, // y
+ DecimalScalar128{"0", 38, 6}, // expected
+ true); // overflow
+
+ // corner cases.
+ MultiplyAndVerifyAllSign(
+ DecimalScalar128{0, UINT64_MAX, 38, 4}, // x
+ DecimalScalar128{0, UINT64_MAX, 38, 4}, // y
+ DecimalScalar128{"3402823669209384634264811192843491082", 38, 6}, // expected
+ false); // overflow
+
+ MultiplyAndVerifyAllSign(
+ DecimalScalar128{0, UINT64_MAX, 38, 4}, // x
+ DecimalScalar128{0, INT64_MAX, 38, 4}, // y
+ DecimalScalar128{"1701411834604692317040171876053197783", 38, 6}, // expected
+ false); // overflow
+
+ MultiplyAndVerifyAllSign(DecimalScalar128{"201", 38, 38}, // x
+ DecimalScalar128{"301", 38, 38}, // y
+ DecimalScalar128{"0", 38, 37}, // expected
+ false); // overflow
+
+ MultiplyAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 38, 38}, // x
+ DecimalScalar128{0, UINT64_MAX, 38, 38}, // y
+ DecimalScalar128{"0", 38, 37}, // expected
+ false); // overflow
+
+ MultiplyAndVerifyAllSign(
+ DecimalScalar128{kThirtyFive9s, 38, 38}, // x
+ DecimalScalar128{kThirtySix9s, 38, 38}, // y
+ DecimalScalar128{"100000000000000000000000000000000", 38, 37}, // expected
+ false); // overflow
+}
+
+TEST_F(TestDecimalSql, Divide) {
+ DivideAndVerifyAllSign(DecimalScalar128{"201", 10, 3}, // x
+ DecimalScalar128{"301", 10, 2}, // y
+ DecimalScalar128{"6677740863787", 23, 14}, // expected
+ false); // overflow
+
+ DivideAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x
+ DecimalScalar128{"301", 20, 2}, // y
+ DecimalScalar128{"667774086378737542", 38, 19}, // expected
+ false); // overflow
+
+ DivideAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x
+ DecimalScalar128{kThirtyFive9s, 35, 2}, // y
+ DecimalScalar128{"0", 38, 19}, // expected
+ false); // overflow
+
+ DivideAndVerifyAllSign(
+ DecimalScalar128{kThirtyFive9s, 35, 6}, // x
+ DecimalScalar128{"201", 20, 3}, // y
+ DecimalScalar128{"497512437810945273631840796019900493", 38, 6}, // expected
+ false); // overflow
+
+ DivideAndVerifyAllSign(DecimalScalar128{kThirtyEight9s, 38, 20}, // x
+ DecimalScalar128{kThirtyFive9s, 38, 20}, // y
+ DecimalScalar128{"1000000000", 38, 6}, // expected
+ false); // overflow
+
+ DivideAndVerifyAllSign(DecimalScalar128{"31939128063561476055", 38, 8}, // x
+ DecimalScalar128{"10000", 20, 0}, // y
+ DecimalScalar128{"3193912806356148", 38, 8}, // expected
+ false);
+
+ // Corner cases
+ DivideAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 38, 4}, // x
+ DecimalScalar128{0, UINT64_MAX, 38, 4}, // y
+ DecimalScalar128{"1000000", 38, 6}, // expected
+ false); // overflow
+
+ DivideAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 38, 4}, // x
+ DecimalScalar128{0, INT64_MAX, 38, 4}, // y
+ DecimalScalar128{"2000000", 38, 6}, // expected
+ false); // overflow
+
+ DivideAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 19, 5}, // x
+ DecimalScalar128{0, INT64_MAX, 19, 5}, // y
+ DecimalScalar128{"20000000000000000001", 38, 19}, // expected
+ false); // overflow
+
+ DivideAndVerifyAllSign(DecimalScalar128{kThirtyFive9s, 38, 37}, // x
+ DecimalScalar128{kThirtyFive9s, 38, 38}, // y
+ DecimalScalar128{"10000000", 38, 6}, // expected
+ false); // overflow
+
+ // overflow
+ DivideAndVerifyAllSign(DecimalScalar128{kThirtyEight9s, 38, 6}, // x
+ DecimalScalar128{"201", 20, 3}, // y
+ DecimalScalar128{"0", 38, 6}, // expected
+ true);
+}
+
+TEST_F(TestDecimalSql, Mod) {
+ ModAndVerifyAllSign(DecimalScalar128{"201", 10, 3}, // x
+ DecimalScalar128{"301", 10, 2}, // y
+ DecimalScalar128{"201", 10, 3}, // expected
+ false); // overflow
+
+ ModAndVerify(DecimalScalar128{"201", 20, 2}, // x
+ DecimalScalar128{"301", 20, 3}, // y
+ DecimalScalar128{"204", 20, 3}, // expected
+ false); // overflow
+
+ ModAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x
+ DecimalScalar128{kThirtyFive9s, 35, 2}, // y
+ DecimalScalar128{"201", 20, 3}, // expected
+ false); // overflow
+
+ ModAndVerifyAllSign(DecimalScalar128{kThirtyFive9s, 35, 6}, // x
+ DecimalScalar128{"201", 20, 3}, // y
+ DecimalScalar128{"180999", 23, 6}, // expected
+ false); // overflow
+
+ ModAndVerifyAllSign(DecimalScalar128{kThirtyEight9s, 38, 20}, // x
+ DecimalScalar128{kThirtyFive9s, 38, 21}, // y
+ DecimalScalar128{"9990", 38, 21}, // expected
+ false); // overflow
+
+ ModAndVerifyAllSign(DecimalScalar128{"31939128063561476055", 38, 8}, // x
+ DecimalScalar128{"10000", 20, 0}, // y
+ DecimalScalar128{"63561476055", 28, 8}, // expected
+ false);
+
+ ModAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 38, 4}, // x
+ DecimalScalar128{0, UINT64_MAX, 38, 4}, // y
+ DecimalScalar128{"0", 38, 4}, // expected
+ false); // overflow
+
+ ModAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 38, 4}, // x
+ DecimalScalar128{0, INT64_MAX, 38, 4}, // y
+ DecimalScalar128{"1", 38, 4}, // expected
+ false); // overflow
+}
+
+TEST_F(TestDecimalSql, DivideByZero) {
+ gandiva::ExecutionContext context;
+ int32_t result_precision;
+ int32_t result_scale;
+ bool overflow;
+
+ // divide-by-zero should cause an error.
+ context.Reset();
+ result_precision = 38;
+ result_scale = 19;
+ decimalops::Divide(reinterpret_cast<gdv_int64>(&context),
+ DecimalScalar128{"201", 20, 3}, DecimalScalar128{"0", 20, 2},
+ result_precision, result_scale, &overflow);
+ EXPECT_TRUE(context.has_error());
+ EXPECT_EQ(context.get_error(), "divide by zero error");
+
+ // divide-by-nonzero should not cause an error.
+ context.Reset();
+ decimalops::Divide(reinterpret_cast<gdv_int64>(&context),
+ DecimalScalar128{"201", 20, 3}, DecimalScalar128{"1", 20, 2},
+ result_precision, result_scale, &overflow);
+ EXPECT_FALSE(context.has_error());
+
+ // mod-by-zero should cause an error.
+ context.Reset();
+ result_precision = 20;
+ result_scale = 3;
+ decimalops::Mod(reinterpret_cast<gdv_int64>(&context), DecimalScalar128{"201", 20, 3},
+ DecimalScalar128{"0", 20, 2}, result_precision, result_scale,
+ &overflow);
+ EXPECT_TRUE(context.has_error());
+ EXPECT_EQ(context.get_error(), "divide by zero error");
+
+ // mod-by-nonzero should not cause an error.
+ context.Reset();
+ decimalops::Mod(reinterpret_cast<gdv_int64>(&context), DecimalScalar128{"201", 20, 3},
+ DecimalScalar128{"1", 20, 2}, result_precision, result_scale,
+ &overflow);
+ EXPECT_FALSE(context.has_error());
+}
+
+TEST_F(TestDecimalSql, Compare) {
+ // x.scale == y.scale
+ EXPECT_EQ(
+ 0, decimalops::Compare(DecimalScalar128{100, 38, 6}, DecimalScalar128{100, 38, 6}));
+ EXPECT_EQ(
+ 1, decimalops::Compare(DecimalScalar128{200, 38, 6}, DecimalScalar128{100, 38, 6}));
+ EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{100, 38, 6},
+ DecimalScalar128{200, 38, 6}));
+
+ // x.scale == y.scale, with -ve.
+ EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{-100, 38, 6},
+ DecimalScalar128{-100, 38, 6}));
+ EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{-200, 38, 6},
+ DecimalScalar128{-100, 38, 6}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{-100, 38, 6},
+ DecimalScalar128{-200, 38, 6}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{100, 38, 6},
+ DecimalScalar128{-200, 38, 6}));
+
+ for (int32_t precision : {16, 36, 38}) {
+ // x_scale > y_scale
+ EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{10000, precision, 6},
+ DecimalScalar128{100, precision, 4}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{20000, precision, 6},
+ DecimalScalar128{100, precision, 4}));
+ EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{10000, precision, 6},
+ DecimalScalar128{200, precision, 4}));
+
+ // x.scale > y.scale, with -ve
+ EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{-10000, precision, 6},
+ DecimalScalar128{-100, precision, 4}));
+ EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{-20000, precision, 6},
+ DecimalScalar128{-100, precision, 4}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{-10000, precision, 6},
+ DecimalScalar128{-200, precision, 4}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{10000, precision, 6},
+ DecimalScalar128{-200, precision, 4}));
+
+ // x.scale < y.scale
+ EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{100, precision, 4},
+ DecimalScalar128{10000, precision, 6}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{200, precision, 4},
+ DecimalScalar128{10000, precision, 6}));
+ EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{100, precision, 4},
+ DecimalScalar128{20000, precision, 6}));
+
+ // x.scale < y.scale, with -ve
+ EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{-100, precision, 4},
+ DecimalScalar128{-10000, precision, 6}));
+ EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{-200, precision, 4},
+ DecimalScalar128{-10000, precision, 6}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{-100, precision, 4},
+ DecimalScalar128{-20000, precision, 6}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{100, precision, 4},
+ DecimalScalar128{-200, precision, 6}));
+ }
+
+ // large cases.
+ EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{kThirtyEight9s, 38, 6},
+ DecimalScalar128{kThirtyEight9s, 38, 6}));
+
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{kThirtyEight9s, 38, 6},
+ DecimalScalar128{kThirtySix9s, 38, 4}));
+
+ EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{kThirtyEight9s, 38, 6},
+ DecimalScalar128{kThirtyEight9s, 38, 4}));
+}
+
+TEST_F(TestDecimalSql, Round) {
+ // expected, input, rounding_scale, overflow
+ using TupleType = std::tuple<DecimalScalar128, DecimalScalar128, int32_t, bool>;
+ std::vector<TupleType> test_values = {
+ // examples from
+ // https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_round
+ std::make_tuple(DecimalScalar128{-1, 36, 0}, DecimalScalar128{-123, 38, 2}, 0,
+ false),
+ std::make_tuple(DecimalScalar128{-2, 36, 0}, DecimalScalar128{-158, 38, 2}, 0,
+ false),
+ std::make_tuple(DecimalScalar128{2, 36, 0}, DecimalScalar128{158, 38, 2}, 0, false),
+ std::make_tuple(DecimalScalar128{-13, 36, 1}, DecimalScalar128{-1298, 38, 3}, 1,
+ false),
+ std::make_tuple(DecimalScalar128{-1, 35, 0}, DecimalScalar128{-1298, 38, 3}, 0,
+ false),
+ std::make_tuple(DecimalScalar128{20, 35, 0}, DecimalScalar128{23298, 38, 3}, -1,
+ false),
+ std::make_tuple(DecimalScalar128{100, 38, 0}, DecimalScalar128{122, 38, 0}, -2,
+ false),
+ std::make_tuple(DecimalScalar128{3, 37, 0}, DecimalScalar128{25, 38, 1}, 0, false),
+
+ // border cases
+ std::make_tuple(DecimalScalar128{INT64_MIN / 100, 36, 0},
+ DecimalScalar128{INT64_MIN, 38, 2}, 0, false),
+
+ std::make_tuple(DecimalScalar128{INT64_MIN, 38, 0},
+ DecimalScalar128{INT64_MIN, 38, 0}, 0, false),
+ std::make_tuple(DecimalScalar128{0, 0, 36, 0}, DecimalScalar128{0, 0, 38, 2}, 0,
+ false),
+ std::make_tuple(DecimalScalar128{INT64_MAX, 38, 0},
+ DecimalScalar128{INT64_MAX, 38, 0}, 0, false),
+
+ std::make_tuple(DecimalScalar128{INT64_MAX / 100, 36, 0},
+ DecimalScalar128{INT64_MAX, 38, 2}, 0, false),
+
+ // large scales
+ std::make_tuple(DecimalScalar128{0, 0, 22, 0}, DecimalScalar128{12345, 38, 16}, 0,
+ false),
+
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128{124}, 22, 0},
+ DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(14), 38, 16}, 0, false),
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128{-124}, 22, 0},
+ DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(14), 38, 16}, 0,
+ false),
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128{124}, 6, 0},
+ DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(30), 38, 32}, 0, false),
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128{-124}, 6, 0},
+ DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(30), 38, 32}, 0,
+ false),
+
+ // scale bigger than arg
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(32), 38, 32},
+ DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(32), 38, 32}, 35,
+ false),
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(32), 38, 32},
+ DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(32), 38, 32}, 35,
+ false),
+
+ // overflow
+ std::make_tuple(DecimalScalar128{0, 0, 1, 0}, DecimalScalar128{99, 2, 1}, 0, true),
+ };
+
+ for (auto iter : test_values) {
+ auto expected = std::get<0>(iter);
+ auto input = std::get<1>(iter);
+ auto rounding_scale = std::get<2>(iter);
+ auto expected_overflow = std::get<3>(iter);
+ bool overflow = false;
+
+ EXPECT_EQ(expected.value(),
+ decimalops::Round(input, expected.precision(), expected.scale(),
+ rounding_scale, &overflow))
+ << " failed on input " << input << " rounding scale " << rounding_scale;
+ if (expected_overflow) {
+ ASSERT_TRUE(overflow) << "overflow expected for input " << input;
+ } else {
+ ASSERT_FALSE(overflow) << "overflow not expected for input " << input;
+ }
+ }
+}
+
+TEST_F(TestDecimalSql, Truncate) {
+ // expected, input, rounding_scale, overflow
+ using TupleType = std::tuple<DecimalScalar128, DecimalScalar128, int32_t, bool>;
+ std::vector<TupleType> test_values = {
+ // examples from
+ // https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_truncate
+ std::make_tuple(DecimalScalar128{12, 36, 1}, DecimalScalar128{1223, 38, 3}, 1,
+ false),
+ std::make_tuple(DecimalScalar128{19, 36, 1}, DecimalScalar128{1999, 38, 3}, 1,
+ false),
+ std::make_tuple(DecimalScalar128{1, 35, 0}, DecimalScalar128{1999, 38, 3}, 0,
+ false),
+ std::make_tuple(DecimalScalar128{-19, 36, 1}, DecimalScalar128{-1999, 38, 3}, 1,
+ false),
+ std::make_tuple(DecimalScalar128{100, 38, 0}, DecimalScalar128{122, 38, 0}, -2,
+ false),
+ std::make_tuple(DecimalScalar128{1028, 38, 0}, DecimalScalar128{1028, 38, 0}, 0,
+ false),
+
+ // border cases
+ std::make_tuple(DecimalScalar128{BasicDecimal128{INT64_MIN / 100}, 36, 0},
+ DecimalScalar128{INT64_MIN, 38, 2}, 0, false),
+
+ std::make_tuple(DecimalScalar128{INT64_MIN, 38, 0},
+ DecimalScalar128{INT64_MIN, 38, 0}, 0, false),
+ std::make_tuple(DecimalScalar128{0, 0, 38, 0}, DecimalScalar128{0, 0, 38, 2}, 0,
+ false),
+ std::make_tuple(DecimalScalar128{INT64_MAX, 38, 0},
+ DecimalScalar128{INT64_MAX, 38, 0}, 0, false),
+
+ std::make_tuple(DecimalScalar128{BasicDecimal128(INT64_MAX / 100), 36, 0},
+ DecimalScalar128{INT64_MAX, 38, 2}, 0, false),
+
+ // large scales
+ std::make_tuple(DecimalScalar128{BasicDecimal128{0, 0}, 22, 0},
+ DecimalScalar128{12345, 38, 16}, 0, false),
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128{123}, 22, 0},
+ DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(14), 38, 16}, 0, false),
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128{-123}, 22, 0},
+ DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(14), 38, 16}, 0,
+ false),
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128{123}, 6, 0},
+ DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(30), 38, 32}, 0, false),
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128{-123}, 6, 0},
+ DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(30), 38, 32}, 0,
+ false),
+
+ // overflow
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(32), 38, 32},
+ DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(32), 38, 32}, 35,
+ false),
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(32), 38, 32},
+ DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(32), 38, 32}, 35,
+ false),
+ };
+
+ for (auto iter : test_values) {
+ auto expected = std::get<0>(iter);
+ auto input = std::get<1>(iter);
+ auto rounding_scale = std::get<2>(iter);
+ auto expected_overflow = std::get<3>(iter);
+ bool overflow = false;
+
+ EXPECT_EQ(expected.value(),
+ decimalops::Truncate(input, expected.precision(), expected.scale(),
+ rounding_scale, &overflow))
+ << " failed on input " << input << " rounding scale " << rounding_scale;
+ if (expected_overflow) {
+ ASSERT_TRUE(overflow) << "overflow expected for input " << input;
+ } else {
+ ASSERT_FALSE(overflow) << "overflow not expected for input " << input;
+ }
+ }
+}
+
+TEST_F(TestDecimalSql, Ceil) {
+ // expected, input, overflow
+ std::vector<std::tuple<BasicDecimal128, DecimalScalar128, bool>> test_values = {
+ // https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_ceil
+ std::make_tuple(2, DecimalScalar128{123, 38, 2}, false),
+ std::make_tuple(-1, DecimalScalar128{-123, 38, 2}, false),
+
+ // border cases
+ std::make_tuple(BasicDecimal128{INT64_MIN / 100},
+ DecimalScalar128{INT64_MIN, 38, 2}, false),
+
+ std::make_tuple(INT64_MIN, DecimalScalar128{INT64_MIN, 38, 0}, false),
+ std::make_tuple(BasicDecimal128{0, 0}, DecimalScalar128{0, 0, 38, 2}, false),
+ std::make_tuple(INT64_MAX, DecimalScalar128{INT64_MAX, 38, 0}, false),
+
+ std::make_tuple(BasicDecimal128(INT64_MAX / 100 + 1),
+ DecimalScalar128{INT64_MAX, 38, 2}, false),
+
+ // large scales
+ std::make_tuple(BasicDecimal128{0, 1}, DecimalScalar128{12345, 38, 16}, false),
+ std::make_tuple(
+ BasicDecimal128{124},
+ DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(14), 38, 16}, false),
+ std::make_tuple(
+ BasicDecimal128{-123},
+ DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(14), 38, 16}, false),
+ std::make_tuple(
+ BasicDecimal128{124},
+ DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(30), 38, 32}, false),
+ std::make_tuple(
+ BasicDecimal128{-123},
+ DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(30), 38, 32}, false),
+ };
+
+ for (auto iter : test_values) {
+ auto expected = std::get<0>(iter);
+ auto input = std::get<1>(iter);
+ auto expected_overflow = std::get<2>(iter);
+ bool overflow = false;
+
+ EXPECT_EQ(expected, decimalops::Ceil(input, &overflow))
+ << " failed on input " << input;
+ if (expected_overflow) {
+ ASSERT_TRUE(overflow) << "overflow expected for input " << input;
+ } else {
+ ASSERT_FALSE(overflow) << "overflow not expected for input " << input;
+ }
+ }
+}
+
+TEST_F(TestDecimalSql, Floor) {
+ // expected, input, overflow
+ std::vector<std::tuple<BasicDecimal128, DecimalScalar128, bool>> test_values = {
+ // https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_floor
+ std::make_tuple(1, DecimalScalar128{123, 38, 2}, false),
+ std::make_tuple(-2, DecimalScalar128{-123, 38, 2}, false),
+
+ // border cases
+ std::make_tuple(BasicDecimal128{INT64_MIN / 100 - 1},
+ DecimalScalar128{INT64_MIN, 38, 2}, false),
+
+ std::make_tuple(INT64_MIN, DecimalScalar128{INT64_MIN, 38, 0}, false),
+ std::make_tuple(BasicDecimal128{0, 0}, DecimalScalar128{0, 0, 38, 2}, false),
+ std::make_tuple(INT64_MAX, DecimalScalar128{INT64_MAX, 38, 0}, false),
+
+ std::make_tuple(BasicDecimal128{INT64_MAX / 100},
+ DecimalScalar128{INT64_MAX, 38, 2}, false),
+
+ // large scales
+ std::make_tuple(BasicDecimal128{0, 0}, DecimalScalar128{12345, 38, 16}, false),
+ std::make_tuple(
+ BasicDecimal128{123},
+ DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(14), 38, 16}, false),
+ std::make_tuple(
+ BasicDecimal128{-124},
+ DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(14), 38, 16}, false),
+ std::make_tuple(
+ BasicDecimal128{123},
+ DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(30), 38, 32}, false),
+ std::make_tuple(
+ BasicDecimal128{-124},
+ DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(30), 38, 32}, false),
+ };
+
+ for (auto iter : test_values) {
+ auto expected = std::get<0>(iter);
+ auto input = std::get<1>(iter);
+ auto expected_overflow = std::get<2>(iter);
+ bool overflow = false;
+
+ EXPECT_EQ(expected, decimalops::Floor(input, &overflow))
+ << " failed on input " << input;
+ if (expected_overflow) {
+ ASSERT_TRUE(overflow) << "overflow expected for input " << input;
+ } else {
+ ASSERT_FALSE(overflow) << "overflow not expected for input " << input;
+ }
+ }
+}
+
+TEST_F(TestDecimalSql, Convert) {
+ // expected, input, overflow
+ std::vector<std::tuple<DecimalScalar128, DecimalScalar128, bool>> test_values = {
+ // simple cases
+ std::make_tuple(DecimalScalar128{12, 38, 1}, DecimalScalar128{123, 38, 2}, false),
+ std::make_tuple(DecimalScalar128{1230, 38, 3}, DecimalScalar128{123, 38, 2}, false),
+ std::make_tuple(DecimalScalar128{123, 38, 2}, DecimalScalar128{123, 38, 2}, false),
+
+ std::make_tuple(DecimalScalar128{-12, 38, 1}, DecimalScalar128{-123, 38, 2}, false),
+ std::make_tuple(DecimalScalar128{-1230, 38, 3}, DecimalScalar128{-123, 38, 2},
+ false),
+ std::make_tuple(DecimalScalar128{-123, 38, 2}, DecimalScalar128{-123, 38, 2},
+ false),
+
+ // border cases
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128(INT64_MIN).ReduceScaleBy(1), 38, 1},
+ DecimalScalar128{INT64_MIN, 38, 2}, false),
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128(INT64_MIN).IncreaseScaleBy(1), 38, 3},
+ DecimalScalar128{INT64_MIN, 38, 2}, false),
+ std::make_tuple(DecimalScalar128{-3, 38, 1}, DecimalScalar128{-32, 38, 2}, false),
+ std::make_tuple(DecimalScalar128{0, 0, 38, 1}, DecimalScalar128{0, 0, 38, 2},
+ false),
+ std::make_tuple(DecimalScalar128{3, 38, 1}, DecimalScalar128{32, 38, 2}, false),
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128(INT64_MAX).ReduceScaleBy(1), 38, 1},
+ DecimalScalar128{INT64_MAX, 38, 2}, false),
+ std::make_tuple(
+ DecimalScalar128{BasicDecimal128(INT64_MAX).IncreaseScaleBy(1), 38, 3},
+ DecimalScalar128{INT64_MAX, 38, 2}, false),
+
+ // large scales
+ std::make_tuple(DecimalScalar128{BasicDecimal128(123).IncreaseScaleBy(16), 38, 18},
+ DecimalScalar128{123, 38, 2}, false),
+ std::make_tuple(DecimalScalar128{BasicDecimal128(-123).IncreaseScaleBy(16), 38, 18},
+ DecimalScalar128{-123, 38, 2}, false),
+ std::make_tuple(DecimalScalar128{BasicDecimal128(123).IncreaseScaleBy(30), 38, 32},
+ DecimalScalar128{123, 38, 2}, false),
+ std::make_tuple(DecimalScalar128{BasicDecimal128(-123).IncreaseScaleBy(30), 38, 32},
+ DecimalScalar128{-123, 38, 2}, false),
+
+ // overflow due to scaling up.
+ std::make_tuple(DecimalScalar128{0, 0, 38, 36}, DecimalScalar128{12345, 38, 2},
+ true),
+ std::make_tuple(DecimalScalar128{0, 0, 38, 36}, DecimalScalar128{-12345, 38, 2},
+ true),
+
+ // overflow due to precision.
+ std::make_tuple(DecimalScalar128{0, 0, 5, 3}, DecimalScalar128{12345, 5, 2}, true),
+ };
+
+ for (auto iter : test_values) {
+ auto expected = std::get<0>(iter);
+ auto input = std::get<1>(iter);
+ auto expected_overflow = std::get<2>(iter);
+ bool overflow = false;
+
+ EXPECT_EQ(expected.value(), decimalops::Convert(input, expected.precision(),
+ expected.scale(), &overflow))
+ << " failed on input " << input;
+
+ if (expected_overflow) {
+ ASSERT_TRUE(overflow) << "overflow expected for input " << input;
+ } else {
+ ASSERT_FALSE(overflow) << "overflow not expected for input " << input;
+ }
+ }
+}
+
+// double can store up to this integer value without losing precision
+static const int64_t kMaxDoubleInt = 1ull << 53;
+
+TEST_F(TestDecimalSql, FromDouble) {
+ // expected, input, overflow
+ std::vector<std::tuple<DecimalScalar128, double, bool>> test_values = {
+ // simple cases
+ std::make_tuple(DecimalScalar128{-16285, 38, 3}, -16.285, false),
+ std::make_tuple(DecimalScalar128{-162850, 38, 4}, -16.285, false),
+ std::make_tuple(DecimalScalar128{-1629, 38, 2}, -16.285, false),
+
+ std::make_tuple(DecimalScalar128{16285, 38, 3}, 16.285, false),
+ std::make_tuple(DecimalScalar128{162850, 38, 4}, 16.285, false),
+ std::make_tuple(DecimalScalar128{1629, 38, 2}, 16.285, false),
+
+ // round up
+ std::make_tuple(DecimalScalar128{1, 18, 0}, 1.15470053838, false),
+ std::make_tuple(DecimalScalar128{-1, 18, 0}, -1.15470053838, false),
+ std::make_tuple(DecimalScalar128{2, 18, 0}, 1.55470053838, false),
+ std::make_tuple(DecimalScalar128{-2, 18, 0}, -1.55470053838, false),
+
+ // border cases
+ std::make_tuple(DecimalScalar128{-kMaxDoubleInt, 38, 0},
+ static_cast<double>(-kMaxDoubleInt), false),
+ std::make_tuple(DecimalScalar128{-32, 38, 0}, -32, false),
+ std::make_tuple(DecimalScalar128{0, 0, 38, 0}, 0, false),
+ std::make_tuple(DecimalScalar128{32, 38, 0}, 32, false),
+ std::make_tuple(DecimalScalar128{kMaxDoubleInt, 38, 0},
+ static_cast<double>(kMaxDoubleInt), false),
+
+ // large scales
+ std::make_tuple(DecimalScalar128{123, 38, 16}, 1.23E-14, false),
+ std::make_tuple(DecimalScalar128{123, 38, 32}, 1.23E-30, false),
+ std::make_tuple(DecimalScalar128{1230, 38, 33}, 1.23E-30, false),
+ std::make_tuple(DecimalScalar128{123, 38, 38}, 1.23E-36, false),
+
+ // very small doubles
+ std::make_tuple(DecimalScalar128{0, 0, 38, 0}, std::numeric_limits<double>::min(),
+ false),
+ std::make_tuple(DecimalScalar128{0, 0, 38, 0}, -std::numeric_limits<double>::min(),
+ false),
+
+ // overflow due to large -ve double
+ std::make_tuple(DecimalScalar128{0, 0, 38, 0}, -std::numeric_limits<double>::max(),
+ true),
+ // overflow due to large +ve double
+ std::make_tuple(DecimalScalar128{0, 0, 38, 0}, std::numeric_limits<double>::max(),
+ true),
+ // overflow due to scaling up.
+ std::make_tuple(DecimalScalar128{0, 0, 38, 36}, 123.45, true),
+ // overflow due to precision.
+ std::make_tuple(DecimalScalar128{0, 0, 4, 2}, 12345.67, true),
+ };
+
+ for (auto iter : test_values) {
+ auto dscalar = std::get<0>(iter);
+ auto input = std::get<1>(iter);
+ auto expected_overflow = std::get<2>(iter);
+ bool overflow = false;
+
+ EXPECT_EQ(dscalar.value(), decimalops::FromDouble(input, dscalar.precision(),
+ dscalar.scale(), &overflow))
+ << " failed on input " << input;
+
+ if (expected_overflow) {
+ ASSERT_TRUE(overflow) << "overflow expected for input " << input;
+ } else {
+ ASSERT_FALSE(overflow) << "overflow not expected for input " << input;
+ }
+ }
+}
+
+#define EXPECT_FUZZY_EQ(x, y) \
+ EXPECT_TRUE(x - y <= 0.00001) << "expected " << x << ", got " << y
+
+TEST_F(TestDecimalSql, ToDouble) {
+ // expected, input, overflow
+ std::vector<std::tuple<double, DecimalScalar128>> test_values = {
+ // simple ones
+ std::make_tuple(-16.285, DecimalScalar128{-16285, 38, 3}),
+ std::make_tuple(-162.85, DecimalScalar128{-16285, 38, 2}),
+ std::make_tuple(-1.6285, DecimalScalar128{-16285, 38, 4}),
+
+ // large scales
+ std::make_tuple(1.23E-14, DecimalScalar128{123, 38, 16}),
+ std::make_tuple(1.23E-30, DecimalScalar128{123, 38, 32}),
+ std::make_tuple(1.23E-36, DecimalScalar128{123, 38, 38}),
+
+ // border cases
+ std::make_tuple(static_cast<double>(-kMaxDoubleInt),
+ DecimalScalar128{-kMaxDoubleInt, 38, 0}),
+ std::make_tuple(-32, DecimalScalar128{-32, 38, 0}),
+ std::make_tuple(0, DecimalScalar128{0, 0, 38, 0}),
+ std::make_tuple(32, DecimalScalar128{32, 38, 0}),
+ std::make_tuple(static_cast<double>(kMaxDoubleInt),
+ DecimalScalar128{kMaxDoubleInt, 38, 0}),
+ };
+ for (auto iter : test_values) {
+ auto input = std::get<1>(iter);
+ bool overflow = false;
+
+ EXPECT_FUZZY_EQ(std::get<0>(iter), decimalops::ToDouble(input, &overflow));
+ ASSERT_FALSE(overflow) << "overflow not expected for input " << input;
+ }
+}
+
+TEST_F(TestDecimalSql, FromInt64) {
+ // expected, input, overflow
+ std::vector<std::tuple<DecimalScalar128, int64_t, bool>> test_values = {
+ // simple cases
+ std::make_tuple(DecimalScalar128{-16000, 38, 3}, -16, false),
+ std::make_tuple(DecimalScalar128{-160000, 38, 4}, -16, false),
+ std::make_tuple(DecimalScalar128{-1600, 38, 2}, -16, false),
+
+ std::make_tuple(DecimalScalar128{16000, 38, 3}, 16, false),
+ std::make_tuple(DecimalScalar128{160000, 38, 4}, 16, false),
+ std::make_tuple(DecimalScalar128{1600, 38, 2}, 16, false),
+
+ // border cases
+ std::make_tuple(DecimalScalar128{INT64_MIN, 38, 0}, INT64_MIN, false),
+ std::make_tuple(DecimalScalar128{-32, 38, 0}, -32, false),
+ std::make_tuple(DecimalScalar128{0, 0, 38, 0}, 0, false),
+ std::make_tuple(DecimalScalar128{32, 38, 0}, 32, false),
+ std::make_tuple(DecimalScalar128{INT64_MAX, 38, 0}, INT64_MAX, false),
+
+ // large scales
+ std::make_tuple(DecimalScalar128{BasicDecimal128(123).IncreaseScaleBy(16), 38, 16},
+ 123, false),
+ std::make_tuple(DecimalScalar128{BasicDecimal128(123).IncreaseScaleBy(32), 38, 32},
+ 123, false),
+ std::make_tuple(DecimalScalar128{BasicDecimal128(-123).IncreaseScaleBy(16), 38, 16},
+ -123, false),
+ std::make_tuple(DecimalScalar128{BasicDecimal128(-123).IncreaseScaleBy(32), 38, 32},
+ -123, false),
+
+ // overflow due to scaling up.
+ std::make_tuple(DecimalScalar128{0, 0, 38, 36}, 123, true),
+ // overflow due to precision.
+ std::make_tuple(DecimalScalar128{0, 0, 4, 2}, 12345, true),
+ };
+
+ for (auto iter : test_values) {
+ auto dscalar = std::get<0>(iter);
+ auto input = std::get<1>(iter);
+ auto expected_overflow = std::get<2>(iter);
+ bool overflow = false;
+
+ EXPECT_EQ(dscalar.value(), decimalops::FromInt64(input, dscalar.precision(),
+ dscalar.scale(), &overflow))
+ << " failed on input " << input;
+
+ if (expected_overflow) {
+ ASSERT_TRUE(overflow) << "overflow expected for input " << input;
+ } else {
+ ASSERT_FALSE(overflow) << "overflow not expected for input " << input;
+ }
+ }
+}
+
+TEST_F(TestDecimalSql, ToInt64) {
+ // expected, input, overflow
+ std::vector<std::tuple<int64_t, DecimalScalar128, bool>> test_values = {
+ // simple ones
+ std::make_tuple(-16, DecimalScalar128{-16285, 38, 3}, false),
+ std::make_tuple(-163, DecimalScalar128{-16285, 38, 2}, false),
+ std::make_tuple(-2, DecimalScalar128{-16285, 38, 4}, false),
+
+ // border cases
+ std::make_tuple(INT64_MIN, DecimalScalar128{INT64_MIN, 38, 0}, false),
+ std::make_tuple(-32, DecimalScalar128{-32, 38, 0}, false),
+ std::make_tuple(0, DecimalScalar128{0, 0, 38, 0}, false),
+ std::make_tuple(32, DecimalScalar128{32, 38, 0}, false),
+ std::make_tuple(INT64_MAX, DecimalScalar128{INT64_MAX, 38, 0}, false),
+
+ // large scales
+ std::make_tuple(0, DecimalScalar128{123, 38, 16}, false),
+ std::make_tuple(0, DecimalScalar128{123, 38, 32}, false),
+ std::make_tuple(0, DecimalScalar128{123, 38, 38}, false),
+
+ // overflow test cases
+ // very large
+ std::make_tuple(0, DecimalScalar128{32768, 16, 38, 2}, true),
+ std::make_tuple(0, DecimalScalar128{INT64_MAX, UINT64_MAX, 38, 10}, true),
+ // very small
+ std::make_tuple(0, -DecimalScalar128{32768, 16, 38, 2}, true),
+ std::make_tuple(0, -DecimalScalar128{INT64_MAX, UINT64_MAX, 38, 10}, true),
+ };
+
+ for (auto iter : test_values) {
+ auto expected_value = std::get<0>(iter);
+ auto input = std::get<1>(iter);
+ auto expected_overflow = std::get<2>(iter);
+ bool overflow = false;
+
+ EXPECT_EQ(expected_value, decimalops::ToInt64(input, &overflow))
+ << " failed on input " << input;
+ if (expected_overflow) {
+ ASSERT_TRUE(overflow) << "overflow expected for input " << input;
+ } else {
+ ASSERT_FALSE(overflow) << "overflow not expected for input " << input;
+ }
+ }
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/precompiled/decimal_wrapper.cc b/src/arrow/cpp/src/gandiva/precompiled/decimal_wrapper.cc
new file mode 100644
index 000000000..082d5832d
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/decimal_wrapper.cc
@@ -0,0 +1,433 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/precompiled/decimal_ops.h"
+#include "gandiva/precompiled/types.h"
+
+extern "C" {
+
+FORCE_INLINE
+void add_large_decimal128_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, int64_t y_high, uint64_t y_low,
+ int32_t y_precision, int32_t y_scale,
+ int32_t out_precision, int32_t out_scale,
+ int64_t* out_high, uint64_t* out_low) {
+ gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale);
+ gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale);
+
+ arrow::BasicDecimal128 out = gandiva::decimalops::Add(x, y, out_precision, out_scale);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+void multiply_decimal128_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, int64_t y_high, uint64_t y_low,
+ int32_t y_precision, int32_t y_scale,
+ int32_t out_precision, int32_t out_scale,
+ int64_t* out_high, uint64_t* out_low) {
+ gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale);
+ gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale);
+ bool overflow;
+
+ // TODO ravindra: generate error on overflows (ARROW-4570).
+ arrow::BasicDecimal128 out =
+ gandiva::decimalops::Multiply(x, y, out_precision, out_scale, &overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+void divide_decimal128_decimal128(int64_t context, int64_t x_high, uint64_t x_low,
+ int32_t x_precision, int32_t x_scale, int64_t y_high,
+ uint64_t y_low, int32_t y_precision, int32_t y_scale,
+ int32_t out_precision, int32_t out_scale,
+ int64_t* out_high, uint64_t* out_low) {
+ gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale);
+ gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale);
+ bool overflow;
+
+ // TODO ravindra: generate error on overflows (ARROW-4570).
+ arrow::BasicDecimal128 out =
+ gandiva::decimalops::Divide(context, x, y, out_precision, out_scale, &overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+void mod_decimal128_decimal128(int64_t context, int64_t x_high, uint64_t x_low,
+ int32_t x_precision, int32_t x_scale, int64_t y_high,
+ uint64_t y_low, int32_t y_precision, int32_t y_scale,
+ int32_t out_precision, int32_t out_scale,
+ int64_t* out_high, uint64_t* out_low) {
+ gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale);
+ gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale);
+ bool overflow;
+
+ // TODO ravindra: generate error on overflows (ARROW-4570).
+ arrow::BasicDecimal128 out =
+ gandiva::decimalops::Mod(context, x, y, out_precision, out_scale, &overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+int32_t compare_decimal128_decimal128_internal(int64_t x_high, uint64_t x_low,
+ int32_t x_precision, int32_t x_scale,
+ int64_t y_high, uint64_t y_low,
+ int32_t y_precision, int32_t y_scale) {
+ gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale);
+ gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale);
+
+ return gandiva::decimalops::Compare(x, y);
+}
+
+FORCE_INLINE
+void abs_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, int32_t x_scale,
+ int32_t out_precision, int32_t out_scale, int64_t* out_high,
+ uint64_t* out_low) {
+ gandiva::BasicDecimal128 x(x_high, x_low);
+ x.Abs();
+ *out_high = x.high_bits();
+ *out_low = x.low_bits();
+}
+
+FORCE_INLINE
+void ceil_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, int32_t x_scale,
+ int32_t out_precision, int32_t out_scale, int64_t* out_high,
+ uint64_t* out_low) {
+ gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale);
+
+ bool overflow = false;
+ auto out = gandiva::decimalops::Ceil(x, &overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+void floor_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, int32_t out_precision, int32_t out_scale,
+ int64_t* out_high, uint64_t* out_low) {
+ gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale);
+
+ bool overflow = false;
+ auto out = gandiva::decimalops::Floor(x, &overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+void round_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, int32_t out_precision, int32_t out_scale,
+ int64_t* out_high, uint64_t* out_low) {
+ gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale);
+
+ bool overflow = false;
+ auto out = gandiva::decimalops::Round(x, out_precision, 0, 0, &overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+void round_decimal128_int32(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, int32_t rounding_scale,
+ int32_t out_precision, int32_t out_scale, int64_t* out_high,
+ uint64_t* out_low) {
+ gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale);
+
+ bool overflow = false;
+ auto out =
+ gandiva::decimalops::Round(x, out_precision, out_scale, rounding_scale, &overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+void truncate_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, int32_t out_precision, int32_t out_scale,
+ int64_t* out_high, uint64_t* out_low) {
+ gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale);
+
+ bool overflow = false;
+ auto out = gandiva::decimalops::Truncate(x, out_precision, 0, 0, &overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+void truncate_decimal128_int32(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, int32_t rounding_scale,
+ int32_t out_precision, int32_t out_scale,
+ int64_t* out_high, uint64_t* out_low) {
+ gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale);
+
+ bool overflow = false;
+ auto out = gandiva::decimalops::Truncate(x, out_precision, out_scale, rounding_scale,
+ &overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+double castFLOAT8_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale) {
+ gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale);
+
+ bool overflow = false;
+ return gandiva::decimalops::ToDouble(x, &overflow);
+}
+
+FORCE_INLINE
+int64_t castBIGINT_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale) {
+ gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale);
+
+ bool overflow = false;
+ return gandiva::decimalops::ToInt64(x, &overflow);
+}
+
+FORCE_INLINE
+void castDECIMAL_int64(int64_t in, int32_t x_precision, int32_t x_scale,
+ int64_t* out_high, uint64_t* out_low) {
+ bool overflow = false;
+ auto out = gandiva::decimalops::FromInt64(in, x_precision, x_scale, &overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+void castDECIMAL_int32(int32_t in, int32_t x_precision, int32_t x_scale,
+ int64_t* out_high, uint64_t* out_low) {
+ castDECIMAL_int64(in, x_precision, x_scale, out_high, out_low);
+}
+
+FORCE_INLINE
+void castDECIMAL_float64(double in, int32_t x_precision, int32_t x_scale,
+ int64_t* out_high, uint64_t* out_low) {
+ bool overflow = false;
+ auto out = gandiva::decimalops::FromDouble(in, x_precision, x_scale, &overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+void castDECIMAL_float32(float in, int32_t x_precision, int32_t x_scale,
+ int64_t* out_high, uint64_t* out_low) {
+ castDECIMAL_float64(in, x_precision, x_scale, out_high, out_low);
+}
+
+FORCE_INLINE
+bool castDecimal_internal(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, int32_t out_precision, int32_t out_scale,
+ int64_t* out_high, int64_t* out_low) {
+ gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale);
+ bool overflow = false;
+ auto out = gandiva::decimalops::Convert(x, out_precision, out_scale, &overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+ return overflow;
+}
+
+FORCE_INLINE
+void castDECIMAL_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, int32_t out_precision, int32_t out_scale,
+ int64_t* out_high, int64_t* out_low) {
+ castDecimal_internal(x_high, x_low, x_precision, x_scale, out_precision, out_scale,
+ out_high, out_low);
+}
+
+FORCE_INLINE
+void castDECIMALNullOnOverflow_decimal128(int64_t x_high, uint64_t x_low,
+ int32_t x_precision, int32_t x_scale,
+ bool x_isvalid, bool* out_valid,
+ int32_t out_precision, int32_t out_scale,
+ int64_t* out_high, int64_t* out_low) {
+ *out_valid = true;
+
+ if (!x_isvalid) {
+ *out_valid = false;
+ return;
+ }
+
+ if (castDecimal_internal(x_high, x_low, x_precision, x_scale, out_precision, out_scale,
+ out_high, out_low)) {
+ *out_valid = false;
+ }
+}
+
+FORCE_INLINE
+int32_t hash32_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, gdv_boolean x_isvalid) {
+ return x_isvalid
+ ? hash32_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, 0)
+ : 0;
+}
+
+FORCE_INLINE
+int32_t hash_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, gdv_boolean x_isvalid) {
+ return hash32_decimal128(x_high, x_low, x_precision, x_scale, x_isvalid);
+}
+
+FORCE_INLINE
+int64_t hash64_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, gdv_boolean x_isvalid) {
+ return x_isvalid
+ ? hash64_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, 0)
+ : 0;
+}
+
+FORCE_INLINE
+int32_t hash32WithSeed_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, gdv_boolean x_isvalid, int32_t seed,
+ gdv_boolean seed_isvalid) {
+ if (!x_isvalid) {
+ return seed;
+ }
+ return hash32_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, seed);
+}
+
+FORCE_INLINE
+int64_t hash64WithSeed_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, gdv_boolean x_isvalid, int64_t seed,
+ gdv_boolean seed_isvalid) {
+ if (!x_isvalid) {
+ return seed;
+ }
+ return hash64_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, seed);
+}
+
+FORCE_INLINE
+int32_t hash32AsDouble_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, gdv_boolean x_isvalid) {
+ return x_isvalid
+ ? hash32_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, 0)
+ : 0;
+}
+
+FORCE_INLINE
+int64_t hash64AsDouble_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, gdv_boolean x_isvalid) {
+ return x_isvalid
+ ? hash64_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, 0)
+ : 0;
+}
+
+FORCE_INLINE
+int32_t hash32AsDoubleWithSeed_decimal128(int64_t x_high, uint64_t x_low,
+ int32_t x_precision, int32_t x_scale,
+ gdv_boolean x_isvalid, int32_t seed,
+ gdv_boolean seed_isvalid) {
+ if (!x_isvalid) {
+ return seed;
+ }
+ return hash32_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, seed);
+}
+
+FORCE_INLINE
+int64_t hash64AsDoubleWithSeed_decimal128(int64_t x_high, uint64_t x_low,
+ int32_t x_precision, int32_t x_scale,
+ gdv_boolean x_isvalid, int64_t seed,
+ gdv_boolean seed_isvalid) {
+ if (!x_isvalid) {
+ return seed;
+ }
+ return hash64_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, seed);
+}
+
+FORCE_INLINE
+gdv_boolean isnull_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, gdv_boolean x_isvalid) {
+ return !x_isvalid;
+}
+
+FORCE_INLINE
+gdv_boolean isnotnull_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, gdv_boolean x_isvalid) {
+ return x_isvalid;
+}
+
+FORCE_INLINE
+gdv_boolean isnumeric_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, gdv_boolean x_isvalid) {
+ return x_isvalid;
+}
+
+FORCE_INLINE
+gdv_boolean is_not_distinct_from_decimal128_decimal128(
+ int64_t x_high, uint64_t x_low, int32_t x_precision, int32_t x_scale,
+ gdv_boolean x_isvalid, int64_t y_high, uint64_t y_low, int32_t y_precision,
+ int32_t y_scale, gdv_boolean y_isvalid) {
+ if (x_isvalid != y_isvalid) {
+ return false;
+ }
+ if (!x_isvalid) {
+ return true;
+ }
+ return 0 == compare_decimal128_decimal128_internal(x_high, x_low, x_precision, x_scale,
+ y_high, y_low, y_precision, y_scale);
+}
+
+FORCE_INLINE
+gdv_boolean is_distinct_from_decimal128_decimal128(int64_t x_high, uint64_t x_low,
+ int32_t x_precision, int32_t x_scale,
+ gdv_boolean x_isvalid, int64_t y_high,
+ uint64_t y_low, int32_t y_precision,
+ int32_t y_scale,
+ gdv_boolean y_isvalid) {
+ return !is_not_distinct_from_decimal128_decimal128(x_high, x_low, x_precision, x_scale,
+ x_isvalid, y_high, y_low,
+ y_precision, y_scale, y_isvalid);
+}
+
+FORCE_INLINE
+void castDECIMAL_utf8(int64_t context, const char* in, int32_t in_length,
+ int32_t out_precision, int32_t out_scale, int64_t* out_high,
+ uint64_t* out_low) {
+ int64_t dec_high_from_str;
+ uint64_t dec_low_from_str;
+ int32_t precision_from_str;
+ int32_t scale_from_str;
+ int32_t status =
+ gdv_fn_dec_from_string(context, in, in_length, &precision_from_str, &scale_from_str,
+ &dec_high_from_str, &dec_low_from_str);
+ if (status != 0) {
+ return;
+ }
+
+ gandiva::BasicDecimalScalar128 x({dec_high_from_str, dec_low_from_str},
+ precision_from_str, scale_from_str);
+ bool overflow = false;
+ auto out = gandiva::decimalops::Convert(x, out_precision, out_scale, &overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+char* castVARCHAR_decimal128_int64(int64_t context, int64_t x_high, uint64_t x_low,
+ int32_t x_precision, int32_t x_scale,
+ int64_t out_len_param, int32_t* out_length) {
+ int32_t full_dec_str_len;
+ char* dec_str =
+ gdv_fn_dec_to_string(context, x_high, x_low, x_scale, &full_dec_str_len);
+ int32_t trunc_dec_str_len =
+ out_len_param < full_dec_str_len ? out_len_param : full_dec_str_len;
+ *out_length = trunc_dec_str_len;
+ return dec_str;
+}
+
+} // extern "C"
diff --git a/src/arrow/cpp/src/gandiva/precompiled/epoch_time_point.h b/src/arrow/cpp/src/gandiva/precompiled/epoch_time_point.h
new file mode 100644
index 000000000..45cfb28ca
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/epoch_time_point.h
@@ -0,0 +1,118 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+// TODO(wesm): IR compilation does not have any include directories set
+#include "../../arrow/vendored/datetime/date.h"
+
+bool is_leap_year(int yy);
+bool did_days_overflow(arrow_vendored::date::year_month_day ymd);
+int last_possible_day_in_month(int month, int year);
+
+// A point of time measured in millis since epoch.
+class EpochTimePoint {
+ public:
+ explicit EpochTimePoint(std::chrono::milliseconds millis_since_epoch)
+ : tp_(millis_since_epoch) {}
+
+ explicit EpochTimePoint(int64_t millis_since_epoch)
+ : EpochTimePoint(std::chrono::milliseconds(millis_since_epoch)) {}
+
+ int TmYear() const { return static_cast<int>(YearMonthDay().year()) - 1900; }
+
+ int TmMon() const { return static_cast<unsigned int>(YearMonthDay().month()) - 1; }
+
+ int TmYday() const {
+ auto to_days = arrow_vendored::date::floor<arrow_vendored::date::days>(tp_);
+ auto first_day_in_year = arrow_vendored::date::sys_days{
+ YearMonthDay().year() / arrow_vendored::date::jan / 1};
+ return (to_days - first_day_in_year).count();
+ }
+
+ int TmMday() const { return static_cast<unsigned int>(YearMonthDay().day()); }
+
+ int TmWday() const {
+ auto to_days = arrow_vendored::date::floor<arrow_vendored::date::days>(tp_);
+ return (arrow_vendored::date::weekday{to_days} - // NOLINT
+ arrow_vendored::date::Sunday)
+ .count();
+ }
+
+ int TmHour() const { return static_cast<int>(TimeOfDay().hours().count()); }
+
+ int TmMin() const { return static_cast<int>(TimeOfDay().minutes().count()); }
+
+ int TmSec() const {
+ // TODO(wesm): UNIX y2k issue on int=gdv_int32 platforms
+ return static_cast<int>(TimeOfDay().seconds().count());
+ }
+
+ EpochTimePoint AddYears(int num_years) const {
+ auto ymd = YearMonthDay() + arrow_vendored::date::years(num_years);
+ return EpochTimePoint((arrow_vendored::date::sys_days{ymd} + // NOLINT
+ TimeOfDay().to_duration())
+ .time_since_epoch());
+ }
+
+ EpochTimePoint AddMonths(int num_months) const {
+ auto ymd = YearMonthDay() + arrow_vendored::date::months(num_months);
+
+ EpochTimePoint tp = EpochTimePoint((arrow_vendored::date::sys_days{ymd} + // NOLINT
+ TimeOfDay().to_duration())
+ .time_since_epoch());
+
+ if (did_days_overflow(ymd)) {
+ int days_to_offset =
+ last_possible_day_in_month(static_cast<int>(ymd.year()),
+ static_cast<unsigned int>(ymd.month())) -
+ static_cast<unsigned int>(ymd.day());
+ tp = tp.AddDays(days_to_offset);
+ }
+ return tp;
+ }
+
+ EpochTimePoint AddDays(int num_days) const {
+ auto days_since_epoch = arrow_vendored::date::sys_days{YearMonthDay()} + // NOLINT
+ arrow_vendored::date::days(num_days);
+ return EpochTimePoint(
+ (days_since_epoch + TimeOfDay().to_duration()).time_since_epoch());
+ }
+
+ EpochTimePoint ClearTimeOfDay() const {
+ return EpochTimePoint((tp_ - TimeOfDay().to_duration()).time_since_epoch());
+ }
+
+ bool operator==(const EpochTimePoint& other) const { return tp_ == other.tp_; }
+
+ int64_t MillisSinceEpoch() const { return tp_.time_since_epoch().count(); }
+
+ arrow_vendored::date::time_of_day<std::chrono::milliseconds> TimeOfDay() const {
+ auto millis_since_midnight =
+ tp_ - arrow_vendored::date::floor<arrow_vendored::date::days>(tp_);
+ return arrow_vendored::date::time_of_day<std::chrono::milliseconds>(
+ millis_since_midnight);
+ }
+
+ private:
+ arrow_vendored::date::year_month_day YearMonthDay() const {
+ return arrow_vendored::date::year_month_day{
+ arrow_vendored::date::floor<arrow_vendored::date::days>(tp_)}; // NOLINT
+ }
+
+ std::chrono::time_point<std::chrono::system_clock, std::chrono::milliseconds> tp_;
+};
diff --git a/src/arrow/cpp/src/gandiva/precompiled/epoch_time_point_test.cc b/src/arrow/cpp/src/gandiva/precompiled/epoch_time_point_test.cc
new file mode 100644
index 000000000..9180aac07
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/epoch_time_point_test.cc
@@ -0,0 +1,103 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <ctime>
+
+#include <gtest/gtest.h>
+#include "./epoch_time_point.h"
+#include "gandiva/precompiled/testing.h"
+#include "gandiva/precompiled/types.h"
+
+#include "gandiva/date_utils.h"
+
+namespace gandiva {
+
+TEST(TestEpochTimePoint, TestTm) {
+ auto ts = StringToTimestamp("2015-05-07 10:20:34");
+ EpochTimePoint tp(ts);
+
+ struct tm* tm_ptr;
+#if defined(_WIN32)
+ __time64_t tsec = ts / 1000;
+ tm_ptr = _gmtime64(&tsec);
+#else
+ struct tm tm;
+ time_t tsec = ts / 1000;
+ tm_ptr = gmtime_r(&tsec, &tm);
+#endif
+
+ EXPECT_EQ(tp.TmYear(), tm_ptr->tm_year);
+ EXPECT_EQ(tp.TmMon(), tm_ptr->tm_mon);
+ EXPECT_EQ(tp.TmYday(), tm_ptr->tm_yday);
+ EXPECT_EQ(tp.TmMday(), tm_ptr->tm_mday);
+ EXPECT_EQ(tp.TmWday(), tm_ptr->tm_wday);
+ EXPECT_EQ(tp.TmHour(), tm_ptr->tm_hour);
+ EXPECT_EQ(tp.TmMin(), tm_ptr->tm_min);
+ EXPECT_EQ(tp.TmSec(), tm_ptr->tm_sec);
+}
+
+TEST(TestEpochTimePoint, TestAddYears) {
+ EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddYears(2),
+ EpochTimePoint(StringToTimestamp("2017-05-05 10:20:34")));
+
+ EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddYears(0),
+ EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")));
+
+ EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddYears(-1),
+ EpochTimePoint(StringToTimestamp("2014-05-05 10:20:34")));
+}
+
+TEST(TestEpochTimePoint, TestAddMonths) {
+ EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddMonths(2),
+ EpochTimePoint(StringToTimestamp("2015-07-05 10:20:34")));
+
+ EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddMonths(11),
+ EpochTimePoint(StringToTimestamp("2016-04-05 10:20:34")));
+
+ EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddMonths(0),
+ EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")));
+
+ EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddMonths(-1),
+ EpochTimePoint(StringToTimestamp("2015-04-05 10:20:34")));
+
+ EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddMonths(-10),
+ EpochTimePoint(StringToTimestamp("2014-07-05 10:20:34")));
+}
+
+TEST(TestEpochTimePoint, TestAddDays) {
+ EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddDays(2),
+ EpochTimePoint(StringToTimestamp("2015-05-07 10:20:34")));
+
+ EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddDays(11),
+ EpochTimePoint(StringToTimestamp("2015-05-16 10:20:34")));
+
+ EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddDays(0),
+ EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")));
+
+ EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddDays(-1),
+ EpochTimePoint(StringToTimestamp("2015-05-04 10:20:34")));
+
+ EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddDays(-10),
+ EpochTimePoint(StringToTimestamp("2015-04-25 10:20:34")));
+}
+
+TEST(TestEpochTimePoint, TestClearTimeOfDay) {
+ EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).ClearTimeOfDay(),
+ EpochTimePoint(StringToTimestamp("2015-05-05 00:00:00")));
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/precompiled/extended_math_ops.cc b/src/arrow/cpp/src/gandiva/precompiled/extended_math_ops.cc
new file mode 100644
index 000000000..365b08a6d
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/extended_math_ops.cc
@@ -0,0 +1,410 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef M_PI
+#define M_PI 3.14159265358979323846
+#endif
+
+#include "arrow/util/logging.h"
+#include "gandiva/precompiled/decimal_ops.h"
+
+extern "C" {
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include "./types.h"
+
+// Expand the inner fn for types that support extended math.
+#define ENUMERIC_TYPES_UNARY(INNER, OUT_TYPE) \
+ INNER(int32, OUT_TYPE) \
+ INNER(uint32, OUT_TYPE) \
+ INNER(int64, OUT_TYPE) \
+ INNER(uint64, OUT_TYPE) \
+ INNER(float32, OUT_TYPE) \
+ INNER(float64, OUT_TYPE)
+
+// Cubic root
+#define CBRT(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE cbrt_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_float64>(cbrtl(static_cast<long double>(in))); \
+ }
+
+ENUMERIC_TYPES_UNARY(CBRT, float64)
+
+// Exponent
+#define EXP(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE exp_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_float64>(expl(static_cast<long double>(in))); \
+ }
+
+ENUMERIC_TYPES_UNARY(EXP, float64)
+
+// log
+#define LOG(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE log_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_float64>(logl(static_cast<long double>(in))); \
+ }
+
+ENUMERIC_TYPES_UNARY(LOG, float64)
+
+// log base 10
+#define LOG10(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE log10_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_float64>(log10l(static_cast<long double>(in))); \
+ }
+
+#define LOGL(VALUE) static_cast<gdv_float64>(logl(static_cast<long double>(VALUE)))
+
+ENUMERIC_TYPES_UNARY(LOG10, float64)
+
+FORCE_INLINE
+void set_error_for_logbase(int64_t execution_context, double base) {
+ char const* prefix = "divide by zero error with log of base";
+ int size = static_cast<int>(strlen(prefix)) + 64;
+ char* error = reinterpret_cast<char*>(malloc(size));
+ snprintf(error, size, "%s %f", prefix, base);
+ gdv_fn_context_set_error_msg(execution_context, error);
+ free(static_cast<char*>(error));
+}
+
+// log with base
+#define LOG_WITH_BASE(IN_TYPE1, IN_TYPE2, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE log_##IN_TYPE1##_##IN_TYPE2(gdv_int64 context, gdv_##IN_TYPE1 base, \
+ gdv_##IN_TYPE2 value) { \
+ gdv_##OUT_TYPE log_of_base = LOGL(base); \
+ if (log_of_base == 0) { \
+ set_error_for_logbase(context, static_cast<gdv_float64>(base)); \
+ return 0; \
+ } \
+ return LOGL(value) / LOGL(base); \
+ }
+
+LOG_WITH_BASE(int32, int32, float64)
+LOG_WITH_BASE(uint32, uint32, float64)
+LOG_WITH_BASE(int64, int64, float64)
+LOG_WITH_BASE(uint64, uint64, float64)
+LOG_WITH_BASE(float32, float32, float64)
+LOG_WITH_BASE(float64, float64, float64)
+
+// Sin
+#define SIN(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE sin_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_##OUT_TYPE>(sin(static_cast<long double>(in))); \
+ }
+ENUMERIC_TYPES_UNARY(SIN, float64)
+
+// Asin
+#define ASIN(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE asin_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_##OUT_TYPE>(asin(static_cast<long double>(in))); \
+ }
+ENUMERIC_TYPES_UNARY(ASIN, float64)
+
+// Cos
+#define COS(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE cos_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_##OUT_TYPE>(cos(static_cast<long double>(in))); \
+ }
+ENUMERIC_TYPES_UNARY(COS, float64)
+
+// Acos
+#define ACOS(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE acos_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_##OUT_TYPE>(acos(static_cast<long double>(in))); \
+ }
+ENUMERIC_TYPES_UNARY(ACOS, float64)
+
+// Tan
+#define TAN(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE tan_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_##OUT_TYPE>(tan(static_cast<long double>(in))); \
+ }
+ENUMERIC_TYPES_UNARY(TAN, float64)
+
+// Atan
+#define ATAN(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE atan_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_##OUT_TYPE>(atan(static_cast<long double>(in))); \
+ }
+ENUMERIC_TYPES_UNARY(ATAN, float64)
+
+// Sinh
+#define SINH(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE sinh_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_##OUT_TYPE>(sinh(static_cast<long double>(in))); \
+ }
+ENUMERIC_TYPES_UNARY(SINH, float64)
+
+// Cosh
+#define COSH(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE cosh_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_##OUT_TYPE>(cosh(static_cast<long double>(in))); \
+ }
+ENUMERIC_TYPES_UNARY(COSH, float64)
+
+// Tanh
+#define TANH(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE tanh_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_##OUT_TYPE>(tanh(static_cast<long double>(in))); \
+ }
+ENUMERIC_TYPES_UNARY(TANH, float64)
+
+// Atan2
+#define ATAN2(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE atan2_##IN_TYPE##_##IN_TYPE(gdv_##IN_TYPE in1, gdv_##IN_TYPE in2) { \
+ return static_cast<gdv_##OUT_TYPE>( \
+ atan2(static_cast<long double>(in1), static_cast<long double>(in2))); \
+ }
+ENUMERIC_TYPES_UNARY(ATAN2, float64)
+
+// Cot
+#define COT(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE cot_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_##OUT_TYPE>(tan(M_PI / 2 - static_cast<long double>(in))); \
+ }
+ENUMERIC_TYPES_UNARY(COT, float64)
+
+// Radians
+#define RADIANS(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE radians_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_##OUT_TYPE>(static_cast<long double>(in) * M_PI / 180.0); \
+ }
+ENUMERIC_TYPES_UNARY(RADIANS, float64)
+
+// Degrees
+#define DEGREES(IN_TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE degrees_##IN_TYPE(gdv_##IN_TYPE in) { \
+ return static_cast<gdv_##OUT_TYPE>(static_cast<long double>(in) * 180.0 / M_PI); \
+ }
+ENUMERIC_TYPES_UNARY(DEGREES, float64)
+
+// power
+#define POWER(IN_TYPE1, IN_TYPE2, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE power_##IN_TYPE1##_##IN_TYPE2(gdv_##IN_TYPE1 in1, gdv_##IN_TYPE2 in2) { \
+ return static_cast<gdv_float64>(powl(in1, in2)); \
+ }
+POWER(float64, float64, float64)
+
+FORCE_INLINE
+gdv_int32 round_int32(gdv_int32 num) { return num; }
+
+FORCE_INLINE
+gdv_int64 round_int64(gdv_int64 num) { return num; }
+
+// rounds the number to the nearest integer
+#define ROUND_DECIMAL(TYPE) \
+ FORCE_INLINE \
+ gdv_##TYPE round_##TYPE(gdv_##TYPE num) { \
+ return static_cast<gdv_##TYPE>(trunc(num + ((num >= 0) ? 0.5 : -0.5))); \
+ }
+
+ROUND_DECIMAL(float32)
+ROUND_DECIMAL(float64)
+
+// rounds the number to the given scale
+#define ROUND_DECIMAL_TO_SCALE(TYPE) \
+ FORCE_INLINE \
+ gdv_##TYPE round_##TYPE##_int32(gdv_##TYPE number, gdv_int32 out_scale) { \
+ gdv_float64 scale_multiplier = get_scale_multiplier(out_scale); \
+ return static_cast<gdv_##TYPE>( \
+ trunc(number * scale_multiplier + ((number >= 0) ? 0.5 : -0.5)) / \
+ scale_multiplier); \
+ }
+
+ROUND_DECIMAL_TO_SCALE(float32)
+ROUND_DECIMAL_TO_SCALE(float64)
+
+FORCE_INLINE
+gdv_int32 round_int32_int32(gdv_int32 number, gdv_int32 precision) {
+ // for integers, there is nothing following the decimal point,
+ // so round() always returns the same number if precision >= 0
+ if (precision >= 0) {
+ return number;
+ }
+ gdv_int32 abs_precision = -precision;
+ // This is to ensure that there is no overflow while calculating 10^precision, 9 is
+ // the smallest N for which 10^N does not fit into 32 bits, so we can safely return 0
+ if (abs_precision > 9) {
+ return 0;
+ }
+ gdv_int32 num_sign = (number > 0) ? 1 : -1;
+ gdv_int32 abs_number = number * num_sign;
+ gdv_int32 power_of_10 = static_cast<gdv_int32>(get_power_of_10(abs_precision));
+ gdv_int32 remainder = abs_number % power_of_10;
+ abs_number -= remainder;
+ // if the fractional part of the quotient >= 0.5, round to next higher integer
+ if (remainder >= power_of_10 / 2) {
+ abs_number += power_of_10;
+ }
+ return abs_number * num_sign;
+}
+
+FORCE_INLINE
+gdv_int64 round_int64_int32(gdv_int64 number, gdv_int32 precision) {
+ // for long integers, there is nothing following the decimal point,
+ // so round() always returns the same number if precision >= 0
+ if (precision >= 0) {
+ return number;
+ }
+ gdv_int32 abs_precision = -precision;
+ // This is to ensure that there is no overflow while calculating 10^precision, 19 is
+ // the smallest N for which 10^N does not fit into 64 bits, so we can safely return 0
+ if (abs_precision > 18) {
+ return 0;
+ }
+ gdv_int32 num_sign = (number > 0) ? 1 : -1;
+ gdv_int64 abs_number = number * num_sign;
+ gdv_int64 power_of_10 = get_power_of_10(abs_precision);
+ gdv_int64 remainder = abs_number % power_of_10;
+ abs_number -= remainder;
+ // if the fractional part of the quotient >= 0.5, round to next higher integer
+ if (remainder >= power_of_10 / 2) {
+ abs_number += power_of_10;
+ }
+ return abs_number * num_sign;
+}
+
+FORCE_INLINE
+gdv_int64 get_power_of_10(gdv_int32 exp) {
+ DCHECK_GE(exp, 0);
+ DCHECK_LE(exp, 18);
+ static const gdv_int64 power_of_10[] = {1,
+ 10,
+ 100,
+ 1000,
+ 10000,
+ 100000,
+ 1000000,
+ 10000000,
+ 100000000,
+ 1000000000,
+ 10000000000,
+ 100000000000,
+ 1000000000000,
+ 10000000000000,
+ 100000000000000,
+ 1000000000000000,
+ 10000000000000000,
+ 100000000000000000,
+ 1000000000000000000};
+ return power_of_10[exp];
+}
+
+FORCE_INLINE
+gdv_int64 truncate_int64_int32(gdv_int64 in, gdv_int32 out_scale) {
+ bool overflow = false;
+ arrow::BasicDecimal128 decimal = gandiva::decimalops::FromInt64(in, 38, 0, &overflow);
+ arrow::BasicDecimal128 decimal_with_outscale =
+ gandiva::decimalops::Truncate(gandiva::BasicDecimalScalar128(decimal, 38, 0), 38,
+ out_scale, out_scale, &overflow);
+ if (out_scale < 0) {
+ out_scale = 0;
+ }
+ return gandiva::decimalops::ToInt64(
+ gandiva::BasicDecimalScalar128(decimal_with_outscale, 38, out_scale), &overflow);
+}
+
+FORCE_INLINE
+gdv_float64 get_scale_multiplier(gdv_int32 scale) {
+ static const gdv_float64 values[] = {1.0,
+ 10.0,
+ 100.0,
+ 1000.0,
+ 10000.0,
+ 100000.0,
+ 1000000.0,
+ 10000000.0,
+ 100000000.0,
+ 1000000000.0,
+ 10000000000.0,
+ 100000000000.0,
+ 1000000000000.0,
+ 10000000000000.0,
+ 100000000000000.0,
+ 1000000000000000.0,
+ 10000000000000000.0,
+ 100000000000000000.0,
+ 1000000000000000000.0,
+ 10000000000000000000.0};
+ if (scale >= 0 && scale < 20) {
+ return values[scale];
+ }
+ return power_float64_float64(10.0, scale);
+}
+
+// returns the binary representation of a given integer (e.g. 928 -> 1110100000)
+#define BIN_INTEGER(IN_TYPE) \
+ FORCE_INLINE \
+ const char* bin_##IN_TYPE(int64_t context, gdv_##IN_TYPE value, int32_t* out_len) { \
+ *out_len = 0; \
+ int32_t len = 8 * sizeof(value); \
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, len)); \
+ if (ret == nullptr) { \
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output"); \
+ return ""; \
+ } \
+ /* handle case when value is zero */ \
+ if (value == 0) { \
+ *out_len = 1; \
+ ret[0] = '0'; \
+ return ret; \
+ } \
+ /* generate binary representation iteratively */ \
+ gdv_u##IN_TYPE i; \
+ int8_t count = 0; \
+ bool first = false; /* flag for not printing left zeros in positive numbers */ \
+ for (i = static_cast<gdv_u##IN_TYPE>(1) << (len - 1); i > 0; i = i / 2) { \
+ if ((value & i) != 0) { \
+ ret[count] = '1'; \
+ if (!first) first = true; \
+ } else { \
+ if (!first) continue; \
+ ret[count] = '0'; \
+ } \
+ count += 1; \
+ } \
+ *out_len = count; \
+ return ret; \
+ }
+
+BIN_INTEGER(int32)
+BIN_INTEGER(int64)
+
+#undef BIN_INTEGER
+
+} // extern "C"
diff --git a/src/arrow/cpp/src/gandiva/precompiled/extended_math_ops_test.cc b/src/arrow/cpp/src/gandiva/precompiled/extended_math_ops_test.cc
new file mode 100644
index 000000000..147b4035c
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/extended_math_ops_test.cc
@@ -0,0 +1,349 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef M_PI
+#define M_PI 3.14159265358979323846
+#endif
+
+#include <gtest/gtest.h>
+#include <cmath>
+#include "gandiva/execution_context.h"
+#include "gandiva/precompiled/types.h"
+
+namespace gandiva {
+
+static const double MAX_ERROR = 0.00005;
+
+void VerifyFuzzyEquals(double actual, double expected, double max_error = MAX_ERROR) {
+ EXPECT_TRUE(fabs(actual - expected) < max_error) << actual << " != " << expected;
+}
+
+TEST(TestExtendedMathOps, TestCbrt) {
+ VerifyFuzzyEquals(cbrt_int32(27), 3);
+ VerifyFuzzyEquals(cbrt_int64(27), 3);
+ VerifyFuzzyEquals(cbrt_float32(27), 3);
+ VerifyFuzzyEquals(cbrt_float64(27), 3);
+ VerifyFuzzyEquals(cbrt_float64(-27), -3);
+
+ VerifyFuzzyEquals(cbrt_float32(15.625), 2.5);
+ VerifyFuzzyEquals(cbrt_float64(15.625), 2.5);
+}
+
+TEST(TestExtendedMathOps, TestExp) {
+ double val = 20.085536923187668;
+
+ VerifyFuzzyEquals(exp_int32(3), val);
+ VerifyFuzzyEquals(exp_int64(3), val);
+ VerifyFuzzyEquals(exp_float32(3), val);
+ VerifyFuzzyEquals(exp_float64(3), val);
+}
+
+TEST(TestExtendedMathOps, TestLog) {
+ double val = 4.1588830833596715;
+
+ VerifyFuzzyEquals(log_int32(64), val);
+ VerifyFuzzyEquals(log_int64(64), val);
+ VerifyFuzzyEquals(log_float32(64), val);
+ VerifyFuzzyEquals(log_float64(64), val);
+
+ EXPECT_EQ(log_int32(0), -std::numeric_limits<double>::infinity());
+}
+
+TEST(TestExtendedMathOps, TestLog10) {
+ VerifyFuzzyEquals(log10_int32(100), 2);
+ VerifyFuzzyEquals(log10_int64(100), 2);
+ VerifyFuzzyEquals(log10_float32(100), 2);
+ VerifyFuzzyEquals(log10_float64(100), 2);
+}
+
+TEST(TestExtendedMathOps, TestPower) {
+ VerifyFuzzyEquals(power_float64_float64(2, 5.4), 42.22425314473263);
+ VerifyFuzzyEquals(power_float64_float64(5.4, 2), 29.160000000000004);
+}
+
+TEST(TestExtendedMathOps, TestLogWithBase) {
+ gandiva::ExecutionContext context;
+ gdv_float64 out =
+ log_int32_int32(reinterpret_cast<gdv_int64>(&context), 1 /*base*/, 10 /*value*/);
+ VerifyFuzzyEquals(out, 0);
+ EXPECT_EQ(context.has_error(), true);
+ EXPECT_TRUE(context.get_error().find("divide by zero error") != std::string::npos)
+ << context.get_error();
+
+ gandiva::ExecutionContext context1;
+ out = log_int32_int32(reinterpret_cast<gdv_int64>(&context), 2 /*base*/, 64 /*value*/);
+ VerifyFuzzyEquals(out, 6);
+ EXPECT_EQ(context1.has_error(), false);
+}
+
+TEST(TestExtendedMathOps, TestRoundDecimal) {
+ EXPECT_FLOAT_EQ(round_float32(1234.245f), 1234);
+ EXPECT_FLOAT_EQ(round_float32(-11.7892f), -12);
+ EXPECT_FLOAT_EQ(round_float32(1.4999999f), 1);
+ EXPECT_EQ(std::signbit(round_float32(0)), 0);
+ EXPECT_FLOAT_EQ(round_float32_int32(1234.789f, 2), 1234.79f);
+ EXPECT_FLOAT_EQ(round_float32_int32(1234.12345f, -3), 1000);
+ EXPECT_FLOAT_EQ(round_float32_int32(-1234.4567f, 3), -1234.457f);
+ EXPECT_FLOAT_EQ(round_float32_int32(-1234.4567f, -3), -1000);
+ EXPECT_FLOAT_EQ(round_float32_int32(1234.4567f, 0), 1234);
+ EXPECT_FLOAT_EQ(round_float32_int32(1.5499999523162842f, 1), 1.5f);
+ EXPECT_EQ(std::signbit(round_float32_int32(0, 5)), 0);
+ EXPECT_FLOAT_EQ(round_float32_int32(static_cast<float>(1.55), 1), 1.5f);
+ EXPECT_FLOAT_EQ(round_float32_int32(static_cast<float>(9.134123), 2), 9.13f);
+ EXPECT_FLOAT_EQ(round_float32_int32(static_cast<float>(-1.923), 1), -1.9f);
+
+ VerifyFuzzyEquals(round_float64(1234.245), 1234);
+ VerifyFuzzyEquals(round_float64(-11.7892), -12);
+ VerifyFuzzyEquals(round_float64(1.4999999), 1);
+ EXPECT_EQ(std::signbit(round_float64(0)), 0);
+ VerifyFuzzyEquals(round_float64_int32(1234.789, 2), 1234.79);
+ VerifyFuzzyEquals(round_float64_int32(1234.12345, -3), 1000);
+ VerifyFuzzyEquals(round_float64_int32(-1234.4567, 3), -1234.457);
+ VerifyFuzzyEquals(round_float64_int32(-1234.4567, -3), -1000);
+ VerifyFuzzyEquals(round_float64_int32(1234.4567, 0), 1234);
+ EXPECT_EQ(std::signbit(round_float64_int32(0, -2)), 0);
+ VerifyFuzzyEquals(round_float64_int32((double)INT_MAX + 1, 0), (double)INT_MAX + 1);
+ VerifyFuzzyEquals(round_float64_int32((double)INT_MIN - 1, 0), (double)INT_MIN - 1);
+}
+
+TEST(TestExtendedMathOps, TestRound) {
+ EXPECT_EQ(round_int32(21134), 21134);
+ EXPECT_EQ(round_int32(-132422), -132422);
+ EXPECT_EQ(round_int32_int32(7589, -1), 7590);
+ EXPECT_EQ(round_int32_int32(8532, -2), 8500);
+ EXPECT_EQ(round_int32_int32(-8579, -1), -8580);
+ EXPECT_EQ(round_int32_int32(-8612, -2), -8600);
+ EXPECT_EQ(round_int32_int32(758, 2), 758);
+ EXPECT_EQ(round_int32_int32(8612, -5), 0);
+
+ EXPECT_EQ(round_int64(3453562312), 3453562312);
+ EXPECT_EQ(round_int64(-23453462343), -23453462343);
+ EXPECT_EQ(round_int64_int32(3453562312, -2), 3453562300);
+ EXPECT_EQ(round_int64_int32(3453562343, -5), 3453600000);
+ EXPECT_EQ(round_int64_int32(345353425343, 12), 345353425343);
+ EXPECT_EQ(round_int64_int32(-23453462343, -4), -23453460000);
+ EXPECT_EQ(round_int64_int32(-23453462343, -5), -23453500000);
+ EXPECT_EQ(round_int64_int32(345353425343, -12), 0);
+}
+
+TEST(TestExtendedMathOps, TestTruncate) {
+ EXPECT_EQ(truncate_int64_int32(1234, 4), 1234);
+ EXPECT_EQ(truncate_int64_int32(-1234, 4), -1234);
+ EXPECT_EQ(truncate_int64_int32(1234, -4), 0);
+ EXPECT_EQ(truncate_int64_int32(-1234, -2), -1200);
+ EXPECT_EQ(truncate_int64_int32(8124674407369523212, 0), 8124674407369523212);
+ EXPECT_EQ(truncate_int64_int32(8124674407369523212, -2), 8124674407369523200);
+}
+
+TEST(TestExtendedMathOps, TestTrigonometricFunctions) {
+ auto pi_float = static_cast<float>(M_PI);
+ // Sin functions
+ VerifyFuzzyEquals(sin_float32(0), sin(0));
+ VerifyFuzzyEquals(sin_float32(0), sin(0));
+ VerifyFuzzyEquals(sin_float32(pi_float / 2), sin(M_PI / 2));
+ VerifyFuzzyEquals(sin_float32(pi_float), sin(M_PI));
+ VerifyFuzzyEquals(sin_float32(-pi_float / 2), sin(-M_PI / 2));
+ VerifyFuzzyEquals(sin_float64(0), sin(0));
+ VerifyFuzzyEquals(sin_float64(M_PI / 2), sin(M_PI / 2));
+ VerifyFuzzyEquals(sin_float64(M_PI), sin(M_PI));
+ VerifyFuzzyEquals(sin_float64(-M_PI / 2), sin(-M_PI / 2));
+ VerifyFuzzyEquals(sin_int32(0), sin(0));
+ VerifyFuzzyEquals(sin_int64(0), sin(0));
+
+ // Cos functions
+ VerifyFuzzyEquals(cos_float32(0), cos(0));
+ VerifyFuzzyEquals(cos_float32(pi_float / 2), cos(M_PI / 2));
+ VerifyFuzzyEquals(cos_float32(pi_float), cos(M_PI));
+ VerifyFuzzyEquals(cos_float32(-pi_float / 2), cos(-M_PI / 2));
+ VerifyFuzzyEquals(cos_float64(0), cos(0));
+ VerifyFuzzyEquals(cos_float64(M_PI / 2), cos(M_PI / 2));
+ VerifyFuzzyEquals(cos_float64(M_PI), cos(M_PI));
+ VerifyFuzzyEquals(cos_float64(-M_PI / 2), cos(-M_PI / 2));
+ VerifyFuzzyEquals(cos_int32(0), cos(0));
+ VerifyFuzzyEquals(cos_int64(0), cos(0));
+
+ // Asin functions
+ VerifyFuzzyEquals(asin_float32(-1.0), asin(-1.0));
+ VerifyFuzzyEquals(asin_float32(1.0), asin(1.0));
+ VerifyFuzzyEquals(asin_float64(-1.0), asin(-1.0));
+ VerifyFuzzyEquals(asin_float64(1.0), asin(1.0));
+ VerifyFuzzyEquals(asin_int32(0), asin(0));
+ VerifyFuzzyEquals(asin_int64(0), asin(0));
+
+ // Acos functions
+ VerifyFuzzyEquals(acos_float32(-1.0), acos(-1.0));
+ VerifyFuzzyEquals(acos_float32(1.0), acos(1.0));
+ VerifyFuzzyEquals(acos_float64(-1.0), acos(-1.0));
+ VerifyFuzzyEquals(acos_float64(1.0), acos(1.0));
+ VerifyFuzzyEquals(acos_int32(0), acos(0));
+ VerifyFuzzyEquals(acos_int64(0), acos(0));
+
+ // Tan
+ VerifyFuzzyEquals(tan_float32(pi_float), tan(M_PI));
+ VerifyFuzzyEquals(tan_float32(-pi_float), tan(-M_PI));
+ VerifyFuzzyEquals(tan_float64(M_PI), tan(M_PI));
+ VerifyFuzzyEquals(tan_float64(-M_PI), tan(-M_PI));
+ VerifyFuzzyEquals(tan_int32(0), tan(0));
+ VerifyFuzzyEquals(tan_int64(0), tan(0));
+
+ // Atan
+ VerifyFuzzyEquals(atan_float32(pi_float), atan(M_PI));
+ VerifyFuzzyEquals(atan_float32(-pi_float), atan(-M_PI));
+ VerifyFuzzyEquals(atan_float64(M_PI), atan(M_PI));
+ VerifyFuzzyEquals(atan_float64(-M_PI), atan(-M_PI));
+ VerifyFuzzyEquals(atan_int32(0), atan(0));
+ VerifyFuzzyEquals(atan_int64(0), atan(0));
+
+ // Sinh functions
+ VerifyFuzzyEquals(sinh_float32(0), sinh(0));
+ VerifyFuzzyEquals(sinh_float32(pi_float / 2), sinh(M_PI / 2));
+ VerifyFuzzyEquals(sinh_float32(pi_float), sinh(M_PI));
+ VerifyFuzzyEquals(sinh_float32(-pi_float / 2), sinh(-M_PI / 2));
+ VerifyFuzzyEquals(sinh_float64(0), sinh(0));
+ VerifyFuzzyEquals(sinh_float64(M_PI / 2), sinh(M_PI / 2));
+ VerifyFuzzyEquals(sinh_float64(M_PI), sinh(M_PI));
+ VerifyFuzzyEquals(sinh_float64(-M_PI / 2), sinh(-M_PI / 2));
+ VerifyFuzzyEquals(sinh_int32(0), sinh(0));
+ VerifyFuzzyEquals(sinh_int64(0), sinh(0));
+
+ // Cosh functions
+ VerifyFuzzyEquals(cosh_float32(0), cosh(0));
+ VerifyFuzzyEquals(cosh_float32(pi_float / 2), cosh(M_PI / 2));
+ VerifyFuzzyEquals(cosh_float32(pi_float), cosh(M_PI));
+ VerifyFuzzyEquals(cosh_float32(-pi_float / 2), cosh(-M_PI / 2));
+ VerifyFuzzyEquals(cosh_float64(0), cosh(0));
+ VerifyFuzzyEquals(cosh_float64(M_PI / 2), cosh(M_PI / 2));
+ VerifyFuzzyEquals(cosh_float64(M_PI), cosh(M_PI));
+ VerifyFuzzyEquals(cosh_float64(-M_PI / 2), cosh(-M_PI / 2));
+ VerifyFuzzyEquals(cosh_int32(0), cosh(0));
+ VerifyFuzzyEquals(cosh_int64(0), cosh(0));
+
+ // Tanh
+ VerifyFuzzyEquals(tanh_float32(pi_float), tanh(M_PI));
+ VerifyFuzzyEquals(tanh_float32(-pi_float), tanh(-M_PI));
+ VerifyFuzzyEquals(tanh_float64(M_PI), tanh(M_PI));
+ VerifyFuzzyEquals(tanh_float64(-M_PI), tanh(-M_PI));
+ VerifyFuzzyEquals(tanh_int32(0), tanh(0));
+ VerifyFuzzyEquals(tanh_int64(0), tanh(0));
+
+ // Atan2
+ VerifyFuzzyEquals(atan2_float32_float32(1, 0), atan2(1, 0));
+ VerifyFuzzyEquals(atan2_float32_float32(-1.0, 0), atan2(-1, 0));
+ VerifyFuzzyEquals(atan2_float64_float64(1.0, 0.0), atan2(1, 0));
+ VerifyFuzzyEquals(atan2_float64_float64(-1, 0), atan2(-1, 0));
+ VerifyFuzzyEquals(atan2_int32_int32(1, 0), atan2(1, 0));
+ VerifyFuzzyEquals(atan2_int64_int64(-1, 0), atan2(-1, 0));
+
+ // Radians
+ VerifyFuzzyEquals(radians_float32(0), 0);
+ VerifyFuzzyEquals(radians_float32(180.0), M_PI);
+ VerifyFuzzyEquals(radians_float32(90.0), M_PI / 2);
+ VerifyFuzzyEquals(radians_float64(0), 0);
+ VerifyFuzzyEquals(radians_float64(180.0), M_PI);
+ VerifyFuzzyEquals(radians_float64(90.0), M_PI / 2);
+ VerifyFuzzyEquals(radians_int32(180), M_PI);
+ VerifyFuzzyEquals(radians_int64(90), M_PI / 2);
+
+ // Degrees
+ VerifyFuzzyEquals(degrees_float32(0), 0.0);
+ VerifyFuzzyEquals(degrees_float32(pi_float), 180.0);
+ VerifyFuzzyEquals(degrees_float32(pi_float / 2), 90.0);
+ VerifyFuzzyEquals(degrees_float64(0), 0.0);
+ VerifyFuzzyEquals(degrees_float64(M_PI), 180.0);
+ VerifyFuzzyEquals(degrees_float64(M_PI / 2), 90.0);
+ VerifyFuzzyEquals(degrees_int32(1), 57.2958);
+ VerifyFuzzyEquals(degrees_int64(1), 57.2958);
+
+ // Cot
+ VerifyFuzzyEquals(cot_float32(pi_float / 2), tan(M_PI / 2 - M_PI / 2));
+ VerifyFuzzyEquals(cot_float64(M_PI / 2), tan(M_PI / 2 - M_PI / 2));
+}
+
+TEST(TestExtendedMathOps, TestBinRepresentation) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+
+ const char* out_str = bin_int32(ctx_ptr, 7, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "111");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int32(ctx_ptr, 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "0");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int32(ctx_ptr, 28550, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "110111110000110");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int32(ctx_ptr, -28550, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "11111111111111111001000001111010");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int32(ctx_ptr, 58117, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1110001100000101");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int32(ctx_ptr, -58117, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "11111111111111110001110011111011");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int32(ctx_ptr, INT32_MAX, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1111111111111111111111111111111");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int32(ctx_ptr, INT32_MIN, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "10000000000000000000000000000000");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int64(ctx_ptr, 7, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "111");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int64(ctx_ptr, 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "0");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int64(ctx_ptr, 28550, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "110111110000110");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int64(ctx_ptr, -28550, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len),
+ "1111111111111111111111111111111111111111111111111001000001111010");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int64(ctx_ptr, 58117, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1110001100000101");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int64(ctx_ptr, -58117, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len),
+ "1111111111111111111111111111111111111111111111110001110011111011");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int64(ctx_ptr, INT64_MAX, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len),
+ "111111111111111111111111111111111111111111111111111111111111111");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = bin_int64(ctx_ptr, INT64_MIN, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len),
+ "1000000000000000000000000000000000000000000000000000000000000000");
+ EXPECT_FALSE(ctx.has_error());
+}
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/precompiled/hash.cc b/src/arrow/cpp/src/gandiva/precompiled/hash.cc
new file mode 100644
index 000000000..eacf36230
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/hash.cc
@@ -0,0 +1,407 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+extern "C" {
+
+#include <string.h>
+
+#include "./types.h"
+
+static inline gdv_uint64 rotate_left(gdv_uint64 val, int distance) {
+ return (val << distance) | (val >> (64 - distance));
+}
+
+//
+// MurmurHash3 was written by Austin Appleby, and is placed in the public
+// domain.
+// See http://smhasher.googlecode.com/svn/trunk/MurmurHash3.cpp
+// MurmurHash3_x64_128
+//
+static inline gdv_uint64 fmix64(gdv_uint64 k) {
+ k ^= k >> 33;
+ k *= 0xff51afd7ed558ccduLL;
+ k ^= k >> 33;
+ k *= 0xc4ceb9fe1a85ec53uLL;
+ k ^= k >> 33;
+ return k;
+}
+
+static inline gdv_uint64 murmur3_64(gdv_uint64 val, gdv_int32 seed) {
+ gdv_uint64 h1 = seed;
+ gdv_uint64 h2 = seed;
+
+ gdv_uint64 c1 = 0x87c37b91114253d5ull;
+ gdv_uint64 c2 = 0x4cf5ad432745937full;
+
+ int length = 8;
+ gdv_uint64 k1 = 0;
+
+ k1 = val;
+ k1 *= c1;
+ k1 = rotate_left(k1, 31);
+ k1 *= c2;
+ h1 ^= k1;
+
+ h1 ^= length;
+ h2 ^= length;
+
+ h1 += h2;
+ h2 += h1;
+
+ h1 = fmix64(h1);
+ h2 = fmix64(h2);
+
+ h1 += h2;
+
+ // h2 += h1;
+ // murmur3_128 should return 128 bit (h1,h2), now we return only 64bits,
+ return h1;
+}
+
+static inline gdv_uint32 murmur3_32(gdv_uint64 val, gdv_int32 seed) {
+ gdv_uint64 c1 = 0xcc9e2d51ull;
+ gdv_uint64 c2 = 0x1b873593ull;
+ int length = 8;
+ static gdv_uint64 UINT_MASK = 0xffffffffull;
+ gdv_uint64 lh1 = seed & UINT_MASK;
+ for (int i = 0; i < 2; i++) {
+ gdv_uint64 lk1 = ((val >> i * 32) & UINT_MASK);
+ lk1 *= c1;
+ lk1 &= UINT_MASK;
+
+ lk1 = ((lk1 << 15) & UINT_MASK) | (lk1 >> 17);
+
+ lk1 *= c2;
+ lk1 &= UINT_MASK;
+
+ lh1 ^= lk1;
+ lh1 = ((lh1 << 13) & UINT_MASK) | (lh1 >> 19);
+
+ lh1 = lh1 * 5 + 0xe6546b64L;
+ lh1 = UINT_MASK & lh1;
+ }
+ lh1 ^= length;
+
+ lh1 ^= lh1 >> 16;
+ lh1 *= 0x85ebca6bull;
+ lh1 = UINT_MASK & lh1;
+ lh1 ^= lh1 >> 13;
+ lh1 *= 0xc2b2ae35ull;
+ lh1 = UINT_MASK & lh1;
+ lh1 ^= lh1 >> 16;
+
+ return static_cast<gdv_uint32>(lh1);
+}
+
+static inline gdv_uint64 double_to_long_bits(double value) {
+ gdv_uint64 result;
+ memcpy(&result, &value, sizeof(result));
+ return result;
+}
+
+FORCE_INLINE gdv_int64 hash64(double val, gdv_int64 seed) {
+ return murmur3_64(double_to_long_bits(val), static_cast<gdv_int32>(seed));
+}
+
+FORCE_INLINE gdv_int32 hash32(double val, gdv_int32 seed) {
+ return murmur3_32(double_to_long_bits(val), seed);
+}
+
+// Wrappers for all the numeric/data/time arrow types
+
+#define HASH64_WITH_SEED_OP(NAME, TYPE) \
+ FORCE_INLINE \
+ gdv_int64 NAME##_##TYPE(gdv_##TYPE in, gdv_boolean is_valid, gdv_int64 seed, \
+ gdv_boolean seed_isvalid) { \
+ if (!is_valid) { \
+ return seed; \
+ } \
+ return hash64(static_cast<double>(in), seed); \
+ }
+
+#define HASH32_WITH_SEED_OP(NAME, TYPE) \
+ FORCE_INLINE \
+ gdv_int32 NAME##_##TYPE(gdv_##TYPE in, gdv_boolean is_valid, gdv_int32 seed, \
+ gdv_boolean seed_isvalid) { \
+ if (!is_valid) { \
+ return seed; \
+ } \
+ return hash32(static_cast<double>(in), seed); \
+ }
+
+#define HASH64_OP(NAME, TYPE) \
+ FORCE_INLINE \
+ gdv_int64 NAME##_##TYPE(gdv_##TYPE in, gdv_boolean is_valid) { \
+ return is_valid ? hash64(static_cast<double>(in), 0) : 0; \
+ }
+
+#define HASH32_OP(NAME, TYPE) \
+ FORCE_INLINE \
+ gdv_int32 NAME##_##TYPE(gdv_##TYPE in, gdv_boolean is_valid) { \
+ return is_valid ? hash32(static_cast<double>(in), 0) : 0; \
+ }
+
+// Expand inner macro for all numeric types.
+#define NUMERIC_BOOL_DATE_TYPES(INNER, NAME) \
+ INNER(NAME, int8) \
+ INNER(NAME, int16) \
+ INNER(NAME, int32) \
+ INNER(NAME, int64) \
+ INNER(NAME, uint8) \
+ INNER(NAME, uint16) \
+ INNER(NAME, uint32) \
+ INNER(NAME, uint64) \
+ INNER(NAME, float32) \
+ INNER(NAME, float64) \
+ INNER(NAME, boolean) \
+ INNER(NAME, date64) \
+ INNER(NAME, date32) \
+ INNER(NAME, time32) \
+ INNER(NAME, timestamp)
+
+NUMERIC_BOOL_DATE_TYPES(HASH32_OP, hash)
+NUMERIC_BOOL_DATE_TYPES(HASH32_OP, hash32)
+NUMERIC_BOOL_DATE_TYPES(HASH32_OP, hash32AsDouble)
+NUMERIC_BOOL_DATE_TYPES(HASH32_WITH_SEED_OP, hash32WithSeed)
+NUMERIC_BOOL_DATE_TYPES(HASH32_WITH_SEED_OP, hash32AsDoubleWithSeed)
+
+NUMERIC_BOOL_DATE_TYPES(HASH64_OP, hash64)
+NUMERIC_BOOL_DATE_TYPES(HASH64_OP, hash64AsDouble)
+NUMERIC_BOOL_DATE_TYPES(HASH64_WITH_SEED_OP, hash64WithSeed)
+NUMERIC_BOOL_DATE_TYPES(HASH64_WITH_SEED_OP, hash64AsDoubleWithSeed)
+
+#undef NUMERIC_BOOL_DATE_TYPES
+
+static inline gdv_uint64 murmur3_64_buf(const gdv_uint8* key, gdv_int32 len,
+ gdv_int32 seed) {
+ gdv_uint64 h1 = seed;
+ gdv_uint64 h2 = seed;
+ gdv_uint64 c1 = 0x87c37b91114253d5ull;
+ gdv_uint64 c2 = 0x4cf5ad432745937full;
+
+ const gdv_uint64* blocks = reinterpret_cast<const gdv_uint64*>(key);
+ int nblocks = len / 16;
+ for (int i = 0; i < nblocks; i++) {
+ gdv_uint64 k1 = blocks[i * 2 + 0];
+ gdv_uint64 k2 = blocks[i * 2 + 1];
+
+ k1 *= c1;
+ k1 = rotate_left(k1, 31);
+ k1 *= c2;
+ h1 ^= k1;
+ h1 = rotate_left(h1, 27);
+ h1 += h2;
+ h1 = h1 * 5 + 0x52dce729;
+ k2 *= c2;
+ k2 = rotate_left(k2, 33);
+ k2 *= c1;
+ h2 ^= k2;
+ h2 = rotate_left(h2, 31);
+ h2 += h1;
+ h2 = h2 * 5 + 0x38495ab5;
+ }
+
+ // tail
+ gdv_uint64 k1 = 0;
+ gdv_uint64 k2 = 0;
+
+ const gdv_uint8* tail = reinterpret_cast<const gdv_uint8*>(key + nblocks * 16);
+ switch (len & 15) {
+ case 15:
+ k2 = static_cast<gdv_uint64>(tail[14]) << 48;
+ case 14:
+ k2 ^= static_cast<gdv_uint64>(tail[13]) << 40;
+ case 13:
+ k2 ^= static_cast<gdv_uint64>(tail[12]) << 32;
+ case 12:
+ k2 ^= static_cast<gdv_uint64>(tail[11]) << 24;
+ case 11:
+ k2 ^= static_cast<gdv_uint64>(tail[10]) << 16;
+ case 10:
+ k2 ^= static_cast<gdv_uint64>(tail[9]) << 8;
+ case 9:
+ k2 ^= static_cast<gdv_uint64>(tail[8]);
+ k2 *= c2;
+ k2 = rotate_left(k2, 33);
+ k2 *= c1;
+ h2 ^= k2;
+ case 8:
+ k1 ^= static_cast<gdv_uint64>(tail[7]) << 56;
+ case 7:
+ k1 ^= static_cast<gdv_uint64>(tail[6]) << 48;
+ case 6:
+ k1 ^= static_cast<gdv_uint64>(tail[5]) << 40;
+ case 5:
+ k1 ^= static_cast<gdv_uint64>(tail[4]) << 32;
+ case 4:
+ k1 ^= static_cast<gdv_uint64>(tail[3]) << 24;
+ case 3:
+ k1 ^= static_cast<gdv_uint64>(tail[2]) << 16;
+ case 2:
+ k1 ^= static_cast<gdv_uint64>(tail[1]) << 8;
+ case 1:
+ k1 ^= static_cast<gdv_uint64>(tail[0]) << 0;
+ k1 *= c1;
+ k1 = rotate_left(k1, 31);
+ k1 *= c2;
+ h1 ^= k1;
+ }
+
+ h1 ^= len;
+ h2 ^= len;
+
+ h1 += h2;
+ h2 += h1;
+
+ h1 = fmix64(h1);
+ h2 = fmix64(h2);
+
+ h1 += h2;
+ // h2 += h1;
+ // returning 64-bits of the 128-bit hash.
+ return h1;
+}
+
+static gdv_uint32 murmur3_32_buf(const gdv_uint8* key, gdv_int32 len, gdv_int32 seed) {
+ static const gdv_uint64 c1 = 0xcc9e2d51ull;
+ static const gdv_uint64 c2 = 0x1b873593ull;
+ static const gdv_uint64 UINT_MASK = 0xffffffffull;
+ gdv_uint64 lh1 = seed;
+ const gdv_uint32* blocks = reinterpret_cast<const gdv_uint32*>(key);
+ int nblocks = len / 4;
+ const gdv_uint8* tail = reinterpret_cast<const gdv_uint8*>(key + nblocks * 4);
+ for (int i = 0; i < nblocks; i++) {
+ gdv_uint64 lk1 = static_cast<gdv_uint64>(blocks[i]);
+
+ // k1 *= c1;
+ lk1 *= c1;
+ lk1 &= UINT_MASK;
+
+ lk1 = ((lk1 << 15) & UINT_MASK) | (lk1 >> 17);
+
+ lk1 *= c2;
+ lk1 = lk1 & UINT_MASK;
+ lh1 ^= lk1;
+ lh1 = ((lh1 << 13) & UINT_MASK) | (lh1 >> 19);
+
+ lh1 = lh1 * 5 + 0xe6546b64ull;
+ lh1 = UINT_MASK & lh1;
+ }
+
+ // tail
+ gdv_uint64 lk1 = 0;
+
+ switch (len & 3) {
+ case 3:
+ lk1 = (tail[2] & 0xff) << 16;
+ case 2:
+ lk1 |= (tail[1] & 0xff) << 8;
+ case 1:
+ lk1 |= (tail[0] & 0xff);
+ lk1 *= c1;
+ lk1 = UINT_MASK & lk1;
+ lk1 = ((lk1 << 15) & UINT_MASK) | (lk1 >> 17);
+
+ lk1 *= c2;
+ lk1 = lk1 & UINT_MASK;
+
+ lh1 ^= lk1;
+ }
+
+ // finalization
+ lh1 ^= len;
+
+ lh1 ^= lh1 >> 16;
+ lh1 *= 0x85ebca6b;
+ lh1 = UINT_MASK & lh1;
+ lh1 ^= lh1 >> 13;
+
+ lh1 *= 0xc2b2ae35;
+ lh1 = UINT_MASK & lh1;
+ lh1 ^= lh1 >> 16;
+
+ return static_cast<gdv_uint32>(lh1 & UINT_MASK);
+}
+
+FORCE_INLINE gdv_int64 hash64_buf(const gdv_uint8* buf, int len, gdv_int64 seed) {
+ return murmur3_64_buf(buf, len, static_cast<gdv_int32>(seed));
+}
+
+FORCE_INLINE gdv_int32 hash32_buf(const gdv_uint8* buf, int len, gdv_int32 seed) {
+ return murmur3_32_buf(buf, len, seed);
+}
+
+// Wrappers for the varlen types
+
+#define HASH64_BUF_WITH_SEED_OP(NAME, TYPE) \
+ FORCE_INLINE \
+ gdv_int64 NAME##_##TYPE(gdv_##TYPE in, gdv_int32 len, gdv_boolean is_valid, \
+ gdv_int64 seed, gdv_boolean seed_isvalid) { \
+ if (!is_valid) { \
+ return seed; \
+ } \
+ return hash64_buf(reinterpret_cast<const uint8_t*>(in), len, seed); \
+ }
+
+#define HASH32_BUF_WITH_SEED_OP(NAME, TYPE) \
+ FORCE_INLINE \
+ gdv_int32 NAME##_##TYPE(gdv_##TYPE in, gdv_int32 len, gdv_boolean is_valid, \
+ gdv_int32 seed, gdv_boolean seed_isvalid) { \
+ if (!is_valid) { \
+ return seed; \
+ } \
+ return hash32_buf(reinterpret_cast<const uint8_t*>(in), len, seed); \
+ }
+
+#define HASH64_BUF_OP(NAME, TYPE) \
+ FORCE_INLINE \
+ gdv_int64 NAME##_##TYPE(gdv_##TYPE in, gdv_int32 len, gdv_boolean is_valid) { \
+ return is_valid ? hash64_buf(reinterpret_cast<const uint8_t*>(in), len, 0) : 0; \
+ }
+
+#define HASH32_BUF_OP(NAME, TYPE) \
+ FORCE_INLINE \
+ gdv_int32 NAME##_##TYPE(gdv_##TYPE in, gdv_int32 len, gdv_boolean is_valid) { \
+ return is_valid ? hash32_buf(reinterpret_cast<const uint8_t*>(in), len, 0) : 0; \
+ }
+
+// Expand inner macro for all non-numeric types.
+#define VAR_LEN_TYPES(INNER, NAME) \
+ INNER(NAME, utf8) \
+ INNER(NAME, binary)
+
+VAR_LEN_TYPES(HASH32_BUF_OP, hash)
+VAR_LEN_TYPES(HASH32_BUF_OP, hash32)
+VAR_LEN_TYPES(HASH32_BUF_OP, hash32AsDouble)
+VAR_LEN_TYPES(HASH32_BUF_WITH_SEED_OP, hash32WithSeed)
+VAR_LEN_TYPES(HASH32_BUF_WITH_SEED_OP, hash32AsDoubleWithSeed)
+
+VAR_LEN_TYPES(HASH64_BUF_OP, hash64)
+VAR_LEN_TYPES(HASH64_BUF_OP, hash64AsDouble)
+VAR_LEN_TYPES(HASH64_BUF_WITH_SEED_OP, hash64WithSeed)
+VAR_LEN_TYPES(HASH64_BUF_WITH_SEED_OP, hash64AsDoubleWithSeed)
+
+#undef HASH32_BUF_OP
+#undef HASH32_BUF_WITH_SEED_OP
+#undef HASH32_OP
+#undef HASH32_WITH_SEED_OP
+#undef HASH64_BUF_OP
+#undef HASH64_BUF_WITH_SEED_OP
+#undef HASH64_OP
+#undef HASH64_WITH_SEED_OP
+
+} // extern "C"
diff --git a/src/arrow/cpp/src/gandiva/precompiled/hash_test.cc b/src/arrow/cpp/src/gandiva/precompiled/hash_test.cc
new file mode 100644
index 000000000..0a51dced2
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/hash_test.cc
@@ -0,0 +1,122 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <time.h>
+
+#include <gtest/gtest.h>
+#include "gandiva/precompiled/types.h"
+
+namespace gandiva {
+
+TEST(TestHash, TestHash32) {
+ gdv_int8 s8 = 0;
+ gdv_uint8 u8 = 0;
+ gdv_int16 s16 = 0;
+ gdv_uint16 u16 = 0;
+ gdv_int32 s32 = 0;
+ gdv_uint32 u32 = 0;
+ gdv_int64 s64 = 0;
+ gdv_uint64 u64 = 0;
+ gdv_float32 f32 = 0;
+ gdv_float64 f64 = 0;
+
+ // hash of 0 should be non-zero (zero is the hash value for nulls).
+ gdv_int32 zero_hash = hash32(s8, 0);
+ EXPECT_NE(zero_hash, 0);
+
+ // for a given value, all numeric types must have the same hash.
+ EXPECT_EQ(hash32(u8, 0), zero_hash);
+ EXPECT_EQ(hash32(s16, 0), zero_hash);
+ EXPECT_EQ(hash32(u16, 0), zero_hash);
+ EXPECT_EQ(hash32(s32, 0), zero_hash);
+ EXPECT_EQ(hash32(u32, 0), zero_hash);
+ EXPECT_EQ(hash32(static_cast<double>(s64), 0), zero_hash);
+ EXPECT_EQ(hash32(static_cast<double>(u64), 0), zero_hash);
+ EXPECT_EQ(hash32(f32, 0), zero_hash);
+ EXPECT_EQ(hash32(f64, 0), zero_hash);
+
+ // hash must change with a change in seed.
+ EXPECT_NE(hash32(s8, 1), zero_hash);
+
+ // for a given value and seed, all numeric types must have the same hash.
+ EXPECT_EQ(hash32(s8, 1), hash32(s16, 1));
+ EXPECT_EQ(hash32(s8, 1), hash32(u32, 1));
+ EXPECT_EQ(hash32(s8, 1), hash32(f32, 1));
+ EXPECT_EQ(hash32(s8, 1), hash32(f64, 1));
+}
+
+TEST(TestHash, TestHash64) {
+ gdv_int8 s8 = 0;
+ gdv_uint8 u8 = 0;
+ gdv_int16 s16 = 0;
+ gdv_uint16 u16 = 0;
+ gdv_int32 s32 = 0;
+ gdv_uint32 u32 = 0;
+ gdv_int64 s64 = 0;
+ gdv_uint64 u64 = 0;
+ gdv_float32 f32 = 0;
+ gdv_float64 f64 = 0;
+
+ // hash of 0 should be non-zero (zero is the hash value for nulls).
+ gdv_int64 zero_hash = hash64(s8, 0);
+ EXPECT_NE(zero_hash, 0);
+ EXPECT_NE(hash64(u8, 0), hash32(u8, 0));
+
+ // for a given value, all numeric types must have the same hash.
+ EXPECT_EQ(hash64(u8, 0), zero_hash);
+ EXPECT_EQ(hash64(s16, 0), zero_hash);
+ EXPECT_EQ(hash64(u16, 0), zero_hash);
+ EXPECT_EQ(hash64(s32, 0), zero_hash);
+ EXPECT_EQ(hash64(u32, 0), zero_hash);
+ EXPECT_EQ(hash64(static_cast<double>(s64), 0), zero_hash);
+ EXPECT_EQ(hash64(static_cast<double>(u64), 0), zero_hash);
+ EXPECT_EQ(hash64(f32, 0), zero_hash);
+ EXPECT_EQ(hash64(f64, 0), zero_hash);
+
+ // hash must change with a change in seed.
+ EXPECT_NE(hash64(s8, 1), zero_hash);
+
+ // for a given value and seed, all numeric types must have the same hash.
+ EXPECT_EQ(hash64(s8, 1), hash64(s16, 1));
+ EXPECT_EQ(hash64(s8, 1), hash64(u32, 1));
+ EXPECT_EQ(hash64(s8, 1), hash64(f32, 1));
+}
+
+TEST(TestHash, TestHashBuf) {
+ const char* buf = "hello";
+ int buf_len = 5;
+
+ // hash should be non-zero (zero is the hash value for nulls).
+ EXPECT_NE(hash32_buf((const gdv_uint8*)buf, buf_len, 0), 0);
+ EXPECT_NE(hash64_buf((const gdv_uint8*)buf, buf_len, 0), 0);
+
+ // hash must change if the string is changed.
+ EXPECT_NE(hash32_buf((const gdv_uint8*)buf, buf_len, 0),
+ hash32_buf((const gdv_uint8*)buf, buf_len - 1, 0));
+
+ EXPECT_NE(hash64_buf((const gdv_uint8*)buf, buf_len, 0),
+ hash64_buf((const gdv_uint8*)buf, buf_len - 1, 0));
+
+ // hash must change if the seed is changed.
+ EXPECT_NE(hash32_buf((const gdv_uint8*)buf, buf_len, 0),
+ hash32_buf((const gdv_uint8*)buf, buf_len, 1));
+
+ EXPECT_NE(hash64_buf((const gdv_uint8*)buf, buf_len, 0),
+ hash64_buf((const gdv_uint8*)buf, buf_len, 1));
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/precompiled/print.cc b/src/arrow/cpp/src/gandiva/precompiled/print.cc
new file mode 100644
index 000000000..ecb90e1a3
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/print.cc
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+extern "C" {
+
+#include <stdio.h>
+
+#include "./types.h"
+
+int print_double(char* msg, double val) { return printf(msg, val); }
+
+int print_float(char* msg, float val) { return printf(msg, val); }
+
+} // extern "C"
diff --git a/src/arrow/cpp/src/gandiva/precompiled/string_ops.cc b/src/arrow/cpp/src/gandiva/precompiled/string_ops.cc
new file mode 100644
index 000000000..48c24b862
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/string_ops.cc
@@ -0,0 +1,2198 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// String functions
+#include "arrow/util/value_parsing.h"
+
+extern "C" {
+
+#include <algorithm>
+#include <climits>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+
+#include "./types.h"
+
+FORCE_INLINE
+gdv_int32 octet_length_utf8(const gdv_utf8 input, gdv_int32 length) { return length; }
+
+FORCE_INLINE
+gdv_int32 bit_length_utf8(const gdv_utf8 input, gdv_int32 length) { return length * 8; }
+
+FORCE_INLINE
+gdv_int32 octet_length_binary(const gdv_binary input, gdv_int32 length) { return length; }
+
+FORCE_INLINE
+gdv_int32 bit_length_binary(const gdv_binary input, gdv_int32 length) {
+ return length * 8;
+}
+
+FORCE_INLINE
+int match_string(const char* input, gdv_int32 input_len, gdv_int32 start_pos,
+ const char* delim, gdv_int32 delim_len) {
+ for (int i = start_pos; i < input_len; i++) {
+ int left_chars = input_len - i;
+ if ((left_chars >= delim_len) && memcmp(input + i, delim, delim_len) == 0) {
+ return i + delim_len;
+ }
+ }
+
+ return -1;
+}
+
+FORCE_INLINE
+gdv_int32 mem_compare(const char* left, gdv_int32 left_len, const char* right,
+ gdv_int32 right_len) {
+ int min = left_len;
+ if (right_len < min) {
+ min = right_len;
+ }
+
+ int cmp_ret = memcmp(left, right, min);
+ if (cmp_ret != 0) {
+ return cmp_ret;
+ } else {
+ return left_len - right_len;
+ }
+}
+
+// Expand inner macro for all varlen types.
+#define VAR_LEN_OP_TYPES(INNER, NAME, OP) \
+ INNER(NAME, utf8, OP) \
+ INNER(NAME, binary, OP)
+
+// Relational binary fns : left, right params are same, return is bool.
+#define BINARY_RELATIONAL(NAME, TYPE, OP) \
+ FORCE_INLINE \
+ bool NAME##_##TYPE##_##TYPE(const gdv_##TYPE left, gdv_int32 left_len, \
+ const gdv_##TYPE right, gdv_int32 right_len) { \
+ return mem_compare(left, left_len, right, right_len) OP 0; \
+ }
+
+VAR_LEN_OP_TYPES(BINARY_RELATIONAL, equal, ==)
+VAR_LEN_OP_TYPES(BINARY_RELATIONAL, not_equal, !=)
+VAR_LEN_OP_TYPES(BINARY_RELATIONAL, less_than, <)
+VAR_LEN_OP_TYPES(BINARY_RELATIONAL, less_than_or_equal_to, <=)
+VAR_LEN_OP_TYPES(BINARY_RELATIONAL, greater_than, >)
+VAR_LEN_OP_TYPES(BINARY_RELATIONAL, greater_than_or_equal_to, >=)
+
+#undef BINARY_RELATIONAL
+#undef VAR_LEN_OP_TYPES
+
+// Expand inner macro for all varlen types.
+#define VAR_LEN_TYPES(INNER, NAME) \
+ INNER(NAME, utf8) \
+ INNER(NAME, binary)
+
+FORCE_INLINE
+int to_binary_from_hex(char ch) {
+ if (ch >= 'A' && ch <= 'F') {
+ return 10 + (ch - 'A');
+ } else if (ch >= 'a' && ch <= 'f') {
+ return 10 + (ch - 'a');
+ }
+ return ch - '0';
+}
+
+FORCE_INLINE
+bool starts_with_utf8_utf8(const char* data, gdv_int32 data_len, const char* prefix,
+ gdv_int32 prefix_len) {
+ return ((data_len >= prefix_len) && (memcmp(data, prefix, prefix_len) == 0));
+}
+
+FORCE_INLINE
+bool ends_with_utf8_utf8(const char* data, gdv_int32 data_len, const char* suffix,
+ gdv_int32 suffix_len) {
+ return ((data_len >= suffix_len) &&
+ (memcmp(data + data_len - suffix_len, suffix, suffix_len) == 0));
+}
+
+FORCE_INLINE
+bool is_substr_utf8_utf8(const char* data, int32_t data_len, const char* substr,
+ int32_t substr_len) {
+ for (int32_t i = 0; i <= data_len - substr_len; ++i) {
+ if (memcmp(data + i, substr, substr_len) == 0) {
+ return true;
+ }
+ }
+ return false;
+}
+
+FORCE_INLINE
+gdv_int32 utf8_char_length(char c) {
+ if ((signed char)c >= 0) { // 1-byte char (0x00 ~ 0x7F)
+ return 1;
+ } else if ((c & 0xE0) == 0xC0) { // 2-byte char
+ return 2;
+ } else if ((c & 0xF0) == 0xE0) { // 3-byte char
+ return 3;
+ } else if ((c & 0xF8) == 0xF0) { // 4-byte char
+ return 4;
+ }
+ // invalid char
+ return 0;
+}
+
+FORCE_INLINE
+void set_error_for_invalid_utf(int64_t execution_context, char val) {
+ char const* fmt = "unexpected byte \\%02hhx encountered while decoding utf8 string";
+ int size = static_cast<int>(strlen(fmt)) + 64;
+ char* error = reinterpret_cast<char*>(malloc(size));
+ snprintf(error, size, fmt, (unsigned char)val);
+ gdv_fn_context_set_error_msg(execution_context, error);
+ free(error);
+}
+
+FORCE_INLINE
+bool validate_utf8_following_bytes(const char* data, int32_t data_len,
+ int32_t char_index) {
+ for (int j = 1; j < data_len; ++j) {
+ if ((data[char_index + j] & 0xC0) != 0x80) { // bytes following head-byte of glyph
+ return false;
+ }
+ }
+ return true;
+}
+
+// Count the number of utf8 characters
+// return 0 for invalid/incomplete input byte sequences
+FORCE_INLINE
+gdv_int32 utf8_length(gdv_int64 context, const char* data, gdv_int32 data_len) {
+ int char_len = 0;
+ int count = 0;
+ for (int i = 0; i < data_len; i += char_len) {
+ char_len = utf8_char_length(data[i]);
+ if (char_len == 0 || i + char_len > data_len) { // invalid byte or incomplete glyph
+ set_error_for_invalid_utf(context, data[i]);
+ return 0;
+ }
+ for (int j = 1; j < char_len; ++j) {
+ if ((data[i + j] & 0xC0) != 0x80) { // bytes following head-byte of glyph
+ set_error_for_invalid_utf(context, data[i + j]);
+ return 0;
+ }
+ }
+ ++count;
+ }
+ return count;
+}
+
+// Count the number of utf8 characters, ignoring invalid char, considering size 1
+FORCE_INLINE
+gdv_int32 utf8_length_ignore_invalid(const char* data, gdv_int32 data_len) {
+ int char_len = 0;
+ int count = 0;
+ for (int i = 0; i < data_len; i += char_len) {
+ char_len = utf8_char_length(data[i]);
+ if (char_len == 0 || i + char_len > data_len) { // invalid byte or incomplete glyph
+ // if invalid byte or incomplete glyph, ignore it
+ char_len = 1;
+ }
+ for (int j = 1; j < char_len; ++j) {
+ if ((data[i + j] & 0xC0) != 0x80) { // bytes following head-byte of glyph
+ char_len += 1;
+ }
+ }
+ ++count;
+ }
+ return count;
+}
+
+// Get the byte position corresponding to a character position for a non-empty utf8
+// sequence
+FORCE_INLINE
+gdv_int32 utf8_byte_pos(gdv_int64 context, const char* str, gdv_int32 str_len,
+ gdv_int32 char_pos) {
+ int char_len = 0;
+ int byte_index = 0;
+ for (gdv_int32 char_index = 0; char_index < char_pos && byte_index < str_len;
+ char_index++) {
+ char_len = utf8_char_length(str[byte_index]);
+ if (char_len == 0 ||
+ byte_index + char_len > str_len) { // invalid byte or incomplete glyph
+ set_error_for_invalid_utf(context, str[byte_index]);
+ return -1;
+ }
+ byte_index += char_len;
+ }
+ return byte_index;
+}
+
+#define UTF8_LENGTH(NAME, TYPE) \
+ FORCE_INLINE \
+ gdv_int32 NAME##_##TYPE(gdv_int64 context, gdv_##TYPE in, gdv_int32 in_len) { \
+ return utf8_length(context, in, in_len); \
+ }
+
+UTF8_LENGTH(char_length, utf8)
+UTF8_LENGTH(length, utf8)
+UTF8_LENGTH(lengthUtf8, binary)
+
+// Returns a string of 'n' spaces.
+#define SPACE_STR(IN_TYPE) \
+ GANDIVA_EXPORT \
+ const char* space_##IN_TYPE(gdv_int64 ctx, gdv_##IN_TYPE n, int32_t* out_len) { \
+ gdv_int32 n_times = static_cast<gdv_int32>(n); \
+ if (n_times <= 0) { \
+ *out_len = 0; \
+ return ""; \
+ } \
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(ctx, n_times)); \
+ if (ret == nullptr) { \
+ gdv_fn_context_set_error_msg(ctx, "Could not allocate memory for output string"); \
+ *out_len = 0; \
+ return ""; \
+ } \
+ for (int i = 0; i < n_times; i++) { \
+ ret[i] = ' '; \
+ } \
+ *out_len = n_times; \
+ return ret; \
+ }
+
+SPACE_STR(int32)
+SPACE_STR(int64)
+
+// Reverse a utf8 sequence
+FORCE_INLINE
+const char* reverse_utf8(gdv_int64 context, const char* data, gdv_int32 data_len,
+ int32_t* out_len) {
+ if (data_len == 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, data_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+
+ gdv_int32 char_len;
+ for (gdv_int32 i = 0; i < data_len; i += char_len) {
+ char_len = utf8_char_length(data[i]);
+
+ if (char_len == 0 || i + char_len > data_len) { // invalid byte or incomplete glyph
+ set_error_for_invalid_utf(context, data[i]);
+ *out_len = 0;
+ return "";
+ }
+
+ for (gdv_int32 j = 0; j < char_len; ++j) {
+ if (j > 0 && (data[i + j] & 0xC0) != 0x80) { // bytes following head-byte of glyph
+ set_error_for_invalid_utf(context, data[i + j]);
+ *out_len = 0;
+ return "";
+ }
+ ret[data_len - i - char_len + j] = data[i + j];
+ }
+ }
+ *out_len = data_len;
+ return ret;
+}
+
+// Trims whitespaces from the left end of the input utf8 sequence
+FORCE_INLINE
+const char* ltrim_utf8(gdv_int64 context, const char* data, gdv_int32 data_len,
+ int32_t* out_len) {
+ if (data_len == 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ gdv_int32 start = 0;
+ // start denotes the first position of non-space characters in the input string
+ while (start < data_len && data[start] == ' ') {
+ ++start;
+ }
+
+ *out_len = data_len - start;
+ return data + start;
+}
+
+// Trims whitespaces from the right end of the input utf8 sequence
+FORCE_INLINE
+const char* rtrim_utf8(gdv_int64 context, const char* data, gdv_int32 data_len,
+ int32_t* out_len) {
+ if (data_len == 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ gdv_int32 end = data_len - 1;
+ // end denotes the last position of non-space characters in the input string
+ while (end >= 0 && data[end] == ' ') {
+ --end;
+ }
+
+ *out_len = end + 1;
+ return data;
+}
+
+// Trims whitespaces from both the ends of the input utf8 sequence
+FORCE_INLINE
+const char* btrim_utf8(gdv_int64 context, const char* data, gdv_int32 data_len,
+ int32_t* out_len) {
+ if (data_len == 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ gdv_int32 start = 0, end = data_len - 1;
+ // start and end denote the first and last positions of non-space
+ // characters in the input string respectively
+ while (start <= end && data[start] == ' ') {
+ ++start;
+ }
+ while (end >= start && data[end] == ' ') {
+ --end;
+ }
+
+ // string has some leading/trailing spaces and some non-space characters
+ *out_len = end - start + 1;
+ return data + start;
+}
+
+// Trims characters present in the trim text from the left end of the base text
+FORCE_INLINE
+const char* ltrim_utf8_utf8(gdv_int64 context, const char* basetext,
+ gdv_int32 basetext_len, const char* trimtext,
+ gdv_int32 trimtext_len, int32_t* out_len) {
+ if (basetext_len == 0) {
+ *out_len = 0;
+ return "";
+ } else if (trimtext_len == 0) {
+ *out_len = basetext_len;
+ return basetext;
+ }
+
+ gdv_int32 start_ptr, char_len;
+ // scan the base text from left to right and increment the start pointer till
+ // there is a character which is not present in the trim text
+ for (start_ptr = 0; start_ptr < basetext_len; start_ptr += char_len) {
+ char_len = utf8_char_length(basetext[start_ptr]);
+ if (char_len == 0 || start_ptr + char_len > basetext_len) {
+ // invalid byte or incomplete glyph
+ set_error_for_invalid_utf(context, basetext[start_ptr]);
+ *out_len = 0;
+ return "";
+ }
+ if (!is_substr_utf8_utf8(trimtext, trimtext_len, basetext + start_ptr, char_len)) {
+ break;
+ }
+ }
+
+ *out_len = basetext_len - start_ptr;
+ return basetext + start_ptr;
+}
+
+// Trims characters present in the trim text from the right end of the base text
+FORCE_INLINE
+const char* rtrim_utf8_utf8(gdv_int64 context, const char* basetext,
+ gdv_int32 basetext_len, const char* trimtext,
+ gdv_int32 trimtext_len, int32_t* out_len) {
+ if (basetext_len == 0) {
+ *out_len = 0;
+ return "";
+ } else if (trimtext_len == 0) {
+ *out_len = basetext_len;
+ return basetext;
+ }
+
+ gdv_int32 char_len, end_ptr, byte_cnt = 1;
+ // scan the base text from right to left and decrement the end pointer till
+ // there is a character which is not present in the trim text
+ for (end_ptr = basetext_len - 1; end_ptr >= 0; --end_ptr) {
+ char_len = utf8_char_length(basetext[end_ptr]);
+ if (char_len == 0) { // trailing bytes of multibyte character
+ ++byte_cnt;
+ continue;
+ }
+ // this is the first byte of a character, hence check if char_len = char_cnt
+ if (byte_cnt != char_len) { // invalid byte or incomplete glyph
+ set_error_for_invalid_utf(context, basetext[end_ptr]);
+ *out_len = 0;
+ return "";
+ }
+ byte_cnt = 1; // reset the counter*/
+ if (!is_substr_utf8_utf8(trimtext, trimtext_len, basetext + end_ptr, char_len)) {
+ break;
+ }
+ }
+
+ // when all characters in the basetext are part of the trimtext
+ if (end_ptr == -1) {
+ *out_len = 0;
+ return "";
+ }
+
+ end_ptr += utf8_char_length(basetext[end_ptr]); // point to the next character
+ *out_len = end_ptr;
+ return basetext;
+}
+
+// Trims characters present in the trim text from both ends of the base text
+FORCE_INLINE
+const char* btrim_utf8_utf8(gdv_int64 context, const char* basetext,
+ gdv_int32 basetext_len, const char* trimtext,
+ gdv_int32 trimtext_len, int32_t* out_len) {
+ if (basetext_len == 0) {
+ *out_len = 0;
+ return "";
+ } else if (trimtext_len == 0) {
+ *out_len = basetext_len;
+ return basetext;
+ }
+
+ gdv_int32 start_ptr, end_ptr, char_len, byte_cnt = 1;
+ // scan the base text from left to right and increment the start and decrement the
+ // end pointers till there are characters which are not present in the trim text
+ for (start_ptr = 0; start_ptr < basetext_len; start_ptr += char_len) {
+ char_len = utf8_char_length(basetext[start_ptr]);
+ if (char_len == 0 || start_ptr + char_len > basetext_len) {
+ // invalid byte or incomplete glyph
+ set_error_for_invalid_utf(context, basetext[start_ptr]);
+ *out_len = 0;
+ return "";
+ }
+ if (!is_substr_utf8_utf8(trimtext, trimtext_len, basetext + start_ptr, char_len)) {
+ break;
+ }
+ }
+ for (end_ptr = basetext_len - 1; end_ptr >= start_ptr; --end_ptr) {
+ char_len = utf8_char_length(basetext[end_ptr]);
+ if (char_len == 0) { // trailing byte in multibyte character
+ ++byte_cnt;
+ continue;
+ }
+ // this is the first byte of a character, hence check if char_len = char_cnt
+ if (byte_cnt != char_len) { // invalid byte or incomplete glyph
+ set_error_for_invalid_utf(context, basetext[end_ptr]);
+ *out_len = 0;
+ return "";
+ }
+ byte_cnt = 1; // reset the counter*/
+ if (!is_substr_utf8_utf8(trimtext, trimtext_len, basetext + end_ptr, char_len)) {
+ break;
+ }
+ }
+
+ // when all characters are trimmed, start_ptr has been incremented to basetext_len and
+ // end_ptr still points to basetext_len - 1, hence we need to handle this case
+ if (start_ptr > end_ptr) {
+ *out_len = 0;
+ return "";
+ }
+
+ end_ptr += utf8_char_length(basetext[end_ptr]); // point to the next character
+ *out_len = end_ptr - start_ptr;
+ return basetext + start_ptr;
+}
+
+FORCE_INLINE
+gdv_boolean compare_lower_strings(const char* base_str, gdv_int32 base_str_len,
+ const char* str, gdv_int32 str_len) {
+ if (base_str_len != str_len) {
+ return false;
+ }
+ for (int i = 0; i < str_len; i++) {
+ // convert char to lower
+ char cur = str[i];
+ // 'A' - 'Z' : 0x41 - 0x5a
+ // 'a' - 'z' : 0x61 - 0x7a
+ if (cur >= 0x41 && cur <= 0x5a) {
+ cur = static_cast<char>(cur + 0x20);
+ }
+ // if the character does not match, break the flow
+ if (cur != base_str[i]) break;
+ // if the character matches and it is the last iteration, return true
+ if (i == str_len - 1) return true;
+ }
+ return false;
+}
+
+// Try to cast the received string ('0', '1', 'true', 'false'), ignoring leading
+// and trailing spaces, also ignoring lower and upper case.
+FORCE_INLINE
+gdv_boolean castBIT_utf8(gdv_int64 context, const char* data, gdv_int32 data_len) {
+ if (data_len <= 0) {
+ gdv_fn_context_set_error_msg(context, "Invalid value for boolean.");
+ return false;
+ }
+
+ // trim leading and trailing spaces
+ int32_t trimmed_len;
+ int32_t start = 0, end = data_len - 1;
+ while (start <= end && data[start] == ' ') {
+ ++start;
+ }
+ while (end >= start && data[end] == ' ') {
+ --end;
+ }
+ trimmed_len = end - start + 1;
+ const char* trimmed_data = data + start;
+
+ // compare received string with the valid bool string values '1', '0', 'true', 'false'
+ if (trimmed_len == 1) {
+ // case for '0' and '1' value
+ if (trimmed_data[0] == '1') return true;
+ if (trimmed_data[0] == '0') return false;
+ } else if (trimmed_len == 4) {
+ // case for matching 'true'
+ if (compare_lower_strings("true", 4, trimmed_data, trimmed_len)) return true;
+ } else if (trimmed_len == 5) {
+ // case for matching 'false'
+ if (compare_lower_strings("false", 5, trimmed_data, trimmed_len)) return false;
+ }
+ // if no 'true', 'false', '0' or '1' value is found, set an error
+ gdv_fn_context_set_error_msg(context, "Invalid value for boolean.");
+ return false;
+}
+
+FORCE_INLINE
+const char* castVARCHAR_bool_int64(gdv_int64 context, gdv_boolean value,
+ gdv_int64 out_len, gdv_int32* out_length) {
+ gdv_int32 len = static_cast<gdv_int32>(out_len);
+ if (len < 0) {
+ gdv_fn_context_set_error_msg(context, "Output buffer length can't be negative");
+ *out_length = 0;
+ return "";
+ }
+ const char* out =
+ reinterpret_cast<const char*>(gdv_fn_context_arena_malloc(context, 5));
+ out = value ? "true" : "false";
+ *out_length = value ? ((len > 4) ? 4 : len) : ((len > 5) ? 5 : len);
+ return out;
+}
+
+// Truncates the string to given length
+#define CAST_VARCHAR_FROM_VARLEN_TYPE(TYPE) \
+ FORCE_INLINE \
+ const char* castVARCHAR_##TYPE##_int64(gdv_int64 context, const char* data, \
+ gdv_int32 data_len, int64_t out_len, \
+ int32_t* out_length) { \
+ int32_t len = static_cast<int32_t>(out_len); \
+ \
+ if (len < 0) { \
+ gdv_fn_context_set_error_msg(context, "Output buffer length can't be negative"); \
+ *out_length = 0; \
+ return ""; \
+ } \
+ \
+ if (len >= data_len || len == 0) { \
+ *out_length = data_len; \
+ return data; \
+ } \
+ \
+ int32_t remaining = len; \
+ int32_t index = 0; \
+ bool is_multibyte = false; \
+ do { \
+ /* In utf8, MSB of a single byte unicode char is always 0, \
+ * whereas for a multibyte character the MSB of each byte is 1. \
+ * So for a single byte char, a bitwise-and with x80 (10000000) will be 0 \
+ * and it won't be 0 for bytes of a multibyte char. \
+ */ \
+ char* data_ptr = const_cast<char*>(data); \
+ \
+ /* advance byte by byte till the 8-byte boundary then advance 8 bytes */ \
+ auto num_bytes = reinterpret_cast<uintptr_t>(data_ptr) & 0x07; \
+ num_bytes = (8 - num_bytes) & 0x07; \
+ while (num_bytes > 0) { \
+ uint8_t* ptr = reinterpret_cast<uint8_t*>(data_ptr + index); \
+ if ((*ptr & 0x80) != 0) { \
+ is_multibyte = true; \
+ break; \
+ } \
+ index++; \
+ remaining--; \
+ num_bytes--; \
+ } \
+ if (is_multibyte) break; \
+ while (remaining >= 8) { \
+ uint64_t* ptr = reinterpret_cast<uint64_t*>(data_ptr + index); \
+ if ((*ptr & 0x8080808080808080) != 0) { \
+ is_multibyte = true; \
+ break; \
+ } \
+ index += 8; \
+ remaining -= 8; \
+ } \
+ if (is_multibyte) break; \
+ if (remaining >= 4) { \
+ uint32_t* ptr = reinterpret_cast<uint32_t*>(data_ptr + index); \
+ if ((*ptr & 0x80808080) != 0) break; \
+ index += 4; \
+ remaining -= 4; \
+ } \
+ while (remaining > 0) { \
+ uint8_t* ptr = reinterpret_cast<uint8_t*>(data_ptr + index); \
+ if ((*ptr & 0x80) != 0) { \
+ is_multibyte = true; \
+ break; \
+ } \
+ index++; \
+ remaining--; \
+ } \
+ if (is_multibyte) break; \
+ /* reached here; all are single byte characters */ \
+ *out_length = len; \
+ return data; \
+ } while (false); \
+ \
+ /* detected multibyte utf8 characters; slow path */ \
+ int32_t byte_pos = \
+ utf8_byte_pos(context, data + index, data_len - index, len - index); \
+ if (byte_pos < 0) { \
+ *out_length = 0; \
+ return ""; \
+ } \
+ \
+ *out_length = index + byte_pos; \
+ return data; \
+ }
+
+CAST_VARCHAR_FROM_VARLEN_TYPE(utf8)
+CAST_VARCHAR_FROM_VARLEN_TYPE(binary)
+
+#undef CAST_VARCHAR_FROM_VARLEN_TYPE
+
+// Add functions for castVARBINARY
+#define CAST_VARBINARY_FROM_STRING_AND_BINARY(TYPE) \
+ GANDIVA_EXPORT \
+ const char* castVARBINARY_##TYPE##_int64(gdv_int64 context, const char* data, \
+ gdv_int32 data_len, int64_t out_len, \
+ int32_t* out_length) { \
+ int32_t len = static_cast<int32_t>(out_len); \
+ if (len < 0) { \
+ gdv_fn_context_set_error_msg(context, "Output buffer length can't be negative"); \
+ *out_length = 0; \
+ return ""; \
+ } \
+ \
+ if (len >= data_len || len == 0) { \
+ *out_length = data_len; \
+ } else { \
+ *out_length = len; \
+ } \
+ return data; \
+ }
+
+CAST_VARBINARY_FROM_STRING_AND_BINARY(utf8)
+CAST_VARBINARY_FROM_STRING_AND_BINARY(binary)
+
+#undef CAST_VARBINARY_FROM_STRING_AND_BINARY
+
+#define IS_NULL(NAME, TYPE) \
+ FORCE_INLINE \
+ bool NAME##_##TYPE(gdv_##TYPE in, gdv_int32 len, gdv_boolean is_valid) { \
+ return !is_valid; \
+ }
+
+VAR_LEN_TYPES(IS_NULL, isnull)
+
+#undef IS_NULL
+
+#define IS_NOT_NULL(NAME, TYPE) \
+ FORCE_INLINE \
+ bool NAME##_##TYPE(gdv_##TYPE in, gdv_int32 len, gdv_boolean is_valid) { \
+ return is_valid; \
+ }
+
+VAR_LEN_TYPES(IS_NOT_NULL, isnotnull)
+
+#undef IS_NOT_NULL
+#undef VAR_LEN_TYPES
+
+/*
+ We follow Oracle semantics for offset:
+ - If position is positive, then the first glyph in the substring is determined by
+ counting that many glyphs forward from the beginning of the input. (i.e., for position ==
+ 1 the first glyph in the substring will be identical to the first glyph in the input)
+
+ - If position is negative, then the first glyph in the substring is determined by
+ counting that many glyphs backward from the end of the input. (i.e., for position == -1
+ the first glyph in the substring will be identical to the last glyph in the input)
+
+ - If position is 0 then it is treated as 1.
+ */
+FORCE_INLINE
+const char* substr_utf8_int64_int64(gdv_int64 context, const char* input,
+ gdv_int32 in_data_len, gdv_int64 position,
+ gdv_int64 substring_length, gdv_int32* out_data_len) {
+ if (substring_length <= 0 || input == nullptr || in_data_len <= 0) {
+ *out_data_len = 0;
+ return "";
+ }
+
+ gdv_int64 in_glyphs_count =
+ static_cast<gdv_int64>(utf8_length(context, input, in_data_len));
+
+ // in_glyphs_count is zero if input has invalid glyphs
+ if (in_glyphs_count == 0) {
+ *out_data_len = 0;
+ return "";
+ }
+
+ gdv_int64 from_glyph; // from_glyph==0 indicates the first glyph of the input
+ if (position > 0) {
+ from_glyph = position - 1;
+ } else if (position < 0) {
+ from_glyph = in_glyphs_count + position;
+ } else {
+ from_glyph = 0;
+ }
+
+ if (from_glyph < 0 || from_glyph >= in_glyphs_count) {
+ *out_data_len = 0;
+ return "";
+ }
+
+ gdv_int64 out_glyphs_count = substring_length;
+ if (substring_length > in_glyphs_count - from_glyph) {
+ out_glyphs_count = in_glyphs_count - from_glyph;
+ }
+
+ gdv_int64 in_data_len64 = static_cast<gdv_int64>(in_data_len);
+ gdv_int64 start_pos = 0;
+ gdv_int64 end_pos = in_data_len64;
+
+ gdv_int64 current_glyph = 0;
+ gdv_int64 pos = 0;
+ while (pos < in_data_len64) {
+ if (current_glyph == from_glyph) {
+ start_pos = pos;
+ }
+ pos += static_cast<gdv_int64>(utf8_char_length(input[pos]));
+ if (current_glyph - from_glyph + 1 == out_glyphs_count) {
+ end_pos = pos;
+ }
+ current_glyph++;
+ }
+
+ if (end_pos > in_data_len64 || end_pos > INT_MAX) {
+ end_pos = in_data_len64;
+ }
+
+ *out_data_len = static_cast<gdv_int32>(end_pos - start_pos);
+ char* ret =
+ reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_data_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_data_len = 0;
+ return "";
+ }
+ memcpy(ret, input + start_pos, *out_data_len);
+ return ret;
+}
+
+FORCE_INLINE
+const char* substr_utf8_int64(gdv_int64 context, const char* input, gdv_int32 in_len,
+ gdv_int64 offset64, gdv_int32* out_len) {
+ return substr_utf8_int64_int64(context, input, in_len, offset64, in_len, out_len);
+}
+
+FORCE_INLINE
+const char* repeat_utf8_int32(gdv_int64 context, const char* in, gdv_int32 in_len,
+ gdv_int32 repeat_number, gdv_int32* out_len) {
+ // if the repeat number is zero, then return empty string
+ if (repeat_number == 0 || in_len <= 0) {
+ *out_len = 0;
+ return "";
+ }
+ // if the repeat number is a negative number, an error is set on context
+ if (repeat_number < 0) {
+ gdv_fn_context_set_error_msg(context, "Repeat number can't be negative");
+ *out_len = 0;
+ return "";
+ }
+ *out_len = repeat_number * in_len;
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ for (int i = 0; i < repeat_number; ++i) {
+ memcpy(ret + (i * in_len), in, in_len);
+ }
+ return ret;
+}
+
+FORCE_INLINE
+const char* concat_utf8_utf8(gdv_int64 context, const char* left, gdv_int32 left_len,
+ bool left_validity, const char* right, gdv_int32 right_len,
+ bool right_validity, gdv_int32* out_len) {
+ if (!left_validity) {
+ left_len = 0;
+ }
+ if (!right_validity) {
+ right_len = 0;
+ }
+ return concatOperator_utf8_utf8(context, left, left_len, right, right_len, out_len);
+}
+
+FORCE_INLINE
+const char* concatOperator_utf8_utf8(gdv_int64 context, const char* left,
+ gdv_int32 left_len, const char* right,
+ gdv_int32 right_len, gdv_int32* out_len) {
+ *out_len = left_len + right_len;
+ if (*out_len <= 0) {
+ *out_len = 0;
+ return "";
+ }
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(ret, left, left_len);
+ memcpy(ret + left_len, right, right_len);
+ return ret;
+}
+
+FORCE_INLINE
+const char* concat_utf8_utf8_utf8(gdv_int64 context, const char* in1, gdv_int32 in1_len,
+ bool in1_validity, const char* in2, gdv_int32 in2_len,
+ bool in2_validity, const char* in3, gdv_int32 in3_len,
+ bool in3_validity, gdv_int32* out_len) {
+ if (!in1_validity) {
+ in1_len = 0;
+ }
+ if (!in2_validity) {
+ in2_len = 0;
+ }
+ if (!in3_validity) {
+ in3_len = 0;
+ }
+ return concatOperator_utf8_utf8_utf8(context, in1, in1_len, in2, in2_len, in3, in3_len,
+ out_len);
+}
+
+FORCE_INLINE
+const char* concatOperator_utf8_utf8_utf8(gdv_int64 context, const char* in1,
+ gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3,
+ gdv_int32 in3_len, gdv_int32* out_len) {
+ *out_len = in1_len + in2_len + in3_len;
+ if (*out_len <= 0) {
+ *out_len = 0;
+ return "";
+ }
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(ret, in1, in1_len);
+ memcpy(ret + in1_len, in2, in2_len);
+ memcpy(ret + in1_len + in2_len, in3, in3_len);
+ return ret;
+}
+
+FORCE_INLINE
+const char* concat_utf8_utf8_utf8_utf8(gdv_int64 context, const char* in1,
+ gdv_int32 in1_len, bool in1_validity,
+ const char* in2, gdv_int32 in2_len,
+ bool in2_validity, const char* in3,
+ gdv_int32 in3_len, bool in3_validity,
+ const char* in4, gdv_int32 in4_len,
+ bool in4_validity, gdv_int32* out_len) {
+ if (!in1_validity) {
+ in1_len = 0;
+ }
+ if (!in2_validity) {
+ in2_len = 0;
+ }
+ if (!in3_validity) {
+ in3_len = 0;
+ }
+ if (!in4_validity) {
+ in4_len = 0;
+ }
+ return concatOperator_utf8_utf8_utf8_utf8(context, in1, in1_len, in2, in2_len, in3,
+ in3_len, in4, in4_len, out_len);
+}
+
+FORCE_INLINE
+const char* concatOperator_utf8_utf8_utf8_utf8(gdv_int64 context, const char* in1,
+ gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3,
+ gdv_int32 in3_len, const char* in4,
+ gdv_int32 in4_len, gdv_int32* out_len) {
+ *out_len = in1_len + in2_len + in3_len + in4_len;
+ if (*out_len <= 0) {
+ *out_len = 0;
+ return "";
+ }
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(ret, in1, in1_len);
+ memcpy(ret + in1_len, in2, in2_len);
+ memcpy(ret + in1_len + in2_len, in3, in3_len);
+ memcpy(ret + in1_len + in2_len + in3_len, in4, in4_len);
+ return ret;
+}
+
+FORCE_INLINE
+const char* concat_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity,
+ const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3,
+ gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len,
+ bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity,
+ gdv_int32* out_len) {
+ if (!in1_validity) {
+ in1_len = 0;
+ }
+ if (!in2_validity) {
+ in2_len = 0;
+ }
+ if (!in3_validity) {
+ in3_len = 0;
+ }
+ if (!in4_validity) {
+ in4_len = 0;
+ }
+ if (!in5_validity) {
+ in5_len = 0;
+ }
+ return concatOperator_utf8_utf8_utf8_utf8_utf8(context, in1, in1_len, in2, in2_len, in3,
+ in3_len, in4, in4_len, in5, in5_len,
+ out_len);
+}
+
+FORCE_INLINE
+const char* concatOperator_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4,
+ gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, gdv_int32* out_len) {
+ *out_len = in1_len + in2_len + in3_len + in4_len + in5_len;
+ if (*out_len <= 0) {
+ *out_len = 0;
+ return "";
+ }
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(ret, in1, in1_len);
+ memcpy(ret + in1_len, in2, in2_len);
+ memcpy(ret + in1_len + in2_len, in3, in3_len);
+ memcpy(ret + in1_len + in2_len + in3_len, in4, in4_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len, in5, in5_len);
+ return ret;
+}
+
+FORCE_INLINE
+const char* concat_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity,
+ const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3,
+ gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len,
+ bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity,
+ const char* in6, gdv_int32 in6_len, bool in6_validity, gdv_int32* out_len) {
+ if (!in1_validity) {
+ in1_len = 0;
+ }
+ if (!in2_validity) {
+ in2_len = 0;
+ }
+ if (!in3_validity) {
+ in3_len = 0;
+ }
+ if (!in4_validity) {
+ in4_len = 0;
+ }
+ if (!in5_validity) {
+ in5_len = 0;
+ }
+ if (!in6_validity) {
+ in6_len = 0;
+ }
+ return concatOperator_utf8_utf8_utf8_utf8_utf8_utf8(context, in1, in1_len, in2, in2_len,
+ in3, in3_len, in4, in4_len, in5,
+ in5_len, in6, in6_len, out_len);
+}
+
+FORCE_INLINE
+const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4,
+ gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6,
+ gdv_int32 in6_len, gdv_int32* out_len) {
+ *out_len = in1_len + in2_len + in3_len + in4_len + in5_len + in6_len;
+ if (*out_len <= 0) {
+ *out_len = 0;
+ return "";
+ }
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(ret, in1, in1_len);
+ memcpy(ret + in1_len, in2, in2_len);
+ memcpy(ret + in1_len + in2_len, in3, in3_len);
+ memcpy(ret + in1_len + in2_len + in3_len, in4, in4_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len, in5, in5_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len, in6, in6_len);
+ return ret;
+}
+
+FORCE_INLINE
+const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity,
+ const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3,
+ gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len,
+ bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity,
+ const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7,
+ gdv_int32 in7_len, bool in7_validity, gdv_int32* out_len) {
+ if (!in1_validity) {
+ in1_len = 0;
+ }
+ if (!in2_validity) {
+ in2_len = 0;
+ }
+ if (!in3_validity) {
+ in3_len = 0;
+ }
+ if (!in4_validity) {
+ in4_len = 0;
+ }
+ if (!in5_validity) {
+ in5_len = 0;
+ }
+ if (!in6_validity) {
+ in6_len = 0;
+ }
+ if (!in7_validity) {
+ in7_len = 0;
+ }
+ return concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ context, in1, in1_len, in2, in2_len, in3, in3_len, in4, in4_len, in5, in5_len, in6,
+ in6_len, in7, in7_len, out_len);
+}
+
+FORCE_INLINE
+const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4,
+ gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6,
+ gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, gdv_int32* out_len) {
+ *out_len = in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len;
+ if (*out_len <= 0) {
+ *out_len = 0;
+ return "";
+ }
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(ret, in1, in1_len);
+ memcpy(ret + in1_len, in2, in2_len);
+ memcpy(ret + in1_len + in2_len, in3, in3_len);
+ memcpy(ret + in1_len + in2_len + in3_len, in4, in4_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len, in5, in5_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len, in6, in6_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len, in7, in7_len);
+ return ret;
+}
+
+FORCE_INLINE
+const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity,
+ const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3,
+ gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len,
+ bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity,
+ const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7,
+ gdv_int32 in7_len, bool in7_validity, const char* in8, gdv_int32 in8_len,
+ bool in8_validity, gdv_int32* out_len) {
+ if (!in1_validity) {
+ in1_len = 0;
+ }
+ if (!in2_validity) {
+ in2_len = 0;
+ }
+ if (!in3_validity) {
+ in3_len = 0;
+ }
+ if (!in4_validity) {
+ in4_len = 0;
+ }
+ if (!in5_validity) {
+ in5_len = 0;
+ }
+ if (!in6_validity) {
+ in6_len = 0;
+ }
+ if (!in7_validity) {
+ in7_len = 0;
+ }
+ if (!in8_validity) {
+ in8_len = 0;
+ }
+ return concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ context, in1, in1_len, in2, in2_len, in3, in3_len, in4, in4_len, in5, in5_len, in6,
+ in6_len, in7, in7_len, in8, in8_len, out_len);
+}
+
+FORCE_INLINE
+const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4,
+ gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6,
+ gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, const char* in8,
+ gdv_int32 in8_len, gdv_int32* out_len) {
+ *out_len =
+ in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len + in8_len;
+ if (*out_len <= 0) {
+ *out_len = 0;
+ return "";
+ }
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(ret, in1, in1_len);
+ memcpy(ret + in1_len, in2, in2_len);
+ memcpy(ret + in1_len + in2_len, in3, in3_len);
+ memcpy(ret + in1_len + in2_len + in3_len, in4, in4_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len, in5, in5_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len, in6, in6_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len, in7, in7_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len, in8,
+ in8_len);
+ return ret;
+}
+
+FORCE_INLINE
+const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity,
+ const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3,
+ gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len,
+ bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity,
+ const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7,
+ gdv_int32 in7_len, bool in7_validity, const char* in8, gdv_int32 in8_len,
+ bool in8_validity, const char* in9, gdv_int32 in9_len, bool in9_validity,
+ gdv_int32* out_len) {
+ if (!in1_validity) {
+ in1_len = 0;
+ }
+ if (!in2_validity) {
+ in2_len = 0;
+ }
+ if (!in3_validity) {
+ in3_len = 0;
+ }
+ if (!in4_validity) {
+ in4_len = 0;
+ }
+ if (!in5_validity) {
+ in5_len = 0;
+ }
+ if (!in6_validity) {
+ in6_len = 0;
+ }
+ if (!in7_validity) {
+ in7_len = 0;
+ }
+ if (!in8_validity) {
+ in8_len = 0;
+ }
+ if (!in9_validity) {
+ in9_len = 0;
+ }
+ return concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ context, in1, in1_len, in2, in2_len, in3, in3_len, in4, in4_len, in5, in5_len, in6,
+ in6_len, in7, in7_len, in8, in8_len, in9, in9_len, out_len);
+}
+
+FORCE_INLINE
+const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4,
+ gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6,
+ gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, const char* in8,
+ gdv_int32 in8_len, const char* in9, gdv_int32 in9_len, gdv_int32* out_len) {
+ *out_len = in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len +
+ in8_len + in9_len;
+ if (*out_len <= 0) {
+ *out_len = 0;
+ return "";
+ }
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(ret, in1, in1_len);
+ memcpy(ret + in1_len, in2, in2_len);
+ memcpy(ret + in1_len + in2_len, in3, in3_len);
+ memcpy(ret + in1_len + in2_len + in3_len, in4, in4_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len, in5, in5_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len, in6, in6_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len, in7, in7_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len, in8,
+ in8_len);
+ memcpy(
+ ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len + in8_len,
+ in9, in9_len);
+ return ret;
+}
+
+FORCE_INLINE
+const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity,
+ const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3,
+ gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len,
+ bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity,
+ const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7,
+ gdv_int32 in7_len, bool in7_validity, const char* in8, gdv_int32 in8_len,
+ bool in8_validity, const char* in9, gdv_int32 in9_len, bool in9_validity,
+ const char* in10, gdv_int32 in10_len, bool in10_validity, gdv_int32* out_len) {
+ if (!in1_validity) {
+ in1_len = 0;
+ }
+ if (!in2_validity) {
+ in2_len = 0;
+ }
+ if (!in3_validity) {
+ in3_len = 0;
+ }
+ if (!in4_validity) {
+ in4_len = 0;
+ }
+ if (!in5_validity) {
+ in5_len = 0;
+ }
+ if (!in6_validity) {
+ in6_len = 0;
+ }
+ if (!in7_validity) {
+ in7_len = 0;
+ }
+ if (!in8_validity) {
+ in8_len = 0;
+ }
+ if (!in9_validity) {
+ in9_len = 0;
+ }
+ if (!in10_validity) {
+ in10_len = 0;
+ }
+ return concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ context, in1, in1_len, in2, in2_len, in3, in3_len, in4, in4_len, in5, in5_len, in6,
+ in6_len, in7, in7_len, in8, in8_len, in9, in9_len, in10, in10_len, out_len);
+}
+
+FORCE_INLINE
+const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4,
+ gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6,
+ gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, const char* in8,
+ gdv_int32 in8_len, const char* in9, gdv_int32 in9_len, const char* in10,
+ gdv_int32 in10_len, gdv_int32* out_len) {
+ *out_len = in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len +
+ in8_len + in9_len + in10_len;
+ if (*out_len <= 0) {
+ *out_len = 0;
+ return "";
+ }
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(ret, in1, in1_len);
+ memcpy(ret + in1_len, in2, in2_len);
+ memcpy(ret + in1_len + in2_len, in3, in3_len);
+ memcpy(ret + in1_len + in2_len + in3_len, in4, in4_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len, in5, in5_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len, in6, in6_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len, in7, in7_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len, in8,
+ in8_len);
+ memcpy(
+ ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len + in8_len,
+ in9, in9_len);
+ memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len +
+ in8_len + in9_len,
+ in10, in10_len);
+ return ret;
+}
+
+// Returns the numeric value of the first character of str.
+GANDIVA_EXPORT
+gdv_int32 ascii_utf8(const char* data, gdv_int32 data_len) {
+ if (data_len == 0) {
+ return 0;
+ }
+ return static_cast<gdv_int32>(data[0]);
+}
+
+FORCE_INLINE
+const char* convert_fromUTF8_binary(gdv_int64 context, const char* bin_in, gdv_int32 len,
+ gdv_int32* out_len) {
+ *out_len = len;
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(ret, bin_in, *out_len);
+ return ret;
+}
+
+FORCE_INLINE
+const char* convert_replace_invalid_fromUTF8_binary(int64_t context, const char* text_in,
+ int32_t text_len,
+ const char* char_to_replace,
+ int32_t char_to_replace_len,
+ int32_t* out_len) {
+ if (char_to_replace_len > 1) {
+ gdv_fn_context_set_error_msg(context, "Replacement of multiple bytes not supported");
+ *out_len = 0;
+ return "";
+ }
+ // actually the convert_replace function replaces invalid chars with an ASCII
+ // character so the output length will be the same as the input length
+ *out_len = text_len;
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ int32_t valid_bytes_to_cpy = 0;
+ int32_t out_byte_counter = 0;
+ int32_t in_byte_counter = 0;
+ int32_t char_len;
+ // scan the base text from left to right and increment the start pointer till
+ // looking for invalid chars to substitute
+ for (int text_index = 0; text_index < text_len; text_index += char_len) {
+ char_len = utf8_char_length(text_in[text_index]);
+ // only memory copy the bytes when detect invalid char
+ if (char_len == 0 || text_index + char_len > text_len ||
+ !validate_utf8_following_bytes(text_in, char_len, text_index)) {
+ // define char_len = 1 to increase text_index by 1 (as ASCII char fits in 1 byte)
+ char_len = 1;
+ // first copy the valid bytes until now and then replace the invalid character
+ memcpy(ret + out_byte_counter, text_in + in_byte_counter, valid_bytes_to_cpy);
+ // if the replacement char is empty, the invalid char should be ignored
+ if (char_to_replace_len == 0) {
+ out_byte_counter += valid_bytes_to_cpy;
+ } else {
+ ret[out_byte_counter + valid_bytes_to_cpy] = char_to_replace[0];
+ out_byte_counter += valid_bytes_to_cpy + char_len;
+ }
+ in_byte_counter += valid_bytes_to_cpy + char_len;
+ valid_bytes_to_cpy = 0;
+ continue;
+ }
+ valid_bytes_to_cpy += char_len;
+ }
+ // if invalid chars were not found, return the original string
+ if (out_byte_counter == 0 && in_byte_counter == 0) return text_in;
+ // if there are still valid bytes to copy, do it
+ if (valid_bytes_to_cpy != 0) {
+ memcpy(ret + out_byte_counter, text_in + in_byte_counter, valid_bytes_to_cpy);
+ }
+ // the out length will be the out bytes copied + the missing end bytes copied
+ *out_len = valid_bytes_to_cpy + out_byte_counter;
+ return ret;
+}
+
+// The function reverse a char array in-place
+static inline void reverse_char_buf(char* buf, int32_t len) {
+ char temp;
+
+ for (int32_t i = 0; i < len / 2; i++) {
+ int32_t pos_swp = len - (1 + i);
+ temp = buf[pos_swp];
+ buf[pos_swp] = buf[i];
+ buf[i] = temp;
+ }
+}
+
+// Converts a double variable to binary
+FORCE_INLINE
+const char* convert_toDOUBLE(int64_t context, double value, int32_t* out_len) {
+ *out_len = sizeof(value);
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not allocate memory for the output string");
+
+ *out_len = 0;
+ return "";
+ }
+
+ memcpy(ret, &value, *out_len);
+
+ return ret;
+}
+
+FORCE_INLINE
+const char* convert_toDOUBLE_be(int64_t context, double value, int32_t* out_len) {
+ // The function behaves like convert_toDOUBLE, but always return the result
+ // in big endian format
+ char* ret = const_cast<char*>(convert_toDOUBLE(context, value, out_len));
+
+#if ARROW_LITTLE_ENDIAN
+ reverse_char_buf(ret, *out_len);
+#endif
+
+ return ret;
+}
+
+// Converts a float variable to binary
+FORCE_INLINE
+const char* convert_toFLOAT(int64_t context, float value, int32_t* out_len) {
+ *out_len = sizeof(value);
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not allocate memory for the output string");
+
+ *out_len = 0;
+ return "";
+ }
+
+ memcpy(ret, &value, *out_len);
+
+ return ret;
+}
+
+FORCE_INLINE
+const char* convert_toFLOAT_be(int64_t context, float value, int32_t* out_len) {
+ // The function behaves like convert_toFLOAT, but always return the result
+ // in big endian format
+ char* ret = const_cast<char*>(convert_toFLOAT(context, value, out_len));
+
+#if ARROW_LITTLE_ENDIAN
+ reverse_char_buf(ret, *out_len);
+#endif
+
+ return ret;
+}
+
+// Converts a bigint(int with 64 bits) variable to binary
+FORCE_INLINE
+const char* convert_toBIGINT(int64_t context, int64_t value, int32_t* out_len) {
+ *out_len = sizeof(value);
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not allocate memory for the output string");
+
+ *out_len = 0;
+ return "";
+ }
+
+ memcpy(ret, &value, *out_len);
+
+ return ret;
+}
+
+FORCE_INLINE
+const char* convert_toBIGINT_be(int64_t context, int64_t value, int32_t* out_len) {
+ // The function behaves like convert_toBIGINT, but always return the result
+ // in big endian format
+ char* ret = const_cast<char*>(convert_toBIGINT(context, value, out_len));
+
+#if ARROW_LITTLE_ENDIAN
+ reverse_char_buf(ret, *out_len);
+#endif
+
+ return ret;
+}
+
+// Converts an integer(with 32 bits) variable to binary
+FORCE_INLINE
+const char* convert_toINT(int64_t context, int32_t value, int32_t* out_len) {
+ *out_len = sizeof(value);
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not allocate memory for the output string");
+
+ *out_len = 0;
+ return "";
+ }
+
+ memcpy(ret, &value, *out_len);
+
+ return ret;
+}
+
+FORCE_INLINE
+const char* convert_toINT_be(int64_t context, int32_t value, int32_t* out_len) {
+ // The function behaves like convert_toINT, but always return the result
+ // in big endian format
+ char* ret = const_cast<char*>(convert_toINT(context, value, out_len));
+
+#if ARROW_LITTLE_ENDIAN
+ reverse_char_buf(ret, *out_len);
+#endif
+
+ return ret;
+}
+
+// Converts a boolean variable to binary
+FORCE_INLINE
+const char* convert_toBOOLEAN(int64_t context, bool value, int32_t* out_len) {
+ *out_len = sizeof(value);
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not allocate memory for the output string");
+
+ *out_len = 0;
+ return "";
+ }
+
+ memcpy(ret, &value, *out_len);
+
+ return ret;
+}
+
+// Converts a time variable to binary
+FORCE_INLINE
+const char* convert_toTIME_EPOCH(int64_t context, int32_t value, int32_t* out_len) {
+ return convert_toINT(context, value, out_len);
+}
+
+FORCE_INLINE
+const char* convert_toTIME_EPOCH_be(int64_t context, int32_t value, int32_t* out_len) {
+ // The function behaves as convert_toTIME_EPOCH, but
+ // returns the bytes in big endian format
+ return convert_toINT_be(context, value, out_len);
+}
+
+// Converts a timestamp variable to binary
+FORCE_INLINE
+const char* convert_toTIMESTAMP_EPOCH(int64_t context, int64_t timestamp,
+ int32_t* out_len) {
+ return convert_toBIGINT(context, timestamp, out_len);
+}
+
+FORCE_INLINE
+const char* convert_toTIMESTAMP_EPOCH_be(int64_t context, int64_t timestamp,
+ int32_t* out_len) {
+ // The function behaves as convert_toTIMESTAMP_EPOCH, but
+ // returns the bytes in big endian format
+ return convert_toBIGINT_be(context, timestamp, out_len);
+}
+
+// Converts a date variable to binary
+FORCE_INLINE
+const char* convert_toDATE_EPOCH(int64_t context, int64_t date, int32_t* out_len) {
+ return convert_toBIGINT(context, date, out_len);
+}
+
+FORCE_INLINE
+const char* convert_toDATE_EPOCH_be(int64_t context, int64_t date, int32_t* out_len) {
+ // The function behaves as convert_toDATE_EPOCH, but
+ // returns the bytes in big endian format
+ return convert_toBIGINT_be(context, date, out_len);
+}
+
+// Converts a string variable to binary
+FORCE_INLINE
+const char* convert_toUTF8(int64_t context, const char* value, int32_t value_len,
+ int32_t* out_len) {
+ *out_len = value_len;
+ return value;
+}
+
+// Search for a string within another string
+// Same as "locate(substr, str)", except for the reverse order of the arguments.
+FORCE_INLINE
+gdv_int32 strpos_utf8_utf8(gdv_int64 context, const char* str, gdv_int32 str_len,
+ const char* sub_str, gdv_int32 sub_str_len) {
+ return locate_utf8_utf8_int32(context, sub_str, sub_str_len, str, str_len, 1);
+}
+
+// Search for a string within another string
+FORCE_INLINE
+gdv_int32 locate_utf8_utf8(gdv_int64 context, const char* sub_str, gdv_int32 sub_str_len,
+ const char* str, gdv_int32 str_len) {
+ return locate_utf8_utf8_int32(context, sub_str, sub_str_len, str, str_len, 1);
+}
+
+// Search for a string within another string starting at position start-pos (1-indexed)
+FORCE_INLINE
+gdv_int32 locate_utf8_utf8_int32(gdv_int64 context, const char* sub_str,
+ gdv_int32 sub_str_len, const char* str,
+ gdv_int32 str_len, gdv_int32 start_pos) {
+ if (start_pos < 1) {
+ gdv_fn_context_set_error_msg(context, "Start position must be greater than 0");
+ return 0;
+ }
+
+ if (str_len == 0 || sub_str_len == 0) {
+ return 0;
+ }
+
+ gdv_int32 byte_pos = utf8_byte_pos(context, str, str_len, start_pos - 1);
+ if (byte_pos < 0 || byte_pos >= str_len) {
+ return 0;
+ }
+ for (gdv_int32 i = byte_pos; i <= str_len - sub_str_len; ++i) {
+ if (memcmp(str + i, sub_str, sub_str_len) == 0) {
+ return utf8_length(context, str, i) + 1;
+ }
+ }
+ return 0;
+}
+
+FORCE_INLINE
+const char* replace_with_max_len_utf8_utf8_utf8(gdv_int64 context, const char* text,
+ gdv_int32 text_len, const char* from_str,
+ gdv_int32 from_str_len,
+ const char* to_str, gdv_int32 to_str_len,
+ gdv_int32 max_length,
+ gdv_int32* out_len) {
+ // if from_str is empty or its length exceeds that of original string,
+ // return the original string
+ if (from_str_len <= 0 || from_str_len > text_len) {
+ *out_len = text_len;
+ return text;
+ }
+
+ bool found = false;
+ gdv_int32 text_index = 0;
+ char* out;
+ gdv_int32 out_index = 0;
+ gdv_int32 last_match_index =
+ 0; // defer copying string from last_match_index till next match is found
+
+ for (; text_index <= text_len - from_str_len;) {
+ if (memcmp(text + text_index, from_str, from_str_len) == 0) {
+ if (out_index + text_index - last_match_index + to_str_len > max_length) {
+ gdv_fn_context_set_error_msg(context, "Buffer overflow for output string");
+ *out_len = 0;
+ return "";
+ }
+ if (!found) {
+ // found match for first time
+ out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, max_length));
+ if (out == nullptr) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ found = true;
+ }
+ // first copy the part deferred till now
+ memcpy(out + out_index, text + last_match_index, (text_index - last_match_index));
+ out_index += text_index - last_match_index;
+ // then copy the target string
+ memcpy(out + out_index, to_str, to_str_len);
+ out_index += to_str_len;
+
+ text_index += from_str_len;
+ last_match_index = text_index;
+ } else {
+ text_index++;
+ }
+ }
+
+ if (!found) {
+ *out_len = text_len;
+ return text;
+ }
+
+ if (out_index + text_len - last_match_index > max_length) {
+ gdv_fn_context_set_error_msg(context, "Buffer overflow for output string");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(out + out_index, text + last_match_index, text_len - last_match_index);
+ out_index += text_len - last_match_index;
+ *out_len = out_index;
+ return out;
+}
+
+FORCE_INLINE
+const char* replace_utf8_utf8_utf8(gdv_int64 context, const char* text,
+ gdv_int32 text_len, const char* from_str,
+ gdv_int32 from_str_len, const char* to_str,
+ gdv_int32 to_str_len, gdv_int32* out_len) {
+ return replace_with_max_len_utf8_utf8_utf8(context, text, text_len, from_str,
+ from_str_len, to_str, to_str_len, 65535,
+ out_len);
+}
+
+FORCE_INLINE
+const char* lpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 text_len,
+ gdv_int32 return_length, const char* fill_text,
+ gdv_int32 fill_text_len, gdv_int32* out_len) {
+ // if the text length or the defined return length (number of characters to return)
+ // is <=0, then return an empty string.
+ if (text_len == 0 || return_length <= 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ // count the number of utf8 characters on text, ignoring invalid bytes
+ int text_char_count = utf8_length_ignore_invalid(text, text_len);
+
+ if (return_length == text_char_count ||
+ (return_length > text_char_count && fill_text_len == 0)) {
+ // case where the return length is same as the text's length, or if it need to
+ // fill into text but "fill_text" is empty, then return text directly.
+ *out_len = text_len;
+ return text;
+ } else if (return_length < text_char_count) {
+ // case where it truncates the result on return length.
+ *out_len = utf8_byte_pos(context, text, text_len, return_length);
+ return text;
+ } else {
+ // case (return_length > text_char_count)
+ // case where it needs to copy "fill_text" on the string left. The total number
+ // of chars to copy is given by (return_length - text_char_count)
+ char* ret =
+ reinterpret_cast<gdv_binary>(gdv_fn_context_arena_malloc(context, return_length));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ // try to fulfill the return string with the "fill_text" continuously
+ int32_t copied_chars_count = 0;
+ int32_t copied_chars_position = 0;
+ while (copied_chars_count < return_length - text_char_count) {
+ int32_t char_len;
+ int32_t fill_index;
+ // for each char, evaluate its length to consider it when mem copying
+ for (fill_index = 0; fill_index < fill_text_len; fill_index += char_len) {
+ if (copied_chars_count >= return_length - text_char_count) {
+ break;
+ }
+ char_len = utf8_char_length(fill_text[fill_index]);
+ // ignore invalid char on the fill text, considering it as size 1
+ if (char_len == 0) char_len += 1;
+ copied_chars_count++;
+ }
+ memcpy(ret + copied_chars_position, fill_text, fill_index);
+ copied_chars_position += fill_index;
+ }
+ // after fulfilling the text, copy the main string
+ memcpy(ret + copied_chars_position, text, text_len);
+ *out_len = copied_chars_position + text_len;
+ return ret;
+ }
+}
+
+FORCE_INLINE
+const char* rpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 text_len,
+ gdv_int32 return_length, const char* fill_text,
+ gdv_int32 fill_text_len, gdv_int32* out_len) {
+ // if the text length or the defined return length (number of characters to return)
+ // is <=0, then return an empty string.
+ if (text_len == 0 || return_length <= 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ // count the number of utf8 characters on text, ignoring invalid bytes
+ int text_char_count = utf8_length_ignore_invalid(text, text_len);
+
+ if (return_length == text_char_count ||
+ (return_length > text_char_count && fill_text_len == 0)) {
+ // case where the return length is same as the text's length, or if it need to
+ // fill into text but "fill_text" is empty, then return text directly.
+ *out_len = text_len;
+ return text;
+ } else if (return_length < text_char_count) {
+ // case where it truncates the result on return length.
+ *out_len = utf8_byte_pos(context, text, text_len, return_length);
+ return text;
+ } else {
+ // case (return_length > text_char_count)
+ // case where it needs to copy "fill_text" on the string right
+ char* ret =
+ reinterpret_cast<gdv_binary>(gdv_fn_context_arena_malloc(context, return_length));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ // fulfill the initial text copying the main input string
+ memcpy(ret, text, text_len);
+ // try to fulfill the return string with the "fill_text" continuously
+ int32_t copied_chars_count = 0;
+ int32_t copied_chars_position = 0;
+ while (text_char_count + copied_chars_count < return_length) {
+ int32_t char_len;
+ int32_t fill_length;
+ // for each char, evaluate its length to consider it when mem copying
+ for (fill_length = 0; fill_length < fill_text_len; fill_length += char_len) {
+ if (text_char_count + copied_chars_count >= return_length) {
+ break;
+ }
+ char_len = utf8_char_length(fill_text[fill_length]);
+ // ignore invalid char on the fill text, considering it as size 1
+ if (char_len == 0) char_len += 1;
+ copied_chars_count++;
+ }
+ memcpy(ret + text_len + copied_chars_position, fill_text, fill_length);
+ copied_chars_position += fill_length;
+ }
+ *out_len = copied_chars_position + text_len;
+ return ret;
+ }
+}
+
+FORCE_INLINE
+const char* lpad_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len,
+ gdv_int32 return_length, gdv_int32* out_len) {
+ return lpad_utf8_int32_utf8(context, text, text_len, return_length, " ", 1, out_len);
+}
+
+FORCE_INLINE
+const char* rpad_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len,
+ gdv_int32 return_length, gdv_int32* out_len) {
+ return rpad_utf8_int32_utf8(context, text, text_len, return_length, " ", 1, out_len);
+}
+
+FORCE_INLINE
+const char* split_part(gdv_int64 context, const char* text, gdv_int32 text_len,
+ const char* delimiter, gdv_int32 delim_len, gdv_int32 index,
+ gdv_int32* out_len) {
+ *out_len = 0;
+ if (index < 1) {
+ char error_message[100];
+ snprintf(error_message, sizeof(error_message),
+ "Index in split_part must be positive, value provided was %d", index);
+ gdv_fn_context_set_error_msg(context, error_message);
+ return "";
+ }
+
+ if (delim_len == 0 || text_len == 0) {
+ // output will just be text if no delimiter is provided
+ *out_len = text_len;
+ return text;
+ }
+
+ int i = 0, match_no = 1;
+
+ while (i < text_len) {
+ // find the position where delimiter matched for the first time
+ int match_pos = match_string(text, text_len, i, delimiter, delim_len);
+ if (match_pos == -1 && match_no != index) {
+ // reached the end without finding a match.
+ return "";
+ } else {
+ // Found a match. If the match number is index then return this match
+ if (match_no == index) {
+ int end_pos = match_pos - delim_len;
+
+ if (match_pos == -1) {
+ // end position should be last position of the string as we have the last
+ // delimiter
+ end_pos = text_len;
+ }
+
+ *out_len = end_pos - i;
+ char* out_str =
+ reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (out_str == nullptr) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(out_str, text + i, *out_len);
+ return out_str;
+ } else {
+ i = match_pos;
+ match_no++;
+ }
+ }
+ }
+
+ return "";
+}
+
+// Returns the x leftmost characters of a given string. Cases:
+// LEFT("TestString", 10) => "TestString"
+// LEFT("TestString", 3) => "Tes"
+// LEFT("TestString", -3) => "TestStr"
+FORCE_INLINE
+const char* left_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len,
+ gdv_int32 number, gdv_int32* out_len) {
+ // returns the 'number' left most characters of a given text
+ if (text_len == 0 || number == 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ // iterate over the utf8 string validating each character
+ int char_len;
+ int char_count = 0;
+ int byte_index = 0;
+ for (int i = 0; i < text_len; i += char_len) {
+ char_len = utf8_char_length(text[i]);
+ if (char_len == 0 || i + char_len > text_len) { // invalid byte or incomplete glyph
+ set_error_for_invalid_utf(context, text[i]);
+ *out_len = 0;
+ return "";
+ }
+ for (int j = 1; j < char_len; ++j) {
+ if ((text[i + j] & 0xC0) != 0x80) { // bytes following head-byte of glyph
+ set_error_for_invalid_utf(context, text[i + j]);
+ *out_len = 0;
+ return "";
+ }
+ }
+ byte_index += char_len;
+ ++char_count;
+ // Define the rules to stop the iteration over the string
+ // case where left('abc', 5) -> 'abc'
+ if (number > 0 && char_count == number) break;
+ // case where left('abc', -5) ==> ''
+ if (number < 0 && char_count == number + text_len) break;
+ }
+
+ *out_len = byte_index;
+ return text;
+}
+
+// Returns the x rightmost characters of a given string. Cases:
+// RIGHT("TestString", 10) => "TestString"
+// RIGHT("TestString", 3) => "ing"
+// RIGHT("TestString", -3) => "tString"
+FORCE_INLINE
+const char* right_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len,
+ gdv_int32 number, gdv_int32* out_len) {
+ // returns the 'number' left most characters of a given text
+ if (text_len == 0 || number == 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ // initially counts the number of utf8 characters in the defined text
+ int32_t char_count = utf8_length(context, text, text_len);
+ // char_count is zero if input has invalid utf8 char
+ if (char_count == 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ int32_t start_char_pos; // the char result start position (inclusive)
+ int32_t end_char_len; // the char result end position (inclusive)
+ if (number > 0) {
+ // case where right('abc', 5) ==> 'abc' start_char_pos=1.
+ start_char_pos = (char_count > number) ? char_count - number : 0;
+ end_char_len = char_count - start_char_pos;
+ } else {
+ start_char_pos = number * -1;
+ end_char_len = char_count - start_char_pos;
+ }
+
+ // calculate the start byte position and the output length
+ int32_t start_byte_pos = utf8_byte_pos(context, text, text_len, start_char_pos);
+ *out_len = utf8_byte_pos(context, text, text_len, end_char_len);
+
+ // try to allocate memory for the response
+ char* ret =
+ reinterpret_cast<gdv_binary>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ memcpy(ret, text + start_byte_pos, *out_len);
+ return ret;
+}
+
+FORCE_INLINE
+const char* binary_string(gdv_int64 context, const char* text, gdv_int32 text_len,
+ gdv_int32* out_len) {
+ gdv_binary ret =
+ reinterpret_cast<gdv_binary>(gdv_fn_context_arena_malloc(context, text_len));
+
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+
+ if (text_len == 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ // converting hex encoded string to normal string
+ int j = 0;
+ for (int i = 0; i < text_len; i++, j++) {
+ if (text[i] == '\\' && i + 3 < text_len &&
+ (text[i + 1] == 'x' || text[i + 1] == 'X')) {
+ char hd1 = text[i + 2];
+ char hd2 = text[i + 3];
+ if (isxdigit(hd1) && isxdigit(hd2)) {
+ // [a-fA-F0-9]
+ ret[j] = to_binary_from_hex(hd1) * 16 + to_binary_from_hex(hd2);
+ i += 3;
+ } else {
+ ret[j] = text[i];
+ }
+ } else {
+ ret[j] = text[i];
+ }
+ }
+ *out_len = j;
+ return ret;
+}
+
+#define CAST_INT_BIGINT_VARBINARY(OUT_TYPE, TYPE_NAME) \
+ FORCE_INLINE \
+ OUT_TYPE \
+ cast##TYPE_NAME##_varbinary(gdv_int64 context, const char* in, int32_t in_len) { \
+ if (in_len == 0) { \
+ gdv_fn_context_set_error_msg(context, "Can't cast an empty string."); \
+ return -1; \
+ } \
+ char sign = in[0]; \
+ \
+ bool negative = false; \
+ if (sign == '-') { \
+ negative = true; \
+ /* Ignores the sign char in the hexadecimal string */ \
+ in++; \
+ in_len--; \
+ } \
+ \
+ if (negative && in_len == 0) { \
+ gdv_fn_context_set_error_msg(context, \
+ "Can't cast hexadecimal with only a minus sign."); \
+ return -1; \
+ } \
+ \
+ OUT_TYPE result = 0; \
+ int digit; \
+ \
+ int read_index = 0; \
+ while (read_index < in_len) { \
+ char c1 = in[read_index]; \
+ if (isxdigit(c1)) { \
+ digit = to_binary_from_hex(c1); \
+ \
+ OUT_TYPE next = result * 16 - digit; \
+ \
+ if (next > result) { \
+ gdv_fn_context_set_error_msg(context, "Integer overflow."); \
+ return -1; \
+ } \
+ result = next; \
+ read_index++; \
+ } else { \
+ gdv_fn_context_set_error_msg(context, \
+ "The hexadecimal given has invalid characters."); \
+ return -1; \
+ } \
+ } \
+ if (!negative) { \
+ result *= -1; \
+ \
+ if (result < 0) { \
+ gdv_fn_context_set_error_msg(context, "Integer overflow."); \
+ return -1; \
+ } \
+ } \
+ return result; \
+ }
+
+CAST_INT_BIGINT_VARBINARY(int32_t, INT)
+CAST_INT_BIGINT_VARBINARY(int64_t, BIGINT)
+
+#undef CAST_INT_BIGINT_VARBINARY
+
+// Produces the binary representation of a string y characters long derived by starting
+// at offset 'x' and considering the defined length 'y'. Notice that the offset index
+// may be a negative number (starting from the end of the string), or a positive number
+// starting on index 1. Cases:
+// BYTE_SUBSTR("TestString", 1, 10) => "TestString"
+// BYTE_SUBSTR("TestString", 5, 10) => "String"
+// BYTE_SUBSTR("TestString", -6, 10) => "String"
+// BYTE_SUBSTR("TestString", -600, 10) => "TestString"
+FORCE_INLINE
+const char* byte_substr_binary_int32_int32(gdv_int64 context, const char* text,
+ gdv_int32 text_len, gdv_int32 offset,
+ gdv_int32 length, gdv_int32* out_len) {
+ // the first offset position for a string is 1, so not consider offset == 0
+ // also, the length should be always a positive number
+ if (text_len == 0 || offset == 0 || length <= 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ char* ret =
+ reinterpret_cast<gdv_binary>(gdv_fn_context_arena_malloc(context, text_len));
+
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+
+ int32_t startPos = 0;
+ if (offset >= 0) {
+ startPos = offset - 1;
+ } else if (text_len + offset >= 0) {
+ startPos = text_len + offset;
+ }
+
+ // calculate end position from length and truncate to upper value bounds
+ if (startPos + length > text_len) {
+ *out_len = text_len - startPos;
+ } else {
+ *out_len = length;
+ }
+
+ memcpy(ret, text + startPos, *out_len);
+ return ret;
+}
+} // extern "C"
diff --git a/src/arrow/cpp/src/gandiva/precompiled/string_ops_test.cc b/src/arrow/cpp/src/gandiva/precompiled/string_ops_test.cc
new file mode 100644
index 000000000..6221dffb3
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/string_ops_test.cc
@@ -0,0 +1,1758 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <limits>
+
+#include "gandiva/execution_context.h"
+#include "gandiva/precompiled/types.h"
+
+namespace gandiva {
+
+TEST(TestStringOps, TestCompare) {
+ const char* left = "abcd789";
+ const char* right = "abcd123";
+
+ // 0 for equal
+ EXPECT_EQ(mem_compare(left, 4, right, 4), 0);
+
+ // compare lengths if the prefixes match
+ EXPECT_GT(mem_compare(left, 5, right, 4), 0);
+ EXPECT_LT(mem_compare(left, 4, right, 5), 0);
+
+ // compare bytes if the prefixes don't match
+ EXPECT_GT(mem_compare(left, 5, right, 5), 0);
+ EXPECT_GT(mem_compare(left, 5, right, 7), 0);
+ EXPECT_GT(mem_compare(left, 7, right, 5), 0);
+}
+
+TEST(TestStringOps, TestAscii) {
+ // ASCII
+ EXPECT_EQ(ascii_utf8("ABC", 3), 65);
+ EXPECT_EQ(ascii_utf8("abc", 3), 97);
+ EXPECT_EQ(ascii_utf8("Hello World!", 12), 72);
+ EXPECT_EQ(ascii_utf8("This is us", 10), 84);
+ EXPECT_EQ(ascii_utf8("", 0), 0);
+ EXPECT_EQ(ascii_utf8("123", 3), 49);
+ EXPECT_EQ(ascii_utf8("999", 3), 57);
+}
+
+TEST(TestStringOps, TestBeginsEnds) {
+ // starts_with
+ EXPECT_TRUE(starts_with_utf8_utf8("hello sir", 9, "hello", 5));
+ EXPECT_TRUE(starts_with_utf8_utf8("hellos", 6, "hello", 5));
+ EXPECT_TRUE(starts_with_utf8_utf8("hello", 5, "hello", 5));
+ EXPECT_FALSE(starts_with_utf8_utf8("hell", 4, "hello", 5));
+ EXPECT_FALSE(starts_with_utf8_utf8("world hello", 11, "hello", 5));
+
+ // ends_with
+ EXPECT_TRUE(ends_with_utf8_utf8("hello sir", 9, "sir", 3));
+ EXPECT_TRUE(ends_with_utf8_utf8("ssir", 4, "sir", 3));
+ EXPECT_TRUE(ends_with_utf8_utf8("sir", 3, "sir", 3));
+ EXPECT_FALSE(ends_with_utf8_utf8("ir", 2, "sir", 3));
+ EXPECT_FALSE(ends_with_utf8_utf8("hello", 5, "sir", 3));
+}
+
+TEST(TestStringOps, TestSpace) {
+ // Space - returns a string with 'n' spaces
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ int32_t out_len = 0;
+
+ auto out = space_int32(ctx_ptr, 1, &out_len);
+ EXPECT_EQ(std::string(out, out_len), " ");
+ out = space_int32(ctx_ptr, 10, &out_len);
+ EXPECT_EQ(std::string(out, out_len), " ");
+ out = space_int32(ctx_ptr, 5, &out_len);
+ EXPECT_EQ(std::string(out, out_len), " ");
+ out = space_int32(ctx_ptr, -5, &out_len);
+ EXPECT_EQ(std::string(out, out_len), "");
+
+ out = space_int64(ctx_ptr, 2, &out_len);
+ EXPECT_EQ(std::string(out, out_len), " ");
+ out = space_int64(ctx_ptr, 9, &out_len);
+ EXPECT_EQ(std::string(out, out_len), " ");
+ out = space_int64(ctx_ptr, 4, &out_len);
+ EXPECT_EQ(std::string(out, out_len), " ");
+ out = space_int64(ctx_ptr, -5, &out_len);
+ EXPECT_EQ(std::string(out, out_len), "");
+}
+
+TEST(TestStringOps, TestIsSubstr) {
+ EXPECT_TRUE(is_substr_utf8_utf8("hello world", 11, "world", 5));
+ EXPECT_TRUE(is_substr_utf8_utf8("hello world", 11, "lo wo", 5));
+ EXPECT_FALSE(is_substr_utf8_utf8("hello world", 11, "adsed", 5));
+ EXPECT_FALSE(is_substr_utf8_utf8("hel", 3, "hello", 5));
+ EXPECT_TRUE(is_substr_utf8_utf8("hello", 5, "hello", 5));
+ EXPECT_TRUE(is_substr_utf8_utf8("hello world", 11, "", 0));
+}
+
+TEST(TestStringOps, TestCharLength) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+
+ EXPECT_EQ(utf8_length(ctx_ptr, "hello sir", 9), 9);
+
+ std::string a("âpple");
+ EXPECT_EQ(utf8_length(ctx_ptr, a.data(), static_cast<int>(a.length())), 5);
+
+ std::string b("मदन");
+ EXPECT_EQ(utf8_length(ctx_ptr, b.data(), static_cast<int>(b.length())), 3);
+
+ // invalid utf8
+ std::string c("\xf8\x28");
+ EXPECT_EQ(utf8_length(ctx_ptr, c.data(), static_cast<int>(c.length())), 0);
+ EXPECT_TRUE(ctx.get_error().find(
+ "unexpected byte \\f8 encountered while decoding utf8 string") !=
+ std::string::npos)
+ << ctx.get_error();
+ ctx.Reset();
+
+ std::string d("aa\xc3");
+ EXPECT_EQ(utf8_length(ctx_ptr, d.data(), static_cast<int>(d.length())), 0);
+ EXPECT_TRUE(ctx.get_error().find(
+ "unexpected byte \\c3 encountered while decoding utf8 string") !=
+ std::string::npos)
+ << ctx.get_error();
+ ctx.Reset();
+
+ std::string e(
+ "a\xc3"
+ "a");
+ EXPECT_EQ(utf8_length(ctx_ptr, e.data(), static_cast<int>(e.length())), 0);
+ EXPECT_TRUE(ctx.get_error().find(
+ "unexpected byte \\61 encountered while decoding utf8 string") !=
+ std::string::npos)
+ << ctx.get_error();
+ ctx.Reset();
+
+ std::string f(
+ "a\xc3\xe3"
+ "a");
+ EXPECT_EQ(utf8_length(ctx_ptr, f.data(), static_cast<int>(f.length())), 0);
+ EXPECT_TRUE(ctx.get_error().find(
+ "unexpected byte \\e3 encountered while decoding utf8 string") !=
+ std::string::npos)
+ << ctx.get_error();
+ ctx.Reset();
+}
+
+TEST(TestStringOps, TestConvertReplaceInvalidUtf8Char) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+
+ // invalid utf8 (xf8 is invalid but x28 is not - x28 = '(')
+ std::string a(
+ "ok-\xf8\x28"
+ "-a");
+ auto a_in_out_len = static_cast<int>(a.length());
+ const char* a_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, a.data(), a_in_out_len, "a", 1, &a_in_out_len);
+ EXPECT_EQ(std::string(a_str, a_in_out_len), "ok-a(-a");
+ EXPECT_FALSE(ctx.has_error());
+
+ // invalid utf8 (xa0 and xa1 are invalid)
+ std::string b("ok-\xa0\xa1-valid");
+ auto b_in_out_len = static_cast<int>(b.length());
+ const char* b_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, b.data(), b_in_out_len, "b", 1, &b_in_out_len);
+ EXPECT_EQ(std::string(b_str, b_in_out_len), "ok-bb-valid");
+ EXPECT_FALSE(ctx.has_error());
+
+ // full valid utf8
+ std::string c("all-valid");
+ auto c_in_out_len = static_cast<int>(c.length());
+ const char* c_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, c.data(), c_in_out_len, "c", 1, &c_in_out_len);
+ EXPECT_EQ(std::string(c_str, c_in_out_len), "all-valid");
+ EXPECT_FALSE(ctx.has_error());
+
+ // valid utf8 (महसुस is 4-char string, each char of which is likely a multibyte char)
+ std::string d("ok-महसुस-valid-new");
+ auto d_in_out_len = static_cast<int>(d.length());
+ const char* d_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, d.data(), d_in_out_len, "d", 1, &d_in_out_len);
+ EXPECT_EQ(std::string(d_str, d_in_out_len), "ok-महसुस-valid-new");
+ EXPECT_FALSE(ctx.has_error());
+
+ // full valid utf8, but invalid replacement char length
+ std::string e("all-valid");
+ auto e_in_out_len = static_cast<int>(e.length());
+ const char* e_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, e.data(), e_in_out_len, "ee", 2, &e_in_out_len);
+ EXPECT_EQ(std::string(e_str, e_in_out_len), "");
+ EXPECT_TRUE(ctx.has_error());
+ ctx.Reset();
+
+ // invalid utf8 (xa0 and xa1 are invalid) with empty replacement char length
+ std::string f("ok-\xa0\xa1-valid");
+ auto f_in_out_len = static_cast<int>(f.length());
+ const char* f_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, f.data(), f_in_out_len, "", 0, &f_in_out_len);
+ EXPECT_EQ(std::string(f_str, f_in_out_len), "ok--valid");
+ EXPECT_FALSE(ctx.has_error());
+ ctx.Reset();
+
+ // invalid utf8 (xa0 and xa1 are invalid) with empty replacement char length
+ std::string g("\xa0\xa1-ok-\xa0\xa1-valid-\xa0\xa1");
+ auto g_in_out_len = static_cast<int>(g.length());
+ const char* g_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, g.data(), g_in_out_len, "", 0, &g_in_out_len);
+ EXPECT_EQ(std::string(g_str, g_in_out_len), "-ok--valid-");
+ EXPECT_FALSE(ctx.has_error());
+ ctx.Reset();
+
+ std::string h("\xa0\xa1-valid");
+ auto h_in_out_len = static_cast<int>(h.length());
+ const char* h_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, h.data(), h_in_out_len, "", 0, &h_in_out_len);
+ EXPECT_EQ(std::string(h_str, h_in_out_len), "-valid");
+ EXPECT_FALSE(ctx.has_error());
+ ctx.Reset();
+
+ std::string i("\xa0\xa1-valid-\xa0\xa1-valid-\xa0\xa1");
+ auto i_in_out_len = static_cast<int>(i.length());
+ const char* i_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, i.data(), i_in_out_len, "", 0, &i_in_out_len);
+ EXPECT_EQ(std::string(i_str, i_in_out_len), "-valid--valid-");
+ EXPECT_FALSE(ctx.has_error());
+ ctx.Reset();
+}
+
+TEST(TestStringOps, TestRepeat) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+
+ const char* out_str = repeat_utf8_int32(ctx_ptr, "abc", 3, 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "abcabc");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = repeat_utf8_int32(ctx_ptr, "a", 1, 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "aaaaa");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = repeat_utf8_int32(ctx_ptr, "", 0, 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = repeat_utf8_int32(ctx_ptr, "", -20, 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = repeat_utf8_int32(ctx_ptr, "a", 1, -10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Repeat number can't be negative"));
+ ctx.Reset();
+}
+
+TEST(TestStringOps, TestCastBoolToVarchar) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+
+ const char* out_str = castVARCHAR_bool_int64(ctx_ptr, true, 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "tr");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_bool_int64(ctx_ptr, true, 7, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "true");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_bool_int64(ctx_ptr, false, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "fals");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_bool_int64(ctx_ptr, false, 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "false");
+ EXPECT_FALSE(ctx.has_error());
+
+ castVARCHAR_bool_int64(ctx_ptr, true, -3, &out_len);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Output buffer length can't be negative"));
+ ctx.Reset();
+}
+
+TEST(TestStringOps, TestCastVarcharToBool) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, "true", 4), true);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, " true ", 14), true);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, "true ", 9), true);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, " true", 9), true);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, "TRUE", 4), true);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, "TrUe", 4), true);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, "1", 1), true);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, " 1", 3), true);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, "false", 5), false);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, "false ", 10), false);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, " false", 10), false);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, "0", 1), false);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, "0 ", 4), false);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, "FALSE", 5), false);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, "FaLsE", 5), false);
+ EXPECT_FALSE(ctx.has_error());
+
+ EXPECT_EQ(castBIT_utf8(ctx_ptr, "test", 4), false);
+ EXPECT_TRUE(ctx.has_error());
+ EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid value for boolean"));
+ ctx.Reset();
+}
+
+TEST(TestStringOps, TestCastVarchar) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+
+ // BINARY TESTS
+ const char* out_str = castVARCHAR_binary_int64(ctx_ptr, "asdf", 4, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "a");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "asdf", 4, 6, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asdf");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "asdf", 4, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asd");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "asdf", 4, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asdf");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "asdf", 4, 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asdf");
+ EXPECT_FALSE(ctx.has_error());
+
+ // do not truncate if output length is 0
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "asdf", 4, 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asdf");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "", 0, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "çåå†", 9, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çåå");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "çåå†", 9, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çåå†");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "çåå†", 9, 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çåå†");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "çåå†", 9, 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çåå†");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "çåå†", 9, 6, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çåå†");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "abc", 3, -1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Output buffer length can't be negative"));
+ ctx.Reset();
+
+ std::string z("aa\xc3");
+ out_str = castVARCHAR_binary_int64(ctx_ptr, z.data(), static_cast<int>(z.length()), 2,
+ &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "aa");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812341234", 16, 16, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1234567812341234");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812341234", 16, 15, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "123456781234123");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812341234", 16, 12, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "123456781234");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812341234", 16, 8, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "12345678");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812341234", 16, 7, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1234567");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812341234", 16, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1234");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812341234", 16, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "123");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812çåå†123456", 25, 16, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1234567812çåå†12");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "123456781234çåå†1234", 25, 15, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "123456781234çåå");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "12çåå†34567812123456", 25, 16, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "12çåå†3456781212");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "çåå†1234567812123456", 25, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çåå†");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "çåå†1234567812123456", 25, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çåå");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_binary_int64(ctx_ptr, "123456781234çåå†", 21, 40, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "123456781234çåå†");
+ EXPECT_FALSE(ctx.has_error());
+
+ std::string f("123456781234çåå\xc3");
+ out_str = castVARCHAR_binary_int64(ctx_ptr, f.data(), static_cast<int32_t>(f.length()),
+ 16, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr(
+ "unexpected byte \\c3 encountered while decoding utf8 string"));
+ ctx.Reset();
+
+ // UTF8 TESTS
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "asdf", 4, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "a");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "asdf", 4, 6, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asdf");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "asdf", 4, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asd");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "asdf", 4, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asdf");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "asdf", 4, 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asdf");
+ EXPECT_FALSE(ctx.has_error());
+
+ // do not truncate if output length is 0
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "asdf", 4, 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asdf");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "", 0, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "çåå†", 9, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çåå");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "çåå†", 9, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çåå†");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "çåå†", 9, 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çåå†");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "çåå†", 9, 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çåå†");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "çåå†", 9, 6, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çåå†");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "abc", 3, -1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Output buffer length can't be negative"));
+ ctx.Reset();
+
+ std::string d("aa\xc3");
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, d.data(), static_cast<int>(d.length()), 2,
+ &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "aa");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812341234", 16, 16, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1234567812341234");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812341234", 16, 15, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "123456781234123");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812341234", 16, 12, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "123456781234");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812341234", 16, 8, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "12345678");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812341234", 16, 7, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1234567");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812341234", 16, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1234");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812341234", 16, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "123");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812çåå†123456", 25, 16, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1234567812çåå†12");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "123456781234çåå†1234", 25, 15, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "123456781234çåå");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "12çåå†34567812123456", 25, 16, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "12çåå†3456781212");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "çåå†1234567812123456", 25, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çåå†");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "çåå†1234567812123456", 25, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çåå");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, "123456781234çåå†", 21, 40, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "123456781234çåå†");
+ EXPECT_FALSE(ctx.has_error());
+
+ std::string y("123456781234çåå\xc3");
+ out_str = castVARCHAR_utf8_int64(ctx_ptr, y.data(), static_cast<int32_t>(y.length()),
+ 16, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr(
+ "unexpected byte \\c3 encountered while decoding utf8 string"));
+ ctx.Reset();
+}
+
+TEST(TestStringOps, TestSubstring) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+
+ const char* out_str = substr_utf8_int64_int64(ctx_ptr, "asdf", 4, 1, 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64_int64(ctx_ptr, "asdf", 4, 1, 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "as");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64_int64(ctx_ptr, "asdf", 4, 1, 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asdf");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64_int64(ctx_ptr, "asdf", 4, 0, 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asdf");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64_int64(ctx_ptr, "asdf", 4, -2, 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "df");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64_int64(ctx_ptr, "asdf", 4, -5, 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64_int64(ctx_ptr, "अपाचे एरो", 25, 1, 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "अपाचे");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64_int64(ctx_ptr, "अपाचे एरो", 25, 7, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "एरो");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64_int64(ctx_ptr, "çåå†", 9, 4, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "†");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64_int64(ctx_ptr, "çåå†", 9, 2, 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "åå");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64_int64(ctx_ptr, "çåå†", 9, 0, 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "çå");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64_int64(ctx_ptr, "afg", 4, 0, -5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64_int64(ctx_ptr, "", 0, 5, 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64(ctx_ptr, "abcd", 4, 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "bcd");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64(ctx_ptr, "abcd", 4, 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "abcd");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = substr_utf8_int64(ctx_ptr, "çåå†", 9, 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "åå†");
+ EXPECT_FALSE(ctx.has_error());
+}
+
+TEST(TestStringOps, TestSubstringInvalidInputs) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+
+ char bytes[] = {'\xA7', 'a'};
+ const char* out_str = substr_utf8_int64_int64(ctx_ptr, bytes, 2, 1, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_TRUE(ctx.has_error());
+ ctx.Reset();
+
+ char midbytes[] = {'c', '\xA7', 'a'};
+ out_str = substr_utf8_int64_int64(ctx_ptr, midbytes, 3, 1, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_TRUE(ctx.has_error());
+ ctx.Reset();
+
+ char midbytes2[] = {'\xC3', 'a', 'a'};
+ out_str = substr_utf8_int64_int64(ctx_ptr, midbytes2, 3, 1, 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_TRUE(ctx.has_error());
+ ctx.Reset();
+
+ char endbytes[] = {'a', 'a', '\xA7'};
+ out_str = substr_utf8_int64_int64(ctx_ptr, endbytes, 3, 1, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_TRUE(ctx.has_error());
+ ctx.Reset();
+
+ char endbytes2[] = {'a', 'a', '\xC3'};
+ out_str = substr_utf8_int64_int64(ctx_ptr, endbytes2, 3, 1, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_TRUE(ctx.has_error());
+ ctx.Reset();
+
+ out_str = substr_utf8_int64_int64(ctx_ptr, "çåå†", 9, 2147483656, 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+}
+
+TEST(TestGdvFnStubs, TestCastVarbinaryUtf8) {
+ gandiva::ExecutionContext ctx;
+
+ int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+ int32_t out_len = 0;
+ const char* input = "abc";
+ const char* out;
+
+ out = castVARBINARY_utf8_int64(ctx_ptr, input, 3, 0, &out_len);
+ EXPECT_EQ(std::string(out, out_len), input);
+
+ out = castVARBINARY_utf8_int64(ctx_ptr, input, 3, 1, &out_len);
+ EXPECT_EQ(std::string(out, out_len), "a");
+
+ out = castVARBINARY_utf8_int64(ctx_ptr, input, 3, 500, &out_len);
+ EXPECT_EQ(std::string(out, out_len), input);
+
+ out = castVARBINARY_utf8_int64(ctx_ptr, input, 3, -10, &out_len);
+ EXPECT_EQ(std::string(out, out_len), "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Output buffer length can't be negative"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestCastVarbinaryBinary) {
+ gandiva::ExecutionContext ctx;
+
+ int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+ int32_t out_len = 0;
+ const char* input = "\\x41\\x42\\x43";
+ const char* out;
+
+ out = castVARBINARY_binary_int64(ctx_ptr, input, 12, 0, &out_len);
+ EXPECT_EQ(std::string(out, out_len), input);
+
+ out = castVARBINARY_binary_int64(ctx_ptr, input, 8, 8, &out_len);
+ EXPECT_EQ(std::string(out, out_len), "\\x41\\x42");
+
+ out = castVARBINARY_binary_int64(ctx_ptr, input, 12, 500, &out_len);
+ EXPECT_EQ(std::string(out, out_len), input);
+
+ out = castVARBINARY_binary_int64(ctx_ptr, input, 12, -10, &out_len);
+ EXPECT_EQ(std::string(out, out_len), "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Output buffer length can't be negative"));
+ ctx.Reset();
+}
+
+TEST(TestStringOps, TestConcat) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+
+ const char* out_str =
+ concat_utf8_utf8(ctx_ptr, "abcd", 4, true, "\npq", 3, false, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "abcd");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concatOperator_utf8_utf8(ctx_ptr, "asdf", 4, "jkl", 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asdfjkl");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concatOperator_utf8_utf8(ctx_ptr, "asdf", 4, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "asdf");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concatOperator_utf8_utf8(ctx_ptr, "", 0, "jkl", 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "jkl");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concatOperator_utf8_utf8(ctx_ptr, "", 0, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concatOperator_utf8_utf8(ctx_ptr, "abcd\n", 5, "a", 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "abcd\na");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concat_utf8_utf8_utf8(ctx_ptr, "abcd", 4, false, "\npq", 3, true, "ard", 3,
+ true, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "\npqard");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str =
+ concatOperator_utf8_utf8_utf8(ctx_ptr, "abcd\n", 5, "a", 1, "bcd", 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "abcd\nabcd");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concatOperator_utf8_utf8_utf8(ctx_ptr, "abcd", 4, "a", 1, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "abcda");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concatOperator_utf8_utf8_utf8(ctx_ptr, "", 0, "a", 1, "pqrs", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "apqrs");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concat_utf8_utf8_utf8_utf8(ctx_ptr, "abcd", 4, false, "\npq", 3, true, "ard",
+ 3, true, "uvw", 3, false, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "\npqard");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concatOperator_utf8_utf8_utf8_utf8(ctx_ptr, "pqrs", 4, "", 0, "\nabc", 4, "y",
+ 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "pqrs\nabcy");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concat_utf8_utf8_utf8_utf8_utf8(ctx_ptr, "abcd", 4, false, "\npq", 3, true,
+ "ard", 3, true, "uvw", 3, false, "abc\n", 4,
+ true, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "\npqardabc\n");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concatOperator_utf8_utf8_utf8_utf8_utf8(ctx_ptr, "pqrs", 4, "", 0, "\nabc", 4,
+ "y", 1, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "pqrs\nabcy");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concat_utf8_utf8_utf8_utf8_utf8_utf8(
+ ctx_ptr, "abcd", 4, false, "\npq", 3, true, "ard", 3, true, "uvw", 3, false,
+ "abc\n", 4, true, "sdfgs", 5, true, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "\npqardabc\nsdfgs");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concatOperator_utf8_utf8_utf8_utf8_utf8_utf8(
+ ctx_ptr, "pqrs", 4, "", 0, "\nabc", 4, "y", 1, "", 0, "\nbcd", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "pqrs\nabcy\nbcd");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ ctx_ptr, "abcd", 4, false, "\npq", 3, true, "ard", 3, true, "uvw", 3, false,
+ "abc\n", 4, true, "sdfgs", 5, true, "wfw", 3, false, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "\npqardabc\nsdfgs");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ ctx_ptr, "", 0, "pqrs", 4, "abc\n", 4, "y", 1, "", 0, "asdf", 4, "jkl", 3,
+ &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "pqrsabc\nyasdfjkl");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ ctx_ptr, "abcd", 4, false, "\npq", 3, true, "ard", 3, true, "uvw", 3, false,
+ "abc\n", 4, true, "sdfgs", 5, true, "wfw", 3, false, "", 0, true, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "\npqardabc\nsdfgs");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ ctx_ptr, "", 0, "pqrs", 4, "abc\n", 4, "y", 1, "", 0, "asdf", 4, "jkl", 3, "", 0,
+ &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "pqrsabc\nyasdfjkl");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ ctx_ptr, "abcd", 4, false, "\npq", 3, true, "ard", 3, true, "uvw", 3, false,
+ "abc\n", 4, true, "sdfgs", 5, true, "wfw", 3, false, "", 0, true, "qwert|n", 7,
+ true, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "\npqardabc\nsdfgsqwert|n");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ ctx_ptr, "", 0, "pqrs", 4, "abc\n", 4, "y", 1, "", 0, "asdf", 4, "jkl", 3, "", 0,
+ "sfl\n", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "pqrsabc\nyasdfjklsfl\n");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ ctx_ptr, "abcd", 4, false, "\npq", 3, true, "ard", 3, true, "uvw", 3, false,
+ "abc\n", 4, true, "sdfgs", 5, true, "wfw", 3, false, "", 0, true, "qwert|n", 7,
+ true, "ewfwe", 5, false, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "\npqardabc\nsdfgsqwert|n");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ ctx_ptr, "", 0, "pqrs", 4, "abc\n", 4, "y", 1, "", 0, "asdf", 4, "", 0, "jkl", 3,
+ "sfl\n", 4, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "pqrsabc\nyasdfjklsfl\n");
+ EXPECT_FALSE(ctx.has_error());
+}
+
+TEST(TestStringOps, TestReverse) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+
+ const char* out_str;
+ out_str = reverse_utf8(ctx_ptr, "TestString", 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "gnirtStseT");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = reverse_utf8(ctx_ptr, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = reverse_utf8(ctx_ptr, "çåå†", 9, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "†ååç");
+ EXPECT_FALSE(ctx.has_error());
+
+ std::string d("aa\xc3");
+ out_str = reverse_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr(
+ "unexpected byte \\c3 encountered while decoding utf8 string"));
+ ctx.Reset();
+}
+
+TEST(TestStringOps, TestLtrim) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+ const char* out_str;
+
+ out_str = ltrim_utf8(ctx_ptr, "TestString ", 12, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString ");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = ltrim_utf8(ctx_ptr, " TestString ", 18, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString ");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = ltrim_utf8(ctx_ptr, " Test çåå†bD", 18, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Test çåå†bD");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = ltrim_utf8(ctx_ptr, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = ltrim_utf8(ctx_ptr, " ", 6, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = ltrim_utf8_utf8(ctx_ptr, "", 0, "TestString", 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = ltrim_utf8_utf8(ctx_ptr, "TestString", 10, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = ltrim_utf8_utf8(ctx_ptr, "abcbbaccabbcdef", 15, "abc", 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "def");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = ltrim_utf8_utf8(ctx_ptr, "abcbbaccabbcdef", 15, "ababbac", 7, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "def");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = ltrim_utf8_utf8(ctx_ptr, "ååçåå†eç†Dd", 21, "çåå†", 9, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "eç†Dd");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = ltrim_utf8_utf8(ctx_ptr, "ç†ååçåå†", 18, "çåå†", 9, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ std::string d(
+ "aa\xc3"
+ "bcd");
+ out_str =
+ ltrim_utf8_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), "a", 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len),
+ "\xc3"
+ "bcd");
+ EXPECT_FALSE(ctx.has_error());
+
+ std::string e(
+ "åå\xe0\xa0"
+ "bcd");
+ out_str =
+ ltrim_utf8_utf8(ctx_ptr, e.data(), static_cast<int>(e.length()), "å", 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len),
+ "\xE0\xa0"
+ "bcd");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = ltrim_utf8_utf8(ctx_ptr, "TestString", 10, "abcd", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = ltrim_utf8_utf8(ctx_ptr, "acbabbcabb", 10, "abcbd", 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+}
+
+TEST(TestStringOps, TestLpadString) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+ const char* out_str;
+
+ // LPAD function tests - with defined fill pad text
+ out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 4, "fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Test");
+
+ out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 10, "fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+
+ out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 0, 10, "fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+
+ out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 0, "fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+
+ out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, -500, "fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+
+ out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 500, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+
+ out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 18, "Fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "FillFillTestString");
+
+ out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 15, "Fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "FillFTestString");
+
+ out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 20, "Fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "FillFillFiTestString");
+
+ out_str = lpad_utf8_int32_utf8(ctx_ptr, "абвгд", 10, 7, "д", 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "ддабвгд");
+
+ out_str = lpad_utf8_int32_utf8(ctx_ptr, "абвгд", 10, 20, "абвгд", 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "абвгдабвгдабвгдабвгд");
+
+ out_str = lpad_utf8_int32_utf8(ctx_ptr, "hello", 5, 6, "д", 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "дhello");
+
+ // LPAD function tests - with NO pad text
+ out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Test");
+
+ out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+
+ out_str = lpad_utf8_int32(ctx_ptr, "TestString", 0, 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+
+ out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+
+ out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, -500, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+
+ out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, 18, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), " TestString");
+
+ out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, 15, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), " TestString");
+
+ out_str = lpad_utf8_int32(ctx_ptr, "абвгд", 10, 7, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), " абвгд");
+}
+
+TEST(TestStringOps, TestRpadString) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+ const char* out_str;
+
+ // RPAD function tests - with defined fill pad text
+ out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 4, "fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Test");
+
+ out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 10, "fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+
+ out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 0, 10, "fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+
+ out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 0, "fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+
+ out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, -500, "fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+
+ out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 500, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+
+ out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 18, "Fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestStringFillFill");
+
+ out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 15, "Fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestStringFillF");
+
+ out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 20, "Fill", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestStringFillFillFi");
+
+ out_str = rpad_utf8_int32_utf8(ctx_ptr, "абвгд", 10, 7, "д", 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "абвгддд");
+
+ out_str = rpad_utf8_int32_utf8(ctx_ptr, "абвгд", 10, 20, "абвгд", 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "абвгдабвгдабвгдабвгд");
+
+ out_str = rpad_utf8_int32_utf8(ctx_ptr, "hello", 5, 6, "д", 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "helloд");
+
+ // RPAD function tests - with NO pad text
+ out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Test");
+
+ out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+
+ out_str = rpad_utf8_int32(ctx_ptr, "TestString", 0, 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+
+ out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+
+ out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, -500, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+
+ out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, 18, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString ");
+
+ out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, 15, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString ");
+
+ out_str = rpad_utf8_int32(ctx_ptr, "абвгд", 10, 7, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "абвгд ");
+}
+
+TEST(TestStringOps, TestRtrim) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+ const char* out_str;
+
+ out_str = rtrim_utf8(ctx_ptr, " TestString", 12, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), " TestString");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = rtrim_utf8(ctx_ptr, " TestString ", 18, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), " TestString");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = rtrim_utf8(ctx_ptr, "Test çåå†bD ", 20, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Test çåå†bD");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = rtrim_utf8(ctx_ptr, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = rtrim_utf8(ctx_ptr, " ", 6, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = rtrim_utf8_utf8(ctx_ptr, "", 0, "TestString", 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = rtrim_utf8_utf8(ctx_ptr, "TestString", 10, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = rtrim_utf8_utf8(ctx_ptr, "TestString", 10, "ring", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestSt");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = rtrim_utf8_utf8(ctx_ptr, "defabcbbaccabbc", 15, "abc", 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "def");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = rtrim_utf8_utf8(ctx_ptr, "defabcbbaccabbc", 15, "ababbac", 7, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "def");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = rtrim_utf8_utf8(ctx_ptr, "eDdç†ååçåå†", 21, "çåå†", 9, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "eDd");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = rtrim_utf8_utf8(ctx_ptr, "ç†ååçåå†", 18, "çåå†", 9, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ std::string d(
+ "\xc3"
+ "aaa");
+ out_str =
+ rtrim_utf8_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), "a", 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_TRUE(ctx.has_error());
+ ctx.Reset();
+
+ std::string e(
+ "\xe0\xa0"
+ "åå");
+ out_str =
+ rtrim_utf8_utf8(ctx_ptr, e.data(), static_cast<int>(e.length()), "å", 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_TRUE(ctx.has_error());
+ ctx.Reset();
+
+ out_str = rtrim_utf8_utf8(ctx_ptr, "åeçå", 7, "çå", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "åe");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = rtrim_utf8_utf8(ctx_ptr, "TestString", 10, "abcd", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = rtrim_utf8_utf8(ctx_ptr, "acbabbcabb", 10, "abcbd", 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+}
+
+TEST(TestStringOps, TestBtrim) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+ const char* out_str;
+
+ out_str = btrim_utf8(ctx_ptr, "TestString", 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = btrim_utf8(ctx_ptr, " TestString ", 18, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = btrim_utf8(ctx_ptr, " Test çåå†bD ", 21, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Test çåå†bD");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = btrim_utf8(ctx_ptr, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = btrim_utf8(ctx_ptr, " ", 6, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = btrim_utf8_utf8(ctx_ptr, "", 0, "TestString", 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = btrim_utf8_utf8(ctx_ptr, "TestString", 10, "Test", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "String");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = btrim_utf8_utf8(ctx_ptr, "TestString", 10, "String", 6, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Tes");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = btrim_utf8_utf8(ctx_ptr, "TestString", 10, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = btrim_utf8_utf8(ctx_ptr, "abcbbadefccabbc", 15, "abc", 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "def");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = btrim_utf8_utf8(ctx_ptr, "abcbbadefccabbc", 15, "ababbac", 7, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "def");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = btrim_utf8_utf8(ctx_ptr, "ååçåå†Ddeç†", 21, "çåå†", 9, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Dde");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = btrim_utf8_utf8(ctx_ptr, "ç†ååçåå†", 18, "çåå†", 9, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+ ctx.Reset();
+
+ std::string d(
+ "acd\xc3"
+ "aaa");
+ out_str =
+ btrim_utf8_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), "a", 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_TRUE(ctx.has_error());
+ ctx.Reset();
+
+ std::string e(
+ "åbc\xe0\xa0"
+ "åå");
+ out_str =
+ btrim_utf8_utf8(ctx_ptr, e.data(), static_cast<int>(e.length()), "å", 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_TRUE(ctx.has_error());
+ ctx.Reset();
+
+ std::string f(
+ "aa\xc3"
+ "bcd");
+ out_str =
+ btrim_utf8_utf8(ctx_ptr, f.data(), static_cast<int>(f.length()), "a", 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len),
+ "\xc3"
+ "bcd");
+ EXPECT_FALSE(ctx.has_error());
+
+ std::string g(
+ "åå\xe0\xa0"
+ "bcå");
+ out_str =
+ btrim_utf8_utf8(ctx_ptr, g.data(), static_cast<int>(g.length()), "å", 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len),
+ "\xe0\xa0"
+ "bc");
+
+ out_str = btrim_utf8_utf8(ctx_ptr, "åe†çå", 10, "çå", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "e†");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = btrim_utf8_utf8(ctx_ptr, "TestString", 10, "abcd", 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = btrim_utf8_utf8(ctx_ptr, "acbabbcabb", 10, "abcbd", 5, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+}
+
+TEST(TestStringOps, TestLocate) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+
+ int pos;
+
+ pos = locate_utf8_utf8(ctx_ptr, "String", 6, "TestString", 10);
+ EXPECT_EQ(pos, 5);
+ EXPECT_FALSE(ctx.has_error());
+
+ pos = locate_utf8_utf8_int32(ctx_ptr, "String", 6, "TestString", 10, 1);
+ EXPECT_EQ(pos, 5);
+ EXPECT_FALSE(ctx.has_error());
+
+ pos = locate_utf8_utf8_int32(ctx_ptr, "abc", 3, "abcabc", 6, 2);
+ EXPECT_EQ(pos, 4);
+ EXPECT_FALSE(ctx.has_error());
+
+ pos = locate_utf8_utf8(ctx_ptr, "çåå", 6, "s†å†emçåå†d", 21);
+ EXPECT_EQ(pos, 7);
+ EXPECT_FALSE(ctx.has_error());
+
+ pos = locate_utf8_utf8_int32(ctx_ptr, "bar", 3, "†barbar", 9, 3);
+ EXPECT_EQ(pos, 5);
+ EXPECT_FALSE(ctx.has_error());
+
+ pos = locate_utf8_utf8_int32(ctx_ptr, "sub", 3, "", 0, 1);
+ EXPECT_EQ(pos, 0);
+ EXPECT_FALSE(ctx.has_error());
+
+ pos = locate_utf8_utf8_int32(ctx_ptr, "", 0, "str", 3, 1);
+ EXPECT_EQ(pos, 0);
+ EXPECT_FALSE(ctx.has_error());
+
+ pos = locate_utf8_utf8_int32(ctx_ptr, "bar", 3, "barbar", 6, 0);
+ EXPECT_EQ(pos, 0);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Start position must be greater than 0"));
+ ctx.Reset();
+
+ pos = locate_utf8_utf8_int32(ctx_ptr, "bar", 3, "barbar", 6, 7);
+ EXPECT_EQ(pos, 0);
+ EXPECT_FALSE(ctx.has_error());
+
+ std::string d(
+ "a\xff"
+ "c");
+ pos =
+ locate_utf8_utf8_int32(ctx_ptr, "c", 1, d.data(), static_cast<int>(d.length()), 3);
+ EXPECT_EQ(pos, 0);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr(
+ "unexpected byte \\ff encountered while decoding utf8 string"));
+ ctx.Reset();
+}
+
+TEST(TestStringOps, TestByteSubstr) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+
+ const char* out_str;
+ out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 5, 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "String");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, -6, 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "String");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 0, 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 0, -500, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 1, 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 1, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Test");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 1, 1000, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 5, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Str");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 5, 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "String");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, -100, 10, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+ EXPECT_FALSE(ctx.has_error());
+}
+
+TEST(TestStringOps, TestStrPos) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+
+ int pos;
+
+ pos = strpos_utf8_utf8(ctx_ptr, "TestString", 10, "String", 6);
+ EXPECT_EQ(pos, 5);
+ EXPECT_FALSE(ctx.has_error());
+
+ pos = strpos_utf8_utf8(ctx_ptr, "TestString", 10, "String", 6);
+ EXPECT_EQ(pos, 5);
+ EXPECT_FALSE(ctx.has_error());
+
+ pos = strpos_utf8_utf8(ctx_ptr, "abcabc", 6, "abc", 3);
+ EXPECT_EQ(pos, 1);
+ EXPECT_FALSE(ctx.has_error());
+
+ pos = strpos_utf8_utf8(ctx_ptr, "s†å†emçåå†d", 21, "çåå", 6);
+ EXPECT_EQ(pos, 7);
+ EXPECT_FALSE(ctx.has_error());
+
+ pos = strpos_utf8_utf8(ctx_ptr, "†barbar", 9, "bar", 3);
+ EXPECT_EQ(pos, 2);
+ EXPECT_FALSE(ctx.has_error());
+
+ pos = strpos_utf8_utf8(ctx_ptr, "", 0, "sub", 3);
+ EXPECT_EQ(pos, 0);
+ EXPECT_FALSE(ctx.has_error());
+
+ pos = strpos_utf8_utf8(ctx_ptr, "str", 3, "", 0);
+ EXPECT_EQ(pos, 0);
+ EXPECT_FALSE(ctx.has_error());
+
+ std::string d(
+ "a\xff"
+ "c");
+ pos = strpos_utf8_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), "c", 1);
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr(
+ "unexpected byte \\ff encountered while decoding utf8 string"));
+ ctx.Reset();
+}
+
+TEST(TestStringOps, TestReplace) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+
+ const char* out_str;
+ out_str = replace_utf8_utf8_utf8(ctx_ptr, "TestString1String2", 18, "String", 6,
+ "Replace", 7, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestReplace1Replace2");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str =
+ replace_utf8_utf8_utf8(ctx_ptr, "TestString1", 11, "String", 6, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Test1");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = replace_utf8_utf8_utf8(ctx_ptr, "", 0, "test", 4, "rep", 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = replace_utf8_utf8_utf8(ctx_ptr, "dž†çåå†", 17, "†", 3, "t", 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Çttçååt");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = replace_utf8_utf8_utf8(ctx_ptr, "TestString", 10, "", 0, "rep", 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str =
+ replace_utf8_utf8_utf8(ctx_ptr, "Test", 4, "TestString", 10, "rep", 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Test");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = replace_utf8_utf8_utf8(ctx_ptr, "Test", 4, "Test", 4, "", 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str =
+ replace_utf8_utf8_utf8(ctx_ptr, "TestString", 10, "abc", 3, "xyz", 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "TestString");
+ EXPECT_FALSE(ctx.has_error());
+
+ replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "Hell", 4, "ell", 3, "ollow", 5, 5,
+ &out_len);
+ EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer overflow for output string"));
+ ctx.Reset();
+
+ replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "eeee", 4, "e", 1, "aaaa", 4, 14,
+ &out_len);
+ EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer overflow for output string"));
+ ctx.Reset();
+}
+
+TEST(TestStringOps, TestLeftString) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+ const char* out_str;
+
+ out_str = left_utf8_int32(ctx_ptr, "TestString", 10, 10, &out_len);
+ std::string output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "TestString");
+
+ out_str = left_utf8_int32(ctx_ptr, "", 0, 0, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "");
+
+ out_str = left_utf8_int32(ctx_ptr, "", 0, 500, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "");
+
+ out_str = left_utf8_int32(ctx_ptr, "TestString", 10, 3, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "Tes");
+
+ out_str = left_utf8_int32(ctx_ptr, "TestString", 10, -3, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "TestStr");
+
+ // the text length for this string is 10 (each utf8 char is represented by two bytes)
+ out_str = left_utf8_int32(ctx_ptr, "абвгд", 10, 3, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "абв");
+}
+
+TEST(TestStringOps, TestRightString) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+ const char* out_str;
+
+ out_str = right_utf8_int32(ctx_ptr, "TestString", 10, 10, &out_len);
+ std::string output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "TestString");
+
+ out_str = right_utf8_int32(ctx_ptr, "", 0, 0, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "");
+
+ out_str = right_utf8_int32(ctx_ptr, "", 0, 500, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "");
+
+ out_str = right_utf8_int32(ctx_ptr, "TestString", 10, 3, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "ing");
+
+ out_str = right_utf8_int32(ctx_ptr, "TestString", 10, -3, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "tString");
+
+ // the text length for this string is 10 (each utf8 char is represented by two bytes)
+ out_str = right_utf8_int32(ctx_ptr, "абвгд", 10, 3, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "вгд");
+}
+
+TEST(TestStringOps, TestBinaryString) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+ const char* out_str;
+
+ out_str = binary_string(ctx_ptr, "TestString", 10, &out_len);
+ std::string output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "TestString");
+
+ out_str = binary_string(ctx_ptr, "", 0, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "");
+
+ out_str = binary_string(ctx_ptr, "T", 1, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "T");
+
+ out_str = binary_string(ctx_ptr, "\\x41\\x42\\x43", 12, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "ABC");
+
+ out_str = binary_string(ctx_ptr, "\\x41", 4, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "A");
+
+ out_str = binary_string(ctx_ptr, "\\x6d\\x6D", 8, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "mm");
+
+ out_str = binary_string(ctx_ptr, "\\x6f\\x6d", 8, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "om");
+
+ out_str = binary_string(ctx_ptr, "\\x4f\\x4D", 8, &out_len);
+ output = std::string(out_str, out_len);
+ EXPECT_EQ(output, "OM");
+}
+
+TEST(TestStringOps, TestSplitPart) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+ const char* out_str;
+
+ out_str = split_part(ctx_ptr, "A,B,C", 5, ",", 1, 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_THAT(
+ ctx.get_error(),
+ ::testing::HasSubstr("Index in split_part must be positive, value provided was 0"));
+
+ out_str = split_part(ctx_ptr, "A,B,C", 5, ",", 1, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "A");
+
+ out_str = split_part(ctx_ptr, "A,B,C", 5, ",", 1, 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "B");
+
+ out_str = split_part(ctx_ptr, "A,B,C", 5, ",", 1, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "C");
+
+ out_str = split_part(ctx_ptr, "abc~@~def~@~ghi", 15, "~@~", 3, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "abc");
+
+ out_str = split_part(ctx_ptr, "abc~@~def~@~ghi", 15, "~@~", 3, 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "def");
+
+ out_str = split_part(ctx_ptr, "abc~@~def~@~ghi", 15, "~@~", 3, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "ghi");
+
+ // Result must be empty when the index is > no of elements
+ out_str = split_part(ctx_ptr, "123|456|789", 11, "|", 1, 4, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+
+ out_str = split_part(ctx_ptr, "123|", 4, "|", 1, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "123");
+
+ out_str = split_part(ctx_ptr, "|123", 4, "|", 1, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+
+ out_str = split_part(ctx_ptr, "ç†ååçåå†", 18, "å", 2, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "ç†");
+
+ out_str = split_part(ctx_ptr, "ç†ååçåå†", 18, "†åå", 6, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "ç");
+
+ out_str = split_part(ctx_ptr, "ç†ååçåå†", 18, "†", 3, 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "ååçåå");
+}
+
+TEST(TestStringOps, TestConvertTo) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+ const char* out_str;
+
+ const int32_t ALL_BYTES_MATCH = 0;
+
+ int32_t integer_value = std::numeric_limits<int32_t>::max();
+ out_str = convert_toINT(ctx_ptr, integer_value, &out_len);
+ EXPECT_EQ(out_len, sizeof(integer_value));
+ EXPECT_EQ(ALL_BYTES_MATCH, memcmp(out_str, &integer_value, out_len));
+
+ int64_t big_integer_value = std::numeric_limits<int64_t>::max();
+ out_str = convert_toBIGINT(ctx_ptr, big_integer_value, &out_len);
+ EXPECT_EQ(out_len, sizeof(big_integer_value));
+ EXPECT_EQ(ALL_BYTES_MATCH, memcmp(out_str, &big_integer_value, out_len));
+
+ float float_value = std::numeric_limits<float>::max();
+ out_str = convert_toFLOAT(ctx_ptr, float_value, &out_len);
+ EXPECT_EQ(out_len, sizeof(float_value));
+ EXPECT_EQ(ALL_BYTES_MATCH, memcmp(out_str, &float_value, out_len));
+
+ double double_value = std::numeric_limits<double>::max();
+ out_str = convert_toDOUBLE(ctx_ptr, double_value, &out_len);
+ EXPECT_EQ(out_len, sizeof(double_value));
+ EXPECT_EQ(ALL_BYTES_MATCH, memcmp(out_str, &double_value, out_len));
+
+ const char* test_string = "test string";
+ int32_t str_len = 11;
+ out_str = convert_toUTF8(ctx_ptr, test_string, str_len, &out_len);
+ EXPECT_EQ(out_len, str_len);
+ EXPECT_EQ(ALL_BYTES_MATCH, memcmp(out_str, test_string, out_len));
+}
+
+TEST(TestStringOps, TestConvertToBigEndian) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+ gdv_int32 out_len_big_endian = 0;
+ const char* out_str;
+ const char* out_str_big_endian;
+
+ int64_t big_integer_value = std::numeric_limits<int64_t>::max();
+ out_str = convert_toBIGINT(ctx_ptr, big_integer_value, &out_len);
+ out_str_big_endian =
+ convert_toBIGINT_be(ctx_ptr, big_integer_value, &out_len_big_endian);
+ EXPECT_EQ(out_len_big_endian, sizeof(big_integer_value));
+ EXPECT_EQ(out_len_big_endian, out_len);
+
+#if ARROW_LITTLE_ENDIAN
+ // Checks that bytes are in reverse order
+ for (auto i = 0; i < out_len; i++) {
+ EXPECT_EQ(out_str[i], out_str_big_endian[out_len - (i + 1)]);
+ }
+#else
+ for (auto i = 0; i < out_len; i++) {
+ EXPECT_EQ(out_str[i], out_str_big_endian[i]);
+ }
+#endif
+
+ double double_value = std::numeric_limits<double>::max();
+ out_str = convert_toDOUBLE(ctx_ptr, double_value, &out_len);
+ out_str_big_endian = convert_toDOUBLE_be(ctx_ptr, double_value, &out_len_big_endian);
+ EXPECT_EQ(out_len_big_endian, sizeof(double_value));
+ EXPECT_EQ(out_len_big_endian, out_len);
+
+#if ARROW_LITTLE_ENDIAN
+ // Checks that bytes are in reverse order
+ for (auto i = 0; i < out_len; i++) {
+ EXPECT_EQ(out_str[i], out_str_big_endian[out_len - (i + 1)]);
+ }
+#else
+ for (auto i = 0; i < out_len; i++) {
+ EXPECT_EQ(out_str[i], out_str_big_endian[i]);
+ }
+#endif
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/precompiled/testing.h b/src/arrow/cpp/src/gandiva/precompiled/testing.h
new file mode 100644
index 000000000..c41bc5471
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/testing.h
@@ -0,0 +1,43 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <ctime>
+#include <string>
+
+#include <gtest/gtest.h>
+
+#include "arrow/util/logging.h"
+#include "arrow/util/value_parsing.h"
+
+#include "gandiva/date_utils.h"
+#include "gandiva/precompiled/types.h"
+
+namespace gandiva {
+
+static inline gdv_timestamp StringToTimestamp(const std::string& s) {
+ int64_t out = 0;
+ bool success = ::arrow::internal::ParseTimestampStrptime(
+ s.c_str(), s.length(), "%Y-%m-%d %H:%M:%S", /*ignore_time_in_day=*/false,
+ /*allow_trailing_chars=*/false, ::arrow::TimeUnit::SECOND, &out);
+ DCHECK(success);
+ ARROW_UNUSED(success);
+ return out * 1000;
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/precompiled/time.cc b/src/arrow/cpp/src/gandiva/precompiled/time.cc
new file mode 100644
index 000000000..336f69226
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/time.cc
@@ -0,0 +1,894 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "./epoch_time_point.h"
+
+extern "C" {
+
+#define __STDC_FORMAT_MACROS
+#include <inttypes.h>
+#include <stdlib.h>
+#include <string.h>
+#include <time.h>
+
+#include "./time_constants.h"
+#include "./time_fields.h"
+#include "./types.h"
+
+#define MINS_IN_HOUR 60
+#define SECONDS_IN_MINUTE 60
+#define SECONDS_IN_HOUR (SECONDS_IN_MINUTE) * (MINS_IN_HOUR)
+
+#define HOURS_IN_DAY 24
+
+// Expand inner macro for all date types.
+#define DATE_TYPES(INNER) \
+ INNER(date64) \
+ INNER(timestamp)
+
+// Expand inner macro for all base numeric types.
+#define NUMERIC_TYPES(INNER) \
+ INNER(int8) \
+ INNER(int16) \
+ INNER(int32) \
+ INNER(int64) \
+ INNER(uint8) \
+ INNER(uint16) \
+ INNER(uint32) \
+ INNER(uint64) \
+ INNER(float32) \
+ INNER(float64)
+
+// Extract millennium
+#define EXTRACT_MILLENNIUM(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractMillennium##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return (1900 + tp.TmYear() - 1) / 1000 + 1; \
+ }
+
+DATE_TYPES(EXTRACT_MILLENNIUM)
+
+// Extract century
+#define EXTRACT_CENTURY(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractCentury##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return (1900 + tp.TmYear() - 1) / 100 + 1; \
+ }
+
+DATE_TYPES(EXTRACT_CENTURY)
+
+// Extract decade
+#define EXTRACT_DECADE(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractDecade##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return (1900 + tp.TmYear()) / 10; \
+ }
+
+DATE_TYPES(EXTRACT_DECADE)
+
+// Extract year.
+#define EXTRACT_YEAR(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractYear##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return 1900 + tp.TmYear(); \
+ }
+
+DATE_TYPES(EXTRACT_YEAR)
+
+#define EXTRACT_DOY(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractDoy##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return 1 + tp.TmYday(); \
+ }
+
+DATE_TYPES(EXTRACT_DOY)
+
+#define EXTRACT_QUARTER(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractQuarter##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return tp.TmMon() / 3 + 1; \
+ }
+
+DATE_TYPES(EXTRACT_QUARTER)
+
+#define EXTRACT_MONTH(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractMonth##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return 1 + tp.TmMon(); \
+ }
+
+DATE_TYPES(EXTRACT_MONTH)
+
+#define JAN1_WDAY(tp) ((tp.TmWday() - (tp.TmYday() % 7) + 7) % 7)
+
+bool IsLeapYear(int yy) {
+ if ((yy % 4) != 0) {
+ // not divisible by 4
+ return false;
+ }
+
+ // yy = 4x
+ if ((yy % 400) == 0) {
+ // yy = 400x
+ return true;
+ }
+
+ // yy = 4x, return true if yy != 100x
+ return ((yy % 100) != 0);
+}
+
+// Day belongs to current year
+// Note that TmYday is 0 for Jan 1 (subtract 1 from day in the below examples)
+//
+// If Jan 1 is Mon, (TmYday) / 7 + 1 (Jan 1->WK1, Jan 8->WK2, etc)
+// If Jan 1 is Tues, (TmYday + 1) / 7 + 1 (Jan 1->WK1, Jan 7->WK2, etc)
+// If Jan 1 is Wed, (TmYday + 2) / 7 + 1
+// If Jan 1 is Thu, (TmYday + 3) / 7 + 1
+//
+// If Jan 1 is Fri, Sat or Sun, the first few days belong to the previous year
+// If Jan 1 is Fri, (TmYday - 3) / 7 + 1 (Jan 4->WK1, Jan 11->WK2)
+// If Jan 1 is Sat, (TmYday - 2) / 7 + 1 (Jan 3->WK1, Jan 10->WK2)
+// If Jan 1 is Sun, (TmYday - 1) / 7 + 1 (Jan 2->WK1, Jan 9->WK2)
+int weekOfCurrentYear(const EpochTimePoint& tp) {
+ int jan1_wday = JAN1_WDAY(tp);
+ switch (jan1_wday) {
+ // Monday
+ case 1:
+ // Tuesday
+ case 2:
+ // Wednesday
+ case 3:
+ // Thursday
+ case 4: {
+ return (tp.TmYday() + jan1_wday - 1) / 7 + 1;
+ }
+ // Friday
+ case 5:
+ // Saturday
+ case 6: {
+ return (tp.TmYday() - (8 - jan1_wday)) / 7 + 1;
+ }
+ // Sunday
+ case 0: {
+ return (tp.TmYday() - 1) / 7 + 1;
+ }
+ }
+
+ // cannot reach here
+ // keep compiler happy
+ return 0;
+}
+
+// Jan 1-3
+// If Jan 1 is one of Mon, Tue, Wed, Thu - belongs to week of current year
+// If Jan 1 is Fri/Sat/Sun - belongs to previous year
+int getJanWeekOfYear(const EpochTimePoint& tp) {
+ int jan1_wday = JAN1_WDAY(tp);
+
+ if ((jan1_wday >= 1) && (jan1_wday <= 4)) {
+ // Jan 1-3 with the week belonging to this year
+ return 1;
+ }
+
+ if (jan1_wday == 5) {
+ // Jan 1 is a Fri
+ // Jan 1-3 belong to previous year. Dec 31 of previous year same week # as Jan 1-3
+ // previous year is a leap year:
+ // Prev Jan 1 is a Wed. Jan 6th is Mon
+ // Dec 31 - Jan 6 = 366 - 5 = 361
+ // week from Jan 6 = (361 - 1) / 7 + 1 = 52
+ // week # in previous year = 52 + 1 = 53
+ //
+ // previous year is not a leap year. Jan 1 is Thu. Jan 5th is Mon
+ // Dec 31 - Jan 5 = 365 - 4 = 361
+ // week from Jan 5 = (361 - 1) / 7 + 1 = 52
+ // week # in previous year = 52 + 1 = 53
+ return 53;
+ }
+
+ if (jan1_wday == 0) {
+ // Jan 1 is a Sun
+ if (tp.TmMday() > 1) {
+ // Jan 2 and 3 belong to current year
+ return 1;
+ }
+
+ // day belongs to previous year. Same as Dec 31
+ // Same as the case where Jan 1 is a Fri, except that previous year
+ // does not have an extra week
+ // Hence, return 52
+ return 52;
+ }
+
+ // Jan 1 is a Sat
+ // Jan 1-2 belong to previous year
+ if (tp.TmMday() == 3) {
+ // Jan 3, return 1
+ return 1;
+ }
+
+ // prev Jan 1 is leap year
+ // prev Jan 1 is a Thu
+ // return 53 (extra week)
+ if (IsLeapYear(1900 + tp.TmYear() - 1)) {
+ return 53;
+ }
+
+ // prev Jan 1 is not a leap year
+ // prev Jan 1 is a Fri
+ // return 52 (no extra week)
+ return 52;
+}
+
+// Dec 29-31
+int getDecWeekOfYear(const EpochTimePoint& tp) {
+ int next_jan1_wday = (tp.TmWday() + (31 - tp.TmMday()) + 1) % 7;
+
+ if (next_jan1_wday == 4) {
+ // next Jan 1 is a Thu
+ // day belongs to week 1 of next year
+ return 1;
+ }
+
+ if (next_jan1_wday == 3) {
+ // next Jan 1 is a Wed
+ // Dec 31 and 30 belong to next year - return 1
+ if (tp.TmMday() != 29) {
+ return 1;
+ }
+
+ // Dec 29 belongs to current year
+ return weekOfCurrentYear(tp);
+ }
+
+ if (next_jan1_wday == 2) {
+ // next Jan 1 is a Tue
+ // Dec 31 belongs to next year - return 1
+ if (tp.TmMday() == 31) {
+ return 1;
+ }
+
+ // Dec 29 and 30 belong to current year
+ return weekOfCurrentYear(tp);
+ }
+
+ // next Jan 1 is a Fri/Sat/Sun. No day from this year belongs to that week
+ // next Jan 1 is a Mon. No day from this year belongs to that week
+ return weekOfCurrentYear(tp);
+}
+
+// Week of year is determined by ISO 8601 standard
+// Take a look at: https://en.wikipedia.org/wiki/ISO_week_date
+//
+// Important points to note:
+// Week starts with a Monday and ends with a Sunday
+// A week can have some days in this year and some days in the previous/next year
+// This is true for the first and last weeks
+//
+// The first week of the year should have at-least 4 days in the current year
+// The last week of the year should have at-least 4 days in the current year
+//
+// A given day might belong to the first week of the next year - e.g Dec 29, 30 and 31
+// A given day might belong to the last week of the previous year - e.g. Jan 1, 2 and 3
+//
+// Algorithm:
+// If day belongs to week in current year, weekOfCurrentYear
+//
+// If day is Jan 1-3, see getJanWeekOfYear
+// If day is Dec 29-21, see getDecWeekOfYear
+//
+gdv_int64 weekOfYear(const EpochTimePoint& tp) {
+ if (tp.TmYday() < 3) {
+ // Jan 1-3
+ return getJanWeekOfYear(tp);
+ }
+
+ if ((tp.TmMon() == 11) && (tp.TmMday() >= 29)) {
+ // Dec 29-31
+ return getDecWeekOfYear(tp);
+ }
+
+ return weekOfCurrentYear(tp);
+}
+
+#define EXTRACT_WEEK(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractWeek##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return weekOfYear(tp); \
+ }
+
+DATE_TYPES(EXTRACT_WEEK)
+
+#define EXTRACT_DOW(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractDow##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return 1 + tp.TmWday(); \
+ }
+
+DATE_TYPES(EXTRACT_DOW)
+
+#define EXTRACT_DAY(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractDay##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return tp.TmMday(); \
+ }
+
+DATE_TYPES(EXTRACT_DAY)
+
+#define EXTRACT_HOUR(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractHour##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return tp.TmHour(); \
+ }
+
+DATE_TYPES(EXTRACT_HOUR)
+
+#define EXTRACT_MINUTE(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractMinute##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return tp.TmMin(); \
+ }
+
+DATE_TYPES(EXTRACT_MINUTE)
+
+#define EXTRACT_SECOND(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractSecond##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return tp.TmSec(); \
+ }
+
+DATE_TYPES(EXTRACT_SECOND)
+
+#define EXTRACT_EPOCH(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractEpoch##_##TYPE(gdv_##TYPE millis) { return MILLIS_TO_SEC(millis); }
+
+DATE_TYPES(EXTRACT_EPOCH)
+
+// Functions that work on millis in a day
+#define EXTRACT_SECOND_TIME(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractSecond##_##TYPE(gdv_##TYPE millis) { \
+ gdv_int64 seconds_of_day = MILLIS_TO_SEC(millis); \
+ gdv_int64 sec = seconds_of_day % SECONDS_IN_MINUTE; \
+ return sec; \
+ }
+
+EXTRACT_SECOND_TIME(time32)
+
+#define EXTRACT_MINUTE_TIME(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractMinute##_##TYPE(gdv_##TYPE millis) { \
+ gdv_##TYPE mins = MILLIS_TO_MINS(millis); \
+ return (mins % (MINS_IN_HOUR)); \
+ }
+
+EXTRACT_MINUTE_TIME(time32)
+
+#define EXTRACT_HOUR_TIME(TYPE) \
+ FORCE_INLINE \
+ gdv_int64 extractHour##_##TYPE(gdv_##TYPE millis) { return MILLIS_TO_HOUR(millis); }
+
+EXTRACT_HOUR_TIME(time32)
+
+#define DATE_TRUNC_FIXED_UNIT(NAME, TYPE, NMILLIS_IN_UNIT) \
+ FORCE_INLINE \
+ gdv_##TYPE NAME##_##TYPE(gdv_##TYPE millis) { \
+ return ((millis / NMILLIS_IN_UNIT) * NMILLIS_IN_UNIT); \
+ }
+
+#define DATE_TRUNC_WEEK(TYPE) \
+ FORCE_INLINE \
+ gdv_##TYPE date_trunc_Week_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ int ndays_to_trunc = 0; \
+ if (tp.TmWday() == 0) { \
+ /* Sunday */ \
+ ndays_to_trunc = 6; \
+ } else { \
+ /* All other days */ \
+ ndays_to_trunc = tp.TmWday() - 1; \
+ } \
+ return tp.AddDays(-ndays_to_trunc).ClearTimeOfDay().MillisSinceEpoch(); \
+ }
+
+#define DATE_TRUNC_MONTH_UNITS(NAME, TYPE, NMONTHS_IN_UNIT) \
+ FORCE_INLINE \
+ gdv_##TYPE NAME##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ int ndays_to_trunc = tp.TmMday() - 1; \
+ int nmonths_to_trunc = \
+ tp.TmMon() - ((tp.TmMon() / NMONTHS_IN_UNIT) * NMONTHS_IN_UNIT); \
+ return tp.AddDays(-ndays_to_trunc) \
+ .AddMonths(-nmonths_to_trunc) \
+ .ClearTimeOfDay() \
+ .MillisSinceEpoch(); \
+ }
+
+#define DATE_TRUNC_YEAR_UNITS(NAME, TYPE, NYEARS_IN_UNIT, OFF_BY) \
+ FORCE_INLINE \
+ gdv_##TYPE NAME##_##TYPE(gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ int ndays_to_trunc = tp.TmMday() - 1; \
+ int nmonths_to_trunc = tp.TmMon(); \
+ int year = 1900 + tp.TmYear(); \
+ year = ((year - OFF_BY) / NYEARS_IN_UNIT) * NYEARS_IN_UNIT + OFF_BY; \
+ int nyears_to_trunc = tp.TmYear() - (year - 1900); \
+ return tp.AddDays(-ndays_to_trunc) \
+ .AddMonths(-nmonths_to_trunc) \
+ .AddYears(-nyears_to_trunc) \
+ .ClearTimeOfDay() \
+ .MillisSinceEpoch(); \
+ }
+
+#define DATE_TRUNC_FUNCTIONS(TYPE) \
+ DATE_TRUNC_FIXED_UNIT(date_trunc_Second, TYPE, MILLIS_IN_SEC) \
+ DATE_TRUNC_FIXED_UNIT(date_trunc_Minute, TYPE, MILLIS_IN_MIN) \
+ DATE_TRUNC_FIXED_UNIT(date_trunc_Hour, TYPE, MILLIS_IN_HOUR) \
+ DATE_TRUNC_FIXED_UNIT(date_trunc_Day, TYPE, MILLIS_IN_DAY) \
+ DATE_TRUNC_WEEK(TYPE) \
+ DATE_TRUNC_MONTH_UNITS(date_trunc_Month, TYPE, 1) \
+ DATE_TRUNC_MONTH_UNITS(date_trunc_Quarter, TYPE, 3) \
+ DATE_TRUNC_MONTH_UNITS(date_trunc_Year, TYPE, 12) \
+ DATE_TRUNC_YEAR_UNITS(date_trunc_Decade, TYPE, 10, 0) \
+ DATE_TRUNC_YEAR_UNITS(date_trunc_Century, TYPE, 100, 1) \
+ DATE_TRUNC_YEAR_UNITS(date_trunc_Millennium, TYPE, 1000, 1)
+
+DATE_TRUNC_FUNCTIONS(date64)
+DATE_TRUNC_FUNCTIONS(timestamp)
+
+#define LAST_DAY_FUNC(TYPE) \
+ FORCE_INLINE \
+ gdv_date64 last_day_from_##TYPE(gdv_date64 millis) { \
+ EpochTimePoint received_day(millis); \
+ const auto& day_without_hours_and_sec = received_day.ClearTimeOfDay(); \
+ \
+ int received_day_in_month = day_without_hours_and_sec.TmMday(); \
+ const auto& first_day_in_month = \
+ day_without_hours_and_sec.AddDays(1 - received_day_in_month); \
+ \
+ const auto& month_last_day = first_day_in_month.AddMonths(1).AddDays(-1); \
+ \
+ return month_last_day.MillisSinceEpoch(); \
+ }
+
+DATE_TYPES(LAST_DAY_FUNC)
+
+FORCE_INLINE
+gdv_date64 castDATE_int64(gdv_int64 in) { return in; }
+
+FORCE_INLINE
+gdv_date32 castDATE_int32(gdv_int32 in) { return in; }
+
+FORCE_INLINE
+gdv_date64 castDATE_date32(gdv_date32 days) {
+ return days * static_cast<gdv_date64>(MILLIS_IN_DAY);
+}
+
+static int days_in_month[] = {31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31};
+
+bool IsLastDayOfMonth(const EpochTimePoint& tp) {
+ if (tp.TmMon() != 1) {
+ // not February. Don't worry about leap year
+ return (tp.TmMday() == days_in_month[tp.TmMon()]);
+ }
+
+ // this is February, check if the day is 28 or 29
+ if (tp.TmMday() < 28) {
+ return false;
+ }
+
+ if (tp.TmMday() == 29) {
+ // Feb 29th
+ return true;
+ }
+
+ // check if year is non-leap year
+ return !IsLeapYear(tp.TmYear());
+}
+
+FORCE_INLINE
+bool is_valid_time(const int hours, const int minutes, const int seconds) {
+ return hours >= 0 && hours < 24 && minutes >= 0 && minutes < 60 && seconds >= 0 &&
+ seconds < 60;
+}
+
+// MONTHS_BETWEEN returns number of months between dates date1 and date2.
+// If date1 is later than date2, then the result is positive.
+// If date1 is earlier than date2, then the result is negative.
+// If date1 and date2 are either the same days of the month or both last days of months,
+// then the result is always an integer. Otherwise Oracle Database calculates the
+// fractional portion of the result based on a 31-day month and considers the difference
+// in time components date1 and date2
+#define MONTHS_BETWEEN(TYPE) \
+ FORCE_INLINE \
+ double months_between##_##TYPE##_##TYPE(uint64_t endEpoch, uint64_t startEpoch) { \
+ EpochTimePoint endTime(endEpoch); \
+ EpochTimePoint startTime(startEpoch); \
+ int endYear = endTime.TmYear(); \
+ int endMonth = endTime.TmMon(); \
+ int startYear = startTime.TmYear(); \
+ int startMonth = startTime.TmMon(); \
+ int monthsDiff = (endYear - startYear) * 12 + (endMonth - startMonth); \
+ if ((endTime.TmMday() == startTime.TmMday()) || \
+ (IsLastDayOfMonth(endTime) && IsLastDayOfMonth(startTime))) { \
+ return static_cast<double>(monthsDiff); \
+ } \
+ double diffDays = static_cast<double>(endTime.TmMday() - startTime.TmMday()) / \
+ static_cast<double>(31); \
+ double diffHours = static_cast<double>(endTime.TmHour() - startTime.TmHour()) + \
+ static_cast<double>(endTime.TmMin() - startTime.TmMin()) / \
+ static_cast<double>(MINS_IN_HOUR) + \
+ static_cast<double>(endTime.TmSec() - startTime.TmSec()) / \
+ static_cast<double>(SECONDS_IN_HOUR); \
+ return static_cast<double>(monthsDiff) + diffDays + \
+ diffHours / static_cast<double>(HOURS_IN_DAY * 31); \
+ }
+
+DATE_TYPES(MONTHS_BETWEEN)
+
+FORCE_INLINE
+void set_error_for_date(gdv_int32 length, const char* input, const char* msg,
+ int64_t execution_context) {
+ int size = length + static_cast<int>(strlen(msg)) + 1;
+ char* error = reinterpret_cast<char*>(malloc(size));
+ snprintf(error, size, "%s%s", msg, input);
+ gdv_fn_context_set_error_msg(execution_context, error);
+ free(error);
+}
+
+gdv_date64 castDATE_utf8(int64_t context, const char* input, gdv_int32 length) {
+ using arrow_vendored::date::day;
+ using arrow_vendored::date::month;
+ using arrow_vendored::date::sys_days;
+ using arrow_vendored::date::year;
+ using arrow_vendored::date::year_month_day;
+ using gandiva::TimeFields;
+ // format : 0 is year, 1 is month and 2 is day.
+ int dateFields[3];
+ int dateIndex = 0, index = 0, value = 0;
+ int year_str_len = 0;
+ while (dateIndex < 3 && index < length) {
+ if (!isdigit(input[index])) {
+ dateFields[dateIndex++] = value;
+ value = 0;
+ } else {
+ value = (value * 10) + (input[index] - '0');
+ if (dateIndex == TimeFields::kYear) {
+ year_str_len++;
+ }
+ }
+ index++;
+ }
+
+ if (dateIndex < 3) {
+ // If we reached the end of input, we would have not encountered a separator
+ // store the last value
+ dateFields[dateIndex++] = value;
+ }
+ const char* msg = "Not a valid date value ";
+ if (dateIndex != 3) {
+ set_error_for_date(length, input, msg, context);
+ return 0;
+ }
+
+ /* Handle two digit years
+ * If range of two digits is between 70 - 99 then year = 1970 - 1999
+ * Else if two digits is between 00 - 69 = 2000 - 2069
+ */
+ if (dateFields[TimeFields::kYear] < 100 && year_str_len < 4) {
+ if (dateFields[TimeFields::kYear] < 70) {
+ dateFields[TimeFields::kYear] += 2000;
+ } else {
+ dateFields[TimeFields::kYear] += 1900;
+ }
+ }
+ year_month_day date = year(dateFields[TimeFields::kYear]) /
+ month(dateFields[TimeFields::kMonth]) /
+ day(dateFields[TimeFields::kDay]);
+ if (!date.ok()) {
+ set_error_for_date(length, input, msg, context);
+ return 0;
+ }
+ return std::chrono::time_point_cast<std::chrono::milliseconds>(sys_days(date))
+ .time_since_epoch()
+ .count();
+}
+
+/*
+ * Input consists of mandatory and optional fields.
+ * Mandatory fields are year, month and day.
+ * Optional fields are time, displacement and zone.
+ * Format is <year-month-day>[ hours:minutes:seconds][.millis][ displacement|zone]
+ */
+gdv_timestamp castTIMESTAMP_utf8(int64_t context, const char* input, gdv_int32 length) {
+ using arrow_vendored::date::day;
+ using arrow_vendored::date::month;
+ using arrow_vendored::date::sys_days;
+ using arrow_vendored::date::year;
+ using arrow_vendored::date::year_month_day;
+ using gandiva::TimeFields;
+ using std::chrono::hours;
+ using std::chrono::milliseconds;
+ using std::chrono::minutes;
+ using std::chrono::seconds;
+
+ int ts_fields[9] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
+ gdv_boolean add_displacement = true;
+ gdv_boolean encountered_zone = false;
+ int year_str_len = 0, sub_seconds_len = 0;
+ int ts_field_index = TimeFields::kYear, index = 0, value = 0;
+ while (ts_field_index < TimeFields::kMax && index < length) {
+ if (isdigit(input[index])) {
+ value = (value * 10) + (input[index] - '0');
+ if (ts_field_index == TimeFields::kYear) {
+ year_str_len++;
+ }
+ if (ts_field_index == TimeFields::kSubSeconds) {
+ sub_seconds_len++;
+ }
+ } else {
+ ts_fields[ts_field_index] = value;
+ value = 0;
+
+ switch (input[index]) {
+ case '.':
+ case ':':
+ case ' ':
+ ts_field_index++;
+ break;
+ case '+':
+ // +08:00, means time zone is 8 hours ahead. Need to subtract.
+ add_displacement = false;
+ ts_field_index = TimeFields::kDisplacementHours;
+ break;
+ case '-':
+ // Overloaded as date separator and negative displacement.
+ ts_field_index = (ts_field_index < 3) ? (ts_field_index + 1)
+ : TimeFields::kDisplacementHours;
+ break;
+ default:
+ encountered_zone = true;
+ break;
+ }
+ }
+ if (encountered_zone) {
+ break;
+ }
+ index++;
+ }
+
+ // Store the last value
+ if (ts_field_index < TimeFields::kMax) {
+ ts_fields[ts_field_index++] = value;
+ }
+
+ // adjust the year
+ if (ts_fields[TimeFields::kYear] < 100 && year_str_len < 4) {
+ if (ts_fields[TimeFields::kYear] < 70) {
+ ts_fields[TimeFields::kYear] += 2000;
+ } else {
+ ts_fields[TimeFields::kYear] += 1900;
+ }
+ }
+
+ // adjust the milliseconds
+ if (sub_seconds_len > 0) {
+ if (sub_seconds_len > 3) {
+ const char* msg = "Invalid millis for timestamp value ";
+ set_error_for_date(length, input, msg, context);
+ return 0;
+ }
+ while (sub_seconds_len < 3) {
+ ts_fields[TimeFields::kSubSeconds] *= 10;
+ sub_seconds_len++;
+ }
+ }
+ // handle timezone
+ if (encountered_zone) {
+ int err = 0;
+ gdv_timestamp ret_time = 0;
+ err = gdv_fn_time_with_zone(&ts_fields[0], (input + index), (length - index),
+ &ret_time);
+ if (err) {
+ const char* msg = "Invalid timestamp or unknown zone for timestamp value ";
+ set_error_for_date(length, input, msg, context);
+ return 0;
+ }
+ return ret_time;
+ }
+
+ year_month_day date = year(ts_fields[TimeFields::kYear]) /
+ month(ts_fields[TimeFields::kMonth]) /
+ day(ts_fields[TimeFields::kDay]);
+ if (!date.ok()) {
+ const char* msg = "Not a valid day for timestamp value ";
+ set_error_for_date(length, input, msg, context);
+ return 0;
+ }
+
+ if (!is_valid_time(ts_fields[TimeFields::kHours], ts_fields[TimeFields::kMinutes],
+ ts_fields[TimeFields::kSeconds])) {
+ const char* msg = "Not a valid time for timestamp value ";
+ set_error_for_date(length, input, msg, context);
+ return 0;
+ }
+
+ auto date_time = sys_days(date) + hours(ts_fields[TimeFields::kHours]) +
+ minutes(ts_fields[TimeFields::kMinutes]) +
+ seconds(ts_fields[TimeFields::kSeconds]) +
+ milliseconds(ts_fields[TimeFields::kSubSeconds]);
+ if (ts_fields[TimeFields::kDisplacementHours] ||
+ ts_fields[TimeFields::kDisplacementMinutes]) {
+ auto displacement_time = hours(ts_fields[TimeFields::kDisplacementHours]) +
+ minutes(ts_fields[TimeFields::kDisplacementMinutes]);
+ date_time = (add_displacement) ? (date_time + displacement_time)
+ : (date_time - displacement_time);
+ }
+ return std::chrono::time_point_cast<milliseconds>(date_time).time_since_epoch().count();
+}
+
+gdv_timestamp castTIMESTAMP_date64(gdv_date64 date_in_millis) { return date_in_millis; }
+
+gdv_timestamp castTIMESTAMP_int64(gdv_int64 in) { return in; }
+
+gdv_date64 castDATE_timestamp(gdv_timestamp timestamp_in_millis) {
+ EpochTimePoint tp(timestamp_in_millis);
+ return tp.ClearTimeOfDay().MillisSinceEpoch();
+}
+
+gdv_time32 castTIME_timestamp(gdv_timestamp timestamp_in_millis) {
+ // Retrieves a timestamp and returns the number of milliseconds since the midnight
+ EpochTimePoint tp(timestamp_in_millis);
+ auto tp_at_midnight = tp.ClearTimeOfDay();
+
+ int64_t millis_since_midnight =
+ tp.MillisSinceEpoch() - tp_at_midnight.MillisSinceEpoch();
+
+ return static_cast<int32_t>(millis_since_midnight);
+}
+
+const char* castVARCHAR_timestamp_int64(gdv_int64 context, gdv_timestamp in,
+ gdv_int64 length, gdv_int32* out_len) {
+ gdv_int64 year = extractYear_timestamp(in);
+ gdv_int64 month = extractMonth_timestamp(in);
+ gdv_int64 day = extractDay_timestamp(in);
+ gdv_int64 hour = extractHour_timestamp(in);
+ gdv_int64 minute = extractMinute_timestamp(in);
+ gdv_int64 second = extractSecond_timestamp(in);
+ gdv_int64 millis = in % MILLIS_IN_SEC;
+
+ static const int kTimeStampStringLen = 23;
+ const int char_buffer_length = kTimeStampStringLen + 1; // snprintf adds \0
+ char char_buffer[char_buffer_length];
+
+ // yyyy-MM-dd hh:mm:ss.sss
+ int res = snprintf(char_buffer, char_buffer_length,
+ "%04" PRId64 "-%02" PRId64 "-%02" PRId64 " %02" PRId64 ":%02" PRId64
+ ":%02" PRId64 ".%03" PRId64,
+ year, month, day, hour, minute, second, millis);
+ if (res < 0) {
+ gdv_fn_context_set_error_msg(context, "Could not format the timestamp");
+ return "";
+ }
+
+ *out_len = static_cast<gdv_int32>(length);
+ if (*out_len > kTimeStampStringLen) {
+ *out_len = kTimeStampStringLen;
+ }
+
+ if (*out_len <= 0) {
+ if (*out_len < 0) {
+ gdv_fn_context_set_error_msg(context, "Length of output string cannot be negative");
+ }
+ *out_len = 0;
+ return "";
+ }
+
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+
+ memcpy(ret, char_buffer, *out_len);
+ return ret;
+}
+
+FORCE_INLINE
+gdv_int64 extractDay_daytimeinterval(gdv_day_time_interval in) {
+ gdv_int32 days = static_cast<gdv_int32>(in & 0x00000000FFFFFFFF);
+ return static_cast<gdv_int64>(days);
+}
+
+FORCE_INLINE
+gdv_int64 extractMillis_daytimeinterval(gdv_day_time_interval in) {
+ gdv_int32 millis = static_cast<gdv_int32>((in & 0xFFFFFFFF00000000) >> 32);
+ return static_cast<gdv_int64>(millis);
+}
+
+FORCE_INLINE
+gdv_int64 castBIGINT_daytimeinterval(gdv_day_time_interval in) {
+ return extractMillis_daytimeinterval(in) +
+ extractDay_daytimeinterval(in) * MILLIS_IN_DAY;
+}
+
+// Convert the seconds since epoch argument to timestamp
+#define TO_TIMESTAMP(TYPE) \
+ FORCE_INLINE \
+ gdv_timestamp to_timestamp##_##TYPE(gdv_##TYPE seconds) { \
+ return static_cast<gdv_timestamp>(seconds * MILLIS_IN_SEC); \
+ }
+
+NUMERIC_TYPES(TO_TIMESTAMP)
+
+// Convert the seconds since epoch argument to time
+#define TO_TIME(TYPE) \
+ FORCE_INLINE \
+ gdv_time32 to_time##_##TYPE(gdv_##TYPE seconds) { \
+ EpochTimePoint tp(static_cast<int64_t>(seconds * MILLIS_IN_SEC)); \
+ return static_cast<gdv_time32>(tp.TimeOfDay().to_duration().count()); \
+ }
+
+NUMERIC_TYPES(TO_TIME)
+
+#define CAST_INT_YEAR_INTERVAL(TYPE, OUT_TYPE) \
+ FORCE_INLINE \
+ gdv_##OUT_TYPE TYPE##_year_interval(gdv_month_interval in) { \
+ return static_cast<gdv_##OUT_TYPE>(in / 12.0); \
+ }
+
+CAST_INT_YEAR_INTERVAL(castBIGINT, int64)
+CAST_INT_YEAR_INTERVAL(castINT, int32)
+
+#define CAST_NULLABLE_INTERVAL_DAY(TYPE) \
+ FORCE_INLINE \
+ gdv_day_time_interval castNULLABLEINTERVALDAY_##TYPE(gdv_##TYPE in) { \
+ return static_cast<gdv_day_time_interval>(in); \
+ }
+
+CAST_NULLABLE_INTERVAL_DAY(int32)
+CAST_NULLABLE_INTERVAL_DAY(int64)
+
+#define CAST_NULLABLE_INTERVAL_YEAR(TYPE) \
+ FORCE_INLINE \
+ gdv_month_interval castNULLABLEINTERVALYEAR_##TYPE(int64_t context, gdv_##TYPE in) { \
+ gdv_month_interval value = static_cast<gdv_month_interval>(in); \
+ if (value != in) { \
+ gdv_fn_context_set_error_msg(context, "Integer overflow"); \
+ } \
+ return value; \
+ }
+
+CAST_NULLABLE_INTERVAL_YEAR(int32)
+CAST_NULLABLE_INTERVAL_YEAR(int64)
+
+} // extern "C"
diff --git a/src/arrow/cpp/src/gandiva/precompiled/time_constants.h b/src/arrow/cpp/src/gandiva/precompiled/time_constants.h
new file mode 100644
index 000000000..015ef4bf9
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/time_constants.h
@@ -0,0 +1,30 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#define MILLIS_IN_SEC (1000)
+#define MILLIS_IN_MIN (60 * MILLIS_IN_SEC)
+#define MILLIS_IN_HOUR (60 * MILLIS_IN_MIN)
+#define MILLIS_IN_DAY (24 * MILLIS_IN_HOUR)
+#define MILLIS_IN_WEEK (7 * MILLIS_IN_DAY)
+
+#define MILLIS_TO_SEC(millis) ((millis) / MILLIS_IN_SEC)
+#define MILLIS_TO_MINS(millis) ((millis) / MILLIS_IN_MIN)
+#define MILLIS_TO_HOUR(millis) ((millis) / MILLIS_IN_HOUR)
+#define MILLIS_TO_DAY(millis) ((millis) / MILLIS_IN_DAY)
+#define MILLIS_TO_WEEK(millis) ((millis) / MILLIS_IN_WEEK)
diff --git a/src/arrow/cpp/src/gandiva/precompiled/time_fields.h b/src/arrow/cpp/src/gandiva/precompiled/time_fields.h
new file mode 100644
index 000000000..d5277e743
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/time_fields.h
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+namespace gandiva {
+
+enum TimeFields {
+ kYear,
+ kMonth,
+ kDay,
+ kHours,
+ kMinutes,
+ kSeconds,
+ kSubSeconds,
+ kDisplacementHours,
+ kDisplacementMinutes,
+ kMax
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/precompiled/time_test.cc b/src/arrow/cpp/src/gandiva/precompiled/time_test.cc
new file mode 100644
index 000000000..332ffa332
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/time_test.cc
@@ -0,0 +1,953 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include <time.h>
+
+#include "../execution_context.h"
+#include "gandiva/precompiled/testing.h"
+#include "gandiva/precompiled/types.h"
+
+namespace gandiva {
+
+TEST(TestTime, TestCastDate) {
+ ExecutionContext context;
+ int64_t context_ptr = reinterpret_cast<int64_t>(&context);
+
+ EXPECT_EQ(castDATE_utf8(context_ptr, "1967-12-1", 9), -65836800000);
+ EXPECT_EQ(castDATE_utf8(context_ptr, "2067-12-1", 9), 3089923200000);
+
+ EXPECT_EQ(castDATE_utf8(context_ptr, "7-12-1", 6), 1196467200000);
+ EXPECT_EQ(castDATE_utf8(context_ptr, "67-12-1", 7), 3089923200000);
+ EXPECT_EQ(castDATE_utf8(context_ptr, "067-12-1", 8), 3089923200000);
+ EXPECT_EQ(castDATE_utf8(context_ptr, "0067-12-1", 9), -60023980800000);
+ EXPECT_EQ(castDATE_utf8(context_ptr, "00067-12-1", 10), -60023980800000);
+ EXPECT_EQ(castDATE_utf8(context_ptr, "167-12-1", 8), -56868307200000);
+
+ EXPECT_EQ(castDATE_utf8(context_ptr, "1972-12-1", 9), 92016000000);
+ EXPECT_EQ(castDATE_utf8(context_ptr, "72-12-1", 7), 92016000000);
+
+ EXPECT_EQ(castDATE_utf8(context_ptr, "1972222222", 10), 0);
+ EXPECT_EQ(context.get_error(), "Not a valid date value 1972222222");
+ context.Reset();
+
+ EXPECT_EQ(castDATE_utf8(context_ptr, "blahblah", 8), 0);
+ EXPECT_EQ(castDATE_utf8(context_ptr, "1967-12-1bb", 11), -65836800000);
+
+ EXPECT_EQ(castDATE_utf8(context_ptr, "67-12-1", 7), 3089923200000);
+ EXPECT_EQ(castDATE_utf8(context_ptr, "67-1-1", 6), 3061065600000);
+ EXPECT_EQ(castDATE_utf8(context_ptr, "71-1-1", 6), 31536000000);
+ EXPECT_EQ(castDATE_utf8(context_ptr, "71-45-1", 7), 0);
+ EXPECT_EQ(castDATE_utf8(context_ptr, "71-12-XX", 8), 0);
+
+ EXPECT_EQ(castDATE_date32(1), 86400000);
+}
+
+TEST(TestTime, TestCastTimestamp) {
+ ExecutionContext context;
+ int64_t context_ptr = reinterpret_cast<int64_t>(&context);
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "1967-12-1", 9), -65836800000);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2067-12-1", 9), 3089923200000);
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "7-12-1", 6), 1196467200000);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "67-12-1", 7), 3089923200000);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "067-12-1", 8), 3089923200000);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "0067-12-1", 9), -60023980800000);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "00067-12-1", 10), -60023980800000);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "167-12-1", 8), -56868307200000);
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "1972-12-1", 9), 92016000000);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "72-12-1", 7), 92016000000);
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "1972-12-1", 9), 92016000000);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "67-12-1", 7), 3089923200000);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "67-1-1", 6), 3061065600000);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "71-1-1", 6), 31536000000);
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30", 18), 969702330000);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.920", 22), 969702330920);
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.920 +08:00", 29),
+ 969673530920);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.920 -11:45", 29),
+ 969744630920);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "65-03-04 00:20:40.920 +00:30", 28),
+ 3003349840920);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "1932-05-18 11:30:00.920 +11:30", 30),
+ -1187308799080);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "1857-02-11 20:31:40.920 -05:30", 30),
+ -3562264699080);
+ EXPECT_EQ(castTIMESTAMP_date64(
+ castDATE_utf8(context_ptr, "2000-09-23 9:45:30.920 +08:00", 29)),
+ castTIMESTAMP_utf8(context_ptr, "2000-09-23 0:00:00.000 +00:00", 29));
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.1", 20),
+ castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30", 18) + 100);
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.10", 20),
+ castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30", 18) + 100);
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.100", 20),
+ castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30", 18) + 100);
+
+ // error cases
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-01-01 24:00:00", 19), 0);
+ EXPECT_EQ(context.get_error(),
+ "Not a valid time for timestamp value 2000-01-01 24:00:00");
+ context.Reset();
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-01-01 00:60:00", 19), 0);
+ EXPECT_EQ(context.get_error(),
+ "Not a valid time for timestamp value 2000-01-01 00:60:00");
+ context.Reset();
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-01-01 00:00:100", 20), 0);
+ EXPECT_EQ(context.get_error(),
+ "Not a valid time for timestamp value 2000-01-01 00:00:100");
+ context.Reset();
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-01-01 00:00:00.0001", 24), 0);
+ EXPECT_EQ(context.get_error(),
+ "Invalid millis for timestamp value 2000-01-01 00:00:00.0001");
+ context.Reset();
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-01-01 00:00:00.1000", 24), 0);
+ EXPECT_EQ(context.get_error(),
+ "Invalid millis for timestamp value 2000-01-01 00:00:00.1000");
+ context.Reset();
+}
+
+#ifndef _WIN32
+
+// TODO(wesm): ARROW-4495. Need to address TZ database issues on Windows
+
+TEST(TestTime, TestCastTimestampWithTZ) {
+ ExecutionContext context;
+ int64_t context_ptr = reinterpret_cast<int64_t>(&context);
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.920 Canada/Pacific", 37),
+ 969727530920);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2012-02-28 23:30:59 Asia/Kolkata", 32),
+ 1330452059000);
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "1923-10-07 03:03:03 America/New_York", 36),
+ -1459094217000);
+}
+
+TEST(TestTime, TestCastTimestampErrors) {
+ ExecutionContext context;
+ int64_t context_ptr = reinterpret_cast<int64_t>(&context);
+
+ // error cases
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "20000923", 8), 0);
+ EXPECT_EQ(context.get_error(), "Not a valid day for timestamp value 20000923");
+ context.Reset();
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-2b", 10), 0);
+ EXPECT_EQ(context.get_error(),
+ "Invalid timestamp or unknown zone for timestamp value 2000-09-2b");
+ context.Reset();
+
+ EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.920 Unknown/Zone", 35),
+ 0);
+ EXPECT_EQ(context.get_error(),
+ "Invalid timestamp or unknown zone for timestamp value 2000-09-23 "
+ "9:45:30.920 Unknown/Zone");
+ context.Reset();
+}
+
+#endif
+
+TEST(TestTime, TestExtractTime) {
+ // 10:20:33
+ gdv_int32 time_as_millis_in_day = 37233000;
+
+ EXPECT_EQ(extractHour_time32(time_as_millis_in_day), 10);
+ EXPECT_EQ(extractMinute_time32(time_as_millis_in_day), 20);
+ EXPECT_EQ(extractSecond_time32(time_as_millis_in_day), 33);
+}
+
+TEST(TestTime, TestTimestampDiffMonth) {
+ gdv_timestamp ts1 = StringToTimestamp("2019-06-30 00:00:00");
+ gdv_timestamp ts2 = StringToTimestamp("2019-05-31 00:00:00");
+ EXPECT_EQ(timestampdiffMonth_timestamp_timestamp(ts1, ts2), -1);
+
+ ts1 = StringToTimestamp("2019-06-30 00:00:00");
+ ts2 = StringToTimestamp("2019-02-28 00:00:00");
+ EXPECT_EQ(timestampdiffMonth_timestamp_timestamp(ts1, ts2), -4);
+
+ ts1 = StringToTimestamp("2019-06-30 00:00:00");
+ ts2 = StringToTimestamp("2019-03-31 00:00:00");
+ EXPECT_EQ(timestampdiffMonth_timestamp_timestamp(ts1, ts2), -3);
+
+ ts1 = StringToTimestamp("2019-06-30 00:00:00");
+ ts2 = StringToTimestamp("2019-06-30 00:00:00");
+ EXPECT_EQ(timestampdiffMonth_timestamp_timestamp(ts1, ts2), 0);
+
+ ts1 = StringToTimestamp("2019-06-30 00:00:00");
+ ts2 = StringToTimestamp("2019-07-31 00:00:00");
+ EXPECT_EQ(timestampdiffMonth_timestamp_timestamp(ts1, ts2), 1);
+
+ ts1 = StringToTimestamp("2019-06-30 00:00:00");
+ ts2 = StringToTimestamp("2019-07-30 00:00:00");
+ EXPECT_EQ(timestampdiffMonth_timestamp_timestamp(ts1, ts2), 1);
+
+ ts1 = StringToTimestamp("2019-06-30 00:00:00");
+ ts2 = StringToTimestamp("2019-07-29 00:00:00");
+ EXPECT_EQ(timestampdiffMonth_timestamp_timestamp(ts1, ts2), 0);
+}
+
+TEST(TestTime, TestExtractTimestamp) {
+ gdv_timestamp ts = StringToTimestamp("1970-05-02 10:20:33");
+
+ EXPECT_EQ(extractMillennium_timestamp(ts), 2);
+ EXPECT_EQ(extractCentury_timestamp(ts), 20);
+ EXPECT_EQ(extractDecade_timestamp(ts), 197);
+ EXPECT_EQ(extractYear_timestamp(ts), 1970);
+ EXPECT_EQ(extractDoy_timestamp(ts), 122);
+ EXPECT_EQ(extractMonth_timestamp(ts), 5);
+ EXPECT_EQ(extractDow_timestamp(ts), 7);
+ EXPECT_EQ(extractDay_timestamp(ts), 2);
+ EXPECT_EQ(extractHour_timestamp(ts), 10);
+ EXPECT_EQ(extractMinute_timestamp(ts), 20);
+ EXPECT_EQ(extractSecond_timestamp(ts), 33);
+}
+
+TEST(TestTime, TimeStampTrunc) {
+ EXPECT_EQ(date_trunc_Second_date64(StringToTimestamp("2015-05-05 10:20:34")),
+ StringToTimestamp("2015-05-05 10:20:34"));
+ EXPECT_EQ(date_trunc_Minute_date64(StringToTimestamp("2015-05-05 10:20:34")),
+ StringToTimestamp("2015-05-05 10:20:00"));
+ EXPECT_EQ(date_trunc_Hour_date64(StringToTimestamp("2015-05-05 10:20:34")),
+ StringToTimestamp("2015-05-05 10:00:00"));
+ EXPECT_EQ(date_trunc_Day_date64(StringToTimestamp("2015-05-05 10:20:34")),
+ StringToTimestamp("2015-05-05 00:00:00"));
+ EXPECT_EQ(date_trunc_Month_date64(StringToTimestamp("2015-05-05 10:20:34")),
+ StringToTimestamp("2015-05-01 00:00:00"));
+ EXPECT_EQ(date_trunc_Quarter_date64(StringToTimestamp("2015-05-05 10:20:34")),
+ StringToTimestamp("2015-04-01 00:00:00"));
+ EXPECT_EQ(date_trunc_Year_date64(StringToTimestamp("2015-05-05 10:20:34")),
+ StringToTimestamp("2015-01-01 00:00:00"));
+ EXPECT_EQ(date_trunc_Decade_date64(StringToTimestamp("2015-05-05 10:20:34")),
+ StringToTimestamp("2010-01-01 00:00:00"));
+ EXPECT_EQ(date_trunc_Century_date64(StringToTimestamp("2115-05-05 10:20:34")),
+ StringToTimestamp("2101-01-01 00:00:00"));
+ EXPECT_EQ(date_trunc_Millennium_date64(StringToTimestamp("2115-05-05 10:20:34")),
+ StringToTimestamp("2001-01-01 00:00:00"));
+
+ // truncate week going to previous year
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-01 10:10:10")),
+ StringToTimestamp("2010-12-27 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-02 10:10:10")),
+ StringToTimestamp("2010-12-27 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-03 10:10:10")),
+ StringToTimestamp("2011-01-03 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-04 10:10:10")),
+ StringToTimestamp("2011-01-03 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-05 10:10:10")),
+ StringToTimestamp("2011-01-03 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-06 10:10:10")),
+ StringToTimestamp("2011-01-03 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-07 10:10:10")),
+ StringToTimestamp("2011-01-03 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-08 10:10:10")),
+ StringToTimestamp("2011-01-03 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-09 10:10:10")),
+ StringToTimestamp("2011-01-03 00:00:00"));
+
+ // truncate week for Feb in a leap year
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-02-28 10:10:10")),
+ StringToTimestamp("2000-02-28 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-02-29 10:10:10")),
+ StringToTimestamp("2000-02-28 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-01 10:10:10")),
+ StringToTimestamp("2000-02-28 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-02 10:10:10")),
+ StringToTimestamp("2000-02-28 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-03 10:10:10")),
+ StringToTimestamp("2000-02-28 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-04 10:10:10")),
+ StringToTimestamp("2000-02-28 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-05 10:10:10")),
+ StringToTimestamp("2000-02-28 00:00:00"));
+ EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-06 10:10:10")),
+ StringToTimestamp("2000-03-06 00:00:00"));
+}
+
+TEST(TestTime, TimeStampAdd) {
+ EXPECT_EQ(
+ timestampaddSecond_int32_timestamp(30, StringToTimestamp("2000-05-01 10:20:34")),
+ StringToTimestamp("2000-05-01 10:21:04"));
+
+ EXPECT_EQ(
+ timestampaddSecond_timestamp_int32(StringToTimestamp("2000-05-01 10:20:34"), 30),
+ StringToTimestamp("2000-05-01 10:21:04"));
+
+ EXPECT_EQ(
+ timestampaddMinute_int64_timestamp(-30, StringToTimestamp("2000-05-01 10:20:34")),
+ StringToTimestamp("2000-05-01 09:50:34"));
+
+ EXPECT_EQ(
+ timestampaddMinute_timestamp_int64(StringToTimestamp("2000-05-01 10:20:34"), -30),
+ StringToTimestamp("2000-05-01 09:50:34"));
+
+ EXPECT_EQ(
+ timestampaddHour_int32_timestamp(20, StringToTimestamp("2000-05-01 10:20:34")),
+ StringToTimestamp("2000-05-02 06:20:34"));
+
+ EXPECT_EQ(
+ timestampaddHour_timestamp_int32(StringToTimestamp("2000-05-01 10:20:34"), 20),
+ StringToTimestamp("2000-05-02 06:20:34"));
+
+ EXPECT_EQ(
+ timestampaddDay_int64_timestamp(-35, StringToTimestamp("2000-05-01 10:20:34")),
+ StringToTimestamp("2000-03-27 10:20:34"));
+
+ EXPECT_EQ(
+ timestampaddDay_timestamp_int64(StringToTimestamp("2000-05-01 10:20:34"), -35),
+ StringToTimestamp("2000-03-27 10:20:34"));
+
+ EXPECT_EQ(timestampaddWeek_int32_timestamp(4, StringToTimestamp("2000-05-01 10:20:34")),
+ StringToTimestamp("2000-05-29 10:20:34"));
+
+ EXPECT_EQ(timestampaddWeek_timestamp_int32(StringToTimestamp("2000-05-01 10:20:34"), 4),
+ StringToTimestamp("2000-05-29 10:20:34"));
+
+ EXPECT_EQ(timestampaddWeek_timestamp_int32(StringToTimestamp("2000-05-01 10:20:34"), 4),
+ StringToTimestamp("2000-05-29 10:20:34"));
+
+ EXPECT_EQ(
+ timestampaddMonth_int64_timestamp(10, StringToTimestamp("2000-05-01 10:20:34")),
+ StringToTimestamp("2001-03-01 10:20:34"));
+
+ EXPECT_EQ(
+ timestampaddMonth_int64_timestamp(1, StringToTimestamp("2000-01-31 10:20:34")),
+ StringToTimestamp("2000-2-29 10:20:34"));
+ EXPECT_EQ(
+ timestampaddMonth_int64_timestamp(13, StringToTimestamp("2001-01-31 10:20:34")),
+ StringToTimestamp("2002-02-28 10:20:34"));
+
+ EXPECT_EQ(
+ timestampaddMonth_int64_timestamp(11, StringToTimestamp("2000-05-31 10:20:34")),
+ StringToTimestamp("2001-04-30 10:20:34"));
+
+ EXPECT_EQ(
+ timestampaddMonth_timestamp_int64(StringToTimestamp("2000-05-31 10:20:34"), 11),
+ StringToTimestamp("2001-04-30 10:20:34"));
+
+ EXPECT_EQ(
+ timestampaddQuarter_int32_timestamp(-2, StringToTimestamp("2000-05-01 10:20:34")),
+ StringToTimestamp("1999-11-01 10:20:34"));
+
+ EXPECT_EQ(timestampaddYear_int64_timestamp(2, StringToTimestamp("2000-05-01 10:20:34")),
+ StringToTimestamp("2002-05-01 10:20:34"));
+
+ EXPECT_EQ(
+ timestampaddQuarter_int32_timestamp(-5, StringToTimestamp("2000-05-01 10:20:34")),
+ StringToTimestamp("1999-02-01 10:20:34"));
+ EXPECT_EQ(
+ timestampaddQuarter_int32_timestamp(-6, StringToTimestamp("2000-05-01 10:20:34")),
+ StringToTimestamp("1998-11-01 10:20:34"));
+
+ // date_add
+ EXPECT_EQ(date_add_int32_timestamp(7, StringToTimestamp("2000-05-01 00:00:00")),
+ StringToTimestamp("2000-05-08 00:00:00"));
+
+ EXPECT_EQ(add_int32_timestamp(4, StringToTimestamp("2000-05-01 00:00:00")),
+ StringToTimestamp("2000-05-05 00:00:00"));
+
+ EXPECT_EQ(add_int64_timestamp(7, StringToTimestamp("2000-05-01 00:00:00")),
+ StringToTimestamp("2000-05-08 00:00:00"));
+
+ EXPECT_EQ(date_add_int64_timestamp(4, StringToTimestamp("2000-05-01 00:00:00")),
+ StringToTimestamp("2000-05-05 00:00:00"));
+
+ EXPECT_EQ(date_add_int64_timestamp(4, StringToTimestamp("2000-02-27 00:00:00")),
+ StringToTimestamp("2000-03-02 00:00:00"));
+
+ EXPECT_EQ(add_date64_int64(StringToTimestamp("2000-02-27 00:00:00"), 4),
+ StringToTimestamp("2000-03-02 00:00:00"));
+
+ // date_sub
+ EXPECT_EQ(date_sub_timestamp_int32(StringToTimestamp("2000-05-01 00:00:00"), 7),
+ StringToTimestamp("2000-04-24 00:00:00"));
+
+ EXPECT_EQ(subtract_timestamp_int32(StringToTimestamp("2000-05-01 00:00:00"), -7),
+ StringToTimestamp("2000-05-08 00:00:00"));
+
+ EXPECT_EQ(date_diff_timestamp_int64(StringToTimestamp("2000-05-01 00:00:00"), 365),
+ StringToTimestamp("1999-05-02 00:00:00"));
+
+ EXPECT_EQ(date_diff_timestamp_int64(StringToTimestamp("2000-03-01 00:00:00"), 1),
+ StringToTimestamp("2000-02-29 00:00:00"));
+
+ EXPECT_EQ(date_diff_timestamp_int64(StringToTimestamp("2000-02-29 00:00:00"), 365),
+ StringToTimestamp("1999-03-01 00:00:00"));
+}
+
+// test cases from http://www.staff.science.uu.nl/~gent0113/calendar/isocalendar.htm
+TEST(TestTime, TestExtractWeek) {
+ std::vector<std::string> data;
+
+ // A type
+ // Jan 1, 2 and 3
+ data.push_back("2006-01-01 10:10:10");
+ data.push_back("52");
+ data.push_back("2006-01-02 10:10:10");
+ data.push_back("1");
+ data.push_back("2006-01-03 10:10:10");
+ data.push_back("1");
+ // middle, Monday and Sunday
+ data.push_back("2006-04-24 10:10:10");
+ data.push_back("17");
+ data.push_back("2006-04-30 10:10:10");
+ data.push_back("17");
+ // Dec 29-31
+ data.push_back("2006-12-29 10:10:10");
+ data.push_back("52");
+ data.push_back("2006-12-30 10:10:10");
+ data.push_back("52");
+ data.push_back("2006-12-31 10:10:10");
+ data.push_back("52");
+ // B(C) type
+ // Jan 1, 2 and 3
+ data.push_back("2011-01-01 10:10:10");
+ data.push_back("52");
+ data.push_back("2011-01-02 10:10:10");
+ data.push_back("52");
+ data.push_back("2011-01-03 10:10:10");
+ data.push_back("1");
+ // middle, Monday and Sunday
+ data.push_back("2011-07-18 10:10:10");
+ data.push_back("29");
+ data.push_back("2011-07-24 10:10:10");
+ data.push_back("29");
+ // Dec 29-31
+ data.push_back("2011-12-29 10:10:10");
+ data.push_back("52");
+ data.push_back("2011-12-30 10:10:10");
+ data.push_back("52");
+ data.push_back("2011-12-31 10:10:10");
+ data.push_back("52");
+ // B(DC) type
+ // Jan 1, 2 and 3
+ data.push_back("2005-01-01 10:10:10");
+ data.push_back("53");
+ data.push_back("2005-01-02 10:10:10");
+ data.push_back("53");
+ data.push_back("2005-01-03 10:10:10");
+ data.push_back("1");
+ // middle, Monday and Sunday
+ data.push_back("2005-11-07 10:10:10");
+ data.push_back("45");
+ data.push_back("2005-11-13 10:10:10");
+ data.push_back("45");
+ // Dec 29-31
+ data.push_back("2005-12-29 10:10:10");
+ data.push_back("52");
+ data.push_back("2005-12-30 10:10:10");
+ data.push_back("52");
+ data.push_back("2005-12-31 10:10:10");
+ data.push_back("52");
+ // C type
+ // Jan 1, 2 and 3
+ data.push_back("2010-01-01 10:10:10");
+ data.push_back("53");
+ data.push_back("2010-01-02 10:10:10");
+ data.push_back("53");
+ data.push_back("2010-01-03 10:10:10");
+ data.push_back("53");
+ // middle, Monday and Sunday
+ data.push_back("2010-09-13 10:10:10");
+ data.push_back("37");
+ data.push_back("2010-09-19 10:10:10");
+ data.push_back("37");
+ // Dec 29-31
+ data.push_back("2010-12-29 10:10:10");
+ data.push_back("52");
+ data.push_back("2010-12-30 10:10:10");
+ data.push_back("52");
+ data.push_back("2010-12-31 10:10:10");
+ data.push_back("52");
+ // D type
+ // Jan 1, 2 and 3
+ data.push_back("2037-01-01 10:10:10");
+ data.push_back("1");
+ data.push_back("2037-01-02 10:10:10");
+ data.push_back("1");
+ data.push_back("2037-01-03 10:10:10");
+ data.push_back("1");
+ // middle, Monday and Sunday
+ data.push_back("2037-08-17 10:10:10");
+ data.push_back("34");
+ data.push_back("2037-08-23 10:10:10");
+ data.push_back("34");
+ // Dec 29-31
+ data.push_back("2037-12-29 10:10:10");
+ data.push_back("53");
+ data.push_back("2037-12-30 10:10:10");
+ data.push_back("53");
+ data.push_back("2037-12-31 10:10:10");
+ data.push_back("53");
+ // E type
+ // Jan 1, 2 and 3
+ data.push_back("2014-01-01 10:10:10");
+ data.push_back("1");
+ data.push_back("2014-01-02 10:10:10");
+ data.push_back("1");
+ data.push_back("2014-01-03 10:10:10");
+ data.push_back("1");
+ // middle, Monday and Sunday
+ data.push_back("2014-01-13 10:10:10");
+ data.push_back("3");
+ data.push_back("2014-01-19 10:10:10");
+ data.push_back("3");
+ // Dec 29-31
+ data.push_back("2014-12-29 10:10:10");
+ data.push_back("1");
+ data.push_back("2014-12-30 10:10:10");
+ data.push_back("1");
+ data.push_back("2014-12-31 10:10:10");
+ data.push_back("1");
+ // F type
+ // Jan 1, 2 and 3
+ data.push_back("2019-01-01 10:10:10");
+ data.push_back("1");
+ data.push_back("2019-01-02 10:10:10");
+ data.push_back("1");
+ data.push_back("2019-01-03 10:10:10");
+ data.push_back("1");
+ // middle, Monday and Sunday
+ data.push_back("2019-02-11 10:10:10");
+ data.push_back("7");
+ data.push_back("2019-02-17 10:10:10");
+ data.push_back("7");
+ // Dec 29-31
+ data.push_back("2019-12-29 10:10:10");
+ data.push_back("52");
+ data.push_back("2019-12-30 10:10:10");
+ data.push_back("1");
+ data.push_back("2019-12-31 10:10:10");
+ data.push_back("1");
+ // G type
+ // Jan 1, 2 and 3
+ data.push_back("2001-01-01 10:10:10");
+ data.push_back("1");
+ data.push_back("2001-01-02 10:10:10");
+ data.push_back("1");
+ data.push_back("2001-01-03 10:10:10");
+ data.push_back("1");
+ // middle, Monday and Sunday
+ data.push_back("2001-03-19 10:10:10");
+ data.push_back("12");
+ data.push_back("2001-03-25 10:10:10");
+ data.push_back("12");
+ // Dec 29-31
+ data.push_back("2001-12-29 10:10:10");
+ data.push_back("52");
+ data.push_back("2001-12-30 10:10:10");
+ data.push_back("52");
+ data.push_back("2001-12-31 10:10:10");
+ data.push_back("1");
+ // AG type
+ // Jan 1, 2 and 3
+ data.push_back("2012-01-01 10:10:10");
+ data.push_back("52");
+ data.push_back("2012-01-02 10:10:10");
+ data.push_back("1");
+ data.push_back("2012-01-03 10:10:10");
+ data.push_back("1");
+ // middle, Monday and Sunday
+ data.push_back("2012-04-02 10:10:10");
+ data.push_back("14");
+ data.push_back("2012-04-08 10:10:10");
+ data.push_back("14");
+ // Dec 29-31
+ data.push_back("2012-12-29 10:10:10");
+ data.push_back("52");
+ data.push_back("2012-12-30 10:10:10");
+ data.push_back("52");
+ data.push_back("2012-12-31 10:10:10");
+ data.push_back("1");
+ // BA type
+ // Jan 1, 2 and 3
+ data.push_back("2000-01-01 10:10:10");
+ data.push_back("52");
+ data.push_back("2000-01-02 10:10:10");
+ data.push_back("52");
+ data.push_back("2000-01-03 10:10:10");
+ data.push_back("1");
+ // middle, Monday and Sunday
+ data.push_back("2000-05-22 10:10:10");
+ data.push_back("21");
+ data.push_back("2000-05-28 10:10:10");
+ data.push_back("21");
+ // Dec 29-31
+ data.push_back("2000-12-29 10:10:10");
+ data.push_back("52");
+ data.push_back("2000-12-30 10:10:10");
+ data.push_back("52");
+ data.push_back("2000-12-31 10:10:10");
+ data.push_back("52");
+ // CB type
+ // Jan 1, 2 and 3
+ data.push_back("2016-01-01 10:10:10");
+ data.push_back("53");
+ data.push_back("2016-01-02 10:10:10");
+ data.push_back("53");
+ data.push_back("2016-01-03 10:10:10");
+ data.push_back("53");
+ // middle, Monday and Sunday
+ data.push_back("2016-06-20 10:10:10");
+ data.push_back("25");
+ data.push_back("2016-06-26 10:10:10");
+ data.push_back("25");
+ // Dec 29-31
+ data.push_back("2016-12-29 10:10:10");
+ data.push_back("52");
+ data.push_back("2016-12-30 10:10:10");
+ data.push_back("52");
+ data.push_back("2016-12-31 10:10:10");
+ data.push_back("52");
+ // DC type
+ // Jan 1, 2 and 3
+ data.push_back("2004-01-01 10:10:10");
+ data.push_back("1");
+ data.push_back("2004-01-02 10:10:10");
+ data.push_back("1");
+ data.push_back("2004-01-03 10:10:10");
+ data.push_back("1");
+ // middle, Monday and Sunday
+ data.push_back("2004-07-19 10:10:10");
+ data.push_back("30");
+ data.push_back("2004-07-25 10:10:10");
+ data.push_back("30");
+ // Dec 29-31
+ data.push_back("2004-12-29 10:10:10");
+ data.push_back("53");
+ data.push_back("2004-12-30 10:10:10");
+ data.push_back("53");
+ data.push_back("2004-12-31 10:10:10");
+ data.push_back("53");
+ // ED type
+ // Jan 1, 2 and 3
+ data.push_back("2020-01-01 10:10:10");
+ data.push_back("1");
+ data.push_back("2020-01-02 10:10:10");
+ data.push_back("1");
+ data.push_back("2020-01-03 10:10:10");
+ data.push_back("1");
+ // middle, Monday and Sunday
+ data.push_back("2020-08-17 10:10:10");
+ data.push_back("34");
+ data.push_back("2020-08-23 10:10:10");
+ data.push_back("34");
+ // Dec 29-31
+ data.push_back("2020-12-29 10:10:10");
+ data.push_back("53");
+ data.push_back("2020-12-30 10:10:10");
+ data.push_back("53");
+ data.push_back("2020-12-31 10:10:10");
+ data.push_back("53");
+ // FE type
+ // Jan 1, 2 and 3
+ data.push_back("2008-01-01 10:10:10");
+ data.push_back("1");
+ data.push_back("2008-01-02 10:10:10");
+ data.push_back("1");
+ data.push_back("2008-01-03 10:10:10");
+ data.push_back("1");
+ // middle, Monday and Sunday
+ data.push_back("2008-09-15 10:10:10");
+ data.push_back("38");
+ data.push_back("2008-09-21 10:10:10");
+ data.push_back("38");
+ // Dec 29-31
+ data.push_back("2008-12-29 10:10:10");
+ data.push_back("1");
+ data.push_back("2008-12-30 10:10:10");
+ data.push_back("1");
+ data.push_back("2008-12-31 10:10:10");
+ data.push_back("1");
+ // GF type
+ // Jan 1, 2 and 3
+ data.push_back("2024-01-01 10:10:10");
+ data.push_back("1");
+ data.push_back("2024-01-02 10:10:10");
+ data.push_back("1");
+ data.push_back("2024-01-03 10:10:10");
+ data.push_back("1");
+ // middle, Monday and Sunday
+ data.push_back("2024-10-07 10:10:10");
+ data.push_back("41");
+ data.push_back("2024-10-13 10:10:10");
+ data.push_back("41");
+ // Dec 29-31
+ data.push_back("2024-12-29 10:10:10");
+ data.push_back("52");
+ data.push_back("2024-12-30 10:10:10");
+ data.push_back("1");
+ data.push_back("2024-12-31 10:10:10");
+ data.push_back("1");
+
+ for (uint32_t i = 0; i < data.size(); i += 2) {
+ gdv_timestamp ts = StringToTimestamp(data.at(i).c_str());
+ gdv_int64 exp = atol(data.at(i + 1).c_str());
+ EXPECT_EQ(extractWeek_timestamp(ts), exp);
+ }
+}
+
+TEST(TestTime, TestMonthsBetween) {
+ std::vector<std::string> testStrings = {
+ "1995-03-02 00:00:00", "1995-02-02 00:00:00", "1.0",
+ "1995-02-02 00:00:00", "1995-03-02 00:00:00", "-1.0",
+ "1995-03-31 00:00:00", "1995-02-28 00:00:00", "1.0",
+ "1996-03-31 00:00:00", "1996-02-28 00:00:00", "1.09677418",
+ "1996-03-31 00:00:00", "1996-02-29 00:00:00", "1.0",
+ "1996-05-31 00:00:00", "1996-04-30 00:00:00", "1.0",
+ "1996-05-31 00:00:00", "1996-03-31 00:00:00", "2.0",
+ "1996-05-31 00:00:00", "1996-03-30 00:00:00", "2.03225806",
+ "1996-03-15 00:00:00", "1996-02-14 00:00:00", "1.03225806",
+ "1995-02-02 00:00:00", "1995-01-01 00:00:00", "1.03225806",
+ "1995-02-02 10:00:00", "1995-01-01 11:00:00", "1.03091397"};
+
+ for (uint32_t i = 0; i < testStrings.size();) {
+ gdv_timestamp endTs = StringToTimestamp(testStrings[i++].c_str());
+ gdv_timestamp startTs = StringToTimestamp(testStrings[i++].c_str());
+
+ double expectedResult = atof(testStrings[i++].c_str());
+ double actualResult = months_between_timestamp_timestamp(endTs, startTs);
+
+ double diff = actualResult - expectedResult;
+ if (diff < 0) {
+ diff = expectedResult - actualResult;
+ }
+
+ EXPECT_TRUE(diff < 0.001);
+ }
+}
+
+TEST(TestTime, castVarcharTimestamp) {
+ ExecutionContext context;
+ int64_t context_ptr = reinterpret_cast<int64_t>(&context);
+ gdv_int32 out_len;
+ gdv_timestamp ts = StringToTimestamp("2000-05-01 10:20:34");
+ const char* out = castVARCHAR_timestamp_int64(context_ptr, ts, 30L, &out_len);
+ EXPECT_EQ(std::string(out, out_len), "2000-05-01 10:20:34.000");
+
+ out = castVARCHAR_timestamp_int64(context_ptr, ts, 19L, &out_len);
+ EXPECT_EQ(std::string(out, out_len), "2000-05-01 10:20:34");
+
+ out = castVARCHAR_timestamp_int64(context_ptr, ts, 0L, &out_len);
+ EXPECT_EQ(std::string(out, out_len), "");
+
+ ts = StringToTimestamp("2-5-1 00:00:04");
+ out = castVARCHAR_timestamp_int64(context_ptr, ts, 24L, &out_len);
+ EXPECT_EQ(std::string(out, out_len), "0002-05-01 00:00:04.000");
+}
+
+TEST(TestTime, TestCastTimestampToDate) {
+ gdv_timestamp ts = StringToTimestamp("2000-05-01 10:20:34");
+ auto out = castDATE_timestamp(ts);
+ EXPECT_EQ(StringToTimestamp("2000-05-01 00:00:00"), out);
+}
+
+TEST(TestTime, TestCastTimestampToTime) {
+ gdv_timestamp ts = StringToTimestamp("2000-05-01 10:20:34");
+ auto expected_response =
+ static_cast<int32_t>(ts - StringToTimestamp("2000-05-01 00:00:00"));
+ auto out = castTIME_timestamp(ts);
+ EXPECT_EQ(expected_response, out);
+
+ // Test when the defined value is midnight, so the returned value must 0
+ ts = StringToTimestamp("1998-12-01 00:00:00");
+ expected_response = 0;
+ out = castTIME_timestamp(ts);
+ EXPECT_EQ(expected_response, out);
+
+ ts = StringToTimestamp("2015-09-16 23:59:59");
+ expected_response = static_cast<int32_t>(ts - StringToTimestamp("2015-09-16 00:00:00"));
+ out = castTIME_timestamp(ts);
+ EXPECT_EQ(expected_response, out);
+}
+
+TEST(TestTime, TestLastDay) {
+ // leap year test
+ gdv_timestamp ts = StringToTimestamp("2016-02-11 03:20:34");
+ auto out = last_day_from_timestamp(ts);
+ EXPECT_EQ(StringToTimestamp("2016-02-29 00:00:00"), out);
+
+ ts = StringToTimestamp("2016-02-29 23:59:59");
+ out = last_day_from_timestamp(ts);
+ EXPECT_EQ(StringToTimestamp("2016-02-29 00:00:00"), out);
+
+ ts = StringToTimestamp("2016-01-30 23:59:00");
+ out = last_day_from_timestamp(ts);
+ EXPECT_EQ(StringToTimestamp("2016-01-31 00:00:00"), out);
+
+ // normal year
+ ts = StringToTimestamp("2017-02-03 23:59:59");
+ out = last_day_from_timestamp(ts);
+ EXPECT_EQ(StringToTimestamp("2017-02-28 00:00:00"), out);
+
+ // december
+ ts = StringToTimestamp("2015-12-03 03:12:59");
+ out = last_day_from_timestamp(ts);
+ EXPECT_EQ(StringToTimestamp("2015-12-31 00:00:00"), out);
+}
+
+TEST(TestTime, TestToTimestamp) {
+ auto ts = StringToTimestamp("1970-01-01 00:00:00");
+ EXPECT_EQ(ts, to_timestamp_int32(0));
+ EXPECT_EQ(ts, to_timestamp_int64(0));
+ EXPECT_EQ(ts, to_timestamp_float32(0));
+ EXPECT_EQ(ts, to_timestamp_float64(0));
+
+ ts = StringToTimestamp("1970-01-01 00:00:01");
+ EXPECT_EQ(ts, to_timestamp_int32(1));
+ EXPECT_EQ(ts, to_timestamp_int64(1));
+ EXPECT_EQ(ts, to_timestamp_float32(1));
+ EXPECT_EQ(ts, to_timestamp_float64(1));
+
+ ts = StringToTimestamp("1970-01-01 00:01:00");
+ EXPECT_EQ(ts, to_timestamp_int32(60));
+ EXPECT_EQ(ts, to_timestamp_int64(60));
+ EXPECT_EQ(ts, to_timestamp_float32(60));
+ EXPECT_EQ(ts, to_timestamp_float64(60));
+
+ ts = StringToTimestamp("1970-01-01 01:00:00");
+ EXPECT_EQ(ts, to_timestamp_int32(3600));
+ EXPECT_EQ(ts, to_timestamp_int64(3600));
+ EXPECT_EQ(ts, to_timestamp_float32(3600));
+ EXPECT_EQ(ts, to_timestamp_float64(3600));
+
+ ts = StringToTimestamp("1970-01-02 00:00:00");
+ EXPECT_EQ(ts, to_timestamp_int32(86400));
+ EXPECT_EQ(ts, to_timestamp_int64(86400));
+ EXPECT_EQ(ts, to_timestamp_float32(86400));
+ EXPECT_EQ(ts, to_timestamp_float64(86400));
+
+ // tests with fractional part
+ ts = StringToTimestamp("1970-01-01 00:00:01") + 500;
+ EXPECT_EQ(ts, to_timestamp_float32(1.500f));
+ EXPECT_EQ(ts, to_timestamp_float64(1.500));
+
+ ts = StringToTimestamp("1970-01-01 00:01:01") + 600;
+ EXPECT_EQ(ts, to_timestamp_float32(61.600f));
+ EXPECT_EQ(ts, to_timestamp_float64(61.600));
+
+ ts = StringToTimestamp("1970-01-01 01:00:01") + 400;
+ EXPECT_EQ(ts, to_timestamp_float32(3601.400f));
+ EXPECT_EQ(ts, to_timestamp_float64(3601.400));
+}
+
+TEST(TestTime, TestToTimeNumeric) {
+ // input timestamp in seconds: 1970-01-01 00:00:00
+ int64_t expected_output = 0; // 0 milliseconds
+ EXPECT_EQ(expected_output, to_time_int32(0));
+ EXPECT_EQ(expected_output, to_time_int64(0));
+ EXPECT_EQ(expected_output, to_time_float32(0.000f));
+ EXPECT_EQ(expected_output, to_time_float64(0.000));
+
+ // input timestamp in seconds: 1970-01-01 00:00:01
+ expected_output = 1000; // 1 seconds
+ EXPECT_EQ(expected_output, to_time_int32(1));
+ EXPECT_EQ(expected_output, to_time_int64(1));
+ EXPECT_EQ(expected_output, to_time_float32(1.000f));
+ EXPECT_EQ(expected_output, to_time_float64(1.000));
+
+ // input timestamp in seconds: 1970-01-01 01:00:00
+ expected_output = 3600000; // 3600 seconds
+ EXPECT_EQ(expected_output, to_time_int32(3600));
+ EXPECT_EQ(expected_output, to_time_int64(3600));
+ EXPECT_EQ(expected_output, to_time_float32(3600.000f));
+ EXPECT_EQ(expected_output, to_time_float64(3600.000));
+
+ // input timestamp in seconds: 1970-01-01 23:59:59
+ expected_output = 86399000; // 86399 seconds
+ EXPECT_EQ(expected_output, to_time_int32(86399));
+ EXPECT_EQ(expected_output, to_time_int64(86399));
+ EXPECT_EQ(expected_output, to_time_float32(86399.000f));
+ EXPECT_EQ(expected_output, to_time_float64(86399.000));
+
+ // input timestamp in seconds: 2020-01-01 00:00:01
+ expected_output = 1000; // 1 second
+ EXPECT_EQ(expected_output, to_time_int64(1577836801));
+ EXPECT_EQ(expected_output, to_time_float64(1577836801.000));
+
+ // tests with fractional part
+ // input timestamp in seconds: 1970-01-01 00:00:01.500
+ expected_output = 1500; // 1.5 seconds
+ EXPECT_EQ(expected_output, to_time_float32(1.500f));
+ EXPECT_EQ(expected_output, to_time_float64(1.500));
+
+ // input timestamp in seconds: 1970-01-01 00:01:01.500
+ expected_output = 61500; // 61.5 seconds
+ EXPECT_EQ(expected_output, to_time_float32(61.500f));
+ EXPECT_EQ(expected_output, to_time_float64(61.500));
+
+ // input timestamp in seconds: 1970-01-01 01:00:01.500
+ expected_output = 3601500; // 3601.5 seconds
+ EXPECT_EQ(expected_output, to_time_float32(3601.500f));
+ EXPECT_EQ(expected_output, to_time_float64(3601.500));
+}
+
+TEST(TestTime, TestCastIntDayInterval) {
+ EXPECT_EQ(castBIGINT_daytimeinterval(10), 864000000);
+ EXPECT_EQ(castBIGINT_daytimeinterval(-100), -8640000001);
+ EXPECT_EQ(castBIGINT_daytimeinterval(-0), 0);
+}
+
+TEST(TestTime, TestCastIntYearInterval) {
+ EXPECT_EQ(castINT_year_interval(24), 2);
+ EXPECT_EQ(castINT_year_interval(-24), -2);
+ EXPECT_EQ(castINT_year_interval(-23), -1);
+
+ EXPECT_EQ(castBIGINT_year_interval(24), 2);
+ EXPECT_EQ(castBIGINT_year_interval(-24), -2);
+ EXPECT_EQ(castBIGINT_year_interval(-23), -1);
+}
+
+TEST(TestTime, TestCastNullableInterval) {
+ ExecutionContext context;
+ auto context_ptr = reinterpret_cast<int64_t>(&context);
+ // Test castNULLABLEINTERVALDAY for int and bigint
+ EXPECT_EQ(castNULLABLEINTERVALDAY_int32(1), 1);
+ EXPECT_EQ(castNULLABLEINTERVALDAY_int32(12), 12);
+ EXPECT_EQ(castNULLABLEINTERVALDAY_int32(-55), -55);
+ EXPECT_EQ(castNULLABLEINTERVALDAY_int32(-1201), -1201);
+ EXPECT_EQ(castNULLABLEINTERVALDAY_int64(1), 1);
+ EXPECT_EQ(castNULLABLEINTERVALDAY_int64(12), 12);
+ EXPECT_EQ(castNULLABLEINTERVALDAY_int64(-55), -55);
+ EXPECT_EQ(castNULLABLEINTERVALDAY_int64(-1201), -1201);
+
+ // Test castNULLABLEINTERVALYEAR for int and bigint
+ EXPECT_EQ(castNULLABLEINTERVALYEAR_int32(context_ptr, 1), 1);
+ EXPECT_EQ(castNULLABLEINTERVALYEAR_int32(context_ptr, 12), 12);
+ EXPECT_EQ(castNULLABLEINTERVALYEAR_int32(context_ptr, 55), 55);
+ EXPECT_EQ(castNULLABLEINTERVALYEAR_int32(context_ptr, 1201), 1201);
+ EXPECT_EQ(castNULLABLEINTERVALYEAR_int64(context_ptr, 1), 1);
+ EXPECT_EQ(castNULLABLEINTERVALYEAR_int64(context_ptr, 12), 12);
+ EXPECT_EQ(castNULLABLEINTERVALYEAR_int64(context_ptr, 55), 55);
+ EXPECT_EQ(castNULLABLEINTERVALYEAR_int64(context_ptr, 1201), 1201);
+ // validate overflow error when using bigint as input
+ castNULLABLEINTERVALYEAR_int64(context_ptr, INT64_MAX);
+ EXPECT_EQ(context.get_error(), "Integer overflow");
+ context.Reset();
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/precompiled/timestamp_arithmetic.cc b/src/arrow/cpp/src/gandiva/precompiled/timestamp_arithmetic.cc
new file mode 100644
index 000000000..695605b3c
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/timestamp_arithmetic.cc
@@ -0,0 +1,283 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "./epoch_time_point.h"
+
+// The first row is for non-leap years
+static int days_in_a_month[2][12] = {{31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31},
+ {31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}};
+
+bool is_leap_year(int yy) {
+ if ((yy % 4) != 0) {
+ // not divisible by 4
+ return false;
+ }
+ // yy = 4x
+ if ((yy % 400) == 0) {
+ // yy = 400x
+ return true;
+ }
+ // yy = 4x, return true if yy != 100x
+ return ((yy % 100) != 0);
+}
+
+bool is_last_day_of_month(const EpochTimePoint& tp) {
+ int matrix_index = is_leap_year(tp.TmYear()) ? 1 : 0;
+
+ return (tp.TmMday() == days_in_a_month[matrix_index][tp.TmMon()]);
+}
+
+bool did_days_overflow(arrow_vendored::date::year_month_day ymd) {
+ int year = static_cast<int>(ymd.year());
+ int month = static_cast<unsigned int>(ymd.month());
+ int days = static_cast<unsigned int>(ymd.day());
+
+ int matrix_index = is_leap_year(year) ? 1 : 0;
+
+ return days > days_in_a_month[matrix_index][month - 1];
+}
+
+int last_possible_day_in_month(int year, int month) {
+ int matrix_index = is_leap_year(year) ? 1 : 0;
+
+ return days_in_a_month[matrix_index][month - 1];
+}
+
+extern "C" {
+
+#include <time.h>
+
+#include "./time_constants.h"
+#include "./types.h"
+
+#define TIMESTAMP_DIFF_FIXED_UNITS(TYPE, NAME, FROM_MILLIS) \
+ FORCE_INLINE \
+ gdv_int32 NAME##_##TYPE##_##TYPE(gdv_##TYPE start_millis, gdv_##TYPE end_millis) { \
+ return static_cast<int32_t>(FROM_MILLIS(end_millis - start_millis)); \
+ }
+
+#define SIGN_ADJUST_DIFF(is_positive, diff) ((is_positive) ? (diff) : -(diff))
+#define MONTHS_TO_TIMEUNIT(diff, num_months) (diff) / (num_months)
+
+// Assuming end_millis > start_millis, the algorithm to find the diff in months is:
+// diff_in_months = year_diff * 12 + month_diff
+// This is approximately correct, except when the last month has not fully elapsed
+//
+// a) If end_day > start_day, return diff_in_months e.g. diff(2015-09-10, 2017-03-31)
+// b) If end_day < start_day, return diff_in_months - 1 e.g. diff(2015-09-30, 2017-03-10)
+// c) If end_day = start_day, check for millis e.g. diff(2017-03-10, 2015-03-10)
+// Need to check if end_millis_in_day > start_millis_in_day
+// c1) If end_millis_in_day >= start_millis_in_day, return diff_in_months
+// c2) else return diff_in_months - 1
+#define TIMESTAMP_DIFF_MONTH_UNITS(TYPE, NAME, N_MONTHS) \
+ FORCE_INLINE \
+ gdv_int32 NAME##_##TYPE##_##TYPE(gdv_##TYPE start_millis, gdv_##TYPE end_millis) { \
+ gdv_int32 diff; \
+ bool is_positive = (end_millis > start_millis); \
+ if (!is_positive) { \
+ /* if end_millis < start_millis, swap and multiply by -1 at the end */ \
+ gdv_##TYPE tmp = start_millis; \
+ start_millis = end_millis; \
+ end_millis = tmp; \
+ } \
+ EpochTimePoint start_tm(start_millis); \
+ EpochTimePoint end_tm(end_millis); \
+ gdv_int32 months_diff; \
+ months_diff = static_cast<gdv_int32>(12 * (end_tm.TmYear() - start_tm.TmYear()) + \
+ (end_tm.TmMon() - start_tm.TmMon())); \
+ if (end_tm.TmMday() > start_tm.TmMday()) { \
+ /* case a */ \
+ diff = MONTHS_TO_TIMEUNIT(months_diff, N_MONTHS); \
+ return SIGN_ADJUST_DIFF(is_positive, diff); \
+ } \
+ if (end_tm.TmMday() < start_tm.TmMday()) { \
+ /* case b */ \
+ months_diff += (is_last_day_of_month(end_tm) ? 1 : 0); \
+ diff = MONTHS_TO_TIMEUNIT(months_diff - 1, N_MONTHS); \
+ return SIGN_ADJUST_DIFF(is_positive, diff); \
+ } \
+ gdv_int32 end_day_millis = \
+ static_cast<gdv_int32>(end_tm.TmHour() * MILLIS_IN_HOUR + \
+ end_tm.TmMin() * MILLIS_IN_MIN + end_tm.TmSec()); \
+ gdv_int32 start_day_millis = \
+ static_cast<gdv_int32>(start_tm.TmHour() * MILLIS_IN_HOUR + \
+ start_tm.TmMin() * MILLIS_IN_MIN + start_tm.TmSec()); \
+ if (end_day_millis >= start_day_millis) { \
+ /* case c1 */ \
+ diff = MONTHS_TO_TIMEUNIT(months_diff, N_MONTHS); \
+ return SIGN_ADJUST_DIFF(is_positive, diff); \
+ } \
+ /* case c2 */ \
+ diff = MONTHS_TO_TIMEUNIT(months_diff - 1, N_MONTHS); \
+ return SIGN_ADJUST_DIFF(is_positive, diff); \
+ }
+
+#define TIMESTAMP_DIFF(TYPE) \
+ TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffSecond, MILLIS_TO_SEC) \
+ TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffMinute, MILLIS_TO_MINS) \
+ TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffHour, MILLIS_TO_HOUR) \
+ TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffDay, MILLIS_TO_DAY) \
+ TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffWeek, MILLIS_TO_WEEK) \
+ TIMESTAMP_DIFF_MONTH_UNITS(TYPE, timestampdiffMonth, 1) \
+ TIMESTAMP_DIFF_MONTH_UNITS(TYPE, timestampdiffQuarter, 3) \
+ TIMESTAMP_DIFF_MONTH_UNITS(TYPE, timestampdiffYear, 12)
+
+TIMESTAMP_DIFF(timestamp)
+
+#define ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \
+ FORCE_INLINE \
+ gdv_##TYPE NAME##_int32_##TYPE(gdv_int32 count, gdv_##TYPE millis) { \
+ return millis + TO_MILLIS * static_cast<gdv_##TYPE>(count); \
+ }
+
+// Documentation of mktime suggests that it handles
+// TmMon() being negative, and also TmMon() being >= 12 by
+// adjusting TmYear() accordingly
+//
+// Using gmtime_r() and timegm() instead of localtime_r() and mktime()
+// since the input millis are since epoch
+#define ADD_INT32_TO_TIMESTAMP_MONTH_UNITS(TYPE, NAME, N_MONTHS) \
+ FORCE_INLINE \
+ gdv_##TYPE NAME##_int32_##TYPE(gdv_int32 count, gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return tp.AddMonths(static_cast<int>(count * N_MONTHS)).MillisSinceEpoch(); \
+ }
+
+// TODO: Handle overflow while converting gdv_int64 to millis
+#define ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \
+ FORCE_INLINE \
+ gdv_##TYPE NAME##_int64_##TYPE(gdv_int64 count, gdv_##TYPE millis) { \
+ return millis + TO_MILLIS * static_cast<gdv_##TYPE>(count); \
+ }
+
+#define ADD_INT64_TO_TIMESTAMP_MONTH_UNITS(TYPE, NAME, N_MONTHS) \
+ FORCE_INLINE \
+ gdv_##TYPE NAME##_int64_##TYPE(gdv_int64 count, gdv_##TYPE millis) { \
+ EpochTimePoint tp(millis); \
+ return tp.AddMonths(static_cast<int>(count * N_MONTHS)).MillisSinceEpoch(); \
+ }
+
+#define ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \
+ FORCE_INLINE \
+ gdv_##TYPE NAME##_##TYPE##_int32(gdv_##TYPE millis, gdv_int32 count) { \
+ return millis + TO_MILLIS * static_cast<gdv_##TYPE>(count); \
+ }
+
+#define ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \
+ FORCE_INLINE \
+ gdv_##TYPE NAME##_##TYPE##_int64(gdv_##TYPE millis, gdv_int64 count) { \
+ return millis + TO_MILLIS * static_cast<gdv_##TYPE>(count); \
+ }
+
+#define ADD_TIMESTAMP_TO_INT32_MONTH_UNITS(TYPE, NAME, N_MONTHS) \
+ FORCE_INLINE \
+ gdv_##TYPE NAME##_##TYPE##_int32(gdv_##TYPE millis, gdv_int32 count) { \
+ EpochTimePoint tp(millis); \
+ return tp.AddMonths(static_cast<int>(count * N_MONTHS)).MillisSinceEpoch(); \
+ }
+
+#define ADD_TIMESTAMP_TO_INT64_MONTH_UNITS(TYPE, NAME, N_MONTHS) \
+ FORCE_INLINE \
+ gdv_##TYPE NAME##_##TYPE##_int64(gdv_##TYPE millis, gdv_int64 count) { \
+ EpochTimePoint tp(millis); \
+ return tp.AddMonths(static_cast<int>(count * N_MONTHS)).MillisSinceEpoch(); \
+ }
+
+#define ADD_TIMESTAMP_INT32_FIXEDUNITS(TYPE, NAME, TO_MILLIS) \
+ ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \
+ ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(TYPE, NAME, TO_MILLIS)
+
+#define ADD_TIMESTAMP_INT32_MONTHUNITS(TYPE, NAME, N_MONTHS) \
+ ADD_INT32_TO_TIMESTAMP_MONTH_UNITS(TYPE, NAME, N_MONTHS) \
+ ADD_TIMESTAMP_TO_INT32_MONTH_UNITS(TYPE, NAME, N_MONTHS)
+
+#define TIMESTAMP_ADD_INT32(TYPE) \
+ ADD_TIMESTAMP_INT32_FIXEDUNITS(TYPE, timestampaddSecond, MILLIS_IN_SEC) \
+ ADD_TIMESTAMP_INT32_FIXEDUNITS(TYPE, timestampaddMinute, MILLIS_IN_MIN) \
+ ADD_TIMESTAMP_INT32_FIXEDUNITS(TYPE, timestampaddHour, MILLIS_IN_HOUR) \
+ ADD_TIMESTAMP_INT32_FIXEDUNITS(TYPE, timestampaddDay, MILLIS_IN_DAY) \
+ ADD_TIMESTAMP_INT32_FIXEDUNITS(TYPE, timestampaddWeek, MILLIS_IN_WEEK) \
+ ADD_TIMESTAMP_INT32_MONTHUNITS(TYPE, timestampaddMonth, 1) \
+ ADD_TIMESTAMP_INT32_MONTHUNITS(TYPE, timestampaddQuarter, 3) \
+ ADD_TIMESTAMP_INT32_MONTHUNITS(TYPE, timestampaddYear, 12)
+
+#define ADD_TIMESTAMP_INT64_FIXEDUNITS(TYPE, NAME, TO_MILLIS) \
+ ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \
+ ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(TYPE, NAME, TO_MILLIS)
+
+#define ADD_TIMESTAMP_INT64_MONTHUNITS(TYPE, NAME, N_MONTHS) \
+ ADD_INT64_TO_TIMESTAMP_MONTH_UNITS(TYPE, NAME, N_MONTHS) \
+ ADD_TIMESTAMP_TO_INT64_MONTH_UNITS(TYPE, NAME, N_MONTHS)
+
+#define TIMESTAMP_ADD_INT64(TYPE) \
+ ADD_TIMESTAMP_INT64_FIXEDUNITS(TYPE, timestampaddSecond, MILLIS_IN_SEC) \
+ ADD_TIMESTAMP_INT64_FIXEDUNITS(TYPE, timestampaddMinute, MILLIS_IN_MIN) \
+ ADD_TIMESTAMP_INT64_FIXEDUNITS(TYPE, timestampaddHour, MILLIS_IN_HOUR) \
+ ADD_TIMESTAMP_INT64_FIXEDUNITS(TYPE, timestampaddDay, MILLIS_IN_DAY) \
+ ADD_TIMESTAMP_INT64_FIXEDUNITS(TYPE, timestampaddWeek, MILLIS_IN_WEEK) \
+ ADD_TIMESTAMP_INT64_MONTHUNITS(TYPE, timestampaddMonth, 1) \
+ ADD_TIMESTAMP_INT64_MONTHUNITS(TYPE, timestampaddQuarter, 3) \
+ ADD_TIMESTAMP_INT64_MONTHUNITS(TYPE, timestampaddYear, 12)
+
+#define TIMESTAMP_ADD_INT(TYPE) \
+ TIMESTAMP_ADD_INT32(TYPE) \
+ TIMESTAMP_ADD_INT64(TYPE)
+
+TIMESTAMP_ADD_INT(date64)
+TIMESTAMP_ADD_INT(timestamp)
+
+// add gdv_int32 to timestamp
+ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(date64, date_add, MILLIS_IN_DAY)
+ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(date64, add, MILLIS_IN_DAY)
+ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(timestamp, date_add, MILLIS_IN_DAY)
+ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(timestamp, add, MILLIS_IN_DAY)
+
+// add gdv_int64 to timestamp
+ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(date64, date_add, MILLIS_IN_DAY)
+ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(date64, add, MILLIS_IN_DAY)
+ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(timestamp, date_add, MILLIS_IN_DAY)
+ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(timestamp, add, MILLIS_IN_DAY)
+
+// date_sub, subtract, date_diff on gdv_int32
+ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(date64, date_sub, -1 * MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(date64, subtract, -1 * MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(date64, date_diff, -1 * MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(timestamp, date_sub, -1 * MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(timestamp, subtract, -1 * MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(timestamp, date_diff, -1 * MILLIS_IN_DAY)
+
+// date_sub, subtract, date_diff on gdv_int64
+ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(date64, date_sub, -1 * MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(date64, subtract, -1 * MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(date64, date_diff, -1 * MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(timestamp, date_sub, -1 * MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(timestamp, subtract, -1 * MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(timestamp, date_diff, -1 * MILLIS_IN_DAY)
+
+// add timestamp to gdv_int32
+ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(date64, date_add, MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(date64, add, MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(timestamp, date_add, MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(timestamp, add, MILLIS_IN_DAY)
+
+// add timestamp to gdv_int64
+ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(date64, date_add, MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(date64, add, MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(timestamp, date_add, MILLIS_IN_DAY)
+ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(timestamp, add, MILLIS_IN_DAY)
+
+} // extern "C"
diff --git a/src/arrow/cpp/src/gandiva/precompiled/types.h b/src/arrow/cpp/src/gandiva/precompiled/types.h
new file mode 100644
index 000000000..987ee2c6d
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled/types.h
@@ -0,0 +1,592 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+
+#include "gandiva/gdv_function_stubs.h"
+
+// Use the same names as in arrow data types. Makes it easy to write pre-processor macros.
+using gdv_boolean = bool;
+using gdv_int8 = int8_t;
+using gdv_int16 = int16_t;
+using gdv_int32 = int32_t;
+using gdv_int64 = int64_t;
+using gdv_uint8 = uint8_t;
+using gdv_uint16 = uint16_t;
+using gdv_uint32 = uint32_t;
+using gdv_uint64 = uint64_t;
+using gdv_float32 = float;
+using gdv_float64 = double;
+using gdv_date64 = int64_t;
+using gdv_date32 = int32_t;
+using gdv_time32 = int32_t;
+using gdv_timestamp = int64_t;
+using gdv_utf8 = char*;
+using gdv_binary = char*;
+using gdv_day_time_interval = int64_t;
+
+#ifdef GANDIVA_UNIT_TEST
+// unit tests may be compiled without O2, so inlining may not happen.
+#define FORCE_INLINE
+#else
+#define FORCE_INLINE __attribute__((always_inline))
+#endif
+
+extern "C" {
+
+bool bitMapGetBit(const unsigned char* bmap, int64_t position);
+void bitMapSetBit(unsigned char* bmap, int64_t position, bool value);
+void bitMapClearBitIfFalse(unsigned char* bmap, int64_t position, bool value);
+
+gdv_int64 extractMillennium_timestamp(gdv_timestamp millis);
+gdv_int64 extractCentury_timestamp(gdv_timestamp millis);
+gdv_int64 extractDecade_timestamp(gdv_timestamp millis);
+gdv_int64 extractYear_timestamp(gdv_timestamp millis);
+gdv_int64 extractDoy_timestamp(gdv_timestamp millis);
+gdv_int64 extractQuarter_timestamp(gdv_timestamp millis);
+gdv_int64 extractMonth_timestamp(gdv_timestamp millis);
+gdv_int64 extractWeek_timestamp(gdv_timestamp millis);
+gdv_int64 extractDow_timestamp(gdv_timestamp millis);
+gdv_int64 extractDay_timestamp(gdv_timestamp millis);
+gdv_int64 extractHour_timestamp(gdv_timestamp millis);
+gdv_int64 extractMinute_timestamp(gdv_timestamp millis);
+gdv_int64 extractSecond_timestamp(gdv_timestamp millis);
+gdv_int64 extractHour_time32(gdv_int32 millis_in_day);
+gdv_int64 extractMinute_time32(gdv_int32 millis_in_day);
+gdv_int64 extractSecond_time32(gdv_int32 millis_in_day);
+
+gdv_int32 hash32(double val, gdv_int32 seed);
+gdv_int32 hash32_buf(const gdv_uint8* buf, int len, gdv_int32 seed);
+gdv_int64 hash64(double val, gdv_int64 seed);
+gdv_int64 hash64_buf(const gdv_uint8* buf, int len, gdv_int64 seed);
+
+gdv_int32 timestampdiffMonth_timestamp_timestamp(gdv_timestamp, gdv_timestamp);
+
+gdv_int64 timestampaddSecond_int32_timestamp(gdv_int32, gdv_timestamp);
+gdv_int64 timestampaddMinute_int32_timestamp(gdv_int32, gdv_timestamp);
+gdv_int64 timestampaddHour_int32_timestamp(gdv_int32, gdv_timestamp);
+gdv_int64 timestampaddDay_int32_timestamp(gdv_int32, gdv_timestamp);
+gdv_int64 timestampaddWeek_int32_timestamp(gdv_int32, gdv_timestamp);
+gdv_int64 timestampaddMonth_int32_timestamp(gdv_int32, gdv_timestamp);
+gdv_int64 timestampaddQuarter_int32_timestamp(gdv_int32, gdv_timestamp);
+gdv_int64 timestampaddYear_int32_timestamp(gdv_int32, gdv_timestamp);
+
+gdv_int64 timestampaddSecond_timestamp_int32(gdv_timestamp, gdv_int32);
+gdv_int64 timestampaddMinute_timestamp_int32(gdv_timestamp, gdv_int32);
+gdv_int64 timestampaddHour_timestamp_int32(gdv_timestamp, gdv_int32);
+gdv_int64 timestampaddDay_timestamp_int32(gdv_timestamp, gdv_int32);
+gdv_int64 timestampaddWeek_timestamp_int32(gdv_timestamp, gdv_int32);
+gdv_int64 timestampaddMonth_timestamp_int32(gdv_timestamp, gdv_int32);
+gdv_int64 timestampaddQuarter_timestamp_int32(gdv_timestamp, gdv_int32);
+gdv_int64 timestampaddYear_timestamp_int32(gdv_timestamp, gdv_int32);
+
+gdv_int64 timestampaddSecond_int64_timestamp(gdv_int64, gdv_timestamp);
+gdv_int64 timestampaddMinute_int64_timestamp(gdv_int64, gdv_timestamp);
+gdv_int64 timestampaddHour_int64_timestamp(gdv_int64, gdv_timestamp);
+gdv_int64 timestampaddDay_int64_timestamp(gdv_int64, gdv_timestamp);
+gdv_int64 timestampaddWeek_int64_timestamp(gdv_int64, gdv_timestamp);
+gdv_int64 timestampaddMonth_int64_timestamp(gdv_int64, gdv_timestamp);
+gdv_int64 timestampaddQuarter_int64_timestamp(gdv_int64, gdv_timestamp);
+gdv_int64 timestampaddYear_int64_timestamp(gdv_int64, gdv_timestamp);
+
+gdv_int64 timestampaddSecond_timestamp_int64(gdv_timestamp, gdv_int64);
+gdv_int64 timestampaddMinute_timestamp_int64(gdv_timestamp, gdv_int64);
+gdv_int64 timestampaddHour_timestamp_int64(gdv_timestamp, gdv_int64);
+gdv_int64 timestampaddDay_timestamp_int64(gdv_timestamp, gdv_int64);
+gdv_int64 timestampaddWeek_timestamp_int64(gdv_timestamp, gdv_int64);
+gdv_int64 timestampaddMonth_timestamp_int64(gdv_timestamp, gdv_int64);
+gdv_int64 timestampaddQuarter_timestamp_int64(gdv_timestamp, gdv_int64);
+gdv_int64 timestampaddYear_timestamp_int64(gdv_timestamp, gdv_int64);
+
+gdv_int64 date_add_int32_timestamp(gdv_int32, gdv_timestamp);
+gdv_int64 add_int64_timestamp(gdv_int64, gdv_timestamp);
+gdv_int64 add_int32_timestamp(gdv_int32, gdv_timestamp);
+gdv_int64 date_add_int64_timestamp(gdv_int64, gdv_timestamp);
+gdv_timestamp add_date64_int64(gdv_date64, gdv_int64);
+
+gdv_timestamp to_timestamp_int32(gdv_int32);
+gdv_timestamp to_timestamp_int64(gdv_int64);
+gdv_timestamp to_timestamp_float32(gdv_float32);
+gdv_timestamp to_timestamp_float64(gdv_float64);
+
+gdv_time32 to_time_int32(gdv_int32);
+gdv_time32 to_time_int64(gdv_int64);
+gdv_time32 to_time_float32(gdv_float32);
+gdv_time32 to_time_float64(gdv_float64);
+
+gdv_int64 date_sub_timestamp_int32(gdv_timestamp, gdv_int32);
+gdv_int64 subtract_timestamp_int32(gdv_timestamp, gdv_int32);
+gdv_int64 date_diff_timestamp_int64(gdv_timestamp, gdv_int64);
+
+gdv_boolean castBIT_utf8(gdv_int64 context, const char* data, gdv_int32 data_len);
+
+bool is_distinct_from_timestamp_timestamp(gdv_int64, bool, gdv_int64, bool);
+bool is_not_distinct_from_int32_int32(gdv_int32, bool, gdv_int32, bool);
+
+gdv_int64 date_trunc_Second_date64(gdv_date64);
+gdv_int64 date_trunc_Minute_date64(gdv_date64);
+gdv_int64 date_trunc_Hour_date64(gdv_date64);
+gdv_int64 date_trunc_Day_date64(gdv_date64);
+gdv_int64 date_trunc_Month_date64(gdv_date64);
+gdv_int64 date_trunc_Quarter_date64(gdv_date64);
+gdv_int64 date_trunc_Year_date64(gdv_date64);
+gdv_int64 date_trunc_Decade_date64(gdv_date64);
+gdv_int64 date_trunc_Century_date64(gdv_date64);
+gdv_int64 date_trunc_Millennium_date64(gdv_date64);
+
+gdv_int64 date_trunc_Week_timestamp(gdv_timestamp);
+double months_between_timestamp_timestamp(gdv_uint64, gdv_uint64);
+
+gdv_int32 mem_compare(const char* left, gdv_int32 left_len, const char* right,
+ gdv_int32 right_len);
+
+gdv_int32 mod_int64_int32(gdv_int64 left, gdv_int32 right);
+gdv_float64 mod_float64_float64(gdv_int64 context, gdv_float64 left, gdv_float64 right);
+
+gdv_int64 divide_int64_int64(gdv_int64 context, gdv_int64 in1, gdv_int64 in2);
+
+gdv_int64 div_int64_int64(gdv_int64 context, gdv_int64 in1, gdv_int64 in2);
+gdv_float32 div_float32_float32(gdv_int64 context, gdv_float32 in1, gdv_float32 in2);
+gdv_float64 div_float64_float64(gdv_int64 context, gdv_float64 in1, gdv_float64 in2);
+
+gdv_float32 round_float32(gdv_float32);
+gdv_float64 round_float64(gdv_float64);
+gdv_float32 round_float32_int32(gdv_float32 number, gdv_int32 out_scale);
+gdv_float64 round_float64_int32(gdv_float64 number, gdv_int32 out_scale);
+gdv_float64 get_scale_multiplier(gdv_int32);
+gdv_int32 round_int32_int32(gdv_int32 number, gdv_int32 precision);
+gdv_int64 round_int64_int32(gdv_int64 number, gdv_int32 precision);
+gdv_int32 round_int32(gdv_int32);
+gdv_int64 round_int64(gdv_int64);
+gdv_int64 get_power_of_10(gdv_int32);
+
+const char* bin_int32(int64_t context, gdv_int32 value, int32_t* out_len);
+const char* bin_int64(int64_t context, gdv_int64 value, int32_t* out_len);
+
+gdv_float64 cbrt_int32(gdv_int32);
+gdv_float64 cbrt_int64(gdv_int64);
+gdv_float64 cbrt_float32(gdv_float32);
+gdv_float64 cbrt_float64(gdv_float64);
+
+gdv_float64 exp_int32(gdv_int32);
+gdv_float64 exp_int64(gdv_int64);
+gdv_float64 exp_float32(gdv_float32);
+gdv_float64 exp_float64(gdv_float64);
+
+gdv_float64 log_int32(gdv_int32);
+gdv_float64 log_int64(gdv_int64);
+gdv_float64 log_float32(gdv_float32);
+gdv_float64 log_float64(gdv_float64);
+
+gdv_float64 log10_int32(gdv_int32);
+gdv_float64 log10_int64(gdv_int64);
+gdv_float64 log10_float32(gdv_float32);
+gdv_float64 log10_float64(gdv_float64);
+
+gdv_float64 sin_int32(gdv_int32);
+gdv_float64 sin_int64(gdv_int64);
+gdv_float64 sin_float32(gdv_float32);
+gdv_float64 sin_float64(gdv_float64);
+gdv_float64 cos_int32(gdv_int32);
+gdv_float64 cos_int64(gdv_int64);
+gdv_float64 cos_float32(gdv_float32);
+gdv_float64 cos_float64(gdv_float64);
+gdv_float64 asin_int32(gdv_int32);
+gdv_float64 asin_int64(gdv_int64);
+gdv_float64 asin_float32(gdv_float32);
+gdv_float64 asin_float64(gdv_float64);
+gdv_float64 acos_int32(gdv_int32);
+gdv_float64 acos_int64(gdv_int64);
+gdv_float64 acos_float32(gdv_float32);
+gdv_float64 acos_float64(gdv_float64);
+gdv_float64 tan_int32(gdv_int32);
+gdv_float64 tan_int64(gdv_int64);
+gdv_float64 tan_float32(gdv_float32);
+gdv_float64 tan_float64(gdv_float64);
+gdv_float64 atan_int32(gdv_int32);
+gdv_float64 atan_int64(gdv_int64);
+gdv_float64 atan_float32(gdv_float32);
+gdv_float64 atan_float64(gdv_float64);
+gdv_float64 sinh_int32(gdv_int32);
+gdv_float64 sinh_int64(gdv_int64);
+gdv_float64 sinh_float32(gdv_float32);
+gdv_float64 sinh_float64(gdv_float64);
+gdv_float64 cosh_int32(gdv_int32);
+gdv_float64 cosh_int64(gdv_int64);
+gdv_float64 cosh_float32(gdv_float32);
+gdv_float64 cosh_float64(gdv_float64);
+gdv_float64 tanh_int32(gdv_int32);
+gdv_float64 tanh_int64(gdv_int64);
+gdv_float64 tanh_float32(gdv_float32);
+gdv_float64 tanh_float64(gdv_float64);
+gdv_float64 atan2_int32_int32(gdv_int32 in1, gdv_int32 in2);
+gdv_float64 atan2_int64_int64(gdv_int64 in1, gdv_int64 in2);
+gdv_float64 atan2_float32_float32(gdv_float32 in1, gdv_float32 in2);
+gdv_float64 atan2_float64_float64(gdv_float64 in1, gdv_float64 in2);
+gdv_float64 cot_float32(gdv_float32);
+gdv_float64 cot_float64(gdv_float64);
+gdv_float64 radians_int32(gdv_int32);
+gdv_float64 radians_int64(gdv_int64);
+gdv_float64 radians_float32(gdv_float32);
+gdv_float64 radians_float64(gdv_float64);
+gdv_float64 degrees_int32(gdv_int32);
+gdv_float64 degrees_int64(gdv_int64);
+gdv_float64 degrees_float32(gdv_float32);
+gdv_float64 degrees_float64(gdv_float64);
+
+gdv_int32 bitwise_and_int32_int32(gdv_int32 in1, gdv_int32 in2);
+gdv_int64 bitwise_and_int64_int64(gdv_int64 in1, gdv_int64 in2);
+gdv_int32 bitwise_or_int32_int32(gdv_int32 in1, gdv_int32 in2);
+gdv_int64 bitwise_or_int64_int64(gdv_int64 in1, gdv_int64 in2);
+gdv_int32 bitwise_xor_int32_int32(gdv_int32 in1, gdv_int32 in2);
+gdv_int64 bitwise_xor_int64_int64(gdv_int64 in1, gdv_int64 in2);
+gdv_int32 bitwise_not_int32(gdv_int32);
+gdv_int64 bitwise_not_int64(gdv_int64);
+
+gdv_float64 power_float64_float64(gdv_float64, gdv_float64);
+
+gdv_float64 log_int32_int32(gdv_int64 context, gdv_int32 base, gdv_int32 value);
+
+bool starts_with_utf8_utf8(const char* data, gdv_int32 data_len, const char* prefix,
+ gdv_int32 prefix_len);
+bool ends_with_utf8_utf8(const char* data, gdv_int32 data_len, const char* suffix,
+ gdv_int32 suffix_len);
+bool is_substr_utf8_utf8(const char* data, gdv_int32 data_len, const char* substr,
+ gdv_int32 substr_len);
+
+gdv_int32 utf8_length(gdv_int64 context, const char* data, gdv_int32 data_len);
+
+gdv_int32 utf8_last_char_pos(gdv_int64 context, const char* data, gdv_int32 data_len);
+
+gdv_date64 castDATE_utf8(int64_t execution_context, const char* input, gdv_int32 length);
+
+gdv_date64 castDATE_int64(gdv_int64 date);
+
+gdv_date64 castDATE_date32(gdv_date32 date);
+
+gdv_date32 castDATE_int32(gdv_int32 date);
+
+gdv_timestamp castTIMESTAMP_utf8(int64_t execution_context, const char* input,
+ gdv_int32 length);
+gdv_timestamp castTIMESTAMP_date64(gdv_date64);
+gdv_timestamp castTIMESTAMP_int64(gdv_int64);
+gdv_date64 castDATE_timestamp(gdv_timestamp);
+gdv_time32 castTIME_timestamp(gdv_timestamp timestamp_in_millis);
+const char* castVARCHAR_timestamp_int64(int64_t, gdv_timestamp, gdv_int64, gdv_int32*);
+gdv_date64 last_day_from_timestamp(gdv_date64 millis);
+
+gdv_int64 truncate_int64_int32(gdv_int64 in, gdv_int32 out_scale);
+
+const char* repeat_utf8_int32(gdv_int64 context, const char* in, gdv_int32 in_len,
+ gdv_int32 repeat_times, gdv_int32* out_len);
+
+const char* substr_utf8_int64_int64(gdv_int64 context, const char* input,
+ gdv_int32 in_len, gdv_int64 offset64,
+ gdv_int64 length, gdv_int32* out_len);
+const char* substr_utf8_int64(gdv_int64 context, const char* input, gdv_int32 in_len,
+ gdv_int64 offset64, gdv_int32* out_len);
+
+const char* concat_utf8_utf8(gdv_int64 context, const char* left, gdv_int32 left_len,
+ bool left_validity, const char* right, gdv_int32 right_len,
+ bool right_validity, gdv_int32* out_len);
+const char* concat_utf8_utf8_utf8(gdv_int64 context, const char* in1, gdv_int32 in1_len,
+ bool in1_validity, const char* in2, gdv_int32 in2_len,
+ bool in2_validity, const char* in3, gdv_int32 in3_len,
+ bool in3_validity, gdv_int32* out_len);
+const char* concat_utf8_utf8_utf8_utf8(gdv_int64 context, const char* in1,
+ gdv_int32 in1_len, bool in1_validity,
+ const char* in2, gdv_int32 in2_len,
+ bool in2_validity, const char* in3,
+ gdv_int32 in3_len, bool in3_validity,
+ const char* in4, gdv_int32 in4_len,
+ bool in4_validity, gdv_int32* out_len);
+const char* space_int32(gdv_int64 ctx, gdv_int32 n, int32_t* out_len);
+const char* space_int64(gdv_int64 ctx, gdv_int64 n, int32_t* out_len);
+const char* concat_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity,
+ const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3,
+ gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len,
+ bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity,
+ gdv_int32* out_len);
+const char* concat_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity,
+ const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3,
+ gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len,
+ bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity,
+ const char* in6, gdv_int32 in6_len, bool in6_validity, gdv_int32* out_len);
+const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity,
+ const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3,
+ gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len,
+ bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity,
+ const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7,
+ gdv_int32 in7_len, bool in7_validity, gdv_int32* out_len);
+const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity,
+ const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3,
+ gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len,
+ bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity,
+ const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7,
+ gdv_int32 in7_len, bool in7_validity, const char* in8, gdv_int32 in8_len,
+ bool in8_validity, gdv_int32* out_len);
+const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity,
+ const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3,
+ gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len,
+ bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity,
+ const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7,
+ gdv_int32 in7_len, bool in7_validity, const char* in8, gdv_int32 in8_len,
+ bool in8_validity, const char* in9, gdv_int32 in9_len, bool in9_validity,
+ gdv_int32* out_len);
+const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity,
+ const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3,
+ gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len,
+ bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity,
+ const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7,
+ gdv_int32 in7_len, bool in7_validity, const char* in8, gdv_int32 in8_len,
+ bool in8_validity, const char* in9, gdv_int32 in9_len, bool in9_validity,
+ const char* in10, gdv_int32 in10_len, bool in10_validity, gdv_int32* out_len);
+
+const char* concatOperator_utf8_utf8(gdv_int64 context, const char* left,
+ gdv_int32 left_len, const char* right,
+ gdv_int32 right_len, gdv_int32* out_len);
+const char* concatOperator_utf8_utf8_utf8(gdv_int64 context, const char* in1,
+ gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3,
+ gdv_int32 in3_len, gdv_int32* out_len);
+const char* concatOperator_utf8_utf8_utf8_utf8(gdv_int64 context, const char* in1,
+ gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3,
+ gdv_int32 in3_len, const char* in4,
+ gdv_int32 in4_len, gdv_int32* out_len);
+const char* concatOperator_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4,
+ gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, gdv_int32* out_len);
+const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4,
+ gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6,
+ gdv_int32 in6_len, gdv_int32* out_len);
+const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4,
+ gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6,
+ gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, gdv_int32* out_len);
+const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4,
+ gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6,
+ gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, const char* in8,
+ gdv_int32 in8_len, gdv_int32* out_len);
+const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4,
+ gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6,
+ gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, const char* in8,
+ gdv_int32 in8_len, const char* in9, gdv_int32 in9_len, gdv_int32* out_len);
+const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8(
+ gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2,
+ gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4,
+ gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6,
+ gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, const char* in8,
+ gdv_int32 in8_len, const char* in9, gdv_int32 in9_len, const char* in10,
+ gdv_int32 in10_len, gdv_int32* out_len);
+
+const char* castVARCHAR_binary_int64(gdv_int64 context, const char* data,
+ gdv_int32 data_len, int64_t out_len,
+ int32_t* out_length);
+
+const char* castVARCHAR_utf8_int64(gdv_int64 context, const char* data,
+ gdv_int32 data_len, int64_t out_len,
+ int32_t* out_length);
+
+const char* castVARBINARY_utf8_int64(gdv_int64 context, const char* data,
+ gdv_int32 data_len, int64_t out_len,
+ int32_t* out_length);
+
+const char* castVARBINARY_binary_int64(gdv_int64 context, const char* data,
+ gdv_int32 data_len, int64_t out_len,
+ int32_t* out_length);
+
+const char* reverse_utf8(gdv_int64 context, const char* data, gdv_int32 data_len,
+ int32_t* out_len);
+
+const char* ltrim_utf8(gdv_int64 context, const char* data, gdv_int32 data_len,
+ int32_t* out_len);
+
+const char* rtrim_utf8(gdv_int64 context, const char* data, gdv_int32 data_len,
+ int32_t* out_len);
+
+const char* btrim_utf8(gdv_int64 context, const char* data, gdv_int32 data_len,
+ int32_t* out_len);
+
+const char* ltrim_utf8_utf8(gdv_int64 context, const char* basetext,
+ gdv_int32 basetext_len, const char* trimtext,
+ gdv_int32 trimtext_len, int32_t* out_len);
+
+const char* rtrim_utf8_utf8(gdv_int64 context, const char* basetext,
+ gdv_int32 basetext_len, const char* trimtext,
+ gdv_int32 trimtext_len, int32_t* out_len);
+
+const char* btrim_utf8_utf8(gdv_int64 context, const char* basetext,
+ gdv_int32 basetext_len, const char* trimtext,
+ gdv_int32 trimtext_len, int32_t* out_len);
+
+gdv_int32 ascii_utf8(const char* data, gdv_int32 data_len);
+
+gdv_int32 locate_utf8_utf8(gdv_int64 context, const char* sub_str, gdv_int32 sub_str_len,
+ const char* str, gdv_int32 str_len);
+
+gdv_int32 strpos_utf8_utf8(gdv_int64 context, const char* str, gdv_int32 str_len,
+ const char* sub_str, gdv_int32 sub_str_len);
+
+gdv_int32 locate_utf8_utf8_int32(gdv_int64 context, const char* sub_str,
+ gdv_int32 sub_str_len, const char* str,
+ gdv_int32 str_len, gdv_int32 start_pos);
+
+const char* lpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 text_len,
+ gdv_int32 return_length, const char* fill_text,
+ gdv_int32 fill_text_len, gdv_int32* out_len);
+
+const char* rpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 text_len,
+ gdv_int32 return_length, const char* fill_text,
+ gdv_int32 fill_text_len, gdv_int32* out_len);
+
+const char* lpad_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len,
+ gdv_int32 return_length, gdv_int32* out_len);
+
+const char* rpad_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len,
+ gdv_int32 return_length, gdv_int32* out_len);
+
+const char* replace_with_max_len_utf8_utf8_utf8(gdv_int64 context, const char* text,
+ gdv_int32 text_len, const char* from_str,
+ gdv_int32 from_str_len,
+ const char* to_str, gdv_int32 to_str_len,
+ gdv_int32 max_length, gdv_int32* out_len);
+
+const char* replace_utf8_utf8_utf8(gdv_int64 context, const char* text,
+ gdv_int32 text_len, const char* from_str,
+ gdv_int32 from_str_len, const char* to_str,
+ gdv_int32 to_str_len, gdv_int32* out_len);
+
+const char* convert_replace_invalid_fromUTF8_binary(int64_t context, const char* text_in,
+ int32_t text_len,
+ const char* char_to_replace,
+ int32_t char_to_replace_len,
+ int32_t* out_len);
+
+const char* convert_toDOUBLE(int64_t context, double value, int32_t* out_len);
+
+const char* convert_toDOUBLE_be(int64_t context, double value, int32_t* out_len);
+
+const char* convert_toFLOAT(int64_t context, float value, int32_t* out_len);
+
+const char* convert_toFLOAT_be(int64_t context, float value, int32_t* out_len);
+
+const char* convert_toBIGINT(int64_t context, int64_t value, int32_t* out_len);
+
+const char* convert_toBIGINT_be(int64_t context, int64_t value, int32_t* out_len);
+
+const char* convert_toINT(int64_t context, int32_t value, int32_t* out_len);
+
+const char* convert_toINT_be(int64_t context, int32_t value, int32_t* out_len);
+
+const char* convert_toBOOLEAN(int64_t context, bool value, int32_t* out_len);
+
+const char* convert_toTIME_EPOCH(int64_t context, int32_t value, int32_t* out_len);
+
+const char* convert_toTIME_EPOCH_be(int64_t context, int32_t value, int32_t* out_len);
+
+const char* convert_toTIMESTAMP_EPOCH(int64_t context, int64_t timestamp,
+ int32_t* out_len);
+const char* convert_toTIMESTAMP_EPOCH_be(int64_t context, int64_t timestamp,
+ int32_t* out_len);
+
+const char* convert_toDATE_EPOCH(int64_t context, int64_t date, int32_t* out_len);
+
+const char* convert_toDATE_EPOCH_be(int64_t context, int64_t date, int32_t* out_len);
+
+const char* convert_toUTF8(int64_t context, const char* value, int32_t value_len,
+ int32_t* out_len);
+
+const char* split_part(gdv_int64 context, const char* text, gdv_int32 text_len,
+ const char* splitter, gdv_int32 split_len, gdv_int32 index,
+ gdv_int32* out_len);
+
+const char* byte_substr_binary_int32_int32(gdv_int64 context, const char* text,
+ gdv_int32 text_len, gdv_int32 offset,
+ gdv_int32 length, gdv_int32* out_len);
+
+const char* castVARCHAR_bool_int64(gdv_int64 context, gdv_boolean value,
+ gdv_int64 out_len, gdv_int32* out_length);
+
+const char* castVARCHAR_int32_int64(int64_t context, int32_t value, int64_t len,
+ int32_t* out_len);
+
+const char* castVARCHAR_int64_int64(int64_t context, int64_t value, int64_t len,
+ int32_t* out_len);
+
+const char* castVARCHAR_float32_int64(int64_t context, float value, int64_t len,
+ int32_t* out_len);
+
+const char* castVARCHAR_float64_int64(int64_t context, double value, int64_t len,
+ int32_t* out_len);
+
+const char* left_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len,
+ gdv_int32 number, gdv_int32* out_len);
+
+const char* right_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len,
+ gdv_int32 number, gdv_int32* out_len);
+
+const char* binary_string(gdv_int64 context, const char* text, gdv_int32 text_len,
+ gdv_int32* out_len);
+
+int32_t castINT_utf8(int64_t context, const char* data, int32_t len);
+
+int64_t castBIGINT_utf8(int64_t context, const char* data, int32_t len);
+
+float castFLOAT4_utf8(int64_t context, const char* data, int32_t len);
+
+double castFLOAT8_utf8(int64_t context, const char* data, int32_t len);
+
+int32_t castINT_float32(gdv_float32 value);
+
+int32_t castINT_float64(gdv_float64 value);
+
+int64_t castBIGINT_float32(gdv_float32 value);
+
+int64_t castBIGINT_float64(gdv_float64 value);
+
+int64_t castBIGINT_daytimeinterval(gdv_day_time_interval in);
+
+int32_t castINT_year_interval(gdv_month_interval in);
+
+int64_t castBIGINT_year_interval(gdv_month_interval in);
+
+gdv_day_time_interval castNULLABLEINTERVALDAY_int32(gdv_int32 in);
+
+gdv_day_time_interval castNULLABLEINTERVALDAY_int64(gdv_int64 in);
+
+gdv_month_interval castNULLABLEINTERVALYEAR_int32(int64_t context, gdv_int32 in);
+
+gdv_month_interval castNULLABLEINTERVALYEAR_int64(int64_t context, gdv_int64 in);
+
+} // extern "C"
diff --git a/src/arrow/cpp/src/gandiva/precompiled_bitcode.cc.in b/src/arrow/cpp/src/gandiva/precompiled_bitcode.cc.in
new file mode 100644
index 000000000..9c382961d
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/precompiled_bitcode.cc.in
@@ -0,0 +1,26 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <string>
+
+namespace gandiva {
+
+// Content of precompiled bitcode file.
+extern const unsigned char kPrecompiledBitcode[] = { <DATA_CHARS> };
+extern const size_t kPrecompiledBitcodeSize = sizeof(kPrecompiledBitcode);
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/projector.cc b/src/arrow/cpp/src/gandiva/projector.cc
new file mode 100644
index 000000000..ff167538f
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/projector.cc
@@ -0,0 +1,369 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/projector.h"
+
+#include <memory>
+#include <thread>
+#include <utility>
+#include <vector>
+
+#include "arrow/util/hash_util.h"
+#include "arrow/util/logging.h"
+
+#include "gandiva/cache.h"
+#include "gandiva/expr_validator.h"
+#include "gandiva/llvm_generator.h"
+
+namespace gandiva {
+
+class ProjectorCacheKey {
+ public:
+ ProjectorCacheKey(SchemaPtr schema, std::shared_ptr<Configuration> configuration,
+ ExpressionVector expression_vector, SelectionVector::Mode mode)
+ : schema_(schema), configuration_(configuration), mode_(mode), uniqifier_(0) {
+ static const int kSeedValue = 4;
+ size_t result = kSeedValue;
+ for (auto& expr : expression_vector) {
+ std::string expr_as_string = expr->ToString();
+ expressions_as_strings_.push_back(expr_as_string);
+ arrow::internal::hash_combine(result, expr_as_string);
+ UpdateUniqifier(expr_as_string);
+ }
+ arrow::internal::hash_combine(result, static_cast<size_t>(mode));
+ arrow::internal::hash_combine(result, configuration->Hash());
+ arrow::internal::hash_combine(result, schema_->ToString());
+ arrow::internal::hash_combine(result, uniqifier_);
+ hash_code_ = result;
+ }
+
+ std::size_t Hash() const { return hash_code_; }
+
+ bool operator==(const ProjectorCacheKey& other) const {
+ // arrow schema does not overload equality operators.
+ if (!(schema_->Equals(*other.schema().get(), true))) {
+ return false;
+ }
+
+ if (*configuration_ != *other.configuration_) {
+ return false;
+ }
+
+ if (expressions_as_strings_ != other.expressions_as_strings_) {
+ return false;
+ }
+
+ if (mode_ != other.mode_) {
+ return false;
+ }
+
+ if (uniqifier_ != other.uniqifier_) {
+ return false;
+ }
+ return true;
+ }
+
+ bool operator!=(const ProjectorCacheKey& other) const { return !(*this == other); }
+
+ SchemaPtr schema() const { return schema_; }
+
+ std::string ToString() const {
+ std::stringstream ss;
+ // indent, window, indent_size, null_rep and skip new lines.
+ arrow::PrettyPrintOptions options{0, 10, 2, "null", true};
+ DCHECK_OK(PrettyPrint(*schema_.get(), options, &ss));
+
+ ss << "Expressions: [";
+ bool first = true;
+ for (auto& expr : expressions_as_strings_) {
+ if (first) {
+ first = false;
+ } else {
+ ss << ", ";
+ }
+
+ ss << expr;
+ }
+ ss << "]";
+ return ss.str();
+ }
+
+ private:
+ void UpdateUniqifier(const std::string& expr) {
+ if (uniqifier_ == 0) {
+ // caching of expressions with re2 patterns causes lock contention. So, use
+ // multiple instances to reduce contention.
+ if (expr.find(" like(") != std::string::npos) {
+ uniqifier_ = std::hash<std::thread::id>()(std::this_thread::get_id()) % 16;
+ }
+ }
+ }
+
+ const SchemaPtr schema_;
+ const std::shared_ptr<Configuration> configuration_;
+ SelectionVector::Mode mode_;
+ std::vector<std::string> expressions_as_strings_;
+ size_t hash_code_;
+ uint32_t uniqifier_;
+};
+
+Projector::Projector(std::unique_ptr<LLVMGenerator> llvm_generator, SchemaPtr schema,
+ const FieldVector& output_fields,
+ std::shared_ptr<Configuration> configuration)
+ : llvm_generator_(std::move(llvm_generator)),
+ schema_(schema),
+ output_fields_(output_fields),
+ configuration_(configuration) {}
+
+Projector::~Projector() {}
+
+Status Projector::Make(SchemaPtr schema, const ExpressionVector& exprs,
+ std::shared_ptr<Projector>* projector) {
+ return Projector::Make(schema, exprs, SelectionVector::Mode::MODE_NONE,
+ ConfigurationBuilder::DefaultConfiguration(), projector);
+}
+
+Status Projector::Make(SchemaPtr schema, const ExpressionVector& exprs,
+ std::shared_ptr<Configuration> configuration,
+ std::shared_ptr<Projector>* projector) {
+ return Projector::Make(schema, exprs, SelectionVector::Mode::MODE_NONE, configuration,
+ projector);
+}
+
+Status Projector::Make(SchemaPtr schema, const ExpressionVector& exprs,
+ SelectionVector::Mode selection_vector_mode,
+ std::shared_ptr<Configuration> configuration,
+ std::shared_ptr<Projector>* projector) {
+ ARROW_RETURN_IF(schema == nullptr, Status::Invalid("Schema cannot be null"));
+ ARROW_RETURN_IF(exprs.empty(), Status::Invalid("Expressions cannot be empty"));
+ ARROW_RETURN_IF(configuration == nullptr,
+ Status::Invalid("Configuration cannot be null"));
+
+ // see if equivalent projector was already built
+ static Cache<ProjectorCacheKey, std::shared_ptr<Projector>> cache;
+ ProjectorCacheKey cache_key(schema, configuration, exprs, selection_vector_mode);
+ std::shared_ptr<Projector> cached_projector = cache.GetModule(cache_key);
+ if (cached_projector != nullptr) {
+ *projector = cached_projector;
+ return Status::OK();
+ }
+
+ // Build LLVM generator, and generate code for the specified expressions
+ std::unique_ptr<LLVMGenerator> llvm_gen;
+ ARROW_RETURN_NOT_OK(LLVMGenerator::Make(configuration, &llvm_gen));
+
+ // Run the validation on the expressions.
+ // Return if any of the expression is invalid since
+ // we will not be able to process further.
+ ExprValidator expr_validator(llvm_gen->types(), schema);
+ for (auto& expr : exprs) {
+ ARROW_RETURN_NOT_OK(expr_validator.Validate(expr));
+ }
+
+ // Start measuring build time
+ auto begin = std::chrono::high_resolution_clock::now();
+ ARROW_RETURN_NOT_OK(llvm_gen->Build(exprs, selection_vector_mode));
+ // Stop measuring time and calculate the elapsed time
+ auto end = std::chrono::high_resolution_clock::now();
+ auto elapsed =
+ std::chrono::duration_cast<std::chrono::milliseconds>(end - begin).count();
+
+ // save the output field types. Used for validation at Evaluate() time.
+ std::vector<FieldPtr> output_fields;
+ output_fields.reserve(exprs.size());
+ for (auto& expr : exprs) {
+ output_fields.push_back(expr->result());
+ }
+
+ // Instantiate the projector with the completely built llvm generator
+ *projector = std::shared_ptr<Projector>(
+ new Projector(std::move(llvm_gen), schema, output_fields, configuration));
+ ValueCacheObject<std::shared_ptr<Projector>> value_cache(*projector, elapsed);
+ cache.PutModule(cache_key, value_cache);
+
+ return Status::OK();
+}
+
+Status Projector::Evaluate(const arrow::RecordBatch& batch,
+ const ArrayDataVector& output_data_vecs) {
+ return Evaluate(batch, nullptr, output_data_vecs);
+}
+
+Status Projector::Evaluate(const arrow::RecordBatch& batch,
+ const SelectionVector* selection_vector,
+ const ArrayDataVector& output_data_vecs) {
+ ARROW_RETURN_NOT_OK(ValidateEvaluateArgsCommon(batch));
+
+ if (output_data_vecs.size() != output_fields_.size()) {
+ std::stringstream ss;
+ ss << "number of buffers for output_data_vecs is " << output_data_vecs.size()
+ << ", expected " << output_fields_.size();
+ return Status::Invalid(ss.str());
+ }
+
+ int idx = 0;
+ for (auto& array_data : output_data_vecs) {
+ if (array_data == nullptr) {
+ std::stringstream ss;
+ ss << "array for output field " << output_fields_[idx]->name() << "is null.";
+ return Status::Invalid(ss.str());
+ }
+
+ auto num_rows =
+ selection_vector == nullptr ? batch.num_rows() : selection_vector->GetNumSlots();
+
+ ARROW_RETURN_NOT_OK(
+ ValidateArrayDataCapacity(*array_data, *(output_fields_[idx]), num_rows));
+ ++idx;
+ }
+ return llvm_generator_->Execute(batch, selection_vector, output_data_vecs);
+}
+
+Status Projector::Evaluate(const arrow::RecordBatch& batch, arrow::MemoryPool* pool,
+ arrow::ArrayVector* output) {
+ return Evaluate(batch, nullptr, pool, output);
+}
+
+Status Projector::Evaluate(const arrow::RecordBatch& batch,
+ const SelectionVector* selection_vector,
+ arrow::MemoryPool* pool, arrow::ArrayVector* output) {
+ ARROW_RETURN_NOT_OK(ValidateEvaluateArgsCommon(batch));
+ ARROW_RETURN_IF(output == nullptr, Status::Invalid("Output must be non-null."));
+ ARROW_RETURN_IF(pool == nullptr, Status::Invalid("Memory pool must be non-null."));
+
+ auto num_rows =
+ selection_vector == nullptr ? batch.num_rows() : selection_vector->GetNumSlots();
+ // Allocate the output data vecs.
+ ArrayDataVector output_data_vecs;
+ for (auto& field : output_fields_) {
+ ArrayDataPtr output_data;
+
+ ARROW_RETURN_NOT_OK(AllocArrayData(field->type(), num_rows, pool, &output_data));
+ output_data_vecs.push_back(output_data);
+ }
+
+ // Execute the expression(s).
+ ARROW_RETURN_NOT_OK(
+ llvm_generator_->Execute(batch, selection_vector, output_data_vecs));
+
+ // Create and return array arrays.
+ output->clear();
+ for (auto& array_data : output_data_vecs) {
+ output->push_back(arrow::MakeArray(array_data));
+ }
+ return Status::OK();
+}
+
+// TODO : handle complex vectors (list/map/..)
+Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records,
+ arrow::MemoryPool* pool, ArrayDataPtr* array_data) {
+ arrow::Status astatus;
+ std::vector<std::shared_ptr<arrow::Buffer>> buffers;
+
+ // The output vector always has a null bitmap.
+ int64_t size = arrow::BitUtil::BytesForBits(num_records);
+ ARROW_ASSIGN_OR_RAISE(auto bitmap_buffer, arrow::AllocateBuffer(size, pool));
+ buffers.push_back(std::move(bitmap_buffer));
+
+ // String/Binary vectors have an offsets array.
+ auto type_id = type->id();
+ if (arrow::is_binary_like(type_id)) {
+ auto offsets_len = arrow::BitUtil::BytesForBits((num_records + 1) * 32);
+
+ ARROW_ASSIGN_OR_RAISE(auto offsets_buffer, arrow::AllocateBuffer(offsets_len, pool));
+ buffers.push_back(std::move(offsets_buffer));
+ }
+
+ // The output vector always has a data array.
+ int64_t data_len;
+ if (arrow::is_primitive(type_id) || type_id == arrow::Type::DECIMAL) {
+ const auto& fw_type = dynamic_cast<const arrow::FixedWidthType&>(*type);
+ data_len = arrow::BitUtil::BytesForBits(num_records * fw_type.bit_width());
+ } else if (arrow::is_binary_like(type_id)) {
+ // we don't know the expected size for varlen output vectors.
+ data_len = 0;
+ } else {
+ return Status::Invalid("Unsupported output data type " + type->ToString());
+ }
+ ARROW_ASSIGN_OR_RAISE(auto data_buffer, arrow::AllocateResizableBuffer(data_len, pool));
+
+ // This is not strictly required but valgrind gets confused and detects this
+ // as uninitialized memory access. See arrow::util::SetBitTo().
+ if (type->id() == arrow::Type::BOOL) {
+ memset(data_buffer->mutable_data(), 0, data_len);
+ }
+ buffers.push_back(std::move(data_buffer));
+
+ *array_data = arrow::ArrayData::Make(type, num_records, std::move(buffers));
+ return Status::OK();
+}
+
+Status Projector::ValidateEvaluateArgsCommon(const arrow::RecordBatch& batch) {
+ ARROW_RETURN_IF(!batch.schema()->Equals(*schema_),
+ Status::Invalid("Schema in RecordBatch must match schema in Make()"));
+ ARROW_RETURN_IF(batch.num_rows() == 0,
+ Status::Invalid("RecordBatch must be non-empty."));
+
+ return Status::OK();
+}
+
+Status Projector::ValidateArrayDataCapacity(const arrow::ArrayData& array_data,
+ const arrow::Field& field,
+ int64_t num_records) {
+ ARROW_RETURN_IF(array_data.buffers.size() < 2,
+ Status::Invalid("ArrayData must have at least 2 buffers"));
+
+ int64_t min_bitmap_len = arrow::BitUtil::BytesForBits(num_records);
+ int64_t bitmap_len = array_data.buffers[0]->capacity();
+ ARROW_RETURN_IF(
+ bitmap_len < min_bitmap_len,
+ Status::Invalid("Bitmap buffer too small for ", field.name(), " expected minimum ",
+ min_bitmap_len, " actual size ", bitmap_len));
+
+ auto type_id = field.type()->id();
+ if (arrow::is_binary_like(type_id)) {
+ // validate size of offsets buffer.
+ int64_t min_offsets_len = arrow::BitUtil::BytesForBits((num_records + 1) * 32);
+ int64_t offsets_len = array_data.buffers[1]->capacity();
+ ARROW_RETURN_IF(
+ offsets_len < min_offsets_len,
+ Status::Invalid("offsets buffer too small for ", field.name(),
+ " minimum required ", min_offsets_len, " actual ", offsets_len));
+
+ // check that it's resizable.
+ auto resizable = dynamic_cast<arrow::ResizableBuffer*>(array_data.buffers[2].get());
+ ARROW_RETURN_IF(
+ resizable == nullptr,
+ Status::Invalid("data buffer for varlen output vectors must be resizable"));
+ } else if (arrow::is_primitive(type_id) || type_id == arrow::Type::DECIMAL) {
+ // verify size of data buffer.
+ const auto& fw_type = dynamic_cast<const arrow::FixedWidthType&>(*field.type());
+ int64_t min_data_len =
+ arrow::BitUtil::BytesForBits(num_records * fw_type.bit_width());
+ int64_t data_len = array_data.buffers[1]->capacity();
+ ARROW_RETURN_IF(data_len < min_data_len,
+ Status::Invalid("Data buffer too small for ", field.name()));
+ } else {
+ return Status::Invalid("Unsupported output data type " + field.type()->ToString());
+ }
+
+ return Status::OK();
+}
+
+std::string Projector::DumpIR() { return llvm_generator_->DumpIR(); }
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/projector.h b/src/arrow/cpp/src/gandiva/projector.h
new file mode 100644
index 000000000..20b36c9d8
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/projector.h
@@ -0,0 +1,143 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/status.h"
+
+#include "gandiva/arrow.h"
+#include "gandiva/configuration.h"
+#include "gandiva/expression.h"
+#include "gandiva/selection_vector.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+class LLVMGenerator;
+
+/// \brief projection using expressions.
+///
+/// A projector is built for a specific schema and vector of expressions.
+/// Once the projector is built, it can be used to evaluate many row batches.
+class GANDIVA_EXPORT Projector {
+ public:
+ // Inline dtor will attempt to resolve the destructor for
+ // LLVMGenerator on MSVC, so we compile the dtor in the object code
+ ~Projector();
+
+ /// Build a default projector for the given schema to evaluate
+ /// the vector of expressions.
+ ///
+ /// \param[in] schema schema for the record batches, and the expressions.
+ /// \param[in] exprs vector of expressions.
+ /// \param[out] projector the returned projector object
+ static Status Make(SchemaPtr schema, const ExpressionVector& exprs,
+ std::shared_ptr<Projector>* projector);
+
+ /// Build a projector for the given schema to evaluate the vector of expressions.
+ /// Customize the projector with runtime configuration.
+ ///
+ /// \param[in] schema schema for the record batches, and the expressions.
+ /// \param[in] exprs vector of expressions.
+ /// \param[in] configuration run time configuration.
+ /// \param[out] projector the returned projector object
+ static Status Make(SchemaPtr schema, const ExpressionVector& exprs,
+ std::shared_ptr<Configuration> configuration,
+ std::shared_ptr<Projector>* projector);
+
+ /// Build a projector for the given schema to evaluate the vector of expressions.
+ /// Customize the projector with runtime configuration.
+ ///
+ /// \param[in] schema schema for the record batches, and the expressions.
+ /// \param[in] exprs vector of expressions.
+ /// \param[in] selection_vector_mode mode of selection vector
+ /// \param[in] configuration run time configuration.
+ /// \param[out] projector the returned projector object
+ static Status Make(SchemaPtr schema, const ExpressionVector& exprs,
+ SelectionVector::Mode selection_vector_mode,
+ std::shared_ptr<Configuration> configuration,
+ std::shared_ptr<Projector>* projector);
+
+ /// Evaluate the specified record batch, and return the allocated and populated output
+ /// arrays. The output arrays will be allocated from the memory pool 'pool', and added
+ /// to the vector 'output'.
+ ///
+ /// \param[in] batch the record batch. schema should be the same as the one in 'Make'
+ /// \param[in] pool memory pool used to allocate output arrays (if required).
+ /// \param[out] output the vector of allocated/populated arrays.
+ Status Evaluate(const arrow::RecordBatch& batch, arrow::MemoryPool* pool,
+ arrow::ArrayVector* output);
+
+ /// Evaluate the specified record batch, and populate the output arrays. The output
+ /// arrays of sufficient capacity must be allocated by the caller.
+ ///
+ /// \param[in] batch the record batch. schema should be the same as the one in 'Make'
+ /// \param[in,out] output vector of arrays, the arrays are allocated by the caller and
+ /// populated by Evaluate.
+ Status Evaluate(const arrow::RecordBatch& batch, const ArrayDataVector& output);
+
+ /// Evaluate the specified record batch, and return the allocated and populated output
+ /// arrays. The output arrays will be allocated from the memory pool 'pool', and added
+ /// to the vector 'output'.
+ ///
+ /// \param[in] batch the record batch. schema should be the same as the one in 'Make'
+ /// \param[in] selection_vector selection vector which has filtered row positions.
+ /// \param[in] pool memory pool used to allocate output arrays (if required).
+ /// \param[out] output the vector of allocated/populated arrays.
+ Status Evaluate(const arrow::RecordBatch& batch,
+ const SelectionVector* selection_vector, arrow::MemoryPool* pool,
+ arrow::ArrayVector* output);
+
+ /// Evaluate the specified record batch, and populate the output arrays at the filtered
+ /// positions. The output arrays of sufficient capacity must be allocated by the caller.
+ ///
+ /// \param[in] batch the record batch. schema should be the same as the one in 'Make'
+ /// \param[in] selection_vector selection vector which has the filtered row positions
+ /// \param[in,out] output vector of arrays, the arrays are allocated by the caller and
+ /// populated by Evaluate.
+ Status Evaluate(const arrow::RecordBatch& batch,
+ const SelectionVector* selection_vector, const ArrayDataVector& output);
+
+ std::string DumpIR();
+
+ private:
+ Projector(std::unique_ptr<LLVMGenerator> llvm_generator, SchemaPtr schema,
+ const FieldVector& output_fields, std::shared_ptr<Configuration>);
+
+ /// Allocate an ArrowData of length 'length'.
+ Status AllocArrayData(const DataTypePtr& type, int64_t num_records,
+ arrow::MemoryPool* pool, ArrayDataPtr* array_data);
+
+ /// Validate that the ArrayData has sufficient capacity to accommodate 'num_records'.
+ Status ValidateArrayDataCapacity(const arrow::ArrayData& array_data,
+ const arrow::Field& field, int64_t num_records);
+
+ /// Validate the common args for Evaluate() APIs.
+ Status ValidateEvaluateArgsCommon(const arrow::RecordBatch& batch);
+
+ std::unique_ptr<LLVMGenerator> llvm_generator_;
+ SchemaPtr schema_;
+ FieldVector output_fields_;
+ std::shared_ptr<Configuration> configuration_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/proto/Types.proto b/src/arrow/cpp/src/gandiva/proto/Types.proto
new file mode 100644
index 000000000..eb0d996b9
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/proto/Types.proto
@@ -0,0 +1,255 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+syntax = "proto2";
+package types;
+
+option java_package = "org.apache.arrow.gandiva.ipc";
+option java_outer_classname = "GandivaTypes";
+option optimize_for = SPEED;
+
+enum GandivaType {
+ NONE = 0; // arrow::Type::NA
+ BOOL = 1; // arrow::Type::BOOL
+ UINT8 = 2; // arrow::Type::UINT8
+ INT8 = 3; // arrow::Type::INT8
+ UINT16 = 4; // represents arrow::Type fields in src/arrow/type.h
+ INT16 = 5;
+ UINT32 = 6;
+ INT32 = 7;
+ UINT64 = 8;
+ INT64 = 9;
+ HALF_FLOAT = 10;
+ FLOAT = 11;
+ DOUBLE = 12;
+ UTF8 = 13;
+ BINARY = 14;
+ FIXED_SIZE_BINARY = 15;
+ DATE32 = 16;
+ DATE64 = 17;
+ TIMESTAMP = 18;
+ TIME32 = 19;
+ TIME64 = 20;
+ INTERVAL = 21;
+ DECIMAL = 22;
+ LIST = 23;
+ STRUCT = 24;
+ UNION = 25;
+ DICTIONARY = 26;
+ MAP = 27;
+}
+
+enum DateUnit {
+ DAY = 0;
+ MILLI = 1;
+}
+
+enum TimeUnit {
+ SEC = 0;
+ MILLISEC = 1;
+ MICROSEC = 2;
+ NANOSEC = 3;
+}
+
+enum IntervalType {
+ YEAR_MONTH = 0;
+ DAY_TIME = 1;
+}
+
+enum SelectionVectorType {
+ SV_NONE = 0;
+ SV_INT16 = 1;
+ SV_INT32 = 2;
+}
+
+message ExtGandivaType {
+ optional GandivaType type = 1;
+ optional uint32 width = 2; // used by FIXED_SIZE_BINARY
+ optional int32 precision = 3; // used by DECIMAL
+ optional int32 scale = 4; // used by DECIMAL
+ optional DateUnit dateUnit = 5; // used by DATE32/DATE64
+ optional TimeUnit timeUnit = 6; // used by TIME32/TIME64
+ optional string timeZone = 7; // used by TIMESTAMP
+ optional IntervalType intervalType = 8; // used by INTERVAL
+}
+
+message Field {
+ // name of the field
+ optional string name = 1;
+ optional ExtGandivaType type = 2;
+ optional bool nullable = 3;
+ // for complex data types like structs, unions
+ repeated Field children = 4;
+}
+
+message FieldNode {
+ optional Field field = 1;
+}
+
+message FunctionNode {
+ optional string functionName = 1;
+ repeated TreeNode inArgs = 2;
+ optional ExtGandivaType returnType = 3;
+}
+
+message IfNode {
+ optional TreeNode cond = 1;
+ optional TreeNode thenNode = 2;
+ optional TreeNode elseNode = 3;
+ optional ExtGandivaType returnType = 4;
+}
+
+message AndNode {
+ repeated TreeNode args = 1;
+}
+
+message OrNode {
+ repeated TreeNode args = 1;
+}
+
+message NullNode {
+ optional ExtGandivaType type = 1;
+}
+
+message IntNode {
+ optional int32 value = 1;
+}
+
+message FloatNode {
+ optional float value = 1;
+}
+
+message DoubleNode {
+ optional double value = 1;
+}
+
+message BooleanNode {
+ optional bool value = 1;
+}
+
+message LongNode {
+ optional int64 value = 1;
+}
+
+message StringNode {
+ optional bytes value = 1;
+}
+
+message BinaryNode {
+ optional bytes value = 1;
+}
+
+message DecimalNode {
+ optional string value = 1;
+ optional int32 precision = 2;
+ optional int32 scale = 3;
+}
+
+
+message TreeNode {
+ optional FieldNode fieldNode = 1;
+ optional FunctionNode fnNode = 2;
+
+ // control expressions
+ optional IfNode ifNode = 6;
+ optional AndNode andNode = 7;
+ optional OrNode orNode = 8;
+
+ // literals
+ optional NullNode nullNode = 11;
+ optional IntNode intNode = 12;
+ optional FloatNode floatNode = 13;
+ optional LongNode longNode = 14;
+ optional BooleanNode booleanNode = 15;
+ optional DoubleNode doubleNode = 16;
+ optional StringNode stringNode = 17;
+ optional BinaryNode binaryNode = 18;
+ optional DecimalNode decimalNode = 19;
+
+ // in expr
+ optional InNode inNode = 21;
+}
+
+message ExpressionRoot {
+ optional TreeNode root = 1;
+ optional Field resultType = 2;
+}
+
+message ExpressionList {
+ repeated ExpressionRoot exprs = 2;
+}
+
+message Condition {
+ optional TreeNode root = 1;
+}
+
+message Schema {
+ repeated Field columns = 1;
+}
+
+message GandivaDataTypes {
+ repeated ExtGandivaType dataType = 1;
+}
+
+message GandivaFunctions {
+ repeated FunctionSignature function = 1;
+}
+
+message FunctionSignature {
+ optional string name = 1;
+ optional ExtGandivaType returnType = 2;
+ repeated ExtGandivaType paramTypes = 3;
+}
+
+message InNode {
+ optional TreeNode node = 1;
+ optional IntConstants intValues = 2;
+ optional LongConstants longValues = 3;
+ optional StringConstants stringValues = 4;
+ optional BinaryConstants binaryValues = 5;
+ optional DecimalConstants decimalValues = 6;
+ optional FloatConstants floatValues = 7;
+ optional DoubleConstants doubleValues = 8;
+}
+
+message IntConstants {
+ repeated IntNode intValues = 1;
+}
+
+message LongConstants {
+ repeated LongNode longValues = 1;
+}
+
+message DecimalConstants {
+ repeated DecimalNode decimalValues = 1;
+}
+
+message FloatConstants {
+ repeated FloatNode floatValues = 1;
+}
+
+message DoubleConstants {
+ repeated DoubleNode doubleValues = 1;
+}
+
+message StringConstants {
+ repeated StringNode stringValues = 1;
+}
+
+message BinaryConstants {
+ repeated BinaryNode binaryValues = 1;
+}
diff --git a/src/arrow/cpp/src/gandiva/random_generator_holder.cc b/src/arrow/cpp/src/gandiva/random_generator_holder.cc
new file mode 100644
index 000000000..3471c87d9
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/random_generator_holder.cc
@@ -0,0 +1,45 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/random_generator_holder.h"
+#include "gandiva/node.h"
+
+namespace gandiva {
+Status RandomGeneratorHolder::Make(const FunctionNode& node,
+ std::shared_ptr<RandomGeneratorHolder>* holder) {
+ ARROW_RETURN_IF(node.children().size() > 1,
+ Status::Invalid("'random' function requires at most one parameter"));
+
+ if (node.children().size() == 0) {
+ *holder = std::shared_ptr<RandomGeneratorHolder>(new RandomGeneratorHolder());
+ return Status::OK();
+ }
+
+ auto literal = dynamic_cast<LiteralNode*>(node.children().at(0).get());
+ ARROW_RETURN_IF(literal == nullptr,
+ Status::Invalid("'random' function requires a literal as parameter"));
+
+ auto literal_type = literal->return_type()->id();
+ ARROW_RETURN_IF(
+ literal_type != arrow::Type::INT32,
+ Status::Invalid("'random' function requires an int32 literal as parameter"));
+
+ *holder = std::shared_ptr<RandomGeneratorHolder>(new RandomGeneratorHolder(
+ literal->is_null() ? 0 : arrow::util::get<int32_t>(literal->holder())));
+ return Status::OK();
+}
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/random_generator_holder.h b/src/arrow/cpp/src/gandiva/random_generator_holder.h
new file mode 100644
index 000000000..65b6607e8
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/random_generator_holder.h
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <random>
+
+#include "arrow/status.h"
+#include "arrow/util/io_util.h"
+
+#include "gandiva/function_holder.h"
+#include "gandiva/node.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// Function Holder for 'random'
+class GANDIVA_EXPORT RandomGeneratorHolder : public FunctionHolder {
+ public:
+ ~RandomGeneratorHolder() override = default;
+
+ static Status Make(const FunctionNode& node,
+ std::shared_ptr<RandomGeneratorHolder>* holder);
+
+ double operator()() { return distribution_(generator_); }
+
+ private:
+ explicit RandomGeneratorHolder(int seed) : distribution_(0, 1) {
+ int64_t seed64 = static_cast<int64_t>(seed);
+ seed64 = (seed64 ^ 0x00000005DEECE66D) & 0x0000ffffffffffff;
+ generator_.seed(static_cast<uint64_t>(seed64));
+ }
+
+ RandomGeneratorHolder() : distribution_(0, 1) {
+ generator_.seed(::arrow::internal::GetRandomSeed());
+ }
+
+ std::mt19937_64 generator_;
+ std::uniform_real_distribution<> distribution_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/random_generator_holder_test.cc b/src/arrow/cpp/src/gandiva/random_generator_holder_test.cc
new file mode 100644
index 000000000..4b16c1b7d
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/random_generator_holder_test.cc
@@ -0,0 +1,103 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/random_generator_holder.h"
+
+#include <memory>
+
+#include <gtest/gtest.h>
+
+namespace gandiva {
+
+class TestRandGenHolder : public ::testing::Test {
+ public:
+ FunctionNode BuildRandFunc() { return FunctionNode("random", {}, arrow::float64()); }
+
+ FunctionNode BuildRandWithSeedFunc(int32_t seed, bool seed_is_null) {
+ auto seed_node =
+ std::make_shared<LiteralNode>(arrow::int32(), LiteralHolder(seed), seed_is_null);
+ return FunctionNode("rand", {seed_node}, arrow::float64());
+ }
+};
+
+TEST_F(TestRandGenHolder, NoSeed) {
+ std::shared_ptr<RandomGeneratorHolder> rand_gen_holder;
+ FunctionNode rand_func = BuildRandFunc();
+ auto status = RandomGeneratorHolder::Make(rand_func, &rand_gen_holder);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& random = *rand_gen_holder;
+ EXPECT_NE(random(), random());
+}
+
+TEST_F(TestRandGenHolder, WithValidEqualSeeds) {
+ std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_1;
+ std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_2;
+ FunctionNode rand_func_1 = BuildRandWithSeedFunc(12, false);
+ FunctionNode rand_func_2 = BuildRandWithSeedFunc(12, false);
+ auto status = RandomGeneratorHolder::Make(rand_func_1, &rand_gen_holder_1);
+ EXPECT_EQ(status.ok(), true) << status.message();
+ status = RandomGeneratorHolder::Make(rand_func_2, &rand_gen_holder_2);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& random_1 = *rand_gen_holder_1;
+ auto& random_2 = *rand_gen_holder_2;
+ EXPECT_EQ(random_1(), random_2());
+ EXPECT_EQ(random_1(), random_2());
+ EXPECT_GT(random_1(), 0);
+ EXPECT_NE(random_1(), random_2());
+ EXPECT_LT(random_2(), 1);
+ EXPECT_EQ(random_1(), random_2());
+}
+
+TEST_F(TestRandGenHolder, WithValidSeeds) {
+ std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_1;
+ std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_2;
+ std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_3;
+ FunctionNode rand_func_1 = BuildRandWithSeedFunc(11, false);
+ FunctionNode rand_func_2 = BuildRandWithSeedFunc(12, false);
+ FunctionNode rand_func_3 = BuildRandWithSeedFunc(-12, false);
+ auto status = RandomGeneratorHolder::Make(rand_func_1, &rand_gen_holder_1);
+ EXPECT_EQ(status.ok(), true) << status.message();
+ status = RandomGeneratorHolder::Make(rand_func_2, &rand_gen_holder_2);
+ EXPECT_EQ(status.ok(), true) << status.message();
+ status = RandomGeneratorHolder::Make(rand_func_3, &rand_gen_holder_3);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& random_1 = *rand_gen_holder_1;
+ auto& random_2 = *rand_gen_holder_2;
+ auto& random_3 = *rand_gen_holder_3;
+ EXPECT_NE(random_2(), random_3());
+ EXPECT_NE(random_1(), random_2());
+}
+
+TEST_F(TestRandGenHolder, WithInValidSeed) {
+ std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_1;
+ std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_2;
+ FunctionNode rand_func_1 = BuildRandWithSeedFunc(12, true);
+ FunctionNode rand_func_2 = BuildRandWithSeedFunc(0, false);
+ auto status = RandomGeneratorHolder::Make(rand_func_1, &rand_gen_holder_1);
+ EXPECT_EQ(status.ok(), true) << status.message();
+ status = RandomGeneratorHolder::Make(rand_func_2, &rand_gen_holder_2);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ auto& random_1 = *rand_gen_holder_1;
+ auto& random_2 = *rand_gen_holder_2;
+ EXPECT_EQ(random_1(), random_2());
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/regex_util.cc b/src/arrow/cpp/src/gandiva/regex_util.cc
new file mode 100644
index 000000000..abdd579d1
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/regex_util.cc
@@ -0,0 +1,63 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/regex_util.h"
+
+namespace gandiva {
+
+const std::set<char> RegexUtil::pcre_regex_specials_ = {
+ '[', ']', '(', ')', '|', '^', '-', '+', '*', '?', '{', '}', '$', '\\', '.'};
+
+Status RegexUtil::SqlLikePatternToPcre(const std::string& sql_pattern, char escape_char,
+ std::string& pcre_pattern) {
+ /// Characters that are considered special by pcre regex. These needs to be
+ /// escaped with '\\'.
+ pcre_pattern.clear();
+ for (size_t idx = 0; idx < sql_pattern.size(); ++idx) {
+ auto cur = sql_pattern.at(idx);
+
+ // Escape any char that is special for pcre regex
+ if (pcre_regex_specials_.find(cur) != pcre_regex_specials_.end()) {
+ pcre_pattern += "\\";
+ }
+
+ if (cur == escape_char) {
+ // escape char must be followed by '_', '%' or the escape char itself.
+ ++idx;
+ ARROW_RETURN_IF(
+ idx == sql_pattern.size(),
+ Status::Invalid("Unexpected escape char at the end of pattern ", sql_pattern));
+
+ cur = sql_pattern.at(idx);
+ if (cur == '_' || cur == '%' || cur == escape_char) {
+ pcre_pattern += cur;
+ } else {
+ return Status::Invalid("Invalid escape sequence in pattern ", sql_pattern,
+ " at offset ", idx);
+ }
+ } else if (cur == '_') {
+ pcre_pattern += '.';
+ } else if (cur == '%') {
+ pcre_pattern += ".*";
+ } else {
+ pcre_pattern += cur;
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/regex_util.h b/src/arrow/cpp/src/gandiva/regex_util.h
new file mode 100644
index 000000000..cf0002b8c
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/regex_util.h
@@ -0,0 +1,45 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <set>
+#include <sstream>
+#include <string>
+
+#include "gandiva/arrow.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// \brief Utility class for converting sql patterns to pcre patterns.
+class GANDIVA_EXPORT RegexUtil {
+ public:
+ // Convert an sql pattern to a pcre pattern
+ static Status SqlLikePatternToPcre(const std::string& like_pattern, char escape_char,
+ std::string& pcre_pattern);
+
+ static Status SqlLikePatternToPcre(const std::string& like_pattern,
+ std::string& pcre_pattern) {
+ return SqlLikePatternToPcre(like_pattern, 0 /*escape_char*/, pcre_pattern);
+ }
+
+ private:
+ static const std::set<char> pcre_regex_specials_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/replace_holder.cc b/src/arrow/cpp/src/gandiva/replace_holder.cc
new file mode 100644
index 000000000..8b42b585f
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/replace_holder.cc
@@ -0,0 +1,65 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/replace_holder.h"
+
+#include "gandiva/node.h"
+#include "gandiva/regex_util.h"
+
+namespace gandiva {
+
+static bool IsArrowStringLiteral(arrow::Type::type type) {
+ return type == arrow::Type::STRING || type == arrow::Type::BINARY;
+}
+
+Status ReplaceHolder::Make(const FunctionNode& node,
+ std::shared_ptr<ReplaceHolder>* holder) {
+ ARROW_RETURN_IF(node.children().size() != 3,
+ Status::Invalid("'replace' function requires three parameters"));
+
+ auto literal = dynamic_cast<LiteralNode*>(node.children().at(1).get());
+ ARROW_RETURN_IF(
+ literal == nullptr,
+ Status::Invalid("'replace' function requires a literal as the second parameter"));
+
+ auto literal_type = literal->return_type()->id();
+ ARROW_RETURN_IF(
+ !IsArrowStringLiteral(literal_type),
+ Status::Invalid(
+ "'replace' function requires a string literal as the second parameter"));
+
+ return Make(arrow::util::get<std::string>(literal->holder()), holder);
+}
+
+Status ReplaceHolder::Make(const std::string& sql_pattern,
+ std::shared_ptr<ReplaceHolder>* holder) {
+ auto lholder = std::shared_ptr<ReplaceHolder>(new ReplaceHolder(sql_pattern));
+ ARROW_RETURN_IF(!lholder->regex_.ok(),
+ Status::Invalid("Building RE2 pattern '", sql_pattern, "' failed"));
+
+ *holder = lholder;
+ return Status::OK();
+}
+
+void ReplaceHolder::return_error(ExecutionContext* context, std::string& data,
+ std::string& replace_string) {
+ std::string err_msg = "Error replacing '" + replace_string + "' on the given string '" +
+ data + "' for the given pattern: " + pattern_;
+ context->set_error_msg(err_msg.c_str());
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/replace_holder.h b/src/arrow/cpp/src/gandiva/replace_holder.h
new file mode 100644
index 000000000..79150d7aa
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/replace_holder.h
@@ -0,0 +1,97 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <re2/re2.h>
+
+#include <memory>
+#include <string>
+
+#include "arrow/status.h"
+#include "gandiva/execution_context.h"
+#include "gandiva/function_holder.h"
+#include "gandiva/node.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// Function Holder for 'replace'
+class GANDIVA_EXPORT ReplaceHolder : public FunctionHolder {
+ public:
+ ~ReplaceHolder() override = default;
+
+ static Status Make(const FunctionNode& node, std::shared_ptr<ReplaceHolder>* holder);
+
+ static Status Make(const std::string& sql_pattern,
+ std::shared_ptr<ReplaceHolder>* holder);
+
+ /// Return a new string with the pattern that matched the regex replaced for
+ /// the replace_input parameter.
+ const char* operator()(ExecutionContext* ctx, const char* user_input,
+ int32_t user_input_len, const char* replace_input,
+ int32_t replace_input_len, int32_t* out_length) {
+ std::string user_input_as_str(user_input, user_input_len);
+ std::string replace_input_as_str(replace_input, replace_input_len);
+
+ int32_t total_replaces =
+ RE2::GlobalReplace(&user_input_as_str, regex_, replace_input_as_str);
+
+ if (total_replaces < 0) {
+ return_error(ctx, user_input_as_str, replace_input_as_str);
+ *out_length = 0;
+ return "";
+ }
+
+ if (total_replaces == 0) {
+ *out_length = user_input_len;
+ return user_input;
+ }
+
+ *out_length = static_cast<int32_t>(user_input_as_str.size());
+
+ // This condition treats the case where the whole string is replaced by an empty
+ // string
+ if (*out_length == 0) {
+ return "";
+ }
+
+ char* result_buffer = reinterpret_cast<char*>(ctx->arena()->Allocate(*out_length));
+
+ if (result_buffer == NULLPTR) {
+ ctx->set_error_msg("Could not allocate memory for result");
+ *out_length = 0;
+ return "";
+ }
+
+ memcpy(result_buffer, user_input_as_str.data(), *out_length);
+
+ return result_buffer;
+ }
+
+ private:
+ explicit ReplaceHolder(const std::string& pattern)
+ : pattern_(pattern), regex_(pattern) {}
+
+ void return_error(ExecutionContext* context, std::string& data,
+ std::string& replace_string);
+
+ std::string pattern_; // posix pattern string, to help debugging
+ RE2 regex_; // compiled regex for the pattern
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/replace_holder_test.cc b/src/arrow/cpp/src/gandiva/replace_holder_test.cc
new file mode 100644
index 000000000..b0830d4f0
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/replace_holder_test.cc
@@ -0,0 +1,129 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/replace_holder.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <vector>
+
+namespace gandiva {
+
+class TestReplaceHolder : public ::testing::Test {
+ protected:
+ ExecutionContext execution_context_;
+};
+
+TEST_F(TestReplaceHolder, TestMultipleReplace) {
+ std::shared_ptr<ReplaceHolder> replace_holder;
+
+ auto status = ReplaceHolder::Make("ana", &replace_holder);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ std::string input_string = "banana";
+ std::string replace_string;
+ int32_t out_length = 0;
+
+ auto& replace = *replace_holder;
+ const char* ret =
+ replace(&execution_context_, input_string.c_str(),
+ static_cast<int32_t>(input_string.length()), replace_string.c_str(),
+ static_cast<int32_t>(replace_string.length()), &out_length);
+ std::string ret_as_str(ret, out_length);
+ EXPECT_EQ(out_length, 3);
+ EXPECT_EQ(ret_as_str, "bna");
+
+ input_string = "bananaana";
+
+ ret = replace(&execution_context_, input_string.c_str(),
+ static_cast<int32_t>(input_string.length()), replace_string.c_str(),
+ static_cast<int32_t>(replace_string.length()), &out_length);
+ ret_as_str = std::string(ret, out_length);
+ EXPECT_EQ(out_length, 3);
+ EXPECT_EQ(ret_as_str, "bna");
+
+ input_string = "bananana";
+
+ ret = replace(&execution_context_, input_string.c_str(),
+ static_cast<int32_t>(input_string.length()), replace_string.c_str(),
+ static_cast<int32_t>(replace_string.length()), &out_length);
+ ret_as_str = std::string(ret, out_length);
+ EXPECT_EQ(out_length, 2);
+ EXPECT_EQ(ret_as_str, "bn");
+
+ input_string = "anaana";
+
+ ret = replace(&execution_context_, input_string.c_str(),
+ static_cast<int32_t>(input_string.length()), replace_string.c_str(),
+ static_cast<int32_t>(replace_string.length()), &out_length);
+ ret_as_str = std::string(ret, out_length);
+ EXPECT_EQ(out_length, 0);
+ EXPECT_FALSE(execution_context_.has_error());
+ EXPECT_EQ(ret_as_str, "");
+}
+
+TEST_F(TestReplaceHolder, TestNoMatchPattern) {
+ std::shared_ptr<ReplaceHolder> replace_holder;
+
+ auto status = ReplaceHolder::Make("ana", &replace_holder);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ std::string input_string = "apple";
+ std::string replace_string;
+ int32_t out_length = 0;
+
+ auto& replace = *replace_holder;
+ const char* ret =
+ replace(&execution_context_, input_string.c_str(),
+ static_cast<int32_t>(input_string.length()), replace_string.c_str(),
+ static_cast<int32_t>(replace_string.length()), &out_length);
+ std::string ret_as_string(ret, out_length);
+ EXPECT_EQ(out_length, 5);
+ EXPECT_EQ(ret_as_string, "apple");
+}
+
+TEST_F(TestReplaceHolder, TestReplaceSameSize) {
+ std::shared_ptr<ReplaceHolder> replace_holder;
+
+ auto status = ReplaceHolder::Make("a", &replace_holder);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ std::string input_string = "ananindeua";
+ std::string replace_string = "b";
+ int32_t out_length = 0;
+
+ auto& replace = *replace_holder;
+ const char* ret =
+ replace(&execution_context_, input_string.c_str(),
+ static_cast<int32_t>(input_string.length()), replace_string.c_str(),
+ static_cast<int32_t>(replace_string.length()), &out_length);
+ std::string ret_as_string(ret, out_length);
+ EXPECT_EQ(out_length, 10);
+ EXPECT_EQ(ret_as_string, "bnbnindeub");
+}
+
+TEST_F(TestReplaceHolder, TestReplaceInvalidPattern) {
+ std::shared_ptr<ReplaceHolder> replace_holder;
+
+ auto status = ReplaceHolder::Make("+", &replace_holder);
+ EXPECT_EQ(status.ok(), false) << status.message();
+
+ execution_context_.Reset();
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/selection_vector.cc b/src/arrow/cpp/src/gandiva/selection_vector.cc
new file mode 100644
index 000000000..a30bba686
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/selection_vector.cc
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/selection_vector.h"
+
+#include <memory>
+#include <sstream>
+#include <utility>
+#include <vector>
+
+#include "arrow/util/bit_util.h"
+#include "arrow/util/endian.h"
+
+#include "gandiva/selection_vector_impl.h"
+
+namespace gandiva {
+
+constexpr SelectionVector::Mode SelectionVector::kAllModes[kNumModes];
+
+Status SelectionVector::PopulateFromBitMap(const uint8_t* bitmap, int64_t bitmap_size,
+ int64_t max_bitmap_index) {
+ const uint64_t max_idx = static_cast<uint64_t>(max_bitmap_index);
+ ARROW_RETURN_IF(bitmap_size % 8, Status::Invalid("Bitmap size ", bitmap_size,
+ " must be aligned to 64-bit size"));
+ ARROW_RETURN_IF(max_bitmap_index < 0,
+ Status::Invalid("Max bitmap index must be positive"));
+ ARROW_RETURN_IF(
+ max_idx > GetMaxSupportedValue(),
+ Status::Invalid("max_bitmap_index ", max_idx, " must be <= maxSupportedValue ",
+ GetMaxSupportedValue(), " in selection vector"));
+
+ int64_t max_slots = GetMaxSlots();
+
+ // jump 8-bytes at a time, add the index corresponding to each valid bit to the
+ // the selection vector.
+ int64_t selection_idx = 0;
+ const uint64_t* bitmap_64 = reinterpret_cast<const uint64_t*>(bitmap);
+ for (int64_t bitmap_idx = 0; bitmap_idx < bitmap_size / 8; ++bitmap_idx) {
+ uint64_t current_word = arrow::BitUtil::ToLittleEndian(bitmap_64[bitmap_idx]);
+
+ while (current_word != 0) {
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4146)
+#endif
+ // MSVC warns about negating an unsigned type. We suppress it for now
+ uint64_t highest_only = current_word & -current_word;
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+ int pos_in_word = arrow::BitUtil::CountTrailingZeros(highest_only);
+
+ int64_t pos_in_bitmap = bitmap_idx * 64 + pos_in_word;
+ if (pos_in_bitmap > max_bitmap_index) {
+ // the bitmap may be slightly larger for alignment/padding.
+ break;
+ }
+
+ ARROW_RETURN_IF(selection_idx >= max_slots,
+ Status::Invalid("selection vector has no remaining slots"));
+
+ SetIndex(selection_idx, pos_in_bitmap);
+ ++selection_idx;
+
+ current_word ^= highest_only;
+ }
+ }
+
+ SetNumSlots(selection_idx);
+ return Status::OK();
+}
+
+Status SelectionVector::MakeInt16(int64_t max_slots,
+ std::shared_ptr<arrow::Buffer> buffer,
+ std::shared_ptr<SelectionVector>* selection_vector) {
+ ARROW_RETURN_NOT_OK(SelectionVectorInt16::ValidateBuffer(max_slots, buffer));
+ *selection_vector = std::make_shared<SelectionVectorInt16>(max_slots, buffer);
+ return Status::OK();
+}
+
+Status SelectionVector::MakeInt16(int64_t max_slots, arrow::MemoryPool* pool,
+ std::shared_ptr<SelectionVector>* selection_vector) {
+ std::shared_ptr<arrow::Buffer> buffer;
+ ARROW_RETURN_NOT_OK(SelectionVectorInt16::AllocateBuffer(max_slots, pool, &buffer));
+ *selection_vector = std::make_shared<SelectionVectorInt16>(max_slots, buffer);
+ return Status::OK();
+}
+
+Status SelectionVector::MakeImmutableInt16(
+ int64_t num_slots, std::shared_ptr<arrow::Buffer> buffer,
+ std::shared_ptr<SelectionVector>* selection_vector) {
+ *selection_vector =
+ std::make_shared<SelectionVectorInt16>(num_slots, num_slots, buffer);
+ return Status::OK();
+}
+
+Status SelectionVector::MakeInt32(int64_t max_slots,
+ std::shared_ptr<arrow::Buffer> buffer,
+ std::shared_ptr<SelectionVector>* selection_vector) {
+ ARROW_RETURN_NOT_OK(SelectionVectorInt32::ValidateBuffer(max_slots, buffer));
+ *selection_vector = std::make_shared<SelectionVectorInt32>(max_slots, buffer);
+
+ return Status::OK();
+}
+
+Status SelectionVector::MakeInt32(int64_t max_slots, arrow::MemoryPool* pool,
+ std::shared_ptr<SelectionVector>* selection_vector) {
+ std::shared_ptr<arrow::Buffer> buffer;
+ ARROW_RETURN_NOT_OK(SelectionVectorInt32::AllocateBuffer(max_slots, pool, &buffer));
+ *selection_vector = std::make_shared<SelectionVectorInt32>(max_slots, buffer);
+
+ return Status::OK();
+}
+
+Status SelectionVector::MakeImmutableInt32(
+ int64_t num_slots, std::shared_ptr<arrow::Buffer> buffer,
+ std::shared_ptr<SelectionVector>* selection_vector) {
+ *selection_vector =
+ std::make_shared<SelectionVectorInt32>(num_slots, num_slots, buffer);
+ return Status::OK();
+}
+
+Status SelectionVector::MakeInt64(int64_t max_slots,
+ std::shared_ptr<arrow::Buffer> buffer,
+ std::shared_ptr<SelectionVector>* selection_vector) {
+ ARROW_RETURN_NOT_OK(SelectionVectorInt64::ValidateBuffer(max_slots, buffer));
+ *selection_vector = std::make_shared<SelectionVectorInt64>(max_slots, buffer);
+
+ return Status::OK();
+}
+
+Status SelectionVector::MakeInt64(int64_t max_slots, arrow::MemoryPool* pool,
+ std::shared_ptr<SelectionVector>* selection_vector) {
+ std::shared_ptr<arrow::Buffer> buffer;
+ ARROW_RETURN_NOT_OK(SelectionVectorInt64::AllocateBuffer(max_slots, pool, &buffer));
+ *selection_vector = std::make_shared<SelectionVectorInt64>(max_slots, buffer);
+
+ return Status::OK();
+}
+
+template <typename C_TYPE, typename A_TYPE, SelectionVector::Mode mode>
+Status SelectionVectorImpl<C_TYPE, A_TYPE, mode>::AllocateBuffer(
+ int64_t max_slots, arrow::MemoryPool* pool, std::shared_ptr<arrow::Buffer>* buffer) {
+ auto buffer_len = max_slots * sizeof(C_TYPE);
+ ARROW_ASSIGN_OR_RAISE(*buffer, arrow::AllocateBuffer(buffer_len, pool));
+
+ return Status::OK();
+}
+
+template <typename C_TYPE, typename A_TYPE, SelectionVector::Mode mode>
+Status SelectionVectorImpl<C_TYPE, A_TYPE, mode>::ValidateBuffer(
+ int64_t max_slots, std::shared_ptr<arrow::Buffer> buffer) {
+ ARROW_RETURN_IF(!buffer->is_mutable(),
+ Status::Invalid("buffer for selection vector must be mutable"));
+
+ const int64_t min_len = max_slots * sizeof(C_TYPE);
+ ARROW_RETURN_IF(buffer->size() < min_len,
+ Status::Invalid("Buffer for selection vector is too small"));
+
+ return Status::OK();
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/selection_vector.h b/src/arrow/cpp/src/gandiva/selection_vector.h
new file mode 100644
index 000000000..1c0fef1c5
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/selection_vector.h
@@ -0,0 +1,151 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "arrow/status.h"
+
+#include "arrow/util/logging.h"
+#include "gandiva/arrow.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// \brief Selection Vector : vector of indices in a row-batch for a selection,
+/// backed by an arrow-array.
+class GANDIVA_EXPORT SelectionVector {
+ public:
+ virtual ~SelectionVector() = default;
+
+ enum Mode : int {
+ MODE_NONE,
+ MODE_UINT16,
+ MODE_UINT32,
+ MODE_UINT64,
+ MODE_MAX = MODE_UINT64, // dummy
+ };
+ static constexpr int kNumModes = static_cast<int>(MODE_MAX) + 1;
+ static constexpr Mode kAllModes[kNumModes] = {MODE_NONE, MODE_UINT16, MODE_UINT32,
+ MODE_UINT64};
+
+ /// Get the value at a given index.
+ virtual uint64_t GetIndex(int64_t index) const = 0;
+
+ /// Set the value at a given index.
+ virtual void SetIndex(int64_t index, uint64_t value) = 0;
+
+ // Get the max supported value in the selection vector.
+ virtual uint64_t GetMaxSupportedValue() const = 0;
+
+ /// The maximum slots (capacity) of the selection vector.
+ virtual int64_t GetMaxSlots() const = 0;
+
+ /// The number of slots (size) of the selection vector.
+ virtual int64_t GetNumSlots() const = 0;
+
+ /// Set the number of slots in the selection vector.
+ virtual void SetNumSlots(int64_t num_slots) = 0;
+
+ /// Convert to arrow-array.
+ virtual ArrayPtr ToArray() const = 0;
+
+ /// Get the underlying arrow buffer.
+ virtual arrow::Buffer& GetBuffer() const = 0;
+
+ /// Mode of SelectionVector
+ virtual Mode GetMode() const = 0;
+
+ /// \brief populate selection vector for all the set bits in the bitmap.
+ ///
+ /// \param[in] bitmap the bitmap
+ /// \param[in] bitmap_size size of the bitmap in bytes
+ /// \param[in] max_bitmap_index max valid index in bitmap (can be lesser than
+ /// capacity in the bitmap, due to alignment/padding).
+ Status PopulateFromBitMap(const uint8_t* bitmap, int64_t bitmap_size,
+ int64_t max_bitmap_index);
+
+ /// \brief make selection vector with int16 type records.
+ ///
+ /// \param[in] max_slots max number of slots
+ /// \param[in] buffer buffer sized to accommodate max_slots
+ /// \param[out] selection_vector selection vector backed by 'buffer'
+ static Status MakeInt16(int64_t max_slots, std::shared_ptr<arrow::Buffer> buffer,
+ std::shared_ptr<SelectionVector>* selection_vector);
+
+ /// \param[in] max_slots max number of slots
+ /// \param[in] pool memory pool to allocate buffer
+ /// \param[out] selection_vector selection vector backed by a buffer allocated from the
+ /// pool.
+ static Status MakeInt16(int64_t max_slots, arrow::MemoryPool* pool,
+ std::shared_ptr<SelectionVector>* selection_vector);
+
+ /// \brief creates a selection vector with pre populated buffer.
+ ///
+ /// \param[in] num_slots size of the selection vector
+ /// \param[in] buffer pre-populated buffer
+ /// \param[out] selection_vector selection vector backed by 'buffer'
+ static Status MakeImmutableInt16(int64_t num_slots,
+ std::shared_ptr<arrow::Buffer> buffer,
+ std::shared_ptr<SelectionVector>* selection_vector);
+
+ /// \brief make selection vector with int32 type records.
+ ///
+ /// \param[in] max_slots max number of slots
+ /// \param[in] buffer buffer sized to accommodate max_slots
+ /// \param[out] selection_vector selection vector backed by 'buffer'
+ static Status MakeInt32(int64_t max_slots, std::shared_ptr<arrow::Buffer> buffer,
+ std::shared_ptr<SelectionVector>* selection_vector);
+
+ /// \brief make selection vector with int32 type records.
+ ///
+ /// \param[in] max_slots max number of slots
+ /// \param[in] pool memory pool to allocate buffer
+ /// \param[out] selection_vector selection vector backed by a buffer allocated from the
+ /// pool.
+ static Status MakeInt32(int64_t max_slots, arrow::MemoryPool* pool,
+ std::shared_ptr<SelectionVector>* selection_vector);
+
+ /// \brief creates a selection vector with pre populated buffer.
+ ///
+ /// \param[in] num_slots size of the selection vector
+ /// \param[in] buffer pre-populated buffer
+ /// \param[out] selection_vector selection vector backed by 'buffer'
+ static Status MakeImmutableInt32(int64_t num_slots,
+ std::shared_ptr<arrow::Buffer> buffer,
+ std::shared_ptr<SelectionVector>* selection_vector);
+
+ /// \brief make selection vector with int64 type records.
+ ///
+ /// \param[in] max_slots max number of slots
+ /// \param[in] buffer buffer sized to accommodate max_slots
+ /// \param[out] selection_vector selection vector backed by 'buffer'
+ static Status MakeInt64(int64_t max_slots, std::shared_ptr<arrow::Buffer> buffer,
+ std::shared_ptr<SelectionVector>* selection_vector);
+
+ /// \brief make selection vector with int64 type records.
+ ///
+ /// \param[in] max_slots max number of slots
+ /// \param[in] pool memory pool to allocate buffer
+ /// \param[out] selection_vector selection vector backed by a buffer allocated from the
+ /// pool.
+ static Status MakeInt64(int64_t max_slots, arrow::MemoryPool* pool,
+ std::shared_ptr<SelectionVector>* selection_vector);
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/selection_vector_impl.h b/src/arrow/cpp/src/gandiva/selection_vector_impl.h
new file mode 100644
index 000000000..dc9724ca8
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/selection_vector_impl.h
@@ -0,0 +1,108 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <limits>
+#include <memory>
+
+#include "arrow/status.h"
+#include "arrow/util/macros.h"
+
+#include "arrow/util/logging.h"
+#include "gandiva/arrow.h"
+#include "gandiva/selection_vector.h"
+
+namespace gandiva {
+
+/// \brief template implementation of selection vector with a specific ctype and arrow
+/// type.
+template <typename C_TYPE, typename A_TYPE, SelectionVector::Mode mode>
+class SelectionVectorImpl : public SelectionVector {
+ public:
+ SelectionVectorImpl(int64_t max_slots, std::shared_ptr<arrow::Buffer> buffer)
+ : max_slots_(max_slots), num_slots_(0), buffer_(buffer), mode_(mode) {
+ raw_data_ = reinterpret_cast<C_TYPE*>(buffer->mutable_data());
+ }
+
+ SelectionVectorImpl(int64_t max_slots, int64_t num_slots,
+ std::shared_ptr<arrow::Buffer> buffer)
+ : max_slots_(max_slots), num_slots_(num_slots), buffer_(buffer), mode_(mode) {
+ if (buffer) {
+ raw_data_ = const_cast<C_TYPE*>(reinterpret_cast<const C_TYPE*>(buffer->data()));
+ }
+ }
+
+ uint64_t GetIndex(int64_t index) const override { return raw_data_[index]; }
+
+ void SetIndex(int64_t index, uint64_t value) override {
+ raw_data_[index] = static_cast<C_TYPE>(value);
+ }
+
+ ArrayPtr ToArray() const override;
+
+ int64_t GetMaxSlots() const override { return max_slots_; }
+
+ int64_t GetNumSlots() const override { return num_slots_; }
+
+ void SetNumSlots(int64_t num_slots) override {
+ DCHECK_LE(num_slots, max_slots_);
+ num_slots_ = num_slots;
+ }
+
+ uint64_t GetMaxSupportedValue() const override {
+ return std::numeric_limits<C_TYPE>::max();
+ }
+
+ Mode GetMode() const override { return mode_; }
+
+ arrow::Buffer& GetBuffer() const override { return *buffer_; }
+
+ static Status AllocateBuffer(int64_t max_slots, arrow::MemoryPool* pool,
+ std::shared_ptr<arrow::Buffer>* buffer);
+
+ static Status ValidateBuffer(int64_t max_slots, std::shared_ptr<arrow::Buffer> buffer);
+
+ protected:
+ /// maximum slots in the vector
+ int64_t max_slots_;
+
+ /// number of slots in the vector
+ int64_t num_slots_;
+
+ std::shared_ptr<arrow::Buffer> buffer_;
+ C_TYPE* raw_data_;
+
+ /// SelectionVector mode
+ Mode mode_;
+};
+
+template <typename C_TYPE, typename A_TYPE, SelectionVector::Mode mode>
+ArrayPtr SelectionVectorImpl<C_TYPE, A_TYPE, mode>::ToArray() const {
+ auto data_type = arrow::TypeTraits<A_TYPE>::type_singleton();
+ auto array_data = arrow::ArrayData::Make(data_type, num_slots_, {NULLPTR, buffer_});
+ return arrow::MakeArray(array_data);
+}
+
+using SelectionVectorInt16 =
+ SelectionVectorImpl<uint16_t, arrow::UInt16Type, SelectionVector::MODE_UINT16>;
+using SelectionVectorInt32 =
+ SelectionVectorImpl<uint32_t, arrow::UInt32Type, SelectionVector::MODE_UINT32>;
+using SelectionVectorInt64 =
+ SelectionVectorImpl<uint64_t, arrow::UInt64Type, SelectionVector::MODE_UINT64>;
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/selection_vector_test.cc b/src/arrow/cpp/src/gandiva/selection_vector_test.cc
new file mode 100644
index 000000000..686892901
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/selection_vector_test.cc
@@ -0,0 +1,270 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/selection_vector.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+
+namespace gandiva {
+
+class TestSelectionVector : public ::testing::Test {
+ protected:
+ virtual void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ arrow::MemoryPool* pool_;
+};
+
+static inline uint32_t RoundUpNumi64(uint32_t value) { return (value + 63) >> 6; }
+
+TEST_F(TestSelectionVector, TestInt16Make) {
+ int max_slots = 10;
+
+ // Test with pool allocation
+ std::shared_ptr<SelectionVector> selection;
+ auto status = SelectionVector::MakeInt16(max_slots, pool_, &selection);
+ EXPECT_EQ(status.ok(), true) << status.message();
+ EXPECT_EQ(selection->GetMaxSlots(), max_slots);
+ EXPECT_EQ(selection->GetNumSlots(), 0);
+
+ // Test with pre-alloced buffer
+ std::shared_ptr<SelectionVector> selection2;
+ auto buffer_len = max_slots * sizeof(int16_t);
+ ASSERT_OK_AND_ASSIGN(auto buffer, arrow::AllocateBuffer(buffer_len, pool_));
+
+ status = SelectionVector::MakeInt16(max_slots, std::move(buffer), &selection2);
+ EXPECT_EQ(status.ok(), true) << status.message();
+ EXPECT_EQ(selection2->GetMaxSlots(), max_slots);
+ EXPECT_EQ(selection2->GetNumSlots(), 0);
+}
+
+TEST_F(TestSelectionVector, TestInt16MakeNegative) {
+ int max_slots = 10;
+
+ std::shared_ptr<SelectionVector> selection;
+ auto buffer_len = max_slots * sizeof(int16_t);
+
+ // alloc a buffer that's insufficient.
+ ASSERT_OK_AND_ASSIGN(auto buffer, arrow::AllocateBuffer(buffer_len - 16, pool_));
+
+ auto status = SelectionVector::MakeInt16(max_slots, std::move(buffer), &selection);
+ EXPECT_EQ(status.IsInvalid(), true);
+}
+
+TEST_F(TestSelectionVector, TestInt16Set) {
+ int max_slots = 10;
+
+ std::shared_ptr<SelectionVector> selection;
+ auto status = SelectionVector::MakeInt16(max_slots, pool_, &selection);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ selection->SetIndex(0, 100);
+ EXPECT_EQ(selection->GetIndex(0), 100);
+
+ selection->SetIndex(1, 200);
+ EXPECT_EQ(selection->GetIndex(1), 200);
+
+ selection->SetNumSlots(2);
+ EXPECT_EQ(selection->GetNumSlots(), 2);
+
+ // TopArray() should return an array with 100,200
+ auto array_raw = selection->ToArray();
+ const auto& array = dynamic_cast<const arrow::UInt16Array&>(*array_raw);
+ EXPECT_EQ(array.length(), 2) << array_raw->ToString();
+ EXPECT_EQ(array.Value(0), 100) << array_raw->ToString();
+ EXPECT_EQ(array.Value(1), 200) << array_raw->ToString();
+}
+
+TEST_F(TestSelectionVector, TestInt16PopulateFromBitMap) {
+ int max_slots = 200;
+
+ std::shared_ptr<SelectionVector> selection;
+ auto status = SelectionVector::MakeInt16(max_slots, pool_, &selection);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ int bitmap_size = RoundUpNumi64(max_slots) * 8;
+ std::vector<uint8_t> bitmap(bitmap_size);
+
+ arrow::BitUtil::SetBit(&bitmap[0], 0);
+ arrow::BitUtil::SetBit(&bitmap[0], 5);
+ arrow::BitUtil::SetBit(&bitmap[0], 121);
+ arrow::BitUtil::SetBit(&bitmap[0], 220);
+
+ status = selection->PopulateFromBitMap(&bitmap[0], bitmap_size, max_slots - 1);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ EXPECT_EQ(selection->GetNumSlots(), 3);
+ EXPECT_EQ(selection->GetIndex(0), 0);
+ EXPECT_EQ(selection->GetIndex(1), 5);
+ EXPECT_EQ(selection->GetIndex(2), 121);
+}
+
+TEST_F(TestSelectionVector, TestInt16PopulateFromBitMapNegative) {
+ int max_slots = 2;
+
+ std::shared_ptr<SelectionVector> selection;
+ auto status = SelectionVector::MakeInt16(max_slots, pool_, &selection);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ int bitmap_size = 16;
+ std::vector<uint8_t> bitmap(bitmap_size);
+
+ arrow::BitUtil::SetBit(&bitmap[0], 0);
+ arrow::BitUtil::SetBit(&bitmap[0], 1);
+ arrow::BitUtil::SetBit(&bitmap[0], 2);
+
+ // The bitmap has three set bits, whereas the selection vector has capacity for only 2.
+ status = selection->PopulateFromBitMap(&bitmap[0], bitmap_size, 2);
+ EXPECT_EQ(status.IsInvalid(), true);
+}
+
+TEST_F(TestSelectionVector, TestInt32Set) {
+ int max_slots = 10;
+
+ std::shared_ptr<SelectionVector> selection;
+ auto status = SelectionVector::MakeInt32(max_slots, pool_, &selection);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ selection->SetIndex(0, 100);
+ EXPECT_EQ(selection->GetIndex(0), 100);
+
+ selection->SetIndex(1, 200);
+ EXPECT_EQ(selection->GetIndex(1), 200);
+
+ selection->SetIndex(2, 100000);
+ EXPECT_EQ(selection->GetIndex(2), 100000);
+
+ selection->SetNumSlots(3);
+ EXPECT_EQ(selection->GetNumSlots(), 3);
+
+ // TopArray() should return an array with 100,200,100000
+ auto array_raw = selection->ToArray();
+ const auto& array = dynamic_cast<const arrow::UInt32Array&>(*array_raw);
+ EXPECT_EQ(array.length(), 3) << array_raw->ToString();
+ EXPECT_EQ(array.Value(0), 100) << array_raw->ToString();
+ EXPECT_EQ(array.Value(1), 200) << array_raw->ToString();
+ EXPECT_EQ(array.Value(2), 100000) << array_raw->ToString();
+}
+
+TEST_F(TestSelectionVector, TestInt32PopulateFromBitMap) {
+ int max_slots = 200;
+
+ std::shared_ptr<SelectionVector> selection;
+ auto status = SelectionVector::MakeInt32(max_slots, pool_, &selection);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ int bitmap_size = RoundUpNumi64(max_slots) * 8;
+ std::vector<uint8_t> bitmap(bitmap_size);
+
+ arrow::BitUtil::SetBit(&bitmap[0], 0);
+ arrow::BitUtil::SetBit(&bitmap[0], 5);
+ arrow::BitUtil::SetBit(&bitmap[0], 121);
+ arrow::BitUtil::SetBit(&bitmap[0], 220);
+
+ status = selection->PopulateFromBitMap(&bitmap[0], bitmap_size, max_slots - 1);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ EXPECT_EQ(selection->GetNumSlots(), 3);
+ EXPECT_EQ(selection->GetIndex(0), 0);
+ EXPECT_EQ(selection->GetIndex(1), 5);
+ EXPECT_EQ(selection->GetIndex(2), 121);
+}
+
+TEST_F(TestSelectionVector, TestInt32MakeNegative) {
+ int max_slots = 10;
+
+ std::shared_ptr<SelectionVector> selection;
+ auto buffer_len = max_slots * sizeof(int32_t);
+
+ // alloc a buffer that's insufficient.
+ ASSERT_OK_AND_ASSIGN(auto buffer, arrow::AllocateBuffer(buffer_len - 1, pool_));
+
+ auto status = SelectionVector::MakeInt32(max_slots, std::move(buffer), &selection);
+ EXPECT_EQ(status.IsInvalid(), true);
+}
+
+TEST_F(TestSelectionVector, TestInt64Set) {
+ int max_slots = 10;
+
+ std::shared_ptr<SelectionVector> selection;
+ auto status = SelectionVector::MakeInt64(max_slots, pool_, &selection);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ selection->SetIndex(0, 100);
+ EXPECT_EQ(selection->GetIndex(0), 100);
+
+ selection->SetIndex(1, 200);
+ EXPECT_EQ(selection->GetIndex(1), 200);
+
+ selection->SetIndex(2, 100000);
+ EXPECT_EQ(selection->GetIndex(2), 100000);
+
+ selection->SetNumSlots(3);
+ EXPECT_EQ(selection->GetNumSlots(), 3);
+
+ // TopArray() should return an array with 100,200,100000
+ auto array_raw = selection->ToArray();
+ const auto& array = dynamic_cast<const arrow::UInt64Array&>(*array_raw);
+ EXPECT_EQ(array.length(), 3) << array_raw->ToString();
+ EXPECT_EQ(array.Value(0), 100) << array_raw->ToString();
+ EXPECT_EQ(array.Value(1), 200) << array_raw->ToString();
+ EXPECT_EQ(array.Value(2), 100000) << array_raw->ToString();
+}
+
+TEST_F(TestSelectionVector, TestInt64PopulateFromBitMap) {
+ int max_slots = 200;
+
+ std::shared_ptr<SelectionVector> selection;
+ auto status = SelectionVector::MakeInt64(max_slots, pool_, &selection);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ int bitmap_size = RoundUpNumi64(max_slots) * 8;
+ std::vector<uint8_t> bitmap(bitmap_size);
+
+ arrow::BitUtil::SetBit(&bitmap[0], 0);
+ arrow::BitUtil::SetBit(&bitmap[0], 5);
+ arrow::BitUtil::SetBit(&bitmap[0], 121);
+ arrow::BitUtil::SetBit(&bitmap[0], 220);
+
+ status = selection->PopulateFromBitMap(&bitmap[0], bitmap_size, max_slots - 1);
+ EXPECT_EQ(status.ok(), true) << status.message();
+
+ EXPECT_EQ(selection->GetNumSlots(), 3);
+ EXPECT_EQ(selection->GetIndex(0), 0);
+ EXPECT_EQ(selection->GetIndex(1), 5);
+ EXPECT_EQ(selection->GetIndex(2), 121);
+}
+
+TEST_F(TestSelectionVector, TestInt64MakeNegative) {
+ int max_slots = 10;
+
+ std::shared_ptr<SelectionVector> selection;
+ auto buffer_len = max_slots * sizeof(int64_t);
+
+ // alloc a buffer that's insufficient.
+ ASSERT_OK_AND_ASSIGN(auto buffer, arrow::AllocateBuffer(buffer_len - 1, pool_));
+
+ auto status = SelectionVector::MakeInt64(max_slots, std::move(buffer), &selection);
+ EXPECT_EQ(status.IsInvalid(), true);
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/simple_arena.h b/src/arrow/cpp/src/gandiva/simple_arena.h
new file mode 100644
index 000000000..da00b3397
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/simple_arena.h
@@ -0,0 +1,160 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "gandiva/arrow.h"
+
+namespace gandiva {
+
+/// \brief Simple arena allocator.
+///
+/// Memory is allocated from system in units of chunk-size, and dished out in the
+/// requested sizes. If the requested size > chunk-size, allocate directly from the
+/// system.
+///
+/// The allocated memory gets released only when the arena is destroyed, or on
+/// Reset.
+///
+/// This code is not multi-thread safe, and avoids all locking for efficiency.
+///
+class SimpleArena {
+ public:
+ explicit SimpleArena(arrow::MemoryPool* pool, int64_t min_chunk_size = 4096);
+
+ ~SimpleArena();
+
+ // Allocate buffer of requested size.
+ uint8_t* Allocate(int64_t size);
+
+ // Reset arena state.
+ void Reset();
+
+ // total bytes allocated from system.
+ int64_t total_bytes() { return total_bytes_; }
+
+ // total bytes available for allocations.
+ int64_t avail_bytes() { return avail_bytes_; }
+
+ private:
+ struct Chunk {
+ Chunk(uint8_t* buf, int64_t size) : buf_(buf), size_(size) {}
+
+ uint8_t* buf_;
+ int64_t size_;
+ };
+
+ // Allocate new chunk.
+ arrow::Status AllocateChunk(int64_t size);
+
+ // release memory from buffers.
+ void ReleaseChunks(bool retain_first);
+
+ // Memory pool used for allocs.
+ arrow::MemoryPool* pool_;
+
+ // The chunk-size used for allocations from system.
+ int64_t min_chunk_size_;
+
+ // Total bytes allocated from system.
+ int64_t total_bytes_;
+
+ // Bytes available from allocated chunk.
+ int64_t avail_bytes_;
+
+ // buffer from current chunk.
+ uint8_t* avail_buf_;
+
+ // List of allocated chunks.
+ std::vector<Chunk> chunks_;
+};
+
+inline SimpleArena::SimpleArena(arrow::MemoryPool* pool, int64_t min_chunk_size)
+ : pool_(pool),
+ min_chunk_size_(min_chunk_size),
+ total_bytes_(0),
+ avail_bytes_(0),
+ avail_buf_(NULL) {}
+
+inline SimpleArena::~SimpleArena() { ReleaseChunks(false /*retain_first*/); }
+
+inline uint8_t* SimpleArena::Allocate(int64_t size) {
+ if (avail_bytes_ < size) {
+ auto status = AllocateChunk(std::max(size, min_chunk_size_));
+ if (!status.ok()) {
+ return NULL;
+ }
+ }
+
+ uint8_t* ret = avail_buf_;
+ avail_buf_ += size;
+ avail_bytes_ -= size;
+ return ret;
+}
+
+inline arrow::Status SimpleArena::AllocateChunk(int64_t size) {
+ uint8_t* out;
+
+ auto status = pool_->Allocate(size, &out);
+ ARROW_RETURN_NOT_OK(status);
+
+ chunks_.emplace_back(out, size);
+ avail_buf_ = out;
+ avail_bytes_ = size; // left-over bytes in the previous chunk cannot be used anymore.
+ total_bytes_ += size;
+ return arrow::Status::OK();
+}
+
+// In the most common case, a chunk will be allocated when processing the first record.
+// And, the same chunk can be used for processing the remaining records in the batch.
+// By retaining the first chunk, the number of malloc calls are reduced to one per batch,
+// instead of one per record.
+inline void SimpleArena::Reset() {
+ if (chunks_.size() == 0) {
+ // if there are no chunks, nothing to do.
+ return;
+ }
+
+ // Release all but the first chunk.
+ if (chunks_.size() > 1) {
+ ReleaseChunks(true);
+ chunks_.erase(chunks_.begin() + 1, chunks_.end());
+ }
+
+ avail_buf_ = chunks_.at(0).buf_;
+ avail_bytes_ = total_bytes_ = chunks_.at(0).size_;
+}
+
+inline void SimpleArena::ReleaseChunks(bool retain_first) {
+ for (auto& chunk : chunks_) {
+ if (retain_first) {
+ // skip freeing first chunk.
+ retain_first = false;
+ continue;
+ }
+ pool_->Free(chunk.buf_, chunk.size_);
+ }
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/simple_arena_test.cc b/src/arrow/cpp/src/gandiva/simple_arena_test.cc
new file mode 100644
index 000000000..60831280c
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/simple_arena_test.cc
@@ -0,0 +1,102 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/simple_arena.h"
+
+#include <gtest/gtest.h>
+
+#include "arrow/memory_pool.h"
+
+namespace gandiva {
+
+class TestSimpleArena : public ::testing::Test {};
+
+TEST_F(TestSimpleArena, TestAlloc) {
+ int64_t chunk_size = 4096;
+ SimpleArena arena(arrow::default_memory_pool(), chunk_size);
+
+ // Small allocations should come from the same chunk.
+ int64_t small_size = 100;
+ for (int64_t i = 0; i < 20; ++i) {
+ auto p = arena.Allocate(small_size);
+ EXPECT_NE(p, nullptr);
+
+ EXPECT_EQ(arena.total_bytes(), chunk_size);
+ EXPECT_EQ(arena.avail_bytes(), chunk_size - (i + 1) * small_size);
+ }
+
+ // large allocations require separate chunks
+ int64_t large_size = 100 * chunk_size;
+ auto p = arena.Allocate(large_size);
+ EXPECT_NE(p, nullptr);
+ EXPECT_EQ(arena.total_bytes(), chunk_size + large_size);
+ EXPECT_EQ(arena.avail_bytes(), 0);
+}
+
+// small followed by big, then reset
+TEST_F(TestSimpleArena, TestReset1) {
+ int64_t chunk_size = 4096;
+ SimpleArena arena(arrow::default_memory_pool(), chunk_size);
+
+ int64_t small_size = 100;
+ auto p = arena.Allocate(small_size);
+ EXPECT_NE(p, nullptr);
+
+ int64_t large_size = 100 * chunk_size;
+ p = arena.Allocate(large_size);
+ EXPECT_NE(p, nullptr);
+
+ EXPECT_EQ(arena.total_bytes(), chunk_size + large_size);
+ EXPECT_EQ(arena.avail_bytes(), 0);
+ arena.Reset();
+ EXPECT_EQ(arena.total_bytes(), chunk_size);
+ EXPECT_EQ(arena.avail_bytes(), chunk_size);
+
+ // should re-use buffer after reset.
+ p = arena.Allocate(small_size);
+ EXPECT_NE(p, nullptr);
+ EXPECT_EQ(arena.total_bytes(), chunk_size);
+ EXPECT_EQ(arena.avail_bytes(), chunk_size - small_size);
+}
+
+// big followed by small, then reset
+TEST_F(TestSimpleArena, TestReset2) {
+ int64_t chunk_size = 4096;
+ SimpleArena arena(arrow::default_memory_pool(), chunk_size);
+
+ int64_t large_size = 100 * chunk_size;
+ auto p = arena.Allocate(large_size);
+ EXPECT_NE(p, nullptr);
+
+ int64_t small_size = 100;
+ p = arena.Allocate(small_size);
+ EXPECT_NE(p, nullptr);
+
+ EXPECT_EQ(arena.total_bytes(), chunk_size + large_size);
+ EXPECT_EQ(arena.avail_bytes(), chunk_size - small_size);
+ arena.Reset();
+ EXPECT_EQ(arena.total_bytes(), large_size);
+ EXPECT_EQ(arena.avail_bytes(), large_size);
+
+ // should re-use buffer after reset.
+ p = arena.Allocate(small_size);
+ EXPECT_NE(p, nullptr);
+ EXPECT_EQ(arena.total_bytes(), large_size);
+ EXPECT_EQ(arena.avail_bytes(), large_size - small_size);
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/symbols.map b/src/arrow/cpp/src/gandiva/symbols.map
new file mode 100644
index 000000000..77f000106
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/symbols.map
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+{
+ # Symbols marked as 'local' are not exported by the DSO and thus may not
+ # be used by client applications.
+ local:
+ # devtoolset / static-libstdc++ symbols
+ __cxa_*;
+ __once_proxy;
+
+ extern "C++" {
+ # devtoolset or -static-libstdc++ - the Red Hat devtoolset statically
+ # links c++11 symbols into binaries so that the result may be executed on
+ # a system with an older libstdc++ which doesn't include the necessary
+ # c++11 symbols.
+ std::*;
+ *std::__once_call*;
+ };
+};
+
diff --git a/src/arrow/cpp/src/gandiva/tests/CMakeLists.txt b/src/arrow/cpp/src/gandiva/tests/CMakeLists.txt
new file mode 100644
index 000000000..5fa2da16c
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/CMakeLists.txt
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_gandiva_test(filter_test)
+add_gandiva_test(projector_test)
+add_gandiva_test(projector_build_validation_test)
+add_gandiva_test(if_expr_test)
+add_gandiva_test(literal_test)
+add_gandiva_test(boolean_expr_test)
+add_gandiva_test(binary_test)
+add_gandiva_test(date_time_test)
+add_gandiva_test(to_string_test)
+add_gandiva_test(utf8_test)
+add_gandiva_test(hash_test)
+add_gandiva_test(in_expr_test)
+add_gandiva_test(null_validity_test)
+add_gandiva_test(decimal_test)
+add_gandiva_test(decimal_single_test)
+add_gandiva_test(filter_project_test)
+
+if(ARROW_BUILD_STATIC)
+ add_gandiva_test(projector_test_static SOURCES projector_test.cc USE_STATIC_LINKING)
+ add_arrow_benchmark(micro_benchmarks
+ PREFIX
+ "gandiva"
+ EXTRA_LINK_LIBS
+ gandiva_static)
+endif()
diff --git a/src/arrow/cpp/src/gandiva/tests/binary_test.cc b/src/arrow/cpp/src/gandiva/tests/binary_test.cc
new file mode 100644
index 000000000..591c5befc
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/binary_test.cc
@@ -0,0 +1,136 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "gandiva/node.h"
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::binary;
+using arrow::boolean;
+using arrow::int32;
+
+class TestBinary : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+TEST_F(TestBinary, TestSimple) {
+ // schema for input fields
+ auto field_a = field("a", binary());
+ auto field_b = field("b", binary());
+ auto schema = arrow::schema({field_a, field_b});
+
+ // output fields
+ auto res = field("res", int32());
+
+ // build expressions.
+ // a > b ? octet_length(a) : octet_length(b)
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto node_b = TreeExprBuilder::MakeField(field_b);
+ auto octet_len_a = TreeExprBuilder::MakeFunction("octet_length", {node_a}, int32());
+ auto octet_len_b = TreeExprBuilder::MakeFunction("octet_length", {node_b}, int32());
+
+ auto is_greater =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
+ auto if_greater =
+ TreeExprBuilder::MakeIf(is_greater, octet_len_a, octet_len_b, int32());
+ auto expr = TreeExprBuilder::MakeExpression(if_greater, res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a =
+ MakeArrowArrayBinary({"foo", "hello", "hi", "bye"}, {true, true, true, false});
+ auto array_b =
+ MakeArrowArrayBinary({"fo", "hellos", "hi", "bye"}, {true, true, true, true});
+
+ // expected output
+ auto exp = MakeArrowArrayInt32({3, 6, 2, 3}, {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestBinary, TestIfElse) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::binary());
+ auto field1 = field("f1", arrow::binary());
+
+ auto schema = arrow::schema({field0, field1});
+
+ auto f0 = TreeExprBuilder::MakeField(field0);
+ auto f1 = TreeExprBuilder::MakeField(field1);
+
+ // output fields
+ auto field_result = field("out", arrow::binary());
+
+ // Build expression
+ auto cond = TreeExprBuilder::MakeFunction("isnotnull", {f0}, arrow::boolean());
+ auto ifexpr = TreeExprBuilder::MakeIf(cond, f0, f1, arrow::binary());
+ auto expr = TreeExprBuilder::MakeExpression(ifexpr, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_f0 =
+ MakeArrowArrayBinary({"foo", "hello", "hi", "bye"}, {true, true, true, false});
+ auto array_f1 =
+ MakeArrowArrayBinary({"fe", "fi", "fo", "fum"}, {true, true, true, true});
+
+ // expected output
+ auto exp =
+ MakeArrowArrayBinary({"foo", "hello", "hi", "fum"}, {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_f0, array_f1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/boolean_expr_test.cc b/src/arrow/cpp/src/gandiva/tests/boolean_expr_test.cc
new file mode 100644
index 000000000..9226f3571
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/boolean_expr_test.cc
@@ -0,0 +1,388 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::int32;
+
+class TestBooleanExpr : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+TEST_F(TestBooleanExpr, SimpleAnd) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto schema = arrow::schema({fielda, fieldb});
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ // build expression.
+ // (a > 0) && (b > 0)
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto literal_0 = TreeExprBuilder::MakeLiteral((int32_t)0);
+ auto a_gt_0 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_0}, boolean());
+ auto b_gt_0 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_0}, boolean());
+
+ auto node_and = TreeExprBuilder::MakeAnd({a_gt_0, b_gt_0});
+ auto expr = TreeExprBuilder::MakeExpression(node_and, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // FALSE_VALID && ? => FALSE_VALID
+ int num_records = 4;
+ auto arraya = MakeArrowArrayInt32({-2, -2, -2, -2}, {true, true, true, true});
+ auto arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false});
+ auto exp = MakeArrowArrayBool({false, false, false, false}, {true, true, true, true});
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb});
+
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+
+ // FALSE_INVALID && ?
+ num_records = 4;
+ arraya = MakeArrowArrayInt32({-2, -2, -2, -2}, {false, false, false, false});
+ arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false});
+ exp = MakeArrowArrayBool({false, false, false, false}, {true, false, false, false});
+ in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb});
+ outputs.clear();
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+
+ // TRUE_VALID && ?
+ num_records = 4;
+ arraya = MakeArrowArrayInt32({2, 2, 2, 2}, {true, true, true, true});
+ arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false});
+ exp = MakeArrowArrayBool({false, false, true, false}, {true, false, true, false});
+ in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb});
+ outputs.clear();
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+
+ // TRUE_INVALID && ?
+ num_records = 4;
+ arraya = MakeArrowArrayInt32({2, 2, 2, 2}, {false, false, false, false});
+ arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false});
+ exp = MakeArrowArrayBool({false, false, false, false}, {true, false, false, false});
+ in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb});
+ outputs.clear();
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestBooleanExpr, SimpleOr) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto schema = arrow::schema({fielda, fieldb});
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ // build expression.
+ // (a > 0) || (b > 0)
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto literal_0 = TreeExprBuilder::MakeLiteral((int32_t)0);
+ auto a_gt_0 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_0}, boolean());
+ auto b_gt_0 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_0}, boolean());
+
+ auto node_or = TreeExprBuilder::MakeOr({a_gt_0, b_gt_0});
+ auto expr = TreeExprBuilder::MakeExpression(node_or, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // TRUE_VALID && ? => TRUE_VALID
+ int num_records = 4;
+ auto arraya = MakeArrowArrayInt32({2, 2, 2, 2}, {true, true, true, true});
+ auto arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false});
+ auto exp = MakeArrowArrayBool({true, true, true, true}, {true, true, true, true});
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb});
+
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+
+ // TRUE_INVALID && ?
+ num_records = 4;
+ arraya = MakeArrowArrayInt32({2, 2, 2, 2}, {false, false, false, false});
+ arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false});
+ exp = MakeArrowArrayBool({false, false, true, false}, {false, false, true, false});
+ in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb});
+ outputs.clear();
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+
+ // FALSE_VALID && ?
+ num_records = 4;
+ arraya = MakeArrowArrayInt32({-2, -2, -2, -2}, {true, true, true, true});
+ arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false});
+ exp = MakeArrowArrayBool({false, false, true, false}, {true, false, true, false});
+ in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb});
+ outputs.clear();
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+
+ // FALSE_INVALID && ?
+ num_records = 4;
+ arraya = MakeArrowArrayInt32({-2, -2, -2, -2}, {false, false, false, false});
+ arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false});
+ exp = MakeArrowArrayBool({false, false, true, false}, {false, false, true, false});
+ in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb});
+ outputs.clear();
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestBooleanExpr, AndThree) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto fieldc = field("c", int32());
+ auto schema = arrow::schema({fielda, fieldb, fieldc});
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ // build expression.
+ // (a > 0) && (b > 0) && (c > 0)
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto node_c = TreeExprBuilder::MakeField(fieldc);
+ auto literal_0 = TreeExprBuilder::MakeLiteral((int32_t)0);
+ auto a_gt_0 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_0}, boolean());
+ auto b_gt_0 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_0}, boolean());
+ auto c_gt_0 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_c, literal_0}, boolean());
+
+ auto node_and = TreeExprBuilder::MakeAnd({a_gt_0, b_gt_0, c_gt_0});
+ auto expr = TreeExprBuilder::MakeExpression(node_and, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ int num_records = 8;
+ std::vector<bool> validity({true, true, true, true, true, true, true, true});
+ auto arraya = MakeArrowArrayInt32({2, 2, 2, 0, 2, 0, 0, 0}, validity);
+ auto arrayb = MakeArrowArrayInt32({2, 2, 0, 2, 0, 2, 0, 0}, validity);
+ auto arrayc = MakeArrowArrayInt32({2, 0, 2, 2, 0, 0, 2, 0}, validity);
+ auto exp = MakeArrowArrayBool({true, false, false, false, false, false, false, false},
+ validity);
+
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb, arrayc});
+
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestBooleanExpr, OrThree) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto fieldc = field("c", int32());
+ auto schema = arrow::schema({fielda, fieldb, fieldc});
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ // build expression.
+ // (a > 0) || (b > 0) || (c > 0)
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto node_c = TreeExprBuilder::MakeField(fieldc);
+ auto literal_0 = TreeExprBuilder::MakeLiteral((int32_t)0);
+ auto a_gt_0 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_0}, boolean());
+ auto b_gt_0 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_0}, boolean());
+ auto c_gt_0 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_c, literal_0}, boolean());
+
+ auto node_or = TreeExprBuilder::MakeOr({a_gt_0, b_gt_0, c_gt_0});
+ auto expr = TreeExprBuilder::MakeExpression(node_or, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ int num_records = 8;
+ std::vector<bool> validity({true, true, true, true, true, true, true, true});
+ auto arraya = MakeArrowArrayInt32({2, 2, 2, 0, 2, 0, 0, 0}, validity);
+ auto arrayb = MakeArrowArrayInt32({2, 2, 0, 2, 0, 2, 0, 0}, validity);
+ auto arrayc = MakeArrowArrayInt32({2, 0, 2, 2, 0, 0, 2, 0}, validity);
+ auto exp =
+ MakeArrowArrayBool({true, true, true, true, true, true, true, false}, validity);
+
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb, arrayc});
+
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestBooleanExpr, BooleanAndInsideIf) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto schema = arrow::schema({fielda, fieldb});
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ // build expression.
+ // if (a > 2 && b > 2)
+ // a > 3 && b > 3
+ // else
+ // a > 1 && b > 1
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto literal_1 = TreeExprBuilder::MakeLiteral((int32_t)1);
+ auto literal_2 = TreeExprBuilder::MakeLiteral((int32_t)2);
+ auto literal_3 = TreeExprBuilder::MakeLiteral((int32_t)3);
+ auto a_gt_1 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_1}, boolean());
+ auto a_gt_2 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_2}, boolean());
+ auto a_gt_3 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_3}, boolean());
+ auto b_gt_1 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_1}, boolean());
+ auto b_gt_2 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_2}, boolean());
+ auto b_gt_3 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_3}, boolean());
+
+ auto and_1 = TreeExprBuilder::MakeAnd({a_gt_1, b_gt_1});
+ auto and_2 = TreeExprBuilder::MakeAnd({a_gt_2, b_gt_2});
+ auto and_3 = TreeExprBuilder::MakeAnd({a_gt_3, b_gt_3});
+
+ auto node_if = TreeExprBuilder::MakeIf(and_2, and_3, and_1, arrow::boolean());
+ auto expr = TreeExprBuilder::MakeExpression(node_if, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ int num_records = 4;
+ std::vector<bool> validity({true, true, true, true});
+ auto arraya = MakeArrowArrayInt32({4, 4, 2, 1}, validity);
+ auto arrayb = MakeArrowArrayInt32({5, 3, 3, 1}, validity);
+ auto exp = MakeArrowArrayBool({true, false, true, false}, validity);
+
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb});
+
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestBooleanExpr, IfInsideBooleanAnd) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto schema = arrow::schema({fielda, fieldb});
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ // build expression.
+ // (if (a > b) a > 3 else b > 3) && (if (a > b) a > 2 else b > 2)
+
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto literal_2 = TreeExprBuilder::MakeLiteral((int32_t)2);
+ auto literal_3 = TreeExprBuilder::MakeLiteral((int32_t)3);
+ auto a_gt_b =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
+ auto a_gt_2 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_2}, boolean());
+ auto a_gt_3 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_3}, boolean());
+ auto b_gt_2 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_2}, boolean());
+ auto b_gt_3 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_3}, boolean());
+
+ auto if_3 = TreeExprBuilder::MakeIf(a_gt_b, a_gt_3, b_gt_3, arrow::boolean());
+ auto if_2 = TreeExprBuilder::MakeIf(a_gt_b, a_gt_2, b_gt_2, arrow::boolean());
+ auto node_and = TreeExprBuilder::MakeAnd({if_3, if_2});
+ auto expr = TreeExprBuilder::MakeExpression(node_and, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ int num_records = 4;
+ std::vector<bool> validity({true, true, true, true});
+ auto arraya = MakeArrowArrayInt32({4, 3, 3, 2}, validity);
+ auto arrayb = MakeArrowArrayInt32({3, 4, 2, 3}, validity);
+ auto exp = MakeArrowArrayBool({true, true, false, false}, validity);
+
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb});
+
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/date_time_test.cc b/src/arrow/cpp/src/gandiva/tests/date_time_test.cc
new file mode 100644
index 000000000..77139125f
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/date_time_test.cc
@@ -0,0 +1,602 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include <math.h>
+#include <time.h>
+
+#include "arrow/memory_pool.h"
+#include "gandiva/precompiled/time_constants.h"
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::date32;
+using arrow::date64;
+using arrow::float32;
+using arrow::int32;
+using arrow::int64;
+using arrow::timestamp;
+
+class TestProjector : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+time_t Epoch() {
+ // HACK: MSVC mktime() fails on UTC times before 1970-01-01 00:00:00.
+ // But it first converts its argument from local time to UTC time,
+ // so we ask for 1970-01-02 to avoid failing in timezones ahead of UTC.
+ struct tm y1970;
+ memset(&y1970, 0, sizeof(struct tm));
+ y1970.tm_year = 70;
+ y1970.tm_mon = 0;
+ y1970.tm_mday = 2;
+ y1970.tm_hour = 0;
+ y1970.tm_min = 0;
+ y1970.tm_sec = 0;
+ time_t epoch = mktime(&y1970);
+ if (epoch == static_cast<time_t>(-1)) {
+ ARROW_LOG(FATAL) << "mktime() failed";
+ }
+ // Adjust for the 24h offset above.
+ return epoch - 24 * 3600;
+}
+
+int32_t MillisInDay(int32_t hh, int32_t mm, int32_t ss, int32_t millis) {
+ int32_t mins = hh * 60 + mm;
+ int32_t secs = mins * 60 + ss;
+
+ return secs * 1000 + millis;
+}
+
+int64_t MillisSince(time_t base_line, int32_t yy, int32_t mm, int32_t dd, int32_t hr,
+ int32_t min, int32_t sec, int32_t millis) {
+ struct tm given_ts;
+ memset(&given_ts, 0, sizeof(struct tm));
+ given_ts.tm_year = (yy - 1900);
+ given_ts.tm_mon = (mm - 1);
+ given_ts.tm_mday = dd;
+ given_ts.tm_hour = hr;
+ given_ts.tm_min = min;
+ given_ts.tm_sec = sec;
+
+ time_t ts = mktime(&given_ts);
+ if (ts == static_cast<time_t>(-1)) {
+ ARROW_LOG(FATAL) << "mktime() failed";
+ }
+ // time_t is an arithmetic type on both POSIX and Windows, we can simply
+ // subtract to get a duration in seconds.
+ return static_cast<int64_t>(ts - base_line) * 1000 + millis;
+}
+
+int32_t DaysSince(time_t base_line, int32_t yy, int32_t mm, int32_t dd, int32_t hr,
+ int32_t min, int32_t sec, int32_t millis) {
+ struct tm given_ts;
+ memset(&given_ts, 0, sizeof(struct tm));
+ given_ts.tm_year = (yy - 1900);
+ given_ts.tm_mon = (mm - 1);
+ given_ts.tm_mday = dd;
+ given_ts.tm_hour = hr;
+ given_ts.tm_min = min;
+ given_ts.tm_sec = sec;
+
+ time_t ts = mktime(&given_ts);
+ if (ts == static_cast<time_t>(-1)) {
+ ARROW_LOG(FATAL) << "mktime() failed";
+ }
+ // time_t is an arithmetic type on both POSIX and Windows, we can simply
+ // subtract to get a duration in seconds.
+ return static_cast<int32_t>(((ts - base_line) * 1000 + millis) / MILLIS_IN_DAY);
+}
+
+TEST_F(TestProjector, TestIsNull) {
+ auto d0 = field("d0", date64());
+ auto t0 = field("t0", time32(arrow::TimeUnit::MILLI));
+ auto schema = arrow::schema({d0, t0});
+
+ // output fields
+ auto b0 = field("isnull", boolean());
+
+ // isnull and isnotnull
+ auto isnull_expr = TreeExprBuilder::MakeExpression("isnull", {d0}, b0);
+ auto isnotnull_expr = TreeExprBuilder::MakeExpression("isnotnull", {t0}, b0);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {isnull_expr, isnotnull_expr},
+ TestConfiguration(), &projector);
+ ASSERT_TRUE(status.ok());
+
+ int num_records = 4;
+ std::vector<int64_t> d0_data = {0, 100, 0, 1000};
+ auto t0_data = {0, 100, 0, 1000};
+ auto validity = {false, true, false, true};
+ auto d0_array =
+ MakeArrowTypeArray<arrow::Date64Type, int64_t>(date64(), d0_data, validity);
+ auto t0_array = MakeArrowTypeArray<arrow::Time32Type, int32_t>(
+ time32(arrow::TimeUnit::MILLI), t0_data, validity);
+
+ // expected output
+ auto exp_isnull =
+ MakeArrowArrayBool({true, false, true, false}, {true, true, true, true});
+ auto exp_isnotnull = MakeArrowArrayBool(validity, {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {d0_array, t0_array});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_isnull, outputs.at(0));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_isnotnull, outputs.at(1));
+}
+
+TEST_F(TestProjector, TestDate32IsNull) {
+ auto d0 = field("d0", date32());
+ auto schema = arrow::schema({d0});
+
+ // output fields
+ auto b0 = field("isnull", boolean());
+
+ // isnull and isnotnull
+ auto isnull_expr = TreeExprBuilder::MakeExpression("isnull", {d0}, b0);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {isnull_expr}, TestConfiguration(), &projector);
+ ASSERT_TRUE(status.ok());
+
+ int num_records = 4;
+ std::vector<int32_t> d0_data = {0, 100, 0, 1000};
+ auto validity = {false, true, false, true};
+ auto d0_array =
+ MakeArrowTypeArray<arrow::Date32Type, int32_t>(date32(), d0_data, validity);
+
+ // expected output
+ auto exp_isnull =
+ MakeArrowArrayBool({true, false, true, false}, {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {d0_array});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_isnull, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestDateTime) {
+ auto field0 = field("f0", date64());
+ auto field1 = field("f1", date32());
+ auto field2 = field("f2", timestamp(arrow::TimeUnit::MILLI));
+ auto schema = arrow::schema({field0, field1, field2});
+
+ // output fields
+ auto field_year = field("yy", int64());
+ auto field_month = field("mm", int64());
+ auto field_day = field("dd", int64());
+ auto field_hour = field("hh", int64());
+ auto field_date64 = field("date64", date64());
+
+ // extract year and month from date
+ auto date2year_expr =
+ TreeExprBuilder::MakeExpression("extractYear", {field0}, field_year);
+ auto date2month_expr =
+ TreeExprBuilder::MakeExpression("extractMonth", {field0}, field_month);
+
+ // extract year and month from date32, cast to date64 first
+ auto node_f1 = TreeExprBuilder::MakeField(field1);
+ auto date32_to_date64_func =
+ TreeExprBuilder::MakeFunction("castDATE", {node_f1}, date64());
+
+ auto date64_2year_func =
+ TreeExprBuilder::MakeFunction("extractYear", {date32_to_date64_func}, int64());
+ auto date64_2year_expr = TreeExprBuilder::MakeExpression(date64_2year_func, field_year);
+
+ auto date64_2month_func =
+ TreeExprBuilder::MakeFunction("extractMonth", {date32_to_date64_func}, int64());
+ auto date64_2month_expr =
+ TreeExprBuilder::MakeExpression(date64_2month_func, field_month);
+
+ // extract month and day from timestamp
+ auto ts2month_expr =
+ TreeExprBuilder::MakeExpression("extractMonth", {field2}, field_month);
+ auto ts2day_expr = TreeExprBuilder::MakeExpression("extractDay", {field2}, field_day);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema,
+ {date2year_expr, date2month_expr, date64_2year_expr,
+ date64_2month_expr, ts2month_expr, ts2day_expr},
+ TestConfiguration(), &projector);
+ ASSERT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ time_t epoch = Epoch();
+ int num_records = 4;
+ auto validity = {true, true, true, true};
+ std::vector<int64_t> field0_data = {MillisSince(epoch, 2000, 1, 1, 5, 0, 0, 0),
+ MillisSince(epoch, 1999, 12, 31, 5, 0, 0, 0),
+ MillisSince(epoch, 2015, 6, 30, 20, 0, 0, 0),
+ MillisSince(epoch, 2015, 7, 1, 20, 0, 0, 0)};
+ auto array0 =
+ MakeArrowTypeArray<arrow::Date64Type, int64_t>(date64(), field0_data, validity);
+
+ std::vector<int32_t> field1_data = {DaysSince(epoch, 2000, 1, 1, 5, 0, 0, 0),
+ DaysSince(epoch, 1999, 12, 31, 5, 0, 0, 0),
+ DaysSince(epoch, 2015, 6, 30, 20, 0, 0, 0),
+ DaysSince(epoch, 2015, 7, 1, 20, 0, 0, 0)};
+ auto array1 =
+ MakeArrowTypeArray<arrow::Date32Type, int32_t>(date32(), field1_data, validity);
+
+ std::vector<int64_t> field2_data = {MillisSince(epoch, 1999, 12, 31, 5, 0, 0, 0),
+ MillisSince(epoch, 2000, 1, 2, 5, 0, 0, 0),
+ MillisSince(epoch, 2015, 7, 1, 1, 0, 0, 0),
+ MillisSince(epoch, 2015, 6, 29, 23, 0, 0, 0)};
+
+ auto array2 = MakeArrowTypeArray<arrow::TimestampType, int64_t>(
+ arrow::timestamp(arrow::TimeUnit::MILLI), field2_data, validity);
+
+ // expected output
+ // date 2 year and date 2 month for date64
+ auto exp_yy_from_date64 = MakeArrowArrayInt64({2000, 1999, 2015, 2015}, validity);
+ auto exp_mm_from_date64 = MakeArrowArrayInt64({1, 12, 6, 7}, validity);
+
+ // date 2 year and date 2 month for date32
+ auto exp_yy_from_date32 = MakeArrowArrayInt64({2000, 1999, 2015, 2015}, validity);
+ auto exp_mm_from_date32 = MakeArrowArrayInt64({1, 12, 6, 7}, validity);
+
+ // ts 2 month and ts 2 day
+ auto exp_mm_from_ts = MakeArrowArrayInt64({12, 1, 7, 6}, validity);
+ auto exp_dd_from_ts = MakeArrowArrayInt64({31, 2, 1, 29}, validity);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_yy_from_date64, outputs.at(0));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_mm_from_date64, outputs.at(1));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_yy_from_date32, outputs.at(2));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_mm_from_date32, outputs.at(3));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_mm_from_ts, outputs.at(4));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_dd_from_ts, outputs.at(5));
+}
+
+TEST_F(TestProjector, TestTime) {
+ auto field0 = field("f0", time32(arrow::TimeUnit::MILLI));
+ auto schema = arrow::schema({field0});
+
+ auto field_min = field("mm", int64());
+ auto field_hour = field("hh", int64());
+
+ // extract day and hour from time32
+ auto time2min_expr =
+ TreeExprBuilder::MakeExpression("extractMinute", {field0}, field_min);
+ auto time2hour_expr =
+ TreeExprBuilder::MakeExpression("extractHour", {field0}, field_hour);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {time2min_expr, time2hour_expr},
+ TestConfiguration(), &projector);
+ ASSERT_TRUE(status.ok());
+
+ // create input data
+ int num_records = 4;
+ auto validity = {true, true, true, true};
+ std::vector<int32_t> field_data = {
+ MillisInDay(5, 35, 25, 0), // 5:35:25
+ MillisInDay(0, 59, 0, 0), // 0:59:12
+ MillisInDay(12, 30, 0, 0), // 12:30:0
+ MillisInDay(23, 0, 0, 0) // 23:0:0
+ };
+ auto array = MakeArrowTypeArray<arrow::Time32Type, int32_t>(
+ time32(arrow::TimeUnit::MILLI), field_data, validity);
+
+ // expected output
+ auto exp_min = MakeArrowArrayInt64({35, 59, 30, 0}, validity);
+ auto exp_hour = MakeArrowArrayInt64({5, 0, 12, 23}, validity);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_min, outputs.at(0));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_hour, outputs.at(1));
+}
+
+TEST_F(TestProjector, TestTimestampDiff) {
+ auto f0 = field("f0", timestamp(arrow::TimeUnit::MILLI));
+ auto f1 = field("f1", timestamp(arrow::TimeUnit::MILLI));
+ auto schema = arrow::schema({f0, f1});
+
+ // output fields
+ auto diff_seconds = field("ss", int32());
+
+ // get diff
+ auto diff_secs_expr =
+ TreeExprBuilder::MakeExpression("timestampdiffSecond", {f0, f1}, diff_seconds);
+
+ auto diff_mins_expr =
+ TreeExprBuilder::MakeExpression("timestampdiffMinute", {f0, f1}, diff_seconds);
+
+ auto diff_hours_expr =
+ TreeExprBuilder::MakeExpression("timestampdiffHour", {f0, f1}, diff_seconds);
+
+ auto diff_days_expr =
+ TreeExprBuilder::MakeExpression("timestampdiffDay", {f0, f1}, diff_seconds);
+
+ auto diff_days_expr_with_datediff_fn =
+ TreeExprBuilder::MakeExpression("datediff", {f0, f1}, diff_seconds);
+
+ auto diff_weeks_expr =
+ TreeExprBuilder::MakeExpression("timestampdiffWeek", {f0, f1}, diff_seconds);
+
+ auto diff_months_expr =
+ TreeExprBuilder::MakeExpression("timestampdiffMonth", {f0, f1}, diff_seconds);
+
+ auto diff_quarters_expr =
+ TreeExprBuilder::MakeExpression("timestampdiffQuarter", {f0, f1}, diff_seconds);
+
+ auto diff_years_expr =
+ TreeExprBuilder::MakeExpression("timestampdiffYear", {f0, f1}, diff_seconds);
+
+ std::shared_ptr<Projector> projector;
+ auto exprs = {diff_secs_expr,
+ diff_mins_expr,
+ diff_hours_expr,
+ diff_days_expr,
+ diff_days_expr_with_datediff_fn,
+ diff_weeks_expr,
+ diff_months_expr,
+ diff_quarters_expr,
+ diff_years_expr};
+ auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector);
+ ASSERT_TRUE(status.ok());
+
+ time_t epoch = Epoch();
+
+ // 2015-09-10T20:49:42.000
+ auto start_millis = MillisSince(epoch, 2015, 9, 10, 20, 49, 42, 0);
+ // 2017-03-30T22:50:59.050
+ auto end_millis = MillisSince(epoch, 2017, 3, 30, 22, 50, 59, 50);
+ std::vector<int64_t> f0_data = {start_millis, end_millis,
+ // 2015-09-10T20:49:42.999
+ start_millis + 999,
+ // 2015-09-10T20:49:42.999
+ MillisSince(epoch, 2015, 9, 10, 20, 49, 42, 999)};
+ std::vector<int64_t> f1_data = {end_millis, start_millis,
+ // 2015-09-10T20:49:42.999
+ start_millis + 999,
+ // 2015-09-9T21:49:42.999 (23 hours behind)
+ MillisSince(epoch, 2015, 9, 9, 21, 49, 42, 999)};
+
+ int64_t num_records = f0_data.size();
+ std::vector<bool> validity(num_records, true);
+ auto array0 = MakeArrowTypeArray<arrow::TimestampType, int64_t>(
+ arrow::timestamp(arrow::TimeUnit::MILLI), f0_data, validity);
+ auto array1 = MakeArrowTypeArray<arrow::TimestampType, int64_t>(
+ arrow::timestamp(arrow::TimeUnit::MILLI), f1_data, validity);
+
+ // expected output
+ std::vector<ArrayPtr> exp_output;
+ exp_output.push_back(
+ MakeArrowArrayInt32({48996077, -48996077, 0, -23 * 3600}, validity));
+ exp_output.push_back(MakeArrowArrayInt32({816601, -816601, 0, -23 * 60}, validity));
+ exp_output.push_back(MakeArrowArrayInt32({13610, -13610, 0, -23}, validity));
+ exp_output.push_back(MakeArrowArrayInt32({567, -567, 0, 0}, validity));
+ exp_output.push_back(MakeArrowArrayInt32({567, -567, 0, 0}, validity));
+ exp_output.push_back(MakeArrowArrayInt32({81, -81, 0, 0}, validity));
+ exp_output.push_back(MakeArrowArrayInt32({18, -18, 0, 0}, validity));
+ exp_output.push_back(MakeArrowArrayInt32({6, -6, 0, 0}, validity));
+ exp_output.push_back(MakeArrowArrayInt32({1, -1, 0, 0}, validity));
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ for (uint32_t i = 0; i < exp_output.size(); i++) {
+ EXPECT_ARROW_ARRAY_EQUALS(exp_output.at(i), outputs.at(i));
+ }
+}
+
+TEST_F(TestProjector, TestTimestampDiffMonth) {
+ auto f0 = field("f0", timestamp(arrow::TimeUnit::MILLI));
+ auto f1 = field("f1", timestamp(arrow::TimeUnit::MILLI));
+ auto schema = arrow::schema({f0, f1});
+
+ // output fields
+ auto diff_seconds = field("ss", int32());
+
+ auto diff_months_expr =
+ TreeExprBuilder::MakeExpression("timestampdiffMonth", {f0, f1}, diff_seconds);
+
+ std::shared_ptr<Projector> projector;
+ auto status =
+ Projector::Make(schema, {diff_months_expr}, TestConfiguration(), &projector);
+ std::cout << status.message();
+ ASSERT_TRUE(status.ok());
+
+ time_t epoch = Epoch();
+
+ // Create a row-batch with some sample data
+ std::vector<int64_t> f0_data = {MillisSince(epoch, 2019, 1, 31, 0, 0, 0, 0),
+ MillisSince(epoch, 2020, 1, 31, 0, 0, 0, 0),
+ MillisSince(epoch, 2020, 1, 31, 0, 0, 0, 0),
+ MillisSince(epoch, 2019, 3, 31, 0, 0, 0, 0),
+ MillisSince(epoch, 2020, 3, 30, 0, 0, 0, 0),
+ MillisSince(epoch, 2020, 5, 31, 0, 0, 0, 0)};
+ std::vector<int64_t> f1_data = {MillisSince(epoch, 2019, 2, 28, 0, 0, 0, 0),
+ MillisSince(epoch, 2020, 2, 28, 0, 0, 0, 0),
+ MillisSince(epoch, 2020, 2, 29, 0, 0, 0, 0),
+ MillisSince(epoch, 2019, 4, 30, 0, 0, 0, 0),
+ MillisSince(epoch, 2020, 2, 29, 0, 0, 0, 0),
+ MillisSince(epoch, 2020, 9, 30, 0, 0, 0, 0)};
+ int64_t num_records = f0_data.size();
+ std::vector<bool> validity(num_records, true);
+
+ auto array0 = MakeArrowTypeArray<arrow::TimestampType, int64_t>(
+ arrow::timestamp(arrow::TimeUnit::MILLI), f0_data, validity);
+ auto array1 = MakeArrowTypeArray<arrow::TimestampType, int64_t>(
+ arrow::timestamp(arrow::TimeUnit::MILLI), f1_data, validity);
+
+ // expected output
+ std::vector<ArrayPtr> exp_output;
+ exp_output.push_back(MakeArrowArrayInt32({1, 0, 1, 1, -1, 4}, validity));
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ for (uint32_t i = 0; i < exp_output.size(); i++) {
+ EXPECT_ARROW_ARRAY_EQUALS(exp_output.at(i), outputs.at(i));
+ }
+}
+
+TEST_F(TestProjector, TestMonthsBetween) {
+ auto f0 = field("f0", arrow::date64());
+ auto f1 = field("f1", arrow::date64());
+ auto schema = arrow::schema({f0, f1});
+
+ // output fields
+ auto output = field("out", arrow::float64());
+
+ auto months_between_expr =
+ TreeExprBuilder::MakeExpression("months_between", {f0, f1}, output);
+
+ std::shared_ptr<Projector> projector;
+ auto status =
+ Projector::Make(schema, {months_between_expr}, TestConfiguration(), &projector);
+ std::cout << status.message();
+ ASSERT_TRUE(status.ok());
+
+ time_t epoch = Epoch();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto validity = {true, true, true, true};
+ std::vector<int64_t> f0_data = {MillisSince(epoch, 1995, 3, 2, 0, 0, 0, 0),
+ MillisSince(epoch, 1995, 2, 2, 0, 0, 0, 0),
+ MillisSince(epoch, 1995, 3, 31, 0, 0, 0, 0),
+ MillisSince(epoch, 1996, 3, 31, 0, 0, 0, 0)};
+
+ auto array0 =
+ MakeArrowTypeArray<arrow::Date64Type, int64_t>(date64(), f0_data, validity);
+
+ std::vector<int64_t> f1_data = {MillisSince(epoch, 1995, 2, 2, 0, 0, 0, 0),
+ MillisSince(epoch, 1995, 3, 2, 0, 0, 0, 0),
+ MillisSince(epoch, 1995, 2, 28, 0, 0, 0, 0),
+ MillisSince(epoch, 1996, 2, 29, 0, 0, 0, 0)};
+
+ auto array1 =
+ MakeArrowTypeArray<arrow::Date64Type, int64_t>(date64(), f1_data, validity);
+
+ // expected output
+ auto exp_output = MakeArrowArrayFloat64({1.0, -1.0, 1.0, 1.0}, validity);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_output, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestLastDay) {
+ auto f0 = field("f0", arrow::date64());
+ auto schema = arrow::schema({f0});
+
+ // output fields
+ auto output = field("out", arrow::date64());
+
+ auto last_day_expr = TreeExprBuilder::MakeExpression("last_day", {f0}, output);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {last_day_expr}, TestConfiguration(), &projector);
+ std::cout << status.message();
+ ASSERT_TRUE(status.ok());
+
+ time_t epoch = Epoch();
+
+ // Create a row-batch with some sample data
+ // Used a leap year as example.
+ int num_records = 5;
+ auto validity = {true, true, true, true, true};
+ std::vector<int64_t> f0_data = {MillisSince(epoch, 2016, 2, 3, 8, 20, 10, 34),
+ MillisSince(epoch, 2016, 2, 29, 23, 59, 59, 59),
+ MillisSince(epoch, 2016, 1, 30, 1, 15, 20, 0),
+ MillisSince(epoch, 2017, 2, 3, 23, 15, 20, 0),
+ MillisSince(epoch, 2015, 12, 30, 22, 50, 11, 0)};
+
+ auto array0 =
+ MakeArrowTypeArray<arrow::Date64Type, int64_t>(date64(), f0_data, validity);
+
+ std::vector<int64_t> f0_output_data = {MillisSince(epoch, 2016, 2, 29, 0, 0, 0, 0),
+ MillisSince(epoch, 2016, 2, 29, 0, 0, 0, 0),
+ MillisSince(epoch, 2016, 1, 31, 0, 0, 0, 0),
+ MillisSince(epoch, 2017, 2, 28, 0, 0, 0, 0),
+ MillisSince(epoch, 2015, 12, 31, 0, 0, 0, 0)};
+
+ // expected output
+ auto exp_output = MakeArrowArrayDate64(f0_output_data, validity);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_output, outputs.at(0));
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/decimal_single_test.cc b/src/arrow/cpp/src/gandiva/tests/decimal_single_test.cc
new file mode 100644
index 000000000..666ee4a68
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/decimal_single_test.cc
@@ -0,0 +1,305 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+
+#include <gtest/gtest.h>
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+
+#include "gandiva/decimal_scalar.h"
+#include "gandiva/decimal_type_util.h"
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+using arrow::Decimal128;
+
+namespace gandiva {
+
+#define EXPECT_DECIMAL_RESULT(op, x, y, expected, actual) \
+ EXPECT_EQ(expected, actual) << op << " (" << (x).ToString() << "),(" << (y).ToString() \
+ << ")" \
+ << " expected : " << (expected).ToString() \
+ << " actual : " << (actual).ToString();
+
+DecimalScalar128 decimal_literal(const char* value, int precision, int scale) {
+ std::string value_string = std::string(value);
+ return DecimalScalar128(value_string, precision, scale);
+}
+
+class TestDecimalOps : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ ArrayPtr MakeDecimalVector(const DecimalScalar128& in);
+
+ void Verify(DecimalTypeUtil::Op, const std::string& function, const DecimalScalar128& x,
+ const DecimalScalar128& y, const DecimalScalar128& expected);
+
+ void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected) {
+ Verify(DecimalTypeUtil::kOpAdd, "add", x, y, expected);
+ }
+
+ void SubtractAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected) {
+ Verify(DecimalTypeUtil::kOpSubtract, "subtract", x, y, expected);
+ }
+
+ void MultiplyAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected) {
+ Verify(DecimalTypeUtil::kOpMultiply, "multiply", x, y, expected);
+ }
+
+ void DivideAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected) {
+ Verify(DecimalTypeUtil::kOpDivide, "divide", x, y, expected);
+ }
+
+ void ModAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected) {
+ Verify(DecimalTypeUtil::kOpMod, "mod", x, y, expected);
+ }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+ArrayPtr TestDecimalOps::MakeDecimalVector(const DecimalScalar128& in) {
+ std::vector<arrow::Decimal128> ret;
+
+ Decimal128 decimal_value = in.value();
+
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(in.precision(), in.scale());
+ return MakeArrowArrayDecimal(decimal_type, {decimal_value}, {true});
+}
+
+void TestDecimalOps::Verify(DecimalTypeUtil::Op op, const std::string& function,
+ const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected) {
+ auto x_type = std::make_shared<arrow::Decimal128Type>(x.precision(), x.scale());
+ auto y_type = std::make_shared<arrow::Decimal128Type>(y.precision(), y.scale());
+ auto field_x = field("x", x_type);
+ auto field_y = field("y", y_type);
+ auto schema = arrow::schema({field_x, field_y});
+
+ Decimal128TypePtr output_type;
+ auto status = DecimalTypeUtil::GetResultType(op, {x_type, y_type}, &output_type);
+ ARROW_EXPECT_OK(status);
+
+ // output fields
+ auto res = field("res", output_type);
+
+ // build expression : x op y
+ auto expr = TreeExprBuilder::MakeExpression(function, {field_x, field_y}, res);
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ ARROW_EXPECT_OK(status);
+
+ // Create a row-batch with some sample data
+ auto array_a = MakeDecimalVector(x);
+ auto array_b = MakeDecimalVector(y);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, 1 /*num_records*/, {array_a, array_b});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ ARROW_EXPECT_OK(status);
+
+ // Validate results
+ auto out_array = dynamic_cast<arrow::Decimal128Array*>(outputs[0].get());
+ const Decimal128 out_value(out_array->GetValue(0));
+
+ auto dtype = dynamic_cast<arrow::Decimal128Type*>(out_array->type().get());
+ std::string value_string = out_value.ToString(0);
+ DecimalScalar128 actual{value_string, dtype->precision(), dtype->scale()};
+
+ EXPECT_DECIMAL_RESULT(function, x, y, expected, actual);
+}
+
+TEST_F(TestDecimalOps, TestAdd) {
+ // fast-path
+ AddAndVerify(decimal_literal("201", 30, 3), // x
+ decimal_literal("301", 30, 3), // y
+ decimal_literal("502", 31, 3)); // expected
+
+ AddAndVerify(decimal_literal("201", 30, 3), // x
+ decimal_literal("301", 30, 2), // y
+ decimal_literal("3211", 32, 3)); // expected
+
+ AddAndVerify(decimal_literal("201", 30, 3), // x
+ decimal_literal("301", 30, 4), // y
+ decimal_literal("2311", 32, 4)); // expected
+
+ // max precision, but no overflow
+ AddAndVerify(decimal_literal("201", 38, 3), // x
+ decimal_literal("301", 38, 3), // y
+ decimal_literal("502", 38, 3)); // expected
+
+ AddAndVerify(decimal_literal("201", 38, 3), // x
+ decimal_literal("301", 38, 2), // y
+ decimal_literal("3211", 38, 3)); // expected
+
+ AddAndVerify(decimal_literal("201", 38, 3), // x
+ decimal_literal("301", 38, 4), // y
+ decimal_literal("2311", 38, 4)); // expected
+
+ AddAndVerify(decimal_literal("201", 38, 3), // x
+ decimal_literal("301", 38, 7), // y
+ decimal_literal("201030", 38, 6)); // expected
+
+ AddAndVerify(decimal_literal("1201", 38, 3), // x
+ decimal_literal("1801", 38, 3), // y
+ decimal_literal("3002", 38, 3)); // carry-over from fractional
+
+ // max precision
+ AddAndVerify(decimal_literal("09999999999999999999999999999999000000", 38, 5), // x
+ decimal_literal("100", 38, 7), // y
+ decimal_literal("99999999999999999999999999999990000010", 38, 6));
+
+ AddAndVerify(decimal_literal("-09999999999999999999999999999999000000", 38, 5), // x
+ decimal_literal("100", 38, 7), // y
+ decimal_literal("-99999999999999999999999999999989999990", 38, 6));
+
+ AddAndVerify(decimal_literal("09999999999999999999999999999999000000", 38, 5), // x
+ decimal_literal("-100", 38, 7), // y
+ decimal_literal("99999999999999999999999999999989999990", 38, 6));
+
+ AddAndVerify(decimal_literal("-09999999999999999999999999999999000000", 38, 5), // x
+ decimal_literal("-100", 38, 7), // y
+ decimal_literal("-99999999999999999999999999999990000010", 38, 6));
+
+ AddAndVerify(decimal_literal("09999999999999999999999999999999999999", 38, 6), // x
+ decimal_literal("89999999999999999999999999999999999999", 38, 7), // y
+ decimal_literal("18999999999999999999999999999999999999", 38, 6));
+
+ // Both -ve
+ AddAndVerify(decimal_literal("-201", 30, 3), // x
+ decimal_literal("-301", 30, 2), // y
+ decimal_literal("-3211", 32, 3)); // expected
+
+ AddAndVerify(decimal_literal("-201", 38, 3), // x
+ decimal_literal("-301", 38, 4), // y
+ decimal_literal("-2311", 38, 4)); // expected
+
+ // Mix of +ve and -ve
+ AddAndVerify(decimal_literal("-201", 30, 3), // x
+ decimal_literal("301", 30, 2), // y
+ decimal_literal("2809", 32, 3)); // expected
+
+ AddAndVerify(decimal_literal("-201", 38, 3), // x
+ decimal_literal("301", 38, 4), // y
+ decimal_literal("-1709", 38, 4)); // expected
+
+ AddAndVerify(decimal_literal("201", 38, 3), // x
+ decimal_literal("-301", 38, 7), // y
+ decimal_literal("200970", 38, 6)); // expected
+
+ AddAndVerify(decimal_literal("-1901", 38, 4), // x
+ decimal_literal("1801", 38, 4), // y
+ decimal_literal("-100", 38, 4)); // expected
+
+ AddAndVerify(decimal_literal("1801", 38, 4), // x
+ decimal_literal("-1901", 38, 4), // y
+ decimal_literal("-100", 38, 4)); // expected
+
+ // rounding +ve
+ AddAndVerify(decimal_literal("1000999", 38, 6), // x
+ decimal_literal("10000999", 38, 7), // y
+ decimal_literal("2001099", 38, 6));
+
+ AddAndVerify(decimal_literal("1000999", 38, 6), // x
+ decimal_literal("10000995", 38, 7), // y
+ decimal_literal("2001099", 38, 6));
+
+ AddAndVerify(decimal_literal("1000999", 38, 6), // x
+ decimal_literal("10000992", 38, 7), // y
+ decimal_literal("2001098", 38, 6));
+
+ // rounding -ve
+ AddAndVerify(decimal_literal("-1000999", 38, 6), // x
+ decimal_literal("-10000999", 38, 7), // y
+ decimal_literal("-2001099", 38, 6));
+
+ AddAndVerify(decimal_literal("-1000999", 38, 6), // x
+ decimal_literal("-10000995", 38, 7), // y
+ decimal_literal("-2001099", 38, 6));
+
+ AddAndVerify(decimal_literal("-1000999", 38, 6), // x
+ decimal_literal("-10000992", 38, 7), // y
+ decimal_literal("-2001098", 38, 6));
+}
+
+// subtract is a wrapper over add. so, minimal tests are sufficient.
+TEST_F(TestDecimalOps, TestSubtract) {
+ // fast-path
+ SubtractAndVerify(decimal_literal("201", 30, 3), // x
+ decimal_literal("301", 30, 3), // y
+ decimal_literal("-100", 31, 3)); // expected
+
+ // max precision
+ SubtractAndVerify(
+ decimal_literal("09999999999999999999999999999999000000", 38, 5), // x
+ decimal_literal("100", 38, 7), // y
+ decimal_literal("99999999999999999999999999999989999990", 38, 6));
+
+ // Mix of +ve and -ve
+ SubtractAndVerify(decimal_literal("-201", 30, 3), // x
+ decimal_literal("301", 30, 2), // y
+ decimal_literal("-3211", 32, 3)); // expected
+}
+
+// Lots of unit tests for multiply/divide/mod in decimal_ops_test.cc. So, keeping these
+// basic.
+TEST_F(TestDecimalOps, TestMultiply) {
+ // fast-path
+ MultiplyAndVerify(decimal_literal("201", 10, 3), // x
+ decimal_literal("301", 10, 2), // y
+ decimal_literal("60501", 21, 5)); // expected
+
+ // max precision
+ MultiplyAndVerify(DecimalScalar128(std::string(35, '9'), 38, 20), // x
+ DecimalScalar128(std::string(36, '9'), 38, 20), // x
+ DecimalScalar128("9999999999999999999999999999999999890", 38, 6));
+}
+
+TEST_F(TestDecimalOps, TestDivide) {
+ DivideAndVerify(decimal_literal("201", 10, 3), // x
+ decimal_literal("301", 10, 2), // y
+ decimal_literal("6677740863787", 23, 14)); // expected
+
+ DivideAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20), // x
+ DecimalScalar128(std::string(35, '9'), 38, 20), // x
+ DecimalScalar128("1000000000", 38, 6));
+}
+
+TEST_F(TestDecimalOps, TestMod) {
+ ModAndVerify(decimal_literal("201", 20, 2), // x
+ decimal_literal("301", 20, 3), // y
+ decimal_literal("204", 20, 3)); // expected
+
+ ModAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20), // x
+ DecimalScalar128(std::string(35, '9'), 38, 21), // x
+ DecimalScalar128("9990", 38, 21));
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/decimal_test.cc b/src/arrow/cpp/src/gandiva/tests/decimal_test.cc
new file mode 100644
index 000000000..31f2dedf5
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/decimal_test.cc
@@ -0,0 +1,1194 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+
+#include <gtest/gtest.h>
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/util/decimal.h"
+
+#include "gandiva/decimal_type_util.h"
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+using arrow::boolean;
+using arrow::Decimal128;
+using arrow::utf8;
+
+namespace gandiva {
+
+class TestDecimal : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ std::vector<Decimal128> MakeDecimalVector(std::vector<std::string> values,
+ int32_t scale);
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+std::vector<Decimal128> TestDecimal::MakeDecimalVector(std::vector<std::string> values,
+ int32_t scale) {
+ std::vector<arrow::Decimal128> ret;
+ for (auto str : values) {
+ Decimal128 str_value;
+ int32_t str_precision;
+ int32_t str_scale;
+
+ DCHECK_OK(Decimal128::FromString(str, &str_value, &str_precision, &str_scale));
+
+ Decimal128 scaled_value;
+ if (str_scale == scale) {
+ scaled_value = str_value;
+ } else {
+ scaled_value = str_value.Rescale(str_scale, scale).ValueOrDie();
+ }
+ ret.push_back(scaled_value);
+ }
+ return ret;
+}
+
+TEST_F(TestDecimal, TestSimple) {
+ // schema for input fields
+ constexpr int32_t precision = 36;
+ constexpr int32_t scale = 18;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field_a = field("a", decimal_type);
+ auto field_b = field("b", decimal_type);
+ auto field_c = field("c", decimal_type);
+ auto schema = arrow::schema({field_a, field_b, field_c});
+
+ Decimal128TypePtr add2_type;
+ auto status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd,
+ {decimal_type, decimal_type}, &add2_type);
+
+ Decimal128TypePtr output_type;
+ status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd,
+ {add2_type, decimal_type}, &output_type);
+
+ // output fields
+ auto res = field("res0", output_type);
+
+ // build expression : a + b + c
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto node_b = TreeExprBuilder::MakeField(field_b);
+ auto node_c = TreeExprBuilder::MakeField(field_c);
+ auto add2 = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, add2_type);
+ auto add3 = TreeExprBuilder::MakeFunction("add", {add2, node_c}, output_type);
+ auto expr = TreeExprBuilder::MakeExpression(add3, res);
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "2", "3", "4"}, scale),
+ {false, true, true, true});
+ auto array_b =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"2", "3", "4", "5"}, scale),
+ {false, true, true, true});
+ auto array_c =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"3", "4", "5", "6"}, scale),
+ {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch =
+ arrow::RecordBatch::Make(schema, num_records, {array_a, array_b, array_c});
+
+ auto expected =
+ MakeArrowArrayDecimal(output_type, MakeDecimalVector({"6", "9", "12", "15"}, scale),
+ {false, true, true, true});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(expected, outputs[0]);
+}
+
+TEST_F(TestDecimal, TestLiteral) {
+ // schema for input fields
+ constexpr int32_t precision = 36;
+ constexpr int32_t scale = 18;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field_a = field("a", decimal_type);
+ auto schema = arrow::schema({
+ field_a,
+ });
+
+ Decimal128TypePtr add2_type;
+ auto status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd,
+ {decimal_type, decimal_type}, &add2_type);
+
+ // output fields
+ auto res = field("res0", add2_type);
+
+ // build expression : a + b + c
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ static std::string decimal_point_six = "6";
+ DecimalScalar128 literal(decimal_point_six, 2, 1);
+ auto node_b = TreeExprBuilder::MakeDecimalLiteral(literal);
+ auto add2 = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, add2_type);
+ auto expr = TreeExprBuilder::MakeExpression(add2, res);
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "2", "3", "4"}, scale),
+ {false, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ auto expected = MakeArrowArrayDecimal(
+ add2_type, MakeDecimalVector({"1.6", "2.6", "3.6", "4.6"}, scale),
+ {false, true, true, true});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(expected, outputs[0]);
+}
+
+TEST_F(TestDecimal, TestIfElse) {
+ // schema for input fields
+ constexpr int32_t precision = 36;
+ constexpr int32_t scale = 18;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field_a = field("a", decimal_type);
+ auto field_b = field("b", decimal_type);
+ auto field_c = field("c", arrow::boolean());
+ auto schema = arrow::schema({field_a, field_b, field_c});
+
+ // output fields
+ auto field_result = field("res", decimal_type);
+
+ // build expression.
+ // if (c)
+ // a
+ // else
+ // b
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto node_b = TreeExprBuilder::MakeField(field_b);
+ auto node_c = TreeExprBuilder::MakeField(field_c);
+ auto if_node = TreeExprBuilder::MakeIf(node_c, node_a, node_b, decimal_type);
+
+ auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ Status status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "2", "3", "4"}, scale),
+ {false, true, true, true});
+ auto array_b =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"2", "3", "4", "5"}, scale),
+ {true, true, true, true});
+
+ auto array_c = MakeArrowArrayBool({true, false, true, false}, {true, true, true, true});
+
+ // expected output
+ auto exp =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"0", "3", "3", "5"}, scale),
+ {false, true, true, true});
+
+ // prepare input record batch
+ auto in_batch =
+ arrow::RecordBatch::Make(schema, num_records, {array_a, array_b, array_c});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestDecimal, TestCompare) {
+ // schema for input fields
+ constexpr int32_t precision = 36;
+ constexpr int32_t scale = 18;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field_a = field("a", decimal_type);
+ auto field_b = field("b", decimal_type);
+ auto schema = arrow::schema({field_a, field_b});
+
+ // build expressions
+ auto exprs = std::vector<ExpressionPtr>{
+ TreeExprBuilder::MakeExpression("equal", {field_a, field_b},
+ field("res_eq", boolean())),
+ TreeExprBuilder::MakeExpression("not_equal", {field_a, field_b},
+ field("res_ne", boolean())),
+ TreeExprBuilder::MakeExpression("less_than", {field_a, field_b},
+ field("res_lt", boolean())),
+ TreeExprBuilder::MakeExpression("less_than_or_equal_to", {field_a, field_b},
+ field("res_le", boolean())),
+ TreeExprBuilder::MakeExpression("greater_than", {field_a, field_b},
+ field("res_gt", boolean())),
+ TreeExprBuilder::MakeExpression("greater_than_or_equal_to", {field_a, field_b},
+ field("res_ge", boolean())),
+ };
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "2", "3", "-4"}, scale),
+ {true, true, true, true});
+ auto array_b =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "3", "2", "-3"}, scale),
+ {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, false, false, false}),
+ outputs[0]); // equal
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, true, true, true}),
+ outputs[1]); // not_equal
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, true, false, true}),
+ outputs[2]); // less_than
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, true, false, true}),
+ outputs[3]); // less_than_or_equal_to
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, false, true, false}),
+ outputs[4]); // greater_than
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, false, true, false}),
+ outputs[5]); // greater_than_or_equal_to
+}
+
+// ARROW-9092: This test is conditionally disabled when building with LLVM 9
+// because it hangs.
+#if GANDIVA_LLVM_VERSION != 9
+
+TEST_F(TestDecimal, TestRoundFunctions) {
+ // schema for input fields
+ constexpr int32_t precision = 38;
+ constexpr int32_t scale = 2;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field_a = field("a", decimal_type);
+ auto schema = arrow::schema({field_a});
+
+ auto scale_1 = TreeExprBuilder::MakeLiteral(1);
+
+ // build expressions
+ auto exprs = std::vector<ExpressionPtr>{
+ TreeExprBuilder::MakeExpression("abs", {field_a}, field("res_abs", decimal_type)),
+ TreeExprBuilder::MakeExpression("ceil", {field_a},
+ field("res_ceil", arrow::decimal(precision, 0))),
+ TreeExprBuilder::MakeExpression("floor", {field_a},
+ field("res_floor", arrow::decimal(precision, 0))),
+ TreeExprBuilder::MakeExpression("round", {field_a},
+ field("res_round", arrow::decimal(precision, 0))),
+ TreeExprBuilder::MakeExpression(
+ "truncate", {field_a}, field("res_truncate", arrow::decimal(precision, 0))),
+
+ TreeExprBuilder::MakeExpression(
+ TreeExprBuilder::MakeFunction("round",
+ {TreeExprBuilder::MakeField(field_a), scale_1},
+ arrow::decimal(precision, 1)),
+ field("res_round_3", arrow::decimal(precision, 1))),
+
+ TreeExprBuilder::MakeExpression(
+ TreeExprBuilder::MakeFunction("truncate",
+ {TreeExprBuilder::MakeField(field_a), scale_1},
+ arrow::decimal(precision, 1)),
+ field("res_truncate_3", arrow::decimal(precision, 1))),
+ };
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto validity = {true, true, true, true};
+ auto array_a = MakeArrowArrayDecimal(
+ decimal_type, MakeDecimalVector({"1.23", "1.58", "-1.23", "-1.58"}, scale),
+ validity);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+
+ // abs(x)
+ EXPECT_ARROW_ARRAY_EQUALS(
+ MakeArrowArrayDecimal(decimal_type,
+ MakeDecimalVector({"1.23", "1.58", "1.23", "1.58"}, scale),
+ validity),
+ outputs[0]);
+
+ // ceil(x)
+ EXPECT_ARROW_ARRAY_EQUALS(
+ MakeArrowArrayDecimal(arrow::decimal(precision, 0),
+ MakeDecimalVector({"2", "2", "-1", "-1"}, 0), validity),
+ outputs[1]);
+
+ // floor(x)
+ EXPECT_ARROW_ARRAY_EQUALS(
+ MakeArrowArrayDecimal(arrow::decimal(precision, 0),
+ MakeDecimalVector({"1", "1", "-2", "-2"}, 0), validity),
+ outputs[2]);
+
+ // round(x)
+ EXPECT_ARROW_ARRAY_EQUALS(
+ MakeArrowArrayDecimal(arrow::decimal(precision, 0),
+ MakeDecimalVector({"1", "2", "-1", "-2"}, 0), validity),
+ outputs[3]);
+
+ // truncate(x)
+ EXPECT_ARROW_ARRAY_EQUALS(
+ MakeArrowArrayDecimal(arrow::decimal(precision, 0),
+ MakeDecimalVector({"1", "1", "-1", "-1"}, 0), validity),
+ outputs[4]);
+
+ // round(x, 1)
+ EXPECT_ARROW_ARRAY_EQUALS(
+ MakeArrowArrayDecimal(arrow::decimal(precision, 1),
+ MakeDecimalVector({"1.2", "1.6", "-1.2", "-1.6"}, 1),
+ validity),
+ outputs[5]);
+
+ // truncate(x, 1)
+ EXPECT_ARROW_ARRAY_EQUALS(
+ MakeArrowArrayDecimal(arrow::decimal(precision, 1),
+ MakeDecimalVector({"1.2", "1.5", "-1.2", "-1.5"}, 1),
+ validity),
+ outputs[6]);
+}
+
+#endif // GANDIVA_LLVM_VERSION != 9
+
+TEST_F(TestDecimal, TestCastFunctions) {
+ // schema for input fields
+ constexpr int32_t precision = 38;
+ constexpr int32_t scale = 2;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto decimal_type_scale_1 = std::make_shared<arrow::Decimal128Type>(precision, 1);
+ auto field_int32 = field("int32", arrow::int32());
+ auto field_int64 = field("int64", arrow::int64());
+ auto field_float32 = field("float32", arrow::float32());
+ auto field_float64 = field("float64", arrow::float64());
+ auto field_dec = field("dec", decimal_type);
+ auto schema =
+ arrow::schema({field_int32, field_int64, field_float32, field_float64, field_dec});
+
+ // build expressions
+ auto exprs = std::vector<ExpressionPtr>{
+ TreeExprBuilder::MakeExpression("castDECIMAL", {field_int32},
+ field("int32_to_dec", decimal_type)),
+ TreeExprBuilder::MakeExpression("castDECIMAL", {field_int64},
+ field("int64_to_dec", decimal_type)),
+ TreeExprBuilder::MakeExpression("castDECIMAL", {field_float32},
+ field("float32_to_dec", decimal_type)),
+ TreeExprBuilder::MakeExpression("castDECIMAL", {field_float64},
+ field("float64_to_dec", decimal_type)),
+ TreeExprBuilder::MakeExpression("castDECIMAL", {field_dec},
+ field("dec_to_dec", decimal_type_scale_1)),
+ TreeExprBuilder::MakeExpression("castBIGINT", {field_dec},
+ field("dec_to_int64", arrow::int64())),
+ TreeExprBuilder::MakeExpression("castFLOAT8", {field_dec},
+ field("dec_to_float64", arrow::float64())),
+ };
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto validity = {true, true, true, true};
+
+ auto array_int32 = MakeArrowArrayInt32({123, 158, -123, -158});
+ auto array_int64 = MakeArrowArrayInt64({123, 158, -123, -158});
+ auto array_float32 = MakeArrowArrayFloat32({1.23f, 1.58f, -1.23f, -1.58f});
+ auto array_float64 = MakeArrowArrayFloat64({1.23, 1.58, -1.23, -1.58});
+ auto array_dec = MakeArrowArrayDecimal(
+ decimal_type, MakeDecimalVector({"1.23", "1.58", "-1.23", "-1.58"}, scale),
+ validity);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(
+ schema, num_records,
+ {array_int32, array_int64, array_float32, array_float64, array_dec});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ auto expected_int_dec = MakeArrowArrayDecimal(
+ decimal_type, MakeDecimalVector({"123", "158", "-123", "-158"}, scale), validity);
+
+ // castDECIMAL(int32)
+ EXPECT_ARROW_ARRAY_EQUALS(expected_int_dec, outputs[0]);
+
+ // castDECIMAL(int64)
+ EXPECT_ARROW_ARRAY_EQUALS(expected_int_dec, outputs[1]);
+
+ // castDECIMAL(float32)
+ EXPECT_ARROW_ARRAY_EQUALS(array_dec, outputs[2]);
+
+ // castDECIMAL(float64)
+ EXPECT_ARROW_ARRAY_EQUALS(array_dec, outputs[3]);
+
+ // castDECIMAL(decimal)
+ EXPECT_ARROW_ARRAY_EQUALS(
+ MakeArrowArrayDecimal(arrow::decimal(precision, 1),
+ MakeDecimalVector({"1.2", "1.6", "-1.2", "-1.6"}, 1),
+ validity),
+ outputs[4]);
+
+ // castBIGINT(decimal)
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayInt64({1, 2, -1, -2}), outputs[5]);
+
+ // castDOUBLE(decimal)
+ EXPECT_ARROW_ARRAY_EQUALS(array_float64, outputs[6]);
+}
+
+// isnull, isnumeric
+TEST_F(TestDecimal, TestIsNullNumericFunctions) {
+ // schema for input fields
+ constexpr int32_t precision = 38;
+ constexpr int32_t scale = 2;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field_dec = field("dec", decimal_type);
+ auto schema = arrow::schema({field_dec});
+
+ // build expressions
+ auto exprs = std::vector<ExpressionPtr>{
+ TreeExprBuilder::MakeExpression("isnull", {field_dec},
+ field("isnull", arrow::boolean())),
+
+ TreeExprBuilder::MakeExpression("isnotnull", {field_dec},
+ field("isnotnull", arrow::boolean())),
+ TreeExprBuilder::MakeExpression("isnumeric", {field_dec},
+ field("isnumeric", arrow::boolean()))};
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto validity = {false, true, true, true, false};
+
+ auto array_dec = MakeArrowArrayDecimal(
+ decimal_type, MakeDecimalVector({"1.51", "1.23", "1.23", "-1.23", "-1.24"}, scale),
+ validity);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_dec});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ auto is_null = outputs.at(0);
+ auto is_not_null = outputs.at(1);
+ auto is_numeric = outputs.at(2);
+
+ // isnull
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, false, false, false, true}),
+ outputs[0]);
+
+ // isnotnull
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool(validity), outputs[1]);
+
+ // isnumeric
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool(validity), outputs[2]);
+}
+
+TEST_F(TestDecimal, TestIsDistinct) {
+ // schema for input fields
+ constexpr int32_t precision = 38;
+ constexpr int32_t scale_1 = 2;
+ auto decimal_type_1 = std::make_shared<arrow::Decimal128Type>(precision, scale_1);
+ auto field_dec_1 = field("dec_1", decimal_type_1);
+ constexpr int32_t scale_2 = 1;
+ auto decimal_type_2 = std::make_shared<arrow::Decimal128Type>(precision, scale_2);
+ auto field_dec_2 = field("dec_2", decimal_type_2);
+
+ auto schema = arrow::schema({field_dec_1, field_dec_2});
+
+ // build expressions
+ auto exprs = std::vector<ExpressionPtr>{
+ TreeExprBuilder::MakeExpression("is_distinct_from", {field_dec_1, field_dec_2},
+ field("isdistinct", arrow::boolean())),
+
+ TreeExprBuilder::MakeExpression("is_not_distinct_from", {field_dec_1, field_dec_2},
+ field("isnotdistinct", arrow::boolean()))};
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+
+ auto validity_1 = {true, false, true, true};
+ auto array_dec_1 = MakeArrowArrayDecimal(
+ decimal_type_1, MakeDecimalVector({"1.51", "1.23", "1.20", "-1.20"}, scale_1),
+ validity_1);
+
+ auto validity_2 = {true, false, false, true};
+ auto array_dec_2 = MakeArrowArrayDecimal(
+ decimal_type_2, MakeDecimalVector({"1.5", "1.2", "1.2", "-1.2"}, scale_2),
+ validity_2);
+
+ // prepare input record batch
+ auto in_batch =
+ arrow::RecordBatch::Make(schema, num_records, {array_dec_1, array_dec_2});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ auto is_distinct = std::dynamic_pointer_cast<arrow::BooleanArray>(outputs.at(0));
+ auto is_not_distinct = std::dynamic_pointer_cast<arrow::BooleanArray>(outputs.at(1));
+
+ // isdistinct
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, false, true, false}), outputs[0]);
+
+ // isnotdistinct
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, true, false, true}), outputs[1]);
+}
+
+// decimal hashes without seed
+TEST_F(TestDecimal, TestHashFunctions) {
+ // schema for input fields
+ constexpr int32_t precision = 38;
+ constexpr int32_t scale = 2;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field_dec = field("dec", decimal_type);
+ auto literal_seed32 = TreeExprBuilder::MakeLiteral((int32_t)10);
+ auto literal_seed64 = TreeExprBuilder::MakeLiteral((int64_t)10);
+ auto schema = arrow::schema({field_dec});
+
+ // build expressions
+ auto exprs = std::vector<ExpressionPtr>{
+ TreeExprBuilder::MakeExpression("hash", {field_dec},
+ field("hash_of_dec", arrow::int32())),
+
+ TreeExprBuilder::MakeExpression("hash64", {field_dec},
+ field("hash64_of_dec", arrow::int64())),
+
+ TreeExprBuilder::MakeExpression("hash32AsDouble", {field_dec},
+ field("hash32_as_double", arrow::int32())),
+
+ TreeExprBuilder::MakeExpression("hash64AsDouble", {field_dec},
+ field("hash64_as_double", arrow::int64()))};
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto validity = {false, true, true, true, true};
+
+ auto array_dec = MakeArrowArrayDecimal(
+ decimal_type, MakeDecimalVector({"1.51", "1.23", "1.23", "-1.23", "-1.24"}, scale),
+ validity);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_dec});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ auto int32_arr = std::dynamic_pointer_cast<arrow::Int32Array>(outputs.at(0));
+ EXPECT_EQ(int32_arr->null_count(), 0);
+ EXPECT_EQ(int32_arr->Value(0), 0);
+ EXPECT_EQ(int32_arr->Value(1), int32_arr->Value(2));
+ EXPECT_NE(int32_arr->Value(2), int32_arr->Value(3));
+ EXPECT_NE(int32_arr->Value(3), int32_arr->Value(4));
+
+ auto int64_arr = std::dynamic_pointer_cast<arrow::Int64Array>(outputs.at(1));
+ EXPECT_EQ(int64_arr->null_count(), 0);
+ EXPECT_EQ(int64_arr->Value(0), 0);
+ EXPECT_EQ(int64_arr->Value(1), int64_arr->Value(2));
+ EXPECT_NE(int64_arr->Value(2), int64_arr->Value(3));
+ EXPECT_NE(int64_arr->Value(3), int64_arr->Value(4));
+}
+
+TEST_F(TestDecimal, TestHash32WithSeed) {
+ constexpr int32_t precision = 38;
+ constexpr int32_t scale = 2;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field_dec_1 = field("dec1", decimal_type);
+ auto field_dec_2 = field("dec2", decimal_type);
+ auto schema = arrow::schema({field_dec_1, field_dec_2});
+
+ auto res = field("hash32_with_seed", arrow::int32());
+
+ auto field_1_nodePtr = TreeExprBuilder::MakeField(field_dec_1);
+ auto field_2_nodePtr = TreeExprBuilder::MakeField(field_dec_2);
+
+ auto hash32 =
+ TreeExprBuilder::MakeFunction("hash32", {field_2_nodePtr}, arrow::int32());
+ auto hash32_with_seed =
+ TreeExprBuilder::MakeFunction("hash32", {field_1_nodePtr, hash32}, arrow::int32());
+ auto expr = TreeExprBuilder::MakeExpression(hash32, field("hash32", arrow::int32()));
+ auto exprWS = TreeExprBuilder::MakeExpression(hash32_with_seed, res);
+
+ auto exprs = std::vector<ExpressionPtr>{expr, exprWS};
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto validity_1 = {false, false, true, true, true};
+
+ auto array_dec_1 = MakeArrowArrayDecimal(
+ decimal_type, MakeDecimalVector({"1.51", "1.23", "1.23", "-1.23", "-1.24"}, scale),
+ validity_1);
+
+ auto validity_2 = {false, true, false, true, true};
+
+ auto array_dec_2 = MakeArrowArrayDecimal(
+ decimal_type, MakeDecimalVector({"1.51", "1.23", "1.23", "-1.23", "-1.24"}, scale),
+ validity_2);
+
+ // prepare input record batch
+ auto in_batch =
+ arrow::RecordBatch::Make(schema, num_records, {array_dec_1, array_dec_2});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ auto int32_arr = std::dynamic_pointer_cast<arrow::Int32Array>(outputs.at(0));
+ auto int32_arr_WS = std::dynamic_pointer_cast<arrow::Int32Array>(outputs.at(1));
+ EXPECT_EQ(int32_arr->null_count(), 0);
+ // seed 0, null decimal
+ EXPECT_EQ(int32_arr_WS->Value(0), 0);
+ // null decimal => hash = seed
+ EXPECT_EQ(int32_arr_WS->Value(1), int32_arr->Value(1));
+ // seed = 0 => hash = hash without seed
+ EXPECT_EQ(int32_arr_WS->Value(2), int32_arr->Value(1));
+ // different inputs => different outputs
+ EXPECT_NE(int32_arr_WS->Value(3), int32_arr_WS->Value(4));
+ // hash with, without seed are not equal
+ EXPECT_NE(int32_arr_WS->Value(4), int32_arr->Value(4));
+}
+
+TEST_F(TestDecimal, TestHash64WithSeed) {
+ constexpr int32_t precision = 38;
+ constexpr int32_t scale = 2;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field_dec_1 = field("dec1", decimal_type);
+ auto field_dec_2 = field("dec2", decimal_type);
+ auto schema = arrow::schema({field_dec_1, field_dec_2});
+
+ auto res = field("hash64_with_seed", arrow::int64());
+
+ auto field_1_nodePtr = TreeExprBuilder::MakeField(field_dec_1);
+ auto field_2_nodePtr = TreeExprBuilder::MakeField(field_dec_2);
+
+ auto hash64 =
+ TreeExprBuilder::MakeFunction("hash64", {field_2_nodePtr}, arrow::int64());
+ auto hash64_with_seed =
+ TreeExprBuilder::MakeFunction("hash64", {field_1_nodePtr, hash64}, arrow::int64());
+ auto expr = TreeExprBuilder::MakeExpression(hash64, field("hash64", arrow::int64()));
+ auto exprWS = TreeExprBuilder::MakeExpression(hash64_with_seed, res);
+
+ auto exprs = std::vector<ExpressionPtr>{expr, exprWS};
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto validity_1 = {false, false, true, true, true};
+
+ auto array_dec_1 = MakeArrowArrayDecimal(
+ decimal_type, MakeDecimalVector({"1.51", "1.23", "1.23", "-1.23", "-1.24"}, scale),
+ validity_1);
+
+ auto validity_2 = {false, true, false, true, true};
+
+ auto array_dec_2 = MakeArrowArrayDecimal(
+ decimal_type, MakeDecimalVector({"1.51", "1.23", "1.23", "-1.23", "-1.24"}, scale),
+ validity_2);
+
+ // prepare input record batch
+ auto in_batch =
+ arrow::RecordBatch::Make(schema, num_records, {array_dec_1, array_dec_2});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ auto int64_arr = std::dynamic_pointer_cast<arrow::Int64Array>(outputs.at(0));
+ auto int64_arr_WS = std::dynamic_pointer_cast<arrow::Int64Array>(outputs.at(1));
+ EXPECT_EQ(int64_arr->null_count(), 0);
+ // seed 0, null decimal
+ EXPECT_EQ(int64_arr_WS->Value(0), 0);
+ // null decimal => hash = seed
+ EXPECT_EQ(int64_arr_WS->Value(1), int64_arr->Value(1));
+ // seed = 0 => hash = hash without seed
+ EXPECT_EQ(int64_arr_WS->Value(2), int64_arr->Value(1));
+ // different inputs => different outputs
+ EXPECT_NE(int64_arr_WS->Value(3), int64_arr_WS->Value(4));
+ // hash with, without seed are not equal
+ EXPECT_NE(int64_arr_WS->Value(4), int64_arr->Value(4));
+}
+
+TEST_F(TestDecimal, TestNullDecimalConstant) {
+ // schema for input fields
+ constexpr int32_t precision = 36;
+ constexpr int32_t scale = 18;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field_b = field("b", decimal_type);
+ auto field_c = field("c", arrow::boolean());
+ auto schema = arrow::schema({field_b, field_c});
+
+ // output fields
+ auto field_result = field("res", decimal_type);
+
+ // build expression.
+ // if (c)
+ // null
+ // else
+ // b
+ auto node_a = TreeExprBuilder::MakeNull(decimal_type);
+ auto node_b = TreeExprBuilder::MakeField(field_b);
+ auto node_c = TreeExprBuilder::MakeField(field_c);
+ auto if_node = TreeExprBuilder::MakeIf(node_c, node_a, node_b, decimal_type);
+
+ auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ Status status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+
+ auto array_b =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"2", "3", "4", "5"}, scale),
+ {true, true, true, true});
+
+ auto array_c = MakeArrowArrayBool({true, false, true, false}, {true, true, true, true});
+
+ // expected output
+ auto exp =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"0", "3", "3", "5"}, scale),
+ {false, true, false, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_b, array_c});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestDecimal, TestCastVarCharDecimal) {
+ // schema for input fields
+ constexpr int32_t precision = 38;
+ constexpr int32_t scale = 2;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+
+ auto field_dec = field("dec", decimal_type);
+ auto field_res_str = field("res_str", utf8());
+ auto field_res_str_1 = field("res_str_1", utf8());
+ auto schema = arrow::schema({field_dec, field_res_str, field_res_str_1});
+
+ // output fields
+ auto res_str = field("res_str", utf8());
+ auto equals_res_bool = field("equals_res", boolean());
+
+ // build expressions.
+ auto node_dec = TreeExprBuilder::MakeField(field_dec);
+ auto node_res_str = TreeExprBuilder::MakeField(field_res_str);
+ auto node_res_str_1 = TreeExprBuilder::MakeField(field_res_str_1);
+ // limits decimal string to input length
+ auto str_len_limit = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(5));
+ auto str_len_limit_1 = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(1));
+ auto cast_varchar =
+ TreeExprBuilder::MakeFunction("castVARCHAR", {node_dec, str_len_limit}, utf8());
+ auto cast_varchar_1 =
+ TreeExprBuilder::MakeFunction("castVARCHAR", {node_dec, str_len_limit_1}, utf8());
+ auto equals =
+ TreeExprBuilder::MakeFunction("equal", {cast_varchar, node_res_str}, boolean());
+ auto equals_1 =
+ TreeExprBuilder::MakeFunction("equal", {cast_varchar_1, node_res_str_1}, boolean());
+ auto expr = TreeExprBuilder::MakeExpression(equals, equals_res_bool);
+ auto expr_1 = TreeExprBuilder::MakeExpression(equals_1, equals_res_bool);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+
+ auto status = Projector::Make(schema, {expr, expr_1}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array_dec = MakeArrowArrayDecimal(
+ decimal_type,
+ MakeDecimalVector({"10.51", "1.23", "100.23", "-1000.23", "-0000.10"}, scale),
+ {true, false, true, true, true});
+ auto array_str_res = MakeArrowArrayUtf8({"10.51", "-null-", "100.2", "-1000", "-0.10"},
+ {true, false, true, true, true});
+ auto array_str_res_1 =
+ MakeArrowArrayUtf8({"1", "-null-", "1", "-", "-"}, {true, false, true, true, true});
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records,
+ {array_dec, array_str_res, array_str_res_1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ auto exp = MakeArrowArrayBool({true, false, true, true, true},
+ {true, false, true, true, true});
+ auto exp_1 = MakeArrowArrayBool({true, false, true, true, true},
+ {true, false, true, true, true});
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs[0]);
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs[1]);
+}
+
+TEST_F(TestDecimal, TestCastDecimalVarChar) {
+ // schema for input fields
+ constexpr int32_t precision = 4;
+ constexpr int32_t scale = 2;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+
+ auto field_str = field("in_str", utf8());
+ auto schema = arrow::schema({field_str});
+
+ // output fields
+ auto res_dec = field("res_dec", decimal_type);
+
+ // build expressions.
+ auto node_str = TreeExprBuilder::MakeField(field_str);
+ auto cast_decimal =
+ TreeExprBuilder::MakeFunction("castDECIMAL", {node_str}, decimal_type);
+ auto expr = TreeExprBuilder::MakeExpression(cast_decimal, res_dec);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+
+ auto array_str = MakeArrowArrayUtf8({"10.5134", "-0.0", "-0.1", "10.516", "-1000"},
+ {true, false, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_str});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ auto array_dec = MakeArrowArrayDecimal(
+ decimal_type, MakeDecimalVector({"10.51", "1.23", "-0.10", "10.52", "0.00"}, scale),
+ {true, false, true, true, true});
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(array_dec, outputs[0]);
+}
+
+TEST_F(TestDecimal, TestCastDecimalVarCharInvalidInput) {
+ // schema for input fields
+ constexpr int32_t precision = 38;
+ constexpr int32_t scale = 0;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+
+ auto field_str = field("in_str", utf8());
+ auto schema = arrow::schema({field_str});
+
+ // output fields
+ auto res_dec = field("res_dec", decimal_type);
+
+ // build expressions.
+ auto node_str = TreeExprBuilder::MakeField(field_str);
+ auto cast_decimal =
+ TreeExprBuilder::MakeFunction("castDECIMAL", {node_str}, decimal_type);
+ auto expr = TreeExprBuilder::MakeExpression(cast_decimal, res_dec);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+
+ // invalid input
+ auto invalid_in = MakeArrowArrayUtf8({"a10.5134", "-0.0", "-0.1", "10.516", "-1000"},
+ {true, false, true, true, true});
+
+ // prepare input record batch
+ auto in_batch_1 = arrow::RecordBatch::Make(schema, num_records, {invalid_in});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs_1;
+ status = projector->Evaluate(*in_batch_1, pool_, &outputs_1);
+ EXPECT_FALSE(status.ok()) << status.message();
+ EXPECT_NE(status.message().find("not a valid decimal128 number"), std::string::npos);
+}
+
+TEST_F(TestDecimal, TestVarCharDecimalNestedCast) {
+ // schema for input fields
+ constexpr int32_t precision = 38;
+ constexpr int32_t scale = 2;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+
+ auto field_dec = field("dec", decimal_type);
+ auto schema = arrow::schema({field_dec});
+
+ // output fields
+ auto field_dec_res = field("dec_res", decimal_type);
+
+ // build expressions.
+ auto node_dec = TreeExprBuilder::MakeField(field_dec);
+
+ // limits decimal string to input length
+ auto str_len_limit = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(5));
+ auto cast_varchar =
+ TreeExprBuilder::MakeFunction("castVARCHAR", {node_dec, str_len_limit}, utf8());
+ auto cast_decimal =
+ TreeExprBuilder::MakeFunction("castDECIMAL", {cast_varchar}, decimal_type);
+
+ auto expr = TreeExprBuilder::MakeExpression(cast_decimal, field_dec_res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array_dec = MakeArrowArrayDecimal(
+ decimal_type,
+ MakeDecimalVector({"10.51", "1.23", "100.23", "-1000.23", "-0000.10"}, scale),
+ {true, false, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_dec});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ auto array_dec_res = MakeArrowArrayDecimal(
+ decimal_type,
+ MakeDecimalVector({"10.51", "1.23", "100.20", "-1000.00", "-0.10"}, scale),
+ {true, false, true, true, true});
+ EXPECT_ARROW_ARRAY_EQUALS(array_dec_res, outputs[0]);
+}
+
+TEST_F(TestDecimal, TestCastDecimalOverflow) {
+ // schema for input fields
+ constexpr int32_t precision_in = 5;
+ constexpr int32_t scale_in = 2;
+ constexpr int32_t precision_out = 3;
+ constexpr int32_t scale_out = 1;
+ auto decimal_5_2 = std::make_shared<arrow::Decimal128Type>(precision_in, scale_in);
+ auto decimal_3_1 = std::make_shared<arrow::Decimal128Type>(precision_out, scale_out);
+
+ auto field_dec = field("dec", decimal_5_2);
+ auto schema = arrow::schema({field_dec});
+
+ // build expressions
+ auto exprs = std::vector<ExpressionPtr>{
+ TreeExprBuilder::MakeExpression("castDECIMAL", {field_dec},
+ field("dec_to_dec", decimal_3_1)),
+ TreeExprBuilder::MakeExpression("castDECIMALNullOnOverflow", {field_dec},
+ field("dec_to_dec_null_overflow", decimal_3_1)),
+ };
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto validity = {true, true, true, true};
+
+ auto array_dec = MakeArrowArrayDecimal(
+ decimal_5_2, MakeDecimalVector({"1.23", "671.58", "-1.23", "-1.58"}, scale_in),
+ validity);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_dec});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ // castDECIMAL(decimal)
+ EXPECT_ARROW_ARRAY_EQUALS(
+ MakeArrowArrayDecimal(arrow::decimal(precision_out, 1),
+ MakeDecimalVector({"1.2", "0.0", "-1.2", "-1.6"}, 1),
+ validity),
+ outputs[0]);
+
+ // castDECIMALNullOnOverflow(decimal)
+ EXPECT_ARROW_ARRAY_EQUALS(
+ MakeArrowArrayDecimal(arrow::decimal(precision_out, 1),
+ MakeDecimalVector({"1.2", "1.6", "-1.2", "-1.6"}, 1),
+ {true, false, true, true}),
+ outputs[1]);
+}
+
+TEST_F(TestDecimal, TestSha) {
+ // schema for input fields
+ const std::shared_ptr<arrow::DataType>& decimal_5_2 = arrow::decimal128(5, 2);
+ auto field_a = field("a", decimal_5_2);
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res_0 = field("res0", utf8());
+ auto res_1 = field("res1", utf8());
+
+ // build expressions.
+ // hashSHA1(a)
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto hashSha1 = TreeExprBuilder::MakeFunction("hashSHA1", {node_a}, utf8());
+ auto expr_0 = TreeExprBuilder::MakeExpression(hashSha1, res_0);
+
+ auto hashSha256 = TreeExprBuilder::MakeFunction("hashSHA256", {node_a}, utf8());
+ auto expr_1 = TreeExprBuilder::MakeExpression(hashSha256, res_1);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status =
+ Projector::Make(schema, {expr_0, expr_1}, TestConfiguration(), &projector);
+ ASSERT_OK(status) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 3;
+ auto validity_array = {false, true, true};
+
+ auto array_dec = MakeArrowArrayDecimal(
+ decimal_5_2, MakeDecimalVector({"3.45", "0", "0.01"}, 2), validity_array);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_dec});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ ASSERT_OK(status);
+
+ auto response = outputs.at(0);
+ EXPECT_EQ(response->null_count(), 0);
+ EXPECT_NE(response->GetScalar(0).ValueOrDie()->ToString(), "");
+
+ // Checks if the hash size in response is correct
+ const int sha1_hash_size = 40;
+ for (int i = 1; i < num_records; ++i) {
+ const auto& value_at_position = response->GetScalar(i).ValueOrDie()->ToString();
+
+ EXPECT_EQ(value_at_position.size(), sha1_hash_size);
+ EXPECT_NE(value_at_position, response->GetScalar(i - 1).ValueOrDie()->ToString());
+ }
+
+ response = outputs.at(1);
+ EXPECT_EQ(response->null_count(), 0);
+ EXPECT_NE(response->GetScalar(0).ValueOrDie()->ToString(), "");
+
+ // Checks if the hash size in response is correct
+ const int sha256_hash_size = 64;
+ for (int i = 1; i < num_records; ++i) {
+ const auto& value_at_position = response->GetScalar(i).ValueOrDie()->ToString();
+
+ EXPECT_EQ(value_at_position.size(), sha256_hash_size);
+ EXPECT_NE(value_at_position, response->GetScalar(i - 1).ValueOrDie()->ToString());
+ }
+}
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/filter_project_test.cc b/src/arrow/cpp/src/gandiva/tests/filter_project_test.cc
new file mode 100644
index 000000000..0607feaef
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/filter_project_test.cc
@@ -0,0 +1,276 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include "arrow/memory_pool.h"
+#include "gandiva/filter.h"
+#include "gandiva/projector.h"
+#include "gandiva/selection_vector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::float32;
+using arrow::int32;
+
+class TestFilterProject : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+TEST_F(TestFilterProject, TestSimple16) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f1", int32());
+ auto field2 = field("f2", int32());
+ auto resultField = field("result", int32());
+ auto schema = arrow::schema({field0, field1, field2});
+
+ // Build condition f0 < f1
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ auto node_f1 = TreeExprBuilder::MakeField(field1);
+ auto node_f2 = TreeExprBuilder::MakeField(field2);
+ auto less_than_function =
+ TreeExprBuilder::MakeFunction("less_than", {node_f0, node_f1}, arrow::boolean());
+ auto condition = TreeExprBuilder::MakeCondition(less_than_function);
+ auto sum_expr = TreeExprBuilder::MakeExpression("add", {field1, field2}, resultField);
+
+ auto configuration = TestConfiguration();
+
+ std::shared_ptr<Filter> filter;
+ std::shared_ptr<Projector> projector;
+
+ auto status = Filter::Make(schema, condition, configuration, &filter);
+ EXPECT_TRUE(status.ok());
+
+ status = Projector::Make(schema, {sum_expr}, SelectionVector::MODE_UINT16,
+ configuration, &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array0 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, true});
+ auto array1 = MakeArrowArrayInt32({5, 9, 3, 17, 6}, {true, true, true, true, true});
+ auto array2 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, false});
+ // expected output
+ auto result = MakeArrowArrayInt32({6, 11, 0}, {true, true, false});
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+
+ status = projector->Evaluate(*in_batch, selection_vector.get(), pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(result, outputs.at(0));
+}
+
+TEST_F(TestFilterProject, TestSimple32) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f1", int32());
+ auto field2 = field("f2", int32());
+ auto resultField = field("result", int32());
+ auto schema = arrow::schema({field0, field1, field2});
+
+ // Build condition f0 < f1
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ auto node_f1 = TreeExprBuilder::MakeField(field1);
+ auto node_f2 = TreeExprBuilder::MakeField(field2);
+ auto less_than_function =
+ TreeExprBuilder::MakeFunction("less_than", {node_f0, node_f1}, arrow::boolean());
+ auto condition = TreeExprBuilder::MakeCondition(less_than_function);
+ auto sum_expr = TreeExprBuilder::MakeExpression("add", {field1, field2}, resultField);
+
+ auto configuration = TestConfiguration();
+
+ std::shared_ptr<Filter> filter;
+ std::shared_ptr<Projector> projector;
+
+ auto status = Filter::Make(schema, condition, configuration, &filter);
+ EXPECT_TRUE(status.ok());
+
+ status = Projector::Make(schema, {sum_expr}, SelectionVector::MODE_UINT32,
+ configuration, &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array0 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, true});
+ auto array1 = MakeArrowArrayInt32({5, 9, 3, 17, 6}, {true, true, true, true, true});
+ auto array2 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, false});
+ // expected output
+ auto result = MakeArrowArrayInt32({6, 11, 0}, {true, true, false});
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt32(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+
+ status = projector->Evaluate(*in_batch, selection_vector.get(), pool_, &outputs);
+ ASSERT_OK(status);
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(result, outputs.at(0));
+}
+
+TEST_F(TestFilterProject, TestSimple64) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f1", int32());
+ auto field2 = field("f2", int32());
+ auto resultField = field("result", int32());
+ auto schema = arrow::schema({field0, field1, field2});
+
+ // Build condition f0 < f1
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ auto node_f1 = TreeExprBuilder::MakeField(field1);
+ auto node_f2 = TreeExprBuilder::MakeField(field2);
+ auto less_than_function =
+ TreeExprBuilder::MakeFunction("less_than", {node_f0, node_f1}, arrow::boolean());
+ auto condition = TreeExprBuilder::MakeCondition(less_than_function);
+ auto sum_expr = TreeExprBuilder::MakeExpression("add", {field1, field2}, resultField);
+
+ auto configuration = TestConfiguration();
+
+ std::shared_ptr<Filter> filter;
+ std::shared_ptr<Projector> projector;
+
+ auto status = Filter::Make(schema, condition, configuration, &filter);
+ EXPECT_TRUE(status.ok());
+
+ status = Projector::Make(schema, {sum_expr}, SelectionVector::MODE_UINT64,
+ configuration, &projector);
+ ASSERT_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array0 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, true});
+ auto array1 = MakeArrowArrayInt32({5, 9, 3, 17, 6}, {true, true, true, true, true});
+ auto array2 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, false});
+ // expected output
+ auto result = MakeArrowArrayInt32({6, 11, 0}, {true, true, false});
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt64(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+
+ status = projector->Evaluate(*in_batch, selection_vector.get(), pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(result, outputs.at(0));
+}
+
+TEST_F(TestFilterProject, TestSimpleIf) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto fieldc = field("c", int32());
+ auto schema = arrow::schema({fielda, fieldb, fieldc});
+
+ // output fields
+ auto field_result = field("res", int32());
+
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto node_c = TreeExprBuilder::MakeField(fieldc);
+
+ auto greater_than_function =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
+ auto filter_condition = TreeExprBuilder::MakeCondition(greater_than_function);
+
+ auto project_condition =
+ TreeExprBuilder::MakeFunction("less_than", {node_b, node_c}, boolean());
+ auto if_node = TreeExprBuilder::MakeIf(project_condition, node_b, node_c, int32());
+
+ auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
+ auto configuration = TestConfiguration();
+
+ // Build a filter for the expressions.
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, filter_condition, configuration, &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ status = Projector::Make(schema, {expr}, SelectionVector::MODE_UINT32, configuration,
+ &projector);
+ ASSERT_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 6;
+ auto array0 =
+ MakeArrowArrayInt32({10, 12, -20, 5, 21, 29}, {true, true, true, true, true, true});
+ auto array1 =
+ MakeArrowArrayInt32({5, 15, 15, 17, 12, 3}, {true, true, true, true, true, true});
+ auto array2 = MakeArrowArrayInt32({1, 25, 11, 30, -21, 30},
+ {true, true, true, true, true, false});
+
+ // Create a selection vector
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt32(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // expected output
+ auto exp = MakeArrowArrayInt32({1, -21, 0}, {true, true, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2});
+
+ // Evaluate filter
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate project
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, selection_vector.get(), pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/filter_test.cc b/src/arrow/cpp/src/gandiva/tests/filter_test.cc
new file mode 100644
index 000000000..d4433f11e
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/filter_test.cc
@@ -0,0 +1,340 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/filter.h"
+#include <gtest/gtest.h>
+#include "arrow/memory_pool.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::float32;
+using arrow::int32;
+
+class TestFilter : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+TEST_F(TestFilter, TestFilterCache) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f1", int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // Build condition f0 + f1 < 10
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ auto node_f1 = TreeExprBuilder::MakeField(field1);
+ auto sum_func =
+ TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32());
+ auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10);
+ auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10},
+ arrow::boolean());
+ auto condition = TreeExprBuilder::MakeCondition(less_than_10);
+ auto configuration = TestConfiguration();
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, configuration, &filter);
+ EXPECT_TRUE(status.ok());
+
+ // same schema and condition, should return the same filter as above.
+ std::shared_ptr<Filter> cached_filter;
+ status = Filter::Make(schema, condition, configuration, &cached_filter);
+ EXPECT_TRUE(status.ok());
+ EXPECT_TRUE(cached_filter.get() == filter.get());
+
+ // schema is different should return a new filter.
+ auto field2 = field("f2", int32());
+ auto different_schema = arrow::schema({field0, field1, field2});
+ std::shared_ptr<Filter> should_be_new_filter;
+ status =
+ Filter::Make(different_schema, condition, configuration, &should_be_new_filter);
+ EXPECT_TRUE(status.ok());
+ EXPECT_TRUE(cached_filter.get() != should_be_new_filter.get());
+
+ // condition is different, should return a new filter.
+ auto greater_than_10 = TreeExprBuilder::MakeFunction(
+ "greater_than", {sum_func, literal_10}, arrow::boolean());
+ auto new_condition = TreeExprBuilder::MakeCondition(greater_than_10);
+ std::shared_ptr<Filter> should_be_new_filter1;
+ status = Filter::Make(schema, new_condition, configuration, &should_be_new_filter1);
+ EXPECT_TRUE(status.ok());
+ EXPECT_TRUE(cached_filter.get() != should_be_new_filter1.get());
+}
+
+TEST_F(TestFilter, TestSimple) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f1", int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // Build condition f0 + f1 < 10
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ auto node_f1 = TreeExprBuilder::MakeField(field1);
+ auto sum_func =
+ TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32());
+ auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10);
+ auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10},
+ arrow::boolean());
+ auto condition = TreeExprBuilder::MakeCondition(less_than_10);
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array0 = MakeArrowArrayInt32({1, 2, 3, 4, 6}, {true, true, true, false, true});
+ auto array1 = MakeArrowArrayInt32({5, 9, 6, 17, 3}, {true, true, false, true, true});
+ // expected output (indices for which condition matches)
+ auto exp = MakeArrowArrayUint16({0, 4});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
+TEST_F(TestFilter, TestSimpleCustomConfig) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f1", int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // Build condition f0 != f1
+ auto condition = TreeExprBuilder::MakeCondition("not_equal", {field0, field1});
+
+ ConfigurationBuilder config_builder;
+ std::shared_ptr<Configuration> config = config_builder.build();
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false});
+ auto array1 = MakeArrowArrayInt32({11, 2, 3, 17}, {true, true, false, true});
+ // expected output
+ auto exp = MakeArrowArrayUint16({0});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
+TEST_F(TestFilter, TestZeroCopy) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto schema = arrow::schema({field0});
+
+ // Build condition
+ auto condition = TreeExprBuilder::MakeCondition("isnotnull", {field0});
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false});
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ // expected output
+ auto exp = MakeArrowArrayUint16({0, 1, 2});
+
+ // allocate selection buffers
+ int64_t data_sz = sizeof(int16_t) * num_records;
+ std::unique_ptr<uint8_t[]> data(new uint8_t[data_sz]);
+ std::shared_ptr<arrow::MutableBuffer> data_buf =
+ std::make_shared<arrow::MutableBuffer>(data.get(), data_sz);
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt16(num_records, data_buf, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
+TEST_F(TestFilter, TestZeroCopyNegative) {
+ ArrayPtr output;
+
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto schema = arrow::schema({field0});
+
+ // Build expression
+ auto condition = TreeExprBuilder::MakeCondition("isnotnull", {field0});
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false});
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ // expected output
+ auto exp = MakeArrowArrayInt16({0, 1, 2});
+
+ // allocate output buffers
+ int64_t data_sz = sizeof(int16_t) * num_records;
+ std::unique_ptr<uint8_t[]> data(new uint8_t[data_sz]);
+ std::shared_ptr<arrow::MutableBuffer> data_buf =
+ std::make_shared<arrow::MutableBuffer>(data.get(), data_sz);
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt16(num_records, data_buf, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // the batch can't be empty.
+ auto bad_batch = arrow::RecordBatch::Make(schema, 0 /*num_records*/, {array0});
+ status = filter->Evaluate(*bad_batch, selection_vector);
+ EXPECT_EQ(status.code(), StatusCode::Invalid);
+
+ // the selection_vector can't be null.
+ std::shared_ptr<SelectionVector> null_selection;
+ status = filter->Evaluate(*in_batch, null_selection);
+ EXPECT_EQ(status.code(), StatusCode::Invalid);
+
+ // the selection vector must be suitably sized.
+ std::shared_ptr<SelectionVector> bad_selection;
+ status = SelectionVector::MakeInt16(num_records - 1, data_buf, &bad_selection);
+ EXPECT_TRUE(status.ok());
+
+ status = filter->Evaluate(*in_batch, bad_selection);
+ EXPECT_EQ(status.code(), StatusCode::Invalid);
+}
+
+TEST_F(TestFilter, TestSimpleSVInt32) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f1", int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // Build condition f0 + f1 < 10
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ auto node_f1 = TreeExprBuilder::MakeField(field1);
+ auto sum_func =
+ TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32());
+ auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10);
+ auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10},
+ arrow::boolean());
+ auto condition = TreeExprBuilder::MakeCondition(less_than_10);
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array0 = MakeArrowArrayInt32({1, 2, 3, 4, 6}, {true, true, true, false, true});
+ auto array1 = MakeArrowArrayInt32({5, 9, 6, 17, 3}, {true, true, false, true, true});
+ // expected output (indices for which condition matches)
+ auto exp = MakeArrowArrayUint32({0, 4});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt32(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
+TEST_F(TestFilter, TestOffset) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f1", int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // Build condition f0 + f1 < 10
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ auto node_f1 = TreeExprBuilder::MakeField(field1);
+ auto sum_func =
+ TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32());
+ auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10);
+ auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10},
+ arrow::boolean());
+ auto condition = TreeExprBuilder::MakeCondition(less_than_10);
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array0 =
+ MakeArrowArrayInt32({0, 1, 2, 3, 4, 6}, {true, true, true, true, false, true});
+ array0 = array0->Slice(1);
+ auto array1 = MakeArrowArrayInt32({5, 9, 6, 17, 3}, {true, true, false, true, true});
+ // expected output (indices for which condition matches)
+ auto exp = MakeArrowArrayUint16({3});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+ in_batch = in_batch->Slice(1);
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/generate_data.h b/src/arrow/cpp/src/gandiva/tests/generate_data.h
new file mode 100644
index 000000000..9fb0e4eae
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/generate_data.h
@@ -0,0 +1,152 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <stdlib.h>
+#include <random>
+#include <string>
+
+#include "arrow/util/decimal.h"
+#include "arrow/util/io_util.h"
+
+#pragma once
+
+namespace gandiva {
+
+template <typename C_TYPE>
+class DataGenerator {
+ public:
+ virtual ~DataGenerator() = default;
+
+ virtual C_TYPE GenerateData() = 0;
+};
+
+class Random {
+ public:
+ Random() : gen_(::arrow::internal::GetRandomSeed()) {}
+ explicit Random(uint64_t seed) : gen_(seed) {}
+
+ int32_t next() { return gen_(); }
+
+ private:
+ std::default_random_engine gen_;
+};
+
+class Int32DataGenerator : public DataGenerator<int32_t> {
+ public:
+ Int32DataGenerator() {}
+
+ int32_t GenerateData() { return random_.next(); }
+
+ protected:
+ Random random_;
+};
+
+class BoundedInt32DataGenerator : public Int32DataGenerator {
+ public:
+ explicit BoundedInt32DataGenerator(uint32_t upperBound)
+ : Int32DataGenerator(), upperBound_(upperBound) {}
+
+ int32_t GenerateData() {
+ int32_t value = (random_.next() % upperBound_);
+ return value;
+ }
+
+ protected:
+ uint32_t upperBound_;
+};
+
+class Int64DataGenerator : public DataGenerator<int64_t> {
+ public:
+ Int64DataGenerator() {}
+
+ int64_t GenerateData() { return random_.next(); }
+
+ protected:
+ Random random_;
+};
+
+class Decimal128DataGenerator : public DataGenerator<arrow::Decimal128> {
+ public:
+ explicit Decimal128DataGenerator(bool large) : large_(large) {}
+
+ arrow::Decimal128 GenerateData() {
+ uint64_t low = random_.next();
+ int64_t high = random_.next();
+ if (large_) {
+ high += (1ull << 62);
+ }
+ return arrow::Decimal128(high, low);
+ }
+
+ protected:
+ bool large_;
+ Random random_;
+};
+
+class FastUtf8DataGenerator : public DataGenerator<std::string> {
+ public:
+ explicit FastUtf8DataGenerator(int max_len) : max_len_(max_len), cur_char_('a') {}
+
+ std::string GenerateData() {
+ std::string generated_str;
+
+ int slen = random_.next() % max_len_;
+ for (int i = 0; i < slen; ++i) {
+ generated_str += generate_next_char();
+ }
+ return generated_str;
+ }
+
+ private:
+ char generate_next_char() {
+ ++cur_char_;
+ if (cur_char_ > 'z') {
+ cur_char_ = 'a';
+ }
+ return cur_char_;
+ }
+
+ Random random_;
+ unsigned int max_len_;
+ char cur_char_;
+};
+
+class Utf8IntDataGenerator : public DataGenerator<std::string> {
+ public:
+ Utf8IntDataGenerator() {}
+
+ std::string GenerateData() { return std::to_string(random_.next()); }
+
+ private:
+ Random random_;
+};
+
+class Utf8FloatDataGenerator : public DataGenerator<std::string> {
+ public:
+ Utf8FloatDataGenerator() {}
+
+ std::string GenerateData() {
+ return std::to_string(
+ static_cast<float>(random_.next()) /
+ static_cast<float>(RAND_MAX / 100)); // random float between 0.0 to 100.0
+ }
+
+ private:
+ Random random_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/hash_test.cc b/src/arrow/cpp/src/gandiva/tests/hash_test.cc
new file mode 100644
index 000000000..40ebc50a2
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/hash_test.cc
@@ -0,0 +1,615 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <sstream>
+
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::float32;
+using arrow::float64;
+using arrow::int32;
+using arrow::int64;
+using arrow::utf8;
+
+class TestHash : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+TEST_F(TestHash, TestSimple) {
+ // schema for input fields
+ auto field_a = field("a", int32());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res_0 = field("res0", int32());
+ auto res_1 = field("res1", int64());
+
+ // build expression.
+ // hash32(a, 10)
+ // hash64(a)
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10);
+ auto hash32 = TreeExprBuilder::MakeFunction("hash32", {node_a, literal_10}, int32());
+ auto hash64 = TreeExprBuilder::MakeFunction("hash64", {node_a}, int64());
+ auto expr_0 = TreeExprBuilder::MakeExpression(hash32, res_0);
+ auto expr_1 = TreeExprBuilder::MakeExpression(hash64, res_1);
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ auto status =
+ Projector::Make(schema, {expr_0, expr_1}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a = MakeArrowArrayInt32({1, 2, 3, 4}, {false, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ auto int32_arr = std::dynamic_pointer_cast<arrow::Int32Array>(outputs.at(0));
+ EXPECT_EQ(int32_arr->null_count(), 0);
+ EXPECT_EQ(int32_arr->Value(0), 10);
+ for (int i = 1; i < num_records; ++i) {
+ EXPECT_NE(int32_arr->Value(i), int32_arr->Value(i - 1));
+ }
+
+ auto int64_arr = std::dynamic_pointer_cast<arrow::Int64Array>(outputs.at(1));
+ EXPECT_EQ(int64_arr->null_count(), 0);
+ EXPECT_EQ(int64_arr->Value(0), 0);
+ for (int i = 1; i < num_records; ++i) {
+ EXPECT_NE(int64_arr->Value(i), int64_arr->Value(i - 1));
+ }
+}
+
+TEST_F(TestHash, TestBuf) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res_0 = field("res0", int32());
+ auto res_1 = field("res1", int64());
+
+ // build expressions.
+ // hash32(a)
+ // hash64(a, 10)
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto literal_10 = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(10));
+ auto hash32 = TreeExprBuilder::MakeFunction("hash32", {node_a}, int32());
+ auto hash64 = TreeExprBuilder::MakeFunction("hash64", {node_a, literal_10}, int64());
+ auto expr_0 = TreeExprBuilder::MakeExpression(hash32, res_0);
+ auto expr_1 = TreeExprBuilder::MakeExpression(hash64, res_1);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status =
+ Projector::Make(schema, {expr_0, expr_1}, TestConfiguration(), &projector);
+ ASSERT_OK(status) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a =
+ MakeArrowArrayUtf8({"foo", "hello", "bye", "hi"}, {false, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ ASSERT_OK(status);
+
+ // Validate results
+ auto int32_arr = std::dynamic_pointer_cast<arrow::Int32Array>(outputs.at(0));
+ EXPECT_EQ(int32_arr->null_count(), 0);
+ EXPECT_EQ(int32_arr->Value(0), 0);
+ for (int i = 1; i < num_records; ++i) {
+ EXPECT_NE(int32_arr->Value(i), int32_arr->Value(i - 1));
+ }
+
+ auto int64_arr = std::dynamic_pointer_cast<arrow::Int64Array>(outputs.at(1));
+ EXPECT_EQ(int64_arr->null_count(), 0);
+ EXPECT_EQ(int64_arr->Value(0), 10);
+ for (int i = 1; i < num_records; ++i) {
+ EXPECT_NE(int64_arr->Value(i), int64_arr->Value(i - 1));
+ }
+}
+
+TEST_F(TestHash, TestSha256Simple) {
+ // schema for input fields
+ auto field_a = field("a", int32());
+ auto field_b = field("b", int64());
+ auto field_c = field("c", float32());
+ auto field_d = field("d", float64());
+ auto schema = arrow::schema({field_a, field_b, field_c, field_d});
+
+ // output fields
+ auto res_0 = field("res0", utf8());
+ auto res_1 = field("res1", utf8());
+ auto res_2 = field("res2", utf8());
+ auto res_3 = field("res3", utf8());
+
+ // build expressions.
+ // hashSHA256(a)
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto hashSha256_1 = TreeExprBuilder::MakeFunction("hashSHA256", {node_a}, utf8());
+ auto expr_0 = TreeExprBuilder::MakeExpression(hashSha256_1, res_0);
+
+ auto node_b = TreeExprBuilder::MakeField(field_b);
+ auto hashSha256_2 = TreeExprBuilder::MakeFunction("hashSHA256", {node_b}, utf8());
+ auto expr_1 = TreeExprBuilder::MakeExpression(hashSha256_2, res_1);
+
+ auto node_c = TreeExprBuilder::MakeField(field_c);
+ auto hashSha256_3 = TreeExprBuilder::MakeFunction("hashSHA256", {node_c}, utf8());
+ auto expr_2 = TreeExprBuilder::MakeExpression(hashSha256_3, res_2);
+
+ auto node_d = TreeExprBuilder::MakeField(field_d);
+ auto hashSha256_4 = TreeExprBuilder::MakeFunction("hashSHA256", {node_d}, utf8());
+ auto expr_3 = TreeExprBuilder::MakeExpression(hashSha256_4, res_3);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr_0, expr_1, expr_2, expr_3},
+ TestConfiguration(), &projector);
+ ASSERT_OK(status) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 2;
+ auto validity_array = {false, true};
+
+ auto array_int32 = MakeArrowArrayInt32({1, 0}, validity_array);
+
+ auto array_int64 = MakeArrowArrayInt64({1, 0}, validity_array);
+
+ auto array_float32 = MakeArrowArrayFloat32({1.0, 0.0}, validity_array);
+
+ auto array_float64 = MakeArrowArrayFloat64({1.0, 0.0}, validity_array);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(
+ schema, num_records, {array_int32, array_int64, array_float32, array_float64});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ ASSERT_OK(status);
+
+ auto response_int32 = outputs.at(0);
+ auto response_int64 = outputs.at(1);
+ auto response_float32 = outputs.at(2);
+ auto response_float64 = outputs.at(3);
+
+ // Checks if the null and zero representation for numeric values
+ // are consistent between the types
+ EXPECT_ARROW_ARRAY_EQUALS(response_int32, response_int64);
+ EXPECT_ARROW_ARRAY_EQUALS(response_int64, response_float32);
+ EXPECT_ARROW_ARRAY_EQUALS(response_float32, response_float64);
+
+ const int sha256_hash_size = 64;
+
+ // Checks if the hash size in response is correct
+ for (int i = 1; i < num_records; ++i) {
+ const auto& value_at_position = response_int32->GetScalar(i).ValueOrDie()->ToString();
+
+ EXPECT_EQ(value_at_position.size(), sha256_hash_size);
+ EXPECT_NE(value_at_position,
+ response_int32->GetScalar(i - 1).ValueOrDie()->ToString());
+ }
+}
+
+TEST_F(TestHash, TestSha256Varlen) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res_0 = field("res0", utf8());
+
+ // build expressions.
+ // hashSHA256(a)
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto hashSha256 = TreeExprBuilder::MakeFunction("hashSHA256", {node_a}, utf8());
+ auto expr_0 = TreeExprBuilder::MakeExpression(hashSha256, res_0);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr_0}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 3;
+
+ std::string first_string =
+ "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeıʃn\nY "
+ "[ˈʏpsilɔn], Yen [jɛn], Yoga [ˈjoːgɑ]";
+ std::string second_string =
+ "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeın\nY "
+ "[ˈʏpsilɔn], Yen [jɛn], Yoga [ˈjoːgɑ] コンニチハ";
+
+ auto array_a =
+ MakeArrowArrayUtf8({"foo", first_string, second_string}, {false, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ ASSERT_OK(status);
+
+ auto response = outputs.at(0);
+ const int sha256_hash_size = 64;
+
+ EXPECT_EQ(response->null_count(), 0);
+
+ // Checks that the null value was hashed
+ EXPECT_NE(response->GetScalar(0).ValueOrDie()->ToString(), "");
+ EXPECT_EQ(response->GetScalar(0).ValueOrDie()->ToString().size(), sha256_hash_size);
+
+ // Check that all generated hashes were different
+ for (int i = 1; i < num_records; ++i) {
+ const auto& value_at_position = response->GetScalar(i).ValueOrDie()->ToString();
+
+ EXPECT_EQ(value_at_position.size(), sha256_hash_size);
+ EXPECT_NE(value_at_position, response->GetScalar(i - 1).ValueOrDie()->ToString());
+ }
+}
+
+TEST_F(TestHash, TestSha1Simple) {
+ // schema for input fields
+ auto field_a = field("a", int32());
+ auto field_b = field("b", int64());
+ auto field_c = field("c", float32());
+ auto field_d = field("d", float64());
+ auto schema = arrow::schema({field_a, field_b, field_c, field_d});
+
+ // output fields
+ auto res_0 = field("res0", utf8());
+ auto res_1 = field("res1", utf8());
+ auto res_2 = field("res2", utf8());
+ auto res_3 = field("res3", utf8());
+
+ // build expressions.
+ // hashSHA1(a)
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto hashSha1_1 = TreeExprBuilder::MakeFunction("hashSHA1", {node_a}, utf8());
+ auto expr_0 = TreeExprBuilder::MakeExpression(hashSha1_1, res_0);
+
+ auto node_b = TreeExprBuilder::MakeField(field_b);
+ auto hashSha1_2 = TreeExprBuilder::MakeFunction("hashSHA1", {node_b}, utf8());
+ auto expr_1 = TreeExprBuilder::MakeExpression(hashSha1_2, res_1);
+
+ auto node_c = TreeExprBuilder::MakeField(field_c);
+ auto hashSha1_3 = TreeExprBuilder::MakeFunction("hashSHA1", {node_c}, utf8());
+ auto expr_2 = TreeExprBuilder::MakeExpression(hashSha1_3, res_2);
+
+ auto node_d = TreeExprBuilder::MakeField(field_d);
+ auto hashSha1_4 = TreeExprBuilder::MakeFunction("hashSHA1", {node_d}, utf8());
+ auto expr_3 = TreeExprBuilder::MakeExpression(hashSha1_4, res_3);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr_0, expr_1, expr_2, expr_3},
+ TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 2;
+ auto validity_array = {false, true};
+
+ auto array_int32 = MakeArrowArrayInt32({1, 0}, validity_array);
+
+ auto array_int64 = MakeArrowArrayInt64({1, 0}, validity_array);
+
+ auto array_float32 = MakeArrowArrayFloat32({1.0, 0.0}, validity_array);
+
+ auto array_float64 = MakeArrowArrayFloat64({1.0, 0.0}, validity_array);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(
+ schema, num_records, {array_int32, array_int64, array_float32, array_float64});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ ASSERT_OK(status);
+
+ auto response_int32 = outputs.at(0);
+ auto response_int64 = outputs.at(1);
+ auto response_float32 = outputs.at(2);
+ auto response_float64 = outputs.at(3);
+
+ // Checks if the null and zero representation for numeric values
+ // are consistent between the types
+ EXPECT_ARROW_ARRAY_EQUALS(response_int32, response_int64);
+ EXPECT_ARROW_ARRAY_EQUALS(response_int64, response_float32);
+ EXPECT_ARROW_ARRAY_EQUALS(response_float32, response_float64);
+
+ const int sha1_hash_size = 40;
+
+ // Checks if the hash size in response is correct
+ for (int i = 1; i < num_records; ++i) {
+ const auto& value_at_position = response_int32->GetScalar(i).ValueOrDie()->ToString();
+
+ EXPECT_EQ(value_at_position.size(), sha1_hash_size);
+ EXPECT_NE(value_at_position,
+ response_int32->GetScalar(i - 1).ValueOrDie()->ToString());
+ }
+}
+
+TEST_F(TestHash, TestSha1Varlen) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res_0 = field("res0", utf8());
+
+ // build expressions.
+ // hashSHA1(a)
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto hashSha1 = TreeExprBuilder::MakeFunction("hashSHA1", {node_a}, utf8());
+ auto expr_0 = TreeExprBuilder::MakeExpression(hashSha1, res_0);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr_0}, TestConfiguration(), &projector);
+ ASSERT_OK(status) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 3;
+
+ std::string first_string =
+ "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeıʃn\nY [ˈʏpsilɔn], "
+ "Yen [jɛn], Yoga [ˈjoːgɑ]";
+ std::string second_string =
+ "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeın\nY [ˈʏpsilɔn], "
+ "Yen [jɛn], Yoga [ˈjoːgɑ] コンニチハ";
+
+ auto array_a =
+ MakeArrowArrayUtf8({"", first_string, second_string}, {false, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ ASSERT_OK(status);
+
+ auto response = outputs.at(0);
+ const int sha1_hash_size = 40;
+
+ EXPECT_EQ(response->null_count(), 0);
+
+ // Checks that the null value was hashed
+ EXPECT_NE(response->GetScalar(0).ValueOrDie()->ToString(), "");
+ EXPECT_EQ(response->GetScalar(0).ValueOrDie()->ToString().size(), sha1_hash_size);
+
+ // Check that all generated hashes were different
+ for (int i = 1; i < num_records; ++i) {
+ const auto& value_at_position = response->GetScalar(i).ValueOrDie()->ToString();
+
+ EXPECT_EQ(value_at_position.size(), sha1_hash_size);
+ EXPECT_NE(value_at_position, response->GetScalar(i - 1).ValueOrDie()->ToString());
+ }
+}
+
+TEST_F(TestHash, TestSha1FunctionsAlias) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto field_b = field("c", int64());
+ auto field_c = field("e", float64());
+ auto schema = arrow::schema({field_a, field_b, field_c});
+
+ // output fields
+ auto res_0 = field("res0", utf8());
+ auto res_0_sha1 = field("res0sha1", utf8());
+ auto res_0_sha = field("res0sha", utf8());
+
+ auto res_1 = field("res1", utf8());
+ auto res_1_sha1 = field("res1sha1", utf8());
+ auto res_1_sha = field("res1sha", utf8());
+
+ auto res_2 = field("res2", utf8());
+ auto res_2_sha1 = field("res2_sha1", utf8());
+ auto res_2_sha = field("res2_sha", utf8());
+
+ // build expressions.
+ // hashSHA1(a)
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto hashSha1 = TreeExprBuilder::MakeFunction("hashSHA1", {node_a}, utf8());
+ auto expr_0 = TreeExprBuilder::MakeExpression(hashSha1, res_0);
+ auto sha1 = TreeExprBuilder::MakeFunction("sha1", {node_a}, utf8());
+ auto expr_0_sha1 = TreeExprBuilder::MakeExpression(sha1, res_0_sha1);
+ auto sha = TreeExprBuilder::MakeFunction("sha", {node_a}, utf8());
+ auto expr_0_sha = TreeExprBuilder::MakeExpression(sha, res_0_sha);
+
+ auto node_b = TreeExprBuilder::MakeField(field_b);
+ auto hashSha1_1 = TreeExprBuilder::MakeFunction("hashSHA1", {node_b}, utf8());
+ auto expr_1 = TreeExprBuilder::MakeExpression(hashSha1_1, res_1);
+ auto sha1_1 = TreeExprBuilder::MakeFunction("sha1", {node_b}, utf8());
+ auto expr_1_sha1 = TreeExprBuilder::MakeExpression(sha1_1, res_1_sha1);
+ auto sha_1 = TreeExprBuilder::MakeFunction("sha", {node_b}, utf8());
+ auto expr_1_sha = TreeExprBuilder::MakeExpression(sha_1, res_1_sha);
+
+ auto node_c = TreeExprBuilder::MakeField(field_c);
+ auto hashSha1_2 = TreeExprBuilder::MakeFunction("hashSHA1", {node_c}, utf8());
+ auto expr_2 = TreeExprBuilder::MakeExpression(hashSha1_2, res_2);
+ auto sha1_2 = TreeExprBuilder::MakeFunction("sha1", {node_c}, utf8());
+ auto expr_2_sha1 = TreeExprBuilder::MakeExpression(sha1_2, res_2_sha1);
+ auto sha_2 = TreeExprBuilder::MakeFunction("sha", {node_c}, utf8());
+ auto expr_2_sha = TreeExprBuilder::MakeExpression(sha_2, res_2_sha);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema,
+ {expr_0, expr_0_sha, expr_0_sha1, expr_1, expr_1_sha,
+ expr_1_sha1, expr_2, expr_2_sha, expr_2_sha1},
+ TestConfiguration(), &projector);
+ ASSERT_OK(status) << status.message();
+
+ // Create a row-batch with some sample data
+ int32_t num_records = 3;
+
+ std::string first_string =
+ "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeıʃn\nY [ˈʏpsilɔn], "
+ "Yen [jɛn], Yoga [ˈjoːgɑ]";
+ std::string second_string =
+ "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeın\nY [ˈʏpsilɔn], "
+ "Yen [jɛn], Yoga [ˈjoːgɑ] コンニチハ";
+
+ auto array_utf8 =
+ MakeArrowArrayUtf8({"", first_string, second_string}, {false, true, true});
+
+ auto validity_array = {false, true, true};
+
+ auto array_int64 = MakeArrowArrayInt64({1, 0, 32423}, validity_array);
+
+ auto array_float64 = MakeArrowArrayFloat64({1.0, 0.0, 324893.3849}, validity_array);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records,
+ {array_utf8, array_int64, array_float64});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ ASSERT_OK(status);
+
+ // Checks that the response for the hashSHA1, sha and sha1 are equals for the first
+ // field of utf8 type
+ EXPECT_ARROW_ARRAY_EQUALS(outputs.at(0), outputs.at(1)); // hashSha1 and sha
+ EXPECT_ARROW_ARRAY_EQUALS(outputs.at(1), outputs.at(2)); // sha and sha1
+
+ // Checks that the response for the hashSHA1, sha and sha1 are equals for the second
+ // field of int64 type
+ EXPECT_ARROW_ARRAY_EQUALS(outputs.at(3), outputs.at(4)); // hashSha1 and sha
+ EXPECT_ARROW_ARRAY_EQUALS(outputs.at(4), outputs.at(5)); // sha and sha1
+
+ // Checks that the response for the hashSHA1, sha and sha1 are equals for the first
+ // field of float64 type
+ EXPECT_ARROW_ARRAY_EQUALS(outputs.at(6), outputs.at(7)); // hashSha1 and sha responses
+ EXPECT_ARROW_ARRAY_EQUALS(outputs.at(7), outputs.at(8)); // sha and sha1 responses
+}
+
+TEST_F(TestHash, TestSha256FunctionsAlias) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto field_b = field("c", int64());
+ auto field_c = field("e", float64());
+ auto schema = arrow::schema({field_a, field_b, field_c});
+
+ // output fields
+ auto res_0 = field("res0", utf8());
+ auto res_0_sha256 = field("res0sha256", utf8());
+
+ auto res_1 = field("res1", utf8());
+ auto res_1_sha256 = field("res1sha256", utf8());
+
+ auto res_2 = field("res2", utf8());
+ auto res_2_sha256 = field("res2_sha256", utf8());
+
+ // build expressions.
+ // hashSHA1(a)
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto hashSha2 = TreeExprBuilder::MakeFunction("hashSHA256", {node_a}, utf8());
+ auto expr_0 = TreeExprBuilder::MakeExpression(hashSha2, res_0);
+ auto sha256 = TreeExprBuilder::MakeFunction("sha256", {node_a}, utf8());
+ auto expr_0_sha256 = TreeExprBuilder::MakeExpression(sha256, res_0_sha256);
+
+ auto node_b = TreeExprBuilder::MakeField(field_b);
+ auto hashSha2_1 = TreeExprBuilder::MakeFunction("hashSHA256", {node_b}, utf8());
+ auto expr_1 = TreeExprBuilder::MakeExpression(hashSha2_1, res_1);
+ auto sha256_1 = TreeExprBuilder::MakeFunction("sha256", {node_b}, utf8());
+ auto expr_1_sha256 = TreeExprBuilder::MakeExpression(sha256_1, res_1_sha256);
+
+ auto node_c = TreeExprBuilder::MakeField(field_c);
+ auto hashSha2_2 = TreeExprBuilder::MakeFunction("hashSHA256", {node_c}, utf8());
+ auto expr_2 = TreeExprBuilder::MakeExpression(hashSha2_2, res_2);
+ auto sha256_2 = TreeExprBuilder::MakeFunction("sha256", {node_c}, utf8());
+ auto expr_2_sha256 = TreeExprBuilder::MakeExpression(sha256_2, res_2_sha256);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(
+ schema, {expr_0, expr_0_sha256, expr_1, expr_1_sha256, expr_2, expr_2_sha256},
+ TestConfiguration(), &projector);
+ ASSERT_OK(status) << status.message();
+
+ // Create a row-batch with some sample data
+ int32_t num_records = 3;
+
+ std::string first_string =
+ "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeıʃn\nY [ˈʏpsilɔn], "
+ "Yen [jɛn], Yoga [ˈjoːgɑ]";
+ std::string second_string =
+ "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeın\nY [ˈʏpsilɔn], "
+ "Yen [jɛn], Yoga [ˈjoːgɑ] コンニチハ";
+
+ auto array_utf8 =
+ MakeArrowArrayUtf8({"", first_string, second_string}, {false, true, true});
+
+ auto validity_array = {false, true, true};
+
+ auto array_int64 = MakeArrowArrayInt64({1, 0, 32423}, validity_array);
+
+ auto array_float64 = MakeArrowArrayFloat64({1.0, 0.0, 324893.3849}, validity_array);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records,
+ {array_utf8, array_int64, array_float64});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ ASSERT_OK(status);
+
+ // Checks that the response for the hashSHA2, sha256 and sha2 are equals for the first
+ // field of utf8 type
+ EXPECT_ARROW_ARRAY_EQUALS(outputs.at(0), outputs.at(1)); // hashSha2 and sha256
+
+ // Checks that the response for the hashSHA2, sha256 and sha2 are equals for the second
+ // field of int64 type
+ EXPECT_ARROW_ARRAY_EQUALS(outputs.at(2), outputs.at(3)); // hashSha2 and sha256
+
+ // Checks that the response for the hashSHA2, sha256 and sha2 are equals for the first
+ // field of float64 type
+ EXPECT_ARROW_ARRAY_EQUALS(outputs.at(4),
+ outputs.at(5)); // hashSha2 and sha256 responses
+}
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/huge_table_test.cc b/src/arrow/cpp/src/gandiva/tests/huge_table_test.cc
new file mode 100644
index 000000000..46f814b47
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/huge_table_test.cc
@@ -0,0 +1,157 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include "arrow/memory_pool.h"
+#include "gandiva/filter.h"
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::float32;
+using arrow::int32;
+
+class LARGE_MEMORY_TEST(TestHugeProjector) : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+class LARGE_MEMORY_TEST(TestHugeFilter) : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+TEST_F(LARGE_MEMORY_TEST(TestHugeProjector), SimpleTestSumHuge) {
+ auto atype = arrow::TypeTraits<arrow::Int32Type>::type_singleton();
+
+ // schema for input fields
+ auto field0 = field("f0", atype);
+ auto field1 = field("f1", atype);
+ auto schema = arrow::schema({field0, field1});
+
+ // output fields
+ auto field_sum = field("add", atype);
+
+ // Build expression
+ auto sum_expr = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum);
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {sum_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ // Cause an overflow in int32_t
+ int64_t num_records = static_cast<int64_t>(INT32_MAX) + 3;
+ std::vector<int32_t> input0 = {2, 29, 5, 37, 11, 59, 17, 19};
+ std::vector<int32_t> input1 = {23, 3, 31, 7, 41, 47, 13};
+ std::vector<bool> validity;
+
+ std::vector<int32_t> arr1;
+ std::vector<int32_t> arr2;
+ // expected output
+ std::vector<int32_t> sum1;
+
+ for (int64_t i = 0; i < num_records; i++) {
+ arr1.push_back(input0[i % 8]);
+ arr2.push_back(input1[i % 7]);
+ sum1.push_back(input0[i % 8] + input1[i % 7]);
+ validity.push_back(true);
+ }
+
+ auto exp_sum = MakeArrowArray<arrow::Int32Type, int32_t>(sum1, validity);
+ auto array0 = MakeArrowArray<arrow::Int32Type, int32_t>(arr1, validity);
+ auto array1 = MakeArrowArray<arrow::Int32Type, int32_t>(arr2, validity);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_sum, outputs.at(0));
+}
+
+TEST_F(LARGE_MEMORY_TEST(TestHugeFilter), TestSimpleHugeFilter) {
+ // Create a row-batch with some sample data
+ // Cause an overflow in int32_t
+ int64_t num_records = static_cast<int64_t>(INT32_MAX) + 3;
+ std::vector<int32_t> input0 = {2, 29, 5, 37, 11, 59, 17, 19};
+ std::vector<int32_t> input1 = {23, 3, 31, 7, 41, 47, 13};
+ std::vector<bool> validity;
+
+ std::vector<int32_t> arr1;
+ std::vector<int32_t> arr2;
+ // expected output
+ std::vector<uint64_t> sel;
+
+ for (int64_t i = 0; i < num_records; i++) {
+ arr1.push_back(input0[i % 8]);
+ arr2.push_back(input1[i % 7]);
+ if (input0[i % 8] + input1[i % 7] > 50) {
+ sel.push_back(i);
+ }
+ validity.push_back(true);
+ }
+
+ auto exp = MakeArrowArrayUint64(sel);
+
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f1", int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // Build condition f0 + f1 < 50
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ auto node_f1 = TreeExprBuilder::MakeField(field1);
+ auto sum_func =
+ TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32());
+ auto literal_50 = TreeExprBuilder::MakeLiteral((int32_t)50);
+ auto less_than_50 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_50},
+ arrow::boolean());
+ auto condition = TreeExprBuilder::MakeCondition(less_than_50);
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arr1, arr2});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt64(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/if_expr_test.cc b/src/arrow/cpp/src/gandiva/tests/if_expr_test.cc
new file mode 100644
index 000000000..54b6d43b4
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/if_expr_test.cc
@@ -0,0 +1,378 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::float32;
+using arrow::int32;
+
+class TestIfExpr : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+TEST_F(TestIfExpr, TestSimple) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto schema = arrow::schema({fielda, fieldb});
+
+ // output fields
+ auto field_result = field("res", int32());
+
+ // build expression.
+ // if (a > b)
+ // a
+ // else
+ // b
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto condition =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
+ auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32());
+
+ auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayInt32({10, 12, -20, 5}, {true, true, true, false});
+ auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true});
+
+ // expected output
+ auto exp = MakeArrowArrayInt32({10, 15, 15, 17}, {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestIfExpr, TestSimpleArithmetic) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto schema = arrow::schema({fielda, fieldb});
+
+ // output fields
+ auto field_result = field("res", int32());
+
+ // build expression.
+ // if (a > b)
+ // a + b
+ // else
+ // a - b
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto condition =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
+ auto sum = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32());
+ auto sub = TreeExprBuilder::MakeFunction("subtract", {node_a, node_b}, int32());
+ auto if_node = TreeExprBuilder::MakeIf(condition, sum, sub, int32());
+
+ auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayInt32({10, 12, -20, 5}, {true, true, true, false});
+ auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true});
+
+ // expected output
+ auto exp = MakeArrowArrayInt32({15, -3, -35, 0}, {true, true, true, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestIfExpr, TestNested) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto schema = arrow::schema({fielda, fieldb});
+
+ // output fields
+ auto field_result = field("res", int32());
+
+ // build expression.
+ // if (a > b)
+ // a + b
+ // else if (a < b)
+ // a - b
+ // else
+ // a * b
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto condition_gt =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
+ auto condition_lt =
+ TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean());
+ auto sum = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32());
+ auto sub = TreeExprBuilder::MakeFunction("subtract", {node_a, node_b}, int32());
+ auto mult = TreeExprBuilder::MakeFunction("multiply", {node_a, node_b}, int32());
+ auto else_node = TreeExprBuilder::MakeIf(condition_lt, sub, mult, int32());
+ auto if_node = TreeExprBuilder::MakeIf(condition_gt, sum, else_node, int32());
+
+ auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayInt32({10, 12, 15, 5}, {true, true, true, false});
+ auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true});
+
+ // expected output
+ auto exp = MakeArrowArrayInt32({15, -3, 225, 0}, {true, true, true, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestIfExpr, TestNestedInIf) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto fieldc = field("c", int32());
+ auto schema = arrow::schema({fielda, fieldb, fieldc});
+
+ // output fields
+ auto field_result = field("res", int32());
+
+ // build expression.
+ // if (a > 10)
+ // if (a < 20)
+ // a + b
+ // else
+ // b + c
+ // else
+ // a + c
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto node_c = TreeExprBuilder::MakeField(fieldc);
+
+ auto literal_10 = TreeExprBuilder::MakeLiteral(10);
+ auto literal_20 = TreeExprBuilder::MakeLiteral(20);
+
+ auto gt_10 =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_10}, boolean());
+ auto lt_20 =
+ TreeExprBuilder::MakeFunction("less_than", {node_a, literal_20}, boolean());
+ auto sum_ab = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32());
+ auto sum_bc = TreeExprBuilder::MakeFunction("add", {node_b, node_c}, int32());
+ auto sum_ac = TreeExprBuilder::MakeFunction("add", {node_a, node_c}, int32());
+
+ auto if_lt_20 = TreeExprBuilder::MakeIf(lt_20, sum_ab, sum_bc, int32());
+ auto if_gt_10 = TreeExprBuilder::MakeIf(gt_10, if_lt_20, sum_ac, int32());
+
+ auto expr = TreeExprBuilder::MakeExpression(if_gt_10, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 6;
+ auto array_a =
+ MakeArrowArrayInt32({21, 15, 5, 22, 15, 5}, {true, true, true, true, true, true});
+ auto array_b = MakeArrowArrayInt32({20, 18, 19, 20, 18, 19},
+ {true, true, true, false, false, false});
+ auto array_c = MakeArrowArrayInt32({35, 45, 55, 35, 45, 55},
+ {true, true, true, false, false, false});
+
+ // expected output
+ auto exp =
+ MakeArrowArrayInt32({55, 33, 60, 0, 0, 0}, {true, true, true, false, false, false});
+
+ // prepare input record batch
+ auto in_batch =
+ arrow::RecordBatch::Make(schema, num_records, {array_a, array_b, array_c});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestIfExpr, TestNestedInCondition) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto schema = arrow::schema({fielda, fieldb});
+
+ // output fields
+ auto field_result = field("res", int32());
+
+ // build expression.
+ // if (if (a > b) then true else if (a < b) false else null)
+ // 1
+ // else if !(if (a > b) then true else if (a < b) false else null)
+ // 2
+ // else
+ // 3
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto literal_1 = TreeExprBuilder::MakeLiteral(1);
+ auto literal_2 = TreeExprBuilder::MakeLiteral(2);
+ auto literal_3 = TreeExprBuilder::MakeLiteral(3);
+ auto literal_true = TreeExprBuilder::MakeLiteral(true);
+ auto literal_false = TreeExprBuilder::MakeLiteral(false);
+ auto literal_null = TreeExprBuilder::MakeNull(boolean());
+
+ auto a_gt_b =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
+ auto a_lt_b = TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean());
+ auto cond_else =
+ TreeExprBuilder::MakeIf(a_lt_b, literal_false, literal_null, boolean());
+ auto cond_if = TreeExprBuilder::MakeIf(a_gt_b, literal_true, cond_else, boolean());
+ auto not_cond_if = TreeExprBuilder::MakeFunction("not", {cond_if}, boolean());
+
+ auto outer_else = TreeExprBuilder::MakeIf(not_cond_if, literal_2, literal_3, int32());
+ auto outer_if = TreeExprBuilder::MakeIf(cond_if, literal_1, outer_else, int32());
+ auto expr = TreeExprBuilder::MakeExpression(outer_if, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 6;
+ auto array_a =
+ MakeArrowArrayInt32({21, 15, 5, 22, 15, 5}, {true, true, true, true, true, true});
+ auto array_b = MakeArrowArrayInt32({20, 18, 19, 20, 18, 19},
+ {true, true, true, false, false, false});
+ // expected output
+ auto exp =
+ MakeArrowArrayInt32({1, 2, 2, 3, 3, 3}, {true, true, true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestIfExpr, TestBigNested) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto schema = arrow::schema({fielda});
+
+ // output fields
+ auto field_result = field("res", int32());
+
+ // build expression.
+ // if (a < 10)
+ // 10
+ // else if (a < 20)
+ // 20
+ // ..
+ // ..
+ // else if (a < 190)
+ // 190
+ // else
+ // 200
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto top_node = TreeExprBuilder::MakeLiteral(200);
+ for (int thresh = 190; thresh > 0; thresh -= 10) {
+ auto literal = TreeExprBuilder::MakeLiteral(thresh);
+ auto condition =
+ TreeExprBuilder::MakeFunction("less_than", {node_a, literal}, boolean());
+ auto if_node = TreeExprBuilder::MakeIf(condition, literal, top_node, int32());
+ top_node = if_node;
+ }
+ auto expr = TreeExprBuilder::MakeExpression(top_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayInt32({10, 102, 158, 302}, {true, true, true, true});
+
+ // expected output
+ auto exp = MakeArrowArrayInt32({20, 110, 160, 200}, {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/in_expr_test.cc b/src/arrow/cpp/src/gandiva/tests/in_expr_test.cc
new file mode 100644
index 000000000..fc1a8a71b
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/in_expr_test.cc
@@ -0,0 +1,278 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include <cmath>
+
+#include "arrow/memory_pool.h"
+#include "gandiva/filter.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::float32;
+using arrow::float64;
+using arrow::int32;
+
+class TestIn : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+std::vector<Decimal128> MakeDecimalVector(std::vector<std::string> values) {
+ std::vector<arrow::Decimal128> ret;
+ for (auto str : values) {
+ Decimal128 decimal_value;
+ int32_t decimal_precision;
+ int32_t decimal_scale;
+
+ DCHECK_OK(
+ Decimal128::FromString(str, &decimal_value, &decimal_precision, &decimal_scale));
+
+ ret.push_back(decimal_value);
+ }
+ return ret;
+}
+
+TEST_F(TestIn, TestInSimple) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f1", int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // Build In f0 + f1 in (6, 11)
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ auto node_f1 = TreeExprBuilder::MakeField(field1);
+ auto sum_func =
+ TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32());
+ std::unordered_set<int32_t> in_constants({6, 11});
+ auto in_expr = TreeExprBuilder::MakeInExpressionInt32(sum_func, in_constants);
+ auto condition = TreeExprBuilder::MakeCondition(in_expr);
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array0 = MakeArrowArrayInt32({1, 2, 3, 4, 6}, {true, true, true, false, true});
+ auto array1 = MakeArrowArrayInt32({5, 9, 6, 17, 5}, {true, true, false, true, false});
+ // expected output (indices for which condition matches)
+ auto exp = MakeArrowArrayUint16({0, 1});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
+TEST_F(TestIn, TestInFloat) {
+ // schema for input fields
+ auto field0 = field("f0", float32());
+ auto schema = arrow::schema({field0});
+
+ // Build In f0 + f1 in (6, 11)
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+
+ std::unordered_set<float> in_constants({6.5f, 12.0f, 11.5f});
+ auto in_expr = TreeExprBuilder::MakeInExpressionFloat(node_f0, in_constants);
+ auto condition = TreeExprBuilder::MakeCondition(in_expr);
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array0 =
+ MakeArrowArrayFloat32({6.5f, 11.5f, 4, 3.15f, 6}, {true, true, false, true, true});
+ // expected output (indices for which condition matches)
+ auto exp = MakeArrowArrayUint16({0, 1});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
+TEST_F(TestIn, TestInDouble) {
+ // schema for input fields
+ auto field0 = field("double0", float64());
+ auto field1 = field("double1", float64());
+ auto schema = arrow::schema({field0, field1});
+
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ auto node_f1 = TreeExprBuilder::MakeField(field1);
+ auto sum_func =
+ TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::float64());
+ std::unordered_set<double> in_constants({3.14159265359, 15.5555555});
+ auto in_expr = TreeExprBuilder::MakeInExpressionDouble(sum_func, in_constants);
+ auto condition = TreeExprBuilder::MakeCondition(in_expr);
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array0 = MakeArrowArrayFloat64({1, 2, 3, 4, 11}, {true, true, true, false, false});
+ auto array1 = MakeArrowArrayFloat64({5, 9, 0.14159265359, 17, 4.5555555},
+ {true, true, true, true, true});
+
+ // expected output (indices for which condition matches)
+ auto exp = MakeArrowArrayUint16({2});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
+TEST_F(TestIn, TestInDecimal) {
+ int32_t precision = 38;
+ int32_t scale = 5;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+
+ // schema for input fields
+ auto field0 = field("f0", arrow::decimal(precision, scale));
+ auto schema = arrow::schema({field0});
+
+ // Build In f0 + f1 in (6, 11)
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+
+ gandiva::DecimalScalar128 d0("6", precision, scale);
+ gandiva::DecimalScalar128 d1("12", precision, scale);
+ gandiva::DecimalScalar128 d2("11", precision, scale);
+ std::unordered_set<gandiva::DecimalScalar128> in_constants({d0, d1, d2});
+ auto in_expr = TreeExprBuilder::MakeInExpressionDecimal(node_f0, in_constants);
+ auto condition = TreeExprBuilder::MakeCondition(in_expr);
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto values0 = MakeDecimalVector({"1", "2", "0", "-6", "6"});
+ auto array0 =
+ MakeArrowArrayDecimal(decimal_type, values0, {true, true, true, false, true});
+ // expected output (indices for which condition matches)
+ auto exp = MakeArrowArrayUint16({4});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
+TEST_F(TestIn, TestInString) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::utf8());
+ auto schema = arrow::schema({field0});
+
+ // Build f0 in ("test" ,"me")
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ std::unordered_set<std::string> in_constants({"test", "me"});
+ auto in_expr = TreeExprBuilder::MakeInExpressionString(node_f0, in_constants);
+
+ auto condition = TreeExprBuilder::MakeCondition(in_expr);
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array_a = MakeArrowArrayUtf8({"test", "lol", "me", "arrow", "test"},
+ {true, true, true, true, false});
+ // expected output (indices for which condition matches)
+ auto exp = MakeArrowArrayUint16({0, 2});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
+TEST_F(TestIn, TestInStringValidationError) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::int32());
+ auto schema = arrow::schema({field0});
+
+ // Build f0 in ("test" ,"me")
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ std::unordered_set<std::string> in_constants({"test", "me"});
+ auto in_expr = TreeExprBuilder::MakeInExpressionString(node_f0, in_constants);
+ auto condition = TreeExprBuilder::MakeCondition(in_expr);
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+
+ EXPECT_TRUE(status.IsExpressionValidationError());
+ std::string expected_error = "Evaluation expression for IN clause returns ";
+ EXPECT_TRUE(status.message().find(expected_error) != std::string::npos);
+}
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/literal_test.cc b/src/arrow/cpp/src/gandiva/tests/literal_test.cc
new file mode 100644
index 000000000..b5ffff031
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/literal_test.cc
@@ -0,0 +1,232 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::float32;
+using arrow::float64;
+using arrow::int32;
+using arrow::int64;
+
+class TestLiteral : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+TEST_F(TestLiteral, TestSimpleArithmetic) {
+ // schema for input fields
+ auto field_a = field("a", boolean());
+ auto field_b = field("b", int32());
+ auto field_c = field("c", int64());
+ auto field_d = field("d", float32());
+ auto field_e = field("e", float64());
+ auto schema = arrow::schema({field_a, field_b, field_c, field_d, field_e});
+
+ // output fields
+ auto res_a = field("a+1", boolean());
+ auto res_b = field("b+1", int32());
+ auto res_c = field("c+1", int64());
+ auto res_d = field("d+1", float32());
+ auto res_e = field("e+1", float64());
+
+ // build expressions.
+ // a == true
+ // b + 1
+ // c + 1
+ // d + 1
+ // e + 1
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto literal_a = TreeExprBuilder::MakeLiteral(true);
+ auto func_a = TreeExprBuilder::MakeFunction("equal", {node_a, literal_a}, boolean());
+ auto expr_a = TreeExprBuilder::MakeExpression(func_a, res_a);
+
+ auto node_b = TreeExprBuilder::MakeField(field_b);
+ auto literal_b = TreeExprBuilder::MakeLiteral((int32_t)1);
+ auto func_b = TreeExprBuilder::MakeFunction("add", {node_b, literal_b}, int32());
+ auto expr_b = TreeExprBuilder::MakeExpression(func_b, res_b);
+
+ auto node_c = TreeExprBuilder::MakeField(field_c);
+ auto literal_c = TreeExprBuilder::MakeLiteral((int64_t)1);
+ auto func_c = TreeExprBuilder::MakeFunction("add", {node_c, literal_c}, int64());
+ auto expr_c = TreeExprBuilder::MakeExpression(func_c, res_c);
+
+ auto node_d = TreeExprBuilder::MakeField(field_d);
+ auto literal_d = TreeExprBuilder::MakeLiteral(static_cast<float>(1));
+ auto func_d = TreeExprBuilder::MakeFunction("add", {node_d, literal_d}, float32());
+ auto expr_d = TreeExprBuilder::MakeExpression(func_d, res_d);
+
+ auto node_e = TreeExprBuilder::MakeField(field_e);
+ auto literal_e = TreeExprBuilder::MakeLiteral(static_cast<double>(1));
+ auto func_e = TreeExprBuilder::MakeFunction("add", {node_e, literal_e}, float64());
+ auto expr_e = TreeExprBuilder::MakeExpression(func_e, res_e);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr_a, expr_b, expr_c, expr_d, expr_e},
+ TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a = MakeArrowArrayBool({true, true, false, true}, {true, true, true, false});
+ auto array_b = MakeArrowArrayInt32({5, 15, -15, 17}, {true, true, true, false});
+ auto array_c = MakeArrowArrayInt64({5, 15, -15, 17}, {true, true, true, false});
+ auto array_d = MakeArrowArrayFloat32({5.2f, 15, -15.6f, 17}, {true, true, true, false});
+ auto array_e = MakeArrowArrayFloat64({5.6f, 15, -15.9f, 17}, {true, true, true, false});
+
+ // expected output
+ auto exp_a = MakeArrowArrayBool({true, true, false, false}, {true, true, true, false});
+ auto exp_b = MakeArrowArrayInt32({6, 16, -14, 0}, {true, true, true, false});
+ auto exp_c = MakeArrowArrayInt64({6, 16, -14, 0}, {true, true, true, false});
+ auto exp_d = MakeArrowArrayFloat32({6.2f, 16, -14.6f, 0}, {true, true, true, false});
+ auto exp_e = MakeArrowArrayFloat64({6.6f, 16, -14.9f, 0}, {true, true, true, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records,
+ {array_a, array_b, array_c, array_d, array_e});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_a, outputs.at(0));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_b, outputs.at(1));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_c, outputs.at(2));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_d, outputs.at(3));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_e, outputs.at(4));
+}
+
+TEST_F(TestLiteral, TestLiteralHash) {
+ auto schema = arrow::schema({});
+ // output fields
+ auto res = field("a", int32());
+ auto int_literal = TreeExprBuilder::MakeLiteral((int32_t)2);
+ auto expr = TreeExprBuilder::MakeExpression(int_literal, res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ auto res1 = field("a", int64());
+ auto int_literal1 = TreeExprBuilder::MakeLiteral((int64_t)2);
+ auto expr1 = TreeExprBuilder::MakeExpression(int_literal1, res1);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector1;
+ status = Projector::Make(schema, {expr1}, TestConfiguration(), &projector1);
+ EXPECT_TRUE(status.ok()) << status.message();
+ EXPECT_TRUE(projector.get() != projector1.get());
+}
+
+TEST_F(TestLiteral, TestNullLiteral) {
+ // schema for input fields
+ auto field_a = field("a", int32());
+ auto field_b = field("b", int32());
+ auto schema = arrow::schema({field_a, field_b});
+
+ // output fields
+ auto res = field("a+b+null", int32());
+
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto node_b = TreeExprBuilder::MakeField(field_b);
+ auto literal_c = TreeExprBuilder::MakeNull(arrow::int32());
+ auto add_a_b = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32());
+ auto add_a_b_c = TreeExprBuilder::MakeFunction("add", {add_a_b, literal_c}, int32());
+ auto expr = TreeExprBuilder::MakeExpression(add_a_b_c, res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a = MakeArrowArrayInt32({5, 15, -15, 17}, {true, true, true, false});
+ auto array_b = MakeArrowArrayInt32({5, 15, -15, 17}, {true, true, true, false});
+
+ // expected output
+ auto exp = MakeArrowArrayInt32({0, 0, 0, 0}, {false, false, false, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestLiteral, TestNullLiteralInIf) {
+ // schema for input fields
+ auto field_a = field("a", float64());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res = field("res", float64());
+
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto literal_5 = TreeExprBuilder::MakeLiteral(5.0);
+ auto a_gt_5 = TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_5},
+ arrow::boolean());
+ auto literal_null = TreeExprBuilder::MakeNull(arrow::float64());
+ auto if_node =
+ TreeExprBuilder::MakeIf(a_gt_5, literal_5, literal_null, arrow::float64());
+ auto expr = TreeExprBuilder::MakeExpression(if_node, res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a = MakeArrowArrayFloat64({6, 15, -15, 17}, {true, true, true, false});
+
+ // expected output
+ auto exp = MakeArrowArrayFloat64({5, 5, 0, 0}, {true, true, false, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/micro_benchmarks.cc b/src/arrow/cpp/src/gandiva/tests/micro_benchmarks.cc
new file mode 100644
index 000000000..35c77e3dd
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/micro_benchmarks.cc
@@ -0,0 +1,456 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <stdlib.h>
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "benchmark/benchmark.h"
+#include "gandiva/decimal_type_util.h"
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tests/timed_evaluate.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::int32;
+using arrow::int64;
+using arrow::utf8;
+
+static void TimedTestAdd3(benchmark::State& state) {
+ // schema for input fields
+ auto field0 = field("f0", int64());
+ auto field1 = field("f1", int64());
+ auto field2 = field("f2", int64());
+ auto schema = arrow::schema({field0, field1, field2});
+ auto pool_ = arrow::default_memory_pool();
+
+ // output field
+ auto field_sum = field("add", int64());
+
+ // Build expression
+ auto part_sum = TreeExprBuilder::MakeFunction(
+ "add", {TreeExprBuilder::MakeField(field1), TreeExprBuilder::MakeField(field2)},
+ int64());
+ auto sum = TreeExprBuilder::MakeFunction(
+ "add", {TreeExprBuilder::MakeField(field0), part_sum}, int64());
+
+ auto sum_expr = TreeExprBuilder::MakeExpression(sum, field_sum);
+
+ std::shared_ptr<Projector> projector;
+ ASSERT_OK(Projector::Make(schema, {sum_expr}, TestConfiguration(), &projector));
+
+ Int64DataGenerator data_generator;
+ ProjectEvaluator evaluator(projector);
+
+ Status status = TimedEvaluate<arrow::Int64Type, int64_t>(
+ schema, evaluator, data_generator, pool_, 1 * MILLION, 16 * THOUSAND, state);
+ ASSERT_OK(status);
+}
+
+static void TimedTestBigNested(benchmark::State& state) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto schema = arrow::schema({fielda});
+ auto pool_ = arrow::default_memory_pool();
+
+ // output fields
+ auto field_result = field("res", int32());
+
+ // build expression.
+ // if (a < 10)
+ // 10
+ // else if (a < 20)
+ // 20
+ // ..
+ // ..
+ // else if (a < 190)
+ // 190
+ // else
+ // 200
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto top_node = TreeExprBuilder::MakeLiteral(200);
+ for (int thresh = 190; thresh > 0; thresh -= 10) {
+ auto literal = TreeExprBuilder::MakeLiteral(thresh);
+ auto condition =
+ TreeExprBuilder::MakeFunction("less_than", {node_a, literal}, boolean());
+ auto if_node = TreeExprBuilder::MakeIf(condition, literal, top_node, int32());
+ top_node = if_node;
+ }
+ auto expr = TreeExprBuilder::MakeExpression(top_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector));
+
+ BoundedInt32DataGenerator data_generator(250);
+ ProjectEvaluator evaluator(projector);
+
+ Status status = TimedEvaluate<arrow::Int32Type, int32_t>(
+ schema, evaluator, data_generator, pool_, 1 * MILLION, 16 * THOUSAND, state);
+ ASSERT_TRUE(status.ok());
+}
+
+static void TimedTestExtractYear(benchmark::State& state) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::date64());
+ auto schema = arrow::schema({field0});
+ auto pool_ = arrow::default_memory_pool();
+
+ // output field
+ auto field_res = field("res", int64());
+
+ // Build expression
+ auto expr = TreeExprBuilder::MakeExpression("extractYear", {field0}, field_res);
+
+ std::shared_ptr<Projector> projector;
+ ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector));
+
+ Int64DataGenerator data_generator;
+ ProjectEvaluator evaluator(projector);
+
+ Status status = TimedEvaluate<arrow::Date64Type, int64_t>(
+ schema, evaluator, data_generator, pool_, 1 * MILLION, 16 * THOUSAND, state);
+ ASSERT_TRUE(status.ok());
+}
+
+static void TimedTestFilterAdd2(benchmark::State& state) {
+ // schema for input fields
+ auto field0 = field("f0", int64());
+ auto field1 = field("f1", int64());
+ auto field2 = field("f2", int64());
+ auto schema = arrow::schema({field0, field1, field2});
+ auto pool_ = arrow::default_memory_pool();
+
+ // Build expression
+ auto sum = TreeExprBuilder::MakeFunction(
+ "add", {TreeExprBuilder::MakeField(field1), TreeExprBuilder::MakeField(field0)},
+ int64());
+ auto less_than = TreeExprBuilder::MakeFunction(
+ "less_than", {sum, TreeExprBuilder::MakeField(field2)}, boolean());
+ auto condition = TreeExprBuilder::MakeCondition(less_than);
+
+ std::shared_ptr<Filter> filter;
+ ASSERT_OK(Filter::Make(schema, condition, TestConfiguration(), &filter));
+
+ Int64DataGenerator data_generator;
+ FilterEvaluator evaluator(filter);
+
+ Status status = TimedEvaluate<arrow::Int64Type, int64_t>(
+ schema, evaluator, data_generator, pool_, MILLION, 16 * THOUSAND, state);
+ ASSERT_TRUE(status.ok());
+}
+
+static void TimedTestFilterLike(benchmark::State& state) {
+ // schema for input fields
+ auto fielda = field("a", utf8());
+ auto schema = arrow::schema({fielda});
+ auto pool_ = arrow::default_memory_pool();
+
+ // build expression.
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto pattern_node = TreeExprBuilder::MakeStringLiteral("%yellow%");
+ auto like_yellow =
+ TreeExprBuilder::MakeFunction("like", {node_a, pattern_node}, arrow::boolean());
+ auto condition = TreeExprBuilder::MakeCondition(like_yellow);
+
+ std::shared_ptr<Filter> filter;
+ ASSERT_OK(Filter::Make(schema, condition, TestConfiguration(), &filter));
+
+ FastUtf8DataGenerator data_generator(32);
+ FilterEvaluator evaluator(filter);
+
+ Status status = TimedEvaluate<arrow::StringType, std::string>(
+ schema, evaluator, data_generator, pool_, 1 * MILLION, 16 * THOUSAND, state);
+ ASSERT_TRUE(status.ok());
+}
+
+static void TimedTestCastFloatFromString(benchmark::State& state) {
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+ auto pool = arrow::default_memory_pool();
+
+ auto field_result = field("res", arrow::float64());
+
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto fn = TreeExprBuilder::MakeFunction("castFLOAT8", {node_a}, arrow::float64());
+ auto expr = TreeExprBuilder::MakeExpression(fn, field_result);
+
+ std::shared_ptr<Projector> projector;
+ ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector));
+
+ Utf8FloatDataGenerator data_generator;
+ ProjectEvaluator evaluator(projector);
+
+ Status status = TimedEvaluate<arrow::StringType, std::string>(
+ schema, evaluator, data_generator, pool, 1 * MILLION, 16 * THOUSAND, state);
+ ASSERT_TRUE(status.ok());
+}
+
+static void TimedTestCastIntFromString(benchmark::State& state) {
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+ auto pool = arrow::default_memory_pool();
+
+ auto field_result = field("res", int32());
+
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto fn = TreeExprBuilder::MakeFunction("castINT", {node_a}, int32());
+ auto expr = TreeExprBuilder::MakeExpression(fn, field_result);
+
+ std::shared_ptr<Projector> projector;
+ ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector));
+
+ Utf8IntDataGenerator data_generator;
+ ProjectEvaluator evaluator(projector);
+
+ Status status = TimedEvaluate<arrow::StringType, std::string>(
+ schema, evaluator, data_generator, pool, 1 * MILLION, 16 * THOUSAND, state);
+ ASSERT_TRUE(status.ok());
+}
+
+static void TimedTestAllocs(benchmark::State& state) {
+ // schema for input fields
+ auto field_a = field("a", arrow::utf8());
+ auto schema = arrow::schema({field_a});
+ auto pool_ = arrow::default_memory_pool();
+
+ // output field
+ auto field_res = field("res", int32());
+
+ // Build expression
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto upper = TreeExprBuilder::MakeFunction("upper", {node_a}, utf8());
+ auto length = TreeExprBuilder::MakeFunction("octet_length", {upper}, int32());
+ auto expr = TreeExprBuilder::MakeExpression(length, field_res);
+
+ std::shared_ptr<Projector> projector;
+ ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector));
+
+ FastUtf8DataGenerator data_generator(64);
+ ProjectEvaluator evaluator(projector);
+
+ Status status = TimedEvaluate<arrow::StringType, std::string>(
+ schema, evaluator, data_generator, pool_, 1 * MILLION, 16 * THOUSAND, state);
+ ASSERT_TRUE(status.ok());
+}
+// following two tests are for benchmark optimization of
+// in expr. will be used in follow-up PRs to optimize in expr.
+
+static void TimedTestMultiOr(benchmark::State& state) {
+ // schema for input fields
+ auto fielda = field("a", utf8());
+ auto schema = arrow::schema({fielda});
+ auto pool_ = arrow::default_memory_pool();
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ // build expression.
+ // booleanOr(a = string1, a = string2, ..)
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+
+ NodeVector boolean_functions;
+ FastUtf8DataGenerator data_generator1(250);
+ for (int thresh = 1; thresh <= 32; thresh++) {
+ auto literal = TreeExprBuilder::MakeStringLiteral(data_generator1.GenerateData());
+ auto condition = TreeExprBuilder::MakeFunction("equal", {node_a, literal}, boolean());
+ boolean_functions.push_back(condition);
+ }
+
+ auto boolean_or = TreeExprBuilder::MakeOr(boolean_functions);
+ auto expr = TreeExprBuilder::MakeExpression(boolean_or, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector));
+
+ FastUtf8DataGenerator data_generator(250);
+ ProjectEvaluator evaluator(projector);
+ Status status = TimedEvaluate<arrow::StringType, std::string>(
+ schema, evaluator, data_generator, pool_, 100 * THOUSAND, 16 * THOUSAND, state);
+ ASSERT_OK(status);
+}
+
+static void TimedTestInExpr(benchmark::State& state) {
+ // schema for input fields
+ auto fielda = field("a", utf8());
+ auto schema = arrow::schema({fielda});
+ auto pool_ = arrow::default_memory_pool();
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ // build expression.
+ // a in (string1, string2, ..)
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+
+ std::unordered_set<std::string> values;
+ FastUtf8DataGenerator data_generator1(250);
+ for (int i = 1; i <= 32; i++) {
+ values.insert(data_generator1.GenerateData());
+ }
+ auto boolean_or = TreeExprBuilder::MakeInExpressionString(node_a, values);
+ auto expr = TreeExprBuilder::MakeExpression(boolean_or, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector));
+
+ FastUtf8DataGenerator data_generator(250);
+ ProjectEvaluator evaluator(projector);
+
+ Status status = TimedEvaluate<arrow::StringType, std::string>(
+ schema, evaluator, data_generator, pool_, 100 * THOUSAND, 16 * THOUSAND, state);
+
+ ASSERT_OK(status);
+}
+
+static void DoDecimalAdd3(benchmark::State& state, int32_t precision, int32_t scale,
+ bool large = false) {
+ // schema for input fields
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field0 = field("f0", decimal_type);
+ auto field1 = field("f1", decimal_type);
+ auto field2 = field("f2", decimal_type);
+ auto schema = arrow::schema({field0, field1, field2});
+
+ Decimal128TypePtr add2_type;
+ auto status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd,
+ {decimal_type, decimal_type}, &add2_type);
+
+ Decimal128TypePtr output_type;
+ status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd,
+ {add2_type, decimal_type}, &output_type);
+
+ // output field
+ auto field_sum = field("add", output_type);
+
+ // Build expression
+ auto part_sum = TreeExprBuilder::MakeFunction(
+ "add", {TreeExprBuilder::MakeField(field1), TreeExprBuilder::MakeField(field2)},
+ add2_type);
+ auto sum = TreeExprBuilder::MakeFunction(
+ "add", {TreeExprBuilder::MakeField(field0), part_sum}, output_type);
+
+ auto sum_expr = TreeExprBuilder::MakeExpression(sum, field_sum);
+
+ std::shared_ptr<Projector> projector;
+ status = Projector::Make(schema, {sum_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ Decimal128DataGenerator data_generator(large);
+ ProjectEvaluator evaluator(projector);
+
+ status = TimedEvaluate<arrow::Decimal128Type, arrow::Decimal128>(
+ schema, evaluator, data_generator, arrow::default_memory_pool(), 1 * MILLION,
+ 16 * THOUSAND, state);
+ ASSERT_OK(status);
+}
+
+static void DoDecimalAdd2(benchmark::State& state, int32_t precision, int32_t scale,
+ bool large = false) {
+ // schema for input fields
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field0 = field("f0", decimal_type);
+ auto field1 = field("f1", decimal_type);
+ auto schema = arrow::schema({field0, field1});
+
+ Decimal128TypePtr output_type;
+ auto status = DecimalTypeUtil::GetResultType(
+ DecimalTypeUtil::kOpAdd, {decimal_type, decimal_type}, &output_type);
+
+ // output field
+ auto field_sum = field("add", output_type);
+
+ // Build expression
+ auto sum = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum);
+
+ std::shared_ptr<Projector> projector;
+ status = Projector::Make(schema, {sum}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ Decimal128DataGenerator data_generator(large);
+ ProjectEvaluator evaluator(projector);
+
+ status = TimedEvaluate<arrow::Decimal128Type, arrow::Decimal128>(
+ schema, evaluator, data_generator, arrow::default_memory_pool(), 1 * MILLION,
+ 16 * THOUSAND, state);
+ ASSERT_OK(status);
+}
+
+static void DecimalAdd2Fast(benchmark::State& state) {
+ // use lesser precision to test the fast-path
+ DoDecimalAdd2(state, DecimalTypeUtil::kMaxPrecision - 6, 18);
+}
+
+static void DecimalAdd2LeadingZeroes(benchmark::State& state) {
+ // use max precision to test the large-integer-path
+ DoDecimalAdd2(state, DecimalTypeUtil::kMaxPrecision, 6);
+}
+
+static void DecimalAdd2LeadingZeroesWithDiv(benchmark::State& state) {
+ // use max precision to test the large-integer-path
+ DoDecimalAdd2(state, DecimalTypeUtil::kMaxPrecision, 18);
+}
+
+static void DecimalAdd2Large(benchmark::State& state) {
+ // use max precision to test the large-integer-path
+ DoDecimalAdd2(state, DecimalTypeUtil::kMaxPrecision, 18, true);
+}
+
+static void DecimalAdd3Fast(benchmark::State& state) {
+ // use lesser precision to test the fast-path
+ DoDecimalAdd3(state, DecimalTypeUtil::kMaxPrecision - 6, 18);
+}
+
+static void DecimalAdd3LeadingZeroes(benchmark::State& state) {
+ // use max precision to test the large-integer-path
+ DoDecimalAdd3(state, DecimalTypeUtil::kMaxPrecision, 6);
+}
+
+static void DecimalAdd3LeadingZeroesWithDiv(benchmark::State& state) {
+ // use max precision to test the large-integer-path
+ DoDecimalAdd3(state, DecimalTypeUtil::kMaxPrecision, 18);
+}
+
+static void DecimalAdd3Large(benchmark::State& state) {
+ // use max precision to test the large-integer-path
+ DoDecimalAdd3(state, DecimalTypeUtil::kMaxPrecision, 18, true);
+}
+
+BENCHMARK(TimedTestAdd3)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(TimedTestBigNested)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(TimedTestExtractYear)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(TimedTestFilterAdd2)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(TimedTestFilterLike)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(TimedTestCastFloatFromString)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(TimedTestCastIntFromString)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(TimedTestAllocs)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(TimedTestMultiOr)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(TimedTestInExpr)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd2Fast)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd2LeadingZeroes)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd2LeadingZeroesWithDiv)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd2Large)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd3Fast)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd3LeadingZeroes)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd3LeadingZeroesWithDiv)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd3Large)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/null_validity_test.cc b/src/arrow/cpp/src/gandiva/tests/null_validity_test.cc
new file mode 100644
index 000000000..0374b68d4
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/null_validity_test.cc
@@ -0,0 +1,175 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include "arrow/memory_pool.h"
+#include "gandiva/filter.h"
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::int32;
+using arrow::utf8;
+
+class TestNullValidity : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+// Create an array without a validity buffer.
+ArrayPtr MakeArrowArrayInt32WithNullValidity(std::vector<int32_t> in_data) {
+ auto array = MakeArrowArrayInt32(in_data);
+ return std::make_shared<arrow::Int32Array>(in_data.size(), array->data()->buffers[1],
+ nullptr, 0);
+}
+
+TEST_F(TestNullValidity, TestFunc) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f1", int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // Build condition f0 + f1 < 10
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ auto node_f1 = TreeExprBuilder::MakeField(field1);
+ auto sum_func =
+ TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32());
+ auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10);
+ auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10},
+ arrow::boolean());
+ auto condition = TreeExprBuilder::MakeCondition(less_than_10);
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+
+ // Create an array without a validity buffer.
+ auto array0 = MakeArrowArrayInt32WithNullValidity({1, 2, 3, 4, 6});
+ auto array1 = MakeArrowArrayInt32({5, 9, 6, 17, 3}, {true, true, false, true, true});
+ // expected output (indices for which condition matches)
+ auto exp = MakeArrowArrayUint16({0, 4});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
+TEST_F(TestNullValidity, TestIfElse) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto schema = arrow::schema({fielda, fieldb});
+
+ // output fields
+ auto field_result = field("res", int32());
+
+ // build expression.
+ // if (a > b)
+ // a
+ // else
+ // b
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto condition =
+ TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
+ auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32());
+
+ auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayInt32WithNullValidity({10, 12, -20, 5});
+ auto array1 = MakeArrowArrayInt32({5, 15, 15, 17});
+
+ // expected output
+ auto exp = MakeArrowArrayInt32({10, 15, 15, 17}, {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestNullValidity, TestUtf8) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res = field("res1", int32());
+
+ // build expressions.
+ // length(a)
+ auto expr = TreeExprBuilder::MakeExpression("length", {field_a}, res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array_v = MakeArrowArrayUtf8({"foo", "hello", "bye", "hi", "मदन"});
+ auto array_a = std::make_shared<arrow::StringArray>(
+ num_records, array_v->data()->buffers[1], array_v->data()->buffers[2]);
+
+ // expected output
+ auto exp = MakeArrowArrayInt32({3, 5, 3, 2, 3}, {true, true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/projector_build_validation_test.cc b/src/arrow/cpp/src/gandiva/tests/projector_build_validation_test.cc
new file mode 100644
index 000000000..5b86844f9
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/projector_build_validation_test.cc
@@ -0,0 +1,287 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include "arrow/memory_pool.h"
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::float32;
+using arrow::int32;
+
+class TestProjector : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+TEST_F(TestProjector, TestNonexistentFunction) {
+ // schema for input fields
+ auto field0 = field("f0", float32());
+ auto field1 = field("f2", float32());
+ auto schema = arrow::schema({field0, field1});
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ // Build expression
+ auto lt_expr = TreeExprBuilder::MakeExpression("nonexistent_function", {field0, field1},
+ field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.IsExpressionValidationError());
+ std::string expected_error =
+ "Function bool nonexistent_function(float, float) not supported yet.";
+ EXPECT_TRUE(status.message().find(expected_error) != std::string::npos);
+}
+
+TEST_F(TestProjector, TestNotMatchingDataType) {
+ // schema for input fields
+ auto field0 = field("f0", float32());
+ auto schema = arrow::schema({field0});
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ // Build expression
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ auto lt_expr = TreeExprBuilder::MakeExpression(node_f0, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.IsExpressionValidationError());
+ std::string expected_error =
+ "Return type of root node float does not match that of expression bool";
+ EXPECT_TRUE(status.message().find(expected_error) != std::string::npos);
+}
+
+TEST_F(TestProjector, TestNotSupportedDataType) {
+ // schema for input fields
+ auto field0 = field("f0", list(int32()));
+ auto schema = arrow::schema({field0});
+
+ // output fields
+ auto field_result = field("res", list(int32()));
+
+ // Build expression
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ auto lt_expr = TreeExprBuilder::MakeExpression(node_f0, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.IsExpressionValidationError());
+ std::string expected_error = "Field f0 has unsupported data type list";
+ EXPECT_TRUE(status.message().find(expected_error) != std::string::npos);
+}
+
+TEST_F(TestProjector, TestIncorrectSchemaMissingField) {
+ // schema for input fields
+ auto field0 = field("f0", float32());
+ auto field1 = field("f2", float32());
+ auto schema = arrow::schema({field0, field0});
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ // Build expression
+ auto lt_expr =
+ TreeExprBuilder::MakeExpression("less_than", {field0, field1}, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.IsExpressionValidationError());
+ std::string expected_error = "Field f2 not in schema";
+ EXPECT_TRUE(status.message().find(expected_error) != std::string::npos);
+}
+
+TEST_F(TestProjector, TestIncorrectSchemaTypeNotMatching) {
+ // schema for input fields
+ auto field0 = field("f0", float32());
+ auto field1 = field("f2", float32());
+ auto field2 = field("f2", int32());
+ auto schema = arrow::schema({field0, field2});
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ // Build expression
+ auto lt_expr =
+ TreeExprBuilder::MakeExpression("less_than", {field0, field1}, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.IsExpressionValidationError());
+ std::string expected_error =
+ "Field definition in schema f2: int32 different from field in expression f2: float";
+ EXPECT_TRUE(status.message().find(expected_error) != std::string::npos);
+}
+
+TEST_F(TestProjector, TestIfNotSupportedFunction) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto schema = arrow::schema({fielda, fieldb});
+
+ // output fields
+ auto field_result = field("res", int32());
+
+ // build expression.
+ // if (a > b)
+ // a
+ // else
+ // b
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto condition =
+ TreeExprBuilder::MakeFunction("nonexistent_function", {node_a, node_b}, boolean());
+ auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32());
+
+ auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.IsExpressionValidationError());
+}
+
+TEST_F(TestProjector, TestIfNotMatchingReturnType) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto schema = arrow::schema({fielda, fieldb});
+
+ // output fields
+ auto field_result = field("res", int32());
+
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto condition =
+ TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean());
+ auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, boolean());
+
+ auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.IsExpressionValidationError());
+}
+
+TEST_F(TestProjector, TestElseNotMatchingReturnType) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto fieldc = field("c", boolean());
+ auto schema = arrow::schema({fielda, fieldb, fieldc});
+
+ // output fields
+ auto field_result = field("res", int32());
+
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto node_c = TreeExprBuilder::MakeField(fieldc);
+ auto condition =
+ TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean());
+ auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_c, int32());
+
+ auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.IsExpressionValidationError());
+}
+
+TEST_F(TestProjector, TestElseNotSupportedType) {
+ // schema for input fields
+ auto fielda = field("a", int32());
+ auto fieldb = field("b", int32());
+ auto fieldc = field("c", list(int32()));
+ auto schema = arrow::schema({fielda, fieldb});
+
+ // output fields
+ auto field_result = field("res", int32());
+
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto node_c = TreeExprBuilder::MakeField(fieldc);
+ auto condition =
+ TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean());
+ auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_c, int32());
+
+ auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.IsExpressionValidationError());
+ EXPECT_EQ(status.code(), StatusCode::ExpressionValidationError);
+}
+
+TEST_F(TestProjector, TestAndMinChildren) {
+ // schema for input fields
+ auto fielda = field("a", boolean());
+ auto schema = arrow::schema({fielda});
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto and_node = TreeExprBuilder::MakeAnd({node_a});
+
+ auto expr = TreeExprBuilder::MakeExpression(and_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.IsExpressionValidationError());
+}
+
+TEST_F(TestProjector, TestAndBooleanArgType) {
+ // schema for input fields
+ auto fielda = field("a", boolean());
+ auto fieldb = field("b", int32());
+ auto schema = arrow::schema({fielda, fieldb});
+
+ // output fields
+ auto field_result = field("res", int32());
+
+ auto node_a = TreeExprBuilder::MakeField(fielda);
+ auto node_b = TreeExprBuilder::MakeField(fieldb);
+ auto and_node = TreeExprBuilder::MakeAnd({node_a, node_b});
+
+ auto expr = TreeExprBuilder::MakeExpression(and_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.IsExpressionValidationError());
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/projector_test.cc b/src/arrow/cpp/src/gandiva/tests/projector_test.cc
new file mode 100644
index 000000000..120207773
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/projector_test.cc
@@ -0,0 +1,1609 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef M_PI
+#define M_PI 3.14159265358979323846
+#endif
+
+#include "gandiva/projector.h"
+
+#include <gtest/gtest.h>
+
+#include <cmath>
+
+#include "arrow/memory_pool.h"
+#include "gandiva/literal_holder.h"
+#include "gandiva/node.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::float32;
+using arrow::int32;
+
+class TestProjector : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+TEST_F(TestProjector, TestProjectCache) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f2", int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // output fields
+ auto field_sum = field("add", int32());
+ auto field_sub = field("subtract", int32());
+
+ // Build expression
+ auto sum_expr = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum);
+ auto sub_expr =
+ TreeExprBuilder::MakeExpression("subtract", {field0, field1}, field_sub);
+
+ auto configuration = TestConfiguration();
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {sum_expr, sub_expr}, configuration, &projector);
+ ASSERT_OK(status);
+
+ // everything is same, should return the same projector.
+ auto schema_same = arrow::schema({field0, field1});
+ std::shared_ptr<Projector> cached_projector;
+ status = Projector::Make(schema_same, {sum_expr, sub_expr}, configuration,
+ &cached_projector);
+ ASSERT_OK(status);
+ EXPECT_EQ(cached_projector, projector);
+
+ // schema is different should return a new projector.
+ auto field2 = field("f2", int32());
+ auto different_schema = arrow::schema({field0, field1, field2});
+ std::shared_ptr<Projector> should_be_new_projector;
+ status = Projector::Make(different_schema, {sum_expr, sub_expr}, configuration,
+ &should_be_new_projector);
+ ASSERT_OK(status);
+ EXPECT_NE(cached_projector, should_be_new_projector);
+
+ // expression list is different should return a new projector.
+ std::shared_ptr<Projector> should_be_new_projector1;
+ status = Projector::Make(schema, {sum_expr}, configuration, &should_be_new_projector1);
+ ASSERT_OK(status);
+ EXPECT_NE(cached_projector, should_be_new_projector1);
+
+ // another instance of the same configuration, should return the same projector.
+ status = Projector::Make(schema, {sum_expr, sub_expr}, TestConfiguration(),
+ &cached_projector);
+ ASSERT_OK(status);
+ EXPECT_EQ(cached_projector, projector);
+}
+
+TEST_F(TestProjector, TestProjectCacheFieldNames) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f1", int32());
+ auto field2 = field("f2", int32());
+ auto schema = arrow::schema({field0, field1, field2});
+
+ // output fields
+ auto sum_01 = field("sum_01", int32());
+ auto sum_12 = field("sum_12", int32());
+
+ auto sum_expr_01 = TreeExprBuilder::MakeExpression("add", {field0, field1}, sum_01);
+ std::shared_ptr<Projector> projector_01;
+ auto status =
+ Projector::Make(schema, {sum_expr_01}, TestConfiguration(), &projector_01);
+ EXPECT_TRUE(status.ok());
+
+ auto sum_expr_12 = TreeExprBuilder::MakeExpression("add", {field1, field2}, sum_12);
+ std::shared_ptr<Projector> projector_12;
+ status = Projector::Make(schema, {sum_expr_12}, TestConfiguration(), &projector_12);
+ EXPECT_TRUE(status.ok());
+
+ // add(f0, f1) != add(f1, f2)
+ EXPECT_TRUE(projector_01.get() != projector_12.get());
+}
+
+TEST_F(TestProjector, TestProjectCacheDouble) {
+ auto schema = arrow::schema({});
+ auto res = field("result", arrow::float64());
+
+ double d0 = 1.23456788912345677E18;
+ double d1 = 1.23456789012345677E18;
+
+ auto literal0 = TreeExprBuilder::MakeLiteral(d0);
+ auto expr0 = TreeExprBuilder::MakeExpression(literal0, res);
+ auto configuration = TestConfiguration();
+
+ std::shared_ptr<Projector> projector0;
+ auto status = Projector::Make(schema, {expr0}, configuration, &projector0);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ auto literal1 = TreeExprBuilder::MakeLiteral(d1);
+ auto expr1 = TreeExprBuilder::MakeExpression(literal1, res);
+ std::shared_ptr<Projector> projector1;
+ status = Projector::Make(schema, {expr1}, configuration, &projector1);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ EXPECT_TRUE(projector0.get() != projector1.get());
+}
+
+TEST_F(TestProjector, TestProjectCacheFloat) {
+ auto schema = arrow::schema({});
+ auto res = field("result", arrow::float32());
+
+ float f0 = static_cast<float>(12345678891.000000);
+ float f1 = f0 - 1000;
+
+ auto literal0 = TreeExprBuilder::MakeLiteral(f0);
+ auto expr0 = TreeExprBuilder::MakeExpression(literal0, res);
+ std::shared_ptr<Projector> projector0;
+ auto status = Projector::Make(schema, {expr0}, TestConfiguration(), &projector0);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ auto literal1 = TreeExprBuilder::MakeLiteral(f1);
+ auto expr1 = TreeExprBuilder::MakeExpression(literal1, res);
+ std::shared_ptr<Projector> projector1;
+ status = Projector::Make(schema, {expr1}, TestConfiguration(), &projector1);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ EXPECT_TRUE(projector0.get() != projector1.get());
+}
+
+TEST_F(TestProjector, TestProjectCacheLiteral) {
+ auto schema = arrow::schema({});
+ auto res = field("result", arrow::decimal(38, 5));
+
+ DecimalScalar128 d0("12345678", 38, 5);
+ DecimalScalar128 d1("98756432", 38, 5);
+
+ auto literal0 = TreeExprBuilder::MakeDecimalLiteral(d0);
+ auto expr0 = TreeExprBuilder::MakeExpression(literal0, res);
+ std::shared_ptr<Projector> projector0;
+ ASSERT_OK(Projector::Make(schema, {expr0}, TestConfiguration(), &projector0));
+
+ auto literal1 = TreeExprBuilder::MakeDecimalLiteral(d1);
+ auto expr1 = TreeExprBuilder::MakeExpression(literal1, res);
+ std::shared_ptr<Projector> projector1;
+ ASSERT_OK(Projector::Make(schema, {expr1}, TestConfiguration(), &projector1));
+
+ EXPECT_NE(projector0.get(), projector1.get());
+}
+
+TEST_F(TestProjector, TestProjectCacheDecimalCast) {
+ auto field_float64 = field("float64", arrow::float64());
+ auto schema = arrow::schema({field_float64});
+
+ auto res_31_13 = field("result", arrow::decimal(31, 13));
+ auto expr0 = TreeExprBuilder::MakeExpression("castDECIMAL", {field_float64}, res_31_13);
+ std::shared_ptr<Projector> projector0;
+ ASSERT_OK(Projector::Make(schema, {expr0}, TestConfiguration(), &projector0));
+
+ // if the output scale is different, the cache can't be used.
+ auto res_31_14 = field("result", arrow::decimal(31, 14));
+ auto expr1 = TreeExprBuilder::MakeExpression("castDECIMAL", {field_float64}, res_31_14);
+ std::shared_ptr<Projector> projector1;
+ ASSERT_OK(Projector::Make(schema, {expr1}, TestConfiguration(), &projector1));
+ EXPECT_NE(projector0.get(), projector1.get());
+
+ // if the output scale/precision are same, should get a cache hit.
+ auto res_31_13_alt = field("result", arrow::decimal(31, 13));
+ auto expr2 =
+ TreeExprBuilder::MakeExpression("castDECIMAL", {field_float64}, res_31_13_alt);
+ std::shared_ptr<Projector> projector2;
+ ASSERT_OK(Projector::Make(schema, {expr2}, TestConfiguration(), &projector2));
+ EXPECT_EQ(projector0.get(), projector2.get());
+}
+
+TEST_F(TestProjector, TestIntSumSub) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f2", int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // output fields
+ auto field_sum = field("add", int32());
+ auto field_sub = field("subtract", int32());
+
+ // Build expression
+ auto sum_expr = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum);
+ auto sub_expr =
+ TreeExprBuilder::MakeExpression("subtract", {field0, field1}, field_sub);
+
+ std::shared_ptr<Projector> projector;
+ auto status =
+ Projector::Make(schema, {sum_expr, sub_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false});
+ auto array1 = MakeArrowArrayInt32({11, 13, 15, 17}, {true, true, false, true});
+ // expected output
+ auto exp_sum = MakeArrowArrayInt32({12, 15, 0, 0}, {true, true, false, false});
+ auto exp_sub = MakeArrowArrayInt32({-10, -11, 0, 0}, {true, true, false, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_sum, outputs.at(0));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_sub, outputs.at(1));
+}
+
+template <typename TYPE, typename C_TYPE>
+static void TestArithmeticOpsForType(arrow::MemoryPool* pool) {
+ auto atype = arrow::TypeTraits<TYPE>::type_singleton();
+
+ // schema for input fields
+ auto field0 = field("f0", atype);
+ auto field1 = field("f1", atype);
+ auto schema = arrow::schema({field0, field1});
+
+ // output fields
+ auto field_sum = field("add", atype);
+ auto field_sub = field("subtract", atype);
+ auto field_mul = field("multiply", atype);
+ auto field_div = field("divide", atype);
+ auto field_eq = field("equal", arrow::boolean());
+ auto field_lt = field("less_than", arrow::boolean());
+
+ // Build expression
+ auto sum_expr = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum);
+ auto sub_expr =
+ TreeExprBuilder::MakeExpression("subtract", {field0, field1}, field_sub);
+ auto mul_expr =
+ TreeExprBuilder::MakeExpression("multiply", {field0, field1}, field_mul);
+ auto div_expr = TreeExprBuilder::MakeExpression("divide", {field0, field1}, field_div);
+ auto eq_expr = TreeExprBuilder::MakeExpression("equal", {field0, field1}, field_eq);
+ auto lt_expr = TreeExprBuilder::MakeExpression("less_than", {field0, field1}, field_lt);
+
+ std::shared_ptr<Projector> projector;
+ auto status =
+ Projector::Make(schema, {sum_expr, sub_expr, mul_expr, div_expr, eq_expr, lt_expr},
+ TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 12;
+ std::vector<C_TYPE> input0 = {1, 2, 53, 84, 5, 15, 0, 1, 52, 83, 4, 120};
+ std::vector<C_TYPE> input1 = {10, 15, 23, 84, 4, 51, 68, 9, 16, 18, 19, 37};
+ std::vector<bool> validity = {true, true, true, true, true, true,
+ true, true, true, true, true, true};
+
+ auto array0 = MakeArrowArray<TYPE, C_TYPE>(input0, validity);
+ auto array1 = MakeArrowArray<TYPE, C_TYPE>(input1, validity);
+
+ // expected output
+ std::vector<C_TYPE> sum;
+ std::vector<C_TYPE> sub;
+ std::vector<C_TYPE> mul;
+ std::vector<C_TYPE> div;
+ std::vector<bool> eq;
+ std::vector<bool> lt;
+ for (int i = 0; i < num_records; i++) {
+ sum.push_back(static_cast<C_TYPE>(input0[i] + input1[i]));
+ sub.push_back(static_cast<C_TYPE>(input0[i] - input1[i]));
+ mul.push_back(static_cast<C_TYPE>(input0[i] * input1[i]));
+ div.push_back(static_cast<C_TYPE>(input0[i] / input1[i]));
+ eq.push_back(input0[i] == input1[i]);
+ lt.push_back(input0[i] < input1[i]);
+ }
+ auto exp_sum = MakeArrowArray<TYPE, C_TYPE>(sum, validity);
+ auto exp_sub = MakeArrowArray<TYPE, C_TYPE>(sub, validity);
+ auto exp_mul = MakeArrowArray<TYPE, C_TYPE>(mul, validity);
+ auto exp_div = MakeArrowArray<TYPE, C_TYPE>(div, validity);
+ auto exp_eq = MakeArrowArray<arrow::BooleanType, bool>(eq, validity);
+ auto exp_lt = MakeArrowArray<arrow::BooleanType, bool>(lt, validity);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_sum, outputs.at(0));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_sub, outputs.at(1));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_mul, outputs.at(2));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_div, outputs.at(3));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_eq, outputs.at(4));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_lt, outputs.at(5));
+}
+
+TEST_F(TestProjector, TestAllIntTypes) {
+ TestArithmeticOpsForType<arrow::UInt8Type, uint8_t>(pool_);
+ TestArithmeticOpsForType<arrow::UInt16Type, uint16_t>(pool_);
+ TestArithmeticOpsForType<arrow::UInt32Type, uint32_t>(pool_);
+ TestArithmeticOpsForType<arrow::UInt64Type, uint64_t>(pool_);
+ TestArithmeticOpsForType<arrow::Int8Type, int8_t>(pool_);
+ TestArithmeticOpsForType<arrow::Int16Type, int16_t>(pool_);
+ TestArithmeticOpsForType<arrow::Int32Type, int32_t>(pool_);
+ TestArithmeticOpsForType<arrow::Int64Type, int64_t>(pool_);
+}
+
+TEST_F(TestProjector, TestExtendedMath) {
+ // schema for input fields
+ auto field0 = arrow::field("f0", arrow::float64());
+ auto field1 = arrow::field("f1", arrow::float64());
+ auto schema = arrow::schema({field0, field1});
+
+ // output fields
+ auto field_cbrt = arrow::field("cbrt", arrow::float64());
+ auto field_exp = arrow::field("exp", arrow::float64());
+ auto field_log = arrow::field("log", arrow::float64());
+ auto field_log10 = arrow::field("log10", arrow::float64());
+ auto field_logb = arrow::field("logb", arrow::float64());
+ auto field_power = arrow::field("power", arrow::float64());
+ auto field_sin = arrow::field("sin", arrow::float64());
+ auto field_cos = arrow::field("cos", arrow::float64());
+ auto field_asin = arrow::field("asin", arrow::float64());
+ auto field_acos = arrow::field("acos", arrow::float64());
+ auto field_tan = arrow::field("tan", arrow::float64());
+ auto field_atan = arrow::field("atan", arrow::float64());
+ auto field_sinh = arrow::field("sinh", arrow::float64());
+ auto field_cosh = arrow::field("cosh", arrow::float64());
+ auto field_tanh = arrow::field("tanh", arrow::float64());
+ auto field_atan2 = arrow::field("atan2", arrow::float64());
+ auto field_cot = arrow::field("cot", arrow::float64());
+ auto field_radians = arrow::field("radians", arrow::float64());
+ auto field_degrees = arrow::field("degrees", arrow::float64());
+
+ // Build expression
+ auto cbrt_expr = TreeExprBuilder::MakeExpression("cbrt", {field0}, field_cbrt);
+ auto exp_expr = TreeExprBuilder::MakeExpression("exp", {field0}, field_exp);
+ auto log_expr = TreeExprBuilder::MakeExpression("log", {field0}, field_log);
+ auto log10_expr = TreeExprBuilder::MakeExpression("log10", {field0}, field_log10);
+ auto logb_expr = TreeExprBuilder::MakeExpression("log", {field0, field1}, field_logb);
+ auto power_expr =
+ TreeExprBuilder::MakeExpression("power", {field0, field1}, field_power);
+ auto sin_expr = TreeExprBuilder::MakeExpression("sin", {field0}, field_sin);
+ auto cos_expr = TreeExprBuilder::MakeExpression("cos", {field0}, field_cos);
+ auto asin_expr = TreeExprBuilder::MakeExpression("asin", {field0}, field_asin);
+ auto acos_expr = TreeExprBuilder::MakeExpression("acos", {field0}, field_acos);
+ auto tan_expr = TreeExprBuilder::MakeExpression("tan", {field0}, field_tan);
+ auto atan_expr = TreeExprBuilder::MakeExpression("atan", {field0}, field_atan);
+ auto sinh_expr = TreeExprBuilder::MakeExpression("sinh", {field0}, field_sinh);
+ auto cosh_expr = TreeExprBuilder::MakeExpression("cosh", {field0}, field_cosh);
+ auto tanh_expr = TreeExprBuilder::MakeExpression("tanh", {field0}, field_tanh);
+ auto atan2_expr =
+ TreeExprBuilder::MakeExpression("atan2", {field0, field1}, field_atan2);
+ auto cot_expr = TreeExprBuilder::MakeExpression("cot", {field0}, field_cot);
+ auto radians_expr = TreeExprBuilder::MakeExpression("radians", {field0}, field_radians);
+ auto degrees_expr = TreeExprBuilder::MakeExpression("degrees", {field0}, field_degrees);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(
+ schema,
+ {cbrt_expr, exp_expr, log_expr, log10_expr, logb_expr, power_expr, sin_expr,
+ cos_expr, asin_expr, acos_expr, tan_expr, atan_expr, sinh_expr, cosh_expr,
+ tanh_expr, atan2_expr, cot_expr, radians_expr, degrees_expr},
+ TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ std::vector<double> input0 = {16, 10, -14, 8.3};
+ std::vector<double> input1 = {2, 3, 5, 7};
+ std::vector<bool> validity = {true, true, true, true};
+
+ auto array0 = MakeArrowArray<arrow::DoubleType, double>(input0, validity);
+ auto array1 = MakeArrowArray<arrow::DoubleType, double>(input1, validity);
+
+ // expected output
+ std::vector<double> cbrt_vals;
+ std::vector<double> exp_vals;
+ std::vector<double> log_vals;
+ std::vector<double> log10_vals;
+ std::vector<double> logb_vals;
+ std::vector<double> power_vals;
+ std::vector<double> sin_vals;
+ std::vector<double> cos_vals;
+ std::vector<double> asin_vals;
+ std::vector<double> acos_vals;
+ std::vector<double> tan_vals;
+ std::vector<double> atan_vals;
+ std::vector<double> sinh_vals;
+ std::vector<double> cosh_vals;
+ std::vector<double> tanh_vals;
+ std::vector<double> atan2_vals;
+ std::vector<double> cot_vals;
+ std::vector<double> radians_vals;
+ std::vector<double> degrees_vals;
+ for (int i = 0; i < num_records; i++) {
+ cbrt_vals.push_back(static_cast<double>(cbrtl(input0[i])));
+ exp_vals.push_back(static_cast<double>(expl(input0[i])));
+ log_vals.push_back(static_cast<double>(logl(input0[i])));
+ log10_vals.push_back(static_cast<double>(log10l(input0[i])));
+ logb_vals.push_back(static_cast<double>(logl(input1[i]) / logl(input0[i])));
+ power_vals.push_back(static_cast<double>(powl(input0[i], input1[i])));
+ sin_vals.push_back(static_cast<double>(sin(input0[i])));
+ cos_vals.push_back(static_cast<double>(cos(input0[i])));
+ asin_vals.push_back(static_cast<double>(asin(input0[i])));
+ acos_vals.push_back(static_cast<double>(acos(input0[i])));
+ tan_vals.push_back(static_cast<double>(tan(input0[i])));
+ atan_vals.push_back(static_cast<double>(atan(input0[i])));
+ sinh_vals.push_back(static_cast<double>(sinh(input0[i])));
+ cosh_vals.push_back(static_cast<double>(cosh(input0[i])));
+ tanh_vals.push_back(static_cast<double>(tanh(input0[i])));
+ atan2_vals.push_back(static_cast<double>(atan2(input0[i], input1[i])));
+ cot_vals.push_back(static_cast<double>(tan(M_PI / 2 - input0[i])));
+ radians_vals.push_back(static_cast<double>(input0[i] * M_PI / 180.0));
+ degrees_vals.push_back(static_cast<double>(input0[i] * 180.0 / M_PI));
+ }
+ auto expected_cbrt = MakeArrowArray<arrow::DoubleType, double>(cbrt_vals, validity);
+ auto expected_exp = MakeArrowArray<arrow::DoubleType, double>(exp_vals, validity);
+ auto expected_log = MakeArrowArray<arrow::DoubleType, double>(log_vals, validity);
+ auto expected_log10 = MakeArrowArray<arrow::DoubleType, double>(log10_vals, validity);
+ auto expected_logb = MakeArrowArray<arrow::DoubleType, double>(logb_vals, validity);
+ auto expected_power = MakeArrowArray<arrow::DoubleType, double>(power_vals, validity);
+ auto expected_sin = MakeArrowArray<arrow::DoubleType, double>(sin_vals, validity);
+ auto expected_cos = MakeArrowArray<arrow::DoubleType, double>(cos_vals, validity);
+ auto expected_asin = MakeArrowArray<arrow::DoubleType, double>(asin_vals, validity);
+ auto expected_acos = MakeArrowArray<arrow::DoubleType, double>(acos_vals, validity);
+ auto expected_tan = MakeArrowArray<arrow::DoubleType, double>(tan_vals, validity);
+ auto expected_atan = MakeArrowArray<arrow::DoubleType, double>(atan_vals, validity);
+ auto expected_sinh = MakeArrowArray<arrow::DoubleType, double>(sinh_vals, validity);
+ auto expected_cosh = MakeArrowArray<arrow::DoubleType, double>(cosh_vals, validity);
+ auto expected_tanh = MakeArrowArray<arrow::DoubleType, double>(tanh_vals, validity);
+ auto expected_atan2 = MakeArrowArray<arrow::DoubleType, double>(atan2_vals, validity);
+ auto expected_cot = MakeArrowArray<arrow::DoubleType, double>(cot_vals, validity);
+ auto expected_radians =
+ MakeArrowArray<arrow::DoubleType, double>(radians_vals, validity);
+ auto expected_degrees =
+ MakeArrowArray<arrow::DoubleType, double>(degrees_vals, validity);
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ double epsilon = 1E-13;
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_cbrt, outputs.at(0), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_exp, outputs.at(1), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_log, outputs.at(2), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_log10, outputs.at(3), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_logb, outputs.at(4), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_power, outputs.at(5), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_sin, outputs.at(6), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_cos, outputs.at(7), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_asin, outputs.at(8), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_acos, outputs.at(9), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_tan, outputs.at(10), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_atan, outputs.at(11), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_sinh, outputs.at(12), 1E-08);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_cosh, outputs.at(13), 1E-08);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_tanh, outputs.at(14), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_atan2, outputs.at(15), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_cot, outputs.at(16), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_radians, outputs.at(17), epsilon);
+ EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_degrees, outputs.at(18), epsilon);
+}
+
+TEST_F(TestProjector, TestFloatLessThan) {
+ // schema for input fields
+ auto field0 = field("f0", float32());
+ auto field1 = field("f2", float32());
+ auto schema = arrow::schema({field0, field1});
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ // Build expression
+ auto lt_expr =
+ TreeExprBuilder::MakeExpression("less_than", {field0, field1}, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 3;
+ auto array0 = MakeArrowArrayFloat32({1.0f, 8.9f, 3.0f}, {true, true, false});
+ auto array1 = MakeArrowArrayFloat32({4.0f, 3.4f, 6.8f}, {true, true, true});
+ // expected output
+ auto exp = MakeArrowArrayBool({true, false, false}, {true, true, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestIsNotNull) {
+ // schema for input fields
+ auto field0 = field("f0", float32());
+ auto schema = arrow::schema({field0});
+
+ // output fields
+ auto field_result = field("res", boolean());
+
+ // Build expression
+ auto myexpr = TreeExprBuilder::MakeExpression("isnotnull", {field0}, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {myexpr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 3;
+ auto array0 = MakeArrowArrayFloat32({1.0f, 8.9f, 3.0f}, {true, true, false});
+ // expected output
+ auto exp = MakeArrowArrayBool({true, true, false}, {true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestZeroCopy) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto schema = arrow::schema({field0});
+
+ // output fields
+ auto res = field("res", float32());
+
+ // Build expression
+ auto cast_expr = TreeExprBuilder::MakeExpression("castFLOAT4", {field0}, res);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {cast_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false});
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ // expected output
+ auto exp = MakeArrowArrayFloat32({1, 2, 3, 0}, {true, true, true, false});
+
+ // allocate output buffers
+ int64_t bitmap_sz = arrow::BitUtil::BytesForBits(num_records);
+ int64_t bitmap_capacity = arrow::BitUtil::RoundUpToMultipleOf64(bitmap_sz);
+ std::vector<uint8_t> bitmap(bitmap_capacity);
+ std::shared_ptr<arrow::MutableBuffer> bitmap_buf =
+ std::make_shared<arrow::MutableBuffer>(&bitmap[0], bitmap_capacity);
+
+ int64_t data_sz = sizeof(float) * num_records;
+ std::vector<uint8_t> data(bitmap_capacity);
+ std::shared_ptr<arrow::MutableBuffer> data_buf =
+ std::make_shared<arrow::MutableBuffer>(&data[0], data_sz);
+
+ auto array_data =
+ arrow::ArrayData::Make(float32(), num_records, {bitmap_buf, data_buf});
+
+ // Evaluate expression
+ status = projector->Evaluate(*in_batch, {array_data});
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ auto output = arrow::MakeArray(array_data);
+ EXPECT_ARROW_ARRAY_EQUALS(exp, output);
+}
+
+TEST_F(TestProjector, TestZeroCopyNegative) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto schema = arrow::schema({field0});
+
+ // output fields
+ auto res = field("res", float32());
+
+ // Build expression
+ auto cast_expr = TreeExprBuilder::MakeExpression("castFLOAT4", {field0}, res);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {cast_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false});
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ // expected output
+ auto exp = MakeArrowArrayFloat32({1, 2, 3, 0}, {true, true, true, false});
+
+ // allocate output buffers
+ int64_t bitmap_sz = arrow::BitUtil::BytesForBits(num_records);
+ std::unique_ptr<uint8_t[]> bitmap(new uint8_t[bitmap_sz]);
+ std::shared_ptr<arrow::MutableBuffer> bitmap_buf =
+ std::make_shared<arrow::MutableBuffer>(bitmap.get(), bitmap_sz);
+
+ int64_t data_sz = sizeof(float) * num_records;
+ std::unique_ptr<uint8_t[]> data(new uint8_t[data_sz]);
+ std::shared_ptr<arrow::MutableBuffer> data_buf =
+ std::make_shared<arrow::MutableBuffer>(data.get(), data_sz);
+
+ auto array_data =
+ arrow::ArrayData::Make(float32(), num_records, {bitmap_buf, data_buf});
+
+ // the batch can't be empty.
+ auto bad_batch = arrow::RecordBatch::Make(schema, 0 /*num_records*/, {array0});
+ status = projector->Evaluate(*bad_batch, {array_data});
+ EXPECT_EQ(status.code(), StatusCode::Invalid);
+
+ // the output array can't be null.
+ std::shared_ptr<arrow::ArrayData> null_array_data;
+ status = projector->Evaluate(*in_batch, {null_array_data});
+ EXPECT_EQ(status.code(), StatusCode::Invalid);
+
+ // the output array must have at least two buffers.
+ auto bad_array_data = arrow::ArrayData::Make(float32(), num_records, {bitmap_buf});
+ status = projector->Evaluate(*in_batch, {bad_array_data});
+ EXPECT_EQ(status.code(), StatusCode::Invalid);
+
+ // the output buffers must have sufficiently sized data_buf.
+ std::shared_ptr<arrow::MutableBuffer> bad_data_buf =
+ std::make_shared<arrow::MutableBuffer>(data.get(), data_sz - 1);
+ auto bad_array_data2 =
+ arrow::ArrayData::Make(float32(), num_records, {bitmap_buf, bad_data_buf});
+ status = projector->Evaluate(*in_batch, {bad_array_data2});
+ EXPECT_EQ(status.code(), StatusCode::Invalid);
+
+ // the output buffers must have sufficiently sized bitmap_buf.
+ std::shared_ptr<arrow::MutableBuffer> bad_bitmap_buf =
+ std::make_shared<arrow::MutableBuffer>(bitmap.get(), bitmap_sz - 1);
+ auto bad_array_data3 =
+ arrow::ArrayData::Make(float32(), num_records, {bad_bitmap_buf, data_buf});
+ status = projector->Evaluate(*in_batch, {bad_array_data3});
+ EXPECT_EQ(status.code(), StatusCode::Invalid);
+}
+
+TEST_F(TestProjector, TestDivideZero) {
+ // schema for input fields
+ auto field0 = field("f0", int32());
+ auto field1 = field("f2", int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // output fields
+ auto field_div = field("divide", int32());
+
+ // Build expression
+ auto div_expr = TreeExprBuilder::MakeExpression("divide", {field0, field1}, field_div);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {div_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array0 = MakeArrowArrayInt32({2, 3, 4, 5, 6}, {true, true, true, true, true});
+ auto array1 = MakeArrowArrayInt32({1, 2, 2, 0, 0}, {true, true, false, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_EQ(status.code(), StatusCode::ExecutionError);
+ std::string expected_error = "divide by zero error";
+ EXPECT_TRUE(status.message().find(expected_error) != std::string::npos);
+
+ // Testing for second batch that has no error should succeed.
+ num_records = 5;
+ array0 = MakeArrowArrayInt32({2, 3, 4, 5, 6}, {true, true, true, true, true});
+ array1 = MakeArrowArrayInt32({1, 2, 2, 1, 1}, {true, true, false, true, true});
+
+ // prepare input record batch
+ in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+ // expected output
+ auto exp = MakeArrowArrayInt32({2, 1, 2, 5, 6}, {true, true, false, true, true});
+
+ // Evaluate expression
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestModZero) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::int64());
+ auto field1 = field("f2", int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // output fields
+ auto field_div = field("mod", int32());
+
+ // Build expression
+ auto mod_expr = TreeExprBuilder::MakeExpression("mod", {field0, field1}, field_div);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {mod_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayInt64({2, 3, 4, 5}, {true, true, true, true});
+ auto array1 = MakeArrowArrayInt32({1, 2, 2, 0}, {true, true, false, true});
+ // expected output
+ auto exp_mod = MakeArrowArrayInt32({0, 1, 0, 5}, {true, true, false, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_mod, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestConcat) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::utf8());
+ auto field1 = field("f1", arrow::utf8());
+ auto schema = arrow::schema({field0, field1});
+
+ // output fields
+ auto field_concat = field("concat", arrow::utf8());
+
+ // Build expression
+ auto concat_expr =
+ TreeExprBuilder::MakeExpression("concat", {field0, field1}, field_concat);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {concat_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 6;
+ auto array0 = MakeArrowArrayUtf8({"ab", "", "ab", "invalid", "valid", "invalid"},
+ {true, true, true, false, true, false});
+ auto array1 = MakeArrowArrayUtf8({"cd", "cd", "", "valid", "invalid", "invalid"},
+ {true, true, true, true, false, false});
+ // expected output
+ auto exp_concat = MakeArrowArrayUtf8({"abcd", "cd", "ab", "valid", "valid", ""},
+ {true, true, true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_concat, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestBase64) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::binary());
+ auto schema = arrow::schema({field0});
+
+ // output fields
+ auto field_base = field("base64", arrow::utf8());
+
+ // Build expression
+ auto base_expr = TreeExprBuilder::MakeExpression("base64", {field0}, field_base);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {base_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 =
+ MakeArrowArrayBinary({"hello", "", "test", "hive"}, {true, true, true, true});
+ // expected output
+ auto exp_base = MakeArrowArrayUtf8({"aGVsbG8=", "", "dGVzdA==", "aGl2ZQ=="},
+ {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_base, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestUnbase64) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::utf8());
+ auto schema = arrow::schema({field0});
+
+ // output fields
+ auto field_base = field("base64", arrow::binary());
+
+ // Build expression
+ auto base_expr = TreeExprBuilder::MakeExpression("unbase64", {field0}, field_base);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {base_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayUtf8({"aGVsbG8=", "", "dGVzdA==", "aGl2ZQ=="},
+ {true, true, true, true});
+ // expected output
+ auto exp_unbase =
+ MakeArrowArrayBinary({"hello", "", "test", "hive"}, {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_unbase, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestLeftString) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::utf8());
+ auto field1 = field("f1", arrow::int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // output fields
+ auto field_concat = field("left", arrow::utf8());
+
+ // Build expression
+ auto concat_expr =
+ TreeExprBuilder::MakeExpression("left", {field0, field1}, field_concat);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {concat_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 6;
+ auto array0 = MakeArrowArrayUtf8({"ab", "", "ab", "invalid", "valid", "invalid"},
+ {true, true, true, true, true, true});
+ auto array1 =
+ MakeArrowArrayInt32({1, 500, 2, -5, 5, 0}, {true, true, true, true, true, true});
+ // expected output
+ auto exp_left = MakeArrowArrayUtf8({"a", "", "ab", "in", "valid", ""},
+ {true, true, true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_left, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestRightString) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::utf8());
+ auto field1 = field("f1", arrow::int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // output fields
+ auto field_concat = field("right", arrow::utf8());
+
+ // Build expression
+ auto concat_expr =
+ TreeExprBuilder::MakeExpression("right", {field0, field1}, field_concat);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {concat_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 6;
+ auto array0 = MakeArrowArrayUtf8({"ab", "", "ab", "invalid", "valid", "invalid"},
+ {true, true, true, true, true, true});
+ auto array1 =
+ MakeArrowArrayInt32({1, 500, 2, -5, 5, 0}, {true, true, true, true, true, true});
+ // expected output
+ auto exp_left = MakeArrowArrayUtf8({"b", "", "ab", "id", "valid", ""},
+ {true, true, true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_left, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestOffset) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::int32());
+ auto field1 = field("f1", arrow::int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // output fields
+ auto field_sum = field("sum", arrow::int32());
+
+ // Build expression
+ auto sum_expr = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {sum_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayInt32({1, 2, 3, 4, 5}, {true, true, true, true, false});
+ array0 = array0->Slice(1);
+ auto array1 = MakeArrowArrayInt32({5, 6, 7, 8}, {true, false, true, true});
+ // expected output
+ auto exp_sum = MakeArrowArrayInt32({9, 11, 13}, {false, true, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+ in_batch = in_batch->Slice(1);
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_sum, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestByteSubString) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::binary());
+ auto field1 = field("f1", arrow::int32());
+ auto field2 = field("f2", arrow::int32());
+ auto schema = arrow::schema({field0, field1, field2});
+
+ // output fields
+ auto field_byte_substr = field("bytesubstring", arrow::binary());
+
+ // Build expression
+ auto byte_substr_expr = TreeExprBuilder::MakeExpression(
+ "bytesubstring", {field0, field1, field2}, field_byte_substr);
+
+ std::shared_ptr<Projector> projector;
+ auto status =
+ Projector::Make(schema, {byte_substr_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 6;
+ auto array0 = MakeArrowArrayBinary({"ab", "", "ab", "invalid", "valid", "invalid"},
+ {true, true, true, true, true, true});
+ auto array1 =
+ MakeArrowArrayInt32({0, 1, 1, 1, 3, 3}, {true, true, true, true, true, true});
+ auto array2 =
+ MakeArrowArrayInt32({0, 1, 1, 2, 3, 3}, {true, true, true, true, true, true});
+ // expected output
+ auto exp_byte_substr = MakeArrowArrayBinary({"", "", "a", "in", "lid", "val"},
+ {true, true, true, true, true, true});
+
+ // prepare input record batch
+ auto in = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_byte_substr, outputs.at(0));
+}
+
+// Test to ensure behaviour of cast functions when the validity is false for an input. The
+// function should not run for that input.
+TEST_F(TestProjector, TestCastFunction) {
+ auto field0 = field("f0", arrow::utf8());
+ auto schema = arrow::schema({field0});
+
+ // output fields
+ auto res_float4 = field("res_float4", arrow::float32());
+ auto res_float8 = field("res_float8", arrow::float64());
+ auto res_int4 = field("castINT", arrow::int32());
+ auto res_int8 = field("castBIGINT", arrow::int64());
+
+ // Build expression
+ auto cast_expr_float4 =
+ TreeExprBuilder::MakeExpression("castFLOAT4", {field0}, res_float4);
+ auto cast_expr_float8 =
+ TreeExprBuilder::MakeExpression("castFLOAT8", {field0}, res_float8);
+ auto cast_expr_int4 = TreeExprBuilder::MakeExpression("castINT", {field0}, res_int4);
+ auto cast_expr_int8 = TreeExprBuilder::MakeExpression("castBIGINT", {field0}, res_int8);
+
+ std::shared_ptr<Projector> projector;
+
+ // {cast_expr_float4, cast_expr_float8, cast_expr_int4, cast_expr_int8}
+ auto status = Projector::Make(
+ schema, {cast_expr_float4, cast_expr_float8, cast_expr_int4, cast_expr_int8},
+ TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+
+ // Last validity is false and the cast functions throw error when input is empty. Should
+ // not be evaluated due to addition of NativeFunction::kCanReturnErrors
+ auto array0 = MakeArrowArrayUtf8({"1", "2", "3", ""}, {true, true, true, false});
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ auto out_float4 = MakeArrowArrayFloat32({1, 2, 3, 0}, {true, true, true, false});
+ auto out_float8 = MakeArrowArrayFloat64({1, 2, 3, 0}, {true, true, true, false});
+ auto out_int4 = MakeArrowArrayInt32({1, 2, 3, 0}, {true, true, true, false});
+ auto out_int8 = MakeArrowArrayInt64({1, 2, 3, 0}, {true, true, true, false});
+
+ arrow::ArrayVector outputs;
+
+ // Evaluate expression
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ EXPECT_ARROW_ARRAY_EQUALS(out_float4, outputs.at(0));
+ EXPECT_ARROW_ARRAY_EQUALS(out_float8, outputs.at(1));
+ EXPECT_ARROW_ARRAY_EQUALS(out_int4, outputs.at(2));
+ EXPECT_ARROW_ARRAY_EQUALS(out_int8, outputs.at(3));
+}
+
+TEST_F(TestProjector, TestCastBitFunction) {
+ auto field0 = field("f0", arrow::utf8());
+ auto schema = arrow::schema({field0});
+
+ // output fields
+ auto res_bit = field("res_bit", arrow::boolean());
+
+ // Build expression
+ auto cast_bit = TreeExprBuilder::MakeExpression("castBIT", {field0}, res_bit);
+
+ std::shared_ptr<Projector> projector;
+
+ auto status = Projector::Make(schema, {cast_bit}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto arr = MakeArrowArrayUtf8({"1", "true", "false", "0"}, {true, true, true, true});
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arr});
+
+ auto out = MakeArrowArrayBool({true, true, false, false}, {true, true, true, true});
+
+ arrow::ArrayVector outputs;
+
+ // Evaluate expression
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ EXPECT_ARROW_ARRAY_EQUALS(out, outputs.at(0));
+}
+
+// Test to ensure behaviour of cast functions when the validity is false for an input. The
+// function should not run for that input.
+TEST_F(TestProjector, TestCastVarbinaryFunction) {
+ auto field0 = field("f0", arrow::binary());
+ auto schema = arrow::schema({field0});
+
+ // output fields
+ auto res_int4 = field("res_int4", arrow::int32());
+ auto res_int8 = field("res_int8", arrow::int64());
+ auto res_float4 = field("res_float4", arrow::float32());
+ auto res_float8 = field("res_float8", arrow::float64());
+
+ // Build expression
+ auto cast_expr_int4 = TreeExprBuilder::MakeExpression("castINT", {field0}, res_int4);
+ auto cast_expr_int8 = TreeExprBuilder::MakeExpression("castBIGINT", {field0}, res_int8);
+ auto cast_expr_float4 =
+ TreeExprBuilder::MakeExpression("castFLOAT4", {field0}, res_float4);
+ auto cast_expr_float8 =
+ TreeExprBuilder::MakeExpression("castFLOAT8", {field0}, res_float8);
+
+ std::shared_ptr<Projector> projector;
+
+ // {cast_expr_float4, cast_expr_float8, cast_expr_int4, cast_expr_int8}
+ auto status = Projector::Make(
+ schema, {cast_expr_int4, cast_expr_int8, cast_expr_float4, cast_expr_float8},
+ TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+
+ // Last validity is false and the cast functions throw error when input is empty. Should
+ // not be evaluated due to addition of NativeFunction::kCanReturnErrors
+ auto array0 =
+ MakeArrowArrayBinary({"37", "-99999", "99999", "4"}, {true, true, true, false});
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ auto out_int4 = MakeArrowArrayInt32({37, -99999, 99999, 0}, {true, true, true, false});
+ auto out_int8 = MakeArrowArrayInt64({37, -99999, 99999, 0}, {true, true, true, false});
+ auto out_float4 =
+ MakeArrowArrayFloat32({37, -99999, 99999, 0}, {true, true, true, false});
+ auto out_float8 =
+ MakeArrowArrayFloat64({37, -99999, 99999, 0}, {true, true, true, false});
+
+ arrow::ArrayVector outputs;
+
+ // Evaluate expression
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ EXPECT_ARROW_ARRAY_EQUALS(out_int4, outputs.at(0));
+ EXPECT_ARROW_ARRAY_EQUALS(out_int8, outputs.at(1));
+ EXPECT_ARROW_ARRAY_EQUALS(out_float4, outputs.at(2));
+ EXPECT_ARROW_ARRAY_EQUALS(out_float8, outputs.at(3));
+}
+
+TEST_F(TestProjector, TestToDate) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::utf8());
+ auto field_node = std::make_shared<FieldNode>(field0);
+ auto schema = arrow::schema({field0});
+
+ // output fields
+ auto field_result = field("res", arrow::date64());
+
+ auto pattern_node = std::make_shared<LiteralNode>(
+ arrow::utf8(), LiteralHolder(std::string("YYYY-MM-DD")), false);
+
+ // Build expression
+ auto fn_node = TreeExprBuilder::MakeFunction("to_date", {field_node, pattern_node},
+ arrow::date64());
+ auto expr = TreeExprBuilder::MakeExpression(fn_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 3;
+ auto array0 =
+ MakeArrowArrayUtf8({"1986-12-01", "2012-12-01", "invalid"}, {true, true, false});
+ // expected output
+ auto exp = MakeArrowArrayDate64({533779200000, 1354320000000, 0}, {true, true, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+// ARROW-11617
+TEST_F(TestProjector, TestIfElseOpt) {
+ // schema for input
+ auto field0 = field("f0", int32());
+ auto field1 = field("f1", int32());
+ auto field2 = field("f2", int32());
+ auto schema = arrow::schema({field0, field1, field2});
+
+ auto f0 = std::make_shared<FieldNode>(field0);
+ auto f1 = std::make_shared<FieldNode>(field1);
+ auto f2 = std::make_shared<FieldNode>(field2);
+
+ // output fields
+ auto field_result = field("out", int32());
+
+ // Expr - (f0, f1 - null; f2 non null)
+ //
+ // if (is not null(f0))
+ // then f0
+ // else add((
+ // if (is not null (f1))
+ // then f1
+ // else f2
+ // ), f1)
+
+ auto cond_node_inner = TreeExprBuilder::MakeFunction("isnotnull", {f1}, boolean());
+ auto if_node_inner = TreeExprBuilder::MakeIf(cond_node_inner, f1, f2, int32());
+
+ auto cond_node_outer = TreeExprBuilder::MakeFunction("isnotnull", {f0}, boolean());
+ auto else_node_outer =
+ TreeExprBuilder::MakeFunction("add", {if_node_inner, f1}, int32());
+
+ auto if_node_outer =
+ TreeExprBuilder::MakeIf(cond_node_outer, f1, else_node_outer, int32());
+ auto expr = TreeExprBuilder::MakeExpression(if_node_outer, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 1;
+ auto array0 = MakeArrowArrayInt32({0}, {false});
+ auto array1 = MakeArrowArrayInt32({0}, {false});
+ auto array2 = MakeArrowArrayInt32({99}, {true});
+ // expected output
+ auto exp = MakeArrowArrayInt32({0}, {false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestRepeat) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::utf8());
+ auto field1 = field("f1", arrow::int32());
+ auto schema = arrow::schema({field0, field1});
+
+ // output fields
+ auto field_repeat = field("repeat", arrow::utf8());
+
+ // Build expression
+ auto repeat_expr =
+ TreeExprBuilder::MakeExpression("repeat", {field0, field1}, field_repeat);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {repeat_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array0 =
+ MakeArrowArrayUtf8({"ab", "a", "car", "valid", ""}, {true, true, true, true, true});
+ auto array1 = MakeArrowArrayInt32({2, 1, 3, 2, 10}, {true, true, true, true, true});
+ // expected output
+ auto exp_repeat = MakeArrowArrayUtf8({"abab", "a", "carcarcar", "validvalid", ""},
+ {true, true, true, true, true});
+
+ // prepare input record batch
+ auto in = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_repeat, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestLpad) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::utf8());
+ auto field1 = field("f1", arrow::int32());
+ auto field2 = field("f2", arrow::utf8());
+ auto schema = arrow::schema({field0, field1, field2});
+
+ // output fields
+ auto field_lpad = field("lpad", arrow::utf8());
+
+ // Build expression
+ auto lpad_expr =
+ TreeExprBuilder::MakeExpression("lpad", {field0, field1, field2}, field_lpad);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {lpad_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 7;
+ auto array0 = MakeArrowArrayUtf8({"ab", "a", "ab", "invalid", "valid", "invalid", ""},
+ {true, true, true, true, true, true, true});
+ auto array1 = MakeArrowArrayInt32({1, 5, 3, 12, 0, 2, 10},
+ {true, true, true, true, true, true, true});
+ auto array2 = MakeArrowArrayUtf8({"z", "z", "c", "valid", "invalid", "invalid", ""},
+ {true, true, true, true, true, true, true});
+ // expected output
+ auto exp_lpad = MakeArrowArrayUtf8({"a", "zzzza", "cab", "validinvalid", "", "in", ""},
+ {true, true, true, true, true, true, true});
+
+ // prepare input record batch
+ auto in = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_lpad, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestRpad) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::utf8());
+ auto field1 = field("f1", arrow::int32());
+ auto field2 = field("f2", arrow::utf8());
+ auto schema = arrow::schema({field0, field1, field2});
+
+ // output fields
+ auto field_rpad = field("rpad", arrow::utf8());
+
+ // Build expression
+ auto rpad_expr =
+ TreeExprBuilder::MakeExpression("rpad", {field0, field1, field2}, field_rpad);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {rpad_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 7;
+ auto array0 = MakeArrowArrayUtf8({"ab", "a", "ab", "invalid", "valid", "invalid", ""},
+ {true, true, true, true, true, true, true});
+ auto array1 = MakeArrowArrayInt32({1, 5, 3, 12, 0, 2, 10},
+ {true, true, true, true, true, true, true});
+ auto array2 = MakeArrowArrayUtf8({"z", "z", "c", "valid", "invalid", "invalid", ""},
+ {true, true, true, true, true, true, true});
+ // expected output
+ auto exp_rpad = MakeArrowArrayUtf8({"a", "azzzz", "abc", "invalidvalid", "", "in", ""},
+ {true, true, true, true, true, true, true});
+
+ // prepare input record batch
+ auto in = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_rpad, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestBinRepresentation) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::int64());
+ auto schema = arrow::schema({field0});
+
+ // output fields
+ auto field_result = field("bin", arrow::utf8());
+
+ // Build expression
+ auto myexpr = TreeExprBuilder::MakeExpression("bin", {field0}, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {myexpr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 3;
+ auto array0 = MakeArrowArrayInt64({7, -28550, 58117}, {true, true, true});
+ // expected output
+ auto exp = MakeArrowArrayUtf8(
+ {"111", "1111111111111111111111111111111111111111111111111001000001111010",
+ "1110001100000101"},
+ {true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestProjector, TestBigIntCastFunction) {
+ // input fields
+ auto field0 = field("f0", arrow::float32());
+ auto field1 = field("f1", arrow::float64());
+ auto field2 = field("f2", arrow::day_time_interval());
+ auto field3 = field("f3", arrow::month_interval());
+ auto schema = arrow::schema({field0, field1, field2, field3});
+
+ // output fields
+ auto res_int64 = field("res", arrow::int64());
+
+ // Build expression
+ auto cast_expr_float4 =
+ TreeExprBuilder::MakeExpression("castBIGINT", {field0}, res_int64);
+ auto cast_expr_float8 =
+ TreeExprBuilder::MakeExpression("castBIGINT", {field1}, res_int64);
+ auto cast_expr_day_interval =
+ TreeExprBuilder::MakeExpression("castBIGINT", {field2}, res_int64);
+ auto cast_expr_year_interval =
+ TreeExprBuilder::MakeExpression("castBIGINT", {field3}, res_int64);
+
+ std::shared_ptr<Projector> projector;
+
+ // {cast_expr_float4, cast_expr_float8, cast_expr_day_interval,
+ // cast_expr_year_interval}
+ auto status = Projector::Make(schema,
+ {cast_expr_float4, cast_expr_float8,
+ cast_expr_day_interval, cast_expr_year_interval},
+ TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+
+ // Last validity is false and the cast functions throw error when input is empty. Should
+ // not be evaluated due to addition of NativeFunction::kCanReturnErrors
+ auto array0 =
+ MakeArrowArrayFloat32({6.6f, -6.6f, 9.999999f, 0}, {true, true, true, false});
+ auto array1 =
+ MakeArrowArrayFloat64({6.6, -6.6, 9.99999999999, 0}, {true, true, true, false});
+ auto array2 = MakeArrowArrayInt64({100, 25, -0, 0}, {true, true, true, false});
+ auto array3 = MakeArrowArrayInt32({25, -25, -0, 0}, {true, true, true, false});
+ auto in_batch =
+ arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2, array3});
+
+ auto out_float4 = MakeArrowArrayInt64({7, -7, 10, 0}, {true, true, true, false});
+ auto out_float8 = MakeArrowArrayInt64({7, -7, 10, 0}, {true, true, true, false});
+ auto out_days_interval =
+ MakeArrowArrayInt64({8640000000, 2160000000, 0, 0}, {true, true, true, false});
+ auto out_year_interval = MakeArrowArrayInt64({2, -2, 0, 0}, {true, true, true, false});
+
+ arrow::ArrayVector outputs;
+
+ // Evaluate expression
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ EXPECT_ARROW_ARRAY_EQUALS(out_float4, outputs.at(0));
+ EXPECT_ARROW_ARRAY_EQUALS(out_float8, outputs.at(1));
+ EXPECT_ARROW_ARRAY_EQUALS(out_days_interval, outputs.at(2));
+ EXPECT_ARROW_ARRAY_EQUALS(out_year_interval, outputs.at(3));
+}
+
+TEST_F(TestProjector, TestIntCastFunction) {
+ // input fields
+ auto field0 = field("f0", arrow::float32());
+ auto field1 = field("f1", arrow::float64());
+ auto field2 = field("f2", arrow::month_interval());
+ auto schema = arrow::schema({field0, field1, field2});
+
+ // output fields
+ auto res_int32 = field("res", arrow::int32());
+
+ // Build expression
+ auto cast_expr_float4 = TreeExprBuilder::MakeExpression("castINT", {field0}, res_int32);
+ auto cast_expr_float8 = TreeExprBuilder::MakeExpression("castINT", {field1}, res_int32);
+ auto cast_expr_year_interval =
+ TreeExprBuilder::MakeExpression("castINT", {field2}, res_int32);
+
+ std::shared_ptr<Projector> projector;
+
+ // {cast_expr_float4, cast_expr_float8, cast_expr_day_interval,
+ // cast_expr_year_interval}
+ auto status = Projector::Make(
+ schema, {cast_expr_float4, cast_expr_float8, cast_expr_year_interval},
+ TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+
+ // Last validity is false and the cast functions throw error when input is empty. Should
+ // not be evaluated due to addition of NativeFunction::kCanReturnErrors
+ auto array0 =
+ MakeArrowArrayFloat32({6.6f, -6.6f, 9.999999f, 0}, {true, true, true, false});
+ auto array1 =
+ MakeArrowArrayFloat64({6.6, -6.6, 9.99999999999, 0}, {true, true, true, false});
+ auto array2 = MakeArrowArrayInt32({25, -25, -0, 0}, {true, true, true, false});
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2});
+
+ auto out_float4 = MakeArrowArrayInt32({7, -7, 10, 0}, {true, true, true, false});
+ auto out_float8 = MakeArrowArrayInt32({7, -7, 10, 0}, {true, true, true, false});
+ auto out_year_interval = MakeArrowArrayInt32({2, -2, 0, 0}, {true, true, true, false});
+
+ arrow::ArrayVector outputs;
+
+ // Evaluate expression
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ EXPECT_ARROW_ARRAY_EQUALS(out_float4, outputs.at(0));
+ EXPECT_ARROW_ARRAY_EQUALS(out_float8, outputs.at(1));
+ EXPECT_ARROW_ARRAY_EQUALS(out_year_interval, outputs.at(2));
+}
+
+TEST_F(TestProjector, TestCastNullableIntYearInterval) {
+ // input fields
+ auto field1 = field("f1", arrow::month_interval());
+ auto schema = arrow::schema({field1});
+
+ // output fields
+ auto res_int32 = field("res", arrow::int32());
+ auto res_int64 = field("res", arrow::int64());
+
+ // Build expression
+ auto cast_expr_int32 =
+ TreeExprBuilder::MakeExpression("castNULLABLEINT", {field1}, res_int32);
+ auto cast_expr_int64 =
+ TreeExprBuilder::MakeExpression("castNULLABLEBIGINT", {field1}, res_int64);
+
+ std::shared_ptr<Projector> projector;
+
+ // {cast_expr_int32, cast_expr_int64, cast_expr_day_interval,
+ // cast_expr_year_interval}
+ auto status = Projector::Make(schema, {cast_expr_int32, cast_expr_int64},
+ TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+
+ // Last validity is false and the cast functions throw error when input is empty. Should
+ // not be evaluated due to addition of NativeFunction::kCanReturnErrors
+ auto array0 = MakeArrowArrayInt32({12, -24, -0, 0}, {true, true, true, false});
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ auto out_int32 = MakeArrowArrayInt32({1, -2, -0, 0}, {true, true, true, false});
+ auto out_int64 = MakeArrowArrayInt64({1, -2, -0, 0}, {true, true, true, false});
+
+ arrow::ArrayVector outputs;
+
+ // Evaluate expression
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ EXPECT_ARROW_ARRAY_EQUALS(out_int32, outputs.at(0));
+ EXPECT_ARROW_ARRAY_EQUALS(out_int64, outputs.at(1));
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/test_util.h b/src/arrow/cpp/src/gandiva/tests/test_util.h
new file mode 100644
index 000000000..54270436c
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/test_util.h
@@ -0,0 +1,103 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <chrono>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/testing/gtest_util.h"
+#include "gandiva/arrow.h"
+#include "gandiva/configuration.h"
+
+#pragma once
+
+namespace gandiva {
+
+// Helper function to create an arrow-array of type ARROWTYPE
+// from primitive vectors of data & validity.
+//
+// arrow/testing/gtest_util.h has good utility classes for this purpose.
+// Using those
+template <typename TYPE, typename C_TYPE>
+static inline ArrayPtr MakeArrowArray(std::vector<C_TYPE> values,
+ std::vector<bool> validity) {
+ ArrayPtr out;
+ arrow::ArrayFromVector<TYPE, C_TYPE>(validity, values, &out);
+ return out;
+}
+
+template <typename TYPE, typename C_TYPE>
+static inline ArrayPtr MakeArrowArray(std::vector<C_TYPE> values) {
+ ArrayPtr out;
+ arrow::ArrayFromVector<TYPE, C_TYPE>(values, &out);
+ return out;
+}
+
+template <typename TYPE, typename C_TYPE>
+static inline ArrayPtr MakeArrowArray(const std::shared_ptr<arrow::DataType>& type,
+ std::vector<C_TYPE> values,
+ std::vector<bool> validity) {
+ ArrayPtr out;
+ arrow::ArrayFromVector<TYPE, C_TYPE>(type, validity, values, &out);
+ return out;
+}
+
+template <typename TYPE, typename C_TYPE>
+static inline ArrayPtr MakeArrowTypeArray(const std::shared_ptr<arrow::DataType>& type,
+ const std::vector<C_TYPE>& values,
+ const std::vector<bool>& validity) {
+ ArrayPtr out;
+ arrow::ArrayFromVector<TYPE, C_TYPE>(type, validity, values, &out);
+ return out;
+}
+
+#define MakeArrowArrayBool MakeArrowArray<arrow::BooleanType, bool>
+#define MakeArrowArrayInt8 MakeArrowArray<arrow::Int8Type, int8_t>
+#define MakeArrowArrayInt16 MakeArrowArray<arrow::Int16Type, int16_t>
+#define MakeArrowArrayInt32 MakeArrowArray<arrow::Int32Type, int32_t>
+#define MakeArrowArrayInt64 MakeArrowArray<arrow::Int64Type, int64_t>
+#define MakeArrowArrayUint8 MakeArrowArray<arrow::UInt8Type, uint8_t>
+#define MakeArrowArrayUint16 MakeArrowArray<arrow::UInt16Type, uint16_t>
+#define MakeArrowArrayUint32 MakeArrowArray<arrow::UInt32Type, uint32_t>
+#define MakeArrowArrayUint64 MakeArrowArray<arrow::UInt64Type, uint64_t>
+#define MakeArrowArrayFloat32 MakeArrowArray<arrow::FloatType, float>
+#define MakeArrowArrayFloat64 MakeArrowArray<arrow::DoubleType, double>
+#define MakeArrowArrayDate64 MakeArrowArray<arrow::Date64Type, int64_t>
+#define MakeArrowArrayUtf8 MakeArrowArray<arrow::StringType, std::string>
+#define MakeArrowArrayBinary MakeArrowArray<arrow::BinaryType, std::string>
+#define MakeArrowArrayDecimal MakeArrowArray<arrow::Decimal128Type, arrow::Decimal128>
+
+#define EXPECT_ARROW_ARRAY_EQUALS(a, b) \
+ EXPECT_TRUE((a)->Equals(b, arrow::EqualOptions().nans_equal(true))) \
+ << "expected array: " << (a)->ToString() << " actual array: " << (b)->ToString()
+
+#define EXPECT_ARROW_ARRAY_APPROX_EQUALS(a, b, epsilon) \
+ EXPECT_TRUE( \
+ (a)->ApproxEquals(b, arrow::EqualOptions().atol(epsilon).nans_equal(true))) \
+ << "expected array: " << (a)->ToString() << " actual array: " << (b)->ToString()
+
+#define EXPECT_ARROW_TYPE_EQUALS(a, b) \
+ EXPECT_TRUE((a)->Equals(b)) << "expected type: " << (a)->ToString() \
+ << " actual type: " << (b)->ToString()
+
+static inline std::shared_ptr<Configuration> TestConfiguration() {
+ auto builder = ConfigurationBuilder();
+ return builder.DefaultConfiguration();
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/timed_evaluate.h b/src/arrow/cpp/src/gandiva/tests/timed_evaluate.h
new file mode 100644
index 000000000..eba0f5eb9
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/timed_evaluate.h
@@ -0,0 +1,136 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <vector>
+#include "benchmark/benchmark.h"
+#include "gandiva/arrow.h"
+#include "gandiva/filter.h"
+#include "gandiva/projector.h"
+#include "gandiva/tests/generate_data.h"
+
+#pragma once
+
+#define THOUSAND (1024)
+#define MILLION (1024 * 1024)
+#define NUM_BATCHES 16
+
+namespace gandiva {
+
+template <typename C_TYPE>
+std::vector<C_TYPE> GenerateData(int num_records, DataGenerator<C_TYPE>& data_generator) {
+ std::vector<C_TYPE> data;
+
+ for (int i = 0; i < num_records; i++) {
+ data.push_back(data_generator.GenerateData());
+ }
+
+ return data;
+}
+
+class BaseEvaluator {
+ public:
+ virtual ~BaseEvaluator() = default;
+
+ virtual Status Evaluate(arrow::RecordBatch& batch, arrow::MemoryPool* pool) = 0;
+};
+
+class ProjectEvaluator : public BaseEvaluator {
+ public:
+ explicit ProjectEvaluator(std::shared_ptr<Projector> projector)
+ : projector_(projector) {}
+
+ Status Evaluate(arrow::RecordBatch& batch, arrow::MemoryPool* pool) override {
+ arrow::ArrayVector outputs;
+ return projector_->Evaluate(batch, pool, &outputs);
+ }
+
+ private:
+ std::shared_ptr<Projector> projector_;
+};
+
+class FilterEvaluator : public BaseEvaluator {
+ public:
+ explicit FilterEvaluator(std::shared_ptr<Filter> filter) : filter_(filter) {}
+
+ Status Evaluate(arrow::RecordBatch& batch, arrow::MemoryPool* pool) override {
+ if (selection_ == nullptr || selection_->GetMaxSlots() < batch.num_rows()) {
+ auto status = SelectionVector::MakeInt16(batch.num_rows(), pool, &selection_);
+ if (!status.ok()) {
+ return status;
+ }
+ }
+ return filter_->Evaluate(batch, selection_);
+ }
+
+ private:
+ std::shared_ptr<Filter> filter_;
+ std::shared_ptr<SelectionVector> selection_;
+};
+
+template <typename TYPE, typename C_TYPE>
+Status TimedEvaluate(SchemaPtr schema, BaseEvaluator& evaluator,
+ DataGenerator<C_TYPE>& data_generator, arrow::MemoryPool* pool,
+ int num_records, int batch_size, benchmark::State& state) {
+ int num_remaining = num_records;
+ int num_fields = schema->num_fields();
+ int num_calls = 0;
+ Status status;
+
+ // Generate batches of data
+ std::shared_ptr<arrow::RecordBatch> batches[NUM_BATCHES];
+ for (int i = 0; i < NUM_BATCHES; i++) {
+ // generate data for all columns in the schema
+ std::vector<ArrayPtr> columns;
+ for (int col = 0; col < num_fields; col++) {
+ std::vector<C_TYPE> data = GenerateData<C_TYPE>(batch_size, data_generator);
+ std::vector<bool> validity(batch_size, true);
+ ArrayPtr col_data =
+ MakeArrowArray<TYPE, C_TYPE>(schema->field(col)->type(), data, validity);
+
+ columns.push_back(col_data);
+ }
+
+ // make the record batch
+ std::shared_ptr<arrow::RecordBatch> batch =
+ arrow::RecordBatch::Make(schema, batch_size, columns);
+ batches[i] = batch;
+ }
+
+ for (auto _ : state) {
+ int num_in_batch = batch_size;
+ num_remaining = num_records;
+ while (num_remaining > 0) {
+ if (batch_size > num_remaining) {
+ num_in_batch = num_remaining;
+ }
+
+ status = evaluator.Evaluate(*(batches[num_calls % NUM_BATCHES]), pool);
+ if (!status.ok()) {
+ state.SkipWithError("Evaluation of the batch failed");
+ return status;
+ }
+
+ num_calls++;
+ num_remaining -= num_in_batch;
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/to_string_test.cc b/src/arrow/cpp/src/gandiva/tests/to_string_test.cc
new file mode 100644
index 000000000..55db6e92b
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/to_string_test.cc
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include <math.h>
+#include <time.h>
+#include "arrow/memory_pool.h"
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::float64;
+using arrow::int32;
+using arrow::int64;
+
+class TestToString : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+#define CHECK_EXPR_TO_STRING(e, str) EXPECT_STREQ(e->ToString().c_str(), str)
+
+TEST_F(TestToString, TestAll) {
+ auto literal_node = TreeExprBuilder::MakeLiteral((uint64_t)100);
+ auto literal_expr =
+ TreeExprBuilder::MakeExpression(literal_node, arrow::field("r", int64()));
+ CHECK_EXPR_TO_STRING(literal_expr, "(const uint64) 100");
+
+ auto f0 = arrow::field("f0", float64());
+ auto f0_node = TreeExprBuilder::MakeField(f0);
+ auto f0_expr = TreeExprBuilder::MakeExpression(f0_node, f0);
+ CHECK_EXPR_TO_STRING(f0_expr, "(double) f0");
+
+ auto f1 = arrow::field("f1", int64());
+ auto f2 = arrow::field("f2", int64());
+ auto f1_node = TreeExprBuilder::MakeField(f1);
+ auto f2_node = TreeExprBuilder::MakeField(f2);
+ auto add_node = TreeExprBuilder::MakeFunction("add", {f1_node, f2_node}, int64());
+ auto add_expr = TreeExprBuilder::MakeExpression(add_node, f1);
+ CHECK_EXPR_TO_STRING(add_expr, "int64 add((int64) f1, (int64) f2)");
+
+ auto cond_node = TreeExprBuilder::MakeFunction(
+ "lesser_than", {f0_node, TreeExprBuilder::MakeLiteral(static_cast<float>(0))},
+ boolean());
+ auto then_node = TreeExprBuilder::MakeField(f1);
+ auto else_node = TreeExprBuilder::MakeField(f2);
+
+ auto if_node = TreeExprBuilder::MakeIf(cond_node, then_node, else_node, int64());
+ auto if_expr = TreeExprBuilder::MakeExpression(if_node, f1);
+
+ CHECK_EXPR_TO_STRING(if_expr,
+ "if (bool lesser_than((double) f0, (const float) 0 raw(0))) { "
+ "(int64) f1 } else { (int64) f2 }");
+
+ auto f1_gt_100 =
+ TreeExprBuilder::MakeFunction("greater_than", {f1_node, literal_node}, boolean());
+ auto f2_equals_100 =
+ TreeExprBuilder::MakeFunction("equals", {f2_node, literal_node}, boolean());
+ auto and_node = TreeExprBuilder::MakeAnd({f1_gt_100, f2_equals_100});
+ auto and_expr =
+ TreeExprBuilder::MakeExpression(and_node, arrow::field("f0", boolean()));
+
+ CHECK_EXPR_TO_STRING(and_expr,
+ "bool greater_than((int64) f1, (const uint64) 100) && bool "
+ "equals((int64) f2, (const uint64) 100)");
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tests/utf8_test.cc b/src/arrow/cpp/src/gandiva/tests/utf8_test.cc
new file mode 100644
index 000000000..e19d6712d
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tests/utf8_test.cc
@@ -0,0 +1,751 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::date64;
+using arrow::int32;
+using arrow::int64;
+using arrow::utf8;
+
+class TestUtf8 : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+TEST_F(TestUtf8, TestSimple) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res_1 = field("res1", int32());
+ auto res_2 = field("res2", boolean());
+ auto res_3 = field("res3", int32());
+
+ // build expressions.
+ // octet_length(a)
+ // octet_length(a) == bit_length(a) / 8
+ // length(a)
+ auto expr_a = TreeExprBuilder::MakeExpression("octet_length", {field_a}, res_1);
+
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto octet_length = TreeExprBuilder::MakeFunction("octet_length", {node_a}, int32());
+ auto literal_8 = TreeExprBuilder::MakeLiteral((int32_t)8);
+ auto bit_length = TreeExprBuilder::MakeFunction("bit_length", {node_a}, int32());
+ auto div_8 = TreeExprBuilder::MakeFunction("divide", {bit_length, literal_8}, int32());
+ auto is_equal =
+ TreeExprBuilder::MakeFunction("equal", {octet_length, div_8}, boolean());
+ auto expr_b = TreeExprBuilder::MakeExpression(is_equal, res_2);
+ auto expr_c = TreeExprBuilder::MakeExpression("length", {field_a}, res_3);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status =
+ Projector::Make(schema, {expr_a, expr_b, expr_c}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array_a = MakeArrowArrayUtf8({"foo", "hello", "bye", "hi", "मदन"},
+ {true, true, false, true, true});
+
+ // expected output
+ auto exp_1 = MakeArrowArrayInt32({3, 5, 0, 2, 9}, {true, true, false, true, true});
+ auto exp_2 = MakeArrowArrayBool({true, true, false, true, true},
+ {true, true, false, true, true});
+ auto exp_3 = MakeArrowArrayInt32({3, 5, 0, 2, 3}, {true, true, false, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_1, outputs.at(0));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_2, outputs.at(1));
+ EXPECT_ARROW_ARRAY_EQUALS(exp_3, outputs.at(2));
+}
+
+TEST_F(TestUtf8, TestLiteral) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res = field("res", boolean());
+
+ // build expressions.
+ // a == literal(s)
+
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto literal_s = TreeExprBuilder::MakeStringLiteral("hello");
+ auto is_equal = TreeExprBuilder::MakeFunction("equal", {node_a, literal_s}, boolean());
+ auto expr = TreeExprBuilder::MakeExpression(is_equal, res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a =
+ MakeArrowArrayUtf8({"foo", "hello", "bye", "hi"}, {true, true, true, false});
+
+ // expected output
+ auto exp = MakeArrowArrayBool({false, true, false, false}, {true, true, true, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestUtf8, TestNullLiteral) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res = field("res", boolean());
+
+ // build expressions.
+ // a == literal(null)
+
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto literal_null = TreeExprBuilder::MakeNull(arrow::utf8());
+ auto is_equal =
+ TreeExprBuilder::MakeFunction("equal", {node_a, literal_null}, boolean());
+ auto expr = TreeExprBuilder::MakeExpression(is_equal, res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a =
+ MakeArrowArrayUtf8({"foo", "hello", "bye", "hi"}, {true, true, true, false});
+
+ // expected output
+ auto exp =
+ MakeArrowArrayBool({false, false, false, false}, {false, false, false, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestUtf8, TestLike) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res = field("res", boolean());
+
+ // build expressions.
+ // like(literal(s), a)
+
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto literal_s = TreeExprBuilder::MakeStringLiteral("%spark%");
+ auto is_like = TreeExprBuilder::MakeFunction("like", {node_a, literal_s}, boolean());
+ auto expr = TreeExprBuilder::MakeExpression(is_like, res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a = MakeArrowArrayUtf8({"park", "sparkle", "bright spark and fire", "spark"},
+ {true, true, true, true});
+
+ // expected output
+ auto exp = MakeArrowArrayBool({false, true, true, true}, {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestUtf8, TestLikeWithEscape) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res = field("res", boolean());
+
+ // build expressions.
+ // like(literal(s), a, '\')
+
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto literal_s = TreeExprBuilder::MakeStringLiteral("%pa\\%rk%");
+ auto escape_char = TreeExprBuilder::MakeStringLiteral("\\");
+ auto is_like =
+ TreeExprBuilder::MakeFunction("like", {node_a, literal_s, escape_char}, boolean());
+ auto expr = TreeExprBuilder::MakeExpression(is_like, res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a = MakeArrowArrayUtf8(
+ {"park", "spa%rkle", "bright spa%rk and fire", "spark"}, {true, true, true, true});
+
+ // expected output
+ auto exp = MakeArrowArrayBool({false, true, true, false}, {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestUtf8, TestBeginsEnds) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res1 = field("res1", boolean());
+ auto res2 = field("res2", boolean());
+
+ // build expressions.
+ // like(literal("spark%"), a)
+ // like(literal("%spark"), a)
+
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto literal_begin = TreeExprBuilder::MakeStringLiteral("spark%");
+ auto is_like1 =
+ TreeExprBuilder::MakeFunction("like", {node_a, literal_begin}, boolean());
+ auto expr1 = TreeExprBuilder::MakeExpression(is_like1, res1);
+
+ auto literal_end = TreeExprBuilder::MakeStringLiteral("%spark");
+ auto is_like2 = TreeExprBuilder::MakeFunction("like", {node_a, literal_end}, boolean());
+ auto expr2 = TreeExprBuilder::MakeExpression(is_like2, res2);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr1, expr2}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a =
+ MakeArrowArrayUtf8({"park", "sparkle", "bright spark and fire", "fiery spark"},
+ {true, true, true, true});
+
+ // expected output
+ auto exp1 = MakeArrowArrayBool({false, true, false, false}, {true, true, true, true});
+ auto exp2 = MakeArrowArrayBool({false, false, false, true}, {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp1, outputs.at(0));
+ EXPECT_ARROW_ARRAY_EQUALS(exp2, outputs.at(1));
+}
+
+TEST_F(TestUtf8, TestInternalAllocs) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res = field("res", boolean());
+
+ // build expressions.
+ // like(upper(a), literal("%SPARK%"))
+
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto upper_a = TreeExprBuilder::MakeFunction("upper", {node_a}, utf8());
+ auto literal_spark = TreeExprBuilder::MakeStringLiteral("%SPARK%");
+ auto is_like =
+ TreeExprBuilder::MakeFunction("like", {upper_a, literal_spark}, boolean());
+ auto expr = TreeExprBuilder::MakeExpression(is_like, res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array_a = MakeArrowArrayUtf8(
+ {"park", "Sparkle", "bright spark and fire", "fiery SPARK", "मदन"},
+ {true, true, false, true, true});
+
+ // expected output
+ auto exp = MakeArrowArrayBool({false, true, false, true, false},
+ {true, true, false, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestUtf8, TestCastDate) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res_1 = field("res1", int64());
+
+ // build expressions.
+ // extractYear(castDATE(a))
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto cast_function = TreeExprBuilder::MakeFunction("castDATE", {node_a}, date64());
+ auto extract_year =
+ TreeExprBuilder::MakeFunction("extractYear", {cast_function}, int64());
+ auto expr = TreeExprBuilder::MakeExpression(extract_year, res_1);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a = MakeArrowArrayUtf8({"1967-12-1", "67-12-01", "incorrect", "67-45-11"},
+ {true, true, false, true});
+
+ // expected output
+ auto exp_1 = MakeArrowArrayInt64({1967, 2067, 0, 0}, {true, true, false, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_EQ(status.code(), StatusCode::ExecutionError);
+ std::string expected_error = "Not a valid date value ";
+ EXPECT_TRUE(status.message().find(expected_error) != std::string::npos);
+
+ auto array_a_2 = MakeArrowArrayUtf8({"1967-12-1", "67-12-01", "67-1-1", "91-1-1"},
+ {true, true, true, true});
+ auto exp_2 = MakeArrowArrayInt64({1967, 2067, 2067, 1991}, {true, true, true, true});
+ auto in_batch_2 = arrow::RecordBatch::Make(schema, num_records, {array_a_2});
+ arrow::ArrayVector outputs2;
+ status = projector->Evaluate(*in_batch_2, pool_, &outputs2);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_2, outputs2.at(0));
+}
+
+TEST_F(TestUtf8, TestToDateNoError) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res_1 = field("res1", int64());
+
+ // build expressions.
+ // extractYear(castDATE(a))
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto node_b = TreeExprBuilder::MakeStringLiteral("YYYY-MM-DD");
+ auto node_c = TreeExprBuilder::MakeLiteral(1);
+
+ auto cast_function =
+ TreeExprBuilder::MakeFunction("to_date", {node_a, node_b, node_c}, date64());
+ auto extract_year =
+ TreeExprBuilder::MakeFunction("extractYear", {cast_function}, int64());
+ auto expr = TreeExprBuilder::MakeExpression(extract_year, res_1);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a = MakeArrowArrayUtf8({"1967-12-1", "67-12-01", "incorrect", "67-45-11"},
+ {true, true, false, true});
+
+ // expected output
+ auto exp_1 = MakeArrowArrayInt64({1967, 67, 0, 0}, {true, true, false, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+ EXPECT_ARROW_ARRAY_EQUALS(exp_1, outputs.at(0));
+
+ // Create a row-batch with some sample data
+ auto array_a_2 = MakeArrowArrayUtf8(
+ {"1967-12-1", "1967-12-01", "1967-11-11", "1991-11-11"}, {true, true, true, true});
+ auto exp_2 = MakeArrowArrayInt64({1967, 1967, 1967, 1991}, {true, true, true, true});
+ auto in_batch_2 = arrow::RecordBatch::Make(schema, num_records, {array_a_2});
+ arrow::ArrayVector outputs2;
+ status = projector->Evaluate(*in_batch_2, pool_, &outputs2);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_2, outputs2.at(0));
+}
+
+TEST_F(TestUtf8, TestToDateError) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // output fields
+ auto res_1 = field("res1", int64());
+
+ // build expressions.
+ // extractYear(castDATE(a))
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto node_b = TreeExprBuilder::MakeStringLiteral("YYYY-MM-DD");
+ auto node_c = TreeExprBuilder::MakeLiteral(0);
+
+ auto cast_function =
+ TreeExprBuilder::MakeFunction("to_date", {node_a, node_b, node_c}, date64());
+ auto extract_year =
+ TreeExprBuilder::MakeFunction("extractYear", {cast_function}, int64());
+ auto expr = TreeExprBuilder::MakeExpression(extract_year, res_1);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a = MakeArrowArrayUtf8({"1967-12-1", "67-12-01", "incorrect", "67-45-11"},
+ {true, true, false, true});
+
+ // expected output
+ auto exp_1 = MakeArrowArrayInt64({1967, 67, 0, 0}, {true, true, false, false});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_EQ(status.code(), StatusCode::ExecutionError);
+ std::string expected_error = "Error parsing value 67-45-11 for given format";
+ EXPECT_TRUE(status.message().find(expected_error) != std::string::npos)
+ << status.message();
+}
+
+TEST_F(TestUtf8, TestIsNull) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto schema = arrow::schema({field_a});
+
+ // build expressions
+ auto exprs = std::vector<ExpressionPtr>{
+ TreeExprBuilder::MakeExpression("isnull", {field_a}, field("is_null", boolean())),
+ TreeExprBuilder::MakeExpression("isnotnull", {field_a},
+ field("is_not_null", boolean())),
+ };
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a = MakeArrowArrayUtf8({"hello", "world", "incorrect", "universe"},
+ {true, true, false, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+
+ // validate results
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, false, true, false}),
+ outputs[0]); // isnull
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, true, false, true}),
+ outputs[1]); // isnotnull
+}
+
+TEST_F(TestUtf8, TestVarlenOutput) {
+ // schema for input fields
+ auto field_a = field("a", boolean());
+ auto schema = arrow::schema({field_a});
+
+ // build expressions.
+ // if (a) literal_hi else literal_bye
+ auto if_node = TreeExprBuilder::MakeIf(
+ TreeExprBuilder::MakeField(field_a), TreeExprBuilder::MakeStringLiteral("hi"),
+ TreeExprBuilder::MakeStringLiteral("bye"), utf8());
+ auto expr = TreeExprBuilder::MakeExpression(if_node, field("res", utf8()));
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+
+ // assert that it fails gracefully.
+ ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector));
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_in =
+ MakeArrowArrayBool({true, false, false, false}, {true, true, true, false});
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_in});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ ASSERT_OK(projector->Evaluate(*in_batch, pool_, &outputs));
+
+ // expected output
+ auto exp = MakeArrowArrayUtf8({"hi", "bye", "bye", "bye"}, {true, true, true, true});
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+TEST_F(TestUtf8, TestConvertUtf8) {
+ // schema for input fields
+ auto field_a = field("a", arrow::binary());
+ auto field_c = field("c", utf8());
+ auto schema = arrow::schema({field_a, field_c});
+
+ // output fields
+ auto res = field("res", boolean());
+
+ // build expressions.
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto node_c = TreeExprBuilder::MakeField(field_c);
+
+ // define char to replace
+ auto node_b = TreeExprBuilder::MakeStringLiteral("z");
+
+ auto convert_replace_utf8 =
+ TreeExprBuilder::MakeFunction("convert_replaceUTF8", {node_a, node_b}, utf8());
+ auto equals =
+ TreeExprBuilder::MakeFunction("equal", {convert_replace_utf8, node_c}, boolean());
+ auto expr = TreeExprBuilder::MakeExpression(equals, res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 3;
+ auto array_a = MakeArrowArrayUtf8({"ok-\xf8\x28"
+ "-a",
+ "all-valid", "ok-\xa0\xa1-valid"},
+ {true, true, true});
+
+ auto array_b =
+ MakeArrowArrayUtf8({"ok-z(-a", "all-valid", "ok-zz-valid"}, {true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ auto exp = MakeArrowArrayBool({true, true, true}, {true, true, true});
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs[0]);
+}
+
+TEST_F(TestUtf8, TestCastVarChar) {
+ // schema for input fields
+ auto field_a = field("a", utf8());
+ auto field_c = field("c", utf8());
+ auto schema = arrow::schema({field_a, field_c});
+
+ // output fields
+ auto res = field("res", boolean());
+
+ // build expressions.
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto node_c = TreeExprBuilder::MakeField(field_c);
+ // truncates the string to input length
+ auto node_b = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(10));
+ auto cast_varchar =
+ TreeExprBuilder::MakeFunction("castVARCHAR", {node_a, node_b}, utf8());
+ auto equals = TreeExprBuilder::MakeFunction("equal", {cast_varchar, node_c}, boolean());
+ auto expr = TreeExprBuilder::MakeExpression(equals, res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array_a = MakeArrowArrayUtf8(
+ {"park", "Sparkle", "bright spark and fire", "fiery SPARK", "मदन"},
+ {true, true, false, true, true});
+
+ auto array_b =
+ MakeArrowArrayUtf8({"park", "Sparkle", "bright spar", "fiery SPAR", "मदन"},
+ {true, true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ auto exp = MakeArrowArrayBool({true, true, false, true, true},
+ {true, true, false, true, true});
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs[0]);
+}
+
+TEST_F(TestUtf8, TestAscii) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::utf8());
+ auto schema = arrow::schema({field0});
+
+ // output fields
+ auto field_asc = field("ascii", arrow::int32());
+
+ // Build expression
+ auto asc_expr = TreeExprBuilder::MakeExpression("ascii", {field0}, field_asc);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {asc_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 6;
+ auto array0 = MakeArrowArrayUtf8({"ABC", "", "abc", "Hello World", "123", "999"},
+ {true, true, true, true, true, true});
+ // expected output
+ auto exp_asc =
+ MakeArrowArrayInt32({65, 0, 97, 72, 49, 57}, {true, true, true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_asc, outputs.at(0));
+}
+
+TEST_F(TestUtf8, TestSpace) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::int64());
+ auto schema = arrow::schema({field0});
+
+ // output fields
+ auto field_space = field("space", arrow::utf8());
+
+ // Build expression
+ auto space_expr = TreeExprBuilder::MakeExpression("space", {field0}, field_space);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {space_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array0 = MakeArrowArrayInt64({1, 0, -5, 2}, {true, true, true, true});
+ // expected output
+ auto exp_space = MakeArrowArrayUtf8({" ", "", "", " "}, {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_space, outputs.at(0));
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/to_date_holder.cc b/src/arrow/cpp/src/gandiva/to_date_holder.cc
new file mode 100644
index 000000000..1b7e2864f
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/to_date_holder.cc
@@ -0,0 +1,116 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/to_date_holder.h"
+
+#include <algorithm>
+#include <string>
+
+#include "arrow/util/value_parsing.h"
+#include "arrow/vendored/datetime.h"
+#include "gandiva/date_utils.h"
+#include "gandiva/execution_context.h"
+#include "gandiva/node.h"
+
+namespace gandiva {
+
+Status ToDateHolder::Make(const FunctionNode& node,
+ std::shared_ptr<ToDateHolder>* holder) {
+ if (node.children().size() != 2 && node.children().size() != 3) {
+ return Status::Invalid("'to_date' function requires two or three parameters");
+ }
+
+ auto literal_pattern = dynamic_cast<LiteralNode*>(node.children().at(1).get());
+ if (literal_pattern == nullptr) {
+ return Status::Invalid(
+ "'to_date' function requires a literal as the second parameter");
+ }
+
+ auto literal_type = literal_pattern->return_type()->id();
+ if (literal_type != arrow::Type::STRING && literal_type != arrow::Type::BINARY) {
+ return Status::Invalid(
+ "'to_date' function requires a string literal as the second parameter");
+ }
+ auto pattern = arrow::util::get<std::string>(literal_pattern->holder());
+
+ int suppress_errors = 0;
+ if (node.children().size() == 3) {
+ auto literal_suppress_errors =
+ dynamic_cast<LiteralNode*>(node.children().at(2).get());
+ if (literal_pattern == nullptr) {
+ return Status::Invalid(
+ "The (optional) third parameter to 'to_date' function needs to an integer "
+ "literal to indicate whether to suppress the error");
+ }
+
+ literal_type = literal_suppress_errors->return_type()->id();
+ if (literal_type != arrow::Type::INT32) {
+ return Status::Invalid(
+ "The (optional) third parameter to 'to_date' function needs to an integer "
+ "literal to indicate whether to suppress the error");
+ }
+ suppress_errors = arrow::util::get<int>(literal_suppress_errors->holder());
+ }
+
+ return Make(pattern, suppress_errors, holder);
+}
+
+Status ToDateHolder::Make(const std::string& sql_pattern, int32_t suppress_errors,
+ std::shared_ptr<ToDateHolder>* holder) {
+ std::shared_ptr<std::string> transformed_pattern;
+ ARROW_RETURN_NOT_OK(DateUtils::ToInternalFormat(sql_pattern, &transformed_pattern));
+ auto lholder = std::shared_ptr<ToDateHolder>(
+ new ToDateHolder(*(transformed_pattern.get()), suppress_errors));
+ *holder = lholder;
+ return Status::OK();
+}
+
+int64_t ToDateHolder::operator()(ExecutionContext* context, const char* data,
+ int data_len, bool in_valid, bool* out_valid) {
+ *out_valid = false;
+ if (!in_valid) {
+ return 0;
+ }
+
+ // Issues
+ // 1. processes date that do not match the format.
+ // 2. does not process time in format +08:00 (or) id.
+ int64_t seconds_since_epoch = 0;
+ if (!::arrow::internal::ParseTimestampStrptime(
+ data, data_len, pattern_.c_str(),
+ /*ignore_time_in_day=*/true, /*allow_trailing_chars=*/true,
+ ::arrow::TimeUnit::SECOND, &seconds_since_epoch)) {
+ return_error(context, data, data_len);
+ return 0;
+ }
+
+ *out_valid = true;
+ return seconds_since_epoch * 1000;
+}
+
+void ToDateHolder::return_error(ExecutionContext* context, const char* data,
+ int data_len) {
+ if (suppress_errors_ == 1) {
+ return;
+ }
+
+ std::string err_msg =
+ "Error parsing value " + std::string(data, data_len) + " for given format.";
+ context->set_error_msg(err_msg.c_str());
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/to_date_holder.h b/src/arrow/cpp/src/gandiva/to_date_holder.h
new file mode 100644
index 000000000..1211b6a30
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/to_date_holder.h
@@ -0,0 +1,58 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include "arrow/status.h"
+
+#include "gandiva/execution_context.h"
+#include "gandiva/function_holder.h"
+#include "gandiva/node.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// Function Holder for SQL 'to_date'
+class GANDIVA_EXPORT ToDateHolder : public FunctionHolder {
+ public:
+ ~ToDateHolder() override = default;
+
+ static Status Make(const FunctionNode& node, std::shared_ptr<ToDateHolder>* holder);
+
+ static Status Make(const std::string& sql_pattern, int32_t suppress_errors,
+ std::shared_ptr<ToDateHolder>* holder);
+
+ /// Return true if the data matches the pattern.
+ int64_t operator()(ExecutionContext* context, const char* data, int data_len,
+ bool in_valid, bool* out_valid);
+
+ private:
+ ToDateHolder(const std::string& pattern, int32_t suppress_errors)
+ : pattern_(pattern), suppress_errors_(suppress_errors) {}
+
+ void return_error(ExecutionContext* context, const char* data, int data_len);
+
+ std::string pattern_; // date format string
+
+ int32_t suppress_errors_; // should throw exception on runtime errors
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/to_date_holder_test.cc b/src/arrow/cpp/src/gandiva/to_date_holder_test.cc
new file mode 100644
index 000000000..a420774bf
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/to_date_holder_test.cc
@@ -0,0 +1,152 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <vector>
+
+#include "arrow/testing/gtest_util.h"
+
+#include "gandiva/execution_context.h"
+#include "gandiva/to_date_holder.h"
+#include "precompiled/epoch_time_point.h"
+
+#include <gtest/gtest.h>
+
+namespace gandiva {
+
+class TestToDateHolder : public ::testing::Test {
+ public:
+ FunctionNode BuildToDate(std::string pattern) {
+ auto field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8()));
+ auto pattern_node =
+ std::make_shared<LiteralNode>(arrow::utf8(), LiteralHolder(pattern), false);
+ auto suppress_error_node =
+ std::make_shared<LiteralNode>(arrow::int32(), LiteralHolder(0), false);
+ return FunctionNode("to_date_utf8_utf8_int32",
+ {field, pattern_node, suppress_error_node}, arrow::int64());
+ }
+
+ protected:
+ ExecutionContext execution_context_;
+};
+
+TEST_F(TestToDateHolder, TestSimpleDateTime) {
+ std::shared_ptr<ToDateHolder> to_date_holder;
+ ASSERT_OK(ToDateHolder::Make("YYYY-MM-DD HH:MI:SS", 1, &to_date_holder));
+
+ auto& to_date = *to_date_holder;
+ bool out_valid;
+ std::string s("1986-12-01 01:01:01");
+ int64_t millis_since_epoch =
+ to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid);
+ EXPECT_EQ(millis_since_epoch, 533779200000);
+
+ s = std::string("1986-12-01 01:01:01.11");
+ millis_since_epoch =
+ to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid);
+ EXPECT_EQ(millis_since_epoch, 533779200000);
+
+ s = std::string("1986-12-01 01:01:01 +0800");
+ millis_since_epoch =
+ to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid);
+ EXPECT_EQ(millis_since_epoch, 533779200000);
+
+#if 0
+ // TODO : this fails parsing with date::parse and strptime on linux
+ s = std::string("1886-12-01 00:00:00");
+ millis_since_epoch =
+ to_date(&execution_context_, s.data(), (int) s.length(), true, &out_valid);
+ EXPECT_EQ(out_valid, true);
+ EXPECT_EQ(millis_since_epoch, -2621894400000);
+#endif
+
+ s = std::string("1886-12-01 01:01:01");
+ millis_since_epoch =
+ to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid);
+ EXPECT_EQ(millis_since_epoch, -2621894400000);
+
+ s = std::string("1986-12-11 01:30:00");
+ millis_since_epoch =
+ to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid);
+ EXPECT_EQ(millis_since_epoch, 534643200000);
+}
+
+TEST_F(TestToDateHolder, TestSimpleDate) {
+ std::shared_ptr<ToDateHolder> to_date_holder;
+ ASSERT_OK(ToDateHolder::Make("YYYY-MM-DD", 1, &to_date_holder));
+
+ auto& to_date = *to_date_holder;
+ bool out_valid;
+ std::string s("1986-12-01");
+ int64_t millis_since_epoch =
+ to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid);
+ EXPECT_EQ(millis_since_epoch, 533779200000);
+
+ s = std::string("1986-12-01");
+ millis_since_epoch =
+ to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid);
+ EXPECT_EQ(millis_since_epoch, 533779200000);
+
+ s = std::string("1886-12-1");
+ millis_since_epoch =
+ to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid);
+ EXPECT_EQ(millis_since_epoch, -2621894400000);
+
+ s = std::string("2012-12-1");
+ millis_since_epoch =
+ to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid);
+ EXPECT_EQ(millis_since_epoch, 1354320000000);
+
+ // wrong month. should return 0 since we are suppressing errors.
+ s = std::string("1986-21-01 01:01:01 +0800");
+ millis_since_epoch =
+ to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid);
+ EXPECT_EQ(millis_since_epoch, 0);
+}
+
+TEST_F(TestToDateHolder, TestSimpleDateTimeError) {
+ std::shared_ptr<ToDateHolder> to_date_holder;
+
+ auto status = ToDateHolder::Make("YYYY-MM-DD HH:MI:SS", 0, &to_date_holder);
+ EXPECT_EQ(status.ok(), true) << status.message();
+ auto& to_date = *to_date_holder;
+ bool out_valid;
+
+ std::string s("1986-01-40 01:01:01 +0800");
+ int64_t millis_since_epoch =
+ to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid);
+ EXPECT_EQ(0, millis_since_epoch);
+ std::string expected_error =
+ "Error parsing value 1986-01-40 01:01:01 +0800 for given format";
+ EXPECT_TRUE(execution_context_.get_error().find(expected_error) != std::string::npos)
+ << status.message();
+
+ // not valid should not return error
+ execution_context_.Reset();
+ millis_since_epoch = to_date(&execution_context_, "nullptr", 7, false, &out_valid);
+ EXPECT_EQ(millis_since_epoch, 0);
+ EXPECT_TRUE(execution_context_.has_error() == false);
+}
+
+TEST_F(TestToDateHolder, TestSimpleDateTimeMakeError) {
+ std::shared_ptr<ToDateHolder> to_date_holder;
+ // reject time stamps for now.
+ auto status = ToDateHolder::Make("YYYY-MM-DD HH:MI:SS tzo", 0, &to_date_holder);
+ EXPECT_EQ(status.IsInvalid(), true) << status.message();
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tree_expr_builder.cc b/src/arrow/cpp/src/gandiva/tree_expr_builder.cc
new file mode 100644
index 000000000..de8e3445a
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tree_expr_builder.cc
@@ -0,0 +1,223 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/tree_expr_builder.h"
+
+#include <iostream>
+#include <utility>
+
+#include "gandiva/decimal_type_util.h"
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/node.h"
+
+namespace gandiva {
+
+#define MAKE_LITERAL(atype, ctype) \
+ NodePtr TreeExprBuilder::MakeLiteral(ctype value) { \
+ return std::make_shared<LiteralNode>(atype, LiteralHolder(value), false); \
+ }
+
+MAKE_LITERAL(arrow::boolean(), bool)
+MAKE_LITERAL(arrow::int8(), int8_t)
+MAKE_LITERAL(arrow::int16(), int16_t)
+MAKE_LITERAL(arrow::int32(), int32_t)
+MAKE_LITERAL(arrow::int64(), int64_t)
+MAKE_LITERAL(arrow::uint8(), uint8_t)
+MAKE_LITERAL(arrow::uint16(), uint16_t)
+MAKE_LITERAL(arrow::uint32(), uint32_t)
+MAKE_LITERAL(arrow::uint64(), uint64_t)
+MAKE_LITERAL(arrow::float32(), float)
+MAKE_LITERAL(arrow::float64(), double)
+
+NodePtr TreeExprBuilder::MakeStringLiteral(const std::string& value) {
+ return std::make_shared<LiteralNode>(arrow::utf8(), LiteralHolder(value), false);
+}
+
+NodePtr TreeExprBuilder::MakeBinaryLiteral(const std::string& value) {
+ return std::make_shared<LiteralNode>(arrow::binary(), LiteralHolder(value), false);
+}
+
+NodePtr TreeExprBuilder::MakeDecimalLiteral(const DecimalScalar128& value) {
+ return std::make_shared<LiteralNode>(arrow::decimal(value.precision(), value.scale()),
+ LiteralHolder(value), false);
+}
+
+NodePtr TreeExprBuilder::MakeNull(DataTypePtr data_type) {
+ static const std::string empty;
+
+ if (data_type == nullptr) {
+ return nullptr;
+ }
+
+ switch (data_type->id()) {
+ case arrow::Type::BOOL:
+ return std::make_shared<LiteralNode>(data_type, LiteralHolder(false), true);
+ case arrow::Type::INT8:
+ return std::make_shared<LiteralNode>(data_type, LiteralHolder((int8_t)0), true);
+ case arrow::Type::INT16:
+ return std::make_shared<LiteralNode>(data_type, LiteralHolder((int16_t)0), true);
+ return std::make_shared<LiteralNode>(data_type, LiteralHolder((int32_t)0), true);
+ case arrow::Type::UINT8:
+ return std::make_shared<LiteralNode>(data_type, LiteralHolder((uint8_t)0), true);
+ case arrow::Type::UINT16:
+ return std::make_shared<LiteralNode>(data_type, LiteralHolder((uint16_t)0), true);
+ case arrow::Type::UINT32:
+ return std::make_shared<LiteralNode>(data_type, LiteralHolder((uint32_t)0), true);
+ case arrow::Type::UINT64:
+ return std::make_shared<LiteralNode>(data_type, LiteralHolder((uint64_t)0), true);
+ case arrow::Type::FLOAT:
+ return std::make_shared<LiteralNode>(data_type,
+ LiteralHolder(static_cast<float>(0)), true);
+ case arrow::Type::DOUBLE:
+ return std::make_shared<LiteralNode>(data_type,
+ LiteralHolder(static_cast<double>(0)), true);
+ case arrow::Type::STRING:
+ case arrow::Type::BINARY:
+ return std::make_shared<LiteralNode>(data_type, LiteralHolder(empty), true);
+ case arrow::Type::INT32:
+ case arrow::Type::DATE32:
+ case arrow::Type::TIME32:
+ case arrow::Type::INTERVAL_MONTHS:
+ return std::make_shared<LiteralNode>(data_type, LiteralHolder((int32_t)0), true);
+ case arrow::Type::INT64:
+ case arrow::Type::DATE64:
+ case arrow::Type::TIME64:
+ case arrow::Type::TIMESTAMP:
+ case arrow::Type::INTERVAL_DAY_TIME:
+ return std::make_shared<LiteralNode>(data_type, LiteralHolder((int64_t)0), true);
+ case arrow::Type::DECIMAL: {
+ std::shared_ptr<arrow::DecimalType> decimal_type =
+ arrow::internal::checked_pointer_cast<arrow::DecimalType>(data_type);
+ DecimalScalar128 literal(decimal_type->precision(), decimal_type->scale());
+ return std::make_shared<LiteralNode>(data_type, LiteralHolder(literal), true);
+ }
+ default:
+ return nullptr;
+ }
+}
+
+NodePtr TreeExprBuilder::MakeField(FieldPtr field) {
+ return NodePtr(new FieldNode(field));
+}
+
+NodePtr TreeExprBuilder::MakeFunction(const std::string& name, const NodeVector& params,
+ DataTypePtr result_type) {
+ if (result_type == nullptr) {
+ return nullptr;
+ }
+ return std::make_shared<FunctionNode>(name, params, result_type);
+}
+
+NodePtr TreeExprBuilder::MakeIf(NodePtr condition, NodePtr then_node, NodePtr else_node,
+ DataTypePtr result_type) {
+ if (condition == nullptr || then_node == nullptr || else_node == nullptr ||
+ result_type == nullptr) {
+ return nullptr;
+ }
+ return std::make_shared<IfNode>(condition, then_node, else_node, result_type);
+}
+
+NodePtr TreeExprBuilder::MakeAnd(const NodeVector& children) {
+ return std::make_shared<BooleanNode>(BooleanNode::AND, children);
+}
+
+NodePtr TreeExprBuilder::MakeOr(const NodeVector& children) {
+ return std::make_shared<BooleanNode>(BooleanNode::OR, children);
+}
+
+// set this to true to print expressions for debugging purposes
+static bool print_expr = false;
+
+ExpressionPtr TreeExprBuilder::MakeExpression(NodePtr root_node, FieldPtr result_field) {
+ if (result_field == nullptr) {
+ return nullptr;
+ }
+ if (print_expr) {
+ std::cout << "Expression: " << root_node->ToString() << "\n";
+ }
+ return ExpressionPtr(new Expression(root_node, result_field));
+}
+
+ExpressionPtr TreeExprBuilder::MakeExpression(const std::string& function,
+ const FieldVector& in_fields,
+ FieldPtr out_field) {
+ if (out_field == nullptr) {
+ return nullptr;
+ }
+ std::vector<NodePtr> field_nodes;
+ for (auto& field : in_fields) {
+ auto node = MakeField(field);
+ field_nodes.push_back(node);
+ }
+ auto func_node = MakeFunction(function, field_nodes, out_field->type());
+ return MakeExpression(func_node, out_field);
+}
+
+ConditionPtr TreeExprBuilder::MakeCondition(NodePtr root_node) {
+ if (root_node == nullptr) {
+ return nullptr;
+ }
+ if (print_expr) {
+ std::cout << "Condition: " << root_node->ToString() << "\n";
+ }
+
+ return ConditionPtr(new Condition(root_node));
+}
+
+ConditionPtr TreeExprBuilder::MakeCondition(const std::string& function,
+ const FieldVector& in_fields) {
+ std::vector<NodePtr> field_nodes;
+ for (auto& field : in_fields) {
+ auto node = MakeField(field);
+ field_nodes.push_back(node);
+ }
+
+ auto func_node = MakeFunction(function, field_nodes, arrow::boolean());
+ return ConditionPtr(new Condition(func_node));
+}
+
+NodePtr TreeExprBuilder::MakeInExpressionDecimal(
+ NodePtr node, std::unordered_set<gandiva::DecimalScalar128>& constants) {
+ int32_t precision = 0;
+ int32_t scale = 0;
+ if (!constants.empty()) {
+ precision = constants.begin()->precision();
+ scale = constants.begin()->scale();
+ }
+ return std::make_shared<InExpressionNode<gandiva::DecimalScalar128>>(node, constants,
+ precision, scale);
+}
+
+#define MAKE_IN(NAME, ctype) \
+ NodePtr TreeExprBuilder::MakeInExpression##NAME( \
+ NodePtr node, const std::unordered_set<ctype>& values) { \
+ return std::make_shared<InExpressionNode<ctype>>(node, values); \
+ }
+
+MAKE_IN(Int32, int32_t);
+MAKE_IN(Int64, int64_t);
+MAKE_IN(Date32, int32_t);
+MAKE_IN(Date64, int64_t);
+MAKE_IN(TimeStamp, int64_t);
+MAKE_IN(Time32, int32_t);
+MAKE_IN(Time64, int64_t);
+MAKE_IN(Float, float);
+MAKE_IN(Double, double);
+MAKE_IN(String, std::string);
+MAKE_IN(Binary, std::string);
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tree_expr_builder.h b/src/arrow/cpp/src/gandiva/tree_expr_builder.h
new file mode 100644
index 000000000..94a4a1793
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tree_expr_builder.h
@@ -0,0 +1,139 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cmath>
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "arrow/type.h"
+#include "gandiva/condition.h"
+#include "gandiva/decimal_scalar.h"
+#include "gandiva/expression.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// \brief Tree Builder for a nested expression.
+class GANDIVA_EXPORT TreeExprBuilder {
+ public:
+ /// \brief create a node on a literal.
+ static NodePtr MakeLiteral(bool value);
+ static NodePtr MakeLiteral(uint8_t value);
+ static NodePtr MakeLiteral(uint16_t value);
+ static NodePtr MakeLiteral(uint32_t value);
+ static NodePtr MakeLiteral(uint64_t value);
+ static NodePtr MakeLiteral(int8_t value);
+ static NodePtr MakeLiteral(int16_t value);
+ static NodePtr MakeLiteral(int32_t value);
+ static NodePtr MakeLiteral(int64_t value);
+ static NodePtr MakeLiteral(float value);
+ static NodePtr MakeLiteral(double value);
+ static NodePtr MakeStringLiteral(const std::string& value);
+ static NodePtr MakeBinaryLiteral(const std::string& value);
+ static NodePtr MakeDecimalLiteral(const DecimalScalar128& value);
+
+ /// \brief create a node on a null literal.
+ /// returns null if data_type is null or if it's not a supported datatype.
+ static NodePtr MakeNull(DataTypePtr data_type);
+
+ /// \brief create a node on arrow field.
+ /// returns null if input is null.
+ static NodePtr MakeField(FieldPtr field);
+
+ /// \brief create a node with a function.
+ /// returns null if return_type is null
+ static NodePtr MakeFunction(const std::string& name, const NodeVector& params,
+ DataTypePtr return_type);
+
+ /// \brief create a node with an if-else expression.
+ /// returns null if any of the inputs is null.
+ static NodePtr MakeIf(NodePtr condition, NodePtr then_node, NodePtr else_node,
+ DataTypePtr result_type);
+
+ /// \brief create a node with a boolean AND expression.
+ static NodePtr MakeAnd(const NodeVector& children);
+
+ /// \brief create a node with a boolean OR expression.
+ static NodePtr MakeOr(const NodeVector& children);
+
+ /// \brief create an expression with the specified root_node, and the
+ /// result written to result_field.
+ /// returns null if the result_field is null.
+ static ExpressionPtr MakeExpression(NodePtr root_node, FieldPtr result_field);
+
+ /// \brief convenience function for simple function expressions.
+ /// returns null if the out_field is null.
+ static ExpressionPtr MakeExpression(const std::string& function,
+ const FieldVector& in_fields, FieldPtr out_field);
+
+ /// \brief create a condition with the specified root_node
+ static ConditionPtr MakeCondition(NodePtr root_node);
+
+ /// \brief convenience function for simple function conditions.
+ static ConditionPtr MakeCondition(const std::string& function,
+ const FieldVector& in_fields);
+
+ /// \brief creates an in expression
+ static NodePtr MakeInExpressionInt32(NodePtr node,
+ const std::unordered_set<int32_t>& constants);
+
+ static NodePtr MakeInExpressionInt64(NodePtr node,
+ const std::unordered_set<int64_t>& constants);
+
+ static NodePtr MakeInExpressionDecimal(
+ NodePtr node, std::unordered_set<gandiva::DecimalScalar128>& constants);
+
+ static NodePtr MakeInExpressionString(NodePtr node,
+ const std::unordered_set<std::string>& constants);
+
+ static NodePtr MakeInExpressionBinary(NodePtr node,
+ const std::unordered_set<std::string>& constants);
+
+ /// \brief creates an in expression for float
+ static NodePtr MakeInExpressionFloat(NodePtr node,
+ const std::unordered_set<float>& constants);
+
+ /// \brief creates an in expression for double
+ static NodePtr MakeInExpressionDouble(NodePtr node,
+ const std::unordered_set<double>& constants);
+
+ /// \brief Date as s/millis since epoch.
+ static NodePtr MakeInExpressionDate32(NodePtr node,
+ const std::unordered_set<int32_t>& constants);
+
+ /// \brief Date as millis/us/ns since epoch.
+ static NodePtr MakeInExpressionDate64(NodePtr node,
+ const std::unordered_set<int64_t>& constants);
+
+ /// \brief Time as s/millis of day
+ static NodePtr MakeInExpressionTime32(NodePtr node,
+ const std::unordered_set<int32_t>& constants);
+
+ /// \brief Time as millis/us/ns of day
+ static NodePtr MakeInExpressionTime64(NodePtr node,
+ const std::unordered_set<int64_t>& constants);
+
+ /// \brief Timestamp as millis since epoch.
+ static NodePtr MakeInExpressionTimeStamp(NodePtr node,
+ const std::unordered_set<int64_t>& constants);
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/tree_expr_test.cc b/src/arrow/cpp/src/gandiva/tree_expr_test.cc
new file mode 100644
index 000000000..e70cf1289
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/tree_expr_test.cc
@@ -0,0 +1,159 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/tree_expr_builder.h"
+
+#include <gtest/gtest.h>
+#include "gandiva/annotator.h"
+#include "gandiva/dex.h"
+#include "gandiva/expr_decomposer.h"
+#include "gandiva/function_registry.h"
+#include "gandiva/function_signature.h"
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/node.h"
+
+namespace gandiva {
+
+using arrow::boolean;
+using arrow::int32;
+
+class TestExprTree : public ::testing::Test {
+ public:
+ void SetUp() {
+ i0_ = field("i0", int32());
+ i1_ = field("i1", int32());
+
+ b0_ = field("b0", boolean());
+ }
+
+ protected:
+ FieldPtr i0_; // int32
+ FieldPtr i1_; // int32
+
+ FieldPtr b0_; // bool
+ FunctionRegistry registry_;
+};
+
+TEST_F(TestExprTree, TestField) {
+ Annotator annotator;
+
+ auto n0 = TreeExprBuilder::MakeField(i0_);
+ EXPECT_EQ(n0->return_type(), int32());
+
+ auto n1 = TreeExprBuilder::MakeField(b0_);
+ EXPECT_EQ(n1->return_type(), boolean());
+
+ ExprDecomposer decomposer(registry_, annotator);
+ ValueValidityPairPtr pair;
+ auto status = decomposer.Decompose(*n1, &pair);
+ DCHECK_EQ(status.ok(), true) << status.message();
+
+ auto value = pair->value_expr();
+ auto value_dex = std::dynamic_pointer_cast<VectorReadFixedLenValueDex>(value);
+ EXPECT_EQ(value_dex->FieldType(), boolean());
+
+ EXPECT_EQ(pair->validity_exprs().size(), 1);
+ auto validity = pair->validity_exprs().at(0);
+ auto validity_dex = std::dynamic_pointer_cast<VectorReadValidityDex>(validity);
+ EXPECT_NE(validity_dex->ValidityIdx(), value_dex->DataIdx());
+}
+
+TEST_F(TestExprTree, TestBinary) {
+ Annotator annotator;
+
+ auto left = TreeExprBuilder::MakeField(i0_);
+ auto right = TreeExprBuilder::MakeField(i1_);
+
+ auto n = TreeExprBuilder::MakeFunction("add", {left, right}, int32());
+ auto add = std::dynamic_pointer_cast<FunctionNode>(n);
+
+ auto func_desc = add->descriptor();
+ FunctionSignature sign(func_desc->name(), func_desc->params(),
+ func_desc->return_type());
+
+ EXPECT_EQ(add->return_type(), int32());
+ EXPECT_TRUE(sign == FunctionSignature("add", {int32(), int32()}, int32()));
+
+ ExprDecomposer decomposer(registry_, annotator);
+ ValueValidityPairPtr pair;
+ auto status = decomposer.Decompose(*n, &pair);
+ DCHECK_EQ(status.ok(), true) << status.message();
+
+ auto value = pair->value_expr();
+ auto null_if_null = std::dynamic_pointer_cast<NonNullableFuncDex>(value);
+
+ FunctionSignature signature("add", {int32(), int32()}, int32());
+ const NativeFunction* fn = registry_.LookupSignature(signature);
+ EXPECT_EQ(null_if_null->native_function(), fn);
+}
+
+TEST_F(TestExprTree, TestUnary) {
+ Annotator annotator;
+
+ auto arg = TreeExprBuilder::MakeField(i0_);
+ auto n = TreeExprBuilder::MakeFunction("isnumeric", {arg}, boolean());
+
+ auto unaryFn = std::dynamic_pointer_cast<FunctionNode>(n);
+ auto func_desc = unaryFn->descriptor();
+ FunctionSignature sign(func_desc->name(), func_desc->params(),
+ func_desc->return_type());
+ EXPECT_EQ(unaryFn->return_type(), boolean());
+ EXPECT_TRUE(sign == FunctionSignature("isnumeric", {int32()}, boolean()));
+
+ ExprDecomposer decomposer(registry_, annotator);
+ ValueValidityPairPtr pair;
+ auto status = decomposer.Decompose(*n, &pair);
+ DCHECK_EQ(status.ok(), true) << status.message();
+
+ auto value = pair->value_expr();
+ auto never_null = std::dynamic_pointer_cast<NullableNeverFuncDex>(value);
+
+ FunctionSignature signature("isnumeric", {int32()}, boolean());
+ const NativeFunction* fn = registry_.LookupSignature(signature);
+ EXPECT_EQ(never_null->native_function(), fn);
+}
+
+TEST_F(TestExprTree, TestExpression) {
+ Annotator annotator;
+ auto left = TreeExprBuilder::MakeField(i0_);
+ auto right = TreeExprBuilder::MakeField(i1_);
+
+ auto n = TreeExprBuilder::MakeFunction("add", {left, right}, int32());
+ auto e = TreeExprBuilder::MakeExpression(n, field("r", int32()));
+ auto root_node = e->root();
+ EXPECT_EQ(root_node->return_type(), int32());
+
+ auto add_node = std::dynamic_pointer_cast<FunctionNode>(root_node);
+ auto func_desc = add_node->descriptor();
+ FunctionSignature sign(func_desc->name(), func_desc->params(),
+ func_desc->return_type());
+ EXPECT_TRUE(sign == FunctionSignature("add", {int32(), int32()}, int32()));
+
+ ExprDecomposer decomposer(registry_, annotator);
+ ValueValidityPairPtr pair;
+ auto status = decomposer.Decompose(*root_node, &pair);
+ DCHECK_EQ(status.ok(), true) << status.message();
+
+ auto value = pair->value_expr();
+ auto null_if_null = std::dynamic_pointer_cast<NonNullableFuncDex>(value);
+
+ FunctionSignature signature("add", {int32(), int32()}, int32());
+ const NativeFunction* fn = registry_.LookupSignature(signature);
+ EXPECT_EQ(null_if_null->native_function(), fn);
+}
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/value_validity_pair.h b/src/arrow/cpp/src/gandiva/value_validity_pair.h
new file mode 100644
index 000000000..e5943b230
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/value_validity_pair.h
@@ -0,0 +1,48 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <vector>
+
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/visibility.h"
+
+namespace gandiva {
+
+/// Pair of vector/validities generated after decomposing an expression tree/subtree.
+class GANDIVA_EXPORT ValueValidityPair {
+ public:
+ ValueValidityPair(const DexVector& validity_exprs, DexPtr value_expr)
+ : validity_exprs_(validity_exprs), value_expr_(value_expr) {}
+
+ ValueValidityPair(DexPtr validity_expr, DexPtr value_expr) : value_expr_(value_expr) {
+ validity_exprs_.push_back(validity_expr);
+ }
+
+ explicit ValueValidityPair(DexPtr value_expr) : value_expr_(value_expr) {}
+
+ const DexVector& validity_exprs() const { return validity_exprs_; }
+
+ const DexPtr& value_expr() const { return value_expr_; }
+
+ private:
+ DexVector validity_exprs_;
+ DexPtr value_expr_;
+};
+
+} // namespace gandiva
diff --git a/src/arrow/cpp/src/gandiva/visibility.h b/src/arrow/cpp/src/gandiva/visibility.h
new file mode 100644
index 000000000..450b3056b
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/visibility.h
@@ -0,0 +1,48 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#if defined(_WIN32) || defined(__CYGWIN__)
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4251)
+#else
+#pragma GCC diagnostic ignored "-Wattributes"
+#endif
+
+#ifdef GANDIVA_STATIC
+#define GANDIVA_EXPORT
+#elif defined(GANDIVA_EXPORTING)
+#define GANDIVA_EXPORT __declspec(dllexport)
+#else
+#define GANDIVA_EXPORT __declspec(dllimport)
+#endif
+
+#define GANDIVA_NO_EXPORT
+#else // Not Windows
+#ifndef GANDIVA_EXPORT
+#define GANDIVA_EXPORT __attribute__((visibility("default")))
+#endif
+#ifndef GANDIVA_NO_EXPORT
+#define GANDIVA_NO_EXPORT __attribute__((visibility("hidden")))
+#endif
+#endif // Non-Windows
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
diff --git a/src/arrow/cpp/src/generated/Expression_generated.h b/src/arrow/cpp/src/generated/Expression_generated.h
new file mode 100644
index 000000000..730b00149
--- /dev/null
+++ b/src/arrow/cpp/src/generated/Expression_generated.h
@@ -0,0 +1,1870 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_EXPRESSION_ORG_APACHE_ARROW_COMPUTEIR_FLATBUF_H_
+#define FLATBUFFERS_GENERATED_EXPRESSION_ORG_APACHE_ARROW_COMPUTEIR_FLATBUF_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+#include "Schema_generated.h"
+#include "Literal_generated.h"
+
+namespace org {
+namespace apache {
+namespace arrow {
+namespace computeir {
+namespace flatbuf {
+
+struct MapKey;
+struct MapKeyBuilder;
+
+struct StructField;
+struct StructFieldBuilder;
+
+struct ArraySubscript;
+struct ArraySubscriptBuilder;
+
+struct ArraySlice;
+struct ArraySliceBuilder;
+
+struct FieldIndex;
+struct FieldIndexBuilder;
+
+struct FieldRef;
+struct FieldRefBuilder;
+
+struct Call;
+struct CallBuilder;
+
+struct CaseFragment;
+struct CaseFragmentBuilder;
+
+struct ConditionalCase;
+struct ConditionalCaseBuilder;
+
+struct SimpleCase;
+struct SimpleCaseBuilder;
+
+struct SortKey;
+struct SortKeyBuilder;
+
+struct Unbounded;
+struct UnboundedBuilder;
+
+struct Preceding;
+struct PrecedingBuilder;
+
+struct Following;
+struct FollowingBuilder;
+
+struct CurrentRow;
+struct CurrentRowBuilder;
+
+struct WindowCall;
+struct WindowCallBuilder;
+
+struct Cast;
+struct CastBuilder;
+
+struct Expression;
+struct ExpressionBuilder;
+
+/// A union of possible dereference operations
+enum class Deref : uint8_t {
+ NONE = 0,
+ /// Access a value for a given map key
+ MapKey = 1,
+ /// Access the value at a struct field
+ StructField = 2,
+ /// Access the element at a given index in an array
+ ArraySubscript = 3,
+ /// Access a range of elements in an array
+ ArraySlice = 4,
+ /// Access a field of a relation
+ FieldIndex = 5,
+ MIN = NONE,
+ MAX = FieldIndex
+};
+
+inline const Deref (&EnumValuesDeref())[6] {
+ static const Deref values[] = {
+ Deref::NONE,
+ Deref::MapKey,
+ Deref::StructField,
+ Deref::ArraySubscript,
+ Deref::ArraySlice,
+ Deref::FieldIndex
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesDeref() {
+ static const char * const names[7] = {
+ "NONE",
+ "MapKey",
+ "StructField",
+ "ArraySubscript",
+ "ArraySlice",
+ "FieldIndex",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameDeref(Deref e) {
+ if (flatbuffers::IsOutRange(e, Deref::NONE, Deref::FieldIndex)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesDeref()[index];
+}
+
+template<typename T> struct DerefTraits {
+ static const Deref enum_value = Deref::NONE;
+};
+
+template<> struct DerefTraits<org::apache::arrow::computeir::flatbuf::MapKey> {
+ static const Deref enum_value = Deref::MapKey;
+};
+
+template<> struct DerefTraits<org::apache::arrow::computeir::flatbuf::StructField> {
+ static const Deref enum_value = Deref::StructField;
+};
+
+template<> struct DerefTraits<org::apache::arrow::computeir::flatbuf::ArraySubscript> {
+ static const Deref enum_value = Deref::ArraySubscript;
+};
+
+template<> struct DerefTraits<org::apache::arrow::computeir::flatbuf::ArraySlice> {
+ static const Deref enum_value = Deref::ArraySlice;
+};
+
+template<> struct DerefTraits<org::apache::arrow::computeir::flatbuf::FieldIndex> {
+ static const Deref enum_value = Deref::FieldIndex;
+};
+
+bool VerifyDeref(flatbuffers::Verifier &verifier, const void *obj, Deref type);
+bool VerifyDerefVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+/// Whether lesser values should precede greater or vice versa,
+/// also whether nulls should preced or follow values
+enum class Ordering : uint8_t {
+ ASCENDING_THEN_NULLS = 0,
+ DESCENDING_THEN_NULLS = 1,
+ NULLS_THEN_ASCENDING = 2,
+ NULLS_THEN_DESCENDING = 3,
+ MIN = ASCENDING_THEN_NULLS,
+ MAX = NULLS_THEN_DESCENDING
+};
+
+inline const Ordering (&EnumValuesOrdering())[4] {
+ static const Ordering values[] = {
+ Ordering::ASCENDING_THEN_NULLS,
+ Ordering::DESCENDING_THEN_NULLS,
+ Ordering::NULLS_THEN_ASCENDING,
+ Ordering::NULLS_THEN_DESCENDING
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesOrdering() {
+ static const char * const names[5] = {
+ "ASCENDING_THEN_NULLS",
+ "DESCENDING_THEN_NULLS",
+ "NULLS_THEN_ASCENDING",
+ "NULLS_THEN_DESCENDING",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameOrdering(Ordering e) {
+ if (flatbuffers::IsOutRange(e, Ordering::ASCENDING_THEN_NULLS, Ordering::NULLS_THEN_DESCENDING)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesOrdering()[index];
+}
+
+/// A concrete bound, which can be an expression or unbounded
+enum class ConcreteBoundImpl : uint8_t {
+ NONE = 0,
+ Expression = 1,
+ Unbounded = 2,
+ MIN = NONE,
+ MAX = Unbounded
+};
+
+inline const ConcreteBoundImpl (&EnumValuesConcreteBoundImpl())[3] {
+ static const ConcreteBoundImpl values[] = {
+ ConcreteBoundImpl::NONE,
+ ConcreteBoundImpl::Expression,
+ ConcreteBoundImpl::Unbounded
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesConcreteBoundImpl() {
+ static const char * const names[4] = {
+ "NONE",
+ "Expression",
+ "Unbounded",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameConcreteBoundImpl(ConcreteBoundImpl e) {
+ if (flatbuffers::IsOutRange(e, ConcreteBoundImpl::NONE, ConcreteBoundImpl::Unbounded)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesConcreteBoundImpl()[index];
+}
+
+template<typename T> struct ConcreteBoundImplTraits {
+ static const ConcreteBoundImpl enum_value = ConcreteBoundImpl::NONE;
+};
+
+template<> struct ConcreteBoundImplTraits<org::apache::arrow::computeir::flatbuf::Expression> {
+ static const ConcreteBoundImpl enum_value = ConcreteBoundImpl::Expression;
+};
+
+template<> struct ConcreteBoundImplTraits<org::apache::arrow::computeir::flatbuf::Unbounded> {
+ static const ConcreteBoundImpl enum_value = ConcreteBoundImpl::Unbounded;
+};
+
+bool VerifyConcreteBoundImpl(flatbuffers::Verifier &verifier, const void *obj, ConcreteBoundImpl type);
+bool VerifyConcreteBoundImplVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+enum class Bound : uint8_t {
+ NONE = 0,
+ Preceding = 1,
+ Following = 2,
+ CurrentRow = 3,
+ MIN = NONE,
+ MAX = CurrentRow
+};
+
+inline const Bound (&EnumValuesBound())[4] {
+ static const Bound values[] = {
+ Bound::NONE,
+ Bound::Preceding,
+ Bound::Following,
+ Bound::CurrentRow
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesBound() {
+ static const char * const names[5] = {
+ "NONE",
+ "Preceding",
+ "Following",
+ "CurrentRow",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameBound(Bound e) {
+ if (flatbuffers::IsOutRange(e, Bound::NONE, Bound::CurrentRow)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesBound()[index];
+}
+
+template<typename T> struct BoundTraits {
+ static const Bound enum_value = Bound::NONE;
+};
+
+template<> struct BoundTraits<org::apache::arrow::computeir::flatbuf::Preceding> {
+ static const Bound enum_value = Bound::Preceding;
+};
+
+template<> struct BoundTraits<org::apache::arrow::computeir::flatbuf::Following> {
+ static const Bound enum_value = Bound::Following;
+};
+
+template<> struct BoundTraits<org::apache::arrow::computeir::flatbuf::CurrentRow> {
+ static const Bound enum_value = Bound::CurrentRow;
+};
+
+bool VerifyBound(flatbuffers::Verifier &verifier, const void *obj, Bound type);
+bool VerifyBoundVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+/// The kind of window function to be executed
+enum class Frame : uint8_t {
+ Rows = 0,
+ Range = 1,
+ MIN = Rows,
+ MAX = Range
+};
+
+inline const Frame (&EnumValuesFrame())[2] {
+ static const Frame values[] = {
+ Frame::Rows,
+ Frame::Range
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesFrame() {
+ static const char * const names[3] = {
+ "Rows",
+ "Range",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameFrame(Frame e) {
+ if (flatbuffers::IsOutRange(e, Frame::Rows, Frame::Range)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesFrame()[index];
+}
+
+/// Various expression types
+///
+/// WindowCall is a separate variant
+/// due to special options for each that don't apply to generic
+/// function calls. Again this is done to make it easier
+/// for consumers to deal with the structure of the operation
+enum class ExpressionImpl : uint8_t {
+ NONE = 0,
+ Literal = 1,
+ FieldRef = 2,
+ Call = 3,
+ ConditionalCase = 4,
+ SimpleCase = 5,
+ WindowCall = 6,
+ Cast = 7,
+ MIN = NONE,
+ MAX = Cast
+};
+
+inline const ExpressionImpl (&EnumValuesExpressionImpl())[8] {
+ static const ExpressionImpl values[] = {
+ ExpressionImpl::NONE,
+ ExpressionImpl::Literal,
+ ExpressionImpl::FieldRef,
+ ExpressionImpl::Call,
+ ExpressionImpl::ConditionalCase,
+ ExpressionImpl::SimpleCase,
+ ExpressionImpl::WindowCall,
+ ExpressionImpl::Cast
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesExpressionImpl() {
+ static const char * const names[9] = {
+ "NONE",
+ "Literal",
+ "FieldRef",
+ "Call",
+ "ConditionalCase",
+ "SimpleCase",
+ "WindowCall",
+ "Cast",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameExpressionImpl(ExpressionImpl e) {
+ if (flatbuffers::IsOutRange(e, ExpressionImpl::NONE, ExpressionImpl::Cast)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesExpressionImpl()[index];
+}
+
+template<typename T> struct ExpressionImplTraits {
+ static const ExpressionImpl enum_value = ExpressionImpl::NONE;
+};
+
+template<> struct ExpressionImplTraits<org::apache::arrow::computeir::flatbuf::Literal> {
+ static const ExpressionImpl enum_value = ExpressionImpl::Literal;
+};
+
+template<> struct ExpressionImplTraits<org::apache::arrow::computeir::flatbuf::FieldRef> {
+ static const ExpressionImpl enum_value = ExpressionImpl::FieldRef;
+};
+
+template<> struct ExpressionImplTraits<org::apache::arrow::computeir::flatbuf::Call> {
+ static const ExpressionImpl enum_value = ExpressionImpl::Call;
+};
+
+template<> struct ExpressionImplTraits<org::apache::arrow::computeir::flatbuf::ConditionalCase> {
+ static const ExpressionImpl enum_value = ExpressionImpl::ConditionalCase;
+};
+
+template<> struct ExpressionImplTraits<org::apache::arrow::computeir::flatbuf::SimpleCase> {
+ static const ExpressionImpl enum_value = ExpressionImpl::SimpleCase;
+};
+
+template<> struct ExpressionImplTraits<org::apache::arrow::computeir::flatbuf::WindowCall> {
+ static const ExpressionImpl enum_value = ExpressionImpl::WindowCall;
+};
+
+template<> struct ExpressionImplTraits<org::apache::arrow::computeir::flatbuf::Cast> {
+ static const ExpressionImpl enum_value = ExpressionImpl::Cast;
+};
+
+bool VerifyExpressionImpl(flatbuffers::Verifier &verifier, const void *obj, ExpressionImpl type);
+bool VerifyExpressionImplVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+/// Access a value for a given map key
+struct MapKey FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef MapKeyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_KEY = 4
+ };
+ /// Any expression can be a map key.
+ const org::apache::arrow::computeir::flatbuf::Expression *key() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Expression *>(VT_KEY);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_KEY) &&
+ verifier.VerifyTable(key()) &&
+ verifier.EndTable();
+ }
+};
+
+struct MapKeyBuilder {
+ typedef MapKey Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_key(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> key) {
+ fbb_.AddOffset(MapKey::VT_KEY, key);
+ }
+ explicit MapKeyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ MapKeyBuilder &operator=(const MapKeyBuilder &);
+ flatbuffers::Offset<MapKey> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<MapKey>(end);
+ fbb_.Required(o, MapKey::VT_KEY);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<MapKey> CreateMapKey(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> key = 0) {
+ MapKeyBuilder builder_(_fbb);
+ builder_.add_key(key);
+ return builder_.Finish();
+}
+
+/// Struct field access
+struct StructField FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef StructFieldBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_POSITION = 4
+ };
+ /// The position of the field in the struct schema
+ uint32_t position() const {
+ return GetField<uint32_t>(VT_POSITION, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint32_t>(verifier, VT_POSITION) &&
+ verifier.EndTable();
+ }
+};
+
+struct StructFieldBuilder {
+ typedef StructField Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_position(uint32_t position) {
+ fbb_.AddElement<uint32_t>(StructField::VT_POSITION, position, 0);
+ }
+ explicit StructFieldBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ StructFieldBuilder &operator=(const StructFieldBuilder &);
+ flatbuffers::Offset<StructField> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<StructField>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<StructField> CreateStructField(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ uint32_t position = 0) {
+ StructFieldBuilder builder_(_fbb);
+ builder_.add_position(position);
+ return builder_.Finish();
+}
+
+/// Zero-based array index
+struct ArraySubscript FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ArraySubscriptBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_POSITION = 4
+ };
+ uint32_t position() const {
+ return GetField<uint32_t>(VT_POSITION, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint32_t>(verifier, VT_POSITION) &&
+ verifier.EndTable();
+ }
+};
+
+struct ArraySubscriptBuilder {
+ typedef ArraySubscript Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_position(uint32_t position) {
+ fbb_.AddElement<uint32_t>(ArraySubscript::VT_POSITION, position, 0);
+ }
+ explicit ArraySubscriptBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ArraySubscriptBuilder &operator=(const ArraySubscriptBuilder &);
+ flatbuffers::Offset<ArraySubscript> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ArraySubscript>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ArraySubscript> CreateArraySubscript(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ uint32_t position = 0) {
+ ArraySubscriptBuilder builder_(_fbb);
+ builder_.add_position(position);
+ return builder_.Finish();
+}
+
+/// Zero-based range of elements in an array
+struct ArraySlice FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ArraySliceBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_START_INCLUSIVE = 4,
+ VT_END_EXCLUSIVE = 6
+ };
+ /// The start of an array slice, inclusive
+ uint32_t start_inclusive() const {
+ return GetField<uint32_t>(VT_START_INCLUSIVE, 0);
+ }
+ /// The end of an array slice, exclusive
+ uint32_t end_exclusive() const {
+ return GetField<uint32_t>(VT_END_EXCLUSIVE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint32_t>(verifier, VT_START_INCLUSIVE) &&
+ VerifyField<uint32_t>(verifier, VT_END_EXCLUSIVE) &&
+ verifier.EndTable();
+ }
+};
+
+struct ArraySliceBuilder {
+ typedef ArraySlice Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_start_inclusive(uint32_t start_inclusive) {
+ fbb_.AddElement<uint32_t>(ArraySlice::VT_START_INCLUSIVE, start_inclusive, 0);
+ }
+ void add_end_exclusive(uint32_t end_exclusive) {
+ fbb_.AddElement<uint32_t>(ArraySlice::VT_END_EXCLUSIVE, end_exclusive, 0);
+ }
+ explicit ArraySliceBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ArraySliceBuilder &operator=(const ArraySliceBuilder &);
+ flatbuffers::Offset<ArraySlice> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ArraySlice>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ArraySlice> CreateArraySlice(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ uint32_t start_inclusive = 0,
+ uint32_t end_exclusive = 0) {
+ ArraySliceBuilder builder_(_fbb);
+ builder_.add_end_exclusive(end_exclusive);
+ builder_.add_start_inclusive(start_inclusive);
+ return builder_.Finish();
+}
+
+/// Field name in a relation, in ordinal position of the relation's schema.
+struct FieldIndex FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FieldIndexBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_POSITION = 4
+ };
+ uint32_t position() const {
+ return GetField<uint32_t>(VT_POSITION, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint32_t>(verifier, VT_POSITION) &&
+ verifier.EndTable();
+ }
+};
+
+struct FieldIndexBuilder {
+ typedef FieldIndex Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_position(uint32_t position) {
+ fbb_.AddElement<uint32_t>(FieldIndex::VT_POSITION, position, 0);
+ }
+ explicit FieldIndexBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FieldIndexBuilder &operator=(const FieldIndexBuilder &);
+ flatbuffers::Offset<FieldIndex> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FieldIndex>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FieldIndex> CreateFieldIndex(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ uint32_t position = 0) {
+ FieldIndexBuilder builder_(_fbb);
+ builder_.add_position(position);
+ return builder_.Finish();
+}
+
+/// Access the data of a field
+struct FieldRef FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FieldRefBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_REF_TYPE = 4,
+ VT_REF = 6,
+ VT_RELATION_INDEX = 8
+ };
+ org::apache::arrow::computeir::flatbuf::Deref ref_type() const {
+ return static_cast<org::apache::arrow::computeir::flatbuf::Deref>(GetField<uint8_t>(VT_REF_TYPE, 0));
+ }
+ const void *ref() const {
+ return GetPointer<const void *>(VT_REF);
+ }
+ template<typename T> const T *ref_as() const;
+ const org::apache::arrow::computeir::flatbuf::MapKey *ref_as_MapKey() const {
+ return ref_type() == org::apache::arrow::computeir::flatbuf::Deref::MapKey ? static_cast<const org::apache::arrow::computeir::flatbuf::MapKey *>(ref()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::StructField *ref_as_StructField() const {
+ return ref_type() == org::apache::arrow::computeir::flatbuf::Deref::StructField ? static_cast<const org::apache::arrow::computeir::flatbuf::StructField *>(ref()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::ArraySubscript *ref_as_ArraySubscript() const {
+ return ref_type() == org::apache::arrow::computeir::flatbuf::Deref::ArraySubscript ? static_cast<const org::apache::arrow::computeir::flatbuf::ArraySubscript *>(ref()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::ArraySlice *ref_as_ArraySlice() const {
+ return ref_type() == org::apache::arrow::computeir::flatbuf::Deref::ArraySlice ? static_cast<const org::apache::arrow::computeir::flatbuf::ArraySlice *>(ref()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::FieldIndex *ref_as_FieldIndex() const {
+ return ref_type() == org::apache::arrow::computeir::flatbuf::Deref::FieldIndex ? static_cast<const org::apache::arrow::computeir::flatbuf::FieldIndex *>(ref()) : nullptr;
+ }
+ /// For Expressions which might reference fields in multiple Relations,
+ /// this index may be provided to indicate which Relation's fields
+ /// `ref` points into. For example in the case of a join,
+ /// 0 refers to the left relation and 1 to the right relation.
+ int32_t relation_index() const {
+ return GetField<int32_t>(VT_RELATION_INDEX, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_REF_TYPE) &&
+ VerifyOffsetRequired(verifier, VT_REF) &&
+ VerifyDeref(verifier, ref(), ref_type()) &&
+ VerifyField<int32_t>(verifier, VT_RELATION_INDEX) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const org::apache::arrow::computeir::flatbuf::MapKey *FieldRef::ref_as<org::apache::arrow::computeir::flatbuf::MapKey>() const {
+ return ref_as_MapKey();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::StructField *FieldRef::ref_as<org::apache::arrow::computeir::flatbuf::StructField>() const {
+ return ref_as_StructField();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::ArraySubscript *FieldRef::ref_as<org::apache::arrow::computeir::flatbuf::ArraySubscript>() const {
+ return ref_as_ArraySubscript();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::ArraySlice *FieldRef::ref_as<org::apache::arrow::computeir::flatbuf::ArraySlice>() const {
+ return ref_as_ArraySlice();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::FieldIndex *FieldRef::ref_as<org::apache::arrow::computeir::flatbuf::FieldIndex>() const {
+ return ref_as_FieldIndex();
+}
+
+struct FieldRefBuilder {
+ typedef FieldRef Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_ref_type(org::apache::arrow::computeir::flatbuf::Deref ref_type) {
+ fbb_.AddElement<uint8_t>(FieldRef::VT_REF_TYPE, static_cast<uint8_t>(ref_type), 0);
+ }
+ void add_ref(flatbuffers::Offset<void> ref) {
+ fbb_.AddOffset(FieldRef::VT_REF, ref);
+ }
+ void add_relation_index(int32_t relation_index) {
+ fbb_.AddElement<int32_t>(FieldRef::VT_RELATION_INDEX, relation_index, 0);
+ }
+ explicit FieldRefBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FieldRefBuilder &operator=(const FieldRefBuilder &);
+ flatbuffers::Offset<FieldRef> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FieldRef>(end);
+ fbb_.Required(o, FieldRef::VT_REF);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FieldRef> CreateFieldRef(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::computeir::flatbuf::Deref ref_type = org::apache::arrow::computeir::flatbuf::Deref::NONE,
+ flatbuffers::Offset<void> ref = 0,
+ int32_t relation_index = 0) {
+ FieldRefBuilder builder_(_fbb);
+ builder_.add_relation_index(relation_index);
+ builder_.add_ref(ref);
+ builder_.add_ref_type(ref_type);
+ return builder_.Finish();
+}
+
+/// A function call expression
+struct Call FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef CallBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_NAME = 4,
+ VT_ARGUMENTS = 6,
+ VT_ORDERINGS = 8
+ };
+ /// The function to call
+ const flatbuffers::String *name() const {
+ return GetPointer<const flatbuffers::String *>(VT_NAME);
+ }
+ /// The arguments passed to `name`.
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *arguments() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *>(VT_ARGUMENTS);
+ }
+ /// Possible ordering of input. These are useful
+ /// in aggregates where ordering in meaningful such as
+ /// string concatenation
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>> *orderings() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>> *>(VT_ORDERINGS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_NAME) &&
+ verifier.VerifyString(name()) &&
+ VerifyOffsetRequired(verifier, VT_ARGUMENTS) &&
+ verifier.VerifyVector(arguments()) &&
+ verifier.VerifyVectorOfTables(arguments()) &&
+ VerifyOffset(verifier, VT_ORDERINGS) &&
+ verifier.VerifyVector(orderings()) &&
+ verifier.VerifyVectorOfTables(orderings()) &&
+ verifier.EndTable();
+ }
+};
+
+struct CallBuilder {
+ typedef Call Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_name(flatbuffers::Offset<flatbuffers::String> name) {
+ fbb_.AddOffset(Call::VT_NAME, name);
+ }
+ void add_arguments(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>> arguments) {
+ fbb_.AddOffset(Call::VT_ARGUMENTS, arguments);
+ }
+ void add_orderings(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>>> orderings) {
+ fbb_.AddOffset(Call::VT_ORDERINGS, orderings);
+ }
+ explicit CallBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CallBuilder &operator=(const CallBuilder &);
+ flatbuffers::Offset<Call> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Call>(end);
+ fbb_.Required(o, Call::VT_NAME);
+ fbb_.Required(o, Call::VT_ARGUMENTS);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Call> CreateCall(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> name = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>> arguments = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>>> orderings = 0) {
+ CallBuilder builder_(_fbb);
+ builder_.add_orderings(orderings);
+ builder_.add_arguments(arguments);
+ builder_.add_name(name);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Call> CreateCallDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *name = nullptr,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *arguments = nullptr,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>> *orderings = nullptr) {
+ auto name__ = name ? _fbb.CreateString(name) : 0;
+ auto arguments__ = arguments ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>(*arguments) : 0;
+ auto orderings__ = orderings ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>>(*orderings) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateCall(
+ _fbb,
+ name__,
+ arguments__,
+ orderings__);
+}
+
+/// A single WHEN x THEN y fragment.
+struct CaseFragment FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef CaseFragmentBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_MATCH = 4,
+ VT_RESULT = 6
+ };
+ const org::apache::arrow::computeir::flatbuf::Expression *match() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Expression *>(VT_MATCH);
+ }
+ const org::apache::arrow::computeir::flatbuf::Expression *result() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Expression *>(VT_RESULT);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_MATCH) &&
+ verifier.VerifyTable(match()) &&
+ VerifyOffsetRequired(verifier, VT_RESULT) &&
+ verifier.VerifyTable(result()) &&
+ verifier.EndTable();
+ }
+};
+
+struct CaseFragmentBuilder {
+ typedef CaseFragment Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_match(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> match) {
+ fbb_.AddOffset(CaseFragment::VT_MATCH, match);
+ }
+ void add_result(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> result) {
+ fbb_.AddOffset(CaseFragment::VT_RESULT, result);
+ }
+ explicit CaseFragmentBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CaseFragmentBuilder &operator=(const CaseFragmentBuilder &);
+ flatbuffers::Offset<CaseFragment> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<CaseFragment>(end);
+ fbb_.Required(o, CaseFragment::VT_MATCH);
+ fbb_.Required(o, CaseFragment::VT_RESULT);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<CaseFragment> CreateCaseFragment(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> match = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> result = 0) {
+ CaseFragmentBuilder builder_(_fbb);
+ builder_.add_result(result);
+ builder_.add_match(match);
+ return builder_.Finish();
+}
+
+/// Conditional case statement expression
+struct ConditionalCase FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ConditionalCaseBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_CONDITIONS = 4,
+ VT_ELSE_ = 6
+ };
+ /// List of conditions to evaluate
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::CaseFragment>> *conditions() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::CaseFragment>> *>(VT_CONDITIONS);
+ }
+ /// The default value if no cases match. This is typically NULL in SQL
+ /// implementations.
+ ///
+ /// Defaulting to NULL is a frontend choice, so producers must specify NULL
+ /// if that's their desired behavior.
+ const org::apache::arrow::computeir::flatbuf::Expression *else_() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Expression *>(VT_ELSE_);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_CONDITIONS) &&
+ verifier.VerifyVector(conditions()) &&
+ verifier.VerifyVectorOfTables(conditions()) &&
+ VerifyOffsetRequired(verifier, VT_ELSE_) &&
+ verifier.VerifyTable(else_()) &&
+ verifier.EndTable();
+ }
+};
+
+struct ConditionalCaseBuilder {
+ typedef ConditionalCase Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_conditions(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::CaseFragment>>> conditions) {
+ fbb_.AddOffset(ConditionalCase::VT_CONDITIONS, conditions);
+ }
+ void add_else_(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> else_) {
+ fbb_.AddOffset(ConditionalCase::VT_ELSE_, else_);
+ }
+ explicit ConditionalCaseBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ConditionalCaseBuilder &operator=(const ConditionalCaseBuilder &);
+ flatbuffers::Offset<ConditionalCase> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ConditionalCase>(end);
+ fbb_.Required(o, ConditionalCase::VT_CONDITIONS);
+ fbb_.Required(o, ConditionalCase::VT_ELSE_);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ConditionalCase> CreateConditionalCase(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::CaseFragment>>> conditions = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> else_ = 0) {
+ ConditionalCaseBuilder builder_(_fbb);
+ builder_.add_else_(else_);
+ builder_.add_conditions(conditions);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<ConditionalCase> CreateConditionalCaseDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::CaseFragment>> *conditions = nullptr,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> else_ = 0) {
+ auto conditions__ = conditions ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::CaseFragment>>(*conditions) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateConditionalCase(
+ _fbb,
+ conditions__,
+ else_);
+}
+
+/// Switch-style case expression
+struct SimpleCase FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SimpleCaseBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_EXPRESSION = 4,
+ VT_MATCHES = 6,
+ VT_ELSE_ = 8
+ };
+ /// The expression whose value will be matched
+ const org::apache::arrow::computeir::flatbuf::Expression *expression() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Expression *>(VT_EXPRESSION);
+ }
+ /// Matches for `expression`
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::CaseFragment>> *matches() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::CaseFragment>> *>(VT_MATCHES);
+ }
+ /// The default value if no cases match
+ const org::apache::arrow::computeir::flatbuf::Expression *else_() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Expression *>(VT_ELSE_);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_EXPRESSION) &&
+ verifier.VerifyTable(expression()) &&
+ VerifyOffsetRequired(verifier, VT_MATCHES) &&
+ verifier.VerifyVector(matches()) &&
+ verifier.VerifyVectorOfTables(matches()) &&
+ VerifyOffsetRequired(verifier, VT_ELSE_) &&
+ verifier.VerifyTable(else_()) &&
+ verifier.EndTable();
+ }
+};
+
+struct SimpleCaseBuilder {
+ typedef SimpleCase Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_expression(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> expression) {
+ fbb_.AddOffset(SimpleCase::VT_EXPRESSION, expression);
+ }
+ void add_matches(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::CaseFragment>>> matches) {
+ fbb_.AddOffset(SimpleCase::VT_MATCHES, matches);
+ }
+ void add_else_(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> else_) {
+ fbb_.AddOffset(SimpleCase::VT_ELSE_, else_);
+ }
+ explicit SimpleCaseBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SimpleCaseBuilder &operator=(const SimpleCaseBuilder &);
+ flatbuffers::Offset<SimpleCase> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SimpleCase>(end);
+ fbb_.Required(o, SimpleCase::VT_EXPRESSION);
+ fbb_.Required(o, SimpleCase::VT_MATCHES);
+ fbb_.Required(o, SimpleCase::VT_ELSE_);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SimpleCase> CreateSimpleCase(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> expression = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::CaseFragment>>> matches = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> else_ = 0) {
+ SimpleCaseBuilder builder_(_fbb);
+ builder_.add_else_(else_);
+ builder_.add_matches(matches);
+ builder_.add_expression(expression);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<SimpleCase> CreateSimpleCaseDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> expression = 0,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::CaseFragment>> *matches = nullptr,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> else_ = 0) {
+ auto matches__ = matches ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::CaseFragment>>(*matches) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateSimpleCase(
+ _fbb,
+ expression,
+ matches__,
+ else_);
+}
+
+/// An expression with an order
+struct SortKey FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SortKeyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_EXPRESSION = 4,
+ VT_ORDERING = 6
+ };
+ const org::apache::arrow::computeir::flatbuf::Expression *expression() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Expression *>(VT_EXPRESSION);
+ }
+ org::apache::arrow::computeir::flatbuf::Ordering ordering() const {
+ return static_cast<org::apache::arrow::computeir::flatbuf::Ordering>(GetField<uint8_t>(VT_ORDERING, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_EXPRESSION) &&
+ verifier.VerifyTable(expression()) &&
+ VerifyField<uint8_t>(verifier, VT_ORDERING) &&
+ verifier.EndTable();
+ }
+};
+
+struct SortKeyBuilder {
+ typedef SortKey Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_expression(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> expression) {
+ fbb_.AddOffset(SortKey::VT_EXPRESSION, expression);
+ }
+ void add_ordering(org::apache::arrow::computeir::flatbuf::Ordering ordering) {
+ fbb_.AddElement<uint8_t>(SortKey::VT_ORDERING, static_cast<uint8_t>(ordering), 0);
+ }
+ explicit SortKeyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SortKeyBuilder &operator=(const SortKeyBuilder &);
+ flatbuffers::Offset<SortKey> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SortKey>(end);
+ fbb_.Required(o, SortKey::VT_EXPRESSION);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SortKey> CreateSortKey(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> expression = 0,
+ org::apache::arrow::computeir::flatbuf::Ordering ordering = org::apache::arrow::computeir::flatbuf::Ordering::ASCENDING_THEN_NULLS) {
+ SortKeyBuilder builder_(_fbb);
+ builder_.add_expression(expression);
+ builder_.add_ordering(ordering);
+ return builder_.Finish();
+}
+
+/// An unbounded window bound
+struct Unbounded FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef UnboundedBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+};
+
+struct UnboundedBuilder {
+ typedef Unbounded Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit UnboundedBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ UnboundedBuilder &operator=(const UnboundedBuilder &);
+ flatbuffers::Offset<Unbounded> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Unbounded>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Unbounded> CreateUnbounded(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ UnboundedBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+/// Boundary is preceding rows, determined by the contained expression
+struct Preceding FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PrecedingBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_IMPL_TYPE = 4,
+ VT_IMPL = 6
+ };
+ org::apache::arrow::computeir::flatbuf::ConcreteBoundImpl impl_type() const {
+ return static_cast<org::apache::arrow::computeir::flatbuf::ConcreteBoundImpl>(GetField<uint8_t>(VT_IMPL_TYPE, 0));
+ }
+ const void *impl() const {
+ return GetPointer<const void *>(VT_IMPL);
+ }
+ template<typename T> const T *impl_as() const;
+ const org::apache::arrow::computeir::flatbuf::Expression *impl_as_Expression() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::ConcreteBoundImpl::Expression ? static_cast<const org::apache::arrow::computeir::flatbuf::Expression *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Unbounded *impl_as_Unbounded() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::ConcreteBoundImpl::Unbounded ? static_cast<const org::apache::arrow::computeir::flatbuf::Unbounded *>(impl()) : nullptr;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_IMPL_TYPE) &&
+ VerifyOffsetRequired(verifier, VT_IMPL) &&
+ VerifyConcreteBoundImpl(verifier, impl(), impl_type()) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Expression *Preceding::impl_as<org::apache::arrow::computeir::flatbuf::Expression>() const {
+ return impl_as_Expression();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Unbounded *Preceding::impl_as<org::apache::arrow::computeir::flatbuf::Unbounded>() const {
+ return impl_as_Unbounded();
+}
+
+struct PrecedingBuilder {
+ typedef Preceding Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_impl_type(org::apache::arrow::computeir::flatbuf::ConcreteBoundImpl impl_type) {
+ fbb_.AddElement<uint8_t>(Preceding::VT_IMPL_TYPE, static_cast<uint8_t>(impl_type), 0);
+ }
+ void add_impl(flatbuffers::Offset<void> impl) {
+ fbb_.AddOffset(Preceding::VT_IMPL, impl);
+ }
+ explicit PrecedingBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PrecedingBuilder &operator=(const PrecedingBuilder &);
+ flatbuffers::Offset<Preceding> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Preceding>(end);
+ fbb_.Required(o, Preceding::VT_IMPL);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Preceding> CreatePreceding(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::computeir::flatbuf::ConcreteBoundImpl impl_type = org::apache::arrow::computeir::flatbuf::ConcreteBoundImpl::NONE,
+ flatbuffers::Offset<void> impl = 0) {
+ PrecedingBuilder builder_(_fbb);
+ builder_.add_impl(impl);
+ builder_.add_impl_type(impl_type);
+ return builder_.Finish();
+}
+
+/// Boundary is following rows, determined by the contained expression
+struct Following FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FollowingBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_IMPL_TYPE = 4,
+ VT_IMPL = 6
+ };
+ org::apache::arrow::computeir::flatbuf::ConcreteBoundImpl impl_type() const {
+ return static_cast<org::apache::arrow::computeir::flatbuf::ConcreteBoundImpl>(GetField<uint8_t>(VT_IMPL_TYPE, 0));
+ }
+ const void *impl() const {
+ return GetPointer<const void *>(VT_IMPL);
+ }
+ template<typename T> const T *impl_as() const;
+ const org::apache::arrow::computeir::flatbuf::Expression *impl_as_Expression() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::ConcreteBoundImpl::Expression ? static_cast<const org::apache::arrow::computeir::flatbuf::Expression *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Unbounded *impl_as_Unbounded() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::ConcreteBoundImpl::Unbounded ? static_cast<const org::apache::arrow::computeir::flatbuf::Unbounded *>(impl()) : nullptr;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_IMPL_TYPE) &&
+ VerifyOffsetRequired(verifier, VT_IMPL) &&
+ VerifyConcreteBoundImpl(verifier, impl(), impl_type()) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Expression *Following::impl_as<org::apache::arrow::computeir::flatbuf::Expression>() const {
+ return impl_as_Expression();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Unbounded *Following::impl_as<org::apache::arrow::computeir::flatbuf::Unbounded>() const {
+ return impl_as_Unbounded();
+}
+
+struct FollowingBuilder {
+ typedef Following Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_impl_type(org::apache::arrow::computeir::flatbuf::ConcreteBoundImpl impl_type) {
+ fbb_.AddElement<uint8_t>(Following::VT_IMPL_TYPE, static_cast<uint8_t>(impl_type), 0);
+ }
+ void add_impl(flatbuffers::Offset<void> impl) {
+ fbb_.AddOffset(Following::VT_IMPL, impl);
+ }
+ explicit FollowingBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FollowingBuilder &operator=(const FollowingBuilder &);
+ flatbuffers::Offset<Following> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Following>(end);
+ fbb_.Required(o, Following::VT_IMPL);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Following> CreateFollowing(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::computeir::flatbuf::ConcreteBoundImpl impl_type = org::apache::arrow::computeir::flatbuf::ConcreteBoundImpl::NONE,
+ flatbuffers::Offset<void> impl = 0) {
+ FollowingBuilder builder_(_fbb);
+ builder_.add_impl(impl);
+ builder_.add_impl_type(impl_type);
+ return builder_.Finish();
+}
+
+/// Boundary is the current row
+struct CurrentRow FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef CurrentRowBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+};
+
+struct CurrentRowBuilder {
+ typedef CurrentRow Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit CurrentRowBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CurrentRowBuilder &operator=(const CurrentRowBuilder &);
+ flatbuffers::Offset<CurrentRow> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<CurrentRow>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<CurrentRow> CreateCurrentRow(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ CurrentRowBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+/// An expression representing a window function call.
+struct WindowCall FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef WindowCallBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_EXPRESSION = 4,
+ VT_KIND = 6,
+ VT_PARTITIONS = 8,
+ VT_ORDERINGS = 10,
+ VT_LOWER_BOUND_TYPE = 12,
+ VT_LOWER_BOUND = 14,
+ VT_UPPER_BOUND_TYPE = 16,
+ VT_UPPER_BOUND = 18
+ };
+ /// The expression to operate over
+ const org::apache::arrow::computeir::flatbuf::Expression *expression() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Expression *>(VT_EXPRESSION);
+ }
+ /// The kind of window frame
+ org::apache::arrow::computeir::flatbuf::Frame kind() const {
+ return static_cast<org::apache::arrow::computeir::flatbuf::Frame>(GetField<uint8_t>(VT_KIND, 0));
+ }
+ /// Partition keys
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *partitions() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *>(VT_PARTITIONS);
+ }
+ /// Sort keys
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>> *orderings() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>> *>(VT_ORDERINGS);
+ }
+ org::apache::arrow::computeir::flatbuf::Bound lower_bound_type() const {
+ return static_cast<org::apache::arrow::computeir::flatbuf::Bound>(GetField<uint8_t>(VT_LOWER_BOUND_TYPE, 0));
+ }
+ /// Lower window bound
+ const void *lower_bound() const {
+ return GetPointer<const void *>(VT_LOWER_BOUND);
+ }
+ template<typename T> const T *lower_bound_as() const;
+ const org::apache::arrow::computeir::flatbuf::Preceding *lower_bound_as_Preceding() const {
+ return lower_bound_type() == org::apache::arrow::computeir::flatbuf::Bound::Preceding ? static_cast<const org::apache::arrow::computeir::flatbuf::Preceding *>(lower_bound()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Following *lower_bound_as_Following() const {
+ return lower_bound_type() == org::apache::arrow::computeir::flatbuf::Bound::Following ? static_cast<const org::apache::arrow::computeir::flatbuf::Following *>(lower_bound()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::CurrentRow *lower_bound_as_CurrentRow() const {
+ return lower_bound_type() == org::apache::arrow::computeir::flatbuf::Bound::CurrentRow ? static_cast<const org::apache::arrow::computeir::flatbuf::CurrentRow *>(lower_bound()) : nullptr;
+ }
+ org::apache::arrow::computeir::flatbuf::Bound upper_bound_type() const {
+ return static_cast<org::apache::arrow::computeir::flatbuf::Bound>(GetField<uint8_t>(VT_UPPER_BOUND_TYPE, 0));
+ }
+ /// Upper window bound
+ const void *upper_bound() const {
+ return GetPointer<const void *>(VT_UPPER_BOUND);
+ }
+ template<typename T> const T *upper_bound_as() const;
+ const org::apache::arrow::computeir::flatbuf::Preceding *upper_bound_as_Preceding() const {
+ return upper_bound_type() == org::apache::arrow::computeir::flatbuf::Bound::Preceding ? static_cast<const org::apache::arrow::computeir::flatbuf::Preceding *>(upper_bound()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Following *upper_bound_as_Following() const {
+ return upper_bound_type() == org::apache::arrow::computeir::flatbuf::Bound::Following ? static_cast<const org::apache::arrow::computeir::flatbuf::Following *>(upper_bound()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::CurrentRow *upper_bound_as_CurrentRow() const {
+ return upper_bound_type() == org::apache::arrow::computeir::flatbuf::Bound::CurrentRow ? static_cast<const org::apache::arrow::computeir::flatbuf::CurrentRow *>(upper_bound()) : nullptr;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_EXPRESSION) &&
+ verifier.VerifyTable(expression()) &&
+ VerifyField<uint8_t>(verifier, VT_KIND) &&
+ VerifyOffsetRequired(verifier, VT_PARTITIONS) &&
+ verifier.VerifyVector(partitions()) &&
+ verifier.VerifyVectorOfTables(partitions()) &&
+ VerifyOffsetRequired(verifier, VT_ORDERINGS) &&
+ verifier.VerifyVector(orderings()) &&
+ verifier.VerifyVectorOfTables(orderings()) &&
+ VerifyField<uint8_t>(verifier, VT_LOWER_BOUND_TYPE) &&
+ VerifyOffsetRequired(verifier, VT_LOWER_BOUND) &&
+ VerifyBound(verifier, lower_bound(), lower_bound_type()) &&
+ VerifyField<uint8_t>(verifier, VT_UPPER_BOUND_TYPE) &&
+ VerifyOffsetRequired(verifier, VT_UPPER_BOUND) &&
+ VerifyBound(verifier, upper_bound(), upper_bound_type()) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Preceding *WindowCall::lower_bound_as<org::apache::arrow::computeir::flatbuf::Preceding>() const {
+ return lower_bound_as_Preceding();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Following *WindowCall::lower_bound_as<org::apache::arrow::computeir::flatbuf::Following>() const {
+ return lower_bound_as_Following();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::CurrentRow *WindowCall::lower_bound_as<org::apache::arrow::computeir::flatbuf::CurrentRow>() const {
+ return lower_bound_as_CurrentRow();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Preceding *WindowCall::upper_bound_as<org::apache::arrow::computeir::flatbuf::Preceding>() const {
+ return upper_bound_as_Preceding();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Following *WindowCall::upper_bound_as<org::apache::arrow::computeir::flatbuf::Following>() const {
+ return upper_bound_as_Following();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::CurrentRow *WindowCall::upper_bound_as<org::apache::arrow::computeir::flatbuf::CurrentRow>() const {
+ return upper_bound_as_CurrentRow();
+}
+
+struct WindowCallBuilder {
+ typedef WindowCall Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_expression(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> expression) {
+ fbb_.AddOffset(WindowCall::VT_EXPRESSION, expression);
+ }
+ void add_kind(org::apache::arrow::computeir::flatbuf::Frame kind) {
+ fbb_.AddElement<uint8_t>(WindowCall::VT_KIND, static_cast<uint8_t>(kind), 0);
+ }
+ void add_partitions(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>> partitions) {
+ fbb_.AddOffset(WindowCall::VT_PARTITIONS, partitions);
+ }
+ void add_orderings(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>>> orderings) {
+ fbb_.AddOffset(WindowCall::VT_ORDERINGS, orderings);
+ }
+ void add_lower_bound_type(org::apache::arrow::computeir::flatbuf::Bound lower_bound_type) {
+ fbb_.AddElement<uint8_t>(WindowCall::VT_LOWER_BOUND_TYPE, static_cast<uint8_t>(lower_bound_type), 0);
+ }
+ void add_lower_bound(flatbuffers::Offset<void> lower_bound) {
+ fbb_.AddOffset(WindowCall::VT_LOWER_BOUND, lower_bound);
+ }
+ void add_upper_bound_type(org::apache::arrow::computeir::flatbuf::Bound upper_bound_type) {
+ fbb_.AddElement<uint8_t>(WindowCall::VT_UPPER_BOUND_TYPE, static_cast<uint8_t>(upper_bound_type), 0);
+ }
+ void add_upper_bound(flatbuffers::Offset<void> upper_bound) {
+ fbb_.AddOffset(WindowCall::VT_UPPER_BOUND, upper_bound);
+ }
+ explicit WindowCallBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ WindowCallBuilder &operator=(const WindowCallBuilder &);
+ flatbuffers::Offset<WindowCall> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<WindowCall>(end);
+ fbb_.Required(o, WindowCall::VT_EXPRESSION);
+ fbb_.Required(o, WindowCall::VT_PARTITIONS);
+ fbb_.Required(o, WindowCall::VT_ORDERINGS);
+ fbb_.Required(o, WindowCall::VT_LOWER_BOUND);
+ fbb_.Required(o, WindowCall::VT_UPPER_BOUND);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<WindowCall> CreateWindowCall(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> expression = 0,
+ org::apache::arrow::computeir::flatbuf::Frame kind = org::apache::arrow::computeir::flatbuf::Frame::Rows,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>> partitions = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>>> orderings = 0,
+ org::apache::arrow::computeir::flatbuf::Bound lower_bound_type = org::apache::arrow::computeir::flatbuf::Bound::NONE,
+ flatbuffers::Offset<void> lower_bound = 0,
+ org::apache::arrow::computeir::flatbuf::Bound upper_bound_type = org::apache::arrow::computeir::flatbuf::Bound::NONE,
+ flatbuffers::Offset<void> upper_bound = 0) {
+ WindowCallBuilder builder_(_fbb);
+ builder_.add_upper_bound(upper_bound);
+ builder_.add_lower_bound(lower_bound);
+ builder_.add_orderings(orderings);
+ builder_.add_partitions(partitions);
+ builder_.add_expression(expression);
+ builder_.add_upper_bound_type(upper_bound_type);
+ builder_.add_lower_bound_type(lower_bound_type);
+ builder_.add_kind(kind);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<WindowCall> CreateWindowCallDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> expression = 0,
+ org::apache::arrow::computeir::flatbuf::Frame kind = org::apache::arrow::computeir::flatbuf::Frame::Rows,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *partitions = nullptr,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>> *orderings = nullptr,
+ org::apache::arrow::computeir::flatbuf::Bound lower_bound_type = org::apache::arrow::computeir::flatbuf::Bound::NONE,
+ flatbuffers::Offset<void> lower_bound = 0,
+ org::apache::arrow::computeir::flatbuf::Bound upper_bound_type = org::apache::arrow::computeir::flatbuf::Bound::NONE,
+ flatbuffers::Offset<void> upper_bound = 0) {
+ auto partitions__ = partitions ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>(*partitions) : 0;
+ auto orderings__ = orderings ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>>(*orderings) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateWindowCall(
+ _fbb,
+ expression,
+ kind,
+ partitions__,
+ orderings__,
+ lower_bound_type,
+ lower_bound,
+ upper_bound_type,
+ upper_bound);
+}
+
+/// A cast expression
+struct Cast FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef CastBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OPERAND = 4,
+ VT_TO = 6
+ };
+ /// The expression to cast
+ const org::apache::arrow::computeir::flatbuf::Expression *operand() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Expression *>(VT_OPERAND);
+ }
+ /// The type to cast to. This value is a `Field` to allow complete representation
+ /// of arrow types.
+ ///
+ /// `Type` is unable to completely represent complex types like lists and
+ /// maps.
+ const org::apache::arrow::flatbuf::Field *to() const {
+ return GetPointer<const org::apache::arrow::flatbuf::Field *>(VT_TO);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_OPERAND) &&
+ verifier.VerifyTable(operand()) &&
+ VerifyOffsetRequired(verifier, VT_TO) &&
+ verifier.VerifyTable(to()) &&
+ verifier.EndTable();
+ }
+};
+
+struct CastBuilder {
+ typedef Cast Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_operand(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> operand) {
+ fbb_.AddOffset(Cast::VT_OPERAND, operand);
+ }
+ void add_to(flatbuffers::Offset<org::apache::arrow::flatbuf::Field> to) {
+ fbb_.AddOffset(Cast::VT_TO, to);
+ }
+ explicit CastBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CastBuilder &operator=(const CastBuilder &);
+ flatbuffers::Offset<Cast> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Cast>(end);
+ fbb_.Required(o, Cast::VT_OPERAND);
+ fbb_.Required(o, Cast::VT_TO);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Cast> CreateCast(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> operand = 0,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Field> to = 0) {
+ CastBuilder builder_(_fbb);
+ builder_.add_to(to);
+ builder_.add_operand(operand);
+ return builder_.Finish();
+}
+
+/// Expression types
+///
+/// Expressions have a concrete `impl` value, which is a specific operation.
+///
+/// This is a workaround for flatbuffers' lack of support for direct use of
+/// union types.
+struct Expression FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ExpressionBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_IMPL_TYPE = 4,
+ VT_IMPL = 6
+ };
+ org::apache::arrow::computeir::flatbuf::ExpressionImpl impl_type() const {
+ return static_cast<org::apache::arrow::computeir::flatbuf::ExpressionImpl>(GetField<uint8_t>(VT_IMPL_TYPE, 0));
+ }
+ const void *impl() const {
+ return GetPointer<const void *>(VT_IMPL);
+ }
+ template<typename T> const T *impl_as() const;
+ const org::apache::arrow::computeir::flatbuf::Literal *impl_as_Literal() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::ExpressionImpl::Literal ? static_cast<const org::apache::arrow::computeir::flatbuf::Literal *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::FieldRef *impl_as_FieldRef() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::ExpressionImpl::FieldRef ? static_cast<const org::apache::arrow::computeir::flatbuf::FieldRef *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Call *impl_as_Call() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::ExpressionImpl::Call ? static_cast<const org::apache::arrow::computeir::flatbuf::Call *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::ConditionalCase *impl_as_ConditionalCase() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::ExpressionImpl::ConditionalCase ? static_cast<const org::apache::arrow::computeir::flatbuf::ConditionalCase *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::SimpleCase *impl_as_SimpleCase() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::ExpressionImpl::SimpleCase ? static_cast<const org::apache::arrow::computeir::flatbuf::SimpleCase *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::WindowCall *impl_as_WindowCall() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::ExpressionImpl::WindowCall ? static_cast<const org::apache::arrow::computeir::flatbuf::WindowCall *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Cast *impl_as_Cast() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::ExpressionImpl::Cast ? static_cast<const org::apache::arrow::computeir::flatbuf::Cast *>(impl()) : nullptr;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_IMPL_TYPE) &&
+ VerifyOffsetRequired(verifier, VT_IMPL) &&
+ VerifyExpressionImpl(verifier, impl(), impl_type()) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Literal *Expression::impl_as<org::apache::arrow::computeir::flatbuf::Literal>() const {
+ return impl_as_Literal();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::FieldRef *Expression::impl_as<org::apache::arrow::computeir::flatbuf::FieldRef>() const {
+ return impl_as_FieldRef();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Call *Expression::impl_as<org::apache::arrow::computeir::flatbuf::Call>() const {
+ return impl_as_Call();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::ConditionalCase *Expression::impl_as<org::apache::arrow::computeir::flatbuf::ConditionalCase>() const {
+ return impl_as_ConditionalCase();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::SimpleCase *Expression::impl_as<org::apache::arrow::computeir::flatbuf::SimpleCase>() const {
+ return impl_as_SimpleCase();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::WindowCall *Expression::impl_as<org::apache::arrow::computeir::flatbuf::WindowCall>() const {
+ return impl_as_WindowCall();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Cast *Expression::impl_as<org::apache::arrow::computeir::flatbuf::Cast>() const {
+ return impl_as_Cast();
+}
+
+struct ExpressionBuilder {
+ typedef Expression Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_impl_type(org::apache::arrow::computeir::flatbuf::ExpressionImpl impl_type) {
+ fbb_.AddElement<uint8_t>(Expression::VT_IMPL_TYPE, static_cast<uint8_t>(impl_type), 0);
+ }
+ void add_impl(flatbuffers::Offset<void> impl) {
+ fbb_.AddOffset(Expression::VT_IMPL, impl);
+ }
+ explicit ExpressionBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ExpressionBuilder &operator=(const ExpressionBuilder &);
+ flatbuffers::Offset<Expression> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Expression>(end);
+ fbb_.Required(o, Expression::VT_IMPL);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Expression> CreateExpression(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::computeir::flatbuf::ExpressionImpl impl_type = org::apache::arrow::computeir::flatbuf::ExpressionImpl::NONE,
+ flatbuffers::Offset<void> impl = 0) {
+ ExpressionBuilder builder_(_fbb);
+ builder_.add_impl(impl);
+ builder_.add_impl_type(impl_type);
+ return builder_.Finish();
+}
+
+inline bool VerifyDeref(flatbuffers::Verifier &verifier, const void *obj, Deref type) {
+ switch (type) {
+ case Deref::NONE: {
+ return true;
+ }
+ case Deref::MapKey: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::MapKey *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Deref::StructField: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::StructField *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Deref::ArraySubscript: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::ArraySubscript *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Deref::ArraySlice: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::ArraySlice *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Deref::FieldIndex: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::FieldIndex *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return true;
+ }
+}
+
+inline bool VerifyDerefVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyDeref(
+ verifier, values->Get(i), types->GetEnum<Deref>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline bool VerifyConcreteBoundImpl(flatbuffers::Verifier &verifier, const void *obj, ConcreteBoundImpl type) {
+ switch (type) {
+ case ConcreteBoundImpl::NONE: {
+ return true;
+ }
+ case ConcreteBoundImpl::Expression: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Expression *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case ConcreteBoundImpl::Unbounded: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Unbounded *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return true;
+ }
+}
+
+inline bool VerifyConcreteBoundImplVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyConcreteBoundImpl(
+ verifier, values->Get(i), types->GetEnum<ConcreteBoundImpl>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline bool VerifyBound(flatbuffers::Verifier &verifier, const void *obj, Bound type) {
+ switch (type) {
+ case Bound::NONE: {
+ return true;
+ }
+ case Bound::Preceding: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Preceding *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Bound::Following: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Following *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Bound::CurrentRow: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::CurrentRow *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return true;
+ }
+}
+
+inline bool VerifyBoundVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyBound(
+ verifier, values->Get(i), types->GetEnum<Bound>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline bool VerifyExpressionImpl(flatbuffers::Verifier &verifier, const void *obj, ExpressionImpl type) {
+ switch (type) {
+ case ExpressionImpl::NONE: {
+ return true;
+ }
+ case ExpressionImpl::Literal: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Literal *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case ExpressionImpl::FieldRef: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::FieldRef *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case ExpressionImpl::Call: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Call *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case ExpressionImpl::ConditionalCase: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::ConditionalCase *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case ExpressionImpl::SimpleCase: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::SimpleCase *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case ExpressionImpl::WindowCall: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::WindowCall *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case ExpressionImpl::Cast: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Cast *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return true;
+ }
+}
+
+inline bool VerifyExpressionImplVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyExpressionImpl(
+ verifier, values->Get(i), types->GetEnum<ExpressionImpl>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline const org::apache::arrow::computeir::flatbuf::Expression *GetExpression(const void *buf) {
+ return flatbuffers::GetRoot<org::apache::arrow::computeir::flatbuf::Expression>(buf);
+}
+
+inline const org::apache::arrow::computeir::flatbuf::Expression *GetSizePrefixedExpression(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<org::apache::arrow::computeir::flatbuf::Expression>(buf);
+}
+
+inline bool VerifyExpressionBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifyBuffer<org::apache::arrow::computeir::flatbuf::Expression>(nullptr);
+}
+
+inline bool VerifySizePrefixedExpressionBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<org::apache::arrow::computeir::flatbuf::Expression>(nullptr);
+}
+
+inline void FinishExpressionBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> root) {
+ fbb.Finish(root);
+}
+
+inline void FinishSizePrefixedExpressionBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> root) {
+ fbb.FinishSizePrefixed(root);
+}
+
+} // namespace flatbuf
+} // namespace computeir
+} // namespace arrow
+} // namespace apache
+} // namespace org
+
+#endif // FLATBUFFERS_GENERATED_EXPRESSION_ORG_APACHE_ARROW_COMPUTEIR_FLATBUF_H_
diff --git a/src/arrow/cpp/src/generated/File_generated.h b/src/arrow/cpp/src/generated/File_generated.h
new file mode 100644
index 000000000..06953c4a0
--- /dev/null
+++ b/src/arrow/cpp/src/generated/File_generated.h
@@ -0,0 +1,200 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_FILE_ORG_APACHE_ARROW_FLATBUF_H_
+#define FLATBUFFERS_GENERATED_FILE_ORG_APACHE_ARROW_FLATBUF_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+#include "Schema_generated.h"
+
+namespace org {
+namespace apache {
+namespace arrow {
+namespace flatbuf {
+
+struct Footer;
+struct FooterBuilder;
+
+struct Block;
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) Block FLATBUFFERS_FINAL_CLASS {
+ private:
+ int64_t offset_;
+ int32_t metaDataLength_;
+ int32_t padding0__;
+ int64_t bodyLength_;
+
+ public:
+ Block() {
+ memset(static_cast<void *>(this), 0, sizeof(Block));
+ }
+ Block(int64_t _offset, int32_t _metaDataLength, int64_t _bodyLength)
+ : offset_(flatbuffers::EndianScalar(_offset)),
+ metaDataLength_(flatbuffers::EndianScalar(_metaDataLength)),
+ padding0__(0),
+ bodyLength_(flatbuffers::EndianScalar(_bodyLength)) {
+ (void)padding0__;
+ }
+ /// Index to the start of the RecordBlock (note this is past the Message header)
+ int64_t offset() const {
+ return flatbuffers::EndianScalar(offset_);
+ }
+ /// Length of the metadata
+ int32_t metaDataLength() const {
+ return flatbuffers::EndianScalar(metaDataLength_);
+ }
+ /// Length of the data (this is aligned so there can be a gap between this and
+ /// the metadata).
+ int64_t bodyLength() const {
+ return flatbuffers::EndianScalar(bodyLength_);
+ }
+};
+FLATBUFFERS_STRUCT_END(Block, 24);
+
+/// ----------------------------------------------------------------------
+/// Arrow File metadata
+///
+struct Footer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FooterBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VERSION = 4,
+ VT_SCHEMA = 6,
+ VT_DICTIONARIES = 8,
+ VT_RECORDBATCHES = 10,
+ VT_CUSTOM_METADATA = 12
+ };
+ org::apache::arrow::flatbuf::MetadataVersion version() const {
+ return static_cast<org::apache::arrow::flatbuf::MetadataVersion>(GetField<int16_t>(VT_VERSION, 0));
+ }
+ const org::apache::arrow::flatbuf::Schema *schema() const {
+ return GetPointer<const org::apache::arrow::flatbuf::Schema *>(VT_SCHEMA);
+ }
+ const flatbuffers::Vector<const org::apache::arrow::flatbuf::Block *> *dictionaries() const {
+ return GetPointer<const flatbuffers::Vector<const org::apache::arrow::flatbuf::Block *> *>(VT_DICTIONARIES);
+ }
+ const flatbuffers::Vector<const org::apache::arrow::flatbuf::Block *> *recordBatches() const {
+ return GetPointer<const flatbuffers::Vector<const org::apache::arrow::flatbuf::Block *> *>(VT_RECORDBATCHES);
+ }
+ /// User-defined metadata
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>> *custom_metadata() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>> *>(VT_CUSTOM_METADATA);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int16_t>(verifier, VT_VERSION) &&
+ VerifyOffset(verifier, VT_SCHEMA) &&
+ verifier.VerifyTable(schema()) &&
+ VerifyOffset(verifier, VT_DICTIONARIES) &&
+ verifier.VerifyVector(dictionaries()) &&
+ VerifyOffset(verifier, VT_RECORDBATCHES) &&
+ verifier.VerifyVector(recordBatches()) &&
+ VerifyOffset(verifier, VT_CUSTOM_METADATA) &&
+ verifier.VerifyVector(custom_metadata()) &&
+ verifier.VerifyVectorOfTables(custom_metadata()) &&
+ verifier.EndTable();
+ }
+};
+
+struct FooterBuilder {
+ typedef Footer Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_version(org::apache::arrow::flatbuf::MetadataVersion version) {
+ fbb_.AddElement<int16_t>(Footer::VT_VERSION, static_cast<int16_t>(version), 0);
+ }
+ void add_schema(flatbuffers::Offset<org::apache::arrow::flatbuf::Schema> schema) {
+ fbb_.AddOffset(Footer::VT_SCHEMA, schema);
+ }
+ void add_dictionaries(flatbuffers::Offset<flatbuffers::Vector<const org::apache::arrow::flatbuf::Block *>> dictionaries) {
+ fbb_.AddOffset(Footer::VT_DICTIONARIES, dictionaries);
+ }
+ void add_recordBatches(flatbuffers::Offset<flatbuffers::Vector<const org::apache::arrow::flatbuf::Block *>> recordBatches) {
+ fbb_.AddOffset(Footer::VT_RECORDBATCHES, recordBatches);
+ }
+ void add_custom_metadata(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>> custom_metadata) {
+ fbb_.AddOffset(Footer::VT_CUSTOM_METADATA, custom_metadata);
+ }
+ explicit FooterBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FooterBuilder &operator=(const FooterBuilder &);
+ flatbuffers::Offset<Footer> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Footer>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Footer> CreateFooter(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::MetadataVersion version = org::apache::arrow::flatbuf::MetadataVersion::V1,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Schema> schema = 0,
+ flatbuffers::Offset<flatbuffers::Vector<const org::apache::arrow::flatbuf::Block *>> dictionaries = 0,
+ flatbuffers::Offset<flatbuffers::Vector<const org::apache::arrow::flatbuf::Block *>> recordBatches = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>> custom_metadata = 0) {
+ FooterBuilder builder_(_fbb);
+ builder_.add_custom_metadata(custom_metadata);
+ builder_.add_recordBatches(recordBatches);
+ builder_.add_dictionaries(dictionaries);
+ builder_.add_schema(schema);
+ builder_.add_version(version);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Footer> CreateFooterDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::MetadataVersion version = org::apache::arrow::flatbuf::MetadataVersion::V1,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Schema> schema = 0,
+ const std::vector<org::apache::arrow::flatbuf::Block> *dictionaries = nullptr,
+ const std::vector<org::apache::arrow::flatbuf::Block> *recordBatches = nullptr,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>> *custom_metadata = nullptr) {
+ auto dictionaries__ = dictionaries ? _fbb.CreateVectorOfStructs<org::apache::arrow::flatbuf::Block>(*dictionaries) : 0;
+ auto recordBatches__ = recordBatches ? _fbb.CreateVectorOfStructs<org::apache::arrow::flatbuf::Block>(*recordBatches) : 0;
+ auto custom_metadata__ = custom_metadata ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>(*custom_metadata) : 0;
+ return org::apache::arrow::flatbuf::CreateFooter(
+ _fbb,
+ version,
+ schema,
+ dictionaries__,
+ recordBatches__,
+ custom_metadata__);
+}
+
+inline const org::apache::arrow::flatbuf::Footer *GetFooter(const void *buf) {
+ return flatbuffers::GetRoot<org::apache::arrow::flatbuf::Footer>(buf);
+}
+
+inline const org::apache::arrow::flatbuf::Footer *GetSizePrefixedFooter(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<org::apache::arrow::flatbuf::Footer>(buf);
+}
+
+inline bool VerifyFooterBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifyBuffer<org::apache::arrow::flatbuf::Footer>(nullptr);
+}
+
+inline bool VerifySizePrefixedFooterBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<org::apache::arrow::flatbuf::Footer>(nullptr);
+}
+
+inline void FinishFooterBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Footer> root) {
+ fbb.Finish(root);
+}
+
+inline void FinishSizePrefixedFooterBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Footer> root) {
+ fbb.FinishSizePrefixed(root);
+}
+
+} // namespace flatbuf
+} // namespace arrow
+} // namespace apache
+} // namespace org
+
+#endif // FLATBUFFERS_GENERATED_FILE_ORG_APACHE_ARROW_FLATBUF_H_
diff --git a/src/arrow/cpp/src/generated/Literal_generated.h b/src/arrow/cpp/src/generated/Literal_generated.h
new file mode 100644
index 000000000..ea095a824
--- /dev/null
+++ b/src/arrow/cpp/src/generated/Literal_generated.h
@@ -0,0 +1,2037 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_LITERAL_ORG_APACHE_ARROW_COMPUTEIR_FLATBUF_H_
+#define FLATBUFFERS_GENERATED_LITERAL_ORG_APACHE_ARROW_COMPUTEIR_FLATBUF_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+#include "Schema_generated.h"
+
+namespace org {
+namespace apache {
+namespace arrow {
+namespace computeir {
+namespace flatbuf {
+
+struct ListLiteral;
+struct ListLiteralBuilder;
+
+struct StructLiteral;
+struct StructLiteralBuilder;
+
+struct KeyValue;
+struct KeyValueBuilder;
+
+struct MapLiteral;
+struct MapLiteralBuilder;
+
+struct Int8Literal;
+struct Int8LiteralBuilder;
+
+struct Int16Literal;
+struct Int16LiteralBuilder;
+
+struct Int32Literal;
+struct Int32LiteralBuilder;
+
+struct Int64Literal;
+struct Int64LiteralBuilder;
+
+struct UInt8Literal;
+struct UInt8LiteralBuilder;
+
+struct UInt16Literal;
+struct UInt16LiteralBuilder;
+
+struct UInt32Literal;
+struct UInt32LiteralBuilder;
+
+struct UInt64Literal;
+struct UInt64LiteralBuilder;
+
+struct Float16Literal;
+struct Float16LiteralBuilder;
+
+struct Float32Literal;
+struct Float32LiteralBuilder;
+
+struct Float64Literal;
+struct Float64LiteralBuilder;
+
+struct DecimalLiteral;
+struct DecimalLiteralBuilder;
+
+struct BooleanLiteral;
+struct BooleanLiteralBuilder;
+
+struct DateLiteral;
+struct DateLiteralBuilder;
+
+struct TimeLiteral;
+struct TimeLiteralBuilder;
+
+struct TimestampLiteral;
+struct TimestampLiteralBuilder;
+
+struct IntervalLiteralMonths;
+struct IntervalLiteralMonthsBuilder;
+
+struct IntervalLiteralDaysMilliseconds;
+struct IntervalLiteralDaysMillisecondsBuilder;
+
+struct IntervalLiteral;
+struct IntervalLiteralBuilder;
+
+struct DurationLiteral;
+struct DurationLiteralBuilder;
+
+struct BinaryLiteral;
+struct BinaryLiteralBuilder;
+
+struct FixedSizeBinaryLiteral;
+struct FixedSizeBinaryLiteralBuilder;
+
+struct StringLiteral;
+struct StringLiteralBuilder;
+
+struct Literal;
+struct LiteralBuilder;
+
+enum class IntervalLiteralImpl : uint8_t {
+ NONE = 0,
+ IntervalLiteralMonths = 1,
+ IntervalLiteralDaysMilliseconds = 2,
+ MIN = NONE,
+ MAX = IntervalLiteralDaysMilliseconds
+};
+
+inline const IntervalLiteralImpl (&EnumValuesIntervalLiteralImpl())[3] {
+ static const IntervalLiteralImpl values[] = {
+ IntervalLiteralImpl::NONE,
+ IntervalLiteralImpl::IntervalLiteralMonths,
+ IntervalLiteralImpl::IntervalLiteralDaysMilliseconds
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesIntervalLiteralImpl() {
+ static const char * const names[4] = {
+ "NONE",
+ "IntervalLiteralMonths",
+ "IntervalLiteralDaysMilliseconds",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameIntervalLiteralImpl(IntervalLiteralImpl e) {
+ if (flatbuffers::IsOutRange(e, IntervalLiteralImpl::NONE, IntervalLiteralImpl::IntervalLiteralDaysMilliseconds)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesIntervalLiteralImpl()[index];
+}
+
+template<typename T> struct IntervalLiteralImplTraits {
+ static const IntervalLiteralImpl enum_value = IntervalLiteralImpl::NONE;
+};
+
+template<> struct IntervalLiteralImplTraits<org::apache::arrow::computeir::flatbuf::IntervalLiteralMonths> {
+ static const IntervalLiteralImpl enum_value = IntervalLiteralImpl::IntervalLiteralMonths;
+};
+
+template<> struct IntervalLiteralImplTraits<org::apache::arrow::computeir::flatbuf::IntervalLiteralDaysMilliseconds> {
+ static const IntervalLiteralImpl enum_value = IntervalLiteralImpl::IntervalLiteralDaysMilliseconds;
+};
+
+bool VerifyIntervalLiteralImpl(flatbuffers::Verifier &verifier, const void *obj, IntervalLiteralImpl type);
+bool VerifyIntervalLiteralImplVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+enum class LiteralImpl : uint8_t {
+ NONE = 0,
+ BooleanLiteral = 1,
+ Int8Literal = 2,
+ Int16Literal = 3,
+ Int32Literal = 4,
+ Int64Literal = 5,
+ UInt8Literal = 6,
+ UInt16Literal = 7,
+ UInt32Literal = 8,
+ UInt64Literal = 9,
+ DateLiteral = 10,
+ TimeLiteral = 11,
+ TimestampLiteral = 12,
+ IntervalLiteral = 13,
+ DurationLiteral = 14,
+ DecimalLiteral = 15,
+ Float16Literal = 16,
+ Float32Literal = 17,
+ Float64Literal = 18,
+ ListLiteral = 19,
+ StructLiteral = 20,
+ MapLiteral = 21,
+ StringLiteral = 22,
+ BinaryLiteral = 23,
+ FixedSizeBinaryLiteral = 24,
+ MIN = NONE,
+ MAX = FixedSizeBinaryLiteral
+};
+
+inline const LiteralImpl (&EnumValuesLiteralImpl())[25] {
+ static const LiteralImpl values[] = {
+ LiteralImpl::NONE,
+ LiteralImpl::BooleanLiteral,
+ LiteralImpl::Int8Literal,
+ LiteralImpl::Int16Literal,
+ LiteralImpl::Int32Literal,
+ LiteralImpl::Int64Literal,
+ LiteralImpl::UInt8Literal,
+ LiteralImpl::UInt16Literal,
+ LiteralImpl::UInt32Literal,
+ LiteralImpl::UInt64Literal,
+ LiteralImpl::DateLiteral,
+ LiteralImpl::TimeLiteral,
+ LiteralImpl::TimestampLiteral,
+ LiteralImpl::IntervalLiteral,
+ LiteralImpl::DurationLiteral,
+ LiteralImpl::DecimalLiteral,
+ LiteralImpl::Float16Literal,
+ LiteralImpl::Float32Literal,
+ LiteralImpl::Float64Literal,
+ LiteralImpl::ListLiteral,
+ LiteralImpl::StructLiteral,
+ LiteralImpl::MapLiteral,
+ LiteralImpl::StringLiteral,
+ LiteralImpl::BinaryLiteral,
+ LiteralImpl::FixedSizeBinaryLiteral
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesLiteralImpl() {
+ static const char * const names[26] = {
+ "NONE",
+ "BooleanLiteral",
+ "Int8Literal",
+ "Int16Literal",
+ "Int32Literal",
+ "Int64Literal",
+ "UInt8Literal",
+ "UInt16Literal",
+ "UInt32Literal",
+ "UInt64Literal",
+ "DateLiteral",
+ "TimeLiteral",
+ "TimestampLiteral",
+ "IntervalLiteral",
+ "DurationLiteral",
+ "DecimalLiteral",
+ "Float16Literal",
+ "Float32Literal",
+ "Float64Literal",
+ "ListLiteral",
+ "StructLiteral",
+ "MapLiteral",
+ "StringLiteral",
+ "BinaryLiteral",
+ "FixedSizeBinaryLiteral",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameLiteralImpl(LiteralImpl e) {
+ if (flatbuffers::IsOutRange(e, LiteralImpl::NONE, LiteralImpl::FixedSizeBinaryLiteral)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesLiteralImpl()[index];
+}
+
+template<typename T> struct LiteralImplTraits {
+ static const LiteralImpl enum_value = LiteralImpl::NONE;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::BooleanLiteral> {
+ static const LiteralImpl enum_value = LiteralImpl::BooleanLiteral;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::Int8Literal> {
+ static const LiteralImpl enum_value = LiteralImpl::Int8Literal;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::Int16Literal> {
+ static const LiteralImpl enum_value = LiteralImpl::Int16Literal;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::Int32Literal> {
+ static const LiteralImpl enum_value = LiteralImpl::Int32Literal;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::Int64Literal> {
+ static const LiteralImpl enum_value = LiteralImpl::Int64Literal;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::UInt8Literal> {
+ static const LiteralImpl enum_value = LiteralImpl::UInt8Literal;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::UInt16Literal> {
+ static const LiteralImpl enum_value = LiteralImpl::UInt16Literal;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::UInt32Literal> {
+ static const LiteralImpl enum_value = LiteralImpl::UInt32Literal;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::UInt64Literal> {
+ static const LiteralImpl enum_value = LiteralImpl::UInt64Literal;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::DateLiteral> {
+ static const LiteralImpl enum_value = LiteralImpl::DateLiteral;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::TimeLiteral> {
+ static const LiteralImpl enum_value = LiteralImpl::TimeLiteral;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::TimestampLiteral> {
+ static const LiteralImpl enum_value = LiteralImpl::TimestampLiteral;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::IntervalLiteral> {
+ static const LiteralImpl enum_value = LiteralImpl::IntervalLiteral;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::DurationLiteral> {
+ static const LiteralImpl enum_value = LiteralImpl::DurationLiteral;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::DecimalLiteral> {
+ static const LiteralImpl enum_value = LiteralImpl::DecimalLiteral;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::Float16Literal> {
+ static const LiteralImpl enum_value = LiteralImpl::Float16Literal;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::Float32Literal> {
+ static const LiteralImpl enum_value = LiteralImpl::Float32Literal;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::Float64Literal> {
+ static const LiteralImpl enum_value = LiteralImpl::Float64Literal;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::ListLiteral> {
+ static const LiteralImpl enum_value = LiteralImpl::ListLiteral;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::StructLiteral> {
+ static const LiteralImpl enum_value = LiteralImpl::StructLiteral;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::MapLiteral> {
+ static const LiteralImpl enum_value = LiteralImpl::MapLiteral;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::StringLiteral> {
+ static const LiteralImpl enum_value = LiteralImpl::StringLiteral;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::BinaryLiteral> {
+ static const LiteralImpl enum_value = LiteralImpl::BinaryLiteral;
+};
+
+template<> struct LiteralImplTraits<org::apache::arrow::computeir::flatbuf::FixedSizeBinaryLiteral> {
+ static const LiteralImpl enum_value = LiteralImpl::FixedSizeBinaryLiteral;
+};
+
+bool VerifyLiteralImpl(flatbuffers::Verifier &verifier, const void *obj, LiteralImpl type);
+bool VerifyLiteralImplVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+struct ListLiteral FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ListLiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUES = 4
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>> *values() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>> *>(VT_VALUES);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_VALUES) &&
+ verifier.VerifyVector(values()) &&
+ verifier.VerifyVectorOfTables(values()) &&
+ verifier.EndTable();
+ }
+};
+
+struct ListLiteralBuilder {
+ typedef ListLiteral Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_values(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>>> values) {
+ fbb_.AddOffset(ListLiteral::VT_VALUES, values);
+ }
+ explicit ListLiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ListLiteralBuilder &operator=(const ListLiteralBuilder &);
+ flatbuffers::Offset<ListLiteral> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ListLiteral>(end);
+ fbb_.Required(o, ListLiteral::VT_VALUES);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ListLiteral> CreateListLiteral(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>>> values = 0) {
+ ListLiteralBuilder builder_(_fbb);
+ builder_.add_values(values);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<ListLiteral> CreateListLiteralDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>> *values = nullptr) {
+ auto values__ = values ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>>(*values) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateListLiteral(
+ _fbb,
+ values__);
+}
+
+struct StructLiteral FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef StructLiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUES = 4
+ };
+ /// Values for each struct field; the order must match the order of fields
+ /// in the `type` field of `Literal`.
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>> *values() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>> *>(VT_VALUES);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_VALUES) &&
+ verifier.VerifyVector(values()) &&
+ verifier.VerifyVectorOfTables(values()) &&
+ verifier.EndTable();
+ }
+};
+
+struct StructLiteralBuilder {
+ typedef StructLiteral Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_values(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>>> values) {
+ fbb_.AddOffset(StructLiteral::VT_VALUES, values);
+ }
+ explicit StructLiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ StructLiteralBuilder &operator=(const StructLiteralBuilder &);
+ flatbuffers::Offset<StructLiteral> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<StructLiteral>(end);
+ fbb_.Required(o, StructLiteral::VT_VALUES);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<StructLiteral> CreateStructLiteral(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>>> values = 0) {
+ StructLiteralBuilder builder_(_fbb);
+ builder_.add_values(values);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<StructLiteral> CreateStructLiteralDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>> *values = nullptr) {
+ auto values__ = values ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>>(*values) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateStructLiteral(
+ _fbb,
+ values__);
+}
+
+struct KeyValue FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef KeyValueBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_KEY = 4,
+ VT_VALUE = 6
+ };
+ const org::apache::arrow::computeir::flatbuf::Literal *key() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Literal *>(VT_KEY);
+ }
+ const org::apache::arrow::computeir::flatbuf::Literal *value() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Literal *>(VT_VALUE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_KEY) &&
+ verifier.VerifyTable(key()) &&
+ VerifyOffsetRequired(verifier, VT_VALUE) &&
+ verifier.VerifyTable(value()) &&
+ verifier.EndTable();
+ }
+};
+
+struct KeyValueBuilder {
+ typedef KeyValue Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_key(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal> key) {
+ fbb_.AddOffset(KeyValue::VT_KEY, key);
+ }
+ void add_value(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal> value) {
+ fbb_.AddOffset(KeyValue::VT_VALUE, value);
+ }
+ explicit KeyValueBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ KeyValueBuilder &operator=(const KeyValueBuilder &);
+ flatbuffers::Offset<KeyValue> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<KeyValue>(end);
+ fbb_.Required(o, KeyValue::VT_KEY);
+ fbb_.Required(o, KeyValue::VT_VALUE);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<KeyValue> CreateKeyValue(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal> key = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal> value = 0) {
+ KeyValueBuilder builder_(_fbb);
+ builder_.add_value(value);
+ builder_.add_key(key);
+ return builder_.Finish();
+}
+
+struct MapLiteral FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef MapLiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUES = 4
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::KeyValue>> *values() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::KeyValue>> *>(VT_VALUES);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_VALUES) &&
+ verifier.VerifyVector(values()) &&
+ verifier.VerifyVectorOfTables(values()) &&
+ verifier.EndTable();
+ }
+};
+
+struct MapLiteralBuilder {
+ typedef MapLiteral Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_values(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::KeyValue>>> values) {
+ fbb_.AddOffset(MapLiteral::VT_VALUES, values);
+ }
+ explicit MapLiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ MapLiteralBuilder &operator=(const MapLiteralBuilder &);
+ flatbuffers::Offset<MapLiteral> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<MapLiteral>(end);
+ fbb_.Required(o, MapLiteral::VT_VALUES);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<MapLiteral> CreateMapLiteral(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::KeyValue>>> values = 0) {
+ MapLiteralBuilder builder_(_fbb);
+ builder_.add_values(values);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<MapLiteral> CreateMapLiteralDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::KeyValue>> *values = nullptr) {
+ auto values__ = values ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::KeyValue>>(*values) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateMapLiteral(
+ _fbb,
+ values__);
+}
+
+struct Int8Literal FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef Int8LiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ int8_t value() const {
+ return GetField<int8_t>(VT_VALUE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct Int8LiteralBuilder {
+ typedef Int8Literal Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(int8_t value) {
+ fbb_.AddElement<int8_t>(Int8Literal::VT_VALUE, value, 0);
+ }
+ explicit Int8LiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ Int8LiteralBuilder &operator=(const Int8LiteralBuilder &);
+ flatbuffers::Offset<Int8Literal> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Int8Literal>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Int8Literal> CreateInt8Literal(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int8_t value = 0) {
+ Int8LiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct Int16Literal FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef Int16LiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ int16_t value() const {
+ return GetField<int16_t>(VT_VALUE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int16_t>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct Int16LiteralBuilder {
+ typedef Int16Literal Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(int16_t value) {
+ fbb_.AddElement<int16_t>(Int16Literal::VT_VALUE, value, 0);
+ }
+ explicit Int16LiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ Int16LiteralBuilder &operator=(const Int16LiteralBuilder &);
+ flatbuffers::Offset<Int16Literal> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Int16Literal>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Int16Literal> CreateInt16Literal(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int16_t value = 0) {
+ Int16LiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct Int32Literal FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef Int32LiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ int32_t value() const {
+ return GetField<int32_t>(VT_VALUE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct Int32LiteralBuilder {
+ typedef Int32Literal Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(int32_t value) {
+ fbb_.AddElement<int32_t>(Int32Literal::VT_VALUE, value, 0);
+ }
+ explicit Int32LiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ Int32LiteralBuilder &operator=(const Int32LiteralBuilder &);
+ flatbuffers::Offset<Int32Literal> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Int32Literal>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Int32Literal> CreateInt32Literal(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t value = 0) {
+ Int32LiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct Int64Literal FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef Int64LiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ int64_t value() const {
+ return GetField<int64_t>(VT_VALUE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int64_t>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct Int64LiteralBuilder {
+ typedef Int64Literal Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(int64_t value) {
+ fbb_.AddElement<int64_t>(Int64Literal::VT_VALUE, value, 0);
+ }
+ explicit Int64LiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ Int64LiteralBuilder &operator=(const Int64LiteralBuilder &);
+ flatbuffers::Offset<Int64Literal> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Int64Literal>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Int64Literal> CreateInt64Literal(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int64_t value = 0) {
+ Int64LiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct UInt8Literal FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef UInt8LiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ uint8_t value() const {
+ return GetField<uint8_t>(VT_VALUE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct UInt8LiteralBuilder {
+ typedef UInt8Literal Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(uint8_t value) {
+ fbb_.AddElement<uint8_t>(UInt8Literal::VT_VALUE, value, 0);
+ }
+ explicit UInt8LiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ UInt8LiteralBuilder &operator=(const UInt8LiteralBuilder &);
+ flatbuffers::Offset<UInt8Literal> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<UInt8Literal>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<UInt8Literal> CreateUInt8Literal(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ uint8_t value = 0) {
+ UInt8LiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct UInt16Literal FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef UInt16LiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ uint16_t value() const {
+ return GetField<uint16_t>(VT_VALUE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint16_t>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct UInt16LiteralBuilder {
+ typedef UInt16Literal Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(uint16_t value) {
+ fbb_.AddElement<uint16_t>(UInt16Literal::VT_VALUE, value, 0);
+ }
+ explicit UInt16LiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ UInt16LiteralBuilder &operator=(const UInt16LiteralBuilder &);
+ flatbuffers::Offset<UInt16Literal> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<UInt16Literal>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<UInt16Literal> CreateUInt16Literal(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ uint16_t value = 0) {
+ UInt16LiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct UInt32Literal FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef UInt32LiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ uint32_t value() const {
+ return GetField<uint32_t>(VT_VALUE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint32_t>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct UInt32LiteralBuilder {
+ typedef UInt32Literal Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(uint32_t value) {
+ fbb_.AddElement<uint32_t>(UInt32Literal::VT_VALUE, value, 0);
+ }
+ explicit UInt32LiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ UInt32LiteralBuilder &operator=(const UInt32LiteralBuilder &);
+ flatbuffers::Offset<UInt32Literal> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<UInt32Literal>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<UInt32Literal> CreateUInt32Literal(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ uint32_t value = 0) {
+ UInt32LiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct UInt64Literal FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef UInt64LiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ uint64_t value() const {
+ return GetField<uint64_t>(VT_VALUE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint64_t>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct UInt64LiteralBuilder {
+ typedef UInt64Literal Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(uint64_t value) {
+ fbb_.AddElement<uint64_t>(UInt64Literal::VT_VALUE, value, 0);
+ }
+ explicit UInt64LiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ UInt64LiteralBuilder &operator=(const UInt64LiteralBuilder &);
+ flatbuffers::Offset<UInt64Literal> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<UInt64Literal>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<UInt64Literal> CreateUInt64Literal(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ uint64_t value = 0) {
+ UInt64LiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct Float16Literal FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef Float16LiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ uint16_t value() const {
+ return GetField<uint16_t>(VT_VALUE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint16_t>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct Float16LiteralBuilder {
+ typedef Float16Literal Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(uint16_t value) {
+ fbb_.AddElement<uint16_t>(Float16Literal::VT_VALUE, value, 0);
+ }
+ explicit Float16LiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ Float16LiteralBuilder &operator=(const Float16LiteralBuilder &);
+ flatbuffers::Offset<Float16Literal> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Float16Literal>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Float16Literal> CreateFloat16Literal(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ uint16_t value = 0) {
+ Float16LiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct Float32Literal FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef Float32LiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ float value() const {
+ return GetField<float>(VT_VALUE, 0.0f);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<float>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct Float32LiteralBuilder {
+ typedef Float32Literal Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(float value) {
+ fbb_.AddElement<float>(Float32Literal::VT_VALUE, value, 0.0f);
+ }
+ explicit Float32LiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ Float32LiteralBuilder &operator=(const Float32LiteralBuilder &);
+ flatbuffers::Offset<Float32Literal> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Float32Literal>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Float32Literal> CreateFloat32Literal(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ float value = 0.0f) {
+ Float32LiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct Float64Literal FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef Float64LiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ double value() const {
+ return GetField<double>(VT_VALUE, 0.0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<double>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct Float64LiteralBuilder {
+ typedef Float64Literal Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(double value) {
+ fbb_.AddElement<double>(Float64Literal::VT_VALUE, value, 0.0);
+ }
+ explicit Float64LiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ Float64LiteralBuilder &operator=(const Float64LiteralBuilder &);
+ flatbuffers::Offset<Float64Literal> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Float64Literal>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Float64Literal> CreateFloat64Literal(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ double value = 0.0) {
+ Float64LiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct DecimalLiteral FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef DecimalLiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ /// Bytes of a Decimal value; bytes must be in little-endian order.
+ const flatbuffers::Vector<int8_t> *value() const {
+ return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_VALUE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_VALUE) &&
+ verifier.VerifyVector(value()) &&
+ verifier.EndTable();
+ }
+};
+
+struct DecimalLiteralBuilder {
+ typedef DecimalLiteral Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(flatbuffers::Offset<flatbuffers::Vector<int8_t>> value) {
+ fbb_.AddOffset(DecimalLiteral::VT_VALUE, value);
+ }
+ explicit DecimalLiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ DecimalLiteralBuilder &operator=(const DecimalLiteralBuilder &);
+ flatbuffers::Offset<DecimalLiteral> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<DecimalLiteral>(end);
+ fbb_.Required(o, DecimalLiteral::VT_VALUE);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<DecimalLiteral> CreateDecimalLiteral(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int8_t>> value = 0) {
+ DecimalLiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<DecimalLiteral> CreateDecimalLiteralDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int8_t> *value = nullptr) {
+ auto value__ = value ? _fbb.CreateVector<int8_t>(*value) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateDecimalLiteral(
+ _fbb,
+ value__);
+}
+
+struct BooleanLiteral FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef BooleanLiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ bool value() const {
+ return GetField<uint8_t>(VT_VALUE, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct BooleanLiteralBuilder {
+ typedef BooleanLiteral Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(bool value) {
+ fbb_.AddElement<uint8_t>(BooleanLiteral::VT_VALUE, static_cast<uint8_t>(value), 0);
+ }
+ explicit BooleanLiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ BooleanLiteralBuilder &operator=(const BooleanLiteralBuilder &);
+ flatbuffers::Offset<BooleanLiteral> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<BooleanLiteral>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<BooleanLiteral> CreateBooleanLiteral(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ bool value = false) {
+ BooleanLiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct DateLiteral FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef DateLiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ int64_t value() const {
+ return GetField<int64_t>(VT_VALUE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int64_t>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct DateLiteralBuilder {
+ typedef DateLiteral Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(int64_t value) {
+ fbb_.AddElement<int64_t>(DateLiteral::VT_VALUE, value, 0);
+ }
+ explicit DateLiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ DateLiteralBuilder &operator=(const DateLiteralBuilder &);
+ flatbuffers::Offset<DateLiteral> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<DateLiteral>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<DateLiteral> CreateDateLiteral(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int64_t value = 0) {
+ DateLiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct TimeLiteral FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef TimeLiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ int64_t value() const {
+ return GetField<int64_t>(VT_VALUE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int64_t>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct TimeLiteralBuilder {
+ typedef TimeLiteral Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(int64_t value) {
+ fbb_.AddElement<int64_t>(TimeLiteral::VT_VALUE, value, 0);
+ }
+ explicit TimeLiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TimeLiteralBuilder &operator=(const TimeLiteralBuilder &);
+ flatbuffers::Offset<TimeLiteral> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TimeLiteral>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TimeLiteral> CreateTimeLiteral(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int64_t value = 0) {
+ TimeLiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct TimestampLiteral FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef TimestampLiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ int64_t value() const {
+ return GetField<int64_t>(VT_VALUE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int64_t>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct TimestampLiteralBuilder {
+ typedef TimestampLiteral Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(int64_t value) {
+ fbb_.AddElement<int64_t>(TimestampLiteral::VT_VALUE, value, 0);
+ }
+ explicit TimestampLiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TimestampLiteralBuilder &operator=(const TimestampLiteralBuilder &);
+ flatbuffers::Offset<TimestampLiteral> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TimestampLiteral>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TimestampLiteral> CreateTimestampLiteral(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int64_t value = 0) {
+ TimestampLiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct IntervalLiteralMonths FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef IntervalLiteralMonthsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_MONTHS = 4
+ };
+ int32_t months() const {
+ return GetField<int32_t>(VT_MONTHS, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_MONTHS) &&
+ verifier.EndTable();
+ }
+};
+
+struct IntervalLiteralMonthsBuilder {
+ typedef IntervalLiteralMonths Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_months(int32_t months) {
+ fbb_.AddElement<int32_t>(IntervalLiteralMonths::VT_MONTHS, months, 0);
+ }
+ explicit IntervalLiteralMonthsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ IntervalLiteralMonthsBuilder &operator=(const IntervalLiteralMonthsBuilder &);
+ flatbuffers::Offset<IntervalLiteralMonths> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<IntervalLiteralMonths>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<IntervalLiteralMonths> CreateIntervalLiteralMonths(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t months = 0) {
+ IntervalLiteralMonthsBuilder builder_(_fbb);
+ builder_.add_months(months);
+ return builder_.Finish();
+}
+
+struct IntervalLiteralDaysMilliseconds FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef IntervalLiteralDaysMillisecondsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_DAYS = 4,
+ VT_MILLISECONDS = 6
+ };
+ int32_t days() const {
+ return GetField<int32_t>(VT_DAYS, 0);
+ }
+ int32_t milliseconds() const {
+ return GetField<int32_t>(VT_MILLISECONDS, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_DAYS) &&
+ VerifyField<int32_t>(verifier, VT_MILLISECONDS) &&
+ verifier.EndTable();
+ }
+};
+
+struct IntervalLiteralDaysMillisecondsBuilder {
+ typedef IntervalLiteralDaysMilliseconds Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_days(int32_t days) {
+ fbb_.AddElement<int32_t>(IntervalLiteralDaysMilliseconds::VT_DAYS, days, 0);
+ }
+ void add_milliseconds(int32_t milliseconds) {
+ fbb_.AddElement<int32_t>(IntervalLiteralDaysMilliseconds::VT_MILLISECONDS, milliseconds, 0);
+ }
+ explicit IntervalLiteralDaysMillisecondsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ IntervalLiteralDaysMillisecondsBuilder &operator=(const IntervalLiteralDaysMillisecondsBuilder &);
+ flatbuffers::Offset<IntervalLiteralDaysMilliseconds> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<IntervalLiteralDaysMilliseconds>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<IntervalLiteralDaysMilliseconds> CreateIntervalLiteralDaysMilliseconds(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t days = 0,
+ int32_t milliseconds = 0) {
+ IntervalLiteralDaysMillisecondsBuilder builder_(_fbb);
+ builder_.add_milliseconds(milliseconds);
+ builder_.add_days(days);
+ return builder_.Finish();
+}
+
+struct IntervalLiteral FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef IntervalLiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE_TYPE = 4,
+ VT_VALUE = 6
+ };
+ org::apache::arrow::computeir::flatbuf::IntervalLiteralImpl value_type() const {
+ return static_cast<org::apache::arrow::computeir::flatbuf::IntervalLiteralImpl>(GetField<uint8_t>(VT_VALUE_TYPE, 0));
+ }
+ const void *value() const {
+ return GetPointer<const void *>(VT_VALUE);
+ }
+ template<typename T> const T *value_as() const;
+ const org::apache::arrow::computeir::flatbuf::IntervalLiteralMonths *value_as_IntervalLiteralMonths() const {
+ return value_type() == org::apache::arrow::computeir::flatbuf::IntervalLiteralImpl::IntervalLiteralMonths ? static_cast<const org::apache::arrow::computeir::flatbuf::IntervalLiteralMonths *>(value()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::IntervalLiteralDaysMilliseconds *value_as_IntervalLiteralDaysMilliseconds() const {
+ return value_type() == org::apache::arrow::computeir::flatbuf::IntervalLiteralImpl::IntervalLiteralDaysMilliseconds ? static_cast<const org::apache::arrow::computeir::flatbuf::IntervalLiteralDaysMilliseconds *>(value()) : nullptr;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_VALUE_TYPE) &&
+ VerifyOffsetRequired(verifier, VT_VALUE) &&
+ VerifyIntervalLiteralImpl(verifier, value(), value_type()) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const org::apache::arrow::computeir::flatbuf::IntervalLiteralMonths *IntervalLiteral::value_as<org::apache::arrow::computeir::flatbuf::IntervalLiteralMonths>() const {
+ return value_as_IntervalLiteralMonths();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::IntervalLiteralDaysMilliseconds *IntervalLiteral::value_as<org::apache::arrow::computeir::flatbuf::IntervalLiteralDaysMilliseconds>() const {
+ return value_as_IntervalLiteralDaysMilliseconds();
+}
+
+struct IntervalLiteralBuilder {
+ typedef IntervalLiteral Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value_type(org::apache::arrow::computeir::flatbuf::IntervalLiteralImpl value_type) {
+ fbb_.AddElement<uint8_t>(IntervalLiteral::VT_VALUE_TYPE, static_cast<uint8_t>(value_type), 0);
+ }
+ void add_value(flatbuffers::Offset<void> value) {
+ fbb_.AddOffset(IntervalLiteral::VT_VALUE, value);
+ }
+ explicit IntervalLiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ IntervalLiteralBuilder &operator=(const IntervalLiteralBuilder &);
+ flatbuffers::Offset<IntervalLiteral> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<IntervalLiteral>(end);
+ fbb_.Required(o, IntervalLiteral::VT_VALUE);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<IntervalLiteral> CreateIntervalLiteral(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::computeir::flatbuf::IntervalLiteralImpl value_type = org::apache::arrow::computeir::flatbuf::IntervalLiteralImpl::NONE,
+ flatbuffers::Offset<void> value = 0) {
+ IntervalLiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ builder_.add_value_type(value_type);
+ return builder_.Finish();
+}
+
+struct DurationLiteral FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef DurationLiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ int64_t value() const {
+ return GetField<int64_t>(VT_VALUE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int64_t>(verifier, VT_VALUE) &&
+ verifier.EndTable();
+ }
+};
+
+struct DurationLiteralBuilder {
+ typedef DurationLiteral Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(int64_t value) {
+ fbb_.AddElement<int64_t>(DurationLiteral::VT_VALUE, value, 0);
+ }
+ explicit DurationLiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ DurationLiteralBuilder &operator=(const DurationLiteralBuilder &);
+ flatbuffers::Offset<DurationLiteral> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<DurationLiteral>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<DurationLiteral> CreateDurationLiteral(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int64_t value = 0) {
+ DurationLiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+struct BinaryLiteral FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef BinaryLiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ const flatbuffers::Vector<int8_t> *value() const {
+ return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_VALUE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_VALUE) &&
+ verifier.VerifyVector(value()) &&
+ verifier.EndTable();
+ }
+};
+
+struct BinaryLiteralBuilder {
+ typedef BinaryLiteral Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(flatbuffers::Offset<flatbuffers::Vector<int8_t>> value) {
+ fbb_.AddOffset(BinaryLiteral::VT_VALUE, value);
+ }
+ explicit BinaryLiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ BinaryLiteralBuilder &operator=(const BinaryLiteralBuilder &);
+ flatbuffers::Offset<BinaryLiteral> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<BinaryLiteral>(end);
+ fbb_.Required(o, BinaryLiteral::VT_VALUE);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<BinaryLiteral> CreateBinaryLiteral(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int8_t>> value = 0) {
+ BinaryLiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<BinaryLiteral> CreateBinaryLiteralDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int8_t> *value = nullptr) {
+ auto value__ = value ? _fbb.CreateVector<int8_t>(*value) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateBinaryLiteral(
+ _fbb,
+ value__);
+}
+
+struct FixedSizeBinaryLiteral FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FixedSizeBinaryLiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ const flatbuffers::Vector<int8_t> *value() const {
+ return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_VALUE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_VALUE) &&
+ verifier.VerifyVector(value()) &&
+ verifier.EndTable();
+ }
+};
+
+struct FixedSizeBinaryLiteralBuilder {
+ typedef FixedSizeBinaryLiteral Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(flatbuffers::Offset<flatbuffers::Vector<int8_t>> value) {
+ fbb_.AddOffset(FixedSizeBinaryLiteral::VT_VALUE, value);
+ }
+ explicit FixedSizeBinaryLiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FixedSizeBinaryLiteralBuilder &operator=(const FixedSizeBinaryLiteralBuilder &);
+ flatbuffers::Offset<FixedSizeBinaryLiteral> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FixedSizeBinaryLiteral>(end);
+ fbb_.Required(o, FixedSizeBinaryLiteral::VT_VALUE);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FixedSizeBinaryLiteral> CreateFixedSizeBinaryLiteral(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int8_t>> value = 0) {
+ FixedSizeBinaryLiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<FixedSizeBinaryLiteral> CreateFixedSizeBinaryLiteralDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int8_t> *value = nullptr) {
+ auto value__ = value ? _fbb.CreateVector<int8_t>(*value) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateFixedSizeBinaryLiteral(
+ _fbb,
+ value__);
+}
+
+struct StringLiteral FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef StringLiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VALUE = 4
+ };
+ const flatbuffers::String *value() const {
+ return GetPointer<const flatbuffers::String *>(VT_VALUE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_VALUE) &&
+ verifier.VerifyString(value()) &&
+ verifier.EndTable();
+ }
+};
+
+struct StringLiteralBuilder {
+ typedef StringLiteral Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_value(flatbuffers::Offset<flatbuffers::String> value) {
+ fbb_.AddOffset(StringLiteral::VT_VALUE, value);
+ }
+ explicit StringLiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ StringLiteralBuilder &operator=(const StringLiteralBuilder &);
+ flatbuffers::Offset<StringLiteral> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<StringLiteral>(end);
+ fbb_.Required(o, StringLiteral::VT_VALUE);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<StringLiteral> CreateStringLiteral(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> value = 0) {
+ StringLiteralBuilder builder_(_fbb);
+ builder_.add_value(value);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<StringLiteral> CreateStringLiteralDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *value = nullptr) {
+ auto value__ = value ? _fbb.CreateString(value) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateStringLiteral(
+ _fbb,
+ value__);
+}
+
+struct Literal FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef LiteralBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_IMPL_TYPE = 4,
+ VT_IMPL = 6,
+ VT_TYPE = 8
+ };
+ org::apache::arrow::computeir::flatbuf::LiteralImpl impl_type() const {
+ return static_cast<org::apache::arrow::computeir::flatbuf::LiteralImpl>(GetField<uint8_t>(VT_IMPL_TYPE, 0));
+ }
+ /// Literal value data; for null literals do not include this field.
+ const void *impl() const {
+ return GetPointer<const void *>(VT_IMPL);
+ }
+ template<typename T> const T *impl_as() const;
+ const org::apache::arrow::computeir::flatbuf::BooleanLiteral *impl_as_BooleanLiteral() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::BooleanLiteral ? static_cast<const org::apache::arrow::computeir::flatbuf::BooleanLiteral *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Int8Literal *impl_as_Int8Literal() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::Int8Literal ? static_cast<const org::apache::arrow::computeir::flatbuf::Int8Literal *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Int16Literal *impl_as_Int16Literal() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::Int16Literal ? static_cast<const org::apache::arrow::computeir::flatbuf::Int16Literal *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Int32Literal *impl_as_Int32Literal() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::Int32Literal ? static_cast<const org::apache::arrow::computeir::flatbuf::Int32Literal *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Int64Literal *impl_as_Int64Literal() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::Int64Literal ? static_cast<const org::apache::arrow::computeir::flatbuf::Int64Literal *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::UInt8Literal *impl_as_UInt8Literal() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::UInt8Literal ? static_cast<const org::apache::arrow::computeir::flatbuf::UInt8Literal *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::UInt16Literal *impl_as_UInt16Literal() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::UInt16Literal ? static_cast<const org::apache::arrow::computeir::flatbuf::UInt16Literal *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::UInt32Literal *impl_as_UInt32Literal() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::UInt32Literal ? static_cast<const org::apache::arrow::computeir::flatbuf::UInt32Literal *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::UInt64Literal *impl_as_UInt64Literal() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::UInt64Literal ? static_cast<const org::apache::arrow::computeir::flatbuf::UInt64Literal *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::DateLiteral *impl_as_DateLiteral() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::DateLiteral ? static_cast<const org::apache::arrow::computeir::flatbuf::DateLiteral *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::TimeLiteral *impl_as_TimeLiteral() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::TimeLiteral ? static_cast<const org::apache::arrow::computeir::flatbuf::TimeLiteral *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::TimestampLiteral *impl_as_TimestampLiteral() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::TimestampLiteral ? static_cast<const org::apache::arrow::computeir::flatbuf::TimestampLiteral *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::IntervalLiteral *impl_as_IntervalLiteral() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::IntervalLiteral ? static_cast<const org::apache::arrow::computeir::flatbuf::IntervalLiteral *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::DurationLiteral *impl_as_DurationLiteral() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::DurationLiteral ? static_cast<const org::apache::arrow::computeir::flatbuf::DurationLiteral *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::DecimalLiteral *impl_as_DecimalLiteral() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::DecimalLiteral ? static_cast<const org::apache::arrow::computeir::flatbuf::DecimalLiteral *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Float16Literal *impl_as_Float16Literal() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::Float16Literal ? static_cast<const org::apache::arrow::computeir::flatbuf::Float16Literal *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Float32Literal *impl_as_Float32Literal() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::Float32Literal ? static_cast<const org::apache::arrow::computeir::flatbuf::Float32Literal *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Float64Literal *impl_as_Float64Literal() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::Float64Literal ? static_cast<const org::apache::arrow::computeir::flatbuf::Float64Literal *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::ListLiteral *impl_as_ListLiteral() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::ListLiteral ? static_cast<const org::apache::arrow::computeir::flatbuf::ListLiteral *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::StructLiteral *impl_as_StructLiteral() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::StructLiteral ? static_cast<const org::apache::arrow::computeir::flatbuf::StructLiteral *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::MapLiteral *impl_as_MapLiteral() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::MapLiteral ? static_cast<const org::apache::arrow::computeir::flatbuf::MapLiteral *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::StringLiteral *impl_as_StringLiteral() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::StringLiteral ? static_cast<const org::apache::arrow::computeir::flatbuf::StringLiteral *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::BinaryLiteral *impl_as_BinaryLiteral() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::BinaryLiteral ? static_cast<const org::apache::arrow::computeir::flatbuf::BinaryLiteral *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::FixedSizeBinaryLiteral *impl_as_FixedSizeBinaryLiteral() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::LiteralImpl::FixedSizeBinaryLiteral ? static_cast<const org::apache::arrow::computeir::flatbuf::FixedSizeBinaryLiteral *>(impl()) : nullptr;
+ }
+ /// Type of the literal value. This must match `impl`.
+ const org::apache::arrow::flatbuf::Field *type() const {
+ return GetPointer<const org::apache::arrow::flatbuf::Field *>(VT_TYPE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_IMPL_TYPE) &&
+ VerifyOffset(verifier, VT_IMPL) &&
+ VerifyLiteralImpl(verifier, impl(), impl_type()) &&
+ VerifyOffsetRequired(verifier, VT_TYPE) &&
+ verifier.VerifyTable(type()) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const org::apache::arrow::computeir::flatbuf::BooleanLiteral *Literal::impl_as<org::apache::arrow::computeir::flatbuf::BooleanLiteral>() const {
+ return impl_as_BooleanLiteral();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Int8Literal *Literal::impl_as<org::apache::arrow::computeir::flatbuf::Int8Literal>() const {
+ return impl_as_Int8Literal();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Int16Literal *Literal::impl_as<org::apache::arrow::computeir::flatbuf::Int16Literal>() const {
+ return impl_as_Int16Literal();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Int32Literal *Literal::impl_as<org::apache::arrow::computeir::flatbuf::Int32Literal>() const {
+ return impl_as_Int32Literal();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Int64Literal *Literal::impl_as<org::apache::arrow::computeir::flatbuf::Int64Literal>() const {
+ return impl_as_Int64Literal();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::UInt8Literal *Literal::impl_as<org::apache::arrow::computeir::flatbuf::UInt8Literal>() const {
+ return impl_as_UInt8Literal();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::UInt16Literal *Literal::impl_as<org::apache::arrow::computeir::flatbuf::UInt16Literal>() const {
+ return impl_as_UInt16Literal();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::UInt32Literal *Literal::impl_as<org::apache::arrow::computeir::flatbuf::UInt32Literal>() const {
+ return impl_as_UInt32Literal();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::UInt64Literal *Literal::impl_as<org::apache::arrow::computeir::flatbuf::UInt64Literal>() const {
+ return impl_as_UInt64Literal();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::DateLiteral *Literal::impl_as<org::apache::arrow::computeir::flatbuf::DateLiteral>() const {
+ return impl_as_DateLiteral();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::TimeLiteral *Literal::impl_as<org::apache::arrow::computeir::flatbuf::TimeLiteral>() const {
+ return impl_as_TimeLiteral();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::TimestampLiteral *Literal::impl_as<org::apache::arrow::computeir::flatbuf::TimestampLiteral>() const {
+ return impl_as_TimestampLiteral();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::IntervalLiteral *Literal::impl_as<org::apache::arrow::computeir::flatbuf::IntervalLiteral>() const {
+ return impl_as_IntervalLiteral();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::DurationLiteral *Literal::impl_as<org::apache::arrow::computeir::flatbuf::DurationLiteral>() const {
+ return impl_as_DurationLiteral();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::DecimalLiteral *Literal::impl_as<org::apache::arrow::computeir::flatbuf::DecimalLiteral>() const {
+ return impl_as_DecimalLiteral();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Float16Literal *Literal::impl_as<org::apache::arrow::computeir::flatbuf::Float16Literal>() const {
+ return impl_as_Float16Literal();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Float32Literal *Literal::impl_as<org::apache::arrow::computeir::flatbuf::Float32Literal>() const {
+ return impl_as_Float32Literal();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Float64Literal *Literal::impl_as<org::apache::arrow::computeir::flatbuf::Float64Literal>() const {
+ return impl_as_Float64Literal();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::ListLiteral *Literal::impl_as<org::apache::arrow::computeir::flatbuf::ListLiteral>() const {
+ return impl_as_ListLiteral();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::StructLiteral *Literal::impl_as<org::apache::arrow::computeir::flatbuf::StructLiteral>() const {
+ return impl_as_StructLiteral();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::MapLiteral *Literal::impl_as<org::apache::arrow::computeir::flatbuf::MapLiteral>() const {
+ return impl_as_MapLiteral();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::StringLiteral *Literal::impl_as<org::apache::arrow::computeir::flatbuf::StringLiteral>() const {
+ return impl_as_StringLiteral();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::BinaryLiteral *Literal::impl_as<org::apache::arrow::computeir::flatbuf::BinaryLiteral>() const {
+ return impl_as_BinaryLiteral();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::FixedSizeBinaryLiteral *Literal::impl_as<org::apache::arrow::computeir::flatbuf::FixedSizeBinaryLiteral>() const {
+ return impl_as_FixedSizeBinaryLiteral();
+}
+
+struct LiteralBuilder {
+ typedef Literal Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_impl_type(org::apache::arrow::computeir::flatbuf::LiteralImpl impl_type) {
+ fbb_.AddElement<uint8_t>(Literal::VT_IMPL_TYPE, static_cast<uint8_t>(impl_type), 0);
+ }
+ void add_impl(flatbuffers::Offset<void> impl) {
+ fbb_.AddOffset(Literal::VT_IMPL, impl);
+ }
+ void add_type(flatbuffers::Offset<org::apache::arrow::flatbuf::Field> type) {
+ fbb_.AddOffset(Literal::VT_TYPE, type);
+ }
+ explicit LiteralBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ LiteralBuilder &operator=(const LiteralBuilder &);
+ flatbuffers::Offset<Literal> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Literal>(end);
+ fbb_.Required(o, Literal::VT_TYPE);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Literal> CreateLiteral(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::computeir::flatbuf::LiteralImpl impl_type = org::apache::arrow::computeir::flatbuf::LiteralImpl::NONE,
+ flatbuffers::Offset<void> impl = 0,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Field> type = 0) {
+ LiteralBuilder builder_(_fbb);
+ builder_.add_type(type);
+ builder_.add_impl(impl);
+ builder_.add_impl_type(impl_type);
+ return builder_.Finish();
+}
+
+inline bool VerifyIntervalLiteralImpl(flatbuffers::Verifier &verifier, const void *obj, IntervalLiteralImpl type) {
+ switch (type) {
+ case IntervalLiteralImpl::NONE: {
+ return true;
+ }
+ case IntervalLiteralImpl::IntervalLiteralMonths: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::IntervalLiteralMonths *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case IntervalLiteralImpl::IntervalLiteralDaysMilliseconds: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::IntervalLiteralDaysMilliseconds *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return true;
+ }
+}
+
+inline bool VerifyIntervalLiteralImplVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyIntervalLiteralImpl(
+ verifier, values->Get(i), types->GetEnum<IntervalLiteralImpl>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline bool VerifyLiteralImpl(flatbuffers::Verifier &verifier, const void *obj, LiteralImpl type) {
+ switch (type) {
+ case LiteralImpl::NONE: {
+ return true;
+ }
+ case LiteralImpl::BooleanLiteral: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::BooleanLiteral *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::Int8Literal: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Int8Literal *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::Int16Literal: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Int16Literal *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::Int32Literal: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Int32Literal *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::Int64Literal: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Int64Literal *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::UInt8Literal: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::UInt8Literal *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::UInt16Literal: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::UInt16Literal *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::UInt32Literal: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::UInt32Literal *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::UInt64Literal: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::UInt64Literal *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::DateLiteral: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::DateLiteral *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::TimeLiteral: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::TimeLiteral *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::TimestampLiteral: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::TimestampLiteral *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::IntervalLiteral: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::IntervalLiteral *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::DurationLiteral: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::DurationLiteral *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::DecimalLiteral: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::DecimalLiteral *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::Float16Literal: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Float16Literal *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::Float32Literal: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Float32Literal *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::Float64Literal: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Float64Literal *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::ListLiteral: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::ListLiteral *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::StructLiteral: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::StructLiteral *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::MapLiteral: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::MapLiteral *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::StringLiteral: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::StringLiteral *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::BinaryLiteral: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::BinaryLiteral *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case LiteralImpl::FixedSizeBinaryLiteral: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::FixedSizeBinaryLiteral *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return true;
+ }
+}
+
+inline bool VerifyLiteralImplVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyLiteralImpl(
+ verifier, values->Get(i), types->GetEnum<LiteralImpl>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline const org::apache::arrow::computeir::flatbuf::Literal *GetLiteral(const void *buf) {
+ return flatbuffers::GetRoot<org::apache::arrow::computeir::flatbuf::Literal>(buf);
+}
+
+inline const org::apache::arrow::computeir::flatbuf::Literal *GetSizePrefixedLiteral(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<org::apache::arrow::computeir::flatbuf::Literal>(buf);
+}
+
+inline bool VerifyLiteralBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifyBuffer<org::apache::arrow::computeir::flatbuf::Literal>(nullptr);
+}
+
+inline bool VerifySizePrefixedLiteralBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<org::apache::arrow::computeir::flatbuf::Literal>(nullptr);
+}
+
+inline void FinishLiteralBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal> root) {
+ fbb.Finish(root);
+}
+
+inline void FinishSizePrefixedLiteralBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal> root) {
+ fbb.FinishSizePrefixed(root);
+}
+
+} // namespace flatbuf
+} // namespace computeir
+} // namespace arrow
+} // namespace apache
+} // namespace org
+
+#endif // FLATBUFFERS_GENERATED_LITERAL_ORG_APACHE_ARROW_COMPUTEIR_FLATBUF_H_
diff --git a/src/arrow/cpp/src/generated/Message_generated.h b/src/arrow/cpp/src/generated/Message_generated.h
new file mode 100644
index 000000000..1c51c6eaf
--- /dev/null
+++ b/src/arrow/cpp/src/generated/Message_generated.h
@@ -0,0 +1,659 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_MESSAGE_ORG_APACHE_ARROW_FLATBUF_H_
+#define FLATBUFFERS_GENERATED_MESSAGE_ORG_APACHE_ARROW_FLATBUF_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+#include "Schema_generated.h"
+#include "SparseTensor_generated.h"
+#include "Tensor_generated.h"
+
+namespace org {
+namespace apache {
+namespace arrow {
+namespace flatbuf {
+
+struct FieldNode;
+
+struct BodyCompression;
+struct BodyCompressionBuilder;
+
+struct RecordBatch;
+struct RecordBatchBuilder;
+
+struct DictionaryBatch;
+struct DictionaryBatchBuilder;
+
+struct Message;
+struct MessageBuilder;
+
+enum class CompressionType : int8_t {
+ LZ4_FRAME = 0,
+ ZSTD = 1,
+ MIN = LZ4_FRAME,
+ MAX = ZSTD
+};
+
+inline const CompressionType (&EnumValuesCompressionType())[2] {
+ static const CompressionType values[] = {
+ CompressionType::LZ4_FRAME,
+ CompressionType::ZSTD
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesCompressionType() {
+ static const char * const names[3] = {
+ "LZ4_FRAME",
+ "ZSTD",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameCompressionType(CompressionType e) {
+ if (flatbuffers::IsOutRange(e, CompressionType::LZ4_FRAME, CompressionType::ZSTD)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesCompressionType()[index];
+}
+
+/// Provided for forward compatibility in case we need to support different
+/// strategies for compressing the IPC message body (like whole-body
+/// compression rather than buffer-level) in the future
+enum class BodyCompressionMethod : int8_t {
+ /// Each constituent buffer is first compressed with the indicated
+ /// compressor, and then written with the uncompressed length in the first 8
+ /// bytes as a 64-bit little-endian signed integer followed by the compressed
+ /// buffer bytes (and then padding as required by the protocol). The
+ /// uncompressed length may be set to -1 to indicate that the data that
+ /// follows is not compressed, which can be useful for cases where
+ /// compression does not yield appreciable savings.
+ BUFFER = 0,
+ MIN = BUFFER,
+ MAX = BUFFER
+};
+
+inline const BodyCompressionMethod (&EnumValuesBodyCompressionMethod())[1] {
+ static const BodyCompressionMethod values[] = {
+ BodyCompressionMethod::BUFFER
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesBodyCompressionMethod() {
+ static const char * const names[2] = {
+ "BUFFER",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameBodyCompressionMethod(BodyCompressionMethod e) {
+ if (flatbuffers::IsOutRange(e, BodyCompressionMethod::BUFFER, BodyCompressionMethod::BUFFER)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesBodyCompressionMethod()[index];
+}
+
+/// ----------------------------------------------------------------------
+/// The root Message type
+/// This union enables us to easily send different message types without
+/// redundant storage, and in the future we can easily add new message types.
+///
+/// Arrow implementations do not need to implement all of the message types,
+/// which may include experimental metadata types. For maximum compatibility,
+/// it is best to send data using RecordBatch
+enum class MessageHeader : uint8_t {
+ NONE = 0,
+ Schema = 1,
+ DictionaryBatch = 2,
+ RecordBatch = 3,
+ Tensor = 4,
+ SparseTensor = 5,
+ MIN = NONE,
+ MAX = SparseTensor
+};
+
+inline const MessageHeader (&EnumValuesMessageHeader())[6] {
+ static const MessageHeader values[] = {
+ MessageHeader::NONE,
+ MessageHeader::Schema,
+ MessageHeader::DictionaryBatch,
+ MessageHeader::RecordBatch,
+ MessageHeader::Tensor,
+ MessageHeader::SparseTensor
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesMessageHeader() {
+ static const char * const names[7] = {
+ "NONE",
+ "Schema",
+ "DictionaryBatch",
+ "RecordBatch",
+ "Tensor",
+ "SparseTensor",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameMessageHeader(MessageHeader e) {
+ if (flatbuffers::IsOutRange(e, MessageHeader::NONE, MessageHeader::SparseTensor)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesMessageHeader()[index];
+}
+
+template<typename T> struct MessageHeaderTraits {
+ static const MessageHeader enum_value = MessageHeader::NONE;
+};
+
+template<> struct MessageHeaderTraits<org::apache::arrow::flatbuf::Schema> {
+ static const MessageHeader enum_value = MessageHeader::Schema;
+};
+
+template<> struct MessageHeaderTraits<org::apache::arrow::flatbuf::DictionaryBatch> {
+ static const MessageHeader enum_value = MessageHeader::DictionaryBatch;
+};
+
+template<> struct MessageHeaderTraits<org::apache::arrow::flatbuf::RecordBatch> {
+ static const MessageHeader enum_value = MessageHeader::RecordBatch;
+};
+
+template<> struct MessageHeaderTraits<org::apache::arrow::flatbuf::Tensor> {
+ static const MessageHeader enum_value = MessageHeader::Tensor;
+};
+
+template<> struct MessageHeaderTraits<org::apache::arrow::flatbuf::SparseTensor> {
+ static const MessageHeader enum_value = MessageHeader::SparseTensor;
+};
+
+bool VerifyMessageHeader(flatbuffers::Verifier &verifier, const void *obj, MessageHeader type);
+bool VerifyMessageHeaderVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+/// ----------------------------------------------------------------------
+/// Data structures for describing a table row batch (a collection of
+/// equal-length Arrow arrays)
+/// Metadata about a field at some level of a nested type tree (but not
+/// its children).
+///
+/// For example, a List<Int16> with values `[[1, 2, 3], null, [4], [5, 6], null]`
+/// would have {length: 5, null_count: 2} for its List node, and {length: 6,
+/// null_count: 0} for its Int16 node, as separate FieldNode structs
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) FieldNode FLATBUFFERS_FINAL_CLASS {
+ private:
+ int64_t length_;
+ int64_t null_count_;
+
+ public:
+ FieldNode() {
+ memset(static_cast<void *>(this), 0, sizeof(FieldNode));
+ }
+ FieldNode(int64_t _length, int64_t _null_count)
+ : length_(flatbuffers::EndianScalar(_length)),
+ null_count_(flatbuffers::EndianScalar(_null_count)) {
+ }
+ /// The number of value slots in the Arrow array at this level of a nested
+ /// tree
+ int64_t length() const {
+ return flatbuffers::EndianScalar(length_);
+ }
+ /// The number of observed nulls. Fields with null_count == 0 may choose not
+ /// to write their physical validity bitmap out as a materialized buffer,
+ /// instead setting the length of the bitmap buffer to 0.
+ int64_t null_count() const {
+ return flatbuffers::EndianScalar(null_count_);
+ }
+};
+FLATBUFFERS_STRUCT_END(FieldNode, 16);
+
+/// Optional compression for the memory buffers constituting IPC message
+/// bodies. Intended for use with RecordBatch but could be used for other
+/// message types
+struct BodyCompression FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef BodyCompressionBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_CODEC = 4,
+ VT_METHOD = 6
+ };
+ /// Compressor library
+ org::apache::arrow::flatbuf::CompressionType codec() const {
+ return static_cast<org::apache::arrow::flatbuf::CompressionType>(GetField<int8_t>(VT_CODEC, 0));
+ }
+ /// Indicates the way the record batch body was compressed
+ org::apache::arrow::flatbuf::BodyCompressionMethod method() const {
+ return static_cast<org::apache::arrow::flatbuf::BodyCompressionMethod>(GetField<int8_t>(VT_METHOD, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_CODEC) &&
+ VerifyField<int8_t>(verifier, VT_METHOD) &&
+ verifier.EndTable();
+ }
+};
+
+struct BodyCompressionBuilder {
+ typedef BodyCompression Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_codec(org::apache::arrow::flatbuf::CompressionType codec) {
+ fbb_.AddElement<int8_t>(BodyCompression::VT_CODEC, static_cast<int8_t>(codec), 0);
+ }
+ void add_method(org::apache::arrow::flatbuf::BodyCompressionMethod method) {
+ fbb_.AddElement<int8_t>(BodyCompression::VT_METHOD, static_cast<int8_t>(method), 0);
+ }
+ explicit BodyCompressionBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ BodyCompressionBuilder &operator=(const BodyCompressionBuilder &);
+ flatbuffers::Offset<BodyCompression> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<BodyCompression>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<BodyCompression> CreateBodyCompression(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::CompressionType codec = org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME,
+ org::apache::arrow::flatbuf::BodyCompressionMethod method = org::apache::arrow::flatbuf::BodyCompressionMethod::BUFFER) {
+ BodyCompressionBuilder builder_(_fbb);
+ builder_.add_method(method);
+ builder_.add_codec(codec);
+ return builder_.Finish();
+}
+
+/// A data header describing the shared memory layout of a "record" or "row"
+/// batch. Some systems call this a "row batch" internally and others a "record
+/// batch".
+struct RecordBatch FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef RecordBatchBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_LENGTH = 4,
+ VT_NODES = 6,
+ VT_BUFFERS = 8,
+ VT_COMPRESSION = 10
+ };
+ /// number of records / rows. The arrays in the batch should all have this
+ /// length
+ int64_t length() const {
+ return GetField<int64_t>(VT_LENGTH, 0);
+ }
+ /// Nodes correspond to the pre-ordered flattened logical schema
+ const flatbuffers::Vector<const org::apache::arrow::flatbuf::FieldNode *> *nodes() const {
+ return GetPointer<const flatbuffers::Vector<const org::apache::arrow::flatbuf::FieldNode *> *>(VT_NODES);
+ }
+ /// Buffers correspond to the pre-ordered flattened buffer tree
+ ///
+ /// The number of buffers appended to this list depends on the schema. For
+ /// example, most primitive arrays will have 2 buffers, 1 for the validity
+ /// bitmap and 1 for the values. For struct arrays, there will only be a
+ /// single buffer for the validity (nulls) bitmap
+ const flatbuffers::Vector<const org::apache::arrow::flatbuf::Buffer *> *buffers() const {
+ return GetPointer<const flatbuffers::Vector<const org::apache::arrow::flatbuf::Buffer *> *>(VT_BUFFERS);
+ }
+ /// Optional compression of the message body
+ const org::apache::arrow::flatbuf::BodyCompression *compression() const {
+ return GetPointer<const org::apache::arrow::flatbuf::BodyCompression *>(VT_COMPRESSION);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int64_t>(verifier, VT_LENGTH) &&
+ VerifyOffset(verifier, VT_NODES) &&
+ verifier.VerifyVector(nodes()) &&
+ VerifyOffset(verifier, VT_BUFFERS) &&
+ verifier.VerifyVector(buffers()) &&
+ VerifyOffset(verifier, VT_COMPRESSION) &&
+ verifier.VerifyTable(compression()) &&
+ verifier.EndTable();
+ }
+};
+
+struct RecordBatchBuilder {
+ typedef RecordBatch Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_length(int64_t length) {
+ fbb_.AddElement<int64_t>(RecordBatch::VT_LENGTH, length, 0);
+ }
+ void add_nodes(flatbuffers::Offset<flatbuffers::Vector<const org::apache::arrow::flatbuf::FieldNode *>> nodes) {
+ fbb_.AddOffset(RecordBatch::VT_NODES, nodes);
+ }
+ void add_buffers(flatbuffers::Offset<flatbuffers::Vector<const org::apache::arrow::flatbuf::Buffer *>> buffers) {
+ fbb_.AddOffset(RecordBatch::VT_BUFFERS, buffers);
+ }
+ void add_compression(flatbuffers::Offset<org::apache::arrow::flatbuf::BodyCompression> compression) {
+ fbb_.AddOffset(RecordBatch::VT_COMPRESSION, compression);
+ }
+ explicit RecordBatchBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ RecordBatchBuilder &operator=(const RecordBatchBuilder &);
+ flatbuffers::Offset<RecordBatch> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<RecordBatch>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<RecordBatch> CreateRecordBatch(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int64_t length = 0,
+ flatbuffers::Offset<flatbuffers::Vector<const org::apache::arrow::flatbuf::FieldNode *>> nodes = 0,
+ flatbuffers::Offset<flatbuffers::Vector<const org::apache::arrow::flatbuf::Buffer *>> buffers = 0,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::BodyCompression> compression = 0) {
+ RecordBatchBuilder builder_(_fbb);
+ builder_.add_length(length);
+ builder_.add_compression(compression);
+ builder_.add_buffers(buffers);
+ builder_.add_nodes(nodes);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<RecordBatch> CreateRecordBatchDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int64_t length = 0,
+ const std::vector<org::apache::arrow::flatbuf::FieldNode> *nodes = nullptr,
+ const std::vector<org::apache::arrow::flatbuf::Buffer> *buffers = nullptr,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::BodyCompression> compression = 0) {
+ auto nodes__ = nodes ? _fbb.CreateVectorOfStructs<org::apache::arrow::flatbuf::FieldNode>(*nodes) : 0;
+ auto buffers__ = buffers ? _fbb.CreateVectorOfStructs<org::apache::arrow::flatbuf::Buffer>(*buffers) : 0;
+ return org::apache::arrow::flatbuf::CreateRecordBatch(
+ _fbb,
+ length,
+ nodes__,
+ buffers__,
+ compression);
+}
+
+/// For sending dictionary encoding information. Any Field can be
+/// dictionary-encoded, but in this case none of its children may be
+/// dictionary-encoded.
+/// There is one vector / column per dictionary, but that vector / column
+/// may be spread across multiple dictionary batches by using the isDelta
+/// flag
+struct DictionaryBatch FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef DictionaryBatchBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_ID = 4,
+ VT_DATA = 6,
+ VT_ISDELTA = 8
+ };
+ int64_t id() const {
+ return GetField<int64_t>(VT_ID, 0);
+ }
+ const org::apache::arrow::flatbuf::RecordBatch *data() const {
+ return GetPointer<const org::apache::arrow::flatbuf::RecordBatch *>(VT_DATA);
+ }
+ /// If isDelta is true the values in the dictionary are to be appended to a
+ /// dictionary with the indicated id. If isDelta is false this dictionary
+ /// should replace the existing dictionary.
+ bool isDelta() const {
+ return GetField<uint8_t>(VT_ISDELTA, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int64_t>(verifier, VT_ID) &&
+ VerifyOffset(verifier, VT_DATA) &&
+ verifier.VerifyTable(data()) &&
+ VerifyField<uint8_t>(verifier, VT_ISDELTA) &&
+ verifier.EndTable();
+ }
+};
+
+struct DictionaryBatchBuilder {
+ typedef DictionaryBatch Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_id(int64_t id) {
+ fbb_.AddElement<int64_t>(DictionaryBatch::VT_ID, id, 0);
+ }
+ void add_data(flatbuffers::Offset<org::apache::arrow::flatbuf::RecordBatch> data) {
+ fbb_.AddOffset(DictionaryBatch::VT_DATA, data);
+ }
+ void add_isDelta(bool isDelta) {
+ fbb_.AddElement<uint8_t>(DictionaryBatch::VT_ISDELTA, static_cast<uint8_t>(isDelta), 0);
+ }
+ explicit DictionaryBatchBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ DictionaryBatchBuilder &operator=(const DictionaryBatchBuilder &);
+ flatbuffers::Offset<DictionaryBatch> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<DictionaryBatch>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<DictionaryBatch> CreateDictionaryBatch(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int64_t id = 0,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::RecordBatch> data = 0,
+ bool isDelta = false) {
+ DictionaryBatchBuilder builder_(_fbb);
+ builder_.add_id(id);
+ builder_.add_data(data);
+ builder_.add_isDelta(isDelta);
+ return builder_.Finish();
+}
+
+struct Message FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef MessageBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VERSION = 4,
+ VT_HEADER_TYPE = 6,
+ VT_HEADER = 8,
+ VT_BODYLENGTH = 10,
+ VT_CUSTOM_METADATA = 12
+ };
+ org::apache::arrow::flatbuf::MetadataVersion version() const {
+ return static_cast<org::apache::arrow::flatbuf::MetadataVersion>(GetField<int16_t>(VT_VERSION, 0));
+ }
+ org::apache::arrow::flatbuf::MessageHeader header_type() const {
+ return static_cast<org::apache::arrow::flatbuf::MessageHeader>(GetField<uint8_t>(VT_HEADER_TYPE, 0));
+ }
+ const void *header() const {
+ return GetPointer<const void *>(VT_HEADER);
+ }
+ template<typename T> const T *header_as() const;
+ const org::apache::arrow::flatbuf::Schema *header_as_Schema() const {
+ return header_type() == org::apache::arrow::flatbuf::MessageHeader::Schema ? static_cast<const org::apache::arrow::flatbuf::Schema *>(header()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::DictionaryBatch *header_as_DictionaryBatch() const {
+ return header_type() == org::apache::arrow::flatbuf::MessageHeader::DictionaryBatch ? static_cast<const org::apache::arrow::flatbuf::DictionaryBatch *>(header()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::RecordBatch *header_as_RecordBatch() const {
+ return header_type() == org::apache::arrow::flatbuf::MessageHeader::RecordBatch ? static_cast<const org::apache::arrow::flatbuf::RecordBatch *>(header()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Tensor *header_as_Tensor() const {
+ return header_type() == org::apache::arrow::flatbuf::MessageHeader::Tensor ? static_cast<const org::apache::arrow::flatbuf::Tensor *>(header()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::SparseTensor *header_as_SparseTensor() const {
+ return header_type() == org::apache::arrow::flatbuf::MessageHeader::SparseTensor ? static_cast<const org::apache::arrow::flatbuf::SparseTensor *>(header()) : nullptr;
+ }
+ int64_t bodyLength() const {
+ return GetField<int64_t>(VT_BODYLENGTH, 0);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>> *custom_metadata() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>> *>(VT_CUSTOM_METADATA);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int16_t>(verifier, VT_VERSION) &&
+ VerifyField<uint8_t>(verifier, VT_HEADER_TYPE) &&
+ VerifyOffset(verifier, VT_HEADER) &&
+ VerifyMessageHeader(verifier, header(), header_type()) &&
+ VerifyField<int64_t>(verifier, VT_BODYLENGTH) &&
+ VerifyOffset(verifier, VT_CUSTOM_METADATA) &&
+ verifier.VerifyVector(custom_metadata()) &&
+ verifier.VerifyVectorOfTables(custom_metadata()) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const org::apache::arrow::flatbuf::Schema *Message::header_as<org::apache::arrow::flatbuf::Schema>() const {
+ return header_as_Schema();
+}
+
+template<> inline const org::apache::arrow::flatbuf::DictionaryBatch *Message::header_as<org::apache::arrow::flatbuf::DictionaryBatch>() const {
+ return header_as_DictionaryBatch();
+}
+
+template<> inline const org::apache::arrow::flatbuf::RecordBatch *Message::header_as<org::apache::arrow::flatbuf::RecordBatch>() const {
+ return header_as_RecordBatch();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Tensor *Message::header_as<org::apache::arrow::flatbuf::Tensor>() const {
+ return header_as_Tensor();
+}
+
+template<> inline const org::apache::arrow::flatbuf::SparseTensor *Message::header_as<org::apache::arrow::flatbuf::SparseTensor>() const {
+ return header_as_SparseTensor();
+}
+
+struct MessageBuilder {
+ typedef Message Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_version(org::apache::arrow::flatbuf::MetadataVersion version) {
+ fbb_.AddElement<int16_t>(Message::VT_VERSION, static_cast<int16_t>(version), 0);
+ }
+ void add_header_type(org::apache::arrow::flatbuf::MessageHeader header_type) {
+ fbb_.AddElement<uint8_t>(Message::VT_HEADER_TYPE, static_cast<uint8_t>(header_type), 0);
+ }
+ void add_header(flatbuffers::Offset<void> header) {
+ fbb_.AddOffset(Message::VT_HEADER, header);
+ }
+ void add_bodyLength(int64_t bodyLength) {
+ fbb_.AddElement<int64_t>(Message::VT_BODYLENGTH, bodyLength, 0);
+ }
+ void add_custom_metadata(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>> custom_metadata) {
+ fbb_.AddOffset(Message::VT_CUSTOM_METADATA, custom_metadata);
+ }
+ explicit MessageBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ MessageBuilder &operator=(const MessageBuilder &);
+ flatbuffers::Offset<Message> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Message>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Message> CreateMessage(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::MetadataVersion version = org::apache::arrow::flatbuf::MetadataVersion::V1,
+ org::apache::arrow::flatbuf::MessageHeader header_type = org::apache::arrow::flatbuf::MessageHeader::NONE,
+ flatbuffers::Offset<void> header = 0,
+ int64_t bodyLength = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>> custom_metadata = 0) {
+ MessageBuilder builder_(_fbb);
+ builder_.add_bodyLength(bodyLength);
+ builder_.add_custom_metadata(custom_metadata);
+ builder_.add_header(header);
+ builder_.add_version(version);
+ builder_.add_header_type(header_type);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Message> CreateMessageDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::MetadataVersion version = org::apache::arrow::flatbuf::MetadataVersion::V1,
+ org::apache::arrow::flatbuf::MessageHeader header_type = org::apache::arrow::flatbuf::MessageHeader::NONE,
+ flatbuffers::Offset<void> header = 0,
+ int64_t bodyLength = 0,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>> *custom_metadata = nullptr) {
+ auto custom_metadata__ = custom_metadata ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>(*custom_metadata) : 0;
+ return org::apache::arrow::flatbuf::CreateMessage(
+ _fbb,
+ version,
+ header_type,
+ header,
+ bodyLength,
+ custom_metadata__);
+}
+
+inline bool VerifyMessageHeader(flatbuffers::Verifier &verifier, const void *obj, MessageHeader type) {
+ switch (type) {
+ case MessageHeader::NONE: {
+ return true;
+ }
+ case MessageHeader::Schema: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Schema *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case MessageHeader::DictionaryBatch: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::DictionaryBatch *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case MessageHeader::RecordBatch: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::RecordBatch *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case MessageHeader::Tensor: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Tensor *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case MessageHeader::SparseTensor: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::SparseTensor *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return true;
+ }
+}
+
+inline bool VerifyMessageHeaderVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyMessageHeader(
+ verifier, values->Get(i), types->GetEnum<MessageHeader>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline const org::apache::arrow::flatbuf::Message *GetMessage(const void *buf) {
+ return flatbuffers::GetRoot<org::apache::arrow::flatbuf::Message>(buf);
+}
+
+inline const org::apache::arrow::flatbuf::Message *GetSizePrefixedMessage(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<org::apache::arrow::flatbuf::Message>(buf);
+}
+
+inline bool VerifyMessageBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifyBuffer<org::apache::arrow::flatbuf::Message>(nullptr);
+}
+
+inline bool VerifySizePrefixedMessageBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<org::apache::arrow::flatbuf::Message>(nullptr);
+}
+
+inline void FinishMessageBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Message> root) {
+ fbb.Finish(root);
+}
+
+inline void FinishSizePrefixedMessageBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Message> root) {
+ fbb.FinishSizePrefixed(root);
+}
+
+} // namespace flatbuf
+} // namespace arrow
+} // namespace apache
+} // namespace org
+
+#endif // FLATBUFFERS_GENERATED_MESSAGE_ORG_APACHE_ARROW_FLATBUF_H_
diff --git a/src/arrow/cpp/src/generated/Plan_generated.h b/src/arrow/cpp/src/generated/Plan_generated.h
new file mode 100644
index 000000000..33f02af58
--- /dev/null
+++ b/src/arrow/cpp/src/generated/Plan_generated.h
@@ -0,0 +1,115 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_PLAN_ORG_APACHE_ARROW_COMPUTEIR_FLATBUF_H_
+#define FLATBUFFERS_GENERATED_PLAN_ORG_APACHE_ARROW_COMPUTEIR_FLATBUF_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+#include "Schema_generated.h"
+#include "Expression_generated.h"
+#include "Literal_generated.h"
+#include "Relation_generated.h"
+
+namespace org {
+namespace apache {
+namespace arrow {
+namespace computeir {
+namespace flatbuf {
+
+struct Plan;
+struct PlanBuilder;
+
+/// A specification of a query.
+struct Plan FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlanBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_SINKS = 4
+ };
+ /// One or more output relations.
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation>> *sinks() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation>> *>(VT_SINKS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_SINKS) &&
+ verifier.VerifyVector(sinks()) &&
+ verifier.VerifyVectorOfTables(sinks()) &&
+ verifier.EndTable();
+ }
+};
+
+struct PlanBuilder {
+ typedef Plan Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_sinks(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation>>> sinks) {
+ fbb_.AddOffset(Plan::VT_SINKS, sinks);
+ }
+ explicit PlanBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlanBuilder &operator=(const PlanBuilder &);
+ flatbuffers::Offset<Plan> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Plan>(end);
+ fbb_.Required(o, Plan::VT_SINKS);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Plan> CreatePlan(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation>>> sinks = 0) {
+ PlanBuilder builder_(_fbb);
+ builder_.add_sinks(sinks);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Plan> CreatePlanDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation>> *sinks = nullptr) {
+ auto sinks__ = sinks ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation>>(*sinks) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreatePlan(
+ _fbb,
+ sinks__);
+}
+
+inline const org::apache::arrow::computeir::flatbuf::Plan *GetPlan(const void *buf) {
+ return flatbuffers::GetRoot<org::apache::arrow::computeir::flatbuf::Plan>(buf);
+}
+
+inline const org::apache::arrow::computeir::flatbuf::Plan *GetSizePrefixedPlan(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<org::apache::arrow::computeir::flatbuf::Plan>(buf);
+}
+
+inline bool VerifyPlanBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifyBuffer<org::apache::arrow::computeir::flatbuf::Plan>(nullptr);
+}
+
+inline bool VerifySizePrefixedPlanBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<org::apache::arrow::computeir::flatbuf::Plan>(nullptr);
+}
+
+inline void FinishPlanBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Plan> root) {
+ fbb.Finish(root);
+}
+
+inline void FinishSizePrefixedPlanBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Plan> root) {
+ fbb.FinishSizePrefixed(root);
+}
+
+} // namespace flatbuf
+} // namespace computeir
+} // namespace arrow
+} // namespace apache
+} // namespace org
+
+#endif // FLATBUFFERS_GENERATED_PLAN_ORG_APACHE_ARROW_COMPUTEIR_FLATBUF_H_
diff --git a/src/arrow/cpp/src/generated/Relation_generated.h b/src/arrow/cpp/src/generated/Relation_generated.h
new file mode 100644
index 000000000..6c9d9bc92
--- /dev/null
+++ b/src/arrow/cpp/src/generated/Relation_generated.h
@@ -0,0 +1,1647 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_RELATION_ORG_APACHE_ARROW_COMPUTEIR_FLATBUF_H_
+#define FLATBUFFERS_GENERATED_RELATION_ORG_APACHE_ARROW_COMPUTEIR_FLATBUF_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+#include "Schema_generated.h"
+#include "Expression_generated.h"
+#include "Literal_generated.h"
+
+namespace org {
+namespace apache {
+namespace arrow {
+namespace computeir {
+namespace flatbuf {
+
+struct Remap;
+struct RemapBuilder;
+
+struct PassThrough;
+struct PassThroughBuilder;
+
+struct RelId;
+struct RelIdBuilder;
+
+struct RelBase;
+struct RelBaseBuilder;
+
+struct Filter;
+struct FilterBuilder;
+
+struct Project;
+struct ProjectBuilder;
+
+struct Grouping;
+struct GroupingBuilder;
+
+struct Aggregate;
+struct AggregateBuilder;
+
+struct Join;
+struct JoinBuilder;
+
+struct OrderBy;
+struct OrderByBuilder;
+
+struct Limit;
+struct LimitBuilder;
+
+struct SetOperation;
+struct SetOperationBuilder;
+
+struct LiteralColumn;
+struct LiteralColumnBuilder;
+
+struct LiteralRelation;
+struct LiteralRelationBuilder;
+
+struct Source;
+struct SourceBuilder;
+
+struct Relation;
+struct RelationBuilder;
+
+/// A union for the different colum remapping variants
+enum class Emit : uint8_t {
+ NONE = 0,
+ Remap = 1,
+ PassThrough = 2,
+ MIN = NONE,
+ MAX = PassThrough
+};
+
+inline const Emit (&EnumValuesEmit())[3] {
+ static const Emit values[] = {
+ Emit::NONE,
+ Emit::Remap,
+ Emit::PassThrough
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesEmit() {
+ static const char * const names[4] = {
+ "NONE",
+ "Remap",
+ "PassThrough",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameEmit(Emit e) {
+ if (flatbuffers::IsOutRange(e, Emit::NONE, Emit::PassThrough)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesEmit()[index];
+}
+
+template<typename T> struct EmitTraits {
+ static const Emit enum_value = Emit::NONE;
+};
+
+template<> struct EmitTraits<org::apache::arrow::computeir::flatbuf::Remap> {
+ static const Emit enum_value = Emit::Remap;
+};
+
+template<> struct EmitTraits<org::apache::arrow::computeir::flatbuf::PassThrough> {
+ static const Emit enum_value = Emit::PassThrough;
+};
+
+bool VerifyEmit(flatbuffers::Verifier &verifier, const void *obj, Emit type);
+bool VerifyEmitVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+enum class JoinKind : uint8_t {
+ Anti = 0,
+ Cross = 1,
+ FullOuter = 2,
+ Inner = 3,
+ LeftOuter = 4,
+ LeftSemi = 5,
+ RightOuter = 6,
+ MIN = Anti,
+ MAX = RightOuter
+};
+
+inline const JoinKind (&EnumValuesJoinKind())[7] {
+ static const JoinKind values[] = {
+ JoinKind::Anti,
+ JoinKind::Cross,
+ JoinKind::FullOuter,
+ JoinKind::Inner,
+ JoinKind::LeftOuter,
+ JoinKind::LeftSemi,
+ JoinKind::RightOuter
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesJoinKind() {
+ static const char * const names[8] = {
+ "Anti",
+ "Cross",
+ "FullOuter",
+ "Inner",
+ "LeftOuter",
+ "LeftSemi",
+ "RightOuter",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameJoinKind(JoinKind e) {
+ if (flatbuffers::IsOutRange(e, JoinKind::Anti, JoinKind::RightOuter)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesJoinKind()[index];
+}
+
+/// The kind of set operation being performed.
+enum class SetOpKind : uint8_t {
+ Union = 0,
+ Intersection = 1,
+ Difference = 2,
+ MIN = Union,
+ MAX = Difference
+};
+
+inline const SetOpKind (&EnumValuesSetOpKind())[3] {
+ static const SetOpKind values[] = {
+ SetOpKind::Union,
+ SetOpKind::Intersection,
+ SetOpKind::Difference
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesSetOpKind() {
+ static const char * const names[4] = {
+ "Union",
+ "Intersection",
+ "Difference",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameSetOpKind(SetOpKind e) {
+ if (flatbuffers::IsOutRange(e, SetOpKind::Union, SetOpKind::Difference)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesSetOpKind()[index];
+}
+
+/// The varieties of relations
+enum class RelationImpl : uint8_t {
+ NONE = 0,
+ Aggregate = 1,
+ Filter = 2,
+ Join = 3,
+ Limit = 4,
+ LiteralRelation = 5,
+ OrderBy = 6,
+ Project = 7,
+ SetOperation = 8,
+ Source = 9,
+ MIN = NONE,
+ MAX = Source
+};
+
+inline const RelationImpl (&EnumValuesRelationImpl())[10] {
+ static const RelationImpl values[] = {
+ RelationImpl::NONE,
+ RelationImpl::Aggregate,
+ RelationImpl::Filter,
+ RelationImpl::Join,
+ RelationImpl::Limit,
+ RelationImpl::LiteralRelation,
+ RelationImpl::OrderBy,
+ RelationImpl::Project,
+ RelationImpl::SetOperation,
+ RelationImpl::Source
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesRelationImpl() {
+ static const char * const names[11] = {
+ "NONE",
+ "Aggregate",
+ "Filter",
+ "Join",
+ "Limit",
+ "LiteralRelation",
+ "OrderBy",
+ "Project",
+ "SetOperation",
+ "Source",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameRelationImpl(RelationImpl e) {
+ if (flatbuffers::IsOutRange(e, RelationImpl::NONE, RelationImpl::Source)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesRelationImpl()[index];
+}
+
+template<typename T> struct RelationImplTraits {
+ static const RelationImpl enum_value = RelationImpl::NONE;
+};
+
+template<> struct RelationImplTraits<org::apache::arrow::computeir::flatbuf::Aggregate> {
+ static const RelationImpl enum_value = RelationImpl::Aggregate;
+};
+
+template<> struct RelationImplTraits<org::apache::arrow::computeir::flatbuf::Filter> {
+ static const RelationImpl enum_value = RelationImpl::Filter;
+};
+
+template<> struct RelationImplTraits<org::apache::arrow::computeir::flatbuf::Join> {
+ static const RelationImpl enum_value = RelationImpl::Join;
+};
+
+template<> struct RelationImplTraits<org::apache::arrow::computeir::flatbuf::Limit> {
+ static const RelationImpl enum_value = RelationImpl::Limit;
+};
+
+template<> struct RelationImplTraits<org::apache::arrow::computeir::flatbuf::LiteralRelation> {
+ static const RelationImpl enum_value = RelationImpl::LiteralRelation;
+};
+
+template<> struct RelationImplTraits<org::apache::arrow::computeir::flatbuf::OrderBy> {
+ static const RelationImpl enum_value = RelationImpl::OrderBy;
+};
+
+template<> struct RelationImplTraits<org::apache::arrow::computeir::flatbuf::Project> {
+ static const RelationImpl enum_value = RelationImpl::Project;
+};
+
+template<> struct RelationImplTraits<org::apache::arrow::computeir::flatbuf::SetOperation> {
+ static const RelationImpl enum_value = RelationImpl::SetOperation;
+};
+
+template<> struct RelationImplTraits<org::apache::arrow::computeir::flatbuf::Source> {
+ static const RelationImpl enum_value = RelationImpl::Source;
+};
+
+bool VerifyRelationImpl(flatbuffers::Verifier &verifier, const void *obj, RelationImpl type);
+bool VerifyRelationImplVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+/// A data type indicating that a different mapping of columns
+/// should occur in the output.
+///
+/// For example:
+///
+/// Given a query `SELECT b, a FROM t` where `t` has columns a, b, c
+/// the mapping value for the projection would equal [1, 0].
+struct Remap FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef RemapBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_MAPPING = 4
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::FieldIndex>> *mapping() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::FieldIndex>> *>(VT_MAPPING);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_MAPPING) &&
+ verifier.VerifyVector(mapping()) &&
+ verifier.VerifyVectorOfTables(mapping()) &&
+ verifier.EndTable();
+ }
+};
+
+struct RemapBuilder {
+ typedef Remap Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_mapping(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::FieldIndex>>> mapping) {
+ fbb_.AddOffset(Remap::VT_MAPPING, mapping);
+ }
+ explicit RemapBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ RemapBuilder &operator=(const RemapBuilder &);
+ flatbuffers::Offset<Remap> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Remap>(end);
+ fbb_.Required(o, Remap::VT_MAPPING);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Remap> CreateRemap(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::FieldIndex>>> mapping = 0) {
+ RemapBuilder builder_(_fbb);
+ builder_.add_mapping(mapping);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Remap> CreateRemapDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::FieldIndex>> *mapping = nullptr) {
+ auto mapping__ = mapping ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::FieldIndex>>(*mapping) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateRemap(
+ _fbb,
+ mapping__);
+}
+
+struct PassThrough FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PassThroughBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+};
+
+struct PassThroughBuilder {
+ typedef PassThrough Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit PassThroughBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PassThroughBuilder &operator=(const PassThroughBuilder &);
+ flatbuffers::Offset<PassThrough> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PassThrough>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PassThrough> CreatePassThrough(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ PassThroughBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+/// An identifier for relations in a query.
+///
+/// A table is used here to allow plan implementations optionality.
+struct RelId FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef RelIdBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_ID = 4
+ };
+ uint64_t id() const {
+ return GetField<uint64_t>(VT_ID, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint64_t>(verifier, VT_ID) &&
+ verifier.EndTable();
+ }
+};
+
+struct RelIdBuilder {
+ typedef RelId Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_id(uint64_t id) {
+ fbb_.AddElement<uint64_t>(RelId::VT_ID, id, 0);
+ }
+ explicit RelIdBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ RelIdBuilder &operator=(const RelIdBuilder &);
+ flatbuffers::Offset<RelId> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<RelId>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<RelId> CreateRelId(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ uint64_t id = 0) {
+ RelIdBuilder builder_(_fbb);
+ builder_.add_id(id);
+ return builder_.Finish();
+}
+
+/// Fields common to every relational operator
+struct RelBase FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef RelBaseBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OUTPUT_MAPPING_TYPE = 4,
+ VT_OUTPUT_MAPPING = 6,
+ VT_ID = 8
+ };
+ org::apache::arrow::computeir::flatbuf::Emit output_mapping_type() const {
+ return static_cast<org::apache::arrow::computeir::flatbuf::Emit>(GetField<uint8_t>(VT_OUTPUT_MAPPING_TYPE, 0));
+ }
+ /// Output remapping of ordinal columns for a given operation
+ const void *output_mapping() const {
+ return GetPointer<const void *>(VT_OUTPUT_MAPPING);
+ }
+ template<typename T> const T *output_mapping_as() const;
+ const org::apache::arrow::computeir::flatbuf::Remap *output_mapping_as_Remap() const {
+ return output_mapping_type() == org::apache::arrow::computeir::flatbuf::Emit::Remap ? static_cast<const org::apache::arrow::computeir::flatbuf::Remap *>(output_mapping()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::PassThrough *output_mapping_as_PassThrough() const {
+ return output_mapping_type() == org::apache::arrow::computeir::flatbuf::Emit::PassThrough ? static_cast<const org::apache::arrow::computeir::flatbuf::PassThrough *>(output_mapping()) : nullptr;
+ }
+ /// An identifiier for a relation. The identifier should be unique over the
+ /// entire plan. Optional.
+ const org::apache::arrow::computeir::flatbuf::RelId *id() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::RelId *>(VT_ID);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_OUTPUT_MAPPING_TYPE) &&
+ VerifyOffsetRequired(verifier, VT_OUTPUT_MAPPING) &&
+ VerifyEmit(verifier, output_mapping(), output_mapping_type()) &&
+ VerifyOffset(verifier, VT_ID) &&
+ verifier.VerifyTable(id()) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Remap *RelBase::output_mapping_as<org::apache::arrow::computeir::flatbuf::Remap>() const {
+ return output_mapping_as_Remap();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::PassThrough *RelBase::output_mapping_as<org::apache::arrow::computeir::flatbuf::PassThrough>() const {
+ return output_mapping_as_PassThrough();
+}
+
+struct RelBaseBuilder {
+ typedef RelBase Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_output_mapping_type(org::apache::arrow::computeir::flatbuf::Emit output_mapping_type) {
+ fbb_.AddElement<uint8_t>(RelBase::VT_OUTPUT_MAPPING_TYPE, static_cast<uint8_t>(output_mapping_type), 0);
+ }
+ void add_output_mapping(flatbuffers::Offset<void> output_mapping) {
+ fbb_.AddOffset(RelBase::VT_OUTPUT_MAPPING, output_mapping);
+ }
+ void add_id(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelId> id) {
+ fbb_.AddOffset(RelBase::VT_ID, id);
+ }
+ explicit RelBaseBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ RelBaseBuilder &operator=(const RelBaseBuilder &);
+ flatbuffers::Offset<RelBase> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<RelBase>(end);
+ fbb_.Required(o, RelBase::VT_OUTPUT_MAPPING);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<RelBase> CreateRelBase(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::computeir::flatbuf::Emit output_mapping_type = org::apache::arrow::computeir::flatbuf::Emit::NONE,
+ flatbuffers::Offset<void> output_mapping = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelId> id = 0) {
+ RelBaseBuilder builder_(_fbb);
+ builder_.add_id(id);
+ builder_.add_output_mapping(output_mapping);
+ builder_.add_output_mapping_type(output_mapping_type);
+ return builder_.Finish();
+}
+
+/// Filter operation
+struct Filter FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FilterBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BASE = 4,
+ VT_REL = 6,
+ VT_PREDICATE = 8
+ };
+ /// Common options
+ const org::apache::arrow::computeir::flatbuf::RelBase *base() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::RelBase *>(VT_BASE);
+ }
+ /// Child relation
+ const org::apache::arrow::computeir::flatbuf::Relation *rel() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Relation *>(VT_REL);
+ }
+ /// The expression which will be evaluated against input rows
+ /// to determine whether they should be excluded from the
+ /// filter relation's output.
+ const org::apache::arrow::computeir::flatbuf::Expression *predicate() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Expression *>(VT_PREDICATE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_BASE) &&
+ verifier.VerifyTable(base()) &&
+ VerifyOffsetRequired(verifier, VT_REL) &&
+ verifier.VerifyTable(rel()) &&
+ VerifyOffsetRequired(verifier, VT_PREDICATE) &&
+ verifier.VerifyTable(predicate()) &&
+ verifier.EndTable();
+ }
+};
+
+struct FilterBuilder {
+ typedef Filter Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_base(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base) {
+ fbb_.AddOffset(Filter::VT_BASE, base);
+ }
+ void add_rel(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> rel) {
+ fbb_.AddOffset(Filter::VT_REL, rel);
+ }
+ void add_predicate(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> predicate) {
+ fbb_.AddOffset(Filter::VT_PREDICATE, predicate);
+ }
+ explicit FilterBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FilterBuilder &operator=(const FilterBuilder &);
+ flatbuffers::Offset<Filter> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Filter>(end);
+ fbb_.Required(o, Filter::VT_BASE);
+ fbb_.Required(o, Filter::VT_REL);
+ fbb_.Required(o, Filter::VT_PREDICATE);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Filter> CreateFilter(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> rel = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> predicate = 0) {
+ FilterBuilder builder_(_fbb);
+ builder_.add_predicate(predicate);
+ builder_.add_rel(rel);
+ builder_.add_base(base);
+ return builder_.Finish();
+}
+
+/// Projection
+struct Project FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ProjectBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BASE = 4,
+ VT_REL = 6,
+ VT_EXPRESSIONS = 8
+ };
+ /// Common options
+ const org::apache::arrow::computeir::flatbuf::RelBase *base() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::RelBase *>(VT_BASE);
+ }
+ /// Child relation
+ const org::apache::arrow::computeir::flatbuf::Relation *rel() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Relation *>(VT_REL);
+ }
+ /// Expressions which will be evaluated to produce to
+ /// the rows of the project relation's output.
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *expressions() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *>(VT_EXPRESSIONS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_BASE) &&
+ verifier.VerifyTable(base()) &&
+ VerifyOffsetRequired(verifier, VT_REL) &&
+ verifier.VerifyTable(rel()) &&
+ VerifyOffsetRequired(verifier, VT_EXPRESSIONS) &&
+ verifier.VerifyVector(expressions()) &&
+ verifier.VerifyVectorOfTables(expressions()) &&
+ verifier.EndTable();
+ }
+};
+
+struct ProjectBuilder {
+ typedef Project Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_base(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base) {
+ fbb_.AddOffset(Project::VT_BASE, base);
+ }
+ void add_rel(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> rel) {
+ fbb_.AddOffset(Project::VT_REL, rel);
+ }
+ void add_expressions(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>> expressions) {
+ fbb_.AddOffset(Project::VT_EXPRESSIONS, expressions);
+ }
+ explicit ProjectBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ProjectBuilder &operator=(const ProjectBuilder &);
+ flatbuffers::Offset<Project> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Project>(end);
+ fbb_.Required(o, Project::VT_BASE);
+ fbb_.Required(o, Project::VT_REL);
+ fbb_.Required(o, Project::VT_EXPRESSIONS);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Project> CreateProject(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> rel = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>> expressions = 0) {
+ ProjectBuilder builder_(_fbb);
+ builder_.add_expressions(expressions);
+ builder_.add_rel(rel);
+ builder_.add_base(base);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Project> CreateProjectDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> rel = 0,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *expressions = nullptr) {
+ auto expressions__ = expressions ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>(*expressions) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateProject(
+ _fbb,
+ base,
+ rel,
+ expressions__);
+}
+
+/// A set of grouping keys
+struct Grouping FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef GroupingBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_KEYS = 4
+ };
+ /// Expressions to group by
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *keys() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *>(VT_KEYS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_KEYS) &&
+ verifier.VerifyVector(keys()) &&
+ verifier.VerifyVectorOfTables(keys()) &&
+ verifier.EndTable();
+ }
+};
+
+struct GroupingBuilder {
+ typedef Grouping Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_keys(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>> keys) {
+ fbb_.AddOffset(Grouping::VT_KEYS, keys);
+ }
+ explicit GroupingBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ GroupingBuilder &operator=(const GroupingBuilder &);
+ flatbuffers::Offset<Grouping> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Grouping>(end);
+ fbb_.Required(o, Grouping::VT_KEYS);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Grouping> CreateGrouping(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>> keys = 0) {
+ GroupingBuilder builder_(_fbb);
+ builder_.add_keys(keys);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Grouping> CreateGroupingDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *keys = nullptr) {
+ auto keys__ = keys ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>(*keys) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateGrouping(
+ _fbb,
+ keys__);
+}
+
+/// Aggregate operation
+struct Aggregate FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef AggregateBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BASE = 4,
+ VT_REL = 6,
+ VT_MEASURES = 8,
+ VT_GROUPINGS = 10
+ };
+ /// Common options
+ const org::apache::arrow::computeir::flatbuf::RelBase *base() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::RelBase *>(VT_BASE);
+ }
+ /// Child relation
+ const org::apache::arrow::computeir::flatbuf::Relation *rel() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Relation *>(VT_REL);
+ }
+ /// Expressions which will be evaluated to produce to
+ /// the rows of the aggregate relation's output.
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *measures() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *>(VT_MEASURES);
+ }
+ /// Keys by which `aggregations` will be grouped.
+ ///
+ /// The nested list here is to support grouping sets
+ /// eg
+ ///
+ /// SELECT a, b, c, sum(d)
+ /// FROM t
+ /// GROUP BY
+ /// GROUPING SETS (
+ /// (a, b, c),
+ /// (a, b),
+ /// (a),
+ /// ()
+ /// );
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Grouping>> *groupings() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Grouping>> *>(VT_GROUPINGS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_BASE) &&
+ verifier.VerifyTable(base()) &&
+ VerifyOffsetRequired(verifier, VT_REL) &&
+ verifier.VerifyTable(rel()) &&
+ VerifyOffsetRequired(verifier, VT_MEASURES) &&
+ verifier.VerifyVector(measures()) &&
+ verifier.VerifyVectorOfTables(measures()) &&
+ VerifyOffsetRequired(verifier, VT_GROUPINGS) &&
+ verifier.VerifyVector(groupings()) &&
+ verifier.VerifyVectorOfTables(groupings()) &&
+ verifier.EndTable();
+ }
+};
+
+struct AggregateBuilder {
+ typedef Aggregate Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_base(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base) {
+ fbb_.AddOffset(Aggregate::VT_BASE, base);
+ }
+ void add_rel(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> rel) {
+ fbb_.AddOffset(Aggregate::VT_REL, rel);
+ }
+ void add_measures(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>> measures) {
+ fbb_.AddOffset(Aggregate::VT_MEASURES, measures);
+ }
+ void add_groupings(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Grouping>>> groupings) {
+ fbb_.AddOffset(Aggregate::VT_GROUPINGS, groupings);
+ }
+ explicit AggregateBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ AggregateBuilder &operator=(const AggregateBuilder &);
+ flatbuffers::Offset<Aggregate> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Aggregate>(end);
+ fbb_.Required(o, Aggregate::VT_BASE);
+ fbb_.Required(o, Aggregate::VT_REL);
+ fbb_.Required(o, Aggregate::VT_MEASURES);
+ fbb_.Required(o, Aggregate::VT_GROUPINGS);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Aggregate> CreateAggregate(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> rel = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>> measures = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Grouping>>> groupings = 0) {
+ AggregateBuilder builder_(_fbb);
+ builder_.add_groupings(groupings);
+ builder_.add_measures(measures);
+ builder_.add_rel(rel);
+ builder_.add_base(base);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Aggregate> CreateAggregateDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> rel = 0,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>> *measures = nullptr,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Grouping>> *groupings = nullptr) {
+ auto measures__ = measures ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression>>(*measures) : 0;
+ auto groupings__ = groupings ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Grouping>>(*groupings) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateAggregate(
+ _fbb,
+ base,
+ rel,
+ measures__,
+ groupings__);
+}
+
+/// Join between two tables
+struct Join FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef JoinBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BASE = 4,
+ VT_LEFT = 6,
+ VT_RIGHT = 8,
+ VT_ON_EXPRESSION = 10,
+ VT_JOIN_KIND = 12
+ };
+ /// Common options
+ const org::apache::arrow::computeir::flatbuf::RelBase *base() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::RelBase *>(VT_BASE);
+ }
+ /// Left relation
+ const org::apache::arrow::computeir::flatbuf::Relation *left() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Relation *>(VT_LEFT);
+ }
+ /// Right relation
+ const org::apache::arrow::computeir::flatbuf::Relation *right() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Relation *>(VT_RIGHT);
+ }
+ /// The expression which will be evaluated against rows from each
+ /// input to determine whether they should be included in the
+ /// join relation's output.
+ const org::apache::arrow::computeir::flatbuf::Expression *on_expression() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Expression *>(VT_ON_EXPRESSION);
+ }
+ /// The kind of join to use.
+ org::apache::arrow::computeir::flatbuf::JoinKind join_kind() const {
+ return static_cast<org::apache::arrow::computeir::flatbuf::JoinKind>(GetField<uint8_t>(VT_JOIN_KIND, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_BASE) &&
+ verifier.VerifyTable(base()) &&
+ VerifyOffsetRequired(verifier, VT_LEFT) &&
+ verifier.VerifyTable(left()) &&
+ VerifyOffsetRequired(verifier, VT_RIGHT) &&
+ verifier.VerifyTable(right()) &&
+ VerifyOffsetRequired(verifier, VT_ON_EXPRESSION) &&
+ verifier.VerifyTable(on_expression()) &&
+ VerifyField<uint8_t>(verifier, VT_JOIN_KIND) &&
+ verifier.EndTable();
+ }
+};
+
+struct JoinBuilder {
+ typedef Join Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_base(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base) {
+ fbb_.AddOffset(Join::VT_BASE, base);
+ }
+ void add_left(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> left) {
+ fbb_.AddOffset(Join::VT_LEFT, left);
+ }
+ void add_right(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> right) {
+ fbb_.AddOffset(Join::VT_RIGHT, right);
+ }
+ void add_on_expression(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> on_expression) {
+ fbb_.AddOffset(Join::VT_ON_EXPRESSION, on_expression);
+ }
+ void add_join_kind(org::apache::arrow::computeir::flatbuf::JoinKind join_kind) {
+ fbb_.AddElement<uint8_t>(Join::VT_JOIN_KIND, static_cast<uint8_t>(join_kind), 0);
+ }
+ explicit JoinBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ JoinBuilder &operator=(const JoinBuilder &);
+ flatbuffers::Offset<Join> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Join>(end);
+ fbb_.Required(o, Join::VT_BASE);
+ fbb_.Required(o, Join::VT_LEFT);
+ fbb_.Required(o, Join::VT_RIGHT);
+ fbb_.Required(o, Join::VT_ON_EXPRESSION);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Join> CreateJoin(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> left = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> right = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Expression> on_expression = 0,
+ org::apache::arrow::computeir::flatbuf::JoinKind join_kind = org::apache::arrow::computeir::flatbuf::JoinKind::Anti) {
+ JoinBuilder builder_(_fbb);
+ builder_.add_on_expression(on_expression);
+ builder_.add_right(right);
+ builder_.add_left(left);
+ builder_.add_base(base);
+ builder_.add_join_kind(join_kind);
+ return builder_.Finish();
+}
+
+/// Order by relation
+struct OrderBy FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef OrderByBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BASE = 4,
+ VT_REL = 6,
+ VT_KEYS = 8
+ };
+ /// Common options
+ const org::apache::arrow::computeir::flatbuf::RelBase *base() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::RelBase *>(VT_BASE);
+ }
+ /// Child relation
+ const org::apache::arrow::computeir::flatbuf::Relation *rel() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Relation *>(VT_REL);
+ }
+ /// Define sort order for rows of output.
+ /// Keys with higher precedence are ordered ahead of other keys.
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>> *keys() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>> *>(VT_KEYS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_BASE) &&
+ verifier.VerifyTable(base()) &&
+ VerifyOffsetRequired(verifier, VT_REL) &&
+ verifier.VerifyTable(rel()) &&
+ VerifyOffsetRequired(verifier, VT_KEYS) &&
+ verifier.VerifyVector(keys()) &&
+ verifier.VerifyVectorOfTables(keys()) &&
+ verifier.EndTable();
+ }
+};
+
+struct OrderByBuilder {
+ typedef OrderBy Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_base(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base) {
+ fbb_.AddOffset(OrderBy::VT_BASE, base);
+ }
+ void add_rel(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> rel) {
+ fbb_.AddOffset(OrderBy::VT_REL, rel);
+ }
+ void add_keys(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>>> keys) {
+ fbb_.AddOffset(OrderBy::VT_KEYS, keys);
+ }
+ explicit OrderByBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ OrderByBuilder &operator=(const OrderByBuilder &);
+ flatbuffers::Offset<OrderBy> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<OrderBy>(end);
+ fbb_.Required(o, OrderBy::VT_BASE);
+ fbb_.Required(o, OrderBy::VT_REL);
+ fbb_.Required(o, OrderBy::VT_KEYS);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<OrderBy> CreateOrderBy(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> rel = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>>> keys = 0) {
+ OrderByBuilder builder_(_fbb);
+ builder_.add_keys(keys);
+ builder_.add_rel(rel);
+ builder_.add_base(base);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<OrderBy> CreateOrderByDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> rel = 0,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>> *keys = nullptr) {
+ auto keys__ = keys ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::SortKey>>(*keys) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateOrderBy(
+ _fbb,
+ base,
+ rel,
+ keys__);
+}
+
+/// Limit operation
+struct Limit FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef LimitBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BASE = 4,
+ VT_REL = 6,
+ VT_OFFSET = 8,
+ VT_COUNT = 10
+ };
+ /// Common options
+ const org::apache::arrow::computeir::flatbuf::RelBase *base() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::RelBase *>(VT_BASE);
+ }
+ /// Child relation
+ const org::apache::arrow::computeir::flatbuf::Relation *rel() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::Relation *>(VT_REL);
+ }
+ /// Starting index of rows
+ uint32_t offset() const {
+ return GetField<uint32_t>(VT_OFFSET, 0);
+ }
+ /// The maximum number of rows of output.
+ uint32_t count() const {
+ return GetField<uint32_t>(VT_COUNT, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_BASE) &&
+ verifier.VerifyTable(base()) &&
+ VerifyOffsetRequired(verifier, VT_REL) &&
+ verifier.VerifyTable(rel()) &&
+ VerifyField<uint32_t>(verifier, VT_OFFSET) &&
+ VerifyField<uint32_t>(verifier, VT_COUNT) &&
+ verifier.EndTable();
+ }
+};
+
+struct LimitBuilder {
+ typedef Limit Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_base(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base) {
+ fbb_.AddOffset(Limit::VT_BASE, base);
+ }
+ void add_rel(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> rel) {
+ fbb_.AddOffset(Limit::VT_REL, rel);
+ }
+ void add_offset(uint32_t offset) {
+ fbb_.AddElement<uint32_t>(Limit::VT_OFFSET, offset, 0);
+ }
+ void add_count(uint32_t count) {
+ fbb_.AddElement<uint32_t>(Limit::VT_COUNT, count, 0);
+ }
+ explicit LimitBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ LimitBuilder &operator=(const LimitBuilder &);
+ flatbuffers::Offset<Limit> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Limit>(end);
+ fbb_.Required(o, Limit::VT_BASE);
+ fbb_.Required(o, Limit::VT_REL);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Limit> CreateLimit(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> rel = 0,
+ uint32_t offset = 0,
+ uint32_t count = 0) {
+ LimitBuilder builder_(_fbb);
+ builder_.add_count(count);
+ builder_.add_offset(offset);
+ builder_.add_rel(rel);
+ builder_.add_base(base);
+ return builder_.Finish();
+}
+
+/// A set operation on two or more relations
+struct SetOperation FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SetOperationBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BASE = 4,
+ VT_RELS = 6,
+ VT_SET_OP = 8
+ };
+ /// Common options
+ const org::apache::arrow::computeir::flatbuf::RelBase *base() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::RelBase *>(VT_BASE);
+ }
+ /// Child relations
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation>> *rels() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation>> *>(VT_RELS);
+ }
+ /// The kind of set operation
+ org::apache::arrow::computeir::flatbuf::SetOpKind set_op() const {
+ return static_cast<org::apache::arrow::computeir::flatbuf::SetOpKind>(GetField<uint8_t>(VT_SET_OP, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_BASE) &&
+ verifier.VerifyTable(base()) &&
+ VerifyOffsetRequired(verifier, VT_RELS) &&
+ verifier.VerifyVector(rels()) &&
+ verifier.VerifyVectorOfTables(rels()) &&
+ VerifyField<uint8_t>(verifier, VT_SET_OP) &&
+ verifier.EndTable();
+ }
+};
+
+struct SetOperationBuilder {
+ typedef SetOperation Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_base(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base) {
+ fbb_.AddOffset(SetOperation::VT_BASE, base);
+ }
+ void add_rels(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation>>> rels) {
+ fbb_.AddOffset(SetOperation::VT_RELS, rels);
+ }
+ void add_set_op(org::apache::arrow::computeir::flatbuf::SetOpKind set_op) {
+ fbb_.AddElement<uint8_t>(SetOperation::VT_SET_OP, static_cast<uint8_t>(set_op), 0);
+ }
+ explicit SetOperationBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SetOperationBuilder &operator=(const SetOperationBuilder &);
+ flatbuffers::Offset<SetOperation> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SetOperation>(end);
+ fbb_.Required(o, SetOperation::VT_BASE);
+ fbb_.Required(o, SetOperation::VT_RELS);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SetOperation> CreateSetOperation(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation>>> rels = 0,
+ org::apache::arrow::computeir::flatbuf::SetOpKind set_op = org::apache::arrow::computeir::flatbuf::SetOpKind::Union) {
+ SetOperationBuilder builder_(_fbb);
+ builder_.add_rels(rels);
+ builder_.add_base(base);
+ builder_.add_set_op(set_op);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<SetOperation> CreateSetOperationDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation>> *rels = nullptr,
+ org::apache::arrow::computeir::flatbuf::SetOpKind set_op = org::apache::arrow::computeir::flatbuf::SetOpKind::Union) {
+ auto rels__ = rels ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation>>(*rels) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateSetOperation(
+ _fbb,
+ base,
+ rels__,
+ set_op);
+}
+
+/// A single column of literal values.
+struct LiteralColumn FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef LiteralColumnBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_ELEMENTS = 4
+ };
+ /// The literal values of the column
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>> *elements() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>> *>(VT_ELEMENTS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_ELEMENTS) &&
+ verifier.VerifyVector(elements()) &&
+ verifier.VerifyVectorOfTables(elements()) &&
+ verifier.EndTable();
+ }
+};
+
+struct LiteralColumnBuilder {
+ typedef LiteralColumn Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_elements(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>>> elements) {
+ fbb_.AddOffset(LiteralColumn::VT_ELEMENTS, elements);
+ }
+ explicit LiteralColumnBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ LiteralColumnBuilder &operator=(const LiteralColumnBuilder &);
+ flatbuffers::Offset<LiteralColumn> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<LiteralColumn>(end);
+ fbb_.Required(o, LiteralColumn::VT_ELEMENTS);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<LiteralColumn> CreateLiteralColumn(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>>> elements = 0) {
+ LiteralColumnBuilder builder_(_fbb);
+ builder_.add_elements(elements);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<LiteralColumn> CreateLiteralColumnDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>> *elements = nullptr) {
+ auto elements__ = elements ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Literal>>(*elements) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateLiteralColumn(
+ _fbb,
+ elements__);
+}
+
+/// Literal relation
+struct LiteralRelation FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef LiteralRelationBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BASE = 4,
+ VT_COLUMNS = 6
+ };
+ /// Common options
+ const org::apache::arrow::computeir::flatbuf::RelBase *base() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::RelBase *>(VT_BASE);
+ }
+ /// The columns of this literal relation.
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::LiteralColumn>> *columns() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::LiteralColumn>> *>(VT_COLUMNS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_BASE) &&
+ verifier.VerifyTable(base()) &&
+ VerifyOffsetRequired(verifier, VT_COLUMNS) &&
+ verifier.VerifyVector(columns()) &&
+ verifier.VerifyVectorOfTables(columns()) &&
+ verifier.EndTable();
+ }
+};
+
+struct LiteralRelationBuilder {
+ typedef LiteralRelation Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_base(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base) {
+ fbb_.AddOffset(LiteralRelation::VT_BASE, base);
+ }
+ void add_columns(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::LiteralColumn>>> columns) {
+ fbb_.AddOffset(LiteralRelation::VT_COLUMNS, columns);
+ }
+ explicit LiteralRelationBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ LiteralRelationBuilder &operator=(const LiteralRelationBuilder &);
+ flatbuffers::Offset<LiteralRelation> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<LiteralRelation>(end);
+ fbb_.Required(o, LiteralRelation::VT_BASE);
+ fbb_.Required(o, LiteralRelation::VT_COLUMNS);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<LiteralRelation> CreateLiteralRelation(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::LiteralColumn>>> columns = 0) {
+ LiteralRelationBuilder builder_(_fbb);
+ builder_.add_columns(columns);
+ builder_.add_base(base);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<LiteralRelation> CreateLiteralRelationDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::LiteralColumn>> *columns = nullptr) {
+ auto columns__ = columns ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::LiteralColumn>>(*columns) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateLiteralRelation(
+ _fbb,
+ base,
+ columns__);
+}
+
+/// An external source of tabular data
+struct Source FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SourceBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BASE = 4,
+ VT_NAME = 6,
+ VT_SCHEMA = 8
+ };
+ const org::apache::arrow::computeir::flatbuf::RelBase *base() const {
+ return GetPointer<const org::apache::arrow::computeir::flatbuf::RelBase *>(VT_BASE);
+ }
+ const flatbuffers::String *name() const {
+ return GetPointer<const flatbuffers::String *>(VT_NAME);
+ }
+ const org::apache::arrow::flatbuf::Schema *schema() const {
+ return GetPointer<const org::apache::arrow::flatbuf::Schema *>(VT_SCHEMA);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_BASE) &&
+ verifier.VerifyTable(base()) &&
+ VerifyOffsetRequired(verifier, VT_NAME) &&
+ verifier.VerifyString(name()) &&
+ VerifyOffsetRequired(verifier, VT_SCHEMA) &&
+ verifier.VerifyTable(schema()) &&
+ verifier.EndTable();
+ }
+};
+
+struct SourceBuilder {
+ typedef Source Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_base(flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base) {
+ fbb_.AddOffset(Source::VT_BASE, base);
+ }
+ void add_name(flatbuffers::Offset<flatbuffers::String> name) {
+ fbb_.AddOffset(Source::VT_NAME, name);
+ }
+ void add_schema(flatbuffers::Offset<org::apache::arrow::flatbuf::Schema> schema) {
+ fbb_.AddOffset(Source::VT_SCHEMA, schema);
+ }
+ explicit SourceBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SourceBuilder &operator=(const SourceBuilder &);
+ flatbuffers::Offset<Source> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Source>(end);
+ fbb_.Required(o, Source::VT_BASE);
+ fbb_.Required(o, Source::VT_NAME);
+ fbb_.Required(o, Source::VT_SCHEMA);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Source> CreateSource(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ flatbuffers::Offset<flatbuffers::String> name = 0,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Schema> schema = 0) {
+ SourceBuilder builder_(_fbb);
+ builder_.add_schema(schema);
+ builder_.add_name(name);
+ builder_.add_base(base);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Source> CreateSourceDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::RelBase> base = 0,
+ const char *name = nullptr,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Schema> schema = 0) {
+ auto name__ = name ? _fbb.CreateString(name) : 0;
+ return org::apache::arrow::computeir::flatbuf::CreateSource(
+ _fbb,
+ base,
+ name__,
+ schema);
+}
+
+/// A table holding an instance of the possible relation types.
+struct Relation FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef RelationBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_IMPL_TYPE = 4,
+ VT_IMPL = 6
+ };
+ org::apache::arrow::computeir::flatbuf::RelationImpl impl_type() const {
+ return static_cast<org::apache::arrow::computeir::flatbuf::RelationImpl>(GetField<uint8_t>(VT_IMPL_TYPE, 0));
+ }
+ const void *impl() const {
+ return GetPointer<const void *>(VT_IMPL);
+ }
+ template<typename T> const T *impl_as() const;
+ const org::apache::arrow::computeir::flatbuf::Aggregate *impl_as_Aggregate() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::RelationImpl::Aggregate ? static_cast<const org::apache::arrow::computeir::flatbuf::Aggregate *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Filter *impl_as_Filter() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::RelationImpl::Filter ? static_cast<const org::apache::arrow::computeir::flatbuf::Filter *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Join *impl_as_Join() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::RelationImpl::Join ? static_cast<const org::apache::arrow::computeir::flatbuf::Join *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Limit *impl_as_Limit() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::RelationImpl::Limit ? static_cast<const org::apache::arrow::computeir::flatbuf::Limit *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::LiteralRelation *impl_as_LiteralRelation() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::RelationImpl::LiteralRelation ? static_cast<const org::apache::arrow::computeir::flatbuf::LiteralRelation *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::OrderBy *impl_as_OrderBy() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::RelationImpl::OrderBy ? static_cast<const org::apache::arrow::computeir::flatbuf::OrderBy *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Project *impl_as_Project() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::RelationImpl::Project ? static_cast<const org::apache::arrow::computeir::flatbuf::Project *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::SetOperation *impl_as_SetOperation() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::RelationImpl::SetOperation ? static_cast<const org::apache::arrow::computeir::flatbuf::SetOperation *>(impl()) : nullptr;
+ }
+ const org::apache::arrow::computeir::flatbuf::Source *impl_as_Source() const {
+ return impl_type() == org::apache::arrow::computeir::flatbuf::RelationImpl::Source ? static_cast<const org::apache::arrow::computeir::flatbuf::Source *>(impl()) : nullptr;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_IMPL_TYPE) &&
+ VerifyOffsetRequired(verifier, VT_IMPL) &&
+ VerifyRelationImpl(verifier, impl(), impl_type()) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Aggregate *Relation::impl_as<org::apache::arrow::computeir::flatbuf::Aggregate>() const {
+ return impl_as_Aggregate();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Filter *Relation::impl_as<org::apache::arrow::computeir::flatbuf::Filter>() const {
+ return impl_as_Filter();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Join *Relation::impl_as<org::apache::arrow::computeir::flatbuf::Join>() const {
+ return impl_as_Join();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Limit *Relation::impl_as<org::apache::arrow::computeir::flatbuf::Limit>() const {
+ return impl_as_Limit();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::LiteralRelation *Relation::impl_as<org::apache::arrow::computeir::flatbuf::LiteralRelation>() const {
+ return impl_as_LiteralRelation();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::OrderBy *Relation::impl_as<org::apache::arrow::computeir::flatbuf::OrderBy>() const {
+ return impl_as_OrderBy();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Project *Relation::impl_as<org::apache::arrow::computeir::flatbuf::Project>() const {
+ return impl_as_Project();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::SetOperation *Relation::impl_as<org::apache::arrow::computeir::flatbuf::SetOperation>() const {
+ return impl_as_SetOperation();
+}
+
+template<> inline const org::apache::arrow::computeir::flatbuf::Source *Relation::impl_as<org::apache::arrow::computeir::flatbuf::Source>() const {
+ return impl_as_Source();
+}
+
+struct RelationBuilder {
+ typedef Relation Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_impl_type(org::apache::arrow::computeir::flatbuf::RelationImpl impl_type) {
+ fbb_.AddElement<uint8_t>(Relation::VT_IMPL_TYPE, static_cast<uint8_t>(impl_type), 0);
+ }
+ void add_impl(flatbuffers::Offset<void> impl) {
+ fbb_.AddOffset(Relation::VT_IMPL, impl);
+ }
+ explicit RelationBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ RelationBuilder &operator=(const RelationBuilder &);
+ flatbuffers::Offset<Relation> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Relation>(end);
+ fbb_.Required(o, Relation::VT_IMPL);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Relation> CreateRelation(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::computeir::flatbuf::RelationImpl impl_type = org::apache::arrow::computeir::flatbuf::RelationImpl::NONE,
+ flatbuffers::Offset<void> impl = 0) {
+ RelationBuilder builder_(_fbb);
+ builder_.add_impl(impl);
+ builder_.add_impl_type(impl_type);
+ return builder_.Finish();
+}
+
+inline bool VerifyEmit(flatbuffers::Verifier &verifier, const void *obj, Emit type) {
+ switch (type) {
+ case Emit::NONE: {
+ return true;
+ }
+ case Emit::Remap: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Remap *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Emit::PassThrough: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::PassThrough *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return true;
+ }
+}
+
+inline bool VerifyEmitVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyEmit(
+ verifier, values->Get(i), types->GetEnum<Emit>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline bool VerifyRelationImpl(flatbuffers::Verifier &verifier, const void *obj, RelationImpl type) {
+ switch (type) {
+ case RelationImpl::NONE: {
+ return true;
+ }
+ case RelationImpl::Aggregate: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Aggregate *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case RelationImpl::Filter: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Filter *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case RelationImpl::Join: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Join *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case RelationImpl::Limit: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Limit *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case RelationImpl::LiteralRelation: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::LiteralRelation *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case RelationImpl::OrderBy: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::OrderBy *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case RelationImpl::Project: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Project *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case RelationImpl::SetOperation: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::SetOperation *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case RelationImpl::Source: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::computeir::flatbuf::Source *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return true;
+ }
+}
+
+inline bool VerifyRelationImplVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyRelationImpl(
+ verifier, values->Get(i), types->GetEnum<RelationImpl>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline const org::apache::arrow::computeir::flatbuf::Relation *GetRelation(const void *buf) {
+ return flatbuffers::GetRoot<org::apache::arrow::computeir::flatbuf::Relation>(buf);
+}
+
+inline const org::apache::arrow::computeir::flatbuf::Relation *GetSizePrefixedRelation(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<org::apache::arrow::computeir::flatbuf::Relation>(buf);
+}
+
+inline bool VerifyRelationBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifyBuffer<org::apache::arrow::computeir::flatbuf::Relation>(nullptr);
+}
+
+inline bool VerifySizePrefixedRelationBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<org::apache::arrow::computeir::flatbuf::Relation>(nullptr);
+}
+
+inline void FinishRelationBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> root) {
+ fbb.Finish(root);
+}
+
+inline void FinishSizePrefixedRelationBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::computeir::flatbuf::Relation> root) {
+ fbb.FinishSizePrefixed(root);
+}
+
+} // namespace flatbuf
+} // namespace computeir
+} // namespace arrow
+} // namespace apache
+} // namespace org
+
+#endif // FLATBUFFERS_GENERATED_RELATION_ORG_APACHE_ARROW_COMPUTEIR_FLATBUF_H_
diff --git a/src/arrow/cpp/src/generated/Schema_generated.h b/src/arrow/cpp/src/generated/Schema_generated.h
new file mode 100644
index 000000000..79ffa661e
--- /dev/null
+++ b/src/arrow/cpp/src/generated/Schema_generated.h
@@ -0,0 +1,2367 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_SCHEMA_ORG_APACHE_ARROW_FLATBUF_H_
+#define FLATBUFFERS_GENERATED_SCHEMA_ORG_APACHE_ARROW_FLATBUF_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+namespace org {
+namespace apache {
+namespace arrow {
+namespace flatbuf {
+
+struct Null;
+struct NullBuilder;
+
+struct Struct_;
+struct Struct_Builder;
+
+struct List;
+struct ListBuilder;
+
+struct LargeList;
+struct LargeListBuilder;
+
+struct FixedSizeList;
+struct FixedSizeListBuilder;
+
+struct Map;
+struct MapBuilder;
+
+struct Union;
+struct UnionBuilder;
+
+struct Int;
+struct IntBuilder;
+
+struct FloatingPoint;
+struct FloatingPointBuilder;
+
+struct Utf8;
+struct Utf8Builder;
+
+struct Binary;
+struct BinaryBuilder;
+
+struct LargeUtf8;
+struct LargeUtf8Builder;
+
+struct LargeBinary;
+struct LargeBinaryBuilder;
+
+struct FixedSizeBinary;
+struct FixedSizeBinaryBuilder;
+
+struct Bool;
+struct BoolBuilder;
+
+struct Decimal;
+struct DecimalBuilder;
+
+struct Date;
+struct DateBuilder;
+
+struct Time;
+struct TimeBuilder;
+
+struct Timestamp;
+struct TimestampBuilder;
+
+struct Interval;
+struct IntervalBuilder;
+
+struct Duration;
+struct DurationBuilder;
+
+struct KeyValue;
+struct KeyValueBuilder;
+
+struct DictionaryEncoding;
+struct DictionaryEncodingBuilder;
+
+struct Field;
+struct FieldBuilder;
+
+struct Buffer;
+
+struct Schema;
+struct SchemaBuilder;
+
+enum class MetadataVersion : int16_t {
+ /// 0.1.0 (October 2016).
+ V1 = 0,
+ /// 0.2.0 (February 2017). Non-backwards compatible with V1.
+ V2 = 1,
+ /// 0.3.0 -> 0.7.1 (May - December 2017). Non-backwards compatible with V2.
+ V3 = 2,
+ /// >= 0.8.0 (December 2017). Non-backwards compatible with V3.
+ V4 = 3,
+ /// >= 1.0.0 (July 2020. Backwards compatible with V4 (V5 readers can read V4
+ /// metadata and IPC messages). Implementations are recommended to provide a
+ /// V4 compatibility mode with V5 format changes disabled.
+ ///
+ /// Incompatible changes between V4 and V5:
+ /// - Union buffer layout has changed. In V5, Unions don't have a validity
+ /// bitmap buffer.
+ V5 = 4,
+ MIN = V1,
+ MAX = V5
+};
+
+inline const MetadataVersion (&EnumValuesMetadataVersion())[5] {
+ static const MetadataVersion values[] = {
+ MetadataVersion::V1,
+ MetadataVersion::V2,
+ MetadataVersion::V3,
+ MetadataVersion::V4,
+ MetadataVersion::V5
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesMetadataVersion() {
+ static const char * const names[6] = {
+ "V1",
+ "V2",
+ "V3",
+ "V4",
+ "V5",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameMetadataVersion(MetadataVersion e) {
+ if (flatbuffers::IsOutRange(e, MetadataVersion::V1, MetadataVersion::V5)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesMetadataVersion()[index];
+}
+
+/// Represents Arrow Features that might not have full support
+/// within implementations. This is intended to be used in
+/// two scenarios:
+/// 1. A mechanism for readers of Arrow Streams
+/// and files to understand that the stream or file makes
+/// use of a feature that isn't supported or unknown to
+/// the implementation (and therefore can meet the Arrow
+/// forward compatibility guarantees).
+/// 2. A means of negotiating between a client and server
+/// what features a stream is allowed to use. The enums
+/// values here are intented to represent higher level
+/// features, additional details maybe negotiated
+/// with key-value pairs specific to the protocol.
+///
+/// Enums added to this list should be assigned power-of-two values
+/// to facilitate exchanging and comparing bitmaps for supported
+/// features.
+enum class Feature : int64_t {
+ /// Needed to make flatbuffers happy.
+ UNUSED = 0,
+ /// The stream makes use of multiple full dictionaries with the
+ /// same ID and assumes clients implement dictionary replacement
+ /// correctly.
+ DICTIONARY_REPLACEMENT = 1LL,
+ /// The stream makes use of compressed bodies as described
+ /// in Message.fbs.
+ COMPRESSED_BODY = 2LL,
+ MIN = UNUSED,
+ MAX = COMPRESSED_BODY
+};
+
+inline const Feature (&EnumValuesFeature())[3] {
+ static const Feature values[] = {
+ Feature::UNUSED,
+ Feature::DICTIONARY_REPLACEMENT,
+ Feature::COMPRESSED_BODY
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesFeature() {
+ static const char * const names[4] = {
+ "UNUSED",
+ "DICTIONARY_REPLACEMENT",
+ "COMPRESSED_BODY",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameFeature(Feature e) {
+ if (flatbuffers::IsOutRange(e, Feature::UNUSED, Feature::COMPRESSED_BODY)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesFeature()[index];
+}
+
+enum class UnionMode : int16_t {
+ Sparse = 0,
+ Dense = 1,
+ MIN = Sparse,
+ MAX = Dense
+};
+
+inline const UnionMode (&EnumValuesUnionMode())[2] {
+ static const UnionMode values[] = {
+ UnionMode::Sparse,
+ UnionMode::Dense
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesUnionMode() {
+ static const char * const names[3] = {
+ "Sparse",
+ "Dense",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameUnionMode(UnionMode e) {
+ if (flatbuffers::IsOutRange(e, UnionMode::Sparse, UnionMode::Dense)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesUnionMode()[index];
+}
+
+enum class Precision : int16_t {
+ HALF = 0,
+ SINGLE = 1,
+ DOUBLE = 2,
+ MIN = HALF,
+ MAX = DOUBLE
+};
+
+inline const Precision (&EnumValuesPrecision())[3] {
+ static const Precision values[] = {
+ Precision::HALF,
+ Precision::SINGLE,
+ Precision::DOUBLE
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesPrecision() {
+ static const char * const names[4] = {
+ "HALF",
+ "SINGLE",
+ "DOUBLE",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNamePrecision(Precision e) {
+ if (flatbuffers::IsOutRange(e, Precision::HALF, Precision::DOUBLE)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesPrecision()[index];
+}
+
+enum class DateUnit : int16_t {
+ DAY = 0,
+ MILLISECOND = 1,
+ MIN = DAY,
+ MAX = MILLISECOND
+};
+
+inline const DateUnit (&EnumValuesDateUnit())[2] {
+ static const DateUnit values[] = {
+ DateUnit::DAY,
+ DateUnit::MILLISECOND
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesDateUnit() {
+ static const char * const names[3] = {
+ "DAY",
+ "MILLISECOND",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameDateUnit(DateUnit e) {
+ if (flatbuffers::IsOutRange(e, DateUnit::DAY, DateUnit::MILLISECOND)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesDateUnit()[index];
+}
+
+enum class TimeUnit : int16_t {
+ SECOND = 0,
+ MILLISECOND = 1,
+ MICROSECOND = 2,
+ NANOSECOND = 3,
+ MIN = SECOND,
+ MAX = NANOSECOND
+};
+
+inline const TimeUnit (&EnumValuesTimeUnit())[4] {
+ static const TimeUnit values[] = {
+ TimeUnit::SECOND,
+ TimeUnit::MILLISECOND,
+ TimeUnit::MICROSECOND,
+ TimeUnit::NANOSECOND
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesTimeUnit() {
+ static const char * const names[5] = {
+ "SECOND",
+ "MILLISECOND",
+ "MICROSECOND",
+ "NANOSECOND",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameTimeUnit(TimeUnit e) {
+ if (flatbuffers::IsOutRange(e, TimeUnit::SECOND, TimeUnit::NANOSECOND)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesTimeUnit()[index];
+}
+
+enum class IntervalUnit : int16_t {
+ YEAR_MONTH = 0,
+ DAY_TIME = 1,
+ MONTH_DAY_NANO = 2,
+ MIN = YEAR_MONTH,
+ MAX = MONTH_DAY_NANO
+};
+
+inline const IntervalUnit (&EnumValuesIntervalUnit())[3] {
+ static const IntervalUnit values[] = {
+ IntervalUnit::YEAR_MONTH,
+ IntervalUnit::DAY_TIME,
+ IntervalUnit::MONTH_DAY_NANO
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesIntervalUnit() {
+ static const char * const names[4] = {
+ "YEAR_MONTH",
+ "DAY_TIME",
+ "MONTH_DAY_NANO",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameIntervalUnit(IntervalUnit e) {
+ if (flatbuffers::IsOutRange(e, IntervalUnit::YEAR_MONTH, IntervalUnit::MONTH_DAY_NANO)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesIntervalUnit()[index];
+}
+
+/// ----------------------------------------------------------------------
+/// Top-level Type value, enabling extensible type-specific metadata. We can
+/// add new logical types to Type without breaking backwards compatibility
+enum class Type : uint8_t {
+ NONE = 0,
+ Null = 1,
+ Int = 2,
+ FloatingPoint = 3,
+ Binary = 4,
+ Utf8 = 5,
+ Bool = 6,
+ Decimal = 7,
+ Date = 8,
+ Time = 9,
+ Timestamp = 10,
+ Interval = 11,
+ List = 12,
+ Struct_ = 13,
+ Union = 14,
+ FixedSizeBinary = 15,
+ FixedSizeList = 16,
+ Map = 17,
+ Duration = 18,
+ LargeBinary = 19,
+ LargeUtf8 = 20,
+ LargeList = 21,
+ MIN = NONE,
+ MAX = LargeList
+};
+
+inline const Type (&EnumValuesType())[22] {
+ static const Type values[] = {
+ Type::NONE,
+ Type::Null,
+ Type::Int,
+ Type::FloatingPoint,
+ Type::Binary,
+ Type::Utf8,
+ Type::Bool,
+ Type::Decimal,
+ Type::Date,
+ Type::Time,
+ Type::Timestamp,
+ Type::Interval,
+ Type::List,
+ Type::Struct_,
+ Type::Union,
+ Type::FixedSizeBinary,
+ Type::FixedSizeList,
+ Type::Map,
+ Type::Duration,
+ Type::LargeBinary,
+ Type::LargeUtf8,
+ Type::LargeList
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesType() {
+ static const char * const names[23] = {
+ "NONE",
+ "Null",
+ "Int",
+ "FloatingPoint",
+ "Binary",
+ "Utf8",
+ "Bool",
+ "Decimal",
+ "Date",
+ "Time",
+ "Timestamp",
+ "Interval",
+ "List",
+ "Struct_",
+ "Union",
+ "FixedSizeBinary",
+ "FixedSizeList",
+ "Map",
+ "Duration",
+ "LargeBinary",
+ "LargeUtf8",
+ "LargeList",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameType(Type e) {
+ if (flatbuffers::IsOutRange(e, Type::NONE, Type::LargeList)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesType()[index];
+}
+
+template<typename T> struct TypeTraits {
+ static const Type enum_value = Type::NONE;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::Null> {
+ static const Type enum_value = Type::Null;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::Int> {
+ static const Type enum_value = Type::Int;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::FloatingPoint> {
+ static const Type enum_value = Type::FloatingPoint;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::Binary> {
+ static const Type enum_value = Type::Binary;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::Utf8> {
+ static const Type enum_value = Type::Utf8;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::Bool> {
+ static const Type enum_value = Type::Bool;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::Decimal> {
+ static const Type enum_value = Type::Decimal;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::Date> {
+ static const Type enum_value = Type::Date;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::Time> {
+ static const Type enum_value = Type::Time;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::Timestamp> {
+ static const Type enum_value = Type::Timestamp;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::Interval> {
+ static const Type enum_value = Type::Interval;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::List> {
+ static const Type enum_value = Type::List;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::Struct_> {
+ static const Type enum_value = Type::Struct_;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::Union> {
+ static const Type enum_value = Type::Union;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::FixedSizeBinary> {
+ static const Type enum_value = Type::FixedSizeBinary;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::FixedSizeList> {
+ static const Type enum_value = Type::FixedSizeList;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::Map> {
+ static const Type enum_value = Type::Map;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::Duration> {
+ static const Type enum_value = Type::Duration;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::LargeBinary> {
+ static const Type enum_value = Type::LargeBinary;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::LargeUtf8> {
+ static const Type enum_value = Type::LargeUtf8;
+};
+
+template<> struct TypeTraits<org::apache::arrow::flatbuf::LargeList> {
+ static const Type enum_value = Type::LargeList;
+};
+
+bool VerifyType(flatbuffers::Verifier &verifier, const void *obj, Type type);
+bool VerifyTypeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+/// ----------------------------------------------------------------------
+/// Dictionary encoding metadata
+/// Maintained for forwards compatibility, in the future
+/// Dictionaries might be explicit maps between integers and values
+/// allowing for non-contiguous index values
+enum class DictionaryKind : int16_t {
+ DenseArray = 0,
+ MIN = DenseArray,
+ MAX = DenseArray
+};
+
+inline const DictionaryKind (&EnumValuesDictionaryKind())[1] {
+ static const DictionaryKind values[] = {
+ DictionaryKind::DenseArray
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesDictionaryKind() {
+ static const char * const names[2] = {
+ "DenseArray",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameDictionaryKind(DictionaryKind e) {
+ if (flatbuffers::IsOutRange(e, DictionaryKind::DenseArray, DictionaryKind::DenseArray)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesDictionaryKind()[index];
+}
+
+/// ----------------------------------------------------------------------
+/// Endianness of the platform producing the data
+enum class Endianness : int16_t {
+ Little = 0,
+ Big = 1,
+ MIN = Little,
+ MAX = Big
+};
+
+inline const Endianness (&EnumValuesEndianness())[2] {
+ static const Endianness values[] = {
+ Endianness::Little,
+ Endianness::Big
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesEndianness() {
+ static const char * const names[3] = {
+ "Little",
+ "Big",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameEndianness(Endianness e) {
+ if (flatbuffers::IsOutRange(e, Endianness::Little, Endianness::Big)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesEndianness()[index];
+}
+
+/// ----------------------------------------------------------------------
+/// A Buffer represents a single contiguous memory segment
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) Buffer FLATBUFFERS_FINAL_CLASS {
+ private:
+ int64_t offset_;
+ int64_t length_;
+
+ public:
+ Buffer() {
+ memset(static_cast<void *>(this), 0, sizeof(Buffer));
+ }
+ Buffer(int64_t _offset, int64_t _length)
+ : offset_(flatbuffers::EndianScalar(_offset)),
+ length_(flatbuffers::EndianScalar(_length)) {
+ }
+ /// The relative offset into the shared memory page where the bytes for this
+ /// buffer starts
+ int64_t offset() const {
+ return flatbuffers::EndianScalar(offset_);
+ }
+ /// The absolute length (in bytes) of the memory buffer. The memory is found
+ /// from offset (inclusive) to offset + length (non-inclusive). When building
+ /// messages using the encapsulated IPC message, padding bytes may be written
+ /// after a buffer, but such padding bytes do not need to be accounted for in
+ /// the size here.
+ int64_t length() const {
+ return flatbuffers::EndianScalar(length_);
+ }
+};
+FLATBUFFERS_STRUCT_END(Buffer, 16);
+
+/// These are stored in the flatbuffer in the Type union below
+struct Null FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef NullBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+};
+
+struct NullBuilder {
+ typedef Null Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit NullBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ NullBuilder &operator=(const NullBuilder &);
+ flatbuffers::Offset<Null> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Null>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Null> CreateNull(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ NullBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+/// A Struct_ in the flatbuffer metadata is the same as an Arrow Struct
+/// (according to the physical memory layout). We used Struct_ here as
+/// Struct is a reserved word in Flatbuffers
+struct Struct_ FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef Struct_Builder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+};
+
+struct Struct_Builder {
+ typedef Struct_ Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit Struct_Builder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ Struct_Builder &operator=(const Struct_Builder &);
+ flatbuffers::Offset<Struct_> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Struct_>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Struct_> CreateStruct_(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ Struct_Builder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct List FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ListBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+};
+
+struct ListBuilder {
+ typedef List Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit ListBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ListBuilder &operator=(const ListBuilder &);
+ flatbuffers::Offset<List> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<List>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<List> CreateList(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ ListBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+/// Same as List, but with 64-bit offsets, allowing to represent
+/// extremely large data values.
+struct LargeList FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef LargeListBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+};
+
+struct LargeListBuilder {
+ typedef LargeList Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit LargeListBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ LargeListBuilder &operator=(const LargeListBuilder &);
+ flatbuffers::Offset<LargeList> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<LargeList>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<LargeList> CreateLargeList(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ LargeListBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct FixedSizeList FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FixedSizeListBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_LISTSIZE = 4
+ };
+ /// Number of list items per value
+ int32_t listSize() const {
+ return GetField<int32_t>(VT_LISTSIZE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_LISTSIZE) &&
+ verifier.EndTable();
+ }
+};
+
+struct FixedSizeListBuilder {
+ typedef FixedSizeList Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_listSize(int32_t listSize) {
+ fbb_.AddElement<int32_t>(FixedSizeList::VT_LISTSIZE, listSize, 0);
+ }
+ explicit FixedSizeListBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FixedSizeListBuilder &operator=(const FixedSizeListBuilder &);
+ flatbuffers::Offset<FixedSizeList> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FixedSizeList>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FixedSizeList> CreateFixedSizeList(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t listSize = 0) {
+ FixedSizeListBuilder builder_(_fbb);
+ builder_.add_listSize(listSize);
+ return builder_.Finish();
+}
+
+/// A Map is a logical nested type that is represented as
+///
+/// List<entries: Struct<key: K, value: V>>
+///
+/// In this layout, the keys and values are each respectively contiguous. We do
+/// not constrain the key and value types, so the application is responsible
+/// for ensuring that the keys are hashable and unique. Whether the keys are sorted
+/// may be set in the metadata for this field.
+///
+/// In a field with Map type, the field has a child Struct field, which then
+/// has two children: key type and the second the value type. The names of the
+/// child fields may be respectively "entries", "key", and "value", but this is
+/// not enforced.
+///
+/// Map
+/// ```text
+/// - child[0] entries: Struct
+/// - child[0] key: K
+/// - child[1] value: V
+/// ```
+/// Neither the "entries" field nor the "key" field may be nullable.
+///
+/// The metadata is structured so that Arrow systems without special handling
+/// for Map can make Map an alias for List. The "layout" attribute for the Map
+/// field must have the same contents as a List.
+struct Map FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef MapBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_KEYSSORTED = 4
+ };
+ /// Set to true if the keys within each value are sorted
+ bool keysSorted() const {
+ return GetField<uint8_t>(VT_KEYSSORTED, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_KEYSSORTED) &&
+ verifier.EndTable();
+ }
+};
+
+struct MapBuilder {
+ typedef Map Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_keysSorted(bool keysSorted) {
+ fbb_.AddElement<uint8_t>(Map::VT_KEYSSORTED, static_cast<uint8_t>(keysSorted), 0);
+ }
+ explicit MapBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ MapBuilder &operator=(const MapBuilder &);
+ flatbuffers::Offset<Map> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Map>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Map> CreateMap(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ bool keysSorted = false) {
+ MapBuilder builder_(_fbb);
+ builder_.add_keysSorted(keysSorted);
+ return builder_.Finish();
+}
+
+/// A union is a complex type with children in Field
+/// By default ids in the type vector refer to the offsets in the children
+/// optionally typeIds provides an indirection between the child offset and the type id
+/// for each child `typeIds[offset]` is the id used in the type vector
+struct Union FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef UnionBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_MODE = 4,
+ VT_TYPEIDS = 6
+ };
+ org::apache::arrow::flatbuf::UnionMode mode() const {
+ return static_cast<org::apache::arrow::flatbuf::UnionMode>(GetField<int16_t>(VT_MODE, 0));
+ }
+ const flatbuffers::Vector<int32_t> *typeIds() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_TYPEIDS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int16_t>(verifier, VT_MODE) &&
+ VerifyOffset(verifier, VT_TYPEIDS) &&
+ verifier.VerifyVector(typeIds()) &&
+ verifier.EndTable();
+ }
+};
+
+struct UnionBuilder {
+ typedef Union Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_mode(org::apache::arrow::flatbuf::UnionMode mode) {
+ fbb_.AddElement<int16_t>(Union::VT_MODE, static_cast<int16_t>(mode), 0);
+ }
+ void add_typeIds(flatbuffers::Offset<flatbuffers::Vector<int32_t>> typeIds) {
+ fbb_.AddOffset(Union::VT_TYPEIDS, typeIds);
+ }
+ explicit UnionBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ UnionBuilder &operator=(const UnionBuilder &);
+ flatbuffers::Offset<Union> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Union>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Union> CreateUnion(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::UnionMode mode = org::apache::arrow::flatbuf::UnionMode::Sparse,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> typeIds = 0) {
+ UnionBuilder builder_(_fbb);
+ builder_.add_typeIds(typeIds);
+ builder_.add_mode(mode);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Union> CreateUnionDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::UnionMode mode = org::apache::arrow::flatbuf::UnionMode::Sparse,
+ const std::vector<int32_t> *typeIds = nullptr) {
+ auto typeIds__ = typeIds ? _fbb.CreateVector<int32_t>(*typeIds) : 0;
+ return org::apache::arrow::flatbuf::CreateUnion(
+ _fbb,
+ mode,
+ typeIds__);
+}
+
+struct Int FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef IntBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BITWIDTH = 4,
+ VT_IS_SIGNED = 6
+ };
+ int32_t bitWidth() const {
+ return GetField<int32_t>(VT_BITWIDTH, 0);
+ }
+ bool is_signed() const {
+ return GetField<uint8_t>(VT_IS_SIGNED, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_BITWIDTH) &&
+ VerifyField<uint8_t>(verifier, VT_IS_SIGNED) &&
+ verifier.EndTable();
+ }
+};
+
+struct IntBuilder {
+ typedef Int Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_bitWidth(int32_t bitWidth) {
+ fbb_.AddElement<int32_t>(Int::VT_BITWIDTH, bitWidth, 0);
+ }
+ void add_is_signed(bool is_signed) {
+ fbb_.AddElement<uint8_t>(Int::VT_IS_SIGNED, static_cast<uint8_t>(is_signed), 0);
+ }
+ explicit IntBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ IntBuilder &operator=(const IntBuilder &);
+ flatbuffers::Offset<Int> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Int>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Int> CreateInt(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t bitWidth = 0,
+ bool is_signed = false) {
+ IntBuilder builder_(_fbb);
+ builder_.add_bitWidth(bitWidth);
+ builder_.add_is_signed(is_signed);
+ return builder_.Finish();
+}
+
+struct FloatingPoint FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FloatingPointBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_PRECISION = 4
+ };
+ org::apache::arrow::flatbuf::Precision precision() const {
+ return static_cast<org::apache::arrow::flatbuf::Precision>(GetField<int16_t>(VT_PRECISION, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int16_t>(verifier, VT_PRECISION) &&
+ verifier.EndTable();
+ }
+};
+
+struct FloatingPointBuilder {
+ typedef FloatingPoint Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_precision(org::apache::arrow::flatbuf::Precision precision) {
+ fbb_.AddElement<int16_t>(FloatingPoint::VT_PRECISION, static_cast<int16_t>(precision), 0);
+ }
+ explicit FloatingPointBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FloatingPointBuilder &operator=(const FloatingPointBuilder &);
+ flatbuffers::Offset<FloatingPoint> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FloatingPoint>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FloatingPoint> CreateFloatingPoint(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::Precision precision = org::apache::arrow::flatbuf::Precision::HALF) {
+ FloatingPointBuilder builder_(_fbb);
+ builder_.add_precision(precision);
+ return builder_.Finish();
+}
+
+/// Unicode with UTF-8 encoding
+struct Utf8 FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef Utf8Builder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+};
+
+struct Utf8Builder {
+ typedef Utf8 Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit Utf8Builder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ Utf8Builder &operator=(const Utf8Builder &);
+ flatbuffers::Offset<Utf8> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Utf8>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Utf8> CreateUtf8(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ Utf8Builder builder_(_fbb);
+ return builder_.Finish();
+}
+
+/// Opaque binary data
+struct Binary FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef BinaryBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+};
+
+struct BinaryBuilder {
+ typedef Binary Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit BinaryBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ BinaryBuilder &operator=(const BinaryBuilder &);
+ flatbuffers::Offset<Binary> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Binary>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Binary> CreateBinary(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ BinaryBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+/// Same as Utf8, but with 64-bit offsets, allowing to represent
+/// extremely large data values.
+struct LargeUtf8 FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef LargeUtf8Builder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+};
+
+struct LargeUtf8Builder {
+ typedef LargeUtf8 Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit LargeUtf8Builder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ LargeUtf8Builder &operator=(const LargeUtf8Builder &);
+ flatbuffers::Offset<LargeUtf8> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<LargeUtf8>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<LargeUtf8> CreateLargeUtf8(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ LargeUtf8Builder builder_(_fbb);
+ return builder_.Finish();
+}
+
+/// Same as Binary, but with 64-bit offsets, allowing to represent
+/// extremely large data values.
+struct LargeBinary FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef LargeBinaryBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+};
+
+struct LargeBinaryBuilder {
+ typedef LargeBinary Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit LargeBinaryBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ LargeBinaryBuilder &operator=(const LargeBinaryBuilder &);
+ flatbuffers::Offset<LargeBinary> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<LargeBinary>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<LargeBinary> CreateLargeBinary(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ LargeBinaryBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct FixedSizeBinary FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FixedSizeBinaryBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BYTEWIDTH = 4
+ };
+ /// Number of bytes per value
+ int32_t byteWidth() const {
+ return GetField<int32_t>(VT_BYTEWIDTH, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_BYTEWIDTH) &&
+ verifier.EndTable();
+ }
+};
+
+struct FixedSizeBinaryBuilder {
+ typedef FixedSizeBinary Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_byteWidth(int32_t byteWidth) {
+ fbb_.AddElement<int32_t>(FixedSizeBinary::VT_BYTEWIDTH, byteWidth, 0);
+ }
+ explicit FixedSizeBinaryBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FixedSizeBinaryBuilder &operator=(const FixedSizeBinaryBuilder &);
+ flatbuffers::Offset<FixedSizeBinary> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FixedSizeBinary>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FixedSizeBinary> CreateFixedSizeBinary(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t byteWidth = 0) {
+ FixedSizeBinaryBuilder builder_(_fbb);
+ builder_.add_byteWidth(byteWidth);
+ return builder_.Finish();
+}
+
+struct Bool FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef BoolBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+};
+
+struct BoolBuilder {
+ typedef Bool Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit BoolBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ BoolBuilder &operator=(const BoolBuilder &);
+ flatbuffers::Offset<Bool> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Bool>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Bool> CreateBool(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ BoolBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+/// Exact decimal value represented as an integer value in two's
+/// complement. Currently only 128-bit (16-byte) and 256-bit (32-byte) integers
+/// are used. The representation uses the endianness indicated
+/// in the Schema.
+struct Decimal FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef DecimalBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_PRECISION = 4,
+ VT_SCALE = 6,
+ VT_BITWIDTH = 8
+ };
+ /// Total number of decimal digits
+ int32_t precision() const {
+ return GetField<int32_t>(VT_PRECISION, 0);
+ }
+ /// Number of digits after the decimal point "."
+ int32_t scale() const {
+ return GetField<int32_t>(VT_SCALE, 0);
+ }
+ /// Number of bits per value. The only accepted widths are 128 and 256.
+ /// We use bitWidth for consistency with Int::bitWidth.
+ int32_t bitWidth() const {
+ return GetField<int32_t>(VT_BITWIDTH, 128);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_PRECISION) &&
+ VerifyField<int32_t>(verifier, VT_SCALE) &&
+ VerifyField<int32_t>(verifier, VT_BITWIDTH) &&
+ verifier.EndTable();
+ }
+};
+
+struct DecimalBuilder {
+ typedef Decimal Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_precision(int32_t precision) {
+ fbb_.AddElement<int32_t>(Decimal::VT_PRECISION, precision, 0);
+ }
+ void add_scale(int32_t scale) {
+ fbb_.AddElement<int32_t>(Decimal::VT_SCALE, scale, 0);
+ }
+ void add_bitWidth(int32_t bitWidth) {
+ fbb_.AddElement<int32_t>(Decimal::VT_BITWIDTH, bitWidth, 128);
+ }
+ explicit DecimalBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ DecimalBuilder &operator=(const DecimalBuilder &);
+ flatbuffers::Offset<Decimal> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Decimal>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Decimal> CreateDecimal(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t precision = 0,
+ int32_t scale = 0,
+ int32_t bitWidth = 128) {
+ DecimalBuilder builder_(_fbb);
+ builder_.add_bitWidth(bitWidth);
+ builder_.add_scale(scale);
+ builder_.add_precision(precision);
+ return builder_.Finish();
+}
+
+/// Date is either a 32-bit or 64-bit signed integer type representing an
+/// elapsed time since UNIX epoch (1970-01-01), stored in either of two units:
+///
+/// * Milliseconds (64 bits) indicating UNIX time elapsed since the epoch (no
+/// leap seconds), where the values are evenly divisible by 86400000
+/// * Days (32 bits) since the UNIX epoch
+struct Date FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef DateBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_UNIT = 4
+ };
+ org::apache::arrow::flatbuf::DateUnit unit() const {
+ return static_cast<org::apache::arrow::flatbuf::DateUnit>(GetField<int16_t>(VT_UNIT, 1));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int16_t>(verifier, VT_UNIT) &&
+ verifier.EndTable();
+ }
+};
+
+struct DateBuilder {
+ typedef Date Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_unit(org::apache::arrow::flatbuf::DateUnit unit) {
+ fbb_.AddElement<int16_t>(Date::VT_UNIT, static_cast<int16_t>(unit), 1);
+ }
+ explicit DateBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ DateBuilder &operator=(const DateBuilder &);
+ flatbuffers::Offset<Date> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Date>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Date> CreateDate(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::DateUnit unit = org::apache::arrow::flatbuf::DateUnit::MILLISECOND) {
+ DateBuilder builder_(_fbb);
+ builder_.add_unit(unit);
+ return builder_.Finish();
+}
+
+/// Time is either a 32-bit or 64-bit signed integer type representing an
+/// elapsed time since midnight, stored in either of four units: seconds,
+/// milliseconds, microseconds or nanoseconds.
+///
+/// The integer `bitWidth` depends on the `unit` and must be one of the following:
+/// * SECOND and MILLISECOND: 32 bits
+/// * MICROSECOND and NANOSECOND: 64 bits
+///
+/// The allowed values are between 0 (inclusive) and 86400 (=24*60*60) seconds
+/// (exclusive), adjusted for the time unit (for example, up to 86400000
+/// exclusive for the MILLISECOND unit).
+/// This definition doesn't allow for leap seconds. Time values from
+/// measurements with leap seconds will need to be corrected when ingesting
+/// into Arrow (for example by replacing the value 86400 with 86399).
+struct Time FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef TimeBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_UNIT = 4,
+ VT_BITWIDTH = 6
+ };
+ org::apache::arrow::flatbuf::TimeUnit unit() const {
+ return static_cast<org::apache::arrow::flatbuf::TimeUnit>(GetField<int16_t>(VT_UNIT, 1));
+ }
+ int32_t bitWidth() const {
+ return GetField<int32_t>(VT_BITWIDTH, 32);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int16_t>(verifier, VT_UNIT) &&
+ VerifyField<int32_t>(verifier, VT_BITWIDTH) &&
+ verifier.EndTable();
+ }
+};
+
+struct TimeBuilder {
+ typedef Time Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_unit(org::apache::arrow::flatbuf::TimeUnit unit) {
+ fbb_.AddElement<int16_t>(Time::VT_UNIT, static_cast<int16_t>(unit), 1);
+ }
+ void add_bitWidth(int32_t bitWidth) {
+ fbb_.AddElement<int32_t>(Time::VT_BITWIDTH, bitWidth, 32);
+ }
+ explicit TimeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TimeBuilder &operator=(const TimeBuilder &);
+ flatbuffers::Offset<Time> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Time>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Time> CreateTime(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::TimeUnit unit = org::apache::arrow::flatbuf::TimeUnit::MILLISECOND,
+ int32_t bitWidth = 32) {
+ TimeBuilder builder_(_fbb);
+ builder_.add_bitWidth(bitWidth);
+ builder_.add_unit(unit);
+ return builder_.Finish();
+}
+
+/// Timestamp is a 64-bit signed integer representing an elapsed time since a
+/// fixed epoch, stored in either of four units: seconds, milliseconds,
+/// microseconds or nanoseconds, and is optionally annotated with a timezone.
+///
+/// Timestamp values do not include any leap seconds (in other words, all
+/// days are considered 86400 seconds long).
+///
+/// Timestamps with a non-empty timezone
+/// ------------------------------------
+///
+/// If a Timestamp column has a non-empty timezone value, its epoch is
+/// 1970-01-01 00:00:00 (January 1st 1970, midnight) in the *UTC* timezone
+/// (the Unix epoch), regardless of the Timestamp's own timezone.
+///
+/// Therefore, timestamp values with a non-empty timezone correspond to
+/// physical points in time together with some additional information about
+/// how the data was obtained and/or how to display it (the timezone).
+///
+/// For example, the timestamp value 0 with the timezone string "Europe/Paris"
+/// corresponds to "January 1st 1970, 00h00" in the UTC timezone, but the
+/// application may prefer to display it as "January 1st 1970, 01h00" in
+/// the Europe/Paris timezone (which is the same physical point in time).
+///
+/// One consequence is that timestamp values with a non-empty timezone
+/// can be compared and ordered directly, since they all share the same
+/// well-known point of reference (the Unix epoch).
+///
+/// Timestamps with an unset / empty timezone
+/// -----------------------------------------
+///
+/// If a Timestamp column has no timezone value, its epoch is
+/// 1970-01-01 00:00:00 (January 1st 1970, midnight) in an *unknown* timezone.
+///
+/// Therefore, timestamp values without a timezone cannot be meaningfully
+/// interpreted as physical points in time, but only as calendar / clock
+/// indications ("wall clock time") in an unspecified timezone.
+///
+/// For example, the timestamp value 0 with an empty timezone string
+/// corresponds to "January 1st 1970, 00h00" in an unknown timezone: there
+/// is not enough information to interpret it as a well-defined physical
+/// point in time.
+///
+/// One consequence is that timestamp values without a timezone cannot
+/// be reliably compared or ordered, since they may have different points of
+/// reference. In particular, it is *not* possible to interpret an unset
+/// or empty timezone as the same as "UTC".
+///
+/// Conversion between timezones
+/// ----------------------------
+///
+/// If a Timestamp column has a non-empty timezone, changing the timezone
+/// to a different non-empty value is a metadata-only operation:
+/// the timestamp values need not change as their point of reference remains
+/// the same (the Unix epoch).
+///
+/// However, if a Timestamp column has no timezone value, changing it to a
+/// non-empty value requires to think about the desired semantics.
+/// One possibility is to assume that the original timestamp values are
+/// relative to the epoch of the timezone being set; timestamp values should
+/// then adjusted to the Unix epoch (for example, changing the timezone from
+/// empty to "Europe/Paris" would require converting the timestamp values
+/// from "Europe/Paris" to "UTC", which seems counter-intuitive but is
+/// nevertheless correct).
+///
+/// Guidelines for encoding data from external libraries
+/// ----------------------------------------------------
+///
+/// Date & time libraries often have multiple different data types for temporal
+/// data. In order to ease interoperability between different implementations the
+/// Arrow project has some recommendations for encoding these types into a Timestamp
+/// column.
+///
+/// An "instant" represents a physical point in time that has no relevant timezone
+/// (for example, astronomical data). To encode an instant, use a Timestamp with
+/// the timezone string set to "UTC", and make sure the Timestamp values
+/// are relative to the UTC epoch (January 1st 1970, midnight).
+///
+/// A "zoned date-time" represents a physical point in time annotated with an
+/// informative timezone (for example, the timezone in which the data was
+/// recorded). To encode a zoned date-time, use a Timestamp with the timezone
+/// string set to the name of the timezone, and make sure the Timestamp values
+/// are relative to the UTC epoch (January 1st 1970, midnight).
+///
+/// (There is some ambiguity between an instant and a zoned date-time with the
+/// UTC timezone. Both of these are stored the same in Arrow. Typically,
+/// this distinction does not matter. If it does, then an application should
+/// use custom metadata or an extension type to distinguish between the two cases.)
+///
+/// An "offset date-time" represents a physical point in time combined with an
+/// explicit offset from UTC. To encode an offset date-time, use a Timestamp
+/// with the timezone string set to the numeric timezone offset string
+/// (e.g. "+03:00"), and make sure the Timestamp values are relative to
+/// the UTC epoch (January 1st 1970, midnight).
+///
+/// A "naive date-time" (also called "local date-time" in some libraries)
+/// represents a wall clock time combined with a calendar date, but with
+/// no indication of how to map this information to a physical point in time.
+/// Naive date-times must be handled with care because of this missing
+/// information, and also because daylight saving time (DST) may make
+/// some values ambiguous or non-existent. A naive date-time may be
+/// stored as a struct with Date and Time fields. However, it may also be
+/// encoded into a Timestamp column with an empty timezone. The timestamp
+/// values should be computed "as if" the timezone of the date-time values
+/// was UTC; for example, the naive date-time "January 1st 1970, 00h00" would
+/// be encoded as timestamp value 0.
+struct Timestamp FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef TimestampBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_UNIT = 4,
+ VT_TIMEZONE = 6
+ };
+ org::apache::arrow::flatbuf::TimeUnit unit() const {
+ return static_cast<org::apache::arrow::flatbuf::TimeUnit>(GetField<int16_t>(VT_UNIT, 0));
+ }
+ /// The timezone is an optional string indicating the name of a timezone,
+ /// one of:
+ ///
+ /// * As used in the Olson timezone database (the "tz database" or
+ /// "tzdata"), such as "America/New_York".
+ /// * An absolute timezone offset of the form "+XX:XX" or "-XX:XX",
+ /// such as "+07:30".
+ ///
+ /// Whether a timezone string is present indicates different semantics about
+ /// the data (see above).
+ const flatbuffers::String *timezone() const {
+ return GetPointer<const flatbuffers::String *>(VT_TIMEZONE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int16_t>(verifier, VT_UNIT) &&
+ VerifyOffset(verifier, VT_TIMEZONE) &&
+ verifier.VerifyString(timezone()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TimestampBuilder {
+ typedef Timestamp Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_unit(org::apache::arrow::flatbuf::TimeUnit unit) {
+ fbb_.AddElement<int16_t>(Timestamp::VT_UNIT, static_cast<int16_t>(unit), 0);
+ }
+ void add_timezone(flatbuffers::Offset<flatbuffers::String> timezone) {
+ fbb_.AddOffset(Timestamp::VT_TIMEZONE, timezone);
+ }
+ explicit TimestampBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TimestampBuilder &operator=(const TimestampBuilder &);
+ flatbuffers::Offset<Timestamp> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Timestamp>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Timestamp> CreateTimestamp(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::TimeUnit unit = org::apache::arrow::flatbuf::TimeUnit::SECOND,
+ flatbuffers::Offset<flatbuffers::String> timezone = 0) {
+ TimestampBuilder builder_(_fbb);
+ builder_.add_timezone(timezone);
+ builder_.add_unit(unit);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Timestamp> CreateTimestampDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::TimeUnit unit = org::apache::arrow::flatbuf::TimeUnit::SECOND,
+ const char *timezone = nullptr) {
+ auto timezone__ = timezone ? _fbb.CreateString(timezone) : 0;
+ return org::apache::arrow::flatbuf::CreateTimestamp(
+ _fbb,
+ unit,
+ timezone__);
+}
+
+struct Interval FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef IntervalBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_UNIT = 4
+ };
+ org::apache::arrow::flatbuf::IntervalUnit unit() const {
+ return static_cast<org::apache::arrow::flatbuf::IntervalUnit>(GetField<int16_t>(VT_UNIT, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int16_t>(verifier, VT_UNIT) &&
+ verifier.EndTable();
+ }
+};
+
+struct IntervalBuilder {
+ typedef Interval Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_unit(org::apache::arrow::flatbuf::IntervalUnit unit) {
+ fbb_.AddElement<int16_t>(Interval::VT_UNIT, static_cast<int16_t>(unit), 0);
+ }
+ explicit IntervalBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ IntervalBuilder &operator=(const IntervalBuilder &);
+ flatbuffers::Offset<Interval> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Interval>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Interval> CreateInterval(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::IntervalUnit unit = org::apache::arrow::flatbuf::IntervalUnit::YEAR_MONTH) {
+ IntervalBuilder builder_(_fbb);
+ builder_.add_unit(unit);
+ return builder_.Finish();
+}
+
+struct Duration FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef DurationBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_UNIT = 4
+ };
+ org::apache::arrow::flatbuf::TimeUnit unit() const {
+ return static_cast<org::apache::arrow::flatbuf::TimeUnit>(GetField<int16_t>(VT_UNIT, 1));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int16_t>(verifier, VT_UNIT) &&
+ verifier.EndTable();
+ }
+};
+
+struct DurationBuilder {
+ typedef Duration Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_unit(org::apache::arrow::flatbuf::TimeUnit unit) {
+ fbb_.AddElement<int16_t>(Duration::VT_UNIT, static_cast<int16_t>(unit), 1);
+ }
+ explicit DurationBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ DurationBuilder &operator=(const DurationBuilder &);
+ flatbuffers::Offset<Duration> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Duration>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Duration> CreateDuration(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::TimeUnit unit = org::apache::arrow::flatbuf::TimeUnit::MILLISECOND) {
+ DurationBuilder builder_(_fbb);
+ builder_.add_unit(unit);
+ return builder_.Finish();
+}
+
+/// ----------------------------------------------------------------------
+/// user defined key value pairs to add custom metadata to arrow
+/// key namespacing is the responsibility of the user
+struct KeyValue FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef KeyValueBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_KEY = 4,
+ VT_VALUE = 6
+ };
+ const flatbuffers::String *key() const {
+ return GetPointer<const flatbuffers::String *>(VT_KEY);
+ }
+ const flatbuffers::String *value() const {
+ return GetPointer<const flatbuffers::String *>(VT_VALUE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_KEY) &&
+ verifier.VerifyString(key()) &&
+ VerifyOffset(verifier, VT_VALUE) &&
+ verifier.VerifyString(value()) &&
+ verifier.EndTable();
+ }
+};
+
+struct KeyValueBuilder {
+ typedef KeyValue Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_key(flatbuffers::Offset<flatbuffers::String> key) {
+ fbb_.AddOffset(KeyValue::VT_KEY, key);
+ }
+ void add_value(flatbuffers::Offset<flatbuffers::String> value) {
+ fbb_.AddOffset(KeyValue::VT_VALUE, value);
+ }
+ explicit KeyValueBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ KeyValueBuilder &operator=(const KeyValueBuilder &);
+ flatbuffers::Offset<KeyValue> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<KeyValue>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<KeyValue> CreateKeyValue(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> key = 0,
+ flatbuffers::Offset<flatbuffers::String> value = 0) {
+ KeyValueBuilder builder_(_fbb);
+ builder_.add_value(value);
+ builder_.add_key(key);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<KeyValue> CreateKeyValueDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *key = nullptr,
+ const char *value = nullptr) {
+ auto key__ = key ? _fbb.CreateString(key) : 0;
+ auto value__ = value ? _fbb.CreateString(value) : 0;
+ return org::apache::arrow::flatbuf::CreateKeyValue(
+ _fbb,
+ key__,
+ value__);
+}
+
+struct DictionaryEncoding FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef DictionaryEncodingBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_ID = 4,
+ VT_INDEXTYPE = 6,
+ VT_ISORDERED = 8,
+ VT_DICTIONARYKIND = 10
+ };
+ /// The known dictionary id in the application where this data is used. In
+ /// the file or streaming formats, the dictionary ids are found in the
+ /// DictionaryBatch messages
+ int64_t id() const {
+ return GetField<int64_t>(VT_ID, 0);
+ }
+ /// The dictionary indices are constrained to be non-negative integers. If
+ /// this field is null, the indices must be signed int32. To maximize
+ /// cross-language compatibility and performance, implementations are
+ /// recommended to prefer signed integer types over unsigned integer types
+ /// and to avoid uint64 indices unless they are required by an application.
+ const org::apache::arrow::flatbuf::Int *indexType() const {
+ return GetPointer<const org::apache::arrow::flatbuf::Int *>(VT_INDEXTYPE);
+ }
+ /// By default, dictionaries are not ordered, or the order does not have
+ /// semantic meaning. In some statistical, applications, dictionary-encoding
+ /// is used to represent ordered categorical data, and we provide a way to
+ /// preserve that metadata here
+ bool isOrdered() const {
+ return GetField<uint8_t>(VT_ISORDERED, 0) != 0;
+ }
+ org::apache::arrow::flatbuf::DictionaryKind dictionaryKind() const {
+ return static_cast<org::apache::arrow::flatbuf::DictionaryKind>(GetField<int16_t>(VT_DICTIONARYKIND, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int64_t>(verifier, VT_ID) &&
+ VerifyOffset(verifier, VT_INDEXTYPE) &&
+ verifier.VerifyTable(indexType()) &&
+ VerifyField<uint8_t>(verifier, VT_ISORDERED) &&
+ VerifyField<int16_t>(verifier, VT_DICTIONARYKIND) &&
+ verifier.EndTable();
+ }
+};
+
+struct DictionaryEncodingBuilder {
+ typedef DictionaryEncoding Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_id(int64_t id) {
+ fbb_.AddElement<int64_t>(DictionaryEncoding::VT_ID, id, 0);
+ }
+ void add_indexType(flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indexType) {
+ fbb_.AddOffset(DictionaryEncoding::VT_INDEXTYPE, indexType);
+ }
+ void add_isOrdered(bool isOrdered) {
+ fbb_.AddElement<uint8_t>(DictionaryEncoding::VT_ISORDERED, static_cast<uint8_t>(isOrdered), 0);
+ }
+ void add_dictionaryKind(org::apache::arrow::flatbuf::DictionaryKind dictionaryKind) {
+ fbb_.AddElement<int16_t>(DictionaryEncoding::VT_DICTIONARYKIND, static_cast<int16_t>(dictionaryKind), 0);
+ }
+ explicit DictionaryEncodingBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ DictionaryEncodingBuilder &operator=(const DictionaryEncodingBuilder &);
+ flatbuffers::Offset<DictionaryEncoding> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<DictionaryEncoding>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<DictionaryEncoding> CreateDictionaryEncoding(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int64_t id = 0,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indexType = 0,
+ bool isOrdered = false,
+ org::apache::arrow::flatbuf::DictionaryKind dictionaryKind = org::apache::arrow::flatbuf::DictionaryKind::DenseArray) {
+ DictionaryEncodingBuilder builder_(_fbb);
+ builder_.add_id(id);
+ builder_.add_indexType(indexType);
+ builder_.add_dictionaryKind(dictionaryKind);
+ builder_.add_isOrdered(isOrdered);
+ return builder_.Finish();
+}
+
+/// ----------------------------------------------------------------------
+/// A field represents a named column in a record / row batch or child of a
+/// nested type.
+struct Field FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FieldBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_NAME = 4,
+ VT_NULLABLE = 6,
+ VT_TYPE_TYPE = 8,
+ VT_TYPE = 10,
+ VT_DICTIONARY = 12,
+ VT_CHILDREN = 14,
+ VT_CUSTOM_METADATA = 16
+ };
+ /// Name is not required, in i.e. a List
+ const flatbuffers::String *name() const {
+ return GetPointer<const flatbuffers::String *>(VT_NAME);
+ }
+ /// Whether or not this field can contain nulls. Should be true in general.
+ bool nullable() const {
+ return GetField<uint8_t>(VT_NULLABLE, 0) != 0;
+ }
+ org::apache::arrow::flatbuf::Type type_type() const {
+ return static_cast<org::apache::arrow::flatbuf::Type>(GetField<uint8_t>(VT_TYPE_TYPE, 0));
+ }
+ /// This is the type of the decoded value if the field is dictionary encoded.
+ const void *type() const {
+ return GetPointer<const void *>(VT_TYPE);
+ }
+ template<typename T> const T *type_as() const;
+ const org::apache::arrow::flatbuf::Null *type_as_Null() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Null ? static_cast<const org::apache::arrow::flatbuf::Null *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Int *type_as_Int() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Int ? static_cast<const org::apache::arrow::flatbuf::Int *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::FloatingPoint *type_as_FloatingPoint() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::FloatingPoint ? static_cast<const org::apache::arrow::flatbuf::FloatingPoint *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Binary *type_as_Binary() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Binary ? static_cast<const org::apache::arrow::flatbuf::Binary *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Utf8 *type_as_Utf8() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Utf8 ? static_cast<const org::apache::arrow::flatbuf::Utf8 *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Bool *type_as_Bool() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Bool ? static_cast<const org::apache::arrow::flatbuf::Bool *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Decimal *type_as_Decimal() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Decimal ? static_cast<const org::apache::arrow::flatbuf::Decimal *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Date *type_as_Date() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Date ? static_cast<const org::apache::arrow::flatbuf::Date *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Time *type_as_Time() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Time ? static_cast<const org::apache::arrow::flatbuf::Time *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Timestamp *type_as_Timestamp() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Timestamp ? static_cast<const org::apache::arrow::flatbuf::Timestamp *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Interval *type_as_Interval() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Interval ? static_cast<const org::apache::arrow::flatbuf::Interval *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::List *type_as_List() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::List ? static_cast<const org::apache::arrow::flatbuf::List *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Struct_ *type_as_Struct_() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Struct_ ? static_cast<const org::apache::arrow::flatbuf::Struct_ *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Union *type_as_Union() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Union ? static_cast<const org::apache::arrow::flatbuf::Union *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::FixedSizeBinary *type_as_FixedSizeBinary() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::FixedSizeBinary ? static_cast<const org::apache::arrow::flatbuf::FixedSizeBinary *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::FixedSizeList *type_as_FixedSizeList() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::FixedSizeList ? static_cast<const org::apache::arrow::flatbuf::FixedSizeList *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Map *type_as_Map() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Map ? static_cast<const org::apache::arrow::flatbuf::Map *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Duration *type_as_Duration() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Duration ? static_cast<const org::apache::arrow::flatbuf::Duration *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::LargeBinary *type_as_LargeBinary() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::LargeBinary ? static_cast<const org::apache::arrow::flatbuf::LargeBinary *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::LargeUtf8 *type_as_LargeUtf8() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::LargeUtf8 ? static_cast<const org::apache::arrow::flatbuf::LargeUtf8 *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::LargeList *type_as_LargeList() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::LargeList ? static_cast<const org::apache::arrow::flatbuf::LargeList *>(type()) : nullptr;
+ }
+ /// Present only if the field is dictionary encoded.
+ const org::apache::arrow::flatbuf::DictionaryEncoding *dictionary() const {
+ return GetPointer<const org::apache::arrow::flatbuf::DictionaryEncoding *>(VT_DICTIONARY);
+ }
+ /// children apply only to nested data types like Struct, List and Union. For
+ /// primitive types children will have length 0.
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>> *children() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>> *>(VT_CHILDREN);
+ }
+ /// User-defined metadata
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>> *custom_metadata() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>> *>(VT_CUSTOM_METADATA);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_NAME) &&
+ verifier.VerifyString(name()) &&
+ VerifyField<uint8_t>(verifier, VT_NULLABLE) &&
+ VerifyField<uint8_t>(verifier, VT_TYPE_TYPE) &&
+ VerifyOffset(verifier, VT_TYPE) &&
+ VerifyType(verifier, type(), type_type()) &&
+ VerifyOffset(verifier, VT_DICTIONARY) &&
+ verifier.VerifyTable(dictionary()) &&
+ VerifyOffset(verifier, VT_CHILDREN) &&
+ verifier.VerifyVector(children()) &&
+ verifier.VerifyVectorOfTables(children()) &&
+ VerifyOffset(verifier, VT_CUSTOM_METADATA) &&
+ verifier.VerifyVector(custom_metadata()) &&
+ verifier.VerifyVectorOfTables(custom_metadata()) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const org::apache::arrow::flatbuf::Null *Field::type_as<org::apache::arrow::flatbuf::Null>() const {
+ return type_as_Null();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Int *Field::type_as<org::apache::arrow::flatbuf::Int>() const {
+ return type_as_Int();
+}
+
+template<> inline const org::apache::arrow::flatbuf::FloatingPoint *Field::type_as<org::apache::arrow::flatbuf::FloatingPoint>() const {
+ return type_as_FloatingPoint();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Binary *Field::type_as<org::apache::arrow::flatbuf::Binary>() const {
+ return type_as_Binary();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Utf8 *Field::type_as<org::apache::arrow::flatbuf::Utf8>() const {
+ return type_as_Utf8();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Bool *Field::type_as<org::apache::arrow::flatbuf::Bool>() const {
+ return type_as_Bool();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Decimal *Field::type_as<org::apache::arrow::flatbuf::Decimal>() const {
+ return type_as_Decimal();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Date *Field::type_as<org::apache::arrow::flatbuf::Date>() const {
+ return type_as_Date();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Time *Field::type_as<org::apache::arrow::flatbuf::Time>() const {
+ return type_as_Time();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Timestamp *Field::type_as<org::apache::arrow::flatbuf::Timestamp>() const {
+ return type_as_Timestamp();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Interval *Field::type_as<org::apache::arrow::flatbuf::Interval>() const {
+ return type_as_Interval();
+}
+
+template<> inline const org::apache::arrow::flatbuf::List *Field::type_as<org::apache::arrow::flatbuf::List>() const {
+ return type_as_List();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Struct_ *Field::type_as<org::apache::arrow::flatbuf::Struct_>() const {
+ return type_as_Struct_();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Union *Field::type_as<org::apache::arrow::flatbuf::Union>() const {
+ return type_as_Union();
+}
+
+template<> inline const org::apache::arrow::flatbuf::FixedSizeBinary *Field::type_as<org::apache::arrow::flatbuf::FixedSizeBinary>() const {
+ return type_as_FixedSizeBinary();
+}
+
+template<> inline const org::apache::arrow::flatbuf::FixedSizeList *Field::type_as<org::apache::arrow::flatbuf::FixedSizeList>() const {
+ return type_as_FixedSizeList();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Map *Field::type_as<org::apache::arrow::flatbuf::Map>() const {
+ return type_as_Map();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Duration *Field::type_as<org::apache::arrow::flatbuf::Duration>() const {
+ return type_as_Duration();
+}
+
+template<> inline const org::apache::arrow::flatbuf::LargeBinary *Field::type_as<org::apache::arrow::flatbuf::LargeBinary>() const {
+ return type_as_LargeBinary();
+}
+
+template<> inline const org::apache::arrow::flatbuf::LargeUtf8 *Field::type_as<org::apache::arrow::flatbuf::LargeUtf8>() const {
+ return type_as_LargeUtf8();
+}
+
+template<> inline const org::apache::arrow::flatbuf::LargeList *Field::type_as<org::apache::arrow::flatbuf::LargeList>() const {
+ return type_as_LargeList();
+}
+
+struct FieldBuilder {
+ typedef Field Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_name(flatbuffers::Offset<flatbuffers::String> name) {
+ fbb_.AddOffset(Field::VT_NAME, name);
+ }
+ void add_nullable(bool nullable) {
+ fbb_.AddElement<uint8_t>(Field::VT_NULLABLE, static_cast<uint8_t>(nullable), 0);
+ }
+ void add_type_type(org::apache::arrow::flatbuf::Type type_type) {
+ fbb_.AddElement<uint8_t>(Field::VT_TYPE_TYPE, static_cast<uint8_t>(type_type), 0);
+ }
+ void add_type(flatbuffers::Offset<void> type) {
+ fbb_.AddOffset(Field::VT_TYPE, type);
+ }
+ void add_dictionary(flatbuffers::Offset<org::apache::arrow::flatbuf::DictionaryEncoding> dictionary) {
+ fbb_.AddOffset(Field::VT_DICTIONARY, dictionary);
+ }
+ void add_children(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>>> children) {
+ fbb_.AddOffset(Field::VT_CHILDREN, children);
+ }
+ void add_custom_metadata(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>> custom_metadata) {
+ fbb_.AddOffset(Field::VT_CUSTOM_METADATA, custom_metadata);
+ }
+ explicit FieldBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FieldBuilder &operator=(const FieldBuilder &);
+ flatbuffers::Offset<Field> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Field>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Field> CreateField(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> name = 0,
+ bool nullable = false,
+ org::apache::arrow::flatbuf::Type type_type = org::apache::arrow::flatbuf::Type::NONE,
+ flatbuffers::Offset<void> type = 0,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::DictionaryEncoding> dictionary = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>>> children = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>> custom_metadata = 0) {
+ FieldBuilder builder_(_fbb);
+ builder_.add_custom_metadata(custom_metadata);
+ builder_.add_children(children);
+ builder_.add_dictionary(dictionary);
+ builder_.add_type(type);
+ builder_.add_name(name);
+ builder_.add_type_type(type_type);
+ builder_.add_nullable(nullable);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Field> CreateFieldDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *name = nullptr,
+ bool nullable = false,
+ org::apache::arrow::flatbuf::Type type_type = org::apache::arrow::flatbuf::Type::NONE,
+ flatbuffers::Offset<void> type = 0,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::DictionaryEncoding> dictionary = 0,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>> *children = nullptr,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>> *custom_metadata = nullptr) {
+ auto name__ = name ? _fbb.CreateString(name) : 0;
+ auto children__ = children ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>>(*children) : 0;
+ auto custom_metadata__ = custom_metadata ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>(*custom_metadata) : 0;
+ return org::apache::arrow::flatbuf::CreateField(
+ _fbb,
+ name__,
+ nullable,
+ type_type,
+ type,
+ dictionary,
+ children__,
+ custom_metadata__);
+}
+
+/// ----------------------------------------------------------------------
+/// A Schema describes the columns in a row batch
+struct Schema FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SchemaBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_ENDIANNESS = 4,
+ VT_FIELDS = 6,
+ VT_CUSTOM_METADATA = 8,
+ VT_FEATURES = 10
+ };
+ /// endianness of the buffer
+ /// it is Little Endian by default
+ /// if endianness doesn't match the underlying system then the vectors need to be converted
+ org::apache::arrow::flatbuf::Endianness endianness() const {
+ return static_cast<org::apache::arrow::flatbuf::Endianness>(GetField<int16_t>(VT_ENDIANNESS, 0));
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>> *fields() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>> *>(VT_FIELDS);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>> *custom_metadata() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>> *>(VT_CUSTOM_METADATA);
+ }
+ /// Features used in the stream/file.
+ const flatbuffers::Vector<int64_t> *features() const {
+ return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_FEATURES);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int16_t>(verifier, VT_ENDIANNESS) &&
+ VerifyOffset(verifier, VT_FIELDS) &&
+ verifier.VerifyVector(fields()) &&
+ verifier.VerifyVectorOfTables(fields()) &&
+ VerifyOffset(verifier, VT_CUSTOM_METADATA) &&
+ verifier.VerifyVector(custom_metadata()) &&
+ verifier.VerifyVectorOfTables(custom_metadata()) &&
+ VerifyOffset(verifier, VT_FEATURES) &&
+ verifier.VerifyVector(features()) &&
+ verifier.EndTable();
+ }
+};
+
+struct SchemaBuilder {
+ typedef Schema Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_endianness(org::apache::arrow::flatbuf::Endianness endianness) {
+ fbb_.AddElement<int16_t>(Schema::VT_ENDIANNESS, static_cast<int16_t>(endianness), 0);
+ }
+ void add_fields(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>>> fields) {
+ fbb_.AddOffset(Schema::VT_FIELDS, fields);
+ }
+ void add_custom_metadata(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>> custom_metadata) {
+ fbb_.AddOffset(Schema::VT_CUSTOM_METADATA, custom_metadata);
+ }
+ void add_features(flatbuffers::Offset<flatbuffers::Vector<int64_t>> features) {
+ fbb_.AddOffset(Schema::VT_FEATURES, features);
+ }
+ explicit SchemaBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SchemaBuilder &operator=(const SchemaBuilder &);
+ flatbuffers::Offset<Schema> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Schema>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Schema> CreateSchema(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::Endianness endianness = org::apache::arrow::flatbuf::Endianness::Little,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>>> fields = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>> custom_metadata = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int64_t>> features = 0) {
+ SchemaBuilder builder_(_fbb);
+ builder_.add_features(features);
+ builder_.add_custom_metadata(custom_metadata);
+ builder_.add_fields(fields);
+ builder_.add_endianness(endianness);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Schema> CreateSchemaDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::Endianness endianness = org::apache::arrow::flatbuf::Endianness::Little,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>> *fields = nullptr,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>> *custom_metadata = nullptr,
+ const std::vector<int64_t> *features = nullptr) {
+ auto fields__ = fields ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>>(*fields) : 0;
+ auto custom_metadata__ = custom_metadata ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::flatbuf::KeyValue>>(*custom_metadata) : 0;
+ auto features__ = features ? _fbb.CreateVector<int64_t>(*features) : 0;
+ return org::apache::arrow::flatbuf::CreateSchema(
+ _fbb,
+ endianness,
+ fields__,
+ custom_metadata__,
+ features__);
+}
+
+inline bool VerifyType(flatbuffers::Verifier &verifier, const void *obj, Type type) {
+ switch (type) {
+ case Type::NONE: {
+ return true;
+ }
+ case Type::Null: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Null *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::Int: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Int *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::FloatingPoint: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::FloatingPoint *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::Binary: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Binary *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::Utf8: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Utf8 *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::Bool: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Bool *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::Decimal: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Decimal *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::Date: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Date *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::Time: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Time *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::Timestamp: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Timestamp *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::Interval: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Interval *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::List: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::List *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::Struct_: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Struct_ *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::Union: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Union *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::FixedSizeBinary: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::FixedSizeBinary *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::FixedSizeList: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::FixedSizeList *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::Map: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Map *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::Duration: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::Duration *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::LargeBinary: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::LargeBinary *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::LargeUtf8: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::LargeUtf8 *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Type::LargeList: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::LargeList *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return true;
+ }
+}
+
+inline bool VerifyTypeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyType(
+ verifier, values->Get(i), types->GetEnum<Type>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline const org::apache::arrow::flatbuf::Schema *GetSchema(const void *buf) {
+ return flatbuffers::GetRoot<org::apache::arrow::flatbuf::Schema>(buf);
+}
+
+inline const org::apache::arrow::flatbuf::Schema *GetSizePrefixedSchema(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<org::apache::arrow::flatbuf::Schema>(buf);
+}
+
+inline bool VerifySchemaBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifyBuffer<org::apache::arrow::flatbuf::Schema>(nullptr);
+}
+
+inline bool VerifySizePrefixedSchemaBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<org::apache::arrow::flatbuf::Schema>(nullptr);
+}
+
+inline void FinishSchemaBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Schema> root) {
+ fbb.Finish(root);
+}
+
+inline void FinishSizePrefixedSchemaBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Schema> root) {
+ fbb.FinishSizePrefixed(root);
+}
+
+} // namespace flatbuf
+} // namespace arrow
+} // namespace apache
+} // namespace org
+
+#endif // FLATBUFFERS_GENERATED_SCHEMA_ORG_APACHE_ARROW_FLATBUF_H_
diff --git a/src/arrow/cpp/src/generated/SparseTensor_generated.h b/src/arrow/cpp/src/generated/SparseTensor_generated.h
new file mode 100644
index 000000000..a66269182
--- /dev/null
+++ b/src/arrow/cpp/src/generated/SparseTensor_generated.h
@@ -0,0 +1,921 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_SPARSETENSOR_ORG_APACHE_ARROW_FLATBUF_H_
+#define FLATBUFFERS_GENERATED_SPARSETENSOR_ORG_APACHE_ARROW_FLATBUF_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+#include "Schema_generated.h"
+#include "Tensor_generated.h"
+
+namespace org {
+namespace apache {
+namespace arrow {
+namespace flatbuf {
+
+struct SparseTensorIndexCOO;
+struct SparseTensorIndexCOOBuilder;
+
+struct SparseMatrixIndexCSX;
+struct SparseMatrixIndexCSXBuilder;
+
+struct SparseTensorIndexCSF;
+struct SparseTensorIndexCSFBuilder;
+
+struct SparseTensor;
+struct SparseTensorBuilder;
+
+enum class SparseMatrixCompressedAxis : int16_t {
+ Row = 0,
+ Column = 1,
+ MIN = Row,
+ MAX = Column
+};
+
+inline const SparseMatrixCompressedAxis (&EnumValuesSparseMatrixCompressedAxis())[2] {
+ static const SparseMatrixCompressedAxis values[] = {
+ SparseMatrixCompressedAxis::Row,
+ SparseMatrixCompressedAxis::Column
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesSparseMatrixCompressedAxis() {
+ static const char * const names[3] = {
+ "Row",
+ "Column",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameSparseMatrixCompressedAxis(SparseMatrixCompressedAxis e) {
+ if (flatbuffers::IsOutRange(e, SparseMatrixCompressedAxis::Row, SparseMatrixCompressedAxis::Column)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesSparseMatrixCompressedAxis()[index];
+}
+
+enum class SparseTensorIndex : uint8_t {
+ NONE = 0,
+ SparseTensorIndexCOO = 1,
+ SparseMatrixIndexCSX = 2,
+ SparseTensorIndexCSF = 3,
+ MIN = NONE,
+ MAX = SparseTensorIndexCSF
+};
+
+inline const SparseTensorIndex (&EnumValuesSparseTensorIndex())[4] {
+ static const SparseTensorIndex values[] = {
+ SparseTensorIndex::NONE,
+ SparseTensorIndex::SparseTensorIndexCOO,
+ SparseTensorIndex::SparseMatrixIndexCSX,
+ SparseTensorIndex::SparseTensorIndexCSF
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesSparseTensorIndex() {
+ static const char * const names[5] = {
+ "NONE",
+ "SparseTensorIndexCOO",
+ "SparseMatrixIndexCSX",
+ "SparseTensorIndexCSF",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameSparseTensorIndex(SparseTensorIndex e) {
+ if (flatbuffers::IsOutRange(e, SparseTensorIndex::NONE, SparseTensorIndex::SparseTensorIndexCSF)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesSparseTensorIndex()[index];
+}
+
+template<typename T> struct SparseTensorIndexTraits {
+ static const SparseTensorIndex enum_value = SparseTensorIndex::NONE;
+};
+
+template<> struct SparseTensorIndexTraits<org::apache::arrow::flatbuf::SparseTensorIndexCOO> {
+ static const SparseTensorIndex enum_value = SparseTensorIndex::SparseTensorIndexCOO;
+};
+
+template<> struct SparseTensorIndexTraits<org::apache::arrow::flatbuf::SparseMatrixIndexCSX> {
+ static const SparseTensorIndex enum_value = SparseTensorIndex::SparseMatrixIndexCSX;
+};
+
+template<> struct SparseTensorIndexTraits<org::apache::arrow::flatbuf::SparseTensorIndexCSF> {
+ static const SparseTensorIndex enum_value = SparseTensorIndex::SparseTensorIndexCSF;
+};
+
+bool VerifySparseTensorIndex(flatbuffers::Verifier &verifier, const void *obj, SparseTensorIndex type);
+bool VerifySparseTensorIndexVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+/// ----------------------------------------------------------------------
+/// EXPERIMENTAL: Data structures for sparse tensors
+/// Coordinate (COO) format of sparse tensor index.
+///
+/// COO's index list are represented as a NxM matrix,
+/// where N is the number of non-zero values,
+/// and M is the number of dimensions of a sparse tensor.
+///
+/// indicesBuffer stores the location and size of the data of this indices
+/// matrix. The value type and the stride of the indices matrix is
+/// specified in indicesType and indicesStrides fields.
+///
+/// For example, let X be a 2x3x4x5 tensor, and it has the following
+/// 6 non-zero values:
+/// ```text
+/// X[0, 1, 2, 0] := 1
+/// X[1, 1, 2, 3] := 2
+/// X[0, 2, 1, 0] := 3
+/// X[0, 1, 3, 0] := 4
+/// X[0, 1, 2, 1] := 5
+/// X[1, 2, 0, 4] := 6
+/// ```
+/// In COO format, the index matrix of X is the following 4x6 matrix:
+/// ```text
+/// [[0, 0, 0, 0, 1, 1],
+/// [1, 1, 1, 2, 1, 2],
+/// [2, 2, 3, 1, 2, 0],
+/// [0, 1, 0, 0, 3, 4]]
+/// ```
+/// When isCanonical is true, the indices is sorted in lexicographical order
+/// (row-major order), and it does not have duplicated entries. Otherwise,
+/// the indices may not be sorted, or may have duplicated entries.
+struct SparseTensorIndexCOO FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SparseTensorIndexCOOBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_INDICESTYPE = 4,
+ VT_INDICESSTRIDES = 6,
+ VT_INDICESBUFFER = 8,
+ VT_ISCANONICAL = 10
+ };
+ /// The type of values in indicesBuffer
+ const org::apache::arrow::flatbuf::Int *indicesType() const {
+ return GetPointer<const org::apache::arrow::flatbuf::Int *>(VT_INDICESTYPE);
+ }
+ /// Non-negative byte offsets to advance one value cell along each dimension
+ /// If omitted, default to row-major order (C-like).
+ const flatbuffers::Vector<int64_t> *indicesStrides() const {
+ return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_INDICESSTRIDES);
+ }
+ /// The location and size of the indices matrix's data
+ const org::apache::arrow::flatbuf::Buffer *indicesBuffer() const {
+ return GetStruct<const org::apache::arrow::flatbuf::Buffer *>(VT_INDICESBUFFER);
+ }
+ /// This flag is true if and only if the indices matrix is sorted in
+ /// row-major order, and does not have duplicated entries.
+ /// This sort order is the same as of Tensorflow's SparseTensor,
+ /// but it is inverse order of SciPy's canonical coo_matrix
+ /// (SciPy employs column-major order for its coo_matrix).
+ bool isCanonical() const {
+ return GetField<uint8_t>(VT_ISCANONICAL, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_INDICESTYPE) &&
+ verifier.VerifyTable(indicesType()) &&
+ VerifyOffset(verifier, VT_INDICESSTRIDES) &&
+ verifier.VerifyVector(indicesStrides()) &&
+ VerifyFieldRequired<org::apache::arrow::flatbuf::Buffer>(verifier, VT_INDICESBUFFER) &&
+ VerifyField<uint8_t>(verifier, VT_ISCANONICAL) &&
+ verifier.EndTable();
+ }
+};
+
+struct SparseTensorIndexCOOBuilder {
+ typedef SparseTensorIndexCOO Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_indicesType(flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indicesType) {
+ fbb_.AddOffset(SparseTensorIndexCOO::VT_INDICESTYPE, indicesType);
+ }
+ void add_indicesStrides(flatbuffers::Offset<flatbuffers::Vector<int64_t>> indicesStrides) {
+ fbb_.AddOffset(SparseTensorIndexCOO::VT_INDICESSTRIDES, indicesStrides);
+ }
+ void add_indicesBuffer(const org::apache::arrow::flatbuf::Buffer *indicesBuffer) {
+ fbb_.AddStruct(SparseTensorIndexCOO::VT_INDICESBUFFER, indicesBuffer);
+ }
+ void add_isCanonical(bool isCanonical) {
+ fbb_.AddElement<uint8_t>(SparseTensorIndexCOO::VT_ISCANONICAL, static_cast<uint8_t>(isCanonical), 0);
+ }
+ explicit SparseTensorIndexCOOBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SparseTensorIndexCOOBuilder &operator=(const SparseTensorIndexCOOBuilder &);
+ flatbuffers::Offset<SparseTensorIndexCOO> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SparseTensorIndexCOO>(end);
+ fbb_.Required(o, SparseTensorIndexCOO::VT_INDICESTYPE);
+ fbb_.Required(o, SparseTensorIndexCOO::VT_INDICESBUFFER);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SparseTensorIndexCOO> CreateSparseTensorIndexCOO(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indicesType = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int64_t>> indicesStrides = 0,
+ const org::apache::arrow::flatbuf::Buffer *indicesBuffer = 0,
+ bool isCanonical = false) {
+ SparseTensorIndexCOOBuilder builder_(_fbb);
+ builder_.add_indicesBuffer(indicesBuffer);
+ builder_.add_indicesStrides(indicesStrides);
+ builder_.add_indicesType(indicesType);
+ builder_.add_isCanonical(isCanonical);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<SparseTensorIndexCOO> CreateSparseTensorIndexCOODirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indicesType = 0,
+ const std::vector<int64_t> *indicesStrides = nullptr,
+ const org::apache::arrow::flatbuf::Buffer *indicesBuffer = 0,
+ bool isCanonical = false) {
+ auto indicesStrides__ = indicesStrides ? _fbb.CreateVector<int64_t>(*indicesStrides) : 0;
+ return org::apache::arrow::flatbuf::CreateSparseTensorIndexCOO(
+ _fbb,
+ indicesType,
+ indicesStrides__,
+ indicesBuffer,
+ isCanonical);
+}
+
+/// Compressed Sparse format, that is matrix-specific.
+struct SparseMatrixIndexCSX FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SparseMatrixIndexCSXBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_COMPRESSEDAXIS = 4,
+ VT_INDPTRTYPE = 6,
+ VT_INDPTRBUFFER = 8,
+ VT_INDICESTYPE = 10,
+ VT_INDICESBUFFER = 12
+ };
+ /// Which axis, row or column, is compressed
+ org::apache::arrow::flatbuf::SparseMatrixCompressedAxis compressedAxis() const {
+ return static_cast<org::apache::arrow::flatbuf::SparseMatrixCompressedAxis>(GetField<int16_t>(VT_COMPRESSEDAXIS, 0));
+ }
+ /// The type of values in indptrBuffer
+ const org::apache::arrow::flatbuf::Int *indptrType() const {
+ return GetPointer<const org::apache::arrow::flatbuf::Int *>(VT_INDPTRTYPE);
+ }
+ /// indptrBuffer stores the location and size of indptr array that
+ /// represents the range of the rows.
+ /// The i-th row spans from `indptr[i]` to `indptr[i+1]` in the data.
+ /// The length of this array is 1 + (the number of rows), and the type
+ /// of index value is long.
+ ///
+ /// For example, let X be the following 6x4 matrix:
+ /// ```text
+ /// X := [[0, 1, 2, 0],
+ /// [0, 0, 3, 0],
+ /// [0, 4, 0, 5],
+ /// [0, 0, 0, 0],
+ /// [6, 0, 7, 8],
+ /// [0, 9, 0, 0]].
+ /// ```
+ /// The array of non-zero values in X is:
+ /// ```text
+ /// values(X) = [1, 2, 3, 4, 5, 6, 7, 8, 9].
+ /// ```
+ /// And the indptr of X is:
+ /// ```text
+ /// indptr(X) = [0, 2, 3, 5, 5, 8, 10].
+ /// ```
+ const org::apache::arrow::flatbuf::Buffer *indptrBuffer() const {
+ return GetStruct<const org::apache::arrow::flatbuf::Buffer *>(VT_INDPTRBUFFER);
+ }
+ /// The type of values in indicesBuffer
+ const org::apache::arrow::flatbuf::Int *indicesType() const {
+ return GetPointer<const org::apache::arrow::flatbuf::Int *>(VT_INDICESTYPE);
+ }
+ /// indicesBuffer stores the location and size of the array that
+ /// contains the column indices of the corresponding non-zero values.
+ /// The type of index value is long.
+ ///
+ /// For example, the indices of the above X is:
+ /// ```text
+ /// indices(X) = [1, 2, 2, 1, 3, 0, 2, 3, 1].
+ /// ```
+ /// Note that the indices are sorted in lexicographical order for each row.
+ const org::apache::arrow::flatbuf::Buffer *indicesBuffer() const {
+ return GetStruct<const org::apache::arrow::flatbuf::Buffer *>(VT_INDICESBUFFER);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int16_t>(verifier, VT_COMPRESSEDAXIS) &&
+ VerifyOffsetRequired(verifier, VT_INDPTRTYPE) &&
+ verifier.VerifyTable(indptrType()) &&
+ VerifyFieldRequired<org::apache::arrow::flatbuf::Buffer>(verifier, VT_INDPTRBUFFER) &&
+ VerifyOffsetRequired(verifier, VT_INDICESTYPE) &&
+ verifier.VerifyTable(indicesType()) &&
+ VerifyFieldRequired<org::apache::arrow::flatbuf::Buffer>(verifier, VT_INDICESBUFFER) &&
+ verifier.EndTable();
+ }
+};
+
+struct SparseMatrixIndexCSXBuilder {
+ typedef SparseMatrixIndexCSX Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_compressedAxis(org::apache::arrow::flatbuf::SparseMatrixCompressedAxis compressedAxis) {
+ fbb_.AddElement<int16_t>(SparseMatrixIndexCSX::VT_COMPRESSEDAXIS, static_cast<int16_t>(compressedAxis), 0);
+ }
+ void add_indptrType(flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indptrType) {
+ fbb_.AddOffset(SparseMatrixIndexCSX::VT_INDPTRTYPE, indptrType);
+ }
+ void add_indptrBuffer(const org::apache::arrow::flatbuf::Buffer *indptrBuffer) {
+ fbb_.AddStruct(SparseMatrixIndexCSX::VT_INDPTRBUFFER, indptrBuffer);
+ }
+ void add_indicesType(flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indicesType) {
+ fbb_.AddOffset(SparseMatrixIndexCSX::VT_INDICESTYPE, indicesType);
+ }
+ void add_indicesBuffer(const org::apache::arrow::flatbuf::Buffer *indicesBuffer) {
+ fbb_.AddStruct(SparseMatrixIndexCSX::VT_INDICESBUFFER, indicesBuffer);
+ }
+ explicit SparseMatrixIndexCSXBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SparseMatrixIndexCSXBuilder &operator=(const SparseMatrixIndexCSXBuilder &);
+ flatbuffers::Offset<SparseMatrixIndexCSX> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SparseMatrixIndexCSX>(end);
+ fbb_.Required(o, SparseMatrixIndexCSX::VT_INDPTRTYPE);
+ fbb_.Required(o, SparseMatrixIndexCSX::VT_INDPTRBUFFER);
+ fbb_.Required(o, SparseMatrixIndexCSX::VT_INDICESTYPE);
+ fbb_.Required(o, SparseMatrixIndexCSX::VT_INDICESBUFFER);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SparseMatrixIndexCSX> CreateSparseMatrixIndexCSX(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::SparseMatrixCompressedAxis compressedAxis = org::apache::arrow::flatbuf::SparseMatrixCompressedAxis::Row,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indptrType = 0,
+ const org::apache::arrow::flatbuf::Buffer *indptrBuffer = 0,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indicesType = 0,
+ const org::apache::arrow::flatbuf::Buffer *indicesBuffer = 0) {
+ SparseMatrixIndexCSXBuilder builder_(_fbb);
+ builder_.add_indicesBuffer(indicesBuffer);
+ builder_.add_indicesType(indicesType);
+ builder_.add_indptrBuffer(indptrBuffer);
+ builder_.add_indptrType(indptrType);
+ builder_.add_compressedAxis(compressedAxis);
+ return builder_.Finish();
+}
+
+/// Compressed Sparse Fiber (CSF) sparse tensor index.
+struct SparseTensorIndexCSF FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SparseTensorIndexCSFBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_INDPTRTYPE = 4,
+ VT_INDPTRBUFFERS = 6,
+ VT_INDICESTYPE = 8,
+ VT_INDICESBUFFERS = 10,
+ VT_AXISORDER = 12
+ };
+ /// CSF is a generalization of compressed sparse row (CSR) index.
+ /// See [smith2017knl](http://shaden.io/pub-files/smith2017knl.pdf)
+ ///
+ /// CSF index recursively compresses each dimension of a tensor into a set
+ /// of prefix trees. Each path from a root to leaf forms one tensor
+ /// non-zero index. CSF is implemented with two arrays of buffers and one
+ /// arrays of integers.
+ ///
+ /// For example, let X be a 2x3x4x5 tensor and let it have the following
+ /// 8 non-zero values:
+ /// ```text
+ /// X[0, 0, 0, 1] := 1
+ /// X[0, 0, 0, 2] := 2
+ /// X[0, 1, 0, 0] := 3
+ /// X[0, 1, 0, 2] := 4
+ /// X[0, 1, 1, 0] := 5
+ /// X[1, 1, 1, 0] := 6
+ /// X[1, 1, 1, 1] := 7
+ /// X[1, 1, 1, 2] := 8
+ /// ```
+ /// As a prefix tree this would be represented as:
+ /// ```text
+ /// 0 1
+ /// / \ |
+ /// 0 1 1
+ /// / / \ |
+ /// 0 0 1 1
+ /// /| /| | /| |
+ /// 1 2 0 2 0 0 1 2
+ /// ```
+ /// The type of values in indptrBuffers
+ const org::apache::arrow::flatbuf::Int *indptrType() const {
+ return GetPointer<const org::apache::arrow::flatbuf::Int *>(VT_INDPTRTYPE);
+ }
+ /// indptrBuffers stores the sparsity structure.
+ /// Each two consecutive dimensions in a tensor correspond to a buffer in
+ /// indptrBuffers. A pair of consecutive values at `indptrBuffers[dim][i]`
+ /// and `indptrBuffers[dim][i + 1]` signify a range of nodes in
+ /// `indicesBuffers[dim + 1]` who are children of `indicesBuffers[dim][i]` node.
+ ///
+ /// For example, the indptrBuffers for the above X is:
+ /// ```text
+ /// indptrBuffer(X) = [
+ /// [0, 2, 3],
+ /// [0, 1, 3, 4],
+ /// [0, 2, 4, 5, 8]
+ /// ].
+ /// ```
+ const flatbuffers::Vector<const org::apache::arrow::flatbuf::Buffer *> *indptrBuffers() const {
+ return GetPointer<const flatbuffers::Vector<const org::apache::arrow::flatbuf::Buffer *> *>(VT_INDPTRBUFFERS);
+ }
+ /// The type of values in indicesBuffers
+ const org::apache::arrow::flatbuf::Int *indicesType() const {
+ return GetPointer<const org::apache::arrow::flatbuf::Int *>(VT_INDICESTYPE);
+ }
+ /// indicesBuffers stores values of nodes.
+ /// Each tensor dimension corresponds to a buffer in indicesBuffers.
+ /// For example, the indicesBuffers for the above X is:
+ /// ```text
+ /// indicesBuffer(X) = [
+ /// [0, 1],
+ /// [0, 1, 1],
+ /// [0, 0, 1, 1],
+ /// [1, 2, 0, 2, 0, 0, 1, 2]
+ /// ].
+ /// ```
+ const flatbuffers::Vector<const org::apache::arrow::flatbuf::Buffer *> *indicesBuffers() const {
+ return GetPointer<const flatbuffers::Vector<const org::apache::arrow::flatbuf::Buffer *> *>(VT_INDICESBUFFERS);
+ }
+ /// axisOrder stores the sequence in which dimensions were traversed to
+ /// produce the prefix tree.
+ /// For example, the axisOrder for the above X is:
+ /// ```text
+ /// axisOrder(X) = [0, 1, 2, 3].
+ /// ```
+ const flatbuffers::Vector<int32_t> *axisOrder() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_AXISORDER);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffsetRequired(verifier, VT_INDPTRTYPE) &&
+ verifier.VerifyTable(indptrType()) &&
+ VerifyOffsetRequired(verifier, VT_INDPTRBUFFERS) &&
+ verifier.VerifyVector(indptrBuffers()) &&
+ VerifyOffsetRequired(verifier, VT_INDICESTYPE) &&
+ verifier.VerifyTable(indicesType()) &&
+ VerifyOffsetRequired(verifier, VT_INDICESBUFFERS) &&
+ verifier.VerifyVector(indicesBuffers()) &&
+ VerifyOffsetRequired(verifier, VT_AXISORDER) &&
+ verifier.VerifyVector(axisOrder()) &&
+ verifier.EndTable();
+ }
+};
+
+struct SparseTensorIndexCSFBuilder {
+ typedef SparseTensorIndexCSF Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_indptrType(flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indptrType) {
+ fbb_.AddOffset(SparseTensorIndexCSF::VT_INDPTRTYPE, indptrType);
+ }
+ void add_indptrBuffers(flatbuffers::Offset<flatbuffers::Vector<const org::apache::arrow::flatbuf::Buffer *>> indptrBuffers) {
+ fbb_.AddOffset(SparseTensorIndexCSF::VT_INDPTRBUFFERS, indptrBuffers);
+ }
+ void add_indicesType(flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indicesType) {
+ fbb_.AddOffset(SparseTensorIndexCSF::VT_INDICESTYPE, indicesType);
+ }
+ void add_indicesBuffers(flatbuffers::Offset<flatbuffers::Vector<const org::apache::arrow::flatbuf::Buffer *>> indicesBuffers) {
+ fbb_.AddOffset(SparseTensorIndexCSF::VT_INDICESBUFFERS, indicesBuffers);
+ }
+ void add_axisOrder(flatbuffers::Offset<flatbuffers::Vector<int32_t>> axisOrder) {
+ fbb_.AddOffset(SparseTensorIndexCSF::VT_AXISORDER, axisOrder);
+ }
+ explicit SparseTensorIndexCSFBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SparseTensorIndexCSFBuilder &operator=(const SparseTensorIndexCSFBuilder &);
+ flatbuffers::Offset<SparseTensorIndexCSF> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SparseTensorIndexCSF>(end);
+ fbb_.Required(o, SparseTensorIndexCSF::VT_INDPTRTYPE);
+ fbb_.Required(o, SparseTensorIndexCSF::VT_INDPTRBUFFERS);
+ fbb_.Required(o, SparseTensorIndexCSF::VT_INDICESTYPE);
+ fbb_.Required(o, SparseTensorIndexCSF::VT_INDICESBUFFERS);
+ fbb_.Required(o, SparseTensorIndexCSF::VT_AXISORDER);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SparseTensorIndexCSF> CreateSparseTensorIndexCSF(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indptrType = 0,
+ flatbuffers::Offset<flatbuffers::Vector<const org::apache::arrow::flatbuf::Buffer *>> indptrBuffers = 0,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indicesType = 0,
+ flatbuffers::Offset<flatbuffers::Vector<const org::apache::arrow::flatbuf::Buffer *>> indicesBuffers = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> axisOrder = 0) {
+ SparseTensorIndexCSFBuilder builder_(_fbb);
+ builder_.add_axisOrder(axisOrder);
+ builder_.add_indicesBuffers(indicesBuffers);
+ builder_.add_indicesType(indicesType);
+ builder_.add_indptrBuffers(indptrBuffers);
+ builder_.add_indptrType(indptrType);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<SparseTensorIndexCSF> CreateSparseTensorIndexCSFDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indptrType = 0,
+ const std::vector<org::apache::arrow::flatbuf::Buffer> *indptrBuffers = nullptr,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Int> indicesType = 0,
+ const std::vector<org::apache::arrow::flatbuf::Buffer> *indicesBuffers = nullptr,
+ const std::vector<int32_t> *axisOrder = nullptr) {
+ auto indptrBuffers__ = indptrBuffers ? _fbb.CreateVectorOfStructs<org::apache::arrow::flatbuf::Buffer>(*indptrBuffers) : 0;
+ auto indicesBuffers__ = indicesBuffers ? _fbb.CreateVectorOfStructs<org::apache::arrow::flatbuf::Buffer>(*indicesBuffers) : 0;
+ auto axisOrder__ = axisOrder ? _fbb.CreateVector<int32_t>(*axisOrder) : 0;
+ return org::apache::arrow::flatbuf::CreateSparseTensorIndexCSF(
+ _fbb,
+ indptrType,
+ indptrBuffers__,
+ indicesType,
+ indicesBuffers__,
+ axisOrder__);
+}
+
+struct SparseTensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SparseTensorBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_TYPE_TYPE = 4,
+ VT_TYPE = 6,
+ VT_SHAPE = 8,
+ VT_NON_ZERO_LENGTH = 10,
+ VT_SPARSEINDEX_TYPE = 12,
+ VT_SPARSEINDEX = 14,
+ VT_DATA = 16
+ };
+ org::apache::arrow::flatbuf::Type type_type() const {
+ return static_cast<org::apache::arrow::flatbuf::Type>(GetField<uint8_t>(VT_TYPE_TYPE, 0));
+ }
+ /// The type of data contained in a value cell.
+ /// Currently only fixed-width value types are supported,
+ /// no strings or nested types.
+ const void *type() const {
+ return GetPointer<const void *>(VT_TYPE);
+ }
+ template<typename T> const T *type_as() const;
+ const org::apache::arrow::flatbuf::Null *type_as_Null() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Null ? static_cast<const org::apache::arrow::flatbuf::Null *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Int *type_as_Int() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Int ? static_cast<const org::apache::arrow::flatbuf::Int *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::FloatingPoint *type_as_FloatingPoint() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::FloatingPoint ? static_cast<const org::apache::arrow::flatbuf::FloatingPoint *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Binary *type_as_Binary() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Binary ? static_cast<const org::apache::arrow::flatbuf::Binary *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Utf8 *type_as_Utf8() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Utf8 ? static_cast<const org::apache::arrow::flatbuf::Utf8 *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Bool *type_as_Bool() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Bool ? static_cast<const org::apache::arrow::flatbuf::Bool *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Decimal *type_as_Decimal() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Decimal ? static_cast<const org::apache::arrow::flatbuf::Decimal *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Date *type_as_Date() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Date ? static_cast<const org::apache::arrow::flatbuf::Date *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Time *type_as_Time() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Time ? static_cast<const org::apache::arrow::flatbuf::Time *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Timestamp *type_as_Timestamp() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Timestamp ? static_cast<const org::apache::arrow::flatbuf::Timestamp *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Interval *type_as_Interval() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Interval ? static_cast<const org::apache::arrow::flatbuf::Interval *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::List *type_as_List() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::List ? static_cast<const org::apache::arrow::flatbuf::List *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Struct_ *type_as_Struct_() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Struct_ ? static_cast<const org::apache::arrow::flatbuf::Struct_ *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Union *type_as_Union() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Union ? static_cast<const org::apache::arrow::flatbuf::Union *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::FixedSizeBinary *type_as_FixedSizeBinary() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::FixedSizeBinary ? static_cast<const org::apache::arrow::flatbuf::FixedSizeBinary *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::FixedSizeList *type_as_FixedSizeList() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::FixedSizeList ? static_cast<const org::apache::arrow::flatbuf::FixedSizeList *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Map *type_as_Map() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Map ? static_cast<const org::apache::arrow::flatbuf::Map *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Duration *type_as_Duration() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Duration ? static_cast<const org::apache::arrow::flatbuf::Duration *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::LargeBinary *type_as_LargeBinary() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::LargeBinary ? static_cast<const org::apache::arrow::flatbuf::LargeBinary *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::LargeUtf8 *type_as_LargeUtf8() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::LargeUtf8 ? static_cast<const org::apache::arrow::flatbuf::LargeUtf8 *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::LargeList *type_as_LargeList() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::LargeList ? static_cast<const org::apache::arrow::flatbuf::LargeList *>(type()) : nullptr;
+ }
+ /// The dimensions of the tensor, optionally named.
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::TensorDim>> *shape() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::TensorDim>> *>(VT_SHAPE);
+ }
+ /// The number of non-zero values in a sparse tensor.
+ int64_t non_zero_length() const {
+ return GetField<int64_t>(VT_NON_ZERO_LENGTH, 0);
+ }
+ org::apache::arrow::flatbuf::SparseTensorIndex sparseIndex_type() const {
+ return static_cast<org::apache::arrow::flatbuf::SparseTensorIndex>(GetField<uint8_t>(VT_SPARSEINDEX_TYPE, 0));
+ }
+ /// Sparse tensor index
+ const void *sparseIndex() const {
+ return GetPointer<const void *>(VT_SPARSEINDEX);
+ }
+ template<typename T> const T *sparseIndex_as() const;
+ const org::apache::arrow::flatbuf::SparseTensorIndexCOO *sparseIndex_as_SparseTensorIndexCOO() const {
+ return sparseIndex_type() == org::apache::arrow::flatbuf::SparseTensorIndex::SparseTensorIndexCOO ? static_cast<const org::apache::arrow::flatbuf::SparseTensorIndexCOO *>(sparseIndex()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::SparseMatrixIndexCSX *sparseIndex_as_SparseMatrixIndexCSX() const {
+ return sparseIndex_type() == org::apache::arrow::flatbuf::SparseTensorIndex::SparseMatrixIndexCSX ? static_cast<const org::apache::arrow::flatbuf::SparseMatrixIndexCSX *>(sparseIndex()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::SparseTensorIndexCSF *sparseIndex_as_SparseTensorIndexCSF() const {
+ return sparseIndex_type() == org::apache::arrow::flatbuf::SparseTensorIndex::SparseTensorIndexCSF ? static_cast<const org::apache::arrow::flatbuf::SparseTensorIndexCSF *>(sparseIndex()) : nullptr;
+ }
+ /// The location and size of the tensor's data
+ const org::apache::arrow::flatbuf::Buffer *data() const {
+ return GetStruct<const org::apache::arrow::flatbuf::Buffer *>(VT_DATA);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_TYPE_TYPE) &&
+ VerifyOffsetRequired(verifier, VT_TYPE) &&
+ VerifyType(verifier, type(), type_type()) &&
+ VerifyOffsetRequired(verifier, VT_SHAPE) &&
+ verifier.VerifyVector(shape()) &&
+ verifier.VerifyVectorOfTables(shape()) &&
+ VerifyField<int64_t>(verifier, VT_NON_ZERO_LENGTH) &&
+ VerifyField<uint8_t>(verifier, VT_SPARSEINDEX_TYPE) &&
+ VerifyOffsetRequired(verifier, VT_SPARSEINDEX) &&
+ VerifySparseTensorIndex(verifier, sparseIndex(), sparseIndex_type()) &&
+ VerifyFieldRequired<org::apache::arrow::flatbuf::Buffer>(verifier, VT_DATA) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const org::apache::arrow::flatbuf::Null *SparseTensor::type_as<org::apache::arrow::flatbuf::Null>() const {
+ return type_as_Null();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Int *SparseTensor::type_as<org::apache::arrow::flatbuf::Int>() const {
+ return type_as_Int();
+}
+
+template<> inline const org::apache::arrow::flatbuf::FloatingPoint *SparseTensor::type_as<org::apache::arrow::flatbuf::FloatingPoint>() const {
+ return type_as_FloatingPoint();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Binary *SparseTensor::type_as<org::apache::arrow::flatbuf::Binary>() const {
+ return type_as_Binary();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Utf8 *SparseTensor::type_as<org::apache::arrow::flatbuf::Utf8>() const {
+ return type_as_Utf8();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Bool *SparseTensor::type_as<org::apache::arrow::flatbuf::Bool>() const {
+ return type_as_Bool();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Decimal *SparseTensor::type_as<org::apache::arrow::flatbuf::Decimal>() const {
+ return type_as_Decimal();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Date *SparseTensor::type_as<org::apache::arrow::flatbuf::Date>() const {
+ return type_as_Date();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Time *SparseTensor::type_as<org::apache::arrow::flatbuf::Time>() const {
+ return type_as_Time();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Timestamp *SparseTensor::type_as<org::apache::arrow::flatbuf::Timestamp>() const {
+ return type_as_Timestamp();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Interval *SparseTensor::type_as<org::apache::arrow::flatbuf::Interval>() const {
+ return type_as_Interval();
+}
+
+template<> inline const org::apache::arrow::flatbuf::List *SparseTensor::type_as<org::apache::arrow::flatbuf::List>() const {
+ return type_as_List();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Struct_ *SparseTensor::type_as<org::apache::arrow::flatbuf::Struct_>() const {
+ return type_as_Struct_();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Union *SparseTensor::type_as<org::apache::arrow::flatbuf::Union>() const {
+ return type_as_Union();
+}
+
+template<> inline const org::apache::arrow::flatbuf::FixedSizeBinary *SparseTensor::type_as<org::apache::arrow::flatbuf::FixedSizeBinary>() const {
+ return type_as_FixedSizeBinary();
+}
+
+template<> inline const org::apache::arrow::flatbuf::FixedSizeList *SparseTensor::type_as<org::apache::arrow::flatbuf::FixedSizeList>() const {
+ return type_as_FixedSizeList();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Map *SparseTensor::type_as<org::apache::arrow::flatbuf::Map>() const {
+ return type_as_Map();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Duration *SparseTensor::type_as<org::apache::arrow::flatbuf::Duration>() const {
+ return type_as_Duration();
+}
+
+template<> inline const org::apache::arrow::flatbuf::LargeBinary *SparseTensor::type_as<org::apache::arrow::flatbuf::LargeBinary>() const {
+ return type_as_LargeBinary();
+}
+
+template<> inline const org::apache::arrow::flatbuf::LargeUtf8 *SparseTensor::type_as<org::apache::arrow::flatbuf::LargeUtf8>() const {
+ return type_as_LargeUtf8();
+}
+
+template<> inline const org::apache::arrow::flatbuf::LargeList *SparseTensor::type_as<org::apache::arrow::flatbuf::LargeList>() const {
+ return type_as_LargeList();
+}
+
+template<> inline const org::apache::arrow::flatbuf::SparseTensorIndexCOO *SparseTensor::sparseIndex_as<org::apache::arrow::flatbuf::SparseTensorIndexCOO>() const {
+ return sparseIndex_as_SparseTensorIndexCOO();
+}
+
+template<> inline const org::apache::arrow::flatbuf::SparseMatrixIndexCSX *SparseTensor::sparseIndex_as<org::apache::arrow::flatbuf::SparseMatrixIndexCSX>() const {
+ return sparseIndex_as_SparseMatrixIndexCSX();
+}
+
+template<> inline const org::apache::arrow::flatbuf::SparseTensorIndexCSF *SparseTensor::sparseIndex_as<org::apache::arrow::flatbuf::SparseTensorIndexCSF>() const {
+ return sparseIndex_as_SparseTensorIndexCSF();
+}
+
+struct SparseTensorBuilder {
+ typedef SparseTensor Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_type_type(org::apache::arrow::flatbuf::Type type_type) {
+ fbb_.AddElement<uint8_t>(SparseTensor::VT_TYPE_TYPE, static_cast<uint8_t>(type_type), 0);
+ }
+ void add_type(flatbuffers::Offset<void> type) {
+ fbb_.AddOffset(SparseTensor::VT_TYPE, type);
+ }
+ void add_shape(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::TensorDim>>> shape) {
+ fbb_.AddOffset(SparseTensor::VT_SHAPE, shape);
+ }
+ void add_non_zero_length(int64_t non_zero_length) {
+ fbb_.AddElement<int64_t>(SparseTensor::VT_NON_ZERO_LENGTH, non_zero_length, 0);
+ }
+ void add_sparseIndex_type(org::apache::arrow::flatbuf::SparseTensorIndex sparseIndex_type) {
+ fbb_.AddElement<uint8_t>(SparseTensor::VT_SPARSEINDEX_TYPE, static_cast<uint8_t>(sparseIndex_type), 0);
+ }
+ void add_sparseIndex(flatbuffers::Offset<void> sparseIndex) {
+ fbb_.AddOffset(SparseTensor::VT_SPARSEINDEX, sparseIndex);
+ }
+ void add_data(const org::apache::arrow::flatbuf::Buffer *data) {
+ fbb_.AddStruct(SparseTensor::VT_DATA, data);
+ }
+ explicit SparseTensorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SparseTensorBuilder &operator=(const SparseTensorBuilder &);
+ flatbuffers::Offset<SparseTensor> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SparseTensor>(end);
+ fbb_.Required(o, SparseTensor::VT_TYPE);
+ fbb_.Required(o, SparseTensor::VT_SHAPE);
+ fbb_.Required(o, SparseTensor::VT_SPARSEINDEX);
+ fbb_.Required(o, SparseTensor::VT_DATA);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SparseTensor> CreateSparseTensor(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::Type type_type = org::apache::arrow::flatbuf::Type::NONE,
+ flatbuffers::Offset<void> type = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::TensorDim>>> shape = 0,
+ int64_t non_zero_length = 0,
+ org::apache::arrow::flatbuf::SparseTensorIndex sparseIndex_type = org::apache::arrow::flatbuf::SparseTensorIndex::NONE,
+ flatbuffers::Offset<void> sparseIndex = 0,
+ const org::apache::arrow::flatbuf::Buffer *data = 0) {
+ SparseTensorBuilder builder_(_fbb);
+ builder_.add_non_zero_length(non_zero_length);
+ builder_.add_data(data);
+ builder_.add_sparseIndex(sparseIndex);
+ builder_.add_shape(shape);
+ builder_.add_type(type);
+ builder_.add_sparseIndex_type(sparseIndex_type);
+ builder_.add_type_type(type_type);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<SparseTensor> CreateSparseTensorDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::Type type_type = org::apache::arrow::flatbuf::Type::NONE,
+ flatbuffers::Offset<void> type = 0,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::TensorDim>> *shape = nullptr,
+ int64_t non_zero_length = 0,
+ org::apache::arrow::flatbuf::SparseTensorIndex sparseIndex_type = org::apache::arrow::flatbuf::SparseTensorIndex::NONE,
+ flatbuffers::Offset<void> sparseIndex = 0,
+ const org::apache::arrow::flatbuf::Buffer *data = 0) {
+ auto shape__ = shape ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::flatbuf::TensorDim>>(*shape) : 0;
+ return org::apache::arrow::flatbuf::CreateSparseTensor(
+ _fbb,
+ type_type,
+ type,
+ shape__,
+ non_zero_length,
+ sparseIndex_type,
+ sparseIndex,
+ data);
+}
+
+inline bool VerifySparseTensorIndex(flatbuffers::Verifier &verifier, const void *obj, SparseTensorIndex type) {
+ switch (type) {
+ case SparseTensorIndex::NONE: {
+ return true;
+ }
+ case SparseTensorIndex::SparseTensorIndexCOO: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::SparseTensorIndexCOO *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case SparseTensorIndex::SparseMatrixIndexCSX: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::SparseMatrixIndexCSX *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case SparseTensorIndex::SparseTensorIndexCSF: {
+ auto ptr = reinterpret_cast<const org::apache::arrow::flatbuf::SparseTensorIndexCSF *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return true;
+ }
+}
+
+inline bool VerifySparseTensorIndexVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifySparseTensorIndex(
+ verifier, values->Get(i), types->GetEnum<SparseTensorIndex>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline const org::apache::arrow::flatbuf::SparseTensor *GetSparseTensor(const void *buf) {
+ return flatbuffers::GetRoot<org::apache::arrow::flatbuf::SparseTensor>(buf);
+}
+
+inline const org::apache::arrow::flatbuf::SparseTensor *GetSizePrefixedSparseTensor(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<org::apache::arrow::flatbuf::SparseTensor>(buf);
+}
+
+inline bool VerifySparseTensorBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifyBuffer<org::apache::arrow::flatbuf::SparseTensor>(nullptr);
+}
+
+inline bool VerifySizePrefixedSparseTensorBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<org::apache::arrow::flatbuf::SparseTensor>(nullptr);
+}
+
+inline void FinishSparseTensorBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::SparseTensor> root) {
+ fbb.Finish(root);
+}
+
+inline void FinishSizePrefixedSparseTensorBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::SparseTensor> root) {
+ fbb.FinishSizePrefixed(root);
+}
+
+} // namespace flatbuf
+} // namespace arrow
+} // namespace apache
+} // namespace org
+
+#endif // FLATBUFFERS_GENERATED_SPARSETENSOR_ORG_APACHE_ARROW_FLATBUF_H_
diff --git a/src/arrow/cpp/src/generated/Tensor_generated.h b/src/arrow/cpp/src/generated/Tensor_generated.h
new file mode 100644
index 000000000..062a3b91a
--- /dev/null
+++ b/src/arrow/cpp/src/generated/Tensor_generated.h
@@ -0,0 +1,387 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_TENSOR_ORG_APACHE_ARROW_FLATBUF_H_
+#define FLATBUFFERS_GENERATED_TENSOR_ORG_APACHE_ARROW_FLATBUF_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+#include "Schema_generated.h"
+
+namespace org {
+namespace apache {
+namespace arrow {
+namespace flatbuf {
+
+struct TensorDim;
+struct TensorDimBuilder;
+
+struct Tensor;
+struct TensorBuilder;
+
+/// ----------------------------------------------------------------------
+/// Data structures for dense tensors
+/// Shape data for a single axis in a tensor
+struct TensorDim FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef TensorDimBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_SIZE = 4,
+ VT_NAME = 6
+ };
+ /// Length of dimension
+ int64_t size() const {
+ return GetField<int64_t>(VT_SIZE, 0);
+ }
+ /// Name of the dimension, optional
+ const flatbuffers::String *name() const {
+ return GetPointer<const flatbuffers::String *>(VT_NAME);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int64_t>(verifier, VT_SIZE) &&
+ VerifyOffset(verifier, VT_NAME) &&
+ verifier.VerifyString(name()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TensorDimBuilder {
+ typedef TensorDim Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_size(int64_t size) {
+ fbb_.AddElement<int64_t>(TensorDim::VT_SIZE, size, 0);
+ }
+ void add_name(flatbuffers::Offset<flatbuffers::String> name) {
+ fbb_.AddOffset(TensorDim::VT_NAME, name);
+ }
+ explicit TensorDimBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TensorDimBuilder &operator=(const TensorDimBuilder &);
+ flatbuffers::Offset<TensorDim> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TensorDim>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TensorDim> CreateTensorDim(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int64_t size = 0,
+ flatbuffers::Offset<flatbuffers::String> name = 0) {
+ TensorDimBuilder builder_(_fbb);
+ builder_.add_size(size);
+ builder_.add_name(name);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TensorDim> CreateTensorDimDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int64_t size = 0,
+ const char *name = nullptr) {
+ auto name__ = name ? _fbb.CreateString(name) : 0;
+ return org::apache::arrow::flatbuf::CreateTensorDim(
+ _fbb,
+ size,
+ name__);
+}
+
+struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef TensorBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_TYPE_TYPE = 4,
+ VT_TYPE = 6,
+ VT_SHAPE = 8,
+ VT_STRIDES = 10,
+ VT_DATA = 12
+ };
+ org::apache::arrow::flatbuf::Type type_type() const {
+ return static_cast<org::apache::arrow::flatbuf::Type>(GetField<uint8_t>(VT_TYPE_TYPE, 0));
+ }
+ /// The type of data contained in a value cell. Currently only fixed-width
+ /// value types are supported, no strings or nested types
+ const void *type() const {
+ return GetPointer<const void *>(VT_TYPE);
+ }
+ template<typename T> const T *type_as() const;
+ const org::apache::arrow::flatbuf::Null *type_as_Null() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Null ? static_cast<const org::apache::arrow::flatbuf::Null *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Int *type_as_Int() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Int ? static_cast<const org::apache::arrow::flatbuf::Int *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::FloatingPoint *type_as_FloatingPoint() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::FloatingPoint ? static_cast<const org::apache::arrow::flatbuf::FloatingPoint *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Binary *type_as_Binary() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Binary ? static_cast<const org::apache::arrow::flatbuf::Binary *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Utf8 *type_as_Utf8() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Utf8 ? static_cast<const org::apache::arrow::flatbuf::Utf8 *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Bool *type_as_Bool() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Bool ? static_cast<const org::apache::arrow::flatbuf::Bool *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Decimal *type_as_Decimal() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Decimal ? static_cast<const org::apache::arrow::flatbuf::Decimal *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Date *type_as_Date() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Date ? static_cast<const org::apache::arrow::flatbuf::Date *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Time *type_as_Time() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Time ? static_cast<const org::apache::arrow::flatbuf::Time *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Timestamp *type_as_Timestamp() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Timestamp ? static_cast<const org::apache::arrow::flatbuf::Timestamp *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Interval *type_as_Interval() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Interval ? static_cast<const org::apache::arrow::flatbuf::Interval *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::List *type_as_List() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::List ? static_cast<const org::apache::arrow::flatbuf::List *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Struct_ *type_as_Struct_() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Struct_ ? static_cast<const org::apache::arrow::flatbuf::Struct_ *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Union *type_as_Union() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Union ? static_cast<const org::apache::arrow::flatbuf::Union *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::FixedSizeBinary *type_as_FixedSizeBinary() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::FixedSizeBinary ? static_cast<const org::apache::arrow::flatbuf::FixedSizeBinary *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::FixedSizeList *type_as_FixedSizeList() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::FixedSizeList ? static_cast<const org::apache::arrow::flatbuf::FixedSizeList *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Map *type_as_Map() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Map ? static_cast<const org::apache::arrow::flatbuf::Map *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::Duration *type_as_Duration() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::Duration ? static_cast<const org::apache::arrow::flatbuf::Duration *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::LargeBinary *type_as_LargeBinary() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::LargeBinary ? static_cast<const org::apache::arrow::flatbuf::LargeBinary *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::LargeUtf8 *type_as_LargeUtf8() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::LargeUtf8 ? static_cast<const org::apache::arrow::flatbuf::LargeUtf8 *>(type()) : nullptr;
+ }
+ const org::apache::arrow::flatbuf::LargeList *type_as_LargeList() const {
+ return type_type() == org::apache::arrow::flatbuf::Type::LargeList ? static_cast<const org::apache::arrow::flatbuf::LargeList *>(type()) : nullptr;
+ }
+ /// The dimensions of the tensor, optionally named
+ const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::TensorDim>> *shape() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::TensorDim>> *>(VT_SHAPE);
+ }
+ /// Non-negative byte offsets to advance one value cell along each dimension
+ /// If omitted, default to row-major order (C-like).
+ const flatbuffers::Vector<int64_t> *strides() const {
+ return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_STRIDES);
+ }
+ /// The location and size of the tensor's data
+ const org::apache::arrow::flatbuf::Buffer *data() const {
+ return GetStruct<const org::apache::arrow::flatbuf::Buffer *>(VT_DATA);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_TYPE_TYPE) &&
+ VerifyOffsetRequired(verifier, VT_TYPE) &&
+ VerifyType(verifier, type(), type_type()) &&
+ VerifyOffsetRequired(verifier, VT_SHAPE) &&
+ verifier.VerifyVector(shape()) &&
+ verifier.VerifyVectorOfTables(shape()) &&
+ VerifyOffset(verifier, VT_STRIDES) &&
+ verifier.VerifyVector(strides()) &&
+ VerifyFieldRequired<org::apache::arrow::flatbuf::Buffer>(verifier, VT_DATA) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const org::apache::arrow::flatbuf::Null *Tensor::type_as<org::apache::arrow::flatbuf::Null>() const {
+ return type_as_Null();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Int *Tensor::type_as<org::apache::arrow::flatbuf::Int>() const {
+ return type_as_Int();
+}
+
+template<> inline const org::apache::arrow::flatbuf::FloatingPoint *Tensor::type_as<org::apache::arrow::flatbuf::FloatingPoint>() const {
+ return type_as_FloatingPoint();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Binary *Tensor::type_as<org::apache::arrow::flatbuf::Binary>() const {
+ return type_as_Binary();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Utf8 *Tensor::type_as<org::apache::arrow::flatbuf::Utf8>() const {
+ return type_as_Utf8();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Bool *Tensor::type_as<org::apache::arrow::flatbuf::Bool>() const {
+ return type_as_Bool();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Decimal *Tensor::type_as<org::apache::arrow::flatbuf::Decimal>() const {
+ return type_as_Decimal();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Date *Tensor::type_as<org::apache::arrow::flatbuf::Date>() const {
+ return type_as_Date();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Time *Tensor::type_as<org::apache::arrow::flatbuf::Time>() const {
+ return type_as_Time();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Timestamp *Tensor::type_as<org::apache::arrow::flatbuf::Timestamp>() const {
+ return type_as_Timestamp();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Interval *Tensor::type_as<org::apache::arrow::flatbuf::Interval>() const {
+ return type_as_Interval();
+}
+
+template<> inline const org::apache::arrow::flatbuf::List *Tensor::type_as<org::apache::arrow::flatbuf::List>() const {
+ return type_as_List();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Struct_ *Tensor::type_as<org::apache::arrow::flatbuf::Struct_>() const {
+ return type_as_Struct_();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Union *Tensor::type_as<org::apache::arrow::flatbuf::Union>() const {
+ return type_as_Union();
+}
+
+template<> inline const org::apache::arrow::flatbuf::FixedSizeBinary *Tensor::type_as<org::apache::arrow::flatbuf::FixedSizeBinary>() const {
+ return type_as_FixedSizeBinary();
+}
+
+template<> inline const org::apache::arrow::flatbuf::FixedSizeList *Tensor::type_as<org::apache::arrow::flatbuf::FixedSizeList>() const {
+ return type_as_FixedSizeList();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Map *Tensor::type_as<org::apache::arrow::flatbuf::Map>() const {
+ return type_as_Map();
+}
+
+template<> inline const org::apache::arrow::flatbuf::Duration *Tensor::type_as<org::apache::arrow::flatbuf::Duration>() const {
+ return type_as_Duration();
+}
+
+template<> inline const org::apache::arrow::flatbuf::LargeBinary *Tensor::type_as<org::apache::arrow::flatbuf::LargeBinary>() const {
+ return type_as_LargeBinary();
+}
+
+template<> inline const org::apache::arrow::flatbuf::LargeUtf8 *Tensor::type_as<org::apache::arrow::flatbuf::LargeUtf8>() const {
+ return type_as_LargeUtf8();
+}
+
+template<> inline const org::apache::arrow::flatbuf::LargeList *Tensor::type_as<org::apache::arrow::flatbuf::LargeList>() const {
+ return type_as_LargeList();
+}
+
+struct TensorBuilder {
+ typedef Tensor Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_type_type(org::apache::arrow::flatbuf::Type type_type) {
+ fbb_.AddElement<uint8_t>(Tensor::VT_TYPE_TYPE, static_cast<uint8_t>(type_type), 0);
+ }
+ void add_type(flatbuffers::Offset<void> type) {
+ fbb_.AddOffset(Tensor::VT_TYPE, type);
+ }
+ void add_shape(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::TensorDim>>> shape) {
+ fbb_.AddOffset(Tensor::VT_SHAPE, shape);
+ }
+ void add_strides(flatbuffers::Offset<flatbuffers::Vector<int64_t>> strides) {
+ fbb_.AddOffset(Tensor::VT_STRIDES, strides);
+ }
+ void add_data(const org::apache::arrow::flatbuf::Buffer *data) {
+ fbb_.AddStruct(Tensor::VT_DATA, data);
+ }
+ explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TensorBuilder &operator=(const TensorBuilder &);
+ flatbuffers::Offset<Tensor> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Tensor>(end);
+ fbb_.Required(o, Tensor::VT_TYPE);
+ fbb_.Required(o, Tensor::VT_SHAPE);
+ fbb_.Required(o, Tensor::VT_DATA);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Tensor> CreateTensor(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::Type type_type = org::apache::arrow::flatbuf::Type::NONE,
+ flatbuffers::Offset<void> type = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<org::apache::arrow::flatbuf::TensorDim>>> shape = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int64_t>> strides = 0,
+ const org::apache::arrow::flatbuf::Buffer *data = 0) {
+ TensorBuilder builder_(_fbb);
+ builder_.add_data(data);
+ builder_.add_strides(strides);
+ builder_.add_shape(shape);
+ builder_.add_type(type);
+ builder_.add_type_type(type_type);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Tensor> CreateTensorDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ org::apache::arrow::flatbuf::Type type_type = org::apache::arrow::flatbuf::Type::NONE,
+ flatbuffers::Offset<void> type = 0,
+ const std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::TensorDim>> *shape = nullptr,
+ const std::vector<int64_t> *strides = nullptr,
+ const org::apache::arrow::flatbuf::Buffer *data = 0) {
+ auto shape__ = shape ? _fbb.CreateVector<flatbuffers::Offset<org::apache::arrow::flatbuf::TensorDim>>(*shape) : 0;
+ auto strides__ = strides ? _fbb.CreateVector<int64_t>(*strides) : 0;
+ return org::apache::arrow::flatbuf::CreateTensor(
+ _fbb,
+ type_type,
+ type,
+ shape__,
+ strides__,
+ data);
+}
+
+inline const org::apache::arrow::flatbuf::Tensor *GetTensor(const void *buf) {
+ return flatbuffers::GetRoot<org::apache::arrow::flatbuf::Tensor>(buf);
+}
+
+inline const org::apache::arrow::flatbuf::Tensor *GetSizePrefixedTensor(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<org::apache::arrow::flatbuf::Tensor>(buf);
+}
+
+inline bool VerifyTensorBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifyBuffer<org::apache::arrow::flatbuf::Tensor>(nullptr);
+}
+
+inline bool VerifySizePrefixedTensorBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<org::apache::arrow::flatbuf::Tensor>(nullptr);
+}
+
+inline void FinishTensorBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Tensor> root) {
+ fbb.Finish(root);
+}
+
+inline void FinishSizePrefixedTensorBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<org::apache::arrow::flatbuf::Tensor> root) {
+ fbb.FinishSizePrefixed(root);
+}
+
+} // namespace flatbuf
+} // namespace arrow
+} // namespace apache
+} // namespace org
+
+#endif // FLATBUFFERS_GENERATED_TENSOR_ORG_APACHE_ARROW_FLATBUF_H_
diff --git a/src/arrow/cpp/src/generated/feather_generated.h b/src/arrow/cpp/src/generated/feather_generated.h
new file mode 100644
index 000000000..b925eb2bc
--- /dev/null
+++ b/src/arrow/cpp/src/generated/feather_generated.h
@@ -0,0 +1,863 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_FEATHER_ARROW_IPC_FEATHER_FBS_H_
+#define FLATBUFFERS_GENERATED_FEATHER_ARROW_IPC_FEATHER_FBS_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+namespace arrow {
+namespace ipc {
+namespace feather {
+namespace fbs {
+
+struct PrimitiveArray;
+struct PrimitiveArrayBuilder;
+
+struct CategoryMetadata;
+struct CategoryMetadataBuilder;
+
+struct TimestampMetadata;
+struct TimestampMetadataBuilder;
+
+struct DateMetadata;
+struct DateMetadataBuilder;
+
+struct TimeMetadata;
+struct TimeMetadataBuilder;
+
+struct Column;
+struct ColumnBuilder;
+
+struct CTable;
+struct CTableBuilder;
+
+/// Feather is an experimental serialization format implemented using
+/// techniques from Apache Arrow. It was created as a proof-of-concept of an
+/// interoperable file format for storing data frames originating in Python or
+/// R. It enabled the developers to sidestep some of the open design questions
+/// in Arrow from early 2016 and instead create something simple and useful for
+/// the intended use cases.
+enum class Type : int8_t {
+ BOOL = 0,
+ INT8 = 1,
+ INT16 = 2,
+ INT32 = 3,
+ INT64 = 4,
+ UINT8 = 5,
+ UINT16 = 6,
+ UINT32 = 7,
+ UINT64 = 8,
+ FLOAT = 9,
+ DOUBLE = 10,
+ UTF8 = 11,
+ BINARY = 12,
+ CATEGORY = 13,
+ TIMESTAMP = 14,
+ DATE = 15,
+ TIME = 16,
+ LARGE_UTF8 = 17,
+ LARGE_BINARY = 18,
+ MIN = BOOL,
+ MAX = LARGE_BINARY
+};
+
+inline const Type (&EnumValuesType())[19] {
+ static const Type values[] = {
+ Type::BOOL,
+ Type::INT8,
+ Type::INT16,
+ Type::INT32,
+ Type::INT64,
+ Type::UINT8,
+ Type::UINT16,
+ Type::UINT32,
+ Type::UINT64,
+ Type::FLOAT,
+ Type::DOUBLE,
+ Type::UTF8,
+ Type::BINARY,
+ Type::CATEGORY,
+ Type::TIMESTAMP,
+ Type::DATE,
+ Type::TIME,
+ Type::LARGE_UTF8,
+ Type::LARGE_BINARY
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesType() {
+ static const char * const names[20] = {
+ "BOOL",
+ "INT8",
+ "INT16",
+ "INT32",
+ "INT64",
+ "UINT8",
+ "UINT16",
+ "UINT32",
+ "UINT64",
+ "FLOAT",
+ "DOUBLE",
+ "UTF8",
+ "BINARY",
+ "CATEGORY",
+ "TIMESTAMP",
+ "DATE",
+ "TIME",
+ "LARGE_UTF8",
+ "LARGE_BINARY",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameType(Type e) {
+ if (flatbuffers::IsOutRange(e, Type::BOOL, Type::LARGE_BINARY)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesType()[index];
+}
+
+enum class Encoding : int8_t {
+ PLAIN = 0,
+ /// Data is stored dictionary-encoded
+ /// dictionary size: <INT32 Dictionary size>
+ /// dictionary data: <TYPE primitive array>
+ /// dictionary index: <INT32 primitive array>
+ ///
+ /// TODO: do we care about storing the index values in a smaller typeclass
+ DICTIONARY = 1,
+ MIN = PLAIN,
+ MAX = DICTIONARY
+};
+
+inline const Encoding (&EnumValuesEncoding())[2] {
+ static const Encoding values[] = {
+ Encoding::PLAIN,
+ Encoding::DICTIONARY
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesEncoding() {
+ static const char * const names[3] = {
+ "PLAIN",
+ "DICTIONARY",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameEncoding(Encoding e) {
+ if (flatbuffers::IsOutRange(e, Encoding::PLAIN, Encoding::DICTIONARY)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesEncoding()[index];
+}
+
+enum class TimeUnit : int8_t {
+ SECOND = 0,
+ MILLISECOND = 1,
+ MICROSECOND = 2,
+ NANOSECOND = 3,
+ MIN = SECOND,
+ MAX = NANOSECOND
+};
+
+inline const TimeUnit (&EnumValuesTimeUnit())[4] {
+ static const TimeUnit values[] = {
+ TimeUnit::SECOND,
+ TimeUnit::MILLISECOND,
+ TimeUnit::MICROSECOND,
+ TimeUnit::NANOSECOND
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesTimeUnit() {
+ static const char * const names[5] = {
+ "SECOND",
+ "MILLISECOND",
+ "MICROSECOND",
+ "NANOSECOND",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameTimeUnit(TimeUnit e) {
+ if (flatbuffers::IsOutRange(e, TimeUnit::SECOND, TimeUnit::NANOSECOND)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesTimeUnit()[index];
+}
+
+enum class TypeMetadata : uint8_t {
+ NONE = 0,
+ CategoryMetadata = 1,
+ TimestampMetadata = 2,
+ DateMetadata = 3,
+ TimeMetadata = 4,
+ MIN = NONE,
+ MAX = TimeMetadata
+};
+
+inline const TypeMetadata (&EnumValuesTypeMetadata())[5] {
+ static const TypeMetadata values[] = {
+ TypeMetadata::NONE,
+ TypeMetadata::CategoryMetadata,
+ TypeMetadata::TimestampMetadata,
+ TypeMetadata::DateMetadata,
+ TypeMetadata::TimeMetadata
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesTypeMetadata() {
+ static const char * const names[6] = {
+ "NONE",
+ "CategoryMetadata",
+ "TimestampMetadata",
+ "DateMetadata",
+ "TimeMetadata",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameTypeMetadata(TypeMetadata e) {
+ if (flatbuffers::IsOutRange(e, TypeMetadata::NONE, TypeMetadata::TimeMetadata)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesTypeMetadata()[index];
+}
+
+template<typename T> struct TypeMetadataTraits {
+ static const TypeMetadata enum_value = TypeMetadata::NONE;
+};
+
+template<> struct TypeMetadataTraits<arrow::ipc::feather::fbs::CategoryMetadata> {
+ static const TypeMetadata enum_value = TypeMetadata::CategoryMetadata;
+};
+
+template<> struct TypeMetadataTraits<arrow::ipc::feather::fbs::TimestampMetadata> {
+ static const TypeMetadata enum_value = TypeMetadata::TimestampMetadata;
+};
+
+template<> struct TypeMetadataTraits<arrow::ipc::feather::fbs::DateMetadata> {
+ static const TypeMetadata enum_value = TypeMetadata::DateMetadata;
+};
+
+template<> struct TypeMetadataTraits<arrow::ipc::feather::fbs::TimeMetadata> {
+ static const TypeMetadata enum_value = TypeMetadata::TimeMetadata;
+};
+
+bool VerifyTypeMetadata(flatbuffers::Verifier &verifier, const void *obj, TypeMetadata type);
+bool VerifyTypeMetadataVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+struct PrimitiveArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PrimitiveArrayBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_TYPE = 4,
+ VT_ENCODING = 6,
+ VT_OFFSET = 8,
+ VT_LENGTH = 10,
+ VT_NULL_COUNT = 12,
+ VT_TOTAL_BYTES = 14
+ };
+ arrow::ipc::feather::fbs::Type type() const {
+ return static_cast<arrow::ipc::feather::fbs::Type>(GetField<int8_t>(VT_TYPE, 0));
+ }
+ arrow::ipc::feather::fbs::Encoding encoding() const {
+ return static_cast<arrow::ipc::feather::fbs::Encoding>(GetField<int8_t>(VT_ENCODING, 0));
+ }
+ /// Relative memory offset of the start of the array data excluding the size
+ /// of the metadata
+ int64_t offset() const {
+ return GetField<int64_t>(VT_OFFSET, 0);
+ }
+ /// The number of logical values in the array
+ int64_t length() const {
+ return GetField<int64_t>(VT_LENGTH, 0);
+ }
+ /// The number of observed nulls
+ int64_t null_count() const {
+ return GetField<int64_t>(VT_NULL_COUNT, 0);
+ }
+ /// The total size of the actual data in the file
+ int64_t total_bytes() const {
+ return GetField<int64_t>(VT_TOTAL_BYTES, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_TYPE) &&
+ VerifyField<int8_t>(verifier, VT_ENCODING) &&
+ VerifyField<int64_t>(verifier, VT_OFFSET) &&
+ VerifyField<int64_t>(verifier, VT_LENGTH) &&
+ VerifyField<int64_t>(verifier, VT_NULL_COUNT) &&
+ VerifyField<int64_t>(verifier, VT_TOTAL_BYTES) &&
+ verifier.EndTable();
+ }
+};
+
+struct PrimitiveArrayBuilder {
+ typedef PrimitiveArray Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_type(arrow::ipc::feather::fbs::Type type) {
+ fbb_.AddElement<int8_t>(PrimitiveArray::VT_TYPE, static_cast<int8_t>(type), 0);
+ }
+ void add_encoding(arrow::ipc::feather::fbs::Encoding encoding) {
+ fbb_.AddElement<int8_t>(PrimitiveArray::VT_ENCODING, static_cast<int8_t>(encoding), 0);
+ }
+ void add_offset(int64_t offset) {
+ fbb_.AddElement<int64_t>(PrimitiveArray::VT_OFFSET, offset, 0);
+ }
+ void add_length(int64_t length) {
+ fbb_.AddElement<int64_t>(PrimitiveArray::VT_LENGTH, length, 0);
+ }
+ void add_null_count(int64_t null_count) {
+ fbb_.AddElement<int64_t>(PrimitiveArray::VT_NULL_COUNT, null_count, 0);
+ }
+ void add_total_bytes(int64_t total_bytes) {
+ fbb_.AddElement<int64_t>(PrimitiveArray::VT_TOTAL_BYTES, total_bytes, 0);
+ }
+ explicit PrimitiveArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PrimitiveArrayBuilder &operator=(const PrimitiveArrayBuilder &);
+ flatbuffers::Offset<PrimitiveArray> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PrimitiveArray>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PrimitiveArray> CreatePrimitiveArray(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ arrow::ipc::feather::fbs::Type type = arrow::ipc::feather::fbs::Type::BOOL,
+ arrow::ipc::feather::fbs::Encoding encoding = arrow::ipc::feather::fbs::Encoding::PLAIN,
+ int64_t offset = 0,
+ int64_t length = 0,
+ int64_t null_count = 0,
+ int64_t total_bytes = 0) {
+ PrimitiveArrayBuilder builder_(_fbb);
+ builder_.add_total_bytes(total_bytes);
+ builder_.add_null_count(null_count);
+ builder_.add_length(length);
+ builder_.add_offset(offset);
+ builder_.add_encoding(encoding);
+ builder_.add_type(type);
+ return builder_.Finish();
+}
+
+struct CategoryMetadata FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef CategoryMetadataBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_LEVELS = 4,
+ VT_ORDERED = 6
+ };
+ /// The category codes are presumed to be integers that are valid indexes into
+ /// the levels array
+ const arrow::ipc::feather::fbs::PrimitiveArray *levels() const {
+ return GetPointer<const arrow::ipc::feather::fbs::PrimitiveArray *>(VT_LEVELS);
+ }
+ bool ordered() const {
+ return GetField<uint8_t>(VT_ORDERED, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_LEVELS) &&
+ verifier.VerifyTable(levels()) &&
+ VerifyField<uint8_t>(verifier, VT_ORDERED) &&
+ verifier.EndTable();
+ }
+};
+
+struct CategoryMetadataBuilder {
+ typedef CategoryMetadata Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_levels(flatbuffers::Offset<arrow::ipc::feather::fbs::PrimitiveArray> levels) {
+ fbb_.AddOffset(CategoryMetadata::VT_LEVELS, levels);
+ }
+ void add_ordered(bool ordered) {
+ fbb_.AddElement<uint8_t>(CategoryMetadata::VT_ORDERED, static_cast<uint8_t>(ordered), 0);
+ }
+ explicit CategoryMetadataBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CategoryMetadataBuilder &operator=(const CategoryMetadataBuilder &);
+ flatbuffers::Offset<CategoryMetadata> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<CategoryMetadata>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<CategoryMetadata> CreateCategoryMetadata(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<arrow::ipc::feather::fbs::PrimitiveArray> levels = 0,
+ bool ordered = false) {
+ CategoryMetadataBuilder builder_(_fbb);
+ builder_.add_levels(levels);
+ builder_.add_ordered(ordered);
+ return builder_.Finish();
+}
+
+struct TimestampMetadata FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef TimestampMetadataBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_UNIT = 4,
+ VT_TIMEZONE = 6
+ };
+ arrow::ipc::feather::fbs::TimeUnit unit() const {
+ return static_cast<arrow::ipc::feather::fbs::TimeUnit>(GetField<int8_t>(VT_UNIT, 0));
+ }
+ /// Timestamp data is assumed to be UTC, but the time zone is stored here for
+ /// presentation as localized
+ const flatbuffers::String *timezone() const {
+ return GetPointer<const flatbuffers::String *>(VT_TIMEZONE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_UNIT) &&
+ VerifyOffset(verifier, VT_TIMEZONE) &&
+ verifier.VerifyString(timezone()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TimestampMetadataBuilder {
+ typedef TimestampMetadata Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_unit(arrow::ipc::feather::fbs::TimeUnit unit) {
+ fbb_.AddElement<int8_t>(TimestampMetadata::VT_UNIT, static_cast<int8_t>(unit), 0);
+ }
+ void add_timezone(flatbuffers::Offset<flatbuffers::String> timezone) {
+ fbb_.AddOffset(TimestampMetadata::VT_TIMEZONE, timezone);
+ }
+ explicit TimestampMetadataBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TimestampMetadataBuilder &operator=(const TimestampMetadataBuilder &);
+ flatbuffers::Offset<TimestampMetadata> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TimestampMetadata>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TimestampMetadata> CreateTimestampMetadata(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ arrow::ipc::feather::fbs::TimeUnit unit = arrow::ipc::feather::fbs::TimeUnit::SECOND,
+ flatbuffers::Offset<flatbuffers::String> timezone = 0) {
+ TimestampMetadataBuilder builder_(_fbb);
+ builder_.add_timezone(timezone);
+ builder_.add_unit(unit);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TimestampMetadata> CreateTimestampMetadataDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ arrow::ipc::feather::fbs::TimeUnit unit = arrow::ipc::feather::fbs::TimeUnit::SECOND,
+ const char *timezone = nullptr) {
+ auto timezone__ = timezone ? _fbb.CreateString(timezone) : 0;
+ return arrow::ipc::feather::fbs::CreateTimestampMetadata(
+ _fbb,
+ unit,
+ timezone__);
+}
+
+struct DateMetadata FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef DateMetadataBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+};
+
+struct DateMetadataBuilder {
+ typedef DateMetadata Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit DateMetadataBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ DateMetadataBuilder &operator=(const DateMetadataBuilder &);
+ flatbuffers::Offset<DateMetadata> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<DateMetadata>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<DateMetadata> CreateDateMetadata(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ DateMetadataBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct TimeMetadata FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef TimeMetadataBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_UNIT = 4
+ };
+ arrow::ipc::feather::fbs::TimeUnit unit() const {
+ return static_cast<arrow::ipc::feather::fbs::TimeUnit>(GetField<int8_t>(VT_UNIT, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_UNIT) &&
+ verifier.EndTable();
+ }
+};
+
+struct TimeMetadataBuilder {
+ typedef TimeMetadata Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_unit(arrow::ipc::feather::fbs::TimeUnit unit) {
+ fbb_.AddElement<int8_t>(TimeMetadata::VT_UNIT, static_cast<int8_t>(unit), 0);
+ }
+ explicit TimeMetadataBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TimeMetadataBuilder &operator=(const TimeMetadataBuilder &);
+ flatbuffers::Offset<TimeMetadata> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TimeMetadata>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TimeMetadata> CreateTimeMetadata(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ arrow::ipc::feather::fbs::TimeUnit unit = arrow::ipc::feather::fbs::TimeUnit::SECOND) {
+ TimeMetadataBuilder builder_(_fbb);
+ builder_.add_unit(unit);
+ return builder_.Finish();
+}
+
+struct Column FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ColumnBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_NAME = 4,
+ VT_VALUES = 6,
+ VT_METADATA_TYPE = 8,
+ VT_METADATA = 10,
+ VT_USER_METADATA = 12
+ };
+ const flatbuffers::String *name() const {
+ return GetPointer<const flatbuffers::String *>(VT_NAME);
+ }
+ const arrow::ipc::feather::fbs::PrimitiveArray *values() const {
+ return GetPointer<const arrow::ipc::feather::fbs::PrimitiveArray *>(VT_VALUES);
+ }
+ arrow::ipc::feather::fbs::TypeMetadata metadata_type() const {
+ return static_cast<arrow::ipc::feather::fbs::TypeMetadata>(GetField<uint8_t>(VT_METADATA_TYPE, 0));
+ }
+ const void *metadata() const {
+ return GetPointer<const void *>(VT_METADATA);
+ }
+ template<typename T> const T *metadata_as() const;
+ const arrow::ipc::feather::fbs::CategoryMetadata *metadata_as_CategoryMetadata() const {
+ return metadata_type() == arrow::ipc::feather::fbs::TypeMetadata::CategoryMetadata ? static_cast<const arrow::ipc::feather::fbs::CategoryMetadata *>(metadata()) : nullptr;
+ }
+ const arrow::ipc::feather::fbs::TimestampMetadata *metadata_as_TimestampMetadata() const {
+ return metadata_type() == arrow::ipc::feather::fbs::TypeMetadata::TimestampMetadata ? static_cast<const arrow::ipc::feather::fbs::TimestampMetadata *>(metadata()) : nullptr;
+ }
+ const arrow::ipc::feather::fbs::DateMetadata *metadata_as_DateMetadata() const {
+ return metadata_type() == arrow::ipc::feather::fbs::TypeMetadata::DateMetadata ? static_cast<const arrow::ipc::feather::fbs::DateMetadata *>(metadata()) : nullptr;
+ }
+ const arrow::ipc::feather::fbs::TimeMetadata *metadata_as_TimeMetadata() const {
+ return metadata_type() == arrow::ipc::feather::fbs::TypeMetadata::TimeMetadata ? static_cast<const arrow::ipc::feather::fbs::TimeMetadata *>(metadata()) : nullptr;
+ }
+ /// This should (probably) be JSON
+ const flatbuffers::String *user_metadata() const {
+ return GetPointer<const flatbuffers::String *>(VT_USER_METADATA);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_NAME) &&
+ verifier.VerifyString(name()) &&
+ VerifyOffset(verifier, VT_VALUES) &&
+ verifier.VerifyTable(values()) &&
+ VerifyField<uint8_t>(verifier, VT_METADATA_TYPE) &&
+ VerifyOffset(verifier, VT_METADATA) &&
+ VerifyTypeMetadata(verifier, metadata(), metadata_type()) &&
+ VerifyOffset(verifier, VT_USER_METADATA) &&
+ verifier.VerifyString(user_metadata()) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const arrow::ipc::feather::fbs::CategoryMetadata *Column::metadata_as<arrow::ipc::feather::fbs::CategoryMetadata>() const {
+ return metadata_as_CategoryMetadata();
+}
+
+template<> inline const arrow::ipc::feather::fbs::TimestampMetadata *Column::metadata_as<arrow::ipc::feather::fbs::TimestampMetadata>() const {
+ return metadata_as_TimestampMetadata();
+}
+
+template<> inline const arrow::ipc::feather::fbs::DateMetadata *Column::metadata_as<arrow::ipc::feather::fbs::DateMetadata>() const {
+ return metadata_as_DateMetadata();
+}
+
+template<> inline const arrow::ipc::feather::fbs::TimeMetadata *Column::metadata_as<arrow::ipc::feather::fbs::TimeMetadata>() const {
+ return metadata_as_TimeMetadata();
+}
+
+struct ColumnBuilder {
+ typedef Column Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_name(flatbuffers::Offset<flatbuffers::String> name) {
+ fbb_.AddOffset(Column::VT_NAME, name);
+ }
+ void add_values(flatbuffers::Offset<arrow::ipc::feather::fbs::PrimitiveArray> values) {
+ fbb_.AddOffset(Column::VT_VALUES, values);
+ }
+ void add_metadata_type(arrow::ipc::feather::fbs::TypeMetadata metadata_type) {
+ fbb_.AddElement<uint8_t>(Column::VT_METADATA_TYPE, static_cast<uint8_t>(metadata_type), 0);
+ }
+ void add_metadata(flatbuffers::Offset<void> metadata) {
+ fbb_.AddOffset(Column::VT_METADATA, metadata);
+ }
+ void add_user_metadata(flatbuffers::Offset<flatbuffers::String> user_metadata) {
+ fbb_.AddOffset(Column::VT_USER_METADATA, user_metadata);
+ }
+ explicit ColumnBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ColumnBuilder &operator=(const ColumnBuilder &);
+ flatbuffers::Offset<Column> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Column>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Column> CreateColumn(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> name = 0,
+ flatbuffers::Offset<arrow::ipc::feather::fbs::PrimitiveArray> values = 0,
+ arrow::ipc::feather::fbs::TypeMetadata metadata_type = arrow::ipc::feather::fbs::TypeMetadata::NONE,
+ flatbuffers::Offset<void> metadata = 0,
+ flatbuffers::Offset<flatbuffers::String> user_metadata = 0) {
+ ColumnBuilder builder_(_fbb);
+ builder_.add_user_metadata(user_metadata);
+ builder_.add_metadata(metadata);
+ builder_.add_values(values);
+ builder_.add_name(name);
+ builder_.add_metadata_type(metadata_type);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Column> CreateColumnDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *name = nullptr,
+ flatbuffers::Offset<arrow::ipc::feather::fbs::PrimitiveArray> values = 0,
+ arrow::ipc::feather::fbs::TypeMetadata metadata_type = arrow::ipc::feather::fbs::TypeMetadata::NONE,
+ flatbuffers::Offset<void> metadata = 0,
+ const char *user_metadata = nullptr) {
+ auto name__ = name ? _fbb.CreateString(name) : 0;
+ auto user_metadata__ = user_metadata ? _fbb.CreateString(user_metadata) : 0;
+ return arrow::ipc::feather::fbs::CreateColumn(
+ _fbb,
+ name__,
+ values,
+ metadata_type,
+ metadata,
+ user_metadata__);
+}
+
+struct CTable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef CTableBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_DESCRIPTION = 4,
+ VT_NUM_ROWS = 6,
+ VT_COLUMNS = 8,
+ VT_VERSION = 10,
+ VT_METADATA = 12
+ };
+ /// Some text (or a name) metadata about what the file is, optional
+ const flatbuffers::String *description() const {
+ return GetPointer<const flatbuffers::String *>(VT_DESCRIPTION);
+ }
+ int64_t num_rows() const {
+ return GetField<int64_t>(VT_NUM_ROWS, 0);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<arrow::ipc::feather::fbs::Column>> *columns() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<arrow::ipc::feather::fbs::Column>> *>(VT_COLUMNS);
+ }
+ /// Version number of the Feather format
+ ///
+ /// Internal versions 0, 1, and 2: Implemented in Apache Arrow <= 0.16.0 and
+ /// wesm/feather. Uses "custom" metadata defined in this file.
+ int32_t version() const {
+ return GetField<int32_t>(VT_VERSION, 0);
+ }
+ /// Table metadata (likely JSON), not yet used
+ const flatbuffers::String *metadata() const {
+ return GetPointer<const flatbuffers::String *>(VT_METADATA);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_DESCRIPTION) &&
+ verifier.VerifyString(description()) &&
+ VerifyField<int64_t>(verifier, VT_NUM_ROWS) &&
+ VerifyOffset(verifier, VT_COLUMNS) &&
+ verifier.VerifyVector(columns()) &&
+ verifier.VerifyVectorOfTables(columns()) &&
+ VerifyField<int32_t>(verifier, VT_VERSION) &&
+ VerifyOffset(verifier, VT_METADATA) &&
+ verifier.VerifyString(metadata()) &&
+ verifier.EndTable();
+ }
+};
+
+struct CTableBuilder {
+ typedef CTable Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_description(flatbuffers::Offset<flatbuffers::String> description) {
+ fbb_.AddOffset(CTable::VT_DESCRIPTION, description);
+ }
+ void add_num_rows(int64_t num_rows) {
+ fbb_.AddElement<int64_t>(CTable::VT_NUM_ROWS, num_rows, 0);
+ }
+ void add_columns(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<arrow::ipc::feather::fbs::Column>>> columns) {
+ fbb_.AddOffset(CTable::VT_COLUMNS, columns);
+ }
+ void add_version(int32_t version) {
+ fbb_.AddElement<int32_t>(CTable::VT_VERSION, version, 0);
+ }
+ void add_metadata(flatbuffers::Offset<flatbuffers::String> metadata) {
+ fbb_.AddOffset(CTable::VT_METADATA, metadata);
+ }
+ explicit CTableBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CTableBuilder &operator=(const CTableBuilder &);
+ flatbuffers::Offset<CTable> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<CTable>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<CTable> CreateCTable(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> description = 0,
+ int64_t num_rows = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<arrow::ipc::feather::fbs::Column>>> columns = 0,
+ int32_t version = 0,
+ flatbuffers::Offset<flatbuffers::String> metadata = 0) {
+ CTableBuilder builder_(_fbb);
+ builder_.add_num_rows(num_rows);
+ builder_.add_metadata(metadata);
+ builder_.add_version(version);
+ builder_.add_columns(columns);
+ builder_.add_description(description);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<CTable> CreateCTableDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *description = nullptr,
+ int64_t num_rows = 0,
+ const std::vector<flatbuffers::Offset<arrow::ipc::feather::fbs::Column>> *columns = nullptr,
+ int32_t version = 0,
+ const char *metadata = nullptr) {
+ auto description__ = description ? _fbb.CreateString(description) : 0;
+ auto columns__ = columns ? _fbb.CreateVector<flatbuffers::Offset<arrow::ipc::feather::fbs::Column>>(*columns) : 0;
+ auto metadata__ = metadata ? _fbb.CreateString(metadata) : 0;
+ return arrow::ipc::feather::fbs::CreateCTable(
+ _fbb,
+ description__,
+ num_rows,
+ columns__,
+ version,
+ metadata__);
+}
+
+inline bool VerifyTypeMetadata(flatbuffers::Verifier &verifier, const void *obj, TypeMetadata type) {
+ switch (type) {
+ case TypeMetadata::NONE: {
+ return true;
+ }
+ case TypeMetadata::CategoryMetadata: {
+ auto ptr = reinterpret_cast<const arrow::ipc::feather::fbs::CategoryMetadata *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case TypeMetadata::TimestampMetadata: {
+ auto ptr = reinterpret_cast<const arrow::ipc::feather::fbs::TimestampMetadata *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case TypeMetadata::DateMetadata: {
+ auto ptr = reinterpret_cast<const arrow::ipc::feather::fbs::DateMetadata *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case TypeMetadata::TimeMetadata: {
+ auto ptr = reinterpret_cast<const arrow::ipc::feather::fbs::TimeMetadata *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return true;
+ }
+}
+
+inline bool VerifyTypeMetadataVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyTypeMetadata(
+ verifier, values->Get(i), types->GetEnum<TypeMetadata>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline const arrow::ipc::feather::fbs::CTable *GetCTable(const void *buf) {
+ return flatbuffers::GetRoot<arrow::ipc::feather::fbs::CTable>(buf);
+}
+
+inline const arrow::ipc::feather::fbs::CTable *GetSizePrefixedCTable(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<arrow::ipc::feather::fbs::CTable>(buf);
+}
+
+inline bool VerifyCTableBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifyBuffer<arrow::ipc::feather::fbs::CTable>(nullptr);
+}
+
+inline bool VerifySizePrefixedCTableBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<arrow::ipc::feather::fbs::CTable>(nullptr);
+}
+
+inline void FinishCTableBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<arrow::ipc::feather::fbs::CTable> root) {
+ fbb.Finish(root);
+}
+
+inline void FinishSizePrefixedCTableBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<arrow::ipc::feather::fbs::CTable> root) {
+ fbb.FinishSizePrefixed(root);
+}
+
+} // namespace fbs
+} // namespace feather
+} // namespace ipc
+} // namespace arrow
+
+#endif // FLATBUFFERS_GENERATED_FEATHER_ARROW_IPC_FEATHER_FBS_H_
diff --git a/src/arrow/cpp/src/generated/parquet_constants.cpp b/src/arrow/cpp/src/generated/parquet_constants.cpp
new file mode 100644
index 000000000..b1b4ce626
--- /dev/null
+++ b/src/arrow/cpp/src/generated/parquet_constants.cpp
@@ -0,0 +1,17 @@
+/**
+ * Autogenerated by Thrift Compiler (0.13.0)
+ *
+ * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
+ * @generated
+ */
+#include "parquet_constants.h"
+
+namespace parquet { namespace format {
+
+const parquetConstants g_parquet_constants;
+
+parquetConstants::parquetConstants() {
+}
+
+}} // namespace
+
diff --git a/src/arrow/cpp/src/generated/parquet_constants.h b/src/arrow/cpp/src/generated/parquet_constants.h
new file mode 100644
index 000000000..1e288c7cd
--- /dev/null
+++ b/src/arrow/cpp/src/generated/parquet_constants.h
@@ -0,0 +1,24 @@
+/**
+ * Autogenerated by Thrift Compiler (0.13.0)
+ *
+ * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
+ * @generated
+ */
+#ifndef parquet_CONSTANTS_H
+#define parquet_CONSTANTS_H
+
+#include "parquet_types.h"
+
+namespace parquet { namespace format {
+
+class parquetConstants {
+ public:
+ parquetConstants();
+
+};
+
+extern const parquetConstants g_parquet_constants;
+
+}} // namespace
+
+#endif
diff --git a/src/arrow/cpp/src/generated/parquet_types.cpp b/src/arrow/cpp/src/generated/parquet_types.cpp
new file mode 100644
index 000000000..cccd92e2e
--- /dev/null
+++ b/src/arrow/cpp/src/generated/parquet_types.cpp
@@ -0,0 +1,7413 @@
+/**
+ * Autogenerated by Thrift Compiler (0.13.0)
+ *
+ * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
+ * @generated
+ */
+#include "parquet_types.h"
+
+#include <algorithm>
+#include <ostream>
+
+#include <thrift/TToString.h>
+
+namespace parquet { namespace format {
+
+int _kTypeValues[] = {
+ Type::BOOLEAN,
+ Type::INT32,
+ Type::INT64,
+ Type::INT96,
+ Type::FLOAT,
+ Type::DOUBLE,
+ Type::BYTE_ARRAY,
+ Type::FIXED_LEN_BYTE_ARRAY
+};
+const char* _kTypeNames[] = {
+ "BOOLEAN",
+ "INT32",
+ "INT64",
+ "INT96",
+ "FLOAT",
+ "DOUBLE",
+ "BYTE_ARRAY",
+ "FIXED_LEN_BYTE_ARRAY"
+};
+const std::map<int, const char*> _Type_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(8, _kTypeValues, _kTypeNames), ::apache::thrift::TEnumIterator(-1, NULL, NULL));
+
+std::ostream& operator<<(std::ostream& out, const Type::type& val) {
+ std::map<int, const char*>::const_iterator it = _Type_VALUES_TO_NAMES.find(val);
+ if (it != _Type_VALUES_TO_NAMES.end()) {
+ out << it->second;
+ } else {
+ out << static_cast<int>(val);
+ }
+ return out;
+}
+
+std::string to_string(const Type::type& val) {
+ std::map<int, const char*>::const_iterator it = _Type_VALUES_TO_NAMES.find(val);
+ if (it != _Type_VALUES_TO_NAMES.end()) {
+ return std::string(it->second);
+ } else {
+ return std::to_string(static_cast<int>(val));
+ }
+}
+
+int _kConvertedTypeValues[] = {
+ ConvertedType::UTF8,
+ ConvertedType::MAP,
+ ConvertedType::MAP_KEY_VALUE,
+ ConvertedType::LIST,
+ ConvertedType::ENUM,
+ ConvertedType::DECIMAL,
+ ConvertedType::DATE,
+ ConvertedType::TIME_MILLIS,
+ ConvertedType::TIME_MICROS,
+ ConvertedType::TIMESTAMP_MILLIS,
+ ConvertedType::TIMESTAMP_MICROS,
+ ConvertedType::UINT_8,
+ ConvertedType::UINT_16,
+ ConvertedType::UINT_32,
+ ConvertedType::UINT_64,
+ ConvertedType::INT_8,
+ ConvertedType::INT_16,
+ ConvertedType::INT_32,
+ ConvertedType::INT_64,
+ ConvertedType::JSON,
+ ConvertedType::BSON,
+ ConvertedType::INTERVAL
+};
+const char* _kConvertedTypeNames[] = {
+ "UTF8",
+ "MAP",
+ "MAP_KEY_VALUE",
+ "LIST",
+ "ENUM",
+ "DECIMAL",
+ "DATE",
+ "TIME_MILLIS",
+ "TIME_MICROS",
+ "TIMESTAMP_MILLIS",
+ "TIMESTAMP_MICROS",
+ "UINT_8",
+ "UINT_16",
+ "UINT_32",
+ "UINT_64",
+ "INT_8",
+ "INT_16",
+ "INT_32",
+ "INT_64",
+ "JSON",
+ "BSON",
+ "INTERVAL"
+};
+const std::map<int, const char*> _ConvertedType_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(22, _kConvertedTypeValues, _kConvertedTypeNames), ::apache::thrift::TEnumIterator(-1, NULL, NULL));
+
+std::ostream& operator<<(std::ostream& out, const ConvertedType::type& val) {
+ std::map<int, const char*>::const_iterator it = _ConvertedType_VALUES_TO_NAMES.find(val);
+ if (it != _ConvertedType_VALUES_TO_NAMES.end()) {
+ out << it->second;
+ } else {
+ out << static_cast<int>(val);
+ }
+ return out;
+}
+
+std::string to_string(const ConvertedType::type& val) {
+ std::map<int, const char*>::const_iterator it = _ConvertedType_VALUES_TO_NAMES.find(val);
+ if (it != _ConvertedType_VALUES_TO_NAMES.end()) {
+ return std::string(it->second);
+ } else {
+ return std::to_string(static_cast<int>(val));
+ }
+}
+
+int _kFieldRepetitionTypeValues[] = {
+ FieldRepetitionType::REQUIRED,
+ FieldRepetitionType::OPTIONAL,
+ FieldRepetitionType::REPEATED
+};
+const char* _kFieldRepetitionTypeNames[] = {
+ "REQUIRED",
+ "OPTIONAL",
+ "REPEATED"
+};
+const std::map<int, const char*> _FieldRepetitionType_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(3, _kFieldRepetitionTypeValues, _kFieldRepetitionTypeNames), ::apache::thrift::TEnumIterator(-1, NULL, NULL));
+
+std::ostream& operator<<(std::ostream& out, const FieldRepetitionType::type& val) {
+ std::map<int, const char*>::const_iterator it = _FieldRepetitionType_VALUES_TO_NAMES.find(val);
+ if (it != _FieldRepetitionType_VALUES_TO_NAMES.end()) {
+ out << it->second;
+ } else {
+ out << static_cast<int>(val);
+ }
+ return out;
+}
+
+std::string to_string(const FieldRepetitionType::type& val) {
+ std::map<int, const char*>::const_iterator it = _FieldRepetitionType_VALUES_TO_NAMES.find(val);
+ if (it != _FieldRepetitionType_VALUES_TO_NAMES.end()) {
+ return std::string(it->second);
+ } else {
+ return std::to_string(static_cast<int>(val));
+ }
+}
+
+int _kEncodingValues[] = {
+ Encoding::PLAIN,
+ Encoding::PLAIN_DICTIONARY,
+ Encoding::RLE,
+ Encoding::BIT_PACKED,
+ Encoding::DELTA_BINARY_PACKED,
+ Encoding::DELTA_LENGTH_BYTE_ARRAY,
+ Encoding::DELTA_BYTE_ARRAY,
+ Encoding::RLE_DICTIONARY,
+ Encoding::BYTE_STREAM_SPLIT
+};
+const char* _kEncodingNames[] = {
+ "PLAIN",
+ "PLAIN_DICTIONARY",
+ "RLE",
+ "BIT_PACKED",
+ "DELTA_BINARY_PACKED",
+ "DELTA_LENGTH_BYTE_ARRAY",
+ "DELTA_BYTE_ARRAY",
+ "RLE_DICTIONARY",
+ "BYTE_STREAM_SPLIT"
+};
+const std::map<int, const char*> _Encoding_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(9, _kEncodingValues, _kEncodingNames), ::apache::thrift::TEnumIterator(-1, NULL, NULL));
+
+std::ostream& operator<<(std::ostream& out, const Encoding::type& val) {
+ std::map<int, const char*>::const_iterator it = _Encoding_VALUES_TO_NAMES.find(val);
+ if (it != _Encoding_VALUES_TO_NAMES.end()) {
+ out << it->second;
+ } else {
+ out << static_cast<int>(val);
+ }
+ return out;
+}
+
+std::string to_string(const Encoding::type& val) {
+ std::map<int, const char*>::const_iterator it = _Encoding_VALUES_TO_NAMES.find(val);
+ if (it != _Encoding_VALUES_TO_NAMES.end()) {
+ return std::string(it->second);
+ } else {
+ return std::to_string(static_cast<int>(val));
+ }
+}
+
+int _kCompressionCodecValues[] = {
+ CompressionCodec::UNCOMPRESSED,
+ CompressionCodec::SNAPPY,
+ CompressionCodec::GZIP,
+ CompressionCodec::LZO,
+ CompressionCodec::BROTLI,
+ CompressionCodec::LZ4,
+ CompressionCodec::ZSTD,
+ CompressionCodec::LZ4_RAW
+};
+const char* _kCompressionCodecNames[] = {
+ "UNCOMPRESSED",
+ "SNAPPY",
+ "GZIP",
+ "LZO",
+ "BROTLI",
+ "LZ4",
+ "ZSTD",
+ "LZ4_RAW"
+};
+const std::map<int, const char*> _CompressionCodec_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(8, _kCompressionCodecValues, _kCompressionCodecNames), ::apache::thrift::TEnumIterator(-1, NULL, NULL));
+
+std::ostream& operator<<(std::ostream& out, const CompressionCodec::type& val) {
+ std::map<int, const char*>::const_iterator it = _CompressionCodec_VALUES_TO_NAMES.find(val);
+ if (it != _CompressionCodec_VALUES_TO_NAMES.end()) {
+ out << it->second;
+ } else {
+ out << static_cast<int>(val);
+ }
+ return out;
+}
+
+std::string to_string(const CompressionCodec::type& val) {
+ std::map<int, const char*>::const_iterator it = _CompressionCodec_VALUES_TO_NAMES.find(val);
+ if (it != _CompressionCodec_VALUES_TO_NAMES.end()) {
+ return std::string(it->second);
+ } else {
+ return std::to_string(static_cast<int>(val));
+ }
+}
+
+int _kPageTypeValues[] = {
+ PageType::DATA_PAGE,
+ PageType::INDEX_PAGE,
+ PageType::DICTIONARY_PAGE,
+ PageType::DATA_PAGE_V2
+};
+const char* _kPageTypeNames[] = {
+ "DATA_PAGE",
+ "INDEX_PAGE",
+ "DICTIONARY_PAGE",
+ "DATA_PAGE_V2"
+};
+const std::map<int, const char*> _PageType_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(4, _kPageTypeValues, _kPageTypeNames), ::apache::thrift::TEnumIterator(-1, NULL, NULL));
+
+std::ostream& operator<<(std::ostream& out, const PageType::type& val) {
+ std::map<int, const char*>::const_iterator it = _PageType_VALUES_TO_NAMES.find(val);
+ if (it != _PageType_VALUES_TO_NAMES.end()) {
+ out << it->second;
+ } else {
+ out << static_cast<int>(val);
+ }
+ return out;
+}
+
+std::string to_string(const PageType::type& val) {
+ std::map<int, const char*>::const_iterator it = _PageType_VALUES_TO_NAMES.find(val);
+ if (it != _PageType_VALUES_TO_NAMES.end()) {
+ return std::string(it->second);
+ } else {
+ return std::to_string(static_cast<int>(val));
+ }
+}
+
+int _kBoundaryOrderValues[] = {
+ BoundaryOrder::UNORDERED,
+ BoundaryOrder::ASCENDING,
+ BoundaryOrder::DESCENDING
+};
+const char* _kBoundaryOrderNames[] = {
+ "UNORDERED",
+ "ASCENDING",
+ "DESCENDING"
+};
+const std::map<int, const char*> _BoundaryOrder_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(3, _kBoundaryOrderValues, _kBoundaryOrderNames), ::apache::thrift::TEnumIterator(-1, NULL, NULL));
+
+std::ostream& operator<<(std::ostream& out, const BoundaryOrder::type& val) {
+ std::map<int, const char*>::const_iterator it = _BoundaryOrder_VALUES_TO_NAMES.find(val);
+ if (it != _BoundaryOrder_VALUES_TO_NAMES.end()) {
+ out << it->second;
+ } else {
+ out << static_cast<int>(val);
+ }
+ return out;
+}
+
+std::string to_string(const BoundaryOrder::type& val) {
+ std::map<int, const char*>::const_iterator it = _BoundaryOrder_VALUES_TO_NAMES.find(val);
+ if (it != _BoundaryOrder_VALUES_TO_NAMES.end()) {
+ return std::string(it->second);
+ } else {
+ return std::to_string(static_cast<int>(val));
+ }
+}
+
+
+Statistics::~Statistics() noexcept {
+}
+
+
+void Statistics::__set_max(const std::string& val) {
+ this->max = val;
+__isset.max = true;
+}
+
+void Statistics::__set_min(const std::string& val) {
+ this->min = val;
+__isset.min = true;
+}
+
+void Statistics::__set_null_count(const int64_t val) {
+ this->null_count = val;
+__isset.null_count = true;
+}
+
+void Statistics::__set_distinct_count(const int64_t val) {
+ this->distinct_count = val;
+__isset.distinct_count = true;
+}
+
+void Statistics::__set_max_value(const std::string& val) {
+ this->max_value = val;
+__isset.max_value = true;
+}
+
+void Statistics::__set_min_value(const std::string& val) {
+ this->min_value = val;
+__isset.min_value = true;
+}
+std::ostream& operator<<(std::ostream& out, const Statistics& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t Statistics::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readBinary(this->max);
+ this->__isset.max = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readBinary(this->min);
+ this->__isset.min = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->null_count);
+ this->__isset.null_count = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 4:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->distinct_count);
+ this->__isset.distinct_count = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 5:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readBinary(this->max_value);
+ this->__isset.max_value = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 6:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readBinary(this->min_value);
+ this->__isset.min_value = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t Statistics::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("Statistics");
+
+ if (this->__isset.max) {
+ xfer += oprot->writeFieldBegin("max", ::apache::thrift::protocol::T_STRING, 1);
+ xfer += oprot->writeBinary(this->max);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.min) {
+ xfer += oprot->writeFieldBegin("min", ::apache::thrift::protocol::T_STRING, 2);
+ xfer += oprot->writeBinary(this->min);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.null_count) {
+ xfer += oprot->writeFieldBegin("null_count", ::apache::thrift::protocol::T_I64, 3);
+ xfer += oprot->writeI64(this->null_count);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.distinct_count) {
+ xfer += oprot->writeFieldBegin("distinct_count", ::apache::thrift::protocol::T_I64, 4);
+ xfer += oprot->writeI64(this->distinct_count);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.max_value) {
+ xfer += oprot->writeFieldBegin("max_value", ::apache::thrift::protocol::T_STRING, 5);
+ xfer += oprot->writeBinary(this->max_value);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.min_value) {
+ xfer += oprot->writeFieldBegin("min_value", ::apache::thrift::protocol::T_STRING, 6);
+ xfer += oprot->writeBinary(this->min_value);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(Statistics &a, Statistics &b) {
+ using ::std::swap;
+ swap(a.max, b.max);
+ swap(a.min, b.min);
+ swap(a.null_count, b.null_count);
+ swap(a.distinct_count, b.distinct_count);
+ swap(a.max_value, b.max_value);
+ swap(a.min_value, b.min_value);
+ swap(a.__isset, b.__isset);
+}
+
+Statistics::Statistics(const Statistics& other0) {
+ max = other0.max;
+ min = other0.min;
+ null_count = other0.null_count;
+ distinct_count = other0.distinct_count;
+ max_value = other0.max_value;
+ min_value = other0.min_value;
+ __isset = other0.__isset;
+}
+Statistics& Statistics::operator=(const Statistics& other1) {
+ max = other1.max;
+ min = other1.min;
+ null_count = other1.null_count;
+ distinct_count = other1.distinct_count;
+ max_value = other1.max_value;
+ min_value = other1.min_value;
+ __isset = other1.__isset;
+ return *this;
+}
+void Statistics::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "Statistics(";
+ out << "max="; (__isset.max ? (out << to_string(max)) : (out << "<null>"));
+ out << ", " << "min="; (__isset.min ? (out << to_string(min)) : (out << "<null>"));
+ out << ", " << "null_count="; (__isset.null_count ? (out << to_string(null_count)) : (out << "<null>"));
+ out << ", " << "distinct_count="; (__isset.distinct_count ? (out << to_string(distinct_count)) : (out << "<null>"));
+ out << ", " << "max_value="; (__isset.max_value ? (out << to_string(max_value)) : (out << "<null>"));
+ out << ", " << "min_value="; (__isset.min_value ? (out << to_string(min_value)) : (out << "<null>"));
+ out << ")";
+}
+
+
+StringType::~StringType() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const StringType& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t StringType::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t StringType::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("StringType");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(StringType &a, StringType &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+StringType::StringType(const StringType& other2) {
+ (void) other2;
+}
+StringType& StringType::operator=(const StringType& other3) {
+ (void) other3;
+ return *this;
+}
+void StringType::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "StringType(";
+ out << ")";
+}
+
+
+UUIDType::~UUIDType() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const UUIDType& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t UUIDType::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t UUIDType::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("UUIDType");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(UUIDType &a, UUIDType &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+UUIDType::UUIDType(const UUIDType& other4) {
+ (void) other4;
+}
+UUIDType& UUIDType::operator=(const UUIDType& other5) {
+ (void) other5;
+ return *this;
+}
+void UUIDType::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "UUIDType(";
+ out << ")";
+}
+
+
+MapType::~MapType() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const MapType& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t MapType::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t MapType::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("MapType");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(MapType &a, MapType &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+MapType::MapType(const MapType& other6) {
+ (void) other6;
+}
+MapType& MapType::operator=(const MapType& other7) {
+ (void) other7;
+ return *this;
+}
+void MapType::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "MapType(";
+ out << ")";
+}
+
+
+ListType::~ListType() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const ListType& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t ListType::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t ListType::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("ListType");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(ListType &a, ListType &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+ListType::ListType(const ListType& other8) {
+ (void) other8;
+}
+ListType& ListType::operator=(const ListType& other9) {
+ (void) other9;
+ return *this;
+}
+void ListType::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "ListType(";
+ out << ")";
+}
+
+
+EnumType::~EnumType() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const EnumType& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t EnumType::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t EnumType::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("EnumType");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(EnumType &a, EnumType &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+EnumType::EnumType(const EnumType& other10) {
+ (void) other10;
+}
+EnumType& EnumType::operator=(const EnumType& other11) {
+ (void) other11;
+ return *this;
+}
+void EnumType::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "EnumType(";
+ out << ")";
+}
+
+
+DateType::~DateType() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const DateType& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t DateType::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t DateType::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("DateType");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(DateType &a, DateType &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+DateType::DateType(const DateType& other12) {
+ (void) other12;
+}
+DateType& DateType::operator=(const DateType& other13) {
+ (void) other13;
+ return *this;
+}
+void DateType::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "DateType(";
+ out << ")";
+}
+
+
+NullType::~NullType() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const NullType& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t NullType::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t NullType::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("NullType");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(NullType &a, NullType &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+NullType::NullType(const NullType& other14) {
+ (void) other14;
+}
+NullType& NullType::operator=(const NullType& other15) {
+ (void) other15;
+ return *this;
+}
+void NullType::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "NullType(";
+ out << ")";
+}
+
+
+DecimalType::~DecimalType() noexcept {
+}
+
+
+void DecimalType::__set_scale(const int32_t val) {
+ this->scale = val;
+}
+
+void DecimalType::__set_precision(const int32_t val) {
+ this->precision = val;
+}
+std::ostream& operator<<(std::ostream& out, const DecimalType& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t DecimalType::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_scale = false;
+ bool isset_precision = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->scale);
+ isset_scale = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->precision);
+ isset_precision = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_scale)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_precision)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t DecimalType::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("DecimalType");
+
+ xfer += oprot->writeFieldBegin("scale", ::apache::thrift::protocol::T_I32, 1);
+ xfer += oprot->writeI32(this->scale);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("precision", ::apache::thrift::protocol::T_I32, 2);
+ xfer += oprot->writeI32(this->precision);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(DecimalType &a, DecimalType &b) {
+ using ::std::swap;
+ swap(a.scale, b.scale);
+ swap(a.precision, b.precision);
+}
+
+DecimalType::DecimalType(const DecimalType& other16) {
+ scale = other16.scale;
+ precision = other16.precision;
+}
+DecimalType& DecimalType::operator=(const DecimalType& other17) {
+ scale = other17.scale;
+ precision = other17.precision;
+ return *this;
+}
+void DecimalType::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "DecimalType(";
+ out << "scale=" << to_string(scale);
+ out << ", " << "precision=" << to_string(precision);
+ out << ")";
+}
+
+
+MilliSeconds::~MilliSeconds() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const MilliSeconds& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t MilliSeconds::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t MilliSeconds::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("MilliSeconds");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(MilliSeconds &a, MilliSeconds &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+MilliSeconds::MilliSeconds(const MilliSeconds& other18) {
+ (void) other18;
+}
+MilliSeconds& MilliSeconds::operator=(const MilliSeconds& other19) {
+ (void) other19;
+ return *this;
+}
+void MilliSeconds::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "MilliSeconds(";
+ out << ")";
+}
+
+
+MicroSeconds::~MicroSeconds() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const MicroSeconds& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t MicroSeconds::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t MicroSeconds::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("MicroSeconds");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(MicroSeconds &a, MicroSeconds &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+MicroSeconds::MicroSeconds(const MicroSeconds& other20) {
+ (void) other20;
+}
+MicroSeconds& MicroSeconds::operator=(const MicroSeconds& other21) {
+ (void) other21;
+ return *this;
+}
+void MicroSeconds::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "MicroSeconds(";
+ out << ")";
+}
+
+
+NanoSeconds::~NanoSeconds() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const NanoSeconds& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t NanoSeconds::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t NanoSeconds::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("NanoSeconds");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(NanoSeconds &a, NanoSeconds &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+NanoSeconds::NanoSeconds(const NanoSeconds& other22) {
+ (void) other22;
+}
+NanoSeconds& NanoSeconds::operator=(const NanoSeconds& other23) {
+ (void) other23;
+ return *this;
+}
+void NanoSeconds::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "NanoSeconds(";
+ out << ")";
+}
+
+
+TimeUnit::~TimeUnit() noexcept {
+}
+
+
+void TimeUnit::__set_MILLIS(const MilliSeconds& val) {
+ this->MILLIS = val;
+__isset.MILLIS = true;
+}
+
+void TimeUnit::__set_MICROS(const MicroSeconds& val) {
+ this->MICROS = val;
+__isset.MICROS = true;
+}
+
+void TimeUnit::__set_NANOS(const NanoSeconds& val) {
+ this->NANOS = val;
+__isset.NANOS = true;
+}
+std::ostream& operator<<(std::ostream& out, const TimeUnit& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t TimeUnit::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->MILLIS.read(iprot);
+ this->__isset.MILLIS = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->MICROS.read(iprot);
+ this->__isset.MICROS = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->NANOS.read(iprot);
+ this->__isset.NANOS = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t TimeUnit::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("TimeUnit");
+
+ if (this->__isset.MILLIS) {
+ xfer += oprot->writeFieldBegin("MILLIS", ::apache::thrift::protocol::T_STRUCT, 1);
+ xfer += this->MILLIS.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.MICROS) {
+ xfer += oprot->writeFieldBegin("MICROS", ::apache::thrift::protocol::T_STRUCT, 2);
+ xfer += this->MICROS.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.NANOS) {
+ xfer += oprot->writeFieldBegin("NANOS", ::apache::thrift::protocol::T_STRUCT, 3);
+ xfer += this->NANOS.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(TimeUnit &a, TimeUnit &b) {
+ using ::std::swap;
+ swap(a.MILLIS, b.MILLIS);
+ swap(a.MICROS, b.MICROS);
+ swap(a.NANOS, b.NANOS);
+ swap(a.__isset, b.__isset);
+}
+
+TimeUnit::TimeUnit(const TimeUnit& other24) {
+ MILLIS = other24.MILLIS;
+ MICROS = other24.MICROS;
+ NANOS = other24.NANOS;
+ __isset = other24.__isset;
+}
+TimeUnit& TimeUnit::operator=(const TimeUnit& other25) {
+ MILLIS = other25.MILLIS;
+ MICROS = other25.MICROS;
+ NANOS = other25.NANOS;
+ __isset = other25.__isset;
+ return *this;
+}
+void TimeUnit::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "TimeUnit(";
+ out << "MILLIS="; (__isset.MILLIS ? (out << to_string(MILLIS)) : (out << "<null>"));
+ out << ", " << "MICROS="; (__isset.MICROS ? (out << to_string(MICROS)) : (out << "<null>"));
+ out << ", " << "NANOS="; (__isset.NANOS ? (out << to_string(NANOS)) : (out << "<null>"));
+ out << ")";
+}
+
+
+TimestampType::~TimestampType() noexcept {
+}
+
+
+void TimestampType::__set_isAdjustedToUTC(const bool val) {
+ this->isAdjustedToUTC = val;
+}
+
+void TimestampType::__set_unit(const TimeUnit& val) {
+ this->unit = val;
+}
+std::ostream& operator<<(std::ostream& out, const TimestampType& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t TimestampType::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_isAdjustedToUTC = false;
+ bool isset_unit = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_BOOL) {
+ xfer += iprot->readBool(this->isAdjustedToUTC);
+ isset_isAdjustedToUTC = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->unit.read(iprot);
+ isset_unit = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_isAdjustedToUTC)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_unit)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t TimestampType::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("TimestampType");
+
+ xfer += oprot->writeFieldBegin("isAdjustedToUTC", ::apache::thrift::protocol::T_BOOL, 1);
+ xfer += oprot->writeBool(this->isAdjustedToUTC);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("unit", ::apache::thrift::protocol::T_STRUCT, 2);
+ xfer += this->unit.write(oprot);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(TimestampType &a, TimestampType &b) {
+ using ::std::swap;
+ swap(a.isAdjustedToUTC, b.isAdjustedToUTC);
+ swap(a.unit, b.unit);
+}
+
+TimestampType::TimestampType(const TimestampType& other26) {
+ isAdjustedToUTC = other26.isAdjustedToUTC;
+ unit = other26.unit;
+}
+TimestampType& TimestampType::operator=(const TimestampType& other27) {
+ isAdjustedToUTC = other27.isAdjustedToUTC;
+ unit = other27.unit;
+ return *this;
+}
+void TimestampType::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "TimestampType(";
+ out << "isAdjustedToUTC=" << to_string(isAdjustedToUTC);
+ out << ", " << "unit=" << to_string(unit);
+ out << ")";
+}
+
+
+TimeType::~TimeType() noexcept {
+}
+
+
+void TimeType::__set_isAdjustedToUTC(const bool val) {
+ this->isAdjustedToUTC = val;
+}
+
+void TimeType::__set_unit(const TimeUnit& val) {
+ this->unit = val;
+}
+std::ostream& operator<<(std::ostream& out, const TimeType& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t TimeType::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_isAdjustedToUTC = false;
+ bool isset_unit = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_BOOL) {
+ xfer += iprot->readBool(this->isAdjustedToUTC);
+ isset_isAdjustedToUTC = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->unit.read(iprot);
+ isset_unit = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_isAdjustedToUTC)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_unit)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t TimeType::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("TimeType");
+
+ xfer += oprot->writeFieldBegin("isAdjustedToUTC", ::apache::thrift::protocol::T_BOOL, 1);
+ xfer += oprot->writeBool(this->isAdjustedToUTC);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("unit", ::apache::thrift::protocol::T_STRUCT, 2);
+ xfer += this->unit.write(oprot);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(TimeType &a, TimeType &b) {
+ using ::std::swap;
+ swap(a.isAdjustedToUTC, b.isAdjustedToUTC);
+ swap(a.unit, b.unit);
+}
+
+TimeType::TimeType(const TimeType& other28) {
+ isAdjustedToUTC = other28.isAdjustedToUTC;
+ unit = other28.unit;
+}
+TimeType& TimeType::operator=(const TimeType& other29) {
+ isAdjustedToUTC = other29.isAdjustedToUTC;
+ unit = other29.unit;
+ return *this;
+}
+void TimeType::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "TimeType(";
+ out << "isAdjustedToUTC=" << to_string(isAdjustedToUTC);
+ out << ", " << "unit=" << to_string(unit);
+ out << ")";
+}
+
+
+IntType::~IntType() noexcept {
+}
+
+
+void IntType::__set_bitWidth(const int8_t val) {
+ this->bitWidth = val;
+}
+
+void IntType::__set_isSigned(const bool val) {
+ this->isSigned = val;
+}
+std::ostream& operator<<(std::ostream& out, const IntType& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t IntType::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_bitWidth = false;
+ bool isset_isSigned = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_BYTE) {
+ xfer += iprot->readByte(this->bitWidth);
+ isset_bitWidth = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_BOOL) {
+ xfer += iprot->readBool(this->isSigned);
+ isset_isSigned = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_bitWidth)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_isSigned)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t IntType::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("IntType");
+
+ xfer += oprot->writeFieldBegin("bitWidth", ::apache::thrift::protocol::T_BYTE, 1);
+ xfer += oprot->writeByte(this->bitWidth);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("isSigned", ::apache::thrift::protocol::T_BOOL, 2);
+ xfer += oprot->writeBool(this->isSigned);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(IntType &a, IntType &b) {
+ using ::std::swap;
+ swap(a.bitWidth, b.bitWidth);
+ swap(a.isSigned, b.isSigned);
+}
+
+IntType::IntType(const IntType& other30) {
+ bitWidth = other30.bitWidth;
+ isSigned = other30.isSigned;
+}
+IntType& IntType::operator=(const IntType& other31) {
+ bitWidth = other31.bitWidth;
+ isSigned = other31.isSigned;
+ return *this;
+}
+void IntType::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "IntType(";
+ out << "bitWidth=" << to_string(bitWidth);
+ out << ", " << "isSigned=" << to_string(isSigned);
+ out << ")";
+}
+
+
+JsonType::~JsonType() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const JsonType& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t JsonType::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t JsonType::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("JsonType");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(JsonType &a, JsonType &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+JsonType::JsonType(const JsonType& other32) {
+ (void) other32;
+}
+JsonType& JsonType::operator=(const JsonType& other33) {
+ (void) other33;
+ return *this;
+}
+void JsonType::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "JsonType(";
+ out << ")";
+}
+
+
+BsonType::~BsonType() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const BsonType& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t BsonType::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t BsonType::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("BsonType");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(BsonType &a, BsonType &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+BsonType::BsonType(const BsonType& other34) {
+ (void) other34;
+}
+BsonType& BsonType::operator=(const BsonType& other35) {
+ (void) other35;
+ return *this;
+}
+void BsonType::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "BsonType(";
+ out << ")";
+}
+
+
+LogicalType::~LogicalType() noexcept {
+}
+
+
+void LogicalType::__set_STRING(const StringType& val) {
+ this->STRING = val;
+__isset.STRING = true;
+}
+
+void LogicalType::__set_MAP(const MapType& val) {
+ this->MAP = val;
+__isset.MAP = true;
+}
+
+void LogicalType::__set_LIST(const ListType& val) {
+ this->LIST = val;
+__isset.LIST = true;
+}
+
+void LogicalType::__set_ENUM(const EnumType& val) {
+ this->ENUM = val;
+__isset.ENUM = true;
+}
+
+void LogicalType::__set_DECIMAL(const DecimalType& val) {
+ this->DECIMAL = val;
+__isset.DECIMAL = true;
+}
+
+void LogicalType::__set_DATE(const DateType& val) {
+ this->DATE = val;
+__isset.DATE = true;
+}
+
+void LogicalType::__set_TIME(const TimeType& val) {
+ this->TIME = val;
+__isset.TIME = true;
+}
+
+void LogicalType::__set_TIMESTAMP(const TimestampType& val) {
+ this->TIMESTAMP = val;
+__isset.TIMESTAMP = true;
+}
+
+void LogicalType::__set_INTEGER(const IntType& val) {
+ this->INTEGER = val;
+__isset.INTEGER = true;
+}
+
+void LogicalType::__set_UNKNOWN(const NullType& val) {
+ this->UNKNOWN = val;
+__isset.UNKNOWN = true;
+}
+
+void LogicalType::__set_JSON(const JsonType& val) {
+ this->JSON = val;
+__isset.JSON = true;
+}
+
+void LogicalType::__set_BSON(const BsonType& val) {
+ this->BSON = val;
+__isset.BSON = true;
+}
+
+void LogicalType::__set_UUID(const UUIDType& val) {
+ this->UUID = val;
+__isset.UUID = true;
+}
+std::ostream& operator<<(std::ostream& out, const LogicalType& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t LogicalType::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->STRING.read(iprot);
+ this->__isset.STRING = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->MAP.read(iprot);
+ this->__isset.MAP = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->LIST.read(iprot);
+ this->__isset.LIST = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 4:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->ENUM.read(iprot);
+ this->__isset.ENUM = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 5:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->DECIMAL.read(iprot);
+ this->__isset.DECIMAL = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 6:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->DATE.read(iprot);
+ this->__isset.DATE = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 7:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->TIME.read(iprot);
+ this->__isset.TIME = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 8:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->TIMESTAMP.read(iprot);
+ this->__isset.TIMESTAMP = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 10:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->INTEGER.read(iprot);
+ this->__isset.INTEGER = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 11:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->UNKNOWN.read(iprot);
+ this->__isset.UNKNOWN = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 12:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->JSON.read(iprot);
+ this->__isset.JSON = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 13:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->BSON.read(iprot);
+ this->__isset.BSON = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 14:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->UUID.read(iprot);
+ this->__isset.UUID = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t LogicalType::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("LogicalType");
+
+ if (this->__isset.STRING) {
+ xfer += oprot->writeFieldBegin("STRING", ::apache::thrift::protocol::T_STRUCT, 1);
+ xfer += this->STRING.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.MAP) {
+ xfer += oprot->writeFieldBegin("MAP", ::apache::thrift::protocol::T_STRUCT, 2);
+ xfer += this->MAP.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.LIST) {
+ xfer += oprot->writeFieldBegin("LIST", ::apache::thrift::protocol::T_STRUCT, 3);
+ xfer += this->LIST.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.ENUM) {
+ xfer += oprot->writeFieldBegin("ENUM", ::apache::thrift::protocol::T_STRUCT, 4);
+ xfer += this->ENUM.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.DECIMAL) {
+ xfer += oprot->writeFieldBegin("DECIMAL", ::apache::thrift::protocol::T_STRUCT, 5);
+ xfer += this->DECIMAL.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.DATE) {
+ xfer += oprot->writeFieldBegin("DATE", ::apache::thrift::protocol::T_STRUCT, 6);
+ xfer += this->DATE.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.TIME) {
+ xfer += oprot->writeFieldBegin("TIME", ::apache::thrift::protocol::T_STRUCT, 7);
+ xfer += this->TIME.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.TIMESTAMP) {
+ xfer += oprot->writeFieldBegin("TIMESTAMP", ::apache::thrift::protocol::T_STRUCT, 8);
+ xfer += this->TIMESTAMP.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.INTEGER) {
+ xfer += oprot->writeFieldBegin("INTEGER", ::apache::thrift::protocol::T_STRUCT, 10);
+ xfer += this->INTEGER.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.UNKNOWN) {
+ xfer += oprot->writeFieldBegin("UNKNOWN", ::apache::thrift::protocol::T_STRUCT, 11);
+ xfer += this->UNKNOWN.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.JSON) {
+ xfer += oprot->writeFieldBegin("JSON", ::apache::thrift::protocol::T_STRUCT, 12);
+ xfer += this->JSON.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.BSON) {
+ xfer += oprot->writeFieldBegin("BSON", ::apache::thrift::protocol::T_STRUCT, 13);
+ xfer += this->BSON.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.UUID) {
+ xfer += oprot->writeFieldBegin("UUID", ::apache::thrift::protocol::T_STRUCT, 14);
+ xfer += this->UUID.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(LogicalType &a, LogicalType &b) {
+ using ::std::swap;
+ swap(a.STRING, b.STRING);
+ swap(a.MAP, b.MAP);
+ swap(a.LIST, b.LIST);
+ swap(a.ENUM, b.ENUM);
+ swap(a.DECIMAL, b.DECIMAL);
+ swap(a.DATE, b.DATE);
+ swap(a.TIME, b.TIME);
+ swap(a.TIMESTAMP, b.TIMESTAMP);
+ swap(a.INTEGER, b.INTEGER);
+ swap(a.UNKNOWN, b.UNKNOWN);
+ swap(a.JSON, b.JSON);
+ swap(a.BSON, b.BSON);
+ swap(a.UUID, b.UUID);
+ swap(a.__isset, b.__isset);
+}
+
+LogicalType::LogicalType(const LogicalType& other36) {
+ STRING = other36.STRING;
+ MAP = other36.MAP;
+ LIST = other36.LIST;
+ ENUM = other36.ENUM;
+ DECIMAL = other36.DECIMAL;
+ DATE = other36.DATE;
+ TIME = other36.TIME;
+ TIMESTAMP = other36.TIMESTAMP;
+ INTEGER = other36.INTEGER;
+ UNKNOWN = other36.UNKNOWN;
+ JSON = other36.JSON;
+ BSON = other36.BSON;
+ UUID = other36.UUID;
+ __isset = other36.__isset;
+}
+LogicalType& LogicalType::operator=(const LogicalType& other37) {
+ STRING = other37.STRING;
+ MAP = other37.MAP;
+ LIST = other37.LIST;
+ ENUM = other37.ENUM;
+ DECIMAL = other37.DECIMAL;
+ DATE = other37.DATE;
+ TIME = other37.TIME;
+ TIMESTAMP = other37.TIMESTAMP;
+ INTEGER = other37.INTEGER;
+ UNKNOWN = other37.UNKNOWN;
+ JSON = other37.JSON;
+ BSON = other37.BSON;
+ UUID = other37.UUID;
+ __isset = other37.__isset;
+ return *this;
+}
+void LogicalType::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "LogicalType(";
+ out << "STRING="; (__isset.STRING ? (out << to_string(STRING)) : (out << "<null>"));
+ out << ", " << "MAP="; (__isset.MAP ? (out << to_string(MAP)) : (out << "<null>"));
+ out << ", " << "LIST="; (__isset.LIST ? (out << to_string(LIST)) : (out << "<null>"));
+ out << ", " << "ENUM="; (__isset.ENUM ? (out << to_string(ENUM)) : (out << "<null>"));
+ out << ", " << "DECIMAL="; (__isset.DECIMAL ? (out << to_string(DECIMAL)) : (out << "<null>"));
+ out << ", " << "DATE="; (__isset.DATE ? (out << to_string(DATE)) : (out << "<null>"));
+ out << ", " << "TIME="; (__isset.TIME ? (out << to_string(TIME)) : (out << "<null>"));
+ out << ", " << "TIMESTAMP="; (__isset.TIMESTAMP ? (out << to_string(TIMESTAMP)) : (out << "<null>"));
+ out << ", " << "INTEGER="; (__isset.INTEGER ? (out << to_string(INTEGER)) : (out << "<null>"));
+ out << ", " << "UNKNOWN="; (__isset.UNKNOWN ? (out << to_string(UNKNOWN)) : (out << "<null>"));
+ out << ", " << "JSON="; (__isset.JSON ? (out << to_string(JSON)) : (out << "<null>"));
+ out << ", " << "BSON="; (__isset.BSON ? (out << to_string(BSON)) : (out << "<null>"));
+ out << ", " << "UUID="; (__isset.UUID ? (out << to_string(UUID)) : (out << "<null>"));
+ out << ")";
+}
+
+
+SchemaElement::~SchemaElement() noexcept {
+}
+
+
+void SchemaElement::__set_type(const Type::type val) {
+ this->type = val;
+__isset.type = true;
+}
+
+void SchemaElement::__set_type_length(const int32_t val) {
+ this->type_length = val;
+__isset.type_length = true;
+}
+
+void SchemaElement::__set_repetition_type(const FieldRepetitionType::type val) {
+ this->repetition_type = val;
+__isset.repetition_type = true;
+}
+
+void SchemaElement::__set_name(const std::string& val) {
+ this->name = val;
+}
+
+void SchemaElement::__set_num_children(const int32_t val) {
+ this->num_children = val;
+__isset.num_children = true;
+}
+
+void SchemaElement::__set_converted_type(const ConvertedType::type val) {
+ this->converted_type = val;
+__isset.converted_type = true;
+}
+
+void SchemaElement::__set_scale(const int32_t val) {
+ this->scale = val;
+__isset.scale = true;
+}
+
+void SchemaElement::__set_precision(const int32_t val) {
+ this->precision = val;
+__isset.precision = true;
+}
+
+void SchemaElement::__set_field_id(const int32_t val) {
+ this->field_id = val;
+__isset.field_id = true;
+}
+
+void SchemaElement::__set_logicalType(const LogicalType& val) {
+ this->logicalType = val;
+__isset.logicalType = true;
+}
+std::ostream& operator<<(std::ostream& out, const SchemaElement& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t SchemaElement::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_name = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ int32_t ecast38;
+ xfer += iprot->readI32(ecast38);
+ this->type = (Type::type)ecast38;
+ this->__isset.type = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->type_length);
+ this->__isset.type_length = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ int32_t ecast39;
+ xfer += iprot->readI32(ecast39);
+ this->repetition_type = (FieldRepetitionType::type)ecast39;
+ this->__isset.repetition_type = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 4:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readString(this->name);
+ isset_name = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 5:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->num_children);
+ this->__isset.num_children = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 6:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ int32_t ecast40;
+ xfer += iprot->readI32(ecast40);
+ this->converted_type = (ConvertedType::type)ecast40;
+ this->__isset.converted_type = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 7:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->scale);
+ this->__isset.scale = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 8:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->precision);
+ this->__isset.precision = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 9:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->field_id);
+ this->__isset.field_id = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 10:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->logicalType.read(iprot);
+ this->__isset.logicalType = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_name)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t SchemaElement::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("SchemaElement");
+
+ if (this->__isset.type) {
+ xfer += oprot->writeFieldBegin("type", ::apache::thrift::protocol::T_I32, 1);
+ xfer += oprot->writeI32((int32_t)this->type);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.type_length) {
+ xfer += oprot->writeFieldBegin("type_length", ::apache::thrift::protocol::T_I32, 2);
+ xfer += oprot->writeI32(this->type_length);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.repetition_type) {
+ xfer += oprot->writeFieldBegin("repetition_type", ::apache::thrift::protocol::T_I32, 3);
+ xfer += oprot->writeI32((int32_t)this->repetition_type);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldBegin("name", ::apache::thrift::protocol::T_STRING, 4);
+ xfer += oprot->writeString(this->name);
+ xfer += oprot->writeFieldEnd();
+
+ if (this->__isset.num_children) {
+ xfer += oprot->writeFieldBegin("num_children", ::apache::thrift::protocol::T_I32, 5);
+ xfer += oprot->writeI32(this->num_children);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.converted_type) {
+ xfer += oprot->writeFieldBegin("converted_type", ::apache::thrift::protocol::T_I32, 6);
+ xfer += oprot->writeI32((int32_t)this->converted_type);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.scale) {
+ xfer += oprot->writeFieldBegin("scale", ::apache::thrift::protocol::T_I32, 7);
+ xfer += oprot->writeI32(this->scale);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.precision) {
+ xfer += oprot->writeFieldBegin("precision", ::apache::thrift::protocol::T_I32, 8);
+ xfer += oprot->writeI32(this->precision);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.field_id) {
+ xfer += oprot->writeFieldBegin("field_id", ::apache::thrift::protocol::T_I32, 9);
+ xfer += oprot->writeI32(this->field_id);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.logicalType) {
+ xfer += oprot->writeFieldBegin("logicalType", ::apache::thrift::protocol::T_STRUCT, 10);
+ xfer += this->logicalType.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(SchemaElement &a, SchemaElement &b) {
+ using ::std::swap;
+ swap(a.type, b.type);
+ swap(a.type_length, b.type_length);
+ swap(a.repetition_type, b.repetition_type);
+ swap(a.name, b.name);
+ swap(a.num_children, b.num_children);
+ swap(a.converted_type, b.converted_type);
+ swap(a.scale, b.scale);
+ swap(a.precision, b.precision);
+ swap(a.field_id, b.field_id);
+ swap(a.logicalType, b.logicalType);
+ swap(a.__isset, b.__isset);
+}
+
+SchemaElement::SchemaElement(const SchemaElement& other41) {
+ type = other41.type;
+ type_length = other41.type_length;
+ repetition_type = other41.repetition_type;
+ name = other41.name;
+ num_children = other41.num_children;
+ converted_type = other41.converted_type;
+ scale = other41.scale;
+ precision = other41.precision;
+ field_id = other41.field_id;
+ logicalType = other41.logicalType;
+ __isset = other41.__isset;
+}
+SchemaElement& SchemaElement::operator=(const SchemaElement& other42) {
+ type = other42.type;
+ type_length = other42.type_length;
+ repetition_type = other42.repetition_type;
+ name = other42.name;
+ num_children = other42.num_children;
+ converted_type = other42.converted_type;
+ scale = other42.scale;
+ precision = other42.precision;
+ field_id = other42.field_id;
+ logicalType = other42.logicalType;
+ __isset = other42.__isset;
+ return *this;
+}
+void SchemaElement::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "SchemaElement(";
+ out << "type="; (__isset.type ? (out << to_string(type)) : (out << "<null>"));
+ out << ", " << "type_length="; (__isset.type_length ? (out << to_string(type_length)) : (out << "<null>"));
+ out << ", " << "repetition_type="; (__isset.repetition_type ? (out << to_string(repetition_type)) : (out << "<null>"));
+ out << ", " << "name=" << to_string(name);
+ out << ", " << "num_children="; (__isset.num_children ? (out << to_string(num_children)) : (out << "<null>"));
+ out << ", " << "converted_type="; (__isset.converted_type ? (out << to_string(converted_type)) : (out << "<null>"));
+ out << ", " << "scale="; (__isset.scale ? (out << to_string(scale)) : (out << "<null>"));
+ out << ", " << "precision="; (__isset.precision ? (out << to_string(precision)) : (out << "<null>"));
+ out << ", " << "field_id="; (__isset.field_id ? (out << to_string(field_id)) : (out << "<null>"));
+ out << ", " << "logicalType="; (__isset.logicalType ? (out << to_string(logicalType)) : (out << "<null>"));
+ out << ")";
+}
+
+
+DataPageHeader::~DataPageHeader() noexcept {
+}
+
+
+void DataPageHeader::__set_num_values(const int32_t val) {
+ this->num_values = val;
+}
+
+void DataPageHeader::__set_encoding(const Encoding::type val) {
+ this->encoding = val;
+}
+
+void DataPageHeader::__set_definition_level_encoding(const Encoding::type val) {
+ this->definition_level_encoding = val;
+}
+
+void DataPageHeader::__set_repetition_level_encoding(const Encoding::type val) {
+ this->repetition_level_encoding = val;
+}
+
+void DataPageHeader::__set_statistics(const Statistics& val) {
+ this->statistics = val;
+__isset.statistics = true;
+}
+std::ostream& operator<<(std::ostream& out, const DataPageHeader& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t DataPageHeader::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_num_values = false;
+ bool isset_encoding = false;
+ bool isset_definition_level_encoding = false;
+ bool isset_repetition_level_encoding = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->num_values);
+ isset_num_values = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ int32_t ecast43;
+ xfer += iprot->readI32(ecast43);
+ this->encoding = (Encoding::type)ecast43;
+ isset_encoding = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ int32_t ecast44;
+ xfer += iprot->readI32(ecast44);
+ this->definition_level_encoding = (Encoding::type)ecast44;
+ isset_definition_level_encoding = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 4:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ int32_t ecast45;
+ xfer += iprot->readI32(ecast45);
+ this->repetition_level_encoding = (Encoding::type)ecast45;
+ isset_repetition_level_encoding = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 5:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->statistics.read(iprot);
+ this->__isset.statistics = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_num_values)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_encoding)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_definition_level_encoding)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_repetition_level_encoding)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t DataPageHeader::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("DataPageHeader");
+
+ xfer += oprot->writeFieldBegin("num_values", ::apache::thrift::protocol::T_I32, 1);
+ xfer += oprot->writeI32(this->num_values);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("encoding", ::apache::thrift::protocol::T_I32, 2);
+ xfer += oprot->writeI32((int32_t)this->encoding);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("definition_level_encoding", ::apache::thrift::protocol::T_I32, 3);
+ xfer += oprot->writeI32((int32_t)this->definition_level_encoding);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("repetition_level_encoding", ::apache::thrift::protocol::T_I32, 4);
+ xfer += oprot->writeI32((int32_t)this->repetition_level_encoding);
+ xfer += oprot->writeFieldEnd();
+
+ if (this->__isset.statistics) {
+ xfer += oprot->writeFieldBegin("statistics", ::apache::thrift::protocol::T_STRUCT, 5);
+ xfer += this->statistics.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(DataPageHeader &a, DataPageHeader &b) {
+ using ::std::swap;
+ swap(a.num_values, b.num_values);
+ swap(a.encoding, b.encoding);
+ swap(a.definition_level_encoding, b.definition_level_encoding);
+ swap(a.repetition_level_encoding, b.repetition_level_encoding);
+ swap(a.statistics, b.statistics);
+ swap(a.__isset, b.__isset);
+}
+
+DataPageHeader::DataPageHeader(const DataPageHeader& other46) {
+ num_values = other46.num_values;
+ encoding = other46.encoding;
+ definition_level_encoding = other46.definition_level_encoding;
+ repetition_level_encoding = other46.repetition_level_encoding;
+ statistics = other46.statistics;
+ __isset = other46.__isset;
+}
+DataPageHeader& DataPageHeader::operator=(const DataPageHeader& other47) {
+ num_values = other47.num_values;
+ encoding = other47.encoding;
+ definition_level_encoding = other47.definition_level_encoding;
+ repetition_level_encoding = other47.repetition_level_encoding;
+ statistics = other47.statistics;
+ __isset = other47.__isset;
+ return *this;
+}
+void DataPageHeader::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "DataPageHeader(";
+ out << "num_values=" << to_string(num_values);
+ out << ", " << "encoding=" << to_string(encoding);
+ out << ", " << "definition_level_encoding=" << to_string(definition_level_encoding);
+ out << ", " << "repetition_level_encoding=" << to_string(repetition_level_encoding);
+ out << ", " << "statistics="; (__isset.statistics ? (out << to_string(statistics)) : (out << "<null>"));
+ out << ")";
+}
+
+
+IndexPageHeader::~IndexPageHeader() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const IndexPageHeader& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t IndexPageHeader::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t IndexPageHeader::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("IndexPageHeader");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(IndexPageHeader &a, IndexPageHeader &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+IndexPageHeader::IndexPageHeader(const IndexPageHeader& other48) {
+ (void) other48;
+}
+IndexPageHeader& IndexPageHeader::operator=(const IndexPageHeader& other49) {
+ (void) other49;
+ return *this;
+}
+void IndexPageHeader::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "IndexPageHeader(";
+ out << ")";
+}
+
+
+DictionaryPageHeader::~DictionaryPageHeader() noexcept {
+}
+
+
+void DictionaryPageHeader::__set_num_values(const int32_t val) {
+ this->num_values = val;
+}
+
+void DictionaryPageHeader::__set_encoding(const Encoding::type val) {
+ this->encoding = val;
+}
+
+void DictionaryPageHeader::__set_is_sorted(const bool val) {
+ this->is_sorted = val;
+__isset.is_sorted = true;
+}
+std::ostream& operator<<(std::ostream& out, const DictionaryPageHeader& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t DictionaryPageHeader::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_num_values = false;
+ bool isset_encoding = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->num_values);
+ isset_num_values = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ int32_t ecast50;
+ xfer += iprot->readI32(ecast50);
+ this->encoding = (Encoding::type)ecast50;
+ isset_encoding = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_BOOL) {
+ xfer += iprot->readBool(this->is_sorted);
+ this->__isset.is_sorted = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_num_values)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_encoding)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t DictionaryPageHeader::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("DictionaryPageHeader");
+
+ xfer += oprot->writeFieldBegin("num_values", ::apache::thrift::protocol::T_I32, 1);
+ xfer += oprot->writeI32(this->num_values);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("encoding", ::apache::thrift::protocol::T_I32, 2);
+ xfer += oprot->writeI32((int32_t)this->encoding);
+ xfer += oprot->writeFieldEnd();
+
+ if (this->__isset.is_sorted) {
+ xfer += oprot->writeFieldBegin("is_sorted", ::apache::thrift::protocol::T_BOOL, 3);
+ xfer += oprot->writeBool(this->is_sorted);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(DictionaryPageHeader &a, DictionaryPageHeader &b) {
+ using ::std::swap;
+ swap(a.num_values, b.num_values);
+ swap(a.encoding, b.encoding);
+ swap(a.is_sorted, b.is_sorted);
+ swap(a.__isset, b.__isset);
+}
+
+DictionaryPageHeader::DictionaryPageHeader(const DictionaryPageHeader& other51) {
+ num_values = other51.num_values;
+ encoding = other51.encoding;
+ is_sorted = other51.is_sorted;
+ __isset = other51.__isset;
+}
+DictionaryPageHeader& DictionaryPageHeader::operator=(const DictionaryPageHeader& other52) {
+ num_values = other52.num_values;
+ encoding = other52.encoding;
+ is_sorted = other52.is_sorted;
+ __isset = other52.__isset;
+ return *this;
+}
+void DictionaryPageHeader::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "DictionaryPageHeader(";
+ out << "num_values=" << to_string(num_values);
+ out << ", " << "encoding=" << to_string(encoding);
+ out << ", " << "is_sorted="; (__isset.is_sorted ? (out << to_string(is_sorted)) : (out << "<null>"));
+ out << ")";
+}
+
+
+DataPageHeaderV2::~DataPageHeaderV2() noexcept {
+}
+
+
+void DataPageHeaderV2::__set_num_values(const int32_t val) {
+ this->num_values = val;
+}
+
+void DataPageHeaderV2::__set_num_nulls(const int32_t val) {
+ this->num_nulls = val;
+}
+
+void DataPageHeaderV2::__set_num_rows(const int32_t val) {
+ this->num_rows = val;
+}
+
+void DataPageHeaderV2::__set_encoding(const Encoding::type val) {
+ this->encoding = val;
+}
+
+void DataPageHeaderV2::__set_definition_levels_byte_length(const int32_t val) {
+ this->definition_levels_byte_length = val;
+}
+
+void DataPageHeaderV2::__set_repetition_levels_byte_length(const int32_t val) {
+ this->repetition_levels_byte_length = val;
+}
+
+void DataPageHeaderV2::__set_is_compressed(const bool val) {
+ this->is_compressed = val;
+__isset.is_compressed = true;
+}
+
+void DataPageHeaderV2::__set_statistics(const Statistics& val) {
+ this->statistics = val;
+__isset.statistics = true;
+}
+std::ostream& operator<<(std::ostream& out, const DataPageHeaderV2& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t DataPageHeaderV2::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_num_values = false;
+ bool isset_num_nulls = false;
+ bool isset_num_rows = false;
+ bool isset_encoding = false;
+ bool isset_definition_levels_byte_length = false;
+ bool isset_repetition_levels_byte_length = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->num_values);
+ isset_num_values = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->num_nulls);
+ isset_num_nulls = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->num_rows);
+ isset_num_rows = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 4:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ int32_t ecast53;
+ xfer += iprot->readI32(ecast53);
+ this->encoding = (Encoding::type)ecast53;
+ isset_encoding = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 5:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->definition_levels_byte_length);
+ isset_definition_levels_byte_length = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 6:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->repetition_levels_byte_length);
+ isset_repetition_levels_byte_length = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 7:
+ if (ftype == ::apache::thrift::protocol::T_BOOL) {
+ xfer += iprot->readBool(this->is_compressed);
+ this->__isset.is_compressed = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 8:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->statistics.read(iprot);
+ this->__isset.statistics = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_num_values)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_num_nulls)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_num_rows)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_encoding)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_definition_levels_byte_length)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_repetition_levels_byte_length)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t DataPageHeaderV2::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("DataPageHeaderV2");
+
+ xfer += oprot->writeFieldBegin("num_values", ::apache::thrift::protocol::T_I32, 1);
+ xfer += oprot->writeI32(this->num_values);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("num_nulls", ::apache::thrift::protocol::T_I32, 2);
+ xfer += oprot->writeI32(this->num_nulls);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("num_rows", ::apache::thrift::protocol::T_I32, 3);
+ xfer += oprot->writeI32(this->num_rows);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("encoding", ::apache::thrift::protocol::T_I32, 4);
+ xfer += oprot->writeI32((int32_t)this->encoding);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("definition_levels_byte_length", ::apache::thrift::protocol::T_I32, 5);
+ xfer += oprot->writeI32(this->definition_levels_byte_length);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("repetition_levels_byte_length", ::apache::thrift::protocol::T_I32, 6);
+ xfer += oprot->writeI32(this->repetition_levels_byte_length);
+ xfer += oprot->writeFieldEnd();
+
+ if (this->__isset.is_compressed) {
+ xfer += oprot->writeFieldBegin("is_compressed", ::apache::thrift::protocol::T_BOOL, 7);
+ xfer += oprot->writeBool(this->is_compressed);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.statistics) {
+ xfer += oprot->writeFieldBegin("statistics", ::apache::thrift::protocol::T_STRUCT, 8);
+ xfer += this->statistics.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(DataPageHeaderV2 &a, DataPageHeaderV2 &b) {
+ using ::std::swap;
+ swap(a.num_values, b.num_values);
+ swap(a.num_nulls, b.num_nulls);
+ swap(a.num_rows, b.num_rows);
+ swap(a.encoding, b.encoding);
+ swap(a.definition_levels_byte_length, b.definition_levels_byte_length);
+ swap(a.repetition_levels_byte_length, b.repetition_levels_byte_length);
+ swap(a.is_compressed, b.is_compressed);
+ swap(a.statistics, b.statistics);
+ swap(a.__isset, b.__isset);
+}
+
+DataPageHeaderV2::DataPageHeaderV2(const DataPageHeaderV2& other54) {
+ num_values = other54.num_values;
+ num_nulls = other54.num_nulls;
+ num_rows = other54.num_rows;
+ encoding = other54.encoding;
+ definition_levels_byte_length = other54.definition_levels_byte_length;
+ repetition_levels_byte_length = other54.repetition_levels_byte_length;
+ is_compressed = other54.is_compressed;
+ statistics = other54.statistics;
+ __isset = other54.__isset;
+}
+DataPageHeaderV2& DataPageHeaderV2::operator=(const DataPageHeaderV2& other55) {
+ num_values = other55.num_values;
+ num_nulls = other55.num_nulls;
+ num_rows = other55.num_rows;
+ encoding = other55.encoding;
+ definition_levels_byte_length = other55.definition_levels_byte_length;
+ repetition_levels_byte_length = other55.repetition_levels_byte_length;
+ is_compressed = other55.is_compressed;
+ statistics = other55.statistics;
+ __isset = other55.__isset;
+ return *this;
+}
+void DataPageHeaderV2::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "DataPageHeaderV2(";
+ out << "num_values=" << to_string(num_values);
+ out << ", " << "num_nulls=" << to_string(num_nulls);
+ out << ", " << "num_rows=" << to_string(num_rows);
+ out << ", " << "encoding=" << to_string(encoding);
+ out << ", " << "definition_levels_byte_length=" << to_string(definition_levels_byte_length);
+ out << ", " << "repetition_levels_byte_length=" << to_string(repetition_levels_byte_length);
+ out << ", " << "is_compressed="; (__isset.is_compressed ? (out << to_string(is_compressed)) : (out << "<null>"));
+ out << ", " << "statistics="; (__isset.statistics ? (out << to_string(statistics)) : (out << "<null>"));
+ out << ")";
+}
+
+
+SplitBlockAlgorithm::~SplitBlockAlgorithm() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const SplitBlockAlgorithm& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t SplitBlockAlgorithm::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t SplitBlockAlgorithm::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("SplitBlockAlgorithm");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(SplitBlockAlgorithm &a, SplitBlockAlgorithm &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+SplitBlockAlgorithm::SplitBlockAlgorithm(const SplitBlockAlgorithm& other56) {
+ (void) other56;
+}
+SplitBlockAlgorithm& SplitBlockAlgorithm::operator=(const SplitBlockAlgorithm& other57) {
+ (void) other57;
+ return *this;
+}
+void SplitBlockAlgorithm::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "SplitBlockAlgorithm(";
+ out << ")";
+}
+
+
+BloomFilterAlgorithm::~BloomFilterAlgorithm() noexcept {
+}
+
+
+void BloomFilterAlgorithm::__set_BLOCK(const SplitBlockAlgorithm& val) {
+ this->BLOCK = val;
+__isset.BLOCK = true;
+}
+std::ostream& operator<<(std::ostream& out, const BloomFilterAlgorithm& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t BloomFilterAlgorithm::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->BLOCK.read(iprot);
+ this->__isset.BLOCK = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t BloomFilterAlgorithm::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("BloomFilterAlgorithm");
+
+ if (this->__isset.BLOCK) {
+ xfer += oprot->writeFieldBegin("BLOCK", ::apache::thrift::protocol::T_STRUCT, 1);
+ xfer += this->BLOCK.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(BloomFilterAlgorithm &a, BloomFilterAlgorithm &b) {
+ using ::std::swap;
+ swap(a.BLOCK, b.BLOCK);
+ swap(a.__isset, b.__isset);
+}
+
+BloomFilterAlgorithm::BloomFilterAlgorithm(const BloomFilterAlgorithm& other58) {
+ BLOCK = other58.BLOCK;
+ __isset = other58.__isset;
+}
+BloomFilterAlgorithm& BloomFilterAlgorithm::operator=(const BloomFilterAlgorithm& other59) {
+ BLOCK = other59.BLOCK;
+ __isset = other59.__isset;
+ return *this;
+}
+void BloomFilterAlgorithm::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "BloomFilterAlgorithm(";
+ out << "BLOCK="; (__isset.BLOCK ? (out << to_string(BLOCK)) : (out << "<null>"));
+ out << ")";
+}
+
+
+XxHash::~XxHash() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const XxHash& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t XxHash::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t XxHash::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("XxHash");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(XxHash &a, XxHash &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+XxHash::XxHash(const XxHash& other60) {
+ (void) other60;
+}
+XxHash& XxHash::operator=(const XxHash& other61) {
+ (void) other61;
+ return *this;
+}
+void XxHash::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "XxHash(";
+ out << ")";
+}
+
+
+BloomFilterHash::~BloomFilterHash() noexcept {
+}
+
+
+void BloomFilterHash::__set_XXHASH(const XxHash& val) {
+ this->XXHASH = val;
+__isset.XXHASH = true;
+}
+std::ostream& operator<<(std::ostream& out, const BloomFilterHash& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t BloomFilterHash::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->XXHASH.read(iprot);
+ this->__isset.XXHASH = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t BloomFilterHash::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("BloomFilterHash");
+
+ if (this->__isset.XXHASH) {
+ xfer += oprot->writeFieldBegin("XXHASH", ::apache::thrift::protocol::T_STRUCT, 1);
+ xfer += this->XXHASH.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(BloomFilterHash &a, BloomFilterHash &b) {
+ using ::std::swap;
+ swap(a.XXHASH, b.XXHASH);
+ swap(a.__isset, b.__isset);
+}
+
+BloomFilterHash::BloomFilterHash(const BloomFilterHash& other62) {
+ XXHASH = other62.XXHASH;
+ __isset = other62.__isset;
+}
+BloomFilterHash& BloomFilterHash::operator=(const BloomFilterHash& other63) {
+ XXHASH = other63.XXHASH;
+ __isset = other63.__isset;
+ return *this;
+}
+void BloomFilterHash::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "BloomFilterHash(";
+ out << "XXHASH="; (__isset.XXHASH ? (out << to_string(XXHASH)) : (out << "<null>"));
+ out << ")";
+}
+
+
+Uncompressed::~Uncompressed() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const Uncompressed& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t Uncompressed::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t Uncompressed::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("Uncompressed");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(Uncompressed &a, Uncompressed &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+Uncompressed::Uncompressed(const Uncompressed& other64) {
+ (void) other64;
+}
+Uncompressed& Uncompressed::operator=(const Uncompressed& other65) {
+ (void) other65;
+ return *this;
+}
+void Uncompressed::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "Uncompressed(";
+ out << ")";
+}
+
+
+BloomFilterCompression::~BloomFilterCompression() noexcept {
+}
+
+
+void BloomFilterCompression::__set_UNCOMPRESSED(const Uncompressed& val) {
+ this->UNCOMPRESSED = val;
+__isset.UNCOMPRESSED = true;
+}
+std::ostream& operator<<(std::ostream& out, const BloomFilterCompression& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t BloomFilterCompression::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->UNCOMPRESSED.read(iprot);
+ this->__isset.UNCOMPRESSED = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t BloomFilterCompression::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("BloomFilterCompression");
+
+ if (this->__isset.UNCOMPRESSED) {
+ xfer += oprot->writeFieldBegin("UNCOMPRESSED", ::apache::thrift::protocol::T_STRUCT, 1);
+ xfer += this->UNCOMPRESSED.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(BloomFilterCompression &a, BloomFilterCompression &b) {
+ using ::std::swap;
+ swap(a.UNCOMPRESSED, b.UNCOMPRESSED);
+ swap(a.__isset, b.__isset);
+}
+
+BloomFilterCompression::BloomFilterCompression(const BloomFilterCompression& other66) {
+ UNCOMPRESSED = other66.UNCOMPRESSED;
+ __isset = other66.__isset;
+}
+BloomFilterCompression& BloomFilterCompression::operator=(const BloomFilterCompression& other67) {
+ UNCOMPRESSED = other67.UNCOMPRESSED;
+ __isset = other67.__isset;
+ return *this;
+}
+void BloomFilterCompression::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "BloomFilterCompression(";
+ out << "UNCOMPRESSED="; (__isset.UNCOMPRESSED ? (out << to_string(UNCOMPRESSED)) : (out << "<null>"));
+ out << ")";
+}
+
+
+BloomFilterHeader::~BloomFilterHeader() noexcept {
+}
+
+
+void BloomFilterHeader::__set_numBytes(const int32_t val) {
+ this->numBytes = val;
+}
+
+void BloomFilterHeader::__set_algorithm(const BloomFilterAlgorithm& val) {
+ this->algorithm = val;
+}
+
+void BloomFilterHeader::__set_hash(const BloomFilterHash& val) {
+ this->hash = val;
+}
+
+void BloomFilterHeader::__set_compression(const BloomFilterCompression& val) {
+ this->compression = val;
+}
+std::ostream& operator<<(std::ostream& out, const BloomFilterHeader& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t BloomFilterHeader::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_numBytes = false;
+ bool isset_algorithm = false;
+ bool isset_hash = false;
+ bool isset_compression = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->numBytes);
+ isset_numBytes = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->algorithm.read(iprot);
+ isset_algorithm = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->hash.read(iprot);
+ isset_hash = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 4:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->compression.read(iprot);
+ isset_compression = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_numBytes)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_algorithm)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_hash)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_compression)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t BloomFilterHeader::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("BloomFilterHeader");
+
+ xfer += oprot->writeFieldBegin("numBytes", ::apache::thrift::protocol::T_I32, 1);
+ xfer += oprot->writeI32(this->numBytes);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("algorithm", ::apache::thrift::protocol::T_STRUCT, 2);
+ xfer += this->algorithm.write(oprot);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("hash", ::apache::thrift::protocol::T_STRUCT, 3);
+ xfer += this->hash.write(oprot);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("compression", ::apache::thrift::protocol::T_STRUCT, 4);
+ xfer += this->compression.write(oprot);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(BloomFilterHeader &a, BloomFilterHeader &b) {
+ using ::std::swap;
+ swap(a.numBytes, b.numBytes);
+ swap(a.algorithm, b.algorithm);
+ swap(a.hash, b.hash);
+ swap(a.compression, b.compression);
+}
+
+BloomFilterHeader::BloomFilterHeader(const BloomFilterHeader& other68) {
+ numBytes = other68.numBytes;
+ algorithm = other68.algorithm;
+ hash = other68.hash;
+ compression = other68.compression;
+}
+BloomFilterHeader& BloomFilterHeader::operator=(const BloomFilterHeader& other69) {
+ numBytes = other69.numBytes;
+ algorithm = other69.algorithm;
+ hash = other69.hash;
+ compression = other69.compression;
+ return *this;
+}
+void BloomFilterHeader::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "BloomFilterHeader(";
+ out << "numBytes=" << to_string(numBytes);
+ out << ", " << "algorithm=" << to_string(algorithm);
+ out << ", " << "hash=" << to_string(hash);
+ out << ", " << "compression=" << to_string(compression);
+ out << ")";
+}
+
+
+PageHeader::~PageHeader() noexcept {
+}
+
+
+void PageHeader::__set_type(const PageType::type val) {
+ this->type = val;
+}
+
+void PageHeader::__set_uncompressed_page_size(const int32_t val) {
+ this->uncompressed_page_size = val;
+}
+
+void PageHeader::__set_compressed_page_size(const int32_t val) {
+ this->compressed_page_size = val;
+}
+
+void PageHeader::__set_crc(const int32_t val) {
+ this->crc = val;
+__isset.crc = true;
+}
+
+void PageHeader::__set_data_page_header(const DataPageHeader& val) {
+ this->data_page_header = val;
+__isset.data_page_header = true;
+}
+
+void PageHeader::__set_index_page_header(const IndexPageHeader& val) {
+ this->index_page_header = val;
+__isset.index_page_header = true;
+}
+
+void PageHeader::__set_dictionary_page_header(const DictionaryPageHeader& val) {
+ this->dictionary_page_header = val;
+__isset.dictionary_page_header = true;
+}
+
+void PageHeader::__set_data_page_header_v2(const DataPageHeaderV2& val) {
+ this->data_page_header_v2 = val;
+__isset.data_page_header_v2 = true;
+}
+std::ostream& operator<<(std::ostream& out, const PageHeader& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t PageHeader::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_type = false;
+ bool isset_uncompressed_page_size = false;
+ bool isset_compressed_page_size = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ int32_t ecast70;
+ xfer += iprot->readI32(ecast70);
+ this->type = (PageType::type)ecast70;
+ isset_type = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->uncompressed_page_size);
+ isset_uncompressed_page_size = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->compressed_page_size);
+ isset_compressed_page_size = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 4:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->crc);
+ this->__isset.crc = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 5:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->data_page_header.read(iprot);
+ this->__isset.data_page_header = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 6:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->index_page_header.read(iprot);
+ this->__isset.index_page_header = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 7:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->dictionary_page_header.read(iprot);
+ this->__isset.dictionary_page_header = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 8:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->data_page_header_v2.read(iprot);
+ this->__isset.data_page_header_v2 = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_type)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_uncompressed_page_size)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_compressed_page_size)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t PageHeader::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("PageHeader");
+
+ xfer += oprot->writeFieldBegin("type", ::apache::thrift::protocol::T_I32, 1);
+ xfer += oprot->writeI32((int32_t)this->type);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("uncompressed_page_size", ::apache::thrift::protocol::T_I32, 2);
+ xfer += oprot->writeI32(this->uncompressed_page_size);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("compressed_page_size", ::apache::thrift::protocol::T_I32, 3);
+ xfer += oprot->writeI32(this->compressed_page_size);
+ xfer += oprot->writeFieldEnd();
+
+ if (this->__isset.crc) {
+ xfer += oprot->writeFieldBegin("crc", ::apache::thrift::protocol::T_I32, 4);
+ xfer += oprot->writeI32(this->crc);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.data_page_header) {
+ xfer += oprot->writeFieldBegin("data_page_header", ::apache::thrift::protocol::T_STRUCT, 5);
+ xfer += this->data_page_header.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.index_page_header) {
+ xfer += oprot->writeFieldBegin("index_page_header", ::apache::thrift::protocol::T_STRUCT, 6);
+ xfer += this->index_page_header.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.dictionary_page_header) {
+ xfer += oprot->writeFieldBegin("dictionary_page_header", ::apache::thrift::protocol::T_STRUCT, 7);
+ xfer += this->dictionary_page_header.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.data_page_header_v2) {
+ xfer += oprot->writeFieldBegin("data_page_header_v2", ::apache::thrift::protocol::T_STRUCT, 8);
+ xfer += this->data_page_header_v2.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(PageHeader &a, PageHeader &b) {
+ using ::std::swap;
+ swap(a.type, b.type);
+ swap(a.uncompressed_page_size, b.uncompressed_page_size);
+ swap(a.compressed_page_size, b.compressed_page_size);
+ swap(a.crc, b.crc);
+ swap(a.data_page_header, b.data_page_header);
+ swap(a.index_page_header, b.index_page_header);
+ swap(a.dictionary_page_header, b.dictionary_page_header);
+ swap(a.data_page_header_v2, b.data_page_header_v2);
+ swap(a.__isset, b.__isset);
+}
+
+PageHeader::PageHeader(const PageHeader& other71) {
+ type = other71.type;
+ uncompressed_page_size = other71.uncompressed_page_size;
+ compressed_page_size = other71.compressed_page_size;
+ crc = other71.crc;
+ data_page_header = other71.data_page_header;
+ index_page_header = other71.index_page_header;
+ dictionary_page_header = other71.dictionary_page_header;
+ data_page_header_v2 = other71.data_page_header_v2;
+ __isset = other71.__isset;
+}
+PageHeader& PageHeader::operator=(const PageHeader& other72) {
+ type = other72.type;
+ uncompressed_page_size = other72.uncompressed_page_size;
+ compressed_page_size = other72.compressed_page_size;
+ crc = other72.crc;
+ data_page_header = other72.data_page_header;
+ index_page_header = other72.index_page_header;
+ dictionary_page_header = other72.dictionary_page_header;
+ data_page_header_v2 = other72.data_page_header_v2;
+ __isset = other72.__isset;
+ return *this;
+}
+void PageHeader::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "PageHeader(";
+ out << "type=" << to_string(type);
+ out << ", " << "uncompressed_page_size=" << to_string(uncompressed_page_size);
+ out << ", " << "compressed_page_size=" << to_string(compressed_page_size);
+ out << ", " << "crc="; (__isset.crc ? (out << to_string(crc)) : (out << "<null>"));
+ out << ", " << "data_page_header="; (__isset.data_page_header ? (out << to_string(data_page_header)) : (out << "<null>"));
+ out << ", " << "index_page_header="; (__isset.index_page_header ? (out << to_string(index_page_header)) : (out << "<null>"));
+ out << ", " << "dictionary_page_header="; (__isset.dictionary_page_header ? (out << to_string(dictionary_page_header)) : (out << "<null>"));
+ out << ", " << "data_page_header_v2="; (__isset.data_page_header_v2 ? (out << to_string(data_page_header_v2)) : (out << "<null>"));
+ out << ")";
+}
+
+
+KeyValue::~KeyValue() noexcept {
+}
+
+
+void KeyValue::__set_key(const std::string& val) {
+ this->key = val;
+}
+
+void KeyValue::__set_value(const std::string& val) {
+ this->value = val;
+__isset.value = true;
+}
+std::ostream& operator<<(std::ostream& out, const KeyValue& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t KeyValue::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_key = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readString(this->key);
+ isset_key = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readString(this->value);
+ this->__isset.value = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_key)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t KeyValue::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("KeyValue");
+
+ xfer += oprot->writeFieldBegin("key", ::apache::thrift::protocol::T_STRING, 1);
+ xfer += oprot->writeString(this->key);
+ xfer += oprot->writeFieldEnd();
+
+ if (this->__isset.value) {
+ xfer += oprot->writeFieldBegin("value", ::apache::thrift::protocol::T_STRING, 2);
+ xfer += oprot->writeString(this->value);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(KeyValue &a, KeyValue &b) {
+ using ::std::swap;
+ swap(a.key, b.key);
+ swap(a.value, b.value);
+ swap(a.__isset, b.__isset);
+}
+
+KeyValue::KeyValue(const KeyValue& other73) {
+ key = other73.key;
+ value = other73.value;
+ __isset = other73.__isset;
+}
+KeyValue& KeyValue::operator=(const KeyValue& other74) {
+ key = other74.key;
+ value = other74.value;
+ __isset = other74.__isset;
+ return *this;
+}
+void KeyValue::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "KeyValue(";
+ out << "key=" << to_string(key);
+ out << ", " << "value="; (__isset.value ? (out << to_string(value)) : (out << "<null>"));
+ out << ")";
+}
+
+
+SortingColumn::~SortingColumn() noexcept {
+}
+
+
+void SortingColumn::__set_column_idx(const int32_t val) {
+ this->column_idx = val;
+}
+
+void SortingColumn::__set_descending(const bool val) {
+ this->descending = val;
+}
+
+void SortingColumn::__set_nulls_first(const bool val) {
+ this->nulls_first = val;
+}
+std::ostream& operator<<(std::ostream& out, const SortingColumn& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t SortingColumn::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_column_idx = false;
+ bool isset_descending = false;
+ bool isset_nulls_first = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->column_idx);
+ isset_column_idx = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_BOOL) {
+ xfer += iprot->readBool(this->descending);
+ isset_descending = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_BOOL) {
+ xfer += iprot->readBool(this->nulls_first);
+ isset_nulls_first = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_column_idx)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_descending)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_nulls_first)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t SortingColumn::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("SortingColumn");
+
+ xfer += oprot->writeFieldBegin("column_idx", ::apache::thrift::protocol::T_I32, 1);
+ xfer += oprot->writeI32(this->column_idx);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("descending", ::apache::thrift::protocol::T_BOOL, 2);
+ xfer += oprot->writeBool(this->descending);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("nulls_first", ::apache::thrift::protocol::T_BOOL, 3);
+ xfer += oprot->writeBool(this->nulls_first);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(SortingColumn &a, SortingColumn &b) {
+ using ::std::swap;
+ swap(a.column_idx, b.column_idx);
+ swap(a.descending, b.descending);
+ swap(a.nulls_first, b.nulls_first);
+}
+
+SortingColumn::SortingColumn(const SortingColumn& other75) {
+ column_idx = other75.column_idx;
+ descending = other75.descending;
+ nulls_first = other75.nulls_first;
+}
+SortingColumn& SortingColumn::operator=(const SortingColumn& other76) {
+ column_idx = other76.column_idx;
+ descending = other76.descending;
+ nulls_first = other76.nulls_first;
+ return *this;
+}
+void SortingColumn::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "SortingColumn(";
+ out << "column_idx=" << to_string(column_idx);
+ out << ", " << "descending=" << to_string(descending);
+ out << ", " << "nulls_first=" << to_string(nulls_first);
+ out << ")";
+}
+
+
+PageEncodingStats::~PageEncodingStats() noexcept {
+}
+
+
+void PageEncodingStats::__set_page_type(const PageType::type val) {
+ this->page_type = val;
+}
+
+void PageEncodingStats::__set_encoding(const Encoding::type val) {
+ this->encoding = val;
+}
+
+void PageEncodingStats::__set_count(const int32_t val) {
+ this->count = val;
+}
+std::ostream& operator<<(std::ostream& out, const PageEncodingStats& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t PageEncodingStats::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_page_type = false;
+ bool isset_encoding = false;
+ bool isset_count = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ int32_t ecast77;
+ xfer += iprot->readI32(ecast77);
+ this->page_type = (PageType::type)ecast77;
+ isset_page_type = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ int32_t ecast78;
+ xfer += iprot->readI32(ecast78);
+ this->encoding = (Encoding::type)ecast78;
+ isset_encoding = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->count);
+ isset_count = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_page_type)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_encoding)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_count)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t PageEncodingStats::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("PageEncodingStats");
+
+ xfer += oprot->writeFieldBegin("page_type", ::apache::thrift::protocol::T_I32, 1);
+ xfer += oprot->writeI32((int32_t)this->page_type);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("encoding", ::apache::thrift::protocol::T_I32, 2);
+ xfer += oprot->writeI32((int32_t)this->encoding);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("count", ::apache::thrift::protocol::T_I32, 3);
+ xfer += oprot->writeI32(this->count);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(PageEncodingStats &a, PageEncodingStats &b) {
+ using ::std::swap;
+ swap(a.page_type, b.page_type);
+ swap(a.encoding, b.encoding);
+ swap(a.count, b.count);
+}
+
+PageEncodingStats::PageEncodingStats(const PageEncodingStats& other79) {
+ page_type = other79.page_type;
+ encoding = other79.encoding;
+ count = other79.count;
+}
+PageEncodingStats& PageEncodingStats::operator=(const PageEncodingStats& other80) {
+ page_type = other80.page_type;
+ encoding = other80.encoding;
+ count = other80.count;
+ return *this;
+}
+void PageEncodingStats::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "PageEncodingStats(";
+ out << "page_type=" << to_string(page_type);
+ out << ", " << "encoding=" << to_string(encoding);
+ out << ", " << "count=" << to_string(count);
+ out << ")";
+}
+
+
+ColumnMetaData::~ColumnMetaData() noexcept {
+}
+
+
+void ColumnMetaData::__set_type(const Type::type val) {
+ this->type = val;
+}
+
+void ColumnMetaData::__set_encodings(const std::vector<Encoding::type> & val) {
+ this->encodings = val;
+}
+
+void ColumnMetaData::__set_path_in_schema(const std::vector<std::string> & val) {
+ this->path_in_schema = val;
+}
+
+void ColumnMetaData::__set_codec(const CompressionCodec::type val) {
+ this->codec = val;
+}
+
+void ColumnMetaData::__set_num_values(const int64_t val) {
+ this->num_values = val;
+}
+
+void ColumnMetaData::__set_total_uncompressed_size(const int64_t val) {
+ this->total_uncompressed_size = val;
+}
+
+void ColumnMetaData::__set_total_compressed_size(const int64_t val) {
+ this->total_compressed_size = val;
+}
+
+void ColumnMetaData::__set_key_value_metadata(const std::vector<KeyValue> & val) {
+ this->key_value_metadata = val;
+__isset.key_value_metadata = true;
+}
+
+void ColumnMetaData::__set_data_page_offset(const int64_t val) {
+ this->data_page_offset = val;
+}
+
+void ColumnMetaData::__set_index_page_offset(const int64_t val) {
+ this->index_page_offset = val;
+__isset.index_page_offset = true;
+}
+
+void ColumnMetaData::__set_dictionary_page_offset(const int64_t val) {
+ this->dictionary_page_offset = val;
+__isset.dictionary_page_offset = true;
+}
+
+void ColumnMetaData::__set_statistics(const Statistics& val) {
+ this->statistics = val;
+__isset.statistics = true;
+}
+
+void ColumnMetaData::__set_encoding_stats(const std::vector<PageEncodingStats> & val) {
+ this->encoding_stats = val;
+__isset.encoding_stats = true;
+}
+
+void ColumnMetaData::__set_bloom_filter_offset(const int64_t val) {
+ this->bloom_filter_offset = val;
+__isset.bloom_filter_offset = true;
+}
+std::ostream& operator<<(std::ostream& out, const ColumnMetaData& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t ColumnMetaData::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_type = false;
+ bool isset_encodings = false;
+ bool isset_path_in_schema = false;
+ bool isset_codec = false;
+ bool isset_num_values = false;
+ bool isset_total_uncompressed_size = false;
+ bool isset_total_compressed_size = false;
+ bool isset_data_page_offset = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ int32_t ecast81;
+ xfer += iprot->readI32(ecast81);
+ this->type = (Type::type)ecast81;
+ isset_type = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->encodings.clear();
+ uint32_t _size82;
+ ::apache::thrift::protocol::TType _etype85;
+ xfer += iprot->readListBegin(_etype85, _size82);
+ this->encodings.resize(_size82);
+ uint32_t _i86;
+ for (_i86 = 0; _i86 < _size82; ++_i86)
+ {
+ int32_t ecast87;
+ xfer += iprot->readI32(ecast87);
+ this->encodings[_i86] = (Encoding::type)ecast87;
+ }
+ xfer += iprot->readListEnd();
+ }
+ isset_encodings = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->path_in_schema.clear();
+ uint32_t _size88;
+ ::apache::thrift::protocol::TType _etype91;
+ xfer += iprot->readListBegin(_etype91, _size88);
+ this->path_in_schema.resize(_size88);
+ uint32_t _i92;
+ for (_i92 = 0; _i92 < _size88; ++_i92)
+ {
+ xfer += iprot->readString(this->path_in_schema[_i92]);
+ }
+ xfer += iprot->readListEnd();
+ }
+ isset_path_in_schema = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 4:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ int32_t ecast93;
+ xfer += iprot->readI32(ecast93);
+ this->codec = (CompressionCodec::type)ecast93;
+ isset_codec = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 5:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->num_values);
+ isset_num_values = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 6:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->total_uncompressed_size);
+ isset_total_uncompressed_size = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 7:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->total_compressed_size);
+ isset_total_compressed_size = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 8:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->key_value_metadata.clear();
+ uint32_t _size94;
+ ::apache::thrift::protocol::TType _etype97;
+ xfer += iprot->readListBegin(_etype97, _size94);
+ this->key_value_metadata.resize(_size94);
+ uint32_t _i98;
+ for (_i98 = 0; _i98 < _size94; ++_i98)
+ {
+ xfer += this->key_value_metadata[_i98].read(iprot);
+ }
+ xfer += iprot->readListEnd();
+ }
+ this->__isset.key_value_metadata = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 9:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->data_page_offset);
+ isset_data_page_offset = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 10:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->index_page_offset);
+ this->__isset.index_page_offset = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 11:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->dictionary_page_offset);
+ this->__isset.dictionary_page_offset = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 12:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->statistics.read(iprot);
+ this->__isset.statistics = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 13:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->encoding_stats.clear();
+ uint32_t _size99;
+ ::apache::thrift::protocol::TType _etype102;
+ xfer += iprot->readListBegin(_etype102, _size99);
+ this->encoding_stats.resize(_size99);
+ uint32_t _i103;
+ for (_i103 = 0; _i103 < _size99; ++_i103)
+ {
+ xfer += this->encoding_stats[_i103].read(iprot);
+ }
+ xfer += iprot->readListEnd();
+ }
+ this->__isset.encoding_stats = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 14:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->bloom_filter_offset);
+ this->__isset.bloom_filter_offset = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_type)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_encodings)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_path_in_schema)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_codec)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_num_values)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_total_uncompressed_size)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_total_compressed_size)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_data_page_offset)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t ColumnMetaData::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("ColumnMetaData");
+
+ xfer += oprot->writeFieldBegin("type", ::apache::thrift::protocol::T_I32, 1);
+ xfer += oprot->writeI32((int32_t)this->type);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("encodings", ::apache::thrift::protocol::T_LIST, 2);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_I32, static_cast<uint32_t>(this->encodings.size()));
+ std::vector<Encoding::type> ::const_iterator _iter104;
+ for (_iter104 = this->encodings.begin(); _iter104 != this->encodings.end(); ++_iter104)
+ {
+ xfer += oprot->writeI32((int32_t)(*_iter104));
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("path_in_schema", ::apache::thrift::protocol::T_LIST, 3);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRING, static_cast<uint32_t>(this->path_in_schema.size()));
+ std::vector<std::string> ::const_iterator _iter105;
+ for (_iter105 = this->path_in_schema.begin(); _iter105 != this->path_in_schema.end(); ++_iter105)
+ {
+ xfer += oprot->writeString((*_iter105));
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("codec", ::apache::thrift::protocol::T_I32, 4);
+ xfer += oprot->writeI32((int32_t)this->codec);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("num_values", ::apache::thrift::protocol::T_I64, 5);
+ xfer += oprot->writeI64(this->num_values);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("total_uncompressed_size", ::apache::thrift::protocol::T_I64, 6);
+ xfer += oprot->writeI64(this->total_uncompressed_size);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("total_compressed_size", ::apache::thrift::protocol::T_I64, 7);
+ xfer += oprot->writeI64(this->total_compressed_size);
+ xfer += oprot->writeFieldEnd();
+
+ if (this->__isset.key_value_metadata) {
+ xfer += oprot->writeFieldBegin("key_value_metadata", ::apache::thrift::protocol::T_LIST, 8);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast<uint32_t>(this->key_value_metadata.size()));
+ std::vector<KeyValue> ::const_iterator _iter106;
+ for (_iter106 = this->key_value_metadata.begin(); _iter106 != this->key_value_metadata.end(); ++_iter106)
+ {
+ xfer += (*_iter106).write(oprot);
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldBegin("data_page_offset", ::apache::thrift::protocol::T_I64, 9);
+ xfer += oprot->writeI64(this->data_page_offset);
+ xfer += oprot->writeFieldEnd();
+
+ if (this->__isset.index_page_offset) {
+ xfer += oprot->writeFieldBegin("index_page_offset", ::apache::thrift::protocol::T_I64, 10);
+ xfer += oprot->writeI64(this->index_page_offset);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.dictionary_page_offset) {
+ xfer += oprot->writeFieldBegin("dictionary_page_offset", ::apache::thrift::protocol::T_I64, 11);
+ xfer += oprot->writeI64(this->dictionary_page_offset);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.statistics) {
+ xfer += oprot->writeFieldBegin("statistics", ::apache::thrift::protocol::T_STRUCT, 12);
+ xfer += this->statistics.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.encoding_stats) {
+ xfer += oprot->writeFieldBegin("encoding_stats", ::apache::thrift::protocol::T_LIST, 13);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast<uint32_t>(this->encoding_stats.size()));
+ std::vector<PageEncodingStats> ::const_iterator _iter107;
+ for (_iter107 = this->encoding_stats.begin(); _iter107 != this->encoding_stats.end(); ++_iter107)
+ {
+ xfer += (*_iter107).write(oprot);
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.bloom_filter_offset) {
+ xfer += oprot->writeFieldBegin("bloom_filter_offset", ::apache::thrift::protocol::T_I64, 14);
+ xfer += oprot->writeI64(this->bloom_filter_offset);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(ColumnMetaData &a, ColumnMetaData &b) {
+ using ::std::swap;
+ swap(a.type, b.type);
+ swap(a.encodings, b.encodings);
+ swap(a.path_in_schema, b.path_in_schema);
+ swap(a.codec, b.codec);
+ swap(a.num_values, b.num_values);
+ swap(a.total_uncompressed_size, b.total_uncompressed_size);
+ swap(a.total_compressed_size, b.total_compressed_size);
+ swap(a.key_value_metadata, b.key_value_metadata);
+ swap(a.data_page_offset, b.data_page_offset);
+ swap(a.index_page_offset, b.index_page_offset);
+ swap(a.dictionary_page_offset, b.dictionary_page_offset);
+ swap(a.statistics, b.statistics);
+ swap(a.encoding_stats, b.encoding_stats);
+ swap(a.bloom_filter_offset, b.bloom_filter_offset);
+ swap(a.__isset, b.__isset);
+}
+
+ColumnMetaData::ColumnMetaData(const ColumnMetaData& other108) {
+ type = other108.type;
+ encodings = other108.encodings;
+ path_in_schema = other108.path_in_schema;
+ codec = other108.codec;
+ num_values = other108.num_values;
+ total_uncompressed_size = other108.total_uncompressed_size;
+ total_compressed_size = other108.total_compressed_size;
+ key_value_metadata = other108.key_value_metadata;
+ data_page_offset = other108.data_page_offset;
+ index_page_offset = other108.index_page_offset;
+ dictionary_page_offset = other108.dictionary_page_offset;
+ statistics = other108.statistics;
+ encoding_stats = other108.encoding_stats;
+ bloom_filter_offset = other108.bloom_filter_offset;
+ __isset = other108.__isset;
+}
+ColumnMetaData& ColumnMetaData::operator=(const ColumnMetaData& other109) {
+ type = other109.type;
+ encodings = other109.encodings;
+ path_in_schema = other109.path_in_schema;
+ codec = other109.codec;
+ num_values = other109.num_values;
+ total_uncompressed_size = other109.total_uncompressed_size;
+ total_compressed_size = other109.total_compressed_size;
+ key_value_metadata = other109.key_value_metadata;
+ data_page_offset = other109.data_page_offset;
+ index_page_offset = other109.index_page_offset;
+ dictionary_page_offset = other109.dictionary_page_offset;
+ statistics = other109.statistics;
+ encoding_stats = other109.encoding_stats;
+ bloom_filter_offset = other109.bloom_filter_offset;
+ __isset = other109.__isset;
+ return *this;
+}
+void ColumnMetaData::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "ColumnMetaData(";
+ out << "type=" << to_string(type);
+ out << ", " << "encodings=" << to_string(encodings);
+ out << ", " << "path_in_schema=" << to_string(path_in_schema);
+ out << ", " << "codec=" << to_string(codec);
+ out << ", " << "num_values=" << to_string(num_values);
+ out << ", " << "total_uncompressed_size=" << to_string(total_uncompressed_size);
+ out << ", " << "total_compressed_size=" << to_string(total_compressed_size);
+ out << ", " << "key_value_metadata="; (__isset.key_value_metadata ? (out << to_string(key_value_metadata)) : (out << "<null>"));
+ out << ", " << "data_page_offset=" << to_string(data_page_offset);
+ out << ", " << "index_page_offset="; (__isset.index_page_offset ? (out << to_string(index_page_offset)) : (out << "<null>"));
+ out << ", " << "dictionary_page_offset="; (__isset.dictionary_page_offset ? (out << to_string(dictionary_page_offset)) : (out << "<null>"));
+ out << ", " << "statistics="; (__isset.statistics ? (out << to_string(statistics)) : (out << "<null>"));
+ out << ", " << "encoding_stats="; (__isset.encoding_stats ? (out << to_string(encoding_stats)) : (out << "<null>"));
+ out << ", " << "bloom_filter_offset="; (__isset.bloom_filter_offset ? (out << to_string(bloom_filter_offset)) : (out << "<null>"));
+ out << ")";
+}
+
+
+EncryptionWithFooterKey::~EncryptionWithFooterKey() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const EncryptionWithFooterKey& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t EncryptionWithFooterKey::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t EncryptionWithFooterKey::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("EncryptionWithFooterKey");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(EncryptionWithFooterKey &a, EncryptionWithFooterKey &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+EncryptionWithFooterKey::EncryptionWithFooterKey(const EncryptionWithFooterKey& other110) {
+ (void) other110;
+}
+EncryptionWithFooterKey& EncryptionWithFooterKey::operator=(const EncryptionWithFooterKey& other111) {
+ (void) other111;
+ return *this;
+}
+void EncryptionWithFooterKey::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "EncryptionWithFooterKey(";
+ out << ")";
+}
+
+
+EncryptionWithColumnKey::~EncryptionWithColumnKey() noexcept {
+}
+
+
+void EncryptionWithColumnKey::__set_path_in_schema(const std::vector<std::string> & val) {
+ this->path_in_schema = val;
+}
+
+void EncryptionWithColumnKey::__set_key_metadata(const std::string& val) {
+ this->key_metadata = val;
+__isset.key_metadata = true;
+}
+std::ostream& operator<<(std::ostream& out, const EncryptionWithColumnKey& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t EncryptionWithColumnKey::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_path_in_schema = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->path_in_schema.clear();
+ uint32_t _size112;
+ ::apache::thrift::protocol::TType _etype115;
+ xfer += iprot->readListBegin(_etype115, _size112);
+ this->path_in_schema.resize(_size112);
+ uint32_t _i116;
+ for (_i116 = 0; _i116 < _size112; ++_i116)
+ {
+ xfer += iprot->readString(this->path_in_schema[_i116]);
+ }
+ xfer += iprot->readListEnd();
+ }
+ isset_path_in_schema = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readBinary(this->key_metadata);
+ this->__isset.key_metadata = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_path_in_schema)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t EncryptionWithColumnKey::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("EncryptionWithColumnKey");
+
+ xfer += oprot->writeFieldBegin("path_in_schema", ::apache::thrift::protocol::T_LIST, 1);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRING, static_cast<uint32_t>(this->path_in_schema.size()));
+ std::vector<std::string> ::const_iterator _iter117;
+ for (_iter117 = this->path_in_schema.begin(); _iter117 != this->path_in_schema.end(); ++_iter117)
+ {
+ xfer += oprot->writeString((*_iter117));
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+
+ if (this->__isset.key_metadata) {
+ xfer += oprot->writeFieldBegin("key_metadata", ::apache::thrift::protocol::T_STRING, 2);
+ xfer += oprot->writeBinary(this->key_metadata);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(EncryptionWithColumnKey &a, EncryptionWithColumnKey &b) {
+ using ::std::swap;
+ swap(a.path_in_schema, b.path_in_schema);
+ swap(a.key_metadata, b.key_metadata);
+ swap(a.__isset, b.__isset);
+}
+
+EncryptionWithColumnKey::EncryptionWithColumnKey(const EncryptionWithColumnKey& other118) {
+ path_in_schema = other118.path_in_schema;
+ key_metadata = other118.key_metadata;
+ __isset = other118.__isset;
+}
+EncryptionWithColumnKey& EncryptionWithColumnKey::operator=(const EncryptionWithColumnKey& other119) {
+ path_in_schema = other119.path_in_schema;
+ key_metadata = other119.key_metadata;
+ __isset = other119.__isset;
+ return *this;
+}
+void EncryptionWithColumnKey::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "EncryptionWithColumnKey(";
+ out << "path_in_schema=" << to_string(path_in_schema);
+ out << ", " << "key_metadata="; (__isset.key_metadata ? (out << to_string(key_metadata)) : (out << "<null>"));
+ out << ")";
+}
+
+
+ColumnCryptoMetaData::~ColumnCryptoMetaData() noexcept {
+}
+
+
+void ColumnCryptoMetaData::__set_ENCRYPTION_WITH_FOOTER_KEY(const EncryptionWithFooterKey& val) {
+ this->ENCRYPTION_WITH_FOOTER_KEY = val;
+__isset.ENCRYPTION_WITH_FOOTER_KEY = true;
+}
+
+void ColumnCryptoMetaData::__set_ENCRYPTION_WITH_COLUMN_KEY(const EncryptionWithColumnKey& val) {
+ this->ENCRYPTION_WITH_COLUMN_KEY = val;
+__isset.ENCRYPTION_WITH_COLUMN_KEY = true;
+}
+std::ostream& operator<<(std::ostream& out, const ColumnCryptoMetaData& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t ColumnCryptoMetaData::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->ENCRYPTION_WITH_FOOTER_KEY.read(iprot);
+ this->__isset.ENCRYPTION_WITH_FOOTER_KEY = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->ENCRYPTION_WITH_COLUMN_KEY.read(iprot);
+ this->__isset.ENCRYPTION_WITH_COLUMN_KEY = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t ColumnCryptoMetaData::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("ColumnCryptoMetaData");
+
+ if (this->__isset.ENCRYPTION_WITH_FOOTER_KEY) {
+ xfer += oprot->writeFieldBegin("ENCRYPTION_WITH_FOOTER_KEY", ::apache::thrift::protocol::T_STRUCT, 1);
+ xfer += this->ENCRYPTION_WITH_FOOTER_KEY.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.ENCRYPTION_WITH_COLUMN_KEY) {
+ xfer += oprot->writeFieldBegin("ENCRYPTION_WITH_COLUMN_KEY", ::apache::thrift::protocol::T_STRUCT, 2);
+ xfer += this->ENCRYPTION_WITH_COLUMN_KEY.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(ColumnCryptoMetaData &a, ColumnCryptoMetaData &b) {
+ using ::std::swap;
+ swap(a.ENCRYPTION_WITH_FOOTER_KEY, b.ENCRYPTION_WITH_FOOTER_KEY);
+ swap(a.ENCRYPTION_WITH_COLUMN_KEY, b.ENCRYPTION_WITH_COLUMN_KEY);
+ swap(a.__isset, b.__isset);
+}
+
+ColumnCryptoMetaData::ColumnCryptoMetaData(const ColumnCryptoMetaData& other120) {
+ ENCRYPTION_WITH_FOOTER_KEY = other120.ENCRYPTION_WITH_FOOTER_KEY;
+ ENCRYPTION_WITH_COLUMN_KEY = other120.ENCRYPTION_WITH_COLUMN_KEY;
+ __isset = other120.__isset;
+}
+ColumnCryptoMetaData& ColumnCryptoMetaData::operator=(const ColumnCryptoMetaData& other121) {
+ ENCRYPTION_WITH_FOOTER_KEY = other121.ENCRYPTION_WITH_FOOTER_KEY;
+ ENCRYPTION_WITH_COLUMN_KEY = other121.ENCRYPTION_WITH_COLUMN_KEY;
+ __isset = other121.__isset;
+ return *this;
+}
+void ColumnCryptoMetaData::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "ColumnCryptoMetaData(";
+ out << "ENCRYPTION_WITH_FOOTER_KEY="; (__isset.ENCRYPTION_WITH_FOOTER_KEY ? (out << to_string(ENCRYPTION_WITH_FOOTER_KEY)) : (out << "<null>"));
+ out << ", " << "ENCRYPTION_WITH_COLUMN_KEY="; (__isset.ENCRYPTION_WITH_COLUMN_KEY ? (out << to_string(ENCRYPTION_WITH_COLUMN_KEY)) : (out << "<null>"));
+ out << ")";
+}
+
+
+ColumnChunk::~ColumnChunk() noexcept {
+}
+
+
+void ColumnChunk::__set_file_path(const std::string& val) {
+ this->file_path = val;
+__isset.file_path = true;
+}
+
+void ColumnChunk::__set_file_offset(const int64_t val) {
+ this->file_offset = val;
+}
+
+void ColumnChunk::__set_meta_data(const ColumnMetaData& val) {
+ this->meta_data = val;
+__isset.meta_data = true;
+}
+
+void ColumnChunk::__set_offset_index_offset(const int64_t val) {
+ this->offset_index_offset = val;
+__isset.offset_index_offset = true;
+}
+
+void ColumnChunk::__set_offset_index_length(const int32_t val) {
+ this->offset_index_length = val;
+__isset.offset_index_length = true;
+}
+
+void ColumnChunk::__set_column_index_offset(const int64_t val) {
+ this->column_index_offset = val;
+__isset.column_index_offset = true;
+}
+
+void ColumnChunk::__set_column_index_length(const int32_t val) {
+ this->column_index_length = val;
+__isset.column_index_length = true;
+}
+
+void ColumnChunk::__set_crypto_metadata(const ColumnCryptoMetaData& val) {
+ this->crypto_metadata = val;
+__isset.crypto_metadata = true;
+}
+
+void ColumnChunk::__set_encrypted_column_metadata(const std::string& val) {
+ this->encrypted_column_metadata = val;
+__isset.encrypted_column_metadata = true;
+}
+std::ostream& operator<<(std::ostream& out, const ColumnChunk& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t ColumnChunk::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_file_offset = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readString(this->file_path);
+ this->__isset.file_path = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->file_offset);
+ isset_file_offset = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->meta_data.read(iprot);
+ this->__isset.meta_data = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 4:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->offset_index_offset);
+ this->__isset.offset_index_offset = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 5:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->offset_index_length);
+ this->__isset.offset_index_length = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 6:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->column_index_offset);
+ this->__isset.column_index_offset = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 7:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->column_index_length);
+ this->__isset.column_index_length = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 8:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->crypto_metadata.read(iprot);
+ this->__isset.crypto_metadata = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 9:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readBinary(this->encrypted_column_metadata);
+ this->__isset.encrypted_column_metadata = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_file_offset)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t ColumnChunk::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("ColumnChunk");
+
+ if (this->__isset.file_path) {
+ xfer += oprot->writeFieldBegin("file_path", ::apache::thrift::protocol::T_STRING, 1);
+ xfer += oprot->writeString(this->file_path);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldBegin("file_offset", ::apache::thrift::protocol::T_I64, 2);
+ xfer += oprot->writeI64(this->file_offset);
+ xfer += oprot->writeFieldEnd();
+
+ if (this->__isset.meta_data) {
+ xfer += oprot->writeFieldBegin("meta_data", ::apache::thrift::protocol::T_STRUCT, 3);
+ xfer += this->meta_data.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.offset_index_offset) {
+ xfer += oprot->writeFieldBegin("offset_index_offset", ::apache::thrift::protocol::T_I64, 4);
+ xfer += oprot->writeI64(this->offset_index_offset);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.offset_index_length) {
+ xfer += oprot->writeFieldBegin("offset_index_length", ::apache::thrift::protocol::T_I32, 5);
+ xfer += oprot->writeI32(this->offset_index_length);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.column_index_offset) {
+ xfer += oprot->writeFieldBegin("column_index_offset", ::apache::thrift::protocol::T_I64, 6);
+ xfer += oprot->writeI64(this->column_index_offset);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.column_index_length) {
+ xfer += oprot->writeFieldBegin("column_index_length", ::apache::thrift::protocol::T_I32, 7);
+ xfer += oprot->writeI32(this->column_index_length);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.crypto_metadata) {
+ xfer += oprot->writeFieldBegin("crypto_metadata", ::apache::thrift::protocol::T_STRUCT, 8);
+ xfer += this->crypto_metadata.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.encrypted_column_metadata) {
+ xfer += oprot->writeFieldBegin("encrypted_column_metadata", ::apache::thrift::protocol::T_STRING, 9);
+ xfer += oprot->writeBinary(this->encrypted_column_metadata);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(ColumnChunk &a, ColumnChunk &b) {
+ using ::std::swap;
+ swap(a.file_path, b.file_path);
+ swap(a.file_offset, b.file_offset);
+ swap(a.meta_data, b.meta_data);
+ swap(a.offset_index_offset, b.offset_index_offset);
+ swap(a.offset_index_length, b.offset_index_length);
+ swap(a.column_index_offset, b.column_index_offset);
+ swap(a.column_index_length, b.column_index_length);
+ swap(a.crypto_metadata, b.crypto_metadata);
+ swap(a.encrypted_column_metadata, b.encrypted_column_metadata);
+ swap(a.__isset, b.__isset);
+}
+
+ColumnChunk::ColumnChunk(const ColumnChunk& other122) {
+ file_path = other122.file_path;
+ file_offset = other122.file_offset;
+ meta_data = other122.meta_data;
+ offset_index_offset = other122.offset_index_offset;
+ offset_index_length = other122.offset_index_length;
+ column_index_offset = other122.column_index_offset;
+ column_index_length = other122.column_index_length;
+ crypto_metadata = other122.crypto_metadata;
+ encrypted_column_metadata = other122.encrypted_column_metadata;
+ __isset = other122.__isset;
+}
+ColumnChunk& ColumnChunk::operator=(const ColumnChunk& other123) {
+ file_path = other123.file_path;
+ file_offset = other123.file_offset;
+ meta_data = other123.meta_data;
+ offset_index_offset = other123.offset_index_offset;
+ offset_index_length = other123.offset_index_length;
+ column_index_offset = other123.column_index_offset;
+ column_index_length = other123.column_index_length;
+ crypto_metadata = other123.crypto_metadata;
+ encrypted_column_metadata = other123.encrypted_column_metadata;
+ __isset = other123.__isset;
+ return *this;
+}
+void ColumnChunk::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "ColumnChunk(";
+ out << "file_path="; (__isset.file_path ? (out << to_string(file_path)) : (out << "<null>"));
+ out << ", " << "file_offset=" << to_string(file_offset);
+ out << ", " << "meta_data="; (__isset.meta_data ? (out << to_string(meta_data)) : (out << "<null>"));
+ out << ", " << "offset_index_offset="; (__isset.offset_index_offset ? (out << to_string(offset_index_offset)) : (out << "<null>"));
+ out << ", " << "offset_index_length="; (__isset.offset_index_length ? (out << to_string(offset_index_length)) : (out << "<null>"));
+ out << ", " << "column_index_offset="; (__isset.column_index_offset ? (out << to_string(column_index_offset)) : (out << "<null>"));
+ out << ", " << "column_index_length="; (__isset.column_index_length ? (out << to_string(column_index_length)) : (out << "<null>"));
+ out << ", " << "crypto_metadata="; (__isset.crypto_metadata ? (out << to_string(crypto_metadata)) : (out << "<null>"));
+ out << ", " << "encrypted_column_metadata="; (__isset.encrypted_column_metadata ? (out << to_string(encrypted_column_metadata)) : (out << "<null>"));
+ out << ")";
+}
+
+
+RowGroup::~RowGroup() noexcept {
+}
+
+
+void RowGroup::__set_columns(const std::vector<ColumnChunk> & val) {
+ this->columns = val;
+}
+
+void RowGroup::__set_total_byte_size(const int64_t val) {
+ this->total_byte_size = val;
+}
+
+void RowGroup::__set_num_rows(const int64_t val) {
+ this->num_rows = val;
+}
+
+void RowGroup::__set_sorting_columns(const std::vector<SortingColumn> & val) {
+ this->sorting_columns = val;
+__isset.sorting_columns = true;
+}
+
+void RowGroup::__set_file_offset(const int64_t val) {
+ this->file_offset = val;
+__isset.file_offset = true;
+}
+
+void RowGroup::__set_total_compressed_size(const int64_t val) {
+ this->total_compressed_size = val;
+__isset.total_compressed_size = true;
+}
+
+void RowGroup::__set_ordinal(const int16_t val) {
+ this->ordinal = val;
+__isset.ordinal = true;
+}
+std::ostream& operator<<(std::ostream& out, const RowGroup& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t RowGroup::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_columns = false;
+ bool isset_total_byte_size = false;
+ bool isset_num_rows = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->columns.clear();
+ uint32_t _size124;
+ ::apache::thrift::protocol::TType _etype127;
+ xfer += iprot->readListBegin(_etype127, _size124);
+ this->columns.resize(_size124);
+ uint32_t _i128;
+ for (_i128 = 0; _i128 < _size124; ++_i128)
+ {
+ xfer += this->columns[_i128].read(iprot);
+ }
+ xfer += iprot->readListEnd();
+ }
+ isset_columns = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->total_byte_size);
+ isset_total_byte_size = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->num_rows);
+ isset_num_rows = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 4:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->sorting_columns.clear();
+ uint32_t _size129;
+ ::apache::thrift::protocol::TType _etype132;
+ xfer += iprot->readListBegin(_etype132, _size129);
+ this->sorting_columns.resize(_size129);
+ uint32_t _i133;
+ for (_i133 = 0; _i133 < _size129; ++_i133)
+ {
+ xfer += this->sorting_columns[_i133].read(iprot);
+ }
+ xfer += iprot->readListEnd();
+ }
+ this->__isset.sorting_columns = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 5:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->file_offset);
+ this->__isset.file_offset = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 6:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->total_compressed_size);
+ this->__isset.total_compressed_size = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 7:
+ if (ftype == ::apache::thrift::protocol::T_I16) {
+ xfer += iprot->readI16(this->ordinal);
+ this->__isset.ordinal = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_columns)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_total_byte_size)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_num_rows)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t RowGroup::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("RowGroup");
+
+ xfer += oprot->writeFieldBegin("columns", ::apache::thrift::protocol::T_LIST, 1);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast<uint32_t>(this->columns.size()));
+ std::vector<ColumnChunk> ::const_iterator _iter134;
+ for (_iter134 = this->columns.begin(); _iter134 != this->columns.end(); ++_iter134)
+ {
+ xfer += (*_iter134).write(oprot);
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("total_byte_size", ::apache::thrift::protocol::T_I64, 2);
+ xfer += oprot->writeI64(this->total_byte_size);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("num_rows", ::apache::thrift::protocol::T_I64, 3);
+ xfer += oprot->writeI64(this->num_rows);
+ xfer += oprot->writeFieldEnd();
+
+ if (this->__isset.sorting_columns) {
+ xfer += oprot->writeFieldBegin("sorting_columns", ::apache::thrift::protocol::T_LIST, 4);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast<uint32_t>(this->sorting_columns.size()));
+ std::vector<SortingColumn> ::const_iterator _iter135;
+ for (_iter135 = this->sorting_columns.begin(); _iter135 != this->sorting_columns.end(); ++_iter135)
+ {
+ xfer += (*_iter135).write(oprot);
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.file_offset) {
+ xfer += oprot->writeFieldBegin("file_offset", ::apache::thrift::protocol::T_I64, 5);
+ xfer += oprot->writeI64(this->file_offset);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.total_compressed_size) {
+ xfer += oprot->writeFieldBegin("total_compressed_size", ::apache::thrift::protocol::T_I64, 6);
+ xfer += oprot->writeI64(this->total_compressed_size);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.ordinal) {
+ xfer += oprot->writeFieldBegin("ordinal", ::apache::thrift::protocol::T_I16, 7);
+ xfer += oprot->writeI16(this->ordinal);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(RowGroup &a, RowGroup &b) {
+ using ::std::swap;
+ swap(a.columns, b.columns);
+ swap(a.total_byte_size, b.total_byte_size);
+ swap(a.num_rows, b.num_rows);
+ swap(a.sorting_columns, b.sorting_columns);
+ swap(a.file_offset, b.file_offset);
+ swap(a.total_compressed_size, b.total_compressed_size);
+ swap(a.ordinal, b.ordinal);
+ swap(a.__isset, b.__isset);
+}
+
+RowGroup::RowGroup(const RowGroup& other136) {
+ columns = other136.columns;
+ total_byte_size = other136.total_byte_size;
+ num_rows = other136.num_rows;
+ sorting_columns = other136.sorting_columns;
+ file_offset = other136.file_offset;
+ total_compressed_size = other136.total_compressed_size;
+ ordinal = other136.ordinal;
+ __isset = other136.__isset;
+}
+RowGroup& RowGroup::operator=(const RowGroup& other137) {
+ columns = other137.columns;
+ total_byte_size = other137.total_byte_size;
+ num_rows = other137.num_rows;
+ sorting_columns = other137.sorting_columns;
+ file_offset = other137.file_offset;
+ total_compressed_size = other137.total_compressed_size;
+ ordinal = other137.ordinal;
+ __isset = other137.__isset;
+ return *this;
+}
+void RowGroup::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "RowGroup(";
+ out << "columns=" << to_string(columns);
+ out << ", " << "total_byte_size=" << to_string(total_byte_size);
+ out << ", " << "num_rows=" << to_string(num_rows);
+ out << ", " << "sorting_columns="; (__isset.sorting_columns ? (out << to_string(sorting_columns)) : (out << "<null>"));
+ out << ", " << "file_offset="; (__isset.file_offset ? (out << to_string(file_offset)) : (out << "<null>"));
+ out << ", " << "total_compressed_size="; (__isset.total_compressed_size ? (out << to_string(total_compressed_size)) : (out << "<null>"));
+ out << ", " << "ordinal="; (__isset.ordinal ? (out << to_string(ordinal)) : (out << "<null>"));
+ out << ")";
+}
+
+
+TypeDefinedOrder::~TypeDefinedOrder() noexcept {
+}
+
+std::ostream& operator<<(std::ostream& out, const TypeDefinedOrder& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t TypeDefinedOrder::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ xfer += iprot->skip(ftype);
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t TypeDefinedOrder::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("TypeDefinedOrder");
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(TypeDefinedOrder &a, TypeDefinedOrder &b) {
+ using ::std::swap;
+ (void) a;
+ (void) b;
+}
+
+TypeDefinedOrder::TypeDefinedOrder(const TypeDefinedOrder& other138) {
+ (void) other138;
+}
+TypeDefinedOrder& TypeDefinedOrder::operator=(const TypeDefinedOrder& other139) {
+ (void) other139;
+ return *this;
+}
+void TypeDefinedOrder::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "TypeDefinedOrder(";
+ out << ")";
+}
+
+
+ColumnOrder::~ColumnOrder() noexcept {
+}
+
+
+void ColumnOrder::__set_TYPE_ORDER(const TypeDefinedOrder& val) {
+ this->TYPE_ORDER = val;
+__isset.TYPE_ORDER = true;
+}
+std::ostream& operator<<(std::ostream& out, const ColumnOrder& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t ColumnOrder::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->TYPE_ORDER.read(iprot);
+ this->__isset.TYPE_ORDER = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t ColumnOrder::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("ColumnOrder");
+
+ if (this->__isset.TYPE_ORDER) {
+ xfer += oprot->writeFieldBegin("TYPE_ORDER", ::apache::thrift::protocol::T_STRUCT, 1);
+ xfer += this->TYPE_ORDER.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(ColumnOrder &a, ColumnOrder &b) {
+ using ::std::swap;
+ swap(a.TYPE_ORDER, b.TYPE_ORDER);
+ swap(a.__isset, b.__isset);
+}
+
+ColumnOrder::ColumnOrder(const ColumnOrder& other140) {
+ TYPE_ORDER = other140.TYPE_ORDER;
+ __isset = other140.__isset;
+}
+ColumnOrder& ColumnOrder::operator=(const ColumnOrder& other141) {
+ TYPE_ORDER = other141.TYPE_ORDER;
+ __isset = other141.__isset;
+ return *this;
+}
+void ColumnOrder::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "ColumnOrder(";
+ out << "TYPE_ORDER="; (__isset.TYPE_ORDER ? (out << to_string(TYPE_ORDER)) : (out << "<null>"));
+ out << ")";
+}
+
+
+PageLocation::~PageLocation() noexcept {
+}
+
+
+void PageLocation::__set_offset(const int64_t val) {
+ this->offset = val;
+}
+
+void PageLocation::__set_compressed_page_size(const int32_t val) {
+ this->compressed_page_size = val;
+}
+
+void PageLocation::__set_first_row_index(const int64_t val) {
+ this->first_row_index = val;
+}
+std::ostream& operator<<(std::ostream& out, const PageLocation& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t PageLocation::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_offset = false;
+ bool isset_compressed_page_size = false;
+ bool isset_first_row_index = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->offset);
+ isset_offset = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->compressed_page_size);
+ isset_compressed_page_size = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->first_row_index);
+ isset_first_row_index = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_offset)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_compressed_page_size)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_first_row_index)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t PageLocation::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("PageLocation");
+
+ xfer += oprot->writeFieldBegin("offset", ::apache::thrift::protocol::T_I64, 1);
+ xfer += oprot->writeI64(this->offset);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("compressed_page_size", ::apache::thrift::protocol::T_I32, 2);
+ xfer += oprot->writeI32(this->compressed_page_size);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("first_row_index", ::apache::thrift::protocol::T_I64, 3);
+ xfer += oprot->writeI64(this->first_row_index);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(PageLocation &a, PageLocation &b) {
+ using ::std::swap;
+ swap(a.offset, b.offset);
+ swap(a.compressed_page_size, b.compressed_page_size);
+ swap(a.first_row_index, b.first_row_index);
+}
+
+PageLocation::PageLocation(const PageLocation& other142) {
+ offset = other142.offset;
+ compressed_page_size = other142.compressed_page_size;
+ first_row_index = other142.first_row_index;
+}
+PageLocation& PageLocation::operator=(const PageLocation& other143) {
+ offset = other143.offset;
+ compressed_page_size = other143.compressed_page_size;
+ first_row_index = other143.first_row_index;
+ return *this;
+}
+void PageLocation::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "PageLocation(";
+ out << "offset=" << to_string(offset);
+ out << ", " << "compressed_page_size=" << to_string(compressed_page_size);
+ out << ", " << "first_row_index=" << to_string(first_row_index);
+ out << ")";
+}
+
+
+OffsetIndex::~OffsetIndex() noexcept {
+}
+
+
+void OffsetIndex::__set_page_locations(const std::vector<PageLocation> & val) {
+ this->page_locations = val;
+}
+std::ostream& operator<<(std::ostream& out, const OffsetIndex& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t OffsetIndex::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_page_locations = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->page_locations.clear();
+ uint32_t _size144;
+ ::apache::thrift::protocol::TType _etype147;
+ xfer += iprot->readListBegin(_etype147, _size144);
+ this->page_locations.resize(_size144);
+ uint32_t _i148;
+ for (_i148 = 0; _i148 < _size144; ++_i148)
+ {
+ xfer += this->page_locations[_i148].read(iprot);
+ }
+ xfer += iprot->readListEnd();
+ }
+ isset_page_locations = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_page_locations)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t OffsetIndex::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("OffsetIndex");
+
+ xfer += oprot->writeFieldBegin("page_locations", ::apache::thrift::protocol::T_LIST, 1);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast<uint32_t>(this->page_locations.size()));
+ std::vector<PageLocation> ::const_iterator _iter149;
+ for (_iter149 = this->page_locations.begin(); _iter149 != this->page_locations.end(); ++_iter149)
+ {
+ xfer += (*_iter149).write(oprot);
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(OffsetIndex &a, OffsetIndex &b) {
+ using ::std::swap;
+ swap(a.page_locations, b.page_locations);
+}
+
+OffsetIndex::OffsetIndex(const OffsetIndex& other150) {
+ page_locations = other150.page_locations;
+}
+OffsetIndex& OffsetIndex::operator=(const OffsetIndex& other151) {
+ page_locations = other151.page_locations;
+ return *this;
+}
+void OffsetIndex::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "OffsetIndex(";
+ out << "page_locations=" << to_string(page_locations);
+ out << ")";
+}
+
+
+ColumnIndex::~ColumnIndex() noexcept {
+}
+
+
+void ColumnIndex::__set_null_pages(const std::vector<bool> & val) {
+ this->null_pages = val;
+}
+
+void ColumnIndex::__set_min_values(const std::vector<std::string> & val) {
+ this->min_values = val;
+}
+
+void ColumnIndex::__set_max_values(const std::vector<std::string> & val) {
+ this->max_values = val;
+}
+
+void ColumnIndex::__set_boundary_order(const BoundaryOrder::type val) {
+ this->boundary_order = val;
+}
+
+void ColumnIndex::__set_null_counts(const std::vector<int64_t> & val) {
+ this->null_counts = val;
+__isset.null_counts = true;
+}
+std::ostream& operator<<(std::ostream& out, const ColumnIndex& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t ColumnIndex::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_null_pages = false;
+ bool isset_min_values = false;
+ bool isset_max_values = false;
+ bool isset_boundary_order = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->null_pages.clear();
+ uint32_t _size152;
+ ::apache::thrift::protocol::TType _etype155;
+ xfer += iprot->readListBegin(_etype155, _size152);
+ this->null_pages.resize(_size152);
+ uint32_t _i156;
+ for (_i156 = 0; _i156 < _size152; ++_i156)
+ {
+ xfer += iprot->readBool(this->null_pages[_i156]);
+ }
+ xfer += iprot->readListEnd();
+ }
+ isset_null_pages = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->min_values.clear();
+ uint32_t _size157;
+ ::apache::thrift::protocol::TType _etype160;
+ xfer += iprot->readListBegin(_etype160, _size157);
+ this->min_values.resize(_size157);
+ uint32_t _i161;
+ for (_i161 = 0; _i161 < _size157; ++_i161)
+ {
+ xfer += iprot->readBinary(this->min_values[_i161]);
+ }
+ xfer += iprot->readListEnd();
+ }
+ isset_min_values = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->max_values.clear();
+ uint32_t _size162;
+ ::apache::thrift::protocol::TType _etype165;
+ xfer += iprot->readListBegin(_etype165, _size162);
+ this->max_values.resize(_size162);
+ uint32_t _i166;
+ for (_i166 = 0; _i166 < _size162; ++_i166)
+ {
+ xfer += iprot->readBinary(this->max_values[_i166]);
+ }
+ xfer += iprot->readListEnd();
+ }
+ isset_max_values = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 4:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ int32_t ecast167;
+ xfer += iprot->readI32(ecast167);
+ this->boundary_order = (BoundaryOrder::type)ecast167;
+ isset_boundary_order = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 5:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->null_counts.clear();
+ uint32_t _size168;
+ ::apache::thrift::protocol::TType _etype171;
+ xfer += iprot->readListBegin(_etype171, _size168);
+ this->null_counts.resize(_size168);
+ uint32_t _i172;
+ for (_i172 = 0; _i172 < _size168; ++_i172)
+ {
+ xfer += iprot->readI64(this->null_counts[_i172]);
+ }
+ xfer += iprot->readListEnd();
+ }
+ this->__isset.null_counts = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_null_pages)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_min_values)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_max_values)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_boundary_order)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t ColumnIndex::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("ColumnIndex");
+
+ xfer += oprot->writeFieldBegin("null_pages", ::apache::thrift::protocol::T_LIST, 1);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_BOOL, static_cast<uint32_t>(this->null_pages.size()));
+ std::vector<bool> ::const_iterator _iter173;
+ for (_iter173 = this->null_pages.begin(); _iter173 != this->null_pages.end(); ++_iter173)
+ {
+ xfer += oprot->writeBool((*_iter173));
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("min_values", ::apache::thrift::protocol::T_LIST, 2);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRING, static_cast<uint32_t>(this->min_values.size()));
+ std::vector<std::string> ::const_iterator _iter174;
+ for (_iter174 = this->min_values.begin(); _iter174 != this->min_values.end(); ++_iter174)
+ {
+ xfer += oprot->writeBinary((*_iter174));
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("max_values", ::apache::thrift::protocol::T_LIST, 3);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRING, static_cast<uint32_t>(this->max_values.size()));
+ std::vector<std::string> ::const_iterator _iter175;
+ for (_iter175 = this->max_values.begin(); _iter175 != this->max_values.end(); ++_iter175)
+ {
+ xfer += oprot->writeBinary((*_iter175));
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("boundary_order", ::apache::thrift::protocol::T_I32, 4);
+ xfer += oprot->writeI32((int32_t)this->boundary_order);
+ xfer += oprot->writeFieldEnd();
+
+ if (this->__isset.null_counts) {
+ xfer += oprot->writeFieldBegin("null_counts", ::apache::thrift::protocol::T_LIST, 5);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_I64, static_cast<uint32_t>(this->null_counts.size()));
+ std::vector<int64_t> ::const_iterator _iter176;
+ for (_iter176 = this->null_counts.begin(); _iter176 != this->null_counts.end(); ++_iter176)
+ {
+ xfer += oprot->writeI64((*_iter176));
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(ColumnIndex &a, ColumnIndex &b) {
+ using ::std::swap;
+ swap(a.null_pages, b.null_pages);
+ swap(a.min_values, b.min_values);
+ swap(a.max_values, b.max_values);
+ swap(a.boundary_order, b.boundary_order);
+ swap(a.null_counts, b.null_counts);
+ swap(a.__isset, b.__isset);
+}
+
+ColumnIndex::ColumnIndex(const ColumnIndex& other177) {
+ null_pages = other177.null_pages;
+ min_values = other177.min_values;
+ max_values = other177.max_values;
+ boundary_order = other177.boundary_order;
+ null_counts = other177.null_counts;
+ __isset = other177.__isset;
+}
+ColumnIndex& ColumnIndex::operator=(const ColumnIndex& other178) {
+ null_pages = other178.null_pages;
+ min_values = other178.min_values;
+ max_values = other178.max_values;
+ boundary_order = other178.boundary_order;
+ null_counts = other178.null_counts;
+ __isset = other178.__isset;
+ return *this;
+}
+void ColumnIndex::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "ColumnIndex(";
+ out << "null_pages=" << to_string(null_pages);
+ out << ", " << "min_values=" << to_string(min_values);
+ out << ", " << "max_values=" << to_string(max_values);
+ out << ", " << "boundary_order=" << to_string(boundary_order);
+ out << ", " << "null_counts="; (__isset.null_counts ? (out << to_string(null_counts)) : (out << "<null>"));
+ out << ")";
+}
+
+
+AesGcmV1::~AesGcmV1() noexcept {
+}
+
+
+void AesGcmV1::__set_aad_prefix(const std::string& val) {
+ this->aad_prefix = val;
+__isset.aad_prefix = true;
+}
+
+void AesGcmV1::__set_aad_file_unique(const std::string& val) {
+ this->aad_file_unique = val;
+__isset.aad_file_unique = true;
+}
+
+void AesGcmV1::__set_supply_aad_prefix(const bool val) {
+ this->supply_aad_prefix = val;
+__isset.supply_aad_prefix = true;
+}
+std::ostream& operator<<(std::ostream& out, const AesGcmV1& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t AesGcmV1::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readBinary(this->aad_prefix);
+ this->__isset.aad_prefix = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readBinary(this->aad_file_unique);
+ this->__isset.aad_file_unique = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_BOOL) {
+ xfer += iprot->readBool(this->supply_aad_prefix);
+ this->__isset.supply_aad_prefix = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t AesGcmV1::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("AesGcmV1");
+
+ if (this->__isset.aad_prefix) {
+ xfer += oprot->writeFieldBegin("aad_prefix", ::apache::thrift::protocol::T_STRING, 1);
+ xfer += oprot->writeBinary(this->aad_prefix);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.aad_file_unique) {
+ xfer += oprot->writeFieldBegin("aad_file_unique", ::apache::thrift::protocol::T_STRING, 2);
+ xfer += oprot->writeBinary(this->aad_file_unique);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.supply_aad_prefix) {
+ xfer += oprot->writeFieldBegin("supply_aad_prefix", ::apache::thrift::protocol::T_BOOL, 3);
+ xfer += oprot->writeBool(this->supply_aad_prefix);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(AesGcmV1 &a, AesGcmV1 &b) {
+ using ::std::swap;
+ swap(a.aad_prefix, b.aad_prefix);
+ swap(a.aad_file_unique, b.aad_file_unique);
+ swap(a.supply_aad_prefix, b.supply_aad_prefix);
+ swap(a.__isset, b.__isset);
+}
+
+AesGcmV1::AesGcmV1(const AesGcmV1& other179) {
+ aad_prefix = other179.aad_prefix;
+ aad_file_unique = other179.aad_file_unique;
+ supply_aad_prefix = other179.supply_aad_prefix;
+ __isset = other179.__isset;
+}
+AesGcmV1& AesGcmV1::operator=(const AesGcmV1& other180) {
+ aad_prefix = other180.aad_prefix;
+ aad_file_unique = other180.aad_file_unique;
+ supply_aad_prefix = other180.supply_aad_prefix;
+ __isset = other180.__isset;
+ return *this;
+}
+void AesGcmV1::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "AesGcmV1(";
+ out << "aad_prefix="; (__isset.aad_prefix ? (out << to_string(aad_prefix)) : (out << "<null>"));
+ out << ", " << "aad_file_unique="; (__isset.aad_file_unique ? (out << to_string(aad_file_unique)) : (out << "<null>"));
+ out << ", " << "supply_aad_prefix="; (__isset.supply_aad_prefix ? (out << to_string(supply_aad_prefix)) : (out << "<null>"));
+ out << ")";
+}
+
+
+AesGcmCtrV1::~AesGcmCtrV1() noexcept {
+}
+
+
+void AesGcmCtrV1::__set_aad_prefix(const std::string& val) {
+ this->aad_prefix = val;
+__isset.aad_prefix = true;
+}
+
+void AesGcmCtrV1::__set_aad_file_unique(const std::string& val) {
+ this->aad_file_unique = val;
+__isset.aad_file_unique = true;
+}
+
+void AesGcmCtrV1::__set_supply_aad_prefix(const bool val) {
+ this->supply_aad_prefix = val;
+__isset.supply_aad_prefix = true;
+}
+std::ostream& operator<<(std::ostream& out, const AesGcmCtrV1& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t AesGcmCtrV1::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readBinary(this->aad_prefix);
+ this->__isset.aad_prefix = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readBinary(this->aad_file_unique);
+ this->__isset.aad_file_unique = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_BOOL) {
+ xfer += iprot->readBool(this->supply_aad_prefix);
+ this->__isset.supply_aad_prefix = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t AesGcmCtrV1::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("AesGcmCtrV1");
+
+ if (this->__isset.aad_prefix) {
+ xfer += oprot->writeFieldBegin("aad_prefix", ::apache::thrift::protocol::T_STRING, 1);
+ xfer += oprot->writeBinary(this->aad_prefix);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.aad_file_unique) {
+ xfer += oprot->writeFieldBegin("aad_file_unique", ::apache::thrift::protocol::T_STRING, 2);
+ xfer += oprot->writeBinary(this->aad_file_unique);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.supply_aad_prefix) {
+ xfer += oprot->writeFieldBegin("supply_aad_prefix", ::apache::thrift::protocol::T_BOOL, 3);
+ xfer += oprot->writeBool(this->supply_aad_prefix);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(AesGcmCtrV1 &a, AesGcmCtrV1 &b) {
+ using ::std::swap;
+ swap(a.aad_prefix, b.aad_prefix);
+ swap(a.aad_file_unique, b.aad_file_unique);
+ swap(a.supply_aad_prefix, b.supply_aad_prefix);
+ swap(a.__isset, b.__isset);
+}
+
+AesGcmCtrV1::AesGcmCtrV1(const AesGcmCtrV1& other181) {
+ aad_prefix = other181.aad_prefix;
+ aad_file_unique = other181.aad_file_unique;
+ supply_aad_prefix = other181.supply_aad_prefix;
+ __isset = other181.__isset;
+}
+AesGcmCtrV1& AesGcmCtrV1::operator=(const AesGcmCtrV1& other182) {
+ aad_prefix = other182.aad_prefix;
+ aad_file_unique = other182.aad_file_unique;
+ supply_aad_prefix = other182.supply_aad_prefix;
+ __isset = other182.__isset;
+ return *this;
+}
+void AesGcmCtrV1::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "AesGcmCtrV1(";
+ out << "aad_prefix="; (__isset.aad_prefix ? (out << to_string(aad_prefix)) : (out << "<null>"));
+ out << ", " << "aad_file_unique="; (__isset.aad_file_unique ? (out << to_string(aad_file_unique)) : (out << "<null>"));
+ out << ", " << "supply_aad_prefix="; (__isset.supply_aad_prefix ? (out << to_string(supply_aad_prefix)) : (out << "<null>"));
+ out << ")";
+}
+
+
+EncryptionAlgorithm::~EncryptionAlgorithm() noexcept {
+}
+
+
+void EncryptionAlgorithm::__set_AES_GCM_V1(const AesGcmV1& val) {
+ this->AES_GCM_V1 = val;
+__isset.AES_GCM_V1 = true;
+}
+
+void EncryptionAlgorithm::__set_AES_GCM_CTR_V1(const AesGcmCtrV1& val) {
+ this->AES_GCM_CTR_V1 = val;
+__isset.AES_GCM_CTR_V1 = true;
+}
+std::ostream& operator<<(std::ostream& out, const EncryptionAlgorithm& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t EncryptionAlgorithm::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->AES_GCM_V1.read(iprot);
+ this->__isset.AES_GCM_V1 = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->AES_GCM_CTR_V1.read(iprot);
+ this->__isset.AES_GCM_CTR_V1 = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ return xfer;
+}
+
+uint32_t EncryptionAlgorithm::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("EncryptionAlgorithm");
+
+ if (this->__isset.AES_GCM_V1) {
+ xfer += oprot->writeFieldBegin("AES_GCM_V1", ::apache::thrift::protocol::T_STRUCT, 1);
+ xfer += this->AES_GCM_V1.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.AES_GCM_CTR_V1) {
+ xfer += oprot->writeFieldBegin("AES_GCM_CTR_V1", ::apache::thrift::protocol::T_STRUCT, 2);
+ xfer += this->AES_GCM_CTR_V1.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(EncryptionAlgorithm &a, EncryptionAlgorithm &b) {
+ using ::std::swap;
+ swap(a.AES_GCM_V1, b.AES_GCM_V1);
+ swap(a.AES_GCM_CTR_V1, b.AES_GCM_CTR_V1);
+ swap(a.__isset, b.__isset);
+}
+
+EncryptionAlgorithm::EncryptionAlgorithm(const EncryptionAlgorithm& other183) {
+ AES_GCM_V1 = other183.AES_GCM_V1;
+ AES_GCM_CTR_V1 = other183.AES_GCM_CTR_V1;
+ __isset = other183.__isset;
+}
+EncryptionAlgorithm& EncryptionAlgorithm::operator=(const EncryptionAlgorithm& other184) {
+ AES_GCM_V1 = other184.AES_GCM_V1;
+ AES_GCM_CTR_V1 = other184.AES_GCM_CTR_V1;
+ __isset = other184.__isset;
+ return *this;
+}
+void EncryptionAlgorithm::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "EncryptionAlgorithm(";
+ out << "AES_GCM_V1="; (__isset.AES_GCM_V1 ? (out << to_string(AES_GCM_V1)) : (out << "<null>"));
+ out << ", " << "AES_GCM_CTR_V1="; (__isset.AES_GCM_CTR_V1 ? (out << to_string(AES_GCM_CTR_V1)) : (out << "<null>"));
+ out << ")";
+}
+
+
+FileMetaData::~FileMetaData() noexcept {
+}
+
+
+void FileMetaData::__set_version(const int32_t val) {
+ this->version = val;
+}
+
+void FileMetaData::__set_schema(const std::vector<SchemaElement> & val) {
+ this->schema = val;
+}
+
+void FileMetaData::__set_num_rows(const int64_t val) {
+ this->num_rows = val;
+}
+
+void FileMetaData::__set_row_groups(const std::vector<RowGroup> & val) {
+ this->row_groups = val;
+}
+
+void FileMetaData::__set_key_value_metadata(const std::vector<KeyValue> & val) {
+ this->key_value_metadata = val;
+__isset.key_value_metadata = true;
+}
+
+void FileMetaData::__set_created_by(const std::string& val) {
+ this->created_by = val;
+__isset.created_by = true;
+}
+
+void FileMetaData::__set_column_orders(const std::vector<ColumnOrder> & val) {
+ this->column_orders = val;
+__isset.column_orders = true;
+}
+
+void FileMetaData::__set_encryption_algorithm(const EncryptionAlgorithm& val) {
+ this->encryption_algorithm = val;
+__isset.encryption_algorithm = true;
+}
+
+void FileMetaData::__set_footer_signing_key_metadata(const std::string& val) {
+ this->footer_signing_key_metadata = val;
+__isset.footer_signing_key_metadata = true;
+}
+std::ostream& operator<<(std::ostream& out, const FileMetaData& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t FileMetaData::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_version = false;
+ bool isset_schema = false;
+ bool isset_num_rows = false;
+ bool isset_row_groups = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_I32) {
+ xfer += iprot->readI32(this->version);
+ isset_version = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->schema.clear();
+ uint32_t _size185;
+ ::apache::thrift::protocol::TType _etype188;
+ xfer += iprot->readListBegin(_etype188, _size185);
+ this->schema.resize(_size185);
+ uint32_t _i189;
+ for (_i189 = 0; _i189 < _size185; ++_i189)
+ {
+ xfer += this->schema[_i189].read(iprot);
+ }
+ xfer += iprot->readListEnd();
+ }
+ isset_schema = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 3:
+ if (ftype == ::apache::thrift::protocol::T_I64) {
+ xfer += iprot->readI64(this->num_rows);
+ isset_num_rows = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 4:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->row_groups.clear();
+ uint32_t _size190;
+ ::apache::thrift::protocol::TType _etype193;
+ xfer += iprot->readListBegin(_etype193, _size190);
+ this->row_groups.resize(_size190);
+ uint32_t _i194;
+ for (_i194 = 0; _i194 < _size190; ++_i194)
+ {
+ xfer += this->row_groups[_i194].read(iprot);
+ }
+ xfer += iprot->readListEnd();
+ }
+ isset_row_groups = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 5:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->key_value_metadata.clear();
+ uint32_t _size195;
+ ::apache::thrift::protocol::TType _etype198;
+ xfer += iprot->readListBegin(_etype198, _size195);
+ this->key_value_metadata.resize(_size195);
+ uint32_t _i199;
+ for (_i199 = 0; _i199 < _size195; ++_i199)
+ {
+ xfer += this->key_value_metadata[_i199].read(iprot);
+ }
+ xfer += iprot->readListEnd();
+ }
+ this->__isset.key_value_metadata = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 6:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readString(this->created_by);
+ this->__isset.created_by = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 7:
+ if (ftype == ::apache::thrift::protocol::T_LIST) {
+ {
+ this->column_orders.clear();
+ uint32_t _size200;
+ ::apache::thrift::protocol::TType _etype203;
+ xfer += iprot->readListBegin(_etype203, _size200);
+ this->column_orders.resize(_size200);
+ uint32_t _i204;
+ for (_i204 = 0; _i204 < _size200; ++_i204)
+ {
+ xfer += this->column_orders[_i204].read(iprot);
+ }
+ xfer += iprot->readListEnd();
+ }
+ this->__isset.column_orders = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 8:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->encryption_algorithm.read(iprot);
+ this->__isset.encryption_algorithm = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 9:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readBinary(this->footer_signing_key_metadata);
+ this->__isset.footer_signing_key_metadata = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_version)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_schema)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_num_rows)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ if (!isset_row_groups)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t FileMetaData::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("FileMetaData");
+
+ xfer += oprot->writeFieldBegin("version", ::apache::thrift::protocol::T_I32, 1);
+ xfer += oprot->writeI32(this->version);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("schema", ::apache::thrift::protocol::T_LIST, 2);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast<uint32_t>(this->schema.size()));
+ std::vector<SchemaElement> ::const_iterator _iter205;
+ for (_iter205 = this->schema.begin(); _iter205 != this->schema.end(); ++_iter205)
+ {
+ xfer += (*_iter205).write(oprot);
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("num_rows", ::apache::thrift::protocol::T_I64, 3);
+ xfer += oprot->writeI64(this->num_rows);
+ xfer += oprot->writeFieldEnd();
+
+ xfer += oprot->writeFieldBegin("row_groups", ::apache::thrift::protocol::T_LIST, 4);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast<uint32_t>(this->row_groups.size()));
+ std::vector<RowGroup> ::const_iterator _iter206;
+ for (_iter206 = this->row_groups.begin(); _iter206 != this->row_groups.end(); ++_iter206)
+ {
+ xfer += (*_iter206).write(oprot);
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+
+ if (this->__isset.key_value_metadata) {
+ xfer += oprot->writeFieldBegin("key_value_metadata", ::apache::thrift::protocol::T_LIST, 5);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast<uint32_t>(this->key_value_metadata.size()));
+ std::vector<KeyValue> ::const_iterator _iter207;
+ for (_iter207 = this->key_value_metadata.begin(); _iter207 != this->key_value_metadata.end(); ++_iter207)
+ {
+ xfer += (*_iter207).write(oprot);
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.created_by) {
+ xfer += oprot->writeFieldBegin("created_by", ::apache::thrift::protocol::T_STRING, 6);
+ xfer += oprot->writeString(this->created_by);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.column_orders) {
+ xfer += oprot->writeFieldBegin("column_orders", ::apache::thrift::protocol::T_LIST, 7);
+ {
+ xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast<uint32_t>(this->column_orders.size()));
+ std::vector<ColumnOrder> ::const_iterator _iter208;
+ for (_iter208 = this->column_orders.begin(); _iter208 != this->column_orders.end(); ++_iter208)
+ {
+ xfer += (*_iter208).write(oprot);
+ }
+ xfer += oprot->writeListEnd();
+ }
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.encryption_algorithm) {
+ xfer += oprot->writeFieldBegin("encryption_algorithm", ::apache::thrift::protocol::T_STRUCT, 8);
+ xfer += this->encryption_algorithm.write(oprot);
+ xfer += oprot->writeFieldEnd();
+ }
+ if (this->__isset.footer_signing_key_metadata) {
+ xfer += oprot->writeFieldBegin("footer_signing_key_metadata", ::apache::thrift::protocol::T_STRING, 9);
+ xfer += oprot->writeBinary(this->footer_signing_key_metadata);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(FileMetaData &a, FileMetaData &b) {
+ using ::std::swap;
+ swap(a.version, b.version);
+ swap(a.schema, b.schema);
+ swap(a.num_rows, b.num_rows);
+ swap(a.row_groups, b.row_groups);
+ swap(a.key_value_metadata, b.key_value_metadata);
+ swap(a.created_by, b.created_by);
+ swap(a.column_orders, b.column_orders);
+ swap(a.encryption_algorithm, b.encryption_algorithm);
+ swap(a.footer_signing_key_metadata, b.footer_signing_key_metadata);
+ swap(a.__isset, b.__isset);
+}
+
+FileMetaData::FileMetaData(const FileMetaData& other209) {
+ version = other209.version;
+ schema = other209.schema;
+ num_rows = other209.num_rows;
+ row_groups = other209.row_groups;
+ key_value_metadata = other209.key_value_metadata;
+ created_by = other209.created_by;
+ column_orders = other209.column_orders;
+ encryption_algorithm = other209.encryption_algorithm;
+ footer_signing_key_metadata = other209.footer_signing_key_metadata;
+ __isset = other209.__isset;
+}
+FileMetaData& FileMetaData::operator=(const FileMetaData& other210) {
+ version = other210.version;
+ schema = other210.schema;
+ num_rows = other210.num_rows;
+ row_groups = other210.row_groups;
+ key_value_metadata = other210.key_value_metadata;
+ created_by = other210.created_by;
+ column_orders = other210.column_orders;
+ encryption_algorithm = other210.encryption_algorithm;
+ footer_signing_key_metadata = other210.footer_signing_key_metadata;
+ __isset = other210.__isset;
+ return *this;
+}
+void FileMetaData::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "FileMetaData(";
+ out << "version=" << to_string(version);
+ out << ", " << "schema=" << to_string(schema);
+ out << ", " << "num_rows=" << to_string(num_rows);
+ out << ", " << "row_groups=" << to_string(row_groups);
+ out << ", " << "key_value_metadata="; (__isset.key_value_metadata ? (out << to_string(key_value_metadata)) : (out << "<null>"));
+ out << ", " << "created_by="; (__isset.created_by ? (out << to_string(created_by)) : (out << "<null>"));
+ out << ", " << "column_orders="; (__isset.column_orders ? (out << to_string(column_orders)) : (out << "<null>"));
+ out << ", " << "encryption_algorithm="; (__isset.encryption_algorithm ? (out << to_string(encryption_algorithm)) : (out << "<null>"));
+ out << ", " << "footer_signing_key_metadata="; (__isset.footer_signing_key_metadata ? (out << to_string(footer_signing_key_metadata)) : (out << "<null>"));
+ out << ")";
+}
+
+
+FileCryptoMetaData::~FileCryptoMetaData() noexcept {
+}
+
+
+void FileCryptoMetaData::__set_encryption_algorithm(const EncryptionAlgorithm& val) {
+ this->encryption_algorithm = val;
+}
+
+void FileCryptoMetaData::__set_key_metadata(const std::string& val) {
+ this->key_metadata = val;
+__isset.key_metadata = true;
+}
+std::ostream& operator<<(std::ostream& out, const FileCryptoMetaData& obj)
+{
+ obj.printTo(out);
+ return out;
+}
+
+
+uint32_t FileCryptoMetaData::read(::apache::thrift::protocol::TProtocol* iprot) {
+
+ ::apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);
+ uint32_t xfer = 0;
+ std::string fname;
+ ::apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ using ::apache::thrift::protocol::TProtocolException;
+
+ bool isset_encryption_algorithm = false;
+
+ while (true)
+ {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == ::apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid)
+ {
+ case 1:
+ if (ftype == ::apache::thrift::protocol::T_STRUCT) {
+ xfer += this->encryption_algorithm.read(iprot);
+ isset_encryption_algorithm = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == ::apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readBinary(this->key_metadata);
+ this->__isset.key_metadata = true;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+
+ if (!isset_encryption_algorithm)
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ return xfer;
+}
+
+uint32_t FileCryptoMetaData::write(::apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ ::apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);
+ xfer += oprot->writeStructBegin("FileCryptoMetaData");
+
+ xfer += oprot->writeFieldBegin("encryption_algorithm", ::apache::thrift::protocol::T_STRUCT, 1);
+ xfer += this->encryption_algorithm.write(oprot);
+ xfer += oprot->writeFieldEnd();
+
+ if (this->__isset.key_metadata) {
+ xfer += oprot->writeFieldBegin("key_metadata", ::apache::thrift::protocol::T_STRING, 2);
+ xfer += oprot->writeBinary(this->key_metadata);
+ xfer += oprot->writeFieldEnd();
+ }
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+void swap(FileCryptoMetaData &a, FileCryptoMetaData &b) {
+ using ::std::swap;
+ swap(a.encryption_algorithm, b.encryption_algorithm);
+ swap(a.key_metadata, b.key_metadata);
+ swap(a.__isset, b.__isset);
+}
+
+FileCryptoMetaData::FileCryptoMetaData(const FileCryptoMetaData& other211) {
+ encryption_algorithm = other211.encryption_algorithm;
+ key_metadata = other211.key_metadata;
+ __isset = other211.__isset;
+}
+FileCryptoMetaData& FileCryptoMetaData::operator=(const FileCryptoMetaData& other212) {
+ encryption_algorithm = other212.encryption_algorithm;
+ key_metadata = other212.key_metadata;
+ __isset = other212.__isset;
+ return *this;
+}
+void FileCryptoMetaData::printTo(std::ostream& out) const {
+ using ::apache::thrift::to_string;
+ out << "FileCryptoMetaData(";
+ out << "encryption_algorithm=" << to_string(encryption_algorithm);
+ out << ", " << "key_metadata="; (__isset.key_metadata ? (out << to_string(key_metadata)) : (out << "<null>"));
+ out << ")";
+}
+
+}} // namespace
diff --git a/src/arrow/cpp/src/generated/parquet_types.h b/src/arrow/cpp/src/generated/parquet_types.h
new file mode 100644
index 000000000..3d7edd409
--- /dev/null
+++ b/src/arrow/cpp/src/generated/parquet_types.h
@@ -0,0 +1,2917 @@
+/**
+ * Autogenerated by Thrift Compiler (0.13.0)
+ *
+ * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
+ * @generated
+ */
+#ifndef parquet_TYPES_H
+#define parquet_TYPES_H
+
+#include <iosfwd>
+
+#include <thrift/Thrift.h>
+#include <thrift/TApplicationException.h>
+#include <thrift/TBase.h>
+#include <thrift/protocol/TProtocol.h>
+#include <thrift/transport/TTransport.h>
+
+#include <functional>
+#include <memory>
+
+#include "parquet/windows_compatibility.h"
+
+namespace parquet { namespace format {
+
+struct Type {
+ enum type {
+ BOOLEAN = 0,
+ INT32 = 1,
+ INT64 = 2,
+ INT96 = 3,
+ FLOAT = 4,
+ DOUBLE = 5,
+ BYTE_ARRAY = 6,
+ FIXED_LEN_BYTE_ARRAY = 7
+ };
+};
+
+extern const std::map<int, const char*> _Type_VALUES_TO_NAMES;
+
+std::ostream& operator<<(std::ostream& out, const Type::type& val);
+
+std::string to_string(const Type::type& val);
+
+struct ConvertedType {
+ enum type {
+ UTF8 = 0,
+ MAP = 1,
+ MAP_KEY_VALUE = 2,
+ LIST = 3,
+ ENUM = 4,
+ DECIMAL = 5,
+ DATE = 6,
+ TIME_MILLIS = 7,
+ TIME_MICROS = 8,
+ TIMESTAMP_MILLIS = 9,
+ TIMESTAMP_MICROS = 10,
+ UINT_8 = 11,
+ UINT_16 = 12,
+ UINT_32 = 13,
+ UINT_64 = 14,
+ INT_8 = 15,
+ INT_16 = 16,
+ INT_32 = 17,
+ INT_64 = 18,
+ JSON = 19,
+ BSON = 20,
+ INTERVAL = 21
+ };
+};
+
+extern const std::map<int, const char*> _ConvertedType_VALUES_TO_NAMES;
+
+std::ostream& operator<<(std::ostream& out, const ConvertedType::type& val);
+
+std::string to_string(const ConvertedType::type& val);
+
+struct FieldRepetitionType {
+ enum type {
+ REQUIRED = 0,
+ OPTIONAL = 1,
+ REPEATED = 2
+ };
+};
+
+extern const std::map<int, const char*> _FieldRepetitionType_VALUES_TO_NAMES;
+
+std::ostream& operator<<(std::ostream& out, const FieldRepetitionType::type& val);
+
+std::string to_string(const FieldRepetitionType::type& val);
+
+struct Encoding {
+ enum type {
+ PLAIN = 0,
+ PLAIN_DICTIONARY = 2,
+ RLE = 3,
+ BIT_PACKED = 4,
+ DELTA_BINARY_PACKED = 5,
+ DELTA_LENGTH_BYTE_ARRAY = 6,
+ DELTA_BYTE_ARRAY = 7,
+ RLE_DICTIONARY = 8,
+ BYTE_STREAM_SPLIT = 9
+ };
+};
+
+extern const std::map<int, const char*> _Encoding_VALUES_TO_NAMES;
+
+std::ostream& operator<<(std::ostream& out, const Encoding::type& val);
+
+std::string to_string(const Encoding::type& val);
+
+struct CompressionCodec {
+ enum type {
+ UNCOMPRESSED = 0,
+ SNAPPY = 1,
+ GZIP = 2,
+ LZO = 3,
+ BROTLI = 4,
+ LZ4 = 5,
+ ZSTD = 6,
+ LZ4_RAW = 7
+ };
+};
+
+extern const std::map<int, const char*> _CompressionCodec_VALUES_TO_NAMES;
+
+std::ostream& operator<<(std::ostream& out, const CompressionCodec::type& val);
+
+std::string to_string(const CompressionCodec::type& val);
+
+struct PageType {
+ enum type {
+ DATA_PAGE = 0,
+ INDEX_PAGE = 1,
+ DICTIONARY_PAGE = 2,
+ DATA_PAGE_V2 = 3
+ };
+};
+
+extern const std::map<int, const char*> _PageType_VALUES_TO_NAMES;
+
+std::ostream& operator<<(std::ostream& out, const PageType::type& val);
+
+std::string to_string(const PageType::type& val);
+
+struct BoundaryOrder {
+ enum type {
+ UNORDERED = 0,
+ ASCENDING = 1,
+ DESCENDING = 2
+ };
+};
+
+extern const std::map<int, const char*> _BoundaryOrder_VALUES_TO_NAMES;
+
+std::ostream& operator<<(std::ostream& out, const BoundaryOrder::type& val);
+
+std::string to_string(const BoundaryOrder::type& val);
+
+class Statistics;
+
+class StringType;
+
+class UUIDType;
+
+class MapType;
+
+class ListType;
+
+class EnumType;
+
+class DateType;
+
+class NullType;
+
+class DecimalType;
+
+class MilliSeconds;
+
+class MicroSeconds;
+
+class NanoSeconds;
+
+class TimeUnit;
+
+class TimestampType;
+
+class TimeType;
+
+class IntType;
+
+class JsonType;
+
+class BsonType;
+
+class LogicalType;
+
+class SchemaElement;
+
+class DataPageHeader;
+
+class IndexPageHeader;
+
+class DictionaryPageHeader;
+
+class DataPageHeaderV2;
+
+class SplitBlockAlgorithm;
+
+class BloomFilterAlgorithm;
+
+class XxHash;
+
+class BloomFilterHash;
+
+class Uncompressed;
+
+class BloomFilterCompression;
+
+class BloomFilterHeader;
+
+class PageHeader;
+
+class KeyValue;
+
+class SortingColumn;
+
+class PageEncodingStats;
+
+class ColumnMetaData;
+
+class EncryptionWithFooterKey;
+
+class EncryptionWithColumnKey;
+
+class ColumnCryptoMetaData;
+
+class ColumnChunk;
+
+class RowGroup;
+
+class TypeDefinedOrder;
+
+class ColumnOrder;
+
+class PageLocation;
+
+class OffsetIndex;
+
+class ColumnIndex;
+
+class AesGcmV1;
+
+class AesGcmCtrV1;
+
+class EncryptionAlgorithm;
+
+class FileMetaData;
+
+class FileCryptoMetaData;
+
+typedef struct _Statistics__isset {
+ _Statistics__isset() : max(false), min(false), null_count(false), distinct_count(false), max_value(false), min_value(false) {}
+ bool max :1;
+ bool min :1;
+ bool null_count :1;
+ bool distinct_count :1;
+ bool max_value :1;
+ bool min_value :1;
+} _Statistics__isset;
+
+class Statistics : public virtual ::apache::thrift::TBase {
+ public:
+
+ Statistics(const Statistics&);
+ Statistics& operator=(const Statistics&);
+ Statistics() : max(), min(), null_count(0), distinct_count(0), max_value(), min_value() {
+ }
+
+ virtual ~Statistics() noexcept;
+ std::string max;
+ std::string min;
+ int64_t null_count;
+ int64_t distinct_count;
+ std::string max_value;
+ std::string min_value;
+
+ _Statistics__isset __isset;
+
+ void __set_max(const std::string& val);
+
+ void __set_min(const std::string& val);
+
+ void __set_null_count(const int64_t val);
+
+ void __set_distinct_count(const int64_t val);
+
+ void __set_max_value(const std::string& val);
+
+ void __set_min_value(const std::string& val);
+
+ bool operator == (const Statistics & rhs) const
+ {
+ if (__isset.max != rhs.__isset.max)
+ return false;
+ else if (__isset.max && !(max == rhs.max))
+ return false;
+ if (__isset.min != rhs.__isset.min)
+ return false;
+ else if (__isset.min && !(min == rhs.min))
+ return false;
+ if (__isset.null_count != rhs.__isset.null_count)
+ return false;
+ else if (__isset.null_count && !(null_count == rhs.null_count))
+ return false;
+ if (__isset.distinct_count != rhs.__isset.distinct_count)
+ return false;
+ else if (__isset.distinct_count && !(distinct_count == rhs.distinct_count))
+ return false;
+ if (__isset.max_value != rhs.__isset.max_value)
+ return false;
+ else if (__isset.max_value && !(max_value == rhs.max_value))
+ return false;
+ if (__isset.min_value != rhs.__isset.min_value)
+ return false;
+ else if (__isset.min_value && !(min_value == rhs.min_value))
+ return false;
+ return true;
+ }
+ bool operator != (const Statistics &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const Statistics & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(Statistics &a, Statistics &b);
+
+std::ostream& operator<<(std::ostream& out, const Statistics& obj);
+
+
+class StringType : public virtual ::apache::thrift::TBase {
+ public:
+
+ StringType(const StringType&);
+ StringType& operator=(const StringType&);
+ StringType() {
+ }
+
+ virtual ~StringType() noexcept;
+
+ bool operator == (const StringType & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const StringType &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const StringType & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(StringType &a, StringType &b);
+
+std::ostream& operator<<(std::ostream& out, const StringType& obj);
+
+
+class UUIDType : public virtual ::apache::thrift::TBase {
+ public:
+
+ UUIDType(const UUIDType&);
+ UUIDType& operator=(const UUIDType&);
+ UUIDType() {
+ }
+
+ virtual ~UUIDType() noexcept;
+
+ bool operator == (const UUIDType & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const UUIDType &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const UUIDType & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(UUIDType &a, UUIDType &b);
+
+std::ostream& operator<<(std::ostream& out, const UUIDType& obj);
+
+
+class MapType : public virtual ::apache::thrift::TBase {
+ public:
+
+ MapType(const MapType&);
+ MapType& operator=(const MapType&);
+ MapType() {
+ }
+
+ virtual ~MapType() noexcept;
+
+ bool operator == (const MapType & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const MapType &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const MapType & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(MapType &a, MapType &b);
+
+std::ostream& operator<<(std::ostream& out, const MapType& obj);
+
+
+class ListType : public virtual ::apache::thrift::TBase {
+ public:
+
+ ListType(const ListType&);
+ ListType& operator=(const ListType&);
+ ListType() {
+ }
+
+ virtual ~ListType() noexcept;
+
+ bool operator == (const ListType & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const ListType &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const ListType & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(ListType &a, ListType &b);
+
+std::ostream& operator<<(std::ostream& out, const ListType& obj);
+
+
+class EnumType : public virtual ::apache::thrift::TBase {
+ public:
+
+ EnumType(const EnumType&);
+ EnumType& operator=(const EnumType&);
+ EnumType() {
+ }
+
+ virtual ~EnumType() noexcept;
+
+ bool operator == (const EnumType & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const EnumType &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const EnumType & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(EnumType &a, EnumType &b);
+
+std::ostream& operator<<(std::ostream& out, const EnumType& obj);
+
+
+class DateType : public virtual ::apache::thrift::TBase {
+ public:
+
+ DateType(const DateType&);
+ DateType& operator=(const DateType&);
+ DateType() {
+ }
+
+ virtual ~DateType() noexcept;
+
+ bool operator == (const DateType & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const DateType &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const DateType & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(DateType &a, DateType &b);
+
+std::ostream& operator<<(std::ostream& out, const DateType& obj);
+
+
+class NullType : public virtual ::apache::thrift::TBase {
+ public:
+
+ NullType(const NullType&);
+ NullType& operator=(const NullType&);
+ NullType() {
+ }
+
+ virtual ~NullType() noexcept;
+
+ bool operator == (const NullType & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const NullType &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const NullType & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(NullType &a, NullType &b);
+
+std::ostream& operator<<(std::ostream& out, const NullType& obj);
+
+
+class DecimalType : public virtual ::apache::thrift::TBase {
+ public:
+
+ DecimalType(const DecimalType&);
+ DecimalType& operator=(const DecimalType&);
+ DecimalType() : scale(0), precision(0) {
+ }
+
+ virtual ~DecimalType() noexcept;
+ int32_t scale;
+ int32_t precision;
+
+ void __set_scale(const int32_t val);
+
+ void __set_precision(const int32_t val);
+
+ bool operator == (const DecimalType & rhs) const
+ {
+ if (!(scale == rhs.scale))
+ return false;
+ if (!(precision == rhs.precision))
+ return false;
+ return true;
+ }
+ bool operator != (const DecimalType &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const DecimalType & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(DecimalType &a, DecimalType &b);
+
+std::ostream& operator<<(std::ostream& out, const DecimalType& obj);
+
+
+class MilliSeconds : public virtual ::apache::thrift::TBase {
+ public:
+
+ MilliSeconds(const MilliSeconds&);
+ MilliSeconds& operator=(const MilliSeconds&);
+ MilliSeconds() {
+ }
+
+ virtual ~MilliSeconds() noexcept;
+
+ bool operator == (const MilliSeconds & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const MilliSeconds &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const MilliSeconds & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(MilliSeconds &a, MilliSeconds &b);
+
+std::ostream& operator<<(std::ostream& out, const MilliSeconds& obj);
+
+
+class MicroSeconds : public virtual ::apache::thrift::TBase {
+ public:
+
+ MicroSeconds(const MicroSeconds&);
+ MicroSeconds& operator=(const MicroSeconds&);
+ MicroSeconds() {
+ }
+
+ virtual ~MicroSeconds() noexcept;
+
+ bool operator == (const MicroSeconds & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const MicroSeconds &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const MicroSeconds & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(MicroSeconds &a, MicroSeconds &b);
+
+std::ostream& operator<<(std::ostream& out, const MicroSeconds& obj);
+
+
+class NanoSeconds : public virtual ::apache::thrift::TBase {
+ public:
+
+ NanoSeconds(const NanoSeconds&);
+ NanoSeconds& operator=(const NanoSeconds&);
+ NanoSeconds() {
+ }
+
+ virtual ~NanoSeconds() noexcept;
+
+ bool operator == (const NanoSeconds & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const NanoSeconds &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const NanoSeconds & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(NanoSeconds &a, NanoSeconds &b);
+
+std::ostream& operator<<(std::ostream& out, const NanoSeconds& obj);
+
+typedef struct _TimeUnit__isset {
+ _TimeUnit__isset() : MILLIS(false), MICROS(false), NANOS(false) {}
+ bool MILLIS :1;
+ bool MICROS :1;
+ bool NANOS :1;
+} _TimeUnit__isset;
+
+class TimeUnit : public virtual ::apache::thrift::TBase {
+ public:
+
+ TimeUnit(const TimeUnit&);
+ TimeUnit& operator=(const TimeUnit&);
+ TimeUnit() {
+ }
+
+ virtual ~TimeUnit() noexcept;
+ MilliSeconds MILLIS;
+ MicroSeconds MICROS;
+ NanoSeconds NANOS;
+
+ _TimeUnit__isset __isset;
+
+ void __set_MILLIS(const MilliSeconds& val);
+
+ void __set_MICROS(const MicroSeconds& val);
+
+ void __set_NANOS(const NanoSeconds& val);
+
+ bool operator == (const TimeUnit & rhs) const
+ {
+ if (__isset.MILLIS != rhs.__isset.MILLIS)
+ return false;
+ else if (__isset.MILLIS && !(MILLIS == rhs.MILLIS))
+ return false;
+ if (__isset.MICROS != rhs.__isset.MICROS)
+ return false;
+ else if (__isset.MICROS && !(MICROS == rhs.MICROS))
+ return false;
+ if (__isset.NANOS != rhs.__isset.NANOS)
+ return false;
+ else if (__isset.NANOS && !(NANOS == rhs.NANOS))
+ return false;
+ return true;
+ }
+ bool operator != (const TimeUnit &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const TimeUnit & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(TimeUnit &a, TimeUnit &b);
+
+std::ostream& operator<<(std::ostream& out, const TimeUnit& obj);
+
+
+class TimestampType : public virtual ::apache::thrift::TBase {
+ public:
+
+ TimestampType(const TimestampType&);
+ TimestampType& operator=(const TimestampType&);
+ TimestampType() : isAdjustedToUTC(0) {
+ }
+
+ virtual ~TimestampType() noexcept;
+ bool isAdjustedToUTC;
+ TimeUnit unit;
+
+ void __set_isAdjustedToUTC(const bool val);
+
+ void __set_unit(const TimeUnit& val);
+
+ bool operator == (const TimestampType & rhs) const
+ {
+ if (!(isAdjustedToUTC == rhs.isAdjustedToUTC))
+ return false;
+ if (!(unit == rhs.unit))
+ return false;
+ return true;
+ }
+ bool operator != (const TimestampType &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const TimestampType & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(TimestampType &a, TimestampType &b);
+
+std::ostream& operator<<(std::ostream& out, const TimestampType& obj);
+
+
+class TimeType : public virtual ::apache::thrift::TBase {
+ public:
+
+ TimeType(const TimeType&);
+ TimeType& operator=(const TimeType&);
+ TimeType() : isAdjustedToUTC(0) {
+ }
+
+ virtual ~TimeType() noexcept;
+ bool isAdjustedToUTC;
+ TimeUnit unit;
+
+ void __set_isAdjustedToUTC(const bool val);
+
+ void __set_unit(const TimeUnit& val);
+
+ bool operator == (const TimeType & rhs) const
+ {
+ if (!(isAdjustedToUTC == rhs.isAdjustedToUTC))
+ return false;
+ if (!(unit == rhs.unit))
+ return false;
+ return true;
+ }
+ bool operator != (const TimeType &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const TimeType & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(TimeType &a, TimeType &b);
+
+std::ostream& operator<<(std::ostream& out, const TimeType& obj);
+
+
+class IntType : public virtual ::apache::thrift::TBase {
+ public:
+
+ IntType(const IntType&);
+ IntType& operator=(const IntType&);
+ IntType() : bitWidth(0), isSigned(0) {
+ }
+
+ virtual ~IntType() noexcept;
+ int8_t bitWidth;
+ bool isSigned;
+
+ void __set_bitWidth(const int8_t val);
+
+ void __set_isSigned(const bool val);
+
+ bool operator == (const IntType & rhs) const
+ {
+ if (!(bitWidth == rhs.bitWidth))
+ return false;
+ if (!(isSigned == rhs.isSigned))
+ return false;
+ return true;
+ }
+ bool operator != (const IntType &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const IntType & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(IntType &a, IntType &b);
+
+std::ostream& operator<<(std::ostream& out, const IntType& obj);
+
+
+class JsonType : public virtual ::apache::thrift::TBase {
+ public:
+
+ JsonType(const JsonType&);
+ JsonType& operator=(const JsonType&);
+ JsonType() {
+ }
+
+ virtual ~JsonType() noexcept;
+
+ bool operator == (const JsonType & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const JsonType &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const JsonType & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(JsonType &a, JsonType &b);
+
+std::ostream& operator<<(std::ostream& out, const JsonType& obj);
+
+
+class BsonType : public virtual ::apache::thrift::TBase {
+ public:
+
+ BsonType(const BsonType&);
+ BsonType& operator=(const BsonType&);
+ BsonType() {
+ }
+
+ virtual ~BsonType() noexcept;
+
+ bool operator == (const BsonType & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const BsonType &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const BsonType & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(BsonType &a, BsonType &b);
+
+std::ostream& operator<<(std::ostream& out, const BsonType& obj);
+
+typedef struct _LogicalType__isset {
+ _LogicalType__isset() : STRING(false), MAP(false), LIST(false), ENUM(false), DECIMAL(false), DATE(false), TIME(false), TIMESTAMP(false), INTEGER(false), UNKNOWN(false), JSON(false), BSON(false), UUID(false) {}
+ bool STRING :1;
+ bool MAP :1;
+ bool LIST :1;
+ bool ENUM :1;
+ bool DECIMAL :1;
+ bool DATE :1;
+ bool TIME :1;
+ bool TIMESTAMP :1;
+ bool INTEGER :1;
+ bool UNKNOWN :1;
+ bool JSON :1;
+ bool BSON :1;
+ bool UUID :1;
+} _LogicalType__isset;
+
+class LogicalType : public virtual ::apache::thrift::TBase {
+ public:
+
+ LogicalType(const LogicalType&);
+ LogicalType& operator=(const LogicalType&);
+ LogicalType() {
+ }
+
+ virtual ~LogicalType() noexcept;
+ StringType STRING;
+ MapType MAP;
+ ListType LIST;
+ EnumType ENUM;
+ DecimalType DECIMAL;
+ DateType DATE;
+ TimeType TIME;
+ TimestampType TIMESTAMP;
+ IntType INTEGER;
+ NullType UNKNOWN;
+ JsonType JSON;
+ BsonType BSON;
+ UUIDType UUID;
+
+ _LogicalType__isset __isset;
+
+ void __set_STRING(const StringType& val);
+
+ void __set_MAP(const MapType& val);
+
+ void __set_LIST(const ListType& val);
+
+ void __set_ENUM(const EnumType& val);
+
+ void __set_DECIMAL(const DecimalType& val);
+
+ void __set_DATE(const DateType& val);
+
+ void __set_TIME(const TimeType& val);
+
+ void __set_TIMESTAMP(const TimestampType& val);
+
+ void __set_INTEGER(const IntType& val);
+
+ void __set_UNKNOWN(const NullType& val);
+
+ void __set_JSON(const JsonType& val);
+
+ void __set_BSON(const BsonType& val);
+
+ void __set_UUID(const UUIDType& val);
+
+ bool operator == (const LogicalType & rhs) const
+ {
+ if (__isset.STRING != rhs.__isset.STRING)
+ return false;
+ else if (__isset.STRING && !(STRING == rhs.STRING))
+ return false;
+ if (__isset.MAP != rhs.__isset.MAP)
+ return false;
+ else if (__isset.MAP && !(MAP == rhs.MAP))
+ return false;
+ if (__isset.LIST != rhs.__isset.LIST)
+ return false;
+ else if (__isset.LIST && !(LIST == rhs.LIST))
+ return false;
+ if (__isset.ENUM != rhs.__isset.ENUM)
+ return false;
+ else if (__isset.ENUM && !(ENUM == rhs.ENUM))
+ return false;
+ if (__isset.DECIMAL != rhs.__isset.DECIMAL)
+ return false;
+ else if (__isset.DECIMAL && !(DECIMAL == rhs.DECIMAL))
+ return false;
+ if (__isset.DATE != rhs.__isset.DATE)
+ return false;
+ else if (__isset.DATE && !(DATE == rhs.DATE))
+ return false;
+ if (__isset.TIME != rhs.__isset.TIME)
+ return false;
+ else if (__isset.TIME && !(TIME == rhs.TIME))
+ return false;
+ if (__isset.TIMESTAMP != rhs.__isset.TIMESTAMP)
+ return false;
+ else if (__isset.TIMESTAMP && !(TIMESTAMP == rhs.TIMESTAMP))
+ return false;
+ if (__isset.INTEGER != rhs.__isset.INTEGER)
+ return false;
+ else if (__isset.INTEGER && !(INTEGER == rhs.INTEGER))
+ return false;
+ if (__isset.UNKNOWN != rhs.__isset.UNKNOWN)
+ return false;
+ else if (__isset.UNKNOWN && !(UNKNOWN == rhs.UNKNOWN))
+ return false;
+ if (__isset.JSON != rhs.__isset.JSON)
+ return false;
+ else if (__isset.JSON && !(JSON == rhs.JSON))
+ return false;
+ if (__isset.BSON != rhs.__isset.BSON)
+ return false;
+ else if (__isset.BSON && !(BSON == rhs.BSON))
+ return false;
+ if (__isset.UUID != rhs.__isset.UUID)
+ return false;
+ else if (__isset.UUID && !(UUID == rhs.UUID))
+ return false;
+ return true;
+ }
+ bool operator != (const LogicalType &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const LogicalType & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(LogicalType &a, LogicalType &b);
+
+std::ostream& operator<<(std::ostream& out, const LogicalType& obj);
+
+typedef struct _SchemaElement__isset {
+ _SchemaElement__isset() : type(false), type_length(false), repetition_type(false), num_children(false), converted_type(false), scale(false), precision(false), field_id(false), logicalType(false) {}
+ bool type :1;
+ bool type_length :1;
+ bool repetition_type :1;
+ bool num_children :1;
+ bool converted_type :1;
+ bool scale :1;
+ bool precision :1;
+ bool field_id :1;
+ bool logicalType :1;
+} _SchemaElement__isset;
+
+class SchemaElement : public virtual ::apache::thrift::TBase {
+ public:
+
+ SchemaElement(const SchemaElement&);
+ SchemaElement& operator=(const SchemaElement&);
+ SchemaElement() : type((Type::type)0), type_length(0), repetition_type((FieldRepetitionType::type)0), name(), num_children(0), converted_type((ConvertedType::type)0), scale(0), precision(0), field_id(0) {
+ }
+
+ virtual ~SchemaElement() noexcept;
+ Type::type type;
+ int32_t type_length;
+ FieldRepetitionType::type repetition_type;
+ std::string name;
+ int32_t num_children;
+ ConvertedType::type converted_type;
+ int32_t scale;
+ int32_t precision;
+ int32_t field_id;
+ LogicalType logicalType;
+
+ _SchemaElement__isset __isset;
+
+ void __set_type(const Type::type val);
+
+ void __set_type_length(const int32_t val);
+
+ void __set_repetition_type(const FieldRepetitionType::type val);
+
+ void __set_name(const std::string& val);
+
+ void __set_num_children(const int32_t val);
+
+ void __set_converted_type(const ConvertedType::type val);
+
+ void __set_scale(const int32_t val);
+
+ void __set_precision(const int32_t val);
+
+ void __set_field_id(const int32_t val);
+
+ void __set_logicalType(const LogicalType& val);
+
+ bool operator == (const SchemaElement & rhs) const
+ {
+ if (__isset.type != rhs.__isset.type)
+ return false;
+ else if (__isset.type && !(type == rhs.type))
+ return false;
+ if (__isset.type_length != rhs.__isset.type_length)
+ return false;
+ else if (__isset.type_length && !(type_length == rhs.type_length))
+ return false;
+ if (__isset.repetition_type != rhs.__isset.repetition_type)
+ return false;
+ else if (__isset.repetition_type && !(repetition_type == rhs.repetition_type))
+ return false;
+ if (!(name == rhs.name))
+ return false;
+ if (__isset.num_children != rhs.__isset.num_children)
+ return false;
+ else if (__isset.num_children && !(num_children == rhs.num_children))
+ return false;
+ if (__isset.converted_type != rhs.__isset.converted_type)
+ return false;
+ else if (__isset.converted_type && !(converted_type == rhs.converted_type))
+ return false;
+ if (__isset.scale != rhs.__isset.scale)
+ return false;
+ else if (__isset.scale && !(scale == rhs.scale))
+ return false;
+ if (__isset.precision != rhs.__isset.precision)
+ return false;
+ else if (__isset.precision && !(precision == rhs.precision))
+ return false;
+ if (__isset.field_id != rhs.__isset.field_id)
+ return false;
+ else if (__isset.field_id && !(field_id == rhs.field_id))
+ return false;
+ if (__isset.logicalType != rhs.__isset.logicalType)
+ return false;
+ else if (__isset.logicalType && !(logicalType == rhs.logicalType))
+ return false;
+ return true;
+ }
+ bool operator != (const SchemaElement &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const SchemaElement & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(SchemaElement &a, SchemaElement &b);
+
+std::ostream& operator<<(std::ostream& out, const SchemaElement& obj);
+
+typedef struct _DataPageHeader__isset {
+ _DataPageHeader__isset() : statistics(false) {}
+ bool statistics :1;
+} _DataPageHeader__isset;
+
+class DataPageHeader : public virtual ::apache::thrift::TBase {
+ public:
+
+ DataPageHeader(const DataPageHeader&);
+ DataPageHeader& operator=(const DataPageHeader&);
+ DataPageHeader() : num_values(0), encoding((Encoding::type)0), definition_level_encoding((Encoding::type)0), repetition_level_encoding((Encoding::type)0) {
+ }
+
+ virtual ~DataPageHeader() noexcept;
+ int32_t num_values;
+ Encoding::type encoding;
+ Encoding::type definition_level_encoding;
+ Encoding::type repetition_level_encoding;
+ Statistics statistics;
+
+ _DataPageHeader__isset __isset;
+
+ void __set_num_values(const int32_t val);
+
+ void __set_encoding(const Encoding::type val);
+
+ void __set_definition_level_encoding(const Encoding::type val);
+
+ void __set_repetition_level_encoding(const Encoding::type val);
+
+ void __set_statistics(const Statistics& val);
+
+ bool operator == (const DataPageHeader & rhs) const
+ {
+ if (!(num_values == rhs.num_values))
+ return false;
+ if (!(encoding == rhs.encoding))
+ return false;
+ if (!(definition_level_encoding == rhs.definition_level_encoding))
+ return false;
+ if (!(repetition_level_encoding == rhs.repetition_level_encoding))
+ return false;
+ if (__isset.statistics != rhs.__isset.statistics)
+ return false;
+ else if (__isset.statistics && !(statistics == rhs.statistics))
+ return false;
+ return true;
+ }
+ bool operator != (const DataPageHeader &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const DataPageHeader & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(DataPageHeader &a, DataPageHeader &b);
+
+std::ostream& operator<<(std::ostream& out, const DataPageHeader& obj);
+
+
+class IndexPageHeader : public virtual ::apache::thrift::TBase {
+ public:
+
+ IndexPageHeader(const IndexPageHeader&);
+ IndexPageHeader& operator=(const IndexPageHeader&);
+ IndexPageHeader() {
+ }
+
+ virtual ~IndexPageHeader() noexcept;
+
+ bool operator == (const IndexPageHeader & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const IndexPageHeader &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const IndexPageHeader & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(IndexPageHeader &a, IndexPageHeader &b);
+
+std::ostream& operator<<(std::ostream& out, const IndexPageHeader& obj);
+
+typedef struct _DictionaryPageHeader__isset {
+ _DictionaryPageHeader__isset() : is_sorted(false) {}
+ bool is_sorted :1;
+} _DictionaryPageHeader__isset;
+
+class DictionaryPageHeader : public virtual ::apache::thrift::TBase {
+ public:
+
+ DictionaryPageHeader(const DictionaryPageHeader&);
+ DictionaryPageHeader& operator=(const DictionaryPageHeader&);
+ DictionaryPageHeader() : num_values(0), encoding((Encoding::type)0), is_sorted(0) {
+ }
+
+ virtual ~DictionaryPageHeader() noexcept;
+ int32_t num_values;
+ Encoding::type encoding;
+ bool is_sorted;
+
+ _DictionaryPageHeader__isset __isset;
+
+ void __set_num_values(const int32_t val);
+
+ void __set_encoding(const Encoding::type val);
+
+ void __set_is_sorted(const bool val);
+
+ bool operator == (const DictionaryPageHeader & rhs) const
+ {
+ if (!(num_values == rhs.num_values))
+ return false;
+ if (!(encoding == rhs.encoding))
+ return false;
+ if (__isset.is_sorted != rhs.__isset.is_sorted)
+ return false;
+ else if (__isset.is_sorted && !(is_sorted == rhs.is_sorted))
+ return false;
+ return true;
+ }
+ bool operator != (const DictionaryPageHeader &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const DictionaryPageHeader & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(DictionaryPageHeader &a, DictionaryPageHeader &b);
+
+std::ostream& operator<<(std::ostream& out, const DictionaryPageHeader& obj);
+
+typedef struct _DataPageHeaderV2__isset {
+ _DataPageHeaderV2__isset() : is_compressed(true), statistics(false) {}
+ bool is_compressed :1;
+ bool statistics :1;
+} _DataPageHeaderV2__isset;
+
+class DataPageHeaderV2 : public virtual ::apache::thrift::TBase {
+ public:
+
+ DataPageHeaderV2(const DataPageHeaderV2&);
+ DataPageHeaderV2& operator=(const DataPageHeaderV2&);
+ DataPageHeaderV2() : num_values(0), num_nulls(0), num_rows(0), encoding((Encoding::type)0), definition_levels_byte_length(0), repetition_levels_byte_length(0), is_compressed(true) {
+ }
+
+ virtual ~DataPageHeaderV2() noexcept;
+ int32_t num_values;
+ int32_t num_nulls;
+ int32_t num_rows;
+ Encoding::type encoding;
+ int32_t definition_levels_byte_length;
+ int32_t repetition_levels_byte_length;
+ bool is_compressed;
+ Statistics statistics;
+
+ _DataPageHeaderV2__isset __isset;
+
+ void __set_num_values(const int32_t val);
+
+ void __set_num_nulls(const int32_t val);
+
+ void __set_num_rows(const int32_t val);
+
+ void __set_encoding(const Encoding::type val);
+
+ void __set_definition_levels_byte_length(const int32_t val);
+
+ void __set_repetition_levels_byte_length(const int32_t val);
+
+ void __set_is_compressed(const bool val);
+
+ void __set_statistics(const Statistics& val);
+
+ bool operator == (const DataPageHeaderV2 & rhs) const
+ {
+ if (!(num_values == rhs.num_values))
+ return false;
+ if (!(num_nulls == rhs.num_nulls))
+ return false;
+ if (!(num_rows == rhs.num_rows))
+ return false;
+ if (!(encoding == rhs.encoding))
+ return false;
+ if (!(definition_levels_byte_length == rhs.definition_levels_byte_length))
+ return false;
+ if (!(repetition_levels_byte_length == rhs.repetition_levels_byte_length))
+ return false;
+ if (__isset.is_compressed != rhs.__isset.is_compressed)
+ return false;
+ else if (__isset.is_compressed && !(is_compressed == rhs.is_compressed))
+ return false;
+ if (__isset.statistics != rhs.__isset.statistics)
+ return false;
+ else if (__isset.statistics && !(statistics == rhs.statistics))
+ return false;
+ return true;
+ }
+ bool operator != (const DataPageHeaderV2 &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const DataPageHeaderV2 & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(DataPageHeaderV2 &a, DataPageHeaderV2 &b);
+
+std::ostream& operator<<(std::ostream& out, const DataPageHeaderV2& obj);
+
+
+class SplitBlockAlgorithm : public virtual ::apache::thrift::TBase {
+ public:
+
+ SplitBlockAlgorithm(const SplitBlockAlgorithm&);
+ SplitBlockAlgorithm& operator=(const SplitBlockAlgorithm&);
+ SplitBlockAlgorithm() {
+ }
+
+ virtual ~SplitBlockAlgorithm() noexcept;
+
+ bool operator == (const SplitBlockAlgorithm & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const SplitBlockAlgorithm &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const SplitBlockAlgorithm & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(SplitBlockAlgorithm &a, SplitBlockAlgorithm &b);
+
+std::ostream& operator<<(std::ostream& out, const SplitBlockAlgorithm& obj);
+
+typedef struct _BloomFilterAlgorithm__isset {
+ _BloomFilterAlgorithm__isset() : BLOCK(false) {}
+ bool BLOCK :1;
+} _BloomFilterAlgorithm__isset;
+
+class BloomFilterAlgorithm : public virtual ::apache::thrift::TBase {
+ public:
+
+ BloomFilterAlgorithm(const BloomFilterAlgorithm&);
+ BloomFilterAlgorithm& operator=(const BloomFilterAlgorithm&);
+ BloomFilterAlgorithm() {
+ }
+
+ virtual ~BloomFilterAlgorithm() noexcept;
+ SplitBlockAlgorithm BLOCK;
+
+ _BloomFilterAlgorithm__isset __isset;
+
+ void __set_BLOCK(const SplitBlockAlgorithm& val);
+
+ bool operator == (const BloomFilterAlgorithm & rhs) const
+ {
+ if (__isset.BLOCK != rhs.__isset.BLOCK)
+ return false;
+ else if (__isset.BLOCK && !(BLOCK == rhs.BLOCK))
+ return false;
+ return true;
+ }
+ bool operator != (const BloomFilterAlgorithm &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const BloomFilterAlgorithm & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(BloomFilterAlgorithm &a, BloomFilterAlgorithm &b);
+
+std::ostream& operator<<(std::ostream& out, const BloomFilterAlgorithm& obj);
+
+
+class XxHash : public virtual ::apache::thrift::TBase {
+ public:
+
+ XxHash(const XxHash&);
+ XxHash& operator=(const XxHash&);
+ XxHash() {
+ }
+
+ virtual ~XxHash() noexcept;
+
+ bool operator == (const XxHash & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const XxHash &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const XxHash & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(XxHash &a, XxHash &b);
+
+std::ostream& operator<<(std::ostream& out, const XxHash& obj);
+
+typedef struct _BloomFilterHash__isset {
+ _BloomFilterHash__isset() : XXHASH(false) {}
+ bool XXHASH :1;
+} _BloomFilterHash__isset;
+
+class BloomFilterHash : public virtual ::apache::thrift::TBase {
+ public:
+
+ BloomFilterHash(const BloomFilterHash&);
+ BloomFilterHash& operator=(const BloomFilterHash&);
+ BloomFilterHash() {
+ }
+
+ virtual ~BloomFilterHash() noexcept;
+ XxHash XXHASH;
+
+ _BloomFilterHash__isset __isset;
+
+ void __set_XXHASH(const XxHash& val);
+
+ bool operator == (const BloomFilterHash & rhs) const
+ {
+ if (__isset.XXHASH != rhs.__isset.XXHASH)
+ return false;
+ else if (__isset.XXHASH && !(XXHASH == rhs.XXHASH))
+ return false;
+ return true;
+ }
+ bool operator != (const BloomFilterHash &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const BloomFilterHash & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(BloomFilterHash &a, BloomFilterHash &b);
+
+std::ostream& operator<<(std::ostream& out, const BloomFilterHash& obj);
+
+
+class Uncompressed : public virtual ::apache::thrift::TBase {
+ public:
+
+ Uncompressed(const Uncompressed&);
+ Uncompressed& operator=(const Uncompressed&);
+ Uncompressed() {
+ }
+
+ virtual ~Uncompressed() noexcept;
+
+ bool operator == (const Uncompressed & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const Uncompressed &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const Uncompressed & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(Uncompressed &a, Uncompressed &b);
+
+std::ostream& operator<<(std::ostream& out, const Uncompressed& obj);
+
+typedef struct _BloomFilterCompression__isset {
+ _BloomFilterCompression__isset() : UNCOMPRESSED(false) {}
+ bool UNCOMPRESSED :1;
+} _BloomFilterCompression__isset;
+
+class BloomFilterCompression : public virtual ::apache::thrift::TBase {
+ public:
+
+ BloomFilterCompression(const BloomFilterCompression&);
+ BloomFilterCompression& operator=(const BloomFilterCompression&);
+ BloomFilterCompression() {
+ }
+
+ virtual ~BloomFilterCompression() noexcept;
+ Uncompressed UNCOMPRESSED;
+
+ _BloomFilterCompression__isset __isset;
+
+ void __set_UNCOMPRESSED(const Uncompressed& val);
+
+ bool operator == (const BloomFilterCompression & rhs) const
+ {
+ if (__isset.UNCOMPRESSED != rhs.__isset.UNCOMPRESSED)
+ return false;
+ else if (__isset.UNCOMPRESSED && !(UNCOMPRESSED == rhs.UNCOMPRESSED))
+ return false;
+ return true;
+ }
+ bool operator != (const BloomFilterCompression &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const BloomFilterCompression & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(BloomFilterCompression &a, BloomFilterCompression &b);
+
+std::ostream& operator<<(std::ostream& out, const BloomFilterCompression& obj);
+
+
+class BloomFilterHeader : public virtual ::apache::thrift::TBase {
+ public:
+
+ BloomFilterHeader(const BloomFilterHeader&);
+ BloomFilterHeader& operator=(const BloomFilterHeader&);
+ BloomFilterHeader() : numBytes(0) {
+ }
+
+ virtual ~BloomFilterHeader() noexcept;
+ int32_t numBytes;
+ BloomFilterAlgorithm algorithm;
+ BloomFilterHash hash;
+ BloomFilterCompression compression;
+
+ void __set_numBytes(const int32_t val);
+
+ void __set_algorithm(const BloomFilterAlgorithm& val);
+
+ void __set_hash(const BloomFilterHash& val);
+
+ void __set_compression(const BloomFilterCompression& val);
+
+ bool operator == (const BloomFilterHeader & rhs) const
+ {
+ if (!(numBytes == rhs.numBytes))
+ return false;
+ if (!(algorithm == rhs.algorithm))
+ return false;
+ if (!(hash == rhs.hash))
+ return false;
+ if (!(compression == rhs.compression))
+ return false;
+ return true;
+ }
+ bool operator != (const BloomFilterHeader &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const BloomFilterHeader & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(BloomFilterHeader &a, BloomFilterHeader &b);
+
+std::ostream& operator<<(std::ostream& out, const BloomFilterHeader& obj);
+
+typedef struct _PageHeader__isset {
+ _PageHeader__isset() : crc(false), data_page_header(false), index_page_header(false), dictionary_page_header(false), data_page_header_v2(false) {}
+ bool crc :1;
+ bool data_page_header :1;
+ bool index_page_header :1;
+ bool dictionary_page_header :1;
+ bool data_page_header_v2 :1;
+} _PageHeader__isset;
+
+class PageHeader : public virtual ::apache::thrift::TBase {
+ public:
+
+ PageHeader(const PageHeader&);
+ PageHeader& operator=(const PageHeader&);
+ PageHeader() : type((PageType::type)0), uncompressed_page_size(0), compressed_page_size(0), crc(0) {
+ }
+
+ virtual ~PageHeader() noexcept;
+ PageType::type type;
+ int32_t uncompressed_page_size;
+ int32_t compressed_page_size;
+ int32_t crc;
+ DataPageHeader data_page_header;
+ IndexPageHeader index_page_header;
+ DictionaryPageHeader dictionary_page_header;
+ DataPageHeaderV2 data_page_header_v2;
+
+ _PageHeader__isset __isset;
+
+ void __set_type(const PageType::type val);
+
+ void __set_uncompressed_page_size(const int32_t val);
+
+ void __set_compressed_page_size(const int32_t val);
+
+ void __set_crc(const int32_t val);
+
+ void __set_data_page_header(const DataPageHeader& val);
+
+ void __set_index_page_header(const IndexPageHeader& val);
+
+ void __set_dictionary_page_header(const DictionaryPageHeader& val);
+
+ void __set_data_page_header_v2(const DataPageHeaderV2& val);
+
+ bool operator == (const PageHeader & rhs) const
+ {
+ if (!(type == rhs.type))
+ return false;
+ if (!(uncompressed_page_size == rhs.uncompressed_page_size))
+ return false;
+ if (!(compressed_page_size == rhs.compressed_page_size))
+ return false;
+ if (__isset.crc != rhs.__isset.crc)
+ return false;
+ else if (__isset.crc && !(crc == rhs.crc))
+ return false;
+ if (__isset.data_page_header != rhs.__isset.data_page_header)
+ return false;
+ else if (__isset.data_page_header && !(data_page_header == rhs.data_page_header))
+ return false;
+ if (__isset.index_page_header != rhs.__isset.index_page_header)
+ return false;
+ else if (__isset.index_page_header && !(index_page_header == rhs.index_page_header))
+ return false;
+ if (__isset.dictionary_page_header != rhs.__isset.dictionary_page_header)
+ return false;
+ else if (__isset.dictionary_page_header && !(dictionary_page_header == rhs.dictionary_page_header))
+ return false;
+ if (__isset.data_page_header_v2 != rhs.__isset.data_page_header_v2)
+ return false;
+ else if (__isset.data_page_header_v2 && !(data_page_header_v2 == rhs.data_page_header_v2))
+ return false;
+ return true;
+ }
+ bool operator != (const PageHeader &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const PageHeader & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(PageHeader &a, PageHeader &b);
+
+std::ostream& operator<<(std::ostream& out, const PageHeader& obj);
+
+typedef struct _KeyValue__isset {
+ _KeyValue__isset() : value(false) {}
+ bool value :1;
+} _KeyValue__isset;
+
+class KeyValue : public virtual ::apache::thrift::TBase {
+ public:
+
+ KeyValue(const KeyValue&);
+ KeyValue& operator=(const KeyValue&);
+ KeyValue() : key(), value() {
+ }
+
+ virtual ~KeyValue() noexcept;
+ std::string key;
+ std::string value;
+
+ _KeyValue__isset __isset;
+
+ void __set_key(const std::string& val);
+
+ void __set_value(const std::string& val);
+
+ bool operator == (const KeyValue & rhs) const
+ {
+ if (!(key == rhs.key))
+ return false;
+ if (__isset.value != rhs.__isset.value)
+ return false;
+ else if (__isset.value && !(value == rhs.value))
+ return false;
+ return true;
+ }
+ bool operator != (const KeyValue &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const KeyValue & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(KeyValue &a, KeyValue &b);
+
+std::ostream& operator<<(std::ostream& out, const KeyValue& obj);
+
+
+class SortingColumn : public virtual ::apache::thrift::TBase {
+ public:
+
+ SortingColumn(const SortingColumn&);
+ SortingColumn& operator=(const SortingColumn&);
+ SortingColumn() : column_idx(0), descending(0), nulls_first(0) {
+ }
+
+ virtual ~SortingColumn() noexcept;
+ int32_t column_idx;
+ bool descending;
+ bool nulls_first;
+
+ void __set_column_idx(const int32_t val);
+
+ void __set_descending(const bool val);
+
+ void __set_nulls_first(const bool val);
+
+ bool operator == (const SortingColumn & rhs) const
+ {
+ if (!(column_idx == rhs.column_idx))
+ return false;
+ if (!(descending == rhs.descending))
+ return false;
+ if (!(nulls_first == rhs.nulls_first))
+ return false;
+ return true;
+ }
+ bool operator != (const SortingColumn &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const SortingColumn & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(SortingColumn &a, SortingColumn &b);
+
+std::ostream& operator<<(std::ostream& out, const SortingColumn& obj);
+
+
+class PageEncodingStats : public virtual ::apache::thrift::TBase {
+ public:
+
+ PageEncodingStats(const PageEncodingStats&);
+ PageEncodingStats& operator=(const PageEncodingStats&);
+ PageEncodingStats() : page_type((PageType::type)0), encoding((Encoding::type)0), count(0) {
+ }
+
+ virtual ~PageEncodingStats() noexcept;
+ PageType::type page_type;
+ Encoding::type encoding;
+ int32_t count;
+
+ void __set_page_type(const PageType::type val);
+
+ void __set_encoding(const Encoding::type val);
+
+ void __set_count(const int32_t val);
+
+ bool operator == (const PageEncodingStats & rhs) const
+ {
+ if (!(page_type == rhs.page_type))
+ return false;
+ if (!(encoding == rhs.encoding))
+ return false;
+ if (!(count == rhs.count))
+ return false;
+ return true;
+ }
+ bool operator != (const PageEncodingStats &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const PageEncodingStats & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(PageEncodingStats &a, PageEncodingStats &b);
+
+std::ostream& operator<<(std::ostream& out, const PageEncodingStats& obj);
+
+typedef struct _ColumnMetaData__isset {
+ _ColumnMetaData__isset() : key_value_metadata(false), index_page_offset(false), dictionary_page_offset(false), statistics(false), encoding_stats(false), bloom_filter_offset(false) {}
+ bool key_value_metadata :1;
+ bool index_page_offset :1;
+ bool dictionary_page_offset :1;
+ bool statistics :1;
+ bool encoding_stats :1;
+ bool bloom_filter_offset :1;
+} _ColumnMetaData__isset;
+
+class ColumnMetaData : public virtual ::apache::thrift::TBase {
+ public:
+
+ ColumnMetaData(const ColumnMetaData&);
+ ColumnMetaData& operator=(const ColumnMetaData&);
+ ColumnMetaData() : type((Type::type)0), codec((CompressionCodec::type)0), num_values(0), total_uncompressed_size(0), total_compressed_size(0), data_page_offset(0), index_page_offset(0), dictionary_page_offset(0), bloom_filter_offset(0) {
+ }
+
+ virtual ~ColumnMetaData() noexcept;
+ Type::type type;
+ std::vector<Encoding::type> encodings;
+ std::vector<std::string> path_in_schema;
+ CompressionCodec::type codec;
+ int64_t num_values;
+ int64_t total_uncompressed_size;
+ int64_t total_compressed_size;
+ std::vector<KeyValue> key_value_metadata;
+ int64_t data_page_offset;
+ int64_t index_page_offset;
+ int64_t dictionary_page_offset;
+ Statistics statistics;
+ std::vector<PageEncodingStats> encoding_stats;
+ int64_t bloom_filter_offset;
+
+ _ColumnMetaData__isset __isset;
+
+ void __set_type(const Type::type val);
+
+ void __set_encodings(const std::vector<Encoding::type> & val);
+
+ void __set_path_in_schema(const std::vector<std::string> & val);
+
+ void __set_codec(const CompressionCodec::type val);
+
+ void __set_num_values(const int64_t val);
+
+ void __set_total_uncompressed_size(const int64_t val);
+
+ void __set_total_compressed_size(const int64_t val);
+
+ void __set_key_value_metadata(const std::vector<KeyValue> & val);
+
+ void __set_data_page_offset(const int64_t val);
+
+ void __set_index_page_offset(const int64_t val);
+
+ void __set_dictionary_page_offset(const int64_t val);
+
+ void __set_statistics(const Statistics& val);
+
+ void __set_encoding_stats(const std::vector<PageEncodingStats> & val);
+
+ void __set_bloom_filter_offset(const int64_t val);
+
+ bool operator == (const ColumnMetaData & rhs) const
+ {
+ if (!(type == rhs.type))
+ return false;
+ if (!(encodings == rhs.encodings))
+ return false;
+ if (!(path_in_schema == rhs.path_in_schema))
+ return false;
+ if (!(codec == rhs.codec))
+ return false;
+ if (!(num_values == rhs.num_values))
+ return false;
+ if (!(total_uncompressed_size == rhs.total_uncompressed_size))
+ return false;
+ if (!(total_compressed_size == rhs.total_compressed_size))
+ return false;
+ if (__isset.key_value_metadata != rhs.__isset.key_value_metadata)
+ return false;
+ else if (__isset.key_value_metadata && !(key_value_metadata == rhs.key_value_metadata))
+ return false;
+ if (!(data_page_offset == rhs.data_page_offset))
+ return false;
+ if (__isset.index_page_offset != rhs.__isset.index_page_offset)
+ return false;
+ else if (__isset.index_page_offset && !(index_page_offset == rhs.index_page_offset))
+ return false;
+ if (__isset.dictionary_page_offset != rhs.__isset.dictionary_page_offset)
+ return false;
+ else if (__isset.dictionary_page_offset && !(dictionary_page_offset == rhs.dictionary_page_offset))
+ return false;
+ if (__isset.statistics != rhs.__isset.statistics)
+ return false;
+ else if (__isset.statistics && !(statistics == rhs.statistics))
+ return false;
+ if (__isset.encoding_stats != rhs.__isset.encoding_stats)
+ return false;
+ else if (__isset.encoding_stats && !(encoding_stats == rhs.encoding_stats))
+ return false;
+ if (__isset.bloom_filter_offset != rhs.__isset.bloom_filter_offset)
+ return false;
+ else if (__isset.bloom_filter_offset && !(bloom_filter_offset == rhs.bloom_filter_offset))
+ return false;
+ return true;
+ }
+ bool operator != (const ColumnMetaData &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const ColumnMetaData & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(ColumnMetaData &a, ColumnMetaData &b);
+
+std::ostream& operator<<(std::ostream& out, const ColumnMetaData& obj);
+
+
+class EncryptionWithFooterKey : public virtual ::apache::thrift::TBase {
+ public:
+
+ EncryptionWithFooterKey(const EncryptionWithFooterKey&);
+ EncryptionWithFooterKey& operator=(const EncryptionWithFooterKey&);
+ EncryptionWithFooterKey() {
+ }
+
+ virtual ~EncryptionWithFooterKey() noexcept;
+
+ bool operator == (const EncryptionWithFooterKey & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const EncryptionWithFooterKey &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const EncryptionWithFooterKey & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(EncryptionWithFooterKey &a, EncryptionWithFooterKey &b);
+
+std::ostream& operator<<(std::ostream& out, const EncryptionWithFooterKey& obj);
+
+typedef struct _EncryptionWithColumnKey__isset {
+ _EncryptionWithColumnKey__isset() : key_metadata(false) {}
+ bool key_metadata :1;
+} _EncryptionWithColumnKey__isset;
+
+class EncryptionWithColumnKey : public virtual ::apache::thrift::TBase {
+ public:
+
+ EncryptionWithColumnKey(const EncryptionWithColumnKey&);
+ EncryptionWithColumnKey& operator=(const EncryptionWithColumnKey&);
+ EncryptionWithColumnKey() : key_metadata() {
+ }
+
+ virtual ~EncryptionWithColumnKey() noexcept;
+ std::vector<std::string> path_in_schema;
+ std::string key_metadata;
+
+ _EncryptionWithColumnKey__isset __isset;
+
+ void __set_path_in_schema(const std::vector<std::string> & val);
+
+ void __set_key_metadata(const std::string& val);
+
+ bool operator == (const EncryptionWithColumnKey & rhs) const
+ {
+ if (!(path_in_schema == rhs.path_in_schema))
+ return false;
+ if (__isset.key_metadata != rhs.__isset.key_metadata)
+ return false;
+ else if (__isset.key_metadata && !(key_metadata == rhs.key_metadata))
+ return false;
+ return true;
+ }
+ bool operator != (const EncryptionWithColumnKey &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const EncryptionWithColumnKey & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(EncryptionWithColumnKey &a, EncryptionWithColumnKey &b);
+
+std::ostream& operator<<(std::ostream& out, const EncryptionWithColumnKey& obj);
+
+typedef struct _ColumnCryptoMetaData__isset {
+ _ColumnCryptoMetaData__isset() : ENCRYPTION_WITH_FOOTER_KEY(false), ENCRYPTION_WITH_COLUMN_KEY(false) {}
+ bool ENCRYPTION_WITH_FOOTER_KEY :1;
+ bool ENCRYPTION_WITH_COLUMN_KEY :1;
+} _ColumnCryptoMetaData__isset;
+
+class ColumnCryptoMetaData : public virtual ::apache::thrift::TBase {
+ public:
+
+ ColumnCryptoMetaData(const ColumnCryptoMetaData&);
+ ColumnCryptoMetaData& operator=(const ColumnCryptoMetaData&);
+ ColumnCryptoMetaData() {
+ }
+
+ virtual ~ColumnCryptoMetaData() noexcept;
+ EncryptionWithFooterKey ENCRYPTION_WITH_FOOTER_KEY;
+ EncryptionWithColumnKey ENCRYPTION_WITH_COLUMN_KEY;
+
+ _ColumnCryptoMetaData__isset __isset;
+
+ void __set_ENCRYPTION_WITH_FOOTER_KEY(const EncryptionWithFooterKey& val);
+
+ void __set_ENCRYPTION_WITH_COLUMN_KEY(const EncryptionWithColumnKey& val);
+
+ bool operator == (const ColumnCryptoMetaData & rhs) const
+ {
+ if (__isset.ENCRYPTION_WITH_FOOTER_KEY != rhs.__isset.ENCRYPTION_WITH_FOOTER_KEY)
+ return false;
+ else if (__isset.ENCRYPTION_WITH_FOOTER_KEY && !(ENCRYPTION_WITH_FOOTER_KEY == rhs.ENCRYPTION_WITH_FOOTER_KEY))
+ return false;
+ if (__isset.ENCRYPTION_WITH_COLUMN_KEY != rhs.__isset.ENCRYPTION_WITH_COLUMN_KEY)
+ return false;
+ else if (__isset.ENCRYPTION_WITH_COLUMN_KEY && !(ENCRYPTION_WITH_COLUMN_KEY == rhs.ENCRYPTION_WITH_COLUMN_KEY))
+ return false;
+ return true;
+ }
+ bool operator != (const ColumnCryptoMetaData &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const ColumnCryptoMetaData & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(ColumnCryptoMetaData &a, ColumnCryptoMetaData &b);
+
+std::ostream& operator<<(std::ostream& out, const ColumnCryptoMetaData& obj);
+
+typedef struct _ColumnChunk__isset {
+ _ColumnChunk__isset() : file_path(false), meta_data(false), offset_index_offset(false), offset_index_length(false), column_index_offset(false), column_index_length(false), crypto_metadata(false), encrypted_column_metadata(false) {}
+ bool file_path :1;
+ bool meta_data :1;
+ bool offset_index_offset :1;
+ bool offset_index_length :1;
+ bool column_index_offset :1;
+ bool column_index_length :1;
+ bool crypto_metadata :1;
+ bool encrypted_column_metadata :1;
+} _ColumnChunk__isset;
+
+class ColumnChunk : public virtual ::apache::thrift::TBase {
+ public:
+
+ ColumnChunk(const ColumnChunk&);
+ ColumnChunk& operator=(const ColumnChunk&);
+ ColumnChunk() : file_path(), file_offset(0), offset_index_offset(0), offset_index_length(0), column_index_offset(0), column_index_length(0), encrypted_column_metadata() {
+ }
+
+ virtual ~ColumnChunk() noexcept;
+ std::string file_path;
+ int64_t file_offset;
+ ColumnMetaData meta_data;
+ int64_t offset_index_offset;
+ int32_t offset_index_length;
+ int64_t column_index_offset;
+ int32_t column_index_length;
+ ColumnCryptoMetaData crypto_metadata;
+ std::string encrypted_column_metadata;
+
+ _ColumnChunk__isset __isset;
+
+ void __set_file_path(const std::string& val);
+
+ void __set_file_offset(const int64_t val);
+
+ void __set_meta_data(const ColumnMetaData& val);
+
+ void __set_offset_index_offset(const int64_t val);
+
+ void __set_offset_index_length(const int32_t val);
+
+ void __set_column_index_offset(const int64_t val);
+
+ void __set_column_index_length(const int32_t val);
+
+ void __set_crypto_metadata(const ColumnCryptoMetaData& val);
+
+ void __set_encrypted_column_metadata(const std::string& val);
+
+ bool operator == (const ColumnChunk & rhs) const
+ {
+ if (__isset.file_path != rhs.__isset.file_path)
+ return false;
+ else if (__isset.file_path && !(file_path == rhs.file_path))
+ return false;
+ if (!(file_offset == rhs.file_offset))
+ return false;
+ if (__isset.meta_data != rhs.__isset.meta_data)
+ return false;
+ else if (__isset.meta_data && !(meta_data == rhs.meta_data))
+ return false;
+ if (__isset.offset_index_offset != rhs.__isset.offset_index_offset)
+ return false;
+ else if (__isset.offset_index_offset && !(offset_index_offset == rhs.offset_index_offset))
+ return false;
+ if (__isset.offset_index_length != rhs.__isset.offset_index_length)
+ return false;
+ else if (__isset.offset_index_length && !(offset_index_length == rhs.offset_index_length))
+ return false;
+ if (__isset.column_index_offset != rhs.__isset.column_index_offset)
+ return false;
+ else if (__isset.column_index_offset && !(column_index_offset == rhs.column_index_offset))
+ return false;
+ if (__isset.column_index_length != rhs.__isset.column_index_length)
+ return false;
+ else if (__isset.column_index_length && !(column_index_length == rhs.column_index_length))
+ return false;
+ if (__isset.crypto_metadata != rhs.__isset.crypto_metadata)
+ return false;
+ else if (__isset.crypto_metadata && !(crypto_metadata == rhs.crypto_metadata))
+ return false;
+ if (__isset.encrypted_column_metadata != rhs.__isset.encrypted_column_metadata)
+ return false;
+ else if (__isset.encrypted_column_metadata && !(encrypted_column_metadata == rhs.encrypted_column_metadata))
+ return false;
+ return true;
+ }
+ bool operator != (const ColumnChunk &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const ColumnChunk & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(ColumnChunk &a, ColumnChunk &b);
+
+std::ostream& operator<<(std::ostream& out, const ColumnChunk& obj);
+
+typedef struct _RowGroup__isset {
+ _RowGroup__isset() : sorting_columns(false), file_offset(false), total_compressed_size(false), ordinal(false) {}
+ bool sorting_columns :1;
+ bool file_offset :1;
+ bool total_compressed_size :1;
+ bool ordinal :1;
+} _RowGroup__isset;
+
+class RowGroup : public virtual ::apache::thrift::TBase {
+ public:
+
+ RowGroup(const RowGroup&);
+ RowGroup& operator=(const RowGroup&);
+ RowGroup() : total_byte_size(0), num_rows(0), file_offset(0), total_compressed_size(0), ordinal(0) {
+ }
+
+ virtual ~RowGroup() noexcept;
+ std::vector<ColumnChunk> columns;
+ int64_t total_byte_size;
+ int64_t num_rows;
+ std::vector<SortingColumn> sorting_columns;
+ int64_t file_offset;
+ int64_t total_compressed_size;
+ int16_t ordinal;
+
+ _RowGroup__isset __isset;
+
+ void __set_columns(const std::vector<ColumnChunk> & val);
+
+ void __set_total_byte_size(const int64_t val);
+
+ void __set_num_rows(const int64_t val);
+
+ void __set_sorting_columns(const std::vector<SortingColumn> & val);
+
+ void __set_file_offset(const int64_t val);
+
+ void __set_total_compressed_size(const int64_t val);
+
+ void __set_ordinal(const int16_t val);
+
+ bool operator == (const RowGroup & rhs) const
+ {
+ if (!(columns == rhs.columns))
+ return false;
+ if (!(total_byte_size == rhs.total_byte_size))
+ return false;
+ if (!(num_rows == rhs.num_rows))
+ return false;
+ if (__isset.sorting_columns != rhs.__isset.sorting_columns)
+ return false;
+ else if (__isset.sorting_columns && !(sorting_columns == rhs.sorting_columns))
+ return false;
+ if (__isset.file_offset != rhs.__isset.file_offset)
+ return false;
+ else if (__isset.file_offset && !(file_offset == rhs.file_offset))
+ return false;
+ if (__isset.total_compressed_size != rhs.__isset.total_compressed_size)
+ return false;
+ else if (__isset.total_compressed_size && !(total_compressed_size == rhs.total_compressed_size))
+ return false;
+ if (__isset.ordinal != rhs.__isset.ordinal)
+ return false;
+ else if (__isset.ordinal && !(ordinal == rhs.ordinal))
+ return false;
+ return true;
+ }
+ bool operator != (const RowGroup &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const RowGroup & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(RowGroup &a, RowGroup &b);
+
+std::ostream& operator<<(std::ostream& out, const RowGroup& obj);
+
+
+class TypeDefinedOrder : public virtual ::apache::thrift::TBase {
+ public:
+
+ TypeDefinedOrder(const TypeDefinedOrder&);
+ TypeDefinedOrder& operator=(const TypeDefinedOrder&);
+ TypeDefinedOrder() {
+ }
+
+ virtual ~TypeDefinedOrder() noexcept;
+
+ bool operator == (const TypeDefinedOrder & /* rhs */) const
+ {
+ return true;
+ }
+ bool operator != (const TypeDefinedOrder &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const TypeDefinedOrder & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(TypeDefinedOrder &a, TypeDefinedOrder &b);
+
+std::ostream& operator<<(std::ostream& out, const TypeDefinedOrder& obj);
+
+typedef struct _ColumnOrder__isset {
+ _ColumnOrder__isset() : TYPE_ORDER(false) {}
+ bool TYPE_ORDER :1;
+} _ColumnOrder__isset;
+
+class ColumnOrder : public virtual ::apache::thrift::TBase {
+ public:
+
+ ColumnOrder(const ColumnOrder&);
+ ColumnOrder& operator=(const ColumnOrder&);
+ ColumnOrder() {
+ }
+
+ virtual ~ColumnOrder() noexcept;
+ TypeDefinedOrder TYPE_ORDER;
+
+ _ColumnOrder__isset __isset;
+
+ void __set_TYPE_ORDER(const TypeDefinedOrder& val);
+
+ bool operator == (const ColumnOrder & rhs) const
+ {
+ if (__isset.TYPE_ORDER != rhs.__isset.TYPE_ORDER)
+ return false;
+ else if (__isset.TYPE_ORDER && !(TYPE_ORDER == rhs.TYPE_ORDER))
+ return false;
+ return true;
+ }
+ bool operator != (const ColumnOrder &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const ColumnOrder & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(ColumnOrder &a, ColumnOrder &b);
+
+std::ostream& operator<<(std::ostream& out, const ColumnOrder& obj);
+
+
+class PageLocation : public virtual ::apache::thrift::TBase {
+ public:
+
+ PageLocation(const PageLocation&);
+ PageLocation& operator=(const PageLocation&);
+ PageLocation() : offset(0), compressed_page_size(0), first_row_index(0) {
+ }
+
+ virtual ~PageLocation() noexcept;
+ int64_t offset;
+ int32_t compressed_page_size;
+ int64_t first_row_index;
+
+ void __set_offset(const int64_t val);
+
+ void __set_compressed_page_size(const int32_t val);
+
+ void __set_first_row_index(const int64_t val);
+
+ bool operator == (const PageLocation & rhs) const
+ {
+ if (!(offset == rhs.offset))
+ return false;
+ if (!(compressed_page_size == rhs.compressed_page_size))
+ return false;
+ if (!(first_row_index == rhs.first_row_index))
+ return false;
+ return true;
+ }
+ bool operator != (const PageLocation &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const PageLocation & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(PageLocation &a, PageLocation &b);
+
+std::ostream& operator<<(std::ostream& out, const PageLocation& obj);
+
+
+class OffsetIndex : public virtual ::apache::thrift::TBase {
+ public:
+
+ OffsetIndex(const OffsetIndex&);
+ OffsetIndex& operator=(const OffsetIndex&);
+ OffsetIndex() {
+ }
+
+ virtual ~OffsetIndex() noexcept;
+ std::vector<PageLocation> page_locations;
+
+ void __set_page_locations(const std::vector<PageLocation> & val);
+
+ bool operator == (const OffsetIndex & rhs) const
+ {
+ if (!(page_locations == rhs.page_locations))
+ return false;
+ return true;
+ }
+ bool operator != (const OffsetIndex &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const OffsetIndex & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(OffsetIndex &a, OffsetIndex &b);
+
+std::ostream& operator<<(std::ostream& out, const OffsetIndex& obj);
+
+typedef struct _ColumnIndex__isset {
+ _ColumnIndex__isset() : null_counts(false) {}
+ bool null_counts :1;
+} _ColumnIndex__isset;
+
+class ColumnIndex : public virtual ::apache::thrift::TBase {
+ public:
+
+ ColumnIndex(const ColumnIndex&);
+ ColumnIndex& operator=(const ColumnIndex&);
+ ColumnIndex() : boundary_order((BoundaryOrder::type)0) {
+ }
+
+ virtual ~ColumnIndex() noexcept;
+ std::vector<bool> null_pages;
+ std::vector<std::string> min_values;
+ std::vector<std::string> max_values;
+ BoundaryOrder::type boundary_order;
+ std::vector<int64_t> null_counts;
+
+ _ColumnIndex__isset __isset;
+
+ void __set_null_pages(const std::vector<bool> & val);
+
+ void __set_min_values(const std::vector<std::string> & val);
+
+ void __set_max_values(const std::vector<std::string> & val);
+
+ void __set_boundary_order(const BoundaryOrder::type val);
+
+ void __set_null_counts(const std::vector<int64_t> & val);
+
+ bool operator == (const ColumnIndex & rhs) const
+ {
+ if (!(null_pages == rhs.null_pages))
+ return false;
+ if (!(min_values == rhs.min_values))
+ return false;
+ if (!(max_values == rhs.max_values))
+ return false;
+ if (!(boundary_order == rhs.boundary_order))
+ return false;
+ if (__isset.null_counts != rhs.__isset.null_counts)
+ return false;
+ else if (__isset.null_counts && !(null_counts == rhs.null_counts))
+ return false;
+ return true;
+ }
+ bool operator != (const ColumnIndex &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const ColumnIndex & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(ColumnIndex &a, ColumnIndex &b);
+
+std::ostream& operator<<(std::ostream& out, const ColumnIndex& obj);
+
+typedef struct _AesGcmV1__isset {
+ _AesGcmV1__isset() : aad_prefix(false), aad_file_unique(false), supply_aad_prefix(false) {}
+ bool aad_prefix :1;
+ bool aad_file_unique :1;
+ bool supply_aad_prefix :1;
+} _AesGcmV1__isset;
+
+class AesGcmV1 : public virtual ::apache::thrift::TBase {
+ public:
+
+ AesGcmV1(const AesGcmV1&);
+ AesGcmV1& operator=(const AesGcmV1&);
+ AesGcmV1() : aad_prefix(), aad_file_unique(), supply_aad_prefix(0) {
+ }
+
+ virtual ~AesGcmV1() noexcept;
+ std::string aad_prefix;
+ std::string aad_file_unique;
+ bool supply_aad_prefix;
+
+ _AesGcmV1__isset __isset;
+
+ void __set_aad_prefix(const std::string& val);
+
+ void __set_aad_file_unique(const std::string& val);
+
+ void __set_supply_aad_prefix(const bool val);
+
+ bool operator == (const AesGcmV1 & rhs) const
+ {
+ if (__isset.aad_prefix != rhs.__isset.aad_prefix)
+ return false;
+ else if (__isset.aad_prefix && !(aad_prefix == rhs.aad_prefix))
+ return false;
+ if (__isset.aad_file_unique != rhs.__isset.aad_file_unique)
+ return false;
+ else if (__isset.aad_file_unique && !(aad_file_unique == rhs.aad_file_unique))
+ return false;
+ if (__isset.supply_aad_prefix != rhs.__isset.supply_aad_prefix)
+ return false;
+ else if (__isset.supply_aad_prefix && !(supply_aad_prefix == rhs.supply_aad_prefix))
+ return false;
+ return true;
+ }
+ bool operator != (const AesGcmV1 &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const AesGcmV1 & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(AesGcmV1 &a, AesGcmV1 &b);
+
+std::ostream& operator<<(std::ostream& out, const AesGcmV1& obj);
+
+typedef struct _AesGcmCtrV1__isset {
+ _AesGcmCtrV1__isset() : aad_prefix(false), aad_file_unique(false), supply_aad_prefix(false) {}
+ bool aad_prefix :1;
+ bool aad_file_unique :1;
+ bool supply_aad_prefix :1;
+} _AesGcmCtrV1__isset;
+
+class AesGcmCtrV1 : public virtual ::apache::thrift::TBase {
+ public:
+
+ AesGcmCtrV1(const AesGcmCtrV1&);
+ AesGcmCtrV1& operator=(const AesGcmCtrV1&);
+ AesGcmCtrV1() : aad_prefix(), aad_file_unique(), supply_aad_prefix(0) {
+ }
+
+ virtual ~AesGcmCtrV1() noexcept;
+ std::string aad_prefix;
+ std::string aad_file_unique;
+ bool supply_aad_prefix;
+
+ _AesGcmCtrV1__isset __isset;
+
+ void __set_aad_prefix(const std::string& val);
+
+ void __set_aad_file_unique(const std::string& val);
+
+ void __set_supply_aad_prefix(const bool val);
+
+ bool operator == (const AesGcmCtrV1 & rhs) const
+ {
+ if (__isset.aad_prefix != rhs.__isset.aad_prefix)
+ return false;
+ else if (__isset.aad_prefix && !(aad_prefix == rhs.aad_prefix))
+ return false;
+ if (__isset.aad_file_unique != rhs.__isset.aad_file_unique)
+ return false;
+ else if (__isset.aad_file_unique && !(aad_file_unique == rhs.aad_file_unique))
+ return false;
+ if (__isset.supply_aad_prefix != rhs.__isset.supply_aad_prefix)
+ return false;
+ else if (__isset.supply_aad_prefix && !(supply_aad_prefix == rhs.supply_aad_prefix))
+ return false;
+ return true;
+ }
+ bool operator != (const AesGcmCtrV1 &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const AesGcmCtrV1 & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(AesGcmCtrV1 &a, AesGcmCtrV1 &b);
+
+std::ostream& operator<<(std::ostream& out, const AesGcmCtrV1& obj);
+
+typedef struct _EncryptionAlgorithm__isset {
+ _EncryptionAlgorithm__isset() : AES_GCM_V1(false), AES_GCM_CTR_V1(false) {}
+ bool AES_GCM_V1 :1;
+ bool AES_GCM_CTR_V1 :1;
+} _EncryptionAlgorithm__isset;
+
+class EncryptionAlgorithm : public virtual ::apache::thrift::TBase {
+ public:
+
+ EncryptionAlgorithm(const EncryptionAlgorithm&);
+ EncryptionAlgorithm& operator=(const EncryptionAlgorithm&);
+ EncryptionAlgorithm() {
+ }
+
+ virtual ~EncryptionAlgorithm() noexcept;
+ AesGcmV1 AES_GCM_V1;
+ AesGcmCtrV1 AES_GCM_CTR_V1;
+
+ _EncryptionAlgorithm__isset __isset;
+
+ void __set_AES_GCM_V1(const AesGcmV1& val);
+
+ void __set_AES_GCM_CTR_V1(const AesGcmCtrV1& val);
+
+ bool operator == (const EncryptionAlgorithm & rhs) const
+ {
+ if (__isset.AES_GCM_V1 != rhs.__isset.AES_GCM_V1)
+ return false;
+ else if (__isset.AES_GCM_V1 && !(AES_GCM_V1 == rhs.AES_GCM_V1))
+ return false;
+ if (__isset.AES_GCM_CTR_V1 != rhs.__isset.AES_GCM_CTR_V1)
+ return false;
+ else if (__isset.AES_GCM_CTR_V1 && !(AES_GCM_CTR_V1 == rhs.AES_GCM_CTR_V1))
+ return false;
+ return true;
+ }
+ bool operator != (const EncryptionAlgorithm &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const EncryptionAlgorithm & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(EncryptionAlgorithm &a, EncryptionAlgorithm &b);
+
+std::ostream& operator<<(std::ostream& out, const EncryptionAlgorithm& obj);
+
+typedef struct _FileMetaData__isset {
+ _FileMetaData__isset() : key_value_metadata(false), created_by(false), column_orders(false), encryption_algorithm(false), footer_signing_key_metadata(false) {}
+ bool key_value_metadata :1;
+ bool created_by :1;
+ bool column_orders :1;
+ bool encryption_algorithm :1;
+ bool footer_signing_key_metadata :1;
+} _FileMetaData__isset;
+
+class FileMetaData : public virtual ::apache::thrift::TBase {
+ public:
+
+ FileMetaData(const FileMetaData&);
+ FileMetaData& operator=(const FileMetaData&);
+ FileMetaData() : version(0), num_rows(0), created_by(), footer_signing_key_metadata() {
+ }
+
+ virtual ~FileMetaData() noexcept;
+ int32_t version;
+ std::vector<SchemaElement> schema;
+ int64_t num_rows;
+ std::vector<RowGroup> row_groups;
+ std::vector<KeyValue> key_value_metadata;
+ std::string created_by;
+ std::vector<ColumnOrder> column_orders;
+ EncryptionAlgorithm encryption_algorithm;
+ std::string footer_signing_key_metadata;
+
+ _FileMetaData__isset __isset;
+
+ void __set_version(const int32_t val);
+
+ void __set_schema(const std::vector<SchemaElement> & val);
+
+ void __set_num_rows(const int64_t val);
+
+ void __set_row_groups(const std::vector<RowGroup> & val);
+
+ void __set_key_value_metadata(const std::vector<KeyValue> & val);
+
+ void __set_created_by(const std::string& val);
+
+ void __set_column_orders(const std::vector<ColumnOrder> & val);
+
+ void __set_encryption_algorithm(const EncryptionAlgorithm& val);
+
+ void __set_footer_signing_key_metadata(const std::string& val);
+
+ bool operator == (const FileMetaData & rhs) const
+ {
+ if (!(version == rhs.version))
+ return false;
+ if (!(schema == rhs.schema))
+ return false;
+ if (!(num_rows == rhs.num_rows))
+ return false;
+ if (!(row_groups == rhs.row_groups))
+ return false;
+ if (__isset.key_value_metadata != rhs.__isset.key_value_metadata)
+ return false;
+ else if (__isset.key_value_metadata && !(key_value_metadata == rhs.key_value_metadata))
+ return false;
+ if (__isset.created_by != rhs.__isset.created_by)
+ return false;
+ else if (__isset.created_by && !(created_by == rhs.created_by))
+ return false;
+ if (__isset.column_orders != rhs.__isset.column_orders)
+ return false;
+ else if (__isset.column_orders && !(column_orders == rhs.column_orders))
+ return false;
+ if (__isset.encryption_algorithm != rhs.__isset.encryption_algorithm)
+ return false;
+ else if (__isset.encryption_algorithm && !(encryption_algorithm == rhs.encryption_algorithm))
+ return false;
+ if (__isset.footer_signing_key_metadata != rhs.__isset.footer_signing_key_metadata)
+ return false;
+ else if (__isset.footer_signing_key_metadata && !(footer_signing_key_metadata == rhs.footer_signing_key_metadata))
+ return false;
+ return true;
+ }
+ bool operator != (const FileMetaData &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const FileMetaData & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(FileMetaData &a, FileMetaData &b);
+
+std::ostream& operator<<(std::ostream& out, const FileMetaData& obj);
+
+typedef struct _FileCryptoMetaData__isset {
+ _FileCryptoMetaData__isset() : key_metadata(false) {}
+ bool key_metadata :1;
+} _FileCryptoMetaData__isset;
+
+class FileCryptoMetaData : public virtual ::apache::thrift::TBase {
+ public:
+
+ FileCryptoMetaData(const FileCryptoMetaData&);
+ FileCryptoMetaData& operator=(const FileCryptoMetaData&);
+ FileCryptoMetaData() : key_metadata() {
+ }
+
+ virtual ~FileCryptoMetaData() noexcept;
+ EncryptionAlgorithm encryption_algorithm;
+ std::string key_metadata;
+
+ _FileCryptoMetaData__isset __isset;
+
+ void __set_encryption_algorithm(const EncryptionAlgorithm& val);
+
+ void __set_key_metadata(const std::string& val);
+
+ bool operator == (const FileCryptoMetaData & rhs) const
+ {
+ if (!(encryption_algorithm == rhs.encryption_algorithm))
+ return false;
+ if (__isset.key_metadata != rhs.__isset.key_metadata)
+ return false;
+ else if (__isset.key_metadata && !(key_metadata == rhs.key_metadata))
+ return false;
+ return true;
+ }
+ bool operator != (const FileCryptoMetaData &rhs) const {
+ return !(*this == rhs);
+ }
+
+ bool operator < (const FileCryptoMetaData & ) const;
+
+ uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
+ uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
+
+ virtual void printTo(std::ostream& out) const;
+};
+
+void swap(FileCryptoMetaData &a, FileCryptoMetaData &b);
+
+std::ostream& operator<<(std::ostream& out, const FileCryptoMetaData& obj);
+
+}} // namespace
+
+#endif
diff --git a/src/arrow/cpp/src/jni/CMakeLists.txt b/src/arrow/cpp/src/jni/CMakeLists.txt
new file mode 100644
index 000000000..3a5cc7fca
--- /dev/null
+++ b/src/arrow/cpp/src/jni/CMakeLists.txt
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#
+# arrow_jni
+#
+if(ARROW_ORC)
+ add_subdirectory(orc)
+endif()
+
+if(ARROW_DATASET)
+ add_subdirectory(dataset)
+endif()
diff --git a/src/arrow/cpp/src/jni/dataset/CMakeLists.txt b/src/arrow/cpp/src/jni/dataset/CMakeLists.txt
new file mode 100644
index 000000000..f3e309b61
--- /dev/null
+++ b/src/arrow/cpp/src/jni/dataset/CMakeLists.txt
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitationsn
+# under the License.
+
+#
+# arrow_dataset_jni
+#
+
+project(arrow_dataset_jni)
+
+cmake_minimum_required(VERSION 3.11)
+
+find_package(JNI REQUIRED)
+
+add_custom_target(arrow_dataset_jni)
+
+set(JNI_HEADERS_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated")
+
+add_subdirectory(../../../../java/dataset ./java)
+
+set(ARROW_BUILD_STATIC OFF)
+
+set(ARROW_DATASET_JNI_LIBS arrow_dataset_static)
+
+set(ARROW_DATASET_JNI_SOURCES jni_wrapper.cc jni_util.cc)
+
+add_arrow_lib(arrow_dataset_jni
+ BUILD_SHARED
+ SOURCES
+ ${ARROW_DATASET_JNI_SOURCES}
+ OUTPUTS
+ ARROW_DATASET_JNI_LIBRARIES
+ SHARED_PRIVATE_LINK_LIBS
+ ${ARROW_DATASET_JNI_LIBS}
+ STATIC_LINK_LIBS
+ ${ARROW_DATASET_JNI_LIBS}
+ EXTRA_INCLUDES
+ ${JNI_HEADERS_DIR}
+ PRIVATE_INCLUDES
+ ${JNI_INCLUDE_DIRS}
+ DEPENDENCIES
+ arrow_static
+ arrow_dataset_java)
+
+add_dependencies(arrow_dataset_jni ${ARROW_DATASET_JNI_LIBRARIES})
+
+add_arrow_test(dataset_jni_test
+ SOURCES
+ jni_util_test.cc
+ jni_util.cc
+ EXTRA_INCLUDES
+ ${JNI_INCLUDE_DIRS})
diff --git a/src/arrow/cpp/src/jni/dataset/jni_util.cc b/src/arrow/cpp/src/jni/dataset/jni_util.cc
new file mode 100644
index 000000000..113669a4c
--- /dev/null
+++ b/src/arrow/cpp/src/jni/dataset/jni_util.cc
@@ -0,0 +1,242 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "jni/dataset/jni_util.h"
+
+#include "arrow/util/logging.h"
+
+#include <mutex>
+
+namespace arrow {
+namespace dataset {
+namespace jni {
+
+class ReservationListenableMemoryPool::Impl {
+ public:
+ explicit Impl(arrow::MemoryPool* pool, std::shared_ptr<ReservationListener> listener,
+ int64_t block_size)
+ : pool_(pool),
+ listener_(listener),
+ block_size_(block_size),
+ blocks_reserved_(0),
+ bytes_reserved_(0) {}
+
+ arrow::Status Allocate(int64_t size, uint8_t** out) {
+ RETURN_NOT_OK(UpdateReservation(size));
+ arrow::Status error = pool_->Allocate(size, out);
+ if (!error.ok()) {
+ RETURN_NOT_OK(UpdateReservation(-size));
+ return error;
+ }
+ return arrow::Status::OK();
+ }
+
+ arrow::Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) {
+ bool reserved = false;
+ int64_t diff = new_size - old_size;
+ if (new_size >= old_size) {
+ // new_size >= old_size, pre-reserve bytes from listener before allocating
+ // from underlying pool
+ RETURN_NOT_OK(UpdateReservation(diff));
+ reserved = true;
+ }
+ arrow::Status error = pool_->Reallocate(old_size, new_size, ptr);
+ if (!error.ok()) {
+ if (reserved) {
+ // roll back reservations on error
+ RETURN_NOT_OK(UpdateReservation(-diff));
+ }
+ return error;
+ }
+ if (!reserved) {
+ // otherwise (e.g. new_size < old_size), make updates after calling underlying pool
+ RETURN_NOT_OK(UpdateReservation(diff));
+ }
+ return arrow::Status::OK();
+ }
+
+ void Free(uint8_t* buffer, int64_t size) {
+ pool_->Free(buffer, size);
+ // FIXME: See ARROW-11143, currently method ::Free doesn't allow Status return
+ arrow::Status s = UpdateReservation(-size);
+ if (!s.ok()) {
+ ARROW_LOG(FATAL) << "Failed to update reservation while freeing bytes: "
+ << s.message();
+ return;
+ }
+ }
+
+ arrow::Status UpdateReservation(int64_t diff) {
+ int64_t granted = Reserve(diff);
+ if (granted == 0) {
+ return arrow::Status::OK();
+ }
+ if (granted < 0) {
+ RETURN_NOT_OK(listener_->OnRelease(-granted));
+ return arrow::Status::OK();
+ }
+ RETURN_NOT_OK(listener_->OnReservation(granted));
+ return arrow::Status::OK();
+ }
+
+ int64_t Reserve(int64_t diff) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ bytes_reserved_ += diff;
+ int64_t new_block_count;
+ if (bytes_reserved_ == 0) {
+ new_block_count = 0;
+ } else {
+ // ceil to get the required block number
+ new_block_count = (bytes_reserved_ - 1) / block_size_ + 1;
+ }
+ int64_t bytes_granted = (new_block_count - blocks_reserved_) * block_size_;
+ blocks_reserved_ = new_block_count;
+ return bytes_granted;
+ }
+
+ int64_t bytes_allocated() { return pool_->bytes_allocated(); }
+
+ int64_t max_memory() { return pool_->max_memory(); }
+
+ std::string backend_name() { return pool_->backend_name(); }
+
+ std::shared_ptr<ReservationListener> get_listener() { return listener_; }
+
+ private:
+ arrow::MemoryPool* pool_;
+ std::shared_ptr<ReservationListener> listener_;
+ int64_t block_size_;
+ int64_t blocks_reserved_;
+ int64_t bytes_reserved_;
+ std::mutex mutex_;
+};
+
+ReservationListenableMemoryPool::ReservationListenableMemoryPool(
+ MemoryPool* pool, std::shared_ptr<ReservationListener> listener, int64_t block_size) {
+ impl_.reset(new Impl(pool, listener, block_size));
+}
+
+arrow::Status ReservationListenableMemoryPool::Allocate(int64_t size, uint8_t** out) {
+ return impl_->Allocate(size, out);
+}
+
+arrow::Status ReservationListenableMemoryPool::Reallocate(int64_t old_size,
+ int64_t new_size,
+ uint8_t** ptr) {
+ return impl_->Reallocate(old_size, new_size, ptr);
+}
+
+void ReservationListenableMemoryPool::Free(uint8_t* buffer, int64_t size) {
+ return impl_->Free(buffer, size);
+}
+
+int64_t ReservationListenableMemoryPool::bytes_allocated() const {
+ return impl_->bytes_allocated();
+}
+
+int64_t ReservationListenableMemoryPool::max_memory() const {
+ return impl_->max_memory();
+}
+
+std::string ReservationListenableMemoryPool::backend_name() const {
+ return impl_->backend_name();
+}
+
+std::shared_ptr<ReservationListener> ReservationListenableMemoryPool::get_listener() {
+ return impl_->get_listener();
+}
+
+ReservationListenableMemoryPool::~ReservationListenableMemoryPool() {}
+
+jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) {
+ jclass local_class = env->FindClass(class_name);
+ jclass global_class = (jclass)env->NewGlobalRef(local_class);
+ env->DeleteLocalRef(local_class);
+ return global_class;
+}
+
+arrow::Result<jmethodID> GetMethodID(JNIEnv* env, jclass this_class, const char* name,
+ const char* sig) {
+ jmethodID ret = env->GetMethodID(this_class, name, sig);
+ if (ret == nullptr) {
+ std::string error_message = "Unable to find method " + std::string(name) +
+ " within signature" + std::string(sig);
+ return arrow::Status::Invalid(error_message);
+ }
+ return ret;
+}
+
+arrow::Result<jmethodID> GetStaticMethodID(JNIEnv* env, jclass this_class,
+ const char* name, const char* sig) {
+ jmethodID ret = env->GetStaticMethodID(this_class, name, sig);
+ if (ret == nullptr) {
+ std::string error_message = "Unable to find static method " + std::string(name) +
+ " within signature" + std::string(sig);
+ return arrow::Status::Invalid(error_message);
+ }
+ return ret;
+}
+
+std::string JStringToCString(JNIEnv* env, jstring string) {
+ if (string == nullptr) {
+ return std::string();
+ }
+ const char* chars = env->GetStringUTFChars(string, nullptr);
+ std::string ret(chars);
+ env->ReleaseStringUTFChars(string, chars);
+ return ret;
+}
+
+std::vector<std::string> ToStringVector(JNIEnv* env, jobjectArray& str_array) {
+ int length = env->GetArrayLength(str_array);
+ std::vector<std::string> vector;
+ for (int i = 0; i < length; i++) {
+ auto string = reinterpret_cast<jstring>(env->GetObjectArrayElement(str_array, i));
+ vector.push_back(JStringToCString(env, string));
+ }
+ return vector;
+}
+
+arrow::Result<jbyteArray> ToSchemaByteArray(JNIEnv* env,
+ std::shared_ptr<arrow::Schema> schema) {
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<arrow::Buffer> buffer,
+ arrow::ipc::SerializeSchema(*schema, arrow::default_memory_pool()))
+
+ jbyteArray out = env->NewByteArray(buffer->size());
+ auto src = reinterpret_cast<const jbyte*>(buffer->data());
+ env->SetByteArrayRegion(out, 0, buffer->size(), src);
+ return out;
+}
+
+arrow::Result<std::shared_ptr<arrow::Schema>> FromSchemaByteArray(
+ JNIEnv* env, jbyteArray schemaBytes) {
+ arrow::ipc::DictionaryMemo in_memo;
+ int schemaBytes_len = env->GetArrayLength(schemaBytes);
+ jbyte* schemaBytes_data = env->GetByteArrayElements(schemaBytes, nullptr);
+ auto serialized_schema = std::make_shared<arrow::Buffer>(
+ reinterpret_cast<uint8_t*>(schemaBytes_data), schemaBytes_len);
+ arrow::io::BufferReader buf_reader(serialized_schema);
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<arrow::Schema> schema,
+ arrow::ipc::ReadSchema(&buf_reader, &in_memo))
+ env->ReleaseByteArrayElements(schemaBytes, schemaBytes_data, JNI_ABORT);
+ return schema;
+}
+
+} // namespace jni
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/jni/dataset/jni_util.h b/src/arrow/cpp/src/jni/dataset/jni_util.h
new file mode 100644
index 000000000..c76033ae6
--- /dev/null
+++ b/src/arrow/cpp/src/jni/dataset/jni_util.h
@@ -0,0 +1,135 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/array.h"
+#include "arrow/io/api.h"
+#include "arrow/ipc/api.h"
+#include "arrow/memory_pool.h"
+#include "arrow/result.h"
+#include "arrow/type.h"
+
+#include <jni.h>
+
+namespace arrow {
+namespace dataset {
+namespace jni {
+
+jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name);
+
+arrow::Result<jmethodID> GetMethodID(JNIEnv* env, jclass this_class, const char* name,
+ const char* sig);
+
+arrow::Result<jmethodID> GetStaticMethodID(JNIEnv* env, jclass this_class,
+ const char* name, const char* sig);
+
+std::string JStringToCString(JNIEnv* env, jstring string);
+
+std::vector<std::string> ToStringVector(JNIEnv* env, jobjectArray& str_array);
+
+arrow::Result<jbyteArray> ToSchemaByteArray(JNIEnv* env,
+ std::shared_ptr<arrow::Schema> schema);
+
+arrow::Result<std::shared_ptr<arrow::Schema>> FromSchemaByteArray(JNIEnv* env,
+ jbyteArray schemaBytes);
+
+/// \brief Create a new shared_ptr on heap from shared_ptr t to prevent
+/// the managed object from being garbage-collected.
+///
+/// \return address of the newly created shared pointer
+template <typename T>
+jlong CreateNativeRef(std::shared_ptr<T> t) {
+ std::shared_ptr<T>* retained_ptr = new std::shared_ptr<T>(t);
+ return reinterpret_cast<jlong>(retained_ptr);
+}
+
+/// \brief Get the shared_ptr that was derived via function CreateNativeRef.
+///
+/// \param[in] ref address of the shared_ptr
+/// \return the shared_ptr object
+template <typename T>
+std::shared_ptr<T> RetrieveNativeInstance(jlong ref) {
+ std::shared_ptr<T>* retrieved_ptr = reinterpret_cast<std::shared_ptr<T>*>(ref);
+ return *retrieved_ptr;
+}
+
+/// \brief Destroy a shared_ptr using its memory address.
+///
+/// \param[in] ref address of the shared_ptr
+template <typename T>
+void ReleaseNativeRef(jlong ref) {
+ std::shared_ptr<T>* retrieved_ptr = reinterpret_cast<std::shared_ptr<T>*>(ref);
+ delete retrieved_ptr;
+}
+
+/// Listener to act on reservations/unreservations from ReservationListenableMemoryPool.
+///
+/// Note the memory pool will call this listener only on block-level memory
+/// reservation/unreservation is granted. So the invocation parameter "size" is always
+/// multiple of block size (by default, 512k) specified in memory pool.
+class ReservationListener {
+ public:
+ virtual ~ReservationListener() = default;
+
+ virtual arrow::Status OnReservation(int64_t size) = 0;
+ virtual arrow::Status OnRelease(int64_t size) = 0;
+
+ protected:
+ ReservationListener() = default;
+};
+
+/// A memory pool implementation for pre-reserving memory blocks from a
+/// customizable listener. This will typically be used when memory allocations
+/// have to be subject to another "virtual" resource manager, which just tracks or
+/// limits number of bytes of application's overall memory usage. The underlying
+/// memory pool will still be responsible for actual malloc/free operations.
+class ReservationListenableMemoryPool : public arrow::MemoryPool {
+ public:
+ /// \brief Constructor.
+ ///
+ /// \param[in] pool the underlying memory pool
+ /// \param[in] listener a listener for block-level reservations/releases.
+ /// \param[in] block_size size of each block to reserve from the listener
+ explicit ReservationListenableMemoryPool(MemoryPool* pool,
+ std::shared_ptr<ReservationListener> listener,
+ int64_t block_size = 512 * 1024);
+
+ ~ReservationListenableMemoryPool();
+
+ arrow::Status Allocate(int64_t size, uint8_t** out) override;
+
+ arrow::Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) override;
+
+ void Free(uint8_t* buffer, int64_t size) override;
+
+ int64_t bytes_allocated() const override;
+
+ int64_t max_memory() const override;
+
+ std::string backend_name() const override;
+
+ std::shared_ptr<ReservationListener> get_listener();
+
+ private:
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+} // namespace jni
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/jni/dataset/jni_util_test.cc b/src/arrow/cpp/src/jni/dataset/jni_util_test.cc
new file mode 100644
index 000000000..589f00b1c
--- /dev/null
+++ b/src/arrow/cpp/src/jni/dataset/jni_util_test.cc
@@ -0,0 +1,134 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include "arrow/memory_pool.h"
+#include "arrow/testing/gtest_util.h"
+#include "jni/dataset/jni_util.h"
+
+namespace arrow {
+namespace dataset {
+namespace jni {
+
+class MyListener : public ReservationListener {
+ public:
+ Status OnReservation(int64_t size) override {
+ bytes_reserved_ += size;
+ reservation_count_++;
+ return arrow::Status::OK();
+ }
+
+ Status OnRelease(int64_t size) override {
+ bytes_reserved_ -= size;
+ release_count_++;
+ return arrow::Status::OK();
+ }
+
+ int64_t bytes_reserved() { return bytes_reserved_; }
+
+ int32_t reservation_count() const { return reservation_count_; }
+
+ int32_t release_count() const { return release_count_; }
+
+ private:
+ int64_t bytes_reserved_;
+ int32_t reservation_count_;
+ int32_t release_count_;
+};
+
+TEST(ReservationListenableMemoryPool, Basic) {
+ auto pool = MemoryPool::CreateDefault();
+ auto listener = std::make_shared<MyListener>();
+ ReservationListenableMemoryPool rlp(pool.get(), listener);
+
+ uint8_t* data;
+ ASSERT_OK(rlp.Allocate(100, &data));
+
+ uint8_t* data2;
+ ASSERT_OK(rlp.Allocate(100, &data2));
+
+ rlp.Free(data, 100);
+ rlp.Free(data2, 100);
+
+ ASSERT_EQ(200, rlp.max_memory());
+ ASSERT_EQ(200, pool->max_memory());
+}
+
+TEST(ReservationListenableMemoryPool, Listener) {
+ auto pool = MemoryPool::CreateDefault();
+ auto listener = std::make_shared<MyListener>();
+ ReservationListenableMemoryPool rlp(pool.get(), listener);
+
+ uint8_t* data;
+ ASSERT_OK(rlp.Allocate(100, &data));
+
+ uint8_t* data2;
+ ASSERT_OK(rlp.Allocate(100, &data2));
+
+ ASSERT_EQ(200, rlp.bytes_allocated());
+ ASSERT_EQ(512 * 1024, listener->bytes_reserved());
+
+ rlp.Free(data, 100);
+ rlp.Free(data2, 100);
+
+ ASSERT_EQ(0, rlp.bytes_allocated());
+ ASSERT_EQ(0, listener->bytes_reserved());
+ ASSERT_EQ(1, listener->reservation_count());
+ ASSERT_EQ(1, listener->release_count());
+}
+
+TEST(ReservationListenableMemoryPool, BlockSize) {
+ auto pool = MemoryPool::CreateDefault();
+ auto listener = std::make_shared<MyListener>();
+ ReservationListenableMemoryPool rlp(pool.get(), listener, 100);
+
+ uint8_t* data;
+ ASSERT_OK(rlp.Allocate(100, &data));
+
+ ASSERT_EQ(100, rlp.bytes_allocated());
+ ASSERT_EQ(100, listener->bytes_reserved());
+
+ rlp.Free(data, 100);
+
+ ASSERT_EQ(0, rlp.bytes_allocated());
+ ASSERT_EQ(0, listener->bytes_reserved());
+}
+
+TEST(ReservationListenableMemoryPool, BlockSize2) {
+ auto pool = MemoryPool::CreateDefault();
+ auto listener = std::make_shared<MyListener>();
+ ReservationListenableMemoryPool rlp(pool.get(), listener, 99);
+
+ uint8_t* data;
+ ASSERT_OK(rlp.Allocate(100, &data));
+
+ ASSERT_EQ(100, rlp.bytes_allocated());
+ ASSERT_EQ(198, listener->bytes_reserved());
+
+ rlp.Free(data, 100);
+
+ ASSERT_EQ(0, rlp.bytes_allocated());
+ ASSERT_EQ(0, listener->bytes_reserved());
+
+ ASSERT_EQ(1, listener->reservation_count());
+ ASSERT_EQ(1, listener->release_count());
+}
+
+} // namespace jni
+} // namespace dataset
+} // namespace arrow
diff --git a/src/arrow/cpp/src/jni/dataset/jni_wrapper.cc b/src/arrow/cpp/src/jni/dataset/jni_wrapper.cc
new file mode 100644
index 000000000..558f0880b
--- /dev/null
+++ b/src/arrow/cpp/src/jni/dataset/jni_wrapper.cc
@@ -0,0 +1,545 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <mutex>
+
+#include "arrow/array.h"
+#include "arrow/dataset/api.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/filesystem/localfs.h"
+#include "arrow/ipc/api.h"
+#include "arrow/util/iterator.h"
+
+#include "jni/dataset/jni_util.h"
+
+#include "org_apache_arrow_dataset_file_JniWrapper.h"
+#include "org_apache_arrow_dataset_jni_JniWrapper.h"
+#include "org_apache_arrow_dataset_jni_NativeMemoryPool.h"
+
+namespace {
+
+jclass illegal_access_exception_class;
+jclass illegal_argument_exception_class;
+jclass runtime_exception_class;
+
+jclass record_batch_handle_class;
+jclass record_batch_handle_field_class;
+jclass record_batch_handle_buffer_class;
+jclass java_reservation_listener_class;
+
+jmethodID record_batch_handle_constructor;
+jmethodID record_batch_handle_field_constructor;
+jmethodID record_batch_handle_buffer_constructor;
+jmethodID reserve_memory_method;
+jmethodID unreserve_memory_method;
+
+jlong default_memory_pool_id = -1L;
+
+jint JNI_VERSION = JNI_VERSION_1_6;
+
+class JniPendingException : public std::runtime_error {
+ public:
+ explicit JniPendingException(const std::string& arg) : runtime_error(arg) {}
+};
+
+void ThrowPendingException(const std::string& message) {
+ throw JniPendingException(message);
+}
+
+template <typename T>
+T JniGetOrThrow(arrow::Result<T> result) {
+ if (!result.status().ok()) {
+ ThrowPendingException(result.status().message());
+ }
+ return std::move(result).ValueOrDie();
+}
+
+void JniAssertOkOrThrow(arrow::Status status) {
+ if (!status.ok()) {
+ ThrowPendingException(status.message());
+ }
+}
+
+void JniThrow(std::string message) { ThrowPendingException(message); }
+
+arrow::Result<std::shared_ptr<arrow::dataset::FileFormat>> GetFileFormat(
+ jint file_format_id) {
+ switch (file_format_id) {
+ case 0:
+ return std::make_shared<arrow::dataset::ParquetFileFormat>();
+ default:
+ std::string error_message =
+ "illegal file format id: " + std::to_string(file_format_id);
+ return arrow::Status::Invalid(error_message);
+ }
+}
+
+class ReserveFromJava : public arrow::dataset::jni::ReservationListener {
+ public:
+ ReserveFromJava(JavaVM* vm, jobject java_reservation_listener)
+ : vm_(vm), java_reservation_listener_(java_reservation_listener) {}
+
+ arrow::Status OnReservation(int64_t size) override {
+ JNIEnv* env;
+ if (vm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
+ return arrow::Status::Invalid("JNIEnv was not attached to current thread");
+ }
+ env->CallObjectMethod(java_reservation_listener_, reserve_memory_method, size);
+ if (env->ExceptionCheck()) {
+ env->ExceptionDescribe();
+ env->ExceptionClear();
+ return arrow::Status::Invalid("Error calling Java side reservation listener");
+ }
+ return arrow::Status::OK();
+ }
+
+ arrow::Status OnRelease(int64_t size) override {
+ JNIEnv* env;
+ if (vm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
+ return arrow::Status::Invalid("JNIEnv was not attached to current thread");
+ }
+ env->CallObjectMethod(java_reservation_listener_, unreserve_memory_method, size);
+ if (env->ExceptionCheck()) {
+ env->ExceptionDescribe();
+ env->ExceptionClear();
+ return arrow::Status::Invalid("Error calling Java side reservation listener");
+ }
+ return arrow::Status::OK();
+ }
+
+ jobject GetJavaReservationListener() { return java_reservation_listener_; }
+
+ private:
+ JavaVM* vm_;
+ jobject java_reservation_listener_;
+};
+
+/// \class DisposableScannerAdaptor
+/// \brief An adaptor that iterates over a Scanner instance then returns RecordBatches
+/// directly.
+///
+/// This lessens the complexity of the JNI bridge to make sure it to be easier to
+/// maintain. On Java-side, NativeScanner can only produces a single NativeScanTask
+/// instance during its whole lifecycle. Each task stands for a DisposableScannerAdaptor
+/// instance through JNI bridge.
+///
+class DisposableScannerAdaptor {
+ public:
+ DisposableScannerAdaptor(std::shared_ptr<arrow::dataset::Scanner> scanner,
+ arrow::dataset::TaggedRecordBatchIterator batch_itr)
+ : scanner_(std::move(scanner)), batch_itr_(std::move(batch_itr)) {}
+
+ static arrow::Result<std::shared_ptr<DisposableScannerAdaptor>> Create(
+ std::shared_ptr<arrow::dataset::Scanner> scanner) {
+ ARROW_ASSIGN_OR_RAISE(auto batch_itr, scanner->ScanBatches());
+ return std::make_shared<DisposableScannerAdaptor>(scanner, std::move(batch_itr));
+ }
+
+ arrow::Result<std::shared_ptr<arrow::RecordBatch>> Next() {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<arrow::RecordBatch> batch, NextBatch());
+ return batch;
+ }
+
+ const std::shared_ptr<arrow::dataset::Scanner>& GetScanner() const { return scanner_; }
+
+ private:
+ std::shared_ptr<arrow::dataset::Scanner> scanner_;
+ arrow::dataset::TaggedRecordBatchIterator batch_itr_;
+
+ arrow::Result<std::shared_ptr<arrow::RecordBatch>> NextBatch() {
+ ARROW_ASSIGN_OR_RAISE(auto batch, batch_itr_.Next())
+ return batch.record_batch;
+ }
+};
+
+} // namespace
+
+using arrow::dataset::jni::CreateGlobalClassReference;
+using arrow::dataset::jni::CreateNativeRef;
+using arrow::dataset::jni::FromSchemaByteArray;
+using arrow::dataset::jni::GetMethodID;
+using arrow::dataset::jni::JStringToCString;
+using arrow::dataset::jni::ReleaseNativeRef;
+using arrow::dataset::jni::RetrieveNativeInstance;
+using arrow::dataset::jni::ToSchemaByteArray;
+using arrow::dataset::jni::ToStringVector;
+
+using arrow::dataset::jni::ReservationListenableMemoryPool;
+using arrow::dataset::jni::ReservationListener;
+
+#define JNI_METHOD_START try {
+// macro ended
+
+#define JNI_METHOD_END(fallback_expr) \
+ } \
+ catch (JniPendingException & e) { \
+ env->ThrowNew(runtime_exception_class, e.what()); \
+ return fallback_expr; \
+ }
+// macro ended
+
+jint JNI_OnLoad(JavaVM* vm, void* reserved) {
+ JNIEnv* env;
+ if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
+ return JNI_ERR;
+ }
+ JNI_METHOD_START
+ illegal_access_exception_class =
+ CreateGlobalClassReference(env, "Ljava/lang/IllegalAccessException;");
+ illegal_argument_exception_class =
+ CreateGlobalClassReference(env, "Ljava/lang/IllegalArgumentException;");
+ runtime_exception_class =
+ CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;");
+
+ record_batch_handle_class =
+ CreateGlobalClassReference(env,
+ "Lorg/apache/arrow/"
+ "dataset/jni/NativeRecordBatchHandle;");
+ record_batch_handle_field_class =
+ CreateGlobalClassReference(env,
+ "Lorg/apache/arrow/"
+ "dataset/jni/NativeRecordBatchHandle$Field;");
+ record_batch_handle_buffer_class =
+ CreateGlobalClassReference(env,
+ "Lorg/apache/arrow/"
+ "dataset/jni/NativeRecordBatchHandle$Buffer;");
+ java_reservation_listener_class =
+ CreateGlobalClassReference(env,
+ "Lorg/apache/arrow/"
+ "dataset/jni/ReservationListener;");
+
+ record_batch_handle_constructor =
+ JniGetOrThrow(GetMethodID(env, record_batch_handle_class, "<init>",
+ "(J[Lorg/apache/arrow/dataset/"
+ "jni/NativeRecordBatchHandle$Field;"
+ "[Lorg/apache/arrow/dataset/"
+ "jni/NativeRecordBatchHandle$Buffer;)V"));
+ record_batch_handle_field_constructor =
+ JniGetOrThrow(GetMethodID(env, record_batch_handle_field_class, "<init>", "(JJ)V"));
+ record_batch_handle_buffer_constructor = JniGetOrThrow(
+ GetMethodID(env, record_batch_handle_buffer_class, "<init>", "(JJJJ)V"));
+ reserve_memory_method =
+ JniGetOrThrow(GetMethodID(env, java_reservation_listener_class, "reserve", "(J)V"));
+ unreserve_memory_method = JniGetOrThrow(
+ GetMethodID(env, java_reservation_listener_class, "unreserve", "(J)V"));
+
+ default_memory_pool_id = reinterpret_cast<jlong>(arrow::default_memory_pool());
+
+ return JNI_VERSION;
+ JNI_METHOD_END(JNI_ERR)
+}
+
+void JNI_OnUnload(JavaVM* vm, void* reserved) {
+ JNIEnv* env;
+ vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION);
+ env->DeleteGlobalRef(illegal_access_exception_class);
+ env->DeleteGlobalRef(illegal_argument_exception_class);
+ env->DeleteGlobalRef(runtime_exception_class);
+ env->DeleteGlobalRef(record_batch_handle_class);
+ env->DeleteGlobalRef(record_batch_handle_field_class);
+ env->DeleteGlobalRef(record_batch_handle_buffer_class);
+ env->DeleteGlobalRef(java_reservation_listener_class);
+
+ default_memory_pool_id = -1L;
+}
+
+/*
+ * Class: org_apache_arrow_dataset_jni_NativeMemoryPool
+ * Method: getDefaultMemoryPool
+ * Signature: ()J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_apache_arrow_dataset_jni_NativeMemoryPool_getDefaultMemoryPool(JNIEnv* env,
+ jclass) {
+ JNI_METHOD_START
+ return default_memory_pool_id;
+ JNI_METHOD_END(-1L)
+}
+
+/*
+ * Class: org_apache_arrow_dataset_jni_NativeMemoryPool
+ * Method: createListenableMemoryPool
+ * Signature: (Lorg/apache/arrow/memory/ReservationListener;)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_apache_arrow_dataset_jni_NativeMemoryPool_createListenableMemoryPool(
+ JNIEnv* env, jclass, jobject jlistener) {
+ JNI_METHOD_START
+ jobject jlistener_ref = env->NewGlobalRef(jlistener);
+ JavaVM* vm;
+ if (env->GetJavaVM(&vm) != JNI_OK) {
+ JniThrow("Unable to get JavaVM instance");
+ }
+ std::shared_ptr<ReservationListener> listener =
+ std::make_shared<ReserveFromJava>(vm, jlistener_ref);
+ auto memory_pool =
+ new ReservationListenableMemoryPool(arrow::default_memory_pool(), listener);
+ return reinterpret_cast<jlong>(memory_pool);
+ JNI_METHOD_END(-1L)
+}
+
+/*
+ * Class: org_apache_arrow_dataset_jni_NativeMemoryPool
+ * Method: releaseMemoryPool
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL
+Java_org_apache_arrow_dataset_jni_NativeMemoryPool_releaseMemoryPool(
+ JNIEnv* env, jclass, jlong memory_pool_id) {
+ JNI_METHOD_START
+ if (memory_pool_id == default_memory_pool_id) {
+ return;
+ }
+ ReservationListenableMemoryPool* pool =
+ reinterpret_cast<ReservationListenableMemoryPool*>(memory_pool_id);
+ if (pool == nullptr) {
+ return;
+ }
+ std::shared_ptr<ReserveFromJava> rm =
+ std::dynamic_pointer_cast<ReserveFromJava>(pool->get_listener());
+ if (rm == nullptr) {
+ delete pool;
+ return;
+ }
+ delete pool;
+ env->DeleteGlobalRef(rm->GetJavaReservationListener());
+ JNI_METHOD_END()
+}
+
+/*
+ * Class: org_apache_arrow_dataset_jni_NativeMemoryPool
+ * Method: bytesAllocated
+ * Signature: (J)J
+ */
+JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_NativeMemoryPool_bytesAllocated(
+ JNIEnv* env, jclass, jlong memory_pool_id) {
+ JNI_METHOD_START
+ arrow::MemoryPool* pool = reinterpret_cast<arrow::MemoryPool*>(memory_pool_id);
+ if (pool == nullptr) {
+ JniThrow("Memory pool instance not found. It may not exist nor has been closed");
+ }
+ return pool->bytes_allocated();
+ JNI_METHOD_END(-1L)
+}
+
+/*
+ * Class: org_apache_arrow_dataset_jni_JniWrapper
+ * Method: closeDatasetFactory
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDatasetFactory(
+ JNIEnv* env, jobject, jlong id) {
+ JNI_METHOD_START
+ ReleaseNativeRef<arrow::dataset::DatasetFactory>(id);
+ JNI_METHOD_END()
+}
+
+/*
+ * Class: org_apache_arrow_dataset_jni_JniWrapper
+ * Method: inspectSchema
+ * Signature: (J)[B
+ */
+JNIEXPORT jbyteArray JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_inspectSchema(
+ JNIEnv* env, jobject, jlong dataset_factor_id) {
+ JNI_METHOD_START
+ std::shared_ptr<arrow::dataset::DatasetFactory> d =
+ RetrieveNativeInstance<arrow::dataset::DatasetFactory>(dataset_factor_id);
+ std::shared_ptr<arrow::Schema> schema = JniGetOrThrow(d->Inspect());
+ return JniGetOrThrow(ToSchemaByteArray(env, schema));
+ JNI_METHOD_END(nullptr)
+}
+
+/*
+ * Class: org_apache_arrow_dataset_jni_JniWrapper
+ * Method: createDataset
+ * Signature: (J[B)J
+ */
+JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createDataset(
+ JNIEnv* env, jobject, jlong dataset_factory_id, jbyteArray schema_bytes) {
+ JNI_METHOD_START
+ std::shared_ptr<arrow::dataset::DatasetFactory> d =
+ RetrieveNativeInstance<arrow::dataset::DatasetFactory>(dataset_factory_id);
+ std::shared_ptr<arrow::Schema> schema;
+ schema = JniGetOrThrow(FromSchemaByteArray(env, schema_bytes));
+ std::shared_ptr<arrow::dataset::Dataset> dataset = JniGetOrThrow(d->Finish(schema));
+ return CreateNativeRef(dataset);
+ JNI_METHOD_END(-1L)
+}
+
+/*
+ * Class: org_apache_arrow_dataset_jni_JniWrapper
+ * Method: closeDataset
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDataset(
+ JNIEnv* env, jobject, jlong id) {
+ JNI_METHOD_START
+ ReleaseNativeRef<arrow::dataset::Dataset>(id);
+ JNI_METHOD_END()
+}
+
+/*
+ * Class: org_apache_arrow_dataset_jni_JniWrapper
+ * Method: createScanner
+ * Signature: (J[Ljava/lang/String;JJ)J
+ */
+JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScanner(
+ JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns, jlong batch_size,
+ jlong memory_pool_id) {
+ JNI_METHOD_START
+ arrow::MemoryPool* pool = reinterpret_cast<arrow::MemoryPool*>(memory_pool_id);
+ if (pool == nullptr) {
+ JniThrow("Memory pool does not exist or has been closed");
+ }
+ std::shared_ptr<arrow::dataset::Dataset> dataset =
+ RetrieveNativeInstance<arrow::dataset::Dataset>(dataset_id);
+ std::shared_ptr<arrow::dataset::ScannerBuilder> scanner_builder =
+ JniGetOrThrow(dataset->NewScan());
+ JniAssertOkOrThrow(scanner_builder->Pool(pool));
+ if (columns != nullptr) {
+ std::vector<std::string> column_vector = ToStringVector(env, columns);
+ JniAssertOkOrThrow(scanner_builder->Project(column_vector));
+ }
+ JniAssertOkOrThrow(scanner_builder->BatchSize(batch_size));
+
+ auto scanner = JniGetOrThrow(scanner_builder->Finish());
+ std::shared_ptr<DisposableScannerAdaptor> scanner_adaptor =
+ JniGetOrThrow(DisposableScannerAdaptor::Create(scanner));
+ jlong id = CreateNativeRef(scanner_adaptor);
+ return id;
+ JNI_METHOD_END(-1L)
+}
+
+/*
+ * Class: org_apache_arrow_dataset_jni_JniWrapper
+ * Method: closeScanner
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeScanner(
+ JNIEnv* env, jobject, jlong scanner_id) {
+ JNI_METHOD_START
+ ReleaseNativeRef<DisposableScannerAdaptor>(scanner_id);
+ JNI_METHOD_END()
+}
+
+/*
+ * Class: org_apache_arrow_dataset_jni_JniWrapper
+ * Method: getSchemaFromScanner
+ * Signature: (J)[B
+ */
+JNIEXPORT jbyteArray JNICALL
+Java_org_apache_arrow_dataset_jni_JniWrapper_getSchemaFromScanner(JNIEnv* env, jobject,
+ jlong scanner_id) {
+ JNI_METHOD_START
+ std::shared_ptr<arrow::Schema> schema =
+ RetrieveNativeInstance<DisposableScannerAdaptor>(scanner_id)
+ ->GetScanner()
+ ->options()
+ ->projected_schema;
+ return JniGetOrThrow(ToSchemaByteArray(env, schema));
+ JNI_METHOD_END(nullptr)
+}
+
+/*
+ * Class: org_apache_arrow_dataset_jni_JniWrapper
+ * Method: nextRecordBatch
+ * Signature: (J)Lorg/apache/arrow/dataset/jni/NativeRecordBatchHandle;
+ */
+JNIEXPORT jobject JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_nextRecordBatch(
+ JNIEnv* env, jobject, jlong scanner_id) {
+ JNI_METHOD_START
+ std::shared_ptr<DisposableScannerAdaptor> scanner_adaptor =
+ RetrieveNativeInstance<DisposableScannerAdaptor>(scanner_id);
+
+ std::shared_ptr<arrow::RecordBatch> record_batch =
+ JniGetOrThrow(scanner_adaptor->Next());
+ if (record_batch == nullptr) {
+ return nullptr; // stream ended
+ }
+ std::shared_ptr<arrow::Schema> schema = record_batch->schema();
+ jobjectArray field_array =
+ env->NewObjectArray(schema->num_fields(), record_batch_handle_field_class, nullptr);
+
+ std::vector<std::shared_ptr<arrow::Buffer>> buffers;
+ for (int i = 0; i < schema->num_fields(); ++i) {
+ auto column = record_batch->column(i);
+ auto dataArray = column->data();
+ jobject field = env->NewObject(record_batch_handle_field_class,
+ record_batch_handle_field_constructor,
+ column->length(), column->null_count());
+ env->SetObjectArrayElement(field_array, i, field);
+
+ for (auto& buffer : dataArray->buffers) {
+ buffers.push_back(buffer);
+ }
+ }
+
+ jobjectArray buffer_array =
+ env->NewObjectArray(buffers.size(), record_batch_handle_buffer_class, nullptr);
+
+ for (size_t j = 0; j < buffers.size(); ++j) {
+ auto buffer = buffers[j];
+ uint8_t* data = nullptr;
+ int64_t size = 0;
+ int64_t capacity = 0;
+ if (buffer != nullptr) {
+ data = (uint8_t*)buffer->data();
+ size = buffer->size();
+ capacity = buffer->capacity();
+ }
+ jobject buffer_handle = env->NewObject(record_batch_handle_buffer_class,
+ record_batch_handle_buffer_constructor,
+ CreateNativeRef(buffer), data, size, capacity);
+ env->SetObjectArrayElement(buffer_array, j, buffer_handle);
+ }
+
+ jobject ret = env->NewObject(record_batch_handle_class, record_batch_handle_constructor,
+ record_batch->num_rows(), field_array, buffer_array);
+ return ret;
+ JNI_METHOD_END(nullptr)
+}
+
+/*
+ * Class: org_apache_arrow_dataset_jni_JniWrapper
+ * Method: releaseBuffer
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_releaseBuffer(
+ JNIEnv* env, jobject, jlong id) {
+ JNI_METHOD_START
+ ReleaseNativeRef<arrow::Buffer>(id);
+ JNI_METHOD_END()
+}
+
+/*
+ * Class: org_apache_arrow_dataset_file_JniWrapper
+ * Method: makeFileSystemDatasetFactory
+ * Signature: (Ljava/lang/String;II)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory(
+ JNIEnv* env, jobject, jstring uri, jint file_format_id) {
+ JNI_METHOD_START
+ std::shared_ptr<arrow::dataset::FileFormat> file_format =
+ JniGetOrThrow(GetFileFormat(file_format_id));
+ arrow::dataset::FileSystemFactoryOptions options;
+ std::shared_ptr<arrow::dataset::DatasetFactory> d =
+ JniGetOrThrow(arrow::dataset::FileSystemDatasetFactory::Make(
+ JStringToCString(env, uri), file_format, options));
+ return CreateNativeRef(d);
+ JNI_METHOD_END(-1L)
+}
diff --git a/src/arrow/cpp/src/jni/orc/CMakeLists.txt b/src/arrow/cpp/src/jni/orc/CMakeLists.txt
new file mode 100644
index 000000000..eceda5294
--- /dev/null
+++ b/src/arrow/cpp/src/jni/orc/CMakeLists.txt
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#
+# arrow_orc_jni
+#
+
+project(arrow_orc_jni)
+
+cmake_minimum_required(VERSION 3.11)
+
+find_package(JNI REQUIRED)
+
+add_custom_target(arrow_orc_jni)
+
+set(JNI_HEADERS_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated")
+
+add_subdirectory(../../../../java/adapter/orc ./java)
+
+add_arrow_lib(arrow_orc_jni
+ BUILD_SHARED
+ ON
+ BUILD_STATIC
+ OFF
+ SOURCES
+ jni_wrapper.cpp
+ OUTPUTS
+ ARROW_ORC_JNI_LIBRARIES
+ SHARED_PRIVATE_LINK_LIBS
+ arrow_static
+ EXTRA_INCLUDES
+ ${JNI_HEADERS_DIR}
+ PRIVATE_INCLUDES
+ ${JNI_INCLUDE_DIRS}
+ DEPENDENCIES
+ arrow_static
+ arrow_orc_java)
+
+add_dependencies(arrow_orc_jni ${ARROW_ORC_JNI_LIBRARIES})
diff --git a/src/arrow/cpp/src/jni/orc/concurrent_map.h b/src/arrow/cpp/src/jni/orc/concurrent_map.h
new file mode 100644
index 000000000..b56088662
--- /dev/null
+++ b/src/arrow/cpp/src/jni/orc/concurrent_map.h
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ */
+
+#pragma once
+
+#include <memory>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace jni {
+
+/**
+ * An utility class that map module id to module pointers.
+ * @tparam Holder class of the object to hold.
+ */
+template <typename Holder>
+class ConcurrentMap {
+ public:
+ ConcurrentMap() : module_id_(init_module_id_) {}
+
+ jlong Insert(Holder holder) {
+ std::lock_guard<std::mutex> lock(mtx_);
+ jlong result = module_id_++;
+ map_.insert(std::pair<jlong, Holder>(result, holder));
+ return result;
+ }
+
+ void Erase(jlong module_id) {
+ std::lock_guard<std::mutex> lock(mtx_);
+ map_.erase(module_id);
+ }
+
+ Holder Lookup(jlong module_id) {
+ std::lock_guard<std::mutex> lock(mtx_);
+ auto it = map_.find(module_id);
+ if (it != map_.end()) {
+ return it->second;
+ }
+ return NULLPTR;
+ }
+
+ void Clear() {
+ std::lock_guard<std::mutex> lock(mtx_);
+ map_.clear();
+ }
+
+ private:
+ // Initialize the module id starting value to a number greater than zero
+ // to allow for easier debugging of uninitialized java variables.
+ static constexpr int init_module_id_ = 4;
+
+ int64_t module_id_;
+ std::mutex mtx_;
+ // map from module ids returned to Java and module pointers
+ std::unordered_map<jlong, Holder> map_;
+};
+
+} // namespace jni
+} // namespace arrow
diff --git a/src/arrow/cpp/src/jni/orc/jni_wrapper.cpp b/src/arrow/cpp/src/jni/orc/jni_wrapper.cpp
new file mode 100644
index 000000000..cc629c9c4
--- /dev/null
+++ b/src/arrow/cpp/src/jni/orc/jni_wrapper.cpp
@@ -0,0 +1,306 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cassert>
+#include <iostream>
+#include <string>
+
+#include <arrow/adapters/orc/adapter.h>
+#include <arrow/array.h>
+#include <arrow/buffer.h>
+#include <arrow/io/api.h>
+#include <arrow/ipc/api.h>
+#include <arrow/util/checked_cast.h>
+#include <arrow/util/logging.h>
+
+#include "org_apache_arrow_adapter_orc_OrcMemoryJniWrapper.h"
+#include "org_apache_arrow_adapter_orc_OrcReaderJniWrapper.h"
+#include "org_apache_arrow_adapter_orc_OrcStripeReaderJniWrapper.h"
+
+#include "./concurrent_map.h"
+
+using ORCFileReader = arrow::adapters::orc::ORCFileReader;
+using RecordBatchReader = arrow::RecordBatchReader;
+
+static jclass io_exception_class;
+static jclass illegal_access_exception_class;
+static jclass illegal_argument_exception_class;
+
+static jclass orc_field_node_class;
+static jmethodID orc_field_node_constructor;
+
+static jclass orc_memory_class;
+static jmethodID orc_memory_constructor;
+
+static jclass record_batch_class;
+static jmethodID record_batch_constructor;
+
+static jint JNI_VERSION = JNI_VERSION_1_6;
+
+using arrow::internal::checked_cast;
+using arrow::jni::ConcurrentMap;
+
+static ConcurrentMap<std::shared_ptr<arrow::Buffer>> buffer_holder_;
+static ConcurrentMap<std::shared_ptr<RecordBatchReader>> orc_stripe_reader_holder_;
+static ConcurrentMap<std::shared_ptr<ORCFileReader>> orc_reader_holder_;
+
+jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) {
+ jclass local_class = env->FindClass(class_name);
+ jclass global_class = (jclass)env->NewGlobalRef(local_class);
+ env->DeleteLocalRef(local_class);
+ return global_class;
+}
+
+jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const char* sig) {
+ jmethodID ret = env->GetMethodID(this_class, name, sig);
+ if (ret == nullptr) {
+ std::string error_message = "Unable to find method " + std::string(name) +
+ " within signature" + std::string(sig);
+ env->ThrowNew(illegal_access_exception_class, error_message.c_str());
+ }
+
+ return ret;
+}
+
+std::string JStringToCString(JNIEnv* env, jstring string) {
+ int32_t jlen, clen;
+ clen = env->GetStringUTFLength(string);
+ jlen = env->GetStringLength(string);
+ std::vector<char> buffer(clen);
+ env->GetStringUTFRegion(string, 0, jlen, buffer.data());
+ return std::string(buffer.data(), clen);
+}
+
+std::shared_ptr<ORCFileReader> GetFileReader(JNIEnv* env, jlong id) {
+ auto reader = orc_reader_holder_.Lookup(id);
+ if (!reader) {
+ std::string error_message = "invalid reader id " + std::to_string(id);
+ env->ThrowNew(illegal_argument_exception_class, error_message.c_str());
+ }
+
+ return reader;
+}
+
+std::shared_ptr<RecordBatchReader> GetStripeReader(JNIEnv* env, jlong id) {
+ auto reader = orc_stripe_reader_holder_.Lookup(id);
+ if (!reader) {
+ std::string error_message = "invalid stripe reader id " + std::to_string(id);
+ env->ThrowNew(illegal_argument_exception_class, error_message.c_str());
+ }
+
+ return reader;
+}
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+jint JNI_OnLoad(JavaVM* vm, void* reserved) {
+ JNIEnv* env;
+ if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
+ return JNI_ERR;
+ }
+
+ io_exception_class = CreateGlobalClassReference(env, "Ljava/io/IOException;");
+ illegal_access_exception_class =
+ CreateGlobalClassReference(env, "Ljava/lang/IllegalAccessException;");
+ illegal_argument_exception_class =
+ CreateGlobalClassReference(env, "Ljava/lang/IllegalArgumentException;");
+
+ orc_field_node_class =
+ CreateGlobalClassReference(env, "Lorg/apache/arrow/adapter/orc/OrcFieldNode;");
+ orc_field_node_constructor = GetMethodID(env, orc_field_node_class, "<init>", "(II)V");
+
+ orc_memory_class = CreateGlobalClassReference(
+ env, "Lorg/apache/arrow/adapter/orc/OrcMemoryJniWrapper;");
+ orc_memory_constructor = GetMethodID(env, orc_memory_class, "<init>", "(JJJJ)V");
+
+ record_batch_class =
+ CreateGlobalClassReference(env, "Lorg/apache/arrow/adapter/orc/OrcRecordBatch;");
+ record_batch_constructor =
+ GetMethodID(env, record_batch_class, "<init>",
+ "(I[Lorg/apache/arrow/adapter/orc/OrcFieldNode;"
+ "[Lorg/apache/arrow/adapter/orc/OrcMemoryJniWrapper;)V");
+
+ env->ExceptionDescribe();
+
+ return JNI_VERSION;
+}
+
+void JNI_OnUnload(JavaVM* vm, void* reserved) {
+ JNIEnv* env;
+ vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION);
+ env->DeleteGlobalRef(io_exception_class);
+ env->DeleteGlobalRef(illegal_access_exception_class);
+ env->DeleteGlobalRef(illegal_argument_exception_class);
+ env->DeleteGlobalRef(orc_field_node_class);
+ env->DeleteGlobalRef(orc_memory_class);
+ env->DeleteGlobalRef(record_batch_class);
+
+ buffer_holder_.Clear();
+ orc_stripe_reader_holder_.Clear();
+ orc_reader_holder_.Clear();
+}
+
+JNIEXPORT jlong JNICALL Java_org_apache_arrow_adapter_orc_OrcReaderJniWrapper_open(
+ JNIEnv* env, jobject this_obj, jstring file_path) {
+ std::string path = JStringToCString(env, file_path);
+
+ if (path.find("hdfs://") == 0) {
+ env->ThrowNew(io_exception_class, "hdfs path not supported yet.");
+ }
+ auto maybe_file = arrow::io::ReadableFile::Open(path);
+
+ if (!maybe_file.ok()) {
+ return -static_cast<jlong>(maybe_file.status().code());
+ }
+ auto maybe_reader = ORCFileReader::Open(*maybe_file, arrow::default_memory_pool());
+ if (!maybe_reader.ok()) {
+ env->ThrowNew(io_exception_class, std::string("Failed open file" + path).c_str());
+ }
+ return orc_reader_holder_.Insert(
+ std::shared_ptr<ORCFileReader>(*std::move(maybe_reader)));
+}
+
+JNIEXPORT void JNICALL Java_org_apache_arrow_adapter_orc_OrcReaderJniWrapper_close(
+ JNIEnv* env, jobject this_obj, jlong id) {
+ orc_reader_holder_.Erase(id);
+}
+
+JNIEXPORT jboolean JNICALL Java_org_apache_arrow_adapter_orc_OrcReaderJniWrapper_seek(
+ JNIEnv* env, jobject this_obj, jlong id, jint row_number) {
+ auto reader = GetFileReader(env, id);
+ return reader->Seek(row_number).ok();
+}
+
+JNIEXPORT jint JNICALL
+Java_org_apache_arrow_adapter_orc_OrcReaderJniWrapper_getNumberOfStripes(JNIEnv* env,
+ jobject this_obj,
+ jlong id) {
+ auto reader = GetFileReader(env, id);
+ return reader->NumberOfStripes();
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_apache_arrow_adapter_orc_OrcReaderJniWrapper_nextStripeReader(JNIEnv* env,
+ jobject this_obj,
+ jlong id,
+ jlong batch_size) {
+ auto reader = GetFileReader(env, id);
+
+ auto maybe_stripe_reader = reader->NextStripeReader(batch_size);
+ if (!maybe_stripe_reader.ok()) {
+ return static_cast<jlong>(maybe_stripe_reader.status().code()) * -1;
+ }
+ if (*maybe_stripe_reader == nullptr) {
+ return static_cast<jlong>(arrow::StatusCode::Invalid) * -1;
+ }
+
+ return orc_stripe_reader_holder_.Insert(*maybe_stripe_reader);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_org_apache_arrow_adapter_orc_OrcStripeReaderJniWrapper_getSchema(JNIEnv* env,
+ jclass this_cls,
+ jlong id) {
+ auto stripe_reader = GetStripeReader(env, id);
+
+ auto schema = stripe_reader->schema();
+
+ auto maybe_buffer = arrow::ipc::SerializeSchema(*schema, arrow::default_memory_pool());
+ if (!maybe_buffer.ok()) {
+ return nullptr;
+ }
+ auto buffer = *std::move(maybe_buffer);
+
+ jbyteArray ret = env->NewByteArray(buffer->size());
+ auto src = reinterpret_cast<const jbyte*>(buffer->data());
+ env->SetByteArrayRegion(ret, 0, buffer->size(), src);
+ return ret;
+}
+
+JNIEXPORT jobject JNICALL
+Java_org_apache_arrow_adapter_orc_OrcStripeReaderJniWrapper_next(JNIEnv* env,
+ jclass this_cls,
+ jlong id) {
+ auto stripe_reader = GetStripeReader(env, id);
+
+ std::shared_ptr<arrow::RecordBatch> record_batch;
+ auto status = stripe_reader->ReadNext(&record_batch);
+ if (!status.ok() || !record_batch) {
+ return nullptr;
+ }
+
+ auto schema = stripe_reader->schema();
+
+ // TODO: ARROW-4714 Ensure JVM has sufficient capacity to create local references
+ // create OrcFieldNode[]
+ jobjectArray field_array =
+ env->NewObjectArray(schema->num_fields(), orc_field_node_class, nullptr);
+
+ std::vector<std::shared_ptr<arrow::Buffer>> buffers;
+ for (int i = 0; i < schema->num_fields(); ++i) {
+ auto column = record_batch->column(i);
+ auto dataArray = column->data();
+ jobject field = env->NewObject(orc_field_node_class, orc_field_node_constructor,
+ column->length(), column->null_count());
+ env->SetObjectArrayElement(field_array, i, field);
+
+ for (auto& buffer : dataArray->buffers) {
+ buffers.push_back(buffer);
+ }
+ }
+
+ // create OrcMemoryJniWrapper[]
+ jobjectArray memory_array =
+ env->NewObjectArray(buffers.size(), orc_memory_class, nullptr);
+
+ for (size_t j = 0; j < buffers.size(); ++j) {
+ auto buffer = buffers[j];
+ uint8_t* data = nullptr;
+ int size = 0;
+ int64_t capacity = 0;
+ if (buffer != nullptr) {
+ data = (uint8_t*)buffer->data();
+ size = (int)buffer->size();
+ capacity = buffer->capacity();
+ }
+ jobject memory = env->NewObject(orc_memory_class, orc_memory_constructor,
+ buffer_holder_.Insert(buffer), data, size, capacity);
+ env->SetObjectArrayElement(memory_array, j, memory);
+ }
+
+ // create OrcRecordBatch
+ jobject ret = env->NewObject(record_batch_class, record_batch_constructor,
+ record_batch->num_rows(), field_array, memory_array);
+
+ return ret;
+}
+
+JNIEXPORT void JNICALL Java_org_apache_arrow_adapter_orc_OrcStripeReaderJniWrapper_close(
+ JNIEnv* env, jclass this_cls, jlong id) {
+ orc_stripe_reader_holder_.Erase(id);
+}
+
+JNIEXPORT void JNICALL Java_org_apache_arrow_adapter_orc_OrcMemoryJniWrapper_release(
+ JNIEnv* env, jobject this_obj, jlong id) {
+ buffer_holder_.Erase(id);
+}
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/src/arrow/cpp/src/parquet/CMakeLists.txt b/src/arrow/cpp/src/parquet/CMakeLists.txt
new file mode 100644
index 000000000..cbf5882f9
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/CMakeLists.txt
@@ -0,0 +1,414 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_custom_target(parquet-all)
+add_custom_target(parquet)
+add_custom_target(parquet-benchmarks)
+add_custom_target(parquet-tests)
+add_dependencies(parquet-all parquet parquet-tests parquet-benchmarks)
+
+function(ADD_PARQUET_TEST REL_TEST_NAME)
+ set(one_value_args)
+ set(multi_value_args EXTRA_DEPENDENCIES LABELS)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+
+ set(TEST_ARGUMENTS PREFIX "parquet" LABELS "parquet-tests")
+
+ if(ARROW_TEST_LINKAGE STREQUAL "static")
+ add_test_case(${REL_TEST_NAME}
+ STATIC_LINK_LIBS
+ ${PARQUET_STATIC_TEST_LINK_LIBS}
+ ${TEST_ARGUMENTS}
+ ${ARG_UNPARSED_ARGUMENTS})
+ else()
+ add_test_case(${REL_TEST_NAME}
+ STATIC_LINK_LIBS
+ ${PARQUET_SHARED_TEST_LINK_LIBS}
+ ${TEST_ARGUMENTS}
+ ${ARG_UNPARSED_ARGUMENTS})
+ endif()
+endfunction()
+
+function(ADD_PARQUET_FUZZ_TARGET REL_FUZZING_NAME)
+ set(options)
+ set(one_value_args PREFIX)
+ set(multi_value_args)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+
+ if(ARG_PREFIX)
+ set(PREFIX ${ARG_PREFIX})
+ else()
+ set(PREFIX "parquet")
+ endif()
+
+ if(ARROW_BUILD_STATIC)
+ set(LINK_LIBS parquet_static)
+ else()
+ set(LINK_LIBS parquet_shared)
+ endif()
+ add_fuzz_target(${REL_FUZZING_NAME}
+ PREFIX
+ ${PREFIX}
+ LINK_LIBS
+ ${LINK_LIBS}
+ ${ARG_UNPARSED_ARGUMENTS})
+endfunction()
+
+function(ADD_PARQUET_BENCHMARK REL_TEST_NAME)
+ set(options)
+ set(one_value_args PREFIX)
+ set(multi_value_args)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+ if(ARG_PREFIX)
+ set(PREFIX ${ARG_PREFIX})
+ else()
+ set(PREFIX "parquet")
+ endif()
+ add_benchmark(${REL_TEST_NAME}
+ PREFIX
+ ${PREFIX}
+ LABELS
+ "parquet-benchmarks"
+ ${PARQUET_BENCHMARK_LINK_OPTION}
+ ${ARG_UNPARSED_ARGUMENTS})
+endfunction()
+
+# ----------------------------------------------------------------------
+# Link libraries setup
+
+# TODO(wesm): Handling of ABI/SO version
+
+if(ARROW_BUILD_STATIC)
+ set(PARQUET_STATIC_LINK_LIBS arrow_static)
+ set(ARROW_LIBRARIES_FOR_STATIC_TESTS arrow_testing_static arrow_static)
+else()
+ set(ARROW_LIBRARIES_FOR_STATIC_TESTS arrow_testing_shared arrow_shared)
+endif()
+
+set(PARQUET_MIN_TEST_LIBS GTest::gtest_main GTest::gtest)
+
+if(APPLE)
+ set(PARQUET_MIN_TEST_LIBS ${PARQUET_MIN_TEST_LIBS} ${CMAKE_DL_LIBS})
+elseif(NOT MSVC)
+ set(PARQUET_MIN_TEST_LIBS ${PARQUET_MIN_TEST_LIBS} pthread ${CMAKE_DL_LIBS})
+endif()
+
+set(PARQUET_SHARED_TEST_LINK_LIBS arrow_testing_shared ${PARQUET_MIN_TEST_LIBS}
+ parquet_shared thrift::thrift)
+
+set(PARQUET_STATIC_TEST_LINK_LIBS ${PARQUET_MIN_TEST_LIBS} parquet_static thrift::thrift
+ ${ARROW_LIBRARIES_FOR_STATIC_TESTS})
+
+#
+# Generated Thrift sources
+set_source_files_properties(src/generated/parquet_types.cpp src/generated/parquet_types.h
+ src/generated/parquet_constants.cpp
+ src/generated/parquet_constants.h
+ PROPERTIES SKIP_PRECOMPILE_HEADERS ON
+ SKIP_UNITY_BUILD_INCLUSION ON)
+
+if(NOT MSVC)
+ set_source_files_properties(src/parquet/parquet_types.cpp
+ PROPERTIES COMPILE_FLAGS -Wno-unused-variable)
+endif()
+
+#
+# Library config
+
+set(PARQUET_SRCS
+ arrow/path_internal.cc
+ arrow/reader.cc
+ arrow/reader_internal.cc
+ arrow/schema.cc
+ arrow/schema_internal.cc
+ arrow/writer.cc
+ bloom_filter.cc
+ column_reader.cc
+ column_scanner.cc
+ column_writer.cc
+ encoding.cc
+ encryption/encryption.cc
+ encryption/internal_file_decryptor.cc
+ encryption/internal_file_encryptor.cc
+ exception.cc
+ file_reader.cc
+ file_writer.cc
+ level_comparison.cc
+ level_conversion.cc
+ metadata.cc
+ murmur3.cc
+ "${ARROW_SOURCE_DIR}/src/generated/parquet_constants.cpp"
+ "${ARROW_SOURCE_DIR}/src/generated/parquet_types.cpp"
+ platform.cc
+ printer.cc
+ properties.cc
+ schema.cc
+ statistics.cc
+ stream_reader.cc
+ stream_writer.cc
+ types.cc)
+
+if(ARROW_HAVE_RUNTIME_AVX2)
+ # AVX2 is used as a proxy for BMI2.
+ list(APPEND PARQUET_SRCS level_comparison_avx2.cc level_conversion_bmi2.cc)
+ set_source_files_properties(level_comparison_avx2.cc
+ PROPERTIES SKIP_PRECOMPILE_HEADERS ON COMPILE_FLAGS
+ "${ARROW_AVX2_FLAG}")
+ # WARNING: DO NOT BLINDLY COPY THIS CODE FOR OTHER BMI2 USE CASES.
+ # This code is always guarded by runtime dispatch which verifies
+ # BMI2 is present. For a very small number of CPUs AVX2 does not
+ # imply BMI2.
+ set_source_files_properties(level_conversion_bmi2.cc
+ PROPERTIES SKIP_PRECOMPILE_HEADERS ON
+ COMPILE_FLAGS
+ "${ARROW_AVX2_FLAG} -DARROW_HAVE_BMI2 -mbmi2")
+endif()
+
+if(PARQUET_REQUIRE_ENCRYPTION)
+ set(PARQUET_SRCS ${PARQUET_SRCS} encryption/encryption_internal.cc)
+ # Encryption key management
+ set(PARQUET_SRCS
+ ${PARQUET_SRCS}
+ encryption/crypto_factory.cc
+ encryption/file_key_unwrapper.cc
+ encryption/file_key_wrapper.cc
+ encryption/kms_client.cc
+ encryption/key_material.cc
+ encryption/key_metadata.cc
+ encryption/key_toolkit.cc
+ encryption/key_toolkit_internal.cc
+ encryption/local_wrap_kms_client.cc)
+else()
+ set(PARQUET_SRCS ${PARQUET_SRCS} encryption/encryption_internal_nossl.cc)
+endif()
+
+if(NOT PARQUET_MINIMAL_DEPENDENCY)
+ set(PARQUET_SHARED_LINK_LIBS arrow_shared)
+
+ # These are libraries that we will link privately with parquet_shared (as they
+ # do not need to be linked transitively by other linkers)
+ set(PARQUET_SHARED_PRIVATE_LINK_LIBS thrift::thrift)
+
+ # Link publicly with parquet_static (because internal users need to
+ # transitively link all dependencies)
+ set(PARQUET_STATIC_LINK_LIBS ${PARQUET_STATIC_LINK_LIBS} thrift::thrift)
+
+ # Although we don't link parquet_objlib against anything, we need it to depend
+ # on these libs as we may generate their headers via ExternalProject_Add
+ if(ARROW_BUILD_SHARED)
+ set(PARQUET_DEPENDENCIES ${PARQUET_DEPENDENCIES} ${PARQUET_SHARED_LINK_LIBS}
+ ${PARQUET_SHARED_PRIVATE_LINK_LIBS})
+ endif()
+
+ if(ARROW_BUILD_STATIC)
+ set(PARQUET_DEPENDENCIES ${PARQUET_DEPENDENCIES} ${PARQUET_STATIC_LINK_LIBS})
+ endif()
+
+endif(NOT PARQUET_MINIMAL_DEPENDENCY)
+
+if(CXX_LINKER_SUPPORTS_VERSION_SCRIPT)
+ set(PARQUET_SHARED_LINK_FLAGS
+ "-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/symbols.map")
+endif()
+
+add_arrow_lib(parquet
+ CMAKE_PACKAGE_NAME
+ Parquet
+ PKG_CONFIG_NAME
+ parquet
+ SOURCES
+ ${PARQUET_SRCS}
+ PRECOMPILED_HEADERS
+ "$<$<COMPILE_LANGUAGE:CXX>:parquet/pch.h>"
+ OUTPUTS
+ PARQUET_LIBRARIES
+ DEPENDENCIES
+ ${PARQUET_DEPENDENCIES}
+ SHARED_LINK_FLAGS
+ ${PARQUET_SHARED_LINK_FLAGS}
+ SHARED_LINK_LIBS
+ ${PARQUET_SHARED_LINK_LIBS}
+ SHARED_PRIVATE_LINK_LIBS
+ ${PARQUET_SHARED_PRIVATE_LINK_LIBS}
+ STATIC_LINK_LIBS
+ ${PARQUET_STATIC_LINK_LIBS})
+
+if(WIN32 AND NOT (ARROW_TEST_LINKAGE STREQUAL "static"))
+ add_library(parquet_test_support STATIC
+ "${ARROW_SOURCE_DIR}/src/generated/parquet_constants.cpp"
+ "${ARROW_SOURCE_DIR}/src/generated/parquet_types.cpp")
+ add_dependencies(parquet_test_support thrift::thrift)
+ set(PARQUET_SHARED_TEST_LINK_LIBS ${PARQUET_SHARED_TEST_LINK_LIBS} parquet_test_support)
+ set(PARQUET_LIBRARIES ${PARQUET_LIBRARIES} parquet_test_support)
+endif()
+
+if(NOT ARROW_BUILD_SHARED)
+ set(PARQUET_BENCHMARK_LINK_OPTION STATIC_LINK_LIBS benchmark::benchmark_main
+ ${PARQUET_STATIC_TEST_LINK_LIBS})
+else()
+ set(PARQUET_BENCHMARK_LINK_OPTION EXTRA_LINK_LIBS ${PARQUET_SHARED_TEST_LINK_LIBS})
+endif()
+
+if(ARROW_BUILD_STATIC AND WIN32)
+ # ARROW-4848: Static Parquet lib needs to import static symbols on Windows
+ target_compile_definitions(parquet_static PUBLIC ARROW_STATIC)
+endif()
+
+add_dependencies(parquet ${PARQUET_LIBRARIES} thrift::thrift)
+
+add_definitions(-DPARQUET_THRIFT_VERSION_MAJOR=${THRIFT_VERSION_MAJOR})
+add_definitions(-DPARQUET_THRIFT_VERSION_MINOR=${THRIFT_VERSION_MINOR})
+
+# Thrift requires these definitions for some types that we use
+foreach(LIB_TARGET ${PARQUET_LIBRARIES})
+ target_compile_definitions(${LIB_TARGET}
+ PRIVATE PARQUET_EXPORTING
+ PRIVATE HAVE_INTTYPES_H
+ PRIVATE HAVE_NETDB_H)
+ if(WIN32)
+ target_compile_definitions(${LIB_TARGET} PRIVATE NOMINMAX)
+ else()
+ target_compile_definitions(${LIB_TARGET} PRIVATE HAVE_NETINET_IN_H)
+ endif()
+endforeach()
+
+if(WIN32 AND ARROW_BUILD_STATIC)
+ target_compile_definitions(parquet_static PUBLIC PARQUET_STATIC)
+endif()
+
+add_subdirectory(api)
+add_subdirectory(arrow)
+add_subdirectory(encryption)
+
+arrow_install_all_headers("parquet")
+
+configure_file(parquet_version.h.in "${CMAKE_CURRENT_BINARY_DIR}/parquet_version.h" @ONLY)
+
+install(FILES "${CMAKE_CURRENT_BINARY_DIR}/parquet_version.h"
+ DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/parquet")
+
+add_parquet_test(internals-test
+ SOURCES
+ bloom_filter_test.cc
+ properties_test.cc
+ statistics_test.cc
+ encoding_test.cc
+ metadata_test.cc
+ public_api_test.cc
+ types_test.cc
+ test_util.cc)
+
+set_source_files_properties(public_api_test.cc PROPERTIES SKIP_PRECOMPILE_HEADERS ON
+ SKIP_UNITY_BUILD_INCLUSION ON)
+
+add_parquet_test(reader-test
+ SOURCES
+ column_reader_test.cc
+ level_conversion_test.cc
+ column_scanner_test.cc
+ reader_test.cc
+ stream_reader_test.cc
+ test_util.cc)
+
+add_parquet_test(writer-test
+ SOURCES
+ column_writer_test.cc
+ file_serialize_test.cc
+ stream_writer_test.cc
+ test_util.cc)
+
+add_parquet_test(arrow-test
+ SOURCES
+ arrow/arrow_reader_writer_test.cc
+ arrow/arrow_schema_test.cc
+ arrow/arrow_statistics_test.cc
+ test_util.cc)
+
+add_parquet_test(arrow-internals-test
+ SOURCES
+ arrow/path_internal_test.cc
+ arrow/reconstruct_internal_test.cc
+ test_util.cc)
+
+if(PARQUET_REQUIRE_ENCRYPTION)
+ add_parquet_test(encryption-test
+ SOURCES
+ encryption/write_configurations_test.cc
+ encryption/read_configurations_test.cc
+ encryption/properties_test.cc
+ encryption/test_encryption_util.cc
+ test_util.cc)
+ add_parquet_test(encryption-key-management-test
+ SOURCES
+ encryption/key_management_test.cc
+ encryption/key_metadata_test.cc
+ encryption/key_wrapping_test.cc
+ encryption/test_encryption_util.cc
+ encryption/test_in_memory_kms.cc
+ encryption/two_level_cache_with_expiration_test.cc
+ test_util.cc)
+endif()
+
+# Those tests need to use static linking as they access thrift-generated
+# symbols which are not exported by parquet.dll on Windows (PARQUET-1420).
+add_parquet_test(file_deserialize_test SOURCES file_deserialize_test.cc test_util.cc)
+add_parquet_test(schema_test)
+
+add_parquet_benchmark(column_io_benchmark)
+add_parquet_benchmark(encoding_benchmark)
+add_parquet_benchmark(level_conversion_benchmark)
+add_parquet_benchmark(arrow/reader_writer_benchmark PREFIX "parquet-arrow")
+
+if(ARROW_WITH_BROTLI)
+ add_definitions(-DARROW_WITH_BROTLI)
+endif()
+
+if(ARROW_WITH_BZ2)
+ add_definitions(-DARROW_WITH_BZ2)
+endif()
+
+if(ARROW_WITH_LZ4)
+ add_definitions(-DARROW_WITH_LZ4)
+endif()
+
+if(ARROW_WITH_SNAPPY)
+ add_definitions(-DARROW_WITH_SNAPPY)
+endif()
+
+if(ARROW_WITH_ZLIB)
+ add_definitions(-DARROW_WITH_ZLIB)
+endif()
+
+if(ARROW_WITH_ZSTD)
+ add_definitions(-DARROW_WITH_ZSTD)
+endif()
+
+if(ARROW_CSV)
+ add_definitions(-DARROW_CSV)
+endif()
diff --git a/src/arrow/cpp/src/parquet/ParquetConfig.cmake.in b/src/arrow/cpp/src/parquet/ParquetConfig.cmake.in
new file mode 100644
index 000000000..afdecc517
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/ParquetConfig.cmake.in
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This config sets the following variables in your project::
+#
+# Parquet_FOUND - true if Parquet found on the system
+# PARQUET_FULL_SO_VERSION - full shared library version of the found Parquet
+# PARQUET_SO_VERSION - shared library version of the found Parquet
+# PARQUET_VERSION - version of the found Parquet
+#
+# This config sets the following targets in your project::
+#
+# parquet_shared - for linked as shared library if shared library is built
+# parquet_static - for linked as static library if static library is built
+
+@PACKAGE_INIT@
+
+include(CMakeFindDependencyMacro)
+find_dependency(Arrow)
+
+set(PARQUET_VERSION "@ARROW_VERSION@")
+set(PARQUET_SO_VERSION "@ARROW_SO_VERSION@")
+set(PARQUET_FULL_SO_VERSION "@ARROW_FULL_SO_VERSION@")
+
+# Load targets only once. If we load targets multiple times, CMake reports
+# already existent target error.
+if(NOT (TARGET parquet_shared OR TARGET parquet_static))
+ include("${CMAKE_CURRENT_LIST_DIR}/ParquetTargets.cmake")
+endif()
diff --git a/src/arrow/cpp/src/parquet/README b/src/arrow/cpp/src/parquet/README
new file mode 100644
index 000000000..fc16a46ca
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/README
@@ -0,0 +1,10 @@
+The CompatibilityTest of bloom_filter-test.cc is used to test cross compatibility of
+Bloom filters between parquet-mr and parquet-cpp. It reads the Bloom filter binary
+generated by the Bloom filter class in the parquet-mr project and tests whether the
+values inserted before could be filtered or not.
+
+The Bloom filter binary is generated by three steps from Parquet-mr:
+Step 1: Construct a Bloom filter with 1024 bytes of bitset.
+Step 2: Insert hashes of "hello", "parquet", "bloom", "filter" strings to Bloom filter
+by calling hash and insert APIs.
+Step 3: Call writeTo API to write to File.
diff --git a/src/arrow/cpp/src/parquet/api/CMakeLists.txt b/src/arrow/cpp/src/parquet/api/CMakeLists.txt
new file mode 100644
index 000000000..d44d04934
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/api/CMakeLists.txt
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Headers: public api
+arrow_install_all_headers("parquet/api")
diff --git a/src/arrow/cpp/src/parquet/api/io.h b/src/arrow/cpp/src/parquet/api/io.h
new file mode 100644
index 000000000..28a00f12a
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/api/io.h
@@ -0,0 +1,20 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "parquet/exception.h"
diff --git a/src/arrow/cpp/src/parquet/api/reader.h b/src/arrow/cpp/src/parquet/api/reader.h
new file mode 100644
index 000000000..7e746e8c5
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/api/reader.h
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+// Column reader API
+#include "parquet/column_reader.h"
+#include "parquet/column_scanner.h"
+#include "parquet/exception.h"
+#include "parquet/file_reader.h"
+#include "parquet/metadata.h"
+#include "parquet/platform.h"
+#include "parquet/printer.h"
+#include "parquet/properties.h"
+#include "parquet/statistics.h"
+
+// Schemas
+#include "parquet/api/schema.h"
+
+// IO
+#include "parquet/api/io.h"
diff --git a/src/arrow/cpp/src/parquet/api/schema.h b/src/arrow/cpp/src/parquet/api/schema.h
new file mode 100644
index 000000000..7ca714f47
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/api/schema.h
@@ -0,0 +1,21 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+// Schemas
+#include "parquet/schema.h"
diff --git a/src/arrow/cpp/src/parquet/api/writer.h b/src/arrow/cpp/src/parquet/api/writer.h
new file mode 100644
index 000000000..b072dcf74
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/api/writer.h
@@ -0,0 +1,25 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "parquet/api/io.h"
+#include "parquet/api/schema.h"
+#include "parquet/column_writer.h"
+#include "parquet/exception.h"
+#include "parquet/file_writer.h"
+#include "parquet/statistics.h"
diff --git a/src/arrow/cpp/src/parquet/arrow/CMakeLists.txt b/src/arrow/cpp/src/parquet/arrow/CMakeLists.txt
new file mode 100644
index 000000000..ac708a0e4
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/CMakeLists.txt
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+arrow_install_all_headers("parquet/arrow")
+
+if(ARROW_FUZZING)
+ add_executable(parquet-arrow-generate-fuzz-corpus generate_fuzz_corpus.cc)
+ if(ARROW_BUILD_STATIC)
+ target_link_libraries(parquet-arrow-generate-fuzz-corpus parquet_static
+ arrow_testing_static)
+ else()
+ target_link_libraries(parquet-arrow-generate-fuzz-corpus parquet_shared
+ arrow_testing_shared)
+ endif()
+endif()
+
+add_parquet_fuzz_target(fuzz PREFIX "parquet-arrow")
diff --git a/src/arrow/cpp/src/parquet/arrow/arrow_reader_writer_test.cc b/src/arrow/cpp/src/parquet/arrow/arrow_reader_writer_test.cc
new file mode 100644
index 000000000..fa8b15d2b
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/arrow_reader_writer_test.cc
@@ -0,0 +1,4343 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifdef _MSC_VER
+#pragma warning(push)
+// Disable forcing value to bool warnings
+#pragma warning(disable : 4800)
+#endif
+
+#include "gtest/gtest.h"
+
+#include <cstdint>
+#include <functional>
+#include <iostream>
+#include <sstream>
+#include <vector>
+
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/array/builder_dict.h"
+#include "arrow/array/builder_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/chunked_array.h"
+#include "arrow/compute/api.h"
+#include "arrow/io/api.h"
+#include "arrow/record_batch.h"
+#include "arrow/scalar.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/range.h"
+
+#ifdef ARROW_CSV
+#include "arrow/csv/api.h"
+#endif
+
+#include "parquet/api/reader.h"
+#include "parquet/api/writer.h"
+
+#include "parquet/arrow/reader.h"
+#include "parquet/arrow/reader_internal.h"
+#include "parquet/arrow/schema.h"
+#include "parquet/arrow/test_util.h"
+#include "parquet/arrow/writer.h"
+#include "parquet/column_writer.h"
+#include "parquet/file_writer.h"
+#include "parquet/test_util.h"
+
+using arrow::Array;
+using arrow::ArrayData;
+using arrow::ArrayFromJSON;
+using arrow::ArrayVector;
+using arrow::ArrayVisitor;
+using arrow::Buffer;
+using arrow::ChunkedArray;
+using arrow::DataType;
+using arrow::Datum;
+using arrow::DecimalType;
+using arrow::default_memory_pool;
+using arrow::ListArray;
+using arrow::PrimitiveArray;
+using arrow::ResizableBuffer;
+using arrow::Scalar;
+using arrow::Status;
+using arrow::Table;
+using arrow::TimeUnit;
+using arrow::compute::DictionaryEncode;
+using arrow::internal::checked_cast;
+using arrow::internal::checked_pointer_cast;
+using arrow::internal::Iota;
+using arrow::io::BufferReader;
+
+using arrow::randint;
+using arrow::random_is_valid;
+
+using ArrowId = ::arrow::Type;
+using ParquetType = parquet::Type;
+
+using parquet::arrow::FromParquetSchema;
+using parquet::schema::GroupNode;
+using parquet::schema::NodePtr;
+using parquet::schema::PrimitiveNode;
+
+namespace parquet {
+namespace arrow {
+
+static constexpr int SMALL_SIZE = 100;
+#ifdef PARQUET_VALGRIND
+static constexpr int LARGE_SIZE = 1000;
+#else
+static constexpr int LARGE_SIZE = 10000;
+#endif
+
+static constexpr uint32_t kDefaultSeed = 0;
+
+std::shared_ptr<const LogicalType> get_logical_type(const DataType& type) {
+ switch (type.id()) {
+ case ArrowId::UINT8:
+ return LogicalType::Int(8, false);
+ case ArrowId::INT8:
+ return LogicalType::Int(8, true);
+ case ArrowId::UINT16:
+ return LogicalType::Int(16, false);
+ case ArrowId::INT16:
+ return LogicalType::Int(16, true);
+ case ArrowId::UINT32:
+ return LogicalType::Int(32, false);
+ case ArrowId::INT32:
+ return LogicalType::Int(32, true);
+ case ArrowId::UINT64:
+ return LogicalType::Int(64, false);
+ case ArrowId::INT64:
+ return LogicalType::Int(64, true);
+ case ArrowId::STRING:
+ return LogicalType::String();
+ case ArrowId::DATE32:
+ return LogicalType::Date();
+ case ArrowId::DATE64:
+ return LogicalType::Date();
+ case ArrowId::TIMESTAMP: {
+ const auto& ts_type = static_cast<const ::arrow::TimestampType&>(type);
+ const bool adjusted_to_utc = !(ts_type.timezone().empty());
+ switch (ts_type.unit()) {
+ case TimeUnit::MILLI:
+ return LogicalType::Timestamp(adjusted_to_utc, LogicalType::TimeUnit::MILLIS);
+ case TimeUnit::MICRO:
+ return LogicalType::Timestamp(adjusted_to_utc, LogicalType::TimeUnit::MICROS);
+ case TimeUnit::NANO:
+ return LogicalType::Timestamp(adjusted_to_utc, LogicalType::TimeUnit::NANOS);
+ default:
+ DCHECK(false)
+ << "Only MILLI, MICRO, and NANO units supported for Arrow TIMESTAMP.";
+ }
+ break;
+ }
+ case ArrowId::TIME32:
+ return LogicalType::Time(false, LogicalType::TimeUnit::MILLIS);
+ case ArrowId::TIME64: {
+ const auto& tm_type = static_cast<const ::arrow::TimeType&>(type);
+ switch (tm_type.unit()) {
+ case TimeUnit::MICRO:
+ return LogicalType::Time(false, LogicalType::TimeUnit::MICROS);
+ case TimeUnit::NANO:
+ return LogicalType::Time(false, LogicalType::TimeUnit::NANOS);
+ default:
+ DCHECK(false) << "Only MICRO and NANO units supported for Arrow TIME64.";
+ }
+ break;
+ }
+ case ArrowId::DICTIONARY: {
+ const ::arrow::DictionaryType& dict_type =
+ static_cast<const ::arrow::DictionaryType&>(type);
+ return get_logical_type(*dict_type.value_type());
+ }
+ case ArrowId::DECIMAL128: {
+ const auto& dec_type = static_cast<const ::arrow::Decimal128Type&>(type);
+ return LogicalType::Decimal(dec_type.precision(), dec_type.scale());
+ }
+ case ArrowId::DECIMAL256: {
+ const auto& dec_type = static_cast<const ::arrow::Decimal256Type&>(type);
+ return LogicalType::Decimal(dec_type.precision(), dec_type.scale());
+ }
+
+ default:
+ break;
+ }
+ return LogicalType::None();
+}
+
+ParquetType::type get_physical_type(const DataType& type) {
+ switch (type.id()) {
+ case ArrowId::BOOL:
+ return ParquetType::BOOLEAN;
+ case ArrowId::UINT8:
+ case ArrowId::INT8:
+ case ArrowId::UINT16:
+ case ArrowId::INT16:
+ case ArrowId::UINT32:
+ case ArrowId::INT32:
+ return ParquetType::INT32;
+ case ArrowId::UINT64:
+ case ArrowId::INT64:
+ return ParquetType::INT64;
+ case ArrowId::FLOAT:
+ return ParquetType::FLOAT;
+ case ArrowId::DOUBLE:
+ return ParquetType::DOUBLE;
+ case ArrowId::BINARY:
+ case ArrowId::LARGE_BINARY:
+ return ParquetType::BYTE_ARRAY;
+ case ArrowId::STRING:
+ case ArrowId::LARGE_STRING:
+ return ParquetType::BYTE_ARRAY;
+ case ArrowId::FIXED_SIZE_BINARY:
+ case ArrowId::DECIMAL128:
+ case ArrowId::DECIMAL256:
+ return ParquetType::FIXED_LEN_BYTE_ARRAY;
+ case ArrowId::DATE32:
+ return ParquetType::INT32;
+ case ArrowId::DATE64:
+ // Convert to date32 internally
+ return ParquetType::INT32;
+ case ArrowId::TIME32:
+ return ParquetType::INT32;
+ case ArrowId::TIME64:
+ return ParquetType::INT64;
+ case ArrowId::TIMESTAMP:
+ return ParquetType::INT64;
+ case ArrowId::DICTIONARY: {
+ const ::arrow::DictionaryType& dict_type =
+ static_cast<const ::arrow::DictionaryType&>(type);
+ return get_physical_type(*dict_type.value_type());
+ }
+ default:
+ break;
+ }
+ DCHECK(false) << "cannot reach this code";
+ return ParquetType::INT32;
+}
+
+template <typename TestType>
+struct test_traits {};
+
+template <>
+struct test_traits<::arrow::BooleanType> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::BOOLEAN;
+ static uint8_t const value;
+};
+
+const uint8_t test_traits<::arrow::BooleanType>::value(1);
+
+template <>
+struct test_traits<::arrow::UInt8Type> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::INT32;
+ static uint8_t const value;
+};
+
+const uint8_t test_traits<::arrow::UInt8Type>::value(64);
+
+template <>
+struct test_traits<::arrow::Int8Type> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::INT32;
+ static int8_t const value;
+};
+
+const int8_t test_traits<::arrow::Int8Type>::value(-64);
+
+template <>
+struct test_traits<::arrow::UInt16Type> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::INT32;
+ static uint16_t const value;
+};
+
+const uint16_t test_traits<::arrow::UInt16Type>::value(1024);
+
+template <>
+struct test_traits<::arrow::Int16Type> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::INT32;
+ static int16_t const value;
+};
+
+const int16_t test_traits<::arrow::Int16Type>::value(-1024);
+
+template <>
+struct test_traits<::arrow::UInt32Type> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::INT32;
+ static uint32_t const value;
+};
+
+const uint32_t test_traits<::arrow::UInt32Type>::value(1024);
+
+template <>
+struct test_traits<::arrow::Int32Type> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::INT32;
+ static int32_t const value;
+};
+
+const int32_t test_traits<::arrow::Int32Type>::value(-1024);
+
+template <>
+struct test_traits<::arrow::UInt64Type> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::INT64;
+ static uint64_t const value;
+};
+
+const uint64_t test_traits<::arrow::UInt64Type>::value(1024);
+
+template <>
+struct test_traits<::arrow::Int64Type> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::INT64;
+ static int64_t const value;
+};
+
+const int64_t test_traits<::arrow::Int64Type>::value(-1024);
+
+template <>
+struct test_traits<::arrow::TimestampType> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::INT64;
+ static int64_t const value;
+};
+
+const int64_t test_traits<::arrow::TimestampType>::value(14695634030000);
+
+template <>
+struct test_traits<::arrow::Date32Type> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::INT32;
+ static int32_t const value;
+};
+
+const int32_t test_traits<::arrow::Date32Type>::value(170000);
+
+template <>
+struct test_traits<::arrow::FloatType> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::FLOAT;
+ static float const value;
+};
+
+const float test_traits<::arrow::FloatType>::value(2.1f);
+
+template <>
+struct test_traits<::arrow::DoubleType> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::DOUBLE;
+ static double const value;
+};
+
+const double test_traits<::arrow::DoubleType>::value(4.2);
+
+template <>
+struct test_traits<::arrow::StringType> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::BYTE_ARRAY;
+ static std::string const value;
+};
+
+template <>
+struct test_traits<::arrow::BinaryType> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::BYTE_ARRAY;
+ static std::string const value;
+};
+
+template <>
+struct test_traits<::arrow::FixedSizeBinaryType> {
+ static constexpr ParquetType::type parquet_enum = ParquetType::FIXED_LEN_BYTE_ARRAY;
+ static std::string const value;
+};
+
+const std::string test_traits<::arrow::StringType>::value("Test"); // NOLINT
+const std::string test_traits<::arrow::BinaryType>::value({0, 1, 2, 3}); // NOLINT
+const std::string test_traits<::arrow::FixedSizeBinaryType>::value("Fixed"); // NOLINT
+
+template <typename T>
+using ParquetDataType = PhysicalType<test_traits<T>::parquet_enum>;
+
+template <typename T>
+using ParquetWriter = TypedColumnWriter<ParquetDataType<T>>;
+
+void WriteTableToBuffer(const std::shared_ptr<Table>& table, int64_t row_group_size,
+ const std::shared_ptr<ArrowWriterProperties>& arrow_properties,
+ std::shared_ptr<Buffer>* out) {
+ auto sink = CreateOutputStream();
+
+ auto write_props = WriterProperties::Builder().write_batch_size(100)->build();
+
+ ASSERT_OK_NO_THROW(WriteTable(*table, ::arrow::default_memory_pool(), sink,
+ row_group_size, write_props, arrow_properties));
+ ASSERT_OK_AND_ASSIGN(*out, sink->Finish());
+}
+
+void DoRoundtrip(const std::shared_ptr<Table>& table, int64_t row_group_size,
+ std::shared_ptr<Table>* out,
+ const std::shared_ptr<::parquet::WriterProperties>& writer_properties =
+ ::parquet::default_writer_properties(),
+ const std::shared_ptr<ArrowWriterProperties>& arrow_writer_properties =
+ default_arrow_writer_properties(),
+ const ArrowReaderProperties& arrow_reader_properties =
+ default_arrow_reader_properties()) {
+ auto sink = CreateOutputStream();
+ ASSERT_OK_NO_THROW(WriteTable(*table, ::arrow::default_memory_pool(), sink,
+ row_group_size, writer_properties,
+ arrow_writer_properties));
+ ASSERT_OK_AND_ASSIGN(auto buffer, sink->Finish());
+
+ std::unique_ptr<FileReader> reader;
+ FileReaderBuilder builder;
+ ASSERT_OK_NO_THROW(builder.Open(std::make_shared<BufferReader>(buffer)));
+ ASSERT_OK(builder.properties(arrow_reader_properties)->Build(&reader));
+ ASSERT_OK_NO_THROW(reader->ReadTable(out));
+}
+
+void CheckConfiguredRoundtrip(
+ const std::shared_ptr<Table>& input_table,
+ const std::shared_ptr<Table>& expected_table = nullptr,
+ const std::shared_ptr<::parquet::WriterProperties>& writer_properties =
+ ::parquet::default_writer_properties(),
+ const std::shared_ptr<ArrowWriterProperties>& arrow_writer_properties =
+ default_arrow_writer_properties()) {
+ std::shared_ptr<Table> actual_table;
+ ASSERT_NO_FATAL_FAILURE(DoRoundtrip(input_table, input_table->num_rows(), &actual_table,
+ writer_properties, arrow_writer_properties));
+ if (expected_table) {
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertSchemaEqual(*actual_table->schema(),
+ *expected_table->schema(),
+ /*check_metadata=*/false));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*expected_table, *actual_table));
+ } else {
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertSchemaEqual(*actual_table->schema(),
+ *input_table->schema(),
+ /*check_metadata=*/false));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*input_table, *actual_table));
+ }
+}
+
+void DoSimpleRoundtrip(const std::shared_ptr<Table>& table, bool use_threads,
+ int64_t row_group_size, const std::vector<int>& column_subset,
+ std::shared_ptr<Table>* out,
+ const std::shared_ptr<ArrowWriterProperties>& arrow_properties =
+ default_arrow_writer_properties()) {
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_NO_FATAL_FAILURE(
+ WriteTableToBuffer(table, row_group_size, arrow_properties, &buffer));
+
+ std::unique_ptr<FileReader> reader;
+ ASSERT_OK_NO_THROW(OpenFile(std::make_shared<BufferReader>(buffer),
+ ::arrow::default_memory_pool(), &reader));
+
+ reader->set_use_threads(use_threads);
+ if (column_subset.size() > 0) {
+ ASSERT_OK_NO_THROW(reader->ReadTable(column_subset, out));
+ } else {
+ // Read everything
+ ASSERT_OK_NO_THROW(reader->ReadTable(out));
+ }
+}
+
+void DoRoundTripWithBatches(
+ const std::shared_ptr<Table>& table, bool use_threads, int64_t row_group_size,
+ const std::vector<int>& column_subset, std::shared_ptr<Table>* out,
+ const std::shared_ptr<ArrowWriterProperties>& arrow_writer_properties =
+ default_arrow_writer_properties()) {
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_NO_FATAL_FAILURE(
+ WriteTableToBuffer(table, row_group_size, arrow_writer_properties, &buffer));
+
+ std::unique_ptr<FileReader> reader;
+ FileReaderBuilder builder;
+ ASSERT_OK_NO_THROW(builder.Open(std::make_shared<BufferReader>(buffer)));
+ ArrowReaderProperties arrow_reader_properties;
+ arrow_reader_properties.set_batch_size(row_group_size - 1);
+ ASSERT_OK_NO_THROW(builder.memory_pool(::arrow::default_memory_pool())
+ ->properties(arrow_reader_properties)
+ ->Build(&reader));
+ std::unique_ptr<::arrow::RecordBatchReader> batch_reader;
+ if (column_subset.size() > 0) {
+ ASSERT_OK_NO_THROW(reader->GetRecordBatchReader(
+ Iota(reader->parquet_reader()->metadata()->num_row_groups()), column_subset,
+ &batch_reader));
+ } else {
+ // Read everything
+
+ ASSERT_OK_NO_THROW(reader->GetRecordBatchReader(
+ Iota(reader->parquet_reader()->metadata()->num_row_groups()), &batch_reader));
+ }
+ ASSERT_OK_AND_ASSIGN(*out, Table::FromRecordBatchReader(batch_reader.get()));
+}
+
+void CheckSimpleRoundtrip(
+ const std::shared_ptr<Table>& table, int64_t row_group_size,
+ const std::shared_ptr<ArrowWriterProperties>& arrow_writer_properties =
+ default_arrow_writer_properties()) {
+ std::shared_ptr<Table> result;
+ ASSERT_NO_FATAL_FAILURE(DoSimpleRoundtrip(table, false /* use_threads */,
+ row_group_size, {}, &result,
+ arrow_writer_properties));
+ ::arrow::AssertSchemaEqual(*table->schema(), *result->schema(),
+ /*check_metadata=*/false);
+ ASSERT_OK(result->ValidateFull());
+
+ ::arrow::AssertTablesEqual(*table, *result, false);
+
+ ASSERT_NO_FATAL_FAILURE(DoRoundTripWithBatches(table, false /* use_threads */,
+ row_group_size, {}, &result,
+ arrow_writer_properties));
+ ::arrow::AssertSchemaEqual(*table->schema(), *result->schema(),
+ /*check_metadata=*/false);
+ ASSERT_OK(result->ValidateFull());
+
+ ::arrow::AssertTablesEqual(*table, *result, false);
+}
+
+static std::shared_ptr<GroupNode> MakeSimpleSchema(const DataType& type,
+ Repetition::type repetition) {
+ int32_t byte_width = -1;
+
+ switch (type.id()) {
+ case ::arrow::Type::DICTIONARY: {
+ const auto& dict_type = static_cast<const ::arrow::DictionaryType&>(type);
+ const DataType& values_type = *dict_type.value_type();
+ switch (values_type.id()) {
+ case ::arrow::Type::FIXED_SIZE_BINARY:
+ byte_width =
+ static_cast<const ::arrow::FixedSizeBinaryType&>(values_type).byte_width();
+ break;
+ case ::arrow::Type::DECIMAL128:
+ case ::arrow::Type::DECIMAL256: {
+ const auto& decimal_type = static_cast<const DecimalType&>(values_type);
+ byte_width = DecimalType::DecimalSize(decimal_type.precision());
+ } break;
+ default:
+ break;
+ }
+ } break;
+ case ::arrow::Type::FIXED_SIZE_BINARY:
+ byte_width = static_cast<const ::arrow::FixedSizeBinaryType&>(type).byte_width();
+ break;
+ case ::arrow::Type::DECIMAL128:
+ case ::arrow::Type::DECIMAL256: {
+ const auto& decimal_type = static_cast<const DecimalType&>(type);
+ byte_width = DecimalType::DecimalSize(decimal_type.precision());
+ } break;
+ default:
+ break;
+ }
+ auto pnode = PrimitiveNode::Make("column1", repetition, get_logical_type(type),
+ get_physical_type(type), byte_width);
+ NodePtr node_ =
+ GroupNode::Make("schema", Repetition::REQUIRED, std::vector<NodePtr>({pnode}));
+ return std::static_pointer_cast<GroupNode>(node_);
+}
+
+void ReadSingleColumnFileStatistics(std::unique_ptr<FileReader> file_reader,
+ std::shared_ptr<Scalar>* min,
+ std::shared_ptr<Scalar>* max) {
+ auto metadata = file_reader->parquet_reader()->metadata();
+ ASSERT_EQ(1, metadata->num_row_groups());
+ ASSERT_EQ(1, metadata->num_columns());
+
+ auto row_group = metadata->RowGroup(0);
+ ASSERT_EQ(1, row_group->num_columns());
+
+ auto column = row_group->ColumnChunk(0);
+ ASSERT_TRUE(column->is_stats_set());
+ auto statistics = column->statistics();
+
+ ASSERT_OK(StatisticsAsScalars(*statistics, min, max));
+}
+
+void DownsampleInt96RoundTrip(std::shared_ptr<Array> arrow_vector_in,
+ std::shared_ptr<Array> arrow_vector_out,
+ ::arrow::TimeUnit::type unit) {
+ // Create single input table of NS to be written to parquet with INT96
+ auto input_schema =
+ ::arrow::schema({::arrow::field("f", ::arrow::timestamp(TimeUnit::NANO))});
+ auto input = Table::Make(input_schema, {arrow_vector_in});
+
+ // Create an expected schema for each resulting table (one for each "downsampled" ts)
+ auto ex_schema = ::arrow::schema({::arrow::field("f", ::arrow::timestamp(unit))});
+ auto ex_result = Table::Make(ex_schema, {arrow_vector_out});
+
+ std::shared_ptr<Table> result;
+
+ ArrowReaderProperties arrow_reader_prop;
+ arrow_reader_prop.set_coerce_int96_timestamp_unit(unit);
+
+ ASSERT_NO_FATAL_FAILURE(DoRoundtrip(
+ input, input->num_rows(), &result, default_writer_properties(),
+ ArrowWriterProperties::Builder().enable_deprecated_int96_timestamps()->build(),
+ arrow_reader_prop));
+
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertSchemaEqual(*ex_result->schema(),
+ *result->schema(),
+ /*check_metadata=*/false));
+
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*ex_result, *result));
+}
+
+// Non-template base class for TestParquetIO, to avoid code duplication
+class ParquetIOTestBase : public ::testing::Test {
+ public:
+ virtual void SetUp() {}
+
+ std::unique_ptr<ParquetFileWriter> MakeWriter(
+ const std::shared_ptr<GroupNode>& schema) {
+ sink_ = CreateOutputStream();
+ return ParquetFileWriter::Open(sink_, schema);
+ }
+
+ void ReaderFromSink(std::unique_ptr<FileReader>* out) {
+ ASSERT_OK_AND_ASSIGN(auto buffer, sink_->Finish());
+ ASSERT_OK_NO_THROW(OpenFile(std::make_shared<BufferReader>(buffer),
+ ::arrow::default_memory_pool(), out));
+ }
+
+ void ReadSingleColumnFile(std::unique_ptr<FileReader> file_reader,
+ std::shared_ptr<Array>* out) {
+ std::unique_ptr<ColumnReader> column_reader;
+ ASSERT_OK_NO_THROW(file_reader->GetColumn(0, &column_reader));
+ ASSERT_NE(nullptr, column_reader.get());
+
+ std::shared_ptr<ChunkedArray> chunked_out;
+ ASSERT_OK(column_reader->NextBatch(SMALL_SIZE, &chunked_out));
+
+ ASSERT_EQ(1, chunked_out->num_chunks());
+ *out = chunked_out->chunk(0);
+ ASSERT_NE(nullptr, out->get());
+ ASSERT_OK((*out)->ValidateFull());
+ }
+
+ void ReadAndCheckSingleColumnFile(const Array& values) {
+ std::shared_ptr<Array> out;
+
+ std::unique_ptr<FileReader> reader;
+ ReaderFromSink(&reader);
+ ReadSingleColumnFile(std::move(reader), &out);
+
+ AssertArraysEqual(values, *out);
+ }
+
+ void ReadTableFromFile(std::unique_ptr<FileReader> reader, bool expect_metadata,
+ std::shared_ptr<Table>* out) {
+ ASSERT_OK_NO_THROW(reader->ReadTable(out));
+ auto key_value_metadata =
+ reader->parquet_reader()->metadata()->key_value_metadata().get();
+ if (!expect_metadata) {
+ ASSERT_EQ(nullptr, key_value_metadata);
+ } else {
+ ASSERT_NE(nullptr, key_value_metadata);
+ }
+ ASSERT_NE(nullptr, out->get());
+ }
+
+ void ReadTableFromFile(std::unique_ptr<FileReader> reader,
+ std::shared_ptr<Table>* out) {
+ ReadTableFromFile(std::move(reader), /*expect_metadata=*/false, out);
+ }
+
+ void RoundTripSingleColumn(
+ const std::shared_ptr<Array>& values, const std::shared_ptr<Array>& expected,
+ const std::shared_ptr<::parquet::ArrowWriterProperties>& arrow_properties,
+ bool nullable = true) {
+ std::shared_ptr<Table> table = MakeSimpleTable(values, nullable);
+ this->ResetSink();
+ ASSERT_OK_NO_THROW(WriteTable(*table, ::arrow::default_memory_pool(), this->sink_,
+ values->length(), default_writer_properties(),
+ arrow_properties));
+
+ std::shared_ptr<Table> out;
+ std::unique_ptr<FileReader> reader;
+ ASSERT_NO_FATAL_FAILURE(this->ReaderFromSink(&reader));
+ const bool expect_metadata = arrow_properties->store_schema();
+ ASSERT_NO_FATAL_FAILURE(
+ this->ReadTableFromFile(std::move(reader), expect_metadata, &out));
+ ASSERT_EQ(1, out->num_columns());
+ ASSERT_EQ(table->num_rows(), out->num_rows());
+
+ const auto chunked_array = out->column(0);
+ ASSERT_EQ(1, chunked_array->num_chunks());
+
+ AssertArraysEqual(*expected, *chunked_array->chunk(0), /*verbose=*/true);
+ }
+
+ // Prepare table of empty lists, with null values array (ARROW-2744)
+ void PrepareEmptyListsTable(int64_t size, std::shared_ptr<Table>* out) {
+ std::shared_ptr<Array> lists;
+ ASSERT_OK(MakeEmptyListsArray(size, &lists));
+ *out = MakeSimpleTable(lists, true /* nullable_lists */);
+ }
+
+ void ReadAndCheckSingleColumnTable(const std::shared_ptr<Array>& values) {
+ std::shared_ptr<::arrow::Table> out;
+ std::unique_ptr<FileReader> reader;
+ ReaderFromSink(&reader);
+ ReadTableFromFile(std::move(reader), &out);
+ ASSERT_EQ(1, out->num_columns());
+ ASSERT_EQ(values->length(), out->num_rows());
+
+ std::shared_ptr<ChunkedArray> chunked_array = out->column(0);
+ ASSERT_EQ(1, chunked_array->num_chunks());
+ auto result = chunked_array->chunk(0);
+
+ AssertArraysEqual(*values, *result);
+ }
+
+ void CheckRoundTrip(const std::shared_ptr<Table>& table) {
+ CheckSimpleRoundtrip(table, table->num_rows());
+ }
+
+ template <typename ArrayType>
+ void WriteColumn(const std::shared_ptr<GroupNode>& schema,
+ const std::shared_ptr<ArrayType>& values) {
+ SchemaDescriptor descriptor;
+ ASSERT_NO_THROW(descriptor.Init(schema));
+ std::shared_ptr<::arrow::Schema> arrow_schema;
+ ArrowReaderProperties props;
+ ASSERT_OK_NO_THROW(FromParquetSchema(&descriptor, props, &arrow_schema));
+
+ std::unique_ptr<FileWriter> writer;
+ ASSERT_OK_NO_THROW(FileWriter::Make(::arrow::default_memory_pool(),
+ MakeWriter(schema), arrow_schema,
+ default_arrow_writer_properties(), &writer));
+ ASSERT_OK_NO_THROW(writer->NewRowGroup(values->length()));
+ ASSERT_OK_NO_THROW(writer->WriteColumnChunk(*values));
+ ASSERT_OK_NO_THROW(writer->Close());
+ // writer->Close() should be idempotent
+ ASSERT_OK_NO_THROW(writer->Close());
+ }
+
+ void ResetSink() { sink_ = CreateOutputStream(); }
+
+ std::shared_ptr<::arrow::io::BufferOutputStream> sink_;
+};
+
+class TestReadDecimals : public ParquetIOTestBase {
+ public:
+ void CheckReadFromByteArrays(const std::shared_ptr<const LogicalType>& logical_type,
+ const std::vector<std::vector<uint8_t>>& values,
+ const Array& expected) {
+ std::vector<ByteArray> byte_arrays(values.size());
+ std::transform(values.begin(), values.end(), byte_arrays.begin(),
+ [](const std::vector<uint8_t>& bytes) {
+ return ByteArray(static_cast<uint32_t>(bytes.size()), bytes.data());
+ });
+
+ auto node = PrimitiveNode::Make("decimals", Repetition::REQUIRED, logical_type,
+ Type::BYTE_ARRAY);
+ auto schema =
+ GroupNode::Make("schema", Repetition::REQUIRED, std::vector<NodePtr>{node});
+
+ auto file_writer = MakeWriter(checked_pointer_cast<GroupNode>(schema));
+ auto column_writer = file_writer->AppendRowGroup()->NextColumn();
+ auto typed_writer = checked_cast<TypedColumnWriter<ByteArrayType>*>(column_writer);
+ typed_writer->WriteBatch(static_cast<int64_t>(byte_arrays.size()),
+ /*def_levels=*/nullptr,
+ /*rep_levels=*/nullptr, byte_arrays.data());
+ column_writer->Close();
+ file_writer->Close();
+
+ ReadAndCheckSingleColumnFile(expected);
+ }
+};
+
+// The Decimal roundtrip tests always go through the FixedLenByteArray path,
+// check the ByteArray case manually.
+
+TEST_F(TestReadDecimals, Decimal128ByteArray) {
+ const std::vector<std::vector<uint8_t>> big_endian_decimals = {
+ // 123456
+ {1, 226, 64},
+ // 987654
+ {15, 18, 6},
+ // -123456
+ {255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 254, 29, 192},
+ };
+
+ auto expected =
+ ArrayFromJSON(::arrow::decimal128(6, 3), R"(["123.456", "987.654", "-123.456"])");
+ CheckReadFromByteArrays(LogicalType::Decimal(6, 3), big_endian_decimals, *expected);
+}
+
+TEST_F(TestReadDecimals, Decimal256ByteArray) {
+ const std::vector<std::vector<uint8_t>> big_endian_decimals = {
+ // 123456
+ {1, 226, 64},
+ // 987654
+ {15, 18, 6},
+ // -123456
+ {255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 254, 29, 192},
+ };
+
+ auto expected =
+ ArrayFromJSON(::arrow::decimal256(40, 3), R"(["123.456", "987.654", "-123.456"])");
+ CheckReadFromByteArrays(LogicalType::Decimal(40, 3), big_endian_decimals, *expected);
+}
+
+template <typename TestType>
+class TestParquetIO : public ParquetIOTestBase {
+ public:
+ void PrepareListTable(int64_t size, bool nullable_lists, bool nullable_elements,
+ int64_t null_count, std::shared_ptr<Table>* out) {
+ std::shared_ptr<Array> values;
+ ASSERT_OK(NullableArray<TestType>(size * size, nullable_elements ? null_count : 0,
+ kDefaultSeed, &values));
+ // Also test that slice offsets are respected
+ values = values->Slice(5, values->length() - 5);
+ std::shared_ptr<ListArray> lists;
+ ASSERT_OK(MakeListArray(values, size, nullable_lists ? null_count : 0, "element",
+ nullable_elements, &lists));
+ *out = MakeSimpleTable(lists->Slice(3, size - 6), nullable_lists);
+ }
+
+ void PrepareListOfListTable(int64_t size, bool nullable_parent_lists,
+ bool nullable_lists, bool nullable_elements,
+ int64_t null_count, std::shared_ptr<Table>* out) {
+ std::shared_ptr<Array> values;
+ ASSERT_OK(NullableArray<TestType>(size * 6, nullable_elements ? null_count : 0,
+ kDefaultSeed, &values));
+ std::shared_ptr<ListArray> lists;
+ ASSERT_OK(MakeListArray(values, size * 3, nullable_lists ? null_count : 0, "item",
+ nullable_elements, &lists));
+ std::shared_ptr<ListArray> parent_lists;
+ ASSERT_OK(MakeListArray(lists, size, nullable_parent_lists ? null_count : 0, "item",
+ nullable_lists, &parent_lists));
+ *out = MakeSimpleTable(parent_lists, nullable_parent_lists);
+ }
+};
+
+// Below, we only test types which map bijectively to logical Parquet types
+// (these tests don't serialize the original Arrow schema in Parquet metadata).
+// Other Arrow types are tested elsewhere:
+// - UInt32Type is serialized as Parquet INT64 in Parquet 1.0 (but not 2.0)
+// - LargeBinaryType and LargeStringType are serialized as Parquet BYTE_ARRAY
+// (and deserialized as BinaryType and StringType, respectively)
+
+typedef ::testing::Types<
+ ::arrow::BooleanType, ::arrow::UInt8Type, ::arrow::Int8Type, ::arrow::UInt16Type,
+ ::arrow::Int16Type, ::arrow::Int32Type, ::arrow::UInt64Type, ::arrow::Int64Type,
+ ::arrow::Date32Type, ::arrow::FloatType, ::arrow::DoubleType, ::arrow::StringType,
+ ::arrow::BinaryType, ::arrow::FixedSizeBinaryType, DecimalWithPrecisionAndScale<1>,
+ DecimalWithPrecisionAndScale<5>, DecimalWithPrecisionAndScale<10>,
+ DecimalWithPrecisionAndScale<19>, DecimalWithPrecisionAndScale<23>,
+ DecimalWithPrecisionAndScale<27>, DecimalWithPrecisionAndScale<38>,
+ Decimal256WithPrecisionAndScale<39>, Decimal256WithPrecisionAndScale<56>,
+ Decimal256WithPrecisionAndScale<76>>
+ TestTypes;
+
+TYPED_TEST_SUITE(TestParquetIO, TestTypes);
+
+TYPED_TEST(TestParquetIO, SingleColumnRequiredWrite) {
+ std::shared_ptr<Array> values;
+ ASSERT_OK(NonNullArray<TypeParam>(SMALL_SIZE, &values));
+
+ std::shared_ptr<GroupNode> schema =
+ MakeSimpleSchema(*values->type(), Repetition::REQUIRED);
+ ASSERT_NO_FATAL_FAILURE(this->WriteColumn(schema, values));
+
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCheckSingleColumnFile(*values));
+}
+
+TYPED_TEST(TestParquetIO, ZeroChunksTable) {
+ auto values = std::make_shared<ChunkedArray>(::arrow::ArrayVector{}, ::arrow::int32());
+ auto table = MakeSimpleTable(values, false);
+
+ this->ResetSink();
+ ASSERT_OK_NO_THROW(
+ WriteTable(*table, ::arrow::default_memory_pool(), this->sink_, SMALL_SIZE));
+
+ std::shared_ptr<Table> out;
+ std::unique_ptr<FileReader> reader;
+ ASSERT_NO_FATAL_FAILURE(this->ReaderFromSink(&reader));
+ ASSERT_NO_FATAL_FAILURE(this->ReadTableFromFile(std::move(reader), &out));
+ ASSERT_EQ(1, out->num_columns());
+ ASSERT_EQ(0, out->num_rows());
+ ASSERT_EQ(0, out->column(0)->length());
+ // odd: even though zero chunks were written, a single empty chunk is read
+ ASSERT_EQ(1, out->column(0)->num_chunks());
+}
+
+TYPED_TEST(TestParquetIO, SingleColumnTableRequiredWrite) {
+ std::shared_ptr<Array> values;
+ ASSERT_OK(NonNullArray<TypeParam>(SMALL_SIZE, &values));
+ std::shared_ptr<Table> table = MakeSimpleTable(values, false);
+
+ this->ResetSink();
+ ASSERT_OK_NO_THROW(WriteTable(*table, ::arrow::default_memory_pool(), this->sink_,
+ values->length(), default_writer_properties()));
+
+ std::shared_ptr<Table> out;
+ std::unique_ptr<FileReader> reader;
+ ASSERT_NO_FATAL_FAILURE(this->ReaderFromSink(&reader));
+ ASSERT_NO_FATAL_FAILURE(this->ReadTableFromFile(std::move(reader), &out));
+ ASSERT_EQ(1, out->num_columns());
+ EXPECT_EQ(table->num_rows(), out->num_rows());
+
+ std::shared_ptr<ChunkedArray> chunked_array = out->column(0);
+ ASSERT_EQ(1, chunked_array->num_chunks());
+
+ AssertArraysEqual(*values, *chunked_array->chunk(0));
+}
+
+TYPED_TEST(TestParquetIO, SingleColumnOptionalReadWrite) {
+ // This also tests max_definition_level = 1
+ std::shared_ptr<Array> values;
+
+ ASSERT_OK(NullableArray<TypeParam>(SMALL_SIZE, 10, kDefaultSeed, &values));
+
+ std::shared_ptr<GroupNode> schema =
+ MakeSimpleSchema(*values->type(), Repetition::OPTIONAL);
+ ASSERT_NO_FATAL_FAILURE(this->WriteColumn(schema, values));
+
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCheckSingleColumnFile(*values));
+}
+
+TYPED_TEST(TestParquetIO, SingleColumnOptionalDictionaryWrite) {
+ // Skip tests for BOOL as we don't create dictionaries for it.
+ if (TypeParam::type_id == ::arrow::Type::BOOL) {
+ return;
+ }
+
+ std::shared_ptr<Array> values;
+
+ ASSERT_OK(NullableArray<TypeParam>(SMALL_SIZE, 10, kDefaultSeed, &values));
+
+ ASSERT_OK_AND_ASSIGN(Datum out, DictionaryEncode(values));
+ std::shared_ptr<Array> dict_values = MakeArray(out.array());
+ std::shared_ptr<GroupNode> schema =
+ MakeSimpleSchema(*dict_values->type(), Repetition::OPTIONAL);
+ ASSERT_NO_FATAL_FAILURE(this->WriteColumn(schema, dict_values));
+
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCheckSingleColumnFile(*values));
+}
+
+TYPED_TEST(TestParquetIO, SingleColumnRequiredSliceWrite) {
+ std::shared_ptr<Array> values;
+ ASSERT_OK(NonNullArray<TypeParam>(2 * SMALL_SIZE, &values));
+ std::shared_ptr<GroupNode> schema =
+ MakeSimpleSchema(*values->type(), Repetition::REQUIRED);
+
+ std::shared_ptr<Array> sliced_values = values->Slice(SMALL_SIZE / 2, SMALL_SIZE);
+ ASSERT_NO_FATAL_FAILURE(this->WriteColumn(schema, sliced_values));
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCheckSingleColumnFile(*sliced_values));
+
+ // Slice offset 1 higher
+ sliced_values = values->Slice(SMALL_SIZE / 2 + 1, SMALL_SIZE);
+ ASSERT_NO_FATAL_FAILURE(this->WriteColumn(schema, sliced_values));
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCheckSingleColumnFile(*sliced_values));
+}
+
+TYPED_TEST(TestParquetIO, SingleColumnOptionalSliceWrite) {
+ std::shared_ptr<Array> values;
+ ASSERT_OK(NullableArray<TypeParam>(2 * SMALL_SIZE, SMALL_SIZE, kDefaultSeed, &values));
+ std::shared_ptr<GroupNode> schema =
+ MakeSimpleSchema(*values->type(), Repetition::OPTIONAL);
+
+ std::shared_ptr<Array> sliced_values = values->Slice(SMALL_SIZE / 2, SMALL_SIZE);
+ ASSERT_NO_FATAL_FAILURE(this->WriteColumn(schema, sliced_values));
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCheckSingleColumnFile(*sliced_values));
+
+ // Slice offset 1 higher, thus different null bitmap.
+ sliced_values = values->Slice(SMALL_SIZE / 2 + 1, SMALL_SIZE);
+ ASSERT_NO_FATAL_FAILURE(this->WriteColumn(schema, sliced_values));
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCheckSingleColumnFile(*sliced_values));
+}
+
+TYPED_TEST(TestParquetIO, SingleColumnTableOptionalReadWrite) {
+ // This also tests max_definition_level = 1
+ std::shared_ptr<Array> values;
+
+ ASSERT_OK(NullableArray<TypeParam>(SMALL_SIZE, 10, kDefaultSeed, &values));
+ std::shared_ptr<Table> table = MakeSimpleTable(values, true);
+ ASSERT_NO_FATAL_FAILURE(this->CheckRoundTrip(table));
+}
+
+TYPED_TEST(TestParquetIO, SingleEmptyListsColumnReadWrite) {
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(this->PrepareEmptyListsTable(SMALL_SIZE, &table));
+ ASSERT_NO_FATAL_FAILURE(this->CheckRoundTrip(table));
+}
+
+TYPED_TEST(TestParquetIO, SingleNullableListNullableColumnReadWrite) {
+ std::shared_ptr<Table> table;
+ this->PrepareListTable(SMALL_SIZE, true, true, 10, &table);
+ this->CheckRoundTrip(table);
+}
+
+TYPED_TEST(TestParquetIO, SingleRequiredListNullableColumnReadWrite) {
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(this->PrepareListTable(SMALL_SIZE, false, true, 10, &table));
+ ASSERT_NO_FATAL_FAILURE(this->CheckRoundTrip(table));
+}
+
+TYPED_TEST(TestParquetIO, SingleNullableListRequiredColumnReadWrite) {
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(this->PrepareListTable(SMALL_SIZE, true, false, 10, &table));
+ ASSERT_NO_FATAL_FAILURE(this->CheckRoundTrip(table));
+}
+
+TYPED_TEST(TestParquetIO, SingleRequiredListRequiredColumnReadWrite) {
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(this->PrepareListTable(SMALL_SIZE, false, false, 0, &table));
+ ASSERT_NO_FATAL_FAILURE(this->CheckRoundTrip(table));
+}
+
+TYPED_TEST(TestParquetIO, SingleNullableListRequiredListRequiredColumnReadWrite) {
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(
+ this->PrepareListOfListTable(SMALL_SIZE, true, false, false, 0, &table));
+ ASSERT_NO_FATAL_FAILURE(this->CheckRoundTrip(table));
+}
+
+TYPED_TEST(TestParquetIO, SingleColumnRequiredChunkedWrite) {
+ std::shared_ptr<Array> values;
+ ASSERT_OK(NonNullArray<TypeParam>(SMALL_SIZE, &values));
+ int64_t chunk_size = values->length() / 4;
+
+ std::shared_ptr<GroupNode> schema =
+ MakeSimpleSchema(*values->type(), Repetition::REQUIRED);
+ SchemaDescriptor descriptor;
+ ASSERT_NO_THROW(descriptor.Init(schema));
+ std::shared_ptr<::arrow::Schema> arrow_schema;
+ ArrowReaderProperties props;
+ ASSERT_OK_NO_THROW(FromParquetSchema(&descriptor, props, &arrow_schema));
+
+ std::unique_ptr<FileWriter> writer;
+ ASSERT_OK_NO_THROW(FileWriter::Make(::arrow::default_memory_pool(),
+ this->MakeWriter(schema), arrow_schema,
+ default_arrow_writer_properties(), &writer));
+ for (int i = 0; i < 4; i++) {
+ ASSERT_OK_NO_THROW(writer->NewRowGroup(chunk_size));
+ std::shared_ptr<Array> sliced_array = values->Slice(i * chunk_size, chunk_size);
+ ASSERT_OK_NO_THROW(writer->WriteColumnChunk(*sliced_array));
+ }
+ ASSERT_OK_NO_THROW(writer->Close());
+
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCheckSingleColumnFile(*values));
+}
+
+TYPED_TEST(TestParquetIO, SingleColumnTableRequiredChunkedWrite) {
+ std::shared_ptr<Array> values;
+ ASSERT_OK(NonNullArray<TypeParam>(LARGE_SIZE, &values));
+ std::shared_ptr<Table> table = MakeSimpleTable(values, false);
+
+ this->ResetSink();
+ ASSERT_OK_NO_THROW(WriteTable(*table, default_memory_pool(), this->sink_, 512,
+ default_writer_properties()));
+
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCheckSingleColumnTable(values));
+}
+
+TYPED_TEST(TestParquetIO, SingleColumnTableRequiredChunkedWriteArrowIO) {
+ std::shared_ptr<Array> values;
+ ASSERT_OK(NonNullArray<TypeParam>(LARGE_SIZE, &values));
+ std::shared_ptr<Table> table = MakeSimpleTable(values, false);
+
+ this->ResetSink();
+ auto buffer = AllocateBuffer();
+
+ {
+ // BufferOutputStream closed on gc
+ auto arrow_sink_ = std::make_shared<::arrow::io::BufferOutputStream>(buffer);
+ ASSERT_OK_NO_THROW(WriteTable(*table, default_memory_pool(), arrow_sink_, 512,
+ default_writer_properties()));
+
+ // XXX: Remove this after ARROW-455 completed
+ ASSERT_OK(arrow_sink_->Close());
+ }
+
+ auto pbuffer = std::make_shared<Buffer>(buffer->data(), buffer->size());
+
+ auto source = std::make_shared<BufferReader>(pbuffer);
+ std::shared_ptr<::arrow::Table> out;
+ std::unique_ptr<FileReader> reader;
+ ASSERT_OK_NO_THROW(OpenFile(source, ::arrow::default_memory_pool(), &reader));
+ ASSERT_NO_FATAL_FAILURE(this->ReadTableFromFile(std::move(reader), &out));
+ ASSERT_EQ(1, out->num_columns());
+ ASSERT_EQ(values->length(), out->num_rows());
+
+ std::shared_ptr<ChunkedArray> chunked_array = out->column(0);
+ ASSERT_EQ(1, chunked_array->num_chunks());
+
+ AssertArraysEqual(*values, *chunked_array->chunk(0));
+}
+
+TYPED_TEST(TestParquetIO, SingleColumnOptionalChunkedWrite) {
+ int64_t chunk_size = SMALL_SIZE / 4;
+ std::shared_ptr<Array> values;
+
+ ASSERT_OK(NullableArray<TypeParam>(SMALL_SIZE, 10, kDefaultSeed, &values));
+
+ std::shared_ptr<GroupNode> schema =
+ MakeSimpleSchema(*values->type(), Repetition::OPTIONAL);
+ SchemaDescriptor descriptor;
+ ASSERT_NO_THROW(descriptor.Init(schema));
+ std::shared_ptr<::arrow::Schema> arrow_schema;
+ ArrowReaderProperties props;
+ ASSERT_OK_NO_THROW(FromParquetSchema(&descriptor, props, &arrow_schema));
+
+ std::unique_ptr<FileWriter> writer;
+ ASSERT_OK_NO_THROW(FileWriter::Make(::arrow::default_memory_pool(),
+ this->MakeWriter(schema), arrow_schema,
+ default_arrow_writer_properties(), &writer));
+ for (int i = 0; i < 4; i++) {
+ ASSERT_OK_NO_THROW(writer->NewRowGroup(chunk_size));
+ std::shared_ptr<Array> sliced_array = values->Slice(i * chunk_size, chunk_size);
+ ASSERT_OK_NO_THROW(writer->WriteColumnChunk(*sliced_array));
+ }
+ ASSERT_OK_NO_THROW(writer->Close());
+
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCheckSingleColumnFile(*values));
+}
+
+TYPED_TEST(TestParquetIO, SingleColumnTableOptionalChunkedWrite) {
+ // This also tests max_definition_level = 1
+ std::shared_ptr<Array> values;
+
+ ASSERT_OK(NullableArray<TypeParam>(LARGE_SIZE, 100, kDefaultSeed, &values));
+ std::shared_ptr<Table> table = MakeSimpleTable(values, true);
+ this->ResetSink();
+ ASSERT_OK_NO_THROW(WriteTable(*table, ::arrow::default_memory_pool(), this->sink_, 512,
+ default_writer_properties()));
+
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCheckSingleColumnTable(values));
+}
+
+TYPED_TEST(TestParquetIO, FileMetaDataWrite) {
+ std::shared_ptr<Array> values;
+ ASSERT_OK(NonNullArray<TypeParam>(SMALL_SIZE, &values));
+ std::shared_ptr<Table> table = MakeSimpleTable(values, false);
+ this->ResetSink();
+ ASSERT_OK_NO_THROW(WriteTable(*table, ::arrow::default_memory_pool(), this->sink_,
+ values->length(), default_writer_properties()));
+
+ std::unique_ptr<FileReader> reader;
+ ASSERT_NO_FATAL_FAILURE(this->ReaderFromSink(&reader));
+ auto metadata = reader->parquet_reader()->metadata();
+ ASSERT_EQ(1, metadata->num_columns());
+ EXPECT_EQ(table->num_rows(), metadata->num_rows());
+
+ this->ResetSink();
+
+ ASSERT_OK_NO_THROW(::parquet::arrow::WriteFileMetaData(*metadata, this->sink_.get()));
+
+ ASSERT_NO_FATAL_FAILURE(this->ReaderFromSink(&reader));
+ auto metadata_written = reader->parquet_reader()->metadata();
+ ASSERT_EQ(metadata->size(), metadata_written->size());
+ ASSERT_EQ(metadata->num_row_groups(), metadata_written->num_row_groups());
+ ASSERT_EQ(metadata->num_rows(), metadata_written->num_rows());
+ ASSERT_EQ(metadata->num_columns(), metadata_written->num_columns());
+ ASSERT_EQ(metadata->RowGroup(0)->num_rows(), metadata_written->RowGroup(0)->num_rows());
+}
+
+TYPED_TEST(TestParquetIO, CheckIterativeColumnRead) {
+ // ARROW-5608: Test using ColumnReader with small batch size (1) and non-repeated
+ // nullable fields with ASAN.
+ std::shared_ptr<Array> values;
+ ASSERT_OK(NonNullArray<TypeParam>(SMALL_SIZE, &values));
+ std::shared_ptr<Table> table = MakeSimpleTable(values, true);
+ this->ResetSink();
+ ASSERT_OK_NO_THROW(WriteTable(*table, ::arrow::default_memory_pool(), this->sink_,
+ values->length(), default_writer_properties()));
+
+ std::unique_ptr<FileReader> reader;
+ this->ReaderFromSink(&reader);
+ std::unique_ptr<ColumnReader> column_reader;
+ ASSERT_OK_NO_THROW(reader->GetColumn(0, &column_reader));
+ ASSERT_NE(nullptr, column_reader.get());
+
+ // Read one record at a time.
+ std::vector<std::shared_ptr<::arrow::Array>> batches;
+
+ for (int64_t i = 0; i < values->length(); ++i) {
+ std::shared_ptr<::arrow::ChunkedArray> batch;
+ ASSERT_OK_NO_THROW(column_reader->NextBatch(1, &batch));
+ ASSERT_EQ(1, batch->length());
+ ASSERT_EQ(1, batch->num_chunks());
+ batches.push_back(batch->chunk(0));
+ }
+
+ auto chunked = std::make_shared<::arrow::ChunkedArray>(batches);
+ auto chunked_table = ::arrow::Table::Make(table->schema(), {chunked});
+ ASSERT_TRUE(table->Equals(*chunked_table));
+}
+
+using TestInt96ParquetIO = TestParquetIO<::arrow::TimestampType>;
+
+TEST_F(TestInt96ParquetIO, ReadIntoTimestamp) {
+ // This test explicitly tests the conversion from an Impala-style timestamp
+ // to a nanoseconds-since-epoch one.
+
+ // 2nd January 1970, 11:35min 145738543ns
+ Int96 day;
+ day.value[2] = UINT32_C(2440589);
+ int64_t seconds = (11 * 60 + 35) * 60;
+ Int96SetNanoSeconds(
+ day, seconds * INT64_C(1000) * INT64_C(1000) * INT64_C(1000) + 145738543);
+ // Compute the corresponding nanosecond timestamp
+ struct tm datetime;
+ memset(&datetime, 0, sizeof(struct tm));
+ datetime.tm_year = 70;
+ datetime.tm_mon = 0;
+ datetime.tm_mday = 2;
+ datetime.tm_hour = 11;
+ datetime.tm_min = 35;
+ struct tm epoch;
+ memset(&epoch, 0, sizeof(struct tm));
+
+ epoch.tm_year = 70;
+ epoch.tm_mday = 1;
+ // Nanoseconds since the epoch
+ int64_t val = lrint(difftime(mktime(&datetime), mktime(&epoch))) * INT64_C(1000000000);
+ val += 145738543;
+
+ std::vector<std::shared_ptr<schema::Node>> fields(
+ {schema::PrimitiveNode::Make("int96", Repetition::REQUIRED, ParquetType::INT96)});
+ std::shared_ptr<schema::GroupNode> schema = std::static_pointer_cast<GroupNode>(
+ schema::GroupNode::Make("schema", Repetition::REQUIRED, fields));
+
+ // We cannot write this column with Arrow, so we have to use the plain parquet-cpp API
+ // to write an Int96 file.
+ this->ResetSink();
+ auto writer = ParquetFileWriter::Open(this->sink_, schema);
+ RowGroupWriter* rg_writer = writer->AppendRowGroup();
+ ColumnWriter* c_writer = rg_writer->NextColumn();
+ auto typed_writer = dynamic_cast<TypedColumnWriter<Int96Type>*>(c_writer);
+ ASSERT_NE(typed_writer, nullptr);
+ typed_writer->WriteBatch(1, nullptr, nullptr, &day);
+ c_writer->Close();
+ rg_writer->Close();
+ writer->Close();
+
+ ::arrow::TimestampBuilder builder(::arrow::timestamp(TimeUnit::NANO),
+ ::arrow::default_memory_pool());
+ ASSERT_OK(builder.Append(val));
+ std::shared_ptr<Array> values;
+ ASSERT_OK(builder.Finish(&values));
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCheckSingleColumnFile(*values));
+}
+
+using TestUInt32ParquetIO = TestParquetIO<::arrow::UInt32Type>;
+
+TEST_F(TestUInt32ParquetIO, Parquet_2_0_Compatibility) {
+ // This also tests max_definition_level = 1
+ std::shared_ptr<Array> values;
+
+ ASSERT_OK(NullableArray<::arrow::UInt32Type>(LARGE_SIZE, 100, kDefaultSeed, &values));
+ std::shared_ptr<Table> table = MakeSimpleTable(values, true);
+
+ // Parquet 2.4 roundtrip should yield an uint32_t column again
+ this->ResetSink();
+ std::shared_ptr<::parquet::WriterProperties> properties =
+ ::parquet::WriterProperties::Builder()
+ .version(ParquetVersion::PARQUET_2_4)
+ ->build();
+ ASSERT_OK_NO_THROW(
+ WriteTable(*table, default_memory_pool(), this->sink_, 512, properties));
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCheckSingleColumnTable(values));
+}
+
+TEST_F(TestUInt32ParquetIO, Parquet_1_0_Compatibility) {
+ // This also tests max_definition_level = 1
+ std::shared_ptr<Array> arr;
+ ASSERT_OK(NullableArray<::arrow::UInt32Type>(LARGE_SIZE, 100, kDefaultSeed, &arr));
+
+ std::shared_ptr<::arrow::UInt32Array> values =
+ std::dynamic_pointer_cast<::arrow::UInt32Array>(arr);
+
+ std::shared_ptr<Table> table = MakeSimpleTable(values, true);
+
+ // Parquet 1.0 returns an int64_t column as there is no way to tell a Parquet 1.0
+ // reader that a column is unsigned.
+ this->ResetSink();
+ std::shared_ptr<::parquet::WriterProperties> properties =
+ ::parquet::WriterProperties::Builder()
+ .version(ParquetVersion::PARQUET_1_0)
+ ->build();
+ ASSERT_OK_NO_THROW(
+ WriteTable(*table, ::arrow::default_memory_pool(), this->sink_, 512, properties));
+
+ std::shared_ptr<ResizableBuffer> int64_data = AllocateBuffer();
+ {
+ ASSERT_OK(int64_data->Resize(sizeof(int64_t) * values->length()));
+ auto int64_data_ptr = reinterpret_cast<int64_t*>(int64_data->mutable_data());
+ auto uint32_data_ptr = reinterpret_cast<const uint32_t*>(values->values()->data());
+ const auto cast_uint32_to_int64 = [](uint32_t value) {
+ return static_cast<int64_t>(value);
+ };
+ std::transform(uint32_data_ptr, uint32_data_ptr + values->length(), int64_data_ptr,
+ cast_uint32_to_int64);
+ }
+
+ std::vector<std::shared_ptr<Buffer>> buffers{values->null_bitmap(), int64_data};
+ auto arr_data = std::make_shared<ArrayData>(::arrow::int64(), values->length(), buffers,
+ values->null_count());
+ std::shared_ptr<Array> expected_values = MakeArray(arr_data);
+ ASSERT_NE(expected_values, NULLPTR);
+
+ const auto& expected = static_cast<const ::arrow::Int64Array&>(*expected_values);
+ ASSERT_GT(values->length(), 0);
+ ASSERT_EQ(values->length(), expected.length());
+
+ // TODO(phillipc): Is there a better way to compare these two arrays?
+ // AssertArraysEqual requires the same type, but we only care about values in this case
+ for (int i = 0; i < expected.length(); ++i) {
+ const bool value_is_valid = values->IsValid(i);
+ const bool expected_value_is_valid = expected.IsValid(i);
+
+ ASSERT_EQ(expected_value_is_valid, value_is_valid);
+
+ if (value_is_valid) {
+ uint32_t value = values->Value(i);
+ int64_t expected_value = expected.Value(i);
+ ASSERT_EQ(expected_value, static_cast<int64_t>(value));
+ }
+ }
+}
+
+using TestStringParquetIO = TestParquetIO<::arrow::StringType>;
+
+TEST_F(TestStringParquetIO, EmptyStringColumnRequiredWrite) {
+ std::shared_ptr<Array> values;
+ ::arrow::StringBuilder builder;
+ for (size_t i = 0; i < SMALL_SIZE; i++) {
+ ASSERT_OK(builder.Append(""));
+ }
+ ASSERT_OK(builder.Finish(&values));
+ std::shared_ptr<Table> table = MakeSimpleTable(values, false);
+ this->ResetSink();
+ ASSERT_OK_NO_THROW(WriteTable(*table, ::arrow::default_memory_pool(), this->sink_,
+ values->length(), default_writer_properties()));
+
+ std::shared_ptr<Table> out;
+ std::unique_ptr<FileReader> reader;
+ ASSERT_NO_FATAL_FAILURE(this->ReaderFromSink(&reader));
+ ASSERT_NO_FATAL_FAILURE(this->ReadTableFromFile(std::move(reader), &out));
+ ASSERT_EQ(1, out->num_columns());
+ ASSERT_EQ(table->num_rows(), out->num_rows());
+
+ std::shared_ptr<ChunkedArray> chunked_array = out->column(0);
+ ASSERT_EQ(1, chunked_array->num_chunks());
+
+ AssertArraysEqual(*values, *chunked_array->chunk(0));
+}
+
+using TestLargeBinaryParquetIO = TestParquetIO<::arrow::LargeBinaryType>;
+
+TEST_F(TestLargeBinaryParquetIO, Basics) {
+ const char* json = "[\"foo\", \"\", null, \"\xff\"]";
+
+ const auto large_type = ::arrow::large_binary();
+ const auto narrow_type = ::arrow::binary();
+ const auto large_array = ::arrow::ArrayFromJSON(large_type, json);
+ const auto narrow_array = ::arrow::ArrayFromJSON(narrow_type, json);
+
+ // When the original Arrow schema isn't stored, a LargeBinary array
+ // is decoded as Binary (since there is no specific Parquet logical
+ // type for it).
+ this->RoundTripSingleColumn(large_array, narrow_array,
+ default_arrow_writer_properties());
+
+ // When the original Arrow schema is stored, the LargeBinary array
+ // is read back as LargeBinary.
+ const auto arrow_properties =
+ ::parquet::ArrowWriterProperties::Builder().store_schema()->build();
+ this->RoundTripSingleColumn(large_array, large_array, arrow_properties);
+}
+
+using TestLargeStringParquetIO = TestParquetIO<::arrow::LargeStringType>;
+
+TEST_F(TestLargeStringParquetIO, Basics) {
+ const char* json = R"(["foo", "", null, "bar"])";
+
+ const auto large_type = ::arrow::large_utf8();
+ const auto narrow_type = ::arrow::utf8();
+ const auto large_array = ::arrow::ArrayFromJSON(large_type, json);
+ const auto narrow_array = ::arrow::ArrayFromJSON(narrow_type, json);
+
+ // When the original Arrow schema isn't stored, a LargeBinary array
+ // is decoded as Binary (since there is no specific Parquet logical
+ // type for it).
+ this->RoundTripSingleColumn(large_array, narrow_array,
+ default_arrow_writer_properties());
+
+ // When the original Arrow schema is stored, the LargeBinary array
+ // is read back as LargeBinary.
+ const auto arrow_properties =
+ ::parquet::ArrowWriterProperties::Builder().store_schema()->build();
+ this->RoundTripSingleColumn(large_array, large_array, arrow_properties);
+}
+
+using TestNullParquetIO = TestParquetIO<::arrow::NullType>;
+
+TEST_F(TestNullParquetIO, NullColumn) {
+ for (int32_t num_rows : {0, SMALL_SIZE}) {
+ std::shared_ptr<Array> values = std::make_shared<::arrow::NullArray>(num_rows);
+ std::shared_ptr<Table> table = MakeSimpleTable(values, true /* nullable */);
+ this->ResetSink();
+
+ const int64_t chunk_size = std::max(static_cast<int64_t>(1), table->num_rows());
+ ASSERT_OK_NO_THROW(WriteTable(*table, ::arrow::default_memory_pool(), this->sink_,
+ chunk_size, default_writer_properties()));
+
+ std::shared_ptr<Table> out;
+ std::unique_ptr<FileReader> reader;
+ ASSERT_NO_FATAL_FAILURE(this->ReaderFromSink(&reader));
+ ASSERT_NO_FATAL_FAILURE(this->ReadTableFromFile(std::move(reader), &out));
+ ASSERT_EQ(1, out->num_columns());
+ ASSERT_EQ(num_rows, out->num_rows());
+
+ std::shared_ptr<ChunkedArray> chunked_array = out->column(0);
+ ASSERT_EQ(1, chunked_array->num_chunks());
+ AssertArraysEqual(*values, *chunked_array->chunk(0));
+ }
+}
+
+TEST_F(TestNullParquetIO, NullListColumn) {
+ std::vector<int32_t> offsets1 = {0};
+ std::vector<int32_t> offsets2 = {0, 2, 2, 3, 115};
+ for (std::vector<int32_t> offsets : {offsets1, offsets2}) {
+ std::shared_ptr<Array> offsets_array, values_array, list_array;
+ ::arrow::ArrayFromVector<::arrow::Int32Type, int32_t>(offsets, &offsets_array);
+ values_array = std::make_shared<::arrow::NullArray>(offsets.back());
+ ASSERT_OK_AND_ASSIGN(list_array,
+ ::arrow::ListArray::FromArrays(*offsets_array, *values_array));
+
+ std::shared_ptr<Table> table = MakeSimpleTable(list_array, false /* nullable */);
+ this->ResetSink();
+
+ const int64_t chunk_size = std::max(static_cast<int64_t>(1), table->num_rows());
+ ASSERT_OK_NO_THROW(WriteTable(*table, ::arrow::default_memory_pool(), this->sink_,
+ chunk_size, default_writer_properties()));
+
+ std::shared_ptr<Table> out;
+ std::unique_ptr<FileReader> reader;
+ this->ReaderFromSink(&reader);
+ this->ReadTableFromFile(std::move(reader), &out);
+ ASSERT_EQ(1, out->num_columns());
+ ASSERT_EQ(offsets.size() - 1, out->num_rows());
+
+ std::shared_ptr<ChunkedArray> chunked_array = out->column(0);
+ ASSERT_EQ(1, chunked_array->num_chunks());
+ AssertArraysEqual(*list_array, *chunked_array->chunk(0));
+ }
+}
+
+TEST_F(TestNullParquetIO, NullDictionaryColumn) {
+ ASSERT_OK_AND_ASSIGN(auto null_bitmap, ::arrow::AllocateEmptyBitmap(SMALL_SIZE));
+
+ ASSERT_OK_AND_ASSIGN(auto indices, MakeArrayOfNull(::arrow::int8(), SMALL_SIZE));
+ std::shared_ptr<::arrow::DictionaryType> dict_type =
+ std::make_shared<::arrow::DictionaryType>(::arrow::int8(), ::arrow::null());
+
+ std::shared_ptr<Array> dict = std::make_shared<::arrow::NullArray>(0);
+ std::shared_ptr<Array> dict_values =
+ std::make_shared<::arrow::DictionaryArray>(dict_type, indices, dict);
+ std::shared_ptr<Table> table = MakeSimpleTable(dict_values, true);
+ this->ResetSink();
+ ASSERT_OK_NO_THROW(WriteTable(*table, ::arrow::default_memory_pool(), this->sink_,
+ dict_values->length(), default_writer_properties()));
+
+ std::shared_ptr<Table> out;
+ std::unique_ptr<FileReader> reader;
+ ASSERT_NO_FATAL_FAILURE(this->ReaderFromSink(&reader));
+ ASSERT_NO_FATAL_FAILURE(this->ReadTableFromFile(std::move(reader), &out));
+ ASSERT_EQ(1, out->num_columns());
+ ASSERT_EQ(100, out->num_rows());
+
+ std::shared_ptr<ChunkedArray> chunked_array = out->column(0);
+ ASSERT_EQ(1, chunked_array->num_chunks());
+
+ std::shared_ptr<Array> expected_values =
+ std::make_shared<::arrow::NullArray>(SMALL_SIZE);
+ AssertArraysEqual(*expected_values, *chunked_array->chunk(0));
+}
+
+template <typename T>
+using ParquetCDataType = typename ParquetDataType<T>::c_type;
+
+template <typename T>
+struct c_type_trait {
+ using ArrowCType = typename T::c_type;
+};
+
+template <>
+struct c_type_trait<::arrow::BooleanType> {
+ using ArrowCType = uint8_t;
+};
+
+template <typename TestType>
+class TestPrimitiveParquetIO : public TestParquetIO<TestType> {
+ public:
+ typedef typename c_type_trait<TestType>::ArrowCType T;
+
+ void MakeTestFile(std::vector<T>& values, int num_chunks,
+ std::unique_ptr<FileReader>* reader) {
+ TestType dummy;
+
+ std::shared_ptr<GroupNode> schema = MakeSimpleSchema(dummy, Repetition::REQUIRED);
+ std::unique_ptr<ParquetFileWriter> file_writer = this->MakeWriter(schema);
+ size_t chunk_size = values.size() / num_chunks;
+ // Convert to Parquet's expected physical type
+ std::vector<uint8_t> values_buffer(sizeof(ParquetCDataType<TestType>) *
+ values.size());
+ auto values_parquet =
+ reinterpret_cast<ParquetCDataType<TestType>*>(values_buffer.data());
+ std::copy(values.cbegin(), values.cend(), values_parquet);
+ for (int i = 0; i < num_chunks; i++) {
+ auto row_group_writer = file_writer->AppendRowGroup();
+ auto column_writer =
+ static_cast<ParquetWriter<TestType>*>(row_group_writer->NextColumn());
+ ParquetCDataType<TestType>* data = values_parquet + i * chunk_size;
+ column_writer->WriteBatch(chunk_size, nullptr, nullptr, data);
+ column_writer->Close();
+ row_group_writer->Close();
+ }
+ file_writer->Close();
+ this->ReaderFromSink(reader);
+ }
+
+ void CheckSingleColumnRequiredTableRead(int num_chunks) {
+ std::vector<T> values(SMALL_SIZE, test_traits<TestType>::value);
+ std::unique_ptr<FileReader> file_reader;
+ ASSERT_NO_FATAL_FAILURE(MakeTestFile(values, num_chunks, &file_reader));
+
+ std::shared_ptr<Table> out;
+ this->ReadTableFromFile(std::move(file_reader), &out);
+ ASSERT_EQ(1, out->num_columns());
+ ASSERT_EQ(SMALL_SIZE, out->num_rows());
+
+ std::shared_ptr<ChunkedArray> chunked_array = out->column(0);
+ ASSERT_EQ(1, chunked_array->num_chunks());
+ ExpectArrayT<TestType>(values.data(), chunked_array->chunk(0).get());
+ }
+
+ void CheckSingleColumnRequiredRead(int num_chunks) {
+ std::vector<T> values(SMALL_SIZE, test_traits<TestType>::value);
+ std::unique_ptr<FileReader> file_reader;
+ ASSERT_NO_FATAL_FAILURE(MakeTestFile(values, num_chunks, &file_reader));
+
+ std::shared_ptr<Array> out;
+ this->ReadSingleColumnFile(std::move(file_reader), &out);
+
+ ExpectArrayT<TestType>(values.data(), out.get());
+ }
+
+ void CheckSingleColumnStatisticsRequiredRead() {
+ std::vector<T> values(SMALL_SIZE, test_traits<TestType>::value);
+ std::unique_ptr<FileReader> file_reader;
+ ASSERT_NO_FATAL_FAILURE(MakeTestFile(values, 1, &file_reader));
+
+ std::shared_ptr<Scalar> min, max;
+ ReadSingleColumnFileStatistics(std::move(file_reader), &min, &max);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto value, ::arrow::MakeScalar(::arrow::TypeTraits<TestType>::type_singleton(),
+ test_traits<TestType>::value));
+
+ ASSERT_TRUE(value->Equals(*min));
+ ASSERT_TRUE(value->Equals(*max));
+ }
+};
+
+typedef ::testing::Types<::arrow::BooleanType, ::arrow::UInt8Type, ::arrow::Int8Type,
+ ::arrow::UInt16Type, ::arrow::Int16Type, ::arrow::UInt32Type,
+ ::arrow::Int32Type, ::arrow::UInt64Type, ::arrow::Int64Type,
+ ::arrow::FloatType, ::arrow::DoubleType>
+ PrimitiveTestTypes;
+
+TYPED_TEST_SUITE(TestPrimitiveParquetIO, PrimitiveTestTypes);
+
+TYPED_TEST(TestPrimitiveParquetIO, SingleColumnRequiredRead) {
+ ASSERT_NO_FATAL_FAILURE(this->CheckSingleColumnRequiredRead(1));
+}
+
+TYPED_TEST(TestPrimitiveParquetIO, SingleColumnStatisticsRequiredRead) {
+ ASSERT_NO_FATAL_FAILURE(this->CheckSingleColumnStatisticsRequiredRead());
+}
+
+TYPED_TEST(TestPrimitiveParquetIO, SingleColumnRequiredTableRead) {
+ ASSERT_NO_FATAL_FAILURE(this->CheckSingleColumnRequiredTableRead(1));
+}
+
+TYPED_TEST(TestPrimitiveParquetIO, SingleColumnRequiredChunkedRead) {
+ ASSERT_NO_FATAL_FAILURE(this->CheckSingleColumnRequiredRead(4));
+}
+
+TYPED_TEST(TestPrimitiveParquetIO, SingleColumnRequiredChunkedTableRead) {
+ ASSERT_NO_FATAL_FAILURE(this->CheckSingleColumnRequiredTableRead(4));
+}
+
+void MakeDateTimeTypesTable(std::shared_ptr<Table>* out, bool expected = false) {
+ using ::arrow::ArrayFromVector;
+
+ std::vector<bool> is_valid = {true, true, true, false, true, true};
+
+ // These are only types that roundtrip without modification
+ auto f0 = field("f0", ::arrow::date32());
+ auto f1 = field("f1", ::arrow::timestamp(TimeUnit::MILLI));
+ auto f2 = field("f2", ::arrow::timestamp(TimeUnit::MICRO));
+ auto f3 = field("f3", ::arrow::timestamp(TimeUnit::NANO));
+ auto f3_x = field("f3", ::arrow::timestamp(TimeUnit::MICRO));
+ auto f4 = field("f4", ::arrow::time32(TimeUnit::MILLI));
+ auto f5 = field("f5", ::arrow::time64(TimeUnit::MICRO));
+ auto f6 = field("f6", ::arrow::time64(TimeUnit::NANO));
+
+ std::shared_ptr<::arrow::Schema> schema(
+ new ::arrow::Schema({f0, f1, f2, (expected ? f3_x : f3), f4, f5, f6}));
+
+ std::vector<int32_t> t32_values = {1489269000, 1489270000, 1489271000,
+ 1489272000, 1489272000, 1489273000};
+ std::vector<int64_t> t64_ns_values = {1489269000000, 1489270000000, 1489271000000,
+ 1489272000000, 1489272000000, 1489273000000};
+ std::vector<int64_t> t64_us_values = {1489269000, 1489270000, 1489271000,
+ 1489272000, 1489272000, 1489273000};
+ std::vector<int64_t> t64_ms_values = {1489269, 1489270, 1489271,
+ 1489272, 1489272, 1489273};
+
+ std::shared_ptr<Array> a0, a1, a2, a3, a3_x, a4, a5, a6;
+ ArrayFromVector<::arrow::Date32Type, int32_t>(f0->type(), is_valid, t32_values, &a0);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(f1->type(), is_valid, t64_ms_values,
+ &a1);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(f2->type(), is_valid, t64_us_values,
+ &a2);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(f3->type(), is_valid, t64_ns_values,
+ &a3);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(f3_x->type(), is_valid, t64_us_values,
+ &a3_x);
+ ArrayFromVector<::arrow::Time32Type, int32_t>(f4->type(), is_valid, t32_values, &a4);
+ ArrayFromVector<::arrow::Time64Type, int64_t>(f5->type(), is_valid, t64_us_values, &a5);
+ ArrayFromVector<::arrow::Time64Type, int64_t>(f6->type(), is_valid, t64_ns_values, &a6);
+
+ *out = Table::Make(schema, {a0, a1, a2, expected ? a3_x : a3, a4, a5, a6});
+}
+
+TEST(TestArrowReadWrite, DateTimeTypes) {
+ std::shared_ptr<Table> table, result;
+
+ MakeDateTimeTypesTable(&table);
+ ASSERT_NO_FATAL_FAILURE(
+ DoSimpleRoundtrip(table, false /* use_threads */, table->num_rows(), {}, &result));
+
+ MakeDateTimeTypesTable(&table, true); // build expected result
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertSchemaEqual(*table->schema(), *result->schema(),
+ /*check_metadata=*/false));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*table, *result));
+}
+
+TEST(TestArrowReadWrite, UseDeprecatedInt96) {
+ using ::arrow::ArrayFromVector;
+ using ::arrow::field;
+ using ::arrow::schema;
+
+ std::vector<bool> is_valid = {true, true, true, false, true, true};
+
+ auto t_s = ::arrow::timestamp(TimeUnit::SECOND);
+ auto t_ms = ::arrow::timestamp(TimeUnit::MILLI);
+ auto t_us = ::arrow::timestamp(TimeUnit::MICRO);
+ auto t_ns = ::arrow::timestamp(TimeUnit::NANO);
+
+ std::vector<int64_t> s_values = {1489269, 1489270, 1489271, 1489272, 1489272, 1489273};
+ std::vector<int64_t> ms_values = {1489269000, 1489270000, 1489271000,
+ 1489272001, 1489272000, 1489273000};
+ std::vector<int64_t> us_values = {1489269000000, 1489270000000, 1489271000000,
+ 1489272000001, 1489272000000, 1489273000000};
+ std::vector<int64_t> ns_values = {1489269000000000LL, 1489270000000000LL,
+ 1489271000000000LL, 1489272000000001LL,
+ 1489272000000000LL, 1489273000000000LL};
+
+ std::shared_ptr<Array> a_s, a_ms, a_us, a_ns;
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_s, is_valid, s_values, &a_s);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_ms, is_valid, ms_values, &a_ms);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_us, is_valid, us_values, &a_us);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_ns, is_valid, ns_values, &a_ns);
+
+ // Each input is typed with a unique TimeUnit
+ auto input_schema = schema(
+ {field("f_s", t_s), field("f_ms", t_ms), field("f_us", t_us), field("f_ns", t_ns)});
+ auto input = Table::Make(input_schema, {a_s, a_ms, a_us, a_ns});
+
+ // When reading parquet files, all int96 schema fields are converted to
+ // timestamp nanoseconds
+ auto ex_schema = schema({field("f_s", t_ns), field("f_ms", t_ns), field("f_us", t_ns),
+ field("f_ns", t_ns)});
+ auto ex_result = Table::Make(ex_schema, {a_ns, a_ns, a_ns, a_ns});
+
+ std::shared_ptr<Table> result;
+ ASSERT_NO_FATAL_FAILURE(DoSimpleRoundtrip(
+ input, false /* use_threads */, input->num_rows(), {}, &result,
+ ArrowWriterProperties::Builder().enable_deprecated_int96_timestamps()->build()));
+
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertSchemaEqual(*ex_result->schema(),
+ *result->schema(),
+ /*check_metadata=*/false));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*ex_result, *result));
+
+ // Ensure enable_deprecated_int96_timestamps as precedence over
+ // coerce_timestamps.
+ ASSERT_NO_FATAL_FAILURE(DoSimpleRoundtrip(input, false /* use_threads */,
+ input->num_rows(), {}, &result,
+ ArrowWriterProperties::Builder()
+ .enable_deprecated_int96_timestamps()
+ ->coerce_timestamps(TimeUnit::MILLI)
+ ->build()));
+
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertSchemaEqual(*ex_result->schema(),
+ *result->schema(),
+ /*check_metadata=*/false));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*ex_result, *result));
+}
+
+TEST(TestArrowReadWrite, DownsampleDeprecatedInt96) {
+ using ::arrow::ArrayFromJSON;
+ using ::arrow::field;
+ using ::arrow::schema;
+
+ // Timestamp values at 2000-01-01 00:00:00,
+ // then with increment unit of 1ns, 1us, 1ms and 1s.
+ auto a_nano =
+ ArrayFromJSON(timestamp(TimeUnit::NANO),
+ "[946684800000000000, 946684800000000001, 946684800000001000, "
+ "946684800001000000, 946684801000000000]");
+ auto a_micro = ArrayFromJSON(timestamp(TimeUnit::MICRO),
+ "[946684800000000, 946684800000000, 946684800000001, "
+ "946684800001000, 946684801000000]");
+ auto a_milli = ArrayFromJSON(
+ timestamp(TimeUnit::MILLI),
+ "[946684800000, 946684800000, 946684800000, 946684800001, 946684801000]");
+ auto a_second =
+ ArrayFromJSON(timestamp(TimeUnit::SECOND),
+ "[946684800, 946684800, 946684800, 946684800, 946684801]");
+
+ ASSERT_NO_FATAL_FAILURE(DownsampleInt96RoundTrip(a_nano, a_nano, TimeUnit::NANO));
+ ASSERT_NO_FATAL_FAILURE(DownsampleInt96RoundTrip(a_nano, a_micro, TimeUnit::MICRO));
+ ASSERT_NO_FATAL_FAILURE(DownsampleInt96RoundTrip(a_nano, a_milli, TimeUnit::MILLI));
+ ASSERT_NO_FATAL_FAILURE(DownsampleInt96RoundTrip(a_nano, a_second, TimeUnit::SECOND));
+}
+
+TEST(TestArrowReadWrite, CoerceTimestamps) {
+ using ::arrow::ArrayFromVector;
+ using ::arrow::field;
+
+ std::vector<bool> is_valid = {true, true, true, false, true, true};
+
+ auto t_s = ::arrow::timestamp(TimeUnit::SECOND);
+ auto t_ms = ::arrow::timestamp(TimeUnit::MILLI);
+ auto t_us = ::arrow::timestamp(TimeUnit::MICRO);
+ auto t_ns = ::arrow::timestamp(TimeUnit::NANO);
+
+ std::vector<int64_t> s_values = {1489269, 1489270, 1489271, 1489272, 1489272, 1489273};
+ std::vector<int64_t> ms_values = {1489269000, 1489270000, 1489271000,
+ 1489272001, 1489272000, 1489273000};
+ std::vector<int64_t> us_values = {1489269000000, 1489270000000, 1489271000000,
+ 1489272000001, 1489272000000, 1489273000000};
+ std::vector<int64_t> ns_values = {1489269000000000LL, 1489270000000000LL,
+ 1489271000000000LL, 1489272000000001LL,
+ 1489272000000000LL, 1489273000000000LL};
+
+ std::shared_ptr<Array> a_s, a_ms, a_us, a_ns;
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_s, is_valid, s_values, &a_s);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_ms, is_valid, ms_values, &a_ms);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_us, is_valid, us_values, &a_us);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_ns, is_valid, ns_values, &a_ns);
+
+ // Input table, all data as is
+ auto s1 = ::arrow::schema(
+ {field("f_s", t_s), field("f_ms", t_ms), field("f_us", t_us), field("f_ns", t_ns)});
+ auto input = Table::Make(s1, {a_s, a_ms, a_us, a_ns});
+
+ // Result when coercing to milliseconds
+ auto s2 = ::arrow::schema({field("f_s", t_ms), field("f_ms", t_ms), field("f_us", t_ms),
+ field("f_ns", t_ms)});
+ auto ex_milli_result = Table::Make(s2, {a_ms, a_ms, a_ms, a_ms});
+ std::shared_ptr<Table> milli_result;
+ ASSERT_NO_FATAL_FAILURE(DoSimpleRoundtrip(
+ input, false /* use_threads */, input->num_rows(), {}, &milli_result,
+ ArrowWriterProperties::Builder().coerce_timestamps(TimeUnit::MILLI)->build()));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertSchemaEqual(*ex_milli_result->schema(),
+ *milli_result->schema(),
+ /*check_metadata=*/false));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*ex_milli_result, *milli_result));
+
+ // Result when coercing to microseconds
+ auto s3 = ::arrow::schema({field("f_s", t_us), field("f_ms", t_us), field("f_us", t_us),
+ field("f_ns", t_us)});
+ auto ex_micro_result = Table::Make(s3, {a_us, a_us, a_us, a_us});
+ std::shared_ptr<Table> micro_result;
+ ASSERT_NO_FATAL_FAILURE(DoSimpleRoundtrip(
+ input, false /* use_threads */, input->num_rows(), {}, &micro_result,
+ ArrowWriterProperties::Builder().coerce_timestamps(TimeUnit::MICRO)->build()));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertSchemaEqual(*ex_micro_result->schema(),
+ *micro_result->schema(),
+ /*check_metadata=*/false));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*ex_micro_result, *micro_result));
+}
+
+TEST(TestArrowReadWrite, CoerceTimestampsLosePrecision) {
+ using ::arrow::ArrayFromVector;
+ using ::arrow::field;
+
+ // PARQUET-1078, coerce Arrow timestamps to either TIMESTAMP_MILLIS or TIMESTAMP_MICROS
+ std::vector<bool> is_valid = {true, true, true, false, true, true};
+
+ auto t_s = ::arrow::timestamp(TimeUnit::SECOND);
+ auto t_ms = ::arrow::timestamp(TimeUnit::MILLI);
+ auto t_us = ::arrow::timestamp(TimeUnit::MICRO);
+ auto t_ns = ::arrow::timestamp(TimeUnit::NANO);
+
+ std::vector<int64_t> s_values = {1489269, 1489270, 1489271, 1489272, 1489272, 1489273};
+ std::vector<int64_t> ms_values = {1489269001, 1489270001, 1489271001,
+ 1489272001, 1489272001, 1489273001};
+ std::vector<int64_t> us_values = {1489269000001, 1489270000001, 1489271000001,
+ 1489272000001, 1489272000001, 1489273000001};
+ std::vector<int64_t> ns_values = {1489269000000001LL, 1489270000000001LL,
+ 1489271000000001LL, 1489272000000001LL,
+ 1489272000000001LL, 1489273000000001LL};
+
+ std::shared_ptr<Array> a_s, a_ms, a_us, a_ns;
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_s, is_valid, s_values, &a_s);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_ms, is_valid, ms_values, &a_ms);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_us, is_valid, us_values, &a_us);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_ns, is_valid, ns_values, &a_ns);
+
+ auto s1 = ::arrow::schema({field("f_s", t_s)});
+ auto s2 = ::arrow::schema({field("f_ms", t_ms)});
+ auto s3 = ::arrow::schema({field("f_us", t_us)});
+ auto s4 = ::arrow::schema({field("f_ns", t_ns)});
+
+ auto t1 = Table::Make(s1, {a_s});
+ auto t2 = Table::Make(s2, {a_ms});
+ auto t3 = Table::Make(s3, {a_us});
+ auto t4 = Table::Make(s4, {a_ns});
+
+ // OK to write to millis
+ auto coerce_millis =
+ (ArrowWriterProperties::Builder().coerce_timestamps(TimeUnit::MILLI)->build());
+ ASSERT_OK_NO_THROW(WriteTable(*t1, ::arrow::default_memory_pool(), CreateOutputStream(),
+ 10, default_writer_properties(), coerce_millis));
+
+ ASSERT_OK_NO_THROW(WriteTable(*t2, ::arrow::default_memory_pool(), CreateOutputStream(),
+ 10, default_writer_properties(), coerce_millis));
+
+ // Loss of precision
+ ASSERT_RAISES(Invalid,
+ WriteTable(*t3, ::arrow::default_memory_pool(), CreateOutputStream(), 10,
+ default_writer_properties(), coerce_millis));
+ ASSERT_RAISES(Invalid,
+ WriteTable(*t4, ::arrow::default_memory_pool(), CreateOutputStream(), 10,
+ default_writer_properties(), coerce_millis));
+
+ // OK to lose micros/nanos -> millis precision if we explicitly allow it
+ auto allow_truncation_to_millis = (ArrowWriterProperties::Builder()
+ .coerce_timestamps(TimeUnit::MILLI)
+ ->allow_truncated_timestamps()
+ ->build());
+ ASSERT_OK_NO_THROW(WriteTable(*t3, ::arrow::default_memory_pool(), CreateOutputStream(),
+ 10, default_writer_properties(),
+ allow_truncation_to_millis));
+ ASSERT_OK_NO_THROW(WriteTable(*t4, ::arrow::default_memory_pool(), CreateOutputStream(),
+ 10, default_writer_properties(),
+ allow_truncation_to_millis));
+
+ // OK to write to micros
+ auto coerce_micros =
+ (ArrowWriterProperties::Builder().coerce_timestamps(TimeUnit::MICRO)->build());
+ ASSERT_OK_NO_THROW(WriteTable(*t1, ::arrow::default_memory_pool(), CreateOutputStream(),
+ 10, default_writer_properties(), coerce_micros));
+ ASSERT_OK_NO_THROW(WriteTable(*t2, ::arrow::default_memory_pool(), CreateOutputStream(),
+ 10, default_writer_properties(), coerce_micros));
+ ASSERT_OK_NO_THROW(WriteTable(*t3, ::arrow::default_memory_pool(), CreateOutputStream(),
+ 10, default_writer_properties(), coerce_micros));
+
+ // Loss of precision
+ ASSERT_RAISES(Invalid,
+ WriteTable(*t4, ::arrow::default_memory_pool(), CreateOutputStream(), 10,
+ default_writer_properties(), coerce_micros));
+
+ // OK to lose nanos -> micros precision if we explicitly allow it
+ auto allow_truncation_to_micros = (ArrowWriterProperties::Builder()
+ .coerce_timestamps(TimeUnit::MICRO)
+ ->allow_truncated_timestamps()
+ ->build());
+ ASSERT_OK_NO_THROW(WriteTable(*t4, ::arrow::default_memory_pool(), CreateOutputStream(),
+ 10, default_writer_properties(),
+ allow_truncation_to_micros));
+}
+
+TEST(TestArrowReadWrite, ImplicitSecondToMillisecondTimestampCoercion) {
+ using ::arrow::ArrayFromVector;
+ using ::arrow::field;
+ using ::arrow::schema;
+
+ std::vector<bool> is_valid = {true, true, true, false, true, true};
+
+ auto t_s = ::arrow::timestamp(TimeUnit::SECOND);
+ auto t_ms = ::arrow::timestamp(TimeUnit::MILLI);
+
+ std::vector<int64_t> s_values = {1489269, 1489270, 1489271, 1489272, 1489272, 1489273};
+ std::vector<int64_t> ms_values = {1489269000, 1489270000, 1489271000,
+ 1489272000, 1489272000, 1489273000};
+
+ std::shared_ptr<Array> a_s, a_ms;
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_s, is_valid, s_values, &a_s);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_ms, is_valid, ms_values, &a_ms);
+
+ auto si = schema({field("timestamp", t_s)});
+ auto sx = schema({field("timestamp", t_ms)});
+
+ auto ti = Table::Make(si, {a_s}); // input
+ auto tx = Table::Make(sx, {a_ms}); // expected output
+ std::shared_ptr<Table> to; // actual output
+
+ // default properties (without explicit coercion instructions) used ...
+ ASSERT_NO_FATAL_FAILURE(
+ DoSimpleRoundtrip(ti, false /* use_threads */, ti->num_rows(), {}, &to));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertSchemaEqual(*tx->schema(), *to->schema(),
+ /*check_metadata=*/false));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*tx, *to));
+}
+
+TEST(TestArrowReadWrite, ParquetVersionTimestampDifferences) {
+ using ::arrow::ArrayFromVector;
+ using ::arrow::field;
+ using ::arrow::schema;
+
+ auto t_s = ::arrow::timestamp(TimeUnit::SECOND);
+ auto t_ms = ::arrow::timestamp(TimeUnit::MILLI);
+ auto t_us = ::arrow::timestamp(TimeUnit::MICRO);
+ auto t_ns = ::arrow::timestamp(TimeUnit::NANO);
+
+ const int N = 24;
+ int64_t instant = INT64_C(1262304000); // 2010-01-01T00:00:00 seconds offset
+ std::vector<int64_t> d_s, d_ms, d_us, d_ns;
+ for (int i = 0; i < N; ++i) {
+ d_s.push_back(instant);
+ d_ms.push_back(instant * INT64_C(1000));
+ d_us.push_back(instant * INT64_C(1000000));
+ d_ns.push_back(instant * INT64_C(1000000000));
+ instant += 3600;
+ }
+
+ std::shared_ptr<Array> a_s, a_ms, a_us, a_ns;
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_s, d_s, &a_s);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_ms, d_ms, &a_ms);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_us, d_us, &a_us);
+ ArrayFromVector<::arrow::TimestampType, int64_t>(t_ns, d_ns, &a_ns);
+
+ auto input_schema = schema({field("ts:s", t_s), field("ts:ms", t_ms),
+ field("ts:us", t_us), field("ts:ns", t_ns)});
+ auto input_table = Table::Make(input_schema, {a_s, a_ms, a_us, a_ns});
+
+ auto parquet_version_1_properties = ::parquet::default_writer_properties();
+ ARROW_SUPPRESS_DEPRECATION_WARNING
+ auto parquet_version_2_0_properties = ::parquet::WriterProperties::Builder()
+ .version(ParquetVersion::PARQUET_2_0)
+ ->build();
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
+ auto parquet_version_2_4_properties = ::parquet::WriterProperties::Builder()
+ .version(ParquetVersion::PARQUET_2_4)
+ ->build();
+ auto parquet_version_2_6_properties = ::parquet::WriterProperties::Builder()
+ .version(ParquetVersion::PARQUET_2_6)
+ ->build();
+ const std::vector<std::shared_ptr<WriterProperties>> all_properties = {
+ parquet_version_1_properties, parquet_version_2_0_properties,
+ parquet_version_2_4_properties, parquet_version_2_6_properties};
+
+ {
+ // Using Parquet version 1.0 and 2.4 defaults, seconds should be coerced to
+ // milliseconds and nanoseconds should be coerced to microseconds
+ auto expected_schema = schema({field("ts:s", t_ms), field("ts:ms", t_ms),
+ field("ts:us", t_us), field("ts:ns", t_us)});
+ auto expected_table = Table::Make(expected_schema, {a_ms, a_ms, a_us, a_us});
+ ASSERT_NO_FATAL_FAILURE(CheckConfiguredRoundtrip(input_table, expected_table,
+ parquet_version_1_properties));
+ ASSERT_NO_FATAL_FAILURE(CheckConfiguredRoundtrip(input_table, expected_table,
+ parquet_version_2_4_properties));
+ }
+ {
+ // Using Parquet version 2.0 and 2.6 defaults, seconds should be coerced to
+ // milliseconds and nanoseconds should be retained
+ auto expected_schema = schema({field("ts:s", t_ms), field("ts:ms", t_ms),
+ field("ts:us", t_us), field("ts:ns", t_ns)});
+ auto expected_table = Table::Make(expected_schema, {a_ms, a_ms, a_us, a_ns});
+ ASSERT_NO_FATAL_FAILURE(CheckConfiguredRoundtrip(input_table, expected_table,
+ parquet_version_2_0_properties));
+ ASSERT_NO_FATAL_FAILURE(CheckConfiguredRoundtrip(input_table, expected_table,
+ parquet_version_2_6_properties));
+ }
+
+ auto arrow_coerce_to_seconds_properties =
+ ArrowWriterProperties::Builder().coerce_timestamps(TimeUnit::SECOND)->build();
+ auto arrow_coerce_to_millis_properties =
+ ArrowWriterProperties::Builder().coerce_timestamps(TimeUnit::MILLI)->build();
+ auto arrow_coerce_to_micros_properties =
+ ArrowWriterProperties::Builder().coerce_timestamps(TimeUnit::MICRO)->build();
+ auto arrow_coerce_to_nanos_properties =
+ ArrowWriterProperties::Builder().coerce_timestamps(TimeUnit::NANO)->build();
+
+ for (const auto& properties : all_properties) {
+ // Using all Parquet versions, coercing to milliseconds or microseconds is allowed
+ ARROW_SCOPED_TRACE("format = ", ParquetVersionToString(properties->version()));
+ auto expected_schema = schema({field("ts:s", t_ms), field("ts:ms", t_ms),
+ field("ts:us", t_ms), field("ts:ns", t_ms)});
+ auto expected_table = Table::Make(expected_schema, {a_ms, a_ms, a_ms, a_ms});
+ ASSERT_NO_FATAL_FAILURE(CheckConfiguredRoundtrip(
+ input_table, expected_table, properties, arrow_coerce_to_millis_properties));
+
+ expected_schema = schema({field("ts:s", t_us), field("ts:ms", t_us),
+ field("ts:us", t_us), field("ts:ns", t_us)});
+ expected_table = Table::Make(expected_schema, {a_us, a_us, a_us, a_us});
+ ASSERT_NO_FATAL_FAILURE(CheckConfiguredRoundtrip(
+ input_table, expected_table, properties, arrow_coerce_to_micros_properties));
+
+ // Neither Parquet version allows coercing to seconds
+ std::shared_ptr<Table> actual_table;
+ ASSERT_RAISES(NotImplemented,
+ WriteTable(*input_table, ::arrow::default_memory_pool(),
+ CreateOutputStream(), input_table->num_rows(), properties,
+ arrow_coerce_to_seconds_properties));
+ }
+ // Using Parquet versions 1.0 and 2.4, coercing to (int64) nanoseconds is not allowed
+ for (const auto& properties :
+ {parquet_version_1_properties, parquet_version_2_4_properties}) {
+ ARROW_SCOPED_TRACE("format = ", ParquetVersionToString(properties->version()));
+ std::shared_ptr<Table> actual_table;
+ ASSERT_RAISES(NotImplemented,
+ WriteTable(*input_table, ::arrow::default_memory_pool(),
+ CreateOutputStream(), input_table->num_rows(), properties,
+ arrow_coerce_to_nanos_properties));
+ }
+ // Using Parquet versions "2.0" and 2.6, coercing to (int64) nanoseconds is allowed
+ for (const auto& properties :
+ {parquet_version_2_0_properties, parquet_version_2_6_properties}) {
+ ARROW_SCOPED_TRACE("format = ", ParquetVersionToString(properties->version()));
+ auto expected_schema = schema({field("ts:s", t_ns), field("ts:ms", t_ns),
+ field("ts:us", t_ns), field("ts:ns", t_ns)});
+ auto expected_table = Table::Make(expected_schema, {a_ns, a_ns, a_ns, a_ns});
+ ASSERT_NO_FATAL_FAILURE(CheckConfiguredRoundtrip(
+ input_table, expected_table, properties, arrow_coerce_to_nanos_properties));
+ }
+
+ // Using all Parquet versions, coercing to nanoseconds is allowed if Int96
+ // storage is used
+ auto arrow_enable_int96_properties =
+ ArrowWriterProperties::Builder().enable_deprecated_int96_timestamps()->build();
+ for (const auto& properties : all_properties) {
+ ARROW_SCOPED_TRACE("format = ", ParquetVersionToString(properties->version()));
+ auto expected_schema = schema({field("ts:s", t_ns), field("ts:ms", t_ns),
+ field("ts:us", t_ns), field("ts:ns", t_ns)});
+ auto expected_table = Table::Make(expected_schema, {a_ns, a_ns, a_ns, a_ns});
+ ASSERT_NO_FATAL_FAILURE(CheckConfiguredRoundtrip(
+ input_table, expected_table, properties, arrow_enable_int96_properties));
+ }
+}
+
+TEST(TestArrowReadWrite, ConvertedDateTimeTypes) {
+ using ::arrow::ArrayFromVector;
+
+ std::vector<bool> is_valid = {true, true, true, false, true, true};
+
+ auto f0 = field("f0", ::arrow::date64());
+ auto f1 = field("f1", ::arrow::time32(TimeUnit::SECOND));
+ auto f2 = field("f2", ::arrow::date64());
+ auto f3 = field("f3", ::arrow::time32(TimeUnit::SECOND));
+
+ auto schema = ::arrow::schema({f0, f1, f2, f3});
+
+ std::vector<int64_t> a0_values = {1489190400000, 1489276800000, 1489363200000,
+ 1489449600000, 1489536000000, 1489622400000};
+ std::vector<int32_t> a1_values = {0, 1, 2, 3, 4, 5};
+
+ std::shared_ptr<Array> a0, a1, a0_nonnull, a1_nonnull, x0, x1, x0_nonnull, x1_nonnull;
+
+ ArrayFromVector<::arrow::Date64Type, int64_t>(f0->type(), is_valid, a0_values, &a0);
+ ArrayFromVector<::arrow::Date64Type, int64_t>(f0->type(), a0_values, &a0_nonnull);
+
+ ArrayFromVector<::arrow::Time32Type, int32_t>(f1->type(), is_valid, a1_values, &a1);
+ ArrayFromVector<::arrow::Time32Type, int32_t>(f1->type(), a1_values, &a1_nonnull);
+
+ auto table = Table::Make(schema, {a0, a1, a0_nonnull, a1_nonnull});
+
+ // Expected schema and values
+ auto e0 = field("f0", ::arrow::date32());
+ auto e1 = field("f1", ::arrow::time32(TimeUnit::MILLI));
+ auto e2 = field("f2", ::arrow::date32());
+ auto e3 = field("f3", ::arrow::time32(TimeUnit::MILLI));
+ auto ex_schema = ::arrow::schema({e0, e1, e2, e3});
+
+ std::vector<int32_t> x0_values = {17236, 17237, 17238, 17239, 17240, 17241};
+ std::vector<int32_t> x1_values = {0, 1000, 2000, 3000, 4000, 5000};
+ ArrayFromVector<::arrow::Date32Type, int32_t>(e0->type(), is_valid, x0_values, &x0);
+ ArrayFromVector<::arrow::Date32Type, int32_t>(e0->type(), x0_values, &x0_nonnull);
+
+ ArrayFromVector<::arrow::Time32Type, int32_t>(e1->type(), is_valid, x1_values, &x1);
+ ArrayFromVector<::arrow::Time32Type, int32_t>(e1->type(), x1_values, &x1_nonnull);
+
+ auto ex_table = Table::Make(ex_schema, {x0, x1, x0_nonnull, x1_nonnull});
+
+ std::shared_ptr<Table> result;
+ ASSERT_NO_FATAL_FAILURE(
+ DoSimpleRoundtrip(table, false /* use_threads */, table->num_rows(), {}, &result));
+
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertSchemaEqual(*ex_table->schema(),
+ *result->schema(),
+ /*check_metadata=*/false));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*ex_table, *result));
+}
+
+void MakeDoubleTable(int num_columns, int num_rows, int nchunks,
+ std::shared_ptr<Table>* out) {
+ std::vector<std::shared_ptr<::arrow::ChunkedArray>> columns(num_columns);
+ std::vector<std::shared_ptr<::arrow::Field>> fields(num_columns);
+
+ for (int i = 0; i < num_columns; ++i) {
+ std::vector<std::shared_ptr<Array>> arrays;
+ std::shared_ptr<Array> values;
+ ASSERT_OK(NullableArray<::arrow::DoubleType>(num_rows, num_rows / 10,
+ static_cast<uint32_t>(i), &values));
+ std::stringstream ss;
+ ss << "col" << i;
+
+ for (int j = 0; j < nchunks; ++j) {
+ arrays.push_back(values);
+ }
+ columns[i] = std::make_shared<ChunkedArray>(arrays);
+ fields[i] = ::arrow::field(ss.str(), values->type());
+ }
+ auto schema = std::make_shared<::arrow::Schema>(fields);
+ *out = Table::Make(schema, columns, num_rows);
+}
+
+void MakeSimpleListArray(int num_rows, int max_value_length, const std::string& item_name,
+ std::shared_ptr<DataType>* out_type,
+ std::shared_ptr<Array>* out_array) {
+ std::vector<int32_t> length_draws;
+ randint(num_rows, 0, max_value_length, &length_draws);
+
+ std::vector<int32_t> offset_values;
+
+ // Make sure some of them are length 0
+ int32_t total_elements = 0;
+ for (size_t i = 0; i < length_draws.size(); ++i) {
+ if (length_draws[i] < max_value_length / 10) {
+ length_draws[i] = 0;
+ }
+ offset_values.push_back(total_elements);
+ total_elements += length_draws[i];
+ }
+ offset_values.push_back(total_elements);
+
+ std::vector<int8_t> value_draws;
+ randint(total_elements, 0, 100, &value_draws);
+
+ std::vector<bool> is_valid;
+ random_is_valid(total_elements, 0.1, &is_valid);
+
+ std::shared_ptr<Array> values, offsets;
+ ::arrow::ArrayFromVector<::arrow::Int8Type, int8_t>(::arrow::int8(), is_valid,
+ value_draws, &values);
+ ::arrow::ArrayFromVector<::arrow::Int32Type, int32_t>(offset_values, &offsets);
+
+ *out_type = ::arrow::list(::arrow::field(item_name, ::arrow::int8()));
+ *out_array = std::make_shared<ListArray>(*out_type, offsets->length() - 1,
+ offsets->data()->buffers[1], values);
+}
+
+TEST(TestArrowReadWrite, MultithreadedRead) {
+ const int num_columns = 20;
+ const int num_rows = 1000;
+ const bool use_threads = true;
+
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(MakeDoubleTable(num_columns, num_rows, 1, &table));
+
+ std::shared_ptr<Table> result;
+ ASSERT_NO_FATAL_FAILURE(
+ DoSimpleRoundtrip(table, use_threads, table->num_rows(), {}, &result));
+
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*table, *result));
+}
+
+TEST(TestArrowReadWrite, ReadSingleRowGroup) {
+ const int num_columns = 10;
+ const int num_rows = 100;
+
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(MakeDoubleTable(num_columns, num_rows, 1, &table));
+
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_NO_FATAL_FAILURE(WriteTableToBuffer(table, num_rows / 2,
+ default_arrow_writer_properties(), &buffer));
+
+ std::unique_ptr<FileReader> reader;
+ ASSERT_OK_NO_THROW(OpenFile(std::make_shared<BufferReader>(buffer),
+ ::arrow::default_memory_pool(), &reader));
+
+ ASSERT_EQ(2, reader->num_row_groups());
+
+ std::shared_ptr<Table> r1, r2, r3, r4;
+ // Read everything
+ ASSERT_OK_NO_THROW(reader->ReadRowGroup(0, &r1));
+ ASSERT_OK_NO_THROW(reader->RowGroup(1)->ReadTable(&r2));
+ ASSERT_OK_NO_THROW(reader->ReadRowGroups({0, 1}, &r3));
+ ASSERT_OK_NO_THROW(reader->ReadRowGroups({1}, &r4));
+
+ std::shared_ptr<Table> concatenated;
+
+ ASSERT_OK_AND_ASSIGN(concatenated, ::arrow::ConcatenateTables({r1, r2}));
+ AssertTablesEqual(*concatenated, *table, /*same_chunk_layout=*/false);
+
+ AssertTablesEqual(*table, *r3, /*same_chunk_layout=*/false);
+ ASSERT_TRUE(r2->Equals(*r4));
+ ASSERT_OK_AND_ASSIGN(concatenated, ::arrow::ConcatenateTables({r1, r4}));
+
+ AssertTablesEqual(*table, *concatenated, /*same_chunk_layout=*/false);
+}
+
+// Exercise reading table manually with nested RowGroup and Column loops, i.e.
+//
+// for (int i = 0; i < n_row_groups; i++)
+// for (int j = 0; j < n_cols; j++)
+// reader->RowGroup(i)->Column(j)->Read(&chunked_array);
+::arrow::Result<std::shared_ptr<Table>> ReadTableManually(FileReader* reader) {
+ std::vector<std::shared_ptr<Table>> tables;
+
+ std::shared_ptr<::arrow::Schema> schema;
+ RETURN_NOT_OK(reader->GetSchema(&schema));
+
+ int n_row_groups = reader->num_row_groups();
+ int n_columns = schema->num_fields();
+ for (int i = 0; i < n_row_groups; i++) {
+ std::vector<std::shared_ptr<ChunkedArray>> columns{static_cast<size_t>(n_columns)};
+
+ for (int j = 0; j < n_columns; j++) {
+ RETURN_NOT_OK(reader->RowGroup(i)->Column(j)->Read(&columns[j]));
+ }
+
+ tables.push_back(Table::Make(schema, columns));
+ }
+
+ return ConcatenateTables(tables);
+}
+
+TEST(TestArrowReadWrite, ReadTableManually) {
+ const int num_columns = 1;
+ const int num_rows = 128;
+
+ std::shared_ptr<Table> expected;
+ ASSERT_NO_FATAL_FAILURE(MakeDoubleTable(num_columns, num_rows, 1, &expected));
+
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_NO_FATAL_FAILURE(WriteTableToBuffer(expected, num_rows / 2,
+ default_arrow_writer_properties(), &buffer));
+
+ std::unique_ptr<FileReader> reader;
+ ASSERT_OK_NO_THROW(OpenFile(std::make_shared<BufferReader>(buffer),
+ ::arrow::default_memory_pool(), &reader));
+
+ ASSERT_EQ(2, reader->num_row_groups());
+
+ ASSERT_OK_AND_ASSIGN(auto actual, ReadTableManually(reader.get()));
+
+ AssertTablesEqual(*actual, *expected, /*same_chunk_layout=*/false);
+}
+
+void TestGetRecordBatchReader(
+ ArrowReaderProperties properties = default_arrow_reader_properties()) {
+ const int num_columns = 20;
+ const int num_rows = 1000;
+ const int batch_size = 100;
+
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(MakeDoubleTable(num_columns, num_rows, 1, &table));
+
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_NO_FATAL_FAILURE(WriteTableToBuffer(table, num_rows / 2,
+ default_arrow_writer_properties(), &buffer));
+
+ properties.set_batch_size(batch_size);
+
+ std::unique_ptr<FileReader> reader;
+ FileReaderBuilder builder;
+ ASSERT_OK(builder.Open(std::make_shared<BufferReader>(buffer)));
+ ASSERT_OK(builder.properties(properties)->Build(&reader));
+
+ // Read the whole file, one batch at a time.
+ std::shared_ptr<::arrow::RecordBatchReader> rb_reader;
+ ASSERT_OK_NO_THROW(reader->GetRecordBatchReader({0, 1}, &rb_reader));
+ std::shared_ptr<::arrow::RecordBatch> actual_batch, expected_batch;
+ ::arrow::TableBatchReader table_reader(*table);
+ table_reader.set_chunksize(batch_size);
+
+ for (int i = 0; i < 10; ++i) {
+ ASSERT_OK(rb_reader->ReadNext(&actual_batch));
+ ASSERT_OK(table_reader.ReadNext(&expected_batch));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertBatchesEqual(*expected_batch, *actual_batch));
+ }
+
+ ASSERT_OK(rb_reader->ReadNext(&actual_batch));
+ ASSERT_EQ(nullptr, actual_batch);
+
+ // ARROW-6005: Read just the second row group
+ ASSERT_OK_NO_THROW(reader->GetRecordBatchReader({1}, &rb_reader));
+ std::shared_ptr<Table> second_rowgroup = table->Slice(num_rows / 2);
+ ::arrow::TableBatchReader second_table_reader(*second_rowgroup);
+ second_table_reader.set_chunksize(batch_size);
+
+ for (int i = 0; i < 5; ++i) {
+ ASSERT_OK(rb_reader->ReadNext(&actual_batch));
+ ASSERT_OK(second_table_reader.ReadNext(&expected_batch));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertBatchesEqual(*expected_batch, *actual_batch));
+ }
+
+ ASSERT_OK(rb_reader->ReadNext(&actual_batch));
+ ASSERT_EQ(nullptr, actual_batch);
+}
+
+TEST(TestArrowReadWrite, GetRecordBatchReader) { TestGetRecordBatchReader(); }
+
+// Same as the test above, but using coalesced reads.
+TEST(TestArrowReadWrite, CoalescedReads) {
+ ArrowReaderProperties arrow_properties = default_arrow_reader_properties();
+ arrow_properties.set_pre_buffer(true);
+ TestGetRecordBatchReader(arrow_properties);
+}
+
+// Use coalesced reads, and explicitly wait for I/O to complete.
+TEST(TestArrowReadWrite, WaitCoalescedReads) {
+ ArrowReaderProperties properties = default_arrow_reader_properties();
+ const int num_rows = 10;
+ const int num_columns = 5;
+
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(MakeDoubleTable(num_columns, num_rows, 1, &table));
+
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_NO_FATAL_FAILURE(
+ WriteTableToBuffer(table, num_rows, default_arrow_writer_properties(), &buffer));
+
+ std::unique_ptr<FileReader> reader;
+ FileReaderBuilder builder;
+ ASSERT_OK(builder.Open(std::make_shared<BufferReader>(buffer)));
+ ASSERT_OK(builder.properties(properties)->Build(&reader));
+ // Pre-buffer data and wait for I/O to complete.
+ reader->parquet_reader()->PreBuffer({0}, {0, 1, 2, 3, 4}, ::arrow::io::IOContext(),
+ ::arrow::io::CacheOptions::Defaults());
+ ASSERT_OK(reader->parquet_reader()->WhenBuffered({0}, {0, 1, 2, 3, 4}).status());
+
+ std::shared_ptr<::arrow::RecordBatchReader> rb_reader;
+ ASSERT_OK_NO_THROW(reader->GetRecordBatchReader({0}, {0, 1, 2, 3, 4}, &rb_reader));
+
+ std::shared_ptr<::arrow::RecordBatch> actual_batch;
+ ASSERT_OK(rb_reader->ReadNext(&actual_batch));
+
+ ASSERT_NE(actual_batch, nullptr);
+ ASSERT_EQ(actual_batch->num_columns(), num_columns);
+ ASSERT_EQ(actual_batch->num_rows(), num_rows);
+}
+
+TEST(TestArrowReadWrite, GetRecordBatchReaderNoColumns) {
+ ArrowReaderProperties properties = default_arrow_reader_properties();
+ const int num_rows = 10;
+ const int num_columns = 20;
+
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(MakeDoubleTable(num_columns, num_rows, 1, &table));
+
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_NO_FATAL_FAILURE(
+ WriteTableToBuffer(table, num_rows, default_arrow_writer_properties(), &buffer));
+
+ std::unique_ptr<FileReader> reader;
+ FileReaderBuilder builder;
+ ASSERT_OK(builder.Open(std::make_shared<BufferReader>(buffer)));
+ ASSERT_OK(builder.properties(properties)->Build(&reader));
+
+ std::shared_ptr<::arrow::RecordBatchReader> rb_reader;
+ ASSERT_OK_NO_THROW(reader->GetRecordBatchReader({0}, {}, &rb_reader));
+
+ std::shared_ptr<::arrow::RecordBatch> actual_batch;
+ ASSERT_OK(rb_reader->ReadNext(&actual_batch));
+
+ ASSERT_NE(actual_batch, nullptr);
+ ASSERT_EQ(actual_batch->num_columns(), 0);
+ ASSERT_EQ(actual_batch->num_rows(), num_rows);
+}
+
+TEST(TestArrowReadWrite, GetRecordBatchGenerator) {
+ ArrowReaderProperties properties = default_arrow_reader_properties();
+ const int num_rows = 1024;
+ const int row_group_size = 512;
+ const int num_columns = 2;
+
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(MakeDoubleTable(num_columns, num_rows, 1, &table));
+
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_NO_FATAL_FAILURE(WriteTableToBuffer(table, row_group_size,
+ default_arrow_writer_properties(), &buffer));
+
+ std::shared_ptr<FileReader> reader;
+ {
+ std::unique_ptr<FileReader> unique_reader;
+ FileReaderBuilder builder;
+ ASSERT_OK(builder.Open(std::make_shared<BufferReader>(buffer)));
+ ASSERT_OK(builder.properties(properties)->Build(&unique_reader));
+ reader = std::move(unique_reader);
+ }
+
+ auto check_batches = [](const std::shared_ptr<::arrow::RecordBatch>& batch,
+ int num_columns, int num_rows) {
+ ASSERT_NE(batch, nullptr);
+ ASSERT_EQ(batch->num_columns(), num_columns);
+ ASSERT_EQ(batch->num_rows(), num_rows);
+ };
+ {
+ ASSERT_OK_AND_ASSIGN(auto batch_generator,
+ reader->GetRecordBatchGenerator(reader, {0, 1}, {0, 1}));
+ auto fut1 = batch_generator();
+ auto fut2 = batch_generator();
+ auto fut3 = batch_generator();
+ ASSERT_OK_AND_ASSIGN(auto batch1, fut1.result());
+ ASSERT_OK_AND_ASSIGN(auto batch2, fut2.result());
+ ASSERT_OK_AND_ASSIGN(auto batch3, fut3.result());
+ ASSERT_EQ(batch3, nullptr);
+ check_batches(batch1, num_columns, row_group_size);
+ check_batches(batch2, num_columns, row_group_size);
+ ASSERT_OK_AND_ASSIGN(auto actual, ::arrow::Table::FromRecordBatches(
+ batch1->schema(), {batch1, batch2}));
+ AssertTablesEqual(*table, *actual, /*same_chunk_layout=*/false);
+ }
+ {
+ // No columns case
+ ASSERT_OK_AND_ASSIGN(auto batch_generator,
+ reader->GetRecordBatchGenerator(reader, {0, 1}, {}));
+ auto fut1 = batch_generator();
+ auto fut2 = batch_generator();
+ auto fut3 = batch_generator();
+ ASSERT_OK_AND_ASSIGN(auto batch1, fut1.result());
+ ASSERT_OK_AND_ASSIGN(auto batch2, fut2.result());
+ ASSERT_OK_AND_ASSIGN(auto batch3, fut3.result());
+ ASSERT_EQ(batch3, nullptr);
+ check_batches(batch1, 0, row_group_size);
+ check_batches(batch2, 0, row_group_size);
+ }
+}
+
+TEST(TestArrowReadWrite, ScanContents) {
+ const int num_columns = 20;
+ const int num_rows = 1000;
+
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(MakeDoubleTable(num_columns, num_rows, 1, &table));
+
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_NO_FATAL_FAILURE(WriteTableToBuffer(table, num_rows / 2,
+ default_arrow_writer_properties(), &buffer));
+
+ std::unique_ptr<FileReader> reader;
+ ASSERT_OK_NO_THROW(OpenFile(std::make_shared<BufferReader>(buffer),
+ ::arrow::default_memory_pool(), &reader));
+
+ int64_t num_rows_returned = 0;
+ ASSERT_OK_NO_THROW(reader->ScanContents({}, 256, &num_rows_returned));
+ ASSERT_EQ(num_rows, num_rows_returned);
+
+ ASSERT_OK_NO_THROW(reader->ScanContents({0, 1, 2}, 256, &num_rows_returned));
+ ASSERT_EQ(num_rows, num_rows_returned);
+}
+
+TEST(TestArrowReadWrite, ReadColumnSubset) {
+ const int num_columns = 20;
+ const int num_rows = 1000;
+ const bool use_threads = true;
+
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(MakeDoubleTable(num_columns, num_rows, 1, &table));
+
+ std::shared_ptr<Table> result;
+ std::vector<int> column_subset = {0, 4, 8, 10};
+ ASSERT_NO_FATAL_FAILURE(
+ DoSimpleRoundtrip(table, use_threads, table->num_rows(), column_subset, &result));
+
+ std::vector<std::shared_ptr<::arrow::ChunkedArray>> ex_columns;
+ std::vector<std::shared_ptr<::arrow::Field>> ex_fields;
+ for (int i : column_subset) {
+ ex_columns.push_back(table->column(i));
+ ex_fields.push_back(table->field(i));
+ }
+
+ auto ex_schema = ::arrow::schema(ex_fields);
+ auto expected = Table::Make(ex_schema, ex_columns);
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*expected, *result));
+}
+
+TEST(TestArrowReadWrite, ReadCoalescedColumnSubset) {
+ const int num_columns = 20;
+ const int num_rows = 1000;
+
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(MakeDoubleTable(num_columns, num_rows, 1, &table));
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_NO_FATAL_FAILURE(WriteTableToBuffer(table, num_rows / 2,
+ default_arrow_writer_properties(), &buffer));
+
+ std::unique_ptr<FileReader> reader;
+ FileReaderBuilder builder;
+ ReaderProperties properties = default_reader_properties();
+ ArrowReaderProperties arrow_properties = default_arrow_reader_properties();
+ arrow_properties.set_pre_buffer(true);
+ ASSERT_OK(builder.Open(std::make_shared<BufferReader>(buffer), properties));
+ ASSERT_OK(builder.properties(arrow_properties)->Build(&reader));
+ reader->set_use_threads(true);
+
+ // Test multiple subsets to ensure we can read from the file multiple times
+ std::vector<std::vector<int>> column_subsets = {
+ {0, 4, 8, 10}, {0, 1, 2, 3}, {5, 17, 18, 19}};
+
+ for (std::vector<int>& column_subset : column_subsets) {
+ std::shared_ptr<Table> result;
+ ASSERT_OK(reader->ReadTable(column_subset, &result));
+
+ std::vector<std::shared_ptr<::arrow::ChunkedArray>> ex_columns;
+ std::vector<std::shared_ptr<::arrow::Field>> ex_fields;
+ for (int i : column_subset) {
+ ex_columns.push_back(table->column(i));
+ ex_fields.push_back(table->field(i));
+ }
+
+ auto ex_schema = ::arrow::schema(ex_fields);
+ auto expected = Table::Make(ex_schema, ex_columns);
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*expected, *result));
+ }
+}
+
+TEST(TestArrowReadWrite, ListLargeRecords) {
+ // PARQUET-1308: This test passed on Linux when num_rows was smaller
+ const int num_rows = 2000;
+ const int row_group_size = 100;
+
+ std::shared_ptr<Array> list_array;
+ std::shared_ptr<DataType> list_type;
+
+ MakeSimpleListArray(num_rows, 20, "item", &list_type, &list_array);
+
+ auto schema = ::arrow::schema({::arrow::field("a", list_type)});
+
+ std::shared_ptr<Table> table = Table::Make(schema, {list_array});
+
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_NO_FATAL_FAILURE(WriteTableToBuffer(table, row_group_size,
+ default_arrow_writer_properties(), &buffer));
+
+ std::unique_ptr<FileReader> reader;
+ ASSERT_OK_NO_THROW(OpenFile(std::make_shared<BufferReader>(buffer),
+ ::arrow::default_memory_pool(), &reader));
+
+ // Read everything
+ std::shared_ptr<Table> result;
+ ASSERT_OK_NO_THROW(reader->ReadTable(&result));
+ ASSERT_NO_FATAL_FAILURE(::arrow::AssertTablesEqual(*table, *result));
+
+ // Read 1 record at a time
+ ASSERT_OK_NO_THROW(OpenFile(std::make_shared<BufferReader>(buffer),
+ ::arrow::default_memory_pool(), &reader));
+
+ std::unique_ptr<ColumnReader> col_reader;
+ ASSERT_OK(reader->GetColumn(0, &col_reader));
+
+ std::vector<std::shared_ptr<Array>> pieces;
+ for (int i = 0; i < num_rows; ++i) {
+ std::shared_ptr<ChunkedArray> chunked_piece;
+ ASSERT_OK(col_reader->NextBatch(1, &chunked_piece));
+ ASSERT_EQ(1, chunked_piece->length());
+ ASSERT_EQ(1, chunked_piece->num_chunks());
+ pieces.push_back(chunked_piece->chunk(0));
+ }
+ auto chunked = std::make_shared<::arrow::ChunkedArray>(pieces);
+ auto chunked_table = Table::Make(table->schema(), {chunked});
+
+ ASSERT_TRUE(table->Equals(*chunked_table));
+}
+
+typedef std::function<void(int, std::shared_ptr<DataType>*, std::shared_ptr<Array>*)>
+ ArrayFactory;
+
+template <typename ArrowType>
+struct GenerateArrayFunctor {
+ explicit GenerateArrayFunctor(double pct_null = 0.1) : pct_null(pct_null) {}
+
+ void operator()(int length, std::shared_ptr<DataType>* type,
+ std::shared_ptr<Array>* array) {
+ using T = typename ArrowType::c_type;
+
+ // TODO(wesm): generate things other than integers
+ std::vector<T> draws;
+ randint(length, 0, 100, &draws);
+
+ std::vector<bool> is_valid;
+ random_is_valid(length, this->pct_null, &is_valid);
+
+ *type = ::arrow::TypeTraits<ArrowType>::type_singleton();
+ ::arrow::ArrayFromVector<ArrowType, T>(*type, is_valid, draws, array);
+ }
+
+ double pct_null;
+};
+
+typedef std::function<void(int, std::shared_ptr<DataType>*, std::shared_ptr<Array>*)>
+ ArrayFactory;
+
+auto GenerateInt32 = [](int length, std::shared_ptr<DataType>* type,
+ std::shared_ptr<Array>* array) {
+ GenerateArrayFunctor<::arrow::Int32Type> func;
+ func(length, type, array);
+};
+
+auto GenerateList = [](int length, std::shared_ptr<DataType>* type,
+ std::shared_ptr<Array>* array) {
+ MakeSimpleListArray(length, 100, "element", type, array);
+};
+
+std::shared_ptr<Table> InvalidTable() {
+ auto type = ::arrow::int8();
+ auto field = ::arrow::field("a", type);
+ auto schema = ::arrow::schema({field, field});
+
+ // Invalid due to array size not matching
+ auto array1 = ArrayFromJSON(type, "[1, 2]");
+ auto array2 = ArrayFromJSON(type, "[1]");
+ return Table::Make(schema, {array1, array2});
+}
+
+TEST(TestArrowReadWrite, InvalidTable) {
+ // ARROW-4774: Shouldn't segfault on writing an invalid table.
+ auto sink = CreateOutputStream();
+ auto invalid_table = InvalidTable();
+
+ ASSERT_RAISES(Invalid, WriteTable(*invalid_table, ::arrow::default_memory_pool(),
+ CreateOutputStream(), 1, default_writer_properties(),
+ default_arrow_writer_properties()));
+}
+
+TEST(TestArrowReadWrite, TableWithChunkedColumns) {
+ std::vector<ArrayFactory> functions = {GenerateInt32, GenerateList};
+
+ std::vector<int> chunk_sizes = {2, 4, 10, 2};
+ const int64_t total_length = 18;
+
+ for (const auto& datagen_func : functions) {
+ ::arrow::ArrayVector arrays;
+ std::shared_ptr<Array> arr;
+ std::shared_ptr<DataType> type;
+ datagen_func(total_length, &type, &arr);
+
+ int64_t offset = 0;
+ for (int chunk_size : chunk_sizes) {
+ arrays.push_back(arr->Slice(offset, chunk_size));
+ offset += chunk_size;
+ }
+
+ auto field = ::arrow::field("fname", type);
+ auto schema = ::arrow::schema({field});
+ auto table = Table::Make(schema, {std::make_shared<ChunkedArray>(arrays)});
+
+ ASSERT_NO_FATAL_FAILURE(CheckSimpleRoundtrip(table, 2));
+ ASSERT_NO_FATAL_FAILURE(CheckSimpleRoundtrip(table, 3));
+ ASSERT_NO_FATAL_FAILURE(CheckSimpleRoundtrip(table, 10));
+ }
+}
+
+TEST(TestArrowReadWrite, ManySmallLists) {
+ // ARROW-11607: The actual scenario this forces is no data reads for
+ // a first batch, and then a single element read for the second batch.
+
+ // Constructs
+ std::shared_ptr<::arrow::Int32Builder> value_builder =
+ std::make_shared<::arrow::Int32Builder>();
+ constexpr int64_t kNullCount = 6;
+ auto type = ::arrow::list(::arrow::int32());
+ std::vector<std::shared_ptr<Array>> arrays(1);
+ arrays[0] = ArrayFromJSON(type, R"([null, null, null, null, null, null, [1]])");
+
+ auto field = ::arrow::field("fname", type);
+ auto schema = ::arrow::schema({field});
+ auto table = Table::Make(schema, {std::make_shared<ChunkedArray>(arrays)});
+ ASSERT_EQ(table->num_rows(), kNullCount + 1);
+
+ CheckSimpleRoundtrip(table, /*row_group_size=*/kNullCount,
+ default_arrow_writer_properties());
+}
+
+TEST(TestArrowReadWrite, TableWithDuplicateColumns) {
+ // See ARROW-1974
+ using ::arrow::ArrayFromVector;
+
+ auto f0 = field("duplicate", ::arrow::int8());
+ auto f1 = field("duplicate", ::arrow::int16());
+ auto schema = ::arrow::schema({f0, f1});
+
+ std::vector<int8_t> a0_values = {1, 2, 3};
+ std::vector<int16_t> a1_values = {14, 15, 16};
+
+ std::shared_ptr<Array> a0, a1;
+
+ ArrayFromVector<::arrow::Int8Type, int8_t>(a0_values, &a0);
+ ArrayFromVector<::arrow::Int16Type, int16_t>(a1_values, &a1);
+
+ auto table = Table::Make(schema, {a0, a1});
+ ASSERT_NO_FATAL_FAILURE(CheckSimpleRoundtrip(table, table->num_rows()));
+}
+
+TEST(ArrowReadWrite, EmptyStruct) {
+ // ARROW-10928: empty struct type not supported
+ {
+ // Empty struct as only column
+ auto fields = ::arrow::FieldVector{
+ ::arrow::field("structs", ::arrow::struct_(::arrow::FieldVector{}))};
+ auto schema = ::arrow::schema(fields);
+ auto columns = ArrayVector{ArrayFromJSON(fields[0]->type(), "[null, {}]")};
+ auto table = Table::Make(schema, columns);
+
+ auto sink = CreateOutputStream();
+ ASSERT_RAISES(
+ NotImplemented,
+ WriteTable(*table, ::arrow::default_memory_pool(), sink, /*chunk_size=*/1,
+ default_writer_properties(), default_arrow_writer_properties()));
+ }
+ {
+ // Empty struct as nested column
+ auto fields = ::arrow::FieldVector{::arrow::field(
+ "structs", ::arrow::list(::arrow::struct_(::arrow::FieldVector{})))};
+ auto schema = ::arrow::schema(fields);
+ auto columns =
+ ArrayVector{ArrayFromJSON(fields[0]->type(), "[null, [], [null, {}]]")};
+ auto table = Table::Make(schema, columns);
+
+ auto sink = CreateOutputStream();
+ ASSERT_RAISES(
+ NotImplemented,
+ WriteTable(*table, ::arrow::default_memory_pool(), sink, /*chunk_size=*/1,
+ default_writer_properties(), default_arrow_writer_properties()));
+ }
+ {
+ // Empty struct along other column
+ auto fields = ::arrow::FieldVector{
+ ::arrow::field("structs", ::arrow::struct_(::arrow::FieldVector{})),
+ ::arrow::field("ints", ::arrow::int32())};
+ auto schema = ::arrow::schema(fields);
+ auto columns = ArrayVector{ArrayFromJSON(fields[0]->type(), "[null, {}]"),
+ ArrayFromJSON(fields[1]->type(), "[1, 2]")};
+ auto table = Table::Make(schema, columns);
+
+ auto sink = CreateOutputStream();
+ ASSERT_RAISES(
+ NotImplemented,
+ WriteTable(*table, ::arrow::default_memory_pool(), sink, /*chunk_size=*/1,
+ default_writer_properties(), default_arrow_writer_properties()));
+ }
+}
+
+TEST(ArrowReadWrite, SimpleStructRoundTrip) {
+ auto links = field(
+ "Links", ::arrow::struct_({field("Backward", ::arrow::int64(), /*nullable=*/true),
+ field("Forward", ::arrow::int64(), /*nullable=*/true)}));
+
+ auto links_id_array = ::arrow::ArrayFromJSON(links->type(),
+ "[{\"Backward\": null, \"Forward\": 20}, "
+ "{\"Backward\": 10, \"Forward\": 40}]");
+
+ CheckSimpleRoundtrip(
+ ::arrow::Table::Make(std::make_shared<::arrow::Schema>(
+ std::vector<std::shared_ptr<::arrow::Field>>{links}),
+ {links_id_array}),
+ 2);
+}
+
+TEST(ArrowReadWrite, SingleColumnNullableStruct) {
+ auto links =
+ field("Links",
+ ::arrow::struct_({field("Backward", ::arrow::int64(), /*nullable=*/true)}));
+
+ auto links_id_array = ::arrow::ArrayFromJSON(links->type(),
+ "[null, "
+ "{\"Backward\": 10}"
+ "]");
+
+ CheckSimpleRoundtrip(
+ ::arrow::Table::Make(std::make_shared<::arrow::Schema>(
+ std::vector<std::shared_ptr<::arrow::Field>>{links}),
+ {links_id_array}),
+ 3);
+}
+
+TEST(ArrowReadWrite, NestedRequiredField) {
+ auto int_field = ::arrow::field("int_array", ::arrow::int32(), /*nullable=*/false);
+ auto int_array = ::arrow::ArrayFromJSON(int_field->type(), "[0, 1, 2, 3, 4, 5, 7, 8]");
+ auto struct_field =
+ ::arrow::field("root", ::arrow::struct_({int_field}), /*nullable=*/true);
+ std::shared_ptr<Buffer> validity_bitmap;
+ ASSERT_OK_AND_ASSIGN(validity_bitmap, ::arrow::AllocateBitmap(8));
+ validity_bitmap->mutable_data()[0] = 0xCC;
+
+ auto struct_data = ArrayData::Make(struct_field->type(), /*length=*/8,
+ {validity_bitmap}, {int_array->data()});
+ CheckSimpleRoundtrip(::arrow::Table::Make(::arrow::schema({struct_field}),
+ {::arrow::MakeArray(struct_data)}),
+ /*row_group_size=*/8);
+}
+
+TEST(ArrowReadWrite, Decimal256) {
+ using ::arrow::Decimal256;
+ using ::arrow::field;
+
+ auto type = ::arrow::decimal256(8, 4);
+
+ const char* json = R"(["1.0000", null, "-1.2345", "-1000.5678",
+ "-9999.9999", "9999.9999"])";
+ auto array = ::arrow::ArrayFromJSON(type, json);
+ auto table = ::arrow::Table::Make(::arrow::schema({field("root", type)}), {array});
+ auto props_store_schema = ArrowWriterProperties::Builder().store_schema()->build();
+ CheckSimpleRoundtrip(table, 2, props_store_schema);
+}
+
+TEST(ArrowReadWrite, DecimalStats) {
+ using ::arrow::Decimal128;
+ using ::arrow::field;
+
+ auto type = ::arrow::decimal128(/*precision=*/8, /*scale=*/0);
+
+ const char* json = R"(["255", "128", null, "0", "1", "-127", "-128", "-129", "-255"])";
+ auto array = ::arrow::ArrayFromJSON(type, json);
+ auto table = ::arrow::Table::Make(::arrow::schema({field("root", type)}), {array});
+
+ std::shared_ptr<Buffer> buffer;
+ ASSERT_NO_FATAL_FAILURE(WriteTableToBuffer(table, /*row_grop_size=*/100,
+ default_arrow_writer_properties(), &buffer));
+
+ std::unique_ptr<FileReader> reader;
+ ASSERT_OK_NO_THROW(OpenFile(std::make_shared<BufferReader>(buffer),
+ ::arrow::default_memory_pool(), &reader));
+
+ std::shared_ptr<Scalar> min, max;
+ ReadSingleColumnFileStatistics(std::move(reader), &min, &max);
+
+ std::shared_ptr<Scalar> expected_min, expected_max;
+ ASSERT_OK_AND_ASSIGN(expected_min, array->GetScalar(array->length() - 1));
+ ASSERT_OK_AND_ASSIGN(expected_max, array->GetScalar(0));
+ ::arrow::AssertScalarsEqual(*expected_min, *min, /*verbose=*/true);
+ ::arrow::AssertScalarsEqual(*expected_max, *max, /*verbose=*/true);
+}
+
+TEST(ArrowReadWrite, NestedNullableField) {
+ auto int_field = ::arrow::field("int_array", ::arrow::int32());
+ auto int_array =
+ ::arrow::ArrayFromJSON(int_field->type(), "[0, null, 2, null, 4, 5, null, 8]");
+ auto struct_field =
+ ::arrow::field("root", ::arrow::struct_({int_field}), /*nullable=*/true);
+ std::shared_ptr<Buffer> validity_bitmap;
+ ASSERT_OK_AND_ASSIGN(validity_bitmap, ::arrow::AllocateBitmap(8));
+ validity_bitmap->mutable_data()[0] = 0xCC;
+
+ auto struct_data = ArrayData::Make(struct_field->type(), /*length=*/8,
+ {validity_bitmap}, {int_array->data()});
+ CheckSimpleRoundtrip(::arrow::Table::Make(::arrow::schema({struct_field}),
+ {::arrow::MakeArray(struct_data)}),
+ /*row_group_size=*/8);
+}
+
+TEST(TestArrowReadWrite, CanonicalNestedRoundTrip) {
+ auto doc_id = field("DocId", ::arrow::int64(), /*nullable=*/false);
+ auto links = field(
+ "Links",
+ ::arrow::struct_({field("Backward", list(::arrow::int64()), /*nullable=*/false),
+ field("Forward", list(::arrow::int64()), /*nullable=*/false)}));
+ auto name_struct = field(
+ "NameStruct",
+ ::arrow::struct_(
+ {field("Language",
+ ::arrow::list(field(
+ "lang_struct",
+ ::arrow::struct_({field("Code", ::arrow::utf8(), /*nullable=*/false),
+ field("Country", ::arrow::utf8())})))),
+ field("Url", ::arrow::utf8())}));
+ auto name = field("Name", ::arrow::list(name_struct), /*nullable=*/false);
+ auto schema = std::make_shared<::arrow::Schema>(
+ std::vector<std::shared_ptr<::arrow::Field>>({doc_id, links, name}));
+
+ auto doc_id_array = ::arrow::ArrayFromJSON(doc_id->type(), "[10, 20]");
+ auto links_id_array =
+ ::arrow::ArrayFromJSON(links->type(),
+ "[{\"Backward\":[], \"Forward\":[20, 40, 60]}, "
+ "{\"Backward\":[10, 30], \"Forward\":[80]}]");
+
+ // Written without C++11 string literal because many editors don't have C++11
+ // string literals implemented properly
+ auto name_array = ::arrow::ArrayFromJSON(
+ name->type(),
+ "[[{\"Language\": [{\"Code\": \"en_us\", \"Country\":\"us\"},"
+ "{\"Code\": \"en_us\", \"Country\": null}],"
+ "\"Url\": \"http://A\"},"
+ "{\"Url\": \"http://B\"},"
+ "{\"Language\": [{\"Code\": \"en-gb\", \"Country\": \"gb\"}]}],"
+ "[{\"Url\": \"http://C\"}]]");
+ auto expected =
+ ::arrow::Table::Make(schema, {doc_id_array, links_id_array, name_array});
+ CheckSimpleRoundtrip(expected, 2);
+}
+
+TEST(ArrowReadWrite, ListOfStruct) {
+ using ::arrow::field;
+
+ auto type = ::arrow::list(::arrow::struct_(
+ {field("a", ::arrow::int16(), /*nullable=*/false), field("b", ::arrow::utf8())}));
+
+ const char* json = R"([
+ [{"a": 4, "b": "foo"}, {"a": 5}, {"a": 6, "b": "bar"}],
+ [null, {"a": 7}],
+ null,
+ []])";
+ auto array = ::arrow::ArrayFromJSON(type, json);
+ auto table = ::arrow::Table::Make(::arrow::schema({field("root", type)}), {array});
+ CheckSimpleRoundtrip(table, 2);
+}
+
+TEST(ArrowReadWrite, ListOfStructOfList1) {
+ using ::arrow::field;
+ using ::arrow::list;
+ using ::arrow::struct_;
+
+ auto type = list(struct_({field("a", ::arrow::int16(), /*nullable=*/false),
+ field("b", list(::arrow::int64()))}));
+
+ const char* json = R"([
+ [{"a": 123, "b": [1, 2, null, 3]}, null],
+ null,
+ [],
+ [{"a": 456}, {"a": 789, "b": []}, {"a": 876, "b": [4, 5, 6]}]])";
+ auto array = ::arrow::ArrayFromJSON(type, json);
+ auto table = ::arrow::Table::Make(::arrow::schema({field("root", type)}), {array});
+ CheckSimpleRoundtrip(table, 2);
+}
+
+TEST(ArrowReadWrite, ListWithNoValues) {
+ using ::arrow::Buffer;
+ using ::arrow::field;
+
+ auto type = list(field("item", ::arrow::int32(), /*nullable=*/false));
+ auto array = ::arrow::ArrayFromJSON(type, "[null, []]");
+ auto table = ::arrow::Table::Make(::arrow::schema({field("root", type)}), {array});
+ auto props_store_schema = ArrowWriterProperties::Builder().store_schema()->build();
+ CheckSimpleRoundtrip(table, 2, props_store_schema);
+}
+
+TEST(ArrowReadWrite, Map) {
+ using ::arrow::field;
+ using ::arrow::map;
+
+ auto type = map(::arrow::int16(), ::arrow::utf8());
+
+ const char* json = R"([
+ [[1, "a"], [2, "b"]],
+ [[3, "c"]],
+ [],
+ null,
+ [[4, "d"], [5, "e"], [6, "f"]]
+ ])";
+ auto array = ::arrow::ArrayFromJSON(type, json);
+ auto table = ::arrow::Table::Make(::arrow::schema({field("root", type)}), {array});
+ auto props_store_schema = ArrowWriterProperties::Builder().store_schema()->build();
+ CheckSimpleRoundtrip(table, 2, props_store_schema);
+}
+
+TEST(ArrowReadWrite, LargeList) {
+ using ::arrow::field;
+ using ::arrow::large_list;
+ using ::arrow::struct_;
+
+ auto type = large_list(::arrow::int16());
+
+ const char* json = R"([
+ [1, 2, 3],
+ [4, 5, 6],
+ [7, 8, 9]])";
+ auto array = ::arrow::ArrayFromJSON(type, json);
+ auto table = ::arrow::Table::Make(::arrow::schema({field("root", type)}), {array});
+ auto props_store_schema = ArrowWriterProperties::Builder().store_schema()->build();
+ CheckSimpleRoundtrip(table, 2, props_store_schema);
+}
+
+TEST(ArrowReadWrite, FixedSizeList) {
+ using ::arrow::field;
+ using ::arrow::fixed_size_list;
+ using ::arrow::struct_;
+
+ auto type = fixed_size_list(::arrow::int16(), /*size=*/3);
+
+ const char* json = R"([
+ [1, 2, 3],
+ [4, 5, 6],
+ [7, 8, 9]])";
+ auto array = ::arrow::ArrayFromJSON(type, json);
+ auto table = ::arrow::Table::Make(::arrow::schema({field("root", type)}), {array});
+ auto props_store_schema = ArrowWriterProperties::Builder().store_schema()->build();
+ CheckSimpleRoundtrip(table, 2, props_store_schema);
+}
+
+TEST(ArrowReadWrite, ListOfStructOfList2) {
+ using ::arrow::field;
+ using ::arrow::list;
+ using ::arrow::struct_;
+
+ auto type =
+ list(field("item",
+ struct_({field("a", ::arrow::int16(), /*nullable=*/false),
+ field("b", list(::arrow::int64()), /*nullable=*/false)}),
+ /*nullable=*/false));
+
+ const char* json = R"([
+ [{"a": 123, "b": [1, 2, 3]}],
+ null,
+ [],
+ [{"a": 456, "b": []}, {"a": 789, "b": [null]}, {"a": 876, "b": [4, 5, 6]}]])";
+ auto array = ::arrow::ArrayFromJSON(type, json);
+ auto table = ::arrow::Table::Make(::arrow::schema({field("root", type)}), {array});
+ CheckSimpleRoundtrip(table, 2);
+}
+
+TEST(ArrowReadWrite, StructOfLists) {
+ using ::arrow::field;
+ using ::arrow::list;
+
+ auto type = ::arrow::struct_(
+ {field("a", list(::arrow::utf8()), /*nullable=*/false),
+ field("b", list(field("f", ::arrow::int64(), /*nullable=*/false)))});
+
+ const char* json = R"([
+ {"a": ["1", "2"], "b": []},
+ {"a": [], "b": [3, 4, 5]},
+ {"a": ["6"], "b": null},
+ {"a": [null, "7"], "b": [8]}])";
+ auto array = ::arrow::ArrayFromJSON(type, json);
+ auto table = ::arrow::Table::Make(::arrow::schema({field("root", type)}), {array});
+ CheckSimpleRoundtrip(table, 2);
+}
+
+TEST(ArrowReadWrite, ListOfStructOfLists1) {
+ using ::arrow::field;
+ using ::arrow::list;
+
+ auto type = list(::arrow::struct_(
+ {field("a", list(::arrow::utf8()), /*nullable=*/false),
+ field("b", list(field("f", ::arrow::int64(), /*nullable=*/false)))}));
+
+ const char* json = R"([
+ [{"a": ["1", "2"], "b": []}, null],
+ [],
+ null,
+ [null],
+ [{"a": [], "b": [3, 4, 5]}, {"a": ["6"], "b": null}],
+ [null, {"a": [null, "7"], "b": [8]}]])";
+ auto array = ::arrow::ArrayFromJSON(type, json);
+ auto table = ::arrow::Table::Make(::arrow::schema({field("root", type)}), {array});
+ CheckSimpleRoundtrip(table, 2);
+}
+
+TEST(ArrowReadWrite, ListOfStructOfLists2) {
+ using ::arrow::field;
+ using ::arrow::list;
+
+ auto type = list(
+ field("x",
+ ::arrow::struct_(
+ {field("a", list(::arrow::utf8()), /*nullable=*/false),
+ field("b", list(field("f", ::arrow::int64(), /*nullable=*/false)))}),
+ /*nullable=*/false));
+
+ const char* json = R"([
+ [{"a": ["1", "2"], "b": []}],
+ [],
+ null,
+ [],
+ [{"a": [], "b": [3, 4, 5]}, {"a": ["6"], "b": null}],
+ [{"a": [null, "7"], "b": [8]}]])";
+ auto array = ::arrow::ArrayFromJSON(type, json);
+ auto table = ::arrow::Table::Make(::arrow::schema({field("root", type)}), {array});
+ CheckSimpleRoundtrip(table, 2);
+}
+
+TEST(ArrowReadWrite, ListOfStructOfLists3) {
+ using ::arrow::field;
+ using ::arrow::list;
+
+ auto type = list(field(
+ "x",
+ ::arrow::struct_({field("a", list(::arrow::utf8()), /*nullable=*/false),
+ field("b", list(field("f", ::arrow::int64(), /*nullable=*/false)),
+ /*nullable=*/false)}),
+ /*nullable=*/false));
+
+ const char* json = R"([
+ [{"a": ["1", "2"], "b": []}],
+ [],
+ null,
+ [],
+ [{"a": [], "b": [3, 4, 5]}, {"a": ["6"], "b": []}],
+ [{"a": [null, "7"], "b": [8]}]])";
+ auto array = ::arrow::ArrayFromJSON(type, json);
+ auto table = ::arrow::Table::Make(::arrow::schema({field("root", type)}), {array});
+ CheckSimpleRoundtrip(table, 2);
+}
+
+TEST(TestArrowReadWrite, DictionaryColumnChunkedWrite) {
+ // This is a regression test for this:
+ //
+ // https://issues.apache.org/jira/browse/ARROW-1938
+ //
+ // As of the writing of this test, columns of type
+ // dictionary are written as their raw/expanded values.
+ // The regression was that the whole column was being
+ // written for each chunk.
+ using ::arrow::ArrayFromVector;
+
+ std::vector<std::string> values = {"first", "second", "third"};
+ auto type = ::arrow::utf8();
+ std::shared_ptr<Array> dict_values;
+ ArrayFromVector<::arrow::StringType, std::string>(values, &dict_values);
+
+ auto value_type = ::arrow::utf8();
+ auto dict_type = ::arrow::dictionary(::arrow::int32(), value_type);
+
+ auto f0 = field("dictionary", dict_type);
+ std::vector<std::shared_ptr<::arrow::Field>> fields;
+ fields.emplace_back(f0);
+ auto schema = ::arrow::schema(fields);
+
+ std::shared_ptr<Array> f0_values, f1_values;
+ ArrayFromVector<::arrow::Int32Type, int32_t>({0, 1, 0, 2, 1}, &f0_values);
+ ArrayFromVector<::arrow::Int32Type, int32_t>({2, 0, 1, 0, 2}, &f1_values);
+ ::arrow::ArrayVector dict_arrays = {
+ std::make_shared<::arrow::DictionaryArray>(dict_type, f0_values, dict_values),
+ std::make_shared<::arrow::DictionaryArray>(dict_type, f1_values, dict_values)};
+
+ std::vector<std::shared_ptr<ChunkedArray>> columns;
+ columns.emplace_back(std::make_shared<ChunkedArray>(dict_arrays));
+
+ auto table = Table::Make(schema, columns);
+
+ std::shared_ptr<Table> result;
+ ASSERT_NO_FATAL_FAILURE(DoSimpleRoundtrip(table, 1,
+ // Just need to make sure that we make
+ // a chunk size that is smaller than the
+ // total number of values
+ 2, {}, &result));
+
+ std::vector<std::string> expected_values = {"first", "second", "first", "third",
+ "second", "third", "first", "second",
+ "first", "third"};
+ columns.clear();
+
+ std::shared_ptr<Array> expected_array;
+ ArrayFromVector<::arrow::StringType, std::string>(expected_values, &expected_array);
+
+ // The column name gets changed on output to the name of the
+ // field, and it also turns into a nullable column
+ columns.emplace_back(std::make_shared<ChunkedArray>(expected_array));
+
+ schema = ::arrow::schema({::arrow::field("dictionary", ::arrow::utf8())});
+
+ auto expected_table = Table::Make(schema, columns);
+
+ ::arrow::AssertTablesEqual(*expected_table, *result, false);
+}
+
+TEST(TestArrowReadWrite, NonUniqueDictionaryValues) {
+ // ARROW-10237
+ auto dict_with_dupes = ArrayFromJSON(::arrow::utf8(), R"(["a", "a", "b"])");
+ // test with all valid 4-long `indices`
+ for (int i = 0; i < 4 * 4 * 4 * 4; ++i) {
+ int j = i;
+ ASSERT_OK_AND_ASSIGN(
+ auto indices,
+ ArrayFromBuilderVisitor(::arrow::int32(), 4, [&](::arrow::Int32Builder* b) {
+ if (j % 4 < dict_with_dupes->length()) {
+ b->UnsafeAppend(j % 4);
+ } else {
+ b->UnsafeAppendNull();
+ }
+ j /= 4;
+ }));
+ ASSERT_OK_AND_ASSIGN(auto plain, ::arrow::compute::Take(*dict_with_dupes, *indices));
+ ASSERT_OK_AND_ASSIGN(auto encoded,
+ ::arrow::DictionaryArray::FromArrays(indices, dict_with_dupes));
+
+ auto table = Table::Make(::arrow::schema({::arrow::field("d", encoded->type())}),
+ ::arrow::ArrayVector{encoded});
+
+ ASSERT_OK(table->ValidateFull());
+
+ std::shared_ptr<Table> round_tripped;
+ ASSERT_NO_FATAL_FAILURE(DoSimpleRoundtrip(table, true, 20, {}, &round_tripped));
+
+ ASSERT_OK(round_tripped->ValidateFull());
+ ::arrow::AssertArraysEqual(*plain, *round_tripped->column(0)->chunk(0), true);
+ }
+}
+
+TEST(TestArrowWrite, CheckChunkSize) {
+ const int num_columns = 2;
+ const int num_rows = 128;
+ const int64_t chunk_size = 0; // note the chunk_size is 0
+ std::shared_ptr<Table> table;
+ ASSERT_NO_FATAL_FAILURE(MakeDoubleTable(num_columns, num_rows, 1, &table));
+
+ auto sink = CreateOutputStream();
+
+ ASSERT_RAISES(Invalid,
+ WriteTable(*table, ::arrow::default_memory_pool(), sink, chunk_size));
+}
+
+class TestNestedSchemaRead : public ::testing::TestWithParam<Repetition::type> {
+ protected:
+ // make it *3 to make it easily divisible by 3
+ const int NUM_SIMPLE_TEST_ROWS = SMALL_SIZE * 3;
+ std::shared_ptr<::arrow::Int32Array> values_array_ = nullptr;
+
+ void InitReader() {
+ ASSERT_OK_AND_ASSIGN(auto buffer, nested_parquet_->Finish());
+ ASSERT_OK_NO_THROW(OpenFile(std::make_shared<BufferReader>(buffer),
+ ::arrow::default_memory_pool(), &reader_));
+ }
+
+ void InitNewParquetFile(const std::shared_ptr<GroupNode>& schema, int num_rows) {
+ nested_parquet_ = CreateOutputStream();
+
+ writer_ = parquet::ParquetFileWriter::Open(nested_parquet_, schema,
+ default_writer_properties());
+ row_group_writer_ = writer_->AppendRowGroup();
+ }
+
+ void FinalizeParquetFile() {
+ row_group_writer_->Close();
+ writer_->Close();
+ }
+
+ void MakeValues(int num_rows) {
+ std::shared_ptr<Array> arr;
+ ASSERT_OK(NullableArray<::arrow::Int32Type>(num_rows, 0, kDefaultSeed, &arr));
+ values_array_ = std::dynamic_pointer_cast<::arrow::Int32Array>(arr);
+ }
+
+ void WriteColumnData(size_t num_rows, int16_t* def_levels, int16_t* rep_levels,
+ int32_t* values) {
+ auto typed_writer =
+ static_cast<TypedColumnWriter<Int32Type>*>(row_group_writer_->NextColumn());
+ typed_writer->WriteBatch(num_rows, def_levels, rep_levels, values);
+ }
+
+ void ValidateArray(const Array& array, size_t expected_nulls) {
+ ASSERT_EQ(array.length(), values_array_->length());
+ ASSERT_EQ(array.null_count(), expected_nulls);
+ // Also independently count the nulls
+ auto local_null_count = 0;
+ for (int i = 0; i < array.length(); i++) {
+ if (array.IsNull(i)) {
+ local_null_count++;
+ }
+ }
+ ASSERT_EQ(local_null_count, expected_nulls);
+ ASSERT_OK(array.ValidateFull());
+ }
+
+ void ValidateColumnArray(const ::arrow::Int32Array& array, size_t expected_nulls) {
+ ValidateArray(array, expected_nulls);
+ int j = 0;
+ for (int i = 0; i < values_array_->length(); i++) {
+ if (array.IsNull(i)) {
+ continue;
+ }
+ ASSERT_EQ(array.Value(i), values_array_->Value(j));
+ j++;
+ }
+ }
+
+ void ValidateTableArrayTypes(const Table& table) {
+ for (int i = 0; i < table.num_columns(); i++) {
+ const std::shared_ptr<::arrow::Field> schema_field = table.schema()->field(i);
+ const std::shared_ptr<ChunkedArray> column = table.column(i);
+ // Compare with the array type
+ ASSERT_TRUE(schema_field->type()->Equals(column->chunk(0)->type()));
+ }
+ }
+
+ // A parquet with a simple nested schema
+ void CreateSimpleNestedParquet(Repetition::type struct_repetition) {
+ std::vector<NodePtr> parquet_fields;
+ // TODO(itaiin): We are using parquet low-level file api to create the nested parquet
+ // this needs to change when a nested writes are implemented
+
+ // create the schema:
+ // <struct_repetition> group group1 {
+ // required int32 leaf1;
+ // optional int32 leaf2;
+ // }
+ // required int32 leaf3;
+
+ parquet_fields.push_back(GroupNode::Make(
+ "group1", struct_repetition,
+ {PrimitiveNode::Make("leaf1", Repetition::REQUIRED, ParquetType::INT32),
+ PrimitiveNode::Make("leaf2", Repetition::OPTIONAL, ParquetType::INT32)}));
+ parquet_fields.push_back(
+ PrimitiveNode::Make("leaf3", Repetition::REQUIRED, ParquetType::INT32));
+
+ auto schema_node = GroupNode::Make("schema", Repetition::REQUIRED, parquet_fields);
+
+ // Create definition levels for the different columns that contain interleaved
+ // nulls and values at all nesting levels
+
+ // definition levels for optional fields
+ std::vector<int16_t> leaf1_def_levels(NUM_SIMPLE_TEST_ROWS);
+ std::vector<int16_t> leaf2_def_levels(NUM_SIMPLE_TEST_ROWS);
+ std::vector<int16_t> leaf3_def_levels(NUM_SIMPLE_TEST_ROWS);
+ for (int i = 0; i < NUM_SIMPLE_TEST_ROWS; i++) {
+ // leaf1 is required within the optional group1, so it is only null
+ // when the group is null
+ leaf1_def_levels[i] = (i % 3 == 0) ? 0 : 1;
+ // leaf2 is optional, can be null in the primitive (def-level 1) or
+ // struct level (def-level 0)
+ leaf2_def_levels[i] = static_cast<int16_t>(i % 3);
+ // leaf3 is required
+ leaf3_def_levels[i] = 0;
+ }
+
+ std::vector<int16_t> rep_levels(NUM_SIMPLE_TEST_ROWS, 0);
+
+ // Produce values for the columns
+ MakeValues(NUM_SIMPLE_TEST_ROWS);
+ int32_t* values = reinterpret_cast<int32_t*>(values_array_->values()->mutable_data());
+
+ // Create the actual parquet file
+ InitNewParquetFile(std::static_pointer_cast<GroupNode>(schema_node),
+ NUM_SIMPLE_TEST_ROWS);
+
+ // leaf1 column
+ WriteColumnData(NUM_SIMPLE_TEST_ROWS, leaf1_def_levels.data(), rep_levels.data(),
+ values);
+ // leaf2 column
+ WriteColumnData(NUM_SIMPLE_TEST_ROWS, leaf2_def_levels.data(), rep_levels.data(),
+ values);
+ // leaf3 column
+ WriteColumnData(NUM_SIMPLE_TEST_ROWS, leaf3_def_levels.data(), rep_levels.data(),
+ values);
+
+ FinalizeParquetFile();
+ InitReader();
+ }
+
+ NodePtr CreateSingleTypedNestedGroup(int index, int depth, int num_children,
+ Repetition::type node_repetition,
+ ParquetType::type leaf_type) {
+ std::vector<NodePtr> children;
+
+ for (int i = 0; i < num_children; i++) {
+ if (depth <= 1) {
+ children.push_back(PrimitiveNode::Make("leaf", node_repetition, leaf_type));
+ } else {
+ children.push_back(CreateSingleTypedNestedGroup(i, depth - 1, num_children,
+ node_repetition, leaf_type));
+ }
+ }
+
+ std::stringstream ss;
+ ss << "group-" << depth << "-" << index;
+ return NodePtr(GroupNode::Make(ss.str(), node_repetition, children));
+ }
+
+ // A deeply nested schema
+ void CreateMultiLevelNestedParquet(int num_trees, int tree_depth, int num_children,
+ int num_rows, Repetition::type node_repetition) {
+ // Create the schema
+ std::vector<NodePtr> parquet_fields;
+ for (int i = 0; i < num_trees; i++) {
+ parquet_fields.push_back(CreateSingleTypedNestedGroup(
+ i, tree_depth, num_children, node_repetition, ParquetType::INT32));
+ }
+ auto schema_node = GroupNode::Make("schema", Repetition::REQUIRED, parquet_fields);
+
+ int num_columns = num_trees * static_cast<int>((std::pow(num_children, tree_depth)));
+
+ std::vector<int16_t> def_levels;
+ std::vector<int16_t> rep_levels;
+
+ int num_levels = 0;
+ while (num_levels < num_rows) {
+ if (node_repetition == Repetition::REQUIRED) {
+ def_levels.push_back(0); // all are required
+ } else {
+ int16_t level = static_cast<int16_t>(num_levels % (tree_depth + 2));
+ def_levels.push_back(level); // all are optional
+ }
+ rep_levels.push_back(0); // none is repeated
+ ++num_levels;
+ }
+
+ // Produce values for the columns
+ MakeValues(num_rows);
+ int32_t* values = reinterpret_cast<int32_t*>(values_array_->values()->mutable_data());
+
+ // Create the actual parquet file
+ InitNewParquetFile(std::static_pointer_cast<GroupNode>(schema_node), num_rows);
+
+ for (int i = 0; i < num_columns; i++) {
+ WriteColumnData(num_rows, def_levels.data(), rep_levels.data(), values);
+ }
+ FinalizeParquetFile();
+ InitReader();
+ }
+
+ class DeepParquetTestVisitor : public ArrayVisitor {
+ public:
+ DeepParquetTestVisitor(Repetition::type node_repetition,
+ std::shared_ptr<::arrow::Int32Array> expected)
+ : node_repetition_(node_repetition), expected_(expected) {}
+
+ Status Validate(std::shared_ptr<Array> tree) { return tree->Accept(this); }
+
+ virtual Status Visit(const ::arrow::Int32Array& array) {
+ if (node_repetition_ == Repetition::REQUIRED) {
+ if (!array.Equals(expected_)) {
+ return Status::Invalid("leaf array data mismatch");
+ }
+ } else if (node_repetition_ == Repetition::OPTIONAL) {
+ if (array.length() != expected_->length()) {
+ return Status::Invalid("Bad leaf array length");
+ }
+ // expect only 1 value every `depth` row
+ if (array.null_count() != SMALL_SIZE) {
+ return Status::Invalid("Unexpected null count");
+ }
+ } else {
+ return Status::NotImplemented("Unsupported repetition");
+ }
+ return Status::OK();
+ }
+
+ virtual Status Visit(const ::arrow::StructArray& array) {
+ for (int32_t i = 0; i < array.num_fields(); ++i) {
+ auto child = array.field(i);
+ if (node_repetition_ == Repetition::REQUIRED) {
+ RETURN_NOT_OK(child->Accept(this));
+ } else if (node_repetition_ == Repetition::OPTIONAL) {
+ // Null count Must be a multiple of SMALL_SIZE
+ if (array.null_count() % SMALL_SIZE != 0) {
+ return Status::Invalid("Unexpected struct null count");
+ }
+ } else {
+ return Status::NotImplemented("Unsupported repetition");
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ Repetition::type node_repetition_;
+ std::shared_ptr<::arrow::Int32Array> expected_;
+ };
+
+ std::shared_ptr<::arrow::io::BufferOutputStream> nested_parquet_;
+ std::unique_ptr<FileReader> reader_;
+ std::unique_ptr<ParquetFileWriter> writer_;
+ RowGroupWriter* row_group_writer_;
+};
+
+TEST_F(TestNestedSchemaRead, ReadIntoTableFull) {
+ ASSERT_NO_FATAL_FAILURE(CreateSimpleNestedParquet(Repetition::OPTIONAL));
+
+ std::shared_ptr<Table> table;
+ ASSERT_OK_NO_THROW(reader_->ReadTable(&table));
+ ASSERT_EQ(table->num_rows(), NUM_SIMPLE_TEST_ROWS);
+ ASSERT_EQ(table->num_columns(), 2);
+ ASSERT_EQ(table->schema()->field(0)->type()->num_fields(), 2);
+ ASSERT_NO_FATAL_FAILURE(ValidateTableArrayTypes(*table));
+
+ auto struct_field_array =
+ std::static_pointer_cast<::arrow::StructArray>(table->column(0)->chunk(0));
+ auto leaf1_array =
+ std::static_pointer_cast<::arrow::Int32Array>(struct_field_array->field(0));
+ auto leaf2_array =
+ std::static_pointer_cast<::arrow::Int32Array>(struct_field_array->field(1));
+ auto leaf3_array =
+ std::static_pointer_cast<::arrow::Int32Array>(table->column(1)->chunk(0));
+
+ // validate struct and leaf arrays
+
+ // validate struct array
+ ASSERT_NO_FATAL_FAILURE(ValidateArray(*struct_field_array, NUM_SIMPLE_TEST_ROWS / 3));
+ // validate leaf1
+ ASSERT_NO_FATAL_FAILURE(ValidateColumnArray(*leaf1_array, NUM_SIMPLE_TEST_ROWS / 3));
+ // validate leaf2
+ ASSERT_NO_FATAL_FAILURE(
+ ValidateColumnArray(*leaf2_array, NUM_SIMPLE_TEST_ROWS * 2 / 3));
+ // validate leaf3
+ ASSERT_NO_FATAL_FAILURE(ValidateColumnArray(*leaf3_array, 0));
+}
+
+TEST_F(TestNestedSchemaRead, ReadTablePartial) {
+ ASSERT_NO_FATAL_FAILURE(CreateSimpleNestedParquet(Repetition::OPTIONAL));
+ std::shared_ptr<Table> table;
+
+ // columns: {group1.leaf1, leaf3}
+ ASSERT_OK_NO_THROW(reader_->ReadTable({0, 2}, &table));
+ ASSERT_EQ(table->num_rows(), NUM_SIMPLE_TEST_ROWS);
+ ASSERT_EQ(table->num_columns(), 2);
+ ASSERT_EQ(table->schema()->field(0)->name(), "group1");
+ ASSERT_EQ(table->schema()->field(1)->name(), "leaf3");
+ ASSERT_EQ(table->schema()->field(0)->type()->num_fields(), 1);
+ ASSERT_NO_FATAL_FAILURE(ValidateTableArrayTypes(*table));
+
+ // columns: {group1.leaf1, leaf3}
+ ASSERT_OK_NO_THROW(reader_->ReadRowGroup(0, {0, 2}, &table));
+ ASSERT_EQ(table->num_rows(), NUM_SIMPLE_TEST_ROWS);
+ ASSERT_EQ(table->num_columns(), 2);
+ ASSERT_EQ(table->schema()->field(0)->name(), "group1");
+ ASSERT_EQ(table->schema()->field(1)->name(), "leaf3");
+ ASSERT_EQ(table->schema()->field(0)->type()->num_fields(), 1);
+ ASSERT_NO_FATAL_FAILURE(ValidateTableArrayTypes(*table));
+
+ // columns: {group1.leaf1, group1.leaf2}
+ ASSERT_OK_NO_THROW(reader_->ReadTable({0, 1}, &table));
+ ASSERT_EQ(table->num_rows(), NUM_SIMPLE_TEST_ROWS);
+ ASSERT_EQ(table->num_columns(), 1);
+ ASSERT_EQ(table->schema()->field(0)->name(), "group1");
+ ASSERT_EQ(table->schema()->field(0)->type()->num_fields(), 2);
+ ASSERT_NO_FATAL_FAILURE(ValidateTableArrayTypes(*table));
+
+ // columns: {leaf3}
+ ASSERT_OK_NO_THROW(reader_->ReadTable({2}, &table));
+ ASSERT_EQ(table->num_rows(), NUM_SIMPLE_TEST_ROWS);
+ ASSERT_EQ(table->num_columns(), 1);
+ ASSERT_EQ(table->schema()->field(0)->name(), "leaf3");
+ ASSERT_EQ(table->schema()->field(0)->type()->num_fields(), 0);
+ ASSERT_NO_FATAL_FAILURE(ValidateTableArrayTypes(*table));
+
+ // Test with different ordering
+ ASSERT_OK_NO_THROW(reader_->ReadTable({2, 0}, &table));
+ ASSERT_EQ(table->num_rows(), NUM_SIMPLE_TEST_ROWS);
+ ASSERT_EQ(table->num_columns(), 2);
+ ASSERT_EQ(table->schema()->field(0)->name(), "leaf3");
+ ASSERT_EQ(table->schema()->field(1)->name(), "group1");
+ ASSERT_EQ(table->schema()->field(1)->type()->num_fields(), 1);
+ ASSERT_NO_FATAL_FAILURE(ValidateTableArrayTypes(*table));
+}
+
+TEST_P(TestNestedSchemaRead, DeepNestedSchemaRead) {
+#ifdef PARQUET_VALGRIND
+ const int num_trees = 3;
+ const int depth = 3;
+#else
+ const int num_trees = 2;
+ const int depth = 2;
+#endif
+ const int num_children = 3;
+ int num_rows = SMALL_SIZE * (depth + 2);
+ ASSERT_NO_FATAL_FAILURE(CreateMultiLevelNestedParquet(num_trees, depth, num_children,
+ num_rows, GetParam()));
+ std::shared_ptr<Table> table;
+ ASSERT_OK_NO_THROW(reader_->ReadTable(&table));
+ ASSERT_EQ(table->num_columns(), num_trees);
+ ASSERT_EQ(table->num_rows(), num_rows);
+
+ DeepParquetTestVisitor visitor(GetParam(), values_array_);
+ for (int i = 0; i < table->num_columns(); i++) {
+ auto tree = table->column(i)->chunk(0);
+ ASSERT_OK_NO_THROW(visitor.Validate(tree));
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(Repetition_type, TestNestedSchemaRead,
+ ::testing::Values(Repetition::REQUIRED, Repetition::OPTIONAL));
+
+TEST(TestImpalaConversion, ArrowTimestampToImpalaTimestamp) {
+ // June 20, 2017 16:32:56 and 123456789 nanoseconds
+ int64_t nanoseconds = INT64_C(1497976376123456789);
+
+ Int96 calculated;
+
+ Int96 expected = {{UINT32_C(632093973), UINT32_C(13871), UINT32_C(2457925)}};
+ ::parquet::internal::NanosecondsToImpalaTimestamp(nanoseconds, &calculated);
+ ASSERT_EQ(expected, calculated);
+}
+
+void TryReadDataFile(const std::string& path,
+ ::arrow::StatusCode expected_code = ::arrow::StatusCode::OK) {
+ auto pool = ::arrow::default_memory_pool();
+
+ std::unique_ptr<FileReader> arrow_reader;
+ Status s =
+ FileReader::Make(pool, ParquetFileReader::OpenFile(path, false), &arrow_reader);
+ if (s.ok()) {
+ std::shared_ptr<::arrow::Table> table;
+ s = arrow_reader->ReadTable(&table);
+ }
+
+ ASSERT_EQ(s.code(), expected_code)
+ << "Expected reading file to return " << arrow::Status::CodeAsString(expected_code)
+ << ", but got " << s.ToString();
+}
+
+TEST(TestArrowReaderAdHoc, Int96BadMemoryAccess) {
+ // PARQUET-995
+ TryReadDataFile(test::get_data_file("alltypes_plain.parquet"));
+}
+
+TEST(TestArrowReaderAdHoc, CorruptedSchema) {
+ // PARQUET-1481
+ auto path = test::get_data_file("PARQUET-1481.parquet", /*is_good=*/false);
+ TryReadDataFile(path, ::arrow::StatusCode::IOError);
+}
+
+TEST(TestArrowReaderAdHoc, LARGE_MEMORY_TEST(LargeStringColumn)) {
+ // ARROW-3762
+ ::arrow::StringBuilder builder;
+ int64_t length = 1 << 30;
+ ASSERT_OK(builder.Resize(length));
+ ASSERT_OK(builder.ReserveData(length));
+ for (int64_t i = 0; i < length; ++i) {
+ builder.UnsafeAppend("1", 1);
+ }
+ std::shared_ptr<Array> array;
+ ASSERT_OK(builder.Finish(&array));
+ auto table =
+ Table::Make(::arrow::schema({::arrow::field("x", ::arrow::utf8())}), {array});
+ std::shared_ptr<SchemaDescriptor> schm;
+ ASSERT_OK_NO_THROW(
+ ToParquetSchema(table->schema().get(), *default_writer_properties(), &schm));
+
+ auto sink = CreateOutputStream();
+
+ auto schm_node = std::static_pointer_cast<GroupNode>(
+ GroupNode::Make("schema", Repetition::REQUIRED, {schm->group_node()->field(0)}));
+
+ auto writer = ParquetFileWriter::Open(sink, schm_node);
+
+ std::unique_ptr<FileWriter> arrow_writer;
+ ASSERT_OK_NO_THROW(FileWriter::Make(::arrow::default_memory_pool(), std::move(writer),
+ table->schema(), default_arrow_writer_properties(),
+ &arrow_writer));
+ for (int i : {0, 1}) {
+ ASSERT_OK_NO_THROW(arrow_writer->WriteTable(*table, table->num_rows())) << i;
+ }
+ ASSERT_OK_NO_THROW(arrow_writer->Close());
+
+ ASSERT_OK_AND_ASSIGN(auto tables_buffer, sink->Finish());
+
+ // drop to save memory
+ table.reset();
+ array.reset();
+
+ auto reader = ParquetFileReader::Open(std::make_shared<BufferReader>(tables_buffer));
+ std::unique_ptr<FileReader> arrow_reader;
+ ASSERT_OK(FileReader::Make(default_memory_pool(), std::move(reader), &arrow_reader));
+ ASSERT_OK_NO_THROW(arrow_reader->ReadTable(&table));
+ ASSERT_OK(table->ValidateFull());
+
+ // ARROW-9297: ensure RecordBatchReader also works
+ reader = ParquetFileReader::Open(std::make_shared<BufferReader>(tables_buffer));
+ ASSERT_OK(FileReader::Make(default_memory_pool(), std::move(reader), &arrow_reader));
+ std::shared_ptr<::arrow::RecordBatchReader> batch_reader;
+ std::vector<int> all_row_groups =
+ ::arrow::internal::Iota(reader->metadata()->num_row_groups());
+ ASSERT_OK_NO_THROW(arrow_reader->GetRecordBatchReader(all_row_groups, &batch_reader));
+ ASSERT_OK_AND_ASSIGN(auto batched_table,
+ ::arrow::Table::FromRecordBatchReader(batch_reader.get()));
+
+ ASSERT_OK(batched_table->ValidateFull());
+ AssertTablesEqual(*table, *batched_table, /*same_chunk_layout=*/false);
+}
+
+TEST(TestArrowReaderAdHoc, HandleDictPageOffsetZero) {
+ // PARQUET-1402: parquet-mr writes files this way which tripped up
+ // some business logic
+ TryReadDataFile(test::get_data_file("dict-page-offset-zero.parquet"));
+}
+
+TEST(TestArrowReaderAdHoc, WriteBatchedNestedNullableStringColumn) {
+ // ARROW-10493
+ std::vector<std::shared_ptr<::arrow::Field>> fields{
+ ::arrow::field("s", ::arrow::utf8(), /*nullable=*/true),
+ ::arrow::field("d", ::arrow::decimal128(4, 2), /*nullable=*/true),
+ ::arrow::field("b", ::arrow::boolean(), /*nullable=*/true),
+ ::arrow::field("i8", ::arrow::int8(), /*nullable=*/true),
+ ::arrow::field("i64", ::arrow::int64(), /*nullable=*/true)};
+ auto type = ::arrow::struct_(fields);
+ auto outer_array = ::arrow::ArrayFromJSON(
+ type,
+ R"([{"s": "abc", "d": "1.23", "b": true, "i8": 10, "i64": 11 },
+ {"s": "de", "d": "3.45", "b": true, "i8": 12, "i64": 13 },
+ {"s": "fghi", "d": "6.78", "b": false, "i8": 14, "i64": 15 },
+ {},
+ {"s": "jklmo", "d": "9.10", "b": true, "i8": 16, "i64": 17 },
+ null,
+ {"s": "p", "d": "11.12", "b": false, "i8": 18, "i64": 19 },
+ {"s": "qrst", "d": "13.14", "b": false, "i8": 20, "i64": 21 },
+ {},
+ {"s": "uvw", "d": "15.16", "b": true, "i8": 22, "i64": 23 },
+ {"s": "x", "d": "17.18", "b": false, "i8": 24, "i64": 25 },
+ {},
+ null])");
+
+ auto expected = Table::Make(
+ ::arrow::schema({::arrow::field("outer", type, /*nullable=*/true)}), {outer_array});
+
+ auto write_props = WriterProperties::Builder().write_batch_size(4)->build();
+
+ std::shared_ptr<Table> actual;
+ DoRoundtrip(expected, /*row_group_size=*/outer_array->length(), &actual, write_props);
+ ::arrow::AssertTablesEqual(*expected, *actual, /*same_chunk_layout=*/false);
+}
+
+class TestArrowReaderAdHocSparkAndHvr
+ : public ::testing::TestWithParam<
+ std::tuple<std::string, std::shared_ptr<DataType>>> {};
+
+TEST_P(TestArrowReaderAdHocSparkAndHvr, ReadDecimals) {
+ std::string path(test::get_data_dir());
+
+ std::string filename;
+ std::shared_ptr<DataType> decimal_type;
+ std::tie(filename, decimal_type) = GetParam();
+
+ path += "/" + filename;
+ ASSERT_GT(path.size(), 0);
+
+ auto pool = ::arrow::default_memory_pool();
+
+ std::unique_ptr<FileReader> arrow_reader;
+ ASSERT_OK_NO_THROW(
+ FileReader::Make(pool, ParquetFileReader::OpenFile(path, false), &arrow_reader));
+ std::shared_ptr<::arrow::Table> table;
+ ASSERT_OK_NO_THROW(arrow_reader->ReadTable(&table));
+
+ std::shared_ptr<::arrow::Schema> schema;
+ ASSERT_OK_NO_THROW(arrow_reader->GetSchema(&schema));
+ ASSERT_EQ(1, schema->num_fields());
+ ASSERT_TRUE(schema->field(0)->type()->Equals(*decimal_type));
+
+ ASSERT_EQ(1, table->num_columns());
+
+ constexpr int32_t expected_length = 24;
+
+ auto value_column = table->column(0);
+ ASSERT_EQ(expected_length, value_column->length());
+
+ ASSERT_EQ(1, value_column->num_chunks());
+
+ auto chunk = value_column->chunk(0);
+
+ std::shared_ptr<Array> expected_array;
+
+ ::arrow::Decimal128Builder builder(decimal_type, pool);
+
+ for (int32_t i = 0; i < expected_length; ++i) {
+ ::arrow::Decimal128 value((i + 1) * 100);
+ ASSERT_OK(builder.Append(value));
+ }
+ ASSERT_OK(builder.Finish(&expected_array));
+ AssertArraysEqual(*expected_array, *chunk);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ReadDecimals, TestArrowReaderAdHocSparkAndHvr,
+ ::testing::Values(
+ std::make_tuple("int32_decimal.parquet", ::arrow::decimal(4, 2)),
+ std::make_tuple("int64_decimal.parquet", ::arrow::decimal(10, 2)),
+ std::make_tuple("fixed_length_decimal.parquet", ::arrow::decimal(25, 2)),
+ std::make_tuple("fixed_length_decimal_legacy.parquet", ::arrow::decimal(13, 2)),
+ std::make_tuple("byte_array_decimal.parquet", ::arrow::decimal(4, 2))));
+
+// direct-as-possible translation of
+// pyarrow/tests/test_parquet.py::test_validate_schema_write_table
+TEST(TestArrowWriterAdHoc, SchemaMismatch) {
+ auto pool = ::arrow::default_memory_pool();
+ auto writer_schm = ::arrow::schema({field("POS", ::arrow::uint32())});
+ auto table_schm = ::arrow::schema({field("POS", ::arrow::int64())});
+ using ::arrow::io::BufferOutputStream;
+ ASSERT_OK_AND_ASSIGN(auto outs, BufferOutputStream::Create(1 << 10, pool));
+ auto props = default_writer_properties();
+ std::unique_ptr<arrow::FileWriter> writer;
+ ASSERT_OK(arrow::FileWriter::Open(*writer_schm, pool, outs, props, &writer));
+ std::shared_ptr<::arrow::Array> col;
+ ::arrow::Int64Builder builder;
+ ASSERT_OK(builder.Append(1));
+ ASSERT_OK(builder.Finish(&col));
+ auto tbl = ::arrow::Table::Make(table_schm, {col});
+ ASSERT_RAISES(Invalid, writer->WriteTable(*tbl, 1));
+}
+
+class TestArrowWriteDictionary : public ::testing::TestWithParam<ParquetDataPageVersion> {
+ public:
+ ParquetDataPageVersion GetParquetDataPageVersion() { return GetParam(); }
+};
+
+TEST_P(TestArrowWriteDictionary, Statistics) {
+ std::vector<std::shared_ptr<::arrow::Array>> test_dictionaries = {
+ ArrayFromJSON(::arrow::utf8(), R"(["b", "c", "d", "a", "b", "c", "d", "a"])"),
+ ArrayFromJSON(::arrow::utf8(), R"(["b", "c", "d", "a", "b", "c", "d", "a"])"),
+ ArrayFromJSON(::arrow::binary(), R"(["d", "c", "b", "a", "d", "c", "b", "a"])"),
+ ArrayFromJSON(::arrow::large_utf8(), R"(["a", "b", "c", "a", "b", "c"])")};
+ std::vector<std::shared_ptr<::arrow::Array>> test_indices = {
+ ArrayFromJSON(::arrow::int32(), R"([0, null, 3, 0, null, 3])"),
+ ArrayFromJSON(::arrow::int32(), R"([0, 1, null, 0, 1, null])"),
+ ArrayFromJSON(::arrow::int32(), R"([0, 1, 3, 0, 1, 3])"),
+ ArrayFromJSON(::arrow::int32(), R"([null, null, null, null, null, null])")};
+ // Arrays will be written with 3 values per row group, 2 values per data page. The
+ // row groups are identical for ease of testing.
+ std::vector<int32_t> expected_valid_counts = {2, 2, 3, 0};
+ std::vector<int32_t> expected_null_counts = {1, 1, 0, 3};
+ std::vector<int> expected_num_data_pages = {2, 2, 2, 1};
+ std::vector<std::vector<int32_t>> expected_valid_by_page = {
+ {1, 1}, {2, 0}, {2, 1}, {0}};
+ std::vector<std::vector<int64_t>> expected_null_by_page = {{1, 0}, {0, 1}, {0, 0}, {3}};
+ std::vector<int32_t> expected_dict_counts = {4, 4, 4, 3};
+ // Pairs of (min, max)
+ std::vector<std::vector<std::string>> expected_min_max_ = {
+ {"a", "b"}, {"b", "c"}, {"a", "d"}, {"", ""}};
+
+ for (std::size_t case_index = 0; case_index < test_dictionaries.size(); case_index++) {
+ SCOPED_TRACE(test_dictionaries[case_index]->type()->ToString());
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<::arrow::Array> dict_encoded,
+ ::arrow::DictionaryArray::FromArrays(
+ test_indices[case_index], test_dictionaries[case_index]));
+ std::shared_ptr<::arrow::Schema> schema =
+ ::arrow::schema({::arrow::field("values", dict_encoded->type())});
+ std::shared_ptr<::arrow::Table> table = ::arrow::Table::Make(schema, {dict_encoded});
+
+ std::shared_ptr<::arrow::ResizableBuffer> serialized_data = AllocateBuffer();
+ auto out_stream = std::make_shared<::arrow::io::BufferOutputStream>(serialized_data);
+ std::shared_ptr<WriterProperties> writer_properties =
+ WriterProperties::Builder()
+ .max_row_group_length(3)
+ ->data_page_version(this->GetParquetDataPageVersion())
+ ->write_batch_size(2)
+ ->data_pagesize(2)
+ ->build();
+ std::unique_ptr<FileWriter> writer;
+ ASSERT_OK(FileWriter::Open(*schema, ::arrow::default_memory_pool(), out_stream,
+ writer_properties, default_arrow_writer_properties(),
+ &writer));
+ ASSERT_OK(writer->WriteTable(*table, std::numeric_limits<int64_t>::max()));
+ ASSERT_OK(writer->Close());
+ ASSERT_OK(out_stream->Close());
+
+ auto buffer_reader = std::make_shared<::arrow::io::BufferReader>(serialized_data);
+ std::unique_ptr<ParquetFileReader> parquet_reader =
+ ParquetFileReader::Open(std::move(buffer_reader));
+
+ // Check row group statistics
+ std::shared_ptr<FileMetaData> metadata = parquet_reader->metadata();
+ ASSERT_EQ(metadata->num_row_groups(), 2);
+ for (int row_group_index = 0; row_group_index < 2; row_group_index++) {
+ ASSERT_EQ(metadata->RowGroup(row_group_index)->num_columns(), 1);
+ std::shared_ptr<Statistics> stats =
+ metadata->RowGroup(row_group_index)->ColumnChunk(0)->statistics();
+
+ EXPECT_EQ(stats->num_values(), expected_valid_counts[case_index]);
+ EXPECT_EQ(stats->null_count(), expected_null_counts[case_index]);
+
+ std::vector<std::string> case_expected_min_max = expected_min_max_[case_index];
+ EXPECT_EQ(stats->EncodeMin(), case_expected_min_max[0]);
+ EXPECT_EQ(stats->EncodeMax(), case_expected_min_max[1]);
+ }
+
+ for (int row_group_index = 0; row_group_index < 2; row_group_index++) {
+ std::unique_ptr<PageReader> page_reader =
+ parquet_reader->RowGroup(row_group_index)->GetColumnPageReader(0);
+ std::shared_ptr<Page> page = page_reader->NextPage();
+ ASSERT_NE(page, nullptr);
+ DictionaryPage* dict_page = (DictionaryPage*)page.get();
+ ASSERT_EQ(dict_page->num_values(), expected_dict_counts[case_index]);
+ for (int page_index = 0; page_index < expected_num_data_pages[case_index];
+ page_index++) {
+ page = page_reader->NextPage();
+ ASSERT_NE(page, nullptr);
+ DataPage* data_page = (DataPage*)page.get();
+ const EncodedStatistics& stats = data_page->statistics();
+ EXPECT_EQ(stats.null_count, expected_null_by_page[case_index][page_index]);
+ EXPECT_EQ(stats.has_min, false);
+ EXPECT_EQ(stats.has_max, false);
+ EXPECT_EQ(data_page->num_values(),
+ expected_valid_by_page[case_index][page_index] +
+ expected_null_by_page[case_index][page_index]);
+ }
+ ASSERT_EQ(page_reader->NextPage(), nullptr);
+ }
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(WriteDictionary, TestArrowWriteDictionary,
+ ::testing::Values(ParquetDataPageVersion::V1,
+ ParquetDataPageVersion::V2));
+// ----------------------------------------------------------------------
+// Tests for directly reading DictionaryArray
+
+class TestArrowReadDictionary : public ::testing::TestWithParam<double> {
+ public:
+ static constexpr int kNumRowGroups = 16;
+
+ struct {
+ int num_rows = 1024 * kNumRowGroups;
+ int num_row_groups = kNumRowGroups;
+ int num_uniques = 128;
+ } options;
+
+ void SetUp() override {
+ properties_ = default_arrow_reader_properties();
+
+ GenerateData(GetParam());
+ }
+
+ void GenerateData(double null_probability) {
+ constexpr int64_t min_length = 2;
+ constexpr int64_t max_length = 100;
+ ::arrow::random::RandomArrayGenerator rag(0);
+ dense_values_ = rag.StringWithRepeats(options.num_rows, options.num_uniques,
+ min_length, max_length, null_probability);
+ expected_dense_ = MakeSimpleTable(dense_values_, /*nullable=*/true);
+ }
+
+ void TearDown() override {}
+
+ void WriteSimple() {
+ // Write `num_row_groups` row groups; each row group will have a different dictionary
+ ASSERT_NO_FATAL_FAILURE(
+ WriteTableToBuffer(expected_dense_, options.num_rows / options.num_row_groups,
+ default_arrow_writer_properties(), &buffer_));
+ }
+
+ void CheckReadWholeFile(const Table& expected) {
+ ASSERT_OK_AND_ASSIGN(auto reader, GetReader());
+
+ std::shared_ptr<Table> actual;
+ ASSERT_OK_NO_THROW(reader->ReadTable(&actual));
+ ::arrow::AssertTablesEqual(expected, *actual, /*same_chunk_layout=*/false);
+ }
+
+ void CheckStreamReadWholeFile(const Table& expected) {
+ ASSERT_OK_AND_ASSIGN(auto reader, GetReader());
+
+ std::unique_ptr<::arrow::RecordBatchReader> rb;
+ ASSERT_OK(reader->GetRecordBatchReader(
+ ::arrow::internal::Iota(options.num_row_groups), &rb));
+
+ std::shared_ptr<Table> actual;
+ ASSERT_OK_NO_THROW(rb->ReadAll(&actual));
+ ::arrow::AssertTablesEqual(expected, *actual, /*same_chunk_layout=*/false);
+ }
+
+ static std::vector<double> null_probabilities() { return {0.0, 0.5, 1}; }
+
+ protected:
+ std::shared_ptr<Array> dense_values_;
+ std::shared_ptr<Table> expected_dense_;
+ std::shared_ptr<Table> expected_dict_;
+ std::shared_ptr<Buffer> buffer_;
+ ArrowReaderProperties properties_;
+
+ ::arrow::Result<std::unique_ptr<FileReader>> GetReader() {
+ std::unique_ptr<FileReader> reader;
+
+ FileReaderBuilder builder;
+ RETURN_NOT_OK(builder.Open(std::make_shared<BufferReader>(buffer_)));
+ RETURN_NOT_OK(builder.properties(properties_)->Build(&reader));
+
+ return std::move(reader);
+ }
+};
+
+void AsDictionary32Encoded(const Array& arr, std::shared_ptr<Array>* out) {
+ ::arrow::StringDictionary32Builder builder(default_memory_pool());
+ const auto& string_array = static_cast<const ::arrow::StringArray&>(arr);
+ ASSERT_OK(builder.AppendArray(string_array));
+ ASSERT_OK(builder.Finish(out));
+}
+
+TEST_P(TestArrowReadDictionary, ReadWholeFileDict) {
+ properties_.set_read_dictionary(0, true);
+
+ WriteSimple();
+
+ auto num_row_groups = options.num_row_groups;
+ auto chunk_size = options.num_rows / num_row_groups;
+
+ std::vector<std::shared_ptr<Array>> chunks(num_row_groups);
+ for (int i = 0; i < num_row_groups; ++i) {
+ AsDictionary32Encoded(*dense_values_->Slice(chunk_size * i, chunk_size), &chunks[i]);
+ }
+ auto ex_table = MakeSimpleTable(std::make_shared<ChunkedArray>(chunks),
+ /*nullable=*/true);
+ CheckReadWholeFile(*ex_table);
+}
+
+TEST_P(TestArrowReadDictionary, ZeroChunksListOfDictionary) {
+ // ARROW-8799
+ properties_.set_read_dictionary(0, true);
+ dense_values_.reset();
+ auto values = std::make_shared<ChunkedArray>(::arrow::ArrayVector{},
+ ::arrow::list(::arrow::utf8()));
+ options.num_rows = 0;
+ options.num_uniques = 0;
+ options.num_row_groups = 1;
+ expected_dense_ = MakeSimpleTable(values, false);
+
+ WriteSimple();
+
+ ASSERT_OK_AND_ASSIGN(auto reader, GetReader());
+
+ std::unique_ptr<ColumnReader> column_reader;
+ ASSERT_OK_NO_THROW(reader->GetColumn(0, &column_reader));
+
+ std::shared_ptr<ChunkedArray> chunked_out;
+ ASSERT_OK(column_reader->NextBatch(1 << 15, &chunked_out));
+
+ ASSERT_EQ(chunked_out->length(), 0);
+ ASSERT_EQ(chunked_out->num_chunks(), 1);
+}
+
+TEST_P(TestArrowReadDictionary, IncrementalReads) {
+ // ARROW-6895
+ options.num_rows = 100;
+ options.num_uniques = 10;
+ SetUp();
+
+ properties_.set_read_dictionary(0, true);
+
+ // Just write a single row group
+ ASSERT_NO_FATAL_FAILURE(WriteTableToBuffer(
+ expected_dense_, options.num_rows, default_arrow_writer_properties(), &buffer_));
+
+ // Read in one shot
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr<FileReader> reader, GetReader());
+ std::shared_ptr<Table> expected;
+ ASSERT_OK_NO_THROW(reader->ReadTable(&expected));
+
+ ASSERT_OK_AND_ASSIGN(reader, GetReader());
+ std::unique_ptr<ColumnReader> col;
+ ASSERT_OK(reader->GetColumn(0, &col));
+
+ int num_reads = 4;
+ int batch_size = options.num_rows / num_reads;
+
+ for (int i = 0; i < num_reads; ++i) {
+ std::shared_ptr<ChunkedArray> chunk;
+ ASSERT_OK(col->NextBatch(batch_size, &chunk));
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> result_dense,
+ ::arrow::compute::Cast(*chunk->chunk(0), ::arrow::utf8()));
+ AssertArraysEqual(*dense_values_->Slice(i * batch_size, batch_size), *result_dense);
+ }
+}
+
+TEST_P(TestArrowReadDictionary, StreamReadWholeFileDict) {
+ // ARROW-6895 and ARROW-7545 reading a parquet file with a dictionary of
+ // binary data, e.g. String, will return invalid values when using the
+ // RecordBatchReader (stream) interface. In some cases, this will trigger an
+ // infinite loop of the calling thread.
+
+ // Recompute generated data with only one row-group
+ options.num_row_groups = 1;
+ options.num_rows = 16;
+ options.num_uniques = 7;
+ SetUp();
+ WriteSimple();
+
+ // Would trigger an infinite loop when requesting a batch greater than the
+ // number of available rows in a row group.
+ properties_.set_batch_size(options.num_rows * 2);
+ CheckStreamReadWholeFile(*expected_dense_);
+}
+
+TEST_P(TestArrowReadDictionary, ReadWholeFileDense) {
+ properties_.set_read_dictionary(0, false);
+ WriteSimple();
+ CheckReadWholeFile(*expected_dense_);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ReadDictionary, TestArrowReadDictionary,
+ ::testing::ValuesIn(TestArrowReadDictionary::null_probabilities()));
+
+TEST(TestArrowWriteDictionaries, ChangingDictionaries) {
+ constexpr int num_unique = 50;
+ constexpr int repeat = 10000;
+ constexpr int64_t min_length = 2;
+ constexpr int64_t max_length = 20;
+ ::arrow::random::RandomArrayGenerator rag(0);
+ auto values = rag.StringWithRepeats(repeat * num_unique, num_unique, min_length,
+ max_length, /*null_probability=*/0.1);
+ auto expected = MakeSimpleTable(values, /*nullable=*/true);
+
+ const int num_chunks = 10;
+ std::vector<std::shared_ptr<Array>> chunks(num_chunks);
+ const int64_t chunk_size = values->length() / num_chunks;
+ for (int i = 0; i < num_chunks; ++i) {
+ AsDictionary32Encoded(*values->Slice(chunk_size * i, chunk_size), &chunks[i]);
+ }
+
+ auto dict_table = MakeSimpleTable(std::make_shared<ChunkedArray>(chunks),
+ /*nullable=*/true);
+
+ std::shared_ptr<Table> actual;
+ DoRoundtrip(dict_table, /*row_group_size=*/values->length() / 2, &actual);
+ ::arrow::AssertTablesEqual(*expected, *actual, /*same_chunk_layout=*/false);
+}
+
+TEST(TestArrowWriteDictionaries, AutoReadAsDictionary) {
+ constexpr int num_unique = 50;
+ constexpr int repeat = 100;
+ constexpr int64_t min_length = 2;
+ constexpr int64_t max_length = 20;
+ ::arrow::random::RandomArrayGenerator rag(0);
+ auto values = rag.StringWithRepeats(repeat * num_unique, num_unique, min_length,
+ max_length, /*null_probability=*/0.1);
+ std::shared_ptr<Array> dict_values;
+ AsDictionary32Encoded(*values, &dict_values);
+
+ auto expected = MakeSimpleTable(dict_values, /*nullable=*/true);
+ auto expected_dense = MakeSimpleTable(values, /*nullable=*/true);
+
+ auto props_store_schema = ArrowWriterProperties::Builder().store_schema()->build();
+ std::shared_ptr<Table> actual, actual_dense;
+
+ DoRoundtrip(expected, values->length(), &actual, default_writer_properties(),
+ props_store_schema);
+ ::arrow::AssertTablesEqual(*expected, *actual);
+
+ auto props_no_store_schema = ArrowWriterProperties::Builder().build();
+ DoRoundtrip(expected, values->length(), &actual_dense, default_writer_properties(),
+ props_no_store_schema);
+ ::arrow::AssertTablesEqual(*expected_dense, *actual_dense);
+}
+
+TEST(TestArrowWriteDictionaries, NestedSubfield) {
+ auto offsets = ::arrow::ArrayFromJSON(::arrow::int32(), "[0, 0, 2, 3]");
+ auto indices = ::arrow::ArrayFromJSON(::arrow::int32(), "[0, 0, 0]");
+ auto dict = ::arrow::ArrayFromJSON(::arrow::utf8(), "[\"foo\"]");
+
+ auto dict_ty = ::arrow::dictionary(::arrow::int32(), ::arrow::utf8());
+ ASSERT_OK_AND_ASSIGN(auto dict_values,
+ ::arrow::DictionaryArray::FromArrays(dict_ty, indices, dict));
+ ASSERT_OK_AND_ASSIGN(auto values,
+ ::arrow::ListArray::FromArrays(*offsets, *dict_values));
+
+ auto table = MakeSimpleTable(values, /*nullable=*/true);
+
+ auto props_store_schema = ArrowWriterProperties::Builder().store_schema()->build();
+ std::shared_ptr<Table> actual;
+ DoRoundtrip(table, values->length(), &actual, default_writer_properties(),
+ props_store_schema);
+
+ ::arrow::AssertTablesEqual(*table, *actual);
+}
+
+#ifdef ARROW_CSV
+TEST(TestArrowReadDeltaEncoding, DeltaBinaryPacked) {
+ auto file = test::get_data_file("delta_binary_packed.parquet");
+ auto expect_file = test::get_data_file("delta_binary_packed_expect.csv");
+ auto pool = ::arrow::default_memory_pool();
+ std::unique_ptr<FileReader> parquet_reader;
+ std::shared_ptr<::arrow::Table> table;
+ ASSERT_OK(
+ FileReader::Make(pool, ParquetFileReader::OpenFile(file, false), &parquet_reader));
+ ASSERT_OK(parquet_reader->ReadTable(&table));
+
+ ASSERT_OK_AND_ASSIGN(auto input_file, ::arrow::io::ReadableFile::Open(expect_file));
+ auto convert_options = ::arrow::csv::ConvertOptions::Defaults();
+ for (int i = 0; i <= 64; ++i) {
+ std::string column_name = "bitwidth" + std::to_string(i);
+ convert_options.column_types[column_name] = ::arrow::int64();
+ }
+ convert_options.column_types["int_value"] = ::arrow::int32();
+ ASSERT_OK_AND_ASSIGN(auto csv_reader,
+ ::arrow::csv::TableReader::Make(
+ ::arrow::io::default_io_context(), input_file,
+ ::arrow::csv::ReadOptions::Defaults(),
+ ::arrow::csv::ParseOptions::Defaults(), convert_options));
+ ASSERT_OK_AND_ASSIGN(auto expect_table, csv_reader->Read());
+
+ ::arrow::AssertTablesEqual(*table, *expect_table);
+}
+#else
+TEST(TestArrowReadDeltaEncoding, DeltaBinaryPacked) {
+ GTEST_SKIP() << "Test needs CSV reader";
+}
+#endif
+
+struct NestedFilterTestCase {
+ std::shared_ptr<::arrow::DataType> write_schema;
+ std::vector<int> indices_to_read;
+ std::shared_ptr<::arrow::DataType> expected_schema;
+ std::string write_data;
+ std::string read_data;
+
+ // For Valgrind
+ friend std::ostream& operator<<(std::ostream& os, const NestedFilterTestCase& param) {
+ os << "NestedFilterTestCase{write_schema = " << param.write_schema->ToString() << "}";
+ return os;
+ }
+};
+class TestNestedSchemaFilteredReader
+ : public ::testing::TestWithParam<NestedFilterTestCase> {};
+
+TEST_P(TestNestedSchemaFilteredReader, ReadWrite) {
+ std::shared_ptr<::arrow::io::BufferOutputStream> sink = CreateOutputStream();
+ auto write_props = WriterProperties::Builder().build();
+ std::shared_ptr<::arrow::Array> array =
+ ArrayFromJSON(GetParam().write_schema, GetParam().write_data);
+
+ ASSERT_OK_NO_THROW(
+ WriteTable(**Table::FromRecordBatches({::arrow::RecordBatch::Make(
+ ::arrow::schema({::arrow::field("col", array->type())}),
+ array->length(), {array})}),
+ ::arrow::default_memory_pool(), sink, /*chunk_size=*/100, write_props,
+ ArrowWriterProperties::Builder().store_schema()->build()));
+ std::shared_ptr<::arrow::Buffer> buffer;
+ ASSERT_OK_AND_ASSIGN(buffer, sink->Finish());
+
+ std::unique_ptr<FileReader> reader;
+ FileReaderBuilder builder;
+ ASSERT_OK_NO_THROW(builder.Open(std::make_shared<BufferReader>(buffer)));
+ ASSERT_OK(builder.properties(default_arrow_reader_properties())->Build(&reader));
+ std::shared_ptr<::arrow::Table> read_table;
+ ASSERT_OK_NO_THROW(reader->ReadTable(GetParam().indices_to_read, &read_table));
+
+ std::shared_ptr<::arrow::Array> expected =
+ ArrayFromJSON(GetParam().expected_schema, GetParam().read_data);
+ AssertArraysEqual(*read_table->column(0)->chunk(0), *expected, /*verbose=*/true);
+}
+
+std::vector<NestedFilterTestCase> GenerateListFilterTestCases() {
+ auto struct_type = ::arrow::struct_(
+ {::arrow::field("a", ::arrow::int64()), ::arrow::field("b", ::arrow::int64())});
+
+ constexpr auto kWriteData = R"([[{"a": 1, "b": 2}]])";
+ constexpr auto kReadData = R"([[{"a": 1}]])";
+
+ std::vector<NestedFilterTestCase> cases;
+ auto first_selected_type = ::arrow::struct_({struct_type->field(0)});
+ cases.push_back({::arrow::list(struct_type),
+ /*indices=*/{0}, ::arrow::list(first_selected_type), kWriteData,
+ kReadData});
+ cases.push_back({::arrow::large_list(struct_type),
+ /*indices=*/{0}, ::arrow::large_list(first_selected_type), kWriteData,
+ kReadData});
+ cases.push_back({::arrow::fixed_size_list(struct_type, /*list_size=*/1),
+ /*indices=*/{0},
+ ::arrow::fixed_size_list(first_selected_type, /*list_size=*/1),
+ kWriteData, kReadData});
+ return cases;
+}
+
+INSTANTIATE_TEST_SUITE_P(ListFilteredReads, TestNestedSchemaFilteredReader,
+ ::testing::ValuesIn(GenerateListFilterTestCases()));
+
+std::vector<NestedFilterTestCase> GenerateNestedStructFilteredTestCases() {
+ using ::arrow::field;
+ using ::arrow::struct_;
+ auto struct_type = struct_(
+ {field("t1", struct_({field("a", ::arrow::int64()), field("b", ::arrow::int64())})),
+ field("t2", ::arrow::int64())});
+
+ constexpr auto kWriteData = R"([{"t1": {"a": 1, "b":2}, "t2": 3}])";
+
+ std::vector<NestedFilterTestCase> cases;
+ auto selected_type = ::arrow::struct_(
+ {field("t1", struct_({field("a", ::arrow::int64())})), struct_type->field(1)});
+ cases.push_back({struct_type,
+ /*indices=*/{0, 2}, selected_type, kWriteData,
+ /*expected=*/R"([{"t1": {"a": 1}, "t2": 3}])"});
+ selected_type = ::arrow::struct_(
+ {field("t1", struct_({field("b", ::arrow::int64())})), struct_type->field(1)});
+
+ cases.push_back({struct_type,
+ /*indices=*/{1, 2}, selected_type, kWriteData,
+ /*expected=*/R"([{"t1": {"b": 2}, "t2": 3}])"});
+
+ return cases;
+}
+
+INSTANTIATE_TEST_SUITE_P(StructFilteredReads, TestNestedSchemaFilteredReader,
+ ::testing::ValuesIn(GenerateNestedStructFilteredTestCases()));
+
+std::vector<NestedFilterTestCase> GenerateMapFilteredTestCases() {
+ using ::arrow::field;
+ using ::arrow::struct_;
+ auto map_type = std::static_pointer_cast<::arrow::MapType>(::arrow::map(
+ struct_({field("a", ::arrow::int64()), field("b", ::arrow::int64())}),
+ struct_({field("c", ::arrow::int64()), field("d", ::arrow::int64())})));
+
+ constexpr auto kWriteData = R"([[[{"a": 0, "b": 1}, {"c": 2, "d": 3}]]])";
+ std::vector<NestedFilterTestCase> cases;
+ // Remove the value element completely converts to a list of struct.
+ cases.push_back(
+ {map_type,
+ /*indices=*/{0, 1},
+ /*selected_type=*/
+ ::arrow::list(field("col", struct_({map_type->key_field()}), /*nullable=*/false)),
+ kWriteData, /*expected_data=*/R"([[{"key": {"a": 0, "b":1}}]])"});
+ // The "col" field name below comes from how naming is done when writing out the
+ // array (it is assigned the column name col.
+
+ // Removing the full key converts to a list of struct.
+ cases.push_back(
+ {map_type,
+ /*indices=*/{3},
+ /*selected_type=*/
+ ::arrow::list(field(
+ "col", struct_({field("value", struct_({field("d", ::arrow::int64())}))}),
+ /*nullable=*/false)),
+ kWriteData, /*expected_data=*/R"([[{"value": {"d": 3}}]])"});
+ // Selecting the full key and a value maintains the map
+ cases.push_back(
+ {map_type, /*indices=*/{0, 1, 2},
+ /*selected_type=*/
+ ::arrow::map(map_type->key_type(), struct_({field("c", ::arrow::int64())})),
+ kWriteData, /*expected=*/R"([[[{"a": 0, "b": 1}, {"c": 2}]]])"});
+
+ // Selecting the partial key (with some part of the value converts to
+ // list of structs (because the key might no longer be unique).
+ cases.push_back(
+ {map_type, /*indices=*/{1, 2, 3},
+ /*selected_type=*/
+ ::arrow::list(field("col",
+ struct_({field("key", struct_({field("b", ::arrow::int64())}),
+ /*nullable=*/false),
+ map_type->item_field()}),
+ /*nullable=*/false)),
+ kWriteData, /*expected=*/R"([[{"key":{"b": 1}, "value": {"c": 2, "d": 3}}]])"});
+
+ return cases;
+}
+
+INSTANTIATE_TEST_SUITE_P(MapFilteredReads, TestNestedSchemaFilteredReader,
+ ::testing::ValuesIn(GenerateMapFilteredTestCases()));
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/arrow_schema_test.cc b/src/arrow/cpp/src/parquet/arrow/arrow_schema_test.cc
new file mode 100644
index 000000000..99ce0c962
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/arrow_schema_test.cc
@@ -0,0 +1,1701 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "parquet/arrow/reader.h"
+#include "parquet/arrow/reader_internal.h"
+#include "parquet/arrow/schema.h"
+#include "parquet/file_reader.h"
+#include "parquet/schema.h"
+#include "parquet/schema_internal.h"
+#include "parquet/test_util.h"
+#include "parquet/thrift_internal.h"
+
+#include "arrow/array.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/key_value_metadata.h"
+
+using arrow::ArrayFromVector;
+using arrow::Field;
+using arrow::TimeUnit;
+
+using ParquetType = parquet::Type;
+using parquet::ConvertedType;
+using parquet::LogicalType;
+using parquet::Repetition;
+using parquet::format::SchemaElement;
+using parquet::internal::LevelInfo;
+using parquet::schema::GroupNode;
+using parquet::schema::NodePtr;
+using parquet::schema::PrimitiveNode;
+
+using ::testing::ElementsAre;
+
+namespace parquet {
+
+namespace arrow {
+
+const auto BOOL = ::arrow::boolean();
+const auto UINT8 = ::arrow::uint8();
+const auto INT32 = ::arrow::int32();
+const auto INT64 = ::arrow::int64();
+const auto FLOAT = ::arrow::float32();
+const auto DOUBLE = ::arrow::float64();
+const auto UTF8 = ::arrow::utf8();
+const auto TIMESTAMP_MS = ::arrow::timestamp(TimeUnit::MILLI);
+const auto TIMESTAMP_US = ::arrow::timestamp(TimeUnit::MICRO);
+const auto TIMESTAMP_NS = ::arrow::timestamp(TimeUnit::NANO);
+const auto BINARY = ::arrow::binary();
+const auto DECIMAL_8_4 = std::make_shared<::arrow::Decimal128Type>(8, 4);
+
+class TestConvertParquetSchema : public ::testing::Test {
+ public:
+ virtual void SetUp() {}
+
+ void CheckFlatSchema(const std::shared_ptr<::arrow::Schema>& expected_schema,
+ bool check_metadata = false) {
+ ASSERT_EQ(expected_schema->num_fields(), result_schema_->num_fields());
+ for (int i = 0; i < expected_schema->num_fields(); ++i) {
+ auto result_field = result_schema_->field(i);
+ auto expected_field = expected_schema->field(i);
+ EXPECT_TRUE(result_field->Equals(expected_field, check_metadata))
+ << "Field " << i << "\n result: " << result_field->ToString()
+ << "\n expected: " << expected_field->ToString();
+ }
+ }
+
+ ::arrow::Status ConvertSchema(
+ const std::vector<NodePtr>& nodes,
+ const std::shared_ptr<const KeyValueMetadata>& key_value_metadata = nullptr) {
+ NodePtr schema = GroupNode::Make("schema", Repetition::REPEATED, nodes);
+ descr_.Init(schema);
+ ArrowReaderProperties props;
+ return FromParquetSchema(&descr_, props, key_value_metadata, &result_schema_);
+ }
+
+ protected:
+ SchemaDescriptor descr_;
+ std::shared_ptr<::arrow::Schema> result_schema_;
+};
+
+TEST_F(TestConvertParquetSchema, ParquetFlatPrimitives) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("boolean", Repetition::REQUIRED, ParquetType::BOOLEAN));
+ arrow_fields.push_back(::arrow::field("boolean", BOOL, false));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("int32", Repetition::REQUIRED, ParquetType::INT32));
+ arrow_fields.push_back(::arrow::field("int32", INT32, false));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("int64", Repetition::REQUIRED, ParquetType::INT64));
+ arrow_fields.push_back(::arrow::field("int64", INT64, false));
+
+ parquet_fields.push_back(PrimitiveNode::Make("timestamp", Repetition::REQUIRED,
+ ParquetType::INT64,
+ ConvertedType::TIMESTAMP_MILLIS));
+ arrow_fields.push_back(
+ ::arrow::field("timestamp", ::arrow::timestamp(TimeUnit::MILLI), false));
+
+ parquet_fields.push_back(PrimitiveNode::Make("timestamp[us]", Repetition::REQUIRED,
+ ParquetType::INT64,
+ ConvertedType::TIMESTAMP_MICROS));
+ arrow_fields.push_back(
+ ::arrow::field("timestamp[us]", ::arrow::timestamp(TimeUnit::MICRO), false));
+
+ parquet_fields.push_back(PrimitiveNode::Make("date", Repetition::REQUIRED,
+ ParquetType::INT32, ConvertedType::DATE));
+ arrow_fields.push_back(::arrow::field("date", ::arrow::date32(), false));
+
+ parquet_fields.push_back(PrimitiveNode::Make(
+ "time32", Repetition::REQUIRED, ParquetType::INT32, ConvertedType::TIME_MILLIS));
+ arrow_fields.push_back(
+ ::arrow::field("time32", ::arrow::time32(TimeUnit::MILLI), false));
+
+ parquet_fields.push_back(PrimitiveNode::Make(
+ "time64", Repetition::REQUIRED, ParquetType::INT64, ConvertedType::TIME_MICROS));
+ arrow_fields.push_back(
+ ::arrow::field("time64", ::arrow::time64(TimeUnit::MICRO), false));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("timestamp96", Repetition::REQUIRED, ParquetType::INT96));
+ arrow_fields.push_back(::arrow::field("timestamp96", TIMESTAMP_NS, false));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("float", Repetition::OPTIONAL, ParquetType::FLOAT));
+ arrow_fields.push_back(::arrow::field("float", FLOAT));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("double", Repetition::OPTIONAL, ParquetType::DOUBLE));
+ arrow_fields.push_back(::arrow::field("double", DOUBLE));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("binary", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY));
+ arrow_fields.push_back(::arrow::field("binary", BINARY));
+
+ parquet_fields.push_back(PrimitiveNode::Make(
+ "string", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY, ConvertedType::UTF8));
+ arrow_fields.push_back(::arrow::field("string", UTF8));
+
+ parquet_fields.push_back(PrimitiveNode::Make("flba-binary", Repetition::OPTIONAL,
+ ParquetType::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::NONE, 12));
+ arrow_fields.push_back(::arrow::field("flba-binary", ::arrow::fixed_size_binary(12)));
+
+ auto arrow_schema = ::arrow::schema(arrow_fields);
+ ASSERT_OK(ConvertSchema(parquet_fields));
+
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(arrow_schema));
+}
+
+TEST_F(TestConvertParquetSchema, ParquetAnnotatedFields) {
+ struct FieldConstructionArguments {
+ std::string name;
+ std::shared_ptr<const LogicalType> logical_type;
+ parquet::Type::type physical_type;
+ int physical_length;
+ std::shared_ptr<::arrow::DataType> datatype;
+ };
+
+ std::vector<FieldConstructionArguments> cases = {
+ {"string", LogicalType::String(), ParquetType::BYTE_ARRAY, -1, ::arrow::utf8()},
+ {"enum", LogicalType::Enum(), ParquetType::BYTE_ARRAY, -1, ::arrow::binary()},
+ {"decimal(8, 2)", LogicalType::Decimal(8, 2), ParquetType::INT32, -1,
+ ::arrow::decimal(8, 2)},
+ {"decimal(16, 4)", LogicalType::Decimal(16, 4), ParquetType::INT64, -1,
+ ::arrow::decimal(16, 4)},
+ {"decimal(32, 8)", LogicalType::Decimal(32, 8), ParquetType::FIXED_LEN_BYTE_ARRAY,
+ 16, ::arrow::decimal(32, 8)},
+ {"date", LogicalType::Date(), ParquetType::INT32, -1, ::arrow::date32()},
+ {"time(ms)", LogicalType::Time(true, LogicalType::TimeUnit::MILLIS),
+ ParquetType::INT32, -1, ::arrow::time32(::arrow::TimeUnit::MILLI)},
+ {"time(us)", LogicalType::Time(true, LogicalType::TimeUnit::MICROS),
+ ParquetType::INT64, -1, ::arrow::time64(::arrow::TimeUnit::MICRO)},
+ {"time(ns)", LogicalType::Time(true, LogicalType::TimeUnit::NANOS),
+ ParquetType::INT64, -1, ::arrow::time64(::arrow::TimeUnit::NANO)},
+ {"time(ms)", LogicalType::Time(false, LogicalType::TimeUnit::MILLIS),
+ ParquetType::INT32, -1, ::arrow::time32(::arrow::TimeUnit::MILLI)},
+ {"time(us)", LogicalType::Time(false, LogicalType::TimeUnit::MICROS),
+ ParquetType::INT64, -1, ::arrow::time64(::arrow::TimeUnit::MICRO)},
+ {"time(ns)", LogicalType::Time(false, LogicalType::TimeUnit::NANOS),
+ ParquetType::INT64, -1, ::arrow::time64(::arrow::TimeUnit::NANO)},
+ {"timestamp(true, ms)", LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS),
+ ParquetType::INT64, -1, ::arrow::timestamp(::arrow::TimeUnit::MILLI, "UTC")},
+ {"timestamp(true, us)", LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS),
+ ParquetType::INT64, -1, ::arrow::timestamp(::arrow::TimeUnit::MICRO, "UTC")},
+ {"timestamp(true, ns)", LogicalType::Timestamp(true, LogicalType::TimeUnit::NANOS),
+ ParquetType::INT64, -1, ::arrow::timestamp(::arrow::TimeUnit::NANO, "UTC")},
+ {"timestamp(false, ms)",
+ LogicalType::Timestamp(false, LogicalType::TimeUnit::MILLIS), ParquetType::INT64,
+ -1, ::arrow::timestamp(::arrow::TimeUnit::MILLI)},
+ {"timestamp(false, us)",
+ LogicalType::Timestamp(false, LogicalType::TimeUnit::MICROS), ParquetType::INT64,
+ -1, ::arrow::timestamp(::arrow::TimeUnit::MICRO)},
+ {"timestamp(false, ns)",
+ LogicalType::Timestamp(false, LogicalType::TimeUnit::NANOS), ParquetType::INT64,
+ -1, ::arrow::timestamp(::arrow::TimeUnit::NANO)},
+ {"int(8, false)", LogicalType::Int(8, false), ParquetType::INT32, -1,
+ ::arrow::uint8()},
+ {"int(8, true)", LogicalType::Int(8, true), ParquetType::INT32, -1,
+ ::arrow::int8()},
+ {"int(16, false)", LogicalType::Int(16, false), ParquetType::INT32, -1,
+ ::arrow::uint16()},
+ {"int(16, true)", LogicalType::Int(16, true), ParquetType::INT32, -1,
+ ::arrow::int16()},
+ {"int(32, false)", LogicalType::Int(32, false), ParquetType::INT32, -1,
+ ::arrow::uint32()},
+ {"int(32, true)", LogicalType::Int(32, true), ParquetType::INT32, -1,
+ ::arrow::int32()},
+ {"int(64, false)", LogicalType::Int(64, false), ParquetType::INT64, -1,
+ ::arrow::uint64()},
+ {"int(64, true)", LogicalType::Int(64, true), ParquetType::INT64, -1,
+ ::arrow::int64()},
+ {"json", LogicalType::JSON(), ParquetType::BYTE_ARRAY, -1, ::arrow::binary()},
+ {"bson", LogicalType::BSON(), ParquetType::BYTE_ARRAY, -1, ::arrow::binary()},
+ {"interval", LogicalType::Interval(), ParquetType::FIXED_LEN_BYTE_ARRAY, 12,
+ ::arrow::fixed_size_binary(12)},
+ {"uuid", LogicalType::UUID(), ParquetType::FIXED_LEN_BYTE_ARRAY, 16,
+ ::arrow::fixed_size_binary(16)},
+ {"none", LogicalType::None(), ParquetType::BOOLEAN, -1, ::arrow::boolean()},
+ {"none", LogicalType::None(), ParquetType::INT32, -1, ::arrow::int32()},
+ {"none", LogicalType::None(), ParquetType::INT64, -1, ::arrow::int64()},
+ {"none", LogicalType::None(), ParquetType::FLOAT, -1, ::arrow::float32()},
+ {"none", LogicalType::None(), ParquetType::DOUBLE, -1, ::arrow::float64()},
+ {"none", LogicalType::None(), ParquetType::BYTE_ARRAY, -1, ::arrow::binary()},
+ {"none", LogicalType::None(), ParquetType::FIXED_LEN_BYTE_ARRAY, 64,
+ ::arrow::fixed_size_binary(64)},
+ {"null", LogicalType::Null(), ParquetType::BYTE_ARRAY, -1, ::arrow::null()},
+ };
+
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ for (const FieldConstructionArguments& c : cases) {
+ parquet_fields.push_back(PrimitiveNode::Make(c.name, Repetition::OPTIONAL,
+ c.logical_type, c.physical_type,
+ c.physical_length));
+ arrow_fields.push_back(::arrow::field(c.name, c.datatype));
+ }
+
+ ASSERT_OK(ConvertSchema(parquet_fields));
+ auto arrow_schema = ::arrow::schema(arrow_fields);
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(arrow_schema));
+}
+
+TEST_F(TestConvertParquetSchema, DuplicateFieldNames) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("xxx", Repetition::REQUIRED, ParquetType::BOOLEAN));
+ auto arrow_field1 = ::arrow::field("xxx", BOOL, false);
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("xxx", Repetition::REQUIRED, ParquetType::INT32));
+ auto arrow_field2 = ::arrow::field("xxx", INT32, false);
+
+ ASSERT_OK(ConvertSchema(parquet_fields));
+ arrow_fields = {arrow_field1, arrow_field2};
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(::arrow::schema(arrow_fields)));
+}
+
+TEST_F(TestConvertParquetSchema, ParquetKeyValueMetadata) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("boolean", Repetition::REQUIRED, ParquetType::BOOLEAN));
+ arrow_fields.push_back(::arrow::field("boolean", BOOL, false));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("int32", Repetition::REQUIRED, ParquetType::INT32));
+ arrow_fields.push_back(::arrow::field("int32", INT32, false));
+
+ auto key_value_metadata = std::make_shared<KeyValueMetadata>();
+ key_value_metadata->Append("foo", "bar");
+ key_value_metadata->Append("biz", "baz");
+ ASSERT_OK(ConvertSchema(parquet_fields, key_value_metadata));
+
+ auto arrow_metadata = result_schema_->metadata();
+ ASSERT_EQ("foo", arrow_metadata->key(0));
+ ASSERT_EQ("bar", arrow_metadata->value(0));
+ ASSERT_EQ("biz", arrow_metadata->key(1));
+ ASSERT_EQ("baz", arrow_metadata->value(1));
+}
+
+TEST_F(TestConvertParquetSchema, ParquetEmptyKeyValueMetadata) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("int32", Repetition::REQUIRED, ParquetType::INT32));
+ arrow_fields.push_back(::arrow::field("int32", INT32, false));
+
+ std::shared_ptr<KeyValueMetadata> key_value_metadata = nullptr;
+ ASSERT_OK(ConvertSchema(parquet_fields, key_value_metadata));
+
+ auto arrow_metadata = result_schema_->metadata();
+ ASSERT_EQ(arrow_metadata, nullptr);
+}
+
+TEST_F(TestConvertParquetSchema, ParquetFlatDecimals) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ parquet_fields.push_back(PrimitiveNode::Make("flba-decimal", Repetition::OPTIONAL,
+ ParquetType::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, 4, 8, 4));
+ arrow_fields.push_back(::arrow::field("flba-decimal", DECIMAL_8_4));
+
+ parquet_fields.push_back(PrimitiveNode::Make("binary-decimal", Repetition::OPTIONAL,
+ ParquetType::BYTE_ARRAY,
+ ConvertedType::DECIMAL, -1, 8, 4));
+ arrow_fields.push_back(::arrow::field("binary-decimal", DECIMAL_8_4));
+
+ parquet_fields.push_back(PrimitiveNode::Make("int32-decimal", Repetition::OPTIONAL,
+ ParquetType::INT32, ConvertedType::DECIMAL,
+ -1, 8, 4));
+ arrow_fields.push_back(::arrow::field("int32-decimal", DECIMAL_8_4));
+
+ parquet_fields.push_back(PrimitiveNode::Make("int64-decimal", Repetition::OPTIONAL,
+ ParquetType::INT64, ConvertedType::DECIMAL,
+ -1, 8, 4));
+ arrow_fields.push_back(::arrow::field("int64-decimal", DECIMAL_8_4));
+
+ auto arrow_schema = ::arrow::schema(arrow_fields);
+ ASSERT_OK(ConvertSchema(parquet_fields));
+
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(arrow_schema));
+}
+
+TEST_F(TestConvertParquetSchema, ParquetMaps) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ // MAP encoding example taken from parquet-format/LogicalTypes.md
+
+ // Two column map.
+ {
+ auto key = PrimitiveNode::Make("key", Repetition::REQUIRED, ParquetType::BYTE_ARRAY,
+ ConvertedType::UTF8);
+ auto value = PrimitiveNode::Make("value", Repetition::OPTIONAL,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+
+ auto list = GroupNode::Make("key_value", Repetition::REPEATED, {key, value});
+ parquet_fields.push_back(
+ GroupNode::Make("my_map", Repetition::REQUIRED, {list}, LogicalType::Map()));
+ auto arrow_value = ::arrow::field("string", UTF8, /*nullable=*/true);
+ auto arrow_map = ::arrow::map(/*key=*/UTF8, arrow_value);
+ arrow_fields.push_back(::arrow::field("my_map", arrow_map, false));
+ }
+ // Single column map (i.e. set) gets converted to list of struct.
+ {
+ auto key = PrimitiveNode::Make("key", Repetition::REQUIRED, ParquetType::BYTE_ARRAY,
+ ConvertedType::UTF8);
+
+ auto list = GroupNode::Make("key_value", Repetition::REPEATED, {key});
+ parquet_fields.push_back(
+ GroupNode::Make("my_set", Repetition::REQUIRED, {list}, LogicalType::Map()));
+ auto arrow_list = ::arrow::list({::arrow::field("key", UTF8, /*nullable=*/false)});
+ arrow_fields.push_back(::arrow::field("my_set", arrow_list, false));
+ }
+
+ auto arrow_schema = ::arrow::schema(arrow_fields);
+ ASSERT_OK(ConvertSchema(parquet_fields));
+
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(arrow_schema));
+}
+
+TEST_F(TestConvertParquetSchema, ParquetLists) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ // LIST encoding example taken from parquet-format/LogicalTypes.md
+
+ // // List<String> (list non-null, elements nullable)
+ // required group my_list (LIST) {
+ // repeated group list {
+ // optional binary element (UTF8);
+ // }
+ // }
+ {
+ auto element = PrimitiveNode::Make("string", Repetition::OPTIONAL,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+ auto list = GroupNode::Make("list", Repetition::REPEATED, {element});
+ parquet_fields.push_back(
+ GroupNode::Make("my_list", Repetition::REQUIRED, {list}, ConvertedType::LIST));
+ auto arrow_element = ::arrow::field("string", UTF8, true);
+ auto arrow_list = ::arrow::list(arrow_element);
+ arrow_fields.push_back(::arrow::field("my_list", arrow_list, false));
+ }
+
+ // // List<String> (list nullable, elements non-null)
+ // optional group my_list (LIST) {
+ // repeated group list {
+ // required binary element (UTF8);
+ // }
+ // }
+ {
+ auto element = PrimitiveNode::Make("string", Repetition::REQUIRED,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+ auto list = GroupNode::Make("list", Repetition::REPEATED, {element});
+ parquet_fields.push_back(
+ GroupNode::Make("my_list", Repetition::OPTIONAL, {list}, ConvertedType::LIST));
+ auto arrow_element = ::arrow::field("string", UTF8, false);
+ auto arrow_list = ::arrow::list(arrow_element);
+ arrow_fields.push_back(::arrow::field("my_list", arrow_list, true));
+ }
+
+ // Element types can be nested structures. For example, a list of lists:
+ //
+ // // List<List<Integer>>
+ // optional group array_of_arrays (LIST) {
+ // repeated group list {
+ // required group element (LIST) {
+ // repeated group list {
+ // required int32 element;
+ // }
+ // }
+ // }
+ // }
+ {
+ auto inner_element =
+ PrimitiveNode::Make("int32", Repetition::REQUIRED, ParquetType::INT32);
+ auto inner_list = GroupNode::Make("list", Repetition::REPEATED, {inner_element});
+ auto element = GroupNode::Make("element", Repetition::REQUIRED, {inner_list},
+ ConvertedType::LIST);
+ auto list = GroupNode::Make("list", Repetition::REPEATED, {element});
+ parquet_fields.push_back(GroupNode::Make("array_of_arrays", Repetition::OPTIONAL,
+ {list}, ConvertedType::LIST));
+ auto arrow_inner_element = ::arrow::field("int32", INT32, false);
+ auto arrow_inner_list = ::arrow::list(arrow_inner_element);
+ auto arrow_element = ::arrow::field("element", arrow_inner_list, false);
+ auto arrow_list = ::arrow::list(arrow_element);
+ arrow_fields.push_back(::arrow::field("array_of_arrays", arrow_list, true));
+ }
+
+ // // List<String> (list nullable, elements non-null)
+ // optional group my_list (LIST) {
+ // repeated group element {
+ // required binary str (UTF8);
+ // };
+ // }
+ {
+ auto element = PrimitiveNode::Make("str", Repetition::REQUIRED,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+ auto list = GroupNode::Make("element", Repetition::REPEATED, {element});
+ parquet_fields.push_back(
+ GroupNode::Make("my_list", Repetition::OPTIONAL, {list}, ConvertedType::LIST));
+ auto arrow_element = ::arrow::field("str", UTF8, false);
+ auto arrow_list = ::arrow::list(arrow_element);
+ arrow_fields.push_back(::arrow::field("my_list", arrow_list, true));
+ }
+
+ // // List<Integer> (nullable list, non-null elements)
+ // optional group my_list (LIST) {
+ // repeated int32 element;
+ // }
+ {
+ auto element =
+ PrimitiveNode::Make("element", Repetition::REPEATED, ParquetType::INT32);
+ parquet_fields.push_back(
+ GroupNode::Make("my_list", Repetition::OPTIONAL, {element}, ConvertedType::LIST));
+ auto arrow_element = ::arrow::field("element", INT32, false);
+ auto arrow_list = ::arrow::list(arrow_element);
+ arrow_fields.push_back(::arrow::field("my_list", arrow_list, true));
+ }
+
+ // // List<Tuple<String, Integer>> (nullable list, non-null elements)
+ // optional group my_list (LIST) {
+ // repeated group element {
+ // required binary str (UTF8);
+ // required int32 num;
+ // };
+ // }
+ {
+ auto str_element = PrimitiveNode::Make("str", Repetition::REQUIRED,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+ auto num_element =
+ PrimitiveNode::Make("num", Repetition::REQUIRED, ParquetType::INT32);
+ auto element =
+ GroupNode::Make("element", Repetition::REPEATED, {str_element, num_element});
+ parquet_fields.push_back(
+ GroupNode::Make("my_list", Repetition::OPTIONAL, {element}, ConvertedType::LIST));
+ auto arrow_str = ::arrow::field("str", UTF8, false);
+ auto arrow_num = ::arrow::field("num", INT32, false);
+ std::vector<std::shared_ptr<Field>> fields({arrow_str, arrow_num});
+ auto arrow_struct = ::arrow::struct_(fields);
+ auto arrow_element = ::arrow::field("element", arrow_struct, false);
+ auto arrow_list = ::arrow::list(arrow_element);
+ arrow_fields.push_back(::arrow::field("my_list", arrow_list, true));
+ }
+
+ // // List<OneTuple<String>> (nullable list, non-null elements)
+ // optional group my_list (LIST) {
+ // repeated group array {
+ // required binary str (UTF8);
+ // };
+ // }
+ // Special case: group is named array
+ {
+ auto element = PrimitiveNode::Make("str", Repetition::REQUIRED,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+ auto array = GroupNode::Make("array", Repetition::REPEATED, {element});
+ parquet_fields.push_back(
+ GroupNode::Make("my_list", Repetition::OPTIONAL, {array}, ConvertedType::LIST));
+ auto arrow_str = ::arrow::field("str", UTF8, false);
+ std::vector<std::shared_ptr<Field>> fields({arrow_str});
+ auto arrow_struct = ::arrow::struct_(fields);
+ auto arrow_element = ::arrow::field("array", arrow_struct, false);
+ auto arrow_list = ::arrow::list(arrow_element);
+ arrow_fields.push_back(::arrow::field("my_list", arrow_list, true));
+ }
+
+ // // List<OneTuple<String>> (nullable list, non-null elements)
+ // optional group my_list (LIST) {
+ // repeated group my_list_tuple {
+ // required binary str (UTF8);
+ // };
+ // }
+ // Special case: group named ends in _tuple
+ {
+ auto element = PrimitiveNode::Make("str", Repetition::REQUIRED,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+ auto array = GroupNode::Make("my_list_tuple", Repetition::REPEATED, {element});
+ parquet_fields.push_back(
+ GroupNode::Make("my_list", Repetition::OPTIONAL, {array}, ConvertedType::LIST));
+ auto arrow_str = ::arrow::field("str", UTF8, false);
+ std::vector<std::shared_ptr<Field>> fields({arrow_str});
+ auto arrow_struct = ::arrow::struct_(fields);
+ auto arrow_element = ::arrow::field("my_list_tuple", arrow_struct, false);
+ auto arrow_list = ::arrow::list(arrow_element);
+ arrow_fields.push_back(::arrow::field("my_list", arrow_list, true));
+ }
+
+ // One-level encoding: Only allows required lists with required cells
+ // repeated value_type name
+ {
+ parquet_fields.push_back(
+ PrimitiveNode::Make("name", Repetition::REPEATED, ParquetType::INT32));
+ auto arrow_element = ::arrow::field("name", INT32, false);
+ auto arrow_list = ::arrow::list(arrow_element);
+ arrow_fields.push_back(::arrow::field("name", arrow_list, false));
+ }
+
+ auto arrow_schema = ::arrow::schema(arrow_fields);
+ ASSERT_OK(ConvertSchema(parquet_fields));
+
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(arrow_schema));
+}
+
+TEST_F(TestConvertParquetSchema, UnsupportedThings) {
+ std::vector<NodePtr> unsupported_nodes;
+
+ for (const NodePtr& node : unsupported_nodes) {
+ ASSERT_RAISES(NotImplemented, ConvertSchema({node}));
+ }
+}
+
+TEST_F(TestConvertParquetSchema, ParquetNestedSchema) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ // required group group1 {
+ // required bool leaf1;
+ // required int32 leaf2;
+ // }
+ // required int64 leaf3;
+ {
+ parquet_fields.push_back(GroupNode::Make(
+ "group1", Repetition::REQUIRED,
+ {PrimitiveNode::Make("leaf1", Repetition::REQUIRED, ParquetType::BOOLEAN),
+ PrimitiveNode::Make("leaf2", Repetition::REQUIRED, ParquetType::INT32)}));
+ parquet_fields.push_back(
+ PrimitiveNode::Make("leaf3", Repetition::REQUIRED, ParquetType::INT64));
+
+ auto group1_fields = {::arrow::field("leaf1", BOOL, false),
+ ::arrow::field("leaf2", INT32, false)};
+ auto arrow_group1_type = ::arrow::struct_(group1_fields);
+ arrow_fields.push_back(::arrow::field("group1", arrow_group1_type, false));
+ arrow_fields.push_back(::arrow::field("leaf3", INT64, false));
+ }
+
+ auto arrow_schema = ::arrow::schema(arrow_fields);
+ ASSERT_OK(ConvertSchema(parquet_fields));
+
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(arrow_schema));
+}
+
+TEST_F(TestConvertParquetSchema, ParquetNestedSchema2) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ // Full Parquet Schema:
+ // required group group1 {
+ // required int64 leaf1;
+ // required int64 leaf2;
+ // }
+ // required group group2 {
+ // required int64 leaf3;
+ // required int64 leaf4;
+ // }
+ // required int64 leaf5;
+ {
+ parquet_fields.push_back(GroupNode::Make(
+ "group1", Repetition::REQUIRED,
+ {PrimitiveNode::Make("leaf1", Repetition::REQUIRED, ParquetType::INT64),
+ PrimitiveNode::Make("leaf2", Repetition::REQUIRED, ParquetType::INT64)}));
+ parquet_fields.push_back(GroupNode::Make(
+ "group2", Repetition::REQUIRED,
+ {PrimitiveNode::Make("leaf3", Repetition::REQUIRED, ParquetType::INT64),
+ PrimitiveNode::Make("leaf4", Repetition::REQUIRED, ParquetType::INT64)}));
+ parquet_fields.push_back(
+ PrimitiveNode::Make("leaf5", Repetition::REQUIRED, ParquetType::INT64));
+
+ auto group1_fields = {::arrow::field("leaf1", INT64, false),
+ ::arrow::field("leaf2", INT64, false)};
+ auto arrow_group1_type = ::arrow::struct_(group1_fields);
+ auto group2_fields = {::arrow::field("leaf3", INT64, false),
+ ::arrow::field("leaf4", INT64, false)};
+ auto arrow_group2_type = ::arrow::struct_(group2_fields);
+ arrow_fields.push_back(::arrow::field("group1", arrow_group1_type, false));
+ arrow_fields.push_back(::arrow::field("group2", arrow_group2_type, false));
+ arrow_fields.push_back(::arrow::field("leaf5", INT64, false));
+ }
+
+ auto arrow_schema = ::arrow::schema(arrow_fields);
+ ASSERT_OK(ConvertSchema(parquet_fields));
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(arrow_schema));
+}
+
+TEST_F(TestConvertParquetSchema, ParquetRepeatedNestedSchema) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+ {
+ // optional int32 leaf1;
+ // repeated group outerGroup {
+ // optional int32 leaf2;
+ // repeated group innerGroup {
+ // optional int32 leaf3;
+ // }
+ // }
+ parquet_fields.push_back(
+ PrimitiveNode::Make("leaf1", Repetition::OPTIONAL, ParquetType::INT32));
+ parquet_fields.push_back(GroupNode::Make(
+ "outerGroup", Repetition::REPEATED,
+ {PrimitiveNode::Make("leaf2", Repetition::OPTIONAL, ParquetType::INT32),
+ GroupNode::Make(
+ "innerGroup", Repetition::REPEATED,
+ {PrimitiveNode::Make("leaf3", Repetition::OPTIONAL, ParquetType::INT32)})}));
+
+ auto inner_group_fields = {::arrow::field("leaf3", INT32, true)};
+ auto inner_group_type = ::arrow::struct_(inner_group_fields);
+ auto outer_group_fields = {
+ ::arrow::field("leaf2", INT32, true),
+ ::arrow::field(
+ "innerGroup",
+ ::arrow::list(::arrow::field("innerGroup", inner_group_type, false)), false)};
+ auto outer_group_type = ::arrow::struct_(outer_group_fields);
+
+ arrow_fields.push_back(::arrow::field("leaf1", INT32, true));
+ arrow_fields.push_back(::arrow::field(
+ "outerGroup",
+ ::arrow::list(::arrow::field("outerGroup", outer_group_type, false)), false));
+ }
+ auto arrow_schema = ::arrow::schema(arrow_fields);
+ ASSERT_OK(ConvertSchema(parquet_fields));
+
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(arrow_schema));
+}
+
+class TestConvertArrowSchema : public ::testing::Test {
+ public:
+ virtual void SetUp() {}
+
+ void CheckFlatSchema(const std::vector<NodePtr>& nodes) {
+ NodePtr schema_node = GroupNode::Make("schema", Repetition::REPEATED, nodes);
+ const GroupNode* expected_schema_node =
+ static_cast<const GroupNode*>(schema_node.get());
+ const GroupNode* result_schema_node = result_schema_->group_node();
+
+ ASSERT_EQ(expected_schema_node->field_count(), result_schema_node->field_count());
+
+ for (int i = 0; i < expected_schema_node->field_count(); i++) {
+ auto lhs = result_schema_node->field(i);
+ auto rhs = expected_schema_node->field(i);
+ EXPECT_TRUE(lhs->Equals(rhs.get()));
+ }
+ }
+
+ ::arrow::Status ConvertSchema(
+ const std::vector<std::shared_ptr<Field>>& fields,
+ std::shared_ptr<::parquet::ArrowWriterProperties> arrow_properties =
+ ::parquet::default_arrow_writer_properties()) {
+ arrow_schema_ = ::arrow::schema(fields);
+ std::shared_ptr<::parquet::WriterProperties> properties =
+ ::parquet::default_writer_properties();
+ return ToParquetSchema(arrow_schema_.get(), *properties.get(), *arrow_properties,
+ &result_schema_);
+ }
+
+ protected:
+ std::shared_ptr<::arrow::Schema> arrow_schema_;
+ std::shared_ptr<SchemaDescriptor> result_schema_;
+};
+
+TEST_F(TestConvertArrowSchema, ParquetFlatPrimitives) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("boolean", Repetition::REQUIRED, ParquetType::BOOLEAN));
+ arrow_fields.push_back(::arrow::field("boolean", BOOL, false));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("int32", Repetition::REQUIRED, ParquetType::INT32));
+ arrow_fields.push_back(::arrow::field("int32", INT32, false));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("int64", Repetition::REQUIRED, ParquetType::INT64));
+ arrow_fields.push_back(::arrow::field("int64", INT64, false));
+
+ parquet_fields.push_back(PrimitiveNode::Make("date", Repetition::REQUIRED,
+ ParquetType::INT32, ConvertedType::DATE));
+ arrow_fields.push_back(::arrow::field("date", ::arrow::date32(), false));
+
+ parquet_fields.push_back(PrimitiveNode::Make("date64", Repetition::REQUIRED,
+ ParquetType::INT32, ConvertedType::DATE));
+ arrow_fields.push_back(::arrow::field("date64", ::arrow::date64(), false));
+
+ parquet_fields.push_back(PrimitiveNode::Make("timestamp", Repetition::REQUIRED,
+ ParquetType::INT64,
+ ConvertedType::TIMESTAMP_MILLIS));
+ arrow_fields.push_back(
+ ::arrow::field("timestamp", ::arrow::timestamp(TimeUnit::MILLI, "UTC"), false));
+
+ parquet_fields.push_back(PrimitiveNode::Make("timestamp[us]", Repetition::REQUIRED,
+ ParquetType::INT64,
+ ConvertedType::TIMESTAMP_MICROS));
+ arrow_fields.push_back(
+ ::arrow::field("timestamp[us]", ::arrow::timestamp(TimeUnit::MICRO, "UTC"), false));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("float", Repetition::OPTIONAL, ParquetType::FLOAT));
+ arrow_fields.push_back(::arrow::field("float", FLOAT));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("double", Repetition::OPTIONAL, ParquetType::DOUBLE));
+ arrow_fields.push_back(::arrow::field("double", DOUBLE));
+
+ parquet_fields.push_back(PrimitiveNode::Make(
+ "string", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY, ConvertedType::UTF8));
+ arrow_fields.push_back(::arrow::field("string", UTF8));
+
+ parquet_fields.push_back(PrimitiveNode::Make(
+ "binary", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY, ConvertedType::NONE));
+ arrow_fields.push_back(::arrow::field("binary", BINARY));
+
+ ASSERT_OK(ConvertSchema(arrow_fields));
+
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(parquet_fields));
+}
+
+TEST_F(TestConvertArrowSchema, ArrowFields) {
+ struct FieldConstructionArguments {
+ std::string name;
+ std::shared_ptr<::arrow::DataType> datatype;
+ std::shared_ptr<const LogicalType> logical_type;
+ parquet::Type::type physical_type;
+ int physical_length;
+ };
+
+ std::vector<FieldConstructionArguments> cases = {
+ {"boolean", ::arrow::boolean(), LogicalType::None(), ParquetType::BOOLEAN, -1},
+ {"binary", ::arrow::binary(), LogicalType::None(), ParquetType::BYTE_ARRAY, -1},
+ {"large_binary", ::arrow::large_binary(), LogicalType::None(),
+ ParquetType::BYTE_ARRAY, -1},
+ {"fixed_size_binary", ::arrow::fixed_size_binary(64), LogicalType::None(),
+ ParquetType::FIXED_LEN_BYTE_ARRAY, 64},
+ {"uint8", ::arrow::uint8(), LogicalType::Int(8, false), ParquetType::INT32, -1},
+ {"int8", ::arrow::int8(), LogicalType::Int(8, true), ParquetType::INT32, -1},
+ {"uint16", ::arrow::uint16(), LogicalType::Int(16, false), ParquetType::INT32, -1},
+ {"int16", ::arrow::int16(), LogicalType::Int(16, true), ParquetType::INT32, -1},
+ {"uint32", ::arrow::uint32(), LogicalType::None(), ParquetType::INT64,
+ -1}, // Parquet 1.0
+ {"int32", ::arrow::int32(), LogicalType::None(), ParquetType::INT32, -1},
+ {"uint64", ::arrow::uint64(), LogicalType::Int(64, false), ParquetType::INT64, -1},
+ {"int64", ::arrow::int64(), LogicalType::None(), ParquetType::INT64, -1},
+ {"float32", ::arrow::float32(), LogicalType::None(), ParquetType::FLOAT, -1},
+ {"float64", ::arrow::float64(), LogicalType::None(), ParquetType::DOUBLE, -1},
+ {"utf8", ::arrow::utf8(), LogicalType::String(), ParquetType::BYTE_ARRAY, -1},
+ {"large_utf8", ::arrow::large_utf8(), LogicalType::String(),
+ ParquetType::BYTE_ARRAY, -1},
+ {"decimal(1, 0)", ::arrow::decimal(1, 0), LogicalType::Decimal(1, 0),
+ ParquetType::FIXED_LEN_BYTE_ARRAY, 1},
+ {"decimal(8, 2)", ::arrow::decimal(8, 2), LogicalType::Decimal(8, 2),
+ ParquetType::FIXED_LEN_BYTE_ARRAY, 4},
+ {"decimal(16, 4)", ::arrow::decimal(16, 4), LogicalType::Decimal(16, 4),
+ ParquetType::FIXED_LEN_BYTE_ARRAY, 7},
+ {"decimal(32, 8)", ::arrow::decimal(32, 8), LogicalType::Decimal(32, 8),
+ ParquetType::FIXED_LEN_BYTE_ARRAY, 14},
+ {"time32", ::arrow::time32(::arrow::TimeUnit::MILLI),
+ LogicalType::Time(true, LogicalType::TimeUnit::MILLIS), ParquetType::INT32, -1},
+ {"time64(microsecond)", ::arrow::time64(::arrow::TimeUnit::MICRO),
+ LogicalType::Time(true, LogicalType::TimeUnit::MICROS), ParquetType::INT64, -1},
+ {"time64(nanosecond)", ::arrow::time64(::arrow::TimeUnit::NANO),
+ LogicalType::Time(true, LogicalType::TimeUnit::NANOS), ParquetType::INT64, -1},
+ {"timestamp(millisecond)", ::arrow::timestamp(::arrow::TimeUnit::MILLI),
+ LogicalType::Timestamp(false, LogicalType::TimeUnit::MILLIS,
+ /*is_from_converted_type=*/false,
+ /*force_set_converted_type=*/true),
+ ParquetType::INT64, -1},
+ {"timestamp(microsecond)", ::arrow::timestamp(::arrow::TimeUnit::MICRO),
+ LogicalType::Timestamp(false, LogicalType::TimeUnit::MICROS,
+ /*is_from_converted_type=*/false,
+ /*force_set_converted_type=*/true),
+ ParquetType::INT64, -1},
+ // Parquet v1, values converted to microseconds
+ {"timestamp(nanosecond)", ::arrow::timestamp(::arrow::TimeUnit::NANO),
+ LogicalType::Timestamp(false, LogicalType::TimeUnit::MICROS,
+ /*is_from_converted_type=*/false,
+ /*force_set_converted_type=*/true),
+ ParquetType::INT64, -1},
+ {"timestamp(millisecond, UTC)", ::arrow::timestamp(::arrow::TimeUnit::MILLI, "UTC"),
+ LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS), ParquetType::INT64,
+ -1},
+ {"timestamp(microsecond, UTC)", ::arrow::timestamp(::arrow::TimeUnit::MICRO, "UTC"),
+ LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), ParquetType::INT64,
+ -1},
+ {"timestamp(nanosecond, UTC)", ::arrow::timestamp(::arrow::TimeUnit::NANO, "UTC"),
+ LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), ParquetType::INT64,
+ -1},
+ {"timestamp(millisecond, CET)", ::arrow::timestamp(::arrow::TimeUnit::MILLI, "CET"),
+ LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS), ParquetType::INT64,
+ -1},
+ {"timestamp(microsecond, CET)", ::arrow::timestamp(::arrow::TimeUnit::MICRO, "CET"),
+ LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), ParquetType::INT64,
+ -1},
+ {"timestamp(nanosecond, CET)", ::arrow::timestamp(::arrow::TimeUnit::NANO, "CET"),
+ LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), ParquetType::INT64,
+ -1}};
+
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+ std::vector<NodePtr> parquet_fields;
+
+ for (const FieldConstructionArguments& c : cases) {
+ arrow_fields.push_back(::arrow::field(c.name, c.datatype, false));
+ parquet_fields.push_back(PrimitiveNode::Make(c.name, Repetition::REQUIRED,
+ c.logical_type, c.physical_type,
+ c.physical_length));
+ }
+
+ ASSERT_OK(ConvertSchema(arrow_fields));
+ CheckFlatSchema(parquet_fields);
+ // ASSERT_NO_FATAL_FAILURE();
+}
+
+TEST_F(TestConvertArrowSchema, ArrowNonconvertibleFields) {
+ struct FieldConstructionArguments {
+ std::string name;
+ std::shared_ptr<::arrow::DataType> datatype;
+ };
+
+ std::vector<FieldConstructionArguments> cases = {
+ {"float16", ::arrow::float16()},
+ };
+
+ for (const FieldConstructionArguments& c : cases) {
+ auto field = ::arrow::field(c.name, c.datatype);
+ ASSERT_RAISES(NotImplemented, ConvertSchema({field}));
+ }
+}
+
+TEST_F(TestConvertArrowSchema, ParquetFlatPrimitivesAsDictionaries) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+ std::shared_ptr<::arrow::Array> dict;
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("int32", Repetition::REQUIRED, ParquetType::INT32));
+ arrow_fields.push_back(::arrow::field(
+ "int32", ::arrow::dictionary(::arrow::int8(), ::arrow::int32()), false));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("int64", Repetition::REQUIRED, ParquetType::INT64));
+ arrow_fields.push_back(::arrow::field(
+ "int64", ::arrow::dictionary(::arrow::int8(), ::arrow::int64()), false));
+
+ parquet_fields.push_back(PrimitiveNode::Make("date", Repetition::REQUIRED,
+ ParquetType::INT32, ConvertedType::DATE));
+ arrow_fields.push_back(::arrow::field(
+ "date", ::arrow::dictionary(::arrow::int8(), ::arrow::date32()), false));
+
+ parquet_fields.push_back(PrimitiveNode::Make("date64", Repetition::REQUIRED,
+ ParquetType::INT32, ConvertedType::DATE));
+ arrow_fields.push_back(::arrow::field(
+ "date64", ::arrow::dictionary(::arrow::int8(), ::arrow::date64()), false));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("float", Repetition::OPTIONAL, ParquetType::FLOAT));
+ arrow_fields.push_back(
+ ::arrow::field("float", ::arrow::dictionary(::arrow::int8(), ::arrow::float32())));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("double", Repetition::OPTIONAL, ParquetType::DOUBLE));
+ arrow_fields.push_back(
+ ::arrow::field("double", ::arrow::dictionary(::arrow::int8(), ::arrow::float64())));
+
+ parquet_fields.push_back(PrimitiveNode::Make(
+ "string", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY, ConvertedType::UTF8));
+ arrow_fields.push_back(
+ ::arrow::field("string", ::arrow::dictionary(::arrow::int8(), ::arrow::utf8())));
+
+ parquet_fields.push_back(PrimitiveNode::Make(
+ "binary", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY, ConvertedType::NONE));
+ arrow_fields.push_back(
+ ::arrow::field("binary", ::arrow::dictionary(::arrow::int8(), ::arrow::binary())));
+
+ ASSERT_OK(ConvertSchema(arrow_fields));
+
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(parquet_fields));
+}
+
+TEST_F(TestConvertArrowSchema, ParquetLists) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ // parquet_arrow will always generate 3-level LIST encodings
+
+ // // List<String> (list non-null, elements nullable)
+ // required group my_list (LIST) {
+ // repeated group list {
+ // optional binary element (UTF8);
+ // }
+ // }
+ {
+ auto element = PrimitiveNode::Make("string", Repetition::OPTIONAL,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+ auto list = GroupNode::Make("list", Repetition::REPEATED, {element});
+ parquet_fields.push_back(
+ GroupNode::Make("my_list", Repetition::REQUIRED, {list}, ConvertedType::LIST));
+ auto arrow_element = ::arrow::field("string", UTF8, true);
+ auto arrow_list = ::arrow::list(arrow_element);
+ arrow_fields.push_back(::arrow::field("my_list", arrow_list, false));
+ }
+
+ // // List<String> (list nullable, elements non-null)
+ // optional group my_list (LIST) {
+ // repeated group list {
+ // required binary element (UTF8);
+ // }
+ // }
+ {
+ auto element = PrimitiveNode::Make("string", Repetition::REQUIRED,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+ auto list = GroupNode::Make("list", Repetition::REPEATED, {element});
+ parquet_fields.push_back(
+ GroupNode::Make("my_list", Repetition::OPTIONAL, {list}, ConvertedType::LIST));
+ auto arrow_element = ::arrow::field("string", UTF8, false);
+ auto arrow_list = ::arrow::list(arrow_element);
+ arrow_fields.push_back(::arrow::field("my_list", arrow_list, true));
+ }
+
+ ASSERT_OK(ConvertSchema(arrow_fields));
+
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(parquet_fields));
+}
+
+TEST_F(TestConvertArrowSchema, ParquetMaps) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ // // Map<String, String> (map and map values nullable)
+ // optional group my_map (MAP) {
+ // repeated group key_value {
+ // required binary key (UTF8);
+ // optional binary value (UTF8);
+ // }
+ // }
+ {
+ auto key = PrimitiveNode::Make("key", Repetition::REQUIRED, ParquetType::BYTE_ARRAY,
+ ConvertedType::UTF8);
+ auto value = PrimitiveNode::Make("value", Repetition::OPTIONAL,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+
+ auto list = GroupNode::Make("key_value", Repetition::REPEATED, {key, value});
+ parquet_fields.push_back(
+ GroupNode::Make("my_map", Repetition::OPTIONAL, {list}, ConvertedType::MAP));
+ auto arrow_key = ::arrow::field("string", UTF8, /*nullable=*/false);
+ auto arrow_value = ::arrow::field("other_string", UTF8, /*nullable=*/true);
+ auto arrow_map = ::arrow::map(arrow_key->type(), arrow_value, /*nullable=*/false);
+ arrow_fields.push_back(::arrow::field("my_map", arrow_map, /*nullable=*/true));
+ }
+
+ // // Map<String, String> (non-nullable)
+ // required group my_map (MAP) {
+ // repeated group key_value {
+ // required binary key (UTF8);
+ // required binary value (UTF8);
+ // }
+ // }
+ {
+ auto key = PrimitiveNode::Make("key", Repetition::REQUIRED, ParquetType::BYTE_ARRAY,
+ ConvertedType::UTF8);
+ auto value = PrimitiveNode::Make("value", Repetition::REQUIRED,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+
+ auto list = GroupNode::Make("key_value", Repetition::REPEATED, {key, value});
+ parquet_fields.push_back(
+ GroupNode::Make("my_map", Repetition::REQUIRED, {list}, ConvertedType::MAP));
+ auto arrow_key = ::arrow::field("string", UTF8, /*nullable=*/false);
+ auto arrow_value = ::arrow::field("other_string", UTF8, /*nullable=*/false);
+ auto arrow_map = ::arrow::map(arrow_key->type(), arrow_value);
+ arrow_fields.push_back(::arrow::field("my_map", arrow_map, /*nullable=*/false));
+ }
+
+ ASSERT_OK(ConvertSchema(arrow_fields));
+
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(parquet_fields));
+}
+
+TEST_F(TestConvertArrowSchema, ParquetOtherLists) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ // parquet_arrow will always generate 3-level LIST encodings
+
+ // // LargeList<String> (list-like non-null, elements nullable)
+ // required group my_list (LIST) {
+ // repeated group list {
+ // optional binary element (UTF8);
+ // }
+ // }
+ {
+ auto element = PrimitiveNode::Make("string", Repetition::OPTIONAL,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+ auto list = GroupNode::Make("list", Repetition::REPEATED, {element});
+ parquet_fields.push_back(
+ GroupNode::Make("my_list", Repetition::REQUIRED, {list}, ConvertedType::LIST));
+ auto arrow_element = ::arrow::field("string", UTF8, true);
+ auto arrow_list = ::arrow::large_list(arrow_element);
+ arrow_fields.push_back(::arrow::field("my_list", arrow_list, false));
+ }
+ // // FixedSizeList[10]<String> (list-like non-null, elements nullable)
+ // required group my_list (LIST) {
+ // repeated group list {
+ // optional binary element (UTF8);
+ // }
+ // }
+ {
+ auto element = PrimitiveNode::Make("string", Repetition::OPTIONAL,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+ auto list = GroupNode::Make("list", Repetition::REPEATED, {element});
+ parquet_fields.push_back(
+ GroupNode::Make("my_list", Repetition::REQUIRED, {list}, ConvertedType::LIST));
+ auto arrow_element = ::arrow::field("string", UTF8, true);
+ auto arrow_list = ::arrow::fixed_size_list(arrow_element, 10);
+ arrow_fields.push_back(::arrow::field("my_list", arrow_list, false));
+ }
+
+ ASSERT_OK(ConvertSchema(arrow_fields));
+
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(parquet_fields));
+}
+
+TEST_F(TestConvertArrowSchema, ParquetNestedComplianceEnabledNullable) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ // parquet_arrow will always generate 3-level LIST encodings
+
+ // // List<String> (list non-null, elements nullable)
+ // required group my_list (LIST) {
+ // repeated group list {
+ // optional binary element (UTF8);
+ // }
+ // }
+ {
+ auto element = PrimitiveNode::Make("element", Repetition::OPTIONAL,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+ auto list = GroupNode::Make("list", Repetition::REPEATED, {element});
+ parquet_fields.push_back(
+ GroupNode::Make("my_list", Repetition::REQUIRED, {list}, ConvertedType::LIST));
+ auto arrow_element = ::arrow::field("string", UTF8, true);
+ auto arrow_list = ::arrow::large_list(arrow_element);
+ arrow_fields.push_back(::arrow::field("my_list", arrow_list, false));
+ }
+
+ ArrowWriterProperties::Builder builder;
+ builder.enable_compliant_nested_types();
+ auto arrow_properties = builder.build();
+
+ ASSERT_OK(ConvertSchema(arrow_fields, arrow_properties));
+
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(parquet_fields));
+}
+
+TEST_F(TestConvertArrowSchema, ParquetNestedComplianceEnabledNotNullable) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ // parquet_arrow will always generate 3-level LIST encodings
+
+ // // List<String> (list non-null, elements nullable)
+ // optional group my_list (LIST) {
+ // repeated group list {
+ // optional binary element (UTF8);
+ // }
+ // }
+ {
+ auto element = PrimitiveNode::Make("element", Repetition::REQUIRED,
+ ParquetType::BYTE_ARRAY, ConvertedType::UTF8);
+ auto list = GroupNode::Make("list", Repetition::REPEATED, {element});
+ parquet_fields.push_back(
+ GroupNode::Make("my_list", Repetition::REQUIRED, {list}, ConvertedType::LIST));
+ auto arrow_element = ::arrow::field("string", UTF8, false);
+ auto arrow_list = ::arrow::large_list(arrow_element);
+ arrow_fields.push_back(::arrow::field("my_list", arrow_list, false));
+ }
+
+ ArrowWriterProperties::Builder builder;
+ builder.enable_compliant_nested_types();
+ auto arrow_properties = builder.build();
+
+ ASSERT_OK(ConvertSchema(arrow_fields, arrow_properties));
+
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(parquet_fields));
+}
+
+TEST_F(TestConvertArrowSchema, ParquetFlatDecimals) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+
+ // TODO: Test Decimal Arrow -> Parquet conversion
+
+ ASSERT_OK(ConvertSchema(arrow_fields));
+
+ ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(parquet_fields));
+}
+
+class TestConvertRoundTrip : public ::testing::Test {
+ public:
+ ::arrow::Status RoundTripSchema(
+ const std::vector<std::shared_ptr<Field>>& fields,
+ std::shared_ptr<::parquet::ArrowWriterProperties> arrow_properties =
+ ::parquet::default_arrow_writer_properties()) {
+ arrow_schema_ = ::arrow::schema(fields);
+ std::shared_ptr<::parquet::WriterProperties> properties =
+ ::parquet::default_writer_properties();
+ RETURN_NOT_OK(ToParquetSchema(arrow_schema_.get(), *properties.get(),
+ *arrow_properties, &parquet_schema_));
+ ::parquet::schema::ToParquet(parquet_schema_->group_node(), &parquet_format_schema_);
+ auto parquet_schema = ::parquet::schema::FromParquet(parquet_format_schema_);
+ return FromParquetSchema(parquet_schema.get(), &result_schema_);
+ }
+
+ protected:
+ std::shared_ptr<::arrow::Schema> arrow_schema_;
+ std::shared_ptr<SchemaDescriptor> parquet_schema_;
+ std::vector<SchemaElement> parquet_format_schema_;
+ std::shared_ptr<::arrow::Schema> result_schema_;
+};
+
+int GetFieldId(const ::arrow::Field& field) {
+ if (field.metadata() == nullptr) {
+ return -1;
+ }
+ auto maybe_field = field.metadata()->Get("PARQUET:field_id");
+ if (!maybe_field.ok()) {
+ return -1;
+ }
+ return std::stoi(maybe_field.ValueOrDie());
+}
+
+void GetFieldIdsDfs(const ::arrow::FieldVector& fields, std::vector<int>* field_ids) {
+ for (const auto& field : fields) {
+ field_ids->push_back(GetFieldId(*field));
+ GetFieldIdsDfs(field->type()->fields(), field_ids);
+ }
+}
+
+std::vector<int> GetFieldIdsDfs(const ::arrow::FieldVector& fields) {
+ std::vector<int> field_ids;
+ GetFieldIdsDfs(fields, &field_ids);
+ return field_ids;
+}
+
+std::vector<int> GetParquetFieldIdsHelper(const parquet::schema::Node* node) {
+ std::vector<int> field_ids;
+ field_ids.push_back(node->field_id());
+ if (node->is_group()) {
+ const GroupNode* group_node = static_cast<const GroupNode*>(node);
+ for (int i = 0; i < group_node->field_count(); i++) {
+ for (auto id : GetParquetFieldIdsHelper(group_node->field(i).get())) {
+ field_ids.push_back(id);
+ }
+ }
+ }
+ return field_ids;
+}
+
+std::vector<int> GetParquetFieldIds(std::shared_ptr<SchemaDescriptor> parquet_schema) {
+ return GetParquetFieldIdsHelper(
+ static_cast<const parquet::schema::Node*>(parquet_schema->group_node()));
+}
+
+std::vector<int> GetThriftFieldIds(
+ const std::vector<SchemaElement>& parquet_format_schema) {
+ std::vector<int> field_ids;
+ for (const auto& element : parquet_format_schema) {
+ field_ids.push_back(element.field_id);
+ }
+ return field_ids;
+}
+
+TEST_F(TestConvertRoundTrip, FieldIdMissingIfNotSpecified) {
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+ arrow_fields.push_back(::arrow::field("simple", ::arrow::int32(), false));
+ /// { "nested": { "outer": { "inner" }, "sibling" } }
+ arrow_fields.push_back(::arrow::field(
+ "nested",
+ ::arrow::struct_({::arrow::field("outer", ::arrow::struct_({::arrow::field(
+ "inner", ::arrow::utf8())})),
+ ::arrow::field("sibling", ::arrow::date32())}),
+ false));
+
+ ASSERT_OK(RoundTripSchema(arrow_fields));
+ auto field_ids = GetFieldIdsDfs(result_schema_->fields());
+ for (int actual_id : field_ids) {
+ ASSERT_EQ(actual_id, -1);
+ }
+ auto parquet_field_ids = GetParquetFieldIds(parquet_schema_);
+ for (int actual_id : parquet_field_ids) {
+ ASSERT_EQ(actual_id, -1);
+ }
+ // In our unit test a "not set" thrift field has a value of 0
+ auto thrift_field_ids = GetThriftFieldIds(parquet_format_schema_);
+ for (int actual_id : thrift_field_ids) {
+ ASSERT_EQ(actual_id, 0);
+ }
+}
+
+std::shared_ptr<::arrow::KeyValueMetadata> FieldIdMetadata(int field_id) {
+ return ::arrow::key_value_metadata({"PARQUET:field_id"}, {std::to_string(field_id)});
+}
+
+TEST_F(TestConvertRoundTrip, FieldIdPreserveExisting) {
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+ arrow_fields.push_back(
+ ::arrow::field("simple", ::arrow::int32(), /*nullable=*/true, FieldIdMetadata(2)));
+ /// { "nested": { "outer": { "inner" }, "sibling" }
+ arrow_fields.push_back(::arrow::field(
+ "nested",
+ ::arrow::struct_({::arrow::field("outer", ::arrow::struct_({::arrow::field(
+ "inner", ::arrow::utf8())})),
+ ::arrow::field("sibling", ::arrow::date32(), /*nullable=*/true,
+ FieldIdMetadata(17))}),
+ false));
+
+ ASSERT_OK(RoundTripSchema(arrow_fields));
+ auto field_ids = GetFieldIdsDfs(result_schema_->fields());
+ auto expected_field_ids = std::vector<int>{2, -1, -1, -1, 17};
+ ASSERT_EQ(field_ids, expected_field_ids);
+
+ // Parquet has a field id for the schema itself
+ expected_field_ids = std::vector<int>{-1, 2, -1, -1, -1, 17};
+ auto parquet_ids = GetParquetFieldIds(parquet_schema_);
+ ASSERT_EQ(parquet_ids, expected_field_ids);
+
+ // In our unit test a "not set" thrift field has a value of 0
+ expected_field_ids = std::vector<int>{0, 2, 0, 0, 0, 17};
+ auto thrift_field_ids = GetThriftFieldIds(parquet_format_schema_);
+ ASSERT_EQ(thrift_field_ids, expected_field_ids);
+}
+
+TEST(InvalidSchema, ParquetNegativeDecimalScale) {
+ const auto& type = ::arrow::decimal(23, -2);
+ const auto& field = ::arrow::field("f0", type);
+ const auto& arrow_schema = ::arrow::schema({field});
+ std::shared_ptr<::parquet::WriterProperties> properties =
+ ::parquet::default_writer_properties();
+ std::shared_ptr<SchemaDescriptor> result_schema;
+
+ ASSERT_RAISES(IOError,
+ ToParquetSchema(arrow_schema.get(), *properties.get(), &result_schema));
+}
+
+TEST(InvalidSchema, NonNullableNullType) {
+ const auto& field = ::arrow::field("f0", ::arrow::null(), /*nullable=*/false);
+ const auto& arrow_schema = ::arrow::schema({field});
+ std::shared_ptr<::parquet::WriterProperties> properties =
+ ::parquet::default_writer_properties();
+ std::shared_ptr<SchemaDescriptor> result_schema;
+ ASSERT_RAISES(Invalid,
+ ToParquetSchema(arrow_schema.get(), *properties.get(), &result_schema));
+}
+
+TEST(TestFromParquetSchema, CorruptMetadata) {
+ // PARQUET-1565: ensure that an IOError is returned when the parquet file contains
+ // corrupted metadata.
+ auto path = test::get_data_file("PARQUET-1481.parquet", /*is_good=*/false);
+
+ std::unique_ptr<parquet::ParquetFileReader> reader =
+ parquet::ParquetFileReader::OpenFile(path);
+ const auto parquet_schema = reader->metadata()->schema();
+ std::shared_ptr<::arrow::Schema> arrow_schema;
+ ArrowReaderProperties props;
+ ASSERT_RAISES(IOError, FromParquetSchema(parquet_schema, props, &arrow_schema));
+}
+
+//
+// Test LevelInfo computation from a Parquet schema
+// (for Parquet -> Arrow reading).
+//
+
+::arrow::Result<std::deque<LevelInfo>> RootToTreeLeafLevels(
+ const SchemaManifest& manifest, int column_number) {
+ std::deque<LevelInfo> out;
+ const SchemaField* field = nullptr;
+ RETURN_NOT_OK(manifest.GetColumnField(column_number, &field));
+ while (field != nullptr) {
+ out.push_front(field->level_info);
+ field = manifest.GetParent(field);
+ }
+ return out;
+}
+
+class TestLevels : public ::testing::Test {
+ public:
+ virtual void SetUp() {}
+
+ ::arrow::Status MaybeSetParquetSchema(const NodePtr& column) {
+ descriptor_.reset(new SchemaDescriptor());
+ manifest_.reset(new SchemaManifest());
+ descriptor_->Init(GroupNode::Make("root", Repetition::REQUIRED, {column}));
+ return SchemaManifest::Make(descriptor_.get(),
+ std::shared_ptr<const ::arrow::KeyValueMetadata>(),
+ ArrowReaderProperties(), manifest_.get());
+ }
+ void SetParquetSchema(const NodePtr& column) {
+ ASSERT_OK(MaybeSetParquetSchema(column));
+ }
+
+ protected:
+ std::unique_ptr<SchemaDescriptor> descriptor_;
+ std::unique_ptr<SchemaManifest> manifest_;
+};
+
+TEST_F(TestLevels, TestPrimitive) {
+ SetParquetSchema(
+ PrimitiveNode::Make("node_name", Repetition::REQUIRED, ParquetType::BOOLEAN));
+ ASSERT_OK_AND_ASSIGN(std::deque<LevelInfo> levels,
+ RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(levels, ElementsAre(LevelInfo{/*null_slot_usage=*/1,
+ /*def_level=*/0, /*rep_level=*/0,
+ /*ancestor_list_def_level*/ 0}));
+ SetParquetSchema(
+ PrimitiveNode::Make("node_name", Repetition::OPTIONAL, ParquetType::BOOLEAN));
+ ASSERT_OK_AND_ASSIGN(levels, RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(levels, ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1,
+ /*rep_level=*/0,
+ /*ancestor_list_def_level*/ 0}));
+
+ // Arrow schema: list(bool not null) not null
+ SetParquetSchema(
+ PrimitiveNode::Make("node_name", Repetition::REPEATED, ParquetType::BOOLEAN));
+ ASSERT_OK_AND_ASSIGN(levels, RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 0}, // List Field
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 1})); // primitive field
+}
+
+TEST_F(TestLevels, TestMaps) {
+ // Two column map.
+ auto key = PrimitiveNode::Make("key", Repetition::REQUIRED, ParquetType::BYTE_ARRAY,
+ ConvertedType::UTF8);
+ auto value = PrimitiveNode::Make("value", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY,
+ ConvertedType::UTF8);
+
+ auto list = GroupNode::Make("key_value", Repetition::REPEATED, {key, value});
+ SetParquetSchema(
+ GroupNode::Make("my_map", Repetition::OPTIONAL, {list}, LogicalType::Map()));
+ ASSERT_OK_AND_ASSIGN(std::deque<LevelInfo> levels,
+ RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/2, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 0},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/2, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 2},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/2, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 2}));
+
+ ASSERT_OK_AND_ASSIGN(levels, RootToTreeLeafLevels(*manifest_, /*column_number=*/1));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/2, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 0},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/2, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 2},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/3, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 2}));
+
+ // single column map.
+ key = PrimitiveNode::Make("key", Repetition::REQUIRED, ParquetType::BYTE_ARRAY,
+ ConvertedType::UTF8);
+
+ list = GroupNode::Make("key_value", Repetition::REPEATED, {key});
+ SetParquetSchema(
+ GroupNode::Make("my_set", Repetition::REQUIRED, {list}, LogicalType::Map()));
+
+ ASSERT_OK_AND_ASSIGN(levels, RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 0},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 1}));
+}
+
+TEST_F(TestLevels, TestSimpleGroups) {
+ // Arrow schema: struct(child: struct(inner: boolean not null))
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "child", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("inner", Repetition::REQUIRED, ParquetType::BOOLEAN)})}));
+ ASSERT_OK_AND_ASSIGN(std::deque<LevelInfo> levels,
+ RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/0,
+ /*ancestor_list_def_level*/ 0},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/2, /*rep_level=*/0,
+ /*ancestor_list_def_level*/ 0},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/2, /*rep_level=*/0,
+ /*ancestor_list_def_level*/ 0}));
+
+ // Arrow schema: struct(child: struct(inner: boolean ))
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "child", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("inner", Repetition::OPTIONAL, ParquetType::BOOLEAN)})}));
+ ASSERT_OK_AND_ASSIGN(levels, RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/0,
+ /*ancestor_list_def_level*/ 0},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/2, /*rep_level=*/0,
+ /*ancestor_list_def_level*/ 0},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/3, /*rep_level=*/0,
+ /*ancestor_list_def_level*/ 0}));
+
+ // Arrow schema: struct(child: struct(inner: boolean)) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "child", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("inner", Repetition::OPTIONAL, ParquetType::BOOLEAN)})}));
+ ASSERT_OK_AND_ASSIGN(levels, RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/0, /*rep_level=*/0,
+ /*ancestor_list_def_level*/ 0},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/0,
+ /*ancestor_list_def_level*/ 0},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/2, /*rep_level=*/0,
+ /*ancestor_list_def_level*/ 0}));
+}
+
+TEST_F(TestLevels, TestRepeatedGroups) {
+ // Arrow schema: list(bool)
+ SetParquetSchema(GroupNode::Make(
+ "child_list", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::OPTIONAL, ParquetType::BOOLEAN)})},
+ LogicalType::List()));
+
+ ASSERT_OK_AND_ASSIGN(std::deque<LevelInfo> levels,
+ RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/2, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 0},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/3, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 2}));
+
+ // Arrow schema: list(bool) not null
+ SetParquetSchema(GroupNode::Make(
+ "child_list", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::OPTIONAL, ParquetType::BOOLEAN)})},
+ LogicalType::List()));
+
+ ASSERT_OK_AND_ASSIGN(levels, RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 0},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/2, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 1}));
+
+ // Arrow schema: list(bool not null)
+ SetParquetSchema(GroupNode::Make(
+ "child_list", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED, ParquetType::BOOLEAN)})},
+ LogicalType::List()));
+
+ ASSERT_OK_AND_ASSIGN(levels, RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/2, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 0},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/2, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 2}));
+
+ // Arrow schema: list(bool not null) not null
+ SetParquetSchema(GroupNode::Make(
+ "child_list", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED, ParquetType::BOOLEAN)})},
+ LogicalType::List()));
+
+ ASSERT_OK_AND_ASSIGN(levels, RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 0},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 1}));
+
+ // Arrow schema: list(struct(child: struct(list(bool not null) not null)) non null) not
+ // null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REPEATED,
+ {GroupNode::Make(
+ "child", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("inner", Repetition::REPEATED, ParquetType::BOOLEAN)})}));
+ ASSERT_OK_AND_ASSIGN(levels, RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 0},
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 1},
+
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/2, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 1}, // optional child struct
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/3, /*rep_level=*/2,
+ /*ancestor_list_def_level*/ 1}, // repeated field
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/3, /*rep_level=*/2,
+ /*ancestor_list_def_level*/ 3})); // inner field
+
+ // Arrow schema: list(struct(child_list: list(struct(f0: bool f1: bool))) not null) not
+ // null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REPEATED,
+ {GroupNode::Make(
+ "child_list", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {GroupNode::Make(
+ "element", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("f0", Repetition::OPTIONAL, ParquetType::BOOLEAN),
+ PrimitiveNode::Make("f1", Repetition::REQUIRED,
+ ParquetType::BOOLEAN)})})},
+ LogicalType::List())}));
+ ASSERT_OK_AND_ASSIGN(levels, RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 0}, // parent list
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 1}, // parent struct
+
+ // Def_level=2 is handled together with def_level=3
+ // When decoding. Def_level=2 indicates present but empty
+ // list. def_level=3 indicates a present element in the
+ // list.
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/3, /*rep_level=*/2,
+ /*ancestor_list_def_level*/ 1}, // list field
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/4, /*rep_level=*/2,
+ /*ancestor_list_def_level*/ 3}, // inner struct field
+
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/5, /*rep_level=*/2,
+ /*ancestor_list_def_level*/ 3})); // f0 bool field
+
+ ASSERT_OK_AND_ASSIGN(levels, RootToTreeLeafLevels(*manifest_, /*column_number=*/1));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 0}, // parent list
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 1}, // parent struct
+ // Def_level=2 is handled together with def_level=3
+ // When decoding. Def_level=2 indicate present but empty
+ // list. def_level=3 indicates a present element in the
+ // list.
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/3, /*rep_level=*/2,
+ /*ancestor_list_def_level*/ 1}, // list field
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/4, /*rep_level=*/2,
+ /*ancestor_list_def_level*/ 3}, // inner struct field
+
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/4, /*rep_level=*/2,
+ /*ancestor_list_def_level*/ 3})); // f1 bool field
+
+ // Arrow schema: list(struct(child_list: list(bool not null)) not null) not null
+ // Legacy 2-level encoding (required for backwards compatibility. See
+ // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#nested-types
+ // for definitions).
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REPEATED,
+ {GroupNode::Make(
+ "child_list", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("bool", Repetition::REPEATED, ParquetType::BOOLEAN)},
+ LogicalType::List())}));
+
+ ASSERT_OK_AND_ASSIGN(levels, RootToTreeLeafLevels(*manifest_, /*column_number=*/0));
+ EXPECT_THAT(
+ levels,
+ ElementsAre(LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 0}, // parent list
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/1, /*rep_level=*/1,
+ /*ancestor_list_def_level*/ 1}, // parent struct
+
+ // Def_level=2 is handled together with def_level=3
+ // When decoding. Def_level=2 indicate present but empty
+ // list. def_level=3 indicates a present element in the
+ // list.
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/3, /*rep_level=*/2,
+ /*ancestor_list_def_level*/ 1}, // list field
+ LevelInfo{/*null_slot_usage=*/1, /*def_level=*/3, /*rep_level=*/2,
+ /*ancestor_list_def_level*/ 3})); // inner bool
+}
+
+TEST_F(TestLevels, ListErrors) {
+ {
+ ::arrow::Status error = MaybeSetParquetSchema(GroupNode::Make(
+ "child_list", Repetition::REPEATED,
+ {PrimitiveNode::Make("bool", Repetition::REPEATED, ParquetType::BOOLEAN)},
+ LogicalType::List()));
+ ASSERT_RAISES(Invalid, error);
+ std::string expected("LIST-annotated groups must not be repeated.");
+ EXPECT_EQ(error.message().substr(0, expected.size()), expected);
+ }
+ {
+ ::arrow::Status error = MaybeSetParquetSchema(GroupNode::Make(
+ "child_list", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("f1", Repetition::REPEATED, ParquetType::BOOLEAN),
+ PrimitiveNode::Make("f2", Repetition::REPEATED, ParquetType::BOOLEAN)},
+ LogicalType::List()));
+ ASSERT_RAISES(Invalid, error);
+ std::string expected("LIST-annotated groups must have a single child.");
+ EXPECT_EQ(error.message().substr(0, expected.size()), expected);
+ }
+
+ {
+ ::arrow::Status error = MaybeSetParquetSchema(GroupNode::Make(
+ "child_list", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("f1", Repetition::OPTIONAL, ParquetType::BOOLEAN)},
+ LogicalType::List()));
+ ASSERT_RAISES(Invalid, error);
+ std::string expected(
+ "Non-repeated nodes in a LIST-annotated group are not supported.");
+ EXPECT_EQ(error.message().substr(0, expected.size()), expected);
+ }
+}
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/arrow_statistics_test.cc b/src/arrow/cpp/src/parquet/arrow/arrow_statistics_test.cc
new file mode 100644
index 000000000..6684300c0
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/arrow_statistics_test.cc
@@ -0,0 +1,161 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gtest/gtest.h"
+
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+
+#include "parquet/api/reader.h"
+#include "parquet/api/writer.h"
+
+#include "parquet/arrow/schema.h"
+#include "parquet/arrow/writer.h"
+#include "parquet/file_writer.h"
+#include "parquet/test_util.h"
+
+using arrow::ArrayFromJSON;
+using arrow::Buffer;
+using arrow::default_memory_pool;
+using arrow::ResizableBuffer;
+using arrow::Table;
+
+using arrow::io::BufferReader;
+
+namespace parquet {
+namespace arrow {
+
+struct StatisticsTestParam {
+ std::shared_ptr<::arrow::Table> table;
+ int expected_null_count;
+ // This is the non-null count and not the num_values in the page headers.
+ int expected_value_count;
+ std::string expected_min;
+ std::string expected_max;
+};
+
+// Define a custom print since the default Googletest print trips Valgrind
+void PrintTo(const StatisticsTestParam& param, std::ostream* os) {
+ (*os) << "StatisticsTestParam{"
+ << "table.schema=" << param.table->schema()->ToString()
+ << ", expected_null_count=" << param.expected_null_count
+ << ", expected_value_count=" << param.expected_value_count
+ << ", expected_min=" << param.expected_min
+ << ", expected_max=" << param.expected_max << "}";
+}
+
+class ParameterizedStatisticsTest : public ::testing::TestWithParam<StatisticsTestParam> {
+};
+
+std::string GetManyEmptyLists() {
+ std::string many_empty_lists = "[";
+ for (int i = 0; i < 2000; ++i) {
+ many_empty_lists += "[],";
+ }
+ many_empty_lists += "[1,2,3,4,5,6,7,8,null]]";
+ return many_empty_lists;
+}
+
+// PARQUET-2067: Tests that nulls from parent fields are included in null statistics.
+TEST_P(ParameterizedStatisticsTest, NoNullCountWrittenForRepeatedFields) {
+ std::shared_ptr<::arrow::ResizableBuffer> serialized_data = AllocateBuffer();
+ auto out_stream = std::make_shared<::arrow::io::BufferOutputStream>(serialized_data);
+ std::unique_ptr<FileWriter> writer;
+ ASSERT_OK(FileWriter::Open(*GetParam().table->schema(), default_memory_pool(),
+ out_stream, default_writer_properties(),
+ default_arrow_writer_properties(), &writer));
+ ASSERT_OK(writer->WriteTable(*GetParam().table, std::numeric_limits<int64_t>::max()));
+ ASSERT_OK(writer->Close());
+ ASSERT_OK(out_stream->Close());
+
+ auto buffer_reader = std::make_shared<::arrow::io::BufferReader>(serialized_data);
+ auto parquet_reader = ParquetFileReader::Open(std::move(buffer_reader));
+ std::shared_ptr<FileMetaData> metadata = parquet_reader->metadata();
+ std::shared_ptr<Statistics> stats = metadata->RowGroup(0)->ColumnChunk(0)->statistics();
+ EXPECT_EQ(stats->null_count(), GetParam().expected_null_count);
+ EXPECT_EQ(stats->num_values(), GetParam().expected_value_count);
+ ASSERT_TRUE(stats->HasMinMax());
+ EXPECT_EQ(stats->EncodeMin(), GetParam().expected_min);
+ EXPECT_EQ(stats->EncodeMax(), GetParam().expected_max);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ StatsTests, ParameterizedStatisticsTest,
+ ::testing::Values(
+ StatisticsTestParam{
+ /*table=*/Table::Make(::arrow::schema({::arrow::field("a", ::arrow::utf8())}),
+ {ArrayFromJSON(::arrow::utf8(),
+ R"(["1", null, "3"])")}),
+ /*expected_null_count=*/1, /* empty list counts as null as well */
+ /*expected_value_count=*/2,
+ /*expected_min=*/"1",
+ /*expected_max=*/"3"},
+ StatisticsTestParam{
+ /*table=*/Table::Make(
+ ::arrow::schema({::arrow::field("a", list(::arrow::utf8()))}),
+ {ArrayFromJSON(list(::arrow::utf8()),
+ R"([["1"], [], null, ["1", null, "3"]])")}),
+ /*expected_null_count=*/3, /* empty list counts as null as well */
+ /*expected_value_count=*/3,
+ /*expected_min=*/"1",
+ /*expected_max=*/"3"},
+ StatisticsTestParam{
+ /*table=*/Table::Make(
+ ::arrow::schema({::arrow::field("a", ::arrow::int64())}),
+ {ArrayFromJSON(::arrow::int64(), R"([1, null, 3, null])")}),
+ /*expected_null_count=*/2, /* empty list counts as null as well */
+ /*expected_value_count=*/2,
+ /*expected_min=*/std::string("\x1\0\0\0\0\0\0\0", 8),
+ /*expected_max=*/std::string("\x3\0\0\0\0\0\0\0", 8)},
+ StatisticsTestParam{
+ /*table=*/Table::Make(
+ ::arrow::schema({::arrow::field("a", list(::arrow::utf8()))}),
+ {ArrayFromJSON(list(::arrow::utf8()), R"([["1"], [], ["1", "3"]])")}),
+ /*expected_null_count=*/1, /* empty list counts as null as well */
+ /*expected_value_count=*/3,
+ /*expected_min=*/"1",
+ /*expected_max=*/"3"},
+ StatisticsTestParam{
+ /*table=*/Table::Make(
+ ::arrow::schema({::arrow::field("a", list(::arrow::int64()))}),
+ {ArrayFromJSON(list(::arrow::int64()),
+ R"([[1], [], null, [1, null, 3]])")}),
+ /*expected_null_count=*/3, /* empty list counts as null as well */
+ /*expected_value_count=*/3,
+ /*expected_min=*/std::string("\x1\0\0\0\0\0\0\0", 8),
+ /*expected_max=*/std::string("\x3\0\0\0\0\0\0\0", 8)},
+ StatisticsTestParam{
+ /*table=*/Table::Make(
+ ::arrow::schema({::arrow::field("a", list(::arrow::int64()), false)}),
+ {ArrayFromJSON(list(::arrow::int64()), GetManyEmptyLists())}),
+ /*expected_null_count=*/2001, /* empty list counts as null as well */
+ /*expected_value_count=*/8,
+ /*expected_min=*/std::string("\x1\0\0\0\0\0\0\0", 8),
+ /*expected_max=*/std::string("\x8\0\0\0\0\0\0\0", 8)},
+ StatisticsTestParam{
+ /*table=*/Table::Make(
+ ::arrow::schema({::arrow::field("a", list(dictionary(::arrow::int32(),
+ ::arrow::utf8())))}),
+ {ArrayFromJSON(list(dictionary(::arrow::int32(), ::arrow::utf8())),
+ R"([null, ["z", null, "z"], null, null, null])")}),
+ /*expected_null_count=*/5,
+ /*expected_value_count=*/2,
+ /*expected_min=*/"z",
+ /*expected_max=*/"z"}));
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/fuzz.cc b/src/arrow/cpp/src/parquet/arrow/fuzz.cc
new file mode 100644
index 000000000..f1c724508
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/fuzz.cc
@@ -0,0 +1,25 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/status.h"
+#include "parquet/arrow/reader.h"
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+ auto status = parquet::arrow::internal::FuzzReader(data, static_cast<int64_t>(size));
+ ARROW_UNUSED(status);
+ return 0;
+}
diff --git a/src/arrow/cpp/src/parquet/arrow/generate_fuzz_corpus.cc b/src/arrow/cpp/src/parquet/arrow/generate_fuzz_corpus.cc
new file mode 100644
index 000000000..33c3a1461
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/generate_fuzz_corpus.cc
@@ -0,0 +1,198 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// A command line executable that generates a bunch of valid Parquet files
+// containing example record batches. Those are used as fuzzing seeds
+// to make fuzzing more efficient.
+
+#include <cstdlib>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/io/file.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/scalar.h"
+#include "arrow/table.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/compression.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/key_value_metadata.h"
+#include "parquet/arrow/writer.h"
+
+namespace arrow {
+
+using ::arrow::internal::CreateDir;
+using ::arrow::internal::PlatformFilename;
+using ::parquet::WriterProperties;
+
+static constexpr int32_t kBatchSize = 1000;
+static constexpr int32_t kChunkSize = kBatchSize * 3 / 8;
+
+std::shared_ptr<WriterProperties> GetWriterProperties() {
+ WriterProperties::Builder builder{};
+ builder.disable_dictionary("no_dict");
+ builder.compression("compressed", Compression::BROTLI);
+ return builder.build();
+}
+
+Result<std::shared_ptr<RecordBatch>> ExampleBatch1() {
+ constexpr double kNullProbability = 0.2;
+
+ random::RandomArrayGenerator gen(42);
+ std::shared_ptr<Array> a, b, c, d, e, f, g, h, no_dict, compressed;
+ std::shared_ptr<Field> f_a, f_b, f_c, f_d, f_e, f_f, f_g, f_h, f_no_dict, f_compressed;
+
+ a = gen.Int16(kBatchSize, -10000, 10000, kNullProbability);
+ f_a = field("a", a->type());
+
+ b = gen.Float64(kBatchSize, -1e10, 1e10, /*null_probability=*/0.0);
+ f_b = field("b", b->type());
+
+ // A column of tiny strings that will hopefully trigger dict encoding
+ c = gen.String(kBatchSize, 0, 3, kNullProbability);
+ f_c = field("c", c->type());
+
+ // A column of lists
+ {
+ auto values = gen.Int64(kBatchSize * 10, -10000, 10000, kNullProbability);
+ auto offsets = gen.Offsets(kBatchSize + 1, 0, static_cast<int32_t>(values->length()));
+ ARROW_ASSIGN_OR_RAISE(d, ListArray::FromArrays(*offsets, *values));
+ }
+ f_d = field("d", d->type());
+
+ // A column of a repeated constant that will hopefully trigger RLE encoding
+ ARROW_ASSIGN_OR_RAISE(e, MakeArrayFromScalar(Int16Scalar(42), kBatchSize));
+ f_e = field("e", e->type());
+
+ // A column of lists of lists
+ {
+ auto inner_values = gen.Int64(kBatchSize * 9, -10000, 10000, kNullProbability);
+ auto inner_offsets =
+ gen.Offsets(kBatchSize * 3 + 1, 0, static_cast<int32_t>(inner_values->length()),
+ kNullProbability);
+ ARROW_ASSIGN_OR_RAISE(auto inner_lists,
+ ListArray::FromArrays(*inner_offsets, *inner_values));
+ auto offsets = gen.Offsets(
+ kBatchSize + 1, 0, static_cast<int32_t>(inner_lists->length()), kNullProbability);
+ ARROW_ASSIGN_OR_RAISE(f, ListArray::FromArrays(*offsets, *inner_lists));
+ }
+ f_f = field("f", f->type());
+
+ // A column of nested non-nullable structs
+ {
+ ARROW_ASSIGN_OR_RAISE(
+ auto inner_a,
+ StructArray::Make({a, b}, std::vector<std::string>{"inner1_aa", "inner1_ab"}));
+ ARROW_ASSIGN_OR_RAISE(
+ g, StructArray::Make({inner_a, c},
+ {field("inner1_a", inner_a->type(), /*nullable=*/false),
+ field("inner1_c", c->type())}));
+ }
+ f_g = field("g", g->type(), /*nullable=*/false);
+
+ // A column of nested nullable structs
+ {
+ auto null_bitmap = gen.NullBitmap(kBatchSize, kNullProbability);
+ ARROW_ASSIGN_OR_RAISE(
+ auto inner_a,
+ StructArray::Make({a, b}, std::vector<std::string>{"inner2_aa", "inner2_ab"},
+ std::move(null_bitmap)));
+ null_bitmap = gen.NullBitmap(kBatchSize, kNullProbability);
+ ARROW_ASSIGN_OR_RAISE(
+ h,
+ StructArray::Make({inner_a, c}, std::vector<std::string>{"inner2_a", "inner2_c"},
+ std::move(null_bitmap)));
+ }
+ f_h = field("h", h->type());
+
+ // A non-dict-encoded column (see GetWriterProperties)
+ no_dict = gen.String(kBatchSize, 0, 30, kNullProbability);
+ f_no_dict = field("no_dict", no_dict->type());
+
+ // A non-dict-encoded column (see GetWriterProperties)
+ compressed = gen.Int64(kBatchSize, -10, 10, kNullProbability);
+ f_compressed = field("compressed", compressed->type());
+
+ auto schema =
+ ::arrow::schema({f_a, f_b, f_c, f_d, f_e, f_f, f_g, f_h, f_compressed, f_no_dict});
+ auto md = key_value_metadata({"key1", "key2"}, {"value1", ""});
+ schema = schema->WithMetadata(md);
+ return RecordBatch::Make(schema, kBatchSize,
+ {a, b, c, d, e, f, g, h, compressed, no_dict});
+}
+
+Result<std::vector<std::shared_ptr<RecordBatch>>> Batches() {
+ std::vector<std::shared_ptr<RecordBatch>> batches;
+ ARROW_ASSIGN_OR_RAISE(auto batch, ExampleBatch1());
+ batches.push_back(batch);
+ return batches;
+}
+
+Status DoMain(const std::string& out_dir) {
+ ARROW_ASSIGN_OR_RAISE(auto dir_fn, PlatformFilename::FromString(out_dir));
+ RETURN_NOT_OK(CreateDir(dir_fn));
+
+ int sample_num = 1;
+ auto sample_name = [&]() -> std::string {
+ return "pq-table-" + std::to_string(sample_num++);
+ };
+
+ ARROW_ASSIGN_OR_RAISE(auto batches, Batches());
+
+ auto writer_properties = GetWriterProperties();
+
+ for (const auto& batch : batches) {
+ RETURN_NOT_OK(batch->ValidateFull());
+ ARROW_ASSIGN_OR_RAISE(auto table, Table::FromRecordBatches({batch}));
+
+ ARROW_ASSIGN_OR_RAISE(auto sample_fn, dir_fn.Join(sample_name()));
+ std::cerr << sample_fn.ToString() << std::endl;
+ ARROW_ASSIGN_OR_RAISE(auto file, io::FileOutputStream::Open(sample_fn.ToString()));
+ RETURN_NOT_OK(::parquet::arrow::WriteTable(*table, default_memory_pool(), file,
+ kChunkSize, writer_properties));
+ RETURN_NOT_OK(file->Close());
+ }
+ return Status::OK();
+}
+
+ARROW_NORETURN void Usage() {
+ std::cerr << "Usage: parquet-arrow-generate-fuzz-corpus "
+ << "<output directory>" << std::endl;
+ std::exit(2);
+}
+
+int Main(int argc, char** argv) {
+ if (argc != 2) {
+ Usage();
+ }
+ auto out_dir = std::string(argv[1]);
+
+ Status st = DoMain(out_dir);
+ if (!st.ok()) {
+ std::cerr << st.ToString() << std::endl;
+ return 1;
+ }
+ return 0;
+}
+
+} // namespace arrow
+
+int main(int argc, char** argv) { return ::arrow::Main(argc, argv); }
diff --git a/src/arrow/cpp/src/parquet/arrow/path_internal.cc b/src/arrow/cpp/src/parquet/arrow/path_internal.cc
new file mode 100644
index 000000000..7f706e50d
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/path_internal.cc
@@ -0,0 +1,901 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Overview.
+//
+// The strategy used for this code for repetition/definition
+// is to dissect the top level array into a list of paths
+// from the top level array to the final primitive (possibly
+// dictionary encoded array). It then evaluates each one of
+// those paths to produce results for the callback iteratively.
+//
+// This approach was taken to reduce the aggregate memory required if we were
+// to build all def/rep levels in parallel as apart of a tree traversal. It
+// also allows for straightforward parallelization at the path level if that is
+// desired in the future.
+//
+// The main downside to this approach is it duplicates effort for nodes
+// that share common ancestors. This can be mitigated to some degree
+// by adding in optimizations that detect leaf arrays that share
+// the same common list ancestor and reuse the repetition levels
+// from the first leaf encountered (only definition levels greater
+// the list ancestor need to be re-evaluated. This is left for future
+// work.
+//
+// Algorithm.
+//
+// As mentioned above this code dissects arrays into constituent parts:
+// nullability data, and list offset data. It tries to optimize for
+// some special cases, where it is known ahead of time that a step
+// can be skipped (e.g. a nullable array happens to have all of its
+// values) or batch filled (a nullable array has all null values).
+// One further optimization that is not implemented but could be done
+// in the future is special handling for nested list arrays that
+// have some intermediate data which indicates the final array contains only
+// nulls.
+//
+// In general, the algorithm attempts to batch work at each node as much
+// as possible. For nullability nodes this means finding runs of null
+// values and batch filling those interspersed with finding runs of non-null values
+// to process in batch at the next column.
+//
+// Similarly, list runs of empty lists are all processed in one batch
+// followed by either:
+// - A single list entry for non-terminal lists (i.e. the upper part of a nested list)
+// - Runs of non-empty lists for the terminal list (i.e. the lowest part of a nested
+// list).
+//
+// This makes use of the following observations.
+// 1. Null values at any node on the path are terminal (repetition and definition
+// level can be set directly when a Null value is encountered).
+// 2. Empty lists share this eager termination property with Null values.
+// 3. In order to keep repetition/definition level populated the algorithm is lazy
+// in assigning repetition levels. The algorithm tracks whether it is currently
+// in the middle of a list by comparing the lengths of repetition/definition levels.
+// If it is currently in the middle of a list the the number of repetition levels
+// populated will be greater than definition levels (the start of a List requires
+// adding the first element). If there are equal numbers of definition and repetition
+// levels populated this indicates a list is waiting to be started and the next list
+// encountered will have its repetition level signify the beginning of the list.
+//
+// Other implementation notes.
+//
+// This code hasn't been benchmarked (or assembly analyzed) but did the following
+// as optimizations (yes premature optimization is the root of all evil).
+// - This code does not use recursion, instead it constructs its own stack and manages
+// updating elements accordingly.
+// - It tries to avoid using Status for common return states.
+// - Avoids virtual dispatch in favor of if/else statements on a set of well known
+// classes.
+
+#include "parquet/arrow/path_internal.h"
+
+#include <atomic>
+#include <cstddef>
+#include <memory>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/extension_type.h"
+#include "arrow/memory_pool.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_visit.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/variant.h"
+#include "arrow/visitor_inline.h"
+#include "parquet/properties.h"
+
+namespace parquet {
+namespace arrow {
+
+namespace {
+
+using ::arrow::Array;
+using ::arrow::Status;
+using ::arrow::TypedBufferBuilder;
+
+constexpr static int16_t kLevelNotSet = -1;
+
+/// \brief Simple result of a iterating over a column to determine values.
+enum IterationResult {
+ /// Processing is done at this node. Move back up the path
+ /// to continue processing.
+ kDone = -1,
+ /// Move down towards the leaf for processing.
+ kNext = 1,
+ /// An error occurred while processing.
+ kError = 2
+};
+
+#define RETURN_IF_ERROR(iteration_result) \
+ do { \
+ if (ARROW_PREDICT_FALSE(iteration_result == kError)) { \
+ return iteration_result; \
+ } \
+ } while (false)
+
+int64_t LazyNullCount(const Array& array) { return array.data()->null_count.load(); }
+
+bool LazyNoNulls(const Array& array) {
+ int64_t null_count = LazyNullCount(array);
+ return null_count == 0 ||
+ // kUnkownNullCount comparison is needed to account
+ // for null arrays.
+ (null_count == ::arrow::kUnknownNullCount &&
+ array.null_bitmap_data() == nullptr);
+}
+
+struct PathWriteContext {
+ PathWriteContext(::arrow::MemoryPool* pool,
+ std::shared_ptr<::arrow::ResizableBuffer> def_levels_buffer)
+ : rep_levels(pool), def_levels(std::move(def_levels_buffer), pool) {}
+ IterationResult ReserveDefLevels(int64_t elements) {
+ last_status = def_levels.Reserve(elements);
+ if (ARROW_PREDICT_TRUE(last_status.ok())) {
+ return kDone;
+ }
+ return kError;
+ }
+
+ IterationResult AppendDefLevel(int16_t def_level) {
+ last_status = def_levels.Append(def_level);
+ if (ARROW_PREDICT_TRUE(last_status.ok())) {
+ return kDone;
+ }
+ return kError;
+ }
+
+ IterationResult AppendDefLevels(int64_t count, int16_t def_level) {
+ last_status = def_levels.Append(count, def_level);
+ if (ARROW_PREDICT_TRUE(last_status.ok())) {
+ return kDone;
+ }
+ return kError;
+ }
+
+ void UnsafeAppendDefLevel(int16_t def_level) { def_levels.UnsafeAppend(def_level); }
+
+ IterationResult AppendRepLevel(int16_t rep_level) {
+ last_status = rep_levels.Append(rep_level);
+
+ if (ARROW_PREDICT_TRUE(last_status.ok())) {
+ return kDone;
+ }
+ return kError;
+ }
+
+ IterationResult AppendRepLevels(int64_t count, int16_t rep_level) {
+ last_status = rep_levels.Append(count, rep_level);
+ if (ARROW_PREDICT_TRUE(last_status.ok())) {
+ return kDone;
+ }
+ return kError;
+ }
+
+ bool EqualRepDefLevelsLengths() const {
+ return rep_levels.length() == def_levels.length();
+ }
+
+ // Incorporates |range| into visited elements. If the |range| is contiguous
+ // with the last range, extend the last range, otherwise add |range| separately
+ // tot he list.
+ void RecordPostListVisit(const ElementRange& range) {
+ if (!visited_elements.empty() && range.start == visited_elements.back().end) {
+ visited_elements.back().end = range.end;
+ return;
+ }
+ visited_elements.push_back(range);
+ }
+
+ Status last_status;
+ TypedBufferBuilder<int16_t> rep_levels;
+ TypedBufferBuilder<int16_t> def_levels;
+ std::vector<ElementRange> visited_elements;
+};
+
+IterationResult FillRepLevels(int64_t count, int16_t rep_level,
+ PathWriteContext* context) {
+ if (rep_level == kLevelNotSet) {
+ return kDone;
+ }
+ int64_t fill_count = count;
+ // This condition occurs (rep and dep levels equals), in one of
+ // in a few cases:
+ // 1. Before any list is encountered.
+ // 2. After rep-level has been filled in due to null/empty
+ // values above it.
+ // 3. After finishing a list.
+ if (!context->EqualRepDefLevelsLengths()) {
+ fill_count--;
+ }
+ return context->AppendRepLevels(fill_count, rep_level);
+}
+
+// A node for handling an array that is discovered to have all
+// null elements. It is referred to as a TerminalNode because
+// traversal of nodes will not continue it when generating
+// rep/def levels. However, there could be many nested children
+// elements beyond it in the Array that is being processed.
+class AllNullsTerminalNode {
+ public:
+ explicit AllNullsTerminalNode(int16_t def_level, int16_t rep_level = kLevelNotSet)
+ : def_level_(def_level), rep_level_(rep_level) {}
+ void SetRepLevelIfNull(int16_t rep_level) { rep_level_ = rep_level; }
+ IterationResult Run(const ElementRange& range, PathWriteContext* context) {
+ int64_t size = range.Size();
+ RETURN_IF_ERROR(FillRepLevels(size, rep_level_, context));
+ return context->AppendDefLevels(size, def_level_);
+ }
+
+ private:
+ int16_t def_level_;
+ int16_t rep_level_;
+};
+
+// Handles the case where all remaining arrays until the leaf have no nulls
+// (and are not interrupted by lists). Unlike AllNullsTerminalNode this is
+// always the last node in a path. We don't need an analogue to the AllNullsTerminalNode
+// because if all values are present at an intermediate array no node is added for it
+// (the def-level for the next nullable node is incremented).
+struct AllPresentTerminalNode {
+ IterationResult Run(const ElementRange& range, PathWriteContext* context) {
+ return context->AppendDefLevels(range.end - range.start, def_level);
+ // No need to worry about rep levels, because this state should
+ // only be applicable for after all list/repeated values
+ // have been evaluated in the path.
+ }
+ int16_t def_level;
+};
+
+/// Node for handling the case when the leaf-array is nullable
+/// and contains null elements.
+struct NullableTerminalNode {
+ NullableTerminalNode() = default;
+
+ NullableTerminalNode(const uint8_t* bitmap, int64_t element_offset,
+ int16_t def_level_if_present)
+ : bitmap_(bitmap),
+ element_offset_(element_offset),
+ def_level_if_present_(def_level_if_present),
+ def_level_if_null_(def_level_if_present - 1) {}
+
+ IterationResult Run(const ElementRange& range, PathWriteContext* context) {
+ int64_t elements = range.Size();
+ RETURN_IF_ERROR(context->ReserveDefLevels(elements));
+
+ DCHECK_GT(elements, 0);
+
+ auto bit_visitor = [&](bool is_set) {
+ context->UnsafeAppendDefLevel(is_set ? def_level_if_present_ : def_level_if_null_);
+ };
+
+ if (elements > 16) { // 16 guarantees at least one unrolled loop.
+ ::arrow::internal::VisitBitsUnrolled(bitmap_, range.start + element_offset_,
+ elements, bit_visitor);
+ } else {
+ ::arrow::internal::VisitBits(bitmap_, range.start + element_offset_, elements,
+ bit_visitor);
+ }
+ return kDone;
+ }
+ const uint8_t* bitmap_;
+ int64_t element_offset_;
+ int16_t def_level_if_present_;
+ int16_t def_level_if_null_;
+};
+
+// List nodes handle populating rep_level for Arrow Lists and def-level for empty lists.
+// Nullability (both list and children) is handled by other Nodes. By
+// construction all list nodes will be intermediate nodes (they will always be followed by
+// at least one other node).
+//
+// Type parameters:
+// |RangeSelector| - A strategy for determine the the range of the child node to
+// process.
+// this varies depending on the type of list (int32_t* offsets, int64_t* offsets of
+// fixed.
+template <typename RangeSelector>
+class ListPathNode {
+ public:
+ ListPathNode(RangeSelector selector, int16_t rep_lev, int16_t def_level_if_empty)
+ : selector_(std::move(selector)),
+ prev_rep_level_(rep_lev - 1),
+ rep_level_(rep_lev),
+ def_level_if_empty_(def_level_if_empty) {}
+
+ int16_t rep_level() const { return rep_level_; }
+
+ IterationResult Run(ElementRange* range, ElementRange* child_range,
+ PathWriteContext* context) {
+ if (range->Empty()) {
+ return kDone;
+ }
+ // Find the first non-empty list (skipping a run of empties).
+ int64_t empty_elements = 0;
+ do {
+ // Retrieve the range of elements that this list contains.
+ *child_range = selector_.GetRange(range->start);
+ if (!child_range->Empty()) {
+ break;
+ }
+ ++empty_elements;
+ ++range->start;
+ } while (!range->Empty());
+
+ // Post condition:
+ // * range is either empty (we are done processing at this node)
+ // or start corresponds a non-empty list.
+ // * If range is non-empty child_range contains
+ // the bounds of non-empty list.
+
+ // Handle any skipped over empty lists.
+ if (empty_elements > 0) {
+ RETURN_IF_ERROR(FillRepLevels(empty_elements, prev_rep_level_, context));
+ RETURN_IF_ERROR(context->AppendDefLevels(empty_elements, def_level_if_empty_));
+ }
+ // Start of a new list. Note that for nested lists adding the element
+ // here effectively suppresses this code until we either encounter null
+ // elements or empty lists between here and the innermost list (since
+ // we make the rep levels repetition and definition levels unequal).
+ // Similarly when we are backtracking up the stack the repetition and
+ // definition levels are again equal so if we encounter an intermediate list
+ // with more elements this will detect it as a new list.
+ if (context->EqualRepDefLevelsLengths() && !range->Empty()) {
+ RETURN_IF_ERROR(context->AppendRepLevel(prev_rep_level_));
+ }
+
+ if (range->Empty()) {
+ return kDone;
+ }
+
+ ++range->start;
+ if (is_last_) {
+ // If this is the last repeated node, we can extend try
+ // to extend the child range as wide as possible before
+ // continuing to the next node.
+ return FillForLast(range, child_range, context);
+ }
+ return kNext;
+ }
+
+ void SetLast() { is_last_ = true; }
+
+ private:
+ IterationResult FillForLast(ElementRange* range, ElementRange* child_range,
+ PathWriteContext* context) {
+ // First fill int the remainder of the list.
+ RETURN_IF_ERROR(FillRepLevels(child_range->Size(), rep_level_, context));
+ // Once we've reached this point the following preconditions should hold:
+ // 1. There are no more repeated path nodes to deal with.
+ // 2. All elements in |range| represent contiguous elements in the
+ // child array (Null values would have shortened the range to ensure
+ // all remaining list elements are present (though they may be empty lists)).
+ // 3. No element of range spans a parent list (intermediate
+ // list nodes only handle one list entry at a time).
+ //
+ // Given these preconditions it should be safe to fill runs on non-empty
+ // lists here and expand the range in the child node accordingly.
+
+ while (!range->Empty()) {
+ ElementRange size_check = selector_.GetRange(range->start);
+ if (size_check.Empty()) {
+ // The empty range will need to be handled after we pass down the accumulated
+ // range because it affects def_level placement and we need to get the children
+ // def_levels entered first.
+ break;
+ }
+ // This is the start of a new list. We can be sure it only applies
+ // to the previous list (and doesn't jump to the start of any list
+ // further up in nesting due to the constraints mentioned at the start
+ // of the function).
+ RETURN_IF_ERROR(context->AppendRepLevel(prev_rep_level_));
+ RETURN_IF_ERROR(context->AppendRepLevels(size_check.Size() - 1, rep_level_));
+ DCHECK_EQ(size_check.start, child_range->end)
+ << size_check.start << " != " << child_range->end;
+ child_range->end = size_check.end;
+ ++range->start;
+ }
+
+ // Do book-keeping to track the elements of the arrays that are actually visited
+ // beyond this point. This is necessary to identify "gaps" in values that should
+ // not be processed (written out to parquet).
+ context->RecordPostListVisit(*child_range);
+ return kNext;
+ }
+
+ RangeSelector selector_;
+ int16_t prev_rep_level_;
+ int16_t rep_level_;
+ int16_t def_level_if_empty_;
+ bool is_last_ = false;
+};
+
+template <typename OffsetType>
+struct VarRangeSelector {
+ ElementRange GetRange(int64_t index) const {
+ return ElementRange{offsets[index], offsets[index + 1]};
+ }
+
+ // Either int32_t* or int64_t*.
+ const OffsetType* offsets;
+};
+
+struct FixedSizedRangeSelector {
+ ElementRange GetRange(int64_t index) const {
+ int64_t start = index * list_size;
+ return ElementRange{start, start + list_size};
+ }
+ int list_size;
+};
+
+// An intermediate node that handles null values.
+class NullableNode {
+ public:
+ NullableNode(const uint8_t* null_bitmap, int64_t entry_offset,
+ int16_t def_level_if_null, int16_t rep_level_if_null = kLevelNotSet)
+ : null_bitmap_(null_bitmap),
+ entry_offset_(entry_offset),
+ valid_bits_reader_(MakeReader(ElementRange{0, 0})),
+ def_level_if_null_(def_level_if_null),
+ rep_level_if_null_(rep_level_if_null),
+ new_range_(true) {}
+
+ void SetRepLevelIfNull(int16_t rep_level) { rep_level_if_null_ = rep_level; }
+
+ ::arrow::internal::BitRunReader MakeReader(const ElementRange& range) {
+ return ::arrow::internal::BitRunReader(null_bitmap_, entry_offset_ + range.start,
+ range.Size());
+ }
+
+ IterationResult Run(ElementRange* range, ElementRange* child_range,
+ PathWriteContext* context) {
+ if (new_range_) {
+ // Reset the reader each time we are starting fresh on a range.
+ // We can't rely on continuity because nulls above can
+ // cause discontinuities.
+ valid_bits_reader_ = MakeReader(*range);
+ }
+ child_range->start = range->start;
+ ::arrow::internal::BitRun run = valid_bits_reader_.NextRun();
+ if (!run.set) {
+ range->start += run.length;
+ RETURN_IF_ERROR(FillRepLevels(run.length, rep_level_if_null_, context));
+ RETURN_IF_ERROR(context->AppendDefLevels(run.length, def_level_if_null_));
+ run = valid_bits_reader_.NextRun();
+ }
+ if (range->Empty()) {
+ new_range_ = true;
+ return kDone;
+ }
+ child_range->end = child_range->start = range->start;
+ child_range->end += run.length;
+
+ DCHECK(!child_range->Empty());
+ range->start += child_range->Size();
+ new_range_ = false;
+ return kNext;
+ }
+
+ const uint8_t* null_bitmap_;
+ int64_t entry_offset_;
+ ::arrow::internal::BitRunReader valid_bits_reader_;
+ int16_t def_level_if_null_;
+ int16_t rep_level_if_null_;
+
+ // Whether the next invocation will be a new range.
+ bool new_range_ = true;
+};
+
+using ListNode = ListPathNode<VarRangeSelector<int32_t>>;
+using LargeListNode = ListPathNode<VarRangeSelector<int64_t>>;
+using FixedSizeListNode = ListPathNode<FixedSizedRangeSelector>;
+
+// Contains static information derived from traversing the schema.
+struct PathInfo {
+ // The vectors are expected to the same length info.
+
+ // Note index order matters here.
+ using Node = ::arrow::util::Variant<NullableTerminalNode, ListNode, LargeListNode,
+ FixedSizeListNode, NullableNode,
+ AllPresentTerminalNode, AllNullsTerminalNode>;
+
+ std::vector<Node> path;
+ std::shared_ptr<Array> primitive_array;
+ int16_t max_def_level = 0;
+ int16_t max_rep_level = 0;
+ bool has_dictionary = false;
+ bool leaf_is_nullable = false;
+};
+
+/// Contains logic for writing a single leaf node to parquet.
+/// This tracks the path from root to leaf.
+///
+/// |writer| will be called after all of the definition/repetition
+/// values have been calculated for root_range with the calculated
+/// values. It is intended to abstract the complexity of writing
+/// the levels and values to parquet.
+Status WritePath(ElementRange root_range, PathInfo* path_info,
+ ArrowWriteContext* arrow_context,
+ MultipathLevelBuilder::CallbackFunction writer) {
+ std::vector<ElementRange> stack(path_info->path.size());
+ MultipathLevelBuilderResult builder_result;
+ builder_result.leaf_array = path_info->primitive_array;
+ builder_result.leaf_is_nullable = path_info->leaf_is_nullable;
+
+ if (path_info->max_def_level == 0) {
+ // This case only occurs when there are no nullable or repeated
+ // columns in the path from the root to leaf.
+ int64_t leaf_length = builder_result.leaf_array->length();
+ builder_result.def_rep_level_count = leaf_length;
+ builder_result.post_list_visited_elements.push_back({0, leaf_length});
+ return writer(builder_result);
+ }
+ stack[0] = root_range;
+ RETURN_NOT_OK(
+ arrow_context->def_levels_buffer->Resize(/*new_size=*/0, /*shrink_to_fit*/ false));
+ PathWriteContext context(arrow_context->memory_pool, arrow_context->def_levels_buffer);
+ // We should need at least this many entries so reserve the space ahead of time.
+ RETURN_NOT_OK(context.def_levels.Reserve(root_range.Size()));
+ if (path_info->max_rep_level > 0) {
+ RETURN_NOT_OK(context.rep_levels.Reserve(root_range.Size()));
+ }
+
+ auto stack_base = &stack[0];
+ auto stack_position = stack_base;
+ // This is the main loop for calculated rep/def levels. The nodes
+ // in the path implement a chain-of-responsibility like pattern
+ // where each node can add some number of repetition/definition
+ // levels to PathWriteContext and also delegate to the next node
+ // in the path to add values. The values are added through each Run(...)
+ // call and the choice to delegate to the next node (or return to the
+ // previous node) is communicated by the return value of Run(...).
+ // The loop terminates after the first node indicates all values in
+ // |root_range| are processed.
+ while (stack_position >= stack_base) {
+ PathInfo::Node& node = path_info->path[stack_position - stack_base];
+ struct {
+ IterationResult operator()(NullableNode* node) {
+ return node->Run(stack_position, stack_position + 1, context);
+ }
+ IterationResult operator()(ListNode* node) {
+ return node->Run(stack_position, stack_position + 1, context);
+ }
+ IterationResult operator()(NullableTerminalNode* node) {
+ return node->Run(*stack_position, context);
+ }
+ IterationResult operator()(FixedSizeListNode* node) {
+ return node->Run(stack_position, stack_position + 1, context);
+ }
+ IterationResult operator()(AllPresentTerminalNode* node) {
+ return node->Run(*stack_position, context);
+ }
+ IterationResult operator()(AllNullsTerminalNode* node) {
+ return node->Run(*stack_position, context);
+ }
+ IterationResult operator()(LargeListNode* node) {
+ return node->Run(stack_position, stack_position + 1, context);
+ }
+ ElementRange* stack_position;
+ PathWriteContext* context;
+ } visitor = {stack_position, &context};
+
+ IterationResult result = ::arrow::util::visit(visitor, &node);
+
+ if (ARROW_PREDICT_FALSE(result == kError)) {
+ DCHECK(!context.last_status.ok());
+ return context.last_status;
+ }
+ stack_position += static_cast<int>(result);
+ }
+ RETURN_NOT_OK(context.last_status);
+ builder_result.def_rep_level_count = context.def_levels.length();
+
+ if (context.rep_levels.length() > 0) {
+ // This case only occurs when there was a repeated element that needs to be
+ // processed.
+ builder_result.rep_levels = context.rep_levels.data();
+ std::swap(builder_result.post_list_visited_elements, context.visited_elements);
+ // If it is possible when processing lists that all lists where empty. In this
+ // case no elements would have been added to post_list_visited_elements. By
+ // added an empty element we avoid special casing in downstream consumers.
+ if (builder_result.post_list_visited_elements.empty()) {
+ builder_result.post_list_visited_elements.push_back({0, 0});
+ }
+ } else {
+ builder_result.post_list_visited_elements.push_back(
+ {0, builder_result.leaf_array->length()});
+ builder_result.rep_levels = nullptr;
+ }
+
+ builder_result.def_levels = context.def_levels.data();
+ return writer(builder_result);
+}
+
+struct FixupVisitor {
+ int max_rep_level = -1;
+ int16_t rep_level_if_null = kLevelNotSet;
+
+ template <typename T>
+ void HandleListNode(T* arg) {
+ if (arg->rep_level() == max_rep_level) {
+ arg->SetLast();
+ // after the last list node we don't need to fill
+ // rep levels on null.
+ rep_level_if_null = kLevelNotSet;
+ } else {
+ rep_level_if_null = arg->rep_level();
+ }
+ }
+ void operator()(ListNode* node) { HandleListNode(node); }
+ void operator()(LargeListNode* node) { HandleListNode(node); }
+ void operator()(FixedSizeListNode* node) { HandleListNode(node); }
+
+ // For non-list intermediate nodes.
+ template <typename T>
+ void HandleIntermediateNode(T* arg) {
+ if (rep_level_if_null != kLevelNotSet) {
+ arg->SetRepLevelIfNull(rep_level_if_null);
+ }
+ }
+
+ void operator()(NullableNode* arg) { HandleIntermediateNode(arg); }
+
+ void operator()(AllNullsTerminalNode* arg) {
+ // Even though no processing happens past this point we
+ // still need to adjust it if a list occurred after an
+ // all null array.
+ HandleIntermediateNode(arg);
+ }
+
+ void operator()(NullableTerminalNode*) {}
+ void operator()(AllPresentTerminalNode*) {}
+};
+
+PathInfo Fixup(PathInfo info) {
+ // We only need to fixup the path if there were repeated
+ // elements on it.
+ if (info.max_rep_level == 0) {
+ return info;
+ }
+ FixupVisitor visitor;
+ visitor.max_rep_level = info.max_rep_level;
+ if (visitor.max_rep_level > 0) {
+ visitor.rep_level_if_null = 0;
+ }
+ for (size_t x = 0; x < info.path.size(); x++) {
+ ::arrow::util::visit(visitor, &info.path[x]);
+ }
+ return info;
+}
+
+class PathBuilder {
+ public:
+ explicit PathBuilder(bool start_nullable) : nullable_in_parent_(start_nullable) {}
+ template <typename T>
+ void AddTerminalInfo(const T& array) {
+ info_.leaf_is_nullable = nullable_in_parent_;
+ if (nullable_in_parent_) {
+ info_.max_def_level++;
+ }
+ // We don't use null_count() because if the null_count isn't known
+ // and the array does in fact contain nulls, we will end up
+ // traversing the null bitmap twice (once here and once when calculating
+ // rep/def levels).
+ if (LazyNoNulls(array)) {
+ info_.path.emplace_back(AllPresentTerminalNode{info_.max_def_level});
+ } else if (LazyNullCount(array) == array.length()) {
+ info_.path.emplace_back(AllNullsTerminalNode(info_.max_def_level - 1));
+ } else {
+ info_.path.emplace_back(NullableTerminalNode(array.null_bitmap_data(),
+ array.offset(), info_.max_def_level));
+ }
+ info_.primitive_array = std::make_shared<T>(array.data());
+ paths_.push_back(Fixup(info_));
+ }
+
+ template <typename T>
+ ::arrow::enable_if_t<std::is_base_of<::arrow::FlatArray, T>::value, Status> Visit(
+ const T& array) {
+ AddTerminalInfo(array);
+ return Status::OK();
+ }
+
+ template <typename T>
+ ::arrow::enable_if_t<std::is_same<::arrow::ListArray, T>::value ||
+ std::is_same<::arrow::LargeListArray, T>::value,
+ Status>
+ Visit(const T& array) {
+ MaybeAddNullable(array);
+ // Increment necessary due to empty lists.
+ info_.max_def_level++;
+ info_.max_rep_level++;
+ // raw_value_offsets() accounts for any slice offset.
+ ListPathNode<VarRangeSelector<typename T::offset_type>> node(
+ VarRangeSelector<typename T::offset_type>{array.raw_value_offsets()},
+ info_.max_rep_level, info_.max_def_level - 1);
+ info_.path.emplace_back(std::move(node));
+ nullable_in_parent_ = array.list_type()->value_field()->nullable();
+ return VisitInline(*array.values());
+ }
+
+ Status Visit(const ::arrow::DictionaryArray& array) {
+ // Only currently handle DictionaryArray where the dictionary is a
+ // primitive type
+ if (array.dict_type()->value_type()->num_fields() > 0) {
+ return Status::NotImplemented(
+ "Writing DictionaryArray with nested dictionary "
+ "type not yet supported");
+ }
+ if (array.dictionary()->null_count() > 0) {
+ return Status::NotImplemented(
+ "Writing DictionaryArray with null encoded in dictionary "
+ "type not yet supported");
+ }
+ AddTerminalInfo(array);
+ return Status::OK();
+ }
+
+ void MaybeAddNullable(const Array& array) {
+ if (!nullable_in_parent_) {
+ return;
+ }
+ info_.max_def_level++;
+ // We don't use null_count() because if the null_count isn't known
+ // and the array does in fact contain nulls, we will end up
+ // traversing the null bitmap twice (once here and once when calculating
+ // rep/def levels). Because this isn't terminal this might not be
+ // the right decision for structs that share the same nullable
+ // parents.
+ if (LazyNoNulls(array)) {
+ // Don't add anything because there won't be any point checking
+ // null values for the array. There will always be at least
+ // one more array to handle nullability.
+ return;
+ }
+ if (LazyNullCount(array) == array.length()) {
+ info_.path.emplace_back(AllNullsTerminalNode(info_.max_def_level - 1));
+ return;
+ }
+ info_.path.emplace_back(
+ NullableNode(array.null_bitmap_data(), array.offset(),
+ /* def_level_if_null = */ info_.max_def_level - 1));
+ }
+
+ Status VisitInline(const Array& array);
+
+ Status Visit(const ::arrow::MapArray& array) {
+ return Visit(static_cast<const ::arrow::ListArray&>(array));
+ }
+
+ Status Visit(const ::arrow::StructArray& array) {
+ MaybeAddNullable(array);
+ PathInfo info_backup = info_;
+ for (int x = 0; x < array.num_fields(); x++) {
+ nullable_in_parent_ = array.type()->field(x)->nullable();
+ RETURN_NOT_OK(VisitInline(*array.field(x)));
+ info_ = info_backup;
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const ::arrow::FixedSizeListArray& array) {
+ MaybeAddNullable(array);
+ int32_t list_size = array.list_type()->list_size();
+ // Technically we could encode fixed size lists with two level encodings
+ // but since we always use 3 level encoding we increment def levels as
+ // well.
+ info_.max_def_level++;
+ info_.max_rep_level++;
+ info_.path.emplace_back(FixedSizeListNode(FixedSizedRangeSelector{list_size},
+ info_.max_rep_level, info_.max_def_level));
+ nullable_in_parent_ = array.list_type()->value_field()->nullable();
+ if (array.offset() > 0) {
+ return VisitInline(*array.values()->Slice(array.value_offset(0)));
+ }
+ return VisitInline(*array.values());
+ }
+
+ Status Visit(const ::arrow::ExtensionArray& array) {
+ return VisitInline(*array.storage());
+ }
+
+#define NOT_IMPLEMENTED_VISIT(ArrowTypePrefix) \
+ Status Visit(const ::arrow::ArrowTypePrefix##Array& array) { \
+ return Status::NotImplemented("Level generation for " #ArrowTypePrefix \
+ " not supported yet"); \
+ }
+
+ // Union types aren't supported in Parquet.
+ NOT_IMPLEMENTED_VISIT(Union)
+
+#undef NOT_IMPLEMENTED_VISIT
+ std::vector<PathInfo>& paths() { return paths_; }
+
+ private:
+ PathInfo info_;
+ std::vector<PathInfo> paths_;
+ bool nullable_in_parent_;
+};
+
+Status PathBuilder::VisitInline(const Array& array) {
+ return ::arrow::VisitArrayInline(array, this);
+}
+
+#undef RETURN_IF_ERROR
+} // namespace
+
+class MultipathLevelBuilderImpl : public MultipathLevelBuilder {
+ public:
+ MultipathLevelBuilderImpl(std::shared_ptr<::arrow::ArrayData> data,
+ std::unique_ptr<PathBuilder> path_builder)
+ : root_range_{0, data->length},
+ data_(std::move(data)),
+ path_builder_(std::move(path_builder)) {}
+
+ int GetLeafCount() const override {
+ return static_cast<int>(path_builder_->paths().size());
+ }
+
+ ::arrow::Status Write(int leaf_index, ArrowWriteContext* context,
+ CallbackFunction write_leaf_callback) override {
+ DCHECK_GE(leaf_index, 0);
+ DCHECK_LT(leaf_index, GetLeafCount());
+ return WritePath(root_range_, &path_builder_->paths()[leaf_index], context,
+ std::move(write_leaf_callback));
+ }
+
+ private:
+ ElementRange root_range_;
+ // Reference holder to ensure the data stays valid.
+ std::shared_ptr<::arrow::ArrayData> data_;
+ std::unique_ptr<PathBuilder> path_builder_;
+};
+
+// static
+::arrow::Result<std::unique_ptr<MultipathLevelBuilder>> MultipathLevelBuilder::Make(
+ const ::arrow::Array& array, bool array_field_nullable) {
+ auto constructor = ::arrow::internal::make_unique<PathBuilder>(array_field_nullable);
+ RETURN_NOT_OK(VisitArrayInline(array, constructor.get()));
+ return ::arrow::internal::make_unique<MultipathLevelBuilderImpl>(
+ array.data(), std::move(constructor));
+}
+
+// static
+Status MultipathLevelBuilder::Write(const Array& array, bool array_field_nullable,
+ ArrowWriteContext* context,
+ MultipathLevelBuilder::CallbackFunction callback) {
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<MultipathLevelBuilder> builder,
+ MultipathLevelBuilder::Make(array, array_field_nullable));
+ PathBuilder constructor(array_field_nullable);
+ RETURN_NOT_OK(VisitArrayInline(array, &constructor));
+ for (int leaf_idx = 0; leaf_idx < builder->GetLeafCount(); leaf_idx++) {
+ RETURN_NOT_OK(builder->Write(leaf_idx, context, callback));
+ }
+ return Status::OK();
+}
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/path_internal.h b/src/arrow/cpp/src/parquet/arrow/path_internal.h
new file mode 100644
index 000000000..c5b7fdfda
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/path_internal.h
@@ -0,0 +1,155 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+
+#include "parquet/platform.h"
+
+namespace arrow {
+
+class Array;
+
+} // namespace arrow
+
+namespace parquet {
+
+struct ArrowWriteContext;
+
+namespace arrow {
+
+// This files contain internal implementation details and should not be considered
+// part of the public API.
+
+// The MultipathLevelBuilder is intended to fully support all Arrow nested types that
+// map to parquet types (i.e. Everything but Unions).
+//
+
+/// \brief Half open range of elements in an array.
+struct ElementRange {
+ /// Upper bound of range (inclusive)
+ int64_t start;
+ /// Upper bound of range (exclusive)
+ int64_t end;
+
+ bool Empty() const { return start == end; }
+
+ int64_t Size() const { return end - start; }
+};
+
+/// \brief Result for a single leaf array when running the builder on the
+/// its root.
+struct MultipathLevelBuilderResult {
+ /// \brief The Array containing only the values to write (after all nesting has
+ /// been processed.
+ ///
+ /// No additional processing is done on this array (it is copied as is when
+ /// visited via a DFS).
+ std::shared_ptr<::arrow::Array> leaf_array;
+
+ /// \brief Might be null.
+ const int16_t* def_levels = nullptr;
+
+ /// \brief Might be null.
+ const int16_t* rep_levels = nullptr;
+
+ /// \brief Number of items (int16_t) contained in def/rep_levels when present.
+ int64_t def_rep_level_count = 0;
+
+ /// \brief Contains element ranges of the required visiting on the
+ /// descendants of the final list ancestor for any leaf node.
+ ///
+ /// The algorithm will attempt to consolidate visited ranges into
+ /// the smallest number possible.
+ ///
+ /// This data is necessary to pass along because after producing
+ /// def-rep levels for each leaf array it is impossible to determine
+ /// which values have to be sent to parquet when a null list value
+ /// in a nullable ListArray is non-empty.
+ ///
+ /// This allows for the parquet writing to determine which values ultimately
+ /// needs to be written.
+ std::vector<ElementRange> post_list_visited_elements;
+
+ /// Whether the leaf array is nullable.
+ bool leaf_is_nullable;
+};
+
+/// \brief Logic for being able to write out nesting (rep/def level) data that is
+/// needed for writing to parquet.
+class PARQUET_EXPORT MultipathLevelBuilder {
+ public:
+ /// \brief A callback function that will receive results from the call to
+ /// Write(...) below. The MultipathLevelBuilderResult passed in will
+ /// only remain valid for the function call (i.e. storing it and relying
+ /// for its data to be consistent afterwards will result in undefined
+ /// behavior.
+ using CallbackFunction =
+ std::function<::arrow::Status(const MultipathLevelBuilderResult&)>;
+
+ /// \brief Determine rep/def level information for the array.
+ ///
+ /// The callback will be invoked for each leaf Array that is a
+ /// descendant of array. Each leaf array is processed in a depth
+ /// first traversal-order.
+ ///
+ /// \param[in] array The array to process.
+ /// \param[in] array_field_nullable Whether the algorithm should consider
+ /// the the array column as nullable (as determined by its type's parent
+ /// field).
+ /// \param[in, out] context for use when allocating memory, etc.
+ /// \param[out] write_leaf_callback Callback to receive results.
+ /// There will be one call to the write_leaf_callback for each leaf node.
+ static ::arrow::Status Write(const ::arrow::Array& array, bool array_field_nullable,
+ ArrowWriteContext* context,
+ CallbackFunction write_leaf_callback);
+
+ /// \brief Construct a new instance of the builder.
+ ///
+ /// \param[in] array The array to process.
+ /// \param[in] array_field_nullable Whether the algorithm should consider
+ /// the the array column as nullable (as determined by its type's parent
+ /// field).
+ static ::arrow::Result<std::unique_ptr<MultipathLevelBuilder>> Make(
+ const ::arrow::Array& array, bool array_field_nullable);
+
+ virtual ~MultipathLevelBuilder() = default;
+
+ /// \brief Returns the number of leaf columns that need to be written
+ /// to Parquet.
+ virtual int GetLeafCount() const = 0;
+
+ /// \brief Calls write_leaf_callback with the MultipathLevelBuilderResult corresponding
+ /// to |leaf_index|.
+ ///
+ /// \param[in] leaf_index The index of the leaf column to write. Must be in the range
+ /// [0, GetLeafCount()].
+ /// \param[in, out] context for use when allocating memory, etc.
+ /// \param[out] write_leaf_callback Callback to receive the result.
+ virtual ::arrow::Status Write(int leaf_index, ArrowWriteContext* context,
+ CallbackFunction write_leaf_callback) = 0;
+};
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/path_internal_test.cc b/src/arrow/cpp/src/parquet/arrow/path_internal_test.cc
new file mode 100644
index 000000000..464580700
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/path_internal_test.cc
@@ -0,0 +1,648 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/arrow/path_internal.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+
+#include "parquet/properties.h"
+
+namespace parquet {
+namespace arrow {
+
+using ::arrow::default_memory_pool;
+using ::arrow::field;
+using ::arrow::fixed_size_list;
+using ::arrow::Status;
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+using ::testing::Eq;
+using ::testing::NotNull;
+using ::testing::SizeIs;
+
+class CapturedResult {
+ public:
+ bool null_rep_levels = false;
+ bool null_def_levels = false;
+ std::vector<ElementRange> post_list_elements;
+
+ CapturedResult(const int16_t* def_levels, const int16_t* rep_levels,
+ int64_t def_rep_level_count,
+ std::vector<ElementRange> post_list_visited_elements) {
+ if (def_levels != nullptr) {
+ def_levels_ = std::vector<int16_t>(def_levels, def_levels + def_rep_level_count);
+ } else {
+ null_def_levels = true;
+ }
+ if (rep_levels != nullptr) {
+ rep_levels_ = std::vector<int16_t>(rep_levels, rep_levels + def_rep_level_count);
+ } else {
+ null_rep_levels = true;
+ }
+ post_list_elements = std::move(post_list_visited_elements);
+ }
+
+ explicit CapturedResult(MultipathLevelBuilderResult result)
+ : CapturedResult(result.def_levels, result.rep_levels, result.def_rep_level_count,
+ std::move(result.post_list_visited_elements)) {}
+
+ void CheckLevelsWithNullRepLevels(const std::vector<int16_t>& expected_def) {
+ EXPECT_TRUE(null_rep_levels);
+ ASSERT_FALSE(null_def_levels);
+ EXPECT_THAT(def_levels_, ElementsAreArray(expected_def));
+ }
+
+ void CheckLevels(const std::vector<int16_t>& expected_def,
+ const std::vector<int16_t>& expected_rep) const {
+ ASSERT_FALSE(null_def_levels);
+ ASSERT_FALSE(null_rep_levels);
+ EXPECT_THAT(def_levels_, ElementsAreArray(expected_def));
+ EXPECT_THAT(rep_levels_, ElementsAreArray(expected_rep));
+ }
+
+ friend std::ostream& operator<<(std::ostream& os, const CapturedResult& result) {
+ // This print method is to silence valgrind issues. What's printed
+ // is not important because all asserts happen directly on
+ // members.
+ os << "CapturedResult (null def, null_rep):" << result.null_def_levels << " "
+ << result.null_rep_levels;
+ return os;
+ }
+
+ private:
+ std::vector<int16_t> def_levels_;
+ std::vector<int16_t> rep_levels_;
+};
+
+struct Callback {
+ Status operator()(const MultipathLevelBuilderResult& result) {
+ results->emplace_back(result);
+ return Status::OK();
+ }
+ std::vector<CapturedResult>* results;
+};
+
+class MultipathLevelBuilderTest : public testing::Test {
+ protected:
+ std::vector<CapturedResult> results_;
+ Callback callback_{&results_};
+ std::shared_ptr<ArrowWriterProperties> arrow_properties_ =
+ default_arrow_writer_properties();
+ ArrowWriteContext context_ =
+ ArrowWriteContext(default_memory_pool(), arrow_properties_.get());
+};
+
+TEST_F(MultipathLevelBuilderTest, NonNullableSingleListNonNullableEntries) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/false);
+ auto list_type = large_list(entries);
+ // Translates to parquet schema:
+ // required group bag {
+ // repeated group [unseen] (List) {
+ // required int64 Entries;
+ // }
+ // }
+ // So:
+ // def level 0: an empty list
+ // def level 1: a non-null entry
+ auto array = ::arrow::ArrayFromJSON(list_type, R"([[1], [2, 3], [4, 5, 6]])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/false, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+
+ result.CheckLevels(/*def_levels=*/std::vector<int16_t>(/*count=*/6, 1),
+ /*rep_levels=*/{0, 0, 1, 0, 1, 1});
+
+ ASSERT_THAT(result.post_list_elements, SizeIs(1));
+ EXPECT_THAT(result.post_list_elements[0].start, Eq(0));
+ EXPECT_THAT(result.post_list_elements[0].end, Eq(6));
+}
+
+TEST_F(MultipathLevelBuilderTest, NullableSingleListWithAllNullsLists) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/false);
+ auto list_type = list(entries);
+ // Translates to parquet schema:
+ // optional group bag {
+ // repeated group [unseen] (List) {
+ // required int64 Entries;
+ // }
+ // }
+ // So:
+ // def level 0: a null list
+ // def level 1: an empty list
+ // def level 2: a non-null entry
+
+ auto array = ::arrow::ArrayFromJSON(list_type, R"([null, null, null, null])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+ result.CheckLevels(/*def_levels=*/std::vector<int16_t>(/*count=*/4, 0),
+ /*rep_levels=*/std::vector<int16_t>(4, 0));
+}
+
+TEST_F(MultipathLevelBuilderTest, NullableSingleListWithMixedElements) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/false);
+ auto list_type = list(entries);
+ // Translates to parquet schema:
+ // optional group bag {
+ // repeated group [unseen] (List) {
+ // required int64 Entries;
+ // }
+ // }
+ // So:
+ // def level 0: a null list
+ // def level 1: an empty list
+ // def level 2: a non-null entry
+
+ auto array = ::arrow::ArrayFromJSON(list_type, R"([null, [], null, [1]])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+ result.CheckLevels(/*def_levels=*/std::vector<int16_t>{0, 1, 0, 2},
+ /*rep_levels=*/std::vector<int16_t>(/*count=*/4, 0));
+}
+
+TEST_F(MultipathLevelBuilderTest, EmptyLists) {
+ // ARROW-13676 - ensure no out of bounds list memory accesses.
+ auto entries = field("Entries", ::arrow::int64());
+ auto list_type = list(entries);
+ // Number of elements is important, to work past buffer padding hiding
+ // the issue.
+ auto array = ::arrow::ArrayFromJSON(list_type, R"([
+ [],[],[],[],[],[],[],[],[],[],[],[],[],[],[]])");
+
+ // Translates to parquet schema:
+ // optional group bag {
+ // repeated group [unseen] (List) {
+ // optional int64 Entries;
+ // }
+ // }
+ // So:
+ // def level 0: a null list
+ // def level 1: an empty list
+ // def level 2: a null entry
+ // def level 3: a non-null entry
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+ result.CheckLevels(/*def_levels=*/std::vector<int16_t>(/*count=*/15, 1),
+ /*rep_levels=*/std::vector<int16_t>(15, 0));
+}
+
+TEST_F(MultipathLevelBuilderTest, NullableSingleListWithAllEmptyLists) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/false);
+ auto list_type = list(entries);
+ // Translates to parquet schema:
+ // optional group bag {
+ // repeated group [unseen] (List) {
+ // required int64 Entries;
+ // }
+ // }
+ // So:
+ // def level 0: a null list
+ // def level 1: an empty list
+ // def level 2: a non-null entry
+
+ auto array = ::arrow::ArrayFromJSON(list_type, R"([[], [], [], []])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+ result.CheckLevels(/*def_levels=*/std::vector<int16_t>(/*count=*/4, 1),
+ /*rep_levels=*/std::vector<int16_t>(4, 0));
+}
+
+// This Parquet schema used for the next several tests
+//
+// optional group bag {
+// repeated group [unseen] (List) {
+// optional int64 Entries;
+// }
+// }
+// So:
+// def level 0: a null list
+// def level 1: an empty list
+// def level 2: a null entry
+// def level 3: a non-null entry
+
+TEST_F(MultipathLevelBuilderTest, NullableSingleListWithAllNullEntries) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/true);
+ auto list_type = list(entries);
+
+ auto array = ::arrow::ArrayFromJSON(list_type, R"([[null], [null], [null], [null]])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+ result.CheckLevels(/*def_levels=*/std::vector<int16_t>(/*count=*/4, 2),
+ /*rep_levels=*/std::vector<int16_t>(4, 0));
+ ASSERT_THAT(result.post_list_elements, SizeIs(1));
+ EXPECT_THAT(result.post_list_elements[0].start, Eq(0));
+ EXPECT_THAT(result.post_list_elements[0].end, Eq(4));
+}
+
+TEST_F(MultipathLevelBuilderTest, NullableSingleListWithAllPresentEntries) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/true);
+ auto list_type = list(entries);
+
+ auto array = ::arrow::ArrayFromJSON(list_type, R"([[], [], [1], [], [2, 3]])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+ result.CheckLevels(/*def_levels=*/std::vector<int16_t>{1, 1, 3, 1, 3, 3},
+ /*rep_levels=*/std::vector<int16_t>{0, 0, 0, 0, 0, 1});
+
+ ASSERT_THAT(result.post_list_elements, SizeIs(1));
+ EXPECT_THAT(result.post_list_elements[0].start, Eq(0));
+ EXPECT_THAT(result.post_list_elements[0].end, Eq(3));
+}
+
+TEST_F(MultipathLevelBuilderTest, NullableSingleListWithAllEmptyEntries) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/true);
+ auto list_type = list(entries);
+
+ auto array = ::arrow::ArrayFromJSON(list_type, R"([[], [], [], [], []])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+ result.CheckLevels(/*def_levels=*/std::vector<int16_t>(/*count=*/5, 1),
+ /*rep_levels=*/std::vector<int16_t>(/*count=*/5, 0));
+}
+
+TEST_F(MultipathLevelBuilderTest, NullableSingleListWithSomeNullEntriesAndSomeNullLists) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/true);
+ auto list_type = list(entries);
+
+ auto array = ::arrow::ArrayFromJSON(
+ list_type, R"([null, [1 , 2, 3], [], [], null, null, [4, 5], [null]])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+
+ result.CheckLevels(
+ /*def_levels=*/std::vector<int16_t>{0, 3, 3, 3, 1, 1, 0, 0, 3, 3, 2},
+ /*rep_levels=*/std::vector<int16_t>{0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0});
+}
+
+// This Parquet schema used for the following tests
+//
+// optional group bag {
+// repeated group outer_list (List) {
+// option group nullable {
+// repeated group inner_list (List) {
+// optional int64 Entries;
+// }
+// }
+// }
+// }
+// So:
+// def level 0: a outer list
+// def level 1: an empty outer list
+// def level 2: a null inner list
+// def level 3: an empty inner list
+// def level 4: a null entry
+// def level 5: a non-null entry
+
+TEST_F(MultipathLevelBuilderTest, NestedListsWithSomeEntries) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/true);
+ auto list_field = field("list", list(entries), /*nullable=*/true);
+ auto nested_list_type = list(list_field);
+ auto array = ::arrow::ArrayFromJSON(
+ nested_list_type, R"([null, [[1 , 2, 3], [4, 5]], [[], [], []], []])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+ result.CheckLevels(/*def_levels=*/std::vector<int16_t>{0, 5, 5, 5, 5, 5, 3, 3, 3, 1},
+ /*rep_levels=*/std::vector<int16_t>{0, 0, 2, 2, 1, 2, 0, 1, 1, 0});
+}
+
+TEST_F(MultipathLevelBuilderTest, NestedListsWithSomeNulls) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/true);
+ auto list_field = field("list", list(entries), /*nullable=*/true);
+ auto nested_list_type = list(list_field);
+
+ auto array = ::arrow::ArrayFromJSON(nested_list_type,
+ R"([null, [[1, null, 3], null, null], [[4, 5]]])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+ result.CheckLevels(/*def_levels=*/std::vector<int16_t>{0, 5, 4, 5, 2, 2, 5, 5},
+ /*rep_levels=*/std::vector<int16_t>{0, 0, 2, 2, 1, 1, 0, 2});
+}
+
+TEST_F(MultipathLevelBuilderTest, NestedListsWithSomeNullsSomeEmptys) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/true);
+ auto list_field = field("list", list(entries), /*nullable=*/true);
+ auto nested_list_type = list(list_field);
+ auto array = ::arrow::ArrayFromJSON(nested_list_type,
+ R"([null, [[1 , null, 3], [], []], [[4, 5]]])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+ result.CheckLevels(/*def_levels=*/std::vector<int16_t>{0, 5, 4, 5, 3, 3, 5, 5},
+ /*rep_levels=*/std::vector<int16_t>{0, 0, 2, 2, 1, 1, 0, 2});
+}
+
+// TripleNested schema
+//
+// optional group bag {
+// repeated group outer_list (List) {
+// option group nullable {
+// repeated group middle_list (List) {
+// option group nullable {
+// repeated group inner_list (List) {
+// optional int64 Entries;
+// }
+// }
+// }
+// }
+// }
+// }
+// So:
+// def level 0: a outer list
+// def level 1: an empty outer list
+// def level 2: a null middle list
+// def level 3: an empty middle list
+// def level 4: an null inner list
+// def level 5: an empty inner list
+// def level 6: a null entry
+// def level 7: a non-null entry
+
+TEST_F(MultipathLevelBuilderTest, TripleNestedListsAllPresent) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/true);
+ auto list_field = field("list", list(entries), /*nullable=*/true);
+ auto nested_list_type = list(list_field);
+ auto double_nested_list_type = list(nested_list_type);
+
+ auto array = ::arrow::ArrayFromJSON(double_nested_list_type,
+ R"([ [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]] ])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+ result.CheckLevels(/*def_levels=*/std::vector<int16_t>(/*counter=*/9, 7),
+ /*rep_levels=*/std::vector<int16_t>{
+ 0, 3, 3, 2, 3, 3, 1, 3, 3 // first row
+ });
+}
+
+TEST_F(MultipathLevelBuilderTest, TripleNestedListsWithSomeNullsSomeEmptys) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/true);
+ auto list_field = field("list", list(entries), /*nullable=*/true);
+ auto nested_list_type = list(list_field);
+ auto double_nested_list_type = list(nested_list_type);
+ auto array = ::arrow::ArrayFromJSON(double_nested_list_type,
+ R"([
+ [null, [[1 , null, 3], []], []],
+ [[[]], [[], [1, 2]], null, [[3]]],
+ null,
+ []
+ ])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+ result.CheckLevels(/*def_levels=*/std::vector<int16_t>{2, 7, 6, 7, 5, 3, // first row
+ 5, 5, 7, 7, 2, 7, // second row
+ 0, // third row
+ 1},
+ /*rep_levels=*/std::vector<int16_t>{0, 1, 3, 3, 2, 1, // first row
+ 0, 1, 2, 3, 1, 1, // second row
+ 0, 0});
+}
+
+TEST_F(MultipathLevelBuilderTest, QuadNestedListsAllPresent) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/true);
+ auto list_field = field("list", list(entries), /*nullable=*/true);
+ auto nested_list_type = list(list_field);
+ auto double_nested_list_type = list(nested_list_type);
+ auto triple_nested_list_type = list(double_nested_list_type);
+
+ auto array = ::arrow::ArrayFromJSON(triple_nested_list_type,
+ R"([ [[[[1, 2], [3, 4]], [[5]]], [[[6, 7, 8]]]],
+ [[[[1, 2], [3, 4]], [[5]]], [[[6, 7, 8]]]] ])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ const CapturedResult& result = results_[0];
+ result.CheckLevels(/*def_levels=*/std::vector<int16_t>(16, 9),
+ /*rep_levels=*/std::vector<int16_t>{
+ 0, 4, 3, 4, 2, 1, 4, 4, //
+ 0, 4, 3, 4, 2, 1, 4, 4 //
+ });
+}
+
+TEST_F(MultipathLevelBuilderTest, TestStruct) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/true);
+ auto list_field = field("list", list(entries), /*nullable=*/true);
+ auto struct_type = ::arrow::struct_({list_field, entries});
+
+ auto array = ::arrow::ArrayFromJSON(struct_type,
+ R"([{"Entries" : 1, "list": [2, 3]},
+ {"Entries" : 4, "list": [5, 6]},
+ null])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+ ASSERT_THAT(results_, SizeIs(2));
+ results_[0].CheckLevels(/*def_levels=*/std::vector<int16_t>{4, 4, 4, 4, 0},
+ /*rep_levels=*/std::vector<int16_t>{0, 1, 0, 1, 0});
+ results_[1].CheckLevelsWithNullRepLevels(
+ /*def_levels=*/std::vector<int16_t>({2, 2, 0}));
+}
+
+TEST_F(MultipathLevelBuilderTest, TestFixedSizeListNullableElements) {
+ auto entries = field("Entries", ::arrow::int64());
+ auto list_type = fixed_size_list(entries, 2);
+ auto array = ::arrow::ArrayFromJSON(list_type, R"([null, [2, 3], [4, 5], null])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ results_[0].CheckLevels(/*def_levels=*/std::vector<int16_t>{0, 3, 3, 3, 3, 0},
+ /*rep_levels=*/std::vector<int16_t>{0, 0, 1, 0, 1, 0});
+
+ // Null slots take up space in a fixed size list (they can in variable size
+ // lists as well) but the actual written values are only the "middle" elements
+ // in this case.
+ ASSERT_THAT(results_[0].post_list_elements, SizeIs(1));
+ EXPECT_THAT(results_[0].post_list_elements[0].start, Eq(2));
+ EXPECT_THAT(results_[0].post_list_elements[0].end, Eq(6));
+}
+
+TEST_F(MultipathLevelBuilderTest, TestFixedSizeList) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/false);
+ auto list_type = fixed_size_list(entries, 2);
+ auto array = ::arrow::ArrayFromJSON(list_type, R"([null, [2, 3], [4, 5], null])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ results_[0].CheckLevels(/*def_levels=*/std::vector<int16_t>{0, 2, 2, 2, 2, 0},
+ /*rep_levels=*/std::vector<int16_t>{0, 0, 1, 0, 1, 0});
+
+ // Null slots take up space in a fixed size list (they can in variable size
+ // lists as well) but the actual written values are only the "middle" elements
+ // in this case.
+ ASSERT_THAT(results_[0].post_list_elements, SizeIs(1));
+ EXPECT_THAT(results_[0].post_list_elements[0].start, Eq(2));
+ EXPECT_THAT(results_[0].post_list_elements[0].end, Eq(6));
+}
+
+TEST_F(MultipathLevelBuilderTest, TestFixedSizeListMissingMiddleHasTwoVisitedRanges) {
+ auto entries = field("Entries", ::arrow::int64(), /*nullable=*/false);
+ auto list_type = fixed_size_list(entries, 2);
+ auto array = ::arrow::ArrayFromJSON(list_type, R"([[0, 1], null, [2, 3]])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+
+ // Null slots take up space in a fixed size list (they can in variable size
+ // lists as well) but the actual written values are only the head and tail elements
+ // in this case.
+ ASSERT_THAT(results_[0].post_list_elements, SizeIs(2));
+ EXPECT_THAT(results_[0].post_list_elements[0].start, Eq(0));
+ EXPECT_THAT(results_[0].post_list_elements[0].end, Eq(2));
+
+ EXPECT_THAT(results_[0].post_list_elements[1].start, Eq(4));
+ EXPECT_THAT(results_[0].post_list_elements[1].end, Eq(6));
+}
+
+TEST_F(MultipathLevelBuilderTest, TestMap) {
+ auto map_type = ::arrow::map(::arrow::int64(), ::arrow::utf8());
+
+ auto array = ::arrow::ArrayFromJSON(map_type,
+ R"([[[1, "a"], [2, "b"]],
+ [[3, "c"], [4, null]],
+ [],
+ null])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/true, &context_, callback_));
+ ASSERT_THAT(results_, SizeIs(2));
+ // optional group bag {
+ // repeated group [unseen] (Map) {
+ // optional group KeyValue {
+ // required int64 key;
+ // optional string value;
+ // }
+ // }
+ // }
+ // So for keys:
+ // def level 0: a null map
+ // def level 1: an empty maps
+ // def level 2: a defined key.
+ //
+ // and for values:
+ // def level 0: a null map
+ // def level 1: an empty maps
+ // def level 2: a null value
+ // def level 3: a present value.
+ //
+
+ results_[0].CheckLevels(/*def_levels=*/
+ std::vector<int16_t>{
+ 2, 2, //
+ 2, 2, //
+ 1, //
+ 0 //
+ },
+ /*rep_levels=*/std::vector<int16_t>{0, 1, //
+ 0, 1, //
+ 0, 0});
+ // entries
+ results_[1].CheckLevels(/*def_levels=*/
+ std::vector<int16_t>{
+ 3, 3, //
+ 3, 2, //
+ 1, //
+ 0 //
+ },
+ /*rep_levels=*/std::vector<int16_t>{0, 1, //
+ 0, 1, //
+ 0, //
+ 0});
+}
+
+TEST_F(MultipathLevelBuilderTest, TestPrimitiveNonNullable) {
+ auto array = ::arrow::ArrayFromJSON(::arrow::int64(), R"([1, 2, 3, 4])");
+
+ ASSERT_OK(
+ MultipathLevelBuilder::Write(*array, /*nullable=*/false, &context_, callback_));
+
+ ASSERT_THAT(results_, SizeIs(1));
+ EXPECT_TRUE(results_[0].null_rep_levels);
+ EXPECT_TRUE(results_[0].null_def_levels);
+
+ ASSERT_THAT(results_[0].post_list_elements, SizeIs(1));
+ EXPECT_THAT(results_[0].post_list_elements[0].start, Eq(0));
+ EXPECT_THAT(results_[0].post_list_elements[0].end, Eq(4));
+}
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/reader.cc b/src/arrow/cpp/src/parquet/arrow/reader.cc
new file mode 100644
index 000000000..1c2331864
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/reader.cc
@@ -0,0 +1,1305 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/arrow/reader.h"
+
+#include <algorithm>
+#include <cstring>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/extension_type.h"
+#include "arrow/io/memory.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/future.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/parallel.h"
+#include "arrow/util/range.h"
+#include "parquet/arrow/reader_internal.h"
+#include "parquet/column_reader.h"
+#include "parquet/exception.h"
+#include "parquet/file_reader.h"
+#include "parquet/metadata.h"
+#include "parquet/properties.h"
+#include "parquet/schema.h"
+
+using arrow::Array;
+using arrow::ArrayData;
+using arrow::BooleanArray;
+using arrow::ChunkedArray;
+using arrow::DataType;
+using arrow::ExtensionType;
+using arrow::Field;
+using arrow::Future;
+using arrow::Int32Array;
+using arrow::ListArray;
+using arrow::MemoryPool;
+using arrow::RecordBatchReader;
+using arrow::ResizableBuffer;
+using arrow::Status;
+using arrow::StructArray;
+using arrow::Table;
+using arrow::TimestampArray;
+
+using arrow::internal::checked_cast;
+using arrow::internal::Iota;
+
+// Help reduce verbosity
+using ParquetReader = parquet::ParquetFileReader;
+
+using parquet::internal::RecordReader;
+
+namespace BitUtil = arrow::BitUtil;
+
+namespace parquet {
+namespace arrow {
+namespace {
+
+::arrow::Result<std::shared_ptr<ArrayData>> ChunksToSingle(const ChunkedArray& chunked) {
+ switch (chunked.num_chunks()) {
+ case 0: {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> array,
+ ::arrow::MakeArrayOfNull(chunked.type(), 0));
+ return array->data();
+ }
+ case 1:
+ return chunked.chunk(0)->data();
+ default:
+ // ARROW-3762(wesm): If item reader yields a chunked array, we reject as
+ // this is not yet implemented
+ return Status::NotImplemented(
+ "Nested data conversions not implemented for chunked array outputs");
+ }
+}
+
+} // namespace
+
+class ColumnReaderImpl : public ColumnReader {
+ public:
+ virtual Status GetDefLevels(const int16_t** data, int64_t* length) = 0;
+ virtual Status GetRepLevels(const int16_t** data, int64_t* length) = 0;
+ virtual const std::shared_ptr<Field> field() = 0;
+
+ ::arrow::Status NextBatch(int64_t batch_size,
+ std::shared_ptr<::arrow::ChunkedArray>* out) final {
+ RETURN_NOT_OK(LoadBatch(batch_size));
+ RETURN_NOT_OK(BuildArray(batch_size, out));
+ for (int x = 0; x < (*out)->num_chunks(); x++) {
+ RETURN_NOT_OK((*out)->chunk(x)->Validate());
+ }
+ return Status::OK();
+ }
+
+ virtual ::arrow::Status LoadBatch(int64_t num_records) = 0;
+
+ virtual ::arrow::Status BuildArray(int64_t length_upper_bound,
+ std::shared_ptr<::arrow::ChunkedArray>* out) = 0;
+ virtual bool IsOrHasRepeatedChild() const = 0;
+};
+
+namespace {
+
+std::shared_ptr<std::unordered_set<int>> VectorToSharedSet(
+ const std::vector<int>& values) {
+ std::shared_ptr<std::unordered_set<int>> result(new std::unordered_set<int>());
+ result->insert(values.begin(), values.end());
+ return result;
+}
+
+// Forward declaration
+Status GetReader(const SchemaField& field, const std::shared_ptr<ReaderContext>& context,
+ std::unique_ptr<ColumnReaderImpl>* out);
+
+// ----------------------------------------------------------------------
+// FileReaderImpl forward declaration
+
+class FileReaderImpl : public FileReader {
+ public:
+ FileReaderImpl(MemoryPool* pool, std::unique_ptr<ParquetFileReader> reader,
+ ArrowReaderProperties properties)
+ : pool_(pool),
+ reader_(std::move(reader)),
+ reader_properties_(std::move(properties)) {}
+
+ Status Init() {
+ return SchemaManifest::Make(reader_->metadata()->schema(),
+ reader_->metadata()->key_value_metadata(),
+ reader_properties_, &manifest_);
+ }
+
+ FileColumnIteratorFactory SomeRowGroupsFactory(std::vector<int> row_groups) {
+ return [row_groups](int i, ParquetFileReader* reader) {
+ return new FileColumnIterator(i, reader, row_groups);
+ };
+ }
+
+ FileColumnIteratorFactory AllRowGroupsFactory() {
+ return SomeRowGroupsFactory(Iota(reader_->metadata()->num_row_groups()));
+ }
+
+ Status BoundsCheckColumn(int column) {
+ if (column < 0 || column >= this->num_columns()) {
+ return Status::Invalid("Column index out of bounds (got ", column,
+ ", should be "
+ "between 0 and ",
+ this->num_columns() - 1, ")");
+ }
+ return Status::OK();
+ }
+
+ Status BoundsCheckRowGroup(int row_group) {
+ // row group indices check
+ if (row_group < 0 || row_group >= num_row_groups()) {
+ return Status::Invalid("Some index in row_group_indices is ", row_group,
+ ", which is either < 0 or >= num_row_groups(",
+ num_row_groups(), ")");
+ }
+ return Status::OK();
+ }
+
+ Status BoundsCheck(const std::vector<int>& row_groups,
+ const std::vector<int>& column_indices) {
+ for (int i : row_groups) {
+ RETURN_NOT_OK(BoundsCheckRowGroup(i));
+ }
+ for (int i : column_indices) {
+ RETURN_NOT_OK(BoundsCheckColumn(i));
+ }
+ return Status::OK();
+ }
+
+ std::shared_ptr<RowGroupReader> RowGroup(int row_group_index) override;
+
+ Status ReadTable(const std::vector<int>& indices,
+ std::shared_ptr<Table>* out) override {
+ return ReadRowGroups(Iota(reader_->metadata()->num_row_groups()), indices, out);
+ }
+
+ Status GetFieldReader(int i,
+ const std::shared_ptr<std::unordered_set<int>>& included_leaves,
+ const std::vector<int>& row_groups,
+ std::unique_ptr<ColumnReaderImpl>* out) {
+ auto ctx = std::make_shared<ReaderContext>();
+ ctx->reader = reader_.get();
+ ctx->pool = pool_;
+ ctx->iterator_factory = SomeRowGroupsFactory(row_groups);
+ ctx->filter_leaves = true;
+ ctx->included_leaves = included_leaves;
+ return GetReader(manifest_.schema_fields[i], ctx, out);
+ }
+
+ Status GetFieldReaders(const std::vector<int>& column_indices,
+ const std::vector<int>& row_groups,
+ std::vector<std::shared_ptr<ColumnReaderImpl>>* out,
+ std::shared_ptr<::arrow::Schema>* out_schema) {
+ // We only need to read schema fields which have columns indicated
+ // in the indices vector
+ ARROW_ASSIGN_OR_RAISE(std::vector<int> field_indices,
+ manifest_.GetFieldIndices(column_indices));
+
+ auto included_leaves = VectorToSharedSet(column_indices);
+
+ out->resize(field_indices.size());
+ ::arrow::FieldVector out_fields(field_indices.size());
+ for (size_t i = 0; i < out->size(); ++i) {
+ std::unique_ptr<ColumnReaderImpl> reader;
+ RETURN_NOT_OK(
+ GetFieldReader(field_indices[i], included_leaves, row_groups, &reader));
+
+ out_fields[i] = reader->field();
+ out->at(i) = std::move(reader);
+ }
+
+ *out_schema = ::arrow::schema(std::move(out_fields), manifest_.schema_metadata);
+ return Status::OK();
+ }
+
+ Status GetColumn(int i, FileColumnIteratorFactory iterator_factory,
+ std::unique_ptr<ColumnReader>* out);
+
+ Status GetColumn(int i, std::unique_ptr<ColumnReader>* out) override {
+ return GetColumn(i, AllRowGroupsFactory(), out);
+ }
+
+ Status GetSchema(std::shared_ptr<::arrow::Schema>* out) override {
+ return FromParquetSchema(reader_->metadata()->schema(), reader_properties_,
+ reader_->metadata()->key_value_metadata(), out);
+ }
+
+ Status ReadSchemaField(int i, std::shared_ptr<ChunkedArray>* out) override {
+ auto included_leaves = VectorToSharedSet(Iota(reader_->metadata()->num_columns()));
+ std::vector<int> row_groups = Iota(reader_->metadata()->num_row_groups());
+
+ std::unique_ptr<ColumnReaderImpl> reader;
+ RETURN_NOT_OK(GetFieldReader(i, included_leaves, row_groups, &reader));
+
+ return ReadColumn(i, row_groups, reader.get(), out);
+ }
+
+ Status ReadColumn(int i, const std::vector<int>& row_groups, ColumnReader* reader,
+ std::shared_ptr<ChunkedArray>* out) {
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ // TODO(wesm): This calculation doesn't make much sense when we have repeated
+ // schema nodes
+ int64_t records_to_read = 0;
+ for (auto row_group : row_groups) {
+ // Can throw exception
+ records_to_read +=
+ reader_->metadata()->RowGroup(row_group)->ColumnChunk(i)->num_values();
+ }
+ return reader->NextBatch(records_to_read, out);
+ END_PARQUET_CATCH_EXCEPTIONS
+ }
+
+ Status ReadColumn(int i, const std::vector<int>& row_groups,
+ std::shared_ptr<ChunkedArray>* out) {
+ std::unique_ptr<ColumnReader> flat_column_reader;
+ RETURN_NOT_OK(GetColumn(i, SomeRowGroupsFactory(row_groups), &flat_column_reader));
+ return ReadColumn(i, row_groups, flat_column_reader.get(), out);
+ }
+
+ Status ReadColumn(int i, std::shared_ptr<ChunkedArray>* out) override {
+ return ReadColumn(i, Iota(reader_->metadata()->num_row_groups()), out);
+ }
+
+ Status ReadTable(std::shared_ptr<Table>* table) override {
+ return ReadTable(Iota(reader_->metadata()->num_columns()), table);
+ }
+
+ Status ReadRowGroups(const std::vector<int>& row_groups,
+ const std::vector<int>& indices,
+ std::shared_ptr<Table>* table) override;
+
+ // Helper method used by ReadRowGroups - read the given row groups/columns, skipping
+ // bounds checks and pre-buffering. Takes a shared_ptr to self to keep the reader
+ // alive in async contexts.
+ Future<std::shared_ptr<Table>> DecodeRowGroups(
+ std::shared_ptr<FileReaderImpl> self, const std::vector<int>& row_groups,
+ const std::vector<int>& column_indices, ::arrow::internal::Executor* cpu_executor);
+
+ Status ReadRowGroups(const std::vector<int>& row_groups,
+ std::shared_ptr<Table>* table) override {
+ return ReadRowGroups(row_groups, Iota(reader_->metadata()->num_columns()), table);
+ }
+
+ Status ReadRowGroup(int row_group_index, const std::vector<int>& column_indices,
+ std::shared_ptr<Table>* out) override {
+ return ReadRowGroups({row_group_index}, column_indices, out);
+ }
+
+ Status ReadRowGroup(int i, std::shared_ptr<Table>* table) override {
+ return ReadRowGroup(i, Iota(reader_->metadata()->num_columns()), table);
+ }
+
+ Status GetRecordBatchReader(const std::vector<int>& row_group_indices,
+ const std::vector<int>& column_indices,
+ std::unique_ptr<RecordBatchReader>* out) override;
+
+ Status GetRecordBatchReader(const std::vector<int>& row_group_indices,
+ std::unique_ptr<RecordBatchReader>* out) override {
+ return GetRecordBatchReader(row_group_indices,
+ Iota(reader_->metadata()->num_columns()), out);
+ }
+
+ ::arrow::Result<::arrow::AsyncGenerator<std::shared_ptr<::arrow::RecordBatch>>>
+ GetRecordBatchGenerator(std::shared_ptr<FileReader> reader,
+ const std::vector<int> row_group_indices,
+ const std::vector<int> column_indices,
+ ::arrow::internal::Executor* cpu_executor,
+ int row_group_readahead) override;
+
+ int num_columns() const { return reader_->metadata()->num_columns(); }
+
+ ParquetFileReader* parquet_reader() const override { return reader_.get(); }
+
+ int num_row_groups() const override { return reader_->metadata()->num_row_groups(); }
+
+ void set_use_threads(bool use_threads) override {
+ reader_properties_.set_use_threads(use_threads);
+ }
+
+ void set_batch_size(int64_t batch_size) override {
+ reader_properties_.set_batch_size(batch_size);
+ }
+
+ const ArrowReaderProperties& properties() const override { return reader_properties_; }
+
+ const SchemaManifest& manifest() const override { return manifest_; }
+
+ Status ScanContents(std::vector<int> columns, const int32_t column_batch_size,
+ int64_t* num_rows) override {
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ *num_rows = ScanFileContents(columns, column_batch_size, reader_.get());
+ return Status::OK();
+ END_PARQUET_CATCH_EXCEPTIONS
+ }
+
+ MemoryPool* pool_;
+ std::unique_ptr<ParquetFileReader> reader_;
+ ArrowReaderProperties reader_properties_;
+
+ SchemaManifest manifest_;
+};
+
+class RowGroupRecordBatchReader : public ::arrow::RecordBatchReader {
+ public:
+ RowGroupRecordBatchReader(::arrow::RecordBatchIterator batches,
+ std::shared_ptr<::arrow::Schema> schema)
+ : batches_(std::move(batches)), schema_(std::move(schema)) {}
+
+ ~RowGroupRecordBatchReader() override {}
+
+ Status ReadNext(std::shared_ptr<::arrow::RecordBatch>* out) override {
+ return batches_.Next().Value(out);
+ }
+
+ std::shared_ptr<::arrow::Schema> schema() const override { return schema_; }
+
+ private:
+ ::arrow::Iterator<std::shared_ptr<::arrow::RecordBatch>> batches_;
+ std::shared_ptr<::arrow::Schema> schema_;
+};
+
+class ColumnChunkReaderImpl : public ColumnChunkReader {
+ public:
+ ColumnChunkReaderImpl(FileReaderImpl* impl, int row_group_index, int column_index)
+ : impl_(impl), column_index_(column_index), row_group_index_(row_group_index) {}
+
+ Status Read(std::shared_ptr<::arrow::ChunkedArray>* out) override {
+ return impl_->ReadColumn(column_index_, {row_group_index_}, out);
+ }
+
+ private:
+ FileReaderImpl* impl_;
+ int column_index_;
+ int row_group_index_;
+};
+
+class RowGroupReaderImpl : public RowGroupReader {
+ public:
+ RowGroupReaderImpl(FileReaderImpl* impl, int row_group_index)
+ : impl_(impl), row_group_index_(row_group_index) {}
+
+ std::shared_ptr<ColumnChunkReader> Column(int column_index) override {
+ return std::shared_ptr<ColumnChunkReader>(
+ new ColumnChunkReaderImpl(impl_, row_group_index_, column_index));
+ }
+
+ Status ReadTable(const std::vector<int>& column_indices,
+ std::shared_ptr<::arrow::Table>* out) override {
+ return impl_->ReadRowGroup(row_group_index_, column_indices, out);
+ }
+
+ Status ReadTable(std::shared_ptr<::arrow::Table>* out) override {
+ return impl_->ReadRowGroup(row_group_index_, out);
+ }
+
+ private:
+ FileReaderImpl* impl_;
+ int row_group_index_;
+};
+
+// ----------------------------------------------------------------------
+// Column reader implementations
+
+// Leaf reader is for primitive arrays and primitive children of nested arrays
+class LeafReader : public ColumnReaderImpl {
+ public:
+ LeafReader(std::shared_ptr<ReaderContext> ctx, std::shared_ptr<Field> field,
+ std::unique_ptr<FileColumnIterator> input,
+ ::parquet::internal::LevelInfo leaf_info)
+ : ctx_(std::move(ctx)),
+ field_(std::move(field)),
+ input_(std::move(input)),
+ descr_(input_->descr()) {
+ record_reader_ = RecordReader::Make(
+ descr_, leaf_info, ctx_->pool, field_->type()->id() == ::arrow::Type::DICTIONARY);
+ NextRowGroup();
+ }
+
+ Status GetDefLevels(const int16_t** data, int64_t* length) final {
+ *data = record_reader_->def_levels();
+ *length = record_reader_->levels_position();
+ return Status::OK();
+ }
+
+ Status GetRepLevels(const int16_t** data, int64_t* length) final {
+ *data = record_reader_->rep_levels();
+ *length = record_reader_->levels_position();
+ return Status::OK();
+ }
+
+ bool IsOrHasRepeatedChild() const final { return false; }
+
+ Status LoadBatch(int64_t records_to_read) final {
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ out_ = nullptr;
+ record_reader_->Reset();
+ // Pre-allocation gives much better performance for flat columns
+ record_reader_->Reserve(records_to_read);
+ while (records_to_read > 0) {
+ if (!record_reader_->HasMoreData()) {
+ break;
+ }
+ int64_t records_read = record_reader_->ReadRecords(records_to_read);
+ records_to_read -= records_read;
+ if (records_read == 0) {
+ NextRowGroup();
+ }
+ }
+ RETURN_NOT_OK(TransferColumnData(record_reader_.get(), field_->type(), descr_,
+ ctx_->pool, &out_));
+ return Status::OK();
+ END_PARQUET_CATCH_EXCEPTIONS
+ }
+
+ ::arrow::Status BuildArray(int64_t length_upper_bound,
+ std::shared_ptr<::arrow::ChunkedArray>* out) final {
+ *out = out_;
+ return Status::OK();
+ }
+
+ const std::shared_ptr<Field> field() override { return field_; }
+
+ private:
+ std::shared_ptr<ChunkedArray> out_;
+ void NextRowGroup() {
+ std::unique_ptr<PageReader> page_reader = input_->NextChunk();
+ record_reader_->SetPageReader(std::move(page_reader));
+ }
+
+ std::shared_ptr<ReaderContext> ctx_;
+ std::shared_ptr<Field> field_;
+ std::unique_ptr<FileColumnIterator> input_;
+ const ColumnDescriptor* descr_;
+ std::shared_ptr<RecordReader> record_reader_;
+};
+
+// Column reader for extension arrays
+class ExtensionReader : public ColumnReaderImpl {
+ public:
+ ExtensionReader(std::shared_ptr<Field> field,
+ std::unique_ptr<ColumnReaderImpl> storage_reader)
+ : field_(std::move(field)), storage_reader_(std::move(storage_reader)) {}
+
+ Status GetDefLevels(const int16_t** data, int64_t* length) override {
+ return storage_reader_->GetDefLevels(data, length);
+ }
+
+ Status GetRepLevels(const int16_t** data, int64_t* length) override {
+ return storage_reader_->GetRepLevels(data, length);
+ }
+
+ Status LoadBatch(int64_t number_of_records) final {
+ return storage_reader_->LoadBatch(number_of_records);
+ }
+
+ Status BuildArray(int64_t length_upper_bound,
+ std::shared_ptr<ChunkedArray>* out) override {
+ std::shared_ptr<ChunkedArray> storage;
+ RETURN_NOT_OK(storage_reader_->BuildArray(length_upper_bound, &storage));
+ *out = ExtensionType::WrapArray(field_->type(), storage);
+ return Status::OK();
+ }
+
+ bool IsOrHasRepeatedChild() const final {
+ return storage_reader_->IsOrHasRepeatedChild();
+ }
+
+ const std::shared_ptr<Field> field() override { return field_; }
+
+ private:
+ std::shared_ptr<Field> field_;
+ std::unique_ptr<ColumnReaderImpl> storage_reader_;
+};
+
+template <typename IndexType>
+class ListReader : public ColumnReaderImpl {
+ public:
+ ListReader(std::shared_ptr<ReaderContext> ctx, std::shared_ptr<Field> field,
+ ::parquet::internal::LevelInfo level_info,
+ std::unique_ptr<ColumnReaderImpl> child_reader)
+ : ctx_(std::move(ctx)),
+ field_(std::move(field)),
+ level_info_(level_info),
+ item_reader_(std::move(child_reader)) {}
+
+ Status GetDefLevels(const int16_t** data, int64_t* length) override {
+ return item_reader_->GetDefLevels(data, length);
+ }
+
+ Status GetRepLevels(const int16_t** data, int64_t* length) override {
+ return item_reader_->GetRepLevels(data, length);
+ }
+
+ bool IsOrHasRepeatedChild() const final { return true; }
+
+ Status LoadBatch(int64_t number_of_records) final {
+ return item_reader_->LoadBatch(number_of_records);
+ }
+
+ virtual ::arrow::Result<std::shared_ptr<ChunkedArray>> AssembleArray(
+ std::shared_ptr<ArrayData> data) {
+ if (field_->type()->id() == ::arrow::Type::MAP) {
+ // Error out if data is not map-compliant instead of aborting in MakeArray below
+ RETURN_NOT_OK(::arrow::MapArray::ValidateChildData(data->child_data));
+ }
+ std::shared_ptr<Array> result = ::arrow::MakeArray(data);
+ return std::make_shared<ChunkedArray>(result);
+ }
+
+ Status BuildArray(int64_t length_upper_bound,
+ std::shared_ptr<ChunkedArray>* out) override {
+ const int16_t* def_levels;
+ const int16_t* rep_levels;
+ int64_t num_levels;
+ RETURN_NOT_OK(item_reader_->GetDefLevels(&def_levels, &num_levels));
+ RETURN_NOT_OK(item_reader_->GetRepLevels(&rep_levels, &num_levels));
+
+ std::shared_ptr<ResizableBuffer> validity_buffer;
+ ::parquet::internal::ValidityBitmapInputOutput validity_io;
+ validity_io.values_read_upper_bound = length_upper_bound;
+ if (field_->nullable()) {
+ ARROW_ASSIGN_OR_RAISE(
+ validity_buffer,
+ AllocateResizableBuffer(BitUtil::BytesForBits(length_upper_bound), ctx_->pool));
+ validity_io.valid_bits = validity_buffer->mutable_data();
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<ResizableBuffer> offsets_buffer,
+ AllocateResizableBuffer(
+ sizeof(IndexType) * std::max(int64_t{1}, length_upper_bound + 1),
+ ctx_->pool));
+ // Ensure zero initialization in case we have reached a zero length list (and
+ // because first entry is always zero).
+ IndexType* offset_data = reinterpret_cast<IndexType*>(offsets_buffer->mutable_data());
+ offset_data[0] = 0;
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ ::parquet::internal::DefRepLevelsToList(def_levels, rep_levels, num_levels,
+ level_info_, &validity_io, offset_data);
+ END_PARQUET_CATCH_EXCEPTIONS
+
+ RETURN_NOT_OK(item_reader_->BuildArray(offset_data[validity_io.values_read], out));
+
+ // Resize to actual number of elements returned.
+ RETURN_NOT_OK(
+ offsets_buffer->Resize((validity_io.values_read + 1) * sizeof(IndexType)));
+ if (validity_buffer != nullptr) {
+ RETURN_NOT_OK(
+ validity_buffer->Resize(BitUtil::BytesForBits(validity_io.values_read)));
+ validity_buffer->ZeroPadding();
+ }
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ArrayData> item_chunk, ChunksToSingle(**out));
+
+ std::vector<std::shared_ptr<Buffer>> buffers{
+ validity_io.null_count > 0 ? validity_buffer : nullptr, offsets_buffer};
+ auto data = std::make_shared<ArrayData>(
+ field_->type(),
+ /*length=*/validity_io.values_read, std::move(buffers),
+ std::vector<std::shared_ptr<ArrayData>>{item_chunk}, validity_io.null_count);
+
+ ARROW_ASSIGN_OR_RAISE(*out, AssembleArray(std::move(data)));
+ return Status::OK();
+ }
+
+ const std::shared_ptr<Field> field() override { return field_; }
+
+ private:
+ std::shared_ptr<ReaderContext> ctx_;
+ std::shared_ptr<Field> field_;
+ ::parquet::internal::LevelInfo level_info_;
+ std::unique_ptr<ColumnReaderImpl> item_reader_;
+};
+
+class PARQUET_NO_EXPORT FixedSizeListReader : public ListReader<int32_t> {
+ public:
+ FixedSizeListReader(std::shared_ptr<ReaderContext> ctx, std::shared_ptr<Field> field,
+ ::parquet::internal::LevelInfo level_info,
+ std::unique_ptr<ColumnReaderImpl> child_reader)
+ : ListReader(std::move(ctx), std::move(field), level_info,
+ std::move(child_reader)) {}
+ ::arrow::Result<std::shared_ptr<ChunkedArray>> AssembleArray(
+ std::shared_ptr<ArrayData> data) final {
+ DCHECK_EQ(data->buffers.size(), 2);
+ DCHECK_EQ(field()->type()->id(), ::arrow::Type::FIXED_SIZE_LIST);
+ const auto& type = checked_cast<::arrow::FixedSizeListType&>(*field()->type());
+ const int32_t* offsets = reinterpret_cast<const int32_t*>(data->buffers[1]->data());
+ for (int x = 1; x <= data->length; x++) {
+ int32_t size = offsets[x] - offsets[x - 1];
+ if (size != type.list_size()) {
+ return Status::Invalid("Expected all lists to be of size=", type.list_size(),
+ " but index ", x, " had size=", size);
+ }
+ }
+ data->buffers.resize(1);
+ std::shared_ptr<Array> result = ::arrow::MakeArray(data);
+ return std::make_shared<ChunkedArray>(result);
+ }
+};
+
+class PARQUET_NO_EXPORT StructReader : public ColumnReaderImpl {
+ public:
+ explicit StructReader(std::shared_ptr<ReaderContext> ctx,
+ std::shared_ptr<Field> filtered_field,
+ ::parquet::internal::LevelInfo level_info,
+ std::vector<std::unique_ptr<ColumnReaderImpl>> children)
+ : ctx_(std::move(ctx)),
+ filtered_field_(std::move(filtered_field)),
+ level_info_(level_info),
+ children_(std::move(children)) {
+ // There could be a mix of children some might be repeated some might not be.
+ // If possible use one that isn't since that will be guaranteed to have the least
+ // number of levels to reconstruct a nullable bitmap.
+ auto result = std::find_if(children_.begin(), children_.end(),
+ [](const std::unique_ptr<ColumnReaderImpl>& child) {
+ return !child->IsOrHasRepeatedChild();
+ });
+ if (result != children_.end()) {
+ def_rep_level_child_ = result->get();
+ has_repeated_child_ = false;
+ } else if (!children_.empty()) {
+ def_rep_level_child_ = children_.front().get();
+ has_repeated_child_ = true;
+ }
+ }
+
+ bool IsOrHasRepeatedChild() const final { return has_repeated_child_; }
+
+ Status LoadBatch(int64_t records_to_read) override {
+ for (const std::unique_ptr<ColumnReaderImpl>& reader : children_) {
+ RETURN_NOT_OK(reader->LoadBatch(records_to_read));
+ }
+ return Status::OK();
+ }
+ Status BuildArray(int64_t length_upper_bound,
+ std::shared_ptr<ChunkedArray>* out) override;
+ Status GetDefLevels(const int16_t** data, int64_t* length) override;
+ Status GetRepLevels(const int16_t** data, int64_t* length) override;
+ const std::shared_ptr<Field> field() override { return filtered_field_; }
+
+ private:
+ const std::shared_ptr<ReaderContext> ctx_;
+ const std::shared_ptr<Field> filtered_field_;
+ const ::parquet::internal::LevelInfo level_info_;
+ const std::vector<std::unique_ptr<ColumnReaderImpl>> children_;
+ ColumnReaderImpl* def_rep_level_child_ = nullptr;
+ bool has_repeated_child_;
+};
+
+Status StructReader::GetDefLevels(const int16_t** data, int64_t* length) {
+ *data = nullptr;
+ if (children_.size() == 0) {
+ *length = 0;
+ return Status::Invalid("StructReader had no children");
+ }
+
+ // This method should only be called when this struct or one of its parents
+ // are optional/repeated or it has a repeated child.
+ // Meaning all children must have rep/def levels associated
+ // with them.
+ RETURN_NOT_OK(def_rep_level_child_->GetDefLevels(data, length));
+ return Status::OK();
+}
+
+Status StructReader::GetRepLevels(const int16_t** data, int64_t* length) {
+ *data = nullptr;
+ if (children_.size() == 0) {
+ *length = 0;
+ return Status::Invalid("StructReader had no childre");
+ }
+
+ // This method should only be called when this struct or one of its parents
+ // are optional/repeated or it has repeated child.
+ // Meaning all children must have rep/def levels associated
+ // with them.
+ RETURN_NOT_OK(def_rep_level_child_->GetRepLevels(data, length));
+ return Status::OK();
+}
+
+Status StructReader::BuildArray(int64_t length_upper_bound,
+ std::shared_ptr<ChunkedArray>* out) {
+ std::vector<std::shared_ptr<ArrayData>> children_array_data;
+ std::shared_ptr<ResizableBuffer> null_bitmap;
+
+ ::parquet::internal::ValidityBitmapInputOutput validity_io;
+ validity_io.values_read_upper_bound = length_upper_bound;
+ // This simplifies accounting below.
+ validity_io.values_read = length_upper_bound;
+
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ const int16_t* def_levels;
+ const int16_t* rep_levels;
+ int64_t num_levels;
+
+ if (has_repeated_child_) {
+ ARROW_ASSIGN_OR_RAISE(
+ null_bitmap,
+ AllocateResizableBuffer(BitUtil::BytesForBits(length_upper_bound), ctx_->pool));
+ validity_io.valid_bits = null_bitmap->mutable_data();
+ RETURN_NOT_OK(GetDefLevels(&def_levels, &num_levels));
+ RETURN_NOT_OK(GetRepLevels(&rep_levels, &num_levels));
+ DefRepLevelsToBitmap(def_levels, rep_levels, num_levels, level_info_, &validity_io);
+ } else if (filtered_field_->nullable()) {
+ ARROW_ASSIGN_OR_RAISE(
+ null_bitmap,
+ AllocateResizableBuffer(BitUtil::BytesForBits(length_upper_bound), ctx_->pool));
+ validity_io.valid_bits = null_bitmap->mutable_data();
+ RETURN_NOT_OK(GetDefLevels(&def_levels, &num_levels));
+ DefLevelsToBitmap(def_levels, num_levels, level_info_, &validity_io);
+ }
+
+ // Ensure all values are initialized.
+ if (null_bitmap) {
+ RETURN_NOT_OK(null_bitmap->Resize(BitUtil::BytesForBits(validity_io.values_read)));
+ null_bitmap->ZeroPadding();
+ }
+
+ END_PARQUET_CATCH_EXCEPTIONS
+ // Gather children arrays and def levels
+ for (auto& child : children_) {
+ std::shared_ptr<ChunkedArray> field;
+ RETURN_NOT_OK(child->BuildArray(validity_io.values_read, &field));
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ArrayData> array_data, ChunksToSingle(*field));
+ children_array_data.push_back(std::move(array_data));
+ }
+
+ if (!filtered_field_->nullable() && !has_repeated_child_) {
+ validity_io.values_read = children_array_data.front()->length;
+ }
+
+ std::vector<std::shared_ptr<Buffer>> buffers{validity_io.null_count > 0 ? null_bitmap
+ : nullptr};
+ auto data =
+ std::make_shared<ArrayData>(filtered_field_->type(),
+ /*length=*/validity_io.values_read, std::move(buffers),
+ std::move(children_array_data));
+ std::shared_ptr<Array> result = ::arrow::MakeArray(data);
+
+ *out = std::make_shared<ChunkedArray>(result);
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// File reader implementation
+
+Status GetReader(const SchemaField& field, const std::shared_ptr<Field>& arrow_field,
+ const std::shared_ptr<ReaderContext>& ctx,
+ std::unique_ptr<ColumnReaderImpl>* out) {
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+
+ auto type_id = arrow_field->type()->id();
+
+ if (type_id == ::arrow::Type::EXTENSION) {
+ auto storage_field = arrow_field->WithType(
+ checked_cast<const ExtensionType&>(*arrow_field->type()).storage_type());
+ RETURN_NOT_OK(GetReader(field, storage_field, ctx, out));
+ out->reset(new ExtensionReader(arrow_field, std::move(*out)));
+ return Status::OK();
+ }
+
+ if (field.children.size() == 0) {
+ if (!field.is_leaf()) {
+ return Status::Invalid("Parquet non-leaf node has no children");
+ }
+ if (!ctx->IncludesLeaf(field.column_index)) {
+ *out = nullptr;
+ return Status::OK();
+ }
+ std::unique_ptr<FileColumnIterator> input(
+ ctx->iterator_factory(field.column_index, ctx->reader));
+ out->reset(new LeafReader(ctx, arrow_field, std::move(input), field.level_info));
+ } else if (type_id == ::arrow::Type::LIST || type_id == ::arrow::Type::MAP ||
+ type_id == ::arrow::Type::FIXED_SIZE_LIST ||
+ type_id == ::arrow::Type::LARGE_LIST) {
+ auto list_field = arrow_field;
+ auto child = &field.children[0];
+ std::unique_ptr<ColumnReaderImpl> child_reader;
+ RETURN_NOT_OK(GetReader(*child, ctx, &child_reader));
+ if (child_reader == nullptr) {
+ *out = nullptr;
+ return Status::OK();
+ }
+
+ // These two types might not be equal if there column pruning occurred.
+ // further down the stack.
+ const std::shared_ptr<DataType> reader_child_type = child_reader->field()->type();
+ // This should really never happen but was raised as a question on the code
+ // review, this should be pretty cheap check so leave it in.
+ if (ARROW_PREDICT_FALSE(list_field->type()->num_fields() != 1)) {
+ return Status::Invalid("expected exactly one child field for: ",
+ list_field->ToString());
+ }
+ const DataType& schema_child_type = *(list_field->type()->field(0)->type());
+ if (type_id == ::arrow::Type::MAP) {
+ if (reader_child_type->num_fields() != 2 ||
+ !reader_child_type->field(0)->type()->Equals(
+ *schema_child_type.field(0)->type())) {
+ // This case applies if either key or value are completed filtered
+ // out so we can take the type as is or the key was partially
+ // so keeping it as a map no longer makes sence.
+ list_field = list_field->WithType(::arrow::list(child_reader->field()));
+ } else if (!reader_child_type->field(1)->type()->Equals(
+ *schema_child_type.field(1)->type())) {
+ list_field = list_field->WithType(std::make_shared<::arrow::MapType>(
+ reader_child_type->field(
+ 0), // field 0 is unchanged baed on previous if statement
+ reader_child_type->field(1)));
+ }
+ // Map types are list<struct<key, value>> so use ListReader
+ // for reconstruction.
+ out->reset(new ListReader<int32_t>(ctx, list_field, field.level_info,
+ std::move(child_reader)));
+ } else if (type_id == ::arrow::Type::LIST) {
+ if (!reader_child_type->Equals(schema_child_type)) {
+ list_field = list_field->WithType(::arrow::list(reader_child_type));
+ }
+
+ out->reset(new ListReader<int32_t>(ctx, list_field, field.level_info,
+ std::move(child_reader)));
+ } else if (type_id == ::arrow::Type::LARGE_LIST) {
+ if (!reader_child_type->Equals(schema_child_type)) {
+ list_field = list_field->WithType(::arrow::large_list(reader_child_type));
+ }
+
+ out->reset(new ListReader<int64_t>(ctx, list_field, field.level_info,
+ std::move(child_reader)));
+ } else if (type_id == ::arrow::Type::FIXED_SIZE_LIST) {
+ if (!reader_child_type->Equals(schema_child_type)) {
+ auto& fixed_list_type =
+ checked_cast<const ::arrow::FixedSizeListType&>(*list_field->type());
+ int32_t list_size = fixed_list_type.list_size();
+ list_field =
+ list_field->WithType(::arrow::fixed_size_list(reader_child_type, list_size));
+ }
+
+ out->reset(new FixedSizeListReader(ctx, list_field, field.level_info,
+ std::move(child_reader)));
+ } else {
+ return Status::UnknownError("Unknown list type: ", field.field->ToString());
+ }
+ } else if (type_id == ::arrow::Type::STRUCT) {
+ std::vector<std::shared_ptr<Field>> child_fields;
+ int arrow_field_idx = 0;
+ std::vector<std::unique_ptr<ColumnReaderImpl>> child_readers;
+ for (const auto& child : field.children) {
+ std::unique_ptr<ColumnReaderImpl> child_reader;
+ RETURN_NOT_OK(GetReader(child, ctx, &child_reader));
+ if (!child_reader) {
+ arrow_field_idx++;
+ // If all children were pruned, then we do not try to read this field
+ continue;
+ }
+ std::shared_ptr<::arrow::Field> child_field = child.field;
+ const DataType& reader_child_type = *child_reader->field()->type();
+ const DataType& schema_child_type =
+ *arrow_field->type()->field(arrow_field_idx++)->type();
+ // These might not be equal if column pruning occurred.
+ if (!schema_child_type.Equals(reader_child_type)) {
+ child_field = child_field->WithType(child_reader->field()->type());
+ }
+ child_fields.push_back(child_field);
+ child_readers.emplace_back(std::move(child_reader));
+ }
+ if (child_fields.size() == 0) {
+ *out = nullptr;
+ return Status::OK();
+ }
+ auto filtered_field =
+ ::arrow::field(arrow_field->name(), ::arrow::struct_(child_fields),
+ arrow_field->nullable(), arrow_field->metadata());
+ out->reset(new StructReader(ctx, filtered_field, field.level_info,
+ std::move(child_readers)));
+ } else {
+ return Status::Invalid("Unsupported nested type: ", arrow_field->ToString());
+ }
+ return Status::OK();
+
+ END_PARQUET_CATCH_EXCEPTIONS
+}
+
+Status GetReader(const SchemaField& field, const std::shared_ptr<ReaderContext>& ctx,
+ std::unique_ptr<ColumnReaderImpl>* out) {
+ return GetReader(field, field.field, ctx, out);
+}
+
+} // namespace
+
+Status FileReaderImpl::GetRecordBatchReader(const std::vector<int>& row_groups,
+ const std::vector<int>& column_indices,
+ std::unique_ptr<RecordBatchReader>* out) {
+ RETURN_NOT_OK(BoundsCheck(row_groups, column_indices));
+
+ if (reader_properties_.pre_buffer()) {
+ // PARQUET-1698/PARQUET-1820: pre-buffer row groups/column chunks if enabled
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ reader_->PreBuffer(row_groups, column_indices, reader_properties_.io_context(),
+ reader_properties_.cache_options());
+ END_PARQUET_CATCH_EXCEPTIONS
+ }
+
+ std::vector<std::shared_ptr<ColumnReaderImpl>> readers;
+ std::shared_ptr<::arrow::Schema> batch_schema;
+ RETURN_NOT_OK(GetFieldReaders(column_indices, row_groups, &readers, &batch_schema));
+
+ if (readers.empty()) {
+ // Just generate all batches right now; they're cheap since they have no columns.
+ int64_t batch_size = properties().batch_size();
+ auto max_sized_batch =
+ ::arrow::RecordBatch::Make(batch_schema, batch_size, ::arrow::ArrayVector{});
+
+ ::arrow::RecordBatchVector batches;
+
+ for (int row_group : row_groups) {
+ int64_t num_rows = parquet_reader()->metadata()->RowGroup(row_group)->num_rows();
+
+ batches.insert(batches.end(), num_rows / batch_size, max_sized_batch);
+
+ if (int64_t trailing_rows = num_rows % batch_size) {
+ batches.push_back(max_sized_batch->Slice(0, trailing_rows));
+ }
+ }
+
+ *out = ::arrow::internal::make_unique<RowGroupRecordBatchReader>(
+ ::arrow::MakeVectorIterator(std::move(batches)), std::move(batch_schema));
+
+ return Status::OK();
+ }
+
+ int64_t num_rows = 0;
+ for (int row_group : row_groups) {
+ num_rows += parquet_reader()->metadata()->RowGroup(row_group)->num_rows();
+ }
+
+ using ::arrow::RecordBatchIterator;
+
+ // NB: This lambda will be invoked outside the scope of this call to
+ // `GetRecordBatchReader()`, so it must capture `readers` and `batch_schema` by value.
+ // `this` is a non-owning pointer so we are relying on the parent FileReader outliving
+ // this RecordBatchReader.
+ ::arrow::Iterator<RecordBatchIterator> batches = ::arrow::MakeFunctionIterator(
+ [readers, batch_schema, num_rows,
+ this]() mutable -> ::arrow::Result<RecordBatchIterator> {
+ ::arrow::ChunkedArrayVector columns(readers.size());
+
+ // don't reserve more rows than necessary
+ int64_t batch_size = std::min(properties().batch_size(), num_rows);
+ num_rows -= batch_size;
+
+ RETURN_NOT_OK(::arrow::internal::OptionalParallelFor(
+ reader_properties_.use_threads(), static_cast<int>(readers.size()),
+ [&](int i) { return readers[i]->NextBatch(batch_size, &columns[i]); }));
+
+ for (const auto& column : columns) {
+ if (column == nullptr || column->length() == 0) {
+ return ::arrow::IterationTraits<RecordBatchIterator>::End();
+ }
+ }
+
+ auto table = ::arrow::Table::Make(batch_schema, std::move(columns));
+ auto table_reader = std::make_shared<::arrow::TableBatchReader>(*table);
+
+ // NB: explicitly preserve table so that table_reader doesn't outlive it
+ return ::arrow::MakeFunctionIterator(
+ [table, table_reader] { return table_reader->Next(); });
+ });
+
+ *out = ::arrow::internal::make_unique<RowGroupRecordBatchReader>(
+ ::arrow::MakeFlattenIterator(std::move(batches)), std::move(batch_schema));
+
+ return Status::OK();
+}
+
+/// Given a file reader and a list of row groups, this is a generator of record
+/// batch generators (where each sub-generator is the contents of a single row group).
+class RowGroupGenerator {
+ public:
+ using RecordBatchGenerator =
+ ::arrow::AsyncGenerator<std::shared_ptr<::arrow::RecordBatch>>;
+
+ explicit RowGroupGenerator(std::shared_ptr<FileReaderImpl> arrow_reader,
+ ::arrow::internal::Executor* cpu_executor,
+ std::vector<int> row_groups, std::vector<int> column_indices)
+ : arrow_reader_(std::move(arrow_reader)),
+ cpu_executor_(cpu_executor),
+ row_groups_(std::move(row_groups)),
+ column_indices_(std::move(column_indices)),
+ index_(0) {}
+
+ ::arrow::Future<RecordBatchGenerator> operator()() {
+ if (index_ >= row_groups_.size()) {
+ return ::arrow::AsyncGeneratorEnd<RecordBatchGenerator>();
+ }
+ int row_group = row_groups_[index_++];
+ std::vector<int> column_indices = column_indices_;
+ auto reader = arrow_reader_;
+ if (!reader->properties().pre_buffer()) {
+ return SubmitRead(cpu_executor_, reader, row_group, column_indices);
+ }
+ auto ready = reader->parquet_reader()->WhenBuffered({row_group}, column_indices);
+ if (cpu_executor_) ready = cpu_executor_->TransferAlways(ready);
+ return ready.Then([=]() -> ::arrow::Future<RecordBatchGenerator> {
+ return ReadOneRowGroup(cpu_executor_, reader, row_group, column_indices);
+ });
+ }
+
+ private:
+ // Synchronous fallback for when pre-buffer isn't enabled.
+ //
+ // Making the Parquet reader truly asynchronous requires heavy refactoring, so the
+ // generator piggybacks on ReadRangeCache. The lazy ReadRangeCache can be used for
+ // async I/O without forcing readahead.
+ static ::arrow::Future<RecordBatchGenerator> SubmitRead(
+ ::arrow::internal::Executor* cpu_executor, std::shared_ptr<FileReaderImpl> self,
+ const int row_group, const std::vector<int>& column_indices) {
+ if (!cpu_executor) {
+ return ReadOneRowGroup(cpu_executor, self, row_group, column_indices);
+ }
+ // If we have an executor, then force transfer (even if I/O was complete)
+ return ::arrow::DeferNotOk(cpu_executor->Submit(ReadOneRowGroup, cpu_executor, self,
+ row_group, column_indices));
+ }
+
+ static ::arrow::Future<RecordBatchGenerator> ReadOneRowGroup(
+ ::arrow::internal::Executor* cpu_executor, std::shared_ptr<FileReaderImpl> self,
+ const int row_group, const std::vector<int>& column_indices) {
+ // Skips bound checks/pre-buffering, since we've done that already
+ const int64_t batch_size = self->properties().batch_size();
+ return self->DecodeRowGroups(self, {row_group}, column_indices, cpu_executor)
+ .Then([batch_size](const std::shared_ptr<Table>& table)
+ -> ::arrow::Result<RecordBatchGenerator> {
+ ::arrow::TableBatchReader table_reader(*table);
+ table_reader.set_chunksize(batch_size);
+ ::arrow::RecordBatchVector batches;
+ RETURN_NOT_OK(table_reader.ReadAll(&batches));
+ return ::arrow::MakeVectorGenerator(std::move(batches));
+ });
+ }
+
+ std::shared_ptr<FileReaderImpl> arrow_reader_;
+ ::arrow::internal::Executor* cpu_executor_;
+ std::vector<int> row_groups_;
+ std::vector<int> column_indices_;
+ size_t index_;
+};
+
+::arrow::Result<::arrow::AsyncGenerator<std::shared_ptr<::arrow::RecordBatch>>>
+FileReaderImpl::GetRecordBatchGenerator(std::shared_ptr<FileReader> reader,
+ const std::vector<int> row_group_indices,
+ const std::vector<int> column_indices,
+ ::arrow::internal::Executor* cpu_executor,
+ int row_group_readahead) {
+ RETURN_NOT_OK(BoundsCheck(row_group_indices, column_indices));
+ if (reader_properties_.pre_buffer()) {
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ reader_->PreBuffer(row_group_indices, column_indices, reader_properties_.io_context(),
+ reader_properties_.cache_options());
+ END_PARQUET_CATCH_EXCEPTIONS
+ }
+ ::arrow::AsyncGenerator<RowGroupGenerator::RecordBatchGenerator> row_group_generator =
+ RowGroupGenerator(::arrow::internal::checked_pointer_cast<FileReaderImpl>(reader),
+ cpu_executor, row_group_indices, column_indices);
+ if (row_group_readahead > 0) {
+ row_group_generator = ::arrow::MakeReadaheadGenerator(std::move(row_group_generator),
+ row_group_readahead);
+ }
+ return ::arrow::MakeConcatenatedGenerator(std::move(row_group_generator));
+}
+
+Status FileReaderImpl::GetColumn(int i, FileColumnIteratorFactory iterator_factory,
+ std::unique_ptr<ColumnReader>* out) {
+ RETURN_NOT_OK(BoundsCheckColumn(i));
+ auto ctx = std::make_shared<ReaderContext>();
+ ctx->reader = reader_.get();
+ ctx->pool = pool_;
+ ctx->iterator_factory = iterator_factory;
+ ctx->filter_leaves = false;
+ std::unique_ptr<ColumnReaderImpl> result;
+ RETURN_NOT_OK(GetReader(manifest_.schema_fields[i], ctx, &result));
+ out->reset(result.release());
+ return Status::OK();
+}
+
+Status FileReaderImpl::ReadRowGroups(const std::vector<int>& row_groups,
+ const std::vector<int>& column_indices,
+ std::shared_ptr<Table>* out) {
+ RETURN_NOT_OK(BoundsCheck(row_groups, column_indices));
+
+ // PARQUET-1698/PARQUET-1820: pre-buffer row groups/column chunks if enabled
+ if (reader_properties_.pre_buffer()) {
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ parquet_reader()->PreBuffer(row_groups, column_indices,
+ reader_properties_.io_context(),
+ reader_properties_.cache_options());
+ END_PARQUET_CATCH_EXCEPTIONS
+ }
+
+ auto fut = DecodeRowGroups(/*self=*/nullptr, row_groups, column_indices,
+ /*cpu_executor=*/nullptr);
+ ARROW_ASSIGN_OR_RAISE(*out, fut.MoveResult());
+ return Status::OK();
+}
+
+Future<std::shared_ptr<Table>> FileReaderImpl::DecodeRowGroups(
+ std::shared_ptr<FileReaderImpl> self, const std::vector<int>& row_groups,
+ const std::vector<int>& column_indices, ::arrow::internal::Executor* cpu_executor) {
+ // `self` is used solely to keep `this` alive in an async context - but we use this
+ // in a sync context too so use `this` over `self`
+ std::vector<std::shared_ptr<ColumnReaderImpl>> readers;
+ std::shared_ptr<::arrow::Schema> result_schema;
+ RETURN_NOT_OK(GetFieldReaders(column_indices, row_groups, &readers, &result_schema));
+ // OptionalParallelForAsync requires an executor
+ if (!cpu_executor) cpu_executor = ::arrow::internal::GetCpuThreadPool();
+
+ auto read_column = [row_groups, self, this](size_t i,
+ std::shared_ptr<ColumnReaderImpl> reader)
+ -> ::arrow::Result<std::shared_ptr<::arrow::ChunkedArray>> {
+ std::shared_ptr<::arrow::ChunkedArray> column;
+ RETURN_NOT_OK(ReadColumn(static_cast<int>(i), row_groups, reader.get(), &column));
+ return column;
+ };
+ auto make_table = [result_schema, row_groups, self,
+ this](const ::arrow::ChunkedArrayVector& columns)
+ -> ::arrow::Result<std::shared_ptr<Table>> {
+ int64_t num_rows = 0;
+ if (!columns.empty()) {
+ num_rows = columns[0]->length();
+ } else {
+ for (int i : row_groups) {
+ num_rows += parquet_reader()->metadata()->RowGroup(i)->num_rows();
+ }
+ }
+ auto table = Table::Make(std::move(result_schema), columns, num_rows);
+ RETURN_NOT_OK(table->Validate());
+ return table;
+ };
+ return ::arrow::internal::OptionalParallelForAsync(reader_properties_.use_threads(),
+ std::move(readers), read_column,
+ cpu_executor)
+ .Then(std::move(make_table));
+}
+
+std::shared_ptr<RowGroupReader> FileReaderImpl::RowGroup(int row_group_index) {
+ return std::make_shared<RowGroupReaderImpl>(this, row_group_index);
+}
+
+// ----------------------------------------------------------------------
+// Public factory functions
+
+Status FileReader::GetRecordBatchReader(const std::vector<int>& row_group_indices,
+ std::shared_ptr<RecordBatchReader>* out) {
+ std::unique_ptr<RecordBatchReader> tmp;
+ ARROW_RETURN_NOT_OK(GetRecordBatchReader(row_group_indices, &tmp));
+ out->reset(tmp.release());
+ return Status::OK();
+}
+
+Status FileReader::GetRecordBatchReader(const std::vector<int>& row_group_indices,
+ const std::vector<int>& column_indices,
+ std::shared_ptr<RecordBatchReader>* out) {
+ std::unique_ptr<RecordBatchReader> tmp;
+ ARROW_RETURN_NOT_OK(GetRecordBatchReader(row_group_indices, column_indices, &tmp));
+ out->reset(tmp.release());
+ return Status::OK();
+}
+
+Status FileReader::Make(::arrow::MemoryPool* pool,
+ std::unique_ptr<ParquetFileReader> reader,
+ const ArrowReaderProperties& properties,
+ std::unique_ptr<FileReader>* out) {
+ out->reset(new FileReaderImpl(pool, std::move(reader), properties));
+ return static_cast<FileReaderImpl*>(out->get())->Init();
+}
+
+Status FileReader::Make(::arrow::MemoryPool* pool,
+ std::unique_ptr<ParquetFileReader> reader,
+ std::unique_ptr<FileReader>* out) {
+ return Make(pool, std::move(reader), default_arrow_reader_properties(), out);
+}
+
+FileReaderBuilder::FileReaderBuilder()
+ : pool_(::arrow::default_memory_pool()),
+ properties_(default_arrow_reader_properties()) {}
+
+Status FileReaderBuilder::Open(std::shared_ptr<::arrow::io::RandomAccessFile> file,
+ const ReaderProperties& properties,
+ std::shared_ptr<FileMetaData> metadata) {
+ PARQUET_CATCH_NOT_OK(raw_reader_ = ParquetReader::Open(std::move(file), properties,
+ std::move(metadata)));
+ return Status::OK();
+}
+
+FileReaderBuilder* FileReaderBuilder::memory_pool(::arrow::MemoryPool* pool) {
+ pool_ = pool;
+ return this;
+}
+
+FileReaderBuilder* FileReaderBuilder::properties(
+ const ArrowReaderProperties& arg_properties) {
+ properties_ = arg_properties;
+ return this;
+}
+
+Status FileReaderBuilder::Build(std::unique_ptr<FileReader>* out) {
+ return FileReader::Make(pool_, std::move(raw_reader_), properties_, out);
+}
+
+Status OpenFile(std::shared_ptr<::arrow::io::RandomAccessFile> file, MemoryPool* pool,
+ std::unique_ptr<FileReader>* reader) {
+ FileReaderBuilder builder;
+ RETURN_NOT_OK(builder.Open(std::move(file)));
+ return builder.memory_pool(pool)->Build(reader);
+}
+
+namespace internal {
+
+Status FuzzReader(std::unique_ptr<FileReader> reader) {
+ auto st = Status::OK();
+ for (int i = 0; i < reader->num_row_groups(); ++i) {
+ std::shared_ptr<Table> table;
+ auto row_group_status = reader->ReadRowGroup(i, &table);
+ if (row_group_status.ok()) {
+ row_group_status &= table->ValidateFull();
+ }
+ st &= row_group_status;
+ }
+ return st;
+}
+
+Status FuzzReader(const uint8_t* data, int64_t size) {
+ auto buffer = std::make_shared<::arrow::Buffer>(data, size);
+ auto file = std::make_shared<::arrow::io::BufferReader>(buffer);
+ FileReaderBuilder builder;
+ RETURN_NOT_OK(builder.Open(std::move(file)));
+
+ std::unique_ptr<FileReader> reader;
+ RETURN_NOT_OK(builder.Build(&reader));
+ return FuzzReader(std::move(reader));
+}
+
+} // namespace internal
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/reader.h b/src/arrow/cpp/src/parquet/arrow/reader.h
new file mode 100644
index 000000000..85f2d7401
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/reader.h
@@ -0,0 +1,344 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+// N.B. we don't include async_generator.h as it's relatively heavy
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "parquet/file_reader.h"
+#include "parquet/platform.h"
+#include "parquet/properties.h"
+
+namespace arrow {
+
+class ChunkedArray;
+class KeyValueMetadata;
+class RecordBatchReader;
+struct Scalar;
+class Schema;
+class Table;
+class RecordBatch;
+
+} // namespace arrow
+
+namespace parquet {
+
+class FileMetaData;
+class SchemaDescriptor;
+
+namespace arrow {
+
+class ColumnChunkReader;
+class ColumnReader;
+struct SchemaManifest;
+class RowGroupReader;
+
+/// \brief Arrow read adapter class for deserializing Parquet files as Arrow row batches.
+///
+/// This interfaces caters for different use cases and thus provides different
+/// interfaces. In its most simplistic form, we cater for a user that wants to
+/// read the whole Parquet at once with the `FileReader::ReadTable` method.
+///
+/// More advanced users that also want to implement parallelism on top of each
+/// single Parquet files should do this on the RowGroup level. For this, they can
+/// call `FileReader::RowGroup(i)->ReadTable` to receive only the specified
+/// RowGroup as a table.
+///
+/// In the most advanced situation, where a consumer wants to independently read
+/// RowGroups in parallel and consume each column individually, they can call
+/// `FileReader::RowGroup(i)->Column(j)->Read` and receive an `arrow::Column`
+/// instance.
+///
+/// The parquet format supports an optional integer field_id which can be assigned
+/// to a field. Arrow will convert these field IDs to a metadata key named
+/// PARQUET:field_id on the appropriate field.
+// TODO(wesm): nested data does not always make sense with this user
+// interface unless you are only reading a single leaf node from a branch of
+// a table. For example:
+//
+// repeated group data {
+// optional group record {
+// optional int32 val1;
+// optional byte_array val2;
+// optional bool val3;
+// }
+// optional int32 val4;
+// }
+//
+// In the Parquet file, there are 3 leaf nodes:
+//
+// * data.record.val1
+// * data.record.val2
+// * data.record.val3
+// * data.val4
+//
+// When materializing this data in an Arrow array, we would have:
+//
+// data: list<struct<
+// record: struct<
+// val1: int32,
+// val2: string (= list<uint8>),
+// val3: bool,
+// >,
+// val4: int32
+// >>
+//
+// However, in the Parquet format, each leaf node has its own repetition and
+// definition levels describing the structure of the intermediate nodes in
+// this array structure. Thus, we will need to scan the leaf data for a group
+// of leaf nodes part of the same type tree to create a single result Arrow
+// nested array structure.
+//
+// This is additionally complicated "chunky" repeated fields or very large byte
+// arrays
+class PARQUET_EXPORT FileReader {
+ public:
+ /// Factory function to create a FileReader from a ParquetFileReader and properties
+ static ::arrow::Status Make(::arrow::MemoryPool* pool,
+ std::unique_ptr<ParquetFileReader> reader,
+ const ArrowReaderProperties& properties,
+ std::unique_ptr<FileReader>* out);
+
+ /// Factory function to create a FileReader from a ParquetFileReader
+ static ::arrow::Status Make(::arrow::MemoryPool* pool,
+ std::unique_ptr<ParquetFileReader> reader,
+ std::unique_ptr<FileReader>* out);
+
+ // Since the distribution of columns amongst a Parquet file's row groups may
+ // be uneven (the number of values in each column chunk can be different), we
+ // provide a column-oriented read interface. The ColumnReader hides the
+ // details of paging through the file's row groups and yielding
+ // fully-materialized arrow::Array instances
+ //
+ // Returns error status if the column of interest is not flat.
+ virtual ::arrow::Status GetColumn(int i, std::unique_ptr<ColumnReader>* out) = 0;
+
+ /// \brief Return arrow schema for all the columns.
+ virtual ::arrow::Status GetSchema(std::shared_ptr<::arrow::Schema>* out) = 0;
+
+ /// \brief Read column as a whole into a chunked array.
+ ///
+ /// The indicated column index is relative to the schema
+ virtual ::arrow::Status ReadColumn(int i,
+ std::shared_ptr<::arrow::ChunkedArray>* out) = 0;
+
+ // NOTE: Experimental API
+ // Reads a specific top level schema field into an Array
+ // The index i refers the index of the top level schema field, which may
+ // be nested or flat - e.g.
+ //
+ // 0 foo.bar
+ // foo.bar.baz
+ // foo.qux
+ // 1 foo2
+ // 2 foo3
+ //
+ // i=0 will read the entire foo struct, i=1 the foo2 primitive column etc
+ virtual ::arrow::Status ReadSchemaField(
+ int i, std::shared_ptr<::arrow::ChunkedArray>* out) = 0;
+
+ /// \brief Return a RecordBatchReader of row groups selected from row_group_indices.
+ ///
+ /// Note that the ordering in row_group_indices matters. FileReaders must outlive
+ /// their RecordBatchReaders.
+ ///
+ /// \returns error Status if row_group_indices contains an invalid index
+ virtual ::arrow::Status GetRecordBatchReader(
+ const std::vector<int>& row_group_indices,
+ std::unique_ptr<::arrow::RecordBatchReader>* out) = 0;
+
+ ::arrow::Status GetRecordBatchReader(const std::vector<int>& row_group_indices,
+ std::shared_ptr<::arrow::RecordBatchReader>* out);
+
+ /// \brief Return a RecordBatchReader of row groups selected from
+ /// row_group_indices, whose columns are selected by column_indices.
+ ///
+ /// Note that the ordering in row_group_indices and column_indices
+ /// matter. FileReaders must outlive their RecordBatchReaders.
+ ///
+ /// \returns error Status if either row_group_indices or column_indices
+ /// contains an invalid index
+ virtual ::arrow::Status GetRecordBatchReader(
+ const std::vector<int>& row_group_indices, const std::vector<int>& column_indices,
+ std::unique_ptr<::arrow::RecordBatchReader>* out) = 0;
+
+ /// \brief Return a generator of record batches.
+ ///
+ /// The FileReader must outlive the generator, so this requires that you pass in a
+ /// shared_ptr.
+ ///
+ /// \returns error Result if either row_group_indices or column_indices contains an
+ /// invalid index
+ virtual ::arrow::Result<
+ std::function<::arrow::Future<std::shared_ptr<::arrow::RecordBatch>>()>>
+ GetRecordBatchGenerator(std::shared_ptr<FileReader> reader,
+ const std::vector<int> row_group_indices,
+ const std::vector<int> column_indices,
+ ::arrow::internal::Executor* cpu_executor = NULLPTR,
+ int row_group_readahead = 0) = 0;
+
+ ::arrow::Status GetRecordBatchReader(const std::vector<int>& row_group_indices,
+ const std::vector<int>& column_indices,
+ std::shared_ptr<::arrow::RecordBatchReader>* out);
+
+ /// Read all columns into a Table
+ virtual ::arrow::Status ReadTable(std::shared_ptr<::arrow::Table>* out) = 0;
+
+ /// \brief Read the given columns into a Table
+ ///
+ /// The indicated column indices are relative to the schema
+ virtual ::arrow::Status ReadTable(const std::vector<int>& column_indices,
+ std::shared_ptr<::arrow::Table>* out) = 0;
+
+ virtual ::arrow::Status ReadRowGroup(int i, const std::vector<int>& column_indices,
+ std::shared_ptr<::arrow::Table>* out) = 0;
+
+ virtual ::arrow::Status ReadRowGroup(int i, std::shared_ptr<::arrow::Table>* out) = 0;
+
+ virtual ::arrow::Status ReadRowGroups(const std::vector<int>& row_groups,
+ const std::vector<int>& column_indices,
+ std::shared_ptr<::arrow::Table>* out) = 0;
+
+ virtual ::arrow::Status ReadRowGroups(const std::vector<int>& row_groups,
+ std::shared_ptr<::arrow::Table>* out) = 0;
+
+ /// \brief Scan file contents with one thread, return number of rows
+ virtual ::arrow::Status ScanContents(std::vector<int> columns,
+ const int32_t column_batch_size,
+ int64_t* num_rows) = 0;
+
+ /// \brief Return a reader for the RowGroup, this object must not outlive the
+ /// FileReader.
+ virtual std::shared_ptr<RowGroupReader> RowGroup(int row_group_index) = 0;
+
+ /// \brief The number of row groups in the file
+ virtual int num_row_groups() const = 0;
+
+ virtual ParquetFileReader* parquet_reader() const = 0;
+
+ /// Set whether to use multiple threads during reads of multiple columns.
+ /// By default only one thread is used.
+ virtual void set_use_threads(bool use_threads) = 0;
+
+ /// Set number of records to read per batch for the RecordBatchReader.
+ virtual void set_batch_size(int64_t batch_size) = 0;
+
+ virtual const ArrowReaderProperties& properties() const = 0;
+
+ virtual const SchemaManifest& manifest() const = 0;
+
+ virtual ~FileReader() = default;
+};
+
+class RowGroupReader {
+ public:
+ virtual ~RowGroupReader() = default;
+ virtual std::shared_ptr<ColumnChunkReader> Column(int column_index) = 0;
+ virtual ::arrow::Status ReadTable(const std::vector<int>& column_indices,
+ std::shared_ptr<::arrow::Table>* out) = 0;
+ virtual ::arrow::Status ReadTable(std::shared_ptr<::arrow::Table>* out) = 0;
+
+ private:
+ struct Iterator;
+};
+
+class ColumnChunkReader {
+ public:
+ virtual ~ColumnChunkReader() = default;
+ virtual ::arrow::Status Read(std::shared_ptr<::arrow::ChunkedArray>* out) = 0;
+};
+
+// At this point, the column reader is a stream iterator. It only knows how to
+// read the next batch of values for a particular column from the file until it
+// runs out.
+//
+// We also do not expose any internal Parquet details, such as row groups. This
+// might change in the future.
+class PARQUET_EXPORT ColumnReader {
+ public:
+ virtual ~ColumnReader() = default;
+
+ // Scan the next array of the indicated size. The actual size of the
+ // returned array may be less than the passed size depending how much data is
+ // available in the file.
+ //
+ // When all the data in the file has been exhausted, the result is set to
+ // nullptr.
+ //
+ // Returns Status::OK on a successful read, including if you have exhausted
+ // the data available in the file.
+ virtual ::arrow::Status NextBatch(int64_t batch_size,
+ std::shared_ptr<::arrow::ChunkedArray>* out) = 0;
+};
+
+/// \brief Experimental helper class for bindings (like Python) that struggle
+/// either with std::move or C++ exceptions
+class PARQUET_EXPORT FileReaderBuilder {
+ public:
+ FileReaderBuilder();
+
+ /// Create FileReaderBuilder from Arrow file and optional properties / metadata
+ ::arrow::Status Open(std::shared_ptr<::arrow::io::RandomAccessFile> file,
+ const ReaderProperties& properties = default_reader_properties(),
+ std::shared_ptr<FileMetaData> metadata = NULLPTR);
+
+ ParquetFileReader* raw_reader() { return raw_reader_.get(); }
+
+ /// Set Arrow MemoryPool for memory allocation
+ FileReaderBuilder* memory_pool(::arrow::MemoryPool* pool);
+ /// Set Arrow reader properties
+ FileReaderBuilder* properties(const ArrowReaderProperties& arg_properties);
+ /// Build FileReader instance
+ ::arrow::Status Build(std::unique_ptr<FileReader>* out);
+
+ private:
+ ::arrow::MemoryPool* pool_;
+ ArrowReaderProperties properties_;
+ std::unique_ptr<ParquetFileReader> raw_reader_;
+};
+
+/// \defgroup parquet-arrow-reader-factories Factory functions for Parquet Arrow readers
+///
+/// @{
+
+/// \brief Build FileReader from Arrow file and MemoryPool
+///
+/// Advanced settings are supported through the FileReaderBuilder class.
+PARQUET_EXPORT
+::arrow::Status OpenFile(std::shared_ptr<::arrow::io::RandomAccessFile>,
+ ::arrow::MemoryPool* allocator,
+ std::unique_ptr<FileReader>* reader);
+
+/// @}
+
+PARQUET_EXPORT
+::arrow::Status StatisticsAsScalars(const Statistics& Statistics,
+ std::shared_ptr<::arrow::Scalar>* min,
+ std::shared_ptr<::arrow::Scalar>* max);
+
+namespace internal {
+
+PARQUET_EXPORT
+::arrow::Status FuzzReader(const uint8_t* data, int64_t size);
+
+} // namespace internal
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/reader_internal.cc b/src/arrow/cpp/src/parquet/arrow/reader_internal.cc
new file mode 100644
index 000000000..f13687079
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/reader_internal.cc
@@ -0,0 +1,791 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/arrow/reader_internal.h"
+
+#include <algorithm>
+#include <climits>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/compute/api.h"
+#include "arrow/datum.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/scalar.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/base64.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string_view.h"
+#include "arrow/util/ubsan.h"
+#include "arrow/visitor_inline.h"
+#include "parquet/arrow/reader.h"
+#include "parquet/arrow/schema.h"
+#include "parquet/arrow/schema_internal.h"
+#include "parquet/column_reader.h"
+#include "parquet/platform.h"
+#include "parquet/properties.h"
+#include "parquet/schema.h"
+#include "parquet/statistics.h"
+#include "parquet/types.h"
+// Required after "arrow/util/int_util_internal.h" (for OPTIONAL)
+#include "parquet/windows_compatibility.h"
+
+using arrow::Array;
+using arrow::BooleanArray;
+using arrow::ChunkedArray;
+using arrow::DataType;
+using arrow::Datum;
+using arrow::Decimal128;
+using arrow::Decimal128Array;
+using arrow::Decimal128Type;
+using arrow::Decimal256;
+using arrow::Decimal256Array;
+using arrow::Decimal256Type;
+using arrow::Field;
+using arrow::Int32Array;
+using arrow::ListArray;
+using arrow::MemoryPool;
+using arrow::ResizableBuffer;
+using arrow::Status;
+using arrow::StructArray;
+using arrow::Table;
+using arrow::TimestampArray;
+
+using ::arrow::BitUtil::FromBigEndian;
+using ::arrow::internal::checked_cast;
+using ::arrow::internal::checked_pointer_cast;
+using ::arrow::internal::SafeLeftShift;
+using ::arrow::util::SafeLoadAs;
+
+using parquet::internal::BinaryRecordReader;
+using parquet::internal::DictionaryRecordReader;
+using parquet::internal::RecordReader;
+using parquet::schema::GroupNode;
+using parquet::schema::Node;
+using parquet::schema::PrimitiveNode;
+using ParquetType = parquet::Type;
+
+namespace BitUtil = arrow::BitUtil;
+
+namespace parquet {
+namespace arrow {
+namespace {
+
+template <typename ArrowType>
+using ArrayType = typename ::arrow::TypeTraits<ArrowType>::ArrayType;
+
+template <typename CType, typename StatisticsType>
+Status MakeMinMaxScalar(const StatisticsType& statistics,
+ std::shared_ptr<::arrow::Scalar>* min,
+ std::shared_ptr<::arrow::Scalar>* max) {
+ *min = ::arrow::MakeScalar(static_cast<CType>(statistics.min()));
+ *max = ::arrow::MakeScalar(static_cast<CType>(statistics.max()));
+ return Status::OK();
+}
+
+template <typename CType, typename StatisticsType>
+Status MakeMinMaxTypedScalar(const StatisticsType& statistics,
+ std::shared_ptr<DataType> type,
+ std::shared_ptr<::arrow::Scalar>* min,
+ std::shared_ptr<::arrow::Scalar>* max) {
+ ARROW_ASSIGN_OR_RAISE(*min, ::arrow::MakeScalar(type, statistics.min()));
+ ARROW_ASSIGN_OR_RAISE(*max, ::arrow::MakeScalar(type, statistics.max()));
+ return Status::OK();
+}
+
+template <typename StatisticsType>
+Status MakeMinMaxIntegralScalar(const StatisticsType& statistics,
+ const ::arrow::DataType& arrow_type,
+ std::shared_ptr<::arrow::Scalar>* min,
+ std::shared_ptr<::arrow::Scalar>* max) {
+ const auto column_desc = statistics.descr();
+ const auto& logical_type = column_desc->logical_type();
+ const auto& integer = checked_pointer_cast<const IntLogicalType>(logical_type);
+ const bool is_signed = integer->is_signed();
+
+ switch (integer->bit_width()) {
+ case 8:
+ return is_signed ? MakeMinMaxScalar<int8_t>(statistics, min, max)
+ : MakeMinMaxScalar<uint8_t>(statistics, min, max);
+ case 16:
+ return is_signed ? MakeMinMaxScalar<int16_t>(statistics, min, max)
+ : MakeMinMaxScalar<uint16_t>(statistics, min, max);
+ case 32:
+ return is_signed ? MakeMinMaxScalar<int32_t>(statistics, min, max)
+ : MakeMinMaxScalar<uint32_t>(statistics, min, max);
+ case 64:
+ return is_signed ? MakeMinMaxScalar<int64_t>(statistics, min, max)
+ : MakeMinMaxScalar<uint64_t>(statistics, min, max);
+ }
+
+ return Status::OK();
+}
+
+static Status FromInt32Statistics(const Int32Statistics& statistics,
+ const LogicalType& logical_type,
+ std::shared_ptr<::arrow::Scalar>* min,
+ std::shared_ptr<::arrow::Scalar>* max) {
+ ARROW_ASSIGN_OR_RAISE(auto type, FromInt32(logical_type));
+
+ switch (logical_type.type()) {
+ case LogicalType::Type::INT:
+ return MakeMinMaxIntegralScalar(statistics, *type, min, max);
+ break;
+ case LogicalType::Type::DATE:
+ case LogicalType::Type::TIME:
+ case LogicalType::Type::NONE:
+ return MakeMinMaxTypedScalar<int32_t>(statistics, type, min, max);
+ break;
+ default:
+ break;
+ }
+
+ return Status::NotImplemented("Cannot extract statistics for type ");
+}
+
+static Status FromInt64Statistics(const Int64Statistics& statistics,
+ const LogicalType& logical_type,
+ std::shared_ptr<::arrow::Scalar>* min,
+ std::shared_ptr<::arrow::Scalar>* max) {
+ ARROW_ASSIGN_OR_RAISE(auto type, FromInt64(logical_type));
+
+ switch (logical_type.type()) {
+ case LogicalType::Type::INT:
+ return MakeMinMaxIntegralScalar(statistics, *type, min, max);
+ break;
+ case LogicalType::Type::TIME:
+ case LogicalType::Type::TIMESTAMP:
+ case LogicalType::Type::NONE:
+ return MakeMinMaxTypedScalar<int64_t>(statistics, type, min, max);
+ break;
+ default:
+ break;
+ }
+
+ return Status::NotImplemented("Cannot extract statistics for type ");
+}
+
+template <typename DecimalType>
+Result<std::shared_ptr<::arrow::Scalar>> FromBigEndianString(
+ const std::string& data, std::shared_ptr<DataType> arrow_type) {
+ ARROW_ASSIGN_OR_RAISE(
+ DecimalType decimal,
+ DecimalType::FromBigEndian(reinterpret_cast<const uint8_t*>(data.data()),
+ static_cast<int32_t>(data.size())));
+ return ::arrow::MakeScalar(std::move(arrow_type), decimal);
+}
+
+// Extracts Min and Max scalar from bytes like types (i.e. types where
+// decimal is encoded as little endian.
+Status ExtractDecimalMinMaxFromBytesType(const Statistics& statistics,
+ const LogicalType& logical_type,
+ std::shared_ptr<::arrow::Scalar>* min,
+ std::shared_ptr<::arrow::Scalar>* max) {
+ const DecimalLogicalType& decimal_type =
+ checked_cast<const DecimalLogicalType&>(logical_type);
+
+ Result<std::shared_ptr<DataType>> maybe_type =
+ Decimal128Type::Make(decimal_type.precision(), decimal_type.scale());
+ std::shared_ptr<DataType> arrow_type;
+ if (maybe_type.ok()) {
+ arrow_type = maybe_type.ValueOrDie();
+ ARROW_ASSIGN_OR_RAISE(
+ *min, FromBigEndianString<Decimal128>(statistics.EncodeMin(), arrow_type));
+ ARROW_ASSIGN_OR_RAISE(*max, FromBigEndianString<Decimal128>(statistics.EncodeMax(),
+ std::move(arrow_type)));
+ return Status::OK();
+ }
+ // Fallback to see if Decimal256 can represent the type.
+ ARROW_ASSIGN_OR_RAISE(
+ arrow_type, Decimal256Type::Make(decimal_type.precision(), decimal_type.scale()));
+ ARROW_ASSIGN_OR_RAISE(
+ *min, FromBigEndianString<Decimal256>(statistics.EncodeMin(), arrow_type));
+ ARROW_ASSIGN_OR_RAISE(*max, FromBigEndianString<Decimal256>(statistics.EncodeMax(),
+ std::move(arrow_type)));
+
+ return Status::OK();
+}
+
+Status ByteArrayStatisticsAsScalars(const Statistics& statistics,
+ std::shared_ptr<::arrow::Scalar>* min,
+ std::shared_ptr<::arrow::Scalar>* max) {
+ auto logical_type = statistics.descr()->logical_type();
+ if (logical_type->type() == LogicalType::Type::DECIMAL) {
+ return ExtractDecimalMinMaxFromBytesType(statistics, *logical_type, min, max);
+ }
+ std::shared_ptr<::arrow::DataType> type;
+ if (statistics.descr()->physical_type() == Type::FIXED_LEN_BYTE_ARRAY) {
+ type = ::arrow::fixed_size_binary(statistics.descr()->type_length());
+ } else {
+ type = logical_type->type() == LogicalType::Type::STRING ? ::arrow::utf8()
+ : ::arrow::binary();
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ *min, ::arrow::MakeScalar(type, Buffer::FromString(statistics.EncodeMin())));
+ ARROW_ASSIGN_OR_RAISE(
+ *max, ::arrow::MakeScalar(type, Buffer::FromString(statistics.EncodeMax())));
+
+ return Status::OK();
+}
+
+} // namespace
+
+Status StatisticsAsScalars(const Statistics& statistics,
+ std::shared_ptr<::arrow::Scalar>* min,
+ std::shared_ptr<::arrow::Scalar>* max) {
+ if (!statistics.HasMinMax()) {
+ return Status::Invalid("Statistics has no min max.");
+ }
+
+ auto column_desc = statistics.descr();
+ if (column_desc == nullptr) {
+ return Status::Invalid("Statistics carries no descriptor, can't infer arrow type.");
+ }
+
+ auto physical_type = column_desc->physical_type();
+ auto logical_type = column_desc->logical_type();
+ switch (physical_type) {
+ case Type::BOOLEAN:
+ return MakeMinMaxScalar<bool, BoolStatistics>(
+ checked_cast<const BoolStatistics&>(statistics), min, max);
+ case Type::FLOAT:
+ return MakeMinMaxScalar<float, FloatStatistics>(
+ checked_cast<const FloatStatistics&>(statistics), min, max);
+ case Type::DOUBLE:
+ return MakeMinMaxScalar<double, DoubleStatistics>(
+ checked_cast<const DoubleStatistics&>(statistics), min, max);
+ case Type::INT32:
+ return FromInt32Statistics(checked_cast<const Int32Statistics&>(statistics),
+ *logical_type, min, max);
+ case Type::INT64:
+ return FromInt64Statistics(checked_cast<const Int64Statistics&>(statistics),
+ *logical_type, min, max);
+ case Type::BYTE_ARRAY:
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return ByteArrayStatisticsAsScalars(statistics, min, max);
+ default:
+ return Status::NotImplemented("Extract statistics unsupported for physical_type ",
+ physical_type, " unsupported.");
+ }
+
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Primitive types
+
+namespace {
+
+template <typename ArrowType, typename ParquetType>
+Status TransferInt(RecordReader* reader, MemoryPool* pool,
+ const std::shared_ptr<DataType>& type, Datum* out) {
+ using ArrowCType = typename ArrowType::c_type;
+ using ParquetCType = typename ParquetType::c_type;
+ int64_t length = reader->values_written();
+ ARROW_ASSIGN_OR_RAISE(auto data,
+ ::arrow::AllocateBuffer(length * sizeof(ArrowCType), pool));
+
+ auto values = reinterpret_cast<const ParquetCType*>(reader->values());
+ auto out_ptr = reinterpret_cast<ArrowCType*>(data->mutable_data());
+ std::copy(values, values + length, out_ptr);
+ *out = std::make_shared<ArrayType<ArrowType>>(
+ type, length, std::move(data), reader->ReleaseIsValid(), reader->null_count());
+ return Status::OK();
+}
+
+std::shared_ptr<Array> TransferZeroCopy(RecordReader* reader,
+ const std::shared_ptr<DataType>& type) {
+ std::vector<std::shared_ptr<Buffer>> buffers = {reader->ReleaseIsValid(),
+ reader->ReleaseValues()};
+ auto data = std::make_shared<::arrow::ArrayData>(type, reader->values_written(),
+ buffers, reader->null_count());
+ return ::arrow::MakeArray(data);
+}
+
+Status TransferBool(RecordReader* reader, MemoryPool* pool, Datum* out) {
+ int64_t length = reader->values_written();
+
+ const int64_t buffer_size = BitUtil::BytesForBits(length);
+ ARROW_ASSIGN_OR_RAISE(auto data, ::arrow::AllocateBuffer(buffer_size, pool));
+
+ // Transfer boolean values to packed bitmap
+ auto values = reinterpret_cast<const bool*>(reader->values());
+ uint8_t* data_ptr = data->mutable_data();
+ memset(data_ptr, 0, buffer_size);
+
+ for (int64_t i = 0; i < length; i++) {
+ if (values[i]) {
+ ::arrow::BitUtil::SetBit(data_ptr, i);
+ }
+ }
+
+ *out = std::make_shared<BooleanArray>(length, std::move(data), reader->ReleaseIsValid(),
+ reader->null_count());
+ return Status::OK();
+}
+
+Status TransferInt96(RecordReader* reader, MemoryPool* pool,
+ const std::shared_ptr<DataType>& type, Datum* out,
+ const ::arrow::TimeUnit::type int96_arrow_time_unit) {
+ int64_t length = reader->values_written();
+ auto values = reinterpret_cast<const Int96*>(reader->values());
+ ARROW_ASSIGN_OR_RAISE(auto data,
+ ::arrow::AllocateBuffer(length * sizeof(int64_t), pool));
+ auto data_ptr = reinterpret_cast<int64_t*>(data->mutable_data());
+ for (int64_t i = 0; i < length; i++) {
+ if (values[i].value[2] == 0) {
+ // Happens for null entries: avoid triggering UBSAN as that Int96 timestamp
+ // isn't representable as a 64-bit Unix timestamp.
+ *data_ptr++ = 0;
+ } else {
+ switch (int96_arrow_time_unit) {
+ case ::arrow::TimeUnit::NANO:
+ *data_ptr++ = Int96GetNanoSeconds(values[i]);
+ break;
+ case ::arrow::TimeUnit::MICRO:
+ *data_ptr++ = Int96GetMicroSeconds(values[i]);
+ break;
+ case ::arrow::TimeUnit::MILLI:
+ *data_ptr++ = Int96GetMilliSeconds(values[i]);
+ break;
+ case ::arrow::TimeUnit::SECOND:
+ *data_ptr++ = Int96GetSeconds(values[i]);
+ break;
+ }
+ }
+ }
+ *out = std::make_shared<TimestampArray>(type, length, std::move(data),
+ reader->ReleaseIsValid(), reader->null_count());
+ return Status::OK();
+}
+
+Status TransferDate64(RecordReader* reader, MemoryPool* pool,
+ const std::shared_ptr<DataType>& type, Datum* out) {
+ int64_t length = reader->values_written();
+ auto values = reinterpret_cast<const int32_t*>(reader->values());
+
+ ARROW_ASSIGN_OR_RAISE(auto data,
+ ::arrow::AllocateBuffer(length * sizeof(int64_t), pool));
+ auto out_ptr = reinterpret_cast<int64_t*>(data->mutable_data());
+
+ for (int64_t i = 0; i < length; i++) {
+ *out_ptr++ = static_cast<int64_t>(values[i]) * kMillisecondsPerDay;
+ }
+
+ *out = std::make_shared<::arrow::Date64Array>(
+ type, length, std::move(data), reader->ReleaseIsValid(), reader->null_count());
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Binary, direct to dictionary-encoded
+
+Status TransferDictionary(RecordReader* reader,
+ const std::shared_ptr<DataType>& logical_value_type,
+ std::shared_ptr<ChunkedArray>* out) {
+ auto dict_reader = dynamic_cast<DictionaryRecordReader*>(reader);
+ DCHECK(dict_reader);
+ *out = dict_reader->GetResult();
+ if (!logical_value_type->Equals(*(*out)->type())) {
+ ARROW_ASSIGN_OR_RAISE(*out, (*out)->View(logical_value_type));
+ }
+ return Status::OK();
+}
+
+Status TransferBinary(RecordReader* reader, MemoryPool* pool,
+ const std::shared_ptr<DataType>& logical_value_type,
+ std::shared_ptr<ChunkedArray>* out) {
+ if (reader->read_dictionary()) {
+ return TransferDictionary(
+ reader, ::arrow::dictionary(::arrow::int32(), logical_value_type), out);
+ }
+ ::arrow::compute::ExecContext ctx(pool);
+ ::arrow::compute::CastOptions cast_options;
+ cast_options.allow_invalid_utf8 = true; // avoid spending time validating UTF8 data
+
+ auto binary_reader = dynamic_cast<BinaryRecordReader*>(reader);
+ DCHECK(binary_reader);
+ auto chunks = binary_reader->GetBuilderChunks();
+ for (auto& chunk : chunks) {
+ if (!chunk->type()->Equals(*logical_value_type)) {
+ // XXX: if a LargeBinary chunk is larger than 2GB, the MSBs of offsets
+ // will be lost because they are first created as int32 and then cast to int64.
+ ARROW_ASSIGN_OR_RAISE(
+ chunk, ::arrow::compute::Cast(*chunk, logical_value_type, cast_options, &ctx));
+ }
+ }
+ *out = std::make_shared<ChunkedArray>(chunks, logical_value_type);
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// INT32 / INT64 / BYTE_ARRAY / FIXED_LEN_BYTE_ARRAY -> Decimal128 || Decimal256
+
+template <typename DecimalType>
+Status RawBytesToDecimalBytes(const uint8_t* value, int32_t byte_width,
+ uint8_t* out_buf) {
+ ARROW_ASSIGN_OR_RAISE(DecimalType t, DecimalType::FromBigEndian(value, byte_width));
+ t.ToBytes(out_buf);
+ return ::arrow::Status::OK();
+}
+
+template <typename DecimalArrayType>
+struct DecimalTypeTrait;
+
+template <>
+struct DecimalTypeTrait<::arrow::Decimal128Array> {
+ using value = ::arrow::Decimal128;
+};
+
+template <>
+struct DecimalTypeTrait<::arrow::Decimal256Array> {
+ using value = ::arrow::Decimal256;
+};
+
+template <typename DecimalArrayType, typename ParquetType>
+struct DecimalConverter {
+ static inline Status ConvertToDecimal(const Array& array,
+ const std::shared_ptr<DataType>&,
+ MemoryPool* pool, std::shared_ptr<Array>*) {
+ return Status::NotImplemented("not implemented");
+ }
+};
+
+template <typename DecimalArrayType>
+struct DecimalConverter<DecimalArrayType, FLBAType> {
+ static inline Status ConvertToDecimal(const Array& array,
+ const std::shared_ptr<DataType>& type,
+ MemoryPool* pool, std::shared_ptr<Array>* out) {
+ const auto& fixed_size_binary_array =
+ checked_cast<const ::arrow::FixedSizeBinaryArray&>(array);
+
+ // The byte width of each decimal value
+ const int32_t type_length =
+ checked_cast<const ::arrow::DecimalType&>(*type).byte_width();
+
+ // number of elements in the entire array
+ const int64_t length = fixed_size_binary_array.length();
+
+ // Get the byte width of the values in the FixedSizeBinaryArray. Most of the time
+ // this will be different from the decimal array width because we write the minimum
+ // number of bytes necessary to represent a given precision
+ const int32_t byte_width =
+ checked_cast<const ::arrow::FixedSizeBinaryType&>(*fixed_size_binary_array.type())
+ .byte_width();
+ // allocate memory for the decimal array
+ ARROW_ASSIGN_OR_RAISE(auto data, ::arrow::AllocateBuffer(length * type_length, pool));
+
+ // raw bytes that we can write to
+ uint8_t* out_ptr = data->mutable_data();
+
+ // convert each FixedSizeBinary value to valid decimal bytes
+ const int64_t null_count = fixed_size_binary_array.null_count();
+
+ using DecimalType = typename DecimalTypeTrait<DecimalArrayType>::value;
+ if (null_count > 0) {
+ for (int64_t i = 0; i < length; ++i, out_ptr += type_length) {
+ if (!fixed_size_binary_array.IsNull(i)) {
+ RETURN_NOT_OK(RawBytesToDecimalBytes<DecimalType>(
+ fixed_size_binary_array.GetValue(i), byte_width, out_ptr));
+ } else {
+ std::memset(out_ptr, 0, type_length);
+ }
+ }
+ } else {
+ for (int64_t i = 0; i < length; ++i, out_ptr += type_length) {
+ RETURN_NOT_OK(RawBytesToDecimalBytes<DecimalType>(
+ fixed_size_binary_array.GetValue(i), byte_width, out_ptr));
+ }
+ }
+
+ *out = std::make_shared<DecimalArrayType>(
+ type, length, std::move(data), fixed_size_binary_array.null_bitmap(), null_count);
+
+ return Status::OK();
+ }
+};
+
+template <typename DecimalArrayType>
+struct DecimalConverter<DecimalArrayType, ByteArrayType> {
+ static inline Status ConvertToDecimal(const Array& array,
+ const std::shared_ptr<DataType>& type,
+ MemoryPool* pool, std::shared_ptr<Array>* out) {
+ const auto& binary_array = checked_cast<const ::arrow::BinaryArray&>(array);
+ const int64_t length = binary_array.length();
+
+ const auto& decimal_type = checked_cast<const ::arrow::DecimalType&>(*type);
+ const int64_t type_length = decimal_type.byte_width();
+
+ ARROW_ASSIGN_OR_RAISE(auto data, ::arrow::AllocateBuffer(length * type_length, pool));
+
+ // raw bytes that we can write to
+ uint8_t* out_ptr = data->mutable_data();
+
+ const int64_t null_count = binary_array.null_count();
+
+ // convert each BinaryArray value to valid decimal bytes
+ for (int64_t i = 0; i < length; i++, out_ptr += type_length) {
+ int32_t record_len = 0;
+ const uint8_t* record_loc = binary_array.GetValue(i, &record_len);
+
+ if (record_len < 0 || record_len > type_length) {
+ return Status::Invalid("Invalid BYTE_ARRAY length for ", type->ToString());
+ }
+
+ auto out_ptr_view = reinterpret_cast<uint64_t*>(out_ptr);
+ out_ptr_view[0] = 0;
+ out_ptr_view[1] = 0;
+
+ // only convert rows that are not null if there are nulls, or
+ // all rows, if there are not
+ if ((null_count > 0 && !binary_array.IsNull(i)) || null_count <= 0) {
+ using DecimalType = typename DecimalTypeTrait<DecimalArrayType>::value;
+ RETURN_NOT_OK(
+ RawBytesToDecimalBytes<DecimalType>(record_loc, record_len, out_ptr));
+ }
+ }
+ *out = std::make_shared<DecimalArrayType>(type, length, std::move(data),
+ binary_array.null_bitmap(), null_count);
+ return Status::OK();
+ }
+};
+
+/// \brief Convert an Int32 or Int64 array into a Decimal128Array
+/// The parquet spec allows systems to write decimals in int32, int64 if the values are
+/// small enough to fit in less 4 bytes or less than 8 bytes, respectively.
+/// This function implements the conversion from int32 and int64 arrays to decimal arrays.
+template <
+ typename ParquetIntegerType,
+ typename = ::arrow::enable_if_t<std::is_same<ParquetIntegerType, Int32Type>::value ||
+ std::is_same<ParquetIntegerType, Int64Type>::value>>
+static Status DecimalIntegerTransfer(RecordReader* reader, MemoryPool* pool,
+ const std::shared_ptr<DataType>& type, Datum* out) {
+ // Decimal128 and Decimal256 are only Arrow constructs. Parquet does not
+ // specifically distinguish between decimal byte widths.
+ // Decimal256 isn't relevant here because the Arrow-Parquet C++ bindings never
+ // write Decimal values as integers and if the decimal value can fit in an
+ // integer it is wasteful to use Decimal256. Put another way, the only
+ // way an integer column could be construed as Decimal256 is if an arrow
+ // schema was stored as metadata in the file indicating the column was
+ // Decimal256. The current Arrow-Parquet C++ bindings will never do this.
+ DCHECK(type->id() == ::arrow::Type::DECIMAL128);
+
+ const int64_t length = reader->values_written();
+
+ using ElementType = typename ParquetIntegerType::c_type;
+ static_assert(std::is_same<ElementType, int32_t>::value ||
+ std::is_same<ElementType, int64_t>::value,
+ "ElementType must be int32_t or int64_t");
+
+ const auto values = reinterpret_cast<const ElementType*>(reader->values());
+
+ const auto& decimal_type = checked_cast<const ::arrow::DecimalType&>(*type);
+ const int64_t type_length = decimal_type.byte_width();
+
+ ARROW_ASSIGN_OR_RAISE(auto data, ::arrow::AllocateBuffer(length * type_length, pool));
+ uint8_t* out_ptr = data->mutable_data();
+
+ using ::arrow::BitUtil::FromLittleEndian;
+
+ for (int64_t i = 0; i < length; ++i, out_ptr += type_length) {
+ // sign/zero extend int32_t values, otherwise a no-op
+ const auto value = static_cast<int64_t>(values[i]);
+
+ ::arrow::Decimal128 decimal(value);
+ decimal.ToBytes(out_ptr);
+ }
+
+ if (reader->nullable_values()) {
+ std::shared_ptr<ResizableBuffer> is_valid = reader->ReleaseIsValid();
+ *out = std::make_shared<Decimal128Array>(type, length, std::move(data), is_valid,
+ reader->null_count());
+ } else {
+ *out = std::make_shared<Decimal128Array>(type, length, std::move(data));
+ }
+ return Status::OK();
+}
+
+/// \brief Convert an arrow::BinaryArray to an arrow::Decimal{128,256}Array
+/// We do this by:
+/// 1. Creating an arrow::BinaryArray from the RecordReader's builder
+/// 2. Allocating a buffer for the arrow::Decimal{128,256}Array
+/// 3. Converting the big-endian bytes in each BinaryArray entry to two integers
+/// representing the high and low bits of each decimal value.
+template <typename DecimalArrayType, typename ParquetType>
+Status TransferDecimal(RecordReader* reader, MemoryPool* pool,
+ const std::shared_ptr<DataType>& type, Datum* out) {
+ auto binary_reader = dynamic_cast<BinaryRecordReader*>(reader);
+ DCHECK(binary_reader);
+ ::arrow::ArrayVector chunks = binary_reader->GetBuilderChunks();
+ for (size_t i = 0; i < chunks.size(); ++i) {
+ std::shared_ptr<Array> chunk_as_decimal;
+ auto fn = &DecimalConverter<DecimalArrayType, ParquetType>::ConvertToDecimal;
+ RETURN_NOT_OK(fn(*chunks[i], type, pool, &chunk_as_decimal));
+ // Replace the chunk, which will hopefully also free memory as we go
+ chunks[i] = chunk_as_decimal;
+ }
+ *out = std::make_shared<ChunkedArray>(chunks, type);
+ return Status::OK();
+}
+
+} // namespace
+
+#define TRANSFER_INT32(ENUM, ArrowType) \
+ case ::arrow::Type::ENUM: { \
+ Status s = TransferInt<ArrowType, Int32Type>(reader, pool, value_type, &result); \
+ RETURN_NOT_OK(s); \
+ } break;
+
+#define TRANSFER_INT64(ENUM, ArrowType) \
+ case ::arrow::Type::ENUM: { \
+ Status s = TransferInt<ArrowType, Int64Type>(reader, pool, value_type, &result); \
+ RETURN_NOT_OK(s); \
+ } break;
+
+Status TransferColumnData(RecordReader* reader, std::shared_ptr<DataType> value_type,
+ const ColumnDescriptor* descr, MemoryPool* pool,
+ std::shared_ptr<ChunkedArray>* out) {
+ Datum result;
+ std::shared_ptr<ChunkedArray> chunked_result;
+ switch (value_type->id()) {
+ case ::arrow::Type::DICTIONARY: {
+ RETURN_NOT_OK(TransferDictionary(reader, value_type, &chunked_result));
+ result = chunked_result;
+ } break;
+ case ::arrow::Type::NA: {
+ result = std::make_shared<::arrow::NullArray>(reader->values_written());
+ break;
+ }
+ case ::arrow::Type::INT32:
+ case ::arrow::Type::INT64:
+ case ::arrow::Type::FLOAT:
+ case ::arrow::Type::DOUBLE:
+ result = TransferZeroCopy(reader, value_type);
+ break;
+ case ::arrow::Type::BOOL:
+ RETURN_NOT_OK(TransferBool(reader, pool, &result));
+ break;
+ TRANSFER_INT32(UINT8, ::arrow::UInt8Type);
+ TRANSFER_INT32(INT8, ::arrow::Int8Type);
+ TRANSFER_INT32(UINT16, ::arrow::UInt16Type);
+ TRANSFER_INT32(INT16, ::arrow::Int16Type);
+ TRANSFER_INT32(UINT32, ::arrow::UInt32Type);
+ TRANSFER_INT64(UINT64, ::arrow::UInt64Type);
+ TRANSFER_INT32(DATE32, ::arrow::Date32Type);
+ TRANSFER_INT32(TIME32, ::arrow::Time32Type);
+ TRANSFER_INT64(TIME64, ::arrow::Time64Type);
+ case ::arrow::Type::DATE64:
+ RETURN_NOT_OK(TransferDate64(reader, pool, value_type, &result));
+ break;
+ case ::arrow::Type::FIXED_SIZE_BINARY:
+ case ::arrow::Type::BINARY:
+ case ::arrow::Type::STRING:
+ case ::arrow::Type::LARGE_BINARY:
+ case ::arrow::Type::LARGE_STRING: {
+ RETURN_NOT_OK(TransferBinary(reader, pool, value_type, &chunked_result));
+ result = chunked_result;
+ } break;
+ case ::arrow::Type::DECIMAL128: {
+ switch (descr->physical_type()) {
+ case ::parquet::Type::INT32: {
+ auto fn = DecimalIntegerTransfer<Int32Type>;
+ RETURN_NOT_OK(fn(reader, pool, value_type, &result));
+ } break;
+ case ::parquet::Type::INT64: {
+ auto fn = &DecimalIntegerTransfer<Int64Type>;
+ RETURN_NOT_OK(fn(reader, pool, value_type, &result));
+ } break;
+ case ::parquet::Type::BYTE_ARRAY: {
+ auto fn = &TransferDecimal<Decimal128Array, ByteArrayType>;
+ RETURN_NOT_OK(fn(reader, pool, value_type, &result));
+ } break;
+ case ::parquet::Type::FIXED_LEN_BYTE_ARRAY: {
+ auto fn = &TransferDecimal<Decimal128Array, FLBAType>;
+ RETURN_NOT_OK(fn(reader, pool, value_type, &result));
+ } break;
+ default:
+ return Status::Invalid(
+ "Physical type for decimal128 must be int32, int64, byte array, or fixed "
+ "length binary");
+ }
+ } break;
+ case ::arrow::Type::DECIMAL256:
+ switch (descr->physical_type()) {
+ case ::parquet::Type::BYTE_ARRAY: {
+ auto fn = &TransferDecimal<Decimal256Array, ByteArrayType>;
+ RETURN_NOT_OK(fn(reader, pool, value_type, &result));
+ } break;
+ case ::parquet::Type::FIXED_LEN_BYTE_ARRAY: {
+ auto fn = &TransferDecimal<Decimal256Array, FLBAType>;
+ RETURN_NOT_OK(fn(reader, pool, value_type, &result));
+ } break;
+ default:
+ return Status::Invalid(
+ "Physical type for decimal256 must be fixed length binary");
+ }
+ break;
+
+ case ::arrow::Type::TIMESTAMP: {
+ const ::arrow::TimestampType& timestamp_type =
+ checked_cast<::arrow::TimestampType&>(*value_type);
+ if (descr->physical_type() == ::parquet::Type::INT96) {
+ RETURN_NOT_OK(
+ TransferInt96(reader, pool, value_type, &result, timestamp_type.unit()));
+ } else {
+ switch (timestamp_type.unit()) {
+ case ::arrow::TimeUnit::MILLI:
+ case ::arrow::TimeUnit::MICRO:
+ case ::arrow::TimeUnit::NANO:
+ result = TransferZeroCopy(reader, value_type);
+ break;
+ default:
+ return Status::NotImplemented("TimeUnit not supported");
+ }
+ }
+ } break;
+ default:
+ return Status::NotImplemented("No support for reading columns of type ",
+ value_type->ToString());
+ }
+
+ if (result.kind() == Datum::ARRAY) {
+ *out = std::make_shared<ChunkedArray>(result.make_array());
+ } else if (result.kind() == Datum::CHUNKED_ARRAY) {
+ *out = result.chunked_array();
+ } else {
+ DCHECK(false) << "Should be impossible, result was " << result.ToString();
+ }
+
+ return Status::OK();
+}
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/reader_internal.h b/src/arrow/cpp/src/parquet/arrow/reader_internal.h
new file mode 100644
index 000000000..ad0b78157
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/reader_internal.h
@@ -0,0 +1,122 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cstdint>
+#include <deque>
+#include <functional>
+#include <memory>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "parquet/arrow/schema.h"
+#include "parquet/column_reader.h"
+#include "parquet/file_reader.h"
+#include "parquet/metadata.h"
+#include "parquet/platform.h"
+#include "parquet/schema.h"
+
+namespace arrow {
+
+class Array;
+class ChunkedArray;
+class DataType;
+class Field;
+class KeyValueMetadata;
+class Schema;
+
+} // namespace arrow
+
+using arrow::Status;
+
+namespace parquet {
+
+class ArrowReaderProperties;
+
+namespace arrow {
+
+class ColumnReaderImpl;
+
+// ----------------------------------------------------------------------
+// Iteration utilities
+
+// Abstraction to decouple row group iteration details from the ColumnReader,
+// so we can read only a single row group if we want
+class FileColumnIterator {
+ public:
+ explicit FileColumnIterator(int column_index, ParquetFileReader* reader,
+ std::vector<int> row_groups)
+ : column_index_(column_index),
+ reader_(reader),
+ schema_(reader->metadata()->schema()),
+ row_groups_(row_groups.begin(), row_groups.end()) {}
+
+ virtual ~FileColumnIterator() {}
+
+ std::unique_ptr<::parquet::PageReader> NextChunk() {
+ if (row_groups_.empty()) {
+ return nullptr;
+ }
+
+ auto row_group_reader = reader_->RowGroup(row_groups_.front());
+ row_groups_.pop_front();
+ return row_group_reader->GetColumnPageReader(column_index_);
+ }
+
+ const SchemaDescriptor* schema() const { return schema_; }
+
+ const ColumnDescriptor* descr() const { return schema_->Column(column_index_); }
+
+ std::shared_ptr<FileMetaData> metadata() const { return reader_->metadata(); }
+
+ int column_index() const { return column_index_; }
+
+ protected:
+ int column_index_;
+ ParquetFileReader* reader_;
+ const SchemaDescriptor* schema_;
+ std::deque<int> row_groups_;
+};
+
+using FileColumnIteratorFactory =
+ std::function<FileColumnIterator*(int, ParquetFileReader*)>;
+
+Status TransferColumnData(::parquet::internal::RecordReader* reader,
+ std::shared_ptr<::arrow::DataType> value_type,
+ const ColumnDescriptor* descr, ::arrow::MemoryPool* pool,
+ std::shared_ptr<::arrow::ChunkedArray>* out);
+
+struct ReaderContext {
+ ParquetFileReader* reader;
+ ::arrow::MemoryPool* pool;
+ FileColumnIteratorFactory iterator_factory;
+ bool filter_leaves;
+ std::shared_ptr<std::unordered_set<int>> included_leaves;
+
+ bool IncludesLeaf(int leaf_index) const {
+ if (this->filter_leaves) {
+ return this->included_leaves->find(leaf_index) != this->included_leaves->end();
+ }
+ return true;
+ }
+};
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/reader_writer_benchmark.cc b/src/arrow/cpp/src/parquet/arrow/reader_writer_benchmark.cc
new file mode 100644
index 000000000..6445bb027
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/reader_writer_benchmark.cc
@@ -0,0 +1,585 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include <array>
+#include <iostream>
+#include <random>
+
+#include "parquet/arrow/reader.h"
+#include "parquet/arrow/writer.h"
+#include "parquet/column_reader.h"
+#include "parquet/column_writer.h"
+#include "parquet/file_reader.h"
+#include "parquet/file_writer.h"
+#include "parquet/platform.h"
+
+#include "arrow/array.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/io/memory.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/logging.h"
+
+using arrow::Array;
+using arrow::ArrayVector;
+using arrow::BooleanBuilder;
+using arrow::FieldVector;
+using arrow::NumericBuilder;
+
+#define EXIT_NOT_OK(s) \
+ do { \
+ ::arrow::Status _s = (s); \
+ if (ARROW_PREDICT_FALSE(!_s.ok())) { \
+ std::cout << "Exiting: " << _s.ToString() << std::endl; \
+ exit(EXIT_FAILURE); \
+ } \
+ } while (0)
+
+namespace parquet {
+
+using arrow::FileReader;
+using arrow::WriteTable;
+using schema::PrimitiveNode;
+
+namespace benchmark {
+
+// This should result in multiple pages for most primitive types
+constexpr int64_t BENCHMARK_SIZE = 10 * 1024 * 1024;
+
+template <typename ParquetType>
+struct benchmark_traits {};
+
+template <>
+struct benchmark_traits<Int32Type> {
+ using arrow_type = ::arrow::Int32Type;
+};
+
+template <>
+struct benchmark_traits<Int64Type> {
+ using arrow_type = ::arrow::Int64Type;
+};
+
+template <>
+struct benchmark_traits<DoubleType> {
+ using arrow_type = ::arrow::DoubleType;
+};
+
+template <>
+struct benchmark_traits<BooleanType> {
+ using arrow_type = ::arrow::BooleanType;
+};
+
+template <typename ParquetType>
+using ArrowType = typename benchmark_traits<ParquetType>::arrow_type;
+
+template <typename ParquetType>
+std::shared_ptr<ColumnDescriptor> MakeSchema(Repetition::type repetition) {
+ auto node = PrimitiveNode::Make("int64", repetition, ParquetType::type_num);
+ return std::make_shared<ColumnDescriptor>(node, repetition != Repetition::REQUIRED,
+ repetition == Repetition::REPEATED);
+}
+
+template <bool nullable, typename ParquetType>
+void SetBytesProcessed(::benchmark::State& state, int64_t num_values = BENCHMARK_SIZE) {
+ const int64_t items_processed = state.iterations() * num_values;
+ const int64_t bytes_processed = items_processed * sizeof(typename ParquetType::c_type);
+
+ state.SetItemsProcessed(bytes_processed);
+ state.SetBytesProcessed(bytes_processed);
+}
+
+constexpr int64_t kAlternatingOrNa = -1;
+
+template <typename T>
+std::vector<T> RandomVector(int64_t true_percentage, int64_t vector_size,
+ const std::array<T, 2>& sample_values, int seed = 500) {
+ std::vector<T> values(vector_size, {});
+ if (true_percentage == kAlternatingOrNa) {
+ int n = {0};
+ std::generate(values.begin(), values.end(), [&n] { return n++ % 2; });
+ } else {
+ std::default_random_engine rng(seed);
+ double true_probability = static_cast<double>(true_percentage) / 100.0;
+ std::bernoulli_distribution dist(true_probability);
+ std::generate(values.begin(), values.end(), [&] { return sample_values[dist(rng)]; });
+ }
+ return values;
+}
+
+template <typename ParquetType>
+std::shared_ptr<::arrow::Table> TableFromVector(
+ const std::vector<typename ParquetType::c_type>& vec, bool nullable,
+ int64_t null_percentage = kAlternatingOrNa) {
+ if (!nullable) {
+ ARROW_CHECK_EQ(null_percentage, kAlternatingOrNa);
+ }
+ std::shared_ptr<::arrow::DataType> type = std::make_shared<ArrowType<ParquetType>>();
+ NumericBuilder<ArrowType<ParquetType>> builder;
+ if (nullable) {
+ // Note true values select index 1 of sample_values
+ auto valid_bytes = RandomVector<uint8_t>(/*true_percentage=*/null_percentage,
+ vec.size(), /*sample_values=*/{1, 0});
+ EXIT_NOT_OK(builder.AppendValues(vec.data(), vec.size(), valid_bytes.data()));
+ } else {
+ EXIT_NOT_OK(builder.AppendValues(vec.data(), vec.size(), nullptr));
+ }
+ std::shared_ptr<::arrow::Array> array;
+ EXIT_NOT_OK(builder.Finish(&array));
+
+ auto field = ::arrow::field("column", type, nullable);
+ auto schema = ::arrow::schema({field});
+ return ::arrow::Table::Make(schema, {array});
+}
+
+template <>
+std::shared_ptr<::arrow::Table> TableFromVector<BooleanType>(const std::vector<bool>& vec,
+ bool nullable,
+ int64_t null_percentage) {
+ BooleanBuilder builder;
+ if (nullable) {
+ auto valid_bytes = RandomVector<bool>(/*true_percentage=*/null_percentage, vec.size(),
+ {true, false});
+ EXIT_NOT_OK(builder.AppendValues(vec, valid_bytes));
+ } else {
+ EXIT_NOT_OK(builder.AppendValues(vec));
+ }
+ std::shared_ptr<::arrow::Array> array;
+ EXIT_NOT_OK(builder.Finish(&array));
+
+ auto field = ::arrow::field("column", ::arrow::boolean(), nullable);
+ auto schema = std::make_shared<::arrow::Schema>(
+ std::vector<std::shared_ptr<::arrow::Field>>({field}));
+ return ::arrow::Table::Make(schema, {array});
+}
+
+template <bool nullable, typename ParquetType>
+static void BM_WriteColumn(::benchmark::State& state) {
+ using T = typename ParquetType::c_type;
+ std::vector<T> values(BENCHMARK_SIZE, static_cast<T>(128));
+ std::shared_ptr<::arrow::Table> table = TableFromVector<ParquetType>(values, nullable);
+
+ while (state.KeepRunning()) {
+ auto output = CreateOutputStream();
+ EXIT_NOT_OK(
+ WriteTable(*table, ::arrow::default_memory_pool(), output, BENCHMARK_SIZE));
+ }
+ SetBytesProcessed<nullable, ParquetType>(state);
+}
+
+BENCHMARK_TEMPLATE2(BM_WriteColumn, false, Int32Type);
+BENCHMARK_TEMPLATE2(BM_WriteColumn, true, Int32Type);
+
+BENCHMARK_TEMPLATE2(BM_WriteColumn, false, Int64Type);
+BENCHMARK_TEMPLATE2(BM_WriteColumn, true, Int64Type);
+
+BENCHMARK_TEMPLATE2(BM_WriteColumn, false, DoubleType);
+BENCHMARK_TEMPLATE2(BM_WriteColumn, true, DoubleType);
+
+BENCHMARK_TEMPLATE2(BM_WriteColumn, false, BooleanType);
+BENCHMARK_TEMPLATE2(BM_WriteColumn, true, BooleanType);
+
+template <typename T>
+struct Examples {
+ static constexpr std::array<T, 2> values() { return {127, 128}; }
+};
+
+template <>
+struct Examples<bool> {
+ static constexpr std::array<bool, 2> values() { return {false, true}; }
+};
+
+static void BenchmarkReadTable(::benchmark::State& state, const ::arrow::Table& table,
+ int64_t num_values = -1, int64_t bytes_per_value = -1) {
+ auto output = CreateOutputStream();
+ EXIT_NOT_OK(
+ WriteTable(table, ::arrow::default_memory_pool(), output, table.num_rows()));
+ PARQUET_ASSIGN_OR_THROW(auto buffer, output->Finish());
+
+ while (state.KeepRunning()) {
+ auto reader =
+ ParquetFileReader::Open(std::make_shared<::arrow::io::BufferReader>(buffer));
+ std::unique_ptr<FileReader> arrow_reader;
+ EXIT_NOT_OK(FileReader::Make(::arrow::default_memory_pool(), std::move(reader),
+ &arrow_reader));
+ std::shared_ptr<::arrow::Table> table;
+ EXIT_NOT_OK(arrow_reader->ReadTable(&table));
+ }
+
+ if (num_values == -1) {
+ num_values = table.num_rows();
+ }
+ state.SetItemsProcessed(num_values * state.iterations());
+ if (bytes_per_value != -1) {
+ state.SetBytesProcessed(num_values * state.iterations() * bytes_per_value);
+ }
+}
+
+static void BenchmarkReadArray(::benchmark::State& state,
+ const std::shared_ptr<Array>& array, bool nullable,
+ int64_t num_values = -1, int64_t bytes_per_value = -1) {
+ auto schema = ::arrow::schema({field("s", array->type(), nullable)});
+ auto table = ::arrow::Table::Make(schema, {array}, array->length());
+
+ EXIT_NOT_OK(table->Validate());
+
+ BenchmarkReadTable(state, *table, num_values, bytes_per_value);
+}
+
+//
+// Benchmark reading a primitive column
+//
+
+template <bool nullable, typename ParquetType>
+static void BM_ReadColumn(::benchmark::State& state) {
+ using T = typename ParquetType::c_type;
+
+ auto values = RandomVector<T>(/*percentage=*/state.range(1), BENCHMARK_SIZE,
+ Examples<T>::values());
+
+ std::shared_ptr<::arrow::Table> table =
+ TableFromVector<ParquetType>(values, nullable, state.range(0));
+
+ BenchmarkReadTable(state, *table, table->num_rows(),
+ sizeof(typename ParquetType::c_type));
+}
+
+// There are two parameters here that cover different data distributions.
+// null_percentage governs distribution and therefore runs of null values.
+// first_value_percentage governs distribution of values (we select from 1 of 2)
+// so when 0 or 100 RLE is triggered all the time. When a value in the range (0, 100)
+// there will be some percentage of RLE encoded values and some percentage of literal
+// encoded values (RLE is much less likely with percentages close to 50).
+BENCHMARK_TEMPLATE2(BM_ReadColumn, false, Int32Type)
+ ->Args({/*null_percentage=*/kAlternatingOrNa, 1})
+ ->Args({/*null_percentage=*/kAlternatingOrNa, 10})
+ ->Args({/*null_percentage=*/kAlternatingOrNa, 50});
+
+BENCHMARK_TEMPLATE2(BM_ReadColumn, true, Int32Type)
+ ->Args({/*null_percentage=*/kAlternatingOrNa, /*first_value_percentage=*/0})
+ ->Args({/*null_percentage=*/1, /*first_value_percentage=*/1})
+ ->Args({/*null_percentage=*/10, /*first_value_percentage=*/10})
+ ->Args({/*null_percentage=*/25, /*first_value_percentage=*/5})
+ ->Args({/*null_percentage=*/50, /*first_value_percentage=*/50})
+ ->Args({/*null_percentage=*/50, /*first_value_percentage=*/0})
+ ->Args({/*null_percentage=*/99, /*first_value_percentage=*/50})
+ ->Args({/*null_percentage=*/99, /*first_value_percentage=*/0});
+
+BENCHMARK_TEMPLATE2(BM_ReadColumn, false, Int64Type)
+ ->Args({/*null_percentage=*/kAlternatingOrNa, 1})
+ ->Args({/*null_percentage=*/kAlternatingOrNa, 10})
+ ->Args({/*null_percentage=*/kAlternatingOrNa, 50});
+BENCHMARK_TEMPLATE2(BM_ReadColumn, true, Int64Type)
+ ->Args({/*null_percentage=*/kAlternatingOrNa, /*first_value_percentage=*/0})
+ ->Args({/*null_percentage=*/1, /*first_value_percentage=*/1})
+ ->Args({/*null_percentage=*/5, /*first_value_percentage=*/5})
+ ->Args({/*null_percentage=*/10, /*first_value_percentage=*/5})
+ ->Args({/*null_percentage=*/25, /*first_value_percentage=*/10})
+ ->Args({/*null_percentage=*/30, /*first_value_percentage=*/10})
+ ->Args({/*null_percentage=*/35, /*first_value_percentage=*/10})
+ ->Args({/*null_percentage=*/45, /*first_value_percentage=*/25})
+ ->Args({/*null_percentage=*/50, /*first_value_percentage=*/50})
+ ->Args({/*null_percentage=*/50, /*first_value_percentage=*/1})
+ ->Args({/*null_percentage=*/75, /*first_value_percentage=*/1})
+ ->Args({/*null_percentage=*/99, /*first_value_percentage=*/50})
+ ->Args({/*null_percentage=*/99, /*first_value_percentage=*/0});
+
+BENCHMARK_TEMPLATE2(BM_ReadColumn, false, DoubleType)
+ ->Args({kAlternatingOrNa, 0})
+ ->Args({kAlternatingOrNa, 20});
+// Less coverage because int64_t should be pretty good representation for nullability and
+// repeating values.
+BENCHMARK_TEMPLATE2(BM_ReadColumn, true, DoubleType)
+ ->Args({/*null_percentage=*/kAlternatingOrNa, /*first_value_percentage=*/0})
+ ->Args({/*null_percentage=*/10, /*first_value_percentage=*/50})
+ ->Args({/*null_percentage=*/25, /*first_value_percentage=*/25});
+
+BENCHMARK_TEMPLATE2(BM_ReadColumn, false, BooleanType)
+ ->Args({kAlternatingOrNa, 0})
+ ->Args({1, 20});
+BENCHMARK_TEMPLATE2(BM_ReadColumn, true, BooleanType)
+ ->Args({kAlternatingOrNa, 1})
+ ->Args({5, 10});
+
+//
+// Benchmark reading a nested column
+//
+
+const std::vector<int64_t> kNestedNullPercents = {0, 1, 50, 99};
+
+// XXX We can use ArgsProduct() starting from Benchmark 1.5.2
+static void NestedReadArguments(::benchmark::internal::Benchmark* b) {
+ for (const auto null_percentage : kNestedNullPercents) {
+ b->Arg(null_percentage);
+ }
+}
+
+static std::shared_ptr<Array> MakeStructArray(::arrow::random::RandomArrayGenerator* rng,
+ const ArrayVector& children,
+ double null_probability,
+ bool propagate_validity = false) {
+ ARROW_CHECK_GT(children.size(), 0);
+ const int64_t length = children[0]->length();
+
+ std::shared_ptr<::arrow::Buffer> null_bitmap;
+ if (null_probability > 0.0) {
+ null_bitmap = rng->NullBitmap(length, null_probability);
+ if (propagate_validity) {
+ // HACK: the Parquet writer currently doesn't allow non-empty list
+ // entries where a parent node is null (for instance, a struct-of-list
+ // where the outer struct is marked null but the inner list value is
+ // non-empty).
+ for (const auto& child : children) {
+ null_bitmap = *::arrow::internal::BitmapOr(
+ ::arrow::default_memory_pool(), null_bitmap->data(), 0,
+ child->null_bitmap_data(), 0, length, 0);
+ }
+ }
+ }
+ FieldVector fields(children.size());
+ char field_name = 'a';
+ for (size_t i = 0; i < children.size(); ++i) {
+ fields[i] = field(std::string{field_name++}, children[i]->type(),
+ /*nullable=*/null_probability > 0.0);
+ }
+ return *::arrow::StructArray::Make(children, std::move(fields), null_bitmap);
+}
+
+// Make a (int32, int64) struct array
+static std::shared_ptr<Array> MakeStructArray(::arrow::random::RandomArrayGenerator* rng,
+ int64_t size, double null_probability) {
+ auto values1 = rng->Int32(size, -5, 5, null_probability);
+ auto values2 = rng->Int64(size, -12345678912345LL, 12345678912345LL, null_probability);
+ return MakeStructArray(rng, {values1, values2}, null_probability);
+}
+
+static void BM_ReadStructColumn(::benchmark::State& state) {
+ constexpr int64_t kNumValues = BENCHMARK_SIZE / 10;
+ const double null_probability = static_cast<double>(state.range(0)) / 100.0;
+ const bool nullable = (null_probability != 0.0);
+
+ ARROW_CHECK_GE(null_probability, 0.0);
+
+ const int64_t kBytesPerValue = sizeof(int32_t) + sizeof(int64_t);
+
+ ::arrow::random::RandomArrayGenerator rng(42);
+ auto array = MakeStructArray(&rng, kNumValues, null_probability);
+
+ BenchmarkReadArray(state, array, nullable, kNumValues, kBytesPerValue);
+}
+
+BENCHMARK(BM_ReadStructColumn)->Apply(NestedReadArguments);
+
+static void BM_ReadStructOfStructColumn(::benchmark::State& state) {
+ constexpr int64_t kNumValues = BENCHMARK_SIZE / 10;
+ const double null_probability = static_cast<double>(state.range(0)) / 100.0;
+ const bool nullable = (null_probability != 0.0);
+
+ ARROW_CHECK_GE(null_probability, 0.0);
+
+ const int64_t kBytesPerValue = 2 * (sizeof(int32_t) + sizeof(int64_t));
+
+ ::arrow::random::RandomArrayGenerator rng(42);
+ auto values1 = MakeStructArray(&rng, kNumValues, null_probability);
+ auto values2 = MakeStructArray(&rng, kNumValues, null_probability);
+ auto array = MakeStructArray(&rng, {values1, values2}, null_probability);
+
+ BenchmarkReadArray(state, array, nullable, kNumValues, kBytesPerValue);
+}
+
+BENCHMARK(BM_ReadStructOfStructColumn)->Apply(NestedReadArguments);
+
+static void BM_ReadStructOfListColumn(::benchmark::State& state) {
+ constexpr int64_t kNumValues = BENCHMARK_SIZE / 10;
+ const double null_probability = static_cast<double>(state.range(0)) / 100.0;
+ const bool nullable = (null_probability != 0.0);
+
+ ARROW_CHECK_GE(null_probability, 0.0);
+
+ ::arrow::random::RandomArrayGenerator rng(42);
+
+ const int64_t kBytesPerValue = sizeof(int32_t) + sizeof(int64_t);
+
+ auto values1 = rng.Int32(kNumValues, -5, 5, null_probability);
+ auto values2 =
+ rng.Int64(kNumValues, -12345678912345LL, 12345678912345LL, null_probability);
+ auto list1 = rng.List(*values1, kNumValues / 10, null_probability);
+ auto list2 = rng.List(*values2, kNumValues / 10, null_probability);
+ auto array = MakeStructArray(&rng, {list1, list2}, null_probability,
+ /*propagate_validity =*/true);
+
+ BenchmarkReadArray(state, array, nullable, kNumValues, kBytesPerValue);
+}
+
+BENCHMARK(BM_ReadStructOfListColumn)->Apply(NestedReadArguments);
+
+static void BM_ReadListColumn(::benchmark::State& state) {
+ constexpr int64_t kNumValues = BENCHMARK_SIZE / 10;
+ const double null_probability = static_cast<double>(state.range(0)) / 100.0;
+ const bool nullable = (null_probability != 0.0);
+
+ ARROW_CHECK_GE(null_probability, 0.0);
+
+ ::arrow::random::RandomArrayGenerator rng(42);
+
+ auto values = rng.Int64(kNumValues, /*min=*/-5, /*max=*/5, null_probability);
+ const int64_t kBytesPerValue = sizeof(int64_t);
+
+ auto array = rng.List(*values, kNumValues / 10, null_probability);
+
+ BenchmarkReadArray(state, array, nullable, kNumValues, kBytesPerValue);
+}
+
+BENCHMARK(BM_ReadListColumn)->Apply(NestedReadArguments);
+
+static void BM_ReadListOfStructColumn(::benchmark::State& state) {
+ constexpr int64_t kNumValues = BENCHMARK_SIZE / 10;
+ const double null_probability = static_cast<double>(state.range(0)) / 100.0;
+ const bool nullable = (null_probability != 0.0);
+
+ ARROW_CHECK_GE(null_probability, 0.0);
+
+ ::arrow::random::RandomArrayGenerator rng(42);
+
+ auto values = MakeStructArray(&rng, kNumValues, null_probability);
+ const int64_t kBytesPerValue = sizeof(int32_t) + sizeof(int64_t);
+
+ auto array = rng.List(*values, kNumValues / 10, null_probability);
+
+ BenchmarkReadArray(state, array, nullable, kNumValues, kBytesPerValue);
+}
+
+BENCHMARK(BM_ReadListOfStructColumn)->Apply(NestedReadArguments);
+
+static void BM_ReadListOfListColumn(::benchmark::State& state) {
+ constexpr int64_t kNumValues = BENCHMARK_SIZE / 10;
+ const double null_probability = static_cast<double>(state.range(0)) / 100.0;
+ const bool nullable = (null_probability != 0.0);
+
+ ARROW_CHECK_GE(null_probability, 0.0);
+
+ ::arrow::random::RandomArrayGenerator rng(42);
+
+ auto values = rng.Int64(kNumValues, /*min=*/-5, /*max=*/5, null_probability);
+ const int64_t kBytesPerValue = sizeof(int64_t);
+
+ auto inner = rng.List(*values, kNumValues / 10, null_probability);
+ auto array = rng.List(*inner, kNumValues / 100, null_probability);
+
+ BenchmarkReadArray(state, array, nullable, kNumValues, kBytesPerValue);
+}
+
+BENCHMARK(BM_ReadListOfListColumn)->Apply(NestedReadArguments);
+
+//
+// Benchmark different ways of reading select row groups
+//
+
+static void BM_ReadIndividualRowGroups(::benchmark::State& state) {
+ std::vector<int64_t> values(BENCHMARK_SIZE, 128);
+ std::shared_ptr<::arrow::Table> table = TableFromVector<Int64Type>(values, true);
+ auto output = CreateOutputStream();
+ // This writes 10 RowGroups
+ EXIT_NOT_OK(
+ WriteTable(*table, ::arrow::default_memory_pool(), output, BENCHMARK_SIZE / 10));
+
+ PARQUET_ASSIGN_OR_THROW(auto buffer, output->Finish());
+
+ while (state.KeepRunning()) {
+ auto reader =
+ ParquetFileReader::Open(std::make_shared<::arrow::io::BufferReader>(buffer));
+ std::unique_ptr<FileReader> arrow_reader;
+ EXIT_NOT_OK(FileReader::Make(::arrow::default_memory_pool(), std::move(reader),
+ &arrow_reader));
+
+ std::vector<std::shared_ptr<::arrow::Table>> tables;
+ for (int i = 0; i < arrow_reader->num_row_groups(); i++) {
+ // Only read the even numbered RowGroups
+ if ((i % 2) == 0) {
+ std::shared_ptr<::arrow::Table> table;
+ EXIT_NOT_OK(arrow_reader->RowGroup(i)->ReadTable(&table));
+ tables.push_back(table);
+ }
+ }
+
+ std::shared_ptr<::arrow::Table> final_table;
+ PARQUET_ASSIGN_OR_THROW(final_table, ConcatenateTables(tables));
+ }
+ SetBytesProcessed<true, Int64Type>(state);
+}
+
+BENCHMARK(BM_ReadIndividualRowGroups);
+
+static void BM_ReadMultipleRowGroups(::benchmark::State& state) {
+ std::vector<int64_t> values(BENCHMARK_SIZE, 128);
+ std::shared_ptr<::arrow::Table> table = TableFromVector<Int64Type>(values, true);
+ auto output = CreateOutputStream();
+ // This writes 10 RowGroups
+ EXIT_NOT_OK(
+ WriteTable(*table, ::arrow::default_memory_pool(), output, BENCHMARK_SIZE / 10));
+ PARQUET_ASSIGN_OR_THROW(auto buffer, output->Finish());
+ std::vector<int> rgs{0, 2, 4, 6, 8};
+
+ while (state.KeepRunning()) {
+ auto reader =
+ ParquetFileReader::Open(std::make_shared<::arrow::io::BufferReader>(buffer));
+ std::unique_ptr<FileReader> arrow_reader;
+ EXIT_NOT_OK(FileReader::Make(::arrow::default_memory_pool(), std::move(reader),
+ &arrow_reader));
+ std::shared_ptr<::arrow::Table> table;
+ EXIT_NOT_OK(arrow_reader->ReadRowGroups(rgs, &table));
+ }
+ SetBytesProcessed<true, Int64Type>(state);
+}
+
+BENCHMARK(BM_ReadMultipleRowGroups);
+
+static void BM_ReadMultipleRowGroupsGenerator(::benchmark::State& state) {
+ std::vector<int64_t> values(BENCHMARK_SIZE, 128);
+ std::shared_ptr<::arrow::Table> table = TableFromVector<Int64Type>(values, true);
+ auto output = CreateOutputStream();
+ // This writes 10 RowGroups
+ EXIT_NOT_OK(
+ WriteTable(*table, ::arrow::default_memory_pool(), output, BENCHMARK_SIZE / 10));
+ PARQUET_ASSIGN_OR_THROW(auto buffer, output->Finish());
+ std::vector<int> rgs{0, 2, 4, 6, 8};
+
+ while (state.KeepRunning()) {
+ auto reader =
+ ParquetFileReader::Open(std::make_shared<::arrow::io::BufferReader>(buffer));
+ std::unique_ptr<FileReader> unique_reader;
+ EXIT_NOT_OK(FileReader::Make(::arrow::default_memory_pool(), std::move(reader),
+ &unique_reader));
+ std::shared_ptr<FileReader> arrow_reader = std::move(unique_reader);
+ ASSIGN_OR_ABORT(auto generator,
+ arrow_reader->GetRecordBatchGenerator(arrow_reader, rgs, {0}));
+ auto fut = ::arrow::CollectAsyncGenerator(generator);
+ ASSIGN_OR_ABORT(auto batches, fut.result());
+ ASSIGN_OR_ABORT(auto actual, ::arrow::Table::FromRecordBatches(std::move(batches)));
+ }
+ SetBytesProcessed<true, Int64Type>(state);
+}
+
+BENCHMARK(BM_ReadMultipleRowGroupsGenerator);
+
+} // namespace benchmark
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/reconstruct_internal_test.cc b/src/arrow/cpp/src/parquet/arrow/reconstruct_internal_test.cc
new file mode 100644
index 000000000..495b69f9e
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/reconstruct_internal_test.cc
@@ -0,0 +1,1639 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/arrow/path_internal.h"
+
+#include <algorithm>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/array/concatenate.h"
+#include "arrow/chunked_array.h"
+#include "arrow/io/memory.h"
+#include "arrow/result.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+#include "parquet/arrow/reader.h"
+#include "parquet/arrow/schema.h"
+#include "parquet/column_writer.h"
+#include "parquet/file_writer.h"
+#include "parquet/properties.h"
+
+using arrow::Array;
+using arrow::ArrayFromJSON;
+using arrow::AssertArraysEqual;
+using arrow::ChunkedArray;
+using arrow::DataType;
+using arrow::field;
+using arrow::int32;
+using arrow::int64;
+using arrow::list;
+using arrow::MemoryPool;
+using arrow::Result;
+using arrow::Status;
+using arrow::struct_;
+using arrow::internal::checked_cast;
+using arrow::internal::checked_pointer_cast;
+using arrow::io::BufferOutputStream;
+using arrow::io::BufferReader;
+
+using testing::ElementsAre;
+using testing::ElementsAreArray;
+using testing::Eq;
+using testing::NotNull;
+using testing::SizeIs;
+
+namespace parquet {
+namespace arrow {
+
+using parquet::schema::GroupNode;
+using parquet::schema::NodePtr;
+using parquet::schema::PrimitiveNode;
+
+using ParquetType = parquet::Type::type;
+template <ParquetType T>
+using ParquetTraits = parquet::type_traits<T>;
+
+using LevelVector = std::vector<int16_t>;
+// For readability
+using DefLevels = LevelVector;
+using RepLevels = LevelVector;
+using Int32Vector = std::vector<int32_t>;
+using Int64Vector = std::vector<int64_t>;
+
+// A Parquet file builder that allows writing values one leaf column at a time
+class FileBuilder {
+ public:
+ static Result<std::shared_ptr<FileBuilder>> Make(const NodePtr& group_node,
+ int num_columns) {
+ auto self = std::make_shared<FileBuilder>();
+ RETURN_NOT_OK(self->Open(group_node, num_columns));
+ return self;
+ }
+
+ Result<std::shared_ptr<Buffer>> Finish() {
+ DCHECK_EQ(column_index_, num_columns_);
+ row_group_writer_->Close();
+ file_writer_->Close();
+ return stream_->Finish();
+ }
+
+ // Write a leaf (primitive) column
+ template <ParquetType TYPE, typename C_TYPE = typename ParquetTraits<TYPE>::value_type>
+ Status WriteColumn(const LevelVector& def_levels, const LevelVector& rep_levels,
+ const std::vector<C_TYPE>& values) {
+ auto column_writer = row_group_writer_->NextColumn();
+ auto column_descr = column_writer->descr();
+ const int16_t max_def_level = column_descr->max_definition_level();
+ const int16_t max_rep_level = column_descr->max_repetition_level();
+ CheckTestedLevels(def_levels, max_def_level);
+ CheckTestedLevels(rep_levels, max_rep_level);
+
+ auto typed_writer =
+ checked_cast<TypedColumnWriter<PhysicalType<TYPE>>*>(column_writer);
+
+ const int64_t num_values = static_cast<int64_t>(
+ (max_def_level > 0) ? def_levels.size()
+ : (max_rep_level > 0) ? rep_levels.size() : values.size());
+ const int64_t values_written = typed_writer->WriteBatch(
+ num_values, LevelPointerOrNull(def_levels, max_def_level),
+ LevelPointerOrNull(rep_levels, max_rep_level), values.data());
+ DCHECK_EQ(values_written, static_cast<int64_t>(values.size())); // Sanity check
+
+ column_writer->Close();
+ ++column_index_;
+ return Status::OK();
+ }
+
+ protected:
+ Status Open(const NodePtr& group_node, int num_columns) {
+ ARROW_ASSIGN_OR_RAISE(stream_, BufferOutputStream::Create());
+ file_writer_ =
+ ParquetFileWriter::Open(stream_, checked_pointer_cast<GroupNode>(group_node));
+ row_group_writer_ = file_writer_->AppendRowGroup();
+ num_columns_ = num_columns;
+ column_index_ = 0;
+ return Status::OK();
+ }
+
+ void CheckTestedLevels(const LevelVector& levels, int16_t max_level) {
+ // Tests are expected to exercise all possible levels in [0, max_level]
+ if (!levels.empty()) {
+ const int16_t max_seen_level = *std::max_element(levels.begin(), levels.end());
+ DCHECK_EQ(max_seen_level, max_level);
+ }
+ }
+
+ const int16_t* LevelPointerOrNull(const LevelVector& levels, int16_t max_level) {
+ if (max_level > 0) {
+ DCHECK_GT(levels.size(), 0);
+ return levels.data();
+ } else {
+ DCHECK_EQ(levels.size(), 0);
+ return nullptr;
+ }
+ }
+
+ std::shared_ptr<BufferOutputStream> stream_;
+ std::unique_ptr<ParquetFileWriter> file_writer_;
+ RowGroupWriter* row_group_writer_;
+ int num_columns_;
+ int column_index_;
+};
+
+// A Parquet file tester that allows reading Arrow columns, corresponding to
+// children of the top-level group node.
+class FileTester {
+ public:
+ static Result<std::shared_ptr<FileTester>> Make(std::shared_ptr<Buffer> buffer,
+ MemoryPool* pool) {
+ auto self = std::make_shared<FileTester>();
+ RETURN_NOT_OK(self->Open(buffer, pool));
+ return self;
+ }
+
+ Result<std::shared_ptr<Array>> ReadColumn(int column_index) {
+ std::shared_ptr<ChunkedArray> column;
+ RETURN_NOT_OK(file_reader_->ReadColumn(column_index, &column));
+ return ::arrow::Concatenate(column->chunks(), pool_);
+ }
+
+ void CheckColumn(int column_index, const Array& expected) {
+ ASSERT_OK_AND_ASSIGN(const auto actual, ReadColumn(column_index));
+ ASSERT_OK(actual->ValidateFull());
+ AssertArraysEqual(expected, *actual, /*verbose=*/true);
+ }
+
+ protected:
+ Status Open(std::shared_ptr<Buffer> buffer, MemoryPool* pool) {
+ pool_ = pool;
+ return OpenFile(std::make_shared<BufferReader>(buffer), pool_, &file_reader_);
+ }
+
+ MemoryPool* pool_;
+ std::unique_ptr<FileReader> file_reader_;
+};
+
+class TestReconstructColumn : public testing::Test {
+ public:
+ void SetUp() override { pool_ = ::arrow::default_memory_pool(); }
+
+ // Write the next leaf (primitive) column
+ template <ParquetType TYPE, typename C_TYPE = typename ParquetTraits<TYPE>::value_type>
+ Status WriteColumn(const LevelVector& def_levels, const LevelVector& rep_levels,
+ const std::vector<C_TYPE>& values) {
+ if (!builder_) {
+ ARROW_ASSIGN_OR_RAISE(builder_,
+ FileBuilder::Make(group_node_, descriptor_->num_columns()));
+ }
+ return builder_->WriteColumn<TYPE, C_TYPE>(def_levels, rep_levels, values);
+ }
+
+ template <typename C_TYPE>
+ Status WriteInt32Column(const LevelVector& def_levels, const LevelVector& rep_levels,
+ const std::vector<C_TYPE>& values) {
+ return WriteColumn<ParquetType::INT32>(def_levels, rep_levels, values);
+ }
+
+ template <typename C_TYPE>
+ Status WriteInt64Column(const LevelVector& def_levels, const LevelVector& rep_levels,
+ const std::vector<C_TYPE>& values) {
+ return WriteColumn<ParquetType::INT64>(def_levels, rep_levels, values);
+ }
+
+ // Read a Arrow column and check its values
+ void CheckColumn(int column_index, const Array& expected) {
+ if (!tester_) {
+ ASSERT_OK_AND_ASSIGN(auto buffer, builder_->Finish());
+ ASSERT_OK_AND_ASSIGN(tester_, FileTester::Make(buffer, pool_));
+ }
+ tester_->CheckColumn(column_index, expected);
+ }
+
+ void CheckColumn(const Array& expected) { CheckColumn(/*column_index=*/0, expected); }
+
+ // One-column shortcut
+ template <ParquetType TYPE, typename C_TYPE = typename ParquetTraits<TYPE>::value_type>
+ void AssertReconstruct(const Array& expected, const LevelVector& def_levels,
+ const LevelVector& rep_levels,
+ const std::vector<C_TYPE>& values) {
+ ASSERT_OK((WriteColumn<TYPE, C_TYPE>(def_levels, rep_levels, values)));
+ CheckColumn(/*column_index=*/0, expected);
+ }
+
+ ::arrow::Status MaybeSetParquetSchema(const NodePtr& column) {
+ descriptor_.reset(new SchemaDescriptor());
+ manifest_.reset(new SchemaManifest());
+ group_node_ = GroupNode::Make("root", Repetition::REQUIRED, {column});
+ descriptor_->Init(group_node_);
+ return SchemaManifest::Make(descriptor_.get(),
+ std::shared_ptr<const ::arrow::KeyValueMetadata>(),
+ ArrowReaderProperties(), manifest_.get());
+ }
+
+ void SetParquetSchema(const NodePtr& column) {
+ ASSERT_OK(MaybeSetParquetSchema(column));
+ }
+
+ protected:
+ MemoryPool* pool_;
+ NodePtr group_node_;
+ std::unique_ptr<SchemaDescriptor> descriptor_;
+ std::unique_ptr<SchemaManifest> manifest_;
+
+ std::shared_ptr<FileBuilder> builder_;
+ std::shared_ptr<FileTester> tester_;
+};
+
+static std::shared_ptr<DataType> OneFieldStruct(const std::string& name,
+ std::shared_ptr<DataType> type,
+ bool nullable = true) {
+ return struct_({field(name, type, nullable)});
+}
+
+static std::shared_ptr<DataType> List(std::shared_ptr<DataType> type,
+ bool nullable = true) {
+ // TODO should field name "element" (Parquet convention for List nodes)
+ // be changed to "item" (Arrow convention for List types)?
+ return list(field("element", type, nullable));
+}
+
+//
+// Primitive columns with no intermediate group node
+//
+
+TEST_F(TestReconstructColumn, PrimitiveOptional) {
+ SetParquetSchema(
+ PrimitiveNode::Make("node_name", Repetition::OPTIONAL, ParquetType::INT32));
+
+ LevelVector def_levels = {1, 0, 1, 1};
+ LevelVector rep_levels = {};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto expected = ArrayFromJSON(int32(), "[4, null, 5, 6]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, PrimitiveRequired) {
+ SetParquetSchema(
+ PrimitiveNode::Make("node_name", Repetition::REQUIRED, ParquetType::INT32));
+
+ LevelVector def_levels = {};
+ LevelVector rep_levels = {};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto expected = ArrayFromJSON(int32(), "[4, 5, 6]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, PrimitiveRepeated) {
+ // Arrow schema: list(int32 not null) not null
+ this->SetParquetSchema(
+ PrimitiveNode::Make("node_name", Repetition::REPEATED, ParquetType::INT32));
+
+ LevelVector def_levels = {0, 1, 1, 1};
+ LevelVector rep_levels = {0, 0, 1, 0};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto expected = ArrayFromJSON(list(field("node_name", int32(), /*nullable=*/false)),
+ "[[], [4, 5], [6]]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+//
+// Struct encodings (one field each)
+//
+
+TEST_F(TestReconstructColumn, NestedRequiredRequired) {
+ // Arrow schema: struct(a: int32 not null) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {PrimitiveNode::Make("a", Repetition::REQUIRED, ParquetType::INT32)}));
+
+ LevelVector def_levels = {};
+ LevelVector rep_levels = {};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto expected = ArrayFromJSON(OneFieldStruct("a", int32(), false),
+ R"([{"a": 4}, {"a": 5}, {"a": 6}])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, NestedOptionalRequired) {
+ // Arrow schema: struct(a: int32 not null)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("a", Repetition::REQUIRED, ParquetType::INT32)}));
+
+ LevelVector def_levels = {0, 1, 1, 1};
+ LevelVector rep_levels = {};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto expected = ArrayFromJSON(OneFieldStruct("a", int32(), false),
+ R"([null, {"a": 4}, {"a": 5}, {"a": 6}])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, NestedRequiredOptional) {
+ // Arrow schema: struct(a: int32) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {PrimitiveNode::Make("a", Repetition::OPTIONAL, ParquetType::INT32)}));
+
+ LevelVector def_levels = {0, 1, 1, 1};
+ LevelVector rep_levels = {};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto expected = ArrayFromJSON(OneFieldStruct("a", int32()),
+ R"([{"a": null}, {"a": 4}, {"a": 5}, {"a": 6}])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, NestedOptionalOptional) {
+ // Arrow schema: struct(a: int32)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("a", Repetition::OPTIONAL, ParquetType::INT32)}));
+
+ LevelVector def_levels = {0, 1, 2, 2};
+ LevelVector rep_levels = {};
+ std::vector<int32_t> values = {4, 5};
+
+ auto expected = ArrayFromJSON(OneFieldStruct("a", int32()),
+ R"([null, {"a": null}, {"a": 4}, {"a": 5}])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+//
+// Nested struct encodings (one field each)
+//
+
+TEST_F(TestReconstructColumn, NestedRequiredRequiredRequired) {
+ // Arrow schema: struct(a: struct(b: int32 not null) not null) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "a", Repetition::REQUIRED,
+ {PrimitiveNode::Make("b", Repetition::REQUIRED, ParquetType::INT32)})}));
+
+ LevelVector def_levels = {};
+ LevelVector rep_levels = {};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto expected =
+ ArrayFromJSON(OneFieldStruct("a", OneFieldStruct("b", int32(), false), false),
+ R"([{"a": {"b": 4}},
+ {"a": {"b": 5}},
+ {"a": {"b": 6}}
+ ])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, NestedRequiredOptionalRequired) {
+ // Arrow schema: struct(a: struct(b: int32 not null)) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "a", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("b", Repetition::REQUIRED, ParquetType::INT32)})}));
+
+ LevelVector def_levels = {1, 0, 1, 1};
+ LevelVector rep_levels = {};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto expected = ArrayFromJSON(OneFieldStruct("a", OneFieldStruct("b", int32(), false)),
+ R"([{"a": {"b": 4}},
+ {"a": null},
+ {"a": {"b": 5}},
+ {"a": {"b": 6}}
+ ])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, NestedOptionalRequiredOptional) {
+ // Arrow schema: struct(a: struct(b: int32) not null)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "a", Repetition::REQUIRED,
+ {PrimitiveNode::Make("b", Repetition::OPTIONAL, ParquetType::INT32)})}));
+
+ LevelVector def_levels = {1, 2, 0, 2, 2};
+ LevelVector rep_levels = {};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto expected = ArrayFromJSON(OneFieldStruct("a", OneFieldStruct("b", int32()), false),
+ R"([{"a": {"b": null}},
+ {"a": {"b": 4}},
+ null,
+ {"a": {"b": 5}},
+ {"a": {"b": 6}}
+ ])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, NestedOptionalOptionalOptional) {
+ // Arrow schema: struct(a: struct(b: int32) not null)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "a", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("b", Repetition::OPTIONAL, ParquetType::INT32)})}));
+
+ LevelVector def_levels = {1, 2, 0, 3, 3, 3};
+ LevelVector rep_levels = {};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto expected = ArrayFromJSON(OneFieldStruct("a", OneFieldStruct("b", int32())),
+ R"([{"a": null},
+ {"a": {"b": null}},
+ null,
+ {"a": {"b": 4}},
+ {"a": {"b": 5}},
+ {"a": {"b": 6}}
+ ])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+//
+// Struct encodings (two fields)
+//
+
+TEST_F(TestReconstructColumn, NestedTwoFields1) {
+ // Arrow schema: struct(a: int32 not null, b: int64 not null) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {PrimitiveNode::Make("a", Repetition::REQUIRED, ParquetType::INT32),
+ PrimitiveNode::Make("b", Repetition::REQUIRED, ParquetType::INT64)}));
+
+ ASSERT_OK(WriteInt32Column(DefLevels{}, RepLevels{}, Int32Vector{4, 5, 6}));
+ ASSERT_OK(WriteInt64Column(DefLevels{}, RepLevels{}, Int64Vector{7, 8, 9}));
+
+ auto type = struct_(
+ {field("a", int32(), /*nullable=*/false), field("b", int64(), /*nullable=*/false)});
+ auto expected = ArrayFromJSON(type, R"([{"a": 4, "b": 7},
+ {"a": 5, "b": 8},
+ {"a": 6, "b": 9}])");
+
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, NestedTwoFields2) {
+ // Arrow schema: struct(a: int32 not null, b: int64) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {PrimitiveNode::Make("a", Repetition::REQUIRED, ParquetType::INT32),
+ PrimitiveNode::Make("b", Repetition::OPTIONAL, ParquetType::INT64)}));
+
+ ASSERT_OK(WriteInt32Column(DefLevels{}, RepLevels{}, Int32Vector{4, 5, 6}));
+ ASSERT_OK(WriteInt64Column(DefLevels{0, 1, 1}, RepLevels{}, Int64Vector{7, 8}));
+
+ auto type = struct_({field("a", int32(), /*nullable=*/false), field("b", int64())});
+ auto expected = ArrayFromJSON(type, R"([{"a": 4, "b": null},
+ {"a": 5, "b": 7},
+ {"a": 6, "b": 8}])");
+
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, NestedTwoFields3) {
+ // Arrow schema: struct(a: int32 not null, b: int64 not null)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("a", Repetition::REQUIRED, ParquetType::INT32),
+ PrimitiveNode::Make("b", Repetition::REQUIRED, ParquetType::INT64)}));
+
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 1}, RepLevels{}, Int32Vector{4, 5}));
+ ASSERT_OK(WriteInt64Column(DefLevels{0, 1, 1}, RepLevels{}, Int64Vector{7, 8}));
+
+ auto type = struct_(
+ {field("a", int32(), /*nullable=*/false), field("b", int64(), /*nullable=*/false)});
+ auto expected = ArrayFromJSON(type, R"([null,
+ {"a": 4, "b": 7},
+ {"a": 5, "b": 8}])");
+
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, NestedTwoFields4) {
+ // Arrow schema: struct(a: int32, b: int64 not null)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("a", Repetition::OPTIONAL, ParquetType::INT32),
+ PrimitiveNode::Make("b", Repetition::REQUIRED, ParquetType::INT64)}));
+
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 2}, RepLevels{}, Int32Vector{4}));
+ ASSERT_OK(WriteInt64Column(DefLevels{0, 1, 1}, RepLevels{}, Int64Vector{7, 8}));
+
+ auto type = struct_({field("a", int32()), field("b", int64(), /*nullable=*/false)});
+ auto expected = ArrayFromJSON(type, R"([null,
+ {"a": null, "b": 7},
+ {"a": 4, "b": 8}])");
+
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, NestedTwoFields5) {
+ // Arrow schema: struct(a: int32, b: int64)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("a", Repetition::OPTIONAL, ParquetType::INT32),
+ PrimitiveNode::Make("b", Repetition::OPTIONAL, ParquetType::INT64)}));
+
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 2}, RepLevels{}, Int32Vector{4}));
+ ASSERT_OK(WriteInt64Column(DefLevels{0, 2, 1}, RepLevels{}, Int64Vector{7}));
+
+ auto type = struct_({field("a", int32()), field("b", int64())});
+ auto expected = ArrayFromJSON(type, R"([null,
+ {"a": null, "b": 7},
+ {"a": 4, "b": null}])");
+
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+//
+// Nested struct encodings (two fields)
+//
+
+TEST_F(TestReconstructColumn, NestedNestedTwoFields1) {
+ // Arrow schema: struct(a: struct(aa: int32 not null,
+ // ab: int64 not null) not null,
+ // b: int32 not null) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "a", Repetition::REQUIRED,
+ {PrimitiveNode::Make("aa", Repetition::REQUIRED, ParquetType::INT32),
+ PrimitiveNode::Make("ab", Repetition::REQUIRED, ParquetType::INT64)}),
+ PrimitiveNode::Make("b", Repetition::REQUIRED, ParquetType::INT32)}));
+
+ // aa
+ ASSERT_OK(WriteInt32Column(DefLevels{}, RepLevels{}, Int32Vector{4, 5, 6}));
+ // ab
+ ASSERT_OK(WriteInt64Column(DefLevels{}, RepLevels{}, Int64Vector{7, 8, 9}));
+ // b
+ ASSERT_OK(WriteInt32Column(DefLevels{}, RepLevels{}, Int32Vector{10, 11, 12}));
+
+ auto type = struct_({field("a",
+ struct_({field("aa", int32(), /*nullable=*/false),
+ field("ab", int64(), /*nullable=*/false)}),
+ /*nullable=*/false),
+ field("b", int32(), /*nullable=*/false)});
+ auto expected = ArrayFromJSON(type, R"([{"a": {"aa": 4, "ab": 7}, "b": 10},
+ {"a": {"aa": 5, "ab": 8}, "b": 11},
+ {"a": {"aa": 6, "ab": 9}, "b": 12}])");
+
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, NestedNestedTwoFields2) {
+ // Arrow schema: struct(a: struct(aa: int32,
+ // ab: int64 not null) not null,
+ // b: int32 not null) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "a", Repetition::REQUIRED,
+ {PrimitiveNode::Make("aa", Repetition::OPTIONAL, ParquetType::INT32),
+ PrimitiveNode::Make("ab", Repetition::REQUIRED, ParquetType::INT64)}),
+ PrimitiveNode::Make("b", Repetition::REQUIRED, ParquetType::INT32)}));
+
+ // aa
+ ASSERT_OK(WriteInt32Column(DefLevels{1, 0, 1}, RepLevels{}, Int32Vector{4, 5}));
+ // ab
+ ASSERT_OK(WriteInt64Column(DefLevels{}, RepLevels{}, Int64Vector{7, 8, 9}));
+ // b
+ ASSERT_OK(WriteInt32Column(DefLevels{}, RepLevels{}, Int32Vector{10, 11, 12}));
+
+ auto type = struct_(
+ {field("a",
+ struct_({field("aa", int32()), field("ab", int64(), /*nullable=*/false)}),
+ /*nullable=*/false),
+ field("b", int32(), /*nullable=*/false)});
+ auto expected = ArrayFromJSON(type, R"([{"a": {"aa": 4, "ab": 7}, "b": 10},
+ {"a": {"aa": null, "ab": 8}, "b": 11},
+ {"a": {"aa": 5, "ab": 9}, "b": 12}])");
+
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, NestedNestedTwoFields3) {
+ // Arrow schema: struct(a: struct(aa: int32 not null,
+ // ab: int64) not null,
+ // b: int32) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "a", Repetition::REQUIRED,
+ {PrimitiveNode::Make("aa", Repetition::REQUIRED, ParquetType::INT32),
+ PrimitiveNode::Make("ab", Repetition::OPTIONAL, ParquetType::INT64)}),
+ PrimitiveNode::Make("b", Repetition::OPTIONAL, ParquetType::INT32)}));
+
+ // aa
+ ASSERT_OK(WriteInt32Column(DefLevels{}, RepLevels{}, Int32Vector{4, 5, 6}));
+ // ab
+ ASSERT_OK(WriteInt64Column(DefLevels{0, 1, 1}, RepLevels{}, Int64Vector{7, 8}));
+ // b
+ ASSERT_OK(WriteInt32Column(DefLevels{1, 0, 1}, RepLevels{}, Int32Vector{10, 11}));
+
+ auto type = struct_(
+ {field("a",
+ struct_({field("aa", int32(), /*nullable=*/false), field("ab", int64())}),
+ /*nullable=*/false),
+ field("b", int32())});
+ auto expected = ArrayFromJSON(type, R"([{"a": {"aa": 4, "ab": null}, "b": 10},
+ {"a": {"aa": 5, "ab": 7}, "b": null},
+ {"a": {"aa": 6, "ab": 8}, "b": 11}])");
+
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, NestedNestedTwoFields4) {
+ // Arrow schema: struct(a: struct(aa: int32 not null,
+ // ab: int64),
+ // b: int32 not null) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "a", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("aa", Repetition::REQUIRED, ParquetType::INT32),
+ PrimitiveNode::Make("ab", Repetition::OPTIONAL, ParquetType::INT64)}),
+ PrimitiveNode::Make("b", Repetition::REQUIRED, ParquetType::INT32)}));
+
+ // aa
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 1}, RepLevels{}, Int32Vector{4, 5}));
+ // ab
+ ASSERT_OK(WriteInt64Column(DefLevels{0, 1, 2}, RepLevels{}, Int64Vector{7}));
+ // b
+ ASSERT_OK(WriteInt32Column(DefLevels{}, RepLevels{}, Int32Vector{10, 11, 12}));
+
+ auto type = struct_({field("a", struct_({field("aa", int32(), /*nullable=*/false),
+ field("ab", int64())})),
+ field("b", int32(), /*nullable=*/false)});
+ auto expected = ArrayFromJSON(type, R"([{"a": null, "b": 10},
+ {"a": {"aa": 4, "ab": null}, "b": 11},
+ {"a": {"aa": 5, "ab": 7}, "b": 12}])");
+
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, NestedNestedTwoFields5) {
+ // Arrow schema: struct(a: struct(aa: int32 not null,
+ // ab: int64) not null,
+ // b: int32)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "a", Repetition::REQUIRED,
+ {PrimitiveNode::Make("aa", Repetition::REQUIRED, ParquetType::INT32),
+ PrimitiveNode::Make("ab", Repetition::OPTIONAL, ParquetType::INT64)}),
+ PrimitiveNode::Make("b", Repetition::OPTIONAL, ParquetType::INT32)}));
+
+ // aa
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 1}, RepLevels{}, Int32Vector{4, 5}));
+ // ab
+ ASSERT_OK(WriteInt64Column(DefLevels{0, 1, 2}, RepLevels{}, Int64Vector{7}));
+ // b
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 2, 1}, RepLevels{}, Int32Vector{10}));
+
+ auto type = struct_(
+ {field("a",
+ struct_({field("aa", int32(), /*nullable=*/false), field("ab", int64())}),
+ /*nullable=*/false),
+ field("b", int32())});
+ auto expected = ArrayFromJSON(type, R"([null,
+ {"a": {"aa": 4, "ab": null}, "b": 10},
+ {"a": {"aa": 5, "ab": 7}, "b": null}])");
+
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, NestedNestedTwoFields6) {
+ // Arrow schema: struct(a: struct(aa: int32 not null,
+ // ab: int64),
+ // b: int32)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "a", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("aa", Repetition::REQUIRED, ParquetType::INT32),
+ PrimitiveNode::Make("ab", Repetition::OPTIONAL, ParquetType::INT64)}),
+ PrimitiveNode::Make("b", Repetition::OPTIONAL, ParquetType::INT32)}));
+
+ // aa
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 2, 2}, RepLevels{}, Int32Vector{4, 5}));
+ // ab
+ ASSERT_OK(WriteInt64Column(DefLevels{0, 1, 2, 3}, RepLevels{}, Int64Vector{7}));
+ // b
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 2, 1, 2}, RepLevels{}, Int32Vector{10, 11}));
+
+ auto type = struct_({field("a", struct_({field("aa", int32(), /*nullable=*/false),
+ field("ab", int64())})),
+ field("b", int32())});
+ auto expected = ArrayFromJSON(type, R"([null,
+ {"a": null, "b": 10},
+ {"a": {"aa": 4, "ab": null}, "b": null},
+ {"a": {"aa": 5, "ab": 7}, "b": 11}])");
+
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+//
+// Three-level list encodings
+//
+
+TEST_F(TestReconstructColumn, ThreeLevelListRequiredRequired) {
+ // Arrow schema: list(int32 not null) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED, ParquetType::INT32)})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 1, 1};
+ LevelVector rep_levels = {0, 0, 1, 0};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ // TODO should field name "element" (Parquet convention for List nodes)
+ // be changed to "item" (Arrow convention for List types)?
+ auto expected = ArrayFromJSON(List(int32(), /*nullable=*/false), "[[], [4, 5], [6]]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, ThreeLevelListOptionalRequired) {
+ // Arrow schema: list(int32 not null)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED, ParquetType::INT32)})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 2, 2};
+ LevelVector rep_levels = {0, 0, 0, 1, 0};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto expected =
+ ArrayFromJSON(List(int32(), /*nullable=*/false), "[null, [], [4, 5], [6]]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, ThreeLevelListRequiredOptional) {
+ // Arrow schema: list(int32) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::OPTIONAL, ParquetType::INT32)})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 2, 2};
+ LevelVector rep_levels = {0, 0, 1, 0, 1};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto expected = ArrayFromJSON(List(int32()), "[[], [null, 4], [5, 6]]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, ThreeLevelListOptionalOptional) {
+ // Arrow schema: list(int32)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::OPTIONAL, ParquetType::INT32)})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 3, 3, 3};
+ LevelVector rep_levels = {0, 0, 0, 1, 0, 1};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto expected = ArrayFromJSON(List(int32()), "[null, [], [null, 4], [5, 6]]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+//
+// Legacy list encodings
+//
+
+TEST_F(TestReconstructColumn, TwoLevelListRequired) {
+ // Arrow schema: list(int32 not null) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {PrimitiveNode::Make("element", Repetition::REPEATED, ParquetType::INT32)},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 1, 1};
+ LevelVector rep_levels = {0, 0, 1, 0};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ // TODO should field name "element" (Parquet convention for List nodes)
+ // be changed to "item" (Arrow convention for List types)?
+ auto expected = ArrayFromJSON(List(int32(), /*nullable=*/false), "[[], [4, 5], [6]]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, TwoLevelListOptional) {
+ // Arrow schema: list(int32 not null)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("element", Repetition::REPEATED, ParquetType::INT32)},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 2, 2};
+ LevelVector rep_levels = {0, 0, 0, 1, 0};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto expected =
+ ArrayFromJSON(List(int32(), /*nullable=*/false), "[null, [], [4, 5], [6]]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+//
+// List-in-struct
+//
+
+TEST_F(TestReconstructColumn, NestedList1) {
+ // Arrow schema: struct(a: list(int32 not null) not null) not null
+ SetParquetSchema(GroupNode::Make(
+ "a", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "p", Repetition::REQUIRED,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED,
+ ParquetType::INT32)})},
+ LogicalType::List())}));
+
+ LevelVector def_levels = {0, 1, 1, 1};
+ LevelVector rep_levels = {0, 0, 1, 0};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = OneFieldStruct("p", List(int32(), /*nullable=*/false),
+ /*nullable=*/false);
+ auto expected = ArrayFromJSON(type, R"([{"p": []},
+ {"p": [4, 5]},
+ {"p": [6]}])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, NestedList2) {
+ // Arrow schema: struct(a: list(int32 not null) not null)
+ SetParquetSchema(GroupNode::Make(
+ "a", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "p", Repetition::REQUIRED,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED,
+ ParquetType::INT32)})},
+ LogicalType::List())}));
+
+ LevelVector def_levels = {0, 1, 2, 2, 2};
+ LevelVector rep_levels = {0, 0, 0, 1, 0};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = OneFieldStruct("p", List(int32(), /*nullable=*/false),
+ /*nullable=*/false);
+ auto expected = ArrayFromJSON(type, R"([null,
+ {"p": []},
+ {"p": [4, 5]},
+ {"p": [6]}])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, NestedList3) {
+ // Arrow schema: struct(a: list(int32 not null)) not null
+ SetParquetSchema(GroupNode::Make(
+ "a", Repetition::REQUIRED, // column name (column a is a struct of)
+ {GroupNode::Make(
+ "p", Repetition::OPTIONAL, // name in struct
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED,
+ ParquetType::INT32)})},
+ LogicalType::List())}));
+
+ LevelVector def_levels = {0, 1, 2, 2, 2};
+ LevelVector rep_levels = {0, 0, 0, 1, 0};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = OneFieldStruct("p", List(int32(), /*nullable=*/false));
+ auto expected = ArrayFromJSON(type, R"([{"p": null},
+ {"p": []},
+ {"p": [4, 5]},
+ {"p": [6]}])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, NestedList4) {
+ // Arrow schema: struct(a: list(int32 not null))
+ SetParquetSchema(GroupNode::Make(
+ "a", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "p", Repetition::OPTIONAL,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED,
+ ParquetType::INT32)})},
+ LogicalType::List())}));
+
+ LevelVector def_levels = {0, 1, 2, 3, 3, 3};
+ LevelVector rep_levels = {0, 0, 0, 0, 1, 0};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = OneFieldStruct("p", List(int32(), /*nullable=*/false));
+ auto expected = ArrayFromJSON(type, R"([null,
+ {"p": null},
+ {"p": []},
+ {"p": [4, 5]},
+ {"p": [6]}])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, NestedList5) {
+ // Arrow schema: struct(a: list(int32) not null)
+ SetParquetSchema(GroupNode::Make(
+ "a", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "p", Repetition::REQUIRED,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::OPTIONAL,
+ ParquetType::INT32)})},
+ LogicalType::List())}));
+
+ LevelVector def_levels = {0, 1, 3, 2, 3, 3};
+ LevelVector rep_levels = {0, 0, 0, 1, 0, 1};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = OneFieldStruct("p", List(int32()), /*nullable=*/false);
+ auto expected = ArrayFromJSON(type, R"([null,
+ {"p": []},
+ {"p": [4, null]},
+ {"p": [5, 6]}])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, NestedList6) {
+ // Arrow schema: struct(a: list(int32))
+ SetParquetSchema(GroupNode::Make(
+ "a", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "p", Repetition::OPTIONAL,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::OPTIONAL,
+ ParquetType::INT32)})},
+ LogicalType::List())}));
+
+ LevelVector def_levels = {0, 1, 2, 4, 3, 4, 4};
+ LevelVector rep_levels = {0, 0, 0, 0, 1, 0, 1};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = OneFieldStruct("p", List(int32()));
+ auto expected = ArrayFromJSON(type, R"([null,
+ {"p": null},
+ {"p": []},
+ {"p": [4, null]},
+ {"p": [5, 6]}])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+//
+// Struct-in-list
+//
+
+TEST_F(TestReconstructColumn, ListNested1) {
+ // Arrow schema: list(struct(a: int32 not null) not null) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {GroupNode::Make("element", Repetition::REQUIRED,
+ {PrimitiveNode::Make("a", Repetition::REQUIRED,
+ ParquetType::INT32)})})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 1, 1};
+ LevelVector rep_levels = {0, 0, 1, 0};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = List(OneFieldStruct("a", int32(), /*nullable=*/false),
+ /*nullable=*/false);
+ auto expected = ArrayFromJSON(type,
+ R"([[],
+ [{"a": 4}, {"a": 5}],
+ [{"a": 6}]])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, ListNested2) {
+ // Arrow schema: list(struct(a: int32 not null) not null)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {GroupNode::Make("element", Repetition::REQUIRED,
+ {PrimitiveNode::Make("a", Repetition::REQUIRED,
+ ParquetType::INT32)})})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 2, 2};
+ LevelVector rep_levels = {0, 0, 0, 1, 0};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = List(OneFieldStruct("a", int32(), /*nullable=*/false),
+ /*nullable=*/false);
+ auto expected = ArrayFromJSON(type,
+ R"([null,
+ [],
+ [{"a": 4}, {"a": 5}],
+ [{"a": 6}]])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, ListNested3) {
+ // Arrow schema: list(struct(a: int32 not null)) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {GroupNode::Make("element", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("a", Repetition::REQUIRED,
+ ParquetType::INT32)})})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 2, 2};
+ LevelVector rep_levels = {0, 0, 1, 1, 0};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = List(OneFieldStruct("a", int32(), /*nullable=*/false));
+ auto expected = ArrayFromJSON(type,
+ R"([[],
+ [null, {"a": 4}, {"a": 5}],
+ [{"a": 6}]])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, ListNested4) {
+ // Arrow schema: list(struct(a: int32 not null))
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {GroupNode::Make("element", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("a", Repetition::REQUIRED,
+ ParquetType::INT32)})})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 3, 3, 3};
+ LevelVector rep_levels = {0, 0, 0, 1, 1, 0};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = List(OneFieldStruct("a", int32(), /*nullable=*/false));
+ auto expected = ArrayFromJSON(type,
+ R"([null,
+ [],
+ [null, {"a": 4}, {"a": 5}],
+ [{"a": 6}]])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, ListNested5) {
+ // Arrow schema: list(struct(a: int32) not null)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {GroupNode::Make("element", Repetition::REQUIRED,
+ {PrimitiveNode::Make("a", Repetition::OPTIONAL,
+ ParquetType::INT32)})})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 3, 3, 3};
+ LevelVector rep_levels = {0, 0, 0, 1, 0, 1};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = List(OneFieldStruct("a", int32()),
+ /*nullable=*/false);
+ auto expected = ArrayFromJSON(type,
+ R"([null,
+ [],
+ [{"a": null}, {"a": 4}],
+ [{"a": 5}, {"a": 6}]])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, ListNested6) {
+ // Arrow schema: list(struct(a: int32))
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {GroupNode::Make("element", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("a", Repetition::OPTIONAL,
+ ParquetType::INT32)})})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 3, 4, 4, 4};
+ LevelVector rep_levels = {0, 0, 0, 1, 1, 0, 1};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = List(OneFieldStruct("a", int32()));
+ auto expected = ArrayFromJSON(type,
+ R"([null,
+ [],
+ [null, {"a": null}, {"a": 4}],
+ [{"a": 5}, {"a": 6}]])");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+//
+// Struct (two fields)-in-list
+//
+
+TEST_F(TestReconstructColumn, ListNestedTwoFields1) {
+ // Arrow schema: list(struct(a: int32 not null,
+ // b: int64 not null) not null) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {GroupNode::Make(
+ "element", Repetition::REQUIRED,
+ {PrimitiveNode::Make("a", Repetition::REQUIRED, ParquetType::INT32),
+ PrimitiveNode::Make("b", Repetition::REQUIRED, ParquetType::INT64)})})},
+ LogicalType::List()));
+
+ // a
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 1, 1}, RepLevels{0, 0, 1, 0},
+ Int32Vector{4, 5, 6}));
+ // b
+ ASSERT_OK(WriteInt64Column(DefLevels{0, 1, 1, 1}, RepLevels{0, 0, 1, 0},
+ Int64Vector{7, 8, 9}));
+
+ auto type = List(struct_({field("a", int32(), /*nullable=*/false),
+ field("b", int64(), /*nullable=*/false)}),
+ /*nullable=*/false);
+ auto expected = ArrayFromJSON(type,
+ R"([[],
+ [{"a": 4, "b": 7}, {"a": 5, "b": 8}],
+ [{"a": 6, "b": 9}]])");
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, ListNestedTwoFields2) {
+ // Arrow schema: list(struct(a: int32,
+ // b: int64 not null) not null) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {GroupNode::Make(
+ "element", Repetition::REQUIRED,
+ {PrimitiveNode::Make("a", Repetition::OPTIONAL, ParquetType::INT32),
+ PrimitiveNode::Make("b", Repetition::REQUIRED, ParquetType::INT64)})})},
+ LogicalType::List()));
+
+ // a
+ ASSERT_OK(
+ WriteInt32Column(DefLevels{0, 2, 1, 2}, RepLevels{0, 0, 1, 0}, Int32Vector{4, 5}));
+ // b
+ ASSERT_OK(WriteInt64Column(DefLevels{0, 1, 1, 1}, RepLevels{0, 0, 1, 0},
+ Int64Vector{7, 8, 9}));
+
+ auto type =
+ List(struct_({field("a", int32()), field("b", int64(), /*nullable=*/false)}),
+ /*nullable=*/false);
+ auto expected = ArrayFromJSON(type,
+ R"([[],
+ [{"a": 4, "b": 7}, {"a": null, "b": 8}],
+ [{"a": 5, "b": 9}]])");
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, ListNestedTwoFields3) {
+ // Arrow schema: list(struct(a: int32 not null,
+ // b: int64 not null)) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {GroupNode::Make(
+ "element", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("a", Repetition::REQUIRED, ParquetType::INT32),
+ PrimitiveNode::Make("b", Repetition::REQUIRED, ParquetType::INT64)})})},
+ LogicalType::List()));
+
+ // a
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 2, 2, 2}, RepLevels{0, 0, 1, 1, 0},
+ Int32Vector{4, 5, 6}));
+ // b
+ ASSERT_OK(WriteInt64Column(DefLevels{0, 1, 2, 2, 2}, RepLevels{0, 0, 1, 1, 0},
+ Int64Vector{7, 8, 9}));
+
+ auto type = List(struct_({field("a", int32(), /*nullable=*/false),
+ field("b", int64(), /*nullable=*/false)}));
+ auto expected = ArrayFromJSON(type,
+ R"([[],
+ [null, {"a": 4, "b": 7}, {"a": 5, "b": 8}],
+ [{"a": 6, "b": 9}]])");
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, ListNestedTwoFields4) {
+ // Arrow schema: list(struct(a: int32,
+ // b: int64 not null) not null)
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {GroupNode::Make(
+ "element", Repetition::REQUIRED,
+ {PrimitiveNode::Make("a", Repetition::OPTIONAL, ParquetType::INT32),
+ PrimitiveNode::Make("b", Repetition::REQUIRED, ParquetType::INT64)})})},
+ LogicalType::List()));
+
+ // a
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 3, 2, 3}, RepLevels{0, 0, 0, 1, 0},
+ Int32Vector{4, 5}));
+ // b
+ ASSERT_OK(WriteInt64Column(DefLevels{0, 1, 2, 2, 2}, RepLevels{0, 0, 0, 1, 0},
+ Int64Vector{7, 8, 9}));
+
+ auto type =
+ List(struct_({field("a", int32()), field("b", int64(), /*nullable=*/false)}),
+ /*nullable=*/false);
+ auto expected = ArrayFromJSON(type,
+ R"([null,
+ [],
+ [{"a": 4, "b": 7}, {"a": null, "b": 8}],
+ [{"a": 5, "b": 9}]])");
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, ListNestedTwoFields5) {
+ // Arrow schema: list(struct(a: int32,
+ // b: int64 not null))
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {GroupNode::Make(
+ "element", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("a", Repetition::OPTIONAL, ParquetType::INT32),
+ PrimitiveNode::Make("b", Repetition::REQUIRED, ParquetType::INT64)})})},
+ LogicalType::List()));
+
+ // a
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 4, 2, 3}, RepLevels{0, 0, 0, 1, 0},
+ Int32Vector{4}));
+ // b
+ ASSERT_OK(WriteInt64Column(DefLevels{0, 1, 3, 2, 3}, RepLevels{0, 0, 0, 1, 0},
+ Int64Vector{7, 8}));
+
+ auto type =
+ List(struct_({field("a", int32()), field("b", int64(), /*nullable=*/false)}));
+ auto expected = ArrayFromJSON(type,
+ R"([null,
+ [],
+ [{"a": 4, "b": 7}, null],
+ [{"a": null, "b": 8}]])");
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, ListNestedTwoFields6) {
+ // Arrow schema: list(struct(a: int32,
+ // b: int64))
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {GroupNode::Make(
+ "element", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("a", Repetition::OPTIONAL, ParquetType::INT32),
+ PrimitiveNode::Make("b", Repetition::OPTIONAL, ParquetType::INT64)})})},
+ LogicalType::List()));
+
+ // a
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 4, 2, 3}, RepLevels{0, 0, 0, 1, 0},
+ Int32Vector{4}));
+ // b
+ ASSERT_OK(WriteInt64Column(DefLevels{0, 1, 3, 2, 4}, RepLevels{0, 0, 0, 1, 0},
+ Int64Vector{7}));
+
+ auto type = List(struct_({field("a", int32()), field("b", int64())}));
+ auto expected = ArrayFromJSON(type,
+ R"([null,
+ [],
+ [{"a": 4, "b": null}, null],
+ [{"a": null, "b": 7}]])");
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+//
+// List-in-struct (two fields)
+//
+
+TEST_F(TestReconstructColumn, NestedTwoFieldsList1) {
+ // Arrow schema: struct(a: int64 not null,
+ // b: list(int32 not null) not null
+ // ) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {PrimitiveNode::Make("a", Repetition::REQUIRED, ParquetType::INT64),
+ GroupNode::Make(
+ "b", Repetition::REQUIRED,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED,
+ ParquetType::INT32)})},
+ LogicalType::List())}));
+
+ // a
+ ASSERT_OK(WriteInt64Column(DefLevels{}, RepLevels{}, Int64Vector{4, 5, 6}));
+ // b
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 1, 1}, RepLevels{0, 0, 1, 0},
+ Int32Vector{7, 8, 9}));
+
+ auto type =
+ struct_({field("a", int64(), /*nullable=*/false),
+ field("b", List(int32(), /*nullable=*/false), /*nullable=*/false)});
+ auto expected = ArrayFromJSON(type,
+ R"([{"a": 4, "b": []},
+ {"a": 5, "b": [7, 8]},
+ {"a": 6, "b": [9]}])");
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, NestedTwoFieldsList2) {
+ // Arrow schema: struct(a: int64 not null,
+ // b: list(int32 not null)
+ // ) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {PrimitiveNode::Make("a", Repetition::REQUIRED, ParquetType::INT64),
+ GroupNode::Make(
+ "b", Repetition::OPTIONAL,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED,
+ ParquetType::INT32)})},
+ LogicalType::List())}));
+
+ // a
+ ASSERT_OK(WriteInt64Column(DefLevels{}, RepLevels{}, Int64Vector{3, 4, 5, 6}));
+ // b
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 2, 2, 2}, RepLevels{0, 0, 0, 1, 0},
+ Int32Vector{7, 8, 9}));
+
+ auto type = struct_({field("a", int64(), /*nullable=*/false),
+ field("b", List(int32(), /*nullable=*/false))});
+ auto expected = ArrayFromJSON(type,
+ R"([{"a": 3, "b": null},
+ {"a": 4, "b": []},
+ {"a": 5, "b": [7, 8]},
+ {"a": 6, "b": [9]}])");
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, NestedTwoFieldsList3) {
+ // Arrow schema: struct(a: int64,
+ // b: list(int32 not null)
+ // ) not null
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::REQUIRED,
+ {PrimitiveNode::Make("a", Repetition::OPTIONAL, ParquetType::INT64),
+ GroupNode::Make(
+ "b", Repetition::OPTIONAL,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED,
+ ParquetType::INT32)})},
+ LogicalType::List())}));
+
+ // a
+ ASSERT_OK(WriteInt64Column(DefLevels{1, 1, 0, 1}, RepLevels{}, Int64Vector{4, 5, 6}));
+ // b
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 2, 2, 2}, RepLevels{0, 0, 0, 1, 0},
+ Int32Vector{7, 8, 9}));
+
+ auto type =
+ struct_({field("a", int64()), field("b", List(int32(), /*nullable=*/false))});
+ auto expected = ArrayFromJSON(type,
+ R"([{"a": 4, "b": null},
+ {"a": 5, "b": []},
+ {"a": null, "b": [7, 8]},
+ {"a": 6, "b": [9]}])");
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, NestedTwoFieldsList4) {
+ // Arrow schema: struct(a: int64,
+ // b: list(int32 not null)
+ // )
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("a", Repetition::OPTIONAL, ParquetType::INT64),
+ GroupNode::Make(
+ "b", Repetition::OPTIONAL,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED,
+ ParquetType::INT32)})},
+ LogicalType::List())}));
+
+ // a
+ ASSERT_OK(
+ WriteInt64Column(DefLevels{0, 2, 2, 1, 2}, RepLevels{}, Int64Vector{4, 5, 6}));
+ // b
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 2, 3, 3, 3}, RepLevels{0, 0, 0, 0, 1, 0},
+ Int32Vector{7, 8, 9}));
+
+ auto type =
+ struct_({field("a", int64()), field("b", List(int32(), /*nullable=*/false))});
+ auto expected = ArrayFromJSON(type,
+ R"([null,
+ {"a": 4, "b": null},
+ {"a": 5, "b": []},
+ {"a": null, "b": [7, 8]},
+ {"a": 6, "b": [9]}])");
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+TEST_F(TestReconstructColumn, NestedTwoFieldsList5) {
+ // Arrow schema: struct(a: int64, b: list(int32))
+ SetParquetSchema(GroupNode::Make(
+ "parent", Repetition::OPTIONAL,
+ {PrimitiveNode::Make("a", Repetition::OPTIONAL, ParquetType::INT64),
+ GroupNode::Make(
+ "b", Repetition::OPTIONAL,
+ {GroupNode::Make("list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::OPTIONAL,
+ ParquetType::INT32)})},
+ LogicalType::List())}));
+
+ // a
+ ASSERT_OK(
+ WriteInt64Column(DefLevels{0, 2, 2, 1, 2}, RepLevels{}, Int64Vector{4, 5, 6}));
+ // b
+ ASSERT_OK(WriteInt32Column(DefLevels{0, 1, 2, 4, 3, 4}, RepLevels{0, 0, 0, 0, 1, 0},
+ Int32Vector{7, 8}));
+
+ auto type = struct_({field("a", int64()), field("b", List(int32()))});
+ auto expected = ArrayFromJSON(type,
+ R"([null,
+ {"a": 4, "b": null},
+ {"a": 5, "b": []},
+ {"a": null, "b": [7, null]},
+ {"a": 6, "b": [8]}])");
+ CheckColumn(/*column_index=*/0, *expected);
+}
+
+//
+// List-in-list
+//
+
+TEST_F(TestReconstructColumn, ListList1) {
+ // Arrow schema: list(list(int32 not null) not null) not null
+ auto inner_list = GroupNode::Make(
+ "element", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED, ParquetType::INT32)})},
+ LogicalType::List());
+ SetParquetSchema(
+ GroupNode::Make("parent", Repetition::REQUIRED,
+ {GroupNode::Make("list", Repetition::REPEATED, {inner_list})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 2, 2};
+ LevelVector rep_levels = {0, 0, 1, 0, 2};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = List(List(int32(), /*nullable=*/false), /*nullable=*/false);
+ auto expected = ArrayFromJSON(type, "[[], [[], [4]], [[5, 6]]]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, ListList2) {
+ // Arrow schema: list(list(int32 not null) not null)
+ auto inner_list = GroupNode::Make(
+ "element", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED, ParquetType::INT32)})},
+ LogicalType::List());
+ SetParquetSchema(
+ GroupNode::Make("parent", Repetition::OPTIONAL,
+ {GroupNode::Make("list", Repetition::REPEATED, {inner_list})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 3, 3, 3};
+ LevelVector rep_levels = {0, 0, 0, 1, 0, 2};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = List(List(int32(), /*nullable=*/false), /*nullable=*/false);
+ auto expected = ArrayFromJSON(type, "[null, [], [[], [4]], [[5, 6]]]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, ListList3) {
+ // Arrow schema: list(list(int32 not null)) not null
+ auto inner_list = GroupNode::Make(
+ "element", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED, ParquetType::INT32)})},
+ LogicalType::List());
+ SetParquetSchema(
+ GroupNode::Make("parent", Repetition::REQUIRED,
+ {GroupNode::Make("list", Repetition::REPEATED, {inner_list})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 3, 3, 3};
+ LevelVector rep_levels = {0, 0, 1, 0, 1, 2};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = List(List(int32(), /*nullable=*/false));
+ auto expected = ArrayFromJSON(type, "[[], [null, []], [[4], [5, 6]]]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, ListList4) {
+ // Arrow schema: list(list(int32 not null))
+ auto inner_list = GroupNode::Make(
+ "element", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::REQUIRED, ParquetType::INT32)})},
+ LogicalType::List());
+ SetParquetSchema(
+ GroupNode::Make("parent", Repetition::OPTIONAL,
+ {GroupNode::Make("list", Repetition::REPEATED, {inner_list})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 3, 4, 4, 4};
+ LevelVector rep_levels = {0, 0, 0, 1, 1, 0, 2};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = List(List(int32(), /*nullable=*/false));
+ auto expected = ArrayFromJSON(type, "[null, [], [null, [], [4]], [[5, 6]]]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, ListList5) {
+ // Arrow schema: list(list(int32) not null)
+ auto inner_list = GroupNode::Make(
+ "element", Repetition::REQUIRED,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::OPTIONAL, ParquetType::INT32)})},
+ LogicalType::List());
+ SetParquetSchema(
+ GroupNode::Make("parent", Repetition::OPTIONAL,
+ {GroupNode::Make("list", Repetition::REPEATED, {inner_list})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 4, 4, 3, 4};
+ LevelVector rep_levels = {0, 0, 0, 1, 0, 1, 2};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = List(List(int32()), /*nullable=*/false);
+ auto expected = ArrayFromJSON(type, "[null, [], [[], [4]], [[5], [null, 6]]]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+TEST_F(TestReconstructColumn, ListList6) {
+ // Arrow schema: list(list(int32))
+ auto inner_list = GroupNode::Make(
+ "element", Repetition::OPTIONAL,
+ {GroupNode::Make(
+ "list", Repetition::REPEATED,
+ {PrimitiveNode::Make("element", Repetition::OPTIONAL, ParquetType::INT32)})},
+ LogicalType::List());
+ SetParquetSchema(
+ GroupNode::Make("parent", Repetition::OPTIONAL,
+ {GroupNode::Make("list", Repetition::REPEATED, {inner_list})},
+ LogicalType::List()));
+
+ LevelVector def_levels = {0, 1, 2, 3, 4, 5, 5, 5};
+ LevelVector rep_levels = {0, 0, 0, 1, 1, 2, 0, 2};
+ std::vector<int32_t> values = {4, 5, 6};
+
+ auto type = List(List(int32()));
+ auto expected = ArrayFromJSON(type, "[null, [], [null, [], [null, 4]], [[5, 6]]]");
+ AssertReconstruct<ParquetType::INT32>(*expected, def_levels, rep_levels, values);
+}
+
+// TODO legacy-list-in-struct etc.?
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/schema.cc b/src/arrow/cpp/src/parquet/arrow/schema.cc
new file mode 100644
index 000000000..19f78a507
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/schema.cc
@@ -0,0 +1,1093 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/arrow/schema.h"
+
+#include <functional>
+#include <string>
+#include <vector>
+
+#include "arrow/extension_type.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/api.h"
+#include "arrow/result_internal.h"
+#include "arrow/type.h"
+#include "arrow/util/base64.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/value_parsing.h"
+
+#include "parquet/arrow/schema_internal.h"
+#include "parquet/exception.h"
+#include "parquet/metadata.h"
+#include "parquet/properties.h"
+#include "parquet/types.h"
+
+using arrow::DecimalType;
+using arrow::Field;
+using arrow::FieldVector;
+using arrow::KeyValueMetadata;
+using arrow::Status;
+using arrow::internal::checked_cast;
+
+using ArrowType = arrow::DataType;
+using ArrowTypeId = arrow::Type;
+
+using parquet::Repetition;
+using parquet::schema::GroupNode;
+using parquet::schema::Node;
+using parquet::schema::NodePtr;
+using parquet::schema::PrimitiveNode;
+
+using ParquetType = parquet::Type;
+using parquet::ConvertedType;
+using parquet::LogicalType;
+
+using parquet::internal::LevelInfo;
+
+namespace parquet {
+
+namespace arrow {
+
+// ----------------------------------------------------------------------
+// Parquet to Arrow schema conversion
+
+namespace {
+
+Repetition::type RepetitionFromNullable(bool is_nullable) {
+ return is_nullable ? Repetition::OPTIONAL : Repetition::REQUIRED;
+}
+
+Status FieldToNode(const std::string& name, const std::shared_ptr<Field>& field,
+ const WriterProperties& properties,
+ const ArrowWriterProperties& arrow_properties, NodePtr* out);
+
+Status ListToNode(const std::shared_ptr<::arrow::BaseListType>& type,
+ const std::string& name, bool nullable,
+ const WriterProperties& properties,
+ const ArrowWriterProperties& arrow_properties, NodePtr* out) {
+ NodePtr element;
+ std::string value_name =
+ arrow_properties.compliant_nested_types() ? "element" : type->value_field()->name();
+ RETURN_NOT_OK(FieldToNode(value_name, type->value_field(), properties, arrow_properties,
+ &element));
+
+ NodePtr list = GroupNode::Make("list", Repetition::REPEATED, {element});
+ *out = GroupNode::Make(name, RepetitionFromNullable(nullable), {list},
+ LogicalType::List());
+ return Status::OK();
+}
+
+Status MapToNode(const std::shared_ptr<::arrow::MapType>& type, const std::string& name,
+ bool nullable, const WriterProperties& properties,
+ const ArrowWriterProperties& arrow_properties, NodePtr* out) {
+ // TODO: Should we offer a non-compliant mode that forwards the type names?
+ NodePtr key_node;
+ RETURN_NOT_OK(
+ FieldToNode("key", type->key_field(), properties, arrow_properties, &key_node));
+
+ NodePtr value_node;
+ RETURN_NOT_OK(FieldToNode("value", type->item_field(), properties, arrow_properties,
+ &value_node));
+
+ NodePtr key_value =
+ GroupNode::Make("key_value", Repetition::REPEATED, {key_node, value_node});
+ *out = GroupNode::Make(name, RepetitionFromNullable(nullable), {key_value},
+ LogicalType::Map());
+ return Status::OK();
+}
+
+Status StructToNode(const std::shared_ptr<::arrow::StructType>& type,
+ const std::string& name, bool nullable,
+ const WriterProperties& properties,
+ const ArrowWriterProperties& arrow_properties, NodePtr* out) {
+ std::vector<NodePtr> children(type->num_fields());
+ if (type->num_fields() != 0) {
+ for (int i = 0; i < type->num_fields(); i++) {
+ RETURN_NOT_OK(FieldToNode(type->field(i)->name(), type->field(i), properties,
+ arrow_properties, &children[i]));
+ }
+ } else {
+ // XXX (ARROW-10928) We could add a dummy primitive node but that would
+ // require special handling when writing and reading, to avoid column index
+ // mismatches.
+ return Status::NotImplemented("Cannot write struct type '", name,
+ "' with no child field to Parquet. "
+ "Consider adding a dummy child field.");
+ }
+
+ *out = GroupNode::Make(name, RepetitionFromNullable(nullable), std::move(children));
+ return Status::OK();
+}
+
+static std::shared_ptr<const LogicalType> TimestampLogicalTypeFromArrowTimestamp(
+ const ::arrow::TimestampType& timestamp_type, ::arrow::TimeUnit::type time_unit) {
+ const bool utc = !(timestamp_type.timezone().empty());
+ // ARROW-5878(wesm): for forward compatibility reasons, and because
+ // there's no other way to signal to old readers that values are
+ // timestamps, we force the ConvertedType field to be set to the
+ // corresponding TIMESTAMP_* value. This does cause some ambiguity
+ // as Parquet readers have not been consistent about the
+ // interpretation of TIMESTAMP_* values as being UTC-normalized.
+ switch (time_unit) {
+ case ::arrow::TimeUnit::MILLI:
+ return LogicalType::Timestamp(utc, LogicalType::TimeUnit::MILLIS,
+ /*is_from_converted_type=*/false,
+ /*force_set_converted_type=*/true);
+ case ::arrow::TimeUnit::MICRO:
+ return LogicalType::Timestamp(utc, LogicalType::TimeUnit::MICROS,
+ /*is_from_converted_type=*/false,
+ /*force_set_converted_type=*/true);
+ case ::arrow::TimeUnit::NANO:
+ return LogicalType::Timestamp(utc, LogicalType::TimeUnit::NANOS);
+ case ::arrow::TimeUnit::SECOND:
+ // No equivalent parquet logical type.
+ break;
+ }
+ return LogicalType::None();
+}
+
+static Status GetTimestampMetadata(const ::arrow::TimestampType& type,
+ const WriterProperties& properties,
+ const ArrowWriterProperties& arrow_properties,
+ ParquetType::type* physical_type,
+ std::shared_ptr<const LogicalType>* logical_type) {
+ const bool coerce = arrow_properties.coerce_timestamps_enabled();
+ const auto target_unit =
+ coerce ? arrow_properties.coerce_timestamps_unit() : type.unit();
+ const auto version = properties.version();
+
+ // The user is explicitly asking for Impala int96 encoding, there is no
+ // logical type.
+ if (arrow_properties.support_deprecated_int96_timestamps()) {
+ *physical_type = ParquetType::INT96;
+ return Status::OK();
+ }
+
+ *physical_type = ParquetType::INT64;
+ *logical_type = TimestampLogicalTypeFromArrowTimestamp(type, target_unit);
+
+ // The user is explicitly asking for timestamp data to be converted to the
+ // specified units (target_unit).
+ if (coerce) {
+ if (version == ::parquet::ParquetVersion::PARQUET_1_0 ||
+ version == ::parquet::ParquetVersion::PARQUET_2_4) {
+ switch (target_unit) {
+ case ::arrow::TimeUnit::MILLI:
+ case ::arrow::TimeUnit::MICRO:
+ break;
+ case ::arrow::TimeUnit::NANO:
+ case ::arrow::TimeUnit::SECOND:
+ return Status::NotImplemented("For Parquet version ",
+ ::parquet::ParquetVersionToString(version),
+ ", can only coerce Arrow timestamps to "
+ "milliseconds or microseconds");
+ }
+ } else {
+ switch (target_unit) {
+ case ::arrow::TimeUnit::MILLI:
+ case ::arrow::TimeUnit::MICRO:
+ case ::arrow::TimeUnit::NANO:
+ break;
+ case ::arrow::TimeUnit::SECOND:
+ return Status::NotImplemented("For Parquet version ",
+ ::parquet::ParquetVersionToString(version),
+ ", can only coerce Arrow timestamps to "
+ "milliseconds, microseconds, or nanoseconds");
+ }
+ }
+ return Status::OK();
+ }
+
+ // The user implicitly wants timestamp data to retain its original time units,
+ // however the ConvertedType field used to indicate logical types for Parquet
+ // version <= 2.4 fields does not allow for nanosecond time units and so nanoseconds
+ // must be coerced to microseconds.
+ if ((version == ::parquet::ParquetVersion::PARQUET_1_0 ||
+ version == ::parquet::ParquetVersion::PARQUET_2_4) &&
+ type.unit() == ::arrow::TimeUnit::NANO) {
+ *logical_type =
+ TimestampLogicalTypeFromArrowTimestamp(type, ::arrow::TimeUnit::MICRO);
+ return Status::OK();
+ }
+
+ // The user implicitly wants timestamp data to retain its original time units,
+ // however the Arrow seconds time unit can not be represented (annotated) in
+ // any version of Parquet and so must be coerced to milliseconds.
+ if (type.unit() == ::arrow::TimeUnit::SECOND) {
+ *logical_type =
+ TimestampLogicalTypeFromArrowTimestamp(type, ::arrow::TimeUnit::MILLI);
+ return Status::OK();
+ }
+
+ return Status::OK();
+}
+
+static constexpr char FIELD_ID_KEY[] = "PARQUET:field_id";
+
+std::shared_ptr<::arrow::KeyValueMetadata> FieldIdMetadata(int field_id) {
+ if (field_id >= 0) {
+ return ::arrow::key_value_metadata({FIELD_ID_KEY}, {std::to_string(field_id)});
+ } else {
+ return nullptr;
+ }
+}
+
+int FieldIdFromMetadata(
+ const std::shared_ptr<const ::arrow::KeyValueMetadata>& metadata) {
+ if (!metadata) {
+ return -1;
+ }
+ int key = metadata->FindKey(FIELD_ID_KEY);
+ if (key < 0) {
+ return -1;
+ }
+ std::string field_id_str = metadata->value(key);
+ int field_id;
+ if (::arrow::internal::ParseValue<::arrow::Int32Type>(
+ field_id_str.c_str(), field_id_str.length(), &field_id)) {
+ if (field_id < 0) {
+ // Thrift should convert any negative value to null but normalize to -1 here in case
+ // we later check this in logic.
+ return -1;
+ }
+ return field_id;
+ } else {
+ return -1;
+ }
+}
+
+Status FieldToNode(const std::string& name, const std::shared_ptr<Field>& field,
+ const WriterProperties& properties,
+ const ArrowWriterProperties& arrow_properties, NodePtr* out) {
+ std::shared_ptr<const LogicalType> logical_type = LogicalType::None();
+ ParquetType::type type;
+ Repetition::type repetition = RepetitionFromNullable(field->nullable());
+
+ int length = -1;
+ int precision = -1;
+ int scale = -1;
+
+ switch (field->type()->id()) {
+ case ArrowTypeId::NA: {
+ type = ParquetType::INT32;
+ logical_type = LogicalType::Null();
+ if (repetition != Repetition::OPTIONAL) {
+ return Status::Invalid("NullType Arrow field must be nullable");
+ }
+ } break;
+ case ArrowTypeId::BOOL:
+ type = ParquetType::BOOLEAN;
+ break;
+ case ArrowTypeId::UINT8:
+ type = ParquetType::INT32;
+ logical_type = LogicalType::Int(8, false);
+ break;
+ case ArrowTypeId::INT8:
+ type = ParquetType::INT32;
+ logical_type = LogicalType::Int(8, true);
+ break;
+ case ArrowTypeId::UINT16:
+ type = ParquetType::INT32;
+ logical_type = LogicalType::Int(16, false);
+ break;
+ case ArrowTypeId::INT16:
+ type = ParquetType::INT32;
+ logical_type = LogicalType::Int(16, true);
+ break;
+ case ArrowTypeId::UINT32:
+ if (properties.version() == ::parquet::ParquetVersion::PARQUET_1_0) {
+ type = ParquetType::INT64;
+ } else {
+ type = ParquetType::INT32;
+ logical_type = LogicalType::Int(32, false);
+ }
+ break;
+ case ArrowTypeId::INT32:
+ type = ParquetType::INT32;
+ break;
+ case ArrowTypeId::UINT64:
+ type = ParquetType::INT64;
+ logical_type = LogicalType::Int(64, false);
+ break;
+ case ArrowTypeId::INT64:
+ type = ParquetType::INT64;
+ break;
+ case ArrowTypeId::FLOAT:
+ type = ParquetType::FLOAT;
+ break;
+ case ArrowTypeId::DOUBLE:
+ type = ParquetType::DOUBLE;
+ break;
+ case ArrowTypeId::LARGE_STRING:
+ case ArrowTypeId::STRING:
+ type = ParquetType::BYTE_ARRAY;
+ logical_type = LogicalType::String();
+ break;
+ case ArrowTypeId::LARGE_BINARY:
+ case ArrowTypeId::BINARY:
+ type = ParquetType::BYTE_ARRAY;
+ break;
+ case ArrowTypeId::FIXED_SIZE_BINARY: {
+ type = ParquetType::FIXED_LEN_BYTE_ARRAY;
+ const auto& fixed_size_binary_type =
+ static_cast<const ::arrow::FixedSizeBinaryType&>(*field->type());
+ length = fixed_size_binary_type.byte_width();
+ } break;
+ case ArrowTypeId::DECIMAL128:
+ case ArrowTypeId::DECIMAL256: {
+ type = ParquetType::FIXED_LEN_BYTE_ARRAY;
+ const auto& decimal_type = static_cast<const ::arrow::DecimalType&>(*field->type());
+ precision = decimal_type.precision();
+ scale = decimal_type.scale();
+ length = DecimalType::DecimalSize(precision);
+ PARQUET_CATCH_NOT_OK(logical_type = LogicalType::Decimal(precision, scale));
+ } break;
+ case ArrowTypeId::DATE32:
+ type = ParquetType::INT32;
+ logical_type = LogicalType::Date();
+ break;
+ case ArrowTypeId::DATE64:
+ type = ParquetType::INT32;
+ logical_type = LogicalType::Date();
+ break;
+ case ArrowTypeId::TIMESTAMP:
+ RETURN_NOT_OK(
+ GetTimestampMetadata(static_cast<::arrow::TimestampType&>(*field->type()),
+ properties, arrow_properties, &type, &logical_type));
+ break;
+ case ArrowTypeId::TIME32:
+ type = ParquetType::INT32;
+ logical_type =
+ LogicalType::Time(/*is_adjusted_to_utc=*/true, LogicalType::TimeUnit::MILLIS);
+ break;
+ case ArrowTypeId::TIME64: {
+ type = ParquetType::INT64;
+ auto time_type = static_cast<::arrow::Time64Type*>(field->type().get());
+ if (time_type->unit() == ::arrow::TimeUnit::NANO) {
+ logical_type =
+ LogicalType::Time(/*is_adjusted_to_utc=*/true, LogicalType::TimeUnit::NANOS);
+ } else {
+ logical_type =
+ LogicalType::Time(/*is_adjusted_to_utc=*/true, LogicalType::TimeUnit::MICROS);
+ }
+ } break;
+ case ArrowTypeId::STRUCT: {
+ auto struct_type = std::static_pointer_cast<::arrow::StructType>(field->type());
+ return StructToNode(struct_type, name, field->nullable(), properties,
+ arrow_properties, out);
+ }
+ case ArrowTypeId::FIXED_SIZE_LIST:
+ case ArrowTypeId::LARGE_LIST:
+ case ArrowTypeId::LIST: {
+ auto list_type = std::static_pointer_cast<::arrow::BaseListType>(field->type());
+ return ListToNode(list_type, name, field->nullable(), properties, arrow_properties,
+ out);
+ }
+ case ArrowTypeId::DICTIONARY: {
+ // Parquet has no Dictionary type, dictionary-encoded is handled on
+ // the encoding, not the schema level.
+ const ::arrow::DictionaryType& dict_type =
+ static_cast<const ::arrow::DictionaryType&>(*field->type());
+ std::shared_ptr<::arrow::Field> unpacked_field = ::arrow::field(
+ name, dict_type.value_type(), field->nullable(), field->metadata());
+ return FieldToNode(name, unpacked_field, properties, arrow_properties, out);
+ }
+ case ArrowTypeId::EXTENSION: {
+ auto ext_type = std::static_pointer_cast<::arrow::ExtensionType>(field->type());
+ std::shared_ptr<::arrow::Field> storage_field = ::arrow::field(
+ name, ext_type->storage_type(), field->nullable(), field->metadata());
+ return FieldToNode(name, storage_field, properties, arrow_properties, out);
+ }
+ case ArrowTypeId::MAP: {
+ auto map_type = std::static_pointer_cast<::arrow::MapType>(field->type());
+ return MapToNode(map_type, name, field->nullable(), properties, arrow_properties,
+ out);
+ }
+
+ default: {
+ // TODO: DENSE_UNION, SPARE_UNION, JSON_SCALAR, DECIMAL_TEXT, VARCHAR
+ return Status::NotImplemented(
+ "Unhandled type for Arrow to Parquet schema conversion: ",
+ field->type()->ToString());
+ }
+ }
+
+ int field_id = FieldIdFromMetadata(field->metadata());
+ PARQUET_CATCH_NOT_OK(*out = PrimitiveNode::Make(name, repetition, logical_type, type,
+ length, field_id));
+
+ return Status::OK();
+}
+
+struct SchemaTreeContext {
+ SchemaManifest* manifest;
+ ArrowReaderProperties properties;
+ const SchemaDescriptor* schema;
+
+ void LinkParent(const SchemaField* child, const SchemaField* parent) {
+ manifest->child_to_parent[child] = parent;
+ }
+
+ void RecordLeaf(const SchemaField* leaf) {
+ manifest->column_index_to_field[leaf->column_index] = leaf;
+ }
+};
+
+bool IsDictionaryReadSupported(const ArrowType& type) {
+ // Only supported currently for BYTE_ARRAY types
+ return type.id() == ::arrow::Type::BINARY || type.id() == ::arrow::Type::STRING;
+}
+
+// ----------------------------------------------------------------------
+// Schema logic
+
+::arrow::Result<std::shared_ptr<ArrowType>> GetTypeForNode(
+ int column_index, const schema::PrimitiveNode& primitive_node,
+ SchemaTreeContext* ctx) {
+ ASSIGN_OR_RAISE(
+ std::shared_ptr<ArrowType> storage_type,
+ GetArrowType(primitive_node, ctx->properties.coerce_int96_timestamp_unit()));
+ if (ctx->properties.read_dictionary(column_index) &&
+ IsDictionaryReadSupported(*storage_type)) {
+ return ::arrow::dictionary(::arrow::int32(), storage_type);
+ }
+ return storage_type;
+}
+
+Status NodeToSchemaField(const Node& node, LevelInfo current_levels,
+ SchemaTreeContext* ctx, const SchemaField* parent,
+ SchemaField* out);
+
+Status GroupToSchemaField(const GroupNode& node, LevelInfo current_levels,
+ SchemaTreeContext* ctx, const SchemaField* parent,
+ SchemaField* out);
+
+Status PopulateLeaf(int column_index, const std::shared_ptr<Field>& field,
+ LevelInfo current_levels, SchemaTreeContext* ctx,
+ const SchemaField* parent, SchemaField* out) {
+ out->field = field;
+ out->column_index = column_index;
+ out->level_info = current_levels;
+ ctx->RecordLeaf(out);
+ ctx->LinkParent(out, parent);
+ return Status::OK();
+}
+
+// Special case mentioned in the format spec:
+// If the name is array or ends in _tuple, this should be a list of struct
+// even for single child elements.
+bool HasStructListName(const GroupNode& node) {
+ ::arrow::util::string_view name{node.name()};
+ return name == "array" || name.ends_with("_tuple");
+}
+
+Status GroupToStruct(const GroupNode& node, LevelInfo current_levels,
+ SchemaTreeContext* ctx, const SchemaField* parent,
+ SchemaField* out) {
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+ out->children.resize(node.field_count());
+ // All level increments for the node are expected to happen by callers.
+ // This is required because repeated elements need to have there own
+ // SchemaField.
+
+ for (int i = 0; i < node.field_count(); i++) {
+ RETURN_NOT_OK(
+ NodeToSchemaField(*node.field(i), current_levels, ctx, out, &out->children[i]));
+ arrow_fields.push_back(out->children[i].field);
+ }
+ auto struct_type = ::arrow::struct_(arrow_fields);
+ out->field = ::arrow::field(node.name(), struct_type, node.is_optional(),
+ FieldIdMetadata(node.field_id()));
+ out->level_info = current_levels;
+ return Status::OK();
+}
+
+Status ListToSchemaField(const GroupNode& group, LevelInfo current_levels,
+ SchemaTreeContext* ctx, const SchemaField* parent,
+ SchemaField* out);
+
+Status MapToSchemaField(const GroupNode& group, LevelInfo current_levels,
+ SchemaTreeContext* ctx, const SchemaField* parent,
+ SchemaField* out) {
+ if (group.field_count() != 1) {
+ return Status::Invalid("MAP-annotated groups must have a single child.");
+ }
+ if (group.is_repeated()) {
+ return Status::Invalid("MAP-annotated groups must not be repeated.");
+ }
+
+ const Node& key_value_node = *group.field(0);
+
+ if (!key_value_node.is_repeated()) {
+ return Status::Invalid(
+ "Non-repeated key value in a MAP-annotated group are not supported.");
+ }
+
+ if (!key_value_node.is_group()) {
+ return Status::Invalid("Key-value node must be a group.");
+ }
+
+ const GroupNode& key_value = checked_cast<const GroupNode&>(key_value_node);
+ if (key_value.field_count() != 1 && key_value.field_count() != 2) {
+ return Status::Invalid("Key-value map node must have 1 or 2 child elements. Found: ",
+ key_value.field_count());
+ }
+ const Node& key_node = *key_value.field(0);
+ if (!key_node.is_required()) {
+ return Status::Invalid("Map keys must be annotated as required.");
+ }
+ // Arrow doesn't support 1 column maps (i.e. Sets). The options are to either
+ // make the values column nullable, or process the map as a list. We choose the latter
+ // as it is simpler.
+ if (key_value.field_count() == 1) {
+ return ListToSchemaField(group, current_levels, ctx, parent, out);
+ }
+
+ current_levels.Increment(group);
+ int16_t repeated_ancestor_def_level = current_levels.IncrementRepeated();
+
+ out->children.resize(1);
+ SchemaField* key_value_field = &out->children[0];
+
+ key_value_field->children.resize(2);
+ SchemaField* key_field = &key_value_field->children[0];
+ SchemaField* value_field = &key_value_field->children[1];
+
+ ctx->LinkParent(out, parent);
+ ctx->LinkParent(key_value_field, out);
+ ctx->LinkParent(key_field, key_value_field);
+ ctx->LinkParent(value_field, key_value_field);
+
+ // required/optional group name=whatever {
+ // repeated group name=key_values{
+ // required TYPE key;
+ // required/optional TYPE value;
+ // }
+ // }
+ //
+
+ RETURN_NOT_OK(NodeToSchemaField(*key_value.field(0), current_levels, ctx,
+ key_value_field, key_field));
+ RETURN_NOT_OK(NodeToSchemaField(*key_value.field(1), current_levels, ctx,
+ key_value_field, value_field));
+
+ key_value_field->field = ::arrow::field(
+ group.name(), ::arrow::struct_({key_field->field, value_field->field}),
+ /*nullable=*/false, FieldIdMetadata(key_value.field_id()));
+ key_value_field->level_info = current_levels;
+
+ out->field = ::arrow::field(group.name(),
+ ::arrow::map(key_field->field->type(), value_field->field),
+ group.is_optional(), FieldIdMetadata(group.field_id()));
+ out->level_info = current_levels;
+ // At this point current levels contains the def level for this list,
+ // we need to reset to the prior parent.
+ out->level_info.repeated_ancestor_def_level = repeated_ancestor_def_level;
+ return Status::OK();
+}
+
+Status ListToSchemaField(const GroupNode& group, LevelInfo current_levels,
+ SchemaTreeContext* ctx, const SchemaField* parent,
+ SchemaField* out) {
+ if (group.field_count() != 1) {
+ return Status::Invalid("LIST-annotated groups must have a single child.");
+ }
+ if (group.is_repeated()) {
+ return Status::Invalid("LIST-annotated groups must not be repeated.");
+ }
+ current_levels.Increment(group);
+
+ out->children.resize(group.field_count());
+ SchemaField* child_field = &out->children[0];
+
+ ctx->LinkParent(out, parent);
+ ctx->LinkParent(child_field, out);
+
+ const Node& list_node = *group.field(0);
+
+ if (!list_node.is_repeated()) {
+ return Status::Invalid(
+ "Non-repeated nodes in a LIST-annotated group are not supported.");
+ }
+
+ int16_t repeated_ancestor_def_level = current_levels.IncrementRepeated();
+ if (list_node.is_group()) {
+ // Resolve 3-level encoding
+ //
+ // required/optional group name=whatever {
+ // repeated group name=list {
+ // required/optional TYPE item;
+ // }
+ // }
+ //
+ // yields list<item: TYPE ?nullable> ?nullable
+ //
+ // We distinguish the special case that we have
+ //
+ // required/optional group name=whatever {
+ // repeated group name=array or $SOMETHING_tuple {
+ // required/optional TYPE item;
+ // }
+ // }
+ //
+ // In this latter case, the inner type of the list should be a struct
+ // rather than a primitive value
+ //
+ // yields list<item: struct<item: TYPE ?nullable> not null> ?nullable
+ const auto& list_group = static_cast<const GroupNode&>(list_node);
+ // Special case mentioned in the format spec:
+ // If the name is array or ends in _tuple, this should be a list of struct
+ // even for single child elements.
+ if (list_group.field_count() == 1 && !HasStructListName(list_group)) {
+ // List of primitive type
+ RETURN_NOT_OK(
+ NodeToSchemaField(*list_group.field(0), current_levels, ctx, out, child_field));
+ } else {
+ RETURN_NOT_OK(GroupToStruct(list_group, current_levels, ctx, out, child_field));
+ }
+ } else {
+ // Two-level list encoding
+ //
+ // required/optional group LIST {
+ // repeated TYPE;
+ // }
+ const auto& primitive_node = static_cast<const PrimitiveNode&>(list_node);
+ int column_index = ctx->schema->GetColumnIndex(primitive_node);
+ ASSIGN_OR_RAISE(std::shared_ptr<ArrowType> type,
+ GetTypeForNode(column_index, primitive_node, ctx));
+ auto item_field = ::arrow::field(list_node.name(), type, /*nullable=*/false,
+ FieldIdMetadata(list_node.field_id()));
+ RETURN_NOT_OK(
+ PopulateLeaf(column_index, item_field, current_levels, ctx, out, child_field));
+ }
+ out->field = ::arrow::field(group.name(), ::arrow::list(child_field->field),
+ group.is_optional(), FieldIdMetadata(group.field_id()));
+ out->level_info = current_levels;
+ // At this point current levels contains the def level for this list,
+ // we need to reset to the prior parent.
+ out->level_info.repeated_ancestor_def_level = repeated_ancestor_def_level;
+ return Status::OK();
+}
+
+Status GroupToSchemaField(const GroupNode& node, LevelInfo current_levels,
+ SchemaTreeContext* ctx, const SchemaField* parent,
+ SchemaField* out) {
+ if (node.logical_type()->is_list()) {
+ return ListToSchemaField(node, current_levels, ctx, parent, out);
+ } else if (node.logical_type()->is_map()) {
+ return MapToSchemaField(node, current_levels, ctx, parent, out);
+ }
+ std::shared_ptr<ArrowType> type;
+ if (node.is_repeated()) {
+ // Simple repeated struct
+ //
+ // repeated group $NAME {
+ // r/o TYPE[0] f0
+ // r/o TYPE[1] f1
+ // }
+ out->children.resize(1);
+
+ int16_t repeated_ancestor_def_level = current_levels.IncrementRepeated();
+ RETURN_NOT_OK(GroupToStruct(node, current_levels, ctx, out, &out->children[0]));
+ out->field = ::arrow::field(node.name(), ::arrow::list(out->children[0].field),
+ /*nullable=*/false, FieldIdMetadata(node.field_id()));
+
+ ctx->LinkParent(&out->children[0], out);
+ out->level_info = current_levels;
+ // At this point current_levels contains this list as the def level, we need to
+ // use the previous ancenstor of thi slist.
+ out->level_info.repeated_ancestor_def_level = repeated_ancestor_def_level;
+ return Status::OK();
+ } else {
+ current_levels.Increment(node);
+ return GroupToStruct(node, current_levels, ctx, parent, out);
+ }
+}
+
+Status NodeToSchemaField(const Node& node, LevelInfo current_levels,
+ SchemaTreeContext* ctx, const SchemaField* parent,
+ SchemaField* out) {
+ // Workhorse function for converting a Parquet schema node to an Arrow
+ // type. Handles different conventions for nested data.
+
+ ctx->LinkParent(out, parent);
+
+ // Now, walk the schema and create a ColumnDescriptor for each leaf node
+ if (node.is_group()) {
+ // A nested field, but we don't know what kind yet
+ return GroupToSchemaField(static_cast<const GroupNode&>(node), current_levels, ctx,
+ parent, out);
+ } else {
+ // Either a normal flat primitive type, or a list type encoded with 1-level
+ // list encoding. Note that the 3-level encoding is the form recommended by
+ // the parquet specification, but technically we can have either
+ //
+ // required/optional $TYPE $FIELD_NAME
+ //
+ // or
+ //
+ // repeated $TYPE $FIELD_NAME
+ const auto& primitive_node = static_cast<const PrimitiveNode&>(node);
+ int column_index = ctx->schema->GetColumnIndex(primitive_node);
+ ASSIGN_OR_RAISE(std::shared_ptr<ArrowType> type,
+ GetTypeForNode(column_index, primitive_node, ctx));
+ if (node.is_repeated()) {
+ // One-level list encoding, e.g.
+ // a: repeated int32;
+ int16_t repeated_ancestor_def_level = current_levels.IncrementRepeated();
+ out->children.resize(1);
+ auto child_field = ::arrow::field(node.name(), type, /*nullable=*/false);
+ RETURN_NOT_OK(PopulateLeaf(column_index, child_field, current_levels, ctx, out,
+ &out->children[0]));
+
+ out->field = ::arrow::field(node.name(), ::arrow::list(child_field),
+ /*nullable=*/false, FieldIdMetadata(node.field_id()));
+ out->level_info = current_levels;
+ // At this point current_levels has consider this list the ancestor so restore
+ // the actual ancenstor.
+ out->level_info.repeated_ancestor_def_level = repeated_ancestor_def_level;
+ return Status::OK();
+ } else {
+ current_levels.Increment(node);
+ // A normal (required/optional) primitive node
+ return PopulateLeaf(column_index,
+ ::arrow::field(node.name(), type, node.is_optional(),
+ FieldIdMetadata(node.field_id())),
+ current_levels, ctx, parent, out);
+ }
+ }
+}
+
+// Get the original Arrow schema, as serialized in the Parquet metadata
+Status GetOriginSchema(const std::shared_ptr<const KeyValueMetadata>& metadata,
+ std::shared_ptr<const KeyValueMetadata>* clean_metadata,
+ std::shared_ptr<::arrow::Schema>* out) {
+ if (metadata == nullptr) {
+ *out = nullptr;
+ *clean_metadata = nullptr;
+ return Status::OK();
+ }
+
+ static const std::string kArrowSchemaKey = "ARROW:schema";
+ int schema_index = metadata->FindKey(kArrowSchemaKey);
+ if (schema_index == -1) {
+ *out = nullptr;
+ *clean_metadata = metadata;
+ return Status::OK();
+ }
+
+ // The original Arrow schema was serialized using the store_schema option.
+ // We deserialize it here and use it to inform read options such as
+ // dictionary-encoded fields.
+ auto decoded = ::arrow::util::base64_decode(metadata->value(schema_index));
+ auto schema_buf = std::make_shared<Buffer>(decoded);
+
+ ::arrow::ipc::DictionaryMemo dict_memo;
+ ::arrow::io::BufferReader input(schema_buf);
+
+ ARROW_ASSIGN_OR_RAISE(*out, ::arrow::ipc::ReadSchema(&input, &dict_memo));
+
+ if (metadata->size() > 1) {
+ // Copy the metadata without the schema key
+ auto new_metadata = ::arrow::key_value_metadata({}, {});
+ new_metadata->reserve(metadata->size() - 1);
+ for (int64_t i = 0; i < metadata->size(); ++i) {
+ if (i == schema_index) continue;
+ new_metadata->Append(metadata->key(i), metadata->value(i));
+ }
+ *clean_metadata = new_metadata;
+ } else {
+ // No other keys, let metadata be null
+ *clean_metadata = nullptr;
+ }
+ return Status::OK();
+}
+
+// Restore original Arrow field information that was serialized as Parquet metadata
+// but that is not necessarily present in the field reconstitued from Parquet data
+// (for example, Parquet timestamp types doesn't carry timezone information).
+
+Result<bool> ApplyOriginalMetadata(const Field& origin_field, SchemaField* inferred);
+
+std::function<std::shared_ptr<::arrow::DataType>(FieldVector)> GetNestedFactory(
+ const ArrowType& origin_type, const ArrowType& inferred_type) {
+ switch (inferred_type.id()) {
+ case ::arrow::Type::STRUCT:
+ if (origin_type.id() == ::arrow::Type::STRUCT) {
+ return ::arrow::struct_;
+ }
+ break;
+ case ::arrow::Type::LIST:
+ if (origin_type.id() == ::arrow::Type::LIST) {
+ return [](FieldVector fields) {
+ DCHECK_EQ(fields.size(), 1);
+ return ::arrow::list(std::move(fields[0]));
+ };
+ }
+ if (origin_type.id() == ::arrow::Type::LARGE_LIST) {
+ return [](FieldVector fields) {
+ DCHECK_EQ(fields.size(), 1);
+ return ::arrow::large_list(std::move(fields[0]));
+ };
+ }
+ if (origin_type.id() == ::arrow::Type::FIXED_SIZE_LIST) {
+ const auto list_size =
+ checked_cast<const ::arrow::FixedSizeListType&>(origin_type).list_size();
+ return [list_size](FieldVector fields) {
+ DCHECK_EQ(fields.size(), 1);
+ return ::arrow::fixed_size_list(std::move(fields[0]), list_size);
+ };
+ }
+ break;
+ default:
+ break;
+ }
+ return {};
+}
+
+Result<bool> ApplyOriginalStorageMetadata(const Field& origin_field,
+ SchemaField* inferred) {
+ bool modified = false;
+
+ auto origin_type = origin_field.type();
+ auto inferred_type = inferred->field->type();
+
+ const int num_children = inferred_type->num_fields();
+
+ if (num_children > 0 && origin_type->num_fields() == num_children) {
+ DCHECK_EQ(static_cast<int>(inferred->children.size()), num_children);
+ const auto factory = GetNestedFactory(*origin_type, *inferred_type);
+ if (factory) {
+ // The type may be modified (e.g. LargeList) while the children stay the same
+ modified |= origin_type->id() != inferred_type->id();
+
+ // Apply original metadata recursively to children
+ for (int i = 0; i < inferred_type->num_fields(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(
+ const bool child_modified,
+ ApplyOriginalMetadata(*origin_type->field(i), &inferred->children[i]));
+ modified |= child_modified;
+ }
+ if (modified) {
+ // Recreate this field using the modified child fields
+ ::arrow::FieldVector modified_children(inferred_type->num_fields());
+ for (int i = 0; i < inferred_type->num_fields(); ++i) {
+ modified_children[i] = inferred->children[i].field;
+ }
+ inferred->field =
+ inferred->field->WithType(factory(std::move(modified_children)));
+ }
+ }
+ }
+
+ if (origin_type->id() == ::arrow::Type::TIMESTAMP &&
+ inferred_type->id() == ::arrow::Type::TIMESTAMP) {
+ // Restore time zone, if any
+ const auto& ts_type = checked_cast<const ::arrow::TimestampType&>(*inferred_type);
+ const auto& ts_origin_type =
+ checked_cast<const ::arrow::TimestampType&>(*origin_type);
+
+ // If the data is tz-aware, then set the original time zone, since Parquet
+ // has no native storage for timezones
+ if (ts_type.timezone() == "UTC" && ts_origin_type.timezone() != "") {
+ if (ts_type.unit() == ts_origin_type.unit()) {
+ inferred->field = inferred->field->WithType(origin_type);
+ } else {
+ auto ts_type_new = ::arrow::timestamp(ts_type.unit(), ts_origin_type.timezone());
+ inferred->field = inferred->field->WithType(ts_type_new);
+ }
+ }
+ modified = true;
+ }
+
+ if (origin_type->id() == ::arrow::Type::DICTIONARY &&
+ inferred_type->id() != ::arrow::Type::DICTIONARY &&
+ IsDictionaryReadSupported(*inferred_type)) {
+ // Direct dictionary reads are only suppored for a couple primitive types,
+ // so no need to recurse on value types.
+ const auto& dict_origin_type =
+ checked_cast<const ::arrow::DictionaryType&>(*origin_type);
+ inferred->field = inferred->field->WithType(
+ ::arrow::dictionary(::arrow::int32(), inferred_type, dict_origin_type.ordered()));
+ modified = true;
+ }
+
+ if ((origin_type->id() == ::arrow::Type::LARGE_BINARY &&
+ inferred_type->id() == ::arrow::Type::BINARY) ||
+ (origin_type->id() == ::arrow::Type::LARGE_STRING &&
+ inferred_type->id() == ::arrow::Type::STRING)) {
+ // Read back binary-like arrays with the intended offset width.
+ inferred->field = inferred->field->WithType(origin_type);
+ modified = true;
+ }
+
+ if (origin_type->id() == ::arrow::Type::DECIMAL256 &&
+ inferred_type->id() == ::arrow::Type::DECIMAL128) {
+ inferred->field = inferred->field->WithType(origin_type);
+ modified = true;
+ }
+
+ // Restore field metadata
+ std::shared_ptr<const KeyValueMetadata> field_metadata = origin_field.metadata();
+ if (field_metadata != nullptr) {
+ if (inferred->field->metadata()) {
+ // Prefer the metadata keys (like field_id) from the current metadata
+ field_metadata = field_metadata->Merge(*inferred->field->metadata());
+ }
+ inferred->field = inferred->field->WithMetadata(field_metadata);
+ modified = true;
+ }
+
+ return modified;
+}
+
+Result<bool> ApplyOriginalMetadata(const Field& origin_field, SchemaField* inferred) {
+ bool modified = false;
+
+ auto origin_type = origin_field.type();
+ auto inferred_type = inferred->field->type();
+
+ if (origin_type->id() == ::arrow::Type::EXTENSION) {
+ const auto& ex_type = checked_cast<const ::arrow::ExtensionType&>(*origin_type);
+ auto origin_storage_field = origin_field.WithType(ex_type.storage_type());
+
+ // Apply metadata recursively to storage type
+ RETURN_NOT_OK(ApplyOriginalStorageMetadata(*origin_storage_field, inferred));
+
+ // Restore extension type, if the storage type is the same as inferred
+ // from the Parquet type
+ if (ex_type.storage_type()->Equals(*inferred->field->type())) {
+ inferred->field = inferred->field->WithType(origin_type);
+ }
+ modified = true;
+ } else {
+ ARROW_ASSIGN_OR_RAISE(modified, ApplyOriginalStorageMetadata(origin_field, inferred));
+ }
+
+ return modified;
+}
+
+} // namespace
+
+Status FieldToNode(const std::shared_ptr<Field>& field,
+ const WriterProperties& properties,
+ const ArrowWriterProperties& arrow_properties, NodePtr* out) {
+ return FieldToNode(field->name(), field, properties, arrow_properties, out);
+}
+
+Status ToParquetSchema(const ::arrow::Schema* arrow_schema,
+ const WriterProperties& properties,
+ const ArrowWriterProperties& arrow_properties,
+ std::shared_ptr<SchemaDescriptor>* out) {
+ std::vector<NodePtr> nodes(arrow_schema->num_fields());
+ for (int i = 0; i < arrow_schema->num_fields(); i++) {
+ RETURN_NOT_OK(
+ FieldToNode(arrow_schema->field(i), properties, arrow_properties, &nodes[i]));
+ }
+
+ NodePtr schema = GroupNode::Make("schema", Repetition::REQUIRED, nodes);
+ *out = std::make_shared<::parquet::SchemaDescriptor>();
+ PARQUET_CATCH_NOT_OK((*out)->Init(schema));
+
+ return Status::OK();
+}
+
+Status ToParquetSchema(const ::arrow::Schema* arrow_schema,
+ const WriterProperties& properties,
+ std::shared_ptr<SchemaDescriptor>* out) {
+ return ToParquetSchema(arrow_schema, properties, *default_arrow_writer_properties(),
+ out);
+}
+
+Status FromParquetSchema(
+ const SchemaDescriptor* schema, const ArrowReaderProperties& properties,
+ const std::shared_ptr<const KeyValueMetadata>& key_value_metadata,
+ std::shared_ptr<::arrow::Schema>* out) {
+ SchemaManifest manifest;
+ RETURN_NOT_OK(SchemaManifest::Make(schema, key_value_metadata, properties, &manifest));
+ std::vector<std::shared_ptr<Field>> fields(manifest.schema_fields.size());
+
+ for (int i = 0; i < static_cast<int>(fields.size()); i++) {
+ const auto& schema_field = manifest.schema_fields[i];
+ fields[i] = schema_field.field;
+ }
+ if (manifest.origin_schema) {
+ // ARROW-8980: If the ARROW:schema was in the input metadata, then
+ // manifest.origin_schema will have it scrubbed out
+ *out = ::arrow::schema(fields, manifest.origin_schema->metadata());
+ } else {
+ *out = ::arrow::schema(fields, key_value_metadata);
+ }
+ return Status::OK();
+}
+
+Status FromParquetSchema(const SchemaDescriptor* parquet_schema,
+ const ArrowReaderProperties& properties,
+ std::shared_ptr<::arrow::Schema>* out) {
+ return FromParquetSchema(parquet_schema, properties, nullptr, out);
+}
+
+Status FromParquetSchema(const SchemaDescriptor* parquet_schema,
+ std::shared_ptr<::arrow::Schema>* out) {
+ ArrowReaderProperties properties;
+ return FromParquetSchema(parquet_schema, properties, nullptr, out);
+}
+
+Status SchemaManifest::Make(const SchemaDescriptor* schema,
+ const std::shared_ptr<const KeyValueMetadata>& metadata,
+ const ArrowReaderProperties& properties,
+ SchemaManifest* manifest) {
+ SchemaTreeContext ctx;
+ ctx.manifest = manifest;
+ ctx.properties = properties;
+ ctx.schema = schema;
+ const GroupNode& schema_node = *schema->group_node();
+ manifest->descr = schema;
+ manifest->schema_fields.resize(schema_node.field_count());
+
+ // Try to deserialize original Arrow schema
+ RETURN_NOT_OK(
+ GetOriginSchema(metadata, &manifest->schema_metadata, &manifest->origin_schema));
+ // Ignore original schema if it's not compatible with the Parquet schema
+ if (manifest->origin_schema != nullptr &&
+ manifest->origin_schema->num_fields() != schema_node.field_count()) {
+ manifest->origin_schema = nullptr;
+ }
+
+ for (int i = 0; i < static_cast<int>(schema_node.field_count()); ++i) {
+ SchemaField* out_field = &manifest->schema_fields[i];
+ RETURN_NOT_OK(NodeToSchemaField(*schema_node.field(i), LevelInfo(), &ctx,
+ /*parent=*/nullptr, out_field));
+
+ // TODO(wesm): as follow up to ARROW-3246, we should really pass the origin
+ // schema (if any) through all functions in the schema reconstruction, but
+ // I'm being lazy and just setting dictionary fields at the top level for
+ // now
+ if (manifest->origin_schema == nullptr) {
+ continue;
+ }
+
+ auto origin_field = manifest->origin_schema->field(i);
+ RETURN_NOT_OK(ApplyOriginalMetadata(*origin_field, out_field));
+ }
+ return Status::OK();
+}
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/schema.h b/src/arrow/cpp/src/parquet/arrow/schema.h
new file mode 100644
index 000000000..dd60fde43
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/schema.h
@@ -0,0 +1,184 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cassert>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+
+#include "parquet/level_conversion.h"
+#include "parquet/platform.h"
+#include "parquet/schema.h"
+
+namespace parquet {
+
+class ArrowReaderProperties;
+class ArrowWriterProperties;
+class WriterProperties;
+
+namespace arrow {
+
+/// \defgroup arrow-to-parquet-schema-conversion Functions to convert an Arrow
+/// schema into a Parquet schema.
+///
+/// @{
+
+PARQUET_EXPORT
+::arrow::Status FieldToNode(const std::shared_ptr<::arrow::Field>& field,
+ const WriterProperties& properties,
+ const ArrowWriterProperties& arrow_properties,
+ schema::NodePtr* out);
+
+PARQUET_EXPORT
+::arrow::Status ToParquetSchema(const ::arrow::Schema* arrow_schema,
+ const WriterProperties& properties,
+ const ArrowWriterProperties& arrow_properties,
+ std::shared_ptr<SchemaDescriptor>* out);
+
+PARQUET_EXPORT
+::arrow::Status ToParquetSchema(const ::arrow::Schema* arrow_schema,
+ const WriterProperties& properties,
+ std::shared_ptr<SchemaDescriptor>* out);
+
+/// @}
+
+/// \defgroup parquet-to-arrow-schema-conversion Functions to convert a Parquet
+/// schema into an Arrow schema.
+///
+/// @{
+
+PARQUET_EXPORT
+::arrow::Status FromParquetSchema(
+ const SchemaDescriptor* parquet_schema, const ArrowReaderProperties& properties,
+ const std::shared_ptr<const ::arrow::KeyValueMetadata>& key_value_metadata,
+ std::shared_ptr<::arrow::Schema>* out);
+
+PARQUET_EXPORT
+::arrow::Status FromParquetSchema(const SchemaDescriptor* parquet_schema,
+ const ArrowReaderProperties& properties,
+ std::shared_ptr<::arrow::Schema>* out);
+
+PARQUET_EXPORT
+::arrow::Status FromParquetSchema(const SchemaDescriptor* parquet_schema,
+ std::shared_ptr<::arrow::Schema>* out);
+
+/// @}
+
+/// \brief Bridge between an arrow::Field and parquet column indices.
+struct PARQUET_EXPORT SchemaField {
+ std::shared_ptr<::arrow::Field> field;
+ std::vector<SchemaField> children;
+
+ // Only set for leaf nodes
+ int column_index = -1;
+
+ parquet::internal::LevelInfo level_info;
+
+ bool is_leaf() const { return column_index != -1; }
+};
+
+/// \brief Bridge between a parquet Schema and an arrow Schema.
+///
+/// Expose parquet columns as a tree structure. Useful traverse and link
+/// between arrow's Schema and parquet's Schema.
+struct PARQUET_EXPORT SchemaManifest {
+ static ::arrow::Status Make(
+ const SchemaDescriptor* schema,
+ const std::shared_ptr<const ::arrow::KeyValueMetadata>& metadata,
+ const ArrowReaderProperties& properties, SchemaManifest* manifest);
+
+ const SchemaDescriptor* descr;
+ std::shared_ptr<::arrow::Schema> origin_schema;
+ std::shared_ptr<const ::arrow::KeyValueMetadata> schema_metadata;
+ std::vector<SchemaField> schema_fields;
+
+ std::unordered_map<int, const SchemaField*> column_index_to_field;
+ std::unordered_map<const SchemaField*, const SchemaField*> child_to_parent;
+
+ ::arrow::Status GetColumnField(int column_index, const SchemaField** out) const {
+ auto it = column_index_to_field.find(column_index);
+ if (it == column_index_to_field.end()) {
+ return ::arrow::Status::KeyError("Column index ", column_index,
+ " not found in schema manifest, may be malformed");
+ }
+ *out = it->second;
+ return ::arrow::Status::OK();
+ }
+
+ const SchemaField* GetParent(const SchemaField* field) const {
+ // Returns nullptr also if not found
+ auto it = child_to_parent.find(field);
+ if (it == child_to_parent.end()) {
+ return NULLPTR;
+ }
+ return it->second;
+ }
+
+ /// Coalesce a list of field indices (relative to the equivalent arrow::Schema) which
+ /// correspond to the column root (first node below the parquet schema's root group) of
+ /// each leaf referenced in column_indices.
+ ///
+ /// For example, for leaves `a.b.c`, `a.b.d.e`, and `i.j.k` (column_indices=[0,1,3])
+ /// the roots are `a` and `i` (return=[0,2]).
+ ///
+ /// root
+ /// -- a <------
+ /// -- -- b | |
+ /// -- -- -- c |
+ /// -- -- -- d |
+ /// -- -- -- -- e
+ /// -- f
+ /// -- -- g
+ /// -- -- -- h
+ /// -- i <---
+ /// -- -- j |
+ /// -- -- -- k
+ ::arrow::Result<std::vector<int>> GetFieldIndices(
+ const std::vector<int>& column_indices) const {
+ const schema::GroupNode* group = descr->group_node();
+ std::unordered_set<int> already_added;
+
+ std::vector<int> out;
+ for (int column_idx : column_indices) {
+ if (column_idx < 0 || column_idx >= descr->num_columns()) {
+ return ::arrow::Status::IndexError("Column index ", column_idx, " is not valid");
+ }
+
+ auto field_node = descr->GetColumnRoot(column_idx);
+ auto field_idx = group->FieldIndex(*field_node);
+ if (field_idx == -1) {
+ return ::arrow::Status::IndexError("Column index ", column_idx, " is not valid");
+ }
+
+ if (already_added.insert(field_idx).second) {
+ out.push_back(field_idx);
+ }
+ }
+ return out;
+ }
+};
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/schema_internal.cc b/src/arrow/cpp/src/parquet/arrow/schema_internal.cc
new file mode 100644
index 000000000..064bf4f55
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/schema_internal.cc
@@ -0,0 +1,222 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/arrow/schema_internal.h"
+
+#include "arrow/type.h"
+
+using ArrowType = ::arrow::DataType;
+using ArrowTypeId = ::arrow::Type;
+using ParquetType = parquet::Type;
+
+namespace parquet {
+
+namespace arrow {
+
+using ::arrow::Result;
+using ::arrow::Status;
+using ::arrow::internal::checked_cast;
+
+Result<std::shared_ptr<ArrowType>> MakeArrowDecimal(const LogicalType& logical_type) {
+ const auto& decimal = checked_cast<const DecimalLogicalType&>(logical_type);
+ if (decimal.precision() <= ::arrow::Decimal128Type::kMaxPrecision) {
+ return ::arrow::Decimal128Type::Make(decimal.precision(), decimal.scale());
+ }
+ return ::arrow::Decimal256Type::Make(decimal.precision(), decimal.scale());
+}
+
+Result<std::shared_ptr<ArrowType>> MakeArrowInt(const LogicalType& logical_type) {
+ const auto& integer = checked_cast<const IntLogicalType&>(logical_type);
+ switch (integer.bit_width()) {
+ case 8:
+ return integer.is_signed() ? ::arrow::int8() : ::arrow::uint8();
+ case 16:
+ return integer.is_signed() ? ::arrow::int16() : ::arrow::uint16();
+ case 32:
+ return integer.is_signed() ? ::arrow::int32() : ::arrow::uint32();
+ default:
+ return Status::TypeError(logical_type.ToString(),
+ " can not annotate physical type Int32");
+ }
+}
+
+Result<std::shared_ptr<ArrowType>> MakeArrowInt64(const LogicalType& logical_type) {
+ const auto& integer = checked_cast<const IntLogicalType&>(logical_type);
+ switch (integer.bit_width()) {
+ case 64:
+ return integer.is_signed() ? ::arrow::int64() : ::arrow::uint64();
+ default:
+ return Status::TypeError(logical_type.ToString(),
+ " can not annotate physical type Int64");
+ }
+}
+
+Result<std::shared_ptr<ArrowType>> MakeArrowTime32(const LogicalType& logical_type) {
+ const auto& time = checked_cast<const TimeLogicalType&>(logical_type);
+ switch (time.time_unit()) {
+ case LogicalType::TimeUnit::MILLIS:
+ return ::arrow::time32(::arrow::TimeUnit::MILLI);
+ default:
+ return Status::TypeError(logical_type.ToString(),
+ " can not annotate physical type Time32");
+ }
+}
+
+Result<std::shared_ptr<ArrowType>> MakeArrowTime64(const LogicalType& logical_type) {
+ const auto& time = checked_cast<const TimeLogicalType&>(logical_type);
+ switch (time.time_unit()) {
+ case LogicalType::TimeUnit::MICROS:
+ return ::arrow::time64(::arrow::TimeUnit::MICRO);
+ case LogicalType::TimeUnit::NANOS:
+ return ::arrow::time64(::arrow::TimeUnit::NANO);
+ default:
+ return Status::TypeError(logical_type.ToString(),
+ " can not annotate physical type Time64");
+ }
+}
+
+Result<std::shared_ptr<ArrowType>> MakeArrowTimestamp(const LogicalType& logical_type) {
+ const auto& timestamp = checked_cast<const TimestampLogicalType&>(logical_type);
+ const bool utc_normalized =
+ timestamp.is_from_converted_type() ? false : timestamp.is_adjusted_to_utc();
+ static const char* utc_timezone = "UTC";
+ switch (timestamp.time_unit()) {
+ case LogicalType::TimeUnit::MILLIS:
+ return (utc_normalized ? ::arrow::timestamp(::arrow::TimeUnit::MILLI, utc_timezone)
+ : ::arrow::timestamp(::arrow::TimeUnit::MILLI));
+ case LogicalType::TimeUnit::MICROS:
+ return (utc_normalized ? ::arrow::timestamp(::arrow::TimeUnit::MICRO, utc_timezone)
+ : ::arrow::timestamp(::arrow::TimeUnit::MICRO));
+ case LogicalType::TimeUnit::NANOS:
+ return (utc_normalized ? ::arrow::timestamp(::arrow::TimeUnit::NANO, utc_timezone)
+ : ::arrow::timestamp(::arrow::TimeUnit::NANO));
+ default:
+ return Status::TypeError("Unrecognized time unit in timestamp logical_type: ",
+ logical_type.ToString());
+ }
+}
+
+Result<std::shared_ptr<ArrowType>> FromByteArray(const LogicalType& logical_type) {
+ switch (logical_type.type()) {
+ case LogicalType::Type::STRING:
+ return ::arrow::utf8();
+ case LogicalType::Type::DECIMAL:
+ return MakeArrowDecimal(logical_type);
+ case LogicalType::Type::NONE:
+ case LogicalType::Type::ENUM:
+ case LogicalType::Type::JSON:
+ case LogicalType::Type::BSON:
+ return ::arrow::binary();
+ default:
+ return Status::NotImplemented("Unhandled logical logical_type ",
+ logical_type.ToString(), " for binary array");
+ }
+}
+
+Result<std::shared_ptr<ArrowType>> FromFLBA(const LogicalType& logical_type,
+ int32_t physical_length) {
+ switch (logical_type.type()) {
+ case LogicalType::Type::DECIMAL:
+ return MakeArrowDecimal(logical_type);
+ case LogicalType::Type::NONE:
+ case LogicalType::Type::INTERVAL:
+ case LogicalType::Type::UUID:
+ return ::arrow::fixed_size_binary(physical_length);
+ default:
+ return Status::NotImplemented("Unhandled logical logical_type ",
+ logical_type.ToString(),
+ " for fixed-length binary array");
+ }
+}
+
+::arrow::Result<std::shared_ptr<ArrowType>> FromInt32(const LogicalType& logical_type) {
+ switch (logical_type.type()) {
+ case LogicalType::Type::INT:
+ return MakeArrowInt(logical_type);
+ case LogicalType::Type::DATE:
+ return ::arrow::date32();
+ case LogicalType::Type::TIME:
+ return MakeArrowTime32(logical_type);
+ case LogicalType::Type::DECIMAL:
+ return MakeArrowDecimal(logical_type);
+ case LogicalType::Type::NONE:
+ return ::arrow::int32();
+ default:
+ return Status::NotImplemented("Unhandled logical type ", logical_type.ToString(),
+ " for INT32");
+ }
+}
+
+Result<std::shared_ptr<ArrowType>> FromInt64(const LogicalType& logical_type) {
+ switch (logical_type.type()) {
+ case LogicalType::Type::INT:
+ return MakeArrowInt64(logical_type);
+ case LogicalType::Type::DECIMAL:
+ return MakeArrowDecimal(logical_type);
+ case LogicalType::Type::TIMESTAMP:
+ return MakeArrowTimestamp(logical_type);
+ case LogicalType::Type::TIME:
+ return MakeArrowTime64(logical_type);
+ case LogicalType::Type::NONE:
+ return ::arrow::int64();
+ default:
+ return Status::NotImplemented("Unhandled logical type ", logical_type.ToString(),
+ " for INT64");
+ }
+}
+
+Result<std::shared_ptr<ArrowType>> GetArrowType(
+ Type::type physical_type, const LogicalType& logical_type, int type_length,
+ const ::arrow::TimeUnit::type int96_arrow_time_unit) {
+ if (logical_type.is_invalid() || logical_type.is_null()) {
+ return ::arrow::null();
+ }
+
+ switch (physical_type) {
+ case ParquetType::BOOLEAN:
+ return ::arrow::boolean();
+ case ParquetType::INT32:
+ return FromInt32(logical_type);
+ case ParquetType::INT64:
+ return FromInt64(logical_type);
+ case ParquetType::INT96:
+ return ::arrow::timestamp(int96_arrow_time_unit);
+ case ParquetType::FLOAT:
+ return ::arrow::float32();
+ case ParquetType::DOUBLE:
+ return ::arrow::float64();
+ case ParquetType::BYTE_ARRAY:
+ return FromByteArray(logical_type);
+ case ParquetType::FIXED_LEN_BYTE_ARRAY:
+ return FromFLBA(logical_type, type_length);
+ default: {
+ // PARQUET-1565: This can occur if the file is corrupt
+ return Status::IOError("Invalid physical column type: ",
+ TypeToString(physical_type));
+ }
+ }
+}
+
+Result<std::shared_ptr<ArrowType>> GetArrowType(
+ const schema::PrimitiveNode& primitive,
+ const ::arrow::TimeUnit::type int96_arrow_time_unit) {
+ return GetArrowType(primitive.physical_type(), *primitive.logical_type(),
+ primitive.type_length(), int96_arrow_time_unit);
+}
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/schema_internal.h b/src/arrow/cpp/src/parquet/arrow/schema_internal.h
new file mode 100644
index 000000000..fb837c3ee
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/schema_internal.h
@@ -0,0 +1,51 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/result.h"
+#include "parquet/schema.h"
+
+namespace arrow {
+class DataType;
+}
+
+namespace parquet {
+namespace arrow {
+
+using ::arrow::Result;
+
+Result<std::shared_ptr<::arrow::DataType>> FromByteArray(const LogicalType& logical_type);
+Result<std::shared_ptr<::arrow::DataType>> FromFLBA(const LogicalType& logical_type,
+ int32_t physical_length);
+Result<std::shared_ptr<::arrow::DataType>> FromInt32(const LogicalType& logical_type);
+Result<std::shared_ptr<::arrow::DataType>> FromInt64(const LogicalType& logical_type);
+
+Result<std::shared_ptr<::arrow::DataType>> GetArrowType(Type::type physical_type,
+ const LogicalType& logical_type,
+ int type_length);
+
+Result<std::shared_ptr<::arrow::DataType>> GetArrowType(
+ Type::type physical_type, const LogicalType& logical_type, int type_length,
+ ::arrow::TimeUnit::type int96_arrow_time_unit = ::arrow::TimeUnit::NANO);
+
+Result<std::shared_ptr<::arrow::DataType>> GetArrowType(
+ const schema::PrimitiveNode& primitive,
+ ::arrow::TimeUnit::type int96_arrow_time_unit = ::arrow::TimeUnit::NANO);
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/test_util.h b/src/arrow/cpp/src/parquet/arrow/test_util.h
new file mode 100644
index 000000000..fb1d39876
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/test_util.h
@@ -0,0 +1,512 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <limits>
+#include <memory>
+#include <random>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_decimal.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/type_fwd.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/decimal.h"
+#include "parquet/column_reader.h"
+
+namespace parquet {
+
+using internal::RecordReader;
+
+namespace arrow {
+
+using ::arrow::Array;
+using ::arrow::ChunkedArray;
+using ::arrow::Status;
+
+template <int32_t PRECISION>
+struct DecimalWithPrecisionAndScale {
+ static_assert(PRECISION >= 1 && PRECISION <= 38, "Invalid precision value");
+
+ using type = ::arrow::Decimal128Type;
+ static constexpr ::arrow::Type::type type_id = ::arrow::Decimal128Type::type_id;
+ static constexpr int32_t precision = PRECISION;
+ static constexpr int32_t scale = PRECISION - 1;
+};
+
+template <int32_t PRECISION>
+struct Decimal256WithPrecisionAndScale {
+ static_assert(PRECISION >= 1 && PRECISION <= 76, "Invalid precision value");
+
+ using type = ::arrow::Decimal256Type;
+ static constexpr ::arrow::Type::type type_id = ::arrow::Decimal256Type::type_id;
+ static constexpr int32_t precision = PRECISION;
+ static constexpr int32_t scale = PRECISION - 1;
+};
+
+template <class ArrowType>
+::arrow::enable_if_floating_point<ArrowType, Status> NonNullArray(
+ size_t size, std::shared_ptr<Array>* out) {
+ using c_type = typename ArrowType::c_type;
+ std::vector<c_type> values;
+ ::arrow::random_real(size, 0, static_cast<c_type>(0), static_cast<c_type>(1), &values);
+ ::arrow::NumericBuilder<ArrowType> builder;
+ RETURN_NOT_OK(builder.AppendValues(values.data(), values.size()));
+ return builder.Finish(out);
+}
+
+template <class ArrowType>
+::arrow::enable_if_integer<ArrowType, Status> NonNullArray(size_t size,
+ std::shared_ptr<Array>* out) {
+ std::vector<typename ArrowType::c_type> values;
+ ::arrow::randint(size, 0, 64, &values);
+
+ // Passing data type so this will work with TimestampType too
+ ::arrow::NumericBuilder<ArrowType> builder(std::make_shared<ArrowType>(),
+ ::arrow::default_memory_pool());
+ RETURN_NOT_OK(builder.AppendValues(values.data(), values.size()));
+ return builder.Finish(out);
+}
+
+template <class ArrowType>
+::arrow::enable_if_date<ArrowType, Status> NonNullArray(size_t size,
+ std::shared_ptr<Array>* out) {
+ std::vector<typename ArrowType::c_type> values;
+ ::arrow::randint(size, 0, 24, &values);
+ for (size_t i = 0; i < size; i++) {
+ values[i] *= 86400000;
+ }
+
+ // Passing data type so this will work with TimestampType too
+ ::arrow::NumericBuilder<ArrowType> builder(std::make_shared<ArrowType>(),
+ ::arrow::default_memory_pool());
+ RETURN_NOT_OK(builder.AppendValues(values.data(), values.size()));
+ return builder.Finish(out);
+}
+
+template <class ArrowType>
+::arrow::enable_if_base_binary<ArrowType, Status> NonNullArray(
+ size_t size, std::shared_ptr<Array>* out) {
+ using BuilderType = typename ::arrow::TypeTraits<ArrowType>::BuilderType;
+ BuilderType builder;
+ for (size_t i = 0; i < size; i++) {
+ RETURN_NOT_OK(builder.Append("test-string"));
+ }
+ return builder.Finish(out);
+}
+
+template <typename ArrowType>
+::arrow::enable_if_fixed_size_binary<ArrowType, Status> NonNullArray(
+ size_t size, std::shared_ptr<Array>* out) {
+ using BuilderType = typename ::arrow::TypeTraits<ArrowType>::BuilderType;
+ // set byte_width to the length of "fixed": 5
+ // todo: find a way to generate test data with more diversity.
+ BuilderType builder(::arrow::fixed_size_binary(5));
+ for (size_t i = 0; i < size; i++) {
+ RETURN_NOT_OK(builder.Append("fixed"));
+ }
+ return builder.Finish(out);
+}
+
+static void random_decimals(int64_t n, uint32_t seed, int32_t precision, uint8_t* out) {
+ std::default_random_engine gen(seed);
+ std::uniform_int_distribution<uint32_t> d(0, std::numeric_limits<uint8_t>::max());
+ const int32_t required_bytes = ::arrow::DecimalType::DecimalSize(precision);
+ int32_t byte_width = precision <= 38 ? 16 : 32;
+ std::fill(out, out + byte_width * n, '\0');
+
+ for (int64_t i = 0; i < n; ++i, out += byte_width) {
+ std::generate(out, out + required_bytes,
+ [&d, &gen] { return static_cast<uint8_t>(d(gen)); });
+
+ // sign extend if the sign bit is set for the last byte generated
+ // 0b10000000 == 0x80 == 128
+ if ((out[required_bytes - 1] & '\x80') != 0) {
+ std::fill(out + required_bytes, out + byte_width, '\xFF');
+ }
+ }
+}
+
+template <typename ArrowType, int32_t precision = ArrowType::precision>
+::arrow::enable_if_t<
+ std::is_same<ArrowType, DecimalWithPrecisionAndScale<precision>>::value, Status>
+NonNullArray(size_t size, std::shared_ptr<Array>* out) {
+ constexpr int32_t kDecimalPrecision = precision;
+ constexpr int32_t kDecimalScale = DecimalWithPrecisionAndScale<precision>::scale;
+
+ const auto type = ::arrow::decimal(kDecimalPrecision, kDecimalScale);
+ ::arrow::Decimal128Builder builder(type);
+ const int32_t byte_width =
+ static_cast<const ::arrow::Decimal128Type&>(*type).byte_width();
+
+ constexpr int32_t seed = 0;
+
+ ARROW_ASSIGN_OR_RAISE(auto out_buf, ::arrow::AllocateBuffer(size * byte_width));
+ random_decimals(size, seed, kDecimalPrecision, out_buf->mutable_data());
+
+ RETURN_NOT_OK(builder.AppendValues(out_buf->data(), size));
+ return builder.Finish(out);
+}
+
+template <typename ArrowType, int32_t precision = ArrowType::precision>
+::arrow::enable_if_t<
+ std::is_same<ArrowType, Decimal256WithPrecisionAndScale<precision>>::value, Status>
+NonNullArray(size_t size, std::shared_ptr<Array>* out) {
+ constexpr int32_t kDecimalPrecision = precision;
+ constexpr int32_t kDecimalScale = Decimal256WithPrecisionAndScale<precision>::scale;
+
+ const auto type = ::arrow::decimal256(kDecimalPrecision, kDecimalScale);
+ ::arrow::Decimal256Builder builder(type);
+ const int32_t byte_width =
+ static_cast<const ::arrow::Decimal256Type&>(*type).byte_width();
+
+ constexpr int32_t seed = 0;
+
+ ARROW_ASSIGN_OR_RAISE(auto out_buf, ::arrow::AllocateBuffer(size * byte_width));
+ random_decimals(size, seed, kDecimalPrecision, out_buf->mutable_data());
+
+ RETURN_NOT_OK(builder.AppendValues(out_buf->data(), size));
+ return builder.Finish(out);
+}
+
+template <class ArrowType>
+::arrow::enable_if_boolean<ArrowType, Status> NonNullArray(size_t size,
+ std::shared_ptr<Array>* out) {
+ std::vector<uint8_t> values;
+ ::arrow::randint(size, 0, 1, &values);
+ ::arrow::BooleanBuilder builder;
+ RETURN_NOT_OK(builder.AppendValues(values.data(), values.size()));
+ return builder.Finish(out);
+}
+
+// This helper function only supports (size/2) nulls.
+template <typename ArrowType>
+::arrow::enable_if_floating_point<ArrowType, Status> NullableArray(
+ size_t size, size_t num_nulls, uint32_t seed, std::shared_ptr<Array>* out) {
+ using c_type = typename ArrowType::c_type;
+ std::vector<c_type> values;
+ ::arrow::random_real(size, seed, static_cast<c_type>(-1e10), static_cast<c_type>(1e10),
+ &values);
+ std::vector<uint8_t> valid_bytes(size, 1);
+
+ for (size_t i = 0; i < num_nulls; i++) {
+ valid_bytes[i * 2] = 0;
+ }
+
+ ::arrow::NumericBuilder<ArrowType> builder;
+ RETURN_NOT_OK(builder.AppendValues(values.data(), values.size(), valid_bytes.data()));
+ return builder.Finish(out);
+}
+
+// This helper function only supports (size/2) nulls.
+template <typename ArrowType>
+::arrow::enable_if_integer<ArrowType, Status> NullableArray(size_t size, size_t num_nulls,
+ uint32_t seed,
+ std::shared_ptr<Array>* out) {
+ std::vector<typename ArrowType::c_type> values;
+
+ // Seed is random in Arrow right now
+ (void)seed;
+ ::arrow::randint(size, 0, 64, &values);
+ std::vector<uint8_t> valid_bytes(size, 1);
+
+ for (size_t i = 0; i < num_nulls; i++) {
+ valid_bytes[i * 2] = 0;
+ }
+
+ // Passing data type so this will work with TimestampType too
+ ::arrow::NumericBuilder<ArrowType> builder(std::make_shared<ArrowType>(),
+ ::arrow::default_memory_pool());
+ RETURN_NOT_OK(builder.AppendValues(values.data(), values.size(), valid_bytes.data()));
+ return builder.Finish(out);
+}
+
+template <typename ArrowType>
+::arrow::enable_if_date<ArrowType, Status> NullableArray(size_t size, size_t num_nulls,
+ uint32_t seed,
+ std::shared_ptr<Array>* out) {
+ std::vector<typename ArrowType::c_type> values;
+
+ // Seed is random in Arrow right now
+ (void)seed;
+ ::arrow::randint(size, 0, 24, &values);
+ for (size_t i = 0; i < size; i++) {
+ values[i] *= 86400000;
+ }
+ std::vector<uint8_t> valid_bytes(size, 1);
+
+ for (size_t i = 0; i < num_nulls; i++) {
+ valid_bytes[i * 2] = 0;
+ }
+
+ // Passing data type so this will work with TimestampType too
+ ::arrow::NumericBuilder<ArrowType> builder(std::make_shared<ArrowType>(),
+ ::arrow::default_memory_pool());
+ RETURN_NOT_OK(builder.AppendValues(values.data(), values.size(), valid_bytes.data()));
+ return builder.Finish(out);
+}
+
+// This helper function only supports (size/2) nulls yet.
+template <typename ArrowType>
+::arrow::enable_if_base_binary<ArrowType, Status> NullableArray(
+ size_t size, size_t num_nulls, uint32_t seed, std::shared_ptr<::arrow::Array>* out) {
+ std::vector<uint8_t> valid_bytes(size, 1);
+
+ for (size_t i = 0; i < num_nulls; i++) {
+ valid_bytes[i * 2] = 0;
+ }
+
+ using BuilderType = typename ::arrow::TypeTraits<ArrowType>::BuilderType;
+ BuilderType builder;
+
+ const int kBufferSize = 10;
+ uint8_t buffer[kBufferSize];
+ for (size_t i = 0; i < size; i++) {
+ if (!valid_bytes[i]) {
+ RETURN_NOT_OK(builder.AppendNull());
+ } else {
+ ::arrow::random_bytes(kBufferSize, seed + static_cast<uint32_t>(i), buffer);
+ if (ArrowType::is_utf8) {
+ // Trivially force data to be valid UTF8 by making it all ASCII
+ for (auto& byte : buffer) {
+ byte &= 0x7f;
+ }
+ }
+ RETURN_NOT_OK(builder.Append(buffer, kBufferSize));
+ }
+ }
+ return builder.Finish(out);
+}
+
+// This helper function only supports (size/2) nulls yet,
+// same as NullableArray<String|Binary>(..)
+template <typename ArrowType>
+::arrow::enable_if_fixed_size_binary<ArrowType, Status> NullableArray(
+ size_t size, size_t num_nulls, uint32_t seed, std::shared_ptr<::arrow::Array>* out) {
+ std::vector<uint8_t> valid_bytes(size, 1);
+
+ for (size_t i = 0; i < num_nulls; i++) {
+ valid_bytes[i * 2] = 0;
+ }
+
+ using BuilderType = typename ::arrow::TypeTraits<ArrowType>::BuilderType;
+ const int byte_width = 10;
+ BuilderType builder(::arrow::fixed_size_binary(byte_width));
+
+ const int kBufferSize = byte_width;
+ uint8_t buffer[kBufferSize];
+ for (size_t i = 0; i < size; i++) {
+ if (!valid_bytes[i]) {
+ RETURN_NOT_OK(builder.AppendNull());
+ } else {
+ ::arrow::random_bytes(kBufferSize, seed + static_cast<uint32_t>(i), buffer);
+ RETURN_NOT_OK(builder.Append(buffer));
+ }
+ }
+ return builder.Finish(out);
+}
+
+template <typename ArrowType, int32_t precision = ArrowType::precision>
+::arrow::enable_if_t<
+ std::is_same<ArrowType, DecimalWithPrecisionAndScale<precision>>::value, Status>
+NullableArray(size_t size, size_t num_nulls, uint32_t seed,
+ std::shared_ptr<::arrow::Array>* out) {
+ std::vector<uint8_t> valid_bytes(size, '\1');
+
+ for (size_t i = 0; i < num_nulls; ++i) {
+ valid_bytes[i * 2] = '\0';
+ }
+
+ constexpr int32_t kDecimalPrecision = precision;
+ constexpr int32_t kDecimalScale = DecimalWithPrecisionAndScale<precision>::scale;
+ const auto type = ::arrow::decimal(kDecimalPrecision, kDecimalScale);
+ const int32_t byte_width =
+ static_cast<const ::arrow::Decimal128Type&>(*type).byte_width();
+
+ ARROW_ASSIGN_OR_RAISE(auto out_buf, ::arrow::AllocateBuffer(size * byte_width));
+
+ random_decimals(size, seed, precision, out_buf->mutable_data());
+
+ ::arrow::Decimal128Builder builder(type);
+ RETURN_NOT_OK(builder.AppendValues(out_buf->data(), size, valid_bytes.data()));
+ return builder.Finish(out);
+}
+
+template <typename ArrowType, int32_t precision = ArrowType::precision>
+::arrow::enable_if_t<
+ std::is_same<ArrowType, Decimal256WithPrecisionAndScale<precision>>::value, Status>
+NullableArray(size_t size, size_t num_nulls, uint32_t seed,
+ std::shared_ptr<::arrow::Array>* out) {
+ std::vector<uint8_t> valid_bytes(size, '\1');
+
+ for (size_t i = 0; i < num_nulls; ++i) {
+ valid_bytes[i * 2] = '\0';
+ }
+
+ constexpr int32_t kDecimalPrecision = precision;
+ constexpr int32_t kDecimalScale = Decimal256WithPrecisionAndScale<precision>::scale;
+ const auto type = ::arrow::decimal256(kDecimalPrecision, kDecimalScale);
+ const int32_t byte_width =
+ static_cast<const ::arrow::Decimal256Type&>(*type).byte_width();
+
+ ARROW_ASSIGN_OR_RAISE(auto out_buf, ::arrow::AllocateBuffer(size * byte_width));
+
+ random_decimals(size, seed, precision, out_buf->mutable_data());
+
+ ::arrow::Decimal256Builder builder(type);
+ RETURN_NOT_OK(builder.AppendValues(out_buf->data(), size, valid_bytes.data()));
+ return builder.Finish(out);
+}
+
+// This helper function only supports (size/2) nulls yet.
+template <class ArrowType>
+::arrow::enable_if_boolean<ArrowType, Status> NullableArray(size_t size, size_t num_nulls,
+ uint32_t seed,
+ std::shared_ptr<Array>* out) {
+ std::vector<uint8_t> values;
+
+ // Seed is random in Arrow right now
+ (void)seed;
+
+ ::arrow::randint(size, 0, 1, &values);
+ std::vector<uint8_t> valid_bytes(size, 1);
+
+ for (size_t i = 0; i < num_nulls; i++) {
+ valid_bytes[i * 2] = 0;
+ }
+
+ ::arrow::BooleanBuilder builder;
+ RETURN_NOT_OK(builder.AppendValues(values.data(), values.size(), valid_bytes.data()));
+ return builder.Finish(out);
+}
+
+/// Wrap an Array into a ListArray by splitting it up into size lists.
+///
+/// This helper function only supports (size/2) nulls.
+Status MakeListArray(const std::shared_ptr<Array>& values, int64_t size,
+ int64_t null_count, const std::string& item_name,
+ bool nullable_values, std::shared_ptr<::arrow::ListArray>* out) {
+ // We always include an empty list
+ int64_t non_null_entries = size - null_count - 1;
+ int64_t length_per_entry = values->length() / non_null_entries;
+
+ auto offsets = AllocateBuffer();
+ RETURN_NOT_OK(offsets->Resize((size + 1) * sizeof(int32_t)));
+ int32_t* offsets_ptr = reinterpret_cast<int32_t*>(offsets->mutable_data());
+
+ auto null_bitmap = AllocateBuffer();
+ int64_t bitmap_size = ::arrow::BitUtil::BytesForBits(size);
+ RETURN_NOT_OK(null_bitmap->Resize(bitmap_size));
+ uint8_t* null_bitmap_ptr = null_bitmap->mutable_data();
+ memset(null_bitmap_ptr, 0, bitmap_size);
+
+ int32_t current_offset = 0;
+ for (int64_t i = 0; i < size; i++) {
+ offsets_ptr[i] = current_offset;
+ if (!(((i % 2) == 0) && ((i / 2) < null_count))) {
+ // Non-null list (list with index 1 is always empty).
+ ::arrow::BitUtil::SetBit(null_bitmap_ptr, i);
+ if (i != 1) {
+ current_offset += static_cast<int32_t>(length_per_entry);
+ }
+ }
+ }
+ offsets_ptr[size] = static_cast<int32_t>(values->length());
+
+ auto value_field = ::arrow::field(item_name, values->type(), nullable_values);
+ *out = std::make_shared<::arrow::ListArray>(::arrow::list(value_field), size, offsets,
+ values, null_bitmap, null_count);
+
+ return Status::OK();
+}
+
+// Make an array containing only empty lists, with a null values array
+Status MakeEmptyListsArray(int64_t size, std::shared_ptr<Array>* out_array) {
+ // Allocate an offsets buffer containing only zeroes
+ const int64_t offsets_nbytes = (size + 1) * sizeof(int32_t);
+ ARROW_ASSIGN_OR_RAISE(auto offsets_buffer, ::arrow::AllocateBuffer(offsets_nbytes));
+ memset(offsets_buffer->mutable_data(), 0, offsets_nbytes);
+
+ auto value_field =
+ ::arrow::field("item", ::arrow::float64(), false /* nullable_values */);
+ auto list_type = ::arrow::list(value_field);
+
+ std::vector<std::shared_ptr<Buffer>> child_buffers = {nullptr /* null bitmap */,
+ nullptr /* values */};
+ auto child_data =
+ ::arrow::ArrayData::Make(value_field->type(), 0, std::move(child_buffers));
+
+ std::vector<std::shared_ptr<Buffer>> buffers = {nullptr /* bitmap */,
+ std::move(offsets_buffer)};
+ auto array_data = ::arrow::ArrayData::Make(list_type, size, std::move(buffers));
+ array_data->child_data.push_back(child_data);
+
+ *out_array = ::arrow::MakeArray(array_data);
+ return Status::OK();
+}
+
+std::shared_ptr<::arrow::Table> MakeSimpleTable(
+ const std::shared_ptr<ChunkedArray>& values, bool nullable) {
+ auto schema = ::arrow::schema({::arrow::field("col", values->type(), nullable)});
+ return ::arrow::Table::Make(schema, {values});
+}
+
+std::shared_ptr<::arrow::Table> MakeSimpleTable(const std::shared_ptr<Array>& values,
+ bool nullable) {
+ auto carr = std::make_shared<::arrow::ChunkedArray>(values);
+ return MakeSimpleTable(carr, nullable);
+}
+
+template <typename T>
+void ExpectArray(T* expected, Array* result) {
+ auto p_array = static_cast<::arrow::PrimitiveArray*>(result);
+ for (int i = 0; i < result->length(); i++) {
+ EXPECT_EQ(expected[i], reinterpret_cast<const T*>(p_array->values()->data())[i]);
+ }
+}
+
+template <typename ArrowType>
+void ExpectArrayT(void* expected, Array* result) {
+ ::arrow::PrimitiveArray* p_array = static_cast<::arrow::PrimitiveArray*>(result);
+ for (int64_t i = 0; i < result->length(); i++) {
+ EXPECT_EQ(reinterpret_cast<typename ArrowType::c_type*>(expected)[i],
+ reinterpret_cast<const typename ArrowType::c_type*>(
+ p_array->values()->data())[i]);
+ }
+}
+
+template <>
+void ExpectArrayT<::arrow::BooleanType>(void* expected, Array* result) {
+ ::arrow::BooleanBuilder builder;
+ ARROW_EXPECT_OK(
+ builder.AppendValues(reinterpret_cast<uint8_t*>(expected), result->length()));
+
+ std::shared_ptr<Array> expected_array;
+ ARROW_EXPECT_OK(builder.Finish(&expected_array));
+ EXPECT_TRUE(result->Equals(*expected_array));
+}
+
+} // namespace arrow
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/writer.cc b/src/arrow/cpp/src/parquet/arrow/writer.cc
new file mode 100644
index 000000000..a2776b456
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/writer.cc
@@ -0,0 +1,480 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/arrow/writer.h"
+
+#include <algorithm>
+#include <deque>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/extension_type.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+#include "arrow/util/base64.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/visitor_inline.h"
+
+#include "parquet/arrow/path_internal.h"
+#include "parquet/arrow/reader_internal.h"
+#include "parquet/arrow/schema.h"
+#include "parquet/column_writer.h"
+#include "parquet/exception.h"
+#include "parquet/file_writer.h"
+#include "parquet/platform.h"
+#include "parquet/schema.h"
+
+using arrow::Array;
+using arrow::BinaryArray;
+using arrow::BooleanArray;
+using arrow::ChunkedArray;
+using arrow::DataType;
+using arrow::DictionaryArray;
+using arrow::ExtensionArray;
+using arrow::ExtensionType;
+using arrow::Field;
+using arrow::FixedSizeBinaryArray;
+using arrow::ListArray;
+using arrow::MemoryPool;
+using arrow::NumericArray;
+using arrow::PrimitiveArray;
+using arrow::ResizableBuffer;
+using arrow::Status;
+using arrow::Table;
+using arrow::TimeUnit;
+
+using arrow::internal::checked_cast;
+
+using parquet::ParquetFileWriter;
+using parquet::ParquetVersion;
+using parquet::schema::GroupNode;
+
+namespace parquet {
+namespace arrow {
+
+namespace {
+
+int CalculateLeafCount(const DataType* type) {
+ if (type->id() == ::arrow::Type::EXTENSION) {
+ type = checked_cast<const ExtensionType&>(*type).storage_type().get();
+ }
+ // Note num_fields() can be 0 for an empty struct type
+ if (!::arrow::is_nested(type->id())) {
+ // Primitive type.
+ return 1;
+ }
+
+ int num_leaves = 0;
+ for (const auto& field : type->fields()) {
+ num_leaves += CalculateLeafCount(field->type().get());
+ }
+ return num_leaves;
+}
+
+// Determines if the |schema_field|'s root ancestor is nullable.
+bool HasNullableRoot(const SchemaManifest& schema_manifest,
+ const SchemaField* schema_field) {
+ DCHECK(schema_field != nullptr);
+ const SchemaField* current_field = schema_field;
+ bool nullable = schema_field->field->nullable();
+ while (current_field != nullptr) {
+ nullable = current_field->field->nullable();
+ current_field = schema_manifest.GetParent(current_field);
+ }
+ return nullable;
+}
+
+// Manages writing nested parquet columns with support for all nested types
+// supported by parquet.
+class ArrowColumnWriterV2 {
+ public:
+ // Constructs a new object (use Make() method below to construct from
+ // A ChunkedArray).
+ // level_builders should contain one MultipathLevelBuilder per chunk of the
+ // Arrow-column to write.
+ ArrowColumnWriterV2(std::vector<std::unique_ptr<MultipathLevelBuilder>> level_builders,
+ int leaf_count, RowGroupWriter* row_group_writer)
+ : level_builders_(std::move(level_builders)),
+ leaf_count_(leaf_count),
+ row_group_writer_(row_group_writer) {}
+
+ // Writes out all leaf parquet columns to the RowGroupWriter that this
+ // object was constructed with. Each leaf column is written fully before
+ // the next column is written (i.e. no buffering is assumed).
+ //
+ // Columns are written in DFS order.
+ Status Write(ArrowWriteContext* ctx) {
+ for (int leaf_idx = 0; leaf_idx < leaf_count_; leaf_idx++) {
+ ColumnWriter* column_writer;
+ PARQUET_CATCH_NOT_OK(column_writer = row_group_writer_->NextColumn());
+ for (auto& level_builder : level_builders_) {
+ RETURN_NOT_OK(level_builder->Write(
+ leaf_idx, ctx, [&](const MultipathLevelBuilderResult& result) {
+ size_t visited_component_size = result.post_list_visited_elements.size();
+ DCHECK_GT(visited_component_size, 0);
+ if (visited_component_size != 1) {
+ return Status::NotImplemented(
+ "Lists with non-zero length null components are not supported");
+ }
+ const ElementRange& range = result.post_list_visited_elements[0];
+ std::shared_ptr<Array> values_array =
+ result.leaf_array->Slice(range.start, range.Size());
+
+ return column_writer->WriteArrow(result.def_levels, result.rep_levels,
+ result.def_rep_level_count, *values_array,
+ ctx, result.leaf_is_nullable);
+ }));
+ }
+
+ PARQUET_CATCH_NOT_OK(column_writer->Close());
+ }
+ return Status::OK();
+ }
+
+ // Make a new object by converting each chunk in |data| to a MultipathLevelBuilder.
+ //
+ // It is necessary to create a new builder per array because the MultipathlevelBuilder
+ // extracts the data necessary for writing each leaf column at construction time.
+ // (it optimizes based on null count) and with slicing via |offset| ephemeral
+ // chunks are created which need to be tracked across each leaf column-write.
+ // This decision could potentially be revisited if we wanted to use "buffered"
+ // RowGroupWriters (we could construct each builder on demand in that case).
+ static ::arrow::Result<std::unique_ptr<ArrowColumnWriterV2>> Make(
+ const ChunkedArray& data, int64_t offset, const int64_t size,
+ const SchemaManifest& schema_manifest, RowGroupWriter* row_group_writer) {
+ int64_t absolute_position = 0;
+ int chunk_index = 0;
+ int64_t chunk_offset = 0;
+ if (data.length() == 0) {
+ return ::arrow::internal::make_unique<ArrowColumnWriterV2>(
+ std::vector<std::unique_ptr<MultipathLevelBuilder>>{},
+ CalculateLeafCount(data.type().get()), row_group_writer);
+ }
+ while (chunk_index < data.num_chunks() && absolute_position < offset) {
+ const int64_t chunk_length = data.chunk(chunk_index)->length();
+ if (absolute_position + chunk_length > offset) {
+ // Relative offset into the chunk to reach the desired start offset for
+ // writing
+ chunk_offset = offset - absolute_position;
+ break;
+ } else {
+ ++chunk_index;
+ absolute_position += chunk_length;
+ }
+ }
+
+ if (absolute_position >= data.length()) {
+ return Status::Invalid("Cannot write data at offset past end of chunked array");
+ }
+
+ int64_t values_written = 0;
+ std::vector<std::unique_ptr<MultipathLevelBuilder>> builders;
+ const int leaf_count = CalculateLeafCount(data.type().get());
+ bool is_nullable = false;
+ // The row_group_writer hasn't been advanced yet so add 1 to the current
+ // which is the one this instance will start writing for.
+ int column_index = row_group_writer->current_column() + 1;
+ for (int leaf_offset = 0; leaf_offset < leaf_count; ++leaf_offset) {
+ const SchemaField* schema_field = nullptr;
+ RETURN_NOT_OK(
+ schema_manifest.GetColumnField(column_index + leaf_offset, &schema_field));
+ bool nullable_root = HasNullableRoot(schema_manifest, schema_field);
+ if (leaf_offset == 0) {
+ is_nullable = nullable_root;
+ }
+
+// Don't validate common ancestry for all leafs if not in debug.
+#ifndef NDEBUG
+ break;
+#else
+ if (is_nullable != nullable_root) {
+ return Status::UnknownError(
+ "Unexpected mismatched nullability between column index",
+ column_index + leaf_offset, " and ", column_index);
+ }
+#endif
+ }
+ while (values_written < size) {
+ const Array& chunk = *data.chunk(chunk_index);
+ const int64_t available_values = chunk.length() - chunk_offset;
+ const int64_t chunk_write_size = std::min(size - values_written, available_values);
+
+ // The chunk offset here will be 0 except for possibly the first chunk
+ // because of the advancing logic above
+ std::shared_ptr<Array> array_to_write = chunk.Slice(chunk_offset, chunk_write_size);
+
+ if (array_to_write->length() > 0) {
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<MultipathLevelBuilder> builder,
+ MultipathLevelBuilder::Make(*array_to_write, is_nullable));
+ if (leaf_count != builder->GetLeafCount()) {
+ return Status::UnknownError("data type leaf_count != builder_leaf_count",
+ leaf_count, " ", builder->GetLeafCount());
+ }
+ builders.emplace_back(std::move(builder));
+ }
+
+ if (chunk_write_size == available_values) {
+ chunk_offset = 0;
+ ++chunk_index;
+ }
+ values_written += chunk_write_size;
+ }
+ return ::arrow::internal::make_unique<ArrowColumnWriterV2>(
+ std::move(builders), leaf_count, row_group_writer);
+ }
+
+ private:
+ // One builder per column-chunk.
+ std::vector<std::unique_ptr<MultipathLevelBuilder>> level_builders_;
+ int leaf_count_;
+ RowGroupWriter* row_group_writer_;
+};
+
+} // namespace
+
+// ----------------------------------------------------------------------
+// FileWriter implementation
+
+class FileWriterImpl : public FileWriter {
+ public:
+ FileWriterImpl(std::shared_ptr<::arrow::Schema> schema, MemoryPool* pool,
+ std::unique_ptr<ParquetFileWriter> writer,
+ std::shared_ptr<ArrowWriterProperties> arrow_properties)
+ : schema_(std::move(schema)),
+ writer_(std::move(writer)),
+ row_group_writer_(nullptr),
+ column_write_context_(pool, arrow_properties.get()),
+ arrow_properties_(std::move(arrow_properties)),
+ closed_(false) {}
+
+ Status Init() {
+ return SchemaManifest::Make(writer_->schema(), /*schema_metadata=*/nullptr,
+ default_arrow_reader_properties(), &schema_manifest_);
+ }
+
+ Status NewRowGroup(int64_t chunk_size) override {
+ if (row_group_writer_ != nullptr) {
+ PARQUET_CATCH_NOT_OK(row_group_writer_->Close());
+ }
+ PARQUET_CATCH_NOT_OK(row_group_writer_ = writer_->AppendRowGroup());
+ return Status::OK();
+ }
+
+ Status Close() override {
+ if (!closed_) {
+ // Make idempotent
+ closed_ = true;
+ if (row_group_writer_ != nullptr) {
+ PARQUET_CATCH_NOT_OK(row_group_writer_->Close());
+ }
+ PARQUET_CATCH_NOT_OK(writer_->Close());
+ }
+ return Status::OK();
+ }
+
+ Status WriteColumnChunk(const Array& data) override {
+ // A bit awkward here since cannot instantiate ChunkedArray from const Array&
+ auto chunk = ::arrow::MakeArray(data.data());
+ auto chunked_array = std::make_shared<::arrow::ChunkedArray>(chunk);
+ return WriteColumnChunk(chunked_array, 0, data.length());
+ }
+
+ Status WriteColumnChunk(const std::shared_ptr<ChunkedArray>& data, int64_t offset,
+ int64_t size) override {
+ if (arrow_properties_->engine_version() == ArrowWriterProperties::V2 ||
+ arrow_properties_->engine_version() == ArrowWriterProperties::V1) {
+ ARROW_ASSIGN_OR_RAISE(
+ std::unique_ptr<ArrowColumnWriterV2> writer,
+ ArrowColumnWriterV2::Make(*data, offset, size, schema_manifest_,
+ row_group_writer_));
+ return writer->Write(&column_write_context_);
+ }
+ return Status::NotImplemented("Unknown engine version.");
+ }
+
+ Status WriteColumnChunk(const std::shared_ptr<::arrow::ChunkedArray>& data) override {
+ return WriteColumnChunk(data, 0, data->length());
+ }
+
+ std::shared_ptr<::arrow::Schema> schema() const override { return schema_; }
+
+ Status WriteTable(const Table& table, int64_t chunk_size) override {
+ RETURN_NOT_OK(table.Validate());
+
+ if (chunk_size <= 0 && table.num_rows() > 0) {
+ return Status::Invalid("chunk size per row_group must be greater than 0");
+ } else if (!table.schema()->Equals(*schema_, false)) {
+ return Status::Invalid("table schema does not match this writer's. table:'",
+ table.schema()->ToString(), "' this:'", schema_->ToString(),
+ "'");
+ } else if (chunk_size > this->properties().max_row_group_length()) {
+ chunk_size = this->properties().max_row_group_length();
+ }
+
+ auto WriteRowGroup = [&](int64_t offset, int64_t size) {
+ RETURN_NOT_OK(NewRowGroup(size));
+ for (int i = 0; i < table.num_columns(); i++) {
+ RETURN_NOT_OK(WriteColumnChunk(table.column(i), offset, size));
+ }
+ return Status::OK();
+ };
+
+ if (table.num_rows() == 0) {
+ // Append a row group with 0 rows
+ RETURN_NOT_OK_ELSE(WriteRowGroup(0, 0), PARQUET_IGNORE_NOT_OK(Close()));
+ return Status::OK();
+ }
+
+ for (int chunk = 0; chunk * chunk_size < table.num_rows(); chunk++) {
+ int64_t offset = chunk * chunk_size;
+ RETURN_NOT_OK_ELSE(
+ WriteRowGroup(offset, std::min(chunk_size, table.num_rows() - offset)),
+ PARQUET_IGNORE_NOT_OK(Close()));
+ }
+ return Status::OK();
+ }
+
+ const WriterProperties& properties() const { return *writer_->properties(); }
+
+ ::arrow::MemoryPool* memory_pool() const override {
+ return column_write_context_.memory_pool;
+ }
+
+ const std::shared_ptr<FileMetaData> metadata() const override {
+ return writer_->metadata();
+ }
+
+ private:
+ friend class FileWriter;
+
+ std::shared_ptr<::arrow::Schema> schema_;
+
+ SchemaManifest schema_manifest_;
+
+ std::unique_ptr<ParquetFileWriter> writer_;
+ RowGroupWriter* row_group_writer_;
+ ArrowWriteContext column_write_context_;
+ std::shared_ptr<ArrowWriterProperties> arrow_properties_;
+ bool closed_;
+};
+
+FileWriter::~FileWriter() {}
+
+Status FileWriter::Make(::arrow::MemoryPool* pool,
+ std::unique_ptr<ParquetFileWriter> writer,
+ std::shared_ptr<::arrow::Schema> schema,
+ std::shared_ptr<ArrowWriterProperties> arrow_properties,
+ std::unique_ptr<FileWriter>* out) {
+ std::unique_ptr<FileWriterImpl> impl(new FileWriterImpl(
+ std::move(schema), pool, std::move(writer), std::move(arrow_properties)));
+ RETURN_NOT_OK(impl->Init());
+ *out = std::move(impl);
+ return Status::OK();
+}
+
+Status FileWriter::Open(const ::arrow::Schema& schema, ::arrow::MemoryPool* pool,
+ std::shared_ptr<::arrow::io::OutputStream> sink,
+ std::shared_ptr<WriterProperties> properties,
+ std::unique_ptr<FileWriter>* writer) {
+ return Open(std::move(schema), pool, std::move(sink), std::move(properties),
+ default_arrow_writer_properties(), writer);
+}
+
+Status GetSchemaMetadata(const ::arrow::Schema& schema, ::arrow::MemoryPool* pool,
+ const ArrowWriterProperties& properties,
+ std::shared_ptr<const KeyValueMetadata>* out) {
+ if (!properties.store_schema()) {
+ *out = nullptr;
+ return Status::OK();
+ }
+
+ static const std::string kArrowSchemaKey = "ARROW:schema";
+ std::shared_ptr<KeyValueMetadata> result;
+ if (schema.metadata()) {
+ result = schema.metadata()->Copy();
+ } else {
+ result = ::arrow::key_value_metadata({}, {});
+ }
+
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> serialized,
+ ::arrow::ipc::SerializeSchema(schema, pool));
+
+ // The serialized schema is not UTF-8, which is required for Thrift
+ std::string schema_as_string = serialized->ToString();
+ std::string schema_base64 = ::arrow::util::base64_encode(schema_as_string);
+ result->Append(kArrowSchemaKey, schema_base64);
+ *out = result;
+ return Status::OK();
+}
+
+Status FileWriter::Open(const ::arrow::Schema& schema, ::arrow::MemoryPool* pool,
+ std::shared_ptr<::arrow::io::OutputStream> sink,
+ std::shared_ptr<WriterProperties> properties,
+ std::shared_ptr<ArrowWriterProperties> arrow_properties,
+ std::unique_ptr<FileWriter>* writer) {
+ std::shared_ptr<SchemaDescriptor> parquet_schema;
+ RETURN_NOT_OK(
+ ToParquetSchema(&schema, *properties, *arrow_properties, &parquet_schema));
+
+ auto schema_node = std::static_pointer_cast<GroupNode>(parquet_schema->schema_root());
+
+ std::shared_ptr<const KeyValueMetadata> metadata;
+ RETURN_NOT_OK(GetSchemaMetadata(schema, pool, *arrow_properties, &metadata));
+
+ std::unique_ptr<ParquetFileWriter> base_writer;
+ PARQUET_CATCH_NOT_OK(base_writer = ParquetFileWriter::Open(std::move(sink), schema_node,
+ std::move(properties),
+ std::move(metadata)));
+
+ auto schema_ptr = std::make_shared<::arrow::Schema>(schema);
+ return Make(pool, std::move(base_writer), std::move(schema_ptr),
+ std::move(arrow_properties), writer);
+}
+
+Status WriteFileMetaData(const FileMetaData& file_metadata,
+ ::arrow::io::OutputStream* sink) {
+ PARQUET_CATCH_NOT_OK(::parquet::WriteFileMetaData(file_metadata, sink));
+ return Status::OK();
+}
+
+Status WriteMetaDataFile(const FileMetaData& file_metadata,
+ ::arrow::io::OutputStream* sink) {
+ PARQUET_CATCH_NOT_OK(::parquet::WriteMetaDataFile(file_metadata, sink));
+ return Status::OK();
+}
+
+Status WriteTable(const ::arrow::Table& table, ::arrow::MemoryPool* pool,
+ std::shared_ptr<::arrow::io::OutputStream> sink, int64_t chunk_size,
+ std::shared_ptr<WriterProperties> properties,
+ std::shared_ptr<ArrowWriterProperties> arrow_properties) {
+ std::unique_ptr<FileWriter> writer;
+ RETURN_NOT_OK(FileWriter::Open(*table.schema(), pool, std::move(sink),
+ std::move(properties), std::move(arrow_properties),
+ &writer));
+ RETURN_NOT_OK(writer->WriteTable(table, chunk_size));
+ return writer->Close();
+}
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/arrow/writer.h b/src/arrow/cpp/src/parquet/arrow/writer.h
new file mode 100644
index 000000000..f31f3d03d
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/arrow/writer.h
@@ -0,0 +1,109 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "parquet/platform.h"
+#include "parquet/properties.h"
+
+namespace arrow {
+
+class Array;
+class ChunkedArray;
+class Schema;
+class Table;
+
+} // namespace arrow
+
+namespace parquet {
+
+class FileMetaData;
+class ParquetFileWriter;
+
+namespace arrow {
+
+/// \brief Iterative FileWriter class
+///
+/// Start a new RowGroup or Chunk with NewRowGroup.
+/// Write column-by-column the whole column chunk.
+///
+/// If PARQUET:field_id is present as a metadata key on a field, and the corresponding
+/// value is a nonnegative integer, then it will be used as the field_id in the parquet
+/// file.
+class PARQUET_EXPORT FileWriter {
+ public:
+ static ::arrow::Status Make(MemoryPool* pool, std::unique_ptr<ParquetFileWriter> writer,
+ std::shared_ptr<::arrow::Schema> schema,
+ std::shared_ptr<ArrowWriterProperties> arrow_properties,
+ std::unique_ptr<FileWriter>* out);
+
+ static ::arrow::Status Open(const ::arrow::Schema& schema, MemoryPool* pool,
+ std::shared_ptr<::arrow::io::OutputStream> sink,
+ std::shared_ptr<WriterProperties> properties,
+ std::unique_ptr<FileWriter>* writer);
+
+ static ::arrow::Status Open(const ::arrow::Schema& schema, MemoryPool* pool,
+ std::shared_ptr<::arrow::io::OutputStream> sink,
+ std::shared_ptr<WriterProperties> properties,
+ std::shared_ptr<ArrowWriterProperties> arrow_properties,
+ std::unique_ptr<FileWriter>* writer);
+
+ virtual std::shared_ptr<::arrow::Schema> schema() const = 0;
+
+ /// \brief Write a Table to Parquet.
+ virtual ::arrow::Status WriteTable(const ::arrow::Table& table, int64_t chunk_size) = 0;
+
+ virtual ::arrow::Status NewRowGroup(int64_t chunk_size) = 0;
+ virtual ::arrow::Status WriteColumnChunk(const ::arrow::Array& data) = 0;
+
+ /// \brief Write ColumnChunk in row group using slice of a ChunkedArray
+ virtual ::arrow::Status WriteColumnChunk(
+ const std::shared_ptr<::arrow::ChunkedArray>& data, int64_t offset,
+ int64_t size) = 0;
+
+ virtual ::arrow::Status WriteColumnChunk(
+ const std::shared_ptr<::arrow::ChunkedArray>& data) = 0;
+ virtual ::arrow::Status Close() = 0;
+ virtual ~FileWriter();
+
+ virtual MemoryPool* memory_pool() const = 0;
+ virtual const std::shared_ptr<FileMetaData> metadata() const = 0;
+};
+
+/// \brief Write Parquet file metadata only to indicated Arrow OutputStream
+PARQUET_EXPORT
+::arrow::Status WriteFileMetaData(const FileMetaData& file_metadata,
+ ::arrow::io::OutputStream* sink);
+
+/// \brief Write metadata-only Parquet file to indicated Arrow OutputStream
+PARQUET_EXPORT
+::arrow::Status WriteMetaDataFile(const FileMetaData& file_metadata,
+ ::arrow::io::OutputStream* sink);
+
+/// \brief Write a Table to Parquet.
+::arrow::Status PARQUET_EXPORT
+WriteTable(const ::arrow::Table& table, MemoryPool* pool,
+ std::shared_ptr<::arrow::io::OutputStream> sink, int64_t chunk_size,
+ std::shared_ptr<WriterProperties> properties = default_writer_properties(),
+ std::shared_ptr<ArrowWriterProperties> arrow_properties =
+ default_arrow_writer_properties());
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/bloom_filter.cc b/src/arrow/cpp/src/parquet/bloom_filter.cc
new file mode 100644
index 000000000..f6f6d327d
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/bloom_filter.cc
@@ -0,0 +1,162 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <cstring>
+
+#include "arrow/result.h"
+#include "arrow/util/logging.h"
+#include "parquet/bloom_filter.h"
+#include "parquet/exception.h"
+#include "parquet/murmur3.h"
+
+namespace parquet {
+constexpr uint32_t BlockSplitBloomFilter::SALT[kBitsSetPerBlock];
+
+BlockSplitBloomFilter::BlockSplitBloomFilter()
+ : pool_(::arrow::default_memory_pool()),
+ hash_strategy_(HashStrategy::MURMUR3_X64_128),
+ algorithm_(Algorithm::BLOCK) {}
+
+void BlockSplitBloomFilter::Init(uint32_t num_bytes) {
+ if (num_bytes < kMinimumBloomFilterBytes) {
+ num_bytes = kMinimumBloomFilterBytes;
+ }
+
+ // Get next power of 2 if it is not power of 2.
+ if ((num_bytes & (num_bytes - 1)) != 0) {
+ num_bytes = static_cast<uint32_t>(::arrow::BitUtil::NextPower2(num_bytes));
+ }
+
+ if (num_bytes > kMaximumBloomFilterBytes) {
+ num_bytes = kMaximumBloomFilterBytes;
+ }
+
+ num_bytes_ = num_bytes;
+ PARQUET_ASSIGN_OR_THROW(data_, ::arrow::AllocateBuffer(num_bytes_, pool_));
+ memset(data_->mutable_data(), 0, num_bytes_);
+
+ this->hasher_.reset(new MurmurHash3());
+}
+
+void BlockSplitBloomFilter::Init(const uint8_t* bitset, uint32_t num_bytes) {
+ DCHECK(bitset != nullptr);
+
+ if (num_bytes < kMinimumBloomFilterBytes || num_bytes > kMaximumBloomFilterBytes ||
+ (num_bytes & (num_bytes - 1)) != 0) {
+ throw ParquetException("Given length of bitset is illegal");
+ }
+
+ num_bytes_ = num_bytes;
+ PARQUET_ASSIGN_OR_THROW(data_, ::arrow::AllocateBuffer(num_bytes_, pool_));
+ memcpy(data_->mutable_data(), bitset, num_bytes_);
+
+ this->hasher_.reset(new MurmurHash3());
+}
+
+BlockSplitBloomFilter BlockSplitBloomFilter::Deserialize(ArrowInputStream* input) {
+ uint32_t len, hash, algorithm;
+ int64_t bytes_available;
+
+ PARQUET_ASSIGN_OR_THROW(bytes_available, input->Read(sizeof(uint32_t), &len));
+ if (static_cast<uint32_t>(bytes_available) != sizeof(uint32_t)) {
+ throw ParquetException("Failed to deserialize from input stream");
+ }
+
+ PARQUET_ASSIGN_OR_THROW(bytes_available, input->Read(sizeof(uint32_t), &hash));
+ if (static_cast<uint32_t>(bytes_available) != sizeof(uint32_t)) {
+ throw ParquetException("Failed to deserialize from input stream");
+ }
+ if (static_cast<HashStrategy>(hash) != HashStrategy::MURMUR3_X64_128) {
+ throw ParquetException("Unsupported hash strategy");
+ }
+
+ PARQUET_ASSIGN_OR_THROW(bytes_available, input->Read(sizeof(uint32_t), &algorithm));
+ if (static_cast<uint32_t>(bytes_available) != sizeof(uint32_t)) {
+ throw ParquetException("Failed to deserialize from input stream");
+ }
+ if (static_cast<Algorithm>(algorithm) != BloomFilter::Algorithm::BLOCK) {
+ throw ParquetException("Unsupported Bloom filter algorithm");
+ }
+
+ BlockSplitBloomFilter bloom_filter;
+
+ PARQUET_ASSIGN_OR_THROW(auto buffer, input->Read(len));
+ bloom_filter.Init(buffer->data(), len);
+ return bloom_filter;
+}
+
+void BlockSplitBloomFilter::WriteTo(ArrowOutputStream* sink) const {
+ DCHECK(sink != nullptr);
+
+ PARQUET_THROW_NOT_OK(
+ sink->Write(reinterpret_cast<const uint8_t*>(&num_bytes_), sizeof(num_bytes_)));
+ PARQUET_THROW_NOT_OK(sink->Write(reinterpret_cast<const uint8_t*>(&hash_strategy_),
+ sizeof(hash_strategy_)));
+ PARQUET_THROW_NOT_OK(
+ sink->Write(reinterpret_cast<const uint8_t*>(&algorithm_), sizeof(algorithm_)));
+ PARQUET_THROW_NOT_OK(sink->Write(data_->mutable_data(), num_bytes_));
+}
+
+void BlockSplitBloomFilter::SetMask(uint32_t key, BlockMask& block_mask) const {
+ for (int i = 0; i < kBitsSetPerBlock; ++i) {
+ block_mask.item[i] = key * SALT[i];
+ }
+
+ for (int i = 0; i < kBitsSetPerBlock; ++i) {
+ block_mask.item[i] = block_mask.item[i] >> 27;
+ }
+
+ for (int i = 0; i < kBitsSetPerBlock; ++i) {
+ block_mask.item[i] = UINT32_C(0x1) << block_mask.item[i];
+ }
+}
+
+bool BlockSplitBloomFilter::FindHash(uint64_t hash) const {
+ const uint32_t bucket_index =
+ static_cast<uint32_t>((hash >> 32) & (num_bytes_ / kBytesPerFilterBlock - 1));
+ uint32_t key = static_cast<uint32_t>(hash);
+ uint32_t* bitset32 = reinterpret_cast<uint32_t*>(data_->mutable_data());
+
+ // Calculate mask for bucket.
+ BlockMask block_mask;
+ SetMask(key, block_mask);
+
+ for (int i = 0; i < kBitsSetPerBlock; ++i) {
+ if (0 == (bitset32[kBitsSetPerBlock * bucket_index + i] & block_mask.item[i])) {
+ return false;
+ }
+ }
+ return true;
+}
+
+void BlockSplitBloomFilter::InsertHash(uint64_t hash) {
+ const uint32_t bucket_index =
+ static_cast<uint32_t>(hash >> 32) & (num_bytes_ / kBytesPerFilterBlock - 1);
+ uint32_t key = static_cast<uint32_t>(hash);
+ uint32_t* bitset32 = reinterpret_cast<uint32_t*>(data_->mutable_data());
+
+ // Calculate mask for bucket.
+ BlockMask block_mask;
+ SetMask(key, block_mask);
+
+ for (int i = 0; i < kBitsSetPerBlock; i++) {
+ bitset32[bucket_index * kBitsSetPerBlock + i] |= block_mask.item[i];
+ }
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/bloom_filter.h b/src/arrow/cpp/src/parquet/bloom_filter.h
new file mode 100644
index 000000000..39f9561ae
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/bloom_filter.h
@@ -0,0 +1,247 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cmath>
+#include <cstdint>
+#include <memory>
+
+#include "arrow/util/bit_util.h"
+#include "arrow/util/logging.h"
+#include "parquet/hasher.h"
+#include "parquet/platform.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+// A Bloom filter is a compact structure to indicate whether an item is not in a set or
+// probably in a set. The Bloom filter usually consists of a bit set that represents a
+// set of elements, a hash strategy and a Bloom filter algorithm.
+class PARQUET_EXPORT BloomFilter {
+ public:
+ // Maximum Bloom filter size, it sets to HDFS default block size 128MB
+ // This value will be reconsidered when implementing Bloom filter producer.
+ static constexpr uint32_t kMaximumBloomFilterBytes = 128 * 1024 * 1024;
+
+ /// Determine whether an element exist in set or not.
+ ///
+ /// @param hash the element to contain.
+ /// @return false if value is definitely not in set, and true means PROBABLY
+ /// in set.
+ virtual bool FindHash(uint64_t hash) const = 0;
+
+ /// Insert element to set represented by Bloom filter bitset.
+ /// @param hash the hash of value to insert into Bloom filter.
+ virtual void InsertHash(uint64_t hash) = 0;
+
+ /// Write this Bloom filter to an output stream. A Bloom filter structure should
+ /// include bitset length, hash strategy, algorithm, and bitset.
+ ///
+ /// @param sink the output stream to write
+ virtual void WriteTo(ArrowOutputStream* sink) const = 0;
+
+ /// Get the number of bytes of bitset
+ virtual uint32_t GetBitsetSize() const = 0;
+
+ /// Compute hash for 32 bits value by using its plain encoding result.
+ ///
+ /// @param value the value to hash.
+ /// @return hash result.
+ virtual uint64_t Hash(int32_t value) const = 0;
+
+ /// Compute hash for 64 bits value by using its plain encoding result.
+ ///
+ /// @param value the value to hash.
+ /// @return hash result.
+ virtual uint64_t Hash(int64_t value) const = 0;
+
+ /// Compute hash for float value by using its plain encoding result.
+ ///
+ /// @param value the value to hash.
+ /// @return hash result.
+ virtual uint64_t Hash(float value) const = 0;
+
+ /// Compute hash for double value by using its plain encoding result.
+ ///
+ /// @param value the value to hash.
+ /// @return hash result.
+ virtual uint64_t Hash(double value) const = 0;
+
+ /// Compute hash for Int96 value by using its plain encoding result.
+ ///
+ /// @param value the value to hash.
+ /// @return hash result.
+ virtual uint64_t Hash(const Int96* value) const = 0;
+
+ /// Compute hash for ByteArray value by using its plain encoding result.
+ ///
+ /// @param value the value to hash.
+ /// @return hash result.
+ virtual uint64_t Hash(const ByteArray* value) const = 0;
+
+ /// Compute hash for fixed byte array value by using its plain encoding result.
+ ///
+ /// @param value the value address.
+ /// @param len the value length.
+ /// @return hash result.
+ virtual uint64_t Hash(const FLBA* value, uint32_t len) const = 0;
+
+ virtual ~BloomFilter() {}
+
+ protected:
+ // Hash strategy available for Bloom filter.
+ enum class HashStrategy : uint32_t { MURMUR3_X64_128 = 0 };
+
+ // Bloom filter algorithm.
+ enum class Algorithm : uint32_t { BLOCK = 0 };
+};
+
+// The BlockSplitBloomFilter is implemented using block-based Bloom filters from
+// Putze et al.'s "Cache-,Hash- and Space-Efficient Bloom filters". The basic idea is to
+// hash the item to a tiny Bloom filter which size fit a single cache line or smaller.
+//
+// This implementation sets 8 bits in each tiny Bloom filter. Each tiny Bloom
+// filter is 32 bytes to take advantage of 32-byte SIMD instructions.
+class PARQUET_EXPORT BlockSplitBloomFilter : public BloomFilter {
+ public:
+ /// The constructor of BlockSplitBloomFilter. It uses murmur3_x64_128 as hash function.
+ BlockSplitBloomFilter();
+
+ /// Initialize the BlockSplitBloomFilter. The range of num_bytes should be within
+ /// [kMinimumBloomFilterBytes, kMaximumBloomFilterBytes], it will be
+ /// rounded up/down to lower/upper bound if num_bytes is out of range and also
+ /// will be rounded up to a power of 2.
+ ///
+ /// @param num_bytes The number of bytes to store Bloom filter bitset.
+ void Init(uint32_t num_bytes);
+
+ /// Initialize the BlockSplitBloomFilter. It copies the bitset as underlying
+ /// bitset because the given bitset may not satisfy the 32-byte alignment requirement
+ /// which may lead to segfault when performing SIMD instructions. It is the caller's
+ /// responsibility to free the bitset passed in. This is used when reconstructing
+ /// a Bloom filter from a parquet file.
+ ///
+ /// @param bitset The given bitset to initialize the Bloom filter.
+ /// @param num_bytes The number of bytes of given bitset.
+ void Init(const uint8_t* bitset, uint32_t num_bytes);
+
+ // Minimum Bloom filter size, it sets to 32 bytes to fit a tiny Bloom filter.
+ static constexpr uint32_t kMinimumBloomFilterBytes = 32;
+
+ /// Calculate optimal size according to the number of distinct values and false
+ /// positive probability.
+ ///
+ /// @param ndv The number of distinct values.
+ /// @param fpp The false positive probability.
+ /// @return it always return a value between kMinimumBloomFilterBytes and
+ /// kMaximumBloomFilterBytes, and the return value is always a power of 2
+ static uint32_t OptimalNumOfBits(uint32_t ndv, double fpp) {
+ DCHECK(fpp > 0.0 && fpp < 1.0);
+ const double m = -8.0 * ndv / log(1 - pow(fpp, 1.0 / 8));
+ uint32_t num_bits;
+
+ // Handle overflow.
+ if (m < 0 || m > kMaximumBloomFilterBytes << 3) {
+ num_bits = static_cast<uint32_t>(kMaximumBloomFilterBytes << 3);
+ } else {
+ num_bits = static_cast<uint32_t>(m);
+ }
+
+ // Round up to lower bound
+ if (num_bits < kMinimumBloomFilterBytes << 3) {
+ num_bits = kMinimumBloomFilterBytes << 3;
+ }
+
+ // Get next power of 2 if bits is not power of 2.
+ if ((num_bits & (num_bits - 1)) != 0) {
+ num_bits = static_cast<uint32_t>(::arrow::BitUtil::NextPower2(num_bits));
+ }
+
+ // Round down to upper bound
+ if (num_bits > kMaximumBloomFilterBytes << 3) {
+ num_bits = kMaximumBloomFilterBytes << 3;
+ }
+
+ return num_bits;
+ }
+
+ bool FindHash(uint64_t hash) const override;
+ void InsertHash(uint64_t hash) override;
+ void WriteTo(ArrowOutputStream* sink) const override;
+ uint32_t GetBitsetSize() const override { return num_bytes_; }
+
+ uint64_t Hash(int64_t value) const override { return hasher_->Hash(value); }
+ uint64_t Hash(float value) const override { return hasher_->Hash(value); }
+ uint64_t Hash(double value) const override { return hasher_->Hash(value); }
+ uint64_t Hash(const Int96* value) const override { return hasher_->Hash(value); }
+ uint64_t Hash(const ByteArray* value) const override { return hasher_->Hash(value); }
+ uint64_t Hash(int32_t value) const override { return hasher_->Hash(value); }
+ uint64_t Hash(const FLBA* value, uint32_t len) const override {
+ return hasher_->Hash(value, len);
+ }
+
+ /// Deserialize the Bloom filter from an input stream. It is used when reconstructing
+ /// a Bloom filter from a parquet filter.
+ ///
+ /// @param input_stream The input stream from which to construct the Bloom filter
+ /// @return The BlockSplitBloomFilter.
+ static BlockSplitBloomFilter Deserialize(ArrowInputStream* input_stream);
+
+ private:
+ // Bytes in a tiny Bloom filter block.
+ static constexpr int kBytesPerFilterBlock = 32;
+
+ // The number of bits to be set in each tiny Bloom filter
+ static constexpr int kBitsSetPerBlock = 8;
+
+ // A mask structure used to set bits in each tiny Bloom filter.
+ struct BlockMask {
+ uint32_t item[kBitsSetPerBlock];
+ };
+
+ // The block-based algorithm needs eight odd SALT values to calculate eight indexes
+ // of bit to set, one bit in each 32-bit word.
+ static constexpr uint32_t SALT[kBitsSetPerBlock] = {
+ 0x47b6137bU, 0x44974d91U, 0x8824ad5bU, 0xa2b7289dU,
+ 0x705495c7U, 0x2df1424bU, 0x9efc4947U, 0x5c6bfb31U};
+
+ /// Set bits in mask array according to input key.
+ /// @param key the value to calculate mask values.
+ /// @param mask the mask array is used to set inside a block
+ void SetMask(uint32_t key, BlockMask& mask) const;
+
+ // Memory pool to allocate aligned buffer for bitset
+ ::arrow::MemoryPool* pool_;
+
+ // The underlying buffer of bitset.
+ std::shared_ptr<Buffer> data_;
+
+ // The number of bytes of Bloom filter bitset.
+ uint32_t num_bytes_;
+
+ // Hash strategy used in this Bloom filter.
+ HashStrategy hash_strategy_;
+
+ // Algorithm used in this Bloom filter.
+ Algorithm algorithm_;
+
+ // The hash pointer points to actual hash class used.
+ std::unique_ptr<Hasher> hasher_;
+};
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/bloom_filter_test.cc b/src/arrow/cpp/src/parquet/bloom_filter_test.cc
new file mode 100644
index 000000000..23aa4a580
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/bloom_filter_test.cc
@@ -0,0 +1,247 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <random>
+#include <string>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/io/file.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+
+#include "parquet/bloom_filter.h"
+#include "parquet/exception.h"
+#include "parquet/murmur3.h"
+#include "parquet/platform.h"
+#include "parquet/test_util.h"
+#include "parquet/types.h"
+
+namespace parquet {
+namespace test {
+
+TEST(Murmur3Test, TestBloomFilter) {
+ uint64_t result;
+ const uint8_t bitset[8] = {0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7};
+ ByteArray byteArray(8, bitset);
+ MurmurHash3 murmur3;
+ result = murmur3.Hash(&byteArray);
+ EXPECT_EQ(result, UINT64_C(913737700387071329));
+}
+
+TEST(ConstructorTest, TestBloomFilter) {
+ BlockSplitBloomFilter bloom_filter;
+ EXPECT_NO_THROW(bloom_filter.Init(1000));
+
+ // It throws because the length cannot be zero
+ std::unique_ptr<uint8_t[]> bitset1(new uint8_t[1024]());
+ EXPECT_THROW(bloom_filter.Init(bitset1.get(), 0), ParquetException);
+
+ // It throws because the number of bytes of Bloom filter bitset must be a power of 2.
+ std::unique_ptr<uint8_t[]> bitset2(new uint8_t[1024]());
+ EXPECT_THROW(bloom_filter.Init(bitset2.get(), 1023), ParquetException);
+}
+
+// The BasicTest is used to test basic operations including InsertHash, FindHash and
+// serializing and de-serializing.
+TEST(BasicTest, TestBloomFilter) {
+ BlockSplitBloomFilter bloom_filter;
+ bloom_filter.Init(1024);
+
+ for (int i = 0; i < 10; i++) {
+ bloom_filter.InsertHash(bloom_filter.Hash(i));
+ }
+
+ for (int i = 0; i < 10; i++) {
+ EXPECT_TRUE(bloom_filter.FindHash(bloom_filter.Hash(i)));
+ }
+
+ // Serialize Bloom filter to memory output stream
+ auto sink = CreateOutputStream();
+ bloom_filter.WriteTo(sink.get());
+
+ // Deserialize Bloom filter from memory
+ ASSERT_OK_AND_ASSIGN(auto buffer, sink->Finish());
+ ::arrow::io::BufferReader source(buffer);
+
+ BlockSplitBloomFilter de_bloom = BlockSplitBloomFilter::Deserialize(&source);
+
+ for (int i = 0; i < 10; i++) {
+ EXPECT_TRUE(de_bloom.FindHash(de_bloom.Hash(i)));
+ }
+}
+
+// Helper function to generate random string.
+std::string GetRandomString(uint32_t length) {
+ // Character set used to generate random string
+ const std::string charset =
+ "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
+
+ std::default_random_engine gen(42);
+ std::uniform_int_distribution<uint32_t> dist(0, static_cast<int>(charset.size() - 1));
+ std::string ret(length, 'x');
+
+ for (uint32_t i = 0; i < length; i++) {
+ ret[i] = charset[dist(gen)];
+ }
+ return ret;
+}
+
+TEST(FPPTest, TestBloomFilter) {
+ // It counts the number of times FindHash returns true.
+ int exist = 0;
+
+ // Total count of elements that will be used
+#ifdef PARQUET_VALGRIND
+ const int total_count = 5000;
+#else
+ const int total_count = 100000;
+#endif
+
+ // Bloom filter fpp parameter
+ const double fpp = 0.01;
+
+ std::vector<std::string> members;
+ BlockSplitBloomFilter bloom_filter;
+ bloom_filter.Init(BlockSplitBloomFilter::OptimalNumOfBits(total_count, fpp));
+
+ // Insert elements into the Bloom filter
+ for (int i = 0; i < total_count; i++) {
+ // Insert random string which length is 8
+ std::string tmp = GetRandomString(8);
+ const ByteArray byte_array(8, reinterpret_cast<const uint8_t*>(tmp.c_str()));
+ members.push_back(tmp);
+ bloom_filter.InsertHash(bloom_filter.Hash(&byte_array));
+ }
+
+ for (int i = 0; i < total_count; i++) {
+ const ByteArray byte_array1(8, reinterpret_cast<const uint8_t*>(members[i].c_str()));
+ ASSERT_TRUE(bloom_filter.FindHash(bloom_filter.Hash(&byte_array1)));
+ std::string tmp = GetRandomString(7);
+ const ByteArray byte_array2(7, reinterpret_cast<const uint8_t*>(tmp.c_str()));
+
+ if (bloom_filter.FindHash(bloom_filter.Hash(&byte_array2))) {
+ exist++;
+ }
+ }
+
+ // The exist should be probably less than 1000 according default FPP 0.01.
+ EXPECT_LT(exist, total_count * fpp);
+}
+
+// The CompatibilityTest is used to test cross compatibility with parquet-mr, it reads
+// the Bloom filter binary generated by the Bloom filter class in the parquet-mr project
+// and tests whether the values inserted before could be filtered or not.
+
+// The Bloom filter binary is generated by three steps in from Parquet-mr.
+// Step 1: Construct a Bloom filter with 1024 bytes bitset.
+// Step 2: Insert "hello", "parquet", "bloom", "filter" to Bloom filter.
+// Step 3: Call writeTo API to write to File.
+TEST(CompatibilityTest, TestBloomFilter) {
+ const std::string test_string[4] = {"hello", "parquet", "bloom", "filter"};
+ const std::string bloom_filter_test_binary =
+ std::string(test::get_data_dir()) + "/bloom_filter.bin";
+
+ PARQUET_ASSIGN_OR_THROW(auto handle,
+ ::arrow::io::ReadableFile::Open(bloom_filter_test_binary));
+ PARQUET_ASSIGN_OR_THROW(int64_t size, handle->GetSize());
+
+ // 1024 bytes (bitset) + 4 bytes (hash) + 4 bytes (algorithm) + 4 bytes (length)
+ EXPECT_EQ(size, 1036);
+
+ std::unique_ptr<uint8_t[]> bitset(new uint8_t[size]());
+ PARQUET_ASSIGN_OR_THROW(auto buffer, handle->Read(size));
+
+ ::arrow::io::BufferReader source(buffer);
+ BlockSplitBloomFilter bloom_filter1 = BlockSplitBloomFilter::Deserialize(&source);
+
+ for (int i = 0; i < 4; i++) {
+ const ByteArray tmp(static_cast<uint32_t>(test_string[i].length()),
+ reinterpret_cast<const uint8_t*>(test_string[i].c_str()));
+ EXPECT_TRUE(bloom_filter1.FindHash(bloom_filter1.Hash(&tmp)));
+ }
+
+ // The following is used to check whether the new created Bloom filter in parquet-cpp is
+ // byte-for-byte identical to file at bloom_data_path which is created from parquet-mr
+ // with same inserted hashes.
+ BlockSplitBloomFilter bloom_filter2;
+ bloom_filter2.Init(bloom_filter1.GetBitsetSize());
+ for (int i = 0; i < 4; i++) {
+ const ByteArray byte_array(static_cast<uint32_t>(test_string[i].length()),
+ reinterpret_cast<const uint8_t*>(test_string[i].c_str()));
+ bloom_filter2.InsertHash(bloom_filter2.Hash(&byte_array));
+ }
+
+ // Serialize Bloom filter to memory output stream
+ auto sink = CreateOutputStream();
+ bloom_filter2.WriteTo(sink.get());
+ PARQUET_ASSIGN_OR_THROW(auto buffer1, sink->Finish());
+
+ PARQUET_THROW_NOT_OK(handle->Seek(0));
+ PARQUET_ASSIGN_OR_THROW(size, handle->GetSize());
+ PARQUET_ASSIGN_OR_THROW(auto buffer2, handle->Read(size));
+
+ EXPECT_TRUE((*buffer1).Equals(*buffer2));
+}
+
+// OptimalValueTest is used to test whether OptimalNumOfBits returns expected
+// numbers according to formula:
+// num_of_bits = -8.0 * ndv / log(1 - pow(fpp, 1.0 / 8.0))
+// where ndv is the number of distinct values and fpp is the false positive probability.
+// Also it is used to test whether OptimalNumOfBits returns value between
+// [MINIMUM_BLOOM_FILTER_SIZE, MAXIMUM_BLOOM_FILTER_SIZE].
+TEST(OptimalValueTest, TestBloomFilter) {
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(256, 0.01), UINT32_C(4096));
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(512, 0.01), UINT32_C(8192));
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(1024, 0.01), UINT32_C(16384));
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(2048, 0.01), UINT32_C(32768));
+
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(200, 0.01), UINT32_C(2048));
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(300, 0.01), UINT32_C(4096));
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(700, 0.01), UINT32_C(8192));
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(1500, 0.01), UINT32_C(16384));
+
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(200, 0.025), UINT32_C(2048));
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(300, 0.025), UINT32_C(4096));
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(700, 0.025), UINT32_C(8192));
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(1500, 0.025), UINT32_C(16384));
+
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(200, 0.05), UINT32_C(2048));
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(300, 0.05), UINT32_C(4096));
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(700, 0.05), UINT32_C(8192));
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(1500, 0.05), UINT32_C(16384));
+
+ // Boundary check
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(4, 0.01), UINT32_C(256));
+ EXPECT_EQ(BlockSplitBloomFilter::OptimalNumOfBits(4, 0.25), UINT32_C(256));
+
+ EXPECT_EQ(
+ BlockSplitBloomFilter::OptimalNumOfBits(std::numeric_limits<uint32_t>::max(), 0.01),
+ UINT32_C(1073741824));
+ EXPECT_EQ(
+ BlockSplitBloomFilter::OptimalNumOfBits(std::numeric_limits<uint32_t>::max(), 0.25),
+ UINT32_C(1073741824));
+}
+
+} // namespace test
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/column_io_benchmark.cc b/src/arrow/cpp/src/parquet/column_io_benchmark.cc
new file mode 100644
index 000000000..7f96516ef
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/column_io_benchmark.cc
@@ -0,0 +1,261 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/array.h"
+#include "arrow/io/memory.h"
+#include "arrow/testing/random.h"
+
+#include "parquet/column_reader.h"
+#include "parquet/column_writer.h"
+#include "parquet/file_reader.h"
+#include "parquet/metadata.h"
+#include "parquet/platform.h"
+#include "parquet/thrift_internal.h"
+
+namespace parquet {
+
+using schema::PrimitiveNode;
+
+namespace benchmark {
+
+std::shared_ptr<Int64Writer> BuildWriter(int64_t output_size,
+ const std::shared_ptr<ArrowOutputStream>& dst,
+ ColumnChunkMetaDataBuilder* metadata,
+ ColumnDescriptor* schema,
+ const WriterProperties* properties,
+ Compression::type codec) {
+ std::unique_ptr<PageWriter> pager =
+ PageWriter::Open(dst, codec, Codec::UseDefaultCompressionLevel(), metadata);
+ std::shared_ptr<ColumnWriter> writer =
+ ColumnWriter::Make(metadata, std::move(pager), properties);
+ return std::static_pointer_cast<Int64Writer>(writer);
+}
+
+std::shared_ptr<ColumnDescriptor> Int64Schema(Repetition::type repetition) {
+ auto node = PrimitiveNode::Make("int64", repetition, Type::INT64);
+ return std::make_shared<ColumnDescriptor>(node, repetition != Repetition::REQUIRED,
+ repetition == Repetition::REPEATED);
+}
+
+void SetBytesProcessed(::benchmark::State& state, Repetition::type repetition) {
+ int64_t bytes_processed = state.iterations() * state.range(0) * sizeof(int64_t);
+ if (repetition != Repetition::REQUIRED) {
+ bytes_processed += state.iterations() * state.range(0) * sizeof(int16_t);
+ }
+ if (repetition == Repetition::REPEATED) {
+ bytes_processed += state.iterations() * state.range(0) * sizeof(int16_t);
+ }
+ state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(int16_t));
+}
+
+template <Repetition::type repetition,
+ Compression::type codec = Compression::UNCOMPRESSED>
+static void BM_WriteInt64Column(::benchmark::State& state) {
+ format::ColumnChunk thrift_metadata;
+
+ ::arrow::random::RandomArrayGenerator rgen(1337);
+ auto values = rgen.Int64(state.range(0), 0, 1000000, 0);
+ const auto& i8_values = static_cast<const ::arrow::Int64Array&>(*values);
+
+ std::vector<int16_t> definition_levels(state.range(0), 1);
+ std::vector<int16_t> repetition_levels(state.range(0), 0);
+ std::shared_ptr<ColumnDescriptor> schema = Int64Schema(repetition);
+ std::shared_ptr<WriterProperties> properties = WriterProperties::Builder()
+ .compression(codec)
+ ->encoding(Encoding::PLAIN)
+ ->disable_dictionary()
+ ->build();
+ auto metadata = ColumnChunkMetaDataBuilder::Make(
+ properties, schema.get(), reinterpret_cast<uint8_t*>(&thrift_metadata));
+
+ while (state.KeepRunning()) {
+ auto stream = CreateOutputStream();
+ std::shared_ptr<Int64Writer> writer = BuildWriter(
+ state.range(0), stream, metadata.get(), schema.get(), properties.get(), codec);
+ writer->WriteBatch(i8_values.length(), definition_levels.data(),
+ repetition_levels.data(), i8_values.raw_values());
+ writer->Close();
+ }
+ SetBytesProcessed(state, repetition);
+}
+
+BENCHMARK_TEMPLATE(BM_WriteInt64Column, Repetition::REQUIRED)->Arg(1 << 20);
+BENCHMARK_TEMPLATE(BM_WriteInt64Column, Repetition::OPTIONAL)->Arg(1 << 20);
+BENCHMARK_TEMPLATE(BM_WriteInt64Column, Repetition::REPEATED)->Arg(1 << 20);
+BENCHMARK_TEMPLATE(BM_WriteInt64Column, Repetition::REQUIRED, Compression::SNAPPY)
+ ->Arg(1 << 20);
+BENCHMARK_TEMPLATE(BM_WriteInt64Column, Repetition::OPTIONAL, Compression::SNAPPY)
+ ->Arg(1 << 20);
+BENCHMARK_TEMPLATE(BM_WriteInt64Column, Repetition::REPEATED, Compression::SNAPPY)
+ ->Arg(1 << 20);
+
+BENCHMARK_TEMPLATE(BM_WriteInt64Column, Repetition::REQUIRED, Compression::LZ4)
+ ->Arg(1 << 20);
+BENCHMARK_TEMPLATE(BM_WriteInt64Column, Repetition::OPTIONAL, Compression::LZ4)
+ ->Arg(1 << 20);
+BENCHMARK_TEMPLATE(BM_WriteInt64Column, Repetition::REPEATED, Compression::LZ4)
+ ->Arg(1 << 20);
+
+BENCHMARK_TEMPLATE(BM_WriteInt64Column, Repetition::REQUIRED, Compression::ZSTD)
+ ->Arg(1 << 20);
+BENCHMARK_TEMPLATE(BM_WriteInt64Column, Repetition::OPTIONAL, Compression::ZSTD)
+ ->Arg(1 << 20);
+BENCHMARK_TEMPLATE(BM_WriteInt64Column, Repetition::REPEATED, Compression::ZSTD)
+ ->Arg(1 << 20);
+
+std::shared_ptr<Int64Reader> BuildReader(std::shared_ptr<Buffer>& buffer,
+ int64_t num_values, Compression::type codec,
+ ColumnDescriptor* schema) {
+ auto source = std::make_shared<::arrow::io::BufferReader>(buffer);
+ std::unique_ptr<PageReader> page_reader = PageReader::Open(source, num_values, codec);
+ return std::static_pointer_cast<Int64Reader>(
+ ColumnReader::Make(schema, std::move(page_reader)));
+}
+
+template <Repetition::type repetition,
+ Compression::type codec = Compression::UNCOMPRESSED>
+static void BM_ReadInt64Column(::benchmark::State& state) {
+ format::ColumnChunk thrift_metadata;
+ std::vector<int64_t> values(state.range(0), 128);
+ std::vector<int16_t> definition_levels(state.range(0), 1);
+ std::vector<int16_t> repetition_levels(state.range(0), 0);
+ std::shared_ptr<ColumnDescriptor> schema = Int64Schema(repetition);
+ std::shared_ptr<WriterProperties> properties = WriterProperties::Builder()
+ .compression(codec)
+ ->encoding(Encoding::PLAIN)
+ ->disable_dictionary()
+ ->build();
+
+ auto metadata = ColumnChunkMetaDataBuilder::Make(
+ properties, schema.get(), reinterpret_cast<uint8_t*>(&thrift_metadata));
+
+ auto stream = CreateOutputStream();
+ std::shared_ptr<Int64Writer> writer = BuildWriter(
+ state.range(0), stream, metadata.get(), schema.get(), properties.get(), codec);
+ writer->WriteBatch(values.size(), definition_levels.data(), repetition_levels.data(),
+ values.data());
+ writer->Close();
+
+ PARQUET_ASSIGN_OR_THROW(auto src, stream->Finish());
+ std::vector<int64_t> values_out(state.range(1));
+ std::vector<int16_t> definition_levels_out(state.range(1));
+ std::vector<int16_t> repetition_levels_out(state.range(1));
+ while (state.KeepRunning()) {
+ std::shared_ptr<Int64Reader> reader =
+ BuildReader(src, state.range(1), codec, schema.get());
+ int64_t values_read = 0;
+ for (size_t i = 0; i < values.size(); i += values_read) {
+ reader->ReadBatch(values_out.size(), definition_levels_out.data(),
+ repetition_levels_out.data(), values_out.data(), &values_read);
+ }
+ }
+ SetBytesProcessed(state, repetition);
+}
+
+void ReadColumnSetArgs(::benchmark::internal::Benchmark* bench) {
+ // Small column, tiny reads
+ bench->Args({1024, 16});
+ // Small column, full read
+ bench->Args({1024, 1024});
+ // Midsize column, midsize reads
+ bench->Args({65536, 1024});
+}
+
+BENCHMARK_TEMPLATE(BM_ReadInt64Column, Repetition::REQUIRED)->Apply(ReadColumnSetArgs);
+
+BENCHMARK_TEMPLATE(BM_ReadInt64Column, Repetition::OPTIONAL)->Apply(ReadColumnSetArgs);
+
+BENCHMARK_TEMPLATE(BM_ReadInt64Column, Repetition::REPEATED)->Apply(ReadColumnSetArgs);
+
+BENCHMARK_TEMPLATE(BM_ReadInt64Column, Repetition::REQUIRED, Compression::SNAPPY)
+ ->Apply(ReadColumnSetArgs);
+BENCHMARK_TEMPLATE(BM_ReadInt64Column, Repetition::OPTIONAL, Compression::SNAPPY)
+ ->Apply(ReadColumnSetArgs);
+BENCHMARK_TEMPLATE(BM_ReadInt64Column, Repetition::REPEATED, Compression::SNAPPY)
+ ->Apply(ReadColumnSetArgs);
+
+BENCHMARK_TEMPLATE(BM_ReadInt64Column, Repetition::REQUIRED, Compression::LZ4)
+ ->Apply(ReadColumnSetArgs);
+BENCHMARK_TEMPLATE(BM_ReadInt64Column, Repetition::OPTIONAL, Compression::LZ4)
+ ->Apply(ReadColumnSetArgs);
+BENCHMARK_TEMPLATE(BM_ReadInt64Column, Repetition::REPEATED, Compression::LZ4)
+ ->Apply(ReadColumnSetArgs);
+
+BENCHMARK_TEMPLATE(BM_ReadInt64Column, Repetition::REQUIRED, Compression::ZSTD)
+ ->Apply(ReadColumnSetArgs);
+BENCHMARK_TEMPLATE(BM_ReadInt64Column, Repetition::OPTIONAL, Compression::ZSTD)
+ ->Apply(ReadColumnSetArgs);
+BENCHMARK_TEMPLATE(BM_ReadInt64Column, Repetition::REPEATED, Compression::ZSTD)
+ ->Apply(ReadColumnSetArgs);
+
+static void BM_RleEncoding(::benchmark::State& state) {
+ std::vector<int16_t> levels(state.range(0), 0);
+ int64_t n = 0;
+ std::generate(levels.begin(), levels.end(),
+ [&state, &n] { return (n++ % state.range(1)) == 0; });
+ int16_t max_level = 1;
+ int64_t rle_size = LevelEncoder::MaxBufferSize(Encoding::RLE, max_level,
+ static_cast<int>(levels.size()));
+ auto buffer_rle = AllocateBuffer();
+ PARQUET_THROW_NOT_OK(buffer_rle->Resize(rle_size));
+
+ while (state.KeepRunning()) {
+ LevelEncoder level_encoder;
+ level_encoder.Init(Encoding::RLE, max_level, static_cast<int>(levels.size()),
+ buffer_rle->mutable_data(), static_cast<int>(buffer_rle->size()));
+ level_encoder.Encode(static_cast<int>(levels.size()), levels.data());
+ }
+ state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(int16_t));
+ state.SetItemsProcessed(state.iterations() * state.range(0));
+}
+
+BENCHMARK(BM_RleEncoding)->RangePair(1024, 65536, 1, 16);
+
+static void BM_RleDecoding(::benchmark::State& state) {
+ LevelEncoder level_encoder;
+ std::vector<int16_t> levels(state.range(0), 0);
+ int64_t n = 0;
+ std::generate(levels.begin(), levels.end(),
+ [&state, &n] { return (n++ % state.range(1)) == 0; });
+ int16_t max_level = 1;
+ int rle_size = LevelEncoder::MaxBufferSize(Encoding::RLE, max_level,
+ static_cast<int>(levels.size()));
+ auto buffer_rle = AllocateBuffer();
+ PARQUET_THROW_NOT_OK(buffer_rle->Resize(rle_size + sizeof(int32_t)));
+ level_encoder.Init(Encoding::RLE, max_level, static_cast<int>(levels.size()),
+ buffer_rle->mutable_data() + sizeof(int32_t), rle_size);
+ level_encoder.Encode(static_cast<int>(levels.size()), levels.data());
+ reinterpret_cast<int32_t*>(buffer_rle->mutable_data())[0] = level_encoder.len();
+
+ while (state.KeepRunning()) {
+ LevelDecoder level_decoder;
+ level_decoder.SetData(Encoding::RLE, max_level, static_cast<int>(levels.size()),
+ buffer_rle->data(), rle_size);
+ level_decoder.Decode(static_cast<int>(state.range(0)), levels.data());
+ }
+
+ state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(int16_t));
+ state.SetItemsProcessed(state.iterations() * state.range(0));
+}
+
+BENCHMARK(BM_RleDecoding)->RangePair(1024, 65536, 1, 16);
+
+} // namespace benchmark
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/column_page.h b/src/arrow/cpp/src/parquet/column_page.h
new file mode 100644
index 000000000..2fab77ed0
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/column_page.h
@@ -0,0 +1,160 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This module defines an abstract interface for iterating through pages in a
+// Parquet column chunk within a row group. It could be extended in the future
+// to iterate through all data pages in all chunks in a file.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+
+#include "parquet/statistics.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+// TODO: Parallel processing is not yet safe because of memory-ownership
+// semantics (the PageReader may or may not own the memory referenced by a
+// page)
+//
+// TODO(wesm): In the future Parquet implementations may store the crc code
+// in format::PageHeader. parquet-mr currently does not, so we also skip it
+// here, both on the read and write path
+class Page {
+ public:
+ Page(const std::shared_ptr<Buffer>& buffer, PageType::type type)
+ : buffer_(buffer), type_(type) {}
+
+ PageType::type type() const { return type_; }
+
+ std::shared_ptr<Buffer> buffer() const { return buffer_; }
+
+ // @returns: a pointer to the page's data
+ const uint8_t* data() const { return buffer_->data(); }
+
+ // @returns: the total size in bytes of the page's data buffer
+ int32_t size() const { return static_cast<int32_t>(buffer_->size()); }
+
+ private:
+ std::shared_ptr<Buffer> buffer_;
+ PageType::type type_;
+};
+
+/// \brief Base type for DataPageV1 and DataPageV2 including common attributes
+class DataPage : public Page {
+ public:
+ int32_t num_values() const { return num_values_; }
+ Encoding::type encoding() const { return encoding_; }
+ int64_t uncompressed_size() const { return uncompressed_size_; }
+ const EncodedStatistics& statistics() const { return statistics_; }
+
+ virtual ~DataPage() = default;
+
+ protected:
+ DataPage(PageType::type type, const std::shared_ptr<Buffer>& buffer, int32_t num_values,
+ Encoding::type encoding, int64_t uncompressed_size,
+ const EncodedStatistics& statistics = EncodedStatistics())
+ : Page(buffer, type),
+ num_values_(num_values),
+ encoding_(encoding),
+ uncompressed_size_(uncompressed_size),
+ statistics_(statistics) {}
+
+ int32_t num_values_;
+ Encoding::type encoding_;
+ int64_t uncompressed_size_;
+ EncodedStatistics statistics_;
+};
+
+class DataPageV1 : public DataPage {
+ public:
+ DataPageV1(const std::shared_ptr<Buffer>& buffer, int32_t num_values,
+ Encoding::type encoding, Encoding::type definition_level_encoding,
+ Encoding::type repetition_level_encoding, int64_t uncompressed_size,
+ const EncodedStatistics& statistics = EncodedStatistics())
+ : DataPage(PageType::DATA_PAGE, buffer, num_values, encoding, uncompressed_size,
+ statistics),
+ definition_level_encoding_(definition_level_encoding),
+ repetition_level_encoding_(repetition_level_encoding) {}
+
+ Encoding::type repetition_level_encoding() const { return repetition_level_encoding_; }
+
+ Encoding::type definition_level_encoding() const { return definition_level_encoding_; }
+
+ private:
+ Encoding::type definition_level_encoding_;
+ Encoding::type repetition_level_encoding_;
+};
+
+class DataPageV2 : public DataPage {
+ public:
+ DataPageV2(const std::shared_ptr<Buffer>& buffer, int32_t num_values, int32_t num_nulls,
+ int32_t num_rows, Encoding::type encoding,
+ int32_t definition_levels_byte_length, int32_t repetition_levels_byte_length,
+ int64_t uncompressed_size, bool is_compressed = false,
+ const EncodedStatistics& statistics = EncodedStatistics())
+ : DataPage(PageType::DATA_PAGE_V2, buffer, num_values, encoding, uncompressed_size,
+ statistics),
+ num_nulls_(num_nulls),
+ num_rows_(num_rows),
+ definition_levels_byte_length_(definition_levels_byte_length),
+ repetition_levels_byte_length_(repetition_levels_byte_length),
+ is_compressed_(is_compressed) {}
+
+ int32_t num_nulls() const { return num_nulls_; }
+
+ int32_t num_rows() const { return num_rows_; }
+
+ int32_t definition_levels_byte_length() const { return definition_levels_byte_length_; }
+
+ int32_t repetition_levels_byte_length() const { return repetition_levels_byte_length_; }
+
+ bool is_compressed() const { return is_compressed_; }
+
+ private:
+ int32_t num_nulls_;
+ int32_t num_rows_;
+ int32_t definition_levels_byte_length_;
+ int32_t repetition_levels_byte_length_;
+ bool is_compressed_;
+};
+
+class DictionaryPage : public Page {
+ public:
+ DictionaryPage(const std::shared_ptr<Buffer>& buffer, int32_t num_values,
+ Encoding::type encoding, bool is_sorted = false)
+ : Page(buffer, PageType::DICTIONARY_PAGE),
+ num_values_(num_values),
+ encoding_(encoding),
+ is_sorted_(is_sorted) {}
+
+ int32_t num_values() const { return num_values_; }
+
+ Encoding::type encoding() const { return encoding_; }
+
+ bool is_sorted() const { return is_sorted_; }
+
+ private:
+ int32_t num_values_;
+ Encoding::type encoding_;
+ bool is_sorted_;
+};
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/column_reader.cc b/src/arrow/cpp/src/parquet/column_reader.cc
new file mode 100644
index 000000000..c7ad78c10
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/column_reader.cc
@@ -0,0 +1,1808 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/column_reader.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <exception>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_dict.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/chunked_array.h"
+#include "arrow/type.h"
+#include "arrow/util/bit_stream_utils.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/compression.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/rle_encoding.h"
+#include "parquet/column_page.h"
+#include "parquet/encoding.h"
+#include "parquet/encryption/encryption_internal.h"
+#include "parquet/encryption/internal_file_decryptor.h"
+#include "parquet/level_comparison.h"
+#include "parquet/level_conversion.h"
+#include "parquet/properties.h"
+#include "parquet/statistics.h"
+#include "parquet/thrift_internal.h" // IWYU pragma: keep
+// Required after "arrow/util/int_util_internal.h" (for OPTIONAL)
+#include "parquet/windows_compatibility.h"
+
+using arrow::MemoryPool;
+using arrow::internal::AddWithOverflow;
+using arrow::internal::checked_cast;
+using arrow::internal::MultiplyWithOverflow;
+
+namespace BitUtil = arrow::BitUtil;
+
+namespace parquet {
+namespace {
+inline bool HasSpacedValues(const ColumnDescriptor* descr) {
+ if (descr->max_repetition_level() > 0) {
+ // repeated+flat case
+ return !descr->schema_node()->is_required();
+ } else {
+ // non-repeated+nested case
+ // Find if a node forces nulls in the lowest level along the hierarchy
+ const schema::Node* node = descr->schema_node().get();
+ while (node) {
+ if (node->is_optional()) {
+ return true;
+ }
+ node = node->parent();
+ }
+ return false;
+ }
+}
+} // namespace
+
+LevelDecoder::LevelDecoder() : num_values_remaining_(0) {}
+
+LevelDecoder::~LevelDecoder() {}
+
+int LevelDecoder::SetData(Encoding::type encoding, int16_t max_level,
+ int num_buffered_values, const uint8_t* data,
+ int32_t data_size) {
+ max_level_ = max_level;
+ int32_t num_bytes = 0;
+ encoding_ = encoding;
+ num_values_remaining_ = num_buffered_values;
+ bit_width_ = BitUtil::Log2(max_level + 1);
+ switch (encoding) {
+ case Encoding::RLE: {
+ if (data_size < 4) {
+ throw ParquetException("Received invalid levels (corrupt data page?)");
+ }
+ num_bytes = ::arrow::util::SafeLoadAs<int32_t>(data);
+ if (num_bytes < 0 || num_bytes > data_size - 4) {
+ throw ParquetException("Received invalid number of bytes (corrupt data page?)");
+ }
+ const uint8_t* decoder_data = data + 4;
+ if (!rle_decoder_) {
+ rle_decoder_.reset(
+ new ::arrow::util::RleDecoder(decoder_data, num_bytes, bit_width_));
+ } else {
+ rle_decoder_->Reset(decoder_data, num_bytes, bit_width_);
+ }
+ return 4 + num_bytes;
+ }
+ case Encoding::BIT_PACKED: {
+ int num_bits = 0;
+ if (MultiplyWithOverflow(num_buffered_values, bit_width_, &num_bits)) {
+ throw ParquetException(
+ "Number of buffered values too large (corrupt data page?)");
+ }
+ num_bytes = static_cast<int32_t>(BitUtil::BytesForBits(num_bits));
+ if (num_bytes < 0 || num_bytes > data_size - 4) {
+ throw ParquetException("Received invalid number of bytes (corrupt data page?)");
+ }
+ if (!bit_packed_decoder_) {
+ bit_packed_decoder_.reset(new ::arrow::BitUtil::BitReader(data, num_bytes));
+ } else {
+ bit_packed_decoder_->Reset(data, num_bytes);
+ }
+ return num_bytes;
+ }
+ default:
+ throw ParquetException("Unknown encoding type for levels.");
+ }
+ return -1;
+}
+
+void LevelDecoder::SetDataV2(int32_t num_bytes, int16_t max_level,
+ int num_buffered_values, const uint8_t* data) {
+ max_level_ = max_level;
+ // Repetition and definition levels always uses RLE encoding
+ // in the DataPageV2 format.
+ if (num_bytes < 0) {
+ throw ParquetException("Invalid page header (corrupt data page?)");
+ }
+ encoding_ = Encoding::RLE;
+ num_values_remaining_ = num_buffered_values;
+ bit_width_ = BitUtil::Log2(max_level + 1);
+
+ if (!rle_decoder_) {
+ rle_decoder_.reset(new ::arrow::util::RleDecoder(data, num_bytes, bit_width_));
+ } else {
+ rle_decoder_->Reset(data, num_bytes, bit_width_);
+ }
+}
+
+int LevelDecoder::Decode(int batch_size, int16_t* levels) {
+ int num_decoded = 0;
+
+ int num_values = std::min(num_values_remaining_, batch_size);
+ if (encoding_ == Encoding::RLE) {
+ num_decoded = rle_decoder_->GetBatch(levels, num_values);
+ } else {
+ num_decoded = bit_packed_decoder_->GetBatch(bit_width_, levels, num_values);
+ }
+ if (num_decoded > 0) {
+ internal::MinMax min_max = internal::FindMinMax(levels, num_decoded);
+ if (ARROW_PREDICT_FALSE(min_max.min < 0 || min_max.max > max_level_)) {
+ std::stringstream ss;
+ ss << "Malformed levels. min: " << min_max.min << " max: " << min_max.max
+ << " out of range. Max Level: " << max_level_;
+ throw ParquetException(ss.str());
+ }
+ }
+ num_values_remaining_ -= num_decoded;
+ return num_decoded;
+}
+
+ReaderProperties default_reader_properties() {
+ static ReaderProperties default_reader_properties;
+ return default_reader_properties;
+}
+
+namespace {
+
+// Extracts encoded statistics from V1 and V2 data page headers
+template <typename H>
+EncodedStatistics ExtractStatsFromHeader(const H& header) {
+ EncodedStatistics page_statistics;
+ if (!header.__isset.statistics) {
+ return page_statistics;
+ }
+ const format::Statistics& stats = header.statistics;
+ if (stats.__isset.max) {
+ page_statistics.set_max(stats.max);
+ }
+ if (stats.__isset.min) {
+ page_statistics.set_min(stats.min);
+ }
+ if (stats.__isset.null_count) {
+ page_statistics.set_null_count(stats.null_count);
+ }
+ if (stats.__isset.distinct_count) {
+ page_statistics.set_distinct_count(stats.distinct_count);
+ }
+ return page_statistics;
+}
+
+// ----------------------------------------------------------------------
+// SerializedPageReader deserializes Thrift metadata and pages that have been
+// assembled in a serialized stream for storing in a Parquet files
+
+// This subclass delimits pages appearing in a serialized stream, each preceded
+// by a serialized Thrift format::PageHeader indicating the type of each page
+// and the page metadata.
+class SerializedPageReader : public PageReader {
+ public:
+ SerializedPageReader(std::shared_ptr<ArrowInputStream> stream, int64_t total_num_rows,
+ Compression::type codec, ::arrow::MemoryPool* pool,
+ const CryptoContext* crypto_ctx)
+ : stream_(std::move(stream)),
+ decompression_buffer_(AllocateBuffer(pool, 0)),
+ page_ordinal_(0),
+ seen_num_rows_(0),
+ total_num_rows_(total_num_rows),
+ decryption_buffer_(AllocateBuffer(pool, 0)) {
+ if (crypto_ctx != nullptr) {
+ crypto_ctx_ = *crypto_ctx;
+ InitDecryption();
+ }
+ max_page_header_size_ = kDefaultMaxPageHeaderSize;
+ decompressor_ = GetCodec(codec);
+ }
+
+ // Implement the PageReader interface
+ std::shared_ptr<Page> NextPage() override;
+
+ void set_max_page_header_size(uint32_t size) override { max_page_header_size_ = size; }
+
+ private:
+ void UpdateDecryption(const std::shared_ptr<Decryptor>& decryptor, int8_t module_type,
+ const std::string& page_aad);
+
+ void InitDecryption();
+
+ std::shared_ptr<Buffer> DecompressIfNeeded(std::shared_ptr<Buffer> page_buffer,
+ int compressed_len, int uncompressed_len,
+ int levels_byte_len = 0);
+
+ std::shared_ptr<ArrowInputStream> stream_;
+
+ format::PageHeader current_page_header_;
+ std::shared_ptr<Page> current_page_;
+
+ // Compression codec to use.
+ std::unique_ptr<::arrow::util::Codec> decompressor_;
+ std::shared_ptr<ResizableBuffer> decompression_buffer_;
+
+ // The fields below are used for calculation of AAD (additional authenticated data)
+ // suffix which is part of the Parquet Modular Encryption.
+ // The AAD suffix for a parquet module is built internally by
+ // concatenating different parts some of which include
+ // the row group ordinal, column ordinal and page ordinal.
+ // Please refer to the encryption specification for more details:
+ // https://github.com/apache/parquet-format/blob/encryption/Encryption.md#44-additional-authenticated-data
+
+ // The ordinal fields in the context below are used for AAD suffix calculation.
+ CryptoContext crypto_ctx_;
+ int16_t page_ordinal_; // page ordinal does not count the dictionary page
+
+ // Maximum allowed page size
+ uint32_t max_page_header_size_;
+
+ // Number of rows read in data pages so far
+ int64_t seen_num_rows_;
+
+ // Number of rows in all the data pages
+ int64_t total_num_rows_;
+
+ // data_page_aad_ and data_page_header_aad_ contain the AAD for data page and data page
+ // header in a single column respectively.
+ // While calculating AAD for different pages in a single column the pages AAD is
+ // updated by only the page ordinal.
+ std::string data_page_aad_;
+ std::string data_page_header_aad_;
+ // Encryption
+ std::shared_ptr<ResizableBuffer> decryption_buffer_;
+};
+
+void SerializedPageReader::InitDecryption() {
+ // Prepare the AAD for quick update later.
+ if (crypto_ctx_.data_decryptor != nullptr) {
+ DCHECK(!crypto_ctx_.data_decryptor->file_aad().empty());
+ data_page_aad_ = encryption::CreateModuleAad(
+ crypto_ctx_.data_decryptor->file_aad(), encryption::kDataPage,
+ crypto_ctx_.row_group_ordinal, crypto_ctx_.column_ordinal, kNonPageOrdinal);
+ }
+ if (crypto_ctx_.meta_decryptor != nullptr) {
+ DCHECK(!crypto_ctx_.meta_decryptor->file_aad().empty());
+ data_page_header_aad_ = encryption::CreateModuleAad(
+ crypto_ctx_.meta_decryptor->file_aad(), encryption::kDataPageHeader,
+ crypto_ctx_.row_group_ordinal, crypto_ctx_.column_ordinal, kNonPageOrdinal);
+ }
+}
+
+void SerializedPageReader::UpdateDecryption(const std::shared_ptr<Decryptor>& decryptor,
+ int8_t module_type,
+ const std::string& page_aad) {
+ DCHECK(decryptor != nullptr);
+ if (crypto_ctx_.start_decrypt_with_dictionary_page) {
+ std::string aad = encryption::CreateModuleAad(
+ decryptor->file_aad(), module_type, crypto_ctx_.row_group_ordinal,
+ crypto_ctx_.column_ordinal, kNonPageOrdinal);
+ decryptor->UpdateAad(aad);
+ } else {
+ encryption::QuickUpdatePageAad(page_aad, page_ordinal_);
+ decryptor->UpdateAad(page_aad);
+ }
+}
+
+std::shared_ptr<Page> SerializedPageReader::NextPage() {
+ // Loop here because there may be unhandled page types that we skip until
+ // finding a page that we do know what to do with
+
+ while (seen_num_rows_ < total_num_rows_) {
+ uint32_t header_size = 0;
+ uint32_t allowed_page_size = kDefaultPageHeaderSize;
+
+ // Page headers can be very large because of page statistics
+ // We try to deserialize a larger buffer progressively
+ // until a maximum allowed header limit
+ while (true) {
+ PARQUET_ASSIGN_OR_THROW(auto view, stream_->Peek(allowed_page_size));
+ if (view.size() == 0) {
+ return std::shared_ptr<Page>(nullptr);
+ }
+
+ // This gets used, then set by DeserializeThriftMsg
+ header_size = static_cast<uint32_t>(view.size());
+ try {
+ if (crypto_ctx_.meta_decryptor != nullptr) {
+ UpdateDecryption(crypto_ctx_.meta_decryptor, encryption::kDictionaryPageHeader,
+ data_page_header_aad_);
+ }
+ DeserializeThriftMsg(reinterpret_cast<const uint8_t*>(view.data()), &header_size,
+ &current_page_header_, crypto_ctx_.meta_decryptor);
+ break;
+ } catch (std::exception& e) {
+ // Failed to deserialize. Double the allowed page header size and try again
+ std::stringstream ss;
+ ss << e.what();
+ allowed_page_size *= 2;
+ if (allowed_page_size > max_page_header_size_) {
+ ss << "Deserializing page header failed.\n";
+ throw ParquetException(ss.str());
+ }
+ }
+ }
+ // Advance the stream offset
+ PARQUET_THROW_NOT_OK(stream_->Advance(header_size));
+
+ int compressed_len = current_page_header_.compressed_page_size;
+ int uncompressed_len = current_page_header_.uncompressed_page_size;
+ if (compressed_len < 0 || uncompressed_len < 0) {
+ throw ParquetException("Invalid page header");
+ }
+
+ if (crypto_ctx_.data_decryptor != nullptr) {
+ UpdateDecryption(crypto_ctx_.data_decryptor, encryption::kDictionaryPage,
+ data_page_aad_);
+ }
+
+ // Read the compressed data page.
+ PARQUET_ASSIGN_OR_THROW(auto page_buffer, stream_->Read(compressed_len));
+ if (page_buffer->size() != compressed_len) {
+ std::stringstream ss;
+ ss << "Page was smaller (" << page_buffer->size() << ") than expected ("
+ << compressed_len << ")";
+ ParquetException::EofException(ss.str());
+ }
+
+ // Decrypt it if we need to
+ if (crypto_ctx_.data_decryptor != nullptr) {
+ PARQUET_THROW_NOT_OK(decryption_buffer_->Resize(
+ compressed_len - crypto_ctx_.data_decryptor->CiphertextSizeDelta(), false));
+ compressed_len = crypto_ctx_.data_decryptor->Decrypt(
+ page_buffer->data(), compressed_len, decryption_buffer_->mutable_data());
+
+ page_buffer = decryption_buffer_;
+ }
+
+ const PageType::type page_type = LoadEnumSafe(&current_page_header_.type);
+
+ if (page_type == PageType::DICTIONARY_PAGE) {
+ crypto_ctx_.start_decrypt_with_dictionary_page = false;
+ const format::DictionaryPageHeader& dict_header =
+ current_page_header_.dictionary_page_header;
+
+ bool is_sorted = dict_header.__isset.is_sorted ? dict_header.is_sorted : false;
+ if (dict_header.num_values < 0) {
+ throw ParquetException("Invalid page header (negative number of values)");
+ }
+
+ // Uncompress if needed
+ page_buffer =
+ DecompressIfNeeded(std::move(page_buffer), compressed_len, uncompressed_len);
+
+ return std::make_shared<DictionaryPage>(page_buffer, dict_header.num_values,
+ LoadEnumSafe(&dict_header.encoding),
+ is_sorted);
+ } else if (page_type == PageType::DATA_PAGE) {
+ ++page_ordinal_;
+ const format::DataPageHeader& header = current_page_header_.data_page_header;
+
+ if (header.num_values < 0) {
+ throw ParquetException("Invalid page header (negative number of values)");
+ }
+ EncodedStatistics page_statistics = ExtractStatsFromHeader(header);
+ seen_num_rows_ += header.num_values;
+
+ // Uncompress if needed
+ page_buffer =
+ DecompressIfNeeded(std::move(page_buffer), compressed_len, uncompressed_len);
+
+ return std::make_shared<DataPageV1>(page_buffer, header.num_values,
+ LoadEnumSafe(&header.encoding),
+ LoadEnumSafe(&header.definition_level_encoding),
+ LoadEnumSafe(&header.repetition_level_encoding),
+ uncompressed_len, page_statistics);
+ } else if (page_type == PageType::DATA_PAGE_V2) {
+ ++page_ordinal_;
+ const format::DataPageHeaderV2& header = current_page_header_.data_page_header_v2;
+
+ if (header.num_values < 0) {
+ throw ParquetException("Invalid page header (negative number of values)");
+ }
+ if (header.definition_levels_byte_length < 0 ||
+ header.repetition_levels_byte_length < 0) {
+ throw ParquetException("Invalid page header (negative levels byte length)");
+ }
+ bool is_compressed = header.__isset.is_compressed ? header.is_compressed : false;
+ EncodedStatistics page_statistics = ExtractStatsFromHeader(header);
+ seen_num_rows_ += header.num_values;
+
+ // Uncompress if needed
+ int levels_byte_len;
+ if (AddWithOverflow(header.definition_levels_byte_length,
+ header.repetition_levels_byte_length, &levels_byte_len)) {
+ throw ParquetException("Levels size too large (corrupt file?)");
+ }
+ // DecompressIfNeeded doesn't take `is_compressed` into account as
+ // it's page type-agnostic.
+ if (is_compressed) {
+ page_buffer = DecompressIfNeeded(std::move(page_buffer), compressed_len,
+ uncompressed_len, levels_byte_len);
+ }
+
+ return std::make_shared<DataPageV2>(
+ page_buffer, header.num_values, header.num_nulls, header.num_rows,
+ LoadEnumSafe(&header.encoding), header.definition_levels_byte_length,
+ header.repetition_levels_byte_length, uncompressed_len, is_compressed,
+ page_statistics);
+ } else {
+ // We don't know what this page type is. We're allowed to skip non-data
+ // pages.
+ continue;
+ }
+ }
+ return std::shared_ptr<Page>(nullptr);
+}
+
+std::shared_ptr<Buffer> SerializedPageReader::DecompressIfNeeded(
+ std::shared_ptr<Buffer> page_buffer, int compressed_len, int uncompressed_len,
+ int levels_byte_len) {
+ if (decompressor_ == nullptr) {
+ return page_buffer;
+ }
+ if (compressed_len < levels_byte_len || uncompressed_len < levels_byte_len) {
+ throw ParquetException("Invalid page header");
+ }
+
+ // Grow the uncompressed buffer if we need to.
+ if (uncompressed_len > static_cast<int>(decompression_buffer_->size())) {
+ PARQUET_THROW_NOT_OK(decompression_buffer_->Resize(uncompressed_len, false));
+ }
+
+ if (levels_byte_len > 0) {
+ // First copy the levels as-is
+ uint8_t* decompressed = decompression_buffer_->mutable_data();
+ memcpy(decompressed, page_buffer->data(), levels_byte_len);
+ }
+
+ // Decompress the values
+ PARQUET_THROW_NOT_OK(decompressor_->Decompress(
+ compressed_len - levels_byte_len, page_buffer->data() + levels_byte_len,
+ uncompressed_len - levels_byte_len,
+ decompression_buffer_->mutable_data() + levels_byte_len));
+
+ return decompression_buffer_;
+}
+
+} // namespace
+
+std::unique_ptr<PageReader> PageReader::Open(std::shared_ptr<ArrowInputStream> stream,
+ int64_t total_num_rows,
+ Compression::type codec,
+ ::arrow::MemoryPool* pool,
+ const CryptoContext* ctx) {
+ return std::unique_ptr<PageReader>(
+ new SerializedPageReader(std::move(stream), total_num_rows, codec, pool, ctx));
+}
+
+namespace {
+
+// ----------------------------------------------------------------------
+// Impl base class for TypedColumnReader and RecordReader
+
+// PLAIN_DICTIONARY is deprecated but used to be used as a dictionary index
+// encoding.
+static bool IsDictionaryIndexEncoding(const Encoding::type& e) {
+ return e == Encoding::RLE_DICTIONARY || e == Encoding::PLAIN_DICTIONARY;
+}
+
+template <typename DType>
+class ColumnReaderImplBase {
+ public:
+ using T = typename DType::c_type;
+
+ ColumnReaderImplBase(const ColumnDescriptor* descr, ::arrow::MemoryPool* pool)
+ : descr_(descr),
+ max_def_level_(descr->max_definition_level()),
+ max_rep_level_(descr->max_repetition_level()),
+ num_buffered_values_(0),
+ num_decoded_values_(0),
+ pool_(pool),
+ current_decoder_(nullptr),
+ current_encoding_(Encoding::UNKNOWN) {}
+
+ virtual ~ColumnReaderImplBase() = default;
+
+ protected:
+ // Read up to batch_size values from the current data page into the
+ // pre-allocated memory T*
+ //
+ // @returns: the number of values read into the out buffer
+ int64_t ReadValues(int64_t batch_size, T* out) {
+ int64_t num_decoded = current_decoder_->Decode(out, static_cast<int>(batch_size));
+ return num_decoded;
+ }
+
+ // Read up to batch_size values from the current data page into the
+ // pre-allocated memory T*, leaving spaces for null entries according
+ // to the def_levels.
+ //
+ // @returns: the number of values read into the out buffer
+ int64_t ReadValuesSpaced(int64_t batch_size, T* out, int64_t null_count,
+ uint8_t* valid_bits, int64_t valid_bits_offset) {
+ return current_decoder_->DecodeSpaced(out, static_cast<int>(batch_size),
+ static_cast<int>(null_count), valid_bits,
+ valid_bits_offset);
+ }
+
+ // Read multiple definition levels into preallocated memory
+ //
+ // Returns the number of decoded definition levels
+ int64_t ReadDefinitionLevels(int64_t batch_size, int16_t* levels) {
+ if (max_def_level_ == 0) {
+ return 0;
+ }
+ return definition_level_decoder_.Decode(static_cast<int>(batch_size), levels);
+ }
+
+ bool HasNextInternal() {
+ // Either there is no data page available yet, or the data page has been
+ // exhausted
+ if (num_buffered_values_ == 0 || num_decoded_values_ == num_buffered_values_) {
+ if (!ReadNewPage() || num_buffered_values_ == 0) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // Read multiple repetition levels into preallocated memory
+ // Returns the number of decoded repetition levels
+ int64_t ReadRepetitionLevels(int64_t batch_size, int16_t* levels) {
+ if (max_rep_level_ == 0) {
+ return 0;
+ }
+ return repetition_level_decoder_.Decode(static_cast<int>(batch_size), levels);
+ }
+
+ // Advance to the next data page
+ bool ReadNewPage() {
+ // Loop until we find the next data page.
+ while (true) {
+ current_page_ = pager_->NextPage();
+ if (!current_page_) {
+ // EOS
+ return false;
+ }
+
+ if (current_page_->type() == PageType::DICTIONARY_PAGE) {
+ ConfigureDictionary(static_cast<const DictionaryPage*>(current_page_.get()));
+ continue;
+ } else if (current_page_->type() == PageType::DATA_PAGE) {
+ const auto page = std::static_pointer_cast<DataPageV1>(current_page_);
+ const int64_t levels_byte_size = InitializeLevelDecoders(
+ *page, page->repetition_level_encoding(), page->definition_level_encoding());
+ InitializeDataDecoder(*page, levels_byte_size);
+ return true;
+ } else if (current_page_->type() == PageType::DATA_PAGE_V2) {
+ const auto page = std::static_pointer_cast<DataPageV2>(current_page_);
+ int64_t levels_byte_size = InitializeLevelDecodersV2(*page);
+ InitializeDataDecoder(*page, levels_byte_size);
+ return true;
+ } else {
+ // We don't know what this page type is. We're allowed to skip non-data
+ // pages.
+ continue;
+ }
+ }
+ return true;
+ }
+
+ void ConfigureDictionary(const DictionaryPage* page) {
+ int encoding = static_cast<int>(page->encoding());
+ if (page->encoding() == Encoding::PLAIN_DICTIONARY ||
+ page->encoding() == Encoding::PLAIN) {
+ encoding = static_cast<int>(Encoding::RLE_DICTIONARY);
+ }
+
+ auto it = decoders_.find(encoding);
+ if (it != decoders_.end()) {
+ throw ParquetException("Column cannot have more than one dictionary.");
+ }
+
+ if (page->encoding() == Encoding::PLAIN_DICTIONARY ||
+ page->encoding() == Encoding::PLAIN) {
+ auto dictionary = MakeTypedDecoder<DType>(Encoding::PLAIN, descr_);
+ dictionary->SetData(page->num_values(), page->data(), page->size());
+
+ // The dictionary is fully decoded during DictionaryDecoder::Init, so the
+ // DictionaryPage buffer is no longer required after this step
+ //
+ // TODO(wesm): investigate whether this all-or-nothing decoding of the
+ // dictionary makes sense and whether performance can be improved
+
+ std::unique_ptr<DictDecoder<DType>> decoder = MakeDictDecoder<DType>(descr_, pool_);
+ decoder->SetDict(dictionary.get());
+ decoders_[encoding] =
+ std::unique_ptr<DecoderType>(dynamic_cast<DecoderType*>(decoder.release()));
+ } else {
+ ParquetException::NYI("only plain dictionary encoding has been implemented");
+ }
+
+ new_dictionary_ = true;
+ current_decoder_ = decoders_[encoding].get();
+ DCHECK(current_decoder_);
+ }
+
+ // Initialize repetition and definition level decoders on the next data page.
+
+ // If the data page includes repetition and definition levels, we
+ // initialize the level decoders and return the number of encoded level bytes.
+ // The return value helps determine the number of bytes in the encoded data.
+ int64_t InitializeLevelDecoders(const DataPage& page,
+ Encoding::type repetition_level_encoding,
+ Encoding::type definition_level_encoding) {
+ // Read a data page.
+ num_buffered_values_ = page.num_values();
+
+ // Have not decoded any values from the data page yet
+ num_decoded_values_ = 0;
+
+ const uint8_t* buffer = page.data();
+ int32_t levels_byte_size = 0;
+ int32_t max_size = page.size();
+
+ // Data page Layout: Repetition Levels - Definition Levels - encoded values.
+ // Levels are encoded as rle or bit-packed.
+ // Init repetition levels
+ if (max_rep_level_ > 0) {
+ int32_t rep_levels_bytes = repetition_level_decoder_.SetData(
+ repetition_level_encoding, max_rep_level_,
+ static_cast<int>(num_buffered_values_), buffer, max_size);
+ buffer += rep_levels_bytes;
+ levels_byte_size += rep_levels_bytes;
+ max_size -= rep_levels_bytes;
+ }
+ // TODO figure a way to set max_def_level_ to 0
+ // if the initial value is invalid
+
+ // Init definition levels
+ if (max_def_level_ > 0) {
+ int32_t def_levels_bytes = definition_level_decoder_.SetData(
+ definition_level_encoding, max_def_level_,
+ static_cast<int>(num_buffered_values_), buffer, max_size);
+ levels_byte_size += def_levels_bytes;
+ max_size -= def_levels_bytes;
+ }
+
+ return levels_byte_size;
+ }
+
+ int64_t InitializeLevelDecodersV2(const DataPageV2& page) {
+ // Read a data page.
+ num_buffered_values_ = page.num_values();
+
+ // Have not decoded any values from the data page yet
+ num_decoded_values_ = 0;
+ const uint8_t* buffer = page.data();
+
+ const int64_t total_levels_length =
+ static_cast<int64_t>(page.repetition_levels_byte_length()) +
+ page.definition_levels_byte_length();
+
+ if (total_levels_length > page.size()) {
+ throw ParquetException("Data page too small for levels (corrupt header?)");
+ }
+
+ if (max_rep_level_ > 0) {
+ repetition_level_decoder_.SetDataV2(page.repetition_levels_byte_length(),
+ max_rep_level_,
+ static_cast<int>(num_buffered_values_), buffer);
+ buffer += page.repetition_levels_byte_length();
+ }
+
+ if (max_def_level_ > 0) {
+ definition_level_decoder_.SetDataV2(page.definition_levels_byte_length(),
+ max_def_level_,
+ static_cast<int>(num_buffered_values_), buffer);
+ }
+
+ return total_levels_length;
+ }
+
+ // Get a decoder object for this page or create a new decoder if this is the
+ // first page with this encoding.
+ void InitializeDataDecoder(const DataPage& page, int64_t levels_byte_size) {
+ const uint8_t* buffer = page.data() + levels_byte_size;
+ const int64_t data_size = page.size() - levels_byte_size;
+
+ if (data_size < 0) {
+ throw ParquetException("Page smaller than size of encoded levels");
+ }
+
+ Encoding::type encoding = page.encoding();
+
+ if (IsDictionaryIndexEncoding(encoding)) {
+ encoding = Encoding::RLE_DICTIONARY;
+ }
+
+ auto it = decoders_.find(static_cast<int>(encoding));
+ if (it != decoders_.end()) {
+ DCHECK(it->second.get() != nullptr);
+ if (encoding == Encoding::RLE_DICTIONARY) {
+ DCHECK(current_decoder_->encoding() == Encoding::RLE_DICTIONARY);
+ }
+ current_decoder_ = it->second.get();
+ } else {
+ switch (encoding) {
+ case Encoding::PLAIN: {
+ auto decoder = MakeTypedDecoder<DType>(Encoding::PLAIN, descr_);
+ current_decoder_ = decoder.get();
+ decoders_[static_cast<int>(encoding)] = std::move(decoder);
+ break;
+ }
+ case Encoding::BYTE_STREAM_SPLIT: {
+ auto decoder = MakeTypedDecoder<DType>(Encoding::BYTE_STREAM_SPLIT, descr_);
+ current_decoder_ = decoder.get();
+ decoders_[static_cast<int>(encoding)] = std::move(decoder);
+ break;
+ }
+ case Encoding::RLE_DICTIONARY:
+ throw ParquetException("Dictionary page must be before data page.");
+
+ case Encoding::DELTA_BINARY_PACKED: {
+ auto decoder = MakeTypedDecoder<DType>(Encoding::DELTA_BINARY_PACKED, descr_);
+ current_decoder_ = decoder.get();
+ decoders_[static_cast<int>(encoding)] = std::move(decoder);
+ break;
+ }
+ case Encoding::DELTA_LENGTH_BYTE_ARRAY:
+ case Encoding::DELTA_BYTE_ARRAY:
+ ParquetException::NYI("Unsupported encoding");
+
+ default:
+ throw ParquetException("Unknown encoding type.");
+ }
+ }
+ current_encoding_ = encoding;
+ current_decoder_->SetData(static_cast<int>(num_buffered_values_), buffer,
+ static_cast<int>(data_size));
+ }
+
+ const ColumnDescriptor* descr_;
+ const int16_t max_def_level_;
+ const int16_t max_rep_level_;
+
+ std::unique_ptr<PageReader> pager_;
+ std::shared_ptr<Page> current_page_;
+
+ // Not set if full schema for this field has no optional or repeated elements
+ LevelDecoder definition_level_decoder_;
+
+ // Not set for flat schemas.
+ LevelDecoder repetition_level_decoder_;
+
+ // The total number of values stored in the data page. This is the maximum of
+ // the number of encoded definition levels or encoded values. For
+ // non-repeated, required columns, this is equal to the number of encoded
+ // values. For repeated or optional values, there may be fewer data values
+ // than levels, and this tells you how many encoded levels there are in that
+ // case.
+ int64_t num_buffered_values_;
+
+ // The number of values from the current data page that have been decoded
+ // into memory
+ int64_t num_decoded_values_;
+
+ ::arrow::MemoryPool* pool_;
+
+ using DecoderType = TypedDecoder<DType>;
+ DecoderType* current_decoder_;
+ Encoding::type current_encoding_;
+
+ /// Flag to signal when a new dictionary has been set, for the benefit of
+ /// DictionaryRecordReader
+ bool new_dictionary_;
+
+ // The exposed encoding
+ ExposedEncoding exposed_encoding_ = ExposedEncoding::NO_ENCODING;
+
+ // Map of encoding type to the respective decoder object. For example, a
+ // column chunk's data pages may include both dictionary-encoded and
+ // plain-encoded data.
+ std::unordered_map<int, std::unique_ptr<DecoderType>> decoders_;
+
+ void ConsumeBufferedValues(int64_t num_values) { num_decoded_values_ += num_values; }
+};
+
+// ----------------------------------------------------------------------
+// TypedColumnReader implementations
+
+template <typename DType>
+class TypedColumnReaderImpl : public TypedColumnReader<DType>,
+ public ColumnReaderImplBase<DType> {
+ public:
+ using T = typename DType::c_type;
+
+ TypedColumnReaderImpl(const ColumnDescriptor* descr, std::unique_ptr<PageReader> pager,
+ ::arrow::MemoryPool* pool)
+ : ColumnReaderImplBase<DType>(descr, pool) {
+ this->pager_ = std::move(pager);
+ }
+
+ bool HasNext() override { return this->HasNextInternal(); }
+
+ int64_t ReadBatch(int64_t batch_size, int16_t* def_levels, int16_t* rep_levels,
+ T* values, int64_t* values_read) override;
+
+ int64_t ReadBatchSpaced(int64_t batch_size, int16_t* def_levels, int16_t* rep_levels,
+ T* values, uint8_t* valid_bits, int64_t valid_bits_offset,
+ int64_t* levels_read, int64_t* values_read,
+ int64_t* null_count) override;
+
+ int64_t Skip(int64_t num_rows_to_skip) override;
+
+ Type::type type() const override { return this->descr_->physical_type(); }
+
+ const ColumnDescriptor* descr() const override { return this->descr_; }
+
+ ExposedEncoding GetExposedEncoding() override { return this->exposed_encoding_; };
+
+ int64_t ReadBatchWithDictionary(int64_t batch_size, int16_t* def_levels,
+ int16_t* rep_levels, int32_t* indices,
+ int64_t* indices_read, const T** dict,
+ int32_t* dict_len) override;
+
+ protected:
+ void SetExposedEncoding(ExposedEncoding encoding) override {
+ this->exposed_encoding_ = encoding;
+ }
+
+ private:
+ // Read dictionary indices. Similar to ReadValues but decode data to dictionary indices.
+ // This function is called only by ReadBatchWithDictionary().
+ int64_t ReadDictionaryIndices(int64_t indices_to_read, int32_t* indices) {
+ auto decoder = dynamic_cast<DictDecoder<DType>*>(this->current_decoder_);
+ return decoder->DecodeIndices(static_cast<int>(indices_to_read), indices);
+ }
+
+ // Get dictionary. The dictionary should have been set by SetDict(). The dictionary is
+ // owned by the internal decoder and is destroyed when the reader is destroyed. This
+ // function is called only by ReadBatchWithDictionary() after dictionary is configured.
+ void GetDictionary(const T** dictionary, int32_t* dictionary_length) {
+ auto decoder = dynamic_cast<DictDecoder<DType>*>(this->current_decoder_);
+ decoder->GetDictionary(dictionary, dictionary_length);
+ }
+
+ // Read definition and repetition levels. Also return the number of definition levels
+ // and number of values to read. This function is called before reading values.
+ void ReadLevels(int64_t batch_size, int16_t* def_levels, int16_t* rep_levels,
+ int64_t* num_def_levels, int64_t* values_to_read) {
+ batch_size =
+ std::min(batch_size, this->num_buffered_values_ - this->num_decoded_values_);
+
+ // If the field is required and non-repeated, there are no definition levels
+ if (this->max_def_level_ > 0 && def_levels != nullptr) {
+ *num_def_levels = this->ReadDefinitionLevels(batch_size, def_levels);
+ // TODO(wesm): this tallying of values-to-decode can be performed with better
+ // cache-efficiency if fused with the level decoding.
+ for (int64_t i = 0; i < *num_def_levels; ++i) {
+ if (def_levels[i] == this->max_def_level_) {
+ ++(*values_to_read);
+ }
+ }
+ } else {
+ // Required field, read all values
+ *values_to_read = batch_size;
+ }
+
+ // Not present for non-repeated fields
+ if (this->max_rep_level_ > 0 && rep_levels != nullptr) {
+ int64_t num_rep_levels = this->ReadRepetitionLevels(batch_size, rep_levels);
+ if (def_levels != nullptr && *num_def_levels != num_rep_levels) {
+ throw ParquetException("Number of decoded rep / def levels did not match");
+ }
+ }
+ }
+};
+
+template <typename DType>
+int64_t TypedColumnReaderImpl<DType>::ReadBatchWithDictionary(
+ int64_t batch_size, int16_t* def_levels, int16_t* rep_levels, int32_t* indices,
+ int64_t* indices_read, const T** dict, int32_t* dict_len) {
+ bool has_dict_output = dict != nullptr && dict_len != nullptr;
+ // Similar logic as ReadValues to get pages.
+ if (!HasNext()) {
+ *indices_read = 0;
+ if (has_dict_output) {
+ *dict = nullptr;
+ *dict_len = 0;
+ }
+ return 0;
+ }
+
+ // Verify the current data page is dictionary encoded.
+ if (this->current_encoding_ != Encoding::RLE_DICTIONARY) {
+ std::stringstream ss;
+ ss << "Data page is not dictionary encoded. Encoding: "
+ << EncodingToString(this->current_encoding_);
+ throw ParquetException(ss.str());
+ }
+
+ // Get dictionary pointer and length.
+ if (has_dict_output) {
+ GetDictionary(dict, dict_len);
+ }
+
+ // Similar logic as ReadValues to get def levels and rep levels.
+ int64_t num_def_levels = 0;
+ int64_t indices_to_read = 0;
+ ReadLevels(batch_size, def_levels, rep_levels, &num_def_levels, &indices_to_read);
+
+ // Read dictionary indices.
+ *indices_read = ReadDictionaryIndices(indices_to_read, indices);
+ int64_t total_indices = std::max(num_def_levels, *indices_read);
+ this->ConsumeBufferedValues(total_indices);
+
+ return total_indices;
+}
+
+template <typename DType>
+int64_t TypedColumnReaderImpl<DType>::ReadBatch(int64_t batch_size, int16_t* def_levels,
+ int16_t* rep_levels, T* values,
+ int64_t* values_read) {
+ // HasNext invokes ReadNewPage
+ if (!HasNext()) {
+ *values_read = 0;
+ return 0;
+ }
+
+ // TODO(wesm): keep reading data pages until batch_size is reached, or the
+ // row group is finished
+ int64_t num_def_levels = 0;
+ int64_t values_to_read = 0;
+ ReadLevels(batch_size, def_levels, rep_levels, &num_def_levels, &values_to_read);
+
+ *values_read = this->ReadValues(values_to_read, values);
+ int64_t total_values = std::max(num_def_levels, *values_read);
+ this->ConsumeBufferedValues(total_values);
+
+ return total_values;
+}
+
+template <typename DType>
+int64_t TypedColumnReaderImpl<DType>::ReadBatchSpaced(
+ int64_t batch_size, int16_t* def_levels, int16_t* rep_levels, T* values,
+ uint8_t* valid_bits, int64_t valid_bits_offset, int64_t* levels_read,
+ int64_t* values_read, int64_t* null_count_out) {
+ // HasNext invokes ReadNewPage
+ if (!HasNext()) {
+ *levels_read = 0;
+ *values_read = 0;
+ *null_count_out = 0;
+ return 0;
+ }
+
+ int64_t total_values;
+ // TODO(wesm): keep reading data pages until batch_size is reached, or the
+ // row group is finished
+ batch_size =
+ std::min(batch_size, this->num_buffered_values_ - this->num_decoded_values_);
+
+ // If the field is required and non-repeated, there are no definition levels
+ if (this->max_def_level_ > 0) {
+ int64_t num_def_levels = this->ReadDefinitionLevels(batch_size, def_levels);
+
+ // Not present for non-repeated fields
+ if (this->max_rep_level_ > 0) {
+ int64_t num_rep_levels = this->ReadRepetitionLevels(batch_size, rep_levels);
+ if (num_def_levels != num_rep_levels) {
+ throw ParquetException("Number of decoded rep / def levels did not match");
+ }
+ }
+
+ const bool has_spaced_values = HasSpacedValues(this->descr_);
+ int64_t null_count = 0;
+ if (!has_spaced_values) {
+ int values_to_read = 0;
+ for (int64_t i = 0; i < num_def_levels; ++i) {
+ if (def_levels[i] == this->max_def_level_) {
+ ++values_to_read;
+ }
+ }
+ total_values = this->ReadValues(values_to_read, values);
+ ::arrow::BitUtil::SetBitsTo(valid_bits, valid_bits_offset,
+ /*length=*/total_values,
+ /*bits_are_set=*/true);
+ *values_read = total_values;
+ } else {
+ internal::LevelInfo info;
+ info.repeated_ancestor_def_level = this->max_def_level_ - 1;
+ info.def_level = this->max_def_level_;
+ info.rep_level = this->max_rep_level_;
+ internal::ValidityBitmapInputOutput validity_io;
+ validity_io.values_read_upper_bound = num_def_levels;
+ validity_io.valid_bits = valid_bits;
+ validity_io.valid_bits_offset = valid_bits_offset;
+ validity_io.null_count = null_count;
+ validity_io.values_read = *values_read;
+
+ internal::DefLevelsToBitmap(def_levels, num_def_levels, info, &validity_io);
+ null_count = validity_io.null_count;
+ *values_read = validity_io.values_read;
+
+ total_values =
+ this->ReadValuesSpaced(*values_read, values, static_cast<int>(null_count),
+ valid_bits, valid_bits_offset);
+ }
+ *levels_read = num_def_levels;
+ *null_count_out = null_count;
+
+ } else {
+ // Required field, read all values
+ total_values = this->ReadValues(batch_size, values);
+ ::arrow::BitUtil::SetBitsTo(valid_bits, valid_bits_offset,
+ /*length=*/total_values,
+ /*bits_are_set=*/true);
+ *null_count_out = 0;
+ *values_read = total_values;
+ *levels_read = total_values;
+ }
+
+ this->ConsumeBufferedValues(*levels_read);
+ return total_values;
+}
+
+template <typename DType>
+int64_t TypedColumnReaderImpl<DType>::Skip(int64_t num_rows_to_skip) {
+ int64_t rows_to_skip = num_rows_to_skip;
+ while (HasNext() && rows_to_skip > 0) {
+ // If the number of rows to skip is more than the number of undecoded values, skip the
+ // Page.
+ if (rows_to_skip > (this->num_buffered_values_ - this->num_decoded_values_)) {
+ rows_to_skip -= this->num_buffered_values_ - this->num_decoded_values_;
+ this->num_decoded_values_ = this->num_buffered_values_;
+ } else {
+ // We need to read this Page
+ // Jump to the right offset in the Page
+ int64_t batch_size = 1024; // ReadBatch with a smaller memory footprint
+ int64_t values_read = 0;
+
+ // This will be enough scratch space to accommodate 16-bit levels or any
+ // value type
+ int value_size = type_traits<DType::type_num>::value_byte_size;
+ std::shared_ptr<ResizableBuffer> scratch = AllocateBuffer(
+ this->pool_, batch_size * std::max<int>(sizeof(int16_t), value_size));
+
+ do {
+ batch_size = std::min(batch_size, rows_to_skip);
+ values_read =
+ ReadBatch(static_cast<int>(batch_size),
+ reinterpret_cast<int16_t*>(scratch->mutable_data()),
+ reinterpret_cast<int16_t*>(scratch->mutable_data()),
+ reinterpret_cast<T*>(scratch->mutable_data()), &values_read);
+ rows_to_skip -= values_read;
+ } while (values_read > 0 && rows_to_skip > 0);
+ }
+ }
+ return num_rows_to_skip - rows_to_skip;
+}
+
+} // namespace
+
+// ----------------------------------------------------------------------
+// Dynamic column reader constructor
+
+std::shared_ptr<ColumnReader> ColumnReader::Make(const ColumnDescriptor* descr,
+ std::unique_ptr<PageReader> pager,
+ MemoryPool* pool) {
+ switch (descr->physical_type()) {
+ case Type::BOOLEAN:
+ return std::make_shared<TypedColumnReaderImpl<BooleanType>>(descr, std::move(pager),
+ pool);
+ case Type::INT32:
+ return std::make_shared<TypedColumnReaderImpl<Int32Type>>(descr, std::move(pager),
+ pool);
+ case Type::INT64:
+ return std::make_shared<TypedColumnReaderImpl<Int64Type>>(descr, std::move(pager),
+ pool);
+ case Type::INT96:
+ return std::make_shared<TypedColumnReaderImpl<Int96Type>>(descr, std::move(pager),
+ pool);
+ case Type::FLOAT:
+ return std::make_shared<TypedColumnReaderImpl<FloatType>>(descr, std::move(pager),
+ pool);
+ case Type::DOUBLE:
+ return std::make_shared<TypedColumnReaderImpl<DoubleType>>(descr, std::move(pager),
+ pool);
+ case Type::BYTE_ARRAY:
+ return std::make_shared<TypedColumnReaderImpl<ByteArrayType>>(
+ descr, std::move(pager), pool);
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return std::make_shared<TypedColumnReaderImpl<FLBAType>>(descr, std::move(pager),
+ pool);
+ default:
+ ParquetException::NYI("type reader not implemented");
+ }
+ // Unreachable code, but suppress compiler warning
+ return std::shared_ptr<ColumnReader>(nullptr);
+}
+
+// ----------------------------------------------------------------------
+// RecordReader
+
+namespace internal {
+namespace {
+
+// The minimum number of repetition/definition levels to decode at a time, for
+// better vectorized performance when doing many smaller record reads
+constexpr int64_t kMinLevelBatchSize = 1024;
+
+template <typename DType>
+class TypedRecordReader : public ColumnReaderImplBase<DType>,
+ virtual public RecordReader {
+ public:
+ using T = typename DType::c_type;
+ using BASE = ColumnReaderImplBase<DType>;
+ TypedRecordReader(const ColumnDescriptor* descr, LevelInfo leaf_info, MemoryPool* pool)
+ : BASE(descr, pool) {
+ leaf_info_ = leaf_info;
+ nullable_values_ = leaf_info.HasNullableValues();
+ at_record_start_ = true;
+ records_read_ = 0;
+ values_written_ = 0;
+ values_capacity_ = 0;
+ null_count_ = 0;
+ levels_written_ = 0;
+ levels_position_ = 0;
+ levels_capacity_ = 0;
+ uses_values_ = !(descr->physical_type() == Type::BYTE_ARRAY);
+
+ if (uses_values_) {
+ values_ = AllocateBuffer(pool);
+ }
+ valid_bits_ = AllocateBuffer(pool);
+ def_levels_ = AllocateBuffer(pool);
+ rep_levels_ = AllocateBuffer(pool);
+ Reset();
+ }
+
+ int64_t available_values_current_page() const {
+ return this->num_buffered_values_ - this->num_decoded_values_;
+ }
+
+ // Compute the values capacity in bytes for the given number of elements
+ int64_t bytes_for_values(int64_t nitems) const {
+ int64_t type_size = GetTypeByteSize(this->descr_->physical_type());
+ int64_t bytes_for_values = -1;
+ if (MultiplyWithOverflow(nitems, type_size, &bytes_for_values)) {
+ throw ParquetException("Total size of items too large");
+ }
+ return bytes_for_values;
+ }
+
+ int64_t ReadRecords(int64_t num_records) override {
+ // Delimit records, then read values at the end
+ int64_t records_read = 0;
+
+ if (levels_position_ < levels_written_) {
+ records_read += ReadRecordData(num_records);
+ }
+
+ int64_t level_batch_size = std::max(kMinLevelBatchSize, num_records);
+
+ // If we are in the middle of a record, we continue until reaching the
+ // desired number of records or the end of the current record if we've found
+ // enough records
+ while (!at_record_start_ || records_read < num_records) {
+ // Is there more data to read in this row group?
+ if (!this->HasNextInternal()) {
+ if (!at_record_start_) {
+ // We ended the row group while inside a record that we haven't seen
+ // the end of yet. So increment the record count for the last record in
+ // the row group
+ ++records_read;
+ at_record_start_ = true;
+ }
+ break;
+ }
+
+ /// We perform multiple batch reads until we either exhaust the row group
+ /// or observe the desired number of records
+ int64_t batch_size = std::min(level_batch_size, available_values_current_page());
+
+ // No more data in column
+ if (batch_size == 0) {
+ break;
+ }
+
+ if (this->max_def_level_ > 0) {
+ ReserveLevels(batch_size);
+
+ int16_t* def_levels = this->def_levels() + levels_written_;
+ int16_t* rep_levels = this->rep_levels() + levels_written_;
+
+ // Not present for non-repeated fields
+ int64_t levels_read = 0;
+ if (this->max_rep_level_ > 0) {
+ levels_read = this->ReadDefinitionLevels(batch_size, def_levels);
+ if (this->ReadRepetitionLevels(batch_size, rep_levels) != levels_read) {
+ throw ParquetException("Number of decoded rep / def levels did not match");
+ }
+ } else if (this->max_def_level_ > 0) {
+ levels_read = this->ReadDefinitionLevels(batch_size, def_levels);
+ }
+
+ // Exhausted column chunk
+ if (levels_read == 0) {
+ break;
+ }
+
+ levels_written_ += levels_read;
+ records_read += ReadRecordData(num_records - records_read);
+ } else {
+ // No repetition or definition levels
+ batch_size = std::min(num_records - records_read, batch_size);
+ records_read += ReadRecordData(batch_size);
+ }
+ }
+
+ return records_read;
+ }
+
+ // We may outwardly have the appearance of having exhausted a column chunk
+ // when in fact we are in the middle of processing the last batch
+ bool has_values_to_process() const { return levels_position_ < levels_written_; }
+
+ std::shared_ptr<ResizableBuffer> ReleaseValues() override {
+ if (uses_values_) {
+ auto result = values_;
+ PARQUET_THROW_NOT_OK(result->Resize(bytes_for_values(values_written_), true));
+ values_ = AllocateBuffer(this->pool_);
+ values_capacity_ = 0;
+ return result;
+ } else {
+ return nullptr;
+ }
+ }
+
+ std::shared_ptr<ResizableBuffer> ReleaseIsValid() override {
+ if (leaf_info_.HasNullableValues()) {
+ auto result = valid_bits_;
+ PARQUET_THROW_NOT_OK(result->Resize(BitUtil::BytesForBits(values_written_), true));
+ valid_bits_ = AllocateBuffer(this->pool_);
+ return result;
+ } else {
+ return nullptr;
+ }
+ }
+
+ // Process written repetition/definition levels to reach the end of
+ // records. Process no more levels than necessary to delimit the indicated
+ // number of logical records. Updates internal state of RecordReader
+ //
+ // \return Number of records delimited
+ int64_t DelimitRecords(int64_t num_records, int64_t* values_seen) {
+ int64_t values_to_read = 0;
+ int64_t records_read = 0;
+
+ const int16_t* def_levels = this->def_levels() + levels_position_;
+ const int16_t* rep_levels = this->rep_levels() + levels_position_;
+
+ DCHECK_GT(this->max_rep_level_, 0);
+
+ // Count logical records and number of values to read
+ while (levels_position_ < levels_written_) {
+ const int16_t rep_level = *rep_levels++;
+ if (rep_level == 0) {
+ // If at_record_start_ is true, we are seeing the start of a record
+ // for the second time, such as after repeated calls to
+ // DelimitRecords. In this case we must continue until we find
+ // another record start or exhausting the ColumnChunk
+ if (!at_record_start_) {
+ // We've reached the end of a record; increment the record count.
+ ++records_read;
+ if (records_read == num_records) {
+ // We've found the number of records we were looking for. Set
+ // at_record_start_ to true and break
+ at_record_start_ = true;
+ break;
+ }
+ }
+ }
+ // We have decided to consume the level at this position; therefore we
+ // must advance until we find another record boundary
+ at_record_start_ = false;
+
+ const int16_t def_level = *def_levels++;
+ if (def_level == this->max_def_level_) {
+ ++values_to_read;
+ }
+ ++levels_position_;
+ }
+ *values_seen = values_to_read;
+ return records_read;
+ }
+
+ void Reserve(int64_t capacity) override {
+ ReserveLevels(capacity);
+ ReserveValues(capacity);
+ }
+
+ int64_t UpdateCapacity(int64_t capacity, int64_t size, int64_t extra_size) {
+ if (extra_size < 0) {
+ throw ParquetException("Negative size (corrupt file?)");
+ }
+ int64_t target_size = -1;
+ if (AddWithOverflow(size, extra_size, &target_size)) {
+ throw ParquetException("Allocation size too large (corrupt file?)");
+ }
+ if (target_size >= (1LL << 62)) {
+ throw ParquetException("Allocation size too large (corrupt file?)");
+ }
+ if (capacity >= target_size) {
+ return capacity;
+ }
+ return BitUtil::NextPower2(target_size);
+ }
+
+ void ReserveLevels(int64_t extra_levels) {
+ if (this->max_def_level_ > 0) {
+ const int64_t new_levels_capacity =
+ UpdateCapacity(levels_capacity_, levels_written_, extra_levels);
+ if (new_levels_capacity > levels_capacity_) {
+ constexpr auto kItemSize = static_cast<int64_t>(sizeof(int16_t));
+ int64_t capacity_in_bytes = -1;
+ if (MultiplyWithOverflow(new_levels_capacity, kItemSize, &capacity_in_bytes)) {
+ throw ParquetException("Allocation size too large (corrupt file?)");
+ }
+ PARQUET_THROW_NOT_OK(def_levels_->Resize(capacity_in_bytes, false));
+ if (this->max_rep_level_ > 0) {
+ PARQUET_THROW_NOT_OK(rep_levels_->Resize(capacity_in_bytes, false));
+ }
+ levels_capacity_ = new_levels_capacity;
+ }
+ }
+ }
+
+ void ReserveValues(int64_t extra_values) {
+ const int64_t new_values_capacity =
+ UpdateCapacity(values_capacity_, values_written_, extra_values);
+ if (new_values_capacity > values_capacity_) {
+ // XXX(wesm): A hack to avoid memory allocation when reading directly
+ // into builder classes
+ if (uses_values_) {
+ PARQUET_THROW_NOT_OK(
+ values_->Resize(bytes_for_values(new_values_capacity), false));
+ }
+ values_capacity_ = new_values_capacity;
+ }
+ if (leaf_info_.HasNullableValues()) {
+ int64_t valid_bytes_new = BitUtil::BytesForBits(values_capacity_);
+ if (valid_bits_->size() < valid_bytes_new) {
+ int64_t valid_bytes_old = BitUtil::BytesForBits(values_written_);
+ PARQUET_THROW_NOT_OK(valid_bits_->Resize(valid_bytes_new, false));
+
+ // Avoid valgrind warnings
+ memset(valid_bits_->mutable_data() + valid_bytes_old, 0,
+ valid_bytes_new - valid_bytes_old);
+ }
+ }
+ }
+
+ void Reset() override {
+ ResetValues();
+
+ if (levels_written_ > 0) {
+ const int64_t levels_remaining = levels_written_ - levels_position_;
+ // Shift remaining levels to beginning of buffer and trim to only the number
+ // of decoded levels remaining
+ int16_t* def_data = def_levels();
+ int16_t* rep_data = rep_levels();
+
+ std::copy(def_data + levels_position_, def_data + levels_written_, def_data);
+ PARQUET_THROW_NOT_OK(
+ def_levels_->Resize(levels_remaining * sizeof(int16_t), false));
+
+ if (this->max_rep_level_ > 0) {
+ std::copy(rep_data + levels_position_, rep_data + levels_written_, rep_data);
+ PARQUET_THROW_NOT_OK(
+ rep_levels_->Resize(levels_remaining * sizeof(int16_t), false));
+ }
+
+ levels_written_ -= levels_position_;
+ levels_position_ = 0;
+ levels_capacity_ = levels_remaining;
+ }
+
+ records_read_ = 0;
+
+ // Call Finish on the binary builders to reset them
+ }
+
+ void SetPageReader(std::unique_ptr<PageReader> reader) override {
+ at_record_start_ = true;
+ this->pager_ = std::move(reader);
+ ResetDecoders();
+ }
+
+ bool HasMoreData() const override { return this->pager_ != nullptr; }
+
+ // Dictionary decoders must be reset when advancing row groups
+ void ResetDecoders() { this->decoders_.clear(); }
+
+ virtual void ReadValuesSpaced(int64_t values_with_nulls, int64_t null_count) {
+ uint8_t* valid_bits = valid_bits_->mutable_data();
+ const int64_t valid_bits_offset = values_written_;
+
+ int64_t num_decoded = this->current_decoder_->DecodeSpaced(
+ ValuesHead<T>(), static_cast<int>(values_with_nulls),
+ static_cast<int>(null_count), valid_bits, valid_bits_offset);
+ DCHECK_EQ(num_decoded, values_with_nulls);
+ }
+
+ virtual void ReadValuesDense(int64_t values_to_read) {
+ int64_t num_decoded =
+ this->current_decoder_->Decode(ValuesHead<T>(), static_cast<int>(values_to_read));
+ DCHECK_EQ(num_decoded, values_to_read);
+ }
+
+ // Return number of logical records read
+ int64_t ReadRecordData(int64_t num_records) {
+ // Conservative upper bound
+ const int64_t possible_num_values =
+ std::max(num_records, levels_written_ - levels_position_);
+ ReserveValues(possible_num_values);
+
+ const int64_t start_levels_position = levels_position_;
+
+ int64_t values_to_read = 0;
+ int64_t records_read = 0;
+ if (this->max_rep_level_ > 0) {
+ records_read = DelimitRecords(num_records, &values_to_read);
+ } else if (this->max_def_level_ > 0) {
+ // No repetition levels, skip delimiting logic. Each level represents a
+ // null or not null entry
+ records_read = std::min(levels_written_ - levels_position_, num_records);
+
+ // This is advanced by DelimitRecords, which we skipped
+ levels_position_ += records_read;
+ } else {
+ records_read = values_to_read = num_records;
+ }
+
+ int64_t null_count = 0;
+ if (leaf_info_.HasNullableValues()) {
+ ValidityBitmapInputOutput validity_io;
+ validity_io.values_read_upper_bound = levels_position_ - start_levels_position;
+ validity_io.valid_bits = valid_bits_->mutable_data();
+ validity_io.valid_bits_offset = values_written_;
+
+ DefLevelsToBitmap(def_levels() + start_levels_position,
+ levels_position_ - start_levels_position, leaf_info_,
+ &validity_io);
+ values_to_read = validity_io.values_read - validity_io.null_count;
+ null_count = validity_io.null_count;
+ DCHECK_GE(values_to_read, 0);
+ ReadValuesSpaced(validity_io.values_read, null_count);
+ } else {
+ DCHECK_GE(values_to_read, 0);
+ ReadValuesDense(values_to_read);
+ }
+ if (this->leaf_info_.def_level > 0) {
+ // Optional, repeated, or some mix thereof
+ this->ConsumeBufferedValues(levels_position_ - start_levels_position);
+ } else {
+ // Flat, non-repeated
+ this->ConsumeBufferedValues(values_to_read);
+ }
+ // Total values, including null spaces, if any
+ values_written_ += values_to_read + null_count;
+ null_count_ += null_count;
+
+ return records_read;
+ }
+
+ void DebugPrintState() override {
+ const int16_t* def_levels = this->def_levels();
+ const int16_t* rep_levels = this->rep_levels();
+ const int64_t total_levels_read = levels_position_;
+
+ const T* vals = reinterpret_cast<const T*>(this->values());
+
+ std::cout << "def levels: ";
+ for (int64_t i = 0; i < total_levels_read; ++i) {
+ std::cout << def_levels[i] << " ";
+ }
+ std::cout << std::endl;
+
+ std::cout << "rep levels: ";
+ for (int64_t i = 0; i < total_levels_read; ++i) {
+ std::cout << rep_levels[i] << " ";
+ }
+ std::cout << std::endl;
+
+ std::cout << "values: ";
+ for (int64_t i = 0; i < this->values_written(); ++i) {
+ std::cout << vals[i] << " ";
+ }
+ std::cout << std::endl;
+ }
+
+ void ResetValues() {
+ if (values_written_ > 0) {
+ // Resize to 0, but do not shrink to fit
+ if (uses_values_) {
+ PARQUET_THROW_NOT_OK(values_->Resize(0, false));
+ }
+ PARQUET_THROW_NOT_OK(valid_bits_->Resize(0, false));
+ values_written_ = 0;
+ values_capacity_ = 0;
+ null_count_ = 0;
+ }
+ }
+
+ protected:
+ template <typename T>
+ T* ValuesHead() {
+ return reinterpret_cast<T*>(values_->mutable_data()) + values_written_;
+ }
+ LevelInfo leaf_info_;
+};
+
+class FLBARecordReader : public TypedRecordReader<FLBAType>,
+ virtual public BinaryRecordReader {
+ public:
+ FLBARecordReader(const ColumnDescriptor* descr, LevelInfo leaf_info,
+ ::arrow::MemoryPool* pool)
+ : TypedRecordReader<FLBAType>(descr, leaf_info, pool), builder_(nullptr) {
+ DCHECK_EQ(descr_->physical_type(), Type::FIXED_LEN_BYTE_ARRAY);
+ int byte_width = descr_->type_length();
+ std::shared_ptr<::arrow::DataType> type = ::arrow::fixed_size_binary(byte_width);
+ builder_.reset(new ::arrow::FixedSizeBinaryBuilder(type, this->pool_));
+ }
+
+ ::arrow::ArrayVector GetBuilderChunks() override {
+ std::shared_ptr<::arrow::Array> chunk;
+ PARQUET_THROW_NOT_OK(builder_->Finish(&chunk));
+ return ::arrow::ArrayVector({chunk});
+ }
+
+ void ReadValuesDense(int64_t values_to_read) override {
+ auto values = ValuesHead<FLBA>();
+ int64_t num_decoded =
+ this->current_decoder_->Decode(values, static_cast<int>(values_to_read));
+ DCHECK_EQ(num_decoded, values_to_read);
+
+ for (int64_t i = 0; i < num_decoded; i++) {
+ PARQUET_THROW_NOT_OK(builder_->Append(values[i].ptr));
+ }
+ ResetValues();
+ }
+
+ void ReadValuesSpaced(int64_t values_to_read, int64_t null_count) override {
+ uint8_t* valid_bits = valid_bits_->mutable_data();
+ const int64_t valid_bits_offset = values_written_;
+ auto values = ValuesHead<FLBA>();
+
+ int64_t num_decoded = this->current_decoder_->DecodeSpaced(
+ values, static_cast<int>(values_to_read), static_cast<int>(null_count),
+ valid_bits, valid_bits_offset);
+ DCHECK_EQ(num_decoded, values_to_read);
+
+ for (int64_t i = 0; i < num_decoded; i++) {
+ if (::arrow::BitUtil::GetBit(valid_bits, valid_bits_offset + i)) {
+ PARQUET_THROW_NOT_OK(builder_->Append(values[i].ptr));
+ } else {
+ PARQUET_THROW_NOT_OK(builder_->AppendNull());
+ }
+ }
+ ResetValues();
+ }
+
+ private:
+ std::unique_ptr<::arrow::FixedSizeBinaryBuilder> builder_;
+};
+
+class ByteArrayChunkedRecordReader : public TypedRecordReader<ByteArrayType>,
+ virtual public BinaryRecordReader {
+ public:
+ ByteArrayChunkedRecordReader(const ColumnDescriptor* descr, LevelInfo leaf_info,
+ ::arrow::MemoryPool* pool)
+ : TypedRecordReader<ByteArrayType>(descr, leaf_info, pool) {
+ DCHECK_EQ(descr_->physical_type(), Type::BYTE_ARRAY);
+ accumulator_.builder.reset(new ::arrow::BinaryBuilder(pool));
+ }
+
+ ::arrow::ArrayVector GetBuilderChunks() override {
+ ::arrow::ArrayVector result = accumulator_.chunks;
+ if (result.size() == 0 || accumulator_.builder->length() > 0) {
+ std::shared_ptr<::arrow::Array> last_chunk;
+ PARQUET_THROW_NOT_OK(accumulator_.builder->Finish(&last_chunk));
+ result.push_back(std::move(last_chunk));
+ }
+ accumulator_.chunks = {};
+ return result;
+ }
+
+ void ReadValuesDense(int64_t values_to_read) override {
+ int64_t num_decoded = this->current_decoder_->DecodeArrowNonNull(
+ static_cast<int>(values_to_read), &accumulator_);
+ DCHECK_EQ(num_decoded, values_to_read);
+ ResetValues();
+ }
+
+ void ReadValuesSpaced(int64_t values_to_read, int64_t null_count) override {
+ int64_t num_decoded = this->current_decoder_->DecodeArrow(
+ static_cast<int>(values_to_read), static_cast<int>(null_count),
+ valid_bits_->mutable_data(), values_written_, &accumulator_);
+ DCHECK_EQ(num_decoded, values_to_read - null_count);
+ ResetValues();
+ }
+
+ private:
+ // Helper data structure for accumulating builder chunks
+ typename EncodingTraits<ByteArrayType>::Accumulator accumulator_;
+};
+
+class ByteArrayDictionaryRecordReader : public TypedRecordReader<ByteArrayType>,
+ virtual public DictionaryRecordReader {
+ public:
+ ByteArrayDictionaryRecordReader(const ColumnDescriptor* descr, LevelInfo leaf_info,
+ ::arrow::MemoryPool* pool)
+ : TypedRecordReader<ByteArrayType>(descr, leaf_info, pool), builder_(pool) {
+ this->read_dictionary_ = true;
+ }
+
+ std::shared_ptr<::arrow::ChunkedArray> GetResult() override {
+ FlushBuilder();
+ std::vector<std::shared_ptr<::arrow::Array>> result;
+ std::swap(result, result_chunks_);
+ return std::make_shared<::arrow::ChunkedArray>(std::move(result), builder_.type());
+ }
+
+ void FlushBuilder() {
+ if (builder_.length() > 0) {
+ std::shared_ptr<::arrow::Array> chunk;
+ PARQUET_THROW_NOT_OK(builder_.Finish(&chunk));
+ result_chunks_.emplace_back(std::move(chunk));
+
+ // Also clears the dictionary memo table
+ builder_.Reset();
+ }
+ }
+
+ void MaybeWriteNewDictionary() {
+ if (this->new_dictionary_) {
+ /// If there is a new dictionary, we may need to flush the builder, then
+ /// insert the new dictionary values
+ FlushBuilder();
+ builder_.ResetFull();
+ auto decoder = dynamic_cast<BinaryDictDecoder*>(this->current_decoder_);
+ decoder->InsertDictionary(&builder_);
+ this->new_dictionary_ = false;
+ }
+ }
+
+ void ReadValuesDense(int64_t values_to_read) override {
+ int64_t num_decoded = 0;
+ if (current_encoding_ == Encoding::RLE_DICTIONARY) {
+ MaybeWriteNewDictionary();
+ auto decoder = dynamic_cast<BinaryDictDecoder*>(this->current_decoder_);
+ num_decoded = decoder->DecodeIndices(static_cast<int>(values_to_read), &builder_);
+ } else {
+ num_decoded = this->current_decoder_->DecodeArrowNonNull(
+ static_cast<int>(values_to_read), &builder_);
+
+ /// Flush values since they have been copied into the builder
+ ResetValues();
+ }
+ DCHECK_EQ(num_decoded, values_to_read);
+ }
+
+ void ReadValuesSpaced(int64_t values_to_read, int64_t null_count) override {
+ int64_t num_decoded = 0;
+ if (current_encoding_ == Encoding::RLE_DICTIONARY) {
+ MaybeWriteNewDictionary();
+ auto decoder = dynamic_cast<BinaryDictDecoder*>(this->current_decoder_);
+ num_decoded = decoder->DecodeIndicesSpaced(
+ static_cast<int>(values_to_read), static_cast<int>(null_count),
+ valid_bits_->mutable_data(), values_written_, &builder_);
+ } else {
+ num_decoded = this->current_decoder_->DecodeArrow(
+ static_cast<int>(values_to_read), static_cast<int>(null_count),
+ valid_bits_->mutable_data(), values_written_, &builder_);
+
+ /// Flush values since they have been copied into the builder
+ ResetValues();
+ }
+ DCHECK_EQ(num_decoded, values_to_read - null_count);
+ }
+
+ private:
+ using BinaryDictDecoder = DictDecoder<ByteArrayType>;
+
+ ::arrow::BinaryDictionary32Builder builder_;
+ std::vector<std::shared_ptr<::arrow::Array>> result_chunks_;
+};
+
+// TODO(wesm): Implement these to some satisfaction
+template <>
+void TypedRecordReader<Int96Type>::DebugPrintState() {}
+
+template <>
+void TypedRecordReader<ByteArrayType>::DebugPrintState() {}
+
+template <>
+void TypedRecordReader<FLBAType>::DebugPrintState() {}
+
+std::shared_ptr<RecordReader> MakeByteArrayRecordReader(const ColumnDescriptor* descr,
+ LevelInfo leaf_info,
+ ::arrow::MemoryPool* pool,
+ bool read_dictionary) {
+ if (read_dictionary) {
+ return std::make_shared<ByteArrayDictionaryRecordReader>(descr, leaf_info, pool);
+ } else {
+ return std::make_shared<ByteArrayChunkedRecordReader>(descr, leaf_info, pool);
+ }
+}
+
+} // namespace
+
+std::shared_ptr<RecordReader> RecordReader::Make(const ColumnDescriptor* descr,
+ LevelInfo leaf_info, MemoryPool* pool,
+ const bool read_dictionary) {
+ switch (descr->physical_type()) {
+ case Type::BOOLEAN:
+ return std::make_shared<TypedRecordReader<BooleanType>>(descr, leaf_info, pool);
+ case Type::INT32:
+ return std::make_shared<TypedRecordReader<Int32Type>>(descr, leaf_info, pool);
+ case Type::INT64:
+ return std::make_shared<TypedRecordReader<Int64Type>>(descr, leaf_info, pool);
+ case Type::INT96:
+ return std::make_shared<TypedRecordReader<Int96Type>>(descr, leaf_info, pool);
+ case Type::FLOAT:
+ return std::make_shared<TypedRecordReader<FloatType>>(descr, leaf_info, pool);
+ case Type::DOUBLE:
+ return std::make_shared<TypedRecordReader<DoubleType>>(descr, leaf_info, pool);
+ case Type::BYTE_ARRAY:
+ return MakeByteArrayRecordReader(descr, leaf_info, pool, read_dictionary);
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return std::make_shared<FLBARecordReader>(descr, leaf_info, pool);
+ default: {
+ // PARQUET-1481: This can occur if the file is corrupt
+ std::stringstream ss;
+ ss << "Invalid physical column type: " << static_cast<int>(descr->physical_type());
+ throw ParquetException(ss.str());
+ }
+ }
+ // Unreachable code, but suppress compiler warning
+ return nullptr;
+}
+
+} // namespace internal
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/column_reader.h b/src/arrow/cpp/src/parquet/column_reader.h
new file mode 100644
index 000000000..8c48e4d78
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/column_reader.h
@@ -0,0 +1,376 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "parquet/exception.h"
+#include "parquet/level_conversion.h"
+#include "parquet/platform.h"
+#include "parquet/schema.h"
+#include "parquet/types.h"
+
+namespace arrow {
+
+class Array;
+class ChunkedArray;
+
+namespace BitUtil {
+class BitReader;
+} // namespace BitUtil
+
+namespace util {
+class RleDecoder;
+} // namespace util
+
+} // namespace arrow
+
+namespace parquet {
+
+class Decryptor;
+class Page;
+
+// 16 MB is the default maximum page header size
+static constexpr uint32_t kDefaultMaxPageHeaderSize = 16 * 1024 * 1024;
+
+// 16 KB is the default expected page header size
+static constexpr uint32_t kDefaultPageHeaderSize = 16 * 1024;
+
+class PARQUET_EXPORT LevelDecoder {
+ public:
+ LevelDecoder();
+ ~LevelDecoder();
+
+ // Initialize the LevelDecoder state with new data
+ // and return the number of bytes consumed
+ int SetData(Encoding::type encoding, int16_t max_level, int num_buffered_values,
+ const uint8_t* data, int32_t data_size);
+
+ void SetDataV2(int32_t num_bytes, int16_t max_level, int num_buffered_values,
+ const uint8_t* data);
+
+ // Decodes a batch of levels into an array and returns the number of levels decoded
+ int Decode(int batch_size, int16_t* levels);
+
+ private:
+ int bit_width_;
+ int num_values_remaining_;
+ Encoding::type encoding_;
+ std::unique_ptr<::arrow::util::RleDecoder> rle_decoder_;
+ std::unique_ptr<::arrow::BitUtil::BitReader> bit_packed_decoder_;
+ int16_t max_level_;
+};
+
+struct CryptoContext {
+ CryptoContext(bool start_with_dictionary_page, int16_t rg_ordinal, int16_t col_ordinal,
+ std::shared_ptr<Decryptor> meta, std::shared_ptr<Decryptor> data)
+ : start_decrypt_with_dictionary_page(start_with_dictionary_page),
+ row_group_ordinal(rg_ordinal),
+ column_ordinal(col_ordinal),
+ meta_decryptor(std::move(meta)),
+ data_decryptor(std::move(data)) {}
+ CryptoContext() {}
+
+ bool start_decrypt_with_dictionary_page = false;
+ int16_t row_group_ordinal = -1;
+ int16_t column_ordinal = -1;
+ std::shared_ptr<Decryptor> meta_decryptor;
+ std::shared_ptr<Decryptor> data_decryptor;
+};
+
+// Abstract page iterator interface. This way, we can feed column pages to the
+// ColumnReader through whatever mechanism we choose
+class PARQUET_EXPORT PageReader {
+ public:
+ virtual ~PageReader() = default;
+
+ static std::unique_ptr<PageReader> Open(
+ std::shared_ptr<ArrowInputStream> stream, int64_t total_num_rows,
+ Compression::type codec, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool(),
+ const CryptoContext* ctx = NULLPTR);
+
+ // @returns: shared_ptr<Page>(nullptr) on EOS, std::shared_ptr<Page>
+ // containing new Page otherwise
+ virtual std::shared_ptr<Page> NextPage() = 0;
+
+ virtual void set_max_page_header_size(uint32_t size) = 0;
+};
+
+class PARQUET_EXPORT ColumnReader {
+ public:
+ virtual ~ColumnReader() = default;
+
+ static std::shared_ptr<ColumnReader> Make(
+ const ColumnDescriptor* descr, std::unique_ptr<PageReader> pager,
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool());
+
+ // Returns true if there are still values in this column.
+ virtual bool HasNext() = 0;
+
+ virtual Type::type type() const = 0;
+
+ virtual const ColumnDescriptor* descr() const = 0;
+
+ // Get the encoding that can be exposed by this reader. If it returns
+ // dictionary encoding, then ReadBatchWithDictionary can be used to read data.
+ //
+ // \note API EXPERIMENTAL
+ virtual ExposedEncoding GetExposedEncoding() = 0;
+
+ protected:
+ friend class RowGroupReader;
+ // Set the encoding that can be exposed by this reader.
+ //
+ // \note API EXPERIMENTAL
+ virtual void SetExposedEncoding(ExposedEncoding encoding) = 0;
+};
+
+// API to read values from a single column. This is a main client facing API.
+template <typename DType>
+class TypedColumnReader : public ColumnReader {
+ public:
+ typedef typename DType::c_type T;
+
+ // Read a batch of repetition levels, definition levels, and values from the
+ // column.
+ //
+ // Since null values are not stored in the values, the number of values read
+ // may be less than the number of repetition and definition levels. With
+ // nested data this is almost certainly true.
+ //
+ // Set def_levels or rep_levels to nullptr if you want to skip reading them.
+ // This is only safe if you know through some other source that there are no
+ // undefined values.
+ //
+ // To fully exhaust a row group, you must read batches until the number of
+ // values read reaches the number of stored values according to the metadata.
+ //
+ // This API is the same for both V1 and V2 of the DataPage
+ //
+ // @returns: actual number of levels read (see values_read for number of values read)
+ virtual int64_t ReadBatch(int64_t batch_size, int16_t* def_levels, int16_t* rep_levels,
+ T* values, int64_t* values_read) = 0;
+
+ /// Read a batch of repetition levels, definition levels, and values from the
+ /// column and leave spaces for null entries on the lowest level in the values
+ /// buffer.
+ ///
+ /// In comparison to ReadBatch the length of repetition and definition levels
+ /// is the same as of the number of values read for max_definition_level == 1.
+ /// In the case of max_definition_level > 1, the repetition and definition
+ /// levels are larger than the values but the values include the null entries
+ /// with definition_level == (max_definition_level - 1).
+ ///
+ /// To fully exhaust a row group, you must read batches until the number of
+ /// values read reaches the number of stored values according to the metadata.
+ ///
+ /// @param batch_size the number of levels to read
+ /// @param[out] def_levels The Parquet definition levels, output has
+ /// the length levels_read.
+ /// @param[out] rep_levels The Parquet repetition levels, output has
+ /// the length levels_read.
+ /// @param[out] values The values in the lowest nested level including
+ /// spacing for nulls on the lowest levels; output has the length
+ /// values_read.
+ /// @param[out] valid_bits Memory allocated for a bitmap that indicates if
+ /// the row is null or on the maximum definition level. For performance
+ /// reasons the underlying buffer should be able to store 1 bit more than
+ /// required. If this requires an additional byte, this byte is only read
+ /// but never written to.
+ /// @param valid_bits_offset The offset in bits of the valid_bits where the
+ /// first relevant bit resides.
+ /// @param[out] levels_read The number of repetition/definition levels that were read.
+ /// @param[out] values_read The number of values read, this includes all
+ /// non-null entries as well as all null-entries on the lowest level
+ /// (i.e. definition_level == max_definition_level - 1)
+ /// @param[out] null_count The number of nulls on the lowest levels.
+ /// (i.e. (values_read - null_count) is total number of non-null entries)
+ ///
+ /// \deprecated Since 4.0.0
+ ARROW_DEPRECATED("Doesn't handle nesting correctly and unused outside of unit tests.")
+ virtual int64_t ReadBatchSpaced(int64_t batch_size, int16_t* def_levels,
+ int16_t* rep_levels, T* values, uint8_t* valid_bits,
+ int64_t valid_bits_offset, int64_t* levels_read,
+ int64_t* values_read, int64_t* null_count) = 0;
+
+ // Skip reading levels
+ // Returns the number of levels skipped
+ virtual int64_t Skip(int64_t num_rows_to_skip) = 0;
+
+ // Read a batch of repetition levels, definition levels, and indices from the
+ // column. And read the dictionary if a dictionary page is encountered during
+ // reading pages. This API is similar to ReadBatch(), with ability to read
+ // dictionary and indices. It is only valid to call this method when the reader can
+ // expose dictionary encoding. (i.e., the reader's GetExposedEncoding() returns
+ // DICTIONARY).
+ //
+ // The dictionary is read along with the data page. When there's no data page,
+ // the dictionary won't be returned.
+ //
+ // @param batch_size The batch size to read
+ // @param[out] def_levels The Parquet definition levels.
+ // @param[out] rep_levels The Parquet repetition levels.
+ // @param[out] indices The dictionary indices.
+ // @param[out] indices_read The number of indices read.
+ // @param[out] dict The pointer to dictionary values. It will return nullptr if
+ // there's no data page. Each column chunk only has one dictionary page. The dictionary
+ // is owned by the reader, so the caller is responsible for copying the dictionary
+ // values before the reader gets destroyed.
+ // @param[out] dict_len The dictionary length. It will return 0 if there's no data
+ // page.
+ // @returns: actual number of levels read (see indices_read for number of
+ // indices read
+ //
+ // \note API EXPERIMENTAL
+ virtual int64_t ReadBatchWithDictionary(int64_t batch_size, int16_t* def_levels,
+ int16_t* rep_levels, int32_t* indices,
+ int64_t* indices_read, const T** dict,
+ int32_t* dict_len) = 0;
+};
+
+namespace internal {
+
+/// \brief Stateful column reader that delimits semantic records for both flat
+/// and nested columns
+///
+/// \note API EXPERIMENTAL
+/// \since 1.3.0
+class RecordReader {
+ public:
+ static std::shared_ptr<RecordReader> Make(
+ const ColumnDescriptor* descr, LevelInfo leaf_info,
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool(),
+ const bool read_dictionary = false);
+
+ virtual ~RecordReader() = default;
+
+ /// \brief Attempt to read indicated number of records from column chunk
+ /// \return number of records read
+ virtual int64_t ReadRecords(int64_t num_records) = 0;
+
+ /// \brief Pre-allocate space for data. Results in better flat read performance
+ virtual void Reserve(int64_t num_values) = 0;
+
+ /// \brief Clear consumed values and repetition/definition levels as the
+ /// result of calling ReadRecords
+ virtual void Reset() = 0;
+
+ /// \brief Transfer filled values buffer to caller. A new one will be
+ /// allocated in subsequent ReadRecords calls
+ virtual std::shared_ptr<ResizableBuffer> ReleaseValues() = 0;
+
+ /// \brief Transfer filled validity bitmap buffer to caller. A new one will
+ /// be allocated in subsequent ReadRecords calls
+ virtual std::shared_ptr<ResizableBuffer> ReleaseIsValid() = 0;
+
+ /// \brief Return true if the record reader has more internal data yet to
+ /// process
+ virtual bool HasMoreData() const = 0;
+
+ /// \brief Advance record reader to the next row group
+ /// \param[in] reader obtained from RowGroupReader::GetColumnPageReader
+ virtual void SetPageReader(std::unique_ptr<PageReader> reader) = 0;
+
+ virtual void DebugPrintState() = 0;
+
+ /// \brief Decoded definition levels
+ int16_t* def_levels() const {
+ return reinterpret_cast<int16_t*>(def_levels_->mutable_data());
+ }
+
+ /// \brief Decoded repetition levels
+ int16_t* rep_levels() const {
+ return reinterpret_cast<int16_t*>(rep_levels_->mutable_data());
+ }
+
+ /// \brief Decoded values, including nulls, if any
+ uint8_t* values() const { return values_->mutable_data(); }
+
+ /// \brief Number of values written including nulls (if any)
+ int64_t values_written() const { return values_written_; }
+
+ /// \brief Number of definition / repetition levels (from those that have
+ /// been decoded) that have been consumed inside the reader.
+ int64_t levels_position() const { return levels_position_; }
+
+ /// \brief Number of definition / repetition levels that have been written
+ /// internally in the reader
+ int64_t levels_written() const { return levels_written_; }
+
+ /// \brief Number of nulls in the leaf
+ int64_t null_count() const { return null_count_; }
+
+ /// \brief True if the leaf values are nullable
+ bool nullable_values() const { return nullable_values_; }
+
+ /// \brief True if reading directly as Arrow dictionary-encoded
+ bool read_dictionary() const { return read_dictionary_; }
+
+ protected:
+ bool nullable_values_;
+
+ bool at_record_start_;
+ int64_t records_read_;
+
+ int64_t values_written_;
+ int64_t values_capacity_;
+ int64_t null_count_;
+
+ int64_t levels_written_;
+ int64_t levels_position_;
+ int64_t levels_capacity_;
+
+ std::shared_ptr<::arrow::ResizableBuffer> values_;
+ // In the case of false, don't allocate the values buffer (when we directly read into
+ // builder classes).
+ bool uses_values_;
+
+ std::shared_ptr<::arrow::ResizableBuffer> valid_bits_;
+ std::shared_ptr<::arrow::ResizableBuffer> def_levels_;
+ std::shared_ptr<::arrow::ResizableBuffer> rep_levels_;
+
+ bool read_dictionary_ = false;
+};
+
+class BinaryRecordReader : virtual public RecordReader {
+ public:
+ virtual std::vector<std::shared_ptr<::arrow::Array>> GetBuilderChunks() = 0;
+};
+
+/// \brief Read records directly to dictionary-encoded Arrow form (int32
+/// indices). Only valid for BYTE_ARRAY columns
+class DictionaryRecordReader : virtual public RecordReader {
+ public:
+ virtual std::shared_ptr<::arrow::ChunkedArray> GetResult() = 0;
+};
+
+} // namespace internal
+
+using BoolReader = TypedColumnReader<BooleanType>;
+using Int32Reader = TypedColumnReader<Int32Type>;
+using Int64Reader = TypedColumnReader<Int64Type>;
+using Int96Reader = TypedColumnReader<Int96Type>;
+using FloatReader = TypedColumnReader<FloatType>;
+using DoubleReader = TypedColumnReader<DoubleType>;
+using ByteArrayReader = TypedColumnReader<ByteArrayType>;
+using FixedLenByteArrayReader = TypedColumnReader<FLBAType>;
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/column_reader_test.cc b/src/arrow/cpp/src/parquet/column_reader_test.cc
new file mode 100644
index 000000000..1a00161fb
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/column_reader_test.cc
@@ -0,0 +1,476 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/util/macros.h"
+#include "arrow/util/make_unique.h"
+#include "parquet/column_page.h"
+#include "parquet/column_reader.h"
+#include "parquet/schema.h"
+#include "parquet/test_util.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+using schema::NodePtr;
+
+namespace test {
+
+template <typename T>
+static inline bool vector_equal_with_def_levels(const std::vector<T>& left,
+ const std::vector<int16_t>& def_levels,
+ int16_t max_def_levels,
+ int16_t max_rep_levels,
+ const std::vector<T>& right) {
+ size_t i_left = 0;
+ size_t i_right = 0;
+ for (size_t i = 0; i < def_levels.size(); i++) {
+ if (def_levels[i] == max_def_levels) {
+ // Compare
+ if (left[i_left] != right[i_right]) {
+ std::cerr << "index " << i << " left was " << left[i_left] << " right was "
+ << right[i] << std::endl;
+ return false;
+ }
+ i_left++;
+ i_right++;
+ } else if (def_levels[i] == (max_def_levels - 1)) {
+ // Null entry on the lowest nested level
+ i_right++;
+ } else if (def_levels[i] < (max_def_levels - 1)) {
+ // Null entry on a higher nesting level, only supported for non-repeating data
+ if (max_rep_levels == 0) {
+ i_right++;
+ }
+ }
+ }
+
+ return true;
+}
+
+class TestPrimitiveReader : public ::testing::Test {
+ public:
+ void InitReader(const ColumnDescriptor* d) {
+ std::unique_ptr<PageReader> pager_;
+ pager_.reset(new test::MockPageReader(pages_));
+ reader_ = ColumnReader::Make(d, std::move(pager_));
+ }
+
+ void CheckResults() {
+ std::vector<int32_t> vresult(num_values_, -1);
+ std::vector<int16_t> dresult(num_levels_, -1);
+ std::vector<int16_t> rresult(num_levels_, -1);
+ int64_t values_read = 0;
+ int total_values_read = 0;
+ int batch_actual = 0;
+
+ Int32Reader* reader = static_cast<Int32Reader*>(reader_.get());
+ int32_t batch_size = 8;
+ int batch = 0;
+ // This will cover both the cases
+ // 1) batch_size < page_size (multiple ReadBatch from a single page)
+ // 2) batch_size > page_size (BatchRead limits to a single page)
+ do {
+ batch = static_cast<int>(reader->ReadBatch(
+ batch_size, &dresult[0] + batch_actual, &rresult[0] + batch_actual,
+ &vresult[0] + total_values_read, &values_read));
+ total_values_read += static_cast<int>(values_read);
+ batch_actual += batch;
+ batch_size = std::min(1 << 24, std::max(batch_size * 2, 4096));
+ } while (batch > 0);
+
+ ASSERT_EQ(num_levels_, batch_actual);
+ ASSERT_EQ(num_values_, total_values_read);
+ ASSERT_TRUE(vector_equal(values_, vresult));
+ if (max_def_level_ > 0) {
+ ASSERT_TRUE(vector_equal(def_levels_, dresult));
+ }
+ if (max_rep_level_ > 0) {
+ ASSERT_TRUE(vector_equal(rep_levels_, rresult));
+ }
+ // catch improper writes at EOS
+ batch_actual =
+ static_cast<int>(reader->ReadBatch(5, nullptr, nullptr, nullptr, &values_read));
+ ASSERT_EQ(0, batch_actual);
+ ASSERT_EQ(0, values_read);
+ }
+ void CheckResultsSpaced() {
+ std::vector<int32_t> vresult(num_levels_, -1);
+ std::vector<int16_t> dresult(num_levels_, -1);
+ std::vector<int16_t> rresult(num_levels_, -1);
+ std::vector<uint8_t> valid_bits(num_levels_, 255);
+ int total_values_read = 0;
+ int batch_actual = 0;
+ int levels_actual = 0;
+ int64_t null_count = -1;
+ int64_t levels_read = 0;
+ int64_t values_read;
+
+ Int32Reader* reader = static_cast<Int32Reader*>(reader_.get());
+ int32_t batch_size = 8;
+ int batch = 0;
+ // This will cover both the cases
+ // 1) batch_size < page_size (multiple ReadBatch from a single page)
+ // 2) batch_size > page_size (BatchRead limits to a single page)
+ do {
+ ARROW_SUPPRESS_DEPRECATION_WARNING
+ batch = static_cast<int>(reader->ReadBatchSpaced(
+ batch_size, dresult.data() + levels_actual, rresult.data() + levels_actual,
+ vresult.data() + batch_actual, valid_bits.data() + batch_actual, 0,
+ &levels_read, &values_read, &null_count));
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
+ total_values_read += batch - static_cast<int>(null_count);
+ batch_actual += batch;
+ levels_actual += static_cast<int>(levels_read);
+ batch_size = std::min(1 << 24, std::max(batch_size * 2, 4096));
+ } while ((batch > 0) || (levels_read > 0));
+
+ ASSERT_EQ(num_levels_, levels_actual);
+ ASSERT_EQ(num_values_, total_values_read);
+ if (max_def_level_ > 0) {
+ ASSERT_TRUE(vector_equal(def_levels_, dresult));
+ ASSERT_TRUE(vector_equal_with_def_levels(values_, dresult, max_def_level_,
+ max_rep_level_, vresult));
+ } else {
+ ASSERT_TRUE(vector_equal(values_, vresult));
+ }
+ if (max_rep_level_ > 0) {
+ ASSERT_TRUE(vector_equal(rep_levels_, rresult));
+ }
+ // catch improper writes at EOS
+ ARROW_SUPPRESS_DEPRECATION_WARNING
+ batch_actual = static_cast<int>(
+ reader->ReadBatchSpaced(5, nullptr, nullptr, nullptr, valid_bits.data(), 0,
+ &levels_read, &values_read, &null_count));
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
+ ASSERT_EQ(0, batch_actual);
+ ASSERT_EQ(0, null_count);
+ }
+
+ void Clear() {
+ values_.clear();
+ def_levels_.clear();
+ rep_levels_.clear();
+ pages_.clear();
+ reader_.reset();
+ }
+
+ void ExecutePlain(int num_pages, int levels_per_page, const ColumnDescriptor* d) {
+ num_values_ =
+ MakePages<Int32Type>(d, num_pages, levels_per_page, def_levels_, rep_levels_,
+ values_, data_buffer_, pages_, Encoding::PLAIN);
+ num_levels_ = num_pages * levels_per_page;
+ InitReader(d);
+ CheckResults();
+ Clear();
+
+ num_values_ =
+ MakePages<Int32Type>(d, num_pages, levels_per_page, def_levels_, rep_levels_,
+ values_, data_buffer_, pages_, Encoding::PLAIN);
+ num_levels_ = num_pages * levels_per_page;
+ InitReader(d);
+ CheckResultsSpaced();
+ Clear();
+ }
+
+ void ExecuteDict(int num_pages, int levels_per_page, const ColumnDescriptor* d) {
+ num_values_ =
+ MakePages<Int32Type>(d, num_pages, levels_per_page, def_levels_, rep_levels_,
+ values_, data_buffer_, pages_, Encoding::RLE_DICTIONARY);
+ num_levels_ = num_pages * levels_per_page;
+ InitReader(d);
+ CheckResults();
+ Clear();
+
+ num_values_ =
+ MakePages<Int32Type>(d, num_pages, levels_per_page, def_levels_, rep_levels_,
+ values_, data_buffer_, pages_, Encoding::RLE_DICTIONARY);
+ num_levels_ = num_pages * levels_per_page;
+ InitReader(d);
+ CheckResultsSpaced();
+ Clear();
+ }
+
+ protected:
+ int num_levels_;
+ int num_values_;
+ int16_t max_def_level_;
+ int16_t max_rep_level_;
+ std::vector<std::shared_ptr<Page>> pages_;
+ std::shared_ptr<ColumnReader> reader_;
+ std::vector<int32_t> values_;
+ std::vector<int16_t> def_levels_;
+ std::vector<int16_t> rep_levels_;
+ std::vector<uint8_t> data_buffer_; // For BA and FLBA
+};
+
+TEST_F(TestPrimitiveReader, TestInt32FlatRequired) {
+ int levels_per_page = 100;
+ int num_pages = 50;
+ max_def_level_ = 0;
+ max_rep_level_ = 0;
+ NodePtr type = schema::Int32("a", Repetition::REQUIRED);
+ const ColumnDescriptor descr(type, max_def_level_, max_rep_level_);
+ ASSERT_NO_FATAL_FAILURE(ExecutePlain(num_pages, levels_per_page, &descr));
+ ASSERT_NO_FATAL_FAILURE(ExecuteDict(num_pages, levels_per_page, &descr));
+}
+
+TEST_F(TestPrimitiveReader, TestInt32FlatOptional) {
+ int levels_per_page = 100;
+ int num_pages = 50;
+ max_def_level_ = 4;
+ max_rep_level_ = 0;
+ NodePtr type = schema::Int32("b", Repetition::OPTIONAL);
+ const ColumnDescriptor descr(type, max_def_level_, max_rep_level_);
+ ASSERT_NO_FATAL_FAILURE(ExecutePlain(num_pages, levels_per_page, &descr));
+ ASSERT_NO_FATAL_FAILURE(ExecuteDict(num_pages, levels_per_page, &descr));
+}
+
+TEST_F(TestPrimitiveReader, TestInt32FlatRepeated) {
+ int levels_per_page = 100;
+ int num_pages = 50;
+ max_def_level_ = 4;
+ max_rep_level_ = 2;
+ NodePtr type = schema::Int32("c", Repetition::REPEATED);
+ const ColumnDescriptor descr(type, max_def_level_, max_rep_level_);
+ ASSERT_NO_FATAL_FAILURE(ExecutePlain(num_pages, levels_per_page, &descr));
+ ASSERT_NO_FATAL_FAILURE(ExecuteDict(num_pages, levels_per_page, &descr));
+}
+
+TEST_F(TestPrimitiveReader, TestInt32FlatRequiredSkip) {
+ int levels_per_page = 100;
+ int num_pages = 5;
+ max_def_level_ = 0;
+ max_rep_level_ = 0;
+ NodePtr type = schema::Int32("b", Repetition::REQUIRED);
+ const ColumnDescriptor descr(type, max_def_level_, max_rep_level_);
+ MakePages<Int32Type>(&descr, num_pages, levels_per_page, def_levels_, rep_levels_,
+ values_, data_buffer_, pages_, Encoding::PLAIN);
+ InitReader(&descr);
+ std::vector<int32_t> vresult(levels_per_page / 2, -1);
+ std::vector<int16_t> dresult(levels_per_page / 2, -1);
+ std::vector<int16_t> rresult(levels_per_page / 2, -1);
+
+ Int32Reader* reader = static_cast<Int32Reader*>(reader_.get());
+ int64_t values_read = 0;
+
+ // 1) skip_size > page_size (multiple pages skipped)
+ // Skip first 2 pages
+ int64_t levels_skipped = reader->Skip(2 * levels_per_page);
+ ASSERT_EQ(2 * levels_per_page, levels_skipped);
+ // Read half a page
+ reader->ReadBatch(levels_per_page / 2, dresult.data(), rresult.data(), vresult.data(),
+ &values_read);
+ std::vector<int32_t> sub_values(
+ values_.begin() + 2 * levels_per_page,
+ values_.begin() + static_cast<int>(2.5 * static_cast<double>(levels_per_page)));
+ ASSERT_TRUE(vector_equal(sub_values, vresult));
+
+ // 2) skip_size == page_size (skip across two pages)
+ levels_skipped = reader->Skip(levels_per_page);
+ ASSERT_EQ(levels_per_page, levels_skipped);
+ // Read half a page
+ reader->ReadBatch(levels_per_page / 2, dresult.data(), rresult.data(), vresult.data(),
+ &values_read);
+ sub_values.clear();
+ sub_values.insert(
+ sub_values.end(),
+ values_.begin() + static_cast<int>(3.5 * static_cast<double>(levels_per_page)),
+ values_.begin() + 4 * levels_per_page);
+ ASSERT_TRUE(vector_equal(sub_values, vresult));
+
+ // 3) skip_size < page_size (skip limited to a single page)
+ // Skip half a page
+ levels_skipped = reader->Skip(levels_per_page / 2);
+ ASSERT_EQ(0.5 * levels_per_page, levels_skipped);
+ // Read half a page
+ reader->ReadBatch(levels_per_page / 2, dresult.data(), rresult.data(), vresult.data(),
+ &values_read);
+ sub_values.clear();
+ sub_values.insert(
+ sub_values.end(),
+ values_.begin() + static_cast<int>(4.5 * static_cast<double>(levels_per_page)),
+ values_.end());
+ ASSERT_TRUE(vector_equal(sub_values, vresult));
+
+ values_.clear();
+ def_levels_.clear();
+ rep_levels_.clear();
+ pages_.clear();
+ reader_.reset();
+}
+
+TEST_F(TestPrimitiveReader, TestDictionaryEncodedPages) {
+ max_def_level_ = 0;
+ max_rep_level_ = 0;
+ NodePtr type = schema::Int32("a", Repetition::REQUIRED);
+ const ColumnDescriptor descr(type, max_def_level_, max_rep_level_);
+ std::shared_ptr<ResizableBuffer> dummy = AllocateBuffer();
+
+ std::shared_ptr<DictionaryPage> dict_page =
+ std::make_shared<DictionaryPage>(dummy, 0, Encoding::PLAIN);
+ std::shared_ptr<DataPageV1> data_page = MakeDataPage<Int32Type>(
+ &descr, {}, 0, Encoding::RLE_DICTIONARY, {}, 0, {}, 0, {}, 0);
+ pages_.push_back(dict_page);
+ pages_.push_back(data_page);
+ InitReader(&descr);
+ // Tests Dict : PLAIN, Data : RLE_DICTIONARY
+ ASSERT_NO_THROW(reader_->HasNext());
+ pages_.clear();
+
+ dict_page = std::make_shared<DictionaryPage>(dummy, 0, Encoding::PLAIN_DICTIONARY);
+ data_page = MakeDataPage<Int32Type>(&descr, {}, 0, Encoding::PLAIN_DICTIONARY, {}, 0,
+ {}, 0, {}, 0);
+ pages_.push_back(dict_page);
+ pages_.push_back(data_page);
+ InitReader(&descr);
+ // Tests Dict : PLAIN_DICTIONARY, Data : PLAIN_DICTIONARY
+ ASSERT_NO_THROW(reader_->HasNext());
+ pages_.clear();
+
+ data_page = MakeDataPage<Int32Type>(&descr, {}, 0, Encoding::RLE_DICTIONARY, {}, 0, {},
+ 0, {}, 0);
+ pages_.push_back(data_page);
+ InitReader(&descr);
+ // Tests dictionary page must occur before data page
+ ASSERT_THROW(reader_->HasNext(), ParquetException);
+ pages_.clear();
+
+ dict_page = std::make_shared<DictionaryPage>(dummy, 0, Encoding::DELTA_BYTE_ARRAY);
+ pages_.push_back(dict_page);
+ InitReader(&descr);
+ // Tests only RLE_DICTIONARY is supported
+ ASSERT_THROW(reader_->HasNext(), ParquetException);
+ pages_.clear();
+
+ std::shared_ptr<DictionaryPage> dict_page1 =
+ std::make_shared<DictionaryPage>(dummy, 0, Encoding::PLAIN_DICTIONARY);
+ std::shared_ptr<DictionaryPage> dict_page2 =
+ std::make_shared<DictionaryPage>(dummy, 0, Encoding::PLAIN);
+ pages_.push_back(dict_page1);
+ pages_.push_back(dict_page2);
+ InitReader(&descr);
+ // Column cannot have more than one dictionary
+ ASSERT_THROW(reader_->HasNext(), ParquetException);
+ pages_.clear();
+
+ data_page = MakeDataPage<Int32Type>(&descr, {}, 0, Encoding::DELTA_BYTE_ARRAY, {}, 0,
+ {}, 0, {}, 0);
+ pages_.push_back(data_page);
+ InitReader(&descr);
+ // unsupported encoding
+ ASSERT_THROW(reader_->HasNext(), ParquetException);
+ pages_.clear();
+}
+
+TEST_F(TestPrimitiveReader, TestDictionaryEncodedPagesWithExposeEncoding) {
+ max_def_level_ = 0;
+ max_rep_level_ = 0;
+ int levels_per_page = 100;
+ int num_pages = 5;
+ std::vector<int16_t> def_levels;
+ std::vector<int16_t> rep_levels;
+ std::vector<ByteArray> values;
+ std::vector<uint8_t> buffer;
+ NodePtr type = schema::ByteArray("a", Repetition::REQUIRED);
+ const ColumnDescriptor descr(type, max_def_level_, max_rep_level_);
+
+ // Fully dictionary encoded
+ MakePages<ByteArrayType>(&descr, num_pages, levels_per_page, def_levels, rep_levels,
+ values, buffer, pages_, Encoding::RLE_DICTIONARY);
+ InitReader(&descr);
+
+ auto reader = static_cast<ByteArrayReader*>(reader_.get());
+ const ByteArray* dict = nullptr;
+ int32_t dict_len = 0;
+ int64_t total_indices = 0;
+ int64_t indices_read = 0;
+ int64_t value_size = values.size();
+ auto indices = ::arrow::internal::make_unique<int32_t[]>(value_size);
+ while (total_indices < value_size && reader->HasNext()) {
+ const ByteArray* tmp_dict = nullptr;
+ int32_t tmp_dict_len = 0;
+ EXPECT_NO_THROW(reader->ReadBatchWithDictionary(
+ value_size, /*def_levels=*/nullptr,
+ /*rep_levels=*/nullptr, indices.get() + total_indices, &indices_read, &tmp_dict,
+ &tmp_dict_len));
+ if (tmp_dict != nullptr) {
+ // Dictionary is read along with data
+ EXPECT_GT(indices_read, 0);
+ dict = tmp_dict;
+ dict_len = tmp_dict_len;
+ } else {
+ // Dictionary is not read when there's no data
+ EXPECT_EQ(indices_read, 0);
+ }
+ total_indices += indices_read;
+ }
+
+ EXPECT_EQ(total_indices, value_size);
+ for (int64_t i = 0; i < total_indices; ++i) {
+ EXPECT_LT(indices[i], dict_len);
+ EXPECT_EQ(dict[indices[i]].len, values[i].len);
+ EXPECT_EQ(memcmp(dict[indices[i]].ptr, values[i].ptr, values[i].len), 0);
+ }
+ pages_.clear();
+}
+
+TEST_F(TestPrimitiveReader, TestNonDictionaryEncodedPagesWithExposeEncoding) {
+ max_def_level_ = 0;
+ max_rep_level_ = 0;
+ int64_t value_size = 100;
+ std::vector<int32_t> values(value_size, 0);
+ NodePtr type = schema::Int32("a", Repetition::REQUIRED);
+ const ColumnDescriptor descr(type, max_def_level_, max_rep_level_);
+
+ // The data page falls back to plain encoding
+ std::shared_ptr<ResizableBuffer> dummy = AllocateBuffer();
+ std::shared_ptr<DictionaryPage> dict_page =
+ std::make_shared<DictionaryPage>(dummy, 0, Encoding::PLAIN);
+ std::shared_ptr<DataPageV1> data_page = MakeDataPage<Int32Type>(
+ &descr, values, static_cast<int>(value_size), Encoding::PLAIN, /*indices=*/{},
+ /*indices_size=*/0, /*def_levels=*/{}, /*max_def_level=*/0, /*rep_levels=*/{},
+ /*max_rep_level=*/0);
+ pages_.push_back(dict_page);
+ pages_.push_back(data_page);
+ InitReader(&descr);
+
+ auto reader = static_cast<ByteArrayReader*>(reader_.get());
+ const ByteArray* dict = nullptr;
+ int32_t dict_len = 0;
+ int64_t indices_read = 0;
+ auto indices = ::arrow::internal::make_unique<int32_t[]>(value_size);
+ // Dictionary cannot be exposed when it's not fully dictionary encoded
+ EXPECT_THROW(reader->ReadBatchWithDictionary(value_size, /*def_levels=*/nullptr,
+ /*rep_levels=*/nullptr, indices.get(),
+ &indices_read, &dict, &dict_len),
+ ParquetException);
+ pages_.clear();
+}
+
+} // namespace test
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/column_scanner.cc b/src/arrow/cpp/src/parquet/column_scanner.cc
new file mode 100644
index 000000000..9ab1663cc
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/column_scanner.cc
@@ -0,0 +1,91 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/column_scanner.h"
+
+#include <cstdint>
+#include <memory>
+
+#include "parquet/column_reader.h"
+
+using arrow::MemoryPool;
+
+namespace parquet {
+
+std::shared_ptr<Scanner> Scanner::Make(std::shared_ptr<ColumnReader> col_reader,
+ int64_t batch_size, MemoryPool* pool) {
+ switch (col_reader->type()) {
+ case Type::BOOLEAN:
+ return std::make_shared<BoolScanner>(std::move(col_reader), batch_size, pool);
+ case Type::INT32:
+ return std::make_shared<Int32Scanner>(std::move(col_reader), batch_size, pool);
+ case Type::INT64:
+ return std::make_shared<Int64Scanner>(std::move(col_reader), batch_size, pool);
+ case Type::INT96:
+ return std::make_shared<Int96Scanner>(std::move(col_reader), batch_size, pool);
+ case Type::FLOAT:
+ return std::make_shared<FloatScanner>(std::move(col_reader), batch_size, pool);
+ case Type::DOUBLE:
+ return std::make_shared<DoubleScanner>(std::move(col_reader), batch_size, pool);
+ case Type::BYTE_ARRAY:
+ return std::make_shared<ByteArrayScanner>(std::move(col_reader), batch_size, pool);
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return std::make_shared<FixedLenByteArrayScanner>(std::move(col_reader), batch_size,
+ pool);
+ default:
+ ParquetException::NYI("type reader not implemented");
+ }
+ // Unreachable code, but suppress compiler warning
+ return std::shared_ptr<Scanner>(nullptr);
+}
+
+int64_t ScanAllValues(int32_t batch_size, int16_t* def_levels, int16_t* rep_levels,
+ uint8_t* values, int64_t* values_buffered,
+ parquet::ColumnReader* reader) {
+ switch (reader->type()) {
+ case parquet::Type::BOOLEAN:
+ return ScanAll<parquet::BoolReader>(batch_size, def_levels, rep_levels, values,
+ values_buffered, reader);
+ case parquet::Type::INT32:
+ return ScanAll<parquet::Int32Reader>(batch_size, def_levels, rep_levels, values,
+ values_buffered, reader);
+ case parquet::Type::INT64:
+ return ScanAll<parquet::Int64Reader>(batch_size, def_levels, rep_levels, values,
+ values_buffered, reader);
+ case parquet::Type::INT96:
+ return ScanAll<parquet::Int96Reader>(batch_size, def_levels, rep_levels, values,
+ values_buffered, reader);
+ case parquet::Type::FLOAT:
+ return ScanAll<parquet::FloatReader>(batch_size, def_levels, rep_levels, values,
+ values_buffered, reader);
+ case parquet::Type::DOUBLE:
+ return ScanAll<parquet::DoubleReader>(batch_size, def_levels, rep_levels, values,
+ values_buffered, reader);
+ case parquet::Type::BYTE_ARRAY:
+ return ScanAll<parquet::ByteArrayReader>(batch_size, def_levels, rep_levels, values,
+ values_buffered, reader);
+ case parquet::Type::FIXED_LEN_BYTE_ARRAY:
+ return ScanAll<parquet::FixedLenByteArrayReader>(batch_size, def_levels, rep_levels,
+ values, values_buffered, reader);
+ default:
+ parquet::ParquetException::NYI("type reader not implemented");
+ }
+ // Unreachable code, but suppress compiler warning
+ return 0;
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/column_scanner.h b/src/arrow/cpp/src/parquet/column_scanner.h
new file mode 100644
index 000000000..d53435f03
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/column_scanner.h
@@ -0,0 +1,262 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <stdio.h>
+
+#include <cstdint>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "parquet/column_reader.h"
+#include "parquet/exception.h"
+#include "parquet/platform.h"
+#include "parquet/schema.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+static constexpr int64_t DEFAULT_SCANNER_BATCH_SIZE = 128;
+
+class PARQUET_EXPORT Scanner {
+ public:
+ explicit Scanner(std::shared_ptr<ColumnReader> reader,
+ int64_t batch_size = DEFAULT_SCANNER_BATCH_SIZE,
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool())
+ : batch_size_(batch_size),
+ level_offset_(0),
+ levels_buffered_(0),
+ value_buffer_(AllocateBuffer(pool)),
+ value_offset_(0),
+ values_buffered_(0),
+ reader_(std::move(reader)) {
+ def_levels_.resize(descr()->max_definition_level() > 0 ? batch_size_ : 0);
+ rep_levels_.resize(descr()->max_repetition_level() > 0 ? batch_size_ : 0);
+ }
+
+ virtual ~Scanner() {}
+
+ static std::shared_ptr<Scanner> Make(
+ std::shared_ptr<ColumnReader> col_reader,
+ int64_t batch_size = DEFAULT_SCANNER_BATCH_SIZE,
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool());
+
+ virtual void PrintNext(std::ostream& out, int width, bool with_levels = false) = 0;
+
+ bool HasNext() { return level_offset_ < levels_buffered_ || reader_->HasNext(); }
+
+ const ColumnDescriptor* descr() const { return reader_->descr(); }
+
+ int64_t batch_size() const { return batch_size_; }
+
+ void SetBatchSize(int64_t batch_size) { batch_size_ = batch_size; }
+
+ protected:
+ int64_t batch_size_;
+
+ std::vector<int16_t> def_levels_;
+ std::vector<int16_t> rep_levels_;
+ int level_offset_;
+ int levels_buffered_;
+
+ std::shared_ptr<ResizableBuffer> value_buffer_;
+ int value_offset_;
+ int64_t values_buffered_;
+ std::shared_ptr<ColumnReader> reader_;
+};
+
+template <typename DType>
+class PARQUET_TEMPLATE_CLASS_EXPORT TypedScanner : public Scanner {
+ public:
+ typedef typename DType::c_type T;
+
+ explicit TypedScanner(std::shared_ptr<ColumnReader> reader,
+ int64_t batch_size = DEFAULT_SCANNER_BATCH_SIZE,
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool())
+ : Scanner(std::move(reader), batch_size, pool) {
+ typed_reader_ = static_cast<TypedColumnReader<DType>*>(reader_.get());
+ int value_byte_size = type_traits<DType::type_num>::value_byte_size;
+ PARQUET_THROW_NOT_OK(value_buffer_->Resize(batch_size_ * value_byte_size));
+ values_ = reinterpret_cast<T*>(value_buffer_->mutable_data());
+ }
+
+ virtual ~TypedScanner() {}
+
+ bool NextLevels(int16_t* def_level, int16_t* rep_level) {
+ if (level_offset_ == levels_buffered_) {
+ levels_buffered_ = static_cast<int>(
+ typed_reader_->ReadBatch(static_cast<int>(batch_size_), def_levels_.data(),
+ rep_levels_.data(), values_, &values_buffered_));
+
+ value_offset_ = 0;
+ level_offset_ = 0;
+ if (!levels_buffered_) {
+ return false;
+ }
+ }
+ *def_level = descr()->max_definition_level() > 0 ? def_levels_[level_offset_] : 0;
+ *rep_level = descr()->max_repetition_level() > 0 ? rep_levels_[level_offset_] : 0;
+ level_offset_++;
+ return true;
+ }
+
+ bool Next(T* val, int16_t* def_level, int16_t* rep_level, bool* is_null) {
+ if (level_offset_ == levels_buffered_) {
+ if (!HasNext()) {
+ // Out of data pages
+ return false;
+ }
+ }
+
+ NextLevels(def_level, rep_level);
+ *is_null = *def_level < descr()->max_definition_level();
+
+ if (*is_null) {
+ return true;
+ }
+
+ if (value_offset_ == values_buffered_) {
+ throw ParquetException("Value was non-null, but has not been buffered");
+ }
+ *val = values_[value_offset_++];
+ return true;
+ }
+
+ // Returns true if there is a next value
+ bool NextValue(T* val, bool* is_null) {
+ if (level_offset_ == levels_buffered_) {
+ if (!HasNext()) {
+ // Out of data pages
+ return false;
+ }
+ }
+
+ // Out of values
+ int16_t def_level = -1;
+ int16_t rep_level = -1;
+ NextLevels(&def_level, &rep_level);
+ *is_null = def_level < descr()->max_definition_level();
+
+ if (*is_null) {
+ return true;
+ }
+
+ if (value_offset_ == values_buffered_) {
+ throw ParquetException("Value was non-null, but has not been buffered");
+ }
+ *val = values_[value_offset_++];
+ return true;
+ }
+
+ virtual void PrintNext(std::ostream& out, int width, bool with_levels = false) {
+ T val{};
+ int16_t def_level = -1;
+ int16_t rep_level = -1;
+ bool is_null = false;
+ char buffer[80];
+
+ if (!Next(&val, &def_level, &rep_level, &is_null)) {
+ throw ParquetException("No more values buffered");
+ }
+
+ if (with_levels) {
+ out << " D:" << def_level << " R:" << rep_level << " ";
+ if (!is_null) {
+ out << "V:";
+ }
+ }
+
+ if (is_null) {
+ std::string null_fmt = format_fwf<ByteArrayType>(width);
+ snprintf(buffer, sizeof(buffer), null_fmt.c_str(), "NULL");
+ } else {
+ FormatValue(&val, buffer, sizeof(buffer), width);
+ }
+ out << buffer;
+ }
+
+ private:
+ // The ownership of this object is expressed through the reader_ variable in the base
+ TypedColumnReader<DType>* typed_reader_;
+
+ inline void FormatValue(void* val, char* buffer, int bufsize, int width);
+
+ T* values_;
+};
+
+template <typename DType>
+inline void TypedScanner<DType>::FormatValue(void* val, char* buffer, int bufsize,
+ int width) {
+ std::string fmt = format_fwf<DType>(width);
+ snprintf(buffer, bufsize, fmt.c_str(), *reinterpret_cast<T*>(val));
+}
+
+template <>
+inline void TypedScanner<Int96Type>::FormatValue(void* val, char* buffer, int bufsize,
+ int width) {
+ std::string fmt = format_fwf<Int96Type>(width);
+ std::string result = Int96ToString(*reinterpret_cast<Int96*>(val));
+ snprintf(buffer, bufsize, fmt.c_str(), result.c_str());
+}
+
+template <>
+inline void TypedScanner<ByteArrayType>::FormatValue(void* val, char* buffer, int bufsize,
+ int width) {
+ std::string fmt = format_fwf<ByteArrayType>(width);
+ std::string result = ByteArrayToString(*reinterpret_cast<ByteArray*>(val));
+ snprintf(buffer, bufsize, fmt.c_str(), result.c_str());
+}
+
+template <>
+inline void TypedScanner<FLBAType>::FormatValue(void* val, char* buffer, int bufsize,
+ int width) {
+ std::string fmt = format_fwf<FLBAType>(width);
+ std::string result = FixedLenByteArrayToString(
+ *reinterpret_cast<FixedLenByteArray*>(val), descr()->type_length());
+ snprintf(buffer, bufsize, fmt.c_str(), result.c_str());
+}
+
+typedef TypedScanner<BooleanType> BoolScanner;
+typedef TypedScanner<Int32Type> Int32Scanner;
+typedef TypedScanner<Int64Type> Int64Scanner;
+typedef TypedScanner<Int96Type> Int96Scanner;
+typedef TypedScanner<FloatType> FloatScanner;
+typedef TypedScanner<DoubleType> DoubleScanner;
+typedef TypedScanner<ByteArrayType> ByteArrayScanner;
+typedef TypedScanner<FLBAType> FixedLenByteArrayScanner;
+
+template <typename RType>
+int64_t ScanAll(int32_t batch_size, int16_t* def_levels, int16_t* rep_levels,
+ uint8_t* values, int64_t* values_buffered,
+ parquet::ColumnReader* reader) {
+ typedef typename RType::T Type;
+ auto typed_reader = static_cast<RType*>(reader);
+ auto vals = reinterpret_cast<Type*>(&values[0]);
+ return typed_reader->ReadBatch(batch_size, def_levels, rep_levels, vals,
+ values_buffered);
+}
+
+int64_t PARQUET_EXPORT ScanAllValues(int32_t batch_size, int16_t* def_levels,
+ int16_t* rep_levels, uint8_t* values,
+ int64_t* values_buffered,
+ parquet::ColumnReader* reader);
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/column_scanner_test.cc b/src/arrow/cpp/src/parquet/column_scanner_test.cc
new file mode 100644
index 000000000..f6d162e3d
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/column_scanner_test.cc
@@ -0,0 +1,229 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/testing/gtest_compat.h"
+
+#include "parquet/column_page.h"
+#include "parquet/column_scanner.h"
+#include "parquet/schema.h"
+#include "parquet/test_util.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+using schema::NodePtr;
+
+namespace test {
+
+template <typename Type>
+class TestFlatScanner : public ::testing::Test {
+ public:
+ using c_type = typename Type::c_type;
+
+ void InitScanner(const ColumnDescriptor* d) {
+ std::unique_ptr<PageReader> pager(new test::MockPageReader(pages_));
+ scanner_ = Scanner::Make(ColumnReader::Make(d, std::move(pager)));
+ }
+
+ void CheckResults(int batch_size, const ColumnDescriptor* d) {
+ TypedScanner<Type>* scanner = reinterpret_cast<TypedScanner<Type>*>(scanner_.get());
+ c_type val;
+ bool is_null = false;
+ int16_t def_level;
+ int16_t rep_level;
+ int j = 0;
+ scanner->SetBatchSize(batch_size);
+ for (int i = 0; i < num_levels_; i++) {
+ ASSERT_TRUE(scanner->Next(&val, &def_level, &rep_level, &is_null)) << i << j;
+ if (!is_null) {
+ ASSERT_EQ(values_[j], val) << i << "V" << j;
+ j++;
+ }
+ if (d->max_definition_level() > 0) {
+ ASSERT_EQ(def_levels_[i], def_level) << i << "D" << j;
+ }
+ if (d->max_repetition_level() > 0) {
+ ASSERT_EQ(rep_levels_[i], rep_level) << i << "R" << j;
+ }
+ }
+ ASSERT_EQ(num_values_, j);
+ ASSERT_FALSE(scanner->Next(&val, &def_level, &rep_level, &is_null));
+ }
+
+ void Clear() {
+ pages_.clear();
+ values_.clear();
+ def_levels_.clear();
+ rep_levels_.clear();
+ }
+
+ void Execute(int num_pages, int levels_per_page, int batch_size,
+ const ColumnDescriptor* d, Encoding::type encoding) {
+ num_values_ = MakePages<Type>(d, num_pages, levels_per_page, def_levels_, rep_levels_,
+ values_, data_buffer_, pages_, encoding);
+ num_levels_ = num_pages * levels_per_page;
+ InitScanner(d);
+ CheckResults(batch_size, d);
+ Clear();
+ }
+
+ void InitDescriptors(std::shared_ptr<ColumnDescriptor>& d1,
+ std::shared_ptr<ColumnDescriptor>& d2,
+ std::shared_ptr<ColumnDescriptor>& d3, int length) {
+ NodePtr type;
+ type = schema::PrimitiveNode::Make("c1", Repetition::REQUIRED, Type::type_num,
+ ConvertedType::NONE, length);
+ d1.reset(new ColumnDescriptor(type, 0, 0));
+ type = schema::PrimitiveNode::Make("c2", Repetition::OPTIONAL, Type::type_num,
+ ConvertedType::NONE, length);
+ d2.reset(new ColumnDescriptor(type, 4, 0));
+ type = schema::PrimitiveNode::Make("c3", Repetition::REPEATED, Type::type_num,
+ ConvertedType::NONE, length);
+ d3.reset(new ColumnDescriptor(type, 4, 2));
+ }
+
+ void ExecuteAll(int num_pages, int num_levels, int batch_size, int type_length,
+ Encoding::type encoding = Encoding::PLAIN) {
+ std::shared_ptr<ColumnDescriptor> d1;
+ std::shared_ptr<ColumnDescriptor> d2;
+ std::shared_ptr<ColumnDescriptor> d3;
+ InitDescriptors(d1, d2, d3, type_length);
+ // evaluate REQUIRED pages
+ Execute(num_pages, num_levels, batch_size, d1.get(), encoding);
+ // evaluate OPTIONAL pages
+ Execute(num_pages, num_levels, batch_size, d2.get(), encoding);
+ // evaluate REPEATED pages
+ Execute(num_pages, num_levels, batch_size, d3.get(), encoding);
+ }
+
+ protected:
+ int num_levels_;
+ int num_values_;
+ std::vector<std::shared_ptr<Page>> pages_;
+ std::shared_ptr<Scanner> scanner_;
+ std::vector<c_type> values_;
+ std::vector<int16_t> def_levels_;
+ std::vector<int16_t> rep_levels_;
+ std::vector<uint8_t> data_buffer_; // For BA and FLBA
+};
+
+static int num_levels_per_page = 100;
+static int num_pages = 20;
+static int batch_size = 32;
+
+typedef ::testing::Types<Int32Type, Int64Type, Int96Type, FloatType, DoubleType,
+ ByteArrayType>
+ TestTypes;
+
+using TestBooleanFlatScanner = TestFlatScanner<BooleanType>;
+using TestFLBAFlatScanner = TestFlatScanner<FLBAType>;
+
+TYPED_TEST_SUITE(TestFlatScanner, TestTypes);
+
+TYPED_TEST(TestFlatScanner, TestPlainScanner) {
+ ASSERT_NO_FATAL_FAILURE(
+ this->ExecuteAll(num_pages, num_levels_per_page, batch_size, 0, Encoding::PLAIN));
+}
+
+TYPED_TEST(TestFlatScanner, TestDictScanner) {
+ ASSERT_NO_FATAL_FAILURE(this->ExecuteAll(num_pages, num_levels_per_page, batch_size, 0,
+ Encoding::RLE_DICTIONARY));
+}
+
+TEST_F(TestBooleanFlatScanner, TestPlainScanner) {
+ ASSERT_NO_FATAL_FAILURE(
+ this->ExecuteAll(num_pages, num_levels_per_page, batch_size, 0));
+}
+
+TEST_F(TestFLBAFlatScanner, TestPlainScanner) {
+ ASSERT_NO_FATAL_FAILURE(
+ this->ExecuteAll(num_pages, num_levels_per_page, batch_size, FLBA_LENGTH));
+}
+
+TEST_F(TestFLBAFlatScanner, TestDictScanner) {
+ ASSERT_NO_FATAL_FAILURE(this->ExecuteAll(num_pages, num_levels_per_page, batch_size,
+ FLBA_LENGTH, Encoding::RLE_DICTIONARY));
+}
+
+TEST_F(TestFLBAFlatScanner, TestPlainDictScanner) {
+ ASSERT_NO_FATAL_FAILURE(this->ExecuteAll(num_pages, num_levels_per_page, batch_size,
+ FLBA_LENGTH, Encoding::PLAIN_DICTIONARY));
+}
+
+// PARQUET 502
+TEST_F(TestFLBAFlatScanner, TestSmallBatch) {
+ NodePtr type =
+ schema::PrimitiveNode::Make("c1", Repetition::REQUIRED, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, FLBA_LENGTH, 10, 2);
+ const ColumnDescriptor d(type, 0, 0);
+ num_values_ = MakePages<FLBAType>(&d, 1, 100, def_levels_, rep_levels_, values_,
+ data_buffer_, pages_);
+ num_levels_ = 1 * 100;
+ InitScanner(&d);
+ ASSERT_NO_FATAL_FAILURE(CheckResults(1, &d));
+}
+
+TEST_F(TestFLBAFlatScanner, TestDescriptorAPI) {
+ NodePtr type =
+ schema::PrimitiveNode::Make("c1", Repetition::OPTIONAL, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, FLBA_LENGTH, 10, 2);
+ const ColumnDescriptor d(type, 4, 0);
+ num_values_ = MakePages<FLBAType>(&d, 1, 100, def_levels_, rep_levels_, values_,
+ data_buffer_, pages_);
+ num_levels_ = 1 * 100;
+ InitScanner(&d);
+ TypedScanner<FLBAType>* scanner =
+ reinterpret_cast<TypedScanner<FLBAType>*>(scanner_.get());
+ ASSERT_EQ(10, scanner->descr()->type_precision());
+ ASSERT_EQ(2, scanner->descr()->type_scale());
+ ASSERT_EQ(FLBA_LENGTH, scanner->descr()->type_length());
+}
+
+TEST_F(TestFLBAFlatScanner, TestFLBAPrinterNext) {
+ NodePtr type =
+ schema::PrimitiveNode::Make("c1", Repetition::OPTIONAL, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, FLBA_LENGTH, 10, 2);
+ const ColumnDescriptor d(type, 4, 0);
+ num_values_ = MakePages<FLBAType>(&d, 1, 100, def_levels_, rep_levels_, values_,
+ data_buffer_, pages_);
+ num_levels_ = 1 * 100;
+ InitScanner(&d);
+ TypedScanner<FLBAType>* scanner =
+ reinterpret_cast<TypedScanner<FLBAType>*>(scanner_.get());
+ scanner->SetBatchSize(batch_size);
+ std::stringstream ss_fail;
+ for (int i = 0; i < num_levels_; i++) {
+ std::stringstream ss;
+ scanner->PrintNext(ss, 17);
+ std::string result = ss.str();
+ ASSERT_LE(17, result.size()) << i;
+ }
+ ASSERT_THROW(scanner->PrintNext(ss_fail, 17), ParquetException);
+}
+
+} // namespace test
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/column_writer.cc b/src/arrow/cpp/src/parquet/column_writer.cc
new file mode 100644
index 000000000..df535bcbe
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/column_writer.cc
@@ -0,0 +1,2103 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/column_writer.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/compute/api.h"
+#include "arrow/io/memory.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_stream_utils.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/compression.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/rle_encoding.h"
+#include "arrow/visitor_inline.h"
+#include "parquet/column_page.h"
+#include "parquet/encoding.h"
+#include "parquet/encryption/encryption_internal.h"
+#include "parquet/encryption/internal_file_encryptor.h"
+#include "parquet/level_conversion.h"
+#include "parquet/metadata.h"
+#include "parquet/platform.h"
+#include "parquet/properties.h"
+#include "parquet/schema.h"
+#include "parquet/statistics.h"
+#include "parquet/thrift_internal.h"
+#include "parquet/types.h"
+
+using arrow::Array;
+using arrow::ArrayData;
+using arrow::Datum;
+using arrow::Result;
+using arrow::Status;
+using arrow::BitUtil::BitWriter;
+using arrow::internal::checked_cast;
+using arrow::internal::checked_pointer_cast;
+using arrow::util::RleEncoder;
+
+namespace BitUtil = arrow::BitUtil;
+
+namespace parquet {
+
+namespace {
+
+// Visitor that exracts the value buffer from a FlatArray at a given offset.
+struct ValueBufferSlicer {
+ template <typename T>
+ ::arrow::enable_if_base_binary<typename T::TypeClass, Status> Visit(const T& array) {
+ auto data = array.data();
+ buffer_ =
+ SliceBuffer(data->buffers[1], data->offset * sizeof(typename T::offset_type),
+ data->length * sizeof(typename T::offset_type));
+ return Status::OK();
+ }
+
+ template <typename T>
+ ::arrow::enable_if_fixed_size_binary<typename T::TypeClass, Status> Visit(
+ const T& array) {
+ auto data = array.data();
+ buffer_ = SliceBuffer(data->buffers[1], data->offset * array.byte_width(),
+ data->length * array.byte_width());
+ return Status::OK();
+ }
+
+ template <typename T>
+ ::arrow::enable_if_t<::arrow::has_c_type<typename T::TypeClass>::value &&
+ !std::is_same<BooleanType, typename T::TypeClass>::value,
+ Status>
+ Visit(const T& array) {
+ auto data = array.data();
+ buffer_ = SliceBuffer(
+ data->buffers[1],
+ ::arrow::TypeTraits<typename T::TypeClass>::bytes_required(data->offset),
+ ::arrow::TypeTraits<typename T::TypeClass>::bytes_required(data->length));
+ return Status::OK();
+ }
+
+ Status Visit(const ::arrow::BooleanArray& array) {
+ auto data = array.data();
+ if (BitUtil::IsMultipleOf8(data->offset)) {
+ buffer_ = SliceBuffer(data->buffers[1], BitUtil::BytesForBits(data->offset),
+ BitUtil::BytesForBits(data->length));
+ return Status::OK();
+ }
+ PARQUET_ASSIGN_OR_THROW(buffer_,
+ ::arrow::internal::CopyBitmap(pool_, data->buffers[1]->data(),
+ data->offset, data->length));
+ return Status::OK();
+ }
+#define NOT_IMPLEMENTED_VISIT(ArrowTypePrefix) \
+ Status Visit(const ::arrow::ArrowTypePrefix##Array& array) { \
+ return Status::NotImplemented("Slicing not implemented for " #ArrowTypePrefix); \
+ }
+
+ NOT_IMPLEMENTED_VISIT(Null);
+ NOT_IMPLEMENTED_VISIT(Union);
+ NOT_IMPLEMENTED_VISIT(List);
+ NOT_IMPLEMENTED_VISIT(LargeList);
+ NOT_IMPLEMENTED_VISIT(Struct);
+ NOT_IMPLEMENTED_VISIT(FixedSizeList);
+ NOT_IMPLEMENTED_VISIT(Dictionary);
+ NOT_IMPLEMENTED_VISIT(Extension);
+
+#undef NOT_IMPLEMENTED_VISIT
+
+ MemoryPool* pool_;
+ std::shared_ptr<Buffer> buffer_;
+};
+
+internal::LevelInfo ComputeLevelInfo(const ColumnDescriptor* descr) {
+ internal::LevelInfo level_info;
+ level_info.def_level = descr->max_definition_level();
+ level_info.rep_level = descr->max_repetition_level();
+
+ int16_t min_spaced_def_level = descr->max_definition_level();
+ const ::parquet::schema::Node* node = descr->schema_node().get();
+ while (node != nullptr && !node->is_repeated()) {
+ if (node->is_optional()) {
+ min_spaced_def_level--;
+ }
+ node = node->parent();
+ }
+ level_info.repeated_ancestor_def_level = min_spaced_def_level;
+ return level_info;
+}
+
+template <class T>
+inline const T* AddIfNotNull(const T* base, int64_t offset) {
+ if (base != nullptr) {
+ return base + offset;
+ }
+ return nullptr;
+}
+
+} // namespace
+
+LevelEncoder::LevelEncoder() {}
+LevelEncoder::~LevelEncoder() {}
+
+void LevelEncoder::Init(Encoding::type encoding, int16_t max_level,
+ int num_buffered_values, uint8_t* data, int data_size) {
+ bit_width_ = BitUtil::Log2(max_level + 1);
+ encoding_ = encoding;
+ switch (encoding) {
+ case Encoding::RLE: {
+ rle_encoder_.reset(new RleEncoder(data, data_size, bit_width_));
+ break;
+ }
+ case Encoding::BIT_PACKED: {
+ int num_bytes =
+ static_cast<int>(BitUtil::BytesForBits(num_buffered_values * bit_width_));
+ bit_packed_encoder_.reset(new BitWriter(data, num_bytes));
+ break;
+ }
+ default:
+ throw ParquetException("Unknown encoding type for levels.");
+ }
+}
+
+int LevelEncoder::MaxBufferSize(Encoding::type encoding, int16_t max_level,
+ int num_buffered_values) {
+ int bit_width = BitUtil::Log2(max_level + 1);
+ int num_bytes = 0;
+ switch (encoding) {
+ case Encoding::RLE: {
+ // TODO: Due to the way we currently check if the buffer is full enough,
+ // we need to have MinBufferSize as head room.
+ num_bytes = RleEncoder::MaxBufferSize(bit_width, num_buffered_values) +
+ RleEncoder::MinBufferSize(bit_width);
+ break;
+ }
+ case Encoding::BIT_PACKED: {
+ num_bytes =
+ static_cast<int>(BitUtil::BytesForBits(num_buffered_values * bit_width));
+ break;
+ }
+ default:
+ throw ParquetException("Unknown encoding type for levels.");
+ }
+ return num_bytes;
+}
+
+int LevelEncoder::Encode(int batch_size, const int16_t* levels) {
+ int num_encoded = 0;
+ if (!rle_encoder_ && !bit_packed_encoder_) {
+ throw ParquetException("Level encoders are not initialized.");
+ }
+
+ if (encoding_ == Encoding::RLE) {
+ for (int i = 0; i < batch_size; ++i) {
+ if (!rle_encoder_->Put(*(levels + i))) {
+ break;
+ }
+ ++num_encoded;
+ }
+ rle_encoder_->Flush();
+ rle_length_ = rle_encoder_->len();
+ } else {
+ for (int i = 0; i < batch_size; ++i) {
+ if (!bit_packed_encoder_->PutValue(*(levels + i), bit_width_)) {
+ break;
+ }
+ ++num_encoded;
+ }
+ bit_packed_encoder_->Flush();
+ }
+ return num_encoded;
+}
+
+// ----------------------------------------------------------------------
+// PageWriter implementation
+
+// This subclass delimits pages appearing in a serialized stream, each preceded
+// by a serialized Thrift format::PageHeader indicating the type of each page
+// and the page metadata.
+class SerializedPageWriter : public PageWriter {
+ public:
+ SerializedPageWriter(std::shared_ptr<ArrowOutputStream> sink, Compression::type codec,
+ int compression_level, ColumnChunkMetaDataBuilder* metadata,
+ int16_t row_group_ordinal, int16_t column_chunk_ordinal,
+ MemoryPool* pool = ::arrow::default_memory_pool(),
+ std::shared_ptr<Encryptor> meta_encryptor = nullptr,
+ std::shared_ptr<Encryptor> data_encryptor = nullptr)
+ : sink_(std::move(sink)),
+ metadata_(metadata),
+ pool_(pool),
+ num_values_(0),
+ dictionary_page_offset_(0),
+ data_page_offset_(0),
+ total_uncompressed_size_(0),
+ total_compressed_size_(0),
+ page_ordinal_(0),
+ row_group_ordinal_(row_group_ordinal),
+ column_ordinal_(column_chunk_ordinal),
+ meta_encryptor_(std::move(meta_encryptor)),
+ data_encryptor_(std::move(data_encryptor)),
+ encryption_buffer_(AllocateBuffer(pool, 0)) {
+ if (data_encryptor_ != nullptr || meta_encryptor_ != nullptr) {
+ InitEncryption();
+ }
+ compressor_ = GetCodec(codec, compression_level);
+ thrift_serializer_.reset(new ThriftSerializer);
+ }
+
+ int64_t WriteDictionaryPage(const DictionaryPage& page) override {
+ int64_t uncompressed_size = page.size();
+ std::shared_ptr<Buffer> compressed_data;
+ if (has_compressor()) {
+ auto buffer = std::static_pointer_cast<ResizableBuffer>(
+ AllocateBuffer(pool_, uncompressed_size));
+ Compress(*(page.buffer().get()), buffer.get());
+ compressed_data = std::static_pointer_cast<Buffer>(buffer);
+ } else {
+ compressed_data = page.buffer();
+ }
+
+ format::DictionaryPageHeader dict_page_header;
+ dict_page_header.__set_num_values(page.num_values());
+ dict_page_header.__set_encoding(ToThrift(page.encoding()));
+ dict_page_header.__set_is_sorted(page.is_sorted());
+
+ const uint8_t* output_data_buffer = compressed_data->data();
+ int32_t output_data_len = static_cast<int32_t>(compressed_data->size());
+
+ if (data_encryptor_.get()) {
+ UpdateEncryption(encryption::kDictionaryPage);
+ PARQUET_THROW_NOT_OK(encryption_buffer_->Resize(
+ data_encryptor_->CiphertextSizeDelta() + output_data_len, false));
+ output_data_len = data_encryptor_->Encrypt(compressed_data->data(), output_data_len,
+ encryption_buffer_->mutable_data());
+ output_data_buffer = encryption_buffer_->data();
+ }
+
+ format::PageHeader page_header;
+ page_header.__set_type(format::PageType::DICTIONARY_PAGE);
+ page_header.__set_uncompressed_page_size(static_cast<int32_t>(uncompressed_size));
+ page_header.__set_compressed_page_size(static_cast<int32_t>(output_data_len));
+ page_header.__set_dictionary_page_header(dict_page_header);
+ // TODO(PARQUET-594) crc checksum
+
+ PARQUET_ASSIGN_OR_THROW(int64_t start_pos, sink_->Tell());
+ if (dictionary_page_offset_ == 0) {
+ dictionary_page_offset_ = start_pos;
+ }
+
+ if (meta_encryptor_) {
+ UpdateEncryption(encryption::kDictionaryPageHeader);
+ }
+ const int64_t header_size =
+ thrift_serializer_->Serialize(&page_header, sink_.get(), meta_encryptor_);
+
+ PARQUET_THROW_NOT_OK(sink_->Write(output_data_buffer, output_data_len));
+
+ total_uncompressed_size_ += uncompressed_size + header_size;
+ total_compressed_size_ += output_data_len + header_size;
+ ++dict_encoding_stats_[page.encoding()];
+ return uncompressed_size + header_size;
+ }
+
+ void Close(bool has_dictionary, bool fallback) override {
+ if (meta_encryptor_ != nullptr) {
+ UpdateEncryption(encryption::kColumnMetaData);
+ }
+ // index_page_offset = -1 since they are not supported
+ metadata_->Finish(num_values_, dictionary_page_offset_, -1, data_page_offset_,
+ total_compressed_size_, total_uncompressed_size_, has_dictionary,
+ fallback, dict_encoding_stats_, data_encoding_stats_,
+ meta_encryptor_);
+ // Write metadata at end of column chunk
+ metadata_->WriteTo(sink_.get());
+ }
+
+ /**
+ * Compress a buffer.
+ */
+ void Compress(const Buffer& src_buffer, ResizableBuffer* dest_buffer) override {
+ DCHECK(compressor_ != nullptr);
+
+ // Compress the data
+ int64_t max_compressed_size =
+ compressor_->MaxCompressedLen(src_buffer.size(), src_buffer.data());
+
+ // Use Arrow::Buffer::shrink_to_fit = false
+ // underlying buffer only keeps growing. Resize to a smaller size does not reallocate.
+ PARQUET_THROW_NOT_OK(dest_buffer->Resize(max_compressed_size, false));
+
+ PARQUET_ASSIGN_OR_THROW(
+ int64_t compressed_size,
+ compressor_->Compress(src_buffer.size(), src_buffer.data(), max_compressed_size,
+ dest_buffer->mutable_data()));
+ PARQUET_THROW_NOT_OK(dest_buffer->Resize(compressed_size, false));
+ }
+
+ int64_t WriteDataPage(const DataPage& page) override {
+ const int64_t uncompressed_size = page.uncompressed_size();
+ std::shared_ptr<Buffer> compressed_data = page.buffer();
+ const uint8_t* output_data_buffer = compressed_data->data();
+ int32_t output_data_len = static_cast<int32_t>(compressed_data->size());
+
+ if (data_encryptor_.get()) {
+ PARQUET_THROW_NOT_OK(encryption_buffer_->Resize(
+ data_encryptor_->CiphertextSizeDelta() + output_data_len, false));
+ UpdateEncryption(encryption::kDataPage);
+ output_data_len = data_encryptor_->Encrypt(compressed_data->data(), output_data_len,
+ encryption_buffer_->mutable_data());
+ output_data_buffer = encryption_buffer_->data();
+ }
+
+ format::PageHeader page_header;
+ page_header.__set_uncompressed_page_size(static_cast<int32_t>(uncompressed_size));
+ page_header.__set_compressed_page_size(static_cast<int32_t>(output_data_len));
+ // TODO(PARQUET-594) crc checksum
+
+ if (page.type() == PageType::DATA_PAGE) {
+ const DataPageV1& v1_page = checked_cast<const DataPageV1&>(page);
+ SetDataPageHeader(page_header, v1_page);
+ } else if (page.type() == PageType::DATA_PAGE_V2) {
+ const DataPageV2& v2_page = checked_cast<const DataPageV2&>(page);
+ SetDataPageV2Header(page_header, v2_page);
+ } else {
+ throw ParquetException("Unexpected page type");
+ }
+
+ PARQUET_ASSIGN_OR_THROW(int64_t start_pos, sink_->Tell());
+ if (page_ordinal_ == 0) {
+ data_page_offset_ = start_pos;
+ }
+
+ if (meta_encryptor_) {
+ UpdateEncryption(encryption::kDataPageHeader);
+ }
+ const int64_t header_size =
+ thrift_serializer_->Serialize(&page_header, sink_.get(), meta_encryptor_);
+ PARQUET_THROW_NOT_OK(sink_->Write(output_data_buffer, output_data_len));
+
+ total_uncompressed_size_ += uncompressed_size + header_size;
+ total_compressed_size_ += output_data_len + header_size;
+ num_values_ += page.num_values();
+ ++data_encoding_stats_[page.encoding()];
+ ++page_ordinal_;
+ return uncompressed_size + header_size;
+ }
+
+ void SetDataPageHeader(format::PageHeader& page_header, const DataPageV1& page) {
+ format::DataPageHeader data_page_header;
+ data_page_header.__set_num_values(page.num_values());
+ data_page_header.__set_encoding(ToThrift(page.encoding()));
+ data_page_header.__set_definition_level_encoding(
+ ToThrift(page.definition_level_encoding()));
+ data_page_header.__set_repetition_level_encoding(
+ ToThrift(page.repetition_level_encoding()));
+ data_page_header.__set_statistics(ToThrift(page.statistics()));
+
+ page_header.__set_type(format::PageType::DATA_PAGE);
+ page_header.__set_data_page_header(data_page_header);
+ }
+
+ void SetDataPageV2Header(format::PageHeader& page_header, const DataPageV2 page) {
+ format::DataPageHeaderV2 data_page_header;
+ data_page_header.__set_num_values(page.num_values());
+ data_page_header.__set_num_nulls(page.num_nulls());
+ data_page_header.__set_num_rows(page.num_rows());
+ data_page_header.__set_encoding(ToThrift(page.encoding()));
+
+ data_page_header.__set_definition_levels_byte_length(
+ page.definition_levels_byte_length());
+ data_page_header.__set_repetition_levels_byte_length(
+ page.repetition_levels_byte_length());
+
+ data_page_header.__set_is_compressed(page.is_compressed());
+ data_page_header.__set_statistics(ToThrift(page.statistics()));
+
+ page_header.__set_type(format::PageType::DATA_PAGE_V2);
+ page_header.__set_data_page_header_v2(data_page_header);
+ }
+
+ bool has_compressor() override { return (compressor_ != nullptr); }
+
+ int64_t num_values() { return num_values_; }
+
+ int64_t dictionary_page_offset() { return dictionary_page_offset_; }
+
+ int64_t data_page_offset() { return data_page_offset_; }
+
+ int64_t total_compressed_size() { return total_compressed_size_; }
+
+ int64_t total_uncompressed_size() { return total_uncompressed_size_; }
+
+ private:
+ // To allow UpdateEncryption on Close
+ friend class BufferedPageWriter;
+
+ void InitEncryption() {
+ // Prepare the AAD for quick update later.
+ if (data_encryptor_ != nullptr) {
+ data_page_aad_ = encryption::CreateModuleAad(
+ data_encryptor_->file_aad(), encryption::kDataPage, row_group_ordinal_,
+ column_ordinal_, kNonPageOrdinal);
+ }
+ if (meta_encryptor_ != nullptr) {
+ data_page_header_aad_ = encryption::CreateModuleAad(
+ meta_encryptor_->file_aad(), encryption::kDataPageHeader, row_group_ordinal_,
+ column_ordinal_, kNonPageOrdinal);
+ }
+ }
+
+ void UpdateEncryption(int8_t module_type) {
+ switch (module_type) {
+ case encryption::kColumnMetaData: {
+ meta_encryptor_->UpdateAad(encryption::CreateModuleAad(
+ meta_encryptor_->file_aad(), module_type, row_group_ordinal_, column_ordinal_,
+ kNonPageOrdinal));
+ break;
+ }
+ case encryption::kDataPage: {
+ encryption::QuickUpdatePageAad(data_page_aad_, page_ordinal_);
+ data_encryptor_->UpdateAad(data_page_aad_);
+ break;
+ }
+ case encryption::kDataPageHeader: {
+ encryption::QuickUpdatePageAad(data_page_header_aad_, page_ordinal_);
+ meta_encryptor_->UpdateAad(data_page_header_aad_);
+ break;
+ }
+ case encryption::kDictionaryPageHeader: {
+ meta_encryptor_->UpdateAad(encryption::CreateModuleAad(
+ meta_encryptor_->file_aad(), module_type, row_group_ordinal_, column_ordinal_,
+ kNonPageOrdinal));
+ break;
+ }
+ case encryption::kDictionaryPage: {
+ data_encryptor_->UpdateAad(encryption::CreateModuleAad(
+ data_encryptor_->file_aad(), module_type, row_group_ordinal_, column_ordinal_,
+ kNonPageOrdinal));
+ break;
+ }
+ default:
+ throw ParquetException("Unknown module type in UpdateEncryption");
+ }
+ }
+
+ std::shared_ptr<ArrowOutputStream> sink_;
+ ColumnChunkMetaDataBuilder* metadata_;
+ MemoryPool* pool_;
+ int64_t num_values_;
+ int64_t dictionary_page_offset_;
+ int64_t data_page_offset_;
+ int64_t total_uncompressed_size_;
+ int64_t total_compressed_size_;
+ int16_t page_ordinal_;
+ int16_t row_group_ordinal_;
+ int16_t column_ordinal_;
+
+ std::unique_ptr<ThriftSerializer> thrift_serializer_;
+
+ // Compression codec to use.
+ std::unique_ptr<::arrow::util::Codec> compressor_;
+
+ std::string data_page_aad_;
+ std::string data_page_header_aad_;
+
+ std::shared_ptr<Encryptor> meta_encryptor_;
+ std::shared_ptr<Encryptor> data_encryptor_;
+
+ std::shared_ptr<ResizableBuffer> encryption_buffer_;
+
+ std::map<Encoding::type, int32_t> dict_encoding_stats_;
+ std::map<Encoding::type, int32_t> data_encoding_stats_;
+};
+
+// This implementation of the PageWriter writes to the final sink on Close .
+class BufferedPageWriter : public PageWriter {
+ public:
+ BufferedPageWriter(std::shared_ptr<ArrowOutputStream> sink, Compression::type codec,
+ int compression_level, ColumnChunkMetaDataBuilder* metadata,
+ int16_t row_group_ordinal, int16_t current_column_ordinal,
+ MemoryPool* pool = ::arrow::default_memory_pool(),
+ std::shared_ptr<Encryptor> meta_encryptor = nullptr,
+ std::shared_ptr<Encryptor> data_encryptor = nullptr)
+ : final_sink_(std::move(sink)), metadata_(metadata), has_dictionary_pages_(false) {
+ in_memory_sink_ = CreateOutputStream(pool);
+ pager_ = std::unique_ptr<SerializedPageWriter>(
+ new SerializedPageWriter(in_memory_sink_, codec, compression_level, metadata,
+ row_group_ordinal, current_column_ordinal, pool,
+ std::move(meta_encryptor), std::move(data_encryptor)));
+ }
+
+ int64_t WriteDictionaryPage(const DictionaryPage& page) override {
+ has_dictionary_pages_ = true;
+ return pager_->WriteDictionaryPage(page);
+ }
+
+ void Close(bool has_dictionary, bool fallback) override {
+ if (pager_->meta_encryptor_ != nullptr) {
+ pager_->UpdateEncryption(encryption::kColumnMetaData);
+ }
+ // index_page_offset = -1 since they are not supported
+ PARQUET_ASSIGN_OR_THROW(int64_t final_position, final_sink_->Tell());
+ // dictionary page offset should be 0 iff there are no dictionary pages
+ auto dictionary_page_offset =
+ has_dictionary_pages_ ? pager_->dictionary_page_offset() + final_position : 0;
+ metadata_->Finish(pager_->num_values(), dictionary_page_offset, -1,
+ pager_->data_page_offset() + final_position,
+ pager_->total_compressed_size(), pager_->total_uncompressed_size(),
+ has_dictionary, fallback, pager_->dict_encoding_stats_,
+ pager_->data_encoding_stats_, pager_->meta_encryptor_);
+
+ // Write metadata at end of column chunk
+ metadata_->WriteTo(in_memory_sink_.get());
+
+ // flush everything to the serialized sink
+ PARQUET_ASSIGN_OR_THROW(auto buffer, in_memory_sink_->Finish());
+ PARQUET_THROW_NOT_OK(final_sink_->Write(buffer));
+ }
+
+ int64_t WriteDataPage(const DataPage& page) override {
+ return pager_->WriteDataPage(page);
+ }
+
+ void Compress(const Buffer& src_buffer, ResizableBuffer* dest_buffer) override {
+ pager_->Compress(src_buffer, dest_buffer);
+ }
+
+ bool has_compressor() override { return pager_->has_compressor(); }
+
+ private:
+ std::shared_ptr<ArrowOutputStream> final_sink_;
+ ColumnChunkMetaDataBuilder* metadata_;
+ std::shared_ptr<::arrow::io::BufferOutputStream> in_memory_sink_;
+ std::unique_ptr<SerializedPageWriter> pager_;
+ bool has_dictionary_pages_;
+};
+
+std::unique_ptr<PageWriter> PageWriter::Open(
+ std::shared_ptr<ArrowOutputStream> sink, Compression::type codec,
+ int compression_level, ColumnChunkMetaDataBuilder* metadata,
+ int16_t row_group_ordinal, int16_t column_chunk_ordinal, MemoryPool* pool,
+ bool buffered_row_group, std::shared_ptr<Encryptor> meta_encryptor,
+ std::shared_ptr<Encryptor> data_encryptor) {
+ if (buffered_row_group) {
+ return std::unique_ptr<PageWriter>(
+ new BufferedPageWriter(std::move(sink), codec, compression_level, metadata,
+ row_group_ordinal, column_chunk_ordinal, pool,
+ std::move(meta_encryptor), std::move(data_encryptor)));
+ } else {
+ return std::unique_ptr<PageWriter>(
+ new SerializedPageWriter(std::move(sink), codec, compression_level, metadata,
+ row_group_ordinal, column_chunk_ordinal, pool,
+ std::move(meta_encryptor), std::move(data_encryptor)));
+ }
+}
+
+// ----------------------------------------------------------------------
+// ColumnWriter
+
+const std::shared_ptr<WriterProperties>& default_writer_properties() {
+ static std::shared_ptr<WriterProperties> default_writer_properties =
+ WriterProperties::Builder().build();
+ return default_writer_properties;
+}
+
+class ColumnWriterImpl {
+ public:
+ ColumnWriterImpl(ColumnChunkMetaDataBuilder* metadata,
+ std::unique_ptr<PageWriter> pager, const bool use_dictionary,
+ Encoding::type encoding, const WriterProperties* properties)
+ : metadata_(metadata),
+ descr_(metadata->descr()),
+ level_info_(ComputeLevelInfo(metadata->descr())),
+ pager_(std::move(pager)),
+ has_dictionary_(use_dictionary),
+ encoding_(encoding),
+ properties_(properties),
+ allocator_(properties->memory_pool()),
+ num_buffered_values_(0),
+ num_buffered_encoded_values_(0),
+ rows_written_(0),
+ total_bytes_written_(0),
+ total_compressed_bytes_(0),
+ closed_(false),
+ fallback_(false),
+ definition_levels_sink_(allocator_),
+ repetition_levels_sink_(allocator_) {
+ definition_levels_rle_ =
+ std::static_pointer_cast<ResizableBuffer>(AllocateBuffer(allocator_, 0));
+ repetition_levels_rle_ =
+ std::static_pointer_cast<ResizableBuffer>(AllocateBuffer(allocator_, 0));
+ uncompressed_data_ =
+ std::static_pointer_cast<ResizableBuffer>(AllocateBuffer(allocator_, 0));
+
+ if (pager_->has_compressor()) {
+ compressor_temp_buffer_ =
+ std::static_pointer_cast<ResizableBuffer>(AllocateBuffer(allocator_, 0));
+ }
+ }
+
+ virtual ~ColumnWriterImpl() = default;
+
+ int64_t Close();
+
+ protected:
+ virtual std::shared_ptr<Buffer> GetValuesBuffer() = 0;
+
+ // Serializes Dictionary Page if enabled
+ virtual void WriteDictionaryPage() = 0;
+
+ // Plain-encoded statistics of the current page
+ virtual EncodedStatistics GetPageStatistics() = 0;
+
+ // Plain-encoded statistics of the whole chunk
+ virtual EncodedStatistics GetChunkStatistics() = 0;
+
+ // Merges page statistics into chunk statistics, then resets the values
+ virtual void ResetPageStatistics() = 0;
+
+ // Adds Data Pages to an in memory buffer in dictionary encoding mode
+ // Serializes the Data Pages in other encoding modes
+ void AddDataPage();
+
+ void BuildDataPageV1(int64_t definition_levels_rle_size,
+ int64_t repetition_levels_rle_size, int64_t uncompressed_size,
+ const std::shared_ptr<Buffer>& values);
+ void BuildDataPageV2(int64_t definition_levels_rle_size,
+ int64_t repetition_levels_rle_size, int64_t uncompressed_size,
+ const std::shared_ptr<Buffer>& values);
+
+ // Serializes Data Pages
+ void WriteDataPage(const DataPage& page) {
+ total_bytes_written_ += pager_->WriteDataPage(page);
+ }
+
+ // Write multiple definition levels
+ void WriteDefinitionLevels(int64_t num_levels, const int16_t* levels) {
+ DCHECK(!closed_);
+ PARQUET_THROW_NOT_OK(
+ definition_levels_sink_.Append(levels, sizeof(int16_t) * num_levels));
+ }
+
+ // Write multiple repetition levels
+ void WriteRepetitionLevels(int64_t num_levels, const int16_t* levels) {
+ DCHECK(!closed_);
+ PARQUET_THROW_NOT_OK(
+ repetition_levels_sink_.Append(levels, sizeof(int16_t) * num_levels));
+ }
+
+ // RLE encode the src_buffer into dest_buffer and return the encoded size
+ int64_t RleEncodeLevels(const void* src_buffer, ResizableBuffer* dest_buffer,
+ int16_t max_level, bool include_length_prefix = true);
+
+ // Serialize the buffered Data Pages
+ void FlushBufferedDataPages();
+
+ ColumnChunkMetaDataBuilder* metadata_;
+ const ColumnDescriptor* descr_;
+ // scratch buffer if validity bits need to be recalculated.
+ std::shared_ptr<ResizableBuffer> bits_buffer_;
+ const internal::LevelInfo level_info_;
+
+ std::unique_ptr<PageWriter> pager_;
+
+ bool has_dictionary_;
+ Encoding::type encoding_;
+ const WriterProperties* properties_;
+
+ LevelEncoder level_encoder_;
+
+ MemoryPool* allocator_;
+
+ // The total number of values stored in the data page. This is the maximum of
+ // the number of encoded definition levels or encoded values. For
+ // non-repeated, required columns, this is equal to the number of encoded
+ // values. For repeated or optional values, there may be fewer data values
+ // than levels, and this tells you how many encoded levels there are in that
+ // case.
+ int64_t num_buffered_values_;
+
+ // The total number of stored values. For repeated or optional values, this
+ // number may be lower than num_buffered_values_.
+ int64_t num_buffered_encoded_values_;
+
+ // Total number of rows written with this ColumnWriter
+ int64_t rows_written_;
+
+ // Records the total number of uncompressed bytes written by the serializer
+ int64_t total_bytes_written_;
+
+ // Records the current number of compressed bytes in a column
+ int64_t total_compressed_bytes_;
+
+ // Flag to check if the Writer has been closed
+ bool closed_;
+
+ // Flag to infer if dictionary encoding has fallen back to PLAIN
+ bool fallback_;
+
+ ::arrow::BufferBuilder definition_levels_sink_;
+ ::arrow::BufferBuilder repetition_levels_sink_;
+
+ std::shared_ptr<ResizableBuffer> definition_levels_rle_;
+ std::shared_ptr<ResizableBuffer> repetition_levels_rle_;
+
+ std::shared_ptr<ResizableBuffer> uncompressed_data_;
+ std::shared_ptr<ResizableBuffer> compressor_temp_buffer_;
+
+ std::vector<std::unique_ptr<DataPage>> data_pages_;
+
+ private:
+ void InitSinks() {
+ definition_levels_sink_.Rewind(0);
+ repetition_levels_sink_.Rewind(0);
+ }
+
+ // Concatenate the encoded levels and values into one buffer
+ void ConcatenateBuffers(int64_t definition_levels_rle_size,
+ int64_t repetition_levels_rle_size,
+ const std::shared_ptr<Buffer>& values, uint8_t* combined) {
+ memcpy(combined, repetition_levels_rle_->data(), repetition_levels_rle_size);
+ combined += repetition_levels_rle_size;
+ memcpy(combined, definition_levels_rle_->data(), definition_levels_rle_size);
+ combined += definition_levels_rle_size;
+ memcpy(combined, values->data(), values->size());
+ }
+};
+
+// return the size of the encoded buffer
+int64_t ColumnWriterImpl::RleEncodeLevels(const void* src_buffer,
+ ResizableBuffer* dest_buffer, int16_t max_level,
+ bool include_length_prefix) {
+ // V1 DataPage includes the length of the RLE level as a prefix.
+ int32_t prefix_size = include_length_prefix ? sizeof(int32_t) : 0;
+
+ // TODO: This only works with due to some RLE specifics
+ int64_t rle_size = LevelEncoder::MaxBufferSize(Encoding::RLE, max_level,
+ static_cast<int>(num_buffered_values_)) +
+ prefix_size;
+
+ // Use Arrow::Buffer::shrink_to_fit = false
+ // underlying buffer only keeps growing. Resize to a smaller size does not reallocate.
+ PARQUET_THROW_NOT_OK(dest_buffer->Resize(rle_size, false));
+
+ level_encoder_.Init(Encoding::RLE, max_level, static_cast<int>(num_buffered_values_),
+ dest_buffer->mutable_data() + prefix_size,
+ static_cast<int>(dest_buffer->size() - prefix_size));
+ int encoded = level_encoder_.Encode(static_cast<int>(num_buffered_values_),
+ reinterpret_cast<const int16_t*>(src_buffer));
+ DCHECK_EQ(encoded, num_buffered_values_);
+
+ if (include_length_prefix) {
+ reinterpret_cast<int32_t*>(dest_buffer->mutable_data())[0] = level_encoder_.len();
+ }
+
+ return level_encoder_.len() + prefix_size;
+}
+
+void ColumnWriterImpl::AddDataPage() {
+ int64_t definition_levels_rle_size = 0;
+ int64_t repetition_levels_rle_size = 0;
+
+ std::shared_ptr<Buffer> values = GetValuesBuffer();
+ bool is_v1_data_page = properties_->data_page_version() == ParquetDataPageVersion::V1;
+
+ if (descr_->max_definition_level() > 0) {
+ definition_levels_rle_size = RleEncodeLevels(
+ definition_levels_sink_.data(), definition_levels_rle_.get(),
+ descr_->max_definition_level(), /*include_length_prefix=*/is_v1_data_page);
+ }
+
+ if (descr_->max_repetition_level() > 0) {
+ repetition_levels_rle_size = RleEncodeLevels(
+ repetition_levels_sink_.data(), repetition_levels_rle_.get(),
+ descr_->max_repetition_level(), /*include_length_prefix=*/is_v1_data_page);
+ }
+
+ int64_t uncompressed_size =
+ definition_levels_rle_size + repetition_levels_rle_size + values->size();
+
+ if (is_v1_data_page) {
+ BuildDataPageV1(definition_levels_rle_size, repetition_levels_rle_size,
+ uncompressed_size, values);
+ } else {
+ BuildDataPageV2(definition_levels_rle_size, repetition_levels_rle_size,
+ uncompressed_size, values);
+ }
+
+ // Re-initialize the sinks for next Page.
+ InitSinks();
+ num_buffered_values_ = 0;
+ num_buffered_encoded_values_ = 0;
+}
+
+void ColumnWriterImpl::BuildDataPageV1(int64_t definition_levels_rle_size,
+ int64_t repetition_levels_rle_size,
+ int64_t uncompressed_size,
+ const std::shared_ptr<Buffer>& values) {
+ // Use Arrow::Buffer::shrink_to_fit = false
+ // underlying buffer only keeps growing. Resize to a smaller size does not reallocate.
+ PARQUET_THROW_NOT_OK(uncompressed_data_->Resize(uncompressed_size, false));
+ ConcatenateBuffers(definition_levels_rle_size, repetition_levels_rle_size, values,
+ uncompressed_data_->mutable_data());
+
+ EncodedStatistics page_stats = GetPageStatistics();
+ page_stats.ApplyStatSizeLimits(properties_->max_statistics_size(descr_->path()));
+ page_stats.set_is_signed(SortOrder::SIGNED == descr_->sort_order());
+ ResetPageStatistics();
+
+ std::shared_ptr<Buffer> compressed_data;
+ if (pager_->has_compressor()) {
+ pager_->Compress(*(uncompressed_data_.get()), compressor_temp_buffer_.get());
+ compressed_data = compressor_temp_buffer_;
+ } else {
+ compressed_data = uncompressed_data_;
+ }
+
+ // Write the page to OutputStream eagerly if there is no dictionary or
+ // if dictionary encoding has fallen back to PLAIN
+ if (has_dictionary_ && !fallback_) { // Save pages until end of dictionary encoding
+ PARQUET_ASSIGN_OR_THROW(
+ auto compressed_data_copy,
+ compressed_data->CopySlice(0, compressed_data->size(), allocator_));
+ std::unique_ptr<DataPage> page_ptr(new DataPageV1(
+ compressed_data_copy, static_cast<int32_t>(num_buffered_values_), encoding_,
+ Encoding::RLE, Encoding::RLE, uncompressed_size, page_stats));
+ total_compressed_bytes_ += page_ptr->size() + sizeof(format::PageHeader);
+
+ data_pages_.push_back(std::move(page_ptr));
+ } else { // Eagerly write pages
+ DataPageV1 page(compressed_data, static_cast<int32_t>(num_buffered_values_),
+ encoding_, Encoding::RLE, Encoding::RLE, uncompressed_size,
+ page_stats);
+ WriteDataPage(page);
+ }
+}
+
+void ColumnWriterImpl::BuildDataPageV2(int64_t definition_levels_rle_size,
+ int64_t repetition_levels_rle_size,
+ int64_t uncompressed_size,
+ const std::shared_ptr<Buffer>& values) {
+ // Compress the values if needed. Repetition and definition levels are uncompressed in
+ // V2.
+ std::shared_ptr<Buffer> compressed_values;
+ if (pager_->has_compressor()) {
+ pager_->Compress(*values, compressor_temp_buffer_.get());
+ compressed_values = compressor_temp_buffer_;
+ } else {
+ compressed_values = values;
+ }
+
+ // Concatenate uncompressed levels and the possibly compressed values
+ int64_t combined_size =
+ definition_levels_rle_size + repetition_levels_rle_size + compressed_values->size();
+ std::shared_ptr<ResizableBuffer> combined = AllocateBuffer(allocator_, combined_size);
+
+ ConcatenateBuffers(definition_levels_rle_size, repetition_levels_rle_size,
+ compressed_values, combined->mutable_data());
+
+ EncodedStatistics page_stats = GetPageStatistics();
+ page_stats.ApplyStatSizeLimits(properties_->max_statistics_size(descr_->path()));
+ page_stats.set_is_signed(SortOrder::SIGNED == descr_->sort_order());
+ ResetPageStatistics();
+
+ int32_t num_values = static_cast<int32_t>(num_buffered_values_);
+ int32_t null_count = static_cast<int32_t>(page_stats.null_count);
+ int32_t def_levels_byte_length = static_cast<int32_t>(definition_levels_rle_size);
+ int32_t rep_levels_byte_length = static_cast<int32_t>(repetition_levels_rle_size);
+
+ // Write the page to OutputStream eagerly if there is no dictionary or
+ // if dictionary encoding has fallen back to PLAIN
+ if (has_dictionary_ && !fallback_) { // Save pages until end of dictionary encoding
+ PARQUET_ASSIGN_OR_THROW(auto data_copy,
+ combined->CopySlice(0, combined->size(), allocator_));
+ std::unique_ptr<DataPage> page_ptr(new DataPageV2(
+ combined, num_values, null_count, num_values, encoding_, def_levels_byte_length,
+ rep_levels_byte_length, uncompressed_size, pager_->has_compressor(), page_stats));
+ total_compressed_bytes_ += page_ptr->size() + sizeof(format::PageHeader);
+ data_pages_.push_back(std::move(page_ptr));
+ } else {
+ DataPageV2 page(combined, num_values, null_count, num_values, encoding_,
+ def_levels_byte_length, rep_levels_byte_length, uncompressed_size,
+ pager_->has_compressor(), page_stats);
+ WriteDataPage(page);
+ }
+}
+
+int64_t ColumnWriterImpl::Close() {
+ if (!closed_) {
+ closed_ = true;
+ if (has_dictionary_ && !fallback_) {
+ WriteDictionaryPage();
+ }
+
+ FlushBufferedDataPages();
+
+ EncodedStatistics chunk_statistics = GetChunkStatistics();
+ chunk_statistics.ApplyStatSizeLimits(
+ properties_->max_statistics_size(descr_->path()));
+ chunk_statistics.set_is_signed(SortOrder::SIGNED == descr_->sort_order());
+
+ // Write stats only if the column has at least one row written
+ if (rows_written_ > 0 && chunk_statistics.is_set()) {
+ metadata_->SetStatistics(chunk_statistics);
+ }
+ pager_->Close(has_dictionary_, fallback_);
+ }
+
+ return total_bytes_written_;
+}
+
+void ColumnWriterImpl::FlushBufferedDataPages() {
+ // Write all outstanding data to a new page
+ if (num_buffered_values_ > 0) {
+ AddDataPage();
+ }
+ for (const auto& page_ptr : data_pages_) {
+ WriteDataPage(*page_ptr);
+ }
+ data_pages_.clear();
+ total_compressed_bytes_ = 0;
+}
+
+// ----------------------------------------------------------------------
+// TypedColumnWriter
+
+template <typename Action>
+inline void DoInBatches(int64_t total, int64_t batch_size, Action&& action) {
+ int64_t num_batches = static_cast<int>(total / batch_size);
+ for (int round = 0; round < num_batches; round++) {
+ action(round * batch_size, batch_size);
+ }
+ // Write the remaining values
+ if (total % batch_size > 0) {
+ action(num_batches * batch_size, total % batch_size);
+ }
+}
+
+bool DictionaryDirectWriteSupported(const ::arrow::Array& array) {
+ DCHECK_EQ(array.type_id(), ::arrow::Type::DICTIONARY);
+ const ::arrow::DictionaryType& dict_type =
+ static_cast<const ::arrow::DictionaryType&>(*array.type());
+ return ::arrow::is_base_binary_like(dict_type.value_type()->id());
+}
+
+Status ConvertDictionaryToDense(const ::arrow::Array& array, MemoryPool* pool,
+ std::shared_ptr<::arrow::Array>* out) {
+ const ::arrow::DictionaryType& dict_type =
+ static_cast<const ::arrow::DictionaryType&>(*array.type());
+
+ ::arrow::compute::ExecContext ctx(pool);
+ ARROW_ASSIGN_OR_RAISE(Datum cast_output,
+ ::arrow::compute::Cast(array.data(), dict_type.value_type(),
+ ::arrow::compute::CastOptions(), &ctx));
+ *out = cast_output.make_array();
+ return Status::OK();
+}
+
+static inline bool IsDictionaryEncoding(Encoding::type encoding) {
+ return encoding == Encoding::PLAIN_DICTIONARY;
+}
+
+template <typename DType>
+class TypedColumnWriterImpl : public ColumnWriterImpl, public TypedColumnWriter<DType> {
+ public:
+ using T = typename DType::c_type;
+
+ TypedColumnWriterImpl(ColumnChunkMetaDataBuilder* metadata,
+ std::unique_ptr<PageWriter> pager, const bool use_dictionary,
+ Encoding::type encoding, const WriterProperties* properties)
+ : ColumnWriterImpl(metadata, std::move(pager), use_dictionary, encoding,
+ properties) {
+ current_encoder_ = MakeEncoder(DType::type_num, encoding, use_dictionary, descr_,
+ properties->memory_pool());
+ // We have to dynamic_cast as some compilers don't want to static_cast
+ // through virtual inheritance.
+ current_value_encoder_ = dynamic_cast<TypedEncoder<DType>*>(current_encoder_.get());
+ // Will be null if not using dictionary, but that's ok
+ current_dict_encoder_ = dynamic_cast<DictEncoder<DType>*>(current_encoder_.get());
+
+ if (properties->statistics_enabled(descr_->path()) &&
+ (SortOrder::UNKNOWN != descr_->sort_order())) {
+ page_statistics_ = MakeStatistics<DType>(descr_, allocator_);
+ chunk_statistics_ = MakeStatistics<DType>(descr_, allocator_);
+ }
+ }
+
+ int64_t Close() override { return ColumnWriterImpl::Close(); }
+
+ int64_t WriteBatch(int64_t num_values, const int16_t* def_levels,
+ const int16_t* rep_levels, const T* values) override {
+ // We check for DataPage limits only after we have inserted the values. If a user
+ // writes a large number of values, the DataPage size can be much above the limit.
+ // The purpose of this chunking is to bound this. Even if a user writes large number
+ // of values, the chunking will ensure the AddDataPage() is called at a reasonable
+ // pagesize limit
+ int64_t value_offset = 0;
+
+ auto WriteChunk = [&](int64_t offset, int64_t batch_size) {
+ int64_t values_to_write = WriteLevels(batch_size, AddIfNotNull(def_levels, offset),
+ AddIfNotNull(rep_levels, offset));
+
+ // PARQUET-780
+ if (values_to_write > 0) {
+ DCHECK_NE(nullptr, values);
+ }
+ WriteValues(AddIfNotNull(values, value_offset), values_to_write,
+ batch_size - values_to_write);
+ CommitWriteAndCheckPageLimit(batch_size, values_to_write);
+ value_offset += values_to_write;
+
+ // Dictionary size checked separately from data page size since we
+ // circumvent this check when writing ::arrow::DictionaryArray directly
+ CheckDictionarySizeLimit();
+ };
+ DoInBatches(num_values, properties_->write_batch_size(), WriteChunk);
+ return value_offset;
+ }
+
+ void WriteBatchSpaced(int64_t num_values, const int16_t* def_levels,
+ const int16_t* rep_levels, const uint8_t* valid_bits,
+ int64_t valid_bits_offset, const T* values) override {
+ // Like WriteBatch, but for spaced values
+ int64_t value_offset = 0;
+ auto WriteChunk = [&](int64_t offset, int64_t batch_size) {
+ int64_t batch_num_values = 0;
+ int64_t batch_num_spaced_values = 0;
+ int64_t null_count;
+ MaybeCalculateValidityBits(AddIfNotNull(def_levels, offset), batch_size,
+ &batch_num_values, &batch_num_spaced_values,
+ &null_count);
+
+ WriteLevelsSpaced(batch_size, AddIfNotNull(def_levels, offset),
+ AddIfNotNull(rep_levels, offset));
+ if (bits_buffer_ != nullptr) {
+ WriteValuesSpaced(AddIfNotNull(values, value_offset), batch_num_values,
+ batch_num_spaced_values, bits_buffer_->data(), /*offset=*/0,
+ /*num_levels=*/batch_size);
+ } else {
+ WriteValuesSpaced(AddIfNotNull(values, value_offset), batch_num_values,
+ batch_num_spaced_values, valid_bits,
+ valid_bits_offset + value_offset, /*num_levels=*/batch_size);
+ }
+ CommitWriteAndCheckPageLimit(batch_size, batch_num_spaced_values);
+ value_offset += batch_num_spaced_values;
+
+ // Dictionary size checked separately from data page size since we
+ // circumvent this check when writing ::arrow::DictionaryArray directly
+ CheckDictionarySizeLimit();
+ };
+ DoInBatches(num_values, properties_->write_batch_size(), WriteChunk);
+ }
+
+ Status WriteArrow(const int16_t* def_levels, const int16_t* rep_levels,
+ int64_t num_levels, const ::arrow::Array& leaf_array,
+ ArrowWriteContext* ctx, bool leaf_field_nullable) override {
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ // Leaf nulls are canonical when there is only a single null element after a list
+ // and it is at the leaf.
+ bool single_nullable_element =
+ (level_info_.def_level == level_info_.repeated_ancestor_def_level + 1) &&
+ leaf_field_nullable;
+ bool maybe_parent_nulls = level_info_.HasNullableValues() && !single_nullable_element;
+ if (maybe_parent_nulls) {
+ ARROW_ASSIGN_OR_RAISE(
+ bits_buffer_,
+ ::arrow::AllocateResizableBuffer(
+ BitUtil::BytesForBits(properties_->write_batch_size()), ctx->memory_pool));
+ bits_buffer_->ZeroPadding();
+ }
+
+ if (leaf_array.type()->id() == ::arrow::Type::DICTIONARY) {
+ return WriteArrowDictionary(def_levels, rep_levels, num_levels, leaf_array, ctx,
+ maybe_parent_nulls);
+ } else {
+ return WriteArrowDense(def_levels, rep_levels, num_levels, leaf_array, ctx,
+ maybe_parent_nulls);
+ }
+ END_PARQUET_CATCH_EXCEPTIONS
+ }
+
+ int64_t EstimatedBufferedValueBytes() const override {
+ return current_encoder_->EstimatedDataEncodedSize();
+ }
+
+ protected:
+ std::shared_ptr<Buffer> GetValuesBuffer() override {
+ return current_encoder_->FlushValues();
+ }
+
+ // Internal function to handle direct writing of ::arrow::DictionaryArray,
+ // since the standard logic concerning dictionary size limits and fallback to
+ // plain encoding is circumvented
+ Status WriteArrowDictionary(const int16_t* def_levels, const int16_t* rep_levels,
+ int64_t num_levels, const ::arrow::Array& array,
+ ArrowWriteContext* context, bool maybe_parent_nulls);
+
+ Status WriteArrowDense(const int16_t* def_levels, const int16_t* rep_levels,
+ int64_t num_levels, const ::arrow::Array& array,
+ ArrowWriteContext* context, bool maybe_parent_nulls);
+
+ void WriteDictionaryPage() override {
+ DCHECK(current_dict_encoder_);
+ std::shared_ptr<ResizableBuffer> buffer = AllocateBuffer(
+ properties_->memory_pool(), current_dict_encoder_->dict_encoded_size());
+ current_dict_encoder_->WriteDict(buffer->mutable_data());
+
+ DictionaryPage page(buffer, current_dict_encoder_->num_entries(),
+ properties_->dictionary_page_encoding());
+ total_bytes_written_ += pager_->WriteDictionaryPage(page);
+ }
+
+ EncodedStatistics GetPageStatistics() override {
+ EncodedStatistics result;
+ if (page_statistics_) result = page_statistics_->Encode();
+ return result;
+ }
+
+ EncodedStatistics GetChunkStatistics() override {
+ EncodedStatistics result;
+ if (chunk_statistics_) result = chunk_statistics_->Encode();
+ return result;
+ }
+
+ void ResetPageStatistics() override {
+ if (chunk_statistics_ != nullptr) {
+ chunk_statistics_->Merge(*page_statistics_);
+ page_statistics_->Reset();
+ }
+ }
+
+ Type::type type() const override { return descr_->physical_type(); }
+
+ const ColumnDescriptor* descr() const override { return descr_; }
+
+ int64_t rows_written() const override { return rows_written_; }
+
+ int64_t total_compressed_bytes() const override { return total_compressed_bytes_; }
+
+ int64_t total_bytes_written() const override { return total_bytes_written_; }
+
+ const WriterProperties* properties() override { return properties_; }
+
+ private:
+ using ValueEncoderType = typename EncodingTraits<DType>::Encoder;
+ using TypedStats = TypedStatistics<DType>;
+ std::unique_ptr<Encoder> current_encoder_;
+ // Downcasted observers of current_encoder_.
+ // The downcast is performed once as opposed to at every use since
+ // dynamic_cast is so expensive, and static_cast is not available due
+ // to virtual inheritance.
+ ValueEncoderType* current_value_encoder_;
+ DictEncoder<DType>* current_dict_encoder_;
+ std::shared_ptr<TypedStats> page_statistics_;
+ std::shared_ptr<TypedStats> chunk_statistics_;
+
+ // If writing a sequence of ::arrow::DictionaryArray to the writer, we keep the
+ // dictionary passed to DictEncoder<T>::PutDictionary so we can check
+ // subsequent array chunks to see either if materialization is required (in
+ // which case we call back to the dense write path)
+ std::shared_ptr<::arrow::Array> preserved_dictionary_;
+
+ int64_t WriteLevels(int64_t num_values, const int16_t* def_levels,
+ const int16_t* rep_levels) {
+ int64_t values_to_write = 0;
+ // If the field is required and non-repeated, there are no definition levels
+ if (descr_->max_definition_level() > 0) {
+ for (int64_t i = 0; i < num_values; ++i) {
+ if (def_levels[i] == descr_->max_definition_level()) {
+ ++values_to_write;
+ }
+ }
+
+ WriteDefinitionLevels(num_values, def_levels);
+ } else {
+ // Required field, write all values
+ values_to_write = num_values;
+ }
+
+ // Not present for non-repeated fields
+ if (descr_->max_repetition_level() > 0) {
+ // A row could include more than one value
+ // Count the occasions where we start a new row
+ for (int64_t i = 0; i < num_values; ++i) {
+ if (rep_levels[i] == 0) {
+ rows_written_++;
+ }
+ }
+
+ WriteRepetitionLevels(num_values, rep_levels);
+ } else {
+ // Each value is exactly one row
+ rows_written_ += num_values;
+ }
+ return values_to_write;
+ }
+
+ // This method will always update the three output parameters,
+ // out_values_to_write, out_spaced_values_to_write and null_count. Additionally
+ // it will update the validity bitmap if required (i.e. if at least one level
+ // of nullable structs directly precede the leaf node).
+ void MaybeCalculateValidityBits(const int16_t* def_levels, int64_t batch_size,
+ int64_t* out_values_to_write,
+ int64_t* out_spaced_values_to_write,
+ int64_t* null_count) {
+ if (bits_buffer_ == nullptr) {
+ if (level_info_.def_level == 0) {
+ // In this case def levels should be null and we only
+ // need to output counts which will always be equal to
+ // the batch size passed in (max def_level == 0 indicates
+ // there cannot be repeated or null fields).
+ DCHECK_EQ(def_levels, nullptr);
+ *out_values_to_write = batch_size;
+ *out_spaced_values_to_write = batch_size;
+ *null_count = 0;
+ } else {
+ for (int x = 0; x < batch_size; x++) {
+ *out_values_to_write += def_levels[x] == level_info_.def_level ? 1 : 0;
+ *out_spaced_values_to_write +=
+ def_levels[x] >= level_info_.repeated_ancestor_def_level ? 1 : 0;
+ }
+ *null_count = *out_values_to_write - *out_spaced_values_to_write;
+ }
+ return;
+ }
+ // Shrink to fit possible causes another allocation, and would only be necessary
+ // on the last batch.
+ int64_t new_bitmap_size = BitUtil::BytesForBits(batch_size);
+ if (new_bitmap_size != bits_buffer_->size()) {
+ PARQUET_THROW_NOT_OK(
+ bits_buffer_->Resize(new_bitmap_size, /*shrink_to_fit=*/false));
+ bits_buffer_->ZeroPadding();
+ }
+ internal::ValidityBitmapInputOutput io;
+ io.valid_bits = bits_buffer_->mutable_data();
+ io.values_read_upper_bound = batch_size;
+ internal::DefLevelsToBitmap(def_levels, batch_size, level_info_, &io);
+ *out_values_to_write = io.values_read - io.null_count;
+ *out_spaced_values_to_write = io.values_read;
+ *null_count = io.null_count;
+ }
+
+ Result<std::shared_ptr<Array>> MaybeReplaceValidity(std::shared_ptr<Array> array,
+ int64_t new_null_count,
+ ::arrow::MemoryPool* memory_pool) {
+ if (bits_buffer_ == nullptr) {
+ return array;
+ }
+ std::vector<std::shared_ptr<Buffer>> buffers = array->data()->buffers;
+ if (buffers.empty()) {
+ return array;
+ }
+ buffers[0] = bits_buffer_;
+ // Should be a leaf array.
+ DCHECK_GT(buffers.size(), 1);
+ ValueBufferSlicer slicer{memory_pool, /*buffer=*/nullptr};
+ if (array->data()->offset > 0) {
+ RETURN_NOT_OK(::arrow::VisitArrayInline(*array, &slicer));
+ buffers[1] = slicer.buffer_;
+ }
+ return ::arrow::MakeArray(std::make_shared<ArrayData>(
+ array->type(), array->length(), std::move(buffers), new_null_count));
+ }
+
+ void WriteLevelsSpaced(int64_t num_levels, const int16_t* def_levels,
+ const int16_t* rep_levels) {
+ // If the field is required and non-repeated, there are no definition levels
+ if (descr_->max_definition_level() > 0) {
+ WriteDefinitionLevels(num_levels, def_levels);
+ }
+ // Not present for non-repeated fields
+ if (descr_->max_repetition_level() > 0) {
+ // A row could include more than one value
+ // Count the occasions where we start a new row
+ for (int64_t i = 0; i < num_levels; ++i) {
+ if (rep_levels[i] == 0) {
+ rows_written_++;
+ }
+ }
+ WriteRepetitionLevels(num_levels, rep_levels);
+ } else {
+ // Each value is exactly one row
+ rows_written_ += num_levels;
+ }
+ }
+
+ void CommitWriteAndCheckPageLimit(int64_t num_levels, int64_t num_values) {
+ num_buffered_values_ += num_levels;
+ num_buffered_encoded_values_ += num_values;
+
+ if (current_encoder_->EstimatedDataEncodedSize() >= properties_->data_pagesize()) {
+ AddDataPage();
+ }
+ }
+
+ void FallbackToPlainEncoding() {
+ if (IsDictionaryEncoding(current_encoder_->encoding())) {
+ WriteDictionaryPage();
+ // Serialize the buffered Dictionary Indices
+ FlushBufferedDataPages();
+ fallback_ = true;
+ // Only PLAIN encoding is supported for fallback in V1
+ current_encoder_ = MakeEncoder(DType::type_num, Encoding::PLAIN, false, descr_,
+ properties_->memory_pool());
+ current_value_encoder_ = dynamic_cast<ValueEncoderType*>(current_encoder_.get());
+ current_dict_encoder_ = nullptr; // not using dict
+ encoding_ = Encoding::PLAIN;
+ }
+ }
+
+ // Checks if the Dictionary Page size limit is reached
+ // If the limit is reached, the Dictionary and Data Pages are serialized
+ // The encoding is switched to PLAIN
+ //
+ // Only one Dictionary Page is written.
+ // Fallback to PLAIN if dictionary page limit is reached.
+ void CheckDictionarySizeLimit() {
+ if (!has_dictionary_ || fallback_) {
+ // Either not using dictionary encoding, or we have already fallen back
+ // to PLAIN encoding because the size threshold was reached
+ return;
+ }
+
+ if (current_dict_encoder_->dict_encoded_size() >=
+ properties_->dictionary_pagesize_limit()) {
+ FallbackToPlainEncoding();
+ }
+ }
+
+ void WriteValues(const T* values, int64_t num_values, int64_t num_nulls) {
+ current_value_encoder_->Put(values, static_cast<int>(num_values));
+ if (page_statistics_ != nullptr) {
+ page_statistics_->Update(values, num_values, num_nulls);
+ }
+ }
+
+ void WriteValuesSpaced(const T* values, int64_t num_values, int64_t num_spaced_values,
+ const uint8_t* valid_bits, int64_t valid_bits_offset,
+ int64_t num_levels) {
+ if (num_values != num_spaced_values) {
+ current_value_encoder_->PutSpaced(values, static_cast<int>(num_spaced_values),
+ valid_bits, valid_bits_offset);
+ } else {
+ current_value_encoder_->Put(values, static_cast<int>(num_values));
+ }
+ if (page_statistics_ != nullptr) {
+ const int64_t num_nulls = num_levels - num_values;
+ page_statistics_->UpdateSpaced(values, valid_bits, valid_bits_offset,
+ num_spaced_values, num_values, num_nulls);
+ }
+ }
+};
+
+template <typename DType>
+Status TypedColumnWriterImpl<DType>::WriteArrowDictionary(
+ const int16_t* def_levels, const int16_t* rep_levels, int64_t num_levels,
+ const ::arrow::Array& array, ArrowWriteContext* ctx, bool maybe_parent_nulls) {
+ // If this is the first time writing a DictionaryArray, then there's
+ // a few possible paths to take:
+ //
+ // - If dictionary encoding is not enabled, convert to densely
+ // encoded and call WriteArrow
+ // - Dictionary encoding enabled
+ // - If this is the first time this is called, then we call
+ // PutDictionary into the encoder and then PutIndices on each
+ // chunk. We store the dictionary that was written in
+ // preserved_dictionary_ so that subsequent calls to this method
+ // can make sure the dictionary has not changed
+ // - On subsequent calls, we have to check whether the dictionary
+ // has changed. If it has, then we trigger the varying
+ // dictionary path and materialize each chunk and then call
+ // WriteArrow with that
+ auto WriteDense = [&] {
+ std::shared_ptr<::arrow::Array> dense_array;
+ RETURN_NOT_OK(
+ ConvertDictionaryToDense(array, properties_->memory_pool(), &dense_array));
+ return WriteArrowDense(def_levels, rep_levels, num_levels, *dense_array, ctx,
+ maybe_parent_nulls);
+ };
+
+ if (!IsDictionaryEncoding(current_encoder_->encoding()) ||
+ !DictionaryDirectWriteSupported(array)) {
+ // No longer dictionary-encoding for whatever reason, maybe we never were
+ // or we decided to stop. Note that WriteArrow can be invoked multiple
+ // times with both dense and dictionary-encoded versions of the same data
+ // without a problem. Any dense data will be hashed to indices until the
+ // dictionary page limit is reached, at which everything (dictionary and
+ // dense) will fall back to plain encoding
+ return WriteDense();
+ }
+
+ auto dict_encoder = dynamic_cast<DictEncoder<DType>*>(current_encoder_.get());
+ const auto& data = checked_cast<const ::arrow::DictionaryArray&>(array);
+ std::shared_ptr<::arrow::Array> dictionary = data.dictionary();
+ std::shared_ptr<::arrow::Array> indices = data.indices();
+
+ int64_t value_offset = 0;
+ auto WriteIndicesChunk = [&](int64_t offset, int64_t batch_size) {
+ int64_t batch_num_values = 0;
+ int64_t batch_num_spaced_values = 0;
+ int64_t null_count = ::arrow::kUnknownNullCount;
+ // Bits is not null for nullable values. At this point in the code we can't determine
+ // if the leaf array has the same null values as any parents it might have had so we
+ // need to recompute it from def levels.
+ MaybeCalculateValidityBits(AddIfNotNull(def_levels, offset), batch_size,
+ &batch_num_values, &batch_num_spaced_values, &null_count);
+ WriteLevelsSpaced(batch_size, AddIfNotNull(def_levels, offset),
+ AddIfNotNull(rep_levels, offset));
+ std::shared_ptr<Array> writeable_indices =
+ indices->Slice(value_offset, batch_num_spaced_values);
+ PARQUET_ASSIGN_OR_THROW(
+ writeable_indices,
+ MaybeReplaceValidity(writeable_indices, null_count, ctx->memory_pool));
+ dict_encoder->PutIndices(*writeable_indices);
+ CommitWriteAndCheckPageLimit(batch_size, batch_num_values);
+ value_offset += batch_num_spaced_values;
+ };
+
+ // Handle seeing dictionary for the first time
+ if (!preserved_dictionary_) {
+ // It's a new dictionary. Call PutDictionary and keep track of it
+ PARQUET_CATCH_NOT_OK(dict_encoder->PutDictionary(*dictionary));
+
+ // If there were duplicate value in the dictionary, the encoder's memo table
+ // will be out of sync with the indices in the Arrow array.
+ // The easiest solution for this uncommon case is to fallback to plain encoding.
+ if (dict_encoder->num_entries() != dictionary->length()) {
+ PARQUET_CATCH_NOT_OK(FallbackToPlainEncoding());
+ return WriteDense();
+ }
+
+ if (page_statistics_ != nullptr) {
+ // TODO(PARQUET-2068) This approach may make two copies. First, a copy of the
+ // indices array to a (hopefully smaller) referenced indices array. Second, a copy
+ // of the values array to a (probably not smaller) referenced values array.
+ //
+ // Once the MinMax kernel supports all data types we should use that kernel instead
+ // as it does not make any copies.
+ ::arrow::compute::ExecContext exec_ctx(ctx->memory_pool);
+ exec_ctx.set_use_threads(false);
+ PARQUET_ASSIGN_OR_THROW(::arrow::Datum referenced_indices,
+ ::arrow::compute::Unique(*indices, &exec_ctx));
+ std::shared_ptr<::arrow::Array> referenced_dictionary;
+ if (referenced_indices.length() == dictionary->length()) {
+ referenced_dictionary = dictionary;
+ } else {
+ PARQUET_ASSIGN_OR_THROW(
+ ::arrow::Datum referenced_dictionary_datum,
+ ::arrow::compute::Take(dictionary, referenced_indices,
+ ::arrow::compute::TakeOptions(/*boundscheck=*/false),
+ &exec_ctx));
+ referenced_dictionary = referenced_dictionary_datum.make_array();
+ }
+ int64_t non_null_count = indices->length() - indices->null_count();
+ page_statistics_->IncrementNullCount(num_levels - non_null_count);
+ page_statistics_->IncrementNumValues(non_null_count);
+ page_statistics_->Update(*referenced_dictionary, /*update_counts=*/false);
+ }
+ preserved_dictionary_ = dictionary;
+ } else if (!dictionary->Equals(*preserved_dictionary_)) {
+ // Dictionary has changed
+ PARQUET_CATCH_NOT_OK(FallbackToPlainEncoding());
+ return WriteDense();
+ }
+
+ PARQUET_CATCH_NOT_OK(
+ DoInBatches(num_levels, properties_->write_batch_size(), WriteIndicesChunk));
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Direct Arrow write path
+
+template <typename ParquetType, typename ArrowType, typename Enable = void>
+struct SerializeFunctor {
+ using ArrowCType = typename ArrowType::c_type;
+ using ArrayType = typename ::arrow::TypeTraits<ArrowType>::ArrayType;
+ using ParquetCType = typename ParquetType::c_type;
+ Status Serialize(const ArrayType& array, ArrowWriteContext*, ParquetCType* out) {
+ const ArrowCType* input = array.raw_values();
+ if (array.null_count() > 0) {
+ for (int i = 0; i < array.length(); i++) {
+ out[i] = static_cast<ParquetCType>(input[i]);
+ }
+ } else {
+ std::copy(input, input + array.length(), out);
+ }
+ return Status::OK();
+ }
+};
+
+template <typename ParquetType, typename ArrowType>
+Status WriteArrowSerialize(const ::arrow::Array& array, int64_t num_levels,
+ const int16_t* def_levels, const int16_t* rep_levels,
+ ArrowWriteContext* ctx, TypedColumnWriter<ParquetType>* writer,
+ bool maybe_parent_nulls) {
+ using ParquetCType = typename ParquetType::c_type;
+ using ArrayType = typename ::arrow::TypeTraits<ArrowType>::ArrayType;
+
+ ParquetCType* buffer = nullptr;
+ PARQUET_THROW_NOT_OK(ctx->GetScratchData<ParquetCType>(array.length(), &buffer));
+
+ SerializeFunctor<ParquetType, ArrowType> functor;
+ RETURN_NOT_OK(functor.Serialize(checked_cast<const ArrayType&>(array), ctx, buffer));
+ bool no_nulls =
+ writer->descr()->schema_node()->is_required() || (array.null_count() == 0);
+ if (!maybe_parent_nulls && no_nulls) {
+ PARQUET_CATCH_NOT_OK(writer->WriteBatch(num_levels, def_levels, rep_levels, buffer));
+ } else {
+ PARQUET_CATCH_NOT_OK(writer->WriteBatchSpaced(num_levels, def_levels, rep_levels,
+ array.null_bitmap_data(),
+ array.offset(), buffer));
+ }
+ return Status::OK();
+}
+
+template <typename ParquetType>
+Status WriteArrowZeroCopy(const ::arrow::Array& array, int64_t num_levels,
+ const int16_t* def_levels, const int16_t* rep_levels,
+ ArrowWriteContext* ctx, TypedColumnWriter<ParquetType>* writer,
+ bool maybe_parent_nulls) {
+ using T = typename ParquetType::c_type;
+ const auto& data = static_cast<const ::arrow::PrimitiveArray&>(array);
+ const T* values = nullptr;
+ // The values buffer may be null if the array is empty (ARROW-2744)
+ if (data.values() != nullptr) {
+ values = reinterpret_cast<const T*>(data.values()->data()) + data.offset();
+ } else {
+ DCHECK_EQ(data.length(), 0);
+ }
+ bool no_nulls =
+ writer->descr()->schema_node()->is_required() || (array.null_count() == 0);
+
+ if (!maybe_parent_nulls && no_nulls) {
+ PARQUET_CATCH_NOT_OK(writer->WriteBatch(num_levels, def_levels, rep_levels, values));
+ } else {
+ PARQUET_CATCH_NOT_OK(writer->WriteBatchSpaced(num_levels, def_levels, rep_levels,
+ data.null_bitmap_data(), data.offset(),
+ values));
+ }
+ return Status::OK();
+}
+
+#define WRITE_SERIALIZE_CASE(ArrowEnum, ArrowType, ParquetType) \
+ case ::arrow::Type::ArrowEnum: \
+ return WriteArrowSerialize<ParquetType, ::arrow::ArrowType>( \
+ array, num_levels, def_levels, rep_levels, ctx, this, maybe_parent_nulls);
+
+#define WRITE_ZERO_COPY_CASE(ArrowEnum, ArrowType, ParquetType) \
+ case ::arrow::Type::ArrowEnum: \
+ return WriteArrowZeroCopy<ParquetType>(array, num_levels, def_levels, rep_levels, \
+ ctx, this, maybe_parent_nulls);
+
+#define ARROW_UNSUPPORTED() \
+ std::stringstream ss; \
+ ss << "Arrow type " << array.type()->ToString() \
+ << " cannot be written to Parquet type " << descr_->ToString(); \
+ return Status::Invalid(ss.str());
+
+// ----------------------------------------------------------------------
+// Write Arrow to BooleanType
+
+template <>
+struct SerializeFunctor<BooleanType, ::arrow::BooleanType> {
+ Status Serialize(const ::arrow::BooleanArray& data, ArrowWriteContext*, bool* out) {
+ for (int i = 0; i < data.length(); i++) {
+ *out++ = data.Value(i);
+ }
+ return Status::OK();
+ }
+};
+
+template <>
+Status TypedColumnWriterImpl<BooleanType>::WriteArrowDense(
+ const int16_t* def_levels, const int16_t* rep_levels, int64_t num_levels,
+ const ::arrow::Array& array, ArrowWriteContext* ctx, bool maybe_parent_nulls) {
+ if (array.type_id() != ::arrow::Type::BOOL) {
+ ARROW_UNSUPPORTED();
+ }
+ return WriteArrowSerialize<BooleanType, ::arrow::BooleanType>(
+ array, num_levels, def_levels, rep_levels, ctx, this, maybe_parent_nulls);
+}
+
+// ----------------------------------------------------------------------
+// Write Arrow types to INT32
+
+template <>
+struct SerializeFunctor<Int32Type, ::arrow::Date64Type> {
+ Status Serialize(const ::arrow::Date64Array& array, ArrowWriteContext*, int32_t* out) {
+ const int64_t* input = array.raw_values();
+ for (int i = 0; i < array.length(); i++) {
+ *out++ = static_cast<int32_t>(*input++ / 86400000);
+ }
+ return Status::OK();
+ }
+};
+
+template <>
+struct SerializeFunctor<Int32Type, ::arrow::Time32Type> {
+ Status Serialize(const ::arrow::Time32Array& array, ArrowWriteContext*, int32_t* out) {
+ const int32_t* input = array.raw_values();
+ const auto& type = static_cast<const ::arrow::Time32Type&>(*array.type());
+ if (type.unit() == ::arrow::TimeUnit::SECOND) {
+ for (int i = 0; i < array.length(); i++) {
+ out[i] = input[i] * 1000;
+ }
+ } else {
+ std::copy(input, input + array.length(), out);
+ }
+ return Status::OK();
+ }
+};
+
+template <>
+Status TypedColumnWriterImpl<Int32Type>::WriteArrowDense(
+ const int16_t* def_levels, const int16_t* rep_levels, int64_t num_levels,
+ const ::arrow::Array& array, ArrowWriteContext* ctx, bool maybe_parent_nulls) {
+ switch (array.type()->id()) {
+ case ::arrow::Type::NA: {
+ PARQUET_CATCH_NOT_OK(WriteBatch(num_levels, def_levels, rep_levels, nullptr));
+ } break;
+ WRITE_SERIALIZE_CASE(INT8, Int8Type, Int32Type)
+ WRITE_SERIALIZE_CASE(UINT8, UInt8Type, Int32Type)
+ WRITE_SERIALIZE_CASE(INT16, Int16Type, Int32Type)
+ WRITE_SERIALIZE_CASE(UINT16, UInt16Type, Int32Type)
+ WRITE_SERIALIZE_CASE(UINT32, UInt32Type, Int32Type)
+ WRITE_ZERO_COPY_CASE(INT32, Int32Type, Int32Type)
+ WRITE_ZERO_COPY_CASE(DATE32, Date32Type, Int32Type)
+ WRITE_SERIALIZE_CASE(DATE64, Date64Type, Int32Type)
+ WRITE_SERIALIZE_CASE(TIME32, Time32Type, Int32Type)
+ default:
+ ARROW_UNSUPPORTED()
+ }
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Write Arrow to Int64 and Int96
+
+#define INT96_CONVERT_LOOP(ConversionFunction) \
+ for (int64_t i = 0; i < array.length(); i++) ConversionFunction(input[i], &out[i]);
+
+template <>
+struct SerializeFunctor<Int96Type, ::arrow::TimestampType> {
+ Status Serialize(const ::arrow::TimestampArray& array, ArrowWriteContext*, Int96* out) {
+ const int64_t* input = array.raw_values();
+ const auto& type = static_cast<const ::arrow::TimestampType&>(*array.type());
+ switch (type.unit()) {
+ case ::arrow::TimeUnit::NANO:
+ INT96_CONVERT_LOOP(internal::NanosecondsToImpalaTimestamp);
+ break;
+ case ::arrow::TimeUnit::MICRO:
+ INT96_CONVERT_LOOP(internal::MicrosecondsToImpalaTimestamp);
+ break;
+ case ::arrow::TimeUnit::MILLI:
+ INT96_CONVERT_LOOP(internal::MillisecondsToImpalaTimestamp);
+ break;
+ case ::arrow::TimeUnit::SECOND:
+ INT96_CONVERT_LOOP(internal::SecondsToImpalaTimestamp);
+ break;
+ }
+ return Status::OK();
+ }
+};
+
+#define COERCE_DIVIDE -1
+#define COERCE_INVALID 0
+#define COERCE_MULTIPLY +1
+
+static std::pair<int, int64_t> kTimestampCoercionFactors[4][4] = {
+ // from seconds ...
+ {{COERCE_INVALID, 0}, // ... to seconds
+ {COERCE_MULTIPLY, 1000}, // ... to millis
+ {COERCE_MULTIPLY, 1000000}, // ... to micros
+ {COERCE_MULTIPLY, INT64_C(1000000000)}}, // ... to nanos
+ // from millis ...
+ {{COERCE_INVALID, 0},
+ {COERCE_MULTIPLY, 1},
+ {COERCE_MULTIPLY, 1000},
+ {COERCE_MULTIPLY, 1000000}},
+ // from micros ...
+ {{COERCE_INVALID, 0},
+ {COERCE_DIVIDE, 1000},
+ {COERCE_MULTIPLY, 1},
+ {COERCE_MULTIPLY, 1000}},
+ // from nanos ...
+ {{COERCE_INVALID, 0},
+ {COERCE_DIVIDE, 1000000},
+ {COERCE_DIVIDE, 1000},
+ {COERCE_MULTIPLY, 1}}};
+
+template <>
+struct SerializeFunctor<Int64Type, ::arrow::TimestampType> {
+ Status Serialize(const ::arrow::TimestampArray& array, ArrowWriteContext* ctx,
+ int64_t* out) {
+ const auto& source_type = static_cast<const ::arrow::TimestampType&>(*array.type());
+ auto source_unit = source_type.unit();
+ const int64_t* values = array.raw_values();
+
+ ::arrow::TimeUnit::type target_unit = ctx->properties->coerce_timestamps_unit();
+ auto target_type = ::arrow::timestamp(target_unit);
+ bool truncation_allowed = ctx->properties->truncated_timestamps_allowed();
+
+ auto DivideBy = [&](const int64_t factor) {
+ for (int64_t i = 0; i < array.length(); i++) {
+ if (!truncation_allowed && array.IsValid(i) && (values[i] % factor != 0)) {
+ return Status::Invalid("Casting from ", source_type.ToString(), " to ",
+ target_type->ToString(),
+ " would lose data: ", values[i]);
+ }
+ out[i] = values[i] / factor;
+ }
+ return Status::OK();
+ };
+
+ auto MultiplyBy = [&](const int64_t factor) {
+ for (int64_t i = 0; i < array.length(); i++) {
+ out[i] = values[i] * factor;
+ }
+ return Status::OK();
+ };
+
+ const auto& coercion = kTimestampCoercionFactors[static_cast<int>(source_unit)]
+ [static_cast<int>(target_unit)];
+
+ // .first -> coercion operation; .second -> scale factor
+ DCHECK_NE(coercion.first, COERCE_INVALID);
+ return coercion.first == COERCE_DIVIDE ? DivideBy(coercion.second)
+ : MultiplyBy(coercion.second);
+ }
+};
+
+#undef COERCE_DIVIDE
+#undef COERCE_INVALID
+#undef COERCE_MULTIPLY
+
+Status WriteTimestamps(const ::arrow::Array& values, int64_t num_levels,
+ const int16_t* def_levels, const int16_t* rep_levels,
+ ArrowWriteContext* ctx, TypedColumnWriter<Int64Type>* writer,
+ bool maybe_parent_nulls) {
+ const auto& source_type = static_cast<const ::arrow::TimestampType&>(*values.type());
+
+ auto WriteCoerce = [&](const ArrowWriterProperties* properties) {
+ ArrowWriteContext temp_ctx = *ctx;
+ temp_ctx.properties = properties;
+ return WriteArrowSerialize<Int64Type, ::arrow::TimestampType>(
+ values, num_levels, def_levels, rep_levels, &temp_ctx, writer,
+ maybe_parent_nulls);
+ };
+
+ const ParquetVersion::type version = writer->properties()->version();
+
+ if (ctx->properties->coerce_timestamps_enabled()) {
+ // User explicitly requested coercion to specific unit
+ if (source_type.unit() == ctx->properties->coerce_timestamps_unit()) {
+ // No data conversion necessary
+ return WriteArrowZeroCopy<Int64Type>(values, num_levels, def_levels, rep_levels,
+ ctx, writer, maybe_parent_nulls);
+ } else {
+ return WriteCoerce(ctx->properties);
+ }
+ } else if ((version == ParquetVersion::PARQUET_1_0 ||
+ version == ParquetVersion::PARQUET_2_4) &&
+ source_type.unit() == ::arrow::TimeUnit::NANO) {
+ // Absent superseding user instructions, when writing Parquet version <= 2.4 files,
+ // timestamps in nanoseconds are coerced to microseconds
+ std::shared_ptr<ArrowWriterProperties> properties =
+ (ArrowWriterProperties::Builder())
+ .coerce_timestamps(::arrow::TimeUnit::MICRO)
+ ->disallow_truncated_timestamps()
+ ->build();
+ return WriteCoerce(properties.get());
+ } else if (source_type.unit() == ::arrow::TimeUnit::SECOND) {
+ // Absent superseding user instructions, timestamps in seconds are coerced to
+ // milliseconds
+ std::shared_ptr<ArrowWriterProperties> properties =
+ (ArrowWriterProperties::Builder())
+ .coerce_timestamps(::arrow::TimeUnit::MILLI)
+ ->build();
+ return WriteCoerce(properties.get());
+ } else {
+ // No data conversion necessary
+ return WriteArrowZeroCopy<Int64Type>(values, num_levels, def_levels, rep_levels, ctx,
+ writer, maybe_parent_nulls);
+ }
+}
+
+template <>
+Status TypedColumnWriterImpl<Int64Type>::WriteArrowDense(
+ const int16_t* def_levels, const int16_t* rep_levels, int64_t num_levels,
+ const ::arrow::Array& array, ArrowWriteContext* ctx, bool maybe_parent_nulls) {
+ switch (array.type()->id()) {
+ case ::arrow::Type::TIMESTAMP:
+ return WriteTimestamps(array, num_levels, def_levels, rep_levels, ctx, this,
+ maybe_parent_nulls);
+ WRITE_ZERO_COPY_CASE(INT64, Int64Type, Int64Type)
+ WRITE_SERIALIZE_CASE(UINT32, UInt32Type, Int64Type)
+ WRITE_SERIALIZE_CASE(UINT64, UInt64Type, Int64Type)
+ WRITE_ZERO_COPY_CASE(TIME64, Time64Type, Int64Type)
+ default:
+ ARROW_UNSUPPORTED();
+ }
+}
+
+template <>
+Status TypedColumnWriterImpl<Int96Type>::WriteArrowDense(
+ const int16_t* def_levels, const int16_t* rep_levels, int64_t num_levels,
+ const ::arrow::Array& array, ArrowWriteContext* ctx, bool maybe_parent_nulls) {
+ if (array.type_id() != ::arrow::Type::TIMESTAMP) {
+ ARROW_UNSUPPORTED();
+ }
+ return WriteArrowSerialize<Int96Type, ::arrow::TimestampType>(
+ array, num_levels, def_levels, rep_levels, ctx, this, maybe_parent_nulls);
+}
+
+// ----------------------------------------------------------------------
+// Floating point types
+
+template <>
+Status TypedColumnWriterImpl<FloatType>::WriteArrowDense(
+ const int16_t* def_levels, const int16_t* rep_levels, int64_t num_levels,
+ const ::arrow::Array& array, ArrowWriteContext* ctx, bool maybe_parent_nulls) {
+ if (array.type_id() != ::arrow::Type::FLOAT) {
+ ARROW_UNSUPPORTED();
+ }
+ return WriteArrowZeroCopy<FloatType>(array, num_levels, def_levels, rep_levels, ctx,
+ this, maybe_parent_nulls);
+}
+
+template <>
+Status TypedColumnWriterImpl<DoubleType>::WriteArrowDense(
+ const int16_t* def_levels, const int16_t* rep_levels, int64_t num_levels,
+ const ::arrow::Array& array, ArrowWriteContext* ctx, bool maybe_parent_nulls) {
+ if (array.type_id() != ::arrow::Type::DOUBLE) {
+ ARROW_UNSUPPORTED();
+ }
+ return WriteArrowZeroCopy<DoubleType>(array, num_levels, def_levels, rep_levels, ctx,
+ this, maybe_parent_nulls);
+}
+
+// ----------------------------------------------------------------------
+// Write Arrow to BYTE_ARRAY
+
+template <>
+Status TypedColumnWriterImpl<ByteArrayType>::WriteArrowDense(
+ const int16_t* def_levels, const int16_t* rep_levels, int64_t num_levels,
+ const ::arrow::Array& array, ArrowWriteContext* ctx, bool maybe_parent_nulls) {
+ if (!::arrow::is_base_binary_like(array.type()->id())) {
+ ARROW_UNSUPPORTED();
+ }
+
+ int64_t value_offset = 0;
+ auto WriteChunk = [&](int64_t offset, int64_t batch_size) {
+ int64_t batch_num_values = 0;
+ int64_t batch_num_spaced_values = 0;
+ int64_t null_count = 0;
+
+ MaybeCalculateValidityBits(AddIfNotNull(def_levels, offset), batch_size,
+ &batch_num_values, &batch_num_spaced_values, &null_count);
+ WriteLevelsSpaced(batch_size, AddIfNotNull(def_levels, offset),
+ AddIfNotNull(rep_levels, offset));
+ std::shared_ptr<Array> data_slice =
+ array.Slice(value_offset, batch_num_spaced_values);
+ PARQUET_ASSIGN_OR_THROW(
+ data_slice, MaybeReplaceValidity(data_slice, null_count, ctx->memory_pool));
+
+ current_encoder_->Put(*data_slice);
+ if (page_statistics_ != nullptr) {
+ page_statistics_->Update(*data_slice, /*update_counts=*/false);
+ // Null values in ancestors count as nulls.
+ int64_t non_null = data_slice->length() - data_slice->null_count();
+ page_statistics_->IncrementNullCount(batch_size - non_null);
+ page_statistics_->IncrementNumValues(non_null);
+ }
+ CommitWriteAndCheckPageLimit(batch_size, batch_num_values);
+ CheckDictionarySizeLimit();
+ value_offset += batch_num_spaced_values;
+ };
+
+ PARQUET_CATCH_NOT_OK(
+ DoInBatches(num_levels, properties_->write_batch_size(), WriteChunk));
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Write Arrow to FIXED_LEN_BYTE_ARRAY
+
+template <typename ParquetType, typename ArrowType>
+struct SerializeFunctor<
+ ParquetType, ArrowType,
+ ::arrow::enable_if_t<::arrow::is_fixed_size_binary_type<ArrowType>::value &&
+ !::arrow::is_decimal_type<ArrowType>::value>> {
+ Status Serialize(const ::arrow::FixedSizeBinaryArray& array, ArrowWriteContext*,
+ FLBA* out) {
+ if (array.null_count() == 0) {
+ // no nulls, just dump the data
+ // todo(advancedxy): use a writeBatch to avoid this step
+ for (int64_t i = 0; i < array.length(); i++) {
+ out[i] = FixedLenByteArray(array.GetValue(i));
+ }
+ } else {
+ for (int64_t i = 0; i < array.length(); i++) {
+ if (array.IsValid(i)) {
+ out[i] = FixedLenByteArray(array.GetValue(i));
+ }
+ }
+ }
+ return Status::OK();
+ }
+};
+
+// ----------------------------------------------------------------------
+// Write Arrow to Decimal128
+
+// Requires a custom serializer because decimal in parquet are in big-endian
+// format. Thus, a temporary local buffer is required.
+template <typename ParquetType, typename ArrowType>
+struct SerializeFunctor<ParquetType, ArrowType, ::arrow::enable_if_decimal<ArrowType>> {
+ Status Serialize(const typename ::arrow::TypeTraits<ArrowType>::ArrayType& array,
+ ArrowWriteContext* ctx, FLBA* out) {
+ AllocateScratch(array, ctx);
+ auto offset = Offset(array);
+
+ if (array.null_count() == 0) {
+ for (int64_t i = 0; i < array.length(); i++) {
+ out[i] = FixDecimalEndianess<ArrowType::kByteWidth>(array.GetValue(i), offset);
+ }
+ } else {
+ for (int64_t i = 0; i < array.length(); i++) {
+ out[i] = array.IsValid(i) ? FixDecimalEndianess<ArrowType::kByteWidth>(
+ array.GetValue(i), offset)
+ : FixedLenByteArray();
+ }
+ }
+
+ return Status::OK();
+ }
+
+ // Parquet's Decimal are stored with FixedLength values where the length is
+ // proportional to the precision. Arrow's Decimal are always stored with 16/32
+ // bytes. Thus the internal FLBA pointer must be adjusted by the offset calculated
+ // here.
+ int32_t Offset(const Array& array) {
+ auto decimal_type = checked_pointer_cast<::arrow::DecimalType>(array.type());
+ return decimal_type->byte_width() -
+ ::arrow::DecimalType::DecimalSize(decimal_type->precision());
+ }
+
+ void AllocateScratch(const typename ::arrow::TypeTraits<ArrowType>::ArrayType& array,
+ ArrowWriteContext* ctx) {
+ int64_t non_null_count = array.length() - array.null_count();
+ int64_t size = non_null_count * ArrowType::kByteWidth;
+ scratch_buffer = AllocateBuffer(ctx->memory_pool, size);
+ scratch = reinterpret_cast<int64_t*>(scratch_buffer->mutable_data());
+ }
+
+ template <int byte_width>
+ FixedLenByteArray FixDecimalEndianess(const uint8_t* in, int64_t offset) {
+ const auto* u64_in = reinterpret_cast<const int64_t*>(in);
+ auto out = reinterpret_cast<const uint8_t*>(scratch) + offset;
+ static_assert(byte_width == 16 || byte_width == 32,
+ "only 16 and 32 byte Decimals supported");
+ if (byte_width == 32) {
+ *scratch++ = ::arrow::BitUtil::ToBigEndian(u64_in[3]);
+ *scratch++ = ::arrow::BitUtil::ToBigEndian(u64_in[2]);
+ *scratch++ = ::arrow::BitUtil::ToBigEndian(u64_in[1]);
+ *scratch++ = ::arrow::BitUtil::ToBigEndian(u64_in[0]);
+ } else {
+ *scratch++ = ::arrow::BitUtil::ToBigEndian(u64_in[1]);
+ *scratch++ = ::arrow::BitUtil::ToBigEndian(u64_in[0]);
+ }
+ return FixedLenByteArray(out);
+ }
+
+ std::shared_ptr<ResizableBuffer> scratch_buffer;
+ int64_t* scratch;
+};
+
+template <>
+Status TypedColumnWriterImpl<FLBAType>::WriteArrowDense(
+ const int16_t* def_levels, const int16_t* rep_levels, int64_t num_levels,
+ const ::arrow::Array& array, ArrowWriteContext* ctx, bool maybe_parent_nulls) {
+ switch (array.type()->id()) {
+ WRITE_SERIALIZE_CASE(FIXED_SIZE_BINARY, FixedSizeBinaryType, FLBAType)
+ WRITE_SERIALIZE_CASE(DECIMAL128, Decimal128Type, FLBAType)
+ WRITE_SERIALIZE_CASE(DECIMAL256, Decimal256Type, FLBAType)
+ default:
+ break;
+ }
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Dynamic column writer constructor
+
+std::shared_ptr<ColumnWriter> ColumnWriter::Make(ColumnChunkMetaDataBuilder* metadata,
+ std::unique_ptr<PageWriter> pager,
+ const WriterProperties* properties) {
+ const ColumnDescriptor* descr = metadata->descr();
+ const bool use_dictionary = properties->dictionary_enabled(descr->path()) &&
+ descr->physical_type() != Type::BOOLEAN;
+ Encoding::type encoding = properties->encoding(descr->path());
+ if (use_dictionary) {
+ encoding = properties->dictionary_index_encoding();
+ }
+ switch (descr->physical_type()) {
+ case Type::BOOLEAN:
+ return std::make_shared<TypedColumnWriterImpl<BooleanType>>(
+ metadata, std::move(pager), use_dictionary, encoding, properties);
+ case Type::INT32:
+ return std::make_shared<TypedColumnWriterImpl<Int32Type>>(
+ metadata, std::move(pager), use_dictionary, encoding, properties);
+ case Type::INT64:
+ return std::make_shared<TypedColumnWriterImpl<Int64Type>>(
+ metadata, std::move(pager), use_dictionary, encoding, properties);
+ case Type::INT96:
+ return std::make_shared<TypedColumnWriterImpl<Int96Type>>(
+ metadata, std::move(pager), use_dictionary, encoding, properties);
+ case Type::FLOAT:
+ return std::make_shared<TypedColumnWriterImpl<FloatType>>(
+ metadata, std::move(pager), use_dictionary, encoding, properties);
+ case Type::DOUBLE:
+ return std::make_shared<TypedColumnWriterImpl<DoubleType>>(
+ metadata, std::move(pager), use_dictionary, encoding, properties);
+ case Type::BYTE_ARRAY:
+ return std::make_shared<TypedColumnWriterImpl<ByteArrayType>>(
+ metadata, std::move(pager), use_dictionary, encoding, properties);
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return std::make_shared<TypedColumnWriterImpl<FLBAType>>(
+ metadata, std::move(pager), use_dictionary, encoding, properties);
+ default:
+ ParquetException::NYI("type reader not implemented");
+ }
+ // Unreachable code, but suppress compiler warning
+ return std::shared_ptr<ColumnWriter>(nullptr);
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/column_writer.h b/src/arrow/cpp/src/parquet/column_writer.h
new file mode 100644
index 000000000..0a6090217
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/column_writer.h
@@ -0,0 +1,270 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+
+#include "parquet/exception.h"
+#include "parquet/platform.h"
+#include "parquet/types.h"
+
+namespace arrow {
+
+class Array;
+
+namespace BitUtil {
+class BitWriter;
+} // namespace BitUtil
+
+namespace util {
+class RleEncoder;
+} // namespace util
+
+} // namespace arrow
+
+namespace parquet {
+
+struct ArrowWriteContext;
+class ColumnDescriptor;
+class DataPage;
+class DictionaryPage;
+class ColumnChunkMetaDataBuilder;
+class Encryptor;
+class WriterProperties;
+
+class PARQUET_EXPORT LevelEncoder {
+ public:
+ LevelEncoder();
+ ~LevelEncoder();
+
+ static int MaxBufferSize(Encoding::type encoding, int16_t max_level,
+ int num_buffered_values);
+
+ // Initialize the LevelEncoder.
+ void Init(Encoding::type encoding, int16_t max_level, int num_buffered_values,
+ uint8_t* data, int data_size);
+
+ // Encodes a batch of levels from an array and returns the number of levels encoded
+ int Encode(int batch_size, const int16_t* levels);
+
+ int32_t len() {
+ if (encoding_ != Encoding::RLE) {
+ throw ParquetException("Only implemented for RLE encoding");
+ }
+ return rle_length_;
+ }
+
+ private:
+ int bit_width_;
+ int rle_length_;
+ Encoding::type encoding_;
+ std::unique_ptr<::arrow::util::RleEncoder> rle_encoder_;
+ std::unique_ptr<::arrow::BitUtil::BitWriter> bit_packed_encoder_;
+};
+
+class PARQUET_EXPORT PageWriter {
+ public:
+ virtual ~PageWriter() {}
+
+ static std::unique_ptr<PageWriter> Open(
+ std::shared_ptr<ArrowOutputStream> sink, Compression::type codec,
+ int compression_level, ColumnChunkMetaDataBuilder* metadata,
+ int16_t row_group_ordinal = -1, int16_t column_chunk_ordinal = -1,
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool(),
+ bool buffered_row_group = false,
+ std::shared_ptr<Encryptor> header_encryptor = NULLPTR,
+ std::shared_ptr<Encryptor> data_encryptor = NULLPTR);
+
+ // The Column Writer decides if dictionary encoding is used if set and
+ // if the dictionary encoding has fallen back to default encoding on reaching dictionary
+ // page limit
+ virtual void Close(bool has_dictionary, bool fallback) = 0;
+
+ // Return the number of uncompressed bytes written (including header size)
+ virtual int64_t WriteDataPage(const DataPage& page) = 0;
+
+ // Return the number of uncompressed bytes written (including header size)
+ virtual int64_t WriteDictionaryPage(const DictionaryPage& page) = 0;
+
+ virtual bool has_compressor() = 0;
+
+ virtual void Compress(const Buffer& src_buffer, ResizableBuffer* dest_buffer) = 0;
+};
+
+static constexpr int WRITE_BATCH_SIZE = 1000;
+class PARQUET_EXPORT ColumnWriter {
+ public:
+ virtual ~ColumnWriter() = default;
+
+ static std::shared_ptr<ColumnWriter> Make(ColumnChunkMetaDataBuilder*,
+ std::unique_ptr<PageWriter>,
+ const WriterProperties* properties);
+
+ /// \brief Closes the ColumnWriter, commits any buffered values to pages.
+ /// \return Total size of the column in bytes
+ virtual int64_t Close() = 0;
+
+ /// \brief The physical Parquet type of the column
+ virtual Type::type type() const = 0;
+
+ /// \brief The schema for the column
+ virtual const ColumnDescriptor* descr() const = 0;
+
+ /// \brief The number of rows written so far
+ virtual int64_t rows_written() const = 0;
+
+ /// \brief The total size of the compressed pages + page headers. Some values
+ /// might be still buffered and not written to a page yet
+ virtual int64_t total_compressed_bytes() const = 0;
+
+ /// \brief The total number of bytes written as serialized data and
+ /// dictionary pages to the ColumnChunk so far
+ virtual int64_t total_bytes_written() const = 0;
+
+ /// \brief The file-level writer properties
+ virtual const WriterProperties* properties() = 0;
+
+ /// \brief Write Apache Arrow columnar data directly to ColumnWriter. Returns
+ /// error status if the array data type is not compatible with the concrete
+ /// writer type.
+ ///
+ /// leaf_array is always a primitive (possibly dictionary encoded type).
+ /// Leaf_field_nullable indicates whether the leaf array is considered nullable
+ /// according to its schema in a Table or its parent array.
+ virtual ::arrow::Status WriteArrow(const int16_t* def_levels, const int16_t* rep_levels,
+ int64_t num_levels, const ::arrow::Array& leaf_array,
+ ArrowWriteContext* ctx,
+ bool leaf_field_nullable) = 0;
+};
+
+// API to write values to a single column. This is the main client facing API.
+template <typename DType>
+class TypedColumnWriter : public ColumnWriter {
+ public:
+ using T = typename DType::c_type;
+
+ // Write a batch of repetition levels, definition levels, and values to the
+ // column.
+ // `num_values` is the number of logical leaf values.
+ // `def_levels` (resp. `rep_levels`) can be null if the column's max definition level
+ // (resp. max repetition level) is 0.
+ // If not null, each of `def_levels` and `rep_levels` must have at least
+ // `num_values`.
+ //
+ // The number of physical values written (taken from `values`) is returned.
+ // It can be smaller than `num_values` is there are some undefined values.
+ virtual int64_t WriteBatch(int64_t num_values, const int16_t* def_levels,
+ const int16_t* rep_levels, const T* values) = 0;
+
+ /// Write a batch of repetition levels, definition levels, and values to the
+ /// column.
+ ///
+ /// In comparison to WriteBatch the length of repetition and definition levels
+ /// is the same as of the number of values read for max_definition_level == 1.
+ /// In the case of max_definition_level > 1, the repetition and definition
+ /// levels are larger than the values but the values include the null entries
+ /// with definition_level == (max_definition_level - 1). Thus we have to differentiate
+ /// in the parameters of this function if the input has the length of num_values or the
+ /// _number of rows in the lowest nesting level_.
+ ///
+ /// In the case that the most inner node in the Parquet is required, the _number of rows
+ /// in the lowest nesting level_ is equal to the number of non-null values. If the
+ /// inner-most schema node is optional, the _number of rows in the lowest nesting level_
+ /// also includes all values with definition_level == (max_definition_level - 1).
+ ///
+ /// @param num_values number of levels to write.
+ /// @param def_levels The Parquet definition levels, length is num_values
+ /// @param rep_levels The Parquet repetition levels, length is num_values
+ /// @param valid_bits Bitmap that indicates if the row is null on the lowest nesting
+ /// level. The length is number of rows in the lowest nesting level.
+ /// @param valid_bits_offset The offset in bits of the valid_bits where the
+ /// first relevant bit resides.
+ /// @param values The values in the lowest nested level including
+ /// spacing for nulls on the lowest levels; input has the length
+ /// of the number of rows on the lowest nesting level.
+ virtual void WriteBatchSpaced(int64_t num_values, const int16_t* def_levels,
+ const int16_t* rep_levels, const uint8_t* valid_bits,
+ int64_t valid_bits_offset, const T* values) = 0;
+
+ // Estimated size of the values that are not written to a page yet
+ virtual int64_t EstimatedBufferedValueBytes() const = 0;
+};
+
+using BoolWriter = TypedColumnWriter<BooleanType>;
+using Int32Writer = TypedColumnWriter<Int32Type>;
+using Int64Writer = TypedColumnWriter<Int64Type>;
+using Int96Writer = TypedColumnWriter<Int96Type>;
+using FloatWriter = TypedColumnWriter<FloatType>;
+using DoubleWriter = TypedColumnWriter<DoubleType>;
+using ByteArrayWriter = TypedColumnWriter<ByteArrayType>;
+using FixedLenByteArrayWriter = TypedColumnWriter<FLBAType>;
+
+namespace internal {
+
+/**
+ * Timestamp conversion constants
+ */
+constexpr int64_t kJulianEpochOffsetDays = INT64_C(2440588);
+
+template <int64_t UnitPerDay, int64_t NanosecondsPerUnit>
+inline void ArrowTimestampToImpalaTimestamp(const int64_t time, Int96* impala_timestamp) {
+ int64_t julian_days = (time / UnitPerDay) + kJulianEpochOffsetDays;
+ (*impala_timestamp).value[2] = (uint32_t)julian_days;
+
+ int64_t last_day_units = time % UnitPerDay;
+ auto last_day_nanos = last_day_units * NanosecondsPerUnit;
+ // impala_timestamp will be unaligned every other entry so do memcpy instead
+ // of assign and reinterpret cast to avoid undefined behavior.
+ std::memcpy(impala_timestamp, &last_day_nanos, sizeof(int64_t));
+}
+
+constexpr int64_t kSecondsInNanos = INT64_C(1000000000);
+
+inline void SecondsToImpalaTimestamp(const int64_t seconds, Int96* impala_timestamp) {
+ ArrowTimestampToImpalaTimestamp<kSecondsPerDay, kSecondsInNanos>(seconds,
+ impala_timestamp);
+}
+
+constexpr int64_t kMillisecondsInNanos = kSecondsInNanos / INT64_C(1000);
+
+inline void MillisecondsToImpalaTimestamp(const int64_t milliseconds,
+ Int96* impala_timestamp) {
+ ArrowTimestampToImpalaTimestamp<kMillisecondsPerDay, kMillisecondsInNanos>(
+ milliseconds, impala_timestamp);
+}
+
+constexpr int64_t kMicrosecondsInNanos = kMillisecondsInNanos / INT64_C(1000);
+
+inline void MicrosecondsToImpalaTimestamp(const int64_t microseconds,
+ Int96* impala_timestamp) {
+ ArrowTimestampToImpalaTimestamp<kMicrosecondsPerDay, kMicrosecondsInNanos>(
+ microseconds, impala_timestamp);
+}
+
+constexpr int64_t kNanosecondsInNanos = INT64_C(1);
+
+inline void NanosecondsToImpalaTimestamp(const int64_t nanoseconds,
+ Int96* impala_timestamp) {
+ ArrowTimestampToImpalaTimestamp<kNanosecondsPerDay, kNanosecondsInNanos>(
+ nanoseconds, impala_timestamp);
+}
+
+} // namespace internal
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/column_writer_test.cc b/src/arrow/cpp/src/parquet/column_writer_test.cc
new file mode 100644
index 000000000..e895b7359
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/column_writer_test.cc
@@ -0,0 +1,1019 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/io/buffered.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_builders.h"
+
+#include "parquet/column_reader.h"
+#include "parquet/column_writer.h"
+#include "parquet/file_writer.h"
+#include "parquet/metadata.h"
+#include "parquet/platform.h"
+#include "parquet/properties.h"
+#include "parquet/statistics.h"
+#include "parquet/test_util.h"
+#include "parquet/thrift_internal.h"
+#include "parquet/types.h"
+
+namespace BitUtil = arrow::BitUtil;
+
+namespace parquet {
+
+using schema::GroupNode;
+using schema::NodePtr;
+using schema::PrimitiveNode;
+
+namespace test {
+
+// The default size used in most tests.
+const int SMALL_SIZE = 100;
+#ifdef PARQUET_VALGRIND
+// Larger size to test some corner cases, only used in some specific cases.
+const int LARGE_SIZE = 10000;
+// Very large size to test dictionary fallback.
+const int VERY_LARGE_SIZE = 40000;
+// Reduced dictionary page size to use for testing dictionary fallback with valgrind
+const int64_t DICTIONARY_PAGE_SIZE = 1024;
+#else
+// Larger size to test some corner cases, only used in some specific cases.
+const int LARGE_SIZE = 100000;
+// Very large size to test dictionary fallback.
+const int VERY_LARGE_SIZE = 400000;
+// Dictionary page size to use for testing dictionary fallback
+const int64_t DICTIONARY_PAGE_SIZE = 1024 * 1024;
+#endif
+
+template <typename TestType>
+class TestPrimitiveWriter : public PrimitiveTypedTest<TestType> {
+ public:
+ void SetUp() {
+ this->SetupValuesOut(SMALL_SIZE);
+ writer_properties_ = default_writer_properties();
+ definition_levels_out_.resize(SMALL_SIZE);
+ repetition_levels_out_.resize(SMALL_SIZE);
+
+ this->SetUpSchema(Repetition::REQUIRED);
+
+ descr_ = this->schema_.Column(0);
+ }
+
+ Type::type type_num() { return TestType::type_num; }
+
+ void BuildReader(int64_t num_rows,
+ Compression::type compression = Compression::UNCOMPRESSED) {
+ ASSERT_OK_AND_ASSIGN(auto buffer, sink_->Finish());
+ auto source = std::make_shared<::arrow::io::BufferReader>(buffer);
+ std::unique_ptr<PageReader> page_reader =
+ PageReader::Open(std::move(source), num_rows, compression);
+ reader_ = std::static_pointer_cast<TypedColumnReader<TestType>>(
+ ColumnReader::Make(this->descr_, std::move(page_reader)));
+ }
+
+ std::shared_ptr<TypedColumnWriter<TestType>> BuildWriter(
+ int64_t output_size = SMALL_SIZE,
+ const ColumnProperties& column_properties = ColumnProperties(),
+ const ParquetVersion::type version = ParquetVersion::PARQUET_1_0) {
+ sink_ = CreateOutputStream();
+ WriterProperties::Builder wp_builder;
+ wp_builder.version(version);
+ if (column_properties.encoding() == Encoding::PLAIN_DICTIONARY ||
+ column_properties.encoding() == Encoding::RLE_DICTIONARY) {
+ wp_builder.enable_dictionary();
+ wp_builder.dictionary_pagesize_limit(DICTIONARY_PAGE_SIZE);
+ } else {
+ wp_builder.disable_dictionary();
+ wp_builder.encoding(column_properties.encoding());
+ }
+ wp_builder.max_statistics_size(column_properties.max_statistics_size());
+ writer_properties_ = wp_builder.build();
+
+ metadata_ = ColumnChunkMetaDataBuilder::Make(writer_properties_, this->descr_);
+ std::unique_ptr<PageWriter> pager =
+ PageWriter::Open(sink_, column_properties.compression(),
+ Codec::UseDefaultCompressionLevel(), metadata_.get());
+ std::shared_ptr<ColumnWriter> writer =
+ ColumnWriter::Make(metadata_.get(), std::move(pager), writer_properties_.get());
+ return std::static_pointer_cast<TypedColumnWriter<TestType>>(writer);
+ }
+
+ void ReadColumn(Compression::type compression = Compression::UNCOMPRESSED) {
+ BuildReader(static_cast<int64_t>(this->values_out_.size()), compression);
+ reader_->ReadBatch(static_cast<int>(this->values_out_.size()),
+ definition_levels_out_.data(), repetition_levels_out_.data(),
+ this->values_out_ptr_, &values_read_);
+ this->SyncValuesOut();
+ }
+
+ void ReadColumnFully(Compression::type compression = Compression::UNCOMPRESSED);
+
+ void TestRequiredWithEncoding(Encoding::type encoding) {
+ return TestRequiredWithSettings(encoding, Compression::UNCOMPRESSED, false, false);
+ }
+
+ void TestRequiredWithSettings(
+ Encoding::type encoding, Compression::type compression, bool enable_dictionary,
+ bool enable_statistics, int64_t num_rows = SMALL_SIZE,
+ int compression_level = Codec::UseDefaultCompressionLevel()) {
+ this->GenerateData(num_rows);
+
+ this->WriteRequiredWithSettings(encoding, compression, enable_dictionary,
+ enable_statistics, compression_level, num_rows);
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCompare(compression, num_rows));
+
+ this->WriteRequiredWithSettingsSpaced(encoding, compression, enable_dictionary,
+ enable_statistics, num_rows, compression_level);
+ ASSERT_NO_FATAL_FAILURE(this->ReadAndCompare(compression, num_rows));
+ }
+
+ void TestDictionaryFallbackEncoding(ParquetVersion::type version) {
+ this->GenerateData(VERY_LARGE_SIZE);
+ ColumnProperties column_properties;
+ column_properties.set_dictionary_enabled(true);
+
+ if (version == ParquetVersion::PARQUET_1_0) {
+ column_properties.set_encoding(Encoding::PLAIN_DICTIONARY);
+ } else {
+ column_properties.set_encoding(Encoding::RLE_DICTIONARY);
+ }
+
+ auto writer = this->BuildWriter(VERY_LARGE_SIZE, column_properties, version);
+
+ writer->WriteBatch(this->values_.size(), nullptr, nullptr, this->values_ptr_);
+ writer->Close();
+
+ // Read all rows so we are sure that also the non-dictionary pages are read correctly
+ this->SetupValuesOut(VERY_LARGE_SIZE);
+ this->ReadColumnFully();
+ ASSERT_EQ(VERY_LARGE_SIZE, this->values_read_);
+ this->values_.resize(VERY_LARGE_SIZE);
+ ASSERT_EQ(this->values_, this->values_out_);
+ std::vector<Encoding::type> encodings = this->metadata_encodings();
+
+ if (this->type_num() == Type::BOOLEAN) {
+ // Dictionary encoding is not allowed for boolean type
+ // There are 2 encodings (PLAIN, RLE) in a non dictionary encoding case
+ std::vector<Encoding::type> expected({Encoding::PLAIN, Encoding::RLE});
+ ASSERT_EQ(encodings, expected);
+ } else if (version == ParquetVersion::PARQUET_1_0) {
+ // There are 4 encodings (PLAIN_DICTIONARY, PLAIN, RLE, PLAIN) in a fallback case
+ // for version 1.0
+ std::vector<Encoding::type> expected(
+ {Encoding::PLAIN_DICTIONARY, Encoding::PLAIN, Encoding::RLE, Encoding::PLAIN});
+ ASSERT_EQ(encodings, expected);
+ } else {
+ // There are 4 encodings (RLE_DICTIONARY, PLAIN, RLE, PLAIN) in a fallback case for
+ // version 2.0
+ std::vector<Encoding::type> expected(
+ {Encoding::RLE_DICTIONARY, Encoding::PLAIN, Encoding::RLE, Encoding::PLAIN});
+ ASSERT_EQ(encodings, expected);
+ }
+
+ std::vector<parquet::PageEncodingStats> encoding_stats =
+ this->metadata_encoding_stats();
+ if (this->type_num() == Type::BOOLEAN) {
+ ASSERT_EQ(encoding_stats[0].encoding, Encoding::PLAIN);
+ ASSERT_EQ(encoding_stats[0].page_type, PageType::DATA_PAGE);
+ } else if (version == ParquetVersion::PARQUET_1_0) {
+ std::vector<Encoding::type> expected(
+ {Encoding::PLAIN_DICTIONARY, Encoding::PLAIN, Encoding::PLAIN_DICTIONARY});
+ ASSERT_EQ(encoding_stats[0].encoding, expected[0]);
+ ASSERT_EQ(encoding_stats[0].page_type, PageType::DICTIONARY_PAGE);
+ for (size_t i = 1; i < encoding_stats.size(); i++) {
+ ASSERT_EQ(encoding_stats[i].encoding, expected[i]);
+ ASSERT_EQ(encoding_stats[i].page_type, PageType::DATA_PAGE);
+ }
+ } else {
+ std::vector<Encoding::type> expected(
+ {Encoding::PLAIN, Encoding::PLAIN, Encoding::RLE_DICTIONARY});
+ ASSERT_EQ(encoding_stats[0].encoding, expected[0]);
+ ASSERT_EQ(encoding_stats[0].page_type, PageType::DICTIONARY_PAGE);
+ for (size_t i = 1; i < encoding_stats.size(); i++) {
+ ASSERT_EQ(encoding_stats[i].encoding, expected[i]);
+ ASSERT_EQ(encoding_stats[i].page_type, PageType::DATA_PAGE);
+ }
+ }
+ }
+
+ void WriteRequiredWithSettings(Encoding::type encoding, Compression::type compression,
+ bool enable_dictionary, bool enable_statistics,
+ int compression_level, int64_t num_rows) {
+ ColumnProperties column_properties(encoding, compression, enable_dictionary,
+ enable_statistics);
+ column_properties.set_compression_level(compression_level);
+ std::shared_ptr<TypedColumnWriter<TestType>> writer =
+ this->BuildWriter(num_rows, column_properties);
+ writer->WriteBatch(this->values_.size(), nullptr, nullptr, this->values_ptr_);
+ // The behaviour should be independent from the number of Close() calls
+ writer->Close();
+ writer->Close();
+ }
+
+ void WriteRequiredWithSettingsSpaced(Encoding::type encoding,
+ Compression::type compression,
+ bool enable_dictionary, bool enable_statistics,
+ int64_t num_rows, int compression_level) {
+ std::vector<uint8_t> valid_bits(
+ BitUtil::BytesForBits(static_cast<uint32_t>(this->values_.size())) + 1, 255);
+ ColumnProperties column_properties(encoding, compression, enable_dictionary,
+ enable_statistics);
+ column_properties.set_compression_level(compression_level);
+ std::shared_ptr<TypedColumnWriter<TestType>> writer =
+ this->BuildWriter(num_rows, column_properties);
+ writer->WriteBatchSpaced(this->values_.size(), nullptr, nullptr, valid_bits.data(), 0,
+ this->values_ptr_);
+ // The behaviour should be independent from the number of Close() calls
+ writer->Close();
+ writer->Close();
+ }
+
+ void ReadAndCompare(Compression::type compression, int64_t num_rows) {
+ this->SetupValuesOut(num_rows);
+ this->ReadColumnFully(compression);
+ auto comparator = MakeComparator<TestType>(this->descr_);
+ for (size_t i = 0; i < this->values_.size(); i++) {
+ if (comparator->Compare(this->values_[i], this->values_out_[i]) ||
+ comparator->Compare(this->values_out_[i], this->values_[i])) {
+ std::cout << "Failed at " << i << std::endl;
+ }
+ ASSERT_FALSE(comparator->Compare(this->values_[i], this->values_out_[i]));
+ ASSERT_FALSE(comparator->Compare(this->values_out_[i], this->values_[i]));
+ }
+ ASSERT_EQ(this->values_, this->values_out_);
+ }
+
+ int64_t metadata_num_values() {
+ // Metadata accessor must be created lazily.
+ // This is because the ColumnChunkMetaData semantics dictate the metadata object is
+ // complete (no changes to the metadata buffer can be made after instantiation)
+ auto metadata_accessor =
+ ColumnChunkMetaData::Make(metadata_->contents(), this->descr_);
+ return metadata_accessor->num_values();
+ }
+
+ bool metadata_is_stats_set() {
+ // Metadata accessor must be created lazily.
+ // This is because the ColumnChunkMetaData semantics dictate the metadata object is
+ // complete (no changes to the metadata buffer can be made after instantiation)
+ ApplicationVersion app_version(this->writer_properties_->created_by());
+ auto metadata_accessor =
+ ColumnChunkMetaData::Make(metadata_->contents(), this->descr_, &app_version);
+ return metadata_accessor->is_stats_set();
+ }
+
+ std::pair<bool, bool> metadata_stats_has_min_max() {
+ // Metadata accessor must be created lazily.
+ // This is because the ColumnChunkMetaData semantics dictate the metadata object is
+ // complete (no changes to the metadata buffer can be made after instantiation)
+ ApplicationVersion app_version(this->writer_properties_->created_by());
+ auto metadata_accessor =
+ ColumnChunkMetaData::Make(metadata_->contents(), this->descr_, &app_version);
+ auto encoded_stats = metadata_accessor->statistics()->Encode();
+ return {encoded_stats.has_min, encoded_stats.has_max};
+ }
+
+ std::vector<Encoding::type> metadata_encodings() {
+ // Metadata accessor must be created lazily.
+ // This is because the ColumnChunkMetaData semantics dictate the metadata object is
+ // complete (no changes to the metadata buffer can be made after instantiation)
+ auto metadata_accessor =
+ ColumnChunkMetaData::Make(metadata_->contents(), this->descr_);
+ return metadata_accessor->encodings();
+ }
+
+ std::vector<parquet::PageEncodingStats> metadata_encoding_stats() {
+ // Metadata accessor must be created lazily.
+ // This is because the ColumnChunkMetaData semantics dictate the metadata object is
+ // complete (no changes to the metadata buffer can be made after instantiation)
+ auto metadata_accessor =
+ ColumnChunkMetaData::Make(metadata_->contents(), this->descr_);
+ return metadata_accessor->encoding_stats();
+ }
+
+ protected:
+ int64_t values_read_;
+ // Keep the reader alive as for ByteArray the lifetime of the ByteArray
+ // content is bound to the reader.
+ std::shared_ptr<TypedColumnReader<TestType>> reader_;
+
+ std::vector<int16_t> definition_levels_out_;
+ std::vector<int16_t> repetition_levels_out_;
+
+ const ColumnDescriptor* descr_;
+
+ private:
+ std::unique_ptr<ColumnChunkMetaDataBuilder> metadata_;
+ std::shared_ptr<::arrow::io::BufferOutputStream> sink_;
+ std::shared_ptr<WriterProperties> writer_properties_;
+ std::vector<std::vector<uint8_t>> data_buffer_;
+};
+
+template <typename TestType>
+void TestPrimitiveWriter<TestType>::ReadColumnFully(Compression::type compression) {
+ int64_t total_values = static_cast<int64_t>(this->values_out_.size());
+ BuildReader(total_values, compression);
+ values_read_ = 0;
+ while (values_read_ < total_values) {
+ int64_t values_read_recently = 0;
+ reader_->ReadBatch(
+ static_cast<int>(this->values_out_.size()) - static_cast<int>(values_read_),
+ definition_levels_out_.data() + values_read_,
+ repetition_levels_out_.data() + values_read_,
+ this->values_out_ptr_ + values_read_, &values_read_recently);
+ values_read_ += values_read_recently;
+ }
+ this->SyncValuesOut();
+}
+
+template <>
+void TestPrimitiveWriter<Int96Type>::ReadAndCompare(Compression::type compression,
+ int64_t num_rows) {
+ this->SetupValuesOut(num_rows);
+ this->ReadColumnFully(compression);
+
+ auto comparator = MakeComparator<Int96Type>(Type::INT96, SortOrder::SIGNED);
+ for (size_t i = 0; i < this->values_.size(); i++) {
+ if (comparator->Compare(this->values_[i], this->values_out_[i]) ||
+ comparator->Compare(this->values_out_[i], this->values_[i])) {
+ std::cout << "Failed at " << i << std::endl;
+ }
+ ASSERT_FALSE(comparator->Compare(this->values_[i], this->values_out_[i]));
+ ASSERT_FALSE(comparator->Compare(this->values_out_[i], this->values_[i]));
+ }
+ ASSERT_EQ(this->values_, this->values_out_);
+}
+
+template <>
+void TestPrimitiveWriter<FLBAType>::ReadColumnFully(Compression::type compression) {
+ int64_t total_values = static_cast<int64_t>(this->values_out_.size());
+ BuildReader(total_values, compression);
+ this->data_buffer_.clear();
+
+ values_read_ = 0;
+ while (values_read_ < total_values) {
+ int64_t values_read_recently = 0;
+ reader_->ReadBatch(
+ static_cast<int>(this->values_out_.size()) - static_cast<int>(values_read_),
+ definition_levels_out_.data() + values_read_,
+ repetition_levels_out_.data() + values_read_,
+ this->values_out_ptr_ + values_read_, &values_read_recently);
+
+ // Copy contents of the pointers
+ std::vector<uint8_t> data(values_read_recently * this->descr_->type_length());
+ uint8_t* data_ptr = data.data();
+ for (int64_t i = 0; i < values_read_recently; i++) {
+ memcpy(data_ptr + this->descr_->type_length() * i,
+ this->values_out_[i + values_read_].ptr, this->descr_->type_length());
+ this->values_out_[i + values_read_].ptr =
+ data_ptr + this->descr_->type_length() * i;
+ }
+ data_buffer_.emplace_back(std::move(data));
+
+ values_read_ += values_read_recently;
+ }
+ this->SyncValuesOut();
+}
+
+typedef ::testing::Types<Int32Type, Int64Type, Int96Type, FloatType, DoubleType,
+ BooleanType, ByteArrayType, FLBAType>
+ TestTypes;
+
+TYPED_TEST_SUITE(TestPrimitiveWriter, TestTypes);
+
+using TestNullValuesWriter = TestPrimitiveWriter<Int32Type>;
+
+TYPED_TEST(TestPrimitiveWriter, RequiredPlain) {
+ this->TestRequiredWithEncoding(Encoding::PLAIN);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredDictionary) {
+ this->TestRequiredWithEncoding(Encoding::PLAIN_DICTIONARY);
+}
+
+/*
+TYPED_TEST(TestPrimitiveWriter, RequiredRLE) {
+ this->TestRequiredWithEncoding(Encoding::RLE);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredBitPacked) {
+ this->TestRequiredWithEncoding(Encoding::BIT_PACKED);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredDeltaBinaryPacked) {
+ this->TestRequiredWithEncoding(Encoding::DELTA_BINARY_PACKED);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredDeltaLengthByteArray) {
+ this->TestRequiredWithEncoding(Encoding::DELTA_LENGTH_BYTE_ARRAY);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredDeltaByteArray) {
+ this->TestRequiredWithEncoding(Encoding::DELTA_BYTE_ARRAY);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredRLEDictionary) {
+ this->TestRequiredWithEncoding(Encoding::RLE_DICTIONARY);
+}
+*/
+
+TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithStats) {
+ this->TestRequiredWithSettings(Encoding::PLAIN, Compression::UNCOMPRESSED, false, true,
+ LARGE_SIZE);
+}
+
+#ifdef ARROW_WITH_SNAPPY
+TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithSnappyCompression) {
+ this->TestRequiredWithSettings(Encoding::PLAIN, Compression::SNAPPY, false, false,
+ LARGE_SIZE);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithStatsAndSnappyCompression) {
+ this->TestRequiredWithSettings(Encoding::PLAIN, Compression::SNAPPY, false, true,
+ LARGE_SIZE);
+}
+#endif
+
+#ifdef ARROW_WITH_BROTLI
+TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithBrotliCompression) {
+ this->TestRequiredWithSettings(Encoding::PLAIN, Compression::BROTLI, false, false,
+ LARGE_SIZE);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithBrotliCompressionAndLevel) {
+ this->TestRequiredWithSettings(Encoding::PLAIN, Compression::BROTLI, false, false,
+ LARGE_SIZE, 10);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithStatsAndBrotliCompression) {
+ this->TestRequiredWithSettings(Encoding::PLAIN, Compression::BROTLI, false, true,
+ LARGE_SIZE);
+}
+
+#endif
+
+#ifdef ARROW_WITH_GZIP
+TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithGzipCompression) {
+ this->TestRequiredWithSettings(Encoding::PLAIN, Compression::GZIP, false, false,
+ LARGE_SIZE);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithGzipCompressionAndLevel) {
+ this->TestRequiredWithSettings(Encoding::PLAIN, Compression::GZIP, false, false,
+ LARGE_SIZE, 10);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithStatsAndGzipCompression) {
+ this->TestRequiredWithSettings(Encoding::PLAIN, Compression::GZIP, false, true,
+ LARGE_SIZE);
+}
+#endif
+
+#ifdef ARROW_WITH_LZ4
+TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithLz4Compression) {
+ this->TestRequiredWithSettings(Encoding::PLAIN, Compression::LZ4, false, false,
+ LARGE_SIZE);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithStatsAndLz4Compression) {
+ this->TestRequiredWithSettings(Encoding::PLAIN, Compression::LZ4, false, true,
+ LARGE_SIZE);
+}
+#endif
+
+#ifdef ARROW_WITH_ZSTD
+TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithZstdCompression) {
+ this->TestRequiredWithSettings(Encoding::PLAIN, Compression::ZSTD, false, false,
+ LARGE_SIZE);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithZstdCompressionAndLevel) {
+ this->TestRequiredWithSettings(Encoding::PLAIN, Compression::ZSTD, false, false,
+ LARGE_SIZE, 6);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredPlainWithStatsAndZstdCompression) {
+ this->TestRequiredWithSettings(Encoding::PLAIN, Compression::ZSTD, false, true,
+ LARGE_SIZE);
+}
+#endif
+
+TYPED_TEST(TestPrimitiveWriter, Optional) {
+ // Optional and non-repeated, with definition levels
+ // but no repetition levels
+ this->SetUpSchema(Repetition::OPTIONAL);
+
+ this->GenerateData(SMALL_SIZE);
+ std::vector<int16_t> definition_levels(SMALL_SIZE, 1);
+ definition_levels[1] = 0;
+
+ auto writer = this->BuildWriter();
+ writer->WriteBatch(this->values_.size(), definition_levels.data(), nullptr,
+ this->values_ptr_);
+ writer->Close();
+
+ // PARQUET-703
+ ASSERT_EQ(100, this->metadata_num_values());
+
+ this->ReadColumn();
+ ASSERT_EQ(99, this->values_read_);
+ this->values_out_.resize(99);
+ this->values_.resize(99);
+ ASSERT_EQ(this->values_, this->values_out_);
+}
+
+TYPED_TEST(TestPrimitiveWriter, OptionalSpaced) {
+ // Optional and non-repeated, with definition levels
+ // but no repetition levels
+ this->SetUpSchema(Repetition::OPTIONAL);
+
+ this->GenerateData(SMALL_SIZE);
+ std::vector<int16_t> definition_levels(SMALL_SIZE, 1);
+ std::vector<uint8_t> valid_bits(::arrow::BitUtil::BytesForBits(SMALL_SIZE), 255);
+
+ definition_levels[SMALL_SIZE - 1] = 0;
+ ::arrow::BitUtil::ClearBit(valid_bits.data(), SMALL_SIZE - 1);
+ definition_levels[1] = 0;
+ ::arrow::BitUtil::ClearBit(valid_bits.data(), 1);
+
+ auto writer = this->BuildWriter();
+ writer->WriteBatchSpaced(this->values_.size(), definition_levels.data(), nullptr,
+ valid_bits.data(), 0, this->values_ptr_);
+ writer->Close();
+
+ // PARQUET-703
+ ASSERT_EQ(100, this->metadata_num_values());
+
+ this->ReadColumn();
+ ASSERT_EQ(98, this->values_read_);
+ this->values_out_.resize(98);
+ this->values_.resize(99);
+ this->values_.erase(this->values_.begin() + 1);
+ ASSERT_EQ(this->values_, this->values_out_);
+}
+
+TYPED_TEST(TestPrimitiveWriter, Repeated) {
+ // Optional and repeated, so definition and repetition levels
+ this->SetUpSchema(Repetition::REPEATED);
+
+ this->GenerateData(SMALL_SIZE);
+ std::vector<int16_t> definition_levels(SMALL_SIZE, 1);
+ definition_levels[1] = 0;
+ std::vector<int16_t> repetition_levels(SMALL_SIZE, 0);
+
+ auto writer = this->BuildWriter();
+ writer->WriteBatch(this->values_.size(), definition_levels.data(),
+ repetition_levels.data(), this->values_ptr_);
+ writer->Close();
+
+ this->ReadColumn();
+ ASSERT_EQ(SMALL_SIZE - 1, this->values_read_);
+ this->values_out_.resize(SMALL_SIZE - 1);
+ this->values_.resize(SMALL_SIZE - 1);
+ ASSERT_EQ(this->values_, this->values_out_);
+}
+
+TYPED_TEST(TestPrimitiveWriter, RequiredLargeChunk) {
+ this->GenerateData(LARGE_SIZE);
+
+ // Test case 1: required and non-repeated, so no definition or repetition levels
+ auto writer = this->BuildWriter(LARGE_SIZE);
+ writer->WriteBatch(this->values_.size(), nullptr, nullptr, this->values_ptr_);
+ writer->Close();
+
+ // Just read the first SMALL_SIZE rows to ensure we could read it back in
+ this->ReadColumn();
+ ASSERT_EQ(SMALL_SIZE, this->values_read_);
+ this->values_.resize(SMALL_SIZE);
+ ASSERT_EQ(this->values_, this->values_out_);
+}
+
+// Test cases for dictionary fallback encoding
+TYPED_TEST(TestPrimitiveWriter, DictionaryFallbackVersion1_0) {
+ this->TestDictionaryFallbackEncoding(ParquetVersion::PARQUET_1_0);
+}
+
+TYPED_TEST(TestPrimitiveWriter, DictionaryFallbackVersion2_0) {
+ this->TestDictionaryFallbackEncoding(ParquetVersion::PARQUET_2_4);
+}
+
+TEST(TestWriter, NullValuesBuffer) {
+ std::shared_ptr<::arrow::io::BufferOutputStream> sink = CreateOutputStream();
+
+ const auto item_node = schema::PrimitiveNode::Make(
+ "item", Repetition::REQUIRED, LogicalType::Int(32, true), Type::INT32);
+ const auto list_node =
+ schema::GroupNode::Make("list", Repetition::REPEATED, {item_node});
+ const auto column_node = schema::GroupNode::Make(
+ "array_of_ints_column", Repetition::OPTIONAL, {list_node}, LogicalType::List());
+ const auto schema_node =
+ schema::GroupNode::Make("schema", Repetition::REQUIRED, {column_node});
+
+ auto file_writer = ParquetFileWriter::Open(
+ sink, std::dynamic_pointer_cast<schema::GroupNode>(schema_node));
+ auto group_writer = file_writer->AppendRowGroup();
+ auto column_writer = group_writer->NextColumn();
+ auto typed_writer = dynamic_cast<Int32Writer*>(column_writer);
+
+ const int64_t num_values = 1;
+ const int16_t def_levels[] = {0};
+ const int16_t rep_levels[] = {0};
+ const uint8_t valid_bits[] = {0};
+ const int64_t valid_bits_offset = 0;
+ const int32_t* values = nullptr;
+
+ typed_writer->WriteBatchSpaced(num_values, def_levels, rep_levels, valid_bits,
+ valid_bits_offset, values);
+}
+
+// PARQUET-719
+// Test case for NULL values
+TEST_F(TestNullValuesWriter, OptionalNullValueChunk) {
+ this->SetUpSchema(Repetition::OPTIONAL);
+
+ this->GenerateData(LARGE_SIZE);
+
+ std::vector<int16_t> definition_levels(LARGE_SIZE, 0);
+ std::vector<int16_t> repetition_levels(LARGE_SIZE, 0);
+
+ auto writer = this->BuildWriter(LARGE_SIZE);
+ // All values being written are NULL
+ writer->WriteBatch(this->values_.size(), definition_levels.data(),
+ repetition_levels.data(), nullptr);
+ writer->Close();
+
+ // Just read the first SMALL_SIZE rows to ensure we could read it back in
+ this->ReadColumn();
+ ASSERT_EQ(0, this->values_read_);
+}
+
+// PARQUET-764
+// Correct bitpacking for boolean write at non-byte boundaries
+using TestBooleanValuesWriter = TestPrimitiveWriter<BooleanType>;
+TEST_F(TestBooleanValuesWriter, AlternateBooleanValues) {
+ this->SetUpSchema(Repetition::REQUIRED);
+ auto writer = this->BuildWriter();
+ for (int i = 0; i < SMALL_SIZE; i++) {
+ bool value = (i % 2 == 0) ? true : false;
+ writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+ writer->Close();
+ this->ReadColumn();
+ for (int i = 0; i < SMALL_SIZE; i++) {
+ ASSERT_EQ((i % 2 == 0) ? true : false, this->values_out_[i]) << i;
+ }
+}
+
+// PARQUET-979
+// Prevent writing large MIN, MAX stats
+using TestByteArrayValuesWriter = TestPrimitiveWriter<ByteArrayType>;
+TEST_F(TestByteArrayValuesWriter, OmitStats) {
+ int min_len = 1024 * 4;
+ int max_len = 1024 * 8;
+ this->SetUpSchema(Repetition::REQUIRED);
+ auto writer = this->BuildWriter();
+
+ values_.resize(SMALL_SIZE);
+ InitWideByteArrayValues(SMALL_SIZE, this->values_, this->buffer_, min_len, max_len);
+ writer->WriteBatch(SMALL_SIZE, nullptr, nullptr, this->values_.data());
+ writer->Close();
+
+ auto has_min_max = this->metadata_stats_has_min_max();
+ ASSERT_FALSE(has_min_max.first);
+ ASSERT_FALSE(has_min_max.second);
+}
+
+// PARQUET-1405
+// Prevent writing large stats in the DataPageHeader
+TEST_F(TestByteArrayValuesWriter, OmitDataPageStats) {
+ int min_len = static_cast<int>(std::pow(10, 7));
+ int max_len = static_cast<int>(std::pow(10, 7));
+ this->SetUpSchema(Repetition::REQUIRED);
+ ColumnProperties column_properties;
+ column_properties.set_statistics_enabled(false);
+ auto writer = this->BuildWriter(SMALL_SIZE, column_properties);
+
+ values_.resize(1);
+ InitWideByteArrayValues(1, this->values_, this->buffer_, min_len, max_len);
+ writer->WriteBatch(1, nullptr, nullptr, this->values_.data());
+ writer->Close();
+
+ ASSERT_NO_THROW(this->ReadColumn());
+}
+
+TEST_F(TestByteArrayValuesWriter, LimitStats) {
+ int min_len = 1024 * 4;
+ int max_len = 1024 * 8;
+ this->SetUpSchema(Repetition::REQUIRED);
+ ColumnProperties column_properties;
+ column_properties.set_max_statistics_size(static_cast<size_t>(max_len));
+ auto writer = this->BuildWriter(SMALL_SIZE, column_properties);
+
+ values_.resize(SMALL_SIZE);
+ InitWideByteArrayValues(SMALL_SIZE, this->values_, this->buffer_, min_len, max_len);
+ writer->WriteBatch(SMALL_SIZE, nullptr, nullptr, this->values_.data());
+ writer->Close();
+
+ ASSERT_TRUE(this->metadata_is_stats_set());
+}
+
+TEST_F(TestByteArrayValuesWriter, CheckDefaultStats) {
+ this->SetUpSchema(Repetition::REQUIRED);
+ auto writer = this->BuildWriter();
+ this->GenerateData(SMALL_SIZE);
+
+ writer->WriteBatch(SMALL_SIZE, nullptr, nullptr, this->values_ptr_);
+ writer->Close();
+
+ ASSERT_TRUE(this->metadata_is_stats_set());
+}
+
+TEST(TestColumnWriter, RepeatedListsUpdateSpacedBug) {
+ // In ARROW-3930 we discovered a bug when writing from Arrow when we had data
+ // that looks like this:
+ //
+ // [null, [0, 1, null, 2, 3, 4, null]]
+
+ // Create schema
+ NodePtr item = schema::Int32("item"); // optional item
+ NodePtr list(GroupNode::Make("b", Repetition::REPEATED, {item}, ConvertedType::LIST));
+ NodePtr bag(GroupNode::Make("bag", Repetition::OPTIONAL, {list})); // optional list
+ std::vector<NodePtr> fields = {bag};
+ NodePtr root = GroupNode::Make("schema", Repetition::REPEATED, fields);
+
+ SchemaDescriptor schema;
+ schema.Init(root);
+
+ auto sink = CreateOutputStream();
+ auto props = WriterProperties::Builder().build();
+
+ auto metadata = ColumnChunkMetaDataBuilder::Make(props, schema.Column(0));
+ std::unique_ptr<PageWriter> pager =
+ PageWriter::Open(sink, Compression::UNCOMPRESSED,
+ Codec::UseDefaultCompressionLevel(), metadata.get());
+ std::shared_ptr<ColumnWriter> writer =
+ ColumnWriter::Make(metadata.get(), std::move(pager), props.get());
+ auto typed_writer = std::static_pointer_cast<TypedColumnWriter<Int32Type>>(writer);
+
+ std::vector<int16_t> def_levels = {1, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3};
+ std::vector<int16_t> rep_levels = {0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
+ std::vector<int32_t> values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+
+ // Write the values into uninitialized memory
+ ASSERT_OK_AND_ASSIGN(auto values_buffer, ::arrow::AllocateBuffer(64));
+ memcpy(values_buffer->mutable_data(), values.data(), 13 * sizeof(int32_t));
+ auto values_data = reinterpret_cast<const int32_t*>(values_buffer->data());
+
+ std::shared_ptr<Buffer> valid_bits;
+ ASSERT_OK_AND_ASSIGN(valid_bits, ::arrow::internal::BytesToBits(
+ {1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1}));
+
+ // valgrind will warn about out of bounds access into def_levels_data
+ typed_writer->WriteBatchSpaced(14, def_levels.data(), rep_levels.data(),
+ valid_bits->data(), 0, values_data);
+ writer->Close();
+}
+
+void GenerateLevels(int min_repeat_factor, int max_repeat_factor, int max_level,
+ std::vector<int16_t>& input_levels) {
+ // for each repetition count up to max_repeat_factor
+ for (int repeat = min_repeat_factor; repeat <= max_repeat_factor; repeat++) {
+ // repeat count increases by a factor of 2 for every iteration
+ int repeat_count = (1 << repeat);
+ // generate levels for repetition count up to the maximum level
+ int16_t value = 0;
+ int bwidth = 0;
+ while (value <= max_level) {
+ for (int i = 0; i < repeat_count; i++) {
+ input_levels.push_back(value);
+ }
+ value = static_cast<int16_t>((2 << bwidth) - 1);
+ bwidth++;
+ }
+ }
+}
+
+void EncodeLevels(Encoding::type encoding, int16_t max_level, int num_levels,
+ const int16_t* input_levels, std::vector<uint8_t>& bytes) {
+ LevelEncoder encoder;
+ int levels_count = 0;
+ bytes.resize(2 * num_levels);
+ ASSERT_EQ(2 * num_levels, static_cast<int>(bytes.size()));
+ // encode levels
+ if (encoding == Encoding::RLE) {
+ // leave space to write the rle length value
+ encoder.Init(encoding, max_level, num_levels, bytes.data() + sizeof(int32_t),
+ static_cast<int>(bytes.size()));
+
+ levels_count = encoder.Encode(num_levels, input_levels);
+ (reinterpret_cast<int32_t*>(bytes.data()))[0] = encoder.len();
+ } else {
+ encoder.Init(encoding, max_level, num_levels, bytes.data(),
+ static_cast<int>(bytes.size()));
+ levels_count = encoder.Encode(num_levels, input_levels);
+ }
+ ASSERT_EQ(num_levels, levels_count);
+}
+
+void VerifyDecodingLevels(Encoding::type encoding, int16_t max_level,
+ std::vector<int16_t>& input_levels,
+ std::vector<uint8_t>& bytes) {
+ LevelDecoder decoder;
+ int levels_count = 0;
+ std::vector<int16_t> output_levels;
+ int num_levels = static_cast<int>(input_levels.size());
+
+ output_levels.resize(num_levels);
+ ASSERT_EQ(num_levels, static_cast<int>(output_levels.size()));
+
+ // Decode levels and test with multiple decode calls
+ decoder.SetData(encoding, max_level, num_levels, bytes.data(),
+ static_cast<int32_t>(bytes.size()));
+ int decode_count = 4;
+ int num_inner_levels = num_levels / decode_count;
+ // Try multiple decoding on a single SetData call
+ for (int ct = 0; ct < decode_count; ct++) {
+ int offset = ct * num_inner_levels;
+ levels_count = decoder.Decode(num_inner_levels, output_levels.data());
+ ASSERT_EQ(num_inner_levels, levels_count);
+ for (int i = 0; i < num_inner_levels; i++) {
+ EXPECT_EQ(input_levels[i + offset], output_levels[i]);
+ }
+ }
+ // check the remaining levels
+ int num_levels_completed = decode_count * (num_levels / decode_count);
+ int num_remaining_levels = num_levels - num_levels_completed;
+ if (num_remaining_levels > 0) {
+ levels_count = decoder.Decode(num_remaining_levels, output_levels.data());
+ ASSERT_EQ(num_remaining_levels, levels_count);
+ for (int i = 0; i < num_remaining_levels; i++) {
+ EXPECT_EQ(input_levels[i + num_levels_completed], output_levels[i]);
+ }
+ }
+ // Test zero Decode values
+ ASSERT_EQ(0, decoder.Decode(1, output_levels.data()));
+}
+
+void VerifyDecodingMultipleSetData(Encoding::type encoding, int16_t max_level,
+ std::vector<int16_t>& input_levels,
+ std::vector<std::vector<uint8_t>>& bytes) {
+ LevelDecoder decoder;
+ int levels_count = 0;
+ std::vector<int16_t> output_levels;
+
+ // Decode levels and test with multiple SetData calls
+ int setdata_count = static_cast<int>(bytes.size());
+ int num_levels = static_cast<int>(input_levels.size()) / setdata_count;
+ output_levels.resize(num_levels);
+ // Try multiple SetData
+ for (int ct = 0; ct < setdata_count; ct++) {
+ int offset = ct * num_levels;
+ ASSERT_EQ(num_levels, static_cast<int>(output_levels.size()));
+ decoder.SetData(encoding, max_level, num_levels, bytes[ct].data(),
+ static_cast<int32_t>(bytes[ct].size()));
+ levels_count = decoder.Decode(num_levels, output_levels.data());
+ ASSERT_EQ(num_levels, levels_count);
+ for (int i = 0; i < num_levels; i++) {
+ EXPECT_EQ(input_levels[i + offset], output_levels[i]);
+ }
+ }
+}
+
+// Test levels with maximum bit-width from 1 to 8
+// increase the repetition count for each iteration by a factor of 2
+TEST(TestLevels, TestLevelsDecodeMultipleBitWidth) {
+ int min_repeat_factor = 0;
+ int max_repeat_factor = 7; // 128
+ int max_bit_width = 8;
+ std::vector<int16_t> input_levels;
+ std::vector<uint8_t> bytes;
+ Encoding::type encodings[2] = {Encoding::RLE, Encoding::BIT_PACKED};
+
+ // for each encoding
+ for (int encode = 0; encode < 2; encode++) {
+ Encoding::type encoding = encodings[encode];
+ // BIT_PACKED requires a sequence of at least 8
+ if (encoding == Encoding::BIT_PACKED) min_repeat_factor = 3;
+ // for each maximum bit-width
+ for (int bit_width = 1; bit_width <= max_bit_width; bit_width++) {
+ // find the maximum level for the current bit_width
+ int16_t max_level = static_cast<int16_t>((1 << bit_width) - 1);
+ // Generate levels
+ GenerateLevels(min_repeat_factor, max_repeat_factor, max_level, input_levels);
+ ASSERT_NO_FATAL_FAILURE(EncodeLevels(encoding, max_level,
+ static_cast<int>(input_levels.size()),
+ input_levels.data(), bytes));
+ ASSERT_NO_FATAL_FAILURE(
+ VerifyDecodingLevels(encoding, max_level, input_levels, bytes));
+ input_levels.clear();
+ }
+ }
+}
+
+// Test multiple decoder SetData calls
+TEST(TestLevels, TestLevelsDecodeMultipleSetData) {
+ int min_repeat_factor = 3;
+ int max_repeat_factor = 7; // 128
+ int bit_width = 8;
+ int16_t max_level = static_cast<int16_t>((1 << bit_width) - 1);
+ std::vector<int16_t> input_levels;
+ std::vector<std::vector<uint8_t>> bytes;
+ Encoding::type encodings[2] = {Encoding::RLE, Encoding::BIT_PACKED};
+ GenerateLevels(min_repeat_factor, max_repeat_factor, max_level, input_levels);
+ int num_levels = static_cast<int>(input_levels.size());
+ int setdata_factor = 8;
+ int split_level_size = num_levels / setdata_factor;
+ bytes.resize(setdata_factor);
+
+ // for each encoding
+ for (int encode = 0; encode < 2; encode++) {
+ Encoding::type encoding = encodings[encode];
+ for (int rf = 0; rf < setdata_factor; rf++) {
+ int offset = rf * split_level_size;
+ ASSERT_NO_FATAL_FAILURE(EncodeLevels(
+ encoding, max_level, split_level_size,
+ reinterpret_cast<int16_t*>(input_levels.data()) + offset, bytes[rf]));
+ }
+ ASSERT_NO_FATAL_FAILURE(
+ VerifyDecodingMultipleSetData(encoding, max_level, input_levels, bytes));
+ }
+}
+
+TEST(TestLevelEncoder, MinimumBufferSize) {
+ // PARQUET-676, PARQUET-698
+ const int kNumToEncode = 1024;
+
+ std::vector<int16_t> levels;
+ for (int i = 0; i < kNumToEncode; ++i) {
+ if (i % 9 == 0) {
+ levels.push_back(0);
+ } else {
+ levels.push_back(1);
+ }
+ }
+
+ std::vector<uint8_t> output(
+ LevelEncoder::MaxBufferSize(Encoding::RLE, 1, kNumToEncode));
+
+ LevelEncoder encoder;
+ encoder.Init(Encoding::RLE, 1, kNumToEncode, output.data(),
+ static_cast<int>(output.size()));
+ int encode_count = encoder.Encode(kNumToEncode, levels.data());
+
+ ASSERT_EQ(kNumToEncode, encode_count);
+}
+
+TEST(TestLevelEncoder, MinimumBufferSize2) {
+ // PARQUET-708
+ // Test the worst case for bit_width=2 consisting of
+ // LiteralRun(size=8)
+ // RepeatedRun(size=8)
+ // LiteralRun(size=8)
+ // ...
+ const int kNumToEncode = 1024;
+
+ std::vector<int16_t> levels;
+ for (int i = 0; i < kNumToEncode; ++i) {
+ // This forces a literal run of 00000001
+ // followed by eight 1s
+ if ((i % 16) < 7) {
+ levels.push_back(0);
+ } else {
+ levels.push_back(1);
+ }
+ }
+
+ for (int16_t bit_width = 1; bit_width <= 8; bit_width++) {
+ std::vector<uint8_t> output(
+ LevelEncoder::MaxBufferSize(Encoding::RLE, bit_width, kNumToEncode));
+
+ LevelEncoder encoder;
+ encoder.Init(Encoding::RLE, bit_width, kNumToEncode, output.data(),
+ static_cast<int>(output.size()));
+ int encode_count = encoder.Encode(kNumToEncode, levels.data());
+
+ ASSERT_EQ(kNumToEncode, encode_count);
+ }
+}
+
+} // namespace test
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encoding.cc b/src/arrow/cpp/src/parquet/encoding.cc
new file mode 100644
index 000000000..2639c3dd4
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encoding.cc
@@ -0,0 +1,2597 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/encoding.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_dict.h"
+#include "arrow/stl_allocator.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bit_stream_utils.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/bitmap_writer.h"
+#include "arrow/util/byte_stream_split.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/hashing.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/rle_encoding.h"
+#include "arrow/util/ubsan.h"
+#include "arrow/visitor_inline.h"
+#include "parquet/exception.h"
+#include "parquet/platform.h"
+#include "parquet/schema.h"
+#include "parquet/types.h"
+
+namespace BitUtil = arrow::BitUtil;
+
+using arrow::Status;
+using arrow::VisitNullBitmapInline;
+using arrow::internal::checked_cast;
+
+template <typename T>
+using ArrowPoolVector = std::vector<T, ::arrow::stl::allocator<T>>;
+
+namespace parquet {
+namespace {
+
+constexpr int64_t kInMemoryDefaultCapacity = 1024;
+// The Parquet spec isn't very clear whether ByteArray lengths are signed or
+// unsigned, but the Java implementation uses signed ints.
+constexpr size_t kMaxByteArraySize = std::numeric_limits<int32_t>::max();
+
+class EncoderImpl : virtual public Encoder {
+ public:
+ EncoderImpl(const ColumnDescriptor* descr, Encoding::type encoding, MemoryPool* pool)
+ : descr_(descr),
+ encoding_(encoding),
+ pool_(pool),
+ type_length_(descr ? descr->type_length() : -1) {}
+
+ Encoding::type encoding() const override { return encoding_; }
+
+ MemoryPool* memory_pool() const override { return pool_; }
+
+ protected:
+ // For accessing type-specific metadata, like FIXED_LEN_BYTE_ARRAY
+ const ColumnDescriptor* descr_;
+ const Encoding::type encoding_;
+ MemoryPool* pool_;
+
+ /// Type length from descr
+ int type_length_;
+};
+
+// ----------------------------------------------------------------------
+// Plain encoder implementation
+
+template <typename DType>
+class PlainEncoder : public EncoderImpl, virtual public TypedEncoder<DType> {
+ public:
+ using T = typename DType::c_type;
+
+ explicit PlainEncoder(const ColumnDescriptor* descr, MemoryPool* pool)
+ : EncoderImpl(descr, Encoding::PLAIN, pool), sink_(pool) {}
+
+ int64_t EstimatedDataEncodedSize() override { return sink_.length(); }
+
+ std::shared_ptr<Buffer> FlushValues() override {
+ std::shared_ptr<Buffer> buffer;
+ PARQUET_THROW_NOT_OK(sink_.Finish(&buffer));
+ return buffer;
+ }
+
+ using TypedEncoder<DType>::Put;
+
+ void Put(const T* buffer, int num_values) override;
+
+ void Put(const ::arrow::Array& values) override;
+
+ void PutSpaced(const T* src, int num_values, const uint8_t* valid_bits,
+ int64_t valid_bits_offset) override {
+ if (valid_bits != NULLPTR) {
+ PARQUET_ASSIGN_OR_THROW(auto buffer, ::arrow::AllocateBuffer(num_values * sizeof(T),
+ this->memory_pool()));
+ T* data = reinterpret_cast<T*>(buffer->mutable_data());
+ int num_valid_values = ::arrow::util::internal::SpacedCompress<T>(
+ src, num_values, valid_bits, valid_bits_offset, data);
+ Put(data, num_valid_values);
+ } else {
+ Put(src, num_values);
+ }
+ }
+
+ void UnsafePutByteArray(const void* data, uint32_t length) {
+ DCHECK(length == 0 || data != nullptr) << "Value ptr cannot be NULL";
+ sink_.UnsafeAppend(&length, sizeof(uint32_t));
+ sink_.UnsafeAppend(data, static_cast<int64_t>(length));
+ }
+
+ void Put(const ByteArray& val) {
+ // Write the result to the output stream
+ const int64_t increment = static_cast<int64_t>(val.len + sizeof(uint32_t));
+ if (ARROW_PREDICT_FALSE(sink_.length() + increment > sink_.capacity())) {
+ PARQUET_THROW_NOT_OK(sink_.Reserve(increment));
+ }
+ UnsafePutByteArray(val.ptr, val.len);
+ }
+
+ protected:
+ template <typename ArrayType>
+ void PutBinaryArray(const ArrayType& array) {
+ const int64_t total_bytes =
+ array.value_offset(array.length()) - array.value_offset(0);
+ PARQUET_THROW_NOT_OK(sink_.Reserve(total_bytes + array.length() * sizeof(uint32_t)));
+
+ PARQUET_THROW_NOT_OK(::arrow::VisitArrayDataInline<typename ArrayType::TypeClass>(
+ *array.data(),
+ [&](::arrow::util::string_view view) {
+ if (ARROW_PREDICT_FALSE(view.size() > kMaxByteArraySize)) {
+ return Status::Invalid("Parquet cannot store strings with size 2GB or more");
+ }
+ UnsafePutByteArray(view.data(), static_cast<uint32_t>(view.size()));
+ return Status::OK();
+ },
+ []() { return Status::OK(); }));
+ }
+
+ ::arrow::BufferBuilder sink_;
+};
+
+template <typename DType>
+void PlainEncoder<DType>::Put(const T* buffer, int num_values) {
+ if (num_values > 0) {
+ PARQUET_THROW_NOT_OK(sink_.Append(buffer, num_values * sizeof(T)));
+ }
+}
+
+template <>
+inline void PlainEncoder<ByteArrayType>::Put(const ByteArray* src, int num_values) {
+ for (int i = 0; i < num_values; ++i) {
+ Put(src[i]);
+ }
+}
+
+template <typename ArrayType>
+void DirectPutImpl(const ::arrow::Array& values, ::arrow::BufferBuilder* sink) {
+ if (values.type_id() != ArrayType::TypeClass::type_id) {
+ std::string type_name = ArrayType::TypeClass::type_name();
+ throw ParquetException("direct put to " + type_name + " from " +
+ values.type()->ToString() + " not supported");
+ }
+
+ using value_type = typename ArrayType::value_type;
+ constexpr auto value_size = sizeof(value_type);
+ auto raw_values = checked_cast<const ArrayType&>(values).raw_values();
+
+ if (values.null_count() == 0) {
+ // no nulls, just dump the data
+ PARQUET_THROW_NOT_OK(sink->Append(raw_values, values.length() * value_size));
+ } else {
+ PARQUET_THROW_NOT_OK(
+ sink->Reserve((values.length() - values.null_count()) * value_size));
+
+ for (int64_t i = 0; i < values.length(); i++) {
+ if (values.IsValid(i)) {
+ sink->UnsafeAppend(&raw_values[i], value_size);
+ }
+ }
+ }
+}
+
+template <>
+void PlainEncoder<Int32Type>::Put(const ::arrow::Array& values) {
+ DirectPutImpl<::arrow::Int32Array>(values, &sink_);
+}
+
+template <>
+void PlainEncoder<Int64Type>::Put(const ::arrow::Array& values) {
+ DirectPutImpl<::arrow::Int64Array>(values, &sink_);
+}
+
+template <>
+void PlainEncoder<Int96Type>::Put(const ::arrow::Array& values) {
+ ParquetException::NYI("direct put to Int96");
+}
+
+template <>
+void PlainEncoder<FloatType>::Put(const ::arrow::Array& values) {
+ DirectPutImpl<::arrow::FloatArray>(values, &sink_);
+}
+
+template <>
+void PlainEncoder<DoubleType>::Put(const ::arrow::Array& values) {
+ DirectPutImpl<::arrow::DoubleArray>(values, &sink_);
+}
+
+template <typename DType>
+void PlainEncoder<DType>::Put(const ::arrow::Array& values) {
+ ParquetException::NYI("direct put of " + values.type()->ToString());
+}
+
+void AssertBaseBinary(const ::arrow::Array& values) {
+ if (!::arrow::is_base_binary_like(values.type_id())) {
+ throw ParquetException("Only BaseBinaryArray and subclasses supported");
+ }
+}
+
+template <>
+inline void PlainEncoder<ByteArrayType>::Put(const ::arrow::Array& values) {
+ AssertBaseBinary(values);
+
+ if (::arrow::is_binary_like(values.type_id())) {
+ PutBinaryArray(checked_cast<const ::arrow::BinaryArray&>(values));
+ } else {
+ DCHECK(::arrow::is_large_binary_like(values.type_id()));
+ PutBinaryArray(checked_cast<const ::arrow::LargeBinaryArray&>(values));
+ }
+}
+
+void AssertFixedSizeBinary(const ::arrow::Array& values, int type_length) {
+ if (values.type_id() != ::arrow::Type::FIXED_SIZE_BINARY &&
+ values.type_id() != ::arrow::Type::DECIMAL) {
+ throw ParquetException("Only FixedSizeBinaryArray and subclasses supported");
+ }
+ if (checked_cast<const ::arrow::FixedSizeBinaryType&>(*values.type()).byte_width() !=
+ type_length) {
+ throw ParquetException("Size mismatch: " + values.type()->ToString() +
+ " should have been " + std::to_string(type_length) + " wide");
+ }
+}
+
+template <>
+inline void PlainEncoder<FLBAType>::Put(const ::arrow::Array& values) {
+ AssertFixedSizeBinary(values, descr_->type_length());
+ const auto& data = checked_cast<const ::arrow::FixedSizeBinaryArray&>(values);
+
+ if (data.null_count() == 0) {
+ // no nulls, just dump the data
+ PARQUET_THROW_NOT_OK(
+ sink_.Append(data.raw_values(), data.length() * data.byte_width()));
+ } else {
+ const int64_t total_bytes =
+ data.length() * data.byte_width() - data.null_count() * data.byte_width();
+ PARQUET_THROW_NOT_OK(sink_.Reserve(total_bytes));
+ for (int64_t i = 0; i < data.length(); i++) {
+ if (data.IsValid(i)) {
+ sink_.UnsafeAppend(data.Value(i), data.byte_width());
+ }
+ }
+ }
+}
+
+template <>
+inline void PlainEncoder<FLBAType>::Put(const FixedLenByteArray* src, int num_values) {
+ if (descr_->type_length() == 0) {
+ return;
+ }
+ for (int i = 0; i < num_values; ++i) {
+ // Write the result to the output stream
+ DCHECK(src[i].ptr != nullptr) << "Value ptr cannot be NULL";
+ PARQUET_THROW_NOT_OK(sink_.Append(src[i].ptr, descr_->type_length()));
+ }
+}
+
+template <>
+class PlainEncoder<BooleanType> : public EncoderImpl, virtual public BooleanEncoder {
+ public:
+ explicit PlainEncoder(const ColumnDescriptor* descr, MemoryPool* pool)
+ : EncoderImpl(descr, Encoding::PLAIN, pool),
+ bits_available_(kInMemoryDefaultCapacity * 8),
+ bits_buffer_(AllocateBuffer(pool, kInMemoryDefaultCapacity)),
+ sink_(pool),
+ bit_writer_(bits_buffer_->mutable_data(),
+ static_cast<int>(bits_buffer_->size())) {}
+
+ int64_t EstimatedDataEncodedSize() override;
+ std::shared_ptr<Buffer> FlushValues() override;
+
+ void Put(const bool* src, int num_values) override;
+
+ void Put(const std::vector<bool>& src, int num_values) override;
+
+ void PutSpaced(const bool* src, int num_values, const uint8_t* valid_bits,
+ int64_t valid_bits_offset) override {
+ if (valid_bits != NULLPTR) {
+ PARQUET_ASSIGN_OR_THROW(auto buffer, ::arrow::AllocateBuffer(num_values * sizeof(T),
+ this->memory_pool()));
+ T* data = reinterpret_cast<T*>(buffer->mutable_data());
+ int num_valid_values = ::arrow::util::internal::SpacedCompress<T>(
+ src, num_values, valid_bits, valid_bits_offset, data);
+ Put(data, num_valid_values);
+ } else {
+ Put(src, num_values);
+ }
+ }
+
+ void Put(const ::arrow::Array& values) override {
+ if (values.type_id() != ::arrow::Type::BOOL) {
+ throw ParquetException("direct put to boolean from " + values.type()->ToString() +
+ " not supported");
+ }
+
+ const auto& data = checked_cast<const ::arrow::BooleanArray&>(values);
+ if (data.null_count() == 0) {
+ PARQUET_THROW_NOT_OK(sink_.Reserve(BitUtil::BytesForBits(data.length())));
+ // no nulls, just dump the data
+ ::arrow::internal::CopyBitmap(data.data()->GetValues<uint8_t>(1), data.offset(),
+ data.length(), sink_.mutable_data(), sink_.length());
+ } else {
+ auto n_valid = BitUtil::BytesForBits(data.length() - data.null_count());
+ PARQUET_THROW_NOT_OK(sink_.Reserve(n_valid));
+ ::arrow::internal::FirstTimeBitmapWriter writer(sink_.mutable_data(),
+ sink_.length(), n_valid);
+
+ for (int64_t i = 0; i < data.length(); i++) {
+ if (data.IsValid(i)) {
+ if (data.Value(i)) {
+ writer.Set();
+ } else {
+ writer.Clear();
+ }
+ writer.Next();
+ }
+ }
+ writer.Finish();
+ }
+ sink_.UnsafeAdvance(data.length());
+ }
+
+ private:
+ int bits_available_;
+ std::shared_ptr<ResizableBuffer> bits_buffer_;
+ ::arrow::BufferBuilder sink_;
+ ::arrow::BitUtil::BitWriter bit_writer_;
+
+ template <typename SequenceType>
+ void PutImpl(const SequenceType& src, int num_values);
+};
+
+template <typename SequenceType>
+void PlainEncoder<BooleanType>::PutImpl(const SequenceType& src, int num_values) {
+ int bit_offset = 0;
+ if (bits_available_ > 0) {
+ int bits_to_write = std::min(bits_available_, num_values);
+ for (int i = 0; i < bits_to_write; i++) {
+ bit_writer_.PutValue(src[i], 1);
+ }
+ bits_available_ -= bits_to_write;
+ bit_offset = bits_to_write;
+
+ if (bits_available_ == 0) {
+ bit_writer_.Flush();
+ PARQUET_THROW_NOT_OK(
+ sink_.Append(bit_writer_.buffer(), bit_writer_.bytes_written()));
+ bit_writer_.Clear();
+ }
+ }
+
+ int bits_remaining = num_values - bit_offset;
+ while (bit_offset < num_values) {
+ bits_available_ = static_cast<int>(bits_buffer_->size()) * 8;
+
+ int bits_to_write = std::min(bits_available_, bits_remaining);
+ for (int i = bit_offset; i < bit_offset + bits_to_write; i++) {
+ bit_writer_.PutValue(src[i], 1);
+ }
+ bit_offset += bits_to_write;
+ bits_available_ -= bits_to_write;
+ bits_remaining -= bits_to_write;
+
+ if (bits_available_ == 0) {
+ bit_writer_.Flush();
+ PARQUET_THROW_NOT_OK(
+ sink_.Append(bit_writer_.buffer(), bit_writer_.bytes_written()));
+ bit_writer_.Clear();
+ }
+ }
+}
+
+int64_t PlainEncoder<BooleanType>::EstimatedDataEncodedSize() {
+ int64_t position = sink_.length();
+ return position + bit_writer_.bytes_written();
+}
+
+std::shared_ptr<Buffer> PlainEncoder<BooleanType>::FlushValues() {
+ if (bits_available_ > 0) {
+ bit_writer_.Flush();
+ PARQUET_THROW_NOT_OK(sink_.Append(bit_writer_.buffer(), bit_writer_.bytes_written()));
+ bit_writer_.Clear();
+ bits_available_ = static_cast<int>(bits_buffer_->size()) * 8;
+ }
+
+ std::shared_ptr<Buffer> buffer;
+ PARQUET_THROW_NOT_OK(sink_.Finish(&buffer));
+ return buffer;
+}
+
+void PlainEncoder<BooleanType>::Put(const bool* src, int num_values) {
+ PutImpl(src, num_values);
+}
+
+void PlainEncoder<BooleanType>::Put(const std::vector<bool>& src, int num_values) {
+ PutImpl(src, num_values);
+}
+
+// ----------------------------------------------------------------------
+// DictEncoder<T> implementations
+
+template <typename DType>
+struct DictEncoderTraits {
+ using c_type = typename DType::c_type;
+ using MemoTableType = ::arrow::internal::ScalarMemoTable<c_type>;
+};
+
+template <>
+struct DictEncoderTraits<ByteArrayType> {
+ using MemoTableType = ::arrow::internal::BinaryMemoTable<::arrow::BinaryBuilder>;
+};
+
+template <>
+struct DictEncoderTraits<FLBAType> {
+ using MemoTableType = ::arrow::internal::BinaryMemoTable<::arrow::BinaryBuilder>;
+};
+
+// Initially 1024 elements
+static constexpr int32_t kInitialHashTableSize = 1 << 10;
+
+/// See the dictionary encoding section of
+/// https://github.com/Parquet/parquet-format. The encoding supports
+/// streaming encoding. Values are encoded as they are added while the
+/// dictionary is being constructed. At any time, the buffered values
+/// can be written out with the current dictionary size. More values
+/// can then be added to the encoder, including new dictionary
+/// entries.
+template <typename DType>
+class DictEncoderImpl : public EncoderImpl, virtual public DictEncoder<DType> {
+ using MemoTableType = typename DictEncoderTraits<DType>::MemoTableType;
+
+ public:
+ typedef typename DType::c_type T;
+
+ explicit DictEncoderImpl(const ColumnDescriptor* desc, MemoryPool* pool)
+ : EncoderImpl(desc, Encoding::PLAIN_DICTIONARY, pool),
+ buffered_indices_(::arrow::stl::allocator<int32_t>(pool)),
+ dict_encoded_size_(0),
+ memo_table_(pool, kInitialHashTableSize) {}
+
+ ~DictEncoderImpl() override { DCHECK(buffered_indices_.empty()); }
+
+ int dict_encoded_size() override { return dict_encoded_size_; }
+
+ int WriteIndices(uint8_t* buffer, int buffer_len) override {
+ // Write bit width in first byte
+ *buffer = static_cast<uint8_t>(bit_width());
+ ++buffer;
+ --buffer_len;
+
+ ::arrow::util::RleEncoder encoder(buffer, buffer_len, bit_width());
+
+ for (int32_t index : buffered_indices_) {
+ if (!encoder.Put(index)) return -1;
+ }
+ encoder.Flush();
+
+ ClearIndices();
+ return 1 + encoder.len();
+ }
+
+ void set_type_length(int type_length) { this->type_length_ = type_length; }
+
+ /// Returns a conservative estimate of the number of bytes needed to encode the buffered
+ /// indices. Used to size the buffer passed to WriteIndices().
+ int64_t EstimatedDataEncodedSize() override {
+ // Note: because of the way RleEncoder::CheckBufferFull() is called, we have to
+ // reserve
+ // an extra "RleEncoder::MinBufferSize" bytes. These extra bytes won't be used
+ // but not reserving them would cause the encoder to fail.
+ return 1 +
+ ::arrow::util::RleEncoder::MaxBufferSize(
+ bit_width(), static_cast<int>(buffered_indices_.size())) +
+ ::arrow::util::RleEncoder::MinBufferSize(bit_width());
+ }
+
+ /// The minimum bit width required to encode the currently buffered indices.
+ int bit_width() const override {
+ if (ARROW_PREDICT_FALSE(num_entries() == 0)) return 0;
+ if (ARROW_PREDICT_FALSE(num_entries() == 1)) return 1;
+ return BitUtil::Log2(num_entries());
+ }
+
+ /// Encode value. Note that this does not actually write any data, just
+ /// buffers the value's index to be written later.
+ inline void Put(const T& value);
+
+ // Not implemented for other data types
+ inline void PutByteArray(const void* ptr, int32_t length);
+
+ void Put(const T* src, int num_values) override {
+ for (int32_t i = 0; i < num_values; i++) {
+ Put(src[i]);
+ }
+ }
+
+ void PutSpaced(const T* src, int num_values, const uint8_t* valid_bits,
+ int64_t valid_bits_offset) override {
+ ::arrow::internal::VisitSetBitRunsVoid(valid_bits, valid_bits_offset, num_values,
+ [&](int64_t position, int64_t length) {
+ for (int64_t i = 0; i < length; i++) {
+ Put(src[i + position]);
+ }
+ });
+ }
+
+ using TypedEncoder<DType>::Put;
+
+ void Put(const ::arrow::Array& values) override;
+ void PutDictionary(const ::arrow::Array& values) override;
+
+ template <typename ArrowType, typename T = typename ArrowType::c_type>
+ void PutIndicesTyped(const ::arrow::Array& data) {
+ auto values = data.data()->GetValues<T>(1);
+ size_t buffer_position = buffered_indices_.size();
+ buffered_indices_.resize(buffer_position +
+ static_cast<size_t>(data.length() - data.null_count()));
+ ::arrow::internal::VisitSetBitRunsVoid(
+ data.null_bitmap_data(), data.offset(), data.length(),
+ [&](int64_t position, int64_t length) {
+ for (int64_t i = 0; i < length; ++i) {
+ buffered_indices_[buffer_position++] =
+ static_cast<int32_t>(values[i + position]);
+ }
+ });
+ }
+
+ void PutIndices(const ::arrow::Array& data) override {
+ switch (data.type()->id()) {
+ case ::arrow::Type::UINT8:
+ case ::arrow::Type::INT8:
+ return PutIndicesTyped<::arrow::UInt8Type>(data);
+ case ::arrow::Type::UINT16:
+ case ::arrow::Type::INT16:
+ return PutIndicesTyped<::arrow::UInt16Type>(data);
+ case ::arrow::Type::UINT32:
+ case ::arrow::Type::INT32:
+ return PutIndicesTyped<::arrow::UInt32Type>(data);
+ case ::arrow::Type::UINT64:
+ case ::arrow::Type::INT64:
+ return PutIndicesTyped<::arrow::UInt64Type>(data);
+ default:
+ throw ParquetException("Passed non-integer array to PutIndices");
+ }
+ }
+
+ std::shared_ptr<Buffer> FlushValues() override {
+ std::shared_ptr<ResizableBuffer> buffer =
+ AllocateBuffer(this->pool_, EstimatedDataEncodedSize());
+ int result_size = WriteIndices(buffer->mutable_data(),
+ static_cast<int>(EstimatedDataEncodedSize()));
+ PARQUET_THROW_NOT_OK(buffer->Resize(result_size, false));
+ return std::move(buffer);
+ }
+
+ /// Writes out the encoded dictionary to buffer. buffer must be preallocated to
+ /// dict_encoded_size() bytes.
+ void WriteDict(uint8_t* buffer) override;
+
+ /// The number of entries in the dictionary.
+ int num_entries() const override { return memo_table_.size(); }
+
+ private:
+ /// Clears all the indices (but leaves the dictionary).
+ void ClearIndices() { buffered_indices_.clear(); }
+
+ /// Indices that have not yet be written out by WriteIndices().
+ ArrowPoolVector<int32_t> buffered_indices_;
+
+ template <typename ArrayType>
+ void PutBinaryArray(const ArrayType& array) {
+ PARQUET_THROW_NOT_OK(::arrow::VisitArrayDataInline<typename ArrayType::TypeClass>(
+ *array.data(),
+ [&](::arrow::util::string_view view) {
+ if (ARROW_PREDICT_FALSE(view.size() > kMaxByteArraySize)) {
+ return Status::Invalid("Parquet cannot store strings with size 2GB or more");
+ }
+ PutByteArray(view.data(), static_cast<uint32_t>(view.size()));
+ return Status::OK();
+ },
+ []() { return Status::OK(); }));
+ }
+
+ template <typename ArrayType>
+ void PutBinaryDictionaryArray(const ArrayType& array) {
+ DCHECK_EQ(array.null_count(), 0);
+ for (int64_t i = 0; i < array.length(); i++) {
+ auto v = array.GetView(i);
+ if (ARROW_PREDICT_FALSE(v.size() > kMaxByteArraySize)) {
+ throw ParquetException("Parquet cannot store strings with size 2GB or more");
+ }
+ dict_encoded_size_ += static_cast<int>(v.size() + sizeof(uint32_t));
+ int32_t unused_memo_index;
+ PARQUET_THROW_NOT_OK(memo_table_.GetOrInsert(
+ v.data(), static_cast<int32_t>(v.size()), &unused_memo_index));
+ }
+ }
+
+ /// The number of bytes needed to encode the dictionary.
+ int dict_encoded_size_;
+
+ MemoTableType memo_table_;
+};
+
+template <typename DType>
+void DictEncoderImpl<DType>::WriteDict(uint8_t* buffer) {
+ // For primitive types, only a memcpy
+ DCHECK_EQ(static_cast<size_t>(dict_encoded_size_), sizeof(T) * memo_table_.size());
+ memo_table_.CopyValues(0 /* start_pos */, reinterpret_cast<T*>(buffer));
+}
+
+// ByteArray and FLBA already have the dictionary encoded in their data heaps
+template <>
+void DictEncoderImpl<ByteArrayType>::WriteDict(uint8_t* buffer) {
+ memo_table_.VisitValues(0, [&buffer](const ::arrow::util::string_view& v) {
+ uint32_t len = static_cast<uint32_t>(v.length());
+ memcpy(buffer, &len, sizeof(len));
+ buffer += sizeof(len);
+ memcpy(buffer, v.data(), len);
+ buffer += len;
+ });
+}
+
+template <>
+void DictEncoderImpl<FLBAType>::WriteDict(uint8_t* buffer) {
+ memo_table_.VisitValues(0, [&](const ::arrow::util::string_view& v) {
+ DCHECK_EQ(v.length(), static_cast<size_t>(type_length_));
+ memcpy(buffer, v.data(), type_length_);
+ buffer += type_length_;
+ });
+}
+
+template <typename DType>
+inline void DictEncoderImpl<DType>::Put(const T& v) {
+ // Put() implementation for primitive types
+ auto on_found = [](int32_t memo_index) {};
+ auto on_not_found = [this](int32_t memo_index) {
+ dict_encoded_size_ += static_cast<int>(sizeof(T));
+ };
+
+ int32_t memo_index;
+ PARQUET_THROW_NOT_OK(memo_table_.GetOrInsert(v, on_found, on_not_found, &memo_index));
+ buffered_indices_.push_back(memo_index);
+}
+
+template <typename DType>
+inline void DictEncoderImpl<DType>::PutByteArray(const void* ptr, int32_t length) {
+ DCHECK(false);
+}
+
+template <>
+inline void DictEncoderImpl<ByteArrayType>::PutByteArray(const void* ptr,
+ int32_t length) {
+ static const uint8_t empty[] = {0};
+
+ auto on_found = [](int32_t memo_index) {};
+ auto on_not_found = [&](int32_t memo_index) {
+ dict_encoded_size_ += static_cast<int>(length + sizeof(uint32_t));
+ };
+
+ DCHECK(ptr != nullptr || length == 0);
+ ptr = (ptr != nullptr) ? ptr : empty;
+ int32_t memo_index;
+ PARQUET_THROW_NOT_OK(
+ memo_table_.GetOrInsert(ptr, length, on_found, on_not_found, &memo_index));
+ buffered_indices_.push_back(memo_index);
+}
+
+template <>
+inline void DictEncoderImpl<ByteArrayType>::Put(const ByteArray& val) {
+ return PutByteArray(val.ptr, static_cast<int32_t>(val.len));
+}
+
+template <>
+inline void DictEncoderImpl<FLBAType>::Put(const FixedLenByteArray& v) {
+ static const uint8_t empty[] = {0};
+
+ auto on_found = [](int32_t memo_index) {};
+ auto on_not_found = [this](int32_t memo_index) { dict_encoded_size_ += type_length_; };
+
+ DCHECK(v.ptr != nullptr || type_length_ == 0);
+ const void* ptr = (v.ptr != nullptr) ? v.ptr : empty;
+ int32_t memo_index;
+ PARQUET_THROW_NOT_OK(
+ memo_table_.GetOrInsert(ptr, type_length_, on_found, on_not_found, &memo_index));
+ buffered_indices_.push_back(memo_index);
+}
+
+template <>
+void DictEncoderImpl<Int96Type>::Put(const ::arrow::Array& values) {
+ ParquetException::NYI("Direct put to Int96");
+}
+
+template <>
+void DictEncoderImpl<Int96Type>::PutDictionary(const ::arrow::Array& values) {
+ ParquetException::NYI("Direct put to Int96");
+}
+
+template <typename DType>
+void DictEncoderImpl<DType>::Put(const ::arrow::Array& values) {
+ using ArrayType = typename ::arrow::CTypeTraits<typename DType::c_type>::ArrayType;
+ const auto& data = checked_cast<const ArrayType&>(values);
+ if (data.null_count() == 0) {
+ // no nulls, just dump the data
+ for (int64_t i = 0; i < data.length(); i++) {
+ Put(data.Value(i));
+ }
+ } else {
+ for (int64_t i = 0; i < data.length(); i++) {
+ if (data.IsValid(i)) {
+ Put(data.Value(i));
+ }
+ }
+ }
+}
+
+template <>
+void DictEncoderImpl<FLBAType>::Put(const ::arrow::Array& values) {
+ AssertFixedSizeBinary(values, type_length_);
+ const auto& data = checked_cast<const ::arrow::FixedSizeBinaryArray&>(values);
+ if (data.null_count() == 0) {
+ // no nulls, just dump the data
+ for (int64_t i = 0; i < data.length(); i++) {
+ Put(FixedLenByteArray(data.Value(i)));
+ }
+ } else {
+ std::vector<uint8_t> empty(type_length_, 0);
+ for (int64_t i = 0; i < data.length(); i++) {
+ if (data.IsValid(i)) {
+ Put(FixedLenByteArray(data.Value(i)));
+ }
+ }
+ }
+}
+
+template <>
+void DictEncoderImpl<ByteArrayType>::Put(const ::arrow::Array& values) {
+ AssertBaseBinary(values);
+ if (::arrow::is_binary_like(values.type_id())) {
+ PutBinaryArray(checked_cast<const ::arrow::BinaryArray&>(values));
+ } else {
+ DCHECK(::arrow::is_large_binary_like(values.type_id()));
+ PutBinaryArray(checked_cast<const ::arrow::LargeBinaryArray&>(values));
+ }
+}
+
+template <typename DType>
+void AssertCanPutDictionary(DictEncoderImpl<DType>* encoder, const ::arrow::Array& dict) {
+ if (dict.null_count() > 0) {
+ throw ParquetException("Inserted dictionary cannot cannot contain nulls");
+ }
+
+ if (encoder->num_entries() > 0) {
+ throw ParquetException("Can only call PutDictionary on an empty DictEncoder");
+ }
+}
+
+template <typename DType>
+void DictEncoderImpl<DType>::PutDictionary(const ::arrow::Array& values) {
+ AssertCanPutDictionary(this, values);
+
+ using ArrayType = typename ::arrow::CTypeTraits<typename DType::c_type>::ArrayType;
+ const auto& data = checked_cast<const ArrayType&>(values);
+
+ dict_encoded_size_ += static_cast<int>(sizeof(typename DType::c_type) * data.length());
+ for (int64_t i = 0; i < data.length(); i++) {
+ int32_t unused_memo_index;
+ PARQUET_THROW_NOT_OK(memo_table_.GetOrInsert(data.Value(i), &unused_memo_index));
+ }
+}
+
+template <>
+void DictEncoderImpl<FLBAType>::PutDictionary(const ::arrow::Array& values) {
+ AssertFixedSizeBinary(values, type_length_);
+ AssertCanPutDictionary(this, values);
+
+ const auto& data = checked_cast<const ::arrow::FixedSizeBinaryArray&>(values);
+
+ dict_encoded_size_ += static_cast<int>(type_length_ * data.length());
+ for (int64_t i = 0; i < data.length(); i++) {
+ int32_t unused_memo_index;
+ PARQUET_THROW_NOT_OK(
+ memo_table_.GetOrInsert(data.Value(i), type_length_, &unused_memo_index));
+ }
+}
+
+template <>
+void DictEncoderImpl<ByteArrayType>::PutDictionary(const ::arrow::Array& values) {
+ AssertBaseBinary(values);
+ AssertCanPutDictionary(this, values);
+
+ if (::arrow::is_binary_like(values.type_id())) {
+ PutBinaryDictionaryArray(checked_cast<const ::arrow::BinaryArray&>(values));
+ } else {
+ DCHECK(::arrow::is_large_binary_like(values.type_id()));
+ PutBinaryDictionaryArray(checked_cast<const ::arrow::LargeBinaryArray&>(values));
+ }
+}
+
+// ----------------------------------------------------------------------
+// ByteStreamSplitEncoder<T> implementations
+
+template <typename DType>
+class ByteStreamSplitEncoder : public EncoderImpl, virtual public TypedEncoder<DType> {
+ public:
+ using T = typename DType::c_type;
+ using TypedEncoder<DType>::Put;
+
+ explicit ByteStreamSplitEncoder(
+ const ColumnDescriptor* descr,
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool());
+
+ int64_t EstimatedDataEncodedSize() override;
+ std::shared_ptr<Buffer> FlushValues() override;
+
+ void Put(const T* buffer, int num_values) override;
+ void Put(const ::arrow::Array& values) override;
+ void PutSpaced(const T* src, int num_values, const uint8_t* valid_bits,
+ int64_t valid_bits_offset) override;
+
+ protected:
+ template <typename ArrowType>
+ void PutImpl(const ::arrow::Array& values) {
+ if (values.type_id() != ArrowType::type_id) {
+ throw ParquetException(std::string() + "direct put to " + ArrowType::type_name() +
+ " from " + values.type()->ToString() + " not supported");
+ }
+ const auto& data = *values.data();
+ PutSpaced(data.GetValues<typename ArrowType::c_type>(1),
+ static_cast<int>(data.length), data.GetValues<uint8_t>(0, 0), data.offset);
+ }
+
+ ::arrow::BufferBuilder sink_;
+ int64_t num_values_in_buffer_;
+};
+
+template <typename DType>
+ByteStreamSplitEncoder<DType>::ByteStreamSplitEncoder(const ColumnDescriptor* descr,
+ ::arrow::MemoryPool* pool)
+ : EncoderImpl(descr, Encoding::BYTE_STREAM_SPLIT, pool),
+ sink_{pool},
+ num_values_in_buffer_{0} {}
+
+template <typename DType>
+int64_t ByteStreamSplitEncoder<DType>::EstimatedDataEncodedSize() {
+ return sink_.length();
+}
+
+template <typename DType>
+std::shared_ptr<Buffer> ByteStreamSplitEncoder<DType>::FlushValues() {
+ std::shared_ptr<ResizableBuffer> output_buffer =
+ AllocateBuffer(this->memory_pool(), EstimatedDataEncodedSize());
+ uint8_t* output_buffer_raw = output_buffer->mutable_data();
+ const uint8_t* raw_values = sink_.data();
+ ::arrow::util::internal::ByteStreamSplitEncode<T>(raw_values, num_values_in_buffer_,
+ output_buffer_raw);
+ sink_.Reset();
+ num_values_in_buffer_ = 0;
+ return std::move(output_buffer);
+}
+
+template <typename DType>
+void ByteStreamSplitEncoder<DType>::Put(const T* buffer, int num_values) {
+ if (num_values > 0) {
+ PARQUET_THROW_NOT_OK(sink_.Append(buffer, num_values * sizeof(T)));
+ num_values_in_buffer_ += num_values;
+ }
+}
+
+template <>
+void ByteStreamSplitEncoder<FloatType>::Put(const ::arrow::Array& values) {
+ PutImpl<::arrow::FloatType>(values);
+}
+
+template <>
+void ByteStreamSplitEncoder<DoubleType>::Put(const ::arrow::Array& values) {
+ PutImpl<::arrow::DoubleType>(values);
+}
+
+template <typename DType>
+void ByteStreamSplitEncoder<DType>::PutSpaced(const T* src, int num_values,
+ const uint8_t* valid_bits,
+ int64_t valid_bits_offset) {
+ if (valid_bits != NULLPTR) {
+ PARQUET_ASSIGN_OR_THROW(auto buffer, ::arrow::AllocateBuffer(num_values * sizeof(T),
+ this->memory_pool()));
+ T* data = reinterpret_cast<T*>(buffer->mutable_data());
+ int num_valid_values = ::arrow::util::internal::SpacedCompress<T>(
+ src, num_values, valid_bits, valid_bits_offset, data);
+ Put(data, num_valid_values);
+ } else {
+ Put(src, num_values);
+ }
+}
+
+class DecoderImpl : virtual public Decoder {
+ public:
+ void SetData(int num_values, const uint8_t* data, int len) override {
+ num_values_ = num_values;
+ data_ = data;
+ len_ = len;
+ }
+
+ int values_left() const override { return num_values_; }
+ Encoding::type encoding() const override { return encoding_; }
+
+ protected:
+ explicit DecoderImpl(const ColumnDescriptor* descr, Encoding::type encoding)
+ : descr_(descr), encoding_(encoding), num_values_(0), data_(NULLPTR), len_(0) {}
+
+ // For accessing type-specific metadata, like FIXED_LEN_BYTE_ARRAY
+ const ColumnDescriptor* descr_;
+
+ const Encoding::type encoding_;
+ int num_values_;
+ const uint8_t* data_;
+ int len_;
+ int type_length_;
+};
+
+template <typename DType>
+class PlainDecoder : public DecoderImpl, virtual public TypedDecoder<DType> {
+ public:
+ using T = typename DType::c_type;
+ explicit PlainDecoder(const ColumnDescriptor* descr);
+
+ int Decode(T* buffer, int max_values) override;
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<DType>::Accumulator* builder) override;
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<DType>::DictAccumulator* builder) override;
+};
+
+template <>
+inline int PlainDecoder<Int96Type>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<Int96Type>::Accumulator* builder) {
+ ParquetException::NYI("DecodeArrow not supported for Int96");
+}
+
+template <>
+inline int PlainDecoder<Int96Type>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<Int96Type>::DictAccumulator* builder) {
+ ParquetException::NYI("DecodeArrow not supported for Int96");
+}
+
+template <>
+inline int PlainDecoder<BooleanType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<BooleanType>::DictAccumulator* builder) {
+ ParquetException::NYI("dictionaries of BooleanType");
+}
+
+template <typename DType>
+int PlainDecoder<DType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<DType>::Accumulator* builder) {
+ using value_type = typename DType::c_type;
+
+ constexpr int value_size = static_cast<int>(sizeof(value_type));
+ int values_decoded = num_values - null_count;
+ if (ARROW_PREDICT_FALSE(len_ < value_size * values_decoded)) {
+ ParquetException::EofException();
+ }
+
+ PARQUET_THROW_NOT_OK(builder->Reserve(num_values));
+
+ VisitNullBitmapInline(
+ valid_bits, valid_bits_offset, num_values, null_count,
+ [&]() {
+ builder->UnsafeAppend(::arrow::util::SafeLoadAs<value_type>(data_));
+ data_ += sizeof(value_type);
+ },
+ [&]() { builder->UnsafeAppendNull(); });
+
+ num_values_ -= values_decoded;
+ len_ -= sizeof(value_type) * values_decoded;
+ return values_decoded;
+}
+
+template <typename DType>
+int PlainDecoder<DType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<DType>::DictAccumulator* builder) {
+ using value_type = typename DType::c_type;
+
+ constexpr int value_size = static_cast<int>(sizeof(value_type));
+ int values_decoded = num_values - null_count;
+ if (ARROW_PREDICT_FALSE(len_ < value_size * values_decoded)) {
+ ParquetException::EofException();
+ }
+
+ PARQUET_THROW_NOT_OK(builder->Reserve(num_values));
+
+ VisitNullBitmapInline(
+ valid_bits, valid_bits_offset, num_values, null_count,
+ [&]() {
+ PARQUET_THROW_NOT_OK(
+ builder->Append(::arrow::util::SafeLoadAs<value_type>(data_)));
+ data_ += sizeof(value_type);
+ },
+ [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); });
+
+ num_values_ -= values_decoded;
+ len_ -= sizeof(value_type) * values_decoded;
+ return values_decoded;
+}
+
+// Decode routine templated on C++ type rather than type enum
+template <typename T>
+inline int DecodePlain(const uint8_t* data, int64_t data_size, int num_values,
+ int type_length, T* out) {
+ int64_t bytes_to_decode = num_values * static_cast<int64_t>(sizeof(T));
+ if (bytes_to_decode > data_size || bytes_to_decode > INT_MAX) {
+ ParquetException::EofException();
+ }
+ // If bytes_to_decode == 0, data could be null
+ if (bytes_to_decode > 0) {
+ memcpy(out, data, bytes_to_decode);
+ }
+ return static_cast<int>(bytes_to_decode);
+}
+
+template <typename DType>
+PlainDecoder<DType>::PlainDecoder(const ColumnDescriptor* descr)
+ : DecoderImpl(descr, Encoding::PLAIN) {
+ if (descr_ && descr_->physical_type() == Type::FIXED_LEN_BYTE_ARRAY) {
+ type_length_ = descr_->type_length();
+ } else {
+ type_length_ = -1;
+ }
+}
+
+// Template specialization for BYTE_ARRAY. The written values do not own their
+// own data.
+
+static inline int64_t ReadByteArray(const uint8_t* data, int64_t data_size,
+ ByteArray* out) {
+ if (ARROW_PREDICT_FALSE(data_size < 4)) {
+ ParquetException::EofException();
+ }
+ const int32_t len = ::arrow::util::SafeLoadAs<int32_t>(data);
+ if (len < 0) {
+ throw ParquetException("Invalid BYTE_ARRAY value");
+ }
+ const int64_t consumed_length = static_cast<int64_t>(len) + 4;
+ if (ARROW_PREDICT_FALSE(data_size < consumed_length)) {
+ ParquetException::EofException();
+ }
+ *out = ByteArray{static_cast<uint32_t>(len), data + 4};
+ return consumed_length;
+}
+
+template <>
+inline int DecodePlain<ByteArray>(const uint8_t* data, int64_t data_size, int num_values,
+ int type_length, ByteArray* out) {
+ int bytes_decoded = 0;
+ for (int i = 0; i < num_values; ++i) {
+ const auto increment = ReadByteArray(data, data_size, out + i);
+ if (ARROW_PREDICT_FALSE(increment > INT_MAX - bytes_decoded)) {
+ throw ParquetException("BYTE_ARRAY chunk too large");
+ }
+ data += increment;
+ data_size -= increment;
+ bytes_decoded += static_cast<int>(increment);
+ }
+ return bytes_decoded;
+}
+
+// Template specialization for FIXED_LEN_BYTE_ARRAY. The written values do not
+// own their own data.
+template <>
+inline int DecodePlain<FixedLenByteArray>(const uint8_t* data, int64_t data_size,
+ int num_values, int type_length,
+ FixedLenByteArray* out) {
+ int64_t bytes_to_decode = static_cast<int64_t>(type_length) * num_values;
+ if (bytes_to_decode > data_size || bytes_to_decode > INT_MAX) {
+ ParquetException::EofException();
+ }
+ for (int i = 0; i < num_values; ++i) {
+ out[i].ptr = data;
+ data += type_length;
+ data_size -= type_length;
+ }
+ return static_cast<int>(bytes_to_decode);
+}
+
+template <typename DType>
+int PlainDecoder<DType>::Decode(T* buffer, int max_values) {
+ max_values = std::min(max_values, num_values_);
+ int bytes_consumed = DecodePlain<T>(data_, len_, max_values, type_length_, buffer);
+ data_ += bytes_consumed;
+ len_ -= bytes_consumed;
+ num_values_ -= max_values;
+ return max_values;
+}
+
+class PlainBooleanDecoder : public DecoderImpl,
+ virtual public TypedDecoder<BooleanType>,
+ virtual public BooleanDecoder {
+ public:
+ explicit PlainBooleanDecoder(const ColumnDescriptor* descr);
+ void SetData(int num_values, const uint8_t* data, int len) override;
+
+ // Two flavors of bool decoding
+ int Decode(uint8_t* buffer, int max_values) override;
+ int Decode(bool* buffer, int max_values) override;
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<BooleanType>::Accumulator* out) override;
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<BooleanType>::DictAccumulator* out) override;
+
+ private:
+ std::unique_ptr<::arrow::BitUtil::BitReader> bit_reader_;
+};
+
+PlainBooleanDecoder::PlainBooleanDecoder(const ColumnDescriptor* descr)
+ : DecoderImpl(descr, Encoding::PLAIN) {}
+
+void PlainBooleanDecoder::SetData(int num_values, const uint8_t* data, int len) {
+ num_values_ = num_values;
+ bit_reader_.reset(new BitUtil::BitReader(data, len));
+}
+
+int PlainBooleanDecoder::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<BooleanType>::Accumulator* builder) {
+ int values_decoded = num_values - null_count;
+ if (ARROW_PREDICT_FALSE(num_values_ < values_decoded)) {
+ ParquetException::EofException();
+ }
+
+ PARQUET_THROW_NOT_OK(builder->Reserve(num_values));
+
+ VisitNullBitmapInline(
+ valid_bits, valid_bits_offset, num_values, null_count,
+ [&]() {
+ bool value;
+ ARROW_IGNORE_EXPR(bit_reader_->GetValue(1, &value));
+ builder->UnsafeAppend(value);
+ },
+ [&]() { builder->UnsafeAppendNull(); });
+
+ num_values_ -= values_decoded;
+ return values_decoded;
+}
+
+inline int PlainBooleanDecoder::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<BooleanType>::DictAccumulator* builder) {
+ ParquetException::NYI("dictionaries of BooleanType");
+}
+
+int PlainBooleanDecoder::Decode(uint8_t* buffer, int max_values) {
+ max_values = std::min(max_values, num_values_);
+ bool val;
+ ::arrow::internal::BitmapWriter bit_writer(buffer, 0, max_values);
+ for (int i = 0; i < max_values; ++i) {
+ if (!bit_reader_->GetValue(1, &val)) {
+ ParquetException::EofException();
+ }
+ if (val) {
+ bit_writer.Set();
+ }
+ bit_writer.Next();
+ }
+ bit_writer.Finish();
+ num_values_ -= max_values;
+ return max_values;
+}
+
+int PlainBooleanDecoder::Decode(bool* buffer, int max_values) {
+ max_values = std::min(max_values, num_values_);
+ if (bit_reader_->GetBatch(1, buffer, max_values) != max_values) {
+ ParquetException::EofException();
+ }
+ num_values_ -= max_values;
+ return max_values;
+}
+
+struct ArrowBinaryHelper {
+ explicit ArrowBinaryHelper(typename EncodingTraits<ByteArrayType>::Accumulator* out) {
+ this->out = out;
+ this->builder = out->builder.get();
+ this->chunk_space_remaining =
+ ::arrow::kBinaryMemoryLimit - this->builder->value_data_length();
+ }
+
+ Status PushChunk() {
+ std::shared_ptr<::arrow::Array> result;
+ RETURN_NOT_OK(builder->Finish(&result));
+ out->chunks.push_back(result);
+ chunk_space_remaining = ::arrow::kBinaryMemoryLimit;
+ return Status::OK();
+ }
+
+ bool CanFit(int64_t length) const { return length <= chunk_space_remaining; }
+
+ void UnsafeAppend(const uint8_t* data, int32_t length) {
+ chunk_space_remaining -= length;
+ builder->UnsafeAppend(data, length);
+ }
+
+ void UnsafeAppendNull() { builder->UnsafeAppendNull(); }
+
+ Status Append(const uint8_t* data, int32_t length) {
+ chunk_space_remaining -= length;
+ return builder->Append(data, length);
+ }
+
+ Status AppendNull() { return builder->AppendNull(); }
+
+ typename EncodingTraits<ByteArrayType>::Accumulator* out;
+ ::arrow::BinaryBuilder* builder;
+ int64_t chunk_space_remaining;
+};
+
+template <>
+inline int PlainDecoder<ByteArrayType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<ByteArrayType>::Accumulator* builder) {
+ ParquetException::NYI();
+}
+
+template <>
+inline int PlainDecoder<ByteArrayType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<ByteArrayType>::DictAccumulator* builder) {
+ ParquetException::NYI();
+}
+
+template <>
+inline int PlainDecoder<FLBAType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<FLBAType>::Accumulator* builder) {
+ int values_decoded = num_values - null_count;
+ if (ARROW_PREDICT_FALSE(len_ < descr_->type_length() * values_decoded)) {
+ ParquetException::EofException();
+ }
+
+ PARQUET_THROW_NOT_OK(builder->Reserve(num_values));
+
+ VisitNullBitmapInline(
+ valid_bits, valid_bits_offset, num_values, null_count,
+ [&]() {
+ builder->UnsafeAppend(data_);
+ data_ += descr_->type_length();
+ },
+ [&]() { builder->UnsafeAppendNull(); });
+
+ num_values_ -= values_decoded;
+ len_ -= descr_->type_length() * values_decoded;
+ return values_decoded;
+}
+
+template <>
+inline int PlainDecoder<FLBAType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<FLBAType>::DictAccumulator* builder) {
+ int values_decoded = num_values - null_count;
+ if (ARROW_PREDICT_FALSE(len_ < descr_->type_length() * values_decoded)) {
+ ParquetException::EofException();
+ }
+
+ PARQUET_THROW_NOT_OK(builder->Reserve(num_values));
+
+ VisitNullBitmapInline(
+ valid_bits, valid_bits_offset, num_values, null_count,
+ [&]() {
+ PARQUET_THROW_NOT_OK(builder->Append(data_));
+ data_ += descr_->type_length();
+ },
+ [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); });
+
+ num_values_ -= values_decoded;
+ len_ -= descr_->type_length() * values_decoded;
+ return values_decoded;
+}
+
+class PlainByteArrayDecoder : public PlainDecoder<ByteArrayType>,
+ virtual public ByteArrayDecoder {
+ public:
+ using Base = PlainDecoder<ByteArrayType>;
+ using Base::DecodeSpaced;
+ using Base::PlainDecoder;
+
+ // ----------------------------------------------------------------------
+ // Dictionary read paths
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ ::arrow::BinaryDictionary32Builder* builder) override {
+ int result = 0;
+ PARQUET_THROW_NOT_OK(DecodeArrow(num_values, null_count, valid_bits,
+ valid_bits_offset, builder, &result));
+ return result;
+ }
+
+ // ----------------------------------------------------------------------
+ // Optimized dense binary read paths
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<ByteArrayType>::Accumulator* out) override {
+ int result = 0;
+ PARQUET_THROW_NOT_OK(DecodeArrowDense(num_values, null_count, valid_bits,
+ valid_bits_offset, out, &result));
+ return result;
+ }
+
+ private:
+ Status DecodeArrowDense(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<ByteArrayType>::Accumulator* out,
+ int* out_values_decoded) {
+ ArrowBinaryHelper helper(out);
+ int values_decoded = 0;
+
+ RETURN_NOT_OK(helper.builder->Reserve(num_values));
+ RETURN_NOT_OK(helper.builder->ReserveData(
+ std::min<int64_t>(len_, helper.chunk_space_remaining)));
+
+ int i = 0;
+ RETURN_NOT_OK(VisitNullBitmapInline(
+ valid_bits, valid_bits_offset, num_values, null_count,
+ [&]() {
+ if (ARROW_PREDICT_FALSE(len_ < 4)) {
+ ParquetException::EofException();
+ }
+ auto value_len = ::arrow::util::SafeLoadAs<int32_t>(data_);
+ if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - 4)) {
+ return Status::Invalid("Invalid or corrupted value_len '", value_len, "'");
+ }
+ auto increment = value_len + 4;
+ if (ARROW_PREDICT_FALSE(len_ < increment)) {
+ ParquetException::EofException();
+ }
+ if (ARROW_PREDICT_FALSE(!helper.CanFit(value_len))) {
+ // This element would exceed the capacity of a chunk
+ RETURN_NOT_OK(helper.PushChunk());
+ RETURN_NOT_OK(helper.builder->Reserve(num_values - i));
+ RETURN_NOT_OK(helper.builder->ReserveData(
+ std::min<int64_t>(len_, helper.chunk_space_remaining)));
+ }
+ helper.UnsafeAppend(data_ + 4, value_len);
+ data_ += increment;
+ len_ -= increment;
+ ++values_decoded;
+ ++i;
+ return Status::OK();
+ },
+ [&]() {
+ helper.UnsafeAppendNull();
+ ++i;
+ return Status::OK();
+ }));
+
+ num_values_ -= values_decoded;
+ *out_values_decoded = values_decoded;
+ return Status::OK();
+ }
+
+ template <typename BuilderType>
+ Status DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset, BuilderType* builder,
+ int* out_values_decoded) {
+ RETURN_NOT_OK(builder->Reserve(num_values));
+ int values_decoded = 0;
+
+ RETURN_NOT_OK(VisitNullBitmapInline(
+ valid_bits, valid_bits_offset, num_values, null_count,
+ [&]() {
+ if (ARROW_PREDICT_FALSE(len_ < 4)) {
+ ParquetException::EofException();
+ }
+ auto value_len = ::arrow::util::SafeLoadAs<int32_t>(data_);
+ if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - 4)) {
+ return Status::Invalid("Invalid or corrupted value_len '", value_len, "'");
+ }
+ auto increment = value_len + 4;
+ if (ARROW_PREDICT_FALSE(len_ < increment)) {
+ ParquetException::EofException();
+ }
+ RETURN_NOT_OK(builder->Append(data_ + 4, value_len));
+ data_ += increment;
+ len_ -= increment;
+ ++values_decoded;
+ return Status::OK();
+ },
+ [&]() { return builder->AppendNull(); }));
+
+ num_values_ -= values_decoded;
+ *out_values_decoded = values_decoded;
+ return Status::OK();
+ }
+};
+
+class PlainFLBADecoder : public PlainDecoder<FLBAType>, virtual public FLBADecoder {
+ public:
+ using Base = PlainDecoder<FLBAType>;
+ using Base::PlainDecoder;
+};
+
+// ----------------------------------------------------------------------
+// Dictionary encoding and decoding
+
+template <typename Type>
+class DictDecoderImpl : public DecoderImpl, virtual public DictDecoder<Type> {
+ public:
+ typedef typename Type::c_type T;
+
+ // Initializes the dictionary with values from 'dictionary'. The data in
+ // dictionary is not guaranteed to persist in memory after this call so the
+ // dictionary decoder needs to copy the data out if necessary.
+ explicit DictDecoderImpl(const ColumnDescriptor* descr,
+ MemoryPool* pool = ::arrow::default_memory_pool())
+ : DecoderImpl(descr, Encoding::RLE_DICTIONARY),
+ dictionary_(AllocateBuffer(pool, 0)),
+ dictionary_length_(0),
+ byte_array_data_(AllocateBuffer(pool, 0)),
+ byte_array_offsets_(AllocateBuffer(pool, 0)),
+ indices_scratch_space_(AllocateBuffer(pool, 0)) {}
+
+ // Perform type-specific initiatialization
+ void SetDict(TypedDecoder<Type>* dictionary) override;
+
+ void SetData(int num_values, const uint8_t* data, int len) override {
+ num_values_ = num_values;
+ if (len == 0) {
+ // Initialize dummy decoder to avoid crashes later on
+ idx_decoder_ = ::arrow::util::RleDecoder(data, len, /*bit_width=*/1);
+ return;
+ }
+ uint8_t bit_width = *data;
+ if (ARROW_PREDICT_FALSE(bit_width >= 64)) {
+ throw ParquetException("Invalid or corrupted bit_width");
+ }
+ idx_decoder_ = ::arrow::util::RleDecoder(++data, --len, bit_width);
+ }
+
+ int Decode(T* buffer, int num_values) override {
+ num_values = std::min(num_values, num_values_);
+ int decoded_values =
+ idx_decoder_.GetBatchWithDict(reinterpret_cast<const T*>(dictionary_->data()),
+ dictionary_length_, buffer, num_values);
+ if (decoded_values != num_values) {
+ ParquetException::EofException();
+ }
+ num_values_ -= num_values;
+ return num_values;
+ }
+
+ int DecodeSpaced(T* buffer, int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset) override {
+ num_values = std::min(num_values, num_values_);
+ if (num_values != idx_decoder_.GetBatchWithDictSpaced(
+ reinterpret_cast<const T*>(dictionary_->data()),
+ dictionary_length_, buffer, num_values, null_count, valid_bits,
+ valid_bits_offset)) {
+ ParquetException::EofException();
+ }
+ num_values_ -= num_values;
+ return num_values;
+ }
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<Type>::Accumulator* out) override;
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<Type>::DictAccumulator* out) override;
+
+ void InsertDictionary(::arrow::ArrayBuilder* builder) override;
+
+ int DecodeIndicesSpaced(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ ::arrow::ArrayBuilder* builder) override {
+ if (num_values > 0) {
+ // TODO(wesm): Refactor to batch reads for improved memory use. It is not
+ // trivial because the null_count is relative to the entire bitmap
+ PARQUET_THROW_NOT_OK(indices_scratch_space_->TypedResize<int32_t>(
+ num_values, /*shrink_to_fit=*/false));
+ }
+
+ auto indices_buffer =
+ reinterpret_cast<int32_t*>(indices_scratch_space_->mutable_data());
+
+ if (num_values != idx_decoder_.GetBatchSpaced(num_values, null_count, valid_bits,
+ valid_bits_offset, indices_buffer)) {
+ ParquetException::EofException();
+ }
+
+ /// XXX(wesm): Cannot append "valid bits" directly to the builder
+ std::vector<uint8_t> valid_bytes(num_values);
+ ::arrow::internal::BitmapReader bit_reader(valid_bits, valid_bits_offset, num_values);
+ for (int64_t i = 0; i < num_values; ++i) {
+ valid_bytes[i] = static_cast<uint8_t>(bit_reader.IsSet());
+ bit_reader.Next();
+ }
+
+ auto binary_builder = checked_cast<::arrow::BinaryDictionary32Builder*>(builder);
+ PARQUET_THROW_NOT_OK(
+ binary_builder->AppendIndices(indices_buffer, num_values, valid_bytes.data()));
+ num_values_ -= num_values - null_count;
+ return num_values - null_count;
+ }
+
+ int DecodeIndices(int num_values, ::arrow::ArrayBuilder* builder) override {
+ num_values = std::min(num_values, num_values_);
+ if (num_values > 0) {
+ // TODO(wesm): Refactor to batch reads for improved memory use. This is
+ // relatively simple here because we don't have to do any bookkeeping of
+ // nulls
+ PARQUET_THROW_NOT_OK(indices_scratch_space_->TypedResize<int32_t>(
+ num_values, /*shrink_to_fit=*/false));
+ }
+ auto indices_buffer =
+ reinterpret_cast<int32_t*>(indices_scratch_space_->mutable_data());
+ if (num_values != idx_decoder_.GetBatch(indices_buffer, num_values)) {
+ ParquetException::EofException();
+ }
+ auto binary_builder = checked_cast<::arrow::BinaryDictionary32Builder*>(builder);
+ PARQUET_THROW_NOT_OK(binary_builder->AppendIndices(indices_buffer, num_values));
+ num_values_ -= num_values;
+ return num_values;
+ }
+
+ int DecodeIndices(int num_values, int32_t* indices) override {
+ if (num_values != idx_decoder_.GetBatch(indices, num_values)) {
+ ParquetException::EofException();
+ }
+ num_values_ -= num_values;
+ return num_values;
+ }
+
+ void GetDictionary(const T** dictionary, int32_t* dictionary_length) override {
+ *dictionary_length = dictionary_length_;
+ *dictionary = reinterpret_cast<T*>(dictionary_->mutable_data());
+ }
+
+ protected:
+ Status IndexInBounds(int32_t index) {
+ if (ARROW_PREDICT_TRUE(0 <= index && index < dictionary_length_)) {
+ return Status::OK();
+ }
+ return Status::Invalid("Index not in dictionary bounds");
+ }
+
+ inline void DecodeDict(TypedDecoder<Type>* dictionary) {
+ dictionary_length_ = static_cast<int32_t>(dictionary->values_left());
+ PARQUET_THROW_NOT_OK(dictionary_->Resize(dictionary_length_ * sizeof(T),
+ /*shrink_to_fit=*/false));
+ dictionary->Decode(reinterpret_cast<T*>(dictionary_->mutable_data()),
+ dictionary_length_);
+ }
+
+ // Only one is set.
+ std::shared_ptr<ResizableBuffer> dictionary_;
+
+ int32_t dictionary_length_;
+
+ // Data that contains the byte array data (byte_array_dictionary_ just has the
+ // pointers).
+ std::shared_ptr<ResizableBuffer> byte_array_data_;
+
+ // Arrow-style byte offsets for each dictionary value. We maintain two
+ // representations of the dictionary, one as ByteArray* for non-Arrow
+ // consumers and this one for Arrow consumers. Since dictionaries are
+ // generally pretty small to begin with this doesn't mean too much extra
+ // memory use in most cases
+ std::shared_ptr<ResizableBuffer> byte_array_offsets_;
+
+ // Reusable buffer for decoding dictionary indices to be appended to a
+ // BinaryDictionary32Builder
+ std::shared_ptr<ResizableBuffer> indices_scratch_space_;
+
+ ::arrow::util::RleDecoder idx_decoder_;
+};
+
+template <typename Type>
+void DictDecoderImpl<Type>::SetDict(TypedDecoder<Type>* dictionary) {
+ DecodeDict(dictionary);
+}
+
+template <>
+void DictDecoderImpl<BooleanType>::SetDict(TypedDecoder<BooleanType>* dictionary) {
+ ParquetException::NYI("Dictionary encoding is not implemented for boolean values");
+}
+
+template <>
+void DictDecoderImpl<ByteArrayType>::SetDict(TypedDecoder<ByteArrayType>* dictionary) {
+ DecodeDict(dictionary);
+
+ auto dict_values = reinterpret_cast<ByteArray*>(dictionary_->mutable_data());
+
+ int total_size = 0;
+ for (int i = 0; i < dictionary_length_; ++i) {
+ total_size += dict_values[i].len;
+ }
+ PARQUET_THROW_NOT_OK(byte_array_data_->Resize(total_size,
+ /*shrink_to_fit=*/false));
+ PARQUET_THROW_NOT_OK(
+ byte_array_offsets_->Resize((dictionary_length_ + 1) * sizeof(int32_t),
+ /*shrink_to_fit=*/false));
+
+ int32_t offset = 0;
+ uint8_t* bytes_data = byte_array_data_->mutable_data();
+ int32_t* bytes_offsets =
+ reinterpret_cast<int32_t*>(byte_array_offsets_->mutable_data());
+ for (int i = 0; i < dictionary_length_; ++i) {
+ memcpy(bytes_data + offset, dict_values[i].ptr, dict_values[i].len);
+ bytes_offsets[i] = offset;
+ dict_values[i].ptr = bytes_data + offset;
+ offset += dict_values[i].len;
+ }
+ bytes_offsets[dictionary_length_] = offset;
+}
+
+template <>
+inline void DictDecoderImpl<FLBAType>::SetDict(TypedDecoder<FLBAType>* dictionary) {
+ DecodeDict(dictionary);
+
+ auto dict_values = reinterpret_cast<FLBA*>(dictionary_->mutable_data());
+
+ int fixed_len = descr_->type_length();
+ int total_size = dictionary_length_ * fixed_len;
+
+ PARQUET_THROW_NOT_OK(byte_array_data_->Resize(total_size,
+ /*shrink_to_fit=*/false));
+ uint8_t* bytes_data = byte_array_data_->mutable_data();
+ for (int32_t i = 0, offset = 0; i < dictionary_length_; ++i, offset += fixed_len) {
+ memcpy(bytes_data + offset, dict_values[i].ptr, fixed_len);
+ dict_values[i].ptr = bytes_data + offset;
+ }
+}
+
+template <>
+inline int DictDecoderImpl<Int96Type>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<Int96Type>::Accumulator* builder) {
+ ParquetException::NYI("DecodeArrow to Int96Type");
+}
+
+template <>
+inline int DictDecoderImpl<Int96Type>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<Int96Type>::DictAccumulator* builder) {
+ ParquetException::NYI("DecodeArrow to Int96Type");
+}
+
+template <>
+inline int DictDecoderImpl<ByteArrayType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<ByteArrayType>::Accumulator* builder) {
+ ParquetException::NYI("DecodeArrow implemented elsewhere");
+}
+
+template <>
+inline int DictDecoderImpl<ByteArrayType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<ByteArrayType>::DictAccumulator* builder) {
+ ParquetException::NYI("DecodeArrow implemented elsewhere");
+}
+
+template <typename DType>
+int DictDecoderImpl<DType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<DType>::DictAccumulator* builder) {
+ PARQUET_THROW_NOT_OK(builder->Reserve(num_values));
+
+ auto dict_values = reinterpret_cast<const typename DType::c_type*>(dictionary_->data());
+
+ VisitNullBitmapInline(
+ valid_bits, valid_bits_offset, num_values, null_count,
+ [&]() {
+ int32_t index;
+ if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) {
+ throw ParquetException("");
+ }
+ PARQUET_THROW_NOT_OK(IndexInBounds(index));
+ PARQUET_THROW_NOT_OK(builder->Append(dict_values[index]));
+ },
+ [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); });
+
+ return num_values - null_count;
+}
+
+template <>
+int DictDecoderImpl<BooleanType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<BooleanType>::DictAccumulator* builder) {
+ ParquetException::NYI("No dictionary encoding for BooleanType");
+}
+
+template <>
+inline int DictDecoderImpl<FLBAType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<FLBAType>::Accumulator* builder) {
+ if (builder->byte_width() != descr_->type_length()) {
+ throw ParquetException("Byte width mismatch: builder was " +
+ std::to_string(builder->byte_width()) + " but decoder was " +
+ std::to_string(descr_->type_length()));
+ }
+
+ PARQUET_THROW_NOT_OK(builder->Reserve(num_values));
+
+ auto dict_values = reinterpret_cast<const FLBA*>(dictionary_->data());
+
+ VisitNullBitmapInline(
+ valid_bits, valid_bits_offset, num_values, null_count,
+ [&]() {
+ int32_t index;
+ if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) {
+ throw ParquetException("");
+ }
+ PARQUET_THROW_NOT_OK(IndexInBounds(index));
+ builder->UnsafeAppend(dict_values[index].ptr);
+ },
+ [&]() { builder->UnsafeAppendNull(); });
+
+ return num_values - null_count;
+}
+
+template <>
+int DictDecoderImpl<FLBAType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<FLBAType>::DictAccumulator* builder) {
+ auto value_type =
+ checked_cast<const ::arrow::DictionaryType&>(*builder->type()).value_type();
+ auto byte_width =
+ checked_cast<const ::arrow::FixedSizeBinaryType&>(*value_type).byte_width();
+ if (byte_width != descr_->type_length()) {
+ throw ParquetException("Byte width mismatch: builder was " +
+ std::to_string(byte_width) + " but decoder was " +
+ std::to_string(descr_->type_length()));
+ }
+
+ PARQUET_THROW_NOT_OK(builder->Reserve(num_values));
+
+ auto dict_values = reinterpret_cast<const FLBA*>(dictionary_->data());
+
+ VisitNullBitmapInline(
+ valid_bits, valid_bits_offset, num_values, null_count,
+ [&]() {
+ int32_t index;
+ if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) {
+ throw ParquetException("");
+ }
+ PARQUET_THROW_NOT_OK(IndexInBounds(index));
+ PARQUET_THROW_NOT_OK(builder->Append(dict_values[index].ptr));
+ },
+ [&]() { PARQUET_THROW_NOT_OK(builder->AppendNull()); });
+
+ return num_values - null_count;
+}
+
+template <typename Type>
+int DictDecoderImpl<Type>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<Type>::Accumulator* builder) {
+ PARQUET_THROW_NOT_OK(builder->Reserve(num_values));
+
+ using value_type = typename Type::c_type;
+ auto dict_values = reinterpret_cast<const value_type*>(dictionary_->data());
+
+ VisitNullBitmapInline(
+ valid_bits, valid_bits_offset, num_values, null_count,
+ [&]() {
+ int32_t index;
+ if (ARROW_PREDICT_FALSE(!idx_decoder_.Get(&index))) {
+ throw ParquetException("");
+ }
+ PARQUET_THROW_NOT_OK(IndexInBounds(index));
+ builder->UnsafeAppend(dict_values[index]);
+ },
+ [&]() { builder->UnsafeAppendNull(); });
+
+ return num_values - null_count;
+}
+
+template <typename Type>
+void DictDecoderImpl<Type>::InsertDictionary(::arrow::ArrayBuilder* builder) {
+ ParquetException::NYI("InsertDictionary only implemented for BYTE_ARRAY types");
+}
+
+template <>
+void DictDecoderImpl<ByteArrayType>::InsertDictionary(::arrow::ArrayBuilder* builder) {
+ auto binary_builder = checked_cast<::arrow::BinaryDictionary32Builder*>(builder);
+
+ // Make a BinaryArray referencing the internal dictionary data
+ auto arr = std::make_shared<::arrow::BinaryArray>(
+ dictionary_length_, byte_array_offsets_, byte_array_data_);
+ PARQUET_THROW_NOT_OK(binary_builder->InsertMemoValues(*arr));
+}
+
+class DictByteArrayDecoderImpl : public DictDecoderImpl<ByteArrayType>,
+ virtual public ByteArrayDecoder {
+ public:
+ using BASE = DictDecoderImpl<ByteArrayType>;
+ using BASE::DictDecoderImpl;
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ ::arrow::BinaryDictionary32Builder* builder) override {
+ int result = 0;
+ if (null_count == 0) {
+ PARQUET_THROW_NOT_OK(DecodeArrowNonNull(num_values, builder, &result));
+ } else {
+ PARQUET_THROW_NOT_OK(DecodeArrow(num_values, null_count, valid_bits,
+ valid_bits_offset, builder, &result));
+ }
+ return result;
+ }
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<ByteArrayType>::Accumulator* out) override {
+ int result = 0;
+ if (null_count == 0) {
+ PARQUET_THROW_NOT_OK(DecodeArrowDenseNonNull(num_values, out, &result));
+ } else {
+ PARQUET_THROW_NOT_OK(DecodeArrowDense(num_values, null_count, valid_bits,
+ valid_bits_offset, out, &result));
+ }
+ return result;
+ }
+
+ private:
+ Status DecodeArrowDense(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<ByteArrayType>::Accumulator* out,
+ int* out_num_values) {
+ constexpr int32_t kBufferSize = 1024;
+ int32_t indices[kBufferSize];
+
+ ArrowBinaryHelper helper(out);
+
+ ::arrow::internal::BitmapReader bit_reader(valid_bits, valid_bits_offset, num_values);
+
+ auto dict_values = reinterpret_cast<const ByteArray*>(dictionary_->data());
+ int values_decoded = 0;
+ int num_appended = 0;
+ while (num_appended < num_values) {
+ bool is_valid = bit_reader.IsSet();
+ bit_reader.Next();
+
+ if (is_valid) {
+ int32_t batch_size =
+ std::min<int32_t>(kBufferSize, num_values - num_appended - null_count);
+ int num_indices = idx_decoder_.GetBatch(indices, batch_size);
+
+ if (ARROW_PREDICT_FALSE(num_indices < 1)) {
+ return Status::Invalid("Invalid number of indices '", num_indices, "'");
+ }
+
+ int i = 0;
+ while (true) {
+ // Consume all indices
+ if (is_valid) {
+ auto idx = indices[i];
+ RETURN_NOT_OK(IndexInBounds(idx));
+ const auto& val = dict_values[idx];
+ if (ARROW_PREDICT_FALSE(!helper.CanFit(val.len))) {
+ RETURN_NOT_OK(helper.PushChunk());
+ }
+ RETURN_NOT_OK(helper.Append(val.ptr, static_cast<int32_t>(val.len)));
+ ++i;
+ ++values_decoded;
+ } else {
+ RETURN_NOT_OK(helper.AppendNull());
+ --null_count;
+ }
+ ++num_appended;
+ if (i == num_indices) {
+ // Do not advance the bit_reader if we have fulfilled the decode
+ // request
+ break;
+ }
+ is_valid = bit_reader.IsSet();
+ bit_reader.Next();
+ }
+ } else {
+ RETURN_NOT_OK(helper.AppendNull());
+ --null_count;
+ ++num_appended;
+ }
+ }
+ *out_num_values = values_decoded;
+ return Status::OK();
+ }
+
+ Status DecodeArrowDenseNonNull(int num_values,
+ typename EncodingTraits<ByteArrayType>::Accumulator* out,
+ int* out_num_values) {
+ constexpr int32_t kBufferSize = 2048;
+ int32_t indices[kBufferSize];
+ int values_decoded = 0;
+
+ ArrowBinaryHelper helper(out);
+ auto dict_values = reinterpret_cast<const ByteArray*>(dictionary_->data());
+
+ while (values_decoded < num_values) {
+ int32_t batch_size = std::min<int32_t>(kBufferSize, num_values - values_decoded);
+ int num_indices = idx_decoder_.GetBatch(indices, batch_size);
+ if (num_indices == 0) ParquetException::EofException();
+ for (int i = 0; i < num_indices; ++i) {
+ auto idx = indices[i];
+ RETURN_NOT_OK(IndexInBounds(idx));
+ const auto& val = dict_values[idx];
+ if (ARROW_PREDICT_FALSE(!helper.CanFit(val.len))) {
+ RETURN_NOT_OK(helper.PushChunk());
+ }
+ RETURN_NOT_OK(helper.Append(val.ptr, static_cast<int32_t>(val.len)));
+ }
+ values_decoded += num_indices;
+ }
+ *out_num_values = values_decoded;
+ return Status::OK();
+ }
+
+ template <typename BuilderType>
+ Status DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset, BuilderType* builder,
+ int* out_num_values) {
+ constexpr int32_t kBufferSize = 1024;
+ int32_t indices[kBufferSize];
+
+ RETURN_NOT_OK(builder->Reserve(num_values));
+ ::arrow::internal::BitmapReader bit_reader(valid_bits, valid_bits_offset, num_values);
+
+ auto dict_values = reinterpret_cast<const ByteArray*>(dictionary_->data());
+
+ int values_decoded = 0;
+ int num_appended = 0;
+ while (num_appended < num_values) {
+ bool is_valid = bit_reader.IsSet();
+ bit_reader.Next();
+
+ if (is_valid) {
+ int32_t batch_size =
+ std::min<int32_t>(kBufferSize, num_values - num_appended - null_count);
+ int num_indices = idx_decoder_.GetBatch(indices, batch_size);
+
+ int i = 0;
+ while (true) {
+ // Consume all indices
+ if (is_valid) {
+ auto idx = indices[i];
+ RETURN_NOT_OK(IndexInBounds(idx));
+ const auto& val = dict_values[idx];
+ RETURN_NOT_OK(builder->Append(val.ptr, val.len));
+ ++i;
+ ++values_decoded;
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ --null_count;
+ }
+ ++num_appended;
+ if (i == num_indices) {
+ // Do not advance the bit_reader if we have fulfilled the decode
+ // request
+ break;
+ }
+ is_valid = bit_reader.IsSet();
+ bit_reader.Next();
+ }
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ --null_count;
+ ++num_appended;
+ }
+ }
+ *out_num_values = values_decoded;
+ return Status::OK();
+ }
+
+ template <typename BuilderType>
+ Status DecodeArrowNonNull(int num_values, BuilderType* builder, int* out_num_values) {
+ constexpr int32_t kBufferSize = 2048;
+ int32_t indices[kBufferSize];
+
+ RETURN_NOT_OK(builder->Reserve(num_values));
+
+ auto dict_values = reinterpret_cast<const ByteArray*>(dictionary_->data());
+
+ int values_decoded = 0;
+ while (values_decoded < num_values) {
+ int32_t batch_size = std::min<int32_t>(kBufferSize, num_values - values_decoded);
+ int num_indices = idx_decoder_.GetBatch(indices, batch_size);
+ if (num_indices == 0) ParquetException::EofException();
+ for (int i = 0; i < num_indices; ++i) {
+ auto idx = indices[i];
+ RETURN_NOT_OK(IndexInBounds(idx));
+ const auto& val = dict_values[idx];
+ RETURN_NOT_OK(builder->Append(val.ptr, val.len));
+ }
+ values_decoded += num_indices;
+ }
+ *out_num_values = values_decoded;
+ return Status::OK();
+ }
+};
+
+// ----------------------------------------------------------------------
+// DeltaBitPackDecoder
+
+template <typename DType>
+class DeltaBitPackDecoder : public DecoderImpl, virtual public TypedDecoder<DType> {
+ public:
+ typedef typename DType::c_type T;
+
+ explicit DeltaBitPackDecoder(const ColumnDescriptor* descr,
+ MemoryPool* pool = ::arrow::default_memory_pool())
+ : DecoderImpl(descr, Encoding::DELTA_BINARY_PACKED), pool_(pool) {
+ if (DType::type_num != Type::INT32 && DType::type_num != Type::INT64) {
+ throw ParquetException("Delta bit pack encoding should only be for integer data.");
+ }
+ }
+
+ void SetData(int num_values, const uint8_t* data, int len) override {
+ this->num_values_ = num_values;
+ decoder_ = ::arrow::BitUtil::BitReader(data, len);
+ InitHeader();
+ }
+
+ int Decode(T* buffer, int max_values) override {
+ return GetInternal(buffer, max_values);
+ }
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<DType>::Accumulator* out) override {
+ if (null_count != 0) {
+ ParquetException::NYI("Delta bit pack DecodeArrow with null slots");
+ }
+ std::vector<T> values(num_values);
+ GetInternal(values.data(), num_values);
+ PARQUET_THROW_NOT_OK(out->AppendValues(values));
+ return num_values;
+ }
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<DType>::DictAccumulator* out) override {
+ if (null_count != 0) {
+ ParquetException::NYI("Delta bit pack DecodeArrow with null slots");
+ }
+ std::vector<T> values(num_values);
+ GetInternal(values.data(), num_values);
+ PARQUET_THROW_NOT_OK(out->Reserve(num_values));
+ for (T value : values) {
+ PARQUET_THROW_NOT_OK(out->Append(value));
+ }
+ return num_values;
+ }
+
+ private:
+ static constexpr int kMaxDeltaBitWidth = static_cast<int>(sizeof(T) * 8);
+
+ void InitHeader() {
+ if (!decoder_.GetVlqInt(&values_per_block_) ||
+ !decoder_.GetVlqInt(&mini_blocks_per_block_) ||
+ !decoder_.GetVlqInt(&total_value_count_) ||
+ !decoder_.GetZigZagVlqInt(&last_value_)) {
+ ParquetException::EofException();
+ }
+
+ if (values_per_block_ == 0) {
+ throw ParquetException("cannot have zero value per block");
+ }
+ if (mini_blocks_per_block_ == 0) {
+ throw ParquetException("cannot have zero miniblock per block");
+ }
+ values_per_mini_block_ = values_per_block_ / mini_blocks_per_block_;
+ if (values_per_mini_block_ == 0) {
+ throw ParquetException("cannot have zero value per miniblock");
+ }
+ if (values_per_mini_block_ % 32 != 0) {
+ throw ParquetException(
+ "the number of values in a miniblock must be multiple of 32, but it's " +
+ std::to_string(values_per_mini_block_));
+ }
+
+ delta_bit_widths_ = AllocateBuffer(pool_, mini_blocks_per_block_);
+ block_initialized_ = false;
+ values_current_mini_block_ = 0;
+ }
+
+ void InitBlock() {
+ if (!decoder_.GetZigZagVlqInt(&min_delta_)) ParquetException::EofException();
+
+ // read the bitwidth of each miniblock
+ uint8_t* bit_width_data = delta_bit_widths_->mutable_data();
+ for (uint32_t i = 0; i < mini_blocks_per_block_; ++i) {
+ if (!decoder_.GetAligned<uint8_t>(1, bit_width_data + i)) {
+ ParquetException::EofException();
+ }
+ if (bit_width_data[i] > kMaxDeltaBitWidth) {
+ throw ParquetException("delta bit width larger than integer bit width");
+ }
+ }
+ mini_block_idx_ = 0;
+ delta_bit_width_ = bit_width_data[0];
+ values_current_mini_block_ = values_per_mini_block_;
+ block_initialized_ = true;
+ }
+
+ int GetInternal(T* buffer, int max_values) {
+ max_values = std::min(max_values, this->num_values_);
+ DCHECK_LE(static_cast<uint32_t>(max_values), total_value_count_);
+ int i = 0;
+ while (i < max_values) {
+ if (ARROW_PREDICT_FALSE(values_current_mini_block_ == 0)) {
+ if (ARROW_PREDICT_FALSE(!block_initialized_)) {
+ buffer[i++] = last_value_;
+ --total_value_count_;
+ if (ARROW_PREDICT_FALSE(i == max_values)) break;
+ InitBlock();
+ } else {
+ ++mini_block_idx_;
+ if (mini_block_idx_ < mini_blocks_per_block_) {
+ delta_bit_width_ = delta_bit_widths_->data()[mini_block_idx_];
+ values_current_mini_block_ = values_per_mini_block_;
+ } else {
+ InitBlock();
+ }
+ }
+ }
+
+ int values_decode =
+ std::min(values_current_mini_block_, static_cast<uint32_t>(max_values - i));
+ if (decoder_.GetBatch(delta_bit_width_, buffer + i, values_decode) !=
+ values_decode) {
+ ParquetException::EofException();
+ }
+ for (int j = 0; j < values_decode; ++j) {
+ // Addition between min_delta, packed int and last_value should be treated as
+ // unsigned addtion. Overflow is as expected.
+ uint64_t delta =
+ static_cast<uint64_t>(min_delta_) + static_cast<uint64_t>(buffer[i + j]);
+ buffer[i + j] = static_cast<T>(delta + static_cast<uint64_t>(last_value_));
+ last_value_ = buffer[i + j];
+ }
+ values_current_mini_block_ -= values_decode;
+ total_value_count_ -= values_decode;
+ i += values_decode;
+ }
+ this->num_values_ -= max_values;
+ return max_values;
+ }
+
+ MemoryPool* pool_;
+ ::arrow::BitUtil::BitReader decoder_;
+ uint32_t values_per_block_;
+ uint32_t mini_blocks_per_block_;
+ uint32_t values_per_mini_block_;
+ uint32_t values_current_mini_block_;
+ uint32_t total_value_count_;
+
+ bool block_initialized_;
+ T min_delta_;
+ uint32_t mini_block_idx_;
+ std::shared_ptr<ResizableBuffer> delta_bit_widths_;
+ int delta_bit_width_;
+
+ T last_value_;
+};
+
+// ----------------------------------------------------------------------
+// DELTA_LENGTH_BYTE_ARRAY
+
+class DeltaLengthByteArrayDecoder : public DecoderImpl,
+ virtual public TypedDecoder<ByteArrayType> {
+ public:
+ explicit DeltaLengthByteArrayDecoder(const ColumnDescriptor* descr,
+ MemoryPool* pool = ::arrow::default_memory_pool())
+ : DecoderImpl(descr, Encoding::DELTA_LENGTH_BYTE_ARRAY),
+ len_decoder_(nullptr, pool),
+ pool_(pool) {}
+
+ void SetData(int num_values, const uint8_t* data, int len) override {
+ num_values_ = num_values;
+ if (len == 0) return;
+ int total_lengths_len = ::arrow::util::SafeLoadAs<int32_t>(data);
+ data += 4;
+ this->len_decoder_.SetData(num_values, data, total_lengths_len);
+ data_ = data + total_lengths_len;
+ this->len_ = len - 4 - total_lengths_len;
+ }
+
+ int Decode(ByteArray* buffer, int max_values) override {
+ using VectorT = ArrowPoolVector<int>;
+ max_values = std::min(max_values, num_values_);
+ VectorT lengths(max_values, 0, ::arrow::stl::allocator<int>(pool_));
+ len_decoder_.Decode(lengths.data(), max_values);
+ for (int i = 0; i < max_values; ++i) {
+ buffer[i].len = lengths[i];
+ buffer[i].ptr = data_;
+ this->data_ += lengths[i];
+ this->len_ -= lengths[i];
+ }
+ this->num_values_ -= max_values;
+ return max_values;
+ }
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<ByteArrayType>::Accumulator* out) override {
+ ParquetException::NYI("DecodeArrow for DeltaLengthByteArrayDecoder");
+ }
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<ByteArrayType>::DictAccumulator* out) override {
+ ParquetException::NYI("DecodeArrow for DeltaLengthByteArrayDecoder");
+ }
+
+ private:
+ DeltaBitPackDecoder<Int32Type> len_decoder_;
+ ::arrow::MemoryPool* pool_;
+};
+
+// ----------------------------------------------------------------------
+// DELTA_BYTE_ARRAY
+
+class DeltaByteArrayDecoder : public DecoderImpl,
+ virtual public TypedDecoder<ByteArrayType> {
+ public:
+ explicit DeltaByteArrayDecoder(const ColumnDescriptor* descr,
+ MemoryPool* pool = ::arrow::default_memory_pool())
+ : DecoderImpl(descr, Encoding::DELTA_BYTE_ARRAY),
+ prefix_len_decoder_(nullptr, pool),
+ suffix_decoder_(nullptr, pool),
+ last_value_(0, nullptr) {}
+
+ virtual void SetData(int num_values, const uint8_t* data, int len) {
+ num_values_ = num_values;
+ if (len == 0) return;
+ int prefix_len_length = ::arrow::util::SafeLoadAs<int32_t>(data);
+ data += 4;
+ len -= 4;
+ prefix_len_decoder_.SetData(num_values, data, prefix_len_length);
+ data += prefix_len_length;
+ len -= prefix_len_length;
+ suffix_decoder_.SetData(num_values, data, len);
+ }
+
+ // TODO: this doesn't work and requires memory management. We need to allocate
+ // new strings to store the results.
+ virtual int Decode(ByteArray* buffer, int max_values) {
+ max_values = std::min(max_values, this->num_values_);
+ for (int i = 0; i < max_values; ++i) {
+ int prefix_len = 0;
+ prefix_len_decoder_.Decode(&prefix_len, 1);
+ ByteArray suffix = {0, nullptr};
+ suffix_decoder_.Decode(&suffix, 1);
+ buffer[i].len = prefix_len + suffix.len;
+
+ uint8_t* result = reinterpret_cast<uint8_t*>(malloc(buffer[i].len));
+ memcpy(result, last_value_.ptr, prefix_len);
+ memcpy(result + prefix_len, suffix.ptr, suffix.len);
+
+ buffer[i].ptr = result;
+ last_value_ = buffer[i];
+ }
+ this->num_values_ -= max_values;
+ return max_values;
+ }
+
+ private:
+ DeltaBitPackDecoder<Int32Type> prefix_len_decoder_;
+ DeltaLengthByteArrayDecoder suffix_decoder_;
+ ByteArray last_value_;
+};
+
+// ----------------------------------------------------------------------
+// BYTE_STREAM_SPLIT
+
+template <typename DType>
+class ByteStreamSplitDecoder : public DecoderImpl, virtual public TypedDecoder<DType> {
+ public:
+ using T = typename DType::c_type;
+ explicit ByteStreamSplitDecoder(const ColumnDescriptor* descr);
+
+ int Decode(T* buffer, int max_values) override;
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<DType>::Accumulator* builder) override;
+
+ int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<DType>::DictAccumulator* builder) override;
+
+ void SetData(int num_values, const uint8_t* data, int len) override;
+
+ T* EnsureDecodeBuffer(int64_t min_values) {
+ const int64_t size = sizeof(T) * min_values;
+ if (!decode_buffer_ || decode_buffer_->size() < size) {
+ PARQUET_ASSIGN_OR_THROW(decode_buffer_, ::arrow::AllocateBuffer(size));
+ }
+ return reinterpret_cast<T*>(decode_buffer_->mutable_data());
+ }
+
+ private:
+ int num_values_in_buffer_{0};
+ std::shared_ptr<Buffer> decode_buffer_;
+
+ static constexpr size_t kNumStreams = sizeof(T);
+};
+
+template <typename DType>
+ByteStreamSplitDecoder<DType>::ByteStreamSplitDecoder(const ColumnDescriptor* descr)
+ : DecoderImpl(descr, Encoding::BYTE_STREAM_SPLIT) {}
+
+template <typename DType>
+void ByteStreamSplitDecoder<DType>::SetData(int num_values, const uint8_t* data,
+ int len) {
+ DecoderImpl::SetData(num_values, data, len);
+ if (num_values * static_cast<int64_t>(sizeof(T)) > len) {
+ throw ParquetException("Data size too small for number of values (corrupted file?)");
+ }
+ num_values_in_buffer_ = num_values;
+}
+
+template <typename DType>
+int ByteStreamSplitDecoder<DType>::Decode(T* buffer, int max_values) {
+ const int values_to_decode = std::min(num_values_, max_values);
+ const int num_decoded_previously = num_values_in_buffer_ - num_values_;
+ const uint8_t* data = data_ + num_decoded_previously;
+
+ ::arrow::util::internal::ByteStreamSplitDecode<T>(data, values_to_decode,
+ num_values_in_buffer_, buffer);
+ num_values_ -= values_to_decode;
+ len_ -= sizeof(T) * values_to_decode;
+ return values_to_decode;
+}
+
+template <typename DType>
+int ByteStreamSplitDecoder<DType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<DType>::Accumulator* builder) {
+ constexpr int value_size = static_cast<int>(kNumStreams);
+ int values_decoded = num_values - null_count;
+ if (ARROW_PREDICT_FALSE(len_ < value_size * values_decoded)) {
+ ParquetException::EofException();
+ }
+
+ PARQUET_THROW_NOT_OK(builder->Reserve(num_values));
+
+ const int num_decoded_previously = num_values_in_buffer_ - num_values_;
+ const uint8_t* data = data_ + num_decoded_previously;
+ int offset = 0;
+
+#if defined(ARROW_HAVE_SIMD_SPLIT)
+ // Use fast decoding into intermediate buffer. This will also decode
+ // some null values, but it's fast enough that we don't care.
+ T* decode_out = EnsureDecodeBuffer(values_decoded);
+ ::arrow::util::internal::ByteStreamSplitDecode<T>(data, values_decoded,
+ num_values_in_buffer_, decode_out);
+
+ // XXX If null_count is 0, we could even append in bulk or decode directly into
+ // builder
+ VisitNullBitmapInline(
+ valid_bits, valid_bits_offset, num_values, null_count,
+ [&]() {
+ builder->UnsafeAppend(decode_out[offset]);
+ ++offset;
+ },
+ [&]() { builder->UnsafeAppendNull(); });
+
+#else
+ VisitNullBitmapInline(
+ valid_bits, valid_bits_offset, num_values, null_count,
+ [&]() {
+ uint8_t gathered_byte_data[kNumStreams];
+ for (size_t b = 0; b < kNumStreams; ++b) {
+ const size_t byte_index = b * num_values_in_buffer_ + offset;
+ gathered_byte_data[b] = data[byte_index];
+ }
+ builder->UnsafeAppend(::arrow::util::SafeLoadAs<T>(&gathered_byte_data[0]));
+ ++offset;
+ },
+ [&]() { builder->UnsafeAppendNull(); });
+#endif
+
+ num_values_ -= values_decoded;
+ len_ -= sizeof(T) * values_decoded;
+ return values_decoded;
+}
+
+template <typename DType>
+int ByteStreamSplitDecoder<DType>::DecodeArrow(
+ int num_values, int null_count, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ typename EncodingTraits<DType>::DictAccumulator* builder) {
+ ParquetException::NYI("DecodeArrow for ByteStreamSplitDecoder");
+}
+
+} // namespace
+
+// ----------------------------------------------------------------------
+// Encoder and decoder factory functions
+
+std::unique_ptr<Encoder> MakeEncoder(Type::type type_num, Encoding::type encoding,
+ bool use_dictionary, const ColumnDescriptor* descr,
+ MemoryPool* pool) {
+ if (use_dictionary) {
+ switch (type_num) {
+ case Type::INT32:
+ return std::unique_ptr<Encoder>(new DictEncoderImpl<Int32Type>(descr, pool));
+ case Type::INT64:
+ return std::unique_ptr<Encoder>(new DictEncoderImpl<Int64Type>(descr, pool));
+ case Type::INT96:
+ return std::unique_ptr<Encoder>(new DictEncoderImpl<Int96Type>(descr, pool));
+ case Type::FLOAT:
+ return std::unique_ptr<Encoder>(new DictEncoderImpl<FloatType>(descr, pool));
+ case Type::DOUBLE:
+ return std::unique_ptr<Encoder>(new DictEncoderImpl<DoubleType>(descr, pool));
+ case Type::BYTE_ARRAY:
+ return std::unique_ptr<Encoder>(new DictEncoderImpl<ByteArrayType>(descr, pool));
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return std::unique_ptr<Encoder>(new DictEncoderImpl<FLBAType>(descr, pool));
+ default:
+ DCHECK(false) << "Encoder not implemented";
+ break;
+ }
+ } else if (encoding == Encoding::PLAIN) {
+ switch (type_num) {
+ case Type::BOOLEAN:
+ return std::unique_ptr<Encoder>(new PlainEncoder<BooleanType>(descr, pool));
+ case Type::INT32:
+ return std::unique_ptr<Encoder>(new PlainEncoder<Int32Type>(descr, pool));
+ case Type::INT64:
+ return std::unique_ptr<Encoder>(new PlainEncoder<Int64Type>(descr, pool));
+ case Type::INT96:
+ return std::unique_ptr<Encoder>(new PlainEncoder<Int96Type>(descr, pool));
+ case Type::FLOAT:
+ return std::unique_ptr<Encoder>(new PlainEncoder<FloatType>(descr, pool));
+ case Type::DOUBLE:
+ return std::unique_ptr<Encoder>(new PlainEncoder<DoubleType>(descr, pool));
+ case Type::BYTE_ARRAY:
+ return std::unique_ptr<Encoder>(new PlainEncoder<ByteArrayType>(descr, pool));
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return std::unique_ptr<Encoder>(new PlainEncoder<FLBAType>(descr, pool));
+ default:
+ DCHECK(false) << "Encoder not implemented";
+ break;
+ }
+ } else if (encoding == Encoding::BYTE_STREAM_SPLIT) {
+ switch (type_num) {
+ case Type::FLOAT:
+ return std::unique_ptr<Encoder>(
+ new ByteStreamSplitEncoder<FloatType>(descr, pool));
+ case Type::DOUBLE:
+ return std::unique_ptr<Encoder>(
+ new ByteStreamSplitEncoder<DoubleType>(descr, pool));
+ default:
+ throw ParquetException("BYTE_STREAM_SPLIT only supports FLOAT and DOUBLE");
+ break;
+ }
+ } else {
+ ParquetException::NYI("Selected encoding is not supported");
+ }
+ DCHECK(false) << "Should not be able to reach this code";
+ return nullptr;
+}
+
+std::unique_ptr<Decoder> MakeDecoder(Type::type type_num, Encoding::type encoding,
+ const ColumnDescriptor* descr) {
+ if (encoding == Encoding::PLAIN) {
+ switch (type_num) {
+ case Type::BOOLEAN:
+ return std::unique_ptr<Decoder>(new PlainBooleanDecoder(descr));
+ case Type::INT32:
+ return std::unique_ptr<Decoder>(new PlainDecoder<Int32Type>(descr));
+ case Type::INT64:
+ return std::unique_ptr<Decoder>(new PlainDecoder<Int64Type>(descr));
+ case Type::INT96:
+ return std::unique_ptr<Decoder>(new PlainDecoder<Int96Type>(descr));
+ case Type::FLOAT:
+ return std::unique_ptr<Decoder>(new PlainDecoder<FloatType>(descr));
+ case Type::DOUBLE:
+ return std::unique_ptr<Decoder>(new PlainDecoder<DoubleType>(descr));
+ case Type::BYTE_ARRAY:
+ return std::unique_ptr<Decoder>(new PlainByteArrayDecoder(descr));
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return std::unique_ptr<Decoder>(new PlainFLBADecoder(descr));
+ default:
+ break;
+ }
+ } else if (encoding == Encoding::BYTE_STREAM_SPLIT) {
+ switch (type_num) {
+ case Type::FLOAT:
+ return std::unique_ptr<Decoder>(new ByteStreamSplitDecoder<FloatType>(descr));
+ case Type::DOUBLE:
+ return std::unique_ptr<Decoder>(new ByteStreamSplitDecoder<DoubleType>(descr));
+ default:
+ throw ParquetException("BYTE_STREAM_SPLIT only supports FLOAT and DOUBLE");
+ break;
+ }
+ } else if (encoding == Encoding::DELTA_BINARY_PACKED) {
+ switch (type_num) {
+ case Type::INT32:
+ return std::unique_ptr<Decoder>(new DeltaBitPackDecoder<Int32Type>(descr));
+ case Type::INT64:
+ return std::unique_ptr<Decoder>(new DeltaBitPackDecoder<Int64Type>(descr));
+ default:
+ throw ParquetException("DELTA_BINARY_PACKED only supports INT32 and INT64");
+ break;
+ }
+ } else {
+ ParquetException::NYI("Selected encoding is not supported");
+ }
+ DCHECK(false) << "Should not be able to reach this code";
+ return nullptr;
+}
+
+namespace detail {
+std::unique_ptr<Decoder> MakeDictDecoder(Type::type type_num,
+ const ColumnDescriptor* descr,
+ MemoryPool* pool) {
+ switch (type_num) {
+ case Type::BOOLEAN:
+ ParquetException::NYI("Dictionary encoding not implemented for boolean type");
+ case Type::INT32:
+ return std::unique_ptr<Decoder>(new DictDecoderImpl<Int32Type>(descr, pool));
+ case Type::INT64:
+ return std::unique_ptr<Decoder>(new DictDecoderImpl<Int64Type>(descr, pool));
+ case Type::INT96:
+ return std::unique_ptr<Decoder>(new DictDecoderImpl<Int96Type>(descr, pool));
+ case Type::FLOAT:
+ return std::unique_ptr<Decoder>(new DictDecoderImpl<FloatType>(descr, pool));
+ case Type::DOUBLE:
+ return std::unique_ptr<Decoder>(new DictDecoderImpl<DoubleType>(descr, pool));
+ case Type::BYTE_ARRAY:
+ return std::unique_ptr<Decoder>(new DictByteArrayDecoderImpl(descr, pool));
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return std::unique_ptr<Decoder>(new DictDecoderImpl<FLBAType>(descr, pool));
+ default:
+ break;
+ }
+ DCHECK(false) << "Should not be able to reach this code";
+ return nullptr;
+}
+
+} // namespace detail
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encoding.h b/src/arrow/cpp/src/parquet/encoding.h
new file mode 100644
index 000000000..b9ca7a7ee
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encoding.h
@@ -0,0 +1,460 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <vector>
+
+#include "arrow/util/spaced.h"
+
+#include "parquet/exception.h"
+#include "parquet/platform.h"
+#include "parquet/types.h"
+
+namespace arrow {
+
+class Array;
+class ArrayBuilder;
+class BinaryArray;
+class BinaryBuilder;
+class BooleanBuilder;
+class Int32Type;
+class Int64Type;
+class FloatType;
+class DoubleType;
+class FixedSizeBinaryType;
+template <typename T>
+class NumericBuilder;
+class FixedSizeBinaryBuilder;
+template <typename T>
+class Dictionary32Builder;
+
+} // namespace arrow
+
+namespace parquet {
+
+template <typename DType>
+class TypedEncoder;
+
+using BooleanEncoder = TypedEncoder<BooleanType>;
+using Int32Encoder = TypedEncoder<Int32Type>;
+using Int64Encoder = TypedEncoder<Int64Type>;
+using Int96Encoder = TypedEncoder<Int96Type>;
+using FloatEncoder = TypedEncoder<FloatType>;
+using DoubleEncoder = TypedEncoder<DoubleType>;
+using ByteArrayEncoder = TypedEncoder<ByteArrayType>;
+using FLBAEncoder = TypedEncoder<FLBAType>;
+
+template <typename DType>
+class TypedDecoder;
+
+class BooleanDecoder;
+using Int32Decoder = TypedDecoder<Int32Type>;
+using Int64Decoder = TypedDecoder<Int64Type>;
+using Int96Decoder = TypedDecoder<Int96Type>;
+using FloatDecoder = TypedDecoder<FloatType>;
+using DoubleDecoder = TypedDecoder<DoubleType>;
+using ByteArrayDecoder = TypedDecoder<ByteArrayType>;
+class FLBADecoder;
+
+template <typename T>
+struct EncodingTraits;
+
+template <>
+struct EncodingTraits<BooleanType> {
+ using Encoder = BooleanEncoder;
+ using Decoder = BooleanDecoder;
+
+ using ArrowType = ::arrow::BooleanType;
+ using Accumulator = ::arrow::BooleanBuilder;
+ struct DictAccumulator {};
+};
+
+template <>
+struct EncodingTraits<Int32Type> {
+ using Encoder = Int32Encoder;
+ using Decoder = Int32Decoder;
+
+ using ArrowType = ::arrow::Int32Type;
+ using Accumulator = ::arrow::NumericBuilder<::arrow::Int32Type>;
+ using DictAccumulator = ::arrow::Dictionary32Builder<::arrow::Int32Type>;
+};
+
+template <>
+struct EncodingTraits<Int64Type> {
+ using Encoder = Int64Encoder;
+ using Decoder = Int64Decoder;
+
+ using ArrowType = ::arrow::Int64Type;
+ using Accumulator = ::arrow::NumericBuilder<::arrow::Int64Type>;
+ using DictAccumulator = ::arrow::Dictionary32Builder<::arrow::Int64Type>;
+};
+
+template <>
+struct EncodingTraits<Int96Type> {
+ using Encoder = Int96Encoder;
+ using Decoder = Int96Decoder;
+
+ struct Accumulator {};
+ struct DictAccumulator {};
+};
+
+template <>
+struct EncodingTraits<FloatType> {
+ using Encoder = FloatEncoder;
+ using Decoder = FloatDecoder;
+
+ using ArrowType = ::arrow::FloatType;
+ using Accumulator = ::arrow::NumericBuilder<::arrow::FloatType>;
+ using DictAccumulator = ::arrow::Dictionary32Builder<::arrow::FloatType>;
+};
+
+template <>
+struct EncodingTraits<DoubleType> {
+ using Encoder = DoubleEncoder;
+ using Decoder = DoubleDecoder;
+
+ using ArrowType = ::arrow::DoubleType;
+ using Accumulator = ::arrow::NumericBuilder<::arrow::DoubleType>;
+ using DictAccumulator = ::arrow::Dictionary32Builder<::arrow::DoubleType>;
+};
+
+template <>
+struct EncodingTraits<ByteArrayType> {
+ using Encoder = ByteArrayEncoder;
+ using Decoder = ByteArrayDecoder;
+
+ /// \brief Internal helper class for decoding BYTE_ARRAY data where we can
+ /// overflow the capacity of a single arrow::BinaryArray
+ struct Accumulator {
+ std::unique_ptr<::arrow::BinaryBuilder> builder;
+ std::vector<std::shared_ptr<::arrow::Array>> chunks;
+ };
+ using ArrowType = ::arrow::BinaryType;
+ using DictAccumulator = ::arrow::Dictionary32Builder<::arrow::BinaryType>;
+};
+
+template <>
+struct EncodingTraits<FLBAType> {
+ using Encoder = FLBAEncoder;
+ using Decoder = FLBADecoder;
+
+ using ArrowType = ::arrow::FixedSizeBinaryType;
+ using Accumulator = ::arrow::FixedSizeBinaryBuilder;
+ using DictAccumulator = ::arrow::Dictionary32Builder<::arrow::FixedSizeBinaryType>;
+};
+
+class ColumnDescriptor;
+
+// Untyped base for all encoders
+class Encoder {
+ public:
+ virtual ~Encoder() = default;
+
+ virtual int64_t EstimatedDataEncodedSize() = 0;
+ virtual std::shared_ptr<Buffer> FlushValues() = 0;
+ virtual Encoding::type encoding() const = 0;
+
+ virtual void Put(const ::arrow::Array& values) = 0;
+
+ virtual MemoryPool* memory_pool() const = 0;
+};
+
+// Base class for value encoders. Since encoders may or not have state (e.g.,
+// dictionary encoding) we use a class instance to maintain any state.
+//
+// Encode interfaces are internal, subject to change without deprecation.
+template <typename DType>
+class TypedEncoder : virtual public Encoder {
+ public:
+ typedef typename DType::c_type T;
+
+ using Encoder::Put;
+
+ virtual void Put(const T* src, int num_values) = 0;
+
+ virtual void Put(const std::vector<T>& src, int num_values = -1);
+
+ virtual void PutSpaced(const T* src, int num_values, const uint8_t* valid_bits,
+ int64_t valid_bits_offset) = 0;
+};
+
+template <typename DType>
+void TypedEncoder<DType>::Put(const std::vector<T>& src, int num_values) {
+ if (num_values == -1) {
+ num_values = static_cast<int>(src.size());
+ }
+ Put(src.data(), num_values);
+}
+
+template <>
+inline void TypedEncoder<BooleanType>::Put(const std::vector<bool>& src, int num_values) {
+ // NOTE(wesm): This stub is here only to satisfy the compiler; it is
+ // overridden later with the actual implementation
+}
+
+// Base class for dictionary encoders
+template <typename DType>
+class DictEncoder : virtual public TypedEncoder<DType> {
+ public:
+ /// Writes out any buffered indices to buffer preceded by the bit width of this data.
+ /// Returns the number of bytes written.
+ /// If the supplied buffer is not big enough, returns -1.
+ /// buffer must be preallocated with buffer_len bytes. Use EstimatedDataEncodedSize()
+ /// to size buffer.
+ virtual int WriteIndices(uint8_t* buffer, int buffer_len) = 0;
+
+ virtual int dict_encoded_size() = 0;
+ // virtual int dict_encoded_size() { return dict_encoded_size_; }
+
+ virtual int bit_width() const = 0;
+
+ /// Writes out the encoded dictionary to buffer. buffer must be preallocated to
+ /// dict_encoded_size() bytes.
+ virtual void WriteDict(uint8_t* buffer) = 0;
+
+ virtual int num_entries() const = 0;
+
+ /// \brief EXPERIMENTAL: Append dictionary indices into the encoder. It is
+ /// assumed (without any boundschecking) that the indices reference
+ /// pre-existing dictionary values
+ /// \param[in] indices the dictionary index values. Only Int32Array currently
+ /// supported
+ virtual void PutIndices(const ::arrow::Array& indices) = 0;
+
+ /// \brief EXPERIMENTAL: Append dictionary into encoder, inserting indices
+ /// separately. Currently throws exception if the current dictionary memo is
+ /// non-empty
+ /// \param[in] values the dictionary values. Only valid for certain
+ /// Parquet/Arrow type combinations, like BYTE_ARRAY/BinaryArray
+ virtual void PutDictionary(const ::arrow::Array& values) = 0;
+};
+
+// ----------------------------------------------------------------------
+// Value decoding
+
+class Decoder {
+ public:
+ virtual ~Decoder() = default;
+
+ // Sets the data for a new page. This will be called multiple times on the same
+ // decoder and should reset all internal state.
+ virtual void SetData(int num_values, const uint8_t* data, int len) = 0;
+
+ // Returns the number of values left (for the last call to SetData()). This is
+ // the number of values left in this page.
+ virtual int values_left() const = 0;
+ virtual Encoding::type encoding() const = 0;
+};
+
+template <typename DType>
+class TypedDecoder : virtual public Decoder {
+ public:
+ using T = typename DType::c_type;
+
+ /// \brief Decode values into a buffer
+ ///
+ /// Subclasses may override the more specialized Decode methods below.
+ ///
+ /// \param[in] buffer destination for decoded values
+ /// \param[in] max_values maximum number of values to decode
+ /// \return The number of values decoded. Should be identical to max_values except
+ /// at the end of the current data page.
+ virtual int Decode(T* buffer, int max_values) = 0;
+
+ /// \brief Decode the values in this data page but leave spaces for null entries.
+ ///
+ /// \param[in] buffer destination for decoded values
+ /// \param[in] num_values size of the def_levels and buffer arrays including the number
+ /// of null slots
+ /// \param[in] null_count number of null slots
+ /// \param[in] valid_bits bitmap data indicating position of valid slots
+ /// \param[in] valid_bits_offset offset into valid_bits
+ /// \return The number of values decoded, including nulls.
+ virtual int DecodeSpaced(T* buffer, int num_values, int null_count,
+ const uint8_t* valid_bits, int64_t valid_bits_offset) {
+ if (null_count > 0) {
+ int values_to_read = num_values - null_count;
+ int values_read = Decode(buffer, values_to_read);
+ if (values_read != values_to_read) {
+ throw ParquetException("Number of values / definition_levels read did not match");
+ }
+
+ return ::arrow::util::internal::SpacedExpand<T>(buffer, num_values, null_count,
+ valid_bits, valid_bits_offset);
+ } else {
+ return Decode(buffer, num_values);
+ }
+ }
+
+ /// \brief Decode into an ArrayBuilder or other accumulator
+ ///
+ /// This function assumes the definition levels were already decoded
+ /// as a validity bitmap in the given `valid_bits`. `null_count`
+ /// is the number of 0s in `valid_bits`.
+ /// As a space optimization, it is allowed for `valid_bits` to be null
+ /// if `null_count` is zero.
+ ///
+ /// \return number of values decoded
+ virtual int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<DType>::Accumulator* out) = 0;
+
+ /// \brief Decode into an ArrayBuilder or other accumulator ignoring nulls
+ ///
+ /// \return number of values decoded
+ int DecodeArrowNonNull(int num_values,
+ typename EncodingTraits<DType>::Accumulator* out) {
+ return DecodeArrow(num_values, 0, /*valid_bits=*/NULLPTR, 0, out);
+ }
+
+ /// \brief Decode into a DictionaryBuilder
+ ///
+ /// This function assumes the definition levels were already decoded
+ /// as a validity bitmap in the given `valid_bits`. `null_count`
+ /// is the number of 0s in `valid_bits`.
+ /// As a space optimization, it is allowed for `valid_bits` to be null
+ /// if `null_count` is zero.
+ ///
+ /// \return number of values decoded
+ virtual int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ typename EncodingTraits<DType>::DictAccumulator* builder) = 0;
+
+ /// \brief Decode into a DictionaryBuilder ignoring nulls
+ ///
+ /// \return number of values decoded
+ int DecodeArrowNonNull(int num_values,
+ typename EncodingTraits<DType>::DictAccumulator* builder) {
+ return DecodeArrow(num_values, 0, /*valid_bits=*/NULLPTR, 0, builder);
+ }
+};
+
+template <typename DType>
+class DictDecoder : virtual public TypedDecoder<DType> {
+ public:
+ using T = typename DType::c_type;
+
+ virtual void SetDict(TypedDecoder<DType>* dictionary) = 0;
+
+ /// \brief Insert dictionary values into the Arrow dictionary builder's memo,
+ /// but do not append any indices
+ virtual void InsertDictionary(::arrow::ArrayBuilder* builder) = 0;
+
+ /// \brief Decode only dictionary indices and append to dictionary
+ /// builder. The builder must have had the dictionary from this decoder
+ /// inserted already.
+ ///
+ /// \warning Remember to reset the builder each time the dict decoder is initialized
+ /// with a new dictionary page
+ virtual int DecodeIndicesSpaced(int num_values, int null_count,
+ const uint8_t* valid_bits, int64_t valid_bits_offset,
+ ::arrow::ArrayBuilder* builder) = 0;
+
+ /// \brief Decode only dictionary indices (no nulls)
+ ///
+ /// \warning Remember to reset the builder each time the dict decoder is initialized
+ /// with a new dictionary page
+ virtual int DecodeIndices(int num_values, ::arrow::ArrayBuilder* builder) = 0;
+
+ /// \brief Decode only dictionary indices (no nulls). Same as above
+ /// DecodeIndices but target is an array instead of a builder.
+ ///
+ /// \note API EXPERIMENTAL
+ virtual int DecodeIndices(int num_values, int32_t* indices) = 0;
+
+ /// \brief Get dictionary. The reader will call this API when it encounters a
+ /// new dictionary.
+ ///
+ /// @param[out] dictionary The pointer to dictionary values. Dictionary is owned by
+ /// the decoder and is destroyed when the decoder is destroyed.
+ /// @param[out] dictionary_length The dictionary length.
+ ///
+ /// \note API EXPERIMENTAL
+ virtual void GetDictionary(const T** dictionary, int32_t* dictionary_length) = 0;
+};
+
+// ----------------------------------------------------------------------
+// TypedEncoder specializations, traits, and factory functions
+
+class BooleanDecoder : virtual public TypedDecoder<BooleanType> {
+ public:
+ using TypedDecoder<BooleanType>::Decode;
+ virtual int Decode(uint8_t* buffer, int max_values) = 0;
+};
+
+class FLBADecoder : virtual public TypedDecoder<FLBAType> {
+ public:
+ using TypedDecoder<FLBAType>::DecodeSpaced;
+
+ // TODO(wesm): As possible follow-up to PARQUET-1508, we should examine if
+ // there is value in adding specialized read methods for
+ // FIXED_LEN_BYTE_ARRAY. If only Decimal data can occur with this data type
+ // then perhaps not
+};
+
+PARQUET_EXPORT
+std::unique_ptr<Encoder> MakeEncoder(
+ Type::type type_num, Encoding::type encoding, bool use_dictionary = false,
+ const ColumnDescriptor* descr = NULLPTR,
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool());
+
+template <typename DType>
+std::unique_ptr<typename EncodingTraits<DType>::Encoder> MakeTypedEncoder(
+ Encoding::type encoding, bool use_dictionary = false,
+ const ColumnDescriptor* descr = NULLPTR,
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) {
+ using OutType = typename EncodingTraits<DType>::Encoder;
+ std::unique_ptr<Encoder> base =
+ MakeEncoder(DType::type_num, encoding, use_dictionary, descr, pool);
+ return std::unique_ptr<OutType>(dynamic_cast<OutType*>(base.release()));
+}
+
+PARQUET_EXPORT
+std::unique_ptr<Decoder> MakeDecoder(Type::type type_num, Encoding::type encoding,
+ const ColumnDescriptor* descr = NULLPTR);
+
+namespace detail {
+
+PARQUET_EXPORT
+std::unique_ptr<Decoder> MakeDictDecoder(Type::type type_num,
+ const ColumnDescriptor* descr,
+ ::arrow::MemoryPool* pool);
+
+} // namespace detail
+
+template <typename DType>
+std::unique_ptr<DictDecoder<DType>> MakeDictDecoder(
+ const ColumnDescriptor* descr = NULLPTR,
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) {
+ using OutType = DictDecoder<DType>;
+ auto decoder = detail::MakeDictDecoder(DType::type_num, descr, pool);
+ return std::unique_ptr<OutType>(dynamic_cast<OutType*>(decoder.release()));
+}
+
+template <typename DType>
+std::unique_ptr<typename EncodingTraits<DType>::Decoder> MakeTypedDecoder(
+ Encoding::type encoding, const ColumnDescriptor* descr = NULLPTR) {
+ using OutType = typename EncodingTraits<DType>::Decoder;
+ std::unique_ptr<Decoder> base = MakeDecoder(DType::type_num, encoding, descr);
+ return std::unique_ptr<OutType>(dynamic_cast<OutType*>(base.release()));
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encoding_benchmark.cc b/src/arrow/cpp/src/parquet/encoding_benchmark.cc
new file mode 100644
index 000000000..7c5eafd15
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encoding_benchmark.cc
@@ -0,0 +1,802 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/array/builder_dict.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+#include "arrow/util/byte_stream_split.h"
+
+#include "parquet/encoding.h"
+#include "parquet/platform.h"
+#include "parquet/schema.h"
+
+#include <cmath>
+#include <random>
+
+using arrow::default_memory_pool;
+using arrow::MemoryPool;
+
+namespace {
+
+// The min/max number of values used to drive each family of encoding benchmarks
+constexpr int MIN_RANGE = 1024;
+constexpr int MAX_RANGE = 65536;
+} // namespace
+
+namespace parquet {
+
+using schema::PrimitiveNode;
+
+std::shared_ptr<ColumnDescriptor> Int64Schema(Repetition::type repetition) {
+ auto node = PrimitiveNode::Make("int64", repetition, Type::INT64);
+ return std::make_shared<ColumnDescriptor>(node, repetition != Repetition::REQUIRED,
+ repetition == Repetition::REPEATED);
+}
+
+static void BM_PlainEncodingBoolean(benchmark::State& state) {
+ std::vector<bool> values(state.range(0), true);
+ auto encoder = MakeEncoder(Type::BOOLEAN, Encoding::PLAIN);
+ auto typed_encoder = dynamic_cast<BooleanEncoder*>(encoder.get());
+
+ for (auto _ : state) {
+ typed_encoder->Put(values, static_cast<int>(values.size()));
+ typed_encoder->FlushValues();
+ }
+ state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(bool));
+}
+
+BENCHMARK(BM_PlainEncodingBoolean)->Range(MIN_RANGE, MAX_RANGE);
+
+static void BM_PlainDecodingBoolean(benchmark::State& state) {
+ std::vector<bool> values(state.range(0), true);
+ bool* output = new bool[state.range(0)];
+ auto encoder = MakeEncoder(Type::BOOLEAN, Encoding::PLAIN);
+ auto typed_encoder = dynamic_cast<BooleanEncoder*>(encoder.get());
+ typed_encoder->Put(values, static_cast<int>(values.size()));
+ std::shared_ptr<Buffer> buf = encoder->FlushValues();
+
+ for (auto _ : state) {
+ auto decoder = MakeTypedDecoder<BooleanType>(Encoding::PLAIN);
+ decoder->SetData(static_cast<int>(values.size()), buf->data(),
+ static_cast<int>(buf->size()));
+ decoder->Decode(output, static_cast<int>(values.size()));
+ }
+
+ state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(bool));
+ delete[] output;
+}
+
+BENCHMARK(BM_PlainDecodingBoolean)->Range(MIN_RANGE, MAX_RANGE);
+
+static void BM_PlainEncodingInt64(benchmark::State& state) {
+ std::vector<int64_t> values(state.range(0), 64);
+ auto encoder = MakeTypedEncoder<Int64Type>(Encoding::PLAIN);
+ for (auto _ : state) {
+ encoder->Put(values.data(), static_cast<int>(values.size()));
+ encoder->FlushValues();
+ }
+ state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(int64_t));
+}
+
+BENCHMARK(BM_PlainEncodingInt64)->Range(MIN_RANGE, MAX_RANGE);
+
+static void BM_PlainDecodingInt64(benchmark::State& state) {
+ std::vector<int64_t> values(state.range(0), 64);
+ auto encoder = MakeTypedEncoder<Int64Type>(Encoding::PLAIN);
+ encoder->Put(values.data(), static_cast<int>(values.size()));
+ std::shared_ptr<Buffer> buf = encoder->FlushValues();
+
+ for (auto _ : state) {
+ auto decoder = MakeTypedDecoder<Int64Type>(Encoding::PLAIN);
+ decoder->SetData(static_cast<int>(values.size()), buf->data(),
+ static_cast<int>(buf->size()));
+ decoder->Decode(values.data(), static_cast<int>(values.size()));
+ }
+ state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(int64_t));
+}
+
+BENCHMARK(BM_PlainDecodingInt64)->Range(MIN_RANGE, MAX_RANGE);
+
+static void BM_PlainEncodingDouble(benchmark::State& state) {
+ std::vector<double> values(state.range(0), 64.0);
+ auto encoder = MakeTypedEncoder<DoubleType>(Encoding::PLAIN);
+ for (auto _ : state) {
+ encoder->Put(values.data(), static_cast<int>(values.size()));
+ encoder->FlushValues();
+ }
+ state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(double));
+}
+
+BENCHMARK(BM_PlainEncodingDouble)->Range(MIN_RANGE, MAX_RANGE);
+
+static void BM_PlainEncodingDoubleNaN(benchmark::State& state) {
+ std::vector<double> values(state.range(0), nan(""));
+ auto encoder = MakeTypedEncoder<DoubleType>(Encoding::PLAIN);
+ for (auto _ : state) {
+ encoder->Put(values.data(), static_cast<int>(values.size()));
+ encoder->FlushValues();
+ }
+ state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(double));
+}
+
+BENCHMARK(BM_PlainEncodingDoubleNaN)->Range(MIN_RANGE, MAX_RANGE);
+
+static void BM_PlainDecodingDouble(benchmark::State& state) {
+ std::vector<double> values(state.range(0), 64.0);
+ auto encoder = MakeTypedEncoder<DoubleType>(Encoding::PLAIN);
+ encoder->Put(values.data(), static_cast<int>(values.size()));
+ std::shared_ptr<Buffer> buf = encoder->FlushValues();
+
+ for (auto _ : state) {
+ auto decoder = MakeTypedDecoder<DoubleType>(Encoding::PLAIN);
+ decoder->SetData(static_cast<int>(values.size()), buf->data(),
+ static_cast<int>(buf->size()));
+ decoder->Decode(values.data(), static_cast<int>(values.size()));
+ }
+ state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(double));
+}
+
+BENCHMARK(BM_PlainDecodingDouble)->Range(MIN_RANGE, MAX_RANGE);
+
+static void BM_PlainEncodingFloat(benchmark::State& state) {
+ std::vector<float> values(state.range(0), 64.0);
+ auto encoder = MakeTypedEncoder<FloatType>(Encoding::PLAIN);
+ for (auto _ : state) {
+ encoder->Put(values.data(), static_cast<int>(values.size()));
+ encoder->FlushValues();
+ }
+ state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(float));
+}
+
+BENCHMARK(BM_PlainEncodingFloat)->Range(MIN_RANGE, MAX_RANGE);
+
+static void BM_PlainEncodingFloatNaN(benchmark::State& state) {
+ std::vector<float> values(state.range(0), nanf(""));
+ auto encoder = MakeTypedEncoder<FloatType>(Encoding::PLAIN);
+ for (auto _ : state) {
+ encoder->Put(values.data(), static_cast<int>(values.size()));
+ encoder->FlushValues();
+ }
+ state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(float));
+}
+
+BENCHMARK(BM_PlainEncodingFloatNaN)->Range(MIN_RANGE, MAX_RANGE);
+
+static void BM_PlainDecodingFloat(benchmark::State& state) {
+ std::vector<float> values(state.range(0), 64.0);
+ auto encoder = MakeTypedEncoder<FloatType>(Encoding::PLAIN);
+ encoder->Put(values.data(), static_cast<int>(values.size()));
+ std::shared_ptr<Buffer> buf = encoder->FlushValues();
+
+ for (auto _ : state) {
+ auto decoder = MakeTypedDecoder<FloatType>(Encoding::PLAIN);
+ decoder->SetData(static_cast<int>(values.size()), buf->data(),
+ static_cast<int>(buf->size()));
+ decoder->Decode(values.data(), static_cast<int>(values.size()));
+ }
+ state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(float));
+}
+
+BENCHMARK(BM_PlainDecodingFloat)->Range(MIN_RANGE, MAX_RANGE);
+
+template <typename ParquetType>
+struct BM_SpacedEncodingTraits {
+ using ArrowType = typename EncodingTraits<ParquetType>::ArrowType;
+ using ArrayType = typename ::arrow::TypeTraits<ArrowType>::ArrayType;
+ using CType = typename ParquetType::c_type;
+};
+
+template <>
+struct BM_SpacedEncodingTraits<BooleanType> {
+ // Leverage UInt8 vector array data for Boolean, the input src of PutSpaced is bool*
+ using ArrowType = ::arrow::UInt8Type;
+ using ArrayType = ::arrow::UInt8Array;
+ using CType = bool;
+};
+
+static void BM_PlainSpacedArgs(benchmark::internal::Benchmark* bench) {
+ constexpr auto kPlainSpacedSize = 32 * 1024; // 32k
+
+ bench->Args({/*size*/ kPlainSpacedSize, /*null_in_ten_thousand*/ 1});
+ bench->Args({/*size*/ kPlainSpacedSize, /*null_in_ten_thousand*/ 100});
+ bench->Args({/*size*/ kPlainSpacedSize, /*null_in_ten_thousand*/ 1000});
+ bench->Args({/*size*/ kPlainSpacedSize, /*null_in_ten_thousand*/ 5000});
+ bench->Args({/*size*/ kPlainSpacedSize, /*null_in_ten_thousand*/ 10000});
+}
+
+template <typename ParquetType>
+static void BM_PlainEncodingSpaced(benchmark::State& state) {
+ using ArrowType = typename BM_SpacedEncodingTraits<ParquetType>::ArrowType;
+ using ArrayType = typename BM_SpacedEncodingTraits<ParquetType>::ArrayType;
+ using CType = typename BM_SpacedEncodingTraits<ParquetType>::CType;
+
+ const int num_values = static_cast<int>(state.range(0));
+ const double null_percent = static_cast<double>(state.range(1)) / 10000.0;
+
+ auto rand = ::arrow::random::RandomArrayGenerator(1923);
+ const auto array = rand.Numeric<ArrowType>(num_values, -100, 100, null_percent);
+ const auto valid_bits = array->null_bitmap_data();
+ const auto array_actual = ::arrow::internal::checked_pointer_cast<ArrayType>(array);
+ const auto raw_values = array_actual->raw_values();
+ // Guarantee the type cast between raw_values and input of PutSpaced.
+ static_assert(sizeof(CType) == sizeof(*raw_values), "Type mismatch");
+ // Cast only happens for BooleanType as it use UInt8 for the array data to match a bool*
+ // input to PutSpaced.
+ const auto src = reinterpret_cast<const CType*>(raw_values);
+
+ auto encoder = MakeTypedEncoder<ParquetType>(Encoding::PLAIN);
+ for (auto _ : state) {
+ encoder->PutSpaced(src, num_values, valid_bits, 0);
+ encoder->FlushValues();
+ }
+ state.counters["null_percent"] = null_percent * 100;
+ state.SetBytesProcessed(state.iterations() * num_values * sizeof(CType));
+}
+
+static void BM_PlainEncodingSpacedBoolean(benchmark::State& state) {
+ BM_PlainEncodingSpaced<BooleanType>(state);
+}
+BENCHMARK(BM_PlainEncodingSpacedBoolean)->Apply(BM_PlainSpacedArgs);
+
+static void BM_PlainEncodingSpacedFloat(benchmark::State& state) {
+ BM_PlainEncodingSpaced<FloatType>(state);
+}
+BENCHMARK(BM_PlainEncodingSpacedFloat)->Apply(BM_PlainSpacedArgs);
+
+static void BM_PlainEncodingSpacedDouble(benchmark::State& state) {
+ BM_PlainEncodingSpaced<DoubleType>(state);
+}
+BENCHMARK(BM_PlainEncodingSpacedDouble)->Apply(BM_PlainSpacedArgs);
+
+template <typename ParquetType>
+static void BM_PlainDecodingSpaced(benchmark::State& state) {
+ using ArrowType = typename BM_SpacedEncodingTraits<ParquetType>::ArrowType;
+ using ArrayType = typename BM_SpacedEncodingTraits<ParquetType>::ArrayType;
+ using CType = typename BM_SpacedEncodingTraits<ParquetType>::CType;
+
+ const int num_values = static_cast<int>(state.range(0));
+ const auto null_percent = static_cast<double>(state.range(1)) / 10000.0;
+
+ auto rand = ::arrow::random::RandomArrayGenerator(1923);
+ const auto array = rand.Numeric<ArrowType>(num_values, -100, 100, null_percent);
+ const auto valid_bits = array->null_bitmap_data();
+ const int null_count = static_cast<int>(array->null_count());
+ const auto array_actual = ::arrow::internal::checked_pointer_cast<ArrayType>(array);
+ const auto raw_values = array_actual->raw_values();
+ // Guarantee the type cast between raw_values and input of PutSpaced.
+ static_assert(sizeof(CType) == sizeof(*raw_values), "Type mismatch");
+ // Cast only happens for BooleanType as it use UInt8 for the array data to match a bool*
+ // input to PutSpaced.
+ const auto src = reinterpret_cast<const CType*>(raw_values);
+
+ auto encoder = MakeTypedEncoder<ParquetType>(Encoding::PLAIN);
+ encoder->PutSpaced(src, num_values, valid_bits, 0);
+ std::shared_ptr<Buffer> buf = encoder->FlushValues();
+
+ auto decoder = MakeTypedDecoder<ParquetType>(Encoding::PLAIN);
+ std::vector<uint8_t> decode_values(num_values * sizeof(CType));
+ auto decode_buf = reinterpret_cast<CType*>(decode_values.data());
+ for (auto _ : state) {
+ decoder->SetData(num_values - null_count, buf->data(), static_cast<int>(buf->size()));
+ decoder->DecodeSpaced(decode_buf, num_values, null_count, valid_bits, 0);
+ }
+ state.counters["null_percent"] = null_percent * 100;
+ state.SetBytesProcessed(state.iterations() * num_values * sizeof(CType));
+}
+
+static void BM_PlainDecodingSpacedBoolean(benchmark::State& state) {
+ BM_PlainDecodingSpaced<BooleanType>(state);
+}
+BENCHMARK(BM_PlainDecodingSpacedBoolean)->Apply(BM_PlainSpacedArgs);
+
+static void BM_PlainDecodingSpacedFloat(benchmark::State& state) {
+ BM_PlainDecodingSpaced<FloatType>(state);
+}
+BENCHMARK(BM_PlainDecodingSpacedFloat)->Apply(BM_PlainSpacedArgs);
+
+static void BM_PlainDecodingSpacedDouble(benchmark::State& state) {
+ BM_PlainDecodingSpaced<DoubleType>(state);
+}
+BENCHMARK(BM_PlainDecodingSpacedDouble)->Apply(BM_PlainSpacedArgs);
+
+template <typename T, typename DecodeFunc>
+static void BM_ByteStreamSplitDecode(benchmark::State& state, DecodeFunc&& decode_func) {
+ std::vector<T> values(state.range(0), 64.0);
+ const uint8_t* values_raw = reinterpret_cast<const uint8_t*>(values.data());
+ std::vector<T> output(state.range(0), 0);
+
+ for (auto _ : state) {
+ decode_func(values_raw, static_cast<int64_t>(values.size()),
+ static_cast<int64_t>(values.size()), output.data());
+ benchmark::ClobberMemory();
+ }
+ state.SetBytesProcessed(state.iterations() * values.size() * sizeof(T));
+}
+
+template <typename T, typename EncodeFunc>
+static void BM_ByteStreamSplitEncode(benchmark::State& state, EncodeFunc&& encode_func) {
+ std::vector<T> values(state.range(0), 64.0);
+ const uint8_t* values_raw = reinterpret_cast<const uint8_t*>(values.data());
+ std::vector<uint8_t> output(state.range(0) * sizeof(T), 0);
+
+ for (auto _ : state) {
+ encode_func(values_raw, values.size(), output.data());
+ benchmark::ClobberMemory();
+ }
+ state.SetBytesProcessed(state.iterations() * values.size() * sizeof(T));
+}
+
+static void BM_ByteStreamSplitDecode_Float_Scalar(benchmark::State& state) {
+ BM_ByteStreamSplitDecode<float>(
+ state, ::arrow::util::internal::ByteStreamSplitDecodeScalar<float>);
+}
+
+static void BM_ByteStreamSplitDecode_Double_Scalar(benchmark::State& state) {
+ BM_ByteStreamSplitDecode<double>(
+ state, ::arrow::util::internal::ByteStreamSplitDecodeScalar<double>);
+}
+
+static void BM_ByteStreamSplitEncode_Float_Scalar(benchmark::State& state) {
+ BM_ByteStreamSplitEncode<float>(
+ state, ::arrow::util::internal::ByteStreamSplitEncodeScalar<float>);
+}
+
+static void BM_ByteStreamSplitEncode_Double_Scalar(benchmark::State& state) {
+ BM_ByteStreamSplitEncode<double>(
+ state, ::arrow::util::internal::ByteStreamSplitEncodeScalar<double>);
+}
+
+BENCHMARK(BM_ByteStreamSplitDecode_Float_Scalar)->Range(MIN_RANGE, MAX_RANGE);
+BENCHMARK(BM_ByteStreamSplitDecode_Double_Scalar)->Range(MIN_RANGE, MAX_RANGE);
+BENCHMARK(BM_ByteStreamSplitEncode_Float_Scalar)->Range(MIN_RANGE, MAX_RANGE);
+BENCHMARK(BM_ByteStreamSplitEncode_Double_Scalar)->Range(MIN_RANGE, MAX_RANGE);
+
+#if defined(ARROW_HAVE_SSE4_2)
+static void BM_ByteStreamSplitDecode_Float_Sse2(benchmark::State& state) {
+ BM_ByteStreamSplitDecode<float>(
+ state, ::arrow::util::internal::ByteStreamSplitDecodeSse2<float>);
+}
+
+static void BM_ByteStreamSplitDecode_Double_Sse2(benchmark::State& state) {
+ BM_ByteStreamSplitDecode<double>(
+ state, ::arrow::util::internal::ByteStreamSplitDecodeSse2<double>);
+}
+
+static void BM_ByteStreamSplitEncode_Float_Sse2(benchmark::State& state) {
+ BM_ByteStreamSplitEncode<float>(
+ state, ::arrow::util::internal::ByteStreamSplitEncodeSse2<float>);
+}
+
+static void BM_ByteStreamSplitEncode_Double_Sse2(benchmark::State& state) {
+ BM_ByteStreamSplitEncode<double>(
+ state, ::arrow::util::internal::ByteStreamSplitEncodeSse2<double>);
+}
+
+BENCHMARK(BM_ByteStreamSplitDecode_Float_Sse2)->Range(MIN_RANGE, MAX_RANGE);
+BENCHMARK(BM_ByteStreamSplitDecode_Double_Sse2)->Range(MIN_RANGE, MAX_RANGE);
+BENCHMARK(BM_ByteStreamSplitEncode_Float_Sse2)->Range(MIN_RANGE, MAX_RANGE);
+BENCHMARK(BM_ByteStreamSplitEncode_Double_Sse2)->Range(MIN_RANGE, MAX_RANGE);
+#endif
+
+#if defined(ARROW_HAVE_AVX2)
+static void BM_ByteStreamSplitDecode_Float_Avx2(benchmark::State& state) {
+ BM_ByteStreamSplitDecode<float>(
+ state, ::arrow::util::internal::ByteStreamSplitDecodeAvx2<float>);
+}
+
+static void BM_ByteStreamSplitDecode_Double_Avx2(benchmark::State& state) {
+ BM_ByteStreamSplitDecode<double>(
+ state, ::arrow::util::internal::ByteStreamSplitDecodeAvx2<double>);
+}
+
+static void BM_ByteStreamSplitEncode_Float_Avx2(benchmark::State& state) {
+ BM_ByteStreamSplitEncode<float>(
+ state, ::arrow::util::internal::ByteStreamSplitEncodeAvx2<float>);
+}
+
+static void BM_ByteStreamSplitEncode_Double_Avx2(benchmark::State& state) {
+ BM_ByteStreamSplitEncode<double>(
+ state, ::arrow::util::internal::ByteStreamSplitEncodeAvx2<double>);
+}
+
+BENCHMARK(BM_ByteStreamSplitDecode_Float_Avx2)->Range(MIN_RANGE, MAX_RANGE);
+BENCHMARK(BM_ByteStreamSplitDecode_Double_Avx2)->Range(MIN_RANGE, MAX_RANGE);
+BENCHMARK(BM_ByteStreamSplitEncode_Float_Avx2)->Range(MIN_RANGE, MAX_RANGE);
+BENCHMARK(BM_ByteStreamSplitEncode_Double_Avx2)->Range(MIN_RANGE, MAX_RANGE);
+#endif
+
+#if defined(ARROW_HAVE_AVX512)
+static void BM_ByteStreamSplitDecode_Float_Avx512(benchmark::State& state) {
+ BM_ByteStreamSplitDecode<float>(
+ state, ::arrow::util::internal::ByteStreamSplitDecodeAvx512<float>);
+}
+
+static void BM_ByteStreamSplitDecode_Double_Avx512(benchmark::State& state) {
+ BM_ByteStreamSplitDecode<double>(
+ state, ::arrow::util::internal::ByteStreamSplitDecodeAvx512<double>);
+}
+
+static void BM_ByteStreamSplitEncode_Float_Avx512(benchmark::State& state) {
+ BM_ByteStreamSplitEncode<float>(
+ state, ::arrow::util::internal::ByteStreamSplitEncodeAvx512<float>);
+}
+
+static void BM_ByteStreamSplitEncode_Double_Avx512(benchmark::State& state) {
+ BM_ByteStreamSplitEncode<double>(
+ state, ::arrow::util::internal::ByteStreamSplitEncodeAvx512<double>);
+}
+
+BENCHMARK(BM_ByteStreamSplitDecode_Float_Avx512)->Range(MIN_RANGE, MAX_RANGE);
+BENCHMARK(BM_ByteStreamSplitDecode_Double_Avx512)->Range(MIN_RANGE, MAX_RANGE);
+BENCHMARK(BM_ByteStreamSplitEncode_Float_Avx512)->Range(MIN_RANGE, MAX_RANGE);
+BENCHMARK(BM_ByteStreamSplitEncode_Double_Avx512)->Range(MIN_RANGE, MAX_RANGE);
+#endif
+
+template <typename Type>
+static void DecodeDict(std::vector<typename Type::c_type>& values,
+ benchmark::State& state) {
+ typedef typename Type::c_type T;
+ int num_values = static_cast<int>(values.size());
+
+ MemoryPool* allocator = default_memory_pool();
+ std::shared_ptr<ColumnDescriptor> descr = Int64Schema(Repetition::REQUIRED);
+
+ auto base_encoder =
+ MakeEncoder(Type::type_num, Encoding::PLAIN, true, descr.get(), allocator);
+ auto encoder =
+ dynamic_cast<typename EncodingTraits<Type>::Encoder*>(base_encoder.get());
+ auto dict_traits = dynamic_cast<DictEncoder<Type>*>(base_encoder.get());
+ encoder->Put(values.data(), num_values);
+
+ std::shared_ptr<ResizableBuffer> dict_buffer =
+ AllocateBuffer(allocator, dict_traits->dict_encoded_size());
+
+ std::shared_ptr<ResizableBuffer> indices =
+ AllocateBuffer(allocator, encoder->EstimatedDataEncodedSize());
+
+ dict_traits->WriteDict(dict_buffer->mutable_data());
+ int actual_bytes = dict_traits->WriteIndices(indices->mutable_data(),
+ static_cast<int>(indices->size()));
+
+ PARQUET_THROW_NOT_OK(indices->Resize(actual_bytes));
+
+ for (auto _ : state) {
+ auto dict_decoder = MakeTypedDecoder<Type>(Encoding::PLAIN, descr.get());
+ dict_decoder->SetData(dict_traits->num_entries(), dict_buffer->data(),
+ static_cast<int>(dict_buffer->size()));
+
+ auto decoder = MakeDictDecoder<Type>(descr.get());
+ decoder->SetDict(dict_decoder.get());
+ decoder->SetData(num_values, indices->data(), static_cast<int>(indices->size()));
+ decoder->Decode(values.data(), num_values);
+ }
+
+ state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(T));
+}
+
+static void BM_DictDecodingInt64_repeats(benchmark::State& state) {
+ typedef Int64Type Type;
+ typedef typename Type::c_type T;
+
+ std::vector<T> values(state.range(0), 64);
+ DecodeDict<Type>(values, state);
+}
+
+BENCHMARK(BM_DictDecodingInt64_repeats)->Range(MIN_RANGE, MAX_RANGE);
+
+static void BM_DictDecodingInt64_literals(benchmark::State& state) {
+ typedef Int64Type Type;
+ typedef typename Type::c_type T;
+
+ std::vector<T> values(state.range(0));
+ for (size_t i = 0; i < values.size(); ++i) {
+ values[i] = i;
+ }
+ DecodeDict<Type>(values, state);
+}
+
+BENCHMARK(BM_DictDecodingInt64_literals)->Range(MIN_RANGE, MAX_RANGE);
+
+// ----------------------------------------------------------------------
+// Shared benchmarks for decoding using arrow builders
+
+using ::arrow::BinaryBuilder;
+using ::arrow::BinaryDictionary32Builder;
+
+class BenchmarkDecodeArrow : public ::benchmark::Fixture {
+ public:
+ void SetUp(const ::benchmark::State& state) override {
+ num_values_ = static_cast<int>(state.range());
+ InitDataInputs();
+ DoEncodeArrow();
+ }
+
+ void TearDown(const ::benchmark::State& state) override {
+ buffer_.reset();
+ input_array_.reset();
+ values_.clear();
+ }
+
+ void InitDataInputs() {
+ // Generate a random string dictionary without any nulls so that this dataset can be
+ // used for benchmarking the DecodeArrowNonNull API
+ constexpr int repeat_factor = 8;
+ constexpr int64_t min_length = 2;
+ constexpr int64_t max_length = 10;
+ ::arrow::random::RandomArrayGenerator rag(0);
+ input_array_ = rag.StringWithRepeats(num_values_, num_values_ / repeat_factor,
+ min_length, max_length, /*null_probability=*/0);
+ valid_bits_ = input_array_->null_bitmap_data();
+ total_size_ = input_array_->data()->buffers[2]->size();
+
+ values_.reserve(num_values_);
+ const auto& binary_array = static_cast<const ::arrow::BinaryArray&>(*input_array_);
+ for (int64_t i = 0; i < binary_array.length(); i++) {
+ auto view = binary_array.GetView(i);
+ values_.emplace_back(static_cast<uint32_t>(view.length()),
+ reinterpret_cast<const uint8_t*>(view.data()));
+ }
+ }
+
+ virtual void DoEncodeArrow() = 0;
+ virtual void DoEncodeLowLevel() = 0;
+
+ virtual std::unique_ptr<ByteArrayDecoder> InitializeDecoder() = 0;
+
+ void EncodeArrowBenchmark(benchmark::State& state) {
+ for (auto _ : state) {
+ DoEncodeArrow();
+ }
+ state.SetBytesProcessed(state.iterations() * total_size_);
+ }
+
+ void EncodeLowLevelBenchmark(benchmark::State& state) {
+ for (auto _ : state) {
+ DoEncodeLowLevel();
+ }
+ state.SetBytesProcessed(state.iterations() * total_size_);
+ }
+
+ void DecodeArrowDenseBenchmark(benchmark::State& state) {
+ for (auto _ : state) {
+ auto decoder = InitializeDecoder();
+ typename EncodingTraits<ByteArrayType>::Accumulator acc;
+ acc.builder.reset(new BinaryBuilder);
+ decoder->DecodeArrow(num_values_, 0, valid_bits_, 0, &acc);
+ }
+ state.SetBytesProcessed(state.iterations() * total_size_);
+ }
+
+ void DecodeArrowNonNullDenseBenchmark(benchmark::State& state) {
+ for (auto _ : state) {
+ auto decoder = InitializeDecoder();
+ typename EncodingTraits<ByteArrayType>::Accumulator acc;
+ acc.builder.reset(new BinaryBuilder);
+ decoder->DecodeArrowNonNull(num_values_, &acc);
+ }
+ state.SetBytesProcessed(state.iterations() * total_size_);
+ }
+
+ void DecodeArrowDictBenchmark(benchmark::State& state) {
+ for (auto _ : state) {
+ auto decoder = InitializeDecoder();
+ BinaryDictionary32Builder builder(default_memory_pool());
+ decoder->DecodeArrow(num_values_, 0, valid_bits_, 0, &builder);
+ }
+
+ state.SetBytesProcessed(state.iterations() * total_size_);
+ }
+
+ void DecodeArrowNonNullDictBenchmark(benchmark::State& state) {
+ for (auto _ : state) {
+ auto decoder = InitializeDecoder();
+ BinaryDictionary32Builder builder(default_memory_pool());
+ decoder->DecodeArrowNonNull(num_values_, &builder);
+ }
+
+ state.SetBytesProcessed(state.iterations() * total_size_);
+ }
+
+ protected:
+ int num_values_;
+ std::shared_ptr<::arrow::Array> input_array_;
+ std::vector<ByteArray> values_;
+ uint64_t total_size_;
+ const uint8_t* valid_bits_;
+ std::shared_ptr<Buffer> buffer_;
+};
+
+// ----------------------------------------------------------------------
+// Benchmark Decoding from Plain Encoding
+class BM_ArrowBinaryPlain : public BenchmarkDecodeArrow {
+ public:
+ void DoEncodeArrow() override {
+ auto encoder = MakeTypedEncoder<ByteArrayType>(Encoding::PLAIN);
+ encoder->Put(*input_array_);
+ buffer_ = encoder->FlushValues();
+ }
+
+ void DoEncodeLowLevel() override {
+ auto encoder = MakeTypedEncoder<ByteArrayType>(Encoding::PLAIN);
+ encoder->Put(values_.data(), num_values_);
+ buffer_ = encoder->FlushValues();
+ }
+
+ std::unique_ptr<ByteArrayDecoder> InitializeDecoder() override {
+ auto decoder = MakeTypedDecoder<ByteArrayType>(Encoding::PLAIN);
+ decoder->SetData(num_values_, buffer_->data(), static_cast<int>(buffer_->size()));
+ return decoder;
+ }
+};
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryPlain, EncodeArrow)
+(benchmark::State& state) { EncodeArrowBenchmark(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryPlain, EncodeArrow)->Range(1 << 18, 1 << 20);
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryPlain, EncodeLowLevel)
+(benchmark::State& state) { EncodeLowLevelBenchmark(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryPlain, EncodeLowLevel)->Range(1 << 18, 1 << 20);
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryPlain, DecodeArrow_Dense)
+(benchmark::State& state) { DecodeArrowDenseBenchmark(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryPlain, DecodeArrow_Dense)->Range(MIN_RANGE, MAX_RANGE);
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryPlain, DecodeArrowNonNull_Dense)
+(benchmark::State& state) { DecodeArrowNonNullDenseBenchmark(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryPlain, DecodeArrowNonNull_Dense)
+ ->Range(MIN_RANGE, MAX_RANGE);
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryPlain, DecodeArrow_Dict)
+(benchmark::State& state) { DecodeArrowDictBenchmark(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryPlain, DecodeArrow_Dict)->Range(MIN_RANGE, MAX_RANGE);
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryPlain, DecodeArrowNonNull_Dict)
+(benchmark::State& state) { DecodeArrowNonNullDictBenchmark(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryPlain, DecodeArrowNonNull_Dict)
+ ->Range(MIN_RANGE, MAX_RANGE);
+
+// ----------------------------------------------------------------------
+// Benchmark Decoding from Dictionary Encoding
+class BM_ArrowBinaryDict : public BenchmarkDecodeArrow {
+ public:
+ template <typename PutValuesFunc>
+ void DoEncode(PutValuesFunc&& put_values) {
+ auto node = schema::ByteArray("name");
+ descr_ = std::unique_ptr<ColumnDescriptor>(new ColumnDescriptor(node, 0, 0));
+
+ auto encoder = MakeTypedEncoder<ByteArrayType>(Encoding::PLAIN,
+ /*use_dictionary=*/true, descr_.get());
+ put_values(encoder.get());
+ buffer_ = encoder->FlushValues();
+
+ auto dict_encoder = dynamic_cast<DictEncoder<ByteArrayType>*>(encoder.get());
+ ASSERT_NE(dict_encoder, nullptr);
+ dict_buffer_ =
+ AllocateBuffer(default_memory_pool(), dict_encoder->dict_encoded_size());
+ dict_encoder->WriteDict(dict_buffer_->mutable_data());
+ num_dict_entries_ = dict_encoder->num_entries();
+ }
+
+ template <typename IndexType>
+ void EncodeDictBenchmark(benchmark::State& state) {
+ constexpr int64_t nunique = 100;
+ constexpr int64_t min_length = 32;
+ constexpr int64_t max_length = 32;
+ ::arrow::random::RandomArrayGenerator rag(0);
+ auto dict = rag.String(nunique, min_length, max_length,
+ /*null_probability=*/0);
+ auto indices = rag.Numeric<IndexType, int32_t>(num_values_, 0, nunique - 1);
+
+ auto PutValues = [&](ByteArrayEncoder* encoder) {
+ auto dict_encoder = dynamic_cast<DictEncoder<ByteArrayType>*>(encoder);
+ dict_encoder->PutDictionary(*dict);
+ dict_encoder->PutIndices(*indices);
+ };
+ for (auto _ : state) {
+ DoEncode(std::move(PutValues));
+ }
+ state.SetItemsProcessed(state.iterations() * num_values_);
+ }
+
+ void DoEncodeArrow() override {
+ auto PutValues = [&](ByteArrayEncoder* encoder) {
+ ASSERT_NO_THROW(encoder->Put(*input_array_));
+ };
+ DoEncode(std::move(PutValues));
+ }
+
+ void DoEncodeLowLevel() override {
+ auto PutValues = [&](ByteArrayEncoder* encoder) {
+ encoder->Put(values_.data(), num_values_);
+ };
+ DoEncode(std::move(PutValues));
+ }
+
+ std::unique_ptr<ByteArrayDecoder> InitializeDecoder() override {
+ auto decoder = MakeTypedDecoder<ByteArrayType>(Encoding::PLAIN, descr_.get());
+ decoder->SetData(num_dict_entries_, dict_buffer_->data(),
+ static_cast<int>(dict_buffer_->size()));
+ auto dict_decoder = MakeDictDecoder<ByteArrayType>(descr_.get());
+ dict_decoder->SetDict(decoder.get());
+ dict_decoder->SetData(num_values_, buffer_->data(),
+ static_cast<int>(buffer_->size()));
+ return std::unique_ptr<ByteArrayDecoder>(
+ dynamic_cast<ByteArrayDecoder*>(dict_decoder.release()));
+ }
+
+ void TearDown(const ::benchmark::State& state) override {
+ BenchmarkDecodeArrow::TearDown(state);
+ dict_buffer_.reset();
+ descr_.reset();
+ }
+
+ protected:
+ std::unique_ptr<ColumnDescriptor> descr_;
+ std::shared_ptr<Buffer> dict_buffer_;
+ int num_dict_entries_;
+};
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryDict, EncodeArrow)
+(benchmark::State& state) { EncodeArrowBenchmark(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryDict, EncodeArrow)->Range(1 << 18, 1 << 20);
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryDict, EncodeDictDirectInt8)
+(benchmark::State& state) { EncodeDictBenchmark<::arrow::Int8Type>(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryDict, EncodeDictDirectInt8)->Range(1 << 20, 1 << 20);
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryDict, EncodeDictDirectInt16)
+(benchmark::State& state) { EncodeDictBenchmark<::arrow::Int16Type>(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryDict, EncodeDictDirectInt16)->Range(1 << 20, 1 << 20);
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryDict, EncodeDictDirectInt32)
+(benchmark::State& state) { EncodeDictBenchmark<::arrow::Int32Type>(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryDict, EncodeDictDirectInt32)->Range(1 << 20, 1 << 20);
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryDict, EncodeDictDirectInt64)
+(benchmark::State& state) { EncodeDictBenchmark<::arrow::Int64Type>(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryDict, EncodeDictDirectInt64)->Range(1 << 20, 1 << 20);
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryDict, EncodeLowLevel)
+(benchmark::State& state) { EncodeLowLevelBenchmark(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryDict, EncodeLowLevel)->Range(1 << 18, 1 << 20);
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryDict, DecodeArrow_Dense)(benchmark::State& state) {
+ DecodeArrowDenseBenchmark(state);
+}
+BENCHMARK_REGISTER_F(BM_ArrowBinaryDict, DecodeArrow_Dense)->Range(MIN_RANGE, MAX_RANGE);
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryDict, DecodeArrowNonNull_Dense)
+(benchmark::State& state) { DecodeArrowNonNullDenseBenchmark(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryDict, DecodeArrowNonNull_Dense)
+ ->Range(MIN_RANGE, MAX_RANGE);
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryDict, DecodeArrow_Dict)
+(benchmark::State& state) { DecodeArrowDictBenchmark(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryDict, DecodeArrow_Dict)->Range(MIN_RANGE, MAX_RANGE);
+
+BENCHMARK_DEFINE_F(BM_ArrowBinaryDict, DecodeArrowNonNull_Dict)
+(benchmark::State& state) { DecodeArrowNonNullDictBenchmark(state); }
+BENCHMARK_REGISTER_F(BM_ArrowBinaryDict, DecodeArrowNonNull_Dict)
+ ->Range(MIN_RANGE, MAX_RANGE);
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encoding_test.cc b/src/arrow/cpp/src/parquet/encoding_test.cc
new file mode 100644
index 000000000..d271d59ef
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encoding_test.cc
@@ -0,0 +1,1247 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+#include <limits>
+#include <utility>
+#include <vector>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_dict.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_writer.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/endian.h"
+
+#include "parquet/encoding.h"
+#include "parquet/platform.h"
+#include "parquet/schema.h"
+#include "parquet/test_util.h"
+#include "parquet/types.h"
+
+using arrow::default_memory_pool;
+using arrow::MemoryPool;
+using arrow::internal::checked_cast;
+
+namespace BitUtil = arrow::BitUtil;
+
+namespace parquet {
+
+namespace test {
+
+TEST(VectorBooleanTest, TestEncodeDecode) {
+ // PARQUET-454
+ int nvalues = 10000;
+ int nbytes = static_cast<int>(BitUtil::BytesForBits(nvalues));
+
+ std::vector<bool> draws;
+ ::arrow::random_is_valid(nvalues, 0.5 /* null prob */, &draws, 0 /* seed */);
+
+ std::unique_ptr<BooleanEncoder> encoder =
+ MakeTypedEncoder<BooleanType>(Encoding::PLAIN);
+ encoder->Put(draws, nvalues);
+
+ std::unique_ptr<BooleanDecoder> decoder =
+ MakeTypedDecoder<BooleanType>(Encoding::PLAIN);
+
+ std::shared_ptr<Buffer> encode_buffer = encoder->FlushValues();
+ ASSERT_EQ(nbytes, encode_buffer->size());
+
+ std::vector<uint8_t> decode_buffer(nbytes);
+ const uint8_t* decode_data = &decode_buffer[0];
+
+ decoder->SetData(nvalues, encode_buffer->data(),
+ static_cast<int>(encode_buffer->size()));
+ int values_decoded = decoder->Decode(&decode_buffer[0], nvalues);
+ ASSERT_EQ(nvalues, values_decoded);
+
+ for (int i = 0; i < nvalues; ++i) {
+ ASSERT_EQ(draws[i], ::arrow::BitUtil::GetBit(decode_data, i)) << i;
+ }
+}
+
+// ----------------------------------------------------------------------
+// test data generation
+
+template <typename T>
+void GenerateData(int num_values, T* out, std::vector<uint8_t>* heap) {
+ // seed the prng so failure is deterministic
+ random_numbers(num_values, 0, std::numeric_limits<T>::min(),
+ std::numeric_limits<T>::max(), out);
+}
+
+template <>
+void GenerateData<bool>(int num_values, bool* out, std::vector<uint8_t>* heap) {
+ // seed the prng so failure is deterministic
+ random_bools(num_values, 0.5, 0, out);
+}
+
+template <>
+void GenerateData<Int96>(int num_values, Int96* out, std::vector<uint8_t>* heap) {
+ // seed the prng so failure is deterministic
+ random_Int96_numbers(num_values, 0, std::numeric_limits<int32_t>::min(),
+ std::numeric_limits<int32_t>::max(), out);
+}
+
+template <>
+void GenerateData<ByteArray>(int num_values, ByteArray* out, std::vector<uint8_t>* heap) {
+ // seed the prng so failure is deterministic
+ int max_byte_array_len = 12;
+ heap->resize(num_values * max_byte_array_len);
+ random_byte_array(num_values, 0, heap->data(), out, 2, max_byte_array_len);
+}
+
+static int flba_length = 8;
+
+template <>
+void GenerateData<FLBA>(int num_values, FLBA* out, std::vector<uint8_t>* heap) {
+ // seed the prng so failure is deterministic
+ heap->resize(num_values * flba_length);
+ random_fixed_byte_array(num_values, 0, heap->data(), flba_length, out);
+}
+
+template <typename T>
+void VerifyResults(T* result, T* expected, int num_values) {
+ for (int i = 0; i < num_values; ++i) {
+ ASSERT_EQ(expected[i], result[i]) << i;
+ }
+}
+
+template <typename T>
+void VerifyResultsSpaced(T* result, T* expected, int num_values,
+ const uint8_t* valid_bits, int64_t valid_bits_offset) {
+ for (auto i = 0; i < num_values; ++i) {
+ if (BitUtil::GetBit(valid_bits, valid_bits_offset + i)) {
+ ASSERT_EQ(expected[i], result[i]) << i;
+ }
+ }
+}
+
+template <>
+void VerifyResults<FLBA>(FLBA* result, FLBA* expected, int num_values) {
+ for (int i = 0; i < num_values; ++i) {
+ ASSERT_EQ(0, memcmp(expected[i].ptr, result[i].ptr, flba_length)) << i;
+ }
+}
+
+template <>
+void VerifyResultsSpaced<FLBA>(FLBA* result, FLBA* expected, int num_values,
+ const uint8_t* valid_bits, int64_t valid_bits_offset) {
+ for (auto i = 0; i < num_values; ++i) {
+ if (BitUtil::GetBit(valid_bits, valid_bits_offset + i)) {
+ ASSERT_EQ(0, memcmp(expected[i].ptr, result[i].ptr, flba_length)) << i;
+ }
+ }
+}
+
+// ----------------------------------------------------------------------
+// Create some column descriptors
+
+template <typename DType>
+std::shared_ptr<ColumnDescriptor> ExampleDescr() {
+ auto node = schema::PrimitiveNode::Make("name", Repetition::OPTIONAL, DType::type_num);
+ return std::make_shared<ColumnDescriptor>(node, 0, 0);
+}
+
+template <>
+std::shared_ptr<ColumnDescriptor> ExampleDescr<FLBAType>() {
+ auto node = schema::PrimitiveNode::Make("name", Repetition::OPTIONAL,
+ Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, flba_length, 10, 2);
+ return std::make_shared<ColumnDescriptor>(node, 0, 0);
+}
+
+// ----------------------------------------------------------------------
+// Plain encoding tests
+
+template <typename Type>
+class TestEncodingBase : public ::testing::Test {
+ public:
+ using c_type = typename Type::c_type;
+ static constexpr int TYPE = Type::type_num;
+
+ void SetUp() {
+ descr_ = ExampleDescr<Type>();
+ type_length_ = descr_->type_length();
+ allocator_ = default_memory_pool();
+ }
+
+ void TearDown() {}
+
+ void InitData(int nvalues, int repeats) {
+ num_values_ = nvalues * repeats;
+ input_bytes_.resize(num_values_ * sizeof(c_type));
+ output_bytes_.resize(num_values_ * sizeof(c_type));
+ draws_ = reinterpret_cast<c_type*>(input_bytes_.data());
+ decode_buf_ = reinterpret_cast<c_type*>(output_bytes_.data());
+ GenerateData<c_type>(nvalues, draws_, &data_buffer_);
+
+ // add some repeated values
+ for (int j = 1; j < repeats; ++j) {
+ for (int i = 0; i < nvalues; ++i) {
+ draws_[nvalues * j + i] = draws_[i];
+ }
+ }
+ }
+
+ virtual void CheckRoundtrip() = 0;
+
+ virtual void CheckRoundtripSpaced(const uint8_t* valid_bits,
+ int64_t valid_bits_offset) {}
+
+ void Execute(int nvalues, int repeats) {
+ InitData(nvalues, repeats);
+ CheckRoundtrip();
+ }
+
+ void ExecuteSpaced(int nvalues, int repeats, int64_t valid_bits_offset,
+ double null_probability) {
+ InitData(nvalues, repeats);
+
+ int64_t size = num_values_ + valid_bits_offset;
+ auto rand = ::arrow::random::RandomArrayGenerator(1923);
+ const auto array = rand.UInt8(size, 0, 100, null_probability);
+ const auto valid_bits = array->null_bitmap_data();
+ if (valid_bits) {
+ CheckRoundtripSpaced(valid_bits, valid_bits_offset);
+ }
+ }
+
+ protected:
+ MemoryPool* allocator_;
+
+ int num_values_;
+ int type_length_;
+ c_type* draws_;
+ c_type* decode_buf_;
+ std::vector<uint8_t> input_bytes_;
+ std::vector<uint8_t> output_bytes_;
+ std::vector<uint8_t> data_buffer_;
+
+ std::shared_ptr<Buffer> encode_buffer_;
+ std::shared_ptr<ColumnDescriptor> descr_;
+};
+
+// Member variables are not visible to templated subclasses. Possibly figure
+// out an alternative to this class layering at some point
+#define USING_BASE_MEMBERS() \
+ using TestEncodingBase<Type>::allocator_; \
+ using TestEncodingBase<Type>::descr_; \
+ using TestEncodingBase<Type>::num_values_; \
+ using TestEncodingBase<Type>::draws_; \
+ using TestEncodingBase<Type>::data_buffer_; \
+ using TestEncodingBase<Type>::type_length_; \
+ using TestEncodingBase<Type>::encode_buffer_; \
+ using TestEncodingBase<Type>::decode_buf_;
+
+template <typename Type>
+class TestPlainEncoding : public TestEncodingBase<Type> {
+ public:
+ using c_type = typename Type::c_type;
+ static constexpr int TYPE = Type::type_num;
+
+ virtual void CheckRoundtrip() {
+ auto encoder = MakeTypedEncoder<Type>(Encoding::PLAIN, false, descr_.get());
+ auto decoder = MakeTypedDecoder<Type>(Encoding::PLAIN, descr_.get());
+ encoder->Put(draws_, num_values_);
+ encode_buffer_ = encoder->FlushValues();
+
+ decoder->SetData(num_values_, encode_buffer_->data(),
+ static_cast<int>(encode_buffer_->size()));
+ int values_decoded = decoder->Decode(decode_buf_, num_values_);
+ ASSERT_EQ(num_values_, values_decoded);
+ ASSERT_NO_FATAL_FAILURE(VerifyResults<c_type>(decode_buf_, draws_, num_values_));
+ }
+
+ void CheckRoundtripSpaced(const uint8_t* valid_bits, int64_t valid_bits_offset) {
+ auto encoder = MakeTypedEncoder<Type>(Encoding::PLAIN, false, descr_.get());
+ auto decoder = MakeTypedDecoder<Type>(Encoding::PLAIN, descr_.get());
+ int null_count = 0;
+ for (auto i = 0; i < num_values_; i++) {
+ if (!BitUtil::GetBit(valid_bits, valid_bits_offset + i)) {
+ null_count++;
+ }
+ }
+
+ encoder->PutSpaced(draws_, num_values_, valid_bits, valid_bits_offset);
+ encode_buffer_ = encoder->FlushValues();
+ decoder->SetData(num_values_ - null_count, encode_buffer_->data(),
+ static_cast<int>(encode_buffer_->size()));
+ auto values_decoded = decoder->DecodeSpaced(decode_buf_, num_values_, null_count,
+ valid_bits, valid_bits_offset);
+ ASSERT_EQ(num_values_, values_decoded);
+ ASSERT_NO_FATAL_FAILURE(VerifyResultsSpaced<c_type>(decode_buf_, draws_, num_values_,
+ valid_bits, valid_bits_offset));
+ }
+
+ protected:
+ USING_BASE_MEMBERS();
+};
+
+TYPED_TEST_SUITE(TestPlainEncoding, ParquetTypes);
+
+TYPED_TEST(TestPlainEncoding, BasicRoundTrip) {
+ ASSERT_NO_FATAL_FAILURE(this->Execute(10000, 1));
+
+ // Spaced test with different sizes and offest to guarantee SIMD implementation
+ constexpr int kAvx512Size = 64; // sizeof(__m512i) for Avx512
+ constexpr int kSimdSize = kAvx512Size; // Current the max is Avx512
+ constexpr int kMultiSimdSize = kSimdSize * 33;
+
+ for (auto null_prob : {0.001, 0.1, 0.5, 0.9, 0.999}) {
+ // Test with both size and offset up to 3 Simd block
+ for (auto i = 1; i < kSimdSize * 3; i++) {
+ ASSERT_NO_FATAL_FAILURE(this->ExecuteSpaced(i, 1, 0, null_prob));
+ ASSERT_NO_FATAL_FAILURE(this->ExecuteSpaced(i, 1, i + 1, null_prob));
+ }
+ // Large block and offset
+ ASSERT_NO_FATAL_FAILURE(this->ExecuteSpaced(kMultiSimdSize, 1, 0, null_prob));
+ ASSERT_NO_FATAL_FAILURE(this->ExecuteSpaced(kMultiSimdSize + 33, 1, 0, null_prob));
+ ASSERT_NO_FATAL_FAILURE(this->ExecuteSpaced(kMultiSimdSize, 1, 33, null_prob));
+ ASSERT_NO_FATAL_FAILURE(this->ExecuteSpaced(kMultiSimdSize + 33, 1, 33, null_prob));
+ }
+}
+
+// ----------------------------------------------------------------------
+// Dictionary encoding tests
+
+typedef ::testing::Types<Int32Type, Int64Type, Int96Type, FloatType, DoubleType,
+ ByteArrayType, FLBAType>
+ DictEncodedTypes;
+
+template <typename Type>
+class TestDictionaryEncoding : public TestEncodingBase<Type> {
+ public:
+ using c_type = typename Type::c_type;
+ static constexpr int TYPE = Type::type_num;
+
+ void CheckRoundtrip() {
+ std::vector<uint8_t> valid_bits(::arrow::BitUtil::BytesForBits(num_values_) + 1, 255);
+
+ auto base_encoder = MakeEncoder(Type::type_num, Encoding::PLAIN, true, descr_.get());
+ auto encoder =
+ dynamic_cast<typename EncodingTraits<Type>::Encoder*>(base_encoder.get());
+ auto dict_traits = dynamic_cast<DictEncoder<Type>*>(base_encoder.get());
+
+ ASSERT_NO_THROW(encoder->Put(draws_, num_values_));
+ dict_buffer_ =
+ AllocateBuffer(default_memory_pool(), dict_traits->dict_encoded_size());
+ dict_traits->WriteDict(dict_buffer_->mutable_data());
+ std::shared_ptr<Buffer> indices = encoder->FlushValues();
+
+ auto base_spaced_encoder =
+ MakeEncoder(Type::type_num, Encoding::PLAIN, true, descr_.get());
+ auto spaced_encoder =
+ dynamic_cast<typename EncodingTraits<Type>::Encoder*>(base_spaced_encoder.get());
+
+ // PutSpaced should lead to the same results
+ // This also checks the PutSpaced implementation for valid_bits=nullptr
+ ASSERT_NO_THROW(spaced_encoder->PutSpaced(draws_, num_values_, nullptr, 0));
+ std::shared_ptr<Buffer> indices_from_spaced = spaced_encoder->FlushValues();
+ ASSERT_TRUE(indices_from_spaced->Equals(*indices));
+
+ auto dict_decoder = MakeTypedDecoder<Type>(Encoding::PLAIN, descr_.get());
+ dict_decoder->SetData(dict_traits->num_entries(), dict_buffer_->data(),
+ static_cast<int>(dict_buffer_->size()));
+
+ auto decoder = MakeDictDecoder<Type>(descr_.get());
+ decoder->SetDict(dict_decoder.get());
+
+ decoder->SetData(num_values_, indices->data(), static_cast<int>(indices->size()));
+ int values_decoded = decoder->Decode(decode_buf_, num_values_);
+ ASSERT_EQ(num_values_, values_decoded);
+
+ // TODO(wesm): The DictionaryDecoder must stay alive because the decoded
+ // values' data is owned by a buffer inside the DictionaryEncoder. We
+ // should revisit when data lifetime is reviewed more generally.
+ ASSERT_NO_FATAL_FAILURE(VerifyResults<c_type>(decode_buf_, draws_, num_values_));
+
+ // Also test spaced decoding
+ decoder->SetData(num_values_, indices->data(), static_cast<int>(indices->size()));
+ // Also tests DecodeSpaced handling for valid_bits=nullptr
+ values_decoded = decoder->DecodeSpaced(decode_buf_, num_values_, 0, nullptr, 0);
+ ASSERT_EQ(num_values_, values_decoded);
+ ASSERT_NO_FATAL_FAILURE(VerifyResults<c_type>(decode_buf_, draws_, num_values_));
+ }
+
+ protected:
+ USING_BASE_MEMBERS();
+ std::shared_ptr<ResizableBuffer> dict_buffer_;
+};
+
+TYPED_TEST_SUITE(TestDictionaryEncoding, DictEncodedTypes);
+
+TYPED_TEST(TestDictionaryEncoding, BasicRoundTrip) {
+ ASSERT_NO_FATAL_FAILURE(this->Execute(2500, 2));
+}
+
+TEST(TestDictionaryEncoding, CannotDictDecodeBoolean) {
+ ASSERT_THROW(MakeDictDecoder<BooleanType>(nullptr), ParquetException);
+}
+
+// ----------------------------------------------------------------------
+// Shared arrow builder decode tests
+
+class TestArrowBuilderDecoding : public ::testing::Test {
+ public:
+ using DenseBuilder = ::arrow::internal::ChunkedBinaryBuilder;
+ using DictBuilder = ::arrow::BinaryDictionary32Builder;
+
+ void SetUp() override { null_probabilities_ = {0.0, 0.5, 1.0}; }
+ void TearDown() override {}
+
+ void InitTestCase(double null_probability) {
+ GenerateInputData(null_probability);
+ SetupEncoderDecoder();
+ }
+
+ void GenerateInputData(double null_probability) {
+ constexpr int num_unique = 100;
+ constexpr int repeat = 100;
+ constexpr int64_t min_length = 2;
+ constexpr int64_t max_length = 10;
+ ::arrow::random::RandomArrayGenerator rag(0);
+ expected_dense_ = rag.BinaryWithRepeats(repeat * num_unique, num_unique, min_length,
+ max_length, null_probability);
+
+ num_values_ = static_cast<int>(expected_dense_->length());
+ null_count_ = static_cast<int>(expected_dense_->null_count());
+ valid_bits_ = expected_dense_->null_bitmap_data();
+
+ auto builder = CreateDictBuilder();
+ ASSERT_OK(builder->AppendArray(*expected_dense_));
+ ASSERT_OK(builder->Finish(&expected_dict_));
+
+ // Initialize input_data_ for the encoder from the expected_array_ values
+ const auto& binary_array = static_cast<const ::arrow::BinaryArray&>(*expected_dense_);
+ input_data_.resize(binary_array.length());
+
+ for (int64_t i = 0; i < binary_array.length(); ++i) {
+ auto view = binary_array.GetView(i);
+ input_data_[i] = {static_cast<uint32_t>(view.length()),
+ reinterpret_cast<const uint8_t*>(view.data())};
+ }
+ }
+
+ std::unique_ptr<DictBuilder> CreateDictBuilder() {
+ return std::unique_ptr<DictBuilder>(new DictBuilder(default_memory_pool()));
+ }
+
+ // Setup encoder/decoder pair for testing with
+ virtual void SetupEncoderDecoder() = 0;
+
+ void CheckDense(int actual_num_values, const ::arrow::Array& chunk) {
+ ASSERT_EQ(actual_num_values, num_values_ - null_count_);
+ ASSERT_ARRAYS_EQUAL(chunk, *expected_dense_);
+ }
+
+ template <typename Builder>
+ void CheckDict(int actual_num_values, Builder& builder) {
+ ASSERT_EQ(actual_num_values, num_values_ - null_count_);
+ std::shared_ptr<::arrow::Array> actual;
+ ASSERT_OK(builder.Finish(&actual));
+ ASSERT_ARRAYS_EQUAL(*actual, *expected_dict_);
+ }
+
+ void CheckDecodeArrowUsingDenseBuilder() {
+ for (auto np : null_probabilities_) {
+ InitTestCase(np);
+
+ typename EncodingTraits<ByteArrayType>::Accumulator acc;
+ acc.builder.reset(new ::arrow::BinaryBuilder);
+ auto actual_num_values =
+ decoder_->DecodeArrow(num_values_, null_count_, valid_bits_, 0, &acc);
+
+ std::shared_ptr<::arrow::Array> chunk;
+ ASSERT_OK(acc.builder->Finish(&chunk));
+ CheckDense(actual_num_values, *chunk);
+ }
+ }
+
+ void CheckDecodeArrowUsingDictBuilder() {
+ for (auto np : null_probabilities_) {
+ InitTestCase(np);
+ auto builder = CreateDictBuilder();
+ auto actual_num_values =
+ decoder_->DecodeArrow(num_values_, null_count_, valid_bits_, 0, builder.get());
+ CheckDict(actual_num_values, *builder);
+ }
+ }
+
+ void CheckDecodeArrowNonNullUsingDenseBuilder() {
+ for (auto np : null_probabilities_) {
+ InitTestCase(np);
+ if (null_count_ > 0) {
+ continue;
+ }
+ typename EncodingTraits<ByteArrayType>::Accumulator acc;
+ acc.builder.reset(new ::arrow::BinaryBuilder);
+ auto actual_num_values = decoder_->DecodeArrowNonNull(num_values_, &acc);
+ std::shared_ptr<::arrow::Array> chunk;
+ ASSERT_OK(acc.builder->Finish(&chunk));
+ CheckDense(actual_num_values, *chunk);
+ }
+ }
+
+ void CheckDecodeArrowNonNullUsingDictBuilder() {
+ for (auto np : null_probabilities_) {
+ InitTestCase(np);
+ if (null_count_ > 0) {
+ continue;
+ }
+ auto builder = CreateDictBuilder();
+ auto actual_num_values = decoder_->DecodeArrowNonNull(num_values_, builder.get());
+ CheckDict(actual_num_values, *builder);
+ }
+ }
+
+ protected:
+ std::vector<double> null_probabilities_;
+ std::shared_ptr<::arrow::Array> expected_dict_;
+ std::shared_ptr<::arrow::Array> expected_dense_;
+ int num_values_;
+ int null_count_;
+ std::vector<ByteArray> input_data_;
+ const uint8_t* valid_bits_;
+ std::unique_ptr<ByteArrayEncoder> encoder_;
+ ByteArrayDecoder* decoder_;
+ std::unique_ptr<ByteArrayDecoder> plain_decoder_;
+ std::unique_ptr<DictDecoder<ByteArrayType>> dict_decoder_;
+ std::shared_ptr<Buffer> buffer_;
+};
+
+class PlainEncoding : public TestArrowBuilderDecoding {
+ public:
+ void SetupEncoderDecoder() override {
+ encoder_ = MakeTypedEncoder<ByteArrayType>(Encoding::PLAIN);
+ plain_decoder_ = MakeTypedDecoder<ByteArrayType>(Encoding::PLAIN);
+ decoder_ = plain_decoder_.get();
+ if (valid_bits_ != nullptr) {
+ ASSERT_NO_THROW(
+ encoder_->PutSpaced(input_data_.data(), num_values_, valid_bits_, 0));
+ } else {
+ ASSERT_NO_THROW(encoder_->Put(input_data_.data(), num_values_));
+ }
+ buffer_ = encoder_->FlushValues();
+ decoder_->SetData(num_values_, buffer_->data(), static_cast<int>(buffer_->size()));
+ }
+};
+
+TEST_F(PlainEncoding, CheckDecodeArrowUsingDenseBuilder) {
+ this->CheckDecodeArrowUsingDenseBuilder();
+}
+
+TEST_F(PlainEncoding, CheckDecodeArrowUsingDictBuilder) {
+ this->CheckDecodeArrowUsingDictBuilder();
+}
+
+TEST_F(PlainEncoding, CheckDecodeArrowNonNullDenseBuilder) {
+ this->CheckDecodeArrowNonNullUsingDenseBuilder();
+}
+
+TEST_F(PlainEncoding, CheckDecodeArrowNonNullDictBuilder) {
+ this->CheckDecodeArrowNonNullUsingDictBuilder();
+}
+
+TEST(PlainEncodingAdHoc, ArrowBinaryDirectPut) {
+ // Implemented as part of ARROW-3246
+
+ const int64_t size = 50;
+ const int32_t min_length = 0;
+ const int32_t max_length = 10;
+ const double null_probability = 0.25;
+
+ auto CheckSeed = [&](int seed) {
+ ::arrow::random::RandomArrayGenerator rag(seed);
+ auto values = rag.String(size, min_length, max_length, null_probability);
+
+ auto encoder = MakeTypedEncoder<ByteArrayType>(Encoding::PLAIN);
+ auto decoder = MakeTypedDecoder<ByteArrayType>(Encoding::PLAIN);
+
+ ASSERT_NO_THROW(encoder->Put(*values));
+ auto buf = encoder->FlushValues();
+
+ int num_values = static_cast<int>(values->length() - values->null_count());
+ decoder->SetData(num_values, buf->data(), static_cast<int>(buf->size()));
+
+ typename EncodingTraits<ByteArrayType>::Accumulator acc;
+ acc.builder.reset(new ::arrow::StringBuilder);
+ ASSERT_EQ(num_values,
+ decoder->DecodeArrow(static_cast<int>(values->length()),
+ static_cast<int>(values->null_count()),
+ values->null_bitmap_data(), values->offset(), &acc));
+
+ std::shared_ptr<::arrow::Array> result;
+ ASSERT_OK(acc.builder->Finish(&result));
+ ASSERT_EQ(50, result->length());
+ ::arrow::AssertArraysEqual(*values, *result);
+ };
+
+ for (auto seed : {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) {
+ CheckSeed(seed);
+ }
+}
+
+template <typename T>
+void GetDictDecoder(DictEncoder<T>* encoder, int64_t num_values,
+ std::shared_ptr<Buffer>* out_values,
+ std::shared_ptr<Buffer>* out_dict, const ColumnDescriptor* descr,
+ std::unique_ptr<TypedDecoder<T>>* out_decoder) {
+ auto decoder = MakeDictDecoder<T>(descr);
+ auto buf = encoder->FlushValues();
+ auto dict_buf = AllocateBuffer(default_memory_pool(), encoder->dict_encoded_size());
+ encoder->WriteDict(dict_buf->mutable_data());
+
+ auto dict_decoder = MakeTypedDecoder<T>(Encoding::PLAIN, descr);
+ dict_decoder->SetData(encoder->num_entries(), dict_buf->data(),
+ static_cast<int>(dict_buf->size()));
+
+ decoder->SetData(static_cast<int>(num_values), buf->data(),
+ static_cast<int>(buf->size()));
+ decoder->SetDict(dict_decoder.get());
+
+ *out_values = buf;
+ *out_dict = dict_buf;
+ ASSERT_NE(decoder, nullptr);
+ auto released = dynamic_cast<TypedDecoder<T>*>(decoder.release());
+ ASSERT_NE(released, nullptr);
+ *out_decoder = std::unique_ptr<TypedDecoder<T>>(released);
+}
+
+template <typename ParquetType>
+class EncodingAdHocTyped : public ::testing::Test {
+ public:
+ using ArrowType = typename EncodingTraits<ParquetType>::ArrowType;
+ using EncoderType = typename EncodingTraits<ParquetType>::Encoder;
+ using DecoderType = typename EncodingTraits<ParquetType>::Decoder;
+ using BuilderType = typename EncodingTraits<ParquetType>::Accumulator;
+ using DictBuilderType = typename EncodingTraits<ParquetType>::DictAccumulator;
+
+ static const ColumnDescriptor* column_descr() {
+ static auto column_descr = ExampleDescr<ParquetType>();
+ return column_descr.get();
+ }
+
+ std::shared_ptr<::arrow::Array> GetValues(int seed);
+
+ static std::shared_ptr<::arrow::DataType> arrow_type();
+
+ void Plain(int seed) {
+ auto values = GetValues(seed);
+ auto encoder = MakeTypedEncoder<ParquetType>(
+ Encoding::PLAIN, /*use_dictionary=*/false, column_descr());
+ auto decoder = MakeTypedDecoder<ParquetType>(Encoding::PLAIN, column_descr());
+
+ ASSERT_NO_THROW(encoder->Put(*values));
+ auto buf = encoder->FlushValues();
+
+ int num_values = static_cast<int>(values->length() - values->null_count());
+ decoder->SetData(num_values, buf->data(), static_cast<int>(buf->size()));
+
+ BuilderType acc(arrow_type(), ::arrow::default_memory_pool());
+ ASSERT_EQ(num_values,
+ decoder->DecodeArrow(static_cast<int>(values->length()),
+ static_cast<int>(values->null_count()),
+ values->null_bitmap_data(), values->offset(), &acc));
+
+ std::shared_ptr<::arrow::Array> result;
+ ASSERT_OK(acc.Finish(&result));
+ ASSERT_EQ(50, result->length());
+ ::arrow::AssertArraysEqual(*values, *result, /*verbose=*/true);
+ }
+
+ void ByteStreamSplit(int seed) {
+ if (!std::is_same<ParquetType, FloatType>::value &&
+ !std::is_same<ParquetType, DoubleType>::value) {
+ return;
+ }
+ auto values = GetValues(seed);
+ auto encoder = MakeTypedEncoder<ParquetType>(
+ Encoding::BYTE_STREAM_SPLIT, /*use_dictionary=*/false, column_descr());
+ auto decoder =
+ MakeTypedDecoder<ParquetType>(Encoding::BYTE_STREAM_SPLIT, column_descr());
+
+ ASSERT_NO_THROW(encoder->Put(*values));
+ auto buf = encoder->FlushValues();
+
+ int num_values = static_cast<int>(values->length() - values->null_count());
+ decoder->SetData(num_values, buf->data(), static_cast<int>(buf->size()));
+
+ BuilderType acc(arrow_type(), ::arrow::default_memory_pool());
+ ASSERT_EQ(num_values,
+ decoder->DecodeArrow(static_cast<int>(values->length()),
+ static_cast<int>(values->null_count()),
+ values->null_bitmap_data(), values->offset(), &acc));
+
+ std::shared_ptr<::arrow::Array> result;
+ ASSERT_OK(acc.Finish(&result));
+ ASSERT_EQ(50, result->length());
+ ::arrow::AssertArraysEqual(*values, *result);
+ }
+
+ void Dict(int seed) {
+ if (std::is_same<ParquetType, BooleanType>::value) {
+ return;
+ }
+
+ auto values = GetValues(seed);
+
+ auto owned_encoder =
+ MakeTypedEncoder<ParquetType>(Encoding::PLAIN,
+ /*use_dictionary=*/true, column_descr());
+ auto encoder = dynamic_cast<DictEncoder<ParquetType>*>(owned_encoder.get());
+
+ ASSERT_NO_THROW(encoder->Put(*values));
+
+ std::shared_ptr<Buffer> buf, dict_buf;
+ int num_values = static_cast<int>(values->length() - values->null_count());
+
+ std::unique_ptr<TypedDecoder<ParquetType>> decoder;
+ GetDictDecoder(encoder, num_values, &buf, &dict_buf, column_descr(), &decoder);
+
+ BuilderType acc(arrow_type(), ::arrow::default_memory_pool());
+ ASSERT_EQ(num_values,
+ decoder->DecodeArrow(static_cast<int>(values->length()),
+ static_cast<int>(values->null_count()),
+ values->null_bitmap_data(), values->offset(), &acc));
+
+ std::shared_ptr<::arrow::Array> result;
+ ASSERT_OK(acc.Finish(&result));
+ ::arrow::AssertArraysEqual(*values, *result);
+ }
+
+ void DictPutIndices() {
+ if (std::is_same<ParquetType, BooleanType>::value) {
+ return;
+ }
+
+ auto dict_values = ::arrow::ArrayFromJSON(
+ arrow_type(), std::is_same<ParquetType, FLBAType>::value
+ ? R"(["abcdefgh", "ijklmnop", "qrstuvwx"])"
+ : "[120, -37, 47]");
+ auto indices = ::arrow::ArrayFromJSON(::arrow::int32(), "[0, 1, 2]");
+ auto indices_nulls =
+ ::arrow::ArrayFromJSON(::arrow::int32(), "[null, 0, 1, null, 2]");
+
+ auto expected = ::arrow::ArrayFromJSON(
+ arrow_type(), std::is_same<ParquetType, FLBAType>::value
+ ? R"(["abcdefgh", "ijklmnop", "qrstuvwx", null,
+ "abcdefgh", "ijklmnop", null, "qrstuvwx"])"
+ : "[120, -37, 47, null, "
+ "120, -37, null, 47]");
+
+ auto owned_encoder =
+ MakeTypedEncoder<ParquetType>(Encoding::PLAIN,
+ /*use_dictionary=*/true, column_descr());
+ auto owned_decoder = MakeDictDecoder<ParquetType>();
+
+ auto encoder = dynamic_cast<DictEncoder<ParquetType>*>(owned_encoder.get());
+
+ ASSERT_NO_THROW(encoder->PutDictionary(*dict_values));
+
+ // Trying to call PutDictionary again throws
+ ASSERT_THROW(encoder->PutDictionary(*dict_values), ParquetException);
+
+ ASSERT_NO_THROW(encoder->PutIndices(*indices));
+ ASSERT_NO_THROW(encoder->PutIndices(*indices_nulls));
+
+ std::shared_ptr<Buffer> buf, dict_buf;
+ int num_values = static_cast<int>(expected->length() - expected->null_count());
+
+ std::unique_ptr<TypedDecoder<ParquetType>> decoder;
+ GetDictDecoder(encoder, num_values, &buf, &dict_buf, column_descr(), &decoder);
+
+ BuilderType acc(arrow_type(), ::arrow::default_memory_pool());
+ ASSERT_EQ(num_values, decoder->DecodeArrow(static_cast<int>(expected->length()),
+ static_cast<int>(expected->null_count()),
+ expected->null_bitmap_data(),
+ expected->offset(), &acc));
+
+ std::shared_ptr<::arrow::Array> result;
+ ASSERT_OK(acc.Finish(&result));
+ ::arrow::AssertArraysEqual(*expected, *result);
+ }
+
+ protected:
+ const int64_t size_ = 50;
+ const double null_probability_ = 0.25;
+};
+
+template <typename ParquetType>
+std::shared_ptr<::arrow::DataType> EncodingAdHocTyped<ParquetType>::arrow_type() {
+ return ::arrow::TypeTraits<ArrowType>::type_singleton();
+}
+
+template <>
+std::shared_ptr<::arrow::DataType> EncodingAdHocTyped<FLBAType>::arrow_type() {
+ return ::arrow::fixed_size_binary(sizeof(uint64_t));
+}
+
+template <typename ParquetType>
+std::shared_ptr<::arrow::Array> EncodingAdHocTyped<ParquetType>::GetValues(int seed) {
+ ::arrow::random::RandomArrayGenerator rag(seed);
+ return rag.Numeric<ArrowType>(size_, 0, 10, null_probability_);
+}
+
+template <>
+std::shared_ptr<::arrow::Array> EncodingAdHocTyped<BooleanType>::GetValues(int seed) {
+ ::arrow::random::RandomArrayGenerator rag(seed);
+ return rag.Boolean(size_, 0.1, null_probability_);
+}
+
+template <>
+std::shared_ptr<::arrow::Array> EncodingAdHocTyped<FLBAType>::GetValues(int seed) {
+ ::arrow::random::RandomArrayGenerator rag(seed);
+ std::shared_ptr<::arrow::Array> values;
+ ARROW_EXPECT_OK(
+ rag.UInt64(size_, 0, std::numeric_limits<uint64_t>::max(), null_probability_)
+ ->View(arrow_type())
+ .Value(&values));
+ return values;
+}
+
+using EncodingAdHocTypedCases =
+ ::testing::Types<BooleanType, Int32Type, Int64Type, FloatType, DoubleType, FLBAType>;
+
+TYPED_TEST_SUITE(EncodingAdHocTyped, EncodingAdHocTypedCases);
+
+TYPED_TEST(EncodingAdHocTyped, PlainArrowDirectPut) {
+ for (auto seed : {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) {
+ this->Plain(seed);
+ }
+}
+
+TYPED_TEST(EncodingAdHocTyped, ByteStreamSplitArrowDirectPut) {
+ for (auto seed : {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) {
+ this->ByteStreamSplit(seed);
+ }
+}
+
+TEST(DictEncodingAdHoc, ArrowBinaryDirectPut) {
+ // Implemented as part of ARROW-3246
+ const int64_t size = 50;
+ const int64_t min_length = 0;
+ const int64_t max_length = 10;
+ const double null_probability = 0.1;
+ ::arrow::random::RandomArrayGenerator rag(0);
+ auto values = rag.String(size, min_length, max_length, null_probability);
+
+ auto owned_encoder = MakeTypedEncoder<ByteArrayType>(Encoding::PLAIN,
+ /*use_dictionary=*/true);
+
+ auto encoder = dynamic_cast<DictEncoder<ByteArrayType>*>(owned_encoder.get());
+
+ ASSERT_NO_THROW(encoder->Put(*values));
+
+ std::unique_ptr<ByteArrayDecoder> decoder;
+ std::shared_ptr<Buffer> buf, dict_buf;
+ int num_values = static_cast<int>(values->length() - values->null_count());
+ GetDictDecoder(encoder, num_values, &buf, &dict_buf, nullptr, &decoder);
+
+ typename EncodingTraits<ByteArrayType>::Accumulator acc;
+ acc.builder.reset(new ::arrow::StringBuilder);
+ ASSERT_EQ(num_values,
+ decoder->DecodeArrow(static_cast<int>(values->length()),
+ static_cast<int>(values->null_count()),
+ values->null_bitmap_data(), values->offset(), &acc));
+
+ std::shared_ptr<::arrow::Array> result;
+ ASSERT_OK(acc.builder->Finish(&result));
+ ::arrow::AssertArraysEqual(*values, *result);
+}
+
+TYPED_TEST(EncodingAdHocTyped, DictArrowDirectPut) { this->Dict(0); }
+
+TEST(DictEncodingAdHoc, PutDictionaryPutIndices) {
+ // Part of ARROW-3246
+ auto dict_values =
+ ::arrow::ArrayFromJSON(::arrow::binary(), "[\"foo\", \"bar\", \"baz\"]");
+
+ auto CheckIndexType = [&](const std::shared_ptr<::arrow::DataType>& index_ty) {
+ auto indices = ::arrow::ArrayFromJSON(index_ty, "[0, 1, 2]");
+ auto indices_nulls = ::arrow::ArrayFromJSON(index_ty, "[null, 0, 1, null, 2]");
+
+ auto expected = ::arrow::ArrayFromJSON(::arrow::binary(),
+ "[\"foo\", \"bar\", \"baz\", null, "
+ "\"foo\", \"bar\", null, \"baz\"]");
+
+ auto owned_encoder = MakeTypedEncoder<ByteArrayType>(Encoding::PLAIN,
+ /*use_dictionary=*/true);
+ auto owned_decoder = MakeDictDecoder<ByteArrayType>();
+
+ auto encoder = dynamic_cast<DictEncoder<ByteArrayType>*>(owned_encoder.get());
+
+ ASSERT_NO_THROW(encoder->PutDictionary(*dict_values));
+
+ // Trying to call PutDictionary again throws
+ ASSERT_THROW(encoder->PutDictionary(*dict_values), ParquetException);
+
+ ASSERT_NO_THROW(encoder->PutIndices(*indices));
+ ASSERT_NO_THROW(encoder->PutIndices(*indices_nulls));
+
+ std::unique_ptr<ByteArrayDecoder> decoder;
+ std::shared_ptr<Buffer> buf, dict_buf;
+ int num_values = static_cast<int>(expected->length() - expected->null_count());
+ GetDictDecoder(encoder, num_values, &buf, &dict_buf, nullptr, &decoder);
+
+ typename EncodingTraits<ByteArrayType>::Accumulator acc;
+ acc.builder.reset(new ::arrow::BinaryBuilder);
+ ASSERT_EQ(num_values, decoder->DecodeArrow(static_cast<int>(expected->length()),
+ static_cast<int>(expected->null_count()),
+ expected->null_bitmap_data(),
+ expected->offset(), &acc));
+
+ std::shared_ptr<::arrow::Array> result;
+ ASSERT_OK(acc.builder->Finish(&result));
+ ::arrow::AssertArraysEqual(*expected, *result);
+ };
+
+ for (auto ty : ::arrow::all_dictionary_index_types()) {
+ CheckIndexType(ty);
+ }
+}
+
+TYPED_TEST(EncodingAdHocTyped, DictArrowDirectPutIndices) { this->DictPutIndices(); }
+
+class DictEncoding : public TestArrowBuilderDecoding {
+ public:
+ void SetupEncoderDecoder() override {
+ auto node = schema::ByteArray("name");
+ descr_ = std::unique_ptr<ColumnDescriptor>(new ColumnDescriptor(node, 0, 0));
+ encoder_ = MakeTypedEncoder<ByteArrayType>(Encoding::PLAIN, /*use_dictionary=*/true,
+ descr_.get());
+ if (null_count_ == 0) {
+ ASSERT_NO_THROW(encoder_->Put(input_data_.data(), num_values_));
+ } else {
+ ASSERT_NO_THROW(
+ encoder_->PutSpaced(input_data_.data(), num_values_, valid_bits_, 0));
+ }
+ buffer_ = encoder_->FlushValues();
+
+ auto dict_encoder = dynamic_cast<DictEncoder<ByteArrayType>*>(encoder_.get());
+ ASSERT_NE(dict_encoder, nullptr);
+ dict_buffer_ =
+ AllocateBuffer(default_memory_pool(), dict_encoder->dict_encoded_size());
+ dict_encoder->WriteDict(dict_buffer_->mutable_data());
+
+ // Simulate reading the dictionary page followed by a data page
+ plain_decoder_ = MakeTypedDecoder<ByteArrayType>(Encoding::PLAIN, descr_.get());
+ plain_decoder_->SetData(dict_encoder->num_entries(), dict_buffer_->data(),
+ static_cast<int>(dict_buffer_->size()));
+
+ dict_decoder_ = MakeDictDecoder<ByteArrayType>(descr_.get());
+ dict_decoder_->SetDict(plain_decoder_.get());
+ dict_decoder_->SetData(num_values_, buffer_->data(),
+ static_cast<int>(buffer_->size()));
+ decoder_ = dynamic_cast<ByteArrayDecoder*>(dict_decoder_.get());
+ }
+
+ protected:
+ std::unique_ptr<ColumnDescriptor> descr_;
+ std::shared_ptr<Buffer> dict_buffer_;
+};
+
+TEST_F(DictEncoding, CheckDecodeArrowUsingDenseBuilder) {
+ this->CheckDecodeArrowUsingDenseBuilder();
+}
+
+TEST_F(DictEncoding, CheckDecodeArrowUsingDictBuilder) {
+ this->CheckDecodeArrowUsingDictBuilder();
+}
+
+TEST_F(DictEncoding, CheckDecodeArrowNonNullDenseBuilder) {
+ this->CheckDecodeArrowNonNullUsingDenseBuilder();
+}
+
+TEST_F(DictEncoding, CheckDecodeArrowNonNullDictBuilder) {
+ this->CheckDecodeArrowNonNullUsingDictBuilder();
+}
+
+TEST_F(DictEncoding, CheckDecodeIndicesSpaced) {
+ for (auto np : null_probabilities_) {
+ InitTestCase(np);
+ auto builder = CreateDictBuilder();
+ dict_decoder_->InsertDictionary(builder.get());
+ int actual_num_values;
+ if (null_count_ == 0) {
+ actual_num_values = dict_decoder_->DecodeIndices(num_values_, builder.get());
+ } else {
+ actual_num_values = dict_decoder_->DecodeIndicesSpaced(
+ num_values_, null_count_, valid_bits_, 0, builder.get());
+ }
+ ASSERT_EQ(actual_num_values, num_values_ - null_count_);
+ std::shared_ptr<::arrow::Array> actual;
+ ASSERT_OK(builder->Finish(&actual));
+ ASSERT_ARRAYS_EQUAL(*actual, *expected_dict_);
+
+ // Check that null indices are zero-initialized
+ const auto& dict_actual = checked_cast<const ::arrow::DictionaryArray&>(*actual);
+ const auto& indices =
+ checked_cast<const ::arrow::Int32Array&>(*dict_actual.indices());
+
+ auto raw_values = indices.raw_values();
+ for (int64_t i = 0; i < indices.length(); ++i) {
+ if (indices.IsNull(i) && raw_values[i] != 0) {
+ FAIL() << "Null slot not zero-initialized";
+ }
+ }
+ }
+}
+
+TEST_F(DictEncoding, CheckDecodeIndicesNoNulls) {
+ InitTestCase(/*null_probability=*/0.0);
+ auto builder = CreateDictBuilder();
+ dict_decoder_->InsertDictionary(builder.get());
+ auto actual_num_values = dict_decoder_->DecodeIndices(num_values_, builder.get());
+ CheckDict(actual_num_values, *builder);
+}
+
+// ----------------------------------------------------------------------
+// BYTE_STREAM_SPLIT encode/decode tests.
+
+template <typename Type>
+class TestByteStreamSplitEncoding : public TestEncodingBase<Type> {
+ public:
+ using c_type = typename Type::c_type;
+ static constexpr int TYPE = Type::type_num;
+
+ void CheckRoundtrip() override {
+ auto encoder =
+ MakeTypedEncoder<Type>(Encoding::BYTE_STREAM_SPLIT, false, descr_.get());
+ auto decoder = MakeTypedDecoder<Type>(Encoding::BYTE_STREAM_SPLIT, descr_.get());
+ encoder->Put(draws_, num_values_);
+ encode_buffer_ = encoder->FlushValues();
+
+ {
+ decoder->SetData(num_values_, encode_buffer_->data(),
+ static_cast<int>(encode_buffer_->size()));
+ int values_decoded = decoder->Decode(decode_buf_, num_values_);
+ ASSERT_EQ(num_values_, values_decoded);
+ ASSERT_NO_FATAL_FAILURE(VerifyResults<c_type>(decode_buf_, draws_, num_values_));
+ }
+
+ {
+ // Try again but with a small step.
+ decoder->SetData(num_values_, encode_buffer_->data(),
+ static_cast<int>(encode_buffer_->size()));
+ int step = 131;
+ int remaining = num_values_;
+ for (int i = 0; i < num_values_; i += step) {
+ int num_decoded = decoder->Decode(decode_buf_, step);
+ ASSERT_EQ(num_decoded, std::min(step, remaining));
+ ASSERT_NO_FATAL_FAILURE(
+ VerifyResults<c_type>(decode_buf_, &draws_[i], num_decoded));
+ remaining -= num_decoded;
+ }
+ }
+
+ {
+ std::vector<uint8_t> valid_bits(::arrow::BitUtil::BytesForBits(num_values_), 0);
+ std::vector<c_type> expected_filtered_output;
+ const int every_nth = 5;
+ expected_filtered_output.reserve((num_values_ + every_nth - 1) / every_nth);
+ ::arrow::internal::BitmapWriter writer{valid_bits.data(), 0, num_values_};
+ // Set every fifth bit.
+ for (int i = 0; i < num_values_; ++i) {
+ if (i % every_nth == 0) {
+ writer.Set();
+ expected_filtered_output.push_back(draws_[i]);
+ }
+ writer.Next();
+ }
+ writer.Finish();
+ const int expected_size = static_cast<int>(expected_filtered_output.size());
+ ASSERT_NO_THROW(encoder->PutSpaced(draws_, num_values_, valid_bits.data(), 0));
+ encode_buffer_ = encoder->FlushValues();
+
+ decoder->SetData(expected_size, encode_buffer_->data(),
+ static_cast<int>(encode_buffer_->size()));
+ int values_decoded = decoder->Decode(decode_buf_, num_values_);
+ ASSERT_EQ(expected_size, values_decoded);
+ ASSERT_NO_FATAL_FAILURE(VerifyResults<c_type>(
+ decode_buf_, expected_filtered_output.data(), expected_size));
+ }
+ }
+
+ void CheckDecode();
+ void CheckEncode();
+
+ protected:
+ USING_BASE_MEMBERS();
+
+ void CheckDecode(const uint8_t* encoded_data, const int64_t encoded_data_size,
+ const c_type* expected_decoded_data, const int num_elements) {
+ std::unique_ptr<TypedDecoder<Type>> decoder =
+ MakeTypedDecoder<Type>(Encoding::BYTE_STREAM_SPLIT);
+ decoder->SetData(num_elements, encoded_data, static_cast<int>(encoded_data_size));
+ std::vector<c_type> decoded_data(num_elements);
+ int num_decoded_elements = decoder->Decode(decoded_data.data(), num_elements);
+ ASSERT_EQ(num_elements, num_decoded_elements);
+ for (size_t i = 0U; i < decoded_data.size(); ++i) {
+ ASSERT_EQ(expected_decoded_data[i], decoded_data[i]);
+ }
+ ASSERT_EQ(0, decoder->values_left());
+ }
+
+ void CheckEncode(const c_type* data, const int num_elements,
+ const uint8_t* expected_encoded_data,
+ const int64_t encoded_data_size) {
+ std::unique_ptr<TypedEncoder<Type>> encoder =
+ MakeTypedEncoder<Type>(Encoding::BYTE_STREAM_SPLIT);
+ encoder->Put(data, num_elements);
+ auto encoded_data = encoder->FlushValues();
+ ASSERT_EQ(encoded_data_size, encoded_data->size());
+ const uint8_t* encoded_data_raw = encoded_data->data();
+ for (int64_t i = 0; i < encoded_data->size(); ++i) {
+ ASSERT_EQ(expected_encoded_data[i], encoded_data_raw[i]);
+ }
+ }
+};
+
+template <typename c_type>
+static std::vector<c_type> ToLittleEndian(const std::vector<c_type>& input) {
+ std::vector<c_type> data(input.size());
+ std::transform(input.begin(), input.end(), data.begin(), [](const c_type& value) {
+ return ::arrow::BitUtil::ToLittleEndian(value);
+ });
+ return data;
+}
+
+static_assert(sizeof(float) == sizeof(uint32_t),
+ "BYTE_STREAM_SPLIT encoding tests assume float / uint32_t type sizes");
+static_assert(sizeof(double) == sizeof(uint64_t),
+ "BYTE_STREAM_SPLIT encoding tests assume double / uint64_t type sizes");
+
+template <>
+void TestByteStreamSplitEncoding<FloatType>::CheckDecode() {
+ const uint8_t data[] = {0x11, 0x22, 0x33, 0x44, 0x55, 0x66,
+ 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC};
+ const auto expected_output =
+ ToLittleEndian<uint32_t>({0xAA774411U, 0xBB885522U, 0xCC996633U});
+ CheckDecode(data, static_cast<int64_t>(sizeof(data)),
+ reinterpret_cast<const float*>(expected_output.data()),
+ static_cast<int>(sizeof(data) / sizeof(float)));
+}
+
+template <>
+void TestByteStreamSplitEncoding<DoubleType>::CheckDecode() {
+ const uint8_t data[] = {0xDE, 0xC0, 0x37, 0x13, 0x11, 0x22, 0x33, 0x44,
+ 0xAA, 0xBB, 0xCC, 0xDD, 0x55, 0x66, 0x77, 0x88};
+ const auto expected_output =
+ ToLittleEndian<uint64_t>({0x7755CCAA331137DEULL, 0x8866DDBB442213C0ULL});
+ CheckDecode(data, static_cast<int64_t>(sizeof(data)),
+ reinterpret_cast<const double*>(expected_output.data()),
+ static_cast<int>(sizeof(data) / sizeof(double)));
+}
+
+template <>
+void TestByteStreamSplitEncoding<DoubleType>::CheckEncode() {
+ const auto data = ToLittleEndian<uint64_t>(
+ {0x4142434445464748ULL, 0x0102030405060708ULL, 0xb1b2b3b4b5b6b7b8ULL});
+ const uint8_t expected_output[24] = {
+ 0x48, 0x08, 0xb8, 0x47, 0x07, 0xb7, 0x46, 0x06, 0xb6, 0x45, 0x05, 0xb5,
+ 0x44, 0x04, 0xb4, 0x43, 0x03, 0xb3, 0x42, 0x02, 0xb2, 0x41, 0x01, 0xb1,
+ };
+ CheckEncode(reinterpret_cast<const double*>(data.data()), static_cast<int>(data.size()),
+ expected_output, sizeof(expected_output));
+}
+
+template <>
+void TestByteStreamSplitEncoding<FloatType>::CheckEncode() {
+ const auto data = ToLittleEndian<uint32_t>({0xaabbccdd, 0x11223344});
+ const uint8_t expected_output[8] = {0xdd, 0x44, 0xcc, 0x33, 0xbb, 0x22, 0xaa, 0x11};
+ CheckEncode(reinterpret_cast<const float*>(data.data()), static_cast<int>(data.size()),
+ expected_output, sizeof(expected_output));
+}
+
+typedef ::testing::Types<FloatType, DoubleType> ByteStreamSplitTypes;
+TYPED_TEST_SUITE(TestByteStreamSplitEncoding, ByteStreamSplitTypes);
+
+TYPED_TEST(TestByteStreamSplitEncoding, BasicRoundTrip) {
+ for (int values = 0; values < 32; ++values) {
+ ASSERT_NO_FATAL_FAILURE(this->Execute(values, 1));
+ }
+
+ // We need to test with different sizes to guarantee that the SIMD implementation
+ // can handle both inputs with size divisible by 4/8 and sizes which would
+ // require a scalar loop for the suffix.
+ constexpr size_t kSuffixSize = 7;
+ constexpr size_t kAvx2Size = 32; // sizeof(__m256i) for AVX2
+ constexpr size_t kAvx512Size = 64; // sizeof(__m512i) for AVX512
+ constexpr size_t kMultiSimdSize = kAvx512Size * 7;
+
+ // Exercise only one SIMD loop. SSE and AVX2 covered in above loop.
+ ASSERT_NO_FATAL_FAILURE(this->Execute(kAvx512Size, 1));
+ // Exercise one SIMD loop with suffix. SSE covered in above loop.
+ ASSERT_NO_FATAL_FAILURE(this->Execute(kAvx2Size + kSuffixSize, 1));
+ ASSERT_NO_FATAL_FAILURE(this->Execute(kAvx512Size + kSuffixSize, 1));
+ // Exercise multi SIMD loop.
+ ASSERT_NO_FATAL_FAILURE(this->Execute(kMultiSimdSize, 1));
+ // Exercise multi SIMD loop with suffix.
+ ASSERT_NO_FATAL_FAILURE(this->Execute(kMultiSimdSize + kSuffixSize, 1));
+}
+
+TYPED_TEST(TestByteStreamSplitEncoding, RoundTripSingleElement) {
+ ASSERT_NO_FATAL_FAILURE(this->Execute(1, 1));
+}
+
+TYPED_TEST(TestByteStreamSplitEncoding, CheckOnlyDecode) {
+ ASSERT_NO_FATAL_FAILURE(this->CheckDecode());
+}
+
+TYPED_TEST(TestByteStreamSplitEncoding, CheckOnlyEncode) {
+ ASSERT_NO_FATAL_FAILURE(this->CheckEncode());
+}
+
+TEST(ByteStreamSplitEncodeDecode, InvalidDataTypes) {
+ // First check encoders.
+ ASSERT_THROW(MakeTypedEncoder<Int32Type>(Encoding::BYTE_STREAM_SPLIT),
+ ParquetException);
+ ASSERT_THROW(MakeTypedEncoder<Int64Type>(Encoding::BYTE_STREAM_SPLIT),
+ ParquetException);
+ ASSERT_THROW(MakeTypedEncoder<Int96Type>(Encoding::BYTE_STREAM_SPLIT),
+ ParquetException);
+ ASSERT_THROW(MakeTypedEncoder<BooleanType>(Encoding::BYTE_STREAM_SPLIT),
+ ParquetException);
+ ASSERT_THROW(MakeTypedEncoder<ByteArrayType>(Encoding::BYTE_STREAM_SPLIT),
+ ParquetException);
+ ASSERT_THROW(MakeTypedEncoder<FLBAType>(Encoding::BYTE_STREAM_SPLIT), ParquetException);
+
+ // Then check decoders.
+ ASSERT_THROW(MakeTypedDecoder<Int32Type>(Encoding::BYTE_STREAM_SPLIT),
+ ParquetException);
+ ASSERT_THROW(MakeTypedDecoder<Int64Type>(Encoding::BYTE_STREAM_SPLIT),
+ ParquetException);
+ ASSERT_THROW(MakeTypedDecoder<Int96Type>(Encoding::BYTE_STREAM_SPLIT),
+ ParquetException);
+ ASSERT_THROW(MakeTypedDecoder<BooleanType>(Encoding::BYTE_STREAM_SPLIT),
+ ParquetException);
+ ASSERT_THROW(MakeTypedDecoder<ByteArrayType>(Encoding::BYTE_STREAM_SPLIT),
+ ParquetException);
+ ASSERT_THROW(MakeTypedDecoder<FLBAType>(Encoding::BYTE_STREAM_SPLIT), ParquetException);
+}
+
+} // namespace test
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/CMakeLists.txt b/src/arrow/cpp/src/parquet/encryption/CMakeLists.txt
new file mode 100644
index 000000000..b4c977fcc
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/CMakeLists.txt
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Headers: public api
+arrow_install_all_headers("parquet/encryption")
diff --git a/src/arrow/cpp/src/parquet/encryption/crypto_factory.cc b/src/arrow/cpp/src/parquet/encryption/crypto_factory.cc
new file mode 100644
index 000000000..384516bff
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/crypto_factory.cc
@@ -0,0 +1,175 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/result.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string.h"
+#include "arrow/util/string_view.h"
+
+#include "parquet/encryption/crypto_factory.h"
+#include "parquet/encryption/encryption_internal.h"
+#include "parquet/encryption/file_key_material_store.h"
+#include "parquet/encryption/file_key_unwrapper.h"
+#include "parquet/encryption/key_toolkit_internal.h"
+
+namespace parquet {
+namespace encryption {
+
+void CryptoFactory::RegisterKmsClientFactory(
+ std::shared_ptr<KmsClientFactory> kms_client_factory) {
+ key_toolkit_.RegisterKmsClientFactory(kms_client_factory);
+}
+
+std::shared_ptr<FileEncryptionProperties> CryptoFactory::GetFileEncryptionProperties(
+ const KmsConnectionConfig& kms_connection_config,
+ const EncryptionConfiguration& encryption_config) {
+ if (!encryption_config.uniform_encryption && encryption_config.column_keys.empty()) {
+ throw ParquetException("Either column_keys or uniform_encryption must be set");
+ } else if (encryption_config.uniform_encryption &&
+ !encryption_config.column_keys.empty()) {
+ throw ParquetException("Cannot set both column_keys and uniform_encryption");
+ }
+ const std::string& footer_key_id = encryption_config.footer_key;
+ const std::string& column_key_str = encryption_config.column_keys;
+
+ std::shared_ptr<FileKeyMaterialStore> key_material_store = NULL;
+ if (!encryption_config.internal_key_material) {
+ // TODO: using external key material store with Hadoop file system
+ throw ParquetException("External key material store is not supported yet.");
+ }
+
+ FileKeyWrapper key_wrapper(&key_toolkit_, kms_connection_config, key_material_store,
+ encryption_config.cache_lifetime_seconds,
+ encryption_config.double_wrapping);
+
+ int32_t dek_length_bits = encryption_config.data_key_length_bits;
+ if (!internal::ValidateKeyLength(dek_length_bits)) {
+ std::ostringstream ss;
+ ss << "Wrong data key length : " << dek_length_bits;
+ throw ParquetException(ss.str());
+ }
+
+ int dek_length = dek_length_bits / 8;
+
+ std::string footer_key(dek_length, '\0');
+ RandBytes(reinterpret_cast<uint8_t*>(&footer_key[0]),
+ static_cast<int>(footer_key.size()));
+
+ std::string footer_key_metadata =
+ key_wrapper.GetEncryptionKeyMetadata(footer_key, footer_key_id, true);
+
+ FileEncryptionProperties::Builder properties_builder =
+ FileEncryptionProperties::Builder(footer_key);
+ properties_builder.footer_key_metadata(footer_key_metadata);
+ properties_builder.algorithm(encryption_config.encryption_algorithm);
+
+ if (!encryption_config.uniform_encryption) {
+ ColumnPathToEncryptionPropertiesMap encrypted_columns =
+ GetColumnEncryptionProperties(dek_length, column_key_str, &key_wrapper);
+ properties_builder.encrypted_columns(encrypted_columns);
+
+ if (encryption_config.plaintext_footer) {
+ properties_builder.set_plaintext_footer();
+ }
+ }
+
+ return properties_builder.build();
+}
+
+ColumnPathToEncryptionPropertiesMap CryptoFactory::GetColumnEncryptionProperties(
+ int dek_length, const std::string& column_keys, FileKeyWrapper* key_wrapper) {
+ ColumnPathToEncryptionPropertiesMap encrypted_columns;
+
+ std::vector<::arrow::util::string_view> key_to_columns =
+ ::arrow::internal::SplitString(column_keys, ';');
+ for (size_t i = 0; i < key_to_columns.size(); ++i) {
+ std::string cur_key_to_columns =
+ ::arrow::internal::TrimString(std::string(key_to_columns[i]));
+ if (cur_key_to_columns.empty()) {
+ continue;
+ }
+
+ std::vector<::arrow::util::string_view> parts =
+ ::arrow::internal::SplitString(cur_key_to_columns, ':');
+ if (parts.size() != 2) {
+ std::ostringstream message;
+ message << "Incorrect key to columns mapping in column keys property"
+ << ": [" << cur_key_to_columns << "]";
+ throw ParquetException(message.str());
+ }
+
+ std::string column_key_id = ::arrow::internal::TrimString(std::string(parts[0]));
+ if (column_key_id.empty()) {
+ throw ParquetException("Empty key name in column keys property.");
+ }
+
+ std::string column_names_str = ::arrow::internal::TrimString(std::string(parts[1]));
+ std::vector<::arrow::util::string_view> column_names =
+ ::arrow::internal::SplitString(column_names_str, ',');
+ if (0 == column_names.size()) {
+ throw ParquetException("No columns to encrypt defined for key: " + column_key_id);
+ }
+
+ for (size_t j = 0; j < column_names.size(); ++j) {
+ std::string column_name =
+ ::arrow::internal::TrimString(std::string(column_names[j]));
+ if (column_name.empty()) {
+ std::ostringstream message;
+ message << "Empty column name in column keys property for key: " << column_key_id;
+ throw ParquetException(message.str());
+ }
+
+ if (encrypted_columns.find(column_name) != encrypted_columns.end()) {
+ throw ParquetException("Multiple keys defined for the same column: " +
+ column_name);
+ }
+
+ std::string column_key(dek_length, '\0');
+ RandBytes(reinterpret_cast<uint8_t*>(&column_key[0]),
+ static_cast<int>(column_key.size()));
+ std::string column_key_key_metadata =
+ key_wrapper->GetEncryptionKeyMetadata(column_key, column_key_id, false);
+
+ std::shared_ptr<ColumnEncryptionProperties> cmd =
+ ColumnEncryptionProperties::Builder(column_name)
+ .key(column_key)
+ ->key_metadata(column_key_key_metadata)
+ ->build();
+ encrypted_columns.insert({column_name, cmd});
+ }
+ }
+ if (encrypted_columns.empty()) {
+ throw ParquetException("No column keys configured in column keys property.");
+ }
+
+ return encrypted_columns;
+}
+
+std::shared_ptr<FileDecryptionProperties> CryptoFactory::GetFileDecryptionProperties(
+ const KmsConnectionConfig& kms_connection_config,
+ const DecryptionConfiguration& decryption_config) {
+ std::shared_ptr<DecryptionKeyRetriever> key_retriever(new FileKeyUnwrapper(
+ &key_toolkit_, kms_connection_config, decryption_config.cache_lifetime_seconds));
+
+ return FileDecryptionProperties::Builder()
+ .key_retriever(key_retriever)
+ ->plaintext_files_allowed()
+ ->build();
+}
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/crypto_factory.h b/src/arrow/cpp/src/parquet/encryption/crypto_factory.h
new file mode 100644
index 000000000..d41e6ad21
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/crypto_factory.h
@@ -0,0 +1,135 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+
+#include "parquet/encryption/encryption.h"
+#include "parquet/encryption/file_key_wrapper.h"
+#include "parquet/encryption/key_toolkit.h"
+#include "parquet/encryption/kms_client_factory.h"
+#include "parquet/platform.h"
+
+namespace parquet {
+namespace encryption {
+
+static constexpr ParquetCipher::type kDefaultEncryptionAlgorithm =
+ ParquetCipher::AES_GCM_V1;
+static constexpr bool kDefaultPlaintextFooter = false;
+static constexpr bool kDefaultDoubleWrapping = true;
+static constexpr double kDefaultCacheLifetimeSeconds = 600; // 10 minutes
+static constexpr bool kDefaultInternalKeyMaterial = true;
+static constexpr bool kDefaultUniformEncryption = false;
+static constexpr int32_t kDefaultDataKeyLengthBits = 128;
+
+struct PARQUET_EXPORT EncryptionConfiguration {
+ explicit EncryptionConfiguration(const std::string& footer_key)
+ : footer_key(footer_key) {}
+
+ /// ID of the master key for footer encryption/signing
+ std::string footer_key;
+
+ /// List of columns to encrypt, with master key IDs (see HIVE-21848).
+ /// Format: "masterKeyID:colName,colName;masterKeyID:colName..."
+ /// Either
+ /// (1) column_keys must be set
+ /// or
+ /// (2) uniform_encryption must be set to true
+ /// If none of (1) and (2) are true, or if both are true, an exception will be
+ /// thrown.
+ std::string column_keys;
+
+ /// Encrypt footer and all columns with the same encryption key.
+ bool uniform_encryption = kDefaultUniformEncryption;
+
+ /// Parquet encryption algorithm. Can be "AES_GCM_V1" (default), or "AES_GCM_CTR_V1".
+ ParquetCipher::type encryption_algorithm = kDefaultEncryptionAlgorithm;
+
+ /// Write files with plaintext footer.
+ /// The default is false - files are written with encrypted footer.
+ bool plaintext_footer = kDefaultPlaintextFooter;
+
+ /// Use double wrapping - where data encryption keys (DEKs) are encrypted with key
+ /// encryption keys (KEKs), which in turn are encrypted with master keys.
+ /// The default is true. If set to false, use single wrapping - where DEKs are
+ /// encrypted directly with master keys.
+ bool double_wrapping = kDefaultDoubleWrapping;
+
+ /// Lifetime of cached entities (key encryption keys, local wrapping keys, KMS client
+ /// objects).
+ /// The default is 600 (10 minutes).
+ double cache_lifetime_seconds = kDefaultCacheLifetimeSeconds;
+
+ /// Store key material inside Parquet file footers; this mode doesn’t produce
+ /// additional files. By default, true. If set to false, key material is stored in
+ /// separate files in the same folder, which enables key rotation for immutable
+ /// Parquet files.
+ bool internal_key_material = kDefaultInternalKeyMaterial;
+
+ /// Length of data encryption keys (DEKs), randomly generated by parquet key
+ /// management tools. Can be 128, 192 or 256 bits.
+ /// The default is 128 bits.
+ int32_t data_key_length_bits = kDefaultDataKeyLengthBits;
+};
+
+struct PARQUET_EXPORT DecryptionConfiguration {
+ /// Lifetime of cached entities (key encryption keys, local wrapping keys, KMS client
+ /// objects).
+ /// The default is 600 (10 minutes).
+ double cache_lifetime_seconds = kDefaultCacheLifetimeSeconds;
+};
+
+/// This is a core class, that translates the parameters of high level encryption (like
+/// the names of encrypted columns, names of master keys, etc), into parameters of low
+/// level encryption (like the key metadata, DEK, etc). A factory that produces the low
+/// level FileEncryptionProperties and FileDecryptionProperties objects, from the high
+/// level parameters.
+class PARQUET_EXPORT CryptoFactory {
+ public:
+ /// a KmsClientFactory object must be registered via this method before calling any of
+ /// GetFileEncryptionProperties()/GetFileDecryptionProperties() methods.
+ void RegisterKmsClientFactory(std::shared_ptr<KmsClientFactory> kms_client_factory);
+
+ std::shared_ptr<FileEncryptionProperties> GetFileEncryptionProperties(
+ const KmsConnectionConfig& kms_connection_config,
+ const EncryptionConfiguration& encryption_config);
+
+ /// The returned FileDecryptionProperties object will use the cache inside this
+ /// CryptoFactory object, so please keep this
+ /// CryptoFactory object alive along with the returned
+ /// FileDecryptionProperties object.
+ std::shared_ptr<FileDecryptionProperties> GetFileDecryptionProperties(
+ const KmsConnectionConfig& kms_connection_config,
+ const DecryptionConfiguration& decryption_config);
+
+ void RemoveCacheEntriesForToken(const std::string& access_token) {
+ key_toolkit_.RemoveCacheEntriesForToken(access_token);
+ }
+
+ void RemoveCacheEntriesForAllTokens() { key_toolkit_.RemoveCacheEntriesForAllTokens(); }
+
+ private:
+ ColumnPathToEncryptionPropertiesMap GetColumnEncryptionProperties(
+ int dek_length, const std::string& column_keys, FileKeyWrapper* key_wrapper);
+
+ /// Key utilities object for kms client initialization and cache control
+ KeyToolkit key_toolkit_;
+};
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/encryption.cc b/src/arrow/cpp/src/parquet/encryption/encryption.cc
new file mode 100644
index 000000000..5927503ab
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/encryption.cc
@@ -0,0 +1,412 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/encryption/encryption.h"
+
+#include <string.h>
+
+#include <map>
+#include <utility>
+
+#include "arrow/util/logging.h"
+#include "arrow/util/utf8.h"
+#include "parquet/encryption/encryption_internal.h"
+
+namespace parquet {
+
+// integer key retriever
+void IntegerKeyIdRetriever::PutKey(uint32_t key_id, const std::string& key) {
+ key_map_.insert({key_id, key});
+}
+
+std::string IntegerKeyIdRetriever::GetKey(const std::string& key_metadata) {
+ uint32_t key_id;
+ memcpy(reinterpret_cast<uint8_t*>(&key_id), key_metadata.c_str(), 4);
+
+ return key_map_.at(key_id);
+}
+
+// string key retriever
+void StringKeyIdRetriever::PutKey(const std::string& key_id, const std::string& key) {
+ key_map_.insert({key_id, key});
+}
+
+std::string StringKeyIdRetriever::GetKey(const std::string& key_id) {
+ return key_map_.at(key_id);
+}
+
+ColumnEncryptionProperties::Builder* ColumnEncryptionProperties::Builder::key(
+ std::string column_key) {
+ if (column_key.empty()) return this;
+
+ DCHECK(key_.empty());
+ key_ = column_key;
+ return this;
+}
+
+ColumnEncryptionProperties::Builder* ColumnEncryptionProperties::Builder::key_metadata(
+ const std::string& key_metadata) {
+ DCHECK(!key_metadata.empty());
+ DCHECK(key_metadata_.empty());
+ key_metadata_ = key_metadata;
+ return this;
+}
+
+ColumnEncryptionProperties::Builder* ColumnEncryptionProperties::Builder::key_id(
+ const std::string& key_id) {
+ // key_id is expected to be in UTF8 encoding
+ ::arrow::util::InitializeUTF8();
+ const uint8_t* data = reinterpret_cast<const uint8_t*>(key_id.c_str());
+ if (!::arrow::util::ValidateUTF8(data, key_id.size())) {
+ throw ParquetException("key id should be in UTF8 encoding");
+ }
+
+ DCHECK(!key_id.empty());
+ this->key_metadata(key_id);
+ return this;
+}
+
+FileDecryptionProperties::Builder* FileDecryptionProperties::Builder::column_keys(
+ const ColumnPathToDecryptionPropertiesMap& column_decryption_properties) {
+ if (column_decryption_properties.size() == 0) return this;
+
+ if (column_decryption_properties_.size() != 0)
+ throw ParquetException("Column properties already set");
+
+ for (const auto& element : column_decryption_properties) {
+ if (element.second->is_utilized()) {
+ throw ParquetException("Column properties utilized in another file");
+ }
+ element.second->set_utilized();
+ }
+
+ column_decryption_properties_ = column_decryption_properties;
+ return this;
+}
+
+void FileDecryptionProperties::WipeOutDecryptionKeys() {
+ footer_key_.clear();
+
+ for (const auto& element : column_decryption_properties_) {
+ element.second->WipeOutDecryptionKey();
+ }
+}
+
+bool FileDecryptionProperties::is_utilized() {
+ if (footer_key_.empty() && column_decryption_properties_.size() == 0 &&
+ aad_prefix_.empty())
+ return false;
+
+ return utilized_;
+}
+
+std::shared_ptr<FileDecryptionProperties> FileDecryptionProperties::DeepClone(
+ std::string new_aad_prefix) {
+ std::string footer_key_copy = footer_key_;
+ ColumnPathToDecryptionPropertiesMap column_decryption_properties_map_copy;
+
+ for (const auto& element : column_decryption_properties_) {
+ column_decryption_properties_map_copy.insert(
+ {element.second->column_path(), element.second->DeepClone()});
+ }
+
+ if (new_aad_prefix.empty()) new_aad_prefix = aad_prefix_;
+ return std::shared_ptr<FileDecryptionProperties>(new FileDecryptionProperties(
+ footer_key_copy, key_retriever_, check_plaintext_footer_integrity_, new_aad_prefix,
+ aad_prefix_verifier_, column_decryption_properties_map_copy,
+ plaintext_files_allowed_));
+}
+
+FileDecryptionProperties::Builder* FileDecryptionProperties::Builder::footer_key(
+ const std::string footer_key) {
+ if (footer_key.empty()) {
+ return this;
+ }
+ DCHECK(footer_key_.empty());
+ footer_key_ = footer_key;
+ return this;
+}
+
+FileDecryptionProperties::Builder* FileDecryptionProperties::Builder::key_retriever(
+ const std::shared_ptr<DecryptionKeyRetriever>& key_retriever) {
+ if (key_retriever == nullptr) return this;
+
+ DCHECK(key_retriever_ == nullptr);
+ key_retriever_ = key_retriever;
+ return this;
+}
+
+FileDecryptionProperties::Builder* FileDecryptionProperties::Builder::aad_prefix(
+ const std::string& aad_prefix) {
+ if (aad_prefix.empty()) {
+ return this;
+ }
+ DCHECK(aad_prefix_.empty());
+ aad_prefix_ = aad_prefix;
+ return this;
+}
+
+FileDecryptionProperties::Builder* FileDecryptionProperties::Builder::aad_prefix_verifier(
+ std::shared_ptr<AADPrefixVerifier> aad_prefix_verifier) {
+ if (aad_prefix_verifier == nullptr) return this;
+
+ DCHECK(aad_prefix_verifier_ == nullptr);
+ aad_prefix_verifier_ = std::move(aad_prefix_verifier);
+ return this;
+}
+
+ColumnDecryptionProperties::Builder* ColumnDecryptionProperties::Builder::key(
+ const std::string& key) {
+ if (key.empty()) return this;
+
+ DCHECK(!key.empty());
+ key_ = key;
+ return this;
+}
+
+std::shared_ptr<ColumnDecryptionProperties> ColumnDecryptionProperties::Builder::build() {
+ return std::shared_ptr<ColumnDecryptionProperties>(
+ new ColumnDecryptionProperties(column_path_, key_));
+}
+
+void ColumnDecryptionProperties::WipeOutDecryptionKey() { key_.clear(); }
+
+std::shared_ptr<ColumnDecryptionProperties> ColumnDecryptionProperties::DeepClone() {
+ std::string key_copy = key_;
+ return std::shared_ptr<ColumnDecryptionProperties>(
+ new ColumnDecryptionProperties(column_path_, key_copy));
+}
+
+FileEncryptionProperties::Builder* FileEncryptionProperties::Builder::footer_key_metadata(
+ const std::string& footer_key_metadata) {
+ if (footer_key_metadata.empty()) return this;
+
+ DCHECK(footer_key_metadata_.empty());
+ footer_key_metadata_ = footer_key_metadata;
+ return this;
+}
+
+FileEncryptionProperties::Builder* FileEncryptionProperties::Builder::encrypted_columns(
+ const ColumnPathToEncryptionPropertiesMap& encrypted_columns) {
+ if (encrypted_columns.size() == 0) return this;
+
+ if (encrypted_columns_.size() != 0)
+ throw ParquetException("Column properties already set");
+
+ for (const auto& element : encrypted_columns) {
+ if (element.second->is_utilized()) {
+ throw ParquetException("Column properties utilized in another file");
+ }
+ element.second->set_utilized();
+ }
+ encrypted_columns_ = encrypted_columns;
+ return this;
+}
+
+void FileEncryptionProperties::WipeOutEncryptionKeys() {
+ footer_key_.clear();
+ for (const auto& element : encrypted_columns_) {
+ element.second->WipeOutEncryptionKey();
+ }
+}
+
+std::shared_ptr<FileEncryptionProperties> FileEncryptionProperties::DeepClone(
+ std::string new_aad_prefix) {
+ std::string footer_key_copy = footer_key_;
+ ColumnPathToEncryptionPropertiesMap encrypted_columns_map_copy;
+
+ for (const auto& element : encrypted_columns_) {
+ encrypted_columns_map_copy.insert(
+ {element.second->column_path(), element.second->DeepClone()});
+ }
+
+ if (new_aad_prefix.empty()) new_aad_prefix = aad_prefix_;
+ return std::shared_ptr<FileEncryptionProperties>(new FileEncryptionProperties(
+ algorithm_.algorithm, footer_key_copy, footer_key_metadata_, encrypted_footer_,
+ new_aad_prefix, store_aad_prefix_in_file_, encrypted_columns_map_copy));
+}
+
+FileEncryptionProperties::Builder* FileEncryptionProperties::Builder::aad_prefix(
+ const std::string& aad_prefix) {
+ if (aad_prefix.empty()) return this;
+
+ DCHECK(aad_prefix_.empty());
+ aad_prefix_ = aad_prefix;
+ store_aad_prefix_in_file_ = true;
+ return this;
+}
+
+FileEncryptionProperties::Builder*
+FileEncryptionProperties::Builder::disable_aad_prefix_storage() {
+ DCHECK(!aad_prefix_.empty());
+
+ store_aad_prefix_in_file_ = false;
+ return this;
+}
+
+ColumnEncryptionProperties::ColumnEncryptionProperties(bool encrypted,
+ const std::string& column_path,
+ const std::string& key,
+ const std::string& key_metadata)
+ : column_path_(column_path) {
+ // column encryption properties object (with a column key) can be used for writing only
+ // one file.
+ // Upon completion of file writing, the encryption keys in the properties will be wiped
+ // out (set to 0 in memory).
+ utilized_ = false;
+
+ DCHECK(!column_path.empty());
+ if (!encrypted) {
+ DCHECK(key.empty() && key_metadata.empty());
+ }
+
+ if (!key.empty()) {
+ DCHECK(key.length() == 16 || key.length() == 24 || key.length() == 32);
+ }
+
+ encrypted_with_footer_key_ = (encrypted && key.empty());
+ if (encrypted_with_footer_key_) {
+ DCHECK(key_metadata.empty());
+ }
+
+ encrypted_ = encrypted;
+ key_metadata_ = key_metadata;
+ key_ = key;
+}
+
+ColumnDecryptionProperties::ColumnDecryptionProperties(const std::string& column_path,
+ const std::string& key)
+ : column_path_(column_path) {
+ utilized_ = false;
+ DCHECK(!column_path.empty());
+
+ if (!key.empty()) {
+ DCHECK(key.length() == 16 || key.length() == 24 || key.length() == 32);
+ }
+
+ key_ = key;
+}
+
+std::string FileDecryptionProperties::column_key(const std::string& column_path) const {
+ if (column_decryption_properties_.find(column_path) !=
+ column_decryption_properties_.end()) {
+ auto column_prop = column_decryption_properties_.at(column_path);
+ if (column_prop != nullptr) {
+ return column_prop->key();
+ }
+ }
+ return empty_string_;
+}
+
+FileDecryptionProperties::FileDecryptionProperties(
+ const std::string& footer_key, std::shared_ptr<DecryptionKeyRetriever> key_retriever,
+ bool check_plaintext_footer_integrity, const std::string& aad_prefix,
+ std::shared_ptr<AADPrefixVerifier> aad_prefix_verifier,
+ const ColumnPathToDecryptionPropertiesMap& column_decryption_properties,
+ bool plaintext_files_allowed) {
+ DCHECK(!footer_key.empty() || nullptr != key_retriever ||
+ 0 != column_decryption_properties.size());
+
+ if (!footer_key.empty()) {
+ DCHECK(footer_key.length() == 16 || footer_key.length() == 24 ||
+ footer_key.length() == 32);
+ }
+ if (footer_key.empty() && check_plaintext_footer_integrity) {
+ DCHECK(nullptr != key_retriever);
+ }
+ aad_prefix_verifier_ = std::move(aad_prefix_verifier);
+ footer_key_ = footer_key;
+ check_plaintext_footer_integrity_ = check_plaintext_footer_integrity;
+ key_retriever_ = std::move(key_retriever);
+ aad_prefix_ = aad_prefix;
+ column_decryption_properties_ = column_decryption_properties;
+ plaintext_files_allowed_ = plaintext_files_allowed;
+ utilized_ = false;
+}
+
+FileEncryptionProperties::Builder* FileEncryptionProperties::Builder::footer_key_id(
+ const std::string& key_id) {
+ // key_id is expected to be in UTF8 encoding
+ ::arrow::util::InitializeUTF8();
+ const uint8_t* data = reinterpret_cast<const uint8_t*>(key_id.c_str());
+ if (!::arrow::util::ValidateUTF8(data, key_id.size())) {
+ throw ParquetException("footer key id should be in UTF8 encoding");
+ }
+
+ if (key_id.empty()) {
+ return this;
+ }
+
+ return footer_key_metadata(key_id);
+}
+
+std::shared_ptr<ColumnEncryptionProperties>
+FileEncryptionProperties::column_encryption_properties(const std::string& column_path) {
+ if (encrypted_columns_.size() == 0) {
+ auto builder = std::make_shared<ColumnEncryptionProperties::Builder>(column_path);
+ return builder->build();
+ }
+ if (encrypted_columns_.find(column_path) != encrypted_columns_.end()) {
+ return encrypted_columns_[column_path];
+ }
+
+ return nullptr;
+}
+
+FileEncryptionProperties::FileEncryptionProperties(
+ ParquetCipher::type cipher, const std::string& footer_key,
+ const std::string& footer_key_metadata, bool encrypted_footer,
+ const std::string& aad_prefix, bool store_aad_prefix_in_file,
+ const ColumnPathToEncryptionPropertiesMap& encrypted_columns)
+ : footer_key_(footer_key),
+ footer_key_metadata_(footer_key_metadata),
+ encrypted_footer_(encrypted_footer),
+ aad_prefix_(aad_prefix),
+ store_aad_prefix_in_file_(store_aad_prefix_in_file),
+ encrypted_columns_(encrypted_columns) {
+ // file encryption properties object can be used for writing only one file.
+ // Upon completion of file writing, the encryption keys in the properties will be wiped
+ // out (set to 0 in memory).
+ utilized_ = false;
+
+ DCHECK(!footer_key.empty());
+ // footer_key must be either 16, 24 or 32 bytes.
+ DCHECK(footer_key.length() == 16 || footer_key.length() == 24 ||
+ footer_key.length() == 32);
+
+ uint8_t aad_file_unique[kAadFileUniqueLength];
+ memset(aad_file_unique, 0, kAadFileUniqueLength);
+ encryption::RandBytes(aad_file_unique, sizeof(kAadFileUniqueLength));
+ std::string aad_file_unique_str(reinterpret_cast<char const*>(aad_file_unique),
+ kAadFileUniqueLength);
+
+ bool supply_aad_prefix = false;
+ if (aad_prefix.empty()) {
+ file_aad_ = aad_file_unique_str;
+ } else {
+ file_aad_ = aad_prefix + aad_file_unique_str;
+ if (!store_aad_prefix_in_file) supply_aad_prefix = true;
+ }
+ algorithm_.algorithm = cipher;
+ algorithm_.aad.aad_file_unique = aad_file_unique_str;
+ algorithm_.aad.supply_aad_prefix = supply_aad_prefix;
+ if (!aad_prefix.empty() && store_aad_prefix_in_file) {
+ algorithm_.aad.aad_prefix = aad_prefix;
+ }
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/encryption.h b/src/arrow/cpp/src/parquet/encryption/encryption.h
new file mode 100644
index 000000000..8fd7ec8d3
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/encryption.h
@@ -0,0 +1,510 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "parquet/exception.h"
+#include "parquet/schema.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+static constexpr ParquetCipher::type kDefaultEncryptionAlgorithm =
+ ParquetCipher::AES_GCM_V1;
+static constexpr int32_t kMaximalAadMetadataLength = 256;
+static constexpr bool kDefaultEncryptedFooter = true;
+static constexpr bool kDefaultCheckSignature = true;
+static constexpr bool kDefaultAllowPlaintextFiles = false;
+static constexpr int32_t kAadFileUniqueLength = 8;
+
+class ColumnDecryptionProperties;
+using ColumnPathToDecryptionPropertiesMap =
+ std::map<std::string, std::shared_ptr<ColumnDecryptionProperties>>;
+
+class ColumnEncryptionProperties;
+using ColumnPathToEncryptionPropertiesMap =
+ std::map<std::string, std::shared_ptr<ColumnEncryptionProperties>>;
+
+class PARQUET_EXPORT DecryptionKeyRetriever {
+ public:
+ virtual std::string GetKey(const std::string& key_metadata) = 0;
+ virtual ~DecryptionKeyRetriever() {}
+};
+
+/// Simple integer key retriever
+class PARQUET_EXPORT IntegerKeyIdRetriever : public DecryptionKeyRetriever {
+ public:
+ void PutKey(uint32_t key_id, const std::string& key);
+ std::string GetKey(const std::string& key_metadata) override;
+
+ private:
+ std::map<uint32_t, std::string> key_map_;
+};
+
+// Simple string key retriever
+class PARQUET_EXPORT StringKeyIdRetriever : public DecryptionKeyRetriever {
+ public:
+ void PutKey(const std::string& key_id, const std::string& key);
+ std::string GetKey(const std::string& key_metadata) override;
+
+ private:
+ std::map<std::string, std::string> key_map_;
+};
+
+class PARQUET_EXPORT HiddenColumnException : public ParquetException {
+ public:
+ explicit HiddenColumnException(const std::string& columnPath)
+ : ParquetException(columnPath.c_str()) {}
+};
+
+class PARQUET_EXPORT KeyAccessDeniedException : public ParquetException {
+ public:
+ explicit KeyAccessDeniedException(const std::string& columnPath)
+ : ParquetException(columnPath.c_str()) {}
+};
+
+inline const uint8_t* str2bytes(const std::string& str) {
+ if (str.empty()) return NULLPTR;
+
+ char* cbytes = const_cast<char*>(str.c_str());
+ return reinterpret_cast<const uint8_t*>(cbytes);
+}
+
+class PARQUET_EXPORT ColumnEncryptionProperties {
+ public:
+ class PARQUET_EXPORT Builder {
+ public:
+ /// Convenience builder for encrypted columns.
+ explicit Builder(const std::string& name) : Builder(name, true) {}
+
+ /// Convenience builder for encrypted columns.
+ explicit Builder(const std::shared_ptr<schema::ColumnPath>& path)
+ : Builder(path->ToDotString(), true) {}
+
+ /// Set a column-specific key.
+ /// If key is not set on an encrypted column, the column will
+ /// be encrypted with the footer key.
+ /// keyBytes Key length must be either 16, 24 or 32 bytes.
+ /// The key is cloned, and will be wiped out (array values set to 0) upon completion
+ /// of file writing.
+ /// Caller is responsible for wiping out the input key array.
+ Builder* key(std::string column_key);
+
+ /// Set a key retrieval metadata.
+ /// use either key_metadata() or key_id(), not both
+ Builder* key_metadata(const std::string& key_metadata);
+
+ /// A convenience function to set key metadata using a string id.
+ /// Set a key retrieval metadata (converted from String).
+ /// use either key_metadata() or key_id(), not both
+ /// key_id will be converted to metadata (UTF-8 array).
+ Builder* key_id(const std::string& key_id);
+
+ std::shared_ptr<ColumnEncryptionProperties> build() {
+ return std::shared_ptr<ColumnEncryptionProperties>(
+ new ColumnEncryptionProperties(encrypted_, column_path_, key_, key_metadata_));
+ }
+
+ private:
+ const std::string column_path_;
+ bool encrypted_;
+ std::string key_;
+ std::string key_metadata_;
+
+ Builder(const std::string path, bool encrypted)
+ : column_path_(path), encrypted_(encrypted) {}
+ };
+
+ std::string column_path() const { return column_path_; }
+ bool is_encrypted() const { return encrypted_; }
+ bool is_encrypted_with_footer_key() const { return encrypted_with_footer_key_; }
+ std::string key() const { return key_; }
+ std::string key_metadata() const { return key_metadata_; }
+
+ /// Upon completion of file writing, the encryption key
+ /// will be wiped out.
+ void WipeOutEncryptionKey() { key_.clear(); }
+
+ bool is_utilized() {
+ if (key_.empty())
+ return false; // can re-use column properties without encryption keys
+ return utilized_;
+ }
+
+ /// ColumnEncryptionProperties object can be used for writing one file only.
+ /// Mark ColumnEncryptionProperties as utilized once it is used in
+ /// FileEncryptionProperties as the encryption key will be wiped out upon
+ /// completion of file writing.
+ void set_utilized() { utilized_ = true; }
+
+ std::shared_ptr<ColumnEncryptionProperties> DeepClone() {
+ std::string key_copy = key_;
+ return std::shared_ptr<ColumnEncryptionProperties>(new ColumnEncryptionProperties(
+ encrypted_, column_path_, key_copy, key_metadata_));
+ }
+
+ ColumnEncryptionProperties() = default;
+ ColumnEncryptionProperties(const ColumnEncryptionProperties& other) = default;
+ ColumnEncryptionProperties(ColumnEncryptionProperties&& other) = default;
+
+ private:
+ const std::string column_path_;
+ bool encrypted_;
+ bool encrypted_with_footer_key_;
+ std::string key_;
+ std::string key_metadata_;
+ bool utilized_;
+ explicit ColumnEncryptionProperties(bool encrypted, const std::string& column_path,
+ const std::string& key,
+ const std::string& key_metadata);
+};
+
+class PARQUET_EXPORT ColumnDecryptionProperties {
+ public:
+ class PARQUET_EXPORT Builder {
+ public:
+ explicit Builder(const std::string& name) : column_path_(name) {}
+
+ explicit Builder(const std::shared_ptr<schema::ColumnPath>& path)
+ : Builder(path->ToDotString()) {}
+
+ /// Set an explicit column key. If applied on a file that contains
+ /// key metadata for this column the metadata will be ignored,
+ /// the column will be decrypted with this key.
+ /// key length must be either 16, 24 or 32 bytes.
+ Builder* key(const std::string& key);
+
+ std::shared_ptr<ColumnDecryptionProperties> build();
+
+ private:
+ const std::string column_path_;
+ std::string key_;
+ };
+
+ ColumnDecryptionProperties() = default;
+ ColumnDecryptionProperties(const ColumnDecryptionProperties& other) = default;
+ ColumnDecryptionProperties(ColumnDecryptionProperties&& other) = default;
+
+ std::string column_path() const { return column_path_; }
+ std::string key() const { return key_; }
+ bool is_utilized() { return utilized_; }
+
+ /// ColumnDecryptionProperties object can be used for reading one file only.
+ /// Mark ColumnDecryptionProperties as utilized once it is used in
+ /// FileDecryptionProperties as the encryption key will be wiped out upon
+ /// completion of file reading.
+ void set_utilized() { utilized_ = true; }
+
+ /// Upon completion of file reading, the encryption key
+ /// will be wiped out.
+ void WipeOutDecryptionKey();
+
+ std::shared_ptr<ColumnDecryptionProperties> DeepClone();
+
+ private:
+ const std::string column_path_;
+ std::string key_;
+ bool utilized_;
+
+ /// This class is only required for setting explicit column decryption keys -
+ /// to override key retriever (or to provide keys when key metadata and/or
+ /// key retriever are not available)
+ explicit ColumnDecryptionProperties(const std::string& column_path,
+ const std::string& key);
+};
+
+class PARQUET_EXPORT AADPrefixVerifier {
+ public:
+ /// Verifies identity (AAD Prefix) of individual file,
+ /// or of file collection in a data set.
+ /// Throws exception if an AAD prefix is wrong.
+ /// In a data set, AAD Prefixes should be collected,
+ /// and then checked for missing files.
+ virtual void Verify(const std::string& aad_prefix) = 0;
+ virtual ~AADPrefixVerifier() {}
+};
+
+class PARQUET_EXPORT FileDecryptionProperties {
+ public:
+ class PARQUET_EXPORT Builder {
+ public:
+ Builder() {
+ check_plaintext_footer_integrity_ = kDefaultCheckSignature;
+ plaintext_files_allowed_ = kDefaultAllowPlaintextFiles;
+ }
+
+ /// Set an explicit footer key. If applied on a file that contains
+ /// footer key metadata the metadata will be ignored, the footer
+ /// will be decrypted/verified with this key.
+ /// If explicit key is not set, footer key will be fetched from
+ /// key retriever.
+ /// With explicit keys or AAD prefix, new encryption properties object must be
+ /// created for each encrypted file.
+ /// Explicit encryption keys (footer and column) are cloned.
+ /// Upon completion of file reading, the cloned encryption keys in the properties
+ /// will be wiped out (array values set to 0).
+ /// Caller is responsible for wiping out the input key array.
+ /// param footerKey Key length must be either 16, 24 or 32 bytes.
+ Builder* footer_key(const std::string footer_key);
+
+ /// Set explicit column keys (decryption properties).
+ /// Its also possible to set a key retriever on this property object.
+ /// Upon file decryption, availability of explicit keys is checked before
+ /// invocation of the retriever callback.
+ /// If an explicit key is available for a footer or a column,
+ /// its key metadata will be ignored.
+ Builder* column_keys(
+ const ColumnPathToDecryptionPropertiesMap& column_decryption_properties);
+
+ /// Set a key retriever callback. Its also possible to
+ /// set explicit footer or column keys on this file property object.
+ /// Upon file decryption, availability of explicit keys is checked before
+ /// invocation of the retriever callback.
+ /// If an explicit key is available for a footer or a column,
+ /// its key metadata will be ignored.
+ Builder* key_retriever(const std::shared_ptr<DecryptionKeyRetriever>& key_retriever);
+
+ /// Skip integrity verification of plaintext footers.
+ /// If not called, integrity of plaintext footers will be checked in runtime,
+ /// and an exception will be thrown in the following situations:
+ /// - footer signing key is not available
+ /// (not passed, or not found by key retriever)
+ /// - footer content and signature don't match
+ Builder* disable_footer_signature_verification() {
+ check_plaintext_footer_integrity_ = false;
+ return this;
+ }
+
+ /// Explicitly supply the file AAD prefix.
+ /// A must when a prefix is used for file encryption, but not stored in file.
+ /// If AAD prefix is stored in file, it will be compared to the explicitly
+ /// supplied value and an exception will be thrown if they differ.
+ Builder* aad_prefix(const std::string& aad_prefix);
+
+ /// Set callback for verification of AAD Prefixes stored in file.
+ Builder* aad_prefix_verifier(std::shared_ptr<AADPrefixVerifier> aad_prefix_verifier);
+
+ /// By default, reading plaintext (unencrypted) files is not
+ /// allowed when using a decryptor
+ /// - in order to detect files that were not encrypted by mistake.
+ /// However, the default behavior can be overridden by calling this method.
+ /// The caller should use then a different method to ensure encryption
+ /// of files with sensitive data.
+ Builder* plaintext_files_allowed() {
+ plaintext_files_allowed_ = true;
+ return this;
+ }
+
+ std::shared_ptr<FileDecryptionProperties> build() {
+ return std::shared_ptr<FileDecryptionProperties>(new FileDecryptionProperties(
+ footer_key_, key_retriever_, check_plaintext_footer_integrity_, aad_prefix_,
+ aad_prefix_verifier_, column_decryption_properties_, plaintext_files_allowed_));
+ }
+
+ private:
+ std::string footer_key_;
+ std::string aad_prefix_;
+ std::shared_ptr<AADPrefixVerifier> aad_prefix_verifier_;
+ ColumnPathToDecryptionPropertiesMap column_decryption_properties_;
+
+ std::shared_ptr<DecryptionKeyRetriever> key_retriever_;
+ bool check_plaintext_footer_integrity_;
+ bool plaintext_files_allowed_;
+ };
+
+ std::string column_key(const std::string& column_path) const;
+
+ std::string footer_key() const { return footer_key_; }
+
+ std::string aad_prefix() const { return aad_prefix_; }
+
+ const std::shared_ptr<DecryptionKeyRetriever>& key_retriever() const {
+ return key_retriever_;
+ }
+
+ bool check_plaintext_footer_integrity() const {
+ return check_plaintext_footer_integrity_;
+ }
+
+ bool plaintext_files_allowed() const { return plaintext_files_allowed_; }
+
+ const std::shared_ptr<AADPrefixVerifier>& aad_prefix_verifier() const {
+ return aad_prefix_verifier_;
+ }
+
+ /// Upon completion of file reading, the encryption keys in the properties
+ /// will be wiped out (array values set to 0).
+ void WipeOutDecryptionKeys();
+
+ bool is_utilized();
+
+ /// FileDecryptionProperties object can be used for reading one file only.
+ /// Mark FileDecryptionProperties as utilized once it is used to read a file as the
+ /// encryption keys will be wiped out upon completion of file reading.
+ void set_utilized() { utilized_ = true; }
+
+ /// FileDecryptionProperties object can be used for reading one file only.
+ /// (unless this object keeps the keyRetrieval callback only, and no explicit
+ /// keys or aadPrefix).
+ /// At the end, keys are wiped out in the memory.
+ /// This method allows to clone identical properties for another file,
+ /// with an option to update the aadPrefix (if newAadPrefix is null,
+ /// aadPrefix will be cloned too)
+ std::shared_ptr<FileDecryptionProperties> DeepClone(std::string new_aad_prefix = "");
+
+ private:
+ std::string footer_key_;
+ std::string aad_prefix_;
+ std::shared_ptr<AADPrefixVerifier> aad_prefix_verifier_;
+
+ const std::string empty_string_ = "";
+ ColumnPathToDecryptionPropertiesMap column_decryption_properties_;
+
+ std::shared_ptr<DecryptionKeyRetriever> key_retriever_;
+ bool check_plaintext_footer_integrity_;
+ bool plaintext_files_allowed_;
+ bool utilized_;
+
+ FileDecryptionProperties(
+ const std::string& footer_key,
+ std::shared_ptr<DecryptionKeyRetriever> key_retriever,
+ bool check_plaintext_footer_integrity, const std::string& aad_prefix,
+ std::shared_ptr<AADPrefixVerifier> aad_prefix_verifier,
+ const ColumnPathToDecryptionPropertiesMap& column_decryption_properties,
+ bool plaintext_files_allowed);
+};
+
+class PARQUET_EXPORT FileEncryptionProperties {
+ public:
+ class PARQUET_EXPORT Builder {
+ public:
+ explicit Builder(const std::string& footer_key)
+ : parquet_cipher_(kDefaultEncryptionAlgorithm),
+ encrypted_footer_(kDefaultEncryptedFooter) {
+ footer_key_ = footer_key;
+ store_aad_prefix_in_file_ = false;
+ }
+
+ /// Create files with plaintext footer.
+ /// If not called, the files will be created with encrypted footer (default).
+ Builder* set_plaintext_footer() {
+ encrypted_footer_ = false;
+ return this;
+ }
+
+ /// Set encryption algorithm.
+ /// If not called, files will be encrypted with AES_GCM_V1 (default).
+ Builder* algorithm(ParquetCipher::type parquet_cipher) {
+ parquet_cipher_ = parquet_cipher;
+ return this;
+ }
+
+ /// Set a key retrieval metadata (converted from String).
+ /// use either footer_key_metadata or footer_key_id, not both.
+ Builder* footer_key_id(const std::string& key_id);
+
+ /// Set a key retrieval metadata.
+ /// use either footer_key_metadata or footer_key_id, not both.
+ Builder* footer_key_metadata(const std::string& footer_key_metadata);
+
+ /// Set the file AAD Prefix.
+ Builder* aad_prefix(const std::string& aad_prefix);
+
+ /// Skip storing AAD Prefix in file.
+ /// If not called, and if AAD Prefix is set, it will be stored.
+ Builder* disable_aad_prefix_storage();
+
+ /// Set the list of encrypted columns and their properties (keys etc).
+ /// If not called, all columns will be encrypted with the footer key.
+ /// If called, the file columns not in the list will be left unencrypted.
+ Builder* encrypted_columns(
+ const ColumnPathToEncryptionPropertiesMap& encrypted_columns);
+
+ std::shared_ptr<FileEncryptionProperties> build() {
+ return std::shared_ptr<FileEncryptionProperties>(new FileEncryptionProperties(
+ parquet_cipher_, footer_key_, footer_key_metadata_, encrypted_footer_,
+ aad_prefix_, store_aad_prefix_in_file_, encrypted_columns_));
+ }
+
+ private:
+ ParquetCipher::type parquet_cipher_;
+ bool encrypted_footer_;
+ std::string footer_key_;
+ std::string footer_key_metadata_;
+
+ std::string aad_prefix_;
+ bool store_aad_prefix_in_file_;
+ ColumnPathToEncryptionPropertiesMap encrypted_columns_;
+ };
+ bool encrypted_footer() const { return encrypted_footer_; }
+
+ EncryptionAlgorithm algorithm() const { return algorithm_; }
+
+ std::string footer_key() const { return footer_key_; }
+
+ std::string footer_key_metadata() const { return footer_key_metadata_; }
+
+ std::string file_aad() const { return file_aad_; }
+
+ std::shared_ptr<ColumnEncryptionProperties> column_encryption_properties(
+ const std::string& column_path);
+
+ bool is_utilized() const { return utilized_; }
+
+ /// FileEncryptionProperties object can be used for writing one file only.
+ /// Mark FileEncryptionProperties as utilized once it is used to write a file as the
+ /// encryption keys will be wiped out upon completion of file writing.
+ void set_utilized() { utilized_ = true; }
+
+ /// Upon completion of file writing, the encryption keys
+ /// will be wiped out (array values set to 0).
+ void WipeOutEncryptionKeys();
+
+ /// FileEncryptionProperties object can be used for writing one file only.
+ /// (at the end, keys are wiped out in the memory).
+ /// This method allows to clone identical properties for another file,
+ /// with an option to update the aadPrefix (if newAadPrefix is null,
+ /// aadPrefix will be cloned too)
+ std::shared_ptr<FileEncryptionProperties> DeepClone(std::string new_aad_prefix = "");
+
+ ColumnPathToEncryptionPropertiesMap encrypted_columns() const {
+ return encrypted_columns_;
+ }
+
+ private:
+ EncryptionAlgorithm algorithm_;
+ std::string footer_key_;
+ std::string footer_key_metadata_;
+ bool encrypted_footer_;
+ std::string file_aad_;
+ std::string aad_prefix_;
+ bool utilized_;
+ bool store_aad_prefix_in_file_;
+ ColumnPathToEncryptionPropertiesMap encrypted_columns_;
+
+ FileEncryptionProperties(ParquetCipher::type cipher, const std::string& footer_key,
+ const std::string& footer_key_metadata, bool encrypted_footer,
+ const std::string& aad_prefix, bool store_aad_prefix_in_file,
+ const ColumnPathToEncryptionPropertiesMap& encrypted_columns);
+};
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/encryption_internal.cc b/src/arrow/cpp/src/parquet/encryption/encryption_internal.cc
new file mode 100644
index 000000000..f46274cbe
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/encryption_internal.cc
@@ -0,0 +1,613 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/encryption/encryption_internal.h"
+#include <openssl/aes.h>
+#include <openssl/evp.h>
+#include <openssl/rand.h>
+
+#include <algorithm>
+#include <iostream>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "parquet/exception.h"
+
+using parquet::ParquetException;
+
+namespace parquet {
+namespace encryption {
+
+constexpr int kGcmMode = 0;
+constexpr int kCtrMode = 1;
+constexpr int kCtrIvLength = 16;
+constexpr int kBufferSizeLength = 4;
+
+#define ENCRYPT_INIT(CTX, ALG) \
+ if (1 != EVP_EncryptInit_ex(CTX, ALG, nullptr, nullptr, nullptr)) { \
+ throw ParquetException("Couldn't init ALG encryption"); \
+ }
+
+#define DECRYPT_INIT(CTX, ALG) \
+ if (1 != EVP_DecryptInit_ex(CTX, ALG, nullptr, nullptr, nullptr)) { \
+ throw ParquetException("Couldn't init ALG decryption"); \
+ }
+
+class AesEncryptor::AesEncryptorImpl {
+ public:
+ explicit AesEncryptorImpl(ParquetCipher::type alg_id, int key_len, bool metadata);
+
+ ~AesEncryptorImpl() {
+ if (nullptr != ctx_) {
+ EVP_CIPHER_CTX_free(ctx_);
+ ctx_ = nullptr;
+ }
+ }
+
+ int Encrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key,
+ int key_len, const uint8_t* aad, int aad_len, uint8_t* ciphertext);
+
+ int SignedFooterEncrypt(const uint8_t* footer, int footer_len, const uint8_t* key,
+ int key_len, const uint8_t* aad, int aad_len,
+ const uint8_t* nonce, uint8_t* encrypted_footer);
+ void WipeOut() {
+ if (nullptr != ctx_) {
+ EVP_CIPHER_CTX_free(ctx_);
+ ctx_ = nullptr;
+ }
+ }
+
+ int ciphertext_size_delta() { return ciphertext_size_delta_; }
+
+ private:
+ EVP_CIPHER_CTX* ctx_;
+ int aes_mode_;
+ int key_length_;
+ int ciphertext_size_delta_;
+
+ int GcmEncrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key,
+ int key_len, const uint8_t* nonce, const uint8_t* aad, int aad_len,
+ uint8_t* ciphertext);
+
+ int CtrEncrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key,
+ int key_len, const uint8_t* nonce, uint8_t* ciphertext);
+};
+
+AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id, int key_len,
+ bool metadata) {
+ ctx_ = nullptr;
+
+ ciphertext_size_delta_ = kBufferSizeLength + kNonceLength;
+ if (metadata || (ParquetCipher::AES_GCM_V1 == alg_id)) {
+ aes_mode_ = kGcmMode;
+ ciphertext_size_delta_ += kGcmTagLength;
+ } else {
+ aes_mode_ = kCtrMode;
+ }
+
+ if (16 != key_len && 24 != key_len && 32 != key_len) {
+ std::stringstream ss;
+ ss << "Wrong key length: " << key_len;
+ throw ParquetException(ss.str());
+ }
+
+ key_length_ = key_len;
+
+ ctx_ = EVP_CIPHER_CTX_new();
+ if (nullptr == ctx_) {
+ throw ParquetException("Couldn't init cipher context");
+ }
+
+ if (kGcmMode == aes_mode_) {
+ // Init AES-GCM with specified key length
+ if (16 == key_len) {
+ ENCRYPT_INIT(ctx_, EVP_aes_128_gcm());
+ } else if (24 == key_len) {
+ ENCRYPT_INIT(ctx_, EVP_aes_192_gcm());
+ } else if (32 == key_len) {
+ ENCRYPT_INIT(ctx_, EVP_aes_256_gcm());
+ }
+ } else {
+ // Init AES-CTR with specified key length
+ if (16 == key_len) {
+ ENCRYPT_INIT(ctx_, EVP_aes_128_ctr());
+ } else if (24 == key_len) {
+ ENCRYPT_INIT(ctx_, EVP_aes_192_ctr());
+ } else if (32 == key_len) {
+ ENCRYPT_INIT(ctx_, EVP_aes_256_ctr());
+ }
+ }
+}
+
+int AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt(
+ const uint8_t* footer, int footer_len, const uint8_t* key, int key_len,
+ const uint8_t* aad, int aad_len, const uint8_t* nonce, uint8_t* encrypted_footer) {
+ if (key_length_ != key_len) {
+ std::stringstream ss;
+ ss << "Wrong key length " << key_len << ". Should be " << key_length_;
+ throw ParquetException(ss.str());
+ }
+
+ if (kGcmMode != aes_mode_) {
+ throw ParquetException("Must use AES GCM (metadata) encryptor");
+ }
+
+ return GcmEncrypt(footer, footer_len, key, key_len, nonce, aad, aad_len,
+ encrypted_footer);
+}
+
+int AesEncryptor::AesEncryptorImpl::Encrypt(const uint8_t* plaintext, int plaintext_len,
+ const uint8_t* key, int key_len,
+ const uint8_t* aad, int aad_len,
+ uint8_t* ciphertext) {
+ if (key_length_ != key_len) {
+ std::stringstream ss;
+ ss << "Wrong key length " << key_len << ". Should be " << key_length_;
+ throw ParquetException(ss.str());
+ }
+
+ uint8_t nonce[kNonceLength];
+ memset(nonce, 0, kNonceLength);
+ // Random nonce
+ RAND_bytes(nonce, sizeof(nonce));
+
+ if (kGcmMode == aes_mode_) {
+ return GcmEncrypt(plaintext, plaintext_len, key, key_len, nonce, aad, aad_len,
+ ciphertext);
+ }
+
+ return CtrEncrypt(plaintext, plaintext_len, key, key_len, nonce, ciphertext);
+}
+
+int AesEncryptor::AesEncryptorImpl::GcmEncrypt(const uint8_t* plaintext,
+ int plaintext_len, const uint8_t* key,
+ int key_len, const uint8_t* nonce,
+ const uint8_t* aad, int aad_len,
+ uint8_t* ciphertext) {
+ int len;
+ int ciphertext_len;
+
+ uint8_t tag[kGcmTagLength];
+ memset(tag, 0, kGcmTagLength);
+
+ // Setting key and IV (nonce)
+ if (1 != EVP_EncryptInit_ex(ctx_, nullptr, nullptr, key, nonce)) {
+ throw ParquetException("Couldn't set key and nonce");
+ }
+
+ // Setting additional authenticated data
+ if ((nullptr != aad) && (1 != EVP_EncryptUpdate(ctx_, nullptr, &len, aad, aad_len))) {
+ throw ParquetException("Couldn't set AAD");
+ }
+
+ // Encryption
+ if (1 != EVP_EncryptUpdate(ctx_, ciphertext + kBufferSizeLength + kNonceLength, &len,
+ plaintext, plaintext_len)) {
+ throw ParquetException("Failed encryption update");
+ }
+
+ ciphertext_len = len;
+
+ // Finalization
+ if (1 != EVP_EncryptFinal_ex(ctx_, ciphertext + kBufferSizeLength + kNonceLength + len,
+ &len)) {
+ throw ParquetException("Failed encryption finalization");
+ }
+
+ ciphertext_len += len;
+
+ // Getting the tag
+ if (1 != EVP_CIPHER_CTX_ctrl(ctx_, EVP_CTRL_GCM_GET_TAG, kGcmTagLength, tag)) {
+ throw ParquetException("Couldn't get AES-GCM tag");
+ }
+
+ // Copying the buffer size, nonce and tag to ciphertext
+ int buffer_size = kNonceLength + ciphertext_len + kGcmTagLength;
+ ciphertext[3] = static_cast<uint8_t>(0xff & (buffer_size >> 24));
+ ciphertext[2] = static_cast<uint8_t>(0xff & (buffer_size >> 16));
+ ciphertext[1] = static_cast<uint8_t>(0xff & (buffer_size >> 8));
+ ciphertext[0] = static_cast<uint8_t>(0xff & (buffer_size));
+ std::copy(nonce, nonce + kNonceLength, ciphertext + kBufferSizeLength);
+ std::copy(tag, tag + kGcmTagLength,
+ ciphertext + kBufferSizeLength + kNonceLength + ciphertext_len);
+
+ return kBufferSizeLength + buffer_size;
+}
+
+int AesEncryptor::AesEncryptorImpl::CtrEncrypt(const uint8_t* plaintext,
+ int plaintext_len, const uint8_t* key,
+ int key_len, const uint8_t* nonce,
+ uint8_t* ciphertext) {
+ int len;
+ int ciphertext_len;
+
+ // Parquet CTR IVs are comprised of a 12-byte nonce and a 4-byte initial
+ // counter field.
+ // The first 31 bits of the initial counter field are set to 0, the last bit
+ // is set to 1.
+ uint8_t iv[kCtrIvLength];
+ memset(iv, 0, kCtrIvLength);
+ std::copy(nonce, nonce + kNonceLength, iv);
+ iv[kCtrIvLength - 1] = 1;
+
+ // Setting key and IV
+ if (1 != EVP_EncryptInit_ex(ctx_, nullptr, nullptr, key, iv)) {
+ throw ParquetException("Couldn't set key and IV");
+ }
+
+ // Encryption
+ if (1 != EVP_EncryptUpdate(ctx_, ciphertext + kBufferSizeLength + kNonceLength, &len,
+ plaintext, plaintext_len)) {
+ throw ParquetException("Failed encryption update");
+ }
+
+ ciphertext_len = len;
+
+ // Finalization
+ if (1 != EVP_EncryptFinal_ex(ctx_, ciphertext + kBufferSizeLength + kNonceLength + len,
+ &len)) {
+ throw ParquetException("Failed encryption finalization");
+ }
+
+ ciphertext_len += len;
+
+ // Copying the buffer size and nonce to ciphertext
+ int buffer_size = kNonceLength + ciphertext_len;
+ ciphertext[3] = static_cast<uint8_t>(0xff & (buffer_size >> 24));
+ ciphertext[2] = static_cast<uint8_t>(0xff & (buffer_size >> 16));
+ ciphertext[1] = static_cast<uint8_t>(0xff & (buffer_size >> 8));
+ ciphertext[0] = static_cast<uint8_t>(0xff & (buffer_size));
+ std::copy(nonce, nonce + kNonceLength, ciphertext + kBufferSizeLength);
+
+ return kBufferSizeLength + buffer_size;
+}
+
+AesEncryptor::~AesEncryptor() {}
+
+int AesEncryptor::SignedFooterEncrypt(const uint8_t* footer, int footer_len,
+ const uint8_t* key, int key_len, const uint8_t* aad,
+ int aad_len, const uint8_t* nonce,
+ uint8_t* encrypted_footer) {
+ return impl_->SignedFooterEncrypt(footer, footer_len, key, key_len, aad, aad_len, nonce,
+ encrypted_footer);
+}
+
+void AesEncryptor::WipeOut() { impl_->WipeOut(); }
+
+int AesEncryptor::CiphertextSizeDelta() { return impl_->ciphertext_size_delta(); }
+
+int AesEncryptor::Encrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key,
+ int key_len, const uint8_t* aad, int aad_len,
+ uint8_t* ciphertext) {
+ return impl_->Encrypt(plaintext, plaintext_len, key, key_len, aad, aad_len, ciphertext);
+}
+
+AesEncryptor::AesEncryptor(ParquetCipher::type alg_id, int key_len, bool metadata)
+ : impl_{std::unique_ptr<AesEncryptorImpl>(
+ new AesEncryptorImpl(alg_id, key_len, metadata))} {}
+
+class AesDecryptor::AesDecryptorImpl {
+ public:
+ explicit AesDecryptorImpl(ParquetCipher::type alg_id, int key_len, bool metadata);
+
+ ~AesDecryptorImpl() {
+ if (nullptr != ctx_) {
+ EVP_CIPHER_CTX_free(ctx_);
+ ctx_ = nullptr;
+ }
+ }
+
+ int Decrypt(const uint8_t* ciphertext, int ciphertext_len, const uint8_t* key,
+ int key_len, const uint8_t* aad, int aad_len, uint8_t* plaintext);
+
+ void WipeOut() {
+ if (nullptr != ctx_) {
+ EVP_CIPHER_CTX_free(ctx_);
+ ctx_ = nullptr;
+ }
+ }
+
+ int ciphertext_size_delta() { return ciphertext_size_delta_; }
+
+ private:
+ EVP_CIPHER_CTX* ctx_;
+ int aes_mode_;
+ int key_length_;
+ int ciphertext_size_delta_;
+ int GcmDecrypt(const uint8_t* ciphertext, int ciphertext_len, const uint8_t* key,
+ int key_len, const uint8_t* aad, int aad_len, uint8_t* plaintext);
+
+ int CtrDecrypt(const uint8_t* ciphertext, int ciphertext_len, const uint8_t* key,
+ int key_len, uint8_t* plaintext);
+};
+
+int AesDecryptor::Decrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key,
+ int key_len, const uint8_t* aad, int aad_len,
+ uint8_t* ciphertext) {
+ return impl_->Decrypt(plaintext, plaintext_len, key, key_len, aad, aad_len, ciphertext);
+}
+
+void AesDecryptor::WipeOut() { impl_->WipeOut(); }
+
+AesDecryptor::~AesDecryptor() {}
+
+AesDecryptor::AesDecryptorImpl::AesDecryptorImpl(ParquetCipher::type alg_id, int key_len,
+ bool metadata) {
+ ctx_ = nullptr;
+
+ ciphertext_size_delta_ = kBufferSizeLength + kNonceLength;
+ if (metadata || (ParquetCipher::AES_GCM_V1 == alg_id)) {
+ aes_mode_ = kGcmMode;
+ ciphertext_size_delta_ += kGcmTagLength;
+ } else {
+ aes_mode_ = kCtrMode;
+ }
+
+ if (16 != key_len && 24 != key_len && 32 != key_len) {
+ std::stringstream ss;
+ ss << "Wrong key length: " << key_len;
+ throw ParquetException(ss.str());
+ }
+
+ key_length_ = key_len;
+
+ ctx_ = EVP_CIPHER_CTX_new();
+ if (nullptr == ctx_) {
+ throw ParquetException("Couldn't init cipher context");
+ }
+
+ if (kGcmMode == aes_mode_) {
+ // Init AES-GCM with specified key length
+ if (16 == key_len) {
+ DECRYPT_INIT(ctx_, EVP_aes_128_gcm());
+ } else if (24 == key_len) {
+ DECRYPT_INIT(ctx_, EVP_aes_192_gcm());
+ } else if (32 == key_len) {
+ DECRYPT_INIT(ctx_, EVP_aes_256_gcm());
+ }
+ } else {
+ // Init AES-CTR with specified key length
+ if (16 == key_len) {
+ DECRYPT_INIT(ctx_, EVP_aes_128_ctr());
+ } else if (24 == key_len) {
+ DECRYPT_INIT(ctx_, EVP_aes_192_ctr());
+ } else if (32 == key_len) {
+ DECRYPT_INIT(ctx_, EVP_aes_256_ctr());
+ }
+ }
+}
+
+AesEncryptor* AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, bool metadata,
+ std::vector<AesEncryptor*>* all_encryptors) {
+ if (ParquetCipher::AES_GCM_V1 != alg_id && ParquetCipher::AES_GCM_CTR_V1 != alg_id) {
+ std::stringstream ss;
+ ss << "Crypto algorithm " << alg_id << " is not supported";
+ throw ParquetException(ss.str());
+ }
+
+ AesEncryptor* encryptor = new AesEncryptor(alg_id, key_len, metadata);
+ if (all_encryptors != nullptr) all_encryptors->push_back(encryptor);
+ return encryptor;
+}
+
+AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int key_len, bool metadata)
+ : impl_{std::unique_ptr<AesDecryptorImpl>(
+ new AesDecryptorImpl(alg_id, key_len, metadata))} {}
+
+AesDecryptor* AesDecryptor::Make(ParquetCipher::type alg_id, int key_len, bool metadata,
+ std::vector<AesDecryptor*>* all_decryptors) {
+ if (ParquetCipher::AES_GCM_V1 != alg_id && ParquetCipher::AES_GCM_CTR_V1 != alg_id) {
+ std::stringstream ss;
+ ss << "Crypto algorithm " << alg_id << " is not supported";
+ throw ParquetException(ss.str());
+ }
+
+ AesDecryptor* decryptor = new AesDecryptor(alg_id, key_len, metadata);
+ if (all_decryptors != nullptr) {
+ all_decryptors->push_back(decryptor);
+ }
+ return decryptor;
+}
+
+int AesDecryptor::CiphertextSizeDelta() { return impl_->ciphertext_size_delta(); }
+
+int AesDecryptor::AesDecryptorImpl::GcmDecrypt(const uint8_t* ciphertext,
+ int ciphertext_len, const uint8_t* key,
+ int key_len, const uint8_t* aad,
+ int aad_len, uint8_t* plaintext) {
+ int len;
+ int plaintext_len;
+
+ uint8_t tag[kGcmTagLength];
+ memset(tag, 0, kGcmTagLength);
+ uint8_t nonce[kNonceLength];
+ memset(nonce, 0, kNonceLength);
+
+ // Extract ciphertext length
+ int written_ciphertext_len = ((ciphertext[3] & 0xff) << 24) |
+ ((ciphertext[2] & 0xff) << 16) |
+ ((ciphertext[1] & 0xff) << 8) | ((ciphertext[0] & 0xff));
+
+ if (ciphertext_len > 0 &&
+ ciphertext_len != (written_ciphertext_len + kBufferSizeLength)) {
+ throw ParquetException("Wrong ciphertext length");
+ }
+ ciphertext_len = written_ciphertext_len + kBufferSizeLength;
+
+ // Extracting IV and tag
+ std::copy(ciphertext + kBufferSizeLength, ciphertext + kBufferSizeLength + kNonceLength,
+ nonce);
+ std::copy(ciphertext + ciphertext_len - kGcmTagLength, ciphertext + ciphertext_len,
+ tag);
+
+ // Setting key and IV
+ if (1 != EVP_DecryptInit_ex(ctx_, nullptr, nullptr, key, nonce)) {
+ throw ParquetException("Couldn't set key and IV");
+ }
+
+ // Setting additional authenticated data
+ if ((nullptr != aad) && (1 != EVP_DecryptUpdate(ctx_, nullptr, &len, aad, aad_len))) {
+ throw ParquetException("Couldn't set AAD");
+ }
+
+ // Decryption
+ if (!EVP_DecryptUpdate(
+ ctx_, plaintext, &len, ciphertext + kBufferSizeLength + kNonceLength,
+ ciphertext_len - kBufferSizeLength - kNonceLength - kGcmTagLength)) {
+ throw ParquetException("Failed decryption update");
+ }
+
+ plaintext_len = len;
+
+ // Checking the tag (authentication)
+ if (!EVP_CIPHER_CTX_ctrl(ctx_, EVP_CTRL_GCM_SET_TAG, kGcmTagLength, tag)) {
+ throw ParquetException("Failed authentication");
+ }
+
+ // Finalization
+ if (1 != EVP_DecryptFinal_ex(ctx_, plaintext + len, &len)) {
+ throw ParquetException("Failed decryption finalization");
+ }
+
+ plaintext_len += len;
+ return plaintext_len;
+}
+
+int AesDecryptor::AesDecryptorImpl::CtrDecrypt(const uint8_t* ciphertext,
+ int ciphertext_len, const uint8_t* key,
+ int key_len, uint8_t* plaintext) {
+ int len;
+ int plaintext_len;
+
+ uint8_t iv[kCtrIvLength];
+ memset(iv, 0, kCtrIvLength);
+
+ // Extract ciphertext length
+ int written_ciphertext_len = ((ciphertext[3] & 0xff) << 24) |
+ ((ciphertext[2] & 0xff) << 16) |
+ ((ciphertext[1] & 0xff) << 8) | ((ciphertext[0] & 0xff));
+
+ if (ciphertext_len > 0 &&
+ ciphertext_len != (written_ciphertext_len + kBufferSizeLength)) {
+ throw ParquetException("Wrong ciphertext length");
+ }
+ ciphertext_len = written_ciphertext_len;
+
+ // Extracting nonce
+ std::copy(ciphertext + kBufferSizeLength, ciphertext + kBufferSizeLength + kNonceLength,
+ iv);
+ // Parquet CTR IVs are comprised of a 12-byte nonce and a 4-byte initial
+ // counter field.
+ // The first 31 bits of the initial counter field are set to 0, the last bit
+ // is set to 1.
+ iv[kCtrIvLength - 1] = 1;
+
+ // Setting key and IV
+ if (1 != EVP_DecryptInit_ex(ctx_, nullptr, nullptr, key, iv)) {
+ throw ParquetException("Couldn't set key and IV");
+ }
+
+ // Decryption
+ if (!EVP_DecryptUpdate(ctx_, plaintext, &len,
+ ciphertext + kBufferSizeLength + kNonceLength,
+ ciphertext_len - kNonceLength)) {
+ throw ParquetException("Failed decryption update");
+ }
+
+ plaintext_len = len;
+
+ // Finalization
+ if (1 != EVP_DecryptFinal_ex(ctx_, plaintext + len, &len)) {
+ throw ParquetException("Failed decryption finalization");
+ }
+
+ plaintext_len += len;
+ return plaintext_len;
+}
+
+int AesDecryptor::AesDecryptorImpl::Decrypt(const uint8_t* ciphertext, int ciphertext_len,
+ const uint8_t* key, int key_len,
+ const uint8_t* aad, int aad_len,
+ uint8_t* plaintext) {
+ if (key_length_ != key_len) {
+ std::stringstream ss;
+ ss << "Wrong key length " << key_len << ". Should be " << key_length_;
+ throw ParquetException(ss.str());
+ }
+
+ if (kGcmMode == aes_mode_) {
+ return GcmDecrypt(ciphertext, ciphertext_len, key, key_len, aad, aad_len, plaintext);
+ }
+
+ return CtrDecrypt(ciphertext, ciphertext_len, key, key_len, plaintext);
+}
+
+static std::string ShortToBytesLe(int16_t input) {
+ int8_t output[2];
+ memset(output, 0, 2);
+ output[1] = static_cast<int8_t>(0xff & (input >> 8));
+ output[0] = static_cast<int8_t>(0xff & (input));
+
+ return std::string(reinterpret_cast<char const*>(output), 2);
+}
+
+std::string CreateModuleAad(const std::string& file_aad, int8_t module_type,
+ int16_t row_group_ordinal, int16_t column_ordinal,
+ int16_t page_ordinal) {
+ int8_t type_ordinal_bytes[1];
+ type_ordinal_bytes[0] = module_type;
+ std::string type_ordinal_bytes_str(reinterpret_cast<char const*>(type_ordinal_bytes),
+ 1);
+ if (kFooter == module_type) {
+ std::string result = file_aad + type_ordinal_bytes_str;
+ return result;
+ }
+ std::string row_group_ordinal_bytes = ShortToBytesLe(row_group_ordinal);
+ std::string column_ordinal_bytes = ShortToBytesLe(column_ordinal);
+ if (kDataPage != module_type && kDataPageHeader != module_type) {
+ std::ostringstream out;
+ out << file_aad << type_ordinal_bytes_str << row_group_ordinal_bytes
+ << column_ordinal_bytes;
+ return out.str();
+ }
+ std::string page_ordinal_bytes = ShortToBytesLe(page_ordinal);
+ std::ostringstream out;
+ out << file_aad << type_ordinal_bytes_str << row_group_ordinal_bytes
+ << column_ordinal_bytes << page_ordinal_bytes;
+ return out.str();
+}
+
+std::string CreateFooterAad(const std::string& aad_prefix_bytes) {
+ return CreateModuleAad(aad_prefix_bytes, kFooter, static_cast<int16_t>(-1),
+ static_cast<int16_t>(-1), static_cast<int16_t>(-1));
+}
+
+// Update last two bytes with new page ordinal (instead of creating new page AAD
+// from scratch)
+void QuickUpdatePageAad(const std::string& AAD, int16_t new_page_ordinal) {
+ std::string page_ordinal_bytes = ShortToBytesLe(new_page_ordinal);
+ int length = static_cast<int>(AAD.size());
+ std::memcpy(reinterpret_cast<int16_t*>(const_cast<char*>(AAD.c_str() + length - 2)),
+ reinterpret_cast<const int16_t*>(page_ordinal_bytes.c_str()), 2);
+}
+
+void RandBytes(unsigned char* buf, int num) { RAND_bytes(buf, num); }
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/encryption_internal.h b/src/arrow/cpp/src/parquet/encryption/encryption_internal.h
new file mode 100644
index 000000000..e50fb9d0b
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/encryption_internal.h
@@ -0,0 +1,116 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "parquet/properties.h"
+#include "parquet/types.h"
+
+using parquet::ParquetCipher;
+
+namespace parquet {
+namespace encryption {
+
+constexpr int kGcmTagLength = 16;
+constexpr int kNonceLength = 12;
+
+// Module types
+constexpr int8_t kFooter = 0;
+constexpr int8_t kColumnMetaData = 1;
+constexpr int8_t kDataPage = 2;
+constexpr int8_t kDictionaryPage = 3;
+constexpr int8_t kDataPageHeader = 4;
+constexpr int8_t kDictionaryPageHeader = 5;
+constexpr int8_t kColumnIndex = 6;
+constexpr int8_t kOffsetIndex = 7;
+
+/// Performs AES encryption operations with GCM or CTR ciphers.
+class AesEncryptor {
+ public:
+ /// Can serve one key length only. Possible values: 16, 24, 32 bytes.
+ explicit AesEncryptor(ParquetCipher::type alg_id, int key_len, bool metadata);
+
+ static AesEncryptor* Make(ParquetCipher::type alg_id, int key_len, bool metadata,
+ std::vector<AesEncryptor*>* all_encryptors);
+
+ ~AesEncryptor();
+
+ /// Size difference between plaintext and ciphertext, for this cipher.
+ int CiphertextSizeDelta();
+
+ /// Encrypts plaintext with the key and aad. Key length is passed only for validation.
+ /// If different from value in constructor, exception will be thrown.
+ int Encrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key,
+ int key_len, const uint8_t* aad, int aad_len, uint8_t* ciphertext);
+
+ /// Encrypts plaintext footer, in order to compute footer signature (tag).
+ int SignedFooterEncrypt(const uint8_t* footer, int footer_len, const uint8_t* key,
+ int key_len, const uint8_t* aad, int aad_len,
+ const uint8_t* nonce, uint8_t* encrypted_footer);
+
+ void WipeOut();
+
+ private:
+ // PIMPL Idiom
+ class AesEncryptorImpl;
+ std::unique_ptr<AesEncryptorImpl> impl_;
+};
+
+/// Performs AES decryption operations with GCM or CTR ciphers.
+class AesDecryptor {
+ public:
+ /// Can serve one key length only. Possible values: 16, 24, 32 bytes.
+ explicit AesDecryptor(ParquetCipher::type alg_id, int key_len, bool metadata);
+
+ static AesDecryptor* Make(ParquetCipher::type alg_id, int key_len, bool metadata,
+ std::vector<AesDecryptor*>* all_decryptors);
+
+ ~AesDecryptor();
+ void WipeOut();
+
+ /// Size difference between plaintext and ciphertext, for this cipher.
+ int CiphertextSizeDelta();
+
+ /// Decrypts ciphertext with the key and aad. Key length is passed only for
+ /// validation. If different from value in constructor, exception will be thrown.
+ int Decrypt(const uint8_t* ciphertext, int ciphertext_len, const uint8_t* key,
+ int key_len, const uint8_t* aad, int aad_len, uint8_t* plaintext);
+
+ private:
+ // PIMPL Idiom
+ class AesDecryptorImpl;
+ std::unique_ptr<AesDecryptorImpl> impl_;
+};
+
+std::string CreateModuleAad(const std::string& file_aad, int8_t module_type,
+ int16_t row_group_ordinal, int16_t column_ordinal,
+ int16_t page_ordinal);
+
+std::string CreateFooterAad(const std::string& aad_prefix_bytes);
+
+// Update last two bytes of page (or page header) module AAD
+void QuickUpdatePageAad(const std::string& AAD, int16_t new_page_ordinal);
+
+// Wraps OpenSSL RAND_bytes function
+void RandBytes(unsigned char* buf, int num);
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/encryption_internal_nossl.cc b/src/arrow/cpp/src/parquet/encryption/encryption_internal_nossl.cc
new file mode 100644
index 000000000..7f2edfa1d
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/encryption_internal_nossl.cc
@@ -0,0 +1,110 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/encryption/encryption_internal.h"
+#include "parquet/exception.h"
+
+namespace parquet {
+namespace encryption {
+
+void ThrowOpenSSLRequiredException() {
+ throw ParquetException(
+ "Calling encryption method in Arrow/Parquet built without OpenSSL");
+}
+
+class AesEncryptor::AesEncryptorImpl {};
+
+AesEncryptor::~AesEncryptor() {}
+
+int AesEncryptor::SignedFooterEncrypt(const uint8_t* footer, int footer_len,
+ const uint8_t* key, int key_len, const uint8_t* aad,
+ int aad_len, const uint8_t* nonce,
+ uint8_t* encrypted_footer) {
+ ThrowOpenSSLRequiredException();
+ return -1;
+}
+
+void AesEncryptor::WipeOut() { ThrowOpenSSLRequiredException(); }
+
+int AesEncryptor::CiphertextSizeDelta() {
+ ThrowOpenSSLRequiredException();
+ return -1;
+}
+
+int AesEncryptor::Encrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key,
+ int key_len, const uint8_t* aad, int aad_len,
+ uint8_t* ciphertext) {
+ ThrowOpenSSLRequiredException();
+ return -1;
+}
+
+AesEncryptor::AesEncryptor(ParquetCipher::type alg_id, int key_len, bool metadata) {
+ ThrowOpenSSLRequiredException();
+}
+
+class AesDecryptor::AesDecryptorImpl {};
+
+int AesDecryptor::Decrypt(const uint8_t* plaintext, int plaintext_len, const uint8_t* key,
+ int key_len, const uint8_t* aad, int aad_len,
+ uint8_t* ciphertext) {
+ ThrowOpenSSLRequiredException();
+ return -1;
+}
+
+void AesDecryptor::WipeOut() { ThrowOpenSSLRequiredException(); }
+
+AesDecryptor::~AesDecryptor() {}
+
+AesEncryptor* AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, bool metadata,
+ std::vector<AesEncryptor*>* all_encryptors) {
+ return NULLPTR;
+}
+
+AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int key_len, bool metadata) {
+ ThrowOpenSSLRequiredException();
+}
+
+AesDecryptor* AesDecryptor::Make(ParquetCipher::type alg_id, int key_len, bool metadata,
+ std::vector<AesDecryptor*>* all_decryptors) {
+ return NULLPTR;
+}
+
+int AesDecryptor::CiphertextSizeDelta() {
+ ThrowOpenSSLRequiredException();
+ return -1;
+}
+
+std::string CreateModuleAad(const std::string& file_aad, int8_t module_type,
+ int16_t row_group_ordinal, int16_t column_ordinal,
+ int16_t page_ordinal) {
+ ThrowOpenSSLRequiredException();
+ return "";
+}
+
+std::string CreateFooterAad(const std::string& aad_prefix_bytes) {
+ ThrowOpenSSLRequiredException();
+ return "";
+}
+
+void QuickUpdatePageAad(const std::string& AAD, int16_t new_page_ordinal) {
+ ThrowOpenSSLRequiredException();
+}
+
+void RandBytes(unsigned char* buf, int num) { ThrowOpenSSLRequiredException(); }
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/file_key_material_store.h b/src/arrow/cpp/src/parquet/encryption/file_key_material_store.h
new file mode 100644
index 000000000..8cf4af48b
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/file_key_material_store.h
@@ -0,0 +1,31 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License") = 0; you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+namespace parquet {
+namespace encryption {
+
+// Key material can be stored outside the Parquet file, for example in a separate small
+// file in the same folder. This is important for “key rotation”, when MEKs have to be
+// changed (if compromised; or periodically, just in case) - without modifying the Parquet
+// files (often immutable).
+// TODO: details will be implemented later
+class FileKeyMaterialStore {};
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/file_key_unwrapper.cc b/src/arrow/cpp/src/parquet/encryption/file_key_unwrapper.cc
new file mode 100644
index 000000000..1d8d35e73
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/file_key_unwrapper.cc
@@ -0,0 +1,114 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <iostream>
+
+#include "arrow/util/utf8.h"
+
+#include "parquet/encryption/file_key_unwrapper.h"
+#include "parquet/encryption/key_metadata.h"
+
+namespace parquet {
+namespace encryption {
+
+using internal::KeyWithMasterId;
+
+FileKeyUnwrapper::FileKeyUnwrapper(KeyToolkit* key_toolkit,
+ const KmsConnectionConfig& kms_connection_config,
+ double cache_lifetime_seconds)
+ : key_toolkit_(key_toolkit),
+ kms_connection_config_(kms_connection_config),
+ cache_entry_lifetime_seconds_(cache_lifetime_seconds) {
+ kek_per_kek_id_ = key_toolkit_->kek_read_cache_per_token().GetOrCreateInternalCache(
+ kms_connection_config.key_access_token(), cache_entry_lifetime_seconds_);
+}
+
+std::string FileKeyUnwrapper::GetKey(const std::string& key_metadata_bytes) {
+ // key_metadata is expected to be in UTF8 encoding
+ ::arrow::util::InitializeUTF8();
+ if (!::arrow::util::ValidateUTF8(
+ reinterpret_cast<const uint8_t*>(key_metadata_bytes.data()),
+ key_metadata_bytes.size())) {
+ throw ParquetException("key metadata should be in UTF8 encoding");
+ }
+ KeyMetadata key_metadata = KeyMetadata::Parse(key_metadata_bytes);
+
+ if (!key_metadata.key_material_stored_internally()) {
+ throw ParquetException("External key material store is not supported yet.");
+ }
+
+ const KeyMaterial& key_material = key_metadata.key_material();
+
+ return GetDataEncryptionKey(key_material).data_key();
+}
+
+KeyWithMasterId FileKeyUnwrapper::GetDataEncryptionKey(const KeyMaterial& key_material) {
+ auto kms_client = GetKmsClientFromConfigOrKeyMaterial(key_material);
+
+ bool double_wrapping = key_material.is_double_wrapped();
+ const std::string& master_key_id = key_material.master_key_id();
+ const std::string& encoded_wrapped_dek = key_material.wrapped_dek();
+
+ std::string data_key;
+ if (!double_wrapping) {
+ data_key = kms_client->UnwrapKey(encoded_wrapped_dek, master_key_id);
+ } else {
+ // Get Key Encryption Key
+ const std::string& encoded_kek_id = key_material.kek_id();
+ const std::string& encoded_wrapped_kek = key_material.wrapped_kek();
+
+ std::string kek_bytes = kek_per_kek_id_->GetOrInsert(
+ encoded_kek_id, [kms_client, encoded_wrapped_kek, master_key_id]() {
+ return kms_client->UnwrapKey(encoded_wrapped_kek, master_key_id);
+ });
+
+ // Decrypt the data key
+ std::string aad = ::arrow::util::base64_decode(encoded_kek_id);
+ data_key = internal::DecryptKeyLocally(encoded_wrapped_dek, kek_bytes, aad);
+ }
+
+ return KeyWithMasterId(data_key, master_key_id);
+}
+
+std::shared_ptr<KmsClient> FileKeyUnwrapper::GetKmsClientFromConfigOrKeyMaterial(
+ const KeyMaterial& key_material) {
+ std::string& kms_instance_id = kms_connection_config_.kms_instance_id;
+ if (kms_instance_id.empty()) {
+ kms_instance_id = key_material.kms_instance_id();
+ if (kms_instance_id.empty()) {
+ throw ParquetException(
+ "KMS instance ID is missing both in both kms connection configuration and file "
+ "key material");
+ }
+ }
+
+ std::string& kms_instance_url = kms_connection_config_.kms_instance_url;
+ if (kms_instance_url.empty()) {
+ kms_instance_url = key_material.kms_instance_url();
+ if (kms_instance_url.empty()) {
+ throw ParquetException(
+ "KMS instance ID is missing both in both kms connection configuration and file "
+ "key material");
+ }
+ }
+
+ return key_toolkit_->GetKmsClient(kms_connection_config_,
+ cache_entry_lifetime_seconds_);
+}
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/file_key_unwrapper.h b/src/arrow/cpp/src/parquet/encryption/file_key_unwrapper.h
new file mode 100644
index 000000000..727e82b5b
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/file_key_unwrapper.h
@@ -0,0 +1,66 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/util/concurrent_map.h"
+
+#include "parquet/encryption/encryption.h"
+#include "parquet/encryption/key_material.h"
+#include "parquet/encryption/key_toolkit.h"
+#include "parquet/encryption/key_toolkit_internal.h"
+#include "parquet/encryption/kms_client.h"
+#include "parquet/platform.h"
+
+namespace parquet {
+namespace encryption {
+
+// This class will retrieve the key from "key metadata", following these steps:
+// 1. Parse "key metadata" (see structure in KeyMetadata class).
+// 2. Retrieve "key material" which can be stored inside or outside "key metadata"
+// Currently we don't support the case "key material" stores outside "key metadata"
+// yet.
+// 3. Unwrap the "data encryption key" from "key material". There are 2 modes:
+// 3.1. single wrapping: decrypt the wrapped "data encryption key" directly with "master
+// encryption key" 3.2. double wrapping: 2 steps: 3.2.1. "key encryption key" is decrypted
+// with "master encryption key" 3.2.2. "data encryption key" is decrypted with the above
+// "key encryption key"
+class PARQUET_EXPORT FileKeyUnwrapper : public DecryptionKeyRetriever {
+ public:
+ /// key_toolkit and kms_connection_config is to get KmsClient from cache or create
+ /// KmsClient if it's not in the cache yet. cache_entry_lifetime_seconds is life time of
+ /// KmsClient in the cache.
+ FileKeyUnwrapper(KeyToolkit* key_toolkit,
+ const KmsConnectionConfig& kms_connection_config,
+ double cache_lifetime_seconds);
+
+ std::string GetKey(const std::string& key_metadata) override;
+
+ private:
+ internal::KeyWithMasterId GetDataEncryptionKey(const KeyMaterial& key_material);
+ std::shared_ptr<KmsClient> GetKmsClientFromConfigOrKeyMaterial(
+ const KeyMaterial& key_material);
+
+ /// A map of Key Encryption Key (KEK) ID -> KEK bytes, for the current token
+ std::shared_ptr<::arrow::util::ConcurrentMap<std::string, std::string>> kek_per_kek_id_;
+ KeyToolkit* key_toolkit_;
+ KmsConnectionConfig kms_connection_config_;
+ const double cache_entry_lifetime_seconds_;
+};
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/file_key_wrapper.cc b/src/arrow/cpp/src/parquet/encryption/file_key_wrapper.cc
new file mode 100644
index 000000000..b931f0fbf
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/file_key_wrapper.cc
@@ -0,0 +1,109 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/encryption/file_key_wrapper.h"
+#include "parquet/encryption/encryption_internal.h"
+#include "parquet/encryption/key_material.h"
+#include "parquet/encryption/key_metadata.h"
+#include "parquet/encryption/key_toolkit_internal.h"
+#include "parquet/exception.h"
+
+namespace parquet {
+namespace encryption {
+
+FileKeyWrapper::FileKeyWrapper(KeyToolkit* key_toolkit,
+ const KmsConnectionConfig& kms_connection_config,
+ std::shared_ptr<FileKeyMaterialStore> key_material_store,
+ double cache_entry_lifetime_seconds, bool double_wrapping)
+ : kms_connection_config_(kms_connection_config),
+ key_material_store_(key_material_store),
+ cache_entry_lifetime_seconds_(cache_entry_lifetime_seconds),
+ double_wrapping_(double_wrapping) {
+ kms_connection_config_.SetDefaultIfEmpty();
+ // Check caches upon each file writing (clean once in cache_entry_lifetime_seconds_)
+ key_toolkit->kms_client_cache_per_token().CheckCacheForExpiredTokens(
+ cache_entry_lifetime_seconds_);
+ kms_client_ =
+ key_toolkit->GetKmsClient(kms_connection_config, cache_entry_lifetime_seconds_);
+
+ if (double_wrapping) {
+ key_toolkit->kek_write_cache_per_token().CheckCacheForExpiredTokens(
+ cache_entry_lifetime_seconds_);
+ kek_per_master_key_id_ =
+ key_toolkit->kek_write_cache_per_token().GetOrCreateInternalCache(
+ kms_connection_config.key_access_token(), cache_entry_lifetime_seconds_);
+ }
+}
+
+std::string FileKeyWrapper::GetEncryptionKeyMetadata(const std::string& data_key,
+ const std::string& master_key_id,
+ bool is_footer_key) {
+ if (kms_client_ == NULL) {
+ throw ParquetException("No KMS client available. See previous errors.");
+ }
+
+ std::string encoded_kek_id;
+ std::string encoded_wrapped_kek;
+ std::string encoded_wrapped_dek;
+ if (!double_wrapping_) {
+ encoded_wrapped_dek = kms_client_->WrapKey(data_key, master_key_id);
+ } else {
+ // Find in cache, or generate KEK for Master Key ID
+ KeyEncryptionKey key_encryption_key = kek_per_master_key_id_->GetOrInsert(
+ master_key_id, [this, master_key_id]() -> KeyEncryptionKey {
+ return this->CreateKeyEncryptionKey(master_key_id);
+ });
+ // Encrypt DEK with KEK
+ const std::string& aad = key_encryption_key.kek_id();
+ const std::string& kek_bytes = key_encryption_key.kek_bytes();
+ encoded_wrapped_dek = internal::EncryptKeyLocally(data_key, kek_bytes, aad);
+ encoded_kek_id = key_encryption_key.encoded_kek_id();
+ encoded_wrapped_kek = key_encryption_key.encoded_wrapped_kek();
+ }
+
+ bool store_key_material_internally = (NULL == key_material_store_);
+
+ std::string serialized_key_material =
+ KeyMaterial::SerializeToJson(is_footer_key, kms_connection_config_.kms_instance_id,
+ kms_connection_config_.kms_instance_url, master_key_id,
+ double_wrapping_, encoded_kek_id, encoded_wrapped_kek,
+ encoded_wrapped_dek, store_key_material_internally);
+
+ // Internal key material storage: key metadata and key material are the same
+ if (store_key_material_internally) {
+ return serialized_key_material;
+ } else {
+ throw ParquetException("External key material store is not supported yet.");
+ }
+}
+
+KeyEncryptionKey FileKeyWrapper::CreateKeyEncryptionKey(
+ const std::string& master_key_id) {
+ std::string kek_bytes(kKeyEncryptionKeyLength, '\0');
+ RandBytes(reinterpret_cast<uint8_t*>(&kek_bytes[0]), kKeyEncryptionKeyLength);
+
+ std::string kek_id(kKeyEncryptionKeyIdLength, '\0');
+ RandBytes(reinterpret_cast<uint8_t*>(&kek_id[0]), kKeyEncryptionKeyIdLength);
+
+ // Encrypt KEK with Master key
+ std::string encoded_wrapped_kek = kms_client_->WrapKey(kek_bytes, master_key_id);
+
+ return KeyEncryptionKey(kek_bytes, kek_id, encoded_wrapped_kek);
+}
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/file_key_wrapper.h b/src/arrow/cpp/src/parquet/encryption/file_key_wrapper.h
new file mode 100644
index 000000000..248c5931a
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/file_key_wrapper.h
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include "arrow/util/concurrent_map.h"
+
+#include "parquet/encryption/file_key_material_store.h"
+#include "parquet/encryption/key_encryption_key.h"
+#include "parquet/encryption/key_toolkit.h"
+#include "parquet/encryption/kms_client.h"
+#include "parquet/platform.h"
+
+namespace parquet {
+namespace encryption {
+
+// This class will generate "key metadata" from "data encryption key" and "master key",
+// following these steps:
+// 1. Wrap "data encryption key". There are 2 modes:
+// 1.1. single wrapping: encrypt "data encryption key" directly with "master encryption
+// key"
+// 1.2. double wrapping: 2 steps:
+// 1.2.1. "key encryption key" is randomized (see KeyEncryptionKey class)
+// 1.2.2. "data encryption key" is encrypted with the above "key encryption key"
+// 2. Create "key material" (see structure in KeyMaterial class)
+// 3. Create "key metadata" with "key material" inside or a reference to outside "key
+// material" (see structure in KeyMetadata class).
+// We don't support the case "key material" stores outside "key metadata" yet.
+class PARQUET_EXPORT FileKeyWrapper {
+ public:
+ static constexpr int kKeyEncryptionKeyLength = 16;
+ static constexpr int kKeyEncryptionKeyIdLength = 16;
+
+ /// key_toolkit and kms_connection_config is to get KmsClient from the cache or create
+ /// KmsClient if it's not in the cache yet. cache_entry_lifetime_seconds is life time of
+ /// KmsClient in the cache. key_material_store is to store "key material" outside
+ /// parquet file, NULL if "key material" is stored inside parquet file.
+ FileKeyWrapper(KeyToolkit* key_toolkit,
+ const KmsConnectionConfig& kms_connection_config,
+ std::shared_ptr<FileKeyMaterialStore> key_material_store,
+ double cache_entry_lifetime_seconds, bool double_wrapping);
+
+ /// Creates key_metadata field for a given data key, via wrapping the key with the
+ /// master key
+ std::string GetEncryptionKeyMetadata(const std::string& data_key,
+ const std::string& master_key_id,
+ bool is_footer_key);
+
+ private:
+ KeyEncryptionKey CreateKeyEncryptionKey(const std::string& master_key_id);
+
+ /// A map of Master Encryption Key ID -> KeyEncryptionKey, for the current token
+ std::shared_ptr<::arrow::util::ConcurrentMap<std::string, KeyEncryptionKey>>
+ kek_per_master_key_id_;
+
+ std::shared_ptr<KmsClient> kms_client_;
+ KmsConnectionConfig kms_connection_config_;
+ std::shared_ptr<FileKeyMaterialStore> key_material_store_;
+ const double cache_entry_lifetime_seconds_;
+ const bool double_wrapping_;
+};
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/internal_file_decryptor.cc b/src/arrow/cpp/src/parquet/encryption/internal_file_decryptor.cc
new file mode 100644
index 000000000..6381e4f37
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/internal_file_decryptor.cc
@@ -0,0 +1,240 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/encryption/internal_file_decryptor.h"
+#include "parquet/encryption/encryption.h"
+#include "parquet/encryption/encryption_internal.h"
+
+namespace parquet {
+
+// Decryptor
+Decryptor::Decryptor(encryption::AesDecryptor* aes_decryptor, const std::string& key,
+ const std::string& file_aad, const std::string& aad,
+ ::arrow::MemoryPool* pool)
+ : aes_decryptor_(aes_decryptor),
+ key_(key),
+ file_aad_(file_aad),
+ aad_(aad),
+ pool_(pool) {}
+
+int Decryptor::CiphertextSizeDelta() { return aes_decryptor_->CiphertextSizeDelta(); }
+
+int Decryptor::Decrypt(const uint8_t* ciphertext, int ciphertext_len,
+ uint8_t* plaintext) {
+ return aes_decryptor_->Decrypt(ciphertext, ciphertext_len, str2bytes(key_),
+ static_cast<int>(key_.size()), str2bytes(aad_),
+ static_cast<int>(aad_.size()), plaintext);
+}
+
+// InternalFileDecryptor
+InternalFileDecryptor::InternalFileDecryptor(FileDecryptionProperties* properties,
+ const std::string& file_aad,
+ ParquetCipher::type algorithm,
+ const std::string& footer_key_metadata,
+ ::arrow::MemoryPool* pool)
+ : properties_(properties),
+ file_aad_(file_aad),
+ algorithm_(algorithm),
+ footer_key_metadata_(footer_key_metadata),
+ pool_(pool) {
+ if (properties_->is_utilized()) {
+ throw ParquetException(
+ "Re-using decryption properties with explicit keys for another file");
+ }
+ properties_->set_utilized();
+}
+
+void InternalFileDecryptor::WipeOutDecryptionKeys() {
+ properties_->WipeOutDecryptionKeys();
+ for (auto const& i : all_decryptors_) {
+ i->WipeOut();
+ }
+}
+
+std::string InternalFileDecryptor::GetFooterKey() {
+ std::string footer_key = properties_->footer_key();
+ // ignore footer key metadata if footer key is explicitly set via API
+ if (footer_key.empty()) {
+ if (footer_key_metadata_.empty())
+ throw ParquetException("No footer key or key metadata");
+ if (properties_->key_retriever() == nullptr)
+ throw ParquetException("No footer key or key retriever");
+ try {
+ footer_key = properties_->key_retriever()->GetKey(footer_key_metadata_);
+ } catch (KeyAccessDeniedException& e) {
+ std::stringstream ss;
+ ss << "Footer key: access denied " << e.what() << "\n";
+ throw ParquetException(ss.str());
+ }
+ }
+ if (footer_key.empty()) {
+ throw ParquetException(
+ "Footer key unavailable. Could not verify "
+ "plaintext footer metadata");
+ }
+ return footer_key;
+}
+
+std::shared_ptr<Decryptor> InternalFileDecryptor::GetFooterDecryptor() {
+ std::string aad = encryption::CreateFooterAad(file_aad_);
+ return GetFooterDecryptor(aad, true);
+}
+
+std::shared_ptr<Decryptor> InternalFileDecryptor::GetFooterDecryptorForColumnMeta(
+ const std::string& aad) {
+ return GetFooterDecryptor(aad, true);
+}
+
+std::shared_ptr<Decryptor> InternalFileDecryptor::GetFooterDecryptorForColumnData(
+ const std::string& aad) {
+ return GetFooterDecryptor(aad, false);
+}
+
+std::shared_ptr<Decryptor> InternalFileDecryptor::GetFooterDecryptor(
+ const std::string& aad, bool metadata) {
+ if (metadata) {
+ if (footer_metadata_decryptor_ != nullptr) return footer_metadata_decryptor_;
+ } else {
+ if (footer_data_decryptor_ != nullptr) return footer_data_decryptor_;
+ }
+
+ std::string footer_key = properties_->footer_key();
+ if (footer_key.empty()) {
+ if (footer_key_metadata_.empty())
+ throw ParquetException("No footer key or key metadata");
+ if (properties_->key_retriever() == nullptr)
+ throw ParquetException("No footer key or key retriever");
+ try {
+ footer_key = properties_->key_retriever()->GetKey(footer_key_metadata_);
+ } catch (KeyAccessDeniedException& e) {
+ std::stringstream ss;
+ ss << "Footer key: access denied " << e.what() << "\n";
+ throw ParquetException(ss.str());
+ }
+ }
+ if (footer_key.empty()) {
+ throw ParquetException(
+ "Invalid footer encryption key. "
+ "Could not parse footer metadata");
+ }
+
+ // Create both data and metadata decryptors to avoid redundant retrieval of key
+ // from the key_retriever.
+ auto aes_metadata_decryptor = GetMetaAesDecryptor(footer_key.size());
+ auto aes_data_decryptor = GetDataAesDecryptor(footer_key.size());
+
+ footer_metadata_decryptor_ = std::make_shared<Decryptor>(
+ aes_metadata_decryptor, footer_key, file_aad_, aad, pool_);
+ footer_data_decryptor_ =
+ std::make_shared<Decryptor>(aes_data_decryptor, footer_key, file_aad_, aad, pool_);
+
+ if (metadata) return footer_metadata_decryptor_;
+ return footer_data_decryptor_;
+}
+
+std::shared_ptr<Decryptor> InternalFileDecryptor::GetColumnMetaDecryptor(
+ const std::string& column_path, const std::string& column_key_metadata,
+ const std::string& aad) {
+ return GetColumnDecryptor(column_path, column_key_metadata, aad, true);
+}
+
+std::shared_ptr<Decryptor> InternalFileDecryptor::GetColumnDataDecryptor(
+ const std::string& column_path, const std::string& column_key_metadata,
+ const std::string& aad) {
+ return GetColumnDecryptor(column_path, column_key_metadata, aad, false);
+}
+
+std::shared_ptr<Decryptor> InternalFileDecryptor::GetColumnDecryptor(
+ const std::string& column_path, const std::string& column_key_metadata,
+ const std::string& aad, bool metadata) {
+ std::string column_key;
+ // first look if we already got the decryptor from before
+ if (metadata) {
+ if (column_metadata_map_.find(column_path) != column_metadata_map_.end()) {
+ auto res(column_metadata_map_.at(column_path));
+ res->UpdateAad(aad);
+ return res;
+ }
+ } else {
+ if (column_data_map_.find(column_path) != column_data_map_.end()) {
+ auto res(column_data_map_.at(column_path));
+ res->UpdateAad(aad);
+ return res;
+ }
+ }
+
+ column_key = properties_->column_key(column_path);
+ // No explicit column key given via API. Retrieve via key metadata.
+ if (column_key.empty() && !column_key_metadata.empty() &&
+ properties_->key_retriever() != nullptr) {
+ try {
+ column_key = properties_->key_retriever()->GetKey(column_key_metadata);
+ } catch (KeyAccessDeniedException& e) {
+ std::stringstream ss;
+ ss << "HiddenColumnException, path=" + column_path + " " << e.what() << "\n";
+ throw HiddenColumnException(ss.str());
+ }
+ }
+ if (column_key.empty()) {
+ throw HiddenColumnException("HiddenColumnException, path=" + column_path);
+ }
+
+ // Create both data and metadata decryptors to avoid redundant retrieval of key
+ // using the key_retriever.
+ auto aes_metadata_decryptor = GetMetaAesDecryptor(column_key.size());
+ auto aes_data_decryptor = GetDataAesDecryptor(column_key.size());
+
+ column_metadata_map_[column_path] = std::make_shared<Decryptor>(
+ aes_metadata_decryptor, column_key, file_aad_, aad, pool_);
+ column_data_map_[column_path] =
+ std::make_shared<Decryptor>(aes_data_decryptor, column_key, file_aad_, aad, pool_);
+
+ if (metadata) return column_metadata_map_[column_path];
+ return column_data_map_[column_path];
+}
+
+int InternalFileDecryptor::MapKeyLenToDecryptorArrayIndex(int key_len) {
+ if (key_len == 16)
+ return 0;
+ else if (key_len == 24)
+ return 1;
+ else if (key_len == 32)
+ return 2;
+ throw ParquetException("decryption key must be 16, 24 or 32 bytes in length");
+}
+
+encryption::AesDecryptor* InternalFileDecryptor::GetMetaAesDecryptor(size_t key_size) {
+ int key_len = static_cast<int>(key_size);
+ int index = MapKeyLenToDecryptorArrayIndex(key_len);
+ if (meta_decryptor_[index] == nullptr) {
+ meta_decryptor_[index].reset(
+ encryption::AesDecryptor::Make(algorithm_, key_len, true, &all_decryptors_));
+ }
+ return meta_decryptor_[index].get();
+}
+
+encryption::AesDecryptor* InternalFileDecryptor::GetDataAesDecryptor(size_t key_size) {
+ int key_len = static_cast<int>(key_size);
+ int index = MapKeyLenToDecryptorArrayIndex(key_len);
+ if (data_decryptor_[index] == nullptr) {
+ data_decryptor_[index].reset(
+ encryption::AesDecryptor::Make(algorithm_, key_len, false, &all_decryptors_));
+ }
+ return data_decryptor_[index].get();
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/internal_file_decryptor.h b/src/arrow/cpp/src/parquet/encryption/internal_file_decryptor.h
new file mode 100644
index 000000000..011c4acbe
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/internal_file_decryptor.h
@@ -0,0 +1,121 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "parquet/schema.h"
+
+namespace parquet {
+
+namespace encryption {
+class AesDecryptor;
+class AesEncryptor;
+} // namespace encryption
+
+class FileDecryptionProperties;
+
+class PARQUET_EXPORT Decryptor {
+ public:
+ Decryptor(encryption::AesDecryptor* decryptor, const std::string& key,
+ const std::string& file_aad, const std::string& aad,
+ ::arrow::MemoryPool* pool);
+
+ const std::string& file_aad() const { return file_aad_; }
+ void UpdateAad(const std::string& aad) { aad_ = aad; }
+ ::arrow::MemoryPool* pool() { return pool_; }
+
+ int CiphertextSizeDelta();
+ int Decrypt(const uint8_t* ciphertext, int ciphertext_len, uint8_t* plaintext);
+
+ private:
+ encryption::AesDecryptor* aes_decryptor_;
+ std::string key_;
+ std::string file_aad_;
+ std::string aad_;
+ ::arrow::MemoryPool* pool_;
+};
+
+class InternalFileDecryptor {
+ public:
+ explicit InternalFileDecryptor(FileDecryptionProperties* properties,
+ const std::string& file_aad,
+ ParquetCipher::type algorithm,
+ const std::string& footer_key_metadata,
+ ::arrow::MemoryPool* pool);
+
+ std::string& file_aad() { return file_aad_; }
+
+ std::string GetFooterKey();
+
+ ParquetCipher::type algorithm() { return algorithm_; }
+
+ std::string& footer_key_metadata() { return footer_key_metadata_; }
+
+ FileDecryptionProperties* properties() { return properties_; }
+
+ void WipeOutDecryptionKeys();
+
+ ::arrow::MemoryPool* pool() { return pool_; }
+
+ std::shared_ptr<Decryptor> GetFooterDecryptor();
+ std::shared_ptr<Decryptor> GetFooterDecryptorForColumnMeta(const std::string& aad = "");
+ std::shared_ptr<Decryptor> GetFooterDecryptorForColumnData(const std::string& aad = "");
+ std::shared_ptr<Decryptor> GetColumnMetaDecryptor(
+ const std::string& column_path, const std::string& column_key_metadata,
+ const std::string& aad = "");
+ std::shared_ptr<Decryptor> GetColumnDataDecryptor(
+ const std::string& column_path, const std::string& column_key_metadata,
+ const std::string& aad = "");
+
+ private:
+ FileDecryptionProperties* properties_;
+ // Concatenation of aad_prefix (if exists) and aad_file_unique
+ std::string file_aad_;
+ std::map<std::string, std::shared_ptr<Decryptor>> column_data_map_;
+ std::map<std::string, std::shared_ptr<Decryptor>> column_metadata_map_;
+
+ std::shared_ptr<Decryptor> footer_metadata_decryptor_;
+ std::shared_ptr<Decryptor> footer_data_decryptor_;
+ ParquetCipher::type algorithm_;
+ std::string footer_key_metadata_;
+ std::vector<encryption::AesDecryptor*> all_decryptors_;
+
+ /// Key must be 16, 24 or 32 bytes in length. Thus there could be up to three
+ // types of meta_decryptors and data_decryptors.
+ std::unique_ptr<encryption::AesDecryptor> meta_decryptor_[3];
+ std::unique_ptr<encryption::AesDecryptor> data_decryptor_[3];
+
+ ::arrow::MemoryPool* pool_;
+
+ std::shared_ptr<Decryptor> GetFooterDecryptor(const std::string& aad, bool metadata);
+ std::shared_ptr<Decryptor> GetColumnDecryptor(const std::string& column_path,
+ const std::string& column_key_metadata,
+ const std::string& aad,
+ bool metadata = false);
+
+ encryption::AesDecryptor* GetMetaAesDecryptor(size_t key_size);
+ encryption::AesDecryptor* GetDataAesDecryptor(size_t key_size);
+
+ int MapKeyLenToDecryptorArrayIndex(int key_len);
+};
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/internal_file_encryptor.cc b/src/arrow/cpp/src/parquet/encryption/internal_file_encryptor.cc
new file mode 100644
index 000000000..15bf52b84
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/internal_file_encryptor.cc
@@ -0,0 +1,170 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/encryption/internal_file_encryptor.h"
+#include "parquet/encryption/encryption.h"
+#include "parquet/encryption/encryption_internal.h"
+
+namespace parquet {
+
+// Encryptor
+Encryptor::Encryptor(encryption::AesEncryptor* aes_encryptor, const std::string& key,
+ const std::string& file_aad, const std::string& aad,
+ ::arrow::MemoryPool* pool)
+ : aes_encryptor_(aes_encryptor),
+ key_(key),
+ file_aad_(file_aad),
+ aad_(aad),
+ pool_(pool) {}
+
+int Encryptor::CiphertextSizeDelta() { return aes_encryptor_->CiphertextSizeDelta(); }
+
+int Encryptor::Encrypt(const uint8_t* plaintext, int plaintext_len, uint8_t* ciphertext) {
+ return aes_encryptor_->Encrypt(plaintext, plaintext_len, str2bytes(key_),
+ static_cast<int>(key_.size()), str2bytes(aad_),
+ static_cast<int>(aad_.size()), ciphertext);
+}
+
+// InternalFileEncryptor
+InternalFileEncryptor::InternalFileEncryptor(FileEncryptionProperties* properties,
+ ::arrow::MemoryPool* pool)
+ : properties_(properties), pool_(pool) {
+ if (properties_->is_utilized()) {
+ throw ParquetException("Re-using encryption properties for another file");
+ }
+ properties_->set_utilized();
+}
+
+void InternalFileEncryptor::WipeOutEncryptionKeys() {
+ properties_->WipeOutEncryptionKeys();
+
+ for (auto const& i : all_encryptors_) {
+ i->WipeOut();
+ }
+}
+
+std::shared_ptr<Encryptor> InternalFileEncryptor::GetFooterEncryptor() {
+ if (footer_encryptor_ != nullptr) {
+ return footer_encryptor_;
+ }
+
+ ParquetCipher::type algorithm = properties_->algorithm().algorithm;
+ std::string footer_aad = encryption::CreateFooterAad(properties_->file_aad());
+ std::string footer_key = properties_->footer_key();
+ auto aes_encryptor = GetMetaAesEncryptor(algorithm, footer_key.size());
+ footer_encryptor_ = std::make_shared<Encryptor>(
+ aes_encryptor, footer_key, properties_->file_aad(), footer_aad, pool_);
+ return footer_encryptor_;
+}
+
+std::shared_ptr<Encryptor> InternalFileEncryptor::GetFooterSigningEncryptor() {
+ if (footer_signing_encryptor_ != nullptr) {
+ return footer_signing_encryptor_;
+ }
+
+ ParquetCipher::type algorithm = properties_->algorithm().algorithm;
+ std::string footer_aad = encryption::CreateFooterAad(properties_->file_aad());
+ std::string footer_signing_key = properties_->footer_key();
+ auto aes_encryptor = GetMetaAesEncryptor(algorithm, footer_signing_key.size());
+ footer_signing_encryptor_ = std::make_shared<Encryptor>(
+ aes_encryptor, footer_signing_key, properties_->file_aad(), footer_aad, pool_);
+ return footer_signing_encryptor_;
+}
+
+std::shared_ptr<Encryptor> InternalFileEncryptor::GetColumnMetaEncryptor(
+ const std::string& column_path) {
+ return GetColumnEncryptor(column_path, true);
+}
+
+std::shared_ptr<Encryptor> InternalFileEncryptor::GetColumnDataEncryptor(
+ const std::string& column_path) {
+ return GetColumnEncryptor(column_path, false);
+}
+
+std::shared_ptr<Encryptor>
+InternalFileEncryptor::InternalFileEncryptor::GetColumnEncryptor(
+ const std::string& column_path, bool metadata) {
+ // first look if we already got the encryptor from before
+ if (metadata) {
+ if (column_metadata_map_.find(column_path) != column_metadata_map_.end()) {
+ return column_metadata_map_.at(column_path);
+ }
+ } else {
+ if (column_data_map_.find(column_path) != column_data_map_.end()) {
+ return column_data_map_.at(column_path);
+ }
+ }
+ auto column_prop = properties_->column_encryption_properties(column_path);
+ if (column_prop == nullptr) {
+ return nullptr;
+ }
+
+ std::string key;
+ if (column_prop->is_encrypted_with_footer_key()) {
+ key = properties_->footer_key();
+ } else {
+ key = column_prop->key();
+ }
+
+ ParquetCipher::type algorithm = properties_->algorithm().algorithm;
+ auto aes_encryptor = metadata ? GetMetaAesEncryptor(algorithm, key.size())
+ : GetDataAesEncryptor(algorithm, key.size());
+
+ std::string file_aad = properties_->file_aad();
+ std::shared_ptr<Encryptor> encryptor =
+ std::make_shared<Encryptor>(aes_encryptor, key, file_aad, "", pool_);
+ if (metadata)
+ column_metadata_map_[column_path] = encryptor;
+ else
+ column_data_map_[column_path] = encryptor;
+
+ return encryptor;
+}
+
+int InternalFileEncryptor::MapKeyLenToEncryptorArrayIndex(int key_len) {
+ if (key_len == 16)
+ return 0;
+ else if (key_len == 24)
+ return 1;
+ else if (key_len == 32)
+ return 2;
+ throw ParquetException("encryption key must be 16, 24 or 32 bytes in length");
+}
+
+encryption::AesEncryptor* InternalFileEncryptor::GetMetaAesEncryptor(
+ ParquetCipher::type algorithm, size_t key_size) {
+ int key_len = static_cast<int>(key_size);
+ int index = MapKeyLenToEncryptorArrayIndex(key_len);
+ if (meta_encryptor_[index] == nullptr) {
+ meta_encryptor_[index].reset(
+ encryption::AesEncryptor::Make(algorithm, key_len, true, &all_encryptors_));
+ }
+ return meta_encryptor_[index].get();
+}
+
+encryption::AesEncryptor* InternalFileEncryptor::GetDataAesEncryptor(
+ ParquetCipher::type algorithm, size_t key_size) {
+ int key_len = static_cast<int>(key_size);
+ int index = MapKeyLenToEncryptorArrayIndex(key_len);
+ if (data_encryptor_[index] == nullptr) {
+ data_encryptor_[index].reset(
+ encryption::AesEncryptor::Make(algorithm, key_len, false, &all_encryptors_));
+ }
+ return data_encryptor_[index].get();
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/internal_file_encryptor.h b/src/arrow/cpp/src/parquet/encryption/internal_file_encryptor.h
new file mode 100644
index 000000000..3cbe53500
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/internal_file_encryptor.h
@@ -0,0 +1,109 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "parquet/encryption/encryption.h"
+#include "parquet/schema.h"
+
+namespace parquet {
+
+namespace encryption {
+class AesEncryptor;
+} // namespace encryption
+
+class FileEncryptionProperties;
+class ColumnEncryptionProperties;
+
+class PARQUET_EXPORT Encryptor {
+ public:
+ Encryptor(encryption::AesEncryptor* aes_encryptor, const std::string& key,
+ const std::string& file_aad, const std::string& aad,
+ ::arrow::MemoryPool* pool);
+ const std::string& file_aad() { return file_aad_; }
+ void UpdateAad(const std::string& aad) { aad_ = aad; }
+ ::arrow::MemoryPool* pool() { return pool_; }
+
+ int CiphertextSizeDelta();
+ int Encrypt(const uint8_t* plaintext, int plaintext_len, uint8_t* ciphertext);
+
+ bool EncryptColumnMetaData(
+ bool encrypted_footer,
+ const std::shared_ptr<ColumnEncryptionProperties>& column_encryption_properties) {
+ // if column is not encrypted then do not encrypt the column metadata
+ if (!column_encryption_properties || !column_encryption_properties->is_encrypted())
+ return false;
+ // if plaintext footer then encrypt the column metadata
+ if (!encrypted_footer) return true;
+ // if column is not encrypted with footer key then encrypt the column metadata
+ return !column_encryption_properties->is_encrypted_with_footer_key();
+ }
+
+ private:
+ encryption::AesEncryptor* aes_encryptor_;
+ std::string key_;
+ std::string file_aad_;
+ std::string aad_;
+ ::arrow::MemoryPool* pool_;
+};
+
+class InternalFileEncryptor {
+ public:
+ explicit InternalFileEncryptor(FileEncryptionProperties* properties,
+ ::arrow::MemoryPool* pool);
+
+ std::shared_ptr<Encryptor> GetFooterEncryptor();
+ std::shared_ptr<Encryptor> GetFooterSigningEncryptor();
+ std::shared_ptr<Encryptor> GetColumnMetaEncryptor(const std::string& column_path);
+ std::shared_ptr<Encryptor> GetColumnDataEncryptor(const std::string& column_path);
+ void WipeOutEncryptionKeys();
+
+ private:
+ FileEncryptionProperties* properties_;
+
+ std::map<std::string, std::shared_ptr<Encryptor>> column_data_map_;
+ std::map<std::string, std::shared_ptr<Encryptor>> column_metadata_map_;
+
+ std::shared_ptr<Encryptor> footer_signing_encryptor_;
+ std::shared_ptr<Encryptor> footer_encryptor_;
+
+ std::vector<encryption::AesEncryptor*> all_encryptors_;
+
+ // Key must be 16, 24 or 32 bytes in length. Thus there could be up to three
+ // types of meta_encryptors and data_encryptors.
+ std::unique_ptr<encryption::AesEncryptor> meta_encryptor_[3];
+ std::unique_ptr<encryption::AesEncryptor> data_encryptor_[3];
+
+ ::arrow::MemoryPool* pool_;
+
+ std::shared_ptr<Encryptor> GetColumnEncryptor(const std::string& column_path,
+ bool metadata);
+
+ encryption::AesEncryptor* GetMetaAesEncryptor(ParquetCipher::type algorithm,
+ size_t key_len);
+ encryption::AesEncryptor* GetDataAesEncryptor(ParquetCipher::type algorithm,
+ size_t key_len);
+
+ int MapKeyLenToEncryptorArrayIndex(int key_len);
+};
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/key_encryption_key.h b/src/arrow/cpp/src/parquet/encryption/key_encryption_key.h
new file mode 100644
index 000000000..153bb4b5e
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/key_encryption_key.h
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <vector>
+
+#include "arrow/util/base64.h"
+
+namespace parquet {
+namespace encryption {
+
+// In the double wrapping mode, each "data encryption key" (DEK) is encrypted with a “key
+// encryption key” (KEK), that in turn is encrypted with a "master encryption key" (MEK).
+// In a writer process, a random KEK is generated for each MEK ID, and cached in a <MEK-ID
+// : KEK> map. This allows to perform an interaction with a KMS server only once for each
+// MEK, in order to wrap its KEK. "Data encryption key" (DEK) wrapping is performed
+// locally, and does not involve an interaction with a KMS server.
+class KeyEncryptionKey {
+ public:
+ KeyEncryptionKey(std::string kek_bytes, std::string kek_id,
+ std::string encoded_wrapped_kek)
+ : kek_bytes_(std::move(kek_bytes)),
+ kek_id_(std::move(kek_id)),
+ encoded_kek_id_(::arrow::util::base64_encode(kek_id_)),
+ encoded_wrapped_kek_(std::move(encoded_wrapped_kek)) {}
+
+ const std::string& kek_bytes() const { return kek_bytes_; }
+
+ const std::string& kek_id() const { return kek_id_; }
+
+ const std::string& encoded_kek_id() const { return encoded_kek_id_; }
+
+ const std::string& encoded_wrapped_kek() const { return encoded_wrapped_kek_; }
+
+ private:
+ std::string kek_bytes_;
+ std::string kek_id_;
+ std::string encoded_kek_id_;
+ std::string encoded_wrapped_kek_;
+};
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/key_management_test.cc b/src/arrow/cpp/src/parquet/encryption/key_management_test.cc
new file mode 100644
index 000000000..81e7d89a3
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/key_management_test.cc
@@ -0,0 +1,225 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <iostream>
+#include <string>
+#include <thread>
+#include <unordered_map>
+
+#include "arrow/testing/util.h"
+#include "arrow/util/logging.h"
+
+#include "parquet/encryption/crypto_factory.h"
+#include "parquet/encryption/key_toolkit.h"
+#include "parquet/encryption/test_encryption_util.h"
+#include "parquet/encryption/test_in_memory_kms.h"
+#include "parquet/test_util.h"
+
+namespace parquet {
+namespace encryption {
+namespace test {
+
+std::unique_ptr<TemporaryDir> temp_dir;
+
+class TestEncryptionKeyManagement : public ::testing::Test {
+ public:
+ void SetUp() {
+ key_list_ = BuildKeyMap(kColumnMasterKeyIds, kColumnMasterKeys, kFooterMasterKeyId,
+ kFooterMasterKey);
+ column_key_mapping_ = BuildColumnKeyMapping();
+ }
+
+ static void SetUpTestCase();
+
+ protected:
+ FileEncryptor encryptor_;
+ FileDecryptor decryptor_;
+
+ std::unordered_map<std::string, std::string> key_list_;
+ std::string column_key_mapping_;
+ KmsConnectionConfig kms_connection_config_;
+ CryptoFactory crypto_factory_;
+ bool wrap_locally_;
+
+ void SetupCryptoFactory(bool wrap_locally) {
+ wrap_locally_ = wrap_locally;
+ std::shared_ptr<KmsClientFactory> kms_client_factory =
+ std::make_shared<TestOnlyInMemoryKmsClientFactory>(wrap_locally, key_list_);
+ crypto_factory_.RegisterKmsClientFactory(kms_client_factory);
+ }
+
+ std::string GetFileName(bool double_wrapping, bool wrap_locally, int encryption_no) {
+ std::string file_name;
+ file_name += double_wrapping ? "double_wrapping" : "no_double_wrapping";
+ file_name += wrap_locally ? "-wrap_locally" : "-wrap_on_server";
+ switch (encryption_no) {
+ case 0:
+ file_name += "-encrypt_columns_and_footer_diff_keys";
+ break;
+ case 1:
+ file_name += "-encrypt_columns_not_footer";
+ break;
+ case 2:
+ file_name += "-encrypt_columns_and_footer_same_keys";
+ break;
+ case 3:
+ file_name += "-encrypt_columns_and_footer_ctr";
+ break;
+ default:
+ file_name += "-no_encrypt";
+ break;
+ }
+ file_name += encryption_no == 4 ? ".parquet" : ".parquet.encrypted";
+ return file_name;
+ }
+
+ EncryptionConfiguration GetEncryptionConfiguration(bool double_wrapping,
+ int encryption_no) {
+ EncryptionConfiguration encryption(kFooterMasterKeyId);
+ encryption.double_wrapping = double_wrapping;
+
+ switch (encryption_no) {
+ case 0:
+ // encrypt some columns and footer, different keys
+ encryption.column_keys = column_key_mapping_;
+ break;
+ case 1:
+ // encrypt columns, plaintext footer, different keys
+ encryption.column_keys = column_key_mapping_;
+ encryption.plaintext_footer = true;
+ break;
+ case 2:
+ // encrypt some columns and footer, same key
+ encryption.uniform_encryption = true;
+ break;
+ case 3:
+ // Encrypt two columns and the footer, with different keys.
+ // Use AES_GCM_CTR_V1 algorithm.
+ encryption.column_keys = column_key_mapping_;
+ encryption.encryption_algorithm = ParquetCipher::AES_GCM_CTR_V1;
+ break;
+ default:
+ // no encryption
+ ARROW_LOG(FATAL) << "Invalid encryption_no";
+ }
+
+ return encryption;
+ }
+
+ DecryptionConfiguration GetDecryptionConfiguration() {
+ return DecryptionConfiguration();
+ }
+
+ void WriteEncryptedParquetFile(bool double_wrapping, int encryption_no) {
+ std::string file_name = GetFileName(double_wrapping, wrap_locally_, encryption_no);
+ auto encryption_config = GetEncryptionConfiguration(double_wrapping, encryption_no);
+
+ auto file_encryption_properties = crypto_factory_.GetFileEncryptionProperties(
+ kms_connection_config_, encryption_config);
+ std::string file = temp_dir->path().ToString() + file_name;
+
+ encryptor_.EncryptFile(file, file_encryption_properties);
+ }
+
+ void ReadEncryptedParquetFile(bool double_wrapping, int encryption_no) {
+ auto decryption_config = GetDecryptionConfiguration();
+ std::string file_name = GetFileName(double_wrapping, wrap_locally_, encryption_no);
+
+ auto file_decryption_properties = crypto_factory_.GetFileDecryptionProperties(
+ kms_connection_config_, decryption_config);
+ std::string file = temp_dir->path().ToString() + file_name;
+
+ decryptor_.DecryptFile(file, file_decryption_properties);
+ }
+};
+
+class TestEncryptionKeyManagementMultiThread : public TestEncryptionKeyManagement {
+ protected:
+ void WriteEncryptedParquetFiles() {
+ std::vector<std::thread> write_threads;
+ for (const bool double_wrapping : {false, true}) {
+ for (int encryption_no = 0; encryption_no < 4; encryption_no++) {
+ write_threads.push_back(std::thread([this, double_wrapping, encryption_no]() {
+ this->WriteEncryptedParquetFile(double_wrapping, encryption_no);
+ }));
+ }
+ }
+ for (auto& th : write_threads) {
+ th.join();
+ }
+ }
+
+ void ReadEncryptedParquetFiles() {
+ std::vector<std::thread> read_threads;
+ for (const bool double_wrapping : {false, true}) {
+ for (int encryption_no = 0; encryption_no < 4; encryption_no++) {
+ read_threads.push_back(std::thread([this, double_wrapping, encryption_no]() {
+ this->ReadEncryptedParquetFile(double_wrapping, encryption_no);
+ }));
+ }
+ }
+ for (auto& th : read_threads) {
+ th.join();
+ }
+ }
+};
+
+TEST_F(TestEncryptionKeyManagement, WrapLocally) {
+ this->SetupCryptoFactory(true);
+
+ for (const bool double_wrapping : {false, true}) {
+ for (int encryption_no = 0; encryption_no < 4; encryption_no++) {
+ this->WriteEncryptedParquetFile(double_wrapping, encryption_no);
+ this->ReadEncryptedParquetFile(double_wrapping, encryption_no);
+ }
+ }
+}
+
+TEST_F(TestEncryptionKeyManagement, WrapOnServer) {
+ this->SetupCryptoFactory(false);
+
+ for (const bool double_wrapping : {false, true}) {
+ for (int encryption_no = 0; encryption_no < 4; encryption_no++) {
+ this->WriteEncryptedParquetFile(double_wrapping, encryption_no);
+ this->ReadEncryptedParquetFile(double_wrapping, encryption_no);
+ }
+ }
+}
+
+TEST_F(TestEncryptionKeyManagementMultiThread, WrapLocally) {
+ this->SetupCryptoFactory(true);
+
+ this->WriteEncryptedParquetFiles();
+ this->ReadEncryptedParquetFiles();
+}
+
+TEST_F(TestEncryptionKeyManagementMultiThread, WrapOnServer) {
+ this->SetupCryptoFactory(false);
+
+ this->WriteEncryptedParquetFiles();
+ this->ReadEncryptedParquetFiles();
+}
+
+// Set temp_dir before running the write/read tests. The encrypted files will
+// be written/read from this directory.
+void TestEncryptionKeyManagement::SetUpTestCase() { temp_dir = *temp_data_dir(); }
+
+} // namespace test
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/key_material.cc b/src/arrow/cpp/src/parquet/encryption/key_material.cc
new file mode 100644
index 000000000..372279c33
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/key_material.cc
@@ -0,0 +1,159 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/json/object_parser.h"
+#include "arrow/json/object_writer.h"
+
+#include "parquet/encryption/key_material.h"
+#include "parquet/encryption/key_metadata.h"
+#include "parquet/exception.h"
+
+using ::arrow::json::internal::ObjectParser;
+using ::arrow::json::internal::ObjectWriter;
+
+namespace parquet {
+namespace encryption {
+
+constexpr const char KeyMaterial::kKeyMaterialTypeField[];
+constexpr const char KeyMaterial::kKeyMaterialType1[];
+
+constexpr const char KeyMaterial::kFooterKeyIdInFile[];
+constexpr const char KeyMaterial::kColumnKeyIdInFilePrefix[];
+
+constexpr const char KeyMaterial::kIsFooterKeyField[];
+constexpr const char KeyMaterial::kDoubleWrappingField[];
+constexpr const char KeyMaterial::kKmsInstanceIdField[];
+constexpr const char KeyMaterial::kKmsInstanceUrlField[];
+constexpr const char KeyMaterial::kMasterKeyIdField[];
+constexpr const char KeyMaterial::kWrappedDataEncryptionKeyField[];
+constexpr const char KeyMaterial::kKeyEncryptionKeyIdField[];
+constexpr const char KeyMaterial::kWrappedKeyEncryptionKeyField[];
+
+KeyMaterial::KeyMaterial(bool is_footer_key, const std::string& kms_instance_id,
+ const std::string& kms_instance_url,
+ const std::string& master_key_id, bool is_double_wrapped,
+ const std::string& kek_id,
+ const std::string& encoded_wrapped_kek,
+ const std::string& encoded_wrapped_dek)
+ : is_footer_key_(is_footer_key),
+ kms_instance_id_(kms_instance_id),
+ kms_instance_url_(kms_instance_url),
+ master_key_id_(master_key_id),
+ is_double_wrapped_(is_double_wrapped),
+ kek_id_(kek_id),
+ encoded_wrapped_kek_(encoded_wrapped_kek),
+ encoded_wrapped_dek_(encoded_wrapped_dek) {}
+
+KeyMaterial KeyMaterial::Parse(const std::string& key_material_string) {
+ ObjectParser json_parser;
+ ::arrow::Status status = json_parser.Parse(key_material_string);
+ if (!status.ok()) {
+ throw ParquetException("Failed to parse key material " + key_material_string);
+ }
+
+ // External key material - extract "key material type", and make sure it is supported
+ std::string key_material_type;
+ PARQUET_ASSIGN_OR_THROW(key_material_type,
+ json_parser.GetString(kKeyMaterialTypeField));
+ if (kKeyMaterialType1 != key_material_type) {
+ throw ParquetException("Wrong key material type: " + key_material_type + " vs " +
+ kKeyMaterialType1);
+ }
+ // Parse other fields (common to internal and external key material)
+ return Parse(&json_parser);
+}
+
+KeyMaterial KeyMaterial::Parse(const ObjectParser* key_material_json) {
+ // 2. Check if "key material" belongs to file footer key
+ bool is_footer_key;
+ PARQUET_ASSIGN_OR_THROW(is_footer_key, key_material_json->GetBool(kIsFooterKeyField));
+ std::string kms_instance_id;
+ std::string kms_instance_url;
+ if (is_footer_key) {
+ // 3. For footer key, extract KMS Instance ID
+ PARQUET_ASSIGN_OR_THROW(kms_instance_id,
+ key_material_json->GetString(kKmsInstanceIdField));
+ // 4. For footer key, extract KMS Instance URL
+ PARQUET_ASSIGN_OR_THROW(kms_instance_url,
+ key_material_json->GetString(kKmsInstanceUrlField));
+ }
+ // 5. Extract master key ID
+ std::string master_key_id;
+ PARQUET_ASSIGN_OR_THROW(master_key_id, key_material_json->GetString(kMasterKeyIdField));
+ // 6. Extract wrapped DEK
+ std::string encoded_wrapped_dek;
+ PARQUET_ASSIGN_OR_THROW(encoded_wrapped_dek,
+ key_material_json->GetString(kWrappedDataEncryptionKeyField));
+ std::string kek_id;
+ std::string encoded_wrapped_kek;
+ // 7. Check if "key material" was generated in double wrapping mode
+ bool is_double_wrapped;
+ PARQUET_ASSIGN_OR_THROW(is_double_wrapped,
+ key_material_json->GetBool(kDoubleWrappingField));
+ if (is_double_wrapped) {
+ // 8. In double wrapping mode, extract KEK ID
+ PARQUET_ASSIGN_OR_THROW(kek_id,
+ key_material_json->GetString(kKeyEncryptionKeyIdField));
+ // 9. In double wrapping mode, extract wrapped KEK
+ PARQUET_ASSIGN_OR_THROW(encoded_wrapped_kek,
+ key_material_json->GetString(kWrappedKeyEncryptionKeyField));
+ }
+
+ return KeyMaterial(is_footer_key, kms_instance_id, kms_instance_url, master_key_id,
+ is_double_wrapped, kek_id, encoded_wrapped_kek, encoded_wrapped_dek);
+}
+
+std::string KeyMaterial::SerializeToJson(
+ bool is_footer_key, const std::string& kms_instance_id,
+ const std::string& kms_instance_url, const std::string& master_key_id,
+ bool is_double_wrapped, const std::string& kek_id,
+ const std::string& encoded_wrapped_kek, const std::string& encoded_wrapped_dek,
+ bool is_internal_storage) {
+ ObjectWriter json_writer;
+ json_writer.SetString(kKeyMaterialTypeField, kKeyMaterialType1);
+
+ if (is_internal_storage) {
+ // 1. for internal storage, key material and key metadata are the same.
+ // adding the "internalStorage" field that belongs to KeyMetadata.
+ json_writer.SetBool(KeyMetadata::kKeyMaterialInternalStorageField, true);
+ }
+ // 2. Write isFooterKey
+ json_writer.SetBool(kIsFooterKeyField, is_footer_key);
+ if (is_footer_key) {
+ // 3. For footer key, write KMS Instance ID
+ json_writer.SetString(kKmsInstanceIdField, kms_instance_id);
+ // 4. For footer key, write KMS Instance URL
+ json_writer.SetString(kKmsInstanceUrlField, kms_instance_url);
+ }
+ // 5. Write master key ID
+ json_writer.SetString(kMasterKeyIdField, master_key_id);
+ // 6. Write wrapped DEK
+ json_writer.SetString(kWrappedDataEncryptionKeyField, encoded_wrapped_dek);
+ // 7. Write isDoubleWrapped
+ json_writer.SetBool(kDoubleWrappingField, is_double_wrapped);
+ if (is_double_wrapped) {
+ // 8. In double wrapping mode, write KEK ID
+ json_writer.SetString(kKeyEncryptionKeyIdField, kek_id);
+ // 9. In double wrapping mode, write wrapped KEK
+ json_writer.SetString(kWrappedKeyEncryptionKeyField, encoded_wrapped_kek);
+ }
+
+ return json_writer.Serialize();
+}
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/key_material.h b/src/arrow/cpp/src/parquet/encryption/key_material.h
new file mode 100644
index 000000000..f20d23ea3
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/key_material.h
@@ -0,0 +1,131 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+
+#include "parquet/platform.h"
+
+namespace arrow {
+namespace json {
+namespace internal {
+class ObjectParser;
+} // namespace internal
+} // namespace json
+} // namespace arrow
+
+namespace parquet {
+namespace encryption {
+
+// KeyMaterial class represents the "key material", keeping the information that allows
+// readers to recover an encryption key (see description of the KeyMetadata class). The
+// keytools package (PARQUET-1373) implements the "envelope encryption" pattern, in a
+// "single wrapping" or "double wrapping" mode. In the single wrapping mode, the key
+// material is generated by encrypting the "data encryption key" (DEK) by a "master key".
+// In the double wrapping mode, the key material is generated by encrypting the DEK by a
+// "key encryption key" (KEK), that in turn is encrypted by a "master key".
+//
+// Key material is kept in a flat json object, with the following fields:
+// 1. "keyMaterialType" - a String, with the type of key material. In the current
+// version, only one value is allowed - "PKMT1" (stands
+// for "parquet key management tools, version 1"). For external key material storage,
+// this field is written in both "key metadata" and "key material" jsons. For internal
+// key material storage, this field is written only once in the common json.
+// 2. "isFooterKey" - a boolean. If true, means that the material belongs to a file footer
+// key, and keeps additional information (such as
+// KMS instance ID and URL). If false, means that the material belongs to a column
+// key.
+// 3. "kmsInstanceID" - a String, with the KMS Instance ID. Written only in footer key
+// material.
+// 4. "kmsInstanceURL" - a String, with the KMS Instance URL. Written only in footer key
+// material.
+// 5. "masterKeyID" - a String, with the ID of the master key used to generate the
+// material.
+// 6. "wrappedDEK" - a String, with the wrapped DEK (base64 encoding).
+// 7. "doubleWrapping" - a boolean. If true, means that the material was generated in
+// double wrapping mode.
+// If false - in single wrapping mode.
+// 8. "keyEncryptionKeyID" - a String, with the ID of the KEK used to generate the
+// material. Written only in double wrapping mode.
+// 9. "wrappedKEK" - a String, with the wrapped KEK (base64 encoding). Written only in
+// double wrapping mode.
+class PARQUET_EXPORT KeyMaterial {
+ public:
+ // these fields are defined in a specification and should never be changed
+ static constexpr const char kKeyMaterialTypeField[] = "keyMaterialType";
+ static constexpr const char kKeyMaterialType1[] = "PKMT1";
+
+ static constexpr const char kFooterKeyIdInFile[] = "footerKey";
+ static constexpr const char kColumnKeyIdInFilePrefix[] = "columnKey";
+
+ static constexpr const char kIsFooterKeyField[] = "isFooterKey";
+ static constexpr const char kDoubleWrappingField[] = "doubleWrapping";
+ static constexpr const char kKmsInstanceIdField[] = "kmsInstanceID";
+ static constexpr const char kKmsInstanceUrlField[] = "kmsInstanceURL";
+ static constexpr const char kMasterKeyIdField[] = "masterKeyID";
+ static constexpr const char kWrappedDataEncryptionKeyField[] = "wrappedDEK";
+ static constexpr const char kKeyEncryptionKeyIdField[] = "keyEncryptionKeyID";
+ static constexpr const char kWrappedKeyEncryptionKeyField[] = "wrappedKEK";
+
+ public:
+ KeyMaterial() = default;
+
+ static KeyMaterial Parse(const std::string& key_material_string);
+
+ static KeyMaterial Parse(
+ const ::arrow::json::internal::ObjectParser* key_material_json);
+
+ /// This method returns a json string that will be stored either inside a parquet file
+ /// or in a key material store outside the parquet file.
+ static std::string SerializeToJson(bool is_footer_key,
+ const std::string& kms_instance_id,
+ const std::string& kms_instance_url,
+ const std::string& master_key_id,
+ bool is_double_wrapped, const std::string& kek_id,
+ const std::string& encoded_wrapped_kek,
+ const std::string& encoded_wrapped_dek,
+ bool is_internal_storage);
+
+ bool is_footer_key() const { return is_footer_key_; }
+ bool is_double_wrapped() const { return is_double_wrapped_; }
+ const std::string& master_key_id() const { return master_key_id_; }
+ const std::string& wrapped_dek() const { return encoded_wrapped_dek_; }
+ const std::string& kek_id() const { return kek_id_; }
+ const std::string& wrapped_kek() const { return encoded_wrapped_kek_; }
+ const std::string& kms_instance_id() const { return kms_instance_id_; }
+ const std::string& kms_instance_url() const { return kms_instance_url_; }
+
+ private:
+ KeyMaterial(bool is_footer_key, const std::string& kms_instance_id,
+ const std::string& kms_instance_url, const std::string& master_key_id,
+ bool is_double_wrapped, const std::string& kek_id,
+ const std::string& encoded_wrapped_kek,
+ const std::string& encoded_wrapped_dek);
+
+ bool is_footer_key_;
+ std::string kms_instance_id_;
+ std::string kms_instance_url_;
+ std::string master_key_id_;
+ bool is_double_wrapped_;
+ std::string kek_id_;
+ std::string encoded_wrapped_kek_;
+ std::string encoded_wrapped_dek_;
+};
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/key_metadata.cc b/src/arrow/cpp/src/parquet/encryption/key_metadata.cc
new file mode 100644
index 000000000..624626c89
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/key_metadata.cc
@@ -0,0 +1,89 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/json/object_parser.h"
+#include "arrow/json/object_writer.h"
+
+#include "parquet/encryption/key_metadata.h"
+#include "parquet/exception.h"
+
+using ::arrow::json::internal::ObjectParser;
+using ::arrow::json::internal::ObjectWriter;
+
+namespace parquet {
+namespace encryption {
+
+constexpr const char KeyMetadata::kKeyMaterialInternalStorageField[];
+constexpr const char KeyMetadata::kKeyReferenceField[];
+
+KeyMetadata::KeyMetadata(const std::string& key_reference)
+ : is_internal_storage_(false), key_material_or_reference_(key_reference) {}
+
+KeyMetadata::KeyMetadata(const KeyMaterial& key_material)
+ : is_internal_storage_(true), key_material_or_reference_(key_material) {}
+
+KeyMetadata KeyMetadata::Parse(const std::string& key_metadata) {
+ ObjectParser json_parser;
+ ::arrow::Status status = json_parser.Parse(key_metadata);
+ if (!status.ok()) {
+ throw ParquetException("Failed to parse key metadata " + key_metadata);
+ }
+
+ // 1. Extract "key material type", and make sure it is supported
+ std::string key_material_type;
+ PARQUET_ASSIGN_OR_THROW(key_material_type,
+ json_parser.GetString(KeyMaterial::kKeyMaterialTypeField));
+ if (key_material_type != KeyMaterial::kKeyMaterialType1) {
+ throw ParquetException("Wrong key material type: " + key_material_type + " vs " +
+ KeyMaterial::kKeyMaterialType1);
+ }
+
+ // 2. Check if "key material" is stored internally in Parquet file key metadata, or is
+ // stored externally
+ bool is_internal_storage;
+ PARQUET_ASSIGN_OR_THROW(is_internal_storage,
+ json_parser.GetBool(kKeyMaterialInternalStorageField));
+
+ if (is_internal_storage) {
+ // 3.1 "key material" is stored internally, inside "key metadata" - parse it
+ KeyMaterial key_material = KeyMaterial::Parse(&json_parser);
+ return KeyMetadata(key_material);
+ } else {
+ // 3.2 "key material" is stored externally. "key metadata" keeps a reference to it
+ std::string key_reference;
+ PARQUET_ASSIGN_OR_THROW(key_reference, json_parser.GetString(kKeyReferenceField));
+ return KeyMetadata(key_reference);
+ }
+}
+
+// For external material only. For internal material, create serialized KeyMaterial
+// directly
+std::string KeyMetadata::CreateSerializedForExternalMaterial(
+ const std::string& key_reference) {
+ ObjectWriter json_writer;
+
+ json_writer.SetString(KeyMaterial::kKeyMaterialTypeField,
+ KeyMaterial::kKeyMaterialType1);
+ json_writer.SetBool(kKeyMaterialInternalStorageField, false);
+
+ json_writer.SetString(kKeyReferenceField, key_reference);
+
+ return json_writer.Serialize();
+}
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/key_metadata.h b/src/arrow/cpp/src/parquet/encryption/key_metadata.h
new file mode 100644
index 000000000..2281b96e6
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/key_metadata.h
@@ -0,0 +1,94 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+
+#include "arrow/util/variant.h"
+
+#include "parquet/encryption/key_material.h"
+#include "parquet/exception.h"
+#include "parquet/platform.h"
+
+namespace parquet {
+namespace encryption {
+
+// Parquet encryption specification defines "key metadata" as an arbitrary byte array,
+// generated by file writers for each encryption key, and passed to the low level API for
+// storage in the file footer. The "key metadata" field is made available to file readers
+// to enable recovery of the key. This interface can be utilized for implementation
+// of any key management scheme.
+//
+// The keytools package (PARQUET-1373) implements one approach, of many possible, to key
+// management and to generation of the "key metadata" fields. This approach, based on the
+// "envelope encryption" pattern, allows integration with KMS servers. It keeps the actual
+// material, required to recover a key, in a "key material" object (see the KeyMaterial
+// class for details). This class is implemented to support version 1 of the parquet key
+// management tools specification.
+//
+// KeyMetadata writes (and reads) the "key metadata" field as a flat json object,
+// with the following fields:
+// 1. "keyMaterialType" - a String, with the type of key material.
+// 2. "internalStorage" - a boolean. If true, means that "key material" is kept inside the
+// "key metadata" field. If false, "key material" is kept externally (outside Parquet
+// files) - in this case, "key metadata" keeps a reference to the external "key material".
+// 3. "keyReference" - a String, with the reference to the external "key material".
+// Written only if internalStorage is false.
+//
+// If internalStorage is true, "key material" is a part of "key metadata", and the json
+// keeps additional fields, described in the KeyMaterial class.
+class PARQUET_EXPORT KeyMetadata {
+ public:
+ static constexpr const char kKeyMaterialInternalStorageField[] = "internalStorage";
+ static constexpr const char kKeyReferenceField[] = "keyReference";
+
+ /// key_metadata_bytes is the key metadata field stored in the parquet file,
+ /// in the serialized json object format.
+ static KeyMetadata Parse(const std::string& key_metadata_bytes);
+
+ static std::string CreateSerializedForExternalMaterial(
+ const std::string& key_reference);
+
+ bool key_material_stored_internally() const { return is_internal_storage_; }
+
+ const KeyMaterial& key_material() const {
+ if (!is_internal_storage_) {
+ throw ParquetException("key material is stored externally.");
+ }
+ return ::arrow::util::get<KeyMaterial>(key_material_or_reference_);
+ }
+
+ const std::string& key_reference() const {
+ if (is_internal_storage_) {
+ throw ParquetException("key material is stored internally.");
+ }
+ return ::arrow::util::get<std::string>(key_material_or_reference_);
+ }
+
+ private:
+ explicit KeyMetadata(const KeyMaterial& key_material);
+ explicit KeyMetadata(const std::string& key_reference);
+
+ bool is_internal_storage_;
+ /// If is_internal_storage_ is true, KeyMaterial is set,
+ /// else a string referencing to an outside "key material" is set.
+ ::arrow::util::Variant<KeyMaterial, std::string> key_material_or_reference_;
+};
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/key_metadata_test.cc b/src/arrow/cpp/src/parquet/encryption/key_metadata_test.cc
new file mode 100644
index 000000000..3f891ef26
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/key_metadata_test.cc
@@ -0,0 +1,77 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <string>
+
+#include <gtest/gtest.h>
+
+#include "parquet/encryption/key_material.h"
+#include "parquet/encryption/key_metadata.h"
+
+namespace parquet {
+namespace encryption {
+namespace test {
+
+TEST(KeyMetadataTest, InternalMaterialStorage) {
+ bool is_footer_key = true;
+ std::string kms_instance_id = "DEFAULT";
+ std::string kms_instance_url = "DEFAULT";
+ std::string master_key_id = "kf";
+ bool double_wrapped = true;
+ std::string kek_id = "FANqyCuxfU1c526Uzb+MTA==";
+ std::string encoded_wrapped_kek =
+ "{\"masterKeyVersion\":\"NO_VERSION\",\"encryptedKey\":\"LAAAAGaoSfMV1YH/"
+ "oxwG2ES8Phva8wimEZcT7zi5bpuK5Jdvw9/zJuqDeIPGGFXd\"}";
+ std::string encoded_wrapped_dek =
+ "LAAAAA3RcNYT1Rxb/gqhA1KvBgHcjvEppST9+cV3bU5nLmtaZHJhsZakR20qRErX";
+ bool internal_storage = true;
+ std::string json = KeyMaterial::SerializeToJson(
+ is_footer_key, kms_instance_id, kms_instance_url, master_key_id, double_wrapped,
+ kek_id, encoded_wrapped_kek, encoded_wrapped_dek, internal_storage);
+
+ KeyMetadata key_metadata = KeyMetadata::Parse(json);
+
+ ASSERT_EQ(key_metadata.key_material_stored_internally(), true);
+
+ const KeyMaterial& key_material = key_metadata.key_material();
+ ASSERT_EQ(key_material.is_footer_key(), is_footer_key);
+ ASSERT_EQ(key_material.kms_instance_id(), kms_instance_id);
+ ASSERT_EQ(key_material.kms_instance_url(), kms_instance_url);
+ ASSERT_EQ(key_material.master_key_id(), master_key_id);
+ ASSERT_EQ(key_material.is_double_wrapped(), double_wrapped);
+ ASSERT_EQ(key_material.kek_id(), kek_id);
+ ASSERT_EQ(key_material.wrapped_kek(), encoded_wrapped_kek);
+ ASSERT_EQ(key_material.wrapped_dek(), encoded_wrapped_dek);
+}
+
+TEST(KeyMetadataTest, ExternalMaterialStorage) {
+ const std::string key_reference = "X44KIHSxDSFAS5q2223";
+
+ // generate key_metadata string in parquet file
+ std::string key_metadata_str =
+ KeyMetadata::CreateSerializedForExternalMaterial(key_reference);
+
+ // parse key_metadata string back
+ KeyMetadata key_metadata = KeyMetadata::Parse(key_metadata_str);
+
+ ASSERT_EQ(key_metadata.key_material_stored_internally(), false);
+ ASSERT_EQ(key_metadata.key_reference(), key_reference);
+}
+
+} // namespace test
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/key_toolkit.cc b/src/arrow/cpp/src/parquet/encryption/key_toolkit.cc
new file mode 100644
index 000000000..033d34891
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/key_toolkit.cc
@@ -0,0 +1,52 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/encryption/key_toolkit.h"
+
+namespace parquet {
+namespace encryption {
+
+std::shared_ptr<KmsClient> KeyToolkit::GetKmsClient(
+ const KmsConnectionConfig& kms_connection_config, double cache_entry_lifetime_ms) {
+ if (kms_client_factory_ == NULL) {
+ throw ParquetException("No KmsClientFactory is registered.");
+ }
+ auto kms_client_per_kms_instance_cache =
+ kms_client_cache_per_token().GetOrCreateInternalCache(
+ kms_connection_config.key_access_token(), cache_entry_lifetime_ms);
+
+ return kms_client_per_kms_instance_cache->GetOrInsert(
+ kms_connection_config.kms_instance_id, [this, kms_connection_config]() {
+ return this->kms_client_factory_->CreateKmsClient(kms_connection_config);
+ });
+}
+
+// Flush any caches that are tied to the (compromised) access_token
+void KeyToolkit::RemoveCacheEntriesForToken(const std::string& access_token) {
+ kms_client_cache_per_token().Remove(access_token);
+ kek_write_cache_per_token().Remove(access_token);
+ kek_read_cache_per_token().Remove(access_token);
+}
+
+void KeyToolkit::RemoveCacheEntriesForAllTokens() {
+ kms_client_cache_per_token().Clear();
+ kek_write_cache_per_token().Clear();
+ kek_read_cache_per_token().Clear();
+}
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/key_toolkit.h b/src/arrow/cpp/src/parquet/encryption/key_toolkit.h
new file mode 100644
index 000000000..92b8dd302
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/key_toolkit.h
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "parquet/encryption/key_encryption_key.h"
+#include "parquet/encryption/kms_client.h"
+#include "parquet/encryption/kms_client_factory.h"
+#include "parquet/encryption/two_level_cache_with_expiration.h"
+#include "parquet/platform.h"
+
+namespace parquet {
+namespace encryption {
+
+// KeyToolkit is a utility that keeps various tools for key management (such as key
+// rotation, kms client instantiation, cache control, etc), plus a number of auxiliary
+// classes for internal use.
+class PARQUET_EXPORT KeyToolkit {
+ public:
+ /// KMS client two level cache: token -> KMSInstanceId -> KmsClient
+ TwoLevelCacheWithExpiration<std::shared_ptr<KmsClient>>& kms_client_cache_per_token() {
+ return kms_client_cache_;
+ }
+ /// Key encryption key two level cache for wrapping: token -> MasterEncryptionKeyId ->
+ /// KeyEncryptionKey
+ TwoLevelCacheWithExpiration<KeyEncryptionKey>& kek_write_cache_per_token() {
+ return key_encryption_key_write_cache_;
+ }
+
+ /// Key encryption key two level cache for unwrapping: token -> KeyEncryptionKeyId ->
+ /// KeyEncryptionKeyBytes
+ TwoLevelCacheWithExpiration<std::string>& kek_read_cache_per_token() {
+ return key_encryption_key_read_cache_;
+ }
+
+ std::shared_ptr<KmsClient> GetKmsClient(
+ const KmsConnectionConfig& kms_connection_config, double cache_entry_lifetime_ms);
+
+ /// Flush any caches that are tied to the (compromised) access_token
+ void RemoveCacheEntriesForToken(const std::string& access_token);
+
+ void RemoveCacheEntriesForAllTokens();
+
+ void RegisterKmsClientFactory(std::shared_ptr<KmsClientFactory> kms_client_factory) {
+ if (kms_client_factory_ != NULL) {
+ throw ParquetException("KMS client factory has already been registered.");
+ }
+ kms_client_factory_ = kms_client_factory;
+ }
+
+ private:
+ TwoLevelCacheWithExpiration<std::shared_ptr<KmsClient>> kms_client_cache_;
+ TwoLevelCacheWithExpiration<KeyEncryptionKey> key_encryption_key_write_cache_;
+ TwoLevelCacheWithExpiration<std::string> key_encryption_key_read_cache_;
+ std::shared_ptr<KmsClientFactory> kms_client_factory_;
+};
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/key_toolkit_internal.cc b/src/arrow/cpp/src/parquet/encryption/key_toolkit_internal.cc
new file mode 100644
index 000000000..6cfd5381f
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/key_toolkit_internal.cc
@@ -0,0 +1,80 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/base64.h"
+
+#include "parquet/encryption/encryption_internal.h"
+#include "parquet/encryption/key_toolkit_internal.h"
+
+namespace parquet {
+namespace encryption {
+namespace internal {
+
+// Acceptable key lengths in number of bits, used to validate the data key lengths
+// configured by users and the master key lengths fetched from KMS server.
+static constexpr const int32_t kAcceptableDataKeyLengths[] = {128, 192, 256};
+
+std::string EncryptKeyLocally(const std::string& key_bytes, const std::string& master_key,
+ const std::string& aad) {
+ AesEncryptor key_encryptor(ParquetCipher::AES_GCM_V1,
+ static_cast<int>(master_key.size()), false);
+
+ int encrypted_key_len =
+ static_cast<int>(key_bytes.size()) + key_encryptor.CiphertextSizeDelta();
+ std::string encrypted_key(encrypted_key_len, '\0');
+ encrypted_key_len = key_encryptor.Encrypt(
+ reinterpret_cast<const uint8_t*>(key_bytes.data()),
+ static_cast<int>(key_bytes.size()),
+ reinterpret_cast<const uint8_t*>(master_key.data()),
+ static_cast<int>(master_key.size()), reinterpret_cast<const uint8_t*>(aad.data()),
+ static_cast<int>(aad.size()), reinterpret_cast<uint8_t*>(&encrypted_key[0]));
+
+ return ::arrow::util::base64_encode(
+ ::arrow::util::string_view(encrypted_key.data(), encrypted_key_len));
+}
+
+std::string DecryptKeyLocally(const std::string& encoded_encrypted_key,
+ const std::string& master_key, const std::string& aad) {
+ std::string encrypted_key = ::arrow::util::base64_decode(encoded_encrypted_key);
+
+ AesDecryptor key_decryptor(ParquetCipher::AES_GCM_V1,
+ static_cast<int>(master_key.size()), false);
+
+ int decrypted_key_len =
+ static_cast<int>(encrypted_key.size()) - key_decryptor.CiphertextSizeDelta();
+ std::string decrypted_key(decrypted_key_len, '\0');
+
+ decrypted_key_len = key_decryptor.Decrypt(
+ reinterpret_cast<const uint8_t*>(encrypted_key.data()),
+ static_cast<int>(encrypted_key.size()),
+ reinterpret_cast<const uint8_t*>(master_key.data()),
+ static_cast<int>(master_key.size()), reinterpret_cast<const uint8_t*>(aad.data()),
+ static_cast<int>(aad.size()), reinterpret_cast<uint8_t*>(&decrypted_key[0]));
+
+ return decrypted_key;
+}
+
+bool ValidateKeyLength(int32_t key_length_bits) {
+ int32_t* found_key_length = std::find(
+ const_cast<int32_t*>(kAcceptableDataKeyLengths),
+ const_cast<int32_t*>(std::end(kAcceptableDataKeyLengths)), key_length_bits);
+ return found_key_length != std::end(kAcceptableDataKeyLengths);
+}
+
+} // namespace internal
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/key_toolkit_internal.h b/src/arrow/cpp/src/parquet/encryption/key_toolkit_internal.h
new file mode 100644
index 000000000..af5f1f08a
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/key_toolkit_internal.h
@@ -0,0 +1,58 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string>
+
+#include "parquet/platform.h"
+
+namespace parquet {
+namespace encryption {
+namespace internal {
+
+// "data encryption key" and "master key identifier" are paired together as output when
+// parsing from "key material"
+class PARQUET_EXPORT KeyWithMasterId {
+ public:
+ KeyWithMasterId(std::string key_bytes, std::string master_id)
+ : key_bytes_(std::move(key_bytes)), master_id_(std::move(master_id)) {}
+
+ const std::string& data_key() const { return key_bytes_; }
+ const std::string& master_id() const { return master_id_; }
+
+ private:
+ const std::string key_bytes_;
+ const std::string master_id_;
+};
+
+/// Encrypts "key" with "master_key", using AES-GCM and the "aad"
+PARQUET_EXPORT
+std::string EncryptKeyLocally(const std::string& key, const std::string& master_key,
+ const std::string& aad);
+
+/// Decrypts encrypted key with "master_key", using AES-GCM and the "aad"
+PARQUET_EXPORT
+std::string DecryptKeyLocally(const std::string& encoded_encrypted_key,
+ const std::string& master_key, const std::string& aad);
+
+PARQUET_EXPORT
+bool ValidateKeyLength(int32_t key_length_bits);
+
+} // namespace internal
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/key_wrapping_test.cc b/src/arrow/cpp/src/parquet/encryption/key_wrapping_test.cc
new file mode 100644
index 000000000..a745fe34c
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/key_wrapping_test.cc
@@ -0,0 +1,103 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include "parquet/encryption/file_key_unwrapper.h"
+#include "parquet/encryption/file_key_wrapper.h"
+#include "parquet/encryption/test_encryption_util.h"
+#include "parquet/encryption/test_in_memory_kms.h"
+
+namespace parquet {
+namespace encryption {
+namespace test {
+
+class KeyWrappingTest : public ::testing::Test {
+ public:
+ void SetUp() {
+ key_list_ = BuildKeyMap(kColumnMasterKeyIds, kColumnMasterKeys, kFooterMasterKeyId,
+ kFooterMasterKey);
+ }
+
+ protected:
+ void WrapThenUnwrap(std::shared_ptr<FileKeyMaterialStore> key_material_store,
+ bool double_wrapping, bool is_wrap_locally) {
+ double cache_entry_lifetime_seconds = 600;
+
+ KeyToolkit key_toolkit;
+ key_toolkit.RegisterKmsClientFactory(
+ std::make_shared<TestOnlyInMemoryKmsClientFactory>(is_wrap_locally, key_list_));
+
+ FileKeyWrapper wrapper(&key_toolkit, kms_connection_config_, key_material_store,
+ cache_entry_lifetime_seconds, double_wrapping);
+
+ std::string key_metadata_json_footer =
+ wrapper.GetEncryptionKeyMetadata(kFooterEncryptionKey, kFooterMasterKeyId, true);
+ std::string key_metadata_json_column = wrapper.GetEncryptionKeyMetadata(
+ kColumnEncryptionKey1, kColumnMasterKeyIds[0], false);
+
+ FileKeyUnwrapper unwrapper(&key_toolkit, kms_connection_config_,
+ cache_entry_lifetime_seconds);
+ std::string footer_key = unwrapper.GetKey(key_metadata_json_footer);
+ ASSERT_EQ(footer_key, kFooterEncryptionKey);
+
+ std::string column_key = unwrapper.GetKey(key_metadata_json_column);
+ ASSERT_EQ(column_key, kColumnEncryptionKey1);
+ }
+
+ // TODO: this method will be removed when material external storage is supported
+ void WrapThenUnwrapWithUnsupportedExternalStorage(bool double_wrapping,
+ bool is_wrap_locally) {
+ double cache_entry_lifetime_seconds = 600;
+
+ KeyToolkit key_toolkit;
+ key_toolkit.RegisterKmsClientFactory(
+ std::make_shared<TestOnlyInMemoryKmsClientFactory>(is_wrap_locally, key_list_));
+
+ std::shared_ptr<FileKeyMaterialStore> unsupported_material_store =
+ std::make_shared<FileKeyMaterialStore>();
+
+ FileKeyWrapper wrapper(&key_toolkit, kms_connection_config_,
+ unsupported_material_store, cache_entry_lifetime_seconds,
+ double_wrapping);
+
+ EXPECT_THROW(
+ wrapper.GetEncryptionKeyMetadata(kFooterEncryptionKey, kFooterMasterKeyId, true),
+ ParquetException);
+ }
+
+ std::unordered_map<std::string, std::string> key_list_;
+ KmsConnectionConfig kms_connection_config_;
+};
+
+TEST_F(KeyWrappingTest, InternalMaterialStorage) {
+ // key_material_store = NULL indicates that "key material" is stored inside parquet
+ // file.
+ this->WrapThenUnwrap(NULL, true, true);
+ this->WrapThenUnwrap(NULL, true, false);
+ this->WrapThenUnwrap(NULL, false, true);
+ this->WrapThenUnwrap(NULL, false, false);
+}
+
+// TODO: this test should be updated when material external storage is supported
+TEST_F(KeyWrappingTest, ExternalMaterialStorage) {
+ this->WrapThenUnwrapWithUnsupportedExternalStorage(true, true);
+}
+
+} // namespace test
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/kms_client.cc b/src/arrow/cpp/src/parquet/encryption/kms_client.cc
new file mode 100644
index 000000000..b9c720272
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/kms_client.cc
@@ -0,0 +1,44 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/encryption/kms_client.h"
+
+namespace parquet {
+namespace encryption {
+
+constexpr const char KmsClient::kKmsInstanceIdDefault[];
+constexpr const char KmsClient::kKmsInstanceUrlDefault[];
+constexpr const char KmsClient::kKeyAccessTokenDefault[];
+
+KmsConnectionConfig::KmsConnectionConfig()
+ : refreshable_key_access_token(
+ std::make_shared<KeyAccessToken>(KmsClient::kKeyAccessTokenDefault)) {}
+
+void KmsConnectionConfig::SetDefaultIfEmpty() {
+ if (kms_instance_id.empty()) {
+ kms_instance_id = KmsClient::kKmsInstanceIdDefault;
+ }
+ if (kms_instance_url.empty()) {
+ kms_instance_url = KmsClient::kKmsInstanceUrlDefault;
+ }
+ if (refreshable_key_access_token == NULL) {
+ refreshable_key_access_token = std::make_shared<KeyAccessToken>();
+ }
+}
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/kms_client.h b/src/arrow/cpp/src/parquet/encryption/kms_client.h
new file mode 100644
index 000000000..5ffa604ff
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/kms_client.h
@@ -0,0 +1,95 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include "arrow/util/mutex.h"
+
+#include "parquet/exception.h"
+#include "parquet/platform.h"
+
+namespace parquet {
+namespace encryption {
+
+/// This class wraps the key access token of a KMS server. If your token changes over
+/// time, you should keep the reference to the KeyAccessToken object and call Refresh()
+/// method every time you have a new token.
+class PARQUET_EXPORT KeyAccessToken {
+ public:
+ KeyAccessToken() = default;
+
+ explicit KeyAccessToken(const std::string value) : value_(value) {}
+
+ void Refresh(const std::string& new_value) {
+ auto lock = mutex_.Lock();
+ value_ = new_value;
+ }
+
+ const std::string& value() const {
+ auto lock = mutex_.Lock();
+ return value_;
+ }
+
+ private:
+ std::string value_;
+ mutable ::arrow::util::Mutex mutex_;
+};
+
+struct PARQUET_EXPORT KmsConnectionConfig {
+ std::string kms_instance_id;
+ std::string kms_instance_url;
+ /// If the access token is changed in the future, you should keep a reference to
+ /// this object and call Refresh() on it whenever there is a new access token.
+ std::shared_ptr<KeyAccessToken> refreshable_key_access_token;
+ std::unordered_map<std::string, std::string> custom_kms_conf;
+
+ KmsConnectionConfig();
+
+ const std::string& key_access_token() const {
+ if (refreshable_key_access_token == NULL ||
+ refreshable_key_access_token->value().empty()) {
+ throw ParquetException("key access token is not set!");
+ }
+ return refreshable_key_access_token->value();
+ }
+
+ void SetDefaultIfEmpty();
+};
+
+class PARQUET_EXPORT KmsClient {
+ public:
+ static constexpr const char kKmsInstanceIdDefault[] = "DEFAULT";
+ static constexpr const char kKmsInstanceUrlDefault[] = "DEFAULT";
+ static constexpr const char kKeyAccessTokenDefault[] = "DEFAULT";
+
+ /// Wraps a key - encrypts it with the master key, encodes the result
+ /// and potentially adds a KMS-specific metadata.
+ virtual std::string WrapKey(const std::string& key_bytes,
+ const std::string& master_key_identifier) = 0;
+
+ /// Decrypts (unwraps) a key with the master key.
+ virtual std::string UnwrapKey(const std::string& wrapped_key,
+ const std::string& master_key_identifier) = 0;
+ virtual ~KmsClient() {}
+};
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/kms_client_factory.h b/src/arrow/cpp/src/parquet/encryption/kms_client_factory.h
new file mode 100644
index 000000000..eac8dfc5d
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/kms_client_factory.h
@@ -0,0 +1,40 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "parquet/encryption/kms_client.h"
+#include "parquet/platform.h"
+
+namespace parquet {
+namespace encryption {
+
+class PARQUET_EXPORT KmsClientFactory {
+ public:
+ explicit KmsClientFactory(bool wrap_locally = false) : wrap_locally_(wrap_locally) {}
+
+ virtual ~KmsClientFactory() = default;
+
+ virtual std::shared_ptr<KmsClient> CreateKmsClient(
+ const KmsConnectionConfig& kms_connection_config) = 0;
+
+ protected:
+ bool wrap_locally_;
+};
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/local_wrap_kms_client.cc b/src/arrow/cpp/src/parquet/encryption/local_wrap_kms_client.cc
new file mode 100644
index 000000000..1b89dc57d
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/local_wrap_kms_client.cc
@@ -0,0 +1,116 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/json/object_parser.h"
+#include "arrow/json/object_writer.h"
+
+#include "parquet/encryption/key_toolkit_internal.h"
+#include "parquet/encryption/local_wrap_kms_client.h"
+#include "parquet/exception.h"
+
+using ::arrow::json::internal::ObjectParser;
+using ::arrow::json::internal::ObjectWriter;
+
+namespace parquet {
+namespace encryption {
+
+constexpr const char LocalWrapKmsClient::kLocalWrapNoKeyVersion[];
+
+constexpr const char LocalWrapKmsClient::LocalKeyWrap::kLocalWrapKeyVersionField[];
+constexpr const char LocalWrapKmsClient::LocalKeyWrap::kLocalWrapEncryptedKeyField[];
+
+LocalWrapKmsClient::LocalKeyWrap::LocalKeyWrap(std::string master_key_version,
+ std::string encrypted_encoded_key)
+ : encrypted_encoded_key_(std::move(encrypted_encoded_key)),
+ master_key_version_(std::move(master_key_version)) {}
+
+std::string LocalWrapKmsClient::LocalKeyWrap::CreateSerialized(
+ const std::string& encrypted_encoded_key) {
+ ObjectWriter json_writer;
+
+ json_writer.SetString(kLocalWrapKeyVersionField, kLocalWrapNoKeyVersion);
+ json_writer.SetString(kLocalWrapEncryptedKeyField, encrypted_encoded_key);
+
+ return json_writer.Serialize();
+}
+
+LocalWrapKmsClient::LocalKeyWrap LocalWrapKmsClient::LocalKeyWrap::Parse(
+ const std::string& wrapped_key) {
+ ObjectParser json_parser;
+ auto status = json_parser.Parse(wrapped_key);
+ if (!status.ok()) {
+ throw ParquetException("Failed to parse local key wrap json " + wrapped_key);
+ }
+ PARQUET_ASSIGN_OR_THROW(const auto master_key_version,
+ json_parser.GetString(kLocalWrapKeyVersionField));
+
+ PARQUET_ASSIGN_OR_THROW(const auto encrypted_encoded_key,
+ json_parser.GetString(kLocalWrapEncryptedKeyField));
+
+ return LocalWrapKmsClient::LocalKeyWrap(std::move(master_key_version),
+ std::move(encrypted_encoded_key));
+}
+
+LocalWrapKmsClient::LocalWrapKmsClient(const KmsConnectionConfig& kms_connection_config)
+ : kms_connection_config_(kms_connection_config) {
+ master_key_cache_.Clear();
+}
+
+std::string LocalWrapKmsClient::WrapKey(const std::string& key_bytes,
+ const std::string& master_key_identifier) {
+ const auto master_key = master_key_cache_.GetOrInsert(
+ master_key_identifier, [this, master_key_identifier]() -> std::string {
+ return this->GetKeyFromServer(master_key_identifier);
+ });
+ const auto& aad = master_key_identifier;
+
+ const auto encrypted_encoded_key =
+ internal::EncryptKeyLocally(key_bytes, master_key, aad);
+ return LocalKeyWrap::CreateSerialized(encrypted_encoded_key);
+}
+
+std::string LocalWrapKmsClient::UnwrapKey(const std::string& wrapped_key,
+ const std::string& master_key_identifier) {
+ LocalKeyWrap key_wrap = LocalKeyWrap::Parse(wrapped_key);
+ const std::string& master_key_version = key_wrap.master_key_version();
+ if (kLocalWrapNoKeyVersion != master_key_version) {
+ throw ParquetException("Master key versions are not supported for local wrapping: " +
+ master_key_version);
+ }
+ const std::string& encrypted_encoded_key = key_wrap.encrypted_encoded_key();
+ const std::string master_key = master_key_cache_.GetOrInsert(
+ master_key_identifier, [this, master_key_identifier]() -> std::string {
+ return this->GetKeyFromServer(master_key_identifier);
+ });
+ const std::string& aad = master_key_identifier;
+
+ return internal::DecryptKeyLocally(encrypted_encoded_key, master_key, aad);
+}
+
+std::string LocalWrapKmsClient::GetKeyFromServer(const std::string& key_identifier) {
+ std::string master_key = GetMasterKeyFromServer(key_identifier);
+ int32_t key_length_bits = static_cast<int32_t>(master_key.size() * 8);
+ if (!internal::ValidateKeyLength(key_length_bits)) {
+ std::ostringstream ss;
+ ss << "Wrong master key length : " << key_length_bits;
+ throw ParquetException(ss.str());
+ }
+ return master_key;
+}
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/local_wrap_kms_client.h b/src/arrow/cpp/src/parquet/encryption/local_wrap_kms_client.h
new file mode 100644
index 000000000..65cf8f42c
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/local_wrap_kms_client.h
@@ -0,0 +1,96 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <unordered_map>
+#include <vector>
+
+#include "arrow/util/concurrent_map.h"
+
+#include "parquet/encryption/kms_client.h"
+#include "parquet/platform.h"
+
+namespace parquet {
+namespace encryption {
+
+/// This class supports local wrapping mode, master keys will be fetched from the KMS
+/// server and used to encrypt other keys (data encryption keys or key encryption keys).
+class PARQUET_EXPORT LocalWrapKmsClient : public KmsClient {
+ public:
+ static constexpr const char kLocalWrapNoKeyVersion[] = "NO_VERSION";
+
+ explicit LocalWrapKmsClient(const KmsConnectionConfig& kms_connection_config);
+
+ std::string WrapKey(const std::string& key_bytes,
+ const std::string& master_key_identifier) override;
+
+ std::string UnwrapKey(const std::string& wrapped_key,
+ const std::string& master_key_identifier) override;
+
+ protected:
+ /// Get master key from the remote KMS server.
+ /// Note: this function might be called by multiple threads
+ virtual std::string GetMasterKeyFromServer(
+ const std::string& master_key_identifier) = 0;
+
+ private:
+ /// KMS systems wrap keys by encrypting them by master keys, and attaching additional
+ /// information (such as the version number of the masker key) to the result of
+ /// encryption. The master key version is required in key rotation. Currently, the
+ /// local wrapping mode does not support key rotation (because not all KMS systems allow
+ /// to fetch a master key by its ID and version number). Still, the local wrapping mode
+ /// adds a placeholder for the master key version, that will enable support for key
+ /// rotation in this mode in the future, with appropriate KMS systems. This will also
+ /// enable backward compatibility, where future readers will be able to extract master
+ /// key version in the files written by the current code.
+ ///
+ /// LocalKeyWrap class writes (and reads) the "key wrap" as a flat json with the
+ /// following fields:
+ /// 1. "masterKeyVersion" - a String, with the master key version. In the current
+ /// version, only one value is allowed - "NO_VERSION".
+ /// 2. "encryptedKey" - a String, with the key encrypted by the master key
+ /// (base64-encoded).
+ class LocalKeyWrap {
+ public:
+ static constexpr const char kLocalWrapKeyVersionField[] = "masterKeyVersion";
+ static constexpr const char kLocalWrapEncryptedKeyField[] = "encryptedKey";
+
+ LocalKeyWrap(std::string master_key_version, std::string encrypted_encoded_key);
+
+ static std::string CreateSerialized(const std::string& encrypted_encoded_key);
+
+ static LocalKeyWrap Parse(const std::string& wrapped_key);
+
+ const std::string& master_key_version() const { return master_key_version_; }
+
+ const std::string& encrypted_encoded_key() const { return encrypted_encoded_key_; }
+
+ private:
+ std::string encrypted_encoded_key_;
+ std::string master_key_version_;
+ };
+
+ std::string GetKeyFromServer(const std::string& key_identifier);
+
+ protected:
+ KmsConnectionConfig kms_connection_config_;
+ ::arrow::util::ConcurrentMap<std::string, std::string> master_key_cache_;
+};
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/properties_test.cc b/src/arrow/cpp/src/parquet/encryption/properties_test.cc
new file mode 100644
index 000000000..0eb5cba20
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/properties_test.cc
@@ -0,0 +1,276 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <string>
+
+#include "parquet/encryption/encryption.h"
+#include "parquet/encryption/test_encryption_util.h"
+
+namespace parquet {
+namespace encryption {
+namespace test {
+
+TEST(TestColumnEncryptionProperties, ColumnEncryptedWithOwnKey) {
+ std::string column_path_1 = "column_1";
+ ColumnEncryptionProperties::Builder column_builder_1(column_path_1);
+ column_builder_1.key(kColumnEncryptionKey1);
+ column_builder_1.key_id("kc1");
+ std::shared_ptr<ColumnEncryptionProperties> column_props_1 = column_builder_1.build();
+
+ ASSERT_EQ(column_path_1, column_props_1->column_path());
+ ASSERT_EQ(true, column_props_1->is_encrypted());
+ ASSERT_EQ(false, column_props_1->is_encrypted_with_footer_key());
+ ASSERT_EQ(kColumnEncryptionKey1, column_props_1->key());
+ ASSERT_EQ("kc1", column_props_1->key_metadata());
+}
+
+TEST(TestColumnEncryptionProperties, ColumnEncryptedWithFooterKey) {
+ std::string column_path_1 = "column_1";
+ ColumnEncryptionProperties::Builder column_builder_1(column_path_1);
+ std::shared_ptr<ColumnEncryptionProperties> column_props_1 = column_builder_1.build();
+
+ ASSERT_EQ(column_path_1, column_props_1->column_path());
+ ASSERT_EQ(true, column_props_1->is_encrypted());
+ ASSERT_EQ(true, column_props_1->is_encrypted_with_footer_key());
+}
+
+// Encrypt all columns and the footer with the same key.
+// (uniform encryption)
+TEST(TestEncryptionProperties, UniformEncryption) {
+ FileEncryptionProperties::Builder builder(kFooterEncryptionKey);
+ builder.footer_key_metadata("kf");
+ std::shared_ptr<FileEncryptionProperties> props = builder.build();
+
+ ASSERT_EQ(true, props->encrypted_footer());
+ ASSERT_EQ(kDefaultEncryptionAlgorithm, props->algorithm().algorithm);
+ ASSERT_EQ(kFooterEncryptionKey, props->footer_key());
+ ASSERT_EQ("kf", props->footer_key_metadata());
+
+ std::shared_ptr<parquet::schema::ColumnPath> column_path =
+ parquet::schema::ColumnPath::FromDotString("a_column");
+ std::shared_ptr<ColumnEncryptionProperties> out_col_props =
+ props->column_encryption_properties(column_path->ToDotString());
+
+ ASSERT_EQ(true, out_col_props->is_encrypted());
+ ASSERT_EQ(true, out_col_props->is_encrypted_with_footer_key());
+}
+
+// Encrypt two columns with their own keys and the same key for
+// the footer and other columns
+TEST(TestEncryptionProperties, EncryptFooterAndTwoColumns) {
+ std::shared_ptr<parquet::schema::ColumnPath> column_path_1 =
+ parquet::schema::ColumnPath::FromDotString("column_1");
+ ColumnEncryptionProperties::Builder column_builder_1(column_path_1->ToDotString());
+ column_builder_1.key(kColumnEncryptionKey1);
+ column_builder_1.key_id("kc1");
+
+ std::shared_ptr<parquet::schema::ColumnPath> column_path_2 =
+ parquet::schema::ColumnPath::FromDotString("column_2");
+ ColumnEncryptionProperties::Builder column_builder_2(column_path_2->ToDotString());
+ column_builder_2.key(kColumnEncryptionKey2);
+ column_builder_2.key_id("kc2");
+
+ std::map<std::string, std::shared_ptr<ColumnEncryptionProperties>> encrypted_columns;
+ encrypted_columns[column_path_1->ToDotString()] = column_builder_1.build();
+ encrypted_columns[column_path_2->ToDotString()] = column_builder_2.build();
+
+ FileEncryptionProperties::Builder builder(kFooterEncryptionKey);
+ builder.footer_key_metadata("kf");
+ builder.encrypted_columns(encrypted_columns);
+ std::shared_ptr<FileEncryptionProperties> props = builder.build();
+
+ ASSERT_EQ(true, props->encrypted_footer());
+ ASSERT_EQ(kDefaultEncryptionAlgorithm, props->algorithm().algorithm);
+ ASSERT_EQ(kFooterEncryptionKey, props->footer_key());
+
+ std::shared_ptr<ColumnEncryptionProperties> out_col_props_1 =
+ props->column_encryption_properties(column_path_1->ToDotString());
+
+ ASSERT_EQ(column_path_1->ToDotString(), out_col_props_1->column_path());
+ ASSERT_EQ(true, out_col_props_1->is_encrypted());
+ ASSERT_EQ(false, out_col_props_1->is_encrypted_with_footer_key());
+ ASSERT_EQ(kColumnEncryptionKey1, out_col_props_1->key());
+ ASSERT_EQ("kc1", out_col_props_1->key_metadata());
+
+ std::shared_ptr<ColumnEncryptionProperties> out_col_props_2 =
+ props->column_encryption_properties(column_path_2->ToDotString());
+
+ ASSERT_EQ(column_path_2->ToDotString(), out_col_props_2->column_path());
+ ASSERT_EQ(true, out_col_props_2->is_encrypted());
+ ASSERT_EQ(false, out_col_props_2->is_encrypted_with_footer_key());
+ ASSERT_EQ(kColumnEncryptionKey2, out_col_props_2->key());
+ ASSERT_EQ("kc2", out_col_props_2->key_metadata());
+
+ std::shared_ptr<parquet::schema::ColumnPath> column_path_3 =
+ parquet::schema::ColumnPath::FromDotString("column_3");
+ std::shared_ptr<ColumnEncryptionProperties> out_col_props_3 =
+ props->column_encryption_properties(column_path_3->ToDotString());
+
+ ASSERT_EQ(NULLPTR, out_col_props_3);
+}
+
+// Encryption configuration 3: Encrypt two columns, don’t encrypt footer.
+// (plaintext footer mode, readable by legacy readers)
+TEST(TestEncryptionProperties, EncryptTwoColumnsNotFooter) {
+ std::shared_ptr<parquet::schema::ColumnPath> column_path_1 =
+ parquet::schema::ColumnPath::FromDotString("column_1");
+ ColumnEncryptionProperties::Builder column_builder_1(column_path_1);
+ column_builder_1.key(kColumnEncryptionKey1);
+ column_builder_1.key_id("kc1");
+
+ std::shared_ptr<parquet::schema::ColumnPath> column_path_2 =
+ parquet::schema::ColumnPath::FromDotString("column_2");
+ ColumnEncryptionProperties::Builder column_builder_2(column_path_2);
+ column_builder_2.key(kColumnEncryptionKey2);
+ column_builder_2.key_id("kc2");
+
+ std::map<std::string, std::shared_ptr<ColumnEncryptionProperties>> encrypted_columns;
+ encrypted_columns[column_path_1->ToDotString()] = column_builder_1.build();
+ encrypted_columns[column_path_2->ToDotString()] = column_builder_2.build();
+
+ FileEncryptionProperties::Builder builder(kFooterEncryptionKey);
+ builder.footer_key_metadata("kf");
+ builder.set_plaintext_footer();
+ builder.encrypted_columns(encrypted_columns);
+ std::shared_ptr<FileEncryptionProperties> props = builder.build();
+
+ ASSERT_EQ(false, props->encrypted_footer());
+ ASSERT_EQ(kDefaultEncryptionAlgorithm, props->algorithm().algorithm);
+ ASSERT_EQ(kFooterEncryptionKey, props->footer_key());
+
+ std::shared_ptr<ColumnEncryptionProperties> out_col_props_1 =
+ props->column_encryption_properties(column_path_1->ToDotString());
+
+ ASSERT_EQ(column_path_1->ToDotString(), out_col_props_1->column_path());
+ ASSERT_EQ(true, out_col_props_1->is_encrypted());
+ ASSERT_EQ(false, out_col_props_1->is_encrypted_with_footer_key());
+ ASSERT_EQ(kColumnEncryptionKey1, out_col_props_1->key());
+ ASSERT_EQ("kc1", out_col_props_1->key_metadata());
+
+ std::shared_ptr<ColumnEncryptionProperties> out_col_props_2 =
+ props->column_encryption_properties(column_path_2->ToDotString());
+
+ ASSERT_EQ(column_path_2->ToDotString(), out_col_props_2->column_path());
+ ASSERT_EQ(true, out_col_props_2->is_encrypted());
+ ASSERT_EQ(false, out_col_props_2->is_encrypted_with_footer_key());
+ ASSERT_EQ(kColumnEncryptionKey2, out_col_props_2->key());
+ ASSERT_EQ("kc2", out_col_props_2->key_metadata());
+
+ // other columns: encrypted with footer, footer is not encrypted
+ // so column is not encrypted as well
+ std::string column_path_3 = "column_3";
+ std::shared_ptr<ColumnEncryptionProperties> out_col_props_3 =
+ props->column_encryption_properties(column_path_3);
+
+ ASSERT_EQ(NULLPTR, out_col_props_3);
+}
+
+// Use aad_prefix
+TEST(TestEncryptionProperties, UseAadPrefix) {
+ FileEncryptionProperties::Builder builder(kFooterEncryptionKey);
+ builder.aad_prefix(kFileName);
+ std::shared_ptr<FileEncryptionProperties> props = builder.build();
+
+ ASSERT_EQ(kFileName, props->algorithm().aad.aad_prefix);
+ ASSERT_EQ(false, props->algorithm().aad.supply_aad_prefix);
+}
+
+// Use aad_prefix and
+// disable_aad_prefix_storage.
+TEST(TestEncryptionProperties, UseAadPrefixNotStoreInFile) {
+ FileEncryptionProperties::Builder builder(kFooterEncryptionKey);
+ builder.aad_prefix(kFileName);
+ builder.disable_aad_prefix_storage();
+ std::shared_ptr<FileEncryptionProperties> props = builder.build();
+
+ ASSERT_EQ("", props->algorithm().aad.aad_prefix);
+ ASSERT_EQ(true, props->algorithm().aad.supply_aad_prefix);
+}
+
+// Use AES_GCM_CTR_V1 algorithm
+TEST(TestEncryptionProperties, UseAES_GCM_CTR_V1Algorithm) {
+ FileEncryptionProperties::Builder builder(kFooterEncryptionKey);
+ builder.algorithm(ParquetCipher::AES_GCM_CTR_V1);
+ std::shared_ptr<FileEncryptionProperties> props = builder.build();
+
+ ASSERT_EQ(ParquetCipher::AES_GCM_CTR_V1, props->algorithm().algorithm);
+}
+
+TEST(TestDecryptionProperties, UseKeyRetriever) {
+ std::shared_ptr<parquet::StringKeyIdRetriever> string_kr1 =
+ std::make_shared<parquet::StringKeyIdRetriever>();
+ string_kr1->PutKey("kf", kFooterEncryptionKey);
+ string_kr1->PutKey("kc1", kColumnEncryptionKey1);
+ string_kr1->PutKey("kc2", kColumnEncryptionKey2);
+ std::shared_ptr<parquet::DecryptionKeyRetriever> kr1 =
+ std::static_pointer_cast<parquet::StringKeyIdRetriever>(string_kr1);
+
+ parquet::FileDecryptionProperties::Builder builder;
+ builder.key_retriever(kr1);
+ std::shared_ptr<parquet::FileDecryptionProperties> props = builder.build();
+
+ auto out_key_retriever = props->key_retriever();
+ ASSERT_EQ(kFooterEncryptionKey, out_key_retriever->GetKey("kf"));
+ ASSERT_EQ(kColumnEncryptionKey1, out_key_retriever->GetKey("kc1"));
+ ASSERT_EQ(kColumnEncryptionKey2, out_key_retriever->GetKey("kc2"));
+}
+
+TEST(TestDecryptionProperties, SupplyAadPrefix) {
+ parquet::FileDecryptionProperties::Builder builder;
+ builder.footer_key(kFooterEncryptionKey);
+ builder.aad_prefix(kFileName);
+ std::shared_ptr<parquet::FileDecryptionProperties> props = builder.build();
+
+ ASSERT_EQ(kFileName, props->aad_prefix());
+}
+
+TEST(ColumnDecryptionProperties, SetKey) {
+ std::shared_ptr<parquet::schema::ColumnPath> column_path_1 =
+ parquet::schema::ColumnPath::FromDotString("column_1");
+ ColumnDecryptionProperties::Builder col_builder_1(column_path_1);
+ col_builder_1.key(kColumnEncryptionKey1);
+
+ auto props = col_builder_1.build();
+ ASSERT_EQ(kColumnEncryptionKey1, props->key());
+}
+
+TEST(TestDecryptionProperties, UsingExplicitFooterAndColumnKeys) {
+ std::string column_path_1 = "column_1";
+ std::string column_path_2 = "column_2";
+ std::map<std::string, std::shared_ptr<parquet::ColumnDecryptionProperties>>
+ decryption_cols;
+ parquet::ColumnDecryptionProperties::Builder col_builder_1(column_path_1);
+ parquet::ColumnDecryptionProperties::Builder col_builder_2(column_path_2);
+
+ decryption_cols[column_path_1] = col_builder_1.key(kColumnEncryptionKey1)->build();
+ decryption_cols[column_path_2] = col_builder_2.key(kColumnEncryptionKey2)->build();
+
+ parquet::FileDecryptionProperties::Builder builder;
+ builder.footer_key(kFooterEncryptionKey);
+ builder.column_keys(decryption_cols);
+ std::shared_ptr<parquet::FileDecryptionProperties> props = builder.build();
+
+ ASSERT_EQ(kFooterEncryptionKey, props->footer_key());
+ ASSERT_EQ(kColumnEncryptionKey1, props->column_key(column_path_1));
+ ASSERT_EQ(kColumnEncryptionKey2, props->column_key(column_path_2));
+}
+
+} // namespace test
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/read_configurations_test.cc b/src/arrow/cpp/src/parquet/encryption/read_configurations_test.cc
new file mode 100644
index 000000000..c065deac6
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/read_configurations_test.cc
@@ -0,0 +1,272 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include <stdio.h>
+
+#include <fstream>
+
+#include "arrow/io/file.h"
+#include "arrow/testing/gtest_compat.h"
+
+#include "parquet/column_reader.h"
+#include "parquet/column_writer.h"
+#include "parquet/encryption/test_encryption_util.h"
+#include "parquet/file_reader.h"
+#include "parquet/test_util.h"
+
+/*
+ * This file contains a unit-test for reading encrypted Parquet files with
+ * different decryption configurations.
+ *
+ * The unit-test is called multiple times, each time to decrypt parquet files using
+ * different decryption configuration as described below.
+ * In each call two encrypted files are read: one temporary file that was generated using
+ * encryption-write-configurations-test.cc test and will be deleted upon
+ * reading it, while the second resides in
+ * parquet-testing/data repository. Those two encrypted files were encrypted using the
+ * same encryption configuration.
+ * The encrypted parquet file names are passed as parameter to the unit-test.
+ *
+ * A detailed description of the Parquet Modular Encryption specification can be found
+ * here:
+ * https://github.com/apache/parquet-format/blob/encryption/Encryption.md
+ *
+ * The following decryption configurations are used to decrypt each parquet file:
+ *
+ * - Decryption configuration 1: Decrypt using key retriever that holds the keys of
+ * two encrypted columns and the footer key.
+ * - Decryption configuration 2: Decrypt using key retriever that holds the keys of
+ * two encrypted columns and the footer key. Supplies
+ * aad_prefix to verify file identity.
+ * - Decryption configuration 3: Decrypt using explicit column and footer keys
+ * (instead of key retrieval callback).
+ * - Decryption Configuration 4: PlainText Footer mode - test legacy reads,
+ * read the footer + all non-encrypted columns.
+ * (pairs with encryption configuration 3)
+ *
+ * The encrypted parquet files that is read was encrypted using one of the configurations
+ * below:
+ *
+ * - Encryption configuration 1: Encrypt all columns and the footer with the same key.
+ * (uniform encryption)
+ * - Encryption configuration 2: Encrypt two columns and the footer, with different
+ * keys.
+ * - Encryption configuration 3: Encrypt two columns, with different keys.
+ * Don’t encrypt footer (to enable legacy readers)
+ * - plaintext footer mode.
+ * - Encryption configuration 4: Encrypt two columns and the footer, with different
+ * keys. Supply aad_prefix for file identity
+ * verification.
+ * - Encryption configuration 5: Encrypt two columns and the footer, with different
+ * keys. Supply aad_prefix, and call
+ * disable_aad_prefix_storage to prevent file
+ * identity storage in file metadata.
+ * - Encryption configuration 6: Encrypt two columns and the footer, with different
+ * keys. Use the alternative (AES_GCM_CTR_V1) algorithm.
+
+ */
+
+namespace parquet {
+namespace encryption {
+namespace test {
+
+using parquet::test::ParquetTestException;
+
+class TestDecryptionConfiguration
+ : public testing::TestWithParam<std::tuple<int, const char*>> {
+ public:
+ void SetUp() { CreateDecryptionConfigurations(); }
+
+ protected:
+ FileDecryptor decryptor_;
+ std::string path_to_double_field_ = kDoubleFieldName;
+ std::string path_to_float_field_ = kFloatFieldName;
+ // This vector will hold various decryption configurations.
+ std::vector<std::shared_ptr<parquet::FileDecryptionProperties>>
+ vector_of_decryption_configurations_;
+ std::string kFooterEncryptionKey_ = std::string(kFooterEncryptionKey);
+ std::string kColumnEncryptionKey1_ = std::string(kColumnEncryptionKey1);
+ std::string kColumnEncryptionKey2_ = std::string(kColumnEncryptionKey2);
+ std::string kFileName_ = std::string(kFileName);
+
+ void CreateDecryptionConfigurations() {
+ /**********************************************************************************
+ Creating a number of Decryption configurations
+ **********************************************************************************/
+
+ // Decryption configuration 1: Decrypt using key retriever callback that holds the
+ // keys of two encrypted columns and the footer key.
+ std::shared_ptr<parquet::StringKeyIdRetriever> string_kr1 =
+ std::make_shared<parquet::StringKeyIdRetriever>();
+ string_kr1->PutKey("kf", kFooterEncryptionKey_);
+ string_kr1->PutKey("kc1", kColumnEncryptionKey1_);
+ string_kr1->PutKey("kc2", kColumnEncryptionKey2_);
+ std::shared_ptr<parquet::DecryptionKeyRetriever> kr1 =
+ std::static_pointer_cast<parquet::StringKeyIdRetriever>(string_kr1);
+
+ parquet::FileDecryptionProperties::Builder file_decryption_builder_1;
+ vector_of_decryption_configurations_.push_back(
+ file_decryption_builder_1.key_retriever(kr1)->build());
+
+ // Decryption configuration 2: Decrypt using key retriever callback that holds the
+ // keys of two encrypted columns and the footer key. Supply aad_prefix.
+ std::shared_ptr<parquet::StringKeyIdRetriever> string_kr2 =
+ std::make_shared<parquet::StringKeyIdRetriever>();
+ string_kr2->PutKey("kf", kFooterEncryptionKey_);
+ string_kr2->PutKey("kc1", kColumnEncryptionKey1_);
+ string_kr2->PutKey("kc2", kColumnEncryptionKey2_);
+ std::shared_ptr<parquet::DecryptionKeyRetriever> kr2 =
+ std::static_pointer_cast<parquet::StringKeyIdRetriever>(string_kr2);
+
+ parquet::FileDecryptionProperties::Builder file_decryption_builder_2;
+ vector_of_decryption_configurations_.push_back(
+ file_decryption_builder_2.key_retriever(kr2)->aad_prefix(kFileName_)->build());
+
+ // Decryption configuration 3: Decrypt using explicit column and footer keys. Supply
+ // aad_prefix.
+ std::string path_float_ptr = kFloatFieldName;
+ std::string path_double_ptr = kDoubleFieldName;
+ std::map<std::string, std::shared_ptr<parquet::ColumnDecryptionProperties>>
+ decryption_cols;
+ parquet::ColumnDecryptionProperties::Builder decryption_col_builder31(
+ path_double_ptr);
+ parquet::ColumnDecryptionProperties::Builder decryption_col_builder32(path_float_ptr);
+
+ decryption_cols[path_double_ptr] =
+ decryption_col_builder31.key(kColumnEncryptionKey1_)->build();
+ decryption_cols[path_float_ptr] =
+ decryption_col_builder32.key(kColumnEncryptionKey2_)->build();
+
+ parquet::FileDecryptionProperties::Builder file_decryption_builder_3;
+ vector_of_decryption_configurations_.push_back(
+ file_decryption_builder_3.footer_key(kFooterEncryptionKey_)
+ ->column_keys(decryption_cols)
+ ->build());
+
+ // Decryption Configuration 4: use plaintext footer mode, read only footer + plaintext
+ // columns.
+ vector_of_decryption_configurations_.push_back(NULL);
+ }
+
+ void DecryptFile(std::string file, int decryption_config_num) {
+ std::string exception_msg;
+ std::shared_ptr<FileDecryptionProperties> file_decryption_properties;
+ // if we get decryption_config_num = x then it means the actual number is x+1
+ // and since we want decryption_config_num=4 we set the condition to 3
+ if (decryption_config_num != 3) {
+ file_decryption_properties =
+ vector_of_decryption_configurations_[decryption_config_num]->DeepClone();
+ }
+
+ decryptor_.DecryptFile(file, file_decryption_properties);
+ }
+
+ // Check that the decryption result is as expected.
+ void CheckResults(const std::string file_name, unsigned decryption_config_num,
+ unsigned encryption_config_num) {
+ // Encryption_configuration number five contains aad_prefix and
+ // disable_aad_prefix_storage.
+ // An exception is expected to be thrown if the file is not decrypted with aad_prefix.
+ if (encryption_config_num == 5) {
+ if (decryption_config_num == 1 || decryption_config_num == 3) {
+ EXPECT_THROW(DecryptFile(file_name, decryption_config_num - 1), ParquetException);
+ return;
+ }
+ }
+ // Decryption configuration number two contains aad_prefix. An exception is expected
+ // to be thrown if the file was not encrypted with the same aad_prefix.
+ if (decryption_config_num == 2) {
+ if (encryption_config_num != 5 && encryption_config_num != 4) {
+ EXPECT_THROW(DecryptFile(file_name, decryption_config_num - 1), ParquetException);
+ return;
+ }
+ }
+
+ // decryption config 4 can only work when the encryption configuration is 3
+ if (decryption_config_num == 4 && encryption_config_num != 3) {
+ return;
+ }
+ EXPECT_NO_THROW(DecryptFile(file_name, decryption_config_num - 1));
+ }
+
+ // Returns true if file exists. Otherwise returns false.
+ bool fexists(const std::string& filename) {
+ std::ifstream ifile(filename.c_str());
+ return ifile.good();
+ }
+};
+
+// Read encrypted parquet file.
+// The test reads two parquet files that were encrypted using the same encryption
+// configuration:
+// one was generated in encryption-write-configurations-test.cc tests and is deleted
+// once the file is read and the second exists in parquet-testing/data folder.
+// The name of the files are passed as parameters to the unit-test.
+TEST_P(TestDecryptionConfiguration, TestDecryption) {
+ int encryption_config_num = std::get<0>(GetParam());
+ const char* param_file_name = std::get<1>(GetParam());
+ // Decrypt parquet file that was generated in encryption-write-configurations-test.cc
+ // test.
+ std::string tmp_file_name = "tmp_" + std::string(param_file_name);
+ std::string file_name = temp_dir->path().ToString() + tmp_file_name;
+ if (!fexists(file_name)) {
+ std::stringstream ss;
+ ss << "File " << file_name << " is missing from temporary dir.";
+ throw ParquetTestException(ss.str());
+ }
+
+ // Iterate over the decryption configurations and use each one to read the encrypted
+ // parqeut file.
+ for (unsigned index = 0; index < vector_of_decryption_configurations_.size(); ++index) {
+ unsigned decryption_config_num = index + 1;
+ CheckResults(file_name, decryption_config_num, encryption_config_num);
+ }
+ // Delete temporary test file.
+ ASSERT_EQ(std::remove(file_name.c_str()), 0);
+
+ // Decrypt parquet file that resides in parquet-testing/data directory.
+ file_name = data_file(param_file_name);
+
+ if (!fexists(file_name)) {
+ std::stringstream ss;
+ ss << "File " << file_name << " is missing from parquet-testing repo.";
+ throw ParquetTestException(ss.str());
+ }
+
+ // Iterate over the decryption configurations and use each one to read the encrypted
+ // parqeut file.
+ for (unsigned index = 0; index < vector_of_decryption_configurations_.size(); ++index) {
+ unsigned decryption_config_num = index + 1;
+ CheckResults(file_name, decryption_config_num, encryption_config_num);
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ DecryptionTests, TestDecryptionConfiguration,
+ ::testing::Values(
+ std::make_tuple(1, "uniform_encryption.parquet.encrypted"),
+ std::make_tuple(2, "encrypt_columns_and_footer.parquet.encrypted"),
+ std::make_tuple(3, "encrypt_columns_plaintext_footer.parquet.encrypted"),
+ std::make_tuple(4, "encrypt_columns_and_footer_aad.parquet.encrypted"),
+ std::make_tuple(
+ 5, "encrypt_columns_and_footer_disable_aad_storage.parquet.encrypted"),
+ std::make_tuple(6, "encrypt_columns_and_footer_ctr.parquet.encrypted")));
+
+} // namespace test
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/test_encryption_util.cc b/src/arrow/cpp/src/parquet/encryption/test_encryption_util.cc
new file mode 100644
index 000000000..8b83154c9
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/test_encryption_util.cc
@@ -0,0 +1,502 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This module defines an abstract interface for iterating through pages in a
+// Parquet column chunk within a row group. It could be extended in the future
+// to iterate through all data pages in all chunks in a file.
+
+#include <sstream>
+
+#include <arrow/io/file.h>
+
+#include "arrow/testing/future_util.h"
+#include "parquet/encryption/test_encryption_util.h"
+#include "parquet/file_reader.h"
+#include "parquet/file_writer.h"
+#include "parquet/test_util.h"
+
+using ::arrow::io::FileOutputStream;
+
+using parquet::ConvertedType;
+using parquet::Repetition;
+using parquet::Type;
+using parquet::schema::GroupNode;
+using parquet::schema::PrimitiveNode;
+
+namespace parquet {
+namespace encryption {
+namespace test {
+
+std::string data_file(const char* file) {
+ std::string dir_string(parquet::test::get_data_dir());
+ std::stringstream ss;
+ ss << dir_string << "/" << file;
+ return ss.str();
+}
+
+std::unordered_map<std::string, std::string> BuildKeyMap(const char* const* column_ids,
+ const char* const* column_keys,
+ const char* footer_id,
+ const char* footer_key) {
+ std::unordered_map<std::string, std::string> key_map;
+ // add column keys
+ for (int i = 0; i < 6; i++) {
+ key_map.insert({column_ids[i], column_keys[i]});
+ }
+ // add footer key
+ key_map.insert({footer_id, footer_key});
+
+ return key_map;
+}
+
+std::string BuildColumnKeyMapping() {
+ std::ostringstream stream;
+ stream << kColumnMasterKeyIds[0] << ":" << kDoubleFieldName << ";"
+ << kColumnMasterKeyIds[1] << ":" << kFloatFieldName << ";"
+ << kColumnMasterKeyIds[2] << ":" << kBooleanFieldName << ";"
+ << kColumnMasterKeyIds[3] << ":" << kInt32FieldName << ";"
+ << kColumnMasterKeyIds[4] << ":" << kByteArrayFieldName << ";"
+ << kColumnMasterKeyIds[5] << ":" << kFixedLenByteArrayFieldName << ";";
+ return stream.str();
+}
+
+template <typename DType>
+struct ColumnData {
+ typedef typename DType::c_type T;
+
+ std::vector<T> values;
+ std::vector<int16_t> definition_levels;
+ std::vector<int16_t> repetition_levels;
+
+ int64_t rows() const { return values.size(); }
+ const T* raw_values() const { return values.data(); }
+ const int16_t* raw_definition_levels() const {
+ return definition_levels.size() == 0 ? nullptr : definition_levels.data();
+ }
+ const int16_t* raw_repetition_levels() const {
+ return repetition_levels.size() == 0 ? nullptr : repetition_levels.data();
+ }
+};
+
+template <typename DType>
+ColumnData<DType> GenerateSampleData(int rows) {
+ return ColumnData<DType>();
+}
+
+template <>
+ColumnData<Int32Type> GenerateSampleData<Int32Type>(int rows) {
+ ColumnData<Int32Type> int32_col;
+ // Int32 column
+ for (int i = 0; i < rows; i++) {
+ int32_col.values.push_back(i);
+ }
+ return int32_col;
+}
+
+template <>
+ColumnData<Int64Type> GenerateSampleData<Int64Type>(int rows) {
+ ColumnData<Int64Type> int64_col;
+ // The Int64 column. Each row has repeats twice.
+ for (int i = 0; i < 2 * rows; i++) {
+ int64_t value = i * 1000 * 1000;
+ value *= 1000 * 1000;
+ int16_t definition_level = 1;
+ int16_t repetition_level = 0;
+ if ((i % 2) == 0) {
+ repetition_level = 1; // start of a new record
+ }
+ int64_col.values.push_back(value);
+ int64_col.definition_levels.push_back(definition_level);
+ int64_col.repetition_levels.push_back(repetition_level);
+ }
+ return int64_col;
+}
+
+template <>
+ColumnData<Int96Type> GenerateSampleData<Int96Type>(int rows) {
+ ColumnData<Int96Type> int96_col;
+ for (int i = 0; i < rows; i++) {
+ parquet::Int96 value;
+ value.value[0] = i;
+ value.value[1] = i + 1;
+ value.value[2] = i + 2;
+ int96_col.values.push_back(value);
+ }
+ return int96_col;
+}
+
+template <>
+ColumnData<FloatType> GenerateSampleData<FloatType>(int rows) {
+ ColumnData<FloatType> float_col;
+ for (int i = 0; i < rows; i++) {
+ float value = static_cast<float>(i) * 1.1f;
+ float_col.values.push_back(value);
+ }
+ return float_col;
+}
+
+template <>
+ColumnData<DoubleType> GenerateSampleData<DoubleType>(int rows) {
+ ColumnData<DoubleType> double_col;
+ for (int i = 0; i < rows; i++) {
+ double value = i * 1.1111111;
+ double_col.values.push_back(value);
+ }
+ return double_col;
+}
+
+template <typename DType, typename NextFunc>
+void WriteBatch(int rows, const NextFunc get_next_column) {
+ ColumnData<DType> column = GenerateSampleData<DType>(rows);
+ TypedColumnWriter<DType>* writer =
+ static_cast<TypedColumnWriter<DType>*>(get_next_column());
+ writer->WriteBatch(column.rows(), column.raw_definition_levels(),
+ column.raw_repetition_levels(), column.raw_values());
+}
+
+FileEncryptor::FileEncryptor() { schema_ = SetupEncryptionSchema(); }
+
+std::shared_ptr<GroupNode> FileEncryptor::SetupEncryptionSchema() {
+ parquet::schema::NodeVector fields;
+
+ fields.push_back(PrimitiveNode::Make(kBooleanFieldName, Repetition::REQUIRED,
+ Type::BOOLEAN, ConvertedType::NONE));
+
+ fields.push_back(PrimitiveNode::Make(kInt32FieldName, Repetition::REQUIRED, Type::INT32,
+ ConvertedType::TIME_MILLIS));
+
+ fields.push_back(PrimitiveNode::Make(kInt64FieldName, Repetition::REPEATED, Type::INT64,
+ ConvertedType::NONE));
+
+ fields.push_back(PrimitiveNode::Make(kInt96FieldName, Repetition::REQUIRED, Type::INT96,
+ ConvertedType::NONE));
+
+ fields.push_back(PrimitiveNode::Make(kFloatFieldName, Repetition::REQUIRED, Type::FLOAT,
+ ConvertedType::NONE));
+
+ fields.push_back(PrimitiveNode::Make(kDoubleFieldName, Repetition::REQUIRED,
+ Type::DOUBLE, ConvertedType::NONE));
+
+ fields.push_back(PrimitiveNode::Make(kByteArrayFieldName, Repetition::OPTIONAL,
+ Type::BYTE_ARRAY, ConvertedType::NONE));
+
+ fields.push_back(PrimitiveNode::Make(kFixedLenByteArrayFieldName, Repetition::REQUIRED,
+ Type::FIXED_LEN_BYTE_ARRAY, ConvertedType::NONE,
+ kFixedLength));
+
+ return std::static_pointer_cast<GroupNode>(
+ GroupNode::Make("schema", Repetition::REQUIRED, fields));
+}
+
+void FileEncryptor::EncryptFile(
+ std::string file,
+ std::shared_ptr<parquet::FileEncryptionProperties> encryption_configurations) {
+ WriterProperties::Builder prop_builder;
+ prop_builder.compression(parquet::Compression::SNAPPY);
+ prop_builder.encryption(encryption_configurations);
+ std::shared_ptr<WriterProperties> writer_properties = prop_builder.build();
+
+ PARQUET_ASSIGN_OR_THROW(auto out_file, FileOutputStream::Open(file));
+ // Create a ParquetFileWriter instance
+ std::shared_ptr<parquet::ParquetFileWriter> file_writer =
+ parquet::ParquetFileWriter::Open(out_file, schema_, writer_properties);
+
+ for (int r = 0; r < num_rowgroups_; r++) {
+ bool buffered_mode = r % 2 == 0;
+ auto row_group_writer = buffered_mode ? file_writer->AppendBufferedRowGroup()
+ : file_writer->AppendRowGroup();
+
+ int column_index = 0;
+ // Captures i by reference; increments it by one
+ auto get_next_column = [&]() {
+ return buffered_mode ? row_group_writer->column(column_index++)
+ : row_group_writer->NextColumn();
+ };
+
+ // Write the Bool column
+ parquet::BoolWriter* bool_writer =
+ static_cast<parquet::BoolWriter*>(get_next_column());
+ for (int i = 0; i < rows_per_rowgroup_; i++) {
+ bool value = ((i % 2) == 0) ? true : false;
+ bool_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+
+ // Write the Int32 column
+ WriteBatch<Int32Type>(rows_per_rowgroup_, get_next_column);
+
+ // Write the Int64 column.
+ WriteBatch<Int64Type>(rows_per_rowgroup_, get_next_column);
+
+ // Write the INT96 column.
+ WriteBatch<Int96Type>(rows_per_rowgroup_, get_next_column);
+
+ // Write the Float column
+ WriteBatch<FloatType>(rows_per_rowgroup_, get_next_column);
+
+ // Write the Double column
+ WriteBatch<DoubleType>(rows_per_rowgroup_, get_next_column);
+
+ // Write the ByteArray column. Make every alternate values NULL
+ // Write the ByteArray column. Make every alternate values NULL
+ parquet::ByteArrayWriter* ba_writer =
+ static_cast<parquet::ByteArrayWriter*>(get_next_column());
+ for (int i = 0; i < rows_per_rowgroup_; i++) {
+ parquet::ByteArray value;
+ char hello[kFixedLength] = "parquet";
+ hello[7] = static_cast<char>(static_cast<int>('0') + i / 100);
+ hello[8] = static_cast<char>(static_cast<int>('0') + (i / 10) % 10);
+ hello[9] = static_cast<char>(static_cast<int>('0') + i % 10);
+ if (i % 2 == 0) {
+ int16_t definition_level = 1;
+ value.ptr = reinterpret_cast<const uint8_t*>(&hello[0]);
+ value.len = kFixedLength;
+ ba_writer->WriteBatch(1, &definition_level, nullptr, &value);
+ } else {
+ int16_t definition_level = 0;
+ ba_writer->WriteBatch(1, &definition_level, nullptr, nullptr);
+ }
+ }
+
+ // Write the FixedLengthByteArray column
+ parquet::FixedLenByteArrayWriter* flba_writer =
+ static_cast<parquet::FixedLenByteArrayWriter*>(get_next_column());
+ for (int i = 0; i < rows_per_rowgroup_; i++) {
+ parquet::FixedLenByteArray value;
+ char v = static_cast<char>(i);
+ char flba[kFixedLength] = {v, v, v, v, v, v, v, v, v, v};
+ value.ptr = reinterpret_cast<const uint8_t*>(&flba[0]);
+ flba_writer->WriteBatch(1, nullptr, nullptr, &value);
+ }
+ }
+
+ // Close the ParquetFileWriter
+ file_writer->Close();
+ PARQUET_THROW_NOT_OK(out_file->Close());
+
+ return;
+} // namespace test
+
+template <typename DType, typename RowGroupReader, typename RowGroupMetadata>
+void ReadAndVerifyColumn(RowGroupReader* rg_reader, RowGroupMetadata* rg_md,
+ int column_index, int rows) {
+ ColumnData<DType> expected_column_data = GenerateSampleData<DType>(rows);
+ std::shared_ptr<parquet::ColumnReader> column_reader = rg_reader->Column(column_index);
+ TypedColumnReader<DType>* reader =
+ static_cast<TypedColumnReader<DType>*>(column_reader.get());
+
+ std::unique_ptr<ColumnChunkMetaData> col_md = rg_md->ColumnChunk(column_index);
+
+ int64_t rows_should_read = expected_column_data.values.size();
+
+ // Read all the rows in the column
+ ColumnData<DType> read_col_data;
+ read_col_data.values.resize(rows_should_read);
+ int64_t values_read;
+ int64_t rows_read;
+ if (expected_column_data.definition_levels.size() > 0 &&
+ expected_column_data.repetition_levels.size() > 0) {
+ std::vector<int16_t> definition_levels(rows_should_read);
+ std::vector<int16_t> repetition_levels(rows_should_read);
+ rows_read = reader->ReadBatch(rows_should_read, definition_levels.data(),
+ repetition_levels.data(), read_col_data.values.data(),
+ &values_read);
+ ASSERT_EQ(definition_levels, expected_column_data.definition_levels);
+ ASSERT_EQ(repetition_levels, expected_column_data.repetition_levels);
+ } else {
+ rows_read = reader->ReadBatch(rows_should_read, nullptr, nullptr,
+ read_col_data.values.data(), &values_read);
+ }
+ ASSERT_EQ(rows_read, rows_should_read);
+ ASSERT_EQ(values_read, rows_should_read);
+ ASSERT_EQ(read_col_data.values, expected_column_data.values);
+ // make sure we got the same number of values the metadata says
+ ASSERT_EQ(col_md->num_values(), rows_read);
+}
+
+void FileDecryptor::DecryptFile(
+ std::string file,
+ std::shared_ptr<FileDecryptionProperties> file_decryption_properties) {
+ std::string exception_msg;
+ parquet::ReaderProperties reader_properties = parquet::default_reader_properties();
+ if (file_decryption_properties) {
+ reader_properties.file_decryption_properties(file_decryption_properties->DeepClone());
+ }
+
+ std::shared_ptr<::arrow::io::RandomAccessFile> source;
+ PARQUET_ASSIGN_OR_THROW(
+ source, ::arrow::io::ReadableFile::Open(file, reader_properties.memory_pool()));
+
+ auto file_reader = parquet::ParquetFileReader::Open(source, reader_properties);
+ CheckFile(file_reader.get(), file_decryption_properties.get());
+
+ if (file_decryption_properties) {
+ reader_properties.file_decryption_properties(file_decryption_properties->DeepClone());
+ }
+ auto fut = parquet::ParquetFileReader::OpenAsync(source, reader_properties);
+ ASSERT_FINISHES_OK(fut);
+ ASSERT_OK_AND_ASSIGN(file_reader, fut.MoveResult());
+ CheckFile(file_reader.get(), file_decryption_properties.get());
+
+ file_reader->Close();
+ PARQUET_THROW_NOT_OK(source->Close());
+}
+
+void FileDecryptor::CheckFile(parquet::ParquetFileReader* file_reader,
+ FileDecryptionProperties* file_decryption_properties) {
+ // Get the File MetaData
+ std::shared_ptr<parquet::FileMetaData> file_metadata = file_reader->metadata();
+
+ // Get the number of RowGroups
+ int num_row_groups = file_metadata->num_row_groups();
+
+ // Get the number of Columns
+ int num_columns = file_metadata->num_columns();
+ ASSERT_EQ(num_columns, 8);
+
+ // Iterate over all the RowGroups in the file
+ for (int r = 0; r < num_row_groups; ++r) {
+ // Get the RowGroup Reader
+ std::shared_ptr<parquet::RowGroupReader> row_group_reader = file_reader->RowGroup(r);
+
+ // Get the RowGroupMetaData
+ std::unique_ptr<RowGroupMetaData> rg_metadata = file_metadata->RowGroup(r);
+
+ int rows_per_rowgroup = static_cast<int>(rg_metadata->num_rows());
+
+ int64_t values_read = 0;
+ int64_t rows_read = 0;
+ int16_t definition_level;
+ // int16_t repetition_level;
+ int i;
+ std::shared_ptr<parquet::ColumnReader> column_reader;
+
+ // Get the Column Reader for the boolean column
+ column_reader = row_group_reader->Column(0);
+ parquet::BoolReader* bool_reader =
+ static_cast<parquet::BoolReader*>(column_reader.get());
+
+ // Get the ColumnChunkMetaData for the boolean column
+ std::unique_ptr<ColumnChunkMetaData> boolean_md = rg_metadata->ColumnChunk(0);
+
+ // Read all the rows in the column
+ i = 0;
+ while (bool_reader->HasNext()) {
+ bool value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = bool_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ ASSERT_EQ(rows_read, 1);
+ // There are no NULL values in the rows written
+ ASSERT_EQ(values_read, 1);
+ // Verify the value written
+ bool expected_value = ((i % 2) == 0) ? true : false;
+ ASSERT_EQ(value, expected_value);
+ i++;
+ }
+ // make sure we got the same number of values the metadata says
+ ASSERT_EQ(boolean_md->num_values(), i);
+
+ ReadAndVerifyColumn<Int32Type>(row_group_reader.get(), rg_metadata.get(), 1,
+ rows_per_rowgroup);
+
+ ReadAndVerifyColumn<Int64Type>(row_group_reader.get(), rg_metadata.get(), 2,
+ rows_per_rowgroup);
+
+ ReadAndVerifyColumn<Int96Type>(row_group_reader.get(), rg_metadata.get(), 3,
+ rows_per_rowgroup);
+
+ if (file_decryption_properties) {
+ ReadAndVerifyColumn<FloatType>(row_group_reader.get(), rg_metadata.get(), 4,
+ rows_per_rowgroup);
+
+ ReadAndVerifyColumn<DoubleType>(row_group_reader.get(), rg_metadata.get(), 5,
+ rows_per_rowgroup);
+ }
+
+ // Get the Column Reader for the ByteArray column
+ column_reader = row_group_reader->Column(6);
+ parquet::ByteArrayReader* ba_reader =
+ static_cast<parquet::ByteArrayReader*>(column_reader.get());
+
+ // Get the ColumnChunkMetaData for the ByteArray column
+ std::unique_ptr<ColumnChunkMetaData> ba_md = rg_metadata->ColumnChunk(6);
+
+ // Read all the rows in the column
+ i = 0;
+ while (ba_reader->HasNext()) {
+ parquet::ByteArray value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read =
+ ba_reader->ReadBatch(1, &definition_level, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ ASSERT_EQ(rows_read, 1);
+ // Verify the value written
+ char expected_value[kFixedLength] = "parquet";
+ expected_value[7] = static_cast<char>('0' + i / 100);
+ expected_value[8] = static_cast<char>('0' + (i / 10) % 10);
+ expected_value[9] = static_cast<char>('0' + i % 10);
+ if (i % 2 == 0) { // only alternate values exist
+ // There are no NULL values in the rows written
+ ASSERT_EQ(values_read, 1);
+ ASSERT_EQ(value.len, kFixedLength);
+ ASSERT_EQ(memcmp(value.ptr, &expected_value[0], kFixedLength), 0);
+ ASSERT_EQ(definition_level, 1);
+ } else {
+ // There are NULL values in the rows written
+ ASSERT_EQ(values_read, 0);
+ ASSERT_EQ(definition_level, 0);
+ }
+ i++;
+ }
+ // make sure we got the same number of values the metadata says
+ ASSERT_EQ(ba_md->num_values(), i);
+
+ // Get the Column Reader for the FixedLengthByteArray column
+ column_reader = row_group_reader->Column(7);
+ parquet::FixedLenByteArrayReader* flba_reader =
+ static_cast<parquet::FixedLenByteArrayReader*>(column_reader.get());
+
+ // Get the ColumnChunkMetaData for the FixedLengthByteArray column
+ std::unique_ptr<ColumnChunkMetaData> flba_md = rg_metadata->ColumnChunk(7);
+
+ // Read all the rows in the column
+ i = 0;
+ while (flba_reader->HasNext()) {
+ parquet::FixedLenByteArray value;
+ // Read one value at a time. The number of rows read is returned. values_read
+ // contains the number of non-null rows
+ rows_read = flba_reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
+ // Ensure only one value is read
+ ASSERT_EQ(rows_read, 1);
+ // There are no NULL values in the rows written
+ ASSERT_EQ(values_read, 1);
+ // Verify the value written
+ char v = static_cast<char>(i);
+ char expected_value[kFixedLength] = {v, v, v, v, v, v, v, v, v, v};
+ ASSERT_EQ(memcmp(value.ptr, &expected_value[0], kFixedLength), 0);
+ i++;
+ }
+ // make sure we got the same number of values the metadata says
+ ASSERT_EQ(flba_md->num_values(), i);
+ }
+}
+
+} // namespace test
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/test_encryption_util.h b/src/arrow/cpp/src/parquet/encryption/test_encryption_util.h
new file mode 100644
index 000000000..b5d71b995
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/test_encryption_util.h
@@ -0,0 +1,118 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This module defines an abstract interface for iterating through pages in a
+// Parquet column chunk within a row group. It could be extended in the future
+// to iterate through all data pages in all chunks in a file.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include <gtest/gtest.h>
+
+#include "arrow/util/io_util.h"
+
+#include "parquet/encryption/encryption.h"
+#include "parquet/test_util.h"
+
+namespace parquet {
+class ParquetFileReader;
+namespace encryption {
+namespace test {
+
+using ::arrow::internal::TemporaryDir;
+
+constexpr int kFixedLength = 10;
+
+const char kFooterEncryptionKey[] = "0123456789012345"; // 128bit/16
+const char kColumnEncryptionKey1[] = "1234567890123450";
+const char kColumnEncryptionKey2[] = "1234567890123451";
+const char kFileName[] = "tester";
+
+// Get the path of file inside parquet test data directory
+std::string data_file(const char* file);
+
+// A temporary directory that contains the encrypted files generated in the tests.
+extern std::unique_ptr<TemporaryDir> temp_dir;
+
+inline ::arrow::Result<std::unique_ptr<TemporaryDir>> temp_data_dir() {
+ return TemporaryDir::Make("parquet-encryption-test-");
+}
+
+const char kDoubleFieldName[] = "double_field";
+const char kFloatFieldName[] = "float_field";
+const char kBooleanFieldName[] = "boolean_field";
+const char kInt32FieldName[] = "int32_field";
+const char kInt64FieldName[] = "int64_field";
+const char kInt96FieldName[] = "int96_field";
+const char kByteArrayFieldName[] = "ba_field";
+const char kFixedLenByteArrayFieldName[] = "flba_field";
+
+const char kFooterMasterKey[] = "0123456789112345";
+const char kFooterMasterKeyId[] = "kf";
+const char* const kColumnMasterKeys[] = {"1234567890123450", "1234567890123451",
+ "1234567890123452", "1234567890123453",
+ "1234567890123454", "1234567890123455"};
+const char* const kColumnMasterKeyIds[] = {"kc1", "kc2", "kc3", "kc4", "kc5", "kc6"};
+
+// The result of this function will be used to set into TestOnlyInMemoryKmsClientFactory
+// as the key mapping to look at.
+std::unordered_map<std::string, std::string> BuildKeyMap(const char* const* column_ids,
+ const char* const* column_keys,
+ const char* footer_id,
+ const char* footer_key);
+
+// The result of this function will be used to set into EncryptionConfiguration
+// as colum keys.
+std::string BuildColumnKeyMapping();
+
+// FileEncryptor and FileDecryptor are helper classes to write/read an encrypted parquet
+// file corresponding to each pair of FileEncryptionProperties/FileDecryptionProperties.
+// FileEncryptor writes the file with fixed data values and FileDecryptor reads the file
+// and verify the correctness of data values.
+class FileEncryptor {
+ public:
+ FileEncryptor();
+
+ void EncryptFile(
+ std::string file,
+ std::shared_ptr<parquet::FileEncryptionProperties> encryption_configurations);
+
+ private:
+ std::shared_ptr<schema::GroupNode> SetupEncryptionSchema();
+
+ int num_rowgroups_ = 5;
+ int rows_per_rowgroup_ = 50;
+ std::shared_ptr<schema::GroupNode> schema_;
+};
+
+class FileDecryptor {
+ public:
+ void DecryptFile(std::string file_name,
+ std::shared_ptr<FileDecryptionProperties> file_decryption_properties);
+
+ private:
+ void CheckFile(parquet::ParquetFileReader* file_reader,
+ FileDecryptionProperties* file_decryption_properties);
+};
+
+} // namespace test
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/test_in_memory_kms.cc b/src/arrow/cpp/src/parquet/encryption/test_in_memory_kms.cc
new file mode 100644
index 000000000..c30b3418d
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/test_in_memory_kms.cc
@@ -0,0 +1,81 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/base64.h"
+
+#include "parquet/encryption/key_toolkit_internal.h"
+#include "parquet/encryption/test_in_memory_kms.h"
+#include "parquet/exception.h"
+
+namespace parquet {
+namespace encryption {
+
+std::unordered_map<std::string, std::string>
+ TestOnlyLocalWrapInMemoryKms::master_key_map_;
+std::unordered_map<std::string, std::string> TestOnlyInServerWrapKms::master_key_map_;
+
+void TestOnlyLocalWrapInMemoryKms::InitializeMasterKeys(
+ const std::unordered_map<std::string, std::string>& master_keys_map) {
+ master_key_map_ = master_keys_map;
+}
+
+TestOnlyLocalWrapInMemoryKms::TestOnlyLocalWrapInMemoryKms(
+ const KmsConnectionConfig& kms_connection_config)
+ : LocalWrapKmsClient(kms_connection_config) {}
+
+std::string TestOnlyLocalWrapInMemoryKms::GetMasterKeyFromServer(
+ const std::string& master_key_identifier) {
+ // Always return the latest key version
+ return master_key_map_.at(master_key_identifier);
+}
+
+void TestOnlyInServerWrapKms::InitializeMasterKeys(
+ const std::unordered_map<std::string, std::string>& master_keys_map) {
+ master_key_map_ = master_keys_map;
+}
+
+std::string TestOnlyInServerWrapKms::WrapKey(const std::string& key_bytes,
+ const std::string& master_key_identifier) {
+ // Always use the latest key version for writing
+ if (master_key_map_.find(master_key_identifier) == master_key_map_.end()) {
+ throw ParquetException("Key not found: " + master_key_identifier);
+ }
+ const std::string& master_key = master_key_map_.at(master_key_identifier);
+
+ std::string aad = master_key_identifier;
+ return internal::EncryptKeyLocally(key_bytes, master_key, aad);
+}
+
+std::string TestOnlyInServerWrapKms::UnwrapKey(const std::string& wrapped_key,
+ const std::string& master_key_identifier) {
+ if (master_key_map_.find(master_key_identifier) == master_key_map_.end()) {
+ throw ParquetException("Key not found: " + master_key_identifier);
+ }
+ const std::string& master_key = master_key_map_.at(master_key_identifier);
+
+ std::string aad = master_key_identifier;
+ return internal::DecryptKeyLocally(wrapped_key, master_key, aad);
+}
+
+std::string TestOnlyInServerWrapKms::GetMasterKeyFromServer(
+ const std::string& master_key_identifier) {
+ // Always return the latest key version
+ return master_key_map_.at(master_key_identifier);
+}
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/test_in_memory_kms.h b/src/arrow/cpp/src/parquet/encryption/test_in_memory_kms.h
new file mode 100644
index 000000000..f54f67da5
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/test_in_memory_kms.h
@@ -0,0 +1,89 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <unordered_map>
+
+#include "arrow/util/base64.h"
+
+#include "parquet/encryption/kms_client_factory.h"
+#include "parquet/encryption/local_wrap_kms_client.h"
+#include "parquet/platform.h"
+
+namespace parquet {
+namespace encryption {
+
+// This is a mock class, built for testing only. Don't use it as an example of
+// LocalWrapKmsClient implementation.
+class TestOnlyLocalWrapInMemoryKms : public LocalWrapKmsClient {
+ public:
+ explicit TestOnlyLocalWrapInMemoryKms(const KmsConnectionConfig& kms_connection_config);
+
+ static void InitializeMasterKeys(
+ const std::unordered_map<std::string, std::string>& master_keys_map);
+
+ protected:
+ std::string GetMasterKeyFromServer(const std::string& master_key_identifier) override;
+
+ private:
+ static std::unordered_map<std::string, std::string> master_key_map_;
+};
+
+// This is a mock class, built for testing only. Don't use it as an example of KmsClient
+// implementation.
+class TestOnlyInServerWrapKms : public KmsClient {
+ public:
+ static void InitializeMasterKeys(
+ const std::unordered_map<std::string, std::string>& master_keys_map);
+
+ std::string WrapKey(const std::string& key_bytes,
+ const std::string& master_key_identifier) override;
+
+ std::string UnwrapKey(const std::string& wrapped_key,
+ const std::string& master_key_identifier) override;
+
+ private:
+ std::string GetMasterKeyFromServer(const std::string& master_key_identifier);
+
+ static std::unordered_map<std::string, std::string> master_key_map_;
+};
+
+// This is a mock class, built for testing only. Don't use it as an example of
+// KmsClientFactory implementation.
+class TestOnlyInMemoryKmsClientFactory : public KmsClientFactory {
+ public:
+ TestOnlyInMemoryKmsClientFactory(
+ bool wrap_locally,
+ const std::unordered_map<std::string, std::string>& master_keys_map)
+ : KmsClientFactory(wrap_locally) {
+ TestOnlyLocalWrapInMemoryKms::InitializeMasterKeys(master_keys_map);
+ TestOnlyInServerWrapKms::InitializeMasterKeys(master_keys_map);
+ }
+
+ std::shared_ptr<KmsClient> CreateKmsClient(
+ const KmsConnectionConfig& kms_connection_config) {
+ if (wrap_locally_) {
+ return std::make_shared<TestOnlyLocalWrapInMemoryKms>(kms_connection_config);
+ } else {
+ return std::make_shared<TestOnlyInServerWrapKms>();
+ }
+ }
+};
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/two_level_cache_with_expiration.h b/src/arrow/cpp/src/parquet/encryption/two_level_cache_with_expiration.h
new file mode 100644
index 000000000..fbd06dc7d
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/two_level_cache_with_expiration.h
@@ -0,0 +1,159 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <chrono>
+#include <unordered_map>
+
+#include "arrow/util/concurrent_map.h"
+#include "arrow/util/mutex.h"
+
+namespace parquet {
+namespace encryption {
+
+using ::arrow::util::ConcurrentMap;
+
+namespace internal {
+
+using TimePoint =
+ std::chrono::time_point<std::chrono::system_clock, std::chrono::duration<double>>;
+
+inline TimePoint CurrentTimePoint() { return std::chrono::system_clock::now(); }
+
+template <typename E>
+class ExpiringCacheEntry {
+ public:
+ ExpiringCacheEntry() = default;
+
+ ExpiringCacheEntry(E cached_item, double expiration_interval_seconds)
+ : expiration_timestamp_(CurrentTimePoint() +
+ std::chrono::duration<double>(expiration_interval_seconds)),
+ cached_item_(std::move(cached_item)) {}
+
+ bool IsExpired() const {
+ const auto now = CurrentTimePoint();
+ return (now > expiration_timestamp_);
+ }
+
+ E cached_item() { return cached_item_; }
+
+ private:
+ const TimePoint expiration_timestamp_;
+ E cached_item_;
+};
+
+// This class is to avoid the below warning when compiling KeyToolkit class with VS2015
+// warning C4503: decorated name length exceeded, name was truncated
+template <typename V>
+class ExpiringCacheMapEntry {
+ public:
+ ExpiringCacheMapEntry() = default;
+
+ explicit ExpiringCacheMapEntry(
+ std::shared_ptr<ConcurrentMap<std::string, V>> cached_item,
+ double expiration_interval_seconds)
+ : map_cache_(cached_item, expiration_interval_seconds) {}
+
+ bool IsExpired() { return map_cache_.IsExpired(); }
+
+ std::shared_ptr<ConcurrentMap<std::string, V>> cached_item() {
+ return map_cache_.cached_item();
+ }
+
+ private:
+ // ConcurrentMap object may be accessed and modified at many places at the same time,
+ // from multiple threads, or even removed from cache.
+ ExpiringCacheEntry<std::shared_ptr<ConcurrentMap<std::string, V>>> map_cache_;
+};
+
+} // namespace internal
+
+// Two-level cache with expiration of internal caches according to token lifetime.
+// External cache is per token, internal is per string key.
+// Wrapper class around:
+// std::unordered_map<std::string,
+// internal::ExpiringCacheEntry<std::unordered_map<std::string, V>>>
+// This cache is safe to be shared between threads.
+template <typename V>
+class TwoLevelCacheWithExpiration {
+ public:
+ TwoLevelCacheWithExpiration() {
+ last_cache_cleanup_timestamp_ = internal::CurrentTimePoint();
+ }
+
+ std::shared_ptr<ConcurrentMap<std::string, V>> GetOrCreateInternalCache(
+ const std::string& access_token, double cache_entry_lifetime_seconds) {
+ auto lock = mutex_.Lock();
+
+ auto external_cache_entry = cache_.find(access_token);
+ if (external_cache_entry == cache_.end() ||
+ external_cache_entry->second.IsExpired()) {
+ cache_.insert({access_token, internal::ExpiringCacheMapEntry<V>(
+ std::shared_ptr<ConcurrentMap<std::string, V>>(
+ new ConcurrentMap<std::string, V>()),
+ cache_entry_lifetime_seconds)});
+ }
+
+ return cache_[access_token].cached_item();
+ }
+
+ void CheckCacheForExpiredTokens(double cache_cleanup_period_seconds) {
+ auto lock = mutex_.Lock();
+
+ const auto now = internal::CurrentTimePoint();
+ if (now > (last_cache_cleanup_timestamp_ +
+ std::chrono::duration<double>(cache_cleanup_period_seconds))) {
+ RemoveExpiredEntriesNoMutex();
+ last_cache_cleanup_timestamp_ =
+ now + std::chrono::duration<double>(cache_cleanup_period_seconds);
+ }
+ }
+
+ void RemoveExpiredEntriesFromCache() {
+ auto lock = mutex_.Lock();
+
+ RemoveExpiredEntriesNoMutex();
+ }
+
+ void Remove(const std::string& access_token) {
+ auto lock = mutex_.Lock();
+ cache_.erase(access_token);
+ }
+
+ void Clear() {
+ auto lock = mutex_.Lock();
+ cache_.clear();
+ }
+
+ private:
+ void RemoveExpiredEntriesNoMutex() {
+ for (auto it = cache_.begin(); it != cache_.end();) {
+ if (it->second.IsExpired()) {
+ it = cache_.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ }
+ std::unordered_map<std::string, internal::ExpiringCacheMapEntry<V>> cache_;
+ internal::TimePoint last_cache_cleanup_timestamp_;
+ ::arrow::util::Mutex mutex_;
+};
+
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/two_level_cache_with_expiration_test.cc b/src/arrow/cpp/src/parquet/encryption/two_level_cache_with_expiration_test.cc
new file mode 100644
index 000000000..f375a5c5b
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/two_level_cache_with_expiration_test.cc
@@ -0,0 +1,177 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <chrono>
+#include <thread>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/concurrent_map.h"
+
+#include "parquet/encryption/two_level_cache_with_expiration.h"
+
+namespace parquet {
+namespace encryption {
+namespace test {
+
+using ::arrow::SleepFor;
+
+class TwoLevelCacheWithExpirationTest : public ::testing::Test {
+ public:
+ void SetUp() {
+ // lifetime is 0.2s
+ std::shared_ptr<ConcurrentMap<std::string, int>> lifetime1 =
+ cache_.GetOrCreateInternalCache("lifetime1", 0.2);
+ lifetime1->Insert("item1", 1);
+ lifetime1->Insert("item2", 2);
+
+ // lifetime is 0.5s
+ std::shared_ptr<ConcurrentMap<std::string, int>> lifetime2 =
+ cache_.GetOrCreateInternalCache("lifetime2", 0.5);
+ lifetime2->Insert("item21", 21);
+ lifetime2->Insert("item22", 22);
+ }
+
+ protected:
+ void TaskInsert(int thread_no) {
+ for (int i = 0; i < 20; i++) {
+ std::string token = (i % 2 == 0) ? "lifetime1" : "lifetime2";
+ double lifetime = (i % 2 == 0) ? 0.2 : 0.5;
+ auto internal_cache = cache_.GetOrCreateInternalCache(token, lifetime);
+ std::stringstream ss;
+ ss << "item_" << thread_no << "_" << i;
+ internal_cache->Insert(ss.str(), i);
+ SleepFor(0.005);
+ }
+ }
+
+ void TaskClean() {
+ for (int i = 0; i < 20; i++) {
+ cache_.Clear();
+ SleepFor(0.008);
+ }
+ }
+
+ TwoLevelCacheWithExpiration<int> cache_;
+};
+
+TEST_F(TwoLevelCacheWithExpirationTest, RemoveExpiration) {
+ auto lifetime1_before_expiration = cache_.GetOrCreateInternalCache("lifetime1", 1);
+ ASSERT_EQ(lifetime1_before_expiration->size(), 2);
+
+ // wait for 0.3s, we expect:
+ // lifetime1 will be expired
+ // lifetime2 will not be expired
+ SleepFor(0.3);
+ // now clear expired items from the cache
+ cache_.RemoveExpiredEntriesFromCache();
+
+ // lifetime1 (with 2 items) is expired and has been removed from the cache.
+ // Now the cache create a new object which has no item.
+ auto lifetime1 = cache_.GetOrCreateInternalCache("lifetime1", 1);
+ ASSERT_EQ(lifetime1->size(), 0);
+
+ // However, lifetime1_before_expiration can still access normally and independently
+ // from the one in cache
+ lifetime1_before_expiration->Insert("item3", 3);
+ ASSERT_EQ(lifetime1_before_expiration->size(), 3);
+ ASSERT_EQ(lifetime1->size(), 0);
+
+ // lifetime2 is not expired and still contains 2 items.
+ std::shared_ptr<ConcurrentMap<std::string, int>> lifetime2 =
+ cache_.GetOrCreateInternalCache("lifetime2", 3);
+ ASSERT_EQ(lifetime2->size(), 2);
+}
+
+TEST_F(TwoLevelCacheWithExpirationTest, CleanupPeriodOk) {
+ // wait for 0.3s, now:
+ // lifetime1 is expired
+ // lifetime2 isn't expired
+ SleepFor(0.3);
+
+ // cleanup_period is 0.2s, less than or equals lifetime of both items, so the expired
+ // items will be removed from cache.
+ cache_.CheckCacheForExpiredTokens(0.2);
+
+ // lifetime1 (with 2 items) is expired and has been removed from the cache.
+ // Now the cache create a new object which has no item.
+ auto lifetime1 = cache_.GetOrCreateInternalCache("lifetime1", 1);
+ ASSERT_EQ(lifetime1->size(), 0);
+
+ // lifetime2 is not expired and still contains 2 items.
+ auto lifetime2 = cache_.GetOrCreateInternalCache("lifetime2", 3);
+ ASSERT_EQ(lifetime2->size(), 2);
+}
+
+TEST_F(TwoLevelCacheWithExpirationTest, RemoveByToken) {
+ cache_.Remove("lifetime1");
+
+ // lifetime1 (with 2 items) has been removed from the cache.
+ // Now the cache create a new object which has no item.
+ auto lifetime1 = cache_.GetOrCreateInternalCache("lifetime1", 1);
+ ASSERT_EQ(lifetime1->size(), 0);
+
+ // lifetime2 is still contains 2 items.
+ auto lifetime2 = cache_.GetOrCreateInternalCache("lifetime2", 3);
+ ASSERT_EQ(lifetime2->size(), 2);
+
+ cache_.Remove("lifetime2");
+ auto lifetime2_after_removed = cache_.GetOrCreateInternalCache("lifetime2", 3);
+ ASSERT_EQ(lifetime2_after_removed->size(), 0);
+}
+
+TEST_F(TwoLevelCacheWithExpirationTest, RemoveAllTokens) {
+ cache_.Clear();
+
+ // All tokens has been removed from the cache.
+ // Now the cache create a new object which has no item.
+ auto lifetime1 = cache_.GetOrCreateInternalCache("lifetime1", 1);
+ ASSERT_EQ(lifetime1->size(), 0);
+
+ auto lifetime2 = cache_.GetOrCreateInternalCache("lifetime2", 3);
+ ASSERT_EQ(lifetime2->size(), 0);
+}
+
+TEST_F(TwoLevelCacheWithExpirationTest, Clear) {
+ cache_.Clear();
+
+ // All tokens has been removed from the cache.
+ // Now the cache create a new object which has no item.
+ auto lifetime1 = cache_.GetOrCreateInternalCache("lifetime1", 1);
+ ASSERT_EQ(lifetime1->size(), 0);
+
+ auto lifetime2 = cache_.GetOrCreateInternalCache("lifetime2", 3);
+ ASSERT_EQ(lifetime2->size(), 0);
+}
+
+TEST_F(TwoLevelCacheWithExpirationTest, MultiThread) {
+ std::vector<std::thread> insert_threads;
+ for (int i = 0; i < 10; i++) {
+ insert_threads.emplace_back([this, i]() { this->TaskInsert(i); });
+ }
+ std::thread clean_thread([this]() { this->TaskClean(); });
+
+ for (auto& th : insert_threads) {
+ th.join();
+ }
+ clean_thread.join();
+}
+
+} // namespace test
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/encryption/write_configurations_test.cc b/src/arrow/cpp/src/parquet/encryption/write_configurations_test.cc
new file mode 100644
index 000000000..a7b5a284f
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/encryption/write_configurations_test.cc
@@ -0,0 +1,234 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <stdio.h>
+
+#include <arrow/io/file.h>
+
+#include "parquet/column_reader.h"
+#include "parquet/column_writer.h"
+#include "parquet/encryption/test_encryption_util.h"
+#include "parquet/file_reader.h"
+#include "parquet/file_writer.h"
+#include "parquet/platform.h"
+#include "parquet/test_util.h"
+
+/*
+ * This file contains unit-tests for writing encrypted Parquet files with
+ * different encryption configurations.
+ * The files are saved in temporary folder and will be deleted after reading
+ * them in encryption-read-configurations-test.cc test.
+ *
+ * A detailed description of the Parquet Modular Encryption specification can be found
+ * here:
+ * https://github.com/apache/parquet-format/blob/encryption/Encryption.md
+ *
+ * Each unit-test creates a single parquet file with eight columns using one of the
+ * following encryption configurations:
+ *
+ * - Encryption configuration 1: Encrypt all columns and the footer with the same key.
+ * (uniform encryption)
+ * - Encryption configuration 2: Encrypt two columns and the footer, with different
+ * keys.
+ * - Encryption configuration 3: Encrypt two columns, with different keys.
+ * Don’t encrypt footer (to enable legacy readers)
+ * - plaintext footer mode.
+ * - Encryption configuration 4: Encrypt two columns and the footer, with different
+ * keys. Supply aad_prefix for file identity
+ * verification.
+ * - Encryption configuration 5: Encrypt two columns and the footer, with different
+ * keys. Supply aad_prefix, and call
+ * disable_aad_prefix_storage to prevent file
+ * identity storage in file metadata.
+ * - Encryption configuration 6: Encrypt two columns and the footer, with different
+ * keys. Use the alternative (AES_GCM_CTR_V1) algorithm.
+ */
+
+namespace parquet {
+namespace encryption {
+namespace test {
+
+using FileClass = ::arrow::io::FileOutputStream;
+
+std::unique_ptr<TemporaryDir> temp_dir;
+
+class TestEncryptionConfiguration : public ::testing::Test {
+ public:
+ static void SetUpTestCase();
+
+ protected:
+ FileEncryptor encryptor_;
+
+ std::string path_to_double_field_ = kDoubleFieldName;
+ std::string path_to_float_field_ = kFloatFieldName;
+ std::string file_name_;
+ std::string kFooterEncryptionKey_ = std::string(kFooterEncryptionKey);
+ std::string kColumnEncryptionKey1_ = std::string(kColumnEncryptionKey1);
+ std::string kColumnEncryptionKey2_ = std::string(kColumnEncryptionKey2);
+ std::string kFileName_ = std::string(kFileName);
+
+ void EncryptFile(
+ std::shared_ptr<parquet::FileEncryptionProperties> encryption_configurations,
+ std::string file_name) {
+ std::string file = temp_dir->path().ToString() + file_name;
+ encryptor_.EncryptFile(file, encryption_configurations);
+ }
+};
+
+// Encryption configuration 1: Encrypt all columns and the footer with the same key.
+// (uniform encryption)
+TEST_F(TestEncryptionConfiguration, UniformEncryption) {
+ parquet::FileEncryptionProperties::Builder file_encryption_builder_1(
+ kFooterEncryptionKey_);
+
+ this->EncryptFile(file_encryption_builder_1.footer_key_metadata("kf")->build(),
+ "tmp_uniform_encryption.parquet.encrypted");
+}
+
+// Encryption configuration 2: Encrypt two columns and the footer, with different keys.
+TEST_F(TestEncryptionConfiguration, EncryptTwoColumnsAndTheFooter) {
+ std::map<std::string, std::shared_ptr<parquet::ColumnEncryptionProperties>>
+ encryption_cols2;
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_20(
+ path_to_double_field_);
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_21(
+ path_to_float_field_);
+ encryption_col_builder_20.key(kColumnEncryptionKey1_)->key_id("kc1");
+ encryption_col_builder_21.key(kColumnEncryptionKey2_)->key_id("kc2");
+
+ encryption_cols2[path_to_double_field_] = encryption_col_builder_20.build();
+ encryption_cols2[path_to_float_field_] = encryption_col_builder_21.build();
+
+ parquet::FileEncryptionProperties::Builder file_encryption_builder_2(
+ kFooterEncryptionKey_);
+
+ this->EncryptFile(file_encryption_builder_2.footer_key_metadata("kf")
+ ->encrypted_columns(encryption_cols2)
+ ->build(),
+ "tmp_encrypt_columns_and_footer.parquet.encrypted");
+}
+
+// Encryption configuration 3: Encrypt two columns, with different keys.
+// Don’t encrypt footer.
+// (plaintext footer mode, readable by legacy readers)
+TEST_F(TestEncryptionConfiguration, EncryptTwoColumnsWithPlaintextFooter) {
+ std::map<std::string, std::shared_ptr<parquet::ColumnEncryptionProperties>>
+ encryption_cols3;
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_30(
+ path_to_double_field_);
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_31(
+ path_to_float_field_);
+ encryption_col_builder_30.key(kColumnEncryptionKey1_)->key_id("kc1");
+ encryption_col_builder_31.key(kColumnEncryptionKey2_)->key_id("kc2");
+
+ encryption_cols3[path_to_double_field_] = encryption_col_builder_30.build();
+ encryption_cols3[path_to_float_field_] = encryption_col_builder_31.build();
+ parquet::FileEncryptionProperties::Builder file_encryption_builder_3(
+ kFooterEncryptionKey_);
+
+ this->EncryptFile(file_encryption_builder_3.footer_key_metadata("kf")
+ ->encrypted_columns(encryption_cols3)
+ ->set_plaintext_footer()
+ ->build(),
+ "tmp_encrypt_columns_plaintext_footer.parquet.encrypted");
+}
+
+// Encryption configuration 4: Encrypt two columns and the footer, with different keys.
+// Use aad_prefix.
+TEST_F(TestEncryptionConfiguration, EncryptTwoColumnsAndFooterWithAadPrefix) {
+ std::map<std::string, std::shared_ptr<parquet::ColumnEncryptionProperties>>
+ encryption_cols4;
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_40(
+ path_to_double_field_);
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_41(
+ path_to_float_field_);
+ encryption_col_builder_40.key(kColumnEncryptionKey1_)->key_id("kc1");
+ encryption_col_builder_41.key(kColumnEncryptionKey2_)->key_id("kc2");
+
+ encryption_cols4[path_to_double_field_] = encryption_col_builder_40.build();
+ encryption_cols4[path_to_float_field_] = encryption_col_builder_41.build();
+ parquet::FileEncryptionProperties::Builder file_encryption_builder_4(
+ kFooterEncryptionKey_);
+
+ this->EncryptFile(file_encryption_builder_4.footer_key_metadata("kf")
+ ->encrypted_columns(encryption_cols4)
+ ->aad_prefix(kFileName_)
+ ->build(),
+ "tmp_encrypt_columns_and_footer_aad.parquet.encrypted");
+}
+
+// Encryption configuration 5: Encrypt two columns and the footer, with different keys.
+// Use aad_prefix and disable_aad_prefix_storage.
+TEST_F(TestEncryptionConfiguration,
+ EncryptTwoColumnsAndFooterWithAadPrefixDisable_aad_prefix_storage) {
+ std::map<std::string, std::shared_ptr<parquet::ColumnEncryptionProperties>>
+ encryption_cols5;
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_50(
+ path_to_double_field_);
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_51(
+ path_to_float_field_);
+ encryption_col_builder_50.key(kColumnEncryptionKey1_)->key_id("kc1");
+ encryption_col_builder_51.key(kColumnEncryptionKey2_)->key_id("kc2");
+
+ encryption_cols5[path_to_double_field_] = encryption_col_builder_50.build();
+ encryption_cols5[path_to_float_field_] = encryption_col_builder_51.build();
+ parquet::FileEncryptionProperties::Builder file_encryption_builder_5(
+ kFooterEncryptionKey_);
+
+ this->EncryptFile(
+ file_encryption_builder_5.encrypted_columns(encryption_cols5)
+ ->footer_key_metadata("kf")
+ ->aad_prefix(kFileName_)
+ ->disable_aad_prefix_storage()
+ ->build(),
+ "tmp_encrypt_columns_and_footer_disable_aad_storage.parquet.encrypted");
+}
+
+// Encryption configuration 6: Encrypt two columns and the footer, with different keys.
+// Use AES_GCM_CTR_V1 algorithm.
+TEST_F(TestEncryptionConfiguration, EncryptTwoColumnsAndFooterUseAES_GCM_CTR) {
+ std::map<std::string, std::shared_ptr<parquet::ColumnEncryptionProperties>>
+ encryption_cols6;
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_60(
+ path_to_double_field_);
+ parquet::ColumnEncryptionProperties::Builder encryption_col_builder_61(
+ path_to_float_field_);
+ encryption_col_builder_60.key(kColumnEncryptionKey1_)->key_id("kc1");
+ encryption_col_builder_61.key(kColumnEncryptionKey2_)->key_id("kc2");
+
+ encryption_cols6[path_to_double_field_] = encryption_col_builder_60.build();
+ encryption_cols6[path_to_float_field_] = encryption_col_builder_61.build();
+ parquet::FileEncryptionProperties::Builder file_encryption_builder_6(
+ kFooterEncryptionKey_);
+
+ EXPECT_NO_THROW(
+ this->EncryptFile(file_encryption_builder_6.footer_key_metadata("kf")
+ ->encrypted_columns(encryption_cols6)
+ ->algorithm(parquet::ParquetCipher::AES_GCM_CTR_V1)
+ ->build(),
+ "tmp_encrypt_columns_and_footer_ctr.parquet.encrypted"));
+}
+
+// Set temp_dir before running the write/read tests. The encrypted files will
+// be written/read from this directory.
+void TestEncryptionConfiguration::SetUpTestCase() { temp_dir = *temp_data_dir(); }
+
+} // namespace test
+} // namespace encryption
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/exception.cc b/src/arrow/cpp/src/parquet/exception.cc
new file mode 100644
index 000000000..c333957dd
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/exception.cc
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/exception.h"
+
+namespace parquet {
+
+std::ostream& operator<<(std::ostream& os, const ParquetException& exception) {
+ os << exception.what();
+ return os;
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/exception.h b/src/arrow/cpp/src/parquet/exception.h
new file mode 100644
index 000000000..826f5bdc8
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/exception.h
@@ -0,0 +1,158 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <exception>
+#include <sstream>
+#include <string>
+#include <utility>
+
+#include "arrow/type_fwd.h"
+#include "arrow/util/string_builder.h"
+#include "parquet/platform.h"
+
+// PARQUET-1085
+#if !defined(ARROW_UNUSED)
+#define ARROW_UNUSED(x) UNUSED(x)
+#endif
+
+// Parquet exception to Arrow Status
+
+#define BEGIN_PARQUET_CATCH_EXCEPTIONS try {
+#define END_PARQUET_CATCH_EXCEPTIONS \
+ } \
+ catch (const ::parquet::ParquetStatusException& e) { \
+ return e.status(); \
+ } \
+ catch (const ::parquet::ParquetException& e) { \
+ return ::arrow::Status::IOError(e.what()); \
+ }
+
+// clang-format off
+
+#define PARQUET_CATCH_NOT_OK(s) \
+ BEGIN_PARQUET_CATCH_EXCEPTIONS \
+ (s); \
+ END_PARQUET_CATCH_EXCEPTIONS
+
+// clang-format on
+
+#define PARQUET_CATCH_AND_RETURN(s) \
+ BEGIN_PARQUET_CATCH_EXCEPTIONS \
+ return (s); \
+ END_PARQUET_CATCH_EXCEPTIONS
+
+// Arrow Status to Parquet exception
+
+#define PARQUET_IGNORE_NOT_OK(s) \
+ do { \
+ ::arrow::Status _s = ::arrow::internal::GenericToStatus(s); \
+ ARROW_UNUSED(_s); \
+ } while (0)
+
+#define PARQUET_THROW_NOT_OK(s) \
+ do { \
+ ::arrow::Status _s = ::arrow::internal::GenericToStatus(s); \
+ if (!_s.ok()) { \
+ throw ::parquet::ParquetStatusException(std::move(_s)); \
+ } \
+ } while (0)
+
+#define PARQUET_ASSIGN_OR_THROW_IMPL(status_name, lhs, rexpr) \
+ auto status_name = (rexpr); \
+ PARQUET_THROW_NOT_OK(status_name.status()); \
+ lhs = std::move(status_name).ValueOrDie();
+
+#define PARQUET_ASSIGN_OR_THROW(lhs, rexpr) \
+ PARQUET_ASSIGN_OR_THROW_IMPL(ARROW_ASSIGN_OR_RAISE_NAME(_error_or_value, __COUNTER__), \
+ lhs, rexpr);
+
+namespace parquet {
+
+class ParquetException : public std::exception {
+ public:
+ PARQUET_NORETURN static void EofException(const std::string& msg = "") {
+ static std::string prefix = "Unexpected end of stream";
+ if (msg.empty()) {
+ throw ParquetException(prefix);
+ }
+ throw ParquetException(prefix, ": ", msg);
+ }
+
+ PARQUET_NORETURN static void NYI(const std::string& msg = "") {
+ throw ParquetException("Not yet implemented: ", msg, ".");
+ }
+
+ template <typename... Args>
+ explicit ParquetException(Args&&... args)
+ : msg_(::arrow::util::StringBuilder(std::forward<Args>(args)...)) {}
+
+ explicit ParquetException(std::string msg) : msg_(std::move(msg)) {}
+
+ explicit ParquetException(const char* msg, const std::exception&) : msg_(msg) {}
+
+ ParquetException(const ParquetException&) = default;
+ ParquetException& operator=(const ParquetException&) = default;
+ ParquetException(ParquetException&&) = default;
+ ParquetException& operator=(ParquetException&&) = default;
+
+ const char* what() const noexcept override { return msg_.c_str(); }
+
+ private:
+ std::string msg_;
+};
+
+// Support printing a ParquetException.
+// This is needed for clang-on-MSVC as there operator<< is not defined for
+// std::exception.
+PARQUET_EXPORT
+std::ostream& operator<<(std::ostream& os, const ParquetException& exception);
+
+class ParquetStatusException : public ParquetException {
+ public:
+ explicit ParquetStatusException(::arrow::Status status)
+ : ParquetException(status.ToString()), status_(std::move(status)) {}
+
+ const ::arrow::Status& status() const { return status_; }
+
+ private:
+ ::arrow::Status status_;
+};
+
+// This class exists for the purpose of detecting an invalid or corrupted file.
+class ParquetInvalidOrCorruptedFileException : public ParquetStatusException {
+ public:
+ ParquetInvalidOrCorruptedFileException(const ParquetInvalidOrCorruptedFileException&) =
+ default;
+
+ template <typename Arg,
+ typename std::enable_if<
+ !std::is_base_of<ParquetInvalidOrCorruptedFileException, Arg>::value,
+ int>::type = 0,
+ typename... Args>
+ explicit ParquetInvalidOrCorruptedFileException(Arg arg, Args&&... args)
+ : ParquetStatusException(::arrow::Status::Invalid(std::forward<Arg>(arg),
+ std::forward<Args>(args)...)) {}
+};
+
+template <typename StatusReturnBlock>
+void ThrowNotOk(StatusReturnBlock&& b) {
+ PARQUET_THROW_NOT_OK(b());
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/file_deserialize_test.cc b/src/arrow/cpp/src/parquet/file_deserialize_test.cc
new file mode 100644
index 000000000..d0d333256
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/file_deserialize_test.cc
@@ -0,0 +1,372 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+
+#include "parquet/column_page.h"
+#include "parquet/column_reader.h"
+#include "parquet/exception.h"
+#include "parquet/file_reader.h"
+#include "parquet/platform.h"
+#include "parquet/test_util.h"
+#include "parquet/thrift_internal.h"
+#include "parquet/types.h"
+
+#include "arrow/io/memory.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/compression.h"
+
+namespace parquet {
+
+using ::arrow::io::BufferReader;
+
+// Adds page statistics occupying a certain amount of bytes (for testing very
+// large page headers)
+template <typename H>
+static inline void AddDummyStats(int stat_size, H& header, bool fill_all_stats = false) {
+ std::vector<uint8_t> stat_bytes(stat_size);
+ // Some non-zero value
+ std::fill(stat_bytes.begin(), stat_bytes.end(), 1);
+ header.statistics.__set_max(
+ std::string(reinterpret_cast<const char*>(stat_bytes.data()), stat_size));
+
+ if (fill_all_stats) {
+ header.statistics.__set_min(
+ std::string(reinterpret_cast<const char*>(stat_bytes.data()), stat_size));
+ header.statistics.__set_null_count(42);
+ header.statistics.__set_distinct_count(1);
+ }
+
+ header.__isset.statistics = true;
+}
+
+template <typename H>
+static inline void CheckStatistics(const H& expected, const EncodedStatistics& actual) {
+ if (expected.statistics.__isset.max) {
+ ASSERT_EQ(expected.statistics.max, actual.max());
+ }
+ if (expected.statistics.__isset.min) {
+ ASSERT_EQ(expected.statistics.min, actual.min());
+ }
+ if (expected.statistics.__isset.null_count) {
+ ASSERT_EQ(expected.statistics.null_count, actual.null_count);
+ }
+ if (expected.statistics.__isset.distinct_count) {
+ ASSERT_EQ(expected.statistics.distinct_count, actual.distinct_count);
+ }
+}
+
+class TestPageSerde : public ::testing::Test {
+ public:
+ void SetUp() {
+ data_page_header_.encoding = format::Encoding::PLAIN;
+ data_page_header_.definition_level_encoding = format::Encoding::RLE;
+ data_page_header_.repetition_level_encoding = format::Encoding::RLE;
+
+ ResetStream();
+ }
+
+ void InitSerializedPageReader(int64_t num_rows,
+ Compression::type codec = Compression::UNCOMPRESSED) {
+ EndStream();
+
+ auto stream = std::make_shared<::arrow::io::BufferReader>(out_buffer_);
+ page_reader_ = PageReader::Open(stream, num_rows, codec);
+ }
+
+ void WriteDataPageHeader(int max_serialized_len = 1024, int32_t uncompressed_size = 0,
+ int32_t compressed_size = 0) {
+ // Simplifying writing serialized data page headers which may or may not
+ // have meaningful data associated with them
+
+ // Serialize the Page header
+ page_header_.__set_data_page_header(data_page_header_);
+ page_header_.uncompressed_page_size = uncompressed_size;
+ page_header_.compressed_page_size = compressed_size;
+ page_header_.type = format::PageType::DATA_PAGE;
+
+ ThriftSerializer serializer;
+ ASSERT_NO_THROW(serializer.Serialize(&page_header_, out_stream_.get()));
+ }
+
+ void WriteDataPageHeaderV2(int max_serialized_len = 1024, int32_t uncompressed_size = 0,
+ int32_t compressed_size = 0) {
+ // Simplifying writing serialized data page V2 headers which may or may not
+ // have meaningful data associated with them
+
+ // Serialize the Page header
+ page_header_.__set_data_page_header_v2(data_page_header_v2_);
+ page_header_.uncompressed_page_size = uncompressed_size;
+ page_header_.compressed_page_size = compressed_size;
+ page_header_.type = format::PageType::DATA_PAGE_V2;
+
+ ThriftSerializer serializer;
+ ASSERT_NO_THROW(serializer.Serialize(&page_header_, out_stream_.get()));
+ }
+
+ void ResetStream() { out_stream_ = CreateOutputStream(); }
+
+ void EndStream() { PARQUET_ASSIGN_OR_THROW(out_buffer_, out_stream_->Finish()); }
+
+ protected:
+ std::shared_ptr<::arrow::io::BufferOutputStream> out_stream_;
+ std::shared_ptr<Buffer> out_buffer_;
+
+ std::unique_ptr<PageReader> page_reader_;
+ format::PageHeader page_header_;
+ format::DataPageHeader data_page_header_;
+ format::DataPageHeaderV2 data_page_header_v2_;
+};
+
+void CheckDataPageHeader(const format::DataPageHeader expected, const Page* page) {
+ ASSERT_EQ(PageType::DATA_PAGE, page->type());
+
+ const DataPageV1* data_page = static_cast<const DataPageV1*>(page);
+ ASSERT_EQ(expected.num_values, data_page->num_values());
+ ASSERT_EQ(expected.encoding, data_page->encoding());
+ ASSERT_EQ(expected.definition_level_encoding, data_page->definition_level_encoding());
+ ASSERT_EQ(expected.repetition_level_encoding, data_page->repetition_level_encoding());
+ CheckStatistics(expected, data_page->statistics());
+}
+
+// Overload for DataPageV2 tests.
+void CheckDataPageHeader(const format::DataPageHeaderV2 expected, const Page* page) {
+ ASSERT_EQ(PageType::DATA_PAGE_V2, page->type());
+
+ const DataPageV2* data_page = static_cast<const DataPageV2*>(page);
+ ASSERT_EQ(expected.num_values, data_page->num_values());
+ ASSERT_EQ(expected.num_nulls, data_page->num_nulls());
+ ASSERT_EQ(expected.num_rows, data_page->num_rows());
+ ASSERT_EQ(expected.encoding, data_page->encoding());
+ ASSERT_EQ(expected.definition_levels_byte_length,
+ data_page->definition_levels_byte_length());
+ ASSERT_EQ(expected.repetition_levels_byte_length,
+ data_page->repetition_levels_byte_length());
+ ASSERT_EQ(expected.is_compressed, data_page->is_compressed());
+ CheckStatistics(expected, data_page->statistics());
+}
+
+TEST_F(TestPageSerde, DataPageV1) {
+ int stats_size = 512;
+ const int32_t num_rows = 4444;
+ AddDummyStats(stats_size, data_page_header_, /* fill_all_stats = */ true);
+ data_page_header_.num_values = num_rows;
+
+ ASSERT_NO_FATAL_FAILURE(WriteDataPageHeader());
+ InitSerializedPageReader(num_rows);
+ std::shared_ptr<Page> current_page = page_reader_->NextPage();
+ ASSERT_NO_FATAL_FAILURE(CheckDataPageHeader(data_page_header_, current_page.get()));
+}
+
+TEST_F(TestPageSerde, DataPageV2) {
+ int stats_size = 512;
+ const int32_t num_rows = 4444;
+ AddDummyStats(stats_size, data_page_header_v2_, /* fill_all_stats = */ true);
+ data_page_header_v2_.num_values = num_rows;
+
+ ASSERT_NO_FATAL_FAILURE(WriteDataPageHeaderV2());
+ InitSerializedPageReader(num_rows);
+ std::shared_ptr<Page> current_page = page_reader_->NextPage();
+ ASSERT_NO_FATAL_FAILURE(CheckDataPageHeader(data_page_header_v2_, current_page.get()));
+}
+
+TEST_F(TestPageSerde, TestLargePageHeaders) {
+ int stats_size = 256 * 1024; // 256 KB
+ AddDummyStats(stats_size, data_page_header_);
+
+ // Any number to verify metadata roundtrip
+ const int32_t num_rows = 4141;
+ data_page_header_.num_values = num_rows;
+
+ int max_header_size = 512 * 1024; // 512 KB
+ ASSERT_NO_FATAL_FAILURE(WriteDataPageHeader(max_header_size));
+
+ ASSERT_OK_AND_ASSIGN(int64_t position, out_stream_->Tell());
+ ASSERT_GE(max_header_size, position);
+
+ // check header size is between 256 KB to 16 MB
+ ASSERT_LE(stats_size, position);
+ ASSERT_GE(kDefaultMaxPageHeaderSize, position);
+
+ InitSerializedPageReader(num_rows);
+ std::shared_ptr<Page> current_page = page_reader_->NextPage();
+ ASSERT_NO_FATAL_FAILURE(CheckDataPageHeader(data_page_header_, current_page.get()));
+}
+
+TEST_F(TestPageSerde, TestFailLargePageHeaders) {
+ const int32_t num_rows = 1337; // dummy value
+
+ int stats_size = 256 * 1024; // 256 KB
+ AddDummyStats(stats_size, data_page_header_);
+
+ // Serialize the Page header
+ int max_header_size = 512 * 1024; // 512 KB
+ ASSERT_NO_FATAL_FAILURE(WriteDataPageHeader(max_header_size));
+ ASSERT_OK_AND_ASSIGN(int64_t position, out_stream_->Tell());
+ ASSERT_GE(max_header_size, position);
+
+ int smaller_max_size = 128 * 1024;
+ ASSERT_LE(smaller_max_size, position);
+ InitSerializedPageReader(num_rows);
+
+ // Set the max page header size to 128 KB, which is less than the current
+ // header size
+ page_reader_->set_max_page_header_size(smaller_max_size);
+ ASSERT_THROW(page_reader_->NextPage(), ParquetException);
+}
+
+TEST_F(TestPageSerde, Compression) {
+ std::vector<Compression::type> codec_types;
+
+#ifdef ARROW_WITH_SNAPPY
+ codec_types.push_back(Compression::SNAPPY);
+#endif
+
+#ifdef ARROW_WITH_BROTLI
+ codec_types.push_back(Compression::BROTLI);
+#endif
+
+#ifdef ARROW_WITH_GZIP
+ codec_types.push_back(Compression::GZIP);
+#endif
+
+#ifdef ARROW_WITH_LZ4
+ codec_types.push_back(Compression::LZ4);
+ codec_types.push_back(Compression::LZ4_HADOOP);
+#endif
+
+#ifdef ARROW_WITH_ZSTD
+ codec_types.push_back(Compression::ZSTD);
+#endif
+
+ const int32_t num_rows = 32; // dummy value
+ data_page_header_.num_values = num_rows;
+
+ const int num_pages = 10;
+
+ std::vector<std::vector<uint8_t>> faux_data;
+ faux_data.resize(num_pages);
+ for (int i = 0; i < num_pages; ++i) {
+ // The pages keep getting larger
+ int page_size = (i + 1) * 64;
+ test::random_bytes(page_size, 0, &faux_data[i]);
+ }
+ for (auto codec_type : codec_types) {
+ auto codec = GetCodec(codec_type);
+
+ std::vector<uint8_t> buffer;
+ for (int i = 0; i < num_pages; ++i) {
+ const uint8_t* data = faux_data[i].data();
+ int data_size = static_cast<int>(faux_data[i].size());
+
+ int64_t max_compressed_size = codec->MaxCompressedLen(data_size, data);
+ buffer.resize(max_compressed_size);
+
+ int64_t actual_size;
+ ASSERT_OK_AND_ASSIGN(
+ actual_size, codec->Compress(data_size, data, max_compressed_size, &buffer[0]));
+
+ ASSERT_NO_FATAL_FAILURE(
+ WriteDataPageHeader(1024, data_size, static_cast<int32_t>(actual_size)));
+ ASSERT_OK(out_stream_->Write(buffer.data(), actual_size));
+ }
+
+ InitSerializedPageReader(num_rows * num_pages, codec_type);
+
+ std::shared_ptr<Page> page;
+ const DataPageV1* data_page;
+ for (int i = 0; i < num_pages; ++i) {
+ int data_size = static_cast<int>(faux_data[i].size());
+ page = page_reader_->NextPage();
+ data_page = static_cast<const DataPageV1*>(page.get());
+ ASSERT_EQ(data_size, data_page->size());
+ ASSERT_EQ(0, memcmp(faux_data[i].data(), data_page->data(), data_size));
+ }
+
+ ResetStream();
+ }
+} // namespace parquet
+
+TEST_F(TestPageSerde, LZONotSupported) {
+ // Must await PARQUET-530
+ int data_size = 1024;
+ std::vector<uint8_t> faux_data(data_size);
+ ASSERT_NO_FATAL_FAILURE(WriteDataPageHeader(1024, data_size, data_size));
+ ASSERT_OK(out_stream_->Write(faux_data.data(), data_size));
+ ASSERT_THROW(InitSerializedPageReader(data_size, Compression::LZO), ParquetException);
+}
+
+// ----------------------------------------------------------------------
+// File structure tests
+
+class TestParquetFileReader : public ::testing::Test {
+ public:
+ void AssertInvalidFileThrows(const std::shared_ptr<Buffer>& buffer) {
+ reader_.reset(new ParquetFileReader());
+
+ auto reader = std::make_shared<BufferReader>(buffer);
+
+ ASSERT_THROW(reader_->Open(ParquetFileReader::Contents::Open(reader)),
+ ParquetException);
+ }
+
+ protected:
+ std::unique_ptr<ParquetFileReader> reader_;
+};
+
+TEST_F(TestParquetFileReader, InvalidHeader) {
+ const char* bad_header = "PAR2";
+
+ auto buffer = Buffer::Wrap(bad_header, strlen(bad_header));
+ ASSERT_NO_FATAL_FAILURE(AssertInvalidFileThrows(buffer));
+}
+
+TEST_F(TestParquetFileReader, InvalidFooter) {
+ // File is smaller than FOOTER_SIZE
+ const char* bad_file = "PAR1PAR";
+ auto buffer = Buffer::Wrap(bad_file, strlen(bad_file));
+ ASSERT_NO_FATAL_FAILURE(AssertInvalidFileThrows(buffer));
+
+ // Magic number incorrect
+ const char* bad_file2 = "PAR1PAR2";
+ buffer = Buffer::Wrap(bad_file2, strlen(bad_file2));
+ ASSERT_NO_FATAL_FAILURE(AssertInvalidFileThrows(buffer));
+}
+
+TEST_F(TestParquetFileReader, IncompleteMetadata) {
+ auto stream = CreateOutputStream();
+
+ const char* magic = "PAR1";
+
+ ASSERT_OK(stream->Write(reinterpret_cast<const uint8_t*>(magic), strlen(magic)));
+ std::vector<uint8_t> bytes(10);
+ ASSERT_OK(stream->Write(bytes.data(), bytes.size()));
+ uint32_t metadata_len = 24;
+ ASSERT_OK(
+ stream->Write(reinterpret_cast<const uint8_t*>(&metadata_len), sizeof(uint32_t)));
+ ASSERT_OK(stream->Write(reinterpret_cast<const uint8_t*>(magic), strlen(magic)));
+
+ ASSERT_OK_AND_ASSIGN(auto buffer, stream->Finish());
+ ASSERT_NO_FATAL_FAILURE(AssertInvalidFileThrows(buffer));
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/file_reader.cc b/src/arrow/cpp/src/parquet/file_reader.cc
new file mode 100644
index 000000000..4e38901aa
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/file_reader.cc
@@ -0,0 +1,868 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/file_reader.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <utility>
+
+#include "arrow/io/caching.h"
+#include "arrow/io/file.h"
+#include "arrow/io/memory.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/int_util_internal.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/ubsan.h"
+#include "parquet/column_reader.h"
+#include "parquet/column_scanner.h"
+#include "parquet/encryption/encryption_internal.h"
+#include "parquet/encryption/internal_file_decryptor.h"
+#include "parquet/exception.h"
+#include "parquet/file_writer.h"
+#include "parquet/metadata.h"
+#include "parquet/platform.h"
+#include "parquet/properties.h"
+#include "parquet/schema.h"
+#include "parquet/types.h"
+
+using arrow::internal::AddWithOverflow;
+
+namespace parquet {
+
+// PARQUET-978: Minimize footer reads by reading 64 KB from the end of the file
+static constexpr int64_t kDefaultFooterReadSize = 64 * 1024;
+static constexpr uint32_t kFooterSize = 8;
+
+// For PARQUET-816
+static constexpr int64_t kMaxDictHeaderSize = 100;
+
+// ----------------------------------------------------------------------
+// RowGroupReader public API
+
+RowGroupReader::RowGroupReader(std::unique_ptr<Contents> contents)
+ : contents_(std::move(contents)) {}
+
+std::shared_ptr<ColumnReader> RowGroupReader::Column(int i) {
+ if (i >= metadata()->num_columns()) {
+ std::stringstream ss;
+ ss << "Trying to read column index " << i << " but row group metadata has only "
+ << metadata()->num_columns() << " columns";
+ throw ParquetException(ss.str());
+ }
+ const ColumnDescriptor* descr = metadata()->schema()->Column(i);
+
+ std::unique_ptr<PageReader> page_reader = contents_->GetColumnPageReader(i);
+ return ColumnReader::Make(
+ descr, std::move(page_reader),
+ const_cast<ReaderProperties*>(contents_->properties())->memory_pool());
+}
+
+std::shared_ptr<ColumnReader> RowGroupReader::ColumnWithExposeEncoding(
+ int i, ExposedEncoding encoding_to_expose) {
+ std::shared_ptr<ColumnReader> reader = Column(i);
+
+ if (encoding_to_expose == ExposedEncoding::DICTIONARY) {
+ // Check the encoding_stats to see if all data pages are dictionary encoded.
+ std::unique_ptr<ColumnChunkMetaData> col = metadata()->ColumnChunk(i);
+ const std::vector<PageEncodingStats>& encoding_stats = col->encoding_stats();
+ if (encoding_stats.empty()) {
+ // Some parquet files may have empty encoding_stats. In this case we are
+ // not sure whether all data pages are dictionary encoded. So we do not
+ // enable exposing dictionary.
+ return reader;
+ }
+ // The 1st page should be the dictionary page.
+ if (encoding_stats[0].page_type != PageType::DICTIONARY_PAGE ||
+ (encoding_stats[0].encoding != Encoding::PLAIN &&
+ encoding_stats[0].encoding != Encoding::PLAIN_DICTIONARY)) {
+ return reader;
+ }
+ // The following pages should be dictionary encoded data pages.
+ for (size_t idx = 1; idx < encoding_stats.size(); ++idx) {
+ if ((encoding_stats[idx].encoding != Encoding::RLE_DICTIONARY &&
+ encoding_stats[idx].encoding != Encoding::PLAIN_DICTIONARY) ||
+ (encoding_stats[idx].page_type != PageType::DATA_PAGE &&
+ encoding_stats[idx].page_type != PageType::DATA_PAGE_V2)) {
+ return reader;
+ }
+ }
+ } else {
+ // Exposing other encodings are not supported for now.
+ return reader;
+ }
+
+ // Set exposed encoding.
+ reader->SetExposedEncoding(encoding_to_expose);
+ return reader;
+}
+
+std::unique_ptr<PageReader> RowGroupReader::GetColumnPageReader(int i) {
+ if (i >= metadata()->num_columns()) {
+ std::stringstream ss;
+ ss << "Trying to read column index " << i << " but row group metadata has only "
+ << metadata()->num_columns() << " columns";
+ throw ParquetException(ss.str());
+ }
+ return contents_->GetColumnPageReader(i);
+}
+
+// Returns the rowgroup metadata
+const RowGroupMetaData* RowGroupReader::metadata() const { return contents_->metadata(); }
+
+/// Compute the section of the file that should be read for the given
+/// row group and column chunk.
+::arrow::io::ReadRange ComputeColumnChunkRange(FileMetaData* file_metadata,
+ int64_t source_size, int row_group_index,
+ int column_index) {
+ auto row_group_metadata = file_metadata->RowGroup(row_group_index);
+ auto column_metadata = row_group_metadata->ColumnChunk(column_index);
+
+ int64_t col_start = column_metadata->data_page_offset();
+ if (column_metadata->has_dictionary_page() &&
+ column_metadata->dictionary_page_offset() > 0 &&
+ col_start > column_metadata->dictionary_page_offset()) {
+ col_start = column_metadata->dictionary_page_offset();
+ }
+
+ int64_t col_length = column_metadata->total_compressed_size();
+ int64_t col_end;
+ if (AddWithOverflow(col_start, col_length, &col_end) || col_end > source_size) {
+ throw ParquetException("Invalid column metadata (corrupt file?)");
+ }
+
+ // PARQUET-816 workaround for old files created by older parquet-mr
+ const ApplicationVersion& version = file_metadata->writer_version();
+ if (version.VersionLt(ApplicationVersion::PARQUET_816_FIXED_VERSION())) {
+ // The Parquet MR writer had a bug in 1.2.8 and below where it didn't include the
+ // dictionary page header size in total_compressed_size and total_uncompressed_size
+ // (see IMPALA-694). We add padding to compensate.
+ int64_t bytes_remaining = source_size - col_end;
+ int64_t padding = std::min<int64_t>(kMaxDictHeaderSize, bytes_remaining);
+ col_length += padding;
+ }
+
+ return {col_start, col_length};
+}
+
+// RowGroupReader::Contents implementation for the Parquet file specification
+class SerializedRowGroup : public RowGroupReader::Contents {
+ public:
+ SerializedRowGroup(std::shared_ptr<ArrowInputFile> source,
+ std::shared_ptr<::arrow::io::internal::ReadRangeCache> cached_source,
+ int64_t source_size, FileMetaData* file_metadata,
+ int row_group_number, const ReaderProperties& props,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor = nullptr)
+ : source_(std::move(source)),
+ cached_source_(std::move(cached_source)),
+ source_size_(source_size),
+ file_metadata_(file_metadata),
+ properties_(props),
+ row_group_ordinal_(row_group_number),
+ file_decryptor_(file_decryptor) {
+ row_group_metadata_ = file_metadata->RowGroup(row_group_number);
+ }
+
+ const RowGroupMetaData* metadata() const override { return row_group_metadata_.get(); }
+
+ const ReaderProperties* properties() const override { return &properties_; }
+
+ std::unique_ptr<PageReader> GetColumnPageReader(int i) override {
+ // Read column chunk from the file
+ auto col = row_group_metadata_->ColumnChunk(i);
+
+ ::arrow::io::ReadRange col_range =
+ ComputeColumnChunkRange(file_metadata_, source_size_, row_group_ordinal_, i);
+ std::shared_ptr<ArrowInputStream> stream;
+ if (cached_source_) {
+ // PARQUET-1698: if read coalescing is enabled, read from pre-buffered
+ // segments.
+ PARQUET_ASSIGN_OR_THROW(auto buffer, cached_source_->Read(col_range));
+ stream = std::make_shared<::arrow::io::BufferReader>(buffer);
+ } else {
+ stream = properties_.GetStream(source_, col_range.offset, col_range.length);
+ }
+
+ std::unique_ptr<ColumnCryptoMetaData> crypto_metadata = col->crypto_metadata();
+
+ // Column is encrypted only if crypto_metadata exists.
+ if (!crypto_metadata) {
+ return PageReader::Open(stream, col->num_values(), col->compression(),
+ properties_.memory_pool());
+ }
+
+ if (file_decryptor_ == nullptr) {
+ throw ParquetException("RowGroup is noted as encrypted but no file decryptor");
+ }
+
+ constexpr auto kEncryptedRowGroupsLimit = 32767;
+ if (i > kEncryptedRowGroupsLimit) {
+ throw ParquetException("Encrypted files cannot contain more than 32767 row groups");
+ }
+
+ // The column is encrypted
+ std::shared_ptr<Decryptor> meta_decryptor;
+ std::shared_ptr<Decryptor> data_decryptor;
+ // The column is encrypted with footer key
+ if (crypto_metadata->encrypted_with_footer_key()) {
+ meta_decryptor = file_decryptor_->GetFooterDecryptorForColumnMeta();
+ data_decryptor = file_decryptor_->GetFooterDecryptorForColumnData();
+ CryptoContext ctx(col->has_dictionary_page(), row_group_ordinal_,
+ static_cast<int16_t>(i), meta_decryptor, data_decryptor);
+ return PageReader::Open(stream, col->num_values(), col->compression(),
+ properties_.memory_pool(), &ctx);
+ }
+
+ // The column is encrypted with its own key
+ std::string column_key_metadata = crypto_metadata->key_metadata();
+ const std::string column_path = crypto_metadata->path_in_schema()->ToDotString();
+
+ meta_decryptor =
+ file_decryptor_->GetColumnMetaDecryptor(column_path, column_key_metadata);
+ data_decryptor =
+ file_decryptor_->GetColumnDataDecryptor(column_path, column_key_metadata);
+
+ CryptoContext ctx(col->has_dictionary_page(), row_group_ordinal_,
+ static_cast<int16_t>(i), meta_decryptor, data_decryptor);
+ return PageReader::Open(stream, col->num_values(), col->compression(),
+ properties_.memory_pool(), &ctx);
+ }
+
+ private:
+ std::shared_ptr<ArrowInputFile> source_;
+ // Will be nullptr if PreBuffer() is not called.
+ std::shared_ptr<::arrow::io::internal::ReadRangeCache> cached_source_;
+ int64_t source_size_;
+ FileMetaData* file_metadata_;
+ std::unique_ptr<RowGroupMetaData> row_group_metadata_;
+ ReaderProperties properties_;
+ int row_group_ordinal_;
+ std::shared_ptr<InternalFileDecryptor> file_decryptor_;
+};
+
+// ----------------------------------------------------------------------
+// SerializedFile: An implementation of ParquetFileReader::Contents that deals
+// with the Parquet file structure, Thrift deserialization, and other internal
+// matters
+
+// This class takes ownership of the provided data source
+class SerializedFile : public ParquetFileReader::Contents {
+ public:
+ SerializedFile(std::shared_ptr<ArrowInputFile> source,
+ const ReaderProperties& props = default_reader_properties())
+ : source_(std::move(source)), properties_(props) {
+ PARQUET_ASSIGN_OR_THROW(source_size_, source_->GetSize());
+ }
+
+ ~SerializedFile() override {
+ try {
+ Close();
+ } catch (...) {
+ }
+ }
+
+ void Close() override {
+ if (file_decryptor_) file_decryptor_->WipeOutDecryptionKeys();
+ }
+
+ std::shared_ptr<RowGroupReader> GetRowGroup(int i) override {
+ std::unique_ptr<SerializedRowGroup> contents(
+ new SerializedRowGroup(source_, cached_source_, source_size_,
+ file_metadata_.get(), i, properties_, file_decryptor_));
+ return std::make_shared<RowGroupReader>(std::move(contents));
+ }
+
+ std::shared_ptr<FileMetaData> metadata() const override { return file_metadata_; }
+
+ void set_metadata(std::shared_ptr<FileMetaData> metadata) {
+ file_metadata_ = std::move(metadata);
+ }
+
+ void PreBuffer(const std::vector<int>& row_groups,
+ const std::vector<int>& column_indices,
+ const ::arrow::io::IOContext& ctx,
+ const ::arrow::io::CacheOptions& options) {
+ cached_source_ =
+ std::make_shared<::arrow::io::internal::ReadRangeCache>(source_, ctx, options);
+ std::vector<::arrow::io::ReadRange> ranges;
+ for (int row : row_groups) {
+ for (int col : column_indices) {
+ ranges.push_back(
+ ComputeColumnChunkRange(file_metadata_.get(), source_size_, row, col));
+ }
+ }
+ PARQUET_THROW_NOT_OK(cached_source_->Cache(ranges));
+ }
+
+ ::arrow::Future<> WhenBuffered(const std::vector<int>& row_groups,
+ const std::vector<int>& column_indices) const {
+ if (!cached_source_) {
+ return ::arrow::Status::Invalid("Must call PreBuffer before WhenBuffered");
+ }
+ std::vector<::arrow::io::ReadRange> ranges;
+ for (int row : row_groups) {
+ for (int col : column_indices) {
+ ranges.push_back(
+ ComputeColumnChunkRange(file_metadata_.get(), source_size_, row, col));
+ }
+ }
+ return cached_source_->WaitFor(ranges);
+ }
+
+ // Metadata/footer parsing. Divided up to separate sync/async paths, and to use
+ // exceptions for error handling (with the async path converting to Future/Status).
+
+ void ParseMetaData() {
+ int64_t footer_read_size = GetFooterReadSize();
+ PARQUET_ASSIGN_OR_THROW(
+ auto footer_buffer,
+ source_->ReadAt(source_size_ - footer_read_size, footer_read_size));
+ uint32_t metadata_len = ParseFooterLength(footer_buffer, footer_read_size);
+ int64_t metadata_start = source_size_ - kFooterSize - metadata_len;
+
+ std::shared_ptr<::arrow::Buffer> metadata_buffer;
+ if (footer_read_size >= (metadata_len + kFooterSize)) {
+ metadata_buffer = SliceBuffer(
+ footer_buffer, footer_read_size - metadata_len - kFooterSize, metadata_len);
+ } else {
+ PARQUET_ASSIGN_OR_THROW(metadata_buffer,
+ source_->ReadAt(metadata_start, metadata_len));
+ }
+
+ // Parse the footer depending on encryption type
+ const bool is_encrypted_footer =
+ memcmp(footer_buffer->data() + footer_read_size - 4, kParquetEMagic, 4) == 0;
+ if (is_encrypted_footer) {
+ // Encrypted file with Encrypted footer.
+ const std::pair<int64_t, uint32_t> read_size =
+ ParseMetaDataOfEncryptedFileWithEncryptedFooter(metadata_buffer, metadata_len);
+ // Read the actual footer
+ metadata_start = read_size.first;
+ metadata_len = read_size.second;
+ PARQUET_ASSIGN_OR_THROW(metadata_buffer,
+ source_->ReadAt(metadata_start, metadata_len));
+ // Fall through
+ }
+
+ const uint32_t read_metadata_len =
+ ParseUnencryptedFileMetadata(metadata_buffer, metadata_len);
+ auto file_decryption_properties = properties_.file_decryption_properties().get();
+ if (is_encrypted_footer) {
+ // Nothing else to do here.
+ return;
+ } else if (!file_metadata_->is_encryption_algorithm_set()) { // Non encrypted file.
+ if (file_decryption_properties != nullptr) {
+ if (!file_decryption_properties->plaintext_files_allowed()) {
+ throw ParquetException("Applying decryption properties on plaintext file");
+ }
+ }
+ } else {
+ // Encrypted file with plaintext footer mode.
+ ParseMetaDataOfEncryptedFileWithPlaintextFooter(
+ file_decryption_properties, metadata_buffer, metadata_len, read_metadata_len);
+ }
+ }
+
+ // Validate the source size and get the initial read size.
+ int64_t GetFooterReadSize() {
+ if (source_size_ == 0) {
+ throw ParquetInvalidOrCorruptedFileException("Parquet file size is 0 bytes");
+ } else if (source_size_ < kFooterSize) {
+ throw ParquetInvalidOrCorruptedFileException(
+ "Parquet file size is ", source_size_,
+ " bytes, smaller than the minimum file footer (", kFooterSize, " bytes)");
+ }
+ return std::min(source_size_, kDefaultFooterReadSize);
+ }
+
+ // Validate the magic bytes and get the length of the full footer.
+ uint32_t ParseFooterLength(const std::shared_ptr<::arrow::Buffer>& footer_buffer,
+ const int64_t footer_read_size) {
+ // Check if all bytes are read. Check if last 4 bytes read have the magic bits
+ if (footer_buffer->size() != footer_read_size ||
+ (memcmp(footer_buffer->data() + footer_read_size - 4, kParquetMagic, 4) != 0 &&
+ memcmp(footer_buffer->data() + footer_read_size - 4, kParquetEMagic, 4) != 0)) {
+ throw ParquetInvalidOrCorruptedFileException(
+ "Parquet magic bytes not found in footer. Either the file is corrupted or this "
+ "is not a parquet file.");
+ }
+ // Both encrypted/unencrypted footers have the same footer length check.
+ uint32_t metadata_len = ::arrow::util::SafeLoadAs<uint32_t>(
+ reinterpret_cast<const uint8_t*>(footer_buffer->data()) + footer_read_size -
+ kFooterSize);
+ if (metadata_len > source_size_ - kFooterSize) {
+ throw ParquetInvalidOrCorruptedFileException(
+ "Parquet file size is ", source_size_,
+ " bytes, smaller than the size reported by footer's (", metadata_len, "bytes)");
+ }
+ return metadata_len;
+ }
+
+ // Does not throw.
+ ::arrow::Future<> ParseMetaDataAsync() {
+ int64_t footer_read_size;
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ footer_read_size = GetFooterReadSize();
+ END_PARQUET_CATCH_EXCEPTIONS
+ // Assumes this is kept alive externally
+ return source_->ReadAsync(source_size_ - footer_read_size, footer_read_size)
+ .Then([=](const std::shared_ptr<::arrow::Buffer>& footer_buffer)
+ -> ::arrow::Future<> {
+ uint32_t metadata_len;
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ metadata_len = ParseFooterLength(footer_buffer, footer_read_size);
+ END_PARQUET_CATCH_EXCEPTIONS
+ int64_t metadata_start = source_size_ - kFooterSize - metadata_len;
+
+ std::shared_ptr<::arrow::Buffer> metadata_buffer;
+ if (footer_read_size >= (metadata_len + kFooterSize)) {
+ metadata_buffer =
+ SliceBuffer(footer_buffer, footer_read_size - metadata_len - kFooterSize,
+ metadata_len);
+ return ParseMaybeEncryptedMetaDataAsync(footer_buffer,
+ std::move(metadata_buffer),
+ footer_read_size, metadata_len);
+ }
+ return source_->ReadAsync(metadata_start, metadata_len)
+ .Then([=](const std::shared_ptr<::arrow::Buffer>& metadata_buffer) {
+ return ParseMaybeEncryptedMetaDataAsync(footer_buffer, metadata_buffer,
+ footer_read_size, metadata_len);
+ });
+ });
+ }
+
+ // Continuation
+ ::arrow::Future<> ParseMaybeEncryptedMetaDataAsync(
+ std::shared_ptr<::arrow::Buffer> footer_buffer,
+ std::shared_ptr<::arrow::Buffer> metadata_buffer, int64_t footer_read_size,
+ uint32_t metadata_len) {
+ // Parse the footer depending on encryption type
+ const bool is_encrypted_footer =
+ memcmp(footer_buffer->data() + footer_read_size - 4, kParquetEMagic, 4) == 0;
+ if (is_encrypted_footer) {
+ // Encrypted file with Encrypted footer.
+ std::pair<int64_t, uint32_t> read_size;
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ read_size =
+ ParseMetaDataOfEncryptedFileWithEncryptedFooter(metadata_buffer, metadata_len);
+ END_PARQUET_CATCH_EXCEPTIONS
+ // Read the actual footer
+ int64_t metadata_start = read_size.first;
+ metadata_len = read_size.second;
+ return source_->ReadAsync(metadata_start, metadata_len)
+ .Then([=](const std::shared_ptr<::arrow::Buffer>& metadata_buffer) {
+ // Continue and read the file footer
+ return ParseMetaDataFinal(metadata_buffer, metadata_len, is_encrypted_footer);
+ });
+ }
+ return ParseMetaDataFinal(std::move(metadata_buffer), metadata_len,
+ is_encrypted_footer);
+ }
+
+ // Continuation
+ ::arrow::Status ParseMetaDataFinal(std::shared_ptr<::arrow::Buffer> metadata_buffer,
+ uint32_t metadata_len,
+ const bool is_encrypted_footer) {
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ const uint32_t read_metadata_len =
+ ParseUnencryptedFileMetadata(metadata_buffer, metadata_len);
+ auto file_decryption_properties = properties_.file_decryption_properties().get();
+ if (is_encrypted_footer) {
+ // Nothing else to do here.
+ return ::arrow::Status::OK();
+ } else if (!file_metadata_->is_encryption_algorithm_set()) { // Non encrypted file.
+ if (file_decryption_properties != nullptr) {
+ if (!file_decryption_properties->plaintext_files_allowed()) {
+ throw ParquetException("Applying decryption properties on plaintext file");
+ }
+ }
+ } else {
+ // Encrypted file with plaintext footer mode.
+ ParseMetaDataOfEncryptedFileWithPlaintextFooter(
+ file_decryption_properties, metadata_buffer, metadata_len, read_metadata_len);
+ }
+ END_PARQUET_CATCH_EXCEPTIONS
+ return ::arrow::Status::OK();
+ }
+
+ private:
+ std::shared_ptr<ArrowInputFile> source_;
+ std::shared_ptr<::arrow::io::internal::ReadRangeCache> cached_source_;
+ int64_t source_size_;
+ std::shared_ptr<FileMetaData> file_metadata_;
+ ReaderProperties properties_;
+
+ std::shared_ptr<InternalFileDecryptor> file_decryptor_;
+
+ // \return The true length of the metadata in bytes
+ uint32_t ParseUnencryptedFileMetadata(const std::shared_ptr<Buffer>& footer_buffer,
+ const uint32_t metadata_len);
+
+ std::string HandleAadPrefix(FileDecryptionProperties* file_decryption_properties,
+ EncryptionAlgorithm& algo);
+
+ void ParseMetaDataOfEncryptedFileWithPlaintextFooter(
+ FileDecryptionProperties* file_decryption_properties,
+ const std::shared_ptr<Buffer>& metadata_buffer, uint32_t metadata_len,
+ uint32_t read_metadata_len);
+
+ // \return The position and size of the actual footer
+ std::pair<int64_t, uint32_t> ParseMetaDataOfEncryptedFileWithEncryptedFooter(
+ const std::shared_ptr<Buffer>& crypto_metadata_buffer, uint32_t footer_len);
+};
+
+uint32_t SerializedFile::ParseUnencryptedFileMetadata(
+ const std::shared_ptr<Buffer>& metadata_buffer, const uint32_t metadata_len) {
+ if (metadata_buffer->size() != metadata_len) {
+ throw ParquetException("Failed reading metadata buffer (requested " +
+ std::to_string(metadata_len) + " bytes but got " +
+ std::to_string(metadata_buffer->size()) + " bytes)");
+ }
+ uint32_t read_metadata_len = metadata_len;
+ // The encrypted read path falls through to here, so pass in the decryptor
+ file_metadata_ =
+ FileMetaData::Make(metadata_buffer->data(), &read_metadata_len, file_decryptor_);
+ return read_metadata_len;
+}
+
+std::pair<int64_t, uint32_t>
+SerializedFile::ParseMetaDataOfEncryptedFileWithEncryptedFooter(
+ const std::shared_ptr<::arrow::Buffer>& crypto_metadata_buffer,
+ // both metadata & crypto metadata length
+ const uint32_t footer_len) {
+ // encryption with encrypted footer
+ // Check if the footer_buffer contains the entire metadata
+ if (crypto_metadata_buffer->size() != footer_len) {
+ throw ParquetException("Failed reading encrypted metadata buffer (requested " +
+ std::to_string(footer_len) + " bytes but got " +
+ std::to_string(crypto_metadata_buffer->size()) + " bytes)");
+ }
+ auto file_decryption_properties = properties_.file_decryption_properties().get();
+ if (file_decryption_properties == nullptr) {
+ throw ParquetException(
+ "Could not read encrypted metadata, no decryption found in reader's properties");
+ }
+ uint32_t crypto_metadata_len = footer_len;
+ std::shared_ptr<FileCryptoMetaData> file_crypto_metadata =
+ FileCryptoMetaData::Make(crypto_metadata_buffer->data(), &crypto_metadata_len);
+ // Handle AAD prefix
+ EncryptionAlgorithm algo = file_crypto_metadata->encryption_algorithm();
+ std::string file_aad = HandleAadPrefix(file_decryption_properties, algo);
+ file_decryptor_ = std::make_shared<InternalFileDecryptor>(
+ file_decryption_properties, file_aad, algo.algorithm,
+ file_crypto_metadata->key_metadata(), properties_.memory_pool());
+
+ int64_t metadata_offset = source_size_ - kFooterSize - footer_len + crypto_metadata_len;
+ uint32_t metadata_len = footer_len - crypto_metadata_len;
+ return std::make_pair(metadata_offset, metadata_len);
+}
+
+void SerializedFile::ParseMetaDataOfEncryptedFileWithPlaintextFooter(
+ FileDecryptionProperties* file_decryption_properties,
+ const std::shared_ptr<Buffer>& metadata_buffer, uint32_t metadata_len,
+ uint32_t read_metadata_len) {
+ // Providing decryption properties in plaintext footer mode is not mandatory, for
+ // example when reading by legacy reader.
+ if (file_decryption_properties != nullptr) {
+ EncryptionAlgorithm algo = file_metadata_->encryption_algorithm();
+ // Handle AAD prefix
+ std::string file_aad = HandleAadPrefix(file_decryption_properties, algo);
+ file_decryptor_ = std::make_shared<InternalFileDecryptor>(
+ file_decryption_properties, file_aad, algo.algorithm,
+ file_metadata_->footer_signing_key_metadata(), properties_.memory_pool());
+ // set the InternalFileDecryptor in the metadata as well, as it's used
+ // for signature verification and for ColumnChunkMetaData creation.
+ file_metadata_->set_file_decryptor(file_decryptor_);
+
+ if (file_decryption_properties->check_plaintext_footer_integrity()) {
+ if (metadata_len - read_metadata_len !=
+ (parquet::encryption::kGcmTagLength + parquet::encryption::kNonceLength)) {
+ throw ParquetInvalidOrCorruptedFileException(
+ "Failed reading metadata for encryption signature (requested ",
+ parquet::encryption::kGcmTagLength + parquet::encryption::kNonceLength,
+ " bytes but have ", metadata_len - read_metadata_len, " bytes)");
+ }
+
+ if (!file_metadata_->VerifySignature(metadata_buffer->data() + read_metadata_len)) {
+ throw ParquetInvalidOrCorruptedFileException(
+ "Parquet crypto signature verification failed");
+ }
+ }
+ }
+}
+
+std::string SerializedFile::HandleAadPrefix(
+ FileDecryptionProperties* file_decryption_properties, EncryptionAlgorithm& algo) {
+ std::string aad_prefix_in_properties = file_decryption_properties->aad_prefix();
+ std::string aad_prefix = aad_prefix_in_properties;
+ bool file_has_aad_prefix = algo.aad.aad_prefix.empty() ? false : true;
+ std::string aad_prefix_in_file = algo.aad.aad_prefix;
+
+ if (algo.aad.supply_aad_prefix && aad_prefix_in_properties.empty()) {
+ throw ParquetException(
+ "AAD prefix used for file encryption, "
+ "but not stored in file and not supplied "
+ "in decryption properties");
+ }
+
+ if (file_has_aad_prefix) {
+ if (!aad_prefix_in_properties.empty()) {
+ if (aad_prefix_in_properties.compare(aad_prefix_in_file) != 0) {
+ throw ParquetException(
+ "AAD Prefix in file and in properties "
+ "is not the same");
+ }
+ }
+ aad_prefix = aad_prefix_in_file;
+ std::shared_ptr<AADPrefixVerifier> aad_prefix_verifier =
+ file_decryption_properties->aad_prefix_verifier();
+ if (aad_prefix_verifier != nullptr) aad_prefix_verifier->Verify(aad_prefix);
+ } else {
+ if (!algo.aad.supply_aad_prefix && !aad_prefix_in_properties.empty()) {
+ throw ParquetException(
+ "AAD Prefix set in decryption properties, but was not used "
+ "for file encryption");
+ }
+ std::shared_ptr<AADPrefixVerifier> aad_prefix_verifier =
+ file_decryption_properties->aad_prefix_verifier();
+ if (aad_prefix_verifier != nullptr) {
+ throw ParquetException(
+ "AAD Prefix Verifier is set, but AAD Prefix not found in file");
+ }
+ }
+ return aad_prefix + algo.aad.aad_file_unique;
+}
+
+// ----------------------------------------------------------------------
+// ParquetFileReader public API
+
+ParquetFileReader::ParquetFileReader() {}
+
+ParquetFileReader::~ParquetFileReader() {
+ try {
+ Close();
+ } catch (...) {
+ }
+}
+
+// Open the file. If no metadata is passed, it is parsed from the footer of
+// the file
+std::unique_ptr<ParquetFileReader::Contents> ParquetFileReader::Contents::Open(
+ std::shared_ptr<ArrowInputFile> source, const ReaderProperties& props,
+ std::shared_ptr<FileMetaData> metadata) {
+ std::unique_ptr<ParquetFileReader::Contents> result(
+ new SerializedFile(std::move(source), props));
+
+ // Access private methods here, but otherwise unavailable
+ SerializedFile* file = static_cast<SerializedFile*>(result.get());
+
+ if (metadata == nullptr) {
+ // Validates magic bytes, parses metadata, and initializes the SchemaDescriptor
+ file->ParseMetaData();
+ } else {
+ file->set_metadata(std::move(metadata));
+ }
+
+ return result;
+}
+
+::arrow::Future<std::unique_ptr<ParquetFileReader::Contents>>
+ParquetFileReader::Contents::OpenAsync(std::shared_ptr<ArrowInputFile> source,
+ const ReaderProperties& props,
+ std::shared_ptr<FileMetaData> metadata) {
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ std::unique_ptr<ParquetFileReader::Contents> result(
+ new SerializedFile(std::move(source), props));
+ SerializedFile* file = static_cast<SerializedFile*>(result.get());
+ if (metadata == nullptr) {
+ // TODO(ARROW-12259): workaround since we have Future<(move-only type)>
+ struct {
+ ::arrow::Result<std::unique_ptr<ParquetFileReader::Contents>> operator()() {
+ return std::move(result);
+ }
+
+ std::unique_ptr<ParquetFileReader::Contents> result;
+ } Continuation;
+ Continuation.result = std::move(result);
+ return file->ParseMetaDataAsync().Then(std::move(Continuation));
+ } else {
+ file->set_metadata(std::move(metadata));
+ return ::arrow::Future<std::unique_ptr<ParquetFileReader::Contents>>::MakeFinished(
+ std::move(result));
+ }
+ END_PARQUET_CATCH_EXCEPTIONS
+}
+
+std::unique_ptr<ParquetFileReader> ParquetFileReader::Open(
+ std::shared_ptr<::arrow::io::RandomAccessFile> source, const ReaderProperties& props,
+ std::shared_ptr<FileMetaData> metadata) {
+ auto contents = SerializedFile::Open(std::move(source), props, std::move(metadata));
+ std::unique_ptr<ParquetFileReader> result(new ParquetFileReader());
+ result->Open(std::move(contents));
+ return result;
+}
+
+std::unique_ptr<ParquetFileReader> ParquetFileReader::OpenFile(
+ const std::string& path, bool memory_map, const ReaderProperties& props,
+ std::shared_ptr<FileMetaData> metadata) {
+ std::shared_ptr<::arrow::io::RandomAccessFile> source;
+ if (memory_map) {
+ PARQUET_ASSIGN_OR_THROW(
+ source, ::arrow::io::MemoryMappedFile::Open(path, ::arrow::io::FileMode::READ));
+ } else {
+ PARQUET_ASSIGN_OR_THROW(source,
+ ::arrow::io::ReadableFile::Open(path, props.memory_pool()));
+ }
+
+ return Open(std::move(source), props, std::move(metadata));
+}
+
+::arrow::Future<std::unique_ptr<ParquetFileReader>> ParquetFileReader::OpenAsync(
+ std::shared_ptr<::arrow::io::RandomAccessFile> source, const ReaderProperties& props,
+ std::shared_ptr<FileMetaData> metadata) {
+ BEGIN_PARQUET_CATCH_EXCEPTIONS
+ auto fut = SerializedFile::OpenAsync(std::move(source), props, std::move(metadata));
+ // TODO(ARROW-12259): workaround since we have Future<(move-only type)>
+ auto completed = ::arrow::Future<std::unique_ptr<ParquetFileReader>>::Make();
+ fut.AddCallback([fut, completed](
+ const ::arrow::Result<std::unique_ptr<ParquetFileReader::Contents>>&
+ contents) mutable {
+ if (!contents.ok()) {
+ completed.MarkFinished(contents.status());
+ return;
+ }
+ std::unique_ptr<ParquetFileReader> result(new ParquetFileReader());
+ result->Open(fut.MoveResult().MoveValueUnsafe());
+ completed.MarkFinished(std::move(result));
+ });
+ return completed;
+ END_PARQUET_CATCH_EXCEPTIONS
+}
+
+void ParquetFileReader::Open(std::unique_ptr<ParquetFileReader::Contents> contents) {
+ contents_ = std::move(contents);
+}
+
+void ParquetFileReader::Close() {
+ if (contents_) {
+ contents_->Close();
+ }
+}
+
+std::shared_ptr<FileMetaData> ParquetFileReader::metadata() const {
+ return contents_->metadata();
+}
+
+std::shared_ptr<RowGroupReader> ParquetFileReader::RowGroup(int i) {
+ if (i >= metadata()->num_row_groups()) {
+ std::stringstream ss;
+ ss << "Trying to read row group " << i << " but file only has "
+ << metadata()->num_row_groups() << " row groups";
+ throw ParquetException(ss.str());
+ }
+ return contents_->GetRowGroup(i);
+}
+
+void ParquetFileReader::PreBuffer(const std::vector<int>& row_groups,
+ const std::vector<int>& column_indices,
+ const ::arrow::io::IOContext& ctx,
+ const ::arrow::io::CacheOptions& options) {
+ // Access private methods here
+ SerializedFile* file =
+ ::arrow::internal::checked_cast<SerializedFile*>(contents_.get());
+ file->PreBuffer(row_groups, column_indices, ctx, options);
+}
+
+::arrow::Future<> ParquetFileReader::WhenBuffered(
+ const std::vector<int>& row_groups, const std::vector<int>& column_indices) const {
+ // Access private methods here
+ SerializedFile* file =
+ ::arrow::internal::checked_cast<SerializedFile*>(contents_.get());
+ return file->WhenBuffered(row_groups, column_indices);
+}
+
+// ----------------------------------------------------------------------
+// File metadata helpers
+
+std::shared_ptr<FileMetaData> ReadMetaData(
+ const std::shared_ptr<::arrow::io::RandomAccessFile>& source) {
+ return ParquetFileReader::Open(source)->metadata();
+}
+
+// ----------------------------------------------------------------------
+// File scanner for performance testing
+
+int64_t ScanFileContents(std::vector<int> columns, const int32_t column_batch_size,
+ ParquetFileReader* reader) {
+ std::vector<int16_t> rep_levels(column_batch_size);
+ std::vector<int16_t> def_levels(column_batch_size);
+
+ int num_columns = static_cast<int>(columns.size());
+
+ // columns are not specified explicitly. Add all columns
+ if (columns.size() == 0) {
+ num_columns = reader->metadata()->num_columns();
+ columns.resize(num_columns);
+ for (int i = 0; i < num_columns; i++) {
+ columns[i] = i;
+ }
+ }
+
+ std::vector<int64_t> total_rows(num_columns, 0);
+
+ for (int r = 0; r < reader->metadata()->num_row_groups(); ++r) {
+ auto group_reader = reader->RowGroup(r);
+ int col = 0;
+ for (auto i : columns) {
+ std::shared_ptr<ColumnReader> col_reader = group_reader->Column(i);
+ size_t value_byte_size = GetTypeByteSize(col_reader->descr()->physical_type());
+ std::vector<uint8_t> values(column_batch_size * value_byte_size);
+
+ int64_t values_read = 0;
+ while (col_reader->HasNext()) {
+ int64_t levels_read =
+ ScanAllValues(column_batch_size, def_levels.data(), rep_levels.data(),
+ values.data(), &values_read, col_reader.get());
+ if (col_reader->descr()->max_repetition_level() > 0) {
+ for (int64_t i = 0; i < levels_read; i++) {
+ if (rep_levels[i] == 0) {
+ total_rows[col]++;
+ }
+ }
+ } else {
+ total_rows[col] += levels_read;
+ }
+ }
+ col++;
+ }
+ }
+
+ for (int i = 1; i < num_columns; ++i) {
+ if (total_rows[0] != total_rows[i]) {
+ throw ParquetException("Parquet error: Total rows among columns do not match");
+ }
+ }
+
+ return total_rows[0];
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/file_reader.h b/src/arrow/cpp/src/parquet/file_reader.h
new file mode 100644
index 000000000..0fc840549
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/file_reader.h
@@ -0,0 +1,188 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/io/caching.h"
+#include "arrow/util/type_fwd.h"
+#include "parquet/metadata.h" // IWYU pragma: keep
+#include "parquet/platform.h"
+#include "parquet/properties.h"
+
+namespace parquet {
+
+class ColumnReader;
+class FileMetaData;
+class PageReader;
+class RowGroupMetaData;
+
+class PARQUET_EXPORT RowGroupReader {
+ public:
+ // Forward declare a virtual class 'Contents' to aid dependency injection and more
+ // easily create test fixtures
+ // An implementation of the Contents class is defined in the .cc file
+ struct Contents {
+ virtual ~Contents() {}
+ virtual std::unique_ptr<PageReader> GetColumnPageReader(int i) = 0;
+ virtual const RowGroupMetaData* metadata() const = 0;
+ virtual const ReaderProperties* properties() const = 0;
+ };
+
+ explicit RowGroupReader(std::unique_ptr<Contents> contents);
+
+ // Returns the rowgroup metadata
+ const RowGroupMetaData* metadata() const;
+
+ // Construct a ColumnReader for the indicated row group-relative
+ // column. Ownership is shared with the RowGroupReader.
+ std::shared_ptr<ColumnReader> Column(int i);
+
+ // Construct a ColumnReader, trying to enable exposed encoding.
+ //
+ // For dictionary encoding, currently we only support column chunks that are fully
+ // dictionary encoded, i.e., all data pages in the column chunk are dictionary encoded.
+ // If a column chunk uses dictionary encoding but then falls back to plain encoding, the
+ // encoding will not be exposed.
+ //
+ // The returned column reader provides an API GetExposedEncoding() for the
+ // users to check the exposed encoding and determine how to read the batches.
+ //
+ // \note API EXPERIMENTAL
+ std::shared_ptr<ColumnReader> ColumnWithExposeEncoding(
+ int i, ExposedEncoding encoding_to_expose);
+
+ std::unique_ptr<PageReader> GetColumnPageReader(int i);
+
+ private:
+ // Holds a pointer to an instance of Contents implementation
+ std::unique_ptr<Contents> contents_;
+};
+
+class PARQUET_EXPORT ParquetFileReader {
+ public:
+ // Declare a virtual class 'Contents' to aid dependency injection and more
+ // easily create test fixtures
+ // An implementation of the Contents class is defined in the .cc file
+ struct PARQUET_EXPORT Contents {
+ static std::unique_ptr<Contents> Open(
+ std::shared_ptr<::arrow::io::RandomAccessFile> source,
+ const ReaderProperties& props = default_reader_properties(),
+ std::shared_ptr<FileMetaData> metadata = NULLPTR);
+
+ static ::arrow::Future<std::unique_ptr<Contents>> OpenAsync(
+ std::shared_ptr<::arrow::io::RandomAccessFile> source,
+ const ReaderProperties& props = default_reader_properties(),
+ std::shared_ptr<FileMetaData> metadata = NULLPTR);
+
+ virtual ~Contents() = default;
+ // Perform any cleanup associated with the file contents
+ virtual void Close() = 0;
+ virtual std::shared_ptr<RowGroupReader> GetRowGroup(int i) = 0;
+ virtual std::shared_ptr<FileMetaData> metadata() const = 0;
+ };
+
+ ParquetFileReader();
+ ~ParquetFileReader();
+
+ // Create a file reader instance from an Arrow file object. Thread-safety is
+ // the responsibility of the file implementation
+ static std::unique_ptr<ParquetFileReader> Open(
+ std::shared_ptr<::arrow::io::RandomAccessFile> source,
+ const ReaderProperties& props = default_reader_properties(),
+ std::shared_ptr<FileMetaData> metadata = NULLPTR);
+
+ // API Convenience to open a serialized Parquet file on disk, using Arrow IO
+ // interfaces.
+ static std::unique_ptr<ParquetFileReader> OpenFile(
+ const std::string& path, bool memory_map = true,
+ const ReaderProperties& props = default_reader_properties(),
+ std::shared_ptr<FileMetaData> metadata = NULLPTR);
+
+ // Asynchronously open a file reader from an Arrow file object.
+ // Does not throw - all errors are reported through the Future.
+ static ::arrow::Future<std::unique_ptr<ParquetFileReader>> OpenAsync(
+ std::shared_ptr<::arrow::io::RandomAccessFile> source,
+ const ReaderProperties& props = default_reader_properties(),
+ std::shared_ptr<FileMetaData> metadata = NULLPTR);
+
+ void Open(std::unique_ptr<Contents> contents);
+ void Close();
+
+ // The RowGroupReader is owned by the FileReader
+ std::shared_ptr<RowGroupReader> RowGroup(int i);
+
+ // Returns the file metadata. Only one instance is ever created
+ std::shared_ptr<FileMetaData> metadata() const;
+
+ /// Pre-buffer the specified column indices in all row groups.
+ ///
+ /// Readers can optionally call this to cache the necessary slices
+ /// of the file in-memory before deserialization. Arrow readers can
+ /// automatically do this via an option. This is intended to
+ /// increase performance when reading from high-latency filesystems
+ /// (e.g. Amazon S3).
+ ///
+ /// After calling this, creating readers for row groups/column
+ /// indices that were not buffered may fail. Creating multiple
+ /// readers for the a subset of the buffered regions is
+ /// acceptable. This may be called again to buffer a different set
+ /// of row groups/columns.
+ ///
+ /// If memory usage is a concern, note that data will remain
+ /// buffered in memory until either \a PreBuffer() is called again,
+ /// or the reader itself is destructed. Reading - and buffering -
+ /// only one row group at a time may be useful.
+ ///
+ /// This method may throw.
+ void PreBuffer(const std::vector<int>& row_groups,
+ const std::vector<int>& column_indices,
+ const ::arrow::io::IOContext& ctx,
+ const ::arrow::io::CacheOptions& options);
+
+ /// Wait for the specified row groups and column indices to be pre-buffered.
+ ///
+ /// After the returned Future completes, reading the specified row
+ /// groups/columns will not block.
+ ///
+ /// PreBuffer must be called first. This method does not throw.
+ ::arrow::Future<> WhenBuffered(const std::vector<int>& row_groups,
+ const std::vector<int>& column_indices) const;
+
+ private:
+ // Holds a pointer to an instance of Contents implementation
+ std::unique_ptr<Contents> contents_;
+};
+
+// Read only Parquet file metadata
+std::shared_ptr<FileMetaData> PARQUET_EXPORT
+ReadMetaData(const std::shared_ptr<::arrow::io::RandomAccessFile>& source);
+
+/// \brief Scan all values in file. Useful for performance testing
+/// \param[in] columns the column numbers to scan. If empty scans all
+/// \param[in] column_batch_size number of values to read at a time when scanning column
+/// \param[in] reader a ParquetFileReader instance
+/// \return number of semantic rows in file
+PARQUET_EXPORT
+int64_t ScanFileContents(std::vector<int> columns, const int32_t column_batch_size,
+ ParquetFileReader* reader);
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/file_serialize_test.cc b/src/arrow/cpp/src/parquet/file_serialize_test.cc
new file mode 100644
index 000000000..eb1133d8a
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/file_serialize_test.cc
@@ -0,0 +1,470 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_compat.h"
+
+#include "parquet/column_reader.h"
+#include "parquet/column_writer.h"
+#include "parquet/file_reader.h"
+#include "parquet/file_writer.h"
+#include "parquet/platform.h"
+#include "parquet/test_util.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+using schema::GroupNode;
+using schema::NodePtr;
+using schema::PrimitiveNode;
+using ::testing::ElementsAre;
+
+namespace test {
+
+template <typename TestType>
+class TestSerialize : public PrimitiveTypedTest<TestType> {
+ public:
+ void SetUp() {
+ num_columns_ = 4;
+ num_rowgroups_ = 4;
+ rows_per_rowgroup_ = 50;
+ rows_per_batch_ = 10;
+ this->SetUpSchema(Repetition::OPTIONAL, num_columns_);
+ }
+
+ protected:
+ int num_columns_;
+ int num_rowgroups_;
+ int rows_per_rowgroup_;
+ int rows_per_batch_;
+
+ void FileSerializeTest(Compression::type codec_type) {
+ FileSerializeTest(codec_type, codec_type);
+ }
+
+ void FileSerializeTest(Compression::type codec_type,
+ Compression::type expected_codec_type) {
+ auto sink = CreateOutputStream();
+ auto gnode = std::static_pointer_cast<GroupNode>(this->node_);
+
+ WriterProperties::Builder prop_builder;
+
+ for (int i = 0; i < num_columns_; ++i) {
+ prop_builder.compression(this->schema_.Column(i)->name(), codec_type);
+ }
+ std::shared_ptr<WriterProperties> writer_properties = prop_builder.build();
+
+ auto file_writer = ParquetFileWriter::Open(sink, gnode, writer_properties);
+ this->GenerateData(rows_per_rowgroup_);
+ for (int rg = 0; rg < num_rowgroups_ / 2; ++rg) {
+ RowGroupWriter* row_group_writer;
+ row_group_writer = file_writer->AppendRowGroup();
+ for (int col = 0; col < num_columns_; ++col) {
+ auto column_writer =
+ static_cast<TypedColumnWriter<TestType>*>(row_group_writer->NextColumn());
+ column_writer->WriteBatch(rows_per_rowgroup_, this->def_levels_.data(), nullptr,
+ this->values_ptr_);
+ column_writer->Close();
+ // Ensure column() API which is specific to BufferedRowGroup cannot be called
+ ASSERT_THROW(row_group_writer->column(col), ParquetException);
+ }
+
+ row_group_writer->Close();
+ }
+ // Write half BufferedRowGroups
+ for (int rg = 0; rg < num_rowgroups_ / 2; ++rg) {
+ RowGroupWriter* row_group_writer;
+ row_group_writer = file_writer->AppendBufferedRowGroup();
+ for (int batch = 0; batch < (rows_per_rowgroup_ / rows_per_batch_); ++batch) {
+ for (int col = 0; col < num_columns_; ++col) {
+ auto column_writer =
+ static_cast<TypedColumnWriter<TestType>*>(row_group_writer->column(col));
+ column_writer->WriteBatch(
+ rows_per_batch_, this->def_levels_.data() + (batch * rows_per_batch_),
+ nullptr, this->values_ptr_ + (batch * rows_per_batch_));
+ // Ensure NextColumn() API which is specific to RowGroup cannot be called
+ ASSERT_THROW(row_group_writer->NextColumn(), ParquetException);
+ }
+ }
+ for (int col = 0; col < num_columns_; ++col) {
+ auto column_writer =
+ static_cast<TypedColumnWriter<TestType>*>(row_group_writer->column(col));
+ column_writer->Close();
+ }
+ row_group_writer->Close();
+ }
+ file_writer->Close();
+
+ PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish());
+
+ int num_rows_ = num_rowgroups_ * rows_per_rowgroup_;
+
+ auto source = std::make_shared<::arrow::io::BufferReader>(buffer);
+ auto file_reader = ParquetFileReader::Open(source);
+ ASSERT_EQ(num_columns_, file_reader->metadata()->num_columns());
+ ASSERT_EQ(num_rowgroups_, file_reader->metadata()->num_row_groups());
+ ASSERT_EQ(num_rows_, file_reader->metadata()->num_rows());
+
+ for (int rg = 0; rg < num_rowgroups_; ++rg) {
+ auto rg_reader = file_reader->RowGroup(rg);
+ auto rg_metadata = rg_reader->metadata();
+ ASSERT_EQ(num_columns_, rg_metadata->num_columns());
+ ASSERT_EQ(rows_per_rowgroup_, rg_metadata->num_rows());
+ // Check that the specified compression was actually used.
+ ASSERT_EQ(expected_codec_type, rg_metadata->ColumnChunk(0)->compression());
+
+ const int64_t total_byte_size = rg_metadata->total_byte_size();
+ const int64_t total_compressed_size = rg_metadata->total_compressed_size();
+ if (expected_codec_type == Compression::UNCOMPRESSED) {
+ ASSERT_EQ(total_byte_size, total_compressed_size);
+ } else {
+ ASSERT_NE(total_byte_size, total_compressed_size);
+ }
+
+ int64_t total_column_byte_size = 0;
+ int64_t total_column_compressed_size = 0;
+
+ for (int i = 0; i < num_columns_; ++i) {
+ int64_t values_read;
+ ASSERT_FALSE(rg_metadata->ColumnChunk(i)->has_index_page());
+ total_column_byte_size += rg_metadata->ColumnChunk(i)->total_uncompressed_size();
+ total_column_compressed_size +=
+ rg_metadata->ColumnChunk(i)->total_compressed_size();
+
+ std::vector<int16_t> def_levels_out(rows_per_rowgroup_);
+ std::vector<int16_t> rep_levels_out(rows_per_rowgroup_);
+ auto col_reader =
+ std::static_pointer_cast<TypedColumnReader<TestType>>(rg_reader->Column(i));
+ this->SetupValuesOut(rows_per_rowgroup_);
+ col_reader->ReadBatch(rows_per_rowgroup_, def_levels_out.data(),
+ rep_levels_out.data(), this->values_out_ptr_, &values_read);
+ this->SyncValuesOut();
+ ASSERT_EQ(rows_per_rowgroup_, values_read);
+ ASSERT_EQ(this->values_, this->values_out_);
+ ASSERT_EQ(this->def_levels_, def_levels_out);
+ }
+
+ ASSERT_EQ(total_byte_size, total_column_byte_size);
+ ASSERT_EQ(total_compressed_size, total_column_compressed_size);
+ }
+ }
+
+ void UnequalNumRows(int64_t max_rows, const std::vector<int64_t> rows_per_column) {
+ auto sink = CreateOutputStream();
+ auto gnode = std::static_pointer_cast<GroupNode>(this->node_);
+
+ std::shared_ptr<WriterProperties> props = WriterProperties::Builder().build();
+
+ auto file_writer = ParquetFileWriter::Open(sink, gnode, props);
+
+ RowGroupWriter* row_group_writer;
+ row_group_writer = file_writer->AppendRowGroup();
+
+ this->GenerateData(max_rows);
+ for (int col = 0; col < num_columns_; ++col) {
+ auto column_writer =
+ static_cast<TypedColumnWriter<TestType>*>(row_group_writer->NextColumn());
+ column_writer->WriteBatch(rows_per_column[col], this->def_levels_.data(), nullptr,
+ this->values_ptr_);
+ column_writer->Close();
+ }
+ row_group_writer->Close();
+ file_writer->Close();
+ }
+
+ void UnequalNumRowsBuffered(int64_t max_rows,
+ const std::vector<int64_t> rows_per_column) {
+ auto sink = CreateOutputStream();
+ auto gnode = std::static_pointer_cast<GroupNode>(this->node_);
+
+ std::shared_ptr<WriterProperties> props = WriterProperties::Builder().build();
+
+ auto file_writer = ParquetFileWriter::Open(sink, gnode, props);
+
+ RowGroupWriter* row_group_writer;
+ row_group_writer = file_writer->AppendBufferedRowGroup();
+
+ this->GenerateData(max_rows);
+ for (int col = 0; col < num_columns_; ++col) {
+ auto column_writer =
+ static_cast<TypedColumnWriter<TestType>*>(row_group_writer->column(col));
+ column_writer->WriteBatch(rows_per_column[col], this->def_levels_.data(), nullptr,
+ this->values_ptr_);
+ column_writer->Close();
+ }
+ row_group_writer->Close();
+ file_writer->Close();
+ }
+
+ void RepeatedUnequalRows() {
+ // Optional and repeated, so definition and repetition levels
+ this->SetUpSchema(Repetition::REPEATED);
+
+ const int kNumRows = 100;
+ this->GenerateData(kNumRows);
+
+ auto sink = CreateOutputStream();
+ auto gnode = std::static_pointer_cast<GroupNode>(this->node_);
+ std::shared_ptr<WriterProperties> props = WriterProperties::Builder().build();
+ auto file_writer = ParquetFileWriter::Open(sink, gnode, props);
+
+ RowGroupWriter* row_group_writer;
+ row_group_writer = file_writer->AppendRowGroup();
+
+ this->GenerateData(kNumRows);
+
+ std::vector<int16_t> definition_levels(kNumRows, 1);
+ std::vector<int16_t> repetition_levels(kNumRows, 0);
+
+ {
+ auto column_writer =
+ static_cast<TypedColumnWriter<TestType>*>(row_group_writer->NextColumn());
+ column_writer->WriteBatch(kNumRows, definition_levels.data(),
+ repetition_levels.data(), this->values_ptr_);
+ column_writer->Close();
+ }
+
+ definition_levels[1] = 0;
+ repetition_levels[3] = 1;
+
+ {
+ auto column_writer =
+ static_cast<TypedColumnWriter<TestType>*>(row_group_writer->NextColumn());
+ column_writer->WriteBatch(kNumRows, definition_levels.data(),
+ repetition_levels.data(), this->values_ptr_);
+ column_writer->Close();
+ }
+ }
+
+ void ZeroRowsRowGroup() {
+ auto sink = CreateOutputStream();
+ auto gnode = std::static_pointer_cast<GroupNode>(this->node_);
+
+ std::shared_ptr<WriterProperties> props = WriterProperties::Builder().build();
+
+ auto file_writer = ParquetFileWriter::Open(sink, gnode, props);
+
+ RowGroupWriter* row_group_writer;
+
+ row_group_writer = file_writer->AppendRowGroup();
+ for (int col = 0; col < num_columns_; ++col) {
+ auto column_writer =
+ static_cast<TypedColumnWriter<TestType>*>(row_group_writer->NextColumn());
+ column_writer->Close();
+ }
+ row_group_writer->Close();
+
+ row_group_writer = file_writer->AppendBufferedRowGroup();
+ for (int col = 0; col < num_columns_; ++col) {
+ auto column_writer =
+ static_cast<TypedColumnWriter<TestType>*>(row_group_writer->column(col));
+ column_writer->Close();
+ }
+ row_group_writer->Close();
+
+ file_writer->Close();
+ }
+};
+
+typedef ::testing::Types<Int32Type, Int64Type, Int96Type, FloatType, DoubleType,
+ BooleanType, ByteArrayType, FLBAType>
+ TestTypes;
+
+TYPED_TEST_SUITE(TestSerialize, TestTypes);
+
+TYPED_TEST(TestSerialize, SmallFileUncompressed) {
+ ASSERT_NO_FATAL_FAILURE(this->FileSerializeTest(Compression::UNCOMPRESSED));
+}
+
+TYPED_TEST(TestSerialize, TooFewRows) {
+ std::vector<int64_t> num_rows = {100, 100, 100, 99};
+ ASSERT_THROW(this->UnequalNumRows(100, num_rows), ParquetException);
+ ASSERT_THROW(this->UnequalNumRowsBuffered(100, num_rows), ParquetException);
+}
+
+TYPED_TEST(TestSerialize, TooManyRows) {
+ std::vector<int64_t> num_rows = {100, 100, 100, 101};
+ ASSERT_THROW(this->UnequalNumRows(101, num_rows), ParquetException);
+ ASSERT_THROW(this->UnequalNumRowsBuffered(101, num_rows), ParquetException);
+}
+
+TYPED_TEST(TestSerialize, ZeroRows) { ASSERT_NO_THROW(this->ZeroRowsRowGroup()); }
+
+TYPED_TEST(TestSerialize, RepeatedTooFewRows) {
+ ASSERT_THROW(this->RepeatedUnequalRows(), ParquetException);
+}
+
+#ifdef ARROW_WITH_SNAPPY
+TYPED_TEST(TestSerialize, SmallFileSnappy) {
+ ASSERT_NO_FATAL_FAILURE(this->FileSerializeTest(Compression::SNAPPY));
+}
+#endif
+
+#ifdef ARROW_WITH_BROTLI
+TYPED_TEST(TestSerialize, SmallFileBrotli) {
+ ASSERT_NO_FATAL_FAILURE(this->FileSerializeTest(Compression::BROTLI));
+}
+#endif
+
+#ifdef ARROW_WITH_GZIP
+TYPED_TEST(TestSerialize, SmallFileGzip) {
+ ASSERT_NO_FATAL_FAILURE(this->FileSerializeTest(Compression::GZIP));
+}
+#endif
+
+#ifdef ARROW_WITH_LZ4
+TYPED_TEST(TestSerialize, SmallFileLz4) {
+ ASSERT_NO_FATAL_FAILURE(this->FileSerializeTest(Compression::LZ4));
+}
+
+TYPED_TEST(TestSerialize, SmallFileLz4Hadoop) {
+ ASSERT_NO_FATAL_FAILURE(this->FileSerializeTest(Compression::LZ4_HADOOP));
+}
+#endif
+
+#ifdef ARROW_WITH_ZSTD
+TYPED_TEST(TestSerialize, SmallFileZstd) {
+ ASSERT_NO_FATAL_FAILURE(this->FileSerializeTest(Compression::ZSTD));
+}
+#endif
+
+TEST(TestBufferedRowGroupWriter, DisabledDictionary) {
+ // PARQUET-1706:
+ // Wrong dictionary_page_offset when writing only data pages via BufferedPageWriter
+ auto sink = CreateOutputStream();
+ auto writer_props = parquet::WriterProperties::Builder().disable_dictionary()->build();
+ schema::NodeVector fields;
+ fields.push_back(
+ PrimitiveNode::Make("col", parquet::Repetition::REQUIRED, parquet::Type::INT32));
+ auto schema = std::static_pointer_cast<GroupNode>(
+ GroupNode::Make("schema", Repetition::REQUIRED, fields));
+ auto file_writer = parquet::ParquetFileWriter::Open(sink, schema, writer_props);
+ auto rg_writer = file_writer->AppendBufferedRowGroup();
+ auto col_writer = static_cast<Int32Writer*>(rg_writer->column(0));
+ int value = 0;
+ col_writer->WriteBatch(1, nullptr, nullptr, &value);
+ rg_writer->Close();
+ file_writer->Close();
+ PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish());
+
+ auto source = std::make_shared<::arrow::io::BufferReader>(buffer);
+ auto file_reader = ParquetFileReader::Open(source);
+ ASSERT_EQ(1, file_reader->metadata()->num_row_groups());
+ auto rg_reader = file_reader->RowGroup(0);
+ ASSERT_EQ(1, rg_reader->metadata()->num_columns());
+ ASSERT_EQ(1, rg_reader->metadata()->num_rows());
+ ASSERT_FALSE(rg_reader->metadata()->ColumnChunk(0)->has_dictionary_page());
+}
+
+TEST(TestBufferedRowGroupWriter, MultiPageDisabledDictionary) {
+ constexpr int kValueCount = 10000;
+ constexpr int kPageSize = 16384;
+ auto sink = CreateOutputStream();
+ auto writer_props = parquet::WriterProperties::Builder()
+ .disable_dictionary()
+ ->data_pagesize(kPageSize)
+ ->build();
+ schema::NodeVector fields;
+ fields.push_back(
+ PrimitiveNode::Make("col", parquet::Repetition::REQUIRED, parquet::Type::INT32));
+ auto schema = std::static_pointer_cast<GroupNode>(
+ GroupNode::Make("schema", Repetition::REQUIRED, fields));
+ auto file_writer = parquet::ParquetFileWriter::Open(sink, schema, writer_props);
+ auto rg_writer = file_writer->AppendBufferedRowGroup();
+ auto col_writer = static_cast<Int32Writer*>(rg_writer->column(0));
+ std::vector<int32_t> values_in;
+ for (int i = 0; i < kValueCount; ++i) {
+ values_in.push_back((i % 100) + 1);
+ }
+ col_writer->WriteBatch(kValueCount, nullptr, nullptr, values_in.data());
+ rg_writer->Close();
+ file_writer->Close();
+ PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish());
+
+ auto source = std::make_shared<::arrow::io::BufferReader>(buffer);
+ auto file_reader = ParquetFileReader::Open(source);
+ auto file_metadata = file_reader->metadata();
+ ASSERT_EQ(1, file_reader->metadata()->num_row_groups());
+ std::vector<int32_t> values_out(kValueCount);
+ for (int r = 0; r < file_metadata->num_row_groups(); ++r) {
+ auto rg_reader = file_reader->RowGroup(r);
+ ASSERT_EQ(1, rg_reader->metadata()->num_columns());
+ ASSERT_EQ(kValueCount, rg_reader->metadata()->num_rows());
+ int64_t total_values_read = 0;
+ std::shared_ptr<parquet::ColumnReader> col_reader;
+ ASSERT_NO_THROW(col_reader = rg_reader->Column(0));
+ parquet::Int32Reader* int32_reader =
+ static_cast<parquet::Int32Reader*>(col_reader.get());
+ int64_t vn = kValueCount;
+ int32_t* vx = values_out.data();
+ while (int32_reader->HasNext()) {
+ int64_t values_read;
+ int32_reader->ReadBatch(vn, nullptr, nullptr, vx, &values_read);
+ vn -= values_read;
+ vx += values_read;
+ total_values_read += values_read;
+ }
+ ASSERT_EQ(kValueCount, total_values_read);
+ ASSERT_EQ(values_in, values_out);
+ }
+}
+
+TEST(ParquetRoundtrip, AllNulls) {
+ auto primitive_node =
+ PrimitiveNode::Make("nulls", Repetition::OPTIONAL, nullptr, Type::INT32);
+ schema::NodeVector columns({primitive_node});
+
+ auto root_node = GroupNode::Make("root", Repetition::REQUIRED, columns, nullptr);
+
+ auto sink = CreateOutputStream();
+
+ auto file_writer =
+ ParquetFileWriter::Open(sink, std::static_pointer_cast<GroupNode>(root_node));
+ auto row_group_writer = file_writer->AppendRowGroup();
+ auto column_writer = static_cast<Int32Writer*>(row_group_writer->NextColumn());
+
+ int32_t values[3];
+ int16_t def_levels[] = {0, 0, 0};
+
+ column_writer->WriteBatch(3, def_levels, nullptr, values);
+
+ column_writer->Close();
+ row_group_writer->Close();
+ file_writer->Close();
+
+ ReaderProperties props = default_reader_properties();
+ props.enable_buffered_stream();
+ PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish());
+
+ auto source = std::make_shared<::arrow::io::BufferReader>(buffer);
+ auto file_reader = ParquetFileReader::Open(source, props);
+ auto row_group_reader = file_reader->RowGroup(0);
+ auto column_reader = std::static_pointer_cast<Int32Reader>(row_group_reader->Column(0));
+
+ int64_t values_read;
+ def_levels[0] = -1;
+ def_levels[1] = -1;
+ def_levels[2] = -1;
+ column_reader->ReadBatch(3, def_levels, nullptr, values, &values_read);
+ EXPECT_THAT(def_levels, ElementsAre(0, 0, 0));
+}
+
+} // namespace test
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/file_writer.cc b/src/arrow/cpp/src/parquet/file_writer.cc
new file mode 100644
index 000000000..deac9586e
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/file_writer.cc
@@ -0,0 +1,547 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/file_writer.h"
+
+#include <cstddef>
+#include <ostream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "parquet/column_writer.h"
+#include "parquet/encryption/encryption_internal.h"
+#include "parquet/encryption/internal_file_encryptor.h"
+#include "parquet/exception.h"
+#include "parquet/platform.h"
+#include "parquet/schema.h"
+#include "parquet/types.h"
+
+using arrow::MemoryPool;
+
+using parquet::schema::GroupNode;
+
+namespace parquet {
+
+// ----------------------------------------------------------------------
+// RowGroupWriter public API
+
+RowGroupWriter::RowGroupWriter(std::unique_ptr<Contents> contents)
+ : contents_(std::move(contents)) {}
+
+void RowGroupWriter::Close() {
+ if (contents_) {
+ contents_->Close();
+ }
+}
+
+ColumnWriter* RowGroupWriter::NextColumn() { return contents_->NextColumn(); }
+
+ColumnWriter* RowGroupWriter::column(int i) { return contents_->column(i); }
+
+int64_t RowGroupWriter::total_compressed_bytes() const {
+ return contents_->total_compressed_bytes();
+}
+
+int64_t RowGroupWriter::total_bytes_written() const {
+ return contents_->total_bytes_written();
+}
+
+int RowGroupWriter::current_column() { return contents_->current_column(); }
+
+int RowGroupWriter::num_columns() const { return contents_->num_columns(); }
+
+int64_t RowGroupWriter::num_rows() const { return contents_->num_rows(); }
+
+inline void ThrowRowsMisMatchError(int col, int64_t prev, int64_t curr) {
+ std::stringstream ss;
+ ss << "Column " << col << " had " << curr << " while previous column had " << prev;
+ throw ParquetException(ss.str());
+}
+
+// ----------------------------------------------------------------------
+// RowGroupSerializer
+
+// RowGroupWriter::Contents implementation for the Parquet file specification
+class RowGroupSerializer : public RowGroupWriter::Contents {
+ public:
+ RowGroupSerializer(std::shared_ptr<ArrowOutputStream> sink,
+ RowGroupMetaDataBuilder* metadata, int16_t row_group_ordinal,
+ const WriterProperties* properties, bool buffered_row_group = false,
+ InternalFileEncryptor* file_encryptor = nullptr)
+ : sink_(std::move(sink)),
+ metadata_(metadata),
+ properties_(properties),
+ total_bytes_written_(0),
+ closed_(false),
+ row_group_ordinal_(row_group_ordinal),
+ next_column_index_(0),
+ num_rows_(0),
+ buffered_row_group_(buffered_row_group),
+ file_encryptor_(file_encryptor) {
+ if (buffered_row_group) {
+ InitColumns();
+ } else {
+ column_writers_.push_back(nullptr);
+ }
+ }
+
+ int num_columns() const override { return metadata_->num_columns(); }
+
+ int64_t num_rows() const override {
+ CheckRowsWritten();
+ // CheckRowsWritten ensures num_rows_ is set correctly
+ return num_rows_;
+ }
+
+ ColumnWriter* NextColumn() override {
+ if (buffered_row_group_) {
+ throw ParquetException(
+ "NextColumn() is not supported when a RowGroup is written by size");
+ }
+
+ if (column_writers_[0]) {
+ CheckRowsWritten();
+ }
+
+ // Throws an error if more columns are being written
+ auto col_meta = metadata_->NextColumnChunk();
+
+ if (column_writers_[0]) {
+ total_bytes_written_ += column_writers_[0]->Close();
+ }
+
+ ++next_column_index_;
+
+ const auto& path = col_meta->descr()->path();
+ auto meta_encryptor =
+ file_encryptor_ ? file_encryptor_->GetColumnMetaEncryptor(path->ToDotString())
+ : nullptr;
+ auto data_encryptor =
+ file_encryptor_ ? file_encryptor_->GetColumnDataEncryptor(path->ToDotString())
+ : nullptr;
+ std::unique_ptr<PageWriter> pager = PageWriter::Open(
+ sink_, properties_->compression(path), properties_->compression_level(path),
+ col_meta, row_group_ordinal_, static_cast<int16_t>(next_column_index_ - 1),
+ properties_->memory_pool(), false, meta_encryptor, data_encryptor);
+ column_writers_[0] = ColumnWriter::Make(col_meta, std::move(pager), properties_);
+ return column_writers_[0].get();
+ }
+
+ ColumnWriter* column(int i) override {
+ if (!buffered_row_group_) {
+ throw ParquetException(
+ "column() is only supported when a BufferedRowGroup is being written");
+ }
+
+ if (i >= 0 && i < static_cast<int>(column_writers_.size())) {
+ return column_writers_[i].get();
+ }
+ return nullptr;
+ }
+
+ int current_column() const override { return metadata_->current_column(); }
+
+ int64_t total_compressed_bytes() const override {
+ int64_t total_compressed_bytes = 0;
+ for (size_t i = 0; i < column_writers_.size(); i++) {
+ if (column_writers_[i]) {
+ total_compressed_bytes += column_writers_[i]->total_compressed_bytes();
+ }
+ }
+ return total_compressed_bytes;
+ }
+
+ int64_t total_bytes_written() const override {
+ int64_t total_bytes_written = 0;
+ for (size_t i = 0; i < column_writers_.size(); i++) {
+ if (column_writers_[i]) {
+ total_bytes_written += column_writers_[i]->total_bytes_written();
+ }
+ }
+ return total_bytes_written;
+ }
+
+ void Close() override {
+ if (!closed_) {
+ closed_ = true;
+ CheckRowsWritten();
+
+ for (size_t i = 0; i < column_writers_.size(); i++) {
+ if (column_writers_[i]) {
+ total_bytes_written_ += column_writers_[i]->Close();
+ column_writers_[i].reset();
+ }
+ }
+
+ column_writers_.clear();
+
+ // Ensures all columns have been written
+ metadata_->set_num_rows(num_rows_);
+ metadata_->Finish(total_bytes_written_, row_group_ordinal_);
+ }
+ }
+
+ private:
+ std::shared_ptr<ArrowOutputStream> sink_;
+ mutable RowGroupMetaDataBuilder* metadata_;
+ const WriterProperties* properties_;
+ int64_t total_bytes_written_;
+ bool closed_;
+ int16_t row_group_ordinal_;
+ int next_column_index_;
+ mutable int64_t num_rows_;
+ bool buffered_row_group_;
+ InternalFileEncryptor* file_encryptor_;
+
+ void CheckRowsWritten() const {
+ // verify when only one column is written at a time
+ if (!buffered_row_group_ && column_writers_.size() > 0 && column_writers_[0]) {
+ int64_t current_col_rows = column_writers_[0]->rows_written();
+ if (num_rows_ == 0) {
+ num_rows_ = current_col_rows;
+ } else if (num_rows_ != current_col_rows) {
+ ThrowRowsMisMatchError(next_column_index_, current_col_rows, num_rows_);
+ }
+ } else if (buffered_row_group_ &&
+ column_writers_.size() > 0) { // when buffered_row_group = true
+ int64_t current_col_rows = column_writers_[0]->rows_written();
+ for (int i = 1; i < static_cast<int>(column_writers_.size()); i++) {
+ int64_t current_col_rows_i = column_writers_[i]->rows_written();
+ if (current_col_rows != current_col_rows_i) {
+ ThrowRowsMisMatchError(i, current_col_rows_i, current_col_rows);
+ }
+ }
+ num_rows_ = current_col_rows;
+ }
+ }
+
+ void InitColumns() {
+ for (int i = 0; i < num_columns(); i++) {
+ auto col_meta = metadata_->NextColumnChunk();
+ const auto& path = col_meta->descr()->path();
+ auto meta_encryptor =
+ file_encryptor_ ? file_encryptor_->GetColumnMetaEncryptor(path->ToDotString())
+ : nullptr;
+ auto data_encryptor =
+ file_encryptor_ ? file_encryptor_->GetColumnDataEncryptor(path->ToDotString())
+ : nullptr;
+ std::unique_ptr<PageWriter> pager = PageWriter::Open(
+ sink_, properties_->compression(path), properties_->compression_level(path),
+ col_meta, static_cast<int16_t>(row_group_ordinal_),
+ static_cast<int16_t>(next_column_index_++), properties_->memory_pool(),
+ buffered_row_group_, meta_encryptor, data_encryptor);
+ column_writers_.push_back(
+ ColumnWriter::Make(col_meta, std::move(pager), properties_));
+ }
+ }
+
+ std::vector<std::shared_ptr<ColumnWriter>> column_writers_;
+};
+
+// ----------------------------------------------------------------------
+// FileSerializer
+
+// An implementation of ParquetFileWriter::Contents that deals with the Parquet
+// file structure, Thrift serialization, and other internal matters
+
+class FileSerializer : public ParquetFileWriter::Contents {
+ public:
+ static std::unique_ptr<ParquetFileWriter::Contents> Open(
+ std::shared_ptr<ArrowOutputStream> sink, std::shared_ptr<GroupNode> schema,
+ std::shared_ptr<WriterProperties> properties,
+ std::shared_ptr<const KeyValueMetadata> key_value_metadata) {
+ std::unique_ptr<ParquetFileWriter::Contents> result(
+ new FileSerializer(std::move(sink), std::move(schema), std::move(properties),
+ std::move(key_value_metadata)));
+
+ return result;
+ }
+
+ void Close() override {
+ if (is_open_) {
+ // If any functions here raise an exception, we set is_open_ to be false
+ // so that this does not get called again (possibly causing segfault)
+ is_open_ = false;
+ if (row_group_writer_) {
+ num_rows_ += row_group_writer_->num_rows();
+ row_group_writer_->Close();
+ }
+ row_group_writer_.reset();
+
+ // Write magic bytes and metadata
+ auto file_encryption_properties = properties_->file_encryption_properties();
+
+ if (file_encryption_properties == nullptr) { // Non encrypted file.
+ file_metadata_ = metadata_->Finish();
+ WriteFileMetaData(*file_metadata_, sink_.get());
+ } else { // Encrypted file
+ CloseEncryptedFile(file_encryption_properties);
+ }
+ }
+ }
+
+ int num_columns() const override { return schema_.num_columns(); }
+
+ int num_row_groups() const override { return num_row_groups_; }
+
+ int64_t num_rows() const override { return num_rows_; }
+
+ const std::shared_ptr<WriterProperties>& properties() const override {
+ return properties_;
+ }
+
+ RowGroupWriter* AppendRowGroup(bool buffered_row_group) {
+ if (row_group_writer_) {
+ row_group_writer_->Close();
+ }
+ num_row_groups_++;
+ auto rg_metadata = metadata_->AppendRowGroup();
+ std::unique_ptr<RowGroupWriter::Contents> contents(new RowGroupSerializer(
+ sink_, rg_metadata, static_cast<int16_t>(num_row_groups_ - 1), properties_.get(),
+ buffered_row_group, file_encryptor_.get()));
+ row_group_writer_.reset(new RowGroupWriter(std::move(contents)));
+ return row_group_writer_.get();
+ }
+
+ RowGroupWriter* AppendRowGroup() override { return AppendRowGroup(false); }
+
+ RowGroupWriter* AppendBufferedRowGroup() override { return AppendRowGroup(true); }
+
+ ~FileSerializer() override {
+ try {
+ Close();
+ } catch (...) {
+ }
+ }
+
+ private:
+ FileSerializer(std::shared_ptr<ArrowOutputStream> sink,
+ std::shared_ptr<GroupNode> schema,
+ std::shared_ptr<WriterProperties> properties,
+ std::shared_ptr<const KeyValueMetadata> key_value_metadata)
+ : ParquetFileWriter::Contents(std::move(schema), std::move(key_value_metadata)),
+ sink_(std::move(sink)),
+ is_open_(true),
+ properties_(std::move(properties)),
+ num_row_groups_(0),
+ num_rows_(0),
+ metadata_(FileMetaDataBuilder::Make(&schema_, properties_, key_value_metadata_)) {
+ PARQUET_ASSIGN_OR_THROW(int64_t position, sink_->Tell());
+ if (position == 0) {
+ StartFile();
+ } else {
+ throw ParquetException("Appending to file not implemented.");
+ }
+ }
+
+ void CloseEncryptedFile(FileEncryptionProperties* file_encryption_properties) {
+ // Encrypted file with encrypted footer
+ if (file_encryption_properties->encrypted_footer()) {
+ // encrypted footer
+ file_metadata_ = metadata_->Finish();
+
+ PARQUET_ASSIGN_OR_THROW(int64_t position, sink_->Tell());
+ uint64_t metadata_start = static_cast<uint64_t>(position);
+ auto crypto_metadata = metadata_->GetCryptoMetaData();
+ WriteFileCryptoMetaData(*crypto_metadata, sink_.get());
+
+ auto footer_encryptor = file_encryptor_->GetFooterEncryptor();
+ WriteEncryptedFileMetadata(*file_metadata_, sink_.get(), footer_encryptor, true);
+ PARQUET_ASSIGN_OR_THROW(position, sink_->Tell());
+ uint32_t footer_and_crypto_len = static_cast<uint32_t>(position - metadata_start);
+ PARQUET_THROW_NOT_OK(
+ sink_->Write(reinterpret_cast<uint8_t*>(&footer_and_crypto_len), 4));
+ PARQUET_THROW_NOT_OK(sink_->Write(kParquetEMagic, 4));
+ } else { // Encrypted file with plaintext footer
+ file_metadata_ = metadata_->Finish();
+ auto footer_signing_encryptor = file_encryptor_->GetFooterSigningEncryptor();
+ WriteEncryptedFileMetadata(*file_metadata_, sink_.get(), footer_signing_encryptor,
+ false);
+ }
+ if (file_encryptor_) {
+ file_encryptor_->WipeOutEncryptionKeys();
+ }
+ }
+
+ std::shared_ptr<ArrowOutputStream> sink_;
+ bool is_open_;
+ const std::shared_ptr<WriterProperties> properties_;
+ int num_row_groups_;
+ int64_t num_rows_;
+ std::unique_ptr<FileMetaDataBuilder> metadata_;
+ // Only one of the row group writers is active at a time
+ std::unique_ptr<RowGroupWriter> row_group_writer_;
+
+ std::unique_ptr<InternalFileEncryptor> file_encryptor_;
+
+ void StartFile() {
+ auto file_encryption_properties = properties_->file_encryption_properties();
+ if (file_encryption_properties == nullptr) {
+ // Unencrypted parquet files always start with PAR1
+ PARQUET_THROW_NOT_OK(sink_->Write(kParquetMagic, 4));
+ } else {
+ // Check that all columns in columnEncryptionProperties exist in the schema.
+ auto encrypted_columns = file_encryption_properties->encrypted_columns();
+ // if columnEncryptionProperties is empty, every column in file schema will be
+ // encrypted with footer key.
+ if (encrypted_columns.size() != 0) {
+ std::vector<std::string> column_path_vec;
+ // First, save all column paths in schema.
+ for (int i = 0; i < num_columns(); i++) {
+ column_path_vec.push_back(schema_.Column(i)->path()->ToDotString());
+ }
+ // Check if column exists in schema.
+ for (const auto& elem : encrypted_columns) {
+ auto it = std::find(column_path_vec.begin(), column_path_vec.end(), elem.first);
+ if (it == column_path_vec.end()) {
+ std::stringstream ss;
+ ss << "Encrypted column " + elem.first + " not in file schema";
+ throw ParquetException(ss.str());
+ }
+ }
+ }
+
+ file_encryptor_.reset(new InternalFileEncryptor(file_encryption_properties,
+ properties_->memory_pool()));
+ if (file_encryption_properties->encrypted_footer()) {
+ PARQUET_THROW_NOT_OK(sink_->Write(kParquetEMagic, 4));
+ } else {
+ // Encrypted file with plaintext footer mode.
+ PARQUET_THROW_NOT_OK(sink_->Write(kParquetMagic, 4));
+ }
+ }
+ }
+};
+
+// ----------------------------------------------------------------------
+// ParquetFileWriter public API
+
+ParquetFileWriter::ParquetFileWriter() {}
+
+ParquetFileWriter::~ParquetFileWriter() {
+ try {
+ Close();
+ } catch (...) {
+ }
+}
+
+std::unique_ptr<ParquetFileWriter> ParquetFileWriter::Open(
+ std::shared_ptr<::arrow::io::OutputStream> sink, std::shared_ptr<GroupNode> schema,
+ std::shared_ptr<WriterProperties> properties,
+ std::shared_ptr<const KeyValueMetadata> key_value_metadata) {
+ auto contents =
+ FileSerializer::Open(std::move(sink), std::move(schema), std::move(properties),
+ std::move(key_value_metadata));
+ std::unique_ptr<ParquetFileWriter> result(new ParquetFileWriter());
+ result->Open(std::move(contents));
+ return result;
+}
+
+void WriteFileMetaData(const FileMetaData& file_metadata, ArrowOutputStream* sink) {
+ // Write MetaData
+ PARQUET_ASSIGN_OR_THROW(int64_t position, sink->Tell());
+ uint32_t metadata_len = static_cast<uint32_t>(position);
+
+ file_metadata.WriteTo(sink);
+ PARQUET_ASSIGN_OR_THROW(position, sink->Tell());
+ metadata_len = static_cast<uint32_t>(position) - metadata_len;
+
+ // Write Footer
+ PARQUET_THROW_NOT_OK(sink->Write(reinterpret_cast<uint8_t*>(&metadata_len), 4));
+ PARQUET_THROW_NOT_OK(sink->Write(kParquetMagic, 4));
+}
+
+void WriteMetaDataFile(const FileMetaData& file_metadata, ArrowOutputStream* sink) {
+ PARQUET_THROW_NOT_OK(sink->Write(kParquetMagic, 4));
+ return WriteFileMetaData(file_metadata, sink);
+}
+
+void WriteEncryptedFileMetadata(const FileMetaData& file_metadata,
+ ArrowOutputStream* sink,
+ const std::shared_ptr<Encryptor>& encryptor,
+ bool encrypt_footer) {
+ if (encrypt_footer) { // Encrypted file with encrypted footer
+ // encrypt and write to sink
+ file_metadata.WriteTo(sink, encryptor);
+ } else { // Encrypted file with plaintext footer mode.
+ PARQUET_ASSIGN_OR_THROW(int64_t position, sink->Tell());
+ uint32_t metadata_len = static_cast<uint32_t>(position);
+ file_metadata.WriteTo(sink, encryptor);
+ PARQUET_ASSIGN_OR_THROW(position, sink->Tell());
+ metadata_len = static_cast<uint32_t>(position) - metadata_len;
+
+ PARQUET_THROW_NOT_OK(sink->Write(reinterpret_cast<uint8_t*>(&metadata_len), 4));
+ PARQUET_THROW_NOT_OK(sink->Write(kParquetMagic, 4));
+ }
+}
+
+void WriteFileCryptoMetaData(const FileCryptoMetaData& crypto_metadata,
+ ArrowOutputStream* sink) {
+ crypto_metadata.WriteTo(sink);
+}
+
+const SchemaDescriptor* ParquetFileWriter::schema() const { return contents_->schema(); }
+
+const ColumnDescriptor* ParquetFileWriter::descr(int i) const {
+ return contents_->schema()->Column(i);
+}
+
+int ParquetFileWriter::num_columns() const { return contents_->num_columns(); }
+
+int64_t ParquetFileWriter::num_rows() const { return contents_->num_rows(); }
+
+int ParquetFileWriter::num_row_groups() const { return contents_->num_row_groups(); }
+
+const std::shared_ptr<const KeyValueMetadata>& ParquetFileWriter::key_value_metadata()
+ const {
+ return contents_->key_value_metadata();
+}
+
+const std::shared_ptr<FileMetaData> ParquetFileWriter::metadata() const {
+ return file_metadata_;
+}
+
+void ParquetFileWriter::Open(std::unique_ptr<ParquetFileWriter::Contents> contents) {
+ contents_ = std::move(contents);
+}
+
+void ParquetFileWriter::Close() {
+ if (contents_) {
+ contents_->Close();
+ file_metadata_ = contents_->metadata();
+ contents_.reset();
+ }
+}
+
+RowGroupWriter* ParquetFileWriter::AppendRowGroup() {
+ return contents_->AppendRowGroup();
+}
+
+RowGroupWriter* ParquetFileWriter::AppendBufferedRowGroup() {
+ return contents_->AppendBufferedRowGroup();
+}
+
+RowGroupWriter* ParquetFileWriter::AppendRowGroup(int64_t num_rows) {
+ return AppendRowGroup();
+}
+
+const std::shared_ptr<WriterProperties>& ParquetFileWriter::properties() const {
+ return contents_->properties();
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/file_writer.h b/src/arrow/cpp/src/parquet/file_writer.h
new file mode 100644
index 000000000..4cfc24719
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/file_writer.h
@@ -0,0 +1,234 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "parquet/metadata.h"
+#include "parquet/platform.h"
+#include "parquet/properties.h"
+#include "parquet/schema.h"
+
+namespace parquet {
+
+class ColumnWriter;
+
+// FIXME: copied from reader-internal.cc
+static constexpr uint8_t kParquetMagic[4] = {'P', 'A', 'R', '1'};
+static constexpr uint8_t kParquetEMagic[4] = {'P', 'A', 'R', 'E'};
+
+class PARQUET_EXPORT RowGroupWriter {
+ public:
+ // Forward declare a virtual class 'Contents' to aid dependency injection and more
+ // easily create test fixtures
+ // An implementation of the Contents class is defined in the .cc file
+ struct Contents {
+ virtual ~Contents() = default;
+ virtual int num_columns() const = 0;
+ virtual int64_t num_rows() const = 0;
+
+ // to be used only with ParquetFileWriter::AppendRowGroup
+ virtual ColumnWriter* NextColumn() = 0;
+ // to be used only with ParquetFileWriter::AppendBufferedRowGroup
+ virtual ColumnWriter* column(int i) = 0;
+
+ virtual int current_column() const = 0;
+ virtual void Close() = 0;
+
+ // total bytes written by the page writer
+ virtual int64_t total_bytes_written() const = 0;
+ // total bytes still compressed but not written
+ virtual int64_t total_compressed_bytes() const = 0;
+ };
+
+ explicit RowGroupWriter(std::unique_ptr<Contents> contents);
+
+ /// Construct a ColumnWriter for the indicated row group-relative column.
+ ///
+ /// To be used only with ParquetFileWriter::AppendRowGroup
+ /// Ownership is solely within the RowGroupWriter. The ColumnWriter is only
+ /// valid until the next call to NextColumn or Close. As the contents are
+ /// directly written to the sink, once a new column is started, the contents
+ /// of the previous one cannot be modified anymore.
+ ColumnWriter* NextColumn();
+ /// Index of currently written column. Equal to -1 if NextColumn()
+ /// has not been called yet.
+ int current_column();
+ void Close();
+
+ int num_columns() const;
+
+ /// Construct a ColumnWriter for the indicated row group column.
+ ///
+ /// To be used only with ParquetFileWriter::AppendBufferedRowGroup
+ /// Ownership is solely within the RowGroupWriter. The ColumnWriter is
+ /// valid until Close. The contents are buffered in memory and written to sink
+ /// on Close
+ ColumnWriter* column(int i);
+
+ /**
+ * Number of rows that shall be written as part of this RowGroup.
+ */
+ int64_t num_rows() const;
+
+ int64_t total_bytes_written() const;
+ int64_t total_compressed_bytes() const;
+
+ private:
+ // Holds a pointer to an instance of Contents implementation
+ std::unique_ptr<Contents> contents_;
+};
+
+PARQUET_EXPORT
+void WriteFileMetaData(const FileMetaData& file_metadata,
+ ::arrow::io::OutputStream* sink);
+
+PARQUET_EXPORT
+void WriteMetaDataFile(const FileMetaData& file_metadata,
+ ::arrow::io::OutputStream* sink);
+
+PARQUET_EXPORT
+void WriteEncryptedFileMetadata(const FileMetaData& file_metadata,
+ ArrowOutputStream* sink,
+ const std::shared_ptr<Encryptor>& encryptor,
+ bool encrypt_footer);
+
+PARQUET_EXPORT
+void WriteEncryptedFileMetadata(const FileMetaData& file_metadata,
+ ::arrow::io::OutputStream* sink,
+ const std::shared_ptr<Encryptor>& encryptor = NULLPTR,
+ bool encrypt_footer = false);
+PARQUET_EXPORT
+void WriteFileCryptoMetaData(const FileCryptoMetaData& crypto_metadata,
+ ::arrow::io::OutputStream* sink);
+
+class PARQUET_EXPORT ParquetFileWriter {
+ public:
+ // Forward declare a virtual class 'Contents' to aid dependency injection and more
+ // easily create test fixtures
+ // An implementation of the Contents class is defined in the .cc file
+ struct Contents {
+ Contents(std::shared_ptr<::parquet::schema::GroupNode> schema,
+ std::shared_ptr<const KeyValueMetadata> key_value_metadata)
+ : schema_(), key_value_metadata_(std::move(key_value_metadata)) {
+ schema_.Init(std::move(schema));
+ }
+ virtual ~Contents() {}
+ // Perform any cleanup associated with the file contents
+ virtual void Close() = 0;
+
+ /// \note Deprecated since 1.3.0
+ RowGroupWriter* AppendRowGroup(int64_t num_rows);
+
+ virtual RowGroupWriter* AppendRowGroup() = 0;
+ virtual RowGroupWriter* AppendBufferedRowGroup() = 0;
+
+ virtual int64_t num_rows() const = 0;
+ virtual int num_columns() const = 0;
+ virtual int num_row_groups() const = 0;
+
+ virtual const std::shared_ptr<WriterProperties>& properties() const = 0;
+
+ const std::shared_ptr<const KeyValueMetadata>& key_value_metadata() const {
+ return key_value_metadata_;
+ }
+
+ // Return const-pointer to make it clear that this object is not to be copied
+ const SchemaDescriptor* schema() const { return &schema_; }
+
+ SchemaDescriptor schema_;
+
+ /// This should be the only place this is stored. Everything else is a const reference
+ std::shared_ptr<const KeyValueMetadata> key_value_metadata_;
+
+ const std::shared_ptr<FileMetaData>& metadata() const { return file_metadata_; }
+ std::shared_ptr<FileMetaData> file_metadata_;
+ };
+
+ ParquetFileWriter();
+ ~ParquetFileWriter();
+
+ static std::unique_ptr<ParquetFileWriter> Open(
+ std::shared_ptr<::arrow::io::OutputStream> sink,
+ std::shared_ptr<schema::GroupNode> schema,
+ std::shared_ptr<WriterProperties> properties = default_writer_properties(),
+ std::shared_ptr<const KeyValueMetadata> key_value_metadata = NULLPTR);
+
+ void Open(std::unique_ptr<Contents> contents);
+ void Close();
+
+ // Construct a RowGroupWriter for the indicated number of rows.
+ //
+ // Ownership is solely within the ParquetFileWriter. The RowGroupWriter is only valid
+ // until the next call to AppendRowGroup or AppendBufferedRowGroup or Close.
+ // @param num_rows The number of rows that are stored in the new RowGroup
+ //
+ // \deprecated Since 1.3.0
+ RowGroupWriter* AppendRowGroup(int64_t num_rows);
+
+ /// Construct a RowGroupWriter with an arbitrary number of rows.
+ ///
+ /// Ownership is solely within the ParquetFileWriter. The RowGroupWriter is only valid
+ /// until the next call to AppendRowGroup or AppendBufferedRowGroup or Close.
+ RowGroupWriter* AppendRowGroup();
+
+ /// Construct a RowGroupWriter that buffers all the values until the RowGroup is ready.
+ /// Use this if you want to write a RowGroup based on a certain size
+ ///
+ /// Ownership is solely within the ParquetFileWriter. The RowGroupWriter is only valid
+ /// until the next call to AppendRowGroup or AppendBufferedRowGroup or Close.
+ RowGroupWriter* AppendBufferedRowGroup();
+
+ /// Number of columns.
+ ///
+ /// This number is fixed during the lifetime of the writer as it is determined via
+ /// the schema.
+ int num_columns() const;
+
+ /// Number of rows in the yet started RowGroups.
+ ///
+ /// Changes on the addition of a new RowGroup.
+ int64_t num_rows() const;
+
+ /// Number of started RowGroups.
+ int num_row_groups() const;
+
+ /// Configuration passed to the writer, e.g. the used Parquet format version.
+ const std::shared_ptr<WriterProperties>& properties() const;
+
+ /// Returns the file schema descriptor
+ const SchemaDescriptor* schema() const;
+
+ /// Returns a column descriptor in schema
+ const ColumnDescriptor* descr(int i) const;
+
+ /// Returns the file custom metadata
+ const std::shared_ptr<const KeyValueMetadata>& key_value_metadata() const;
+
+ /// Returns the file metadata, only available after calling Close().
+ const std::shared_ptr<FileMetaData> metadata() const;
+
+ private:
+ // Holds a pointer to an instance of Contents implementation
+ std::unique_ptr<Contents> contents_;
+ std::shared_ptr<FileMetaData> file_metadata_;
+};
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/hasher.h b/src/arrow/cpp/src/parquet/hasher.h
new file mode 100644
index 000000000..d699356a6
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/hasher.h
@@ -0,0 +1,72 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include "parquet/types.h"
+
+namespace parquet {
+// Abstract class for hash
+class Hasher {
+ public:
+ /// Compute hash for 32 bits value by using its plain encoding result.
+ ///
+ /// @param value the value to hash.
+ /// @return hash result.
+ virtual uint64_t Hash(int32_t value) const = 0;
+
+ /// Compute hash for 64 bits value by using its plain encoding result.
+ ///
+ /// @param value the value to hash.
+ /// @return hash result.
+ virtual uint64_t Hash(int64_t value) const = 0;
+
+ /// Compute hash for float value by using its plain encoding result.
+ ///
+ /// @param value the value to hash.
+ /// @return hash result.
+ virtual uint64_t Hash(float value) const = 0;
+
+ /// Compute hash for double value by using its plain encoding result.
+ ///
+ /// @param value the value to hash.
+ /// @return hash result.
+ virtual uint64_t Hash(double value) const = 0;
+
+ /// Compute hash for Int96 value by using its plain encoding result.
+ ///
+ /// @param value the value to hash.
+ /// @return hash result.
+ virtual uint64_t Hash(const Int96* value) const = 0;
+
+ /// Compute hash for ByteArray value by using its plain encoding result.
+ ///
+ /// @param value the value to hash.
+ /// @return hash result.
+ virtual uint64_t Hash(const ByteArray* value) const = 0;
+
+ /// Compute hash for fixed byte array value by using its plain encoding result.
+ ///
+ /// @param value the value address.
+ /// @param len the value length.
+ virtual uint64_t Hash(const FLBA* value, uint32_t len) const = 0;
+
+ virtual ~Hasher() = default;
+};
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/level_comparison.cc b/src/arrow/cpp/src/parquet/level_comparison.cc
new file mode 100644
index 000000000..30614ae61
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/level_comparison.cc
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/level_comparison.h"
+
+#define PARQUET_IMPL_NAMESPACE standard
+#include "parquet/level_comparison_inc.h"
+#undef PARQUET_IMPL_NAMESPACE
+
+#include <vector>
+
+#include "arrow/util/dispatch.h"
+
+namespace parquet {
+namespace internal {
+
+#if defined(ARROW_HAVE_RUNTIME_AVX2)
+MinMax FindMinMaxAvx2(const int16_t* levels, int64_t num_levels);
+uint64_t GreaterThanBitmapAvx2(const int16_t* levels, int64_t num_levels, int16_t rhs);
+#endif
+
+namespace {
+
+using ::arrow::internal::DispatchLevel;
+using ::arrow::internal::DynamicDispatch;
+
+// defined in level_comparison_avx2.cc
+
+struct GreaterThanDynamicFunction {
+ using FunctionType = decltype(&GreaterThanBitmap);
+
+ static std::vector<std::pair<DispatchLevel, FunctionType>> implementations() {
+ return {
+ { DispatchLevel::NONE, standard::GreaterThanBitmapImpl }
+#if defined(ARROW_HAVE_RUNTIME_AVX2)
+ , { DispatchLevel::AVX2, GreaterThanBitmapAvx2 }
+#endif
+ };
+ }
+};
+
+struct MinMaxDynamicFunction {
+ using FunctionType = decltype(&FindMinMax);
+
+ static std::vector<std::pair<DispatchLevel, FunctionType>> implementations() {
+ return {
+ { DispatchLevel::NONE, standard::FindMinMaxImpl }
+#if defined(ARROW_HAVE_RUNTIME_AVX2)
+ , { DispatchLevel::AVX2, FindMinMaxAvx2 }
+#endif
+ };
+ }
+};
+
+} // namespace
+
+uint64_t GreaterThanBitmap(const int16_t* levels, int64_t num_levels, int16_t rhs) {
+ static DynamicDispatch<GreaterThanDynamicFunction> dispatch;
+ return dispatch.func(levels, num_levels, rhs);
+}
+
+MinMax FindMinMax(const int16_t* levels, int64_t num_levels) {
+ static DynamicDispatch<MinMaxDynamicFunction> dispatch;
+ return dispatch.func(levels, num_levels);
+}
+
+} // namespace internal
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/level_comparison.h b/src/arrow/cpp/src/parquet/level_comparison.h
new file mode 100644
index 000000000..38e7ef8e2
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/level_comparison.h
@@ -0,0 +1,40 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#pragma once
+
+#include <algorithm>
+#include <cstdint>
+
+#include "parquet/platform.h"
+
+namespace parquet {
+namespace internal {
+
+/// Builds a bitmap where each set bit indicates the corresponding level is greater
+/// than rhs.
+uint64_t PARQUET_EXPORT GreaterThanBitmap(const int16_t* levels, int64_t num_levels,
+ int16_t rhs);
+
+struct MinMax {
+ int16_t min;
+ int16_t max;
+};
+
+MinMax FindMinMax(const int16_t* levels, int64_t num_levels);
+
+} // namespace internal
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/level_comparison_avx2.cc b/src/arrow/cpp/src/parquet/level_comparison_avx2.cc
new file mode 100644
index 000000000..b33eb2e29
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/level_comparison_avx2.cc
@@ -0,0 +1,34 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#define PARQUET_IMPL_NAMESPACE avx2
+#include "parquet/level_comparison_inc.h"
+#undef PARQUET_IMPL_NAMESPACE
+
+namespace parquet {
+namespace internal {
+
+uint64_t GreaterThanBitmapAvx2(const int16_t* levels, int64_t num_levels, int16_t rhs) {
+ return avx2::GreaterThanBitmapImpl(levels, num_levels, rhs);
+}
+
+MinMax FindMinMaxAvx2(const int16_t* levels, int64_t num_levels) {
+ return avx2::FindMinMaxImpl(levels, num_levels);
+}
+
+} // namespace internal
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/level_comparison_inc.h b/src/arrow/cpp/src/parquet/level_comparison_inc.h
new file mode 100644
index 000000000..e21c3e582
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/level_comparison_inc.h
@@ -0,0 +1,65 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#pragma once
+
+#include "arrow/util/bit_util.h"
+#include "arrow/util/endian.h"
+#include "parquet/level_comparison.h"
+
+// Used to make sure ODR rule isn't violated.
+#ifndef PARQUET_IMPL_NAMESPACE
+#error "PARQUET_IMPL_NAMESPACE must be defined"
+#endif
+namespace parquet {
+namespace internal {
+namespace PARQUET_IMPL_NAMESPACE {
+/// Builds a bitmap by applying predicate to the level vector provided.
+///
+/// \param[in] levels Rep or def level array.
+/// \param[in] num_levels The number of levels to process (must be [0, 64])
+/// \param[in] predicate The predicate to apply (must have the signature `bool
+/// predicate(int16_t)`.
+/// \returns The bitmap using least significant "bit" ordering.
+///
+template <typename Predicate>
+inline uint64_t LevelsToBitmap(const int16_t* levels, int64_t num_levels,
+ Predicate predicate) {
+ // Both clang and GCC can vectorize this automatically with SSE4/AVX2.
+ uint64_t mask = 0;
+ for (int x = 0; x < num_levels; x++) {
+ mask |= static_cast<uint64_t>(predicate(levels[x]) ? 1 : 0) << x;
+ }
+ return ::arrow::BitUtil::ToLittleEndian(mask);
+}
+
+inline MinMax FindMinMaxImpl(const int16_t* levels, int64_t num_levels) {
+ MinMax out{std::numeric_limits<int16_t>::max(), std::numeric_limits<int16_t>::min()};
+ for (int x = 0; x < num_levels; x++) {
+ out.min = std::min(levels[x], out.min);
+ out.max = std::max(levels[x], out.max);
+ }
+ return out;
+}
+
+inline uint64_t GreaterThanBitmapImpl(const int16_t* levels, int64_t num_levels,
+ int16_t rhs) {
+ return LevelsToBitmap(levels, num_levels, [rhs](int16_t value) { return value > rhs; });
+}
+
+} // namespace PARQUET_IMPL_NAMESPACE
+} // namespace internal
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/level_conversion.cc b/src/arrow/cpp/src/parquet/level_conversion.cc
new file mode 100644
index 000000000..ffdca476d
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/level_conversion.cc
@@ -0,0 +1,183 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#include "parquet/level_conversion.h"
+
+#include <algorithm>
+#include <limits>
+
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/cpu_info.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/optional.h"
+#include "parquet/exception.h"
+
+#include "parquet/level_comparison.h"
+#define PARQUET_IMPL_NAMESPACE standard
+#include "parquet/level_conversion_inc.h"
+#undef PARQUET_IMPL_NAMESPACE
+
+namespace parquet {
+namespace internal {
+namespace {
+
+using ::arrow::internal::CpuInfo;
+using ::arrow::util::optional;
+
+template <typename OffsetType>
+void DefRepLevelsToListInfo(const int16_t* def_levels, const int16_t* rep_levels,
+ int64_t num_def_levels, LevelInfo level_info,
+ ValidityBitmapInputOutput* output, OffsetType* offsets) {
+ OffsetType* orig_pos = offsets;
+ optional<::arrow::internal::FirstTimeBitmapWriter> valid_bits_writer;
+ if (output->valid_bits) {
+ valid_bits_writer.emplace(output->valid_bits, output->valid_bits_offset,
+ output->values_read_upper_bound);
+ }
+ for (int x = 0; x < num_def_levels; x++) {
+ // Skip items that belong to empty or null ancestor lists and further nested lists.
+ if (def_levels[x] < level_info.repeated_ancestor_def_level ||
+ rep_levels[x] > level_info.rep_level) {
+ continue;
+ }
+
+ if (rep_levels[x] == level_info.rep_level) {
+ // A continuation of an existing list.
+ // offsets can be null for structs with repeated children (we don't need to know
+ // offsets until we get to the children).
+ if (offsets != nullptr) {
+ if (ARROW_PREDICT_FALSE(*offsets == std::numeric_limits<OffsetType>::max())) {
+ throw ParquetException("List index overflow.");
+ }
+ *offsets += 1;
+ }
+ } else {
+ if (ARROW_PREDICT_FALSE(
+ (valid_bits_writer.has_value() &&
+ valid_bits_writer->position() >= output->values_read_upper_bound) ||
+ (offsets - orig_pos) >= output->values_read_upper_bound)) {
+ std::stringstream ss;
+ ss << "Definition levels exceeded upper bound: "
+ << output->values_read_upper_bound;
+ throw ParquetException(ss.str());
+ }
+
+ // current_rep < list rep_level i.e. start of a list (ancestor empty lists are
+ // filtered out above).
+ // offsets can be null for structs with repeated children (we don't need to know
+ // offsets until we get to the children).
+ if (offsets != nullptr) {
+ ++offsets;
+ // Use cumulative offsets because variable size lists are more common then
+ // fixed size lists so it should be cheaper to make these cumulative and
+ // subtract when validating fixed size lists.
+ *offsets = *(offsets - 1);
+ if (def_levels[x] >= level_info.def_level) {
+ if (ARROW_PREDICT_FALSE(*offsets == std::numeric_limits<OffsetType>::max())) {
+ throw ParquetException("List index overflow.");
+ }
+ *offsets += 1;
+ }
+ }
+
+ if (valid_bits_writer.has_value()) {
+ // the level_info def level for lists reflects element present level.
+ // the prior level distinguishes between empty lists.
+ if (def_levels[x] >= level_info.def_level - 1) {
+ valid_bits_writer->Set();
+ } else {
+ output->null_count++;
+ valid_bits_writer->Clear();
+ }
+ valid_bits_writer->Next();
+ }
+ }
+ }
+ if (valid_bits_writer.has_value()) {
+ valid_bits_writer->Finish();
+ }
+ if (offsets != nullptr) {
+ output->values_read = offsets - orig_pos;
+ } else if (valid_bits_writer.has_value()) {
+ output->values_read = valid_bits_writer->position();
+ }
+ if (output->null_count > 0 && level_info.null_slot_usage > 1) {
+ throw ParquetException(
+ "Null values with null_slot_usage > 1 not supported."
+ "(i.e. FixedSizeLists with null values are not supported)");
+ }
+}
+
+} // namespace
+
+#if defined(ARROW_HAVE_RUNTIME_BMI2)
+// defined in level_conversion_bmi2.cc for dynamic dispatch.
+void DefLevelsToBitmapBmi2WithRepeatedParent(const int16_t* def_levels,
+ int64_t num_def_levels, LevelInfo level_info,
+ ValidityBitmapInputOutput* output);
+#endif
+
+void DefLevelsToBitmap(const int16_t* def_levels, int64_t num_def_levels,
+ LevelInfo level_info, ValidityBitmapInputOutput* output) {
+ // It is simpler to rely on rep_level here until PARQUET-1899 is done and the code
+ // is deleted in a follow-up release.
+ if (level_info.rep_level > 0) {
+#if defined(ARROW_HAVE_RUNTIME_BMI2)
+ if (CpuInfo::GetInstance()->HasEfficientBmi2()) {
+ return DefLevelsToBitmapBmi2WithRepeatedParent(def_levels, num_def_levels,
+ level_info, output);
+ }
+#endif
+ standard::DefLevelsToBitmapSimd</*has_repeated_parent=*/true>(
+ def_levels, num_def_levels, level_info, output);
+ } else {
+ standard::DefLevelsToBitmapSimd</*has_repeated_parent=*/false>(
+ def_levels, num_def_levels, level_info, output);
+ }
+}
+
+uint64_t TestOnlyExtractBitsSoftware(uint64_t bitmap, uint64_t select_bitmap) {
+ return standard::ExtractBitsSoftware(bitmap, select_bitmap);
+}
+
+void DefRepLevelsToList(const int16_t* def_levels, const int16_t* rep_levels,
+ int64_t num_def_levels, LevelInfo level_info,
+ ValidityBitmapInputOutput* output, int32_t* offsets) {
+ DefRepLevelsToListInfo<int32_t>(def_levels, rep_levels, num_def_levels, level_info,
+ output, offsets);
+}
+
+void DefRepLevelsToList(const int16_t* def_levels, const int16_t* rep_levels,
+ int64_t num_def_levels, LevelInfo level_info,
+ ValidityBitmapInputOutput* output, int64_t* offsets) {
+ DefRepLevelsToListInfo<int64_t>(def_levels, rep_levels, num_def_levels, level_info,
+ output, offsets);
+}
+
+void DefRepLevelsToBitmap(const int16_t* def_levels, const int16_t* rep_levels,
+ int64_t num_def_levels, LevelInfo level_info,
+ ValidityBitmapInputOutput* output) {
+ // DefReplevelsToListInfo assumes it for the actual list method and this
+ // method is for parent structs, so we need to bump def and ref level.
+ level_info.rep_level += 1;
+ level_info.def_level += 1;
+ DefRepLevelsToListInfo<int32_t>(def_levels, rep_levels, num_def_levels, level_info,
+ output, /*offsets=*/nullptr);
+}
+
+} // namespace internal
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/level_conversion.h b/src/arrow/cpp/src/parquet/level_conversion.h
new file mode 100644
index 000000000..e45a288e8
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/level_conversion.h
@@ -0,0 +1,199 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+
+#include "arrow/util/endian.h"
+#include "parquet/platform.h"
+#include "parquet/schema.h"
+
+namespace parquet {
+namespace internal {
+
+struct PARQUET_EXPORT LevelInfo {
+ LevelInfo()
+ : null_slot_usage(1), def_level(0), rep_level(0), repeated_ancestor_def_level(0) {}
+ LevelInfo(int32_t null_slots, int32_t definition_level, int32_t repetition_level,
+ int32_t repeated_ancestor_definition_level)
+ : null_slot_usage(null_slots),
+ def_level(definition_level),
+ rep_level(repetition_level),
+ repeated_ancestor_def_level(repeated_ancestor_definition_level) {}
+
+ bool operator==(const LevelInfo& b) const {
+ return null_slot_usage == b.null_slot_usage && def_level == b.def_level &&
+ rep_level == b.rep_level &&
+ repeated_ancestor_def_level == b.repeated_ancestor_def_level;
+ }
+
+ bool HasNullableValues() const { return repeated_ancestor_def_level < def_level; }
+
+ // How many slots an undefined but present (i.e. null) element in
+ // parquet consumes when decoding to Arrow.
+ // "Slot" is used in the same context as the Arrow specification
+ // (i.e. a value holder).
+ // This is only ever >1 for descendents of FixedSizeList.
+ int32_t null_slot_usage = 1;
+
+ // The definition level at which the value for the field
+ // is considered not null (definition levels greater than
+ // or equal to this value indicate a not-null
+ // value for the field). For list fields definition levels
+ // greater than or equal to this field indicate a present,
+ // possibly null, child value.
+ int16_t def_level = 0;
+
+ // The repetition level corresponding to this element
+ // or the closest repeated ancestor. Any repetition
+ // level less than this indicates either a new list OR
+ // an empty list (which is determined in conjunction
+ // with definition levels).
+ int16_t rep_level = 0;
+
+ // The definition level indicating the level at which the closest
+ // repeated ancestor is not empty. This is used to discriminate
+ // between a value less than |def_level| being null or excluded entirely.
+ // For instance if we have an arrow schema like:
+ // list(struct(f0: int)). Then then there are the following
+ // definition levels:
+ // 0 = null list
+ // 1 = present but empty list.
+ // 2 = a null value in the list
+ // 3 = a non null struct but null integer.
+ // 4 = a present integer.
+ // When reconstructing, the struct and integer arrays'
+ // repeated_ancestor_def_level would be 2. Any
+ // def_level < 2 indicates that there isn't a corresponding
+ // child value in the list.
+ // i.e. [null, [], [null], [{f0: null}], [{f0: 1}]]
+ // has the def levels [0, 1, 2, 3, 4]. The actual
+ // struct array is only of length 3: [not-set, set, set] and
+ // the int array is also of length 3: [N/A, null, 1].
+ //
+ int16_t repeated_ancestor_def_level = 0;
+
+ /// Increments levels according to the cardinality of node.
+ void Increment(const schema::Node& node) {
+ if (node.is_repeated()) {
+ IncrementRepeated();
+ return;
+ }
+ if (node.is_optional()) {
+ IncrementOptional();
+ return;
+ }
+ }
+
+ /// Incremetns level for a optional node.
+ void IncrementOptional() { def_level++; }
+
+ /// Increments levels for the repeated node. Returns
+ /// the previous ancestor_list_def_level.
+ int16_t IncrementRepeated() {
+ int16_t last_repeated_ancestor = repeated_ancestor_def_level;
+
+ // Repeated fields add both a repetition and definition level. This is used
+ // to distinguish between an empty list and a list with an item in it.
+ ++rep_level;
+ ++def_level;
+ // For levels >= repeated_ancenstor_def_level it indicates the list was
+ // non-null and had at least one element. This is important
+ // for later decoding because we need to add a slot for these
+ // values. for levels < current_def_level no slots are added
+ // to arrays.
+ repeated_ancestor_def_level = def_level;
+ return last_repeated_ancestor;
+ }
+
+ friend std::ostream& operator<<(std::ostream& os, const LevelInfo& levels) {
+ // This print method is to silence valgrind issues. What's printed
+ // is not important because all asserts happen directly on
+ // members.
+ os << "{def=" << levels.def_level << ", rep=" << levels.rep_level
+ << ", repeated_ancestor_def=" << levels.repeated_ancestor_def_level;
+ if (levels.null_slot_usage > 1) {
+ os << ", null_slot_usage=" << levels.null_slot_usage;
+ }
+ os << "}";
+ return os;
+ }
+};
+
+// Input/Output structure for reconstructed validity bitmaps.
+struct PARQUET_EXPORT ValidityBitmapInputOutput {
+ // Input only.
+ // The maximum number of values_read expected (actual
+ // values read must be less than or equal to this value).
+ // If this number is exceeded methods will throw a
+ // ParquetException. Exceeding this limit indicates
+ // either a corrupt or incorrectly written file.
+ int64_t values_read_upper_bound = 0;
+ // Output only. The number of values added to the encountered
+ // (this is logically the count of the number of elements
+ // for an Arrow array).
+ int64_t values_read = 0;
+ // Input/Output. The number of nulls encountered.
+ int64_t null_count = 0;
+ // Output only. The validity bitmap to populate. May be be null only
+ // for DefRepLevelsToListInfo (if all that is needed is list offsets).
+ uint8_t* valid_bits = NULLPTR;
+ // Input only, offset into valid_bits to start at.
+ int64_t valid_bits_offset = 0;
+};
+
+// Converts def_levels to validity bitmaps for non-list arrays and structs that have
+// at least one member that is not a list and has no list descendents.
+// For lists use DefRepLevelsToList and structs where all descendants contain
+// a list use DefRepLevelsToBitmap.
+void PARQUET_EXPORT DefLevelsToBitmap(const int16_t* def_levels, int64_t num_def_levels,
+ LevelInfo level_info,
+ ValidityBitmapInputOutput* output);
+
+// Reconstructs a validity bitmap and list offsets for a list arrays based on
+// def/rep levels. The first element of offsets will not be modified if rep_levels
+// starts with a new list. The first element of offsets will be used when calculating
+// the next offset. See documentation onf DefLevelsToBitmap for when to use this
+// method vs the other ones in this file for reconstruction.
+//
+// Offsets must be sized to 1 + values_read_upper_bound.
+void PARQUET_EXPORT DefRepLevelsToList(const int16_t* def_levels,
+ const int16_t* rep_levels, int64_t num_def_levels,
+ LevelInfo level_info,
+ ValidityBitmapInputOutput* output,
+ int32_t* offsets);
+void PARQUET_EXPORT DefRepLevelsToList(const int16_t* def_levels,
+ const int16_t* rep_levels, int64_t num_def_levels,
+ LevelInfo level_info,
+ ValidityBitmapInputOutput* output,
+ int64_t* offsets);
+
+// Reconstructs a validity bitmap for a struct every member is a list or has
+// a list descendant. See documentation on DefLevelsToBitmap for when more
+// details on this method compared to the other ones defined above.
+void PARQUET_EXPORT DefRepLevelsToBitmap(const int16_t* def_levels,
+ const int16_t* rep_levels,
+ int64_t num_def_levels, LevelInfo level_info,
+ ValidityBitmapInputOutput* output);
+
+// This is exposed to ensure we can properly test a software simulated pext function
+// (i.e. it isn't hidden by runtime dispatch).
+uint64_t PARQUET_EXPORT TestOnlyExtractBitsSoftware(uint64_t bitmap, uint64_t selection);
+
+} // namespace internal
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/level_conversion_benchmark.cc b/src/arrow/cpp/src/parquet/level_conversion_benchmark.cc
new file mode 100644
index 000000000..f9e91c482
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/level_conversion_benchmark.cc
@@ -0,0 +1,80 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <vector>
+
+#include "benchmark/benchmark.h"
+#include "parquet/level_conversion.h"
+
+constexpr int64_t kLevelCount = 2048;
+// Def level indicating the element is missing from the leaf
+// array (a parent repeated element was empty).
+constexpr int16_t kMissingDefLevel = 0;
+
+// Definition Level indicating the values has an entry in the leaf element.
+constexpr int16_t kPresentDefLevel = 2;
+
+// A repition level that indicates a repeated element.
+constexpr int16_t kHasRepeatedElements = 1;
+
+std::vector<uint8_t> RunDefinitionLevelsToBitmap(const std::vector<int16_t>& def_levels,
+ ::benchmark::State* state) {
+ std::vector<uint8_t> bitmap(/*count=*/def_levels.size(), 0);
+ parquet::internal::LevelInfo info;
+ info.def_level = kHasRepeatedElements;
+ info.repeated_ancestor_def_level = kPresentDefLevel;
+ info.rep_level = 1;
+ parquet::internal::ValidityBitmapInputOutput validity_io;
+ validity_io.values_read_upper_bound = def_levels.size();
+ validity_io.valid_bits = bitmap.data();
+ for (auto _ : *state) {
+ parquet::internal::DefLevelsToBitmap(def_levels.data(), def_levels.size(), info,
+ &validity_io);
+ }
+ state->SetBytesProcessed(int64_t(state->iterations()) * def_levels.size());
+ return bitmap;
+}
+
+void BM_DefinitionLevelsToBitmapRepeatedAllMissing(::benchmark::State& state) {
+ std::vector<int16_t> def_levels(/*count=*/kLevelCount, kMissingDefLevel);
+ auto result = RunDefinitionLevelsToBitmap(def_levels, &state);
+ ::benchmark::DoNotOptimize(result);
+}
+
+BENCHMARK(BM_DefinitionLevelsToBitmapRepeatedAllMissing);
+
+void BM_DefinitionLevelsToBitmapRepeatedAllPresent(::benchmark::State& state) {
+ std::vector<int16_t> def_levels(/*count=*/kLevelCount, kPresentDefLevel);
+ auto result = RunDefinitionLevelsToBitmap(def_levels, &state);
+ ::benchmark::DoNotOptimize(result);
+}
+
+BENCHMARK(BM_DefinitionLevelsToBitmapRepeatedAllPresent);
+
+void BM_DefinitionLevelsToBitmapRepeatedMostPresent(::benchmark::State& state) {
+ std::vector<int16_t> def_levels(/*count=*/kLevelCount, kPresentDefLevel);
+ for (size_t x = 0; x < def_levels.size(); x++) {
+ if (x % 10 == 0) {
+ def_levels[x] = kMissingDefLevel;
+ }
+ }
+ auto result = RunDefinitionLevelsToBitmap(def_levels, &state);
+ ::benchmark::DoNotOptimize(result);
+}
+
+BENCHMARK(BM_DefinitionLevelsToBitmapRepeatedMostPresent);
diff --git a/src/arrow/cpp/src/parquet/level_conversion_bmi2.cc b/src/arrow/cpp/src/parquet/level_conversion_bmi2.cc
new file mode 100644
index 000000000..274d54e50
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/level_conversion_bmi2.cc
@@ -0,0 +1,33 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#include "parquet/level_conversion.h"
+
+#define PARQUET_IMPL_NAMESPACE bmi2
+#include "parquet/level_conversion_inc.h"
+#undef PARQUET_IMPL_NAMESPACE
+
+namespace parquet {
+namespace internal {
+void DefLevelsToBitmapBmi2WithRepeatedParent(const int16_t* def_levels,
+ int64_t num_def_levels, LevelInfo level_info,
+ ValidityBitmapInputOutput* output) {
+ bmi2::DefLevelsToBitmapSimd</*has_repeated_parent=*/true>(def_levels, num_def_levels,
+ level_info, output);
+}
+
+} // namespace internal
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/level_conversion_inc.h b/src/arrow/cpp/src/parquet/level_conversion_inc.h
new file mode 100644
index 000000000..75c7716c4
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/level_conversion_inc.h
@@ -0,0 +1,357 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#pragma once
+
+#include "parquet/level_conversion.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_writer.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/simd.h"
+#include "parquet/exception.h"
+#include "parquet/level_comparison.h"
+
+namespace parquet {
+namespace internal {
+#ifndef PARQUET_IMPL_NAMESPACE
+#error "PARQUET_IMPL_NAMESPACE must be defined"
+#endif
+namespace PARQUET_IMPL_NAMESPACE {
+
+// clang-format off
+/* Python code to generate lookup table:
+
+kLookupBits = 5
+count = 0
+print('constexpr int kLookupBits = {};'.format(kLookupBits))
+print('constexpr uint8_t kPextTable[1 << kLookupBits][1 << kLookupBits] = {')
+print(' ', end = '')
+for mask in range(1 << kLookupBits):
+ for data in range(1 << kLookupBits):
+ bit_value = 0
+ bit_len = 0
+ for i in range(kLookupBits):
+ if mask & (1 << i):
+ bit_value |= (((data >> i) & 1) << bit_len)
+ bit_len += 1
+ out = '0x{:02X},'.format(bit_value)
+ count += 1
+ if count % (1 << kLookupBits) == 1:
+ print(' {')
+ if count % 8 == 1:
+ print(' ', end = '')
+ if count % 8 == 0:
+ print(out, end = '\n')
+ else:
+ print(out, end = ' ')
+ if count % (1 << kLookupBits) == 0:
+ print(' },', end = '')
+print('\n};')
+
+*/
+// clang-format on
+
+constexpr int kLookupBits = 5;
+constexpr uint8_t kPextTable[1 << kLookupBits][1 << kLookupBits] = {
+ {
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ },
+ {
+ 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00,
+ 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01,
+ 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01,
+ },
+ {
+ 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01,
+ 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00,
+ 0x01, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x01,
+ },
+ {
+ 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02,
+ 0x03, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01,
+ 0x02, 0x03, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03,
+ },
+ {
+ 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00,
+ 0x00, 0x01, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01,
+ 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01,
+ },
+ {
+ 0x00, 0x01, 0x00, 0x01, 0x02, 0x03, 0x02, 0x03, 0x00, 0x01, 0x00,
+ 0x01, 0x02, 0x03, 0x02, 0x03, 0x00, 0x01, 0x00, 0x01, 0x02, 0x03,
+ 0x02, 0x03, 0x00, 0x01, 0x00, 0x01, 0x02, 0x03, 0x02, 0x03,
+ },
+ {
+ 0x00, 0x00, 0x01, 0x01, 0x02, 0x02, 0x03, 0x03, 0x00, 0x00, 0x01,
+ 0x01, 0x02, 0x02, 0x03, 0x03, 0x00, 0x00, 0x01, 0x01, 0x02, 0x02,
+ 0x03, 0x03, 0x00, 0x00, 0x01, 0x01, 0x02, 0x02, 0x03, 0x03,
+ },
+ {
+ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x00, 0x01, 0x02,
+ 0x03, 0x04, 0x05, 0x06, 0x07, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
+ 0x06, 0x07, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
+ },
+ {
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01,
+ 0x01, 0x01, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+ },
+ {
+ 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x02, 0x03, 0x02,
+ 0x03, 0x02, 0x03, 0x02, 0x03, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01,
+ 0x00, 0x01, 0x02, 0x03, 0x02, 0x03, 0x02, 0x03, 0x02, 0x03,
+ },
+ {
+ 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x01, 0x02, 0x02, 0x03,
+ 0x03, 0x02, 0x02, 0x03, 0x03, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00,
+ 0x01, 0x01, 0x02, 0x02, 0x03, 0x03, 0x02, 0x02, 0x03, 0x03,
+ },
+ {
+ 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
+ 0x07, 0x04, 0x05, 0x06, 0x07, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01,
+ 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x04, 0x05, 0x06, 0x07,
+ },
+ {
+ 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x02, 0x02, 0x02,
+ 0x02, 0x03, 0x03, 0x03, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01,
+ 0x01, 0x01, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03,
+ },
+ {
+ 0x00, 0x01, 0x00, 0x01, 0x02, 0x03, 0x02, 0x03, 0x04, 0x05, 0x04,
+ 0x05, 0x06, 0x07, 0x06, 0x07, 0x00, 0x01, 0x00, 0x01, 0x02, 0x03,
+ 0x02, 0x03, 0x04, 0x05, 0x04, 0x05, 0x06, 0x07, 0x06, 0x07,
+ },
+ {
+ 0x00, 0x00, 0x01, 0x01, 0x02, 0x02, 0x03, 0x03, 0x04, 0x04, 0x05,
+ 0x05, 0x06, 0x06, 0x07, 0x07, 0x00, 0x00, 0x01, 0x01, 0x02, 0x02,
+ 0x03, 0x03, 0x04, 0x04, 0x05, 0x05, 0x06, 0x06, 0x07, 0x07,
+ },
+ {
+ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
+ 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
+ 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
+ },
+ {
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+ 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+ },
+ {
+ 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00,
+ 0x01, 0x00, 0x01, 0x00, 0x01, 0x02, 0x03, 0x02, 0x03, 0x02, 0x03,
+ 0x02, 0x03, 0x02, 0x03, 0x02, 0x03, 0x02, 0x03, 0x02, 0x03,
+ },
+ {
+ 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01,
+ 0x01, 0x00, 0x00, 0x01, 0x01, 0x02, 0x02, 0x03, 0x03, 0x02, 0x02,
+ 0x03, 0x03, 0x02, 0x02, 0x03, 0x03, 0x02, 0x02, 0x03, 0x03,
+ },
+ {
+ 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02,
+ 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x04, 0x05,
+ 0x06, 0x07, 0x04, 0x05, 0x06, 0x07, 0x04, 0x05, 0x06, 0x07,
+ },
+ {
+ 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00,
+ 0x00, 0x01, 0x01, 0x01, 0x01, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03,
+ 0x03, 0x03, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03,
+ },
+ {
+ 0x00, 0x01, 0x00, 0x01, 0x02, 0x03, 0x02, 0x03, 0x00, 0x01, 0x00,
+ 0x01, 0x02, 0x03, 0x02, 0x03, 0x04, 0x05, 0x04, 0x05, 0x06, 0x07,
+ 0x06, 0x07, 0x04, 0x05, 0x04, 0x05, 0x06, 0x07, 0x06, 0x07,
+ },
+ {
+ 0x00, 0x00, 0x01, 0x01, 0x02, 0x02, 0x03, 0x03, 0x00, 0x00, 0x01,
+ 0x01, 0x02, 0x02, 0x03, 0x03, 0x04, 0x04, 0x05, 0x05, 0x06, 0x06,
+ 0x07, 0x07, 0x04, 0x04, 0x05, 0x05, 0x06, 0x06, 0x07, 0x07,
+ },
+ {
+ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x00, 0x01, 0x02,
+ 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D,
+ 0x0E, 0x0F, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
+ },
+ {
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01,
+ 0x01, 0x01, 0x01, 0x01, 0x01, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
+ 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03,
+ },
+ {
+ 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x02, 0x03, 0x02,
+ 0x03, 0x02, 0x03, 0x02, 0x03, 0x04, 0x05, 0x04, 0x05, 0x04, 0x05,
+ 0x04, 0x05, 0x06, 0x07, 0x06, 0x07, 0x06, 0x07, 0x06, 0x07,
+ },
+ {
+ 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x01, 0x02, 0x02, 0x03,
+ 0x03, 0x02, 0x02, 0x03, 0x03, 0x04, 0x04, 0x05, 0x05, 0x04, 0x04,
+ 0x05, 0x05, 0x06, 0x06, 0x07, 0x07, 0x06, 0x06, 0x07, 0x07,
+ },
+ {
+ 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
+ 0x07, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x08, 0x09,
+ 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x0C, 0x0D, 0x0E, 0x0F,
+ },
+ {
+ 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x02, 0x02, 0x02,
+ 0x02, 0x03, 0x03, 0x03, 0x03, 0x04, 0x04, 0x04, 0x04, 0x05, 0x05,
+ 0x05, 0x05, 0x06, 0x06, 0x06, 0x06, 0x07, 0x07, 0x07, 0x07,
+ },
+ {
+ 0x00, 0x01, 0x00, 0x01, 0x02, 0x03, 0x02, 0x03, 0x04, 0x05, 0x04,
+ 0x05, 0x06, 0x07, 0x06, 0x07, 0x08, 0x09, 0x08, 0x09, 0x0A, 0x0B,
+ 0x0A, 0x0B, 0x0C, 0x0D, 0x0C, 0x0D, 0x0E, 0x0F, 0x0E, 0x0F,
+ },
+ {
+ 0x00, 0x00, 0x01, 0x01, 0x02, 0x02, 0x03, 0x03, 0x04, 0x04, 0x05,
+ 0x05, 0x06, 0x06, 0x07, 0x07, 0x08, 0x08, 0x09, 0x09, 0x0A, 0x0A,
+ 0x0B, 0x0B, 0x0C, 0x0C, 0x0D, 0x0D, 0x0E, 0x0E, 0x0F, 0x0F,
+ },
+ {
+ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
+ 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
+ 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F,
+ },
+};
+
+inline uint64_t ExtractBitsSoftware(uint64_t bitmap, uint64_t select_bitmap) {
+ // A software emulation of _pext_u64
+
+ // These checks should be inline and are likely to be common cases.
+ if (select_bitmap == ~uint64_t{0}) {
+ return bitmap;
+ } else if (select_bitmap == 0) {
+ return 0;
+ }
+
+ // Fallback to lookup table method
+ uint64_t bit_value = 0;
+ int bit_len = 0;
+ constexpr uint8_t kLookupMask = (1U << kLookupBits) - 1;
+ while (select_bitmap != 0) {
+ const auto mask_len = ARROW_POPCOUNT32(select_bitmap & kLookupMask);
+ const uint64_t value = kPextTable[select_bitmap & kLookupMask][bitmap & kLookupMask];
+ bit_value |= (value << bit_len);
+ bit_len += mask_len;
+ bitmap >>= kLookupBits;
+ select_bitmap >>= kLookupBits;
+ }
+ return bit_value;
+}
+
+#ifdef ARROW_HAVE_BMI2
+
+// Use _pext_u64 on 64-bit builds, _pext_u32 on 32-bit builds,
+#if UINTPTR_MAX == 0xFFFFFFFF
+
+using extract_bitmap_t = uint32_t;
+inline extract_bitmap_t ExtractBits(extract_bitmap_t bitmap,
+ extract_bitmap_t select_bitmap) {
+ return _pext_u32(bitmap, select_bitmap);
+}
+
+#else
+
+using extract_bitmap_t = uint64_t;
+inline extract_bitmap_t ExtractBits(extract_bitmap_t bitmap,
+ extract_bitmap_t select_bitmap) {
+ return _pext_u64(bitmap, select_bitmap);
+}
+
+#endif
+
+#else // !defined(ARROW_HAVE_BMI2)
+
+// Use 64-bit pext emulation when BMI2 isn't available.
+using extract_bitmap_t = uint64_t;
+inline extract_bitmap_t ExtractBits(extract_bitmap_t bitmap,
+ extract_bitmap_t select_bitmap) {
+ return ExtractBitsSoftware(bitmap, select_bitmap);
+}
+
+#endif
+
+static constexpr int64_t kExtractBitsSize = 8 * sizeof(extract_bitmap_t);
+
+template <bool has_repeated_parent>
+int64_t DefLevelsBatchToBitmap(const int16_t* def_levels, const int64_t batch_size,
+ int64_t upper_bound_remaining, LevelInfo level_info,
+ ::arrow::internal::FirstTimeBitmapWriter* writer) {
+ DCHECK_LE(batch_size, kExtractBitsSize);
+
+ // Greater than level_info.def_level - 1 implies >= the def_level
+ auto defined_bitmap = static_cast<extract_bitmap_t>(
+ internal::GreaterThanBitmap(def_levels, batch_size, level_info.def_level - 1));
+
+ if (has_repeated_parent) {
+ // Greater than level_info.repeated_ancestor_def_level - 1 implies >= the
+ // repeated_ancestor_def_level
+ auto present_bitmap = static_cast<extract_bitmap_t>(internal::GreaterThanBitmap(
+ def_levels, batch_size, level_info.repeated_ancestor_def_level - 1));
+ auto selected_bits = ExtractBits(defined_bitmap, present_bitmap);
+ int64_t selected_count = ::arrow::BitUtil::PopCount(present_bitmap);
+ if (ARROW_PREDICT_FALSE(selected_count > upper_bound_remaining)) {
+ throw ParquetException("Values read exceeded upper bound");
+ }
+ writer->AppendWord(selected_bits, selected_count);
+ return ::arrow::BitUtil::PopCount(selected_bits);
+ } else {
+ if (ARROW_PREDICT_FALSE(batch_size > upper_bound_remaining)) {
+ std::stringstream ss;
+ ss << "Values read exceeded upper bound";
+ throw ParquetException(ss.str());
+ }
+
+ writer->AppendWord(defined_bitmap, batch_size);
+ return ::arrow::BitUtil::PopCount(defined_bitmap);
+ }
+}
+
+template <bool has_repeated_parent>
+void DefLevelsToBitmapSimd(const int16_t* def_levels, int64_t num_def_levels,
+ LevelInfo level_info, ValidityBitmapInputOutput* output) {
+ ::arrow::internal::FirstTimeBitmapWriter writer(
+ output->valid_bits,
+ /*start_offset=*/output->valid_bits_offset,
+ /*length=*/num_def_levels);
+ int64_t set_count = 0;
+ output->values_read = 0;
+ int64_t values_read_remaining = output->values_read_upper_bound;
+ while (num_def_levels > kExtractBitsSize) {
+ set_count += DefLevelsBatchToBitmap<has_repeated_parent>(
+ def_levels, kExtractBitsSize, values_read_remaining, level_info, &writer);
+ def_levels += kExtractBitsSize;
+ num_def_levels -= kExtractBitsSize;
+ values_read_remaining = output->values_read_upper_bound - writer.position();
+ }
+ set_count += DefLevelsBatchToBitmap<has_repeated_parent>(
+ def_levels, num_def_levels, values_read_remaining, level_info, &writer);
+
+ output->values_read = writer.position();
+ output->null_count += output->values_read - set_count;
+ writer.Finish();
+}
+
+} // namespace PARQUET_IMPL_NAMESPACE
+} // namespace internal
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/level_conversion_test.cc b/src/arrow/cpp/src/parquet/level_conversion_test.cc
new file mode 100644
index 000000000..bfce74ae3
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/level_conversion_test.cc
@@ -0,0 +1,361 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/level_conversion.h"
+
+#include "parquet/level_comparison.h"
+#include "parquet/test_util.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <string>
+#include <vector>
+
+#include "arrow/testing/gtest_compat.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap.h"
+#include "arrow/util/ubsan.h"
+
+namespace parquet {
+namespace internal {
+
+using ::arrow::internal::Bitmap;
+using ::testing::ElementsAreArray;
+
+std::string BitmapToString(const uint8_t* bitmap, int64_t bit_count) {
+ return ::arrow::internal::Bitmap(bitmap, /*offset*/ 0, /*length=*/bit_count).ToString();
+}
+
+std::string BitmapToString(const std::vector<uint8_t>& bitmap, int64_t bit_count) {
+ return BitmapToString(bitmap.data(), bit_count);
+}
+
+TEST(TestColumnReader, DefLevelsToBitmap) {
+ // Bugs in this function were exposed in ARROW-3930
+ std::vector<int16_t> def_levels = {3, 3, 3, 2, 3, 3, 3, 3, 3};
+
+ std::vector<uint8_t> valid_bits(2, 0);
+
+ LevelInfo level_info;
+ level_info.def_level = 3;
+ level_info.rep_level = 1;
+
+ ValidityBitmapInputOutput io;
+ io.values_read_upper_bound = def_levels.size();
+ io.values_read = -1;
+ io.valid_bits = valid_bits.data();
+
+ internal::DefLevelsToBitmap(def_levels.data(), 9, level_info, &io);
+ ASSERT_EQ(9, io.values_read);
+ ASSERT_EQ(1, io.null_count);
+
+ // Call again with 0 definition levels, make sure that valid_bits is unmodified
+ const uint8_t current_byte = valid_bits[1];
+ io.null_count = 0;
+ internal::DefLevelsToBitmap(def_levels.data(), 0, level_info, &io);
+
+ ASSERT_EQ(0, io.values_read);
+ ASSERT_EQ(0, io.null_count);
+ ASSERT_EQ(current_byte, valid_bits[1]);
+}
+
+TEST(TestColumnReader, DefLevelsToBitmapPowerOfTwo) {
+ // PARQUET-1623: Invalid memory access when decoding a valid bits vector that has a
+ // length equal to a power of two and also using a non-zero valid_bits_offset. This
+ // should not fail when run with ASAN or valgrind.
+ std::vector<int16_t> def_levels = {3, 3, 3, 2, 3, 3, 3, 3};
+ std::vector<uint8_t> valid_bits(1, 0);
+
+ LevelInfo level_info;
+ level_info.rep_level = 1;
+ level_info.def_level = 3;
+
+ ValidityBitmapInputOutput io;
+ io.values_read_upper_bound = def_levels.size();
+ io.values_read = -1;
+ io.valid_bits = valid_bits.data();
+
+ // Read the latter half of the validity bitmap
+ internal::DefLevelsToBitmap(def_levels.data() + 4, 4, level_info, &io);
+ ASSERT_EQ(4, io.values_read);
+ ASSERT_EQ(0, io.null_count);
+}
+
+#if defined(ARROW_LITTLE_ENDIAN)
+TEST(GreaterThanBitmap, GeneratesExpectedBitmasks) {
+ std::vector<int16_t> levels = {0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7,
+ 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7,
+ 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7,
+ 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7};
+ EXPECT_EQ(GreaterThanBitmap(levels.data(), /*num_levels=*/0, /*rhs*/ 0), 0);
+ EXPECT_EQ(GreaterThanBitmap(levels.data(), /*num_levels=*/64, /*rhs*/ 8), 0);
+ EXPECT_EQ(GreaterThanBitmap(levels.data(), /*num_levels=*/64, /*rhs*/ -1),
+ 0xFFFFFFFFFFFFFFFF);
+ // Should be zero padded.
+ EXPECT_EQ(GreaterThanBitmap(levels.data(), /*num_levels=*/47, /*rhs*/ -1),
+ 0x7FFFFFFFFFFF);
+ EXPECT_EQ(GreaterThanBitmap(levels.data(), /*num_levels=*/64, /*rhs*/ 6),
+ 0x8080808080808080);
+}
+#endif
+
+TEST(DefLevelsToBitmap, WithRepetitionLevelFiltersOutEmptyListValues) {
+ std::vector<uint8_t> validity_bitmap(/*count*/ 8, 0);
+
+ ValidityBitmapInputOutput io;
+ io.values_read_upper_bound = 64;
+ io.values_read = 1;
+ io.null_count = 5;
+ io.valid_bits = validity_bitmap.data();
+ io.valid_bits_offset = 1;
+
+ LevelInfo level_info;
+ level_info.repeated_ancestor_def_level = 1;
+ level_info.def_level = 2;
+ level_info.rep_level = 1;
+ // All zeros should be ignored, ones should be unset in the bitmp and 2 should be set.
+ std::vector<int16_t> def_levels = {0, 0, 0, 2, 2, 1, 0, 2};
+ DefLevelsToBitmap(def_levels.data(), def_levels.size(), level_info, &io);
+
+ EXPECT_EQ(BitmapToString(validity_bitmap, /*bit_count=*/8), "01101000");
+ for (size_t x = 1; x < validity_bitmap.size(); x++) {
+ EXPECT_EQ(validity_bitmap[x], 0) << "index: " << x;
+ }
+ EXPECT_EQ(io.null_count, /*5 + 1 =*/6);
+ EXPECT_EQ(io.values_read, 4); // value should get overwritten.
+}
+
+struct MultiLevelTestData {
+ public:
+ std::vector<int16_t> def_levels;
+ std::vector<int16_t> rep_levels;
+};
+
+MultiLevelTestData TriplyNestedList() {
+ // Triply nested list values borrow from write_path
+ // [null, [[1 , null, 3], []], []],
+ // [[[]], [[], [1, 2]], null, [[3]]],
+ // null,
+ // []
+ return MultiLevelTestData{
+ /*def_levels=*/std::vector<int16_t>{2, 7, 6, 7, 5, 3, // first row
+ 5, 5, 7, 7, 2, 7, // second row
+ 0, // third row
+ 1},
+ /*rep_levels=*/std::vector<int16_t>{0, 1, 3, 3, 2, 1, // first row
+ 0, 1, 2, 3, 1, 1, // second row
+ 0, 0}};
+}
+
+template <typename ConverterType>
+class NestedListTest : public testing::Test {
+ public:
+ void InitForLength(int length) {
+ this->validity_bits_.clear();
+ this->validity_bits_.insert(this->validity_bits_.end(), length, 0);
+ validity_io_.valid_bits = validity_bits_.data();
+ validity_io_.values_read_upper_bound = length;
+ offsets_.clear();
+ offsets_.insert(offsets_.end(), length + 1, 0);
+ }
+
+ typename ConverterType::OffsetsType* Run(const MultiLevelTestData& test_data,
+ LevelInfo level_info) {
+ return this->converter_.ComputeListInfo(test_data, level_info, &validity_io_,
+ offsets_.data());
+ }
+
+ ConverterType converter_;
+ ValidityBitmapInputOutput validity_io_;
+ std::vector<uint8_t> validity_bits_;
+ std::vector<typename ConverterType::OffsetsType> offsets_;
+};
+
+template <typename IndexType>
+struct RepDefLevelConverter {
+ using OffsetsType = IndexType;
+ OffsetsType* ComputeListInfo(const MultiLevelTestData& test_data, LevelInfo level_info,
+ ValidityBitmapInputOutput* output, IndexType* offsets) {
+ DefRepLevelsToList(test_data.def_levels.data(), test_data.rep_levels.data(),
+ test_data.def_levels.size(), level_info, output, offsets);
+ return offsets + output->values_read;
+ }
+};
+
+using ConverterTypes =
+ ::testing::Types<RepDefLevelConverter</*list_length_type=*/int32_t>,
+ RepDefLevelConverter</*list_length_type=*/int64_t>>;
+TYPED_TEST_SUITE(NestedListTest, ConverterTypes);
+
+TYPED_TEST(NestedListTest, OuterMostTest) {
+ // [null, [[1 , null, 3], []], []],
+ // [[[]], [[], [1, 2]], null, [[3]]],
+ // null,
+ // []
+ // -> 4 outer most lists (len(3), len(4), null, len(0))
+ LevelInfo level_info;
+ level_info.rep_level = 1;
+ level_info.def_level = 2;
+
+ this->InitForLength(4);
+ typename TypeParam::OffsetsType* next_position =
+ this->Run(TriplyNestedList(), level_info);
+
+ EXPECT_EQ(next_position, this->offsets_.data() + 4);
+ EXPECT_THAT(this->offsets_, testing::ElementsAre(0, 3, 7, 7, 7));
+
+ EXPECT_EQ(this->validity_io_.values_read, 4);
+ EXPECT_EQ(this->validity_io_.null_count, 1);
+ EXPECT_EQ(BitmapToString(this->validity_io_.valid_bits, /*length=*/4), "1101");
+}
+
+TYPED_TEST(NestedListTest, MiddleListTest) {
+ // [null, [[1 , null, 3], []], []],
+ // [[[]], [[], [1, 2]], null, [[3]]],
+ // null,
+ // []
+ // -> middle lists (null, len(2), len(0),
+ // len(1), len(2), null, len(1),
+ // N/A,
+ // N/A
+ LevelInfo level_info;
+ level_info.rep_level = 2;
+ level_info.def_level = 4;
+ level_info.repeated_ancestor_def_level = 2;
+
+ this->InitForLength(7);
+ typename TypeParam::OffsetsType* next_position =
+ this->Run(TriplyNestedList(), level_info);
+
+ EXPECT_EQ(next_position, this->offsets_.data() + 7);
+ EXPECT_THAT(this->offsets_, testing::ElementsAre(0, 0, 2, 2, 3, 5, 5, 6));
+
+ EXPECT_EQ(this->validity_io_.values_read, 7);
+ EXPECT_EQ(this->validity_io_.null_count, 2);
+ EXPECT_EQ(BitmapToString(this->validity_io_.valid_bits, /*length=*/7), "0111101");
+}
+
+TYPED_TEST(NestedListTest, InnerMostListTest) {
+ // [null, [[1, null, 3], []], []],
+ // [[[]], [[], [1, 2]], null, [[3]]],
+ // null,
+ // []
+ // -> 6 inner lists (N/A, [len(3), len(0)], N/A
+ // len(0), [len(0), len(2)], N/A, len(1),
+ // N/A,
+ // N/A
+ LevelInfo level_info;
+ level_info.rep_level = 3;
+ level_info.def_level = 6;
+ level_info.repeated_ancestor_def_level = 4;
+
+ this->InitForLength(6);
+ typename TypeParam::OffsetsType* next_position =
+ this->Run(TriplyNestedList(), level_info);
+
+ EXPECT_EQ(next_position, this->offsets_.data() + 6);
+ EXPECT_THAT(this->offsets_, testing::ElementsAre(0, 3, 3, 3, 3, 5, 6));
+
+ EXPECT_EQ(this->validity_io_.values_read, 6);
+ EXPECT_EQ(this->validity_io_.null_count, 0);
+ EXPECT_EQ(BitmapToString(this->validity_io_.valid_bits, /*length=*/6), "111111");
+}
+
+TYPED_TEST(NestedListTest, SimpleLongList) {
+ LevelInfo level_info;
+ level_info.rep_level = 1;
+ level_info.def_level = 2;
+ level_info.repeated_ancestor_def_level = 0;
+
+ MultiLevelTestData test_data;
+ // No empty lists.
+ test_data.def_levels = std::vector<int16_t>(65 * 9, 2);
+ for (int x = 0; x < 65; x++) {
+ test_data.rep_levels.push_back(0);
+ test_data.rep_levels.insert(test_data.rep_levels.end(), 8,
+ /*rep_level=*/1);
+ }
+
+ std::vector<typename TypeParam::OffsetsType> expected_offsets(66, 0);
+ for (size_t x = 1; x < expected_offsets.size(); x++) {
+ expected_offsets[x] = static_cast<typename TypeParam::OffsetsType>(x) * 9;
+ }
+ this->InitForLength(65);
+ typename TypeParam::OffsetsType* next_position = this->Run(test_data, level_info);
+
+ EXPECT_EQ(next_position, this->offsets_.data() + 65);
+ EXPECT_THAT(this->offsets_, testing::ElementsAreArray(expected_offsets));
+
+ EXPECT_EQ(this->validity_io_.values_read, 65);
+ EXPECT_EQ(this->validity_io_.null_count, 0);
+ EXPECT_EQ(BitmapToString(this->validity_io_.valid_bits, /*length=*/65),
+ "11111111 "
+ "11111111 "
+ "11111111 "
+ "11111111 "
+ "11111111 "
+ "11111111 "
+ "11111111 "
+ "11111111 "
+ "1");
+}
+
+TYPED_TEST(NestedListTest, TestOverflow) {
+ LevelInfo level_info;
+ level_info.rep_level = 1;
+ level_info.def_level = 2;
+ level_info.repeated_ancestor_def_level = 0;
+
+ MultiLevelTestData test_data;
+ test_data.def_levels = std::vector<int16_t>{2};
+ test_data.rep_levels = std::vector<int16_t>{0};
+
+ this->InitForLength(2);
+ // Offsets is populated as the cumulative sum of all elements,
+ // so populating the offsets[0] with max-value impacts the
+ // other values populated.
+ this->offsets_[0] = std::numeric_limits<typename TypeParam::OffsetsType>::max();
+ this->offsets_[1] = std::numeric_limits<typename TypeParam::OffsetsType>::max();
+ ASSERT_THROW(this->Run(test_data, level_info), ParquetException);
+
+ ASSERT_THROW(this->Run(test_data, level_info), ParquetException);
+
+ // Same thing should happen if the list already existed.
+ test_data.rep_levels = std::vector<int16_t>{1};
+ ASSERT_THROW(this->Run(test_data, level_info), ParquetException);
+
+ // Should be OK because it shouldn't increment.
+ test_data.def_levels = std::vector<int16_t>{0};
+ test_data.rep_levels = std::vector<int16_t>{0};
+ this->Run(test_data, level_info);
+}
+
+TEST(TestOnlyExtractBitsSoftware, BasicTest) {
+ auto check = [](uint64_t bitmap, uint64_t selection, uint64_t expected) -> void {
+ EXPECT_EQ(TestOnlyExtractBitsSoftware(bitmap, selection), expected);
+ };
+ check(0xFF, 0, 0);
+ check(0xFF, ~uint64_t{0}, 0xFF);
+ check(0xFF00FF, 0xAAAA, 0x000F);
+ check(0xFF0AFF, 0xAFAA, 0x00AF);
+ check(0xFFAAFF, 0xAFAA, 0x03AF);
+ check(0xFECBDA9876543210ULL, 0xF00FF00FF00FF00FULL, 0xFBD87430ULL);
+}
+
+} // namespace internal
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/metadata.cc b/src/arrow/cpp/src/parquet/metadata.cc
new file mode 100644
index 000000000..0f99530dd
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/metadata.cc
@@ -0,0 +1,1797 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/metadata.h"
+
+#include <algorithm>
+#include <cinttypes>
+#include <ostream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/io/memory.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/string_view.h"
+#include "parquet/encryption/encryption_internal.h"
+#include "parquet/encryption/internal_file_decryptor.h"
+#include "parquet/exception.h"
+#include "parquet/schema.h"
+#include "parquet/schema_internal.h"
+#include "parquet/statistics.h"
+#include "parquet/thrift_internal.h"
+
+namespace parquet {
+
+const ApplicationVersion& ApplicationVersion::PARQUET_251_FIXED_VERSION() {
+ static ApplicationVersion version("parquet-mr", 1, 8, 0);
+ return version;
+}
+
+const ApplicationVersion& ApplicationVersion::PARQUET_816_FIXED_VERSION() {
+ static ApplicationVersion version("parquet-mr", 1, 2, 9);
+ return version;
+}
+
+const ApplicationVersion& ApplicationVersion::PARQUET_CPP_FIXED_STATS_VERSION() {
+ static ApplicationVersion version("parquet-cpp", 1, 3, 0);
+ return version;
+}
+
+const ApplicationVersion& ApplicationVersion::PARQUET_MR_FIXED_STATS_VERSION() {
+ static ApplicationVersion version("parquet-mr", 1, 10, 0);
+ return version;
+}
+
+std::string ParquetVersionToString(ParquetVersion::type ver) {
+ switch (ver) {
+ case ParquetVersion::PARQUET_1_0:
+ return "1.0";
+ ARROW_SUPPRESS_DEPRECATION_WARNING
+ case ParquetVersion::PARQUET_2_0:
+ return "pseudo-2.0";
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
+ case ParquetVersion::PARQUET_2_4:
+ return "2.4";
+ case ParquetVersion::PARQUET_2_6:
+ return "2.6";
+ }
+
+ // This should be unreachable
+ return "UNKNOWN";
+}
+
+template <typename DType>
+static std::shared_ptr<Statistics> MakeTypedColumnStats(
+ const format::ColumnMetaData& metadata, const ColumnDescriptor* descr) {
+ // If ColumnOrder is defined, return max_value and min_value
+ if (descr->column_order().get_order() == ColumnOrder::TYPE_DEFINED_ORDER) {
+ return MakeStatistics<DType>(
+ descr, metadata.statistics.min_value, metadata.statistics.max_value,
+ metadata.num_values - metadata.statistics.null_count,
+ metadata.statistics.null_count, metadata.statistics.distinct_count,
+ metadata.statistics.__isset.max_value || metadata.statistics.__isset.min_value,
+ metadata.statistics.__isset.null_count,
+ metadata.statistics.__isset.distinct_count);
+ }
+ // Default behavior
+ return MakeStatistics<DType>(
+ descr, metadata.statistics.min, metadata.statistics.max,
+ metadata.num_values - metadata.statistics.null_count,
+ metadata.statistics.null_count, metadata.statistics.distinct_count,
+ metadata.statistics.__isset.max || metadata.statistics.__isset.min,
+ metadata.statistics.__isset.null_count, metadata.statistics.__isset.distinct_count);
+}
+
+std::shared_ptr<Statistics> MakeColumnStats(const format::ColumnMetaData& meta_data,
+ const ColumnDescriptor* descr) {
+ switch (static_cast<Type::type>(meta_data.type)) {
+ case Type::BOOLEAN:
+ return MakeTypedColumnStats<BooleanType>(meta_data, descr);
+ case Type::INT32:
+ return MakeTypedColumnStats<Int32Type>(meta_data, descr);
+ case Type::INT64:
+ return MakeTypedColumnStats<Int64Type>(meta_data, descr);
+ case Type::INT96:
+ return MakeTypedColumnStats<Int96Type>(meta_data, descr);
+ case Type::DOUBLE:
+ return MakeTypedColumnStats<DoubleType>(meta_data, descr);
+ case Type::FLOAT:
+ return MakeTypedColumnStats<FloatType>(meta_data, descr);
+ case Type::BYTE_ARRAY:
+ return MakeTypedColumnStats<ByteArrayType>(meta_data, descr);
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return MakeTypedColumnStats<FLBAType>(meta_data, descr);
+ case Type::UNDEFINED:
+ break;
+ }
+ throw ParquetException("Can't decode page statistics for selected column type");
+}
+
+// MetaData Accessor
+
+// ColumnCryptoMetaData
+class ColumnCryptoMetaData::ColumnCryptoMetaDataImpl {
+ public:
+ explicit ColumnCryptoMetaDataImpl(const format::ColumnCryptoMetaData* crypto_metadata)
+ : crypto_metadata_(crypto_metadata) {}
+
+ bool encrypted_with_footer_key() const {
+ return crypto_metadata_->__isset.ENCRYPTION_WITH_FOOTER_KEY;
+ }
+ bool encrypted_with_column_key() const {
+ return crypto_metadata_->__isset.ENCRYPTION_WITH_COLUMN_KEY;
+ }
+ std::shared_ptr<schema::ColumnPath> path_in_schema() const {
+ return std::make_shared<schema::ColumnPath>(
+ crypto_metadata_->ENCRYPTION_WITH_COLUMN_KEY.path_in_schema);
+ }
+ const std::string& key_metadata() const {
+ return crypto_metadata_->ENCRYPTION_WITH_COLUMN_KEY.key_metadata;
+ }
+
+ private:
+ const format::ColumnCryptoMetaData* crypto_metadata_;
+};
+
+std::unique_ptr<ColumnCryptoMetaData> ColumnCryptoMetaData::Make(
+ const uint8_t* metadata) {
+ return std::unique_ptr<ColumnCryptoMetaData>(new ColumnCryptoMetaData(metadata));
+}
+
+ColumnCryptoMetaData::ColumnCryptoMetaData(const uint8_t* metadata)
+ : impl_(new ColumnCryptoMetaDataImpl(
+ reinterpret_cast<const format::ColumnCryptoMetaData*>(metadata))) {}
+
+ColumnCryptoMetaData::~ColumnCryptoMetaData() = default;
+
+std::shared_ptr<schema::ColumnPath> ColumnCryptoMetaData::path_in_schema() const {
+ return impl_->path_in_schema();
+}
+bool ColumnCryptoMetaData::encrypted_with_footer_key() const {
+ return impl_->encrypted_with_footer_key();
+}
+const std::string& ColumnCryptoMetaData::key_metadata() const {
+ return impl_->key_metadata();
+}
+
+// ColumnChunk metadata
+class ColumnChunkMetaData::ColumnChunkMetaDataImpl {
+ public:
+ explicit ColumnChunkMetaDataImpl(const format::ColumnChunk* column,
+ const ColumnDescriptor* descr,
+ int16_t row_group_ordinal, int16_t column_ordinal,
+ const ApplicationVersion* writer_version,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor)
+ : column_(column), descr_(descr), writer_version_(writer_version) {
+ column_metadata_ = &column->meta_data;
+ if (column->__isset.crypto_metadata) { // column metadata is encrypted
+ format::ColumnCryptoMetaData ccmd = column->crypto_metadata;
+
+ if (ccmd.__isset.ENCRYPTION_WITH_COLUMN_KEY) {
+ if (file_decryptor != nullptr && file_decryptor->properties() != nullptr) {
+ // should decrypt metadata
+ std::shared_ptr<schema::ColumnPath> path = std::make_shared<schema::ColumnPath>(
+ ccmd.ENCRYPTION_WITH_COLUMN_KEY.path_in_schema);
+ std::string key_metadata = ccmd.ENCRYPTION_WITH_COLUMN_KEY.key_metadata;
+
+ std::string aad_column_metadata = encryption::CreateModuleAad(
+ file_decryptor->file_aad(), encryption::kColumnMetaData, row_group_ordinal,
+ column_ordinal, static_cast<int16_t>(-1));
+ auto decryptor = file_decryptor->GetColumnMetaDecryptor(
+ path->ToDotString(), key_metadata, aad_column_metadata);
+ auto len = static_cast<uint32_t>(column->encrypted_column_metadata.size());
+ DeserializeThriftMsg(
+ reinterpret_cast<const uint8_t*>(column->encrypted_column_metadata.c_str()),
+ &len, &decrypted_metadata_, decryptor);
+ column_metadata_ = &decrypted_metadata_;
+ } else {
+ throw ParquetException(
+ "Cannot decrypt ColumnMetadata."
+ " FileDecryption is not setup correctly");
+ }
+ }
+ }
+ for (const auto& encoding : column_metadata_->encodings) {
+ encodings_.push_back(LoadEnumSafe(&encoding));
+ }
+ for (const auto& encoding_stats : column_metadata_->encoding_stats) {
+ encoding_stats_.push_back({LoadEnumSafe(&encoding_stats.page_type),
+ LoadEnumSafe(&encoding_stats.encoding),
+ encoding_stats.count});
+ }
+ possible_stats_ = nullptr;
+ }
+
+ bool Equals(const ColumnChunkMetaDataImpl& other) const {
+ return *column_metadata_ == *other.column_metadata_;
+ }
+
+ // column chunk
+ inline int64_t file_offset() const { return column_->file_offset; }
+ inline const std::string& file_path() const { return column_->file_path; }
+
+ inline Type::type type() const { return LoadEnumSafe(&column_metadata_->type); }
+
+ inline int64_t num_values() const { return column_metadata_->num_values; }
+
+ std::shared_ptr<schema::ColumnPath> path_in_schema() {
+ return std::make_shared<schema::ColumnPath>(column_metadata_->path_in_schema);
+ }
+
+ // Check if statistics are set and are valid
+ // 1) Must be set in the metadata
+ // 2) Statistics must not be corrupted
+ inline bool is_stats_set() const {
+ DCHECK(writer_version_ != nullptr);
+ // If the column statistics don't exist or column sort order is unknown
+ // we cannot use the column stats
+ if (!column_metadata_->__isset.statistics ||
+ descr_->sort_order() == SortOrder::UNKNOWN) {
+ return false;
+ }
+ if (possible_stats_ == nullptr) {
+ possible_stats_ = MakeColumnStats(*column_metadata_, descr_);
+ }
+ EncodedStatistics encodedStatistics = possible_stats_->Encode();
+ return writer_version_->HasCorrectStatistics(type(), encodedStatistics,
+ descr_->sort_order());
+ }
+
+ inline std::shared_ptr<Statistics> statistics() const {
+ return is_stats_set() ? possible_stats_ : nullptr;
+ }
+
+ inline Compression::type compression() const {
+ return LoadEnumSafe(&column_metadata_->codec);
+ }
+
+ const std::vector<Encoding::type>& encodings() const { return encodings_; }
+
+ const std::vector<PageEncodingStats>& encoding_stats() const { return encoding_stats_; }
+
+ inline bool has_dictionary_page() const {
+ return column_metadata_->__isset.dictionary_page_offset;
+ }
+
+ inline int64_t dictionary_page_offset() const {
+ return column_metadata_->dictionary_page_offset;
+ }
+
+ inline int64_t data_page_offset() const { return column_metadata_->data_page_offset; }
+
+ inline bool has_index_page() const {
+ return column_metadata_->__isset.index_page_offset;
+ }
+
+ inline int64_t index_page_offset() const { return column_metadata_->index_page_offset; }
+
+ inline int64_t total_compressed_size() const {
+ return column_metadata_->total_compressed_size;
+ }
+
+ inline int64_t total_uncompressed_size() const {
+ return column_metadata_->total_uncompressed_size;
+ }
+
+ inline std::unique_ptr<ColumnCryptoMetaData> crypto_metadata() const {
+ if (column_->__isset.crypto_metadata) {
+ return ColumnCryptoMetaData::Make(
+ reinterpret_cast<const uint8_t*>(&column_->crypto_metadata));
+ } else {
+ return nullptr;
+ }
+ }
+
+ private:
+ mutable std::shared_ptr<Statistics> possible_stats_;
+ std::vector<Encoding::type> encodings_;
+ std::vector<PageEncodingStats> encoding_stats_;
+ const format::ColumnChunk* column_;
+ const format::ColumnMetaData* column_metadata_;
+ format::ColumnMetaData decrypted_metadata_;
+ const ColumnDescriptor* descr_;
+ const ApplicationVersion* writer_version_;
+};
+
+std::unique_ptr<ColumnChunkMetaData> ColumnChunkMetaData::Make(
+ const void* metadata, const ColumnDescriptor* descr,
+ const ApplicationVersion* writer_version, int16_t row_group_ordinal,
+ int16_t column_ordinal, std::shared_ptr<InternalFileDecryptor> file_decryptor) {
+ return std::unique_ptr<ColumnChunkMetaData>(
+ new ColumnChunkMetaData(metadata, descr, row_group_ordinal, column_ordinal,
+ writer_version, std::move(file_decryptor)));
+}
+
+ColumnChunkMetaData::ColumnChunkMetaData(
+ const void* metadata, const ColumnDescriptor* descr, int16_t row_group_ordinal,
+ int16_t column_ordinal, const ApplicationVersion* writer_version,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor)
+ : impl_{new ColumnChunkMetaDataImpl(
+ reinterpret_cast<const format::ColumnChunk*>(metadata), descr,
+ row_group_ordinal, column_ordinal, writer_version, std::move(file_decryptor))} {
+}
+
+ColumnChunkMetaData::~ColumnChunkMetaData() = default;
+
+// column chunk
+int64_t ColumnChunkMetaData::file_offset() const { return impl_->file_offset(); }
+
+const std::string& ColumnChunkMetaData::file_path() const { return impl_->file_path(); }
+
+Type::type ColumnChunkMetaData::type() const { return impl_->type(); }
+
+int64_t ColumnChunkMetaData::num_values() const { return impl_->num_values(); }
+
+std::shared_ptr<schema::ColumnPath> ColumnChunkMetaData::path_in_schema() const {
+ return impl_->path_in_schema();
+}
+
+std::shared_ptr<Statistics> ColumnChunkMetaData::statistics() const {
+ return impl_->statistics();
+}
+
+bool ColumnChunkMetaData::is_stats_set() const { return impl_->is_stats_set(); }
+
+bool ColumnChunkMetaData::has_dictionary_page() const {
+ return impl_->has_dictionary_page();
+}
+
+int64_t ColumnChunkMetaData::dictionary_page_offset() const {
+ return impl_->dictionary_page_offset();
+}
+
+int64_t ColumnChunkMetaData::data_page_offset() const {
+ return impl_->data_page_offset();
+}
+
+bool ColumnChunkMetaData::has_index_page() const { return impl_->has_index_page(); }
+
+int64_t ColumnChunkMetaData::index_page_offset() const {
+ return impl_->index_page_offset();
+}
+
+Compression::type ColumnChunkMetaData::compression() const {
+ return impl_->compression();
+}
+
+bool ColumnChunkMetaData::can_decompress() const {
+ return ::arrow::util::Codec::IsAvailable(compression());
+}
+
+const std::vector<Encoding::type>& ColumnChunkMetaData::encodings() const {
+ return impl_->encodings();
+}
+
+const std::vector<PageEncodingStats>& ColumnChunkMetaData::encoding_stats() const {
+ return impl_->encoding_stats();
+}
+
+int64_t ColumnChunkMetaData::total_uncompressed_size() const {
+ return impl_->total_uncompressed_size();
+}
+
+int64_t ColumnChunkMetaData::total_compressed_size() const {
+ return impl_->total_compressed_size();
+}
+
+std::unique_ptr<ColumnCryptoMetaData> ColumnChunkMetaData::crypto_metadata() const {
+ return impl_->crypto_metadata();
+}
+
+bool ColumnChunkMetaData::Equals(const ColumnChunkMetaData& other) const {
+ return impl_->Equals(*other.impl_);
+}
+
+// row-group metadata
+class RowGroupMetaData::RowGroupMetaDataImpl {
+ public:
+ explicit RowGroupMetaDataImpl(const format::RowGroup* row_group,
+ const SchemaDescriptor* schema,
+ const ApplicationVersion* writer_version,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor)
+ : row_group_(row_group),
+ schema_(schema),
+ writer_version_(writer_version),
+ file_decryptor_(std::move(file_decryptor)) {}
+
+ bool Equals(const RowGroupMetaDataImpl& other) const {
+ return *row_group_ == *other.row_group_;
+ }
+
+ inline int num_columns() const { return static_cast<int>(row_group_->columns.size()); }
+
+ inline int64_t num_rows() const { return row_group_->num_rows; }
+
+ inline int64_t total_byte_size() const { return row_group_->total_byte_size; }
+
+ inline int64_t total_compressed_size() const {
+ return row_group_->total_compressed_size;
+ }
+
+ inline int64_t file_offset() const { return row_group_->file_offset; }
+
+ inline const SchemaDescriptor* schema() const { return schema_; }
+
+ std::unique_ptr<ColumnChunkMetaData> ColumnChunk(int i) {
+ if (i < num_columns()) {
+ return ColumnChunkMetaData::Make(&row_group_->columns[i], schema_->Column(i),
+ writer_version_, row_group_->ordinal,
+ static_cast<int16_t>(i), file_decryptor_);
+ }
+ throw ParquetException("The file only has ", num_columns(),
+ " columns, requested metadata for column: ", i);
+ }
+
+ private:
+ const format::RowGroup* row_group_;
+ const SchemaDescriptor* schema_;
+ const ApplicationVersion* writer_version_;
+ std::shared_ptr<InternalFileDecryptor> file_decryptor_;
+};
+
+std::unique_ptr<RowGroupMetaData> RowGroupMetaData::Make(
+ const void* metadata, const SchemaDescriptor* schema,
+ const ApplicationVersion* writer_version,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor) {
+ return std::unique_ptr<RowGroupMetaData>(
+ new RowGroupMetaData(metadata, schema, writer_version, std::move(file_decryptor)));
+}
+
+RowGroupMetaData::RowGroupMetaData(const void* metadata, const SchemaDescriptor* schema,
+ const ApplicationVersion* writer_version,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor)
+ : impl_{new RowGroupMetaDataImpl(reinterpret_cast<const format::RowGroup*>(metadata),
+ schema, writer_version, std::move(file_decryptor))} {
+}
+
+RowGroupMetaData::~RowGroupMetaData() = default;
+
+bool RowGroupMetaData::Equals(const RowGroupMetaData& other) const {
+ return impl_->Equals(*other.impl_);
+}
+
+int RowGroupMetaData::num_columns() const { return impl_->num_columns(); }
+
+int64_t RowGroupMetaData::num_rows() const { return impl_->num_rows(); }
+
+int64_t RowGroupMetaData::total_byte_size() const { return impl_->total_byte_size(); }
+
+int64_t RowGroupMetaData::total_compressed_size() const {
+ return impl_->total_compressed_size();
+}
+
+int64_t RowGroupMetaData::file_offset() const { return impl_->file_offset(); }
+
+const SchemaDescriptor* RowGroupMetaData::schema() const { return impl_->schema(); }
+
+std::unique_ptr<ColumnChunkMetaData> RowGroupMetaData::ColumnChunk(int i) const {
+ return impl_->ColumnChunk(i);
+}
+
+bool RowGroupMetaData::can_decompress() const {
+ int n_columns = num_columns();
+ for (int i = 0; i < n_columns; i++) {
+ if (!ColumnChunk(i)->can_decompress()) {
+ return false;
+ }
+ }
+ return true;
+}
+
+// file metadata
+class FileMetaData::FileMetaDataImpl {
+ public:
+ FileMetaDataImpl() = default;
+
+ explicit FileMetaDataImpl(
+ const void* metadata, uint32_t* metadata_len,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor = nullptr)
+ : file_decryptor_(file_decryptor) {
+ metadata_.reset(new format::FileMetaData);
+
+ auto footer_decryptor =
+ file_decryptor_ != nullptr ? file_decryptor->GetFooterDecryptor() : nullptr;
+
+ DeserializeThriftMsg(reinterpret_cast<const uint8_t*>(metadata), metadata_len,
+ metadata_.get(), footer_decryptor);
+ metadata_len_ = *metadata_len;
+
+ if (metadata_->__isset.created_by) {
+ writer_version_ = ApplicationVersion(metadata_->created_by);
+ } else {
+ writer_version_ = ApplicationVersion("unknown 0.0.0");
+ }
+
+ InitSchema();
+ InitColumnOrders();
+ InitKeyValueMetadata();
+ }
+
+ bool VerifySignature(const void* signature) {
+ // verify decryption properties are set
+ if (file_decryptor_ == nullptr) {
+ throw ParquetException("Decryption not set properly. cannot verify signature");
+ }
+ // serialize the footer
+ uint8_t* serialized_data;
+ uint32_t serialized_len = metadata_len_;
+ ThriftSerializer serializer;
+ serializer.SerializeToBuffer(metadata_.get(), &serialized_len, &serialized_data);
+
+ // encrypt with nonce
+ auto nonce = const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(signature));
+ auto tag = const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(signature)) +
+ encryption::kNonceLength;
+
+ std::string key = file_decryptor_->GetFooterKey();
+ std::string aad = encryption::CreateFooterAad(file_decryptor_->file_aad());
+
+ auto aes_encryptor = encryption::AesEncryptor::Make(
+ file_decryptor_->algorithm(), static_cast<int>(key.size()), true, nullptr);
+
+ std::shared_ptr<Buffer> encrypted_buffer = std::static_pointer_cast<ResizableBuffer>(
+ AllocateBuffer(file_decryptor_->pool(),
+ aes_encryptor->CiphertextSizeDelta() + serialized_len));
+ uint32_t encrypted_len = aes_encryptor->SignedFooterEncrypt(
+ serialized_data, serialized_len, str2bytes(key), static_cast<int>(key.size()),
+ str2bytes(aad), static_cast<int>(aad.size()), nonce,
+ encrypted_buffer->mutable_data());
+ // Delete AES encryptor object. It was created only to verify the footer signature.
+ aes_encryptor->WipeOut();
+ delete aes_encryptor;
+ return 0 ==
+ memcmp(encrypted_buffer->data() + encrypted_len - encryption::kGcmTagLength,
+ tag, encryption::kGcmTagLength);
+ }
+
+ inline uint32_t size() const { return metadata_len_; }
+ inline int num_columns() const { return schema_.num_columns(); }
+ inline int64_t num_rows() const { return metadata_->num_rows; }
+ inline int num_row_groups() const {
+ return static_cast<int>(metadata_->row_groups.size());
+ }
+ inline int32_t version() const { return metadata_->version; }
+ inline const std::string& created_by() const { return metadata_->created_by; }
+ inline int num_schema_elements() const {
+ return static_cast<int>(metadata_->schema.size());
+ }
+
+ inline bool is_encryption_algorithm_set() const {
+ return metadata_->__isset.encryption_algorithm;
+ }
+ inline EncryptionAlgorithm encryption_algorithm() {
+ return FromThrift(metadata_->encryption_algorithm);
+ }
+ inline const std::string& footer_signing_key_metadata() {
+ return metadata_->footer_signing_key_metadata;
+ }
+
+ const ApplicationVersion& writer_version() const { return writer_version_; }
+
+ void WriteTo(::arrow::io::OutputStream* dst,
+ const std::shared_ptr<Encryptor>& encryptor) const {
+ ThriftSerializer serializer;
+ // Only in encrypted files with plaintext footers the
+ // encryption_algorithm is set in footer
+ if (is_encryption_algorithm_set()) {
+ uint8_t* serialized_data;
+ uint32_t serialized_len;
+ serializer.SerializeToBuffer(metadata_.get(), &serialized_len, &serialized_data);
+
+ // encrypt the footer key
+ std::vector<uint8_t> encrypted_data(encryptor->CiphertextSizeDelta() +
+ serialized_len);
+ unsigned encrypted_len =
+ encryptor->Encrypt(serialized_data, serialized_len, encrypted_data.data());
+
+ // write unencrypted footer
+ PARQUET_THROW_NOT_OK(dst->Write(serialized_data, serialized_len));
+ // Write signature (nonce and tag)
+ PARQUET_THROW_NOT_OK(
+ dst->Write(encrypted_data.data() + 4, encryption::kNonceLength));
+ PARQUET_THROW_NOT_OK(
+ dst->Write(encrypted_data.data() + encrypted_len - encryption::kGcmTagLength,
+ encryption::kGcmTagLength));
+ } else { // either plaintext file (when encryptor is null)
+ // or encrypted file with encrypted footer
+ serializer.Serialize(metadata_.get(), dst, encryptor);
+ }
+ }
+
+ std::unique_ptr<RowGroupMetaData> RowGroup(int i) {
+ if (!(i < num_row_groups())) {
+ std::stringstream ss;
+ ss << "The file only has " << num_row_groups()
+ << " row groups, requested metadata for row group: " << i;
+ throw ParquetException(ss.str());
+ }
+ return RowGroupMetaData::Make(&metadata_->row_groups[i], &schema_, &writer_version_,
+ file_decryptor_);
+ }
+
+ bool Equals(const FileMetaDataImpl& other) const {
+ return *metadata_ == *other.metadata_;
+ }
+
+ const SchemaDescriptor* schema() const { return &schema_; }
+
+ const std::shared_ptr<const KeyValueMetadata>& key_value_metadata() const {
+ return key_value_metadata_;
+ }
+
+ void set_file_path(const std::string& path) {
+ for (format::RowGroup& row_group : metadata_->row_groups) {
+ for (format::ColumnChunk& chunk : row_group.columns) {
+ chunk.__set_file_path(path);
+ }
+ }
+ }
+
+ format::RowGroup& row_group(int i) {
+ DCHECK_LT(i, num_row_groups());
+ return metadata_->row_groups[i];
+ }
+
+ void AppendRowGroups(const std::unique_ptr<FileMetaDataImpl>& other) {
+ if (!schema()->Equals(*other->schema())) {
+ throw ParquetException("AppendRowGroups requires equal schemas.");
+ }
+
+ // ARROW-13654: `other` may point to self, be careful not to enter an infinite loop
+ const int n = other->num_row_groups();
+ metadata_->row_groups.reserve(metadata_->row_groups.size() + n);
+ for (int i = 0; i < n; i++) {
+ format::RowGroup other_rg = other->row_group(i);
+ metadata_->num_rows += other_rg.num_rows;
+ metadata_->row_groups.push_back(std::move(other_rg));
+ }
+ }
+
+ std::shared_ptr<FileMetaData> Subset(const std::vector<int>& row_groups) {
+ for (int i : row_groups) {
+ if (i < num_row_groups()) continue;
+
+ throw ParquetException(
+ "The file only has ", num_row_groups(),
+ " row groups, but requested a subset including row group: ", i);
+ }
+
+ std::shared_ptr<FileMetaData> out(new FileMetaData());
+ out->impl_.reset(new FileMetaDataImpl());
+ out->impl_->metadata_.reset(new format::FileMetaData());
+
+ auto metadata = out->impl_->metadata_.get();
+ metadata->version = metadata_->version;
+ metadata->schema = metadata_->schema;
+
+ metadata->row_groups.resize(row_groups.size());
+ int i = 0;
+ for (int selected_index : row_groups) {
+ metadata->num_rows += row_group(selected_index).num_rows;
+ metadata->row_groups[i++] = row_group(selected_index);
+ }
+
+ metadata->key_value_metadata = metadata_->key_value_metadata;
+ metadata->created_by = metadata_->created_by;
+ metadata->column_orders = metadata_->column_orders;
+ metadata->encryption_algorithm = metadata_->encryption_algorithm;
+ metadata->footer_signing_key_metadata = metadata_->footer_signing_key_metadata;
+ metadata->__isset = metadata_->__isset;
+
+ out->impl_->schema_ = schema_;
+ out->impl_->writer_version_ = writer_version_;
+ out->impl_->key_value_metadata_ = key_value_metadata_;
+ out->impl_->file_decryptor_ = file_decryptor_;
+
+ return out;
+ }
+
+ void set_file_decryptor(std::shared_ptr<InternalFileDecryptor> file_decryptor) {
+ file_decryptor_ = file_decryptor;
+ }
+
+ private:
+ friend FileMetaDataBuilder;
+ uint32_t metadata_len_ = 0;
+ std::unique_ptr<format::FileMetaData> metadata_;
+ SchemaDescriptor schema_;
+ ApplicationVersion writer_version_;
+ std::shared_ptr<const KeyValueMetadata> key_value_metadata_;
+ std::shared_ptr<InternalFileDecryptor> file_decryptor_;
+
+ void InitSchema() {
+ if (metadata_->schema.empty()) {
+ throw ParquetException("Empty file schema (no root)");
+ }
+ schema_.Init(schema::Unflatten(&metadata_->schema[0],
+ static_cast<int>(metadata_->schema.size())));
+ }
+
+ void InitColumnOrders() {
+ // update ColumnOrder
+ std::vector<parquet::ColumnOrder> column_orders;
+ if (metadata_->__isset.column_orders) {
+ for (auto column_order : metadata_->column_orders) {
+ if (column_order.__isset.TYPE_ORDER) {
+ column_orders.push_back(ColumnOrder::type_defined_);
+ } else {
+ column_orders.push_back(ColumnOrder::undefined_);
+ }
+ }
+ } else {
+ column_orders.resize(schema_.num_columns(), ColumnOrder::undefined_);
+ }
+
+ schema_.updateColumnOrders(column_orders);
+ }
+
+ void InitKeyValueMetadata() {
+ std::shared_ptr<KeyValueMetadata> metadata = nullptr;
+ if (metadata_->__isset.key_value_metadata) {
+ metadata = std::make_shared<KeyValueMetadata>();
+ for (const auto& it : metadata_->key_value_metadata) {
+ metadata->Append(it.key, it.value);
+ }
+ }
+ key_value_metadata_ = std::move(metadata);
+ }
+};
+
+std::shared_ptr<FileMetaData> FileMetaData::Make(
+ const void* metadata, uint32_t* metadata_len,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor) {
+ // This FileMetaData ctor is private, not compatible with std::make_shared
+ return std::shared_ptr<FileMetaData>(
+ new FileMetaData(metadata, metadata_len, file_decryptor));
+}
+
+FileMetaData::FileMetaData(const void* metadata, uint32_t* metadata_len,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor)
+ : impl_{std::unique_ptr<FileMetaDataImpl>(
+ new FileMetaDataImpl(metadata, metadata_len, file_decryptor))} {}
+
+FileMetaData::FileMetaData()
+ : impl_{std::unique_ptr<FileMetaDataImpl>(new FileMetaDataImpl())} {}
+
+FileMetaData::~FileMetaData() = default;
+
+bool FileMetaData::Equals(const FileMetaData& other) const {
+ return impl_->Equals(*other.impl_);
+}
+
+std::unique_ptr<RowGroupMetaData> FileMetaData::RowGroup(int i) const {
+ return impl_->RowGroup(i);
+}
+
+bool FileMetaData::VerifySignature(const void* signature) {
+ return impl_->VerifySignature(signature);
+}
+
+uint32_t FileMetaData::size() const { return impl_->size(); }
+
+int FileMetaData::num_columns() const { return impl_->num_columns(); }
+
+int64_t FileMetaData::num_rows() const { return impl_->num_rows(); }
+
+int FileMetaData::num_row_groups() const { return impl_->num_row_groups(); }
+
+bool FileMetaData::can_decompress() const {
+ int n_row_groups = num_row_groups();
+ for (int i = 0; i < n_row_groups; i++) {
+ if (!RowGroup(i)->can_decompress()) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool FileMetaData::is_encryption_algorithm_set() const {
+ return impl_->is_encryption_algorithm_set();
+}
+
+EncryptionAlgorithm FileMetaData::encryption_algorithm() const {
+ return impl_->encryption_algorithm();
+}
+
+const std::string& FileMetaData::footer_signing_key_metadata() const {
+ return impl_->footer_signing_key_metadata();
+}
+
+void FileMetaData::set_file_decryptor(
+ std::shared_ptr<InternalFileDecryptor> file_decryptor) {
+ impl_->set_file_decryptor(file_decryptor);
+}
+
+ParquetVersion::type FileMetaData::version() const {
+ switch (impl_->version()) {
+ case 1:
+ return ParquetVersion::PARQUET_1_0;
+ case 2:
+ return ParquetVersion::PARQUET_2_LATEST;
+ default:
+ // Improperly set version, assuming Parquet 1.0
+ break;
+ }
+ return ParquetVersion::PARQUET_1_0;
+}
+
+const ApplicationVersion& FileMetaData::writer_version() const {
+ return impl_->writer_version();
+}
+
+const std::string& FileMetaData::created_by() const { return impl_->created_by(); }
+
+int FileMetaData::num_schema_elements() const { return impl_->num_schema_elements(); }
+
+const SchemaDescriptor* FileMetaData::schema() const { return impl_->schema(); }
+
+const std::shared_ptr<const KeyValueMetadata>& FileMetaData::key_value_metadata() const {
+ return impl_->key_value_metadata();
+}
+
+void FileMetaData::set_file_path(const std::string& path) { impl_->set_file_path(path); }
+
+void FileMetaData::AppendRowGroups(const FileMetaData& other) {
+ impl_->AppendRowGroups(other.impl_);
+}
+
+std::shared_ptr<FileMetaData> FileMetaData::Subset(
+ const std::vector<int>& row_groups) const {
+ return impl_->Subset(row_groups);
+}
+
+void FileMetaData::WriteTo(::arrow::io::OutputStream* dst,
+ const std::shared_ptr<Encryptor>& encryptor) const {
+ return impl_->WriteTo(dst, encryptor);
+}
+
+class FileCryptoMetaData::FileCryptoMetaDataImpl {
+ public:
+ FileCryptoMetaDataImpl() = default;
+
+ explicit FileCryptoMetaDataImpl(const uint8_t* metadata, uint32_t* metadata_len) {
+ metadata_.reset(new format::FileCryptoMetaData);
+ DeserializeThriftMsg(metadata, metadata_len, metadata_.get());
+ metadata_len_ = *metadata_len;
+ }
+
+ EncryptionAlgorithm encryption_algorithm() {
+ return FromThrift(metadata_->encryption_algorithm);
+ }
+ const std::string& key_metadata() { return metadata_->key_metadata; }
+ void WriteTo(::arrow::io::OutputStream* dst) const {
+ ThriftSerializer serializer;
+ serializer.Serialize(metadata_.get(), dst);
+ }
+
+ private:
+ friend FileMetaDataBuilder;
+ std::unique_ptr<format::FileCryptoMetaData> metadata_;
+ uint32_t metadata_len_;
+};
+
+EncryptionAlgorithm FileCryptoMetaData::encryption_algorithm() const {
+ return impl_->encryption_algorithm();
+}
+
+const std::string& FileCryptoMetaData::key_metadata() const {
+ return impl_->key_metadata();
+}
+
+std::shared_ptr<FileCryptoMetaData> FileCryptoMetaData::Make(
+ const uint8_t* serialized_metadata, uint32_t* metadata_len) {
+ return std::shared_ptr<FileCryptoMetaData>(
+ new FileCryptoMetaData(serialized_metadata, metadata_len));
+}
+
+FileCryptoMetaData::FileCryptoMetaData(const uint8_t* serialized_metadata,
+ uint32_t* metadata_len)
+ : impl_(new FileCryptoMetaDataImpl(serialized_metadata, metadata_len)) {}
+
+FileCryptoMetaData::FileCryptoMetaData() : impl_(new FileCryptoMetaDataImpl()) {}
+
+FileCryptoMetaData::~FileCryptoMetaData() = default;
+
+void FileCryptoMetaData::WriteTo(::arrow::io::OutputStream* dst) const {
+ impl_->WriteTo(dst);
+}
+
+std::string FileMetaData::SerializeToString() const {
+ // We need to pass in an initial size. Since it will automatically
+ // increase the buffer size to hold the metadata, we just leave it 0.
+ PARQUET_ASSIGN_OR_THROW(auto serializer, ::arrow::io::BufferOutputStream::Create(0));
+ WriteTo(serializer.get());
+ PARQUET_ASSIGN_OR_THROW(auto metadata_buffer, serializer->Finish());
+ return metadata_buffer->ToString();
+}
+
+ApplicationVersion::ApplicationVersion(std::string application, int major, int minor,
+ int patch)
+ : application_(std::move(application)), version{major, minor, patch, "", "", ""} {}
+
+namespace {
+// Parse the application version format and set parsed values to
+// ApplicationVersion.
+//
+// The application version format must be compatible parquet-mr's
+// one. See also:
+// * https://github.com/apache/parquet-mr/blob/master/parquet-common/src/main/java/org/apache/parquet/VersionParser.java
+// * https://github.com/apache/parquet-mr/blob/master/parquet-common/src/main/java/org/apache/parquet/SemanticVersion.java
+//
+// The application version format:
+// "${APPLICATION_NAME}"
+// "${APPLICATION_NAME} version ${VERSION}"
+// "${APPLICATION_NAME} version ${VERSION} (build ${BUILD_NAME})"
+//
+// Eg:
+// parquet-cpp
+// parquet-cpp version 1.5.0ab-xyz5.5.0+cd
+// parquet-cpp version 1.5.0ab-xyz5.5.0+cd (build abcd)
+//
+// The VERSION format:
+// "${MAJOR}"
+// "${MAJOR}.${MINOR}"
+// "${MAJOR}.${MINOR}.${PATCH}"
+// "${MAJOR}.${MINOR}.${PATCH}${UNKNOWN}"
+// "${MAJOR}.${MINOR}.${PATCH}${UNKNOWN}-${PRE_RELEASE}"
+// "${MAJOR}.${MINOR}.${PATCH}${UNKNOWN}-${PRE_RELEASE}+${BUILD_INFO}"
+// "${MAJOR}.${MINOR}.${PATCH}${UNKNOWN}+${BUILD_INFO}"
+// "${MAJOR}.${MINOR}.${PATCH}-${PRE_RELEASE}"
+// "${MAJOR}.${MINOR}.${PATCH}-${PRE_RELEASE}+${BUILD_INFO}"
+// "${MAJOR}.${MINOR}.${PATCH}+${BUILD_INFO}"
+//
+// Eg:
+// 1
+// 1.5
+// 1.5.0
+// 1.5.0ab
+// 1.5.0ab-cdh5.5.0
+// 1.5.0ab-cdh5.5.0+cd
+// 1.5.0ab+cd
+// 1.5.0-cdh5.5.0
+// 1.5.0-cdh5.5.0+cd
+// 1.5.0+cd
+class ApplicationVersionParser {
+ public:
+ ApplicationVersionParser(const std::string& created_by,
+ ApplicationVersion& application_version)
+ : created_by_(created_by),
+ application_version_(application_version),
+ spaces_(" \t\v\r\n\f"),
+ digits_("0123456789") {}
+
+ void Parse() {
+ application_version_.application_ = "unknown";
+ application_version_.version = {0, 0, 0, "", "", ""};
+
+ if (!ParseApplicationName()) {
+ return;
+ }
+ if (!ParseVersion()) {
+ return;
+ }
+ if (!ParseBuildName()) {
+ return;
+ }
+ }
+
+ private:
+ bool IsSpace(const std::string& string, const size_t& offset) {
+ auto target = ::arrow::util::string_view(string).substr(offset, 1);
+ return target.find_first_of(spaces_) != ::arrow::util::string_view::npos;
+ }
+
+ void RemovePrecedingSpaces(const std::string& string, size_t& start,
+ const size_t& end) {
+ while (start < end && IsSpace(string, start)) {
+ ++start;
+ }
+ }
+
+ void RemoveTrailingSpaces(const std::string& string, const size_t& start, size_t& end) {
+ while (start < (end - 1) && (end - 1) < string.size() && IsSpace(string, end - 1)) {
+ --end;
+ }
+ }
+
+ bool ParseApplicationName() {
+ std::string version_mark(" version ");
+ auto version_mark_position = created_by_.find(version_mark);
+ size_t application_name_end;
+ // No VERSION and BUILD_NAME.
+ if (version_mark_position == std::string::npos) {
+ version_start_ = std::string::npos;
+ application_name_end = created_by_.size();
+ } else {
+ version_start_ = version_mark_position + version_mark.size();
+ application_name_end = version_mark_position;
+ }
+
+ size_t application_name_start = 0;
+ RemovePrecedingSpaces(created_by_, application_name_start, application_name_end);
+ RemoveTrailingSpaces(created_by_, application_name_start, application_name_end);
+ application_version_.application_ = created_by_.substr(
+ application_name_start, application_name_end - application_name_start);
+
+ return true;
+ }
+
+ bool ParseVersion() {
+ // No VERSION.
+ if (version_start_ == std::string::npos) {
+ return false;
+ }
+
+ RemovePrecedingSpaces(created_by_, version_start_, created_by_.size());
+ version_end_ = created_by_.find(" (", version_start_);
+ // No BUILD_NAME.
+ if (version_end_ == std::string::npos) {
+ version_end_ = created_by_.size();
+ }
+ RemoveTrailingSpaces(created_by_, version_start_, version_end_);
+ // No VERSION.
+ if (version_start_ == version_end_) {
+ return false;
+ }
+ version_string_ = created_by_.substr(version_start_, version_end_ - version_start_);
+
+ if (!ParseVersionMajor()) {
+ return false;
+ }
+ if (!ParseVersionMinor()) {
+ return false;
+ }
+ if (!ParseVersionPatch()) {
+ return false;
+ }
+ if (!ParseVersionUnknown()) {
+ return false;
+ }
+ if (!ParseVersionPreRelease()) {
+ return false;
+ }
+ if (!ParseVersionBuildInfo()) {
+ return false;
+ }
+
+ return true;
+ }
+
+ bool ParseVersionMajor() {
+ size_t version_major_start = 0;
+ auto version_major_end = version_string_.find_first_not_of(digits_);
+ // MAJOR only.
+ if (version_major_end == std::string::npos) {
+ version_major_end = version_string_.size();
+ version_parsing_position_ = version_major_end;
+ } else {
+ // No ".".
+ if (version_string_[version_major_end] != '.') {
+ return false;
+ }
+ // No MAJOR.
+ if (version_major_end == version_major_start) {
+ return false;
+ }
+ version_parsing_position_ = version_major_end + 1; // +1 is for '.'.
+ }
+ auto version_major_string = version_string_.substr(
+ version_major_start, version_major_end - version_major_start);
+ application_version_.version.major = atoi(version_major_string.c_str());
+ return true;
+ }
+
+ bool ParseVersionMinor() {
+ auto version_minor_start = version_parsing_position_;
+ auto version_minor_end =
+ version_string_.find_first_not_of(digits_, version_minor_start);
+ // MAJOR.MINOR only.
+ if (version_minor_end == std::string::npos) {
+ version_minor_end = version_string_.size();
+ version_parsing_position_ = version_minor_end;
+ } else {
+ // No ".".
+ if (version_string_[version_minor_end] != '.') {
+ return false;
+ }
+ // No MINOR.
+ if (version_minor_end == version_minor_start) {
+ return false;
+ }
+ version_parsing_position_ = version_minor_end + 1; // +1 is for '.'.
+ }
+ auto version_minor_string = version_string_.substr(
+ version_minor_start, version_minor_end - version_minor_start);
+ application_version_.version.minor = atoi(version_minor_string.c_str());
+ return true;
+ }
+
+ bool ParseVersionPatch() {
+ auto version_patch_start = version_parsing_position_;
+ auto version_patch_end =
+ version_string_.find_first_not_of(digits_, version_patch_start);
+ // No UNKNOWN, PRE_RELEASE and BUILD_INFO.
+ if (version_patch_end == std::string::npos) {
+ version_patch_end = version_string_.size();
+ }
+ // No PATCH.
+ if (version_patch_end == version_patch_start) {
+ return false;
+ }
+ auto version_patch_string = version_string_.substr(
+ version_patch_start, version_patch_end - version_patch_start);
+ application_version_.version.patch = atoi(version_patch_string.c_str());
+ version_parsing_position_ = version_patch_end;
+ return true;
+ }
+
+ bool ParseVersionUnknown() {
+ // No UNKNOWN.
+ if (version_parsing_position_ == version_string_.size()) {
+ return true;
+ }
+ auto version_unknown_start = version_parsing_position_;
+ auto version_unknown_end = version_string_.find_first_of("-+", version_unknown_start);
+ // No PRE_RELEASE and BUILD_INFO
+ if (version_unknown_end == std::string::npos) {
+ version_unknown_end = version_string_.size();
+ }
+ application_version_.version.unknown = version_string_.substr(
+ version_unknown_start, version_unknown_end - version_unknown_start);
+ version_parsing_position_ = version_unknown_end;
+ return true;
+ }
+
+ bool ParseVersionPreRelease() {
+ // No PRE_RELEASE.
+ if (version_parsing_position_ == version_string_.size() ||
+ version_string_[version_parsing_position_] != '-') {
+ return true;
+ }
+
+ auto version_pre_release_start = version_parsing_position_ + 1; // +1 is for '-'.
+ auto version_pre_release_end =
+ version_string_.find_first_of("+", version_pre_release_start);
+ // No BUILD_INFO
+ if (version_pre_release_end == std::string::npos) {
+ version_pre_release_end = version_string_.size();
+ }
+ application_version_.version.pre_release = version_string_.substr(
+ version_pre_release_start, version_pre_release_end - version_pre_release_start);
+ version_parsing_position_ = version_pre_release_end;
+ return true;
+ }
+
+ bool ParseVersionBuildInfo() {
+ // No BUILD_INFO.
+ if (version_parsing_position_ == version_string_.size() ||
+ version_string_[version_parsing_position_] != '+') {
+ return true;
+ }
+
+ auto version_build_info_start = version_parsing_position_ + 1; // +1 is for '+'.
+ application_version_.version.build_info =
+ version_string_.substr(version_build_info_start);
+ return true;
+ }
+
+ bool ParseBuildName() {
+ std::string build_mark(" (build ");
+ auto build_mark_position = created_by_.find(build_mark, version_end_);
+ // No BUILD_NAME.
+ if (build_mark_position == std::string::npos) {
+ return false;
+ }
+ auto build_name_start = build_mark_position + build_mark.size();
+ RemovePrecedingSpaces(created_by_, build_name_start, created_by_.size());
+ auto build_name_end = created_by_.find_first_of(")", build_name_start);
+ // No end ")".
+ if (build_name_end == std::string::npos) {
+ return false;
+ }
+ RemoveTrailingSpaces(created_by_, build_name_start, build_name_end);
+ application_version_.build_ =
+ created_by_.substr(build_name_start, build_name_end - build_name_start);
+
+ return true;
+ }
+
+ const std::string& created_by_;
+ ApplicationVersion& application_version_;
+
+ // For parsing.
+ std::string spaces_;
+ std::string digits_;
+ size_t version_parsing_position_;
+ size_t version_start_;
+ size_t version_end_;
+ std::string version_string_;
+};
+} // namespace
+
+ApplicationVersion::ApplicationVersion(const std::string& created_by) {
+ ApplicationVersionParser parser(created_by, *this);
+ parser.Parse();
+}
+
+bool ApplicationVersion::VersionLt(const ApplicationVersion& other_version) const {
+ if (application_ != other_version.application_) return false;
+
+ if (version.major < other_version.version.major) return true;
+ if (version.major > other_version.version.major) return false;
+ DCHECK_EQ(version.major, other_version.version.major);
+ if (version.minor < other_version.version.minor) return true;
+ if (version.minor > other_version.version.minor) return false;
+ DCHECK_EQ(version.minor, other_version.version.minor);
+ return version.patch < other_version.version.patch;
+}
+
+bool ApplicationVersion::VersionEq(const ApplicationVersion& other_version) const {
+ return application_ == other_version.application_ &&
+ version.major == other_version.version.major &&
+ version.minor == other_version.version.minor &&
+ version.patch == other_version.version.patch;
+}
+
+// Reference:
+// parquet-mr/parquet-column/src/main/java/org/apache/parquet/CorruptStatistics.java
+// PARQUET-686 has more discussion on statistics
+bool ApplicationVersion::HasCorrectStatistics(Type::type col_type,
+ EncodedStatistics& statistics,
+ SortOrder::type sort_order) const {
+ // parquet-cpp version 1.3.0 and parquet-mr 1.10.0 onwards stats are computed
+ // correctly for all types
+ if ((application_ == "parquet-cpp" && VersionLt(PARQUET_CPP_FIXED_STATS_VERSION())) ||
+ (application_ == "parquet-mr" && VersionLt(PARQUET_MR_FIXED_STATS_VERSION()))) {
+ // Only SIGNED are valid unless max and min are the same
+ // (in which case the sort order does not matter)
+ bool max_equals_min = statistics.has_min && statistics.has_max
+ ? statistics.min() == statistics.max()
+ : false;
+ if (SortOrder::SIGNED != sort_order && !max_equals_min) {
+ return false;
+ }
+
+ // Statistics of other types are OK
+ if (col_type != Type::FIXED_LEN_BYTE_ARRAY && col_type != Type::BYTE_ARRAY) {
+ return true;
+ }
+ }
+ // created_by is not populated, which could have been caused by
+ // parquet-mr during the same time as PARQUET-251, see PARQUET-297
+ if (application_ == "unknown") {
+ return true;
+ }
+
+ // Unknown sort order has incorrect stats
+ if (SortOrder::UNKNOWN == sort_order) {
+ return false;
+ }
+
+ // PARQUET-251
+ if (VersionLt(PARQUET_251_FIXED_VERSION())) {
+ return false;
+ }
+
+ return true;
+}
+
+// MetaData Builders
+// row-group metadata
+class ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilderImpl {
+ public:
+ explicit ColumnChunkMetaDataBuilderImpl(std::shared_ptr<WriterProperties> props,
+ const ColumnDescriptor* column)
+ : owned_column_chunk_(new format::ColumnChunk),
+ properties_(std::move(props)),
+ column_(column) {
+ Init(owned_column_chunk_.get());
+ }
+
+ explicit ColumnChunkMetaDataBuilderImpl(std::shared_ptr<WriterProperties> props,
+ const ColumnDescriptor* column,
+ format::ColumnChunk* column_chunk)
+ : properties_(std::move(props)), column_(column) {
+ Init(column_chunk);
+ }
+
+ const void* contents() const { return column_chunk_; }
+
+ // column chunk
+ void set_file_path(const std::string& val) { column_chunk_->__set_file_path(val); }
+
+ // column metadata
+ void SetStatistics(const EncodedStatistics& val) {
+ column_chunk_->meta_data.__set_statistics(ToThrift(val));
+ }
+
+ void Finish(int64_t num_values, int64_t dictionary_page_offset,
+ int64_t index_page_offset, int64_t data_page_offset,
+ int64_t compressed_size, int64_t uncompressed_size, bool has_dictionary,
+ bool dictionary_fallback,
+ const std::map<Encoding::type, int32_t>& dict_encoding_stats,
+ const std::map<Encoding::type, int32_t>& data_encoding_stats,
+ const std::shared_ptr<Encryptor>& encryptor) {
+ if (dictionary_page_offset > 0) {
+ column_chunk_->meta_data.__set_dictionary_page_offset(dictionary_page_offset);
+ column_chunk_->__set_file_offset(dictionary_page_offset + compressed_size);
+ } else {
+ column_chunk_->__set_file_offset(data_page_offset + compressed_size);
+ }
+ column_chunk_->__isset.meta_data = true;
+ column_chunk_->meta_data.__set_num_values(num_values);
+ if (index_page_offset >= 0) {
+ column_chunk_->meta_data.__set_index_page_offset(index_page_offset);
+ }
+ column_chunk_->meta_data.__set_data_page_offset(data_page_offset);
+ column_chunk_->meta_data.__set_total_uncompressed_size(uncompressed_size);
+ column_chunk_->meta_data.__set_total_compressed_size(compressed_size);
+
+ std::vector<format::Encoding::type> thrift_encodings;
+ if (has_dictionary) {
+ thrift_encodings.push_back(ToThrift(properties_->dictionary_index_encoding()));
+ if (properties_->version() == ParquetVersion::PARQUET_1_0) {
+ thrift_encodings.push_back(ToThrift(Encoding::PLAIN));
+ } else {
+ thrift_encodings.push_back(ToThrift(properties_->dictionary_page_encoding()));
+ }
+ } else { // Dictionary not enabled
+ thrift_encodings.push_back(ToThrift(properties_->encoding(column_->path())));
+ }
+ thrift_encodings.push_back(ToThrift(Encoding::RLE));
+ // Only PLAIN encoding is supported for fallback in V1
+ // TODO(majetideepak): Use user specified encoding for V2
+ if (dictionary_fallback) {
+ thrift_encodings.push_back(ToThrift(Encoding::PLAIN));
+ }
+ column_chunk_->meta_data.__set_encodings(thrift_encodings);
+ std::vector<format::PageEncodingStats> thrift_encoding_stats;
+ // Add dictionary page encoding stats
+ for (const auto& entry : dict_encoding_stats) {
+ format::PageEncodingStats dict_enc_stat;
+ dict_enc_stat.__set_page_type(format::PageType::DICTIONARY_PAGE);
+ dict_enc_stat.__set_encoding(ToThrift(entry.first));
+ dict_enc_stat.__set_count(entry.second);
+ thrift_encoding_stats.push_back(dict_enc_stat);
+ }
+ // Add data page encoding stats
+ for (const auto& entry : data_encoding_stats) {
+ format::PageEncodingStats data_enc_stat;
+ data_enc_stat.__set_page_type(format::PageType::DATA_PAGE);
+ data_enc_stat.__set_encoding(ToThrift(entry.first));
+ data_enc_stat.__set_count(entry.second);
+ thrift_encoding_stats.push_back(data_enc_stat);
+ }
+ column_chunk_->meta_data.__set_encoding_stats(thrift_encoding_stats);
+
+ const auto& encrypt_md =
+ properties_->column_encryption_properties(column_->path()->ToDotString());
+ // column is encrypted
+ if (encrypt_md != nullptr && encrypt_md->is_encrypted()) {
+ column_chunk_->__isset.crypto_metadata = true;
+ format::ColumnCryptoMetaData ccmd;
+ if (encrypt_md->is_encrypted_with_footer_key()) {
+ // encrypted with footer key
+ ccmd.__isset.ENCRYPTION_WITH_FOOTER_KEY = true;
+ ccmd.__set_ENCRYPTION_WITH_FOOTER_KEY(format::EncryptionWithFooterKey());
+ } else { // encrypted with column key
+ format::EncryptionWithColumnKey eck;
+ eck.__set_key_metadata(encrypt_md->key_metadata());
+ eck.__set_path_in_schema(column_->path()->ToDotVector());
+ ccmd.__isset.ENCRYPTION_WITH_COLUMN_KEY = true;
+ ccmd.__set_ENCRYPTION_WITH_COLUMN_KEY(eck);
+ }
+ column_chunk_->__set_crypto_metadata(ccmd);
+
+ bool encrypted_footer =
+ properties_->file_encryption_properties()->encrypted_footer();
+ bool encrypt_metadata =
+ !encrypted_footer || !encrypt_md->is_encrypted_with_footer_key();
+ if (encrypt_metadata) {
+ ThriftSerializer serializer;
+ // Serialize and encrypt ColumnMetadata separately
+ // Thrift-serialize the ColumnMetaData structure,
+ // encrypt it with the column key, and write to encrypted_column_metadata
+ uint8_t* serialized_data;
+ uint32_t serialized_len;
+
+ serializer.SerializeToBuffer(&column_chunk_->meta_data, &serialized_len,
+ &serialized_data);
+
+ std::vector<uint8_t> encrypted_data(encryptor->CiphertextSizeDelta() +
+ serialized_len);
+ unsigned encrypted_len =
+ encryptor->Encrypt(serialized_data, serialized_len, encrypted_data.data());
+
+ const char* temp =
+ const_cast<const char*>(reinterpret_cast<char*>(encrypted_data.data()));
+ std::string encrypted_column_metadata(temp, encrypted_len);
+ column_chunk_->__set_encrypted_column_metadata(encrypted_column_metadata);
+
+ if (encrypted_footer) {
+ column_chunk_->__isset.meta_data = false;
+ } else {
+ // Keep redacted metadata version for old readers
+ column_chunk_->__isset.meta_data = true;
+ column_chunk_->meta_data.__isset.statistics = false;
+ column_chunk_->meta_data.__isset.encoding_stats = false;
+ }
+ }
+ }
+ }
+
+ void WriteTo(::arrow::io::OutputStream* sink) {
+ ThriftSerializer serializer;
+ serializer.Serialize(column_chunk_, sink);
+ }
+
+ const ColumnDescriptor* descr() const { return column_; }
+ int64_t total_compressed_size() const {
+ return column_chunk_->meta_data.total_compressed_size;
+ }
+
+ private:
+ void Init(format::ColumnChunk* column_chunk) {
+ column_chunk_ = column_chunk;
+
+ column_chunk_->meta_data.__set_type(ToThrift(column_->physical_type()));
+ column_chunk_->meta_data.__set_path_in_schema(column_->path()->ToDotVector());
+ column_chunk_->meta_data.__set_codec(
+ ToThrift(properties_->compression(column_->path())));
+ }
+
+ format::ColumnChunk* column_chunk_;
+ std::unique_ptr<format::ColumnChunk> owned_column_chunk_;
+ const std::shared_ptr<WriterProperties> properties_;
+ const ColumnDescriptor* column_;
+};
+
+std::unique_ptr<ColumnChunkMetaDataBuilder> ColumnChunkMetaDataBuilder::Make(
+ std::shared_ptr<WriterProperties> props, const ColumnDescriptor* column,
+ void* contents) {
+ return std::unique_ptr<ColumnChunkMetaDataBuilder>(
+ new ColumnChunkMetaDataBuilder(std::move(props), column, contents));
+}
+
+std::unique_ptr<ColumnChunkMetaDataBuilder> ColumnChunkMetaDataBuilder::Make(
+ std::shared_ptr<WriterProperties> props, const ColumnDescriptor* column) {
+ return std::unique_ptr<ColumnChunkMetaDataBuilder>(
+ new ColumnChunkMetaDataBuilder(std::move(props), column));
+}
+
+ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilder(
+ std::shared_ptr<WriterProperties> props, const ColumnDescriptor* column)
+ : impl_{std::unique_ptr<ColumnChunkMetaDataBuilderImpl>(
+ new ColumnChunkMetaDataBuilderImpl(std::move(props), column))} {}
+
+ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilder(
+ std::shared_ptr<WriterProperties> props, const ColumnDescriptor* column,
+ void* contents)
+ : impl_{std::unique_ptr<ColumnChunkMetaDataBuilderImpl>(
+ new ColumnChunkMetaDataBuilderImpl(
+ std::move(props), column,
+ reinterpret_cast<format::ColumnChunk*>(contents)))} {}
+
+ColumnChunkMetaDataBuilder::~ColumnChunkMetaDataBuilder() = default;
+
+const void* ColumnChunkMetaDataBuilder::contents() const { return impl_->contents(); }
+
+void ColumnChunkMetaDataBuilder::set_file_path(const std::string& path) {
+ impl_->set_file_path(path);
+}
+
+void ColumnChunkMetaDataBuilder::Finish(
+ int64_t num_values, int64_t dictionary_page_offset, int64_t index_page_offset,
+ int64_t data_page_offset, int64_t compressed_size, int64_t uncompressed_size,
+ bool has_dictionary, bool dictionary_fallback,
+ const std::map<Encoding::type, int32_t>& dict_encoding_stats,
+ const std::map<Encoding::type, int32_t>& data_encoding_stats,
+ const std::shared_ptr<Encryptor>& encryptor) {
+ impl_->Finish(num_values, dictionary_page_offset, index_page_offset, data_page_offset,
+ compressed_size, uncompressed_size, has_dictionary, dictionary_fallback,
+ dict_encoding_stats, data_encoding_stats, encryptor);
+}
+
+void ColumnChunkMetaDataBuilder::WriteTo(::arrow::io::OutputStream* sink) {
+ impl_->WriteTo(sink);
+}
+
+const ColumnDescriptor* ColumnChunkMetaDataBuilder::descr() const {
+ return impl_->descr();
+}
+
+void ColumnChunkMetaDataBuilder::SetStatistics(const EncodedStatistics& result) {
+ impl_->SetStatistics(result);
+}
+
+int64_t ColumnChunkMetaDataBuilder::total_compressed_size() const {
+ return impl_->total_compressed_size();
+}
+
+class RowGroupMetaDataBuilder::RowGroupMetaDataBuilderImpl {
+ public:
+ explicit RowGroupMetaDataBuilderImpl(std::shared_ptr<WriterProperties> props,
+ const SchemaDescriptor* schema, void* contents)
+ : properties_(std::move(props)), schema_(schema), next_column_(0) {
+ row_group_ = reinterpret_cast<format::RowGroup*>(contents);
+ InitializeColumns(schema->num_columns());
+ }
+
+ ColumnChunkMetaDataBuilder* NextColumnChunk() {
+ if (!(next_column_ < num_columns())) {
+ std::stringstream ss;
+ ss << "The schema only has " << num_columns()
+ << " columns, requested metadata for column: " << next_column_;
+ throw ParquetException(ss.str());
+ }
+ auto column = schema_->Column(next_column_);
+ auto column_builder = ColumnChunkMetaDataBuilder::Make(
+ properties_, column, &row_group_->columns[next_column_++]);
+ auto column_builder_ptr = column_builder.get();
+ column_builders_.push_back(std::move(column_builder));
+ return column_builder_ptr;
+ }
+
+ int current_column() { return next_column_ - 1; }
+
+ void Finish(int64_t total_bytes_written, int16_t row_group_ordinal) {
+ if (!(next_column_ == schema_->num_columns())) {
+ std::stringstream ss;
+ ss << "Only " << next_column_ - 1 << " out of " << schema_->num_columns()
+ << " columns are initialized";
+ throw ParquetException(ss.str());
+ }
+
+ int64_t file_offset = 0;
+ int64_t total_compressed_size = 0;
+ for (int i = 0; i < schema_->num_columns(); i++) {
+ if (!(row_group_->columns[i].file_offset >= 0)) {
+ std::stringstream ss;
+ ss << "Column " << i << " is not complete.";
+ throw ParquetException(ss.str());
+ }
+ if (i == 0) {
+ const format::ColumnMetaData& first_col = row_group_->columns[0].meta_data;
+ // As per spec, file_offset for the row group points to the first
+ // dictionary or data page of the column.
+ if (first_col.__isset.dictionary_page_offset &&
+ first_col.dictionary_page_offset > 0) {
+ file_offset = first_col.dictionary_page_offset;
+ } else {
+ file_offset = first_col.data_page_offset;
+ }
+ }
+ // sometimes column metadata is encrypted and not available to read,
+ // so we must get total_compressed_size from column builder
+ total_compressed_size += column_builders_[i]->total_compressed_size();
+ }
+
+ row_group_->__set_file_offset(file_offset);
+ row_group_->__set_total_compressed_size(total_compressed_size);
+ row_group_->__set_total_byte_size(total_bytes_written);
+ row_group_->__set_ordinal(row_group_ordinal);
+ }
+
+ void set_num_rows(int64_t num_rows) { row_group_->num_rows = num_rows; }
+
+ int num_columns() { return static_cast<int>(row_group_->columns.size()); }
+
+ int64_t num_rows() { return row_group_->num_rows; }
+
+ private:
+ void InitializeColumns(int ncols) { row_group_->columns.resize(ncols); }
+
+ format::RowGroup* row_group_;
+ const std::shared_ptr<WriterProperties> properties_;
+ const SchemaDescriptor* schema_;
+ std::vector<std::unique_ptr<ColumnChunkMetaDataBuilder>> column_builders_;
+ int next_column_;
+};
+
+std::unique_ptr<RowGroupMetaDataBuilder> RowGroupMetaDataBuilder::Make(
+ std::shared_ptr<WriterProperties> props, const SchemaDescriptor* schema_,
+ void* contents) {
+ return std::unique_ptr<RowGroupMetaDataBuilder>(
+ new RowGroupMetaDataBuilder(std::move(props), schema_, contents));
+}
+
+RowGroupMetaDataBuilder::RowGroupMetaDataBuilder(std::shared_ptr<WriterProperties> props,
+ const SchemaDescriptor* schema_,
+ void* contents)
+ : impl_{new RowGroupMetaDataBuilderImpl(std::move(props), schema_, contents)} {}
+
+RowGroupMetaDataBuilder::~RowGroupMetaDataBuilder() = default;
+
+ColumnChunkMetaDataBuilder* RowGroupMetaDataBuilder::NextColumnChunk() {
+ return impl_->NextColumnChunk();
+}
+
+int RowGroupMetaDataBuilder::current_column() const { return impl_->current_column(); }
+
+int RowGroupMetaDataBuilder::num_columns() { return impl_->num_columns(); }
+
+int64_t RowGroupMetaDataBuilder::num_rows() { return impl_->num_rows(); }
+
+void RowGroupMetaDataBuilder::set_num_rows(int64_t num_rows) {
+ impl_->set_num_rows(num_rows);
+}
+
+void RowGroupMetaDataBuilder::Finish(int64_t total_bytes_written,
+ int16_t row_group_ordinal) {
+ impl_->Finish(total_bytes_written, row_group_ordinal);
+}
+
+// file metadata
+// TODO(PARQUET-595) Support key_value_metadata
+class FileMetaDataBuilder::FileMetaDataBuilderImpl {
+ public:
+ explicit FileMetaDataBuilderImpl(
+ const SchemaDescriptor* schema, std::shared_ptr<WriterProperties> props,
+ std::shared_ptr<const KeyValueMetadata> key_value_metadata)
+ : metadata_(new format::FileMetaData()),
+ properties_(std::move(props)),
+ schema_(schema),
+ key_value_metadata_(std::move(key_value_metadata)) {
+ if (properties_->file_encryption_properties() != nullptr &&
+ properties_->file_encryption_properties()->encrypted_footer()) {
+ crypto_metadata_.reset(new format::FileCryptoMetaData());
+ }
+ }
+
+ RowGroupMetaDataBuilder* AppendRowGroup() {
+ row_groups_.emplace_back();
+ current_row_group_builder_ =
+ RowGroupMetaDataBuilder::Make(properties_, schema_, &row_groups_.back());
+ return current_row_group_builder_.get();
+ }
+
+ std::unique_ptr<FileMetaData> Finish() {
+ int64_t total_rows = 0;
+ for (auto row_group : row_groups_) {
+ total_rows += row_group.num_rows;
+ }
+ metadata_->__set_num_rows(total_rows);
+ metadata_->__set_row_groups(row_groups_);
+
+ if (key_value_metadata_) {
+ metadata_->key_value_metadata.clear();
+ metadata_->key_value_metadata.reserve(key_value_metadata_->size());
+ for (int64_t i = 0; i < key_value_metadata_->size(); ++i) {
+ format::KeyValue kv_pair;
+ kv_pair.__set_key(key_value_metadata_->key(i));
+ kv_pair.__set_value(key_value_metadata_->value(i));
+ metadata_->key_value_metadata.push_back(kv_pair);
+ }
+ metadata_->__isset.key_value_metadata = true;
+ }
+
+ int32_t file_version = 0;
+ switch (properties_->version()) {
+ case ParquetVersion::PARQUET_1_0:
+ file_version = 1;
+ break;
+ default:
+ file_version = 2;
+ break;
+ }
+ metadata_->__set_version(file_version);
+ metadata_->__set_created_by(properties_->created_by());
+
+ // Users cannot set the `ColumnOrder` since we donot not have user defined sort order
+ // in the spec yet.
+ // We always default to `TYPE_DEFINED_ORDER`. We can expose it in
+ // the API once we have user defined sort orders in the Parquet format.
+ // TypeDefinedOrder implies choose SortOrder based on ConvertedType/PhysicalType
+ format::TypeDefinedOrder type_defined_order;
+ format::ColumnOrder column_order;
+ column_order.__set_TYPE_ORDER(type_defined_order);
+ column_order.__isset.TYPE_ORDER = true;
+ metadata_->column_orders.resize(schema_->num_columns(), column_order);
+ metadata_->__isset.column_orders = true;
+
+ // if plaintext footer, set footer signing algorithm
+ auto file_encryption_properties = properties_->file_encryption_properties();
+ if (file_encryption_properties && !file_encryption_properties->encrypted_footer()) {
+ EncryptionAlgorithm signing_algorithm;
+ EncryptionAlgorithm algo = file_encryption_properties->algorithm();
+ signing_algorithm.aad.aad_file_unique = algo.aad.aad_file_unique;
+ signing_algorithm.aad.supply_aad_prefix = algo.aad.supply_aad_prefix;
+ if (!algo.aad.supply_aad_prefix) {
+ signing_algorithm.aad.aad_prefix = algo.aad.aad_prefix;
+ }
+ signing_algorithm.algorithm = ParquetCipher::AES_GCM_V1;
+
+ metadata_->__set_encryption_algorithm(ToThrift(signing_algorithm));
+ const std::string& footer_signing_key_metadata =
+ file_encryption_properties->footer_key_metadata();
+ if (footer_signing_key_metadata.size() > 0) {
+ metadata_->__set_footer_signing_key_metadata(footer_signing_key_metadata);
+ }
+ }
+
+ ToParquet(static_cast<parquet::schema::GroupNode*>(schema_->schema_root().get()),
+ &metadata_->schema);
+ auto file_meta_data = std::unique_ptr<FileMetaData>(new FileMetaData());
+ file_meta_data->impl_->metadata_ = std::move(metadata_);
+ file_meta_data->impl_->InitSchema();
+ file_meta_data->impl_->InitKeyValueMetadata();
+ return file_meta_data;
+ }
+
+ std::unique_ptr<FileCryptoMetaData> BuildFileCryptoMetaData() {
+ if (crypto_metadata_ == nullptr) {
+ return nullptr;
+ }
+
+ auto file_encryption_properties = properties_->file_encryption_properties();
+
+ crypto_metadata_->__set_encryption_algorithm(
+ ToThrift(file_encryption_properties->algorithm()));
+ std::string key_metadata = file_encryption_properties->footer_key_metadata();
+
+ if (!key_metadata.empty()) {
+ crypto_metadata_->__set_key_metadata(key_metadata);
+ }
+
+ std::unique_ptr<FileCryptoMetaData> file_crypto_metadata =
+ std::unique_ptr<FileCryptoMetaData>(new FileCryptoMetaData());
+ file_crypto_metadata->impl_->metadata_ = std::move(crypto_metadata_);
+
+ return file_crypto_metadata;
+ }
+
+ protected:
+ std::unique_ptr<format::FileMetaData> metadata_;
+ std::unique_ptr<format::FileCryptoMetaData> crypto_metadata_;
+
+ private:
+ const std::shared_ptr<WriterProperties> properties_;
+ std::vector<format::RowGroup> row_groups_;
+
+ std::unique_ptr<RowGroupMetaDataBuilder> current_row_group_builder_;
+ const SchemaDescriptor* schema_;
+ std::shared_ptr<const KeyValueMetadata> key_value_metadata_;
+};
+
+std::unique_ptr<FileMetaDataBuilder> FileMetaDataBuilder::Make(
+ const SchemaDescriptor* schema, std::shared_ptr<WriterProperties> props,
+ std::shared_ptr<const KeyValueMetadata> key_value_metadata) {
+ return std::unique_ptr<FileMetaDataBuilder>(
+ new FileMetaDataBuilder(schema, std::move(props), std::move(key_value_metadata)));
+}
+
+FileMetaDataBuilder::FileMetaDataBuilder(
+ const SchemaDescriptor* schema, std::shared_ptr<WriterProperties> props,
+ std::shared_ptr<const KeyValueMetadata> key_value_metadata)
+ : impl_{std::unique_ptr<FileMetaDataBuilderImpl>(new FileMetaDataBuilderImpl(
+ schema, std::move(props), std::move(key_value_metadata)))} {}
+
+FileMetaDataBuilder::~FileMetaDataBuilder() = default;
+
+RowGroupMetaDataBuilder* FileMetaDataBuilder::AppendRowGroup() {
+ return impl_->AppendRowGroup();
+}
+
+std::unique_ptr<FileMetaData> FileMetaDataBuilder::Finish() { return impl_->Finish(); }
+
+std::unique_ptr<FileCryptoMetaData> FileMetaDataBuilder::GetCryptoMetaData() {
+ return impl_->BuildFileCryptoMetaData();
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/metadata.h b/src/arrow/cpp/src/parquet/metadata.h
new file mode 100644
index 000000000..3dd936d90
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/metadata.h
@@ -0,0 +1,489 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "parquet/platform.h"
+#include "parquet/properties.h"
+#include "parquet/schema.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+class ColumnDescriptor;
+class EncodedStatistics;
+class Statistics;
+class SchemaDescriptor;
+
+class FileCryptoMetaData;
+class InternalFileDecryptor;
+class Decryptor;
+class Encryptor;
+class FooterSigningEncryptor;
+
+namespace schema {
+
+class ColumnPath;
+
+} // namespace schema
+
+using KeyValueMetadata = ::arrow::KeyValueMetadata;
+
+class PARQUET_EXPORT ApplicationVersion {
+ public:
+ // Known Versions with Issues
+ static const ApplicationVersion& PARQUET_251_FIXED_VERSION();
+ static const ApplicationVersion& PARQUET_816_FIXED_VERSION();
+ static const ApplicationVersion& PARQUET_CPP_FIXED_STATS_VERSION();
+ static const ApplicationVersion& PARQUET_MR_FIXED_STATS_VERSION();
+
+ // Application that wrote the file. e.g. "IMPALA"
+ std::string application_;
+ // Build name
+ std::string build_;
+
+ // Version of the application that wrote the file, expressed as
+ // (<major>.<minor>.<patch>). Unmatched parts default to 0.
+ // "1.2.3" => {1, 2, 3}
+ // "1.2" => {1, 2, 0}
+ // "1.2-cdh5" => {1, 2, 0}
+ struct {
+ int major;
+ int minor;
+ int patch;
+ std::string unknown;
+ std::string pre_release;
+ std::string build_info;
+ } version;
+
+ ApplicationVersion() = default;
+ explicit ApplicationVersion(const std::string& created_by);
+ ApplicationVersion(std::string application, int major, int minor, int patch);
+
+ // Returns true if version is strictly less than other_version
+ bool VersionLt(const ApplicationVersion& other_version) const;
+
+ // Returns true if version is strictly equal with other_version
+ bool VersionEq(const ApplicationVersion& other_version) const;
+
+ // Checks if the Version has the correct statistics for a given column
+ bool HasCorrectStatistics(Type::type primitive, EncodedStatistics& statistics,
+ SortOrder::type sort_order = SortOrder::SIGNED) const;
+};
+
+class PARQUET_EXPORT ColumnCryptoMetaData {
+ public:
+ static std::unique_ptr<ColumnCryptoMetaData> Make(const uint8_t* metadata);
+ ~ColumnCryptoMetaData();
+
+ bool Equals(const ColumnCryptoMetaData& other) const;
+
+ std::shared_ptr<schema::ColumnPath> path_in_schema() const;
+ bool encrypted_with_footer_key() const;
+ const std::string& key_metadata() const;
+
+ private:
+ explicit ColumnCryptoMetaData(const uint8_t* metadata);
+
+ class ColumnCryptoMetaDataImpl;
+ std::unique_ptr<ColumnCryptoMetaDataImpl> impl_;
+};
+
+/// \brief Public struct for Thrift PageEncodingStats in ColumnChunkMetaData
+struct PageEncodingStats {
+ PageType::type page_type;
+ Encoding::type encoding;
+ int32_t count;
+};
+
+/// \brief ColumnChunkMetaData is a proxy around format::ColumnChunkMetaData.
+class PARQUET_EXPORT ColumnChunkMetaData {
+ public:
+ // API convenience to get a MetaData accessor
+ static std::unique_ptr<ColumnChunkMetaData> Make(
+ const void* metadata, const ColumnDescriptor* descr,
+ const ApplicationVersion* writer_version = NULLPTR, int16_t row_group_ordinal = -1,
+ int16_t column_ordinal = -1,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor = NULLPTR);
+
+ ~ColumnChunkMetaData();
+
+ bool Equals(const ColumnChunkMetaData& other) const;
+
+ // column chunk
+ int64_t file_offset() const;
+
+ // parameter is only used when a dataset is spread across multiple files
+ const std::string& file_path() const;
+
+ // column metadata
+ bool is_metadata_set() const;
+ Type::type type() const;
+ int64_t num_values() const;
+ std::shared_ptr<schema::ColumnPath> path_in_schema() const;
+ bool is_stats_set() const;
+ std::shared_ptr<Statistics> statistics() const;
+
+ Compression::type compression() const;
+ // Indicate if the ColumnChunk compression is supported by the current
+ // compiled parquet library.
+ bool can_decompress() const;
+
+ const std::vector<Encoding::type>& encodings() const;
+ const std::vector<PageEncodingStats>& encoding_stats() const;
+ bool has_dictionary_page() const;
+ int64_t dictionary_page_offset() const;
+ int64_t data_page_offset() const;
+ bool has_index_page() const;
+ int64_t index_page_offset() const;
+ int64_t total_compressed_size() const;
+ int64_t total_uncompressed_size() const;
+ std::unique_ptr<ColumnCryptoMetaData> crypto_metadata() const;
+
+ private:
+ explicit ColumnChunkMetaData(
+ const void* metadata, const ColumnDescriptor* descr, int16_t row_group_ordinal,
+ int16_t column_ordinal, const ApplicationVersion* writer_version = NULLPTR,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor = NULLPTR);
+ // PIMPL Idiom
+ class ColumnChunkMetaDataImpl;
+ std::unique_ptr<ColumnChunkMetaDataImpl> impl_;
+};
+
+/// \brief RowGroupMetaData is a proxy around format::RowGroupMetaData.
+class PARQUET_EXPORT RowGroupMetaData {
+ public:
+ /// \brief Create a RowGroupMetaData from a serialized thrift message.
+ static std::unique_ptr<RowGroupMetaData> Make(
+ const void* metadata, const SchemaDescriptor* schema,
+ const ApplicationVersion* writer_version = NULLPTR,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor = NULLPTR);
+
+ ~RowGroupMetaData();
+
+ bool Equals(const RowGroupMetaData& other) const;
+
+ /// \brief The number of columns in this row group. The order must match the
+ /// parent's column ordering.
+ int num_columns() const;
+
+ /// \brief Return the ColumnChunkMetaData of the corresponding column ordinal.
+ ///
+ /// WARNING, the returned object references memory location in it's parent
+ /// (RowGroupMetaData) object. Hence, the parent must outlive the returned
+ /// object.
+ ///
+ /// \param[in] index of the ColumnChunkMetaData to retrieve.
+ ///
+ /// \throws ParquetException if the index is out of bound.
+ std::unique_ptr<ColumnChunkMetaData> ColumnChunk(int index) const;
+
+ /// \brief Number of rows in this row group.
+ int64_t num_rows() const;
+
+ /// \brief Total byte size of all the uncompressed column data in this row group.
+ int64_t total_byte_size() const;
+
+ /// \brief Total byte size of all the compressed (and potentially encrypted)
+ /// column data in this row group.
+ ///
+ /// This information is optional and may be 0 if omitted.
+ int64_t total_compressed_size() const;
+
+ /// \brief Byte offset from beginning of file to first page (data or
+ /// dictionary) in this row group
+ ///
+ /// The file_offset field that this method exposes is optional. This method
+ /// will return 0 if that field is not set to a meaningful value.
+ int64_t file_offset() const;
+ // Return const-pointer to make it clear that this object is not to be copied
+ const SchemaDescriptor* schema() const;
+ // Indicate if all of the RowGroup's ColumnChunks can be decompressed.
+ bool can_decompress() const;
+
+ private:
+ explicit RowGroupMetaData(
+ const void* metadata, const SchemaDescriptor* schema,
+ const ApplicationVersion* writer_version = NULLPTR,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor = NULLPTR);
+ // PIMPL Idiom
+ class RowGroupMetaDataImpl;
+ std::unique_ptr<RowGroupMetaDataImpl> impl_;
+};
+
+class FileMetaDataBuilder;
+
+/// \brief FileMetaData is a proxy around format::FileMetaData.
+class PARQUET_EXPORT FileMetaData {
+ public:
+ /// \brief Create a FileMetaData from a serialized thrift message.
+ static std::shared_ptr<FileMetaData> Make(
+ const void* serialized_metadata, uint32_t* inout_metadata_len,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor = NULLPTR);
+
+ ~FileMetaData();
+
+ bool Equals(const FileMetaData& other) const;
+
+ /// \brief The number of top-level columns in the schema.
+ ///
+ /// Parquet thrift definition requires that nested schema elements are
+ /// flattened. This method returns the number of columns in the un-flattened
+ /// version.
+ int num_columns() const;
+
+ /// \brief The number of flattened schema elements.
+ ///
+ /// Parquet thrift definition requires that nested schema elements are
+ /// flattened. This method returns the total number of elements in the
+ /// flattened list.
+ int num_schema_elements() const;
+
+ /// \brief The total number of rows.
+ int64_t num_rows() const;
+
+ /// \brief The number of row groups in the file.
+ int num_row_groups() const;
+
+ /// \brief Return the RowGroupMetaData of the corresponding row group ordinal.
+ ///
+ /// WARNING, the returned object references memory location in it's parent
+ /// (FileMetaData) object. Hence, the parent must outlive the returned object.
+ ///
+ /// \param[in] index of the RowGroup to retrieve.
+ ///
+ /// \throws ParquetException if the index is out of bound.
+ std::unique_ptr<RowGroupMetaData> RowGroup(int index) const;
+
+ /// \brief Return the "version" of the file
+ ///
+ /// WARNING: The value returned by this method is unreliable as 1) the Parquet
+ /// file metadata stores the version as a single integer and 2) some producers
+ /// are known to always write a hardcoded value. Therefore, you cannot use
+ /// this value to know which features are used in the file.
+ ParquetVersion::type version() const;
+
+ /// \brief Return the application's user-agent string of the writer.
+ const std::string& created_by() const;
+
+ /// \brief Return the application's version of the writer.
+ const ApplicationVersion& writer_version() const;
+
+ /// \brief Size of the original thrift encoded metadata footer.
+ uint32_t size() const;
+
+ /// \brief Indicate if all of the FileMetadata's RowGroups can be decompressed.
+ ///
+ /// This will return false if any of the RowGroup's page is compressed with a
+ /// compression format which is not compiled in the current parquet library.
+ bool can_decompress() const;
+
+ bool is_encryption_algorithm_set() const;
+ EncryptionAlgorithm encryption_algorithm() const;
+ const std::string& footer_signing_key_metadata() const;
+
+ /// \brief Verify signature of FileMetaData when file is encrypted but footer
+ /// is not encrypted (plaintext footer).
+ bool VerifySignature(const void* signature);
+
+ void WriteTo(::arrow::io::OutputStream* dst,
+ const std::shared_ptr<Encryptor>& encryptor = NULLPTR) const;
+
+ /// \brief Return Thrift-serialized representation of the metadata as a
+ /// string
+ std::string SerializeToString() const;
+
+ // Return const-pointer to make it clear that this object is not to be copied
+ const SchemaDescriptor* schema() const;
+
+ const std::shared_ptr<const KeyValueMetadata>& key_value_metadata() const;
+
+ /// \brief Set a path to all ColumnChunk for all RowGroups.
+ ///
+ /// Commonly used by systems (Dask, Spark) who generates an metadata-only
+ /// parquet file. The path is usually relative to said index file.
+ ///
+ /// \param[in] path to set.
+ void set_file_path(const std::string& path);
+
+ /// \brief Merge row groups from another metadata file into this one.
+ ///
+ /// The schema of the input FileMetaData must be equal to the
+ /// schema of this object.
+ ///
+ /// This is used by systems who creates an aggregate metadata-only file by
+ /// concatenating the row groups of multiple files. This newly created
+ /// metadata file acts as an index of all available row groups.
+ ///
+ /// \param[in] other FileMetaData to merge the row groups from.
+ ///
+ /// \throws ParquetException if schemas are not equal.
+ void AppendRowGroups(const FileMetaData& other);
+
+ /// \brief Return a FileMetaData containing a subset of the row groups in this
+ /// FileMetaData.
+ std::shared_ptr<FileMetaData> Subset(const std::vector<int>& row_groups) const;
+
+ private:
+ friend FileMetaDataBuilder;
+ friend class SerializedFile;
+
+ explicit FileMetaData(const void* serialized_metadata, uint32_t* metadata_len,
+ std::shared_ptr<InternalFileDecryptor> file_decryptor = NULLPTR);
+
+ void set_file_decryptor(std::shared_ptr<InternalFileDecryptor> file_decryptor);
+
+ // PIMPL Idiom
+ FileMetaData();
+ class FileMetaDataImpl;
+ std::unique_ptr<FileMetaDataImpl> impl_;
+};
+
+class PARQUET_EXPORT FileCryptoMetaData {
+ public:
+ // API convenience to get a MetaData accessor
+ static std::shared_ptr<FileCryptoMetaData> Make(const uint8_t* serialized_metadata,
+ uint32_t* metadata_len);
+ ~FileCryptoMetaData();
+
+ EncryptionAlgorithm encryption_algorithm() const;
+ const std::string& key_metadata() const;
+
+ void WriteTo(::arrow::io::OutputStream* dst) const;
+
+ private:
+ friend FileMetaDataBuilder;
+ FileCryptoMetaData(const uint8_t* serialized_metadata, uint32_t* metadata_len);
+
+ // PIMPL Idiom
+ FileCryptoMetaData();
+ class FileCryptoMetaDataImpl;
+ std::unique_ptr<FileCryptoMetaDataImpl> impl_;
+};
+
+// Builder API
+class PARQUET_EXPORT ColumnChunkMetaDataBuilder {
+ public:
+ // API convenience to get a MetaData reader
+ static std::unique_ptr<ColumnChunkMetaDataBuilder> Make(
+ std::shared_ptr<WriterProperties> props, const ColumnDescriptor* column);
+
+ static std::unique_ptr<ColumnChunkMetaDataBuilder> Make(
+ std::shared_ptr<WriterProperties> props, const ColumnDescriptor* column,
+ void* contents);
+
+ ~ColumnChunkMetaDataBuilder();
+
+ // column chunk
+ // Used when a dataset is spread across multiple files
+ void set_file_path(const std::string& path);
+ // column metadata
+ void SetStatistics(const EncodedStatistics& stats);
+ // get the column descriptor
+ const ColumnDescriptor* descr() const;
+
+ int64_t total_compressed_size() const;
+ // commit the metadata
+
+ void Finish(int64_t num_values, int64_t dictionary_page_offset,
+ int64_t index_page_offset, int64_t data_page_offset,
+ int64_t compressed_size, int64_t uncompressed_size, bool has_dictionary,
+ bool dictionary_fallback,
+ const std::map<Encoding::type, int32_t>& dict_encoding_stats_,
+ const std::map<Encoding::type, int32_t>& data_encoding_stats_,
+ const std::shared_ptr<Encryptor>& encryptor = NULLPTR);
+
+ // The metadata contents, suitable for passing to ColumnChunkMetaData::Make
+ const void* contents() const;
+
+ // For writing metadata at end of column chunk
+ void WriteTo(::arrow::io::OutputStream* sink);
+
+ private:
+ explicit ColumnChunkMetaDataBuilder(std::shared_ptr<WriterProperties> props,
+ const ColumnDescriptor* column);
+ explicit ColumnChunkMetaDataBuilder(std::shared_ptr<WriterProperties> props,
+ const ColumnDescriptor* column, void* contents);
+ // PIMPL Idiom
+ class ColumnChunkMetaDataBuilderImpl;
+ std::unique_ptr<ColumnChunkMetaDataBuilderImpl> impl_;
+};
+
+class PARQUET_EXPORT RowGroupMetaDataBuilder {
+ public:
+ // API convenience to get a MetaData reader
+ static std::unique_ptr<RowGroupMetaDataBuilder> Make(
+ std::shared_ptr<WriterProperties> props, const SchemaDescriptor* schema_,
+ void* contents);
+
+ ~RowGroupMetaDataBuilder();
+
+ ColumnChunkMetaDataBuilder* NextColumnChunk();
+ int num_columns();
+ int64_t num_rows();
+ int current_column() const;
+
+ void set_num_rows(int64_t num_rows);
+
+ // commit the metadata
+ void Finish(int64_t total_bytes_written, int16_t row_group_ordinal = -1);
+
+ private:
+ explicit RowGroupMetaDataBuilder(std::shared_ptr<WriterProperties> props,
+ const SchemaDescriptor* schema_, void* contents);
+ // PIMPL Idiom
+ class RowGroupMetaDataBuilderImpl;
+ std::unique_ptr<RowGroupMetaDataBuilderImpl> impl_;
+};
+
+class PARQUET_EXPORT FileMetaDataBuilder {
+ public:
+ // API convenience to get a MetaData reader
+ static std::unique_ptr<FileMetaDataBuilder> Make(
+ const SchemaDescriptor* schema, std::shared_ptr<WriterProperties> props,
+ std::shared_ptr<const KeyValueMetadata> key_value_metadata = NULLPTR);
+
+ ~FileMetaDataBuilder();
+
+ // The prior RowGroupMetaDataBuilder (if any) is destroyed
+ RowGroupMetaDataBuilder* AppendRowGroup();
+
+ // Complete the Thrift structure
+ std::unique_ptr<FileMetaData> Finish();
+
+ // crypto metadata
+ std::unique_ptr<FileCryptoMetaData> GetCryptoMetaData();
+
+ private:
+ explicit FileMetaDataBuilder(
+ const SchemaDescriptor* schema, std::shared_ptr<WriterProperties> props,
+ std::shared_ptr<const KeyValueMetadata> key_value_metadata = NULLPTR);
+ // PIMPL Idiom
+ class FileMetaDataBuilderImpl;
+ std::unique_ptr<FileMetaDataBuilderImpl> impl_;
+};
+
+PARQUET_EXPORT std::string ParquetVersionToString(ParquetVersion::type ver);
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/metadata_test.cc b/src/arrow/cpp/src/parquet/metadata_test.cc
new file mode 100644
index 000000000..a89d3d97f
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/metadata_test.cc
@@ -0,0 +1,571 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/metadata.h"
+
+#include <gtest/gtest.h>
+
+#include "arrow/util/key_value_metadata.h"
+#include "parquet/schema.h"
+#include "parquet/statistics.h"
+#include "parquet/thrift_internal.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+namespace metadata {
+
+// Helper function for generating table metadata
+std::unique_ptr<parquet::FileMetaData> GenerateTableMetaData(
+ const parquet::SchemaDescriptor& schema,
+ const std::shared_ptr<WriterProperties>& props, const int64_t& nrows,
+ EncodedStatistics stats_int, EncodedStatistics stats_float) {
+ auto f_builder = FileMetaDataBuilder::Make(&schema, props);
+ auto rg1_builder = f_builder->AppendRowGroup();
+ // Write the metadata
+ // rowgroup1 metadata
+ auto col1_builder = rg1_builder->NextColumnChunk();
+ auto col2_builder = rg1_builder->NextColumnChunk();
+ // column metadata
+ std::map<Encoding::type, int32_t> dict_encoding_stats({{Encoding::RLE_DICTIONARY, 1}});
+ std::map<Encoding::type, int32_t> data_encoding_stats(
+ {{Encoding::PLAIN, 1}, {Encoding::RLE, 1}});
+ stats_int.set_is_signed(true);
+ col1_builder->SetStatistics(stats_int);
+ stats_float.set_is_signed(true);
+ col2_builder->SetStatistics(stats_float);
+ col1_builder->Finish(nrows / 2, 4, 0, 10, 512, 600, true, false, dict_encoding_stats,
+ data_encoding_stats);
+ col2_builder->Finish(nrows / 2, 24, 0, 30, 512, 600, true, false, dict_encoding_stats,
+ data_encoding_stats);
+
+ rg1_builder->set_num_rows(nrows / 2);
+ rg1_builder->Finish(1024);
+
+ // rowgroup2 metadata
+ auto rg2_builder = f_builder->AppendRowGroup();
+ col1_builder = rg2_builder->NextColumnChunk();
+ col2_builder = rg2_builder->NextColumnChunk();
+ // column metadata
+ col1_builder->SetStatistics(stats_int);
+ col2_builder->SetStatistics(stats_float);
+ dict_encoding_stats.clear();
+ col1_builder->Finish(nrows / 2, /*dictionary_page_offset=*/0, 0, 10, 512, 600,
+ /*has_dictionary=*/false, false, dict_encoding_stats,
+ data_encoding_stats);
+ col2_builder->Finish(nrows / 2, 16, 0, 26, 512, 600, true, false, dict_encoding_stats,
+ data_encoding_stats);
+
+ rg2_builder->set_num_rows(nrows / 2);
+ rg2_builder->Finish(1024);
+
+ // Return the metadata accessor
+ return f_builder->Finish();
+}
+
+TEST(Metadata, TestBuildAccess) {
+ parquet::schema::NodeVector fields;
+ parquet::schema::NodePtr root;
+ parquet::SchemaDescriptor schema;
+
+ WriterProperties::Builder prop_builder;
+
+ std::shared_ptr<WriterProperties> props =
+ prop_builder.version(ParquetVersion::PARQUET_2_6)->build();
+
+ fields.push_back(parquet::schema::Int32("int_col", Repetition::REQUIRED));
+ fields.push_back(parquet::schema::Float("float_col", Repetition::REQUIRED));
+ root = parquet::schema::GroupNode::Make("schema", Repetition::REPEATED, fields);
+ schema.Init(root);
+
+ int64_t nrows = 1000;
+ int32_t int_min = 100, int_max = 200;
+ EncodedStatistics stats_int;
+ stats_int.set_null_count(0)
+ .set_distinct_count(nrows)
+ .set_min(std::string(reinterpret_cast<const char*>(&int_min), 4))
+ .set_max(std::string(reinterpret_cast<const char*>(&int_max), 4));
+ EncodedStatistics stats_float;
+ float float_min = 100.100f, float_max = 200.200f;
+ stats_float.set_null_count(0)
+ .set_distinct_count(nrows)
+ .set_min(std::string(reinterpret_cast<const char*>(&float_min), 4))
+ .set_max(std::string(reinterpret_cast<const char*>(&float_max), 4));
+
+ // Generate the metadata
+ auto f_accessor = GenerateTableMetaData(schema, props, nrows, stats_int, stats_float);
+
+ std::string f_accessor_serialized_metadata = f_accessor->SerializeToString();
+ uint32_t expected_len = static_cast<uint32_t>(f_accessor_serialized_metadata.length());
+
+ // decoded_len is an in-out parameter
+ uint32_t decoded_len = expected_len;
+ auto f_accessor_copy =
+ FileMetaData::Make(f_accessor_serialized_metadata.data(), &decoded_len);
+
+ // Check that all of the serialized data is consumed
+ ASSERT_EQ(expected_len, decoded_len);
+
+ // Run this block twice, one for f_accessor, one for f_accessor_copy.
+ // To make sure SerializedMetadata was deserialized correctly.
+ std::vector<FileMetaData*> f_accessors = {f_accessor.get(), f_accessor_copy.get()};
+ for (int loop_index = 0; loop_index < 2; loop_index++) {
+ // file metadata
+ ASSERT_EQ(nrows, f_accessors[loop_index]->num_rows());
+ ASSERT_LE(0, static_cast<int>(f_accessors[loop_index]->size()));
+ ASSERT_EQ(2, f_accessors[loop_index]->num_row_groups());
+ ASSERT_EQ(ParquetVersion::PARQUET_2_6, f_accessors[loop_index]->version());
+ ASSERT_EQ(DEFAULT_CREATED_BY, f_accessors[loop_index]->created_by());
+ ASSERT_EQ(3, f_accessors[loop_index]->num_schema_elements());
+
+ // row group1 metadata
+ auto rg1_accessor = f_accessors[loop_index]->RowGroup(0);
+ ASSERT_EQ(2, rg1_accessor->num_columns());
+ ASSERT_EQ(nrows / 2, rg1_accessor->num_rows());
+ ASSERT_EQ(1024, rg1_accessor->total_byte_size());
+ ASSERT_EQ(1024, rg1_accessor->total_compressed_size());
+ EXPECT_EQ(rg1_accessor->file_offset(),
+ rg1_accessor->ColumnChunk(0)->dictionary_page_offset());
+
+ auto rg1_column1 = rg1_accessor->ColumnChunk(0);
+ auto rg1_column2 = rg1_accessor->ColumnChunk(1);
+ ASSERT_EQ(true, rg1_column1->is_stats_set());
+ ASSERT_EQ(true, rg1_column2->is_stats_set());
+ ASSERT_EQ(stats_float.min(), rg1_column2->statistics()->EncodeMin());
+ ASSERT_EQ(stats_float.max(), rg1_column2->statistics()->EncodeMax());
+ ASSERT_EQ(stats_int.min(), rg1_column1->statistics()->EncodeMin());
+ ASSERT_EQ(stats_int.max(), rg1_column1->statistics()->EncodeMax());
+ ASSERT_EQ(0, rg1_column1->statistics()->null_count());
+ ASSERT_EQ(0, rg1_column2->statistics()->null_count());
+ ASSERT_EQ(nrows, rg1_column1->statistics()->distinct_count());
+ ASSERT_EQ(nrows, rg1_column2->statistics()->distinct_count());
+ ASSERT_EQ(DEFAULT_COMPRESSION_TYPE, rg1_column1->compression());
+ ASSERT_EQ(DEFAULT_COMPRESSION_TYPE, rg1_column2->compression());
+ ASSERT_EQ(nrows / 2, rg1_column1->num_values());
+ ASSERT_EQ(nrows / 2, rg1_column2->num_values());
+ ASSERT_EQ(3, rg1_column1->encodings().size());
+ ASSERT_EQ(3, rg1_column2->encodings().size());
+ ASSERT_EQ(512, rg1_column1->total_compressed_size());
+ ASSERT_EQ(512, rg1_column2->total_compressed_size());
+ ASSERT_EQ(600, rg1_column1->total_uncompressed_size());
+ ASSERT_EQ(600, rg1_column2->total_uncompressed_size());
+ ASSERT_EQ(4, rg1_column1->dictionary_page_offset());
+ ASSERT_EQ(24, rg1_column2->dictionary_page_offset());
+ ASSERT_EQ(10, rg1_column1->data_page_offset());
+ ASSERT_EQ(30, rg1_column2->data_page_offset());
+ ASSERT_EQ(3, rg1_column1->encoding_stats().size());
+ ASSERT_EQ(3, rg1_column2->encoding_stats().size());
+
+ auto rg2_accessor = f_accessors[loop_index]->RowGroup(1);
+ ASSERT_EQ(2, rg2_accessor->num_columns());
+ ASSERT_EQ(nrows / 2, rg2_accessor->num_rows());
+ ASSERT_EQ(1024, rg2_accessor->total_byte_size());
+ ASSERT_EQ(1024, rg2_accessor->total_compressed_size());
+ EXPECT_EQ(rg2_accessor->file_offset(),
+ rg2_accessor->ColumnChunk(0)->data_page_offset());
+
+ auto rg2_column1 = rg2_accessor->ColumnChunk(0);
+ auto rg2_column2 = rg2_accessor->ColumnChunk(1);
+ ASSERT_EQ(true, rg2_column1->is_stats_set());
+ ASSERT_EQ(true, rg2_column2->is_stats_set());
+ ASSERT_EQ(stats_float.min(), rg2_column2->statistics()->EncodeMin());
+ ASSERT_EQ(stats_float.max(), rg2_column2->statistics()->EncodeMax());
+ ASSERT_EQ(stats_int.min(), rg1_column1->statistics()->EncodeMin());
+ ASSERT_EQ(stats_int.max(), rg1_column1->statistics()->EncodeMax());
+ ASSERT_EQ(0, rg2_column1->statistics()->null_count());
+ ASSERT_EQ(0, rg2_column2->statistics()->null_count());
+ ASSERT_EQ(nrows, rg2_column1->statistics()->distinct_count());
+ ASSERT_EQ(nrows, rg2_column2->statistics()->distinct_count());
+ ASSERT_EQ(nrows / 2, rg2_column1->num_values());
+ ASSERT_EQ(nrows / 2, rg2_column2->num_values());
+ ASSERT_EQ(DEFAULT_COMPRESSION_TYPE, rg2_column1->compression());
+ ASSERT_EQ(DEFAULT_COMPRESSION_TYPE, rg2_column2->compression());
+ ASSERT_EQ(2, rg2_column1->encodings().size());
+ ASSERT_EQ(3, rg2_column2->encodings().size());
+ ASSERT_EQ(512, rg2_column1->total_compressed_size());
+ ASSERT_EQ(512, rg2_column2->total_compressed_size());
+ ASSERT_EQ(600, rg2_column1->total_uncompressed_size());
+ ASSERT_EQ(600, rg2_column2->total_uncompressed_size());
+ EXPECT_FALSE(rg2_column1->has_dictionary_page());
+ ASSERT_EQ(0, rg2_column1->dictionary_page_offset());
+ ASSERT_EQ(16, rg2_column2->dictionary_page_offset());
+ ASSERT_EQ(10, rg2_column1->data_page_offset());
+ ASSERT_EQ(26, rg2_column2->data_page_offset());
+ ASSERT_EQ(2, rg2_column1->encoding_stats().size());
+ ASSERT_EQ(2, rg2_column2->encoding_stats().size());
+
+ // Test FileMetaData::set_file_path
+ ASSERT_TRUE(rg2_column1->file_path().empty());
+ f_accessors[loop_index]->set_file_path("/foo/bar/bar.parquet");
+ ASSERT_EQ("/foo/bar/bar.parquet", rg2_column1->file_path());
+ }
+
+ // Test AppendRowGroups
+ auto f_accessor_2 = GenerateTableMetaData(schema, props, nrows, stats_int, stats_float);
+ f_accessor->AppendRowGroups(*f_accessor_2);
+ ASSERT_EQ(4, f_accessor->num_row_groups());
+ ASSERT_EQ(nrows * 2, f_accessor->num_rows());
+ ASSERT_LE(0, static_cast<int>(f_accessor->size()));
+ ASSERT_EQ(ParquetVersion::PARQUET_2_6, f_accessor->version());
+ ASSERT_EQ(DEFAULT_CREATED_BY, f_accessor->created_by());
+ ASSERT_EQ(3, f_accessor->num_schema_elements());
+
+ // Test AppendRowGroups from self (ARROW-13654)
+ f_accessor->AppendRowGroups(*f_accessor);
+ ASSERT_EQ(8, f_accessor->num_row_groups());
+ ASSERT_EQ(nrows * 4, f_accessor->num_rows());
+ ASSERT_EQ(3, f_accessor->num_schema_elements());
+
+ // Test Subset
+ auto f_accessor_1 = f_accessor->Subset({2, 3});
+ ASSERT_TRUE(f_accessor_1->Equals(*f_accessor_2));
+
+ f_accessor_1 = f_accessor_2->Subset({0});
+ f_accessor_1->AppendRowGroups(*f_accessor->Subset({0}));
+ ASSERT_TRUE(f_accessor_1->Equals(*f_accessor->Subset({2, 0})));
+}
+
+TEST(Metadata, TestV1Version) {
+ // PARQUET-839
+ parquet::schema::NodeVector fields;
+ parquet::schema::NodePtr root;
+ parquet::SchemaDescriptor schema;
+
+ WriterProperties::Builder prop_builder;
+
+ std::shared_ptr<WriterProperties> props =
+ prop_builder.version(ParquetVersion::PARQUET_1_0)->build();
+
+ fields.push_back(parquet::schema::Int32("int_col", Repetition::REQUIRED));
+ fields.push_back(parquet::schema::Float("float_col", Repetition::REQUIRED));
+ root = parquet::schema::GroupNode::Make("schema", Repetition::REPEATED, fields);
+ schema.Init(root);
+
+ auto f_builder = FileMetaDataBuilder::Make(&schema, props);
+
+ // Read the metadata
+ auto f_accessor = f_builder->Finish();
+
+ // file metadata
+ ASSERT_EQ(ParquetVersion::PARQUET_1_0, f_accessor->version());
+}
+
+TEST(Metadata, TestKeyValueMetadata) {
+ parquet::schema::NodeVector fields;
+ parquet::schema::NodePtr root;
+ parquet::SchemaDescriptor schema;
+
+ WriterProperties::Builder prop_builder;
+
+ std::shared_ptr<WriterProperties> props =
+ prop_builder.version(ParquetVersion::PARQUET_1_0)->build();
+
+ fields.push_back(parquet::schema::Int32("int_col", Repetition::REQUIRED));
+ fields.push_back(parquet::schema::Float("float_col", Repetition::REQUIRED));
+ root = parquet::schema::GroupNode::Make("schema", Repetition::REPEATED, fields);
+ schema.Init(root);
+
+ auto kvmeta = std::make_shared<KeyValueMetadata>();
+ kvmeta->Append("test_key", "test_value");
+
+ auto f_builder = FileMetaDataBuilder::Make(&schema, props, kvmeta);
+
+ // Read the metadata
+ auto f_accessor = f_builder->Finish();
+
+ // Key value metadata
+ ASSERT_TRUE(f_accessor->key_value_metadata());
+ EXPECT_TRUE(f_accessor->key_value_metadata()->Equals(*kvmeta));
+}
+
+TEST(ApplicationVersion, Basics) {
+ ApplicationVersion version("parquet-mr version 1.7.9");
+ ApplicationVersion version1("parquet-mr version 1.8.0");
+ ApplicationVersion version2("parquet-cpp version 1.0.0");
+ ApplicationVersion version3("");
+ ApplicationVersion version4("parquet-mr version 1.5.0ab-cdh5.5.0+cd (build abcd)");
+ ApplicationVersion version5("parquet-mr");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ(1, version.version.major);
+ ASSERT_EQ(7, version.version.minor);
+ ASSERT_EQ(9, version.version.patch);
+
+ ASSERT_EQ("parquet-cpp", version2.application_);
+ ASSERT_EQ(1, version2.version.major);
+ ASSERT_EQ(0, version2.version.minor);
+ ASSERT_EQ(0, version2.version.patch);
+
+ ASSERT_EQ("parquet-mr", version4.application_);
+ ASSERT_EQ("abcd", version4.build_);
+ ASSERT_EQ(1, version4.version.major);
+ ASSERT_EQ(5, version4.version.minor);
+ ASSERT_EQ(0, version4.version.patch);
+ ASSERT_EQ("ab", version4.version.unknown);
+ ASSERT_EQ("cdh5.5.0", version4.version.pre_release);
+ ASSERT_EQ("cd", version4.version.build_info);
+
+ ASSERT_EQ("parquet-mr", version5.application_);
+ ASSERT_EQ(0, version5.version.major);
+ ASSERT_EQ(0, version5.version.minor);
+ ASSERT_EQ(0, version5.version.patch);
+
+ ASSERT_EQ(true, version.VersionLt(version1));
+
+ EncodedStatistics stats;
+ ASSERT_FALSE(version1.HasCorrectStatistics(Type::INT96, stats, SortOrder::UNKNOWN));
+ ASSERT_TRUE(version.HasCorrectStatistics(Type::INT32, stats, SortOrder::SIGNED));
+ ASSERT_FALSE(version.HasCorrectStatistics(Type::BYTE_ARRAY, stats, SortOrder::SIGNED));
+ ASSERT_TRUE(version1.HasCorrectStatistics(Type::BYTE_ARRAY, stats, SortOrder::SIGNED));
+ ASSERT_FALSE(
+ version1.HasCorrectStatistics(Type::BYTE_ARRAY, stats, SortOrder::UNSIGNED));
+ ASSERT_TRUE(version3.HasCorrectStatistics(Type::FIXED_LEN_BYTE_ARRAY, stats,
+ SortOrder::SIGNED));
+
+ // Check that the old stats are correct if min and max are the same
+ // regardless of sort order
+ EncodedStatistics stats_str;
+ stats_str.set_min("a").set_max("b");
+ ASSERT_FALSE(
+ version1.HasCorrectStatistics(Type::BYTE_ARRAY, stats_str, SortOrder::UNSIGNED));
+ stats_str.set_max("a");
+ ASSERT_TRUE(
+ version1.HasCorrectStatistics(Type::BYTE_ARRAY, stats_str, SortOrder::UNSIGNED));
+
+ // Check that the same holds true for ints
+ int32_t int_min = 100, int_max = 200;
+ EncodedStatistics stats_int;
+ stats_int.set_min(std::string(reinterpret_cast<const char*>(&int_min), 4))
+ .set_max(std::string(reinterpret_cast<const char*>(&int_max), 4));
+ ASSERT_FALSE(
+ version1.HasCorrectStatistics(Type::BYTE_ARRAY, stats_int, SortOrder::UNSIGNED));
+ stats_int.set_max(std::string(reinterpret_cast<const char*>(&int_min), 4));
+ ASSERT_TRUE(
+ version1.HasCorrectStatistics(Type::BYTE_ARRAY, stats_int, SortOrder::UNSIGNED));
+}
+
+TEST(ApplicationVersion, Empty) {
+ ApplicationVersion version("");
+
+ ASSERT_EQ("", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(0, version.version.major);
+ ASSERT_EQ(0, version.version.minor);
+ ASSERT_EQ(0, version.version.patch);
+ ASSERT_EQ("", version.version.unknown);
+ ASSERT_EQ("", version.version.pre_release);
+ ASSERT_EQ("", version.version.build_info);
+}
+
+TEST(ApplicationVersion, NoVersion) {
+ ApplicationVersion version("parquet-mr (build abcd)");
+
+ ASSERT_EQ("parquet-mr (build abcd)", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(0, version.version.major);
+ ASSERT_EQ(0, version.version.minor);
+ ASSERT_EQ(0, version.version.patch);
+ ASSERT_EQ("", version.version.unknown);
+ ASSERT_EQ("", version.version.pre_release);
+ ASSERT_EQ("", version.version.build_info);
+}
+
+TEST(ApplicationVersion, VersionEmpty) {
+ ApplicationVersion version("parquet-mr version ");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(0, version.version.major);
+ ASSERT_EQ(0, version.version.minor);
+ ASSERT_EQ(0, version.version.patch);
+ ASSERT_EQ("", version.version.unknown);
+ ASSERT_EQ("", version.version.pre_release);
+ ASSERT_EQ("", version.version.build_info);
+}
+
+TEST(ApplicationVersion, VersionNoMajor) {
+ ApplicationVersion version("parquet-mr version .");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(0, version.version.major);
+ ASSERT_EQ(0, version.version.minor);
+ ASSERT_EQ(0, version.version.patch);
+ ASSERT_EQ("", version.version.unknown);
+ ASSERT_EQ("", version.version.pre_release);
+ ASSERT_EQ("", version.version.build_info);
+}
+
+TEST(ApplicationVersion, VersionInvalidMajor) {
+ ApplicationVersion version("parquet-mr version x1");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(0, version.version.major);
+ ASSERT_EQ(0, version.version.minor);
+ ASSERT_EQ(0, version.version.patch);
+ ASSERT_EQ("", version.version.unknown);
+ ASSERT_EQ("", version.version.pre_release);
+ ASSERT_EQ("", version.version.build_info);
+}
+
+TEST(ApplicationVersion, VersionMajorOnly) {
+ ApplicationVersion version("parquet-mr version 1");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(1, version.version.major);
+ ASSERT_EQ(0, version.version.minor);
+ ASSERT_EQ(0, version.version.patch);
+ ASSERT_EQ("", version.version.unknown);
+ ASSERT_EQ("", version.version.pre_release);
+ ASSERT_EQ("", version.version.build_info);
+}
+
+TEST(ApplicationVersion, VersionNoMinor) {
+ ApplicationVersion version("parquet-mr version 1.");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(1, version.version.major);
+ ASSERT_EQ(0, version.version.minor);
+ ASSERT_EQ(0, version.version.patch);
+ ASSERT_EQ("", version.version.unknown);
+ ASSERT_EQ("", version.version.pre_release);
+ ASSERT_EQ("", version.version.build_info);
+}
+
+TEST(ApplicationVersion, VersionMajorMinorOnly) {
+ ApplicationVersion version("parquet-mr version 1.7");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(1, version.version.major);
+ ASSERT_EQ(7, version.version.minor);
+ ASSERT_EQ(0, version.version.patch);
+ ASSERT_EQ("", version.version.unknown);
+ ASSERT_EQ("", version.version.pre_release);
+ ASSERT_EQ("", version.version.build_info);
+}
+
+TEST(ApplicationVersion, VersionInvalidMinor) {
+ ApplicationVersion version("parquet-mr version 1.x7");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(1, version.version.major);
+ ASSERT_EQ(0, version.version.minor);
+ ASSERT_EQ(0, version.version.patch);
+ ASSERT_EQ("", version.version.unknown);
+ ASSERT_EQ("", version.version.pre_release);
+ ASSERT_EQ("", version.version.build_info);
+}
+
+TEST(ApplicationVersion, VersionNoPatch) {
+ ApplicationVersion version("parquet-mr version 1.7.");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(1, version.version.major);
+ ASSERT_EQ(7, version.version.minor);
+ ASSERT_EQ(0, version.version.patch);
+ ASSERT_EQ("", version.version.unknown);
+ ASSERT_EQ("", version.version.pre_release);
+ ASSERT_EQ("", version.version.build_info);
+}
+
+TEST(ApplicationVersion, VersionInvalidPatch) {
+ ApplicationVersion version("parquet-mr version 1.7.x9");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(1, version.version.major);
+ ASSERT_EQ(7, version.version.minor);
+ ASSERT_EQ(0, version.version.patch);
+ ASSERT_EQ("", version.version.unknown);
+ ASSERT_EQ("", version.version.pre_release);
+ ASSERT_EQ("", version.version.build_info);
+}
+
+TEST(ApplicationVersion, VersionNoUnknown) {
+ ApplicationVersion version("parquet-mr version 1.7.9-cdh5.5.0+cd");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(1, version.version.major);
+ ASSERT_EQ(7, version.version.minor);
+ ASSERT_EQ(9, version.version.patch);
+ ASSERT_EQ("", version.version.unknown);
+ ASSERT_EQ("cdh5.5.0", version.version.pre_release);
+ ASSERT_EQ("cd", version.version.build_info);
+}
+
+TEST(ApplicationVersion, VersionNoPreRelease) {
+ ApplicationVersion version("parquet-mr version 1.7.9ab+cd");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(1, version.version.major);
+ ASSERT_EQ(7, version.version.minor);
+ ASSERT_EQ(9, version.version.patch);
+ ASSERT_EQ("ab", version.version.unknown);
+ ASSERT_EQ("", version.version.pre_release);
+ ASSERT_EQ("cd", version.version.build_info);
+}
+
+TEST(ApplicationVersion, VersionNoUnknownNoPreRelease) {
+ ApplicationVersion version("parquet-mr version 1.7.9+cd");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(1, version.version.major);
+ ASSERT_EQ(7, version.version.minor);
+ ASSERT_EQ(9, version.version.patch);
+ ASSERT_EQ("", version.version.unknown);
+ ASSERT_EQ("", version.version.pre_release);
+ ASSERT_EQ("cd", version.version.build_info);
+}
+
+TEST(ApplicationVersion, VersionNoUnknownBuildInfoPreRelease) {
+ ApplicationVersion version("parquet-mr version 1.7.9+cd-cdh5.5.0");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ("", version.build_);
+ ASSERT_EQ(1, version.version.major);
+ ASSERT_EQ(7, version.version.minor);
+ ASSERT_EQ(9, version.version.patch);
+ ASSERT_EQ("", version.version.unknown);
+ ASSERT_EQ("", version.version.pre_release);
+ ASSERT_EQ("cd-cdh5.5.0", version.version.build_info);
+}
+
+TEST(ApplicationVersion, FullWithSpaces) {
+ ApplicationVersion version(
+ " parquet-mr \t version \v 1.5.3ab-cdh5.5.0+cd \r (build \n abcd \f) ");
+
+ ASSERT_EQ("parquet-mr", version.application_);
+ ASSERT_EQ("abcd", version.build_);
+ ASSERT_EQ(1, version.version.major);
+ ASSERT_EQ(5, version.version.minor);
+ ASSERT_EQ(3, version.version.patch);
+ ASSERT_EQ("ab", version.version.unknown);
+ ASSERT_EQ("cdh5.5.0", version.version.pre_release);
+ ASSERT_EQ("cd", version.version.build_info);
+}
+
+} // namespace metadata
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/murmur3.cc b/src/arrow/cpp/src/parquet/murmur3.cc
new file mode 100644
index 000000000..07a936e04
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/murmur3.cc
@@ -0,0 +1,222 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//-----------------------------------------------------------------------------
+// MurmurHash3 was written by Austin Appleby, and is placed in the public
+// domain. The author hereby disclaims copyright to this source code.
+
+// Note - The x86 and x64 versions do _not_ produce the same results, as the
+// algorithms are optimized for their respective platforms. You can still
+// compile and run any of them on any platform, but your performance with the
+// non-native version will be less than optimal.
+
+#include "parquet/murmur3.h"
+
+namespace parquet {
+
+#if defined(_MSC_VER)
+
+#define FORCE_INLINE __forceinline
+#define ROTL64(x, y) _rotl64(x, y)
+
+#else // defined(_MSC_VER)
+
+#define FORCE_INLINE inline __attribute__((always_inline))
+inline uint64_t rotl64(uint64_t x, int8_t r) { return (x << r) | (x >> (64 - r)); }
+#define ROTL64(x, y) rotl64(x, y)
+
+#endif // !defined(_MSC_VER)
+
+#define BIG_CONSTANT(x) (x##LLU)
+
+//-----------------------------------------------------------------------------
+// Block read - if your platform needs to do endian-swapping or can only
+// handle aligned reads, do the conversion here
+
+FORCE_INLINE uint32_t getblock32(const uint32_t* p, int i) { return p[i]; }
+
+FORCE_INLINE uint64_t getblock64(const uint64_t* p, int i) { return p[i]; }
+
+//-----------------------------------------------------------------------------
+// Finalization mix - force all bits of a hash block to avalanche
+
+FORCE_INLINE uint32_t fmix32(uint32_t h) {
+ h ^= h >> 16;
+ h *= 0x85ebca6b;
+ h ^= h >> 13;
+ h *= 0xc2b2ae35;
+ h ^= h >> 16;
+
+ return h;
+}
+
+//----------
+
+FORCE_INLINE uint64_t fmix64(uint64_t k) {
+ k ^= k >> 33;
+ k *= BIG_CONSTANT(0xff51afd7ed558ccd);
+ k ^= k >> 33;
+ k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53);
+ k ^= k >> 33;
+
+ return k;
+}
+
+//-----------------------------------------------------------------------------
+
+void Hash_x64_128(const void* key, const int len, const uint32_t seed, uint64_t out[2]) {
+ const uint8_t* data = (const uint8_t*)key;
+ const int nblocks = len / 16;
+
+ uint64_t h1 = seed;
+ uint64_t h2 = seed;
+
+ const uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5);
+ const uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f);
+
+ //----------
+ // body
+
+ const uint64_t* blocks = (const uint64_t*)(data);
+
+ for (int i = 0; i < nblocks; i++) {
+ uint64_t k1 = getblock64(blocks, i * 2 + 0);
+ uint64_t k2 = getblock64(blocks, i * 2 + 1);
+
+ k1 *= c1;
+ k1 = ROTL64(k1, 31);
+ k1 *= c2;
+ h1 ^= k1;
+
+ h1 = ROTL64(h1, 27);
+ h1 += h2;
+ h1 = h1 * 5 + 0x52dce729;
+
+ k2 *= c2;
+ k2 = ROTL64(k2, 33);
+ k2 *= c1;
+ h2 ^= k2;
+
+ h2 = ROTL64(h2, 31);
+ h2 += h1;
+ h2 = h2 * 5 + 0x38495ab5;
+ }
+
+ //----------
+ // tail
+
+ const uint8_t* tail = (const uint8_t*)(data + nblocks * 16);
+
+ uint64_t k1 = 0;
+ uint64_t k2 = 0;
+
+ switch (len & 15) {
+ case 15:
+ k2 ^= ((uint64_t)tail[14]) << 48; // fall through
+ case 14:
+ k2 ^= ((uint64_t)tail[13]) << 40; // fall through
+ case 13:
+ k2 ^= ((uint64_t)tail[12]) << 32; // fall through
+ case 12:
+ k2 ^= ((uint64_t)tail[11]) << 24; // fall through
+ case 11:
+ k2 ^= ((uint64_t)tail[10]) << 16; // fall through
+ case 10:
+ k2 ^= ((uint64_t)tail[9]) << 8; // fall through
+ case 9:
+ k2 ^= ((uint64_t)tail[8]) << 0;
+ k2 *= c2;
+ k2 = ROTL64(k2, 33);
+ k2 *= c1;
+ h2 ^= k2; // fall through
+
+ case 8:
+ k1 ^= ((uint64_t)tail[7]) << 56; // fall through
+ case 7:
+ k1 ^= ((uint64_t)tail[6]) << 48; // fall through
+ case 6:
+ k1 ^= ((uint64_t)tail[5]) << 40; // fall through
+ case 5:
+ k1 ^= ((uint64_t)tail[4]) << 32; // fall through
+ case 4:
+ k1 ^= ((uint64_t)tail[3]) << 24; // fall through
+ case 3:
+ k1 ^= ((uint64_t)tail[2]) << 16; // fall through
+ case 2:
+ k1 ^= ((uint64_t)tail[1]) << 8; // fall through
+ case 1:
+ k1 ^= ((uint64_t)tail[0]) << 0;
+ k1 *= c1;
+ k1 = ROTL64(k1, 31);
+ k1 *= c2;
+ h1 ^= k1;
+ }
+
+ //----------
+ // finalization
+
+ h1 ^= len;
+ h2 ^= len;
+
+ h1 += h2;
+ h2 += h1;
+
+ h1 = fmix64(h1);
+ h2 = fmix64(h2);
+
+ h1 += h2;
+ h2 += h1;
+
+ reinterpret_cast<uint64_t*>(out)[0] = h1;
+ reinterpret_cast<uint64_t*>(out)[1] = h2;
+}
+
+template <typename T>
+uint64_t HashHelper(T value, uint32_t seed) {
+ uint64_t output[2];
+ Hash_x64_128(reinterpret_cast<void*>(&value), sizeof(T), seed, output);
+ return output[0];
+}
+
+uint64_t MurmurHash3::Hash(int32_t value) const { return HashHelper(value, seed_); }
+
+uint64_t MurmurHash3::Hash(int64_t value) const { return HashHelper(value, seed_); }
+
+uint64_t MurmurHash3::Hash(float value) const { return HashHelper(value, seed_); }
+
+uint64_t MurmurHash3::Hash(double value) const { return HashHelper(value, seed_); }
+
+uint64_t MurmurHash3::Hash(const FLBA* value, uint32_t len) const {
+ uint64_t out[2];
+ Hash_x64_128(reinterpret_cast<const void*>(value->ptr), len, seed_, out);
+ return out[0];
+}
+
+uint64_t MurmurHash3::Hash(const Int96* value) const {
+ uint64_t out[2];
+ Hash_x64_128(reinterpret_cast<const void*>(value->value), sizeof(value->value), seed_,
+ out);
+ return out[0];
+}
+
+uint64_t MurmurHash3::Hash(const ByteArray* value) const {
+ uint64_t out[2];
+ Hash_x64_128(reinterpret_cast<const void*>(value->ptr), value->len, seed_, out);
+ return out[0];
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/murmur3.h b/src/arrow/cpp/src/parquet/murmur3.h
new file mode 100644
index 000000000..acf7088e4
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/murmur3.h
@@ -0,0 +1,54 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//-----------------------------------------------------------------------------
+// MurmurHash3 was written by Austin Appleby, and is placed in the public
+// domain. The author hereby disclaims copyright to this source code.
+
+#pragma once
+
+#include <cstdint>
+
+#include "parquet/hasher.h"
+#include "parquet/platform.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+/// Source:
+/// https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp
+/// (Modified to adapt to coding conventions and to inherit the Hasher abstract class)
+class PARQUET_EXPORT MurmurHash3 : public Hasher {
+ public:
+ MurmurHash3() : seed_(DEFAULT_SEED) {}
+ uint64_t Hash(int32_t value) const override;
+ uint64_t Hash(int64_t value) const override;
+ uint64_t Hash(float value) const override;
+ uint64_t Hash(double value) const override;
+ uint64_t Hash(const Int96* value) const override;
+ uint64_t Hash(const ByteArray* value) const override;
+ uint64_t Hash(const FLBA* val, uint32_t len) const override;
+
+ private:
+ // Default seed for hash which comes from Bloom filter in parquet-mr, it is generated
+ // by System.nanoTime() of java.
+ static constexpr int DEFAULT_SEED = 1361930890;
+
+ uint32_t seed_;
+};
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/parquet.pc.in b/src/arrow/cpp/src/parquet/parquet.pc.in
new file mode 100644
index 000000000..3b29263a9
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/parquet.pc.in
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+prefix=@CMAKE_INSTALL_PREFIX@
+libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@
+includedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@
+
+so_version=@ARROW_SO_VERSION@
+abi_version=@ARROW_SO_VERSION@
+full_so_version=@ARROW_FULL_SO_VERSION@
+
+Name: Apache Parquet
+Description: Apache Parquet is a columnar storage format.
+Version: @ARROW_VERSION@
+Requires: arrow
+Libs: -L${libdir} -lparquet
+Cflags: -I${includedir}
diff --git a/src/arrow/cpp/src/parquet/parquet.thrift b/src/arrow/cpp/src/parquet/parquet.thrift
new file mode 100644
index 000000000..8aa984816
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/parquet.thrift
@@ -0,0 +1,1063 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * File format description for the parquet file format
+ */
+
+cpp_include "parquet/windows_compatibility.h"
+namespace cpp parquet.format
+namespace java org.apache.parquet.format
+
+/**
+ * Types supported by Parquet. These types are intended to be used in combination
+ * with the encodings to control the on disk storage format.
+ * For example INT16 is not included as a type since a good encoding of INT32
+ * would handle this.
+ */
+enum Type {
+ BOOLEAN = 0;
+ INT32 = 1;
+ INT64 = 2;
+ INT96 = 3; // deprecated, only used by legacy implementations.
+ FLOAT = 4;
+ DOUBLE = 5;
+ BYTE_ARRAY = 6;
+ FIXED_LEN_BYTE_ARRAY = 7;
+}
+
+/**
+ * Common types used by frameworks(e.g. hive, pig) using parquet. This helps map
+ * between types in those frameworks to the base types in parquet. This is only
+ * metadata and not needed to read or write the data.
+ */
+enum ConvertedType {
+ /** a BYTE_ARRAY actually contains UTF8 encoded chars */
+ UTF8 = 0;
+
+ /** a map is converted as an optional field containing a repeated key/value pair */
+ MAP = 1;
+
+ /** a key/value pair is converted into a group of two fields */
+ MAP_KEY_VALUE = 2;
+
+ /** a list is converted into an optional field containing a repeated field for its
+ * values */
+ LIST = 3;
+
+ /** an enum is converted into a binary field */
+ ENUM = 4;
+
+ /**
+ * A decimal value.
+ *
+ * This may be used to annotate binary or fixed primitive types. The
+ * underlying byte array stores the unscaled value encoded as two's
+ * complement using big-endian byte order (the most significant byte is the
+ * zeroth element). The value of the decimal is the value * 10^{-scale}.
+ *
+ * This must be accompanied by a (maximum) precision and a scale in the
+ * SchemaElement. The precision specifies the number of digits in the decimal
+ * and the scale stores the location of the decimal point. For example 1.23
+ * would have precision 3 (3 total digits) and scale 2 (the decimal point is
+ * 2 digits over).
+ */
+ DECIMAL = 5;
+
+ /**
+ * A Date
+ *
+ * Stored as days since Unix epoch, encoded as the INT32 physical type.
+ *
+ */
+ DATE = 6;
+
+ /**
+ * A time
+ *
+ * The total number of milliseconds since midnight. The value is stored
+ * as an INT32 physical type.
+ */
+ TIME_MILLIS = 7;
+
+ /**
+ * A time.
+ *
+ * The total number of microseconds since midnight. The value is stored as
+ * an INT64 physical type.
+ */
+ TIME_MICROS = 8;
+
+ /**
+ * A date/time combination
+ *
+ * Date and time recorded as milliseconds since the Unix epoch. Recorded as
+ * a physical type of INT64.
+ */
+ TIMESTAMP_MILLIS = 9;
+
+ /**
+ * A date/time combination
+ *
+ * Date and time recorded as microseconds since the Unix epoch. The value is
+ * stored as an INT64 physical type.
+ */
+ TIMESTAMP_MICROS = 10;
+
+
+ /**
+ * An unsigned integer value.
+ *
+ * The number describes the maximum number of meaningful data bits in
+ * the stored value. 8, 16 and 32 bit values are stored using the
+ * INT32 physical type. 64 bit values are stored using the INT64
+ * physical type.
+ *
+ */
+ UINT_8 = 11;
+ UINT_16 = 12;
+ UINT_32 = 13;
+ UINT_64 = 14;
+
+ /**
+ * A signed integer value.
+ *
+ * The number describes the maximum number of meaningful data bits in
+ * the stored value. 8, 16 and 32 bit values are stored using the
+ * INT32 physical type. 64 bit values are stored using the INT64
+ * physical type.
+ *
+ */
+ INT_8 = 15;
+ INT_16 = 16;
+ INT_32 = 17;
+ INT_64 = 18;
+
+ /**
+ * An embedded JSON document
+ *
+ * A JSON document embedded within a single UTF8 column.
+ */
+ JSON = 19;
+
+ /**
+ * An embedded BSON document
+ *
+ * A BSON document embedded within a single BINARY column.
+ */
+ BSON = 20;
+
+ /**
+ * An interval of time
+ *
+ * This type annotates data stored as a FIXED_LEN_BYTE_ARRAY of length 12
+ * This data is composed of three separate little endian unsigned
+ * integers. Each stores a component of a duration of time. The first
+ * integer identifies the number of months associated with the duration,
+ * the second identifies the number of days associated with the duration
+ * and the third identifies the number of milliseconds associated with
+ * the provided duration. This duration of time is independent of any
+ * particular timezone or date.
+ */
+ INTERVAL = 21;
+}
+
+/**
+ * Representation of Schemas
+ */
+enum FieldRepetitionType {
+ /** This field is required (can not be null) and each record has exactly 1 value. */
+ REQUIRED = 0;
+
+ /** The field is optional (can be null) and each record has 0 or 1 values. */
+ OPTIONAL = 1;
+
+ /** The field is repeated and can contain 0 or more values */
+ REPEATED = 2;
+}
+
+/**
+ * Statistics per row group and per page
+ * All fields are optional.
+ */
+struct Statistics {
+ /**
+ * DEPRECATED: min and max value of the column. Use min_value and max_value.
+ *
+ * Values are encoded using PLAIN encoding, except that variable-length byte
+ * arrays do not include a length prefix.
+ *
+ * These fields encode min and max values determined by signed comparison
+ * only. New files should use the correct order for a column's logical type
+ * and store the values in the min_value and max_value fields.
+ *
+ * To support older readers, these may be set when the column order is
+ * signed.
+ */
+ 1: optional binary max;
+ 2: optional binary min;
+ /** count of null value in the column */
+ 3: optional i64 null_count;
+ /** count of distinct values occurring */
+ 4: optional i64 distinct_count;
+ /**
+ * Min and max values for the column, determined by its ColumnOrder.
+ *
+ * Values are encoded using PLAIN encoding, except that variable-length byte
+ * arrays do not include a length prefix.
+ */
+ 5: optional binary max_value;
+ 6: optional binary min_value;
+}
+
+/** Empty structs to use as logical type annotations */
+struct StringType {} // allowed for BINARY, must be encoded with UTF-8
+struct UUIDType {} // allowed for FIXED[16], must encoded raw UUID bytes
+struct MapType {} // see LogicalTypes.md
+struct ListType {} // see LogicalTypes.md
+struct EnumType {} // allowed for BINARY, must be encoded with UTF-8
+struct DateType {} // allowed for INT32
+
+/**
+ * Logical type to annotate a column that is always null.
+ *
+ * Sometimes when discovering the schema of existing data, values are always
+ * null and the physical type can't be determined. This annotation signals
+ * the case where the physical type was guessed from all null values.
+ */
+struct NullType {} // allowed for any physical type, only null values stored
+
+/**
+ * Decimal logical type annotation
+ *
+ * To maintain forward-compatibility in v1, implementations using this logical
+ * type must also set scale and precision on the annotated SchemaElement.
+ *
+ * Allowed for physical types: INT32, INT64, FIXED, and BINARY
+ */
+struct DecimalType {
+ 1: required i32 scale
+ 2: required i32 precision
+}
+
+/** Time units for logical types */
+struct MilliSeconds {}
+struct MicroSeconds {}
+struct NanoSeconds {}
+union TimeUnit {
+ 1: MilliSeconds MILLIS
+ 2: MicroSeconds MICROS
+ 3: NanoSeconds NANOS
+}
+
+/**
+ * Timestamp logical type annotation
+ *
+ * Allowed for physical types: INT64
+ */
+struct TimestampType {
+ 1: required bool isAdjustedToUTC
+ 2: required TimeUnit unit
+}
+
+/**
+ * Time logical type annotation
+ *
+ * Allowed for physical types: INT32 (millis), INT64 (micros, nanos)
+ */
+struct TimeType {
+ 1: required bool isAdjustedToUTC
+ 2: required TimeUnit unit
+}
+
+/**
+ * Integer logical type annotation
+ *
+ * bitWidth must be 8, 16, 32, or 64.
+ *
+ * Allowed for physical types: INT32, INT64
+ */
+struct IntType {
+ 1: required i8 bitWidth
+ 2: required bool isSigned
+}
+
+/**
+ * Embedded JSON logical type annotation
+ *
+ * Allowed for physical types: BINARY
+ */
+struct JsonType {
+}
+
+/**
+ * Embedded BSON logical type annotation
+ *
+ * Allowed for physical types: BINARY
+ */
+struct BsonType {
+}
+
+/**
+ * LogicalType annotations to replace ConvertedType.
+ *
+ * To maintain compatibility, implementations using LogicalType for a
+ * SchemaElement must also set the corresponding ConvertedType from the
+ * following table.
+ */
+union LogicalType {
+ 1: StringType STRING // use ConvertedType UTF8
+ 2: MapType MAP // use ConvertedType MAP
+ 3: ListType LIST // use ConvertedType LIST
+ 4: EnumType ENUM // use ConvertedType ENUM
+ 5: DecimalType DECIMAL // use ConvertedType DECIMAL
+ 6: DateType DATE // use ConvertedType DATE
+
+ // use ConvertedType TIME_MICROS for TIME(isAdjustedToUTC = *, unit = MICROS)
+ // use ConvertedType TIME_MILLIS for TIME(isAdjustedToUTC = *, unit = MILLIS)
+ 7: TimeType TIME
+
+ // use ConvertedType TIMESTAMP_MICROS for TIMESTAMP(isAdjustedToUTC = *, unit = MICROS)
+ // use ConvertedType TIMESTAMP_MILLIS for TIMESTAMP(isAdjustedToUTC = *, unit = MILLIS)
+ 8: TimestampType TIMESTAMP
+
+ // 9: reserved for INTERVAL
+ 10: IntType INTEGER // use ConvertedType INT_* or UINT_*
+ 11: NullType UNKNOWN // no compatible ConvertedType
+ 12: JsonType JSON // use ConvertedType JSON
+ 13: BsonType BSON // use ConvertedType BSON
+ 14: UUIDType UUID
+}
+
+/**
+ * Represents a element inside a schema definition.
+ * - if it is a group (inner node) then type is undefined and num_children is defined
+ * - if it is a primitive type (leaf) then type is defined and num_children is undefined
+ * the nodes are listed in depth first traversal order.
+ */
+struct SchemaElement {
+ /** Data type for this field. Not set if the current element is a non-leaf node */
+ 1: optional Type type;
+
+ /** If type is FIXED_LEN_BYTE_ARRAY, this is the byte length of the vales.
+ * Otherwise, if specified, this is the maximum bit length to store any of the values.
+ * (e.g. a low cardinality INT col could have this set to 3). Note that this is
+ * in the schema, and therefore fixed for the entire file.
+ */
+ 2: optional i32 type_length;
+
+ /** repetition of the field. The root of the schema does not have a repetition_type.
+ * All other nodes must have one */
+ 3: optional FieldRepetitionType repetition_type;
+
+ /** Name of the field in the schema */
+ 4: required string name;
+
+ /** Nested fields. Since thrift does not support nested fields,
+ * the nesting is flattened to a single list by a depth-first traversal.
+ * The children count is used to construct the nested relationship.
+ * This field is not set when the element is a primitive type
+ */
+ 5: optional i32 num_children;
+
+ /** When the schema is the result of a conversion from another model
+ * Used to record the original type to help with cross conversion.
+ */
+ 6: optional ConvertedType converted_type;
+
+ /** Used when this column contains decimal data.
+ * See the DECIMAL converted type for more details.
+ */
+ 7: optional i32 scale
+ 8: optional i32 precision
+
+ /** When the original schema supports field ids, this will save the
+ * original field id in the parquet schema
+ */
+ 9: optional i32 field_id;
+
+ /**
+ * The logical type of this SchemaElement
+ *
+ * LogicalType replaces ConvertedType, but ConvertedType is still required
+ * for some logical types to ensure forward-compatibility in format v1.
+ */
+ 10: optional LogicalType logicalType
+}
+
+/**
+ * Encodings supported by Parquet. Not all encodings are valid for all types. These
+ * enums are also used to specify the encoding of definition and repetition levels.
+ * See the accompanying doc for the details of the more complicated encodings.
+ */
+enum Encoding {
+ /** Default encoding.
+ * BOOLEAN - 1 bit per value. 0 is false; 1 is true.
+ * INT32 - 4 bytes per value. Stored as little-endian.
+ * INT64 - 8 bytes per value. Stored as little-endian.
+ * FLOAT - 4 bytes per value. IEEE. Stored as little-endian.
+ * DOUBLE - 8 bytes per value. IEEE. Stored as little-endian.
+ * BYTE_ARRAY - 4 byte length stored as little endian, followed by bytes.
+ * FIXED_LEN_BYTE_ARRAY - Just the bytes.
+ */
+ PLAIN = 0;
+
+ /** Group VarInt encoding for INT32/INT64.
+ * This encoding is deprecated. It was never used
+ */
+ // GROUP_VAR_INT = 1;
+
+ /**
+ * Deprecated: Dictionary encoding. The values in the dictionary are encoded in the
+ * plain type.
+ * in a data page use RLE_DICTIONARY instead.
+ * in a Dictionary page use PLAIN instead
+ */
+ PLAIN_DICTIONARY = 2;
+
+ /** Group packed run length encoding. Usable for definition/repetition levels
+ * encoding and Booleans (on one bit: 0 is false; 1 is true.)
+ */
+ RLE = 3;
+
+ /** Bit packed encoding. This can only be used if the data has a known max
+ * width. Usable for definition/repetition levels encoding.
+ */
+ BIT_PACKED = 4;
+
+ /** Delta encoding for integers. This can be used for int columns and works best
+ * on sorted data
+ */
+ DELTA_BINARY_PACKED = 5;
+
+ /** Encoding for byte arrays to separate the length values and the data. The lengths
+ * are encoded using DELTA_BINARY_PACKED
+ */
+ DELTA_LENGTH_BYTE_ARRAY = 6;
+
+ /** Incremental-encoded byte array. Prefix lengths are encoded using DELTA_BINARY_PACKED.
+ * Suffixes are stored as delta length byte arrays.
+ */
+ DELTA_BYTE_ARRAY = 7;
+
+ /** Dictionary encoding: the ids are encoded using the RLE encoding
+ */
+ RLE_DICTIONARY = 8;
+
+ /** Encoding for floating-point data.
+ K byte-streams are created where K is the size in bytes of the data type.
+ The individual bytes of an FP value are scattered to the corresponding stream and
+ the streams are concatenated.
+ This itself does not reduce the size of the data but can lead to better compression
+ afterwards.
+ */
+ BYTE_STREAM_SPLIT = 9;
+}
+
+/**
+ * Supported compression algorithms.
+ *
+ * Codecs added in format version X.Y can be read by readers based on X.Y and later.
+ * Codec support may vary between readers based on the format version and
+ * libraries available at runtime.
+ *
+ * See Compression.md for a detailed specification of these algorithms.
+ */
+enum CompressionCodec {
+ UNCOMPRESSED = 0;
+ SNAPPY = 1;
+ GZIP = 2;
+ LZO = 3;
+ BROTLI = 4; // Added in 2.4
+ LZ4 = 5; // DEPRECATED (Added in 2.4)
+ ZSTD = 6; // Added in 2.4
+ LZ4_RAW = 7; // Added in 2.9
+}
+
+enum PageType {
+ DATA_PAGE = 0;
+ INDEX_PAGE = 1;
+ DICTIONARY_PAGE = 2;
+ DATA_PAGE_V2 = 3;
+}
+
+/**
+ * Enum to annotate whether lists of min/max elements inside ColumnIndex
+ * are ordered and if so, in which direction.
+ */
+enum BoundaryOrder {
+ UNORDERED = 0;
+ ASCENDING = 1;
+ DESCENDING = 2;
+}
+
+/** Data page header */
+struct DataPageHeader {
+ /** Number of values, including NULLs, in this data page. **/
+ 1: required i32 num_values
+
+ /** Encoding used for this data page **/
+ 2: required Encoding encoding
+
+ /** Encoding used for definition levels **/
+ 3: required Encoding definition_level_encoding;
+
+ /** Encoding used for repetition levels **/
+ 4: required Encoding repetition_level_encoding;
+
+ /** Optional statistics for the data in this page**/
+ 5: optional Statistics statistics;
+}
+
+struct IndexPageHeader {
+ // TODO
+}
+
+struct DictionaryPageHeader {
+ /** Number of values in the dictionary **/
+ 1: required i32 num_values;
+
+ /** Encoding using this dictionary page **/
+ 2: required Encoding encoding
+
+ /** If true, the entries in the dictionary are sorted in ascending order **/
+ 3: optional bool is_sorted;
+}
+
+/**
+ * New page format allowing reading levels without decompressing the data
+ * Repetition and definition levels are uncompressed
+ * The remaining section containing the data is compressed if is_compressed is true
+ **/
+struct DataPageHeaderV2 {
+ /** Number of values, including NULLs, in this data page. **/
+ 1: required i32 num_values
+ /** Number of NULL values, in this data page.
+ Number of non-null = num_values - num_nulls which is also the number of values in the data section **/
+ 2: required i32 num_nulls
+ /** Number of rows in this data page. which means pages change on record boundaries (r = 0) **/
+ 3: required i32 num_rows
+ /** Encoding used for data in this page **/
+ 4: required Encoding encoding
+
+ // repetition levels and definition levels are always using RLE (without size in it)
+
+ /** length of the definition levels */
+ 5: required i32 definition_levels_byte_length;
+ /** length of the repetition levels */
+ 6: required i32 repetition_levels_byte_length;
+
+ /** whether the values are compressed.
+ Which means the section of the page between
+ definition_levels_byte_length + repetition_levels_byte_length + 1 and compressed_page_size (included)
+ is compressed with the compression_codec.
+ If missing it is considered compressed */
+ 7: optional bool is_compressed = 1;
+
+ /** optional statistics for the data in this page **/
+ 8: optional Statistics statistics;
+}
+
+/** Block-based algorithm type annotation. **/
+struct SplitBlockAlgorithm {}
+/** The algorithm used in Bloom filter. **/
+union BloomFilterAlgorithm {
+ /** Block-based Bloom filter. **/
+ 1: SplitBlockAlgorithm BLOCK;
+}
+
+/** Hash strategy type annotation. xxHash is an extremely fast non-cryptographic hash
+ * algorithm. It uses 64 bits version of xxHash.
+ **/
+struct XxHash {}
+
+/**
+ * The hash function used in Bloom filter. This function takes the hash of a column value
+ * using plain encoding.
+ **/
+union BloomFilterHash {
+ /** xxHash Strategy. **/
+ 1: XxHash XXHASH;
+}
+
+/**
+ * The compression used in the Bloom filter.
+ **/
+struct Uncompressed {}
+union BloomFilterCompression {
+ 1: Uncompressed UNCOMPRESSED;
+}
+
+/**
+ * Bloom filter header is stored at beginning of Bloom filter data of each column
+ * and followed by its bitset.
+ **/
+struct BloomFilterHeader {
+ /** The size of bitset in bytes **/
+ 1: required i32 numBytes;
+ /** The algorithm for setting bits. **/
+ 2: required BloomFilterAlgorithm algorithm;
+ /** The hash function used for Bloom filter. **/
+ 3: required BloomFilterHash hash;
+ /** The compression used in the Bloom filter **/
+ 4: required BloomFilterCompression compression;
+}
+
+struct PageHeader {
+ /** the type of the page: indicates which of the *_header fields is set **/
+ 1: required PageType type
+
+ /** Uncompressed page size in bytes (not including this header) **/
+ 2: required i32 uncompressed_page_size
+
+ /** Compressed (and potentially encrypted) page size in bytes, not including this header **/
+ 3: required i32 compressed_page_size
+
+ /** The 32bit CRC for the page, to be be calculated as follows:
+ * - Using the standard CRC32 algorithm
+ * - On the data only, i.e. this header should not be included. 'Data'
+ * hereby refers to the concatenation of the repetition levels, the
+ * definition levels and the column value, in this exact order.
+ * - On the encoded versions of the repetition levels, definition levels and
+ * column values
+ * - On the compressed versions of the repetition levels, definition levels
+ * and column values where possible;
+ * - For v1 data pages, the repetition levels, definition levels and column
+ * values are always compressed together. If a compression scheme is
+ * specified, the CRC shall be calculated on the compressed version of
+ * this concatenation. If no compression scheme is specified, the CRC
+ * shall be calculated on the uncompressed version of this concatenation.
+ * - For v2 data pages, the repetition levels and definition levels are
+ * handled separately from the data and are never compressed (only
+ * encoded). If a compression scheme is specified, the CRC shall be
+ * calculated on the concatenation of the uncompressed repetition levels,
+ * uncompressed definition levels and the compressed column values.
+ * If no compression scheme is specified, the CRC shall be calculated on
+ * the uncompressed concatenation.
+ * - In encrypted columns, CRC is calculated after page encryption; the
+ * encryption itself is performed after page compression (if compressed)
+ * If enabled, this allows for disabling checksumming in HDFS if only a few
+ * pages need to be read.
+ **/
+ 4: optional i32 crc
+
+ // Headers for page specific data. One only will be set.
+ 5: optional DataPageHeader data_page_header;
+ 6: optional IndexPageHeader index_page_header;
+ 7: optional DictionaryPageHeader dictionary_page_header;
+ 8: optional DataPageHeaderV2 data_page_header_v2;
+}
+
+/**
+ * Wrapper struct to store key values
+ */
+ struct KeyValue {
+ 1: required string key
+ 2: optional string value
+}
+
+/**
+ * Wrapper struct to specify sort order
+ */
+struct SortingColumn {
+ /** The column index (in this row group) **/
+ 1: required i32 column_idx
+
+ /** If true, indicates this column is sorted in descending order. **/
+ 2: required bool descending
+
+ /** If true, nulls will come before non-null values, otherwise,
+ * nulls go at the end. */
+ 3: required bool nulls_first
+}
+
+/**
+ * statistics of a given page type and encoding
+ */
+struct PageEncodingStats {
+
+ /** the page type (data/dic/...) **/
+ 1: required PageType page_type;
+
+ /** encoding of the page **/
+ 2: required Encoding encoding;
+
+ /** number of pages of this type with this encoding **/
+ 3: required i32 count;
+
+}
+
+/**
+ * Description for column metadata
+ */
+struct ColumnMetaData {
+ /** Type of this column **/
+ 1: required Type type
+
+ /** Set of all encodings used for this column. The purpose is to validate
+ * whether we can decode those pages. **/
+ 2: required list<Encoding> encodings
+
+ /** Path in schema **/
+ 3: required list<string> path_in_schema
+
+ /** Compression codec **/
+ 4: required CompressionCodec codec
+
+ /** Number of values in this column **/
+ 5: required i64 num_values
+
+ /** total byte size of all uncompressed pages in this column chunk (including the headers) **/
+ 6: required i64 total_uncompressed_size
+
+ /** total byte size of all compressed, and potentially encrypted, pages
+ * in this column chunk (including the headers) **/
+ 7: required i64 total_compressed_size
+
+ /** Optional key/value metadata **/
+ 8: optional list<KeyValue> key_value_metadata
+
+ /** Byte offset from beginning of file to first data page **/
+ 9: required i64 data_page_offset
+
+ /** Byte offset from beginning of file to root index page **/
+ 10: optional i64 index_page_offset
+
+ /** Byte offset from the beginning of file to first (only) dictionary page **/
+ 11: optional i64 dictionary_page_offset
+
+ /** optional statistics for this column chunk */
+ 12: optional Statistics statistics;
+
+ /** Set of all encodings used for pages in this column chunk.
+ * This information can be used to determine if all data pages are
+ * dictionary encoded for example **/
+ 13: optional list<PageEncodingStats> encoding_stats;
+
+ /** Byte offset from beginning of file to Bloom filter data. **/
+ 14: optional i64 bloom_filter_offset;
+}
+
+struct EncryptionWithFooterKey {
+}
+
+struct EncryptionWithColumnKey {
+ /** Column path in schema **/
+ 1: required list<string> path_in_schema
+
+ /** Retrieval metadata of column encryption key **/
+ 2: optional binary key_metadata
+}
+
+union ColumnCryptoMetaData {
+ 1: EncryptionWithFooterKey ENCRYPTION_WITH_FOOTER_KEY
+ 2: EncryptionWithColumnKey ENCRYPTION_WITH_COLUMN_KEY
+}
+
+struct ColumnChunk {
+ /** File where column data is stored. If not set, assumed to be same file as
+ * metadata. This path is relative to the current file.
+ **/
+ 1: optional string file_path
+
+ /** Byte offset in file_path to the ColumnMetaData **/
+ 2: required i64 file_offset
+
+ /** Column metadata for this chunk. This is the same content as what is at
+ * file_path/file_offset. Having it here has it replicated in the file
+ * metadata.
+ **/
+ 3: optional ColumnMetaData meta_data
+
+ /** File offset of ColumnChunk's OffsetIndex **/
+ 4: optional i64 offset_index_offset
+
+ /** Size of ColumnChunk's OffsetIndex, in bytes **/
+ 5: optional i32 offset_index_length
+
+ /** File offset of ColumnChunk's ColumnIndex **/
+ 6: optional i64 column_index_offset
+
+ /** Size of ColumnChunk's ColumnIndex, in bytes **/
+ 7: optional i32 column_index_length
+
+ /** Crypto metadata of encrypted columns **/
+ 8: optional ColumnCryptoMetaData crypto_metadata
+
+ /** Encrypted column metadata for this chunk **/
+ 9: optional binary encrypted_column_metadata
+}
+
+struct RowGroup {
+ /** Metadata for each column chunk in this row group.
+ * This list must have the same order as the SchemaElement list in FileMetaData.
+ **/
+ 1: required list<ColumnChunk> columns
+
+ /** Total byte size of all the uncompressed column data in this row group **/
+ 2: required i64 total_byte_size
+
+ /** Number of rows in this row group **/
+ 3: required i64 num_rows
+
+ /** If set, specifies a sort ordering of the rows in this RowGroup.
+ * The sorting columns can be a subset of all the columns.
+ */
+ 4: optional list<SortingColumn> sorting_columns
+
+ /** Byte offset from beginning of file to first page (data or dictionary)
+ * in this row group **/
+ 5: optional i64 file_offset
+
+ /** Total byte size of all compressed (and potentially encrypted) column data
+ * in this row group **/
+ 6: optional i64 total_compressed_size
+
+ /** Row group ordinal in the file **/
+ 7: optional i16 ordinal
+}
+
+/** Empty struct to signal the order defined by the physical or logical type */
+struct TypeDefinedOrder {}
+
+/**
+ * Union to specify the order used for the min_value and max_value fields for a
+ * column. This union takes the role of an enhanced enum that allows rich
+ * elements (which will be needed for a collation-based ordering in the future).
+ *
+ * Possible values are:
+ * * TypeDefinedOrder - the column uses the order defined by its logical or
+ * physical type (if there is no logical type).
+ *
+ * If the reader does not support the value of this union, min and max stats
+ * for this column should be ignored.
+ */
+union ColumnOrder {
+
+ /**
+ * The sort orders for logical types are:
+ * UTF8 - unsigned byte-wise comparison
+ * INT8 - signed comparison
+ * INT16 - signed comparison
+ * INT32 - signed comparison
+ * INT64 - signed comparison
+ * UINT8 - unsigned comparison
+ * UINT16 - unsigned comparison
+ * UINT32 - unsigned comparison
+ * UINT64 - unsigned comparison
+ * DECIMAL - signed comparison of the represented value
+ * DATE - signed comparison
+ * TIME_MILLIS - signed comparison
+ * TIME_MICROS - signed comparison
+ * TIMESTAMP_MILLIS - signed comparison
+ * TIMESTAMP_MICROS - signed comparison
+ * INTERVAL - unsigned comparison
+ * JSON - unsigned byte-wise comparison
+ * BSON - unsigned byte-wise comparison
+ * ENUM - unsigned byte-wise comparison
+ * LIST - undefined
+ * MAP - undefined
+ *
+ * In the absence of logical types, the sort order is determined by the physical type:
+ * BOOLEAN - false, true
+ * INT32 - signed comparison
+ * INT64 - signed comparison
+ * INT96 (only used for legacy timestamps) - undefined
+ * FLOAT - signed comparison of the represented value (*)
+ * DOUBLE - signed comparison of the represented value (*)
+ * BYTE_ARRAY - unsigned byte-wise comparison
+ * FIXED_LEN_BYTE_ARRAY - unsigned byte-wise comparison
+ *
+ * (*) Because the sorting order is not specified properly for floating
+ * point values (relations vs. total ordering) the following
+ * compatibility rules should be applied when reading statistics:
+ * - If the min is a NaN, it should be ignored.
+ * - If the max is a NaN, it should be ignored.
+ * - If the min is +0, the row group may contain -0 values as well.
+ * - If the max is -0, the row group may contain +0 values as well.
+ * - When looking for NaN values, min and max should be ignored.
+ */
+ 1: TypeDefinedOrder TYPE_ORDER;
+}
+
+struct PageLocation {
+ /** Offset of the page in the file **/
+ 1: required i64 offset
+
+ /**
+ * Size of the page, including header. Sum of compressed_page_size and header
+ * length
+ */
+ 2: required i32 compressed_page_size
+
+ /**
+ * Index within the RowGroup of the first row of the page; this means pages
+ * change on record boundaries (r = 0).
+ */
+ 3: required i64 first_row_index
+}
+
+struct OffsetIndex {
+ /**
+ * PageLocations, ordered by increasing PageLocation.offset. It is required
+ * that page_locations[i].first_row_index < page_locations[i+1].first_row_index.
+ */
+ 1: required list<PageLocation> page_locations
+}
+
+/**
+ * Description for ColumnIndex.
+ * Each <array-field>[i] refers to the page at OffsetIndex.page_locations[i]
+ */
+struct ColumnIndex {
+ /**
+ * A list of Boolean values to determine the validity of the corresponding
+ * min and max values. If true, a page contains only null values, and writers
+ * have to set the corresponding entries in min_values and max_values to
+ * byte[0], so that all lists have the same length. If false, the
+ * corresponding entries in min_values and max_values must be valid.
+ */
+ 1: required list<bool> null_pages
+
+ /**
+ * Two lists containing lower and upper bounds for the values of each page.
+ * These may be the actual minimum and maximum values found on a page, but
+ * can also be (more compact) values that do not exist on a page. For
+ * example, instead of storing ""Blart Versenwald III", a writer may set
+ * min_values[i]="B", max_values[i]="C". Such more compact values must still
+ * be valid values within the column's logical type. Readers must make sure
+ * that list entries are populated before using them by inspecting null_pages.
+ */
+ 2: required list<binary> min_values
+ 3: required list<binary> max_values
+
+ /**
+ * Stores whether both min_values and max_values are orderd and if so, in
+ * which direction. This allows readers to perform binary searches in both
+ * lists. Readers cannot assume that max_values[i] <= min_values[i+1], even
+ * if the lists are ordered.
+ */
+ 4: required BoundaryOrder boundary_order
+
+ /** A list containing the number of null values for each page **/
+ 5: optional list<i64> null_counts
+}
+
+struct AesGcmV1 {
+ /** AAD prefix **/
+ 1: optional binary aad_prefix
+
+ /** Unique file identifier part of AAD suffix **/
+ 2: optional binary aad_file_unique
+
+ /** In files encrypted with AAD prefix without storing it,
+ * readers must supply the prefix **/
+ 3: optional bool supply_aad_prefix
+}
+
+struct AesGcmCtrV1 {
+ /** AAD prefix **/
+ 1: optional binary aad_prefix
+
+ /** Unique file identifier part of AAD suffix **/
+ 2: optional binary aad_file_unique
+
+ /** In files encrypted with AAD prefix without storing it,
+ * readers must supply the prefix **/
+ 3: optional bool supply_aad_prefix
+}
+
+union EncryptionAlgorithm {
+ 1: AesGcmV1 AES_GCM_V1
+ 2: AesGcmCtrV1 AES_GCM_CTR_V1
+}
+
+/**
+ * Description for file metadata
+ */
+struct FileMetaData {
+ /** Version of this file **/
+ 1: required i32 version
+
+ /** Parquet schema for this file. This schema contains metadata for all the columns.
+ * The schema is represented as a tree with a single root. The nodes of the tree
+ * are flattened to a list by doing a depth-first traversal.
+ * The column metadata contains the path in the schema for that column which can be
+ * used to map columns to nodes in the schema.
+ * The first element is the root **/
+ 2: required list<SchemaElement> schema;
+
+ /** Number of rows in this file **/
+ 3: required i64 num_rows
+
+ /** Row groups in this file **/
+ 4: required list<RowGroup> row_groups
+
+ /** Optional key/value metadata **/
+ 5: optional list<KeyValue> key_value_metadata
+
+ /** String for application that wrote this file. This should be in the format
+ * <Application> version <App Version> (build <App Build Hash>).
+ * e.g. impala version 1.0 (build 6cf94d29b2b7115df4de2c06e2ab4326d721eb55)
+ **/
+ 6: optional string created_by
+
+ /**
+ * Sort order used for the min_value and max_value fields of each column in
+ * this file. Sort orders are listed in the order matching the columns in the
+ * schema. The indexes are not necessary the same though, because only leaf
+ * nodes of the schema are represented in the list of sort orders.
+ *
+ * Without column_orders, the meaning of the min_value and max_value fields is
+ * undefined. To ensure well-defined behaviour, if min_value and max_value are
+ * written to a Parquet file, column_orders must be written as well.
+ *
+ * The obsolete min and max fields are always sorted by signed comparison
+ * regardless of column_orders.
+ */
+ 7: optional list<ColumnOrder> column_orders;
+
+ /**
+ * Encryption algorithm. This field is set only in encrypted files
+ * with plaintext footer. Files with encrypted footer store algorithm id
+ * in FileCryptoMetaData structure.
+ */
+ 8: optional EncryptionAlgorithm encryption_algorithm
+
+ /**
+ * Retrieval metadata of key used for signing the footer.
+ * Used only in encrypted files with plaintext footer.
+ */
+ 9: optional binary footer_signing_key_metadata
+}
+
+/** Crypto metadata for files with encrypted footer **/
+struct FileCryptoMetaData {
+ /**
+ * Encryption algorithm. This field is only used for files
+ * with encrypted footer. Files with plaintext footer store algorithm id
+ * inside footer (FileMetaData structure).
+ */
+ 1: required EncryptionAlgorithm encryption_algorithm
+
+ /** Retrieval metadata of key used for encryption of footer,
+ * and (possibly) columns **/
+ 2: optional binary key_metadata
+}
+
diff --git a/src/arrow/cpp/src/parquet/parquet_version.h.in b/src/arrow/cpp/src/parquet/parquet_version.h.in
new file mode 100644
index 000000000..b7d9576c4
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/parquet_version.h.in
@@ -0,0 +1,31 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef PARQUET_VERSION_H
+#define PARQUET_VERSION_H
+
+#define PARQUET_VERSION_MAJOR @ARROW_VERSION_MAJOR@
+#define PARQUET_VERSION_MINOR @ARROW_VERSION_MINOR@
+#define PARQUET_VERSION_PATCH @ARROW_VERSION_PATCH@
+
+#define PARQUET_SO_VERSION "@ARROW_SO_VERSION@"
+#define PARQUET_FULL_SO_VERSION "@ARROW_FULL_SO_VERSION@"
+
+// define the parquet created by version
+#define CREATED_BY_VERSION "parquet-cpp-arrow version @ARROW_VERSION@"
+
+#endif // PARQUET_VERSION_H
diff --git a/src/arrow/cpp/src/parquet/pch.h b/src/arrow/cpp/src/parquet/pch.h
new file mode 100644
index 000000000..59e64bfc6
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/pch.h
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Often-used headers, for precompiling.
+// If updating this header, please make sure you check compilation speed
+// before checking in. Adding headers which are not used extremely often
+// may incur a slowdown, since it makes the precompiled header heavier to load.
+
+#include "parquet/encoding.h"
+#include "parquet/exception.h"
+#include "parquet/metadata.h"
+#include "parquet/properties.h"
+#include "parquet/schema.h"
+#include "parquet/types.h"
diff --git a/src/arrow/cpp/src/parquet/platform.cc b/src/arrow/cpp/src/parquet/platform.cc
new file mode 100644
index 000000000..5c355c28b
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/platform.cc
@@ -0,0 +1,41 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/platform.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+
+#include "arrow/io/memory.h"
+
+#include "parquet/exception.h"
+
+namespace parquet {
+
+std::shared_ptr<::arrow::io::BufferOutputStream> CreateOutputStream(MemoryPool* pool) {
+ PARQUET_ASSIGN_OR_THROW(auto stream, ::arrow::io::BufferOutputStream::Create(
+ kDefaultOutputStreamSize, pool));
+ return stream;
+}
+
+std::shared_ptr<ResizableBuffer> AllocateBuffer(MemoryPool* pool, int64_t size) {
+ PARQUET_ASSIGN_OR_THROW(auto result, ::arrow::AllocateResizableBuffer(size, pool));
+ return std::move(result);
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/platform.h b/src/arrow/cpp/src/parquet/platform.h
new file mode 100644
index 000000000..00a193f14
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/platform.h
@@ -0,0 +1,111 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+
+#include "arrow/buffer.h" // IWYU pragma: export
+#include "arrow/io/interfaces.h" // IWYU pragma: export
+#include "arrow/status.h" // IWYU pragma: export
+#include "arrow/type_fwd.h" // IWYU pragma: export
+#include "arrow/util/macros.h" // IWYU pragma: export
+
+#if defined(_WIN32) || defined(__CYGWIN__)
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+// Disable warning for STL types usage in DLL interface
+// https://web.archive.org/web/20130317015847/http://connect.microsoft.com/VisualStudio/feedback/details/696593/vc-10-vs-2010-basic-string-exports
+#pragma warning(disable : 4275 4251)
+// Disable diamond inheritance warnings
+#pragma warning(disable : 4250)
+// Disable macro redefinition warnings
+#pragma warning(disable : 4005)
+// Disable extern before exported template warnings
+#pragma warning(disable : 4910)
+#else
+#pragma GCC diagnostic ignored "-Wattributes"
+#endif
+
+#ifdef PARQUET_STATIC
+#define PARQUET_EXPORT
+#elif defined(PARQUET_EXPORTING)
+#define PARQUET_EXPORT __declspec(dllexport)
+#else
+#define PARQUET_EXPORT __declspec(dllimport)
+#endif
+
+#define PARQUET_NO_EXPORT
+
+#else // Not Windows
+#ifndef PARQUET_EXPORT
+#define PARQUET_EXPORT __attribute__((visibility("default")))
+#endif
+#ifndef PARQUET_NO_EXPORT
+#define PARQUET_NO_EXPORT __attribute__((visibility("hidden")))
+#endif
+#endif // Non-Windows
+
+// This is a complicated topic, some reading on it:
+// http://www.codesynthesis.com/~boris/blog/2010/01/18/dll-export-cxx-templates/
+#if defined(_MSC_VER) || defined(__clang__)
+#define PARQUET_TEMPLATE_CLASS_EXPORT
+#define PARQUET_TEMPLATE_EXPORT PARQUET_EXPORT
+#else
+#define PARQUET_TEMPLATE_CLASS_EXPORT PARQUET_EXPORT
+#define PARQUET_TEMPLATE_EXPORT
+#endif
+
+#define PARQUET_DISALLOW_COPY_AND_ASSIGN ARROW_DISALLOW_COPY_AND_ASSIGN
+
+#define PARQUET_NORETURN ARROW_NORETURN
+#define PARQUET_DEPRECATED ARROW_DEPRECATED
+
+// If ARROW_VALGRIND set when compiling unit tests, also define
+// PARQUET_VALGRIND
+#ifdef ARROW_VALGRIND
+#define PARQUET_VALGRIND
+#endif
+
+namespace parquet {
+
+using Buffer = ::arrow::Buffer;
+using Codec = ::arrow::util::Codec;
+using Compression = ::arrow::Compression;
+using MemoryPool = ::arrow::MemoryPool;
+using MutableBuffer = ::arrow::MutableBuffer;
+using ResizableBuffer = ::arrow::ResizableBuffer;
+using ResizableBuffer = ::arrow::ResizableBuffer;
+using ArrowInputFile = ::arrow::io::RandomAccessFile;
+using ArrowInputStream = ::arrow::io::InputStream;
+using ArrowOutputStream = ::arrow::io::OutputStream;
+
+constexpr int64_t kDefaultOutputStreamSize = 1024;
+
+constexpr int16_t kNonPageOrdinal = static_cast<int16_t>(-1);
+
+PARQUET_EXPORT
+std::shared_ptr<::arrow::io::BufferOutputStream> CreateOutputStream(
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool());
+
+PARQUET_EXPORT
+std::shared_ptr<ResizableBuffer> AllocateBuffer(
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool(), int64_t size = 0);
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/printer.cc b/src/arrow/cpp/src/parquet/printer.cc
new file mode 100644
index 000000000..dfd4bd802
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/printer.cc
@@ -0,0 +1,297 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/printer.h"
+
+#include <cstdint>
+#include <cstdio>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <vector>
+
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/string.h"
+
+#include "parquet/column_scanner.h"
+#include "parquet/exception.h"
+#include "parquet/file_reader.h"
+#include "parquet/metadata.h"
+#include "parquet/schema.h"
+#include "parquet/statistics.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+class ColumnReader;
+
+// ----------------------------------------------------------------------
+// ParquetFilePrinter::DebugPrint
+
+// the fixed initial size is just for an example
+#define COL_WIDTH 30
+
+void ParquetFilePrinter::DebugPrint(std::ostream& stream, std::list<int> selected_columns,
+ bool print_values, bool format_dump,
+ bool print_key_value_metadata, const char* filename) {
+ const FileMetaData* file_metadata = fileReader->metadata().get();
+
+ stream << "File Name: " << filename << "\n";
+ stream << "Version: " << ParquetVersionToString(file_metadata->version()) << "\n";
+ stream << "Created By: " << file_metadata->created_by() << "\n";
+ stream << "Total rows: " << file_metadata->num_rows() << "\n";
+
+ if (print_key_value_metadata && file_metadata->key_value_metadata()) {
+ auto key_value_metadata = file_metadata->key_value_metadata();
+ int64_t size_of_key_value_metadata = key_value_metadata->size();
+ stream << "Key Value File Metadata: " << size_of_key_value_metadata << " entries\n";
+ for (int64_t i = 0; i < size_of_key_value_metadata; i++) {
+ stream << " Key nr " << i << " " << key_value_metadata->key(i) << ": "
+ << key_value_metadata->value(i) << "\n";
+ }
+ }
+
+ stream << "Number of RowGroups: " << file_metadata->num_row_groups() << "\n";
+ stream << "Number of Real Columns: "
+ << file_metadata->schema()->group_node()->field_count() << "\n";
+
+ if (selected_columns.size() == 0) {
+ for (int i = 0; i < file_metadata->num_columns(); i++) {
+ selected_columns.push_back(i);
+ }
+ } else {
+ for (auto i : selected_columns) {
+ if (i < 0 || i >= file_metadata->num_columns()) {
+ throw ParquetException("Selected column is out of range");
+ }
+ }
+ }
+
+ stream << "Number of Columns: " << file_metadata->num_columns() << "\n";
+ stream << "Number of Selected Columns: " << selected_columns.size() << "\n";
+ for (auto i : selected_columns) {
+ const ColumnDescriptor* descr = file_metadata->schema()->Column(i);
+ stream << "Column " << i << ": " << descr->path()->ToDotString() << " ("
+ << TypeToString(descr->physical_type());
+ const auto& logical_type = descr->logical_type();
+ if (!logical_type->is_none()) {
+ stream << " / " << logical_type->ToString();
+ }
+ if (descr->converted_type() != ConvertedType::NONE) {
+ stream << " / " << ConvertedTypeToString(descr->converted_type());
+ if (descr->converted_type() == ConvertedType::DECIMAL) {
+ stream << "(" << descr->type_precision() << "," << descr->type_scale() << ")";
+ }
+ }
+ stream << ")" << std::endl;
+ }
+
+ for (int r = 0; r < file_metadata->num_row_groups(); ++r) {
+ stream << "--- Row Group: " << r << " ---\n";
+
+ auto group_reader = fileReader->RowGroup(r);
+ std::unique_ptr<RowGroupMetaData> group_metadata = file_metadata->RowGroup(r);
+
+ stream << "--- Total Bytes: " << group_metadata->total_byte_size() << " ---\n";
+ stream << "--- Total Compressed Bytes: " << group_metadata->total_compressed_size()
+ << " ---\n";
+ stream << "--- Rows: " << group_metadata->num_rows() << " ---\n";
+
+ // Print column metadata
+ for (auto i : selected_columns) {
+ auto column_chunk = group_metadata->ColumnChunk(i);
+ std::shared_ptr<Statistics> stats = column_chunk->statistics();
+
+ const ColumnDescriptor* descr = file_metadata->schema()->Column(i);
+ stream << "Column " << i << std::endl << " Values: " << column_chunk->num_values();
+ if (column_chunk->is_stats_set()) {
+ std::string min = stats->EncodeMin(), max = stats->EncodeMax();
+ stream << ", Null Values: " << stats->null_count()
+ << ", Distinct Values: " << stats->distinct_count() << std::endl
+ << " Max: " << FormatStatValue(descr->physical_type(), max)
+ << ", Min: " << FormatStatValue(descr->physical_type(), min);
+ } else {
+ stream << " Statistics Not Set";
+ }
+ stream << std::endl
+ << " Compression: "
+ << ::arrow::internal::AsciiToUpper(
+ Codec::GetCodecAsString(column_chunk->compression()))
+ << ", Encodings:";
+ for (auto encoding : column_chunk->encodings()) {
+ stream << " " << EncodingToString(encoding);
+ }
+ stream << std::endl
+ << " Uncompressed Size: " << column_chunk->total_uncompressed_size()
+ << ", Compressed Size: " << column_chunk->total_compressed_size()
+ << std::endl;
+ }
+
+ if (!print_values) {
+ continue;
+ }
+ stream << "--- Values ---\n";
+
+ static constexpr int bufsize = COL_WIDTH + 1;
+ char buffer[bufsize];
+
+ // Create readers for selected columns and print contents
+ std::vector<std::shared_ptr<Scanner>> scanners(selected_columns.size(), nullptr);
+ int j = 0;
+ for (auto i : selected_columns) {
+ std::shared_ptr<ColumnReader> col_reader = group_reader->Column(i);
+ // This is OK in this method as long as the RowGroupReader does not get
+ // deleted
+ auto& scanner = scanners[j++] = Scanner::Make(col_reader);
+
+ if (format_dump) {
+ stream << "Column " << i << std::endl;
+ while (scanner->HasNext()) {
+ scanner->PrintNext(stream, 0, true);
+ stream << "\n";
+ }
+ continue;
+ }
+
+ snprintf(buffer, bufsize, "%-*s", COL_WIDTH,
+ file_metadata->schema()->Column(i)->name().c_str());
+ stream << buffer << '|';
+ }
+ if (format_dump) {
+ continue;
+ }
+ stream << "\n";
+
+ bool hasRow;
+ do {
+ hasRow = false;
+ for (auto scanner : scanners) {
+ if (scanner->HasNext()) {
+ hasRow = true;
+ scanner->PrintNext(stream, COL_WIDTH);
+ stream << '|';
+ }
+ }
+ stream << "\n";
+ } while (hasRow);
+ }
+}
+
+void ParquetFilePrinter::JSONPrint(std::ostream& stream, std::list<int> selected_columns,
+ const char* filename) {
+ const FileMetaData* file_metadata = fileReader->metadata().get();
+ stream << "{\n";
+ stream << " \"FileName\": \"" << filename << "\",\n";
+ stream << " \"Version\": \"" << ParquetVersionToString(file_metadata->version())
+ << "\",\n";
+ stream << " \"CreatedBy\": \"" << file_metadata->created_by() << "\",\n";
+ stream << " \"TotalRows\": \"" << file_metadata->num_rows() << "\",\n";
+ stream << " \"NumberOfRowGroups\": \"" << file_metadata->num_row_groups() << "\",\n";
+ stream << " \"NumberOfRealColumns\": \""
+ << file_metadata->schema()->group_node()->field_count() << "\",\n";
+ stream << " \"NumberOfColumns\": \"" << file_metadata->num_columns() << "\",\n";
+
+ if (selected_columns.size() == 0) {
+ for (int i = 0; i < file_metadata->num_columns(); i++) {
+ selected_columns.push_back(i);
+ }
+ } else {
+ for (auto i : selected_columns) {
+ if (i < 0 || i >= file_metadata->num_columns()) {
+ throw ParquetException("Selected column is out of range");
+ }
+ }
+ }
+
+ stream << " \"Columns\": [\n";
+ int c = 0;
+ for (auto i : selected_columns) {
+ const ColumnDescriptor* descr = file_metadata->schema()->Column(i);
+ stream << " { \"Id\": \"" << i << "\","
+ << " \"Name\": \"" << descr->path()->ToDotString() << "\","
+ << " \"PhysicalType\": \"" << TypeToString(descr->physical_type()) << "\","
+ << " \"ConvertedType\": \"" << ConvertedTypeToString(descr->converted_type())
+ << "\","
+ << " \"LogicalType\": " << (descr->logical_type())->ToJSON() << " }";
+ c++;
+ if (c != static_cast<int>(selected_columns.size())) {
+ stream << ",\n";
+ }
+ }
+
+ stream << "\n ],\n \"RowGroups\": [\n";
+ for (int r = 0; r < file_metadata->num_row_groups(); ++r) {
+ stream << " {\n \"Id\": \"" << r << "\", ";
+
+ auto group_reader = fileReader->RowGroup(r);
+ std::unique_ptr<RowGroupMetaData> group_metadata = file_metadata->RowGroup(r);
+
+ stream << " \"TotalBytes\": \"" << group_metadata->total_byte_size() << "\", ";
+ stream << " \"TotalCompressedBytes\": \"" << group_metadata->total_compressed_size()
+ << "\", ";
+ stream << " \"Rows\": \"" << group_metadata->num_rows() << "\",\n";
+
+ // Print column metadata
+ stream << " \"ColumnChunks\": [\n";
+ int c1 = 0;
+ for (auto i : selected_columns) {
+ auto column_chunk = group_metadata->ColumnChunk(i);
+ std::shared_ptr<Statistics> stats = column_chunk->statistics();
+
+ const ColumnDescriptor* descr = file_metadata->schema()->Column(i);
+ stream << " {\"Id\": \"" << i << "\", \"Values\": \""
+ << column_chunk->num_values() << "\", "
+ << "\"StatsSet\": ";
+ if (column_chunk->is_stats_set()) {
+ stream << "\"True\", \"Stats\": {";
+ std::string min = stats->EncodeMin(), max = stats->EncodeMax();
+ stream << "\"NumNulls\": \"" << stats->null_count() << "\", "
+ << "\"DistinctValues\": \"" << stats->distinct_count() << "\", "
+ << "\"Max\": \"" << FormatStatValue(descr->physical_type(), max) << "\", "
+ << "\"Min\": \"" << FormatStatValue(descr->physical_type(), min)
+ << "\" },";
+ } else {
+ stream << "\"False\",";
+ }
+ stream << "\n \"Compression\": \""
+ << ::arrow::internal::AsciiToUpper(
+ Codec::GetCodecAsString(column_chunk->compression()))
+ << "\", \"Encodings\": \"";
+ for (auto encoding : column_chunk->encodings()) {
+ stream << EncodingToString(encoding) << " ";
+ }
+ stream << "\", "
+ << "\"UncompressedSize\": \"" << column_chunk->total_uncompressed_size()
+ << "\", \"CompressedSize\": \"" << column_chunk->total_compressed_size();
+
+ // end of a ColumnChunk
+ stream << "\" }";
+ c1++;
+ if (c1 != static_cast<int>(selected_columns.size())) {
+ stream << ",\n";
+ }
+ }
+
+ stream << "\n ]\n }";
+ if ((r + 1) != static_cast<int>(file_metadata->num_row_groups())) {
+ stream << ",\n";
+ }
+ }
+ stream << "\n ]\n}\n";
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/printer.h b/src/arrow/cpp/src/parquet/printer.h
new file mode 100644
index 000000000..6bdf5b456
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/printer.h
@@ -0,0 +1,46 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <iosfwd>
+#include <list>
+
+#include "parquet/platform.h"
+
+namespace parquet {
+
+class ParquetFileReader;
+
+class PARQUET_EXPORT ParquetFilePrinter {
+ private:
+ ParquetFileReader* fileReader;
+
+ public:
+ explicit ParquetFilePrinter(ParquetFileReader* reader) : fileReader(reader) {}
+ ~ParquetFilePrinter() {}
+
+ void DebugPrint(std::ostream& stream, std::list<int> selected_columns,
+ bool print_values = false, bool format_dump = false,
+ bool print_key_value_metadata = false,
+ const char* filename = "No Name");
+
+ void JSONPrint(std::ostream& stream, std::list<int> selected_columns,
+ const char* filename = "No Name");
+};
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/properties.cc b/src/arrow/cpp/src/parquet/properties.cc
new file mode 100644
index 000000000..93638dbe2
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/properties.cc
@@ -0,0 +1,64 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+#include <utility>
+
+#include "parquet/properties.h"
+
+#include "arrow/io/buffered.h"
+#include "arrow/io/memory.h"
+#include "arrow/util/logging.h"
+
+namespace parquet {
+
+std::shared_ptr<ArrowInputStream> ReaderProperties::GetStream(
+ std::shared_ptr<ArrowInputFile> source, int64_t start, int64_t num_bytes) {
+ if (buffered_stream_enabled_) {
+ // ARROW-6180 / PARQUET-1636 Create isolated reader that references segment
+ // of source
+ std::shared_ptr<::arrow::io::InputStream> safe_stream =
+ ::arrow::io::RandomAccessFile::GetStream(source, start, num_bytes);
+ PARQUET_ASSIGN_OR_THROW(
+ auto stream, ::arrow::io::BufferedInputStream::Create(buffer_size_, pool_,
+ safe_stream, num_bytes));
+ return std::move(stream);
+ } else {
+ PARQUET_ASSIGN_OR_THROW(auto data, source->ReadAt(start, num_bytes));
+
+ if (data->size() != num_bytes) {
+ std::stringstream ss;
+ ss << "Tried reading " << num_bytes << " bytes starting at position " << start
+ << " from file but only got " << data->size();
+ throw ParquetException(ss.str());
+ }
+ return std::make_shared<::arrow::io::BufferReader>(data);
+ }
+}
+
+ArrowReaderProperties default_arrow_reader_properties() {
+ static ArrowReaderProperties default_reader_props;
+ return default_reader_props;
+}
+
+std::shared_ptr<ArrowWriterProperties> default_arrow_writer_properties() {
+ static std::shared_ptr<ArrowWriterProperties> default_writer_properties =
+ ArrowWriterProperties::Builder().build();
+ return default_writer_properties;
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/properties.h b/src/arrow/cpp/src/parquet/properties.h
new file mode 100644
index 000000000..cd9b87b05
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/properties.h
@@ -0,0 +1,801 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "arrow/io/caching.h"
+#include "arrow/type.h"
+#include "arrow/util/compression.h"
+#include "parquet/encryption/encryption.h"
+#include "parquet/exception.h"
+#include "parquet/parquet_version.h"
+#include "parquet/platform.h"
+#include "parquet/schema.h"
+#include "parquet/type_fwd.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+/// Controls serialization format of data pages. parquet-format v2.0.0
+/// introduced a new data page metadata type DataPageV2 and serialized page
+/// structure (for example, encoded levels are no longer compressed). Prior to
+/// the completion of PARQUET-457 in 2020, this library did not implement
+/// DataPageV2 correctly, so if you use the V2 data page format, you may have
+/// forward compatibility issues (older versions of the library will be unable
+/// to read the files). Note that some Parquet implementations do not implement
+/// DataPageV2 at all.
+enum class ParquetDataPageVersion { V1, V2 };
+
+/// Align the default buffer size to a small multiple of a page size.
+constexpr int64_t kDefaultBufferSize = 4096 * 4;
+
+class PARQUET_EXPORT ReaderProperties {
+ public:
+ explicit ReaderProperties(MemoryPool* pool = ::arrow::default_memory_pool())
+ : pool_(pool) {}
+
+ MemoryPool* memory_pool() const { return pool_; }
+
+ std::shared_ptr<ArrowInputStream> GetStream(std::shared_ptr<ArrowInputFile> source,
+ int64_t start, int64_t num_bytes);
+
+ /// Buffered stream reading allows the user to control the memory usage of
+ /// parquet readers. This ensure that all `RandomAccessFile::ReadAt` calls are
+ /// wrapped in a buffered reader that uses a fix sized buffer (of size
+ /// `buffer_size()`) instead of the full size of the ReadAt.
+ ///
+ /// The primary reason for this control knobs is for resource control and not
+ /// performance.
+ bool is_buffered_stream_enabled() const { return buffered_stream_enabled_; }
+ void enable_buffered_stream() { buffered_stream_enabled_ = true; }
+ void disable_buffered_stream() { buffered_stream_enabled_ = false; }
+
+ int64_t buffer_size() const { return buffer_size_; }
+ void set_buffer_size(int64_t size) { buffer_size_ = size; }
+
+ void file_decryption_properties(std::shared_ptr<FileDecryptionProperties> decryption) {
+ file_decryption_properties_ = std::move(decryption);
+ }
+
+ const std::shared_ptr<FileDecryptionProperties>& file_decryption_properties() const {
+ return file_decryption_properties_;
+ }
+
+ private:
+ MemoryPool* pool_;
+ int64_t buffer_size_ = kDefaultBufferSize;
+ bool buffered_stream_enabled_ = false;
+ std::shared_ptr<FileDecryptionProperties> file_decryption_properties_;
+};
+
+ReaderProperties PARQUET_EXPORT default_reader_properties();
+
+static constexpr int64_t kDefaultDataPageSize = 1024 * 1024;
+static constexpr bool DEFAULT_IS_DICTIONARY_ENABLED = true;
+static constexpr int64_t DEFAULT_DICTIONARY_PAGE_SIZE_LIMIT = kDefaultDataPageSize;
+static constexpr int64_t DEFAULT_WRITE_BATCH_SIZE = 1024;
+static constexpr int64_t DEFAULT_MAX_ROW_GROUP_LENGTH = 64 * 1024 * 1024;
+static constexpr bool DEFAULT_ARE_STATISTICS_ENABLED = true;
+static constexpr int64_t DEFAULT_MAX_STATISTICS_SIZE = 4096;
+static constexpr Encoding::type DEFAULT_ENCODING = Encoding::PLAIN;
+static const char DEFAULT_CREATED_BY[] = CREATED_BY_VERSION;
+static constexpr Compression::type DEFAULT_COMPRESSION_TYPE = Compression::UNCOMPRESSED;
+
+class PARQUET_EXPORT ColumnProperties {
+ public:
+ ColumnProperties(Encoding::type encoding = DEFAULT_ENCODING,
+ Compression::type codec = DEFAULT_COMPRESSION_TYPE,
+ bool dictionary_enabled = DEFAULT_IS_DICTIONARY_ENABLED,
+ bool statistics_enabled = DEFAULT_ARE_STATISTICS_ENABLED,
+ size_t max_stats_size = DEFAULT_MAX_STATISTICS_SIZE)
+ : encoding_(encoding),
+ codec_(codec),
+ dictionary_enabled_(dictionary_enabled),
+ statistics_enabled_(statistics_enabled),
+ max_stats_size_(max_stats_size),
+ compression_level_(Codec::UseDefaultCompressionLevel()) {}
+
+ void set_encoding(Encoding::type encoding) { encoding_ = encoding; }
+
+ void set_compression(Compression::type codec) { codec_ = codec; }
+
+ void set_dictionary_enabled(bool dictionary_enabled) {
+ dictionary_enabled_ = dictionary_enabled;
+ }
+
+ void set_statistics_enabled(bool statistics_enabled) {
+ statistics_enabled_ = statistics_enabled;
+ }
+
+ void set_max_statistics_size(size_t max_stats_size) {
+ max_stats_size_ = max_stats_size;
+ }
+
+ void set_compression_level(int compression_level) {
+ compression_level_ = compression_level;
+ }
+
+ Encoding::type encoding() const { return encoding_; }
+
+ Compression::type compression() const { return codec_; }
+
+ bool dictionary_enabled() const { return dictionary_enabled_; }
+
+ bool statistics_enabled() const { return statistics_enabled_; }
+
+ size_t max_statistics_size() const { return max_stats_size_; }
+
+ int compression_level() const { return compression_level_; }
+
+ private:
+ Encoding::type encoding_;
+ Compression::type codec_;
+ bool dictionary_enabled_;
+ bool statistics_enabled_;
+ size_t max_stats_size_;
+ int compression_level_;
+};
+
+class PARQUET_EXPORT WriterProperties {
+ public:
+ class Builder {
+ public:
+ Builder()
+ : pool_(::arrow::default_memory_pool()),
+ dictionary_pagesize_limit_(DEFAULT_DICTIONARY_PAGE_SIZE_LIMIT),
+ write_batch_size_(DEFAULT_WRITE_BATCH_SIZE),
+ max_row_group_length_(DEFAULT_MAX_ROW_GROUP_LENGTH),
+ pagesize_(kDefaultDataPageSize),
+ version_(ParquetVersion::PARQUET_1_0),
+ data_page_version_(ParquetDataPageVersion::V1),
+ created_by_(DEFAULT_CREATED_BY) {}
+ virtual ~Builder() {}
+
+ Builder* memory_pool(MemoryPool* pool) {
+ pool_ = pool;
+ return this;
+ }
+
+ Builder* enable_dictionary() {
+ default_column_properties_.set_dictionary_enabled(true);
+ return this;
+ }
+
+ Builder* disable_dictionary() {
+ default_column_properties_.set_dictionary_enabled(false);
+ return this;
+ }
+
+ Builder* enable_dictionary(const std::string& path) {
+ dictionary_enabled_[path] = true;
+ return this;
+ }
+
+ Builder* enable_dictionary(const std::shared_ptr<schema::ColumnPath>& path) {
+ return this->enable_dictionary(path->ToDotString());
+ }
+
+ Builder* disable_dictionary(const std::string& path) {
+ dictionary_enabled_[path] = false;
+ return this;
+ }
+
+ Builder* disable_dictionary(const std::shared_ptr<schema::ColumnPath>& path) {
+ return this->disable_dictionary(path->ToDotString());
+ }
+
+ Builder* dictionary_pagesize_limit(int64_t dictionary_psize_limit) {
+ dictionary_pagesize_limit_ = dictionary_psize_limit;
+ return this;
+ }
+
+ Builder* write_batch_size(int64_t write_batch_size) {
+ write_batch_size_ = write_batch_size;
+ return this;
+ }
+
+ Builder* max_row_group_length(int64_t max_row_group_length) {
+ max_row_group_length_ = max_row_group_length;
+ return this;
+ }
+
+ Builder* data_pagesize(int64_t pg_size) {
+ pagesize_ = pg_size;
+ return this;
+ }
+
+ Builder* data_page_version(ParquetDataPageVersion data_page_version) {
+ data_page_version_ = data_page_version;
+ return this;
+ }
+
+ Builder* version(ParquetVersion::type version) {
+ version_ = version;
+ return this;
+ }
+
+ Builder* created_by(const std::string& created_by) {
+ created_by_ = created_by;
+ return this;
+ }
+
+ /**
+ * Define the encoding that is used when we don't utilise dictionary encoding.
+ *
+ * This either apply if dictionary encoding is disabled or if we fallback
+ * as the dictionary grew too large.
+ */
+ Builder* encoding(Encoding::type encoding_type) {
+ if (encoding_type == Encoding::PLAIN_DICTIONARY ||
+ encoding_type == Encoding::RLE_DICTIONARY) {
+ throw ParquetException("Can't use dictionary encoding as fallback encoding");
+ }
+
+ default_column_properties_.set_encoding(encoding_type);
+ return this;
+ }
+
+ /**
+ * Define the encoding that is used when we don't utilise dictionary encoding.
+ *
+ * This either apply if dictionary encoding is disabled or if we fallback
+ * as the dictionary grew too large.
+ */
+ Builder* encoding(const std::string& path, Encoding::type encoding_type) {
+ if (encoding_type == Encoding::PLAIN_DICTIONARY ||
+ encoding_type == Encoding::RLE_DICTIONARY) {
+ throw ParquetException("Can't use dictionary encoding as fallback encoding");
+ }
+
+ encodings_[path] = encoding_type;
+ return this;
+ }
+
+ /**
+ * Define the encoding that is used when we don't utilise dictionary encoding.
+ *
+ * This either apply if dictionary encoding is disabled or if we fallback
+ * as the dictionary grew too large.
+ */
+ Builder* encoding(const std::shared_ptr<schema::ColumnPath>& path,
+ Encoding::type encoding_type) {
+ return this->encoding(path->ToDotString(), encoding_type);
+ }
+
+ Builder* compression(Compression::type codec) {
+ default_column_properties_.set_compression(codec);
+ return this;
+ }
+
+ Builder* max_statistics_size(size_t max_stats_sz) {
+ default_column_properties_.set_max_statistics_size(max_stats_sz);
+ return this;
+ }
+
+ Builder* compression(const std::string& path, Compression::type codec) {
+ codecs_[path] = codec;
+ return this;
+ }
+
+ Builder* compression(const std::shared_ptr<schema::ColumnPath>& path,
+ Compression::type codec) {
+ return this->compression(path->ToDotString(), codec);
+ }
+
+ /// \brief Specify the default compression level for the compressor in
+ /// every column. In case a column does not have an explicitly specified
+ /// compression level, the default one would be used.
+ ///
+ /// The provided compression level is compressor specific. The user would
+ /// have to familiarize oneself with the available levels for the selected
+ /// compressor. If the compressor does not allow for selecting different
+ /// compression levels, calling this function would not have any effect.
+ /// Parquet and Arrow do not validate the passed compression level. If no
+ /// level is selected by the user or if the special
+ /// std::numeric_limits<int>::min() value is passed, then Arrow selects the
+ /// compression level.
+ Builder* compression_level(int compression_level) {
+ default_column_properties_.set_compression_level(compression_level);
+ return this;
+ }
+
+ /// \brief Specify a compression level for the compressor for the column
+ /// described by path.
+ ///
+ /// The provided compression level is compressor specific. The user would
+ /// have to familiarize oneself with the available levels for the selected
+ /// compressor. If the compressor does not allow for selecting different
+ /// compression levels, calling this function would not have any effect.
+ /// Parquet and Arrow do not validate the passed compression level. If no
+ /// level is selected by the user or if the special
+ /// std::numeric_limits<int>::min() value is passed, then Arrow selects the
+ /// compression level.
+ Builder* compression_level(const std::string& path, int compression_level) {
+ codecs_compression_level_[path] = compression_level;
+ return this;
+ }
+
+ /// \brief Specify a compression level for the compressor for the column
+ /// described by path.
+ ///
+ /// The provided compression level is compressor specific. The user would
+ /// have to familiarize oneself with the available levels for the selected
+ /// compressor. If the compressor does not allow for selecting different
+ /// compression levels, calling this function would not have any effect.
+ /// Parquet and Arrow do not validate the passed compression level. If no
+ /// level is selected by the user or if the special
+ /// std::numeric_limits<int>::min() value is passed, then Arrow selects the
+ /// compression level.
+ Builder* compression_level(const std::shared_ptr<schema::ColumnPath>& path,
+ int compression_level) {
+ return this->compression_level(path->ToDotString(), compression_level);
+ }
+
+ Builder* encryption(
+ std::shared_ptr<FileEncryptionProperties> file_encryption_properties) {
+ file_encryption_properties_ = std::move(file_encryption_properties);
+ return this;
+ }
+
+ Builder* enable_statistics() {
+ default_column_properties_.set_statistics_enabled(true);
+ return this;
+ }
+
+ Builder* disable_statistics() {
+ default_column_properties_.set_statistics_enabled(false);
+ return this;
+ }
+
+ Builder* enable_statistics(const std::string& path) {
+ statistics_enabled_[path] = true;
+ return this;
+ }
+
+ Builder* enable_statistics(const std::shared_ptr<schema::ColumnPath>& path) {
+ return this->enable_statistics(path->ToDotString());
+ }
+
+ Builder* disable_statistics(const std::string& path) {
+ statistics_enabled_[path] = false;
+ return this;
+ }
+
+ Builder* disable_statistics(const std::shared_ptr<schema::ColumnPath>& path) {
+ return this->disable_statistics(path->ToDotString());
+ }
+
+ std::shared_ptr<WriterProperties> build() {
+ std::unordered_map<std::string, ColumnProperties> column_properties;
+ auto get = [&](const std::string& key) -> ColumnProperties& {
+ auto it = column_properties.find(key);
+ if (it == column_properties.end())
+ return column_properties[key] = default_column_properties_;
+ else
+ return it->second;
+ };
+
+ for (const auto& item : encodings_) get(item.first).set_encoding(item.second);
+ for (const auto& item : codecs_) get(item.first).set_compression(item.second);
+ for (const auto& item : codecs_compression_level_)
+ get(item.first).set_compression_level(item.second);
+ for (const auto& item : dictionary_enabled_)
+ get(item.first).set_dictionary_enabled(item.second);
+ for (const auto& item : statistics_enabled_)
+ get(item.first).set_statistics_enabled(item.second);
+
+ return std::shared_ptr<WriterProperties>(new WriterProperties(
+ pool_, dictionary_pagesize_limit_, write_batch_size_, max_row_group_length_,
+ pagesize_, version_, created_by_, std::move(file_encryption_properties_),
+ default_column_properties_, column_properties, data_page_version_));
+ }
+
+ private:
+ MemoryPool* pool_;
+ int64_t dictionary_pagesize_limit_;
+ int64_t write_batch_size_;
+ int64_t max_row_group_length_;
+ int64_t pagesize_;
+ ParquetVersion::type version_;
+ ParquetDataPageVersion data_page_version_;
+ std::string created_by_;
+
+ std::shared_ptr<FileEncryptionProperties> file_encryption_properties_;
+
+ // Settings used for each column unless overridden in any of the maps below
+ ColumnProperties default_column_properties_;
+ std::unordered_map<std::string, Encoding::type> encodings_;
+ std::unordered_map<std::string, Compression::type> codecs_;
+ std::unordered_map<std::string, int32_t> codecs_compression_level_;
+ std::unordered_map<std::string, bool> dictionary_enabled_;
+ std::unordered_map<std::string, bool> statistics_enabled_;
+ };
+
+ inline MemoryPool* memory_pool() const { return pool_; }
+
+ inline int64_t dictionary_pagesize_limit() const { return dictionary_pagesize_limit_; }
+
+ inline int64_t write_batch_size() const { return write_batch_size_; }
+
+ inline int64_t max_row_group_length() const { return max_row_group_length_; }
+
+ inline int64_t data_pagesize() const { return pagesize_; }
+
+ inline ParquetDataPageVersion data_page_version() const {
+ return parquet_data_page_version_;
+ }
+
+ inline ParquetVersion::type version() const { return parquet_version_; }
+
+ inline std::string created_by() const { return parquet_created_by_; }
+
+ inline Encoding::type dictionary_index_encoding() const {
+ if (parquet_version_ == ParquetVersion::PARQUET_1_0) {
+ return Encoding::PLAIN_DICTIONARY;
+ } else {
+ return Encoding::RLE_DICTIONARY;
+ }
+ }
+
+ inline Encoding::type dictionary_page_encoding() const {
+ if (parquet_version_ == ParquetVersion::PARQUET_1_0) {
+ return Encoding::PLAIN_DICTIONARY;
+ } else {
+ return Encoding::PLAIN;
+ }
+ }
+
+ const ColumnProperties& column_properties(
+ const std::shared_ptr<schema::ColumnPath>& path) const {
+ auto it = column_properties_.find(path->ToDotString());
+ if (it != column_properties_.end()) return it->second;
+ return default_column_properties_;
+ }
+
+ Encoding::type encoding(const std::shared_ptr<schema::ColumnPath>& path) const {
+ return column_properties(path).encoding();
+ }
+
+ Compression::type compression(const std::shared_ptr<schema::ColumnPath>& path) const {
+ return column_properties(path).compression();
+ }
+
+ int compression_level(const std::shared_ptr<schema::ColumnPath>& path) const {
+ return column_properties(path).compression_level();
+ }
+
+ bool dictionary_enabled(const std::shared_ptr<schema::ColumnPath>& path) const {
+ return column_properties(path).dictionary_enabled();
+ }
+
+ bool statistics_enabled(const std::shared_ptr<schema::ColumnPath>& path) const {
+ return column_properties(path).statistics_enabled();
+ }
+
+ size_t max_statistics_size(const std::shared_ptr<schema::ColumnPath>& path) const {
+ return column_properties(path).max_statistics_size();
+ }
+
+ inline FileEncryptionProperties* file_encryption_properties() const {
+ return file_encryption_properties_.get();
+ }
+
+ std::shared_ptr<ColumnEncryptionProperties> column_encryption_properties(
+ const std::string& path) const {
+ if (file_encryption_properties_) {
+ return file_encryption_properties_->column_encryption_properties(path);
+ } else {
+ return NULLPTR;
+ }
+ }
+
+ private:
+ explicit WriterProperties(
+ MemoryPool* pool, int64_t dictionary_pagesize_limit, int64_t write_batch_size,
+ int64_t max_row_group_length, int64_t pagesize, ParquetVersion::type version,
+ const std::string& created_by,
+ std::shared_ptr<FileEncryptionProperties> file_encryption_properties,
+ const ColumnProperties& default_column_properties,
+ const std::unordered_map<std::string, ColumnProperties>& column_properties,
+ ParquetDataPageVersion data_page_version)
+ : pool_(pool),
+ dictionary_pagesize_limit_(dictionary_pagesize_limit),
+ write_batch_size_(write_batch_size),
+ max_row_group_length_(max_row_group_length),
+ pagesize_(pagesize),
+ parquet_data_page_version_(data_page_version),
+ parquet_version_(version),
+ parquet_created_by_(created_by),
+ file_encryption_properties_(file_encryption_properties),
+ default_column_properties_(default_column_properties),
+ column_properties_(column_properties) {}
+
+ MemoryPool* pool_;
+ int64_t dictionary_pagesize_limit_;
+ int64_t write_batch_size_;
+ int64_t max_row_group_length_;
+ int64_t pagesize_;
+ ParquetDataPageVersion parquet_data_page_version_;
+ ParquetVersion::type parquet_version_;
+ std::string parquet_created_by_;
+
+ std::shared_ptr<FileEncryptionProperties> file_encryption_properties_;
+
+ ColumnProperties default_column_properties_;
+ std::unordered_map<std::string, ColumnProperties> column_properties_;
+};
+
+PARQUET_EXPORT const std::shared_ptr<WriterProperties>& default_writer_properties();
+
+// ----------------------------------------------------------------------
+// Properties specific to Apache Arrow columnar read and write
+
+static constexpr bool kArrowDefaultUseThreads = false;
+
+// Default number of rows to read when using ::arrow::RecordBatchReader
+static constexpr int64_t kArrowDefaultBatchSize = 64 * 1024;
+
+/// EXPERIMENTAL: Properties for configuring FileReader behavior.
+class PARQUET_EXPORT ArrowReaderProperties {
+ public:
+ explicit ArrowReaderProperties(bool use_threads = kArrowDefaultUseThreads)
+ : use_threads_(use_threads),
+ read_dict_indices_(),
+ batch_size_(kArrowDefaultBatchSize),
+ pre_buffer_(false),
+ cache_options_(::arrow::io::CacheOptions::Defaults()),
+ coerce_int96_timestamp_unit_(::arrow::TimeUnit::NANO) {}
+
+ void set_use_threads(bool use_threads) { use_threads_ = use_threads; }
+
+ bool use_threads() const { return use_threads_; }
+
+ void set_read_dictionary(int column_index, bool read_dict) {
+ if (read_dict) {
+ read_dict_indices_.insert(column_index);
+ } else {
+ read_dict_indices_.erase(column_index);
+ }
+ }
+ bool read_dictionary(int column_index) const {
+ if (read_dict_indices_.find(column_index) != read_dict_indices_.end()) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ void set_batch_size(int64_t batch_size) { batch_size_ = batch_size; }
+
+ int64_t batch_size() const { return batch_size_; }
+
+ /// Enable read coalescing.
+ ///
+ /// When enabled, the Arrow reader will pre-buffer necessary regions
+ /// of the file in-memory. This is intended to improve performance on
+ /// high-latency filesystems (e.g. Amazon S3).
+ void set_pre_buffer(bool pre_buffer) { pre_buffer_ = pre_buffer; }
+
+ bool pre_buffer() const { return pre_buffer_; }
+
+ /// Set options for read coalescing. This can be used to tune the
+ /// implementation for characteristics of different filesystems.
+ void set_cache_options(::arrow::io::CacheOptions options) { cache_options_ = options; }
+
+ const ::arrow::io::CacheOptions& cache_options() const { return cache_options_; }
+
+ /// Set execution context for read coalescing.
+ void set_io_context(const ::arrow::io::IOContext& ctx) { io_context_ = ctx; }
+
+ const ::arrow::io::IOContext& io_context() const { return io_context_; }
+
+ /// Set timestamp unit to use for deprecated INT96-encoded timestamps
+ /// (default is NANO).
+ void set_coerce_int96_timestamp_unit(::arrow::TimeUnit::type unit) {
+ coerce_int96_timestamp_unit_ = unit;
+ }
+
+ ::arrow::TimeUnit::type coerce_int96_timestamp_unit() const {
+ return coerce_int96_timestamp_unit_;
+ }
+
+ private:
+ bool use_threads_;
+ std::unordered_set<int> read_dict_indices_;
+ int64_t batch_size_;
+ bool pre_buffer_;
+ ::arrow::io::IOContext io_context_;
+ ::arrow::io::CacheOptions cache_options_;
+ ::arrow::TimeUnit::type coerce_int96_timestamp_unit_;
+};
+
+/// EXPERIMENTAL: Constructs the default ArrowReaderProperties
+PARQUET_EXPORT
+ArrowReaderProperties default_arrow_reader_properties();
+
+class PARQUET_EXPORT ArrowWriterProperties {
+ public:
+ enum EngineVersion {
+ V1, // Supports only nested lists.
+ V2 // Full support for all nesting combinations
+ };
+ class Builder {
+ public:
+ Builder()
+ : write_timestamps_as_int96_(false),
+ coerce_timestamps_enabled_(false),
+ coerce_timestamps_unit_(::arrow::TimeUnit::SECOND),
+ truncated_timestamps_allowed_(false),
+ store_schema_(false),
+ // TODO: At some point we should flip this.
+ compliant_nested_types_(false),
+ engine_version_(V2) {}
+ virtual ~Builder() = default;
+
+ Builder* disable_deprecated_int96_timestamps() {
+ write_timestamps_as_int96_ = false;
+ return this;
+ }
+
+ Builder* enable_deprecated_int96_timestamps() {
+ write_timestamps_as_int96_ = true;
+ return this;
+ }
+
+ Builder* coerce_timestamps(::arrow::TimeUnit::type unit) {
+ coerce_timestamps_enabled_ = true;
+ coerce_timestamps_unit_ = unit;
+ return this;
+ }
+
+ Builder* allow_truncated_timestamps() {
+ truncated_timestamps_allowed_ = true;
+ return this;
+ }
+
+ Builder* disallow_truncated_timestamps() {
+ truncated_timestamps_allowed_ = false;
+ return this;
+ }
+
+ /// \brief EXPERIMENTAL: Write binary serialized Arrow schema to the file,
+ /// to enable certain read options (like "read_dictionary") to be set
+ /// automatically
+ Builder* store_schema() {
+ store_schema_ = true;
+ return this;
+ }
+
+ Builder* enable_compliant_nested_types() {
+ compliant_nested_types_ = true;
+ return this;
+ }
+
+ Builder* disable_compliant_nested_types() {
+ compliant_nested_types_ = false;
+ return this;
+ }
+
+ Builder* set_engine_version(EngineVersion version) {
+ engine_version_ = version;
+ return this;
+ }
+
+ std::shared_ptr<ArrowWriterProperties> build() {
+ return std::shared_ptr<ArrowWriterProperties>(new ArrowWriterProperties(
+ write_timestamps_as_int96_, coerce_timestamps_enabled_, coerce_timestamps_unit_,
+ truncated_timestamps_allowed_, store_schema_, compliant_nested_types_,
+ engine_version_));
+ }
+
+ private:
+ bool write_timestamps_as_int96_;
+
+ bool coerce_timestamps_enabled_;
+ ::arrow::TimeUnit::type coerce_timestamps_unit_;
+ bool truncated_timestamps_allowed_;
+
+ bool store_schema_;
+ bool compliant_nested_types_;
+ EngineVersion engine_version_;
+ };
+
+ bool support_deprecated_int96_timestamps() const { return write_timestamps_as_int96_; }
+
+ bool coerce_timestamps_enabled() const { return coerce_timestamps_enabled_; }
+ ::arrow::TimeUnit::type coerce_timestamps_unit() const {
+ return coerce_timestamps_unit_;
+ }
+
+ bool truncated_timestamps_allowed() const { return truncated_timestamps_allowed_; }
+
+ bool store_schema() const { return store_schema_; }
+
+ /// \brief Enable nested type naming according to the parquet specification.
+ ///
+ /// Older versions of arrow wrote out field names for nested lists based on the name
+ /// of the field. According to the parquet specification they should always be
+ /// "element".
+ bool compliant_nested_types() const { return compliant_nested_types_; }
+
+ /// \brief The underlying engine version to use when writing Arrow data.
+ ///
+ /// V2 is currently the latest V1 is considered deprecated but left in
+ /// place in case there are bugs detected in V2.
+ EngineVersion engine_version() const { return engine_version_; }
+
+ private:
+ explicit ArrowWriterProperties(bool write_nanos_as_int96,
+ bool coerce_timestamps_enabled,
+ ::arrow::TimeUnit::type coerce_timestamps_unit,
+ bool truncated_timestamps_allowed, bool store_schema,
+ bool compliant_nested_types,
+ EngineVersion engine_version)
+ : write_timestamps_as_int96_(write_nanos_as_int96),
+ coerce_timestamps_enabled_(coerce_timestamps_enabled),
+ coerce_timestamps_unit_(coerce_timestamps_unit),
+ truncated_timestamps_allowed_(truncated_timestamps_allowed),
+ store_schema_(store_schema),
+ compliant_nested_types_(compliant_nested_types),
+ engine_version_(engine_version) {}
+
+ const bool write_timestamps_as_int96_;
+ const bool coerce_timestamps_enabled_;
+ const ::arrow::TimeUnit::type coerce_timestamps_unit_;
+ const bool truncated_timestamps_allowed_;
+ const bool store_schema_;
+ const bool compliant_nested_types_;
+ const EngineVersion engine_version_;
+};
+
+/// \brief State object used for writing Arrow data directly to a Parquet
+/// column chunk. API possibly not stable
+struct ArrowWriteContext {
+ ArrowWriteContext(MemoryPool* memory_pool, ArrowWriterProperties* properties)
+ : memory_pool(memory_pool),
+ properties(properties),
+ data_buffer(AllocateBuffer(memory_pool)),
+ def_levels_buffer(AllocateBuffer(memory_pool)) {}
+
+ template <typename T>
+ ::arrow::Status GetScratchData(const int64_t num_values, T** out) {
+ ARROW_RETURN_NOT_OK(this->data_buffer->Resize(num_values * sizeof(T), false));
+ *out = reinterpret_cast<T*>(this->data_buffer->mutable_data());
+ return ::arrow::Status::OK();
+ }
+
+ MemoryPool* memory_pool;
+ const ArrowWriterProperties* properties;
+
+ // Buffer used for storing the data of an array converted to the physical type
+ // as expected by parquet-cpp.
+ std::shared_ptr<ResizableBuffer> data_buffer;
+
+ // We use the shared ownership of this buffer
+ std::shared_ptr<ResizableBuffer> def_levels_buffer;
+};
+
+PARQUET_EXPORT
+std::shared_ptr<ArrowWriterProperties> default_arrow_writer_properties();
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/properties_test.cc b/src/arrow/cpp/src/parquet/properties_test.cc
new file mode 100644
index 000000000..7ce96e4a7
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/properties_test.cc
@@ -0,0 +1,90 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <string>
+
+#include "arrow/buffer.h"
+#include "arrow/io/memory.h"
+
+#include "parquet/file_reader.h"
+#include "parquet/properties.h"
+
+namespace parquet {
+
+using schema::ColumnPath;
+
+namespace test {
+
+TEST(TestReaderProperties, Basics) {
+ ReaderProperties props;
+
+ ASSERT_EQ(props.buffer_size(), kDefaultBufferSize);
+ ASSERT_FALSE(props.is_buffered_stream_enabled());
+}
+
+TEST(TestWriterProperties, Basics) {
+ std::shared_ptr<WriterProperties> props = WriterProperties::Builder().build();
+
+ ASSERT_EQ(kDefaultDataPageSize, props->data_pagesize());
+ ASSERT_EQ(DEFAULT_DICTIONARY_PAGE_SIZE_LIMIT, props->dictionary_pagesize_limit());
+ ASSERT_EQ(ParquetVersion::PARQUET_1_0, props->version());
+ ASSERT_EQ(ParquetDataPageVersion::V1, props->data_page_version());
+}
+
+TEST(TestWriterProperties, AdvancedHandling) {
+ WriterProperties::Builder builder;
+ builder.compression("gzip", Compression::GZIP);
+ builder.compression("zstd", Compression::ZSTD);
+ builder.compression(Compression::SNAPPY);
+ builder.encoding(Encoding::DELTA_BINARY_PACKED);
+ builder.encoding("delta-length", Encoding::DELTA_LENGTH_BYTE_ARRAY);
+ builder.data_page_version(ParquetDataPageVersion::V2);
+ std::shared_ptr<WriterProperties> props = builder.build();
+
+ ASSERT_EQ(Compression::GZIP, props->compression(ColumnPath::FromDotString("gzip")));
+ ASSERT_EQ(Compression::ZSTD, props->compression(ColumnPath::FromDotString("zstd")));
+ ASSERT_EQ(Compression::SNAPPY,
+ props->compression(ColumnPath::FromDotString("delta-length")));
+ ASSERT_EQ(Encoding::DELTA_BINARY_PACKED,
+ props->encoding(ColumnPath::FromDotString("gzip")));
+ ASSERT_EQ(Encoding::DELTA_LENGTH_BYTE_ARRAY,
+ props->encoding(ColumnPath::FromDotString("delta-length")));
+ ASSERT_EQ(ParquetDataPageVersion::V2, props->data_page_version());
+}
+
+TEST(TestReaderProperties, GetStreamInsufficientData) {
+ // ARROW-6058
+ std::string data = "shorter than expected";
+ auto buf = std::make_shared<Buffer>(data);
+ auto reader = std::make_shared<::arrow::io::BufferReader>(buf);
+
+ ReaderProperties props;
+ try {
+ ARROW_UNUSED(props.GetStream(reader, 12, 15));
+ FAIL() << "No exception raised";
+ } catch (const ParquetException& e) {
+ std::string ex_what =
+ ("Tried reading 15 bytes starting at position 12"
+ " from file but only got 9");
+ ASSERT_EQ(ex_what, e.what());
+ }
+}
+
+} // namespace test
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/public_api_test.cc b/src/arrow/cpp/src/parquet/public_api_test.cc
new file mode 100644
index 000000000..c0ef97a70
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/public_api_test.cc
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include "parquet/api/io.h" // IWYU pragma: keep
+#include "parquet/api/reader.h" // IWYU pragma: keep
+#include "parquet/api/schema.h" // IWYU pragma: keep
+#include "parquet/api/writer.h" // IWYU pragma: keep
+
+TEST(TestPublicAPI, DoesNotIncludeThrift) {
+#ifdef _THRIFT_THRIFT_H_
+ FAIL() << "Thrift headers should not be in the public API";
+#endif
+}
+
+TEST(TestPublicAPI, DoesNotExportDCHECK) {
+#ifdef DCHECK
+ FAIL() << "parquet/util/logging.h should not be transitively included";
+#endif
+}
+
+TEST(TestPublicAPI, DoesNotIncludeZlib) {
+#ifdef ZLIB_H
+ FAIL() << "zlib.h should not be transitively included";
+#endif
+}
+
+PARQUET_NORETURN void ThrowsParquetException() {
+ throw parquet::ParquetException("This function throws");
+}
+
+TEST(TestPublicAPI, CanThrowParquetException) {
+ ASSERT_THROW(ThrowsParquetException(), parquet::ParquetException);
+}
diff --git a/src/arrow/cpp/src/parquet/reader_test.cc b/src/arrow/cpp/src/parquet/reader_test.cc
new file mode 100644
index 000000000..2d13266df
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/reader_test.cc
@@ -0,0 +1,810 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <fcntl.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <cstdint>
+#include <cstdlib>
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/io/file.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/make_unique.h"
+
+#include "parquet/column_reader.h"
+#include "parquet/column_scanner.h"
+#include "parquet/file_reader.h"
+#include "parquet/file_writer.h"
+#include "parquet/metadata.h"
+#include "parquet/platform.h"
+#include "parquet/printer.h"
+#include "parquet/test_util.h"
+
+using arrow::internal::checked_pointer_cast;
+
+namespace parquet {
+using schema::GroupNode;
+using schema::PrimitiveNode;
+
+using ReadableFile = ::arrow::io::ReadableFile;
+
+std::string data_file(const char* file) {
+ std::string dir_string(test::get_data_dir());
+ std::stringstream ss;
+ ss << dir_string << "/" << file;
+ return ss.str();
+}
+
+std::string alltypes_plain() { return data_file("alltypes_plain.parquet"); }
+
+std::string nation_dict_truncated_data_page() {
+ return data_file("nation.dict-malformed.parquet");
+}
+
+// LZ4-compressed data files.
+// These files come in three flavours:
+// - legacy "LZ4" compression type, actually compressed with block LZ4 codec
+// (as emitted by some earlier versions of parquet-cpp)
+// - legacy "LZ4" compression type, actually compressed with custom Hadoop LZ4 codec
+// (as emitted by parquet-mr)
+// - "LZ4_RAW" compression type (added in Parquet format version 2.9.0)
+
+std::string hadoop_lz4_compressed() { return data_file("hadoop_lz4_compressed.parquet"); }
+
+std::string hadoop_lz4_compressed_larger() {
+ return data_file("hadoop_lz4_compressed_larger.parquet");
+}
+
+std::string non_hadoop_lz4_compressed() {
+ return data_file("non_hadoop_lz4_compressed.parquet");
+}
+
+std::string lz4_raw_compressed() { return data_file("lz4_raw_compressed.parquet"); }
+
+std::string lz4_raw_compressed_larger() {
+ return data_file("lz4_raw_compressed_larger.parquet");
+}
+
+// TODO: Assert on definition and repetition levels
+template <typename DType, typename ValueType>
+void AssertColumnValues(std::shared_ptr<TypedColumnReader<DType>> col, int64_t batch_size,
+ int64_t expected_levels_read,
+ std::vector<ValueType>& expected_values,
+ int64_t expected_values_read) {
+ std::vector<ValueType> values(batch_size);
+ int64_t values_read;
+
+ auto levels_read =
+ col->ReadBatch(batch_size, nullptr, nullptr, values.data(), &values_read);
+ ASSERT_EQ(expected_levels_read, levels_read);
+
+ ASSERT_EQ(expected_values, values);
+ ASSERT_EQ(expected_values_read, values_read);
+}
+
+void CheckRowGroupMetadata(const RowGroupMetaData* rg_metadata,
+ bool allow_uncompressed_mismatch = false) {
+ const int64_t total_byte_size = rg_metadata->total_byte_size();
+ const int64_t total_compressed_size = rg_metadata->total_compressed_size();
+
+ ASSERT_GE(total_byte_size, 0);
+ ASSERT_GE(total_compressed_size, 0);
+
+ int64_t total_column_byte_size = 0;
+ int64_t total_column_compressed_size = 0;
+ for (int i = 0; i < rg_metadata->num_columns(); ++i) {
+ total_column_byte_size += rg_metadata->ColumnChunk(i)->total_uncompressed_size();
+ total_column_compressed_size += rg_metadata->ColumnChunk(i)->total_compressed_size();
+ }
+
+ if (!allow_uncompressed_mismatch) {
+ ASSERT_EQ(total_byte_size, total_column_byte_size);
+ }
+ if (total_compressed_size != 0) {
+ ASSERT_EQ(total_compressed_size, total_column_compressed_size);
+ }
+}
+
+class TestAllTypesPlain : public ::testing::Test {
+ public:
+ void SetUp() { reader_ = ParquetFileReader::OpenFile(alltypes_plain()); }
+
+ void TearDown() {}
+
+ protected:
+ std::unique_ptr<ParquetFileReader> reader_;
+};
+
+TEST_F(TestAllTypesPlain, NoopConstructDestruct) {}
+
+TEST_F(TestAllTypesPlain, RowGroupMetaData) {
+ auto group = reader_->RowGroup(0);
+ CheckRowGroupMetadata(group->metadata());
+}
+
+TEST_F(TestAllTypesPlain, TestBatchRead) {
+ std::shared_ptr<RowGroupReader> group = reader_->RowGroup(0);
+
+ // column 0, id
+ std::shared_ptr<Int32Reader> col =
+ std::dynamic_pointer_cast<Int32Reader>(group->Column(0));
+
+ int16_t def_levels[4];
+ int16_t rep_levels[4];
+ int32_t values[4];
+
+ // This file only has 8 rows
+ ASSERT_EQ(8, reader_->metadata()->num_rows());
+ // This file only has 1 row group
+ ASSERT_EQ(1, reader_->metadata()->num_row_groups());
+ // Size of the metadata is 730 bytes
+ ASSERT_EQ(730, reader_->metadata()->size());
+ // This row group must have 8 rows
+ ASSERT_EQ(8, group->metadata()->num_rows());
+
+ ASSERT_TRUE(col->HasNext());
+ int64_t values_read;
+ auto levels_read = col->ReadBatch(4, def_levels, rep_levels, values, &values_read);
+ ASSERT_EQ(4, levels_read);
+ ASSERT_EQ(4, values_read);
+
+ // Now read past the end of the file
+ ASSERT_TRUE(col->HasNext());
+ levels_read = col->ReadBatch(5, def_levels, rep_levels, values, &values_read);
+ ASSERT_EQ(4, levels_read);
+ ASSERT_EQ(4, values_read);
+
+ ASSERT_FALSE(col->HasNext());
+}
+
+TEST_F(TestAllTypesPlain, RowGroupColumnBoundchecking) {
+ // Part of PARQUET-1857
+ ASSERT_THROW(reader_->RowGroup(reader_->metadata()->num_row_groups()),
+ ParquetException);
+
+ auto row_group = reader_->RowGroup(0);
+ ASSERT_THROW(row_group->Column(row_group->metadata()->num_columns()), ParquetException);
+ ASSERT_THROW(row_group->GetColumnPageReader(row_group->metadata()->num_columns()),
+ ParquetException);
+}
+
+TEST_F(TestAllTypesPlain, TestFlatScannerInt32) {
+ std::shared_ptr<RowGroupReader> group = reader_->RowGroup(0);
+
+ // column 0, id
+ std::shared_ptr<Int32Scanner> scanner(new Int32Scanner(group->Column(0)));
+ int32_t val;
+ bool is_null;
+ for (int i = 0; i < 8; ++i) {
+ ASSERT_TRUE(scanner->HasNext());
+ ASSERT_TRUE(scanner->NextValue(&val, &is_null));
+ ASSERT_FALSE(is_null);
+ }
+ ASSERT_FALSE(scanner->HasNext());
+ ASSERT_FALSE(scanner->NextValue(&val, &is_null));
+}
+
+TEST_F(TestAllTypesPlain, TestSetScannerBatchSize) {
+ std::shared_ptr<RowGroupReader> group = reader_->RowGroup(0);
+
+ // column 0, id
+ std::shared_ptr<Int32Scanner> scanner(new Int32Scanner(group->Column(0)));
+
+ ASSERT_EQ(128, scanner->batch_size());
+ scanner->SetBatchSize(1024);
+ ASSERT_EQ(1024, scanner->batch_size());
+}
+
+TEST_F(TestAllTypesPlain, DebugPrintWorks) {
+ std::stringstream ss;
+
+ std::list<int> columns;
+ ParquetFilePrinter printer(reader_.get());
+ printer.DebugPrint(ss, columns);
+
+ std::string result = ss.str();
+ ASSERT_GT(result.size(), 0);
+}
+
+TEST_F(TestAllTypesPlain, ColumnSelection) {
+ std::stringstream ss;
+
+ std::list<int> columns;
+ columns.push_back(5);
+ columns.push_back(0);
+ columns.push_back(10);
+ ParquetFilePrinter printer(reader_.get());
+ printer.DebugPrint(ss, columns);
+
+ std::string result = ss.str();
+ ASSERT_GT(result.size(), 0);
+}
+
+TEST_F(TestAllTypesPlain, ColumnSelectionOutOfRange) {
+ std::stringstream ss;
+
+ std::list<int> columns;
+ columns.push_back(100);
+ ParquetFilePrinter printer1(reader_.get());
+ ASSERT_THROW(printer1.DebugPrint(ss, columns), ParquetException);
+
+ columns.clear();
+ columns.push_back(-1);
+ ParquetFilePrinter printer2(reader_.get());
+ ASSERT_THROW(printer2.DebugPrint(ss, columns), ParquetException);
+}
+
+class TestLocalFile : public ::testing::Test {
+ public:
+ void SetUp() {
+ std::string dir_string(test::get_data_dir());
+
+ std::stringstream ss;
+ ss << dir_string << "/"
+ << "alltypes_plain.parquet";
+
+ PARQUET_ASSIGN_OR_THROW(handle, ReadableFile::Open(ss.str()));
+ fileno = handle->file_descriptor();
+ }
+
+ void TearDown() {}
+
+ protected:
+ int fileno;
+ std::shared_ptr<::arrow::io::ReadableFile> handle;
+};
+
+TEST_F(TestLocalFile, OpenWithMetadata) {
+ // PARQUET-808
+ std::stringstream ss;
+ std::shared_ptr<FileMetaData> metadata = ReadMetaData(handle);
+
+ auto reader = ParquetFileReader::Open(handle, default_reader_properties(), metadata);
+
+ // Compare pointers
+ ASSERT_EQ(metadata.get(), reader->metadata().get());
+
+ std::list<int> columns;
+ ParquetFilePrinter printer(reader.get());
+ printer.DebugPrint(ss, columns, true);
+
+ // Make sure OpenFile passes on the external metadata, too
+ auto reader2 = ParquetFileReader::OpenFile(alltypes_plain(), false,
+ default_reader_properties(), metadata);
+
+ // Compare pointers
+ ASSERT_EQ(metadata.get(), reader2->metadata().get());
+}
+
+TEST(TestFileReaderAdHoc, NationDictTruncatedDataPage) {
+ // PARQUET-816. Some files generated by older Parquet implementations may
+ // contain malformed data page metadata, and we can successfully decode them
+ // if we optimistically proceed to decoding, even if there is not enough data
+ // available in the stream. Before, we had quite aggressive checking of
+ // stream reads, which are not found e.g. in Impala's Parquet implementation
+ auto reader = ParquetFileReader::OpenFile(nation_dict_truncated_data_page(), false);
+ std::stringstream ss;
+
+ // empty list means print all
+ std::list<int> columns;
+ ParquetFilePrinter printer1(reader.get());
+ printer1.DebugPrint(ss, columns, true);
+
+ reader = ParquetFileReader::OpenFile(nation_dict_truncated_data_page(), true);
+ std::stringstream ss2;
+ ParquetFilePrinter printer2(reader.get());
+ printer2.DebugPrint(ss2, columns, true);
+
+ // The memory-mapped reads runs over the end of the column chunk and succeeds
+ // by accident
+ ASSERT_EQ(ss2.str(), ss.str());
+}
+
+TEST(TestDumpWithLocalFile, DumpOutput) {
+ std::string header_output = R"###(File Name: nested_lists.snappy.parquet
+Version: 1.0
+Created By: parquet-mr version 1.8.2 (build c6522788629e590a53eb79874b95f6c3ff11f16c)
+Total rows: 3
+Number of RowGroups: 1
+Number of Real Columns: 2
+Number of Columns: 2
+Number of Selected Columns: 2
+Column 0: a.list.element.list.element.list.element (BYTE_ARRAY / String / UTF8)
+Column 1: b (INT32)
+--- Row Group: 0 ---
+--- Total Bytes: 155 ---
+--- Total Compressed Bytes: 0 ---
+--- Rows: 3 ---
+Column 0
+ Values: 18 Statistics Not Set
+ Compression: SNAPPY, Encodings: RLE PLAIN_DICTIONARY
+ Uncompressed Size: 103, Compressed Size: 104
+Column 1
+ Values: 3, Null Values: 0, Distinct Values: 0
+ Max: 1, Min: 1
+ Compression: SNAPPY, Encodings: BIT_PACKED PLAIN_DICTIONARY
+ Uncompressed Size: 52, Compressed Size: 56
+)###";
+ std::string values_output = R"###(--- Values ---
+element |b |
+a |1 |
+b |1 |
+c |1 |
+NULL |
+d |
+a |
+b |
+c |
+d |
+NULL |
+e |
+a |
+b |
+c |
+d |
+e |
+NULL |
+f |
+
+)###";
+ std::string dump_output = R"###(--- Values ---
+Column 0
+ D:7 R:0 V:a
+ D:7 R:3 V:b
+ D:7 R:2 V:c
+ D:4 R:1 NULL
+ D:7 R:2 V:d
+ D:7 R:0 V:a
+ D:7 R:3 V:b
+ D:7 R:2 V:c
+ D:7 R:3 V:d
+ D:4 R:1 NULL
+ D:7 R:2 V:e
+ D:7 R:0 V:a
+ D:7 R:3 V:b
+ D:7 R:2 V:c
+ D:7 R:3 V:d
+ D:7 R:2 V:e
+ D:4 R:1 NULL
+ D:7 R:2 V:f
+Column 1
+ D:0 R:0 V:1
+ D:0 R:0 V:1
+ D:0 R:0 V:1
+)###";
+
+ // empty list means print all
+ std::list<int> columns;
+
+ std::stringstream ss_values, ss_dump;
+ const char* file = "nested_lists.snappy.parquet";
+ auto reader_props = default_reader_properties();
+ auto reader = ParquetFileReader::OpenFile(data_file(file), false, reader_props);
+ ParquetFilePrinter printer(reader.get());
+
+ printer.DebugPrint(ss_values, columns, true, false, false, file);
+ printer.DebugPrint(ss_dump, columns, true, true, false, file);
+
+ ASSERT_EQ(header_output + values_output, ss_values.str());
+ ASSERT_EQ(header_output + dump_output, ss_dump.str());
+}
+
+TEST(TestJSONWithLocalFile, JSONOutput) {
+ std::string json_output = R"###({
+ "FileName": "alltypes_plain.parquet",
+ "Version": "1.0",
+ "CreatedBy": "impala version 1.3.0-INTERNAL (build 8a48ddb1eff84592b3fc06bc6f51ec120e1fffc9)",
+ "TotalRows": "8",
+ "NumberOfRowGroups": "1",
+ "NumberOfRealColumns": "11",
+ "NumberOfColumns": "11",
+ "Columns": [
+ { "Id": "0", "Name": "id", "PhysicalType": "INT32", "ConvertedType": "NONE", "LogicalType": {"Type": "None"} },
+ { "Id": "1", "Name": "bool_col", "PhysicalType": "BOOLEAN", "ConvertedType": "NONE", "LogicalType": {"Type": "None"} },
+ { "Id": "2", "Name": "tinyint_col", "PhysicalType": "INT32", "ConvertedType": "NONE", "LogicalType": {"Type": "None"} },
+ { "Id": "3", "Name": "smallint_col", "PhysicalType": "INT32", "ConvertedType": "NONE", "LogicalType": {"Type": "None"} },
+ { "Id": "4", "Name": "int_col", "PhysicalType": "INT32", "ConvertedType": "NONE", "LogicalType": {"Type": "None"} },
+ { "Id": "5", "Name": "bigint_col", "PhysicalType": "INT64", "ConvertedType": "NONE", "LogicalType": {"Type": "None"} },
+ { "Id": "6", "Name": "float_col", "PhysicalType": "FLOAT", "ConvertedType": "NONE", "LogicalType": {"Type": "None"} },
+ { "Id": "7", "Name": "double_col", "PhysicalType": "DOUBLE", "ConvertedType": "NONE", "LogicalType": {"Type": "None"} },
+ { "Id": "8", "Name": "date_string_col", "PhysicalType": "BYTE_ARRAY", "ConvertedType": "NONE", "LogicalType": {"Type": "None"} },
+ { "Id": "9", "Name": "string_col", "PhysicalType": "BYTE_ARRAY", "ConvertedType": "NONE", "LogicalType": {"Type": "None"} },
+ { "Id": "10", "Name": "timestamp_col", "PhysicalType": "INT96", "ConvertedType": "NONE", "LogicalType": {"Type": "None"} }
+ ],
+ "RowGroups": [
+ {
+ "Id": "0", "TotalBytes": "671", "TotalCompressedBytes": "0", "Rows": "8",
+ "ColumnChunks": [
+ {"Id": "0", "Values": "8", "StatsSet": "False",
+ "Compression": "UNCOMPRESSED", "Encodings": "RLE PLAIN_DICTIONARY PLAIN ", "UncompressedSize": "73", "CompressedSize": "73" },
+ {"Id": "1", "Values": "8", "StatsSet": "False",
+ "Compression": "UNCOMPRESSED", "Encodings": "RLE PLAIN_DICTIONARY PLAIN ", "UncompressedSize": "24", "CompressedSize": "24" },
+ {"Id": "2", "Values": "8", "StatsSet": "False",
+ "Compression": "UNCOMPRESSED", "Encodings": "RLE PLAIN_DICTIONARY PLAIN ", "UncompressedSize": "47", "CompressedSize": "47" },
+ {"Id": "3", "Values": "8", "StatsSet": "False",
+ "Compression": "UNCOMPRESSED", "Encodings": "RLE PLAIN_DICTIONARY PLAIN ", "UncompressedSize": "47", "CompressedSize": "47" },
+ {"Id": "4", "Values": "8", "StatsSet": "False",
+ "Compression": "UNCOMPRESSED", "Encodings": "RLE PLAIN_DICTIONARY PLAIN ", "UncompressedSize": "47", "CompressedSize": "47" },
+ {"Id": "5", "Values": "8", "StatsSet": "False",
+ "Compression": "UNCOMPRESSED", "Encodings": "RLE PLAIN_DICTIONARY PLAIN ", "UncompressedSize": "55", "CompressedSize": "55" },
+ {"Id": "6", "Values": "8", "StatsSet": "False",
+ "Compression": "UNCOMPRESSED", "Encodings": "RLE PLAIN_DICTIONARY PLAIN ", "UncompressedSize": "47", "CompressedSize": "47" },
+ {"Id": "7", "Values": "8", "StatsSet": "False",
+ "Compression": "UNCOMPRESSED", "Encodings": "RLE PLAIN_DICTIONARY PLAIN ", "UncompressedSize": "55", "CompressedSize": "55" },
+ {"Id": "8", "Values": "8", "StatsSet": "False",
+ "Compression": "UNCOMPRESSED", "Encodings": "RLE PLAIN_DICTIONARY PLAIN ", "UncompressedSize": "88", "CompressedSize": "88" },
+ {"Id": "9", "Values": "8", "StatsSet": "False",
+ "Compression": "UNCOMPRESSED", "Encodings": "RLE PLAIN_DICTIONARY PLAIN ", "UncompressedSize": "49", "CompressedSize": "49" },
+ {"Id": "10", "Values": "8", "StatsSet": "False",
+ "Compression": "UNCOMPRESSED", "Encodings": "RLE PLAIN_DICTIONARY PLAIN ", "UncompressedSize": "139", "CompressedSize": "139" }
+ ]
+ }
+ ]
+}
+)###";
+
+ std::stringstream ss;
+ // empty list means print all
+ std::list<int> columns;
+
+ auto reader =
+ ParquetFileReader::OpenFile(alltypes_plain(), false, default_reader_properties());
+ ParquetFilePrinter printer(reader.get());
+ printer.JSONPrint(ss, columns, "alltypes_plain.parquet");
+
+ ASSERT_EQ(json_output, ss.str());
+}
+
+TEST(TestFileReader, BufferedReadsWithDictionary) {
+ const int num_rows = 1000;
+
+ // Make schema
+ schema::NodeVector fields;
+ fields.push_back(PrimitiveNode::Make("field", Repetition::REQUIRED, Type::DOUBLE,
+ ConvertedType::NONE));
+ auto schema = std::static_pointer_cast<GroupNode>(
+ GroupNode::Make("schema", Repetition::REQUIRED, fields));
+
+ // Write small batches and small data pages
+ std::shared_ptr<WriterProperties> writer_props = WriterProperties::Builder()
+ .write_batch_size(64)
+ ->data_pagesize(128)
+ ->enable_dictionary()
+ ->build();
+
+ ASSERT_OK_AND_ASSIGN(auto out_file, ::arrow::io::BufferOutputStream::Create());
+ std::shared_ptr<ParquetFileWriter> file_writer =
+ ParquetFileWriter::Open(out_file, schema, writer_props);
+
+ RowGroupWriter* rg_writer = file_writer->AppendRowGroup();
+
+ // write one column
+ ::arrow::random::RandomArrayGenerator rag(0);
+ DoubleWriter* writer = static_cast<DoubleWriter*>(rg_writer->NextColumn());
+ std::shared_ptr<::arrow::Array> col = rag.Float64(num_rows, 0, 100);
+ const auto& col_typed = static_cast<const ::arrow::DoubleArray&>(*col);
+ writer->WriteBatch(num_rows, nullptr, nullptr, col_typed.raw_values());
+ rg_writer->Close();
+ file_writer->Close();
+
+ // Open the reader
+ ASSERT_OK_AND_ASSIGN(auto file_buf, out_file->Finish());
+ auto in_file = std::make_shared<::arrow::io::BufferReader>(file_buf);
+
+ ReaderProperties reader_props;
+ reader_props.enable_buffered_stream();
+ reader_props.set_buffer_size(64);
+ std::unique_ptr<ParquetFileReader> file_reader =
+ ParquetFileReader::Open(in_file, reader_props);
+
+ auto row_group = file_reader->RowGroup(0);
+ auto col_reader = std::static_pointer_cast<DoubleReader>(
+ row_group->ColumnWithExposeEncoding(0, ExposedEncoding::DICTIONARY));
+ EXPECT_EQ(col_reader->GetExposedEncoding(), ExposedEncoding::DICTIONARY);
+
+ auto indices = ::arrow::internal::make_unique<int32_t[]>(num_rows);
+ const double* dict = nullptr;
+ int32_t dict_len = 0;
+ for (int row_index = 0; row_index < num_rows; ++row_index) {
+ const double* tmp_dict = nullptr;
+ int32_t tmp_dict_len = 0;
+ int64_t values_read = 0;
+ int64_t levels_read = col_reader->ReadBatchWithDictionary(
+ /*batch_size=*/1, /*def_levels=*/nullptr, /*rep_levels=*/nullptr,
+ indices.get() + row_index, &values_read, &tmp_dict, &tmp_dict_len);
+
+ if (tmp_dict != nullptr) {
+ EXPECT_EQ(values_read, 1);
+ dict = tmp_dict;
+ dict_len = tmp_dict_len;
+ } else {
+ EXPECT_EQ(values_read, 0);
+ }
+
+ ASSERT_EQ(1, levels_read);
+ ASSERT_EQ(1, values_read);
+ }
+
+ // Check the results
+ for (int row_index = 0; row_index < num_rows; ++row_index) {
+ EXPECT_LT(indices[row_index], dict_len);
+ EXPECT_EQ(dict[indices[row_index]], col_typed.Value(row_index));
+ }
+}
+
+TEST(TestFileReader, BufferedReads) {
+ // PARQUET-1636: Buffered reads were broken before introduction of
+ // RandomAccessFile::GetStream
+
+ const int num_columns = 10;
+ const int num_rows = 1000;
+
+ // Make schema
+ schema::NodeVector fields;
+ for (int i = 0; i < num_columns; ++i) {
+ fields.push_back(PrimitiveNode::Make("field" + std::to_string(i),
+ Repetition::REQUIRED, Type::DOUBLE,
+ ConvertedType::NONE));
+ }
+ auto schema = std::static_pointer_cast<GroupNode>(
+ GroupNode::Make("schema", Repetition::REQUIRED, fields));
+
+ // Write small batches and small data pages
+ std::shared_ptr<WriterProperties> writer_props =
+ WriterProperties::Builder().write_batch_size(64)->data_pagesize(128)->build();
+
+ ASSERT_OK_AND_ASSIGN(auto out_file, ::arrow::io::BufferOutputStream::Create());
+ std::shared_ptr<ParquetFileWriter> file_writer =
+ ParquetFileWriter::Open(out_file, schema, writer_props);
+
+ RowGroupWriter* rg_writer = file_writer->AppendRowGroup();
+
+ ::arrow::ArrayVector column_data;
+ ::arrow::random::RandomArrayGenerator rag(0);
+
+ // Scratch space for reads
+ ::arrow::BufferVector scratch_space;
+
+ // write columns
+ for (int col_index = 0; col_index < num_columns; ++col_index) {
+ DoubleWriter* writer = static_cast<DoubleWriter*>(rg_writer->NextColumn());
+ std::shared_ptr<::arrow::Array> col = rag.Float64(num_rows, 0, 100);
+ const auto& col_typed = static_cast<const ::arrow::DoubleArray&>(*col);
+ writer->WriteBatch(num_rows, nullptr, nullptr, col_typed.raw_values());
+ column_data.push_back(col);
+
+ // We use this later for reading back the columns
+ scratch_space.push_back(
+ AllocateBuffer(::arrow::default_memory_pool(), num_rows * sizeof(double)));
+ }
+ rg_writer->Close();
+ file_writer->Close();
+
+ // Open the reader
+ ASSERT_OK_AND_ASSIGN(auto file_buf, out_file->Finish());
+ auto in_file = std::make_shared<::arrow::io::BufferReader>(file_buf);
+
+ ReaderProperties reader_props;
+ reader_props.enable_buffered_stream();
+ reader_props.set_buffer_size(64);
+ std::unique_ptr<ParquetFileReader> file_reader =
+ ParquetFileReader::Open(in_file, reader_props);
+
+ auto row_group = file_reader->RowGroup(0);
+ std::vector<std::shared_ptr<DoubleReader>> col_readers;
+ for (int col_index = 0; col_index < num_columns; ++col_index) {
+ col_readers.push_back(
+ std::static_pointer_cast<DoubleReader>(row_group->Column(col_index)));
+ }
+
+ for (int row_index = 0; row_index < num_rows; ++row_index) {
+ for (int col_index = 0; col_index < num_columns; ++col_index) {
+ double* out =
+ reinterpret_cast<double*>(scratch_space[col_index]->mutable_data()) + row_index;
+ int64_t values_read = 0;
+ int64_t levels_read =
+ col_readers[col_index]->ReadBatch(1, nullptr, nullptr, out, &values_read);
+
+ ASSERT_EQ(1, levels_read);
+ ASSERT_EQ(1, values_read);
+ }
+ }
+
+ // Check the results
+ for (int col_index = 0; col_index < num_columns; ++col_index) {
+ ASSERT_TRUE(
+ scratch_space[col_index]->Equals(*column_data[col_index]->data()->buffers[1]));
+ }
+}
+
+std::unique_ptr<ParquetFileReader> OpenBuffer(const std::string& contents) {
+ auto buffer = ::arrow::Buffer::FromString(contents);
+ return ParquetFileReader::Open(std::make_shared<::arrow::io::BufferReader>(buffer));
+}
+
+::arrow::Future<> OpenBufferAsync(const std::string& contents) {
+ auto buffer = ::arrow::Buffer::FromString(contents);
+ return ::arrow::Future<>(
+ ParquetFileReader::OpenAsync(std::make_shared<::arrow::io::BufferReader>(buffer)));
+}
+
+// https://github.com/google/googletest/pull/2904 not available in our version of
+// gtest/gmock
+#define EXPECT_THROW_THAT(callable, ex_type, property) \
+ EXPECT_THROW( \
+ try { (callable)(); } catch (const ex_type& err) { \
+ EXPECT_THAT(err, (property)); \
+ throw; \
+ }, \
+ ex_type)
+
+TEST(TestFileReader, TestOpenErrors) {
+ EXPECT_THROW_THAT(
+ []() { OpenBuffer(""); }, ParquetInvalidOrCorruptedFileException,
+ ::testing::Property(&ParquetInvalidOrCorruptedFileException::what,
+ ::testing::HasSubstr("Parquet file size is 0 bytes")));
+ EXPECT_THROW_THAT(
+ []() { OpenBuffer("AAAAPAR0"); }, ParquetInvalidOrCorruptedFileException,
+ ::testing::Property(&ParquetInvalidOrCorruptedFileException::what,
+ ::testing::HasSubstr("Parquet magic bytes not found")));
+ EXPECT_THROW_THAT(
+ []() { OpenBuffer("APAR1"); }, ParquetInvalidOrCorruptedFileException,
+ ::testing::Property(
+ &ParquetInvalidOrCorruptedFileException::what,
+ ::testing::HasSubstr(
+ "Parquet file size is 5 bytes, smaller than the minimum file footer")));
+ EXPECT_THROW_THAT(
+ []() { OpenBuffer("\xFF\xFF\xFF\x0FPAR1"); },
+ ParquetInvalidOrCorruptedFileException,
+ ::testing::Property(&ParquetInvalidOrCorruptedFileException::what,
+ ::testing::HasSubstr("Parquet file size is 8 bytes, smaller "
+ "than the size reported by footer's")));
+ EXPECT_THROW_THAT(
+ []() { OpenBuffer(std::string("\x00\x00\x00\x00PAR1", 8)); }, ParquetException,
+ ::testing::Property(
+ &ParquetException::what,
+ ::testing::HasSubstr("Couldn't deserialize thrift: No more data to read")));
+
+ EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Parquet file size is 0 bytes"), OpenBufferAsync(""));
+ EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Parquet magic bytes not found"),
+ OpenBufferAsync("AAAAPAR0"));
+ EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr(
+ "Parquet file size is 5 bytes, smaller than the minimum file footer"),
+ OpenBufferAsync("APAR1"));
+ EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr(
+ "Parquet file size is 8 bytes, smaller than the size reported by footer's"),
+ OpenBufferAsync("\xFF\xFF\xFF\x0FPAR1"));
+ EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(
+ IOError, ::testing::HasSubstr("Couldn't deserialize thrift: No more data to read"),
+ OpenBufferAsync(std::string("\x00\x00\x00\x00PAR1", 8)));
+}
+
+#undef EXPECT_THROW_THAT
+
+#ifdef ARROW_WITH_LZ4
+struct TestCodecParam {
+ std::string name;
+ std::string small_data_file;
+ std::string larger_data_file;
+};
+
+void PrintTo(const TestCodecParam& p, std::ostream* os) { *os << p.name; }
+
+class TestCodec : public ::testing::TestWithParam<TestCodecParam> {
+ protected:
+ const std::string& GetSmallDataFile() { return GetParam().small_data_file; }
+
+ const std::string& GetLargerDataFile() { return GetParam().larger_data_file; }
+};
+
+TEST_P(TestCodec, SmallFileMetadataAndValues) {
+ std::unique_ptr<ParquetFileReader> reader_ =
+ ParquetFileReader::OpenFile(GetSmallDataFile());
+ std::shared_ptr<RowGroupReader> group = reader_->RowGroup(0);
+ const auto rg_metadata = group->metadata();
+
+ // This file only has 4 rows
+ ASSERT_EQ(4, reader_->metadata()->num_rows());
+ // This file only has 3 columns
+ ASSERT_EQ(3, reader_->metadata()->num_columns());
+ // This file only has 1 row group
+ ASSERT_EQ(1, reader_->metadata()->num_row_groups());
+
+ // This row group must have 4 rows
+ ASSERT_EQ(4, rg_metadata->num_rows());
+
+ // Some parquet-cpp versions are susceptible to PARQUET-2008
+ const auto& app_ver = reader_->metadata()->writer_version();
+ const bool allow_uncompressed_mismatch =
+ (app_ver.application_ == "parquet-cpp" && app_ver.version.major == 1 &&
+ app_ver.version.minor == 5 && app_ver.version.patch == 1);
+
+ CheckRowGroupMetadata(rg_metadata, allow_uncompressed_mismatch);
+
+ // column 0, c0
+ auto col0 = checked_pointer_cast<Int64Reader>(group->Column(0));
+ std::vector<int64_t> expected_values = {1593604800, 1593604800, 1593604801, 1593604801};
+ AssertColumnValues(col0, 4, 4, expected_values, 4);
+
+ // column 1, c1
+ std::vector<ByteArray> expected_byte_arrays = {ByteArray("abc"), ByteArray("def"),
+ ByteArray("abc"), ByteArray("def")};
+ auto col1 = checked_pointer_cast<ByteArrayReader>(group->Column(1));
+ AssertColumnValues(col1, 4, 4, expected_byte_arrays, 4);
+
+ // column 2, v11
+ std::vector<double> expected_double_values = {42.0, 7.7, 42.125, 7.7};
+ auto col2 = checked_pointer_cast<DoubleReader>(group->Column(2));
+ AssertColumnValues(col2, 4, 4, expected_double_values, 4);
+}
+
+TEST_P(TestCodec, LargeFileValues) {
+ // Test codec with a larger data file such data may have been compressed
+ // in several "frames" (ARROW-9177)
+ auto file_path = GetParam().larger_data_file;
+ if (file_path.empty()) {
+ GTEST_SKIP() << "Larger data file not available for this codec";
+ }
+ auto file = ParquetFileReader::OpenFile(file_path);
+ auto group = file->RowGroup(0);
+
+ const int64_t kNumRows = 10000;
+
+ ASSERT_EQ(kNumRows, file->metadata()->num_rows());
+ ASSERT_EQ(1, file->metadata()->num_columns());
+ ASSERT_EQ(1, file->metadata()->num_row_groups());
+ ASSERT_EQ(kNumRows, group->metadata()->num_rows());
+
+ // column 0 ("a")
+ auto col = checked_pointer_cast<ByteArrayReader>(group->Column(0));
+
+ std::vector<ByteArray> values(kNumRows);
+ int64_t values_read;
+ auto levels_read =
+ col->ReadBatch(kNumRows, nullptr, nullptr, values.data(), &values_read);
+ ASSERT_EQ(kNumRows, levels_read);
+ ASSERT_EQ(kNumRows, values_read);
+ ASSERT_EQ(values[0], ByteArray("c7ce6bef-d5b0-4863-b199-8ea8c7fb117b"));
+ ASSERT_EQ(values[1], ByteArray("e8fb9197-cb9f-4118-b67f-fbfa65f61843"));
+ ASSERT_EQ(values[kNumRows - 2], ByteArray("ab52a0cc-c6bb-4d61-8a8f-166dc4b8b13c"));
+ ASSERT_EQ(values[kNumRows - 1], ByteArray("85440778-460a-41ac-aa2e-ac3ee41696bf"));
+}
+
+std::vector<TestCodecParam> test_codec_params{
+ {"LegacyLZ4Hadoop", hadoop_lz4_compressed(), hadoop_lz4_compressed_larger()},
+ {"LegacyLZ4NonHadoop", non_hadoop_lz4_compressed(), ""},
+ {"LZ4Raw", lz4_raw_compressed(), lz4_raw_compressed_larger()}};
+
+INSTANTIATE_TEST_SUITE_P(Lz4CodecTests, TestCodec, ::testing::ValuesIn(test_codec_params),
+ testing::PrintToStringParamName());
+#endif // ARROW_WITH_LZ4
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/schema.cc b/src/arrow/cpp/src/parquet/schema.cc
new file mode 100644
index 000000000..cfa6bdb29
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/schema.cc
@@ -0,0 +1,945 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/schema.h"
+
+#include <algorithm>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <utility>
+
+#include "arrow/util/logging.h"
+#include "parquet/exception.h"
+#include "parquet/schema_internal.h"
+#include "parquet/thrift_internal.h"
+
+using parquet::format::SchemaElement;
+
+namespace parquet {
+
+namespace schema {
+
+namespace {
+
+void ThrowInvalidLogicalType(const LogicalType& logical_type) {
+ std::stringstream ss;
+ ss << "Invalid logical type: " << logical_type.ToString();
+ throw ParquetException(ss.str());
+}
+
+} // namespace
+
+// ----------------------------------------------------------------------
+// ColumnPath
+
+std::shared_ptr<ColumnPath> ColumnPath::FromDotString(const std::string& dotstring) {
+ std::stringstream ss(dotstring);
+ std::string item;
+ std::vector<std::string> path;
+ while (std::getline(ss, item, '.')) {
+ path.push_back(item);
+ }
+ return std::make_shared<ColumnPath>(std::move(path));
+}
+
+std::shared_ptr<ColumnPath> ColumnPath::FromNode(const Node& node) {
+ // Build the path in reverse order as we traverse the nodes to the top
+ std::vector<std::string> rpath_;
+ const Node* cursor = &node;
+ // The schema node is not part of the ColumnPath
+ while (cursor->parent()) {
+ rpath_.push_back(cursor->name());
+ cursor = cursor->parent();
+ }
+
+ // Build ColumnPath in correct order
+ std::vector<std::string> path(rpath_.crbegin(), rpath_.crend());
+ return std::make_shared<ColumnPath>(std::move(path));
+}
+
+std::shared_ptr<ColumnPath> ColumnPath::extend(const std::string& node_name) const {
+ std::vector<std::string> path;
+ path.reserve(path_.size() + 1);
+ path.resize(path_.size() + 1);
+ std::copy(path_.cbegin(), path_.cend(), path.begin());
+ path[path_.size()] = node_name;
+
+ return std::make_shared<ColumnPath>(std::move(path));
+}
+
+std::string ColumnPath::ToDotString() const {
+ std::stringstream ss;
+ for (auto it = path_.cbegin(); it != path_.cend(); ++it) {
+ if (it != path_.cbegin()) {
+ ss << ".";
+ }
+ ss << *it;
+ }
+ return ss.str();
+}
+
+const std::vector<std::string>& ColumnPath::ToDotVector() const { return path_; }
+
+// ----------------------------------------------------------------------
+// Base node
+
+const std::shared_ptr<ColumnPath> Node::path() const {
+ // TODO(itaiin): Cache the result, or more precisely, cache ->ToDotString()
+ // since it is being used to access the leaf nodes
+ return ColumnPath::FromNode(*this);
+}
+
+bool Node::EqualsInternal(const Node* other) const {
+ return type_ == other->type_ && name_ == other->name_ &&
+ repetition_ == other->repetition_ && converted_type_ == other->converted_type_ &&
+ field_id_ == other->field_id() &&
+ logical_type_->Equals(*(other->logical_type()));
+}
+
+void Node::SetParent(const Node* parent) { parent_ = parent; }
+
+// ----------------------------------------------------------------------
+// Primitive node
+
+PrimitiveNode::PrimitiveNode(const std::string& name, Repetition::type repetition,
+ Type::type type, ConvertedType::type converted_type,
+ int length, int precision, int scale, int id)
+ : Node(Node::PRIMITIVE, name, repetition, converted_type, id),
+ physical_type_(type),
+ type_length_(length) {
+ std::stringstream ss;
+
+ // PARQUET-842: In an earlier revision, decimal_metadata_.isset was being
+ // set to true, but Impala will raise an incompatible metadata in such cases
+ memset(&decimal_metadata_, 0, sizeof(decimal_metadata_));
+
+ // Check if the physical and logical types match
+ // Mapping referred from Apache parquet-mr as on 2016-02-22
+ switch (converted_type) {
+ case ConvertedType::NONE:
+ // Logical type not set
+ break;
+ case ConvertedType::UTF8:
+ case ConvertedType::JSON:
+ case ConvertedType::BSON:
+ if (type != Type::BYTE_ARRAY) {
+ ss << ConvertedTypeToString(converted_type);
+ ss << " can only annotate BYTE_ARRAY fields";
+ throw ParquetException(ss.str());
+ }
+ break;
+ case ConvertedType::DECIMAL:
+ if ((type != Type::INT32) && (type != Type::INT64) && (type != Type::BYTE_ARRAY) &&
+ (type != Type::FIXED_LEN_BYTE_ARRAY)) {
+ ss << "DECIMAL can only annotate INT32, INT64, BYTE_ARRAY, and FIXED";
+ throw ParquetException(ss.str());
+ }
+ if (precision <= 0) {
+ ss << "Invalid DECIMAL precision: " << precision
+ << ". Precision must be a number between 1 and 38 inclusive";
+ throw ParquetException(ss.str());
+ }
+ if (scale < 0) {
+ ss << "Invalid DECIMAL scale: " << scale
+ << ". Scale must be a number between 0 and precision inclusive";
+ throw ParquetException(ss.str());
+ }
+ if (scale > precision) {
+ ss << "Invalid DECIMAL scale " << scale;
+ ss << " cannot be greater than precision " << precision;
+ throw ParquetException(ss.str());
+ }
+ decimal_metadata_.isset = true;
+ decimal_metadata_.precision = precision;
+ decimal_metadata_.scale = scale;
+ break;
+ case ConvertedType::DATE:
+ case ConvertedType::TIME_MILLIS:
+ case ConvertedType::UINT_8:
+ case ConvertedType::UINT_16:
+ case ConvertedType::UINT_32:
+ case ConvertedType::INT_8:
+ case ConvertedType::INT_16:
+ case ConvertedType::INT_32:
+ if (type != Type::INT32) {
+ ss << ConvertedTypeToString(converted_type);
+ ss << " can only annotate INT32";
+ throw ParquetException(ss.str());
+ }
+ break;
+ case ConvertedType::TIME_MICROS:
+ case ConvertedType::TIMESTAMP_MILLIS:
+ case ConvertedType::TIMESTAMP_MICROS:
+ case ConvertedType::UINT_64:
+ case ConvertedType::INT_64:
+ if (type != Type::INT64) {
+ ss << ConvertedTypeToString(converted_type);
+ ss << " can only annotate INT64";
+ throw ParquetException(ss.str());
+ }
+ break;
+ case ConvertedType::INTERVAL:
+ if ((type != Type::FIXED_LEN_BYTE_ARRAY) || (length != 12)) {
+ ss << "INTERVAL can only annotate FIXED_LEN_BYTE_ARRAY(12)";
+ throw ParquetException(ss.str());
+ }
+ break;
+ case ConvertedType::ENUM:
+ if (type != Type::BYTE_ARRAY) {
+ ss << "ENUM can only annotate BYTE_ARRAY fields";
+ throw ParquetException(ss.str());
+ }
+ break;
+ case ConvertedType::NA:
+ // NA can annotate any type
+ break;
+ default:
+ ss << ConvertedTypeToString(converted_type);
+ ss << " cannot be applied to a primitive type";
+ throw ParquetException(ss.str());
+ }
+ // For forward compatibility, create an equivalent logical type
+ logical_type_ = LogicalType::FromConvertedType(converted_type_, decimal_metadata_);
+ if (!(logical_type_ && !logical_type_->is_nested() &&
+ logical_type_->is_compatible(converted_type_, decimal_metadata_))) {
+ ThrowInvalidLogicalType(*logical_type_);
+ }
+
+ if (type == Type::FIXED_LEN_BYTE_ARRAY) {
+ if (length <= 0) {
+ ss << "Invalid FIXED_LEN_BYTE_ARRAY length: " << length;
+ throw ParquetException(ss.str());
+ }
+ type_length_ = length;
+ }
+}
+
+PrimitiveNode::PrimitiveNode(const std::string& name, Repetition::type repetition,
+ std::shared_ptr<const LogicalType> logical_type,
+ Type::type physical_type, int physical_length, int id)
+ : Node(Node::PRIMITIVE, name, repetition, std::move(logical_type), id),
+ physical_type_(physical_type),
+ type_length_(physical_length) {
+ std::stringstream error;
+ if (logical_type_) {
+ // Check for logical type <=> node type consistency
+ if (!logical_type_->is_nested()) {
+ // Check for logical type <=> physical type consistency
+ if (logical_type_->is_applicable(physical_type, physical_length)) {
+ // For backward compatibility, assign equivalent legacy
+ // converted type (if possible)
+ converted_type_ = logical_type_->ToConvertedType(&decimal_metadata_);
+ } else {
+ error << logical_type_->ToString();
+ error << " can not be applied to primitive type ";
+ error << TypeToString(physical_type);
+ throw ParquetException(error.str());
+ }
+ } else {
+ error << "Nested logical type ";
+ error << logical_type_->ToString();
+ error << " can not be applied to non-group node";
+ throw ParquetException(error.str());
+ }
+ } else {
+ logical_type_ = NoLogicalType::Make();
+ converted_type_ = logical_type_->ToConvertedType(&decimal_metadata_);
+ }
+ if (!(logical_type_ && !logical_type_->is_nested() &&
+ logical_type_->is_compatible(converted_type_, decimal_metadata_))) {
+ ThrowInvalidLogicalType(*logical_type_);
+ }
+
+ if (physical_type == Type::FIXED_LEN_BYTE_ARRAY) {
+ if (physical_length <= 0) {
+ error << "Invalid FIXED_LEN_BYTE_ARRAY length: " << physical_length;
+ throw ParquetException(error.str());
+ }
+ }
+}
+
+bool PrimitiveNode::EqualsInternal(const PrimitiveNode* other) const {
+ bool is_equal = true;
+ if (physical_type_ != other->physical_type_) {
+ return false;
+ }
+ if (converted_type_ == ConvertedType::DECIMAL) {
+ is_equal &= (decimal_metadata_.precision == other->decimal_metadata_.precision) &&
+ (decimal_metadata_.scale == other->decimal_metadata_.scale);
+ }
+ if (physical_type_ == Type::FIXED_LEN_BYTE_ARRAY) {
+ is_equal &= (type_length_ == other->type_length_);
+ }
+ return is_equal;
+}
+
+bool PrimitiveNode::Equals(const Node* other) const {
+ if (!Node::EqualsInternal(other)) {
+ return false;
+ }
+ return EqualsInternal(static_cast<const PrimitiveNode*>(other));
+}
+
+void PrimitiveNode::Visit(Node::Visitor* visitor) { visitor->Visit(this); }
+
+void PrimitiveNode::VisitConst(Node::ConstVisitor* visitor) const {
+ visitor->Visit(this);
+}
+
+// ----------------------------------------------------------------------
+// Group node
+
+GroupNode::GroupNode(const std::string& name, Repetition::type repetition,
+ const NodeVector& fields, ConvertedType::type converted_type, int id)
+ : Node(Node::GROUP, name, repetition, converted_type, id), fields_(fields) {
+ // For forward compatibility, create an equivalent logical type
+ logical_type_ = LogicalType::FromConvertedType(converted_type_);
+ if (!(logical_type_ && (logical_type_->is_nested() || logical_type_->is_none()) &&
+ logical_type_->is_compatible(converted_type_))) {
+ ThrowInvalidLogicalType(*logical_type_);
+ }
+
+ field_name_to_idx_.clear();
+ auto field_idx = 0;
+ for (NodePtr& field : fields_) {
+ field->SetParent(this);
+ field_name_to_idx_.emplace(field->name(), field_idx++);
+ }
+}
+
+GroupNode::GroupNode(const std::string& name, Repetition::type repetition,
+ const NodeVector& fields,
+ std::shared_ptr<const LogicalType> logical_type, int id)
+ : Node(Node::GROUP, name, repetition, std::move(logical_type), id), fields_(fields) {
+ if (logical_type_) {
+ // Check for logical type <=> node type consistency
+ if (logical_type_->is_nested()) {
+ // For backward compatibility, assign equivalent legacy converted type (if possible)
+ converted_type_ = logical_type_->ToConvertedType(nullptr);
+ } else {
+ std::stringstream error;
+ error << "Logical type ";
+ error << logical_type_->ToString();
+ error << " can not be applied to group node";
+ throw ParquetException(error.str());
+ }
+ } else {
+ logical_type_ = NoLogicalType::Make();
+ converted_type_ = logical_type_->ToConvertedType(nullptr);
+ }
+ if (!(logical_type_ && (logical_type_->is_nested() || logical_type_->is_none()) &&
+ logical_type_->is_compatible(converted_type_))) {
+ ThrowInvalidLogicalType(*logical_type_);
+ }
+
+ field_name_to_idx_.clear();
+ auto field_idx = 0;
+ for (NodePtr& field : fields_) {
+ field->SetParent(this);
+ field_name_to_idx_.emplace(field->name(), field_idx++);
+ }
+}
+
+bool GroupNode::EqualsInternal(const GroupNode* other) const {
+ if (this == other) {
+ return true;
+ }
+ if (this->field_count() != other->field_count()) {
+ return false;
+ }
+ for (int i = 0; i < this->field_count(); ++i) {
+ if (!this->field(i)->Equals(other->field(i).get())) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool GroupNode::Equals(const Node* other) const {
+ if (!Node::EqualsInternal(other)) {
+ return false;
+ }
+ return EqualsInternal(static_cast<const GroupNode*>(other));
+}
+
+int GroupNode::FieldIndex(const std::string& name) const {
+ auto search = field_name_to_idx_.find(name);
+ if (search == field_name_to_idx_.end()) {
+ // Not found
+ return -1;
+ }
+ return search->second;
+}
+
+int GroupNode::FieldIndex(const Node& node) const {
+ auto search = field_name_to_idx_.equal_range(node.name());
+ for (auto it = search.first; it != search.second; ++it) {
+ const int idx = it->second;
+ if (&node == field(idx).get()) {
+ return idx;
+ }
+ }
+ return -1;
+}
+
+void GroupNode::Visit(Node::Visitor* visitor) { visitor->Visit(this); }
+
+void GroupNode::VisitConst(Node::ConstVisitor* visitor) const { visitor->Visit(this); }
+
+// ----------------------------------------------------------------------
+// Node construction from Parquet metadata
+
+std::unique_ptr<Node> GroupNode::FromParquet(const void* opaque_element,
+ NodeVector fields) {
+ const format::SchemaElement* element =
+ static_cast<const format::SchemaElement*>(opaque_element);
+
+ int field_id = -1;
+ if (element->__isset.field_id) {
+ field_id = element->field_id;
+ }
+
+ std::unique_ptr<GroupNode> group_node;
+ if (element->__isset.logicalType) {
+ // updated writer with logical type present
+ group_node = std::unique_ptr<GroupNode>(
+ new GroupNode(element->name, LoadEnumSafe(&element->repetition_type), fields,
+ LogicalType::FromThrift(element->logicalType), field_id));
+ } else {
+ group_node = std::unique_ptr<GroupNode>(new GroupNode(
+ element->name, LoadEnumSafe(&element->repetition_type), fields,
+ (element->__isset.converted_type ? LoadEnumSafe(&element->converted_type)
+ : ConvertedType::NONE),
+ field_id));
+ }
+
+ return std::unique_ptr<Node>(group_node.release());
+}
+
+std::unique_ptr<Node> PrimitiveNode::FromParquet(const void* opaque_element) {
+ const format::SchemaElement* element =
+ static_cast<const format::SchemaElement*>(opaque_element);
+
+ int field_id = -1;
+ if (element->__isset.field_id) {
+ field_id = element->field_id;
+ }
+
+ std::unique_ptr<PrimitiveNode> primitive_node;
+ if (element->__isset.logicalType) {
+ // updated writer with logical type present
+ primitive_node = std::unique_ptr<PrimitiveNode>(
+ new PrimitiveNode(element->name, LoadEnumSafe(&element->repetition_type),
+ LogicalType::FromThrift(element->logicalType),
+ LoadEnumSafe(&element->type), element->type_length, field_id));
+ } else if (element->__isset.converted_type) {
+ // legacy writer with converted type present
+ primitive_node = std::unique_ptr<PrimitiveNode>(new PrimitiveNode(
+ element->name, LoadEnumSafe(&element->repetition_type),
+ LoadEnumSafe(&element->type), LoadEnumSafe(&element->converted_type),
+ element->type_length, element->precision, element->scale, field_id));
+ } else {
+ // logical type not present
+ primitive_node = std::unique_ptr<PrimitiveNode>(new PrimitiveNode(
+ element->name, LoadEnumSafe(&element->repetition_type), NoLogicalType::Make(),
+ LoadEnumSafe(&element->type), element->type_length, field_id));
+ }
+
+ // Return as unique_ptr to the base type
+ return std::unique_ptr<Node>(primitive_node.release());
+}
+
+bool GroupNode::HasRepeatedFields() const {
+ for (int i = 0; i < this->field_count(); ++i) {
+ auto field = this->field(i);
+ if (field->repetition() == Repetition::REPEATED) {
+ return true;
+ }
+ if (field->is_group()) {
+ const auto& group = static_cast<const GroupNode&>(*field);
+ return group.HasRepeatedFields();
+ }
+ }
+ return false;
+}
+
+void GroupNode::ToParquet(void* opaque_element) const {
+ format::SchemaElement* element = static_cast<format::SchemaElement*>(opaque_element);
+ element->__set_name(name_);
+ element->__set_num_children(field_count());
+ element->__set_repetition_type(ToThrift(repetition_));
+ if (converted_type_ != ConvertedType::NONE) {
+ element->__set_converted_type(ToThrift(converted_type_));
+ }
+ if (field_id_ >= 0) {
+ element->__set_field_id(field_id_);
+ }
+ if (logical_type_ && logical_type_->is_serialized()) {
+ element->__set_logicalType(logical_type_->ToThrift());
+ }
+ return;
+}
+
+void PrimitiveNode::ToParquet(void* opaque_element) const {
+ format::SchemaElement* element = static_cast<format::SchemaElement*>(opaque_element);
+ element->__set_name(name_);
+ element->__set_repetition_type(ToThrift(repetition_));
+ if (converted_type_ != ConvertedType::NONE) {
+ if (converted_type_ != ConvertedType::NA) {
+ element->__set_converted_type(ToThrift(converted_type_));
+ } else {
+ // ConvertedType::NA is an unreleased, obsolete synonym for LogicalType::Null.
+ // Never emit it (see PARQUET-1990 for discussion).
+ if (!logical_type_ || !logical_type_->is_null()) {
+ throw ParquetException(
+ "ConvertedType::NA is obsolete, please use LogicalType::Null instead");
+ }
+ }
+ }
+ if (field_id_ >= 0) {
+ element->__set_field_id(field_id_);
+ }
+ if (logical_type_ && logical_type_->is_serialized() &&
+ // TODO(tpboudreau): remove the following conjunct to enable serialization
+ // of IntervalTypes after parquet.thrift recognizes them
+ !logical_type_->is_interval()) {
+ element->__set_logicalType(logical_type_->ToThrift());
+ }
+ element->__set_type(ToThrift(physical_type_));
+ if (physical_type_ == Type::FIXED_LEN_BYTE_ARRAY) {
+ element->__set_type_length(type_length_);
+ }
+ if (decimal_metadata_.isset) {
+ element->__set_precision(decimal_metadata_.precision);
+ element->__set_scale(decimal_metadata_.scale);
+ }
+ return;
+}
+
+// ----------------------------------------------------------------------
+// Schema converters
+
+std::unique_ptr<Node> Unflatten(const format::SchemaElement* elements, int length) {
+ if (elements[0].num_children == 0) {
+ if (length == 1) {
+ // Degenerate case of Parquet file with no columns
+ return GroupNode::FromParquet(elements, {});
+ } else {
+ throw ParquetException(
+ "Parquet schema had multiple nodes but root had no children");
+ }
+ }
+
+ // We don't check that the root node is repeated since this is not
+ // consistently set by implementations
+
+ int pos = 0;
+
+ std::function<std::unique_ptr<Node>()> NextNode = [&]() {
+ if (pos == length) {
+ throw ParquetException("Malformed schema: not enough elements");
+ }
+ const SchemaElement& element = elements[pos++];
+ const void* opaque_element = static_cast<const void*>(&element);
+
+ if (element.num_children == 0 && element.__isset.type) {
+ // Leaf (primitive) node: always has a type
+ return PrimitiveNode::FromParquet(opaque_element);
+ } else {
+ // Group node (may have 0 children, but cannot have a type)
+ NodeVector fields;
+ for (int i = 0; i < element.num_children; ++i) {
+ std::unique_ptr<Node> field = NextNode();
+ fields.push_back(NodePtr(field.release()));
+ }
+ return GroupNode::FromParquet(opaque_element, std::move(fields));
+ }
+ };
+ return NextNode();
+}
+
+std::shared_ptr<SchemaDescriptor> FromParquet(const std::vector<SchemaElement>& schema) {
+ if (schema.empty()) {
+ throw ParquetException("Empty file schema (no root)");
+ }
+ std::unique_ptr<Node> root = Unflatten(&schema[0], static_cast<int>(schema.size()));
+ std::shared_ptr<SchemaDescriptor> descr = std::make_shared<SchemaDescriptor>();
+ descr->Init(std::shared_ptr<GroupNode>(static_cast<GroupNode*>(root.release())));
+ return descr;
+}
+
+class SchemaVisitor : public Node::ConstVisitor {
+ public:
+ explicit SchemaVisitor(std::vector<format::SchemaElement>* elements)
+ : elements_(elements) {}
+
+ void Visit(const Node* node) override {
+ format::SchemaElement element;
+ node->ToParquet(&element);
+ elements_->push_back(element);
+
+ if (node->is_group()) {
+ const GroupNode* group_node = static_cast<const GroupNode*>(node);
+ for (int i = 0; i < group_node->field_count(); ++i) {
+ group_node->field(i)->VisitConst(this);
+ }
+ }
+ }
+
+ private:
+ std::vector<format::SchemaElement>* elements_;
+};
+
+void ToParquet(const GroupNode* schema, std::vector<format::SchemaElement>* out) {
+ SchemaVisitor visitor(out);
+ schema->VisitConst(&visitor);
+}
+
+// ----------------------------------------------------------------------
+// Schema printing
+
+static void PrintRepLevel(Repetition::type repetition, std::ostream& stream) {
+ switch (repetition) {
+ case Repetition::REQUIRED:
+ stream << "required";
+ break;
+ case Repetition::OPTIONAL:
+ stream << "optional";
+ break;
+ case Repetition::REPEATED:
+ stream << "repeated";
+ break;
+ default:
+ break;
+ }
+}
+
+static void PrintType(const PrimitiveNode* node, std::ostream& stream) {
+ switch (node->physical_type()) {
+ case Type::BOOLEAN:
+ stream << "boolean";
+ break;
+ case Type::INT32:
+ stream << "int32";
+ break;
+ case Type::INT64:
+ stream << "int64";
+ break;
+ case Type::INT96:
+ stream << "int96";
+ break;
+ case Type::FLOAT:
+ stream << "float";
+ break;
+ case Type::DOUBLE:
+ stream << "double";
+ break;
+ case Type::BYTE_ARRAY:
+ stream << "binary";
+ break;
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ stream << "fixed_len_byte_array(" << node->type_length() << ")";
+ break;
+ default:
+ break;
+ }
+}
+
+static void PrintConvertedType(const PrimitiveNode* node, std::ostream& stream) {
+ auto lt = node->converted_type();
+ auto la = node->logical_type();
+ if (la && la->is_valid() && !la->is_none()) {
+ stream << " (" << la->ToString() << ")";
+ } else if (lt == ConvertedType::DECIMAL) {
+ stream << " (" << ConvertedTypeToString(lt) << "("
+ << node->decimal_metadata().precision << "," << node->decimal_metadata().scale
+ << "))";
+ } else if (lt != ConvertedType::NONE) {
+ stream << " (" << ConvertedTypeToString(lt) << ")";
+ }
+}
+
+struct SchemaPrinter : public Node::ConstVisitor {
+ explicit SchemaPrinter(std::ostream& stream, int indent_width)
+ : stream_(stream), indent_(0), indent_width_(2) {}
+
+ void Indent() {
+ if (indent_ > 0) {
+ std::string spaces(indent_, ' ');
+ stream_ << spaces;
+ }
+ }
+
+ void Visit(const Node* node) {
+ Indent();
+ if (node->is_group()) {
+ Visit(static_cast<const GroupNode*>(node));
+ } else {
+ // Primitive
+ Visit(static_cast<const PrimitiveNode*>(node));
+ }
+ }
+
+ void Visit(const PrimitiveNode* node) {
+ PrintRepLevel(node->repetition(), stream_);
+ stream_ << " ";
+ PrintType(node, stream_);
+ stream_ << " field_id=" << node->field_id() << " " << node->name();
+ PrintConvertedType(node, stream_);
+ stream_ << ";" << std::endl;
+ }
+
+ void Visit(const GroupNode* node) {
+ PrintRepLevel(node->repetition(), stream_);
+ stream_ << " group "
+ << "field_id=" << node->field_id() << " " << node->name();
+ auto lt = node->converted_type();
+ auto la = node->logical_type();
+ if (la && la->is_valid() && !la->is_none()) {
+ stream_ << " (" << la->ToString() << ")";
+ } else if (lt != ConvertedType::NONE) {
+ stream_ << " (" << ConvertedTypeToString(lt) << ")";
+ }
+ stream_ << " {" << std::endl;
+
+ indent_ += indent_width_;
+ for (int i = 0; i < node->field_count(); ++i) {
+ node->field(i)->VisitConst(this);
+ }
+ indent_ -= indent_width_;
+ Indent();
+ stream_ << "}" << std::endl;
+ }
+
+ std::ostream& stream_;
+ int indent_;
+ int indent_width_;
+};
+
+void PrintSchema(const Node* schema, std::ostream& stream, int indent_width) {
+ SchemaPrinter printer(stream, indent_width);
+ printer.Visit(schema);
+}
+
+} // namespace schema
+
+using schema::ColumnPath;
+using schema::GroupNode;
+using schema::Node;
+using schema::NodePtr;
+using schema::PrimitiveNode;
+
+void SchemaDescriptor::Init(std::unique_ptr<schema::Node> schema) {
+ Init(NodePtr(schema.release()));
+}
+
+class SchemaUpdater : public Node::Visitor {
+ public:
+ explicit SchemaUpdater(const std::vector<ColumnOrder>& column_orders)
+ : column_orders_(column_orders), leaf_count_(0) {}
+
+ void Visit(Node* node) override {
+ if (node->is_group()) {
+ GroupNode* group_node = static_cast<GroupNode*>(node);
+ for (int i = 0; i < group_node->field_count(); ++i) {
+ group_node->field(i)->Visit(this);
+ }
+ } else { // leaf node
+ PrimitiveNode* leaf_node = static_cast<PrimitiveNode*>(node);
+ leaf_node->SetColumnOrder(column_orders_[leaf_count_++]);
+ }
+ }
+
+ private:
+ const std::vector<ColumnOrder>& column_orders_;
+ int leaf_count_;
+};
+
+void SchemaDescriptor::updateColumnOrders(const std::vector<ColumnOrder>& column_orders) {
+ if (static_cast<int>(column_orders.size()) != num_columns()) {
+ throw ParquetException("Malformed schema: not enough ColumnOrder values");
+ }
+ SchemaUpdater visitor(column_orders);
+ const_cast<GroupNode*>(group_node_)->Visit(&visitor);
+}
+
+void SchemaDescriptor::Init(NodePtr schema) {
+ schema_ = std::move(schema);
+
+ if (!schema_->is_group()) {
+ throw ParquetException("Must initialize with a schema group");
+ }
+
+ group_node_ = static_cast<const GroupNode*>(schema_.get());
+ leaves_.clear();
+
+ for (int i = 0; i < group_node_->field_count(); ++i) {
+ BuildTree(group_node_->field(i), 0, 0, group_node_->field(i));
+ }
+}
+
+bool SchemaDescriptor::Equals(const SchemaDescriptor& other) const {
+ if (this->num_columns() != other.num_columns()) {
+ return false;
+ }
+
+ for (int i = 0; i < this->num_columns(); ++i) {
+ if (!this->Column(i)->Equals(*other.Column(i))) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+void SchemaDescriptor::BuildTree(const NodePtr& node, int16_t max_def_level,
+ int16_t max_rep_level, const NodePtr& base) {
+ if (node->is_optional()) {
+ ++max_def_level;
+ } else if (node->is_repeated()) {
+ // Repeated fields add a definition level. This is used to distinguish
+ // between an empty list and a list with an item in it.
+ ++max_rep_level;
+ ++max_def_level;
+ }
+
+ // Now, walk the schema and create a ColumnDescriptor for each leaf node
+ if (node->is_group()) {
+ const GroupNode* group = static_cast<const GroupNode*>(node.get());
+ for (int i = 0; i < group->field_count(); ++i) {
+ BuildTree(group->field(i), max_def_level, max_rep_level, base);
+ }
+ } else {
+ node_to_leaf_index_[static_cast<const PrimitiveNode*>(node.get())] =
+ static_cast<int>(leaves_.size());
+
+ // Primitive node, append to leaves
+ leaves_.push_back(ColumnDescriptor(node, max_def_level, max_rep_level, this));
+ leaf_to_base_.emplace(static_cast<int>(leaves_.size()) - 1, base);
+ leaf_to_idx_.emplace(node->path()->ToDotString(),
+ static_cast<int>(leaves_.size()) - 1);
+ }
+}
+
+int SchemaDescriptor::GetColumnIndex(const PrimitiveNode& node) const {
+ auto it = node_to_leaf_index_.find(&node);
+ if (it == node_to_leaf_index_.end()) {
+ return -1;
+ }
+ return it->second;
+}
+
+ColumnDescriptor::ColumnDescriptor(schema::NodePtr node, int16_t max_definition_level,
+ int16_t max_repetition_level,
+ const SchemaDescriptor* schema_descr)
+ : node_(std::move(node)),
+ max_definition_level_(max_definition_level),
+ max_repetition_level_(max_repetition_level) {
+ if (!node_->is_primitive()) {
+ throw ParquetException("Must be a primitive type");
+ }
+ primitive_node_ = static_cast<const PrimitiveNode*>(node_.get());
+}
+
+bool ColumnDescriptor::Equals(const ColumnDescriptor& other) const {
+ return primitive_node_->Equals(other.primitive_node_) &&
+ max_repetition_level() == other.max_repetition_level() &&
+ max_definition_level() == other.max_definition_level();
+}
+
+const ColumnDescriptor* SchemaDescriptor::Column(int i) const {
+ DCHECK(i >= 0 && i < static_cast<int>(leaves_.size()));
+ return &leaves_[i];
+}
+
+int SchemaDescriptor::ColumnIndex(const std::string& node_path) const {
+ auto search = leaf_to_idx_.find(node_path);
+ if (search == leaf_to_idx_.end()) {
+ // Not found
+ return -1;
+ }
+ return search->second;
+}
+
+int SchemaDescriptor::ColumnIndex(const Node& node) const {
+ auto search = leaf_to_idx_.equal_range(node.path()->ToDotString());
+ for (auto it = search.first; it != search.second; ++it) {
+ const int idx = it->second;
+ if (&node == Column(idx)->schema_node().get()) {
+ return idx;
+ }
+ }
+ return -1;
+}
+
+const schema::Node* SchemaDescriptor::GetColumnRoot(int i) const {
+ DCHECK(i >= 0 && i < static_cast<int>(leaves_.size()));
+ return leaf_to_base_.find(i)->second.get();
+}
+
+bool SchemaDescriptor::HasRepeatedFields() const {
+ return group_node_->HasRepeatedFields();
+}
+
+std::string SchemaDescriptor::ToString() const {
+ std::ostringstream ss;
+ PrintSchema(schema_.get(), ss);
+ return ss.str();
+}
+
+std::string ColumnDescriptor::ToString() const {
+ std::ostringstream ss;
+ ss << "column descriptor = {" << std::endl
+ << " name: " << name() << "," << std::endl
+ << " path: " << path()->ToDotString() << "," << std::endl
+ << " physical_type: " << TypeToString(physical_type()) << "," << std::endl
+ << " converted_type: " << ConvertedTypeToString(converted_type()) << ","
+ << std::endl
+ << " logical_type: " << logical_type()->ToString() << "," << std::endl
+ << " max_definition_level: " << max_definition_level() << "," << std::endl
+ << " max_repetition_level: " << max_repetition_level() << "," << std::endl;
+
+ if (physical_type() == ::parquet::Type::FIXED_LEN_BYTE_ARRAY) {
+ ss << " length: " << type_length() << "," << std::endl;
+ }
+
+ if (converted_type() == parquet::ConvertedType::DECIMAL) {
+ ss << " precision: " << type_precision() << "," << std::endl
+ << " scale: " << type_scale() << "," << std::endl;
+ }
+
+ ss << "}";
+ return ss.str();
+}
+
+int ColumnDescriptor::type_scale() const {
+ return primitive_node_->decimal_metadata().scale;
+}
+
+int ColumnDescriptor::type_precision() const {
+ return primitive_node_->decimal_metadata().precision;
+}
+
+int ColumnDescriptor::type_length() const { return primitive_node_->type_length(); }
+
+const std::shared_ptr<ColumnPath> ColumnDescriptor::path() const {
+ return primitive_node_->path();
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/schema.h b/src/arrow/cpp/src/parquet/schema.h
new file mode 100644
index 000000000..83d0cf24f
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/schema.h
@@ -0,0 +1,491 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This module contains the logical parquet-cpp types (independent of Thrift
+// structures), schema nodes, and related type tools
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "parquet/platform.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+class SchemaDescriptor;
+
+namespace schema {
+
+class Node;
+
+// List encodings: using the terminology from Impala to define different styles
+// of representing logical lists (a.k.a. ARRAY types) in Parquet schemas. Since
+// the converted type named in the Parquet metadata is ConvertedType::LIST we
+// use that terminology here. It also helps distinguish from the *_ARRAY
+// primitive types.
+//
+// One-level encoding: Only allows required lists with required cells
+// repeated value_type name
+//
+// Two-level encoding: Enables optional lists with only required cells
+// <required/optional> group list
+// repeated value_type item
+//
+// Three-level encoding: Enables optional lists with optional cells
+// <required/optional> group bag
+// repeated group list
+// <required/optional> value_type item
+//
+// 2- and 1-level encoding are respectively equivalent to 3-level encoding with
+// the non-repeated nodes set to required.
+//
+// The "official" encoding recommended in the Parquet spec is the 3-level, and
+// we use that as the default when creating list types. For semantic completeness
+// we allow the other two. Since all types of encodings will occur "in the
+// wild" we need to be able to interpret the associated definition levels in
+// the context of the actual encoding used in the file.
+//
+// NB: Some Parquet writers may not set ConvertedType::LIST on the repeated
+// SchemaElement, which could make things challenging if we are trying to infer
+// that a sequence of nodes semantically represents an array according to one
+// of these encodings (versus a struct containing an array). We should refuse
+// the temptation to guess, as they say.
+struct ListEncoding {
+ enum type { ONE_LEVEL, TWO_LEVEL, THREE_LEVEL };
+};
+
+class PARQUET_EXPORT ColumnPath {
+ public:
+ ColumnPath() : path_() {}
+ explicit ColumnPath(const std::vector<std::string>& path) : path_(path) {}
+ explicit ColumnPath(std::vector<std::string>&& path) : path_(std::move(path)) {}
+
+ static std::shared_ptr<ColumnPath> FromDotString(const std::string& dotstring);
+ static std::shared_ptr<ColumnPath> FromNode(const Node& node);
+
+ std::shared_ptr<ColumnPath> extend(const std::string& node_name) const;
+ std::string ToDotString() const;
+ const std::vector<std::string>& ToDotVector() const;
+
+ protected:
+ std::vector<std::string> path_;
+};
+
+// Base class for logical schema types. A type has a name, repetition level,
+// and optionally a logical type (ConvertedType in Parquet metadata parlance)
+class PARQUET_EXPORT Node {
+ public:
+ enum type { PRIMITIVE, GROUP };
+
+ virtual ~Node() {}
+
+ bool is_primitive() const { return type_ == Node::PRIMITIVE; }
+
+ bool is_group() const { return type_ == Node::GROUP; }
+
+ bool is_optional() const { return repetition_ == Repetition::OPTIONAL; }
+
+ bool is_repeated() const { return repetition_ == Repetition::REPEATED; }
+
+ bool is_required() const { return repetition_ == Repetition::REQUIRED; }
+
+ virtual bool Equals(const Node* other) const = 0;
+
+ const std::string& name() const { return name_; }
+
+ Node::type node_type() const { return type_; }
+
+ Repetition::type repetition() const { return repetition_; }
+
+ ConvertedType::type converted_type() const { return converted_type_; }
+
+ const std::shared_ptr<const LogicalType>& logical_type() const { return logical_type_; }
+
+ /// \brief The field_id value for the serialized SchemaElement. If the
+ /// field_id is less than 0 (e.g. -1), it will not be set when serialized to
+ /// Thrift.
+ int field_id() const { return field_id_; }
+
+ const Node* parent() const { return parent_; }
+
+ const std::shared_ptr<ColumnPath> path() const;
+
+ virtual void ToParquet(void* element) const = 0;
+
+ // Node::Visitor abstract class for walking schemas with the visitor pattern
+ class Visitor {
+ public:
+ virtual ~Visitor() {}
+
+ virtual void Visit(Node* node) = 0;
+ };
+ class ConstVisitor {
+ public:
+ virtual ~ConstVisitor() {}
+
+ virtual void Visit(const Node* node) = 0;
+ };
+
+ virtual void Visit(Visitor* visitor) = 0;
+ virtual void VisitConst(ConstVisitor* visitor) const = 0;
+
+ protected:
+ friend class GroupNode;
+
+ Node(Node::type type, const std::string& name, Repetition::type repetition,
+ ConvertedType::type converted_type = ConvertedType::NONE, int field_id = -1)
+ : type_(type),
+ name_(name),
+ repetition_(repetition),
+ converted_type_(converted_type),
+ field_id_(field_id),
+ parent_(NULLPTR) {}
+
+ Node(Node::type type, const std::string& name, Repetition::type repetition,
+ std::shared_ptr<const LogicalType> logical_type, int field_id = -1)
+ : type_(type),
+ name_(name),
+ repetition_(repetition),
+ logical_type_(std::move(logical_type)),
+ field_id_(field_id),
+ parent_(NULLPTR) {}
+
+ Node::type type_;
+ std::string name_;
+ Repetition::type repetition_;
+ ConvertedType::type converted_type_;
+ std::shared_ptr<const LogicalType> logical_type_;
+ int field_id_;
+ // Nodes should not be shared, they have a single parent.
+ const Node* parent_;
+
+ bool EqualsInternal(const Node* other) const;
+ void SetParent(const Node* p_parent);
+
+ private:
+ PARQUET_DISALLOW_COPY_AND_ASSIGN(Node);
+};
+
+// Save our breath all over the place with these typedefs
+typedef std::shared_ptr<Node> NodePtr;
+typedef std::vector<NodePtr> NodeVector;
+
+// A type that is one of the primitive Parquet storage types. In addition to
+// the other type metadata (name, repetition level, logical type), also has the
+// physical storage type and their type-specific metadata (byte width, decimal
+// parameters)
+class PARQUET_EXPORT PrimitiveNode : public Node {
+ public:
+ static std::unique_ptr<Node> FromParquet(const void* opaque_element);
+
+ // A field_id -1 (or any negative value) will be serialized as null in Thrift
+ static inline NodePtr Make(const std::string& name, Repetition::type repetition,
+ Type::type type,
+ ConvertedType::type converted_type = ConvertedType::NONE,
+ int length = -1, int precision = -1, int scale = -1,
+ int field_id = -1) {
+ return NodePtr(new PrimitiveNode(name, repetition, type, converted_type, length,
+ precision, scale, field_id));
+ }
+
+ // If no logical type, pass LogicalType::None() or nullptr
+ // A field_id -1 (or any negative value) will be serialized as null in Thrift
+ static inline NodePtr Make(const std::string& name, Repetition::type repetition,
+ std::shared_ptr<const LogicalType> logical_type,
+ Type::type primitive_type, int primitive_length = -1,
+ int field_id = -1) {
+ return NodePtr(new PrimitiveNode(name, repetition, logical_type, primitive_type,
+ primitive_length, field_id));
+ }
+
+ bool Equals(const Node* other) const override;
+
+ Type::type physical_type() const { return physical_type_; }
+
+ ColumnOrder column_order() const { return column_order_; }
+
+ void SetColumnOrder(ColumnOrder column_order) { column_order_ = column_order; }
+
+ int32_t type_length() const { return type_length_; }
+
+ const DecimalMetadata& decimal_metadata() const { return decimal_metadata_; }
+
+ void ToParquet(void* element) const override;
+ void Visit(Visitor* visitor) override;
+ void VisitConst(ConstVisitor* visitor) const override;
+
+ private:
+ PrimitiveNode(const std::string& name, Repetition::type repetition, Type::type type,
+ ConvertedType::type converted_type = ConvertedType::NONE, int length = -1,
+ int precision = -1, int scale = -1, int field_id = -1);
+
+ PrimitiveNode(const std::string& name, Repetition::type repetition,
+ std::shared_ptr<const LogicalType> logical_type,
+ Type::type primitive_type, int primitive_length = -1, int field_id = -1);
+
+ Type::type physical_type_;
+ int32_t type_length_;
+ DecimalMetadata decimal_metadata_;
+ ColumnOrder column_order_;
+
+ // For FIXED_LEN_BYTE_ARRAY
+ void SetTypeLength(int32_t length) { type_length_ = length; }
+
+ bool EqualsInternal(const PrimitiveNode* other) const;
+
+ FRIEND_TEST(TestPrimitiveNode, Attrs);
+ FRIEND_TEST(TestPrimitiveNode, Equals);
+ FRIEND_TEST(TestPrimitiveNode, PhysicalLogicalMapping);
+ FRIEND_TEST(TestPrimitiveNode, FromParquet);
+};
+
+class PARQUET_EXPORT GroupNode : public Node {
+ public:
+ static std::unique_ptr<Node> FromParquet(const void* opaque_element,
+ NodeVector fields = {});
+
+ // A field_id -1 (or any negative value) will be serialized as null in Thrift
+ static inline NodePtr Make(const std::string& name, Repetition::type repetition,
+ const NodeVector& fields,
+ ConvertedType::type converted_type = ConvertedType::NONE,
+ int field_id = -1) {
+ return NodePtr(new GroupNode(name, repetition, fields, converted_type, field_id));
+ }
+
+ // If no logical type, pass nullptr
+ // A field_id -1 (or any negative value) will be serialized as null in Thrift
+ static inline NodePtr Make(const std::string& name, Repetition::type repetition,
+ const NodeVector& fields,
+ std::shared_ptr<const LogicalType> logical_type,
+ int field_id = -1) {
+ return NodePtr(new GroupNode(name, repetition, fields, logical_type, field_id));
+ }
+
+ bool Equals(const Node* other) const override;
+
+ NodePtr field(int i) const { return fields_[i]; }
+ // Get the index of a field by its name, or negative value if not found.
+ // If several fields share the same name, it is unspecified which one
+ // is returned.
+ int FieldIndex(const std::string& name) const;
+ // Get the index of a field by its node, or negative value if not found.
+ int FieldIndex(const Node& node) const;
+
+ int field_count() const { return static_cast<int>(fields_.size()); }
+
+ void ToParquet(void* element) const override;
+ void Visit(Visitor* visitor) override;
+ void VisitConst(ConstVisitor* visitor) const override;
+
+ /// \brief Return true if this node or any child node has REPEATED repetition
+ /// type
+ bool HasRepeatedFields() const;
+
+ private:
+ GroupNode(const std::string& name, Repetition::type repetition,
+ const NodeVector& fields,
+ ConvertedType::type converted_type = ConvertedType::NONE, int field_id = -1);
+
+ GroupNode(const std::string& name, Repetition::type repetition,
+ const NodeVector& fields, std::shared_ptr<const LogicalType> logical_type,
+ int field_id = -1);
+
+ NodeVector fields_;
+ bool EqualsInternal(const GroupNode* other) const;
+
+ // Mapping between field name to the field index
+ std::unordered_multimap<std::string, int> field_name_to_idx_;
+
+ FRIEND_TEST(TestGroupNode, Attrs);
+ FRIEND_TEST(TestGroupNode, Equals);
+ FRIEND_TEST(TestGroupNode, FieldIndex);
+ FRIEND_TEST(TestGroupNode, FieldIndexDuplicateName);
+};
+
+// ----------------------------------------------------------------------
+// Convenience primitive type factory functions
+
+#define PRIMITIVE_FACTORY(FuncName, TYPE) \
+ static inline NodePtr FuncName(const std::string& name, \
+ Repetition::type repetition = Repetition::OPTIONAL, \
+ int field_id = -1) { \
+ return PrimitiveNode::Make(name, repetition, Type::TYPE, ConvertedType::NONE, \
+ /*length=*/-1, /*precision=*/-1, /*scale=*/-1, field_id); \
+ }
+
+PRIMITIVE_FACTORY(Boolean, BOOLEAN)
+PRIMITIVE_FACTORY(Int32, INT32)
+PRIMITIVE_FACTORY(Int64, INT64)
+PRIMITIVE_FACTORY(Int96, INT96)
+PRIMITIVE_FACTORY(Float, FLOAT)
+PRIMITIVE_FACTORY(Double, DOUBLE)
+PRIMITIVE_FACTORY(ByteArray, BYTE_ARRAY)
+
+void PARQUET_EXPORT PrintSchema(const schema::Node* schema, std::ostream& stream,
+ int indent_width = 2);
+
+} // namespace schema
+
+// The ColumnDescriptor encapsulates information necessary to interpret
+// primitive column data in the context of a particular schema. We have to
+// examine the node structure of a column's path to the root in the schema tree
+// to be able to reassemble the nested structure from the repetition and
+// definition levels.
+class PARQUET_EXPORT ColumnDescriptor {
+ public:
+ ColumnDescriptor(schema::NodePtr node, int16_t max_definition_level,
+ int16_t max_repetition_level,
+ const SchemaDescriptor* schema_descr = NULLPTR);
+
+ bool Equals(const ColumnDescriptor& other) const;
+
+ int16_t max_definition_level() const { return max_definition_level_; }
+
+ int16_t max_repetition_level() const { return max_repetition_level_; }
+
+ Type::type physical_type() const { return primitive_node_->physical_type(); }
+
+ ConvertedType::type converted_type() const { return primitive_node_->converted_type(); }
+
+ const std::shared_ptr<const LogicalType>& logical_type() const {
+ return primitive_node_->logical_type();
+ }
+
+ ColumnOrder column_order() const { return primitive_node_->column_order(); }
+
+ SortOrder::type sort_order() const {
+ auto la = logical_type();
+ auto pt = physical_type();
+ return la ? GetSortOrder(la, pt) : GetSortOrder(converted_type(), pt);
+ }
+
+ const std::string& name() const { return primitive_node_->name(); }
+
+ const std::shared_ptr<schema::ColumnPath> path() const;
+
+ const schema::NodePtr& schema_node() const { return node_; }
+
+ std::string ToString() const;
+
+ int type_length() const;
+
+ int type_precision() const;
+
+ int type_scale() const;
+
+ private:
+ schema::NodePtr node_;
+ const schema::PrimitiveNode* primitive_node_;
+
+ int16_t max_definition_level_;
+ int16_t max_repetition_level_;
+};
+
+// Container for the converted Parquet schema with a computed information from
+// the schema analysis needed for file reading
+//
+// * Column index to Node
+// * Max repetition / definition levels for each primitive node
+//
+// The ColumnDescriptor objects produced by this class can be used to assist in
+// the reconstruction of fully materialized data structures from the
+// repetition-definition level encoding of nested data
+//
+// TODO(wesm): this object can be recomputed from a Schema
+class PARQUET_EXPORT SchemaDescriptor {
+ public:
+ SchemaDescriptor() {}
+ ~SchemaDescriptor() {}
+
+ // Analyze the schema
+ void Init(std::unique_ptr<schema::Node> schema);
+ void Init(schema::NodePtr schema);
+
+ const ColumnDescriptor* Column(int i) const;
+
+ // Get the index of a column by its dotstring path, or negative value if not found.
+ // If several columns share the same dotstring path, it is unspecified which one
+ // is returned.
+ int ColumnIndex(const std::string& node_path) const;
+ // Get the index of a column by its node, or negative value if not found.
+ int ColumnIndex(const schema::Node& node) const;
+
+ bool Equals(const SchemaDescriptor& other) const;
+
+ // The number of physical columns appearing in the file
+ int num_columns() const { return static_cast<int>(leaves_.size()); }
+
+ const schema::NodePtr& schema_root() const { return schema_; }
+
+ const schema::GroupNode* group_node() const { return group_node_; }
+
+ // Returns the root (child of the schema root) node of the leaf(column) node
+ const schema::Node* GetColumnRoot(int i) const;
+
+ const std::string& name() const { return group_node_->name(); }
+
+ std::string ToString() const;
+
+ void updateColumnOrders(const std::vector<ColumnOrder>& column_orders);
+
+ /// \brief Return column index corresponding to a particular
+ /// PrimitiveNode. Returns -1 if not found
+ int GetColumnIndex(const schema::PrimitiveNode& node) const;
+
+ /// \brief Return true if any field or their children have REPEATED repetition
+ /// type
+ bool HasRepeatedFields() const;
+
+ private:
+ friend class ColumnDescriptor;
+
+ // Root Node
+ schema::NodePtr schema_;
+ // Root Node
+ const schema::GroupNode* group_node_;
+
+ void BuildTree(const schema::NodePtr& node, int16_t max_def_level,
+ int16_t max_rep_level, const schema::NodePtr& base);
+
+ // Result of leaf node / tree analysis
+ std::vector<ColumnDescriptor> leaves_;
+
+ std::unordered_map<const schema::PrimitiveNode*, int> node_to_leaf_index_;
+
+ // Mapping between leaf nodes and root group of leaf (first node
+ // below the schema's root group)
+ //
+ // For example, the leaf `a.b.c.d` would have a link back to `a`
+ //
+ // -- a <------
+ // -- -- b |
+ // -- -- -- c |
+ // -- -- -- -- d
+ std::unordered_map<int, schema::NodePtr> leaf_to_base_;
+
+ // Mapping between ColumnPath DotString to the leaf index
+ std::unordered_multimap<std::string, int> leaf_to_idx_;
+};
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/schema_internal.h b/src/arrow/cpp/src/parquet/schema_internal.h
new file mode 100644
index 000000000..c0cfffc87
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/schema_internal.h
@@ -0,0 +1,54 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Non-public Thrift schema serialization utilities
+
+#pragma once
+
+#include <memory>
+#include <vector>
+
+#include "parquet/platform.h"
+#include "parquet/schema.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+namespace format {
+class SchemaElement;
+}
+
+namespace schema {
+
+// ----------------------------------------------------------------------
+// Conversion from Parquet Thrift metadata
+
+PARQUET_EXPORT
+std::shared_ptr<SchemaDescriptor> FromParquet(
+ const std::vector<format::SchemaElement>& schema);
+
+PARQUET_EXPORT
+std::unique_ptr<Node> Unflatten(const format::SchemaElement* elements, int length);
+
+// ----------------------------------------------------------------------
+// Conversion to Parquet Thrift metadata
+
+PARQUET_EXPORT
+void ToParquet(const GroupNode* schema, std::vector<format::SchemaElement>* out);
+
+} // namespace schema
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/schema_test.cc b/src/arrow/cpp/src/parquet/schema_test.cc
new file mode 100644
index 000000000..703bac810
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/schema_test.cc
@@ -0,0 +1,2226 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <cstdlib>
+#include <cstring>
+#include <functional>
+#include <iosfwd>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/util/checked_cast.h"
+#include "parquet/exception.h"
+#include "parquet/schema.h"
+#include "parquet/schema_internal.h"
+#include "parquet/thrift_internal.h"
+#include "parquet/types.h"
+
+using ::arrow::internal::checked_cast;
+
+namespace parquet {
+
+using format::FieldRepetitionType;
+using format::SchemaElement;
+
+namespace schema {
+
+static inline SchemaElement NewPrimitive(const std::string& name,
+ FieldRepetitionType::type repetition,
+ Type::type type, int field_id = -1) {
+ SchemaElement result;
+ result.__set_name(name);
+ result.__set_repetition_type(repetition);
+ result.__set_type(static_cast<format::Type::type>(type));
+ if (field_id >= 0) {
+ result.__set_field_id(field_id);
+ }
+ return result;
+}
+
+static inline SchemaElement NewGroup(const std::string& name,
+ FieldRepetitionType::type repetition,
+ int num_children, int field_id = -1) {
+ SchemaElement result;
+ result.__set_name(name);
+ result.__set_repetition_type(repetition);
+ result.__set_num_children(num_children);
+
+ if (field_id >= 0) {
+ result.__set_field_id(field_id);
+ }
+
+ return result;
+}
+
+template <typename NodeType>
+static void CheckNodeRoundtrip(const Node& node) {
+ format::SchemaElement serialized;
+ node.ToParquet(&serialized);
+ std::unique_ptr<Node> recovered = NodeType::FromParquet(&serialized);
+ ASSERT_TRUE(node.Equals(recovered.get()))
+ << "Recovered node not equivalent to original node constructed "
+ << "with logical type " << node.logical_type()->ToString() << " got "
+ << recovered->logical_type()->ToString();
+}
+
+static void ConfirmPrimitiveNodeRoundtrip(
+ const std::shared_ptr<const LogicalType>& logical_type, Type::type physical_type,
+ int physical_length, int field_id = -1) {
+ auto node = PrimitiveNode::Make("something", Repetition::REQUIRED, logical_type,
+ physical_type, physical_length, field_id);
+ CheckNodeRoundtrip<PrimitiveNode>(*node);
+}
+
+static void ConfirmGroupNodeRoundtrip(
+ std::string name, const std::shared_ptr<const LogicalType>& logical_type,
+ int field_id = -1) {
+ auto node = GroupNode::Make(name, Repetition::REQUIRED, {}, logical_type, field_id);
+ CheckNodeRoundtrip<GroupNode>(*node);
+}
+
+// ----------------------------------------------------------------------
+// ColumnPath
+
+TEST(TestColumnPath, TestAttrs) {
+ ColumnPath path(std::vector<std::string>({"toplevel", "leaf"}));
+
+ ASSERT_EQ(path.ToDotString(), "toplevel.leaf");
+
+ std::shared_ptr<ColumnPath> path_ptr = ColumnPath::FromDotString("toplevel.leaf");
+ ASSERT_EQ(path_ptr->ToDotString(), "toplevel.leaf");
+
+ std::shared_ptr<ColumnPath> extended = path_ptr->extend("anotherlevel");
+ ASSERT_EQ(extended->ToDotString(), "toplevel.leaf.anotherlevel");
+}
+
+// ----------------------------------------------------------------------
+// Primitive node
+
+class TestPrimitiveNode : public ::testing::Test {
+ public:
+ void SetUp() {
+ name_ = "name";
+ field_id_ = 5;
+ }
+
+ void Convert(const format::SchemaElement* element) {
+ node_ = PrimitiveNode::FromParquet(element);
+ ASSERT_TRUE(node_->is_primitive());
+ prim_node_ = static_cast<const PrimitiveNode*>(node_.get());
+ }
+
+ protected:
+ std::string name_;
+ const PrimitiveNode* prim_node_;
+
+ int field_id_;
+ std::unique_ptr<Node> node_;
+};
+
+TEST_F(TestPrimitiveNode, Attrs) {
+ PrimitiveNode node1("foo", Repetition::REPEATED, Type::INT32);
+
+ PrimitiveNode node2("bar", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::UTF8);
+
+ ASSERT_EQ("foo", node1.name());
+
+ ASSERT_TRUE(node1.is_primitive());
+ ASSERT_FALSE(node1.is_group());
+
+ ASSERT_EQ(Repetition::REPEATED, node1.repetition());
+ ASSERT_EQ(Repetition::OPTIONAL, node2.repetition());
+
+ ASSERT_EQ(Node::PRIMITIVE, node1.node_type());
+
+ ASSERT_EQ(Type::INT32, node1.physical_type());
+ ASSERT_EQ(Type::BYTE_ARRAY, node2.physical_type());
+
+ // logical types
+ ASSERT_EQ(ConvertedType::NONE, node1.converted_type());
+ ASSERT_EQ(ConvertedType::UTF8, node2.converted_type());
+
+ // repetition
+ PrimitiveNode node3("foo", Repetition::REPEATED, Type::INT32);
+ PrimitiveNode node4("foo", Repetition::REQUIRED, Type::INT32);
+ PrimitiveNode node5("foo", Repetition::OPTIONAL, Type::INT32);
+
+ ASSERT_TRUE(node3.is_repeated());
+ ASSERT_FALSE(node3.is_optional());
+
+ ASSERT_TRUE(node4.is_required());
+
+ ASSERT_TRUE(node5.is_optional());
+ ASSERT_FALSE(node5.is_required());
+}
+
+TEST_F(TestPrimitiveNode, FromParquet) {
+ SchemaElement elt =
+ NewPrimitive(name_, FieldRepetitionType::OPTIONAL, Type::INT32, field_id_);
+ ASSERT_NO_FATAL_FAILURE(Convert(&elt));
+ ASSERT_EQ(name_, prim_node_->name());
+ ASSERT_EQ(field_id_, prim_node_->field_id());
+ ASSERT_EQ(Repetition::OPTIONAL, prim_node_->repetition());
+ ASSERT_EQ(Type::INT32, prim_node_->physical_type());
+ ASSERT_EQ(ConvertedType::NONE, prim_node_->converted_type());
+
+ // Test a logical type
+ elt = NewPrimitive(name_, FieldRepetitionType::REQUIRED, Type::BYTE_ARRAY, field_id_);
+ elt.__set_converted_type(format::ConvertedType::UTF8);
+
+ ASSERT_NO_FATAL_FAILURE(Convert(&elt));
+ ASSERT_EQ(Repetition::REQUIRED, prim_node_->repetition());
+ ASSERT_EQ(Type::BYTE_ARRAY, prim_node_->physical_type());
+ ASSERT_EQ(ConvertedType::UTF8, prim_node_->converted_type());
+
+ // FIXED_LEN_BYTE_ARRAY
+ elt = NewPrimitive(name_, FieldRepetitionType::OPTIONAL, Type::FIXED_LEN_BYTE_ARRAY,
+ field_id_);
+ elt.__set_type_length(16);
+
+ ASSERT_NO_FATAL_FAILURE(Convert(&elt));
+ ASSERT_EQ(name_, prim_node_->name());
+ ASSERT_EQ(field_id_, prim_node_->field_id());
+ ASSERT_EQ(Repetition::OPTIONAL, prim_node_->repetition());
+ ASSERT_EQ(Type::FIXED_LEN_BYTE_ARRAY, prim_node_->physical_type());
+ ASSERT_EQ(16, prim_node_->type_length());
+
+ // format::ConvertedType::Decimal
+ elt = NewPrimitive(name_, FieldRepetitionType::OPTIONAL, Type::FIXED_LEN_BYTE_ARRAY,
+ field_id_);
+ elt.__set_converted_type(format::ConvertedType::DECIMAL);
+ elt.__set_type_length(6);
+ elt.__set_scale(2);
+ elt.__set_precision(12);
+
+ ASSERT_NO_FATAL_FAILURE(Convert(&elt));
+ ASSERT_EQ(Type::FIXED_LEN_BYTE_ARRAY, prim_node_->physical_type());
+ ASSERT_EQ(ConvertedType::DECIMAL, prim_node_->converted_type());
+ ASSERT_EQ(6, prim_node_->type_length());
+ ASSERT_EQ(2, prim_node_->decimal_metadata().scale);
+ ASSERT_EQ(12, prim_node_->decimal_metadata().precision);
+}
+
+TEST_F(TestPrimitiveNode, Equals) {
+ PrimitiveNode node1("foo", Repetition::REQUIRED, Type::INT32);
+ PrimitiveNode node2("foo", Repetition::REQUIRED, Type::INT64);
+ PrimitiveNode node3("bar", Repetition::REQUIRED, Type::INT32);
+ PrimitiveNode node4("foo", Repetition::OPTIONAL, Type::INT32);
+ PrimitiveNode node5("foo", Repetition::REQUIRED, Type::INT32);
+
+ ASSERT_TRUE(node1.Equals(&node1));
+ ASSERT_FALSE(node1.Equals(&node2));
+ ASSERT_FALSE(node1.Equals(&node3));
+ ASSERT_FALSE(node1.Equals(&node4));
+ ASSERT_TRUE(node1.Equals(&node5));
+
+ PrimitiveNode flba1("foo", Repetition::REQUIRED, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, 12, 4, 2);
+
+ PrimitiveNode flba2("foo", Repetition::REQUIRED, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, 1, 4, 2);
+ flba2.SetTypeLength(12);
+
+ PrimitiveNode flba3("foo", Repetition::REQUIRED, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, 1, 4, 2);
+ flba3.SetTypeLength(16);
+
+ PrimitiveNode flba4("foo", Repetition::REQUIRED, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, 12, 4, 0);
+
+ PrimitiveNode flba5("foo", Repetition::REQUIRED, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::NONE, 12, 4, 0);
+
+ ASSERT_TRUE(flba1.Equals(&flba2));
+ ASSERT_FALSE(flba1.Equals(&flba3));
+ ASSERT_FALSE(flba1.Equals(&flba4));
+ ASSERT_FALSE(flba1.Equals(&flba5));
+}
+
+TEST_F(TestPrimitiveNode, PhysicalLogicalMapping) {
+ ASSERT_NO_THROW(PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::INT32,
+ ConvertedType::INT_32));
+ ASSERT_NO_THROW(PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::BYTE_ARRAY,
+ ConvertedType::JSON));
+ ASSERT_THROW(
+ PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::INT32, ConvertedType::JSON),
+ ParquetException);
+ ASSERT_NO_THROW(PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::INT64,
+ ConvertedType::TIMESTAMP_MILLIS));
+ ASSERT_THROW(PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::INT32,
+ ConvertedType::INT_64),
+ ParquetException);
+ ASSERT_THROW(PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::BYTE_ARRAY,
+ ConvertedType::INT_8),
+ ParquetException);
+ ASSERT_THROW(PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::BYTE_ARRAY,
+ ConvertedType::INTERVAL),
+ ParquetException);
+ ASSERT_THROW(PrimitiveNode::Make("foo", Repetition::REQUIRED,
+ Type::FIXED_LEN_BYTE_ARRAY, ConvertedType::ENUM),
+ ParquetException);
+ ASSERT_NO_THROW(PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::BYTE_ARRAY,
+ ConvertedType::ENUM));
+ ASSERT_THROW(
+ PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, 0, 2, 4),
+ ParquetException);
+ ASSERT_THROW(PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::FLOAT,
+ ConvertedType::DECIMAL, 0, 2, 4),
+ ParquetException);
+ ASSERT_THROW(
+ PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, 0, 4, 0),
+ ParquetException);
+ ASSERT_THROW(
+ PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, 10, 0, 4),
+ ParquetException);
+ ASSERT_THROW(
+ PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, 10, 4, -1),
+ ParquetException);
+ ASSERT_THROW(
+ PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, 10, 2, 4),
+ ParquetException);
+ ASSERT_NO_THROW(PrimitiveNode::Make("foo", Repetition::REQUIRED,
+ Type::FIXED_LEN_BYTE_ARRAY, ConvertedType::DECIMAL,
+ 10, 6, 4));
+ ASSERT_NO_THROW(PrimitiveNode::Make("foo", Repetition::REQUIRED,
+ Type::FIXED_LEN_BYTE_ARRAY, ConvertedType::INTERVAL,
+ 12));
+ ASSERT_THROW(
+ PrimitiveNode::Make("foo", Repetition::REQUIRED, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::INTERVAL, 10),
+ ParquetException);
+}
+
+// ----------------------------------------------------------------------
+// Group node
+
+class TestGroupNode : public ::testing::Test {
+ public:
+ NodeVector Fields1() {
+ NodeVector fields;
+
+ fields.push_back(Int32("one", Repetition::REQUIRED));
+ fields.push_back(Int64("two"));
+ fields.push_back(Double("three"));
+
+ return fields;
+ }
+
+ NodeVector Fields2() {
+ // Fields with a duplicate name
+ NodeVector fields;
+
+ fields.push_back(Int32("duplicate", Repetition::REQUIRED));
+ fields.push_back(Int64("unique"));
+ fields.push_back(Double("duplicate"));
+
+ return fields;
+ }
+};
+
+TEST_F(TestGroupNode, Attrs) {
+ NodeVector fields = Fields1();
+
+ GroupNode node1("foo", Repetition::REPEATED, fields);
+ GroupNode node2("bar", Repetition::OPTIONAL, fields, ConvertedType::LIST);
+
+ ASSERT_EQ("foo", node1.name());
+
+ ASSERT_TRUE(node1.is_group());
+ ASSERT_FALSE(node1.is_primitive());
+
+ ASSERT_EQ(fields.size(), node1.field_count());
+
+ ASSERT_TRUE(node1.is_repeated());
+ ASSERT_TRUE(node2.is_optional());
+
+ ASSERT_EQ(Repetition::REPEATED, node1.repetition());
+ ASSERT_EQ(Repetition::OPTIONAL, node2.repetition());
+
+ ASSERT_EQ(Node::GROUP, node1.node_type());
+
+ // logical types
+ ASSERT_EQ(ConvertedType::NONE, node1.converted_type());
+ ASSERT_EQ(ConvertedType::LIST, node2.converted_type());
+}
+
+TEST_F(TestGroupNode, Equals) {
+ NodeVector f1 = Fields1();
+ NodeVector f2 = Fields1();
+
+ GroupNode group1("group", Repetition::REPEATED, f1);
+ GroupNode group2("group", Repetition::REPEATED, f2);
+ GroupNode group3("group2", Repetition::REPEATED, f2);
+
+ // This is copied in the GroupNode ctor, so this is okay
+ f2.push_back(Float("four", Repetition::OPTIONAL));
+ GroupNode group4("group", Repetition::REPEATED, f2);
+ GroupNode group5("group", Repetition::REPEATED, Fields1());
+
+ ASSERT_TRUE(group1.Equals(&group1));
+ ASSERT_TRUE(group1.Equals(&group2));
+ ASSERT_FALSE(group1.Equals(&group3));
+
+ ASSERT_FALSE(group1.Equals(&group4));
+ ASSERT_FALSE(group5.Equals(&group4));
+}
+
+TEST_F(TestGroupNode, FieldIndex) {
+ NodeVector fields = Fields1();
+ GroupNode group("group", Repetition::REQUIRED, fields);
+ for (size_t i = 0; i < fields.size(); i++) {
+ auto field = group.field(static_cast<int>(i));
+ ASSERT_EQ(i, group.FieldIndex(*field));
+ }
+
+ // Test a non field node
+ auto non_field_alien = Int32("alien", Repetition::REQUIRED); // other name
+ auto non_field_familiar = Int32("one", Repetition::REPEATED); // other node
+ ASSERT_LT(group.FieldIndex(*non_field_alien), 0);
+ ASSERT_LT(group.FieldIndex(*non_field_familiar), 0);
+}
+
+TEST_F(TestGroupNode, FieldIndexDuplicateName) {
+ NodeVector fields = Fields2();
+ GroupNode group("group", Repetition::REQUIRED, fields);
+ for (size_t i = 0; i < fields.size(); i++) {
+ auto field = group.field(static_cast<int>(i));
+ ASSERT_EQ(i, group.FieldIndex(*field));
+ }
+}
+
+// ----------------------------------------------------------------------
+// Test convert group
+
+class TestSchemaConverter : public ::testing::Test {
+ public:
+ void setUp() { name_ = "parquet_schema"; }
+
+ void Convert(const parquet::format::SchemaElement* elements, int length) {
+ node_ = Unflatten(elements, length);
+ ASSERT_TRUE(node_->is_group());
+ group_ = static_cast<const GroupNode*>(node_.get());
+ }
+
+ protected:
+ std::string name_;
+ const GroupNode* group_;
+ std::unique_ptr<Node> node_;
+};
+
+bool check_for_parent_consistency(const GroupNode* node) {
+ // Each node should have the group as parent
+ for (int i = 0; i < node->field_count(); i++) {
+ const NodePtr& field = node->field(i);
+ if (field->parent() != node) {
+ return false;
+ }
+ if (field->is_group()) {
+ const GroupNode* group = static_cast<GroupNode*>(field.get());
+ if (!check_for_parent_consistency(group)) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+TEST_F(TestSchemaConverter, NestedExample) {
+ SchemaElement elt;
+ std::vector<SchemaElement> elements;
+ elements.push_back(NewGroup(name_, FieldRepetitionType::REPEATED, /*num_children=*/2,
+ /*field_id=*/0));
+
+ // A primitive one
+ elements.push_back(NewPrimitive("a", FieldRepetitionType::REQUIRED, Type::INT32, 1));
+
+ // A group
+ elements.push_back(NewGroup("bag", FieldRepetitionType::OPTIONAL, 1, 2));
+
+ // 3-level list encoding, by hand
+ elt = NewGroup("b", FieldRepetitionType::REPEATED, 1, 3);
+ elt.__set_converted_type(format::ConvertedType::LIST);
+ elements.push_back(elt);
+ elements.push_back(NewPrimitive("item", FieldRepetitionType::OPTIONAL, Type::INT64, 4));
+
+ ASSERT_NO_FATAL_FAILURE(Convert(&elements[0], static_cast<int>(elements.size())));
+
+ // Construct the expected schema
+ NodeVector fields;
+ fields.push_back(Int32("a", Repetition::REQUIRED, 1));
+
+ // 3-level list encoding
+ NodePtr item = Int64("item", Repetition::OPTIONAL, 4);
+ NodePtr list(
+ GroupNode::Make("b", Repetition::REPEATED, {item}, ConvertedType::LIST, 3));
+ NodePtr bag(
+ GroupNode::Make("bag", Repetition::OPTIONAL, {list}, /*logical_type=*/nullptr, 2));
+ fields.push_back(bag);
+
+ NodePtr schema = GroupNode::Make(name_, Repetition::REPEATED, fields,
+ /*logical_type=*/nullptr, 0);
+
+ ASSERT_TRUE(schema->Equals(group_));
+
+ // Check that the parent relationship in each node is consistent
+ ASSERT_EQ(group_->parent(), nullptr);
+ ASSERT_TRUE(check_for_parent_consistency(group_));
+}
+
+TEST_F(TestSchemaConverter, ZeroColumns) {
+ // ARROW-3843
+ SchemaElement elements[1];
+ elements[0] = NewGroup("schema", FieldRepetitionType::REPEATED, 0, 0);
+ ASSERT_NO_THROW(Convert(elements, 1));
+}
+
+TEST_F(TestSchemaConverter, InvalidRoot) {
+ // According to the Parquet specification, the first element in the
+ // list<SchemaElement> is a group whose children (and their descendants)
+ // contain all of the rest of the flattened schema elements. If the first
+ // element is not a group, it is a malformed Parquet file.
+
+ SchemaElement elements[2];
+ elements[0] =
+ NewPrimitive("not-a-group", FieldRepetitionType::REQUIRED, Type::INT32, 0);
+ ASSERT_THROW(Convert(elements, 2), ParquetException);
+
+ // While the Parquet spec indicates that the root group should have REPEATED
+ // repetition type, some implementations may return REQUIRED or OPTIONAL
+ // groups as the first element. These tests check that this is okay as a
+ // practicality matter.
+ elements[0] = NewGroup("not-repeated", FieldRepetitionType::REQUIRED, 1, 0);
+ elements[1] = NewPrimitive("a", FieldRepetitionType::REQUIRED, Type::INT32, 1);
+ ASSERT_NO_FATAL_FAILURE(Convert(elements, 2));
+
+ elements[0] = NewGroup("not-repeated", FieldRepetitionType::OPTIONAL, 1, 0);
+ ASSERT_NO_FATAL_FAILURE(Convert(elements, 2));
+}
+
+TEST_F(TestSchemaConverter, NotEnoughChildren) {
+ // Throw a ParquetException, but don't core dump or anything
+ SchemaElement elt;
+ std::vector<SchemaElement> elements;
+ elements.push_back(NewGroup(name_, FieldRepetitionType::REPEATED, 2, 0));
+ ASSERT_THROW(Convert(&elements[0], 1), ParquetException);
+}
+
+// ----------------------------------------------------------------------
+// Schema tree flatten / unflatten
+
+class TestSchemaFlatten : public ::testing::Test {
+ public:
+ void setUp() { name_ = "parquet_schema"; }
+
+ void Flatten(const GroupNode* schema) { ToParquet(schema, &elements_); }
+
+ protected:
+ std::string name_;
+ std::vector<format::SchemaElement> elements_;
+};
+
+TEST_F(TestSchemaFlatten, DecimalMetadata) {
+ // Checks that DecimalMetadata is only set for DecimalTypes
+ NodePtr node = PrimitiveNode::Make("decimal", Repetition::REQUIRED, Type::INT64,
+ ConvertedType::DECIMAL, -1, 8, 4);
+ NodePtr group =
+ GroupNode::Make("group", Repetition::REPEATED, {node}, ConvertedType::LIST);
+ Flatten(reinterpret_cast<GroupNode*>(group.get()));
+ ASSERT_EQ("decimal", elements_[1].name);
+ ASSERT_TRUE(elements_[1].__isset.precision);
+ ASSERT_TRUE(elements_[1].__isset.scale);
+
+ elements_.clear();
+ // ... including those created with new logical types
+ node = PrimitiveNode::Make("decimal", Repetition::REQUIRED,
+ DecimalLogicalType::Make(10, 5), Type::INT64, -1);
+ group = GroupNode::Make("group", Repetition::REPEATED, {node}, ListLogicalType::Make());
+ Flatten(reinterpret_cast<GroupNode*>(group.get()));
+ ASSERT_EQ("decimal", elements_[1].name);
+ ASSERT_TRUE(elements_[1].__isset.precision);
+ ASSERT_TRUE(elements_[1].__isset.scale);
+
+ elements_.clear();
+ // Not for integers with no logical type
+ group = GroupNode::Make("group", Repetition::REPEATED, {Int64("int64")},
+ ConvertedType::LIST);
+ Flatten(reinterpret_cast<GroupNode*>(group.get()));
+ ASSERT_EQ("int64", elements_[1].name);
+ ASSERT_FALSE(elements_[0].__isset.precision);
+ ASSERT_FALSE(elements_[0].__isset.scale);
+}
+
+TEST_F(TestSchemaFlatten, NestedExample) {
+ SchemaElement elt;
+ std::vector<SchemaElement> elements;
+ elements.push_back(NewGroup(name_, FieldRepetitionType::REPEATED, 2, 0));
+
+ // A primitive one
+ elements.push_back(NewPrimitive("a", FieldRepetitionType::REQUIRED, Type::INT32, 1));
+
+ // A group
+ elements.push_back(NewGroup("bag", FieldRepetitionType::OPTIONAL, 1, 2));
+
+ // 3-level list encoding, by hand
+ elt = NewGroup("b", FieldRepetitionType::REPEATED, 1, 3);
+ elt.__set_converted_type(format::ConvertedType::LIST);
+ format::ListType ls;
+ format::LogicalType lt;
+ lt.__set_LIST(ls);
+ elt.__set_logicalType(lt);
+ elements.push_back(elt);
+ elements.push_back(NewPrimitive("item", FieldRepetitionType::OPTIONAL, Type::INT64, 4));
+
+ // Construct the schema
+ NodeVector fields;
+ fields.push_back(Int32("a", Repetition::REQUIRED, 1));
+
+ // 3-level list encoding
+ NodePtr item = Int64("item", Repetition::OPTIONAL, 4);
+ NodePtr list(
+ GroupNode::Make("b", Repetition::REPEATED, {item}, ConvertedType::LIST, 3));
+ NodePtr bag(GroupNode::Make("bag", Repetition::OPTIONAL, {list},
+ /*logical_type=*/nullptr, 2));
+ fields.push_back(bag);
+
+ NodePtr schema = GroupNode::Make(name_, Repetition::REPEATED, fields,
+ /*logical_type=*/nullptr, 0);
+
+ Flatten(static_cast<GroupNode*>(schema.get()));
+ ASSERT_EQ(elements_.size(), elements.size());
+ for (size_t i = 0; i < elements_.size(); i++) {
+ ASSERT_EQ(elements_[i], elements[i]);
+ }
+}
+
+TEST(TestColumnDescriptor, TestAttrs) {
+ NodePtr node = PrimitiveNode::Make("name", Repetition::OPTIONAL, Type::BYTE_ARRAY,
+ ConvertedType::UTF8);
+ ColumnDescriptor descr(node, 4, 1);
+
+ ASSERT_EQ("name", descr.name());
+ ASSERT_EQ(4, descr.max_definition_level());
+ ASSERT_EQ(1, descr.max_repetition_level());
+
+ ASSERT_EQ(Type::BYTE_ARRAY, descr.physical_type());
+
+ ASSERT_EQ(-1, descr.type_length());
+ const char* expected_descr = R"(column descriptor = {
+ name: name,
+ path: ,
+ physical_type: BYTE_ARRAY,
+ converted_type: UTF8,
+ logical_type: String,
+ max_definition_level: 4,
+ max_repetition_level: 1,
+})";
+ ASSERT_EQ(expected_descr, descr.ToString());
+
+ // Test FIXED_LEN_BYTE_ARRAY
+ node = PrimitiveNode::Make("name", Repetition::OPTIONAL, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::DECIMAL, 12, 10, 4);
+ ColumnDescriptor descr2(node, 4, 1);
+
+ ASSERT_EQ(Type::FIXED_LEN_BYTE_ARRAY, descr2.physical_type());
+ ASSERT_EQ(12, descr2.type_length());
+
+ expected_descr = R"(column descriptor = {
+ name: name,
+ path: ,
+ physical_type: FIXED_LEN_BYTE_ARRAY,
+ converted_type: DECIMAL,
+ logical_type: Decimal(precision=10, scale=4),
+ max_definition_level: 4,
+ max_repetition_level: 1,
+ length: 12,
+ precision: 10,
+ scale: 4,
+})";
+ ASSERT_EQ(expected_descr, descr2.ToString());
+}
+
+class TestSchemaDescriptor : public ::testing::Test {
+ public:
+ void setUp() {}
+
+ protected:
+ SchemaDescriptor descr_;
+};
+
+TEST_F(TestSchemaDescriptor, InitNonGroup) {
+ NodePtr node = PrimitiveNode::Make("field", Repetition::OPTIONAL, Type::INT32);
+
+ ASSERT_THROW(descr_.Init(node), ParquetException);
+}
+
+TEST_F(TestSchemaDescriptor, Equals) {
+ NodePtr schema;
+
+ NodePtr inta = Int32("a", Repetition::REQUIRED);
+ NodePtr intb = Int64("b", Repetition::OPTIONAL);
+ NodePtr intb2 = Int64("b2", Repetition::OPTIONAL);
+ NodePtr intc = ByteArray("c", Repetition::REPEATED);
+
+ NodePtr item1 = Int64("item1", Repetition::REQUIRED);
+ NodePtr item2 = Boolean("item2", Repetition::OPTIONAL);
+ NodePtr item3 = Int32("item3", Repetition::REPEATED);
+ NodePtr list(GroupNode::Make("records", Repetition::REPEATED, {item1, item2, item3},
+ ConvertedType::LIST));
+
+ NodePtr bag(GroupNode::Make("bag", Repetition::OPTIONAL, {list}));
+ NodePtr bag2(GroupNode::Make("bag", Repetition::REQUIRED, {list}));
+
+ SchemaDescriptor descr1;
+ descr1.Init(GroupNode::Make("schema", Repetition::REPEATED, {inta, intb, intc, bag}));
+
+ ASSERT_TRUE(descr1.Equals(descr1));
+
+ SchemaDescriptor descr2;
+ descr2.Init(GroupNode::Make("schema", Repetition::REPEATED, {inta, intb, intc, bag2}));
+ ASSERT_FALSE(descr1.Equals(descr2));
+
+ SchemaDescriptor descr3;
+ descr3.Init(GroupNode::Make("schema", Repetition::REPEATED, {inta, intb2, intc, bag}));
+ ASSERT_FALSE(descr1.Equals(descr3));
+
+ // Robust to name of parent node
+ SchemaDescriptor descr4;
+ descr4.Init(GroupNode::Make("SCHEMA", Repetition::REPEATED, {inta, intb, intc, bag}));
+ ASSERT_TRUE(descr1.Equals(descr4));
+
+ SchemaDescriptor descr5;
+ descr5.Init(
+ GroupNode::Make("schema", Repetition::REPEATED, {inta, intb, intc, bag, intb2}));
+ ASSERT_FALSE(descr1.Equals(descr5));
+
+ // Different max repetition / definition levels
+ ColumnDescriptor col1(inta, 5, 1);
+ ColumnDescriptor col2(inta, 6, 1);
+ ColumnDescriptor col3(inta, 5, 2);
+
+ ASSERT_TRUE(col1.Equals(col1));
+ ASSERT_FALSE(col1.Equals(col2));
+ ASSERT_FALSE(col1.Equals(col3));
+}
+
+TEST_F(TestSchemaDescriptor, BuildTree) {
+ NodeVector fields;
+ NodePtr schema;
+
+ NodePtr inta = Int32("a", Repetition::REQUIRED);
+ fields.push_back(inta);
+ fields.push_back(Int64("b", Repetition::OPTIONAL));
+ fields.push_back(ByteArray("c", Repetition::REPEATED));
+
+ // 3-level list encoding
+ NodePtr item1 = Int64("item1", Repetition::REQUIRED);
+ NodePtr item2 = Boolean("item2", Repetition::OPTIONAL);
+ NodePtr item3 = Int32("item3", Repetition::REPEATED);
+ NodePtr list(GroupNode::Make("records", Repetition::REPEATED, {item1, item2, item3},
+ ConvertedType::LIST));
+ NodePtr bag(GroupNode::Make("bag", Repetition::OPTIONAL, {list}));
+ fields.push_back(bag);
+
+ schema = GroupNode::Make("schema", Repetition::REPEATED, fields);
+
+ descr_.Init(schema);
+
+ int nleaves = 6;
+
+ // 6 leaves
+ ASSERT_EQ(nleaves, descr_.num_columns());
+
+ // mdef mrep
+ // required int32 a 0 0
+ // optional int64 b 1 0
+ // repeated byte_array c 1 1
+ // optional group bag 1 0
+ // repeated group records 2 1
+ // required int64 item1 2 1
+ // optional boolean item2 3 1
+ // repeated int32 item3 3 2
+ int16_t ex_max_def_levels[6] = {0, 1, 1, 2, 3, 3};
+ int16_t ex_max_rep_levels[6] = {0, 0, 1, 1, 1, 2};
+
+ for (int i = 0; i < nleaves; ++i) {
+ const ColumnDescriptor* col = descr_.Column(i);
+ EXPECT_EQ(ex_max_def_levels[i], col->max_definition_level()) << i;
+ EXPECT_EQ(ex_max_rep_levels[i], col->max_repetition_level()) << i;
+ }
+
+ ASSERT_EQ(descr_.Column(0)->path()->ToDotString(), "a");
+ ASSERT_EQ(descr_.Column(1)->path()->ToDotString(), "b");
+ ASSERT_EQ(descr_.Column(2)->path()->ToDotString(), "c");
+ ASSERT_EQ(descr_.Column(3)->path()->ToDotString(), "bag.records.item1");
+ ASSERT_EQ(descr_.Column(4)->path()->ToDotString(), "bag.records.item2");
+ ASSERT_EQ(descr_.Column(5)->path()->ToDotString(), "bag.records.item3");
+
+ for (int i = 0; i < nleaves; ++i) {
+ auto col = descr_.Column(i);
+ ASSERT_EQ(i, descr_.ColumnIndex(*col->schema_node()));
+ }
+
+ // Test non-column nodes find
+ NodePtr non_column_alien = Int32("alien", Repetition::REQUIRED); // other path
+ NodePtr non_column_familiar = Int32("a", Repetition::REPEATED); // other node
+ ASSERT_LT(descr_.ColumnIndex(*non_column_alien), 0);
+ ASSERT_LT(descr_.ColumnIndex(*non_column_familiar), 0);
+
+ ASSERT_EQ(inta.get(), descr_.GetColumnRoot(0));
+ ASSERT_EQ(bag.get(), descr_.GetColumnRoot(3));
+ ASSERT_EQ(bag.get(), descr_.GetColumnRoot(4));
+ ASSERT_EQ(bag.get(), descr_.GetColumnRoot(5));
+
+ ASSERT_EQ(schema.get(), descr_.group_node());
+
+ // Init clears the leaves
+ descr_.Init(schema);
+ ASSERT_EQ(nleaves, descr_.num_columns());
+}
+
+TEST_F(TestSchemaDescriptor, HasRepeatedFields) {
+ NodeVector fields;
+ NodePtr schema;
+
+ NodePtr inta = Int32("a", Repetition::REQUIRED);
+ fields.push_back(inta);
+ fields.push_back(Int64("b", Repetition::OPTIONAL));
+ fields.push_back(ByteArray("c", Repetition::REPEATED));
+
+ schema = GroupNode::Make("schema", Repetition::REPEATED, fields);
+ descr_.Init(schema);
+ ASSERT_EQ(true, descr_.HasRepeatedFields());
+
+ // 3-level list encoding
+ NodePtr item1 = Int64("item1", Repetition::REQUIRED);
+ NodePtr item2 = Boolean("item2", Repetition::OPTIONAL);
+ NodePtr item3 = Int32("item3", Repetition::REPEATED);
+ NodePtr list(GroupNode::Make("records", Repetition::REPEATED, {item1, item2, item3},
+ ConvertedType::LIST));
+ NodePtr bag(GroupNode::Make("bag", Repetition::OPTIONAL, {list}));
+ fields.push_back(bag);
+
+ schema = GroupNode::Make("schema", Repetition::REPEATED, fields);
+ descr_.Init(schema);
+ ASSERT_EQ(true, descr_.HasRepeatedFields());
+
+ // 3-level list encoding
+ NodePtr item_key = Int64("key", Repetition::REQUIRED);
+ NodePtr item_value = Boolean("value", Repetition::OPTIONAL);
+ NodePtr map(GroupNode::Make("map", Repetition::REPEATED, {item_key, item_value},
+ ConvertedType::MAP));
+ NodePtr my_map(GroupNode::Make("my_map", Repetition::OPTIONAL, {map}));
+ fields.push_back(my_map);
+
+ schema = GroupNode::Make("schema", Repetition::REPEATED, fields);
+ descr_.Init(schema);
+ ASSERT_EQ(true, descr_.HasRepeatedFields());
+ ASSERT_EQ(true, descr_.HasRepeatedFields());
+}
+
+static std::string Print(const NodePtr& node) {
+ std::stringstream ss;
+ PrintSchema(node.get(), ss);
+ return ss.str();
+}
+
+TEST(TestSchemaPrinter, Examples) {
+ // Test schema 1
+ NodeVector fields;
+ fields.push_back(Int32("a", Repetition::REQUIRED, 1));
+
+ // 3-level list encoding
+ NodePtr item1 = Int64("item1", Repetition::OPTIONAL, 4);
+ NodePtr item2 = Boolean("item2", Repetition::REQUIRED, 5);
+ NodePtr list(
+ GroupNode::Make("b", Repetition::REPEATED, {item1, item2}, ConvertedType::LIST, 3));
+ NodePtr bag(
+ GroupNode::Make("bag", Repetition::OPTIONAL, {list}, /*logical_type=*/nullptr, 2));
+ fields.push_back(bag);
+
+ fields.push_back(PrimitiveNode::Make("c", Repetition::REQUIRED, Type::INT32,
+ ConvertedType::DECIMAL, -1, 3, 2, 6));
+
+ fields.push_back(PrimitiveNode::Make("d", Repetition::REQUIRED,
+ DecimalLogicalType::Make(10, 5), Type::INT64,
+ /*length=*/-1, 7));
+
+ NodePtr schema = GroupNode::Make("schema", Repetition::REPEATED, fields,
+ /*logical_type=*/nullptr, 0);
+
+ std::string result = Print(schema);
+
+ std::string expected = R"(repeated group field_id=0 schema {
+ required int32 field_id=1 a;
+ optional group field_id=2 bag {
+ repeated group field_id=3 b (List) {
+ optional int64 field_id=4 item1;
+ required boolean field_id=5 item2;
+ }
+ }
+ required int32 field_id=6 c (Decimal(precision=3, scale=2));
+ required int64 field_id=7 d (Decimal(precision=10, scale=5));
+}
+)";
+ ASSERT_EQ(expected, result);
+}
+
+static void ConfirmFactoryEquivalence(
+ ConvertedType::type converted_type,
+ const std::shared_ptr<const LogicalType>& from_make,
+ std::function<bool(const std::shared_ptr<const LogicalType>&)> check_is_type) {
+ std::shared_ptr<const LogicalType> from_converted_type =
+ LogicalType::FromConvertedType(converted_type);
+ ASSERT_EQ(from_converted_type->type(), from_make->type())
+ << from_make->ToString() << " logical types unexpectedly do not match on type";
+ ASSERT_TRUE(from_converted_type->Equals(*from_make))
+ << from_make->ToString() << " logical types unexpectedly not equivalent";
+ ASSERT_TRUE(check_is_type(from_converted_type))
+ << from_converted_type->ToString()
+ << " logical type (from converted type) does not have expected type property";
+ ASSERT_TRUE(check_is_type(from_make))
+ << from_make->ToString()
+ << " logical type (from Make()) does not have expected type property";
+ return;
+}
+
+TEST(TestLogicalTypeConstruction, FactoryEquivalence) {
+ // For each legacy converted type, ensure that the equivalent logical type object
+ // can be obtained from either the base class's FromConvertedType() factory method or
+ // the logical type type class's Make() method (accessed via convenience methods on the
+ // base class) and that these logical type objects are equivalent
+
+ struct ConfirmFactoryEquivalenceArguments {
+ ConvertedType::type converted_type;
+ std::shared_ptr<const LogicalType> logical_type;
+ std::function<bool(const std::shared_ptr<const LogicalType>&)> check_is_type;
+ };
+
+ auto check_is_string = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_string();
+ };
+ auto check_is_map = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_map();
+ };
+ auto check_is_list = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_list();
+ };
+ auto check_is_enum = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_enum();
+ };
+ auto check_is_date = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_date();
+ };
+ auto check_is_time = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_time();
+ };
+ auto check_is_timestamp = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_timestamp();
+ };
+ auto check_is_int = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_int();
+ };
+ auto check_is_JSON = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_JSON();
+ };
+ auto check_is_BSON = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_BSON();
+ };
+ auto check_is_interval = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_interval();
+ };
+ auto check_is_none = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_none();
+ };
+
+ std::vector<ConfirmFactoryEquivalenceArguments> cases = {
+ {ConvertedType::UTF8, LogicalType::String(), check_is_string},
+ {ConvertedType::MAP, LogicalType::Map(), check_is_map},
+ {ConvertedType::MAP_KEY_VALUE, LogicalType::Map(), check_is_map},
+ {ConvertedType::LIST, LogicalType::List(), check_is_list},
+ {ConvertedType::ENUM, LogicalType::Enum(), check_is_enum},
+ {ConvertedType::DATE, LogicalType::Date(), check_is_date},
+ {ConvertedType::TIME_MILLIS, LogicalType::Time(true, LogicalType::TimeUnit::MILLIS),
+ check_is_time},
+ {ConvertedType::TIME_MICROS, LogicalType::Time(true, LogicalType::TimeUnit::MICROS),
+ check_is_time},
+ {ConvertedType::TIMESTAMP_MILLIS,
+ LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS), check_is_timestamp},
+ {ConvertedType::TIMESTAMP_MICROS,
+ LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), check_is_timestamp},
+ {ConvertedType::UINT_8, LogicalType::Int(8, false), check_is_int},
+ {ConvertedType::UINT_16, LogicalType::Int(16, false), check_is_int},
+ {ConvertedType::UINT_32, LogicalType::Int(32, false), check_is_int},
+ {ConvertedType::UINT_64, LogicalType::Int(64, false), check_is_int},
+ {ConvertedType::INT_8, LogicalType::Int(8, true), check_is_int},
+ {ConvertedType::INT_16, LogicalType::Int(16, true), check_is_int},
+ {ConvertedType::INT_32, LogicalType::Int(32, true), check_is_int},
+ {ConvertedType::INT_64, LogicalType::Int(64, true), check_is_int},
+ {ConvertedType::JSON, LogicalType::JSON(), check_is_JSON},
+ {ConvertedType::BSON, LogicalType::BSON(), check_is_BSON},
+ {ConvertedType::INTERVAL, LogicalType::Interval(), check_is_interval},
+ {ConvertedType::NONE, LogicalType::None(), check_is_none}};
+
+ for (const ConfirmFactoryEquivalenceArguments& c : cases) {
+ ConfirmFactoryEquivalence(c.converted_type, c.logical_type, c.check_is_type);
+ }
+
+ // ConvertedType::DECIMAL, LogicalType::Decimal, is_decimal
+ schema::DecimalMetadata converted_decimal_metadata;
+ converted_decimal_metadata.isset = true;
+ converted_decimal_metadata.precision = 10;
+ converted_decimal_metadata.scale = 4;
+ std::shared_ptr<const LogicalType> from_converted_type =
+ LogicalType::FromConvertedType(ConvertedType::DECIMAL, converted_decimal_metadata);
+ std::shared_ptr<const LogicalType> from_make = LogicalType::Decimal(10, 4);
+ ASSERT_EQ(from_converted_type->type(), from_make->type());
+ ASSERT_TRUE(from_converted_type->Equals(*from_make));
+ ASSERT_TRUE(from_converted_type->is_decimal());
+ ASSERT_TRUE(from_make->is_decimal());
+ ASSERT_TRUE(LogicalType::Decimal(16)->Equals(*LogicalType::Decimal(16, 0)));
+}
+
+static void ConfirmConvertedTypeCompatibility(
+ const std::shared_ptr<const LogicalType>& original,
+ ConvertedType::type expected_converted_type) {
+ ASSERT_TRUE(original->is_valid())
+ << original->ToString() << " logical type unexpectedly is not valid";
+ schema::DecimalMetadata converted_decimal_metadata;
+ ConvertedType::type converted_type =
+ original->ToConvertedType(&converted_decimal_metadata);
+ ASSERT_EQ(converted_type, expected_converted_type)
+ << original->ToString()
+ << " logical type unexpectedly returns incorrect converted type";
+ ASSERT_FALSE(converted_decimal_metadata.isset)
+ << original->ToString()
+ << " logical type unexpectedly returns converted decimal metadata that is set";
+ ASSERT_TRUE(original->is_compatible(converted_type, converted_decimal_metadata))
+ << original->ToString()
+ << " logical type unexpectedly is incompatible with converted type and decimal "
+ "metadata it returned";
+ ASSERT_FALSE(original->is_compatible(converted_type, {true, 1, 1}))
+ << original->ToString()
+ << " logical type unexpectedly is compatible with converted decimal metadata that "
+ "is "
+ "set";
+ ASSERT_TRUE(original->is_compatible(converted_type))
+ << original->ToString()
+ << " logical type unexpectedly is incompatible with converted type it returned";
+ std::shared_ptr<const LogicalType> reconstructed =
+ LogicalType::FromConvertedType(converted_type, converted_decimal_metadata);
+ ASSERT_TRUE(reconstructed->is_valid()) << "Reconstructed " << reconstructed->ToString()
+ << " logical type unexpectedly is not valid";
+ ASSERT_TRUE(reconstructed->Equals(*original))
+ << "Reconstructed logical type (" << reconstructed->ToString()
+ << ") unexpectedly not equivalent to original logical type ("
+ << original->ToString() << ")";
+ return;
+}
+
+TEST(TestLogicalTypeConstruction, ConvertedTypeCompatibility) {
+ // For each legacy converted type, ensure that the equivalent logical type
+ // emits correct, compatible converted type information and that the emitted
+ // information can be used to reconstruct another equivalent logical type.
+
+ struct ExpectedConvertedType {
+ std::shared_ptr<const LogicalType> logical_type;
+ ConvertedType::type converted_type;
+ };
+
+ std::vector<ExpectedConvertedType> cases = {
+ {LogicalType::String(), ConvertedType::UTF8},
+ {LogicalType::Map(), ConvertedType::MAP},
+ {LogicalType::List(), ConvertedType::LIST},
+ {LogicalType::Enum(), ConvertedType::ENUM},
+ {LogicalType::Date(), ConvertedType::DATE},
+ {LogicalType::Time(true, LogicalType::TimeUnit::MILLIS),
+ ConvertedType::TIME_MILLIS},
+ {LogicalType::Time(true, LogicalType::TimeUnit::MICROS),
+ ConvertedType::TIME_MICROS},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS),
+ ConvertedType::TIMESTAMP_MILLIS},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS),
+ ConvertedType::TIMESTAMP_MICROS},
+ {LogicalType::Int(8, false), ConvertedType::UINT_8},
+ {LogicalType::Int(16, false), ConvertedType::UINT_16},
+ {LogicalType::Int(32, false), ConvertedType::UINT_32},
+ {LogicalType::Int(64, false), ConvertedType::UINT_64},
+ {LogicalType::Int(8, true), ConvertedType::INT_8},
+ {LogicalType::Int(16, true), ConvertedType::INT_16},
+ {LogicalType::Int(32, true), ConvertedType::INT_32},
+ {LogicalType::Int(64, true), ConvertedType::INT_64},
+ {LogicalType::JSON(), ConvertedType::JSON},
+ {LogicalType::BSON(), ConvertedType::BSON},
+ {LogicalType::Interval(), ConvertedType::INTERVAL},
+ {LogicalType::None(), ConvertedType::NONE}};
+
+ for (const ExpectedConvertedType& c : cases) {
+ ConfirmConvertedTypeCompatibility(c.logical_type, c.converted_type);
+ }
+
+ // Special cases ...
+
+ std::shared_ptr<const LogicalType> original;
+ ConvertedType::type converted_type;
+ schema::DecimalMetadata converted_decimal_metadata;
+ std::shared_ptr<const LogicalType> reconstructed;
+
+ // DECIMAL
+ std::memset(&converted_decimal_metadata, 0x00, sizeof(converted_decimal_metadata));
+ original = LogicalType::Decimal(6, 2);
+ ASSERT_TRUE(original->is_valid());
+ converted_type = original->ToConvertedType(&converted_decimal_metadata);
+ ASSERT_EQ(converted_type, ConvertedType::DECIMAL);
+ ASSERT_TRUE(converted_decimal_metadata.isset);
+ ASSERT_EQ(converted_decimal_metadata.precision, 6);
+ ASSERT_EQ(converted_decimal_metadata.scale, 2);
+ ASSERT_TRUE(original->is_compatible(converted_type, converted_decimal_metadata));
+ reconstructed =
+ LogicalType::FromConvertedType(converted_type, converted_decimal_metadata);
+ ASSERT_TRUE(reconstructed->is_valid());
+ ASSERT_TRUE(reconstructed->Equals(*original));
+
+ // Undefined
+ original = UndefinedLogicalType::Make();
+ ASSERT_TRUE(original->is_invalid());
+ ASSERT_FALSE(original->is_valid());
+ converted_type = original->ToConvertedType(&converted_decimal_metadata);
+ ASSERT_EQ(converted_type, ConvertedType::UNDEFINED);
+ ASSERT_FALSE(converted_decimal_metadata.isset);
+ ASSERT_TRUE(original->is_compatible(converted_type, converted_decimal_metadata));
+ ASSERT_TRUE(original->is_compatible(converted_type));
+ reconstructed =
+ LogicalType::FromConvertedType(converted_type, converted_decimal_metadata);
+ ASSERT_TRUE(reconstructed->is_invalid());
+ ASSERT_TRUE(reconstructed->Equals(*original));
+}
+
+static void ConfirmNewTypeIncompatibility(
+ const std::shared_ptr<const LogicalType>& logical_type,
+ std::function<bool(const std::shared_ptr<const LogicalType>&)> check_is_type) {
+ ASSERT_TRUE(logical_type->is_valid())
+ << logical_type->ToString() << " logical type unexpectedly is not valid";
+ ASSERT_TRUE(check_is_type(logical_type))
+ << logical_type->ToString() << " logical type is not expected logical type";
+ schema::DecimalMetadata converted_decimal_metadata;
+ ConvertedType::type converted_type =
+ logical_type->ToConvertedType(&converted_decimal_metadata);
+ ASSERT_EQ(converted_type, ConvertedType::NONE)
+ << logical_type->ToString()
+ << " logical type converted type unexpectedly is not NONE";
+ ASSERT_FALSE(converted_decimal_metadata.isset)
+ << logical_type->ToString()
+ << " logical type converted decimal metadata unexpectedly is set";
+ return;
+}
+
+TEST(TestLogicalTypeConstruction, NewTypeIncompatibility) {
+ // For each new logical type, ensure that the type
+ // correctly reports that it has no legacy equivalent
+
+ struct ConfirmNewTypeIncompatibilityArguments {
+ std::shared_ptr<const LogicalType> logical_type;
+ std::function<bool(const std::shared_ptr<const LogicalType>&)> check_is_type;
+ };
+
+ auto check_is_UUID = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_UUID();
+ };
+ auto check_is_null = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_null();
+ };
+ auto check_is_time = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_time();
+ };
+ auto check_is_timestamp = [](const std::shared_ptr<const LogicalType>& logical_type) {
+ return logical_type->is_timestamp();
+ };
+
+ std::vector<ConfirmNewTypeIncompatibilityArguments> cases = {
+ {LogicalType::UUID(), check_is_UUID},
+ {LogicalType::Null(), check_is_null},
+ {LogicalType::Time(false, LogicalType::TimeUnit::MILLIS), check_is_time},
+ {LogicalType::Time(false, LogicalType::TimeUnit::MICROS), check_is_time},
+ {LogicalType::Time(false, LogicalType::TimeUnit::NANOS), check_is_time},
+ {LogicalType::Time(true, LogicalType::TimeUnit::NANOS), check_is_time},
+ {LogicalType::Timestamp(false, LogicalType::TimeUnit::NANOS), check_is_timestamp},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::NANOS), check_is_timestamp},
+ };
+
+ for (const ConfirmNewTypeIncompatibilityArguments& c : cases) {
+ ConfirmNewTypeIncompatibility(c.logical_type, c.check_is_type);
+ }
+}
+
+TEST(TestLogicalTypeConstruction, FactoryExceptions) {
+ // Ensure that logical type construction catches invalid arguments
+
+ std::vector<std::function<void()>> cases = {
+ []() {
+ TimeLogicalType::Make(true, LogicalType::TimeUnit::UNKNOWN);
+ }, // Invalid TimeUnit
+ []() {
+ TimestampLogicalType::Make(true, LogicalType::TimeUnit::UNKNOWN);
+ }, // Invalid TimeUnit
+ []() { IntLogicalType::Make(-1, false); }, // Invalid bit width
+ []() { IntLogicalType::Make(0, false); }, // Invalid bit width
+ []() { IntLogicalType::Make(1, false); }, // Invalid bit width
+ []() { IntLogicalType::Make(65, false); }, // Invalid bit width
+ []() { DecimalLogicalType::Make(-1); }, // Invalid precision
+ []() { DecimalLogicalType::Make(0); }, // Invalid precision
+ []() { DecimalLogicalType::Make(0, 0); }, // Invalid precision
+ []() { DecimalLogicalType::Make(10, -1); }, // Invalid scale
+ []() { DecimalLogicalType::Make(10, 11); } // Invalid scale
+ };
+
+ for (auto f : cases) {
+ ASSERT_ANY_THROW(f());
+ }
+}
+
+static void ConfirmLogicalTypeProperties(
+ const std::shared_ptr<const LogicalType>& logical_type, bool nested, bool serialized,
+ bool valid) {
+ ASSERT_TRUE(logical_type->is_nested() == nested)
+ << logical_type->ToString() << " logical type has incorrect nested() property";
+ ASSERT_TRUE(logical_type->is_serialized() == serialized)
+ << logical_type->ToString() << " logical type has incorrect serialized() property";
+ ASSERT_TRUE(logical_type->is_valid() == valid)
+ << logical_type->ToString() << " logical type has incorrect valid() property";
+ ASSERT_TRUE(logical_type->is_nonnested() != nested)
+ << logical_type->ToString() << " logical type has incorrect nonnested() property";
+ ASSERT_TRUE(logical_type->is_invalid() != valid)
+ << logical_type->ToString() << " logical type has incorrect invalid() property";
+ return;
+}
+
+TEST(TestLogicalTypeOperation, LogicalTypeProperties) {
+ // For each logical type, ensure that the correct general properties are reported
+
+ struct ExpectedProperties {
+ std::shared_ptr<const LogicalType> logical_type;
+ bool nested;
+ bool serialized;
+ bool valid;
+ };
+
+ std::vector<ExpectedProperties> cases = {
+ {StringLogicalType::Make(), false, true, true},
+ {MapLogicalType::Make(), true, true, true},
+ {ListLogicalType::Make(), true, true, true},
+ {EnumLogicalType::Make(), false, true, true},
+ {DecimalLogicalType::Make(16, 6), false, true, true},
+ {DateLogicalType::Make(), false, true, true},
+ {TimeLogicalType::Make(true, LogicalType::TimeUnit::MICROS), false, true, true},
+ {TimestampLogicalType::Make(true, LogicalType::TimeUnit::MICROS), false, true,
+ true},
+ {IntervalLogicalType::Make(), false, true, true},
+ {IntLogicalType::Make(8, false), false, true, true},
+ {IntLogicalType::Make(64, true), false, true, true},
+ {NullLogicalType::Make(), false, true, true},
+ {JSONLogicalType::Make(), false, true, true},
+ {BSONLogicalType::Make(), false, true, true},
+ {UUIDLogicalType::Make(), false, true, true},
+ {NoLogicalType::Make(), false, false, true},
+ };
+
+ for (const ExpectedProperties& c : cases) {
+ ConfirmLogicalTypeProperties(c.logical_type, c.nested, c.serialized, c.valid);
+ }
+}
+
+static constexpr int PHYSICAL_TYPE_COUNT = 8;
+
+static Type::type physical_type[PHYSICAL_TYPE_COUNT] = {
+ Type::BOOLEAN, Type::INT32, Type::INT64, Type::INT96,
+ Type::FLOAT, Type::DOUBLE, Type::BYTE_ARRAY, Type::FIXED_LEN_BYTE_ARRAY};
+
+static void ConfirmSinglePrimitiveTypeApplicability(
+ const std::shared_ptr<const LogicalType>& logical_type, Type::type applicable_type) {
+ for (int i = 0; i < PHYSICAL_TYPE_COUNT; ++i) {
+ if (physical_type[i] == applicable_type) {
+ ASSERT_TRUE(logical_type->is_applicable(physical_type[i]))
+ << logical_type->ToString()
+ << " logical type unexpectedly inapplicable to physical type "
+ << TypeToString(physical_type[i]);
+ } else {
+ ASSERT_FALSE(logical_type->is_applicable(physical_type[i]))
+ << logical_type->ToString()
+ << " logical type unexpectedly applicable to physical type "
+ << TypeToString(physical_type[i]);
+ }
+ }
+ return;
+}
+
+static void ConfirmAnyPrimitiveTypeApplicability(
+ const std::shared_ptr<const LogicalType>& logical_type) {
+ for (int i = 0; i < PHYSICAL_TYPE_COUNT; ++i) {
+ ASSERT_TRUE(logical_type->is_applicable(physical_type[i]))
+ << logical_type->ToString()
+ << " logical type unexpectedly inapplicable to physical type "
+ << TypeToString(physical_type[i]);
+ }
+ return;
+}
+
+static void ConfirmNoPrimitiveTypeApplicability(
+ const std::shared_ptr<const LogicalType>& logical_type) {
+ for (int i = 0; i < PHYSICAL_TYPE_COUNT; ++i) {
+ ASSERT_FALSE(logical_type->is_applicable(physical_type[i]))
+ << logical_type->ToString()
+ << " logical type unexpectedly applicable to physical type "
+ << TypeToString(physical_type[i]);
+ }
+ return;
+}
+
+TEST(TestLogicalTypeOperation, LogicalTypeApplicability) {
+ // Check that each logical type correctly reports which
+ // underlying primitive type(s) it can be applied to
+
+ struct ExpectedApplicability {
+ std::shared_ptr<const LogicalType> logical_type;
+ Type::type applicable_type;
+ };
+
+ std::vector<ExpectedApplicability> single_type_cases = {
+ {LogicalType::String(), Type::BYTE_ARRAY},
+ {LogicalType::Enum(), Type::BYTE_ARRAY},
+ {LogicalType::Date(), Type::INT32},
+ {LogicalType::Time(true, LogicalType::TimeUnit::MILLIS), Type::INT32},
+ {LogicalType::Time(true, LogicalType::TimeUnit::MICROS), Type::INT64},
+ {LogicalType::Time(true, LogicalType::TimeUnit::NANOS), Type::INT64},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS), Type::INT64},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), Type::INT64},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::NANOS), Type::INT64},
+ {LogicalType::Int(8, false), Type::INT32},
+ {LogicalType::Int(16, false), Type::INT32},
+ {LogicalType::Int(32, false), Type::INT32},
+ {LogicalType::Int(64, false), Type::INT64},
+ {LogicalType::Int(8, true), Type::INT32},
+ {LogicalType::Int(16, true), Type::INT32},
+ {LogicalType::Int(32, true), Type::INT32},
+ {LogicalType::Int(64, true), Type::INT64},
+ {LogicalType::JSON(), Type::BYTE_ARRAY},
+ {LogicalType::BSON(), Type::BYTE_ARRAY}};
+
+ for (const ExpectedApplicability& c : single_type_cases) {
+ ConfirmSinglePrimitiveTypeApplicability(c.logical_type, c.applicable_type);
+ }
+
+ std::vector<std::shared_ptr<const LogicalType>> no_type_cases = {LogicalType::Map(),
+ LogicalType::List()};
+
+ for (auto c : no_type_cases) {
+ ConfirmNoPrimitiveTypeApplicability(c);
+ }
+
+ std::vector<std::shared_ptr<const LogicalType>> any_type_cases = {
+ LogicalType::Null(), LogicalType::None(), UndefinedLogicalType::Make()};
+
+ for (auto c : any_type_cases) {
+ ConfirmAnyPrimitiveTypeApplicability(c);
+ }
+
+ // Fixed binary, exact length cases ...
+
+ struct InapplicableType {
+ Type::type physical_type;
+ int physical_length;
+ };
+
+ std::vector<InapplicableType> inapplicable_types = {{Type::FIXED_LEN_BYTE_ARRAY, 8},
+ {Type::FIXED_LEN_BYTE_ARRAY, 20},
+ {Type::BOOLEAN, -1},
+ {Type::INT32, -1},
+ {Type::INT64, -1},
+ {Type::INT96, -1},
+ {Type::FLOAT, -1},
+ {Type::DOUBLE, -1},
+ {Type::BYTE_ARRAY, -1}};
+
+ std::shared_ptr<const LogicalType> logical_type;
+
+ logical_type = LogicalType::Interval();
+ ASSERT_TRUE(logical_type->is_applicable(Type::FIXED_LEN_BYTE_ARRAY, 12));
+ for (const InapplicableType& t : inapplicable_types) {
+ ASSERT_FALSE(logical_type->is_applicable(t.physical_type, t.physical_length));
+ }
+
+ logical_type = LogicalType::UUID();
+ ASSERT_TRUE(logical_type->is_applicable(Type::FIXED_LEN_BYTE_ARRAY, 16));
+ for (const InapplicableType& t : inapplicable_types) {
+ ASSERT_FALSE(logical_type->is_applicable(t.physical_type, t.physical_length));
+ }
+}
+
+TEST(TestLogicalTypeOperation, DecimalLogicalTypeApplicability) {
+ // Check that the decimal logical type correctly reports which
+ // underlying primitive type(s) it can be applied to
+
+ std::shared_ptr<const LogicalType> logical_type;
+
+ for (int32_t precision = 1; precision <= 9; ++precision) {
+ logical_type = DecimalLogicalType::Make(precision, 0);
+ ASSERT_TRUE(logical_type->is_applicable(Type::INT32))
+ << logical_type->ToString()
+ << " unexpectedly inapplicable to physical type INT32";
+ }
+ logical_type = DecimalLogicalType::Make(10, 0);
+ ASSERT_FALSE(logical_type->is_applicable(Type::INT32))
+ << logical_type->ToString() << " unexpectedly applicable to physical type INT32";
+
+ for (int32_t precision = 1; precision <= 18; ++precision) {
+ logical_type = DecimalLogicalType::Make(precision, 0);
+ ASSERT_TRUE(logical_type->is_applicable(Type::INT64))
+ << logical_type->ToString()
+ << " unexpectedly inapplicable to physical type INT64";
+ }
+ logical_type = DecimalLogicalType::Make(19, 0);
+ ASSERT_FALSE(logical_type->is_applicable(Type::INT64))
+ << logical_type->ToString() << " unexpectedly applicable to physical type INT64";
+
+ for (int32_t precision = 1; precision <= 36; ++precision) {
+ logical_type = DecimalLogicalType::Make(precision, 0);
+ ASSERT_TRUE(logical_type->is_applicable(Type::BYTE_ARRAY))
+ << logical_type->ToString()
+ << " unexpectedly inapplicable to physical type BYTE_ARRAY";
+ }
+
+ struct PrecisionLimits {
+ int32_t physical_length;
+ int32_t precision_limit;
+ };
+
+ std::vector<PrecisionLimits> cases = {{1, 2}, {2, 4}, {3, 6}, {4, 9}, {8, 18},
+ {10, 23}, {16, 38}, {20, 47}, {32, 76}};
+
+ for (const PrecisionLimits& c : cases) {
+ int32_t precision;
+ for (precision = 1; precision <= c.precision_limit; ++precision) {
+ logical_type = DecimalLogicalType::Make(precision, 0);
+ ASSERT_TRUE(
+ logical_type->is_applicable(Type::FIXED_LEN_BYTE_ARRAY, c.physical_length))
+ << logical_type->ToString()
+ << " unexpectedly inapplicable to physical type FIXED_LEN_BYTE_ARRAY with "
+ "length "
+ << c.physical_length;
+ }
+ logical_type = DecimalLogicalType::Make(precision, 0);
+ ASSERT_FALSE(
+ logical_type->is_applicable(Type::FIXED_LEN_BYTE_ARRAY, c.physical_length))
+ << logical_type->ToString()
+ << " unexpectedly applicable to physical type FIXED_LEN_BYTE_ARRAY with length "
+ << c.physical_length;
+ }
+
+ ASSERT_FALSE((DecimalLogicalType::Make(16, 6))->is_applicable(Type::BOOLEAN));
+ ASSERT_FALSE((DecimalLogicalType::Make(16, 6))->is_applicable(Type::FLOAT));
+ ASSERT_FALSE((DecimalLogicalType::Make(16, 6))->is_applicable(Type::DOUBLE));
+}
+
+TEST(TestLogicalTypeOperation, LogicalTypeRepresentation) {
+ // Ensure that each logical type prints a correct string and
+ // JSON representation
+
+ struct ExpectedRepresentation {
+ std::shared_ptr<const LogicalType> logical_type;
+ const char* string_representation;
+ const char* JSON_representation;
+ };
+
+ std::vector<ExpectedRepresentation> cases = {
+ {UndefinedLogicalType::Make(), "Undefined", R"({"Type": "Undefined"})"},
+ {LogicalType::String(), "String", R"({"Type": "String"})"},
+ {LogicalType::Map(), "Map", R"({"Type": "Map"})"},
+ {LogicalType::List(), "List", R"({"Type": "List"})"},
+ {LogicalType::Enum(), "Enum", R"({"Type": "Enum"})"},
+ {LogicalType::Decimal(10, 4), "Decimal(precision=10, scale=4)",
+ R"({"Type": "Decimal", "precision": 10, "scale": 4})"},
+ {LogicalType::Decimal(10), "Decimal(precision=10, scale=0)",
+ R"({"Type": "Decimal", "precision": 10, "scale": 0})"},
+ {LogicalType::Date(), "Date", R"({"Type": "Date"})"},
+ {LogicalType::Time(true, LogicalType::TimeUnit::MILLIS),
+ "Time(isAdjustedToUTC=true, timeUnit=milliseconds)",
+ R"({"Type": "Time", "isAdjustedToUTC": true, "timeUnit": "milliseconds"})"},
+ {LogicalType::Time(true, LogicalType::TimeUnit::MICROS),
+ "Time(isAdjustedToUTC=true, timeUnit=microseconds)",
+ R"({"Type": "Time", "isAdjustedToUTC": true, "timeUnit": "microseconds"})"},
+ {LogicalType::Time(true, LogicalType::TimeUnit::NANOS),
+ "Time(isAdjustedToUTC=true, timeUnit=nanoseconds)",
+ R"({"Type": "Time", "isAdjustedToUTC": true, "timeUnit": "nanoseconds"})"},
+ {LogicalType::Time(false, LogicalType::TimeUnit::MILLIS),
+ "Time(isAdjustedToUTC=false, timeUnit=milliseconds)",
+ R"({"Type": "Time", "isAdjustedToUTC": false, "timeUnit": "milliseconds"})"},
+ {LogicalType::Time(false, LogicalType::TimeUnit::MICROS),
+ "Time(isAdjustedToUTC=false, timeUnit=microseconds)",
+ R"({"Type": "Time", "isAdjustedToUTC": false, "timeUnit": "microseconds"})"},
+ {LogicalType::Time(false, LogicalType::TimeUnit::NANOS),
+ "Time(isAdjustedToUTC=false, timeUnit=nanoseconds)",
+ R"({"Type": "Time", "isAdjustedToUTC": false, "timeUnit": "nanoseconds"})"},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS),
+ "Timestamp(isAdjustedToUTC=true, timeUnit=milliseconds, "
+ "is_from_converted_type=false, force_set_converted_type=false)",
+ R"({"Type": "Timestamp", "isAdjustedToUTC": true, "timeUnit": "milliseconds", )"
+ R"("is_from_converted_type": false, "force_set_converted_type": false})"},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS),
+ "Timestamp(isAdjustedToUTC=true, timeUnit=microseconds, "
+ "is_from_converted_type=false, force_set_converted_type=false)",
+ R"({"Type": "Timestamp", "isAdjustedToUTC": true, "timeUnit": "microseconds", )"
+ R"("is_from_converted_type": false, "force_set_converted_type": false})"},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::NANOS),
+ "Timestamp(isAdjustedToUTC=true, timeUnit=nanoseconds, "
+ "is_from_converted_type=false, force_set_converted_type=false)",
+ R"({"Type": "Timestamp", "isAdjustedToUTC": true, "timeUnit": "nanoseconds", )"
+ R"("is_from_converted_type": false, "force_set_converted_type": false})"},
+ {LogicalType::Timestamp(false, LogicalType::TimeUnit::MILLIS, true, true),
+ "Timestamp(isAdjustedToUTC=false, timeUnit=milliseconds, "
+ "is_from_converted_type=true, force_set_converted_type=true)",
+ R"({"Type": "Timestamp", "isAdjustedToUTC": false, "timeUnit": "milliseconds", )"
+ R"("is_from_converted_type": true, "force_set_converted_type": true})"},
+ {LogicalType::Timestamp(false, LogicalType::TimeUnit::MICROS),
+ "Timestamp(isAdjustedToUTC=false, timeUnit=microseconds, "
+ "is_from_converted_type=false, force_set_converted_type=false)",
+ R"({"Type": "Timestamp", "isAdjustedToUTC": false, "timeUnit": "microseconds", )"
+ R"("is_from_converted_type": false, "force_set_converted_type": false})"},
+ {LogicalType::Timestamp(false, LogicalType::TimeUnit::NANOS),
+ "Timestamp(isAdjustedToUTC=false, timeUnit=nanoseconds, "
+ "is_from_converted_type=false, force_set_converted_type=false)",
+ R"({"Type": "Timestamp", "isAdjustedToUTC": false, "timeUnit": "nanoseconds", )"
+ R"("is_from_converted_type": false, "force_set_converted_type": false})"},
+ {LogicalType::Interval(), "Interval", R"({"Type": "Interval"})"},
+ {LogicalType::Int(8, false), "Int(bitWidth=8, isSigned=false)",
+ R"({"Type": "Int", "bitWidth": 8, "isSigned": false})"},
+ {LogicalType::Int(16, false), "Int(bitWidth=16, isSigned=false)",
+ R"({"Type": "Int", "bitWidth": 16, "isSigned": false})"},
+ {LogicalType::Int(32, false), "Int(bitWidth=32, isSigned=false)",
+ R"({"Type": "Int", "bitWidth": 32, "isSigned": false})"},
+ {LogicalType::Int(64, false), "Int(bitWidth=64, isSigned=false)",
+ R"({"Type": "Int", "bitWidth": 64, "isSigned": false})"},
+ {LogicalType::Int(8, true), "Int(bitWidth=8, isSigned=true)",
+ R"({"Type": "Int", "bitWidth": 8, "isSigned": true})"},
+ {LogicalType::Int(16, true), "Int(bitWidth=16, isSigned=true)",
+ R"({"Type": "Int", "bitWidth": 16, "isSigned": true})"},
+ {LogicalType::Int(32, true), "Int(bitWidth=32, isSigned=true)",
+ R"({"Type": "Int", "bitWidth": 32, "isSigned": true})"},
+ {LogicalType::Int(64, true), "Int(bitWidth=64, isSigned=true)",
+ R"({"Type": "Int", "bitWidth": 64, "isSigned": true})"},
+ {LogicalType::Null(), "Null", R"({"Type": "Null"})"},
+ {LogicalType::JSON(), "JSON", R"({"Type": "JSON"})"},
+ {LogicalType::BSON(), "BSON", R"({"Type": "BSON"})"},
+ {LogicalType::UUID(), "UUID", R"({"Type": "UUID"})"},
+ {LogicalType::None(), "None", R"({"Type": "None"})"},
+ };
+
+ for (const ExpectedRepresentation& c : cases) {
+ ASSERT_STREQ(c.logical_type->ToString().c_str(), c.string_representation);
+ ASSERT_STREQ(c.logical_type->ToJSON().c_str(), c.JSON_representation);
+ }
+}
+
+TEST(TestLogicalTypeOperation, LogicalTypeSortOrder) {
+ // Ensure that each logical type reports the correct sort order
+
+ struct ExpectedSortOrder {
+ std::shared_ptr<const LogicalType> logical_type;
+ SortOrder::type sort_order;
+ };
+
+ std::vector<ExpectedSortOrder> cases = {
+ {LogicalType::String(), SortOrder::UNSIGNED},
+ {LogicalType::Map(), SortOrder::UNKNOWN},
+ {LogicalType::List(), SortOrder::UNKNOWN},
+ {LogicalType::Enum(), SortOrder::UNSIGNED},
+ {LogicalType::Decimal(8, 2), SortOrder::SIGNED},
+ {LogicalType::Date(), SortOrder::SIGNED},
+ {LogicalType::Time(true, LogicalType::TimeUnit::MILLIS), SortOrder::SIGNED},
+ {LogicalType::Time(true, LogicalType::TimeUnit::MICROS), SortOrder::SIGNED},
+ {LogicalType::Time(true, LogicalType::TimeUnit::NANOS), SortOrder::SIGNED},
+ {LogicalType::Time(false, LogicalType::TimeUnit::MILLIS), SortOrder::SIGNED},
+ {LogicalType::Time(false, LogicalType::TimeUnit::MICROS), SortOrder::SIGNED},
+ {LogicalType::Time(false, LogicalType::TimeUnit::NANOS), SortOrder::SIGNED},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS), SortOrder::SIGNED},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), SortOrder::SIGNED},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::NANOS), SortOrder::SIGNED},
+ {LogicalType::Timestamp(false, LogicalType::TimeUnit::MILLIS), SortOrder::SIGNED},
+ {LogicalType::Timestamp(false, LogicalType::TimeUnit::MICROS), SortOrder::SIGNED},
+ {LogicalType::Timestamp(false, LogicalType::TimeUnit::NANOS), SortOrder::SIGNED},
+ {LogicalType::Interval(), SortOrder::UNKNOWN},
+ {LogicalType::Int(8, false), SortOrder::UNSIGNED},
+ {LogicalType::Int(16, false), SortOrder::UNSIGNED},
+ {LogicalType::Int(32, false), SortOrder::UNSIGNED},
+ {LogicalType::Int(64, false), SortOrder::UNSIGNED},
+ {LogicalType::Int(8, true), SortOrder::SIGNED},
+ {LogicalType::Int(16, true), SortOrder::SIGNED},
+ {LogicalType::Int(32, true), SortOrder::SIGNED},
+ {LogicalType::Int(64, true), SortOrder::SIGNED},
+ {LogicalType::Null(), SortOrder::UNKNOWN},
+ {LogicalType::JSON(), SortOrder::UNSIGNED},
+ {LogicalType::BSON(), SortOrder::UNSIGNED},
+ {LogicalType::UUID(), SortOrder::UNSIGNED},
+ {LogicalType::None(), SortOrder::UNKNOWN}};
+
+ for (const ExpectedSortOrder& c : cases) {
+ ASSERT_EQ(c.logical_type->sort_order(), c.sort_order)
+ << c.logical_type->ToString() << " logical type has incorrect sort order";
+ }
+}
+
+static void ConfirmPrimitiveNodeFactoryEquivalence(
+ const std::shared_ptr<const LogicalType>& logical_type,
+ ConvertedType::type converted_type, Type::type physical_type, int physical_length,
+ int precision, int scale) {
+ std::string name = "something";
+ Repetition::type repetition = Repetition::REQUIRED;
+ NodePtr from_converted_type = PrimitiveNode::Make(
+ name, repetition, physical_type, converted_type, physical_length, precision, scale);
+ NodePtr from_logical_type =
+ PrimitiveNode::Make(name, repetition, logical_type, physical_type, physical_length);
+ ASSERT_TRUE(from_converted_type->Equals(from_logical_type.get()))
+ << "Primitive node constructed with converted type "
+ << ConvertedTypeToString(converted_type)
+ << " unexpectedly not equivalent to primitive node constructed with logical "
+ "type "
+ << logical_type->ToString();
+ return;
+}
+
+static void ConfirmGroupNodeFactoryEquivalence(
+ std::string name, const std::shared_ptr<const LogicalType>& logical_type,
+ ConvertedType::type converted_type) {
+ Repetition::type repetition = Repetition::OPTIONAL;
+ NodePtr from_converted_type = GroupNode::Make(name, repetition, {}, converted_type);
+ NodePtr from_logical_type = GroupNode::Make(name, repetition, {}, logical_type);
+ ASSERT_TRUE(from_converted_type->Equals(from_logical_type.get()))
+ << "Group node constructed with converted type "
+ << ConvertedTypeToString(converted_type)
+ << " unexpectedly not equivalent to group node constructed with logical type "
+ << logical_type->ToString();
+ return;
+}
+
+TEST(TestSchemaNodeCreation, FactoryEquivalence) {
+ // Ensure that the Node factory methods produce equivalent results regardless
+ // of whether they are given a converted type or a logical type.
+
+ // Primitive nodes ...
+
+ struct PrimitiveNodeFactoryArguments {
+ std::shared_ptr<const LogicalType> logical_type;
+ ConvertedType::type converted_type;
+ Type::type physical_type;
+ int physical_length;
+ int precision;
+ int scale;
+ };
+
+ std::vector<PrimitiveNodeFactoryArguments> cases = {
+ {LogicalType::String(), ConvertedType::UTF8, Type::BYTE_ARRAY, -1, -1, -1},
+ {LogicalType::Enum(), ConvertedType::ENUM, Type::BYTE_ARRAY, -1, -1, -1},
+ {LogicalType::Decimal(16, 6), ConvertedType::DECIMAL, Type::INT64, -1, 16, 6},
+ {LogicalType::Date(), ConvertedType::DATE, Type::INT32, -1, -1, -1},
+ {LogicalType::Time(true, LogicalType::TimeUnit::MILLIS), ConvertedType::TIME_MILLIS,
+ Type::INT32, -1, -1, -1},
+ {LogicalType::Time(true, LogicalType::TimeUnit::MICROS), ConvertedType::TIME_MICROS,
+ Type::INT64, -1, -1, -1},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS),
+ ConvertedType::TIMESTAMP_MILLIS, Type::INT64, -1, -1, -1},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS),
+ ConvertedType::TIMESTAMP_MICROS, Type::INT64, -1, -1, -1},
+ {LogicalType::Interval(), ConvertedType::INTERVAL, Type::FIXED_LEN_BYTE_ARRAY, 12,
+ -1, -1},
+ {LogicalType::Int(8, false), ConvertedType::UINT_8, Type::INT32, -1, -1, -1},
+ {LogicalType::Int(8, true), ConvertedType::INT_8, Type::INT32, -1, -1, -1},
+ {LogicalType::Int(16, false), ConvertedType::UINT_16, Type::INT32, -1, -1, -1},
+ {LogicalType::Int(16, true), ConvertedType::INT_16, Type::INT32, -1, -1, -1},
+ {LogicalType::Int(32, false), ConvertedType::UINT_32, Type::INT32, -1, -1, -1},
+ {LogicalType::Int(32, true), ConvertedType::INT_32, Type::INT32, -1, -1, -1},
+ {LogicalType::Int(64, false), ConvertedType::UINT_64, Type::INT64, -1, -1, -1},
+ {LogicalType::Int(64, true), ConvertedType::INT_64, Type::INT64, -1, -1, -1},
+ {LogicalType::JSON(), ConvertedType::JSON, Type::BYTE_ARRAY, -1, -1, -1},
+ {LogicalType::BSON(), ConvertedType::BSON, Type::BYTE_ARRAY, -1, -1, -1},
+ {LogicalType::None(), ConvertedType::NONE, Type::INT64, -1, -1, -1}};
+
+ for (const PrimitiveNodeFactoryArguments& c : cases) {
+ ConfirmPrimitiveNodeFactoryEquivalence(c.logical_type, c.converted_type,
+ c.physical_type, c.physical_length,
+ c.precision, c.scale);
+ }
+
+ // Group nodes ...
+ ConfirmGroupNodeFactoryEquivalence("map", LogicalType::Map(), ConvertedType::MAP);
+ ConfirmGroupNodeFactoryEquivalence("list", LogicalType::List(), ConvertedType::LIST);
+}
+
+TEST(TestSchemaNodeCreation, FactoryExceptions) {
+ // Ensure that the Node factory method that accepts a logical type refuses to create
+ // an object if compatibility conditions are not met
+
+ // Nested logical type on non-group node ...
+ ASSERT_ANY_THROW(PrimitiveNode::Make("map", Repetition::REQUIRED,
+ MapLogicalType::Make(), Type::INT64));
+ // Incompatible primitive type ...
+ ASSERT_ANY_THROW(PrimitiveNode::Make("string", Repetition::REQUIRED,
+ StringLogicalType::Make(), Type::BOOLEAN));
+ // Incompatible primitive length ...
+ ASSERT_ANY_THROW(PrimitiveNode::Make("interval", Repetition::REQUIRED,
+ IntervalLogicalType::Make(),
+ Type::FIXED_LEN_BYTE_ARRAY, 11));
+ // Primitive too small for given precision ...
+ ASSERT_ANY_THROW(PrimitiveNode::Make("decimal", Repetition::REQUIRED,
+ DecimalLogicalType::Make(16, 6), Type::INT32));
+ // Incompatible primitive length ...
+ ASSERT_ANY_THROW(PrimitiveNode::Make("uuid", Repetition::REQUIRED,
+ UUIDLogicalType::Make(),
+ Type::FIXED_LEN_BYTE_ARRAY, 64));
+ // Non-positive length argument for fixed length binary ...
+ ASSERT_ANY_THROW(PrimitiveNode::Make("negative_length", Repetition::REQUIRED,
+ NoLogicalType::Make(), Type::FIXED_LEN_BYTE_ARRAY,
+ -16));
+ // Non-positive length argument for fixed length binary ...
+ ASSERT_ANY_THROW(PrimitiveNode::Make("zero_length", Repetition::REQUIRED,
+ NoLogicalType::Make(), Type::FIXED_LEN_BYTE_ARRAY,
+ 0));
+ // Non-nested logical type on group node ...
+ ASSERT_ANY_THROW(
+ GroupNode::Make("list", Repetition::REPEATED, {}, JSONLogicalType::Make()));
+
+ // nullptr logical type arguments convert to NoLogicalType/ConvertedType::NONE
+ std::shared_ptr<const LogicalType> empty;
+ NodePtr node;
+ ASSERT_NO_THROW(
+ node = PrimitiveNode::Make("value", Repetition::REQUIRED, empty, Type::DOUBLE));
+ ASSERT_TRUE(node->logical_type()->is_none());
+ ASSERT_EQ(node->converted_type(), ConvertedType::NONE);
+ ASSERT_NO_THROW(node = GroupNode::Make("items", Repetition::REPEATED, {}, empty));
+ ASSERT_TRUE(node->logical_type()->is_none());
+ ASSERT_EQ(node->converted_type(), ConvertedType::NONE);
+
+ // Invalid ConvertedType in deserialized element ...
+ node = PrimitiveNode::Make("string", Repetition::REQUIRED, StringLogicalType::Make(),
+ Type::BYTE_ARRAY);
+ ASSERT_EQ(node->logical_type()->type(), LogicalType::Type::STRING);
+ ASSERT_TRUE(node->logical_type()->is_valid());
+ ASSERT_TRUE(node->logical_type()->is_serialized());
+ format::SchemaElement string_intermediary;
+ node->ToParquet(&string_intermediary);
+ // ... corrupt the Thrift intermediary ....
+ string_intermediary.logicalType.__isset.STRING = false;
+ ASSERT_ANY_THROW(node = PrimitiveNode::FromParquet(&string_intermediary));
+
+ // Invalid TimeUnit in deserialized TimeLogicalType ...
+ node = PrimitiveNode::Make("time", Repetition::REQUIRED,
+ TimeLogicalType::Make(true, LogicalType::TimeUnit::NANOS),
+ Type::INT64);
+ format::SchemaElement time_intermediary;
+ node->ToParquet(&time_intermediary);
+ // ... corrupt the Thrift intermediary ....
+ time_intermediary.logicalType.TIME.unit.__isset.NANOS = false;
+ ASSERT_ANY_THROW(PrimitiveNode::FromParquet(&time_intermediary));
+
+ // Invalid TimeUnit in deserialized TimestampLogicalType ...
+ node = PrimitiveNode::Make(
+ "timestamp", Repetition::REQUIRED,
+ TimestampLogicalType::Make(true, LogicalType::TimeUnit::NANOS), Type::INT64);
+ format::SchemaElement timestamp_intermediary;
+ node->ToParquet(&timestamp_intermediary);
+ // ... corrupt the Thrift intermediary ....
+ timestamp_intermediary.logicalType.TIMESTAMP.unit.__isset.NANOS = false;
+ ASSERT_ANY_THROW(PrimitiveNode::FromParquet(&timestamp_intermediary));
+}
+
+struct SchemaElementConstructionArguments {
+ std::string name;
+ std::shared_ptr<const LogicalType> logical_type;
+ Type::type physical_type;
+ int physical_length;
+ bool expect_converted_type;
+ ConvertedType::type converted_type;
+ bool expect_logicalType;
+ std::function<bool()> check_logicalType;
+};
+
+struct LegacySchemaElementConstructionArguments {
+ std::string name;
+ Type::type physical_type;
+ int physical_length;
+ bool expect_converted_type;
+ ConvertedType::type converted_type;
+ bool expect_logicalType;
+ std::function<bool()> check_logicalType;
+};
+
+class TestSchemaElementConstruction : public ::testing::Test {
+ public:
+ TestSchemaElementConstruction* Reconstruct(
+ const SchemaElementConstructionArguments& c) {
+ // Make node, create serializable Thrift object from it ...
+ node_ = PrimitiveNode::Make(c.name, Repetition::REQUIRED, c.logical_type,
+ c.physical_type, c.physical_length);
+ element_.reset(new format::SchemaElement);
+ node_->ToParquet(element_.get());
+
+ // ... then set aside some values for later inspection.
+ name_ = c.name;
+ expect_converted_type_ = c.expect_converted_type;
+ converted_type_ = c.converted_type;
+ expect_logicalType_ = c.expect_logicalType;
+ check_logicalType_ = c.check_logicalType;
+ return this;
+ }
+
+ TestSchemaElementConstruction* LegacyReconstruct(
+ const LegacySchemaElementConstructionArguments& c) {
+ // Make node, create serializable Thrift object from it ...
+ node_ = PrimitiveNode::Make(c.name, Repetition::REQUIRED, c.physical_type,
+ c.converted_type, c.physical_length);
+ element_.reset(new format::SchemaElement);
+ node_->ToParquet(element_.get());
+
+ // ... then set aside some values for later inspection.
+ name_ = c.name;
+ expect_converted_type_ = c.expect_converted_type;
+ converted_type_ = c.converted_type;
+ expect_logicalType_ = c.expect_logicalType;
+ check_logicalType_ = c.check_logicalType;
+ return this;
+ }
+
+ void Inspect() {
+ ASSERT_EQ(element_->name, name_);
+ if (expect_converted_type_) {
+ ASSERT_TRUE(element_->__isset.converted_type)
+ << node_->logical_type()->ToString()
+ << " logical type unexpectedly failed to generate a converted type in the "
+ "Thrift "
+ "intermediate object";
+ ASSERT_EQ(element_->converted_type, ToThrift(converted_type_))
+ << node_->logical_type()->ToString()
+ << " logical type unexpectedly failed to generate correct converted type in "
+ "the "
+ "Thrift intermediate object";
+ } else {
+ ASSERT_FALSE(element_->__isset.converted_type)
+ << node_->logical_type()->ToString()
+ << " logical type unexpectedly generated a converted type in the Thrift "
+ "intermediate object";
+ }
+ if (expect_logicalType_) {
+ ASSERT_TRUE(element_->__isset.logicalType)
+ << node_->logical_type()->ToString()
+ << " logical type unexpectedly failed to genverate a logicalType in the Thrift "
+ "intermediate object";
+ ASSERT_TRUE(check_logicalType_())
+ << node_->logical_type()->ToString()
+ << " logical type generated incorrect logicalType "
+ "settings in the Thrift intermediate object";
+ } else {
+ ASSERT_FALSE(element_->__isset.logicalType)
+ << node_->logical_type()->ToString()
+ << " logical type unexpectedly generated a logicalType in the Thrift "
+ "intermediate object";
+ }
+ return;
+ }
+
+ protected:
+ NodePtr node_;
+ std::unique_ptr<format::SchemaElement> element_;
+ std::string name_;
+ bool expect_converted_type_;
+ ConvertedType::type converted_type_; // expected converted type in Thrift object
+ bool expect_logicalType_;
+ std::function<bool()> check_logicalType_; // specialized (by logical type)
+ // logicalType check for Thrift object
+};
+
+/*
+ * The Test*SchemaElementConstruction suites confirm that the logical type
+ * and converted type members of the Thrift intermediate message object
+ * (format::SchemaElement) that is created upon serialization of an annotated
+ * schema node are correctly populated.
+ */
+
+TEST_F(TestSchemaElementConstruction, SimpleCases) {
+ auto check_nothing = []() {
+ return true;
+ }; // used for logical types that don't expect a logicalType to be set
+
+ std::vector<SchemaElementConstructionArguments> cases = {
+ {"string", LogicalType::String(), Type::BYTE_ARRAY, -1, true, ConvertedType::UTF8,
+ true, [this]() { return element_->logicalType.__isset.STRING; }},
+ {"enum", LogicalType::Enum(), Type::BYTE_ARRAY, -1, true, ConvertedType::ENUM, true,
+ [this]() { return element_->logicalType.__isset.ENUM; }},
+ {"date", LogicalType::Date(), Type::INT32, -1, true, ConvertedType::DATE, true,
+ [this]() { return element_->logicalType.__isset.DATE; }},
+ {"interval", LogicalType::Interval(), Type::FIXED_LEN_BYTE_ARRAY, 12, true,
+ ConvertedType::INTERVAL, false, check_nothing},
+ {"null", LogicalType::Null(), Type::DOUBLE, -1, false, ConvertedType::NA, true,
+ [this]() { return element_->logicalType.__isset.UNKNOWN; }},
+ {"json", LogicalType::JSON(), Type::BYTE_ARRAY, -1, true, ConvertedType::JSON, true,
+ [this]() { return element_->logicalType.__isset.JSON; }},
+ {"bson", LogicalType::BSON(), Type::BYTE_ARRAY, -1, true, ConvertedType::BSON, true,
+ [this]() { return element_->logicalType.__isset.BSON; }},
+ {"uuid", LogicalType::UUID(), Type::FIXED_LEN_BYTE_ARRAY, 16, false,
+ ConvertedType::NA, true, [this]() { return element_->logicalType.__isset.UUID; }},
+ {"none", LogicalType::None(), Type::INT64, -1, false, ConvertedType::NA, false,
+ check_nothing}};
+
+ for (const SchemaElementConstructionArguments& c : cases) {
+ this->Reconstruct(c)->Inspect();
+ }
+
+ std::vector<LegacySchemaElementConstructionArguments> legacy_cases = {
+ {"timestamp_ms", Type::INT64, -1, true, ConvertedType::TIMESTAMP_MILLIS, false,
+ check_nothing},
+ {"timestamp_us", Type::INT64, -1, true, ConvertedType::TIMESTAMP_MICROS, false,
+ check_nothing},
+ };
+
+ for (const LegacySchemaElementConstructionArguments& c : legacy_cases) {
+ this->LegacyReconstruct(c)->Inspect();
+ }
+}
+
+class TestDecimalSchemaElementConstruction : public TestSchemaElementConstruction {
+ public:
+ TestDecimalSchemaElementConstruction* Reconstruct(
+ const SchemaElementConstructionArguments& c) {
+ TestSchemaElementConstruction::Reconstruct(c);
+ const auto& decimal_logical_type =
+ checked_cast<const DecimalLogicalType&>(*c.logical_type);
+ precision_ = decimal_logical_type.precision();
+ scale_ = decimal_logical_type.scale();
+ return this;
+ }
+
+ void Inspect() {
+ TestSchemaElementConstruction::Inspect();
+ ASSERT_EQ(element_->precision, precision_);
+ ASSERT_EQ(element_->scale, scale_);
+ ASSERT_EQ(element_->logicalType.DECIMAL.precision, precision_);
+ ASSERT_EQ(element_->logicalType.DECIMAL.scale, scale_);
+ return;
+ }
+
+ protected:
+ int32_t precision_;
+ int32_t scale_;
+};
+
+TEST_F(TestDecimalSchemaElementConstruction, DecimalCases) {
+ auto check_DECIMAL = [this]() { return element_->logicalType.__isset.DECIMAL; };
+
+ std::vector<SchemaElementConstructionArguments> cases = {
+ {"decimal", LogicalType::Decimal(16, 6), Type::INT64, -1, true,
+ ConvertedType::DECIMAL, true, check_DECIMAL},
+ {"decimal", LogicalType::Decimal(1, 0), Type::INT32, -1, true,
+ ConvertedType::DECIMAL, true, check_DECIMAL},
+ {"decimal", LogicalType::Decimal(10), Type::INT64, -1, true, ConvertedType::DECIMAL,
+ true, check_DECIMAL},
+ {"decimal", LogicalType::Decimal(11, 11), Type::INT64, -1, true,
+ ConvertedType::DECIMAL, true, check_DECIMAL},
+ };
+
+ for (const SchemaElementConstructionArguments& c : cases) {
+ this->Reconstruct(c)->Inspect();
+ }
+}
+
+class TestTemporalSchemaElementConstruction : public TestSchemaElementConstruction {
+ public:
+ template <typename T>
+ TestTemporalSchemaElementConstruction* Reconstruct(
+ const SchemaElementConstructionArguments& c) {
+ TestSchemaElementConstruction::Reconstruct(c);
+ const auto& t = checked_cast<const T&>(*c.logical_type);
+ adjusted_ = t.is_adjusted_to_utc();
+ unit_ = t.time_unit();
+ return this;
+ }
+
+ template <typename T>
+ void Inspect() {
+ FAIL() << "Invalid typename specified in test suite";
+ return;
+ }
+
+ protected:
+ bool adjusted_;
+ LogicalType::TimeUnit::unit unit_;
+};
+
+template <>
+void TestTemporalSchemaElementConstruction::Inspect<format::TimeType>() {
+ TestSchemaElementConstruction::Inspect();
+ ASSERT_EQ(element_->logicalType.TIME.isAdjustedToUTC, adjusted_);
+ switch (unit_) {
+ case LogicalType::TimeUnit::MILLIS:
+ ASSERT_TRUE(element_->logicalType.TIME.unit.__isset.MILLIS);
+ break;
+ case LogicalType::TimeUnit::MICROS:
+ ASSERT_TRUE(element_->logicalType.TIME.unit.__isset.MICROS);
+ break;
+ case LogicalType::TimeUnit::NANOS:
+ ASSERT_TRUE(element_->logicalType.TIME.unit.__isset.NANOS);
+ break;
+ case LogicalType::TimeUnit::UNKNOWN:
+ default:
+ FAIL() << "Invalid time unit in test case";
+ }
+ return;
+}
+
+template <>
+void TestTemporalSchemaElementConstruction::Inspect<format::TimestampType>() {
+ TestSchemaElementConstruction::Inspect();
+ ASSERT_EQ(element_->logicalType.TIMESTAMP.isAdjustedToUTC, adjusted_);
+ switch (unit_) {
+ case LogicalType::TimeUnit::MILLIS:
+ ASSERT_TRUE(element_->logicalType.TIMESTAMP.unit.__isset.MILLIS);
+ break;
+ case LogicalType::TimeUnit::MICROS:
+ ASSERT_TRUE(element_->logicalType.TIMESTAMP.unit.__isset.MICROS);
+ break;
+ case LogicalType::TimeUnit::NANOS:
+ ASSERT_TRUE(element_->logicalType.TIMESTAMP.unit.__isset.NANOS);
+ break;
+ case LogicalType::TimeUnit::UNKNOWN:
+ default:
+ FAIL() << "Invalid time unit in test case";
+ }
+ return;
+}
+
+TEST_F(TestTemporalSchemaElementConstruction, TemporalCases) {
+ auto check_TIME = [this]() { return element_->logicalType.__isset.TIME; };
+
+ std::vector<SchemaElementConstructionArguments> time_cases = {
+ {"time_T_ms", LogicalType::Time(true, LogicalType::TimeUnit::MILLIS), Type::INT32,
+ -1, true, ConvertedType::TIME_MILLIS, true, check_TIME},
+ {"time_F_ms", LogicalType::Time(false, LogicalType::TimeUnit::MILLIS), Type::INT32,
+ -1, false, ConvertedType::NA, true, check_TIME},
+ {"time_T_us", LogicalType::Time(true, LogicalType::TimeUnit::MICROS), Type::INT64,
+ -1, true, ConvertedType::TIME_MICROS, true, check_TIME},
+ {"time_F_us", LogicalType::Time(false, LogicalType::TimeUnit::MICROS), Type::INT64,
+ -1, false, ConvertedType::NA, true, check_TIME},
+ {"time_T_ns", LogicalType::Time(true, LogicalType::TimeUnit::NANOS), Type::INT64,
+ -1, false, ConvertedType::NA, true, check_TIME},
+ {"time_F_ns", LogicalType::Time(false, LogicalType::TimeUnit::NANOS), Type::INT64,
+ -1, false, ConvertedType::NA, true, check_TIME},
+ };
+
+ for (const SchemaElementConstructionArguments& c : time_cases) {
+ this->Reconstruct<TimeLogicalType>(c)->Inspect<format::TimeType>();
+ }
+
+ auto check_TIMESTAMP = [this]() { return element_->logicalType.__isset.TIMESTAMP; };
+
+ std::vector<SchemaElementConstructionArguments> timestamp_cases = {
+ {"timestamp_T_ms", LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS),
+ Type::INT64, -1, true, ConvertedType::TIMESTAMP_MILLIS, true, check_TIMESTAMP},
+ {"timestamp_F_ms", LogicalType::Timestamp(false, LogicalType::TimeUnit::MILLIS),
+ Type::INT64, -1, false, ConvertedType::NA, true, check_TIMESTAMP},
+ {"timestamp_F_ms_force",
+ LogicalType::Timestamp(false, LogicalType::TimeUnit::MILLIS,
+ /*is_from_converted_type=*/false,
+ /*force_set_converted_type=*/true),
+ Type::INT64, -1, true, ConvertedType::TIMESTAMP_MILLIS, true, check_TIMESTAMP},
+ {"timestamp_T_us", LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS),
+ Type::INT64, -1, true, ConvertedType::TIMESTAMP_MICROS, true, check_TIMESTAMP},
+ {"timestamp_F_us", LogicalType::Timestamp(false, LogicalType::TimeUnit::MICROS),
+ Type::INT64, -1, false, ConvertedType::NA, true, check_TIMESTAMP},
+ {"timestamp_F_us_force",
+ LogicalType::Timestamp(false, LogicalType::TimeUnit::MILLIS,
+ /*is_from_converted_type=*/false,
+ /*force_set_converted_type=*/true),
+ Type::INT64, -1, true, ConvertedType::TIMESTAMP_MILLIS, true, check_TIMESTAMP},
+ {"timestamp_T_ns", LogicalType::Timestamp(true, LogicalType::TimeUnit::NANOS),
+ Type::INT64, -1, false, ConvertedType::NA, true, check_TIMESTAMP},
+ {"timestamp_F_ns", LogicalType::Timestamp(false, LogicalType::TimeUnit::NANOS),
+ Type::INT64, -1, false, ConvertedType::NA, true, check_TIMESTAMP},
+ };
+
+ for (const SchemaElementConstructionArguments& c : timestamp_cases) {
+ this->Reconstruct<TimestampLogicalType>(c)->Inspect<format::TimestampType>();
+ }
+}
+
+class TestIntegerSchemaElementConstruction : public TestSchemaElementConstruction {
+ public:
+ TestIntegerSchemaElementConstruction* Reconstruct(
+ const SchemaElementConstructionArguments& c) {
+ TestSchemaElementConstruction::Reconstruct(c);
+ const auto& int_logical_type = checked_cast<const IntLogicalType&>(*c.logical_type);
+ width_ = int_logical_type.bit_width();
+ signed_ = int_logical_type.is_signed();
+ return this;
+ }
+
+ void Inspect() {
+ TestSchemaElementConstruction::Inspect();
+ ASSERT_EQ(element_->logicalType.INTEGER.bitWidth, width_);
+ ASSERT_EQ(element_->logicalType.INTEGER.isSigned, signed_);
+ return;
+ }
+
+ protected:
+ int width_;
+ bool signed_;
+};
+
+TEST_F(TestIntegerSchemaElementConstruction, IntegerCases) {
+ auto check_INTEGER = [this]() { return element_->logicalType.__isset.INTEGER; };
+
+ std::vector<SchemaElementConstructionArguments> cases = {
+ {"uint8", LogicalType::Int(8, false), Type::INT32, -1, true, ConvertedType::UINT_8,
+ true, check_INTEGER},
+ {"uint16", LogicalType::Int(16, false), Type::INT32, -1, true,
+ ConvertedType::UINT_16, true, check_INTEGER},
+ {"uint32", LogicalType::Int(32, false), Type::INT32, -1, true,
+ ConvertedType::UINT_32, true, check_INTEGER},
+ {"uint64", LogicalType::Int(64, false), Type::INT64, -1, true,
+ ConvertedType::UINT_64, true, check_INTEGER},
+ {"int8", LogicalType::Int(8, true), Type::INT32, -1, true, ConvertedType::INT_8,
+ true, check_INTEGER},
+ {"int16", LogicalType::Int(16, true), Type::INT32, -1, true, ConvertedType::INT_16,
+ true, check_INTEGER},
+ {"int32", LogicalType::Int(32, true), Type::INT32, -1, true, ConvertedType::INT_32,
+ true, check_INTEGER},
+ {"int64", LogicalType::Int(64, true), Type::INT64, -1, true, ConvertedType::INT_64,
+ true, check_INTEGER},
+ };
+
+ for (const SchemaElementConstructionArguments& c : cases) {
+ this->Reconstruct(c)->Inspect();
+ }
+}
+
+TEST(TestLogicalTypeSerialization, SchemaElementNestedCases) {
+ // Confirm that the intermediate Thrift objects created during node serialization
+ // contain correct ConvertedType and ConvertedType information
+
+ NodePtr string_node = PrimitiveNode::Make("string", Repetition::REQUIRED,
+ StringLogicalType::Make(), Type::BYTE_ARRAY);
+ NodePtr date_node = PrimitiveNode::Make("date", Repetition::REQUIRED,
+ DateLogicalType::Make(), Type::INT32);
+ NodePtr json_node = PrimitiveNode::Make("json", Repetition::REQUIRED,
+ JSONLogicalType::Make(), Type::BYTE_ARRAY);
+ NodePtr uuid_node =
+ PrimitiveNode::Make("uuid", Repetition::REQUIRED, UUIDLogicalType::Make(),
+ Type::FIXED_LEN_BYTE_ARRAY, 16);
+ NodePtr timestamp_node = PrimitiveNode::Make(
+ "timestamp", Repetition::REQUIRED,
+ TimestampLogicalType::Make(false, LogicalType::TimeUnit::NANOS), Type::INT64);
+ NodePtr int_node = PrimitiveNode::Make("int", Repetition::REQUIRED,
+ IntLogicalType::Make(64, false), Type::INT64);
+ NodePtr decimal_node = PrimitiveNode::Make(
+ "decimal", Repetition::REQUIRED, DecimalLogicalType::Make(16, 6), Type::INT64);
+
+ NodePtr list_node = GroupNode::Make("list", Repetition::REPEATED,
+ {string_node, date_node, json_node, uuid_node,
+ timestamp_node, int_node, decimal_node},
+ ListLogicalType::Make());
+ std::vector<format::SchemaElement> list_elements;
+ ToParquet(reinterpret_cast<GroupNode*>(list_node.get()), &list_elements);
+ ASSERT_EQ(list_elements[0].name, "list");
+ ASSERT_TRUE(list_elements[0].__isset.converted_type);
+ ASSERT_TRUE(list_elements[0].__isset.logicalType);
+ ASSERT_EQ(list_elements[0].converted_type, ToThrift(ConvertedType::LIST));
+ ASSERT_TRUE(list_elements[0].logicalType.__isset.LIST);
+ ASSERT_TRUE(list_elements[1].logicalType.__isset.STRING);
+ ASSERT_TRUE(list_elements[2].logicalType.__isset.DATE);
+ ASSERT_TRUE(list_elements[3].logicalType.__isset.JSON);
+ ASSERT_TRUE(list_elements[4].logicalType.__isset.UUID);
+ ASSERT_TRUE(list_elements[5].logicalType.__isset.TIMESTAMP);
+ ASSERT_TRUE(list_elements[6].logicalType.__isset.INTEGER);
+ ASSERT_TRUE(list_elements[7].logicalType.__isset.DECIMAL);
+
+ NodePtr map_node =
+ GroupNode::Make("map", Repetition::REQUIRED, {}, MapLogicalType::Make());
+ std::vector<format::SchemaElement> map_elements;
+ ToParquet(reinterpret_cast<GroupNode*>(map_node.get()), &map_elements);
+ ASSERT_EQ(map_elements[0].name, "map");
+ ASSERT_TRUE(map_elements[0].__isset.converted_type);
+ ASSERT_TRUE(map_elements[0].__isset.logicalType);
+ ASSERT_EQ(map_elements[0].converted_type, ToThrift(ConvertedType::MAP));
+ ASSERT_TRUE(map_elements[0].logicalType.__isset.MAP);
+}
+
+TEST(TestLogicalTypeSerialization, Roundtrips) {
+ // Confirm that Thrift serialization-deserialization of nodes with logical
+ // types produces equivalent reconstituted nodes
+
+ // Primitive nodes ...
+ struct AnnotatedPrimitiveNodeFactoryArguments {
+ std::shared_ptr<const LogicalType> logical_type;
+ Type::type physical_type;
+ int physical_length;
+ };
+
+ std::vector<AnnotatedPrimitiveNodeFactoryArguments> cases = {
+ {LogicalType::String(), Type::BYTE_ARRAY, -1},
+ {LogicalType::Enum(), Type::BYTE_ARRAY, -1},
+ {LogicalType::Decimal(16, 6), Type::INT64, -1},
+ {LogicalType::Date(), Type::INT32, -1},
+ {LogicalType::Time(true, LogicalType::TimeUnit::MILLIS), Type::INT32, -1},
+ {LogicalType::Time(true, LogicalType::TimeUnit::MICROS), Type::INT64, -1},
+ {LogicalType::Time(true, LogicalType::TimeUnit::NANOS), Type::INT64, -1},
+ {LogicalType::Time(false, LogicalType::TimeUnit::MILLIS), Type::INT32, -1},
+ {LogicalType::Time(false, LogicalType::TimeUnit::MICROS), Type::INT64, -1},
+ {LogicalType::Time(false, LogicalType::TimeUnit::NANOS), Type::INT64, -1},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::MILLIS), Type::INT64, -1},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::MICROS), Type::INT64, -1},
+ {LogicalType::Timestamp(true, LogicalType::TimeUnit::NANOS), Type::INT64, -1},
+ {LogicalType::Timestamp(false, LogicalType::TimeUnit::MILLIS), Type::INT64, -1},
+ {LogicalType::Timestamp(false, LogicalType::TimeUnit::MICROS), Type::INT64, -1},
+ {LogicalType::Timestamp(false, LogicalType::TimeUnit::NANOS), Type::INT64, -1},
+ {LogicalType::Interval(), Type::FIXED_LEN_BYTE_ARRAY, 12},
+ {LogicalType::Int(8, false), Type::INT32, -1},
+ {LogicalType::Int(16, false), Type::INT32, -1},
+ {LogicalType::Int(32, false), Type::INT32, -1},
+ {LogicalType::Int(64, false), Type::INT64, -1},
+ {LogicalType::Int(8, true), Type::INT32, -1},
+ {LogicalType::Int(16, true), Type::INT32, -1},
+ {LogicalType::Int(32, true), Type::INT32, -1},
+ {LogicalType::Int(64, true), Type::INT64, -1},
+ {LogicalType::Null(), Type::BOOLEAN, -1},
+ {LogicalType::JSON(), Type::BYTE_ARRAY, -1},
+ {LogicalType::BSON(), Type::BYTE_ARRAY, -1},
+ {LogicalType::UUID(), Type::FIXED_LEN_BYTE_ARRAY, 16},
+ {LogicalType::None(), Type::BOOLEAN, -1}};
+
+ for (const AnnotatedPrimitiveNodeFactoryArguments& c : cases) {
+ ConfirmPrimitiveNodeRoundtrip(c.logical_type, c.physical_type, c.physical_length);
+ }
+
+ // Group nodes ...
+ ConfirmGroupNodeRoundtrip("map", LogicalType::Map());
+ ConfirmGroupNodeRoundtrip("list", LogicalType::List());
+}
+
+} // namespace schema
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/statistics.cc b/src/arrow/cpp/src/parquet/statistics.cc
new file mode 100644
index 000000000..715c1a1ab
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/statistics.cc
@@ -0,0 +1,887 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/statistics.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstring>
+#include <limits>
+#include <type_traits>
+#include <utility>
+
+#include "arrow/array.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_run_reader.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/ubsan.h"
+#include "arrow/visitor_inline.h"
+#include "parquet/encoding.h"
+#include "parquet/exception.h"
+#include "parquet/platform.h"
+#include "parquet/schema.h"
+
+using arrow::default_memory_pool;
+using arrow::MemoryPool;
+using arrow::internal::checked_cast;
+using arrow::util::SafeCopy;
+
+namespace parquet {
+namespace {
+
+// ----------------------------------------------------------------------
+// Comparator implementations
+
+constexpr int value_length(int value_length, const ByteArray& value) { return value.len; }
+constexpr int value_length(int type_length, const FLBA& value) { return type_length; }
+
+template <typename DType, bool is_signed>
+struct CompareHelper {
+ using T = typename DType::c_type;
+
+ static_assert(!std::is_unsigned<T>::value || std::is_same<T, bool>::value,
+ "T is an unsigned numeric");
+
+ constexpr static T DefaultMin() { return std::numeric_limits<T>::max(); }
+ constexpr static T DefaultMax() { return std::numeric_limits<T>::lowest(); }
+
+ // MSVC17 fix, isnan is not overloaded for IntegralType as per C++11
+ // standard requirements.
+ template <typename T1 = T>
+ static ::arrow::enable_if_t<std::is_floating_point<T1>::value, T> Coalesce(T val,
+ T fallback) {
+ return std::isnan(val) ? fallback : val;
+ }
+
+ template <typename T1 = T>
+ static ::arrow::enable_if_t<!std::is_floating_point<T1>::value, T> Coalesce(
+ T val, T fallback) {
+ return val;
+ }
+
+ static inline bool Compare(int type_length, const T& a, const T& b) { return a < b; }
+
+ static T Min(int type_length, T a, T b) { return a < b ? a : b; }
+ static T Max(int type_length, T a, T b) { return a < b ? b : a; }
+};
+
+template <typename DType>
+struct UnsignedCompareHelperBase {
+ using T = typename DType::c_type;
+ using UCType = typename std::make_unsigned<T>::type;
+
+ static_assert(!std::is_same<T, UCType>::value, "T is unsigned");
+ static_assert(sizeof(T) == sizeof(UCType), "T and UCType not the same size");
+
+ // NOTE: according to the C++ spec, unsigned-to-signed conversion is
+ // implementation-defined if the original value does not fit in the signed type
+ // (i.e., two's complement cannot be assumed even on mainstream machines,
+ // because the compiler may decide otherwise). Hence the use of `SafeCopy`
+ // below for deterministic bit-casting.
+ // (see "Integer conversions" in
+ // https://en.cppreference.com/w/cpp/language/implicit_conversion)
+
+ static const T DefaultMin() { return SafeCopy<T>(std::numeric_limits<UCType>::max()); }
+ static const T DefaultMax() { return 0; }
+
+ static T Coalesce(T val, T fallback) { return val; }
+
+ static bool Compare(int type_length, T a, T b) {
+ return SafeCopy<UCType>(a) < SafeCopy<UCType>(b);
+ }
+
+ static T Min(int type_length, T a, T b) { return Compare(type_length, a, b) ? a : b; }
+ static T Max(int type_length, T a, T b) { return Compare(type_length, a, b) ? b : a; }
+};
+
+template <>
+struct CompareHelper<Int32Type, false> : public UnsignedCompareHelperBase<Int32Type> {};
+
+template <>
+struct CompareHelper<Int64Type, false> : public UnsignedCompareHelperBase<Int64Type> {};
+
+template <bool is_signed>
+struct CompareHelper<Int96Type, is_signed> {
+ using T = typename Int96Type::c_type;
+ using msb_type = typename std::conditional<is_signed, int32_t, uint32_t>::type;
+
+ static T DefaultMin() {
+ uint32_t kMsbMax = SafeCopy<uint32_t>(std::numeric_limits<msb_type>::max());
+ uint32_t kMax = std::numeric_limits<uint32_t>::max();
+ return {kMax, kMax, kMsbMax};
+ }
+ static T DefaultMax() {
+ uint32_t kMsbMin = SafeCopy<uint32_t>(std::numeric_limits<msb_type>::min());
+ uint32_t kMin = std::numeric_limits<uint32_t>::min();
+ return {kMin, kMin, kMsbMin};
+ }
+ static T Coalesce(T val, T fallback) { return val; }
+
+ static inline bool Compare(int type_length, const T& a, const T& b) {
+ if (a.value[2] != b.value[2]) {
+ // Only the MSB bit is by Signed comparison. For little-endian, this is the
+ // last bit of Int96 type.
+ return SafeCopy<msb_type>(a.value[2]) < SafeCopy<msb_type>(b.value[2]);
+ } else if (a.value[1] != b.value[1]) {
+ return (a.value[1] < b.value[1]);
+ }
+ return (a.value[0] < b.value[0]);
+ }
+
+ static T Min(int type_length, const T& a, const T& b) {
+ return Compare(0, a, b) ? a : b;
+ }
+ static T Max(int type_length, const T& a, const T& b) {
+ return Compare(0, a, b) ? b : a;
+ }
+};
+
+template <typename T, bool is_signed>
+struct BinaryLikeComparer {};
+
+template <typename T>
+struct BinaryLikeComparer<T, /*is_signed=*/false> {
+ static bool Compare(int type_length, const T& a, const T& b) {
+ int a_length = value_length(type_length, a);
+ int b_length = value_length(type_length, b);
+ // Unsigned comparison is used for non-numeric types so straight
+ // lexiographic comparison makes sense. (a.ptr is always unsigned)....
+ return std::lexicographical_compare(a.ptr, a.ptr + a_length, b.ptr, b.ptr + b_length);
+ }
+};
+
+template <typename T>
+struct BinaryLikeComparer<T, /*is_signed=*/true> {
+ static bool Compare(int type_length, const T& a, const T& b) {
+ // Is signed is used for integers encoded as big-endian twos
+ // complement integers. (e.g. decimals).
+ int a_length = value_length(type_length, a);
+ int b_length = value_length(type_length, b);
+
+ // At least of the lengths is zero.
+ if (a_length == 0 || b_length == 0) {
+ return a_length == 0 && b_length > 0;
+ }
+
+ int8_t first_a = *a.ptr;
+ int8_t first_b = *b.ptr;
+ // We can short circuit for different signed numbers or
+ // for equal length bytes arrays that have different first bytes.
+ // The equality requirement is necessary for sign extension cases.
+ // 0xFF10 should be eqaul to 0x10 (due to big endian sign extension).
+ if ((0x80 & first_a) != (0x80 & first_b) ||
+ (a_length == b_length && first_a != first_b)) {
+ return first_a < first_b;
+ }
+ // When the lengths are unequal and the numbers are of the same
+ // sign we need to do comparison by sign extending the shorter
+ // value first, and once we get to equal sized arrays, lexicographical
+ // unsigned comparison of everything but the first byte is sufficient.
+ const uint8_t* a_start = a.ptr;
+ const uint8_t* b_start = b.ptr;
+ if (a_length != b_length) {
+ const uint8_t* lead_start = nullptr;
+ const uint8_t* lead_end = nullptr;
+ if (a_length > b_length) {
+ int lead_length = a_length - b_length;
+ lead_start = a.ptr;
+ lead_end = a.ptr + lead_length;
+ a_start += lead_length;
+ } else {
+ DCHECK_LT(a_length, b_length);
+ int lead_length = b_length - a_length;
+ lead_start = b.ptr;
+ lead_end = b.ptr + lead_length;
+ b_start += lead_length;
+ }
+ // Compare extra bytes to the sign extension of the first
+ // byte of the other number.
+ uint8_t extension = first_a < 0 ? 0xFF : 0;
+ bool not_equal = std::any_of(lead_start, lead_end,
+ [extension](uint8_t a) { return extension != a; });
+ if (not_equal) {
+ // Since sign extension are extrema values for unsigned bytes:
+ //
+ // Four cases exist:
+ // negative values:
+ // b is the longer value.
+ // b must be the lesser value: return false
+ // else:
+ // a must be the lesser value: return true
+ //
+ // positive values:
+ // b is the longer value.
+ // values in b must be greater than a: return true
+ // else:
+ // values in a must be greater than b: return false
+ bool negative_values = first_a < 0;
+ bool b_longer = a_length < b_length;
+ return negative_values != b_longer;
+ }
+ } else {
+ a_start++;
+ b_start++;
+ }
+ return std::lexicographical_compare(a_start, a.ptr + a_length, b_start,
+ b.ptr + b_length);
+ }
+};
+
+template <typename DType, bool is_signed>
+struct BinaryLikeCompareHelperBase {
+ using T = typename DType::c_type;
+
+ static T DefaultMin() { return {}; }
+ static T DefaultMax() { return {}; }
+ static T Coalesce(T val, T fallback) { return val; }
+
+ static inline bool Compare(int type_length, const T& a, const T& b) {
+ return BinaryLikeComparer<T, is_signed>::Compare(type_length, a, b);
+ }
+ static T Min(int type_length, const T& a, const T& b) {
+ if (a.ptr == nullptr) return b;
+ if (b.ptr == nullptr) return a;
+ return Compare(type_length, a, b) ? a : b;
+ }
+
+ static T Max(int type_length, const T& a, const T& b) {
+ if (a.ptr == nullptr) return b;
+ if (b.ptr == nullptr) return a;
+ return Compare(type_length, a, b) ? b : a;
+ }
+};
+
+template <bool is_signed>
+struct CompareHelper<ByteArrayType, is_signed>
+ : public BinaryLikeCompareHelperBase<ByteArrayType, is_signed> {};
+
+template <bool is_signed>
+struct CompareHelper<FLBAType, is_signed>
+ : public BinaryLikeCompareHelperBase<FLBAType, is_signed> {};
+
+using ::arrow::util::optional;
+
+template <typename T>
+::arrow::enable_if_t<std::is_integral<T>::value, optional<std::pair<T, T>>>
+CleanStatistic(std::pair<T, T> min_max) {
+ return min_max;
+}
+
+// In case of floating point types, the following rules are applied (as per
+// upstream parquet-mr):
+// - If any of min/max is NaN, return nothing.
+// - If min is 0.0f, replace with -0.0f
+// - If max is -0.0f, replace with 0.0f
+template <typename T>
+::arrow::enable_if_t<std::is_floating_point<T>::value, optional<std::pair<T, T>>>
+CleanStatistic(std::pair<T, T> min_max) {
+ T min = min_max.first;
+ T max = min_max.second;
+
+ // Ignore if one of the value is nan.
+ if (std::isnan(min) || std::isnan(max)) {
+ return ::arrow::util::nullopt;
+ }
+
+ if (min == std::numeric_limits<T>::max() && max == std::numeric_limits<T>::lowest()) {
+ return ::arrow::util::nullopt;
+ }
+
+ T zero{};
+
+ if (min == zero && !std::signbit(min)) {
+ min = -min;
+ }
+
+ if (max == zero && std::signbit(max)) {
+ max = -max;
+ }
+
+ return {{min, max}};
+}
+
+optional<std::pair<FLBA, FLBA>> CleanStatistic(std::pair<FLBA, FLBA> min_max) {
+ if (min_max.first.ptr == nullptr || min_max.second.ptr == nullptr) {
+ return ::arrow::util::nullopt;
+ }
+ return min_max;
+}
+
+optional<std::pair<ByteArray, ByteArray>> CleanStatistic(
+ std::pair<ByteArray, ByteArray> min_max) {
+ if (min_max.first.ptr == nullptr || min_max.second.ptr == nullptr) {
+ return ::arrow::util::nullopt;
+ }
+ return min_max;
+}
+
+template <bool is_signed, typename DType>
+class TypedComparatorImpl : virtual public TypedComparator<DType> {
+ public:
+ using T = typename DType::c_type;
+ using Helper = CompareHelper<DType, is_signed>;
+
+ explicit TypedComparatorImpl(int type_length = -1) : type_length_(type_length) {}
+
+ bool CompareInline(const T& a, const T& b) const {
+ return Helper::Compare(type_length_, a, b);
+ }
+
+ bool Compare(const T& a, const T& b) override { return CompareInline(a, b); }
+
+ std::pair<T, T> GetMinMax(const T* values, int64_t length) override {
+ DCHECK_GT(length, 0);
+
+ T min = Helper::DefaultMin();
+ T max = Helper::DefaultMax();
+
+ for (int64_t i = 0; i < length; i++) {
+ auto val = values[i];
+ min = Helper::Min(type_length_, min, Helper::Coalesce(val, Helper::DefaultMin()));
+ max = Helper::Max(type_length_, max, Helper::Coalesce(val, Helper::DefaultMax()));
+ }
+
+ return {min, max};
+ }
+
+ std::pair<T, T> GetMinMaxSpaced(const T* values, int64_t length,
+ const uint8_t* valid_bits,
+ int64_t valid_bits_offset) override {
+ DCHECK_GT(length, 0);
+
+ T min = Helper::DefaultMin();
+ T max = Helper::DefaultMax();
+
+ ::arrow::internal::VisitSetBitRunsVoid(
+ valid_bits, valid_bits_offset, length, [&](int64_t position, int64_t length) {
+ for (int64_t i = 0; i < length; i++) {
+ const auto val = values[i + position];
+ min = Helper::Min(type_length_, min,
+ Helper::Coalesce(val, Helper::DefaultMin()));
+ max = Helper::Max(type_length_, max,
+ Helper::Coalesce(val, Helper::DefaultMax()));
+ }
+ });
+
+ return {min, max};
+ }
+
+ std::pair<T, T> GetMinMax(const ::arrow::Array& values) override;
+
+ private:
+ int type_length_;
+};
+
+// ARROW-11675: A hand-written version of GetMinMax(), to work around
+// what looks like a MSVC code generation bug.
+// This does not seem to be required for GetMinMaxSpaced().
+template <>
+std::pair<int32_t, int32_t>
+TypedComparatorImpl</*is_signed=*/false, Int32Type>::GetMinMax(const int32_t* values,
+ int64_t length) {
+ DCHECK_GT(length, 0);
+
+ const uint32_t* unsigned_values = reinterpret_cast<const uint32_t*>(values);
+ uint32_t min = std::numeric_limits<uint32_t>::max();
+ uint32_t max = std::numeric_limits<uint32_t>::lowest();
+
+ for (int64_t i = 0; i < length; i++) {
+ const auto val = unsigned_values[i];
+ min = std::min<uint32_t>(min, val);
+ max = std::max<uint32_t>(max, val);
+ }
+
+ return {SafeCopy<int32_t>(min), SafeCopy<int32_t>(max)};
+}
+
+template <bool is_signed, typename DType>
+std::pair<typename DType::c_type, typename DType::c_type>
+TypedComparatorImpl<is_signed, DType>::GetMinMax(const ::arrow::Array& values) {
+ ParquetException::NYI(values.type()->ToString());
+}
+
+template <bool is_signed>
+std::pair<ByteArray, ByteArray> GetMinMaxBinaryHelper(
+ const TypedComparatorImpl<is_signed, ByteArrayType>& comparator,
+ const ::arrow::Array& values) {
+ using Helper = CompareHelper<ByteArrayType, is_signed>;
+
+ ByteArray min = Helper::DefaultMin();
+ ByteArray max = Helper::DefaultMax();
+ constexpr int type_length = -1;
+
+ const auto valid_func = [&](ByteArray val) {
+ min = Helper::Min(type_length, val, min);
+ max = Helper::Max(type_length, val, max);
+ };
+ const auto null_func = [&]() {};
+
+ if (::arrow::is_binary_like(values.type_id())) {
+ ::arrow::VisitArrayDataInline<::arrow::BinaryType>(
+ *values.data(), std::move(valid_func), std::move(null_func));
+ } else {
+ DCHECK(::arrow::is_large_binary_like(values.type_id()));
+ ::arrow::VisitArrayDataInline<::arrow::LargeBinaryType>(
+ *values.data(), std::move(valid_func), std::move(null_func));
+ }
+
+ return {min, max};
+}
+
+template <>
+std::pair<ByteArray, ByteArray> TypedComparatorImpl<true, ByteArrayType>::GetMinMax(
+ const ::arrow::Array& values) {
+ return GetMinMaxBinaryHelper<true>(*this, values);
+}
+
+template <>
+std::pair<ByteArray, ByteArray> TypedComparatorImpl<false, ByteArrayType>::GetMinMax(
+ const ::arrow::Array& values) {
+ return GetMinMaxBinaryHelper<false>(*this, values);
+}
+
+template <typename DType>
+class TypedStatisticsImpl : public TypedStatistics<DType> {
+ public:
+ using T = typename DType::c_type;
+
+ TypedStatisticsImpl(const ColumnDescriptor* descr, MemoryPool* pool)
+ : descr_(descr),
+ pool_(pool),
+ min_buffer_(AllocateBuffer(pool_, 0)),
+ max_buffer_(AllocateBuffer(pool_, 0)) {
+ auto comp = Comparator::Make(descr);
+ comparator_ = std::static_pointer_cast<TypedComparator<DType>>(comp);
+ Reset();
+ has_null_count_ = true;
+ has_distinct_count_ = true;
+ }
+
+ TypedStatisticsImpl(const T& min, const T& max, int64_t num_values, int64_t null_count,
+ int64_t distinct_count)
+ : pool_(default_memory_pool()),
+ min_buffer_(AllocateBuffer(pool_, 0)),
+ max_buffer_(AllocateBuffer(pool_, 0)) {
+ IncrementNumValues(num_values);
+ IncrementNullCount(null_count);
+ IncrementDistinctCount(distinct_count);
+
+ Copy(min, &min_, min_buffer_.get());
+ Copy(max, &max_, max_buffer_.get());
+ has_min_max_ = true;
+ }
+
+ TypedStatisticsImpl(const ColumnDescriptor* descr, const std::string& encoded_min,
+ const std::string& encoded_max, int64_t num_values,
+ int64_t null_count, int64_t distinct_count, bool has_min_max,
+ bool has_null_count, bool has_distinct_count, MemoryPool* pool)
+ : TypedStatisticsImpl(descr, pool) {
+ IncrementNumValues(num_values);
+ if (has_null_count_) {
+ IncrementNullCount(null_count);
+ }
+ if (has_distinct_count) {
+ IncrementDistinctCount(distinct_count);
+ }
+
+ if (!encoded_min.empty()) {
+ PlainDecode(encoded_min, &min_);
+ }
+ if (!encoded_max.empty()) {
+ PlainDecode(encoded_max, &max_);
+ }
+ has_min_max_ = has_min_max;
+ }
+
+ bool HasDistinctCount() const override { return has_distinct_count_; };
+ bool HasMinMax() const override { return has_min_max_; }
+ bool HasNullCount() const override { return has_null_count_; };
+
+ void IncrementNullCount(int64_t n) override {
+ statistics_.null_count += n;
+ has_null_count_ = true;
+ }
+
+ void IncrementNumValues(int64_t n) override { num_values_ += n; }
+
+ bool Equals(const Statistics& raw_other) const override {
+ if (physical_type() != raw_other.physical_type()) return false;
+
+ const auto& other = checked_cast<const TypedStatisticsImpl&>(raw_other);
+
+ if (has_min_max_ != other.has_min_max_) return false;
+
+ return (has_min_max_ && MinMaxEqual(other)) && null_count() == other.null_count() &&
+ distinct_count() == other.distinct_count() &&
+ num_values() == other.num_values();
+ }
+
+ bool MinMaxEqual(const TypedStatisticsImpl& other) const;
+
+ void Reset() override {
+ ResetCounts();
+ has_min_max_ = false;
+ has_distinct_count_ = false;
+ has_null_count_ = false;
+ }
+
+ void SetMinMax(const T& arg_min, const T& arg_max) override {
+ SetMinMaxPair({arg_min, arg_max});
+ }
+
+ void Merge(const TypedStatistics<DType>& other) override {
+ this->num_values_ += other.num_values();
+ if (other.HasNullCount()) {
+ this->statistics_.null_count += other.null_count();
+ }
+ if (other.HasDistinctCount()) {
+ this->statistics_.distinct_count += other.distinct_count();
+ }
+ if (other.HasMinMax()) {
+ SetMinMax(other.min(), other.max());
+ }
+ }
+
+ void Update(const T* values, int64_t num_not_null, int64_t num_null) override;
+ void UpdateSpaced(const T* values, const uint8_t* valid_bits, int64_t valid_bits_offset,
+ int64_t num_spaced_values, int64_t num_not_null,
+ int64_t num_null) override;
+
+ void Update(const ::arrow::Array& values, bool update_counts) override {
+ if (update_counts) {
+ IncrementNullCount(values.null_count());
+ IncrementNumValues(values.length() - values.null_count());
+ }
+
+ if (values.null_count() == values.length()) {
+ return;
+ }
+
+ SetMinMaxPair(comparator_->GetMinMax(values));
+ }
+
+ const T& min() const override { return min_; }
+
+ const T& max() const override { return max_; }
+
+ Type::type physical_type() const override { return descr_->physical_type(); }
+
+ const ColumnDescriptor* descr() const override { return descr_; }
+
+ std::string EncodeMin() const override {
+ std::string s;
+ if (HasMinMax()) this->PlainEncode(min_, &s);
+ return s;
+ }
+
+ std::string EncodeMax() const override {
+ std::string s;
+ if (HasMinMax()) this->PlainEncode(max_, &s);
+ return s;
+ }
+
+ EncodedStatistics Encode() override {
+ EncodedStatistics s;
+ if (HasMinMax()) {
+ s.set_min(this->EncodeMin());
+ s.set_max(this->EncodeMax());
+ }
+ if (HasNullCount()) {
+ s.set_null_count(this->null_count());
+ }
+ return s;
+ }
+
+ int64_t null_count() const override { return statistics_.null_count; }
+ int64_t distinct_count() const override { return statistics_.distinct_count; }
+ int64_t num_values() const override { return num_values_; }
+
+ private:
+ const ColumnDescriptor* descr_;
+ bool has_min_max_ = false;
+ bool has_null_count_ = false;
+ bool has_distinct_count_ = false;
+ T min_;
+ T max_;
+ ::arrow::MemoryPool* pool_;
+ int64_t num_values_ = 0;
+ EncodedStatistics statistics_;
+ std::shared_ptr<TypedComparator<DType>> comparator_;
+ std::shared_ptr<ResizableBuffer> min_buffer_, max_buffer_;
+
+ void PlainEncode(const T& src, std::string* dst) const;
+ void PlainDecode(const std::string& src, T* dst) const;
+
+ void Copy(const T& src, T* dst, ResizableBuffer*) { *dst = src; }
+
+ void IncrementDistinctCount(int64_t n) {
+ statistics_.distinct_count += n;
+ has_distinct_count_ = true;
+ }
+
+ void ResetCounts() {
+ this->statistics_.null_count = 0;
+ this->statistics_.distinct_count = 0;
+ this->num_values_ = 0;
+ }
+
+ void SetMinMaxPair(std::pair<T, T> min_max) {
+ // CleanStatistic can return a nullopt in case of erroneous values, e.g. NaN
+ auto maybe_min_max = CleanStatistic(min_max);
+ if (!maybe_min_max) return;
+
+ auto min = maybe_min_max.value().first;
+ auto max = maybe_min_max.value().second;
+
+ if (!has_min_max_) {
+ has_min_max_ = true;
+ Copy(min, &min_, min_buffer_.get());
+ Copy(max, &max_, max_buffer_.get());
+ } else {
+ Copy(comparator_->Compare(min_, min) ? min_ : min, &min_, min_buffer_.get());
+ Copy(comparator_->Compare(max_, max) ? max : max_, &max_, max_buffer_.get());
+ }
+ }
+};
+
+template <>
+inline bool TypedStatisticsImpl<FLBAType>::MinMaxEqual(
+ const TypedStatisticsImpl<FLBAType>& other) const {
+ uint32_t len = descr_->type_length();
+ return std::memcmp(min_.ptr, other.min_.ptr, len) == 0 &&
+ std::memcmp(max_.ptr, other.max_.ptr, len) == 0;
+}
+
+template <typename DType>
+bool TypedStatisticsImpl<DType>::MinMaxEqual(
+ const TypedStatisticsImpl<DType>& other) const {
+ return min_ != other.min_ && max_ != other.max_;
+}
+
+template <>
+inline void TypedStatisticsImpl<FLBAType>::Copy(const FLBA& src, FLBA* dst,
+ ResizableBuffer* buffer) {
+ if (dst->ptr == src.ptr) return;
+ uint32_t len = descr_->type_length();
+ PARQUET_THROW_NOT_OK(buffer->Resize(len, false));
+ std::memcpy(buffer->mutable_data(), src.ptr, len);
+ *dst = FLBA(buffer->data());
+}
+
+template <>
+inline void TypedStatisticsImpl<ByteArrayType>::Copy(const ByteArray& src, ByteArray* dst,
+ ResizableBuffer* buffer) {
+ if (dst->ptr == src.ptr) return;
+ PARQUET_THROW_NOT_OK(buffer->Resize(src.len, false));
+ std::memcpy(buffer->mutable_data(), src.ptr, src.len);
+ *dst = ByteArray(src.len, buffer->data());
+}
+
+template <typename DType>
+void TypedStatisticsImpl<DType>::Update(const T* values, int64_t num_not_null,
+ int64_t num_null) {
+ DCHECK_GE(num_not_null, 0);
+ DCHECK_GE(num_null, 0);
+
+ IncrementNullCount(num_null);
+ IncrementNumValues(num_not_null);
+
+ if (num_not_null == 0) return;
+ SetMinMaxPair(comparator_->GetMinMax(values, num_not_null));
+}
+
+template <typename DType>
+void TypedStatisticsImpl<DType>::UpdateSpaced(const T* values, const uint8_t* valid_bits,
+ int64_t valid_bits_offset,
+ int64_t num_spaced_values,
+ int64_t num_not_null, int64_t num_null) {
+ DCHECK_GE(num_not_null, 0);
+ DCHECK_GE(num_null, 0);
+
+ IncrementNullCount(num_null);
+ IncrementNumValues(num_not_null);
+
+ if (num_not_null == 0) return;
+ SetMinMaxPair(comparator_->GetMinMaxSpaced(values, num_spaced_values, valid_bits,
+ valid_bits_offset));
+}
+
+template <typename DType>
+void TypedStatisticsImpl<DType>::PlainEncode(const T& src, std::string* dst) const {
+ auto encoder = MakeTypedEncoder<DType>(Encoding::PLAIN, false, descr_, pool_);
+ encoder->Put(&src, 1);
+ auto buffer = encoder->FlushValues();
+ auto ptr = reinterpret_cast<const char*>(buffer->data());
+ dst->assign(ptr, buffer->size());
+}
+
+template <typename DType>
+void TypedStatisticsImpl<DType>::PlainDecode(const std::string& src, T* dst) const {
+ auto decoder = MakeTypedDecoder<DType>(Encoding::PLAIN, descr_);
+ decoder->SetData(1, reinterpret_cast<const uint8_t*>(src.c_str()),
+ static_cast<int>(src.size()));
+ decoder->Decode(dst, 1);
+}
+
+template <>
+void TypedStatisticsImpl<ByteArrayType>::PlainEncode(const T& src,
+ std::string* dst) const {
+ dst->assign(reinterpret_cast<const char*>(src.ptr), src.len);
+}
+
+template <>
+void TypedStatisticsImpl<ByteArrayType>::PlainDecode(const std::string& src,
+ T* dst) const {
+ dst->len = static_cast<uint32_t>(src.size());
+ dst->ptr = reinterpret_cast<const uint8_t*>(src.c_str());
+}
+
+} // namespace
+
+// ----------------------------------------------------------------------
+// Public factory functions
+
+std::shared_ptr<Comparator> Comparator::Make(Type::type physical_type,
+ SortOrder::type sort_order,
+ int type_length) {
+ if (SortOrder::SIGNED == sort_order) {
+ switch (physical_type) {
+ case Type::BOOLEAN:
+ return std::make_shared<TypedComparatorImpl<true, BooleanType>>();
+ case Type::INT32:
+ return std::make_shared<TypedComparatorImpl<true, Int32Type>>();
+ case Type::INT64:
+ return std::make_shared<TypedComparatorImpl<true, Int64Type>>();
+ case Type::INT96:
+ return std::make_shared<TypedComparatorImpl<true, Int96Type>>();
+ case Type::FLOAT:
+ return std::make_shared<TypedComparatorImpl<true, FloatType>>();
+ case Type::DOUBLE:
+ return std::make_shared<TypedComparatorImpl<true, DoubleType>>();
+ case Type::BYTE_ARRAY:
+ return std::make_shared<TypedComparatorImpl<true, ByteArrayType>>();
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return std::make_shared<TypedComparatorImpl<true, FLBAType>>(type_length);
+ default:
+ ParquetException::NYI("Signed Compare not implemented");
+ }
+ } else if (SortOrder::UNSIGNED == sort_order) {
+ switch (physical_type) {
+ case Type::INT32:
+ return std::make_shared<TypedComparatorImpl<false, Int32Type>>();
+ case Type::INT64:
+ return std::make_shared<TypedComparatorImpl<false, Int64Type>>();
+ case Type::INT96:
+ return std::make_shared<TypedComparatorImpl<false, Int96Type>>();
+ case Type::BYTE_ARRAY:
+ return std::make_shared<TypedComparatorImpl<false, ByteArrayType>>();
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return std::make_shared<TypedComparatorImpl<false, FLBAType>>(type_length);
+ default:
+ ParquetException::NYI("Unsigned Compare not implemented");
+ }
+ } else {
+ throw ParquetException("UNKNOWN Sort Order");
+ }
+ return nullptr;
+}
+
+std::shared_ptr<Comparator> Comparator::Make(const ColumnDescriptor* descr) {
+ return Make(descr->physical_type(), descr->sort_order(), descr->type_length());
+}
+
+std::shared_ptr<Statistics> Statistics::Make(const ColumnDescriptor* descr,
+ ::arrow::MemoryPool* pool) {
+ switch (descr->physical_type()) {
+ case Type::BOOLEAN:
+ return std::make_shared<TypedStatisticsImpl<BooleanType>>(descr, pool);
+ case Type::INT32:
+ return std::make_shared<TypedStatisticsImpl<Int32Type>>(descr, pool);
+ case Type::INT64:
+ return std::make_shared<TypedStatisticsImpl<Int64Type>>(descr, pool);
+ case Type::FLOAT:
+ return std::make_shared<TypedStatisticsImpl<FloatType>>(descr, pool);
+ case Type::DOUBLE:
+ return std::make_shared<TypedStatisticsImpl<DoubleType>>(descr, pool);
+ case Type::BYTE_ARRAY:
+ return std::make_shared<TypedStatisticsImpl<ByteArrayType>>(descr, pool);
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return std::make_shared<TypedStatisticsImpl<FLBAType>>(descr, pool);
+ default:
+ ParquetException::NYI("Statistics not implemented");
+ }
+}
+
+std::shared_ptr<Statistics> Statistics::Make(Type::type physical_type, const void* min,
+ const void* max, int64_t num_values,
+ int64_t null_count, int64_t distinct_count) {
+#define MAKE_STATS(CAP_TYPE, KLASS) \
+ case Type::CAP_TYPE: \
+ return std::make_shared<TypedStatisticsImpl<KLASS>>( \
+ *reinterpret_cast<const typename KLASS::c_type*>(min), \
+ *reinterpret_cast<const typename KLASS::c_type*>(max), num_values, null_count, \
+ distinct_count)
+
+ switch (physical_type) {
+ MAKE_STATS(BOOLEAN, BooleanType);
+ MAKE_STATS(INT32, Int32Type);
+ MAKE_STATS(INT64, Int64Type);
+ MAKE_STATS(FLOAT, FloatType);
+ MAKE_STATS(DOUBLE, DoubleType);
+ MAKE_STATS(BYTE_ARRAY, ByteArrayType);
+ MAKE_STATS(FIXED_LEN_BYTE_ARRAY, FLBAType);
+ default:
+ break;
+ }
+#undef MAKE_STATS
+ DCHECK(false) << "Cannot reach here";
+ return nullptr;
+}
+
+std::shared_ptr<Statistics> Statistics::Make(const ColumnDescriptor* descr,
+ const std::string& encoded_min,
+ const std::string& encoded_max,
+ int64_t num_values, int64_t null_count,
+ int64_t distinct_count, bool has_min_max,
+ bool has_null_count, bool has_distinct_count,
+ ::arrow::MemoryPool* pool) {
+#define MAKE_STATS(CAP_TYPE, KLASS) \
+ case Type::CAP_TYPE: \
+ return std::make_shared<TypedStatisticsImpl<KLASS>>( \
+ descr, encoded_min, encoded_max, num_values, null_count, distinct_count, \
+ has_min_max, has_null_count, has_distinct_count, pool)
+
+ switch (descr->physical_type()) {
+ MAKE_STATS(BOOLEAN, BooleanType);
+ MAKE_STATS(INT32, Int32Type);
+ MAKE_STATS(INT64, Int64Type);
+ MAKE_STATS(FLOAT, FloatType);
+ MAKE_STATS(DOUBLE, DoubleType);
+ MAKE_STATS(BYTE_ARRAY, ByteArrayType);
+ MAKE_STATS(FIXED_LEN_BYTE_ARRAY, FLBAType);
+ default:
+ break;
+ }
+#undef MAKE_STATS
+ DCHECK(false) << "Cannot reach here";
+ return nullptr;
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/statistics.h b/src/arrow/cpp/src/parquet/statistics.h
new file mode 100644
index 000000000..ac7abda90
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/statistics.h
@@ -0,0 +1,367 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "parquet/platform.h"
+#include "parquet/types.h"
+
+namespace arrow {
+
+class Array;
+class BinaryArray;
+
+} // namespace arrow
+
+namespace parquet {
+
+class ColumnDescriptor;
+
+// ----------------------------------------------------------------------
+// Value comparator interfaces
+
+/// \brief Base class for value comparators. Generally used with
+/// TypedComparator<T>
+class PARQUET_EXPORT Comparator {
+ public:
+ virtual ~Comparator() {}
+
+ /// \brief Create a comparator explicitly from physical type and
+ /// sort order
+ /// \param[in] physical_type the physical type for the typed
+ /// comparator
+ /// \param[in] sort_order either SortOrder::SIGNED or
+ /// SortOrder::UNSIGNED
+ /// \param[in] type_length for FIXED_LEN_BYTE_ARRAY only
+ static std::shared_ptr<Comparator> Make(Type::type physical_type,
+ SortOrder::type sort_order,
+ int type_length = -1);
+
+ /// \brief Create typed comparator inferring default sort order from
+ /// ColumnDescriptor
+ /// \param[in] descr the Parquet column schema
+ static std::shared_ptr<Comparator> Make(const ColumnDescriptor* descr);
+};
+
+/// \brief Interface for comparison of physical types according to the
+/// semantics of a particular logical type.
+template <typename DType>
+class TypedComparator : public Comparator {
+ public:
+ using T = typename DType::c_type;
+
+ /// \brief Scalar comparison of two elements, return true if first
+ /// is strictly less than the second
+ virtual bool Compare(const T& a, const T& b) = 0;
+
+ /// \brief Compute maximum and minimum elements in a batch of
+ /// elements without any nulls
+ virtual std::pair<T, T> GetMinMax(const T* values, int64_t length) = 0;
+
+ /// \brief Compute minimum and maximum elements from an Arrow array. Only
+ /// valid for certain Parquet Type / Arrow Type combinations, like BYTE_ARRAY
+ /// / arrow::BinaryArray
+ virtual std::pair<T, T> GetMinMax(const ::arrow::Array& values) = 0;
+
+ /// \brief Compute maximum and minimum elements in a batch of
+ /// elements with accompanying bitmap indicating which elements are
+ /// included (bit set) and excluded (bit not set)
+ ///
+ /// \param[in] values the sequence of values
+ /// \param[in] length the length of the sequence
+ /// \param[in] valid_bits a bitmap indicating which elements are
+ /// included (1) or excluded (0)
+ /// \param[in] valid_bits_offset the bit offset into the bitmap of
+ /// the first element in the sequence
+ virtual std::pair<T, T> GetMinMaxSpaced(const T* values, int64_t length,
+ const uint8_t* valid_bits,
+ int64_t valid_bits_offset) = 0;
+};
+
+/// \brief Typed version of Comparator::Make
+template <typename DType>
+std::shared_ptr<TypedComparator<DType>> MakeComparator(Type::type physical_type,
+ SortOrder::type sort_order,
+ int type_length = -1) {
+ return std::static_pointer_cast<TypedComparator<DType>>(
+ Comparator::Make(physical_type, sort_order, type_length));
+}
+
+/// \brief Typed version of Comparator::Make
+template <typename DType>
+std::shared_ptr<TypedComparator<DType>> MakeComparator(const ColumnDescriptor* descr) {
+ return std::static_pointer_cast<TypedComparator<DType>>(Comparator::Make(descr));
+}
+
+// ----------------------------------------------------------------------
+
+/// \brief Structure represented encoded statistics to be written to
+/// and from Parquet serialized metadata
+class PARQUET_EXPORT EncodedStatistics {
+ std::shared_ptr<std::string> max_, min_;
+ bool is_signed_ = false;
+
+ public:
+ EncodedStatistics()
+ : max_(std::make_shared<std::string>()), min_(std::make_shared<std::string>()) {}
+
+ const std::string& max() const { return *max_; }
+ const std::string& min() const { return *min_; }
+
+ int64_t null_count = 0;
+ int64_t distinct_count = 0;
+
+ bool has_min = false;
+ bool has_max = false;
+ bool has_null_count = false;
+ bool has_distinct_count = false;
+
+ // From parquet-mr
+ // Don't write stats larger than the max size rather than truncating. The
+ // rationale is that some engines may use the minimum value in the page as
+ // the true minimum for aggregations and there is no way to mark that a
+ // value has been truncated and is a lower bound and not in the page.
+ void ApplyStatSizeLimits(size_t length) {
+ if (max_->length() > length) {
+ has_max = false;
+ }
+ if (min_->length() > length) {
+ has_min = false;
+ }
+ }
+
+ bool is_set() const {
+ return has_min || has_max || has_null_count || has_distinct_count;
+ }
+
+ bool is_signed() const { return is_signed_; }
+
+ void set_is_signed(bool is_signed) { is_signed_ = is_signed; }
+
+ EncodedStatistics& set_max(const std::string& value) {
+ *max_ = value;
+ has_max = true;
+ return *this;
+ }
+
+ EncodedStatistics& set_min(const std::string& value) {
+ *min_ = value;
+ has_min = true;
+ return *this;
+ }
+
+ EncodedStatistics& set_null_count(int64_t value) {
+ null_count = value;
+ has_null_count = true;
+ return *this;
+ }
+
+ EncodedStatistics& set_distinct_count(int64_t value) {
+ distinct_count = value;
+ has_distinct_count = true;
+ return *this;
+ }
+};
+
+/// \brief Base type for computing column statistics while writing a file
+class PARQUET_EXPORT Statistics {
+ public:
+ virtual ~Statistics() {}
+
+ /// \brief Create a new statistics instance given a column schema
+ /// definition
+ /// \param[in] descr the column schema
+ /// \param[in] pool a memory pool to use for any memory allocations, optional
+ static std::shared_ptr<Statistics> Make(
+ const ColumnDescriptor* descr,
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool());
+
+ /// \brief Create a new statistics instance given a column schema
+ /// definition and pre-existing state
+ /// \param[in] descr the column schema
+ /// \param[in] encoded_min the encoded minimum value
+ /// \param[in] encoded_max the encoded maximum value
+ /// \param[in] num_values total number of values
+ /// \param[in] null_count number of null values
+ /// \param[in] distinct_count number of distinct values
+ /// \param[in] has_min_max whether the min/max statistics are set
+ /// \param[in] has_null_count whether the null_count statistics are set
+ /// \param[in] has_distinct_count whether the distinct_count statistics are set
+ /// \param[in] pool a memory pool to use for any memory allocations, optional
+ static std::shared_ptr<Statistics> Make(
+ const ColumnDescriptor* descr, const std::string& encoded_min,
+ const std::string& encoded_max, int64_t num_values, int64_t null_count,
+ int64_t distinct_count, bool has_min_max, bool has_null_count,
+ bool has_distinct_count,
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool());
+
+ /// \brief Return true if the count of null values is set
+ virtual bool HasNullCount() const = 0;
+
+ /// \brief The number of null values, may not be set
+ virtual int64_t null_count() const = 0;
+
+ /// \brief Return true if the count of distinct values is set
+ virtual bool HasDistinctCount() const = 0;
+
+ /// \brief The number of distinct values, may not be set
+ virtual int64_t distinct_count() const = 0;
+
+ /// \brief The total number of values in the column
+ virtual int64_t num_values() const = 0;
+
+ /// \brief Return true if the min and max statistics are set. Obtain
+ /// with TypedStatistics<T>::min and max
+ virtual bool HasMinMax() const = 0;
+
+ /// \brief Reset state of object to initial (no data observed) state
+ virtual void Reset() = 0;
+
+ /// \brief Plain-encoded minimum value
+ virtual std::string EncodeMin() const = 0;
+
+ /// \brief Plain-encoded maximum value
+ virtual std::string EncodeMax() const = 0;
+
+ /// \brief The finalized encoded form of the statistics for transport
+ virtual EncodedStatistics Encode() = 0;
+
+ /// \brief The physical type of the column schema
+ virtual Type::type physical_type() const = 0;
+
+ /// \brief The full type descriptor from the column schema
+ virtual const ColumnDescriptor* descr() const = 0;
+
+ /// \brief Check two Statistics for equality
+ virtual bool Equals(const Statistics& other) const = 0;
+
+ protected:
+ static std::shared_ptr<Statistics> Make(Type::type physical_type, const void* min,
+ const void* max, int64_t num_values,
+ int64_t null_count, int64_t distinct_count);
+};
+
+/// \brief A typed implementation of Statistics
+template <typename DType>
+class TypedStatistics : public Statistics {
+ public:
+ using T = typename DType::c_type;
+
+ /// \brief The current minimum value
+ virtual const T& min() const = 0;
+
+ /// \brief The current maximum value
+ virtual const T& max() const = 0;
+
+ /// \brief Update state with state of another Statistics object
+ virtual void Merge(const TypedStatistics<DType>& other) = 0;
+
+ /// \brief Batch statistics update
+ virtual void Update(const T* values, int64_t num_not_null, int64_t num_null) = 0;
+
+ /// \brief Batch statistics update with supplied validity bitmap
+ /// \param[in] values pointer to column values
+ /// \param[in] valid_bits Pointer to bitmap representing if values are non-null.
+ /// \param[in] valid_bits_offset Offset offset into valid_bits where the slice of
+ /// data begins.
+ /// \param[in] num_spaced_values The length of values in values/valid_bits to inspect
+ /// when calculating statistics. This can be smaller than
+ /// num_not_null+num_null as num_null can include nulls
+ /// from parents while num_spaced_values does not.
+ /// \param[in] num_not_null Number of values that are not null.
+ /// \param[in] num_null Number of values that are null.
+ virtual void UpdateSpaced(const T* values, const uint8_t* valid_bits,
+ int64_t valid_bits_offset, int64_t num_spaced_values,
+ int64_t num_not_null, int64_t num_null) = 0;
+
+ /// \brief EXPERIMENTAL: Update statistics with an Arrow array without
+ /// conversion to a primitive Parquet C type. Only implemented for certain
+ /// Parquet type / Arrow type combinations like BYTE_ARRAY /
+ /// arrow::BinaryArray
+ ///
+ /// If update_counts is true then the null_count and num_values will be updated
+ /// based on the null_count of values. Set to false if these are updated
+ /// elsewhere (e.g. when updating a dictionary where the counts are taken from
+ /// the indices and not the values)
+ virtual void Update(const ::arrow::Array& values, bool update_counts = true) = 0;
+
+ /// \brief Set min and max values to particular values
+ virtual void SetMinMax(const T& min, const T& max) = 0;
+
+ /// \brief Increments the null count directly
+ /// Use Update to extract the null count from data. Use this if you determine
+ /// the null count through some other means (e.g. dictionary arrays where the
+ /// null count is determined from the indices)
+ virtual void IncrementNullCount(int64_t n) = 0;
+
+ /// \brief Increments the number ov values directly
+ /// The same note on IncrementNullCount applies here
+ virtual void IncrementNumValues(int64_t n) = 0;
+};
+
+using BoolStatistics = TypedStatistics<BooleanType>;
+using Int32Statistics = TypedStatistics<Int32Type>;
+using Int64Statistics = TypedStatistics<Int64Type>;
+using FloatStatistics = TypedStatistics<FloatType>;
+using DoubleStatistics = TypedStatistics<DoubleType>;
+using ByteArrayStatistics = TypedStatistics<ByteArrayType>;
+using FLBAStatistics = TypedStatistics<FLBAType>;
+
+/// \brief Typed version of Statistics::Make
+template <typename DType>
+std::shared_ptr<TypedStatistics<DType>> MakeStatistics(
+ const ColumnDescriptor* descr,
+ ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) {
+ return std::static_pointer_cast<TypedStatistics<DType>>(Statistics::Make(descr, pool));
+}
+
+/// \brief Create Statistics initialized to a particular state
+/// \param[in] min the minimum value
+/// \param[in] max the minimum value
+/// \param[in] num_values number of values
+/// \param[in] null_count number of null values
+/// \param[in] distinct_count number of distinct values
+template <typename DType>
+std::shared_ptr<TypedStatistics<DType>> MakeStatistics(const typename DType::c_type& min,
+ const typename DType::c_type& max,
+ int64_t num_values,
+ int64_t null_count,
+ int64_t distinct_count) {
+ return std::static_pointer_cast<TypedStatistics<DType>>(Statistics::Make(
+ DType::type_num, &min, &max, num_values, null_count, distinct_count));
+}
+
+/// \brief Typed version of Statistics::Make
+template <typename DType>
+std::shared_ptr<TypedStatistics<DType>> MakeStatistics(
+ const ColumnDescriptor* descr, const std::string& encoded_min,
+ const std::string& encoded_max, int64_t num_values, int64_t null_count,
+ int64_t distinct_count, bool has_min_max, bool has_null_count,
+ bool has_distinct_count, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) {
+ return std::static_pointer_cast<TypedStatistics<DType>>(Statistics::Make(
+ descr, encoded_min, encoded_max, num_values, null_count, distinct_count,
+ has_min_max, has_null_count, has_distinct_count, pool));
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/statistics_test.cc b/src/arrow/cpp/src/parquet/statistics_test.cc
new file mode 100644
index 000000000..9552c7b91
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/statistics_test.cc
@@ -0,0 +1,1178 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <algorithm>
+#include <array>
+#include <cmath>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <vector>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/ubsan.h"
+
+#include "parquet/column_reader.h"
+#include "parquet/column_writer.h"
+#include "parquet/file_reader.h"
+#include "parquet/file_writer.h"
+#include "parquet/platform.h"
+#include "parquet/schema.h"
+#include "parquet/statistics.h"
+#include "parquet/test_util.h"
+#include "parquet/thrift_internal.h"
+#include "parquet/types.h"
+
+using arrow::default_memory_pool;
+using arrow::MemoryPool;
+using arrow::util::SafeCopy;
+
+namespace BitUtil = arrow::BitUtil;
+
+namespace parquet {
+
+using schema::GroupNode;
+using schema::NodePtr;
+using schema::PrimitiveNode;
+
+namespace test {
+
+// ----------------------------------------------------------------------
+// Test comparators
+
+static ByteArray ByteArrayFromString(const std::string& s) {
+ auto ptr = reinterpret_cast<const uint8_t*>(s.data());
+ return ByteArray(static_cast<uint32_t>(s.size()), ptr);
+}
+
+static FLBA FLBAFromString(const std::string& s) {
+ auto ptr = reinterpret_cast<const uint8_t*>(s.data());
+ return FLBA(ptr);
+}
+
+TEST(Comparison, SignedByteArray) {
+ // Signed byte array comparison is only used for Decimal comparison. When
+ // decimals are encoded as byte arrays they use twos complement big-endian
+ // encoded values. Comparisons of byte arrays of unequal types need to handle
+ // sign extension.
+ auto comparator = MakeComparator<ByteArrayType>(Type::BYTE_ARRAY, SortOrder::SIGNED);
+ struct Case {
+ std::vector<uint8_t> bytes;
+ int order;
+ ByteArray ToByteArray() const {
+ return ByteArray(static_cast<int>(bytes.size()), bytes.data());
+ }
+ };
+
+ // Test a mix of big-endian comparison values that are both equal and
+ // unequal after sign extension.
+ std::vector<Case> cases = {
+ {{0x80, 0x80, 0, 0}, 0}, {{/*0xFF,*/ 0x80, 0, 0}, 1},
+ {{0xFF, 0x80, 0, 0}, 1}, {{/*0xFF,*/ 0xFF, 0x01, 0}, 2},
+ {{/*0xFF, 0xFF,*/ 0x80, 0}, 3}, {{/*0xFF,*/ 0xFF, 0x80, 0}, 3},
+ {{0xFF, 0xFF, 0x80, 0}, 3}, {{/*0xFF,0xFF,0xFF,*/ 0x80}, 4},
+ {{/*0xFF, 0xFF, 0xFF,*/ 0xFF}, 5}, {{/*0, 0,*/ 0x01, 0x01}, 6},
+ {{/*0,*/ 0, 0x01, 0x01}, 6}, {{0, 0, 0x01, 0x01}, 6},
+ {{/*0,*/ 0x01, 0x01, 0}, 7}, {{0x01, 0x01, 0, 0}, 8}};
+
+ for (size_t x = 0; x < cases.size(); x++) {
+ const auto& case1 = cases[x];
+ // Empty array is always the smallest values
+ EXPECT_TRUE(comparator->Compare(ByteArray(), case1.ToByteArray())) << x;
+ EXPECT_FALSE(comparator->Compare(case1.ToByteArray(), ByteArray())) << x;
+ // Equals is always false.
+ EXPECT_FALSE(comparator->Compare(case1.ToByteArray(), case1.ToByteArray())) << x;
+
+ for (size_t y = 0; y < cases.size(); y++) {
+ const auto& case2 = cases[y];
+ if (case1.order < case2.order) {
+ EXPECT_TRUE(comparator->Compare(case1.ToByteArray(), case2.ToByteArray()))
+ << x << " (order: " << case1.order << ") " << y << " (order: " << case2.order
+ << ")";
+ } else {
+ EXPECT_FALSE(comparator->Compare(case1.ToByteArray(), case2.ToByteArray()))
+ << x << " (order: " << case1.order << ") " << y << " (order: " << case2.order
+ << ")";
+ }
+ }
+ }
+}
+
+TEST(Comparison, UnsignedByteArray) {
+ // Check if UTF-8 is compared using unsigned correctly
+ auto comparator = MakeComparator<ByteArrayType>(Type::BYTE_ARRAY, SortOrder::UNSIGNED);
+
+ std::string s1 = "arrange";
+ std::string s2 = "arrangement";
+ ByteArray s1ba = ByteArrayFromString(s1);
+ ByteArray s2ba = ByteArrayFromString(s2);
+ ASSERT_TRUE(comparator->Compare(s1ba, s2ba));
+
+ // Multi-byte UTF-8 characters
+ s1 = u8"braten";
+ s2 = u8"bügeln";
+ s1ba = ByteArrayFromString(s1);
+ s2ba = ByteArrayFromString(s2);
+ ASSERT_TRUE(comparator->Compare(s1ba, s2ba));
+
+ s1 = u8"ünk123456"; // ü = 252
+ s2 = u8"ănk123456"; // ă = 259
+ s1ba = ByteArrayFromString(s1);
+ s2ba = ByteArrayFromString(s2);
+ ASSERT_TRUE(comparator->Compare(s1ba, s2ba));
+}
+
+TEST(Comparison, SignedFLBA) {
+ int size = 4;
+ auto comparator =
+ MakeComparator<FLBAType>(Type::FIXED_LEN_BYTE_ARRAY, SortOrder::SIGNED, size);
+
+ std::vector<uint8_t> byte_values[] = {
+ {0x80, 0, 0, 0}, {0xFF, 0xFF, 0x01, 0}, {0xFF, 0xFF, 0x80, 0},
+ {0xFF, 0xFF, 0xFF, 0x80}, {0xFF, 0xFF, 0xFF, 0xFF}, {0, 0, 0x01, 0x01},
+ {0, 0x01, 0x01, 0}, {0x01, 0x01, 0, 0}};
+ std::vector<FLBA> values_to_compare;
+ for (auto& bytes : byte_values) {
+ values_to_compare.emplace_back(FLBA(bytes.data()));
+ }
+
+ for (size_t x = 0; x < values_to_compare.size(); x++) {
+ EXPECT_FALSE(comparator->Compare(values_to_compare[x], values_to_compare[x])) << x;
+ for (size_t y = x + 1; y < values_to_compare.size(); y++) {
+ EXPECT_TRUE(comparator->Compare(values_to_compare[x], values_to_compare[y]))
+ << x << " " << y;
+ EXPECT_FALSE(comparator->Compare(values_to_compare[y], values_to_compare[x]))
+ << y << " " << x;
+ }
+ }
+}
+
+TEST(Comparison, UnsignedFLBA) {
+ int size = 10;
+ auto comparator =
+ MakeComparator<FLBAType>(Type::FIXED_LEN_BYTE_ARRAY, SortOrder::UNSIGNED, size);
+
+ std::string s1 = "Anti123456";
+ std::string s2 = "Bunkd123456";
+ FLBA s1flba = FLBAFromString(s1);
+ FLBA s2flba = FLBAFromString(s2);
+ ASSERT_TRUE(comparator->Compare(s1flba, s2flba));
+
+ s1 = "Bunk123456";
+ s2 = "Bünk123456";
+ s1flba = FLBAFromString(s1);
+ s2flba = FLBAFromString(s2);
+ ASSERT_TRUE(comparator->Compare(s1flba, s2flba));
+}
+
+TEST(Comparison, SignedInt96) {
+ parquet::Int96 a{{1, 41, 14}}, b{{1, 41, 42}};
+ parquet::Int96 aa{{1, 41, 14}}, bb{{1, 41, 14}};
+ parquet::Int96 aaa{{1, 41, static_cast<uint32_t>(-14)}}, bbb{{1, 41, 42}};
+
+ auto comparator = MakeComparator<Int96Type>(Type::INT96, SortOrder::SIGNED);
+
+ ASSERT_TRUE(comparator->Compare(a, b));
+ ASSERT_TRUE(!comparator->Compare(aa, bb) && !comparator->Compare(bb, aa));
+ ASSERT_TRUE(comparator->Compare(aaa, bbb));
+}
+
+TEST(Comparison, UnsignedInt96) {
+ parquet::Int96 a{{1, 41, 14}}, b{{1, static_cast<uint32_t>(-41), 42}};
+ parquet::Int96 aa{{1, 41, 14}}, bb{{1, 41, static_cast<uint32_t>(-14)}};
+ parquet::Int96 aaa, bbb;
+
+ auto comparator = MakeComparator<Int96Type>(Type::INT96, SortOrder::UNSIGNED);
+
+ ASSERT_TRUE(comparator->Compare(a, b));
+ ASSERT_TRUE(comparator->Compare(aa, bb));
+
+ // INT96 Timestamp
+ aaa.value[2] = 2451545; // 2000-01-01
+ bbb.value[2] = 2451546; // 2000-01-02
+ // 12 hours + 34 minutes + 56 seconds.
+ Int96SetNanoSeconds(aaa, 45296000000000);
+ // 12 hours + 34 minutes + 50 seconds.
+ Int96SetNanoSeconds(bbb, 45290000000000);
+ ASSERT_TRUE(comparator->Compare(aaa, bbb));
+
+ aaa.value[2] = 2451545; // 2000-01-01
+ bbb.value[2] = 2451545; // 2000-01-01
+ // 11 hours + 34 minutes + 56 seconds.
+ Int96SetNanoSeconds(aaa, 41696000000000);
+ // 12 hours + 34 minutes + 50 seconds.
+ Int96SetNanoSeconds(bbb, 45290000000000);
+ ASSERT_TRUE(comparator->Compare(aaa, bbb));
+
+ aaa.value[2] = 2451545; // 2000-01-01
+ bbb.value[2] = 2451545; // 2000-01-01
+ // 12 hours + 34 minutes + 55 seconds.
+ Int96SetNanoSeconds(aaa, 45295000000000);
+ // 12 hours + 34 minutes + 56 seconds.
+ Int96SetNanoSeconds(bbb, 45296000000000);
+ ASSERT_TRUE(comparator->Compare(aaa, bbb));
+}
+
+TEST(Comparison, SignedInt64) {
+ int64_t a = 1, b = 4;
+ int64_t aa = 1, bb = 1;
+ int64_t aaa = -1, bbb = 1;
+
+ NodePtr node = PrimitiveNode::Make("SignedInt64", Repetition::REQUIRED, Type::INT64);
+ ColumnDescriptor descr(node, 0, 0);
+
+ auto comparator = MakeComparator<Int64Type>(&descr);
+
+ ASSERT_TRUE(comparator->Compare(a, b));
+ ASSERT_TRUE(!comparator->Compare(aa, bb) && !comparator->Compare(bb, aa));
+ ASSERT_TRUE(comparator->Compare(aaa, bbb));
+}
+
+TEST(Comparison, UnsignedInt64) {
+ uint64_t a = 1, b = 4;
+ uint64_t aa = 1, bb = 1;
+ uint64_t aaa = 1, bbb = -1;
+
+ NodePtr node = PrimitiveNode::Make("UnsignedInt64", Repetition::REQUIRED, Type::INT64,
+ ConvertedType::UINT_64);
+ ColumnDescriptor descr(node, 0, 0);
+
+ ASSERT_EQ(SortOrder::UNSIGNED, descr.sort_order());
+ auto comparator = MakeComparator<Int64Type>(&descr);
+
+ ASSERT_TRUE(comparator->Compare(a, b));
+ ASSERT_TRUE(!comparator->Compare(aa, bb) && !comparator->Compare(bb, aa));
+ ASSERT_TRUE(comparator->Compare(aaa, bbb));
+}
+
+TEST(Comparison, UnsignedInt32) {
+ uint32_t a = 1, b = 4;
+ uint32_t aa = 1, bb = 1;
+ uint32_t aaa = 1, bbb = -1;
+
+ NodePtr node = PrimitiveNode::Make("UnsignedInt32", Repetition::REQUIRED, Type::INT32,
+ ConvertedType::UINT_32);
+ ColumnDescriptor descr(node, 0, 0);
+
+ ASSERT_EQ(SortOrder::UNSIGNED, descr.sort_order());
+ auto comparator = MakeComparator<Int32Type>(&descr);
+
+ ASSERT_TRUE(comparator->Compare(a, b));
+ ASSERT_TRUE(!comparator->Compare(aa, bb) && !comparator->Compare(bb, aa));
+ ASSERT_TRUE(comparator->Compare(aaa, bbb));
+}
+
+TEST(Comparison, UnknownSortOrder) {
+ NodePtr node =
+ PrimitiveNode::Make("Unknown", Repetition::REQUIRED, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::INTERVAL, 12);
+ ColumnDescriptor descr(node, 0, 0);
+
+ ASSERT_THROW(Comparator::Make(&descr), ParquetException);
+}
+
+// ----------------------------------------------------------------------
+
+template <typename TestType>
+class TestStatistics : public PrimitiveTypedTest<TestType> {
+ public:
+ using c_type = typename TestType::c_type;
+
+ std::vector<c_type> GetDeepCopy(
+ const std::vector<c_type>&); // allocates new memory for FLBA/ByteArray
+
+ c_type* GetValuesPointer(std::vector<c_type>&);
+ void DeepFree(std::vector<c_type>&);
+
+ void TestMinMaxEncode() {
+ this->GenerateData(1000);
+
+ auto statistics1 = MakeStatistics<TestType>(this->schema_.Column(0));
+ statistics1->Update(this->values_ptr_, this->values_.size(), 0);
+ std::string encoded_min = statistics1->EncodeMin();
+ std::string encoded_max = statistics1->EncodeMax();
+
+ auto statistics2 =
+ MakeStatistics<TestType>(this->schema_.Column(0), encoded_min, encoded_max,
+ this->values_.size(), 0, 0, true, true, true);
+
+ auto statistics3 = MakeStatistics<TestType>(this->schema_.Column(0));
+ std::vector<uint8_t> valid_bits(
+ BitUtil::BytesForBits(static_cast<uint32_t>(this->values_.size())) + 1, 255);
+ statistics3->UpdateSpaced(this->values_ptr_, valid_bits.data(), 0,
+ this->values_.size(), this->values_.size(), 0);
+ std::string encoded_min_spaced = statistics3->EncodeMin();
+ std::string encoded_max_spaced = statistics3->EncodeMax();
+
+ ASSERT_EQ(encoded_min, statistics2->EncodeMin());
+ ASSERT_EQ(encoded_max, statistics2->EncodeMax());
+ ASSERT_EQ(statistics1->min(), statistics2->min());
+ ASSERT_EQ(statistics1->max(), statistics2->max());
+ ASSERT_EQ(encoded_min_spaced, statistics2->EncodeMin());
+ ASSERT_EQ(encoded_max_spaced, statistics2->EncodeMax());
+ ASSERT_EQ(statistics3->min(), statistics2->min());
+ ASSERT_EQ(statistics3->max(), statistics2->max());
+ }
+
+ void TestReset() {
+ this->GenerateData(1000);
+
+ auto statistics = MakeStatistics<TestType>(this->schema_.Column(0));
+ statistics->Update(this->values_ptr_, this->values_.size(), 0);
+ ASSERT_EQ(this->values_.size(), statistics->num_values());
+
+ statistics->Reset();
+ ASSERT_EQ(0, statistics->null_count());
+ ASSERT_EQ(0, statistics->num_values());
+ ASSERT_EQ(0, statistics->distinct_count());
+ ASSERT_EQ("", statistics->EncodeMin());
+ ASSERT_EQ("", statistics->EncodeMax());
+ }
+
+ void TestMerge() {
+ int num_null[2];
+ random_numbers(2, 42, 0, 100, num_null);
+
+ auto statistics1 = MakeStatistics<TestType>(this->schema_.Column(0));
+ this->GenerateData(1000);
+ statistics1->Update(this->values_ptr_, this->values_.size() - num_null[0],
+ num_null[0]);
+
+ auto statistics2 = MakeStatistics<TestType>(this->schema_.Column(0));
+ this->GenerateData(1000);
+ statistics2->Update(this->values_ptr_, this->values_.size() - num_null[1],
+ num_null[1]);
+
+ auto total = MakeStatistics<TestType>(this->schema_.Column(0));
+ total->Merge(*statistics1);
+ total->Merge(*statistics2);
+
+ ASSERT_EQ(num_null[0] + num_null[1], total->null_count());
+ ASSERT_EQ(this->values_.size() * 2 - num_null[0] - num_null[1], total->num_values());
+ ASSERT_EQ(total->min(), std::min(statistics1->min(), statistics2->min()));
+ ASSERT_EQ(total->max(), std::max(statistics1->max(), statistics2->max()));
+ }
+
+ void TestFullRoundtrip(int64_t num_values, int64_t null_count) {
+ this->GenerateData(num_values);
+
+ // compute statistics for the whole batch
+ auto expected_stats = MakeStatistics<TestType>(this->schema_.Column(0));
+ expected_stats->Update(this->values_ptr_, num_values - null_count, null_count);
+
+ auto sink = CreateOutputStream();
+ auto gnode = std::static_pointer_cast<GroupNode>(this->node_);
+ std::shared_ptr<WriterProperties> writer_properties =
+ WriterProperties::Builder().enable_statistics("column")->build();
+ auto file_writer = ParquetFileWriter::Open(sink, gnode, writer_properties);
+ auto row_group_writer = file_writer->AppendRowGroup();
+ auto column_writer =
+ static_cast<TypedColumnWriter<TestType>*>(row_group_writer->NextColumn());
+
+ // simulate the case when data comes from multiple buffers,
+ // in which case special care is necessary for FLBA/ByteArray types
+ for (int i = 0; i < 2; i++) {
+ int64_t batch_num_values = i ? num_values - num_values / 2 : num_values / 2;
+ int64_t batch_null_count = i ? null_count : 0;
+ DCHECK(null_count <= num_values); // avoid too much headache
+ std::vector<int16_t> definition_levels(batch_null_count, 0);
+ definition_levels.insert(definition_levels.end(),
+ batch_num_values - batch_null_count, 1);
+ auto beg = this->values_.begin() + i * num_values / 2;
+ auto end = beg + batch_num_values;
+ std::vector<c_type> batch = GetDeepCopy(std::vector<c_type>(beg, end));
+ c_type* batch_values_ptr = GetValuesPointer(batch);
+ column_writer->WriteBatch(batch_num_values, definition_levels.data(), nullptr,
+ batch_values_ptr);
+ DeepFree(batch);
+ }
+ column_writer->Close();
+ row_group_writer->Close();
+ file_writer->Close();
+
+ ASSERT_OK_AND_ASSIGN(auto buffer, sink->Finish());
+ auto source = std::make_shared<::arrow::io::BufferReader>(buffer);
+ auto file_reader = ParquetFileReader::Open(source);
+ auto rg_reader = file_reader->RowGroup(0);
+ auto column_chunk = rg_reader->metadata()->ColumnChunk(0);
+ if (!column_chunk->is_stats_set()) return;
+ std::shared_ptr<Statistics> stats = column_chunk->statistics();
+ // check values after serialization + deserialization
+ EXPECT_EQ(null_count, stats->null_count());
+ EXPECT_EQ(num_values - null_count, stats->num_values());
+ EXPECT_TRUE(expected_stats->HasMinMax());
+ EXPECT_EQ(expected_stats->EncodeMin(), stats->EncodeMin());
+ EXPECT_EQ(expected_stats->EncodeMax(), stats->EncodeMax());
+ }
+};
+
+template <typename TestType>
+typename TestType::c_type* TestStatistics<TestType>::GetValuesPointer(
+ std::vector<typename TestType::c_type>& values) {
+ return values.data();
+}
+
+template <>
+bool* TestStatistics<BooleanType>::GetValuesPointer(std::vector<bool>& values) {
+ static std::vector<uint8_t> bool_buffer;
+ bool_buffer.clear();
+ bool_buffer.resize(values.size());
+ std::copy(values.begin(), values.end(), bool_buffer.begin());
+ return reinterpret_cast<bool*>(bool_buffer.data());
+}
+
+template <typename TestType>
+typename std::vector<typename TestType::c_type> TestStatistics<TestType>::GetDeepCopy(
+ const std::vector<typename TestType::c_type>& values) {
+ return values;
+}
+
+template <>
+std::vector<FLBA> TestStatistics<FLBAType>::GetDeepCopy(const std::vector<FLBA>& values) {
+ std::vector<FLBA> copy;
+ MemoryPool* pool = ::arrow::default_memory_pool();
+ for (const FLBA& flba : values) {
+ uint8_t* ptr;
+ PARQUET_THROW_NOT_OK(pool->Allocate(FLBA_LENGTH, &ptr));
+ memcpy(ptr, flba.ptr, FLBA_LENGTH);
+ copy.emplace_back(ptr);
+ }
+ return copy;
+}
+
+template <>
+std::vector<ByteArray> TestStatistics<ByteArrayType>::GetDeepCopy(
+ const std::vector<ByteArray>& values) {
+ std::vector<ByteArray> copy;
+ MemoryPool* pool = default_memory_pool();
+ for (const ByteArray& ba : values) {
+ uint8_t* ptr;
+ PARQUET_THROW_NOT_OK(pool->Allocate(ba.len, &ptr));
+ memcpy(ptr, ba.ptr, ba.len);
+ copy.emplace_back(ba.len, ptr);
+ }
+ return copy;
+}
+
+template <typename TestType>
+void TestStatistics<TestType>::DeepFree(std::vector<typename TestType::c_type>& values) {}
+
+template <>
+void TestStatistics<FLBAType>::DeepFree(std::vector<FLBA>& values) {
+ MemoryPool* pool = default_memory_pool();
+ for (FLBA& flba : values) {
+ auto ptr = const_cast<uint8_t*>(flba.ptr);
+ memset(ptr, 0, FLBA_LENGTH);
+ pool->Free(ptr, FLBA_LENGTH);
+ }
+}
+
+template <>
+void TestStatistics<ByteArrayType>::DeepFree(std::vector<ByteArray>& values) {
+ MemoryPool* pool = default_memory_pool();
+ for (ByteArray& ba : values) {
+ auto ptr = const_cast<uint8_t*>(ba.ptr);
+ memset(ptr, 0, ba.len);
+ pool->Free(ptr, ba.len);
+ }
+}
+
+template <>
+void TestStatistics<ByteArrayType>::TestMinMaxEncode() {
+ this->GenerateData(1000);
+ // Test that we encode min max strings correctly
+ auto statistics1 = MakeStatistics<ByteArrayType>(this->schema_.Column(0));
+ statistics1->Update(this->values_ptr_, this->values_.size(), 0);
+ std::string encoded_min = statistics1->EncodeMin();
+ std::string encoded_max = statistics1->EncodeMax();
+
+ // encoded is same as unencoded
+ ASSERT_EQ(encoded_min,
+ std::string(reinterpret_cast<const char*>(statistics1->min().ptr),
+ statistics1->min().len));
+ ASSERT_EQ(encoded_max,
+ std::string(reinterpret_cast<const char*>(statistics1->max().ptr),
+ statistics1->max().len));
+
+ auto statistics2 =
+ MakeStatistics<ByteArrayType>(this->schema_.Column(0), encoded_min, encoded_max,
+ this->values_.size(), 0, 0, true, true, true);
+
+ ASSERT_EQ(encoded_min, statistics2->EncodeMin());
+ ASSERT_EQ(encoded_max, statistics2->EncodeMax());
+ ASSERT_EQ(statistics1->min(), statistics2->min());
+ ASSERT_EQ(statistics1->max(), statistics2->max());
+}
+
+using Types = ::testing::Types<Int32Type, Int64Type, FloatType, DoubleType, ByteArrayType,
+ FLBAType, BooleanType>;
+
+TYPED_TEST_SUITE(TestStatistics, Types);
+
+TYPED_TEST(TestStatistics, MinMaxEncode) {
+ this->SetUpSchema(Repetition::REQUIRED);
+ ASSERT_NO_FATAL_FAILURE(this->TestMinMaxEncode());
+}
+
+TYPED_TEST(TestStatistics, Reset) {
+ this->SetUpSchema(Repetition::OPTIONAL);
+ ASSERT_NO_FATAL_FAILURE(this->TestReset());
+}
+
+TYPED_TEST(TestStatistics, FullRoundtrip) {
+ this->SetUpSchema(Repetition::OPTIONAL);
+ ASSERT_NO_FATAL_FAILURE(this->TestFullRoundtrip(100, 31));
+ ASSERT_NO_FATAL_FAILURE(this->TestFullRoundtrip(1000, 415));
+ ASSERT_NO_FATAL_FAILURE(this->TestFullRoundtrip(10000, 926));
+}
+
+template <typename TestType>
+class TestNumericStatistics : public TestStatistics<TestType> {};
+
+using NumericTypes = ::testing::Types<Int32Type, Int64Type, FloatType, DoubleType>;
+
+TYPED_TEST_SUITE(TestNumericStatistics, NumericTypes);
+
+TYPED_TEST(TestNumericStatistics, Merge) {
+ this->SetUpSchema(Repetition::OPTIONAL);
+ ASSERT_NO_FATAL_FAILURE(this->TestMerge());
+}
+
+// Helper for basic statistics tests below
+void AssertStatsSet(const ApplicationVersion& version,
+ std::shared_ptr<parquet::WriterProperties> props,
+ const ColumnDescriptor* column, bool expected_is_set) {
+ auto metadata_builder = ColumnChunkMetaDataBuilder::Make(props, column);
+ auto column_chunk =
+ ColumnChunkMetaData::Make(metadata_builder->contents(), column, &version);
+ EncodedStatistics stats;
+ stats.set_is_signed(false);
+ metadata_builder->SetStatistics(stats);
+ ASSERT_EQ(column_chunk->is_stats_set(), expected_is_set);
+}
+
+// Statistics are restricted for few types in older parquet version
+TEST(CorruptStatistics, Basics) {
+ std::string created_by = "parquet-mr version 1.8.0";
+ ApplicationVersion version(created_by);
+ SchemaDescriptor schema;
+ schema::NodePtr node;
+ std::vector<schema::NodePtr> fields;
+ // Test Physical Types
+ fields.push_back(schema::PrimitiveNode::Make("col1", Repetition::OPTIONAL, Type::INT32,
+ ConvertedType::NONE));
+ fields.push_back(schema::PrimitiveNode::Make("col2", Repetition::OPTIONAL,
+ Type::BYTE_ARRAY, ConvertedType::NONE));
+ // Test Logical Types
+ fields.push_back(schema::PrimitiveNode::Make("col3", Repetition::OPTIONAL, Type::INT32,
+ ConvertedType::DATE));
+ fields.push_back(schema::PrimitiveNode::Make("col4", Repetition::OPTIONAL, Type::INT32,
+ ConvertedType::UINT_32));
+ fields.push_back(schema::PrimitiveNode::Make("col5", Repetition::OPTIONAL,
+ Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::INTERVAL, 12));
+ fields.push_back(schema::PrimitiveNode::Make("col6", Repetition::OPTIONAL,
+ Type::BYTE_ARRAY, ConvertedType::UTF8));
+ node = schema::GroupNode::Make("schema", Repetition::REQUIRED, fields);
+ schema.Init(node);
+
+ parquet::WriterProperties::Builder builder;
+ builder.created_by(created_by);
+ std::shared_ptr<parquet::WriterProperties> props = builder.build();
+
+ AssertStatsSet(version, props, schema.Column(0), true);
+ AssertStatsSet(version, props, schema.Column(1), false);
+ AssertStatsSet(version, props, schema.Column(2), true);
+ AssertStatsSet(version, props, schema.Column(3), false);
+ AssertStatsSet(version, props, schema.Column(4), false);
+ AssertStatsSet(version, props, schema.Column(5), false);
+}
+
+// Statistics for all types have no restrictions in newer parquet version
+TEST(CorrectStatistics, Basics) {
+ std::string created_by = "parquet-cpp version 1.3.0";
+ ApplicationVersion version(created_by);
+ SchemaDescriptor schema;
+ schema::NodePtr node;
+ std::vector<schema::NodePtr> fields;
+ // Test Physical Types
+ fields.push_back(schema::PrimitiveNode::Make("col1", Repetition::OPTIONAL, Type::INT32,
+ ConvertedType::NONE));
+ fields.push_back(schema::PrimitiveNode::Make("col2", Repetition::OPTIONAL,
+ Type::BYTE_ARRAY, ConvertedType::NONE));
+ // Test Logical Types
+ fields.push_back(schema::PrimitiveNode::Make("col3", Repetition::OPTIONAL, Type::INT32,
+ ConvertedType::DATE));
+ fields.push_back(schema::PrimitiveNode::Make("col4", Repetition::OPTIONAL, Type::INT32,
+ ConvertedType::UINT_32));
+ fields.push_back(schema::PrimitiveNode::Make("col5", Repetition::OPTIONAL,
+ Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::INTERVAL, 12));
+ fields.push_back(schema::PrimitiveNode::Make("col6", Repetition::OPTIONAL,
+ Type::BYTE_ARRAY, ConvertedType::UTF8));
+ node = schema::GroupNode::Make("schema", Repetition::REQUIRED, fields);
+ schema.Init(node);
+
+ parquet::WriterProperties::Builder builder;
+ builder.created_by(created_by);
+ std::shared_ptr<parquet::WriterProperties> props = builder.build();
+
+ AssertStatsSet(version, props, schema.Column(0), true);
+ AssertStatsSet(version, props, schema.Column(1), true);
+ AssertStatsSet(version, props, schema.Column(2), true);
+ AssertStatsSet(version, props, schema.Column(3), true);
+ AssertStatsSet(version, props, schema.Column(4), false);
+ AssertStatsSet(version, props, schema.Column(5), true);
+}
+
+// Test SortOrder class
+static const int NUM_VALUES = 10;
+
+template <typename TestType>
+class TestStatisticsSortOrder : public ::testing::Test {
+ public:
+ using c_type = typename TestType::c_type;
+
+ void AddNodes(std::string name) {
+ fields_.push_back(schema::PrimitiveNode::Make(
+ name, Repetition::REQUIRED, TestType::type_num, ConvertedType::NONE));
+ }
+
+ void SetUpSchema() {
+ stats_.resize(fields_.size());
+ values_.resize(NUM_VALUES);
+ schema_ = std::static_pointer_cast<GroupNode>(
+ GroupNode::Make("Schema", Repetition::REQUIRED, fields_));
+
+ parquet_sink_ = CreateOutputStream();
+ }
+
+ void SetValues();
+
+ void WriteParquet() {
+ // Add writer properties
+ parquet::WriterProperties::Builder builder;
+ builder.compression(parquet::Compression::SNAPPY);
+ builder.created_by("parquet-cpp version 1.3.0");
+ std::shared_ptr<parquet::WriterProperties> props = builder.build();
+
+ // Create a ParquetFileWriter instance
+ auto file_writer = parquet::ParquetFileWriter::Open(parquet_sink_, schema_, props);
+
+ // Append a RowGroup with a specific number of rows.
+ auto rg_writer = file_writer->AppendRowGroup();
+
+ this->SetValues();
+
+ // Insert Values
+ for (int i = 0; i < static_cast<int>(fields_.size()); i++) {
+ auto column_writer =
+ static_cast<parquet::TypedColumnWriter<TestType>*>(rg_writer->NextColumn());
+ column_writer->WriteBatch(NUM_VALUES, nullptr, nullptr, values_.data());
+ }
+ }
+
+ void VerifyParquetStats() {
+ ASSERT_OK_AND_ASSIGN(auto pbuffer, parquet_sink_->Finish());
+
+ // Create a ParquetReader instance
+ std::unique_ptr<parquet::ParquetFileReader> parquet_reader =
+ parquet::ParquetFileReader::Open(
+ std::make_shared<::arrow::io::BufferReader>(pbuffer));
+
+ // Get the File MetaData
+ std::shared_ptr<parquet::FileMetaData> file_metadata = parquet_reader->metadata();
+ std::shared_ptr<parquet::RowGroupMetaData> rg_metadata = file_metadata->RowGroup(0);
+ for (int i = 0; i < static_cast<int>(fields_.size()); i++) {
+ ARROW_SCOPED_TRACE("Statistics for field #", i);
+ std::shared_ptr<parquet::ColumnChunkMetaData> cc_metadata =
+ rg_metadata->ColumnChunk(i);
+ EXPECT_EQ(stats_[i].min(), cc_metadata->statistics()->EncodeMin());
+ EXPECT_EQ(stats_[i].max(), cc_metadata->statistics()->EncodeMax());
+ }
+ }
+
+ protected:
+ std::vector<c_type> values_;
+ std::vector<uint8_t> values_buf_;
+ std::vector<schema::NodePtr> fields_;
+ std::shared_ptr<schema::GroupNode> schema_;
+ std::shared_ptr<::arrow::io::BufferOutputStream> parquet_sink_;
+ std::vector<EncodedStatistics> stats_;
+};
+
+using CompareTestTypes = ::testing::Types<Int32Type, Int64Type, FloatType, DoubleType,
+ ByteArrayType, FLBAType>;
+
+// TYPE::INT32
+template <>
+void TestStatisticsSortOrder<Int32Type>::AddNodes(std::string name) {
+ // UINT_32 logical type to set Unsigned Statistics
+ fields_.push_back(schema::PrimitiveNode::Make(name, Repetition::REQUIRED, Type::INT32,
+ ConvertedType::UINT_32));
+ // INT_32 logical type to set Signed Statistics
+ fields_.push_back(schema::PrimitiveNode::Make(name, Repetition::REQUIRED, Type::INT32,
+ ConvertedType::INT_32));
+}
+
+template <>
+void TestStatisticsSortOrder<Int32Type>::SetValues() {
+ for (int i = 0; i < NUM_VALUES; i++) {
+ values_[i] = i - 5; // {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4};
+ }
+
+ // Write UINT32 min/max values
+ stats_[0]
+ .set_min(std::string(reinterpret_cast<const char*>(&values_[5]), sizeof(c_type)))
+ .set_max(std::string(reinterpret_cast<const char*>(&values_[4]), sizeof(c_type)));
+
+ // Write INT32 min/max values
+ stats_[1]
+ .set_min(std::string(reinterpret_cast<const char*>(&values_[0]), sizeof(c_type)))
+ .set_max(std::string(reinterpret_cast<const char*>(&values_[9]), sizeof(c_type)));
+}
+
+// TYPE::INT64
+template <>
+void TestStatisticsSortOrder<Int64Type>::AddNodes(std::string name) {
+ // UINT_64 logical type to set Unsigned Statistics
+ fields_.push_back(schema::PrimitiveNode::Make(name, Repetition::REQUIRED, Type::INT64,
+ ConvertedType::UINT_64));
+ // INT_64 logical type to set Signed Statistics
+ fields_.push_back(schema::PrimitiveNode::Make(name, Repetition::REQUIRED, Type::INT64,
+ ConvertedType::INT_64));
+}
+
+template <>
+void TestStatisticsSortOrder<Int64Type>::SetValues() {
+ for (int i = 0; i < NUM_VALUES; i++) {
+ values_[i] = i - 5; // {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4};
+ }
+
+ // Write UINT64 min/max values
+ stats_[0]
+ .set_min(std::string(reinterpret_cast<const char*>(&values_[5]), sizeof(c_type)))
+ .set_max(std::string(reinterpret_cast<const char*>(&values_[4]), sizeof(c_type)));
+
+ // Write INT64 min/max values
+ stats_[1]
+ .set_min(std::string(reinterpret_cast<const char*>(&values_[0]), sizeof(c_type)))
+ .set_max(std::string(reinterpret_cast<const char*>(&values_[9]), sizeof(c_type)));
+}
+
+// TYPE::FLOAT
+template <>
+void TestStatisticsSortOrder<FloatType>::SetValues() {
+ for (int i = 0; i < NUM_VALUES; i++) {
+ values_[i] = static_cast<float>(i) -
+ 5; // {-5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0};
+ }
+
+ // Write Float min/max values
+ stats_[0]
+ .set_min(std::string(reinterpret_cast<const char*>(&values_[0]), sizeof(c_type)))
+ .set_max(std::string(reinterpret_cast<const char*>(&values_[9]), sizeof(c_type)));
+}
+
+// TYPE::DOUBLE
+template <>
+void TestStatisticsSortOrder<DoubleType>::SetValues() {
+ for (int i = 0; i < NUM_VALUES; i++) {
+ values_[i] = static_cast<float>(i) -
+ 5; // {-5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0};
+ }
+
+ // Write Double min/max values
+ stats_[0]
+ .set_min(std::string(reinterpret_cast<const char*>(&values_[0]), sizeof(c_type)))
+ .set_max(std::string(reinterpret_cast<const char*>(&values_[9]), sizeof(c_type)));
+}
+
+// TYPE::ByteArray
+template <>
+void TestStatisticsSortOrder<ByteArrayType>::AddNodes(std::string name) {
+ // UTF8 logical type to set Unsigned Statistics
+ fields_.push_back(schema::PrimitiveNode::Make(name, Repetition::REQUIRED,
+ Type::BYTE_ARRAY, ConvertedType::UTF8));
+}
+
+template <>
+void TestStatisticsSortOrder<ByteArrayType>::SetValues() {
+ int max_byte_array_len = 10;
+ size_t nbytes = NUM_VALUES * max_byte_array_len;
+ values_buf_.resize(nbytes);
+ std::vector<std::string> vals = {u8"c123", u8"b123", u8"a123", u8"d123", u8"e123",
+ u8"f123", u8"g123", u8"h123", u8"i123", u8"ü123"};
+
+ uint8_t* base = &values_buf_.data()[0];
+ for (int i = 0; i < NUM_VALUES; i++) {
+ memcpy(base, vals[i].c_str(), vals[i].length());
+ values_[i].ptr = base;
+ values_[i].len = static_cast<uint32_t>(vals[i].length());
+ base += vals[i].length();
+ }
+
+ // Write String min/max values
+ stats_[0]
+ .set_min(
+ std::string(reinterpret_cast<const char*>(vals[2].c_str()), vals[2].length()))
+ .set_max(
+ std::string(reinterpret_cast<const char*>(vals[9].c_str()), vals[9].length()));
+}
+
+// TYPE::FLBAArray
+template <>
+void TestStatisticsSortOrder<FLBAType>::AddNodes(std::string name) {
+ // FLBA has only Unsigned Statistics
+ fields_.push_back(schema::PrimitiveNode::Make(name, Repetition::REQUIRED,
+ Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::NONE, FLBA_LENGTH));
+}
+
+template <>
+void TestStatisticsSortOrder<FLBAType>::SetValues() {
+ size_t nbytes = NUM_VALUES * FLBA_LENGTH;
+ values_buf_.resize(nbytes);
+ char vals[NUM_VALUES][FLBA_LENGTH] = {"b12345", "a12345", "c12345", "d12345", "e12345",
+ "f12345", "g12345", "h12345", "z12345", "a12345"};
+
+ uint8_t* base = &values_buf_.data()[0];
+ for (int i = 0; i < NUM_VALUES; i++) {
+ memcpy(base, &vals[i][0], FLBA_LENGTH);
+ values_[i].ptr = base;
+ base += FLBA_LENGTH;
+ }
+
+ // Write FLBA min,max values
+ stats_[0]
+ .set_min(std::string(reinterpret_cast<const char*>(&vals[1][0]), FLBA_LENGTH))
+ .set_max(std::string(reinterpret_cast<const char*>(&vals[8][0]), FLBA_LENGTH));
+}
+
+TYPED_TEST_SUITE(TestStatisticsSortOrder, CompareTestTypes);
+
+TYPED_TEST(TestStatisticsSortOrder, MinMax) {
+ this->AddNodes("Column ");
+ this->SetUpSchema();
+ this->WriteParquet();
+ ASSERT_NO_FATAL_FAILURE(this->VerifyParquetStats());
+}
+
+template <typename ArrowType>
+void TestByteArrayStatisticsFromArrow() {
+ using TypeTraits = ::arrow::TypeTraits<ArrowType>;
+ using ArrayType = typename TypeTraits::ArrayType;
+
+ auto values = ArrayFromJSON(TypeTraits::type_singleton(),
+ u8"[\"c123\", \"b123\", \"a123\", null, "
+ "null, \"f123\", \"g123\", \"h123\", \"i123\", \"ü123\"]");
+
+ const auto& typed_values = static_cast<const ArrayType&>(*values);
+
+ NodePtr node = PrimitiveNode::Make("field", Repetition::REQUIRED, Type::BYTE_ARRAY,
+ ConvertedType::UTF8);
+ ColumnDescriptor descr(node, 0, 0);
+ auto stats = MakeStatistics<ByteArrayType>(&descr);
+ ASSERT_NO_FATAL_FAILURE(stats->Update(*values));
+
+ ASSERT_EQ(ByteArray(typed_values.GetView(2)), stats->min());
+ ASSERT_EQ(ByteArray(typed_values.GetView(9)), stats->max());
+ ASSERT_EQ(2, stats->null_count());
+}
+
+TEST(TestByteArrayStatisticsFromArrow, StringType) {
+ // Part of ARROW-3246. Replicating TestStatisticsSortOrder test but via Arrow
+ TestByteArrayStatisticsFromArrow<::arrow::StringType>();
+}
+
+TEST(TestByteArrayStatisticsFromArrow, LargeStringType) {
+ TestByteArrayStatisticsFromArrow<::arrow::LargeStringType>();
+}
+
+// Ensure UNKNOWN sort order is handled properly
+using TestStatisticsSortOrderFLBA = TestStatisticsSortOrder<FLBAType>;
+
+TEST_F(TestStatisticsSortOrderFLBA, UnknownSortOrder) {
+ this->fields_.push_back(schema::PrimitiveNode::Make(
+ "Column 0", Repetition::REQUIRED, Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::INTERVAL, FLBA_LENGTH));
+ this->SetUpSchema();
+ this->WriteParquet();
+
+ ASSERT_OK_AND_ASSIGN(auto pbuffer, parquet_sink_->Finish());
+
+ // Create a ParquetReader instance
+ std::unique_ptr<parquet::ParquetFileReader> parquet_reader =
+ parquet::ParquetFileReader::Open(
+ std::make_shared<::arrow::io::BufferReader>(pbuffer));
+ // Get the File MetaData
+ std::shared_ptr<parquet::FileMetaData> file_metadata = parquet_reader->metadata();
+ std::shared_ptr<parquet::RowGroupMetaData> rg_metadata = file_metadata->RowGroup(0);
+ std::shared_ptr<parquet::ColumnChunkMetaData> cc_metadata = rg_metadata->ColumnChunk(0);
+
+ // stats should not be set for UNKNOWN sort order
+ ASSERT_FALSE(cc_metadata->is_stats_set());
+}
+
+template <typename Stats, typename Array, typename T = typename Array::value_type>
+void AssertMinMaxAre(Stats stats, const Array& values, T expected_min, T expected_max) {
+ stats->Update(values.data(), values.size(), 0);
+ ASSERT_TRUE(stats->HasMinMax());
+ EXPECT_EQ(stats->min(), expected_min);
+ EXPECT_EQ(stats->max(), expected_max);
+}
+
+template <typename Stats, typename Array, typename T = typename Stats::T>
+void AssertMinMaxAre(Stats stats, const Array& values, const uint8_t* valid_bitmap,
+ T expected_min, T expected_max) {
+ auto n_values = values.size();
+ auto null_count = ::arrow::internal::CountSetBits(valid_bitmap, n_values, 0);
+ auto non_null_count = n_values - null_count;
+ stats->UpdateSpaced(values.data(), valid_bitmap, 0, non_null_count + null_count,
+ non_null_count, null_count);
+ ASSERT_TRUE(stats->HasMinMax());
+ EXPECT_EQ(stats->min(), expected_min);
+ EXPECT_EQ(stats->max(), expected_max);
+}
+
+template <typename Stats, typename Array>
+void AssertUnsetMinMax(Stats stats, const Array& values) {
+ stats->Update(values.data(), values.size(), 0);
+ ASSERT_FALSE(stats->HasMinMax());
+}
+
+template <typename Stats, typename Array>
+void AssertUnsetMinMax(Stats stats, const Array& values, const uint8_t* valid_bitmap) {
+ auto n_values = values.size();
+ auto null_count = ::arrow::internal::CountSetBits(valid_bitmap, n_values, 0);
+ auto non_null_count = n_values - null_count;
+ stats->UpdateSpaced(values.data(), valid_bitmap, 0, non_null_count + null_count,
+ non_null_count, null_count);
+ ASSERT_FALSE(stats->HasMinMax());
+}
+
+template <typename ParquetType, typename T = typename ParquetType::c_type>
+void CheckExtrema() {
+ using UT = typename std::make_unsigned<T>::type;
+
+ const T smin = std::numeric_limits<T>::min();
+ const T smax = std::numeric_limits<T>::max();
+ const T umin = SafeCopy<T>(std::numeric_limits<UT>::min());
+ const T umax = SafeCopy<T>(std::numeric_limits<UT>::max());
+
+ constexpr int kNumValues = 8;
+ std::array<T, kNumValues> values{0, smin, smax, umin,
+ umax, smin + 1, smax - 1, umax - 1};
+
+ NodePtr unsigned_node = PrimitiveNode::Make(
+ "uint", Repetition::OPTIONAL,
+ LogicalType::Int(sizeof(T) * CHAR_BIT, false /*signed*/), ParquetType::type_num);
+ ColumnDescriptor unsigned_descr(unsigned_node, 1, 1);
+ NodePtr signed_node = PrimitiveNode::Make(
+ "int", Repetition::OPTIONAL,
+ LogicalType::Int(sizeof(T) * CHAR_BIT, true /*signed*/), ParquetType::type_num);
+ ColumnDescriptor signed_descr(signed_node, 1, 1);
+
+ {
+ ARROW_SCOPED_TRACE("unsigned statistics: umin = ", umin, ", umax = ", umax,
+ ", node type = ", unsigned_node->logical_type()->ToString(),
+ ", physical type = ", unsigned_descr.physical_type(),
+ ", sort order = ", unsigned_descr.sort_order());
+ auto unsigned_stats = MakeStatistics<ParquetType>(&unsigned_descr);
+ AssertMinMaxAre(unsigned_stats, values, umin, umax);
+ }
+ {
+ ARROW_SCOPED_TRACE("signed statistics: smin = ", smin, ", smax = ", smax,
+ ", node type = ", signed_node->logical_type()->ToString(),
+ ", physical type = ", signed_descr.physical_type(),
+ ", sort order = ", signed_descr.sort_order());
+ auto signed_stats = MakeStatistics<ParquetType>(&signed_descr);
+ AssertMinMaxAre(signed_stats, values, smin, smax);
+ }
+
+ // With validity bitmap
+ std::vector<bool> is_valid = {true, false, false, false, false, true, true, true};
+ std::shared_ptr<Buffer> valid_bitmap;
+ ::arrow::BitmapFromVector(is_valid, &valid_bitmap);
+ {
+ ARROW_SCOPED_TRACE("spaced unsigned statistics: umin = ", umin, ", umax = ", umax,
+ ", node type = ", unsigned_node->logical_type()->ToString(),
+ ", physical type = ", unsigned_descr.physical_type(),
+ ", sort order = ", unsigned_descr.sort_order());
+ auto unsigned_stats = MakeStatistics<ParquetType>(&unsigned_descr);
+ AssertMinMaxAre(unsigned_stats, values, valid_bitmap->data(), T{0}, umax - 1);
+ }
+ {
+ ARROW_SCOPED_TRACE("spaced signed statistics: smin = ", smin, ", smax = ", smax,
+ ", node type = ", signed_node->logical_type()->ToString(),
+ ", physical type = ", signed_descr.physical_type(),
+ ", sort order = ", signed_descr.sort_order());
+ auto signed_stats = MakeStatistics<ParquetType>(&signed_descr);
+ AssertMinMaxAre(signed_stats, values, valid_bitmap->data(), smin + 1, smax - 1);
+ }
+}
+
+TEST(TestStatistic, Int32Extrema) { CheckExtrema<Int32Type>(); }
+TEST(TestStatistic, Int64Extrema) { CheckExtrema<Int64Type>(); }
+
+// PARQUET-1225: Float NaN values may lead to incorrect min-max
+template <typename ParquetType>
+void CheckNaNs() {
+ using T = typename ParquetType::c_type;
+
+ constexpr int kNumValues = 8;
+ NodePtr node = PrimitiveNode::Make("f", Repetition::OPTIONAL, ParquetType::type_num);
+ ColumnDescriptor descr(node, 1, 1);
+
+ constexpr T nan = std::numeric_limits<T>::quiet_NaN();
+ constexpr T min = -4.0f;
+ constexpr T max = 3.0f;
+
+ std::array<T, kNumValues> all_nans{nan, nan, nan, nan, nan, nan, nan, nan};
+ std::array<T, kNumValues> some_nans{nan, max, -3.0f, -1.0f, nan, 2.0f, min, nan};
+ uint8_t valid_bitmap = 0x7F; // 0b01111111
+ // NaNs excluded
+ uint8_t valid_bitmap_no_nans = 0x6E; // 0b01101110
+
+ // Test values
+ auto some_nan_stats = MakeStatistics<ParquetType>(&descr);
+ // Ingesting only nans should not yield valid min max
+ AssertUnsetMinMax(some_nan_stats, all_nans);
+ // Ingesting a mix of NaNs and non-NaNs should not yield valid min max.
+ AssertMinMaxAre(some_nan_stats, some_nans, min, max);
+ // Ingesting only nans after a valid min/max, should have not effect
+ AssertMinMaxAre(some_nan_stats, all_nans, min, max);
+
+ some_nan_stats = MakeStatistics<ParquetType>(&descr);
+ AssertUnsetMinMax(some_nan_stats, all_nans, &valid_bitmap);
+ // NaNs should not pollute min max when excluded via null bitmap.
+ AssertMinMaxAre(some_nan_stats, some_nans, &valid_bitmap_no_nans, min, max);
+ // Ingesting NaNs with a null bitmap should not change the result.
+ AssertMinMaxAre(some_nan_stats, some_nans, &valid_bitmap, min, max);
+
+ // An array that doesn't start with NaN
+ std::array<T, kNumValues> other_nans{1.5f, max, -3.0f, -1.0f, nan, 2.0f, min, nan};
+ auto other_stats = MakeStatistics<ParquetType>(&descr);
+ AssertMinMaxAre(other_stats, other_nans, min, max);
+}
+
+TEST(TestStatistic, NaNFloatValues) { CheckNaNs<FloatType>(); }
+
+TEST(TestStatistic, NaNDoubleValues) { CheckNaNs<DoubleType>(); }
+
+// ARROW-7376
+TEST(TestStatisticsSortOrderFloatNaN, NaNAndNullsInfiniteLoop) {
+ constexpr int kNumValues = 8;
+ NodePtr node = PrimitiveNode::Make("nan_float", Repetition::OPTIONAL, Type::FLOAT);
+ ColumnDescriptor descr(node, 1, 1);
+
+ constexpr float nan = std::numeric_limits<float>::quiet_NaN();
+ std::array<float, kNumValues> nans_but_last{nan, nan, nan, nan, nan, nan, nan, 0.0f};
+
+ uint8_t all_but_last_valid = 0x7F; // 0b01111111
+ auto stats = MakeStatistics<FloatType>(&descr);
+ AssertUnsetMinMax(stats, nans_but_last, &all_but_last_valid);
+}
+
+template <typename Stats, typename Array, typename T = typename Array::value_type>
+void AssertMinMaxZeroesSign(Stats stats, const Array& values) {
+ stats->Update(values.data(), values.size(), 0);
+ ASSERT_TRUE(stats->HasMinMax());
+
+ T zero{};
+ ASSERT_EQ(stats->min(), zero);
+ ASSERT_TRUE(std::signbit(stats->min()));
+
+ ASSERT_EQ(stats->max(), zero);
+ ASSERT_FALSE(std::signbit(stats->max()));
+}
+
+// ARROW-5562: Ensure that -0.0f and 0.0f values are properly handled like in
+// parquet-mr
+template <typename ParquetType>
+void CheckNegativeZeroStats() {
+ using T = typename ParquetType::c_type;
+
+ NodePtr node = PrimitiveNode::Make("f", Repetition::OPTIONAL, ParquetType::type_num);
+ ColumnDescriptor descr(node, 1, 1);
+ T zero{};
+
+ {
+ std::array<T, 2> values{-zero, zero};
+ auto stats = MakeStatistics<ParquetType>(&descr);
+ AssertMinMaxZeroesSign(stats, values);
+ }
+
+ {
+ std::array<T, 2> values{zero, -zero};
+ auto stats = MakeStatistics<ParquetType>(&descr);
+ AssertMinMaxZeroesSign(stats, values);
+ }
+
+ {
+ std::array<T, 2> values{-zero, -zero};
+ auto stats = MakeStatistics<ParquetType>(&descr);
+ AssertMinMaxZeroesSign(stats, values);
+ }
+
+ {
+ std::array<T, 2> values{zero, zero};
+ auto stats = MakeStatistics<ParquetType>(&descr);
+ AssertMinMaxZeroesSign(stats, values);
+ }
+}
+
+TEST(TestStatistics, FloatNegativeZero) { CheckNegativeZeroStats<FloatType>(); }
+
+TEST(TestStatistics, DoubleNegativeZero) { CheckNegativeZeroStats<DoubleType>(); }
+
+// Test statistics for binary column with UNSIGNED sort order
+TEST(TestStatisticsSortOrderMinMax, Unsigned) {
+ std::string dir_string(test::get_data_dir());
+ std::stringstream ss;
+ ss << dir_string << "/binary.parquet";
+ auto path = ss.str();
+
+ // The file is generated by parquet-mr 1.10.0, the first version that
+ // supports correct statistics for binary data (see PARQUET-1025). It
+ // contains a single column of binary type. Data is just single byte values
+ // from 0x00 to 0x0B.
+ auto file_reader = ParquetFileReader::OpenFile(path);
+ auto rg_reader = file_reader->RowGroup(0);
+ auto metadata = rg_reader->metadata();
+ auto column_schema = metadata->schema()->Column(0);
+ ASSERT_EQ(SortOrder::UNSIGNED, column_schema->sort_order());
+
+ auto column_chunk = metadata->ColumnChunk(0);
+ ASSERT_TRUE(column_chunk->is_stats_set());
+
+ std::shared_ptr<Statistics> stats = column_chunk->statistics();
+ ASSERT_TRUE(stats != NULL);
+ ASSERT_EQ(0, stats->null_count());
+ ASSERT_EQ(12, stats->num_values());
+ ASSERT_EQ(0x00, stats->EncodeMin()[0]);
+ ASSERT_EQ(0x0b, stats->EncodeMax()[0]);
+}
+
+} // namespace test
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/stream_reader.cc b/src/arrow/cpp/src/parquet/stream_reader.cc
new file mode 100644
index 000000000..9a7cc8cdf
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/stream_reader.cc
@@ -0,0 +1,521 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/stream_reader.h"
+
+#include <set>
+#include <utility>
+
+namespace parquet {
+
+constexpr int64_t StreamReader::kBatchSizeOne;
+
+// The converted type expected by the stream reader does not always
+// exactly match with the schema in the Parquet file. The following
+// is a list of converted types which are allowed instead of the
+// expected converted type.
+// Each pair given is:
+// {<StreamReader expected type>, <Parquet file converted type>}
+// So for example {ConvertedType::INT_32, ConvertedType::NONE} means
+// that if the StreamReader was expecting the converted type INT_32,
+// then it will allow the Parquet file to use the converted type
+// NONE.
+//
+static const std::set<std::pair<ConvertedType::type, ConvertedType::type> >
+ converted_type_exceptions = {{ConvertedType::INT_32, ConvertedType::NONE},
+ {ConvertedType::INT_64, ConvertedType::NONE},
+ {ConvertedType::INT_32, ConvertedType::DECIMAL},
+ {ConvertedType::INT_64, ConvertedType::DECIMAL},
+ {ConvertedType::UTF8, ConvertedType::NONE}};
+
+StreamReader::StreamReader(std::unique_ptr<ParquetFileReader> reader)
+ : file_reader_{std::move(reader)}, eof_{false} {
+ file_metadata_ = file_reader_->metadata();
+
+ auto schema = file_metadata_->schema();
+ auto group_node = schema->group_node();
+
+ nodes_.resize(schema->num_columns());
+
+ for (auto i = 0; i < schema->num_columns(); ++i) {
+ nodes_[i] = std::static_pointer_cast<schema::PrimitiveNode>(group_node->field(i));
+ }
+ NextRowGroup();
+}
+
+int StreamReader::num_columns() const {
+ // Check for file metadata i.e. object is not default constructed.
+ if (file_metadata_) {
+ return file_metadata_->num_columns();
+ }
+ return 0;
+}
+
+int64_t StreamReader::num_rows() const {
+ // Check for file metadata i.e. object is not default constructed.
+ if (file_metadata_) {
+ return file_metadata_->num_rows();
+ }
+ return 0;
+}
+
+StreamReader& StreamReader::operator>>(bool& v) {
+ CheckColumn(Type::BOOLEAN, ConvertedType::NONE);
+ Read<BoolReader>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(int8_t& v) {
+ CheckColumn(Type::INT32, ConvertedType::INT_8);
+ Read<Int32Reader, int32_t>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(uint8_t& v) {
+ CheckColumn(Type::INT32, ConvertedType::UINT_8);
+ Read<Int32Reader, int32_t>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(int16_t& v) {
+ CheckColumn(Type::INT32, ConvertedType::INT_16);
+ Read<Int32Reader, int32_t>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(uint16_t& v) {
+ CheckColumn(Type::INT32, ConvertedType::UINT_16);
+ Read<Int32Reader, int32_t>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(int32_t& v) {
+ CheckColumn(Type::INT32, ConvertedType::INT_32);
+ Read<Int32Reader>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(uint32_t& v) {
+ CheckColumn(Type::INT32, ConvertedType::UINT_32);
+ Read<Int32Reader>(reinterpret_cast<int32_t*>(&v));
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(int64_t& v) {
+ CheckColumn(Type::INT64, ConvertedType::INT_64);
+ Read<Int64Reader>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(uint64_t& v) {
+ CheckColumn(Type::INT64, ConvertedType::UINT_64);
+ Read<Int64Reader>(reinterpret_cast<int64_t*>(&v));
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(std::chrono::milliseconds& v) {
+ CheckColumn(Type::INT64, ConvertedType::TIMESTAMP_MILLIS);
+ int64_t tmp;
+ Read<Int64Reader>(&tmp);
+ v = std::chrono::milliseconds{tmp};
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(std::chrono::microseconds& v) {
+ CheckColumn(Type::INT64, ConvertedType::TIMESTAMP_MICROS);
+ int64_t tmp;
+ Read<Int64Reader>(&tmp);
+ v = std::chrono::microseconds{tmp};
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(float& v) {
+ CheckColumn(Type::FLOAT, ConvertedType::NONE);
+ Read<FloatReader>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(double& v) {
+ CheckColumn(Type::DOUBLE, ConvertedType::NONE);
+ Read<DoubleReader>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(char& v) {
+ CheckColumn(Type::FIXED_LEN_BYTE_ARRAY, ConvertedType::NONE, 1);
+ FixedLenByteArray flba;
+
+ Read(&flba);
+ v = static_cast<char>(flba.ptr[0]);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(std::string& v) {
+ CheckColumn(Type::BYTE_ARRAY, ConvertedType::UTF8);
+ ByteArray ba;
+
+ Read(&ba);
+ v = std::string(reinterpret_cast<const char*>(ba.ptr), ba.len);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<bool>& v) {
+ CheckColumn(Type::BOOLEAN, ConvertedType::NONE);
+ ReadOptional<BoolReader>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<int8_t>& v) {
+ CheckColumn(Type::INT32, ConvertedType::INT_8);
+ ReadOptional<Int32Reader, int32_t>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<uint8_t>& v) {
+ CheckColumn(Type::INT32, ConvertedType::UINT_8);
+ ReadOptional<Int32Reader, int32_t>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<int16_t>& v) {
+ CheckColumn(Type::INT32, ConvertedType::INT_16);
+ ReadOptional<Int32Reader, int32_t>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<uint16_t>& v) {
+ CheckColumn(Type::INT32, ConvertedType::UINT_16);
+ ReadOptional<Int32Reader, int32_t>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<int32_t>& v) {
+ CheckColumn(Type::INT32, ConvertedType::INT_32);
+ ReadOptional<Int32Reader>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<uint32_t>& v) {
+ CheckColumn(Type::INT32, ConvertedType::UINT_32);
+ ReadOptional<Int32Reader>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<int64_t>& v) {
+ CheckColumn(Type::INT64, ConvertedType::INT_64);
+ ReadOptional<Int64Reader>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<uint64_t>& v) {
+ CheckColumn(Type::INT64, ConvertedType::UINT_64);
+ ReadOptional<Int64Reader>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<float>& v) {
+ CheckColumn(Type::FLOAT, ConvertedType::NONE);
+ ReadOptional<FloatReader>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<double>& v) {
+ CheckColumn(Type::DOUBLE, ConvertedType::NONE);
+ ReadOptional<DoubleReader>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<std::chrono::milliseconds>& v) {
+ CheckColumn(Type::INT64, ConvertedType::TIMESTAMP_MILLIS);
+ ReadOptional<Int64Reader, int64_t>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<std::chrono::microseconds>& v) {
+ CheckColumn(Type::INT64, ConvertedType::TIMESTAMP_MICROS);
+ ReadOptional<Int64Reader, int64_t>(&v);
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<char>& v) {
+ CheckColumn(Type::FIXED_LEN_BYTE_ARRAY, ConvertedType::NONE, 1);
+ FixedLenByteArray flba;
+
+ if (ReadOptional(&flba)) {
+ v = static_cast<char>(flba.ptr[0]);
+ } else {
+ v.reset();
+ }
+ return *this;
+}
+
+StreamReader& StreamReader::operator>>(optional<std::string>& v) {
+ CheckColumn(Type::BYTE_ARRAY, ConvertedType::UTF8);
+ ByteArray ba;
+
+ if (ReadOptional(&ba)) {
+ v = std::string(reinterpret_cast<const char*>(ba.ptr), ba.len);
+ } else {
+ v.reset();
+ }
+ return *this;
+}
+
+void StreamReader::ReadFixedLength(char* ptr, int len) {
+ CheckColumn(Type::FIXED_LEN_BYTE_ARRAY, ConvertedType::NONE, len);
+ FixedLenByteArray flba;
+ Read(&flba);
+ std::memcpy(ptr, flba.ptr, len);
+}
+
+void StreamReader::Read(ByteArray* v) {
+ const auto& node = nodes_[column_index_];
+ auto reader = static_cast<ByteArrayReader*>(column_readers_[column_index_++].get());
+ int16_t def_level;
+ int16_t rep_level;
+ int64_t values_read;
+
+ reader->ReadBatch(kBatchSizeOne, &def_level, &rep_level, v, &values_read);
+
+ if (values_read != 1) {
+ ThrowReadFailedException(node);
+ }
+}
+
+bool StreamReader::ReadOptional(ByteArray* v) {
+ const auto& node = nodes_[column_index_];
+ auto reader = static_cast<ByteArrayReader*>(column_readers_[column_index_++].get());
+ int16_t def_level;
+ int16_t rep_level;
+ int64_t values_read;
+
+ reader->ReadBatch(kBatchSizeOne, &def_level, &rep_level, v, &values_read);
+
+ if (values_read == 1) {
+ return true;
+ } else if ((values_read == 0) && (def_level == 0)) {
+ return false;
+ }
+ ThrowReadFailedException(node);
+}
+
+void StreamReader::Read(FixedLenByteArray* v) {
+ const auto& node = nodes_[column_index_];
+ auto reader =
+ static_cast<FixedLenByteArrayReader*>(column_readers_[column_index_++].get());
+ int16_t def_level;
+ int16_t rep_level;
+ int64_t values_read;
+
+ reader->ReadBatch(kBatchSizeOne, &def_level, &rep_level, v, &values_read);
+
+ if (values_read != 1) {
+ ThrowReadFailedException(node);
+ }
+}
+
+bool StreamReader::ReadOptional(FixedLenByteArray* v) {
+ const auto& node = nodes_[column_index_];
+ auto reader =
+ static_cast<FixedLenByteArrayReader*>(column_readers_[column_index_++].get());
+ int16_t def_level;
+ int16_t rep_level;
+ int64_t values_read;
+
+ reader->ReadBatch(kBatchSizeOne, &def_level, &rep_level, v, &values_read);
+
+ if (values_read == 1) {
+ return true;
+ } else if ((values_read == 0) && (def_level == 0)) {
+ return false;
+ }
+ ThrowReadFailedException(node);
+}
+
+void StreamReader::EndRow() {
+ if (!file_reader_) {
+ throw ParquetException("StreamReader not initialized");
+ }
+ if (static_cast<std::size_t>(column_index_) < nodes_.size()) {
+ throw ParquetException("Cannot end row with " + std::to_string(column_index_) +
+ " of " + std::to_string(nodes_.size()) + " columns read");
+ }
+ column_index_ = 0;
+ ++current_row_;
+
+ if (!column_readers_[0]->HasNext()) {
+ NextRowGroup();
+ }
+}
+
+void StreamReader::NextRowGroup() {
+ // Find next none-empty row group
+ while (row_group_index_ < file_metadata_->num_row_groups()) {
+ row_group_reader_ = file_reader_->RowGroup(row_group_index_);
+ ++row_group_index_;
+
+ column_readers_.resize(file_metadata_->num_columns());
+
+ for (int i = 0; i < file_metadata_->num_columns(); ++i) {
+ column_readers_[i] = row_group_reader_->Column(i);
+ }
+ if (column_readers_[0]->HasNext()) {
+ row_group_row_offset_ = current_row_;
+ return;
+ }
+ }
+ // No more row groups found.
+ SetEof();
+}
+
+void StreamReader::SetEof() {
+ // Do not reset file_metadata_ to ensure queries on the number of
+ // rows/columns still function.
+ eof_ = true;
+ file_reader_.reset();
+ row_group_reader_.reset();
+ column_readers_.clear();
+ nodes_.clear();
+}
+
+int64_t StreamReader::SkipRows(int64_t num_rows_to_skip) {
+ if (0 != column_index_) {
+ throw ParquetException("Must finish reading current row before skipping rows.");
+ }
+ int64_t num_rows_remaining_to_skip = num_rows_to_skip;
+
+ while (!eof_ && (num_rows_remaining_to_skip > 0)) {
+ int64_t num_rows_in_row_group = row_group_reader_->metadata()->num_rows();
+ int64_t num_rows_remaining_in_row_group =
+ num_rows_in_row_group - current_row_ - row_group_row_offset_;
+
+ if (num_rows_remaining_in_row_group > num_rows_remaining_to_skip) {
+ for (auto reader : column_readers_) {
+ SkipRowsInColumn(reader.get(), num_rows_remaining_to_skip);
+ }
+ current_row_ += num_rows_remaining_to_skip;
+ num_rows_remaining_to_skip = 0;
+ } else {
+ num_rows_remaining_to_skip -= num_rows_remaining_in_row_group;
+ current_row_ += num_rows_remaining_in_row_group;
+ NextRowGroup();
+ }
+ }
+ return num_rows_to_skip - num_rows_remaining_to_skip;
+}
+
+int64_t StreamReader::SkipColumns(int64_t num_columns_to_skip) {
+ int64_t num_columns_skipped = 0;
+
+ if (!eof_) {
+ for (; (num_columns_to_skip > num_columns_skipped) &&
+ static_cast<std::size_t>(column_index_) < nodes_.size();
+ ++column_index_) {
+ SkipRowsInColumn(column_readers_[column_index_].get(), 1);
+ ++num_columns_skipped;
+ }
+ }
+ return num_columns_skipped;
+}
+
+void StreamReader::SkipRowsInColumn(ColumnReader* reader, int64_t num_rows_to_skip) {
+ int64_t num_skipped = 0;
+
+ switch (reader->type()) {
+ case Type::BOOLEAN:
+ num_skipped = static_cast<BoolReader*>(reader)->Skip(num_rows_to_skip);
+ break;
+ case Type::INT32:
+ num_skipped = static_cast<Int32Reader*>(reader)->Skip(num_rows_to_skip);
+ break;
+ case Type::INT64:
+ num_skipped = static_cast<Int64Reader*>(reader)->Skip(num_rows_to_skip);
+ break;
+ case Type::BYTE_ARRAY:
+ num_skipped = static_cast<ByteArrayReader*>(reader)->Skip(num_rows_to_skip);
+ break;
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ num_skipped = static_cast<FixedLenByteArrayReader*>(reader)->Skip(num_rows_to_skip);
+ break;
+ case Type::FLOAT:
+ num_skipped = static_cast<FloatReader*>(reader)->Skip(num_rows_to_skip);
+ break;
+ case Type::DOUBLE:
+ num_skipped = static_cast<DoubleReader*>(reader)->Skip(num_rows_to_skip);
+ break;
+ case Type::INT96:
+ num_skipped = static_cast<Int96Reader*>(reader)->Skip(num_rows_to_skip);
+ break;
+ case Type::UNDEFINED:
+ throw ParquetException("Unexpected type: " + TypeToString(reader->type()));
+ break;
+ }
+ if (num_rows_to_skip != num_skipped) {
+ throw ParquetException("Skipped " + std::to_string(num_skipped) + "/" +
+ std::to_string(num_rows_to_skip) + " rows in column " +
+ reader->descr()->name());
+ }
+}
+
+void StreamReader::CheckColumn(Type::type physical_type,
+ ConvertedType::type converted_type, int length) {
+ if (static_cast<std::size_t>(column_index_) >= nodes_.size()) {
+ if (eof_) {
+ ParquetException::EofException();
+ }
+ throw ParquetException("Column index out-of-bounds. Index " +
+ std::to_string(column_index_) + " is invalid for " +
+ std::to_string(nodes_.size()) + " columns");
+ }
+ const auto& node = nodes_[column_index_];
+
+ if (physical_type != node->physical_type()) {
+ throw ParquetException("Column physical type mismatch. Column '" + node->name() +
+ "' has physical type '" + TypeToString(node->physical_type()) +
+ "' not '" + TypeToString(physical_type) + "'");
+ }
+ if (converted_type != node->converted_type()) {
+ // The converted type does not always match with the value
+ // provided so check the set of exceptions.
+ if (converted_type_exceptions.find({converted_type, node->converted_type()}) ==
+ converted_type_exceptions.end()) {
+ throw ParquetException("Column converted type mismatch. Column '" + node->name() +
+ "' has converted type '" +
+ ConvertedTypeToString(node->converted_type()) + "' not '" +
+ ConvertedTypeToString(converted_type) + "'");
+ }
+ }
+ // Length must be exact.
+ if (length != node->type_length()) {
+ throw ParquetException("Column length mismatch. Column '" + node->name() +
+ "' has length " + std::to_string(node->type_length()) +
+ "] not " + std::to_string(length));
+ }
+} // namespace parquet
+
+void StreamReader::ThrowReadFailedException(
+ const std::shared_ptr<schema::PrimitiveNode>& node) {
+ throw ParquetException("Failed to read value for column '" + node->name() +
+ "' on row " + std::to_string(current_row_));
+}
+
+StreamReader& operator>>(StreamReader& os, EndRowType) {
+ os.EndRow();
+ return os;
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/stream_reader.h b/src/arrow/cpp/src/parquet/stream_reader.h
new file mode 100644
index 000000000..806b0e8ad
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/stream_reader.h
@@ -0,0 +1,299 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <array>
+#include <chrono>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/util/optional.h"
+#include "parquet/column_reader.h"
+#include "parquet/file_reader.h"
+#include "parquet/stream_writer.h"
+
+namespace parquet {
+
+/// \brief A class for reading Parquet files using an output stream type API.
+///
+/// The values given must be of the correct type i.e. the type must
+/// match the file schema exactly otherwise a ParquetException will be
+/// thrown.
+///
+/// The user must explicitly advance to the next row using the
+/// EndRow() function or EndRow input manipulator.
+///
+/// Required and optional fields are supported:
+/// - Required fields are read using operator>>(T)
+/// - Optional fields are read with
+/// operator>>(arrow::util::optional<T>)
+///
+/// Note that operator>>(arrow::util::optional<T>) can be used to read
+/// required fields.
+///
+/// Similarly operator>>(T) can be used to read optional fields.
+/// However, if the value is not present then a ParquetException will
+/// be raised.
+///
+/// Currently there is no support for repeated fields.
+///
+class PARQUET_EXPORT StreamReader {
+ public:
+ template <typename T>
+ using optional = ::arrow::util::optional<T>;
+
+ // N.B. Default constructed objects are not usable. This
+ // constructor is provided so that the object may be move
+ // assigned afterwards.
+ StreamReader() = default;
+
+ explicit StreamReader(std::unique_ptr<ParquetFileReader> reader);
+
+ ~StreamReader() = default;
+
+ bool eof() const { return eof_; }
+
+ int current_column() const { return column_index_; }
+
+ int64_t current_row() const { return current_row_; }
+
+ int num_columns() const;
+
+ int64_t num_rows() const;
+
+ // Moving is possible.
+ StreamReader(StreamReader&&) = default;
+ StreamReader& operator=(StreamReader&&) = default;
+
+ // Copying is not allowed.
+ StreamReader(const StreamReader&) = delete;
+ StreamReader& operator=(const StreamReader&) = delete;
+
+ StreamReader& operator>>(bool& v);
+
+ StreamReader& operator>>(int8_t& v);
+
+ StreamReader& operator>>(uint8_t& v);
+
+ StreamReader& operator>>(int16_t& v);
+
+ StreamReader& operator>>(uint16_t& v);
+
+ StreamReader& operator>>(int32_t& v);
+
+ StreamReader& operator>>(uint32_t& v);
+
+ StreamReader& operator>>(int64_t& v);
+
+ StreamReader& operator>>(uint64_t& v);
+
+ StreamReader& operator>>(std::chrono::milliseconds& v);
+
+ StreamReader& operator>>(std::chrono::microseconds& v);
+
+ StreamReader& operator>>(float& v);
+
+ StreamReader& operator>>(double& v);
+
+ StreamReader& operator>>(char& v);
+
+ template <int N>
+ StreamReader& operator>>(char (&v)[N]) {
+ ReadFixedLength(v, N);
+ return *this;
+ }
+
+ template <std::size_t N>
+ StreamReader& operator>>(std::array<char, N>& v) {
+ ReadFixedLength(v.data(), static_cast<int>(N));
+ return *this;
+ }
+
+ // N.B. Cannot allow for reading to a arbitrary char pointer as the
+ // length cannot be verified. Also it would overshadow the
+ // char[N] input operator.
+ // StreamReader& operator>>(char * v);
+
+ StreamReader& operator>>(std::string& v);
+
+ // Input operators for optional fields.
+
+ StreamReader& operator>>(optional<bool>& v);
+
+ StreamReader& operator>>(optional<int8_t>& v);
+
+ StreamReader& operator>>(optional<uint8_t>& v);
+
+ StreamReader& operator>>(optional<int16_t>& v);
+
+ StreamReader& operator>>(optional<uint16_t>& v);
+
+ StreamReader& operator>>(optional<int32_t>& v);
+
+ StreamReader& operator>>(optional<uint32_t>& v);
+
+ StreamReader& operator>>(optional<int64_t>& v);
+
+ StreamReader& operator>>(optional<uint64_t>& v);
+
+ StreamReader& operator>>(optional<float>& v);
+
+ StreamReader& operator>>(optional<double>& v);
+
+ StreamReader& operator>>(optional<std::chrono::milliseconds>& v);
+
+ StreamReader& operator>>(optional<std::chrono::microseconds>& v);
+
+ StreamReader& operator>>(optional<char>& v);
+
+ StreamReader& operator>>(optional<std::string>& v);
+
+ template <std::size_t N>
+ StreamReader& operator>>(optional<std::array<char, N>>& v) {
+ CheckColumn(Type::FIXED_LEN_BYTE_ARRAY, ConvertedType::NONE, N);
+ FixedLenByteArray flba;
+ if (ReadOptional(&flba)) {
+ v = std::array<char, N>{};
+ std::memcpy(v->data(), flba.ptr, N);
+ } else {
+ v.reset();
+ }
+ return *this;
+ }
+
+ /// \brief Terminate current row and advance to next one.
+ /// \throws ParquetException if all columns in the row were not
+ /// read or skipped.
+ void EndRow();
+
+ /// \brief Skip the data in the next columns.
+ /// If the number of columns exceeds the columns remaining on the
+ /// current row then skipping is terminated - it does _not_ continue
+ /// skipping columns on the next row.
+ /// Skipping of columns still requires the use 'EndRow' even if all
+ /// remaining columns were skipped.
+ /// \return Number of columns actually skipped.
+ int64_t SkipColumns(int64_t num_columns_to_skip);
+
+ /// \brief Skip the data in the next rows.
+ /// Skipping of rows is not allowed if reading of data for the
+ /// current row is not finished.
+ /// Skipping of rows will be terminated if the end of file is
+ /// reached.
+ /// \return Number of rows actually skipped.
+ int64_t SkipRows(int64_t num_rows_to_skip);
+
+ protected:
+ [[noreturn]] void ThrowReadFailedException(
+ const std::shared_ptr<schema::PrimitiveNode>& node);
+
+ template <typename ReaderType, typename T>
+ void Read(T* v) {
+ const auto& node = nodes_[column_index_];
+ auto reader = static_cast<ReaderType*>(column_readers_[column_index_++].get());
+ int16_t def_level;
+ int16_t rep_level;
+ int64_t values_read;
+
+ reader->ReadBatch(kBatchSizeOne, &def_level, &rep_level, v, &values_read);
+
+ if (values_read != 1) {
+ ThrowReadFailedException(node);
+ }
+ }
+
+ template <typename ReaderType, typename ReadType, typename T>
+ void Read(T* v) {
+ const auto& node = nodes_[column_index_];
+ auto reader = static_cast<ReaderType*>(column_readers_[column_index_++].get());
+ int16_t def_level;
+ int16_t rep_level;
+ ReadType tmp;
+ int64_t values_read;
+
+ reader->ReadBatch(kBatchSizeOne, &def_level, &rep_level, &tmp, &values_read);
+
+ if (values_read == 1) {
+ *v = tmp;
+ } else {
+ ThrowReadFailedException(node);
+ }
+ }
+
+ template <typename ReaderType, typename ReadType = typename ReaderType::T, typename T>
+ void ReadOptional(optional<T>* v) {
+ const auto& node = nodes_[column_index_];
+ auto reader = static_cast<ReaderType*>(column_readers_[column_index_++].get());
+ int16_t def_level;
+ int16_t rep_level;
+ ReadType tmp;
+ int64_t values_read;
+
+ reader->ReadBatch(kBatchSizeOne, &def_level, &rep_level, &tmp, &values_read);
+
+ if (values_read == 1) {
+ *v = T(tmp);
+ } else if ((values_read == 0) && (def_level == 0)) {
+ v->reset();
+ } else {
+ ThrowReadFailedException(node);
+ }
+ }
+
+ void ReadFixedLength(char* ptr, int len);
+
+ void Read(ByteArray* v);
+
+ void Read(FixedLenByteArray* v);
+
+ bool ReadOptional(ByteArray* v);
+
+ bool ReadOptional(FixedLenByteArray* v);
+
+ void NextRowGroup();
+
+ void CheckColumn(Type::type physical_type, ConvertedType::type converted_type,
+ int length = 0);
+
+ void SkipRowsInColumn(ColumnReader* reader, int64_t num_rows_to_skip);
+
+ void SetEof();
+
+ private:
+ std::unique_ptr<ParquetFileReader> file_reader_;
+ std::shared_ptr<FileMetaData> file_metadata_;
+ std::shared_ptr<RowGroupReader> row_group_reader_;
+ std::vector<std::shared_ptr<ColumnReader>> column_readers_;
+ std::vector<std::shared_ptr<schema::PrimitiveNode>> nodes_;
+
+ bool eof_{true};
+ int row_group_index_{0};
+ int column_index_{0};
+ int64_t current_row_{0};
+ int64_t row_group_row_offset_{0};
+
+ static constexpr int64_t kBatchSizeOne = 1;
+}; // namespace parquet
+
+PARQUET_EXPORT
+StreamReader& operator>>(StreamReader&, EndRowType);
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/stream_reader_test.cc b/src/arrow/cpp/src/parquet/stream_reader_test.cc
new file mode 100644
index 000000000..eb7b13374
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/stream_reader_test.cc
@@ -0,0 +1,916 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/stream_reader.h"
+
+#include <fcntl.h>
+#include <gtest/gtest.h>
+
+#include <chrono>
+#include <ctime>
+#include <memory>
+#include <utility>
+
+#include "arrow/io/file.h"
+#include "parquet/exception.h"
+#include "parquet/test_util.h"
+
+namespace parquet {
+namespace test {
+
+template <typename T>
+using optional = StreamReader::optional<T>;
+using ::arrow::util::nullopt;
+
+struct TestData {
+ static void init() { std::time(&ts_offset_); }
+
+ static constexpr int num_rows = 2000;
+
+ static std::string GetString(const int i) { return "Str #" + std::to_string(i); }
+ static bool GetBool(const int i) { return i % 7 < 3; }
+ static char GetChar(const int i) { return i & 1 ? 'M' : 'F'; }
+ static std::array<char, 4> GetCharArray(const int i) {
+ return {'X', 'Y', 'Z', char('A' + i % 26)};
+ }
+ static int8_t GetInt8(const int i) { return static_cast<int8_t>((i % 256) - 128); }
+ static uint16_t GetUInt16(const int i) { return static_cast<uint16_t>(i); }
+ static int32_t GetInt32(const int i) { return 3 * i - 17; }
+ static uint64_t GetUInt64(const int i) { return (1ull << 40) + i * i + 101; }
+ static float GetFloat(const int i) { return 3.1415926535897f * i; }
+ static double GetDouble(const int i) { return 6.62607004e-34 * 3e8 * i; }
+
+ static std::chrono::microseconds GetChronoMicroseconds(const int i) {
+ return std::chrono::microseconds{(ts_offset_ + 3 * i) * 1000000ull + i};
+ }
+
+ static optional<bool> GetOptBool(const int i) {
+ if (i % 11 == 0) {
+ return nullopt;
+ }
+ return i % 7 < 3;
+ }
+
+ static optional<char> GetOptChar(const int i) {
+ if ((i + 1) % 11 == 1) {
+ return nullopt;
+ }
+ return i & 1 ? 'M' : 'F';
+ }
+
+ static optional<std::array<char, 4>> GetOptCharArray(const int i) {
+ if ((i + 2) % 11 == 1) {
+ return nullopt;
+ }
+ return std::array<char, 4>{{'X', 'Y', 'Z', char('A' + i % 26)}};
+ }
+
+ static optional<int8_t> GetOptInt8(const int i) {
+ if ((i + 3) % 11 == 1) {
+ return nullopt;
+ }
+ return static_cast<int8_t>((i % 256) - 128);
+ }
+
+ static optional<uint16_t> GetOptUInt16(const int i) {
+ if ((i + 4) % 11 == 1) {
+ return nullopt;
+ }
+ return static_cast<uint16_t>(i);
+ }
+
+ static optional<int32_t> GetOptInt32(const int i) {
+ if ((i + 5) % 11 == 1) {
+ return nullopt;
+ }
+ return 3 * i - 17;
+ }
+
+ static optional<uint64_t> GetOptUInt64(const int i) {
+ if ((i + 6) % 11 == 1) {
+ return nullopt;
+ }
+ return (1ull << 40) + i * i + 101;
+ }
+
+ static optional<std::string> GetOptString(const int i) {
+ if (i % 5 == 0) {
+ return nullopt;
+ }
+ return "Str #" + std::to_string(i);
+ }
+
+ static optional<float> GetOptFloat(const int i) {
+ if ((i + 1) % 3 == 0) {
+ return nullopt;
+ }
+ return 2.718281828459045f * i;
+ }
+
+ static optional<double> GetOptDouble(const int i) {
+ if ((i + 2) % 3 == 0) {
+ return nullopt;
+ }
+ return 6.62607004e-34 * 3e8 * i;
+ }
+
+ static optional<std::chrono::microseconds> GetOptChronoMicroseconds(const int i) {
+ if ((i + 2) % 7 == 0) {
+ return nullopt;
+ }
+ return std::chrono::microseconds{(ts_offset_ + 3 * i) * 1000000ull + i};
+ }
+
+ private:
+ static std::time_t ts_offset_;
+};
+
+std::time_t TestData::ts_offset_;
+constexpr int TestData::num_rows;
+
+class TestStreamReader : public ::testing::Test {
+ public:
+ TestStreamReader() { createTestFile(); }
+
+ protected:
+ const char* GetDataFile() const { return "stream_reader_test.parquet"; }
+
+ void SetUp() {
+ PARQUET_ASSIGN_OR_THROW(auto infile, ::arrow::io::ReadableFile::Open(GetDataFile()));
+ auto file_reader = parquet::ParquetFileReader::Open(infile);
+ reader_ = StreamReader{std::move(file_reader)};
+ }
+
+ void TearDown() { reader_ = StreamReader{}; }
+
+ std::shared_ptr<schema::GroupNode> GetSchema() {
+ schema::NodeVector fields;
+
+ fields.push_back(schema::PrimitiveNode::Make("bool_field", Repetition::REQUIRED,
+ Type::BOOLEAN, ConvertedType::NONE));
+
+ fields.push_back(schema::PrimitiveNode::Make("string_field", Repetition::REQUIRED,
+ Type::BYTE_ARRAY, ConvertedType::UTF8));
+
+ fields.push_back(schema::PrimitiveNode::Make("char_field", Repetition::REQUIRED,
+ Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::NONE, 1));
+
+ fields.push_back(schema::PrimitiveNode::Make("char[4]_field", Repetition::REQUIRED,
+ Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::NONE, 4));
+
+ fields.push_back(schema::PrimitiveNode::Make("int8_field", Repetition::REQUIRED,
+ Type::INT32, ConvertedType::INT_8));
+
+ fields.push_back(schema::PrimitiveNode::Make("uint16_field", Repetition::REQUIRED,
+ Type::INT32, ConvertedType::UINT_16));
+
+ fields.push_back(schema::PrimitiveNode::Make("int32_field", Repetition::REQUIRED,
+ Type::INT32, ConvertedType::INT_32));
+
+ fields.push_back(schema::PrimitiveNode::Make("uint64_field", Repetition::REQUIRED,
+ Type::INT64, ConvertedType::UINT_64));
+
+ fields.push_back(schema::PrimitiveNode::Make("chrono_microseconds_field",
+ Repetition::REQUIRED, Type::INT64,
+ ConvertedType::TIMESTAMP_MICROS));
+
+ fields.push_back(schema::PrimitiveNode::Make("float_field", Repetition::REQUIRED,
+ Type::FLOAT, ConvertedType::NONE));
+
+ fields.push_back(schema::PrimitiveNode::Make("double_field", Repetition::REQUIRED,
+ Type::DOUBLE, ConvertedType::NONE));
+
+ return std::static_pointer_cast<schema::GroupNode>(
+ schema::GroupNode::Make("schema", Repetition::REQUIRED, fields));
+ }
+
+ void createTestFile() {
+ PARQUET_ASSIGN_OR_THROW(auto outfile,
+ ::arrow::io::FileOutputStream::Open(GetDataFile()));
+
+ auto file_writer = ParquetFileWriter::Open(outfile, GetSchema());
+
+ StreamWriter os{std::move(file_writer)};
+
+ TestData::init();
+
+ for (auto i = 0; i < TestData::num_rows; ++i) {
+ os << TestData::GetBool(i);
+ os << TestData::GetString(i);
+ os << TestData::GetChar(i);
+ os << TestData::GetCharArray(i);
+ os << TestData::GetInt8(i);
+ os << TestData::GetUInt16(i);
+ os << TestData::GetInt32(i);
+ os << TestData::GetUInt64(i);
+ os << TestData::GetChronoMicroseconds(i);
+ os << TestData::GetFloat(i);
+ os << TestData::GetDouble(i);
+ os << EndRow;
+ }
+ }
+
+ StreamReader reader_;
+};
+
+TEST_F(TestStreamReader, DefaultConstructed) {
+ StreamReader os;
+ int i;
+ std::string s;
+
+ // N.B. Default constructor objects are not usable.
+ EXPECT_THROW(os >> i, ParquetException);
+ EXPECT_THROW(os >> s, ParquetException);
+ EXPECT_THROW(os >> EndRow, ParquetException);
+
+ EXPECT_EQ(true, os.eof());
+ EXPECT_EQ(0, os.current_column());
+ EXPECT_EQ(0, os.current_row());
+
+ EXPECT_EQ(0, os.num_columns());
+ EXPECT_EQ(0, os.num_rows());
+
+ // Skipping columns and rows is allowed.
+ //
+ EXPECT_EQ(0, os.SkipColumns(100));
+ EXPECT_EQ(0, os.SkipRows(100));
+}
+
+TEST_F(TestStreamReader, TypeChecking) {
+ bool b;
+ std::string s;
+ std::array<char, 4> char_array;
+ char c;
+ int8_t int8;
+ int16_t int16;
+ uint16_t uint16;
+ int32_t int32;
+ int64_t int64;
+ uint64_t uint64;
+ std::chrono::microseconds ts_us;
+ float f;
+ double d;
+ std::string str;
+
+ EXPECT_THROW(reader_ >> int8, ParquetException);
+ EXPECT_NO_THROW(reader_ >> b);
+ EXPECT_THROW(reader_ >> c, ParquetException);
+ EXPECT_NO_THROW(reader_ >> s);
+ EXPECT_THROW(reader_ >> s, ParquetException);
+ EXPECT_NO_THROW(reader_ >> c);
+ EXPECT_THROW(reader_ >> s, ParquetException);
+ EXPECT_NO_THROW(reader_ >> char_array);
+ EXPECT_THROW(reader_ >> int16, ParquetException);
+ EXPECT_NO_THROW(reader_ >> int8);
+ EXPECT_THROW(reader_ >> int16, ParquetException);
+ EXPECT_NO_THROW(reader_ >> uint16);
+ EXPECT_THROW(reader_ >> int64, ParquetException);
+ EXPECT_NO_THROW(reader_ >> int32);
+ EXPECT_THROW(reader_ >> int64, ParquetException);
+ EXPECT_NO_THROW(reader_ >> uint64);
+ EXPECT_THROW(reader_ >> uint64, ParquetException);
+ EXPECT_NO_THROW(reader_ >> ts_us);
+ EXPECT_THROW(reader_ >> d, ParquetException);
+ EXPECT_NO_THROW(reader_ >> f);
+ EXPECT_THROW(reader_ >> f, ParquetException);
+ EXPECT_NO_THROW(reader_ >> d);
+ EXPECT_NO_THROW(reader_ >> EndRow);
+}
+
+TEST_F(TestStreamReader, ValueChecking) {
+ bool b;
+ std::string str;
+ std::array<char, 4> char_array;
+ char c;
+ int8_t int8;
+ uint16_t uint16;
+ int32_t int32;
+ uint64_t uint64;
+ std::chrono::microseconds ts_us;
+ float f;
+ double d;
+
+ int i;
+
+ for (i = 0; !reader_.eof(); ++i) {
+ EXPECT_EQ(i, reader_.current_row());
+
+ reader_ >> b;
+ reader_ >> str;
+ reader_ >> c;
+ reader_ >> char_array;
+ reader_ >> int8;
+ reader_ >> uint16;
+ reader_ >> int32;
+ reader_ >> uint64;
+ reader_ >> ts_us;
+ reader_ >> f;
+ reader_ >> d;
+ reader_ >> EndRow;
+
+ EXPECT_EQ(b, TestData::GetBool(i)) << "index: " << i;
+ EXPECT_EQ(str, TestData::GetString(i)) << "index: " << i;
+ EXPECT_EQ(c, TestData::GetChar(i)) << "index: " << i;
+ EXPECT_EQ(char_array, TestData::GetCharArray(i)) << "index: " << i;
+ EXPECT_EQ(int8, TestData::GetInt8(i)) << "index: " << i;
+ EXPECT_EQ(uint16, TestData::GetUInt16(i)) << "index: " << i;
+ EXPECT_EQ(int32, TestData::GetInt32(i)) << "index: " << i;
+ EXPECT_EQ(uint64, TestData::GetUInt64(i)) << "index: " << i;
+ EXPECT_EQ(ts_us, TestData::GetChronoMicroseconds(i)) << "index: " << i;
+ EXPECT_FLOAT_EQ(f, TestData::GetFloat(i)) << "index: " << i;
+ EXPECT_DOUBLE_EQ(d, TestData::GetDouble(i)) << "index: " << i;
+ }
+ EXPECT_EQ(reader_.current_row(), TestData::num_rows);
+ EXPECT_EQ(reader_.num_rows(), TestData::num_rows);
+ EXPECT_EQ(i, TestData::num_rows);
+}
+
+TEST_F(TestStreamReader, ReadRequiredFieldAsOptionalField) {
+ /* Test that required fields can be read using optional types.
+
+ This can be useful if a schema is changed such that a field which
+ was optional is changed to be required. Applications can continue
+ to read the field as if it were still optional.
+ */
+
+ optional<bool> opt_bool;
+ optional<std::string> opt_string;
+ optional<std::array<char, 4>> opt_char_array;
+ optional<char> opt_char;
+ optional<int8_t> opt_int8;
+ optional<uint16_t> opt_uint16;
+ optional<int32_t> opt_int32;
+ optional<uint64_t> opt_uint64;
+ optional<std::chrono::microseconds> opt_ts_us;
+ optional<float> opt_float;
+ optional<double> opt_double;
+
+ int i;
+
+ for (i = 0; !reader_.eof(); ++i) {
+ EXPECT_EQ(i, reader_.current_row());
+
+ reader_ >> opt_bool;
+ reader_ >> opt_string;
+ reader_ >> opt_char;
+ reader_ >> opt_char_array;
+ reader_ >> opt_int8;
+ reader_ >> opt_uint16;
+ reader_ >> opt_int32;
+ reader_ >> opt_uint64;
+ reader_ >> opt_ts_us;
+ reader_ >> opt_float;
+ reader_ >> opt_double;
+ reader_ >> EndRow;
+
+ EXPECT_EQ(*opt_bool, TestData::GetBool(i)) << "index: " << i;
+ EXPECT_EQ(*opt_string, TestData::GetString(i)) << "index: " << i;
+ EXPECT_EQ(*opt_char, TestData::GetChar(i)) << "index: " << i;
+ EXPECT_EQ(*opt_char_array, TestData::GetCharArray(i)) << "index: " << i;
+ EXPECT_EQ(*opt_int8, TestData::GetInt8(i)) << "index: " << i;
+ EXPECT_EQ(*opt_uint16, TestData::GetUInt16(i)) << "index: " << i;
+ EXPECT_EQ(*opt_int32, TestData::GetInt32(i)) << "index: " << i;
+ EXPECT_EQ(*opt_uint64, TestData::GetUInt64(i)) << "index: " << i;
+ EXPECT_EQ(*opt_ts_us, TestData::GetChronoMicroseconds(i)) << "index: " << i;
+ EXPECT_FLOAT_EQ(*opt_float, TestData::GetFloat(i)) << "index: " << i;
+ EXPECT_DOUBLE_EQ(*opt_double, TestData::GetDouble(i)) << "index: " << i;
+ }
+ EXPECT_EQ(reader_.current_row(), TestData::num_rows);
+ EXPECT_EQ(reader_.num_rows(), TestData::num_rows);
+ EXPECT_EQ(i, TestData::num_rows);
+}
+
+TEST_F(TestStreamReader, SkipRows) {
+ // Skipping zero and negative number of rows is ok.
+ //
+ EXPECT_EQ(0, reader_.SkipRows(0));
+ EXPECT_EQ(0, reader_.SkipRows(-100));
+
+ EXPECT_EQ(false, reader_.eof());
+ EXPECT_EQ(0, reader_.current_row());
+ EXPECT_EQ(TestData::num_rows, reader_.num_rows());
+
+ const int iter_num_rows_to_read = 3;
+ const int iter_num_rows_to_skip = 13;
+ int num_rows_read = 0;
+ int i = 0;
+ int num_iterations;
+
+ for (num_iterations = 0; !reader_.eof(); ++num_iterations) {
+ // Each iteration of this loop reads some rows (iter_num_rows_to_read
+ // are read) and then skips some rows (iter_num_rows_to_skip will be
+ // skipped).
+ // The loop variable i is the current row being read.
+ // Loop variable j is used just to count the number of rows to
+ // read.
+ bool b;
+ std::string s;
+ std::array<char, 4> char_array;
+ char c;
+ int8_t int8;
+ uint16_t uint16;
+ int32_t int32;
+ uint64_t uint64;
+ std::chrono::microseconds ts_us;
+ float f;
+ double d;
+ std::string str;
+
+ for (int j = 0; !reader_.eof() && (j < iter_num_rows_to_read); ++i, ++j) {
+ EXPECT_EQ(i, reader_.current_row());
+
+ reader_ >> b;
+ reader_ >> s;
+ reader_ >> c;
+ reader_ >> char_array;
+ reader_ >> int8;
+ reader_ >> uint16;
+
+ // Not allowed to skip row once reading columns has started.
+ EXPECT_THROW(reader_.SkipRows(1), ParquetException);
+
+ reader_ >> int32;
+ reader_ >> uint64;
+ reader_ >> ts_us;
+ reader_ >> f;
+ reader_ >> d;
+ reader_ >> EndRow;
+ num_rows_read += 1;
+
+ EXPECT_EQ(b, TestData::GetBool(i));
+ EXPECT_EQ(s, TestData::GetString(i));
+ EXPECT_EQ(c, TestData::GetChar(i));
+ EXPECT_EQ(char_array, TestData::GetCharArray(i));
+ EXPECT_EQ(int8, TestData::GetInt8(i));
+ EXPECT_EQ(uint16, TestData::GetUInt16(i));
+ EXPECT_EQ(int32, TestData::GetInt32(i));
+ EXPECT_EQ(uint64, TestData::GetUInt64(i));
+ EXPECT_EQ(ts_us, TestData::GetChronoMicroseconds(i));
+ EXPECT_FLOAT_EQ(f, TestData::GetFloat(i));
+ EXPECT_DOUBLE_EQ(d, TestData::GetDouble(i));
+ }
+ EXPECT_EQ(iter_num_rows_to_skip, reader_.SkipRows(iter_num_rows_to_skip));
+ i += iter_num_rows_to_skip;
+ }
+ EXPECT_EQ(TestData::num_rows, reader_.current_row());
+
+ EXPECT_EQ(num_rows_read, num_iterations * iter_num_rows_to_read);
+
+ // Skipping rows at eof is allowed.
+ //
+ EXPECT_EQ(0, reader_.SkipRows(100));
+}
+
+TEST_F(TestStreamReader, SkipAllRows) {
+ EXPECT_EQ(false, reader_.eof());
+ EXPECT_EQ(0, reader_.current_row());
+
+ EXPECT_EQ(reader_.num_rows(), reader_.SkipRows(2 * reader_.num_rows()));
+
+ EXPECT_EQ(true, reader_.eof());
+ EXPECT_EQ(reader_.num_rows(), reader_.current_row());
+}
+
+TEST_F(TestStreamReader, SkipColumns) {
+ bool b;
+ std::string s;
+ std::array<char, 4> char_array;
+ char c;
+ int8_t int8;
+ uint16_t uint16;
+ int32_t int32;
+ uint64_t uint64;
+ std::chrono::microseconds ts_us;
+ float f;
+ double d;
+ std::string str;
+
+ int i;
+
+ // Skipping zero and negative number of columns is ok.
+ //
+ EXPECT_EQ(0, reader_.SkipColumns(0));
+ EXPECT_EQ(0, reader_.SkipColumns(-100));
+
+ for (i = 0; !reader_.eof(); ++i) {
+ EXPECT_EQ(i, reader_.current_row());
+ EXPECT_EQ(0, reader_.current_column());
+
+ // Skip all columns every 31 rows.
+ if (i % 31 == 0) {
+ EXPECT_EQ(reader_.num_columns(), reader_.SkipColumns(reader_.num_columns()))
+ << "index: " << i;
+ EXPECT_EQ(reader_.num_columns(), reader_.current_column()) << "index: " << i;
+ reader_ >> EndRow;
+ continue;
+ }
+ reader_ >> b;
+ EXPECT_EQ(b, TestData::GetBool(i)) << "index: " << i;
+ EXPECT_EQ(1, reader_.current_column()) << "index: " << i;
+
+ // Skip the next column every 3 rows.
+ if (i % 3 == 0) {
+ EXPECT_EQ(1, reader_.SkipColumns(1)) << "index: " << i;
+ } else {
+ reader_ >> s;
+ EXPECT_EQ(s, TestData::GetString(i)) << "index: " << i;
+ }
+ EXPECT_EQ(2, reader_.current_column()) << "index: " << i;
+
+ reader_ >> c;
+ EXPECT_EQ(c, TestData::GetChar(i)) << "index: " << i;
+ EXPECT_EQ(3, reader_.current_column()) << "index: " << i;
+ reader_ >> char_array;
+ EXPECT_EQ(char_array, TestData::GetCharArray(i)) << "index: " << i;
+ EXPECT_EQ(4, reader_.current_column()) << "index: " << i;
+ reader_ >> int8;
+ EXPECT_EQ(int8, TestData::GetInt8(i)) << "index: " << i;
+ EXPECT_EQ(5, reader_.current_column()) << "index: " << i;
+
+ // Skip the next 3 columns every 7 rows.
+ if (i % 7 == 0) {
+ EXPECT_EQ(3, reader_.SkipColumns(3)) << "index: " << i;
+ } else {
+ reader_ >> uint16;
+ EXPECT_EQ(uint16, TestData::GetUInt16(i)) << "index: " << i;
+ EXPECT_EQ(6, reader_.current_column()) << "index: " << i;
+ reader_ >> int32;
+ EXPECT_EQ(int32, TestData::GetInt32(i)) << "index: " << i;
+ EXPECT_EQ(7, reader_.current_column()) << "index: " << i;
+ reader_ >> uint64;
+ EXPECT_EQ(uint64, TestData::GetUInt64(i)) << "index: " << i;
+ }
+ EXPECT_EQ(8, reader_.current_column());
+
+ reader_ >> ts_us;
+ EXPECT_EQ(ts_us, TestData::GetChronoMicroseconds(i)) << "index: " << i;
+ EXPECT_EQ(9, reader_.current_column()) << "index: " << i;
+
+ // Skip 301 columns (i.e. all remaining) every 11 rows.
+ if (i % 11 == 0) {
+ EXPECT_EQ(2, reader_.SkipColumns(301)) << "index: " << i;
+ } else {
+ reader_ >> f;
+ EXPECT_FLOAT_EQ(f, TestData::GetFloat(i)) << "index: " << i;
+ EXPECT_EQ(10, reader_.current_column()) << "index: " << i;
+ reader_ >> d;
+ EXPECT_DOUBLE_EQ(d, TestData::GetDouble(i)) << "index: " << i;
+ }
+ EXPECT_EQ(11, reader_.current_column()) << "index: " << i;
+ reader_ >> EndRow;
+ }
+ EXPECT_EQ(i, TestData::num_rows);
+ EXPECT_EQ(reader_.current_row(), TestData::num_rows);
+
+ // Skipping columns at eof is allowed.
+ //
+ EXPECT_EQ(0, reader_.SkipColumns(100));
+}
+
+class TestOptionalFields : public ::testing::Test {
+ public:
+ TestOptionalFields() { createTestFile(); }
+
+ protected:
+ const char* GetDataFile() const { return "stream_reader_test_optional_fields.parquet"; }
+
+ void SetUp() {
+ PARQUET_ASSIGN_OR_THROW(auto infile, ::arrow::io::ReadableFile::Open(GetDataFile()));
+
+ auto file_reader = ParquetFileReader::Open(infile);
+
+ reader_ = StreamReader{std::move(file_reader)};
+ }
+
+ void TearDown() { reader_ = StreamReader{}; }
+
+ std::shared_ptr<schema::GroupNode> GetSchema() {
+ schema::NodeVector fields;
+
+ fields.push_back(schema::PrimitiveNode::Make("bool_field", Repetition::OPTIONAL,
+ Type::BOOLEAN, ConvertedType::NONE));
+
+ fields.push_back(schema::PrimitiveNode::Make("string_field", Repetition::OPTIONAL,
+ Type::BYTE_ARRAY, ConvertedType::UTF8));
+
+ fields.push_back(schema::PrimitiveNode::Make("char_field", Repetition::OPTIONAL,
+ Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::NONE, 1));
+
+ fields.push_back(schema::PrimitiveNode::Make("char[4]_field", Repetition::OPTIONAL,
+ Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::NONE, 4));
+
+ fields.push_back(schema::PrimitiveNode::Make("int8_field", Repetition::OPTIONAL,
+ Type::INT32, ConvertedType::INT_8));
+
+ fields.push_back(schema::PrimitiveNode::Make("uint16_field", Repetition::OPTIONAL,
+ Type::INT32, ConvertedType::UINT_16));
+
+ fields.push_back(schema::PrimitiveNode::Make("int32_field", Repetition::OPTIONAL,
+ Type::INT32, ConvertedType::INT_32));
+
+ fields.push_back(schema::PrimitiveNode::Make("uint64_field", Repetition::OPTIONAL,
+ Type::INT64, ConvertedType::UINT_64));
+
+ fields.push_back(schema::PrimitiveNode::Make("chrono_microseconds_field",
+ Repetition::OPTIONAL, Type::INT64,
+ ConvertedType::TIMESTAMP_MICROS));
+
+ fields.push_back(schema::PrimitiveNode::Make("float_field", Repetition::OPTIONAL,
+ Type::FLOAT, ConvertedType::NONE));
+
+ fields.push_back(schema::PrimitiveNode::Make("double_field", Repetition::OPTIONAL,
+ Type::DOUBLE, ConvertedType::NONE));
+
+ return std::static_pointer_cast<schema::GroupNode>(
+ schema::GroupNode::Make("schema", Repetition::REQUIRED, fields));
+ }
+
+ void createTestFile() {
+ PARQUET_ASSIGN_OR_THROW(auto outfile,
+ ::arrow::io::FileOutputStream::Open(GetDataFile()));
+
+ StreamWriter os{ParquetFileWriter::Open(outfile, GetSchema())};
+
+ TestData::init();
+
+ for (auto i = 0; i < TestData::num_rows; ++i) {
+ os << TestData::GetOptBool(i);
+ os << TestData::GetOptString(i);
+ os << TestData::GetOptChar(i);
+ os << TestData::GetOptCharArray(i);
+ os << TestData::GetOptInt8(i);
+ os << TestData::GetOptUInt16(i);
+ os << TestData::GetOptInt32(i);
+ os << TestData::GetOptUInt64(i);
+ os << TestData::GetOptChronoMicroseconds(i);
+ os << TestData::GetOptFloat(i);
+ os << TestData::GetOptDouble(i);
+ os << EndRow;
+ }
+ }
+
+ StreamReader reader_;
+};
+
+TEST_F(TestOptionalFields, ValueChecking) {
+ optional<bool> opt_bool;
+ optional<std::string> opt_string;
+ optional<std::array<char, 4>> opt_char_array;
+ optional<char> opt_char;
+ optional<int8_t> opt_int8;
+ optional<uint16_t> opt_uint16;
+ optional<int32_t> opt_int32;
+ optional<uint64_t> opt_uint64;
+ optional<std::chrono::microseconds> opt_ts_us;
+ optional<float> opt_float;
+ optional<double> opt_double;
+
+ int i;
+
+ for (i = 0; !reader_.eof(); ++i) {
+ EXPECT_EQ(i, reader_.current_row());
+
+ reader_ >> opt_bool;
+ reader_ >> opt_string;
+ reader_ >> opt_char;
+ reader_ >> opt_char_array;
+ reader_ >> opt_int8;
+ reader_ >> opt_uint16;
+ reader_ >> opt_int32;
+ reader_ >> opt_uint64;
+ reader_ >> opt_ts_us;
+ reader_ >> opt_float;
+ reader_ >> opt_double;
+ reader_ >> EndRow;
+
+ EXPECT_EQ(opt_bool, TestData::GetOptBool(i)) << "index: " << i;
+ EXPECT_EQ(opt_string, TestData::GetOptString(i)) << "index: " << i;
+ EXPECT_EQ(opt_char, TestData::GetOptChar(i)) << "index: " << i;
+ EXPECT_EQ(opt_char_array, TestData::GetOptCharArray(i)) << "index: " << i;
+ EXPECT_EQ(opt_int8, TestData::GetOptInt8(i)) << "index: " << i;
+ EXPECT_EQ(opt_uint16, TestData::GetOptUInt16(i)) << "index: " << i;
+ EXPECT_EQ(opt_int32, TestData::GetOptInt32(i)) << "index: " << i;
+ EXPECT_EQ(opt_uint64, TestData::GetOptUInt64(i)) << "index: " << i;
+ EXPECT_EQ(opt_ts_us, TestData::GetOptChronoMicroseconds(i)) << "index: " << i;
+ if (opt_float && TestData::GetOptFloat(i)) {
+ EXPECT_FLOAT_EQ(*opt_float, *TestData::GetOptFloat(i)) << "index: " << i;
+ } else {
+ EXPECT_EQ(opt_float, TestData::GetOptFloat(i)) << "index: " << i;
+ }
+ if (opt_double && TestData::GetOptDouble(i)) {
+ EXPECT_DOUBLE_EQ(*opt_double, *TestData::GetOptDouble(i)) << "index: " << i;
+ } else {
+ EXPECT_EQ(opt_double, TestData::GetOptDouble(i)) << "index: " << i;
+ }
+ }
+ EXPECT_EQ(reader_.current_row(), TestData::num_rows);
+ EXPECT_EQ(reader_.num_rows(), TestData::num_rows);
+ EXPECT_EQ(i, TestData::num_rows);
+}
+
+TEST_F(TestOptionalFields, ReadOptionalFieldAsRequiredField) {
+ /* Test that optional fields can be read using non-optional types
+ _provided_ that the optional value is available.
+
+ This can be useful if a schema is changed such that a required
+ field beomes optional. Applications can continue reading the
+ field as if it were mandatory and do not need to be changed if the
+ field value is always provided.
+
+ Of course if the optional value is not present, then the read will
+ fail by throwing an exception. This is also tested below.
+ */
+
+ bool b;
+ std::string s;
+ std::array<char, 4> char_array;
+ char c;
+ int8_t int8;
+ uint16_t uint16;
+ int32_t int32;
+ uint64_t uint64;
+ std::chrono::microseconds ts_us;
+ float f;
+ double d;
+ std::string str;
+
+ int i;
+
+ for (i = 0; !reader_.eof(); ++i) {
+ EXPECT_EQ(i, reader_.current_row());
+
+ if (TestData::GetOptBool(i)) {
+ reader_ >> b;
+ EXPECT_EQ(b, *TestData::GetOptBool(i)) << "index: " << i;
+ } else {
+ EXPECT_THROW(reader_ >> b, ParquetException) << "index: " << i;
+ }
+ if (TestData::GetOptString(i)) {
+ reader_ >> s;
+ EXPECT_EQ(s, *TestData::GetOptString(i)) << "index: " << i;
+ } else {
+ EXPECT_THROW(reader_ >> s, ParquetException) << "index: " << i;
+ }
+ if (TestData::GetOptChar(i)) {
+ reader_ >> c;
+ EXPECT_EQ(c, *TestData::GetOptChar(i)) << "index: " << i;
+ } else {
+ EXPECT_THROW(reader_ >> c, ParquetException) << "index: " << i;
+ }
+ if (TestData::GetOptCharArray(i)) {
+ reader_ >> char_array;
+ EXPECT_EQ(char_array, *TestData::GetOptCharArray(i)) << "index: " << i;
+ } else {
+ EXPECT_THROW(reader_ >> char_array, ParquetException) << "index: " << i;
+ }
+ if (TestData::GetOptInt8(i)) {
+ reader_ >> int8;
+ EXPECT_EQ(int8, *TestData::GetOptInt8(i)) << "index: " << i;
+ } else {
+ EXPECT_THROW(reader_ >> int8, ParquetException) << "index: " << i;
+ }
+ if (TestData::GetOptUInt16(i)) {
+ reader_ >> uint16;
+ EXPECT_EQ(uint16, *TestData::GetOptUInt16(i)) << "index: " << i;
+ } else {
+ EXPECT_THROW(reader_ >> uint16, ParquetException) << "index: " << i;
+ }
+ if (TestData::GetOptInt32(i)) {
+ reader_ >> int32;
+ EXPECT_EQ(int32, *TestData::GetOptInt32(i)) << "index: " << i;
+ } else {
+ EXPECT_THROW(reader_ >> int32, ParquetException) << "index: " << i;
+ }
+ if (TestData::GetOptUInt64(i)) {
+ reader_ >> uint64;
+ EXPECT_EQ(uint64, *TestData::GetOptUInt64(i)) << "index: " << i;
+ } else {
+ EXPECT_THROW(reader_ >> uint64, ParquetException) << "index: " << i;
+ }
+ if (TestData::GetOptChronoMicroseconds(i)) {
+ reader_ >> ts_us;
+ EXPECT_EQ(ts_us, *TestData::GetOptChronoMicroseconds(i)) << "index: " << i;
+ } else {
+ EXPECT_THROW(reader_ >> ts_us, ParquetException) << "index: " << i;
+ }
+ if (TestData::GetOptFloat(i)) {
+ reader_ >> f;
+ EXPECT_FLOAT_EQ(f, *TestData::GetOptFloat(i)) << "index: " << i;
+ } else {
+ EXPECT_THROW(reader_ >> f, ParquetException) << "index: " << i;
+ }
+ if (TestData::GetOptDouble(i)) {
+ reader_ >> d;
+ EXPECT_DOUBLE_EQ(d, *TestData::GetOptDouble(i)) << "index: " << i;
+ } else {
+ EXPECT_THROW(reader_ >> d, ParquetException) << "index: " << i;
+ }
+ reader_ >> EndRow;
+ }
+}
+
+class TestReadingDataFiles : public ::testing::Test {
+ protected:
+ std::string GetDataFile(const std::string& filename) const {
+ return std::string(get_data_dir()) + "/" + filename;
+ }
+};
+
+TEST_F(TestReadingDataFiles, AllTypesPlain) {
+ PARQUET_ASSIGN_OR_THROW(auto infile, ::arrow::io::ReadableFile::Open(
+ GetDataFile("alltypes_plain.parquet")));
+
+ auto file_reader = ParquetFileReader::Open(infile);
+ auto reader = StreamReader{std::move(file_reader)};
+
+ int32_t c0;
+ bool c1;
+ int32_t c2;
+ int32_t c3;
+ int32_t c4;
+ int64_t c5;
+ float c6;
+ double c7;
+ std::string c8;
+ std::string c9;
+
+ const char* expected_date_str[] = {"03/01/09", "03/01/09", "04/01/09", "04/01/09",
+ "02/01/09", "02/01/09", "01/01/09", "01/01/09"};
+ int i;
+
+ for (i = 0; !reader.eof(); ++i) {
+ reader >> c0 >> c1 >> c2 >> c3 >> c4 >> c5;
+ reader >> c6 >> c7;
+ reader >> c8 >> c9;
+ reader.SkipColumns(1); // Skip column with unsupported 96-bit type
+ reader >> EndRow;
+
+ EXPECT_EQ(c1, (i & 1) == 0);
+ EXPECT_EQ(c2, i & 1);
+ EXPECT_EQ(c3, i & 1);
+ EXPECT_EQ(c4, i & 1);
+ EXPECT_EQ(c5, i & 1 ? 10 : 0);
+ EXPECT_FLOAT_EQ(c6, i & 1 ? 1.1f : 0.f);
+ EXPECT_DOUBLE_EQ(c7, i & 1 ? 10.1 : 0.);
+ ASSERT_LT(static_cast<std::size_t>(i),
+ sizeof(expected_date_str) / sizeof(expected_date_str[0]));
+ EXPECT_EQ(c8, expected_date_str[i]);
+ EXPECT_EQ(c9, i & 1 ? "1" : "0");
+ }
+ EXPECT_EQ(i, sizeof(expected_date_str) / sizeof(expected_date_str[0]));
+}
+
+TEST_F(TestReadingDataFiles, Int32Decimal) {
+ PARQUET_ASSIGN_OR_THROW(
+ auto infile, ::arrow::io::ReadableFile::Open(GetDataFile("int32_decimal.parquet")));
+
+ auto file_reader = ParquetFileReader::Open(infile);
+ auto reader = StreamReader{std::move(file_reader)};
+
+ int32_t x;
+ int i;
+
+ for (i = 1; !reader.eof(); ++i) {
+ reader >> x >> EndRow;
+ EXPECT_EQ(x, i * 100);
+ }
+ EXPECT_EQ(i, 25);
+}
+
+TEST_F(TestReadingDataFiles, Int64Decimal) {
+ PARQUET_ASSIGN_OR_THROW(
+ auto infile, ::arrow::io::ReadableFile::Open(GetDataFile("int64_decimal.parquet")));
+
+ auto file_reader = ParquetFileReader::Open(infile);
+ auto reader = StreamReader{std::move(file_reader)};
+
+ int64_t x;
+ int i;
+
+ for (i = 1; !reader.eof(); ++i) {
+ reader >> x >> EndRow;
+ EXPECT_EQ(x, i * 100);
+ }
+ EXPECT_EQ(i, 25);
+}
+
+} // namespace test
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/stream_writer.cc b/src/arrow/cpp/src/parquet/stream_writer.cc
new file mode 100644
index 000000000..253ebf1bc
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/stream_writer.cc
@@ -0,0 +1,324 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/stream_writer.h"
+
+#include <utility>
+
+namespace parquet {
+
+int64_t StreamWriter::default_row_group_size_{512 * 1024 * 1024}; // 512MB
+
+constexpr int16_t StreamWriter::kDefLevelZero;
+constexpr int16_t StreamWriter::kDefLevelOne;
+constexpr int16_t StreamWriter::kRepLevelZero;
+constexpr int64_t StreamWriter::kBatchSizeOne;
+
+StreamWriter::FixedStringView::FixedStringView(const char* data_ptr)
+ : data{data_ptr}, size{std::strlen(data_ptr)} {}
+
+StreamWriter::FixedStringView::FixedStringView(const char* data_ptr, std::size_t data_len)
+ : data{data_ptr}, size{data_len} {}
+
+StreamWriter::StreamWriter(std::unique_ptr<ParquetFileWriter> writer)
+ : file_writer_{std::move(writer)},
+ row_group_writer_{file_writer_->AppendBufferedRowGroup()} {
+ auto schema = file_writer_->schema();
+ auto group_node = schema->group_node();
+
+ nodes_.resize(schema->num_columns());
+
+ for (auto i = 0; i < schema->num_columns(); ++i) {
+ nodes_[i] = std::static_pointer_cast<schema::PrimitiveNode>(group_node->field(i));
+ }
+}
+
+void StreamWriter::SetDefaultMaxRowGroupSize(int64_t max_size) {
+ default_row_group_size_ = max_size;
+}
+
+void StreamWriter::SetMaxRowGroupSize(int64_t max_size) {
+ max_row_group_size_ = max_size;
+}
+
+int StreamWriter::num_columns() const { return static_cast<int>(nodes_.size()); }
+
+StreamWriter& StreamWriter::operator<<(bool v) {
+ CheckColumn(Type::BOOLEAN, ConvertedType::NONE);
+ return Write<BoolWriter>(v);
+}
+
+StreamWriter& StreamWriter::operator<<(int8_t v) {
+ CheckColumn(Type::INT32, ConvertedType::INT_8);
+ return Write<Int32Writer>(static_cast<int32_t>(v));
+}
+
+StreamWriter& StreamWriter::operator<<(uint8_t v) {
+ CheckColumn(Type::INT32, ConvertedType::UINT_8);
+ return Write<Int32Writer>(static_cast<int32_t>(v));
+}
+
+StreamWriter& StreamWriter::operator<<(int16_t v) {
+ CheckColumn(Type::INT32, ConvertedType::INT_16);
+ return Write<Int32Writer>(static_cast<int32_t>(v));
+}
+
+StreamWriter& StreamWriter::operator<<(uint16_t v) {
+ CheckColumn(Type::INT32, ConvertedType::UINT_16);
+ return Write<Int32Writer>(static_cast<int32_t>(v));
+}
+
+StreamWriter& StreamWriter::operator<<(int32_t v) {
+ CheckColumn(Type::INT32, ConvertedType::INT_32);
+ return Write<Int32Writer>(v);
+}
+
+StreamWriter& StreamWriter::operator<<(uint32_t v) {
+ CheckColumn(Type::INT32, ConvertedType::UINT_32);
+ return Write<Int32Writer>(static_cast<int32_t>(v));
+}
+
+StreamWriter& StreamWriter::operator<<(int64_t v) {
+ CheckColumn(Type::INT64, ConvertedType::INT_64);
+ return Write<Int64Writer>(v);
+}
+
+StreamWriter& StreamWriter::operator<<(uint64_t v) {
+ CheckColumn(Type::INT64, ConvertedType::UINT_64);
+ return Write<Int64Writer>(static_cast<int64_t>(v));
+}
+
+StreamWriter& StreamWriter::operator<<(const std::chrono::milliseconds& v) {
+ CheckColumn(Type::INT64, ConvertedType::TIMESTAMP_MILLIS);
+ return Write<Int64Writer>(static_cast<int64_t>(v.count()));
+}
+
+StreamWriter& StreamWriter::operator<<(const std::chrono::microseconds& v) {
+ CheckColumn(Type::INT64, ConvertedType::TIMESTAMP_MICROS);
+ return Write<Int64Writer>(static_cast<int64_t>(v.count()));
+}
+
+StreamWriter& StreamWriter::operator<<(float v) {
+ CheckColumn(Type::FLOAT, ConvertedType::NONE);
+ return Write<FloatWriter>(v);
+}
+
+StreamWriter& StreamWriter::operator<<(double v) {
+ CheckColumn(Type::DOUBLE, ConvertedType::NONE);
+ return Write<DoubleWriter>(v);
+}
+
+StreamWriter& StreamWriter::operator<<(char v) { return WriteFixedLength(&v, 1); }
+
+StreamWriter& StreamWriter::operator<<(FixedStringView v) {
+ return WriteFixedLength(v.data, v.size);
+}
+
+StreamWriter& StreamWriter::operator<<(const char* v) {
+ return WriteVariableLength(v, std::strlen(v));
+}
+
+StreamWriter& StreamWriter::operator<<(const std::string& v) {
+ return WriteVariableLength(v.data(), v.size());
+}
+
+StreamWriter& StreamWriter::operator<<(::arrow::util::string_view v) {
+ return WriteVariableLength(v.data(), v.size());
+}
+
+StreamWriter& StreamWriter::WriteVariableLength(const char* data_ptr,
+ std::size_t data_len) {
+ CheckColumn(Type::BYTE_ARRAY, ConvertedType::UTF8);
+
+ auto writer = static_cast<ByteArrayWriter*>(row_group_writer_->column(column_index_++));
+
+ if (data_ptr != nullptr) {
+ ByteArray ba_value;
+
+ ba_value.ptr = reinterpret_cast<const uint8_t*>(data_ptr);
+ ba_value.len = static_cast<uint32_t>(data_len);
+
+ writer->WriteBatch(kBatchSizeOne, &kDefLevelOne, &kRepLevelZero, &ba_value);
+ } else {
+ writer->WriteBatch(kBatchSizeOne, &kDefLevelZero, &kRepLevelZero, nullptr);
+ }
+ if (max_row_group_size_ > 0) {
+ row_group_size_ += writer->EstimatedBufferedValueBytes();
+ }
+ return *this;
+}
+
+StreamWriter& StreamWriter::WriteFixedLength(const char* data_ptr, std::size_t data_len) {
+ CheckColumn(Type::FIXED_LEN_BYTE_ARRAY, ConvertedType::NONE,
+ static_cast<int>(data_len));
+
+ auto writer =
+ static_cast<FixedLenByteArrayWriter*>(row_group_writer_->column(column_index_++));
+
+ if (data_ptr != nullptr) {
+ FixedLenByteArray flba_value;
+
+ flba_value.ptr = reinterpret_cast<const uint8_t*>(data_ptr);
+ writer->WriteBatch(kBatchSizeOne, &kDefLevelOne, &kRepLevelZero, &flba_value);
+ } else {
+ writer->WriteBatch(kBatchSizeOne, &kDefLevelZero, &kRepLevelZero, nullptr);
+ }
+ if (max_row_group_size_ > 0) {
+ row_group_size_ += writer->EstimatedBufferedValueBytes();
+ }
+ return *this;
+}
+
+void StreamWriter::CheckColumn(Type::type physical_type,
+ ConvertedType::type converted_type, int length) {
+ if (static_cast<std::size_t>(column_index_) >= nodes_.size()) {
+ throw ParquetException("Column index out-of-bounds. Index " +
+ std::to_string(column_index_) + " is invalid for " +
+ std::to_string(nodes_.size()) + " columns");
+ }
+ const auto& node = nodes_[column_index_];
+
+ if (physical_type != node->physical_type()) {
+ throw ParquetException("Column physical type mismatch. Column '" + node->name() +
+ "' has physical type '" + TypeToString(node->physical_type()) +
+ "' not '" + TypeToString(physical_type) + "'");
+ }
+ if (converted_type != node->converted_type()) {
+ throw ParquetException("Column converted type mismatch. Column '" + node->name() +
+ "' has converted type[" +
+ ConvertedTypeToString(node->converted_type()) + "] not '" +
+ ConvertedTypeToString(converted_type) + "'");
+ }
+ // Length must be exact.
+ // A shorter length fixed array is not acceptable as it would
+ // result in array bound read errors.
+ //
+ if (length != node->type_length()) {
+ throw ParquetException("Column length mismatch. Column '" + node->name() +
+ "' has length " + std::to_string(node->type_length()) +
+ " not " + std::to_string(length));
+ }
+}
+
+int64_t StreamWriter::SkipColumns(int num_columns_to_skip) {
+ int num_columns_skipped = 0;
+
+ for (; (num_columns_to_skip > num_columns_skipped) &&
+ static_cast<std::size_t>(column_index_) < nodes_.size();
+ ++num_columns_skipped) {
+ const auto& node = nodes_[column_index_];
+
+ if (node->is_required()) {
+ throw ParquetException("Cannot skip column '" + node->name() +
+ "' as it is required.");
+ }
+ auto writer = row_group_writer_->column(column_index_++);
+
+ WriteNullValue(writer);
+ }
+ return num_columns_skipped;
+}
+
+void StreamWriter::WriteNullValue(ColumnWriter* writer) {
+ switch (writer->type()) {
+ case Type::BOOLEAN:
+ static_cast<BoolWriter*>(writer)->WriteBatch(kBatchSizeOne, &kDefLevelZero,
+ &kRepLevelZero, nullptr);
+ break;
+ case Type::INT32:
+ static_cast<Int32Writer*>(writer)->WriteBatch(kBatchSizeOne, &kDefLevelZero,
+ &kRepLevelZero, nullptr);
+ break;
+ case Type::INT64:
+ static_cast<Int64Writer*>(writer)->WriteBatch(kBatchSizeOne, &kDefLevelZero,
+ &kRepLevelZero, nullptr);
+ break;
+ case Type::BYTE_ARRAY:
+ static_cast<ByteArrayWriter*>(writer)->WriteBatch(kBatchSizeOne, &kDefLevelZero,
+ &kRepLevelZero, nullptr);
+ break;
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ static_cast<FixedLenByteArrayWriter*>(writer)->WriteBatch(
+ kBatchSizeOne, &kDefLevelZero, &kRepLevelZero, nullptr);
+ break;
+ case Type::FLOAT:
+ static_cast<FloatWriter*>(writer)->WriteBatch(kBatchSizeOne, &kDefLevelZero,
+ &kRepLevelZero, nullptr);
+ break;
+ case Type::DOUBLE:
+ static_cast<DoubleWriter*>(writer)->WriteBatch(kBatchSizeOne, &kDefLevelZero,
+ &kRepLevelZero, nullptr);
+ break;
+ case Type::INT96:
+ case Type::UNDEFINED:
+ throw ParquetException("Unexpected type: " + TypeToString(writer->type()));
+ break;
+ }
+}
+
+void StreamWriter::SkipOptionalColumn() {
+ if (SkipColumns(1) != 1) {
+ throw ParquetException("Failed to skip optional column at column index " +
+ std::to_string(column_index_));
+ }
+}
+
+void StreamWriter::EndRow() {
+ if (!file_writer_) {
+ throw ParquetException("StreamWriter not initialized");
+ }
+ if (static_cast<std::size_t>(column_index_) < nodes_.size()) {
+ throw ParquetException("Cannot end row with " + std::to_string(column_index_) +
+ " of " + std::to_string(nodes_.size()) + " columns written");
+ }
+ column_index_ = 0;
+ ++current_row_;
+
+ if (max_row_group_size_ > 0) {
+ if (row_group_size_ > max_row_group_size_) {
+ EndRowGroup();
+ }
+ // Initialize for each row with size already written
+ // (compressed + uncompressed).
+ //
+ row_group_size_ = row_group_writer_->total_bytes_written() +
+ row_group_writer_->total_compressed_bytes();
+ }
+}
+
+void StreamWriter::EndRowGroup() {
+ if (!file_writer_) {
+ throw ParquetException("StreamWriter not initialized");
+ }
+ // Avoid creating empty row groups.
+ if (row_group_writer_->num_rows() > 0) {
+ row_group_writer_->Close();
+ row_group_writer_.reset(file_writer_->AppendBufferedRowGroup());
+ }
+}
+
+StreamWriter& operator<<(StreamWriter& os, EndRowType) {
+ os.EndRow();
+ return os;
+}
+
+StreamWriter& operator<<(StreamWriter& os, EndRowGroupType) {
+ os.EndRowGroup();
+ return os;
+}
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/stream_writer.h b/src/arrow/cpp/src/parquet/stream_writer.h
new file mode 100644
index 000000000..d0db850c3
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/stream_writer.h
@@ -0,0 +1,243 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <array>
+#include <chrono>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/util/optional.h"
+#include "arrow/util/string_view.h"
+#include "parquet/column_writer.h"
+#include "parquet/file_writer.h"
+
+namespace parquet {
+
+/// \brief A class for writing Parquet files using an output stream type API.
+///
+/// The values given must be of the correct type i.e. the type must
+/// match the file schema exactly otherwise a ParquetException will be
+/// thrown.
+///
+/// The user must explicitly indicate the end of the row using the
+/// EndRow() function or EndRow output manipulator.
+///
+/// A maximum row group size can be configured, the default size is
+/// 512MB. Alternatively the row group size can be set to zero and the
+/// user can create new row groups by calling the EndRowGroup()
+/// function or using the EndRowGroup output manipulator.
+///
+/// Required and optional fields are supported:
+/// - Required fields are written using operator<<(T)
+/// - Optional fields are written using
+/// operator<<(arrow::util::optional<T>).
+///
+/// Note that operator<<(T) can be used to write optional fields.
+///
+/// Similarly, operator<<(arrow::util::optional<T>) can be used to
+/// write required fields. However if the optional parameter does not
+/// have a value (i.e. it is nullopt) then a ParquetException will be
+/// raised.
+///
+/// Currently there is no support for repeated fields.
+///
+class PARQUET_EXPORT StreamWriter {
+ public:
+ template <typename T>
+ using optional = ::arrow::util::optional<T>;
+
+ // N.B. Default constructed objects are not usable. This
+ // constructor is provided so that the object may be move
+ // assigned afterwards.
+ StreamWriter() = default;
+
+ explicit StreamWriter(std::unique_ptr<ParquetFileWriter> writer);
+
+ ~StreamWriter() = default;
+
+ static void SetDefaultMaxRowGroupSize(int64_t max_size);
+
+ void SetMaxRowGroupSize(int64_t max_size);
+
+ int current_column() const { return column_index_; }
+
+ int64_t current_row() const { return current_row_; }
+
+ int num_columns() const;
+
+ // Moving is possible.
+ StreamWriter(StreamWriter&&) = default;
+ StreamWriter& operator=(StreamWriter&&) = default;
+
+ // Copying is not allowed.
+ StreamWriter(const StreamWriter&) = delete;
+ StreamWriter& operator=(const StreamWriter&) = delete;
+
+ /// \brief Output operators for required fields.
+ /// These can also be used for optional fields when a value must be set.
+ StreamWriter& operator<<(bool v);
+
+ StreamWriter& operator<<(int8_t v);
+
+ StreamWriter& operator<<(uint8_t v);
+
+ StreamWriter& operator<<(int16_t v);
+
+ StreamWriter& operator<<(uint16_t v);
+
+ StreamWriter& operator<<(int32_t v);
+
+ StreamWriter& operator<<(uint32_t v);
+
+ StreamWriter& operator<<(int64_t v);
+
+ StreamWriter& operator<<(uint64_t v);
+
+ StreamWriter& operator<<(const std::chrono::milliseconds& v);
+
+ StreamWriter& operator<<(const std::chrono::microseconds& v);
+
+ StreamWriter& operator<<(float v);
+
+ StreamWriter& operator<<(double v);
+
+ StreamWriter& operator<<(char v);
+
+ /// \brief Helper class to write fixed length strings.
+ /// This is useful as the standard string view (such as
+ /// arrow::util::string_view) is for variable length data.
+ struct PARQUET_EXPORT FixedStringView {
+ FixedStringView() = default;
+
+ explicit FixedStringView(const char* data_ptr);
+
+ FixedStringView(const char* data_ptr, std::size_t data_len);
+
+ const char* data{NULLPTR};
+ std::size_t size{0};
+ };
+
+ /// \brief Output operators for fixed length strings.
+ template <int N>
+ StreamWriter& operator<<(const char (&v)[N]) {
+ return WriteFixedLength(v, N);
+ }
+ template <std::size_t N>
+ StreamWriter& operator<<(const std::array<char, N>& v) {
+ return WriteFixedLength(v.data(), N);
+ }
+ StreamWriter& operator<<(FixedStringView v);
+
+ /// \brief Output operators for variable length strings.
+ StreamWriter& operator<<(const char* v);
+ StreamWriter& operator<<(const std::string& v);
+ StreamWriter& operator<<(::arrow::util::string_view v);
+
+ /// \brief Output operator for optional fields.
+ template <typename T>
+ StreamWriter& operator<<(const optional<T>& v) {
+ if (v) {
+ return operator<<(*v);
+ }
+ SkipOptionalColumn();
+ return *this;
+ }
+
+ /// \brief Skip the next N columns of optional data. If there are
+ /// less than N columns remaining then the excess columns are
+ /// ignored.
+ /// \throws ParquetException if there is an attempt to skip any
+ /// required column.
+ /// \return Number of columns actually skipped.
+ int64_t SkipColumns(int num_columns_to_skip);
+
+ /// \brief Terminate the current row and advance to next one.
+ /// \throws ParquetException if all columns in the row were not
+ /// written or skipped.
+ void EndRow();
+
+ /// \brief Terminate the current row group and create new one.
+ void EndRowGroup();
+
+ protected:
+ template <typename WriterType, typename T>
+ StreamWriter& Write(const T v) {
+ auto writer = static_cast<WriterType*>(row_group_writer_->column(column_index_++));
+
+ writer->WriteBatch(kBatchSizeOne, &kDefLevelOne, &kRepLevelZero, &v);
+
+ if (max_row_group_size_ > 0) {
+ row_group_size_ += writer->EstimatedBufferedValueBytes();
+ }
+ return *this;
+ }
+
+ StreamWriter& WriteVariableLength(const char* data_ptr, std::size_t data_len);
+
+ StreamWriter& WriteFixedLength(const char* data_ptr, std::size_t data_len);
+
+ void CheckColumn(Type::type physical_type, ConvertedType::type converted_type,
+ int length = -1);
+
+ /// \brief Skip the next column which must be optional.
+ /// \throws ParquetException if the next column does not exist or is
+ /// not optional.
+ void SkipOptionalColumn();
+
+ void WriteNullValue(ColumnWriter* writer);
+
+ private:
+ using node_ptr_type = std::shared_ptr<schema::PrimitiveNode>;
+
+ struct null_deleter {
+ void operator()(void*) {}
+ };
+
+ int32_t column_index_{0};
+ int64_t current_row_{0};
+ int64_t row_group_size_{0};
+ int64_t max_row_group_size_{default_row_group_size_};
+
+ std::unique_ptr<ParquetFileWriter> file_writer_;
+ std::unique_ptr<RowGroupWriter, null_deleter> row_group_writer_;
+ std::vector<node_ptr_type> nodes_;
+
+ static constexpr int16_t kDefLevelZero = 0;
+ static constexpr int16_t kDefLevelOne = 1;
+ static constexpr int16_t kRepLevelZero = 0;
+ static constexpr int64_t kBatchSizeOne = 1;
+
+ static int64_t default_row_group_size_;
+};
+
+struct PARQUET_EXPORT EndRowType {};
+constexpr EndRowType EndRow = {};
+
+struct PARQUET_EXPORT EndRowGroupType {};
+constexpr EndRowGroupType EndRowGroup = {};
+
+PARQUET_EXPORT
+StreamWriter& operator<<(StreamWriter&, EndRowType);
+
+PARQUET_EXPORT
+StreamWriter& operator<<(StreamWriter&, EndRowGroupType);
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/stream_writer_test.cc b/src/arrow/cpp/src/parquet/stream_writer_test.cc
new file mode 100644
index 000000000..a36feb429
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/stream_writer_test.cc
@@ -0,0 +1,419 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "parquet/stream_writer.h"
+
+#include <fcntl.h>
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <utility>
+
+#include "arrow/io/file.h"
+#include "arrow/io/memory.h"
+#include "parquet/exception.h"
+
+namespace parquet {
+namespace test {
+
+template <typename T>
+using optional = StreamWriter::optional<T>;
+
+using char4_array_type = std::array<char, 4>;
+
+class TestStreamWriter : public ::testing::Test {
+ protected:
+ const char* GetDataFile() const { return "stream_writer_test.parquet"; }
+
+ void SetUp() {
+ writer_ = StreamWriter{ParquetFileWriter::Open(CreateOutputStream(), GetSchema())};
+ }
+
+ void TearDown() { writer_ = StreamWriter{}; }
+
+ std::shared_ptr<schema::GroupNode> GetSchema() {
+ schema::NodeVector fields;
+
+ fields.push_back(schema::PrimitiveNode::Make("bool_field", Repetition::REQUIRED,
+ Type::BOOLEAN, ConvertedType::NONE));
+
+ fields.push_back(schema::PrimitiveNode::Make("string_field", Repetition::REQUIRED,
+ Type::BYTE_ARRAY, ConvertedType::UTF8));
+
+ fields.push_back(schema::PrimitiveNode::Make("char_field", Repetition::REQUIRED,
+ Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::NONE, 1));
+
+ fields.push_back(schema::PrimitiveNode::Make("char[4]_field", Repetition::REQUIRED,
+ Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::NONE, 4));
+
+ fields.push_back(schema::PrimitiveNode::Make("int8_field", Repetition::REQUIRED,
+ Type::INT32, ConvertedType::INT_8));
+
+ fields.push_back(schema::PrimitiveNode::Make("uint16_field", Repetition::REQUIRED,
+ Type::INT32, ConvertedType::UINT_16));
+
+ fields.push_back(schema::PrimitiveNode::Make("int32_field", Repetition::REQUIRED,
+ Type::INT32, ConvertedType::INT_32));
+
+ fields.push_back(schema::PrimitiveNode::Make("uint64_field", Repetition::REQUIRED,
+ Type::INT64, ConvertedType::UINT_64));
+
+ fields.push_back(schema::PrimitiveNode::Make("float_field", Repetition::REQUIRED,
+ Type::FLOAT, ConvertedType::NONE));
+
+ fields.push_back(schema::PrimitiveNode::Make("double_field", Repetition::REQUIRED,
+ Type::DOUBLE, ConvertedType::NONE));
+
+ return std::static_pointer_cast<schema::GroupNode>(
+ schema::GroupNode::Make("schema", Repetition::REQUIRED, fields));
+ }
+
+ StreamWriter writer_;
+};
+
+TEST_F(TestStreamWriter, DefaultConstructed) {
+ StreamWriter os;
+
+ // Default constructor objects are not usable for writing data.
+ EXPECT_THROW(os << 4, ParquetException);
+ EXPECT_THROW(os << "bad", ParquetException);
+ EXPECT_THROW(os << EndRow, ParquetException);
+ EXPECT_THROW(os << EndRowGroup, ParquetException);
+
+ EXPECT_EQ(0, os.current_column());
+ EXPECT_EQ(0, os.current_row());
+ EXPECT_EQ(0, os.num_columns());
+ EXPECT_EQ(0, os.SkipColumns(10));
+}
+
+TEST_F(TestStreamWriter, TypeChecking) {
+ std::array<char, 3> char3_array = {'T', 'S', 'T'};
+ std::array<char, 4> char4_array = {'T', 'E', 'S', 'T'};
+ std::array<char, 5> char5_array = {'T', 'E', 'S', 'T', '2'};
+
+ // Required type: bool
+ EXPECT_EQ(0, writer_.current_column());
+ EXPECT_THROW(writer_ << 4.5, ParquetException);
+ EXPECT_NO_THROW(writer_ << true);
+
+ // Required type: Variable length string.
+ EXPECT_EQ(1, writer_.current_column());
+ EXPECT_THROW(writer_ << 5, ParquetException);
+ EXPECT_THROW(writer_ << char3_array, ParquetException);
+ EXPECT_THROW(writer_ << char4_array, ParquetException);
+ EXPECT_THROW(writer_ << char5_array, ParquetException);
+ EXPECT_NO_THROW(writer_ << "ok");
+
+ // Required type: A char.
+ EXPECT_EQ(2, writer_.current_column());
+ EXPECT_THROW(writer_ << "no good", ParquetException);
+ EXPECT_NO_THROW(writer_ << 'K');
+
+ // Required type: Fixed string of length 4
+ EXPECT_EQ(3, writer_.current_column());
+ EXPECT_THROW(writer_ << "bad", ParquetException);
+ EXPECT_THROW(writer_ << char3_array, ParquetException);
+ EXPECT_THROW(writer_ << char5_array, ParquetException);
+ EXPECT_NO_THROW(writer_ << char4_array);
+
+ // Required type: int8_t
+ EXPECT_EQ(4, writer_.current_column());
+ EXPECT_THROW(writer_ << false, ParquetException);
+ EXPECT_NO_THROW(writer_ << int8_t(51));
+
+ // Required type: uint16_t
+ EXPECT_EQ(5, writer_.current_column());
+ EXPECT_THROW(writer_ << int16_t(15), ParquetException);
+ EXPECT_NO_THROW(writer_ << uint16_t(15));
+
+ // Required type: int32_t
+ EXPECT_EQ(6, writer_.current_column());
+ EXPECT_THROW(writer_ << int16_t(99), ParquetException);
+ EXPECT_NO_THROW(writer_ << int32_t(329487));
+
+ // Required type: uint64_t
+ EXPECT_EQ(7, writer_.current_column());
+ EXPECT_THROW(writer_ << uint32_t(9832423), ParquetException);
+ EXPECT_NO_THROW(writer_ << uint64_t((1ull << 60) + 123));
+
+ // Required type: float
+ EXPECT_EQ(8, writer_.current_column());
+ EXPECT_THROW(writer_ << 5.4, ParquetException);
+ EXPECT_NO_THROW(writer_ << 5.4f);
+
+ // Required type: double
+ EXPECT_EQ(9, writer_.current_column());
+ EXPECT_THROW(writer_ << 5.4f, ParquetException);
+ EXPECT_NO_THROW(writer_ << 5.4);
+
+ EXPECT_EQ(0, writer_.current_row());
+ EXPECT_NO_THROW(writer_ << EndRow);
+ EXPECT_EQ(1, writer_.current_row());
+}
+
+TEST_F(TestStreamWriter, RequiredFieldChecking) {
+ char4_array_type char4_array = {'T', 'E', 'S', 'T'};
+
+ // Required field of type: bool
+ EXPECT_THROW(writer_ << optional<bool>(), ParquetException);
+ EXPECT_NO_THROW(writer_ << optional<bool>(true));
+
+ // Required field of type: Variable length string.
+ EXPECT_THROW(writer_ << optional<std::string>(), ParquetException);
+ EXPECT_NO_THROW(writer_ << optional<std::string>("ok"));
+
+ // Required field of type: A char.
+ EXPECT_THROW(writer_ << optional<char>(), ParquetException);
+ EXPECT_NO_THROW(writer_ << optional<char>('K'));
+
+ // Required field of type: Fixed string of length 4
+ EXPECT_THROW(writer_ << optional<char4_array_type>(), ParquetException);
+ EXPECT_NO_THROW(writer_ << optional<char4_array_type>(char4_array));
+
+ // Required field of type: int8_t
+ EXPECT_THROW(writer_ << optional<int8_t>(), ParquetException);
+ EXPECT_NO_THROW(writer_ << optional<int8_t>(51));
+
+ // Required field of type: uint16_t
+ EXPECT_THROW(writer_ << optional<uint16_t>(), ParquetException);
+ EXPECT_NO_THROW(writer_ << optional<uint16_t>(15));
+
+ // Required field of type: int32_t
+ EXPECT_THROW(writer_ << optional<int32_t>(), ParquetException);
+ EXPECT_NO_THROW(writer_ << optional<int32_t>(329487));
+
+ // Required field of type: uint64_t
+ EXPECT_THROW(writer_ << optional<uint64_t>(), ParquetException);
+ EXPECT_NO_THROW(writer_ << optional<uint64_t>((1ull << 60) + 123));
+
+ // Required field of type: float
+ EXPECT_THROW(writer_ << optional<float>(), ParquetException);
+ EXPECT_NO_THROW(writer_ << optional<float>(5.4f));
+
+ // Required field of type: double
+ EXPECT_THROW(writer_ << optional<double>(), ParquetException);
+ EXPECT_NO_THROW(writer_ << optional<double>(5.4));
+
+ EXPECT_NO_THROW(writer_ << EndRow);
+}
+
+TEST_F(TestStreamWriter, EndRow) {
+ // Attempt #1 to end row prematurely.
+ EXPECT_EQ(0, writer_.current_row());
+ EXPECT_THROW(writer_ << EndRow, ParquetException);
+ EXPECT_EQ(0, writer_.current_row());
+
+ EXPECT_NO_THROW(writer_ << true);
+ EXPECT_NO_THROW(writer_ << "eschatology");
+ EXPECT_NO_THROW(writer_ << 'z');
+ EXPECT_NO_THROW(writer_ << StreamWriter::FixedStringView("Test", 4));
+ EXPECT_NO_THROW(writer_ << int8_t(51));
+ EXPECT_NO_THROW(writer_ << uint16_t(15));
+
+ // Attempt #2 to end row prematurely.
+ EXPECT_THROW(writer_ << EndRow, ParquetException);
+ EXPECT_EQ(0, writer_.current_row());
+
+ EXPECT_NO_THROW(writer_ << int32_t(329487));
+ EXPECT_NO_THROW(writer_ << uint64_t((1ull << 60) + 123));
+ EXPECT_NO_THROW(writer_ << 25.4f);
+ EXPECT_NO_THROW(writer_ << 3.3424);
+ // Correct use of end row after all fields have been output.
+ EXPECT_NO_THROW(writer_ << EndRow);
+ EXPECT_EQ(1, writer_.current_row());
+
+ // Attempt #3 to end row prematurely.
+ EXPECT_THROW(writer_ << EndRow, ParquetException);
+ EXPECT_EQ(1, writer_.current_row());
+}
+
+TEST_F(TestStreamWriter, EndRowGroup) {
+ writer_.SetMaxRowGroupSize(0);
+
+ // It's ok to end a row group multiple times.
+ EXPECT_NO_THROW(writer_ << EndRowGroup);
+ EXPECT_NO_THROW(writer_ << EndRowGroup);
+ EXPECT_NO_THROW(writer_ << EndRowGroup);
+
+ std::array<char, 4> char_array = {'A', 'B', 'C', 'D'};
+
+ for (int i = 0; i < 20000; ++i) {
+ EXPECT_NO_THROW(writer_ << bool(i & 1)) << "index: " << i;
+ EXPECT_NO_THROW(writer_ << std::to_string(i)) << "index: " << i;
+ EXPECT_NO_THROW(writer_ << char(i % 26 + 'A')) << "index: " << i;
+ // Rotate letters.
+ {
+ char tmp{char_array[0]};
+ char_array[0] = char_array[3];
+ char_array[3] = char_array[2];
+ char_array[2] = char_array[1];
+ char_array[1] = tmp;
+ }
+ EXPECT_NO_THROW(writer_ << char_array) << "index: " << i;
+ EXPECT_NO_THROW(writer_ << int8_t(i & 0xff)) << "index: " << i;
+ EXPECT_NO_THROW(writer_ << uint16_t(7 * i)) << "index: " << i;
+ EXPECT_NO_THROW(writer_ << int32_t((1 << 30) - i * i)) << "index: " << i;
+ EXPECT_NO_THROW(writer_ << uint64_t((1ull << 60) - i * i)) << "index: " << i;
+ EXPECT_NO_THROW(writer_ << 42325.4f / float(i + 1)) << "index: " << i;
+ EXPECT_NO_THROW(writer_ << 3.2342e5 / double(i + 1)) << "index: " << i;
+ EXPECT_NO_THROW(writer_ << EndRow) << "index: " << i;
+
+ if (i % 1000 == 0) {
+ // It's ok to end a row group multiple times.
+ EXPECT_NO_THROW(writer_ << EndRowGroup);
+ EXPECT_NO_THROW(writer_ << EndRowGroup);
+ EXPECT_NO_THROW(writer_ << EndRowGroup);
+ }
+ }
+ // It's ok to end a row group multiple times.
+ EXPECT_NO_THROW(writer_ << EndRowGroup);
+ EXPECT_NO_THROW(writer_ << EndRowGroup);
+ EXPECT_NO_THROW(writer_ << EndRowGroup);
+}
+
+TEST_F(TestStreamWriter, SkipColumns) {
+ EXPECT_EQ(0, writer_.SkipColumns(0));
+ EXPECT_THROW(writer_.SkipColumns(2), ParquetException);
+ writer_ << true << std::string("Cannot skip mandatory columns");
+ EXPECT_THROW(writer_.SkipColumns(1), ParquetException);
+ writer_ << 'x' << std::array<char, 4>{'A', 'B', 'C', 'D'} << int8_t(2) << uint16_t(3)
+ << int32_t(4) << uint64_t(5) << 6.0f << 7.0;
+ writer_ << EndRow;
+}
+
+TEST_F(TestStreamWriter, AppendNotImplemented) {
+ PARQUET_ASSIGN_OR_THROW(auto outfile,
+ ::arrow::io::FileOutputStream::Open(GetDataFile()));
+
+ writer_ = StreamWriter{ParquetFileWriter::Open(outfile, GetSchema())};
+ writer_ << false << std::string("Just one row") << 'x'
+ << std::array<char, 4>{'A', 'B', 'C', 'D'} << int8_t(2) << uint16_t(3)
+ << int32_t(4) << uint64_t(5) << 6.0f << 7.0;
+ writer_ << EndRow;
+ writer_ = StreamWriter{};
+
+ // Re-open file in append mode.
+ PARQUET_ASSIGN_OR_THROW(outfile,
+ ::arrow::io::FileOutputStream::Open(GetDataFile(), true));
+
+ EXPECT_THROW(ParquetFileWriter::Open(outfile, GetSchema()), ParquetException);
+} // namespace test
+
+class TestOptionalFields : public ::testing::Test {
+ protected:
+ void SetUp() {
+ writer_ = StreamWriter{ParquetFileWriter::Open(CreateOutputStream(), GetSchema())};
+ }
+
+ void TearDown() { writer_ = StreamWriter{}; }
+
+ std::shared_ptr<schema::GroupNode> GetSchema() {
+ schema::NodeVector fields;
+
+ fields.push_back(schema::PrimitiveNode::Make("bool_field", Repetition::OPTIONAL,
+ Type::BOOLEAN, ConvertedType::NONE));
+
+ fields.push_back(schema::PrimitiveNode::Make("string_field", Repetition::OPTIONAL,
+ Type::BYTE_ARRAY, ConvertedType::UTF8));
+
+ fields.push_back(schema::PrimitiveNode::Make("char_field", Repetition::OPTIONAL,
+ Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::NONE, 1));
+
+ fields.push_back(schema::PrimitiveNode::Make("char[4]_field", Repetition::OPTIONAL,
+ Type::FIXED_LEN_BYTE_ARRAY,
+ ConvertedType::NONE, 4));
+
+ fields.push_back(schema::PrimitiveNode::Make("int8_field", Repetition::OPTIONAL,
+ Type::INT32, ConvertedType::INT_8));
+
+ fields.push_back(schema::PrimitiveNode::Make("uint16_field", Repetition::OPTIONAL,
+ Type::INT32, ConvertedType::UINT_16));
+
+ fields.push_back(schema::PrimitiveNode::Make("int32_field", Repetition::OPTIONAL,
+ Type::INT32, ConvertedType::INT_32));
+
+ fields.push_back(schema::PrimitiveNode::Make("uint64_field", Repetition::OPTIONAL,
+ Type::INT64, ConvertedType::UINT_64));
+
+ fields.push_back(schema::PrimitiveNode::Make("float_field", Repetition::OPTIONAL,
+ Type::FLOAT, ConvertedType::NONE));
+
+ fields.push_back(schema::PrimitiveNode::Make("double_field", Repetition::OPTIONAL,
+ Type::DOUBLE, ConvertedType::NONE));
+
+ return std::static_pointer_cast<schema::GroupNode>(
+ schema::GroupNode::Make("schema", Repetition::REQUIRED, fields));
+ }
+
+ StreamWriter writer_;
+};
+
+TEST_F(TestOptionalFields, OutputOperatorWithOptionalT) {
+ for (int i = 0; i < 100; ++i) {
+ // Write optional fields using operator<<(optional<T>). Writing
+ // of a value is skipped every 9 rows by using a optional<T>
+ // object without a value.
+
+ if (i % 9 == 0) {
+ writer_ << optional<bool>() << optional<std::string>() << optional<char>()
+ << optional<char4_array_type>() << optional<int8_t>()
+ << optional<uint16_t>() << optional<int32_t>() << optional<uint64_t>()
+ << optional<float>() << optional<double>();
+ } else {
+ writer_ << bool(i & 1) << optional<std::string>("#" + std::to_string(i))
+ << optional<char>('A' + i % 26)
+ << optional<char4_array_type>{{'F', 'O', 'O', 0}} << optional<int8_t>(i)
+ << optional<uint16_t>(0xffff - i) << optional<int32_t>(0x7fffffff - 3 * i)
+ << optional<uint64_t>((1ull << 60) + i) << optional<float>(5.4f * i)
+ << optional<double>(5.1322e6 * i);
+ }
+ EXPECT_NO_THROW(writer_ << EndRow);
+ }
+}
+
+TEST_F(TestOptionalFields, OutputOperatorTAndSkipColumns) {
+ constexpr int num_rows = 100;
+
+ EXPECT_EQ(0, writer_.current_row());
+
+ for (int i = 0; i < num_rows; ++i) {
+ // Write optional fields using standard operator<<(T). Writing of
+ // a value is skipped every 9 rows by using SkipColumns().
+
+ EXPECT_EQ(0, writer_.current_column());
+ EXPECT_EQ(i, writer_.current_row());
+
+ if (i % 9 == 0) {
+ EXPECT_EQ(writer_.num_columns(), writer_.SkipColumns(writer_.num_columns() + 99))
+ << "index: " << i;
+ } else {
+ writer_ << bool(i & 1) << std::string("ok") << char('A' + i % 26)
+ << char4_array_type{{'S', 'K', 'I', 'P'}} << int8_t(i)
+ << uint16_t(0xffff - i) << int32_t(0x7fffffff - 3 * i)
+ << uint64_t((1ull << 60) + i) << 5.4f * i << 5.1322e6 * i;
+ }
+ EXPECT_EQ(writer_.num_columns(), writer_.current_column()) << "index: " << i;
+ EXPECT_NO_THROW(writer_ << EndRow) << "index: " << i;
+ }
+ EXPECT_EQ(num_rows, writer_.current_row());
+}
+
+} // namespace test
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/symbols.map b/src/arrow/cpp/src/parquet/symbols.map
new file mode 100644
index 000000000..4bf032dd5
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/symbols.map
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+{
+ # Symbols marked as 'local' are not exported by the DSO and thus may not
+ # be used by client applications.
+ local:
+ # devtoolset / static-libstdc++ symbols
+ __cxa_*;
+ __once_proxy;
+
+ extern "C++" {
+ # boost
+ boost::*;
+
+ # thrift
+ apache::thrift::*;
+
+ # devtoolset or -static-libstdc++ - the Red Hat devtoolset statically
+ # links c++11 symbols into binaries so that the result may be executed on
+ # a system with an older libstdc++ which doesn't include the necessary
+ # c++11 symbols.
+ std::*;
+ *std::__once_call*;
+ };
+};
diff --git a/src/arrow/cpp/src/parquet/test_util.cc b/src/arrow/cpp/src/parquet/test_util.cc
new file mode 100644
index 000000000..9d104618b
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/test_util.cc
@@ -0,0 +1,136 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This module defines an abstract interface for iterating through pages in a
+// Parquet column chunk within a row group. It could be extended in the future
+// to iterate through all data pages in all chunks in a file.
+
+#include "parquet/test_util.h"
+
+#include <algorithm>
+#include <chrono>
+#include <limits>
+#include <memory>
+#include <random>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "parquet/column_page.h"
+#include "parquet/column_reader.h"
+#include "parquet/column_writer.h"
+#include "parquet/encoding.h"
+#include "parquet/platform.h"
+
+namespace parquet {
+namespace test {
+
+const char* get_data_dir() {
+ const auto result = std::getenv("PARQUET_TEST_DATA");
+ if (!result || !result[0]) {
+ throw ParquetTestException(
+ "Please point the PARQUET_TEST_DATA environment "
+ "variable to the test data directory");
+ }
+ return result;
+}
+
+std::string get_bad_data_dir() {
+ // PARQUET_TEST_DATA should point to ARROW_HOME/cpp/submodules/parquet-testing/data
+ // so need to reach one folder up to access the "bad_data" folder.
+ std::string data_dir(get_data_dir());
+ std::stringstream ss;
+ ss << data_dir << "/../bad_data";
+ return ss.str();
+}
+
+std::string get_data_file(const std::string& filename, bool is_good) {
+ std::stringstream ss;
+
+ if (is_good) {
+ ss << get_data_dir();
+ } else {
+ ss << get_bad_data_dir();
+ }
+
+ ss << "/" << filename;
+ return ss.str();
+}
+
+void random_bytes(int n, uint32_t seed, std::vector<uint8_t>* out) {
+ std::default_random_engine gen(seed);
+ std::uniform_int_distribution<int> d(0, 255);
+
+ out->resize(n);
+ for (int i = 0; i < n; ++i) {
+ (*out)[i] = static_cast<uint8_t>(d(gen));
+ }
+}
+
+void random_bools(int n, double p, uint32_t seed, bool* out) {
+ std::default_random_engine gen(seed);
+ std::bernoulli_distribution d(p);
+ for (int i = 0; i < n; ++i) {
+ out[i] = d(gen);
+ }
+}
+
+void random_Int96_numbers(int n, uint32_t seed, int32_t min_value, int32_t max_value,
+ Int96* out) {
+ std::default_random_engine gen(seed);
+ std::uniform_int_distribution<int32_t> d(min_value, max_value);
+ for (int i = 0; i < n; ++i) {
+ out[i].value[0] = d(gen);
+ out[i].value[1] = d(gen);
+ out[i].value[2] = d(gen);
+ }
+}
+
+void random_fixed_byte_array(int n, uint32_t seed, uint8_t* buf, int len, FLBA* out) {
+ std::default_random_engine gen(seed);
+ std::uniform_int_distribution<int> d(0, 255);
+ for (int i = 0; i < n; ++i) {
+ out[i].ptr = buf;
+ for (int j = 0; j < len; ++j) {
+ buf[j] = static_cast<uint8_t>(d(gen));
+ }
+ buf += len;
+ }
+}
+
+void random_byte_array(int n, uint32_t seed, uint8_t* buf, ByteArray* out, int min_size,
+ int max_size) {
+ std::default_random_engine gen(seed);
+ std::uniform_int_distribution<int> d1(min_size, max_size);
+ std::uniform_int_distribution<int> d2(0, 255);
+ for (int i = 0; i < n; ++i) {
+ int len = d1(gen);
+ out[i].len = len;
+ out[i].ptr = buf;
+ for (int j = 0; j < len; ++j) {
+ buf[j] = static_cast<uint8_t>(d2(gen));
+ }
+ buf += len;
+ }
+}
+
+void random_byte_array(int n, uint32_t seed, uint8_t* buf, ByteArray* out, int max_size) {
+ random_byte_array(n, seed, buf, out, 0, max_size);
+}
+
+} // namespace test
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/test_util.h b/src/arrow/cpp/src/parquet/test_util.h
new file mode 100644
index 000000000..d4e6de825
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/test_util.h
@@ -0,0 +1,715 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This module defines an abstract interface for iterating through pages in a
+// Parquet column chunk within a row group. It could be extended in the future
+// to iterate through all data pages in all chunks in a file.
+
+#pragma once
+
+#include <algorithm>
+#include <limits>
+#include <memory>
+#include <random>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/io/memory.h"
+#include "arrow/testing/util.h"
+
+#include "parquet/column_page.h"
+#include "parquet/column_reader.h"
+#include "parquet/column_writer.h"
+#include "parquet/encoding.h"
+#include "parquet/platform.h"
+
+namespace parquet {
+
+static constexpr int FLBA_LENGTH = 12;
+
+inline bool operator==(const FixedLenByteArray& a, const FixedLenByteArray& b) {
+ return 0 == memcmp(a.ptr, b.ptr, FLBA_LENGTH);
+}
+
+namespace test {
+
+typedef ::testing::Types<BooleanType, Int32Type, Int64Type, Int96Type, FloatType,
+ DoubleType, ByteArrayType, FLBAType>
+ ParquetTypes;
+
+class ParquetTestException : public parquet::ParquetException {
+ using ParquetException::ParquetException;
+};
+
+const char* get_data_dir();
+std::string get_bad_data_dir();
+
+std::string get_data_file(const std::string& filename, bool is_good = true);
+
+template <typename T>
+static inline void assert_vector_equal(const std::vector<T>& left,
+ const std::vector<T>& right) {
+ ASSERT_EQ(left.size(), right.size());
+
+ for (size_t i = 0; i < left.size(); ++i) {
+ ASSERT_EQ(left[i], right[i]) << i;
+ }
+}
+
+template <typename T>
+static inline bool vector_equal(const std::vector<T>& left, const std::vector<T>& right) {
+ if (left.size() != right.size()) {
+ return false;
+ }
+
+ for (size_t i = 0; i < left.size(); ++i) {
+ if (left[i] != right[i]) {
+ std::cerr << "index " << i << " left was " << left[i] << " right was " << right[i]
+ << std::endl;
+ return false;
+ }
+ }
+
+ return true;
+}
+
+template <typename T>
+static std::vector<T> slice(const std::vector<T>& values, int start, int end) {
+ if (end < start) {
+ return std::vector<T>(0);
+ }
+
+ std::vector<T> out(end - start);
+ for (int i = start; i < end; ++i) {
+ out[i - start] = values[i];
+ }
+ return out;
+}
+
+void random_bytes(int n, uint32_t seed, std::vector<uint8_t>* out);
+void random_bools(int n, double p, uint32_t seed, bool* out);
+
+template <typename T>
+inline void random_numbers(int n, uint32_t seed, T min_value, T max_value, T* out) {
+ std::default_random_engine gen(seed);
+ std::uniform_int_distribution<T> d(min_value, max_value);
+ for (int i = 0; i < n; ++i) {
+ out[i] = d(gen);
+ }
+}
+
+template <>
+inline void random_numbers(int n, uint32_t seed, float min_value, float max_value,
+ float* out) {
+ std::default_random_engine gen(seed);
+ std::uniform_real_distribution<float> d(min_value, max_value);
+ for (int i = 0; i < n; ++i) {
+ out[i] = d(gen);
+ }
+}
+
+template <>
+inline void random_numbers(int n, uint32_t seed, double min_value, double max_value,
+ double* out) {
+ std::default_random_engine gen(seed);
+ std::uniform_real_distribution<double> d(min_value, max_value);
+ for (int i = 0; i < n; ++i) {
+ out[i] = d(gen);
+ }
+}
+
+void random_Int96_numbers(int n, uint32_t seed, int32_t min_value, int32_t max_value,
+ Int96* out);
+
+void random_fixed_byte_array(int n, uint32_t seed, uint8_t* buf, int len, FLBA* out);
+
+void random_byte_array(int n, uint32_t seed, uint8_t* buf, ByteArray* out, int min_size,
+ int max_size);
+
+void random_byte_array(int n, uint32_t seed, uint8_t* buf, ByteArray* out, int max_size);
+
+template <typename Type, typename Sequence>
+std::shared_ptr<Buffer> EncodeValues(Encoding::type encoding, bool use_dictionary,
+ const Sequence& values, int length,
+ const ColumnDescriptor* descr) {
+ auto encoder = MakeTypedEncoder<Type>(encoding, use_dictionary, descr);
+ encoder->Put(values, length);
+ return encoder->FlushValues();
+}
+
+template <typename T>
+static void InitValues(int num_values, std::vector<T>& values,
+ std::vector<uint8_t>& buffer) {
+ random_numbers(num_values, 0, std::numeric_limits<T>::min(),
+ std::numeric_limits<T>::max(), values.data());
+}
+
+template <typename T>
+static void InitDictValues(int num_values, int num_dicts, std::vector<T>& values,
+ std::vector<uint8_t>& buffer) {
+ int repeat_factor = num_values / num_dicts;
+ InitValues<T>(num_dicts, values, buffer);
+ // add some repeated values
+ for (int j = 1; j < repeat_factor; ++j) {
+ for (int i = 0; i < num_dicts; ++i) {
+ std::memcpy(&values[num_dicts * j + i], &values[i], sizeof(T));
+ }
+ }
+ // computed only dict_per_page * repeat_factor - 1 values < num_values
+ // compute remaining
+ for (int i = num_dicts * repeat_factor; i < num_values; ++i) {
+ std::memcpy(&values[i], &values[i - num_dicts * repeat_factor], sizeof(T));
+ }
+}
+
+template <>
+inline void InitDictValues<bool>(int num_values, int num_dicts, std::vector<bool>& values,
+ std::vector<uint8_t>& buffer) {
+ // No op for bool
+}
+
+class MockPageReader : public PageReader {
+ public:
+ explicit MockPageReader(const std::vector<std::shared_ptr<Page>>& pages)
+ : pages_(pages), page_index_(0) {}
+
+ std::shared_ptr<Page> NextPage() override {
+ if (page_index_ == static_cast<int>(pages_.size())) {
+ // EOS to consumer
+ return std::shared_ptr<Page>(nullptr);
+ }
+ return pages_[page_index_++];
+ }
+
+ // No-op
+ void set_max_page_header_size(uint32_t size) override {}
+
+ private:
+ std::vector<std::shared_ptr<Page>> pages_;
+ int page_index_;
+};
+
+// TODO(wesm): this is only used for testing for now. Refactor to form part of
+// primary file write path
+template <typename Type>
+class DataPageBuilder {
+ public:
+ using c_type = typename Type::c_type;
+
+ // This class writes data and metadata to the passed inputs
+ explicit DataPageBuilder(ArrowOutputStream* sink)
+ : sink_(sink),
+ num_values_(0),
+ encoding_(Encoding::PLAIN),
+ definition_level_encoding_(Encoding::RLE),
+ repetition_level_encoding_(Encoding::RLE),
+ have_def_levels_(false),
+ have_rep_levels_(false),
+ have_values_(false) {}
+
+ void AppendDefLevels(const std::vector<int16_t>& levels, int16_t max_level,
+ Encoding::type encoding = Encoding::RLE) {
+ AppendLevels(levels, max_level, encoding);
+
+ num_values_ = std::max(static_cast<int32_t>(levels.size()), num_values_);
+ definition_level_encoding_ = encoding;
+ have_def_levels_ = true;
+ }
+
+ void AppendRepLevels(const std::vector<int16_t>& levels, int16_t max_level,
+ Encoding::type encoding = Encoding::RLE) {
+ AppendLevels(levels, max_level, encoding);
+
+ num_values_ = std::max(static_cast<int32_t>(levels.size()), num_values_);
+ repetition_level_encoding_ = encoding;
+ have_rep_levels_ = true;
+ }
+
+ void AppendValues(const ColumnDescriptor* d, const std::vector<c_type>& values,
+ Encoding::type encoding = Encoding::PLAIN) {
+ std::shared_ptr<Buffer> values_sink = EncodeValues<Type>(
+ encoding, false, values.data(), static_cast<int>(values.size()), d);
+ PARQUET_THROW_NOT_OK(sink_->Write(values_sink->data(), values_sink->size()));
+
+ num_values_ = std::max(static_cast<int32_t>(values.size()), num_values_);
+ encoding_ = encoding;
+ have_values_ = true;
+ }
+
+ int32_t num_values() const { return num_values_; }
+
+ Encoding::type encoding() const { return encoding_; }
+
+ Encoding::type rep_level_encoding() const { return repetition_level_encoding_; }
+
+ Encoding::type def_level_encoding() const { return definition_level_encoding_; }
+
+ private:
+ ArrowOutputStream* sink_;
+
+ int32_t num_values_;
+ Encoding::type encoding_;
+ Encoding::type definition_level_encoding_;
+ Encoding::type repetition_level_encoding_;
+
+ bool have_def_levels_;
+ bool have_rep_levels_;
+ bool have_values_;
+
+ // Used internally for both repetition and definition levels
+ void AppendLevels(const std::vector<int16_t>& levels, int16_t max_level,
+ Encoding::type encoding) {
+ if (encoding != Encoding::RLE) {
+ ParquetException::NYI("only rle encoding currently implemented");
+ }
+
+ // TODO: compute a more precise maximum size for the encoded levels
+ std::vector<uint8_t> encode_buffer(levels.size() * 2);
+
+ // We encode into separate memory from the output stream because the
+ // RLE-encoded bytes have to be preceded in the stream by their absolute
+ // size.
+ LevelEncoder encoder;
+ encoder.Init(encoding, max_level, static_cast<int>(levels.size()),
+ encode_buffer.data(), static_cast<int>(encode_buffer.size()));
+
+ encoder.Encode(static_cast<int>(levels.size()), levels.data());
+
+ int32_t rle_bytes = encoder.len();
+ PARQUET_THROW_NOT_OK(
+ sink_->Write(reinterpret_cast<const uint8_t*>(&rle_bytes), sizeof(int32_t)));
+ PARQUET_THROW_NOT_OK(sink_->Write(encode_buffer.data(), rle_bytes));
+ }
+};
+
+template <>
+inline void DataPageBuilder<BooleanType>::AppendValues(const ColumnDescriptor* d,
+ const std::vector<bool>& values,
+ Encoding::type encoding) {
+ if (encoding != Encoding::PLAIN) {
+ ParquetException::NYI("only plain encoding currently implemented");
+ }
+
+ auto encoder = MakeTypedEncoder<BooleanType>(Encoding::PLAIN, false, d);
+ dynamic_cast<BooleanEncoder*>(encoder.get())
+ ->Put(values, static_cast<int>(values.size()));
+ std::shared_ptr<Buffer> buffer = encoder->FlushValues();
+ PARQUET_THROW_NOT_OK(sink_->Write(buffer->data(), buffer->size()));
+
+ num_values_ = std::max(static_cast<int32_t>(values.size()), num_values_);
+ encoding_ = encoding;
+ have_values_ = true;
+}
+
+template <typename Type>
+static std::shared_ptr<DataPageV1> MakeDataPage(
+ const ColumnDescriptor* d, const std::vector<typename Type::c_type>& values,
+ int num_vals, Encoding::type encoding, const uint8_t* indices, int indices_size,
+ const std::vector<int16_t>& def_levels, int16_t max_def_level,
+ const std::vector<int16_t>& rep_levels, int16_t max_rep_level) {
+ int num_values = 0;
+
+ auto page_stream = CreateOutputStream();
+ test::DataPageBuilder<Type> page_builder(page_stream.get());
+
+ if (!rep_levels.empty()) {
+ page_builder.AppendRepLevels(rep_levels, max_rep_level);
+ }
+ if (!def_levels.empty()) {
+ page_builder.AppendDefLevels(def_levels, max_def_level);
+ }
+
+ if (encoding == Encoding::PLAIN) {
+ page_builder.AppendValues(d, values, encoding);
+ num_values = page_builder.num_values();
+ } else { // DICTIONARY PAGES
+ PARQUET_THROW_NOT_OK(page_stream->Write(indices, indices_size));
+ num_values = std::max(page_builder.num_values(), num_vals);
+ }
+
+ PARQUET_ASSIGN_OR_THROW(auto buffer, page_stream->Finish());
+
+ return std::make_shared<DataPageV1>(buffer, num_values, encoding,
+ page_builder.def_level_encoding(),
+ page_builder.rep_level_encoding(), buffer->size());
+}
+
+template <typename TYPE>
+class DictionaryPageBuilder {
+ public:
+ typedef typename TYPE::c_type TC;
+ static constexpr int TN = TYPE::type_num;
+ using SpecializedEncoder = typename EncodingTraits<TYPE>::Encoder;
+
+ // This class writes data and metadata to the passed inputs
+ explicit DictionaryPageBuilder(const ColumnDescriptor* d)
+ : num_dict_values_(0), have_values_(false) {
+ auto encoder = MakeTypedEncoder<TYPE>(Encoding::PLAIN, true, d);
+ dict_traits_ = dynamic_cast<DictEncoder<TYPE>*>(encoder.get());
+ encoder_.reset(dynamic_cast<SpecializedEncoder*>(encoder.release()));
+ }
+
+ ~DictionaryPageBuilder() {}
+
+ std::shared_ptr<Buffer> AppendValues(const std::vector<TC>& values) {
+ int num_values = static_cast<int>(values.size());
+ // Dictionary encoding
+ encoder_->Put(values.data(), num_values);
+ num_dict_values_ = dict_traits_->num_entries();
+ have_values_ = true;
+ return encoder_->FlushValues();
+ }
+
+ std::shared_ptr<Buffer> WriteDict() {
+ std::shared_ptr<Buffer> dict_buffer =
+ AllocateBuffer(::arrow::default_memory_pool(), dict_traits_->dict_encoded_size());
+ dict_traits_->WriteDict(dict_buffer->mutable_data());
+ return dict_buffer;
+ }
+
+ int32_t num_values() const { return num_dict_values_; }
+
+ private:
+ DictEncoder<TYPE>* dict_traits_;
+ std::unique_ptr<SpecializedEncoder> encoder_;
+ int32_t num_dict_values_;
+ bool have_values_;
+};
+
+template <>
+inline DictionaryPageBuilder<BooleanType>::DictionaryPageBuilder(
+ const ColumnDescriptor* d) {
+ ParquetException::NYI("only plain encoding currently implemented for boolean");
+}
+
+template <>
+inline std::shared_ptr<Buffer> DictionaryPageBuilder<BooleanType>::WriteDict() {
+ ParquetException::NYI("only plain encoding currently implemented for boolean");
+ return nullptr;
+}
+
+template <>
+inline std::shared_ptr<Buffer> DictionaryPageBuilder<BooleanType>::AppendValues(
+ const std::vector<TC>& values) {
+ ParquetException::NYI("only plain encoding currently implemented for boolean");
+ return nullptr;
+}
+
+template <typename Type>
+inline static std::shared_ptr<DictionaryPage> MakeDictPage(
+ const ColumnDescriptor* d, const std::vector<typename Type::c_type>& values,
+ const std::vector<int>& values_per_page, Encoding::type encoding,
+ std::vector<std::shared_ptr<Buffer>>& rle_indices) {
+ test::DictionaryPageBuilder<Type> page_builder(d);
+ int num_pages = static_cast<int>(values_per_page.size());
+ int value_start = 0;
+
+ for (int i = 0; i < num_pages; i++) {
+ rle_indices.push_back(page_builder.AppendValues(
+ slice(values, value_start, value_start + values_per_page[i])));
+ value_start += values_per_page[i];
+ }
+
+ auto buffer = page_builder.WriteDict();
+
+ return std::make_shared<DictionaryPage>(buffer, page_builder.num_values(),
+ Encoding::PLAIN);
+}
+
+// Given def/rep levels and values create multiple dict pages
+template <typename Type>
+inline static void PaginateDict(const ColumnDescriptor* d,
+ const std::vector<typename Type::c_type>& values,
+ const std::vector<int16_t>& def_levels,
+ int16_t max_def_level,
+ const std::vector<int16_t>& rep_levels,
+ int16_t max_rep_level, int num_levels_per_page,
+ const std::vector<int>& values_per_page,
+ std::vector<std::shared_ptr<Page>>& pages,
+ Encoding::type encoding = Encoding::RLE_DICTIONARY) {
+ int num_pages = static_cast<int>(values_per_page.size());
+ std::vector<std::shared_ptr<Buffer>> rle_indices;
+ std::shared_ptr<DictionaryPage> dict_page =
+ MakeDictPage<Type>(d, values, values_per_page, encoding, rle_indices);
+ pages.push_back(dict_page);
+ int def_level_start = 0;
+ int def_level_end = 0;
+ int rep_level_start = 0;
+ int rep_level_end = 0;
+ for (int i = 0; i < num_pages; i++) {
+ if (max_def_level > 0) {
+ def_level_start = i * num_levels_per_page;
+ def_level_end = (i + 1) * num_levels_per_page;
+ }
+ if (max_rep_level > 0) {
+ rep_level_start = i * num_levels_per_page;
+ rep_level_end = (i + 1) * num_levels_per_page;
+ }
+ std::shared_ptr<DataPageV1> data_page = MakeDataPage<Int32Type>(
+ d, {}, values_per_page[i], encoding, rle_indices[i]->data(),
+ static_cast<int>(rle_indices[i]->size()),
+ slice(def_levels, def_level_start, def_level_end), max_def_level,
+ slice(rep_levels, rep_level_start, rep_level_end), max_rep_level);
+ pages.push_back(data_page);
+ }
+}
+
+// Given def/rep levels and values create multiple plain pages
+template <typename Type>
+static inline void PaginatePlain(const ColumnDescriptor* d,
+ const std::vector<typename Type::c_type>& values,
+ const std::vector<int16_t>& def_levels,
+ int16_t max_def_level,
+ const std::vector<int16_t>& rep_levels,
+ int16_t max_rep_level, int num_levels_per_page,
+ const std::vector<int>& values_per_page,
+ std::vector<std::shared_ptr<Page>>& pages,
+ Encoding::type encoding = Encoding::PLAIN) {
+ int num_pages = static_cast<int>(values_per_page.size());
+ int def_level_start = 0;
+ int def_level_end = 0;
+ int rep_level_start = 0;
+ int rep_level_end = 0;
+ int value_start = 0;
+ for (int i = 0; i < num_pages; i++) {
+ if (max_def_level > 0) {
+ def_level_start = i * num_levels_per_page;
+ def_level_end = (i + 1) * num_levels_per_page;
+ }
+ if (max_rep_level > 0) {
+ rep_level_start = i * num_levels_per_page;
+ rep_level_end = (i + 1) * num_levels_per_page;
+ }
+ std::shared_ptr<DataPage> page = MakeDataPage<Type>(
+ d, slice(values, value_start, value_start + values_per_page[i]),
+ values_per_page[i], encoding, nullptr, 0,
+ slice(def_levels, def_level_start, def_level_end), max_def_level,
+ slice(rep_levels, rep_level_start, rep_level_end), max_rep_level);
+ pages.push_back(page);
+ value_start += values_per_page[i];
+ }
+}
+
+// Generates pages from randomly generated data
+template <typename Type>
+static inline int MakePages(const ColumnDescriptor* d, int num_pages, int levels_per_page,
+ std::vector<int16_t>& def_levels,
+ std::vector<int16_t>& rep_levels,
+ std::vector<typename Type::c_type>& values,
+ std::vector<uint8_t>& buffer,
+ std::vector<std::shared_ptr<Page>>& pages,
+ Encoding::type encoding = Encoding::PLAIN) {
+ int num_levels = levels_per_page * num_pages;
+ int num_values = 0;
+ uint32_t seed = 0;
+ int16_t zero = 0;
+ int16_t max_def_level = d->max_definition_level();
+ int16_t max_rep_level = d->max_repetition_level();
+ std::vector<int> values_per_page(num_pages, levels_per_page);
+ // Create definition levels
+ if (max_def_level > 0) {
+ def_levels.resize(num_levels);
+ random_numbers(num_levels, seed, zero, max_def_level, def_levels.data());
+ for (int p = 0; p < num_pages; p++) {
+ int num_values_per_page = 0;
+ for (int i = 0; i < levels_per_page; i++) {
+ if (def_levels[i + p * levels_per_page] == max_def_level) {
+ num_values_per_page++;
+ num_values++;
+ }
+ }
+ values_per_page[p] = num_values_per_page;
+ }
+ } else {
+ num_values = num_levels;
+ }
+ // Create repetition levels
+ if (max_rep_level > 0) {
+ rep_levels.resize(num_levels);
+ random_numbers(num_levels, seed, zero, max_rep_level, rep_levels.data());
+ }
+ // Create values
+ values.resize(num_values);
+ if (encoding == Encoding::PLAIN) {
+ InitValues<typename Type::c_type>(num_values, values, buffer);
+ PaginatePlain<Type>(d, values, def_levels, max_def_level, rep_levels, max_rep_level,
+ levels_per_page, values_per_page, pages);
+ } else if (encoding == Encoding::RLE_DICTIONARY ||
+ encoding == Encoding::PLAIN_DICTIONARY) {
+ // Calls InitValues and repeats the data
+ InitDictValues<typename Type::c_type>(num_values, levels_per_page, values, buffer);
+ PaginateDict<Type>(d, values, def_levels, max_def_level, rep_levels, max_rep_level,
+ levels_per_page, values_per_page, pages);
+ }
+
+ return num_values;
+}
+
+// ----------------------------------------------------------------------
+// Test data generation
+
+template <>
+void inline InitValues<bool>(int num_values, std::vector<bool>& values,
+ std::vector<uint8_t>& buffer) {
+ values = {};
+ ::arrow::random_is_valid(num_values, 0.5, &values,
+ static_cast<int>(::arrow::random_seed()));
+}
+
+template <>
+inline void InitValues<ByteArray>(int num_values, std::vector<ByteArray>& values,
+ std::vector<uint8_t>& buffer) {
+ int max_byte_array_len = 12;
+ int num_bytes = static_cast<int>(max_byte_array_len + sizeof(uint32_t));
+ size_t nbytes = num_values * num_bytes;
+ buffer.resize(nbytes);
+ random_byte_array(num_values, 0, buffer.data(), values.data(), max_byte_array_len);
+}
+
+inline void InitWideByteArrayValues(int num_values, std::vector<ByteArray>& values,
+ std::vector<uint8_t>& buffer, int min_len,
+ int max_len) {
+ int num_bytes = static_cast<int>(max_len + sizeof(uint32_t));
+ size_t nbytes = num_values * num_bytes;
+ buffer.resize(nbytes);
+ random_byte_array(num_values, 0, buffer.data(), values.data(), min_len, max_len);
+}
+
+template <>
+inline void InitValues<FLBA>(int num_values, std::vector<FLBA>& values,
+ std::vector<uint8_t>& buffer) {
+ size_t nbytes = num_values * FLBA_LENGTH;
+ buffer.resize(nbytes);
+ random_fixed_byte_array(num_values, 0, buffer.data(), FLBA_LENGTH, values.data());
+}
+
+template <>
+inline void InitValues<Int96>(int num_values, std::vector<Int96>& values,
+ std::vector<uint8_t>& buffer) {
+ random_Int96_numbers(num_values, 0, std::numeric_limits<int32_t>::min(),
+ std::numeric_limits<int32_t>::max(), values.data());
+}
+
+inline std::string TestColumnName(int i) {
+ std::stringstream col_name;
+ col_name << "column_" << i;
+ return col_name.str();
+}
+
+// This class lives here because of its dependency on the InitValues specializations.
+template <typename TestType>
+class PrimitiveTypedTest : public ::testing::Test {
+ public:
+ using c_type = typename TestType::c_type;
+
+ void SetUpSchema(Repetition::type repetition, int num_columns = 1) {
+ std::vector<schema::NodePtr> fields;
+
+ for (int i = 0; i < num_columns; ++i) {
+ std::string name = TestColumnName(i);
+ fields.push_back(schema::PrimitiveNode::Make(name, repetition, TestType::type_num,
+ ConvertedType::NONE, FLBA_LENGTH));
+ }
+ node_ = schema::GroupNode::Make("schema", Repetition::REQUIRED, fields);
+ schema_.Init(node_);
+ }
+
+ void GenerateData(int64_t num_values);
+ void SetupValuesOut(int64_t num_values);
+ void SyncValuesOut();
+
+ protected:
+ schema::NodePtr node_;
+ SchemaDescriptor schema_;
+
+ // Input buffers
+ std::vector<c_type> values_;
+
+ std::vector<int16_t> def_levels_;
+
+ std::vector<uint8_t> buffer_;
+ // Pointer to the values, needed as we cannot use std::vector<bool>::data()
+ c_type* values_ptr_;
+ std::vector<uint8_t> bool_buffer_;
+
+ // Output buffers
+ std::vector<c_type> values_out_;
+ std::vector<uint8_t> bool_buffer_out_;
+ c_type* values_out_ptr_;
+};
+
+template <typename TestType>
+inline void PrimitiveTypedTest<TestType>::SyncValuesOut() {}
+
+template <>
+inline void PrimitiveTypedTest<BooleanType>::SyncValuesOut() {
+ std::vector<uint8_t>::const_iterator source_iterator = bool_buffer_out_.begin();
+ std::vector<c_type>::iterator destination_iterator = values_out_.begin();
+ while (source_iterator != bool_buffer_out_.end()) {
+ *destination_iterator++ = *source_iterator++ != 0;
+ }
+}
+
+template <typename TestType>
+inline void PrimitiveTypedTest<TestType>::SetupValuesOut(int64_t num_values) {
+ values_out_.clear();
+ values_out_.resize(num_values);
+ values_out_ptr_ = values_out_.data();
+}
+
+template <>
+inline void PrimitiveTypedTest<BooleanType>::SetupValuesOut(int64_t num_values) {
+ values_out_.clear();
+ values_out_.resize(num_values);
+
+ bool_buffer_out_.clear();
+ bool_buffer_out_.resize(num_values);
+ // Write once to all values so we can copy it without getting Valgrind errors
+ // about uninitialised values.
+ std::fill(bool_buffer_out_.begin(), bool_buffer_out_.end(), true);
+ values_out_ptr_ = reinterpret_cast<bool*>(bool_buffer_out_.data());
+}
+
+template <typename TestType>
+inline void PrimitiveTypedTest<TestType>::GenerateData(int64_t num_values) {
+ def_levels_.resize(num_values);
+ values_.resize(num_values);
+
+ InitValues<c_type>(static_cast<int>(num_values), values_, buffer_);
+ values_ptr_ = values_.data();
+
+ std::fill(def_levels_.begin(), def_levels_.end(), 1);
+}
+
+template <>
+inline void PrimitiveTypedTest<BooleanType>::GenerateData(int64_t num_values) {
+ def_levels_.resize(num_values);
+ values_.resize(num_values);
+
+ InitValues<c_type>(static_cast<int>(num_values), values_, buffer_);
+ bool_buffer_.resize(num_values);
+ std::copy(values_.begin(), values_.end(), bool_buffer_.begin());
+ values_ptr_ = reinterpret_cast<bool*>(bool_buffer_.data());
+
+ std::fill(def_levels_.begin(), def_levels_.end(), 1);
+}
+
+} // namespace test
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/thrift_internal.h b/src/arrow/cpp/src/parquet/thrift_internal.h
new file mode 100644
index 000000000..99bd39c65
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/thrift_internal.h
@@ -0,0 +1,509 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/util/windows_compatibility.h"
+
+#include <cstdint>
+#include <limits>
+
+// Check if thrift version < 0.11.0
+// or if FORCE_BOOST_SMART_PTR is defined. Ref: https://thrift.apache.org/lib/cpp
+#if defined(PARQUET_THRIFT_USE_BOOST) || defined(FORCE_BOOST_SMART_PTR)
+#include <boost/shared_ptr.hpp>
+#else
+#include <memory>
+#endif
+#include <sstream>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+// TCompactProtocol requires some #defines to work right.
+#define SIGNED_RIGHT_SHIFT_IS 1
+#define ARITHMETIC_RIGHT_SHIFT 1
+#include <thrift/TApplicationException.h>
+#include <thrift/protocol/TCompactProtocol.h>
+#include <thrift/transport/TBufferTransports.h>
+
+#include "arrow/util/logging.h"
+
+#include "parquet/encryption/internal_file_decryptor.h"
+#include "parquet/encryption/internal_file_encryptor.h"
+#include "parquet/exception.h"
+#include "parquet/platform.h"
+#include "parquet/statistics.h"
+#include "parquet/types.h"
+
+#include "generated/parquet_types.h" // IYWU pragma: export
+
+namespace parquet {
+
+// Check if thrift version < 0.11.0
+// or if FORCE_BOOST_SMART_PTR is defined. Ref: https://thrift.apache.org/lib/cpp
+#if defined(PARQUET_THRIFT_USE_BOOST) || defined(FORCE_BOOST_SMART_PTR)
+using ::boost::shared_ptr;
+#else
+using ::std::shared_ptr;
+#endif
+
+// ----------------------------------------------------------------------
+// Convert Thrift enums to Parquet enums
+
+// Unsafe enum converters (input is not checked for validity)
+
+static inline Type::type FromThriftUnsafe(format::Type::type type) {
+ return static_cast<Type::type>(type);
+}
+
+static inline ConvertedType::type FromThriftUnsafe(format::ConvertedType::type type) {
+ // item 0 is NONE
+ return static_cast<ConvertedType::type>(static_cast<int>(type) + 1);
+}
+
+static inline Repetition::type FromThriftUnsafe(format::FieldRepetitionType::type type) {
+ return static_cast<Repetition::type>(type);
+}
+
+static inline Encoding::type FromThriftUnsafe(format::Encoding::type type) {
+ return static_cast<Encoding::type>(type);
+}
+
+static inline PageType::type FromThriftUnsafe(format::PageType::type type) {
+ return static_cast<PageType::type>(type);
+}
+
+static inline Compression::type FromThriftUnsafe(format::CompressionCodec::type type) {
+ switch (type) {
+ case format::CompressionCodec::UNCOMPRESSED:
+ return Compression::UNCOMPRESSED;
+ case format::CompressionCodec::SNAPPY:
+ return Compression::SNAPPY;
+ case format::CompressionCodec::GZIP:
+ return Compression::GZIP;
+ case format::CompressionCodec::LZO:
+ return Compression::LZO;
+ case format::CompressionCodec::BROTLI:
+ return Compression::BROTLI;
+ case format::CompressionCodec::LZ4:
+ return Compression::LZ4_HADOOP;
+ case format::CompressionCodec::LZ4_RAW:
+ return Compression::LZ4;
+ case format::CompressionCodec::ZSTD:
+ return Compression::ZSTD;
+ default:
+ DCHECK(false) << "Cannot reach here";
+ return Compression::UNCOMPRESSED;
+ }
+}
+
+namespace internal {
+
+template <typename T>
+struct ThriftEnumTypeTraits {};
+
+template <>
+struct ThriftEnumTypeTraits<::parquet::format::Type::type> {
+ using ParquetEnum = Type;
+};
+
+template <>
+struct ThriftEnumTypeTraits<::parquet::format::ConvertedType::type> {
+ using ParquetEnum = ConvertedType;
+};
+
+template <>
+struct ThriftEnumTypeTraits<::parquet::format::FieldRepetitionType::type> {
+ using ParquetEnum = Repetition;
+};
+
+template <>
+struct ThriftEnumTypeTraits<::parquet::format::Encoding::type> {
+ using ParquetEnum = Encoding;
+};
+
+template <>
+struct ThriftEnumTypeTraits<::parquet::format::PageType::type> {
+ using ParquetEnum = PageType;
+};
+
+// If the parquet file is corrupted it is possible the enum value decoded
+// will not be in the range of defined values, which is undefined behaviour.
+// This facility prevents this by loading the value as the underlying type
+// and checking to make sure it is in range.
+
+template <typename EnumType,
+ typename EnumTypeRaw = typename std::underlying_type<EnumType>::type>
+inline static EnumTypeRaw LoadEnumRaw(const EnumType* in) {
+ EnumTypeRaw raw_value;
+ // Use memcpy(), as a regular cast would be undefined behaviour on invalid values
+ memcpy(&raw_value, in, sizeof(EnumType));
+ return raw_value;
+}
+
+template <typename ApiType>
+struct SafeLoader {
+ using ApiTypeEnum = typename ApiType::type;
+ using ApiTypeRawEnum = typename std::underlying_type<ApiTypeEnum>::type;
+
+ template <typename ThriftType>
+ inline static ApiTypeRawEnum LoadRaw(const ThriftType* in) {
+ static_assert(sizeof(ApiTypeEnum) == sizeof(ThriftType),
+ "parquet type should always be the same size as thrift type");
+ return static_cast<ApiTypeRawEnum>(LoadEnumRaw(in));
+ }
+
+ template <typename ThriftType, bool IsUnsigned = true>
+ inline static ApiTypeEnum LoadChecked(
+ const typename std::enable_if<IsUnsigned, ThriftType>::type* in) {
+ auto raw_value = LoadRaw(in);
+ if (ARROW_PREDICT_FALSE(raw_value >=
+ static_cast<ApiTypeRawEnum>(ApiType::UNDEFINED))) {
+ return ApiType::UNDEFINED;
+ }
+ return FromThriftUnsafe(static_cast<ThriftType>(raw_value));
+ }
+
+ template <typename ThriftType, bool IsUnsigned = false>
+ inline static ApiTypeEnum LoadChecked(
+ const typename std::enable_if<!IsUnsigned, ThriftType>::type* in) {
+ auto raw_value = LoadRaw(in);
+ if (ARROW_PREDICT_FALSE(raw_value >=
+ static_cast<ApiTypeRawEnum>(ApiType::UNDEFINED) ||
+ raw_value < 0)) {
+ return ApiType::UNDEFINED;
+ }
+ return FromThriftUnsafe(static_cast<ThriftType>(raw_value));
+ }
+
+ template <typename ThriftType>
+ inline static ApiTypeEnum Load(const ThriftType* in) {
+ return LoadChecked<ThriftType, std::is_unsigned<ApiTypeRawEnum>::value>(in);
+ }
+};
+
+} // namespace internal
+
+// Safe enum loader: will check for invalid enum value before converting
+
+template <typename ThriftType,
+ typename ParquetEnum =
+ typename internal::ThriftEnumTypeTraits<ThriftType>::ParquetEnum>
+inline typename ParquetEnum::type LoadEnumSafe(const ThriftType* in) {
+ return internal::SafeLoader<ParquetEnum>::Load(in);
+}
+
+inline typename Compression::type LoadEnumSafe(const format::CompressionCodec::type* in) {
+ const auto raw_value = internal::LoadEnumRaw(in);
+ // Check bounds manually, as Compression::type doesn't have the same values
+ // as format::CompressionCodec.
+ const auto min_value =
+ static_cast<decltype(raw_value)>(format::CompressionCodec::UNCOMPRESSED);
+ const auto max_value =
+ static_cast<decltype(raw_value)>(format::CompressionCodec::LZ4_RAW);
+ if (raw_value < min_value || raw_value > max_value) {
+ return Compression::UNCOMPRESSED;
+ }
+ return FromThriftUnsafe(*in);
+}
+
+// Safe non-enum converters
+
+static inline AadMetadata FromThrift(format::AesGcmV1 aesGcmV1) {
+ return AadMetadata{aesGcmV1.aad_prefix, aesGcmV1.aad_file_unique,
+ aesGcmV1.supply_aad_prefix};
+}
+
+static inline AadMetadata FromThrift(format::AesGcmCtrV1 aesGcmCtrV1) {
+ return AadMetadata{aesGcmCtrV1.aad_prefix, aesGcmCtrV1.aad_file_unique,
+ aesGcmCtrV1.supply_aad_prefix};
+}
+
+static inline EncryptionAlgorithm FromThrift(format::EncryptionAlgorithm encryption) {
+ EncryptionAlgorithm encryption_algorithm;
+
+ if (encryption.__isset.AES_GCM_V1) {
+ encryption_algorithm.algorithm = ParquetCipher::AES_GCM_V1;
+ encryption_algorithm.aad = FromThrift(encryption.AES_GCM_V1);
+ } else if (encryption.__isset.AES_GCM_CTR_V1) {
+ encryption_algorithm.algorithm = ParquetCipher::AES_GCM_CTR_V1;
+ encryption_algorithm.aad = FromThrift(encryption.AES_GCM_CTR_V1);
+ } else {
+ throw ParquetException("Unsupported algorithm");
+ }
+ return encryption_algorithm;
+}
+
+// ----------------------------------------------------------------------
+// Convert Thrift enums from Parquet enums
+
+static inline format::Type::type ToThrift(Type::type type) {
+ return static_cast<format::Type::type>(type);
+}
+
+static inline format::ConvertedType::type ToThrift(ConvertedType::type type) {
+ // item 0 is NONE
+ DCHECK_NE(type, ConvertedType::NONE);
+ // it is forbidden to emit "NA" (PARQUET-1990)
+ DCHECK_NE(type, ConvertedType::NA);
+ DCHECK_NE(type, ConvertedType::UNDEFINED);
+ return static_cast<format::ConvertedType::type>(static_cast<int>(type) - 1);
+}
+
+static inline format::FieldRepetitionType::type ToThrift(Repetition::type type) {
+ return static_cast<format::FieldRepetitionType::type>(type);
+}
+
+static inline format::Encoding::type ToThrift(Encoding::type type) {
+ return static_cast<format::Encoding::type>(type);
+}
+
+static inline format::CompressionCodec::type ToThrift(Compression::type type) {
+ switch (type) {
+ case Compression::UNCOMPRESSED:
+ return format::CompressionCodec::UNCOMPRESSED;
+ case Compression::SNAPPY:
+ return format::CompressionCodec::SNAPPY;
+ case Compression::GZIP:
+ return format::CompressionCodec::GZIP;
+ case Compression::LZO:
+ return format::CompressionCodec::LZO;
+ case Compression::BROTLI:
+ return format::CompressionCodec::BROTLI;
+ case Compression::LZ4:
+ return format::CompressionCodec::LZ4_RAW;
+ case Compression::LZ4_HADOOP:
+ // Deprecated "LZ4" Parquet compression has Hadoop-specific framing
+ return format::CompressionCodec::LZ4;
+ case Compression::ZSTD:
+ return format::CompressionCodec::ZSTD;
+ default:
+ DCHECK(false) << "Cannot reach here";
+ return format::CompressionCodec::UNCOMPRESSED;
+ }
+}
+
+static inline format::Statistics ToThrift(const EncodedStatistics& stats) {
+ format::Statistics statistics;
+ if (stats.has_min) {
+ statistics.__set_min_value(stats.min());
+ // If the order is SIGNED, then the old min value must be set too.
+ // This for backward compatibility
+ if (stats.is_signed()) {
+ statistics.__set_min(stats.min());
+ }
+ }
+ if (stats.has_max) {
+ statistics.__set_max_value(stats.max());
+ // If the order is SIGNED, then the old max value must be set too.
+ // This for backward compatibility
+ if (stats.is_signed()) {
+ statistics.__set_max(stats.max());
+ }
+ }
+ if (stats.has_null_count) {
+ statistics.__set_null_count(stats.null_count);
+ }
+ if (stats.has_distinct_count) {
+ statistics.__set_distinct_count(stats.distinct_count);
+ }
+
+ return statistics;
+}
+
+static inline format::AesGcmV1 ToAesGcmV1Thrift(AadMetadata aad) {
+ format::AesGcmV1 aesGcmV1;
+ // aad_file_unique is always set
+ aesGcmV1.__set_aad_file_unique(aad.aad_file_unique);
+ aesGcmV1.__set_supply_aad_prefix(aad.supply_aad_prefix);
+ if (!aad.aad_prefix.empty()) {
+ aesGcmV1.__set_aad_prefix(aad.aad_prefix);
+ }
+ return aesGcmV1;
+}
+
+static inline format::AesGcmCtrV1 ToAesGcmCtrV1Thrift(AadMetadata aad) {
+ format::AesGcmCtrV1 aesGcmCtrV1;
+ // aad_file_unique is always set
+ aesGcmCtrV1.__set_aad_file_unique(aad.aad_file_unique);
+ aesGcmCtrV1.__set_supply_aad_prefix(aad.supply_aad_prefix);
+ if (!aad.aad_prefix.empty()) {
+ aesGcmCtrV1.__set_aad_prefix(aad.aad_prefix);
+ }
+ return aesGcmCtrV1;
+}
+
+static inline format::EncryptionAlgorithm ToThrift(EncryptionAlgorithm encryption) {
+ format::EncryptionAlgorithm encryption_algorithm;
+ if (encryption.algorithm == ParquetCipher::AES_GCM_V1) {
+ encryption_algorithm.__set_AES_GCM_V1(ToAesGcmV1Thrift(encryption.aad));
+ } else {
+ encryption_algorithm.__set_AES_GCM_CTR_V1(ToAesGcmCtrV1Thrift(encryption.aad));
+ }
+ return encryption_algorithm;
+}
+
+// ----------------------------------------------------------------------
+// Thrift struct serialization / deserialization utilities
+
+using ThriftBuffer = apache::thrift::transport::TMemoryBuffer;
+
+// On Thrift 0.14.0+, we want to use TConfiguration to raise the max message size
+// limit (ARROW-13655). If we wanted to protect against huge messages, we could
+// do it ourselves since we know the message size up front.
+
+inline std::shared_ptr<ThriftBuffer> CreateReadOnlyMemoryBuffer(uint8_t* buf,
+ uint32_t len) {
+#if PARQUET_THRIFT_VERSION_MAJOR > 0 || PARQUET_THRIFT_VERSION_MINOR >= 14
+ auto conf = std::make_shared<apache::thrift::TConfiguration>();
+ conf->setMaxMessageSize(std::numeric_limits<int>::max());
+ return std::make_shared<ThriftBuffer>(buf, len, ThriftBuffer::OBSERVE, conf);
+#else
+ return std::make_shared<ThriftBuffer>(buf, len);
+#endif
+}
+
+template <class T>
+inline void DeserializeThriftUnencryptedMsg(const uint8_t* buf, uint32_t* len,
+ T* deserialized_msg) {
+ // Deserialize msg bytes into c++ thrift msg using memory transport.
+ auto tmem_transport = CreateReadOnlyMemoryBuffer(const_cast<uint8_t*>(buf), *len);
+ apache::thrift::protocol::TCompactProtocolFactoryT<ThriftBuffer> tproto_factory;
+ // Protect against CPU and memory bombs
+ tproto_factory.setStringSizeLimit(100 * 1000 * 1000);
+ // Structs in the thrift definition are relatively large (at least 300 bytes).
+ // This limits total memory to the same order of magnitude as stringSize.
+ tproto_factory.setContainerSizeLimit(1000 * 1000);
+ shared_ptr<apache::thrift::protocol::TProtocol> tproto = //
+ tproto_factory.getProtocol(tmem_transport);
+ try {
+ deserialized_msg->read(tproto.get());
+ } catch (std::exception& e) {
+ std::stringstream ss;
+ ss << "Couldn't deserialize thrift: " << e.what() << "\n";
+ throw ParquetException(ss.str());
+ }
+ uint32_t bytes_left = tmem_transport->available_read();
+ *len = *len - bytes_left;
+}
+
+// Deserialize a thrift message from buf/len. buf/len must at least contain
+// all the bytes needed to store the thrift message. On return, len will be
+// set to the actual length of the header.
+template <class T>
+inline void DeserializeThriftMsg(const uint8_t* buf, uint32_t* len, T* deserialized_msg,
+ const std::shared_ptr<Decryptor>& decryptor = NULLPTR) {
+ // thrift message is not encrypted
+ if (decryptor == NULLPTR) {
+ DeserializeThriftUnencryptedMsg(buf, len, deserialized_msg);
+ } else { // thrift message is encrypted
+ uint32_t clen;
+ clen = *len;
+ // decrypt
+ std::shared_ptr<ResizableBuffer> decrypted_buffer =
+ std::static_pointer_cast<ResizableBuffer>(AllocateBuffer(
+ decryptor->pool(),
+ static_cast<int64_t>(clen - decryptor->CiphertextSizeDelta())));
+ const uint8_t* cipher_buf = buf;
+ uint32_t decrypted_buffer_len =
+ decryptor->Decrypt(cipher_buf, 0, decrypted_buffer->mutable_data());
+ if (decrypted_buffer_len <= 0) {
+ throw ParquetException("Couldn't decrypt buffer\n");
+ }
+ *len = decrypted_buffer_len + decryptor->CiphertextSizeDelta();
+ DeserializeThriftMsg(decrypted_buffer->data(), &decrypted_buffer_len,
+ deserialized_msg);
+ }
+}
+
+/// Utility class to serialize thrift objects to a binary format. This object
+/// should be reused if possible to reuse the underlying memory.
+/// Note: thrift will encode NULLs into the serialized buffer so it is not valid
+/// to treat it as a string.
+class ThriftSerializer {
+ public:
+ explicit ThriftSerializer(int initial_buffer_size = 1024)
+ : mem_buffer_(new ThriftBuffer(initial_buffer_size)) {
+ apache::thrift::protocol::TCompactProtocolFactoryT<ThriftBuffer> factory;
+ protocol_ = factory.getProtocol(mem_buffer_);
+ }
+
+ /// Serialize obj into a memory buffer. The result is returned in buffer/len. The
+ /// memory returned is owned by this object and will be invalid when another object
+ /// is serialized.
+ template <class T>
+ void SerializeToBuffer(const T* obj, uint32_t* len, uint8_t** buffer) {
+ SerializeObject(obj);
+ mem_buffer_->getBuffer(buffer, len);
+ }
+
+ template <class T>
+ void SerializeToString(const T* obj, std::string* result) {
+ SerializeObject(obj);
+ *result = mem_buffer_->getBufferAsString();
+ }
+
+ template <class T>
+ int64_t Serialize(const T* obj, ArrowOutputStream* out,
+ const std::shared_ptr<Encryptor>& encryptor = NULLPTR) {
+ uint8_t* out_buffer;
+ uint32_t out_length;
+ SerializeToBuffer(obj, &out_length, &out_buffer);
+
+ // obj is not encrypted
+ if (encryptor == NULLPTR) {
+ PARQUET_THROW_NOT_OK(out->Write(out_buffer, out_length));
+ return static_cast<int64_t>(out_length);
+ } else { // obj is encrypted
+ return SerializeEncryptedObj(out, out_buffer, out_length, encryptor);
+ }
+ }
+
+ private:
+ template <class T>
+ void SerializeObject(const T* obj) {
+ try {
+ mem_buffer_->resetBuffer();
+ obj->write(protocol_.get());
+ } catch (std::exception& e) {
+ std::stringstream ss;
+ ss << "Couldn't serialize thrift: " << e.what() << "\n";
+ throw ParquetException(ss.str());
+ }
+ }
+
+ int64_t SerializeEncryptedObj(ArrowOutputStream* out, uint8_t* out_buffer,
+ uint32_t out_length,
+ const std::shared_ptr<Encryptor>& encryptor) {
+ std::shared_ptr<ResizableBuffer> cipher_buffer =
+ std::static_pointer_cast<ResizableBuffer>(AllocateBuffer(
+ encryptor->pool(),
+ static_cast<int64_t>(encryptor->CiphertextSizeDelta() + out_length)));
+ int cipher_buffer_len =
+ encryptor->Encrypt(out_buffer, out_length, cipher_buffer->mutable_data());
+
+ PARQUET_THROW_NOT_OK(out->Write(cipher_buffer->data(), cipher_buffer_len));
+ return static_cast<int64_t>(cipher_buffer_len);
+ }
+
+ shared_ptr<ThriftBuffer> mem_buffer_;
+ shared_ptr<apache::thrift::protocol::TProtocol> protocol_;
+};
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/type_fwd.h b/src/arrow/cpp/src/parquet/type_fwd.h
new file mode 100644
index 000000000..3e66f32fc
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/type_fwd.h
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+namespace parquet {
+
+/// \brief Feature selection when writing Parquet files
+///
+/// `ParquetVersion::type` governs which data types are allowed and how they
+/// are represented. For example, uint32_t data will be written differently
+/// depending on this value (as INT64 for PARQUET_1_0, as UINT32 for other
+/// versions).
+///
+/// However, some features - such as compression algorithms, encryption,
+/// or the improved "v2" data page format - must be enabled separately in
+/// ArrowWriterProperties.
+struct ParquetVersion {
+ enum type : int {
+ /// Enable only pre-2.2 Parquet format features when writing
+ ///
+ /// This setting is useful for maximum compatibility with legacy readers.
+ /// Note that logical types may still be emitted, as long they have a
+ /// corresponding converted type.
+ PARQUET_1_0,
+
+ /// DEPRECATED: Enable Parquet format 2.6 features
+ ///
+ /// This misleadingly named enum value is roughly similar to PARQUET_2_6.
+ PARQUET_2_0 ARROW_DEPRECATED_ENUM_VALUE("use PARQUET_2_4 or PARQUET_2_6 "
+ "for fine-grained feature selection"),
+
+ /// Enable Parquet format 2.4 and earlier features when writing
+ ///
+ /// This enables UINT32 as well as logical types which don't have
+ /// a corresponding converted type.
+ ///
+ /// Note: Parquet format 2.4.0 was released in October 2017.
+ PARQUET_2_4,
+
+ /// Enable Parquet format 2.6 and earlier features when writing
+ ///
+ /// This enables the NANOS time unit in addition to the PARQUET_2_4
+ /// features.
+ ///
+ /// Note: Parquet format 2.6.0 was released in September 2018.
+ PARQUET_2_6,
+
+ /// Enable latest Parquet format 2.x features
+ ///
+ /// This value is equal to the greatest 2.x version supported by
+ /// this library.
+ PARQUET_2_LATEST = PARQUET_2_6
+ };
+};
+
+class FileMetaData;
+class SchemaDescriptor;
+
+class ReaderProperties;
+class ArrowReaderProperties;
+
+class WriterProperties;
+class WriterPropertiesBuilder;
+class ArrowWriterProperties;
+class ArrowWriterPropertiesBuilder;
+
+namespace arrow {
+
+class FileWriter;
+class FileReader;
+
+} // namespace arrow
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/types.cc b/src/arrow/cpp/src/parquet/types.cc
new file mode 100644
index 000000000..ef23c4066
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/types.cc
@@ -0,0 +1,1567 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cmath>
+#include <cstdint>
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/compression.h"
+#include "arrow/util/logging.h"
+
+#include "parquet/exception.h"
+#include "parquet/types.h"
+
+#include "generated/parquet_types.h"
+
+using arrow::internal::checked_cast;
+using arrow::util::Codec;
+
+namespace parquet {
+
+bool IsCodecSupported(Compression::type codec) {
+ switch (codec) {
+ case Compression::UNCOMPRESSED:
+ case Compression::SNAPPY:
+ case Compression::GZIP:
+ case Compression::BROTLI:
+ case Compression::ZSTD:
+ case Compression::LZ4:
+ case Compression::LZ4_HADOOP:
+ return true;
+ default:
+ return false;
+ }
+}
+
+std::unique_ptr<Codec> GetCodec(Compression::type codec) {
+ return GetCodec(codec, Codec::UseDefaultCompressionLevel());
+}
+
+std::unique_ptr<Codec> GetCodec(Compression::type codec, int compression_level) {
+ std::unique_ptr<Codec> result;
+ if (codec == Compression::LZO) {
+ throw ParquetException(
+ "While LZO compression is supported by the Parquet format in "
+ "general, it is currently not supported by the C++ implementation.");
+ }
+
+ if (!IsCodecSupported(codec)) {
+ std::stringstream ss;
+ ss << "Codec type " << Codec::GetCodecAsString(codec)
+ << " not supported in Parquet format";
+ throw ParquetException(ss.str());
+ }
+
+ PARQUET_ASSIGN_OR_THROW(result, Codec::Create(codec, compression_level));
+ return result;
+}
+
+std::string FormatStatValue(Type::type parquet_type, ::arrow::util::string_view val) {
+ std::stringstream result;
+
+ const char* bytes = val.data();
+ switch (parquet_type) {
+ case Type::BOOLEAN:
+ result << reinterpret_cast<const bool*>(bytes)[0];
+ break;
+ case Type::INT32:
+ result << reinterpret_cast<const int32_t*>(bytes)[0];
+ break;
+ case Type::INT64:
+ result << reinterpret_cast<const int64_t*>(bytes)[0];
+ break;
+ case Type::DOUBLE:
+ result << reinterpret_cast<const double*>(bytes)[0];
+ break;
+ case Type::FLOAT:
+ result << reinterpret_cast<const float*>(bytes)[0];
+ break;
+ case Type::INT96: {
+ auto const i32_val = reinterpret_cast<const int32_t*>(bytes);
+ result << i32_val[0] << " " << i32_val[1] << " " << i32_val[2];
+ break;
+ }
+ case Type::BYTE_ARRAY: {
+ return std::string(val);
+ }
+ case Type::FIXED_LEN_BYTE_ARRAY: {
+ return std::string(val);
+ }
+ case Type::UNDEFINED:
+ default:
+ break;
+ }
+ return result.str();
+}
+
+std::string EncodingToString(Encoding::type t) {
+ switch (t) {
+ case Encoding::PLAIN:
+ return "PLAIN";
+ case Encoding::PLAIN_DICTIONARY:
+ return "PLAIN_DICTIONARY";
+ case Encoding::RLE:
+ return "RLE";
+ case Encoding::BIT_PACKED:
+ return "BIT_PACKED";
+ case Encoding::DELTA_BINARY_PACKED:
+ return "DELTA_BINARY_PACKED";
+ case Encoding::DELTA_LENGTH_BYTE_ARRAY:
+ return "DELTA_LENGTH_BYTE_ARRAY";
+ case Encoding::DELTA_BYTE_ARRAY:
+ return "DELTA_BYTE_ARRAY";
+ case Encoding::RLE_DICTIONARY:
+ return "RLE_DICTIONARY";
+ case Encoding::BYTE_STREAM_SPLIT:
+ return "BYTE_STREAM_SPLIT";
+ default:
+ return "UNKNOWN";
+ }
+}
+
+std::string TypeToString(Type::type t) {
+ switch (t) {
+ case Type::BOOLEAN:
+ return "BOOLEAN";
+ case Type::INT32:
+ return "INT32";
+ case Type::INT64:
+ return "INT64";
+ case Type::INT96:
+ return "INT96";
+ case Type::FLOAT:
+ return "FLOAT";
+ case Type::DOUBLE:
+ return "DOUBLE";
+ case Type::BYTE_ARRAY:
+ return "BYTE_ARRAY";
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return "FIXED_LEN_BYTE_ARRAY";
+ case Type::UNDEFINED:
+ default:
+ return "UNKNOWN";
+ }
+}
+
+std::string ConvertedTypeToString(ConvertedType::type t) {
+ switch (t) {
+ case ConvertedType::NONE:
+ return "NONE";
+ case ConvertedType::UTF8:
+ return "UTF8";
+ case ConvertedType::MAP:
+ return "MAP";
+ case ConvertedType::MAP_KEY_VALUE:
+ return "MAP_KEY_VALUE";
+ case ConvertedType::LIST:
+ return "LIST";
+ case ConvertedType::ENUM:
+ return "ENUM";
+ case ConvertedType::DECIMAL:
+ return "DECIMAL";
+ case ConvertedType::DATE:
+ return "DATE";
+ case ConvertedType::TIME_MILLIS:
+ return "TIME_MILLIS";
+ case ConvertedType::TIME_MICROS:
+ return "TIME_MICROS";
+ case ConvertedType::TIMESTAMP_MILLIS:
+ return "TIMESTAMP_MILLIS";
+ case ConvertedType::TIMESTAMP_MICROS:
+ return "TIMESTAMP_MICROS";
+ case ConvertedType::UINT_8:
+ return "UINT_8";
+ case ConvertedType::UINT_16:
+ return "UINT_16";
+ case ConvertedType::UINT_32:
+ return "UINT_32";
+ case ConvertedType::UINT_64:
+ return "UINT_64";
+ case ConvertedType::INT_8:
+ return "INT_8";
+ case ConvertedType::INT_16:
+ return "INT_16";
+ case ConvertedType::INT_32:
+ return "INT_32";
+ case ConvertedType::INT_64:
+ return "INT_64";
+ case ConvertedType::JSON:
+ return "JSON";
+ case ConvertedType::BSON:
+ return "BSON";
+ case ConvertedType::INTERVAL:
+ return "INTERVAL";
+ case ConvertedType::UNDEFINED:
+ default:
+ return "UNKNOWN";
+ }
+}
+
+int GetTypeByteSize(Type::type parquet_type) {
+ switch (parquet_type) {
+ case Type::BOOLEAN:
+ return type_traits<BooleanType::type_num>::value_byte_size;
+ case Type::INT32:
+ return type_traits<Int32Type::type_num>::value_byte_size;
+ case Type::INT64:
+ return type_traits<Int64Type::type_num>::value_byte_size;
+ case Type::INT96:
+ return type_traits<Int96Type::type_num>::value_byte_size;
+ case Type::DOUBLE:
+ return type_traits<DoubleType::type_num>::value_byte_size;
+ case Type::FLOAT:
+ return type_traits<FloatType::type_num>::value_byte_size;
+ case Type::BYTE_ARRAY:
+ return type_traits<ByteArrayType::type_num>::value_byte_size;
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return type_traits<FLBAType::type_num>::value_byte_size;
+ case Type::UNDEFINED:
+ default:
+ return 0;
+ }
+ return 0;
+}
+
+// Return the Sort Order of the Parquet Physical Types
+SortOrder::type DefaultSortOrder(Type::type primitive) {
+ switch (primitive) {
+ case Type::BOOLEAN:
+ case Type::INT32:
+ case Type::INT64:
+ case Type::FLOAT:
+ case Type::DOUBLE:
+ return SortOrder::SIGNED;
+ case Type::BYTE_ARRAY:
+ case Type::FIXED_LEN_BYTE_ARRAY:
+ return SortOrder::UNSIGNED;
+ case Type::INT96:
+ case Type::UNDEFINED:
+ return SortOrder::UNKNOWN;
+ }
+ return SortOrder::UNKNOWN;
+}
+
+// Return the SortOrder of the Parquet Types using Logical or Physical Types
+SortOrder::type GetSortOrder(ConvertedType::type converted, Type::type primitive) {
+ if (converted == ConvertedType::NONE) return DefaultSortOrder(primitive);
+ switch (converted) {
+ case ConvertedType::INT_8:
+ case ConvertedType::INT_16:
+ case ConvertedType::INT_32:
+ case ConvertedType::INT_64:
+ case ConvertedType::DATE:
+ case ConvertedType::TIME_MICROS:
+ case ConvertedType::TIME_MILLIS:
+ case ConvertedType::TIMESTAMP_MICROS:
+ case ConvertedType::TIMESTAMP_MILLIS:
+ return SortOrder::SIGNED;
+ case ConvertedType::UINT_8:
+ case ConvertedType::UINT_16:
+ case ConvertedType::UINT_32:
+ case ConvertedType::UINT_64:
+ case ConvertedType::ENUM:
+ case ConvertedType::UTF8:
+ case ConvertedType::BSON:
+ case ConvertedType::JSON:
+ return SortOrder::UNSIGNED;
+ case ConvertedType::DECIMAL:
+ case ConvertedType::LIST:
+ case ConvertedType::MAP:
+ case ConvertedType::MAP_KEY_VALUE:
+ case ConvertedType::INTERVAL:
+ case ConvertedType::NONE: // required instead of default
+ case ConvertedType::NA: // required instead of default
+ case ConvertedType::UNDEFINED:
+ return SortOrder::UNKNOWN;
+ }
+ return SortOrder::UNKNOWN;
+}
+
+SortOrder::type GetSortOrder(const std::shared_ptr<const LogicalType>& logical_type,
+ Type::type primitive) {
+ SortOrder::type o = SortOrder::UNKNOWN;
+ if (logical_type && logical_type->is_valid()) {
+ o = (logical_type->is_none() ? DefaultSortOrder(primitive)
+ : logical_type->sort_order());
+ }
+ return o;
+}
+
+ColumnOrder ColumnOrder::undefined_ = ColumnOrder(ColumnOrder::UNDEFINED);
+ColumnOrder ColumnOrder::type_defined_ = ColumnOrder(ColumnOrder::TYPE_DEFINED_ORDER);
+
+// Static methods for LogicalType class
+
+std::shared_ptr<const LogicalType> LogicalType::FromConvertedType(
+ const ConvertedType::type converted_type,
+ const schema::DecimalMetadata converted_decimal_metadata) {
+ switch (converted_type) {
+ case ConvertedType::UTF8:
+ return StringLogicalType::Make();
+ case ConvertedType::MAP_KEY_VALUE:
+ case ConvertedType::MAP:
+ return MapLogicalType::Make();
+ case ConvertedType::LIST:
+ return ListLogicalType::Make();
+ case ConvertedType::ENUM:
+ return EnumLogicalType::Make();
+ case ConvertedType::DECIMAL:
+ return DecimalLogicalType::Make(converted_decimal_metadata.precision,
+ converted_decimal_metadata.scale);
+ case ConvertedType::DATE:
+ return DateLogicalType::Make();
+ case ConvertedType::TIME_MILLIS:
+ return TimeLogicalType::Make(true, LogicalType::TimeUnit::MILLIS);
+ case ConvertedType::TIME_MICROS:
+ return TimeLogicalType::Make(true, LogicalType::TimeUnit::MICROS);
+ case ConvertedType::TIMESTAMP_MILLIS:
+ return TimestampLogicalType::Make(true, LogicalType::TimeUnit::MILLIS,
+ /*is_from_converted_type=*/true,
+ /*force_set_converted_type=*/false);
+ case ConvertedType::TIMESTAMP_MICROS:
+ return TimestampLogicalType::Make(true, LogicalType::TimeUnit::MICROS,
+ /*is_from_converted_type=*/true,
+ /*force_set_converted_type=*/false);
+ case ConvertedType::INTERVAL:
+ return IntervalLogicalType::Make();
+ case ConvertedType::INT_8:
+ return IntLogicalType::Make(8, true);
+ case ConvertedType::INT_16:
+ return IntLogicalType::Make(16, true);
+ case ConvertedType::INT_32:
+ return IntLogicalType::Make(32, true);
+ case ConvertedType::INT_64:
+ return IntLogicalType::Make(64, true);
+ case ConvertedType::UINT_8:
+ return IntLogicalType::Make(8, false);
+ case ConvertedType::UINT_16:
+ return IntLogicalType::Make(16, false);
+ case ConvertedType::UINT_32:
+ return IntLogicalType::Make(32, false);
+ case ConvertedType::UINT_64:
+ return IntLogicalType::Make(64, false);
+ case ConvertedType::JSON:
+ return JSONLogicalType::Make();
+ case ConvertedType::BSON:
+ return BSONLogicalType::Make();
+ case ConvertedType::NA:
+ return NullLogicalType::Make();
+ case ConvertedType::NONE:
+ return NoLogicalType::Make();
+ case ConvertedType::UNDEFINED:
+ return UndefinedLogicalType::Make();
+ }
+ return UndefinedLogicalType::Make();
+}
+
+std::shared_ptr<const LogicalType> LogicalType::FromThrift(
+ const format::LogicalType& type) {
+ if (type.__isset.STRING) {
+ return StringLogicalType::Make();
+ } else if (type.__isset.MAP) {
+ return MapLogicalType::Make();
+ } else if (type.__isset.LIST) {
+ return ListLogicalType::Make();
+ } else if (type.__isset.ENUM) {
+ return EnumLogicalType::Make();
+ } else if (type.__isset.DECIMAL) {
+ return DecimalLogicalType::Make(type.DECIMAL.precision, type.DECIMAL.scale);
+ } else if (type.__isset.DATE) {
+ return DateLogicalType::Make();
+ } else if (type.__isset.TIME) {
+ LogicalType::TimeUnit::unit unit;
+ if (type.TIME.unit.__isset.MILLIS) {
+ unit = LogicalType::TimeUnit::MILLIS;
+ } else if (type.TIME.unit.__isset.MICROS) {
+ unit = LogicalType::TimeUnit::MICROS;
+ } else if (type.TIME.unit.__isset.NANOS) {
+ unit = LogicalType::TimeUnit::NANOS;
+ } else {
+ unit = LogicalType::TimeUnit::UNKNOWN;
+ }
+ return TimeLogicalType::Make(type.TIME.isAdjustedToUTC, unit);
+ } else if (type.__isset.TIMESTAMP) {
+ LogicalType::TimeUnit::unit unit;
+ if (type.TIMESTAMP.unit.__isset.MILLIS) {
+ unit = LogicalType::TimeUnit::MILLIS;
+ } else if (type.TIMESTAMP.unit.__isset.MICROS) {
+ unit = LogicalType::TimeUnit::MICROS;
+ } else if (type.TIMESTAMP.unit.__isset.NANOS) {
+ unit = LogicalType::TimeUnit::NANOS;
+ } else {
+ unit = LogicalType::TimeUnit::UNKNOWN;
+ }
+ return TimestampLogicalType::Make(type.TIMESTAMP.isAdjustedToUTC, unit);
+ // TODO(tpboudreau): activate the commented code after parquet.thrift
+ // recognizes IntervalType as a LogicalType
+ //} else if (type.__isset.INTERVAL) {
+ // return IntervalLogicalType::Make();
+ } else if (type.__isset.INTEGER) {
+ return IntLogicalType::Make(static_cast<int>(type.INTEGER.bitWidth),
+ type.INTEGER.isSigned);
+ } else if (type.__isset.UNKNOWN) {
+ return NullLogicalType::Make();
+ } else if (type.__isset.JSON) {
+ return JSONLogicalType::Make();
+ } else if (type.__isset.BSON) {
+ return BSONLogicalType::Make();
+ } else if (type.__isset.UUID) {
+ return UUIDLogicalType::Make();
+ } else {
+ throw ParquetException("Metadata contains Thrift LogicalType that is not recognized");
+ }
+}
+
+std::shared_ptr<const LogicalType> LogicalType::String() {
+ return StringLogicalType::Make();
+}
+
+std::shared_ptr<const LogicalType> LogicalType::Map() { return MapLogicalType::Make(); }
+
+std::shared_ptr<const LogicalType> LogicalType::List() { return ListLogicalType::Make(); }
+
+std::shared_ptr<const LogicalType> LogicalType::Enum() { return EnumLogicalType::Make(); }
+
+std::shared_ptr<const LogicalType> LogicalType::Decimal(int32_t precision,
+ int32_t scale) {
+ return DecimalLogicalType::Make(precision, scale);
+}
+
+std::shared_ptr<const LogicalType> LogicalType::Date() { return DateLogicalType::Make(); }
+
+std::shared_ptr<const LogicalType> LogicalType::Time(
+ bool is_adjusted_to_utc, LogicalType::TimeUnit::unit time_unit) {
+ DCHECK(time_unit != LogicalType::TimeUnit::UNKNOWN);
+ return TimeLogicalType::Make(is_adjusted_to_utc, time_unit);
+}
+
+std::shared_ptr<const LogicalType> LogicalType::Timestamp(
+ bool is_adjusted_to_utc, LogicalType::TimeUnit::unit time_unit,
+ bool is_from_converted_type, bool force_set_converted_type) {
+ DCHECK(time_unit != LogicalType::TimeUnit::UNKNOWN);
+ return TimestampLogicalType::Make(is_adjusted_to_utc, time_unit, is_from_converted_type,
+ force_set_converted_type);
+}
+
+std::shared_ptr<const LogicalType> LogicalType::Interval() {
+ return IntervalLogicalType::Make();
+}
+
+std::shared_ptr<const LogicalType> LogicalType::Int(int bit_width, bool is_signed) {
+ DCHECK(bit_width == 64 || bit_width == 32 || bit_width == 16 || bit_width == 8);
+ return IntLogicalType::Make(bit_width, is_signed);
+}
+
+std::shared_ptr<const LogicalType> LogicalType::Null() { return NullLogicalType::Make(); }
+
+std::shared_ptr<const LogicalType> LogicalType::JSON() { return JSONLogicalType::Make(); }
+
+std::shared_ptr<const LogicalType> LogicalType::BSON() { return BSONLogicalType::Make(); }
+
+std::shared_ptr<const LogicalType> LogicalType::UUID() { return UUIDLogicalType::Make(); }
+
+std::shared_ptr<const LogicalType> LogicalType::None() { return NoLogicalType::Make(); }
+
+/*
+ * The logical type implementation classes are built in four layers: (1) the base
+ * layer, which establishes the interface and provides generally reusable implementations
+ * for the ToJSON() and Equals() methods; (2) an intermediate derived layer for the
+ * "compatibility" methods, which provides implementations for is_compatible() and
+ * ToConvertedType(); (3) another intermediate layer for the "applicability" methods
+ * that provides several implementations for the is_applicable() method; and (4) the
+ * final derived classes, one for each logical type, which supply implementations
+ * for those methods that remain virtual (usually just ToString() and ToThrift()) or
+ * otherwise need to be overridden.
+ */
+
+// LogicalTypeImpl base class
+
+class LogicalType::Impl {
+ public:
+ virtual bool is_applicable(parquet::Type::type primitive_type,
+ int32_t primitive_length = -1) const = 0;
+
+ virtual bool is_compatible(ConvertedType::type converted_type,
+ schema::DecimalMetadata converted_decimal_metadata = {
+ false, -1, -1}) const = 0;
+
+ virtual ConvertedType::type ToConvertedType(
+ schema::DecimalMetadata* out_decimal_metadata) const = 0;
+
+ virtual std::string ToString() const = 0;
+
+ virtual bool is_serialized() const {
+ return !(type_ == LogicalType::Type::NONE || type_ == LogicalType::Type::UNDEFINED);
+ }
+
+ virtual std::string ToJSON() const {
+ std::stringstream json;
+ json << R"({"Type": ")" << ToString() << R"("})";
+ return json.str();
+ }
+
+ virtual format::LogicalType ToThrift() const {
+ // logical types inheriting this method should never be serialized
+ std::stringstream ss;
+ ss << "Logical type " << ToString() << " should not be serialized";
+ throw ParquetException(ss.str());
+ }
+
+ virtual bool Equals(const LogicalType& other) const { return other.type() == type_; }
+
+ LogicalType::Type::type type() const { return type_; }
+
+ SortOrder::type sort_order() const { return order_; }
+
+ Impl(const Impl&) = delete;
+ Impl& operator=(const Impl&) = delete;
+ virtual ~Impl() noexcept {}
+
+ class Compatible;
+ class SimpleCompatible;
+ class Incompatible;
+
+ class Applicable;
+ class SimpleApplicable;
+ class TypeLengthApplicable;
+ class UniversalApplicable;
+ class Inapplicable;
+
+ class String;
+ class Map;
+ class List;
+ class Enum;
+ class Decimal;
+ class Date;
+ class Time;
+ class Timestamp;
+ class Interval;
+ class Int;
+ class Null;
+ class JSON;
+ class BSON;
+ class UUID;
+ class No;
+ class Undefined;
+
+ protected:
+ Impl(LogicalType::Type::type t, SortOrder::type o) : type_(t), order_(o) {}
+ Impl() = default;
+
+ private:
+ LogicalType::Type::type type_ = LogicalType::Type::UNDEFINED;
+ SortOrder::type order_ = SortOrder::UNKNOWN;
+};
+
+// Special methods for public LogicalType class
+
+LogicalType::LogicalType() = default;
+LogicalType::~LogicalType() noexcept = default;
+
+// Delegating methods for public LogicalType class
+
+bool LogicalType::is_applicable(parquet::Type::type primitive_type,
+ int32_t primitive_length) const {
+ return impl_->is_applicable(primitive_type, primitive_length);
+}
+
+bool LogicalType::is_compatible(
+ ConvertedType::type converted_type,
+ schema::DecimalMetadata converted_decimal_metadata) const {
+ return impl_->is_compatible(converted_type, converted_decimal_metadata);
+}
+
+ConvertedType::type LogicalType::ToConvertedType(
+ schema::DecimalMetadata* out_decimal_metadata) const {
+ return impl_->ToConvertedType(out_decimal_metadata);
+}
+
+std::string LogicalType::ToString() const { return impl_->ToString(); }
+
+std::string LogicalType::ToJSON() const { return impl_->ToJSON(); }
+
+format::LogicalType LogicalType::ToThrift() const { return impl_->ToThrift(); }
+
+bool LogicalType::Equals(const LogicalType& other) const { return impl_->Equals(other); }
+
+LogicalType::Type::type LogicalType::type() const { return impl_->type(); }
+
+SortOrder::type LogicalType::sort_order() const { return impl_->sort_order(); }
+
+// Type checks for public LogicalType class
+
+bool LogicalType::is_string() const { return impl_->type() == LogicalType::Type::STRING; }
+bool LogicalType::is_map() const { return impl_->type() == LogicalType::Type::MAP; }
+bool LogicalType::is_list() const { return impl_->type() == LogicalType::Type::LIST; }
+bool LogicalType::is_enum() const { return impl_->type() == LogicalType::Type::ENUM; }
+bool LogicalType::is_decimal() const {
+ return impl_->type() == LogicalType::Type::DECIMAL;
+}
+bool LogicalType::is_date() const { return impl_->type() == LogicalType::Type::DATE; }
+bool LogicalType::is_time() const { return impl_->type() == LogicalType::Type::TIME; }
+bool LogicalType::is_timestamp() const {
+ return impl_->type() == LogicalType::Type::TIMESTAMP;
+}
+bool LogicalType::is_interval() const {
+ return impl_->type() == LogicalType::Type::INTERVAL;
+}
+bool LogicalType::is_int() const { return impl_->type() == LogicalType::Type::INT; }
+bool LogicalType::is_null() const { return impl_->type() == LogicalType::Type::NIL; }
+bool LogicalType::is_JSON() const { return impl_->type() == LogicalType::Type::JSON; }
+bool LogicalType::is_BSON() const { return impl_->type() == LogicalType::Type::BSON; }
+bool LogicalType::is_UUID() const { return impl_->type() == LogicalType::Type::UUID; }
+bool LogicalType::is_none() const { return impl_->type() == LogicalType::Type::NONE; }
+bool LogicalType::is_valid() const {
+ return impl_->type() != LogicalType::Type::UNDEFINED;
+}
+bool LogicalType::is_invalid() const { return !is_valid(); }
+bool LogicalType::is_nested() const {
+ return (impl_->type() == LogicalType::Type::LIST) ||
+ (impl_->type() == LogicalType::Type::MAP);
+}
+bool LogicalType::is_nonnested() const { return !is_nested(); }
+bool LogicalType::is_serialized() const { return impl_->is_serialized(); }
+
+// LogicalTypeImpl intermediate "compatibility" classes
+
+class LogicalType::Impl::Compatible : public virtual LogicalType::Impl {
+ protected:
+ Compatible() = default;
+};
+
+#define set_decimal_metadata(m___, i___, p___, s___) \
+ { \
+ if (m___) { \
+ (m___)->isset = (i___); \
+ (m___)->scale = (s___); \
+ (m___)->precision = (p___); \
+ } \
+ }
+
+#define reset_decimal_metadata(m___) \
+ { set_decimal_metadata(m___, false, -1, -1); }
+
+// For logical types that always translate to the same converted type
+class LogicalType::Impl::SimpleCompatible : public virtual LogicalType::Impl::Compatible {
+ public:
+ bool is_compatible(ConvertedType::type converted_type,
+ schema::DecimalMetadata converted_decimal_metadata) const override {
+ return (converted_type == converted_type_) && !converted_decimal_metadata.isset;
+ }
+
+ ConvertedType::type ToConvertedType(
+ schema::DecimalMetadata* out_decimal_metadata) const override {
+ reset_decimal_metadata(out_decimal_metadata);
+ return converted_type_;
+ }
+
+ protected:
+ explicit SimpleCompatible(ConvertedType::type c) : converted_type_(c) {}
+
+ private:
+ ConvertedType::type converted_type_ = ConvertedType::NA;
+};
+
+// For logical types that have no corresponding converted type
+class LogicalType::Impl::Incompatible : public virtual LogicalType::Impl {
+ public:
+ bool is_compatible(ConvertedType::type converted_type,
+ schema::DecimalMetadata converted_decimal_metadata) const override {
+ return (converted_type == ConvertedType::NONE ||
+ converted_type == ConvertedType::NA) &&
+ !converted_decimal_metadata.isset;
+ }
+
+ ConvertedType::type ToConvertedType(
+ schema::DecimalMetadata* out_decimal_metadata) const override {
+ reset_decimal_metadata(out_decimal_metadata);
+ return ConvertedType::NONE;
+ }
+
+ protected:
+ Incompatible() = default;
+};
+
+// LogicalTypeImpl intermediate "applicability" classes
+
+class LogicalType::Impl::Applicable : public virtual LogicalType::Impl {
+ protected:
+ Applicable() = default;
+};
+
+// For logical types that can apply only to a single
+// physical type
+class LogicalType::Impl::SimpleApplicable : public virtual LogicalType::Impl::Applicable {
+ public:
+ bool is_applicable(parquet::Type::type primitive_type,
+ int32_t primitive_length = -1) const override {
+ return primitive_type == type_;
+ }
+
+ protected:
+ explicit SimpleApplicable(parquet::Type::type t) : type_(t) {}
+
+ private:
+ parquet::Type::type type_;
+};
+
+// For logical types that can apply only to a particular
+// physical type and physical length combination
+class LogicalType::Impl::TypeLengthApplicable
+ : public virtual LogicalType::Impl::Applicable {
+ public:
+ bool is_applicable(parquet::Type::type primitive_type,
+ int32_t primitive_length = -1) const override {
+ return primitive_type == type_ && primitive_length == length_;
+ }
+
+ protected:
+ TypeLengthApplicable(parquet::Type::type t, int32_t l) : type_(t), length_(l) {}
+
+ private:
+ parquet::Type::type type_;
+ int32_t length_;
+};
+
+// For logical types that can apply to any physical type
+class LogicalType::Impl::UniversalApplicable
+ : public virtual LogicalType::Impl::Applicable {
+ public:
+ bool is_applicable(parquet::Type::type primitive_type,
+ int32_t primitive_length = -1) const override {
+ return true;
+ }
+
+ protected:
+ UniversalApplicable() = default;
+};
+
+// For logical types that can never apply to any primitive
+// physical type
+class LogicalType::Impl::Inapplicable : public virtual LogicalType::Impl {
+ public:
+ bool is_applicable(parquet::Type::type primitive_type,
+ int32_t primitive_length = -1) const override {
+ return false;
+ }
+
+ protected:
+ Inapplicable() = default;
+};
+
+// LogicalType implementation final classes
+
+#define OVERRIDE_TOSTRING(n___) \
+ std::string ToString() const override { return #n___; }
+
+#define OVERRIDE_TOTHRIFT(t___, s___) \
+ format::LogicalType ToThrift() const override { \
+ format::LogicalType type; \
+ format::t___ subtype; \
+ type.__set_##s___(subtype); \
+ return type; \
+ }
+
+class LogicalType::Impl::String final : public LogicalType::Impl::SimpleCompatible,
+ public LogicalType::Impl::SimpleApplicable {
+ public:
+ friend class StringLogicalType;
+
+ OVERRIDE_TOSTRING(String)
+ OVERRIDE_TOTHRIFT(StringType, STRING)
+
+ private:
+ String()
+ : LogicalType::Impl(LogicalType::Type::STRING, SortOrder::UNSIGNED),
+ LogicalType::Impl::SimpleCompatible(ConvertedType::UTF8),
+ LogicalType::Impl::SimpleApplicable(parquet::Type::BYTE_ARRAY) {}
+};
+
+// Each public logical type class's Make() creation method instantiates a corresponding
+// LogicalType::Impl::* object and installs that implementation in the logical type
+// it returns.
+
+#define GENERATE_MAKE(a___) \
+ std::shared_ptr<const LogicalType> a___##LogicalType::Make() { \
+ auto* logical_type = new a___##LogicalType(); \
+ logical_type->impl_.reset(new LogicalType::Impl::a___()); \
+ return std::shared_ptr<const LogicalType>(logical_type); \
+ }
+
+GENERATE_MAKE(String)
+
+class LogicalType::Impl::Map final : public LogicalType::Impl::SimpleCompatible,
+ public LogicalType::Impl::Inapplicable {
+ public:
+ friend class MapLogicalType;
+
+ bool is_compatible(ConvertedType::type converted_type,
+ schema::DecimalMetadata converted_decimal_metadata) const override {
+ return (converted_type == ConvertedType::MAP ||
+ converted_type == ConvertedType::MAP_KEY_VALUE) &&
+ !converted_decimal_metadata.isset;
+ }
+
+ OVERRIDE_TOSTRING(Map)
+ OVERRIDE_TOTHRIFT(MapType, MAP)
+
+ private:
+ Map()
+ : LogicalType::Impl(LogicalType::Type::MAP, SortOrder::UNKNOWN),
+ LogicalType::Impl::SimpleCompatible(ConvertedType::MAP) {}
+};
+
+GENERATE_MAKE(Map)
+
+class LogicalType::Impl::List final : public LogicalType::Impl::SimpleCompatible,
+ public LogicalType::Impl::Inapplicable {
+ public:
+ friend class ListLogicalType;
+
+ OVERRIDE_TOSTRING(List)
+ OVERRIDE_TOTHRIFT(ListType, LIST)
+
+ private:
+ List()
+ : LogicalType::Impl(LogicalType::Type::LIST, SortOrder::UNKNOWN),
+ LogicalType::Impl::SimpleCompatible(ConvertedType::LIST) {}
+};
+
+GENERATE_MAKE(List)
+
+class LogicalType::Impl::Enum final : public LogicalType::Impl::SimpleCompatible,
+ public LogicalType::Impl::SimpleApplicable {
+ public:
+ friend class EnumLogicalType;
+
+ OVERRIDE_TOSTRING(Enum)
+ OVERRIDE_TOTHRIFT(EnumType, ENUM)
+
+ private:
+ Enum()
+ : LogicalType::Impl(LogicalType::Type::ENUM, SortOrder::UNSIGNED),
+ LogicalType::Impl::SimpleCompatible(ConvertedType::ENUM),
+ LogicalType::Impl::SimpleApplicable(parquet::Type::BYTE_ARRAY) {}
+};
+
+GENERATE_MAKE(Enum)
+
+// The parameterized logical types (currently Decimal, Time, Timestamp, and Int)
+// generally can't reuse the simple method implementations available in the base and
+// intermediate classes and must (re)implement them all
+
+class LogicalType::Impl::Decimal final : public LogicalType::Impl::Compatible,
+ public LogicalType::Impl::Applicable {
+ public:
+ friend class DecimalLogicalType;
+
+ bool is_applicable(parquet::Type::type primitive_type,
+ int32_t primitive_length = -1) const override;
+ bool is_compatible(ConvertedType::type converted_type,
+ schema::DecimalMetadata converted_decimal_metadata) const override;
+ ConvertedType::type ToConvertedType(
+ schema::DecimalMetadata* out_decimal_metadata) const override;
+ std::string ToString() const override;
+ std::string ToJSON() const override;
+ format::LogicalType ToThrift() const override;
+ bool Equals(const LogicalType& other) const override;
+
+ int32_t precision() const { return precision_; }
+ int32_t scale() const { return scale_; }
+
+ private:
+ Decimal(int32_t p, int32_t s)
+ : LogicalType::Impl(LogicalType::Type::DECIMAL, SortOrder::SIGNED),
+ precision_(p),
+ scale_(s) {}
+ int32_t precision_ = -1;
+ int32_t scale_ = -1;
+};
+
+bool LogicalType::Impl::Decimal::is_applicable(parquet::Type::type primitive_type,
+ int32_t primitive_length) const {
+ bool ok = false;
+ switch (primitive_type) {
+ case parquet::Type::INT32: {
+ ok = (1 <= precision_) && (precision_ <= 9);
+ } break;
+ case parquet::Type::INT64: {
+ ok = (1 <= precision_) && (precision_ <= 18);
+ if (precision_ < 10) {
+ // FIXME(tpb): warn that INT32 could be used
+ }
+ } break;
+ case parquet::Type::FIXED_LEN_BYTE_ARRAY: {
+ ok = precision_ <= static_cast<int32_t>(std::floor(
+ std::log10(std::pow(2.0, (8.0 * primitive_length) - 1.0))));
+ } break;
+ case parquet::Type::BYTE_ARRAY: {
+ ok = true;
+ } break;
+ default: {
+ } break;
+ }
+ return ok;
+}
+
+bool LogicalType::Impl::Decimal::is_compatible(
+ ConvertedType::type converted_type,
+ schema::DecimalMetadata converted_decimal_metadata) const {
+ return converted_type == ConvertedType::DECIMAL &&
+ (converted_decimal_metadata.isset &&
+ converted_decimal_metadata.scale == scale_ &&
+ converted_decimal_metadata.precision == precision_);
+}
+
+ConvertedType::type LogicalType::Impl::Decimal::ToConvertedType(
+ schema::DecimalMetadata* out_decimal_metadata) const {
+ set_decimal_metadata(out_decimal_metadata, true, precision_, scale_);
+ return ConvertedType::DECIMAL;
+}
+
+std::string LogicalType::Impl::Decimal::ToString() const {
+ std::stringstream type;
+ type << "Decimal(precision=" << precision_ << ", scale=" << scale_ << ")";
+ return type.str();
+}
+
+std::string LogicalType::Impl::Decimal::ToJSON() const {
+ std::stringstream json;
+ json << R"({"Type": "Decimal", "precision": )" << precision_ << R"(, "scale": )"
+ << scale_ << "}";
+ return json.str();
+}
+
+format::LogicalType LogicalType::Impl::Decimal::ToThrift() const {
+ format::LogicalType type;
+ format::DecimalType decimal_type;
+ decimal_type.__set_precision(precision_);
+ decimal_type.__set_scale(scale_);
+ type.__set_DECIMAL(decimal_type);
+ return type;
+}
+
+bool LogicalType::Impl::Decimal::Equals(const LogicalType& other) const {
+ bool eq = false;
+ if (other.is_decimal()) {
+ const auto& other_decimal = checked_cast<const DecimalLogicalType&>(other);
+ eq = (precision_ == other_decimal.precision() && scale_ == other_decimal.scale());
+ }
+ return eq;
+}
+
+std::shared_ptr<const LogicalType> DecimalLogicalType::Make(int32_t precision,
+ int32_t scale) {
+ if (precision < 1) {
+ throw ParquetException(
+ "Precision must be greater than or equal to 1 for Decimal logical type");
+ }
+ if (scale < 0 || scale > precision) {
+ throw ParquetException(
+ "Scale must be a non-negative integer that does not exceed precision for "
+ "Decimal logical type");
+ }
+ auto* logical_type = new DecimalLogicalType();
+ logical_type->impl_.reset(new LogicalType::Impl::Decimal(precision, scale));
+ return std::shared_ptr<const LogicalType>(logical_type);
+}
+
+int32_t DecimalLogicalType::precision() const {
+ return (dynamic_cast<const LogicalType::Impl::Decimal&>(*impl_)).precision();
+}
+
+int32_t DecimalLogicalType::scale() const {
+ return (dynamic_cast<const LogicalType::Impl::Decimal&>(*impl_)).scale();
+}
+
+class LogicalType::Impl::Date final : public LogicalType::Impl::SimpleCompatible,
+ public LogicalType::Impl::SimpleApplicable {
+ public:
+ friend class DateLogicalType;
+
+ OVERRIDE_TOSTRING(Date)
+ OVERRIDE_TOTHRIFT(DateType, DATE)
+
+ private:
+ Date()
+ : LogicalType::Impl(LogicalType::Type::DATE, SortOrder::SIGNED),
+ LogicalType::Impl::SimpleCompatible(ConvertedType::DATE),
+ LogicalType::Impl::SimpleApplicable(parquet::Type::INT32) {}
+};
+
+GENERATE_MAKE(Date)
+
+#define time_unit_string(u___) \
+ ((u___) == LogicalType::TimeUnit::MILLIS \
+ ? "milliseconds" \
+ : ((u___) == LogicalType::TimeUnit::MICROS \
+ ? "microseconds" \
+ : ((u___) == LogicalType::TimeUnit::NANOS ? "nanoseconds" : "unknown")))
+
+class LogicalType::Impl::Time final : public LogicalType::Impl::Compatible,
+ public LogicalType::Impl::Applicable {
+ public:
+ friend class TimeLogicalType;
+
+ bool is_applicable(parquet::Type::type primitive_type,
+ int32_t primitive_length = -1) const override;
+ bool is_compatible(ConvertedType::type converted_type,
+ schema::DecimalMetadata converted_decimal_metadata) const override;
+ ConvertedType::type ToConvertedType(
+ schema::DecimalMetadata* out_decimal_metadata) const override;
+ std::string ToString() const override;
+ std::string ToJSON() const override;
+ format::LogicalType ToThrift() const override;
+ bool Equals(const LogicalType& other) const override;
+
+ bool is_adjusted_to_utc() const { return adjusted_; }
+ LogicalType::TimeUnit::unit time_unit() const { return unit_; }
+
+ private:
+ Time(bool a, LogicalType::TimeUnit::unit u)
+ : LogicalType::Impl(LogicalType::Type::TIME, SortOrder::SIGNED),
+ adjusted_(a),
+ unit_(u) {}
+ bool adjusted_ = false;
+ LogicalType::TimeUnit::unit unit_;
+};
+
+bool LogicalType::Impl::Time::is_applicable(parquet::Type::type primitive_type,
+ int32_t primitive_length) const {
+ return (primitive_type == parquet::Type::INT32 &&
+ unit_ == LogicalType::TimeUnit::MILLIS) ||
+ (primitive_type == parquet::Type::INT64 &&
+ (unit_ == LogicalType::TimeUnit::MICROS ||
+ unit_ == LogicalType::TimeUnit::NANOS));
+}
+
+bool LogicalType::Impl::Time::is_compatible(
+ ConvertedType::type converted_type,
+ schema::DecimalMetadata converted_decimal_metadata) const {
+ if (converted_decimal_metadata.isset) {
+ return false;
+ } else if (adjusted_ && unit_ == LogicalType::TimeUnit::MILLIS) {
+ return converted_type == ConvertedType::TIME_MILLIS;
+ } else if (adjusted_ && unit_ == LogicalType::TimeUnit::MICROS) {
+ return converted_type == ConvertedType::TIME_MICROS;
+ } else {
+ return (converted_type == ConvertedType::NONE) ||
+ (converted_type == ConvertedType::NA);
+ }
+}
+
+ConvertedType::type LogicalType::Impl::Time::ToConvertedType(
+ schema::DecimalMetadata* out_decimal_metadata) const {
+ reset_decimal_metadata(out_decimal_metadata);
+ if (adjusted_) {
+ if (unit_ == LogicalType::TimeUnit::MILLIS) {
+ return ConvertedType::TIME_MILLIS;
+ } else if (unit_ == LogicalType::TimeUnit::MICROS) {
+ return ConvertedType::TIME_MICROS;
+ }
+ }
+ return ConvertedType::NONE;
+}
+
+std::string LogicalType::Impl::Time::ToString() const {
+ std::stringstream type;
+ type << "Time(isAdjustedToUTC=" << std::boolalpha << adjusted_
+ << ", timeUnit=" << time_unit_string(unit_) << ")";
+ return type.str();
+}
+
+std::string LogicalType::Impl::Time::ToJSON() const {
+ std::stringstream json;
+ json << R"({"Type": "Time", "isAdjustedToUTC": )" << std::boolalpha << adjusted_
+ << R"(, "timeUnit": ")" << time_unit_string(unit_) << R"("})";
+ return json.str();
+}
+
+format::LogicalType LogicalType::Impl::Time::ToThrift() const {
+ format::LogicalType type;
+ format::TimeType time_type;
+ format::TimeUnit time_unit;
+ DCHECK(unit_ != LogicalType::TimeUnit::UNKNOWN);
+ if (unit_ == LogicalType::TimeUnit::MILLIS) {
+ format::MilliSeconds millis;
+ time_unit.__set_MILLIS(millis);
+ } else if (unit_ == LogicalType::TimeUnit::MICROS) {
+ format::MicroSeconds micros;
+ time_unit.__set_MICROS(micros);
+ } else if (unit_ == LogicalType::TimeUnit::NANOS) {
+ format::NanoSeconds nanos;
+ time_unit.__set_NANOS(nanos);
+ }
+ time_type.__set_isAdjustedToUTC(adjusted_);
+ time_type.__set_unit(time_unit);
+ type.__set_TIME(time_type);
+ return type;
+}
+
+bool LogicalType::Impl::Time::Equals(const LogicalType& other) const {
+ bool eq = false;
+ if (other.is_time()) {
+ const auto& other_time = checked_cast<const TimeLogicalType&>(other);
+ eq =
+ (adjusted_ == other_time.is_adjusted_to_utc() && unit_ == other_time.time_unit());
+ }
+ return eq;
+}
+
+std::shared_ptr<const LogicalType> TimeLogicalType::Make(
+ bool is_adjusted_to_utc, LogicalType::TimeUnit::unit time_unit) {
+ if (time_unit == LogicalType::TimeUnit::MILLIS ||
+ time_unit == LogicalType::TimeUnit::MICROS ||
+ time_unit == LogicalType::TimeUnit::NANOS) {
+ auto* logical_type = new TimeLogicalType();
+ logical_type->impl_.reset(new LogicalType::Impl::Time(is_adjusted_to_utc, time_unit));
+ return std::shared_ptr<const LogicalType>(logical_type);
+ } else {
+ throw ParquetException(
+ "TimeUnit must be one of MILLIS, MICROS, or NANOS for Time logical type");
+ }
+}
+
+bool TimeLogicalType::is_adjusted_to_utc() const {
+ return (dynamic_cast<const LogicalType::Impl::Time&>(*impl_)).is_adjusted_to_utc();
+}
+
+LogicalType::TimeUnit::unit TimeLogicalType::time_unit() const {
+ return (dynamic_cast<const LogicalType::Impl::Time&>(*impl_)).time_unit();
+}
+
+class LogicalType::Impl::Timestamp final : public LogicalType::Impl::Compatible,
+ public LogicalType::Impl::SimpleApplicable {
+ public:
+ friend class TimestampLogicalType;
+
+ bool is_serialized() const override;
+ bool is_compatible(ConvertedType::type converted_type,
+ schema::DecimalMetadata converted_decimal_metadata) const override;
+ ConvertedType::type ToConvertedType(
+ schema::DecimalMetadata* out_decimal_metadata) const override;
+ std::string ToString() const override;
+ std::string ToJSON() const override;
+ format::LogicalType ToThrift() const override;
+ bool Equals(const LogicalType& other) const override;
+
+ bool is_adjusted_to_utc() const { return adjusted_; }
+ LogicalType::TimeUnit::unit time_unit() const { return unit_; }
+
+ bool is_from_converted_type() const { return is_from_converted_type_; }
+ bool force_set_converted_type() const { return force_set_converted_type_; }
+
+ private:
+ Timestamp(bool adjusted, LogicalType::TimeUnit::unit unit, bool is_from_converted_type,
+ bool force_set_converted_type)
+ : LogicalType::Impl(LogicalType::Type::TIMESTAMP, SortOrder::SIGNED),
+ LogicalType::Impl::SimpleApplicable(parquet::Type::INT64),
+ adjusted_(adjusted),
+ unit_(unit),
+ is_from_converted_type_(is_from_converted_type),
+ force_set_converted_type_(force_set_converted_type) {}
+ bool adjusted_ = false;
+ LogicalType::TimeUnit::unit unit_;
+ bool is_from_converted_type_ = false;
+ bool force_set_converted_type_ = false;
+};
+
+bool LogicalType::Impl::Timestamp::is_serialized() const {
+ return !is_from_converted_type_;
+}
+
+bool LogicalType::Impl::Timestamp::is_compatible(
+ ConvertedType::type converted_type,
+ schema::DecimalMetadata converted_decimal_metadata) const {
+ if (converted_decimal_metadata.isset) {
+ return false;
+ } else if (unit_ == LogicalType::TimeUnit::MILLIS) {
+ if (adjusted_ || force_set_converted_type_) {
+ return converted_type == ConvertedType::TIMESTAMP_MILLIS;
+ } else {
+ return (converted_type == ConvertedType::NONE) ||
+ (converted_type == ConvertedType::NA);
+ }
+ } else if (unit_ == LogicalType::TimeUnit::MICROS) {
+ if (adjusted_ || force_set_converted_type_) {
+ return converted_type == ConvertedType::TIMESTAMP_MICROS;
+ } else {
+ return (converted_type == ConvertedType::NONE) ||
+ (converted_type == ConvertedType::NA);
+ }
+ } else {
+ return (converted_type == ConvertedType::NONE) ||
+ (converted_type == ConvertedType::NA);
+ }
+}
+
+ConvertedType::type LogicalType::Impl::Timestamp::ToConvertedType(
+ schema::DecimalMetadata* out_decimal_metadata) const {
+ reset_decimal_metadata(out_decimal_metadata);
+ if (adjusted_ || force_set_converted_type_) {
+ if (unit_ == LogicalType::TimeUnit::MILLIS) {
+ return ConvertedType::TIMESTAMP_MILLIS;
+ } else if (unit_ == LogicalType::TimeUnit::MICROS) {
+ return ConvertedType::TIMESTAMP_MICROS;
+ }
+ }
+ return ConvertedType::NONE;
+}
+
+std::string LogicalType::Impl::Timestamp::ToString() const {
+ std::stringstream type;
+ type << "Timestamp(isAdjustedToUTC=" << std::boolalpha << adjusted_
+ << ", timeUnit=" << time_unit_string(unit_)
+ << ", is_from_converted_type=" << is_from_converted_type_
+ << ", force_set_converted_type=" << force_set_converted_type_ << ")";
+ return type.str();
+}
+
+std::string LogicalType::Impl::Timestamp::ToJSON() const {
+ std::stringstream json;
+ json << R"({"Type": "Timestamp", "isAdjustedToUTC": )" << std::boolalpha << adjusted_
+ << R"(, "timeUnit": ")" << time_unit_string(unit_) << R"(")"
+ << R"(, "is_from_converted_type": )" << is_from_converted_type_
+ << R"(, "force_set_converted_type": )" << force_set_converted_type_ << R"(})";
+ return json.str();
+}
+
+format::LogicalType LogicalType::Impl::Timestamp::ToThrift() const {
+ format::LogicalType type;
+ format::TimestampType timestamp_type;
+ format::TimeUnit time_unit;
+ DCHECK(unit_ != LogicalType::TimeUnit::UNKNOWN);
+ if (unit_ == LogicalType::TimeUnit::MILLIS) {
+ format::MilliSeconds millis;
+ time_unit.__set_MILLIS(millis);
+ } else if (unit_ == LogicalType::TimeUnit::MICROS) {
+ format::MicroSeconds micros;
+ time_unit.__set_MICROS(micros);
+ } else if (unit_ == LogicalType::TimeUnit::NANOS) {
+ format::NanoSeconds nanos;
+ time_unit.__set_NANOS(nanos);
+ }
+ timestamp_type.__set_isAdjustedToUTC(adjusted_);
+ timestamp_type.__set_unit(time_unit);
+ type.__set_TIMESTAMP(timestamp_type);
+ return type;
+}
+
+bool LogicalType::Impl::Timestamp::Equals(const LogicalType& other) const {
+ bool eq = false;
+ if (other.is_timestamp()) {
+ const auto& other_timestamp = checked_cast<const TimestampLogicalType&>(other);
+ eq = (adjusted_ == other_timestamp.is_adjusted_to_utc() &&
+ unit_ == other_timestamp.time_unit());
+ }
+ return eq;
+}
+
+std::shared_ptr<const LogicalType> TimestampLogicalType::Make(
+ bool is_adjusted_to_utc, LogicalType::TimeUnit::unit time_unit,
+ bool is_from_converted_type, bool force_set_converted_type) {
+ if (time_unit == LogicalType::TimeUnit::MILLIS ||
+ time_unit == LogicalType::TimeUnit::MICROS ||
+ time_unit == LogicalType::TimeUnit::NANOS) {
+ auto* logical_type = new TimestampLogicalType();
+ logical_type->impl_.reset(new LogicalType::Impl::Timestamp(
+ is_adjusted_to_utc, time_unit, is_from_converted_type, force_set_converted_type));
+ return std::shared_ptr<const LogicalType>(logical_type);
+ } else {
+ throw ParquetException(
+ "TimeUnit must be one of MILLIS, MICROS, or NANOS for Timestamp logical type");
+ }
+}
+
+bool TimestampLogicalType::is_adjusted_to_utc() const {
+ return (dynamic_cast<const LogicalType::Impl::Timestamp&>(*impl_)).is_adjusted_to_utc();
+}
+
+LogicalType::TimeUnit::unit TimestampLogicalType::time_unit() const {
+ return (dynamic_cast<const LogicalType::Impl::Timestamp&>(*impl_)).time_unit();
+}
+
+bool TimestampLogicalType::is_from_converted_type() const {
+ return (dynamic_cast<const LogicalType::Impl::Timestamp&>(*impl_))
+ .is_from_converted_type();
+}
+
+bool TimestampLogicalType::force_set_converted_type() const {
+ return (dynamic_cast<const LogicalType::Impl::Timestamp&>(*impl_))
+ .force_set_converted_type();
+}
+
+class LogicalType::Impl::Interval final : public LogicalType::Impl::SimpleCompatible,
+ public LogicalType::Impl::TypeLengthApplicable {
+ public:
+ friend class IntervalLogicalType;
+
+ OVERRIDE_TOSTRING(Interval)
+ // TODO(tpboudreau): uncomment the following line to enable serialization after
+ // parquet.thrift recognizes IntervalType as a ConvertedType
+ // OVERRIDE_TOTHRIFT(IntervalType, INTERVAL)
+
+ private:
+ Interval()
+ : LogicalType::Impl(LogicalType::Type::INTERVAL, SortOrder::UNKNOWN),
+ LogicalType::Impl::SimpleCompatible(ConvertedType::INTERVAL),
+ LogicalType::Impl::TypeLengthApplicable(parquet::Type::FIXED_LEN_BYTE_ARRAY, 12) {
+ }
+};
+
+GENERATE_MAKE(Interval)
+
+class LogicalType::Impl::Int final : public LogicalType::Impl::Compatible,
+ public LogicalType::Impl::Applicable {
+ public:
+ friend class IntLogicalType;
+
+ bool is_applicable(parquet::Type::type primitive_type,
+ int32_t primitive_length = -1) const override;
+ bool is_compatible(ConvertedType::type converted_type,
+ schema::DecimalMetadata converted_decimal_metadata) const override;
+ ConvertedType::type ToConvertedType(
+ schema::DecimalMetadata* out_decimal_metadata) const override;
+ std::string ToString() const override;
+ std::string ToJSON() const override;
+ format::LogicalType ToThrift() const override;
+ bool Equals(const LogicalType& other) const override;
+
+ int bit_width() const { return width_; }
+ bool is_signed() const { return signed_; }
+
+ private:
+ Int(int w, bool s)
+ : LogicalType::Impl(LogicalType::Type::INT,
+ (s ? SortOrder::SIGNED : SortOrder::UNSIGNED)),
+ width_(w),
+ signed_(s) {}
+ int width_ = 0;
+ bool signed_ = false;
+};
+
+bool LogicalType::Impl::Int::is_applicable(parquet::Type::type primitive_type,
+ int32_t primitive_length) const {
+ return (primitive_type == parquet::Type::INT32 && width_ <= 32) ||
+ (primitive_type == parquet::Type::INT64 && width_ == 64);
+}
+
+bool LogicalType::Impl::Int::is_compatible(
+ ConvertedType::type converted_type,
+ schema::DecimalMetadata converted_decimal_metadata) const {
+ if (converted_decimal_metadata.isset) {
+ return false;
+ } else if (signed_ && width_ == 8) {
+ return converted_type == ConvertedType::INT_8;
+ } else if (signed_ && width_ == 16) {
+ return converted_type == ConvertedType::INT_16;
+ } else if (signed_ && width_ == 32) {
+ return converted_type == ConvertedType::INT_32;
+ } else if (signed_ && width_ == 64) {
+ return converted_type == ConvertedType::INT_64;
+ } else if (!signed_ && width_ == 8) {
+ return converted_type == ConvertedType::UINT_8;
+ } else if (!signed_ && width_ == 16) {
+ return converted_type == ConvertedType::UINT_16;
+ } else if (!signed_ && width_ == 32) {
+ return converted_type == ConvertedType::UINT_32;
+ } else if (!signed_ && width_ == 64) {
+ return converted_type == ConvertedType::UINT_64;
+ } else {
+ return false;
+ }
+}
+
+ConvertedType::type LogicalType::Impl::Int::ToConvertedType(
+ schema::DecimalMetadata* out_decimal_metadata) const {
+ reset_decimal_metadata(out_decimal_metadata);
+ if (signed_) {
+ switch (width_) {
+ case 8:
+ return ConvertedType::INT_8;
+ case 16:
+ return ConvertedType::INT_16;
+ case 32:
+ return ConvertedType::INT_32;
+ case 64:
+ return ConvertedType::INT_64;
+ }
+ } else { // unsigned
+ switch (width_) {
+ case 8:
+ return ConvertedType::UINT_8;
+ case 16:
+ return ConvertedType::UINT_16;
+ case 32:
+ return ConvertedType::UINT_32;
+ case 64:
+ return ConvertedType::UINT_64;
+ }
+ }
+ return ConvertedType::NONE;
+}
+
+std::string LogicalType::Impl::Int::ToString() const {
+ std::stringstream type;
+ type << "Int(bitWidth=" << width_ << ", isSigned=" << std::boolalpha << signed_ << ")";
+ return type.str();
+}
+
+std::string LogicalType::Impl::Int::ToJSON() const {
+ std::stringstream json;
+ json << R"({"Type": "Int", "bitWidth": )" << width_ << R"(, "isSigned": )"
+ << std::boolalpha << signed_ << "}";
+ return json.str();
+}
+
+format::LogicalType LogicalType::Impl::Int::ToThrift() const {
+ format::LogicalType type;
+ format::IntType int_type;
+ DCHECK(width_ == 64 || width_ == 32 || width_ == 16 || width_ == 8);
+ int_type.__set_bitWidth(static_cast<int8_t>(width_));
+ int_type.__set_isSigned(signed_);
+ type.__set_INTEGER(int_type);
+ return type;
+}
+
+bool LogicalType::Impl::Int::Equals(const LogicalType& other) const {
+ bool eq = false;
+ if (other.is_int()) {
+ const auto& other_int = checked_cast<const IntLogicalType&>(other);
+ eq = (width_ == other_int.bit_width() && signed_ == other_int.is_signed());
+ }
+ return eq;
+}
+
+std::shared_ptr<const LogicalType> IntLogicalType::Make(int bit_width, bool is_signed) {
+ if (bit_width == 8 || bit_width == 16 || bit_width == 32 || bit_width == 64) {
+ auto* logical_type = new IntLogicalType();
+ logical_type->impl_.reset(new LogicalType::Impl::Int(bit_width, is_signed));
+ return std::shared_ptr<const LogicalType>(logical_type);
+ } else {
+ throw ParquetException(
+ "Bit width must be exactly 8, 16, 32, or 64 for Int logical type");
+ }
+}
+
+int IntLogicalType::bit_width() const {
+ return (dynamic_cast<const LogicalType::Impl::Int&>(*impl_)).bit_width();
+}
+
+bool IntLogicalType::is_signed() const {
+ return (dynamic_cast<const LogicalType::Impl::Int&>(*impl_)).is_signed();
+}
+
+class LogicalType::Impl::Null final : public LogicalType::Impl::Incompatible,
+ public LogicalType::Impl::UniversalApplicable {
+ public:
+ friend class NullLogicalType;
+
+ OVERRIDE_TOSTRING(Null)
+ OVERRIDE_TOTHRIFT(NullType, UNKNOWN)
+
+ private:
+ Null() : LogicalType::Impl(LogicalType::Type::NIL, SortOrder::UNKNOWN) {}
+};
+
+GENERATE_MAKE(Null)
+
+class LogicalType::Impl::JSON final : public LogicalType::Impl::SimpleCompatible,
+ public LogicalType::Impl::SimpleApplicable {
+ public:
+ friend class JSONLogicalType;
+
+ OVERRIDE_TOSTRING(JSON)
+ OVERRIDE_TOTHRIFT(JsonType, JSON)
+
+ private:
+ JSON()
+ : LogicalType::Impl(LogicalType::Type::JSON, SortOrder::UNSIGNED),
+ LogicalType::Impl::SimpleCompatible(ConvertedType::JSON),
+ LogicalType::Impl::SimpleApplicable(parquet::Type::BYTE_ARRAY) {}
+};
+
+GENERATE_MAKE(JSON)
+
+class LogicalType::Impl::BSON final : public LogicalType::Impl::SimpleCompatible,
+ public LogicalType::Impl::SimpleApplicable {
+ public:
+ friend class BSONLogicalType;
+
+ OVERRIDE_TOSTRING(BSON)
+ OVERRIDE_TOTHRIFT(BsonType, BSON)
+
+ private:
+ BSON()
+ : LogicalType::Impl(LogicalType::Type::BSON, SortOrder::UNSIGNED),
+ LogicalType::Impl::SimpleCompatible(ConvertedType::BSON),
+ LogicalType::Impl::SimpleApplicable(parquet::Type::BYTE_ARRAY) {}
+};
+
+GENERATE_MAKE(BSON)
+
+class LogicalType::Impl::UUID final : public LogicalType::Impl::Incompatible,
+ public LogicalType::Impl::TypeLengthApplicable {
+ public:
+ friend class UUIDLogicalType;
+
+ OVERRIDE_TOSTRING(UUID)
+ OVERRIDE_TOTHRIFT(UUIDType, UUID)
+
+ private:
+ UUID()
+ : LogicalType::Impl(LogicalType::Type::UUID, SortOrder::UNSIGNED),
+ LogicalType::Impl::TypeLengthApplicable(parquet::Type::FIXED_LEN_BYTE_ARRAY, 16) {
+ }
+};
+
+GENERATE_MAKE(UUID)
+
+class LogicalType::Impl::No final : public LogicalType::Impl::SimpleCompatible,
+ public LogicalType::Impl::UniversalApplicable {
+ public:
+ friend class NoLogicalType;
+
+ OVERRIDE_TOSTRING(None)
+
+ private:
+ No()
+ : LogicalType::Impl(LogicalType::Type::NONE, SortOrder::UNKNOWN),
+ LogicalType::Impl::SimpleCompatible(ConvertedType::NONE) {}
+};
+
+GENERATE_MAKE(No)
+
+class LogicalType::Impl::Undefined final : public LogicalType::Impl::SimpleCompatible,
+ public LogicalType::Impl::UniversalApplicable {
+ public:
+ friend class UndefinedLogicalType;
+
+ OVERRIDE_TOSTRING(Undefined)
+
+ private:
+ Undefined()
+ : LogicalType::Impl(LogicalType::Type::UNDEFINED, SortOrder::UNKNOWN),
+ LogicalType::Impl::SimpleCompatible(ConvertedType::UNDEFINED) {}
+};
+
+GENERATE_MAKE(Undefined)
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/types.h b/src/arrow/cpp/src/parquet/types.h
new file mode 100644
index 000000000..505a6c5cb
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/types.h
@@ -0,0 +1,766 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include "arrow/util/string_view.h"
+
+#include "parquet/platform.h"
+#include "parquet/type_fwd.h"
+
+#ifdef _WIN32
+
+// Repetition::OPTIONAL conflicts with a #define, so we undefine it
+#ifdef OPTIONAL
+#undef OPTIONAL
+#endif
+
+#endif // _WIN32
+
+namespace arrow {
+namespace util {
+
+class Codec;
+
+} // namespace util
+} // namespace arrow
+
+namespace parquet {
+
+// ----------------------------------------------------------------------
+// Metadata enums to match Thrift metadata
+//
+// The reason we maintain our own enums is to avoid transitive dependency on
+// the compiled Thrift headers (and thus thrift/Thrift.h) for users of the
+// public API. After building parquet-cpp, you should not need to include
+// Thrift headers in your application. This means some boilerplate to convert
+// between our types and Parquet's Thrift types.
+//
+// We can also add special values like NONE to distinguish between metadata
+// values being set and not set. As an example consider ConvertedType and
+// CompressionCodec
+
+// Mirrors parquet::Type
+struct Type {
+ enum type {
+ BOOLEAN = 0,
+ INT32 = 1,
+ INT64 = 2,
+ INT96 = 3,
+ FLOAT = 4,
+ DOUBLE = 5,
+ BYTE_ARRAY = 6,
+ FIXED_LEN_BYTE_ARRAY = 7,
+ // Should always be last element.
+ UNDEFINED = 8
+ };
+};
+
+// Mirrors parquet::ConvertedType
+struct ConvertedType {
+ enum type {
+ NONE, // Not a real converted type, but means no converted type is specified
+ UTF8,
+ MAP,
+ MAP_KEY_VALUE,
+ LIST,
+ ENUM,
+ DECIMAL,
+ DATE,
+ TIME_MILLIS,
+ TIME_MICROS,
+ TIMESTAMP_MILLIS,
+ TIMESTAMP_MICROS,
+ UINT_8,
+ UINT_16,
+ UINT_32,
+ UINT_64,
+ INT_8,
+ INT_16,
+ INT_32,
+ INT_64,
+ JSON,
+ BSON,
+ INTERVAL,
+ // DEPRECATED INVALID ConvertedType for all-null data.
+ // Only useful for reading legacy files written out by interim Parquet C++ releases.
+ // For writing, always emit LogicalType::Null instead.
+ // See PARQUET-1990.
+ NA = 25,
+ UNDEFINED = 26 // Not a real converted type; should always be last element
+ };
+};
+
+// forward declaration
+namespace format {
+
+class LogicalType;
+
+}
+
+// Mirrors parquet::FieldRepetitionType
+struct Repetition {
+ enum type { REQUIRED = 0, OPTIONAL = 1, REPEATED = 2, /*Always last*/ UNDEFINED = 3 };
+};
+
+// Reference:
+// parquet-mr/parquet-hadoop/src/main/java/org/apache/parquet/
+// format/converter/ParquetMetadataConverter.java
+// Sort order for page and column statistics. Types are associated with sort
+// orders (e.g., UTF8 columns should use UNSIGNED) and column stats are
+// aggregated using a sort order. As of parquet-format version 2.3.1, the
+// order used to aggregate stats is always SIGNED and is not stored in the
+// Parquet file. These stats are discarded for types that need unsigned.
+// See PARQUET-686.
+struct SortOrder {
+ enum type { SIGNED, UNSIGNED, UNKNOWN };
+};
+
+namespace schema {
+
+struct DecimalMetadata {
+ bool isset;
+ int32_t scale;
+ int32_t precision;
+};
+
+} // namespace schema
+
+/// \brief Implementation of parquet.thrift LogicalType types.
+class PARQUET_EXPORT LogicalType {
+ public:
+ struct Type {
+ enum type {
+ UNDEFINED = 0, // Not a real logical type
+ STRING = 1,
+ MAP,
+ LIST,
+ ENUM,
+ DECIMAL,
+ DATE,
+ TIME,
+ TIMESTAMP,
+ INTERVAL,
+ INT,
+ NIL, // Thrift NullType: annotates data that is always null
+ JSON,
+ BSON,
+ UUID,
+ NONE // Not a real logical type; should always be last element
+ };
+ };
+
+ struct TimeUnit {
+ enum unit { UNKNOWN = 0, MILLIS = 1, MICROS, NANOS };
+ };
+
+ /// \brief If possible, return a logical type equivalent to the given legacy
+ /// converted type (and decimal metadata if applicable).
+ static std::shared_ptr<const LogicalType> FromConvertedType(
+ const parquet::ConvertedType::type converted_type,
+ const parquet::schema::DecimalMetadata converted_decimal_metadata = {false, -1,
+ -1});
+
+ /// \brief Return the logical type represented by the Thrift intermediary object.
+ static std::shared_ptr<const LogicalType> FromThrift(
+ const parquet::format::LogicalType& thrift_logical_type);
+
+ /// \brief Return the explicitly requested logical type.
+ static std::shared_ptr<const LogicalType> String();
+ static std::shared_ptr<const LogicalType> Map();
+ static std::shared_ptr<const LogicalType> List();
+ static std::shared_ptr<const LogicalType> Enum();
+ static std::shared_ptr<const LogicalType> Decimal(int32_t precision, int32_t scale = 0);
+ static std::shared_ptr<const LogicalType> Date();
+ static std::shared_ptr<const LogicalType> Time(bool is_adjusted_to_utc,
+ LogicalType::TimeUnit::unit time_unit);
+
+ /// \brief Create a Timestamp logical type
+ /// \param[in] is_adjusted_to_utc set true if the data is UTC-normalized
+ /// \param[in] time_unit the resolution of the timestamp
+ /// \param[in] is_from_converted_type if true, the timestamp was generated
+ /// by translating a legacy converted type of TIMESTAMP_MILLIS or
+ /// TIMESTAMP_MICROS. Default is false.
+ /// \param[in] force_set_converted_type if true, always set the
+ /// legacy ConvertedType TIMESTAMP_MICROS and TIMESTAMP_MILLIS
+ /// metadata. Default is false
+ static std::shared_ptr<const LogicalType> Timestamp(
+ bool is_adjusted_to_utc, LogicalType::TimeUnit::unit time_unit,
+ bool is_from_converted_type = false, bool force_set_converted_type = false);
+
+ static std::shared_ptr<const LogicalType> Interval();
+ static std::shared_ptr<const LogicalType> Int(int bit_width, bool is_signed);
+
+ /// \brief Create a logical type for data that's always null
+ ///
+ /// Any physical type can be annotated with this logical type.
+ static std::shared_ptr<const LogicalType> Null();
+
+ static std::shared_ptr<const LogicalType> JSON();
+ static std::shared_ptr<const LogicalType> BSON();
+ static std::shared_ptr<const LogicalType> UUID();
+
+ /// \brief Create a placeholder for when no logical type is specified
+ static std::shared_ptr<const LogicalType> None();
+
+ /// \brief Return true if this logical type is consistent with the given underlying
+ /// physical type.
+ bool is_applicable(parquet::Type::type primitive_type,
+ int32_t primitive_length = -1) const;
+
+ /// \brief Return true if this logical type is equivalent to the given legacy converted
+ /// type (and decimal metadata if applicable).
+ bool is_compatible(parquet::ConvertedType::type converted_type,
+ parquet::schema::DecimalMetadata converted_decimal_metadata = {
+ false, -1, -1}) const;
+
+ /// \brief If possible, return the legacy converted type (and decimal metadata if
+ /// applicable) equivalent to this logical type.
+ parquet::ConvertedType::type ToConvertedType(
+ parquet::schema::DecimalMetadata* out_decimal_metadata) const;
+
+ /// \brief Return a printable representation of this logical type.
+ std::string ToString() const;
+
+ /// \brief Return a JSON representation of this logical type.
+ std::string ToJSON() const;
+
+ /// \brief Return a serializable Thrift object for this logical type.
+ parquet::format::LogicalType ToThrift() const;
+
+ /// \brief Return true if the given logical type is equivalent to this logical type.
+ bool Equals(const LogicalType& other) const;
+
+ /// \brief Return the enumerated type of this logical type.
+ LogicalType::Type::type type() const;
+
+ /// \brief Return the appropriate sort order for this logical type.
+ SortOrder::type sort_order() const;
+
+ // Type checks ...
+ bool is_string() const;
+ bool is_map() const;
+ bool is_list() const;
+ bool is_enum() const;
+ bool is_decimal() const;
+ bool is_date() const;
+ bool is_time() const;
+ bool is_timestamp() const;
+ bool is_interval() const;
+ bool is_int() const;
+ bool is_null() const;
+ bool is_JSON() const;
+ bool is_BSON() const;
+ bool is_UUID() const;
+ bool is_none() const;
+ /// \brief Return true if this logical type is of a known type.
+ bool is_valid() const;
+ bool is_invalid() const;
+ /// \brief Return true if this logical type is suitable for a schema GroupNode.
+ bool is_nested() const;
+ bool is_nonnested() const;
+ /// \brief Return true if this logical type is included in the Thrift output for its
+ /// node.
+ bool is_serialized() const;
+
+ LogicalType(const LogicalType&) = delete;
+ LogicalType& operator=(const LogicalType&) = delete;
+ virtual ~LogicalType() noexcept;
+
+ protected:
+ LogicalType();
+
+ class Impl;
+ std::unique_ptr<const Impl> impl_;
+};
+
+/// \brief Allowed for physical type BYTE_ARRAY, must be encoded as UTF-8.
+class PARQUET_EXPORT StringLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make();
+
+ private:
+ StringLogicalType() = default;
+};
+
+/// \brief Allowed for group nodes only.
+class PARQUET_EXPORT MapLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make();
+
+ private:
+ MapLogicalType() = default;
+};
+
+/// \brief Allowed for group nodes only.
+class PARQUET_EXPORT ListLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make();
+
+ private:
+ ListLogicalType() = default;
+};
+
+/// \brief Allowed for physical type BYTE_ARRAY, must be encoded as UTF-8.
+class PARQUET_EXPORT EnumLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make();
+
+ private:
+ EnumLogicalType() = default;
+};
+
+/// \brief Allowed for physical type INT32, INT64, FIXED_LEN_BYTE_ARRAY, or BYTE_ARRAY,
+/// depending on the precision.
+class PARQUET_EXPORT DecimalLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make(int32_t precision, int32_t scale = 0);
+ int32_t precision() const;
+ int32_t scale() const;
+
+ private:
+ DecimalLogicalType() = default;
+};
+
+/// \brief Allowed for physical type INT32.
+class PARQUET_EXPORT DateLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make();
+
+ private:
+ DateLogicalType() = default;
+};
+
+/// \brief Allowed for physical type INT32 (for MILLIS) or INT64 (for MICROS and NANOS).
+class PARQUET_EXPORT TimeLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make(bool is_adjusted_to_utc,
+ LogicalType::TimeUnit::unit time_unit);
+ bool is_adjusted_to_utc() const;
+ LogicalType::TimeUnit::unit time_unit() const;
+
+ private:
+ TimeLogicalType() = default;
+};
+
+/// \brief Allowed for physical type INT64.
+class PARQUET_EXPORT TimestampLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make(bool is_adjusted_to_utc,
+ LogicalType::TimeUnit::unit time_unit,
+ bool is_from_converted_type = false,
+ bool force_set_converted_type = false);
+ bool is_adjusted_to_utc() const;
+ LogicalType::TimeUnit::unit time_unit() const;
+
+ /// \brief If true, will not set LogicalType in Thrift metadata
+ bool is_from_converted_type() const;
+
+ /// \brief If true, will set ConvertedType for micros and millis
+ /// resolution in legacy ConvertedType Thrift metadata
+ bool force_set_converted_type() const;
+
+ private:
+ TimestampLogicalType() = default;
+};
+
+/// \brief Allowed for physical type FIXED_LEN_BYTE_ARRAY with length 12
+class PARQUET_EXPORT IntervalLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make();
+
+ private:
+ IntervalLogicalType() = default;
+};
+
+/// \brief Allowed for physical type INT32 (for bit widths 8, 16, and 32) and INT64
+/// (for bit width 64).
+class PARQUET_EXPORT IntLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make(int bit_width, bool is_signed);
+ int bit_width() const;
+ bool is_signed() const;
+
+ private:
+ IntLogicalType() = default;
+};
+
+/// \brief Allowed for any physical type.
+class PARQUET_EXPORT NullLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make();
+
+ private:
+ NullLogicalType() = default;
+};
+
+/// \brief Allowed for physical type BYTE_ARRAY.
+class PARQUET_EXPORT JSONLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make();
+
+ private:
+ JSONLogicalType() = default;
+};
+
+/// \brief Allowed for physical type BYTE_ARRAY.
+class PARQUET_EXPORT BSONLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make();
+
+ private:
+ BSONLogicalType() = default;
+};
+
+/// \brief Allowed for physical type FIXED_LEN_BYTE_ARRAY with length 16,
+/// must encode raw UUID bytes.
+class PARQUET_EXPORT UUIDLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make();
+
+ private:
+ UUIDLogicalType() = default;
+};
+
+/// \brief Allowed for any physical type.
+class PARQUET_EXPORT NoLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make();
+
+ private:
+ NoLogicalType() = default;
+};
+
+// Internal API, for unrecognized logical types
+class PARQUET_EXPORT UndefinedLogicalType : public LogicalType {
+ public:
+ static std::shared_ptr<const LogicalType> Make();
+
+ private:
+ UndefinedLogicalType() = default;
+};
+
+// Data encodings. Mirrors parquet::Encoding
+struct Encoding {
+ enum type {
+ PLAIN = 0,
+ PLAIN_DICTIONARY = 2,
+ RLE = 3,
+ BIT_PACKED = 4,
+ DELTA_BINARY_PACKED = 5,
+ DELTA_LENGTH_BYTE_ARRAY = 6,
+ DELTA_BYTE_ARRAY = 7,
+ RLE_DICTIONARY = 8,
+ BYTE_STREAM_SPLIT = 9,
+ // Should always be last element (except UNKNOWN)
+ UNDEFINED = 10,
+ UNKNOWN = 999
+ };
+};
+
+// Exposed data encodings. It is the encoding of the data read from the file,
+// rather than the encoding of the data in the file. E.g., the data encoded as
+// RLE_DICTIONARY in the file can be read as dictionary indices by RLE
+// decoding, in which case the data read from the file is DICTIONARY encoded.
+enum class ExposedEncoding {
+ NO_ENCODING = 0, // data is not encoded, i.e. already decoded during reading
+ DICTIONARY = 1
+};
+
+/// \brief Return true if Parquet supports indicated compression type
+PARQUET_EXPORT
+bool IsCodecSupported(Compression::type codec);
+
+PARQUET_EXPORT
+std::unique_ptr<Codec> GetCodec(Compression::type codec);
+
+PARQUET_EXPORT
+std::unique_ptr<Codec> GetCodec(Compression::type codec, int compression_level);
+
+struct ParquetCipher {
+ enum type { AES_GCM_V1 = 0, AES_GCM_CTR_V1 = 1 };
+};
+
+struct AadMetadata {
+ std::string aad_prefix;
+ std::string aad_file_unique;
+ bool supply_aad_prefix;
+};
+
+struct EncryptionAlgorithm {
+ ParquetCipher::type algorithm;
+ AadMetadata aad;
+};
+
+// parquet::PageType
+struct PageType {
+ enum type {
+ DATA_PAGE,
+ INDEX_PAGE,
+ DICTIONARY_PAGE,
+ DATA_PAGE_V2,
+ // Should always be last element
+ UNDEFINED
+ };
+};
+
+class ColumnOrder {
+ public:
+ enum type { UNDEFINED, TYPE_DEFINED_ORDER };
+ explicit ColumnOrder(ColumnOrder::type column_order) : column_order_(column_order) {}
+ // Default to Type Defined Order
+ ColumnOrder() : column_order_(type::TYPE_DEFINED_ORDER) {}
+ ColumnOrder::type get_order() { return column_order_; }
+
+ static ColumnOrder undefined_;
+ static ColumnOrder type_defined_;
+
+ private:
+ ColumnOrder::type column_order_;
+};
+
+// ----------------------------------------------------------------------
+
+struct ByteArray {
+ ByteArray() : len(0), ptr(NULLPTR) {}
+ ByteArray(uint32_t len, const uint8_t* ptr) : len(len), ptr(ptr) {}
+
+ ByteArray(::arrow::util::string_view view) // NOLINT implicit conversion
+ : ByteArray(static_cast<uint32_t>(view.size()),
+ reinterpret_cast<const uint8_t*>(view.data())) {}
+ uint32_t len;
+ const uint8_t* ptr;
+};
+
+inline bool operator==(const ByteArray& left, const ByteArray& right) {
+ return left.len == right.len &&
+ (left.len == 0 || std::memcmp(left.ptr, right.ptr, left.len) == 0);
+}
+
+inline bool operator!=(const ByteArray& left, const ByteArray& right) {
+ return !(left == right);
+}
+
+struct FixedLenByteArray {
+ FixedLenByteArray() : ptr(NULLPTR) {}
+ explicit FixedLenByteArray(const uint8_t* ptr) : ptr(ptr) {}
+ const uint8_t* ptr;
+};
+
+using FLBA = FixedLenByteArray;
+
+// Julian day at unix epoch.
+//
+// The Julian Day Number (JDN) is the integer assigned to a whole solar day in
+// the Julian day count starting from noon Universal time, with Julian day
+// number 0 assigned to the day starting at noon on Monday, January 1, 4713 BC,
+// proleptic Julian calendar (November 24, 4714 BC, in the proleptic Gregorian
+// calendar),
+constexpr int64_t kJulianToUnixEpochDays = INT64_C(2440588);
+constexpr int64_t kSecondsPerDay = INT64_C(60 * 60 * 24);
+constexpr int64_t kMillisecondsPerDay = kSecondsPerDay * INT64_C(1000);
+constexpr int64_t kMicrosecondsPerDay = kMillisecondsPerDay * INT64_C(1000);
+constexpr int64_t kNanosecondsPerDay = kMicrosecondsPerDay * INT64_C(1000);
+
+MANUALLY_ALIGNED_STRUCT(1) Int96 { uint32_t value[3]; };
+STRUCT_END(Int96, 12);
+
+inline bool operator==(const Int96& left, const Int96& right) {
+ return std::equal(left.value, left.value + 3, right.value);
+}
+
+inline bool operator!=(const Int96& left, const Int96& right) { return !(left == right); }
+
+static inline std::string ByteArrayToString(const ByteArray& a) {
+ return std::string(reinterpret_cast<const char*>(a.ptr), a.len);
+}
+
+static inline void Int96SetNanoSeconds(parquet::Int96& i96, int64_t nanoseconds) {
+ std::memcpy(&i96.value, &nanoseconds, sizeof(nanoseconds));
+}
+
+struct DecodedInt96 {
+ uint64_t days_since_epoch;
+ uint64_t nanoseconds;
+};
+
+static inline DecodedInt96 DecodeInt96Timestamp(const parquet::Int96& i96) {
+ // We do the computations in the unsigned domain to avoid unsigned behaviour
+ // on overflow.
+ DecodedInt96 result;
+ result.days_since_epoch = i96.value[2] - static_cast<uint64_t>(kJulianToUnixEpochDays);
+ result.nanoseconds = 0;
+
+ memcpy(&result.nanoseconds, &i96.value, sizeof(uint64_t));
+ return result;
+}
+
+static inline int64_t Int96GetNanoSeconds(const parquet::Int96& i96) {
+ const auto decoded = DecodeInt96Timestamp(i96);
+ return static_cast<int64_t>(decoded.days_since_epoch * kNanosecondsPerDay +
+ decoded.nanoseconds);
+}
+
+static inline int64_t Int96GetMicroSeconds(const parquet::Int96& i96) {
+ const auto decoded = DecodeInt96Timestamp(i96);
+ uint64_t microseconds = decoded.nanoseconds / static_cast<uint64_t>(1000);
+ return static_cast<int64_t>(decoded.days_since_epoch * kMicrosecondsPerDay +
+ microseconds);
+}
+
+static inline int64_t Int96GetMilliSeconds(const parquet::Int96& i96) {
+ const auto decoded = DecodeInt96Timestamp(i96);
+ uint64_t milliseconds = decoded.nanoseconds / static_cast<uint64_t>(1000000);
+ return static_cast<int64_t>(decoded.days_since_epoch * kMillisecondsPerDay +
+ milliseconds);
+}
+
+static inline int64_t Int96GetSeconds(const parquet::Int96& i96) {
+ const auto decoded = DecodeInt96Timestamp(i96);
+ uint64_t seconds = decoded.nanoseconds / static_cast<uint64_t>(1000000000);
+ return static_cast<int64_t>(decoded.days_since_epoch * kSecondsPerDay + seconds);
+}
+
+static inline std::string Int96ToString(const Int96& a) {
+ std::ostringstream result;
+ std::copy(a.value, a.value + 3, std::ostream_iterator<uint32_t>(result, " "));
+ return result.str();
+}
+
+static inline std::string FixedLenByteArrayToString(const FixedLenByteArray& a, int len) {
+ std::ostringstream result;
+ std::copy(a.ptr, a.ptr + len, std::ostream_iterator<uint32_t>(result, " "));
+ return result.str();
+}
+
+template <Type::type TYPE>
+struct type_traits {};
+
+template <>
+struct type_traits<Type::BOOLEAN> {
+ using value_type = bool;
+
+ static constexpr int value_byte_size = 1;
+ static constexpr const char* printf_code = "d";
+};
+
+template <>
+struct type_traits<Type::INT32> {
+ using value_type = int32_t;
+
+ static constexpr int value_byte_size = 4;
+ static constexpr const char* printf_code = "d";
+};
+
+template <>
+struct type_traits<Type::INT64> {
+ using value_type = int64_t;
+
+ static constexpr int value_byte_size = 8;
+ static constexpr const char* printf_code =
+ (sizeof(long) == 64) ? "ld" : "lld"; // NOLINT: runtime/int
+};
+
+template <>
+struct type_traits<Type::INT96> {
+ using value_type = Int96;
+
+ static constexpr int value_byte_size = 12;
+ static constexpr const char* printf_code = "s";
+};
+
+template <>
+struct type_traits<Type::FLOAT> {
+ using value_type = float;
+
+ static constexpr int value_byte_size = 4;
+ static constexpr const char* printf_code = "f";
+};
+
+template <>
+struct type_traits<Type::DOUBLE> {
+ using value_type = double;
+
+ static constexpr int value_byte_size = 8;
+ static constexpr const char* printf_code = "lf";
+};
+
+template <>
+struct type_traits<Type::BYTE_ARRAY> {
+ using value_type = ByteArray;
+
+ static constexpr int value_byte_size = sizeof(ByteArray);
+ static constexpr const char* printf_code = "s";
+};
+
+template <>
+struct type_traits<Type::FIXED_LEN_BYTE_ARRAY> {
+ using value_type = FixedLenByteArray;
+
+ static constexpr int value_byte_size = sizeof(FixedLenByteArray);
+ static constexpr const char* printf_code = "s";
+};
+
+template <Type::type TYPE>
+struct PhysicalType {
+ using c_type = typename type_traits<TYPE>::value_type;
+ static constexpr Type::type type_num = TYPE;
+};
+
+using BooleanType = PhysicalType<Type::BOOLEAN>;
+using Int32Type = PhysicalType<Type::INT32>;
+using Int64Type = PhysicalType<Type::INT64>;
+using Int96Type = PhysicalType<Type::INT96>;
+using FloatType = PhysicalType<Type::FLOAT>;
+using DoubleType = PhysicalType<Type::DOUBLE>;
+using ByteArrayType = PhysicalType<Type::BYTE_ARRAY>;
+using FLBAType = PhysicalType<Type::FIXED_LEN_BYTE_ARRAY>;
+
+template <typename Type>
+inline std::string format_fwf(int width) {
+ std::stringstream ss;
+ ss << "%-" << width << type_traits<Type::type_num>::printf_code;
+ return ss.str();
+}
+
+PARQUET_EXPORT std::string EncodingToString(Encoding::type t);
+
+PARQUET_EXPORT std::string ConvertedTypeToString(ConvertedType::type t);
+
+PARQUET_EXPORT std::string TypeToString(Type::type t);
+
+PARQUET_EXPORT std::string FormatStatValue(Type::type parquet_type,
+ ::arrow::util::string_view val);
+
+PARQUET_EXPORT int GetTypeByteSize(Type::type t);
+
+PARQUET_EXPORT SortOrder::type DefaultSortOrder(Type::type primitive);
+
+PARQUET_EXPORT SortOrder::type GetSortOrder(ConvertedType::type converted,
+ Type::type primitive);
+
+PARQUET_EXPORT SortOrder::type GetSortOrder(
+ const std::shared_ptr<const LogicalType>& logical_type, Type::type primitive);
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/types_test.cc b/src/arrow/cpp/src/parquet/types_test.cc
new file mode 100644
index 000000000..e0ca7d635
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/types_test.cc
@@ -0,0 +1,172 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <string>
+
+#include "arrow/util/endian.h"
+#include "parquet/types.h"
+
+namespace parquet {
+
+TEST(TestTypeToString, PhysicalTypes) {
+ ASSERT_STREQ("BOOLEAN", TypeToString(Type::BOOLEAN).c_str());
+ ASSERT_STREQ("INT32", TypeToString(Type::INT32).c_str());
+ ASSERT_STREQ("INT64", TypeToString(Type::INT64).c_str());
+ ASSERT_STREQ("INT96", TypeToString(Type::INT96).c_str());
+ ASSERT_STREQ("FLOAT", TypeToString(Type::FLOAT).c_str());
+ ASSERT_STREQ("DOUBLE", TypeToString(Type::DOUBLE).c_str());
+ ASSERT_STREQ("BYTE_ARRAY", TypeToString(Type::BYTE_ARRAY).c_str());
+ ASSERT_STREQ("FIXED_LEN_BYTE_ARRAY", TypeToString(Type::FIXED_LEN_BYTE_ARRAY).c_str());
+}
+
+TEST(TestConvertedTypeToString, ConvertedTypes) {
+ ASSERT_STREQ("NONE", ConvertedTypeToString(ConvertedType::NONE).c_str());
+ ASSERT_STREQ("UTF8", ConvertedTypeToString(ConvertedType::UTF8).c_str());
+ ASSERT_STREQ("MAP", ConvertedTypeToString(ConvertedType::MAP).c_str());
+ ASSERT_STREQ("MAP_KEY_VALUE",
+ ConvertedTypeToString(ConvertedType::MAP_KEY_VALUE).c_str());
+ ASSERT_STREQ("LIST", ConvertedTypeToString(ConvertedType::LIST).c_str());
+ ASSERT_STREQ("ENUM", ConvertedTypeToString(ConvertedType::ENUM).c_str());
+ ASSERT_STREQ("DECIMAL", ConvertedTypeToString(ConvertedType::DECIMAL).c_str());
+ ASSERT_STREQ("DATE", ConvertedTypeToString(ConvertedType::DATE).c_str());
+ ASSERT_STREQ("TIME_MILLIS", ConvertedTypeToString(ConvertedType::TIME_MILLIS).c_str());
+ ASSERT_STREQ("TIME_MICROS", ConvertedTypeToString(ConvertedType::TIME_MICROS).c_str());
+ ASSERT_STREQ("TIMESTAMP_MILLIS",
+ ConvertedTypeToString(ConvertedType::TIMESTAMP_MILLIS).c_str());
+ ASSERT_STREQ("TIMESTAMP_MICROS",
+ ConvertedTypeToString(ConvertedType::TIMESTAMP_MICROS).c_str());
+ ASSERT_STREQ("UINT_8", ConvertedTypeToString(ConvertedType::UINT_8).c_str());
+ ASSERT_STREQ("UINT_16", ConvertedTypeToString(ConvertedType::UINT_16).c_str());
+ ASSERT_STREQ("UINT_32", ConvertedTypeToString(ConvertedType::UINT_32).c_str());
+ ASSERT_STREQ("UINT_64", ConvertedTypeToString(ConvertedType::UINT_64).c_str());
+ ASSERT_STREQ("INT_8", ConvertedTypeToString(ConvertedType::INT_8).c_str());
+ ASSERT_STREQ("INT_16", ConvertedTypeToString(ConvertedType::INT_16).c_str());
+ ASSERT_STREQ("INT_32", ConvertedTypeToString(ConvertedType::INT_32).c_str());
+ ASSERT_STREQ("INT_64", ConvertedTypeToString(ConvertedType::INT_64).c_str());
+ ASSERT_STREQ("JSON", ConvertedTypeToString(ConvertedType::JSON).c_str());
+ ASSERT_STREQ("BSON", ConvertedTypeToString(ConvertedType::BSON).c_str());
+ ASSERT_STREQ("INTERVAL", ConvertedTypeToString(ConvertedType::INTERVAL).c_str());
+}
+
+#ifdef __GNUC__
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+#elif defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4996)
+#endif
+
+TEST(TypePrinter, StatisticsTypes) {
+ std::string smin;
+ std::string smax;
+ int32_t int_min = 1024;
+ int32_t int_max = 2048;
+ smin = std::string(reinterpret_cast<char*>(&int_min), sizeof(int32_t));
+ smax = std::string(reinterpret_cast<char*>(&int_max), sizeof(int32_t));
+ ASSERT_STREQ("1024", FormatStatValue(Type::INT32, smin).c_str());
+ ASSERT_STREQ("2048", FormatStatValue(Type::INT32, smax).c_str());
+
+ int64_t int64_min = 10240000000000;
+ int64_t int64_max = 20480000000000;
+ smin = std::string(reinterpret_cast<char*>(&int64_min), sizeof(int64_t));
+ smax = std::string(reinterpret_cast<char*>(&int64_max), sizeof(int64_t));
+ ASSERT_STREQ("10240000000000", FormatStatValue(Type::INT64, smin).c_str());
+ ASSERT_STREQ("20480000000000", FormatStatValue(Type::INT64, smax).c_str());
+
+ float float_min = 1.024f;
+ float float_max = 2.048f;
+ smin = std::string(reinterpret_cast<char*>(&float_min), sizeof(float));
+ smax = std::string(reinterpret_cast<char*>(&float_max), sizeof(float));
+ ASSERT_STREQ("1.024", FormatStatValue(Type::FLOAT, smin).c_str());
+ ASSERT_STREQ("2.048", FormatStatValue(Type::FLOAT, smax).c_str());
+
+ double double_min = 1.0245;
+ double double_max = 2.0489;
+ smin = std::string(reinterpret_cast<char*>(&double_min), sizeof(double));
+ smax = std::string(reinterpret_cast<char*>(&double_max), sizeof(double));
+ ASSERT_STREQ("1.0245", FormatStatValue(Type::DOUBLE, smin).c_str());
+ ASSERT_STREQ("2.0489", FormatStatValue(Type::DOUBLE, smax).c_str());
+
+#if ARROW_LITTLE_ENDIAN
+ Int96 Int96_min = {{1024, 2048, 4096}};
+ Int96 Int96_max = {{2048, 4096, 8192}};
+#else
+ Int96 Int96_min = {{2048, 1024, 4096}};
+ Int96 Int96_max = {{4096, 2048, 8192}};
+#endif
+ smin = std::string(reinterpret_cast<char*>(&Int96_min), sizeof(Int96));
+ smax = std::string(reinterpret_cast<char*>(&Int96_max), sizeof(Int96));
+ ASSERT_STREQ("1024 2048 4096", FormatStatValue(Type::INT96, smin).c_str());
+ ASSERT_STREQ("2048 4096 8192", FormatStatValue(Type::INT96, smax).c_str());
+
+ smin = std::string("abcdef");
+ smax = std::string("ijklmnop");
+ ASSERT_STREQ("abcdef", FormatStatValue(Type::BYTE_ARRAY, smin).c_str());
+ ASSERT_STREQ("ijklmnop", FormatStatValue(Type::BYTE_ARRAY, smax).c_str());
+
+ // PARQUET-1357: FormatStatValue truncates binary statistics on zero character
+ smax.push_back('\0');
+ ASSERT_EQ(smax, FormatStatValue(Type::BYTE_ARRAY, smax));
+
+ smin = std::string("abcdefgh");
+ smax = std::string("ijklmnop");
+ ASSERT_STREQ("abcdefgh", FormatStatValue(Type::FIXED_LEN_BYTE_ARRAY, smin).c_str());
+ ASSERT_STREQ("ijklmnop", FormatStatValue(Type::FIXED_LEN_BYTE_ARRAY, smax).c_str());
+}
+
+TEST(TestInt96Timestamp, Decoding) {
+ auto check = [](int32_t julian_day, uint64_t nanoseconds) {
+#if ARROW_LITTLE_ENDIAN
+ Int96 i96{static_cast<uint32_t>(nanoseconds),
+ static_cast<uint32_t>(nanoseconds >> 32),
+ static_cast<uint32_t>(julian_day)};
+#else
+ Int96 i96{static_cast<uint32_t>(nanoseconds >> 32),
+ static_cast<uint32_t>(nanoseconds), static_cast<uint32_t>(julian_day)};
+#endif
+ // Official formula according to https://github.com/apache/parquet-format/pull/49
+ int64_t expected =
+ (julian_day - 2440588) * (86400LL * 1000 * 1000 * 1000) + nanoseconds;
+ int64_t actual = Int96GetNanoSeconds(i96);
+ ASSERT_EQ(expected, actual);
+ };
+
+ // [2333837, 2547339] is the range of Julian days that can be converted to
+ // 64-bit Unix timestamps.
+ check(2333837, 0);
+ check(2333855, 0);
+ check(2547330, 0);
+ check(2547338, 0);
+ check(2547339, 0);
+
+ check(2547330, 13);
+ check(2547330, 32769);
+ check(2547330, 87654);
+ check(2547330, 0x123456789abcdefULL);
+ check(2547330, 0xfedcba9876543210ULL);
+ check(2547339, 0xffffffffffffffffULL);
+}
+
+#if !(defined(_WIN32) || defined(__CYGWIN__))
+#pragma GCC diagnostic pop
+#elif _MSC_VER
+#pragma warning(pop)
+#endif
+
+} // namespace parquet
diff --git a/src/arrow/cpp/src/parquet/windows_compatibility.h b/src/arrow/cpp/src/parquet/windows_compatibility.h
new file mode 100644
index 000000000..31ca04c8b
--- /dev/null
+++ b/src/arrow/cpp/src/parquet/windows_compatibility.h
@@ -0,0 +1,30 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/util/windows_compatibility.h"
+
+#ifdef _WIN32
+
+// parquet.thrift's OPTIONAL RepetitionType conflicts with a #define from
+// above, so we undefine it
+#ifdef OPTIONAL
+#undef OPTIONAL
+#endif
+
+#endif
diff --git a/src/arrow/cpp/src/plasma/.gitignore b/src/arrow/cpp/src/plasma/.gitignore
new file mode 100644
index 000000000..163b5c56e
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/.gitignore
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+*_generated.h
diff --git a/src/arrow/cpp/src/plasma/CMakeLists.txt b/src/arrow/cpp/src/plasma/CMakeLists.txt
new file mode 100644
index 000000000..46603d6f8
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/CMakeLists.txt
@@ -0,0 +1,235 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_custom_target(plasma-all)
+add_custom_target(plasma)
+add_custom_target(plasma-benchmarks)
+add_custom_target(plasma-tests)
+add_dependencies(plasma-all plasma plasma-tests plasma-benchmarks)
+
+# For the moment, Plasma is versioned like Arrow
+set(PLASMA_VERSION "${ARROW_VERSION}")
+
+find_package(Threads)
+
+# The SO version is also the ABI version
+set(PLASMA_SO_VERSION "${ARROW_SO_VERSION}")
+set(PLASMA_FULL_SO_VERSION "${ARROW_FULL_SO_VERSION}")
+
+set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-conversion")
+
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
+
+set(PLASMA_SRCS
+ client.cc
+ common.cc
+ fling.cc
+ io.cc
+ malloc.cc
+ plasma.cc
+ protocol.cc)
+
+set(PLASMA_STORE_SRCS
+ dlmalloc.cc
+ events.cc
+ eviction_policy.cc
+ quota_aware_policy.cc
+ plasma_allocator.cc
+ store.cc
+ thirdparty/ae/ae.c)
+
+set(PLASMA_LINK_LIBS arrow_shared)
+set(PLASMA_STATIC_LINK_LIBS arrow_static)
+
+if(ARROW_CUDA)
+ list(INSERT PLASMA_LINK_LIBS 0 arrow_cuda_shared)
+ list(INSERT PLASMA_STATIC_LINK_LIBS 0 arrow_cuda_static)
+ add_definitions(-DPLASMA_CUDA)
+endif()
+
+if(CXX_LINKER_SUPPORTS_VERSION_SCRIPT)
+ set(PLASMA_SHARED_LINK_FLAGS
+ "-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/symbols.map")
+endif()
+
+add_arrow_lib(plasma
+ CMAKE_PACKAGE_NAME
+ Plasma
+ PKG_CONFIG_NAME
+ plasma
+ SOURCES
+ ${PLASMA_SRCS}
+ OUTPUTS
+ PLASMA_LIBRARIES
+ SHARED_LINK_FLAGS
+ ${PLASMA_SHARED_LINK_FLAGS}
+ SHARED_LINK_LIBS
+ ${PLASMA_LINK_LIBS}
+ STATIC_LINK_LIBS
+ ${PLASMA_STATIC_LINK_LIBS})
+
+add_dependencies(plasma ${PLASMA_LIBRARIES})
+
+foreach(LIB_TARGET ${PLASMA_LIBRARIES})
+ target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_EXPORTING)
+endforeach()
+
+# The optimization flag -O3 is suggested by dlmalloc.c, which is #included in
+# malloc.cc; we set it here regardless of whether we do a debug or release build.
+set_source_files_properties(dlmalloc.cc PROPERTIES COMPILE_FLAGS "-O3")
+
+if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
+ set_property(SOURCE dlmalloc.cc
+ APPEND_STRING
+ PROPERTY COMPILE_FLAGS
+ " -Wno-parentheses-equality \
+-Wno-null-pointer-arithmetic \
+-Wno-shorten-64-to-32 \
+-Wno-unused-macros")
+endif()
+
+if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+ set_property(SOURCE dlmalloc.cc
+ APPEND_STRING
+ PROPERTY COMPILE_FLAGS " -Wno-conversion")
+endif()
+
+list(APPEND PLASMA_EXTERNAL_STORE_SOURCES "external_store.cc" "hash_table_store.cc")
+
+# We use static libraries for the plasma-store-server executable so that it can
+# be copied around and used in different locations.
+add_executable(plasma-store-server ${PLASMA_EXTERNAL_STORE_SOURCES} ${PLASMA_STORE_SRCS})
+target_link_libraries(plasma-store-server ${GFLAGS_LIBRARIES})
+if(ARROW_BUILD_STATIC)
+ target_link_libraries(plasma-store-server plasma_static ${PLASMA_STATIC_LINK_LIBS})
+else()
+ # Fallback to shared libs in the case that static libraries are not build.
+ target_link_libraries(plasma-store-server plasma_shared ${PLASMA_LINK_LIBS})
+endif()
+add_dependencies(plasma plasma-store-server)
+
+if(ARROW_RPATH_ORIGIN)
+ if(APPLE)
+ set(_lib_install_rpath "@loader_path")
+ else()
+ set(_lib_install_rpath "\$ORIGIN")
+ endif()
+ set_target_properties(plasma-store-server PROPERTIES INSTALL_RPATH
+ ${_lib_install_rpath})
+elseif(APPLE)
+ # With OSX and conda, we need to set the correct RPATH so that dependencies
+ # are found. The installed libraries with conda have an RPATH that matches
+ # for executables and libraries lying in $ENV{CONDA_PREFIX}/bin or
+ # $ENV{CONDA_PREFIX}/lib but our test libraries and executables are not
+ # installed there.
+ if(NOT "$ENV{CONDA_PREFIX}" STREQUAL "" AND APPLE)
+ set_target_properties(plasma-store-server
+ PROPERTIES BUILD_WITH_INSTALL_RPATH TRUE
+ INSTALL_RPATH_USE_LINK_PATH TRUE
+ INSTALL_RPATH "$ENV{CONDA_PREFIX}/lib")
+ endif()
+endif()
+
+install(FILES common.h
+ compat.h
+ client.h
+ events.h
+ test_util.h
+ DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/plasma")
+
+# Plasma store
+set_target_properties(plasma-store-server PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE)
+install(TARGETS plasma-store-server ${INSTALL_IS_OPTIONAL}
+ DESTINATION ${CMAKE_INSTALL_BINDIR})
+
+if(ARROW_PLASMA_JAVA_CLIENT)
+ # Plasma java client support
+ find_package(JNI REQUIRED)
+ # add jni support
+ include_directories(${JAVA_INCLUDE_PATH})
+ include_directories(${JAVA_INCLUDE_PATH2})
+ if(JNI_FOUND)
+ message(STATUS "JNI_INCLUDE_DIRS = ${JNI_INCLUDE_DIRS}")
+ message(STATUS "JNI_LIBRARIES = ${JNI_LIBRARIES}")
+ else()
+ message(WARNING "Could not find JNI")
+ endif()
+
+ add_compile_options("-I$ENV{JAVA_HOME}/include/")
+ if(WIN32)
+ add_compile_options("-I$ENV{JAVA_HOME}/include/win32")
+ elseif(APPLE)
+ add_compile_options("-I$ENV{JAVA_HOME}/include/darwin")
+ else() # linux
+ add_compile_options("-I$ENV{JAVA_HOME}/include/linux")
+ endif()
+
+ include_directories("${CMAKE_CURRENT_LIST_DIR}/lib/java")
+
+ file(GLOB PLASMA_LIBRARY_EXT_java_SRC lib/java/*.cc lib/*.cc)
+ add_library(plasma_java SHARED ${PLASMA_LIBRARY_EXT_java_SRC})
+
+ if(APPLE)
+ target_link_libraries(plasma_java
+ plasma_shared
+ ${PLASMA_LINK_LIBS}
+ "-undefined dynamic_lookup"
+ ${PTHREAD_LIBRARY})
+ else(APPLE)
+ target_link_libraries(plasma_java plasma_shared ${PLASMA_LINK_LIBS}
+ ${PTHREAD_LIBRARY})
+ endif(APPLE)
+endif()
+#
+# Unit tests
+#
+
+# Adding unit tests part of the "arrow" portion of the test suite
+function(ADD_PLASMA_TEST REL_TEST_NAME)
+ set(options)
+ set(one_value_args)
+ set(multi_value_args)
+ cmake_parse_arguments(ARG
+ "${options}"
+ "${one_value_args}"
+ "${multi_value_args}"
+ ${ARGN})
+ add_test_case(${REL_TEST_NAME}
+ PREFIX
+ "plasma"
+ LABELS
+ "plasma-tests"
+ ${ARG_UNPARSED_ARGUMENTS})
+endfunction()
+
+if(ARROW_BUILD_SHARED)
+ set(PLASMA_TEST_LIBS plasma_shared ${PLASMA_LINK_LIBS})
+else()
+ set(PLASMA_TEST_LIBS plasma_static ${PLASMA_STATIC_LINK_LIBS})
+endif()
+
+add_plasma_test(test/serialization_tests EXTRA_LINK_LIBS ${PLASMA_TEST_LIBS})
+add_plasma_test(test/client_tests
+ EXTRA_LINK_LIBS
+ ${PLASMA_TEST_LIBS}
+ EXTRA_DEPENDENCIES
+ plasma-store-server)
+add_plasma_test(test/external_store_tests
+ EXTRA_LINK_LIBS
+ ${PLASMA_TEST_LIBS}
+ EXTRA_DEPENDENCIES
+ plasma-store-server)
diff --git a/src/arrow/cpp/src/plasma/PlasmaConfig.cmake.in b/src/arrow/cpp/src/plasma/PlasmaConfig.cmake.in
new file mode 100644
index 000000000..928b50876
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/PlasmaConfig.cmake.in
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This config sets the following variables in your project::
+#
+# PLASMA_STORE_SERVER - path to the found plasma-store-server
+# Plasma_FOUND - true if Plasma found on the system
+#
+# This config sets the following targets in your project::
+#
+# plasma_shared - for linked as shared library if shared library is built
+# plasma_static - for linked as static library if static library is built
+
+@PACKAGE_INIT@
+
+include(CMakeFindDependencyMacro)
+find_dependency(Arrow)
+
+set(PLASMA_STORE_SERVER "@CMAKE_INSTALL_PREFIX@/@CMAKE_INSTALL_BINDIR@/plasma-store-server@CMAKE_EXECUTABLE_SUFFIX@")
+
+# Load targets only once. If we load targets multiple times, CMake reports
+# already existent target error.
+if(NOT (TARGET plasma_shared OR TARGET plasma_static))
+ include("${CMAKE_CURRENT_LIST_DIR}/PlasmaTargets.cmake")
+endif()
diff --git a/src/arrow/cpp/src/plasma/client.cc b/src/arrow/cpp/src/plasma/client.cc
new file mode 100644
index 000000000..260999922
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/client.cc
@@ -0,0 +1,1224 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// PLASMA CLIENT: Client library for using the plasma store and manager
+
+#include "plasma/client.h"
+
+#ifdef _WIN32
+#include <Win32_Interop/win32_types.h>
+#endif
+
+#include <fcntl.h>
+#include <netinet/in.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <strings.h>
+#include <sys/ioctl.h>
+#include <sys/mman.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <time.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <deque>
+#include <mutex>
+#include <tuple>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/util/thread_pool.h"
+
+#include "plasma/common.h"
+#include "plasma/fling.h"
+#include "plasma/io.h"
+#include "plasma/malloc.h"
+#include "plasma/plasma.h"
+#include "plasma/protocol.h"
+
+#ifdef PLASMA_CUDA
+#include "arrow/gpu/cuda_api.h"
+
+using arrow::cuda::CudaBuffer;
+using arrow::cuda::CudaBufferWriter;
+using arrow::cuda::CudaContext;
+using arrow::cuda::CudaDeviceManager;
+#endif
+
+#define XXH_INLINE_ALL 1
+#include "arrow/vendored/xxhash.h"
+
+#define XXH64_DEFAULT_SEED 0
+
+namespace fb = plasma::flatbuf;
+
+namespace plasma {
+
+using fb::MessageType;
+using fb::PlasmaError;
+
+using arrow::MutableBuffer;
+
+typedef struct XXH64_state_s XXH64_state_t;
+
+// Number of threads used for hash computations.
+constexpr int64_t kHashingConcurrency = 8;
+constexpr int64_t kBytesInMB = 1 << 20;
+
+// ----------------------------------------------------------------------
+// GPU support
+
+#ifdef PLASMA_CUDA
+
+namespace {
+
+struct GpuProcessHandle {
+ /// Pointer to CUDA buffer that is backing this GPU object.
+ std::shared_ptr<CudaBuffer> ptr;
+ /// Number of client using this GPU object.
+ int client_count;
+};
+
+// This is necessary as IPC handles can only be mapped once per process.
+// Thus if multiple clients in the same process get the same gpu object,
+// they need to access the same mapped CudaBuffer.
+std::unordered_map<ObjectID, GpuProcessHandle*> gpu_object_map;
+std::mutex gpu_mutex;
+
+// Return a new CudaBuffer pointing to the same data as the GpuProcessHandle,
+// but able to persist after the original IPC-backed buffer is closed
+// (ARROW-5924).
+std::shared_ptr<Buffer> MakeBufferFromGpuProcessHandle(GpuProcessHandle* handle) {
+ return std::make_shared<CudaBuffer>(handle->ptr->address(), handle->ptr->size(),
+ handle->ptr->context());
+}
+
+} // namespace
+
+#endif
+
+// ----------------------------------------------------------------------
+// PlasmaBuffer
+
+/// A Buffer class that automatically releases the backing plasma object
+/// when it goes out of scope. This is returned by Get.
+class ARROW_NO_EXPORT PlasmaBuffer : public Buffer {
+ public:
+ ~PlasmaBuffer();
+
+ PlasmaBuffer(std::shared_ptr<PlasmaClient::Impl> client, const ObjectID& object_id,
+ const std::shared_ptr<Buffer>& buffer)
+ : Buffer(buffer, 0, buffer->size()), client_(client), object_id_(object_id) {
+ if (buffer->is_mutable()) {
+ is_mutable_ = true;
+ }
+ }
+
+ private:
+ std::shared_ptr<PlasmaClient::Impl> client_;
+ ObjectID object_id_;
+};
+
+/// A mutable Buffer class that keeps the backing data alive by keeping a
+/// PlasmaClient shared pointer. This is returned by Create. Release will
+/// be called in the associated Seal call.
+class ARROW_NO_EXPORT PlasmaMutableBuffer : public MutableBuffer {
+ public:
+ PlasmaMutableBuffer(std::shared_ptr<PlasmaClient::Impl> client, uint8_t* mutable_data,
+ int64_t data_size)
+ : MutableBuffer(mutable_data, data_size), client_(client) {}
+
+ private:
+ std::shared_ptr<PlasmaClient::Impl> client_;
+};
+
+// ----------------------------------------------------------------------
+// PlasmaClient::Impl
+
+struct ObjectInUseEntry {
+ /// A count of the number of times this client has called PlasmaClient::Create
+ /// or
+ /// PlasmaClient::Get on this object ID minus the number of calls to
+ /// PlasmaClient::Release.
+ /// When this count reaches zero, we remove the entry from the ObjectsInUse
+ /// and decrement a count in the relevant ClientMmapTableEntry.
+ int count;
+ /// Cached information to read the object.
+ PlasmaObject object;
+ /// A flag representing whether the object has been sealed.
+ bool is_sealed;
+};
+
+class ClientMmapTableEntry {
+ public:
+ ClientMmapTableEntry(int fd, int64_t map_size)
+ : fd_(fd), pointer_(nullptr), length_(0) {
+ // We subtract kMmapRegionsGap from the length that was added
+ // in fake_mmap in malloc.h, to make map_size page-aligned again.
+ length_ = map_size - kMmapRegionsGap;
+ pointer_ = reinterpret_cast<uint8_t*>(
+ mmap(NULL, length_, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0));
+ // TODO(pcm): Don't fail here, instead return a Status.
+ if (pointer_ == MAP_FAILED) {
+ ARROW_LOG(FATAL) << "mmap failed";
+ }
+ close(fd); // Closing this fd has an effect on performance.
+ }
+
+ ~ClientMmapTableEntry() {
+ // At this point it is safe to unmap the memory, as the PlasmaBuffer
+ // keeps the PlasmaClient (and therefore the ClientMmapTableEntry)
+ // alive until it is destroyed.
+ // We don't need to close the associated file, since it has
+ // already been closed in the constructor.
+ int r = munmap(pointer_, length_);
+ if (r != 0) {
+ ARROW_LOG(ERROR) << "munmap returned " << r << ", errno = " << errno;
+ }
+ }
+
+ uint8_t* pointer() { return pointer_; }
+
+ int fd() { return fd_; }
+
+ private:
+ /// The associated file descriptor on the client.
+ int fd_;
+ /// The result of mmap for this file descriptor.
+ uint8_t* pointer_;
+ /// The length of the memory-mapped file.
+ size_t length_;
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ClientMmapTableEntry);
+};
+
+class PlasmaClient::Impl : public std::enable_shared_from_this<PlasmaClient::Impl> {
+ public:
+ Impl();
+ ~Impl();
+
+ // PlasmaClient method implementations
+
+ Status Connect(const std::string& store_socket_name,
+ const std::string& manager_socket_name, int release_delay = 0,
+ int num_retries = -1);
+
+ Status SetClientOptions(const std::string& client_name, int64_t output_memory_quota);
+
+ Status Create(const ObjectID& object_id, int64_t data_size, const uint8_t* metadata,
+ int64_t metadata_size, std::shared_ptr<Buffer>* data, int device_num = 0,
+ bool evict_if_full = true);
+
+ Status CreateAndSeal(const ObjectID& object_id, const std::string& data,
+ const std::string& metadata, bool evict_if_full = true);
+
+ Status CreateAndSealBatch(const std::vector<ObjectID>& object_ids,
+ const std::vector<std::string>& data,
+ const std::vector<std::string>& metadata,
+ bool evict_if_full = true);
+
+ Status Get(const std::vector<ObjectID>& object_ids, int64_t timeout_ms,
+ std::vector<ObjectBuffer>* object_buffers);
+
+ Status Get(const ObjectID* object_ids, int64_t num_objects, int64_t timeout_ms,
+ ObjectBuffer* object_buffers);
+
+ Status Release(const ObjectID& object_id);
+
+ Status Contains(const ObjectID& object_id, bool* has_object);
+
+ Status List(ObjectTable* objects);
+
+ Status Abort(const ObjectID& object_id);
+
+ Status Seal(const ObjectID& object_id);
+
+ Status Delete(const std::vector<ObjectID>& object_ids);
+
+ Status Evict(int64_t num_bytes, int64_t& num_bytes_evicted);
+
+ Status Refresh(const std::vector<ObjectID>& object_ids);
+
+ Status Hash(const ObjectID& object_id, uint8_t* digest);
+
+ Status Subscribe(int* fd);
+
+ Status GetNotification(int fd, ObjectID* object_id, int64_t* data_size,
+ int64_t* metadata_size);
+
+ Status DecodeNotifications(const uint8_t* buffer, std::vector<ObjectID>* object_ids,
+ std::vector<int64_t>* data_sizes,
+ std::vector<int64_t>* metadata_sizes);
+
+ Status Disconnect();
+
+ std::string DebugString();
+
+ bool IsInUse(const ObjectID& object_id);
+
+ int64_t store_capacity() { return store_capacity_; }
+
+ private:
+ /// Check if store_fd has already been received from the store. If yes,
+ /// return it. Otherwise, receive it from the store (see analogous logic
+ /// in store.cc).
+ ///
+ /// \param store_fd File descriptor to fetch from the store.
+ /// \return Client file descriptor corresponding to store_fd.
+ int GetStoreFd(int store_fd);
+
+ /// This is a helper method for marking an object as unused by this client.
+ ///
+ /// \param object_id The object ID we mark unused.
+ /// \return The return status.
+ Status MarkObjectUnused(const ObjectID& object_id);
+
+ /// Common helper for Get() variants
+ Status GetBuffers(const ObjectID* object_ids, int64_t num_objects, int64_t timeout_ms,
+ const std::function<std::shared_ptr<Buffer>(
+ const ObjectID&, const std::shared_ptr<Buffer>&)>& wrap_buffer,
+ ObjectBuffer* object_buffers);
+
+ uint8_t* LookupOrMmap(int fd, int store_fd_val, int64_t map_size);
+
+ uint8_t* LookupMmappedFile(int store_fd_val);
+
+ void IncrementObjectCount(const ObjectID& object_id, PlasmaObject* object,
+ bool is_sealed);
+
+ bool ComputeObjectHashParallel(XXH64_state_t* hash_state, const unsigned char* data,
+ int64_t nbytes);
+
+ uint64_t ComputeObjectHash(const ObjectBuffer& obj_buffer);
+
+ uint64_t ComputeObjectHashCPU(const uint8_t* data, int64_t data_size,
+ const uint8_t* metadata, int64_t metadata_size);
+
+#ifdef PLASMA_CUDA
+ arrow::Result<std::shared_ptr<CudaContext>> GetCudaContext(int device_number);
+#endif
+
+ /// File descriptor of the Unix domain socket that connects to the store.
+ int store_conn_;
+ /// Table of dlmalloc buffer files that have been memory mapped so far. This
+ /// is a hash table mapping a file descriptor to a struct containing the
+ /// address of the corresponding memory-mapped file.
+ std::unordered_map<int, std::unique_ptr<ClientMmapTableEntry>> mmap_table_;
+ /// A hash table of the object IDs that are currently being used by this
+ /// client.
+ std::unordered_map<ObjectID, std::unique_ptr<ObjectInUseEntry>> objects_in_use_;
+ /// The amount of memory available to the Plasma store. The client needs this
+ /// information to make sure that it does not delay in releasing so much
+ /// memory that the store is unable to evict enough objects to free up space.
+ int64_t store_capacity_;
+ /// A hash set to record the ids that users want to delete but still in use.
+ std::unordered_set<ObjectID> deletion_cache_;
+ /// A queue of notification
+ std::deque<std::tuple<ObjectID, int64_t, int64_t>> pending_notification_;
+ /// A mutex which protects this class.
+ std::recursive_mutex client_mutex_;
+};
+
+PlasmaBuffer::~PlasmaBuffer() { ARROW_UNUSED(client_->Release(object_id_)); }
+
+PlasmaClient::Impl::Impl() : store_conn_(0), store_capacity_(0) {}
+
+PlasmaClient::Impl::~Impl() {}
+
+// If the file descriptor fd has been mmapped in this client process before,
+// return the pointer that was returned by mmap, otherwise mmap it and store the
+// pointer in a hash table.
+uint8_t* PlasmaClient::Impl::LookupOrMmap(int fd, int store_fd_val, int64_t map_size) {
+ auto entry = mmap_table_.find(store_fd_val);
+ if (entry != mmap_table_.end()) {
+ return entry->second->pointer();
+ } else {
+ mmap_table_[store_fd_val] =
+ std::unique_ptr<ClientMmapTableEntry>(new ClientMmapTableEntry(fd, map_size));
+ return mmap_table_[store_fd_val]->pointer();
+ }
+}
+
+// Get a pointer to a file that we know has been memory mapped in this client
+// process before.
+uint8_t* PlasmaClient::Impl::LookupMmappedFile(int store_fd_val) {
+ auto entry = mmap_table_.find(store_fd_val);
+ ARROW_CHECK(entry != mmap_table_.end());
+ return entry->second->pointer();
+}
+
+bool PlasmaClient::Impl::IsInUse(const ObjectID& object_id) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ const auto elem = objects_in_use_.find(object_id);
+ return (elem != objects_in_use_.end());
+}
+
+int PlasmaClient::Impl::GetStoreFd(int store_fd) {
+ auto entry = mmap_table_.find(store_fd);
+ if (entry == mmap_table_.end()) {
+ int fd = recv_fd(store_conn_);
+ ARROW_CHECK(fd >= 0) << "recv not successful";
+ return fd;
+ } else {
+ return entry->second->fd();
+ }
+}
+
+void PlasmaClient::Impl::IncrementObjectCount(const ObjectID& object_id,
+ PlasmaObject* object, bool is_sealed) {
+ // Increment the count of the object to track the fact that it is being used.
+ // The corresponding decrement should happen in PlasmaClient::Release.
+ auto elem = objects_in_use_.find(object_id);
+ ObjectInUseEntry* object_entry;
+ if (elem == objects_in_use_.end()) {
+ // Add this object ID to the hash table of object IDs in use. The
+ // corresponding call to free happens in PlasmaClient::Release.
+ objects_in_use_[object_id] =
+ std::unique_ptr<ObjectInUseEntry>(new ObjectInUseEntry());
+ objects_in_use_[object_id]->object = *object;
+ objects_in_use_[object_id]->count = 0;
+ objects_in_use_[object_id]->is_sealed = is_sealed;
+ object_entry = objects_in_use_[object_id].get();
+ } else {
+ object_entry = elem->second.get();
+ ARROW_CHECK(object_entry->count > 0);
+ }
+ // Increment the count of the number of instances of this object that are
+ // being used by this client. The corresponding decrement should happen in
+ // PlasmaClient::Release.
+ object_entry->count += 1;
+}
+
+#ifdef PLASMA_CUDA
+arrow::Result<std::shared_ptr<CudaContext>> PlasmaClient::Impl::GetCudaContext(
+ int device_number) {
+ ARROW_ASSIGN_OR_RAISE(auto manager, CudaDeviceManager::Instance());
+ return manager->GetContext(device_number - 1);
+}
+#endif
+
+Status PlasmaClient::Impl::Create(const ObjectID& object_id, int64_t data_size,
+ const uint8_t* metadata, int64_t metadata_size,
+ std::shared_ptr<Buffer>* data, int device_num,
+ bool evict_if_full) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ ARROW_LOG(DEBUG) << "called plasma_create on conn " << store_conn_ << " with size "
+ << data_size << " and metadata size " << metadata_size;
+ RETURN_NOT_OK(SendCreateRequest(store_conn_, object_id, evict_if_full, data_size,
+ metadata_size, device_num));
+ std::vector<uint8_t> buffer;
+ RETURN_NOT_OK(PlasmaReceive(store_conn_, MessageType::PlasmaCreateReply, &buffer));
+ ObjectID id;
+ PlasmaObject object;
+ int store_fd;
+ int64_t mmap_size;
+ RETURN_NOT_OK(
+ ReadCreateReply(buffer.data(), buffer.size(), &id, &object, &store_fd, &mmap_size));
+ // If the CreateReply included an error, then the store will not send a file
+ // descriptor.
+ if (device_num == 0) {
+ int fd = GetStoreFd(store_fd);
+ ARROW_CHECK(object.data_size == data_size);
+ ARROW_CHECK(object.metadata_size == metadata_size);
+ // The metadata should come right after the data.
+ ARROW_CHECK(object.metadata_offset == object.data_offset + data_size);
+ *data = std::make_shared<PlasmaMutableBuffer>(
+ shared_from_this(), LookupOrMmap(fd, store_fd, mmap_size) + object.data_offset,
+ data_size);
+ // If plasma_create is being called from a transfer, then we will not copy the
+ // metadata here. The metadata will be written along with the data streamed
+ // from the transfer.
+ if (metadata != NULL) {
+ // Copy the metadata to the buffer.
+ memcpy((*data)->mutable_data() + object.data_size, metadata, metadata_size);
+ }
+ } else {
+#ifdef PLASMA_CUDA
+ ARROW_ASSIGN_OR_RAISE(auto context, GetCudaContext(device_num));
+ GpuProcessHandle* handle = new GpuProcessHandle();
+ handle->client_count = 2;
+ ARROW_ASSIGN_OR_RAISE(handle->ptr, context->OpenIpcBuffer(*object.ipc_handle));
+ {
+ std::lock_guard<std::mutex> lock(gpu_mutex);
+ gpu_object_map[object_id] = handle;
+ }
+ if (metadata != NULL) {
+ // Copy the metadata to the buffer.
+ CudaBufferWriter writer(handle->ptr);
+ RETURN_NOT_OK(writer.WriteAt(object.data_size, metadata, metadata_size));
+ }
+ *data = MakeBufferFromGpuProcessHandle(handle);
+#else
+ ARROW_LOG(FATAL) << "Arrow GPU library is not enabled.";
+#endif
+ }
+
+ // Increment the count of the number of instances of this object that this
+ // client is using. A call to PlasmaClient::Release is required to decrement
+ // this count. Cache the reference to the object.
+ IncrementObjectCount(object_id, &object, false);
+ // We increment the count a second time (and the corresponding decrement will
+ // happen in a PlasmaClient::Release call in plasma_seal) so even if the
+ // buffer returned by PlasmaClient::Create goes out of scope, the object does
+ // not get released before the call to PlasmaClient::Seal happens.
+ IncrementObjectCount(object_id, &object, false);
+ return Status::OK();
+}
+
+Status PlasmaClient::Impl::CreateAndSeal(const ObjectID& object_id,
+ const std::string& data,
+ const std::string& metadata,
+ bool evict_if_full) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ ARROW_LOG(DEBUG) << "called CreateAndSeal on conn " << store_conn_;
+ // Compute the object hash.
+ static unsigned char digest[kDigestSize];
+ uint64_t hash = ComputeObjectHashCPU(
+ reinterpret_cast<const uint8_t*>(data.data()), data.size(),
+ reinterpret_cast<const uint8_t*>(metadata.data()), metadata.size());
+ memcpy(&digest[0], &hash, sizeof(hash));
+
+ RETURN_NOT_OK(SendCreateAndSealRequest(store_conn_, object_id, evict_if_full, data,
+ metadata, digest));
+ std::vector<uint8_t> buffer;
+ RETURN_NOT_OK(
+ PlasmaReceive(store_conn_, MessageType::PlasmaCreateAndSealReply, &buffer));
+ RETURN_NOT_OK(ReadCreateAndSealReply(buffer.data(), buffer.size()));
+ return Status::OK();
+}
+
+Status PlasmaClient::Impl::CreateAndSealBatch(const std::vector<ObjectID>& object_ids,
+ const std::vector<std::string>& data,
+ const std::vector<std::string>& metadata,
+ bool evict_if_full) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ ARROW_LOG(DEBUG) << "called CreateAndSealBatch on conn " << store_conn_;
+
+ std::vector<std::string> digests;
+ for (size_t i = 0; i < object_ids.size(); i++) {
+ // Compute the object hash.
+ std::string digest;
+ uint64_t hash = ComputeObjectHashCPU(
+ reinterpret_cast<const uint8_t*>(data.data()), data.size(),
+ reinterpret_cast<const uint8_t*>(metadata.data()), metadata.size());
+ digest.assign(reinterpret_cast<char*>(&hash), sizeof(hash));
+ digests.push_back(digest);
+ }
+
+ RETURN_NOT_OK(SendCreateAndSealBatchRequest(store_conn_, object_ids, evict_if_full,
+ data, metadata, digests));
+ std::vector<uint8_t> buffer;
+ RETURN_NOT_OK(
+ PlasmaReceive(store_conn_, MessageType::PlasmaCreateAndSealBatchReply, &buffer));
+ RETURN_NOT_OK(ReadCreateAndSealBatchReply(buffer.data(), buffer.size()));
+
+ return Status::OK();
+}
+
+Status PlasmaClient::Impl::GetBuffers(
+ const ObjectID* object_ids, int64_t num_objects, int64_t timeout_ms,
+ const std::function<std::shared_ptr<Buffer>(
+ const ObjectID&, const std::shared_ptr<Buffer>&)>& wrap_buffer,
+ ObjectBuffer* object_buffers) {
+ // Fill out the info for the objects that are already in use locally.
+ bool all_present = true;
+ for (int64_t i = 0; i < num_objects; ++i) {
+ auto object_entry = objects_in_use_.find(object_ids[i]);
+ if (object_entry == objects_in_use_.end()) {
+ // This object is not currently in use by this client, so we need to send
+ // a request to the store.
+ all_present = false;
+ } else if (!object_entry->second->is_sealed) {
+ // This client created the object but hasn't sealed it. If we call Get
+ // with no timeout, we will deadlock, because this client won't be able to
+ // call Seal.
+ ARROW_CHECK(timeout_ms != -1)
+ << "Plasma client called get on an unsealed object that it created";
+ ARROW_LOG(WARNING)
+ << "Attempting to get an object that this client created but hasn't sealed.";
+ all_present = false;
+ } else {
+ PlasmaObject* object = &object_entry->second->object;
+ std::shared_ptr<Buffer> physical_buf;
+
+ if (object->device_num == 0) {
+ uint8_t* data = LookupMmappedFile(object->store_fd);
+ physical_buf = std::make_shared<Buffer>(
+ data + object->data_offset, object->data_size + object->metadata_size);
+ } else {
+#ifdef PLASMA_CUDA
+ std::lock_guard<std::mutex> lock(gpu_mutex);
+ auto iter = gpu_object_map.find(object_ids[i]);
+ ARROW_CHECK(iter != gpu_object_map.end());
+ iter->second->client_count++;
+ physical_buf = MakeBufferFromGpuProcessHandle(iter->second);
+#else
+ ARROW_LOG(FATAL) << "Arrow GPU library is not enabled.";
+#endif
+ }
+ physical_buf = wrap_buffer(object_ids[i], physical_buf);
+ object_buffers[i].data = SliceBuffer(physical_buf, 0, object->data_size);
+ object_buffers[i].metadata =
+ SliceBuffer(physical_buf, object->data_size, object->metadata_size);
+ object_buffers[i].device_num = object->device_num;
+ // Increment the count of the number of instances of this object that this
+ // client is using. Cache the reference to the object.
+ IncrementObjectCount(object_ids[i], object, true);
+ }
+ }
+
+ if (all_present) {
+ return Status::OK();
+ }
+
+ // If we get here, then the objects aren't all currently in use by this
+ // client, so we need to send a request to the plasma store.
+ RETURN_NOT_OK(SendGetRequest(store_conn_, &object_ids[0], num_objects, timeout_ms));
+ std::vector<uint8_t> buffer;
+ RETURN_NOT_OK(PlasmaReceive(store_conn_, MessageType::PlasmaGetReply, &buffer));
+ std::vector<ObjectID> received_object_ids(num_objects);
+ std::vector<PlasmaObject> object_data(num_objects);
+ PlasmaObject* object;
+ std::vector<int> store_fds;
+ std::vector<int64_t> mmap_sizes;
+ RETURN_NOT_OK(ReadGetReply(buffer.data(), buffer.size(), received_object_ids.data(),
+ object_data.data(), num_objects, store_fds, mmap_sizes));
+
+ // We mmap all of the file descriptors here so that we can avoid look them up
+ // in the subsequent loop based on just the store file descriptor and without
+ // having to know the relevant file descriptor received from recv_fd.
+ for (size_t i = 0; i < store_fds.size(); i++) {
+ int fd = GetStoreFd(store_fds[i]);
+ LookupOrMmap(fd, store_fds[i], mmap_sizes[i]);
+ }
+
+ for (int64_t i = 0; i < num_objects; ++i) {
+ DCHECK(received_object_ids[i] == object_ids[i]);
+ object = &object_data[i];
+ if (object_buffers[i].data) {
+ // If the object was already in use by the client, then the store should
+ // have returned it.
+ DCHECK_NE(object->data_size, -1);
+ // We've already filled out the information for this object, so we can
+ // just continue.
+ continue;
+ }
+ // If we are here, the object was not currently in use, so we need to
+ // process the reply from the object store.
+ if (object->data_size != -1) {
+ std::shared_ptr<Buffer> physical_buf;
+ if (object->device_num == 0) {
+ uint8_t* data = LookupMmappedFile(object->store_fd);
+ physical_buf = std::make_shared<Buffer>(
+ data + object->data_offset, object->data_size + object->metadata_size);
+ } else {
+#ifdef PLASMA_CUDA
+ std::lock_guard<std::mutex> lock(gpu_mutex);
+ auto iter = gpu_object_map.find(object_ids[i]);
+ if (iter == gpu_object_map.end()) {
+ ARROW_ASSIGN_OR_RAISE(auto context, GetCudaContext(object->device_num));
+ GpuProcessHandle* obj_handle = new GpuProcessHandle();
+ obj_handle->client_count = 1;
+ ARROW_ASSIGN_OR_RAISE(obj_handle->ptr,
+ context->OpenIpcBuffer(*object->ipc_handle));
+ gpu_object_map[object_ids[i]] = obj_handle;
+ physical_buf = MakeBufferFromGpuProcessHandle(obj_handle);
+ } else {
+ iter->second->client_count++;
+ physical_buf = MakeBufferFromGpuProcessHandle(iter->second);
+ }
+#else
+ ARROW_LOG(FATAL) << "Arrow GPU library is not enabled.";
+#endif
+ }
+ // Finish filling out the return values.
+ physical_buf = wrap_buffer(object_ids[i], physical_buf);
+ object_buffers[i].data = SliceBuffer(physical_buf, 0, object->data_size);
+ object_buffers[i].metadata =
+ SliceBuffer(physical_buf, object->data_size, object->metadata_size);
+ object_buffers[i].device_num = object->device_num;
+ // Increment the count of the number of instances of this object that this
+ // client is using. Cache the reference to the object.
+ IncrementObjectCount(received_object_ids[i], object, true);
+ } else {
+ // The object was not retrieved. The caller can detect this condition
+ // by checking the boolean value of the metadata/data buffers.
+ DCHECK(!object_buffers[i].metadata);
+ DCHECK(!object_buffers[i].data);
+ }
+ }
+ return Status::OK();
+}
+
+Status PlasmaClient::Impl::Get(const std::vector<ObjectID>& object_ids,
+ int64_t timeout_ms, std::vector<ObjectBuffer>* out) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ const auto wrap_buffer = [=](const ObjectID& object_id,
+ const std::shared_ptr<Buffer>& buffer) {
+ return std::make_shared<PlasmaBuffer>(shared_from_this(), object_id, buffer);
+ };
+ const size_t num_objects = object_ids.size();
+ *out = std::vector<ObjectBuffer>(num_objects);
+ return GetBuffers(&object_ids[0], num_objects, timeout_ms, wrap_buffer, &(*out)[0]);
+}
+
+Status PlasmaClient::Impl::Get(const ObjectID* object_ids, int64_t num_objects,
+ int64_t timeout_ms, ObjectBuffer* out) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ const auto wrap_buffer = [](const ObjectID& object_id,
+ const std::shared_ptr<Buffer>& buffer) { return buffer; };
+ return GetBuffers(object_ids, num_objects, timeout_ms, wrap_buffer, out);
+}
+
+Status PlasmaClient::Impl::MarkObjectUnused(const ObjectID& object_id) {
+ auto object_entry = objects_in_use_.find(object_id);
+ ARROW_CHECK(object_entry != objects_in_use_.end());
+ ARROW_CHECK(object_entry->second->count == 0);
+
+ // Remove the entry from the hash table of objects currently in use.
+ objects_in_use_.erase(object_id);
+ return Status::OK();
+}
+
+Status PlasmaClient::Impl::Release(const ObjectID& object_id) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ // If the client is already disconnected, ignore release requests.
+ if (store_conn_ < 0) {
+ return Status::OK();
+ }
+ auto object_entry = objects_in_use_.find(object_id);
+ ARROW_CHECK(object_entry != objects_in_use_.end());
+
+#ifdef PLASMA_CUDA
+ if (object_entry->second->object.device_num != 0) {
+ std::lock_guard<std::mutex> lock(gpu_mutex);
+ auto iter = gpu_object_map.find(object_id);
+ ARROW_CHECK(iter != gpu_object_map.end());
+ if (--iter->second->client_count == 0) {
+ delete iter->second;
+ gpu_object_map.erase(iter);
+ }
+ }
+#endif
+
+ object_entry->second->count -= 1;
+ ARROW_CHECK(object_entry->second->count >= 0);
+ // Check if the client is no longer using this object.
+ if (object_entry->second->count == 0) {
+ // Tell the store that the client no longer needs the object.
+ RETURN_NOT_OK(MarkObjectUnused(object_id));
+ RETURN_NOT_OK(SendReleaseRequest(store_conn_, object_id));
+ auto iter = deletion_cache_.find(object_id);
+ if (iter != deletion_cache_.end()) {
+ deletion_cache_.erase(object_id);
+ RETURN_NOT_OK(Delete({object_id}));
+ }
+ }
+ return Status::OK();
+}
+
+// This method is used to query whether the plasma store contains an object.
+Status PlasmaClient::Impl::Contains(const ObjectID& object_id, bool* has_object) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ // Check if we already have a reference to the object.
+ if (objects_in_use_.count(object_id) > 0) {
+ *has_object = 1;
+ } else {
+ // If we don't already have a reference to the object, check with the store
+ // to see if we have the object.
+ RETURN_NOT_OK(SendContainsRequest(store_conn_, object_id));
+ std::vector<uint8_t> buffer;
+ RETURN_NOT_OK(PlasmaReceive(store_conn_, MessageType::PlasmaContainsReply, &buffer));
+ ObjectID object_id2;
+ DCHECK_GT(buffer.size(), 0);
+ RETURN_NOT_OK(
+ ReadContainsReply(buffer.data(), buffer.size(), &object_id2, has_object));
+ }
+ return Status::OK();
+}
+
+Status PlasmaClient::Impl::List(ObjectTable* objects) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+ RETURN_NOT_OK(SendListRequest(store_conn_));
+ std::vector<uint8_t> buffer;
+ RETURN_NOT_OK(PlasmaReceive(store_conn_, MessageType::PlasmaListReply, &buffer));
+ return ReadListReply(buffer.data(), buffer.size(), objects);
+}
+
+static void ComputeBlockHash(const unsigned char* data, int64_t nbytes, uint64_t* hash) {
+ XXH64_state_t hash_state;
+ XXH64_reset(&hash_state, XXH64_DEFAULT_SEED);
+ XXH64_update(&hash_state, data, nbytes);
+ *hash = XXH64_digest(&hash_state);
+}
+
+bool PlasmaClient::Impl::ComputeObjectHashParallel(XXH64_state_t* hash_state,
+ const unsigned char* data,
+ int64_t nbytes) {
+ // Note that this function will likely be faster if the address of data is
+ // aligned on a 64-byte boundary.
+ auto pool = arrow::internal::GetCpuThreadPool();
+
+ const int num_threads = kHashingConcurrency;
+ uint64_t threadhash[num_threads + 1];
+ const uint64_t data_address = reinterpret_cast<uint64_t>(data);
+ const uint64_t num_blocks = nbytes / kBlockSize;
+ const uint64_t chunk_size = (num_blocks / num_threads) * kBlockSize;
+ const uint64_t right_address = data_address + chunk_size * num_threads;
+ const uint64_t suffix = (data_address + nbytes) - right_address;
+ // Now the data layout is | k * num_threads * block_size | suffix | ==
+ // | num_threads * chunk_size | suffix |, where chunk_size = k * block_size.
+ // Each thread gets a "chunk" of k blocks, except the suffix thread.
+
+ std::vector<arrow::Future<>> futures;
+ for (int i = 0; i < num_threads; i++) {
+ futures.push_back(*pool->Submit(
+ ComputeBlockHash, reinterpret_cast<uint8_t*>(data_address) + i * chunk_size,
+ chunk_size, &threadhash[i]));
+ }
+ ComputeBlockHash(reinterpret_cast<uint8_t*>(right_address), suffix,
+ &threadhash[num_threads]);
+
+ for (auto& fut : futures) {
+ ARROW_CHECK_OK(fut.status());
+ }
+
+ XXH64_update(hash_state, reinterpret_cast<unsigned char*>(threadhash),
+ sizeof(threadhash));
+ return true;
+}
+
+uint64_t PlasmaClient::Impl::ComputeObjectHash(const ObjectBuffer& obj_buffer) {
+ if (obj_buffer.device_num != 0) {
+ // TODO(wap): Create cuda program to hash data on gpu.
+ return 0;
+ }
+ return ComputeObjectHashCPU(obj_buffer.data->data(), obj_buffer.data->size(),
+ obj_buffer.metadata->data(), obj_buffer.metadata->size());
+}
+
+uint64_t PlasmaClient::Impl::ComputeObjectHashCPU(const uint8_t* data, int64_t data_size,
+ const uint8_t* metadata,
+ int64_t metadata_size) {
+ DCHECK(metadata);
+ DCHECK(data);
+ XXH64_state_t hash_state;
+ XXH64_reset(&hash_state, XXH64_DEFAULT_SEED);
+ if (data_size >= kBytesInMB) {
+ ComputeObjectHashParallel(&hash_state, reinterpret_cast<const unsigned char*>(data),
+ data_size);
+ } else {
+ XXH64_update(&hash_state, reinterpret_cast<const unsigned char*>(data), data_size);
+ }
+ XXH64_update(&hash_state, reinterpret_cast<const unsigned char*>(metadata),
+ metadata_size);
+ return XXH64_digest(&hash_state);
+}
+
+Status PlasmaClient::Impl::Seal(const ObjectID& object_id) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ // Make sure this client has a reference to the object before sending the
+ // request to Plasma.
+ auto object_entry = objects_in_use_.find(object_id);
+
+ if (object_entry == objects_in_use_.end()) {
+ return MakePlasmaError(PlasmaErrorCode::PlasmaObjectNotFound,
+ "Seal() called on an object without a reference to it");
+ }
+ if (object_entry->second->is_sealed) {
+ return MakePlasmaError(PlasmaErrorCode::PlasmaObjectAlreadySealed,
+ "Seal() called on an already sealed object");
+ }
+
+ object_entry->second->is_sealed = true;
+ /// Send the seal request to Plasma.
+ std::vector<uint8_t> digest(kDigestSize);
+ RETURN_NOT_OK(Hash(object_id, &digest[0]));
+ RETURN_NOT_OK(
+ SendSealRequest(store_conn_, object_id, std::string(digest.begin(), digest.end())));
+ std::vector<uint8_t> buffer;
+ RETURN_NOT_OK(PlasmaReceive(store_conn_, MessageType::PlasmaSealReply, &buffer));
+ ObjectID sealed_id;
+ RETURN_NOT_OK(ReadSealReply(buffer.data(), buffer.size(), &sealed_id));
+ ARROW_CHECK(sealed_id == object_id);
+ // We call PlasmaClient::Release to decrement the number of instances of this
+ // object
+ // that are currently being used by this client. The corresponding increment
+ // happened in plasma_create and was used to ensure that the object was not
+ // released before the call to PlasmaClient::Seal.
+ return Release(object_id);
+}
+
+Status PlasmaClient::Impl::Abort(const ObjectID& object_id) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+ auto object_entry = objects_in_use_.find(object_id);
+ ARROW_CHECK(object_entry != objects_in_use_.end())
+ << "Plasma client called abort on an object without a reference to it";
+ ARROW_CHECK(!object_entry->second->is_sealed)
+ << "Plasma client called abort on a sealed object";
+
+ // Make sure that the Plasma client only has one reference to the object. If
+ // it has more, then the client needs to release the buffer before calling
+ // abort.
+ if (object_entry->second->count > 1) {
+ return Status::Invalid("Plasma client cannot have a reference to the buffer.");
+ }
+
+#ifdef PLASMA_CUDA
+ if (object_entry->second->object.device_num != 0) {
+ std::lock_guard<std::mutex> lock(gpu_mutex);
+ auto iter = gpu_object_map.find(object_id);
+ ARROW_CHECK(iter != gpu_object_map.end());
+ ARROW_CHECK(iter->second->client_count == 1);
+ delete iter->second;
+ gpu_object_map.erase(iter);
+ }
+#endif
+
+ // Send the abort request.
+ RETURN_NOT_OK(SendAbortRequest(store_conn_, object_id));
+ // Decrease the reference count to zero, then remove the object.
+ object_entry->second->count--;
+ RETURN_NOT_OK(MarkObjectUnused(object_id));
+
+ std::vector<uint8_t> buffer;
+ ObjectID id;
+ MessageType type;
+ RETURN_NOT_OK(ReadMessage(store_conn_, &type, &buffer));
+ return ReadAbortReply(buffer.data(), buffer.size(), &id);
+}
+
+Status PlasmaClient::Impl::Delete(const std::vector<ObjectID>& object_ids) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ std::vector<ObjectID> not_in_use_ids;
+ for (auto& object_id : object_ids) {
+ // If the object is in used, skip it.
+ if (objects_in_use_.count(object_id) == 0) {
+ not_in_use_ids.push_back(object_id);
+ } else {
+ deletion_cache_.emplace(object_id);
+ }
+ }
+ if (not_in_use_ids.size() > 0) {
+ RETURN_NOT_OK(SendDeleteRequest(store_conn_, not_in_use_ids));
+ std::vector<uint8_t> buffer;
+ RETURN_NOT_OK(PlasmaReceive(store_conn_, MessageType::PlasmaDeleteReply, &buffer));
+ DCHECK_GT(buffer.size(), 0);
+ std::vector<PlasmaError> error_codes;
+ not_in_use_ids.clear();
+ RETURN_NOT_OK(
+ ReadDeleteReply(buffer.data(), buffer.size(), &not_in_use_ids, &error_codes));
+ }
+ return Status::OK();
+}
+
+Status PlasmaClient::Impl::Evict(int64_t num_bytes, int64_t& num_bytes_evicted) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ // Send a request to the store to evict objects.
+ RETURN_NOT_OK(SendEvictRequest(store_conn_, num_bytes));
+ // Wait for a response with the number of bytes actually evicted.
+ std::vector<uint8_t> buffer;
+ MessageType type;
+ RETURN_NOT_OK(ReadMessage(store_conn_, &type, &buffer));
+ return ReadEvictReply(buffer.data(), buffer.size(), num_bytes_evicted);
+}
+
+Status PlasmaClient::Impl::Refresh(const std::vector<ObjectID>& object_ids) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ RETURN_NOT_OK(SendRefreshLRURequest(store_conn_, object_ids));
+ std::vector<uint8_t> buffer;
+ MessageType type;
+ RETURN_NOT_OK(ReadMessage(store_conn_, &type, &buffer));
+ return ReadRefreshLRUReply(buffer.data(), buffer.size());
+}
+
+Status PlasmaClient::Impl::Hash(const ObjectID& object_id, uint8_t* digest) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ // Get the plasma object data. We pass in a timeout of 0 to indicate that
+ // the operation should timeout immediately.
+ std::vector<ObjectBuffer> object_buffers;
+ RETURN_NOT_OK(Get({object_id}, 0, &object_buffers));
+ // If the object was not retrieved, return false.
+ if (!object_buffers[0].data) {
+ return MakePlasmaError(PlasmaErrorCode::PlasmaObjectNotFound, "Object not found");
+ }
+ // Compute the hash.
+ uint64_t hash = ComputeObjectHash(object_buffers[0]);
+ memcpy(digest, &hash, sizeof(hash));
+ return Status::OK();
+}
+
+Status PlasmaClient::Impl::Subscribe(int* fd) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ int sock[2];
+ // Create a non-blocking socket pair. This will only be used to send
+ // notifications from the Plasma store to the client.
+ socketpair(AF_UNIX, SOCK_STREAM, 0, sock);
+ // Make the socket non-blocking.
+ int flags = fcntl(sock[1], F_GETFL, 0);
+ ARROW_CHECK(fcntl(sock[1], F_SETFL, flags | O_NONBLOCK) == 0);
+ // Tell the Plasma store about the subscription.
+ RETURN_NOT_OK(SendSubscribeRequest(store_conn_));
+ // Send the file descriptor that the Plasma store should use to push
+ // notifications about sealed objects to this client.
+ ARROW_CHECK(send_fd(store_conn_, sock[1]) >= 0);
+ close(sock[1]);
+ // Return the file descriptor that the client should use to read notifications
+ // about sealed objects.
+ *fd = sock[0];
+ return Status::OK();
+}
+
+Status PlasmaClient::Impl::GetNotification(int fd, ObjectID* object_id,
+ int64_t* data_size, int64_t* metadata_size) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ if (pending_notification_.empty()) {
+ auto message = ReadMessageAsync(fd);
+ if (message == NULL) {
+ return Status::IOError("Failed to read object notification from Plasma socket");
+ }
+
+ std::vector<ObjectID> object_ids;
+ std::vector<int64_t> data_sizes;
+ std::vector<int64_t> metadata_sizes;
+ RETURN_NOT_OK(
+ DecodeNotifications(message.get(), &object_ids, &data_sizes, &metadata_sizes));
+ for (size_t i = 0; i < object_ids.size(); ++i) {
+ pending_notification_.emplace_back(object_ids[i], data_sizes[i], metadata_sizes[i]);
+ }
+ }
+
+ auto notification = pending_notification_.front();
+ *object_id = std::get<0>(notification);
+ *data_size = std::get<1>(notification);
+ *metadata_size = std::get<2>(notification);
+
+ pending_notification_.pop_front();
+
+ return Status::OK();
+}
+
+Status PlasmaClient::Impl::DecodeNotifications(const uint8_t* buffer,
+ std::vector<ObjectID>* object_ids,
+ std::vector<int64_t>* data_sizes,
+ std::vector<int64_t>* metadata_sizes) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+ auto object_info = flatbuffers::GetRoot<fb::PlasmaNotification>(buffer);
+
+ for (size_t i = 0; i < object_info->object_info()->size(); ++i) {
+ auto info = object_info->object_info()->Get(i);
+ ObjectID id = ObjectID::from_binary(info->object_id()->str());
+ object_ids->push_back(id);
+ if (info->is_deletion()) {
+ data_sizes->push_back(-1);
+ metadata_sizes->push_back(-1);
+ } else {
+ data_sizes->push_back(info->data_size());
+ metadata_sizes->push_back(info->metadata_size());
+ }
+ }
+
+ return Status::OK();
+}
+
+Status PlasmaClient::Impl::Connect(const std::string& store_socket_name,
+ const std::string& manager_socket_name,
+ int release_delay, int num_retries) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ RETURN_NOT_OK(ConnectIpcSocketRetry(store_socket_name, num_retries, -1, &store_conn_));
+ if (manager_socket_name != "") {
+ return Status::NotImplemented("plasma manager is no longer supported");
+ }
+ if (release_delay != 0) {
+ ARROW_LOG(WARNING) << "The release_delay parameter in PlasmaClient::Connect "
+ << "is deprecated";
+ }
+ // Send a ConnectRequest to the store to get its memory capacity.
+ RETURN_NOT_OK(SendConnectRequest(store_conn_));
+ std::vector<uint8_t> buffer;
+ RETURN_NOT_OK(PlasmaReceive(store_conn_, MessageType::PlasmaConnectReply, &buffer));
+ RETURN_NOT_OK(ReadConnectReply(buffer.data(), buffer.size(), &store_capacity_));
+ return Status::OK();
+}
+
+Status PlasmaClient::Impl::SetClientOptions(const std::string& client_name,
+ int64_t output_memory_quota) {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+ RETURN_NOT_OK(SendSetOptionsRequest(store_conn_, client_name, output_memory_quota));
+ std::vector<uint8_t> buffer;
+ RETURN_NOT_OK(PlasmaReceive(store_conn_, MessageType::PlasmaSetOptionsReply, &buffer));
+ return ReadSetOptionsReply(buffer.data(), buffer.size());
+}
+
+Status PlasmaClient::Impl::Disconnect() {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+
+ // NOTE: We purposefully do not finish sending release calls for objects in
+ // use, so that we don't duplicate PlasmaClient::Release calls (when handling
+ // a SIGTERM, for example).
+
+ // Close the connections to Plasma. The Plasma store will release the objects
+ // that were in use by us when handling the SIGPIPE.
+ close(store_conn_);
+ store_conn_ = -1;
+ return Status::OK();
+}
+
+std::string PlasmaClient::Impl::DebugString() {
+ std::lock_guard<std::recursive_mutex> guard(client_mutex_);
+ if (!SendGetDebugStringRequest(store_conn_).ok()) {
+ return "error sending request";
+ }
+ std::vector<uint8_t> buffer;
+ if (!PlasmaReceive(store_conn_, MessageType::PlasmaGetDebugStringReply, &buffer).ok()) {
+ return "error receiving reply";
+ }
+ std::string debug_string;
+ if (!ReadGetDebugStringReply(buffer.data(), buffer.size(), &debug_string).ok()) {
+ return "error parsing reply";
+ }
+ return debug_string;
+}
+
+// ----------------------------------------------------------------------
+// PlasmaClient
+
+PlasmaClient::PlasmaClient() : impl_(std::make_shared<PlasmaClient::Impl>()) {}
+
+PlasmaClient::~PlasmaClient() {}
+
+Status PlasmaClient::Connect(const std::string& store_socket_name,
+ const std::string& manager_socket_name, int release_delay,
+ int num_retries) {
+ return impl_->Connect(store_socket_name, manager_socket_name, release_delay,
+ num_retries);
+}
+
+Status PlasmaClient::SetClientOptions(const std::string& client_name,
+ int64_t output_memory_quota) {
+ return impl_->SetClientOptions(client_name, output_memory_quota);
+}
+
+Status PlasmaClient::Create(const ObjectID& object_id, int64_t data_size,
+ const uint8_t* metadata, int64_t metadata_size,
+ std::shared_ptr<Buffer>* data, int device_num,
+ bool evict_if_full) {
+ return impl_->Create(object_id, data_size, metadata, metadata_size, data, device_num,
+ evict_if_full);
+}
+
+Status PlasmaClient::CreateAndSeal(const ObjectID& object_id, const std::string& data,
+ const std::string& metadata, bool evict_if_full) {
+ return impl_->CreateAndSeal(object_id, data, metadata, evict_if_full);
+}
+
+Status PlasmaClient::CreateAndSealBatch(const std::vector<ObjectID>& object_ids,
+ const std::vector<std::string>& data,
+ const std::vector<std::string>& metadata,
+ bool evict_if_full) {
+ return impl_->CreateAndSealBatch(object_ids, data, metadata, evict_if_full);
+}
+
+Status PlasmaClient::Get(const std::vector<ObjectID>& object_ids, int64_t timeout_ms,
+ std::vector<ObjectBuffer>* object_buffers) {
+ return impl_->Get(object_ids, timeout_ms, object_buffers);
+}
+
+Status PlasmaClient::Get(const ObjectID* object_ids, int64_t num_objects,
+ int64_t timeout_ms, ObjectBuffer* object_buffers) {
+ return impl_->Get(object_ids, num_objects, timeout_ms, object_buffers);
+}
+
+Status PlasmaClient::Release(const ObjectID& object_id) {
+ return impl_->Release(object_id);
+}
+
+Status PlasmaClient::Contains(const ObjectID& object_id, bool* has_object) {
+ return impl_->Contains(object_id, has_object);
+}
+
+Status PlasmaClient::List(ObjectTable* objects) { return impl_->List(objects); }
+
+Status PlasmaClient::Abort(const ObjectID& object_id) { return impl_->Abort(object_id); }
+
+Status PlasmaClient::Seal(const ObjectID& object_id) { return impl_->Seal(object_id); }
+
+Status PlasmaClient::Delete(const ObjectID& object_id) {
+ return impl_->Delete(std::vector<ObjectID>{object_id});
+}
+
+Status PlasmaClient::Delete(const std::vector<ObjectID>& object_ids) {
+ return impl_->Delete(object_ids);
+}
+
+Status PlasmaClient::Evict(int64_t num_bytes, int64_t& num_bytes_evicted) {
+ return impl_->Evict(num_bytes, num_bytes_evicted);
+}
+
+Status PlasmaClient::Refresh(const std::vector<ObjectID>& object_ids) {
+ return impl_->Refresh(object_ids);
+}
+
+Status PlasmaClient::Hash(const ObjectID& object_id, uint8_t* digest) {
+ return impl_->Hash(object_id, digest);
+}
+
+Status PlasmaClient::Subscribe(int* fd) { return impl_->Subscribe(fd); }
+
+Status PlasmaClient::GetNotification(int fd, ObjectID* object_id, int64_t* data_size,
+ int64_t* metadata_size) {
+ return impl_->GetNotification(fd, object_id, data_size, metadata_size);
+}
+
+Status PlasmaClient::DecodeNotifications(const uint8_t* buffer,
+ std::vector<ObjectID>* object_ids,
+ std::vector<int64_t>* data_sizes,
+ std::vector<int64_t>* metadata_sizes) {
+ return impl_->DecodeNotifications(buffer, object_ids, data_sizes, metadata_sizes);
+}
+
+Status PlasmaClient::Disconnect() { return impl_->Disconnect(); }
+
+std::string PlasmaClient::DebugString() { return impl_->DebugString(); }
+
+bool PlasmaClient::IsInUse(const ObjectID& object_id) {
+ return impl_->IsInUse(object_id);
+}
+
+int64_t PlasmaClient::store_capacity() { return impl_->store_capacity(); }
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/client.h b/src/arrow/cpp/src/plasma/client.h
new file mode 100644
index 000000000..7a70bba5f
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/client.h
@@ -0,0 +1,309 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/status.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
+#include "plasma/common.h"
+
+using arrow::Buffer;
+using arrow::Status;
+
+namespace plasma {
+
+/// Object buffer data structure.
+struct ObjectBuffer {
+ /// The data buffer.
+ std::shared_ptr<Buffer> data;
+ /// The metadata buffer.
+ std::shared_ptr<Buffer> metadata;
+ /// The device number.
+ int device_num;
+};
+
+class ARROW_EXPORT PlasmaClient {
+ public:
+ PlasmaClient();
+ ~PlasmaClient();
+
+ /// Connect to the local plasma store. Return the resulting connection.
+ ///
+ /// \param store_socket_name The name of the UNIX domain socket to use to
+ /// connect to the Plasma store.
+ /// \param manager_socket_name The name of the UNIX domain socket to use to
+ /// connect to the local Plasma manager. If this is "", then this
+ /// function will not connect to a manager.
+ /// Note that plasma manager is no longer supported, this function
+ /// will return failure if this is not "".
+ /// \param release_delay Deprecated (not used).
+ /// \param num_retries number of attempts to connect to IPC socket, default 50
+ /// \return The return status.
+ Status Connect(const std::string& store_socket_name,
+ const std::string& manager_socket_name = "", int release_delay = 0,
+ int num_retries = -1);
+
+ /// Set runtime options for this client.
+ ///
+ /// \param client_name The name of the client, used in debug messages.
+ /// \param output_memory_quota The memory quota in bytes for objects created by
+ /// this client.
+ Status SetClientOptions(const std::string& client_name, int64_t output_memory_quota);
+
+ /// Create an object in the Plasma Store. Any metadata for this object must be
+ /// be passed in when the object is created.
+ ///
+ /// \param object_id The ID to use for the newly created object.
+ /// \param data_size The size in bytes of the space to be allocated for this
+ /// object's
+ /// data (this does not include space used for metadata).
+ /// \param metadata The object's metadata. If there is no metadata, this
+ /// pointer
+ /// should be NULL.
+ /// \param metadata_size The size in bytes of the metadata. If there is no
+ /// metadata, this should be 0.
+ /// \param data The address of the newly created object will be written here.
+ /// \param device_num The number of the device where the object is being
+ /// created.
+ /// device_num = 0 corresponds to the host,
+ /// device_num = 1 corresponds to GPU0,
+ /// device_num = 2 corresponds to GPU1, etc.
+ /// \param evict_if_full Whether to evict other objects to make space for
+ /// this object.
+ /// \return The return status.
+ ///
+ /// The returned object must be released once it is done with. It must also
+ /// be either sealed or aborted.
+ Status Create(const ObjectID& object_id, int64_t data_size, const uint8_t* metadata,
+ int64_t metadata_size, std::shared_ptr<Buffer>* data, int device_num = 0,
+ bool evict_if_full = true);
+
+ /// Create and seal an object in the object store. This is an optimization
+ /// which allows small objects to be created quickly with fewer messages to
+ /// the store.
+ ///
+ /// \param object_id The ID of the object to create.
+ /// \param data The data for the object to create.
+ /// \param metadata The metadata for the object to create.
+ /// \param evict_if_full Whether to evict other objects to make space for
+ /// this object.
+ /// \return The return status.
+ Status CreateAndSeal(const ObjectID& object_id, const std::string& data,
+ const std::string& metadata, bool evict_if_full = true);
+
+ /// Create and seal multiple objects in the object store. This is an optimization
+ /// of CreateAndSeal to eliminate the cost of IPC per object.
+ ///
+ /// \param object_ids The vector of IDs of the objects to create.
+ /// \param data The vector of data for the objects to create.
+ /// \param metadata The vector of metadata for the objects to create.
+ /// \param evict_if_full Whether to evict other objects to make space for
+ /// these objects.
+ /// \return The return status.
+ Status CreateAndSealBatch(const std::vector<ObjectID>& object_ids,
+ const std::vector<std::string>& data,
+ const std::vector<std::string>& metadata,
+ bool evict_if_full = true);
+
+ /// Get some objects from the Plasma Store. This function will block until the
+ /// objects have all been created and sealed in the Plasma Store or the
+ /// timeout expires.
+ ///
+ /// If an object was not retrieved, the corresponding metadata and data
+ /// fields in the ObjectBuffer structure will evaluate to false.
+ /// Objects are automatically released by the client when their buffers
+ /// get out of scope.
+ ///
+ /// \param object_ids The IDs of the objects to get.
+ /// \param timeout_ms The amount of time in milliseconds to wait before this
+ /// request times out. If this value is -1, then no timeout is set.
+ /// \param[out] object_buffers The object results.
+ /// \return The return status.
+ Status Get(const std::vector<ObjectID>& object_ids, int64_t timeout_ms,
+ std::vector<ObjectBuffer>* object_buffers);
+
+ /// Deprecated variant of Get() that doesn't automatically release buffers
+ /// when they get out of scope.
+ ///
+ /// \param object_ids The IDs of the objects to get.
+ /// \param num_objects The number of object IDs to get.
+ /// \param timeout_ms The amount of time in milliseconds to wait before this
+ /// request times out. If this value is -1, then no timeout is set.
+ /// \param object_buffers An array where the results will be stored.
+ /// \return The return status.
+ ///
+ /// The caller is responsible for releasing any retrieved objects, but it
+ /// should not release objects that were not retrieved.
+ Status Get(const ObjectID* object_ids, int64_t num_objects, int64_t timeout_ms,
+ ObjectBuffer* object_buffers);
+
+ /// Tell Plasma that the client no longer needs the object. This should be
+ /// called after Get() or Create() when the client is done with the object.
+ /// After this call, the buffer returned by Get() is no longer valid.
+ ///
+ /// \param object_id The ID of the object that is no longer needed.
+ /// \return The return status.
+ Status Release(const ObjectID& object_id);
+
+ /// Check if the object store contains a particular object and the object has
+ /// been sealed. The result will be stored in has_object.
+ ///
+ /// @todo: We may want to indicate if the object has been created but not
+ /// sealed.
+ ///
+ /// \param object_id The ID of the object whose presence we are checking.
+ /// \param has_object The function will write true at this address if
+ /// the object is present and false if it is not present.
+ /// \return The return status.
+ Status Contains(const ObjectID& object_id, bool* has_object);
+
+ /// List all the objects in the object store.
+ ///
+ /// This API is experimental and might change in the future.
+ ///
+ /// \param[out] objects ObjectTable of objects in the store. For each entry
+ /// in the map, the following fields are available:
+ /// - metadata_size: Size of the object metadata in bytes
+ /// - data_size: Size of the object data in bytes
+ /// - ref_count: Number of clients referencing the object buffer
+ /// - create_time: Unix timestamp of the object creation
+ /// - construct_duration: Object creation time in seconds
+ /// - state: Is the object still being created or already sealed?
+ /// \return The return status.
+ Status List(ObjectTable* objects);
+
+ /// Abort an unsealed object in the object store. If the abort succeeds, then
+ /// it will be as if the object was never created at all. The unsealed object
+ /// must have only a single reference (the one that would have been removed by
+ /// calling Seal).
+ ///
+ /// \param object_id The ID of the object to abort.
+ /// \return The return status.
+ Status Abort(const ObjectID& object_id);
+
+ /// Seal an object in the object store. The object will be immutable after
+ /// this
+ /// call.
+ ///
+ /// \param object_id The ID of the object to seal.
+ /// \return The return status.
+ Status Seal(const ObjectID& object_id);
+
+ /// Delete an object from the object store. This currently assumes that the
+ /// object is present, has been sealed and not used by another client. Otherwise,
+ /// it is a no operation.
+ ///
+ /// \todo We may want to allow the deletion of objects that are not present or
+ /// haven't been sealed.
+ ///
+ /// \param object_id The ID of the object to delete.
+ /// \return The return status.
+ Status Delete(const ObjectID& object_id);
+
+ /// Delete a list of objects from the object store. This currently assumes that the
+ /// object is present, has been sealed and not used by another client. Otherwise,
+ /// it is a no operation.
+ ///
+ /// \param object_ids The list of IDs of the objects to delete.
+ /// \return The return status. If all the objects are nonexistent, return OK.
+ Status Delete(const std::vector<ObjectID>& object_ids);
+
+ /// Delete objects until we have freed up num_bytes bytes or there are no more
+ /// released objects that can be deleted.
+ ///
+ /// \param num_bytes The number of bytes to try to free up.
+ /// \param num_bytes_evicted Out parameter for total number of bytes of space
+ /// retrieved.
+ /// \return The return status.
+ Status Evict(int64_t num_bytes, int64_t& num_bytes_evicted);
+
+ /// Bump objects up in the LRU cache, i.e. treat them as recently accessed.
+ /// Objects that do not exist in the store will be ignored.
+ ///
+ /// \param object_ids The IDs of the objects to bump.
+ /// \return The return status.
+ Status Refresh(const std::vector<ObjectID>& object_ids);
+
+ /// Compute the hash of an object in the object store.
+ ///
+ /// \param object_id The ID of the object we want to hash.
+ /// \param digest A pointer at which to return the hash digest of the object.
+ /// The pointer must have at least kDigestSize bytes allocated.
+ /// \return The return status.
+ Status Hash(const ObjectID& object_id, uint8_t* digest);
+
+ /// Subscribe to notifications when objects are sealed in the object store.
+ /// Whenever an object is sealed, a message will be written to the client
+ /// socket that is returned by this method.
+ ///
+ /// \param fd Out parameter for the file descriptor the client should use to
+ /// read notifications
+ /// from the object store about sealed objects.
+ /// \return The return status.
+ Status Subscribe(int* fd);
+
+ /// Receive next object notification for this client if Subscribe has been called.
+ ///
+ /// \param fd The file descriptor we are reading the notification from.
+ /// \param object_id Out parameter, the object_id of the object that was sealed.
+ /// \param data_size Out parameter, the data size of the object that was sealed.
+ /// \param metadata_size Out parameter, the metadata size of the object that was sealed.
+ /// \return The return status.
+ Status GetNotification(int fd, ObjectID* object_id, int64_t* data_size,
+ int64_t* metadata_size);
+
+ Status DecodeNotifications(const uint8_t* buffer, std::vector<ObjectID>* object_ids,
+ std::vector<int64_t>* data_sizes,
+ std::vector<int64_t>* metadata_sizes);
+
+ /// Disconnect from the local plasma instance, including the local store and
+ /// manager.
+ ///
+ /// \return The return status.
+ Status Disconnect();
+
+ /// Get the current debug string from the plasma store server.
+ ///
+ /// \return The debug string.
+ std::string DebugString();
+
+ /// Get the memory capacity of the store.
+ ///
+ /// \return Memory capacity of the store in bytes.
+ int64_t store_capacity();
+
+ private:
+ friend class PlasmaBuffer;
+ friend class PlasmaMutableBuffer;
+ FRIEND_TEST(TestPlasmaStore, GetTest);
+ FRIEND_TEST(TestPlasmaStore, LegacyGetTest);
+ FRIEND_TEST(TestPlasmaStore, AbortTest);
+
+ bool IsInUse(const ObjectID& object_id);
+
+ class ARROW_NO_EXPORT Impl;
+ std::shared_ptr<Impl> impl_;
+};
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/common.cc b/src/arrow/cpp/src/plasma/common.cc
new file mode 100644
index 000000000..e7d2643d7
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/common.cc
@@ -0,0 +1,195 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "plasma/common.h"
+
+#include <limits>
+#include <utility>
+
+#include "arrow/util/ubsan.h"
+
+#include "plasma/plasma_generated.h"
+
+namespace fb = plasma::flatbuf;
+
+namespace plasma {
+
+namespace {
+
+const char kErrorDetailTypeId[] = "plasma::PlasmaStatusDetail";
+
+class PlasmaStatusDetail : public arrow::StatusDetail {
+ public:
+ explicit PlasmaStatusDetail(PlasmaErrorCode code) : code_(code) {}
+ const char* type_id() const override { return kErrorDetailTypeId; }
+ std::string ToString() const override {
+ const char* type;
+ switch (code()) {
+ case PlasmaErrorCode::PlasmaObjectExists:
+ type = "Plasma object is exists";
+ break;
+ case PlasmaErrorCode::PlasmaObjectNotFound:
+ type = "Plasma object is not found";
+ break;
+ case PlasmaErrorCode::PlasmaStoreFull:
+ type = "Plasma store is full";
+ break;
+ case PlasmaErrorCode::PlasmaObjectAlreadySealed:
+ type = "Plasma object is already sealed";
+ break;
+ default:
+ type = "Unknown plasma error";
+ break;
+ }
+ return std::string(type);
+ }
+ PlasmaErrorCode code() const { return code_; }
+
+ private:
+ PlasmaErrorCode code_;
+};
+
+bool IsPlasmaStatus(const arrow::Status& status, PlasmaErrorCode code) {
+ if (status.ok()) {
+ return false;
+ }
+ auto* detail = status.detail().get();
+ return detail != nullptr && detail->type_id() == kErrorDetailTypeId &&
+ static_cast<PlasmaStatusDetail*>(detail)->code() == code;
+}
+
+} // namespace
+
+using arrow::Status;
+
+arrow::Status MakePlasmaError(PlasmaErrorCode code, std::string message) {
+ arrow::StatusCode arrow_code = arrow::StatusCode::UnknownError;
+ switch (code) {
+ case PlasmaErrorCode::PlasmaObjectExists:
+ arrow_code = arrow::StatusCode::AlreadyExists;
+ break;
+ case PlasmaErrorCode::PlasmaObjectNotFound:
+ arrow_code = arrow::StatusCode::KeyError;
+ break;
+ case PlasmaErrorCode::PlasmaStoreFull:
+ arrow_code = arrow::StatusCode::CapacityError;
+ break;
+ case PlasmaErrorCode::PlasmaObjectAlreadySealed:
+ // Maybe a stretch?
+ arrow_code = arrow::StatusCode::TypeError;
+ break;
+ }
+ return arrow::Status(arrow_code, std::move(message),
+ std::make_shared<PlasmaStatusDetail>(code));
+}
+
+bool IsPlasmaObjectExists(const arrow::Status& status) {
+ return IsPlasmaStatus(status, PlasmaErrorCode::PlasmaObjectExists);
+}
+bool IsPlasmaObjectNotFound(const arrow::Status& status) {
+ return IsPlasmaStatus(status, PlasmaErrorCode::PlasmaObjectNotFound);
+}
+bool IsPlasmaObjectAlreadySealed(const arrow::Status& status) {
+ return IsPlasmaStatus(status, PlasmaErrorCode::PlasmaObjectAlreadySealed);
+}
+bool IsPlasmaStoreFull(const arrow::Status& status) {
+ return IsPlasmaStatus(status, PlasmaErrorCode::PlasmaStoreFull);
+}
+
+UniqueID UniqueID::from_binary(const std::string& binary) {
+ UniqueID id;
+ std::memcpy(&id, binary.data(), sizeof(id));
+ return id;
+}
+
+const uint8_t* UniqueID::data() const { return id_; }
+
+uint8_t* UniqueID::mutable_data() { return id_; }
+
+std::string UniqueID::binary() const {
+ return std::string(reinterpret_cast<const char*>(id_), kUniqueIDSize);
+}
+
+std::string UniqueID::hex() const {
+ constexpr char hex[] = "0123456789abcdef";
+ std::string result;
+ for (int i = 0; i < kUniqueIDSize; i++) {
+ unsigned int val = id_[i];
+ result.push_back(hex[val >> 4]);
+ result.push_back(hex[val & 0xf]);
+ }
+ return result;
+}
+
+// This code is from https://sites.google.com/site/murmurhash/
+// and is public domain.
+uint64_t MurmurHash64A(const void* key, int len, unsigned int seed) {
+ const uint64_t m = 0xc6a4a7935bd1e995;
+ const int r = 47;
+
+ uint64_t h = seed ^ (len * m);
+
+ const uint64_t* data = reinterpret_cast<const uint64_t*>(key);
+ const uint64_t* end = data + (len / 8);
+
+ while (data != end) {
+ uint64_t k = arrow::util::SafeLoad(data++);
+
+ k *= m;
+ k ^= k >> r;
+ k *= m;
+
+ h ^= k;
+ h *= m;
+ }
+
+ const unsigned char* data2 = reinterpret_cast<const unsigned char*>(data);
+
+ switch (len & 7) {
+ case 7:
+ h ^= uint64_t(data2[6]) << 48; // fall through
+ case 6:
+ h ^= uint64_t(data2[5]) << 40; // fall through
+ case 5:
+ h ^= uint64_t(data2[4]) << 32; // fall through
+ case 4:
+ h ^= uint64_t(data2[3]) << 24; // fall through
+ case 3:
+ h ^= uint64_t(data2[2]) << 16; // fall through
+ case 2:
+ h ^= uint64_t(data2[1]) << 8; // fall through
+ case 1:
+ h ^= uint64_t(data2[0]);
+ h *= m;
+ }
+
+ h ^= h >> r;
+ h *= m;
+ h ^= h >> r;
+
+ return h;
+}
+
+size_t UniqueID::hash() const { return MurmurHash64A(&id_[0], kUniqueIDSize, 0); }
+
+bool UniqueID::operator==(const UniqueID& rhs) const {
+ return std::memcmp(data(), rhs.data(), kUniqueIDSize) == 0;
+}
+
+const PlasmaStoreInfo* plasma_config;
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/common.fbs b/src/arrow/cpp/src/plasma/common.fbs
new file mode 100644
index 000000000..818827a7e
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/common.fbs
@@ -0,0 +1,39 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+namespace plasma.flatbuf;
+
+// Object information data structure.
+table ObjectInfo {
+ // Object ID of this object.
+ object_id: string;
+ // Number of bytes the content of this object occupies in memory.
+ data_size: long;
+ // Number of bytes the metadata of this object occupies in memory.
+ metadata_size: long;
+ // Number of clients using the objects.
+ ref_count: int;
+ // Unix epoch of when this object was created.
+ create_time: long;
+ // How long creation of this object took.
+ construct_duration: long;
+ // Hash of the object content. If the object is not sealed yet this is
+ // an empty string.
+ digest: string;
+ // Specifies if this object was deleted or added.
+ is_deletion: bool;
+}
diff --git a/src/arrow/cpp/src/plasma/common.h b/src/arrow/cpp/src/plasma/common.h
new file mode 100644
index 000000000..071e55ea3
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/common.h
@@ -0,0 +1,155 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <stddef.h>
+
+#include <cstring>
+#include <memory>
+#include <string>
+// TODO(pcm): Convert getopt and sscanf in the store to use more idiomatic C++
+// and get rid of the next three lines:
+#ifndef __STDC_FORMAT_MACROS
+#define __STDC_FORMAT_MACROS
+#endif
+#include <unordered_map>
+
+#include "plasma/compat.h"
+
+#include "arrow/status.h"
+#ifdef PLASMA_CUDA
+#include "arrow/gpu/cuda_api.h"
+#endif
+
+namespace plasma {
+
+enum class ObjectLocation : int32_t { Local, Remote, NotFound };
+
+enum class PlasmaErrorCode : int8_t {
+ PlasmaObjectExists = 1,
+ PlasmaObjectNotFound = 2,
+ PlasmaStoreFull = 3,
+ PlasmaObjectAlreadySealed = 4,
+};
+
+ARROW_EXPORT arrow::Status MakePlasmaError(PlasmaErrorCode code, std::string message);
+/// Return true iff the status indicates an already existing Plasma object.
+ARROW_EXPORT bool IsPlasmaObjectExists(const arrow::Status& status);
+/// Return true iff the status indicates a nonexistent Plasma object.
+ARROW_EXPORT bool IsPlasmaObjectNotFound(const arrow::Status& status);
+/// Return true iff the status indicates an already sealed Plasma object.
+ARROW_EXPORT bool IsPlasmaObjectAlreadySealed(const arrow::Status& status);
+/// Return true iff the status indicates the Plasma store reached its capacity limit.
+ARROW_EXPORT bool IsPlasmaStoreFull(const arrow::Status& status);
+
+constexpr int64_t kUniqueIDSize = 20;
+
+class ARROW_EXPORT UniqueID {
+ public:
+ static UniqueID from_binary(const std::string& binary);
+ bool operator==(const UniqueID& rhs) const;
+ const uint8_t* data() const;
+ uint8_t* mutable_data();
+ std::string binary() const;
+ std::string hex() const;
+ size_t hash() const;
+ static int64_t size() { return kUniqueIDSize; }
+
+ private:
+ uint8_t id_[kUniqueIDSize];
+};
+
+static_assert(std::is_pod<UniqueID>::value, "UniqueID must be plain old data");
+
+typedef UniqueID ObjectID;
+
+/// Size of object hash digests.
+constexpr int64_t kDigestSize = sizeof(uint64_t);
+
+enum class ObjectState : int {
+ /// Object was created but not sealed in the local Plasma Store.
+ PLASMA_CREATED = 1,
+ /// Object is sealed and stored in the local Plasma Store.
+ PLASMA_SEALED = 2,
+ /// Object is evicted to external store.
+ PLASMA_EVICTED = 3,
+};
+
+namespace internal {
+
+struct CudaIpcPlaceholder {};
+
+} // namespace internal
+
+/// This type is used by the Plasma store. It is here because it is exposed to
+/// the eviction policy.
+struct ObjectTableEntry {
+ ObjectTableEntry();
+
+ ~ObjectTableEntry();
+
+ /// Memory mapped file containing the object.
+ int fd;
+ /// Device number.
+ int device_num;
+ /// Size of the underlying map.
+ int64_t map_size;
+ /// Offset from the base of the mmap.
+ ptrdiff_t offset;
+ /// Pointer to the object data. Needed to free the object.
+ uint8_t* pointer;
+ /// Size of the object in bytes.
+ int64_t data_size;
+ /// Size of the object metadata in bytes.
+ int64_t metadata_size;
+ /// Number of clients currently using this object.
+ int ref_count;
+ /// Unix epoch of when this object was created.
+ int64_t create_time;
+ /// How long creation of this object took.
+ int64_t construct_duration;
+
+ /// The state of the object, e.g., whether it is open or sealed.
+ ObjectState state;
+ /// The digest of the object. Used to see if two objects are the same.
+ unsigned char digest[kDigestSize];
+
+#ifdef PLASMA_CUDA
+ /// IPC GPU handle to share with clients.
+ std::shared_ptr<::arrow::cuda::CudaIpcMemHandle> ipc_handle;
+#else
+ std::shared_ptr<internal::CudaIpcPlaceholder> ipc_handle;
+#endif
+};
+
+/// Mapping from ObjectIDs to information about the object.
+typedef std::unordered_map<ObjectID, std::unique_ptr<ObjectTableEntry>> ObjectTable;
+
+/// Globally accessible reference to plasma store configuration.
+/// TODO(pcm): This can be avoided with some refactoring of existing code
+/// by making it possible to pass a context object through dlmalloc.
+struct PlasmaStoreInfo;
+extern const PlasmaStoreInfo* plasma_config;
+} // namespace plasma
+
+namespace std {
+template <>
+struct hash<::plasma::UniqueID> {
+ size_t operator()(const ::plasma::UniqueID& id) const { return id.hash(); }
+};
+} // namespace std
diff --git a/src/arrow/cpp/src/plasma/common_generated.h b/src/arrow/cpp/src/plasma/common_generated.h
new file mode 100644
index 000000000..ba9ef6e72
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/common_generated.h
@@ -0,0 +1,230 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_COMMON_PLASMA_FLATBUF_H_
+#define FLATBUFFERS_GENERATED_COMMON_PLASMA_FLATBUF_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+namespace plasma {
+namespace flatbuf {
+
+struct ObjectInfo;
+struct ObjectInfoBuilder;
+struct ObjectInfoT;
+
+struct ObjectInfoT : public flatbuffers::NativeTable {
+ typedef ObjectInfo TableType;
+ std::string object_id;
+ int64_t data_size;
+ int64_t metadata_size;
+ int32_t ref_count;
+ int64_t create_time;
+ int64_t construct_duration;
+ std::string digest;
+ bool is_deletion;
+ ObjectInfoT()
+ : data_size(0),
+ metadata_size(0),
+ ref_count(0),
+ create_time(0),
+ construct_duration(0),
+ is_deletion(false) {
+ }
+};
+
+struct ObjectInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ObjectInfoT NativeTableType;
+ typedef ObjectInfoBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_ID = 4,
+ VT_DATA_SIZE = 6,
+ VT_METADATA_SIZE = 8,
+ VT_REF_COUNT = 10,
+ VT_CREATE_TIME = 12,
+ VT_CONSTRUCT_DURATION = 14,
+ VT_DIGEST = 16,
+ VT_IS_DELETION = 18
+ };
+ const flatbuffers::String *object_id() const {
+ return GetPointer<const flatbuffers::String *>(VT_OBJECT_ID);
+ }
+ int64_t data_size() const {
+ return GetField<int64_t>(VT_DATA_SIZE, 0);
+ }
+ int64_t metadata_size() const {
+ return GetField<int64_t>(VT_METADATA_SIZE, 0);
+ }
+ int32_t ref_count() const {
+ return GetField<int32_t>(VT_REF_COUNT, 0);
+ }
+ int64_t create_time() const {
+ return GetField<int64_t>(VT_CREATE_TIME, 0);
+ }
+ int64_t construct_duration() const {
+ return GetField<int64_t>(VT_CONSTRUCT_DURATION, 0);
+ }
+ const flatbuffers::String *digest() const {
+ return GetPointer<const flatbuffers::String *>(VT_DIGEST);
+ }
+ bool is_deletion() const {
+ return GetField<uint8_t>(VT_IS_DELETION, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_ID) &&
+ verifier.VerifyString(object_id()) &&
+ VerifyField<int64_t>(verifier, VT_DATA_SIZE) &&
+ VerifyField<int64_t>(verifier, VT_METADATA_SIZE) &&
+ VerifyField<int32_t>(verifier, VT_REF_COUNT) &&
+ VerifyField<int64_t>(verifier, VT_CREATE_TIME) &&
+ VerifyField<int64_t>(verifier, VT_CONSTRUCT_DURATION) &&
+ VerifyOffset(verifier, VT_DIGEST) &&
+ verifier.VerifyString(digest()) &&
+ VerifyField<uint8_t>(verifier, VT_IS_DELETION) &&
+ verifier.EndTable();
+ }
+ ObjectInfoT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ObjectInfoT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ObjectInfo> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ObjectInfoT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ObjectInfoBuilder {
+ typedef ObjectInfo Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_id(flatbuffers::Offset<flatbuffers::String> object_id) {
+ fbb_.AddOffset(ObjectInfo::VT_OBJECT_ID, object_id);
+ }
+ void add_data_size(int64_t data_size) {
+ fbb_.AddElement<int64_t>(ObjectInfo::VT_DATA_SIZE, data_size, 0);
+ }
+ void add_metadata_size(int64_t metadata_size) {
+ fbb_.AddElement<int64_t>(ObjectInfo::VT_METADATA_SIZE, metadata_size, 0);
+ }
+ void add_ref_count(int32_t ref_count) {
+ fbb_.AddElement<int32_t>(ObjectInfo::VT_REF_COUNT, ref_count, 0);
+ }
+ void add_create_time(int64_t create_time) {
+ fbb_.AddElement<int64_t>(ObjectInfo::VT_CREATE_TIME, create_time, 0);
+ }
+ void add_construct_duration(int64_t construct_duration) {
+ fbb_.AddElement<int64_t>(ObjectInfo::VT_CONSTRUCT_DURATION, construct_duration, 0);
+ }
+ void add_digest(flatbuffers::Offset<flatbuffers::String> digest) {
+ fbb_.AddOffset(ObjectInfo::VT_DIGEST, digest);
+ }
+ void add_is_deletion(bool is_deletion) {
+ fbb_.AddElement<uint8_t>(ObjectInfo::VT_IS_DELETION, static_cast<uint8_t>(is_deletion), 0);
+ }
+ explicit ObjectInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ObjectInfoBuilder &operator=(const ObjectInfoBuilder &);
+ flatbuffers::Offset<ObjectInfo> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ObjectInfo>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ObjectInfo> CreateObjectInfo(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> object_id = 0,
+ int64_t data_size = 0,
+ int64_t metadata_size = 0,
+ int32_t ref_count = 0,
+ int64_t create_time = 0,
+ int64_t construct_duration = 0,
+ flatbuffers::Offset<flatbuffers::String> digest = 0,
+ bool is_deletion = false) {
+ ObjectInfoBuilder builder_(_fbb);
+ builder_.add_construct_duration(construct_duration);
+ builder_.add_create_time(create_time);
+ builder_.add_metadata_size(metadata_size);
+ builder_.add_data_size(data_size);
+ builder_.add_digest(digest);
+ builder_.add_ref_count(ref_count);
+ builder_.add_object_id(object_id);
+ builder_.add_is_deletion(is_deletion);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<ObjectInfo> CreateObjectInfoDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *object_id = nullptr,
+ int64_t data_size = 0,
+ int64_t metadata_size = 0,
+ int32_t ref_count = 0,
+ int64_t create_time = 0,
+ int64_t construct_duration = 0,
+ const char *digest = nullptr,
+ bool is_deletion = false) {
+ auto object_id__ = object_id ? _fbb.CreateString(object_id) : 0;
+ auto digest__ = digest ? _fbb.CreateString(digest) : 0;
+ return plasma::flatbuf::CreateObjectInfo(
+ _fbb,
+ object_id__,
+ data_size,
+ metadata_size,
+ ref_count,
+ create_time,
+ construct_duration,
+ digest__,
+ is_deletion);
+}
+
+flatbuffers::Offset<ObjectInfo> CreateObjectInfo(flatbuffers::FlatBufferBuilder &_fbb, const ObjectInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+inline ObjectInfoT *ObjectInfo::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::ObjectInfoT> _o = std::unique_ptr<plasma::flatbuf::ObjectInfoT>(new ObjectInfoT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void ObjectInfo::UnPackTo(ObjectInfoT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_id(); if (_e) _o->object_id = _e->str(); }
+ { auto _e = data_size(); _o->data_size = _e; }
+ { auto _e = metadata_size(); _o->metadata_size = _e; }
+ { auto _e = ref_count(); _o->ref_count = _e; }
+ { auto _e = create_time(); _o->create_time = _e; }
+ { auto _e = construct_duration(); _o->construct_duration = _e; }
+ { auto _e = digest(); if (_e) _o->digest = _e->str(); }
+ { auto _e = is_deletion(); _o->is_deletion = _e; }
+}
+
+inline flatbuffers::Offset<ObjectInfo> ObjectInfo::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ObjectInfoT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateObjectInfo(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ObjectInfo> CreateObjectInfo(flatbuffers::FlatBufferBuilder &_fbb, const ObjectInfoT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ObjectInfoT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_id = _o->object_id.empty() ? 0 : _fbb.CreateString(_o->object_id);
+ auto _data_size = _o->data_size;
+ auto _metadata_size = _o->metadata_size;
+ auto _ref_count = _o->ref_count;
+ auto _create_time = _o->create_time;
+ auto _construct_duration = _o->construct_duration;
+ auto _digest = _o->digest.empty() ? 0 : _fbb.CreateString(_o->digest);
+ auto _is_deletion = _o->is_deletion;
+ return plasma::flatbuf::CreateObjectInfo(
+ _fbb,
+ _object_id,
+ _data_size,
+ _metadata_size,
+ _ref_count,
+ _create_time,
+ _construct_duration,
+ _digest,
+ _is_deletion);
+}
+
+} // namespace flatbuf
+} // namespace plasma
+
+#endif // FLATBUFFERS_GENERATED_COMMON_PLASMA_FLATBUF_H_
diff --git a/src/arrow/cpp/src/plasma/compat.h b/src/arrow/cpp/src/plasma/compat.h
new file mode 100644
index 000000000..504b523da
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/compat.h
@@ -0,0 +1,32 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+// Workaround for multithreading on XCode 9, see
+// https://issues.apache.org/jira/browse/ARROW-1622 and
+// https://github.com/tensorflow/tensorflow/issues/13220#issuecomment-331579775
+// This should be a short-term fix until the problem is fixed upstream.
+#ifdef __APPLE__
+#ifndef _MACH_PORT_T
+#define _MACH_PORT_T
+#include <sys/_types.h> /* __darwin_mach_port_t */
+typedef __darwin_mach_port_t mach_port_t;
+#include <pthread.h>
+mach_port_t pthread_mach_thread_np(pthread_t);
+#endif /* _MACH_PORT_T */
+#endif /* __APPLE__ */
diff --git a/src/arrow/cpp/src/plasma/dlmalloc.cc b/src/arrow/cpp/src/plasma/dlmalloc.cc
new file mode 100644
index 000000000..463e967e0
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/dlmalloc.cc
@@ -0,0 +1,166 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "plasma/malloc.h"
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <unistd.h>
+
+#include <cerrno>
+#include <string>
+#include <vector>
+
+#include "plasma/common.h"
+#include "plasma/plasma.h"
+
+namespace plasma {
+
+void* fake_mmap(size_t);
+int fake_munmap(void*, int64_t);
+
+#define MMAP(s) fake_mmap(s)
+#define MUNMAP(a, s) fake_munmap(a, s)
+#define DIRECT_MMAP(s) fake_mmap(s)
+#define DIRECT_MUNMAP(a, s) fake_munmap(a, s)
+#define USE_DL_PREFIX
+#define HAVE_MORECORE 0
+#define DEFAULT_MMAP_THRESHOLD MAX_SIZE_T
+#define DEFAULT_GRANULARITY ((size_t)128U * 1024U)
+
+#include "plasma/thirdparty/dlmalloc.c" // NOLINT
+
+#undef MMAP
+#undef MUNMAP
+#undef DIRECT_MMAP
+#undef DIRECT_MUNMAP
+#undef USE_DL_PREFIX
+#undef HAVE_MORECORE
+#undef DEFAULT_GRANULARITY
+
+// dlmalloc.c defined DEBUG which will conflict with ARROW_LOG(DEBUG).
+#ifdef DEBUG
+#undef DEBUG
+#endif
+
+constexpr int GRANULARITY_MULTIPLIER = 2;
+
+static void* pointer_advance(void* p, ptrdiff_t n) { return (unsigned char*)p + n; }
+
+static void* pointer_retreat(void* p, ptrdiff_t n) { return (unsigned char*)p - n; }
+
+// Create a buffer. This is creating a temporary file and then
+// immediately unlinking it so we do not leave traces in the system.
+int create_buffer(int64_t size) {
+ int fd;
+ std::string file_template = plasma_config->directory;
+#ifdef _WIN32
+ if (!CreateFileMapping(INVALID_HANDLE_VALUE, NULL, PAGE_READWRITE,
+ (DWORD)((uint64_t)size >> (CHAR_BIT * sizeof(DWORD))),
+ (DWORD)(uint64_t)size, NULL)) {
+ fd = -1;
+ }
+#else
+ file_template += "/plasmaXXXXXX";
+ std::vector<char> file_name(file_template.begin(), file_template.end());
+ file_name.push_back('\0');
+ fd = mkstemp(&file_name[0]);
+ if (fd < 0) {
+ ARROW_LOG(FATAL) << "create_buffer failed to open file " << &file_name[0];
+ return -1;
+ }
+ // Immediately unlink the file so we do not leave traces in the system.
+ if (unlink(&file_name[0]) != 0) {
+ ARROW_LOG(FATAL) << "failed to unlink file " << &file_name[0];
+ return -1;
+ }
+ if (!plasma_config->hugepages_enabled) {
+ // Increase the size of the file to the desired size. This seems not to be
+ // needed for files that are backed by the huge page fs, see also
+ // http://www.mail-archive.com/kvm-devel@lists.sourceforge.net/msg14737.html
+ if (ftruncate(fd, (off_t)size) != 0) {
+ ARROW_LOG(FATAL) << "failed to ftruncate file " << &file_name[0];
+ return -1;
+ }
+ }
+#endif
+ return fd;
+}
+
+void* fake_mmap(size_t size) {
+ // Add kMmapRegionsGap so that the returned pointer is deliberately not
+ // page-aligned. This ensures that the segments of memory returned by
+ // fake_mmap are never contiguous.
+ size += kMmapRegionsGap;
+
+ int fd = create_buffer(size);
+ ARROW_CHECK(fd >= 0) << "Failed to create buffer during mmap";
+ // MAP_POPULATE can be used to pre-populate the page tables for this memory region
+ // which avoids work when accessing the pages later. However it causes long pauses
+ // when mmapping the files. Only supported on Linux.
+ void* pointer = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
+ if (pointer == MAP_FAILED) {
+ ARROW_LOG(ERROR) << "mmap failed with error: " << std::strerror(errno);
+ if (errno == ENOMEM && plasma_config->hugepages_enabled) {
+ ARROW_LOG(ERROR)
+ << " (this probably means you have to increase /proc/sys/vm/nr_hugepages)";
+ }
+ return pointer;
+ }
+
+ // Increase dlmalloc's allocation granularity directly.
+ mparams.granularity *= GRANULARITY_MULTIPLIER;
+
+ MmapRecord& record = mmap_records[pointer];
+ record.fd = fd;
+ record.size = size;
+
+ // We lie to dlmalloc about where mapped memory actually lives.
+ pointer = pointer_advance(pointer, kMmapRegionsGap);
+ ARROW_LOG(DEBUG) << pointer << " = fake_mmap(" << size << ")";
+ return pointer;
+}
+
+int fake_munmap(void* addr, int64_t size) {
+ ARROW_LOG(DEBUG) << "fake_munmap(" << addr << ", " << size << ")";
+ addr = pointer_retreat(addr, kMmapRegionsGap);
+ size += kMmapRegionsGap;
+
+ auto entry = mmap_records.find(addr);
+
+ if (entry == mmap_records.end() || entry->second.size != size) {
+ // Reject requests to munmap that don't directly match previous
+ // calls to mmap, to prevent dlmalloc from trimming.
+ return -1;
+ }
+
+ int r = munmap(addr, size);
+ if (r == 0) {
+ close(entry->second.fd);
+ }
+
+ mmap_records.erase(entry);
+ return r;
+}
+
+void SetMallocGranularity(int value) { change_mparam(M_GRANULARITY, value); }
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/events.cc b/src/arrow/cpp/src/plasma/events.cc
new file mode 100644
index 000000000..28ff12675
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/events.cc
@@ -0,0 +1,107 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "plasma/events.h"
+
+#include <utility>
+
+#include <errno.h>
+
+extern "C" {
+#include "plasma/thirdparty/ae/ae.h"
+}
+
+namespace plasma {
+
+// Verify that the constants defined in events.h are defined correctly.
+static_assert(kEventLoopTimerDone == AE_NOMORE, "constant defined incorrectly");
+static_assert(kEventLoopOk == AE_OK, "constant defined incorrectly");
+static_assert(kEventLoopRead == AE_READABLE, "constant defined incorrectly");
+static_assert(kEventLoopWrite == AE_WRITABLE, "constant defined incorrectly");
+
+void EventLoop::FileEventCallback(aeEventLoop* loop, int fd, void* context, int events) {
+ FileCallback* callback = reinterpret_cast<FileCallback*>(context);
+ (*callback)(events);
+}
+
+int EventLoop::TimerEventCallback(aeEventLoop* loop, TimerID timer_id, void* context) {
+ TimerCallback* callback = reinterpret_cast<TimerCallback*>(context);
+ return (*callback)(timer_id);
+}
+
+constexpr int kInitialEventLoopSize = 1024;
+
+EventLoop::EventLoop() { loop_ = aeCreateEventLoop(kInitialEventLoopSize); }
+
+bool EventLoop::AddFileEvent(int fd, int events, const FileCallback& callback) {
+ if (file_callbacks_.find(fd) != file_callbacks_.end()) {
+ return false;
+ }
+ auto data = std::unique_ptr<FileCallback>(new FileCallback(callback));
+ void* context = reinterpret_cast<void*>(data.get());
+ // Try to add the file descriptor.
+ int err = aeCreateFileEvent(loop_, fd, events, EventLoop::FileEventCallback, context);
+ // If it cannot be added, increase the size of the event loop.
+ if (err == AE_ERR && errno == ERANGE) {
+ err = aeResizeSetSize(loop_, 3 * aeGetSetSize(loop_) / 2);
+ if (err != AE_OK) {
+ return false;
+ }
+ err = aeCreateFileEvent(loop_, fd, events, EventLoop::FileEventCallback, context);
+ }
+ // In any case, test if there were errors.
+ if (err == AE_OK) {
+ file_callbacks_.emplace(fd, std::move(data));
+ return true;
+ }
+ return false;
+}
+
+void EventLoop::RemoveFileEvent(int fd) {
+ aeDeleteFileEvent(loop_, fd, AE_READABLE | AE_WRITABLE);
+ file_callbacks_.erase(fd);
+}
+
+void EventLoop::Start() { aeMain(loop_); }
+
+void EventLoop::Stop() { aeStop(loop_); }
+
+void EventLoop::Shutdown() {
+ if (loop_ != nullptr) {
+ aeDeleteEventLoop(loop_);
+ loop_ = nullptr;
+ }
+}
+
+EventLoop::~EventLoop() { Shutdown(); }
+
+int64_t EventLoop::AddTimer(int64_t timeout, const TimerCallback& callback) {
+ auto data = std::unique_ptr<TimerCallback>(new TimerCallback(callback));
+ void* context = reinterpret_cast<void*>(data.get());
+ int64_t timer_id =
+ aeCreateTimeEvent(loop_, timeout, EventLoop::TimerEventCallback, context, NULL);
+ timer_callbacks_.emplace(timer_id, std::move(data));
+ return timer_id;
+}
+
+int EventLoop::RemoveTimer(int64_t timer_id) {
+ int err = aeDeleteTimeEvent(loop_, timer_id);
+ timer_callbacks_.erase(timer_id);
+ return err;
+}
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/events.h b/src/arrow/cpp/src/plasma/events.h
new file mode 100644
index 000000000..7b08d4443
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/events.h
@@ -0,0 +1,108 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <unordered_map>
+
+struct aeEventLoop;
+
+namespace plasma {
+
+// The constants below are defined using hardcoded values taken from ae.h so
+// that ae.h does not need to be included in this file.
+
+/// Constant specifying that the timer is done and it will be removed.
+constexpr int kEventLoopTimerDone = -1; // AE_NOMORE
+
+/// A successful status.
+constexpr int kEventLoopOk = 0; // AE_OK
+
+/// Read event on the file descriptor.
+constexpr int kEventLoopRead = 1; // AE_READABLE
+
+/// Write event on the file descriptor.
+constexpr int kEventLoopWrite = 2; // AE_WRITABLE
+
+typedef long long TimerID; // NOLINT
+
+class EventLoop {
+ public:
+ // Signature of the handler that will be called when there is a new event
+ // on the file descriptor that this handler has been registered for.
+ //
+ // The arguments are the event flags (read or write).
+ using FileCallback = std::function<void(int)>;
+
+ // This handler will be called when a timer times out. The timer id is
+ // passed as an argument. The return is the number of milliseconds the timer
+ // shall be reset to or kEventLoopTimerDone if the timer shall not be
+ // triggered again.
+ using TimerCallback = std::function<int(int64_t)>;
+
+ EventLoop();
+
+ ~EventLoop();
+
+ /// Add a new file event handler to the event loop.
+ ///
+ /// \param fd The file descriptor we are listening to.
+ /// \param events The flags for events we are listening to (read or write).
+ /// \param callback The callback that will be called when the event happens.
+ /// \return Returns true if the event handler was added successfully.
+ bool AddFileEvent(int fd, int events, const FileCallback& callback);
+
+ /// Remove a file event handler from the event loop.
+ ///
+ /// \param fd The file descriptor of the event handler.
+ void RemoveFileEvent(int fd);
+
+ /// Register a handler that will be called after a time slice of
+ /// "timeout" milliseconds.
+ ///
+ /// \param timeout The timeout in milliseconds.
+ /// \param callback The callback for the timeout.
+ /// \return The ID of the newly created timer.
+ int64_t AddTimer(int64_t timeout, const TimerCallback& callback);
+
+ /// Remove a timer handler from the event loop.
+ ///
+ /// \param timer_id The ID of the timer that is to be removed.
+ /// \return The ae.c error code. TODO(pcm): needs to be standardized
+ int RemoveTimer(int64_t timer_id);
+
+ /// \brief Run the event loop.
+ void Start();
+
+ /// \brief Stop the event loop
+ void Stop();
+
+ void Shutdown();
+
+ private:
+ static void FileEventCallback(aeEventLoop* loop, int fd, void* context, int events);
+
+ static int TimerEventCallback(aeEventLoop* loop, TimerID timer_id, void* context);
+
+ aeEventLoop* loop_;
+ std::unordered_map<int, std::unique_ptr<FileCallback>> file_callbacks_;
+ std::unordered_map<int64_t, std::unique_ptr<TimerCallback>> timer_callbacks_;
+};
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/eviction_policy.cc b/src/arrow/cpp/src/plasma/eviction_policy.cc
new file mode 100644
index 000000000..c3b786785
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/eviction_policy.cc
@@ -0,0 +1,175 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "plasma/eviction_policy.h"
+#include "plasma/plasma_allocator.h"
+
+#include <algorithm>
+#include <sstream>
+
+namespace plasma {
+
+void LRUCache::Add(const ObjectID& key, int64_t size) {
+ auto it = item_map_.find(key);
+ ARROW_CHECK(it == item_map_.end());
+ // Note that it is important to use a list so the iterators stay valid.
+ item_list_.emplace_front(key, size);
+ item_map_.emplace(key, item_list_.begin());
+ used_capacity_ += size;
+}
+
+int64_t LRUCache::Remove(const ObjectID& key) {
+ auto it = item_map_.find(key);
+ if (it == item_map_.end()) {
+ return -1;
+ }
+ int64_t size = it->second->second;
+ used_capacity_ -= size;
+ item_list_.erase(it->second);
+ item_map_.erase(it);
+ ARROW_CHECK(used_capacity_ >= 0) << DebugString();
+ return size;
+}
+
+void LRUCache::AdjustCapacity(int64_t delta) {
+ ARROW_LOG(INFO) << "adjusting global lru capacity from " << Capacity() << " to "
+ << (Capacity() + delta) << " (max " << OriginalCapacity() << ")";
+ capacity_ += delta;
+ ARROW_CHECK(used_capacity_ >= 0) << DebugString();
+}
+
+int64_t LRUCache::Capacity() const { return capacity_; }
+
+int64_t LRUCache::OriginalCapacity() const { return original_capacity_; }
+
+int64_t LRUCache::RemainingCapacity() const { return capacity_ - used_capacity_; }
+
+void LRUCache::Foreach(std::function<void(const ObjectID&)> f) {
+ for (auto& pair : item_list_) {
+ f(pair.first);
+ }
+}
+
+std::string LRUCache::DebugString() const {
+ std::stringstream result;
+ result << "\n(" << name_ << ") capacity: " << Capacity();
+ result << "\n(" << name_
+ << ") used: " << 100. * (1. - (RemainingCapacity() / (double)OriginalCapacity()))
+ << "%";
+ result << "\n(" << name_ << ") num objects: " << item_map_.size();
+ result << "\n(" << name_ << ") num evictions: " << num_evictions_total_;
+ result << "\n(" << name_ << ") bytes evicted: " << bytes_evicted_total_;
+ return result.str();
+}
+
+int64_t LRUCache::ChooseObjectsToEvict(int64_t num_bytes_required,
+ std::vector<ObjectID>* objects_to_evict) {
+ int64_t bytes_evicted = 0;
+ auto it = item_list_.end();
+ while (bytes_evicted < num_bytes_required && it != item_list_.begin()) {
+ it--;
+ objects_to_evict->push_back(it->first);
+ bytes_evicted += it->second;
+ bytes_evicted_total_ += it->second;
+ num_evictions_total_ += 1;
+ }
+ return bytes_evicted;
+}
+
+EvictionPolicy::EvictionPolicy(PlasmaStoreInfo* store_info, int64_t max_size)
+ : pinned_memory_bytes_(0), store_info_(store_info), cache_("global lru", max_size) {}
+
+int64_t EvictionPolicy::ChooseObjectsToEvict(int64_t num_bytes_required,
+ std::vector<ObjectID>* objects_to_evict) {
+ int64_t bytes_evicted =
+ cache_.ChooseObjectsToEvict(num_bytes_required, objects_to_evict);
+ // Update the LRU cache.
+ for (auto& object_id : *objects_to_evict) {
+ cache_.Remove(object_id);
+ }
+ return bytes_evicted;
+}
+
+void EvictionPolicy::ObjectCreated(const ObjectID& object_id, Client* client,
+ bool is_create) {
+ cache_.Add(object_id, GetObjectSize(object_id));
+}
+
+bool EvictionPolicy::SetClientQuota(Client* client, int64_t output_memory_quota) {
+ return false;
+}
+
+bool EvictionPolicy::EnforcePerClientQuota(Client* client, int64_t size, bool is_create,
+ std::vector<ObjectID>* objects_to_evict) {
+ return true;
+}
+
+void EvictionPolicy::ClientDisconnected(Client* client) {}
+
+bool EvictionPolicy::RequireSpace(int64_t size, std::vector<ObjectID>* objects_to_evict) {
+ // Check if there is enough space to create the object.
+ int64_t required_space =
+ PlasmaAllocator::Allocated() + size - PlasmaAllocator::GetFootprintLimit();
+ // Try to free up at least as much space as we need right now but ideally
+ // up to 20% of the total capacity.
+ int64_t space_to_free =
+ std::max(required_space, PlasmaAllocator::GetFootprintLimit() / 5);
+ ARROW_LOG(DEBUG) << "not enough space to create this object, so evicting objects";
+ // Choose some objects to evict, and update the return pointers.
+ int64_t num_bytes_evicted = ChooseObjectsToEvict(space_to_free, objects_to_evict);
+ ARROW_LOG(INFO) << "There is not enough space to create this object, so evicting "
+ << objects_to_evict->size() << " objects to free up "
+ << num_bytes_evicted << " bytes. The number of bytes in use (before "
+ << "this eviction) is " << PlasmaAllocator::Allocated() << ".";
+ return num_bytes_evicted >= required_space && num_bytes_evicted > 0;
+}
+
+void EvictionPolicy::BeginObjectAccess(const ObjectID& object_id) {
+ // If the object is in the LRU cache, remove it.
+ cache_.Remove(object_id);
+ pinned_memory_bytes_ += GetObjectSize(object_id);
+}
+
+void EvictionPolicy::EndObjectAccess(const ObjectID& object_id) {
+ auto size = GetObjectSize(object_id);
+ // Add the object to the LRU cache.
+ cache_.Add(object_id, size);
+ pinned_memory_bytes_ -= size;
+}
+
+void EvictionPolicy::RemoveObject(const ObjectID& object_id) {
+ // If the object is in the LRU cache, remove it.
+ cache_.Remove(object_id);
+}
+
+void EvictionPolicy::RefreshObjects(const std::vector<ObjectID>& object_ids) {
+ for (const auto& object_id : object_ids) {
+ int64_t size = cache_.Remove(object_id);
+ if (size != -1) {
+ cache_.Add(object_id, size);
+ }
+ }
+}
+
+int64_t EvictionPolicy::GetObjectSize(const ObjectID& object_id) const {
+ auto entry = store_info_->objects[object_id].get();
+ return entry->data_size + entry->metadata_size;
+}
+
+std::string EvictionPolicy::DebugString() const { return cache_.DebugString(); }
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/eviction_policy.h b/src/arrow/cpp/src/plasma/eviction_policy.h
new file mode 100644
index 000000000..6c13ecf6b
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/eviction_policy.h
@@ -0,0 +1,209 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <functional>
+#include <list>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "plasma/common.h"
+#include "plasma/plasma.h"
+
+namespace plasma {
+
+// ==== The eviction policy ====
+//
+// This file contains declaration for all functions and data structures that
+// need to be provided if you want to implement a new eviction algorithm for the
+// Plasma store.
+//
+// It does not implement memory quotas; see quota_aware_policy for that.
+
+class LRUCache {
+ public:
+ LRUCache(const std::string& name, int64_t size)
+ : name_(name),
+ original_capacity_(size),
+ capacity_(size),
+ used_capacity_(0),
+ num_evictions_total_(0),
+ bytes_evicted_total_(0) {}
+
+ void Add(const ObjectID& key, int64_t size);
+
+ int64_t Remove(const ObjectID& key);
+
+ int64_t ChooseObjectsToEvict(int64_t num_bytes_required,
+ std::vector<ObjectID>* objects_to_evict);
+
+ int64_t OriginalCapacity() const;
+
+ int64_t Capacity() const;
+
+ int64_t RemainingCapacity() const;
+
+ void AdjustCapacity(int64_t delta);
+
+ void Foreach(std::function<void(const ObjectID&)>);
+
+ std::string DebugString() const;
+
+ private:
+ /// A doubly-linked list containing the items in the cache and
+ /// their sizes in LRU order.
+ typedef std::list<std::pair<ObjectID, int64_t>> ItemList;
+ ItemList item_list_;
+ /// A hash table mapping the object ID of an object in the cache to its
+ /// location in the doubly linked list item_list_.
+ std::unordered_map<ObjectID, ItemList::iterator> item_map_;
+
+ /// The name of this cache, used for debugging purposes only.
+ const std::string name_;
+ /// The original (max) capacity of this cache in bytes.
+ const int64_t original_capacity_;
+ /// The current capacity, which must be <= the original capacity.
+ int64_t capacity_;
+ /// The number of bytes used of the available capacity.
+ int64_t used_capacity_;
+ /// The number of objects evicted from this cache.
+ int64_t num_evictions_total_;
+ /// The number of bytes evicted from this cache.
+ int64_t bytes_evicted_total_;
+};
+
+/// The eviction policy.
+class EvictionPolicy {
+ public:
+ /// Construct an eviction policy.
+ ///
+ /// \param store_info Information about the Plasma store that is exposed
+ /// to the eviction policy.
+ /// \param max_size Max size in bytes total of objects to store.
+ explicit EvictionPolicy(PlasmaStoreInfo* store_info, int64_t max_size);
+
+ /// Destroy an eviction policy.
+ virtual ~EvictionPolicy() {}
+
+ /// This method will be called whenever an object is first created in order to
+ /// add it to the LRU cache. This is done so that the first time, the Plasma
+ /// store calls begin_object_access, we can remove the object from the LRU
+ /// cache.
+ ///
+ /// \param object_id The object ID of the object that was created.
+ /// \param client The pointer to the client.
+ /// \param is_create Whether we are creating a new object (vs reading an object).
+ virtual void ObjectCreated(const ObjectID& object_id, Client* client, bool is_create);
+
+ /// Set quota for a client.
+ ///
+ /// \param client The pointer to the client.
+ /// \param output_memory_quota Set the quota for this client. This can only be
+ /// called once per client. This is effectively the equivalent of giving
+ /// the client its own LRU cache instance. The memory for this is taken
+ /// out of the capacity of the global LRU cache for the client lifetime.
+ ///
+ /// \return True if enough space can be reserved for the given client quota.
+ virtual bool SetClientQuota(Client* client, int64_t output_memory_quota);
+
+ /// Determine what objects need to be evicted to enforce the given client's quota.
+ ///
+ /// \param client The pointer to the client creating the object.
+ /// \param size The size of the object to create.
+ /// \param is_create Whether we are creating a new object (vs reading an object).
+ /// \param objects_to_evict The object IDs that were chosen for eviction will
+ /// be stored into this vector.
+ ///
+ /// \return True if enough space could be freed and false otherwise.
+ virtual bool EnforcePerClientQuota(Client* client, int64_t size, bool is_create,
+ std::vector<ObjectID>* objects_to_evict);
+
+ /// Called to clean up any resources allocated by this client. This merges any
+ /// per-client LRU queue created by SetClientQuota into the global LRU queue.
+ ///
+ /// \param client The pointer to the client.
+ virtual void ClientDisconnected(Client* client);
+
+ /// This method will be called when the Plasma store needs more space, perhaps
+ /// to create a new object. When this method is called, the eviction
+ /// policy will assume that the objects chosen to be evicted will in fact be
+ /// evicted from the Plasma store by the caller.
+ ///
+ /// \param size The size in bytes of the new object, including both data and
+ /// metadata.
+ /// \param objects_to_evict The object IDs that were chosen for eviction will
+ /// be stored into this vector.
+ /// \return True if enough space can be freed and false otherwise.
+ virtual bool RequireSpace(int64_t size, std::vector<ObjectID>* objects_to_evict);
+
+ /// This method will be called whenever an unused object in the Plasma store
+ /// starts to be used. When this method is called, the eviction policy will
+ /// assume that the objects chosen to be evicted will in fact be evicted from
+ /// the Plasma store by the caller.
+ ///
+ /// \param object_id The ID of the object that is now being used.
+ virtual void BeginObjectAccess(const ObjectID& object_id);
+
+ /// This method will be called whenever an object in the Plasma store that was
+ /// being used is no longer being used. When this method is called, the
+ /// eviction policy will assume that the objects chosen to be evicted will in
+ /// fact be evicted from the Plasma store by the caller.
+ ///
+ /// \param object_id The ID of the object that is no longer being used.
+ virtual void EndObjectAccess(const ObjectID& object_id);
+
+ /// Choose some objects to evict from the Plasma store. When this method is
+ /// called, the eviction policy will assume that the objects chosen to be
+ /// evicted will in fact be evicted from the Plasma store by the caller.
+ ///
+ /// \note This method is not part of the API. It is exposed in the header file
+ /// only for testing.
+ ///
+ /// \param num_bytes_required The number of bytes of space to try to free up.
+ /// \param objects_to_evict The object IDs that were chosen for eviction will
+ /// be stored into this vector.
+ /// \return The total number of bytes of space chosen to be evicted.
+ virtual int64_t ChooseObjectsToEvict(int64_t num_bytes_required,
+ std::vector<ObjectID>* objects_to_evict);
+
+ /// This method will be called when an object is going to be removed
+ ///
+ /// \param object_id The ID of the object that is now being used.
+ virtual void RemoveObject(const ObjectID& object_id);
+
+ virtual void RefreshObjects(const std::vector<ObjectID>& object_ids);
+
+ /// Returns debugging information for this eviction policy.
+ virtual std::string DebugString() const;
+
+ protected:
+ /// Returns the size of the object
+ int64_t GetObjectSize(const ObjectID& object_id) const;
+
+ /// The number of bytes pinned by applications.
+ int64_t pinned_memory_bytes_;
+
+ /// Pointer to the plasma store info.
+ PlasmaStoreInfo* store_info_;
+ /// Datastructure for the LRU cache.
+ LRUCache cache_;
+};
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/external_store.cc b/src/arrow/cpp/src/plasma/external_store.cc
new file mode 100644
index 000000000..8cfbad179
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/external_store.cc
@@ -0,0 +1,63 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <iostream>
+#include <sstream>
+
+#include "arrow/util/memory.h"
+
+#include "plasma/external_store.h"
+
+namespace plasma {
+
+Status ExternalStores::ExtractStoreName(const std::string& endpoint,
+ std::string* store_name) {
+ size_t off = endpoint.find_first_of(':');
+ if (off == std::string::npos) {
+ return Status::Invalid("Malformed endpoint " + endpoint);
+ }
+ *store_name = endpoint.substr(0, off);
+ return Status::OK();
+}
+
+void ExternalStores::RegisterStore(const std::string& store_name,
+ std::shared_ptr<ExternalStore> store) {
+ Stores().insert({store_name, store});
+}
+
+void ExternalStores::DeregisterStore(const std::string& store_name) {
+ auto it = Stores().find(store_name);
+ if (it == Stores().end()) {
+ return;
+ }
+ Stores().erase(it);
+}
+
+std::shared_ptr<ExternalStore> ExternalStores::GetStore(const std::string& store_name) {
+ auto it = Stores().find(store_name);
+ if (it == Stores().end()) {
+ return nullptr;
+ }
+ return it->second;
+}
+
+ExternalStores::StoreMap& ExternalStores::Stores() {
+ static auto* external_stores = new StoreMap();
+ return *external_stores;
+}
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/external_store.h b/src/arrow/cpp/src/plasma/external_store.h
new file mode 100644
index 000000000..c089d06c1
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/external_store.h
@@ -0,0 +1,120 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "plasma/client.h"
+
+namespace plasma {
+
+// ==== The external store ====
+//
+// This file contains declaration for all functions that need to be implemented
+// for an external storage service so that objects evicted from Plasma store
+// can be written to it.
+
+class ExternalStore {
+ public:
+ /// Default constructor.
+ ExternalStore() = default;
+
+ /// Virtual destructor.
+ virtual ~ExternalStore() = default;
+
+ /// Connect to the local plasma store. Return the resulting connection.
+ ///
+ /// \param endpoint The name of the endpoint to connect to the external
+ /// storage service. While the formatting of the endpoint name is
+ /// specific to the implementation of the external store, it always
+ /// starts with {store-name}://, where {store-name} is the name of the
+ /// external store.
+ ///
+ /// \return The return status.
+ virtual Status Connect(const std::string& endpoint) = 0;
+
+ /// This method will be called whenever an object in the Plasma store needs
+ /// to be evicted to the external store.
+ ///
+ /// This API is experimental and might change in the future.
+ ///
+ /// \param ids The IDs of the objects to put.
+ /// \param data The object data to put.
+ /// \return The return status.
+ virtual Status Put(const std::vector<ObjectID>& ids,
+ const std::vector<std::shared_ptr<Buffer>>& data) = 0;
+
+ /// This method will be called whenever an evicted object in the external
+ /// store store needs to be accessed.
+ ///
+ /// This API is experimental and might change in the future.
+ ///
+ /// \param ids The IDs of the objects to get.
+ /// \param buffers List of buffers the data should be written to.
+ /// \return The return status.
+ virtual Status Get(const std::vector<ObjectID>& ids,
+ std::vector<std::shared_ptr<Buffer>> buffers) = 0;
+};
+
+class ExternalStores {
+ public:
+ typedef std::unordered_map<std::string, std::shared_ptr<ExternalStore>> StoreMap;
+ /// Extracts the external store name from the external store endpoint.
+ ///
+ /// \param endpoint The endpoint for the external store.
+ /// \param[out] store_name The name of the external store.
+ /// \return The return status.
+ static Status ExtractStoreName(const std::string& endpoint, std::string* store_name);
+
+ /// Register a new external store.
+ ///
+ /// \param store_name Name of the new external store.
+ /// \param store The new external store object.
+ static void RegisterStore(const std::string& store_name,
+ std::shared_ptr<ExternalStore> store);
+
+ /// Remove an external store from the registry.
+ ///
+ /// \param store_name Name of the external store to remove.
+ static void DeregisterStore(const std::string& store_name);
+
+ /// Obtain the external store given its name.
+ ///
+ /// \param store_name Name of the external store.
+ /// \return The external store object.
+ static std::shared_ptr<ExternalStore> GetStore(const std::string& store_name);
+
+ private:
+ /// Obtain mapping between external store names and store instances.
+ ///
+ /// \return Mapping between external store names and store instances.
+ static StoreMap& Stores();
+};
+
+#define REGISTER_EXTERNAL_STORE(name, store) \
+ class store##Class { \
+ public: \
+ store##Class() { ExternalStores::RegisterStore(name, std::make_shared<store>()); } \
+ ~store##Class() { ExternalStores::DeregisterStore(name); } \
+ }; \
+ store##Class singleton_##store = store##Class()
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/fling.cc b/src/arrow/cpp/src/plasma/fling.cc
new file mode 100644
index 000000000..f0960aab6
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/fling.cc
@@ -0,0 +1,129 @@
+// Copyright 2013 Sharvil Nanavati
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "plasma/fling.h"
+
+#include <string.h>
+
+#include "arrow/util/logging.h"
+
+void init_msg(struct msghdr* msg, struct iovec* iov, char* buf, size_t buf_len) {
+ iov->iov_base = buf;
+ iov->iov_len = 1;
+
+ msg->msg_iov = iov;
+ msg->msg_iovlen = 1;
+ msg->msg_control = buf;
+ msg->msg_controllen = static_cast<socklen_t>(buf_len);
+ msg->msg_name = NULL;
+ msg->msg_namelen = 0;
+}
+
+int send_fd(int conn, int fd) {
+ struct msghdr msg;
+ struct iovec iov;
+ char buf[CMSG_SPACE(sizeof(int))];
+ memset(&buf, 0, CMSG_SPACE(sizeof(int)));
+
+ init_msg(&msg, &iov, buf, sizeof(buf));
+
+ struct cmsghdr* header = CMSG_FIRSTHDR(&msg);
+ if (header == nullptr) {
+ return -1;
+ }
+ header->cmsg_level = SOL_SOCKET;
+ header->cmsg_type = SCM_RIGHTS;
+ header->cmsg_len = CMSG_LEN(sizeof(int));
+ memcpy(CMSG_DATA(header), reinterpret_cast<void*>(&fd), sizeof(int));
+
+ // Send file descriptor.
+ while (true) {
+ ssize_t r = sendmsg(conn, &msg, 0);
+ if (r < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
+ continue;
+ } else if (errno == EMSGSIZE) {
+ ARROW_LOG(WARNING) << "Failed to send file descriptor"
+ << " (errno = EMSGSIZE), retrying.";
+ // If we failed to send the file descriptor, loop until we have sent it
+ // successfully. TODO(rkn): This is problematic for two reasons. First
+ // of all, sending the file descriptor should just succeed without any
+ // errors, but sometimes I see a "Message too long" error number.
+ // Second, looping like this allows a client to potentially block the
+ // plasma store event loop which should never happen.
+ continue;
+ } else {
+ ARROW_LOG(INFO) << "Error in send_fd (errno = " << errno << ")";
+ return static_cast<int>(r);
+ }
+ } else if (r == 0) {
+ ARROW_LOG(INFO) << "Encountered unexpected EOF";
+ return 0;
+ } else {
+ ARROW_CHECK(r > 0);
+ return static_cast<int>(r);
+ }
+ }
+}
+
+int recv_fd(int conn) {
+ struct msghdr msg;
+ struct iovec iov;
+ char buf[CMSG_SPACE(sizeof(int))];
+ init_msg(&msg, &iov, buf, sizeof(buf));
+
+ while (true) {
+ ssize_t r = recvmsg(conn, &msg, 0);
+ if (r == -1) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
+ continue;
+ } else {
+ ARROW_LOG(INFO) << "Error in recv_fd (errno = " << errno << ")";
+ return -1;
+ }
+ } else {
+ break;
+ }
+ }
+
+ int found_fd = -1;
+ int oh_noes = 0;
+ for (struct cmsghdr* header = CMSG_FIRSTHDR(&msg); header != NULL;
+ header = CMSG_NXTHDR(&msg, header))
+ if (header->cmsg_level == SOL_SOCKET && header->cmsg_type == SCM_RIGHTS) {
+ ssize_t count = (header->cmsg_len -
+ (CMSG_DATA(header) - reinterpret_cast<unsigned char*>(header))) /
+ sizeof(int);
+ for (int i = 0; i < count; ++i) {
+ int fd = (reinterpret_cast<int*>(CMSG_DATA(header)))[i];
+ if (found_fd == -1) {
+ found_fd = fd;
+ } else {
+ close(fd);
+ oh_noes = 1;
+ }
+ }
+ }
+
+ // The sender sent us more than one file descriptor. We've closed
+ // them all to prevent fd leaks but notify the caller that we got
+ // a bad message.
+ if (oh_noes) {
+ close(found_fd);
+ errno = EBADMSG;
+ return -1;
+ }
+
+ return found_fd;
+}
diff --git a/src/arrow/cpp/src/plasma/fling.h b/src/arrow/cpp/src/plasma/fling.h
new file mode 100644
index 000000000..d1582c3c8
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/fling.h
@@ -0,0 +1,52 @@
+// Copyright 2013 Sharvil Nanavati
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// FLING: Exchanging file descriptors over sockets
+//
+// This is a little library for sending file descriptors over a socket
+// between processes. The reason for doing that (as opposed to using
+// filenames to share the files) is so (a) no files remain in the
+// filesystem after all the processes terminate, (b) to make sure that
+// there are no name collisions and (c) to be able to control who has
+// access to the data.
+//
+// Most of the code is from https://github.com/sharvil/flingfd
+
+#include <errno.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+// This is necessary for Mac OS X, see http://www.apuebook.com/faqs2e.html
+// (10).
+#if !defined(CMSG_SPACE) && !defined(CMSG_LEN)
+#define CMSG_SPACE(len) (__DARWIN_ALIGN32(sizeof(struct cmsghdr)) + __DARWIN_ALIGN32(len))
+#define CMSG_LEN(len) (__DARWIN_ALIGN32(sizeof(struct cmsghdr)) + (len))
+#endif
+
+void init_msg(struct msghdr* msg, struct iovec* iov, char* buf, size_t buf_len);
+
+// Send a file descriptor over a unix domain socket.
+//
+// \param conn Unix domain socket to send the file descriptor over.
+// \param fd File descriptor to send over.
+// \return Status code which is < 0 on failure.
+int send_fd(int conn, int fd);
+
+// Receive a file descriptor over a unix domain socket.
+//
+// \param conn Unix domain socket to receive the file descriptor from.
+// \return File descriptor or a value < 0 on failure.
+int recv_fd(int conn);
diff --git a/src/arrow/cpp/src/plasma/hash_table_store.cc b/src/arrow/cpp/src/plasma/hash_table_store.cc
new file mode 100644
index 000000000..b77d3693f
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/hash_table_store.cc
@@ -0,0 +1,58 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <string>
+
+#include "arrow/util/logging.h"
+
+#include "plasma/hash_table_store.h"
+
+namespace plasma {
+
+Status HashTableStore::Connect(const std::string& endpoint) { return Status::OK(); }
+
+Status HashTableStore::Put(const std::vector<ObjectID>& ids,
+ const std::vector<std::shared_ptr<Buffer>>& data) {
+ for (size_t i = 0; i < ids.size(); ++i) {
+ table_[ids[i]] = data[i]->ToString();
+ }
+ return Status::OK();
+}
+
+Status HashTableStore::Get(const std::vector<ObjectID>& ids,
+ std::vector<std::shared_ptr<Buffer>> buffers) {
+ ARROW_CHECK(ids.size() == buffers.size());
+ for (size_t i = 0; i < ids.size(); ++i) {
+ bool valid;
+ HashTable::iterator result;
+ {
+ result = table_.find(ids[i]);
+ valid = result != table_.end();
+ }
+ if (valid) {
+ ARROW_CHECK(buffers[i]->size() == static_cast<int64_t>(result->second.size()));
+ std::memcpy(buffers[i]->mutable_data(), result->second.data(),
+ result->second.size());
+ }
+ }
+ return Status::OK();
+}
+
+REGISTER_EXTERNAL_STORE("hashtable", HashTableStore);
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/hash_table_store.h b/src/arrow/cpp/src/plasma/hash_table_store.h
new file mode 100644
index 000000000..7940ae2db
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/hash_table_store.h
@@ -0,0 +1,50 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "plasma/external_store.h"
+
+namespace plasma {
+
+// This is a sample implementation for an external store, for illustration
+// purposes only.
+
+class HashTableStore : public ExternalStore {
+ public:
+ HashTableStore() = default;
+
+ Status Connect(const std::string& endpoint) override;
+
+ Status Get(const std::vector<ObjectID>& ids,
+ std::vector<std::shared_ptr<Buffer>> buffers) override;
+
+ Status Put(const std::vector<ObjectID>& ids,
+ const std::vector<std::shared_ptr<Buffer>>& data) override;
+
+ private:
+ typedef std::unordered_map<ObjectID, std::string> HashTable;
+
+ HashTable table_;
+};
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/io.cc b/src/arrow/cpp/src/plasma/io.cc
new file mode 100644
index 000000000..002f4e999
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/io.cc
@@ -0,0 +1,250 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "plasma/io.h"
+
+#include <cstdint>
+#include <memory>
+#include <sstream>
+
+#include "arrow/status.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/logging.h"
+
+#include "plasma/common.h"
+#include "plasma/plasma_generated.h"
+
+using arrow::Status;
+
+/// Number of times we try connecting to a socket.
+constexpr int64_t kNumConnectAttempts = 80;
+/// Time to wait between connection attempts to a socket.
+constexpr int64_t kConnectTimeoutMs = 100;
+
+namespace plasma {
+
+using flatbuf::MessageType;
+
+Status WriteBytes(int fd, uint8_t* cursor, size_t length) {
+ ssize_t nbytes = 0;
+ size_t bytesleft = length;
+ size_t offset = 0;
+ while (bytesleft > 0) {
+ // While we haven't written the whole message, write to the file descriptor,
+ // advance the cursor, and decrease the amount left to write.
+ nbytes = write(fd, cursor + offset, bytesleft);
+ if (nbytes < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
+ continue;
+ }
+ return Status::IOError(strerror(errno));
+ } else if (nbytes == 0) {
+ return Status::IOError("Encountered unexpected EOF");
+ }
+ ARROW_CHECK(nbytes > 0);
+ bytesleft -= nbytes;
+ offset += nbytes;
+ }
+
+ return Status::OK();
+}
+
+Status WriteMessage(int fd, MessageType type, int64_t length, uint8_t* bytes) {
+ int64_t version = arrow::BitUtil::ToLittleEndian(kPlasmaProtocolVersion);
+ assert(sizeof(MessageType) == sizeof(int64_t));
+ type = static_cast<MessageType>(
+ arrow::BitUtil::ToLittleEndian(static_cast<int64_t>(type)));
+ int64_t length_le = arrow::BitUtil::ToLittleEndian(length);
+ RETURN_NOT_OK(WriteBytes(fd, reinterpret_cast<uint8_t*>(&version), sizeof(version)));
+ RETURN_NOT_OK(WriteBytes(fd, reinterpret_cast<uint8_t*>(&type), sizeof(type)));
+ RETURN_NOT_OK(WriteBytes(fd, reinterpret_cast<uint8_t*>(&length_le), sizeof(length)));
+ return WriteBytes(fd, bytes, length * sizeof(char));
+}
+
+Status ReadBytes(int fd, uint8_t* cursor, size_t length) {
+ ssize_t nbytes = 0;
+ // Termination condition: EOF or read 'length' bytes total.
+ size_t bytesleft = length;
+ size_t offset = 0;
+ while (bytesleft > 0) {
+ nbytes = read(fd, cursor + offset, bytesleft);
+ if (nbytes < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
+ continue;
+ }
+ return Status::IOError(strerror(errno));
+ } else if (0 == nbytes) {
+ return Status::IOError("Encountered unexpected EOF");
+ }
+ ARROW_CHECK(nbytes > 0);
+ bytesleft -= nbytes;
+ offset += nbytes;
+ }
+
+ return Status::OK();
+}
+
+Status ReadMessage(int fd, MessageType* type, std::vector<uint8_t>* buffer) {
+ int64_t version;
+ RETURN_NOT_OK_ELSE(ReadBytes(fd, reinterpret_cast<uint8_t*>(&version), sizeof(version)),
+ *type = MessageType::PlasmaDisconnectClient);
+ version = arrow::BitUtil::FromLittleEndian(version);
+ ARROW_CHECK(version == kPlasmaProtocolVersion) << "version = " << version;
+ RETURN_NOT_OK_ELSE(ReadBytes(fd, reinterpret_cast<uint8_t*>(type), sizeof(*type)),
+ *type = MessageType::PlasmaDisconnectClient);
+ assert(sizeof(MessageType) == sizeof(int64_t));
+ *type = static_cast<MessageType>(
+ arrow::BitUtil::FromLittleEndian(static_cast<int64_t>(*type)));
+ int64_t length_temp;
+ RETURN_NOT_OK_ELSE(
+ ReadBytes(fd, reinterpret_cast<uint8_t*>(&length_temp), sizeof(length_temp)),
+ *type = MessageType::PlasmaDisconnectClient);
+ // The length must be read as an int64_t, but it should be used as a size_t.
+ size_t length = static_cast<size_t>(arrow::BitUtil::FromLittleEndian(length_temp));
+ if (length > buffer->size()) {
+ buffer->resize(length);
+ }
+ RETURN_NOT_OK_ELSE(ReadBytes(fd, buffer->data(), length),
+ *type = MessageType::PlasmaDisconnectClient);
+ return Status::OK();
+}
+
+int BindIpcSock(const std::string& pathname, bool shall_listen) {
+ struct sockaddr_un socket_address;
+ int socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
+ if (socket_fd < 0) {
+ ARROW_LOG(ERROR) << "socket() failed for pathname " << pathname;
+ return -1;
+ }
+ // Tell the system to allow the port to be reused.
+ int on = 1;
+ if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&on),
+ sizeof(on)) < 0) {
+ ARROW_LOG(ERROR) << "setsockopt failed for pathname " << pathname;
+ close(socket_fd);
+ return -1;
+ }
+
+ unlink(pathname.c_str());
+ memset(&socket_address, 0, sizeof(socket_address));
+ socket_address.sun_family = AF_UNIX;
+ if (pathname.size() + 1 > sizeof(socket_address.sun_path)) {
+ ARROW_LOG(ERROR) << "Socket pathname is too long.";
+ close(socket_fd);
+ return -1;
+ }
+ strncpy(socket_address.sun_path, pathname.c_str(), pathname.size() + 1);
+
+ if (bind(socket_fd, reinterpret_cast<struct sockaddr*>(&socket_address),
+ sizeof(socket_address)) != 0) {
+ ARROW_LOG(ERROR) << "Bind failed for pathname " << pathname;
+ close(socket_fd);
+ return -1;
+ }
+ if (shall_listen && listen(socket_fd, 128) == -1) {
+ ARROW_LOG(ERROR) << "Could not listen to socket " << pathname;
+ close(socket_fd);
+ return -1;
+ }
+ return socket_fd;
+}
+
+Status ConnectIpcSocketRetry(const std::string& pathname, int num_retries,
+ int64_t timeout, int* fd) {
+ // Pick the default values if the user did not specify.
+ if (num_retries < 0) {
+ num_retries = kNumConnectAttempts;
+ }
+ if (timeout < 0) {
+ timeout = kConnectTimeoutMs;
+ }
+ *fd = ConnectIpcSock(pathname);
+ while (*fd < 0 && num_retries > 0) {
+ ARROW_LOG(ERROR) << "Connection to IPC socket failed for pathname " << pathname
+ << ", retrying " << num_retries << " more times";
+ // Sleep for timeout milliseconds.
+ usleep(static_cast<int>(timeout * 1000));
+ *fd = ConnectIpcSock(pathname);
+ --num_retries;
+ }
+
+ // If we could not connect to the socket, exit.
+ if (*fd == -1) {
+ return Status::IOError("Could not connect to socket ", pathname);
+ }
+
+ return Status::OK();
+}
+
+int ConnectIpcSock(const std::string& pathname) {
+ struct sockaddr_un socket_address;
+ int socket_fd;
+
+ socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
+ if (socket_fd < 0) {
+ ARROW_LOG(ERROR) << "socket() failed for pathname " << pathname;
+ return -1;
+ }
+
+ memset(&socket_address, 0, sizeof(socket_address));
+ socket_address.sun_family = AF_UNIX;
+ if (pathname.size() + 1 > sizeof(socket_address.sun_path)) {
+ ARROW_LOG(ERROR) << "Socket pathname is too long.";
+ close(socket_fd);
+ return -1;
+ }
+ strncpy(socket_address.sun_path, pathname.c_str(), pathname.size() + 1);
+
+ if (connect(socket_fd, reinterpret_cast<struct sockaddr*>(&socket_address),
+ sizeof(socket_address)) != 0) {
+ close(socket_fd);
+ return -1;
+ }
+
+ return socket_fd;
+}
+
+int AcceptClient(int socket_fd) {
+ int client_fd = accept(socket_fd, NULL, NULL);
+ if (client_fd < 0) {
+ ARROW_LOG(ERROR) << "Error reading from socket.";
+ return -1;
+ }
+ return client_fd;
+}
+
+std::unique_ptr<uint8_t[]> ReadMessageAsync(int sock) {
+ int64_t size;
+ Status s = ReadBytes(sock, reinterpret_cast<uint8_t*>(&size), sizeof(int64_t));
+ if (!s.ok()) {
+ // The other side has closed the socket.
+ ARROW_LOG(DEBUG) << "Socket has been closed, or some other error has occurred.";
+ close(sock);
+ return NULL;
+ }
+ auto message = std::unique_ptr<uint8_t[]>(new uint8_t[size]);
+ s = ReadBytes(sock, message.get(), size);
+ if (!s.ok()) {
+ // The other side has closed the socket.
+ ARROW_LOG(DEBUG) << "Socket has been closed, or some other error has occurred.";
+ close(sock);
+ return NULL;
+ }
+ return message;
+}
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/io.h b/src/arrow/cpp/src/plasma/io.h
new file mode 100644
index 000000000..c9f17169f
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/io.h
@@ -0,0 +1,67 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <inttypes.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/status.h"
+#include "plasma/compat.h"
+
+namespace plasma {
+
+namespace flatbuf {
+
+// Forward declaration outside the namespace, which is defined in plasma_generated.h.
+enum class MessageType : int64_t;
+
+} // namespace flatbuf
+
+// TODO(pcm): Replace our own custom message header (message type,
+// message length, plasma protocol version) with one that is serialized
+// using flatbuffers.
+constexpr int64_t kPlasmaProtocolVersion = 0x0000000000000000;
+
+using arrow::Status;
+
+Status WriteBytes(int fd, uint8_t* cursor, size_t length);
+
+Status WriteMessage(int fd, flatbuf::MessageType type, int64_t length, uint8_t* bytes);
+
+Status ReadBytes(int fd, uint8_t* cursor, size_t length);
+
+Status ReadMessage(int fd, flatbuf::MessageType* type, std::vector<uint8_t>* buffer);
+
+int BindIpcSock(const std::string& pathname, bool shall_listen);
+
+int ConnectIpcSock(const std::string& pathname);
+
+Status ConnectIpcSocketRetry(const std::string& pathname, int num_retries,
+ int64_t timeout, int* fd);
+
+int AcceptClient(int socket_fd);
+
+std::unique_ptr<uint8_t[]> ReadMessageAsync(int sock);
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.cc b/src/arrow/cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.cc
new file mode 100644
index 000000000..10e0fcb37
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.cc
@@ -0,0 +1,263 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.h"
+
+#include <pthread.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <algorithm>
+#include <cstring>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/util/logging.h"
+
+#include "plasma/client.h"
+
+constexpr jsize OBJECT_ID_SIZE = sizeof(plasma::ObjectID) / sizeof(jbyte);
+
+inline void jbyteArray_to_object_id(JNIEnv* env, jbyteArray a, plasma::ObjectID* oid) {
+ env->GetByteArrayRegion(a, 0, OBJECT_ID_SIZE, reinterpret_cast<jbyte*>(oid));
+}
+
+inline void object_id_to_jbyteArray(JNIEnv* env, jbyteArray a, plasma::ObjectID* oid) {
+ env->SetByteArrayRegion(a, 0, OBJECT_ID_SIZE, reinterpret_cast<jbyte*>(oid));
+}
+
+inline void throw_exception_if_not_OK(JNIEnv* env, const arrow::Status& status) {
+ if (!status.ok()) {
+ jclass Exception =
+ env->FindClass("org/apache/arrow/plasma/exceptions/PlasmaClientException");
+ env->ThrowNew(Exception, status.message().c_str());
+ }
+}
+
+class JByteArrayGetter {
+ private:
+ JNIEnv* _env;
+ jbyteArray _a;
+ jbyte* bp;
+
+ public:
+ JByteArrayGetter(JNIEnv* env, jbyteArray a, jbyte** out) {
+ _env = env;
+ _a = a;
+
+ bp = _env->GetByteArrayElements(_a, nullptr);
+ *out = bp;
+ }
+
+ ~JByteArrayGetter() { _env->ReleaseByteArrayElements(_a, bp, 0); }
+};
+
+JNIEXPORT jlong JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_connect(
+ JNIEnv* env, jclass cls, jstring store_socket_name, jstring manager_socket_name,
+ jint release_delay) {
+ const char* s_name = env->GetStringUTFChars(store_socket_name, nullptr);
+ const char* m_name = env->GetStringUTFChars(manager_socket_name, nullptr);
+
+ plasma::PlasmaClient* client = new plasma::PlasmaClient();
+ throw_exception_if_not_OK(env, client->Connect(s_name, m_name, release_delay));
+
+ env->ReleaseStringUTFChars(store_socket_name, s_name);
+ env->ReleaseStringUTFChars(manager_socket_name, m_name);
+ return reinterpret_cast<int64_t>(client);
+}
+
+JNIEXPORT void JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_disconnect(
+ JNIEnv* env, jclass cls, jlong conn) {
+ plasma::PlasmaClient* client = reinterpret_cast<plasma::PlasmaClient*>(conn);
+
+ throw_exception_if_not_OK(env, client->Disconnect());
+ delete client;
+ return;
+}
+
+JNIEXPORT jobject JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_create(
+ JNIEnv* env, jclass cls, jlong conn, jbyteArray object_id, jint size,
+ jbyteArray metadata) {
+ plasma::PlasmaClient* client = reinterpret_cast<plasma::PlasmaClient*>(conn);
+ plasma::ObjectID oid;
+ jbyteArray_to_object_id(env, object_id, &oid);
+
+ // prepare metadata buffer
+ uint8_t* md = nullptr;
+ jsize md_size = 0;
+ std::unique_ptr<JByteArrayGetter> md_getter;
+ if (metadata != nullptr) {
+ md_size = env->GetArrayLength(metadata);
+ }
+ if (md_size > 0) {
+ md_getter.reset(new JByteArrayGetter(env, metadata, reinterpret_cast<jbyte**>(&md)));
+ }
+
+ std::shared_ptr<Buffer> data;
+ Status s = client->Create(oid, size, md, md_size, &data);
+ if (plasma::IsPlasmaObjectExists(s)) {
+ jclass exceptionClass =
+ env->FindClass("org/apache/arrow/plasma/exceptions/DuplicateObjectException");
+ env->ThrowNew(exceptionClass, oid.hex().c_str());
+ return nullptr;
+ }
+ if (plasma::IsPlasmaStoreFull(s)) {
+ jclass exceptionClass =
+ env->FindClass("org/apache/arrow/plasma/exceptions/PlasmaOutOfMemoryException");
+ env->ThrowNew(exceptionClass, "");
+ return nullptr;
+ }
+ throw_exception_if_not_OK(env, s);
+
+ return env->NewDirectByteBuffer(data->mutable_data(), size);
+}
+
+JNIEXPORT jbyteArray JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_hash(
+ JNIEnv* env, jclass cls, jlong conn, jbyteArray object_id) {
+ plasma::PlasmaClient* client = reinterpret_cast<plasma::PlasmaClient*>(conn);
+ plasma::ObjectID oid;
+ jbyteArray_to_object_id(env, object_id, &oid);
+
+ unsigned char digest[plasma::kDigestSize];
+ bool success = client->Hash(oid, digest).ok();
+
+ if (success) {
+ jbyteArray ret = env->NewByteArray(plasma::kDigestSize);
+ env->SetByteArrayRegion(ret, 0, plasma::kDigestSize,
+ reinterpret_cast<jbyte*>(digest));
+ return ret;
+ } else {
+ return nullptr;
+ }
+}
+
+JNIEXPORT void JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_seal(
+ JNIEnv* env, jclass cls, jlong conn, jbyteArray object_id) {
+ plasma::PlasmaClient* client = reinterpret_cast<plasma::PlasmaClient*>(conn);
+ plasma::ObjectID oid;
+ jbyteArray_to_object_id(env, object_id, &oid);
+
+ throw_exception_if_not_OK(env, client->Seal(oid));
+}
+
+JNIEXPORT void JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_release(
+ JNIEnv* env, jclass cls, jlong conn, jbyteArray object_id) {
+ plasma::PlasmaClient* client = reinterpret_cast<plasma::PlasmaClient*>(conn);
+ plasma::ObjectID oid;
+ jbyteArray_to_object_id(env, object_id, &oid);
+
+ throw_exception_if_not_OK(env, client->Release(oid));
+}
+
+JNIEXPORT void JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_delete(
+ JNIEnv* env, jclass cls, jlong conn, jbyteArray object_id) {
+ plasma::PlasmaClient* client = reinterpret_cast<plasma::PlasmaClient*>(conn);
+ plasma::ObjectID oid;
+ jbyteArray_to_object_id(env, object_id, &oid);
+
+ throw_exception_if_not_OK(env, client->Delete(oid));
+}
+
+JNIEXPORT jobjectArray JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_get(
+ JNIEnv* env, jclass cls, jlong conn, jobjectArray object_ids, jint timeout_ms) {
+ plasma::PlasmaClient* client = reinterpret_cast<plasma::PlasmaClient*>(conn);
+
+ jsize num_oids = env->GetArrayLength(object_ids);
+ std::vector<plasma::ObjectID> oids(num_oids);
+ std::vector<plasma::ObjectBuffer> obufs(num_oids);
+ for (int i = 0; i < num_oids; ++i) {
+ jbyteArray_to_object_id(
+ env, reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(object_ids, i)),
+ &oids[i]);
+ }
+ // TODO: may be blocked. consider to add the thread support
+ throw_exception_if_not_OK(env,
+ client->Get(oids.data(), num_oids, timeout_ms, obufs.data()));
+
+ jclass clsByteBuffer = env->FindClass("java/nio/ByteBuffer");
+ jclass clsByteBufferArray = env->FindClass("[Ljava/nio/ByteBuffer;");
+
+ jobjectArray ret = env->NewObjectArray(num_oids, clsByteBufferArray, nullptr);
+ jobjectArray o = nullptr;
+ jobject dataBuf, metadataBuf;
+ for (int i = 0; i < num_oids; ++i) {
+ o = env->NewObjectArray(2, clsByteBuffer, nullptr);
+ if (obufs[i].data && obufs[i].data->size() != -1) {
+ dataBuf = env->NewDirectByteBuffer(const_cast<uint8_t*>(obufs[i].data->data()),
+ obufs[i].data->size());
+ if (obufs[i].metadata && obufs[i].metadata->size() > 0) {
+ metadataBuf = env->NewDirectByteBuffer(
+ const_cast<uint8_t*>(obufs[i].metadata->data()), obufs[i].metadata->size());
+ } else {
+ metadataBuf = nullptr;
+ }
+ } else {
+ dataBuf = nullptr;
+ metadataBuf = nullptr;
+ }
+
+ env->SetObjectArrayElement(o, 0, dataBuf);
+ env->SetObjectArrayElement(o, 1, metadataBuf);
+ env->SetObjectArrayElement(ret, i, o);
+ }
+ return ret;
+}
+
+JNIEXPORT jboolean JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_contains(
+ JNIEnv* env, jclass cls, jlong conn, jbyteArray object_id) {
+ plasma::PlasmaClient* client = reinterpret_cast<plasma::PlasmaClient*>(conn);
+ plasma::ObjectID oid;
+ jbyteArray_to_object_id(env, object_id, &oid);
+
+ bool has_object;
+ throw_exception_if_not_OK(env, client->Contains(oid, &has_object));
+
+ return has_object;
+}
+
+JNIEXPORT jlong JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_evict(
+ JNIEnv* env, jclass cls, jlong conn, jlong num_bytes) {
+ plasma::PlasmaClient* client = reinterpret_cast<plasma::PlasmaClient*>(conn);
+
+ int64_t evicted_bytes;
+ throw_exception_if_not_OK(
+ env, client->Evict(static_cast<int64_t>(num_bytes), evicted_bytes));
+
+ return static_cast<jlong>(evicted_bytes);
+}
+
+JNIEXPORT jobjectArray JNICALL
+Java_org_apache_arrow_plasma_PlasmaClientJNI_list(JNIEnv* env, jclass cls, jlong conn) {
+ plasma::PlasmaClient* client = reinterpret_cast<plasma::PlasmaClient*>(conn);
+ plasma::ObjectTable objectTable;
+ throw_exception_if_not_OK(env, client->List(&objectTable));
+ jobjectArray ret =
+ env->NewObjectArray(objectTable.size(), env->FindClass("[B"), env->NewByteArray(1));
+ int i = 0;
+ for (const auto& id_entry_pair : objectTable) {
+ const plasma::ObjectID& id = id_entry_pair.first;
+ jbyteArray idByteArray = env->NewByteArray(OBJECT_ID_SIZE);
+ env->SetByteArrayRegion(idByteArray, 0, OBJECT_ID_SIZE,
+ reinterpret_cast<jbyte*>(const_cast<uint8_t*>(id.data())));
+ env->SetObjectArrayElement(ret, i, idByteArray);
+ i++;
+ }
+
+ return ret;
+}
diff --git a/src/arrow/cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.h b/src/arrow/cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.h
new file mode 100644
index 000000000..8a18be91d
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.h
@@ -0,0 +1,141 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/* DO NOT EDIT THIS FILE - it is machine generated */
+#include <jni.h>
+/* Header for class org_apache_arrow_plasma_PlasmaClientJNI */
+
+#ifndef _Included_org_apache_arrow_plasma_PlasmaClientJNI
+#define _Included_org_apache_arrow_plasma_PlasmaClientJNI
+#ifdef __cplusplus
+extern "C" {
+#endif
+/*
+ * Class: org_apache_arrow_plasma_PlasmaClientJNI
+ * Method: connect
+ * Signature: (Ljava/lang/String;Ljava/lang/String;I)J
+ */
+JNIEXPORT jlong JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_connect(
+ JNIEnv*, jclass, jstring, jstring, jint);
+
+/*
+ * Class: org_apache_arrow_plasma_PlasmaClientJNI
+ * Method: disconnect
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_disconnect(JNIEnv*,
+ jclass,
+ jlong);
+
+/*
+ * Class: org_apache_arrow_plasma_PlasmaClientJNI
+ * Method: create
+ * Signature: (J[BI[B)Ljava/nio/ByteBuffer;
+ */
+JNIEXPORT jobject JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_create(
+ JNIEnv*, jclass, jlong, jbyteArray, jint, jbyteArray);
+
+/*
+ * Class: org_apache_arrow_plasma_PlasmaClientJNI
+ * Method: hash
+ * Signature: (J[B)[B
+ */
+JNIEXPORT jbyteArray JNICALL
+Java_org_apache_arrow_plasma_PlasmaClientJNI_hash(JNIEnv*, jclass, jlong, jbyteArray);
+
+/*
+ * Class: org_apache_arrow_plasma_PlasmaClientJNI
+ * Method: seal
+ * Signature: (J[B)V
+ */
+JNIEXPORT void JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_seal(JNIEnv*, jclass,
+ jlong,
+ jbyteArray);
+
+/*
+ * Class: org_apache_arrow_plasma_PlasmaClientJNI
+ * Method: release
+ * Signature: (J[B)V
+ */
+JNIEXPORT void JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_release(JNIEnv*,
+ jclass, jlong,
+ jbyteArray);
+
+/*
+ * Class: org_apache_arrow_plasma_PlasmaClientJNI
+ * Method: delete
+ * Signature: (J[B)V
+ */
+JNIEXPORT void JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_delete(JNIEnv*,
+ jclass, jlong,
+ jbyteArray);
+
+/*
+ * Class: org_apache_arrow_plasma_PlasmaClientJNI
+ * Method: get
+ * Signature: (J[[BI)[[Ljava/nio/ByteBuffer;
+ */
+JNIEXPORT jobjectArray JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_get(
+ JNIEnv*, jclass, jlong, jobjectArray, jint);
+
+/*
+ * Class: org_apache_arrow_plasma_PlasmaClientJNI
+ * Method: contains
+ * Signature: (J[B)Z
+ */
+JNIEXPORT jboolean JNICALL
+Java_org_apache_arrow_plasma_PlasmaClientJNI_contains(JNIEnv*, jclass, jlong, jbyteArray);
+
+/*
+ * Class: org_apache_arrow_plasma_PlasmaClientJNI
+ * Method: fetch
+ * Signature: (J[[B)V
+ */
+JNIEXPORT void JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_fetch(JNIEnv*, jclass,
+ jlong,
+ jobjectArray);
+
+/*
+ * Class: org_apache_arrow_plasma_PlasmaClientJNI
+ * Method: wait
+ * Signature: (J[[BII)[[B
+ */
+JNIEXPORT jobjectArray JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_wait(
+ JNIEnv*, jclass, jlong, jobjectArray, jint, jint);
+
+/*
+ * Class: org_apache_arrow_plasma_PlasmaClientJNI
+ * Method: evict
+ * Signature: (JJ)J
+ */
+JNIEXPORT jlong JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_evict(JNIEnv*,
+ jclass, jlong,
+ jlong);
+
+/*
+ * Class: org_apache_arrow_plasma_PlasmaClientJNI
+ * Method: list
+ * Signature: (J)[[B
+ */
+JNIEXPORT jobjectArray JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_list(JNIEnv*,
+ jclass,
+ jlong);
+
+#ifdef __cplusplus
+}
+#endif
+#endif
diff --git a/src/arrow/cpp/src/plasma/malloc.cc b/src/arrow/cpp/src/plasma/malloc.cc
new file mode 100644
index 000000000..bb027a6cb
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/malloc.cc
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "plasma/malloc.h"
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <unistd.h>
+
+#include <cerrno>
+#include <string>
+#include <vector>
+
+#include "plasma/common.h"
+#include "plasma/plasma.h"
+
+namespace plasma {
+
+std::unordered_map<void*, MmapRecord> mmap_records;
+
+static void* pointer_advance(void* p, ptrdiff_t n) { return (unsigned char*)p + n; }
+
+static ptrdiff_t pointer_distance(void const* pfrom, void const* pto) {
+ return (unsigned char const*)pto - (unsigned char const*)pfrom;
+}
+
+void GetMallocMapinfo(void* addr, int* fd, int64_t* map_size, ptrdiff_t* offset) {
+ // TODO(rshin): Implement a more efficient search through mmap_records.
+ for (const auto& entry : mmap_records) {
+ if (addr >= entry.first && addr < pointer_advance(entry.first, entry.second.size)) {
+ *fd = entry.second.fd;
+ *map_size = entry.second.size;
+ *offset = pointer_distance(entry.first, addr);
+ return;
+ }
+ }
+ *fd = -1;
+ *map_size = 0;
+ *offset = 0;
+}
+
+int64_t GetMmapSize(int fd) {
+ for (const auto& entry : mmap_records) {
+ if (entry.second.fd == fd) {
+ return entry.second.size;
+ }
+ }
+ ARROW_LOG(FATAL) << "failed to find entry in mmap_records for fd " << fd;
+ return -1; // This code is never reached.
+}
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/malloc.h b/src/arrow/cpp/src/plasma/malloc.h
new file mode 100644
index 000000000..edc0763a5
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/malloc.h
@@ -0,0 +1,51 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <inttypes.h>
+#include <stddef.h>
+
+#include <unordered_map>
+
+namespace plasma {
+
+/// Gap between two consecutive mmap regions allocated by fake_mmap.
+/// This ensures that the segments of memory returned by
+/// fake_mmap are never contiguous and dlmalloc does not coalesce it
+/// (in the client we cannot guarantee that these mmaps are contiguous).
+constexpr int64_t kMmapRegionsGap = sizeof(size_t);
+
+void GetMallocMapinfo(void* addr, int* fd, int64_t* map_length, ptrdiff_t* offset);
+
+/// Get the mmap size corresponding to a specific file descriptor.
+///
+/// \param fd The file descriptor to look up.
+/// \return The size of the corresponding memory-mapped file.
+int64_t GetMmapSize(int fd);
+
+struct MmapRecord {
+ int fd;
+ int64_t size;
+};
+
+/// Hashtable that contains one entry per segment that we got from the OS
+/// via mmap. Associates the address of that segment with its file descriptor
+/// and size.
+extern std::unordered_map<void*, MmapRecord> mmap_records;
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/plasma.cc b/src/arrow/cpp/src/plasma/plasma.cc
new file mode 100644
index 000000000..6f38951fb
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/plasma.cc
@@ -0,0 +1,99 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "plasma/plasma.h"
+
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "plasma/common.h"
+#include "plasma/common_generated.h"
+#include "plasma/protocol.h"
+
+namespace fb = plasma::flatbuf;
+
+namespace plasma {
+
+ObjectTableEntry::ObjectTableEntry() : pointer(nullptr), ref_count(0) {}
+
+ObjectTableEntry::~ObjectTableEntry() { pointer = nullptr; }
+
+int WarnIfSigpipe(int status, int client_sock) {
+ if (status >= 0) {
+ return 0;
+ }
+ if (errno == EPIPE || errno == EBADF || errno == ECONNRESET) {
+ ARROW_LOG(WARNING) << "Received SIGPIPE, BAD FILE DESCRIPTOR, or ECONNRESET when "
+ "sending a message to client on fd "
+ << client_sock
+ << ". The client on the other end may "
+ "have hung up.";
+ return errno;
+ }
+ ARROW_LOG(FATAL) << "Failed to write message to client on fd " << client_sock << ".";
+ return -1; // This is never reached.
+}
+
+/**
+ * This will create a new ObjectInfo buffer. The first sizeof(int64_t) bytes
+ * of this buffer are the length of the remaining message and the
+ * remaining message is a serialized version of the object info.
+ *
+ * \param object_info The object info to be serialized
+ * \return The object info buffer. It is the caller's responsibility to free
+ * this buffer with "delete" after it has been used.
+ */
+std::unique_ptr<uint8_t[]> CreateObjectInfoBuffer(fb::ObjectInfoT* object_info) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreateObjectInfo(fbb, object_info);
+ fbb.Finish(message);
+ auto notification =
+ std::unique_ptr<uint8_t[]>(new uint8_t[sizeof(int64_t) + fbb.GetSize()]);
+ *(reinterpret_cast<int64_t*>(notification.get())) = fbb.GetSize();
+ memcpy(notification.get() + sizeof(int64_t), fbb.GetBufferPointer(), fbb.GetSize());
+ return notification;
+}
+
+std::unique_ptr<uint8_t[]> CreatePlasmaNotificationBuffer(
+ std::vector<fb::ObjectInfoT>& object_info) {
+ flatbuffers::FlatBufferBuilder fbb;
+ std::vector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>> info;
+ for (size_t i = 0; i < object_info.size(); ++i) {
+ info.push_back(fb::CreateObjectInfo(fbb, &object_info[i]));
+ }
+
+ auto info_array = fbb.CreateVector(info);
+ auto message = fb::CreatePlasmaNotification(fbb, info_array);
+ fbb.Finish(message);
+ auto notification =
+ std::unique_ptr<uint8_t[]>(new uint8_t[sizeof(int64_t) + fbb.GetSize()]);
+ *(reinterpret_cast<int64_t*>(notification.get())) = fbb.GetSize();
+ memcpy(notification.get() + sizeof(int64_t), fbb.GetBufferPointer(), fbb.GetSize());
+ return notification;
+}
+
+ObjectTableEntry* GetObjectTableEntry(PlasmaStoreInfo* store_info,
+ const ObjectID& object_id) {
+ auto it = store_info->objects.find(object_id);
+ if (it == store_info->objects.end()) {
+ return NULL;
+ }
+ return it->second.get();
+}
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/plasma.fbs b/src/arrow/cpp/src/plasma/plasma.fbs
new file mode 100644
index 000000000..62c02b96a
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/plasma.fbs
@@ -0,0 +1,357 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+include "common.fbs";
+
+// Plasma protocol specification
+namespace plasma.flatbuf;
+
+enum MessageType:long {
+ // Message that gets send when a client hangs up.
+ PlasmaDisconnectClient = 0,
+ // Create a new object.
+ PlasmaCreateRequest,
+ PlasmaCreateReply,
+ PlasmaCreateAndSealRequest,
+ PlasmaCreateAndSealReply,
+ PlasmaAbortRequest,
+ PlasmaAbortReply,
+ // Seal an object.
+ PlasmaSealRequest,
+ PlasmaSealReply,
+ // Get an object that is stored on the local Plasma store.
+ PlasmaGetRequest,
+ PlasmaGetReply,
+ // Release an object.
+ PlasmaReleaseRequest,
+ PlasmaReleaseReply,
+ // Delete an object.
+ PlasmaDeleteRequest,
+ PlasmaDeleteReply,
+ // See if the store contains an object (will be deprecated).
+ PlasmaContainsRequest,
+ PlasmaContainsReply,
+ // List all objects in the store.
+ PlasmaListRequest,
+ PlasmaListReply,
+ // Get information for a newly connecting client.
+ PlasmaConnectRequest,
+ PlasmaConnectReply,
+ // Make room for new objects in the plasma store.
+ PlasmaEvictRequest,
+ PlasmaEvictReply,
+ // Subscribe to a list of objects or to all objects.
+ PlasmaSubscribeRequest,
+ // Unsubscribe.
+ PlasmaUnsubscribeRequest,
+ // Sending and receiving data.
+ // PlasmaDataRequest initiates sending the data, there will be one
+ // such message per data transfer.
+ PlasmaDataRequest,
+ // PlasmaDataReply contains the actual data and is sent back to the
+ // object store that requested the data. For each transfer, multiple
+ // reply messages get sent. Each one contains a fixed number of bytes.
+ PlasmaDataReply,
+ // Object notifications.
+ PlasmaNotification,
+ // Set memory quota for a client.
+ PlasmaSetOptionsRequest,
+ PlasmaSetOptionsReply,
+ // Get debugging information from the store.
+ PlasmaGetDebugStringRequest,
+ PlasmaGetDebugStringReply,
+ // Create and seal a batch of objects. This should be used to save
+ // IPC for creating many small objects.
+ PlasmaCreateAndSealBatchRequest,
+ PlasmaCreateAndSealBatchReply,
+ // Touch a number of objects to bump their position in the LRU cache.
+ PlasmaRefreshLRURequest,
+ PlasmaRefreshLRUReply,
+}
+
+enum PlasmaError:int {
+ // Operation was successful.
+ OK,
+ // Trying to create an object that already exists.
+ ObjectExists,
+ // Trying to access an object that doesn't exist.
+ ObjectNotFound,
+ // Trying to create an object but there isn't enough space in the store.
+ OutOfMemory,
+ // Trying to delete an object but it's not sealed.
+ ObjectNotSealed,
+ // Trying to delete an object but it's in use.
+ ObjectInUse,
+}
+
+// Plasma store messages
+
+struct PlasmaObjectSpec {
+ // Index of the memory segment (= memory mapped file) that
+ // this object is allocated in.
+ segment_index: int;
+ // The offset in bytes in the memory mapped file of the data.
+ data_offset: ulong;
+ // The size in bytes of the data.
+ data_size: ulong;
+ // The offset in bytes in the memory mapped file of the metadata.
+ metadata_offset: ulong;
+ // The size in bytes of the metadata.
+ metadata_size: ulong;
+ // Device to create buffer on.
+ device_num: int;
+}
+
+table PlasmaSetOptionsRequest {
+ // The name of the client.
+ client_name: string;
+ // The size of the output memory limit in bytes.
+ output_memory_quota: long;
+}
+
+table PlasmaSetOptionsReply {
+ // Whether setting options succeeded.
+ error: PlasmaError;
+}
+
+table PlasmaGetDebugStringRequest {
+}
+
+table PlasmaGetDebugStringReply {
+ // The debug string from the server.
+ debug_string: string;
+}
+
+table PlasmaCreateRequest {
+ // ID of the object to be created.
+ object_id: string;
+ // Whether to evict other objects to make room for this one.
+ evict_if_full: bool;
+ // The size of the object's data in bytes.
+ data_size: ulong;
+ // The size of the object's metadata in bytes.
+ metadata_size: ulong;
+ // Device to create buffer on.
+ device_num: int;
+}
+
+table CudaHandle {
+ handle: [ubyte];
+}
+
+table PlasmaCreateReply {
+ // ID of the object that was created.
+ object_id: string;
+ // The object that is returned with this reply.
+ plasma_object: PlasmaObjectSpec;
+ // Error that occurred for this call.
+ error: PlasmaError;
+ // The file descriptor in the store that corresponds to the file descriptor
+ // being sent to the client right after this message.
+ store_fd: int;
+ // The size in bytes of the segment for the store file descriptor (needed to
+ // call mmap).
+ mmap_size: long;
+ // CUDA IPC Handle for objects on GPU.
+ ipc_handle: CudaHandle;
+}
+
+table PlasmaCreateAndSealRequest {
+ // ID of the object to be created.
+ object_id: string;
+ // Whether to evict other objects to make room for this one.
+ evict_if_full: bool;
+ // The object's data.
+ data: string;
+ // The object's metadata.
+ metadata: string;
+ // Hash of the object data.
+ digest: string;
+}
+
+table PlasmaCreateAndSealReply {
+ // Error that occurred for this call.
+ error: PlasmaError;
+}
+
+table PlasmaCreateAndSealBatchRequest {
+ object_ids: [string];
+ // Whether to evict other objects to make room for these objects.
+ evict_if_full: bool;
+ data: [string];
+ metadata: [string];
+ digest: [string];
+}
+
+table PlasmaCreateAndSealBatchReply {
+ // Error that occurred for this call.
+ error: PlasmaError;
+}
+
+table PlasmaAbortRequest {
+ // ID of the object to be aborted.
+ object_id: string;
+}
+
+table PlasmaAbortReply {
+ // ID of the object that was aborted.
+ object_id: string;
+}
+
+table PlasmaSealRequest {
+ // ID of the object to be sealed.
+ object_id: string;
+ // Hash of the object data.
+ digest: string;
+}
+
+table PlasmaSealReply {
+ // ID of the object that was sealed.
+ object_id: string;
+ // Error code.
+ error: PlasmaError;
+}
+
+table PlasmaGetRequest {
+ // IDs of the objects stored at local Plasma store we are getting.
+ object_ids: [string];
+ // The number of milliseconds before the request should timeout.
+ timeout_ms: long;
+}
+
+table PlasmaGetReply {
+ // IDs of the objects being returned.
+ // This number can be smaller than the number of requested
+ // objects if not all requested objects are stored and sealed
+ // in the local Plasma store.
+ object_ids: [string];
+ // Plasma object information, in the same order as their IDs. The number of
+ // elements in both object_ids and plasma_objects arrays must agree.
+ plasma_objects: [PlasmaObjectSpec];
+ // A list of the file descriptors in the store that correspond to the file
+ // descriptors being sent to the client. The length of this list is the number
+ // of file descriptors that the store will send to the client after this
+ // message.
+ store_fds: [int];
+ // Size in bytes of the segment for each store file descriptor (needed to call
+ // mmap). This list must have the same length as store_fds.
+ mmap_sizes: [long];
+ // The number of elements in both object_ids and plasma_objects arrays must agree.
+ handles: [CudaHandle];
+}
+
+table PlasmaReleaseRequest {
+ // ID of the object to be released.
+ object_id: string;
+}
+
+table PlasmaReleaseReply {
+ // ID of the object that was released.
+ object_id: string;
+ // Error code.
+ error: PlasmaError;
+}
+
+table PlasmaDeleteRequest {
+ // The number of objects to delete.
+ count: int;
+ // ID of the object to be deleted.
+ object_ids: [string];
+}
+
+table PlasmaDeleteReply {
+ // The number of objects to delete.
+ count: int;
+ // ID of the object that was deleted.
+ object_ids: [string];
+ // Error code.
+ errors: [PlasmaError];
+}
+
+table PlasmaContainsRequest {
+ // ID of the object we are querying.
+ object_id: string;
+}
+
+table PlasmaContainsReply {
+ // ID of the object we are querying.
+ object_id: string;
+ // 1 if the object is in the store and 0 otherwise.
+ has_object: int;
+}
+
+table PlasmaListRequest {
+}
+
+table PlasmaListReply {
+ objects: [ObjectInfo];
+}
+
+// PlasmaConnect is used by a plasma client the first time it connects with the
+// store. This is not really necessary, but is used to get some information
+// about the store such as its memory capacity.
+
+table PlasmaConnectRequest {
+}
+
+table PlasmaConnectReply {
+ // The memory capacity of the store.
+ memory_capacity: long;
+}
+
+table PlasmaEvictRequest {
+ // Number of bytes that shall be freed.
+ num_bytes: ulong;
+}
+
+table PlasmaEvictReply {
+ // Number of bytes that have been freed.
+ num_bytes: ulong;
+}
+
+table PlasmaSubscribeRequest {
+}
+
+table PlasmaNotification {
+ object_info: [ObjectInfo];
+}
+
+table PlasmaDataRequest {
+ // ID of the object that is requested.
+ object_id: string;
+ // The host address where the data shall be sent to.
+ address: string;
+ // The port of the manager the data shall be sent to.
+ port: int;
+}
+
+table PlasmaDataReply {
+ // ID of the object that will be sent.
+ object_id: string;
+ // Size of the object data in bytes.
+ object_size: ulong;
+ // Size of the metadata in bytes.
+ metadata_size: ulong;
+}
+
+table PlasmaRefreshLRURequest {
+ // ID of the objects to be bumped in the LRU cache.
+ object_ids: [string];
+}
+
+table PlasmaRefreshLRUReply {
+}
diff --git a/src/arrow/cpp/src/plasma/plasma.h b/src/arrow/cpp/src/plasma/plasma.h
new file mode 100644
index 000000000..236f5c948
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/plasma.h
@@ -0,0 +1,175 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <errno.h>
+#include <inttypes.h>
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <unistd.h> // pid_t
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "plasma/compat.h"
+
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "plasma/common.h"
+
+#ifdef PLASMA_CUDA
+using arrow::cuda::CudaIpcMemHandle;
+#endif
+
+namespace plasma {
+
+namespace flatbuf {
+struct ObjectInfoT;
+} // namespace flatbuf
+
+#define HANDLE_SIGPIPE(s, fd_) \
+ do { \
+ Status _s = (s); \
+ if (!_s.ok()) { \
+ if (errno == EPIPE || errno == EBADF || errno == ECONNRESET) { \
+ ARROW_LOG(WARNING) \
+ << "Received SIGPIPE, BAD FILE DESCRIPTOR, or ECONNRESET when " \
+ "sending a message to client on fd " \
+ << fd_ \
+ << ". " \
+ "The client on the other end may have hung up."; \
+ } else { \
+ return _s; \
+ } \
+ } \
+ } while (0);
+
+/// Allocation granularity used in plasma for object allocation.
+constexpr int64_t kBlockSize = 64;
+
+/// Contains all information that is associated with a Plasma store client.
+struct Client {
+ explicit Client(int fd);
+
+ /// The file descriptor used to communicate with the client.
+ int fd;
+
+ /// Object ids that are used by this client.
+ std::unordered_set<ObjectID> object_ids;
+
+ /// File descriptors that are used by this client.
+ std::unordered_set<int> used_fds;
+
+ /// The file descriptor used to push notifications to client. This is only valid
+ /// if client subscribes to plasma store. -1 indicates invalid.
+ int notification_fd;
+
+ std::string name = "anonymous_client";
+};
+
+// TODO(pcm): Replace this by the flatbuffers message PlasmaObjectSpec.
+struct PlasmaObject {
+#ifdef PLASMA_CUDA
+ // IPC handle for Cuda.
+ std::shared_ptr<CudaIpcMemHandle> ipc_handle;
+#endif
+ /// The file descriptor of the memory mapped file in the store. It is used as
+ /// a unique identifier of the file in the client to look up the corresponding
+ /// file descriptor on the client's side.
+ int store_fd;
+ /// The offset in bytes in the memory mapped file of the data.
+ ptrdiff_t data_offset;
+ /// The offset in bytes in the memory mapped file of the metadata.
+ ptrdiff_t metadata_offset;
+ /// The size in bytes of the data.
+ int64_t data_size;
+ /// The size in bytes of the metadata.
+ int64_t metadata_size;
+ /// Device number object is on.
+ int device_num;
+
+ bool operator==(const PlasmaObject& other) const {
+ return (
+#ifdef PLASMA_CUDA
+ (ipc_handle == other.ipc_handle) &&
+#endif
+ (store_fd == other.store_fd) && (data_offset == other.data_offset) &&
+ (metadata_offset == other.metadata_offset) && (data_size == other.data_size) &&
+ (metadata_size == other.metadata_size) && (device_num == other.device_num));
+ }
+};
+
+enum class ObjectStatus : int {
+ /// The object was not found.
+ OBJECT_NOT_FOUND = 0,
+ /// The object was found.
+ OBJECT_FOUND = 1
+};
+
+/// The plasma store information that is exposed to the eviction policy.
+struct PlasmaStoreInfo {
+ /// Objects that are in the Plasma store.
+ ObjectTable objects;
+ /// Boolean flag indicating whether to start the object store with hugepages
+ /// support enabled. Huge pages are substantially larger than normal memory
+ /// pages (e.g. 2MB or 1GB instead of 4KB) and using them can reduce
+ /// bookkeeping overhead from the OS.
+ bool hugepages_enabled;
+ /// A (platform-dependent) directory where to create the memory-backed file.
+ std::string directory;
+};
+
+/// Get an entry from the object table and return NULL if the object_id
+/// is not present.
+///
+/// \param store_info The PlasmaStoreInfo that contains the object table.
+/// \param object_id The object_id of the entry we are looking for.
+/// \return The entry associated with the object_id or NULL if the object_id
+/// is not present.
+ObjectTableEntry* GetObjectTableEntry(PlasmaStoreInfo* store_info,
+ const ObjectID& object_id);
+
+/// Print a warning if the status is less than zero. This should be used to check
+/// the success of messages sent to plasma clients. We print a warning instead of
+/// failing because the plasma clients are allowed to die. This is used to handle
+/// situations where the store writes to a client file descriptor, and the client
+/// may already have disconnected. If we have processed the disconnection and
+/// closed the file descriptor, we should get a BAD FILE DESCRIPTOR error. If we
+/// have not, then we should get a SIGPIPE. If we write to a TCP socket that
+/// isn't connected yet, then we should get an ECONNRESET.
+///
+/// \param status The status to check. If it is less less than zero, we will
+/// print a warning.
+/// \param client_sock The client socket. This is just used to print some extra
+/// information.
+/// \return The errno set.
+int WarnIfSigpipe(int status, int client_sock);
+
+std::unique_ptr<uint8_t[]> CreateObjectInfoBuffer(flatbuf::ObjectInfoT* object_info);
+
+std::unique_ptr<uint8_t[]> CreatePlasmaNotificationBuffer(
+ std::vector<flatbuf::ObjectInfoT>& object_info);
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/plasma.pc.in b/src/arrow/cpp/src/plasma/plasma.pc.in
new file mode 100644
index 000000000..17af01590
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/plasma.pc.in
@@ -0,0 +1,33 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+prefix=@CMAKE_INSTALL_PREFIX@
+libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@
+includedir=${prefix}/include
+
+so_version=@PLASMA_SO_VERSION@
+abi_version=@PLASMA_SO_VERSION@
+full_so_version=@PLASMA_FULL_SO_VERSION@
+plasma_store_server=${prefix}/@CMAKE_INSTALL_BINDIR@/plasma-store-server@CMAKE_EXECUTABLE_SUFFIX@
+executable=${plasma_store_server}
+
+Name: Plasma
+Description: Plasma is an in-memory object store and cache for big data.
+Version: @PLASMA_VERSION@
+Requires: arrow
+Libs: -L${libdir} -lplasma
+Cflags: -I${includedir}
diff --git a/src/arrow/cpp/src/plasma/plasma_allocator.cc b/src/arrow/cpp/src/plasma/plasma_allocator.cc
new file mode 100644
index 000000000..b67eeea40
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/plasma_allocator.cc
@@ -0,0 +1,56 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arrow/util/logging.h>
+
+#include "plasma/malloc.h"
+#include "plasma/plasma_allocator.h"
+
+namespace plasma {
+
+extern "C" {
+void* dlmemalign(size_t alignment, size_t bytes);
+void dlfree(void* mem);
+}
+
+int64_t PlasmaAllocator::footprint_limit_ = 0;
+int64_t PlasmaAllocator::allocated_ = 0;
+
+void* PlasmaAllocator::Memalign(size_t alignment, size_t bytes) {
+ if (allocated_ + static_cast<int64_t>(bytes) > footprint_limit_) {
+ return nullptr;
+ }
+ void* mem = dlmemalign(alignment, bytes);
+ ARROW_CHECK(mem);
+ allocated_ += bytes;
+ return mem;
+}
+
+void PlasmaAllocator::Free(void* mem, size_t bytes) {
+ dlfree(mem);
+ allocated_ -= bytes;
+}
+
+void PlasmaAllocator::SetFootprintLimit(size_t bytes) {
+ footprint_limit_ = static_cast<int64_t>(bytes);
+}
+
+int64_t PlasmaAllocator::GetFootprintLimit() { return footprint_limit_; }
+
+int64_t PlasmaAllocator::Allocated() { return allocated_; }
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/plasma_allocator.h b/src/arrow/cpp/src/plasma/plasma_allocator.h
new file mode 100644
index 000000000..f0e23246b
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/plasma_allocator.h
@@ -0,0 +1,61 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+
+namespace plasma {
+
+class PlasmaAllocator {
+ public:
+ /// Allocates size bytes and returns a pointer to the allocated memory. The
+ /// memory address will be a multiple of alignment, which must be a power of two.
+ ///
+ /// \param alignment Memory alignment.
+ /// \param bytes Number of bytes.
+ /// \return Pointer to allocated memory.
+ static void* Memalign(size_t alignment, size_t bytes);
+
+ /// Frees the memory space pointed to by mem, which must have been returned by
+ /// a previous call to Memalign()
+ ///
+ /// \param mem Pointer to memory to free.
+ /// \param bytes Number of bytes to be freed.
+ static void Free(void* mem, size_t bytes);
+
+ /// Sets the memory footprint limit for Plasma.
+ ///
+ /// \param bytes Plasma memory footprint limit in bytes.
+ static void SetFootprintLimit(size_t bytes);
+
+ /// Get the memory footprint limit for Plasma.
+ ///
+ /// \return Plasma memory footprint limit in bytes.
+ static int64_t GetFootprintLimit();
+
+ /// Get the number of bytes allocated by Plasma so far.
+ /// \return Number of bytes allocated by Plasma so far.
+ static int64_t Allocated();
+
+ private:
+ static int64_t allocated_;
+ static int64_t footprint_limit_;
+};
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/plasma_generated.h b/src/arrow/cpp/src/plasma/plasma_generated.h
new file mode 100644
index 000000000..340f043bc
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/plasma_generated.h
@@ -0,0 +1,3984 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_PLASMA_PLASMA_FLATBUF_H_
+#define FLATBUFFERS_GENERATED_PLASMA_PLASMA_FLATBUF_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+#include "common_generated.h"
+
+namespace plasma {
+namespace flatbuf {
+
+struct PlasmaObjectSpec;
+
+struct PlasmaSetOptionsRequest;
+struct PlasmaSetOptionsRequestBuilder;
+struct PlasmaSetOptionsRequestT;
+
+struct PlasmaSetOptionsReply;
+struct PlasmaSetOptionsReplyBuilder;
+struct PlasmaSetOptionsReplyT;
+
+struct PlasmaGetDebugStringRequest;
+struct PlasmaGetDebugStringRequestBuilder;
+struct PlasmaGetDebugStringRequestT;
+
+struct PlasmaGetDebugStringReply;
+struct PlasmaGetDebugStringReplyBuilder;
+struct PlasmaGetDebugStringReplyT;
+
+struct PlasmaCreateRequest;
+struct PlasmaCreateRequestBuilder;
+struct PlasmaCreateRequestT;
+
+struct CudaHandle;
+struct CudaHandleBuilder;
+struct CudaHandleT;
+
+struct PlasmaCreateReply;
+struct PlasmaCreateReplyBuilder;
+struct PlasmaCreateReplyT;
+
+struct PlasmaCreateAndSealRequest;
+struct PlasmaCreateAndSealRequestBuilder;
+struct PlasmaCreateAndSealRequestT;
+
+struct PlasmaCreateAndSealReply;
+struct PlasmaCreateAndSealReplyBuilder;
+struct PlasmaCreateAndSealReplyT;
+
+struct PlasmaCreateAndSealBatchRequest;
+struct PlasmaCreateAndSealBatchRequestBuilder;
+struct PlasmaCreateAndSealBatchRequestT;
+
+struct PlasmaCreateAndSealBatchReply;
+struct PlasmaCreateAndSealBatchReplyBuilder;
+struct PlasmaCreateAndSealBatchReplyT;
+
+struct PlasmaAbortRequest;
+struct PlasmaAbortRequestBuilder;
+struct PlasmaAbortRequestT;
+
+struct PlasmaAbortReply;
+struct PlasmaAbortReplyBuilder;
+struct PlasmaAbortReplyT;
+
+struct PlasmaSealRequest;
+struct PlasmaSealRequestBuilder;
+struct PlasmaSealRequestT;
+
+struct PlasmaSealReply;
+struct PlasmaSealReplyBuilder;
+struct PlasmaSealReplyT;
+
+struct PlasmaGetRequest;
+struct PlasmaGetRequestBuilder;
+struct PlasmaGetRequestT;
+
+struct PlasmaGetReply;
+struct PlasmaGetReplyBuilder;
+struct PlasmaGetReplyT;
+
+struct PlasmaReleaseRequest;
+struct PlasmaReleaseRequestBuilder;
+struct PlasmaReleaseRequestT;
+
+struct PlasmaReleaseReply;
+struct PlasmaReleaseReplyBuilder;
+struct PlasmaReleaseReplyT;
+
+struct PlasmaDeleteRequest;
+struct PlasmaDeleteRequestBuilder;
+struct PlasmaDeleteRequestT;
+
+struct PlasmaDeleteReply;
+struct PlasmaDeleteReplyBuilder;
+struct PlasmaDeleteReplyT;
+
+struct PlasmaContainsRequest;
+struct PlasmaContainsRequestBuilder;
+struct PlasmaContainsRequestT;
+
+struct PlasmaContainsReply;
+struct PlasmaContainsReplyBuilder;
+struct PlasmaContainsReplyT;
+
+struct PlasmaListRequest;
+struct PlasmaListRequestBuilder;
+struct PlasmaListRequestT;
+
+struct PlasmaListReply;
+struct PlasmaListReplyBuilder;
+struct PlasmaListReplyT;
+
+struct PlasmaConnectRequest;
+struct PlasmaConnectRequestBuilder;
+struct PlasmaConnectRequestT;
+
+struct PlasmaConnectReply;
+struct PlasmaConnectReplyBuilder;
+struct PlasmaConnectReplyT;
+
+struct PlasmaEvictRequest;
+struct PlasmaEvictRequestBuilder;
+struct PlasmaEvictRequestT;
+
+struct PlasmaEvictReply;
+struct PlasmaEvictReplyBuilder;
+struct PlasmaEvictReplyT;
+
+struct PlasmaSubscribeRequest;
+struct PlasmaSubscribeRequestBuilder;
+struct PlasmaSubscribeRequestT;
+
+struct PlasmaNotification;
+struct PlasmaNotificationBuilder;
+struct PlasmaNotificationT;
+
+struct PlasmaDataRequest;
+struct PlasmaDataRequestBuilder;
+struct PlasmaDataRequestT;
+
+struct PlasmaDataReply;
+struct PlasmaDataReplyBuilder;
+struct PlasmaDataReplyT;
+
+struct PlasmaRefreshLRURequest;
+struct PlasmaRefreshLRURequestBuilder;
+struct PlasmaRefreshLRURequestT;
+
+struct PlasmaRefreshLRUReply;
+struct PlasmaRefreshLRUReplyBuilder;
+struct PlasmaRefreshLRUReplyT;
+
+enum class MessageType : int64_t {
+ PlasmaDisconnectClient = 0,
+ PlasmaCreateRequest = 1LL,
+ PlasmaCreateReply = 2LL,
+ PlasmaCreateAndSealRequest = 3LL,
+ PlasmaCreateAndSealReply = 4LL,
+ PlasmaAbortRequest = 5LL,
+ PlasmaAbortReply = 6LL,
+ PlasmaSealRequest = 7LL,
+ PlasmaSealReply = 8LL,
+ PlasmaGetRequest = 9LL,
+ PlasmaGetReply = 10LL,
+ PlasmaReleaseRequest = 11LL,
+ PlasmaReleaseReply = 12LL,
+ PlasmaDeleteRequest = 13LL,
+ PlasmaDeleteReply = 14LL,
+ PlasmaContainsRequest = 15LL,
+ PlasmaContainsReply = 16LL,
+ PlasmaListRequest = 17LL,
+ PlasmaListReply = 18LL,
+ PlasmaConnectRequest = 19LL,
+ PlasmaConnectReply = 20LL,
+ PlasmaEvictRequest = 21LL,
+ PlasmaEvictReply = 22LL,
+ PlasmaSubscribeRequest = 23LL,
+ PlasmaUnsubscribeRequest = 24LL,
+ PlasmaDataRequest = 25LL,
+ PlasmaDataReply = 26LL,
+ PlasmaNotification = 27LL,
+ PlasmaSetOptionsRequest = 28LL,
+ PlasmaSetOptionsReply = 29LL,
+ PlasmaGetDebugStringRequest = 30LL,
+ PlasmaGetDebugStringReply = 31LL,
+ PlasmaCreateAndSealBatchRequest = 32LL,
+ PlasmaCreateAndSealBatchReply = 33LL,
+ PlasmaRefreshLRURequest = 34LL,
+ PlasmaRefreshLRUReply = 35LL,
+ MIN = PlasmaDisconnectClient,
+ MAX = PlasmaRefreshLRUReply
+};
+
+inline const MessageType (&EnumValuesMessageType())[36] {
+ static const MessageType values[] = {
+ MessageType::PlasmaDisconnectClient,
+ MessageType::PlasmaCreateRequest,
+ MessageType::PlasmaCreateReply,
+ MessageType::PlasmaCreateAndSealRequest,
+ MessageType::PlasmaCreateAndSealReply,
+ MessageType::PlasmaAbortRequest,
+ MessageType::PlasmaAbortReply,
+ MessageType::PlasmaSealRequest,
+ MessageType::PlasmaSealReply,
+ MessageType::PlasmaGetRequest,
+ MessageType::PlasmaGetReply,
+ MessageType::PlasmaReleaseRequest,
+ MessageType::PlasmaReleaseReply,
+ MessageType::PlasmaDeleteRequest,
+ MessageType::PlasmaDeleteReply,
+ MessageType::PlasmaContainsRequest,
+ MessageType::PlasmaContainsReply,
+ MessageType::PlasmaListRequest,
+ MessageType::PlasmaListReply,
+ MessageType::PlasmaConnectRequest,
+ MessageType::PlasmaConnectReply,
+ MessageType::PlasmaEvictRequest,
+ MessageType::PlasmaEvictReply,
+ MessageType::PlasmaSubscribeRequest,
+ MessageType::PlasmaUnsubscribeRequest,
+ MessageType::PlasmaDataRequest,
+ MessageType::PlasmaDataReply,
+ MessageType::PlasmaNotification,
+ MessageType::PlasmaSetOptionsRequest,
+ MessageType::PlasmaSetOptionsReply,
+ MessageType::PlasmaGetDebugStringRequest,
+ MessageType::PlasmaGetDebugStringReply,
+ MessageType::PlasmaCreateAndSealBatchRequest,
+ MessageType::PlasmaCreateAndSealBatchReply,
+ MessageType::PlasmaRefreshLRURequest,
+ MessageType::PlasmaRefreshLRUReply
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesMessageType() {
+ static const char * const names[37] = {
+ "PlasmaDisconnectClient",
+ "PlasmaCreateRequest",
+ "PlasmaCreateReply",
+ "PlasmaCreateAndSealRequest",
+ "PlasmaCreateAndSealReply",
+ "PlasmaAbortRequest",
+ "PlasmaAbortReply",
+ "PlasmaSealRequest",
+ "PlasmaSealReply",
+ "PlasmaGetRequest",
+ "PlasmaGetReply",
+ "PlasmaReleaseRequest",
+ "PlasmaReleaseReply",
+ "PlasmaDeleteRequest",
+ "PlasmaDeleteReply",
+ "PlasmaContainsRequest",
+ "PlasmaContainsReply",
+ "PlasmaListRequest",
+ "PlasmaListReply",
+ "PlasmaConnectRequest",
+ "PlasmaConnectReply",
+ "PlasmaEvictRequest",
+ "PlasmaEvictReply",
+ "PlasmaSubscribeRequest",
+ "PlasmaUnsubscribeRequest",
+ "PlasmaDataRequest",
+ "PlasmaDataReply",
+ "PlasmaNotification",
+ "PlasmaSetOptionsRequest",
+ "PlasmaSetOptionsReply",
+ "PlasmaGetDebugStringRequest",
+ "PlasmaGetDebugStringReply",
+ "PlasmaCreateAndSealBatchRequest",
+ "PlasmaCreateAndSealBatchReply",
+ "PlasmaRefreshLRURequest",
+ "PlasmaRefreshLRUReply",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameMessageType(MessageType e) {
+ if (flatbuffers::IsOutRange(e, MessageType::PlasmaDisconnectClient, MessageType::PlasmaRefreshLRUReply)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesMessageType()[index];
+}
+
+enum class PlasmaError : int32_t {
+ OK = 0,
+ ObjectExists = 1,
+ ObjectNotFound = 2,
+ OutOfMemory = 3,
+ ObjectNotSealed = 4,
+ ObjectInUse = 5,
+ MIN = OK,
+ MAX = ObjectInUse
+};
+
+inline const PlasmaError (&EnumValuesPlasmaError())[6] {
+ static const PlasmaError values[] = {
+ PlasmaError::OK,
+ PlasmaError::ObjectExists,
+ PlasmaError::ObjectNotFound,
+ PlasmaError::OutOfMemory,
+ PlasmaError::ObjectNotSealed,
+ PlasmaError::ObjectInUse
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesPlasmaError() {
+ static const char * const names[7] = {
+ "OK",
+ "ObjectExists",
+ "ObjectNotFound",
+ "OutOfMemory",
+ "ObjectNotSealed",
+ "ObjectInUse",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNamePlasmaError(PlasmaError e) {
+ if (flatbuffers::IsOutRange(e, PlasmaError::OK, PlasmaError::ObjectInUse)) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesPlasmaError()[index];
+}
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) PlasmaObjectSpec FLATBUFFERS_FINAL_CLASS {
+ private:
+ int32_t segment_index_;
+ int32_t padding0__;
+ uint64_t data_offset_;
+ uint64_t data_size_;
+ uint64_t metadata_offset_;
+ uint64_t metadata_size_;
+ int32_t device_num_;
+ int32_t padding1__;
+
+ public:
+ PlasmaObjectSpec() {
+ memset(static_cast<void *>(this), 0, sizeof(PlasmaObjectSpec));
+ }
+ PlasmaObjectSpec(int32_t _segment_index, uint64_t _data_offset, uint64_t _data_size, uint64_t _metadata_offset, uint64_t _metadata_size, int32_t _device_num)
+ : segment_index_(flatbuffers::EndianScalar(_segment_index)),
+ padding0__(0),
+ data_offset_(flatbuffers::EndianScalar(_data_offset)),
+ data_size_(flatbuffers::EndianScalar(_data_size)),
+ metadata_offset_(flatbuffers::EndianScalar(_metadata_offset)),
+ metadata_size_(flatbuffers::EndianScalar(_metadata_size)),
+ device_num_(flatbuffers::EndianScalar(_device_num)),
+ padding1__(0) {
+ (void)padding0__;
+ (void)padding1__;
+ }
+ int32_t segment_index() const {
+ return flatbuffers::EndianScalar(segment_index_);
+ }
+ uint64_t data_offset() const {
+ return flatbuffers::EndianScalar(data_offset_);
+ }
+ uint64_t data_size() const {
+ return flatbuffers::EndianScalar(data_size_);
+ }
+ uint64_t metadata_offset() const {
+ return flatbuffers::EndianScalar(metadata_offset_);
+ }
+ uint64_t metadata_size() const {
+ return flatbuffers::EndianScalar(metadata_size_);
+ }
+ int32_t device_num() const {
+ return flatbuffers::EndianScalar(device_num_);
+ }
+};
+FLATBUFFERS_STRUCT_END(PlasmaObjectSpec, 48);
+
+struct PlasmaSetOptionsRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaSetOptionsRequest TableType;
+ std::string client_name;
+ int64_t output_memory_quota;
+ PlasmaSetOptionsRequestT()
+ : output_memory_quota(0) {
+ }
+};
+
+struct PlasmaSetOptionsRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaSetOptionsRequestT NativeTableType;
+ typedef PlasmaSetOptionsRequestBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_CLIENT_NAME = 4,
+ VT_OUTPUT_MEMORY_QUOTA = 6
+ };
+ const flatbuffers::String *client_name() const {
+ return GetPointer<const flatbuffers::String *>(VT_CLIENT_NAME);
+ }
+ int64_t output_memory_quota() const {
+ return GetField<int64_t>(VT_OUTPUT_MEMORY_QUOTA, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_CLIENT_NAME) &&
+ verifier.VerifyString(client_name()) &&
+ VerifyField<int64_t>(verifier, VT_OUTPUT_MEMORY_QUOTA) &&
+ verifier.EndTable();
+ }
+ PlasmaSetOptionsRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaSetOptionsRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaSetOptionsRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSetOptionsRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaSetOptionsRequestBuilder {
+ typedef PlasmaSetOptionsRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_client_name(flatbuffers::Offset<flatbuffers::String> client_name) {
+ fbb_.AddOffset(PlasmaSetOptionsRequest::VT_CLIENT_NAME, client_name);
+ }
+ void add_output_memory_quota(int64_t output_memory_quota) {
+ fbb_.AddElement<int64_t>(PlasmaSetOptionsRequest::VT_OUTPUT_MEMORY_QUOTA, output_memory_quota, 0);
+ }
+ explicit PlasmaSetOptionsRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaSetOptionsRequestBuilder &operator=(const PlasmaSetOptionsRequestBuilder &);
+ flatbuffers::Offset<PlasmaSetOptionsRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaSetOptionsRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaSetOptionsRequest> CreatePlasmaSetOptionsRequest(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> client_name = 0,
+ int64_t output_memory_quota = 0) {
+ PlasmaSetOptionsRequestBuilder builder_(_fbb);
+ builder_.add_output_memory_quota(output_memory_quota);
+ builder_.add_client_name(client_name);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaSetOptionsRequest> CreatePlasmaSetOptionsRequestDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *client_name = nullptr,
+ int64_t output_memory_quota = 0) {
+ auto client_name__ = client_name ? _fbb.CreateString(client_name) : 0;
+ return plasma::flatbuf::CreatePlasmaSetOptionsRequest(
+ _fbb,
+ client_name__,
+ output_memory_quota);
+}
+
+flatbuffers::Offset<PlasmaSetOptionsRequest> CreatePlasmaSetOptionsRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSetOptionsRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaSetOptionsReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaSetOptionsReply TableType;
+ plasma::flatbuf::PlasmaError error;
+ PlasmaSetOptionsReplyT()
+ : error(plasma::flatbuf::PlasmaError::OK) {
+ }
+};
+
+struct PlasmaSetOptionsReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaSetOptionsReplyT NativeTableType;
+ typedef PlasmaSetOptionsReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_ERROR = 4
+ };
+ plasma::flatbuf::PlasmaError error() const {
+ return static_cast<plasma::flatbuf::PlasmaError>(GetField<int32_t>(VT_ERROR, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_ERROR) &&
+ verifier.EndTable();
+ }
+ PlasmaSetOptionsReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaSetOptionsReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaSetOptionsReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSetOptionsReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaSetOptionsReplyBuilder {
+ typedef PlasmaSetOptionsReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_error(plasma::flatbuf::PlasmaError error) {
+ fbb_.AddElement<int32_t>(PlasmaSetOptionsReply::VT_ERROR, static_cast<int32_t>(error), 0);
+ }
+ explicit PlasmaSetOptionsReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaSetOptionsReplyBuilder &operator=(const PlasmaSetOptionsReplyBuilder &);
+ flatbuffers::Offset<PlasmaSetOptionsReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaSetOptionsReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaSetOptionsReply> CreatePlasmaSetOptionsReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ plasma::flatbuf::PlasmaError error = plasma::flatbuf::PlasmaError::OK) {
+ PlasmaSetOptionsReplyBuilder builder_(_fbb);
+ builder_.add_error(error);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PlasmaSetOptionsReply> CreatePlasmaSetOptionsReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSetOptionsReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaGetDebugStringRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaGetDebugStringRequest TableType;
+ PlasmaGetDebugStringRequestT() {
+ }
+};
+
+struct PlasmaGetDebugStringRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaGetDebugStringRequestT NativeTableType;
+ typedef PlasmaGetDebugStringRequestBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ PlasmaGetDebugStringRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaGetDebugStringRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaGetDebugStringRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetDebugStringRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaGetDebugStringRequestBuilder {
+ typedef PlasmaGetDebugStringRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit PlasmaGetDebugStringRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaGetDebugStringRequestBuilder &operator=(const PlasmaGetDebugStringRequestBuilder &);
+ flatbuffers::Offset<PlasmaGetDebugStringRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaGetDebugStringRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaGetDebugStringRequest> CreatePlasmaGetDebugStringRequest(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ PlasmaGetDebugStringRequestBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PlasmaGetDebugStringRequest> CreatePlasmaGetDebugStringRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetDebugStringRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaGetDebugStringReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaGetDebugStringReply TableType;
+ std::string debug_string;
+ PlasmaGetDebugStringReplyT() {
+ }
+};
+
+struct PlasmaGetDebugStringReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaGetDebugStringReplyT NativeTableType;
+ typedef PlasmaGetDebugStringReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_DEBUG_STRING = 4
+ };
+ const flatbuffers::String *debug_string() const {
+ return GetPointer<const flatbuffers::String *>(VT_DEBUG_STRING);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_DEBUG_STRING) &&
+ verifier.VerifyString(debug_string()) &&
+ verifier.EndTable();
+ }
+ PlasmaGetDebugStringReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaGetDebugStringReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaGetDebugStringReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetDebugStringReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaGetDebugStringReplyBuilder {
+ typedef PlasmaGetDebugStringReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_debug_string(flatbuffers::Offset<flatbuffers::String> debug_string) {
+ fbb_.AddOffset(PlasmaGetDebugStringReply::VT_DEBUG_STRING, debug_string);
+ }
+ explicit PlasmaGetDebugStringReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaGetDebugStringReplyBuilder &operator=(const PlasmaGetDebugStringReplyBuilder &);
+ flatbuffers::Offset<PlasmaGetDebugStringReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaGetDebugStringReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaGetDebugStringReply> CreatePlasmaGetDebugStringReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> debug_string = 0) {
+ PlasmaGetDebugStringReplyBuilder builder_(_fbb);
+ builder_.add_debug_string(debug_string);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaGetDebugStringReply> CreatePlasmaGetDebugStringReplyDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *debug_string = nullptr) {
+ auto debug_string__ = debug_string ? _fbb.CreateString(debug_string) : 0;
+ return plasma::flatbuf::CreatePlasmaGetDebugStringReply(
+ _fbb,
+ debug_string__);
+}
+
+flatbuffers::Offset<PlasmaGetDebugStringReply> CreatePlasmaGetDebugStringReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetDebugStringReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaCreateRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaCreateRequest TableType;
+ std::string object_id;
+ bool evict_if_full;
+ uint64_t data_size;
+ uint64_t metadata_size;
+ int32_t device_num;
+ PlasmaCreateRequestT()
+ : evict_if_full(false),
+ data_size(0),
+ metadata_size(0),
+ device_num(0) {
+ }
+};
+
+struct PlasmaCreateRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaCreateRequestT NativeTableType;
+ typedef PlasmaCreateRequestBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_ID = 4,
+ VT_EVICT_IF_FULL = 6,
+ VT_DATA_SIZE = 8,
+ VT_METADATA_SIZE = 10,
+ VT_DEVICE_NUM = 12
+ };
+ const flatbuffers::String *object_id() const {
+ return GetPointer<const flatbuffers::String *>(VT_OBJECT_ID);
+ }
+ bool evict_if_full() const {
+ return GetField<uint8_t>(VT_EVICT_IF_FULL, 0) != 0;
+ }
+ uint64_t data_size() const {
+ return GetField<uint64_t>(VT_DATA_SIZE, 0);
+ }
+ uint64_t metadata_size() const {
+ return GetField<uint64_t>(VT_METADATA_SIZE, 0);
+ }
+ int32_t device_num() const {
+ return GetField<int32_t>(VT_DEVICE_NUM, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_ID) &&
+ verifier.VerifyString(object_id()) &&
+ VerifyField<uint8_t>(verifier, VT_EVICT_IF_FULL) &&
+ VerifyField<uint64_t>(verifier, VT_DATA_SIZE) &&
+ VerifyField<uint64_t>(verifier, VT_METADATA_SIZE) &&
+ VerifyField<int32_t>(verifier, VT_DEVICE_NUM) &&
+ verifier.EndTable();
+ }
+ PlasmaCreateRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaCreateRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaCreateRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaCreateRequestBuilder {
+ typedef PlasmaCreateRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_id(flatbuffers::Offset<flatbuffers::String> object_id) {
+ fbb_.AddOffset(PlasmaCreateRequest::VT_OBJECT_ID, object_id);
+ }
+ void add_evict_if_full(bool evict_if_full) {
+ fbb_.AddElement<uint8_t>(PlasmaCreateRequest::VT_EVICT_IF_FULL, static_cast<uint8_t>(evict_if_full), 0);
+ }
+ void add_data_size(uint64_t data_size) {
+ fbb_.AddElement<uint64_t>(PlasmaCreateRequest::VT_DATA_SIZE, data_size, 0);
+ }
+ void add_metadata_size(uint64_t metadata_size) {
+ fbb_.AddElement<uint64_t>(PlasmaCreateRequest::VT_METADATA_SIZE, metadata_size, 0);
+ }
+ void add_device_num(int32_t device_num) {
+ fbb_.AddElement<int32_t>(PlasmaCreateRequest::VT_DEVICE_NUM, device_num, 0);
+ }
+ explicit PlasmaCreateRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaCreateRequestBuilder &operator=(const PlasmaCreateRequestBuilder &);
+ flatbuffers::Offset<PlasmaCreateRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaCreateRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaCreateRequest> CreatePlasmaCreateRequest(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> object_id = 0,
+ bool evict_if_full = false,
+ uint64_t data_size = 0,
+ uint64_t metadata_size = 0,
+ int32_t device_num = 0) {
+ PlasmaCreateRequestBuilder builder_(_fbb);
+ builder_.add_metadata_size(metadata_size);
+ builder_.add_data_size(data_size);
+ builder_.add_device_num(device_num);
+ builder_.add_object_id(object_id);
+ builder_.add_evict_if_full(evict_if_full);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaCreateRequest> CreatePlasmaCreateRequestDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *object_id = nullptr,
+ bool evict_if_full = false,
+ uint64_t data_size = 0,
+ uint64_t metadata_size = 0,
+ int32_t device_num = 0) {
+ auto object_id__ = object_id ? _fbb.CreateString(object_id) : 0;
+ return plasma::flatbuf::CreatePlasmaCreateRequest(
+ _fbb,
+ object_id__,
+ evict_if_full,
+ data_size,
+ metadata_size,
+ device_num);
+}
+
+flatbuffers::Offset<PlasmaCreateRequest> CreatePlasmaCreateRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct CudaHandleT : public flatbuffers::NativeTable {
+ typedef CudaHandle TableType;
+ std::vector<uint8_t> handle;
+ CudaHandleT() {
+ }
+};
+
+struct CudaHandle FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef CudaHandleT NativeTableType;
+ typedef CudaHandleBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_HANDLE = 4
+ };
+ const flatbuffers::Vector<uint8_t> *handle() const {
+ return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_HANDLE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_HANDLE) &&
+ verifier.VerifyVector(handle()) &&
+ verifier.EndTable();
+ }
+ CudaHandleT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(CudaHandleT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<CudaHandle> Pack(flatbuffers::FlatBufferBuilder &_fbb, const CudaHandleT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct CudaHandleBuilder {
+ typedef CudaHandle Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_handle(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> handle) {
+ fbb_.AddOffset(CudaHandle::VT_HANDLE, handle);
+ }
+ explicit CudaHandleBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CudaHandleBuilder &operator=(const CudaHandleBuilder &);
+ flatbuffers::Offset<CudaHandle> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<CudaHandle>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<CudaHandle> CreateCudaHandle(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> handle = 0) {
+ CudaHandleBuilder builder_(_fbb);
+ builder_.add_handle(handle);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<CudaHandle> CreateCudaHandleDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<uint8_t> *handle = nullptr) {
+ auto handle__ = handle ? _fbb.CreateVector<uint8_t>(*handle) : 0;
+ return plasma::flatbuf::CreateCudaHandle(
+ _fbb,
+ handle__);
+}
+
+flatbuffers::Offset<CudaHandle> CreateCudaHandle(flatbuffers::FlatBufferBuilder &_fbb, const CudaHandleT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaCreateReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaCreateReply TableType;
+ std::string object_id;
+ std::unique_ptr<plasma::flatbuf::PlasmaObjectSpec> plasma_object;
+ plasma::flatbuf::PlasmaError error;
+ int32_t store_fd;
+ int64_t mmap_size;
+ std::unique_ptr<plasma::flatbuf::CudaHandleT> ipc_handle;
+ PlasmaCreateReplyT()
+ : error(plasma::flatbuf::PlasmaError::OK),
+ store_fd(0),
+ mmap_size(0) {
+ }
+};
+
+struct PlasmaCreateReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaCreateReplyT NativeTableType;
+ typedef PlasmaCreateReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_ID = 4,
+ VT_PLASMA_OBJECT = 6,
+ VT_ERROR = 8,
+ VT_STORE_FD = 10,
+ VT_MMAP_SIZE = 12,
+ VT_IPC_HANDLE = 14
+ };
+ const flatbuffers::String *object_id() const {
+ return GetPointer<const flatbuffers::String *>(VT_OBJECT_ID);
+ }
+ const plasma::flatbuf::PlasmaObjectSpec *plasma_object() const {
+ return GetStruct<const plasma::flatbuf::PlasmaObjectSpec *>(VT_PLASMA_OBJECT);
+ }
+ plasma::flatbuf::PlasmaError error() const {
+ return static_cast<plasma::flatbuf::PlasmaError>(GetField<int32_t>(VT_ERROR, 0));
+ }
+ int32_t store_fd() const {
+ return GetField<int32_t>(VT_STORE_FD, 0);
+ }
+ int64_t mmap_size() const {
+ return GetField<int64_t>(VT_MMAP_SIZE, 0);
+ }
+ const plasma::flatbuf::CudaHandle *ipc_handle() const {
+ return GetPointer<const plasma::flatbuf::CudaHandle *>(VT_IPC_HANDLE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_ID) &&
+ verifier.VerifyString(object_id()) &&
+ VerifyField<plasma::flatbuf::PlasmaObjectSpec>(verifier, VT_PLASMA_OBJECT) &&
+ VerifyField<int32_t>(verifier, VT_ERROR) &&
+ VerifyField<int32_t>(verifier, VT_STORE_FD) &&
+ VerifyField<int64_t>(verifier, VT_MMAP_SIZE) &&
+ VerifyOffset(verifier, VT_IPC_HANDLE) &&
+ verifier.VerifyTable(ipc_handle()) &&
+ verifier.EndTable();
+ }
+ PlasmaCreateReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaCreateReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaCreateReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaCreateReplyBuilder {
+ typedef PlasmaCreateReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_id(flatbuffers::Offset<flatbuffers::String> object_id) {
+ fbb_.AddOffset(PlasmaCreateReply::VT_OBJECT_ID, object_id);
+ }
+ void add_plasma_object(const plasma::flatbuf::PlasmaObjectSpec *plasma_object) {
+ fbb_.AddStruct(PlasmaCreateReply::VT_PLASMA_OBJECT, plasma_object);
+ }
+ void add_error(plasma::flatbuf::PlasmaError error) {
+ fbb_.AddElement<int32_t>(PlasmaCreateReply::VT_ERROR, static_cast<int32_t>(error), 0);
+ }
+ void add_store_fd(int32_t store_fd) {
+ fbb_.AddElement<int32_t>(PlasmaCreateReply::VT_STORE_FD, store_fd, 0);
+ }
+ void add_mmap_size(int64_t mmap_size) {
+ fbb_.AddElement<int64_t>(PlasmaCreateReply::VT_MMAP_SIZE, mmap_size, 0);
+ }
+ void add_ipc_handle(flatbuffers::Offset<plasma::flatbuf::CudaHandle> ipc_handle) {
+ fbb_.AddOffset(PlasmaCreateReply::VT_IPC_HANDLE, ipc_handle);
+ }
+ explicit PlasmaCreateReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaCreateReplyBuilder &operator=(const PlasmaCreateReplyBuilder &);
+ flatbuffers::Offset<PlasmaCreateReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaCreateReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaCreateReply> CreatePlasmaCreateReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> object_id = 0,
+ const plasma::flatbuf::PlasmaObjectSpec *plasma_object = 0,
+ plasma::flatbuf::PlasmaError error = plasma::flatbuf::PlasmaError::OK,
+ int32_t store_fd = 0,
+ int64_t mmap_size = 0,
+ flatbuffers::Offset<plasma::flatbuf::CudaHandle> ipc_handle = 0) {
+ PlasmaCreateReplyBuilder builder_(_fbb);
+ builder_.add_mmap_size(mmap_size);
+ builder_.add_ipc_handle(ipc_handle);
+ builder_.add_store_fd(store_fd);
+ builder_.add_error(error);
+ builder_.add_plasma_object(plasma_object);
+ builder_.add_object_id(object_id);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaCreateReply> CreatePlasmaCreateReplyDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *object_id = nullptr,
+ const plasma::flatbuf::PlasmaObjectSpec *plasma_object = 0,
+ plasma::flatbuf::PlasmaError error = plasma::flatbuf::PlasmaError::OK,
+ int32_t store_fd = 0,
+ int64_t mmap_size = 0,
+ flatbuffers::Offset<plasma::flatbuf::CudaHandle> ipc_handle = 0) {
+ auto object_id__ = object_id ? _fbb.CreateString(object_id) : 0;
+ return plasma::flatbuf::CreatePlasmaCreateReply(
+ _fbb,
+ object_id__,
+ plasma_object,
+ error,
+ store_fd,
+ mmap_size,
+ ipc_handle);
+}
+
+flatbuffers::Offset<PlasmaCreateReply> CreatePlasmaCreateReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaCreateAndSealRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaCreateAndSealRequest TableType;
+ std::string object_id;
+ bool evict_if_full;
+ std::string data;
+ std::string metadata;
+ std::string digest;
+ PlasmaCreateAndSealRequestT()
+ : evict_if_full(false) {
+ }
+};
+
+struct PlasmaCreateAndSealRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaCreateAndSealRequestT NativeTableType;
+ typedef PlasmaCreateAndSealRequestBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_ID = 4,
+ VT_EVICT_IF_FULL = 6,
+ VT_DATA = 8,
+ VT_METADATA = 10,
+ VT_DIGEST = 12
+ };
+ const flatbuffers::String *object_id() const {
+ return GetPointer<const flatbuffers::String *>(VT_OBJECT_ID);
+ }
+ bool evict_if_full() const {
+ return GetField<uint8_t>(VT_EVICT_IF_FULL, 0) != 0;
+ }
+ const flatbuffers::String *data() const {
+ return GetPointer<const flatbuffers::String *>(VT_DATA);
+ }
+ const flatbuffers::String *metadata() const {
+ return GetPointer<const flatbuffers::String *>(VT_METADATA);
+ }
+ const flatbuffers::String *digest() const {
+ return GetPointer<const flatbuffers::String *>(VT_DIGEST);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_ID) &&
+ verifier.VerifyString(object_id()) &&
+ VerifyField<uint8_t>(verifier, VT_EVICT_IF_FULL) &&
+ VerifyOffset(verifier, VT_DATA) &&
+ verifier.VerifyString(data()) &&
+ VerifyOffset(verifier, VT_METADATA) &&
+ verifier.VerifyString(metadata()) &&
+ VerifyOffset(verifier, VT_DIGEST) &&
+ verifier.VerifyString(digest()) &&
+ verifier.EndTable();
+ }
+ PlasmaCreateAndSealRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaCreateAndSealRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaCreateAndSealRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaCreateAndSealRequestBuilder {
+ typedef PlasmaCreateAndSealRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_id(flatbuffers::Offset<flatbuffers::String> object_id) {
+ fbb_.AddOffset(PlasmaCreateAndSealRequest::VT_OBJECT_ID, object_id);
+ }
+ void add_evict_if_full(bool evict_if_full) {
+ fbb_.AddElement<uint8_t>(PlasmaCreateAndSealRequest::VT_EVICT_IF_FULL, static_cast<uint8_t>(evict_if_full), 0);
+ }
+ void add_data(flatbuffers::Offset<flatbuffers::String> data) {
+ fbb_.AddOffset(PlasmaCreateAndSealRequest::VT_DATA, data);
+ }
+ void add_metadata(flatbuffers::Offset<flatbuffers::String> metadata) {
+ fbb_.AddOffset(PlasmaCreateAndSealRequest::VT_METADATA, metadata);
+ }
+ void add_digest(flatbuffers::Offset<flatbuffers::String> digest) {
+ fbb_.AddOffset(PlasmaCreateAndSealRequest::VT_DIGEST, digest);
+ }
+ explicit PlasmaCreateAndSealRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaCreateAndSealRequestBuilder &operator=(const PlasmaCreateAndSealRequestBuilder &);
+ flatbuffers::Offset<PlasmaCreateAndSealRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaCreateAndSealRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaCreateAndSealRequest> CreatePlasmaCreateAndSealRequest(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> object_id = 0,
+ bool evict_if_full = false,
+ flatbuffers::Offset<flatbuffers::String> data = 0,
+ flatbuffers::Offset<flatbuffers::String> metadata = 0,
+ flatbuffers::Offset<flatbuffers::String> digest = 0) {
+ PlasmaCreateAndSealRequestBuilder builder_(_fbb);
+ builder_.add_digest(digest);
+ builder_.add_metadata(metadata);
+ builder_.add_data(data);
+ builder_.add_object_id(object_id);
+ builder_.add_evict_if_full(evict_if_full);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaCreateAndSealRequest> CreatePlasmaCreateAndSealRequestDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *object_id = nullptr,
+ bool evict_if_full = false,
+ const char *data = nullptr,
+ const char *metadata = nullptr,
+ const char *digest = nullptr) {
+ auto object_id__ = object_id ? _fbb.CreateString(object_id) : 0;
+ auto data__ = data ? _fbb.CreateString(data) : 0;
+ auto metadata__ = metadata ? _fbb.CreateString(metadata) : 0;
+ auto digest__ = digest ? _fbb.CreateString(digest) : 0;
+ return plasma::flatbuf::CreatePlasmaCreateAndSealRequest(
+ _fbb,
+ object_id__,
+ evict_if_full,
+ data__,
+ metadata__,
+ digest__);
+}
+
+flatbuffers::Offset<PlasmaCreateAndSealRequest> CreatePlasmaCreateAndSealRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaCreateAndSealReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaCreateAndSealReply TableType;
+ plasma::flatbuf::PlasmaError error;
+ PlasmaCreateAndSealReplyT()
+ : error(plasma::flatbuf::PlasmaError::OK) {
+ }
+};
+
+struct PlasmaCreateAndSealReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaCreateAndSealReplyT NativeTableType;
+ typedef PlasmaCreateAndSealReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_ERROR = 4
+ };
+ plasma::flatbuf::PlasmaError error() const {
+ return static_cast<plasma::flatbuf::PlasmaError>(GetField<int32_t>(VT_ERROR, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_ERROR) &&
+ verifier.EndTable();
+ }
+ PlasmaCreateAndSealReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaCreateAndSealReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaCreateAndSealReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaCreateAndSealReplyBuilder {
+ typedef PlasmaCreateAndSealReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_error(plasma::flatbuf::PlasmaError error) {
+ fbb_.AddElement<int32_t>(PlasmaCreateAndSealReply::VT_ERROR, static_cast<int32_t>(error), 0);
+ }
+ explicit PlasmaCreateAndSealReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaCreateAndSealReplyBuilder &operator=(const PlasmaCreateAndSealReplyBuilder &);
+ flatbuffers::Offset<PlasmaCreateAndSealReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaCreateAndSealReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaCreateAndSealReply> CreatePlasmaCreateAndSealReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ plasma::flatbuf::PlasmaError error = plasma::flatbuf::PlasmaError::OK) {
+ PlasmaCreateAndSealReplyBuilder builder_(_fbb);
+ builder_.add_error(error);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PlasmaCreateAndSealReply> CreatePlasmaCreateAndSealReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaCreateAndSealBatchRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaCreateAndSealBatchRequest TableType;
+ std::vector<std::string> object_ids;
+ bool evict_if_full;
+ std::vector<std::string> data;
+ std::vector<std::string> metadata;
+ std::vector<std::string> digest;
+ PlasmaCreateAndSealBatchRequestT()
+ : evict_if_full(false) {
+ }
+};
+
+struct PlasmaCreateAndSealBatchRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaCreateAndSealBatchRequestT NativeTableType;
+ typedef PlasmaCreateAndSealBatchRequestBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_IDS = 4,
+ VT_EVICT_IF_FULL = 6,
+ VT_DATA = 8,
+ VT_METADATA = 10,
+ VT_DIGEST = 12
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *object_ids() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_OBJECT_IDS);
+ }
+ bool evict_if_full() const {
+ return GetField<uint8_t>(VT_EVICT_IF_FULL, 0) != 0;
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *data() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_DATA);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *metadata() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_METADATA);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *digest() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_DIGEST);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_IDS) &&
+ verifier.VerifyVector(object_ids()) &&
+ verifier.VerifyVectorOfStrings(object_ids()) &&
+ VerifyField<uint8_t>(verifier, VT_EVICT_IF_FULL) &&
+ VerifyOffset(verifier, VT_DATA) &&
+ verifier.VerifyVector(data()) &&
+ verifier.VerifyVectorOfStrings(data()) &&
+ VerifyOffset(verifier, VT_METADATA) &&
+ verifier.VerifyVector(metadata()) &&
+ verifier.VerifyVectorOfStrings(metadata()) &&
+ VerifyOffset(verifier, VT_DIGEST) &&
+ verifier.VerifyVector(digest()) &&
+ verifier.VerifyVectorOfStrings(digest()) &&
+ verifier.EndTable();
+ }
+ PlasmaCreateAndSealBatchRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaCreateAndSealBatchRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaCreateAndSealBatchRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealBatchRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaCreateAndSealBatchRequestBuilder {
+ typedef PlasmaCreateAndSealBatchRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_ids(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> object_ids) {
+ fbb_.AddOffset(PlasmaCreateAndSealBatchRequest::VT_OBJECT_IDS, object_ids);
+ }
+ void add_evict_if_full(bool evict_if_full) {
+ fbb_.AddElement<uint8_t>(PlasmaCreateAndSealBatchRequest::VT_EVICT_IF_FULL, static_cast<uint8_t>(evict_if_full), 0);
+ }
+ void add_data(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> data) {
+ fbb_.AddOffset(PlasmaCreateAndSealBatchRequest::VT_DATA, data);
+ }
+ void add_metadata(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> metadata) {
+ fbb_.AddOffset(PlasmaCreateAndSealBatchRequest::VT_METADATA, metadata);
+ }
+ void add_digest(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> digest) {
+ fbb_.AddOffset(PlasmaCreateAndSealBatchRequest::VT_DIGEST, digest);
+ }
+ explicit PlasmaCreateAndSealBatchRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaCreateAndSealBatchRequestBuilder &operator=(const PlasmaCreateAndSealBatchRequestBuilder &);
+ flatbuffers::Offset<PlasmaCreateAndSealBatchRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaCreateAndSealBatchRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaCreateAndSealBatchRequest> CreatePlasmaCreateAndSealBatchRequest(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> object_ids = 0,
+ bool evict_if_full = false,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> data = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> metadata = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> digest = 0) {
+ PlasmaCreateAndSealBatchRequestBuilder builder_(_fbb);
+ builder_.add_digest(digest);
+ builder_.add_metadata(metadata);
+ builder_.add_data(data);
+ builder_.add_object_ids(object_ids);
+ builder_.add_evict_if_full(evict_if_full);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaCreateAndSealBatchRequest> CreatePlasmaCreateAndSealBatchRequestDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *object_ids = nullptr,
+ bool evict_if_full = false,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *data = nullptr,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *metadata = nullptr,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *digest = nullptr) {
+ auto object_ids__ = object_ids ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*object_ids) : 0;
+ auto data__ = data ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*data) : 0;
+ auto metadata__ = metadata ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*metadata) : 0;
+ auto digest__ = digest ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*digest) : 0;
+ return plasma::flatbuf::CreatePlasmaCreateAndSealBatchRequest(
+ _fbb,
+ object_ids__,
+ evict_if_full,
+ data__,
+ metadata__,
+ digest__);
+}
+
+flatbuffers::Offset<PlasmaCreateAndSealBatchRequest> CreatePlasmaCreateAndSealBatchRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealBatchRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaCreateAndSealBatchReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaCreateAndSealBatchReply TableType;
+ plasma::flatbuf::PlasmaError error;
+ PlasmaCreateAndSealBatchReplyT()
+ : error(plasma::flatbuf::PlasmaError::OK) {
+ }
+};
+
+struct PlasmaCreateAndSealBatchReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaCreateAndSealBatchReplyT NativeTableType;
+ typedef PlasmaCreateAndSealBatchReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_ERROR = 4
+ };
+ plasma::flatbuf::PlasmaError error() const {
+ return static_cast<plasma::flatbuf::PlasmaError>(GetField<int32_t>(VT_ERROR, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_ERROR) &&
+ verifier.EndTable();
+ }
+ PlasmaCreateAndSealBatchReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaCreateAndSealBatchReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaCreateAndSealBatchReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealBatchReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaCreateAndSealBatchReplyBuilder {
+ typedef PlasmaCreateAndSealBatchReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_error(plasma::flatbuf::PlasmaError error) {
+ fbb_.AddElement<int32_t>(PlasmaCreateAndSealBatchReply::VT_ERROR, static_cast<int32_t>(error), 0);
+ }
+ explicit PlasmaCreateAndSealBatchReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaCreateAndSealBatchReplyBuilder &operator=(const PlasmaCreateAndSealBatchReplyBuilder &);
+ flatbuffers::Offset<PlasmaCreateAndSealBatchReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaCreateAndSealBatchReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaCreateAndSealBatchReply> CreatePlasmaCreateAndSealBatchReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ plasma::flatbuf::PlasmaError error = plasma::flatbuf::PlasmaError::OK) {
+ PlasmaCreateAndSealBatchReplyBuilder builder_(_fbb);
+ builder_.add_error(error);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PlasmaCreateAndSealBatchReply> CreatePlasmaCreateAndSealBatchReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealBatchReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaAbortRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaAbortRequest TableType;
+ std::string object_id;
+ PlasmaAbortRequestT() {
+ }
+};
+
+struct PlasmaAbortRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaAbortRequestT NativeTableType;
+ typedef PlasmaAbortRequestBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_ID = 4
+ };
+ const flatbuffers::String *object_id() const {
+ return GetPointer<const flatbuffers::String *>(VT_OBJECT_ID);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_ID) &&
+ verifier.VerifyString(object_id()) &&
+ verifier.EndTable();
+ }
+ PlasmaAbortRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaAbortRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaAbortRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaAbortRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaAbortRequestBuilder {
+ typedef PlasmaAbortRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_id(flatbuffers::Offset<flatbuffers::String> object_id) {
+ fbb_.AddOffset(PlasmaAbortRequest::VT_OBJECT_ID, object_id);
+ }
+ explicit PlasmaAbortRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaAbortRequestBuilder &operator=(const PlasmaAbortRequestBuilder &);
+ flatbuffers::Offset<PlasmaAbortRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaAbortRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaAbortRequest> CreatePlasmaAbortRequest(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> object_id = 0) {
+ PlasmaAbortRequestBuilder builder_(_fbb);
+ builder_.add_object_id(object_id);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaAbortRequest> CreatePlasmaAbortRequestDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *object_id = nullptr) {
+ auto object_id__ = object_id ? _fbb.CreateString(object_id) : 0;
+ return plasma::flatbuf::CreatePlasmaAbortRequest(
+ _fbb,
+ object_id__);
+}
+
+flatbuffers::Offset<PlasmaAbortRequest> CreatePlasmaAbortRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaAbortRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaAbortReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaAbortReply TableType;
+ std::string object_id;
+ PlasmaAbortReplyT() {
+ }
+};
+
+struct PlasmaAbortReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaAbortReplyT NativeTableType;
+ typedef PlasmaAbortReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_ID = 4
+ };
+ const flatbuffers::String *object_id() const {
+ return GetPointer<const flatbuffers::String *>(VT_OBJECT_ID);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_ID) &&
+ verifier.VerifyString(object_id()) &&
+ verifier.EndTable();
+ }
+ PlasmaAbortReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaAbortReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaAbortReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaAbortReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaAbortReplyBuilder {
+ typedef PlasmaAbortReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_id(flatbuffers::Offset<flatbuffers::String> object_id) {
+ fbb_.AddOffset(PlasmaAbortReply::VT_OBJECT_ID, object_id);
+ }
+ explicit PlasmaAbortReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaAbortReplyBuilder &operator=(const PlasmaAbortReplyBuilder &);
+ flatbuffers::Offset<PlasmaAbortReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaAbortReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaAbortReply> CreatePlasmaAbortReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> object_id = 0) {
+ PlasmaAbortReplyBuilder builder_(_fbb);
+ builder_.add_object_id(object_id);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaAbortReply> CreatePlasmaAbortReplyDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *object_id = nullptr) {
+ auto object_id__ = object_id ? _fbb.CreateString(object_id) : 0;
+ return plasma::flatbuf::CreatePlasmaAbortReply(
+ _fbb,
+ object_id__);
+}
+
+flatbuffers::Offset<PlasmaAbortReply> CreatePlasmaAbortReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaAbortReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaSealRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaSealRequest TableType;
+ std::string object_id;
+ std::string digest;
+ PlasmaSealRequestT() {
+ }
+};
+
+struct PlasmaSealRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaSealRequestT NativeTableType;
+ typedef PlasmaSealRequestBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_ID = 4,
+ VT_DIGEST = 6
+ };
+ const flatbuffers::String *object_id() const {
+ return GetPointer<const flatbuffers::String *>(VT_OBJECT_ID);
+ }
+ const flatbuffers::String *digest() const {
+ return GetPointer<const flatbuffers::String *>(VT_DIGEST);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_ID) &&
+ verifier.VerifyString(object_id()) &&
+ VerifyOffset(verifier, VT_DIGEST) &&
+ verifier.VerifyString(digest()) &&
+ verifier.EndTable();
+ }
+ PlasmaSealRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaSealRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaSealRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSealRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaSealRequestBuilder {
+ typedef PlasmaSealRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_id(flatbuffers::Offset<flatbuffers::String> object_id) {
+ fbb_.AddOffset(PlasmaSealRequest::VT_OBJECT_ID, object_id);
+ }
+ void add_digest(flatbuffers::Offset<flatbuffers::String> digest) {
+ fbb_.AddOffset(PlasmaSealRequest::VT_DIGEST, digest);
+ }
+ explicit PlasmaSealRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaSealRequestBuilder &operator=(const PlasmaSealRequestBuilder &);
+ flatbuffers::Offset<PlasmaSealRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaSealRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaSealRequest> CreatePlasmaSealRequest(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> object_id = 0,
+ flatbuffers::Offset<flatbuffers::String> digest = 0) {
+ PlasmaSealRequestBuilder builder_(_fbb);
+ builder_.add_digest(digest);
+ builder_.add_object_id(object_id);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaSealRequest> CreatePlasmaSealRequestDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *object_id = nullptr,
+ const char *digest = nullptr) {
+ auto object_id__ = object_id ? _fbb.CreateString(object_id) : 0;
+ auto digest__ = digest ? _fbb.CreateString(digest) : 0;
+ return plasma::flatbuf::CreatePlasmaSealRequest(
+ _fbb,
+ object_id__,
+ digest__);
+}
+
+flatbuffers::Offset<PlasmaSealRequest> CreatePlasmaSealRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSealRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaSealReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaSealReply TableType;
+ std::string object_id;
+ plasma::flatbuf::PlasmaError error;
+ PlasmaSealReplyT()
+ : error(plasma::flatbuf::PlasmaError::OK) {
+ }
+};
+
+struct PlasmaSealReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaSealReplyT NativeTableType;
+ typedef PlasmaSealReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_ID = 4,
+ VT_ERROR = 6
+ };
+ const flatbuffers::String *object_id() const {
+ return GetPointer<const flatbuffers::String *>(VT_OBJECT_ID);
+ }
+ plasma::flatbuf::PlasmaError error() const {
+ return static_cast<plasma::flatbuf::PlasmaError>(GetField<int32_t>(VT_ERROR, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_ID) &&
+ verifier.VerifyString(object_id()) &&
+ VerifyField<int32_t>(verifier, VT_ERROR) &&
+ verifier.EndTable();
+ }
+ PlasmaSealReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaSealReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaSealReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSealReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaSealReplyBuilder {
+ typedef PlasmaSealReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_id(flatbuffers::Offset<flatbuffers::String> object_id) {
+ fbb_.AddOffset(PlasmaSealReply::VT_OBJECT_ID, object_id);
+ }
+ void add_error(plasma::flatbuf::PlasmaError error) {
+ fbb_.AddElement<int32_t>(PlasmaSealReply::VT_ERROR, static_cast<int32_t>(error), 0);
+ }
+ explicit PlasmaSealReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaSealReplyBuilder &operator=(const PlasmaSealReplyBuilder &);
+ flatbuffers::Offset<PlasmaSealReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaSealReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaSealReply> CreatePlasmaSealReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> object_id = 0,
+ plasma::flatbuf::PlasmaError error = plasma::flatbuf::PlasmaError::OK) {
+ PlasmaSealReplyBuilder builder_(_fbb);
+ builder_.add_error(error);
+ builder_.add_object_id(object_id);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaSealReply> CreatePlasmaSealReplyDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *object_id = nullptr,
+ plasma::flatbuf::PlasmaError error = plasma::flatbuf::PlasmaError::OK) {
+ auto object_id__ = object_id ? _fbb.CreateString(object_id) : 0;
+ return plasma::flatbuf::CreatePlasmaSealReply(
+ _fbb,
+ object_id__,
+ error);
+}
+
+flatbuffers::Offset<PlasmaSealReply> CreatePlasmaSealReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSealReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaGetRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaGetRequest TableType;
+ std::vector<std::string> object_ids;
+ int64_t timeout_ms;
+ PlasmaGetRequestT()
+ : timeout_ms(0) {
+ }
+};
+
+struct PlasmaGetRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaGetRequestT NativeTableType;
+ typedef PlasmaGetRequestBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_IDS = 4,
+ VT_TIMEOUT_MS = 6
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *object_ids() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_OBJECT_IDS);
+ }
+ int64_t timeout_ms() const {
+ return GetField<int64_t>(VT_TIMEOUT_MS, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_IDS) &&
+ verifier.VerifyVector(object_ids()) &&
+ verifier.VerifyVectorOfStrings(object_ids()) &&
+ VerifyField<int64_t>(verifier, VT_TIMEOUT_MS) &&
+ verifier.EndTable();
+ }
+ PlasmaGetRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaGetRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaGetRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaGetRequestBuilder {
+ typedef PlasmaGetRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_ids(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> object_ids) {
+ fbb_.AddOffset(PlasmaGetRequest::VT_OBJECT_IDS, object_ids);
+ }
+ void add_timeout_ms(int64_t timeout_ms) {
+ fbb_.AddElement<int64_t>(PlasmaGetRequest::VT_TIMEOUT_MS, timeout_ms, 0);
+ }
+ explicit PlasmaGetRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaGetRequestBuilder &operator=(const PlasmaGetRequestBuilder &);
+ flatbuffers::Offset<PlasmaGetRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaGetRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaGetRequest> CreatePlasmaGetRequest(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> object_ids = 0,
+ int64_t timeout_ms = 0) {
+ PlasmaGetRequestBuilder builder_(_fbb);
+ builder_.add_timeout_ms(timeout_ms);
+ builder_.add_object_ids(object_ids);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaGetRequest> CreatePlasmaGetRequestDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *object_ids = nullptr,
+ int64_t timeout_ms = 0) {
+ auto object_ids__ = object_ids ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*object_ids) : 0;
+ return plasma::flatbuf::CreatePlasmaGetRequest(
+ _fbb,
+ object_ids__,
+ timeout_ms);
+}
+
+flatbuffers::Offset<PlasmaGetRequest> CreatePlasmaGetRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaGetReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaGetReply TableType;
+ std::vector<std::string> object_ids;
+ std::vector<plasma::flatbuf::PlasmaObjectSpec> plasma_objects;
+ std::vector<int32_t> store_fds;
+ std::vector<int64_t> mmap_sizes;
+ std::vector<std::unique_ptr<plasma::flatbuf::CudaHandleT>> handles;
+ PlasmaGetReplyT() {
+ }
+};
+
+struct PlasmaGetReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaGetReplyT NativeTableType;
+ typedef PlasmaGetReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_IDS = 4,
+ VT_PLASMA_OBJECTS = 6,
+ VT_STORE_FDS = 8,
+ VT_MMAP_SIZES = 10,
+ VT_HANDLES = 12
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *object_ids() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_OBJECT_IDS);
+ }
+ const flatbuffers::Vector<const plasma::flatbuf::PlasmaObjectSpec *> *plasma_objects() const {
+ return GetPointer<const flatbuffers::Vector<const plasma::flatbuf::PlasmaObjectSpec *> *>(VT_PLASMA_OBJECTS);
+ }
+ const flatbuffers::Vector<int32_t> *store_fds() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_STORE_FDS);
+ }
+ const flatbuffers::Vector<int64_t> *mmap_sizes() const {
+ return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_MMAP_SIZES);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<plasma::flatbuf::CudaHandle>> *handles() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<plasma::flatbuf::CudaHandle>> *>(VT_HANDLES);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_IDS) &&
+ verifier.VerifyVector(object_ids()) &&
+ verifier.VerifyVectorOfStrings(object_ids()) &&
+ VerifyOffset(verifier, VT_PLASMA_OBJECTS) &&
+ verifier.VerifyVector(plasma_objects()) &&
+ VerifyOffset(verifier, VT_STORE_FDS) &&
+ verifier.VerifyVector(store_fds()) &&
+ VerifyOffset(verifier, VT_MMAP_SIZES) &&
+ verifier.VerifyVector(mmap_sizes()) &&
+ VerifyOffset(verifier, VT_HANDLES) &&
+ verifier.VerifyVector(handles()) &&
+ verifier.VerifyVectorOfTables(handles()) &&
+ verifier.EndTable();
+ }
+ PlasmaGetReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaGetReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaGetReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaGetReplyBuilder {
+ typedef PlasmaGetReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_ids(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> object_ids) {
+ fbb_.AddOffset(PlasmaGetReply::VT_OBJECT_IDS, object_ids);
+ }
+ void add_plasma_objects(flatbuffers::Offset<flatbuffers::Vector<const plasma::flatbuf::PlasmaObjectSpec *>> plasma_objects) {
+ fbb_.AddOffset(PlasmaGetReply::VT_PLASMA_OBJECTS, plasma_objects);
+ }
+ void add_store_fds(flatbuffers::Offset<flatbuffers::Vector<int32_t>> store_fds) {
+ fbb_.AddOffset(PlasmaGetReply::VT_STORE_FDS, store_fds);
+ }
+ void add_mmap_sizes(flatbuffers::Offset<flatbuffers::Vector<int64_t>> mmap_sizes) {
+ fbb_.AddOffset(PlasmaGetReply::VT_MMAP_SIZES, mmap_sizes);
+ }
+ void add_handles(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<plasma::flatbuf::CudaHandle>>> handles) {
+ fbb_.AddOffset(PlasmaGetReply::VT_HANDLES, handles);
+ }
+ explicit PlasmaGetReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaGetReplyBuilder &operator=(const PlasmaGetReplyBuilder &);
+ flatbuffers::Offset<PlasmaGetReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaGetReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaGetReply> CreatePlasmaGetReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> object_ids = 0,
+ flatbuffers::Offset<flatbuffers::Vector<const plasma::flatbuf::PlasmaObjectSpec *>> plasma_objects = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> store_fds = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int64_t>> mmap_sizes = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<plasma::flatbuf::CudaHandle>>> handles = 0) {
+ PlasmaGetReplyBuilder builder_(_fbb);
+ builder_.add_handles(handles);
+ builder_.add_mmap_sizes(mmap_sizes);
+ builder_.add_store_fds(store_fds);
+ builder_.add_plasma_objects(plasma_objects);
+ builder_.add_object_ids(object_ids);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaGetReply> CreatePlasmaGetReplyDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *object_ids = nullptr,
+ const std::vector<plasma::flatbuf::PlasmaObjectSpec> *plasma_objects = nullptr,
+ const std::vector<int32_t> *store_fds = nullptr,
+ const std::vector<int64_t> *mmap_sizes = nullptr,
+ const std::vector<flatbuffers::Offset<plasma::flatbuf::CudaHandle>> *handles = nullptr) {
+ auto object_ids__ = object_ids ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*object_ids) : 0;
+ auto plasma_objects__ = plasma_objects ? _fbb.CreateVectorOfStructs<plasma::flatbuf::PlasmaObjectSpec>(*plasma_objects) : 0;
+ auto store_fds__ = store_fds ? _fbb.CreateVector<int32_t>(*store_fds) : 0;
+ auto mmap_sizes__ = mmap_sizes ? _fbb.CreateVector<int64_t>(*mmap_sizes) : 0;
+ auto handles__ = handles ? _fbb.CreateVector<flatbuffers::Offset<plasma::flatbuf::CudaHandle>>(*handles) : 0;
+ return plasma::flatbuf::CreatePlasmaGetReply(
+ _fbb,
+ object_ids__,
+ plasma_objects__,
+ store_fds__,
+ mmap_sizes__,
+ handles__);
+}
+
+flatbuffers::Offset<PlasmaGetReply> CreatePlasmaGetReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaReleaseRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaReleaseRequest TableType;
+ std::string object_id;
+ PlasmaReleaseRequestT() {
+ }
+};
+
+struct PlasmaReleaseRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaReleaseRequestT NativeTableType;
+ typedef PlasmaReleaseRequestBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_ID = 4
+ };
+ const flatbuffers::String *object_id() const {
+ return GetPointer<const flatbuffers::String *>(VT_OBJECT_ID);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_ID) &&
+ verifier.VerifyString(object_id()) &&
+ verifier.EndTable();
+ }
+ PlasmaReleaseRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaReleaseRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaReleaseRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaReleaseRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaReleaseRequestBuilder {
+ typedef PlasmaReleaseRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_id(flatbuffers::Offset<flatbuffers::String> object_id) {
+ fbb_.AddOffset(PlasmaReleaseRequest::VT_OBJECT_ID, object_id);
+ }
+ explicit PlasmaReleaseRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaReleaseRequestBuilder &operator=(const PlasmaReleaseRequestBuilder &);
+ flatbuffers::Offset<PlasmaReleaseRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaReleaseRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaReleaseRequest> CreatePlasmaReleaseRequest(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> object_id = 0) {
+ PlasmaReleaseRequestBuilder builder_(_fbb);
+ builder_.add_object_id(object_id);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaReleaseRequest> CreatePlasmaReleaseRequestDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *object_id = nullptr) {
+ auto object_id__ = object_id ? _fbb.CreateString(object_id) : 0;
+ return plasma::flatbuf::CreatePlasmaReleaseRequest(
+ _fbb,
+ object_id__);
+}
+
+flatbuffers::Offset<PlasmaReleaseRequest> CreatePlasmaReleaseRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaReleaseRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaReleaseReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaReleaseReply TableType;
+ std::string object_id;
+ plasma::flatbuf::PlasmaError error;
+ PlasmaReleaseReplyT()
+ : error(plasma::flatbuf::PlasmaError::OK) {
+ }
+};
+
+struct PlasmaReleaseReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaReleaseReplyT NativeTableType;
+ typedef PlasmaReleaseReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_ID = 4,
+ VT_ERROR = 6
+ };
+ const flatbuffers::String *object_id() const {
+ return GetPointer<const flatbuffers::String *>(VT_OBJECT_ID);
+ }
+ plasma::flatbuf::PlasmaError error() const {
+ return static_cast<plasma::flatbuf::PlasmaError>(GetField<int32_t>(VT_ERROR, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_ID) &&
+ verifier.VerifyString(object_id()) &&
+ VerifyField<int32_t>(verifier, VT_ERROR) &&
+ verifier.EndTable();
+ }
+ PlasmaReleaseReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaReleaseReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaReleaseReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaReleaseReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaReleaseReplyBuilder {
+ typedef PlasmaReleaseReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_id(flatbuffers::Offset<flatbuffers::String> object_id) {
+ fbb_.AddOffset(PlasmaReleaseReply::VT_OBJECT_ID, object_id);
+ }
+ void add_error(plasma::flatbuf::PlasmaError error) {
+ fbb_.AddElement<int32_t>(PlasmaReleaseReply::VT_ERROR, static_cast<int32_t>(error), 0);
+ }
+ explicit PlasmaReleaseReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaReleaseReplyBuilder &operator=(const PlasmaReleaseReplyBuilder &);
+ flatbuffers::Offset<PlasmaReleaseReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaReleaseReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaReleaseReply> CreatePlasmaReleaseReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> object_id = 0,
+ plasma::flatbuf::PlasmaError error = plasma::flatbuf::PlasmaError::OK) {
+ PlasmaReleaseReplyBuilder builder_(_fbb);
+ builder_.add_error(error);
+ builder_.add_object_id(object_id);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaReleaseReply> CreatePlasmaReleaseReplyDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *object_id = nullptr,
+ plasma::flatbuf::PlasmaError error = plasma::flatbuf::PlasmaError::OK) {
+ auto object_id__ = object_id ? _fbb.CreateString(object_id) : 0;
+ return plasma::flatbuf::CreatePlasmaReleaseReply(
+ _fbb,
+ object_id__,
+ error);
+}
+
+flatbuffers::Offset<PlasmaReleaseReply> CreatePlasmaReleaseReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaReleaseReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaDeleteRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaDeleteRequest TableType;
+ int32_t count;
+ std::vector<std::string> object_ids;
+ PlasmaDeleteRequestT()
+ : count(0) {
+ }
+};
+
+struct PlasmaDeleteRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaDeleteRequestT NativeTableType;
+ typedef PlasmaDeleteRequestBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_COUNT = 4,
+ VT_OBJECT_IDS = 6
+ };
+ int32_t count() const {
+ return GetField<int32_t>(VT_COUNT, 0);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *object_ids() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_OBJECT_IDS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_COUNT) &&
+ VerifyOffset(verifier, VT_OBJECT_IDS) &&
+ verifier.VerifyVector(object_ids()) &&
+ verifier.VerifyVectorOfStrings(object_ids()) &&
+ verifier.EndTable();
+ }
+ PlasmaDeleteRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaDeleteRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaDeleteRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDeleteRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaDeleteRequestBuilder {
+ typedef PlasmaDeleteRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_count(int32_t count) {
+ fbb_.AddElement<int32_t>(PlasmaDeleteRequest::VT_COUNT, count, 0);
+ }
+ void add_object_ids(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> object_ids) {
+ fbb_.AddOffset(PlasmaDeleteRequest::VT_OBJECT_IDS, object_ids);
+ }
+ explicit PlasmaDeleteRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaDeleteRequestBuilder &operator=(const PlasmaDeleteRequestBuilder &);
+ flatbuffers::Offset<PlasmaDeleteRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaDeleteRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaDeleteRequest> CreatePlasmaDeleteRequest(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t count = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> object_ids = 0) {
+ PlasmaDeleteRequestBuilder builder_(_fbb);
+ builder_.add_object_ids(object_ids);
+ builder_.add_count(count);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaDeleteRequest> CreatePlasmaDeleteRequestDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t count = 0,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *object_ids = nullptr) {
+ auto object_ids__ = object_ids ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*object_ids) : 0;
+ return plasma::flatbuf::CreatePlasmaDeleteRequest(
+ _fbb,
+ count,
+ object_ids__);
+}
+
+flatbuffers::Offset<PlasmaDeleteRequest> CreatePlasmaDeleteRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDeleteRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaDeleteReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaDeleteReply TableType;
+ int32_t count;
+ std::vector<std::string> object_ids;
+ std::vector<plasma::flatbuf::PlasmaError> errors;
+ PlasmaDeleteReplyT()
+ : count(0) {
+ }
+};
+
+struct PlasmaDeleteReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaDeleteReplyT NativeTableType;
+ typedef PlasmaDeleteReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_COUNT = 4,
+ VT_OBJECT_IDS = 6,
+ VT_ERRORS = 8
+ };
+ int32_t count() const {
+ return GetField<int32_t>(VT_COUNT, 0);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *object_ids() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_OBJECT_IDS);
+ }
+ const flatbuffers::Vector<int32_t> *errors() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_ERRORS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_COUNT) &&
+ VerifyOffset(verifier, VT_OBJECT_IDS) &&
+ verifier.VerifyVector(object_ids()) &&
+ verifier.VerifyVectorOfStrings(object_ids()) &&
+ VerifyOffset(verifier, VT_ERRORS) &&
+ verifier.VerifyVector(errors()) &&
+ verifier.EndTable();
+ }
+ PlasmaDeleteReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaDeleteReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaDeleteReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDeleteReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaDeleteReplyBuilder {
+ typedef PlasmaDeleteReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_count(int32_t count) {
+ fbb_.AddElement<int32_t>(PlasmaDeleteReply::VT_COUNT, count, 0);
+ }
+ void add_object_ids(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> object_ids) {
+ fbb_.AddOffset(PlasmaDeleteReply::VT_OBJECT_IDS, object_ids);
+ }
+ void add_errors(flatbuffers::Offset<flatbuffers::Vector<int32_t>> errors) {
+ fbb_.AddOffset(PlasmaDeleteReply::VT_ERRORS, errors);
+ }
+ explicit PlasmaDeleteReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaDeleteReplyBuilder &operator=(const PlasmaDeleteReplyBuilder &);
+ flatbuffers::Offset<PlasmaDeleteReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaDeleteReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaDeleteReply> CreatePlasmaDeleteReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t count = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> object_ids = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> errors = 0) {
+ PlasmaDeleteReplyBuilder builder_(_fbb);
+ builder_.add_errors(errors);
+ builder_.add_object_ids(object_ids);
+ builder_.add_count(count);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaDeleteReply> CreatePlasmaDeleteReplyDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t count = 0,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *object_ids = nullptr,
+ const std::vector<int32_t> *errors = nullptr) {
+ auto object_ids__ = object_ids ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*object_ids) : 0;
+ auto errors__ = errors ? _fbb.CreateVector<int32_t>(*errors) : 0;
+ return plasma::flatbuf::CreatePlasmaDeleteReply(
+ _fbb,
+ count,
+ object_ids__,
+ errors__);
+}
+
+flatbuffers::Offset<PlasmaDeleteReply> CreatePlasmaDeleteReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDeleteReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaContainsRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaContainsRequest TableType;
+ std::string object_id;
+ PlasmaContainsRequestT() {
+ }
+};
+
+struct PlasmaContainsRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaContainsRequestT NativeTableType;
+ typedef PlasmaContainsRequestBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_ID = 4
+ };
+ const flatbuffers::String *object_id() const {
+ return GetPointer<const flatbuffers::String *>(VT_OBJECT_ID);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_ID) &&
+ verifier.VerifyString(object_id()) &&
+ verifier.EndTable();
+ }
+ PlasmaContainsRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaContainsRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaContainsRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaContainsRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaContainsRequestBuilder {
+ typedef PlasmaContainsRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_id(flatbuffers::Offset<flatbuffers::String> object_id) {
+ fbb_.AddOffset(PlasmaContainsRequest::VT_OBJECT_ID, object_id);
+ }
+ explicit PlasmaContainsRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaContainsRequestBuilder &operator=(const PlasmaContainsRequestBuilder &);
+ flatbuffers::Offset<PlasmaContainsRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaContainsRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaContainsRequest> CreatePlasmaContainsRequest(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> object_id = 0) {
+ PlasmaContainsRequestBuilder builder_(_fbb);
+ builder_.add_object_id(object_id);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaContainsRequest> CreatePlasmaContainsRequestDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *object_id = nullptr) {
+ auto object_id__ = object_id ? _fbb.CreateString(object_id) : 0;
+ return plasma::flatbuf::CreatePlasmaContainsRequest(
+ _fbb,
+ object_id__);
+}
+
+flatbuffers::Offset<PlasmaContainsRequest> CreatePlasmaContainsRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaContainsRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaContainsReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaContainsReply TableType;
+ std::string object_id;
+ int32_t has_object;
+ PlasmaContainsReplyT()
+ : has_object(0) {
+ }
+};
+
+struct PlasmaContainsReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaContainsReplyT NativeTableType;
+ typedef PlasmaContainsReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_ID = 4,
+ VT_HAS_OBJECT = 6
+ };
+ const flatbuffers::String *object_id() const {
+ return GetPointer<const flatbuffers::String *>(VT_OBJECT_ID);
+ }
+ int32_t has_object() const {
+ return GetField<int32_t>(VT_HAS_OBJECT, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_ID) &&
+ verifier.VerifyString(object_id()) &&
+ VerifyField<int32_t>(verifier, VT_HAS_OBJECT) &&
+ verifier.EndTable();
+ }
+ PlasmaContainsReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaContainsReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaContainsReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaContainsReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaContainsReplyBuilder {
+ typedef PlasmaContainsReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_id(flatbuffers::Offset<flatbuffers::String> object_id) {
+ fbb_.AddOffset(PlasmaContainsReply::VT_OBJECT_ID, object_id);
+ }
+ void add_has_object(int32_t has_object) {
+ fbb_.AddElement<int32_t>(PlasmaContainsReply::VT_HAS_OBJECT, has_object, 0);
+ }
+ explicit PlasmaContainsReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaContainsReplyBuilder &operator=(const PlasmaContainsReplyBuilder &);
+ flatbuffers::Offset<PlasmaContainsReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaContainsReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaContainsReply> CreatePlasmaContainsReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> object_id = 0,
+ int32_t has_object = 0) {
+ PlasmaContainsReplyBuilder builder_(_fbb);
+ builder_.add_has_object(has_object);
+ builder_.add_object_id(object_id);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaContainsReply> CreatePlasmaContainsReplyDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *object_id = nullptr,
+ int32_t has_object = 0) {
+ auto object_id__ = object_id ? _fbb.CreateString(object_id) : 0;
+ return plasma::flatbuf::CreatePlasmaContainsReply(
+ _fbb,
+ object_id__,
+ has_object);
+}
+
+flatbuffers::Offset<PlasmaContainsReply> CreatePlasmaContainsReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaContainsReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaListRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaListRequest TableType;
+ PlasmaListRequestT() {
+ }
+};
+
+struct PlasmaListRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaListRequestT NativeTableType;
+ typedef PlasmaListRequestBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ PlasmaListRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaListRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaListRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaListRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaListRequestBuilder {
+ typedef PlasmaListRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit PlasmaListRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaListRequestBuilder &operator=(const PlasmaListRequestBuilder &);
+ flatbuffers::Offset<PlasmaListRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaListRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaListRequest> CreatePlasmaListRequest(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ PlasmaListRequestBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PlasmaListRequest> CreatePlasmaListRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaListRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaListReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaListReply TableType;
+ std::vector<std::unique_ptr<plasma::flatbuf::ObjectInfoT>> objects;
+ PlasmaListReplyT() {
+ }
+};
+
+struct PlasmaListReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaListReplyT NativeTableType;
+ typedef PlasmaListReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECTS = 4
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>> *objects() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>> *>(VT_OBJECTS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECTS) &&
+ verifier.VerifyVector(objects()) &&
+ verifier.VerifyVectorOfTables(objects()) &&
+ verifier.EndTable();
+ }
+ PlasmaListReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaListReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaListReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaListReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaListReplyBuilder {
+ typedef PlasmaListReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_objects(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>>> objects) {
+ fbb_.AddOffset(PlasmaListReply::VT_OBJECTS, objects);
+ }
+ explicit PlasmaListReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaListReplyBuilder &operator=(const PlasmaListReplyBuilder &);
+ flatbuffers::Offset<PlasmaListReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaListReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaListReply> CreatePlasmaListReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>>> objects = 0) {
+ PlasmaListReplyBuilder builder_(_fbb);
+ builder_.add_objects(objects);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaListReply> CreatePlasmaListReplyDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>> *objects = nullptr) {
+ auto objects__ = objects ? _fbb.CreateVector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>>(*objects) : 0;
+ return plasma::flatbuf::CreatePlasmaListReply(
+ _fbb,
+ objects__);
+}
+
+flatbuffers::Offset<PlasmaListReply> CreatePlasmaListReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaListReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaConnectRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaConnectRequest TableType;
+ PlasmaConnectRequestT() {
+ }
+};
+
+struct PlasmaConnectRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaConnectRequestT NativeTableType;
+ typedef PlasmaConnectRequestBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ PlasmaConnectRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaConnectRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaConnectRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaConnectRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaConnectRequestBuilder {
+ typedef PlasmaConnectRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit PlasmaConnectRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaConnectRequestBuilder &operator=(const PlasmaConnectRequestBuilder &);
+ flatbuffers::Offset<PlasmaConnectRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaConnectRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaConnectRequest> CreatePlasmaConnectRequest(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ PlasmaConnectRequestBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PlasmaConnectRequest> CreatePlasmaConnectRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaConnectRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaConnectReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaConnectReply TableType;
+ int64_t memory_capacity;
+ PlasmaConnectReplyT()
+ : memory_capacity(0) {
+ }
+};
+
+struct PlasmaConnectReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaConnectReplyT NativeTableType;
+ typedef PlasmaConnectReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_MEMORY_CAPACITY = 4
+ };
+ int64_t memory_capacity() const {
+ return GetField<int64_t>(VT_MEMORY_CAPACITY, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int64_t>(verifier, VT_MEMORY_CAPACITY) &&
+ verifier.EndTable();
+ }
+ PlasmaConnectReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaConnectReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaConnectReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaConnectReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaConnectReplyBuilder {
+ typedef PlasmaConnectReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_memory_capacity(int64_t memory_capacity) {
+ fbb_.AddElement<int64_t>(PlasmaConnectReply::VT_MEMORY_CAPACITY, memory_capacity, 0);
+ }
+ explicit PlasmaConnectReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaConnectReplyBuilder &operator=(const PlasmaConnectReplyBuilder &);
+ flatbuffers::Offset<PlasmaConnectReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaConnectReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaConnectReply> CreatePlasmaConnectReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int64_t memory_capacity = 0) {
+ PlasmaConnectReplyBuilder builder_(_fbb);
+ builder_.add_memory_capacity(memory_capacity);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PlasmaConnectReply> CreatePlasmaConnectReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaConnectReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaEvictRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaEvictRequest TableType;
+ uint64_t num_bytes;
+ PlasmaEvictRequestT()
+ : num_bytes(0) {
+ }
+};
+
+struct PlasmaEvictRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaEvictRequestT NativeTableType;
+ typedef PlasmaEvictRequestBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_NUM_BYTES = 4
+ };
+ uint64_t num_bytes() const {
+ return GetField<uint64_t>(VT_NUM_BYTES, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint64_t>(verifier, VT_NUM_BYTES) &&
+ verifier.EndTable();
+ }
+ PlasmaEvictRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaEvictRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaEvictRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaEvictRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaEvictRequestBuilder {
+ typedef PlasmaEvictRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_num_bytes(uint64_t num_bytes) {
+ fbb_.AddElement<uint64_t>(PlasmaEvictRequest::VT_NUM_BYTES, num_bytes, 0);
+ }
+ explicit PlasmaEvictRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaEvictRequestBuilder &operator=(const PlasmaEvictRequestBuilder &);
+ flatbuffers::Offset<PlasmaEvictRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaEvictRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaEvictRequest> CreatePlasmaEvictRequest(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ uint64_t num_bytes = 0) {
+ PlasmaEvictRequestBuilder builder_(_fbb);
+ builder_.add_num_bytes(num_bytes);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PlasmaEvictRequest> CreatePlasmaEvictRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaEvictRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaEvictReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaEvictReply TableType;
+ uint64_t num_bytes;
+ PlasmaEvictReplyT()
+ : num_bytes(0) {
+ }
+};
+
+struct PlasmaEvictReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaEvictReplyT NativeTableType;
+ typedef PlasmaEvictReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_NUM_BYTES = 4
+ };
+ uint64_t num_bytes() const {
+ return GetField<uint64_t>(VT_NUM_BYTES, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint64_t>(verifier, VT_NUM_BYTES) &&
+ verifier.EndTable();
+ }
+ PlasmaEvictReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaEvictReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaEvictReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaEvictReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaEvictReplyBuilder {
+ typedef PlasmaEvictReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_num_bytes(uint64_t num_bytes) {
+ fbb_.AddElement<uint64_t>(PlasmaEvictReply::VT_NUM_BYTES, num_bytes, 0);
+ }
+ explicit PlasmaEvictReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaEvictReplyBuilder &operator=(const PlasmaEvictReplyBuilder &);
+ flatbuffers::Offset<PlasmaEvictReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaEvictReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaEvictReply> CreatePlasmaEvictReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ uint64_t num_bytes = 0) {
+ PlasmaEvictReplyBuilder builder_(_fbb);
+ builder_.add_num_bytes(num_bytes);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PlasmaEvictReply> CreatePlasmaEvictReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaEvictReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaSubscribeRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaSubscribeRequest TableType;
+ PlasmaSubscribeRequestT() {
+ }
+};
+
+struct PlasmaSubscribeRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaSubscribeRequestT NativeTableType;
+ typedef PlasmaSubscribeRequestBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ PlasmaSubscribeRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaSubscribeRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaSubscribeRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSubscribeRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaSubscribeRequestBuilder {
+ typedef PlasmaSubscribeRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit PlasmaSubscribeRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaSubscribeRequestBuilder &operator=(const PlasmaSubscribeRequestBuilder &);
+ flatbuffers::Offset<PlasmaSubscribeRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaSubscribeRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaSubscribeRequest> CreatePlasmaSubscribeRequest(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ PlasmaSubscribeRequestBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PlasmaSubscribeRequest> CreatePlasmaSubscribeRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSubscribeRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaNotificationT : public flatbuffers::NativeTable {
+ typedef PlasmaNotification TableType;
+ std::vector<std::unique_ptr<plasma::flatbuf::ObjectInfoT>> object_info;
+ PlasmaNotificationT() {
+ }
+};
+
+struct PlasmaNotification FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaNotificationT NativeTableType;
+ typedef PlasmaNotificationBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_INFO = 4
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>> *object_info() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>> *>(VT_OBJECT_INFO);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_INFO) &&
+ verifier.VerifyVector(object_info()) &&
+ verifier.VerifyVectorOfTables(object_info()) &&
+ verifier.EndTable();
+ }
+ PlasmaNotificationT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaNotificationT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaNotification> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaNotificationT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaNotificationBuilder {
+ typedef PlasmaNotification Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_info(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>>> object_info) {
+ fbb_.AddOffset(PlasmaNotification::VT_OBJECT_INFO, object_info);
+ }
+ explicit PlasmaNotificationBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaNotificationBuilder &operator=(const PlasmaNotificationBuilder &);
+ flatbuffers::Offset<PlasmaNotification> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaNotification>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaNotification> CreatePlasmaNotification(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>>> object_info = 0) {
+ PlasmaNotificationBuilder builder_(_fbb);
+ builder_.add_object_info(object_info);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaNotification> CreatePlasmaNotificationDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>> *object_info = nullptr) {
+ auto object_info__ = object_info ? _fbb.CreateVector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>>(*object_info) : 0;
+ return plasma::flatbuf::CreatePlasmaNotification(
+ _fbb,
+ object_info__);
+}
+
+flatbuffers::Offset<PlasmaNotification> CreatePlasmaNotification(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaNotificationT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaDataRequestT : public flatbuffers::NativeTable {
+ typedef PlasmaDataRequest TableType;
+ std::string object_id;
+ std::string address;
+ int32_t port;
+ PlasmaDataRequestT()
+ : port(0) {
+ }
+};
+
+struct PlasmaDataRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaDataRequestT NativeTableType;
+ typedef PlasmaDataRequestBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_ID = 4,
+ VT_ADDRESS = 6,
+ VT_PORT = 8
+ };
+ const flatbuffers::String *object_id() const {
+ return GetPointer<const flatbuffers::String *>(VT_OBJECT_ID);
+ }
+ const flatbuffers::String *address() const {
+ return GetPointer<const flatbuffers::String *>(VT_ADDRESS);
+ }
+ int32_t port() const {
+ return GetField<int32_t>(VT_PORT, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_ID) &&
+ verifier.VerifyString(object_id()) &&
+ VerifyOffset(verifier, VT_ADDRESS) &&
+ verifier.VerifyString(address()) &&
+ VerifyField<int32_t>(verifier, VT_PORT) &&
+ verifier.EndTable();
+ }
+ PlasmaDataRequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaDataRequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaDataRequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDataRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaDataRequestBuilder {
+ typedef PlasmaDataRequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_id(flatbuffers::Offset<flatbuffers::String> object_id) {
+ fbb_.AddOffset(PlasmaDataRequest::VT_OBJECT_ID, object_id);
+ }
+ void add_address(flatbuffers::Offset<flatbuffers::String> address) {
+ fbb_.AddOffset(PlasmaDataRequest::VT_ADDRESS, address);
+ }
+ void add_port(int32_t port) {
+ fbb_.AddElement<int32_t>(PlasmaDataRequest::VT_PORT, port, 0);
+ }
+ explicit PlasmaDataRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaDataRequestBuilder &operator=(const PlasmaDataRequestBuilder &);
+ flatbuffers::Offset<PlasmaDataRequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaDataRequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaDataRequest> CreatePlasmaDataRequest(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> object_id = 0,
+ flatbuffers::Offset<flatbuffers::String> address = 0,
+ int32_t port = 0) {
+ PlasmaDataRequestBuilder builder_(_fbb);
+ builder_.add_port(port);
+ builder_.add_address(address);
+ builder_.add_object_id(object_id);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaDataRequest> CreatePlasmaDataRequestDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *object_id = nullptr,
+ const char *address = nullptr,
+ int32_t port = 0) {
+ auto object_id__ = object_id ? _fbb.CreateString(object_id) : 0;
+ auto address__ = address ? _fbb.CreateString(address) : 0;
+ return plasma::flatbuf::CreatePlasmaDataRequest(
+ _fbb,
+ object_id__,
+ address__,
+ port);
+}
+
+flatbuffers::Offset<PlasmaDataRequest> CreatePlasmaDataRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDataRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaDataReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaDataReply TableType;
+ std::string object_id;
+ uint64_t object_size;
+ uint64_t metadata_size;
+ PlasmaDataReplyT()
+ : object_size(0),
+ metadata_size(0) {
+ }
+};
+
+struct PlasmaDataReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaDataReplyT NativeTableType;
+ typedef PlasmaDataReplyBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_ID = 4,
+ VT_OBJECT_SIZE = 6,
+ VT_METADATA_SIZE = 8
+ };
+ const flatbuffers::String *object_id() const {
+ return GetPointer<const flatbuffers::String *>(VT_OBJECT_ID);
+ }
+ uint64_t object_size() const {
+ return GetField<uint64_t>(VT_OBJECT_SIZE, 0);
+ }
+ uint64_t metadata_size() const {
+ return GetField<uint64_t>(VT_METADATA_SIZE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_ID) &&
+ verifier.VerifyString(object_id()) &&
+ VerifyField<uint64_t>(verifier, VT_OBJECT_SIZE) &&
+ VerifyField<uint64_t>(verifier, VT_METADATA_SIZE) &&
+ verifier.EndTable();
+ }
+ PlasmaDataReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaDataReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaDataReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDataReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaDataReplyBuilder {
+ typedef PlasmaDataReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_id(flatbuffers::Offset<flatbuffers::String> object_id) {
+ fbb_.AddOffset(PlasmaDataReply::VT_OBJECT_ID, object_id);
+ }
+ void add_object_size(uint64_t object_size) {
+ fbb_.AddElement<uint64_t>(PlasmaDataReply::VT_OBJECT_SIZE, object_size, 0);
+ }
+ void add_metadata_size(uint64_t metadata_size) {
+ fbb_.AddElement<uint64_t>(PlasmaDataReply::VT_METADATA_SIZE, metadata_size, 0);
+ }
+ explicit PlasmaDataReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaDataReplyBuilder &operator=(const PlasmaDataReplyBuilder &);
+ flatbuffers::Offset<PlasmaDataReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaDataReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaDataReply> CreatePlasmaDataReply(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> object_id = 0,
+ uint64_t object_size = 0,
+ uint64_t metadata_size = 0) {
+ PlasmaDataReplyBuilder builder_(_fbb);
+ builder_.add_metadata_size(metadata_size);
+ builder_.add_object_size(object_size);
+ builder_.add_object_id(object_id);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaDataReply> CreatePlasmaDataReplyDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *object_id = nullptr,
+ uint64_t object_size = 0,
+ uint64_t metadata_size = 0) {
+ auto object_id__ = object_id ? _fbb.CreateString(object_id) : 0;
+ return plasma::flatbuf::CreatePlasmaDataReply(
+ _fbb,
+ object_id__,
+ object_size,
+ metadata_size);
+}
+
+flatbuffers::Offset<PlasmaDataReply> CreatePlasmaDataReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDataReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaRefreshLRURequestT : public flatbuffers::NativeTable {
+ typedef PlasmaRefreshLRURequest TableType;
+ std::vector<std::string> object_ids;
+ PlasmaRefreshLRURequestT() {
+ }
+};
+
+struct PlasmaRefreshLRURequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaRefreshLRURequestT NativeTableType;
+ typedef PlasmaRefreshLRURequestBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OBJECT_IDS = 4
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *object_ids() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_OBJECT_IDS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OBJECT_IDS) &&
+ verifier.VerifyVector(object_ids()) &&
+ verifier.VerifyVectorOfStrings(object_ids()) &&
+ verifier.EndTable();
+ }
+ PlasmaRefreshLRURequestT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaRefreshLRURequestT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaRefreshLRURequest> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaRefreshLRURequestT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaRefreshLRURequestBuilder {
+ typedef PlasmaRefreshLRURequest Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_object_ids(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> object_ids) {
+ fbb_.AddOffset(PlasmaRefreshLRURequest::VT_OBJECT_IDS, object_ids);
+ }
+ explicit PlasmaRefreshLRURequestBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaRefreshLRURequestBuilder &operator=(const PlasmaRefreshLRURequestBuilder &);
+ flatbuffers::Offset<PlasmaRefreshLRURequest> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaRefreshLRURequest>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaRefreshLRURequest> CreatePlasmaRefreshLRURequest(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> object_ids = 0) {
+ PlasmaRefreshLRURequestBuilder builder_(_fbb);
+ builder_.add_object_ids(object_ids);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<PlasmaRefreshLRURequest> CreatePlasmaRefreshLRURequestDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *object_ids = nullptr) {
+ auto object_ids__ = object_ids ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*object_ids) : 0;
+ return plasma::flatbuf::CreatePlasmaRefreshLRURequest(
+ _fbb,
+ object_ids__);
+}
+
+flatbuffers::Offset<PlasmaRefreshLRURequest> CreatePlasmaRefreshLRURequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaRefreshLRURequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PlasmaRefreshLRUReplyT : public flatbuffers::NativeTable {
+ typedef PlasmaRefreshLRUReply TableType;
+ PlasmaRefreshLRUReplyT() {
+ }
+};
+
+struct PlasmaRefreshLRUReply FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PlasmaRefreshLRUReplyT NativeTableType;
+ typedef PlasmaRefreshLRUReplyBuilder Builder;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ PlasmaRefreshLRUReplyT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PlasmaRefreshLRUReplyT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PlasmaRefreshLRUReply> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaRefreshLRUReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PlasmaRefreshLRUReplyBuilder {
+ typedef PlasmaRefreshLRUReply Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit PlasmaRefreshLRUReplyBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PlasmaRefreshLRUReplyBuilder &operator=(const PlasmaRefreshLRUReplyBuilder &);
+ flatbuffers::Offset<PlasmaRefreshLRUReply> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PlasmaRefreshLRUReply>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PlasmaRefreshLRUReply> CreatePlasmaRefreshLRUReply(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ PlasmaRefreshLRUReplyBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PlasmaRefreshLRUReply> CreatePlasmaRefreshLRUReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaRefreshLRUReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+inline PlasmaSetOptionsRequestT *PlasmaSetOptionsRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaSetOptionsRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaSetOptionsRequestT>(new PlasmaSetOptionsRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaSetOptionsRequest::UnPackTo(PlasmaSetOptionsRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = client_name(); if (_e) _o->client_name = _e->str(); }
+ { auto _e = output_memory_quota(); _o->output_memory_quota = _e; }
+}
+
+inline flatbuffers::Offset<PlasmaSetOptionsRequest> PlasmaSetOptionsRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSetOptionsRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaSetOptionsRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaSetOptionsRequest> CreatePlasmaSetOptionsRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSetOptionsRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaSetOptionsRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _client_name = _o->client_name.empty() ? 0 : _fbb.CreateString(_o->client_name);
+ auto _output_memory_quota = _o->output_memory_quota;
+ return plasma::flatbuf::CreatePlasmaSetOptionsRequest(
+ _fbb,
+ _client_name,
+ _output_memory_quota);
+}
+
+inline PlasmaSetOptionsReplyT *PlasmaSetOptionsReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaSetOptionsReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaSetOptionsReplyT>(new PlasmaSetOptionsReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaSetOptionsReply::UnPackTo(PlasmaSetOptionsReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = error(); _o->error = _e; }
+}
+
+inline flatbuffers::Offset<PlasmaSetOptionsReply> PlasmaSetOptionsReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSetOptionsReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaSetOptionsReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaSetOptionsReply> CreatePlasmaSetOptionsReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSetOptionsReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaSetOptionsReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _error = _o->error;
+ return plasma::flatbuf::CreatePlasmaSetOptionsReply(
+ _fbb,
+ _error);
+}
+
+inline PlasmaGetDebugStringRequestT *PlasmaGetDebugStringRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaGetDebugStringRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaGetDebugStringRequestT>(new PlasmaGetDebugStringRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaGetDebugStringRequest::UnPackTo(PlasmaGetDebugStringRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<PlasmaGetDebugStringRequest> PlasmaGetDebugStringRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetDebugStringRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaGetDebugStringRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaGetDebugStringRequest> CreatePlasmaGetDebugStringRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetDebugStringRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaGetDebugStringRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return plasma::flatbuf::CreatePlasmaGetDebugStringRequest(
+ _fbb);
+}
+
+inline PlasmaGetDebugStringReplyT *PlasmaGetDebugStringReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaGetDebugStringReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaGetDebugStringReplyT>(new PlasmaGetDebugStringReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaGetDebugStringReply::UnPackTo(PlasmaGetDebugStringReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = debug_string(); if (_e) _o->debug_string = _e->str(); }
+}
+
+inline flatbuffers::Offset<PlasmaGetDebugStringReply> PlasmaGetDebugStringReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetDebugStringReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaGetDebugStringReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaGetDebugStringReply> CreatePlasmaGetDebugStringReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetDebugStringReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaGetDebugStringReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _debug_string = _o->debug_string.empty() ? 0 : _fbb.CreateString(_o->debug_string);
+ return plasma::flatbuf::CreatePlasmaGetDebugStringReply(
+ _fbb,
+ _debug_string);
+}
+
+inline PlasmaCreateRequestT *PlasmaCreateRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaCreateRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaCreateRequestT>(new PlasmaCreateRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaCreateRequest::UnPackTo(PlasmaCreateRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_id(); if (_e) _o->object_id = _e->str(); }
+ { auto _e = evict_if_full(); _o->evict_if_full = _e; }
+ { auto _e = data_size(); _o->data_size = _e; }
+ { auto _e = metadata_size(); _o->metadata_size = _e; }
+ { auto _e = device_num(); _o->device_num = _e; }
+}
+
+inline flatbuffers::Offset<PlasmaCreateRequest> PlasmaCreateRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaCreateRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaCreateRequest> CreatePlasmaCreateRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaCreateRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_id = _o->object_id.empty() ? 0 : _fbb.CreateString(_o->object_id);
+ auto _evict_if_full = _o->evict_if_full;
+ auto _data_size = _o->data_size;
+ auto _metadata_size = _o->metadata_size;
+ auto _device_num = _o->device_num;
+ return plasma::flatbuf::CreatePlasmaCreateRequest(
+ _fbb,
+ _object_id,
+ _evict_if_full,
+ _data_size,
+ _metadata_size,
+ _device_num);
+}
+
+inline CudaHandleT *CudaHandle::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::CudaHandleT> _o = std::unique_ptr<plasma::flatbuf::CudaHandleT>(new CudaHandleT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void CudaHandle::UnPackTo(CudaHandleT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = handle(); if (_e) { _o->handle.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->handle[_i] = _e->Get(_i); } } }
+}
+
+inline flatbuffers::Offset<CudaHandle> CudaHandle::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CudaHandleT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateCudaHandle(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<CudaHandle> CreateCudaHandle(flatbuffers::FlatBufferBuilder &_fbb, const CudaHandleT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CudaHandleT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _handle = _o->handle.size() ? _fbb.CreateVector(_o->handle) : 0;
+ return plasma::flatbuf::CreateCudaHandle(
+ _fbb,
+ _handle);
+}
+
+inline PlasmaCreateReplyT *PlasmaCreateReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaCreateReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaCreateReplyT>(new PlasmaCreateReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaCreateReply::UnPackTo(PlasmaCreateReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_id(); if (_e) _o->object_id = _e->str(); }
+ { auto _e = plasma_object(); if (_e) _o->plasma_object = std::unique_ptr<plasma::flatbuf::PlasmaObjectSpec>(new plasma::flatbuf::PlasmaObjectSpec(*_e)); }
+ { auto _e = error(); _o->error = _e; }
+ { auto _e = store_fd(); _o->store_fd = _e; }
+ { auto _e = mmap_size(); _o->mmap_size = _e; }
+ { auto _e = ipc_handle(); if (_e) _o->ipc_handle = std::unique_ptr<plasma::flatbuf::CudaHandleT>(_e->UnPack(_resolver)); }
+}
+
+inline flatbuffers::Offset<PlasmaCreateReply> PlasmaCreateReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaCreateReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaCreateReply> CreatePlasmaCreateReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaCreateReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_id = _o->object_id.empty() ? 0 : _fbb.CreateString(_o->object_id);
+ auto _plasma_object = _o->plasma_object ? _o->plasma_object.get() : 0;
+ auto _error = _o->error;
+ auto _store_fd = _o->store_fd;
+ auto _mmap_size = _o->mmap_size;
+ auto _ipc_handle = _o->ipc_handle ? CreateCudaHandle(_fbb, _o->ipc_handle.get(), _rehasher) : 0;
+ return plasma::flatbuf::CreatePlasmaCreateReply(
+ _fbb,
+ _object_id,
+ _plasma_object,
+ _error,
+ _store_fd,
+ _mmap_size,
+ _ipc_handle);
+}
+
+inline PlasmaCreateAndSealRequestT *PlasmaCreateAndSealRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaCreateAndSealRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaCreateAndSealRequestT>(new PlasmaCreateAndSealRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaCreateAndSealRequest::UnPackTo(PlasmaCreateAndSealRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_id(); if (_e) _o->object_id = _e->str(); }
+ { auto _e = evict_if_full(); _o->evict_if_full = _e; }
+ { auto _e = data(); if (_e) _o->data = _e->str(); }
+ { auto _e = metadata(); if (_e) _o->metadata = _e->str(); }
+ { auto _e = digest(); if (_e) _o->digest = _e->str(); }
+}
+
+inline flatbuffers::Offset<PlasmaCreateAndSealRequest> PlasmaCreateAndSealRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaCreateAndSealRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaCreateAndSealRequest> CreatePlasmaCreateAndSealRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaCreateAndSealRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_id = _o->object_id.empty() ? 0 : _fbb.CreateString(_o->object_id);
+ auto _evict_if_full = _o->evict_if_full;
+ auto _data = _o->data.empty() ? 0 : _fbb.CreateString(_o->data);
+ auto _metadata = _o->metadata.empty() ? 0 : _fbb.CreateString(_o->metadata);
+ auto _digest = _o->digest.empty() ? 0 : _fbb.CreateString(_o->digest);
+ return plasma::flatbuf::CreatePlasmaCreateAndSealRequest(
+ _fbb,
+ _object_id,
+ _evict_if_full,
+ _data,
+ _metadata,
+ _digest);
+}
+
+inline PlasmaCreateAndSealReplyT *PlasmaCreateAndSealReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaCreateAndSealReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaCreateAndSealReplyT>(new PlasmaCreateAndSealReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaCreateAndSealReply::UnPackTo(PlasmaCreateAndSealReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = error(); _o->error = _e; }
+}
+
+inline flatbuffers::Offset<PlasmaCreateAndSealReply> PlasmaCreateAndSealReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaCreateAndSealReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaCreateAndSealReply> CreatePlasmaCreateAndSealReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaCreateAndSealReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _error = _o->error;
+ return plasma::flatbuf::CreatePlasmaCreateAndSealReply(
+ _fbb,
+ _error);
+}
+
+inline PlasmaCreateAndSealBatchRequestT *PlasmaCreateAndSealBatchRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaCreateAndSealBatchRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaCreateAndSealBatchRequestT>(new PlasmaCreateAndSealBatchRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaCreateAndSealBatchRequest::UnPackTo(PlasmaCreateAndSealBatchRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_ids(); if (_e) { _o->object_ids.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->object_ids[_i] = _e->Get(_i)->str(); } } }
+ { auto _e = evict_if_full(); _o->evict_if_full = _e; }
+ { auto _e = data(); if (_e) { _o->data.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->data[_i] = _e->Get(_i)->str(); } } }
+ { auto _e = metadata(); if (_e) { _o->metadata.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->metadata[_i] = _e->Get(_i)->str(); } } }
+ { auto _e = digest(); if (_e) { _o->digest.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->digest[_i] = _e->Get(_i)->str(); } } }
+}
+
+inline flatbuffers::Offset<PlasmaCreateAndSealBatchRequest> PlasmaCreateAndSealBatchRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealBatchRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaCreateAndSealBatchRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaCreateAndSealBatchRequest> CreatePlasmaCreateAndSealBatchRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealBatchRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaCreateAndSealBatchRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_ids = _o->object_ids.size() ? _fbb.CreateVectorOfStrings(_o->object_ids) : 0;
+ auto _evict_if_full = _o->evict_if_full;
+ auto _data = _o->data.size() ? _fbb.CreateVectorOfStrings(_o->data) : 0;
+ auto _metadata = _o->metadata.size() ? _fbb.CreateVectorOfStrings(_o->metadata) : 0;
+ auto _digest = _o->digest.size() ? _fbb.CreateVectorOfStrings(_o->digest) : 0;
+ return plasma::flatbuf::CreatePlasmaCreateAndSealBatchRequest(
+ _fbb,
+ _object_ids,
+ _evict_if_full,
+ _data,
+ _metadata,
+ _digest);
+}
+
+inline PlasmaCreateAndSealBatchReplyT *PlasmaCreateAndSealBatchReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaCreateAndSealBatchReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaCreateAndSealBatchReplyT>(new PlasmaCreateAndSealBatchReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaCreateAndSealBatchReply::UnPackTo(PlasmaCreateAndSealBatchReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = error(); _o->error = _e; }
+}
+
+inline flatbuffers::Offset<PlasmaCreateAndSealBatchReply> PlasmaCreateAndSealBatchReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealBatchReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaCreateAndSealBatchReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaCreateAndSealBatchReply> CreatePlasmaCreateAndSealBatchReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaCreateAndSealBatchReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaCreateAndSealBatchReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _error = _o->error;
+ return plasma::flatbuf::CreatePlasmaCreateAndSealBatchReply(
+ _fbb,
+ _error);
+}
+
+inline PlasmaAbortRequestT *PlasmaAbortRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaAbortRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaAbortRequestT>(new PlasmaAbortRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaAbortRequest::UnPackTo(PlasmaAbortRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_id(); if (_e) _o->object_id = _e->str(); }
+}
+
+inline flatbuffers::Offset<PlasmaAbortRequest> PlasmaAbortRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaAbortRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaAbortRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaAbortRequest> CreatePlasmaAbortRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaAbortRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaAbortRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_id = _o->object_id.empty() ? 0 : _fbb.CreateString(_o->object_id);
+ return plasma::flatbuf::CreatePlasmaAbortRequest(
+ _fbb,
+ _object_id);
+}
+
+inline PlasmaAbortReplyT *PlasmaAbortReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaAbortReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaAbortReplyT>(new PlasmaAbortReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaAbortReply::UnPackTo(PlasmaAbortReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_id(); if (_e) _o->object_id = _e->str(); }
+}
+
+inline flatbuffers::Offset<PlasmaAbortReply> PlasmaAbortReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaAbortReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaAbortReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaAbortReply> CreatePlasmaAbortReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaAbortReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaAbortReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_id = _o->object_id.empty() ? 0 : _fbb.CreateString(_o->object_id);
+ return plasma::flatbuf::CreatePlasmaAbortReply(
+ _fbb,
+ _object_id);
+}
+
+inline PlasmaSealRequestT *PlasmaSealRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaSealRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaSealRequestT>(new PlasmaSealRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaSealRequest::UnPackTo(PlasmaSealRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_id(); if (_e) _o->object_id = _e->str(); }
+ { auto _e = digest(); if (_e) _o->digest = _e->str(); }
+}
+
+inline flatbuffers::Offset<PlasmaSealRequest> PlasmaSealRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSealRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaSealRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaSealRequest> CreatePlasmaSealRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSealRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaSealRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_id = _o->object_id.empty() ? 0 : _fbb.CreateString(_o->object_id);
+ auto _digest = _o->digest.empty() ? 0 : _fbb.CreateString(_o->digest);
+ return plasma::flatbuf::CreatePlasmaSealRequest(
+ _fbb,
+ _object_id,
+ _digest);
+}
+
+inline PlasmaSealReplyT *PlasmaSealReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaSealReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaSealReplyT>(new PlasmaSealReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaSealReply::UnPackTo(PlasmaSealReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_id(); if (_e) _o->object_id = _e->str(); }
+ { auto _e = error(); _o->error = _e; }
+}
+
+inline flatbuffers::Offset<PlasmaSealReply> PlasmaSealReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSealReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaSealReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaSealReply> CreatePlasmaSealReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSealReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaSealReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_id = _o->object_id.empty() ? 0 : _fbb.CreateString(_o->object_id);
+ auto _error = _o->error;
+ return plasma::flatbuf::CreatePlasmaSealReply(
+ _fbb,
+ _object_id,
+ _error);
+}
+
+inline PlasmaGetRequestT *PlasmaGetRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaGetRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaGetRequestT>(new PlasmaGetRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaGetRequest::UnPackTo(PlasmaGetRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_ids(); if (_e) { _o->object_ids.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->object_ids[_i] = _e->Get(_i)->str(); } } }
+ { auto _e = timeout_ms(); _o->timeout_ms = _e; }
+}
+
+inline flatbuffers::Offset<PlasmaGetRequest> PlasmaGetRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaGetRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaGetRequest> CreatePlasmaGetRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaGetRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_ids = _o->object_ids.size() ? _fbb.CreateVectorOfStrings(_o->object_ids) : 0;
+ auto _timeout_ms = _o->timeout_ms;
+ return plasma::flatbuf::CreatePlasmaGetRequest(
+ _fbb,
+ _object_ids,
+ _timeout_ms);
+}
+
+inline PlasmaGetReplyT *PlasmaGetReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaGetReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaGetReplyT>(new PlasmaGetReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaGetReply::UnPackTo(PlasmaGetReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_ids(); if (_e) { _o->object_ids.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->object_ids[_i] = _e->Get(_i)->str(); } } }
+ { auto _e = plasma_objects(); if (_e) { _o->plasma_objects.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->plasma_objects[_i] = *_e->Get(_i); } } }
+ { auto _e = store_fds(); if (_e) { _o->store_fds.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->store_fds[_i] = _e->Get(_i); } } }
+ { auto _e = mmap_sizes(); if (_e) { _o->mmap_sizes.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->mmap_sizes[_i] = _e->Get(_i); } } }
+ { auto _e = handles(); if (_e) { _o->handles.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->handles[_i] = std::unique_ptr<plasma::flatbuf::CudaHandleT>(_e->Get(_i)->UnPack(_resolver)); } } }
+}
+
+inline flatbuffers::Offset<PlasmaGetReply> PlasmaGetReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaGetReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaGetReply> CreatePlasmaGetReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaGetReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaGetReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_ids = _o->object_ids.size() ? _fbb.CreateVectorOfStrings(_o->object_ids) : 0;
+ auto _plasma_objects = _o->plasma_objects.size() ? _fbb.CreateVectorOfStructs(_o->plasma_objects) : 0;
+ auto _store_fds = _o->store_fds.size() ? _fbb.CreateVector(_o->store_fds) : 0;
+ auto _mmap_sizes = _o->mmap_sizes.size() ? _fbb.CreateVector(_o->mmap_sizes) : 0;
+ auto _handles = _o->handles.size() ? _fbb.CreateVector<flatbuffers::Offset<plasma::flatbuf::CudaHandle>> (_o->handles.size(), [](size_t i, _VectorArgs *__va) { return CreateCudaHandle(*__va->__fbb, __va->__o->handles[i].get(), __va->__rehasher); }, &_va ) : 0;
+ return plasma::flatbuf::CreatePlasmaGetReply(
+ _fbb,
+ _object_ids,
+ _plasma_objects,
+ _store_fds,
+ _mmap_sizes,
+ _handles);
+}
+
+inline PlasmaReleaseRequestT *PlasmaReleaseRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaReleaseRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaReleaseRequestT>(new PlasmaReleaseRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaReleaseRequest::UnPackTo(PlasmaReleaseRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_id(); if (_e) _o->object_id = _e->str(); }
+}
+
+inline flatbuffers::Offset<PlasmaReleaseRequest> PlasmaReleaseRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaReleaseRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaReleaseRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaReleaseRequest> CreatePlasmaReleaseRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaReleaseRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaReleaseRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_id = _o->object_id.empty() ? 0 : _fbb.CreateString(_o->object_id);
+ return plasma::flatbuf::CreatePlasmaReleaseRequest(
+ _fbb,
+ _object_id);
+}
+
+inline PlasmaReleaseReplyT *PlasmaReleaseReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaReleaseReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaReleaseReplyT>(new PlasmaReleaseReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaReleaseReply::UnPackTo(PlasmaReleaseReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_id(); if (_e) _o->object_id = _e->str(); }
+ { auto _e = error(); _o->error = _e; }
+}
+
+inline flatbuffers::Offset<PlasmaReleaseReply> PlasmaReleaseReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaReleaseReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaReleaseReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaReleaseReply> CreatePlasmaReleaseReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaReleaseReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaReleaseReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_id = _o->object_id.empty() ? 0 : _fbb.CreateString(_o->object_id);
+ auto _error = _o->error;
+ return plasma::flatbuf::CreatePlasmaReleaseReply(
+ _fbb,
+ _object_id,
+ _error);
+}
+
+inline PlasmaDeleteRequestT *PlasmaDeleteRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaDeleteRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaDeleteRequestT>(new PlasmaDeleteRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaDeleteRequest::UnPackTo(PlasmaDeleteRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = count(); _o->count = _e; }
+ { auto _e = object_ids(); if (_e) { _o->object_ids.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->object_ids[_i] = _e->Get(_i)->str(); } } }
+}
+
+inline flatbuffers::Offset<PlasmaDeleteRequest> PlasmaDeleteRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDeleteRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaDeleteRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaDeleteRequest> CreatePlasmaDeleteRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDeleteRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaDeleteRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _count = _o->count;
+ auto _object_ids = _o->object_ids.size() ? _fbb.CreateVectorOfStrings(_o->object_ids) : 0;
+ return plasma::flatbuf::CreatePlasmaDeleteRequest(
+ _fbb,
+ _count,
+ _object_ids);
+}
+
+inline PlasmaDeleteReplyT *PlasmaDeleteReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaDeleteReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaDeleteReplyT>(new PlasmaDeleteReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaDeleteReply::UnPackTo(PlasmaDeleteReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = count(); _o->count = _e; }
+ { auto _e = object_ids(); if (_e) { _o->object_ids.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->object_ids[_i] = _e->Get(_i)->str(); } } }
+ { auto _e = errors(); if (_e) { _o->errors.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->errors[_i] = static_cast<plasma::flatbuf::PlasmaError>(_e->Get(_i)); } } }
+}
+
+inline flatbuffers::Offset<PlasmaDeleteReply> PlasmaDeleteReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDeleteReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaDeleteReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaDeleteReply> CreatePlasmaDeleteReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDeleteReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaDeleteReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _count = _o->count;
+ auto _object_ids = _o->object_ids.size() ? _fbb.CreateVectorOfStrings(_o->object_ids) : 0;
+ auto _errors = _o->errors.size() ? _fbb.CreateVectorScalarCast<int32_t>(flatbuffers::data(_o->errors), _o->errors.size()) : 0;
+ return plasma::flatbuf::CreatePlasmaDeleteReply(
+ _fbb,
+ _count,
+ _object_ids,
+ _errors);
+}
+
+inline PlasmaContainsRequestT *PlasmaContainsRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaContainsRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaContainsRequestT>(new PlasmaContainsRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaContainsRequest::UnPackTo(PlasmaContainsRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_id(); if (_e) _o->object_id = _e->str(); }
+}
+
+inline flatbuffers::Offset<PlasmaContainsRequest> PlasmaContainsRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaContainsRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaContainsRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaContainsRequest> CreatePlasmaContainsRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaContainsRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaContainsRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_id = _o->object_id.empty() ? 0 : _fbb.CreateString(_o->object_id);
+ return plasma::flatbuf::CreatePlasmaContainsRequest(
+ _fbb,
+ _object_id);
+}
+
+inline PlasmaContainsReplyT *PlasmaContainsReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaContainsReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaContainsReplyT>(new PlasmaContainsReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaContainsReply::UnPackTo(PlasmaContainsReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_id(); if (_e) _o->object_id = _e->str(); }
+ { auto _e = has_object(); _o->has_object = _e; }
+}
+
+inline flatbuffers::Offset<PlasmaContainsReply> PlasmaContainsReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaContainsReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaContainsReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaContainsReply> CreatePlasmaContainsReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaContainsReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaContainsReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_id = _o->object_id.empty() ? 0 : _fbb.CreateString(_o->object_id);
+ auto _has_object = _o->has_object;
+ return plasma::flatbuf::CreatePlasmaContainsReply(
+ _fbb,
+ _object_id,
+ _has_object);
+}
+
+inline PlasmaListRequestT *PlasmaListRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaListRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaListRequestT>(new PlasmaListRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaListRequest::UnPackTo(PlasmaListRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<PlasmaListRequest> PlasmaListRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaListRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaListRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaListRequest> CreatePlasmaListRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaListRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaListRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return plasma::flatbuf::CreatePlasmaListRequest(
+ _fbb);
+}
+
+inline PlasmaListReplyT *PlasmaListReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaListReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaListReplyT>(new PlasmaListReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaListReply::UnPackTo(PlasmaListReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = objects(); if (_e) { _o->objects.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->objects[_i] = std::unique_ptr<plasma::flatbuf::ObjectInfoT>(_e->Get(_i)->UnPack(_resolver)); } } }
+}
+
+inline flatbuffers::Offset<PlasmaListReply> PlasmaListReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaListReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaListReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaListReply> CreatePlasmaListReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaListReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaListReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _objects = _o->objects.size() ? _fbb.CreateVector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>> (_o->objects.size(), [](size_t i, _VectorArgs *__va) { return CreateObjectInfo(*__va->__fbb, __va->__o->objects[i].get(), __va->__rehasher); }, &_va ) : 0;
+ return plasma::flatbuf::CreatePlasmaListReply(
+ _fbb,
+ _objects);
+}
+
+inline PlasmaConnectRequestT *PlasmaConnectRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaConnectRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaConnectRequestT>(new PlasmaConnectRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaConnectRequest::UnPackTo(PlasmaConnectRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<PlasmaConnectRequest> PlasmaConnectRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaConnectRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaConnectRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaConnectRequest> CreatePlasmaConnectRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaConnectRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaConnectRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return plasma::flatbuf::CreatePlasmaConnectRequest(
+ _fbb);
+}
+
+inline PlasmaConnectReplyT *PlasmaConnectReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaConnectReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaConnectReplyT>(new PlasmaConnectReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaConnectReply::UnPackTo(PlasmaConnectReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = memory_capacity(); _o->memory_capacity = _e; }
+}
+
+inline flatbuffers::Offset<PlasmaConnectReply> PlasmaConnectReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaConnectReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaConnectReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaConnectReply> CreatePlasmaConnectReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaConnectReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaConnectReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _memory_capacity = _o->memory_capacity;
+ return plasma::flatbuf::CreatePlasmaConnectReply(
+ _fbb,
+ _memory_capacity);
+}
+
+inline PlasmaEvictRequestT *PlasmaEvictRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaEvictRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaEvictRequestT>(new PlasmaEvictRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaEvictRequest::UnPackTo(PlasmaEvictRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = num_bytes(); _o->num_bytes = _e; }
+}
+
+inline flatbuffers::Offset<PlasmaEvictRequest> PlasmaEvictRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaEvictRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaEvictRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaEvictRequest> CreatePlasmaEvictRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaEvictRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaEvictRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _num_bytes = _o->num_bytes;
+ return plasma::flatbuf::CreatePlasmaEvictRequest(
+ _fbb,
+ _num_bytes);
+}
+
+inline PlasmaEvictReplyT *PlasmaEvictReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaEvictReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaEvictReplyT>(new PlasmaEvictReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaEvictReply::UnPackTo(PlasmaEvictReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = num_bytes(); _o->num_bytes = _e; }
+}
+
+inline flatbuffers::Offset<PlasmaEvictReply> PlasmaEvictReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaEvictReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaEvictReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaEvictReply> CreatePlasmaEvictReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaEvictReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaEvictReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _num_bytes = _o->num_bytes;
+ return plasma::flatbuf::CreatePlasmaEvictReply(
+ _fbb,
+ _num_bytes);
+}
+
+inline PlasmaSubscribeRequestT *PlasmaSubscribeRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaSubscribeRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaSubscribeRequestT>(new PlasmaSubscribeRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaSubscribeRequest::UnPackTo(PlasmaSubscribeRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<PlasmaSubscribeRequest> PlasmaSubscribeRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSubscribeRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaSubscribeRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaSubscribeRequest> CreatePlasmaSubscribeRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaSubscribeRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaSubscribeRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return plasma::flatbuf::CreatePlasmaSubscribeRequest(
+ _fbb);
+}
+
+inline PlasmaNotificationT *PlasmaNotification::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaNotificationT> _o = std::unique_ptr<plasma::flatbuf::PlasmaNotificationT>(new PlasmaNotificationT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaNotification::UnPackTo(PlasmaNotificationT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_info(); if (_e) { _o->object_info.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->object_info[_i] = std::unique_ptr<plasma::flatbuf::ObjectInfoT>(_e->Get(_i)->UnPack(_resolver)); } } }
+}
+
+inline flatbuffers::Offset<PlasmaNotification> PlasmaNotification::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaNotificationT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaNotification(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaNotification> CreatePlasmaNotification(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaNotificationT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaNotificationT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_info = _o->object_info.size() ? _fbb.CreateVector<flatbuffers::Offset<plasma::flatbuf::ObjectInfo>> (_o->object_info.size(), [](size_t i, _VectorArgs *__va) { return CreateObjectInfo(*__va->__fbb, __va->__o->object_info[i].get(), __va->__rehasher); }, &_va ) : 0;
+ return plasma::flatbuf::CreatePlasmaNotification(
+ _fbb,
+ _object_info);
+}
+
+inline PlasmaDataRequestT *PlasmaDataRequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaDataRequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaDataRequestT>(new PlasmaDataRequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaDataRequest::UnPackTo(PlasmaDataRequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_id(); if (_e) _o->object_id = _e->str(); }
+ { auto _e = address(); if (_e) _o->address = _e->str(); }
+ { auto _e = port(); _o->port = _e; }
+}
+
+inline flatbuffers::Offset<PlasmaDataRequest> PlasmaDataRequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDataRequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaDataRequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaDataRequest> CreatePlasmaDataRequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDataRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaDataRequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_id = _o->object_id.empty() ? 0 : _fbb.CreateString(_o->object_id);
+ auto _address = _o->address.empty() ? 0 : _fbb.CreateString(_o->address);
+ auto _port = _o->port;
+ return plasma::flatbuf::CreatePlasmaDataRequest(
+ _fbb,
+ _object_id,
+ _address,
+ _port);
+}
+
+inline PlasmaDataReplyT *PlasmaDataReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaDataReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaDataReplyT>(new PlasmaDataReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaDataReply::UnPackTo(PlasmaDataReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_id(); if (_e) _o->object_id = _e->str(); }
+ { auto _e = object_size(); _o->object_size = _e; }
+ { auto _e = metadata_size(); _o->metadata_size = _e; }
+}
+
+inline flatbuffers::Offset<PlasmaDataReply> PlasmaDataReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDataReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaDataReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaDataReply> CreatePlasmaDataReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaDataReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaDataReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_id = _o->object_id.empty() ? 0 : _fbb.CreateString(_o->object_id);
+ auto _object_size = _o->object_size;
+ auto _metadata_size = _o->metadata_size;
+ return plasma::flatbuf::CreatePlasmaDataReply(
+ _fbb,
+ _object_id,
+ _object_size,
+ _metadata_size);
+}
+
+inline PlasmaRefreshLRURequestT *PlasmaRefreshLRURequest::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaRefreshLRURequestT> _o = std::unique_ptr<plasma::flatbuf::PlasmaRefreshLRURequestT>(new PlasmaRefreshLRURequestT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaRefreshLRURequest::UnPackTo(PlasmaRefreshLRURequestT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = object_ids(); if (_e) { _o->object_ids.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->object_ids[_i] = _e->Get(_i)->str(); } } }
+}
+
+inline flatbuffers::Offset<PlasmaRefreshLRURequest> PlasmaRefreshLRURequest::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaRefreshLRURequestT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaRefreshLRURequest(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaRefreshLRURequest> CreatePlasmaRefreshLRURequest(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaRefreshLRURequestT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaRefreshLRURequestT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _object_ids = _o->object_ids.size() ? _fbb.CreateVectorOfStrings(_o->object_ids) : 0;
+ return plasma::flatbuf::CreatePlasmaRefreshLRURequest(
+ _fbb,
+ _object_ids);
+}
+
+inline PlasmaRefreshLRUReplyT *PlasmaRefreshLRUReply::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ std::unique_ptr<plasma::flatbuf::PlasmaRefreshLRUReplyT> _o = std::unique_ptr<plasma::flatbuf::PlasmaRefreshLRUReplyT>(new PlasmaRefreshLRUReplyT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void PlasmaRefreshLRUReply::UnPackTo(PlasmaRefreshLRUReplyT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<PlasmaRefreshLRUReply> PlasmaRefreshLRUReply::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaRefreshLRUReplyT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePlasmaRefreshLRUReply(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PlasmaRefreshLRUReply> CreatePlasmaRefreshLRUReply(flatbuffers::FlatBufferBuilder &_fbb, const PlasmaRefreshLRUReplyT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PlasmaRefreshLRUReplyT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return plasma::flatbuf::CreatePlasmaRefreshLRUReply(
+ _fbb);
+}
+
+} // namespace flatbuf
+} // namespace plasma
+
+#endif // FLATBUFFERS_GENERATED_PLASMA_PLASMA_FLATBUF_H_
diff --git a/src/arrow/cpp/src/plasma/protocol.cc b/src/arrow/cpp/src/plasma/protocol.cc
new file mode 100644
index 000000000..735636cda
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/protocol.cc
@@ -0,0 +1,829 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "plasma/protocol.h"
+
+#include <utility>
+
+#include "flatbuffers/flatbuffers.h"
+#include "plasma/common.h"
+#include "plasma/io.h"
+#include "plasma/plasma_generated.h"
+
+#ifdef PLASMA_CUDA
+#include "arrow/gpu/cuda_api.h"
+#endif
+#include "arrow/util/ubsan.h"
+
+namespace fb = plasma::flatbuf;
+
+namespace plasma {
+
+using fb::MessageType;
+using fb::PlasmaError;
+using fb::PlasmaObjectSpec;
+
+using flatbuffers::uoffset_t;
+
+#define PLASMA_CHECK_ENUM(x, y) \
+ static_assert(static_cast<int>(x) == static_cast<int>(y), "protocol mismatch")
+
+flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
+ToFlatbuffer(flatbuffers::FlatBufferBuilder* fbb, const ObjectID* object_ids,
+ int64_t num_objects) {
+ std::vector<flatbuffers::Offset<flatbuffers::String>> results;
+ for (int64_t i = 0; i < num_objects; i++) {
+ results.push_back(fbb->CreateString(object_ids[i].binary()));
+ }
+ return fbb->CreateVector(arrow::util::MakeNonNull(results.data()), results.size());
+}
+
+flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
+ToFlatbuffer(flatbuffers::FlatBufferBuilder* fbb,
+ const std::vector<std::string>& strings) {
+ std::vector<flatbuffers::Offset<flatbuffers::String>> results;
+ for (size_t i = 0; i < strings.size(); i++) {
+ results.push_back(fbb->CreateString(strings[i]));
+ }
+
+ return fbb->CreateVector(arrow::util::MakeNonNull(results.data()), results.size());
+}
+
+flatbuffers::Offset<flatbuffers::Vector<int64_t>> ToFlatbuffer(
+ flatbuffers::FlatBufferBuilder* fbb, const std::vector<int64_t>& data) {
+ return fbb->CreateVector(arrow::util::MakeNonNull(data.data()), data.size());
+}
+
+Status PlasmaReceive(int sock, MessageType message_type, std::vector<uint8_t>* buffer) {
+ MessageType type;
+ RETURN_NOT_OK(ReadMessage(sock, &type, buffer));
+ ARROW_CHECK(type == message_type)
+ << "type = " << static_cast<int64_t>(type)
+ << ", message_type = " << static_cast<int64_t>(message_type);
+ return Status::OK();
+}
+
+// Helper function to create a vector of elements from Data (Request/Reply struct).
+// The Getter function is used to extract one element from Data.
+template <typename T, typename Data, typename Getter>
+void ToVector(const Data& request, std::vector<T>* out, const Getter& getter) {
+ int count = request.count();
+ out->clear();
+ out->reserve(count);
+ for (int i = 0; i < count; ++i) {
+ out->push_back(getter(request, i));
+ }
+}
+
+template <typename T, typename FlatbufferVectorPointer, typename Converter>
+void ConvertToVector(const FlatbufferVectorPointer fbvector, std::vector<T>* out,
+ const Converter& converter) {
+ out->clear();
+ out->reserve(fbvector->size());
+ for (size_t i = 0; i < fbvector->size(); ++i) {
+ out->push_back(converter(*fbvector->Get(i)));
+ }
+}
+
+template <typename Message>
+Status PlasmaSend(int sock, MessageType message_type, flatbuffers::FlatBufferBuilder* fbb,
+ const Message& message) {
+ fbb->Finish(message);
+ return WriteMessage(sock, message_type, fbb->GetSize(), fbb->GetBufferPointer());
+}
+
+Status PlasmaErrorStatus(fb::PlasmaError plasma_error) {
+ switch (plasma_error) {
+ case fb::PlasmaError::OK:
+ return Status::OK();
+ case fb::PlasmaError::ObjectExists:
+ return MakePlasmaError(PlasmaErrorCode::PlasmaObjectExists,
+ "object already exists in the plasma store");
+ case fb::PlasmaError::ObjectNotFound:
+ return MakePlasmaError(PlasmaErrorCode::PlasmaObjectNotFound,
+ "object does not exist in the plasma store");
+ case fb::PlasmaError::OutOfMemory:
+ return MakePlasmaError(PlasmaErrorCode::PlasmaStoreFull,
+ "object does not fit in the plasma store");
+ default:
+ ARROW_LOG(FATAL) << "unknown plasma error code " << static_cast<int>(plasma_error);
+ }
+ return Status::OK();
+}
+
+// Set options messages.
+
+Status SendSetOptionsRequest(int sock, const std::string& client_name,
+ int64_t output_memory_limit) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaSetOptionsRequest(fbb, fbb.CreateString(client_name),
+ output_memory_limit);
+ return PlasmaSend(sock, MessageType::PlasmaSetOptionsRequest, &fbb, message);
+}
+
+Status ReadSetOptionsRequest(const uint8_t* data, size_t size, std::string* client_name,
+ int64_t* output_memory_quota) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaSetOptionsRequest>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *client_name = std::string(message->client_name()->str());
+ *output_memory_quota = message->output_memory_quota();
+ return Status::OK();
+}
+
+Status SendSetOptionsReply(int sock, PlasmaError error) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaSetOptionsReply(fbb, error);
+ return PlasmaSend(sock, MessageType::PlasmaSetOptionsReply, &fbb, message);
+}
+
+Status ReadSetOptionsReply(const uint8_t* data, size_t size) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaSetOptionsReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ return PlasmaErrorStatus(message->error());
+}
+
+// Get debug string messages.
+
+Status SendGetDebugStringRequest(int sock) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaGetDebugStringRequest(fbb);
+ return PlasmaSend(sock, MessageType::PlasmaGetDebugStringRequest, &fbb, message);
+}
+
+Status SendGetDebugStringReply(int sock, const std::string& debug_string) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaGetDebugStringReply(fbb, fbb.CreateString(debug_string));
+ return PlasmaSend(sock, MessageType::PlasmaGetDebugStringReply, &fbb, message);
+}
+
+Status ReadGetDebugStringReply(const uint8_t* data, size_t size,
+ std::string* debug_string) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaGetDebugStringReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *debug_string = message->debug_string()->str();
+ return Status::OK();
+}
+
+// Create messages.
+
+Status SendCreateRequest(int sock, ObjectID object_id, bool evict_if_full,
+ int64_t data_size, int64_t metadata_size, int device_num) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message =
+ fb::CreatePlasmaCreateRequest(fbb, fbb.CreateString(object_id.binary()),
+ evict_if_full, data_size, metadata_size, device_num);
+ return PlasmaSend(sock, MessageType::PlasmaCreateRequest, &fbb, message);
+}
+
+Status ReadCreateRequest(const uint8_t* data, size_t size, ObjectID* object_id,
+ bool* evict_if_full, int64_t* data_size, int64_t* metadata_size,
+ int* device_num) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaCreateRequest>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *evict_if_full = message->evict_if_full();
+ *data_size = message->data_size();
+ *metadata_size = message->metadata_size();
+ *object_id = ObjectID::from_binary(message->object_id()->str());
+ *device_num = message->device_num();
+ return Status::OK();
+}
+
+Status SendCreateReply(int sock, ObjectID object_id, PlasmaObject* object,
+ PlasmaError error_code, int64_t mmap_size) {
+ flatbuffers::FlatBufferBuilder fbb;
+ PlasmaObjectSpec plasma_object(object->store_fd, object->data_offset, object->data_size,
+ object->metadata_offset, object->metadata_size,
+ object->device_num);
+ auto object_string = fbb.CreateString(object_id.binary());
+#ifdef PLASMA_CUDA
+ flatbuffers::Offset<fb::CudaHandle> ipc_handle;
+ if (object->device_num != 0) {
+ std::shared_ptr<arrow::Buffer> handle;
+ ARROW_ASSIGN_OR_RAISE(handle, object->ipc_handle->Serialize());
+ ipc_handle =
+ fb::CreateCudaHandle(fbb, fbb.CreateVector(handle->data(), handle->size()));
+ }
+#endif
+ fb::PlasmaCreateReplyBuilder crb(fbb);
+ crb.add_error(static_cast<PlasmaError>(error_code));
+ crb.add_plasma_object(&plasma_object);
+ crb.add_object_id(object_string);
+ crb.add_store_fd(object->store_fd);
+ crb.add_mmap_size(mmap_size);
+ if (object->device_num != 0) {
+#ifdef PLASMA_CUDA
+ crb.add_ipc_handle(ipc_handle);
+#else
+ ARROW_LOG(FATAL) << "This should be unreachable.";
+#endif
+ }
+ auto message = crb.Finish();
+ return PlasmaSend(sock, MessageType::PlasmaCreateReply, &fbb, message);
+}
+
+Status ReadCreateReply(const uint8_t* data, size_t size, ObjectID* object_id,
+ PlasmaObject* object, int* store_fd, int64_t* mmap_size) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaCreateReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *object_id = ObjectID::from_binary(message->object_id()->str());
+ object->store_fd = message->plasma_object()->segment_index();
+ object->data_offset = message->plasma_object()->data_offset();
+ object->data_size = message->plasma_object()->data_size();
+ object->metadata_offset = message->plasma_object()->metadata_offset();
+ object->metadata_size = message->plasma_object()->metadata_size();
+
+ *store_fd = message->store_fd();
+ *mmap_size = message->mmap_size();
+
+ object->device_num = message->plasma_object()->device_num();
+#ifdef PLASMA_CUDA
+ if (object->device_num != 0) {
+ ARROW_ASSIGN_OR_RAISE(
+ object->ipc_handle,
+ CudaIpcMemHandle::FromBuffer(message->ipc_handle()->handle()->data()));
+ }
+#endif
+ return PlasmaErrorStatus(message->error());
+}
+
+Status SendCreateAndSealRequest(int sock, const ObjectID& object_id, bool evict_if_full,
+ const std::string& data, const std::string& metadata,
+ unsigned char* digest) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto digest_string = fbb.CreateString(reinterpret_cast<char*>(digest), kDigestSize);
+ auto message = fb::CreatePlasmaCreateAndSealRequest(
+ fbb, fbb.CreateString(object_id.binary()), evict_if_full, fbb.CreateString(data),
+ fbb.CreateString(metadata), digest_string);
+ return PlasmaSend(sock, MessageType::PlasmaCreateAndSealRequest, &fbb, message);
+}
+
+Status ReadCreateAndSealRequest(const uint8_t* data, size_t size, ObjectID* object_id,
+ bool* evict_if_full, std::string* object_data,
+ std::string* metadata, std::string* digest) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaCreateAndSealRequest>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+
+ *object_id = ObjectID::from_binary(message->object_id()->str());
+ *evict_if_full = message->evict_if_full();
+ *object_data = message->data()->str();
+ *metadata = message->metadata()->str();
+ ARROW_CHECK(message->digest()->size() == kDigestSize);
+ digest->assign(message->digest()->data(), kDigestSize);
+ return Status::OK();
+}
+
+Status SendCreateAndSealBatchRequest(int sock, const std::vector<ObjectID>& object_ids,
+ bool evict_if_full,
+ const std::vector<std::string>& data,
+ const std::vector<std::string>& metadata,
+ const std::vector<std::string>& digests) {
+ flatbuffers::FlatBufferBuilder fbb;
+
+ auto message = fb::CreatePlasmaCreateAndSealBatchRequest(
+ fbb, ToFlatbuffer(&fbb, object_ids.data(), object_ids.size()), evict_if_full,
+ ToFlatbuffer(&fbb, data), ToFlatbuffer(&fbb, metadata),
+ ToFlatbuffer(&fbb, digests));
+
+ return PlasmaSend(sock, MessageType::PlasmaCreateAndSealBatchRequest, &fbb, message);
+}
+
+Status ReadCreateAndSealBatchRequest(const uint8_t* data, size_t size,
+ std::vector<ObjectID>* object_ids,
+ bool* evict_if_full,
+ std::vector<std::string>* object_data,
+ std::vector<std::string>* metadata,
+ std::vector<std::string>* digests) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaCreateAndSealBatchRequest>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+
+ *evict_if_full = message->evict_if_full();
+ ConvertToVector(message->object_ids(), object_ids,
+ [](const flatbuffers::String& element) {
+ return ObjectID::from_binary(element.str());
+ });
+
+ ConvertToVector(message->data(), object_data,
+ [](const flatbuffers::String& element) { return element.str(); });
+
+ ConvertToVector(message->metadata(), metadata,
+ [](const flatbuffers::String& element) { return element.str(); });
+
+ ConvertToVector(message->digest(), digests,
+ [](const flatbuffers::String& element) { return element.str(); });
+
+ return Status::OK();
+}
+
+Status SendCreateAndSealReply(int sock, PlasmaError error) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaCreateAndSealReply(fbb, static_cast<PlasmaError>(error));
+ return PlasmaSend(sock, MessageType::PlasmaCreateAndSealReply, &fbb, message);
+}
+
+Status ReadCreateAndSealReply(const uint8_t* data, size_t size) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaCreateAndSealReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ return PlasmaErrorStatus(message->error());
+}
+
+Status SendCreateAndSealBatchReply(int sock, PlasmaError error) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message =
+ fb::CreatePlasmaCreateAndSealBatchReply(fbb, static_cast<PlasmaError>(error));
+ return PlasmaSend(sock, MessageType::PlasmaCreateAndSealBatchReply, &fbb, message);
+}
+
+Status ReadCreateAndSealBatchReply(const uint8_t* data, size_t size) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaCreateAndSealBatchReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ return PlasmaErrorStatus(message->error());
+}
+
+Status SendAbortRequest(int sock, ObjectID object_id) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaAbortRequest(fbb, fbb.CreateString(object_id.binary()));
+ return PlasmaSend(sock, MessageType::PlasmaAbortRequest, &fbb, message);
+}
+
+Status ReadAbortRequest(const uint8_t* data, size_t size, ObjectID* object_id) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaAbortRequest>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *object_id = ObjectID::from_binary(message->object_id()->str());
+ return Status::OK();
+}
+
+Status SendAbortReply(int sock, ObjectID object_id) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaAbortReply(fbb, fbb.CreateString(object_id.binary()));
+ return PlasmaSend(sock, MessageType::PlasmaAbortReply, &fbb, message);
+}
+
+Status ReadAbortReply(const uint8_t* data, size_t size, ObjectID* object_id) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaAbortReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *object_id = ObjectID::from_binary(message->object_id()->str());
+ return Status::OK();
+}
+
+// Seal messages.
+
+Status SendSealRequest(int sock, ObjectID object_id, const std::string& digest) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaSealRequest(fbb, fbb.CreateString(object_id.binary()),
+ fbb.CreateString(digest));
+ return PlasmaSend(sock, MessageType::PlasmaSealRequest, &fbb, message);
+}
+
+Status ReadSealRequest(const uint8_t* data, size_t size, ObjectID* object_id,
+ std::string* digest) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaSealRequest>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *object_id = ObjectID::from_binary(message->object_id()->str());
+ ARROW_CHECK_EQ(message->digest()->size(), kDigestSize);
+ digest->assign(message->digest()->data(), kDigestSize);
+ return Status::OK();
+}
+
+Status SendSealReply(int sock, ObjectID object_id, PlasmaError error) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message =
+ fb::CreatePlasmaSealReply(fbb, fbb.CreateString(object_id.binary()), error);
+ return PlasmaSend(sock, MessageType::PlasmaSealReply, &fbb, message);
+}
+
+Status ReadSealReply(const uint8_t* data, size_t size, ObjectID* object_id) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaSealReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *object_id = ObjectID::from_binary(message->object_id()->str());
+ return PlasmaErrorStatus(message->error());
+}
+
+// Release messages.
+
+Status SendReleaseRequest(int sock, ObjectID object_id) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message =
+ fb::CreatePlasmaReleaseRequest(fbb, fbb.CreateString(object_id.binary()));
+ return PlasmaSend(sock, MessageType::PlasmaReleaseRequest, &fbb, message);
+}
+
+Status ReadReleaseRequest(const uint8_t* data, size_t size, ObjectID* object_id) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaReleaseRequest>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *object_id = ObjectID::from_binary(message->object_id()->str());
+ return Status::OK();
+}
+
+Status SendReleaseReply(int sock, ObjectID object_id, PlasmaError error) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message =
+ fb::CreatePlasmaReleaseReply(fbb, fbb.CreateString(object_id.binary()), error);
+ return PlasmaSend(sock, MessageType::PlasmaReleaseReply, &fbb, message);
+}
+
+Status ReadReleaseReply(const uint8_t* data, size_t size, ObjectID* object_id) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaReleaseReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *object_id = ObjectID::from_binary(message->object_id()->str());
+ return PlasmaErrorStatus(message->error());
+}
+
+// Delete objects messages.
+
+Status SendDeleteRequest(int sock, const std::vector<ObjectID>& object_ids) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaDeleteRequest(
+ fbb, static_cast<int32_t>(object_ids.size()),
+ ToFlatbuffer(&fbb, &object_ids[0], object_ids.size()));
+ return PlasmaSend(sock, MessageType::PlasmaDeleteRequest, &fbb, message);
+}
+
+Status ReadDeleteRequest(const uint8_t* data, size_t size,
+ std::vector<ObjectID>* object_ids) {
+ using fb::PlasmaDeleteRequest;
+
+ DCHECK(data);
+ DCHECK(object_ids);
+ auto message = flatbuffers::GetRoot<PlasmaDeleteRequest>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ ToVector(*message, object_ids, [](const PlasmaDeleteRequest& request, int i) {
+ return ObjectID::from_binary(request.object_ids()->Get(i)->str());
+ });
+ return Status::OK();
+}
+
+Status SendDeleteReply(int sock, const std::vector<ObjectID>& object_ids,
+ const std::vector<PlasmaError>& errors) {
+ DCHECK(object_ids.size() == errors.size());
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaDeleteReply(
+ fbb, static_cast<int32_t>(object_ids.size()),
+ ToFlatbuffer(&fbb, &object_ids[0], object_ids.size()),
+ fbb.CreateVector(
+ arrow::util::MakeNonNull(reinterpret_cast<const int32_t*>(errors.data())),
+ object_ids.size()));
+ return PlasmaSend(sock, MessageType::PlasmaDeleteReply, &fbb, message);
+}
+
+Status ReadDeleteReply(const uint8_t* data, size_t size,
+ std::vector<ObjectID>* object_ids,
+ std::vector<PlasmaError>* errors) {
+ using fb::PlasmaDeleteReply;
+
+ DCHECK(data);
+ DCHECK(object_ids);
+ DCHECK(errors);
+ auto message = flatbuffers::GetRoot<PlasmaDeleteReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ ToVector(*message, object_ids, [](const PlasmaDeleteReply& request, int i) {
+ return ObjectID::from_binary(request.object_ids()->Get(i)->str());
+ });
+ ToVector(*message, errors, [](const PlasmaDeleteReply& request, int i) {
+ return static_cast<PlasmaError>(request.errors()->Get(i));
+ });
+ return Status::OK();
+}
+
+// Contains messages.
+
+Status SendContainsRequest(int sock, ObjectID object_id) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message =
+ fb::CreatePlasmaContainsRequest(fbb, fbb.CreateString(object_id.binary()));
+ return PlasmaSend(sock, MessageType::PlasmaContainsRequest, &fbb, message);
+}
+
+Status ReadContainsRequest(const uint8_t* data, size_t size, ObjectID* object_id) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaContainsRequest>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *object_id = ObjectID::from_binary(message->object_id()->str());
+ return Status::OK();
+}
+
+Status SendContainsReply(int sock, ObjectID object_id, bool has_object) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaContainsReply(fbb, fbb.CreateString(object_id.binary()),
+ has_object);
+ return PlasmaSend(sock, MessageType::PlasmaContainsReply, &fbb, message);
+}
+
+Status ReadContainsReply(const uint8_t* data, size_t size, ObjectID* object_id,
+ bool* has_object) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaContainsReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *object_id = ObjectID::from_binary(message->object_id()->str());
+ *has_object = message->has_object();
+ return Status::OK();
+}
+
+// List messages.
+
+Status SendListRequest(int sock) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaListRequest(fbb);
+ return PlasmaSend(sock, MessageType::PlasmaListRequest, &fbb, message);
+}
+
+Status ReadListRequest(const uint8_t* data, size_t size) { return Status::OK(); }
+
+Status SendListReply(int sock, const ObjectTable& objects) {
+ flatbuffers::FlatBufferBuilder fbb;
+ std::vector<flatbuffers::Offset<fb::ObjectInfo>> object_infos;
+ for (auto const& entry : objects) {
+ auto digest = entry.second->state == ObjectState::PLASMA_CREATED
+ ? fbb.CreateString("")
+ : fbb.CreateString(reinterpret_cast<char*>(entry.second->digest),
+ kDigestSize);
+ auto info = fb::CreateObjectInfo(fbb, fbb.CreateString(entry.first.binary()),
+ entry.second->data_size, entry.second->metadata_size,
+ entry.second->ref_count, entry.second->create_time,
+ entry.second->construct_duration, digest);
+ object_infos.push_back(info);
+ }
+ auto message = fb::CreatePlasmaListReply(
+ fbb, fbb.CreateVector(arrow::util::MakeNonNull(object_infos.data()),
+ object_infos.size()));
+ return PlasmaSend(sock, MessageType::PlasmaListReply, &fbb, message);
+}
+
+Status ReadListReply(const uint8_t* data, size_t size, ObjectTable* objects) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaListReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ for (auto const object : *message->objects()) {
+ ObjectID object_id = ObjectID::from_binary(object->object_id()->str());
+ auto entry = std::unique_ptr<ObjectTableEntry>(new ObjectTableEntry());
+ entry->data_size = object->data_size();
+ entry->metadata_size = object->metadata_size();
+ entry->ref_count = object->ref_count();
+ entry->create_time = object->create_time();
+ entry->construct_duration = object->construct_duration();
+ entry->state = object->digest()->size() == 0 ? ObjectState::PLASMA_CREATED
+ : ObjectState::PLASMA_SEALED;
+ (*objects)[object_id] = std::move(entry);
+ }
+ return Status::OK();
+}
+
+// Connect messages.
+
+Status SendConnectRequest(int sock) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaConnectRequest(fbb);
+ return PlasmaSend(sock, MessageType::PlasmaConnectRequest, &fbb, message);
+}
+
+Status ReadConnectRequest(const uint8_t* data) { return Status::OK(); }
+
+Status SendConnectReply(int sock, int64_t memory_capacity) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaConnectReply(fbb, memory_capacity);
+ return PlasmaSend(sock, MessageType::PlasmaConnectReply, &fbb, message);
+}
+
+Status ReadConnectReply(const uint8_t* data, size_t size, int64_t* memory_capacity) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaConnectReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *memory_capacity = message->memory_capacity();
+ return Status::OK();
+}
+
+// Evict messages.
+
+Status SendEvictRequest(int sock, int64_t num_bytes) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaEvictRequest(fbb, num_bytes);
+ return PlasmaSend(sock, MessageType::PlasmaEvictRequest, &fbb, message);
+}
+
+Status ReadEvictRequest(const uint8_t* data, size_t size, int64_t* num_bytes) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaEvictRequest>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *num_bytes = message->num_bytes();
+ return Status::OK();
+}
+
+Status SendEvictReply(int sock, int64_t num_bytes) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaEvictReply(fbb, num_bytes);
+ return PlasmaSend(sock, MessageType::PlasmaEvictReply, &fbb, message);
+}
+
+Status ReadEvictReply(const uint8_t* data, size_t size, int64_t& num_bytes) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaEvictReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ num_bytes = message->num_bytes();
+ return Status::OK();
+}
+
+// Get messages.
+
+Status SendGetRequest(int sock, const ObjectID* object_ids, int64_t num_objects,
+ int64_t timeout_ms) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaGetRequest(
+ fbb, ToFlatbuffer(&fbb, object_ids, num_objects), timeout_ms);
+ return PlasmaSend(sock, MessageType::PlasmaGetRequest, &fbb, message);
+}
+
+Status ReadGetRequest(const uint8_t* data, size_t size, std::vector<ObjectID>& object_ids,
+ int64_t* timeout_ms) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaGetRequest>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ for (uoffset_t i = 0; i < message->object_ids()->size(); ++i) {
+ auto object_id = message->object_ids()->Get(i)->str();
+ object_ids.push_back(ObjectID::from_binary(object_id));
+ }
+ *timeout_ms = message->timeout_ms();
+ return Status::OK();
+}
+
+Status SendGetReply(int sock, ObjectID object_ids[],
+ std::unordered_map<ObjectID, PlasmaObject>& plasma_objects,
+ int64_t num_objects, const std::vector<int>& store_fds,
+ const std::vector<int64_t>& mmap_sizes) {
+ flatbuffers::FlatBufferBuilder fbb;
+ std::vector<PlasmaObjectSpec> objects;
+
+ std::vector<flatbuffers::Offset<fb::CudaHandle>> handles;
+ for (int64_t i = 0; i < num_objects; ++i) {
+ const PlasmaObject& object = plasma_objects[object_ids[i]];
+ objects.push_back(PlasmaObjectSpec(object.store_fd, object.data_offset,
+ object.data_size, object.metadata_offset,
+ object.metadata_size, object.device_num));
+#ifdef PLASMA_CUDA
+ if (object.device_num != 0) {
+ std::shared_ptr<arrow::Buffer> handle;
+ ARROW_ASSIGN_OR_RAISE(handle, object.ipc_handle->Serialize());
+ handles.push_back(
+ fb::CreateCudaHandle(fbb, fbb.CreateVector(handle->data(), handle->size())));
+ }
+#endif
+ }
+ auto message = fb::CreatePlasmaGetReply(
+ fbb, ToFlatbuffer(&fbb, object_ids, num_objects),
+ fbb.CreateVectorOfStructs(arrow::util::MakeNonNull(objects.data()), num_objects),
+ fbb.CreateVector(arrow::util::MakeNonNull(store_fds.data()), store_fds.size()),
+ fbb.CreateVector(arrow::util::MakeNonNull(mmap_sizes.data()), mmap_sizes.size()),
+ fbb.CreateVector(arrow::util::MakeNonNull(handles.data()), handles.size()));
+ return PlasmaSend(sock, MessageType::PlasmaGetReply, &fbb, message);
+}
+
+Status ReadGetReply(const uint8_t* data, size_t size, ObjectID object_ids[],
+ PlasmaObject plasma_objects[], int64_t num_objects,
+ std::vector<int>& store_fds, std::vector<int64_t>& mmap_sizes) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaGetReply>(data);
+#ifdef PLASMA_CUDA
+ int handle_pos = 0;
+#endif
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ for (uoffset_t i = 0; i < num_objects; ++i) {
+ object_ids[i] = ObjectID::from_binary(message->object_ids()->Get(i)->str());
+ }
+ for (uoffset_t i = 0; i < num_objects; ++i) {
+ const PlasmaObjectSpec* object = message->plasma_objects()->Get(i);
+ plasma_objects[i].store_fd = object->segment_index();
+ plasma_objects[i].data_offset = object->data_offset();
+ plasma_objects[i].data_size = object->data_size();
+ plasma_objects[i].metadata_offset = object->metadata_offset();
+ plasma_objects[i].metadata_size = object->metadata_size();
+ plasma_objects[i].device_num = object->device_num();
+#ifdef PLASMA_CUDA
+ if (object->device_num() != 0) {
+ const void* ipc_handle = message->handles()->Get(handle_pos)->handle()->data();
+ ARROW_ASSIGN_OR_RAISE(plasma_objects[i].ipc_handle,
+ CudaIpcMemHandle::FromBuffer(ipc_handle));
+ handle_pos++;
+ }
+#endif
+ }
+ ARROW_CHECK(message->store_fds()->size() == message->mmap_sizes()->size());
+ for (uoffset_t i = 0; i < message->store_fds()->size(); i++) {
+ store_fds.push_back(message->store_fds()->Get(i));
+ mmap_sizes.push_back(message->mmap_sizes()->Get(i));
+ }
+ return Status::OK();
+}
+
+// Subscribe messages.
+
+Status SendSubscribeRequest(int sock) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaSubscribeRequest(fbb);
+ return PlasmaSend(sock, MessageType::PlasmaSubscribeRequest, &fbb, message);
+}
+
+// Data messages.
+
+Status SendDataRequest(int sock, ObjectID object_id, const char* address, int port) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto addr = fbb.CreateString(address, strlen(address));
+ auto message =
+ fb::CreatePlasmaDataRequest(fbb, fbb.CreateString(object_id.binary()), addr, port);
+ return PlasmaSend(sock, MessageType::PlasmaDataRequest, &fbb, message);
+}
+
+Status ReadDataRequest(const uint8_t* data, size_t size, ObjectID* object_id,
+ char** address, int* port) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaDataRequest>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ DCHECK(message->object_id()->size() == sizeof(ObjectID));
+ *object_id = ObjectID::from_binary(message->object_id()->str());
+ *address = strdup(message->address()->c_str());
+ *port = message->port();
+ return Status::OK();
+}
+
+Status SendDataReply(int sock, ObjectID object_id, int64_t object_size,
+ int64_t metadata_size) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaDataReply(fbb, fbb.CreateString(object_id.binary()),
+ object_size, metadata_size);
+ return PlasmaSend(sock, MessageType::PlasmaDataReply, &fbb, message);
+}
+
+Status ReadDataReply(const uint8_t* data, size_t size, ObjectID* object_id,
+ int64_t* object_size, int64_t* metadata_size) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaDataReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ *object_id = ObjectID::from_binary(message->object_id()->str());
+ *object_size = static_cast<int64_t>(message->object_size());
+ *metadata_size = static_cast<int64_t>(message->metadata_size());
+ return Status::OK();
+}
+
+// RefreshLRU messages.
+
+Status SendRefreshLRURequest(int sock, const std::vector<ObjectID>& object_ids) {
+ flatbuffers::FlatBufferBuilder fbb;
+
+ auto message = fb::CreatePlasmaRefreshLRURequest(
+ fbb, ToFlatbuffer(&fbb, object_ids.data(), object_ids.size()));
+
+ return PlasmaSend(sock, MessageType::PlasmaRefreshLRURequest, &fbb, message);
+}
+
+Status ReadRefreshLRURequest(const uint8_t* data, size_t size,
+ std::vector<ObjectID>* object_ids) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaRefreshLRURequest>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ for (uoffset_t i = 0; i < message->object_ids()->size(); ++i) {
+ auto object_id = message->object_ids()->Get(i)->str();
+ object_ids->push_back(ObjectID::from_binary(object_id));
+ }
+ return Status::OK();
+}
+
+Status SendRefreshLRUReply(int sock) {
+ flatbuffers::FlatBufferBuilder fbb;
+ auto message = fb::CreatePlasmaRefreshLRUReply(fbb);
+ return PlasmaSend(sock, MessageType::PlasmaRefreshLRUReply, &fbb, message);
+}
+
+Status ReadRefreshLRUReply(const uint8_t* data, size_t size) {
+ DCHECK(data);
+ auto message = flatbuffers::GetRoot<fb::PlasmaRefreshLRUReply>(data);
+ DCHECK(VerifyFlatbuffer(message, data, size));
+ return Status::OK();
+}
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/protocol.h b/src/arrow/cpp/src/plasma/protocol.h
new file mode 100644
index 000000000..31257be47
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/protocol.h
@@ -0,0 +1,251 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "arrow/status.h"
+#include "plasma/plasma.h"
+#include "plasma/plasma_generated.h"
+
+namespace plasma {
+
+using arrow::Status;
+
+using flatbuf::MessageType;
+using flatbuf::PlasmaError;
+
+template <class T>
+bool VerifyFlatbuffer(T* object, const uint8_t* data, size_t size) {
+ flatbuffers::Verifier verifier(data, size);
+ return object->Verify(verifier);
+}
+
+flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
+ToFlatbuffer(flatbuffers::FlatBufferBuilder* fbb, const ObjectID* object_ids,
+ int64_t num_objects);
+
+flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
+ToFlatbuffer(flatbuffers::FlatBufferBuilder* fbb,
+ const std::vector<std::string>& strings);
+
+flatbuffers::Offset<flatbuffers::Vector<int64_t>> ToFlatbuffer(
+ flatbuffers::FlatBufferBuilder* fbb, const std::vector<int64_t>& data);
+
+/* Plasma receive message. */
+
+Status PlasmaReceive(int sock, MessageType message_type, std::vector<uint8_t>* buffer);
+
+/* Set options messages. */
+
+Status SendSetOptionsRequest(int sock, const std::string& client_name,
+ int64_t output_memory_limit);
+
+Status ReadSetOptionsRequest(const uint8_t* data, size_t size, std::string* client_name,
+ int64_t* output_memory_quota);
+
+Status SendSetOptionsReply(int sock, PlasmaError error);
+
+Status ReadSetOptionsReply(const uint8_t* data, size_t size);
+
+/* Debug string messages. */
+
+Status SendGetDebugStringRequest(int sock);
+
+Status SendGetDebugStringReply(int sock, const std::string& debug_string);
+
+Status ReadGetDebugStringReply(const uint8_t* data, size_t size,
+ std::string* debug_string);
+
+/* Plasma Create message functions. */
+
+Status SendCreateRequest(int sock, ObjectID object_id, bool evict_if_full,
+ int64_t data_size, int64_t metadata_size, int device_num);
+
+Status ReadCreateRequest(const uint8_t* data, size_t size, ObjectID* object_id,
+ bool* evict_if_full, int64_t* data_size, int64_t* metadata_size,
+ int* device_num);
+
+Status SendCreateReply(int sock, ObjectID object_id, PlasmaObject* object,
+ PlasmaError error, int64_t mmap_size);
+
+Status ReadCreateReply(const uint8_t* data, size_t size, ObjectID* object_id,
+ PlasmaObject* object, int* store_fd, int64_t* mmap_size);
+
+Status SendCreateAndSealRequest(int sock, const ObjectID& object_id, bool evict_if_full,
+ const std::string& data, const std::string& metadata,
+ unsigned char* digest);
+
+Status ReadCreateAndSealRequest(const uint8_t* data, size_t size, ObjectID* object_id,
+ bool* evict_if_full, std::string* object_data,
+ std::string* metadata, std::string* digest);
+
+Status SendCreateAndSealBatchRequest(int sock, const std::vector<ObjectID>& object_ids,
+ bool evict_if_full,
+ const std::vector<std::string>& data,
+ const std::vector<std::string>& metadata,
+ const std::vector<std::string>& digests);
+
+Status ReadCreateAndSealBatchRequest(const uint8_t* data, size_t size,
+ std::vector<ObjectID>* object_id,
+ bool* evict_if_full,
+ std::vector<std::string>* object_data,
+ std::vector<std::string>* metadata,
+ std::vector<std::string>* digests);
+
+Status SendCreateAndSealReply(int sock, PlasmaError error);
+
+Status ReadCreateAndSealReply(const uint8_t* data, size_t size);
+
+Status SendCreateAndSealBatchReply(int sock, PlasmaError error);
+
+Status ReadCreateAndSealBatchReply(const uint8_t* data, size_t size);
+
+Status SendAbortRequest(int sock, ObjectID object_id);
+
+Status ReadAbortRequest(const uint8_t* data, size_t size, ObjectID* object_id);
+
+Status SendAbortReply(int sock, ObjectID object_id);
+
+Status ReadAbortReply(const uint8_t* data, size_t size, ObjectID* object_id);
+
+/* Plasma Seal message functions. */
+
+Status SendSealRequest(int sock, ObjectID object_id, const std::string& digest);
+
+Status ReadSealRequest(const uint8_t* data, size_t size, ObjectID* object_id,
+ std::string* digest);
+
+Status SendSealReply(int sock, ObjectID object_id, PlasmaError error);
+
+Status ReadSealReply(const uint8_t* data, size_t size, ObjectID* object_id);
+
+/* Plasma Get message functions. */
+
+Status SendGetRequest(int sock, const ObjectID* object_ids, int64_t num_objects,
+ int64_t timeout_ms);
+
+Status ReadGetRequest(const uint8_t* data, size_t size, std::vector<ObjectID>& object_ids,
+ int64_t* timeout_ms);
+
+Status SendGetReply(int sock, ObjectID object_ids[],
+ std::unordered_map<ObjectID, PlasmaObject>& plasma_objects,
+ int64_t num_objects, const std::vector<int>& store_fds,
+ const std::vector<int64_t>& mmap_sizes);
+
+Status ReadGetReply(const uint8_t* data, size_t size, ObjectID object_ids[],
+ PlasmaObject plasma_objects[], int64_t num_objects,
+ std::vector<int>& store_fds, std::vector<int64_t>& mmap_sizes);
+
+/* Plasma Release message functions. */
+
+Status SendReleaseRequest(int sock, ObjectID object_id);
+
+Status ReadReleaseRequest(const uint8_t* data, size_t size, ObjectID* object_id);
+
+Status SendReleaseReply(int sock, ObjectID object_id, PlasmaError error);
+
+Status ReadReleaseReply(const uint8_t* data, size_t size, ObjectID* object_id);
+
+/* Plasma Delete objects message functions. */
+
+Status SendDeleteRequest(int sock, const std::vector<ObjectID>& object_ids);
+
+Status ReadDeleteRequest(const uint8_t* data, size_t size,
+ std::vector<ObjectID>* object_ids);
+
+Status SendDeleteReply(int sock, const std::vector<ObjectID>& object_ids,
+ const std::vector<PlasmaError>& errors);
+
+Status ReadDeleteReply(const uint8_t* data, size_t size,
+ std::vector<ObjectID>* object_ids,
+ std::vector<PlasmaError>* errors);
+
+/* Plasma Contains message functions. */
+
+Status SendContainsRequest(int sock, ObjectID object_id);
+
+Status ReadContainsRequest(const uint8_t* data, size_t size, ObjectID* object_id);
+
+Status SendContainsReply(int sock, ObjectID object_id, bool has_object);
+
+Status ReadContainsReply(const uint8_t* data, size_t size, ObjectID* object_id,
+ bool* has_object);
+
+/* Plasma List message functions. */
+
+Status SendListRequest(int sock);
+
+Status ReadListRequest(const uint8_t* data, size_t size);
+
+Status SendListReply(int sock, const ObjectTable& objects);
+
+Status ReadListReply(const uint8_t* data, size_t size, ObjectTable* objects);
+
+/* Plasma Connect message functions. */
+
+Status SendConnectRequest(int sock);
+
+Status ReadConnectRequest(const uint8_t* data, size_t size);
+
+Status SendConnectReply(int sock, int64_t memory_capacity);
+
+Status ReadConnectReply(const uint8_t* data, size_t size, int64_t* memory_capacity);
+
+/* Plasma Evict message functions (no reply so far). */
+
+Status SendEvictRequest(int sock, int64_t num_bytes);
+
+Status ReadEvictRequest(const uint8_t* data, size_t size, int64_t* num_bytes);
+
+Status SendEvictReply(int sock, int64_t num_bytes);
+
+Status ReadEvictReply(const uint8_t* data, size_t size, int64_t& num_bytes);
+
+/* Plasma Subscribe message functions. */
+
+Status SendSubscribeRequest(int sock);
+
+/* Data messages. */
+
+Status SendDataRequest(int sock, ObjectID object_id, const char* address, int port);
+
+Status ReadDataRequest(const uint8_t* data, size_t size, ObjectID* object_id,
+ char** address, int* port);
+
+Status SendDataReply(int sock, ObjectID object_id, int64_t object_size,
+ int64_t metadata_size);
+
+Status ReadDataReply(const uint8_t* data, size_t size, ObjectID* object_id,
+ int64_t* object_size, int64_t* metadata_size);
+
+/* Plasma refresh LRU cache functions. */
+
+Status SendRefreshLRURequest(int sock, const std::vector<ObjectID>& object_ids);
+
+Status ReadRefreshLRURequest(const uint8_t* data, size_t size,
+ std::vector<ObjectID>* object_ids);
+
+Status SendRefreshLRUReply(int sock);
+
+Status ReadRefreshLRUReply(const uint8_t* data, size_t size);
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/quota_aware_policy.cc b/src/arrow/cpp/src/plasma/quota_aware_policy.cc
new file mode 100644
index 000000000..67c4e9248
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/quota_aware_policy.cc
@@ -0,0 +1,177 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "plasma/quota_aware_policy.h"
+#include "plasma/common.h"
+#include "plasma/plasma_allocator.h"
+
+#include <algorithm>
+#include <memory>
+#include <sstream>
+
+namespace plasma {
+
+QuotaAwarePolicy::QuotaAwarePolicy(PlasmaStoreInfo* store_info, int64_t max_size)
+ : EvictionPolicy(store_info, max_size) {}
+
+bool QuotaAwarePolicy::HasQuota(Client* client, bool is_create) {
+ if (!is_create) {
+ return false; // no quota enforcement on read requests yet
+ }
+ return per_client_cache_.find(client) != per_client_cache_.end();
+}
+
+void QuotaAwarePolicy::ObjectCreated(const ObjectID& object_id, Client* client,
+ bool is_create) {
+ if (HasQuota(client, is_create)) {
+ per_client_cache_[client]->Add(object_id, GetObjectSize(object_id));
+ owned_by_client_[object_id] = client;
+ } else {
+ EvictionPolicy::ObjectCreated(object_id, client, is_create);
+ }
+}
+
+bool QuotaAwarePolicy::SetClientQuota(Client* client, int64_t output_memory_quota) {
+ if (per_client_cache_.find(client) != per_client_cache_.end()) {
+ ARROW_LOG(WARNING) << "Cannot change the client quota once set";
+ return false;
+ }
+
+ if (cache_.Capacity() - output_memory_quota <
+ cache_.OriginalCapacity() * kGlobalLruReserveFraction) {
+ ARROW_LOG(WARNING) << "Not enough memory to set client quota: " << DebugString();
+ return false;
+ }
+
+ // those objects will be lazily evicted on the next call
+ cache_.AdjustCapacity(-output_memory_quota);
+ per_client_cache_[client] =
+ std::unique_ptr<LRUCache>(new LRUCache(client->name, output_memory_quota));
+ return true;
+}
+
+bool QuotaAwarePolicy::EnforcePerClientQuota(Client* client, int64_t size, bool is_create,
+ std::vector<ObjectID>* objects_to_evict) {
+ if (!HasQuota(client, is_create)) {
+ return true;
+ }
+
+ auto& client_cache = per_client_cache_[client];
+ if (size > client_cache->Capacity()) {
+ ARROW_LOG(WARNING) << "object too large (" << size
+ << " bytes) to fit in client quota " << client_cache->Capacity()
+ << " " << DebugString();
+ return false;
+ }
+
+ if (client_cache->RemainingCapacity() >= size) {
+ return true;
+ }
+
+ int64_t space_to_free = size - client_cache->RemainingCapacity();
+ if (space_to_free > 0) {
+ std::vector<ObjectID> candidates;
+ client_cache->ChooseObjectsToEvict(space_to_free, &candidates);
+ for (ObjectID& object_id : candidates) {
+ if (shared_for_read_.count(object_id)) {
+ // Pinned so we can't evict it, so demote the object to global LRU instead.
+ // We an do this by simply removing it from all data structures, so that
+ // the next EndObjectAccess() will add it back to global LRU.
+ shared_for_read_.erase(object_id);
+ } else {
+ objects_to_evict->push_back(object_id);
+ }
+ owned_by_client_.erase(object_id);
+ client_cache->Remove(object_id);
+ }
+ }
+ return true;
+}
+
+void QuotaAwarePolicy::BeginObjectAccess(const ObjectID& object_id) {
+ if (owned_by_client_.find(object_id) != owned_by_client_.end()) {
+ shared_for_read_.insert(object_id);
+ pinned_memory_bytes_ += GetObjectSize(object_id);
+ return;
+ }
+ EvictionPolicy::BeginObjectAccess(object_id);
+}
+
+void QuotaAwarePolicy::EndObjectAccess(const ObjectID& object_id) {
+ if (owned_by_client_.find(object_id) != owned_by_client_.end()) {
+ shared_for_read_.erase(object_id);
+ pinned_memory_bytes_ -= GetObjectSize(object_id);
+ return;
+ }
+ EvictionPolicy::EndObjectAccess(object_id);
+}
+
+void QuotaAwarePolicy::RemoveObject(const ObjectID& object_id) {
+ if (owned_by_client_.find(object_id) != owned_by_client_.end()) {
+ per_client_cache_[owned_by_client_[object_id]]->Remove(object_id);
+ owned_by_client_.erase(object_id);
+ shared_for_read_.erase(object_id);
+ return;
+ }
+ EvictionPolicy::RemoveObject(object_id);
+}
+
+void QuotaAwarePolicy::RefreshObjects(const std::vector<ObjectID>& object_ids) {
+ for (const auto& object_id : object_ids) {
+ if (owned_by_client_.find(object_id) != owned_by_client_.end()) {
+ int64_t size = per_client_cache_[owned_by_client_[object_id]]->Remove(object_id);
+ per_client_cache_[owned_by_client_[object_id]]->Add(object_id, size);
+ }
+ }
+ EvictionPolicy::RefreshObjects(object_ids);
+}
+
+void QuotaAwarePolicy::ClientDisconnected(Client* client) {
+ if (per_client_cache_.find(client) == per_client_cache_.end()) {
+ return;
+ }
+ // return capacity back to global LRU
+ cache_.AdjustCapacity(per_client_cache_[client]->Capacity());
+ // clean up any entries used to track this client's quota usage
+ per_client_cache_[client]->Foreach([this](const ObjectID& obj) {
+ if (!shared_for_read_.count(obj)) {
+ // only add it to the global LRU if we have it in pinned mode
+ // otherwise, EndObjectAccess will add it later
+ cache_.Add(obj, GetObjectSize(obj));
+ }
+ owned_by_client_.erase(obj);
+ shared_for_read_.erase(obj);
+ });
+ per_client_cache_.erase(client);
+}
+
+std::string QuotaAwarePolicy::DebugString() const {
+ std::stringstream result;
+ result << "num clients with quota: " << per_client_cache_.size();
+ result << "\nquota map size: " << owned_by_client_.size();
+ result << "\npinned quota map size: " << shared_for_read_.size();
+ result << "\nallocated bytes: " << PlasmaAllocator::Allocated();
+ result << "\nallocation limit: " << PlasmaAllocator::GetFootprintLimit();
+ result << "\npinned bytes: " << pinned_memory_bytes_;
+ result << cache_.DebugString();
+ for (const auto& pair : per_client_cache_) {
+ result << pair.second->DebugString();
+ }
+ return result.str();
+}
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/quota_aware_policy.h b/src/arrow/cpp/src/plasma/quota_aware_policy.h
new file mode 100644
index 000000000..9bb7dbccc
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/quota_aware_policy.h
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <list>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "plasma/common.h"
+#include "plasma/eviction_policy.h"
+#include "plasma/plasma.h"
+
+namespace plasma {
+
+/// Reserve this fraction of memory for shared usage. Attempts to set client
+/// quotas that would cause the global LRU memory fraction to fall below this
+/// value will be rejected.
+constexpr double kGlobalLruReserveFraction = 0.3;
+
+/// Extends the basic eviction policy to implement per-client memory quotas.
+/// This effectively gives each client its own LRU queue, which caps its
+/// memory usage and protects this memory from being evicted by other clients.
+///
+/// The quotas are enforced when objects are first created, by evicting the
+/// necessary number of objects from the client's own LRU queue to cap its
+/// memory usage. Once that is done, allocation is handled by the normal
+/// eviction policy. This may result in the eviction of objects from the
+/// global LRU queue, if not enough memory can be allocated even after the
+/// evictions from the client's own LRU queue.
+///
+/// Some special cases:
+/// - When a pinned object is "evicted" from a per-client queue, it is
+/// instead transferred into the global LRU queue.
+/// - When a client disconnects, its LRU queue is merged into the head of the
+/// global LRU queue.
+class QuotaAwarePolicy : public EvictionPolicy {
+ public:
+ /// Construct a quota-aware eviction policy.
+ ///
+ /// \param store_info Information about the Plasma store that is exposed
+ /// to the eviction policy.
+ /// \param max_size Max size in bytes total of objects to store.
+ explicit QuotaAwarePolicy(PlasmaStoreInfo* store_info, int64_t max_size);
+ void ObjectCreated(const ObjectID& object_id, Client* client, bool is_create) override;
+ bool SetClientQuota(Client* client, int64_t output_memory_quota) override;
+ bool EnforcePerClientQuota(Client* client, int64_t size, bool is_create,
+ std::vector<ObjectID>* objects_to_evict) override;
+ void ClientDisconnected(Client* client) override;
+ void BeginObjectAccess(const ObjectID& object_id) override;
+ void EndObjectAccess(const ObjectID& object_id) override;
+ void RemoveObject(const ObjectID& object_id) override;
+ void RefreshObjects(const std::vector<ObjectID>& object_ids) override;
+ std::string DebugString() const override;
+
+ private:
+ /// Returns whether we are enforcing memory quotas for an operation.
+ bool HasQuota(Client* client, bool is_create);
+
+ /// Per-client LRU caches, if quota is enabled.
+ std::unordered_map<Client*, std::unique_ptr<LRUCache>> per_client_cache_;
+ /// Tracks which client created which object. This only applies to clients
+ /// that have a memory quota set.
+ std::unordered_map<ObjectID, Client*> owned_by_client_;
+ /// Tracks which objects are mapped for read and hence can't be evicted.
+ /// However these objects are still tracked within the client caches.
+ std::unordered_set<ObjectID> shared_for_read_;
+};
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/store.cc b/src/arrow/cpp/src/plasma/store.cc
new file mode 100644
index 000000000..032a12fcf
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/store.cc
@@ -0,0 +1,1353 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// PLASMA STORE: This is a simple object store server process
+//
+// It accepts incoming client connections on a unix domain socket
+// (name passed in via the -s option of the executable) and uses a
+// single thread to serve the clients. Each client establishes a
+// connection and can create objects, wait for objects and seal
+// objects through that connection.
+//
+// It keeps a hash table that maps object_ids (which are 20 byte long,
+// just enough to store and SHA1 hash) to memory mapped files.
+
+#include "plasma/store.h"
+
+#include <assert.h>
+#include <fcntl.h>
+#include <getopt.h>
+#include <limits.h>
+#include <signal.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/statvfs.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <ctime>
+#include <deque>
+#include <iostream>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include <gflags/gflags.h>
+
+#include "arrow/status.h"
+#include "arrow/util/config.h"
+
+#include "plasma/common.h"
+#include "plasma/common_generated.h"
+#include "plasma/fling.h"
+#include "plasma/io.h"
+#include "plasma/malloc.h"
+#include "plasma/plasma_allocator.h"
+#include "plasma/protocol.h"
+
+#ifdef PLASMA_CUDA
+#include "arrow/gpu/cuda_api.h"
+
+using arrow::cuda::CudaBuffer;
+using arrow::cuda::CudaContext;
+using arrow::cuda::CudaDeviceManager;
+#endif
+
+using arrow::util::ArrowLog;
+using arrow::util::ArrowLogLevel;
+
+namespace fb = plasma::flatbuf;
+
+namespace plasma {
+
+void SetMallocGranularity(int value);
+
+struct GetRequest {
+ GetRequest(Client* client, const std::vector<ObjectID>& object_ids);
+ /// The client that called get.
+ Client* client;
+ /// The ID of the timer that will time out and cause this wait to return to
+ /// the client if it hasn't already returned.
+ int64_t timer;
+ /// The object IDs involved in this request. This is used in the reply.
+ std::vector<ObjectID> object_ids;
+ /// The object information for the objects in this request. This is used in
+ /// the reply.
+ std::unordered_map<ObjectID, PlasmaObject> objects;
+ /// The minimum number of objects to wait for in this request.
+ int64_t num_objects_to_wait_for;
+ /// The number of object requests in this wait request that are already
+ /// satisfied.
+ int64_t num_satisfied;
+};
+
+GetRequest::GetRequest(Client* client, const std::vector<ObjectID>& object_ids)
+ : client(client),
+ timer(-1),
+ object_ids(object_ids.begin(), object_ids.end()),
+ objects(object_ids.size()),
+ num_satisfied(0) {
+ std::unordered_set<ObjectID> unique_ids(object_ids.begin(), object_ids.end());
+ num_objects_to_wait_for = unique_ids.size();
+}
+
+Client::Client(int fd) : fd(fd), notification_fd(-1) {}
+
+PlasmaStore::PlasmaStore(EventLoop* loop, std::string directory, bool hugepages_enabled,
+ const std::string& socket_name,
+ std::shared_ptr<ExternalStore> external_store)
+ : loop_(loop),
+ eviction_policy_(&store_info_, PlasmaAllocator::GetFootprintLimit()),
+ external_store_(external_store) {
+ store_info_.directory = directory;
+ store_info_.hugepages_enabled = hugepages_enabled;
+}
+
+// TODO(pcm): Get rid of this destructor by using RAII to clean up data.
+PlasmaStore::~PlasmaStore() {}
+
+const PlasmaStoreInfo* PlasmaStore::GetPlasmaStoreInfo() { return &store_info_; }
+
+// If this client is not already using the object, add the client to the
+// object's list of clients, otherwise do nothing.
+void PlasmaStore::AddToClientObjectIds(const ObjectID& object_id, ObjectTableEntry* entry,
+ Client* client) {
+ // Check if this client is already using the object.
+ if (client->object_ids.find(object_id) != client->object_ids.end()) {
+ return;
+ }
+ // If there are no other clients using this object, notify the eviction policy
+ // that the object is being used.
+ if (entry->ref_count == 0) {
+ // Tell the eviction policy that this object is being used.
+ eviction_policy_.BeginObjectAccess(object_id);
+ }
+ // Increase reference count.
+ entry->ref_count++;
+
+ // Add object id to the list of object ids that this client is using.
+ client->object_ids.insert(object_id);
+}
+
+// Allocate memory
+uint8_t* PlasmaStore::AllocateMemory(size_t size, bool evict_if_full, int* fd,
+ int64_t* map_size, ptrdiff_t* offset, Client* client,
+ bool is_create) {
+ // First free up space from the client's LRU queue if quota enforcement is on.
+ if (evict_if_full) {
+ std::vector<ObjectID> client_objects_to_evict;
+ bool quota_ok = eviction_policy_.EnforcePerClientQuota(client, size, is_create,
+ &client_objects_to_evict);
+ if (!quota_ok) {
+ return nullptr;
+ }
+ EvictObjects(client_objects_to_evict);
+ }
+
+ // Try to evict objects until there is enough space.
+ uint8_t* pointer = nullptr;
+ while (true) {
+ // Allocate space for the new object. We use memalign instead of malloc
+ // in order to align the allocated region to a 64-byte boundary. This is not
+ // strictly necessary, but it is an optimization that could speed up the
+ // computation of a hash of the data (see compute_object_hash_parallel in
+ // plasma_client.cc). Note that even though this pointer is 64-byte aligned,
+ // it is not guaranteed that the corresponding pointer in the client will be
+ // 64-byte aligned, but in practice it often will be.
+ pointer = reinterpret_cast<uint8_t*>(PlasmaAllocator::Memalign(kBlockSize, size));
+ if (pointer || !evict_if_full) {
+ // If we manage to allocate the memory, return the pointer. If we cannot
+ // allocate the space, but we are also not allowed to evict anything to
+ // make more space, return an error to the client.
+ break;
+ }
+ // Tell the eviction policy how much space we need to create this object.
+ std::vector<ObjectID> objects_to_evict;
+ bool success = eviction_policy_.RequireSpace(size, &objects_to_evict);
+ EvictObjects(objects_to_evict);
+ // Return an error to the client if not enough space could be freed to
+ // create the object.
+ if (!success) {
+ break;
+ }
+ }
+
+ if (pointer != nullptr) {
+ GetMallocMapinfo(pointer, fd, map_size, offset);
+ ARROW_CHECK(*fd != -1);
+ }
+ return pointer;
+}
+
+#ifdef PLASMA_CUDA
+arrow::Result<std::shared_ptr<CudaContext>> PlasmaStore::GetCudaContext(int device_num) {
+ DCHECK_NE(device_num, 0);
+ ARROW_ASSIGN_OR_RAISE(auto manager, CudaDeviceManager::Instance());
+ return manager->GetContext(device_num - 1);
+}
+
+Status PlasmaStore::AllocateCudaMemory(
+ int device_num, int64_t size, uint8_t** out_pointer,
+ std::shared_ptr<CudaIpcMemHandle>* out_ipc_handle) {
+ ARROW_ASSIGN_OR_RAISE(auto context, GetCudaContext(device_num));
+ ARROW_ASSIGN_OR_RAISE(auto cuda_buffer, context->Allocate(static_cast<int64_t>(size)));
+ *out_pointer = reinterpret_cast<uint8_t*>(cuda_buffer->address());
+ // The IPC handle will keep the buffer memory alive
+ return cuda_buffer->ExportForIpc().Value(out_ipc_handle);
+}
+
+Status PlasmaStore::FreeCudaMemory(int device_num, int64_t size, uint8_t* pointer) {
+ ARROW_ASSIGN_OR_RAISE(auto context, GetCudaContext(device_num));
+ RETURN_NOT_OK(context->Free(pointer, size));
+ return Status::OK();
+}
+#endif
+
+// Create a new object buffer in the hash table.
+PlasmaError PlasmaStore::CreateObject(const ObjectID& object_id, bool evict_if_full,
+ int64_t data_size, int64_t metadata_size,
+ int device_num, Client* client,
+ PlasmaObject* result) {
+ ARROW_LOG(DEBUG) << "creating object " << object_id.hex();
+
+ auto entry = GetObjectTableEntry(&store_info_, object_id);
+ if (entry != nullptr) {
+ // There is already an object with the same ID in the Plasma Store, so
+ // ignore this request.
+ return PlasmaError::ObjectExists;
+ }
+
+ int fd = -1;
+ int64_t map_size = 0;
+ ptrdiff_t offset = 0;
+ uint8_t* pointer = nullptr;
+ auto total_size = data_size + metadata_size;
+
+ if (device_num == 0) {
+ pointer =
+ AllocateMemory(total_size, evict_if_full, &fd, &map_size, &offset, client, true);
+ if (!pointer) {
+ ARROW_LOG(ERROR) << "Not enough memory to create the object " << object_id.hex()
+ << ", data_size=" << data_size
+ << ", metadata_size=" << metadata_size
+ << ", will send a reply of PlasmaError::OutOfMemory";
+ return PlasmaError::OutOfMemory;
+ }
+ } else {
+#ifdef PLASMA_CUDA
+ /// IPC GPU handle to share with clients.
+ std::shared_ptr<::arrow::cuda::CudaIpcMemHandle> ipc_handle;
+ auto st = AllocateCudaMemory(device_num, total_size, &pointer, &ipc_handle);
+ if (!st.ok()) {
+ ARROW_LOG(ERROR) << "Failed to allocate CUDA memory: " << st.ToString();
+ return PlasmaError::OutOfMemory;
+ }
+ result->ipc_handle = ipc_handle;
+#else
+ ARROW_LOG(ERROR) << "device_num != 0 but CUDA not enabled";
+ return PlasmaError::OutOfMemory;
+#endif
+ }
+
+ auto ptr = std::unique_ptr<ObjectTableEntry>(new ObjectTableEntry());
+ entry = store_info_.objects.emplace(object_id, std::move(ptr)).first->second.get();
+ entry->data_size = data_size;
+ entry->metadata_size = metadata_size;
+ entry->pointer = pointer;
+ // TODO(pcm): Set the other fields.
+ entry->fd = fd;
+ entry->map_size = map_size;
+ entry->offset = offset;
+ entry->state = ObjectState::PLASMA_CREATED;
+ entry->device_num = device_num;
+ entry->create_time = std::time(nullptr);
+ entry->construct_duration = -1;
+
+#ifdef PLASMA_CUDA
+ entry->ipc_handle = result->ipc_handle;
+#endif
+
+ result->store_fd = fd;
+ result->data_offset = offset;
+ result->metadata_offset = offset + data_size;
+ result->data_size = data_size;
+ result->metadata_size = metadata_size;
+ result->device_num = device_num;
+ // Notify the eviction policy that this object was created. This must be done
+ // immediately before the call to AddToClientObjectIds so that the
+ // eviction policy does not have an opportunity to evict the object.
+ eviction_policy_.ObjectCreated(object_id, client, true);
+ // Record that this client is using this object.
+ AddToClientObjectIds(object_id, store_info_.objects[object_id].get(), client);
+ return PlasmaError::OK;
+}
+
+void PlasmaObject_init(PlasmaObject* object, ObjectTableEntry* entry) {
+ DCHECK(object != nullptr);
+ DCHECK(entry != nullptr);
+ DCHECK(entry->state == ObjectState::PLASMA_SEALED);
+#ifdef PLASMA_CUDA
+ if (entry->device_num != 0) {
+ object->ipc_handle = entry->ipc_handle;
+ }
+#endif
+ object->store_fd = entry->fd;
+ object->data_offset = entry->offset;
+ object->metadata_offset = entry->offset + entry->data_size;
+ object->data_size = entry->data_size;
+ object->metadata_size = entry->metadata_size;
+ object->device_num = entry->device_num;
+}
+
+void PlasmaStore::RemoveGetRequest(GetRequest* get_request) {
+ // Remove the get request from each of the relevant object_get_requests hash
+ // tables if it is present there. It should only be present there if the get
+ // request timed out or if it was issued by a client that has disconnected.
+ for (ObjectID& object_id : get_request->object_ids) {
+ auto object_request_iter = object_get_requests_.find(object_id);
+ if (object_request_iter != object_get_requests_.end()) {
+ auto& get_requests = object_request_iter->second;
+ // Erase get_req from the vector.
+ auto it = std::find(get_requests.begin(), get_requests.end(), get_request);
+ if (it != get_requests.end()) {
+ get_requests.erase(it);
+ // If the vector is empty, remove the object ID from the map.
+ if (get_requests.empty()) {
+ object_get_requests_.erase(object_request_iter);
+ }
+ }
+ }
+ }
+ // Remove the get request.
+ if (get_request->timer != -1) {
+ ARROW_CHECK(loop_->RemoveTimer(get_request->timer) == kEventLoopOk);
+ }
+ delete get_request;
+}
+
+void PlasmaStore::RemoveGetRequestsForClient(Client* client) {
+ std::unordered_set<GetRequest*> get_requests_to_remove;
+ for (auto const& pair : object_get_requests_) {
+ for (GetRequest* get_request : pair.second) {
+ if (get_request->client == client) {
+ get_requests_to_remove.insert(get_request);
+ }
+ }
+ }
+
+ // It shouldn't be possible for a given client to be in the middle of multiple get
+ // requests.
+ ARROW_CHECK(get_requests_to_remove.size() <= 1);
+ for (GetRequest* get_request : get_requests_to_remove) {
+ RemoveGetRequest(get_request);
+ }
+}
+
+void PlasmaStore::ReturnFromGet(GetRequest* get_req) {
+ // Figure out how many file descriptors we need to send.
+ std::unordered_set<int> fds_to_send;
+ std::vector<int> store_fds;
+ std::vector<int64_t> mmap_sizes;
+ for (const auto& object_id : get_req->object_ids) {
+ PlasmaObject& object = get_req->objects[object_id];
+ int fd = object.store_fd;
+ if (object.data_size != -1 && fds_to_send.count(fd) == 0 && fd != -1) {
+ fds_to_send.insert(fd);
+ store_fds.push_back(fd);
+ mmap_sizes.push_back(GetMmapSize(fd));
+ }
+ }
+
+ // Send the get reply to the client.
+ Status s = SendGetReply(get_req->client->fd, &get_req->object_ids[0], get_req->objects,
+ get_req->object_ids.size(), store_fds, mmap_sizes);
+ WarnIfSigpipe(s.ok() ? 0 : -1, get_req->client->fd);
+ // If we successfully sent the get reply message to the client, then also send
+ // the file descriptors.
+ if (s.ok()) {
+ // Send all of the file descriptors for the present objects.
+ for (int store_fd : store_fds) {
+ // Only send the file descriptor if it hasn't been sent (see analogous
+ // logic in GetStoreFd in client.cc).
+ if (get_req->client->used_fds.find(store_fd) == get_req->client->used_fds.end()) {
+ WarnIfSigpipe(send_fd(get_req->client->fd, store_fd), get_req->client->fd);
+ get_req->client->used_fds.insert(store_fd);
+ }
+ }
+ }
+
+ // Remove the get request from each of the relevant object_get_requests hash
+ // tables if it is present there. It should only be present there if the get
+ // request timed out.
+ RemoveGetRequest(get_req);
+}
+
+void PlasmaStore::UpdateObjectGetRequests(const ObjectID& object_id) {
+ auto it = object_get_requests_.find(object_id);
+ // If there are no get requests involving this object, then return.
+ if (it == object_get_requests_.end()) {
+ return;
+ }
+
+ auto& get_requests = it->second;
+
+ // After finishing the loop below, get_requests and it will have been
+ // invalidated by the removal of object_id from object_get_requests_.
+ size_t index = 0;
+ size_t num_requests = get_requests.size();
+ for (size_t i = 0; i < num_requests; ++i) {
+ auto get_req = get_requests[index];
+ auto entry = GetObjectTableEntry(&store_info_, object_id);
+ ARROW_CHECK(entry != nullptr);
+
+ PlasmaObject_init(&get_req->objects[object_id], entry);
+ get_req->num_satisfied += 1;
+ // Record the fact that this client will be using this object and will
+ // be responsible for releasing this object.
+ AddToClientObjectIds(object_id, entry, get_req->client);
+
+ // If this get request is done, reply to the client.
+ if (get_req->num_satisfied == get_req->num_objects_to_wait_for) {
+ ReturnFromGet(get_req);
+ } else {
+ // The call to ReturnFromGet will remove the current element in the
+ // array, so we only increment the counter in the else branch.
+ index += 1;
+ }
+ }
+
+ // No get requests should be waiting for this object anymore. The object ID
+ // may have been removed from the object_get_requests_ by ReturnFromGet, but
+ // if the get request has not returned yet, then remove the object ID from the
+ // map here.
+ it = object_get_requests_.find(object_id);
+ if (it != object_get_requests_.end()) {
+ object_get_requests_.erase(object_id);
+ }
+}
+
+void PlasmaStore::ProcessGetRequest(Client* client,
+ const std::vector<ObjectID>& object_ids,
+ int64_t timeout_ms) {
+ // Create a get request for this object.
+ auto get_req = new GetRequest(client, object_ids);
+ std::vector<ObjectID> evicted_ids;
+ std::vector<ObjectTableEntry*> evicted_entries;
+ for (auto object_id : object_ids) {
+ // Check if this object is already present locally. If so, record that the
+ // object is being used and mark it as accounted for.
+ auto entry = GetObjectTableEntry(&store_info_, object_id);
+ if (entry && entry->state == ObjectState::PLASMA_SEALED) {
+ // Update the get request to take into account the present object.
+ PlasmaObject_init(&get_req->objects[object_id], entry);
+ get_req->num_satisfied += 1;
+ // If necessary, record that this client is using this object. In the case
+ // where entry == NULL, this will be called from SealObject.
+ AddToClientObjectIds(object_id, entry, client);
+ } else if (entry && entry->state == ObjectState::PLASMA_EVICTED) {
+ // Make sure the object pointer is not already allocated
+ ARROW_CHECK(!entry->pointer);
+
+ entry->pointer =
+ AllocateMemory(entry->data_size + entry->metadata_size, /*evict=*/true,
+ &entry->fd, &entry->map_size, &entry->offset, client, false);
+ if (entry->pointer) {
+ entry->state = ObjectState::PLASMA_CREATED;
+ entry->create_time = std::time(nullptr);
+ eviction_policy_.ObjectCreated(object_id, client, false);
+ AddToClientObjectIds(object_id, store_info_.objects[object_id].get(), client);
+ evicted_ids.push_back(object_id);
+ evicted_entries.push_back(entry);
+ } else {
+ // We are out of memory and cannot allocate memory for this object.
+ // Change the state of the object back to PLASMA_EVICTED so some
+ // other request can try again.
+ entry->state = ObjectState::PLASMA_EVICTED;
+ }
+ } else {
+ // Add a placeholder plasma object to the get request to indicate that the
+ // object is not present. This will be parsed by the client. We set the
+ // data size to -1 to indicate that the object is not present.
+ get_req->objects[object_id].data_size = -1;
+ // Add the get request to the relevant data structures.
+ object_get_requests_[object_id].push_back(get_req);
+ }
+ }
+
+ if (!evicted_ids.empty()) {
+ unsigned char digest[kDigestSize] = {};
+ std::vector<std::shared_ptr<Buffer>> buffers;
+ for (size_t i = 0; i < evicted_ids.size(); ++i) {
+ ARROW_CHECK(evicted_entries[i]->pointer != nullptr);
+ buffers.emplace_back(new arrow::MutableBuffer(evicted_entries[i]->pointer,
+ evicted_entries[i]->data_size));
+ }
+ if (external_store_->Get(evicted_ids, buffers).ok()) {
+ for (size_t i = 0; i < evicted_ids.size(); ++i) {
+ evicted_entries[i]->state = ObjectState::PLASMA_SEALED;
+ std::memcpy(&evicted_entries[i]->digest[0], &digest[0], kDigestSize);
+ evicted_entries[i]->construct_duration =
+ std::time(nullptr) - evicted_entries[i]->create_time;
+ PlasmaObject_init(&get_req->objects[evicted_ids[i]], evicted_entries[i]);
+ get_req->num_satisfied += 1;
+ }
+ } else {
+ // We tried to get the objects from the external store, but could not get them.
+ // Set the state of these objects back to PLASMA_EVICTED so some other request
+ // can try again.
+ for (size_t i = 0; i < evicted_ids.size(); ++i) {
+ evicted_entries[i]->state = ObjectState::PLASMA_EVICTED;
+ }
+ }
+ }
+
+ // If all of the objects are present already or if the timeout is 0, return to
+ // the client.
+ if (get_req->num_satisfied == get_req->num_objects_to_wait_for || timeout_ms == 0) {
+ ReturnFromGet(get_req);
+ } else if (timeout_ms != -1) {
+ // Set a timer that will cause the get request to return to the client. Note
+ // that a timeout of -1 is used to indicate that no timer should be set.
+ get_req->timer = loop_->AddTimer(timeout_ms, [this, get_req](int64_t timer_id) {
+ ReturnFromGet(get_req);
+ return kEventLoopTimerDone;
+ });
+ }
+}
+
+int PlasmaStore::RemoveFromClientObjectIds(const ObjectID& object_id,
+ ObjectTableEntry* entry, Client* client) {
+ auto it = client->object_ids.find(object_id);
+ if (it != client->object_ids.end()) {
+ client->object_ids.erase(it);
+ // Decrease reference count.
+ entry->ref_count--;
+
+ // If no more clients are using this object, notify the eviction policy
+ // that the object is no longer being used.
+ if (entry->ref_count == 0) {
+ if (deletion_cache_.count(object_id) == 0) {
+ // Tell the eviction policy that this object is no longer being used.
+ eviction_policy_.EndObjectAccess(object_id);
+ } else {
+ // Above code does not really delete an object. Instead, it just put an
+ // object to LRU cache which will be cleaned when the memory is not enough.
+ deletion_cache_.erase(object_id);
+ EvictObjects({object_id});
+ }
+ }
+ // Return 1 to indicate that the client was removed.
+ return 1;
+ } else {
+ // Return 0 to indicate that the client was not removed.
+ return 0;
+ }
+}
+
+void PlasmaStore::EraseFromObjectTable(const ObjectID& object_id) {
+ auto& object = store_info_.objects[object_id];
+ auto buff_size = object->data_size + object->metadata_size;
+ if (object->device_num == 0) {
+ PlasmaAllocator::Free(object->pointer, buff_size);
+ } else {
+#ifdef PLASMA_CUDA
+ ARROW_CHECK_OK(FreeCudaMemory(object->device_num, buff_size, object->pointer));
+#endif
+ }
+ store_info_.objects.erase(object_id);
+}
+
+void PlasmaStore::ReleaseObject(const ObjectID& object_id, Client* client) {
+ auto entry = GetObjectTableEntry(&store_info_, object_id);
+ ARROW_CHECK(entry != nullptr);
+ // Remove the client from the object's array of clients.
+ ARROW_CHECK(RemoveFromClientObjectIds(object_id, entry, client) == 1);
+}
+
+// Check if an object is present.
+ObjectStatus PlasmaStore::ContainsObject(const ObjectID& object_id) {
+ auto entry = GetObjectTableEntry(&store_info_, object_id);
+ return entry && (entry->state == ObjectState::PLASMA_SEALED ||
+ entry->state == ObjectState::PLASMA_EVICTED)
+ ? ObjectStatus::OBJECT_FOUND
+ : ObjectStatus::OBJECT_NOT_FOUND;
+}
+
+void PlasmaStore::SealObjects(const std::vector<ObjectID>& object_ids,
+ const std::vector<std::string>& digests) {
+ std::vector<ObjectInfoT> infos;
+
+ ARROW_LOG(DEBUG) << "sealing " << object_ids.size() << " objects";
+ for (size_t i = 0; i < object_ids.size(); ++i) {
+ ObjectInfoT object_info;
+ auto entry = GetObjectTableEntry(&store_info_, object_ids[i]);
+ ARROW_CHECK(entry != nullptr);
+ ARROW_CHECK(entry->state == ObjectState::PLASMA_CREATED);
+ // Set the state of object to SEALED.
+ entry->state = ObjectState::PLASMA_SEALED;
+ // Set the object digest.
+ std::memcpy(&entry->digest[0], digests[i].c_str(), kDigestSize);
+ // Set object construction duration.
+ entry->construct_duration = std::time(nullptr) - entry->create_time;
+
+ object_info.object_id = object_ids[i].binary();
+ object_info.data_size = entry->data_size;
+ object_info.metadata_size = entry->metadata_size;
+ object_info.digest = digests[i];
+ infos.push_back(object_info);
+ }
+
+ PushNotifications(infos);
+
+ for (size_t i = 0; i < object_ids.size(); ++i) {
+ UpdateObjectGetRequests(object_ids[i]);
+ }
+}
+
+int PlasmaStore::AbortObject(const ObjectID& object_id, Client* client) {
+ auto entry = GetObjectTableEntry(&store_info_, object_id);
+ ARROW_CHECK(entry != nullptr) << "To abort an object it must be in the object table.";
+ ARROW_CHECK(entry->state != ObjectState::PLASMA_SEALED)
+ << "To abort an object it must not have been sealed.";
+ auto it = client->object_ids.find(object_id);
+ if (it == client->object_ids.end()) {
+ // If the client requesting the abort is not the creator, do not
+ // perform the abort.
+ return 0;
+ } else {
+ // The client requesting the abort is the creator. Free the object.
+ EraseFromObjectTable(object_id);
+ client->object_ids.erase(it);
+ return 1;
+ }
+}
+
+PlasmaError PlasmaStore::DeleteObject(ObjectID& object_id) {
+ auto entry = GetObjectTableEntry(&store_info_, object_id);
+ // TODO(rkn): This should probably not fail, but should instead throw an
+ // error. Maybe we should also support deleting objects that have been
+ // created but not sealed.
+ if (entry == nullptr) {
+ // To delete an object it must be in the object table.
+ return PlasmaError::ObjectNotFound;
+ }
+
+ if (entry->state != ObjectState::PLASMA_SEALED) {
+ // To delete an object it must have been sealed.
+ // Put it into deletion cache, it will be deleted later.
+ deletion_cache_.emplace(object_id);
+ return PlasmaError::ObjectNotSealed;
+ }
+
+ if (entry->ref_count != 0) {
+ // To delete an object, there must be no clients currently using it.
+ // Put it into deletion cache, it will be deleted later.
+ deletion_cache_.emplace(object_id);
+ return PlasmaError::ObjectInUse;
+ }
+
+ eviction_policy_.RemoveObject(object_id);
+ EraseFromObjectTable(object_id);
+ // Inform all subscribers that the object has been deleted.
+ fb::ObjectInfoT notification;
+ notification.object_id = object_id.binary();
+ notification.is_deletion = true;
+ PushNotification(&notification);
+
+ return PlasmaError::OK;
+}
+
+void PlasmaStore::EvictObjects(const std::vector<ObjectID>& object_ids) {
+ if (object_ids.size() == 0) {
+ return;
+ }
+
+ std::vector<std::shared_ptr<arrow::Buffer>> evicted_object_data;
+ std::vector<ObjectTableEntry*> evicted_entries;
+ for (const auto& object_id : object_ids) {
+ ARROW_LOG(DEBUG) << "evicting object " << object_id.hex();
+ auto entry = GetObjectTableEntry(&store_info_, object_id);
+ // TODO(rkn): This should probably not fail, but should instead throw an
+ // error. Maybe we should also support deleting objects that have been
+ // created but not sealed.
+ ARROW_CHECK(entry != nullptr) << "To evict an object it must be in the object table.";
+ ARROW_CHECK(entry->state == ObjectState::PLASMA_SEALED)
+ << "To evict an object it must have been sealed.";
+ ARROW_CHECK(entry->ref_count == 0)
+ << "To evict an object, there must be no clients currently using it.";
+
+ // If there is a backing external store, then mark object for eviction to
+ // external store, free the object data pointer and keep a placeholder
+ // entry in ObjectTable
+ if (external_store_) {
+ evicted_object_data.push_back(std::make_shared<arrow::Buffer>(
+ entry->pointer, entry->data_size + entry->metadata_size));
+ evicted_entries.push_back(entry);
+ } else {
+ // If there is no backing external store, just erase the object entry
+ // and send a deletion notification.
+ EraseFromObjectTable(object_id);
+ // Inform all subscribers that the object has been deleted.
+ fb::ObjectInfoT notification;
+ notification.object_id = object_id.binary();
+ notification.is_deletion = true;
+ PushNotification(&notification);
+ }
+ }
+
+ if (external_store_ && !object_ids.empty()) {
+ ARROW_CHECK_OK(external_store_->Put(object_ids, evicted_object_data));
+ for (auto entry : evicted_entries) {
+ PlasmaAllocator::Free(entry->pointer, entry->data_size + entry->metadata_size);
+ entry->pointer = nullptr;
+ entry->state = ObjectState::PLASMA_EVICTED;
+ }
+ }
+}
+
+void PlasmaStore::ConnectClient(int listener_sock) {
+ int client_fd = AcceptClient(listener_sock);
+
+ Client* client = new Client(client_fd);
+ connected_clients_[client_fd] = std::unique_ptr<Client>(client);
+
+ // Add a callback to handle events on this socket.
+ // TODO(pcm): Check return value.
+ loop_->AddFileEvent(client_fd, kEventLoopRead, [this, client](int events) {
+ Status s = ProcessMessage(client);
+ if (!s.ok()) {
+ ARROW_LOG(FATAL) << "Failed to process file event: " << s;
+ }
+ });
+ ARROW_LOG(DEBUG) << "New connection with fd " << client_fd;
+}
+
+void PlasmaStore::DisconnectClient(int client_fd) {
+ ARROW_CHECK(client_fd > 0);
+ auto it = connected_clients_.find(client_fd);
+ ARROW_CHECK(it != connected_clients_.end());
+ loop_->RemoveFileEvent(client_fd);
+ // Close the socket.
+ close(client_fd);
+ ARROW_LOG(INFO) << "Disconnecting client on fd " << client_fd;
+ // Release all the objects that the client was using.
+ auto client = it->second.get();
+ eviction_policy_.ClientDisconnected(client);
+ std::unordered_map<ObjectID, ObjectTableEntry*> sealed_objects;
+ for (const auto& object_id : client->object_ids) {
+ auto it = store_info_.objects.find(object_id);
+ if (it == store_info_.objects.end()) {
+ continue;
+ }
+
+ if (it->second->state == ObjectState::PLASMA_SEALED) {
+ // Add sealed objects to a temporary list of object IDs. Do not perform
+ // the remove here, since it potentially modifies the object_ids table.
+ sealed_objects[it->first] = it->second.get();
+ } else {
+ // Abort unsealed object.
+ // Don't call AbortObject() because client->object_ids would be modified.
+ EraseFromObjectTable(object_id);
+ }
+ }
+
+ /// Remove all of the client's GetRequests.
+ RemoveGetRequestsForClient(client);
+
+ for (const auto& entry : sealed_objects) {
+ RemoveFromClientObjectIds(entry.first, entry.second, client);
+ }
+
+ if (client->notification_fd > 0) {
+ // This client has subscribed for notifications.
+ auto notify_fd = client->notification_fd;
+ loop_->RemoveFileEvent(notify_fd);
+ // Close socket.
+ close(notify_fd);
+ // Remove notification queue for this fd from global map.
+ pending_notifications_.erase(notify_fd);
+ // Reset fd.
+ client->notification_fd = -1;
+ }
+
+ connected_clients_.erase(it);
+}
+
+/// Send notifications about sealed objects to the subscribers. This is called
+/// in SealObject. If the socket's send buffer is full, the notification will
+/// be buffered, and this will be called again when the send buffer has room.
+/// Since we call erase on pending_notifications_, all iterators get
+/// invalidated, which is why we return a valid iterator to the next client to
+/// be used in PushNotification.
+///
+/// \param it Iterator that points to the client to send the notification to.
+/// \return Iterator pointing to the next client.
+PlasmaStore::NotificationMap::iterator PlasmaStore::SendNotifications(
+ PlasmaStore::NotificationMap::iterator it) {
+ int client_fd = it->first;
+ auto& notifications = it->second.object_notifications;
+
+ int num_processed = 0;
+ bool closed = false;
+ // Loop over the array of pending notifications and send as many of them as
+ // possible.
+ for (size_t i = 0; i < notifications.size(); ++i) {
+ auto& notification = notifications.at(i);
+ // Decode the length, which is the first bytes of the message.
+ int64_t size = *(reinterpret_cast<int64_t*>(notification.get()));
+
+ // Attempt to send a notification about this object ID.
+ ssize_t nbytes = send(client_fd, notification.get(), sizeof(int64_t) + size, 0);
+ if (nbytes >= 0) {
+ ARROW_CHECK(nbytes == static_cast<ssize_t>(sizeof(int64_t)) + size);
+ } else if (nbytes == -1 &&
+ (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)) {
+ ARROW_LOG(DEBUG) << "The socket's send buffer is full, so we are caching this "
+ "notification and will send it later.";
+ // Add a callback to the event loop to send queued notifications whenever
+ // there is room in the socket's send buffer. Callbacks can be added
+ // more than once here and will be overwritten. The callback is removed
+ // at the end of the method.
+ // TODO(pcm): Introduce status codes and check in case the file descriptor
+ // is added twice.
+ loop_->AddFileEvent(client_fd, kEventLoopWrite, [this, client_fd](int events) {
+ SendNotifications(pending_notifications_.find(client_fd));
+ });
+ break;
+ } else {
+ ARROW_LOG(WARNING) << "Failed to send notification to client on fd " << client_fd;
+ if (errno == EPIPE) {
+ closed = true;
+ break;
+ }
+ }
+ num_processed += 1;
+ }
+ // Remove the sent notifications from the array.
+ notifications.erase(notifications.begin(), notifications.begin() + num_processed);
+
+ // If we have sent all notifications, remove the fd from the event loop.
+ if (notifications.empty()) {
+ loop_->RemoveFileEvent(client_fd);
+ }
+
+ // Stop sending notifications if the pipe was broken.
+ if (closed) {
+ close(client_fd);
+ return pending_notifications_.erase(it);
+ } else {
+ return ++it;
+ }
+}
+
+void PlasmaStore::PushNotification(fb::ObjectInfoT* object_info) {
+ auto it = pending_notifications_.begin();
+ while (it != pending_notifications_.end()) {
+ std::vector<fb::ObjectInfoT> info;
+ info.push_back(*object_info);
+ auto notification = CreatePlasmaNotificationBuffer(info);
+ it->second.object_notifications.emplace_back(std::move(notification));
+ it = SendNotifications(it);
+ }
+}
+
+void PlasmaStore::PushNotifications(std::vector<fb::ObjectInfoT>& object_info) {
+ auto it = pending_notifications_.begin();
+ while (it != pending_notifications_.end()) {
+ auto notifications = CreatePlasmaNotificationBuffer(object_info);
+ it->second.object_notifications.emplace_back(std::move(notifications));
+ it = SendNotifications(it);
+ }
+}
+
+void PlasmaStore::PushNotification(fb::ObjectInfoT* object_info, int client_fd) {
+ auto it = pending_notifications_.find(client_fd);
+ if (it != pending_notifications_.end()) {
+ std::vector<fb::ObjectInfoT> info;
+ info.push_back(*object_info);
+ auto notification = CreatePlasmaNotificationBuffer(info);
+ it->second.object_notifications.emplace_back(std::move(notification));
+ SendNotifications(it);
+ }
+}
+
+// Subscribe to notifications about sealed objects.
+void PlasmaStore::SubscribeToUpdates(Client* client) {
+ ARROW_LOG(DEBUG) << "subscribing to updates on fd " << client->fd;
+ if (client->notification_fd > 0) {
+ // This client has already subscribed. Return.
+ return;
+ }
+
+ // TODO(rkn): The store could block here if the client doesn't send a file
+ // descriptor.
+ int fd = recv_fd(client->fd);
+ if (fd < 0) {
+ // This may mean that the client died before sending the file descriptor.
+ ARROW_LOG(WARNING) << "Failed to receive file descriptor from client on fd "
+ << client->fd << ".";
+ return;
+ }
+
+ // Add this fd to global map, which is needed for this client to receive notifications.
+ pending_notifications_[fd];
+ client->notification_fd = fd;
+
+ // Push notifications to the new subscriber about existing sealed objects.
+ for (const auto& entry : store_info_.objects) {
+ if (entry.second->state == ObjectState::PLASMA_SEALED) {
+ ObjectInfoT info;
+ info.object_id = entry.first.binary();
+ info.data_size = entry.second->data_size;
+ info.metadata_size = entry.second->metadata_size;
+ info.digest =
+ std::string(reinterpret_cast<char*>(&entry.second->digest[0]), kDigestSize);
+ PushNotification(&info, fd);
+ }
+ }
+}
+
+Status PlasmaStore::ProcessMessage(Client* client) {
+ fb::MessageType type;
+ Status s = ReadMessage(client->fd, &type, &input_buffer_);
+ ARROW_CHECK(s.ok() || s.IsIOError());
+
+ uint8_t* input = input_buffer_.data();
+ size_t input_size = input_buffer_.size();
+ ObjectID object_id;
+ PlasmaObject object = {};
+
+ // Process the different types of requests.
+ switch (type) {
+ case fb::MessageType::PlasmaCreateRequest: {
+ bool evict_if_full;
+ int64_t data_size;
+ int64_t metadata_size;
+ int device_num;
+ RETURN_NOT_OK(ReadCreateRequest(input, input_size, &object_id, &evict_if_full,
+ &data_size, &metadata_size, &device_num));
+ PlasmaError error_code = CreateObject(object_id, evict_if_full, data_size,
+ metadata_size, device_num, client, &object);
+ int64_t mmap_size = 0;
+ if (error_code == PlasmaError::OK && device_num == 0) {
+ mmap_size = GetMmapSize(object.store_fd);
+ }
+ HANDLE_SIGPIPE(
+ SendCreateReply(client->fd, object_id, &object, error_code, mmap_size),
+ client->fd);
+ // Only send the file descriptor if it hasn't been sent (see analogous
+ // logic in GetStoreFd in client.cc). Similar in ReturnFromGet.
+ if (error_code == PlasmaError::OK && device_num == 0 &&
+ client->used_fds.find(object.store_fd) == client->used_fds.end()) {
+ WarnIfSigpipe(send_fd(client->fd, object.store_fd), client->fd);
+ client->used_fds.insert(object.store_fd);
+ }
+ } break;
+ case fb::MessageType::PlasmaCreateAndSealRequest: {
+ bool evict_if_full;
+ std::string data;
+ std::string metadata;
+ std::string digest;
+ digest.reserve(kDigestSize);
+ RETURN_NOT_OK(ReadCreateAndSealRequest(input, input_size, &object_id,
+ &evict_if_full, &data, &metadata, &digest));
+ // CreateAndSeal currently only supports device_num = 0, which corresponds
+ // to the host.
+ int device_num = 0;
+ PlasmaError error_code = CreateObject(object_id, evict_if_full, data.size(),
+ metadata.size(), device_num, client, &object);
+
+ // If the object was successfully created, fill out the object data and seal it.
+ if (error_code == PlasmaError::OK) {
+ auto entry = GetObjectTableEntry(&store_info_, object_id);
+ ARROW_CHECK(entry != nullptr);
+ // Write the inlined data and metadata into the allocated object.
+ std::memcpy(entry->pointer, data.data(), data.size());
+ std::memcpy(entry->pointer + data.size(), metadata.data(), metadata.size());
+ SealObjects({object_id}, {digest});
+ // Remove the client from the object's array of clients because the
+ // object is not being used by any client. The client was added to the
+ // object's array of clients in CreateObject. This is analogous to the
+ // Release call that happens in the client's Seal method.
+ ARROW_CHECK(RemoveFromClientObjectIds(object_id, entry, client) == 1);
+ }
+
+ // Reply to the client.
+ HANDLE_SIGPIPE(SendCreateAndSealReply(client->fd, error_code), client->fd);
+ } break;
+ case fb::MessageType::PlasmaCreateAndSealBatchRequest: {
+ bool evict_if_full;
+ std::vector<ObjectID> object_ids;
+ std::vector<std::string> data;
+ std::vector<std::string> metadata;
+ std::vector<std::string> digests;
+
+ RETURN_NOT_OK(ReadCreateAndSealBatchRequest(
+ input, input_size, &object_ids, &evict_if_full, &data, &metadata, &digests));
+
+ // CreateAndSeal currently only supports device_num = 0, which corresponds
+ // to the host.
+ int device_num = 0;
+ size_t i = 0;
+ PlasmaError error_code = PlasmaError::OK;
+ for (i = 0; i < object_ids.size(); i++) {
+ error_code = CreateObject(object_ids[i], evict_if_full, data[i].size(),
+ metadata[i].size(), device_num, client, &object);
+ if (error_code != PlasmaError::OK) {
+ break;
+ }
+ }
+
+ // if OK, seal all the objects,
+ // if error, abort the previous i objects immediately
+ if (error_code == PlasmaError::OK) {
+ for (i = 0; i < object_ids.size(); i++) {
+ auto entry = GetObjectTableEntry(&store_info_, object_ids[i]);
+ ARROW_CHECK(entry != nullptr);
+ // Write the inlined data and metadata into the allocated object.
+ std::memcpy(entry->pointer, data[i].data(), data[i].size());
+ std::memcpy(entry->pointer + data[i].size(), metadata[i].data(),
+ metadata[i].size());
+ }
+
+ SealObjects(object_ids, digests);
+ // Remove the client from the object's array of clients because the
+ // object is not being used by any client. The client was added to the
+ // object's array of clients in CreateObject. This is analogous to the
+ // Release call that happens in the client's Seal method.
+ for (i = 0; i < object_ids.size(); i++) {
+ auto entry = GetObjectTableEntry(&store_info_, object_ids[i]);
+ ARROW_CHECK(RemoveFromClientObjectIds(object_ids[i], entry, client) == 1);
+ }
+ } else {
+ for (size_t j = 0; j < i; j++) {
+ AbortObject(object_ids[j], client);
+ }
+ }
+
+ HANDLE_SIGPIPE(SendCreateAndSealBatchReply(client->fd, error_code), client->fd);
+ } break;
+ case fb::MessageType::PlasmaAbortRequest: {
+ RETURN_NOT_OK(ReadAbortRequest(input, input_size, &object_id));
+ ARROW_CHECK(AbortObject(object_id, client) == 1) << "To abort an object, the only "
+ "client currently using it "
+ "must be the creator.";
+ HANDLE_SIGPIPE(SendAbortReply(client->fd, object_id), client->fd);
+ } break;
+ case fb::MessageType::PlasmaGetRequest: {
+ std::vector<ObjectID> object_ids_to_get;
+ int64_t timeout_ms;
+ RETURN_NOT_OK(ReadGetRequest(input, input_size, object_ids_to_get, &timeout_ms));
+ ProcessGetRequest(client, object_ids_to_get, timeout_ms);
+ } break;
+ case fb::MessageType::PlasmaReleaseRequest: {
+ RETURN_NOT_OK(ReadReleaseRequest(input, input_size, &object_id));
+ ReleaseObject(object_id, client);
+ } break;
+ case fb::MessageType::PlasmaDeleteRequest: {
+ std::vector<ObjectID> object_ids;
+ std::vector<PlasmaError> error_codes;
+ RETURN_NOT_OK(ReadDeleteRequest(input, input_size, &object_ids));
+ error_codes.reserve(object_ids.size());
+ for (auto& object_id : object_ids) {
+ error_codes.push_back(DeleteObject(object_id));
+ }
+ HANDLE_SIGPIPE(SendDeleteReply(client->fd, object_ids, error_codes), client->fd);
+ } break;
+ case fb::MessageType::PlasmaContainsRequest: {
+ RETURN_NOT_OK(ReadContainsRequest(input, input_size, &object_id));
+ if (ContainsObject(object_id) == ObjectStatus::OBJECT_FOUND) {
+ HANDLE_SIGPIPE(SendContainsReply(client->fd, object_id, 1), client->fd);
+ } else {
+ HANDLE_SIGPIPE(SendContainsReply(client->fd, object_id, 0), client->fd);
+ }
+ } break;
+ case fb::MessageType::PlasmaListRequest: {
+ RETURN_NOT_OK(ReadListRequest(input, input_size));
+ HANDLE_SIGPIPE(SendListReply(client->fd, store_info_.objects), client->fd);
+ } break;
+ case fb::MessageType::PlasmaSealRequest: {
+ std::string digest;
+ RETURN_NOT_OK(ReadSealRequest(input, input_size, &object_id, &digest));
+ SealObjects({object_id}, {digest});
+ HANDLE_SIGPIPE(SendSealReply(client->fd, object_id, PlasmaError::OK), client->fd);
+ } break;
+ case fb::MessageType::PlasmaEvictRequest: {
+ // This code path should only be used for testing.
+ int64_t num_bytes;
+ RETURN_NOT_OK(ReadEvictRequest(input, input_size, &num_bytes));
+ std::vector<ObjectID> objects_to_evict;
+ int64_t num_bytes_evicted =
+ eviction_policy_.ChooseObjectsToEvict(num_bytes, &objects_to_evict);
+ EvictObjects(objects_to_evict);
+ HANDLE_SIGPIPE(SendEvictReply(client->fd, num_bytes_evicted), client->fd);
+ } break;
+ case fb::MessageType::PlasmaRefreshLRURequest: {
+ std::vector<ObjectID> object_ids;
+ RETURN_NOT_OK(ReadRefreshLRURequest(input, input_size, &object_ids));
+ eviction_policy_.RefreshObjects(object_ids);
+ HANDLE_SIGPIPE(SendRefreshLRUReply(client->fd), client->fd);
+ } break;
+ case fb::MessageType::PlasmaSubscribeRequest:
+ SubscribeToUpdates(client);
+ break;
+ case fb::MessageType::PlasmaConnectRequest: {
+ HANDLE_SIGPIPE(SendConnectReply(client->fd, PlasmaAllocator::GetFootprintLimit()),
+ client->fd);
+ } break;
+ case fb::MessageType::PlasmaDisconnectClient:
+ ARROW_LOG(DEBUG) << "Disconnecting client on fd " << client->fd;
+ DisconnectClient(client->fd);
+ break;
+ case fb::MessageType::PlasmaSetOptionsRequest: {
+ std::string client_name;
+ int64_t output_memory_quota;
+ RETURN_NOT_OK(
+ ReadSetOptionsRequest(input, input_size, &client_name, &output_memory_quota));
+ client->name = client_name;
+ bool success = eviction_policy_.SetClientQuota(client, output_memory_quota);
+ HANDLE_SIGPIPE(SendSetOptionsReply(client->fd, success ? PlasmaError::OK
+ : PlasmaError::OutOfMemory),
+ client->fd);
+ } break;
+ case fb::MessageType::PlasmaGetDebugStringRequest: {
+ HANDLE_SIGPIPE(SendGetDebugStringReply(client->fd, eviction_policy_.DebugString()),
+ client->fd);
+ } break;
+ default:
+ // This code should be unreachable.
+ ARROW_CHECK(0);
+ }
+ return Status::OK();
+}
+
+class PlasmaStoreRunner {
+ public:
+ PlasmaStoreRunner() {}
+
+ void Start(char* socket_name, std::string directory, bool hugepages_enabled,
+ std::shared_ptr<ExternalStore> external_store) {
+ // Create the event loop.
+ loop_.reset(new EventLoop);
+ store_.reset(new PlasmaStore(loop_.get(), directory, hugepages_enabled, socket_name,
+ external_store));
+ plasma_config = store_->GetPlasmaStoreInfo();
+
+ // We are using a single memory-mapped file by mallocing and freeing a single
+ // large amount of space up front. According to the documentation,
+ // dlmalloc might need up to 128*sizeof(size_t) bytes for internal
+ // bookkeeping.
+ void* pointer = plasma::PlasmaAllocator::Memalign(
+ kBlockSize, PlasmaAllocator::GetFootprintLimit() - 256 * sizeof(size_t));
+ ARROW_CHECK(pointer != nullptr);
+ // This will unmap the file, but the next one created will be as large
+ // as this one (this is an implementation detail of dlmalloc).
+ plasma::PlasmaAllocator::Free(
+ pointer, PlasmaAllocator::GetFootprintLimit() - 256 * sizeof(size_t));
+
+ int socket = BindIpcSock(socket_name, true);
+ // TODO(pcm): Check return value.
+ ARROW_CHECK(socket >= 0);
+
+ loop_->AddFileEvent(socket, kEventLoopRead, [this, socket](int events) {
+ this->store_->ConnectClient(socket);
+ });
+ loop_->Start();
+ }
+
+ void Stop() { loop_->Stop(); }
+
+ void Shutdown() {
+ loop_->Shutdown();
+ loop_ = nullptr;
+ store_ = nullptr;
+ }
+
+ private:
+ std::unique_ptr<EventLoop> loop_;
+ std::unique_ptr<PlasmaStore> store_;
+};
+
+static std::unique_ptr<PlasmaStoreRunner> g_runner = nullptr;
+
+void HandleSignal(int signal) {
+ if (signal == SIGTERM) {
+ ARROW_LOG(INFO) << "SIGTERM Signal received, closing Plasma Server...";
+ if (g_runner != nullptr) {
+ g_runner->Stop();
+ }
+ }
+}
+
+void StartServer(char* socket_name, std::string plasma_directory, bool hugepages_enabled,
+ std::shared_ptr<ExternalStore> external_store) {
+ // Ignore SIGPIPE signals. If we don't do this, then when we attempt to write
+ // to a client that has already died, the store could die.
+ signal(SIGPIPE, SIG_IGN);
+
+ g_runner.reset(new PlasmaStoreRunner());
+ signal(SIGTERM, HandleSignal);
+ g_runner->Start(socket_name, plasma_directory, hugepages_enabled, external_store);
+}
+
+// Function to use (instead of ARROW_LOG(FATAL)) for usage, etc. errors before
+// the main server loop starts, so users don't get a backtrace if they
+// simply forgot a command-line switch.
+void ExitWithUsageError(const char* error_msg) {
+ std::cerr << gflags::ProgramInvocationShortName() << ": " << error_msg << std::endl;
+ exit(1);
+}
+
+} // namespace plasma
+
+#ifdef __linux__
+#define SHM_DEFAULT_PATH "/dev/shm"
+#else
+#define SHM_DEFAULT_PATH "/tmp"
+#endif
+
+// Command-line flags.
+DEFINE_string(d, SHM_DEFAULT_PATH, "directory where to create the memory-backed file");
+DEFINE_string(e, "",
+ "endpoint for external storage service, where objects "
+ "evicted from Plasma store can be written to, optional");
+DEFINE_bool(h, false, "whether to enable hugepage support");
+DEFINE_string(s, "",
+ "socket name where the Plasma store will listen for requests, required");
+DEFINE_string(m, "", "amount of memory in bytes to use for Plasma store, required");
+
+int main(int argc, char* argv[]) {
+ ArrowLog::StartArrowLog(argv[0], ArrowLogLevel::ARROW_INFO);
+ ArrowLog::InstallFailureSignalHandler();
+
+ gflags::SetUsageMessage("Shared-memory server for Arrow data.\nUsage: ");
+ gflags::SetVersionString(ARROW_VERSION_STRING);
+
+ char* socket_name = nullptr;
+ // Directory where plasma memory mapped files are stored.
+ std::string plasma_directory;
+ std::string external_store_endpoint;
+ bool hugepages_enabled = false;
+ int64_t system_memory = -1;
+
+ gflags::ParseCommandLineFlags(&argc, &argv, /*remove_flags=*/true);
+ plasma_directory = FLAGS_d;
+ external_store_endpoint = FLAGS_e;
+ hugepages_enabled = FLAGS_h;
+ if (!FLAGS_s.empty()) {
+ // We only check below if socket_name is null, so don't set it if the flag was empty.
+ socket_name = const_cast<char*>(FLAGS_s.c_str());
+ }
+
+ if (!FLAGS_m.empty()) {
+ char extra;
+ int scanned = sscanf(FLAGS_m.c_str(), "%" SCNd64 "%c", &system_memory, &extra);
+ if (scanned != 1) {
+ plasma::ExitWithUsageError(
+ "-m switch takes memory in bytes, with no letter suffix allowed");
+ }
+
+ // Set system memory capacity
+ plasma::PlasmaAllocator::SetFootprintLimit(static_cast<size_t>(system_memory));
+ ARROW_LOG(INFO) << "Allowing the Plasma store to use up to "
+ << static_cast<double>(system_memory) / 1000000000 << "GB of memory.";
+ }
+
+ // Sanity check command line options.
+ if (socket_name == nullptr && system_memory == -1) {
+ // Nicer error message for the case where the user ran the program without
+ // any of the required command-line switches.
+ plasma::ExitWithUsageError(
+ "please specify socket for incoming connections with -s, "
+ "and the amount of memory (in bytes) to use with -m");
+ } else if (socket_name == nullptr) {
+ plasma::ExitWithUsageError("please specify socket for incoming connections with -s");
+ } else if (system_memory == -1) {
+ plasma::ExitWithUsageError(
+ "please specify the amount of memory (in bytes) to use with -m");
+ }
+ if (hugepages_enabled && plasma_directory.empty()) {
+ plasma::ExitWithUsageError(
+ "if you want to use hugepages, please specify path to huge pages "
+ "filesystem with -d");
+ }
+ ARROW_CHECK(!plasma_directory.empty());
+ ARROW_LOG(INFO) << "Starting object store with directory " << plasma_directory
+ << " and huge page support "
+ << (hugepages_enabled ? "enabled" : "disabled");
+
+#ifdef __linux__
+ if (!hugepages_enabled) {
+ // On Linux, check that the amount of memory available in /dev/shm is large
+ // enough to accommodate the request. If it isn't, then fail.
+ int shm_fd = open(plasma_directory.c_str(), O_RDONLY);
+ struct statvfs shm_vfs_stats;
+ fstatvfs(shm_fd, &shm_vfs_stats);
+ // The value shm_vfs_stats.f_bsize is the block size, and the value
+ // shm_vfs_stats.f_bavail is the number of available blocks.
+ int64_t shm_mem_avail = shm_vfs_stats.f_bsize * shm_vfs_stats.f_bavail;
+ close(shm_fd);
+ // Keep some safety margin for allocator fragmentation.
+ shm_mem_avail = 9 * shm_mem_avail / 10;
+ if (system_memory > shm_mem_avail) {
+ ARROW_LOG(WARNING)
+ << "System memory request exceeds memory available in " << plasma_directory
+ << ". The request is for " << system_memory
+ << " bytes, and the amount available is " << shm_mem_avail
+ << " bytes. You may be able to free up space by deleting files in "
+ "/dev/shm. If you are inside a Docker container, you may need to "
+ "pass an argument with the flag '--shm-size' to 'docker run'.";
+ system_memory = shm_mem_avail;
+ }
+ } else {
+ plasma::SetMallocGranularity(1024 * 1024 * 1024); // 1 GiB
+ }
+#endif
+
+ // Get external store
+ std::shared_ptr<plasma::ExternalStore> external_store{nullptr};
+ if (!external_store_endpoint.empty()) {
+ std::string name;
+ ARROW_CHECK_OK(
+ plasma::ExternalStores::ExtractStoreName(external_store_endpoint, &name));
+ external_store = plasma::ExternalStores::GetStore(name);
+ if (external_store == nullptr) {
+ std::ostringstream error_msg;
+ error_msg << "no such external store \"" << name << "\"";
+ plasma::ExitWithUsageError(error_msg.str().c_str());
+ }
+ ARROW_LOG(DEBUG) << "connecting to external store...";
+ ARROW_CHECK_OK(external_store->Connect(external_store_endpoint));
+ }
+
+ ARROW_LOG(DEBUG) << "starting server listening on " << socket_name;
+ plasma::StartServer(socket_name, plasma_directory, hugepages_enabled, external_store);
+ plasma::g_runner->Shutdown();
+ plasma::g_runner = nullptr;
+
+ ArrowLog::UninstallSignalAction();
+ ArrowLog::ShutDownArrowLog();
+ return 0;
+}
diff --git a/src/arrow/cpp/src/plasma/store.h b/src/arrow/cpp/src/plasma/store.h
new file mode 100644
index 000000000..182798938
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/store.h
@@ -0,0 +1,245 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <deque>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "plasma/common.h"
+#include "plasma/events.h"
+#include "plasma/external_store.h"
+#include "plasma/plasma.h"
+#include "plasma/protocol.h"
+#include "plasma/quota_aware_policy.h"
+
+namespace arrow {
+class Status;
+} // namespace arrow
+
+namespace plasma {
+
+namespace flatbuf {
+struct ObjectInfoT;
+enum class PlasmaError;
+} // namespace flatbuf
+
+using flatbuf::ObjectInfoT;
+using flatbuf::PlasmaError;
+
+struct GetRequest;
+
+struct NotificationQueue {
+ /// The object notifications for clients. We notify the client about the
+ /// objects in the order that the objects were sealed or deleted.
+ std::deque<std::unique_ptr<uint8_t[]>> object_notifications;
+};
+
+class PlasmaStore {
+ public:
+ using NotificationMap = std::unordered_map<int, NotificationQueue>;
+
+ // TODO: PascalCase PlasmaStore methods.
+ PlasmaStore(EventLoop* loop, std::string directory, bool hugepages_enabled,
+ const std::string& socket_name,
+ std::shared_ptr<ExternalStore> external_store);
+
+ ~PlasmaStore();
+
+ /// Get a const pointer to the internal PlasmaStoreInfo object.
+ const PlasmaStoreInfo* GetPlasmaStoreInfo();
+
+ /// Create a new object. The client must do a call to release_object to tell
+ /// the store when it is done with the object.
+ ///
+ /// \param object_id Object ID of the object to be created.
+ /// \param evict_if_full If this is true, then when the object store is full,
+ /// try to evict objects that are not currently referenced before
+ /// creating the object. Else, do not evict any objects and
+ /// immediately return an PlasmaError::OutOfMemory.
+ /// \param data_size Size in bytes of the object to be created.
+ /// \param metadata_size Size in bytes of the object metadata.
+ /// \param device_num The number of the device where the object is being
+ /// created.
+ /// device_num = 0 corresponds to the host,
+ /// device_num = 1 corresponds to GPU0,
+ /// device_num = 2 corresponds to GPU1, etc.
+ /// \param client The client that created the object.
+ /// \param result The object that has been created.
+ /// \return One of the following error codes:
+ /// - PlasmaError::OK, if the object was created successfully.
+ /// - PlasmaError::ObjectExists, if an object with this ID is already
+ /// present in the store. In this case, the client should not call
+ /// plasma_release.
+ /// - PlasmaError::OutOfMemory, if the store is out of memory and
+ /// cannot create the object. In this case, the client should not call
+ /// plasma_release.
+ PlasmaError CreateObject(const ObjectID& object_id, bool evict_if_full,
+ int64_t data_size, int64_t metadata_size, int device_num,
+ Client* client, PlasmaObject* result);
+
+ /// Abort a created but unsealed object. If the client is not the
+ /// creator, then the abort will fail.
+ ///
+ /// \param object_id Object ID of the object to be aborted.
+ /// \param client The client who created the object. If this does not
+ /// match the creator of the object, then the abort will fail.
+ /// \return 1 if the abort succeeds, else 0.
+ int AbortObject(const ObjectID& object_id, Client* client);
+
+ /// Delete a specific object by object_id that have been created in the hash table.
+ ///
+ /// \param object_id Object ID of the object to be deleted.
+ /// \return One of the following error codes:
+ /// - PlasmaError::OK, if the object was delete successfully.
+ /// - PlasmaError::ObjectNotFound, if ths object isn't existed.
+ /// - PlasmaError::ObjectInUse, if the object is in use.
+ PlasmaError DeleteObject(ObjectID& object_id);
+
+ /// Evict objects returned by the eviction policy.
+ ///
+ /// \param object_ids Object IDs of the objects to be evicted.
+ void EvictObjects(const std::vector<ObjectID>& object_ids);
+
+ /// Process a get request from a client. This method assumes that we will
+ /// eventually have these objects sealed. If one of the objects has not yet
+ /// been sealed, the client that requested the object will be notified when it
+ /// is sealed.
+ ///
+ /// For each object, the client must do a call to release_object to tell the
+ /// store when it is done with the object.
+ ///
+ /// \param client The client making this request.
+ /// \param object_ids Object IDs of the objects to be gotten.
+ /// \param timeout_ms The timeout for the get request in milliseconds.
+ void ProcessGetRequest(Client* client, const std::vector<ObjectID>& object_ids,
+ int64_t timeout_ms);
+
+ /// Seal a vector of objects. The objects are now immutable and can be accessed with
+ /// get.
+ ///
+ /// \param object_ids The vector of Object IDs of the objects to be sealed.
+ /// \param digests The vector of digests of the objects. This is used to tell if two
+ /// objects with the same object ID are the same.
+ void SealObjects(const std::vector<ObjectID>& object_ids,
+ const std::vector<std::string>& digests);
+
+ /// Check if the plasma store contains an object:
+ ///
+ /// \param object_id Object ID that will be checked.
+ /// \return OBJECT_FOUND if the object is in the store, OBJECT_NOT_FOUND if
+ /// not
+ ObjectStatus ContainsObject(const ObjectID& object_id);
+
+ /// Record the fact that a particular client is no longer using an object.
+ ///
+ /// \param object_id The object ID of the object that is being released.
+ /// \param client The client making this request.
+ void ReleaseObject(const ObjectID& object_id, Client* client);
+
+ /// Subscribe a file descriptor to updates about new sealed objects.
+ ///
+ /// \param client The client making this request.
+ void SubscribeToUpdates(Client* client);
+
+ /// Connect a new client to the PlasmaStore.
+ ///
+ /// \param listener_sock The socket that is listening to incoming connections.
+ void ConnectClient(int listener_sock);
+
+ /// Disconnect a client from the PlasmaStore.
+ ///
+ /// \param client_fd The client file descriptor that is disconnected.
+ void DisconnectClient(int client_fd);
+
+ NotificationMap::iterator SendNotifications(NotificationMap::iterator it);
+
+ arrow::Status ProcessMessage(Client* client);
+
+ private:
+ void PushNotification(ObjectInfoT* object_notification);
+
+ void PushNotifications(std::vector<ObjectInfoT>& object_notifications);
+
+ void PushNotification(ObjectInfoT* object_notification, int client_fd);
+
+ void AddToClientObjectIds(const ObjectID& object_id, ObjectTableEntry* entry,
+ Client* client);
+
+ /// Remove a GetRequest and clean up the relevant data structures.
+ ///
+ /// \param get_request The GetRequest to remove.
+ void RemoveGetRequest(GetRequest* get_request);
+
+ /// Remove all of the GetRequests for a given client.
+ ///
+ /// \param client The client whose GetRequests should be removed.
+ void RemoveGetRequestsForClient(Client* client);
+
+ void ReturnFromGet(GetRequest* get_req);
+
+ void UpdateObjectGetRequests(const ObjectID& object_id);
+
+ int RemoveFromClientObjectIds(const ObjectID& object_id, ObjectTableEntry* entry,
+ Client* client);
+
+ void EraseFromObjectTable(const ObjectID& object_id);
+
+ uint8_t* AllocateMemory(size_t size, bool evict_if_full, int* fd, int64_t* map_size,
+ ptrdiff_t* offset, Client* client, bool is_create);
+#ifdef PLASMA_CUDA
+ arrow::Result<std::shared_ptr<arrow::cuda::CudaContext>> GetCudaContext(int device_num);
+ Status AllocateCudaMemory(int device_num, int64_t size, uint8_t** out_pointer,
+ std::shared_ptr<CudaIpcMemHandle>* out_ipc_handle);
+
+ Status FreeCudaMemory(int device_num, int64_t size, uint8_t* out_pointer);
+#endif
+
+ /// Event loop of the plasma store.
+ EventLoop* loop_;
+ /// The plasma store information, including the object tables, that is exposed
+ /// to the eviction policy.
+ PlasmaStoreInfo store_info_;
+ /// The state that is managed by the eviction policy.
+ QuotaAwarePolicy eviction_policy_;
+ /// Input buffer. This is allocated only once to avoid mallocs for every
+ /// call to process_message.
+ std::vector<uint8_t> input_buffer_;
+ /// A hash table mapping object IDs to a vector of the get requests that are
+ /// waiting for the object to arrive.
+ std::unordered_map<ObjectID, std::vector<GetRequest*>> object_get_requests_;
+ /// The pending notifications that have not been sent to subscribers because
+ /// the socket send buffers were full. This is a hash table from client file
+ /// descriptor to an array of object_ids to send to that client.
+ /// TODO(pcm): Consider putting this into the Client data structure and
+ /// reorganize the code slightly.
+ NotificationMap pending_notifications_;
+
+ std::unordered_map<int, std::unique_ptr<Client>> connected_clients_;
+
+ std::unordered_set<ObjectID> deletion_cache_;
+
+ /// Manages worker threads for handling asynchronous/multi-threaded requests
+ /// for reading/writing data to/from external store.
+ std::shared_ptr<ExternalStore> external_store_;
+};
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/symbols.map b/src/arrow/cpp/src/plasma/symbols.map
new file mode 100644
index 000000000..32c86daa4
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/symbols.map
@@ -0,0 +1,34 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+{
+ # Symbols marked as 'local' are not exported by the DSO and thus may not
+ # be used by client applications.
+ local:
+ # devtoolset / static-libstdc++ symbols
+ __cxa_*;
+ __once_proxy;
+
+ extern "C++" {
+ # devtoolset or -static-libstdc++ - the Red Hat devtoolset statically
+ # links c++11 symbols into binaries so that the result may be executed on
+ # a system with an older libstdc++ which doesn't include the necessary
+ # c++11 symbols.
+ std::*;
+ *std::__once_call*;
+ };
+};
diff --git a/src/arrow/cpp/src/plasma/test/client_tests.cc b/src/arrow/cpp/src/plasma/test/client_tests.cc
new file mode 100644
index 000000000..e3d517b0a
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/test/client_tests.cc
@@ -0,0 +1,1084 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <assert.h>
+#include <signal.h>
+#include <stdlib.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <memory>
+#include <thread>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/io_util.h"
+
+#include "plasma/client.h"
+#include "plasma/common.h"
+#include "plasma/plasma.h"
+#include "plasma/protocol.h"
+#include "plasma/test_util.h"
+
+namespace plasma {
+
+using arrow::internal::TemporaryDir;
+
+std::string test_executable; // NOLINT
+
+void AssertObjectBufferEqual(const ObjectBuffer& object_buffer,
+ const std::vector<uint8_t>& metadata,
+ const std::vector<uint8_t>& data) {
+ arrow::AssertBufferEqual(*object_buffer.metadata, metadata);
+ arrow::AssertBufferEqual(*object_buffer.data, data);
+}
+
+class TestPlasmaStore : public ::testing::Test {
+ public:
+ // TODO(pcm): At the moment, stdout of the test gets mixed up with
+ // stdout of the object store. Consider changing that.
+
+ void SetUp() {
+ ASSERT_OK_AND_ASSIGN(temp_dir_, TemporaryDir::Make("cli-test-"));
+ store_socket_name_ = temp_dir_->path().ToString() + "store";
+
+ std::string plasma_directory =
+ test_executable.substr(0, test_executable.find_last_of("/"));
+ std::string plasma_command =
+ plasma_directory + "/plasma-store-server -m 10000000 -s " + store_socket_name_ +
+ " 1> /dev/null 2> /dev/null & " + "echo $! > " + store_socket_name_ + ".pid";
+ PLASMA_CHECK_SYSTEM(system(plasma_command.c_str()));
+ ARROW_CHECK_OK(client_.Connect(store_socket_name_, ""));
+ ARROW_CHECK_OK(client2_.Connect(store_socket_name_, ""));
+ }
+
+ virtual void TearDown() {
+ ARROW_CHECK_OK(client_.Disconnect());
+ ARROW_CHECK_OK(client2_.Disconnect());
+ // Kill plasma_store process that we started
+#ifdef COVERAGE_BUILD
+ // Ask plasma_store to exit gracefully and give it time to write out
+ // coverage files
+ std::string plasma_term_command =
+ "kill -TERM `cat " + store_socket_name_ + ".pid` || exit 0";
+ PLASMA_CHECK_SYSTEM(system(plasma_term_command.c_str()));
+ std::this_thread::sleep_for(std::chrono::milliseconds(200));
+#endif
+ std::string plasma_kill_command =
+ "kill -KILL `cat " + store_socket_name_ + ".pid` || exit 0";
+ PLASMA_CHECK_SYSTEM(system(plasma_kill_command.c_str()));
+ }
+
+ void CreateObject(PlasmaClient& client, const ObjectID& object_id,
+ const std::vector<uint8_t>& metadata,
+ const std::vector<uint8_t>& data, bool release = true) {
+ std::shared_ptr<Buffer> data_buffer;
+ ARROW_CHECK_OK(client.Create(object_id, data.size(), metadata.data(), metadata.size(),
+ &data_buffer));
+ for (size_t i = 0; i < data.size(); i++) {
+ data_buffer->mutable_data()[i] = data[i];
+ }
+ ARROW_CHECK_OK(client.Seal(object_id));
+ if (release) {
+ ARROW_CHECK_OK(client.Release(object_id));
+ }
+ }
+
+ protected:
+ PlasmaClient client_;
+ PlasmaClient client2_;
+ std::unique_ptr<TemporaryDir> temp_dir_;
+ std::string store_socket_name_;
+};
+
+TEST_F(TestPlasmaStore, NewSubscriberTest) {
+ PlasmaClient local_client, local_client2;
+
+ ARROW_CHECK_OK(local_client.Connect(store_socket_name_, ""));
+ ARROW_CHECK_OK(local_client2.Connect(store_socket_name_, ""));
+
+ ObjectID object_id = random_object_id();
+
+ // Test for the object being in local Plasma store.
+ // First create object.
+ int64_t data_size = 100;
+ uint8_t metadata[] = {5};
+ int64_t metadata_size = sizeof(metadata);
+ std::shared_ptr<Buffer> data;
+ ARROW_CHECK_OK(
+ local_client.Create(object_id, data_size, metadata, metadata_size, &data));
+ ARROW_CHECK_OK(local_client.Seal(object_id));
+
+ // Test that new subscriber client2 can receive notifications about existing objects.
+ int fd = -1;
+ ARROW_CHECK_OK(local_client2.Subscribe(&fd));
+ ASSERT_GT(fd, 0);
+
+ ObjectID object_id2 = random_object_id();
+ int64_t data_size2 = 0;
+ int64_t metadata_size2 = 0;
+ ARROW_CHECK_OK(
+ local_client2.GetNotification(fd, &object_id2, &data_size2, &metadata_size2));
+ ASSERT_EQ(object_id, object_id2);
+ ASSERT_EQ(data_size, data_size2);
+ ASSERT_EQ(metadata_size, metadata_size2);
+
+ // Delete the object.
+ ARROW_CHECK_OK(local_client.Release(object_id));
+ ARROW_CHECK_OK(local_client.Delete(object_id));
+
+ ARROW_CHECK_OK(
+ local_client2.GetNotification(fd, &object_id2, &data_size2, &metadata_size2));
+ ASSERT_EQ(object_id, object_id2);
+ ASSERT_EQ(-1, data_size2);
+ ASSERT_EQ(-1, metadata_size2);
+
+ ARROW_CHECK_OK(local_client2.Disconnect());
+ ARROW_CHECK_OK(local_client.Disconnect());
+}
+
+TEST_F(TestPlasmaStore, BatchNotificationTest) {
+ PlasmaClient local_client, local_client2;
+
+ ARROW_CHECK_OK(local_client.Connect(store_socket_name_, ""));
+ ARROW_CHECK_OK(local_client2.Connect(store_socket_name_, ""));
+
+ int fd = -1;
+ ARROW_CHECK_OK(local_client2.Subscribe(&fd));
+ ASSERT_GT(fd, 0);
+
+ ObjectID object_id1 = random_object_id();
+ ObjectID object_id2 = random_object_id();
+
+ std::vector<ObjectID> object_ids = {object_id1, object_id2};
+
+ std::vector<std::string> data = {"hello", "world!"};
+ std::vector<std::string> metadata = {"1", "23"};
+ ARROW_CHECK_OK(local_client.CreateAndSealBatch(object_ids, data, metadata));
+
+ ObjectID object_id = random_object_id();
+ int64_t data_size = 0;
+ int64_t metadata_size = 0;
+ ARROW_CHECK_OK(
+ local_client2.GetNotification(fd, &object_id, &data_size, &metadata_size));
+ ASSERT_EQ(object_id, object_id1);
+ ASSERT_EQ(data_size, 5);
+ ASSERT_EQ(metadata_size, 1);
+
+ ARROW_CHECK_OK(
+ local_client2.GetNotification(fd, &object_id, &data_size, &metadata_size));
+ ASSERT_EQ(object_id, object_id2);
+ ASSERT_EQ(data_size, 6);
+ ASSERT_EQ(metadata_size, 2);
+
+ ARROW_CHECK_OK(local_client2.Disconnect());
+ ARROW_CHECK_OK(local_client.Disconnect());
+}
+
+TEST_F(TestPlasmaStore, SealErrorsTest) {
+ ObjectID object_id = random_object_id();
+
+ Status result = client_.Seal(object_id);
+ ASSERT_TRUE(IsPlasmaObjectNotFound(result));
+
+ // Create object.
+ std::vector<uint8_t> data(100, 0);
+ CreateObject(client_, object_id, {42}, data, false);
+
+ // Trying to seal it again.
+ result = client_.Seal(object_id);
+ ASSERT_TRUE(IsPlasmaObjectAlreadySealed(result));
+ ARROW_CHECK_OK(client_.Release(object_id));
+}
+
+TEST_F(TestPlasmaStore, SetQuotaBasicTest) {
+ bool has_object = false;
+ ObjectID id1 = random_object_id();
+ ObjectID id2 = random_object_id();
+
+ ARROW_CHECK_OK(client_.SetClientOptions("client1", 5 * 1024 * 1024));
+ std::vector<uint8_t> big_data(3 * 1024 * 1024, 0);
+
+ // First object fits
+ CreateObject(client_, id1, {42}, big_data, true);
+ ARROW_CHECK_OK(client_.Contains(id1, &has_object));
+ ASSERT_TRUE(has_object);
+
+ // Evicts first object
+ CreateObject(client_, id2, {42}, big_data, true);
+ ARROW_CHECK_OK(client_.Contains(id2, &has_object));
+ ASSERT_TRUE(has_object);
+ ARROW_CHECK_OK(client_.Contains(id1, &has_object));
+ ASSERT_FALSE(has_object);
+
+ // Too big to fit in quota at all
+ std::shared_ptr<Buffer> data_buffer;
+ ASSERT_FALSE(
+ client_.Create(random_object_id(), 7 * 1024 * 1024, {}, 0, &data_buffer).ok());
+ ASSERT_TRUE(
+ client_.Create(random_object_id(), 4 * 1024 * 1024, {}, 0, &data_buffer).ok());
+}
+
+TEST_F(TestPlasmaStore, SetQuotaProvidesIsolationFromOtherClients) {
+ bool has_object = false;
+ ObjectID id1 = random_object_id();
+ ObjectID id2 = random_object_id();
+
+ std::vector<uint8_t> big_data(3 * 1024 * 1024, 0);
+
+ // First object, created without quota
+ CreateObject(client_, id1, {42}, big_data, true);
+ ARROW_CHECK_OK(client_.Contains(id1, &has_object));
+ ASSERT_TRUE(has_object);
+
+ // Second client creates a bunch of objects
+ for (int i = 0; i < 10; i++) {
+ CreateObject(client2_, random_object_id(), {42}, big_data, true);
+ }
+
+ // First client's object is evicted
+ ARROW_CHECK_OK(client_.Contains(id1, &has_object));
+ ASSERT_FALSE(has_object);
+
+ // Try again with quota enabled
+ ARROW_CHECK_OK(client_.SetClientOptions("client1", 5 * 1024 * 1024));
+ CreateObject(client_, id2, {42}, big_data, true);
+ ARROW_CHECK_OK(client_.Contains(id2, &has_object));
+ ASSERT_TRUE(has_object);
+
+ // Second client creates a bunch of objects
+ for (int i = 0; i < 10; i++) {
+ CreateObject(client2_, random_object_id(), {42}, big_data, true);
+ }
+
+ // First client's object is not evicted
+ ARROW_CHECK_OK(client_.Contains(id2, &has_object));
+ ASSERT_TRUE(has_object);
+}
+
+TEST_F(TestPlasmaStore, SetQuotaProtectsOtherClients) {
+ bool has_object = false;
+ ObjectID id1 = random_object_id();
+
+ std::vector<uint8_t> big_data(3 * 1024 * 1024, 0);
+
+ // First client has no quota
+ CreateObject(client_, id1, {42}, big_data, true);
+ ARROW_CHECK_OK(client_.Contains(id1, &has_object));
+ ASSERT_TRUE(has_object);
+
+ // Second client creates a bunch of objects under a quota
+ ARROW_CHECK_OK(client2_.SetClientOptions("client2", 5 * 1024 * 1024));
+ for (int i = 0; i < 10; i++) {
+ CreateObject(client2_, random_object_id(), {42}, big_data, true);
+ }
+
+ // First client's object is NOT evicted
+ ARROW_CHECK_OK(client_.Contains(id1, &has_object));
+ ASSERT_TRUE(has_object);
+}
+
+TEST_F(TestPlasmaStore, SetQuotaCannotExceedSeventyPercentMemory) {
+ ASSERT_FALSE(client_.SetClientOptions("client1", 8 * 1024 * 1024).ok());
+ ASSERT_TRUE(client_.SetClientOptions("client1", 5 * 1024 * 1024).ok());
+ // cannot set quota twice
+ ASSERT_FALSE(client_.SetClientOptions("client1", 5 * 1024 * 1024).ok());
+ // cannot exceed 70% summed
+ ASSERT_FALSE(client2_.SetClientOptions("client2", 3 * 1024 * 1024).ok());
+ ASSERT_TRUE(client2_.SetClientOptions("client2", 1 * 1024 * 1024).ok());
+}
+
+TEST_F(TestPlasmaStore, SetQuotaDemotesPinnedObjectsToGlobalLRU) {
+ bool has_object = false;
+ ASSERT_TRUE(client_.SetClientOptions("client1", 5 * 1024 * 1024).ok());
+
+ ObjectID id1 = random_object_id();
+ ObjectID id2 = random_object_id();
+ std::vector<uint8_t> big_data(3 * 1024 * 1024, 0);
+
+ // Quota is not enough to fit both id1 and id2, but global LRU is
+ CreateObject(client_, id1, {42}, big_data, false);
+ CreateObject(client_, id2, {42}, big_data, false);
+ ARROW_CHECK_OK(client_.Contains(id1, &has_object));
+ ASSERT_TRUE(has_object);
+ ARROW_CHECK_OK(client_.Contains(id2, &has_object));
+ ASSERT_TRUE(has_object);
+
+ // Release both objects. Now id1 is in global LRU and id2 is in quota
+ ARROW_CHECK_OK(client_.Release(id1));
+ ARROW_CHECK_OK(client_.Release(id2));
+
+ // This flushes id1 from the object store
+ for (int i = 0; i < 10; i++) {
+ CreateObject(client2_, random_object_id(), {42}, big_data, true);
+ }
+ ARROW_CHECK_OK(client_.Contains(id1, &has_object));
+ ASSERT_FALSE(has_object);
+ ARROW_CHECK_OK(client_.Contains(id2, &has_object));
+ ASSERT_TRUE(has_object);
+}
+
+TEST_F(TestPlasmaStore, SetQuotaDemoteDisconnectToGlobalLRU) {
+ bool has_object = false;
+ PlasmaClient local_client;
+ ARROW_CHECK_OK(local_client.Connect(store_socket_name_, ""));
+ ARROW_CHECK_OK(local_client.SetClientOptions("local", 5 * 1024 * 1024));
+
+ ObjectID id1 = random_object_id();
+ std::vector<uint8_t> big_data(3 * 1024 * 1024, 0);
+
+ // First object fits
+ CreateObject(local_client, id1, {42}, big_data, true);
+ for (int i = 0; i < 10; i++) {
+ CreateObject(client_, random_object_id(), {42}, big_data, true);
+ }
+ ARROW_CHECK_OK(client_.Contains(id1, &has_object));
+ ASSERT_TRUE(has_object);
+
+ // Object is still present after disconnect
+ ARROW_CHECK_OK(local_client.Disconnect());
+ ARROW_CHECK_OK(client_.Contains(id1, &has_object));
+ ASSERT_TRUE(has_object);
+
+ // But is eligible for global LRU
+ for (int i = 0; i < 10; i++) {
+ CreateObject(client_, random_object_id(), {42}, big_data, true);
+ }
+ ARROW_CHECK_OK(client_.Contains(id1, &has_object));
+ ASSERT_FALSE(has_object);
+}
+
+TEST_F(TestPlasmaStore, SetQuotaCleanupObjectMetadata) {
+ PlasmaClient local_client;
+ ARROW_CHECK_OK(local_client.Connect(store_socket_name_, ""));
+ ARROW_CHECK_OK(local_client.SetClientOptions("local", 5 * 1024 * 1024));
+
+ ObjectID id0 = random_object_id();
+ ObjectID id1 = random_object_id();
+ ObjectID id2 = random_object_id();
+ ObjectID id3 = random_object_id();
+ std::vector<uint8_t> big_data(3 * 1024 * 1024, 0);
+ std::vector<uint8_t> small_data(1 * 1024 * 1024, 0);
+ CreateObject(local_client, id0, {42}, small_data, false);
+ CreateObject(local_client, id1, {42}, big_data, true);
+ CreateObject(local_client, id2, {42}, big_data,
+ true); // spills id0 to global, evicts id1
+ CreateObject(local_client, id3, {42}, small_data, false);
+
+ ASSERT_TRUE(client_.DebugString().find("num clients with quota: 1") !=
+ std::string::npos);
+ ASSERT_TRUE(client_.DebugString().find("quota map size: 2") != std::string::npos);
+ ASSERT_TRUE(client_.DebugString().find("pinned quota map size: 1") !=
+ std::string::npos);
+ ASSERT_TRUE(client_.DebugString().find("(global lru) num objects: 0") !=
+ std::string::npos);
+ ASSERT_TRUE(client_.DebugString().find("(local) num objects: 2") != std::string::npos);
+
+ // release id0
+ ARROW_CHECK_OK(local_client.Release(id0));
+ ASSERT_TRUE(client_.DebugString().find("(global lru) num objects: 1") !=
+ std::string::npos);
+
+ // delete everything
+ ARROW_CHECK_OK(local_client.Delete(id0));
+ ARROW_CHECK_OK(local_client.Delete(id2));
+ ARROW_CHECK_OK(local_client.Delete(id3));
+ ARROW_CHECK_OK(local_client.Release(id3));
+ ASSERT_TRUE(client_.DebugString().find("quota map size: 0") != std::string::npos);
+ ASSERT_TRUE(client_.DebugString().find("pinned quota map size: 0") !=
+ std::string::npos);
+ ASSERT_TRUE(client_.DebugString().find("(global lru) num objects: 0") !=
+ std::string::npos);
+ ASSERT_TRUE(client_.DebugString().find("(local) num objects: 0") != std::string::npos);
+
+ ARROW_CHECK_OK(local_client.Disconnect());
+ int tries = 10; // wait for disconnect to complete
+ while (tries > 0 &&
+ client_.DebugString().find("num clients with quota: 0") == std::string::npos) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(200));
+ tries -= 1;
+ }
+ ASSERT_TRUE(client_.DebugString().find("num clients with quota: 0") !=
+ std::string::npos);
+ ASSERT_TRUE(client_.DebugString().find("(global lru) capacity: 10000000") !=
+ std::string::npos);
+ ASSERT_TRUE(client_.DebugString().find("(global lru) used: 0%") != std::string::npos);
+}
+
+TEST_F(TestPlasmaStore, SetQuotaCleanupClientDisconnect) {
+ PlasmaClient local_client;
+ ARROW_CHECK_OK(local_client.Connect(store_socket_name_, ""));
+ ARROW_CHECK_OK(local_client.SetClientOptions("local", 5 * 1024 * 1024));
+
+ ObjectID id1 = random_object_id();
+ ObjectID id2 = random_object_id();
+ ObjectID id3 = random_object_id();
+ std::vector<uint8_t> big_data(3 * 1024 * 1024, 0);
+ std::vector<uint8_t> small_data(1 * 1024 * 1024, 0);
+ CreateObject(local_client, id1, {42}, big_data, true);
+ CreateObject(local_client, id2, {42}, big_data, true);
+ CreateObject(local_client, id3, {42}, small_data, false);
+
+ ARROW_CHECK_OK(local_client.Disconnect());
+ int tries = 10; // wait for disconnect to complete
+ while (tries > 0 &&
+ client_.DebugString().find("num clients with quota: 0") == std::string::npos) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(200));
+ tries -= 1;
+ }
+ ASSERT_TRUE(client_.DebugString().find("num clients with quota: 0") !=
+ std::string::npos);
+ ASSERT_TRUE(client_.DebugString().find("quota map size: 0") != std::string::npos);
+ ASSERT_TRUE(client_.DebugString().find("pinned quota map size: 0") !=
+ std::string::npos);
+ ASSERT_TRUE(client_.DebugString().find("(global lru) num objects: 2") !=
+ std::string::npos);
+ ASSERT_TRUE(client_.DebugString().find("(global lru) capacity: 10000000") !=
+ std::string::npos);
+ ASSERT_TRUE(client_.DebugString().find("(global lru) used: 41.9431%") !=
+ std::string::npos);
+}
+
+TEST_F(TestPlasmaStore, RefreshLRUTest) {
+ bool has_object = false;
+ std::vector<ObjectID> object_ids;
+
+ for (int i = 0; i < 10; ++i) {
+ object_ids.push_back(random_object_id());
+ }
+
+ std::vector<uint8_t> small_data(1 * 1000 * 1000, 0);
+
+ // we can fit ten small objects into the store
+ for (const auto& object_id : object_ids) {
+ CreateObject(client_, object_id, {}, small_data, true);
+ ARROW_CHECK_OK(client_.Contains(object_ids[0], &has_object));
+ ASSERT_TRUE(has_object);
+ }
+
+ ObjectID id = random_object_id();
+ CreateObject(client_, id, {}, small_data, true);
+
+ // the first two objects got evicted (20% of the store)
+ ARROW_CHECK_OK(client_.Contains(object_ids[0], &has_object));
+ ASSERT_FALSE(has_object);
+
+ ARROW_CHECK_OK(client_.Contains(object_ids[1], &has_object));
+ ASSERT_FALSE(has_object);
+
+ ARROW_CHECK_OK(client_.Refresh({object_ids[2], object_ids[3]}));
+
+ id = random_object_id();
+ CreateObject(client_, id, {}, small_data, true);
+ id = random_object_id();
+ CreateObject(client_, id, {}, small_data, true);
+
+ // the refreshed objects are not evicted
+ ARROW_CHECK_OK(client_.Contains(object_ids[2], &has_object));
+ ASSERT_TRUE(has_object);
+ ARROW_CHECK_OK(client_.Contains(object_ids[3], &has_object));
+ ASSERT_TRUE(has_object);
+
+ // the next object in LRU order is evicted
+ ARROW_CHECK_OK(client_.Contains(object_ids[4], &has_object));
+ ASSERT_FALSE(has_object);
+}
+
+TEST_F(TestPlasmaStore, DeleteTest) {
+ ObjectID object_id = random_object_id();
+
+ // Test for deleting nonexistent object.
+ Status result = client_.Delete(object_id);
+ ARROW_CHECK_OK(result);
+
+ // Test for the object being in local Plasma store.
+ // First create object.
+ int64_t data_size = 100;
+ uint8_t metadata[] = {5};
+ int64_t metadata_size = sizeof(metadata);
+ std::shared_ptr<Buffer> data;
+ ARROW_CHECK_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data));
+ ARROW_CHECK_OK(client_.Seal(object_id));
+
+ result = client_.Delete(object_id);
+ ARROW_CHECK_OK(result);
+ bool has_object = false;
+ ARROW_CHECK_OK(client_.Contains(object_id, &has_object));
+ ASSERT_TRUE(has_object);
+
+ ARROW_CHECK_OK(client_.Release(object_id));
+ // object_id is marked as to-be-deleted, when it is not in use, it will be deleted.
+ ARROW_CHECK_OK(client_.Contains(object_id, &has_object));
+ ASSERT_FALSE(has_object);
+ ARROW_CHECK_OK(client_.Delete(object_id));
+}
+
+TEST_F(TestPlasmaStore, DeleteObjectsTest) {
+ ObjectID object_id1 = random_object_id();
+ ObjectID object_id2 = random_object_id();
+
+ // Test for deleting nonexistent object.
+ Status result = client_.Delete(std::vector<ObjectID>{object_id1, object_id2});
+ ARROW_CHECK_OK(result);
+ // Test for the object being in local Plasma store.
+ // First create object.
+ int64_t data_size = 100;
+ uint8_t metadata[] = {5};
+ int64_t metadata_size = sizeof(metadata);
+ std::shared_ptr<Buffer> data;
+ ARROW_CHECK_OK(client_.Create(object_id1, data_size, metadata, metadata_size, &data));
+ ARROW_CHECK_OK(client_.Seal(object_id1));
+ ARROW_CHECK_OK(client_.Create(object_id2, data_size, metadata, metadata_size, &data));
+ ARROW_CHECK_OK(client_.Seal(object_id2));
+ // Release the ref count of Create function.
+ ARROW_CHECK_OK(client_.Release(object_id1));
+ ARROW_CHECK_OK(client_.Release(object_id2));
+ // Increase the ref count by calling Get using client2_.
+ std::vector<ObjectBuffer> object_buffers;
+ ARROW_CHECK_OK(client2_.Get({object_id1, object_id2}, 0, &object_buffers));
+ // Objects are still used by client2_.
+ result = client_.Delete(std::vector<ObjectID>{object_id1, object_id2});
+ ARROW_CHECK_OK(result);
+ // The object is used and it should not be deleted right now.
+ bool has_object = false;
+ ARROW_CHECK_OK(client_.Contains(object_id1, &has_object));
+ ASSERT_TRUE(has_object);
+ ARROW_CHECK_OK(client_.Contains(object_id2, &has_object));
+ ASSERT_TRUE(has_object);
+ // Decrease the ref count by deleting the PlasmaBuffer (in ObjectBuffer).
+ // client2_ won't send the release request immediately because the trigger
+ // condition is not reached. The release is only added to release cache.
+ object_buffers.clear();
+ // Delete the objects.
+ result = client2_.Delete(std::vector<ObjectID>{object_id1, object_id2});
+ ARROW_CHECK_OK(client_.Contains(object_id1, &has_object));
+ ASSERT_FALSE(has_object);
+ ARROW_CHECK_OK(client_.Contains(object_id2, &has_object));
+ ASSERT_FALSE(has_object);
+}
+
+TEST_F(TestPlasmaStore, ContainsTest) {
+ ObjectID object_id = random_object_id();
+
+ // Test for object nonexistence.
+ bool has_object;
+ ARROW_CHECK_OK(client_.Contains(object_id, &has_object));
+ ASSERT_FALSE(has_object);
+
+ // Test for the object being in local Plasma store.
+ // First create object.
+ std::vector<uint8_t> data(100, 0);
+ CreateObject(client_, object_id, {42}, data);
+ std::vector<ObjectBuffer> object_buffers;
+ ARROW_CHECK_OK(client_.Get({object_id}, -1, &object_buffers));
+ ARROW_CHECK_OK(client_.Contains(object_id, &has_object));
+ ASSERT_TRUE(has_object);
+}
+
+TEST_F(TestPlasmaStore, GetTest) {
+ std::vector<ObjectBuffer> object_buffers;
+
+ ObjectID object_id = random_object_id();
+
+ // Test for object nonexistence.
+ ARROW_CHECK_OK(client_.Get({object_id}, 0, &object_buffers));
+ ASSERT_EQ(object_buffers.size(), 1);
+ ASSERT_FALSE(object_buffers[0].metadata);
+ ASSERT_FALSE(object_buffers[0].data);
+ EXPECT_FALSE(client_.IsInUse(object_id));
+
+ // Test for the object being in local Plasma store.
+ // First create object.
+ std::vector<uint8_t> data = {3, 5, 6, 7, 9};
+ CreateObject(client_, object_id, {42}, data);
+ EXPECT_FALSE(client_.IsInUse(object_id));
+
+ object_buffers.clear();
+ ARROW_CHECK_OK(client_.Get({object_id}, -1, &object_buffers));
+ ASSERT_EQ(object_buffers.size(), 1);
+ ASSERT_EQ(object_buffers[0].device_num, 0);
+ AssertObjectBufferEqual(object_buffers[0], {42}, {3, 5, 6, 7, 9});
+
+ // Metadata keeps object in use
+ {
+ auto metadata = object_buffers[0].metadata;
+ object_buffers.clear();
+ ::arrow::AssertBufferEqual(*metadata, std::string{42});
+ EXPECT_TRUE(client_.IsInUse(object_id));
+ }
+ // Object is automatically released
+ EXPECT_FALSE(client_.IsInUse(object_id));
+}
+
+TEST_F(TestPlasmaStore, LegacyGetTest) {
+ // Test for old non-releasing Get() variant
+ ObjectID object_id = random_object_id();
+ {
+ ObjectBuffer object_buffer;
+
+ // Test for object nonexistence.
+ ARROW_CHECK_OK(client_.Get(&object_id, 1, 0, &object_buffer));
+ ASSERT_FALSE(object_buffer.metadata);
+ ASSERT_FALSE(object_buffer.data);
+ EXPECT_FALSE(client_.IsInUse(object_id));
+
+ // First create object.
+ std::vector<uint8_t> data = {3, 5, 6, 7, 9};
+ CreateObject(client_, object_id, {42}, data);
+ EXPECT_FALSE(client_.IsInUse(object_id));
+
+ ARROW_CHECK_OK(client_.Get(&object_id, 1, -1, &object_buffer));
+ AssertObjectBufferEqual(object_buffer, {42}, {3, 5, 6, 7, 9});
+ }
+ // Object needs releasing manually
+ EXPECT_TRUE(client_.IsInUse(object_id));
+ ARROW_CHECK_OK(client_.Release(object_id));
+ EXPECT_FALSE(client_.IsInUse(object_id));
+}
+
+TEST_F(TestPlasmaStore, MultipleGetTest) {
+ ObjectID object_id1 = random_object_id();
+ ObjectID object_id2 = random_object_id();
+ std::vector<ObjectID> object_ids = {object_id1, object_id2};
+ std::vector<ObjectBuffer> object_buffers;
+
+ int64_t data_size = 4;
+ uint8_t metadata[] = {5};
+ int64_t metadata_size = sizeof(metadata);
+ std::shared_ptr<Buffer> data;
+ ARROW_CHECK_OK(client_.Create(object_id1, data_size, metadata, metadata_size, &data));
+ data->mutable_data()[0] = 1;
+ ARROW_CHECK_OK(client_.Seal(object_id1));
+
+ ARROW_CHECK_OK(client_.Create(object_id2, data_size, metadata, metadata_size, &data));
+ data->mutable_data()[0] = 2;
+ ARROW_CHECK_OK(client_.Seal(object_id2));
+
+ ARROW_CHECK_OK(client_.Get(object_ids, -1, &object_buffers));
+ ASSERT_EQ(object_buffers[0].data->data()[0], 1);
+ ASSERT_EQ(object_buffers[1].data->data()[0], 2);
+}
+
+TEST_F(TestPlasmaStore, BatchCreateTest) {
+ ObjectID object_id1 = random_object_id();
+ ObjectID object_id2 = random_object_id();
+ std::vector<ObjectID> object_ids = {object_id1, object_id2};
+
+ std::vector<std::string> data = {"hello", "world"};
+ std::vector<std::string> metadata = {"1", "2"};
+
+ ARROW_CHECK_OK(client_.CreateAndSealBatch(object_ids, data, metadata));
+
+ std::vector<ObjectBuffer> object_buffers;
+
+ ARROW_CHECK_OK(client_.Get(object_ids, -1, &object_buffers));
+
+ std::string out1, out2;
+ out1.assign(reinterpret_cast<const char*>(object_buffers[0].data->data()),
+ object_buffers[0].data->size());
+ out2.assign(reinterpret_cast<const char*>(object_buffers[1].data->data()),
+ object_buffers[1].data->size());
+
+ ASSERT_STREQ(out1.c_str(), "hello");
+ ASSERT_STREQ(out2.c_str(), "world");
+}
+
+TEST_F(TestPlasmaStore, AbortTest) {
+ ObjectID object_id = random_object_id();
+ std::vector<ObjectBuffer> object_buffers;
+
+ // Test for object nonexistence.
+ ARROW_CHECK_OK(client_.Get({object_id}, 0, &object_buffers));
+ ASSERT_FALSE(object_buffers[0].data);
+
+ // Test object abort.
+ // First create object.
+ int64_t data_size = 4;
+ uint8_t metadata[] = {5};
+ int64_t metadata_size = sizeof(metadata);
+ std::shared_ptr<Buffer> data;
+ uint8_t* data_ptr;
+ ARROW_CHECK_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data));
+ data_ptr = data->mutable_data();
+ // Write some data.
+ for (int64_t i = 0; i < data_size / 2; i++) {
+ data_ptr[i] = static_cast<uint8_t>(i % 4);
+ }
+ // Attempt to abort. Test that this fails before the first release.
+ Status status = client_.Abort(object_id);
+ ASSERT_TRUE(status.IsInvalid());
+ // Release, then abort.
+ ARROW_CHECK_OK(client_.Release(object_id));
+ EXPECT_TRUE(client_.IsInUse(object_id));
+
+ ARROW_CHECK_OK(client_.Abort(object_id));
+ EXPECT_FALSE(client_.IsInUse(object_id));
+
+ // Test for object nonexistence after the abort.
+ ARROW_CHECK_OK(client_.Get({object_id}, 0, &object_buffers));
+ ASSERT_FALSE(object_buffers[0].data);
+
+ // Create the object successfully this time.
+ CreateObject(client_, object_id, {42, 43}, {1, 2, 3, 4, 5});
+
+ // Test that we can get the object.
+ ARROW_CHECK_OK(client_.Get({object_id}, -1, &object_buffers));
+ AssertObjectBufferEqual(object_buffers[0], {42, 43}, {1, 2, 3, 4, 5});
+}
+
+TEST_F(TestPlasmaStore, OneIdCreateRepeatedlyTest) {
+ const int64_t loop_times = 5;
+
+ ObjectID object_id = random_object_id();
+ std::vector<ObjectBuffer> object_buffers;
+
+ // Test for object nonexistence.
+ ARROW_CHECK_OK(client_.Get({object_id}, 0, &object_buffers));
+ ASSERT_FALSE(object_buffers[0].data);
+
+ int64_t data_size = 20;
+ uint8_t metadata[] = {5};
+ int64_t metadata_size = sizeof(metadata);
+
+ // Test the sequence: create -> release -> abort -> ...
+ for (int64_t i = 0; i < loop_times; i++) {
+ std::shared_ptr<Buffer> data;
+ ARROW_CHECK_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data));
+ ARROW_CHECK_OK(client_.Release(object_id));
+ ARROW_CHECK_OK(client_.Abort(object_id));
+ }
+
+ // Test the sequence: create -> seal -> release -> delete -> ...
+ for (int64_t i = 0; i < loop_times; i++) {
+ std::shared_ptr<Buffer> data;
+ ARROW_CHECK_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data));
+ ARROW_CHECK_OK(client_.Seal(object_id));
+ ARROW_CHECK_OK(client_.Release(object_id));
+ ARROW_CHECK_OK(client_.Delete(object_id));
+ }
+}
+
+TEST_F(TestPlasmaStore, MultipleClientTest) {
+ ObjectID object_id = random_object_id();
+ std::vector<ObjectBuffer> object_buffers;
+
+ // Test for object nonexistence on the first client.
+ bool has_object;
+ ARROW_CHECK_OK(client_.Contains(object_id, &has_object));
+ ASSERT_FALSE(has_object);
+
+ // Test for the object being in local Plasma store.
+ // First create and seal object on the second client.
+ int64_t data_size = 100;
+ uint8_t metadata[] = {5};
+ int64_t metadata_size = sizeof(metadata);
+ std::shared_ptr<Buffer> data;
+ ARROW_CHECK_OK(client2_.Create(object_id, data_size, metadata, metadata_size, &data));
+ ARROW_CHECK_OK(client2_.Seal(object_id));
+ // Test that the first client can get the object.
+ ARROW_CHECK_OK(client_.Get({object_id}, -1, &object_buffers));
+ ASSERT_TRUE(object_buffers[0].data);
+ ARROW_CHECK_OK(client_.Contains(object_id, &has_object));
+ ASSERT_TRUE(has_object);
+
+ // Test that one client disconnecting does not interfere with the other.
+ // First create object on the second client.
+ object_id = random_object_id();
+ ARROW_CHECK_OK(client2_.Create(object_id, data_size, metadata, metadata_size, &data));
+ // Disconnect the first client.
+ ARROW_CHECK_OK(client_.Disconnect());
+ // Test that the second client can seal and get the created object.
+ ARROW_CHECK_OK(client2_.Seal(object_id));
+ ARROW_CHECK_OK(client2_.Get({object_id}, -1, &object_buffers));
+ ASSERT_TRUE(object_buffers[0].data);
+ ARROW_CHECK_OK(client2_.Contains(object_id, &has_object));
+ ASSERT_TRUE(has_object);
+}
+
+TEST_F(TestPlasmaStore, ManyObjectTest) {
+ // Create many objects on the first client. Seal one third, abort one third,
+ // and leave the last third unsealed.
+ std::vector<ObjectID> object_ids;
+ for (int i = 0; i < 100; i++) {
+ ObjectID object_id = random_object_id();
+ object_ids.push_back(object_id);
+
+ // Test for object nonexistence on the first client.
+ bool has_object;
+ ARROW_CHECK_OK(client_.Contains(object_id, &has_object));
+ ASSERT_FALSE(has_object);
+
+ // Test for the object being in local Plasma store.
+ // First create and seal object on the first client.
+ int64_t data_size = 100;
+ uint8_t metadata[] = {5};
+ int64_t metadata_size = sizeof(metadata);
+ std::shared_ptr<Buffer> data;
+ ARROW_CHECK_OK(client_.Create(object_id, data_size, metadata, metadata_size, &data));
+
+ if (i % 3 == 0) {
+ // Seal one third of the objects.
+ ARROW_CHECK_OK(client_.Seal(object_id));
+ // Test that the first client can get the object.
+ ARROW_CHECK_OK(client_.Contains(object_id, &has_object));
+ ASSERT_TRUE(has_object);
+ } else if (i % 3 == 1) {
+ // Abort one third of the objects.
+ ARROW_CHECK_OK(client_.Release(object_id));
+ ARROW_CHECK_OK(client_.Abort(object_id));
+ }
+ }
+ // Disconnect the first client. All unsealed objects should be aborted.
+ ARROW_CHECK_OK(client_.Disconnect());
+
+ // Check that the second client can query the object store for the first
+ // client's objects.
+ int i = 0;
+ for (auto const& object_id : object_ids) {
+ bool has_object;
+ ARROW_CHECK_OK(client2_.Contains(object_id, &has_object));
+ if (i % 3 == 0) {
+ // The first third should be sealed.
+ ASSERT_TRUE(has_object);
+ } else {
+ // The rest were aborted, so the object is not in the store.
+ ASSERT_FALSE(has_object);
+ }
+ i++;
+ }
+}
+
+#ifdef PLASMA_CUDA
+using arrow::cuda::CudaBuffer;
+using arrow::cuda::CudaBufferReader;
+using arrow::cuda::CudaBufferWriter;
+
+// actual CUDA device number + 1
+constexpr int kGpuDeviceNumber = 1;
+
+namespace {
+
+void AssertCudaRead(const std::shared_ptr<Buffer>& buffer,
+ const std::vector<uint8_t>& expected_data) {
+ std::shared_ptr<CudaBuffer> gpu_buffer;
+ const size_t data_size = expected_data.size();
+
+ ASSERT_OK_AND_ASSIGN(gpu_buffer, CudaBuffer::FromBuffer(buffer));
+ ASSERT_EQ(gpu_buffer->size(), data_size);
+
+ CudaBufferReader reader(gpu_buffer);
+ std::vector<uint8_t> read_data(data_size);
+ ASSERT_OK_AND_EQ(data_size, reader.Read(data_size, read_data.data()));
+
+ for (size_t i = 0; i < data_size; i++) {
+ ASSERT_EQ(read_data[i], expected_data[i]);
+ }
+}
+
+} // namespace
+
+TEST_F(TestPlasmaStore, GetGPUTest) {
+ ObjectID object_id = random_object_id();
+ std::vector<ObjectBuffer> object_buffers;
+
+ // Test for object nonexistence.
+ ARROW_CHECK_OK(client_.Get({object_id}, 0, &object_buffers));
+ ASSERT_EQ(object_buffers.size(), 1);
+ ASSERT_FALSE(object_buffers[0].data);
+
+ // Test for the object being in local Plasma store.
+ // First create object.
+ uint8_t data[] = {4, 5, 3, 1};
+ int64_t data_size = sizeof(data);
+ uint8_t metadata[] = {42};
+ int64_t metadata_size = sizeof(metadata);
+ std::shared_ptr<Buffer> data_buffer;
+ std::shared_ptr<CudaBuffer> gpu_buffer;
+ ARROW_CHECK_OK(client_.Create(object_id, data_size, metadata, metadata_size,
+ &data_buffer, kGpuDeviceNumber));
+ ASSERT_OK_AND_ASSIGN(gpu_buffer, CudaBuffer::FromBuffer(data_buffer));
+ CudaBufferWriter writer(gpu_buffer);
+ ARROW_CHECK_OK(writer.Write(data, data_size));
+ ARROW_CHECK_OK(client_.Seal(object_id));
+
+ object_buffers.clear();
+ ARROW_CHECK_OK(client_.Get({object_id}, -1, &object_buffers));
+ ASSERT_EQ(object_buffers.size(), 1);
+ ASSERT_EQ(object_buffers[0].device_num, kGpuDeviceNumber);
+ // Check data
+ AssertCudaRead(object_buffers[0].data, {4, 5, 3, 1});
+ // Check metadata
+ AssertCudaRead(object_buffers[0].metadata, {42});
+}
+
+TEST_F(TestPlasmaStore, DeleteObjectsGPUTest) {
+ ObjectID object_id1 = random_object_id();
+ ObjectID object_id2 = random_object_id();
+
+ // Test for deleting nonexistent object.
+ Status result = client_.Delete(std::vector<ObjectID>{object_id1, object_id2});
+ ARROW_CHECK_OK(result);
+ // Test for the object being in local Plasma store.
+ // First create object.
+ int64_t data_size = 100;
+ uint8_t metadata[] = {5};
+ int64_t metadata_size = sizeof(metadata);
+ std::shared_ptr<Buffer> data;
+ ARROW_CHECK_OK(client_.Create(object_id1, data_size, metadata, metadata_size, &data,
+ kGpuDeviceNumber));
+ ARROW_CHECK_OK(client_.Seal(object_id1));
+ ARROW_CHECK_OK(client_.Create(object_id2, data_size, metadata, metadata_size, &data,
+ kGpuDeviceNumber));
+ ARROW_CHECK_OK(client_.Seal(object_id2));
+ // Release the ref count of Create function.
+ ARROW_CHECK_OK(client_.Release(object_id1));
+ ARROW_CHECK_OK(client_.Release(object_id2));
+ // Increase the ref count by calling Get using client2_.
+ std::vector<ObjectBuffer> object_buffers;
+ ARROW_CHECK_OK(client2_.Get({object_id1, object_id2}, 0, &object_buffers));
+ // Objects are still used by client2_.
+ result = client_.Delete(std::vector<ObjectID>{object_id1, object_id2});
+ ARROW_CHECK_OK(result);
+ // The object is used and it should not be deleted right now.
+ bool has_object = false;
+ ARROW_CHECK_OK(client_.Contains(object_id1, &has_object));
+ ASSERT_TRUE(has_object);
+ ARROW_CHECK_OK(client_.Contains(object_id2, &has_object));
+ ASSERT_TRUE(has_object);
+ // Decrease the ref count by deleting the PlasmaBuffer (in ObjectBuffer).
+ // client2_ won't send the release request immediately because the trigger
+ // condition is not reached. The release is only added to release cache.
+ object_buffers.clear();
+ // Delete the objects.
+ result = client2_.Delete(std::vector<ObjectID>{object_id1, object_id2});
+ ARROW_CHECK_OK(client_.Contains(object_id1, &has_object));
+ ASSERT_FALSE(has_object);
+ ARROW_CHECK_OK(client_.Contains(object_id2, &has_object));
+ ASSERT_FALSE(has_object);
+}
+
+TEST_F(TestPlasmaStore, RepeatlyCreateGPUTest) {
+ const int64_t loop_times = 100;
+ const int64_t object_num = 5;
+ const int64_t data_size = 40;
+
+ std::vector<ObjectID> object_ids;
+
+ // create new gpu objects
+ for (int64_t i = 0; i < object_num; i++) {
+ object_ids.push_back(random_object_id());
+ ObjectID& object_id = object_ids[i];
+
+ std::shared_ptr<Buffer> data;
+ ARROW_CHECK_OK(client_.Create(object_id, data_size, 0, 0, &data, kGpuDeviceNumber));
+ ARROW_CHECK_OK(client_.Seal(object_id));
+ ARROW_CHECK_OK(client_.Release(object_id));
+ }
+
+ // delete and create again
+ for (int64_t i = 0; i < loop_times; i++) {
+ ObjectID& object_id = object_ids[i % object_num];
+
+ ARROW_CHECK_OK(client_.Delete(object_id));
+
+ std::shared_ptr<Buffer> data;
+ ARROW_CHECK_OK(client_.Create(object_id, data_size, 0, 0, &data, kGpuDeviceNumber));
+ ARROW_CHECK_OK(client_.Seal(object_id));
+ ARROW_CHECK_OK(client_.Release(object_id));
+ }
+
+ // delete all
+ ARROW_CHECK_OK(client_.Delete(object_ids));
+}
+
+TEST_F(TestPlasmaStore, GPUBufferLifetime) {
+ // ARROW-5924: GPU buffer is allowed to persist after Release()
+ ObjectID object_id = random_object_id();
+ const int64_t data_size = 40;
+
+ std::shared_ptr<Buffer> create_buff;
+ ARROW_CHECK_OK(
+ client_.Create(object_id, data_size, nullptr, 0, &create_buff, kGpuDeviceNumber));
+ ARROW_CHECK_OK(client_.Seal(object_id));
+ ARROW_CHECK_OK(client_.Release(object_id));
+
+ ObjectBuffer get_buff_1;
+ ARROW_CHECK_OK(client_.Get(&object_id, 1, -1, &get_buff_1));
+ ObjectBuffer get_buff_2;
+ ARROW_CHECK_OK(client_.Get(&object_id, 1, -1, &get_buff_2));
+ ARROW_CHECK_OK(client_.Release(object_id));
+ ARROW_CHECK_OK(client_.Release(object_id));
+
+ ObjectBuffer get_buff_3;
+ ARROW_CHECK_OK(client_.Get(&object_id, 1, -1, &get_buff_3));
+ ARROW_CHECK_OK(client_.Release(object_id));
+
+ ARROW_CHECK_OK(client_.Delete(object_id));
+}
+
+TEST_F(TestPlasmaStore, MultipleClientGPUTest) {
+ ObjectID object_id = random_object_id();
+ std::vector<ObjectBuffer> object_buffers;
+
+ // Test for object nonexistence on the first client.
+ bool has_object;
+ ARROW_CHECK_OK(client_.Contains(object_id, &has_object));
+ ASSERT_FALSE(has_object);
+
+ // Test for the object being in local Plasma store.
+ // First create and seal object on the second client.
+ int64_t data_size = 100;
+ uint8_t metadata[] = {5};
+ int64_t metadata_size = sizeof(metadata);
+ std::shared_ptr<Buffer> data;
+ ARROW_CHECK_OK(client2_.Create(object_id, data_size, metadata, metadata_size, &data,
+ kGpuDeviceNumber));
+ ARROW_CHECK_OK(client2_.Seal(object_id));
+ // Test that the first client can get the object.
+ ARROW_CHECK_OK(client_.Get({object_id}, -1, &object_buffers));
+ ARROW_CHECK_OK(client_.Contains(object_id, &has_object));
+ ASSERT_TRUE(has_object);
+
+ // Test that one client disconnecting does not interfere with the other.
+ // First create object on the second client.
+ object_id = random_object_id();
+ ARROW_CHECK_OK(client2_.Create(object_id, data_size, metadata, metadata_size, &data,
+ kGpuDeviceNumber));
+ // Disconnect the first client.
+ ARROW_CHECK_OK(client_.Disconnect());
+ // Test that the second client can seal and get the created object.
+ ARROW_CHECK_OK(client2_.Seal(object_id));
+ object_buffers.clear();
+ ARROW_CHECK_OK(client2_.Contains(object_id, &has_object));
+ ASSERT_TRUE(has_object);
+ ARROW_CHECK_OK(client2_.Get({object_id}, -1, &object_buffers));
+ ASSERT_EQ(object_buffers.size(), 1);
+ ASSERT_EQ(object_buffers[0].device_num, kGpuDeviceNumber);
+ AssertCudaRead(object_buffers[0].metadata, {5});
+}
+
+#endif // PLASMA_CUDA
+
+} // namespace plasma
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ plasma::test_executable = std::string(argv[0]);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/arrow/cpp/src/plasma/test/external_store_tests.cc b/src/arrow/cpp/src/plasma/test/external_store_tests.cc
new file mode 100644
index 000000000..2b7ab059f
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/test/external_store_tests.cc
@@ -0,0 +1,143 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <assert.h>
+#include <signal.h>
+#include <stdlib.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <memory>
+#include <thread>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/io_util.h"
+
+#include "plasma/client.h"
+#include "plasma/common.h"
+#include "plasma/external_store.h"
+#include "plasma/plasma.h"
+#include "plasma/protocol.h"
+#include "plasma/test_util.h"
+
+namespace plasma {
+
+using arrow::internal::TemporaryDir;
+
+std::string external_test_executable; // NOLINT
+
+void AssertObjectBufferEqual(const ObjectBuffer& object_buffer,
+ const std::string& metadata, const std::string& data) {
+ arrow::AssertBufferEqual(*object_buffer.metadata, metadata);
+ arrow::AssertBufferEqual(*object_buffer.data, data);
+}
+
+class TestPlasmaStoreWithExternal : public ::testing::Test {
+ public:
+ // TODO(pcm): At the moment, stdout of the test gets mixed up with
+ // stdout of the object store. Consider changing that.
+ void SetUp() override {
+ ASSERT_OK_AND_ASSIGN(temp_dir_, TemporaryDir::Make("ext-test-"));
+ store_socket_name_ = temp_dir_->path().ToString() + "store";
+
+ std::string plasma_directory =
+ external_test_executable.substr(0, external_test_executable.find_last_of('/'));
+ std::string plasma_command = plasma_directory +
+ "/plasma-store-server -m 1024000 -e " +
+ "hashtable://test -s " + store_socket_name_ +
+ " 1> /tmp/log.stdout 2> /tmp/log.stderr & " +
+ "echo $! > " + store_socket_name_ + ".pid";
+ PLASMA_CHECK_SYSTEM(system(plasma_command.c_str()));
+ ARROW_CHECK_OK(client_.Connect(store_socket_name_, ""));
+ }
+
+ void TearDown() override {
+ ARROW_CHECK_OK(client_.Disconnect());
+ // Kill plasma_store process that we started
+#ifdef COVERAGE_BUILD
+ // Ask plasma_store to exit gracefully and give it time to write out
+ // coverage files
+ std::string plasma_term_command =
+ "kill -TERM `cat " + store_socket_name_ + ".pid` || exit 0";
+ PLASMA_CHECK_SYSTEM(system(plasma_term_command.c_str()));
+ std::this_thread::sleep_for(std::chrono::milliseconds(200));
+#endif
+ std::string plasma_kill_command =
+ "kill -KILL `cat " + store_socket_name_ + ".pid` || exit 0";
+ PLASMA_CHECK_SYSTEM(system(plasma_kill_command.c_str()));
+ }
+
+ protected:
+ PlasmaClient client_;
+ std::unique_ptr<TemporaryDir> temp_dir_;
+ std::string store_socket_name_;
+};
+
+TEST_F(TestPlasmaStoreWithExternal, EvictionTest) {
+ std::vector<ObjectID> object_ids;
+ std::string data(100 * 1024, 'x');
+ std::string metadata;
+ for (int i = 0; i < 20; i++) {
+ ObjectID object_id = random_object_id();
+ object_ids.push_back(object_id);
+
+ // Test for object nonexistence.
+ bool has_object;
+ ARROW_CHECK_OK(client_.Contains(object_id, &has_object));
+ ASSERT_FALSE(has_object);
+
+ // Test for the object being in local Plasma store.
+ // Create and seal the object.
+ ARROW_CHECK_OK(client_.CreateAndSeal(object_id, data, metadata));
+ // Test that the client can get the object.
+ ARROW_CHECK_OK(client_.Contains(object_id, &has_object));
+ ASSERT_TRUE(has_object);
+ }
+
+ for (int i = 0; i < 20; i++) {
+ // Since we are accessing objects sequentially, every object we
+ // access would be a cache "miss" owing to LRU eviction.
+ // Try and access the object from the plasma store first, and then try
+ // external store on failure. This should succeed to fetch the object.
+ // However, it may evict the next few objects.
+ std::vector<ObjectBuffer> object_buffers;
+ ARROW_CHECK_OK(client_.Get({object_ids[i]}, -1, &object_buffers));
+ ASSERT_EQ(object_buffers.size(), 1);
+ ASSERT_EQ(object_buffers[0].device_num, 0);
+ ASSERT_TRUE(object_buffers[0].data);
+ AssertObjectBufferEqual(object_buffers[0], metadata, data);
+ }
+
+ // Make sure we still cannot fetch objects that do not exist
+ std::vector<ObjectBuffer> object_buffers;
+ ARROW_CHECK_OK(client_.Get({random_object_id()}, 100, &object_buffers));
+ ASSERT_EQ(object_buffers.size(), 1);
+ ASSERT_EQ(object_buffers[0].device_num, 0);
+ ASSERT_EQ(object_buffers[0].data, nullptr);
+ ASSERT_EQ(object_buffers[0].metadata, nullptr);
+}
+
+} // namespace plasma
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ plasma::external_test_executable = std::string(argv[0]);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/arrow/cpp/src/plasma/test/serialization_tests.cc b/src/arrow/cpp/src/plasma/test/serialization_tests.cc
new file mode 100644
index 000000000..a9eea7be7
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/test/serialization_tests.cc
@@ -0,0 +1,333 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/io_util.h"
+
+#include "plasma/common.h"
+#include "plasma/io.h"
+#include "plasma/plasma.h"
+#include "plasma/protocol.h"
+#include "plasma/test_util.h"
+
+namespace fb = plasma::flatbuf;
+
+namespace plasma {
+
+using arrow::internal::TemporaryDir;
+
+/**
+ * Seek to the beginning of a file and read a message from it.
+ *
+ * \param fd File descriptor of the file.
+ * \param message_type Message type that we expect in the file.
+ *
+ * \return Pointer to the content of the message. Needs to be freed by the
+ * caller.
+ */
+std::vector<uint8_t> read_message_from_file(int fd, MessageType message_type) {
+ /* Go to the beginning of the file. */
+ lseek(fd, 0, SEEK_SET);
+ MessageType type;
+ std::vector<uint8_t> data;
+ Status s = ReadMessage(fd, &type, &data);
+ DCHECK_OK(s);
+ DCHECK_EQ(type, message_type);
+ return data;
+}
+
+PlasmaObject random_plasma_object(void) {
+ unsigned int seed = static_cast<unsigned int>(time(NULL));
+ int random = rand_r(&seed);
+ PlasmaObject object = {};
+ object.store_fd = random + 7;
+ object.data_offset = random + 1;
+ object.metadata_offset = random + 2;
+ object.data_size = random + 3;
+ object.metadata_size = random + 4;
+ object.device_num = 0;
+ return object;
+}
+
+class TestPlasmaSerialization : public ::testing::Test {
+ public:
+ void SetUp() { ASSERT_OK_AND_ASSIGN(temp_dir_, TemporaryDir::Make("ser-test-")); }
+
+ // Create a temporary file.
+ // A fd is returned which must be closed manually. The file itself
+ // is deleted at the end of the test.
+ int CreateTemporaryFile(void) {
+ char path[1024];
+
+ std::stringstream ss;
+ ss << temp_dir_->path().ToString() << "fileXXXXXX";
+ strncpy(path, ss.str().c_str(), sizeof(path));
+ ARROW_LOG(INFO) << "file path: '" << path << "'";
+ return mkstemp(path);
+ }
+
+ protected:
+ std::unique_ptr<TemporaryDir> temp_dir_;
+};
+
+TEST_F(TestPlasmaSerialization, CreateRequest) {
+ int fd = CreateTemporaryFile();
+ ObjectID object_id1 = random_object_id();
+ int64_t data_size1 = 42;
+ int64_t metadata_size1 = 11;
+ int device_num1 = 0;
+ ASSERT_OK(SendCreateRequest(fd, object_id1, /*evict_if_full=*/true, data_size1,
+ metadata_size1, device_num1));
+ std::vector<uint8_t> data =
+ read_message_from_file(fd, MessageType::PlasmaCreateRequest);
+ ObjectID object_id2;
+ bool evict_if_full;
+ int64_t data_size2;
+ int64_t metadata_size2;
+ int device_num2;
+ ASSERT_OK(ReadCreateRequest(data.data(), data.size(), &object_id2, &evict_if_full,
+ &data_size2, &metadata_size2, &device_num2));
+ ASSERT_TRUE(evict_if_full);
+ ASSERT_EQ(data_size1, data_size2);
+ ASSERT_EQ(metadata_size1, metadata_size2);
+ ASSERT_EQ(object_id1, object_id2);
+ ASSERT_EQ(device_num1, device_num2);
+ close(fd);
+}
+
+TEST_F(TestPlasmaSerialization, CreateReply) {
+ int fd = CreateTemporaryFile();
+ ObjectID object_id1 = random_object_id();
+ PlasmaObject object1 = random_plasma_object();
+ int64_t mmap_size1 = 1000000;
+ ASSERT_OK(SendCreateReply(fd, object_id1, &object1, PlasmaError::OK, mmap_size1));
+ std::vector<uint8_t> data = read_message_from_file(fd, MessageType::PlasmaCreateReply);
+ ObjectID object_id2;
+ PlasmaObject object2 = {};
+ int store_fd;
+ int64_t mmap_size2;
+ ASSERT_OK(ReadCreateReply(data.data(), data.size(), &object_id2, &object2, &store_fd,
+ &mmap_size2));
+ ASSERT_EQ(object_id1, object_id2);
+ ASSERT_EQ(object1.store_fd, store_fd);
+ ASSERT_EQ(mmap_size1, mmap_size2);
+ ASSERT_EQ(memcmp(&object1, &object2, sizeof(object1)), 0);
+ close(fd);
+}
+
+TEST_F(TestPlasmaSerialization, SealRequest) {
+ int fd = CreateTemporaryFile();
+ ObjectID object_id1 = random_object_id();
+ std::string digest1 = std::string(kDigestSize, 7);
+ ASSERT_OK(SendSealRequest(fd, object_id1, digest1));
+ std::vector<uint8_t> data = read_message_from_file(fd, MessageType::PlasmaSealRequest);
+ ObjectID object_id2;
+ std::string digest2;
+ ASSERT_OK(ReadSealRequest(data.data(), data.size(), &object_id2, &digest2));
+ ASSERT_EQ(object_id1, object_id2);
+ ASSERT_EQ(memcmp(digest1.data(), digest2.data(), kDigestSize), 0);
+ close(fd);
+}
+
+TEST_F(TestPlasmaSerialization, SealReply) {
+ int fd = CreateTemporaryFile();
+ ObjectID object_id1 = random_object_id();
+ ASSERT_OK(SendSealReply(fd, object_id1, PlasmaError::ObjectExists));
+ std::vector<uint8_t> data = read_message_from_file(fd, MessageType::PlasmaSealReply);
+ ObjectID object_id2;
+ Status s = ReadSealReply(data.data(), data.size(), &object_id2);
+ ASSERT_EQ(object_id1, object_id2);
+ ASSERT_TRUE(IsPlasmaObjectExists(s));
+ close(fd);
+}
+
+TEST_F(TestPlasmaSerialization, GetRequest) {
+ int fd = CreateTemporaryFile();
+ ObjectID object_ids[2];
+ object_ids[0] = random_object_id();
+ object_ids[1] = random_object_id();
+ int64_t timeout_ms = 1234;
+ ASSERT_OK(SendGetRequest(fd, object_ids, 2, timeout_ms));
+ std::vector<uint8_t> data = read_message_from_file(fd, MessageType::PlasmaGetRequest);
+ std::vector<ObjectID> object_ids_return;
+ int64_t timeout_ms_return;
+ ASSERT_OK(
+ ReadGetRequest(data.data(), data.size(), object_ids_return, &timeout_ms_return));
+ ASSERT_EQ(object_ids[0], object_ids_return[0]);
+ ASSERT_EQ(object_ids[1], object_ids_return[1]);
+ ASSERT_EQ(timeout_ms, timeout_ms_return);
+ close(fd);
+}
+
+TEST_F(TestPlasmaSerialization, GetReply) {
+ int fd = CreateTemporaryFile();
+ ObjectID object_ids[2];
+ object_ids[0] = random_object_id();
+ object_ids[1] = random_object_id();
+ std::unordered_map<ObjectID, PlasmaObject> plasma_objects;
+ plasma_objects[object_ids[0]] = random_plasma_object();
+ plasma_objects[object_ids[1]] = random_plasma_object();
+ std::vector<int> store_fds = {1, 2, 3};
+ std::vector<int64_t> mmap_sizes = {100, 200, 300};
+ ASSERT_OK(SendGetReply(fd, object_ids, plasma_objects, 2, store_fds, mmap_sizes));
+
+ std::vector<uint8_t> data = read_message_from_file(fd, MessageType::PlasmaGetReply);
+ ObjectID object_ids_return[2];
+ PlasmaObject plasma_objects_return[2];
+ std::vector<int> store_fds_return;
+ std::vector<int64_t> mmap_sizes_return;
+ memset(&plasma_objects_return, 0, sizeof(plasma_objects_return));
+ ASSERT_OK(ReadGetReply(data.data(), data.size(), object_ids_return,
+ &plasma_objects_return[0], 2, store_fds_return,
+ mmap_sizes_return));
+
+ ASSERT_EQ(object_ids[0], object_ids_return[0]);
+ ASSERT_EQ(object_ids[1], object_ids_return[1]);
+
+ PlasmaObject po, po2;
+ for (int i = 0; i < 2; ++i) {
+ po = plasma_objects[object_ids[i]];
+ po2 = plasma_objects_return[i];
+ ASSERT_EQ(po, po2);
+ }
+ ASSERT_TRUE(store_fds == store_fds_return);
+ ASSERT_TRUE(mmap_sizes == mmap_sizes_return);
+ close(fd);
+}
+
+TEST_F(TestPlasmaSerialization, ReleaseRequest) {
+ int fd = CreateTemporaryFile();
+ ObjectID object_id1 = random_object_id();
+ ASSERT_OK(SendReleaseRequest(fd, object_id1));
+ std::vector<uint8_t> data =
+ read_message_from_file(fd, MessageType::PlasmaReleaseRequest);
+ ObjectID object_id2;
+ ASSERT_OK(ReadReleaseRequest(data.data(), data.size(), &object_id2));
+ ASSERT_EQ(object_id1, object_id2);
+ close(fd);
+}
+
+TEST_F(TestPlasmaSerialization, ReleaseReply) {
+ int fd = CreateTemporaryFile();
+ ObjectID object_id1 = random_object_id();
+ ASSERT_OK(SendReleaseReply(fd, object_id1, PlasmaError::ObjectExists));
+ std::vector<uint8_t> data = read_message_from_file(fd, MessageType::PlasmaReleaseReply);
+ ObjectID object_id2;
+ Status s = ReadReleaseReply(data.data(), data.size(), &object_id2);
+ ASSERT_EQ(object_id1, object_id2);
+ ASSERT_TRUE(IsPlasmaObjectExists(s));
+ close(fd);
+}
+
+TEST_F(TestPlasmaSerialization, DeleteRequest) {
+ int fd = CreateTemporaryFile();
+ ObjectID object_id1 = random_object_id();
+ ASSERT_OK(SendDeleteRequest(fd, std::vector<ObjectID>{object_id1}));
+ std::vector<uint8_t> data =
+ read_message_from_file(fd, MessageType::PlasmaDeleteRequest);
+ std::vector<ObjectID> object_vec;
+ ASSERT_OK(ReadDeleteRequest(data.data(), data.size(), &object_vec));
+ ASSERT_EQ(object_vec.size(), 1);
+ ASSERT_EQ(object_id1, object_vec[0]);
+ close(fd);
+}
+
+TEST_F(TestPlasmaSerialization, DeleteReply) {
+ int fd = CreateTemporaryFile();
+ ObjectID object_id1 = random_object_id();
+ PlasmaError error1 = PlasmaError::ObjectExists;
+ ASSERT_OK(SendDeleteReply(fd, std::vector<ObjectID>{object_id1},
+ std::vector<PlasmaError>{error1}));
+ std::vector<uint8_t> data = read_message_from_file(fd, MessageType::PlasmaDeleteReply);
+ std::vector<ObjectID> object_vec;
+ std::vector<PlasmaError> error_vec;
+ Status s = ReadDeleteReply(data.data(), data.size(), &object_vec, &error_vec);
+ ASSERT_EQ(object_vec.size(), 1);
+ ASSERT_EQ(object_id1, object_vec[0]);
+ ASSERT_EQ(error_vec.size(), 1);
+ ASSERT_TRUE(error_vec[0] == PlasmaError::ObjectExists);
+ ASSERT_TRUE(s.ok());
+ close(fd);
+}
+
+TEST_F(TestPlasmaSerialization, EvictRequest) {
+ int fd = CreateTemporaryFile();
+ int64_t num_bytes = 111;
+ ASSERT_OK(SendEvictRequest(fd, num_bytes));
+ std::vector<uint8_t> data = read_message_from_file(fd, MessageType::PlasmaEvictRequest);
+ int64_t num_bytes_received;
+ ASSERT_OK(ReadEvictRequest(data.data(), data.size(), &num_bytes_received));
+ ASSERT_EQ(num_bytes, num_bytes_received);
+ close(fd);
+}
+
+TEST_F(TestPlasmaSerialization, EvictReply) {
+ int fd = CreateTemporaryFile();
+ int64_t num_bytes = 111;
+ ASSERT_OK(SendEvictReply(fd, num_bytes));
+ std::vector<uint8_t> data = read_message_from_file(fd, MessageType::PlasmaEvictReply);
+ int64_t num_bytes_received;
+ ASSERT_OK(ReadEvictReply(data.data(), data.size(), num_bytes_received));
+ ASSERT_EQ(num_bytes, num_bytes_received);
+ close(fd);
+}
+
+TEST_F(TestPlasmaSerialization, DataRequest) {
+ int fd = CreateTemporaryFile();
+ ObjectID object_id1 = random_object_id();
+ const char* address1 = "address1";
+ int port1 = 12345;
+ ASSERT_OK(SendDataRequest(fd, object_id1, address1, port1));
+ /* Reading message back. */
+ std::vector<uint8_t> data = read_message_from_file(fd, MessageType::PlasmaDataRequest);
+ ObjectID object_id2;
+ char* address2;
+ int port2;
+ ASSERT_OK(ReadDataRequest(data.data(), data.size(), &object_id2, &address2, &port2));
+ ASSERT_EQ(object_id1, object_id2);
+ ASSERT_EQ(strcmp(address1, address2), 0);
+ ASSERT_EQ(port1, port2);
+ free(address2);
+ close(fd);
+}
+
+TEST_F(TestPlasmaSerialization, DataReply) {
+ int fd = CreateTemporaryFile();
+ ObjectID object_id1 = random_object_id();
+ int64_t object_size1 = 146;
+ int64_t metadata_size1 = 198;
+ ASSERT_OK(SendDataReply(fd, object_id1, object_size1, metadata_size1));
+ /* Reading message back. */
+ std::vector<uint8_t> data = read_message_from_file(fd, MessageType::PlasmaDataReply);
+ ObjectID object_id2;
+ int64_t object_size2;
+ int64_t metadata_size2;
+ ASSERT_OK(ReadDataReply(data.data(), data.size(), &object_id2, &object_size2,
+ &metadata_size2));
+ ASSERT_EQ(object_id1, object_id2);
+ ASSERT_EQ(object_size1, object_size2);
+ ASSERT_EQ(metadata_size1, metadata_size2);
+}
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/test_util.h b/src/arrow/cpp/src/plasma/test_util.h
new file mode 100644
index 000000000..81dae0152
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/test_util.h
@@ -0,0 +1,46 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <limits>
+#include <random>
+
+#include "plasma/common.h"
+
+namespace plasma {
+
+ObjectID random_object_id() {
+ static uint32_t random_seed = 0;
+ std::mt19937 gen(random_seed++);
+ std::uniform_int_distribution<uint32_t> d(0, std::numeric_limits<uint8_t>::max());
+ ObjectID result;
+ uint8_t* data = result.mutable_data();
+ std::generate(data, data + kUniqueIDSize,
+ [&d, &gen] { return static_cast<uint8_t>(d(gen)); });
+ return result;
+}
+
+#define PLASMA_CHECK_SYSTEM(expr) \
+ do { \
+ int status__ = (expr); \
+ EXPECT_TRUE(WIFEXITED(status__)); \
+ EXPECT_EQ(WEXITSTATUS(status__), 0); \
+ } while (false);
+
+} // namespace plasma
diff --git a/src/arrow/cpp/src/plasma/thirdparty/ae/ae.c b/src/arrow/cpp/src/plasma/thirdparty/ae/ae.c
new file mode 100644
index 000000000..dfb722444
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/thirdparty/ae/ae.c
@@ -0,0 +1,465 @@
+/* A simple event-driven programming library. Originally I wrote this code
+ * for the Jim's event-loop (Jim is a Tcl interpreter) but later translated
+ * it in form of a library for easy reuse.
+ *
+ * Copyright (c) 2006-2010, Salvatore Sanfilippo <antirez at gmail dot com>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * * Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ * * Neither the name of Redis nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without
+ * specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include <stdio.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <stdlib.h>
+#include <poll.h>
+#include <string.h>
+#include <time.h>
+#include <errno.h>
+
+#include "plasma/thirdparty/ae/ae.h"
+#include "plasma/thirdparty/ae/zmalloc.h"
+#include "plasma/thirdparty/ae/config.h"
+
+/* Include the best multiplexing layer supported by this system.
+ * The following should be ordered by performances, descending. */
+#ifdef HAVE_EVPORT
+#include "plasma/thirdparty/ae/ae_evport.c"
+#else
+ #ifdef HAVE_EPOLL
+ #include "plasma/thirdparty/ae/ae_epoll.c"
+ #else
+ #ifdef HAVE_KQUEUE
+ #include "plasma/thirdparty/ae/ae_kqueue.c"
+ #else
+ #include "plasma/thirdparty/ae/ae_select.c"
+ #endif
+ #endif
+#endif
+
+aeEventLoop *aeCreateEventLoop(int setsize) {
+ aeEventLoop *eventLoop;
+ int i;
+
+ if ((eventLoop = zmalloc(sizeof(*eventLoop))) == NULL) goto err;
+ eventLoop->events = zmalloc(sizeof(aeFileEvent)*setsize);
+ eventLoop->fired = zmalloc(sizeof(aeFiredEvent)*setsize);
+ if (eventLoop->events == NULL || eventLoop->fired == NULL) goto err;
+ eventLoop->setsize = setsize;
+ eventLoop->lastTime = time(NULL);
+ eventLoop->timeEventHead = NULL;
+ eventLoop->timeEventNextId = 0;
+ eventLoop->stop = 0;
+ eventLoop->maxfd = -1;
+ eventLoop->beforesleep = NULL;
+ if (aeApiCreate(eventLoop) == -1) goto err;
+ /* Events with mask == AE_NONE are not set. So let's initialize the
+ * vector with it. */
+ for (i = 0; i < setsize; i++)
+ eventLoop->events[i].mask = AE_NONE;
+ return eventLoop;
+
+err:
+ if (eventLoop) {
+ zfree(eventLoop->events);
+ zfree(eventLoop->fired);
+ zfree(eventLoop);
+ }
+ return NULL;
+}
+
+/* Return the current set size. */
+int aeGetSetSize(aeEventLoop *eventLoop) {
+ return eventLoop->setsize;
+}
+
+/* Resize the maximum set size of the event loop.
+ * If the requested set size is smaller than the current set size, but
+ * there is already a file descriptor in use that is >= the requested
+ * set size minus one, AE_ERR is returned and the operation is not
+ * performed at all.
+ *
+ * Otherwise AE_OK is returned and the operation is successful. */
+int aeResizeSetSize(aeEventLoop *eventLoop, int setsize) {
+ int i;
+
+ if (setsize == eventLoop->setsize) return AE_OK;
+ if (eventLoop->maxfd >= setsize) return AE_ERR;
+ if (aeApiResize(eventLoop,setsize) == -1) return AE_ERR;
+
+ eventLoop->events = zrealloc(eventLoop->events,sizeof(aeFileEvent)*setsize);
+ eventLoop->fired = zrealloc(eventLoop->fired,sizeof(aeFiredEvent)*setsize);
+ eventLoop->setsize = setsize;
+
+ /* Make sure that if we created new slots, they are initialized with
+ * an AE_NONE mask. */
+ for (i = eventLoop->maxfd+1; i < setsize; i++)
+ eventLoop->events[i].mask = AE_NONE;
+ return AE_OK;
+}
+
+void aeDeleteEventLoop(aeEventLoop *eventLoop) {
+ aeApiFree(eventLoop);
+ zfree(eventLoop->events);
+ zfree(eventLoop->fired);
+ zfree(eventLoop);
+}
+
+void aeStop(aeEventLoop *eventLoop) {
+ eventLoop->stop = 1;
+}
+
+int aeCreateFileEvent(aeEventLoop *eventLoop, int fd, int mask,
+ aeFileProc *proc, void *clientData)
+{
+ if (fd >= eventLoop->setsize) {
+ errno = ERANGE;
+ return AE_ERR;
+ }
+ aeFileEvent *fe = &eventLoop->events[fd];
+
+ if (aeApiAddEvent(eventLoop, fd, mask) == -1)
+ return AE_ERR;
+ fe->mask |= mask;
+ if (mask & AE_READABLE) fe->rfileProc = proc;
+ if (mask & AE_WRITABLE) fe->wfileProc = proc;
+ fe->clientData = clientData;
+ if (fd > eventLoop->maxfd)
+ eventLoop->maxfd = fd;
+ return AE_OK;
+}
+
+void aeDeleteFileEvent(aeEventLoop *eventLoop, int fd, int mask)
+{
+ if (fd >= eventLoop->setsize) return;
+ aeFileEvent *fe = &eventLoop->events[fd];
+ if (fe->mask == AE_NONE) return;
+
+ aeApiDelEvent(eventLoop, fd, mask);
+ fe->mask = fe->mask & (~mask);
+ if (fd == eventLoop->maxfd && fe->mask == AE_NONE) {
+ /* Update the max fd */
+ int j;
+
+ for (j = eventLoop->maxfd-1; j >= 0; j--)
+ if (eventLoop->events[j].mask != AE_NONE) break;
+ eventLoop->maxfd = j;
+ }
+}
+
+int aeGetFileEvents(aeEventLoop *eventLoop, int fd) {
+ if (fd >= eventLoop->setsize) return 0;
+ aeFileEvent *fe = &eventLoop->events[fd];
+
+ return fe->mask;
+}
+
+static void aeGetTime(long *seconds, long *milliseconds)
+{
+ struct timeval tv;
+
+ gettimeofday(&tv, NULL);
+ *seconds = tv.tv_sec;
+ *milliseconds = tv.tv_usec/1000;
+}
+
+static void aeAddMillisecondsToNow(long long milliseconds, long *sec, long *ms) {
+ long cur_sec, cur_ms, when_sec, when_ms;
+
+ aeGetTime(&cur_sec, &cur_ms);
+ when_sec = cur_sec + milliseconds/1000;
+ when_ms = cur_ms + milliseconds%1000;
+ if (when_ms >= 1000) {
+ when_sec ++;
+ when_ms -= 1000;
+ }
+ *sec = when_sec;
+ *ms = when_ms;
+}
+
+long long aeCreateTimeEvent(aeEventLoop *eventLoop, long long milliseconds,
+ aeTimeProc *proc, void *clientData,
+ aeEventFinalizerProc *finalizerProc)
+{
+ long long id = eventLoop->timeEventNextId++;
+ aeTimeEvent *te;
+
+ te = zmalloc(sizeof(*te));
+ if (te == NULL) return AE_ERR;
+ te->id = id;
+ aeAddMillisecondsToNow(milliseconds,&te->when_sec,&te->when_ms);
+ te->timeProc = proc;
+ te->finalizerProc = finalizerProc;
+ te->clientData = clientData;
+ te->next = eventLoop->timeEventHead;
+ eventLoop->timeEventHead = te;
+ return id;
+}
+
+int aeDeleteTimeEvent(aeEventLoop *eventLoop, long long id)
+{
+ aeTimeEvent *te = eventLoop->timeEventHead;
+ while(te) {
+ if (te->id == id) {
+ te->id = AE_DELETED_EVENT_ID;
+ return AE_OK;
+ }
+ te = te->next;
+ }
+ return AE_ERR; /* NO event with the specified ID found */
+}
+
+/* Search the first timer to fire.
+ * This operation is useful to know how many time the select can be
+ * put in sleep without to delay any event.
+ * If there are no timers NULL is returned.
+ *
+ * Note that's O(N) since time events are unsorted.
+ * Possible optimizations (not needed by Redis so far, but...):
+ * 1) Insert the event in order, so that the nearest is just the head.
+ * Much better but still insertion or deletion of timers is O(N).
+ * 2) Use a skiplist to have this operation as O(1) and insertion as O(log(N)).
+ */
+static aeTimeEvent *aeSearchNearestTimer(aeEventLoop *eventLoop)
+{
+ aeTimeEvent *te = eventLoop->timeEventHead;
+ aeTimeEvent *nearest = NULL;
+
+ while(te) {
+ if (!nearest || te->when_sec < nearest->when_sec ||
+ (te->when_sec == nearest->when_sec &&
+ te->when_ms < nearest->when_ms))
+ nearest = te;
+ te = te->next;
+ }
+ return nearest;
+}
+
+/* Process time events */
+static int processTimeEvents(aeEventLoop *eventLoop) {
+ int processed = 0;
+ aeTimeEvent *te, *prev;
+ long long maxId;
+ time_t now = time(NULL);
+
+ /* If the system clock is moved to the future, and then set back to the
+ * right value, time events may be delayed in a random way. Often this
+ * means that scheduled operations will not be performed soon enough.
+ *
+ * Here we try to detect system clock skews, and force all the time
+ * events to be processed ASAP when this happens: the idea is that
+ * processing events earlier is less dangerous than delaying them
+ * indefinitely, and practice suggests it is. */
+ if (now < eventLoop->lastTime) {
+ te = eventLoop->timeEventHead;
+ while(te) {
+ te->when_sec = 0;
+ te = te->next;
+ }
+ }
+ eventLoop->lastTime = now;
+
+ prev = NULL;
+ te = eventLoop->timeEventHead;
+ maxId = eventLoop->timeEventNextId-1;
+ while(te) {
+ long now_sec, now_ms;
+ long long id;
+
+ /* Remove events scheduled for deletion. */
+ if (te->id == AE_DELETED_EVENT_ID) {
+ aeTimeEvent *next = te->next;
+ if (prev == NULL)
+ eventLoop->timeEventHead = te->next;
+ else
+ prev->next = te->next;
+ if (te->finalizerProc)
+ te->finalizerProc(eventLoop, te->clientData);
+ zfree(te);
+ te = next;
+ continue;
+ }
+
+ /* Make sure we don't process time events created by time events in
+ * this iteration. Note that this check is currently useless: we always
+ * add new timers on the head, however if we change the implementation
+ * detail, this check may be useful again: we keep it here for future
+ * defense. */
+ if (te->id > maxId) {
+ te = te->next;
+ continue;
+ }
+ aeGetTime(&now_sec, &now_ms);
+ if (now_sec > te->when_sec ||
+ (now_sec == te->when_sec && now_ms >= te->when_ms))
+ {
+ int retval;
+
+ id = te->id;
+ retval = te->timeProc(eventLoop, id, te->clientData);
+ processed++;
+ if (retval != AE_NOMORE) {
+ aeAddMillisecondsToNow(retval,&te->when_sec,&te->when_ms);
+ } else {
+ te->id = AE_DELETED_EVENT_ID;
+ }
+ }
+ prev = te;
+ te = te->next;
+ }
+ return processed;
+}
+
+/* Process every pending time event, then every pending file event
+ * (that may be registered by time event callbacks just processed).
+ * Without special flags the function sleeps until some file event
+ * fires, or when the next time event occurs (if any).
+ *
+ * If flags is 0, the function does nothing and returns.
+ * if flags has AE_ALL_EVENTS set, all the kind of events are processed.
+ * if flags has AE_FILE_EVENTS set, file events are processed.
+ * if flags has AE_TIME_EVENTS set, time events are processed.
+ * if flags has AE_DONT_WAIT set the function returns ASAP until all
+ * the events that's possible to process without to wait are processed.
+ *
+ * The function returns the number of events processed. */
+int aeProcessEvents(aeEventLoop *eventLoop, int flags)
+{
+ int processed = 0, numevents;
+
+ /* Nothing to do? return ASAP */
+ if (!(flags & AE_TIME_EVENTS) && !(flags & AE_FILE_EVENTS)) return 0;
+
+ /* Note that we want call select() even if there are no
+ * file events to process as long as we want to process time
+ * events, in order to sleep until the next time event is ready
+ * to fire. */
+ if (eventLoop->maxfd != -1 ||
+ ((flags & AE_TIME_EVENTS) && !(flags & AE_DONT_WAIT))) {
+ int j;
+ aeTimeEvent *shortest = NULL;
+ struct timeval tv, *tvp;
+
+ if (flags & AE_TIME_EVENTS && !(flags & AE_DONT_WAIT))
+ shortest = aeSearchNearestTimer(eventLoop);
+ if (shortest) {
+ long now_sec, now_ms;
+
+ aeGetTime(&now_sec, &now_ms);
+ tvp = &tv;
+
+ /* How many milliseconds we need to wait for the next
+ * time event to fire? */
+ long long ms =
+ (shortest->when_sec - now_sec)*1000 +
+ shortest->when_ms - now_ms;
+
+ if (ms > 0) {
+ tvp->tv_sec = ms/1000;
+ tvp->tv_usec = (ms % 1000)*1000;
+ } else {
+ tvp->tv_sec = 0;
+ tvp->tv_usec = 0;
+ }
+ } else {
+ /* If we have to check for events but need to return
+ * ASAP because of AE_DONT_WAIT we need to set the timeout
+ * to zero */
+ if (flags & AE_DONT_WAIT) {
+ tv.tv_sec = tv.tv_usec = 0;
+ tvp = &tv;
+ } else {
+ /* Otherwise we can block */
+ tvp = NULL; /* wait forever */
+ }
+ }
+
+ numevents = aeApiPoll(eventLoop, tvp);
+ for (j = 0; j < numevents; j++) {
+ aeFileEvent *fe = &eventLoop->events[eventLoop->fired[j].fd];
+ int mask = eventLoop->fired[j].mask;
+ int fd = eventLoop->fired[j].fd;
+ int rfired = 0;
+
+ /* note the fe->mask & mask & ... code: maybe an already processed
+ * event removed an element that fired and we still didn't
+ * processed, so we check if the event is still valid. */
+ if (fe->mask & mask & AE_READABLE) {
+ rfired = 1;
+ fe->rfileProc(eventLoop,fd,fe->clientData,mask);
+ }
+ if (fe->mask & mask & AE_WRITABLE) {
+ if (!rfired || fe->wfileProc != fe->rfileProc)
+ fe->wfileProc(eventLoop,fd,fe->clientData,mask);
+ }
+ processed++;
+ }
+ }
+ /* Check time events */
+ if (flags & AE_TIME_EVENTS)
+ processed += processTimeEvents(eventLoop);
+
+ return processed; /* return the number of processed file/time events */
+}
+
+/* Wait for milliseconds until the given file descriptor becomes
+ * writable/readable/exception */
+int aeWait(int fd, int mask, long long milliseconds) {
+ struct pollfd pfd;
+ int retmask = 0, retval;
+
+ memset(&pfd, 0, sizeof(pfd));
+ pfd.fd = fd;
+ if (mask & AE_READABLE) pfd.events |= POLLIN;
+ if (mask & AE_WRITABLE) pfd.events |= POLLOUT;
+
+ if ((retval = poll(&pfd, 1, milliseconds))== 1) {
+ if (pfd.revents & POLLIN) retmask |= AE_READABLE;
+ if (pfd.revents & POLLOUT) retmask |= AE_WRITABLE;
+ if (pfd.revents & POLLERR) retmask |= AE_WRITABLE;
+ if (pfd.revents & POLLHUP) retmask |= AE_WRITABLE;
+ return retmask;
+ } else {
+ return retval;
+ }
+}
+
+void aeMain(aeEventLoop *eventLoop) {
+ eventLoop->stop = 0;
+ while (!eventLoop->stop) {
+ if (eventLoop->beforesleep != NULL)
+ eventLoop->beforesleep(eventLoop);
+ aeProcessEvents(eventLoop, AE_ALL_EVENTS);
+ }
+}
+
+char *aeGetApiName(void) {
+ return aeApiName();
+}
+
+void aeSetBeforeSleepProc(aeEventLoop *eventLoop, aeBeforeSleepProc *beforesleep) {
+ eventLoop->beforesleep = beforesleep;
+}
diff --git a/src/arrow/cpp/src/plasma/thirdparty/ae/ae.h b/src/arrow/cpp/src/plasma/thirdparty/ae/ae.h
new file mode 100644
index 000000000..1a5e766e5
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/thirdparty/ae/ae.h
@@ -0,0 +1,121 @@
+/* A simple event-driven programming library. Originally I wrote this code
+ * for the Jim's event-loop (Jim is a Tcl interpreter) but later translated
+ * it in form of a library for easy reuse.
+ *
+ * Copyright (c) 2006-2012, Salvatore Sanfilippo <antirez at gmail dot com>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * * Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ * * Neither the name of Redis nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without
+ * specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#pragma once
+
+#include <time.h>
+
+#define AE_OK 0
+#define AE_ERR -1
+
+#define AE_NONE 0
+#define AE_READABLE 1
+#define AE_WRITABLE 2
+
+#define AE_FILE_EVENTS 1
+#define AE_TIME_EVENTS 2
+#define AE_ALL_EVENTS (AE_FILE_EVENTS|AE_TIME_EVENTS)
+#define AE_DONT_WAIT 4
+
+#define AE_NOMORE -1
+#define AE_DELETED_EVENT_ID -1
+
+/* Macros */
+#define AE_NOTUSED(V) ((void) V)
+
+struct aeEventLoop;
+
+/* Types and data structures */
+typedef void aeFileProc(struct aeEventLoop *eventLoop, int fd, void *clientData, int mask);
+typedef int aeTimeProc(struct aeEventLoop *eventLoop, long long id, void *clientData);
+typedef void aeEventFinalizerProc(struct aeEventLoop *eventLoop, void *clientData);
+typedef void aeBeforeSleepProc(struct aeEventLoop *eventLoop);
+
+/* File event structure */
+typedef struct aeFileEvent {
+ int mask; /* one of AE_(READABLE|WRITABLE) */
+ aeFileProc *rfileProc;
+ aeFileProc *wfileProc;
+ void *clientData;
+} aeFileEvent;
+
+/* Time event structure */
+typedef struct aeTimeEvent {
+ long long id; /* time event identifier. */
+ long when_sec; /* seconds */
+ long when_ms; /* milliseconds */
+ aeTimeProc *timeProc;
+ aeEventFinalizerProc *finalizerProc;
+ void *clientData;
+ struct aeTimeEvent *next;
+} aeTimeEvent;
+
+/* A fired event */
+typedef struct aeFiredEvent {
+ int fd;
+ int mask;
+} aeFiredEvent;
+
+/* State of an event based program */
+typedef struct aeEventLoop {
+ int maxfd; /* highest file descriptor currently registered */
+ int setsize; /* max number of file descriptors tracked */
+ long long timeEventNextId;
+ time_t lastTime; /* Used to detect system clock skew */
+ aeFileEvent *events; /* Registered events */
+ aeFiredEvent *fired; /* Fired events */
+ aeTimeEvent *timeEventHead;
+ int stop;
+ void *apidata; /* This is used for polling API specific data */
+ aeBeforeSleepProc *beforesleep;
+} aeEventLoop;
+
+/* Prototypes */
+aeEventLoop *aeCreateEventLoop(int setsize);
+void aeDeleteEventLoop(aeEventLoop *eventLoop);
+void aeStop(aeEventLoop *eventLoop);
+int aeCreateFileEvent(aeEventLoop *eventLoop, int fd, int mask,
+ aeFileProc *proc, void *clientData);
+void aeDeleteFileEvent(aeEventLoop *eventLoop, int fd, int mask);
+int aeGetFileEvents(aeEventLoop *eventLoop, int fd);
+long long aeCreateTimeEvent(aeEventLoop *eventLoop, long long milliseconds,
+ aeTimeProc *proc, void *clientData,
+ aeEventFinalizerProc *finalizerProc);
+int aeDeleteTimeEvent(aeEventLoop *eventLoop, long long id);
+int aeProcessEvents(aeEventLoop *eventLoop, int flags);
+int aeWait(int fd, int mask, long long milliseconds);
+void aeMain(aeEventLoop *eventLoop);
+char *aeGetApiName(void);
+void aeSetBeforeSleepProc(aeEventLoop *eventLoop, aeBeforeSleepProc *beforesleep);
+int aeGetSetSize(aeEventLoop *eventLoop);
+int aeResizeSetSize(aeEventLoop *eventLoop, int setsize);
+
diff --git a/src/arrow/cpp/src/plasma/thirdparty/ae/ae_epoll.c b/src/arrow/cpp/src/plasma/thirdparty/ae/ae_epoll.c
new file mode 100644
index 000000000..2f70550a9
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/thirdparty/ae/ae_epoll.c
@@ -0,0 +1,137 @@
+/* Linux epoll(2) based ae.c module
+ *
+ * Copyright (c) 2009-2012, Salvatore Sanfilippo <antirez at gmail dot com>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * * Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ * * Neither the name of Redis nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without
+ * specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+
+
+#include <sys/epoll.h>
+
+typedef struct aeApiState {
+ int epfd;
+ struct epoll_event *events;
+} aeApiState;
+
+static int aeApiCreate(aeEventLoop *eventLoop) {
+ aeApiState *state = zmalloc(sizeof(aeApiState));
+
+ if (!state) return -1;
+ state->events = zmalloc(sizeof(struct epoll_event)*eventLoop->setsize);
+ if (!state->events) {
+ zfree(state);
+ return -1;
+ }
+ state->epfd = epoll_create(1024); /* 1024 is just a hint for the kernel */
+ if (state->epfd == -1) {
+ zfree(state->events);
+ zfree(state);
+ return -1;
+ }
+ eventLoop->apidata = state;
+ return 0;
+}
+
+static int aeApiResize(aeEventLoop *eventLoop, int setsize) {
+ aeApiState *state = eventLoop->apidata;
+
+ state->events = zrealloc(state->events, sizeof(struct epoll_event)*setsize);
+ return 0;
+}
+
+static void aeApiFree(aeEventLoop *eventLoop) {
+ aeApiState *state = eventLoop->apidata;
+
+ close(state->epfd);
+ zfree(state->events);
+ zfree(state);
+}
+
+static int aeApiAddEvent(aeEventLoop *eventLoop, int fd, int mask) {
+ aeApiState *state = eventLoop->apidata;
+ struct epoll_event ee;
+ memset(&ee, 0, sizeof(struct epoll_event)); // avoid valgrind warning
+ /* If the fd was already monitored for some event, we need a MOD
+ * operation. Otherwise we need an ADD operation. */
+ int op = eventLoop->events[fd].mask == AE_NONE ?
+ EPOLL_CTL_ADD : EPOLL_CTL_MOD;
+
+ ee.events = 0;
+ mask |= eventLoop->events[fd].mask; /* Merge old events */
+ if (mask & AE_READABLE) ee.events |= EPOLLIN;
+ if (mask & AE_WRITABLE) ee.events |= EPOLLOUT;
+ ee.data.fd = fd;
+ if (epoll_ctl(state->epfd,op,fd,&ee) == -1) return -1;
+ return 0;
+}
+
+static void aeApiDelEvent(aeEventLoop *eventLoop, int fd, int delmask) {
+ aeApiState *state = eventLoop->apidata;
+ struct epoll_event ee;
+ memset(&ee, 0, sizeof(struct epoll_event)); // avoid valgrind warning
+ int mask = eventLoop->events[fd].mask & (~delmask);
+
+ ee.events = 0;
+ if (mask & AE_READABLE) ee.events |= EPOLLIN;
+ if (mask & AE_WRITABLE) ee.events |= EPOLLOUT;
+ ee.data.fd = fd;
+ if (mask != AE_NONE) {
+ epoll_ctl(state->epfd,EPOLL_CTL_MOD,fd,&ee);
+ } else {
+ /* Note, Kernel < 2.6.9 requires a non null event pointer even for
+ * EPOLL_CTL_DEL. */
+ epoll_ctl(state->epfd,EPOLL_CTL_DEL,fd,&ee);
+ }
+}
+
+static int aeApiPoll(aeEventLoop *eventLoop, struct timeval *tvp) {
+ aeApiState *state = eventLoop->apidata;
+ int retval, numevents = 0;
+
+ retval = epoll_wait(state->epfd,state->events,eventLoop->setsize,
+ tvp ? (tvp->tv_sec*1000 + tvp->tv_usec/1000) : -1);
+ if (retval > 0) {
+ int j;
+
+ numevents = retval;
+ for (j = 0; j < numevents; j++) {
+ int mask = 0;
+ struct epoll_event *e = state->events+j;
+
+ if (e->events & EPOLLIN) mask |= AE_READABLE;
+ if (e->events & EPOLLOUT) mask |= AE_WRITABLE;
+ if (e->events & EPOLLERR) mask |= AE_WRITABLE;
+ if (e->events & EPOLLHUP) mask |= AE_WRITABLE;
+ eventLoop->fired[j].fd = e->data.fd;
+ eventLoop->fired[j].mask = mask;
+ }
+ }
+ return numevents;
+}
+
+static char *aeApiName(void) {
+ return "epoll";
+}
diff --git a/src/arrow/cpp/src/plasma/thirdparty/ae/ae_evport.c b/src/arrow/cpp/src/plasma/thirdparty/ae/ae_evport.c
new file mode 100644
index 000000000..b79ed9bc7
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/thirdparty/ae/ae_evport.c
@@ -0,0 +1,320 @@
+/* ae.c module for illumos event ports.
+ *
+ * Copyright (c) 2012, Joyent, Inc. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * * Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ * * Neither the name of Redis nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without
+ * specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+
+
+#include <assert.h>
+#include <errno.h>
+#include <port.h>
+#include <poll.h>
+
+#include <sys/types.h>
+#include <sys/time.h>
+
+#include <stdio.h>
+
+static int evport_debug = 0;
+
+/*
+ * This file implements the ae API using event ports, present on Solaris-based
+ * systems since Solaris 10. Using the event port interface, we associate file
+ * descriptors with the port. Each association also includes the set of poll(2)
+ * events that the consumer is interested in (e.g., POLLIN and POLLOUT).
+ *
+ * There's one tricky piece to this implementation: when we return events via
+ * aeApiPoll, the corresponding file descriptors become dissociated from the
+ * port. This is necessary because poll events are level-triggered, so if the
+ * fd didn't become dissociated, it would immediately fire another event since
+ * the underlying state hasn't changed yet. We must re-associate the file
+ * descriptor, but only after we know that our caller has actually read from it.
+ * The ae API does not tell us exactly when that happens, but we do know that
+ * it must happen by the time aeApiPoll is called again. Our solution is to
+ * keep track of the last fds returned by aeApiPoll and re-associate them next
+ * time aeApiPoll is invoked.
+ *
+ * To summarize, in this module, each fd association is EITHER (a) represented
+ * only via the in-kernel association OR (b) represented by pending_fds and
+ * pending_masks. (b) is only true for the last fds we returned from aeApiPoll,
+ * and only until we enter aeApiPoll again (at which point we restore the
+ * in-kernel association).
+ */
+#define MAX_EVENT_BATCHSZ 512
+
+typedef struct aeApiState {
+ int portfd; /* event port */
+ int npending; /* # of pending fds */
+ int pending_fds[MAX_EVENT_BATCHSZ]; /* pending fds */
+ int pending_masks[MAX_EVENT_BATCHSZ]; /* pending fds' masks */
+} aeApiState;
+
+static int aeApiCreate(aeEventLoop *eventLoop) {
+ int i;
+ aeApiState *state = zmalloc(sizeof(aeApiState));
+ if (!state) return -1;
+
+ state->portfd = port_create();
+ if (state->portfd == -1) {
+ zfree(state);
+ return -1;
+ }
+
+ state->npending = 0;
+
+ for (i = 0; i < MAX_EVENT_BATCHSZ; i++) {
+ state->pending_fds[i] = -1;
+ state->pending_masks[i] = AE_NONE;
+ }
+
+ eventLoop->apidata = state;
+ return 0;
+}
+
+static int aeApiResize(aeEventLoop *eventLoop, int setsize) {
+ /* Nothing to resize here. */
+ return 0;
+}
+
+static void aeApiFree(aeEventLoop *eventLoop) {
+ aeApiState *state = eventLoop->apidata;
+
+ close(state->portfd);
+ zfree(state);
+}
+
+static int aeApiLookupPending(aeApiState *state, int fd) {
+ int i;
+
+ for (i = 0; i < state->npending; i++) {
+ if (state->pending_fds[i] == fd)
+ return (i);
+ }
+
+ return (-1);
+}
+
+/*
+ * Helper function to invoke port_associate for the given fd and mask.
+ */
+static int aeApiAssociate(const char *where, int portfd, int fd, int mask) {
+ int events = 0;
+ int rv, err;
+
+ if (mask & AE_READABLE)
+ events |= POLLIN;
+ if (mask & AE_WRITABLE)
+ events |= POLLOUT;
+
+ if (evport_debug)
+ fprintf(stderr, "%s: port_associate(%d, 0x%x) = ", where, fd, events);
+
+ rv = port_associate(portfd, PORT_SOURCE_FD, fd, events,
+ (void *)(uintptr_t)mask);
+ err = errno;
+
+ if (evport_debug)
+ fprintf(stderr, "%d (%s)\n", rv, rv == 0 ? "no error" : strerror(err));
+
+ if (rv == -1) {
+ fprintf(stderr, "%s: port_associate: %s\n", where, strerror(err));
+
+ if (err == EAGAIN)
+ fprintf(stderr, "aeApiAssociate: event port limit exceeded.");
+ }
+
+ return rv;
+}
+
+static int aeApiAddEvent(aeEventLoop *eventLoop, int fd, int mask) {
+ aeApiState *state = eventLoop->apidata;
+ int fullmask, pfd;
+
+ if (evport_debug)
+ fprintf(stderr, "aeApiAddEvent: fd %d mask 0x%x\n", fd, mask);
+
+ /*
+ * Since port_associate's "events" argument replaces any existing events, we
+ * must be sure to include whatever events are already associated when
+ * we call port_associate() again.
+ */
+ fullmask = mask | eventLoop->events[fd].mask;
+ pfd = aeApiLookupPending(state, fd);
+
+ if (pfd != -1) {
+ /*
+ * This fd was recently returned from aeApiPoll. It should be safe to
+ * assume that the consumer has processed that poll event, but we play
+ * it safer by simply updating pending_mask. The fd will be
+ * re-associated as usual when aeApiPoll is called again.
+ */
+ if (evport_debug)
+ fprintf(stderr, "aeApiAddEvent: adding to pending fd %d\n", fd);
+ state->pending_masks[pfd] |= fullmask;
+ return 0;
+ }
+
+ return (aeApiAssociate("aeApiAddEvent", state->portfd, fd, fullmask));
+}
+
+static void aeApiDelEvent(aeEventLoop *eventLoop, int fd, int mask) {
+ aeApiState *state = eventLoop->apidata;
+ int fullmask, pfd;
+
+ if (evport_debug)
+ fprintf(stderr, "del fd %d mask 0x%x\n", fd, mask);
+
+ pfd = aeApiLookupPending(state, fd);
+
+ if (pfd != -1) {
+ if (evport_debug)
+ fprintf(stderr, "deleting event from pending fd %d\n", fd);
+
+ /*
+ * This fd was just returned from aeApiPoll, so it's not currently
+ * associated with the port. All we need to do is update
+ * pending_mask appropriately.
+ */
+ state->pending_masks[pfd] &= ~mask;
+
+ if (state->pending_masks[pfd] == AE_NONE)
+ state->pending_fds[pfd] = -1;
+
+ return;
+ }
+
+ /*
+ * The fd is currently associated with the port. Like with the add case
+ * above, we must look at the full mask for the file descriptor before
+ * updating that association. We don't have a good way of knowing what the
+ * events are without looking into the eventLoop state directly. We rely on
+ * the fact that our caller has already updated the mask in the eventLoop.
+ */
+
+ fullmask = eventLoop->events[fd].mask;
+ if (fullmask == AE_NONE) {
+ /*
+ * We're removing *all* events, so use port_dissociate to remove the
+ * association completely. Failure here indicates a bug.
+ */
+ if (evport_debug)
+ fprintf(stderr, "aeApiDelEvent: port_dissociate(%d)\n", fd);
+
+ if (port_dissociate(state->portfd, PORT_SOURCE_FD, fd) != 0) {
+ perror("aeApiDelEvent: port_dissociate");
+ abort(); /* will not return */
+ }
+ } else if (aeApiAssociate("aeApiDelEvent", state->portfd, fd,
+ fullmask) != 0) {
+ /*
+ * ENOMEM is a potentially transient condition, but the kernel won't
+ * generally return it unless things are really bad. EAGAIN indicates
+ * we've reached a resource limit, for which it doesn't make sense to
+ * retry (counter-intuitively). All other errors indicate a bug. In any
+ * of these cases, the best we can do is to abort.
+ */
+ abort(); /* will not return */
+ }
+}
+
+static int aeApiPoll(aeEventLoop *eventLoop, struct timeval *tvp) {
+ aeApiState *state = eventLoop->apidata;
+ struct timespec timeout, *tsp;
+ int mask, i;
+ uint_t nevents;
+ port_event_t event[MAX_EVENT_BATCHSZ];
+
+ /*
+ * If we've returned fd events before, we must re-associate them with the
+ * port now, before calling port_get(). See the block comment at the top of
+ * this file for an explanation of why.
+ */
+ for (i = 0; i < state->npending; i++) {
+ if (state->pending_fds[i] == -1)
+ /* This fd has since been deleted. */
+ continue;
+
+ if (aeApiAssociate("aeApiPoll", state->portfd,
+ state->pending_fds[i], state->pending_masks[i]) != 0) {
+ /* See aeApiDelEvent for why this case is fatal. */
+ abort();
+ }
+
+ state->pending_masks[i] = AE_NONE;
+ state->pending_fds[i] = -1;
+ }
+
+ state->npending = 0;
+
+ if (tvp != NULL) {
+ timeout.tv_sec = tvp->tv_sec;
+ timeout.tv_nsec = tvp->tv_usec * 1000;
+ tsp = &timeout;
+ } else {
+ tsp = NULL;
+ }
+
+ /*
+ * port_getn can return with errno == ETIME having returned some events (!).
+ * So if we get ETIME, we check nevents, too.
+ */
+ nevents = 1;
+ if (port_getn(state->portfd, event, MAX_EVENT_BATCHSZ, &nevents,
+ tsp) == -1 && (errno != ETIME || nevents == 0)) {
+ if (errno == ETIME || errno == EINTR)
+ return 0;
+
+ /* Any other error indicates a bug. */
+ perror("aeApiPoll: port_get");
+ abort();
+ }
+
+ state->npending = nevents;
+
+ for (i = 0; i < nevents; i++) {
+ mask = 0;
+ if (event[i].portev_events & POLLIN)
+ mask |= AE_READABLE;
+ if (event[i].portev_events & POLLOUT)
+ mask |= AE_WRITABLE;
+
+ eventLoop->fired[i].fd = event[i].portev_object;
+ eventLoop->fired[i].mask = mask;
+
+ if (evport_debug)
+ fprintf(stderr, "aeApiPoll: fd %d mask 0x%x\n",
+ (int)event[i].portev_object, mask);
+
+ state->pending_fds[i] = event[i].portev_object;
+ state->pending_masks[i] = (uintptr_t)event[i].portev_user;
+ }
+
+ return nevents;
+}
+
+static char *aeApiName(void) {
+ return "evport";
+}
diff --git a/src/arrow/cpp/src/plasma/thirdparty/ae/ae_kqueue.c b/src/arrow/cpp/src/plasma/thirdparty/ae/ae_kqueue.c
new file mode 100644
index 000000000..6796f4ceb
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/thirdparty/ae/ae_kqueue.c
@@ -0,0 +1,138 @@
+/* Kqueue(2)-based ae.c module
+ *
+ * Copyright (C) 2009 Harish Mallipeddi - harish.mallipeddi@gmail.com
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * * Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ * * Neither the name of Redis nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without
+ * specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+
+
+#include <sys/types.h>
+#include <sys/event.h>
+#include <sys/time.h>
+
+typedef struct aeApiState {
+ int kqfd;
+ struct kevent *events;
+} aeApiState;
+
+static int aeApiCreate(aeEventLoop *eventLoop) {
+ aeApiState *state = zmalloc(sizeof(aeApiState));
+
+ if (!state) return -1;
+ state->events = zmalloc(sizeof(struct kevent)*eventLoop->setsize);
+ if (!state->events) {
+ zfree(state);
+ return -1;
+ }
+ state->kqfd = kqueue();
+ if (state->kqfd == -1) {
+ zfree(state->events);
+ zfree(state);
+ return -1;
+ }
+ eventLoop->apidata = state;
+ return 0;
+}
+
+static int aeApiResize(aeEventLoop *eventLoop, int setsize) {
+ aeApiState *state = eventLoop->apidata;
+
+ state->events = zrealloc(state->events, sizeof(struct kevent)*setsize);
+ return 0;
+}
+
+static void aeApiFree(aeEventLoop *eventLoop) {
+ aeApiState *state = eventLoop->apidata;
+
+ close(state->kqfd);
+ zfree(state->events);
+ zfree(state);
+}
+
+static int aeApiAddEvent(aeEventLoop *eventLoop, int fd, int mask) {
+ aeApiState *state = eventLoop->apidata;
+ struct kevent ke;
+
+ if (mask & AE_READABLE) {
+ EV_SET(&ke, fd, EVFILT_READ, EV_ADD, 0, 0, NULL);
+ if (kevent(state->kqfd, &ke, 1, NULL, 0, NULL) == -1) return -1;
+ }
+ if (mask & AE_WRITABLE) {
+ EV_SET(&ke, fd, EVFILT_WRITE, EV_ADD, 0, 0, NULL);
+ if (kevent(state->kqfd, &ke, 1, NULL, 0, NULL) == -1) return -1;
+ }
+ return 0;
+}
+
+static void aeApiDelEvent(aeEventLoop *eventLoop, int fd, int mask) {
+ aeApiState *state = eventLoop->apidata;
+ struct kevent ke;
+
+ if (mask & AE_READABLE) {
+ EV_SET(&ke, fd, EVFILT_READ, EV_DELETE, 0, 0, NULL);
+ kevent(state->kqfd, &ke, 1, NULL, 0, NULL);
+ }
+ if (mask & AE_WRITABLE) {
+ EV_SET(&ke, fd, EVFILT_WRITE, EV_DELETE, 0, 0, NULL);
+ kevent(state->kqfd, &ke, 1, NULL, 0, NULL);
+ }
+}
+
+static int aeApiPoll(aeEventLoop *eventLoop, struct timeval *tvp) {
+ aeApiState *state = eventLoop->apidata;
+ int retval, numevents = 0;
+
+ if (tvp != NULL) {
+ struct timespec timeout;
+ timeout.tv_sec = tvp->tv_sec;
+ timeout.tv_nsec = tvp->tv_usec * 1000;
+ retval = kevent(state->kqfd, NULL, 0, state->events, eventLoop->setsize,
+ &timeout);
+ } else {
+ retval = kevent(state->kqfd, NULL, 0, state->events, eventLoop->setsize,
+ NULL);
+ }
+
+ if (retval > 0) {
+ int j;
+
+ numevents = retval;
+ for(j = 0; j < numevents; j++) {
+ int mask = 0;
+ struct kevent *e = state->events+j;
+
+ if (e->filter == EVFILT_READ) mask |= AE_READABLE;
+ if (e->filter == EVFILT_WRITE) mask |= AE_WRITABLE;
+ eventLoop->fired[j].fd = e->ident;
+ eventLoop->fired[j].mask = mask;
+ }
+ }
+ return numevents;
+}
+
+static char *aeApiName(void) {
+ return "kqueue";
+}
diff --git a/src/arrow/cpp/src/plasma/thirdparty/ae/ae_select.c b/src/arrow/cpp/src/plasma/thirdparty/ae/ae_select.c
new file mode 100644
index 000000000..c039a8ea3
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/thirdparty/ae/ae_select.c
@@ -0,0 +1,106 @@
+/* Select()-based ae.c module.
+ *
+ * Copyright (c) 2009-2012, Salvatore Sanfilippo <antirez at gmail dot com>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * * Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ * * Neither the name of Redis nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without
+ * specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+
+
+#include <sys/select.h>
+#include <string.h>
+
+typedef struct aeApiState {
+ fd_set rfds, wfds;
+ /* We need to have a copy of the fd sets as it's not safe to reuse
+ * FD sets after select(). */
+ fd_set _rfds, _wfds;
+} aeApiState;
+
+static int aeApiCreate(aeEventLoop *eventLoop) {
+ aeApiState *state = zmalloc(sizeof(aeApiState));
+
+ if (!state) return -1;
+ FD_ZERO(&state->rfds);
+ FD_ZERO(&state->wfds);
+ eventLoop->apidata = state;
+ return 0;
+}
+
+static int aeApiResize(aeEventLoop *eventLoop, int setsize) {
+ /* Just ensure we have enough room in the fd_set type. */
+ if (setsize >= FD_SETSIZE) return -1;
+ return 0;
+}
+
+static void aeApiFree(aeEventLoop *eventLoop) {
+ zfree(eventLoop->apidata);
+}
+
+static int aeApiAddEvent(aeEventLoop *eventLoop, int fd, int mask) {
+ aeApiState *state = eventLoop->apidata;
+
+ if (mask & AE_READABLE) FD_SET(fd,&state->rfds);
+ if (mask & AE_WRITABLE) FD_SET(fd,&state->wfds);
+ return 0;
+}
+
+static void aeApiDelEvent(aeEventLoop *eventLoop, int fd, int mask) {
+ aeApiState *state = eventLoop->apidata;
+
+ if (mask & AE_READABLE) FD_CLR(fd,&state->rfds);
+ if (mask & AE_WRITABLE) FD_CLR(fd,&state->wfds);
+}
+
+static int aeApiPoll(aeEventLoop *eventLoop, struct timeval *tvp) {
+ aeApiState *state = eventLoop->apidata;
+ int retval, j, numevents = 0;
+
+ memcpy(&state->_rfds,&state->rfds,sizeof(fd_set));
+ memcpy(&state->_wfds,&state->wfds,sizeof(fd_set));
+
+ retval = select(eventLoop->maxfd+1,
+ &state->_rfds,&state->_wfds,NULL,tvp);
+ if (retval > 0) {
+ for (j = 0; j <= eventLoop->maxfd; j++) {
+ int mask = 0;
+ aeFileEvent *fe = &eventLoop->events[j];
+
+ if (fe->mask == AE_NONE) continue;
+ if (fe->mask & AE_READABLE && FD_ISSET(j,&state->_rfds))
+ mask |= AE_READABLE;
+ if (fe->mask & AE_WRITABLE && FD_ISSET(j,&state->_wfds))
+ mask |= AE_WRITABLE;
+ eventLoop->fired[numevents].fd = j;
+ eventLoop->fired[numevents].mask = mask;
+ numevents++;
+ }
+ }
+ return numevents;
+}
+
+static char *aeApiName(void) {
+ return "select";
+}
diff --git a/src/arrow/cpp/src/plasma/thirdparty/ae/config.h b/src/arrow/cpp/src/plasma/thirdparty/ae/config.h
new file mode 100644
index 000000000..9e4b20820
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/thirdparty/ae/config.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2009-2012, Salvatore Sanfilippo <antirez at gmail dot com>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * * Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ * * Neither the name of Redis nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without
+ * specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#pragma once
+
+#ifdef __APPLE__
+#include <AvailabilityMacros.h>
+#endif
+
+/* Test for polling API */
+#ifdef __linux__
+#define HAVE_EPOLL 1
+#endif
+
+#if (defined(__APPLE__) && defined(MAC_OS_X_VERSION_10_6)) || defined(__FreeBSD__) || defined(__OpenBSD__) || defined (__NetBSD__)
+#define HAVE_KQUEUE 1
+#endif
+
+#ifdef __sun
+#include <sys/feature_tests.h>
+#ifdef _DTRACE_VERSION
+#define HAVE_EVPORT 1
+#endif
+#endif
+
+
diff --git a/src/arrow/cpp/src/plasma/thirdparty/ae/zmalloc.h b/src/arrow/cpp/src/plasma/thirdparty/ae/zmalloc.h
new file mode 100644
index 000000000..8894d7605
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/thirdparty/ae/zmalloc.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2009-2012, Salvatore Sanfilippo <antirez at gmail dot com>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * * Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ * * Neither the name of Redis nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without
+ * specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#pragma once
+
+#ifndef zmalloc
+#define zmalloc malloc
+#endif
+
+#ifndef zfree
+#define zfree free
+#endif
+
+#ifndef zrealloc
+#define zrealloc realloc
+#endif
+
diff --git a/src/arrow/cpp/src/plasma/thirdparty/dlmalloc.c b/src/arrow/cpp/src/plasma/thirdparty/dlmalloc.c
new file mode 100644
index 000000000..47346ff7b
--- /dev/null
+++ b/src/arrow/cpp/src/plasma/thirdparty/dlmalloc.c
@@ -0,0 +1,6296 @@
+/*
+ This is a version (aka dlmalloc) of malloc/free/realloc written by
+ Doug Lea and released to the public domain, as explained at
+ http://creativecommons.org/publicdomain/zero/1.0/ Send questions,
+ comments, complaints, performance data, etc to dl@cs.oswego.edu
+
+* Version 2.8.6 Wed Aug 29 06:57:58 2012 Doug Lea
+ Note: There may be an updated version of this malloc obtainable at
+ ftp://gee.cs.oswego.edu/pub/misc/malloc.c
+ Check before installing!
+
+* Quickstart
+
+ This library is all in one file to simplify the most common usage:
+ ftp it, compile it (-O3), and link it into another program. All of
+ the compile-time options default to reasonable values for use on
+ most platforms. You might later want to step through various
+ compile-time and dynamic tuning options.
+
+ For convenience, an include file for code using this malloc is at:
+ ftp://gee.cs.oswego.edu/pub/misc/malloc-2.8.6.h
+ You don't really need this .h file unless you call functions not
+ defined in your system include files. The .h file contains only the
+ excerpts from this file needed for using this malloc on ANSI C/C++
+ systems, so long as you haven't changed compile-time options about
+ naming and tuning parameters. If you do, then you can create your
+ own malloc.h that does include all settings by cutting at the point
+ indicated below. Note that you may already by default be using a C
+ library containing a malloc that is based on some version of this
+ malloc (for example in linux). You might still want to use the one
+ in this file to customize settings or to avoid overheads associated
+ with library versions.
+
+* Vital statistics:
+
+ Supported pointer/size_t representation: 4 or 8 bytes
+ size_t MUST be an unsigned type of the same width as
+ pointers. (If you are using an ancient system that declares
+ size_t as a signed type, or need it to be a different width
+ than pointers, you can use a previous release of this malloc
+ (e.g. 2.7.2) supporting these.)
+
+ Alignment: 8 bytes (minimum)
+ This suffices for nearly all current machines and C compilers.
+ However, you can define MALLOC_ALIGNMENT to be wider than this
+ if necessary (up to 128bytes), at the expense of using more space.
+
+ Minimum overhead per allocated chunk: 4 or 8 bytes (if 4byte sizes)
+ 8 or 16 bytes (if 8byte sizes)
+ Each malloced chunk has a hidden word of overhead holding size
+ and status information, and additional cross-check word
+ if FOOTERS is defined.
+
+ Minimum allocated size: 4-byte ptrs: 16 bytes (including overhead)
+ 8-byte ptrs: 32 bytes (including overhead)
+
+ Even a request for zero bytes (i.e., malloc(0)) returns a
+ pointer to something of the minimum allocatable size.
+ The maximum overhead wastage (i.e., number of extra bytes
+ allocated than were requested in malloc) is less than or equal
+ to the minimum size, except for requests >= mmap_threshold that
+ are serviced via mmap(), where the worst case wastage is about
+ 32 bytes plus the remainder from a system page (the minimal
+ mmap unit); typically 4096 or 8192 bytes.
+
+ Security: static-safe; optionally more or less
+ The "security" of malloc refers to the ability of malicious
+ code to accentuate the effects of errors (for example, freeing
+ space that is not currently malloc'ed or overwriting past the
+ ends of chunks) in code that calls malloc. This malloc
+ guarantees not to modify any memory locations below the base of
+ heap, i.e., static variables, even in the presence of usage
+ errors. The routines additionally detect most improper frees
+ and reallocs. All this holds as long as the static bookkeeping
+ for malloc itself is not corrupted by some other means. This
+ is only one aspect of security -- these checks do not, and
+ cannot, detect all possible programming errors.
+
+ If FOOTERS is defined nonzero, then each allocated chunk
+ carries an additional check word to verify that it was malloced
+ from its space. These check words are the same within each
+ execution of a program using malloc, but differ across
+ executions, so externally crafted fake chunks cannot be
+ freed. This improves security by rejecting frees/reallocs that
+ could corrupt heap memory, in addition to the checks preventing
+ writes to statics that are always on. This may further improve
+ security at the expense of time and space overhead. (Note that
+ FOOTERS may also be worth using with MSPACES.)
+
+ By default detected errors cause the program to abort (calling
+ "abort()"). You can override this to instead proceed past
+ errors by defining PROCEED_ON_ERROR. In this case, a bad free
+ has no effect, and a malloc that encounters a bad address
+ caused by user overwrites will ignore the bad address by
+ dropping pointers and indices to all known memory. This may
+ be appropriate for programs that should continue if at all
+ possible in the face of programming errors, although they may
+ run out of memory because dropped memory is never reclaimed.
+
+ If you don't like either of these options, you can define
+ CORRUPTION_ERROR_ACTION and USAGE_ERROR_ACTION to do anything
+ else. And if if you are sure that your program using malloc has
+ no errors or vulnerabilities, you can define INSECURE to 1,
+ which might (or might not) provide a small performance improvement.
+
+ It is also possible to limit the maximum total allocatable
+ space, using malloc_set_footprint_limit. This is not
+ designed as a security feature in itself (calls to set limits
+ are not screened or privileged), but may be useful as one
+ aspect of a secure implementation.
+
+ Thread-safety: NOT thread-safe unless USE_LOCKS defined non-zero
+ When USE_LOCKS is defined, each public call to malloc, free,
+ etc is surrounded with a lock. By default, this uses a plain
+ pthread mutex, win32 critical section, or a spin-lock if if
+ available for the platform and not disabled by setting
+ USE_SPIN_LOCKS=0. However, if USE_RECURSIVE_LOCKS is defined,
+ recursive versions are used instead (which are not required for
+ base functionality but may be needed in layered extensions).
+ Using a global lock is not especially fast, and can be a major
+ bottleneck. It is designed only to provide minimal protection
+ in concurrent environments, and to provide a basis for
+ extensions. If you are using malloc in a concurrent program,
+ consider instead using nedmalloc
+ (http://www.nedprod.com/programs/portable/nedmalloc/) or
+ ptmalloc (See http://www.malloc.de), which are derived from
+ versions of this malloc.
+
+ System requirements: Any combination of MORECORE and/or MMAP/MUNMAP
+ This malloc can use unix sbrk or any emulation (invoked using
+ the CALL_MORECORE macro) and/or mmap/munmap or any emulation
+ (invoked using CALL_MMAP/CALL_MUNMAP) to get and release system
+ memory. On most unix systems, it tends to work best if both
+ MORECORE and MMAP are enabled. On Win32, it uses emulations
+ based on VirtualAlloc. It also uses common C library functions
+ like memset.
+
+ Compliance: I believe it is compliant with the Single Unix Specification
+ (See http://www.unix.org). Also SVID/XPG, ANSI C, and probably
+ others as well.
+
+* Overview of algorithms
+
+ This is not the fastest, most space-conserving, most portable, or
+ most tunable malloc ever written. However it is among the fastest
+ while also being among the most space-conserving, portable and
+ tunable. Consistent balance across these factors results in a good
+ general-purpose allocator for malloc-intensive programs.
+
+ In most ways, this malloc is a best-fit allocator. Generally, it
+ chooses the best-fitting existing chunk for a request, with ties
+ broken in approximately least-recently-used order. (This strategy
+ normally maintains low fragmentation.) However, for requests less
+ than 256bytes, it deviates from best-fit when there is not an
+ exactly fitting available chunk by preferring to use space adjacent
+ to that used for the previous small request, as well as by breaking
+ ties in approximately most-recently-used order. (These enhance
+ locality of series of small allocations.) And for very large requests
+ (>= 256Kb by default), it relies on system memory mapping
+ facilities, if supported. (This helps avoid carrying around and
+ possibly fragmenting memory used only for large chunks.)
+
+ All operations (except malloc_stats and mallinfo) have execution
+ times that are bounded by a constant factor of the number of bits in
+ a size_t, not counting any clearing in calloc or copying in realloc,
+ or actions surrounding MORECORE and MMAP that have times
+ proportional to the number of non-contiguous regions returned by
+ system allocation routines, which is often just 1. In real-time
+ applications, you can optionally suppress segment traversals using
+ NO_SEGMENT_TRAVERSAL, which assures bounded execution even when
+ system allocators return non-contiguous spaces, at the typical
+ expense of carrying around more memory and increased fragmentation.
+
+ The implementation is not very modular and seriously overuses
+ macros. Perhaps someday all C compilers will do as good a job
+ inlining modular code as can now be done by brute-force expansion,
+ but now, enough of them seem not to.
+
+ Some compilers issue a lot of warnings about code that is
+ dead/unreachable only on some platforms, and also about intentional
+ uses of negation on unsigned types. All known cases of each can be
+ ignored.
+
+ For a longer but out of date high-level description, see
+ http://gee.cs.oswego.edu/dl/html/malloc.html
+
+* MSPACES
+ If MSPACES is defined, then in addition to malloc, free, etc.,
+ this file also defines mspace_malloc, mspace_free, etc. These
+ are versions of malloc routines that take an "mspace" argument
+ obtained using create_mspace, to control all internal bookkeeping.
+ If ONLY_MSPACES is defined, only these versions are compiled.
+ So if you would like to use this allocator for only some allocations,
+ and your system malloc for others, you can compile with
+ ONLY_MSPACES and then do something like...
+ static mspace mymspace = create_mspace(0,0); // for example
+ #define mymalloc(bytes) mspace_malloc(mymspace, bytes)
+
+ (Note: If you only need one instance of an mspace, you can instead
+ use "USE_DL_PREFIX" to relabel the global malloc.)
+
+ You can similarly create thread-local allocators by storing
+ mspaces as thread-locals. For example:
+ static __thread mspace tlms = 0;
+ void* tlmalloc(size_t bytes) {
+ if (tlms == 0) tlms = create_mspace(0, 0);
+ return mspace_malloc(tlms, bytes);
+ }
+ void tlfree(void* mem) { mspace_free(tlms, mem); }
+
+ Unless FOOTERS is defined, each mspace is completely independent.
+ You cannot allocate from one and free to another (although
+ conformance is only weakly checked, so usage errors are not always
+ caught). If FOOTERS is defined, then each chunk carries around a tag
+ indicating its originating mspace, and frees are directed to their
+ originating spaces. Normally, this requires use of locks.
+
+ ------------------------- Compile-time options ---------------------------
+
+Be careful in setting #define values for numerical constants of type
+size_t. On some systems, literal values are not automatically extended
+to size_t precision unless they are explicitly casted. You can also
+use the symbolic values MAX_SIZE_T, SIZE_T_ONE, etc below.
+
+WIN32 default: defined if _WIN32 defined
+ Defining WIN32 sets up defaults for MS environment and compilers.
+ Otherwise defaults are for unix. Beware that there seem to be some
+ cases where this malloc might not be a pure drop-in replacement for
+ Win32 malloc: Random-looking failures from Win32 GDI API's (eg;
+ SetDIBits()) may be due to bugs in some video driver implementations
+ when pixel buffers are malloc()ed, and the region spans more than
+ one VirtualAlloc()ed region. Because dlmalloc uses a small (64Kb)
+ default granularity, pixel buffers may straddle virtual allocation
+ regions more often than when using the Microsoft allocator. You can
+ avoid this by using VirtualAlloc() and VirtualFree() for all pixel
+ buffers rather than using malloc(). If this is not possible,
+ recompile this malloc with a larger DEFAULT_GRANULARITY. Note:
+ in cases where MSC and gcc (cygwin) are known to differ on WIN32,
+ conditions use _MSC_VER to distinguish them.
+
+DLMALLOC_EXPORT default: extern
+ Defines how public APIs are declared. If you want to export via a
+ Windows DLL, you might define this as
+ #define DLMALLOC_EXPORT extern __declspec(dllexport)
+ If you want a POSIX ELF shared object, you might use
+ #define DLMALLOC_EXPORT extern __attribute__((visibility("default")))
+
+MALLOC_ALIGNMENT default: (size_t)(2 * sizeof(void *))
+ Controls the minimum alignment for malloc'ed chunks. It must be a
+ power of two and at least 8, even on machines for which smaller
+ alignments would suffice. It may be defined as larger than this
+ though. Note however that code and data structures are optimized for
+ the case of 8-byte alignment.
+
+MSPACES default: 0 (false)
+ If true, compile in support for independent allocation spaces.
+ This is only supported if HAVE_MMAP is true.
+
+ONLY_MSPACES default: 0 (false)
+ If true, only compile in mspace versions, not regular versions.
+
+USE_LOCKS default: 0 (false)
+ Causes each call to each public routine to be surrounded with
+ pthread or WIN32 mutex lock/unlock. (If set true, this can be
+ overridden on a per-mspace basis for mspace versions.) If set to a
+ non-zero value other than 1, locks are used, but their
+ implementation is left out, so lock functions must be supplied manually,
+ as described below.
+
+USE_SPIN_LOCKS default: 1 iff USE_LOCKS and spin locks available
+ If true, uses custom spin locks for locking. This is currently
+ supported only gcc >= 4.1, older gccs on x86 platforms, and recent
+ MS compilers. Otherwise, posix locks or win32 critical sections are
+ used.
+
+USE_RECURSIVE_LOCKS default: not defined
+ If defined nonzero, uses recursive (aka reentrant) locks, otherwise
+ uses plain mutexes. This is not required for malloc proper, but may
+ be needed for layered allocators such as nedmalloc.
+
+LOCK_AT_FORK default: not defined
+ If defined nonzero, performs pthread_atfork upon initialization
+ to initialize child lock while holding parent lock. The implementation
+ assumes that pthread locks (not custom locks) are being used. In other
+ cases, you may need to customize the implementation.
+
+FOOTERS default: 0
+ If true, provide extra checking and dispatching by placing
+ information in the footers of allocated chunks. This adds
+ space and time overhead.
+
+INSECURE default: 0
+ If true, omit checks for usage errors and heap space overwrites.
+
+USE_DL_PREFIX default: NOT defined
+ Causes compiler to prefix all public routines with the string 'dl'.
+ This can be useful when you only want to use this malloc in one part
+ of a program, using your regular system malloc elsewhere.
+
+MALLOC_INSPECT_ALL default: NOT defined
+ If defined, compiles malloc_inspect_all and mspace_inspect_all, that
+ perform traversal of all heap space. Unless access to these
+ functions is otherwise restricted, you probably do not want to
+ include them in secure implementations.
+
+ABORT default: defined as abort()
+ Defines how to abort on failed checks. On most systems, a failed
+ check cannot die with an "assert" or even print an informative
+ message, because the underlying print routines in turn call malloc,
+ which will fail again. Generally, the best policy is to simply call
+ abort(). It's not very useful to do more than this because many
+ errors due to overwriting will show up as address faults (null, odd
+ addresses etc) rather than malloc-triggered checks, so will also
+ abort. Also, most compilers know that abort() does not return, so
+ can better optimize code conditionally calling it.
+
+PROCEED_ON_ERROR default: defined as 0 (false)
+ Controls whether detected bad addresses cause them to bypassed
+ rather than aborting. If set, detected bad arguments to free and
+ realloc are ignored. And all bookkeeping information is zeroed out
+ upon a detected overwrite of freed heap space, thus losing the
+ ability to ever return it from malloc again, but enabling the
+ application to proceed. If PROCEED_ON_ERROR is defined, the
+ static variable malloc_corruption_error_count is compiled in
+ and can be examined to see if errors have occurred. This option
+ generates slower code than the default abort policy.
+
+DEBUG default: NOT defined
+ The DEBUG setting is mainly intended for people trying to modify
+ this code or diagnose problems when porting to new platforms.
+ However, it may also be able to better isolate user errors than just
+ using runtime checks. The assertions in the check routines spell
+ out in more detail the assumptions and invariants underlying the
+ algorithms. The checking is fairly extensive, and will slow down
+ execution noticeably. Calling malloc_stats or mallinfo with DEBUG
+ set will attempt to check every non-mmapped allocated and free chunk
+ in the course of computing the summaries.
+
+ABORT_ON_ASSERT_FAILURE default: defined as 1 (true)
+ Debugging assertion failures can be nearly impossible if your
+ version of the assert macro causes malloc to be called, which will
+ lead to a cascade of further failures, blowing the runtime stack.
+ ABORT_ON_ASSERT_FAILURE cause assertions failures to call abort(),
+ which will usually make debugging easier.
+
+MALLOC_FAILURE_ACTION default: sets errno to ENOMEM, or no-op on win32
+ The action to take before "return 0" when malloc fails to be able to
+ return memory because there is none available.
+
+HAVE_MORECORE default: 1 (true) unless win32 or ONLY_MSPACES
+ True if this system supports sbrk or an emulation of it.
+
+MORECORE default: sbrk
+ The name of the sbrk-style system routine to call to obtain more
+ memory. See below for guidance on writing custom MORECORE
+ functions. The type of the argument to sbrk/MORECORE varies across
+ systems. It cannot be size_t, because it supports negative
+ arguments, so it is normally the signed type of the same width as
+ size_t (sometimes declared as "intptr_t"). It doesn't much matter
+ though. Internally, we only call it with arguments less than half
+ the max value of a size_t, which should work across all reasonable
+ possibilities, although sometimes generating compiler warnings.
+
+MORECORE_CONTIGUOUS default: 1 (true) if HAVE_MORECORE
+ If true, take advantage of fact that consecutive calls to MORECORE
+ with positive arguments always return contiguous increasing
+ addresses. This is true of unix sbrk. It does not hurt too much to
+ set it true anyway, since malloc copes with non-contiguities.
+ Setting it false when definitely non-contiguous saves time
+ and possibly wasted space it would take to discover this though.
+
+MORECORE_CANNOT_TRIM default: NOT defined
+ True if MORECORE cannot release space back to the system when given
+ negative arguments. This is generally necessary only if you are
+ using a hand-crafted MORECORE function that cannot handle negative
+ arguments.
+
+NO_SEGMENT_TRAVERSAL default: 0
+ If non-zero, suppresses traversals of memory segments
+ returned by either MORECORE or CALL_MMAP. This disables
+ merging of segments that are contiguous, and selectively
+ releasing them to the OS if unused, but bounds execution times.
+
+HAVE_MMAP default: 1 (true)
+ True if this system supports mmap or an emulation of it. If so, and
+ HAVE_MORECORE is not true, MMAP is used for all system
+ allocation. If set and HAVE_MORECORE is true as well, MMAP is
+ primarily used to directly allocate very large blocks. It is also
+ used as a backup strategy in cases where MORECORE fails to provide
+ space from system. Note: A single call to MUNMAP is assumed to be
+ able to unmap memory that may have be allocated using multiple calls
+ to MMAP, so long as they are adjacent.
+
+HAVE_MREMAP default: 1 on linux, else 0
+ If true realloc() uses mremap() to re-allocate large blocks and
+ extend or shrink allocation spaces.
+
+MMAP_CLEARS default: 1 except on WINCE.
+ True if mmap clears memory so calloc doesn't need to. This is true
+ for standard unix mmap using /dev/zero and on WIN32 except for WINCE.
+
+USE_BUILTIN_FFS default: 0 (i.e., not used)
+ Causes malloc to use the builtin ffs() function to compute indices.
+ Some compilers may recognize and intrinsify ffs to be faster than the
+ supplied C version. Also, the case of x86 using gcc is special-cased
+ to an asm instruction, so is already as fast as it can be, and so
+ this setting has no effect. Similarly for Win32 under recent MS compilers.
+ (On most x86s, the asm version is only slightly faster than the C version.)
+
+malloc_getpagesize default: derive from system includes, or 4096.
+ The system page size. To the extent possible, this malloc manages
+ memory from the system in page-size units. This may be (and
+ usually is) a function rather than a constant. This is ignored
+ if WIN32, where page size is determined using getSystemInfo during
+ initialization.
+
+USE_DEV_RANDOM default: 0 (i.e., not used)
+ Causes malloc to use /dev/random to initialize secure magic seed for
+ stamping footers. Otherwise, the current time is used.
+
+NO_MALLINFO default: 0
+ If defined, don't compile "mallinfo". This can be a simple way
+ of dealing with mismatches between system declarations and
+ those in this file.
+
+MALLINFO_FIELD_TYPE default: size_t
+ The type of the fields in the mallinfo struct. This was originally
+ defined as "int" in SVID etc, but is more usefully defined as
+ size_t. The value is used only if HAVE_USR_INCLUDE_MALLOC_H is not set
+
+NO_MALLOC_STATS default: 0
+ If defined, don't compile "malloc_stats". This avoids calls to
+ fprintf and bringing in stdio dependencies you might not want.
+
+REALLOC_ZERO_BYTES_FREES default: not defined
+ This should be set if a call to realloc with zero bytes should
+ be the same as a call to free. Some people think it should. Otherwise,
+ since this malloc returns a unique pointer for malloc(0), so does
+ realloc(p, 0).
+
+LACKS_UNISTD_H, LACKS_FCNTL_H, LACKS_SYS_PARAM_H, LACKS_SYS_MMAN_H
+LACKS_STRINGS_H, LACKS_STRING_H, LACKS_SYS_TYPES_H, LACKS_ERRNO_H
+LACKS_STDLIB_H LACKS_SCHED_H LACKS_TIME_H default: NOT defined unless on WIN32
+ Define these if your system does not have these header files.
+ You might need to manually insert some of the declarations they provide.
+
+DEFAULT_GRANULARITY default: page size if MORECORE_CONTIGUOUS,
+ system_info.dwAllocationGranularity in WIN32,
+ otherwise 64K.
+ Also settable using mallopt(M_GRANULARITY, x)
+ The unit for allocating and deallocating memory from the system. On
+ most systems with contiguous MORECORE, there is no reason to
+ make this more than a page. However, systems with MMAP tend to
+ either require or encourage larger granularities. You can increase
+ this value to prevent system allocation functions to be called so
+ often, especially if they are slow. The value must be at least one
+ page and must be a power of two. Setting to 0 causes initialization
+ to either page size or win32 region size. (Note: In previous
+ versions of malloc, the equivalent of this option was called
+ "TOP_PAD")
+
+DEFAULT_TRIM_THRESHOLD default: 2MB
+ Also settable using mallopt(M_TRIM_THRESHOLD, x)
+ The maximum amount of unused top-most memory to keep before
+ releasing via malloc_trim in free(). Automatic trimming is mainly
+ useful in long-lived programs using contiguous MORECORE. Because
+ trimming via sbrk can be slow on some systems, and can sometimes be
+ wasteful (in cases where programs immediately afterward allocate
+ more large chunks) the value should be high enough so that your
+ overall system performance would improve by releasing this much
+ memory. As a rough guide, you might set to a value close to the
+ average size of a process (program) running on your system.
+ Releasing this much memory would allow such a process to run in
+ memory. Generally, it is worth tuning trim thresholds when a
+ program undergoes phases where several large chunks are allocated
+ and released in ways that can reuse each other's storage, perhaps
+ mixed with phases where there are no such chunks at all. The trim
+ value must be greater than page size to have any useful effect. To
+ disable trimming completely, you can set to MAX_SIZE_T. Note that the trick
+ some people use of mallocing a huge space and then freeing it at
+ program startup, in an attempt to reserve system memory, doesn't
+ have the intended effect under automatic trimming, since that memory
+ will immediately be returned to the system.
+
+DEFAULT_MMAP_THRESHOLD default: 256K
+ Also settable using mallopt(M_MMAP_THRESHOLD, x)
+ The request size threshold for using MMAP to directly service a
+ request. Requests of at least this size that cannot be allocated
+ using already-existing space will be serviced via mmap. (If enough
+ normal freed space already exists it is used instead.) Using mmap
+ segregates relatively large chunks of memory so that they can be
+ individually obtained and released from the host system. A request
+ serviced through mmap is never reused by any other request (at least
+ not directly; the system may just so happen to remap successive
+ requests to the same locations). Segregating space in this way has
+ the benefits that: Mmapped space can always be individually released
+ back to the system, which helps keep the system level memory demands
+ of a long-lived program low. Also, mapped memory doesn't become
+ `locked' between other chunks, as can happen with normally allocated
+ chunks, which means that even trimming via malloc_trim would not
+ release them. However, it has the disadvantage that the space
+ cannot be reclaimed, consolidated, and then used to service later
+ requests, as happens with normal chunks. The advantages of mmap
+ nearly always outweigh disadvantages for "large" chunks, but the
+ value of "large" may vary across systems. The default is an
+ empirically derived value that works well in most systems. You can
+ disable mmap by setting to MAX_SIZE_T.
+
+MAX_RELEASE_CHECK_RATE default: 4095 unless not HAVE_MMAP
+ The number of consolidated frees between checks to release
+ unused segments when freeing. When using non-contiguous segments,
+ especially with multiple mspaces, checking only for topmost space
+ doesn't always suffice to trigger trimming. To compensate for this,
+ free() will, with a period of MAX_RELEASE_CHECK_RATE (or the
+ current number of segments, if greater) try to release unused
+ segments to the OS when freeing chunks that result in
+ consolidation. The best value for this parameter is a compromise
+ between slowing down frees with relatively costly checks that
+ rarely trigger versus holding on to unused memory. To effectively
+ disable, set to MAX_SIZE_T. This may lead to a very slight speed
+ improvement at the expense of carrying around more memory.
+*/
+
+
+/* Version identifier to allow people to support multiple versions */
+#ifndef DLMALLOC_VERSION
+#define DLMALLOC_VERSION 20806
+#endif /* DLMALLOC_VERSION */
+
+#ifndef DLMALLOC_EXPORT
+#define DLMALLOC_EXPORT extern
+#endif
+
+#ifndef WIN32
+#ifdef _WIN32
+#define WIN32 1
+#endif /* _WIN32 */
+#ifdef _WIN32_WCE
+#define LACKS_FCNTL_H
+#define WIN32 1
+#endif /* _WIN32_WCE */
+#endif /* WIN32 */
+#ifdef WIN32
+#define WIN32_LEAN_AND_MEAN
+#include <windows.h>
+#include <tchar.h>
+#define HAVE_MMAP 1
+#define HAVE_MORECORE 0
+#define LACKS_UNISTD_H
+#define LACKS_SYS_PARAM_H
+#define LACKS_SYS_MMAN_H
+#define LACKS_STRING_H
+#define LACKS_STRINGS_H
+#define LACKS_SYS_TYPES_H
+#define LACKS_ERRNO_H
+#define LACKS_SCHED_H
+#ifndef MALLOC_FAILURE_ACTION
+#define MALLOC_FAILURE_ACTION
+#endif /* MALLOC_FAILURE_ACTION */
+#ifndef MMAP_CLEARS
+#ifdef _WIN32_WCE /* WINCE reportedly does not clear */
+#define MMAP_CLEARS 0
+#else
+#define MMAP_CLEARS 1
+#endif /* _WIN32_WCE */
+#endif /*MMAP_CLEARS */
+#endif /* WIN32 */
+
+#if defined(DARWIN) || defined(_DARWIN)
+/* Mac OSX docs advise not to use sbrk; it seems better to use mmap */
+#ifndef HAVE_MORECORE
+#define HAVE_MORECORE 0
+#define HAVE_MMAP 1
+/* OSX allocators provide 16 byte alignment */
+#ifndef MALLOC_ALIGNMENT
+#define MALLOC_ALIGNMENT ((size_t)16U)
+#endif
+#endif /* HAVE_MORECORE */
+#endif /* DARWIN */
+
+#ifndef LACKS_SYS_TYPES_H
+#include <sys/types.h> /* For size_t */
+#endif /* LACKS_SYS_TYPES_H */
+
+/* The maximum possible size_t value has all bits set */
+#define MAX_SIZE_T (~(size_t)0)
+
+#if (defined(USE_RECURSIVE_LOCKS) && USE_RECURSIVE_LOCKS != 0)
+#define RECURSIVE_LOCKS_ENABLED 1
+#else
+#define RECURSIVE_LOCKS_ENABLED 0
+#endif
+
+#if (defined(USE_RECURSIVE_LOCKS) && USE_RECURSIVE_LOCKS != 0)
+#define SPIN_LOCKS_ENABLED 1
+#else
+#define SPIN_LOCKS_ENABLED 0
+#endif
+
+#ifndef USE_LOCKS /* ensure true if spin or recursive locks set */
+#define USE_LOCKS ((SPIN_LOCKS_ENABLED != 0) || \
+ (RECURSIVE_LOCKS_ENABLED != 0))
+#endif /* USE_LOCKS */
+
+#if USE_LOCKS /* Spin locks for gcc >= 4.1, older gcc on x86, MSC >= 1310 */
+#if ((defined(__GNUC__) && \
+ ((__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 1)) || \
+ defined(__i386__) || defined(__x86_64__))) || \
+ (defined(_MSC_VER) && _MSC_VER>=1310))
+#ifndef USE_SPIN_LOCKS
+#define USE_SPIN_LOCKS 1
+#endif /* USE_SPIN_LOCKS */
+#elif USE_SPIN_LOCKS
+#error "USE_SPIN_LOCKS defined without implementation"
+#endif /* ... locks available... */
+#elif !defined(USE_SPIN_LOCKS)
+#define USE_SPIN_LOCKS 0
+#endif /* USE_LOCKS */
+
+#ifndef ONLY_MSPACES
+#define ONLY_MSPACES 0
+#endif /* ONLY_MSPACES */
+#ifndef MSPACES
+#if ONLY_MSPACES
+#define MSPACES 1
+#else /* ONLY_MSPACES */
+#define MSPACES 0
+#endif /* ONLY_MSPACES */
+#endif /* MSPACES */
+#ifndef MALLOC_ALIGNMENT
+#define MALLOC_ALIGNMENT ((size_t)(2 * sizeof(void *)))
+#endif /* MALLOC_ALIGNMENT */
+#ifndef FOOTERS
+#define FOOTERS 0
+#endif /* FOOTERS */
+#ifndef ABORT
+#define ABORT abort()
+#endif /* ABORT */
+#ifndef ABORT_ON_ASSERT_FAILURE
+#define ABORT_ON_ASSERT_FAILURE 1
+#endif /* ABORT_ON_ASSERT_FAILURE */
+#ifndef PROCEED_ON_ERROR
+#define PROCEED_ON_ERROR 0
+#endif /* PROCEED_ON_ERROR */
+
+#ifndef INSECURE
+#define INSECURE 0
+#endif /* INSECURE */
+#ifndef MALLOC_INSPECT_ALL
+#define MALLOC_INSPECT_ALL 0
+#endif /* MALLOC_INSPECT_ALL */
+#ifndef HAVE_MMAP
+#define HAVE_MMAP 1
+#endif /* HAVE_MMAP */
+#ifndef MMAP_CLEARS
+#define MMAP_CLEARS 1
+#endif /* MMAP_CLEARS */
+#ifndef HAVE_MREMAP
+#ifdef linux
+#define HAVE_MREMAP 1
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE /* Turns on mremap() definition */
+#endif /* _GNU_SOURCE */
+#else /* linux */
+#define HAVE_MREMAP 0
+#endif /* linux */
+#endif /* HAVE_MREMAP */
+#ifndef MALLOC_FAILURE_ACTION
+#define MALLOC_FAILURE_ACTION errno = ENOMEM;
+#endif /* MALLOC_FAILURE_ACTION */
+#ifndef HAVE_MORECORE
+#if ONLY_MSPACES
+#define HAVE_MORECORE 0
+#else /* ONLY_MSPACES */
+#define HAVE_MORECORE 1
+#endif /* ONLY_MSPACES */
+#endif /* HAVE_MORECORE */
+#if !HAVE_MORECORE
+#define MORECORE_CONTIGUOUS 0
+#else /* !HAVE_MORECORE */
+#define MORECORE_DEFAULT sbrk
+#ifndef MORECORE_CONTIGUOUS
+#define MORECORE_CONTIGUOUS 1
+#endif /* MORECORE_CONTIGUOUS */
+#endif /* HAVE_MORECORE */
+#ifndef DEFAULT_GRANULARITY
+#if (MORECORE_CONTIGUOUS || defined(WIN32))
+#define DEFAULT_GRANULARITY (0) /* 0 means to compute in init_mparams */
+#else /* MORECORE_CONTIGUOUS */
+#define DEFAULT_GRANULARITY ((size_t)64U * (size_t)1024U)
+#endif /* MORECORE_CONTIGUOUS */
+#endif /* DEFAULT_GRANULARITY */
+#ifndef DEFAULT_TRIM_THRESHOLD
+#ifndef MORECORE_CANNOT_TRIM
+#define DEFAULT_TRIM_THRESHOLD ((size_t)2U * (size_t)1024U * (size_t)1024U)
+#else /* MORECORE_CANNOT_TRIM */
+#define DEFAULT_TRIM_THRESHOLD MAX_SIZE_T
+#endif /* MORECORE_CANNOT_TRIM */
+#endif /* DEFAULT_TRIM_THRESHOLD */
+#ifndef DEFAULT_MMAP_THRESHOLD
+#if HAVE_MMAP
+#define DEFAULT_MMAP_THRESHOLD ((size_t)256U * (size_t)1024U)
+#else /* HAVE_MMAP */
+#define DEFAULT_MMAP_THRESHOLD MAX_SIZE_T
+#endif /* HAVE_MMAP */
+#endif /* DEFAULT_MMAP_THRESHOLD */
+#ifndef MAX_RELEASE_CHECK_RATE
+#if HAVE_MMAP
+#define MAX_RELEASE_CHECK_RATE 4095
+#else
+#define MAX_RELEASE_CHECK_RATE MAX_SIZE_T
+#endif /* HAVE_MMAP */
+#endif /* MAX_RELEASE_CHECK_RATE */
+#ifndef USE_BUILTIN_FFS
+#define USE_BUILTIN_FFS 0
+#endif /* USE_BUILTIN_FFS */
+#ifndef USE_DEV_RANDOM
+#define USE_DEV_RANDOM 0
+#endif /* USE_DEV_RANDOM */
+#ifndef NO_MALLINFO
+#define NO_MALLINFO 0
+#endif /* NO_MALLINFO */
+#ifndef MALLINFO_FIELD_TYPE
+#define MALLINFO_FIELD_TYPE size_t
+#endif /* MALLINFO_FIELD_TYPE */
+#ifndef NO_MALLOC_STATS
+#define NO_MALLOC_STATS 0
+#endif /* NO_MALLOC_STATS */
+#ifndef NO_SEGMENT_TRAVERSAL
+#define NO_SEGMENT_TRAVERSAL 0
+#endif /* NO_SEGMENT_TRAVERSAL */
+
+/*
+ mallopt tuning options. SVID/XPG defines four standard parameter
+ numbers for mallopt, normally defined in malloc.h. None of these
+ are used in this malloc, so setting them has no effect. But this
+ malloc does support the following options.
+*/
+
+#define M_TRIM_THRESHOLD (-1)
+#define M_GRANULARITY (-2)
+#define M_MMAP_THRESHOLD (-3)
+
+/* ------------------------ Mallinfo declarations ------------------------ */
+
+#if !NO_MALLINFO
+/*
+ This version of malloc supports the standard SVID/XPG mallinfo
+ routine that returns a struct containing usage properties and
+ statistics. It should work on any system that has a
+ /usr/include/malloc.h defining struct mallinfo. The main
+ declaration needed is the mallinfo struct that is returned (by-copy)
+ by mallinfo(). The malloinfo struct contains a bunch of fields that
+ are not even meaningful in this version of malloc. These fields are
+ are instead filled by mallinfo() with other numbers that might be of
+ interest.
+
+ HAVE_USR_INCLUDE_MALLOC_H should be set if you have a
+ /usr/include/malloc.h file that includes a declaration of struct
+ mallinfo. If so, it is included; else a compliant version is
+ declared below. These must be precisely the same for mallinfo() to
+ work. The original SVID version of this struct, defined on most
+ systems with mallinfo, declares all fields as ints. But some others
+ define as unsigned long. If your system defines the fields using a
+ type of different width than listed here, you MUST #include your
+ system version and #define HAVE_USR_INCLUDE_MALLOC_H.
+*/
+
+/* #define HAVE_USR_INCLUDE_MALLOC_H */
+
+#ifdef HAVE_USR_INCLUDE_MALLOC_H
+#include "/usr/include/malloc.h"
+#else /* HAVE_USR_INCLUDE_MALLOC_H */
+#ifndef STRUCT_MALLINFO_DECLARED
+/* HP-UX (and others?) redefines mallinfo unless _STRUCT_MALLINFO is defined */
+#define _STRUCT_MALLINFO
+#define STRUCT_MALLINFO_DECLARED 1
+struct mallinfo {
+ MALLINFO_FIELD_TYPE arena; /* non-mmapped space allocated from system */
+ MALLINFO_FIELD_TYPE ordblks; /* number of free chunks */
+ MALLINFO_FIELD_TYPE smblks; /* always 0 */
+ MALLINFO_FIELD_TYPE hblks; /* always 0 */
+ MALLINFO_FIELD_TYPE hblkhd; /* space in mmapped regions */
+ MALLINFO_FIELD_TYPE usmblks; /* maximum total allocated space */
+ MALLINFO_FIELD_TYPE fsmblks; /* always 0 */
+ MALLINFO_FIELD_TYPE uordblks; /* total allocated space */
+ MALLINFO_FIELD_TYPE fordblks; /* total free space */
+ MALLINFO_FIELD_TYPE keepcost; /* releasable (via malloc_trim) space */
+};
+#endif /* STRUCT_MALLINFO_DECLARED */
+#endif /* HAVE_USR_INCLUDE_MALLOC_H */
+#endif /* NO_MALLINFO */
+
+/*
+ Try to persuade compilers to inline. The most critical functions for
+ inlining are defined as macros, so these aren't used for them.
+*/
+
+#ifndef FORCEINLINE
+ #if defined(__GNUC__)
+#define FORCEINLINE __inline __attribute__ ((always_inline))
+ #elif defined(_MSC_VER)
+ #define FORCEINLINE __forceinline
+ #endif
+#endif
+#ifndef NOINLINE
+ #if defined(__GNUC__)
+ #define NOINLINE __attribute__ ((noinline))
+ #elif defined(_MSC_VER)
+ #define NOINLINE __declspec(noinline)
+ #else
+ #define NOINLINE
+ #endif
+#endif
+
+#ifdef __cplusplus
+extern "C" {
+#ifndef FORCEINLINE
+ #define FORCEINLINE inline
+#endif
+#endif /* __cplusplus */
+#ifndef FORCEINLINE
+ #define FORCEINLINE
+#endif
+
+#if !ONLY_MSPACES
+
+/* ------------------- Declarations of public routines ------------------- */
+
+#ifndef USE_DL_PREFIX
+#define dlcalloc calloc
+#define dlfree free
+#define dlmalloc malloc
+#define dlmemalign memalign
+#define dlposix_memalign posix_memalign
+#define dlrealloc realloc
+#define dlrealloc_in_place realloc_in_place
+#define dlvalloc valloc
+#define dlpvalloc pvalloc
+#define dlmallinfo mallinfo
+#define dlmallopt mallopt
+#define dlmalloc_trim malloc_trim
+#define dlmalloc_stats malloc_stats
+#define dlmalloc_usable_size malloc_usable_size
+#define dlmalloc_footprint malloc_footprint
+#define dlmalloc_max_footprint malloc_max_footprint
+#define dlmalloc_footprint_limit malloc_footprint_limit
+#define dlmalloc_set_footprint_limit malloc_set_footprint_limit
+#define dlmalloc_inspect_all malloc_inspect_all
+#define dlindependent_calloc independent_calloc
+#define dlindependent_comalloc independent_comalloc
+#define dlbulk_free bulk_free
+#endif /* USE_DL_PREFIX */
+
+/*
+ malloc(size_t n)
+ Returns a pointer to a newly allocated chunk of at least n bytes, or
+ null if no space is available, in which case errno is set to ENOMEM
+ on ANSI C systems.
+
+ If n is zero, malloc returns a minimum-sized chunk. (The minimum
+ size is 16 bytes on most 32bit systems, and 32 bytes on 64bit
+ systems.) Note that size_t is an unsigned type, so calls with
+ arguments that would be negative if signed are interpreted as
+ requests for huge amounts of space, which will often fail. The
+ maximum supported value of n differs across systems, but is in all
+ cases less than the maximum representable value of a size_t.
+*/
+DLMALLOC_EXPORT void* dlmalloc(size_t);
+
+/*
+ free(void* p)
+ Releases the chunk of memory pointed to by p, that had been previously
+ allocated using malloc or a related routine such as realloc.
+ It has no effect if p is null. If p was not malloced or already
+ freed, free(p) will by default cause the current program to abort.
+*/
+DLMALLOC_EXPORT void dlfree(void*);
+
+/*
+ calloc(size_t n_elements, size_t element_size);
+ Returns a pointer to n_elements * element_size bytes, with all locations
+ set to zero.
+*/
+DLMALLOC_EXPORT void* dlcalloc(size_t, size_t);
+
+/*
+ realloc(void* p, size_t n)
+ Returns a pointer to a chunk of size n that contains the same data
+ as does chunk p up to the minimum of (n, p's size) bytes, or null
+ if no space is available.
+
+ The returned pointer may or may not be the same as p. The algorithm
+ prefers extending p in most cases when possible, otherwise it
+ employs the equivalent of a malloc-copy-free sequence.
+
+ If p is null, realloc is equivalent to malloc.
+
+ If space is not available, realloc returns null, errno is set (if on
+ ANSI) and p is NOT freed.
+
+ if n is for fewer bytes than already held by p, the newly unused
+ space is lopped off and freed if possible. realloc with a size
+ argument of zero (re)allocates a minimum-sized chunk.
+
+ The old unix realloc convention of allowing the last-free'd chunk
+ to be used as an argument to realloc is not supported.
+*/
+DLMALLOC_EXPORT void* dlrealloc(void*, size_t);
+
+/*
+ realloc_in_place(void* p, size_t n)
+ Resizes the space allocated for p to size n, only if this can be
+ done without moving p (i.e., only if there is adjacent space
+ available if n is greater than p's current allocated size, or n is
+ less than or equal to p's size). This may be used instead of plain
+ realloc if an alternative allocation strategy is needed upon failure
+ to expand space; for example, reallocation of a buffer that must be
+ memory-aligned or cleared. You can use realloc_in_place to trigger
+ these alternatives only when needed.
+
+ Returns p if successful; otherwise null.
+*/
+DLMALLOC_EXPORT void* dlrealloc_in_place(void*, size_t);
+
+/*
+ memalign(size_t alignment, size_t n);
+ Returns a pointer to a newly allocated chunk of n bytes, aligned
+ in accord with the alignment argument.
+
+ The alignment argument should be a power of two. If the argument is
+ not a power of two, the nearest greater power is used.
+ 8-byte alignment is guaranteed by normal malloc calls, so don't
+ bother calling memalign with an argument of 8 or less.
+
+ Overreliance on memalign is a sure way to fragment space.
+*/
+DLMALLOC_EXPORT void* dlmemalign(size_t, size_t);
+
+/*
+ int posix_memalign(void** pp, size_t alignment, size_t n);
+ Allocates a chunk of n bytes, aligned in accord with the alignment
+ argument. Differs from memalign only in that it (1) assigns the
+ allocated memory to *pp rather than returning it, (2) fails and
+ returns EINVAL if the alignment is not a power of two (3) fails and
+ returns ENOMEM if memory cannot be allocated.
+*/
+DLMALLOC_EXPORT int dlposix_memalign(void**, size_t, size_t);
+
+/*
+ valloc(size_t n);
+ Equivalent to memalign(pagesize, n), where pagesize is the page
+ size of the system. If the pagesize is unknown, 4096 is used.
+*/
+DLMALLOC_EXPORT void* dlvalloc(size_t);
+
+/*
+ mallopt(int parameter_number, int parameter_value)
+ Sets tunable parameters The format is to provide a
+ (parameter-number, parameter-value) pair. mallopt then sets the
+ corresponding parameter to the argument value if it can (i.e., so
+ long as the value is meaningful), and returns 1 if successful else
+ 0. To workaround the fact that mallopt is specified to use int,
+ not size_t parameters, the value -1 is specially treated as the
+ maximum unsigned size_t value.
+
+ SVID/XPG/ANSI defines four standard param numbers for mallopt,
+ normally defined in malloc.h. None of these are use in this malloc,
+ so setting them has no effect. But this malloc also supports other
+ options in mallopt. See below for details. Briefly, supported
+ parameters are as follows (listed defaults are for "typical"
+ configurations).
+
+ Symbol param # default allowed param values
+ M_TRIM_THRESHOLD -1 2*1024*1024 any (-1 disables)
+ M_GRANULARITY -2 page size any power of 2 >= page size
+ M_MMAP_THRESHOLD -3 256*1024 any (or 0 if no MMAP support)
+*/
+DLMALLOC_EXPORT int dlmallopt(int, int);
+
+/*
+ malloc_footprint();
+ Returns the number of bytes obtained from the system. The total
+ number of bytes allocated by malloc, realloc etc., is less than this
+ value. Unlike mallinfo, this function returns only a precomputed
+ result, so can be called frequently to monitor memory consumption.
+ Even if locks are otherwise defined, this function does not use them,
+ so results might not be up to date.
+*/
+DLMALLOC_EXPORT size_t dlmalloc_footprint(void);
+
+/*
+ malloc_max_footprint();
+ Returns the maximum number of bytes obtained from the system. This
+ value will be greater than current footprint if deallocated space
+ has been reclaimed by the system. The peak number of bytes allocated
+ by malloc, realloc etc., is less than this value. Unlike mallinfo,
+ this function returns only a precomputed result, so can be called
+ frequently to monitor memory consumption. Even if locks are
+ otherwise defined, this function does not use them, so results might
+ not be up to date.
+*/
+DLMALLOC_EXPORT size_t dlmalloc_max_footprint(void);
+
+/*
+ malloc_footprint_limit();
+ Returns the number of bytes that the heap is allowed to obtain from
+ the system, returning the last value returned by
+ malloc_set_footprint_limit, or the maximum size_t value if
+ never set. The returned value reflects a permission. There is no
+ guarantee that this number of bytes can actually be obtained from
+ the system.
+*/
+DLMALLOC_EXPORT size_t dlmalloc_footprint_limit();
+
+/*
+ malloc_set_footprint_limit();
+ Sets the maximum number of bytes to obtain from the system, causing
+ failure returns from malloc and related functions upon attempts to
+ exceed this value. The argument value may be subject to page
+ rounding to an enforceable limit; this actual value is returned.
+ Using an argument of the maximum possible size_t effectively
+ disables checks. If the argument is less than or equal to the
+ current malloc_footprint, then all future allocations that require
+ additional system memory will fail. However, invocation cannot
+ retroactively deallocate existing used memory.
+*/
+DLMALLOC_EXPORT size_t dlmalloc_set_footprint_limit(size_t bytes);
+
+#if MALLOC_INSPECT_ALL
+/*
+ malloc_inspect_all(void(*handler)(void *start,
+ void *end,
+ size_t used_bytes,
+ void* callback_arg),
+ void* arg);
+ Traverses the heap and calls the given handler for each managed
+ region, skipping all bytes that are (or may be) used for bookkeeping
+ purposes. Traversal does not include include chunks that have been
+ directly memory mapped. Each reported region begins at the start
+ address, and continues up to but not including the end address. The
+ first used_bytes of the region contain allocated data. If
+ used_bytes is zero, the region is unallocated. The handler is
+ invoked with the given callback argument. If locks are defined, they
+ are held during the entire traversal. It is a bad idea to invoke
+ other malloc functions from within the handler.
+
+ For example, to count the number of in-use chunks with size greater
+ than 1000, you could write:
+ static int count = 0;
+ void count_chunks(void* start, void* end, size_t used, void* arg) {
+ if (used >= 1000) ++count;
+ }
+ then:
+ malloc_inspect_all(count_chunks, NULL);
+
+ malloc_inspect_all is compiled only if MALLOC_INSPECT_ALL is defined.
+*/
+DLMALLOC_EXPORT void dlmalloc_inspect_all(void(*handler)(void*, void *, size_t, void*),
+ void* arg);
+
+#endif /* MALLOC_INSPECT_ALL */
+
+#if !NO_MALLINFO
+/*
+ mallinfo()
+ Returns (by copy) a struct containing various summary statistics:
+
+ arena: current total non-mmapped bytes allocated from system
+ ordblks: the number of free chunks
+ smblks: always zero.
+ hblks: current number of mmapped regions
+ hblkhd: total bytes held in mmapped regions
+ usmblks: the maximum total allocated space. This will be greater
+ than current total if trimming has occurred.
+ fsmblks: always zero
+ uordblks: current total allocated space (normal or mmapped)
+ fordblks: total free space
+ keepcost: the maximum number of bytes that could ideally be released
+ back to system via malloc_trim. ("ideally" means that
+ it ignores page restrictions etc.)
+
+ Because these fields are ints, but internal bookkeeping may
+ be kept as longs, the reported values may wrap around zero and
+ thus be inaccurate.
+*/
+DLMALLOC_EXPORT struct mallinfo dlmallinfo(void);
+#endif /* NO_MALLINFO */
+
+/*
+ independent_calloc(size_t n_elements, size_t element_size, void* chunks[]);
+
+ independent_calloc is similar to calloc, but instead of returning a
+ single cleared space, it returns an array of pointers to n_elements
+ independent elements that can hold contents of size elem_size, each
+ of which starts out cleared, and can be independently freed,
+ realloc'ed etc. The elements are guaranteed to be adjacently
+ allocated (this is not guaranteed to occur with multiple callocs or
+ mallocs), which may also improve cache locality in some
+ applications.
+
+ The "chunks" argument is optional (i.e., may be null, which is
+ probably the most typical usage). If it is null, the returned array
+ is itself dynamically allocated and should also be freed when it is
+ no longer needed. Otherwise, the chunks array must be of at least
+ n_elements in length. It is filled in with the pointers to the
+ chunks.
+
+ In either case, independent_calloc returns this pointer array, or
+ null if the allocation failed. If n_elements is zero and "chunks"
+ is null, it returns a chunk representing an array with zero elements
+ (which should be freed if not wanted).
+
+ Each element must be freed when it is no longer needed. This can be
+ done all at once using bulk_free.
+
+ independent_calloc simplifies and speeds up implementations of many
+ kinds of pools. It may also be useful when constructing large data
+ structures that initially have a fixed number of fixed-sized nodes,
+ but the number is not known at compile time, and some of the nodes
+ may later need to be freed. For example:
+
+ struct Node { int item; struct Node* next; };
+
+ struct Node* build_list() {
+ struct Node** pool;
+ int n = read_number_of_nodes_needed();
+ if (n <= 0) return 0;
+ pool = (struct Node**)(independent_calloc(n, sizeof(struct Node), 0);
+ if (pool == 0) die();
+ // organize into a linked list...
+ struct Node* first = pool[0];
+ for (i = 0; i < n-1; ++i)
+ pool[i]->next = pool[i+1];
+ free(pool); // Can now free the array (or not, if it is needed later)
+ return first;
+ }
+*/
+DLMALLOC_EXPORT void** dlindependent_calloc(size_t, size_t, void**);
+
+/*
+ independent_comalloc(size_t n_elements, size_t sizes[], void* chunks[]);
+
+ independent_comalloc allocates, all at once, a set of n_elements
+ chunks with sizes indicated in the "sizes" array. It returns
+ an array of pointers to these elements, each of which can be
+ independently freed, realloc'ed etc. The elements are guaranteed to
+ be adjacently allocated (this is not guaranteed to occur with
+ multiple callocs or mallocs), which may also improve cache locality
+ in some applications.
+
+ The "chunks" argument is optional (i.e., may be null). If it is null
+ the returned array is itself dynamically allocated and should also
+ be freed when it is no longer needed. Otherwise, the chunks array
+ must be of at least n_elements in length. It is filled in with the
+ pointers to the chunks.
+
+ In either case, independent_comalloc returns this pointer array, or
+ null if the allocation failed. If n_elements is zero and chunks is
+ null, it returns a chunk representing an array with zero elements
+ (which should be freed if not wanted).
+
+ Each element must be freed when it is no longer needed. This can be
+ done all at once using bulk_free.
+
+ independent_comalloc differs from independent_calloc in that each
+ element may have a different size, and also that it does not
+ automatically clear elements.
+
+ independent_comalloc can be used to speed up allocation in cases
+ where several structs or objects must always be allocated at the
+ same time. For example:
+
+ struct Head { ... }
+ struct Foot { ... }
+
+ void send_message(char* msg) {
+ int msglen = strlen(msg);
+ size_t sizes[3] = { sizeof(struct Head), msglen, sizeof(struct Foot) };
+ void* chunks[3];
+ if (independent_comalloc(3, sizes, chunks) == 0)
+ die();
+ struct Head* head = (struct Head*)(chunks[0]);
+ char* body = (char*)(chunks[1]);
+ struct Foot* foot = (struct Foot*)(chunks[2]);
+ // ...
+ }
+
+ In general though, independent_comalloc is worth using only for
+ larger values of n_elements. For small values, you probably won't
+ detect enough difference from series of malloc calls to bother.
+
+ Overuse of independent_comalloc can increase overall memory usage,
+ since it cannot reuse existing noncontiguous small chunks that
+ might be available for some of the elements.
+*/
+DLMALLOC_EXPORT void** dlindependent_comalloc(size_t, size_t*, void**);
+
+/*
+ bulk_free(void* array[], size_t n_elements)
+ Frees and clears (sets to null) each non-null pointer in the given
+ array. This is likely to be faster than freeing them one-by-one.
+ If footers are used, pointers that have been allocated in different
+ mspaces are not freed or cleared, and the count of all such pointers
+ is returned. For large arrays of pointers with poor locality, it
+ may be worthwhile to sort this array before calling bulk_free.
+*/
+DLMALLOC_EXPORT size_t dlbulk_free(void**, size_t n_elements);
+
+/*
+ pvalloc(size_t n);
+ Equivalent to valloc(minimum-page-that-holds(n)), that is,
+ round up n to nearest pagesize.
+ */
+DLMALLOC_EXPORT void* dlpvalloc(size_t);
+
+/*
+ malloc_trim(size_t pad);
+
+ If possible, gives memory back to the system (via negative arguments
+ to sbrk) if there is unused memory at the `high' end of the malloc
+ pool or in unused MMAP segments. You can call this after freeing
+ large blocks of memory to potentially reduce the system-level memory
+ requirements of a program. However, it cannot guarantee to reduce
+ memory. Under some allocation patterns, some large free blocks of
+ memory will be locked between two used chunks, so they cannot be
+ given back to the system.
+
+ The `pad' argument to malloc_trim represents the amount of free
+ trailing space to leave untrimmed. If this argument is zero, only
+ the minimum amount of memory to maintain internal data structures
+ will be left. Non-zero arguments can be supplied to maintain enough
+ trailing space to service future expected allocations without having
+ to re-obtain memory from the system.
+
+ Malloc_trim returns 1 if it actually released any memory, else 0.
+*/
+DLMALLOC_EXPORT int dlmalloc_trim(size_t);
+
+/*
+ malloc_stats();
+ Prints on stderr the amount of space obtained from the system (both
+ via sbrk and mmap), the maximum amount (which may be more than
+ current if malloc_trim and/or munmap got called), and the current
+ number of bytes allocated via malloc (or realloc, etc) but not yet
+ freed. Note that this is the number of bytes allocated, not the
+ number requested. It will be larger than the number requested
+ because of alignment and bookkeeping overhead. Because it includes
+ alignment wastage as being in use, this figure may be greater than
+ zero even when no user-level chunks are allocated.
+
+ The reported current and maximum system memory can be inaccurate if
+ a program makes other calls to system memory allocation functions
+ (normally sbrk) outside of malloc.
+
+ malloc_stats prints only the most commonly interesting statistics.
+ More information can be obtained by calling mallinfo.
+*/
+DLMALLOC_EXPORT void dlmalloc_stats(void);
+
+/*
+ malloc_usable_size(void* p);
+
+ Returns the number of bytes you can actually use in
+ an allocated chunk, which may be more than you requested (although
+ often not) due to alignment and minimum size constraints.
+ You can use this many bytes without worrying about
+ overwriting other allocated objects. This is not a particularly great
+ programming practice. malloc_usable_size can be more useful in
+ debugging and assertions, for example:
+
+ p = malloc(n);
+ assert(malloc_usable_size(p) >= 256);
+*/
+size_t dlmalloc_usable_size(void*);
+
+#endif /* ONLY_MSPACES */
+
+#if MSPACES
+
+/*
+ mspace is an opaque type representing an independent
+ region of space that supports mspace_malloc, etc.
+*/
+typedef void* mspace;
+
+/*
+ create_mspace creates and returns a new independent space with the
+ given initial capacity, or, if 0, the default granularity size. It
+ returns null if there is no system memory available to create the
+ space. If argument locked is non-zero, the space uses a separate
+ lock to control access. The capacity of the space will grow
+ dynamically as needed to service mspace_malloc requests. You can
+ control the sizes of incremental increases of this space by
+ compiling with a different DEFAULT_GRANULARITY or dynamically
+ setting with mallopt(M_GRANULARITY, value).
+*/
+DLMALLOC_EXPORT mspace create_mspace(size_t capacity, int locked);
+
+/*
+ destroy_mspace destroys the given space, and attempts to return all
+ of its memory back to the system, returning the total number of
+ bytes freed. After destruction, the results of access to all memory
+ used by the space become undefined.
+*/
+DLMALLOC_EXPORT size_t destroy_mspace(mspace msp);
+
+/*
+ create_mspace_with_base uses the memory supplied as the initial base
+ of a new mspace. Part (less than 128*sizeof(size_t) bytes) of this
+ space is used for bookkeeping, so the capacity must be at least this
+ large. (Otherwise 0 is returned.) When this initial space is
+ exhausted, additional memory will be obtained from the system.
+ Destroying this space will deallocate all additionally allocated
+ space (if possible) but not the initial base.
+*/
+DLMALLOC_EXPORT mspace create_mspace_with_base(void* base, size_t capacity, int locked);
+
+/*
+ mspace_track_large_chunks controls whether requests for large chunks
+ are allocated in their own untracked mmapped regions, separate from
+ others in this mspace. By default large chunks are not tracked,
+ which reduces fragmentation. However, such chunks are not
+ necessarily released to the system upon destroy_mspace. Enabling
+ tracking by setting to true may increase fragmentation, but avoids
+ leakage when relying on destroy_mspace to release all memory
+ allocated using this space. The function returns the previous
+ setting.
+*/
+DLMALLOC_EXPORT int mspace_track_large_chunks(mspace msp, int enable);
+
+
+/*
+ mspace_malloc behaves as malloc, but operates within
+ the given space.
+*/
+DLMALLOC_EXPORT void* mspace_malloc(mspace msp, size_t bytes);
+
+/*
+ mspace_free behaves as free, but operates within
+ the given space.
+
+ If compiled with FOOTERS==1, mspace_free is not actually needed.
+ free may be called instead of mspace_free because freed chunks from
+ any space are handled by their originating spaces.
+*/
+DLMALLOC_EXPORT void mspace_free(mspace msp, void* mem);
+
+/*
+ mspace_realloc behaves as realloc, but operates within
+ the given space.
+
+ If compiled with FOOTERS==1, mspace_realloc is not actually
+ needed. realloc may be called instead of mspace_realloc because
+ realloced chunks from any space are handled by their originating
+ spaces.
+*/
+DLMALLOC_EXPORT void* mspace_realloc(mspace msp, void* mem, size_t newsize);
+
+/*
+ mspace_calloc behaves as calloc, but operates within
+ the given space.
+*/
+DLMALLOC_EXPORT void* mspace_calloc(mspace msp, size_t n_elements, size_t elem_size);
+
+/*
+ mspace_memalign behaves as memalign, but operates within
+ the given space.
+*/
+DLMALLOC_EXPORT void* mspace_memalign(mspace msp, size_t alignment, size_t bytes);
+
+/*
+ mspace_independent_calloc behaves as independent_calloc, but
+ operates within the given space.
+*/
+DLMALLOC_EXPORT void** mspace_independent_calloc(mspace msp, size_t n_elements,
+ size_t elem_size, void* chunks[]);
+
+/*
+ mspace_independent_comalloc behaves as independent_comalloc, but
+ operates within the given space.
+*/
+DLMALLOC_EXPORT void** mspace_independent_comalloc(mspace msp, size_t n_elements,
+ size_t sizes[], void* chunks[]);
+
+/*
+ mspace_footprint() returns the number of bytes obtained from the
+ system for this space.
+*/
+DLMALLOC_EXPORT size_t mspace_footprint(mspace msp);
+
+/*
+ mspace_max_footprint() returns the peak number of bytes obtained from the
+ system for this space.
+*/
+DLMALLOC_EXPORT size_t mspace_max_footprint(mspace msp);
+
+
+#if !NO_MALLINFO
+/*
+ mspace_mallinfo behaves as mallinfo, but reports properties of
+ the given space.
+*/
+DLMALLOC_EXPORT struct mallinfo mspace_mallinfo(mspace msp);
+#endif /* NO_MALLINFO */
+
+/*
+ malloc_usable_size(void* p) behaves the same as malloc_usable_size;
+*/
+DLMALLOC_EXPORT size_t mspace_usable_size(const void* mem);
+
+/*
+ mspace_malloc_stats behaves as malloc_stats, but reports
+ properties of the given space.
+*/
+DLMALLOC_EXPORT void mspace_malloc_stats(mspace msp);
+
+/*
+ mspace_trim behaves as malloc_trim, but
+ operates within the given space.
+*/
+DLMALLOC_EXPORT int mspace_trim(mspace msp, size_t pad);
+
+/*
+ An alias for mallopt.
+*/
+DLMALLOC_EXPORT int mspace_mallopt(int, int);
+
+#endif /* MSPACES */
+
+#ifdef __cplusplus
+} /* end of extern "C" */
+#endif /* __cplusplus */
+
+/*
+ ========================================================================
+ To make a fully customizable malloc.h header file, cut everything
+ above this line, put into file malloc.h, edit to suit, and #include it
+ on the next line, as well as in programs that use this malloc.
+ ========================================================================
+*/
+
+/* #include "malloc.h" */
+
+/*------------------------------ internal #includes ---------------------- */
+
+#ifdef _MSC_VER
+#pragma warning( disable : 4146 ) /* no "unsigned" warnings */
+#endif /* _MSC_VER */
+#if !NO_MALLOC_STATS
+#include <stdio.h> /* for printing in malloc_stats */
+#endif /* NO_MALLOC_STATS */
+#ifndef LACKS_ERRNO_H
+#include <errno.h> /* for MALLOC_FAILURE_ACTION */
+#endif /* LACKS_ERRNO_H */
+#ifdef DEBUG
+#if ABORT_ON_ASSERT_FAILURE
+#undef assert
+#define assert(x) if(!(x)) ABORT
+#else /* ABORT_ON_ASSERT_FAILURE */
+#include <assert.h>
+#endif /* ABORT_ON_ASSERT_FAILURE */
+#else /* DEBUG */
+#ifndef assert
+#define assert(x)
+#endif
+#define DEBUG 0
+#endif /* DEBUG */
+#if !defined(WIN32) && !defined(LACKS_TIME_H)
+#include <time.h> /* for magic initialization */
+#endif /* WIN32 */
+#ifndef LACKS_STDLIB_H
+#include <stdlib.h> /* for abort() */
+#endif /* LACKS_STDLIB_H */
+#ifndef LACKS_STRING_H
+#include <string.h> /* for memset etc */
+#endif /* LACKS_STRING_H */
+#if USE_BUILTIN_FFS
+#ifndef LACKS_STRINGS_H
+#include <strings.h> /* for ffs */
+#endif /* LACKS_STRINGS_H */
+#endif /* USE_BUILTIN_FFS */
+#if HAVE_MMAP
+#ifndef LACKS_SYS_MMAN_H
+/* On some versions of linux, mremap decl in mman.h needs __USE_GNU set */
+#if (defined(linux) && !defined(__USE_GNU))
+#define __USE_GNU 1
+#include <sys/mman.h> /* for mmap */
+#undef __USE_GNU
+#else
+#include <sys/mman.h> /* for mmap */
+#endif /* linux */
+#endif /* LACKS_SYS_MMAN_H */
+#ifndef LACKS_FCNTL_H
+#include <fcntl.h>
+#endif /* LACKS_FCNTL_H */
+#endif /* HAVE_MMAP */
+#ifndef LACKS_UNISTD_H
+#include <unistd.h> /* for sbrk, sysconf */
+#else /* LACKS_UNISTD_H */
+#if !defined(__FreeBSD__) && !defined(__OpenBSD__) && !defined(__NetBSD__)
+extern void* sbrk(ptrdiff_t);
+#endif /* FreeBSD etc */
+#endif /* LACKS_UNISTD_H */
+
+/* Declarations for locking */
+#if USE_LOCKS
+#ifndef WIN32
+#if defined (__SVR4) && defined (__sun) /* solaris */
+#include <thread.h>
+#elif !defined(LACKS_SCHED_H)
+#include <sched.h>
+#endif /* solaris or LACKS_SCHED_H */
+#if (defined(USE_RECURSIVE_LOCKS) && USE_RECURSIVE_LOCKS != 0) || !USE_SPIN_LOCKS
+#include <pthread.h>
+#endif /* USE_RECURSIVE_LOCKS ... */
+#elif defined(_MSC_VER)
+#ifndef _M_AMD64
+/* These are already defined on AMD64 builds */
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+LONG __cdecl _InterlockedCompareExchange(LONG volatile *Dest, LONG Exchange, LONG Comp);
+LONG __cdecl _InterlockedExchange(LONG volatile *Target, LONG Value);
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */
+#endif /* _M_AMD64 */
+#pragma intrinsic (_InterlockedCompareExchange)
+#pragma intrinsic (_InterlockedExchange)
+#define interlockedcompareexchange _InterlockedCompareExchange
+#define interlockedexchange _InterlockedExchange
+#elif defined(WIN32) && defined(__GNUC__)
+#define interlockedcompareexchange(a, b, c) __sync_val_compare_and_swap(a, c, b)
+#define interlockedexchange __sync_lock_test_and_set
+#endif /* Win32 */
+#else /* USE_LOCKS */
+#endif /* USE_LOCKS */
+
+#ifndef LOCK_AT_FORK
+#define LOCK_AT_FORK 0
+#endif
+
+/* Declarations for bit scanning on win32 */
+#if defined(_MSC_VER) && _MSC_VER>=1300
+#ifndef BitScanForward /* Try to avoid pulling in WinNT.h */
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+unsigned char _BitScanForward(unsigned long *index, unsigned long mask);
+unsigned char _BitScanReverse(unsigned long *index, unsigned long mask);
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */
+
+#define BitScanForward _BitScanForward
+#define BitScanReverse _BitScanReverse
+#pragma intrinsic(_BitScanForward)
+#pragma intrinsic(_BitScanReverse)
+#endif /* BitScanForward */
+#endif /* defined(_MSC_VER) && _MSC_VER>=1300 */
+
+#ifndef WIN32
+#ifndef malloc_getpagesize
+# ifdef _SC_PAGESIZE /* some SVR4 systems omit an underscore */
+# ifndef _SC_PAGE_SIZE
+# define _SC_PAGE_SIZE _SC_PAGESIZE
+# endif
+# endif
+# ifdef _SC_PAGE_SIZE
+# define malloc_getpagesize sysconf(_SC_PAGE_SIZE)
+# else
+# if defined(BSD) || defined(DGUX) || defined(HAVE_GETPAGESIZE)
+ extern size_t getpagesize();
+# define malloc_getpagesize getpagesize()
+# else
+# ifdef WIN32 /* use supplied emulation of getpagesize */
+# define malloc_getpagesize getpagesize()
+# else
+# ifndef LACKS_SYS_PARAM_H
+# include <sys/param.h>
+# endif
+# ifdef EXEC_PAGESIZE
+# define malloc_getpagesize EXEC_PAGESIZE
+# else
+# ifdef NBPG
+# ifndef CLSIZE
+# define malloc_getpagesize NBPG
+# else
+# define malloc_getpagesize (NBPG * CLSIZE)
+# endif
+# else
+# ifdef NBPC
+# define malloc_getpagesize NBPC
+# else
+# ifdef PAGESIZE
+# define malloc_getpagesize PAGESIZE
+# else /* just guess */
+# define malloc_getpagesize ((size_t)4096U)
+# endif
+# endif
+# endif
+# endif
+# endif
+# endif
+# endif
+#endif
+#endif
+
+/* ------------------- size_t and alignment properties -------------------- */
+
+/* The byte and bit size of a size_t */
+#define SIZE_T_SIZE (sizeof(size_t))
+#define SIZE_T_BITSIZE (sizeof(size_t) << 3)
+
+/* Some constants coerced to size_t */
+/* Annoying but necessary to avoid errors on some platforms */
+#define SIZE_T_ZERO ((size_t)0)
+#define SIZE_T_ONE ((size_t)1)
+#define SIZE_T_TWO ((size_t)2)
+#define SIZE_T_FOUR ((size_t)4)
+#define TWO_SIZE_T_SIZES (SIZE_T_SIZE<<1)
+#define FOUR_SIZE_T_SIZES (SIZE_T_SIZE<<2)
+#define SIX_SIZE_T_SIZES (FOUR_SIZE_T_SIZES+TWO_SIZE_T_SIZES)
+#define HALF_MAX_SIZE_T (MAX_SIZE_T / 2U)
+
+/* The bit mask value corresponding to MALLOC_ALIGNMENT */
+#define CHUNK_ALIGN_MASK (MALLOC_ALIGNMENT - SIZE_T_ONE)
+
+/* True if address a has acceptable alignment */
+#define is_aligned(A) (((size_t)((A)) & (CHUNK_ALIGN_MASK)) == 0)
+
+/* the number of bytes to offset an address to align it */
+#define align_offset(A)\
+ ((((size_t)(A) & CHUNK_ALIGN_MASK) == 0)? 0 :\
+ ((MALLOC_ALIGNMENT - ((size_t)(A) & CHUNK_ALIGN_MASK)) & CHUNK_ALIGN_MASK))
+
+/* -------------------------- MMAP preliminaries ------------------------- */
+
+/*
+ If HAVE_MORECORE or HAVE_MMAP are false, we just define calls and
+ checks to fail so compiler optimizer can delete code rather than
+ using so many "#if"s.
+*/
+
+
+/* MORECORE and MMAP must return MFAIL on failure */
+#define MFAIL ((void*)(MAX_SIZE_T))
+#define CMFAIL ((char*)(MFAIL)) /* defined for convenience */
+
+#if HAVE_MMAP
+
+#ifndef WIN32
+#define MUNMAP_DEFAULT(a, s) munmap((a), (s))
+#define MMAP_PROT (PROT_READ|PROT_WRITE)
+#if !defined(MAP_ANONYMOUS) && defined(MAP_ANON)
+#define MAP_ANONYMOUS MAP_ANON
+#endif /* MAP_ANON */
+#ifdef MAP_ANONYMOUS
+#define MMAP_FLAGS (MAP_PRIVATE|MAP_ANONYMOUS)
+#define MMAP_DEFAULT(s) mmap(0, (s), MMAP_PROT, MMAP_FLAGS, -1, 0)
+#else /* MAP_ANONYMOUS */
+/*
+ Nearly all versions of mmap support MAP_ANONYMOUS, so the following
+ is unlikely to be needed, but is supplied just in case.
+*/
+#define MMAP_FLAGS (MAP_PRIVATE)
+static int dev_zero_fd = -1; /* Cached file descriptor for /dev/zero. */
+#define MMAP_DEFAULT(s) ((dev_zero_fd < 0) ? \
+ (dev_zero_fd = open("/dev/zero", O_RDWR), \
+ mmap(0, (s), MMAP_PROT, MMAP_FLAGS, dev_zero_fd, 0)) : \
+ mmap(0, (s), MMAP_PROT, MMAP_FLAGS, dev_zero_fd, 0))
+#endif /* MAP_ANONYMOUS */
+
+#define DIRECT_MMAP_DEFAULT(s) MMAP_DEFAULT(s)
+
+#else /* WIN32 */
+
+/* Win32 MMAP via VirtualAlloc */
+static FORCEINLINE void* win32mmap(size_t size) {
+ void* ptr = VirtualAlloc(0, size, MEM_RESERVE|MEM_COMMIT, PAGE_READWRITE);
+ return (ptr != 0)? ptr: MFAIL;
+}
+
+/* For direct MMAP, use MEM_TOP_DOWN to minimize interference */
+static FORCEINLINE void* win32direct_mmap(size_t size) {
+ void* ptr = VirtualAlloc(0, size, MEM_RESERVE|MEM_COMMIT|MEM_TOP_DOWN,
+ PAGE_READWRITE);
+ return (ptr != 0)? ptr: MFAIL;
+}
+
+/* This function supports releasing coalesced segments */
+static FORCEINLINE int win32munmap(void* ptr, size_t size) {
+ MEMORY_BASIC_INFORMATION minfo;
+ char* cptr = (char*)ptr;
+ while (size) {
+ if (VirtualQuery(cptr, &minfo, sizeof(minfo)) == 0)
+ return -1;
+ if (minfo.BaseAddress != cptr || minfo.AllocationBase != cptr ||
+ minfo.State != MEM_COMMIT || minfo.RegionSize > size)
+ return -1;
+ if (VirtualFree(cptr, 0, MEM_RELEASE) == 0)
+ return -1;
+ cptr += minfo.RegionSize;
+ size -= minfo.RegionSize;
+ }
+ return 0;
+}
+
+#define MMAP_DEFAULT(s) win32mmap(s)
+#define MUNMAP_DEFAULT(a, s) win32munmap((a), (s))
+#define DIRECT_MMAP_DEFAULT(s) win32direct_mmap(s)
+#endif /* WIN32 */
+#endif /* HAVE_MMAP */
+
+#if HAVE_MREMAP
+#ifndef WIN32
+#define MREMAP_DEFAULT(addr, osz, nsz, mv) mremap((addr), (osz), (nsz), (mv))
+#endif /* WIN32 */
+#endif /* HAVE_MREMAP */
+
+/**
+ * Define CALL_MORECORE
+ */
+#if HAVE_MORECORE
+ #ifdef MORECORE
+ #define CALL_MORECORE(S) MORECORE(S)
+ #else /* MORECORE */
+ #define CALL_MORECORE(S) MORECORE_DEFAULT(S)
+ #endif /* MORECORE */
+#else /* HAVE_MORECORE */
+ #define CALL_MORECORE(S) MFAIL
+#endif /* HAVE_MORECORE */
+
+/**
+ * Define CALL_MMAP/CALL_MUNMAP/CALL_DIRECT_MMAP
+ */
+#if HAVE_MMAP
+ #define USE_MMAP_BIT (SIZE_T_ONE)
+
+ #ifdef MMAP
+ #define CALL_MMAP(s) MMAP(s)
+ #else /* MMAP */
+ #define CALL_MMAP(s) MMAP_DEFAULT(s)
+ #endif /* MMAP */
+ #ifdef MUNMAP
+ #define CALL_MUNMAP(a, s) MUNMAP((a), (s))
+ #else /* MUNMAP */
+ #define CALL_MUNMAP(a, s) MUNMAP_DEFAULT((a), (s))
+ #endif /* MUNMAP */
+ #ifdef DIRECT_MMAP
+ #define CALL_DIRECT_MMAP(s) DIRECT_MMAP(s)
+ #else /* DIRECT_MMAP */
+ #define CALL_DIRECT_MMAP(s) DIRECT_MMAP_DEFAULT(s)
+ #endif /* DIRECT_MMAP */
+#else /* HAVE_MMAP */
+ #define USE_MMAP_BIT (SIZE_T_ZERO)
+
+ #define MMAP(s) MFAIL
+ #define MUNMAP(a, s) (-1)
+ #define DIRECT_MMAP(s) MFAIL
+ #define CALL_DIRECT_MMAP(s) DIRECT_MMAP(s)
+ #define CALL_MMAP(s) MMAP(s)
+ #define CALL_MUNMAP(a, s) MUNMAP((a), (s))
+#endif /* HAVE_MMAP */
+
+/**
+ * Define CALL_MREMAP
+ */
+#if HAVE_MMAP && HAVE_MREMAP
+ #ifdef MREMAP
+ #define CALL_MREMAP(addr, osz, nsz, mv) MREMAP((addr), (osz), (nsz), (mv))
+ #else /* MREMAP */
+ #define CALL_MREMAP(addr, osz, nsz, mv) MREMAP_DEFAULT((addr), (osz), (nsz), (mv))
+ #endif /* MREMAP */
+#else /* HAVE_MMAP && HAVE_MREMAP */
+ #define CALL_MREMAP(addr, osz, nsz, mv) MFAIL
+#endif /* HAVE_MMAP && HAVE_MREMAP */
+
+/* mstate bit set if contiguous morecore disabled or failed */
+#define USE_NONCONTIGUOUS_BIT (4U)
+
+/* segment bit set in create_mspace_with_base */
+#define EXTERN_BIT (8U)
+
+
+/* --------------------------- Lock preliminaries ------------------------ */
+
+/*
+ When locks are defined, there is one global lock, plus
+ one per-mspace lock.
+
+ The global lock_ensures that mparams.magic and other unique
+ mparams values are initialized only once. It also protects
+ sequences of calls to MORECORE. In many cases sys_alloc requires
+ two calls, that should not be interleaved with calls by other
+ threads. This does not protect against direct calls to MORECORE
+ by other threads not using this lock, so there is still code to
+ cope the best we can on interference.
+
+ Per-mspace locks surround calls to malloc, free, etc.
+ By default, locks are simple non-reentrant mutexes.
+
+ Because lock-protected regions generally have bounded times, it is
+ OK to use the supplied simple spinlocks. Spinlocks are likely to
+ improve performance for lightly contended applications, but worsen
+ performance under heavy contention.
+
+ If USE_LOCKS is > 1, the definitions of lock routines here are
+ bypassed, in which case you will need to define the type MLOCK_T,
+ and at least INITIAL_LOCK, DESTROY_LOCK, ACQUIRE_LOCK, RELEASE_LOCK
+ and TRY_LOCK. You must also declare a
+ static MLOCK_T malloc_global_mutex = { initialization values };.
+
+*/
+
+#if !USE_LOCKS
+#define USE_LOCK_BIT (0U)
+#define INITIAL_LOCK(l) (0)
+#define DESTROY_LOCK(l) (0)
+#define ACQUIRE_MALLOC_GLOBAL_LOCK()
+#define RELEASE_MALLOC_GLOBAL_LOCK()
+
+#else
+#if USE_LOCKS > 1
+/* ----------------------- User-defined locks ------------------------ */
+/* Define your own lock implementation here */
+/* #define INITIAL_LOCK(lk) ... */
+/* #define DESTROY_LOCK(lk) ... */
+/* #define ACQUIRE_LOCK(lk) ... */
+/* #define RELEASE_LOCK(lk) ... */
+/* #define TRY_LOCK(lk) ... */
+/* static MLOCK_T malloc_global_mutex = ... */
+
+#elif USE_SPIN_LOCKS
+
+/* First, define CAS_LOCK and CLEAR_LOCK on ints */
+/* Note CAS_LOCK defined to return 0 on success */
+
+#if defined(__GNUC__)&& (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 1))
+#define CAS_LOCK(sl) __sync_lock_test_and_set(sl, 1)
+#define CLEAR_LOCK(sl) __sync_lock_release(sl)
+
+#elif (defined(__GNUC__) && (defined(__i386__) || defined(__x86_64__)))
+/* Custom spin locks for older gcc on x86 */
+static FORCEINLINE int x86_cas_lock(int *sl) {
+ int ret;
+ int val = 1;
+ int cmp = 0;
+ __asm__ __volatile__ ("lock; cmpxchgl %1, %2"
+ : "=a" (ret)
+ : "r" (val), "m" (*(sl)), "0"(cmp)
+ : "memory", "cc");
+ return ret;
+}
+
+static FORCEINLINE void x86_clear_lock(int* sl) {
+ assert(*sl != 0);
+ int prev = 0;
+ int ret;
+ __asm__ __volatile__ ("lock; xchgl %0, %1"
+ : "=r" (ret)
+ : "m" (*(sl)), "0"(prev)
+ : "memory");
+}
+
+#define CAS_LOCK(sl) x86_cas_lock(sl)
+#define CLEAR_LOCK(sl) x86_clear_lock(sl)
+
+#else /* Win32 MSC */
+#define CAS_LOCK(sl) interlockedexchange(sl, (LONG)1)
+#define CLEAR_LOCK(sl) interlockedexchange (sl, (LONG)0)
+
+#endif /* ... gcc spins locks ... */
+
+/* How to yield for a spin lock */
+#define SPINS_PER_YIELD 63
+#if defined(_MSC_VER)
+#define SLEEP_EX_DURATION 50 /* delay for yield/sleep */
+#define SPIN_LOCK_YIELD SleepEx(SLEEP_EX_DURATION, FALSE)
+#elif defined (__SVR4) && defined (__sun) /* solaris */
+#define SPIN_LOCK_YIELD thr_yield();
+#elif !defined(LACKS_SCHED_H)
+#define SPIN_LOCK_YIELD sched_yield();
+#else
+#define SPIN_LOCK_YIELD
+#endif /* ... yield ... */
+
+#if !defined(USE_RECURSIVE_LOCKS) || USE_RECURSIVE_LOCKS == 0
+/* Plain spin locks use single word (embedded in malloc_states) */
+static int spin_acquire_lock(int *sl) {
+ int spins = 0;
+ while (*(volatile int *)sl != 0 || CAS_LOCK(sl)) {
+ if ((++spins & SPINS_PER_YIELD) == 0) {
+ SPIN_LOCK_YIELD;
+ }
+ }
+ return 0;
+}
+
+#define MLOCK_T int
+#define TRY_LOCK(sl) !CAS_LOCK(sl)
+#define RELEASE_LOCK(sl) CLEAR_LOCK(sl)
+#define ACQUIRE_LOCK(sl) (CAS_LOCK(sl)? spin_acquire_lock(sl) : 0)
+#define INITIAL_LOCK(sl) (*sl = 0)
+#define DESTROY_LOCK(sl) (0)
+static MLOCK_T malloc_global_mutex = 0;
+
+#else /* USE_RECURSIVE_LOCKS */
+/* types for lock owners */
+#ifdef WIN32
+#define THREAD_ID_T DWORD
+#define CURRENT_THREAD GetCurrentThreadId()
+#define EQ_OWNER(X,Y) ((X) == (Y))
+#else
+/*
+ Note: the following assume that pthread_t is a type that can be
+ initialized to (casted) zero. If this is not the case, you will need to
+ somehow redefine these or not use spin locks.
+*/
+#define THREAD_ID_T pthread_t
+#define CURRENT_THREAD pthread_self()
+#define EQ_OWNER(X,Y) pthread_equal(X, Y)
+#endif
+
+struct malloc_recursive_lock {
+ int sl;
+ unsigned int c;
+ THREAD_ID_T threadid;
+};
+
+#define MLOCK_T struct malloc_recursive_lock
+static MLOCK_T malloc_global_mutex = { 0, 0, (THREAD_ID_T)0};
+
+static FORCEINLINE void recursive_release_lock(MLOCK_T *lk) {
+ assert(lk->sl != 0);
+ if (--lk->c == 0) {
+ CLEAR_LOCK(&lk->sl);
+ }
+}
+
+static FORCEINLINE int recursive_acquire_lock(MLOCK_T *lk) {
+ THREAD_ID_T mythreadid = CURRENT_THREAD;
+ int spins = 0;
+ for (;;) {
+ if (*((volatile int *)(&lk->sl)) == 0) {
+ if (!CAS_LOCK(&lk->sl)) {
+ lk->threadid = mythreadid;
+ lk->c = 1;
+ return 0;
+ }
+ }
+ else if (EQ_OWNER(lk->threadid, mythreadid)) {
+ ++lk->c;
+ return 0;
+ }
+ if ((++spins & SPINS_PER_YIELD) == 0) {
+ SPIN_LOCK_YIELD;
+ }
+ }
+}
+
+static FORCEINLINE int recursive_try_lock(MLOCK_T *lk) {
+ THREAD_ID_T mythreadid = CURRENT_THREAD;
+ if (*((volatile int *)(&lk->sl)) == 0) {
+ if (!CAS_LOCK(&lk->sl)) {
+ lk->threadid = mythreadid;
+ lk->c = 1;
+ return 1;
+ }
+ }
+ else if (EQ_OWNER(lk->threadid, mythreadid)) {
+ ++lk->c;
+ return 1;
+ }
+ return 0;
+}
+
+#define RELEASE_LOCK(lk) recursive_release_lock(lk)
+#define TRY_LOCK(lk) recursive_try_lock(lk)
+#define ACQUIRE_LOCK(lk) recursive_acquire_lock(lk)
+#define INITIAL_LOCK(lk) ((lk)->threadid = (THREAD_ID_T)0, (lk)->sl = 0, (lk)->c = 0)
+#define DESTROY_LOCK(lk) (0)
+#endif /* USE_RECURSIVE_LOCKS */
+
+#elif defined(WIN32) /* Win32 critical sections */
+#define MLOCK_T CRITICAL_SECTION
+#define ACQUIRE_LOCK(lk) (EnterCriticalSection(lk), 0)
+#define RELEASE_LOCK(lk) LeaveCriticalSection(lk)
+#define TRY_LOCK(lk) TryEnterCriticalSection(lk)
+#define INITIAL_LOCK(lk) (!InitializeCriticalSectionAndSpinCount((lk), 0x80000000|4000))
+#define DESTROY_LOCK(lk) (DeleteCriticalSection(lk), 0)
+#define NEED_GLOBAL_LOCK_INIT
+
+static MLOCK_T malloc_global_mutex;
+static volatile LONG malloc_global_mutex_status;
+
+/* Use spin loop to initialize global lock */
+static void init_malloc_global_mutex() {
+ for (;;) {
+ long stat = malloc_global_mutex_status;
+ if (stat > 0)
+ return;
+ /* transition to < 0 while initializing, then to > 0) */
+ if (stat == 0 &&
+ interlockedcompareexchange(&malloc_global_mutex_status, (LONG)-1, (LONG)0) == 0) {
+ InitializeCriticalSection(&malloc_global_mutex);
+ interlockedexchange(&malloc_global_mutex_status, (LONG)1);
+ return;
+ }
+ SleepEx(0, FALSE);
+ }
+}
+
+#else /* pthreads-based locks */
+#define MLOCK_T pthread_mutex_t
+#define ACQUIRE_LOCK(lk) pthread_mutex_lock(lk)
+#define RELEASE_LOCK(lk) pthread_mutex_unlock(lk)
+#define TRY_LOCK(lk) (!pthread_mutex_trylock(lk))
+#define INITIAL_LOCK(lk) pthread_init_lock(lk)
+#define DESTROY_LOCK(lk) pthread_mutex_destroy(lk)
+
+#if defined(USE_RECURSIVE_LOCKS) && USE_RECURSIVE_LOCKS != 0 && defined(linux) && !defined(PTHREAD_MUTEX_RECURSIVE)
+/* Cope with old-style linux recursive lock initialization by adding */
+/* skipped internal declaration from pthread.h */
+extern int pthread_mutexattr_setkind_np __P ((pthread_mutexattr_t *__attr,
+ int __kind));
+#define PTHREAD_MUTEX_RECURSIVE PTHREAD_MUTEX_RECURSIVE_NP
+#define pthread_mutexattr_settype(x,y) pthread_mutexattr_setkind_np(x,y)
+#endif /* USE_RECURSIVE_LOCKS ... */
+
+static MLOCK_T malloc_global_mutex = PTHREAD_MUTEX_INITIALIZER;
+
+static int pthread_init_lock (MLOCK_T *lk) {
+ pthread_mutexattr_t attr;
+ if (pthread_mutexattr_init(&attr)) return 1;
+#if defined(USE_RECURSIVE_LOCKS) && USE_RECURSIVE_LOCKS != 0
+ if (pthread_mutexattr_settype(&attr, PTHREAD_MUTEX_RECURSIVE)) return 1;
+#endif
+ if (pthread_mutex_init(lk, &attr)) return 1;
+ if (pthread_mutexattr_destroy(&attr)) return 1;
+ return 0;
+}
+
+#endif /* ... lock types ... */
+
+/* Common code for all lock types */
+#define USE_LOCK_BIT (2U)
+
+#ifndef ACQUIRE_MALLOC_GLOBAL_LOCK
+#define ACQUIRE_MALLOC_GLOBAL_LOCK() ACQUIRE_LOCK(&malloc_global_mutex);
+#endif
+
+#ifndef RELEASE_MALLOC_GLOBAL_LOCK
+#define RELEASE_MALLOC_GLOBAL_LOCK() RELEASE_LOCK(&malloc_global_mutex);
+#endif
+
+#endif /* USE_LOCKS */
+
+/* ----------------------- Chunk representations ------------------------ */
+
+/*
+ (The following includes lightly edited explanations by Colin Plumb.)
+
+ The malloc_chunk declaration below is misleading (but accurate and
+ necessary). It declares a "view" into memory allowing access to
+ necessary fields at known offsets from a given base.
+
+ Chunks of memory are maintained using a `boundary tag' method as
+ originally described by Knuth. (See the paper by Paul Wilson
+ ftp://ftp.cs.utexas.edu/pub/garbage/allocsrv.ps for a survey of such
+ techniques.) Sizes of free chunks are stored both in the front of
+ each chunk and at the end. This makes consolidating fragmented
+ chunks into bigger chunks fast. The head fields also hold bits
+ representing whether chunks are free or in use.
+
+ Here are some pictures to make it clearer. They are "exploded" to
+ show that the state of a chunk can be thought of as extending from
+ the high 31 bits of the head field of its header through the
+ prev_foot and PINUSE_BIT bit of the following chunk header.
+
+ A chunk that's in use looks like:
+
+ chunk-> +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Size of previous chunk (if P = 0) |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ |P|
+ | Size of this chunk 1| +-+
+ mem-> +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | |
+ +- -+
+ | |
+ +- -+
+ | :
+ +- size - sizeof(size_t) available payload bytes -+
+ : |
+ chunk-> +- -+
+ | |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ |1|
+ | Size of next chunk (may or may not be in use) | +-+
+ mem-> +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+
+ And if it's free, it looks like this:
+
+ chunk-> +- -+
+ | User payload (must be in use, or we would have merged!) |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ |P|
+ | Size of this chunk 0| +-+
+ mem-> +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Next pointer |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Prev pointer |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | :
+ +- size - sizeof(struct chunk) unused bytes -+
+ : |
+ chunk-> +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Size of this chunk |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ |0|
+ | Size of next chunk (must be in use, or we would have merged)| +-+
+ mem-> +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | :
+ +- User payload -+
+ : |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ |0|
+ +-+
+ Note that since we always merge adjacent free chunks, the chunks
+ adjacent to a free chunk must be in use.
+
+ Given a pointer to a chunk (which can be derived trivially from the
+ payload pointer) we can, in O(1) time, find out whether the adjacent
+ chunks are free, and if so, unlink them from the lists that they
+ are on and merge them with the current chunk.
+
+ Chunks always begin on even word boundaries, so the mem portion
+ (which is returned to the user) is also on an even word boundary, and
+ thus at least double-word aligned.
+
+ The P (PINUSE_BIT) bit, stored in the unused low-order bit of the
+ chunk size (which is always a multiple of two words), is an in-use
+ bit for the *previous* chunk. If that bit is *clear*, then the
+ word before the current chunk size contains the previous chunk
+ size, and can be used to find the front of the previous chunk.
+ The very first chunk allocated always has this bit set, preventing
+ access to non-existent (or non-owned) memory. If pinuse is set for
+ any given chunk, then you CANNOT determine the size of the
+ previous chunk, and might even get a memory addressing fault when
+ trying to do so.
+
+ The C (CINUSE_BIT) bit, stored in the unused second-lowest bit of
+ the chunk size redundantly records whether the current chunk is
+ inuse (unless the chunk is mmapped). This redundancy enables usage
+ checks within free and realloc, and reduces indirection when freeing
+ and consolidating chunks.
+
+ Each freshly allocated chunk must have both cinuse and pinuse set.
+ That is, each allocated chunk borders either a previously allocated
+ and still in-use chunk, or the base of its memory arena. This is
+ ensured by making all allocations from the `lowest' part of any
+ found chunk. Further, no free chunk physically borders another one,
+ so each free chunk is known to be preceded and followed by either
+ inuse chunks or the ends of memory.
+
+ Note that the `foot' of the current chunk is actually represented
+ as the prev_foot of the NEXT chunk. This makes it easier to
+ deal with alignments etc but can be very confusing when trying
+ to extend or adapt this code.
+
+ The exceptions to all this are
+
+ 1. The special chunk `top' is the top-most available chunk (i.e.,
+ the one bordering the end of available memory). It is treated
+ specially. Top is never included in any bin, is used only if
+ no other chunk is available, and is released back to the
+ system if it is very large (see M_TRIM_THRESHOLD). In effect,
+ the top chunk is treated as larger (and thus less well
+ fitting) than any other available chunk. The top chunk
+ doesn't update its trailing size field since there is no next
+ contiguous chunk that would have to index off it. However,
+ space is still allocated for it (TOP_FOOT_SIZE) to enable
+ separation or merging when space is extended.
+
+ 3. Chunks allocated via mmap, have both cinuse and pinuse bits
+ cleared in their head fields. Because they are allocated
+ one-by-one, each must carry its own prev_foot field, which is
+ also used to hold the offset this chunk has within its mmapped
+ region, which is needed to preserve alignment. Each mmapped
+ chunk is trailed by the first two fields of a fake next-chunk
+ for sake of usage checks.
+
+*/
+
+struct malloc_chunk {
+ size_t prev_foot; /* Size of previous chunk (if free). */
+ size_t head; /* Size and inuse bits. */
+ struct malloc_chunk* fd; /* double links -- used only if free. */
+ struct malloc_chunk* bk;
+};
+
+typedef struct malloc_chunk mchunk;
+typedef struct malloc_chunk* mchunkptr;
+typedef struct malloc_chunk* sbinptr; /* The type of bins of chunks */
+typedef unsigned int bindex_t; /* Described below */
+typedef unsigned int binmap_t; /* Described below */
+typedef unsigned int flag_t; /* The type of various bit flag sets */
+
+/* ------------------- Chunks sizes and alignments ----------------------- */
+
+#define MCHUNK_SIZE (sizeof(mchunk))
+
+#if FOOTERS
+#define CHUNK_OVERHEAD (TWO_SIZE_T_SIZES)
+#else /* FOOTERS */
+#define CHUNK_OVERHEAD (SIZE_T_SIZE)
+#endif /* FOOTERS */
+
+/* MMapped chunks need a second word of overhead ... */
+#define MMAP_CHUNK_OVERHEAD (TWO_SIZE_T_SIZES)
+/* ... and additional padding for fake next-chunk at foot */
+#define MMAP_FOOT_PAD (FOUR_SIZE_T_SIZES)
+
+/* The smallest size we can malloc is an aligned minimal chunk */
+#define MIN_CHUNK_SIZE\
+ ((MCHUNK_SIZE + CHUNK_ALIGN_MASK) & ~CHUNK_ALIGN_MASK)
+
+/* conversion from malloc headers to user pointers, and back */
+#define chunk2mem(p) ((void*)((char*)(p) + TWO_SIZE_T_SIZES))
+#define mem2chunk(mem) ((mchunkptr)((char*)(mem) - TWO_SIZE_T_SIZES))
+/* chunk associated with aligned address A */
+#define align_as_chunk(A) (mchunkptr)((A) + align_offset(chunk2mem(A)))
+
+/* Bounds on request (not chunk) sizes. */
+#define MAX_REQUEST ((-MIN_CHUNK_SIZE) << 2)
+#define MIN_REQUEST (MIN_CHUNK_SIZE - CHUNK_OVERHEAD - SIZE_T_ONE)
+
+/* pad request bytes into a usable size */
+#define pad_request(req) \
+ (((req) + CHUNK_OVERHEAD + CHUNK_ALIGN_MASK) & ~CHUNK_ALIGN_MASK)
+
+/* pad request, checking for minimum (but not maximum) */
+#define request2size(req) \
+ (((req) < MIN_REQUEST)? MIN_CHUNK_SIZE : pad_request(req))
+
+
+/* ------------------ Operations on head and foot fields ----------------- */
+
+/*
+ The head field of a chunk is or'ed with PINUSE_BIT when previous
+ adjacent chunk in use, and or'ed with CINUSE_BIT if this chunk is in
+ use, unless mmapped, in which case both bits are cleared.
+
+ FLAG4_BIT is not used by this malloc, but might be useful in extensions.
+*/
+
+#define PINUSE_BIT (SIZE_T_ONE)
+#define CINUSE_BIT (SIZE_T_TWO)
+#define FLAG4_BIT (SIZE_T_FOUR)
+#define INUSE_BITS (PINUSE_BIT|CINUSE_BIT)
+#define FLAG_BITS (PINUSE_BIT|CINUSE_BIT|FLAG4_BIT)
+
+/* Head value for fenceposts */
+#define FENCEPOST_HEAD (INUSE_BITS|SIZE_T_SIZE)
+
+/* extraction of fields from head words */
+#define cinuse(p) ((p)->head & CINUSE_BIT)
+#define pinuse(p) ((p)->head & PINUSE_BIT)
+#define flag4inuse(p) ((p)->head & FLAG4_BIT)
+#define is_inuse(p) (((p)->head & INUSE_BITS) != PINUSE_BIT)
+#define is_mmapped(p) (((p)->head & INUSE_BITS) == 0)
+
+#define chunksize(p) ((p)->head & ~(FLAG_BITS))
+
+#define clear_pinuse(p) ((p)->head &= ~PINUSE_BIT)
+#define set_flag4(p) ((p)->head |= FLAG4_BIT)
+#define clear_flag4(p) ((p)->head &= ~FLAG4_BIT)
+
+/* Treat space at ptr +/- offset as a chunk */
+#define chunk_plus_offset(p, s) ((mchunkptr)(((char*)(p)) + (s)))
+#define chunk_minus_offset(p, s) ((mchunkptr)(((char*)(p)) - (s)))
+
+/* Ptr to next or previous physical malloc_chunk. */
+#define next_chunk(p) ((mchunkptr)( ((char*)(p)) + ((p)->head & ~FLAG_BITS)))
+#define prev_chunk(p) ((mchunkptr)( ((char*)(p)) - ((p)->prev_foot) ))
+
+/* extract next chunk's pinuse bit */
+#define next_pinuse(p) ((next_chunk(p)->head) & PINUSE_BIT)
+
+/* Get/set size at footer */
+#define get_foot(p, s) (((mchunkptr)((char*)(p) + (s)))->prev_foot)
+#define set_foot(p, s) (((mchunkptr)((char*)(p) + (s)))->prev_foot = (s))
+
+/* Set size, pinuse bit, and foot */
+#define set_size_and_pinuse_of_free_chunk(p, s)\
+ ((p)->head = (s|PINUSE_BIT), set_foot(p, s))
+
+/* Set size, pinuse bit, foot, and clear next pinuse */
+#define set_free_with_pinuse(p, s, n)\
+ (clear_pinuse(n), set_size_and_pinuse_of_free_chunk(p, s))
+
+/* Get the internal overhead associated with chunk p */
+#define overhead_for(p)\
+ (is_mmapped(p)? MMAP_CHUNK_OVERHEAD : CHUNK_OVERHEAD)
+
+/* Return true if malloced space is not necessarily cleared */
+#if MMAP_CLEARS
+#define calloc_must_clear(p) (!is_mmapped(p))
+#else /* MMAP_CLEARS */
+#define calloc_must_clear(p) (1)
+#endif /* MMAP_CLEARS */
+
+/* ---------------------- Overlaid data structures ----------------------- */
+
+/*
+ When chunks are not in use, they are treated as nodes of either
+ lists or trees.
+
+ "Small" chunks are stored in circular doubly-linked lists, and look
+ like this:
+
+ chunk-> +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Size of previous chunk |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ `head:' | Size of chunk, in bytes |P|
+ mem-> +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Forward pointer to next chunk in list |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Back pointer to previous chunk in list |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Unused space (may be 0 bytes long) .
+ . .
+ . |
+nextchunk-> +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ `foot:' | Size of chunk, in bytes |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+
+ Larger chunks are kept in a form of bitwise digital trees (aka
+ tries) keyed on chunksizes. Because malloc_tree_chunks are only for
+ free chunks greater than 256 bytes, their size doesn't impose any
+ constraints on user chunk sizes. Each node looks like:
+
+ chunk-> +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Size of previous chunk |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ `head:' | Size of chunk, in bytes |P|
+ mem-> +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Forward pointer to next chunk of same size |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Back pointer to previous chunk of same size |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Pointer to left child (child[0]) |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Pointer to right child (child[1]) |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Pointer to parent |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | bin index of this chunk |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Unused space .
+ . |
+nextchunk-> +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ `foot:' | Size of chunk, in bytes |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+
+ Each tree holding treenodes is a tree of unique chunk sizes. Chunks
+ of the same size are arranged in a circularly-linked list, with only
+ the oldest chunk (the next to be used, in our FIFO ordering)
+ actually in the tree. (Tree members are distinguished by a non-null
+ parent pointer.) If a chunk with the same size an an existing node
+ is inserted, it is linked off the existing node using pointers that
+ work in the same way as fd/bk pointers of small chunks.
+
+ Each tree contains a power of 2 sized range of chunk sizes (the
+ smallest is 0x100 <= x < 0x180), which is is divided in half at each
+ tree level, with the chunks in the smaller half of the range (0x100
+ <= x < 0x140 for the top nose) in the left subtree and the larger
+ half (0x140 <= x < 0x180) in the right subtree. This is, of course,
+ done by inspecting individual bits.
+
+ Using these rules, each node's left subtree contains all smaller
+ sizes than its right subtree. However, the node at the root of each
+ subtree has no particular ordering relationship to either. (The
+ dividing line between the subtree sizes is based on trie relation.)
+ If we remove the last chunk of a given size from the interior of the
+ tree, we need to replace it with a leaf node. The tree ordering
+ rules permit a node to be replaced by any leaf below it.
+
+ The smallest chunk in a tree (a common operation in a best-fit
+ allocator) can be found by walking a path to the leftmost leaf in
+ the tree. Unlike a usual binary tree, where we follow left child
+ pointers until we reach a null, here we follow the right child
+ pointer any time the left one is null, until we reach a leaf with
+ both child pointers null. The smallest chunk in the tree will be
+ somewhere along that path.
+
+ The worst case number of steps to add, find, or remove a node is
+ bounded by the number of bits differentiating chunks within
+ bins. Under current bin calculations, this ranges from 6 up to 21
+ (for 32 bit sizes) or up to 53 (for 64 bit sizes). The typical case
+ is of course much better.
+*/
+
+struct malloc_tree_chunk {
+ /* The first four fields must be compatible with malloc_chunk */
+ size_t prev_foot;
+ size_t head;
+ struct malloc_tree_chunk* fd;
+ struct malloc_tree_chunk* bk;
+
+ struct malloc_tree_chunk* child[2];
+ struct malloc_tree_chunk* parent;
+ bindex_t index;
+};
+
+typedef struct malloc_tree_chunk tchunk;
+typedef struct malloc_tree_chunk* tchunkptr;
+typedef struct malloc_tree_chunk* tbinptr; /* The type of bins of trees */
+
+/* A little helper macro for trees */
+#define leftmost_child(t) ((t)->child[0] != 0? (t)->child[0] : (t)->child[1])
+
+/* ----------------------------- Segments -------------------------------- */
+
+/*
+ Each malloc space may include non-contiguous segments, held in a
+ list headed by an embedded malloc_segment record representing the
+ top-most space. Segments also include flags holding properties of
+ the space. Large chunks that are directly allocated by mmap are not
+ included in this list. They are instead independently created and
+ destroyed without otherwise keeping track of them.
+
+ Segment management mainly comes into play for spaces allocated by
+ MMAP. Any call to MMAP might or might not return memory that is
+ adjacent to an existing segment. MORECORE normally contiguously
+ extends the current space, so this space is almost always adjacent,
+ which is simpler and faster to deal with. (This is why MORECORE is
+ used preferentially to MMAP when both are available -- see
+ sys_alloc.) When allocating using MMAP, we don't use any of the
+ hinting mechanisms (inconsistently) supported in various
+ implementations of unix mmap, or distinguish reserving from
+ committing memory. Instead, we just ask for space, and exploit
+ contiguity when we get it. It is probably possible to do
+ better than this on some systems, but no general scheme seems
+ to be significantly better.
+
+ Management entails a simpler variant of the consolidation scheme
+ used for chunks to reduce fragmentation -- new adjacent memory is
+ normally prepended or appended to an existing segment. However,
+ there are limitations compared to chunk consolidation that mostly
+ reflect the fact that segment processing is relatively infrequent
+ (occurring only when getting memory from system) and that we
+ don't expect to have huge numbers of segments:
+
+ * Segments are not indexed, so traversal requires linear scans. (It
+ would be possible to index these, but is not worth the extra
+ overhead and complexity for most programs on most platforms.)
+ * New segments are only appended to old ones when holding top-most
+ memory; if they cannot be prepended to others, they are held in
+ different segments.
+
+ Except for the top-most segment of an mstate, each segment record
+ is kept at the tail of its segment. Segments are added by pushing
+ segment records onto the list headed by &mstate.seg for the
+ containing mstate.
+
+ Segment flags control allocation/merge/deallocation policies:
+ * If EXTERN_BIT set, then we did not allocate this segment,
+ and so should not try to deallocate or merge with others.
+ (This currently holds only for the initial segment passed
+ into create_mspace_with_base.)
+ * If USE_MMAP_BIT set, the segment may be merged with
+ other surrounding mmapped segments and trimmed/de-allocated
+ using munmap.
+ * If neither bit is set, then the segment was obtained using
+ MORECORE so can be merged with surrounding MORECORE'd segments
+ and deallocated/trimmed using MORECORE with negative arguments.
+*/
+
+struct malloc_segment {
+ char* base; /* base address */
+ size_t size; /* allocated size */
+ struct malloc_segment* next; /* ptr to next segment */
+ flag_t sflags; /* mmap and extern flag */
+};
+
+#define is_mmapped_segment(S) ((S)->sflags & USE_MMAP_BIT)
+#define is_extern_segment(S) ((S)->sflags & EXTERN_BIT)
+
+typedef struct malloc_segment msegment;
+typedef struct malloc_segment* msegmentptr;
+
+/* ---------------------------- malloc_state ----------------------------- */
+
+/*
+ A malloc_state holds all of the bookkeeping for a space.
+ The main fields are:
+
+ Top
+ The topmost chunk of the currently active segment. Its size is
+ cached in topsize. The actual size of topmost space is
+ topsize+TOP_FOOT_SIZE, which includes space reserved for adding
+ fenceposts and segment records if necessary when getting more
+ space from the system. The size at which to autotrim top is
+ cached from mparams in trim_check, except that it is disabled if
+ an autotrim fails.
+
+ Designated victim (dv)
+ This is the preferred chunk for servicing small requests that
+ don't have exact fits. It is normally the chunk split off most
+ recently to service another small request. Its size is cached in
+ dvsize. The link fields of this chunk are not maintained since it
+ is not kept in a bin.
+
+ SmallBins
+ An array of bin headers for free chunks. These bins hold chunks
+ with sizes less than MIN_LARGE_SIZE bytes. Each bin contains
+ chunks of all the same size, spaced 8 bytes apart. To simplify
+ use in double-linked lists, each bin header acts as a malloc_chunk
+ pointing to the real first node, if it exists (else pointing to
+ itself). This avoids special-casing for headers. But to avoid
+ waste, we allocate only the fd/bk pointers of bins, and then use
+ repositioning tricks to treat these as the fields of a chunk.
+
+ TreeBins
+ Treebins are pointers to the roots of trees holding a range of
+ sizes. There are 2 equally spaced treebins for each power of two
+ from TREE_SHIFT to TREE_SHIFT+16. The last bin holds anything
+ larger.
+
+ Bin maps
+ There is one bit map for small bins ("smallmap") and one for
+ treebins ("treemap). Each bin sets its bit when non-empty, and
+ clears the bit when empty. Bit operations are then used to avoid
+ bin-by-bin searching -- nearly all "search" is done without ever
+ looking at bins that won't be selected. The bit maps
+ conservatively use 32 bits per map word, even if on 64bit system.
+ For a good description of some of the bit-based techniques used
+ here, see Henry S. Warren Jr's book "Hacker's Delight" (and
+ supplement at http://hackersdelight.org/). Many of these are
+ intended to reduce the branchiness of paths through malloc etc, as
+ well as to reduce the number of memory locations read or written.
+
+ Segments
+ A list of segments headed by an embedded malloc_segment record
+ representing the initial space.
+
+ Address check support
+ The least_addr field is the least address ever obtained from
+ MORECORE or MMAP. Attempted frees and reallocs of any address less
+ than this are trapped (unless INSECURE is defined).
+
+ Magic tag
+ A cross-check field that should always hold same value as mparams.magic.
+
+ Max allowed footprint
+ The maximum allowed bytes to allocate from system (zero means no limit)
+
+ Flags
+ Bits recording whether to use MMAP, locks, or contiguous MORECORE
+
+ Statistics
+ Each space keeps track of current and maximum system memory
+ obtained via MORECORE or MMAP.
+
+ Trim support
+ Fields holding the amount of unused topmost memory that should trigger
+ trimming, and a counter to force periodic scanning to release unused
+ non-topmost segments.
+
+ Locking
+ If USE_LOCKS is defined, the "mutex" lock is acquired and released
+ around every public call using this mspace.
+
+ Extension support
+ A void* pointer and a size_t field that can be used to help implement
+ extensions to this malloc.
+*/
+
+/* Bin types, widths and sizes */
+#define NSMALLBINS (32U)
+#define NTREEBINS (32U)
+#define SMALLBIN_SHIFT (3U)
+#define SMALLBIN_WIDTH (SIZE_T_ONE << SMALLBIN_SHIFT)
+#define TREEBIN_SHIFT (8U)
+#define MIN_LARGE_SIZE (SIZE_T_ONE << TREEBIN_SHIFT)
+#define MAX_SMALL_SIZE (MIN_LARGE_SIZE - SIZE_T_ONE)
+#define MAX_SMALL_REQUEST (MAX_SMALL_SIZE - CHUNK_ALIGN_MASK - CHUNK_OVERHEAD)
+
+struct malloc_state {
+ binmap_t smallmap;
+ binmap_t treemap;
+ size_t dvsize;
+ size_t topsize;
+ char* least_addr;
+ mchunkptr dv;
+ mchunkptr top;
+ size_t trim_check;
+ size_t release_checks;
+ size_t magic;
+ mchunkptr smallbins[(NSMALLBINS+1)*2];
+ tbinptr treebins[NTREEBINS];
+ size_t footprint;
+ size_t max_footprint;
+ size_t footprint_limit; /* zero means no limit */
+ flag_t mflags;
+#if USE_LOCKS
+ MLOCK_T mutex; /* locate lock among fields that rarely change */
+#endif /* USE_LOCKS */
+ msegment seg;
+ void* extp; /* Unused but available for extensions */
+ size_t exts;
+};
+
+typedef struct malloc_state* mstate;
+
+/* ------------- Global malloc_state and malloc_params ------------------- */
+
+/*
+ malloc_params holds global properties, including those that can be
+ dynamically set using mallopt. There is a single instance, mparams,
+ initialized in init_mparams. Note that the non-zeroness of "magic"
+ also serves as an initialization flag.
+*/
+
+struct malloc_params {
+ size_t magic;
+ size_t page_size;
+ size_t granularity;
+ size_t mmap_threshold;
+ size_t trim_threshold;
+ flag_t default_mflags;
+};
+
+static struct malloc_params mparams;
+
+/* Ensure mparams initialized */
+#define ensure_initialization() (void)(mparams.magic != 0 || init_mparams())
+
+#if !ONLY_MSPACES
+
+/* The global malloc_state used for all non-"mspace" calls */
+static struct malloc_state _gm_;
+#define gm (&_gm_)
+#define is_global(M) ((M) == &_gm_)
+
+#endif /* !ONLY_MSPACES */
+
+#define is_initialized(M) ((M)->top != 0)
+
+/* -------------------------- system alloc setup ------------------------- */
+
+/* Operations on mflags */
+
+#define use_lock(M) ((M)->mflags & USE_LOCK_BIT)
+#define enable_lock(M) ((M)->mflags |= USE_LOCK_BIT)
+#if USE_LOCKS
+#define disable_lock(M) ((M)->mflags &= ~USE_LOCK_BIT)
+#else
+#define disable_lock(M)
+#endif
+
+#define use_mmap(M) ((M)->mflags & USE_MMAP_BIT)
+#define enable_mmap(M) ((M)->mflags |= USE_MMAP_BIT)
+#if HAVE_MMAP
+#define disable_mmap(M) ((M)->mflags &= ~USE_MMAP_BIT)
+#else
+#define disable_mmap(M)
+#endif
+
+#define use_noncontiguous(M) ((M)->mflags & USE_NONCONTIGUOUS_BIT)
+#define disable_contiguous(M) ((M)->mflags |= USE_NONCONTIGUOUS_BIT)
+
+#define set_lock(M,L)\
+ ((M)->mflags = (L)?\
+ ((M)->mflags | USE_LOCK_BIT) :\
+ ((M)->mflags & ~USE_LOCK_BIT))
+
+/* page-align a size */
+#define page_align(S)\
+ (((S) + (mparams.page_size - SIZE_T_ONE)) & ~(mparams.page_size - SIZE_T_ONE))
+
+/* granularity-align a size */
+#define granularity_align(S)\
+ (((S) + (mparams.granularity - SIZE_T_ONE))\
+ & ~(mparams.granularity - SIZE_T_ONE))
+
+
+/* For mmap, use granularity alignment on windows, else page-align */
+#ifdef WIN32
+#define mmap_align(S) granularity_align(S)
+#else
+#define mmap_align(S) page_align(S)
+#endif
+
+/* For sys_alloc, enough padding to ensure can malloc request on success */
+#define SYS_ALLOC_PADDING (TOP_FOOT_SIZE + MALLOC_ALIGNMENT)
+
+#define is_page_aligned(S)\
+ (((size_t)(S) & (mparams.page_size - SIZE_T_ONE)) == 0)
+#define is_granularity_aligned(S)\
+ (((size_t)(S) & (mparams.granularity - SIZE_T_ONE)) == 0)
+
+/* True if segment S holds address A */
+#define segment_holds(S, A)\
+ ((char*)(A) >= S->base && (char*)(A) < S->base + S->size)
+
+/* Return segment holding given address */
+static msegmentptr segment_holding(mstate m, char* addr) {
+ msegmentptr sp = &m->seg;
+ for (;;) {
+ if (addr >= sp->base && addr < sp->base + sp->size)
+ return sp;
+ if ((sp = sp->next) == 0)
+ return 0;
+ }
+}
+
+/* Return true if segment contains a segment link */
+static int has_segment_link(mstate m, msegmentptr ss) {
+ msegmentptr sp = &m->seg;
+ for (;;) {
+ if ((char*)sp >= ss->base && (char*)sp < ss->base + ss->size)
+ return 1;
+ if ((sp = sp->next) == 0)
+ return 0;
+ }
+}
+
+#ifndef MORECORE_CANNOT_TRIM
+#define should_trim(M,s) ((s) > (M)->trim_check)
+#else /* MORECORE_CANNOT_TRIM */
+#define should_trim(M,s) (0)
+#endif /* MORECORE_CANNOT_TRIM */
+
+/*
+ TOP_FOOT_SIZE is padding at the end of a segment, including space
+ that may be needed to place segment records and fenceposts when new
+ noncontiguous segments are added.
+*/
+#define TOP_FOOT_SIZE\
+ (align_offset(chunk2mem(0))+pad_request(sizeof(struct malloc_segment))+MIN_CHUNK_SIZE)
+
+
+/* ------------------------------- Hooks -------------------------------- */
+
+/*
+ PREACTION should be defined to return 0 on success, and nonzero on
+ failure. If you are not using locking, you can redefine these to do
+ anything you like.
+*/
+
+#if USE_LOCKS
+#define PREACTION(M) ((use_lock(M))? ACQUIRE_LOCK(&(M)->mutex) : 0)
+#define POSTACTION(M) { if (use_lock(M)) RELEASE_LOCK(&(M)->mutex); }
+#else /* USE_LOCKS */
+
+#ifndef PREACTION
+#define PREACTION(M) (0)
+#endif /* PREACTION */
+
+#ifndef POSTACTION
+#define POSTACTION(M)
+#endif /* POSTACTION */
+
+#endif /* USE_LOCKS */
+
+/*
+ CORRUPTION_ERROR_ACTION is triggered upon detected bad addresses.
+ USAGE_ERROR_ACTION is triggered on detected bad frees and
+ reallocs. The argument p is an address that might have triggered the
+ fault. It is ignored by the two predefined actions, but might be
+ useful in custom actions that try to help diagnose errors.
+*/
+
+#if PROCEED_ON_ERROR
+
+/* A count of the number of corruption errors causing resets */
+int malloc_corruption_error_count;
+
+/* default corruption action */
+static void reset_on_error(mstate m);
+
+#define CORRUPTION_ERROR_ACTION(m) reset_on_error(m)
+#define USAGE_ERROR_ACTION(m, p)
+
+#else /* PROCEED_ON_ERROR */
+
+#ifndef CORRUPTION_ERROR_ACTION
+#define CORRUPTION_ERROR_ACTION(m) ABORT
+#endif /* CORRUPTION_ERROR_ACTION */
+
+#ifndef USAGE_ERROR_ACTION
+#define USAGE_ERROR_ACTION(m,p) ABORT
+#endif /* USAGE_ERROR_ACTION */
+
+#endif /* PROCEED_ON_ERROR */
+
+
+/* -------------------------- Debugging setup ---------------------------- */
+
+#if ! DEBUG
+
+#define check_free_chunk(M,P)
+#define check_inuse_chunk(M,P)
+#define check_malloced_chunk(M,P,N)
+#define check_mmapped_chunk(M,P)
+#define check_malloc_state(M)
+#define check_top_chunk(M,P)
+
+#else /* DEBUG */
+#define check_free_chunk(M,P) do_check_free_chunk(M,P)
+#define check_inuse_chunk(M,P) do_check_inuse_chunk(M,P)
+#define check_top_chunk(M,P) do_check_top_chunk(M,P)
+#define check_malloced_chunk(M,P,N) do_check_malloced_chunk(M,P,N)
+#define check_mmapped_chunk(M,P) do_check_mmapped_chunk(M,P)
+#define check_malloc_state(M) do_check_malloc_state(M)
+
+static void do_check_any_chunk(mstate m, mchunkptr p);
+static void do_check_top_chunk(mstate m, mchunkptr p);
+static void do_check_mmapped_chunk(mstate m, mchunkptr p);
+static void do_check_inuse_chunk(mstate m, mchunkptr p);
+static void do_check_free_chunk(mstate m, mchunkptr p);
+static void do_check_malloced_chunk(mstate m, void* mem, size_t s);
+static void do_check_tree(mstate m, tchunkptr t);
+static void do_check_treebin(mstate m, bindex_t i);
+static void do_check_smallbin(mstate m, bindex_t i);
+static void do_check_malloc_state(mstate m);
+static int bin_find(mstate m, mchunkptr x);
+static size_t traverse_and_check(mstate m);
+#endif /* DEBUG */
+
+/* ---------------------------- Indexing Bins ---------------------------- */
+
+#define is_small(s) (((s) >> SMALLBIN_SHIFT) < NSMALLBINS)
+#define small_index(s) (bindex_t)((s) >> SMALLBIN_SHIFT)
+#define small_index2size(i) ((i) << SMALLBIN_SHIFT)
+#define MIN_SMALL_INDEX (small_index(MIN_CHUNK_SIZE))
+
+/* addressing by index. See above about smallbin repositioning */
+#define smallbin_at(M, i) ((sbinptr)((char*)&((M)->smallbins[(i)<<1])))
+#define treebin_at(M,i) (&((M)->treebins[i]))
+
+/* assign tree index for size S to variable I. Use x86 asm if possible */
+#if defined(__GNUC__) && (defined(__i386__) || defined(__x86_64__))
+#define compute_tree_index(S, I)\
+{\
+ unsigned int X = S >> TREEBIN_SHIFT;\
+ if (X == 0)\
+ I = 0;\
+ else if (X > 0xFFFF)\
+ I = NTREEBINS-1;\
+ else {\
+ unsigned int K = (unsigned) sizeof(X)*__CHAR_BIT__ - 1 - (unsigned) __builtin_clz(X); \
+ I = (bindex_t)((K << 1) + ((S >> (K + (TREEBIN_SHIFT-1)) & 1)));\
+ }\
+}
+
+#elif defined (__INTEL_COMPILER)
+#define compute_tree_index(S, I)\
+{\
+ size_t X = S >> TREEBIN_SHIFT;\
+ if (X == 0)\
+ I = 0;\
+ else if (X > 0xFFFF)\
+ I = NTREEBINS-1;\
+ else {\
+ unsigned int K = _bit_scan_reverse (X); \
+ I = (bindex_t)((K << 1) + ((S >> (K + (TREEBIN_SHIFT-1)) & 1)));\
+ }\
+}
+
+#elif defined(_MSC_VER) && _MSC_VER>=1300
+#define compute_tree_index(S, I)\
+{\
+ size_t X = S >> TREEBIN_SHIFT;\
+ if (X == 0)\
+ I = 0;\
+ else if (X > 0xFFFF)\
+ I = NTREEBINS-1;\
+ else {\
+ unsigned int K;\
+ _BitScanReverse((DWORD *) &K, (DWORD) X);\
+ I = (bindex_t)((K << 1) + ((S >> (K + (TREEBIN_SHIFT-1)) & 1)));\
+ }\
+}
+
+#else /* GNUC */
+#define compute_tree_index(S, I)\
+{\
+ size_t X = S >> TREEBIN_SHIFT;\
+ if (X == 0)\
+ I = 0;\
+ else if (X > 0xFFFF)\
+ I = NTREEBINS-1;\
+ else {\
+ unsigned int Y = (unsigned int)X;\
+ unsigned int N = ((Y - 0x100) >> 16) & 8;\
+ unsigned int K = (((Y <<= N) - 0x1000) >> 16) & 4;\
+ N += K;\
+ N += K = (((Y <<= K) - 0x4000) >> 16) & 2;\
+ K = 14 - N + ((Y <<= K) >> 15);\
+ I = (K << 1) + ((S >> (K + (TREEBIN_SHIFT-1)) & 1));\
+ }\
+}
+#endif /* GNUC */
+
+/* Bit representing maximum resolved size in a treebin at i */
+#define bit_for_tree_index(i) \
+ (i == NTREEBINS-1)? (SIZE_T_BITSIZE-1) : (((i) >> 1) + TREEBIN_SHIFT - 2)
+
+/* Shift placing maximum resolved bit in a treebin at i as sign bit */
+#define leftshift_for_tree_index(i) \
+ ((i == NTREEBINS-1)? 0 : \
+ ((SIZE_T_BITSIZE-SIZE_T_ONE) - (((i) >> 1) + TREEBIN_SHIFT - 2)))
+
+/* The size of the smallest chunk held in bin with index i */
+#define minsize_for_tree_index(i) \
+ ((SIZE_T_ONE << (((i) >> 1) + TREEBIN_SHIFT)) | \
+ (((size_t)((i) & SIZE_T_ONE)) << (((i) >> 1) + TREEBIN_SHIFT - 1)))
+
+
+/* ------------------------ Operations on bin maps ----------------------- */
+
+/* bit corresponding to given index */
+#define idx2bit(i) ((binmap_t)(1) << (i))
+
+/* Mark/Clear bits with given index */
+#define mark_smallmap(M,i) ((M)->smallmap |= idx2bit(i))
+#define clear_smallmap(M,i) ((M)->smallmap &= ~idx2bit(i))
+#define smallmap_is_marked(M,i) ((M)->smallmap & idx2bit(i))
+
+#define mark_treemap(M,i) ((M)->treemap |= idx2bit(i))
+#define clear_treemap(M,i) ((M)->treemap &= ~idx2bit(i))
+#define treemap_is_marked(M,i) ((M)->treemap & idx2bit(i))
+
+/* isolate the least set bit of a bitmap */
+#define least_bit(x) ((x) & -(x))
+
+/* mask with all bits to left of least bit of x on */
+#define left_bits(x) ((x<<1) | -(x<<1))
+
+/* mask with all bits to left of or equal to least bit of x on */
+#define same_or_left_bits(x) ((x) | -(x))
+
+/* index corresponding to given bit. Use x86 asm if possible */
+
+#if defined(__GNUC__) && (defined(__i386__) || defined(__x86_64__))
+#define compute_bit2idx(X, I)\
+{\
+ unsigned int J;\
+ J = __builtin_ctz(X); \
+ I = (bindex_t)J;\
+}
+
+#elif defined (__INTEL_COMPILER)
+#define compute_bit2idx(X, I)\
+{\
+ unsigned int J;\
+ J = _bit_scan_forward (X); \
+ I = (bindex_t)J;\
+}
+
+#elif defined(_MSC_VER) && _MSC_VER>=1300
+#define compute_bit2idx(X, I)\
+{\
+ unsigned int J;\
+ _BitScanForward((DWORD *) &J, X);\
+ I = (bindex_t)J;\
+}
+
+#elif USE_BUILTIN_FFS
+#define compute_bit2idx(X, I) I = ffs(X)-1
+
+#else
+#define compute_bit2idx(X, I)\
+{\
+ unsigned int Y = X - 1;\
+ unsigned int K = Y >> (16-4) & 16;\
+ unsigned int N = K; Y >>= K;\
+ N += K = Y >> (8-3) & 8; Y >>= K;\
+ N += K = Y >> (4-2) & 4; Y >>= K;\
+ N += K = Y >> (2-1) & 2; Y >>= K;\
+ N += K = Y >> (1-0) & 1; Y >>= K;\
+ I = (bindex_t)(N + Y);\
+}
+#endif /* GNUC */
+
+
+/* ----------------------- Runtime Check Support ------------------------- */
+
+/*
+ For security, the main invariant is that malloc/free/etc never
+ writes to a static address other than malloc_state, unless static
+ malloc_state itself has been corrupted, which cannot occur via
+ malloc (because of these checks). In essence this means that we
+ believe all pointers, sizes, maps etc held in malloc_state, but
+ check all of those linked or offsetted from other embedded data
+ structures. These checks are interspersed with main code in a way
+ that tends to minimize their run-time cost.
+
+ When FOOTERS is defined, in addition to range checking, we also
+ verify footer fields of inuse chunks, which can be used guarantee
+ that the mstate controlling malloc/free is intact. This is a
+ streamlined version of the approach described by William Robertson
+ et al in "Run-time Detection of Heap-based Overflows" LISA'03
+ http://www.usenix.org/events/lisa03/tech/robertson.html The footer
+ of an inuse chunk holds the xor of its mstate and a random seed,
+ that is checked upon calls to free() and realloc(). This is
+ (probabalistically) unguessable from outside the program, but can be
+ computed by any code successfully malloc'ing any chunk, so does not
+ itself provide protection against code that has already broken
+ security through some other means. Unlike Robertson et al, we
+ always dynamically check addresses of all offset chunks (previous,
+ next, etc). This turns out to be cheaper than relying on hashes.
+*/
+
+#if !INSECURE
+/* Check if address a is at least as high as any from MORECORE or MMAP */
+#define ok_address(M, a) ((char*)(a) >= (M)->least_addr)
+/* Check if address of next chunk n is higher than base chunk p */
+#define ok_next(p, n) ((char*)(p) < (char*)(n))
+/* Check if p has inuse status */
+#define ok_inuse(p) is_inuse(p)
+/* Check if p has its pinuse bit on */
+#define ok_pinuse(p) pinuse(p)
+
+#else /* !INSECURE */
+#define ok_address(M, a) (1)
+#define ok_next(b, n) (1)
+#define ok_inuse(p) (1)
+#define ok_pinuse(p) (1)
+#endif /* !INSECURE */
+
+#if (FOOTERS && !INSECURE)
+/* Check if (alleged) mstate m has expected magic field */
+#define ok_magic(M) ((M)->magic == mparams.magic)
+#else /* (FOOTERS && !INSECURE) */
+#define ok_magic(M) (1)
+#endif /* (FOOTERS && !INSECURE) */
+
+/* In gcc, use __builtin_expect to minimize impact of checks */
+#if !INSECURE
+#if defined(__GNUC__) && __GNUC__ >= 3
+#define RTCHECK(e) __builtin_expect(e, 1)
+#else /* GNUC */
+#define RTCHECK(e) (e)
+#endif /* GNUC */
+#else /* !INSECURE */
+#define RTCHECK(e) (1)
+#endif /* !INSECURE */
+
+/* macros to set up inuse chunks with or without footers */
+
+#if !FOOTERS
+
+#define mark_inuse_foot(M,p,s)
+
+/* Macros for setting head/foot of non-mmapped chunks */
+
+/* Set cinuse bit and pinuse bit of next chunk */
+#define set_inuse(M,p,s)\
+ ((p)->head = (((p)->head & PINUSE_BIT)|s|CINUSE_BIT),\
+ ((mchunkptr)(((char*)(p)) + (s)))->head |= PINUSE_BIT)
+
+/* Set cinuse and pinuse of this chunk and pinuse of next chunk */
+#define set_inuse_and_pinuse(M,p,s)\
+ ((p)->head = (s|PINUSE_BIT|CINUSE_BIT),\
+ ((mchunkptr)(((char*)(p)) + (s)))->head |= PINUSE_BIT)
+
+/* Set size, cinuse and pinuse bit of this chunk */
+#define set_size_and_pinuse_of_inuse_chunk(M, p, s)\
+ ((p)->head = (s|PINUSE_BIT|CINUSE_BIT))
+
+#else /* FOOTERS */
+
+/* Set foot of inuse chunk to be xor of mstate and seed */
+#define mark_inuse_foot(M,p,s)\
+ (((mchunkptr)((char*)(p) + (s)))->prev_foot = ((size_t)(M) ^ mparams.magic))
+
+#define get_mstate_for(p)\
+ ((mstate)(((mchunkptr)((char*)(p) +\
+ (chunksize(p))))->prev_foot ^ mparams.magic))
+
+#define set_inuse(M,p,s)\
+ ((p)->head = (((p)->head & PINUSE_BIT)|s|CINUSE_BIT),\
+ (((mchunkptr)(((char*)(p)) + (s)))->head |= PINUSE_BIT), \
+ mark_inuse_foot(M,p,s))
+
+#define set_inuse_and_pinuse(M,p,s)\
+ ((p)->head = (s|PINUSE_BIT|CINUSE_BIT),\
+ (((mchunkptr)(((char*)(p)) + (s)))->head |= PINUSE_BIT),\
+ mark_inuse_foot(M,p,s))
+
+#define set_size_and_pinuse_of_inuse_chunk(M, p, s)\
+ ((p)->head = (s|PINUSE_BIT|CINUSE_BIT),\
+ mark_inuse_foot(M, p, s))
+
+#endif /* !FOOTERS */
+
+/* ---------------------------- setting mparams -------------------------- */
+
+#if LOCK_AT_FORK
+static void pre_fork(void) { ACQUIRE_LOCK(&(gm)->mutex); }
+static void post_fork_parent(void) { RELEASE_LOCK(&(gm)->mutex); }
+static void post_fork_child(void) { INITIAL_LOCK(&(gm)->mutex); }
+#endif /* LOCK_AT_FORK */
+
+/* Initialize mparams */
+static int init_mparams(void) {
+#ifdef NEED_GLOBAL_LOCK_INIT
+ if (malloc_global_mutex_status <= 0)
+ init_malloc_global_mutex();
+#endif
+
+ ACQUIRE_MALLOC_GLOBAL_LOCK();
+ if (mparams.magic == 0) {
+ size_t magic;
+ size_t psize;
+ size_t gsize;
+
+#ifndef WIN32
+ psize = malloc_getpagesize;
+ gsize = ((DEFAULT_GRANULARITY != 0)? DEFAULT_GRANULARITY : psize);
+#else /* WIN32 */
+ {
+ SYSTEM_INFO system_info;
+ GetSystemInfo(&system_info);
+ psize = system_info.dwPageSize;
+ gsize = ((DEFAULT_GRANULARITY != 0)?
+ DEFAULT_GRANULARITY : system_info.dwAllocationGranularity);
+ }
+#endif /* WIN32 */
+
+ /* Sanity-check configuration:
+ size_t must be unsigned and as wide as pointer type.
+ ints must be at least 4 bytes.
+ alignment must be at least 8.
+ Alignment, min chunk size, and page size must all be powers of 2.
+ */
+ if ((sizeof(size_t) != sizeof(char*)) ||
+ (MAX_SIZE_T < MIN_CHUNK_SIZE) ||
+ (sizeof(int) < 4) ||
+ (MALLOC_ALIGNMENT < (size_t)8U) ||
+ ((MALLOC_ALIGNMENT & (MALLOC_ALIGNMENT-SIZE_T_ONE)) != 0) ||
+ ((MCHUNK_SIZE & (MCHUNK_SIZE-SIZE_T_ONE)) != 0) ||
+ ((gsize & (gsize-SIZE_T_ONE)) != 0) ||
+ ((psize & (psize-SIZE_T_ONE)) != 0))
+ ABORT;
+ mparams.granularity = gsize;
+ mparams.page_size = psize;
+ mparams.mmap_threshold = DEFAULT_MMAP_THRESHOLD;
+ mparams.trim_threshold = DEFAULT_TRIM_THRESHOLD;
+#if MORECORE_CONTIGUOUS
+ mparams.default_mflags = USE_LOCK_BIT|USE_MMAP_BIT;
+#else /* MORECORE_CONTIGUOUS */
+ mparams.default_mflags = USE_LOCK_BIT|USE_MMAP_BIT|USE_NONCONTIGUOUS_BIT;
+#endif /* MORECORE_CONTIGUOUS */
+
+#if !ONLY_MSPACES
+ /* Set up lock for main malloc area */
+ gm->mflags = mparams.default_mflags;
+ (void)INITIAL_LOCK(&gm->mutex);
+#endif
+#if LOCK_AT_FORK
+ pthread_atfork(&pre_fork, &post_fork_parent, &post_fork_child);
+#endif
+
+ {
+#if USE_DEV_RANDOM
+ int fd;
+ unsigned char buf[sizeof(size_t)];
+ /* Try to use /dev/urandom, else fall back on using time */
+ if ((fd = open("/dev/urandom", O_RDONLY)) >= 0 &&
+ read(fd, buf, sizeof(buf)) == sizeof(buf)) {
+ magic = *((size_t *) buf);
+ close(fd);
+ }
+ else
+#endif /* USE_DEV_RANDOM */
+#ifdef WIN32
+ magic = (size_t)(GetTickCount() ^ (size_t)0x55555555U);
+#elif defined(LACKS_TIME_H)
+ magic = (size_t)&magic ^ (size_t)0x55555555U;
+#else
+ magic = (size_t)(time(0) ^ (size_t)0x55555555U);
+#endif
+ magic |= (size_t)8U; /* ensure nonzero */
+ magic &= ~(size_t)7U; /* improve chances of fault for bad values */
+ /* Until memory modes commonly available, use volatile-write */
+ (*(volatile size_t *)(&(mparams.magic))) = magic;
+ }
+ }
+
+ RELEASE_MALLOC_GLOBAL_LOCK();
+ return 1;
+}
+
+/* support for mallopt */
+static int change_mparam(int param_number, int value) {
+ size_t val;
+ ensure_initialization();
+ val = (value == -1)? MAX_SIZE_T : (size_t)value;
+ switch(param_number) {
+ case M_TRIM_THRESHOLD:
+ mparams.trim_threshold = val;
+ return 1;
+ case M_GRANULARITY:
+ if (val >= mparams.page_size && ((val & (val-1)) == 0)) {
+ mparams.granularity = val;
+ return 1;
+ }
+ else
+ return 0;
+ case M_MMAP_THRESHOLD:
+ mparams.mmap_threshold = val;
+ return 1;
+ default:
+ return 0;
+ }
+}
+
+#if DEBUG
+/* ------------------------- Debugging Support --------------------------- */
+
+/* Check properties of any chunk, whether free, inuse, mmapped etc */
+static void do_check_any_chunk(mstate m, mchunkptr p) {
+ assert((is_aligned(chunk2mem(p))) || (p->head == FENCEPOST_HEAD));
+ assert(ok_address(m, p));
+}
+
+/* Check properties of top chunk */
+static void do_check_top_chunk(mstate m, mchunkptr p) {
+ msegmentptr sp = segment_holding(m, (char*)p);
+ size_t sz = p->head & ~INUSE_BITS; /* third-lowest bit can be set! */
+ assert(sp != 0);
+ assert((is_aligned(chunk2mem(p))) || (p->head == FENCEPOST_HEAD));
+ assert(ok_address(m, p));
+ assert(sz == m->topsize);
+ assert(sz > 0);
+ assert(sz == ((sp->base + sp->size) - (char*)p) - TOP_FOOT_SIZE);
+ assert(pinuse(p));
+ assert(!pinuse(chunk_plus_offset(p, sz)));
+}
+
+/* Check properties of (inuse) mmapped chunks */
+static void do_check_mmapped_chunk(mstate m, mchunkptr p) {
+ size_t sz = chunksize(p);
+ size_t len = (sz + (p->prev_foot) + MMAP_FOOT_PAD);
+ assert(is_mmapped(p));
+ assert(use_mmap(m));
+ assert((is_aligned(chunk2mem(p))) || (p->head == FENCEPOST_HEAD));
+ assert(ok_address(m, p));
+ assert(!is_small(sz));
+ assert((len & (mparams.page_size-SIZE_T_ONE)) == 0);
+ assert(chunk_plus_offset(p, sz)->head == FENCEPOST_HEAD);
+ assert(chunk_plus_offset(p, sz+SIZE_T_SIZE)->head == 0);
+}
+
+/* Check properties of inuse chunks */
+static void do_check_inuse_chunk(mstate m, mchunkptr p) {
+ do_check_any_chunk(m, p);
+ assert(is_inuse(p));
+ assert(next_pinuse(p));
+ /* If not pinuse and not mmapped, previous chunk has OK offset */
+ assert(is_mmapped(p) || pinuse(p) || next_chunk(prev_chunk(p)) == p);
+ if (is_mmapped(p))
+ do_check_mmapped_chunk(m, p);
+}
+
+/* Check properties of free chunks */
+static void do_check_free_chunk(mstate m, mchunkptr p) {
+ size_t sz = chunksize(p);
+ mchunkptr next = chunk_plus_offset(p, sz);
+ do_check_any_chunk(m, p);
+ assert(!is_inuse(p));
+ assert(!next_pinuse(p));
+ assert (!is_mmapped(p));
+ if (p != m->dv && p != m->top) {
+ if (sz >= MIN_CHUNK_SIZE) {
+ assert((sz & CHUNK_ALIGN_MASK) == 0);
+ assert(is_aligned(chunk2mem(p)));
+ assert(next->prev_foot == sz);
+ assert(pinuse(p));
+ assert (next == m->top || is_inuse(next));
+ assert(p->fd->bk == p);
+ assert(p->bk->fd == p);
+ }
+ else /* markers are always of size SIZE_T_SIZE */
+ assert(sz == SIZE_T_SIZE);
+ }
+}
+
+/* Check properties of malloced chunks at the point they are malloced */
+static void do_check_malloced_chunk(mstate m, void* mem, size_t s) {
+ if (mem != 0) {
+ mchunkptr p = mem2chunk(mem);
+ size_t sz = p->head & ~INUSE_BITS;
+ do_check_inuse_chunk(m, p);
+ assert((sz & CHUNK_ALIGN_MASK) == 0);
+ assert(sz >= MIN_CHUNK_SIZE);
+ assert(sz >= s);
+ /* unless mmapped, size is less than MIN_CHUNK_SIZE more than request */
+ assert(is_mmapped(p) || sz < (s + MIN_CHUNK_SIZE));
+ }
+}
+
+/* Check a tree and its subtrees. */
+static void do_check_tree(mstate m, tchunkptr t) {
+ tchunkptr head = 0;
+ tchunkptr u = t;
+ bindex_t tindex = t->index;
+ size_t tsize = chunksize(t);
+ bindex_t idx;
+ compute_tree_index(tsize, idx);
+ assert(tindex == idx);
+ assert(tsize >= MIN_LARGE_SIZE);
+ assert(tsize >= minsize_for_tree_index(idx));
+ assert((idx == NTREEBINS-1) || (tsize < minsize_for_tree_index((idx+1))));
+
+ do { /* traverse through chain of same-sized nodes */
+ do_check_any_chunk(m, ((mchunkptr)u));
+ assert(u->index == tindex);
+ assert(chunksize(u) == tsize);
+ assert(!is_inuse(u));
+ assert(!next_pinuse(u));
+ assert(u->fd->bk == u);
+ assert(u->bk->fd == u);
+ if (u->parent == 0) {
+ assert(u->child[0] == 0);
+ assert(u->child[1] == 0);
+ }
+ else {
+ assert(head == 0); /* only one node on chain has parent */
+ head = u;
+ assert(u->parent != u);
+ assert (u->parent->child[0] == u ||
+ u->parent->child[1] == u ||
+ *((tbinptr*)(u->parent)) == u);
+ if (u->child[0] != 0) {
+ assert(u->child[0]->parent == u);
+ assert(u->child[0] != u);
+ do_check_tree(m, u->child[0]);
+ }
+ if (u->child[1] != 0) {
+ assert(u->child[1]->parent == u);
+ assert(u->child[1] != u);
+ do_check_tree(m, u->child[1]);
+ }
+ if (u->child[0] != 0 && u->child[1] != 0) {
+ assert(chunksize(u->child[0]) < chunksize(u->child[1]));
+ }
+ }
+ u = u->fd;
+ } while (u != t);
+ assert(head != 0);
+}
+
+/* Check all the chunks in a treebin. */
+static void do_check_treebin(mstate m, bindex_t i) {
+ tbinptr* tb = treebin_at(m, i);
+ tchunkptr t = *tb;
+ int empty = (m->treemap & (1U << i)) == 0;
+ if (t == 0)
+ assert(empty);
+ if (!empty)
+ do_check_tree(m, t);
+}
+
+/* Check all the chunks in a smallbin. */
+static void do_check_smallbin(mstate m, bindex_t i) {
+ sbinptr b = smallbin_at(m, i);
+ mchunkptr p = b->bk;
+ unsigned int empty = (m->smallmap & (1U << i)) == 0;
+ if (p == b)
+ assert(empty);
+ if (!empty) {
+ for (; p != b; p = p->bk) {
+ size_t size = chunksize(p);
+ mchunkptr q;
+ /* each chunk claims to be free */
+ do_check_free_chunk(m, p);
+ /* chunk belongs in bin */
+ assert(small_index(size) == i);
+ assert(p->bk == b || chunksize(p->bk) == chunksize(p));
+ /* chunk is followed by an inuse chunk */
+ q = next_chunk(p);
+ if (q->head != FENCEPOST_HEAD)
+ do_check_inuse_chunk(m, q);
+ }
+ }
+}
+
+/* Find x in a bin. Used in other check functions. */
+static int bin_find(mstate m, mchunkptr x) {
+ size_t size = chunksize(x);
+ if (is_small(size)) {
+ bindex_t sidx = small_index(size);
+ sbinptr b = smallbin_at(m, sidx);
+ if (smallmap_is_marked(m, sidx)) {
+ mchunkptr p = b;
+ do {
+ if (p == x)
+ return 1;
+ } while ((p = p->fd) != b);
+ }
+ }
+ else {
+ bindex_t tidx;
+ compute_tree_index(size, tidx);
+ if (treemap_is_marked(m, tidx)) {
+ tchunkptr t = *treebin_at(m, tidx);
+ size_t sizebits = size << leftshift_for_tree_index(tidx);
+ while (t != 0 && chunksize(t) != size) {
+ t = t->child[(sizebits >> (SIZE_T_BITSIZE-SIZE_T_ONE)) & 1];
+ sizebits <<= 1;
+ }
+ if (t != 0) {
+ tchunkptr u = t;
+ do {
+ if (u == (tchunkptr)x)
+ return 1;
+ } while ((u = u->fd) != t);
+ }
+ }
+ }
+ return 0;
+}
+
+/* Traverse each chunk and check it; return total */
+static size_t traverse_and_check(mstate m) {
+ size_t sum = 0;
+ if (is_initialized(m)) {
+ msegmentptr s = &m->seg;
+ sum += m->topsize + TOP_FOOT_SIZE;
+ while (s != 0) {
+ mchunkptr q = align_as_chunk(s->base);
+ mchunkptr lastq = 0;
+ assert(pinuse(q));
+ while (segment_holds(s, q) &&
+ q != m->top && q->head != FENCEPOST_HEAD) {
+ sum += chunksize(q);
+ if (is_inuse(q)) {
+ assert(!bin_find(m, q));
+ do_check_inuse_chunk(m, q);
+ }
+ else {
+ assert(q == m->dv || bin_find(m, q));
+ assert(lastq == 0 || is_inuse(lastq)); /* Not 2 consecutive free */
+ do_check_free_chunk(m, q);
+ }
+ lastq = q;
+ q = next_chunk(q);
+ }
+ s = s->next;
+ }
+ }
+ return sum;
+}
+
+
+/* Check all properties of malloc_state. */
+static void do_check_malloc_state(mstate m) {
+ bindex_t i;
+ size_t total;
+ /* check bins */
+ for (i = 0; i < NSMALLBINS; ++i)
+ do_check_smallbin(m, i);
+ for (i = 0; i < NTREEBINS; ++i)
+ do_check_treebin(m, i);
+
+ if (m->dvsize != 0) { /* check dv chunk */
+ do_check_any_chunk(m, m->dv);
+ assert(m->dvsize == chunksize(m->dv));
+ assert(m->dvsize >= MIN_CHUNK_SIZE);
+ assert(bin_find(m, m->dv) == 0);
+ }
+
+ if (m->top != 0) { /* check top chunk */
+ do_check_top_chunk(m, m->top);
+ /*assert(m->topsize == chunksize(m->top)); redundant */
+ assert(m->topsize > 0);
+ assert(bin_find(m, m->top) == 0);
+ }
+
+ total = traverse_and_check(m);
+ assert(total <= m->footprint);
+ assert(m->footprint <= m->max_footprint);
+}
+#endif /* DEBUG */
+
+/* ----------------------------- statistics ------------------------------ */
+
+#if !NO_MALLINFO
+static struct mallinfo internal_mallinfo(mstate m) {
+ struct mallinfo nm = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
+ ensure_initialization();
+ if (!PREACTION(m)) {
+ check_malloc_state(m);
+ if (is_initialized(m)) {
+ size_t nfree = SIZE_T_ONE; /* top always free */
+ size_t mfree = m->topsize + TOP_FOOT_SIZE;
+ size_t sum = mfree;
+ msegmentptr s = &m->seg;
+ while (s != 0) {
+ mchunkptr q = align_as_chunk(s->base);
+ while (segment_holds(s, q) &&
+ q != m->top && q->head != FENCEPOST_HEAD) {
+ size_t sz = chunksize(q);
+ sum += sz;
+ if (!is_inuse(q)) {
+ mfree += sz;
+ ++nfree;
+ }
+ q = next_chunk(q);
+ }
+ s = s->next;
+ }
+
+ nm.arena = sum;
+ nm.ordblks = nfree;
+ nm.hblkhd = m->footprint - sum;
+ nm.usmblks = m->max_footprint;
+ nm.uordblks = m->footprint - mfree;
+ nm.fordblks = mfree;
+ nm.keepcost = m->topsize;
+ }
+
+ POSTACTION(m);
+ }
+ return nm;
+}
+#endif /* !NO_MALLINFO */
+
+#if !NO_MALLOC_STATS
+static void internal_malloc_stats(mstate m) {
+ ensure_initialization();
+ if (!PREACTION(m)) {
+ size_t maxfp = 0;
+ size_t fp = 0;
+ size_t used = 0;
+ check_malloc_state(m);
+ if (is_initialized(m)) {
+ msegmentptr s = &m->seg;
+ maxfp = m->max_footprint;
+ fp = m->footprint;
+ used = fp - (m->topsize + TOP_FOOT_SIZE);
+
+ while (s != 0) {
+ mchunkptr q = align_as_chunk(s->base);
+ while (segment_holds(s, q) &&
+ q != m->top && q->head != FENCEPOST_HEAD) {
+ if (!is_inuse(q))
+ used -= chunksize(q);
+ q = next_chunk(q);
+ }
+ s = s->next;
+ }
+ }
+ POSTACTION(m); /* drop lock */
+ fprintf(stderr, "max system bytes = %10lu\n", (unsigned long)(maxfp));
+ fprintf(stderr, "system bytes = %10lu\n", (unsigned long)(fp));
+ fprintf(stderr, "in use bytes = %10lu\n", (unsigned long)(used));
+ }
+}
+#endif /* NO_MALLOC_STATS */
+
+/* ----------------------- Operations on smallbins ----------------------- */
+
+/*
+ Various forms of linking and unlinking are defined as macros. Even
+ the ones for trees, which are very long but have very short typical
+ paths. This is ugly but reduces reliance on inlining support of
+ compilers.
+*/
+
+/* Link a free chunk into a smallbin */
+#define insert_small_chunk(M, P, S) {\
+ bindex_t I = small_index(S);\
+ mchunkptr B = smallbin_at(M, I);\
+ mchunkptr F = B;\
+ assert(S >= MIN_CHUNK_SIZE);\
+ if (!smallmap_is_marked(M, I))\
+ mark_smallmap(M, I);\
+ else if (RTCHECK(ok_address(M, B->fd)))\
+ F = B->fd;\
+ else {\
+ CORRUPTION_ERROR_ACTION(M);\
+ }\
+ B->fd = P;\
+ F->bk = P;\
+ P->fd = F;\
+ P->bk = B;\
+}
+
+/* Unlink a chunk from a smallbin */
+#define unlink_small_chunk(M, P, S) {\
+ mchunkptr F = P->fd;\
+ mchunkptr B = P->bk;\
+ bindex_t I = small_index(S);\
+ assert(P != B);\
+ assert(P != F);\
+ assert(chunksize(P) == small_index2size(I));\
+ if (RTCHECK(F == smallbin_at(M,I) || (ok_address(M, F) && F->bk == P))) { \
+ if (B == F) {\
+ clear_smallmap(M, I);\
+ }\
+ else if (RTCHECK(B == smallbin_at(M,I) ||\
+ (ok_address(M, B) && B->fd == P))) {\
+ F->bk = B;\
+ B->fd = F;\
+ }\
+ else {\
+ CORRUPTION_ERROR_ACTION(M);\
+ }\
+ }\
+ else {\
+ CORRUPTION_ERROR_ACTION(M);\
+ }\
+}
+
+/* Unlink the first chunk from a smallbin */
+#define unlink_first_small_chunk(M, B, P, I) {\
+ mchunkptr F = P->fd;\
+ assert(P != B);\
+ assert(P != F);\
+ assert(chunksize(P) == small_index2size(I));\
+ if (B == F) {\
+ clear_smallmap(M, I);\
+ }\
+ else if (RTCHECK(ok_address(M, F) && F->bk == P)) {\
+ F->bk = B;\
+ B->fd = F;\
+ }\
+ else {\
+ CORRUPTION_ERROR_ACTION(M);\
+ }\
+}
+
+/* Replace dv node, binning the old one */
+/* Used only when dvsize known to be small */
+#define replace_dv(M, P, S) {\
+ size_t DVS = M->dvsize;\
+ assert(is_small(DVS));\
+ if (DVS != 0) {\
+ mchunkptr DV = M->dv;\
+ insert_small_chunk(M, DV, DVS);\
+ }\
+ M->dvsize = S;\
+ M->dv = P;\
+}
+
+/* ------------------------- Operations on trees ------------------------- */
+
+/* Insert chunk into tree */
+#define insert_large_chunk(M, X, S) {\
+ tbinptr* H;\
+ bindex_t I;\
+ compute_tree_index(S, I);\
+ H = treebin_at(M, I);\
+ X->index = I;\
+ X->child[0] = X->child[1] = 0;\
+ if (!treemap_is_marked(M, I)) {\
+ mark_treemap(M, I);\
+ *H = X;\
+ X->parent = (tchunkptr)H;\
+ X->fd = X->bk = X;\
+ }\
+ else {\
+ tchunkptr T = *H;\
+ size_t K = S << leftshift_for_tree_index(I);\
+ for (;;) {\
+ if (chunksize(T) != S) {\
+ tchunkptr* C = &(T->child[(K >> (SIZE_T_BITSIZE-SIZE_T_ONE)) & 1]);\
+ K <<= 1;\
+ if (*C != 0)\
+ T = *C;\
+ else if (RTCHECK(ok_address(M, C))) {\
+ *C = X;\
+ X->parent = T;\
+ X->fd = X->bk = X;\
+ break;\
+ }\
+ else {\
+ CORRUPTION_ERROR_ACTION(M);\
+ break;\
+ }\
+ }\
+ else {\
+ tchunkptr F = T->fd;\
+ if (RTCHECK(ok_address(M, T) && ok_address(M, F))) {\
+ T->fd = F->bk = X;\
+ X->fd = F;\
+ X->bk = T;\
+ X->parent = 0;\
+ break;\
+ }\
+ else {\
+ CORRUPTION_ERROR_ACTION(M);\
+ break;\
+ }\
+ }\
+ }\
+ }\
+}
+
+/*
+ Unlink steps:
+
+ 1. If x is a chained node, unlink it from its same-sized fd/bk links
+ and choose its bk node as its replacement.
+ 2. If x was the last node of its size, but not a leaf node, it must
+ be replaced with a leaf node (not merely one with an open left or
+ right), to make sure that lefts and rights of descendents
+ correspond properly to bit masks. We use the rightmost descendent
+ of x. We could use any other leaf, but this is easy to locate and
+ tends to counteract removal of leftmosts elsewhere, and so keeps
+ paths shorter than minimally guaranteed. This doesn't loop much
+ because on average a node in a tree is near the bottom.
+ 3. If x is the base of a chain (i.e., has parent links) relink
+ x's parent and children to x's replacement (or null if none).
+*/
+
+#define unlink_large_chunk(M, X) {\
+ tchunkptr XP = X->parent;\
+ tchunkptr R;\
+ if (X->bk != X) {\
+ tchunkptr F = X->fd;\
+ R = X->bk;\
+ if (RTCHECK(ok_address(M, F) && F->bk == X && R->fd == X)) {\
+ F->bk = R;\
+ R->fd = F;\
+ }\
+ else {\
+ CORRUPTION_ERROR_ACTION(M);\
+ }\
+ }\
+ else {\
+ tchunkptr* RP;\
+ if (((R = *(RP = &(X->child[1]))) != 0) ||\
+ ((R = *(RP = &(X->child[0]))) != 0)) {\
+ tchunkptr* CP;\
+ while ((*(CP = &(R->child[1])) != 0) ||\
+ (*(CP = &(R->child[0])) != 0)) {\
+ R = *(RP = CP);\
+ }\
+ if (RTCHECK(ok_address(M, RP)))\
+ *RP = 0;\
+ else {\
+ CORRUPTION_ERROR_ACTION(M);\
+ }\
+ }\
+ }\
+ if (XP != 0) {\
+ tbinptr* H = treebin_at(M, X->index);\
+ if (X == *H) {\
+ if ((*H = R) == 0) \
+ clear_treemap(M, X->index);\
+ }\
+ else if (RTCHECK(ok_address(M, XP))) {\
+ if (XP->child[0] == X) \
+ XP->child[0] = R;\
+ else \
+ XP->child[1] = R;\
+ }\
+ else\
+ CORRUPTION_ERROR_ACTION(M);\
+ if (R != 0) {\
+ if (RTCHECK(ok_address(M, R))) {\
+ tchunkptr C0, C1;\
+ R->parent = XP;\
+ if ((C0 = X->child[0]) != 0) {\
+ if (RTCHECK(ok_address(M, C0))) {\
+ R->child[0] = C0;\
+ C0->parent = R;\
+ }\
+ else\
+ CORRUPTION_ERROR_ACTION(M);\
+ }\
+ if ((C1 = X->child[1]) != 0) {\
+ if (RTCHECK(ok_address(M, C1))) {\
+ R->child[1] = C1;\
+ C1->parent = R;\
+ }\
+ else\
+ CORRUPTION_ERROR_ACTION(M);\
+ }\
+ }\
+ else\
+ CORRUPTION_ERROR_ACTION(M);\
+ }\
+ }\
+}
+
+/* Relays to large vs small bin operations */
+
+#define insert_chunk(M, P, S)\
+ if (is_small(S)) insert_small_chunk(M, P, S)\
+ else { tchunkptr TP = (tchunkptr)(P); insert_large_chunk(M, TP, S); }
+
+#define unlink_chunk(M, P, S)\
+ if (is_small(S)) unlink_small_chunk(M, P, S)\
+ else { tchunkptr TP = (tchunkptr)(P); unlink_large_chunk(M, TP); }
+
+
+/* Relays to internal calls to malloc/free from realloc, memalign etc */
+
+#if ONLY_MSPACES
+#define internal_malloc(m, b) mspace_malloc(m, b)
+#define internal_free(m, mem) mspace_free(m,mem);
+#else /* ONLY_MSPACES */
+#if MSPACES
+#define internal_malloc(m, b)\
+ ((m == gm)? dlmalloc(b) : mspace_malloc(m, b))
+#define internal_free(m, mem)\
+ if (m == gm) dlfree(mem); else mspace_free(m,mem);
+#else /* MSPACES */
+#define internal_malloc(m, b) dlmalloc(b)
+#define internal_free(m, mem) dlfree(mem)
+#endif /* MSPACES */
+#endif /* ONLY_MSPACES */
+
+/* ----------------------- Direct-mmapping chunks ----------------------- */
+
+/*
+ Directly mmapped chunks are set up with an offset to the start of
+ the mmapped region stored in the prev_foot field of the chunk. This
+ allows reconstruction of the required argument to MUNMAP when freed,
+ and also allows adjustment of the returned chunk to meet alignment
+ requirements (especially in memalign).
+*/
+
+/* Malloc using mmap */
+static void* mmap_alloc(mstate m, size_t nb) {
+ size_t mmsize = mmap_align(nb + SIX_SIZE_T_SIZES + CHUNK_ALIGN_MASK);
+ if (m->footprint_limit != 0) {
+ size_t fp = m->footprint + mmsize;
+ if (fp <= m->footprint || fp > m->footprint_limit)
+ return 0;
+ }
+ if (mmsize > nb) { /* Check for wrap around 0 */
+ char* mm = (char*)(CALL_DIRECT_MMAP(mmsize));
+ if (mm != CMFAIL) {
+ size_t offset = align_offset(chunk2mem(mm));
+ size_t psize = mmsize - offset - MMAP_FOOT_PAD;
+ mchunkptr p = (mchunkptr)(mm + offset);
+ p->prev_foot = offset;
+ p->head = psize;
+ mark_inuse_foot(m, p, psize);
+ chunk_plus_offset(p, psize)->head = FENCEPOST_HEAD;
+ chunk_plus_offset(p, psize+SIZE_T_SIZE)->head = 0;
+
+ if (m->least_addr == 0 || mm < m->least_addr)
+ m->least_addr = mm;
+ if ((m->footprint += mmsize) > m->max_footprint)
+ m->max_footprint = m->footprint;
+ assert(is_aligned(chunk2mem(p)));
+ check_mmapped_chunk(m, p);
+ return chunk2mem(p);
+ }
+ }
+ return 0;
+}
+
+/* Realloc using mmap */
+static mchunkptr mmap_resize(mstate m, mchunkptr oldp, size_t nb, int flags) {
+ size_t oldsize = chunksize(oldp);
+ (void)flags; /* placate people compiling -Wunused */
+ if (is_small(nb)) /* Can't shrink mmap regions below small size */
+ return 0;
+ /* Keep old chunk if big enough but not too big */
+ if (oldsize >= nb + SIZE_T_SIZE &&
+ (oldsize - nb) <= (mparams.granularity << 1))
+ return oldp;
+ else {
+ size_t offset = oldp->prev_foot;
+ size_t oldmmsize = oldsize + offset + MMAP_FOOT_PAD;
+ size_t newmmsize = mmap_align(nb + SIX_SIZE_T_SIZES + CHUNK_ALIGN_MASK);
+ char* cp = (char*)CALL_MREMAP((char*)oldp - offset,
+ oldmmsize, newmmsize, flags);
+ if (cp != CMFAIL) {
+ mchunkptr newp = (mchunkptr)(cp + offset);
+ size_t psize = newmmsize - offset - MMAP_FOOT_PAD;
+ newp->head = psize;
+ mark_inuse_foot(m, newp, psize);
+ chunk_plus_offset(newp, psize)->head = FENCEPOST_HEAD;
+ chunk_plus_offset(newp, psize+SIZE_T_SIZE)->head = 0;
+
+ if (cp < m->least_addr)
+ m->least_addr = cp;
+ if ((m->footprint += newmmsize - oldmmsize) > m->max_footprint)
+ m->max_footprint = m->footprint;
+ check_mmapped_chunk(m, newp);
+ return newp;
+ }
+ }
+ return 0;
+}
+
+
+/* -------------------------- mspace management -------------------------- */
+
+/* Initialize top chunk and its size */
+static void init_top(mstate m, mchunkptr p, size_t psize) {
+ /* Ensure alignment */
+ size_t offset = align_offset(chunk2mem(p));
+ p = (mchunkptr)((char*)p + offset);
+ psize -= offset;
+
+ m->top = p;
+ m->topsize = psize;
+ p->head = psize | PINUSE_BIT;
+ /* set size of fake trailing chunk holding overhead space only once */
+ chunk_plus_offset(p, psize)->head = TOP_FOOT_SIZE;
+ m->trim_check = mparams.trim_threshold; /* reset on each update */
+}
+
+/* Initialize bins for a new mstate that is otherwise zeroed out */
+static void init_bins(mstate m) {
+ /* Establish circular links for smallbins */
+ bindex_t i;
+ for (i = 0; i < NSMALLBINS; ++i) {
+ sbinptr bin = smallbin_at(m,i);
+ bin->fd = bin->bk = bin;
+ }
+}
+
+#if PROCEED_ON_ERROR
+
+/* default corruption action */
+static void reset_on_error(mstate m) {
+ int i;
+ ++malloc_corruption_error_count;
+ /* Reinitialize fields to forget about all memory */
+ m->smallmap = m->treemap = 0;
+ m->dvsize = m->topsize = 0;
+ m->seg.base = 0;
+ m->seg.size = 0;
+ m->seg.next = 0;
+ m->top = m->dv = 0;
+ for (i = 0; i < NTREEBINS; ++i)
+ *treebin_at(m, i) = 0;
+ init_bins(m);
+}
+#endif /* PROCEED_ON_ERROR */
+
+/* Allocate chunk and prepend remainder with chunk in successor base. */
+static void* prepend_alloc(mstate m, char* newbase, char* oldbase,
+ size_t nb) {
+ mchunkptr p = align_as_chunk(newbase);
+ mchunkptr oldfirst = align_as_chunk(oldbase);
+ size_t psize = (char*)oldfirst - (char*)p;
+ mchunkptr q = chunk_plus_offset(p, nb);
+ size_t qsize = psize - nb;
+ set_size_and_pinuse_of_inuse_chunk(m, p, nb);
+
+ assert((char*)oldfirst > (char*)q);
+ assert(pinuse(oldfirst));
+ assert(qsize >= MIN_CHUNK_SIZE);
+
+ /* consolidate remainder with first chunk of old base */
+ if (oldfirst == m->top) {
+ size_t tsize = m->topsize += qsize;
+ m->top = q;
+ q->head = tsize | PINUSE_BIT;
+ check_top_chunk(m, q);
+ }
+ else if (oldfirst == m->dv) {
+ size_t dsize = m->dvsize += qsize;
+ m->dv = q;
+ set_size_and_pinuse_of_free_chunk(q, dsize);
+ }
+ else {
+ if (!is_inuse(oldfirst)) {
+ size_t nsize = chunksize(oldfirst);
+ unlink_chunk(m, oldfirst, nsize);
+ oldfirst = chunk_plus_offset(oldfirst, nsize);
+ qsize += nsize;
+ }
+ set_free_with_pinuse(q, qsize, oldfirst);
+ insert_chunk(m, q, qsize);
+ check_free_chunk(m, q);
+ }
+
+ check_malloced_chunk(m, chunk2mem(p), nb);
+ return chunk2mem(p);
+}
+
+/* Add a segment to hold a new noncontiguous region */
+static void add_segment(mstate m, char* tbase, size_t tsize, flag_t mmapped) {
+ /* Determine locations and sizes of segment, fenceposts, old top */
+ char* old_top = (char*)m->top;
+ msegmentptr oldsp = segment_holding(m, old_top);
+ char* old_end = oldsp->base + oldsp->size;
+ size_t ssize = pad_request(sizeof(struct malloc_segment));
+ char* rawsp = old_end - (ssize + FOUR_SIZE_T_SIZES + CHUNK_ALIGN_MASK);
+ size_t offset = align_offset(chunk2mem(rawsp));
+ char* asp = rawsp + offset;
+ char* csp = (asp < (old_top + MIN_CHUNK_SIZE))? old_top : asp;
+ mchunkptr sp = (mchunkptr)csp;
+ msegmentptr ss = (msegmentptr)(chunk2mem(sp));
+ mchunkptr tnext = chunk_plus_offset(sp, ssize);
+ mchunkptr p = tnext;
+ int nfences = 0;
+
+ /* reset top to new space */
+ init_top(m, (mchunkptr)tbase, tsize - TOP_FOOT_SIZE);
+
+ /* Set up segment record */
+ assert(is_aligned(ss));
+ set_size_and_pinuse_of_inuse_chunk(m, sp, ssize);
+ *ss = m->seg; /* Push current record */
+ m->seg.base = tbase;
+ m->seg.size = tsize;
+ m->seg.sflags = mmapped;
+ m->seg.next = ss;
+
+ /* Insert trailing fenceposts */
+ for (;;) {
+ mchunkptr nextp = chunk_plus_offset(p, SIZE_T_SIZE);
+ p->head = FENCEPOST_HEAD;
+ ++nfences;
+ if ((char*)(&(nextp->head)) < old_end)
+ p = nextp;
+ else
+ break;
+ }
+ assert(nfences >= 2);
+
+ /* Insert the rest of old top into a bin as an ordinary free chunk */
+ if (csp != old_top) {
+ mchunkptr q = (mchunkptr)old_top;
+ size_t psize = csp - old_top;
+ mchunkptr tn = chunk_plus_offset(q, psize);
+ set_free_with_pinuse(q, psize, tn);
+ insert_chunk(m, q, psize);
+ }
+
+ check_top_chunk(m, m->top);
+}
+
+/* -------------------------- System allocation -------------------------- */
+
+/* Get memory from system using MORECORE or MMAP */
+static void* sys_alloc(mstate m, size_t nb) {
+ char* tbase = CMFAIL;
+ size_t tsize = 0;
+ flag_t mmap_flag = 0;
+ size_t asize; /* allocation size */
+
+ ensure_initialization();
+
+ /* Directly map large chunks, but only if already initialized */
+ if (use_mmap(m) && nb >= mparams.mmap_threshold && m->topsize != 0) {
+ void* mem = mmap_alloc(m, nb);
+ if (mem != 0)
+ return mem;
+ }
+
+ asize = granularity_align(nb + SYS_ALLOC_PADDING);
+ if (asize <= nb)
+ return 0; /* wraparound */
+ if (m->footprint_limit != 0) {
+ size_t fp = m->footprint + asize;
+ if (fp <= m->footprint || fp > m->footprint_limit)
+ return 0;
+ }
+
+ /*
+ Try getting memory in any of three ways (in most-preferred to
+ least-preferred order):
+ 1. A call to MORECORE that can normally contiguously extend memory.
+ (disabled if not MORECORE_CONTIGUOUS or not HAVE_MORECORE or
+ or main space is mmapped or a previous contiguous call failed)
+ 2. A call to MMAP new space (disabled if not HAVE_MMAP).
+ Note that under the default settings, if MORECORE is unable to
+ fulfill a request, and HAVE_MMAP is true, then mmap is
+ used as a noncontiguous system allocator. This is a useful backup
+ strategy for systems with holes in address spaces -- in this case
+ sbrk cannot contiguously expand the heap, but mmap may be able to
+ find space.
+ 3. A call to MORECORE that cannot usually contiguously extend memory.
+ (disabled if not HAVE_MORECORE)
+
+ In all cases, we need to request enough bytes from system to ensure
+ we can malloc nb bytes upon success, so pad with enough space for
+ top_foot, plus alignment-pad to make sure we don't lose bytes if
+ not on boundary, and round this up to a granularity unit.
+ */
+
+ if (MORECORE_CONTIGUOUS && !use_noncontiguous(m)) {
+ char* br = CMFAIL;
+ size_t ssize = asize; /* sbrk call size */
+ msegmentptr ss = (m->top == 0)? 0 : segment_holding(m, (char*)m->top);
+ ACQUIRE_MALLOC_GLOBAL_LOCK();
+
+ if (ss == 0) { /* First time through or recovery */
+ char* base = (char*)CALL_MORECORE(0);
+ if (base != CMFAIL) {
+ size_t fp;
+ /* Adjust to end on a page boundary */
+ if (!is_page_aligned(base))
+ ssize += (page_align((size_t)base) - (size_t)base);
+ fp = m->footprint + ssize; /* recheck limits */
+ if (ssize > nb && ssize < HALF_MAX_SIZE_T &&
+ (m->footprint_limit == 0 ||
+ (fp > m->footprint && fp <= m->footprint_limit)) &&
+ (br = (char*)(CALL_MORECORE(ssize))) == base) {
+ tbase = base;
+ tsize = ssize;
+ }
+ }
+ }
+ else {
+ /* Subtract out existing available top space from MORECORE request. */
+ ssize = granularity_align(nb - m->topsize + SYS_ALLOC_PADDING);
+ /* Use mem here only if it did continuously extend old space */
+ if (ssize < HALF_MAX_SIZE_T &&
+ (br = (char*)(CALL_MORECORE(ssize))) == ss->base+ss->size) {
+ tbase = br;
+ tsize = ssize;
+ }
+ }
+
+ if (tbase == CMFAIL) { /* Cope with partial failure */
+ if (br != CMFAIL) { /* Try to use/extend the space we did get */
+ if (ssize < HALF_MAX_SIZE_T &&
+ ssize < nb + SYS_ALLOC_PADDING) {
+ size_t esize = granularity_align(nb + SYS_ALLOC_PADDING - ssize);
+ if (esize < HALF_MAX_SIZE_T) {
+ char* end = (char*)CALL_MORECORE(esize);
+ if (end != CMFAIL)
+ ssize += esize;
+ else { /* Can't use; try to release */
+ (void) CALL_MORECORE(-ssize);
+ br = CMFAIL;
+ }
+ }
+ }
+ }
+ if (br != CMFAIL) { /* Use the space we did get */
+ tbase = br;
+ tsize = ssize;
+ }
+ else
+ disable_contiguous(m); /* Don't try contiguous path in the future */
+ }
+
+ RELEASE_MALLOC_GLOBAL_LOCK();
+ }
+
+ if (HAVE_MMAP && tbase == CMFAIL) { /* Try MMAP */
+ char* mp = (char*)(CALL_MMAP(asize));
+ if (mp != CMFAIL) {
+ tbase = mp;
+ tsize = asize;
+ mmap_flag = USE_MMAP_BIT;
+ }
+ }
+
+ if (HAVE_MORECORE && tbase == CMFAIL) { /* Try noncontiguous MORECORE */
+ if (asize < HALF_MAX_SIZE_T) {
+ char* br = CMFAIL;
+ char* end = CMFAIL;
+ ACQUIRE_MALLOC_GLOBAL_LOCK();
+ br = (char*)(CALL_MORECORE(asize));
+ end = (char*)(CALL_MORECORE(0));
+ RELEASE_MALLOC_GLOBAL_LOCK();
+ if (br != CMFAIL && end != CMFAIL && br < end) {
+ size_t ssize = end - br;
+ if (ssize > nb + TOP_FOOT_SIZE) {
+ tbase = br;
+ tsize = ssize;
+ }
+ }
+ }
+ }
+
+ if (tbase != CMFAIL) {
+
+ if ((m->footprint += tsize) > m->max_footprint)
+ m->max_footprint = m->footprint;
+
+ if (!is_initialized(m)) { /* first-time initialization */
+ if (m->least_addr == 0 || tbase < m->least_addr)
+ m->least_addr = tbase;
+ m->seg.base = tbase;
+ m->seg.size = tsize;
+ m->seg.sflags = mmap_flag;
+ m->magic = mparams.magic;
+ m->release_checks = MAX_RELEASE_CHECK_RATE;
+ init_bins(m);
+#if !ONLY_MSPACES
+ if (is_global(m))
+ init_top(m, (mchunkptr)tbase, tsize - TOP_FOOT_SIZE);
+ else
+#endif
+ {
+ /* Offset top by embedded malloc_state */
+ mchunkptr mn = next_chunk(mem2chunk(m));
+ init_top(m, mn, (size_t)((tbase + tsize) - (char*)mn) -TOP_FOOT_SIZE);
+ }
+ }
+
+ else {
+ /* Try to merge with an existing segment */
+ msegmentptr sp = &m->seg;
+ /* Only consider most recent segment if traversal suppressed */
+ while (sp != 0 && tbase != sp->base + sp->size)
+ sp = (NO_SEGMENT_TRAVERSAL) ? 0 : sp->next;
+ if (sp != 0 &&
+ !is_extern_segment(sp) &&
+ (sp->sflags & USE_MMAP_BIT) == mmap_flag &&
+ segment_holds(sp, m->top)) { /* append */
+ sp->size += tsize;
+ init_top(m, m->top, m->topsize + tsize);
+ }
+ else {
+ if (tbase < m->least_addr)
+ m->least_addr = tbase;
+ sp = &m->seg;
+ while (sp != 0 && sp->base != tbase + tsize)
+ sp = (NO_SEGMENT_TRAVERSAL) ? 0 : sp->next;
+ if (sp != 0 &&
+ !is_extern_segment(sp) &&
+ (sp->sflags & USE_MMAP_BIT) == mmap_flag) {
+ char* oldbase = sp->base;
+ sp->base = tbase;
+ sp->size += tsize;
+ return prepend_alloc(m, tbase, oldbase, nb);
+ }
+ else
+ add_segment(m, tbase, tsize, mmap_flag);
+ }
+ }
+
+ if (nb < m->topsize) { /* Allocate from new or extended top space */
+ size_t rsize = m->topsize -= nb;
+ mchunkptr p = m->top;
+ mchunkptr r = m->top = chunk_plus_offset(p, nb);
+ r->head = rsize | PINUSE_BIT;
+ set_size_and_pinuse_of_inuse_chunk(m, p, nb);
+ check_top_chunk(m, m->top);
+ check_malloced_chunk(m, chunk2mem(p), nb);
+ return chunk2mem(p);
+ }
+ }
+
+ MALLOC_FAILURE_ACTION;
+ return 0;
+}
+
+/* ----------------------- system deallocation -------------------------- */
+
+/* Unmap and unlink any mmapped segments that don't contain used chunks */
+static size_t release_unused_segments(mstate m) {
+ size_t released = 0;
+ int nsegs = 0;
+ msegmentptr pred = &m->seg;
+ msegmentptr sp = pred->next;
+ while (sp != 0) {
+ char* base = sp->base;
+ size_t size = sp->size;
+ msegmentptr next = sp->next;
+ ++nsegs;
+ if (is_mmapped_segment(sp) && !is_extern_segment(sp)) {
+ mchunkptr p = align_as_chunk(base);
+ size_t psize = chunksize(p);
+ /* Can unmap if first chunk holds entire segment and not pinned */
+ if (!is_inuse(p) && (char*)p + psize >= base + size - TOP_FOOT_SIZE) {
+ tchunkptr tp = (tchunkptr)p;
+ assert(segment_holds(sp, (char*)sp));
+ if (p == m->dv) {
+ m->dv = 0;
+ m->dvsize = 0;
+ }
+ else {
+ unlink_large_chunk(m, tp);
+ }
+ if (CALL_MUNMAP(base, size) == 0) {
+ released += size;
+ m->footprint -= size;
+ /* unlink obsoleted record */
+ sp = pred;
+ sp->next = next;
+ }
+ else { /* back out if cannot unmap */
+ insert_large_chunk(m, tp, psize);
+ }
+ }
+ }
+ if (NO_SEGMENT_TRAVERSAL) /* scan only first segment */
+ break;
+ pred = sp;
+ sp = next;
+ }
+ /* Reset check counter */
+ m->release_checks = (((size_t) nsegs > (size_t) MAX_RELEASE_CHECK_RATE)?
+ (size_t) nsegs : (size_t) MAX_RELEASE_CHECK_RATE);
+ return released;
+}
+
+static int sys_trim(mstate m, size_t pad) {
+ size_t released = 0;
+ ensure_initialization();
+ if (pad < MAX_REQUEST && is_initialized(m)) {
+ pad += TOP_FOOT_SIZE; /* ensure enough room for segment overhead */
+
+ if (m->topsize > pad) {
+ /* Shrink top space in granularity-size units, keeping at least one */
+ size_t unit = mparams.granularity;
+ size_t extra = ((m->topsize - pad + (unit - SIZE_T_ONE)) / unit -
+ SIZE_T_ONE) * unit;
+ msegmentptr sp = segment_holding(m, (char*)m->top);
+
+ if (!is_extern_segment(sp)) {
+ if (is_mmapped_segment(sp)) {
+ if (HAVE_MMAP &&
+ sp->size >= extra &&
+ !has_segment_link(m, sp)) { /* can't shrink if pinned */
+ size_t newsize = sp->size - extra;
+ (void)newsize; /* placate people compiling -Wunused-variable */
+ /* Prefer mremap, fall back to munmap */
+ if ((CALL_MREMAP(sp->base, sp->size, newsize, 0) != MFAIL) ||
+ (CALL_MUNMAP(sp->base + newsize, extra) == 0)) {
+ released = extra;
+ }
+ }
+ }
+ else if (HAVE_MORECORE) {
+ if (extra >= HALF_MAX_SIZE_T) /* Avoid wrapping negative */
+ extra = (HALF_MAX_SIZE_T) + SIZE_T_ONE - unit;
+ ACQUIRE_MALLOC_GLOBAL_LOCK();
+ {
+ /* Make sure end of memory is where we last set it. */
+ char* old_br = (char*)(CALL_MORECORE(0));
+ if (old_br == sp->base + sp->size) {
+ char* rel_br = (char*)(CALL_MORECORE(-extra));
+ char* new_br = (char*)(CALL_MORECORE(0));
+ if (rel_br != CMFAIL && new_br < old_br)
+ released = old_br - new_br;
+ }
+ }
+ RELEASE_MALLOC_GLOBAL_LOCK();
+ }
+ }
+
+ if (released != 0) {
+ sp->size -= released;
+ m->footprint -= released;
+ init_top(m, m->top, m->topsize - released);
+ check_top_chunk(m, m->top);
+ }
+ }
+
+ /* Unmap any unused mmapped segments */
+ if (HAVE_MMAP)
+ released += release_unused_segments(m);
+
+ /* On failure, disable autotrim to avoid repeated failed future calls */
+ if (released == 0 && m->topsize > m->trim_check)
+ m->trim_check = MAX_SIZE_T;
+ }
+
+ return (released != 0)? 1 : 0;
+}
+
+/* Consolidate and bin a chunk. Differs from exported versions
+ of free mainly in that the chunk need not be marked as inuse.
+*/
+static void dispose_chunk(mstate m, mchunkptr p, size_t psize) {
+ mchunkptr next = chunk_plus_offset(p, psize);
+ if (!pinuse(p)) {
+ mchunkptr prev;
+ size_t prevsize = p->prev_foot;
+ if (is_mmapped(p)) {
+ psize += prevsize + MMAP_FOOT_PAD;
+ if (CALL_MUNMAP((char*)p - prevsize, psize) == 0)
+ m->footprint -= psize;
+ return;
+ }
+ prev = chunk_minus_offset(p, prevsize);
+ psize += prevsize;
+ p = prev;
+ if (RTCHECK(ok_address(m, prev))) { /* consolidate backward */
+ if (p != m->dv) {
+ unlink_chunk(m, p, prevsize);
+ }
+ else if ((next->head & INUSE_BITS) == INUSE_BITS) {
+ m->dvsize = psize;
+ set_free_with_pinuse(p, psize, next);
+ return;
+ }
+ }
+ else {
+ CORRUPTION_ERROR_ACTION(m);
+ return;
+ }
+ }
+ if (RTCHECK(ok_address(m, next))) {
+ if (!cinuse(next)) { /* consolidate forward */
+ if (next == m->top) {
+ size_t tsize = m->topsize += psize;
+ m->top = p;
+ p->head = tsize | PINUSE_BIT;
+ if (p == m->dv) {
+ m->dv = 0;
+ m->dvsize = 0;
+ }
+ return;
+ }
+ else if (next == m->dv) {
+ size_t dsize = m->dvsize += psize;
+ m->dv = p;
+ set_size_and_pinuse_of_free_chunk(p, dsize);
+ return;
+ }
+ else {
+ size_t nsize = chunksize(next);
+ psize += nsize;
+ unlink_chunk(m, next, nsize);
+ set_size_and_pinuse_of_free_chunk(p, psize);
+ if (p == m->dv) {
+ m->dvsize = psize;
+ return;
+ }
+ }
+ }
+ else {
+ set_free_with_pinuse(p, psize, next);
+ }
+ insert_chunk(m, p, psize);
+ }
+ else {
+ CORRUPTION_ERROR_ACTION(m);
+ }
+}
+
+/* ---------------------------- malloc --------------------------- */
+
+/* allocate a large request from the best fitting chunk in a treebin */
+static void* tmalloc_large(mstate m, size_t nb) {
+ tchunkptr v = 0;
+ size_t rsize = -nb; /* Unsigned negation */
+ tchunkptr t;
+ bindex_t idx;
+ compute_tree_index(nb, idx);
+ if ((t = *treebin_at(m, idx)) != 0) {
+ /* Traverse tree for this bin looking for node with size == nb */
+ size_t sizebits = nb << leftshift_for_tree_index(idx);
+ tchunkptr rst = 0; /* The deepest untaken right subtree */
+ for (;;) {
+ tchunkptr rt;
+ size_t trem = chunksize(t) - nb;
+ if (trem < rsize) {
+ v = t;
+ if ((rsize = trem) == 0)
+ break;
+ }
+ rt = t->child[1];
+ t = t->child[(sizebits >> (SIZE_T_BITSIZE-SIZE_T_ONE)) & 1];
+ if (rt != 0 && rt != t)
+ rst = rt;
+ if (t == 0) {
+ t = rst; /* set t to least subtree holding sizes > nb */
+ break;
+ }
+ sizebits <<= 1;
+ }
+ }
+ if (t == 0 && v == 0) { /* set t to root of next non-empty treebin */
+ binmap_t leftbits = left_bits(idx2bit(idx)) & m->treemap;
+ if (leftbits != 0) {
+ bindex_t i;
+ binmap_t leastbit = least_bit(leftbits);
+ compute_bit2idx(leastbit, i);
+ t = *treebin_at(m, i);
+ }
+ }
+
+ while (t != 0) { /* find smallest of tree or subtree */
+ size_t trem = chunksize(t) - nb;
+ if (trem < rsize) {
+ rsize = trem;
+ v = t;
+ }
+ t = leftmost_child(t);
+ }
+
+ /* If dv is a better fit, return 0 so malloc will use it */
+ if (v != 0 && rsize < (size_t)(m->dvsize - nb)) {
+ if (RTCHECK(ok_address(m, v))) { /* split */
+ mchunkptr r = chunk_plus_offset(v, nb);
+ assert(chunksize(v) == rsize + nb);
+ if (RTCHECK(ok_next(v, r))) {
+ unlink_large_chunk(m, v);
+ if (rsize < MIN_CHUNK_SIZE)
+ set_inuse_and_pinuse(m, v, (rsize + nb));
+ else {
+ set_size_and_pinuse_of_inuse_chunk(m, v, nb);
+ set_size_and_pinuse_of_free_chunk(r, rsize);
+ insert_chunk(m, r, rsize);
+ }
+ return chunk2mem(v);
+ }
+ }
+ CORRUPTION_ERROR_ACTION(m);
+ }
+ return 0;
+}
+
+/* allocate a small request from the best fitting chunk in a treebin */
+static void* tmalloc_small(mstate m, size_t nb) {
+ tchunkptr t, v;
+ size_t rsize;
+ bindex_t i;
+ binmap_t leastbit = least_bit(m->treemap);
+ compute_bit2idx(leastbit, i);
+ v = t = *treebin_at(m, i);
+ rsize = chunksize(t) - nb;
+
+ while ((t = leftmost_child(t)) != 0) {
+ size_t trem = chunksize(t) - nb;
+ if (trem < rsize) {
+ rsize = trem;
+ v = t;
+ }
+ }
+
+ if (RTCHECK(ok_address(m, v))) {
+ mchunkptr r = chunk_plus_offset(v, nb);
+ assert(chunksize(v) == rsize + nb);
+ if (RTCHECK(ok_next(v, r))) {
+ unlink_large_chunk(m, v);
+ if (rsize < MIN_CHUNK_SIZE)
+ set_inuse_and_pinuse(m, v, (rsize + nb));
+ else {
+ set_size_and_pinuse_of_inuse_chunk(m, v, nb);
+ set_size_and_pinuse_of_free_chunk(r, rsize);
+ replace_dv(m, r, rsize);
+ }
+ return chunk2mem(v);
+ }
+ }
+
+ CORRUPTION_ERROR_ACTION(m);
+ return 0;
+}
+
+#if !ONLY_MSPACES
+
+void* dlmalloc(size_t bytes) {
+ /*
+ Basic algorithm:
+ If a small request (< 256 bytes minus per-chunk overhead):
+ 1. If one exists, use a remainderless chunk in associated smallbin.
+ (Remainderless means that there are too few excess bytes to
+ represent as a chunk.)
+ 2. If it is big enough, use the dv chunk, which is normally the
+ chunk adjacent to the one used for the most recent small request.
+ 3. If one exists, split the smallest available chunk in a bin,
+ saving remainder in dv.
+ 4. If it is big enough, use the top chunk.
+ 5. If available, get memory from system and use it
+ Otherwise, for a large request:
+ 1. Find the smallest available binned chunk that fits, and use it
+ if it is better fitting than dv chunk, splitting if necessary.
+ 2. If better fitting than any binned chunk, use the dv chunk.
+ 3. If it is big enough, use the top chunk.
+ 4. If request size >= mmap threshold, try to directly mmap this chunk.
+ 5. If available, get memory from system and use it
+
+ The ugly goto's here ensure that postaction occurs along all paths.
+ */
+
+#if USE_LOCKS
+ ensure_initialization(); /* initialize in sys_alloc if not using locks */
+#endif
+
+ if (!PREACTION(gm)) {
+ void* mem;
+ size_t nb;
+ if (bytes <= MAX_SMALL_REQUEST) {
+ bindex_t idx;
+ binmap_t smallbits;
+ nb = (bytes < MIN_REQUEST)? MIN_CHUNK_SIZE : pad_request(bytes);
+ idx = small_index(nb);
+ smallbits = gm->smallmap >> idx;
+
+ if ((smallbits & 0x3U) != 0) { /* Remainderless fit to a smallbin. */
+ mchunkptr b, p;
+ idx += ~smallbits & 1; /* Uses next bin if idx empty */
+ b = smallbin_at(gm, idx);
+ p = b->fd;
+ assert(chunksize(p) == small_index2size(idx));
+ unlink_first_small_chunk(gm, b, p, idx);
+ set_inuse_and_pinuse(gm, p, small_index2size(idx));
+ mem = chunk2mem(p);
+ check_malloced_chunk(gm, mem, nb);
+ goto postaction;
+ }
+
+ else if (nb > gm->dvsize) {
+ if (smallbits != 0) { /* Use chunk in next nonempty smallbin */
+ mchunkptr b, p, r;
+ size_t rsize;
+ bindex_t i;
+ binmap_t leftbits = (smallbits << idx) & left_bits(idx2bit(idx));
+ binmap_t leastbit = least_bit(leftbits);
+ compute_bit2idx(leastbit, i);
+ b = smallbin_at(gm, i);
+ p = b->fd;
+ assert(chunksize(p) == small_index2size(i));
+ unlink_first_small_chunk(gm, b, p, i);
+ rsize = small_index2size(i) - nb;
+ /* Fit here cannot be remainderless if 4byte sizes */
+ if (SIZE_T_SIZE != 4 && rsize < MIN_CHUNK_SIZE)
+ set_inuse_and_pinuse(gm, p, small_index2size(i));
+ else {
+ set_size_and_pinuse_of_inuse_chunk(gm, p, nb);
+ r = chunk_plus_offset(p, nb);
+ set_size_and_pinuse_of_free_chunk(r, rsize);
+ replace_dv(gm, r, rsize);
+ }
+ mem = chunk2mem(p);
+ check_malloced_chunk(gm, mem, nb);
+ goto postaction;
+ }
+
+ else if (gm->treemap != 0 && (mem = tmalloc_small(gm, nb)) != 0) {
+ check_malloced_chunk(gm, mem, nb);
+ goto postaction;
+ }
+ }
+ }
+ else if (bytes >= MAX_REQUEST)
+ nb = MAX_SIZE_T; /* Too big to allocate. Force failure (in sys alloc) */
+ else {
+ nb = pad_request(bytes);
+ if (gm->treemap != 0 && (mem = tmalloc_large(gm, nb)) != 0) {
+ check_malloced_chunk(gm, mem, nb);
+ goto postaction;
+ }
+ }
+
+ if (nb <= gm->dvsize) {
+ size_t rsize = gm->dvsize - nb;
+ mchunkptr p = gm->dv;
+ if (rsize >= MIN_CHUNK_SIZE) { /* split dv */
+ mchunkptr r = gm->dv = chunk_plus_offset(p, nb);
+ gm->dvsize = rsize;
+ set_size_and_pinuse_of_free_chunk(r, rsize);
+ set_size_and_pinuse_of_inuse_chunk(gm, p, nb);
+ }
+ else { /* exhaust dv */
+ size_t dvs = gm->dvsize;
+ gm->dvsize = 0;
+ gm->dv = 0;
+ set_inuse_and_pinuse(gm, p, dvs);
+ }
+ mem = chunk2mem(p);
+ check_malloced_chunk(gm, mem, nb);
+ goto postaction;
+ }
+
+ else if (nb < gm->topsize) { /* Split top */
+ size_t rsize = gm->topsize -= nb;
+ mchunkptr p = gm->top;
+ mchunkptr r = gm->top = chunk_plus_offset(p, nb);
+ r->head = rsize | PINUSE_BIT;
+ set_size_and_pinuse_of_inuse_chunk(gm, p, nb);
+ mem = chunk2mem(p);
+ check_top_chunk(gm, gm->top);
+ check_malloced_chunk(gm, mem, nb);
+ goto postaction;
+ }
+
+ mem = sys_alloc(gm, nb);
+
+ postaction:
+ POSTACTION(gm);
+ return mem;
+ }
+
+ return 0;
+}
+
+/* ---------------------------- free --------------------------- */
+
+void dlfree(void* mem) {
+ /*
+ Consolidate freed chunks with preceding or succeeding bordering
+ free chunks, if they exist, and then place in a bin. Intermixed
+ with special cases for top, dv, mmapped chunks, and usage errors.
+ */
+
+ if (mem != 0) {
+ mchunkptr p = mem2chunk(mem);
+#if FOOTERS
+ mstate fm = get_mstate_for(p);
+ if (!ok_magic(fm)) {
+ USAGE_ERROR_ACTION(fm, p);
+ return;
+ }
+#else /* FOOTERS */
+#define fm gm
+#endif /* FOOTERS */
+ if (!PREACTION(fm)) {
+ check_inuse_chunk(fm, p);
+ if (RTCHECK(ok_address(fm, p) && ok_inuse(p))) {
+ size_t psize = chunksize(p);
+ mchunkptr next = chunk_plus_offset(p, psize);
+ if (!pinuse(p)) {
+ size_t prevsize = p->prev_foot;
+ if (is_mmapped(p)) {
+ psize += prevsize + MMAP_FOOT_PAD;
+ if (CALL_MUNMAP((char*)p - prevsize, psize) == 0)
+ fm->footprint -= psize;
+ goto postaction;
+ }
+ else {
+ mchunkptr prev = chunk_minus_offset(p, prevsize);
+ psize += prevsize;
+ p = prev;
+ if (RTCHECK(ok_address(fm, prev))) { /* consolidate backward */
+ if (p != fm->dv) {
+ unlink_chunk(fm, p, prevsize);
+ }
+ else if ((next->head & INUSE_BITS) == INUSE_BITS) {
+ fm->dvsize = psize;
+ set_free_with_pinuse(p, psize, next);
+ goto postaction;
+ }
+ }
+ else
+ goto erroraction;
+ }
+ }
+
+ if (RTCHECK(ok_next(p, next) && ok_pinuse(next))) {
+ if (!cinuse(next)) { /* consolidate forward */
+ if (next == fm->top) {
+ size_t tsize = fm->topsize += psize;
+ fm->top = p;
+ p->head = tsize | PINUSE_BIT;
+ if (p == fm->dv) {
+ fm->dv = 0;
+ fm->dvsize = 0;
+ }
+ if (should_trim(fm, tsize))
+ sys_trim(fm, 0);
+ goto postaction;
+ }
+ else if (next == fm->dv) {
+ size_t dsize = fm->dvsize += psize;
+ fm->dv = p;
+ set_size_and_pinuse_of_free_chunk(p, dsize);
+ goto postaction;
+ }
+ else {
+ size_t nsize = chunksize(next);
+ psize += nsize;
+ unlink_chunk(fm, next, nsize);
+ set_size_and_pinuse_of_free_chunk(p, psize);
+ if (p == fm->dv) {
+ fm->dvsize = psize;
+ goto postaction;
+ }
+ }
+ }
+ else
+ set_free_with_pinuse(p, psize, next);
+
+ if (is_small(psize)) {
+ insert_small_chunk(fm, p, psize);
+ check_free_chunk(fm, p);
+ }
+ else {
+ tchunkptr tp = (tchunkptr)p;
+ insert_large_chunk(fm, tp, psize);
+ check_free_chunk(fm, p);
+ if (--fm->release_checks == 0)
+ release_unused_segments(fm);
+ }
+ goto postaction;
+ }
+ }
+ erroraction:
+ USAGE_ERROR_ACTION(fm, p);
+ postaction:
+ POSTACTION(fm);
+ }
+ }
+#if !FOOTERS
+#undef fm
+#endif /* FOOTERS */
+}
+
+void* dlcalloc(size_t n_elements, size_t elem_size) {
+ void* mem;
+ size_t req = 0;
+ if (n_elements != 0) {
+ req = n_elements * elem_size;
+ if (((n_elements | elem_size) & ~(size_t)0xffff) &&
+ (req / n_elements != elem_size))
+ req = MAX_SIZE_T; /* force downstream failure on overflow */
+ }
+ mem = dlmalloc(req);
+ if (mem != 0 && calloc_must_clear(mem2chunk(mem)))
+ memset(mem, 0, req);
+ return mem;
+}
+
+#endif /* !ONLY_MSPACES */
+
+/* ------------ Internal support for realloc, memalign, etc -------------- */
+
+/* Try to realloc; only in-place unless can_move true */
+static mchunkptr try_realloc_chunk(mstate m, mchunkptr p, size_t nb,
+ int can_move) {
+ mchunkptr newp = 0;
+ size_t oldsize = chunksize(p);
+ mchunkptr next = chunk_plus_offset(p, oldsize);
+ if (RTCHECK(ok_address(m, p) && ok_inuse(p) &&
+ ok_next(p, next) && ok_pinuse(next))) {
+ if (is_mmapped(p)) {
+ newp = mmap_resize(m, p, nb, can_move);
+ }
+ else if (oldsize >= nb) { /* already big enough */
+ size_t rsize = oldsize - nb;
+ if (rsize >= MIN_CHUNK_SIZE) { /* split off remainder */
+ mchunkptr r = chunk_plus_offset(p, nb);
+ set_inuse(m, p, nb);
+ set_inuse(m, r, rsize);
+ dispose_chunk(m, r, rsize);
+ }
+ newp = p;
+ }
+ else if (next == m->top) { /* extend into top */
+ if (oldsize + m->topsize > nb) {
+ size_t newsize = oldsize + m->topsize;
+ size_t newtopsize = newsize - nb;
+ mchunkptr newtop = chunk_plus_offset(p, nb);
+ set_inuse(m, p, nb);
+ newtop->head = newtopsize |PINUSE_BIT;
+ m->top = newtop;
+ m->topsize = newtopsize;
+ newp = p;
+ }
+ }
+ else if (next == m->dv) { /* extend into dv */
+ size_t dvs = m->dvsize;
+ if (oldsize + dvs >= nb) {
+ size_t dsize = oldsize + dvs - nb;
+ if (dsize >= MIN_CHUNK_SIZE) {
+ mchunkptr r = chunk_plus_offset(p, nb);
+ mchunkptr n = chunk_plus_offset(r, dsize);
+ set_inuse(m, p, nb);
+ set_size_and_pinuse_of_free_chunk(r, dsize);
+ clear_pinuse(n);
+ m->dvsize = dsize;
+ m->dv = r;
+ }
+ else { /* exhaust dv */
+ size_t newsize = oldsize + dvs;
+ set_inuse(m, p, newsize);
+ m->dvsize = 0;
+ m->dv = 0;
+ }
+ newp = p;
+ }
+ }
+ else if (!cinuse(next)) { /* extend into next free chunk */
+ size_t nextsize = chunksize(next);
+ if (oldsize + nextsize >= nb) {
+ size_t rsize = oldsize + nextsize - nb;
+ unlink_chunk(m, next, nextsize);
+ if (rsize < MIN_CHUNK_SIZE) {
+ size_t newsize = oldsize + nextsize;
+ set_inuse(m, p, newsize);
+ }
+ else {
+ mchunkptr r = chunk_plus_offset(p, nb);
+ set_inuse(m, p, nb);
+ set_inuse(m, r, rsize);
+ dispose_chunk(m, r, rsize);
+ }
+ newp = p;
+ }
+ }
+ }
+ else {
+ USAGE_ERROR_ACTION(m, chunk2mem(p));
+ }
+ return newp;
+}
+
+static void* internal_memalign(mstate m, size_t alignment, size_t bytes) {
+ void* mem = 0;
+ if (alignment < MIN_CHUNK_SIZE) /* must be at least a minimum chunk size */
+ alignment = MIN_CHUNK_SIZE;
+ if ((alignment & (alignment-SIZE_T_ONE)) != 0) {/* Ensure a power of 2 */
+ size_t a = MALLOC_ALIGNMENT << 1;
+ while (a < alignment) a <<= 1;
+ alignment = a;
+ }
+ if (bytes >= MAX_REQUEST - alignment) {
+ if (m != 0) { /* Test isn't needed but avoids compiler warning */
+ MALLOC_FAILURE_ACTION;
+ }
+ }
+ else {
+ size_t nb = request2size(bytes);
+ size_t req = nb + alignment + MIN_CHUNK_SIZE - CHUNK_OVERHEAD;
+ mem = internal_malloc(m, req);
+ if (mem != 0) {
+ mchunkptr p = mem2chunk(mem);
+ if (PREACTION(m))
+ return 0;
+ if ((((size_t)(mem)) & (alignment - 1)) != 0) { /* misaligned */
+ /*
+ Find an aligned spot inside chunk. Since we need to give
+ back leading space in a chunk of at least MIN_CHUNK_SIZE, if
+ the first calculation places us at a spot with less than
+ MIN_CHUNK_SIZE leader, we can move to the next aligned spot.
+ We've allocated enough total room so that this is always
+ possible.
+ */
+ char* br = (char*)mem2chunk((size_t)(((size_t)((char*)mem + alignment -
+ SIZE_T_ONE)) &
+ -alignment));
+ char* pos = ((size_t)(br - (char*)(p)) >= MIN_CHUNK_SIZE)?
+ br : br+alignment;
+ mchunkptr newp = (mchunkptr)pos;
+ size_t leadsize = pos - (char*)(p);
+ size_t newsize = chunksize(p) - leadsize;
+
+ if (is_mmapped(p)) { /* For mmapped chunks, just adjust offset */
+ newp->prev_foot = p->prev_foot + leadsize;
+ newp->head = newsize;
+ }
+ else { /* Otherwise, give back leader, use the rest */
+ set_inuse(m, newp, newsize);
+ set_inuse(m, p, leadsize);
+ dispose_chunk(m, p, leadsize);
+ }
+ p = newp;
+ }
+
+ /* Give back spare room at the end */
+ if (!is_mmapped(p)) {
+ size_t size = chunksize(p);
+ if (size > nb + MIN_CHUNK_SIZE) {
+ size_t remainder_size = size - nb;
+ mchunkptr remainder = chunk_plus_offset(p, nb);
+ set_inuse(m, p, nb);
+ set_inuse(m, remainder, remainder_size);
+ dispose_chunk(m, remainder, remainder_size);
+ }
+ }
+
+ mem = chunk2mem(p);
+ assert (chunksize(p) >= nb);
+ assert(((size_t)mem & (alignment - 1)) == 0);
+ check_inuse_chunk(m, p);
+ POSTACTION(m);
+ }
+ }
+ return mem;
+}
+
+/*
+ Common support for independent_X routines, handling
+ all of the combinations that can result.
+ The opts arg has:
+ bit 0 set if all elements are same size (using sizes[0])
+ bit 1 set if elements should be zeroed
+*/
+static void** ialloc(mstate m,
+ size_t n_elements,
+ size_t* sizes,
+ int opts,
+ void* chunks[]) {
+
+ size_t element_size; /* chunksize of each element, if all same */
+ size_t contents_size; /* total size of elements */
+ size_t array_size; /* request size of pointer array */
+ void* mem; /* malloced aggregate space */
+ mchunkptr p; /* corresponding chunk */
+ size_t remainder_size; /* remaining bytes while splitting */
+ void** marray; /* either "chunks" or malloced ptr array */
+ mchunkptr array_chunk; /* chunk for malloced ptr array */
+ flag_t was_enabled; /* to disable mmap */
+ size_t size;
+ size_t i;
+
+ ensure_initialization();
+ /* compute array length, if needed */
+ if (chunks != 0) {
+ if (n_elements == 0)
+ return chunks; /* nothing to do */
+ marray = chunks;
+ array_size = 0;
+ }
+ else {
+ /* if empty req, must still return chunk representing empty array */
+ if (n_elements == 0)
+ return (void**)internal_malloc(m, 0);
+ marray = 0;
+ array_size = request2size(n_elements * (sizeof(void*)));
+ }
+
+ /* compute total element size */
+ if (opts & 0x1) { /* all-same-size */
+ element_size = request2size(*sizes);
+ contents_size = n_elements * element_size;
+ }
+ else { /* add up all the sizes */
+ element_size = 0;
+ contents_size = 0;
+ for (i = 0; i != n_elements; ++i)
+ contents_size += request2size(sizes[i]);
+ }
+
+ size = contents_size + array_size;
+
+ /*
+ Allocate the aggregate chunk. First disable direct-mmapping so
+ malloc won't use it, since we would not be able to later
+ free/realloc space internal to a segregated mmap region.
+ */
+ was_enabled = use_mmap(m);
+ disable_mmap(m);
+ mem = internal_malloc(m, size - CHUNK_OVERHEAD);
+ if (was_enabled)
+ enable_mmap(m);
+ if (mem == 0)
+ return 0;
+
+ if (PREACTION(m)) return 0;
+ p = mem2chunk(mem);
+ remainder_size = chunksize(p);
+
+ assert(!is_mmapped(p));
+
+ if (opts & 0x2) { /* optionally clear the elements */
+ memset((size_t*)mem, 0, remainder_size - SIZE_T_SIZE - array_size);
+ }
+
+ /* If not provided, allocate the pointer array as final part of chunk */
+ if (marray == 0) {
+ size_t array_chunk_size;
+ array_chunk = chunk_plus_offset(p, contents_size);
+ array_chunk_size = remainder_size - contents_size;
+ marray = (void**) (chunk2mem(array_chunk));
+ set_size_and_pinuse_of_inuse_chunk(m, array_chunk, array_chunk_size);
+ remainder_size = contents_size;
+ }
+
+ /* split out elements */
+ for (i = 0; ; ++i) {
+ marray[i] = chunk2mem(p);
+ if (i != n_elements-1) {
+ if (element_size != 0)
+ size = element_size;
+ else
+ size = request2size(sizes[i]);
+ remainder_size -= size;
+ set_size_and_pinuse_of_inuse_chunk(m, p, size);
+ p = chunk_plus_offset(p, size);
+ }
+ else { /* the final element absorbs any overallocation slop */
+ set_size_and_pinuse_of_inuse_chunk(m, p, remainder_size);
+ break;
+ }
+ }
+
+#if DEBUG
+ if (marray != chunks) {
+ /* final element must have exactly exhausted chunk */
+ if (element_size != 0) {
+ assert(remainder_size == element_size);
+ }
+ else {
+ assert(remainder_size == request2size(sizes[i]));
+ }
+ check_inuse_chunk(m, mem2chunk(marray));
+ }
+ for (i = 0; i != n_elements; ++i)
+ check_inuse_chunk(m, mem2chunk(marray[i]));
+
+#endif /* DEBUG */
+
+ POSTACTION(m);
+ return marray;
+}
+
+/* Try to free all pointers in the given array.
+ Note: this could be made faster, by delaying consolidation,
+ at the price of disabling some user integrity checks, We
+ still optimize some consolidations by combining adjacent
+ chunks before freeing, which will occur often if allocated
+ with ialloc or the array is sorted.
+*/
+static size_t internal_bulk_free(mstate m, void* array[], size_t nelem) {
+ size_t unfreed = 0;
+ if (!PREACTION(m)) {
+ void** a;
+ void** fence = &(array[nelem]);
+ for (a = array; a != fence; ++a) {
+ void* mem = *a;
+ if (mem != 0) {
+ mchunkptr p = mem2chunk(mem);
+ size_t psize = chunksize(p);
+#if FOOTERS
+ if (get_mstate_for(p) != m) {
+ ++unfreed;
+ continue;
+ }
+#endif
+ check_inuse_chunk(m, p);
+ *a = 0;
+ if (RTCHECK(ok_address(m, p) && ok_inuse(p))) {
+ void ** b = a + 1; /* try to merge with next chunk */
+ mchunkptr next = next_chunk(p);
+ if (b != fence && *b == chunk2mem(next)) {
+ size_t newsize = chunksize(next) + psize;
+ set_inuse(m, p, newsize);
+ *b = chunk2mem(p);
+ }
+ else
+ dispose_chunk(m, p, psize);
+ }
+ else {
+ CORRUPTION_ERROR_ACTION(m);
+ break;
+ }
+ }
+ }
+ if (should_trim(m, m->topsize))
+ sys_trim(m, 0);
+ POSTACTION(m);
+ }
+ return unfreed;
+}
+
+/* Traversal */
+#if MALLOC_INSPECT_ALL
+static void internal_inspect_all(mstate m,
+ void(*handler)(void *start,
+ void *end,
+ size_t used_bytes,
+ void* callback_arg),
+ void* arg) {
+ if (is_initialized(m)) {
+ mchunkptr top = m->top;
+ msegmentptr s;
+ for (s = &m->seg; s != 0; s = s->next) {
+ mchunkptr q = align_as_chunk(s->base);
+ while (segment_holds(s, q) && q->head != FENCEPOST_HEAD) {
+ mchunkptr next = next_chunk(q);
+ size_t sz = chunksize(q);
+ size_t used;
+ void* start;
+ if (is_inuse(q)) {
+ used = sz - CHUNK_OVERHEAD; /* must not be mmapped */
+ start = chunk2mem(q);
+ }
+ else {
+ used = 0;
+ if (is_small(sz)) { /* offset by possible bookkeeping */
+ start = (void*)((char*)q + sizeof(struct malloc_chunk));
+ }
+ else {
+ start = (void*)((char*)q + sizeof(struct malloc_tree_chunk));
+ }
+ }
+ if (start < (void*)next) /* skip if all space is bookkeeping */
+ handler(start, next, used, arg);
+ if (q == top)
+ break;
+ q = next;
+ }
+ }
+ }
+}
+#endif /* MALLOC_INSPECT_ALL */
+
+/* ------------------ Exported realloc, memalign, etc -------------------- */
+
+#if !ONLY_MSPACES
+
+void* dlrealloc(void* oldmem, size_t bytes) {
+ void* mem = 0;
+ if (oldmem == 0) {
+ mem = dlmalloc(bytes);
+ }
+ else if (bytes >= MAX_REQUEST) {
+ MALLOC_FAILURE_ACTION;
+ }
+#ifdef REALLOC_ZERO_BYTES_FREES
+ else if (bytes == 0) {
+ dlfree(oldmem);
+ }
+#endif /* REALLOC_ZERO_BYTES_FREES */
+ else {
+ size_t nb = request2size(bytes);
+ mchunkptr oldp = mem2chunk(oldmem);
+#if ! FOOTERS
+ mstate m = gm;
+#else /* FOOTERS */
+ mstate m = get_mstate_for(oldp);
+ if (!ok_magic(m)) {
+ USAGE_ERROR_ACTION(m, oldmem);
+ return 0;
+ }
+#endif /* FOOTERS */
+ if (!PREACTION(m)) {
+ mchunkptr newp = try_realloc_chunk(m, oldp, nb, 1);
+ POSTACTION(m);
+ if (newp != 0) {
+ check_inuse_chunk(m, newp);
+ mem = chunk2mem(newp);
+ }
+ else {
+ mem = internal_malloc(m, bytes);
+ if (mem != 0) {
+ size_t oc = chunksize(oldp) - overhead_for(oldp);
+ memcpy(mem, oldmem, (oc < bytes)? oc : bytes);
+ internal_free(m, oldmem);
+ }
+ }
+ }
+ }
+ return mem;
+}
+
+void* dlrealloc_in_place(void* oldmem, size_t bytes) {
+ void* mem = 0;
+ if (oldmem != 0) {
+ if (bytes >= MAX_REQUEST) {
+ MALLOC_FAILURE_ACTION;
+ }
+ else {
+ size_t nb = request2size(bytes);
+ mchunkptr oldp = mem2chunk(oldmem);
+#if ! FOOTERS
+ mstate m = gm;
+#else /* FOOTERS */
+ mstate m = get_mstate_for(oldp);
+ if (!ok_magic(m)) {
+ USAGE_ERROR_ACTION(m, oldmem);
+ return 0;
+ }
+#endif /* FOOTERS */
+ if (!PREACTION(m)) {
+ mchunkptr newp = try_realloc_chunk(m, oldp, nb, 0);
+ POSTACTION(m);
+ if (newp == oldp) {
+ check_inuse_chunk(m, newp);
+ mem = oldmem;
+ }
+ }
+ }
+ }
+ return mem;
+}
+
+void* dlmemalign(size_t alignment, size_t bytes) {
+ if (alignment <= MALLOC_ALIGNMENT) {
+ return dlmalloc(bytes);
+ }
+ return internal_memalign(gm, alignment, bytes);
+}
+
+int dlposix_memalign(void** pp, size_t alignment, size_t bytes) {
+ void* mem = 0;
+ if (alignment == MALLOC_ALIGNMENT)
+ mem = dlmalloc(bytes);
+ else {
+ size_t d = alignment / sizeof(void*);
+ size_t r = alignment % sizeof(void*);
+ if (r != 0 || d == 0 || (d & (d-SIZE_T_ONE)) != 0)
+ return EINVAL;
+ else if (bytes <= MAX_REQUEST - alignment) {
+ if (alignment < MIN_CHUNK_SIZE)
+ alignment = MIN_CHUNK_SIZE;
+ mem = internal_memalign(gm, alignment, bytes);
+ }
+ }
+ if (mem == 0)
+ return ENOMEM;
+ else {
+ *pp = mem;
+ return 0;
+ }
+}
+
+void* dlvalloc(size_t bytes) {
+ size_t pagesz;
+ ensure_initialization();
+ pagesz = mparams.page_size;
+ return dlmemalign(pagesz, bytes);
+}
+
+void* dlpvalloc(size_t bytes) {
+ size_t pagesz;
+ ensure_initialization();
+ pagesz = mparams.page_size;
+ return dlmemalign(pagesz, (bytes + pagesz - SIZE_T_ONE) & ~(pagesz - SIZE_T_ONE));
+}
+
+void** dlindependent_calloc(size_t n_elements, size_t elem_size,
+ void* chunks[]) {
+ size_t sz = elem_size; /* serves as 1-element array */
+ return ialloc(gm, n_elements, &sz, 3, chunks);
+}
+
+void** dlindependent_comalloc(size_t n_elements, size_t sizes[],
+ void* chunks[]) {
+ return ialloc(gm, n_elements, sizes, 0, chunks);
+}
+
+size_t dlbulk_free(void* array[], size_t nelem) {
+ return internal_bulk_free(gm, array, nelem);
+}
+
+#if MALLOC_INSPECT_ALL
+void dlmalloc_inspect_all(void(*handler)(void *start,
+ void *end,
+ size_t used_bytes,
+ void* callback_arg),
+ void* arg) {
+ ensure_initialization();
+ if (!PREACTION(gm)) {
+ internal_inspect_all(gm, handler, arg);
+ POSTACTION(gm);
+ }
+}
+#endif /* MALLOC_INSPECT_ALL */
+
+int dlmalloc_trim(size_t pad) {
+ int result = 0;
+ ensure_initialization();
+ if (!PREACTION(gm)) {
+ result = sys_trim(gm, pad);
+ POSTACTION(gm);
+ }
+ return result;
+}
+
+size_t dlmalloc_footprint(void) {
+ return gm->footprint;
+}
+
+size_t dlmalloc_max_footprint(void) {
+ return gm->max_footprint;
+}
+
+size_t dlmalloc_footprint_limit(void) {
+ size_t maf = gm->footprint_limit;
+ return maf == 0 ? MAX_SIZE_T : maf;
+}
+
+size_t dlmalloc_set_footprint_limit(size_t bytes) {
+ ensure_initialization();
+ size_t result; /* invert sense of 0 */
+ if (bytes == 0)
+ result = granularity_align(1); /* Use minimal size */
+ if (bytes == MAX_SIZE_T)
+ result = 0; /* disable */
+ else
+ result = granularity_align(bytes);
+ return gm->footprint_limit = result;
+}
+
+#if !NO_MALLINFO
+struct mallinfo dlmallinfo(void) {
+ return internal_mallinfo(gm);
+}
+#endif /* NO_MALLINFO */
+
+#if !NO_MALLOC_STATS
+void dlmalloc_stats() {
+ internal_malloc_stats(gm);
+}
+#endif /* NO_MALLOC_STATS */
+
+int dlmallopt(int param_number, int value) {
+ return change_mparam(param_number, value);
+}
+
+size_t dlmalloc_usable_size(void* mem) {
+ if (mem != 0) {
+ mchunkptr p = mem2chunk(mem);
+ if (is_inuse(p))
+ return chunksize(p) - overhead_for(p);
+ }
+ return 0;
+}
+
+#endif /* !ONLY_MSPACES */
+
+/* ----------------------------- user mspaces ---------------------------- */
+
+#if MSPACES
+
+static mstate init_user_mstate(char* tbase, size_t tsize) {
+ size_t msize = pad_request(sizeof(struct malloc_state));
+ mchunkptr mn;
+ mchunkptr msp = align_as_chunk(tbase);
+ mstate m = (mstate)(chunk2mem(msp));
+ memset(m, 0, msize);
+ (void)INITIAL_LOCK(&m->mutex);
+ msp->head = (msize|INUSE_BITS);
+ m->seg.base = m->least_addr = tbase;
+ m->seg.size = m->footprint = m->max_footprint = tsize;
+ m->magic = mparams.magic;
+ m->release_checks = MAX_RELEASE_CHECK_RATE;
+ m->mflags = mparams.default_mflags;
+ m->extp = 0;
+ m->exts = 0;
+ disable_contiguous(m);
+ init_bins(m);
+ mn = next_chunk(mem2chunk(m));
+ init_top(m, mn, (size_t)((tbase + tsize) - (char*)mn) - TOP_FOOT_SIZE);
+ check_top_chunk(m, m->top);
+ return m;
+}
+
+mspace create_mspace(size_t capacity, int locked) {
+ mstate m = 0;
+ size_t msize;
+ ensure_initialization();
+ msize = pad_request(sizeof(struct malloc_state));
+ if (capacity < (size_t) -(msize + TOP_FOOT_SIZE + mparams.page_size)) {
+ size_t rs = ((capacity == 0)? mparams.granularity :
+ (capacity + TOP_FOOT_SIZE + msize));
+ size_t tsize = granularity_align(rs);
+ char* tbase = (char*)(CALL_MMAP(tsize));
+ if (tbase != CMFAIL) {
+ m = init_user_mstate(tbase, tsize);
+ m->seg.sflags = USE_MMAP_BIT;
+ set_lock(m, locked);
+ }
+ }
+ return (mspace)m;
+}
+
+mspace create_mspace_with_base(void* base, size_t capacity, int locked) {
+ mstate m = 0;
+ size_t msize;
+ ensure_initialization();
+ msize = pad_request(sizeof(struct malloc_state));
+ if (capacity > msize + TOP_FOOT_SIZE &&
+ capacity < (size_t) -(msize + TOP_FOOT_SIZE + mparams.page_size)) {
+ m = init_user_mstate((char*)base, capacity);
+ m->seg.sflags = EXTERN_BIT;
+ set_lock(m, locked);
+ }
+ return (mspace)m;
+}
+
+int mspace_track_large_chunks(mspace msp, int enable) {
+ int ret = 0;
+ mstate ms = (mstate)msp;
+ if (!PREACTION(ms)) {
+ if (!use_mmap(ms)) {
+ ret = 1;
+ }
+ if (!enable) {
+ enable_mmap(ms);
+ } else {
+ disable_mmap(ms);
+ }
+ POSTACTION(ms);
+ }
+ return ret;
+}
+
+size_t destroy_mspace(mspace msp) {
+ size_t freed = 0;
+ mstate ms = (mstate)msp;
+ if (ok_magic(ms)) {
+ msegmentptr sp = &ms->seg;
+ (void)DESTROY_LOCK(&ms->mutex); /* destroy before unmapped */
+ while (sp != 0) {
+ char* base = sp->base;
+ size_t size = sp->size;
+ flag_t flag = sp->sflags;
+ (void)base; /* placate people compiling -Wunused-variable */
+ sp = sp->next;
+ if ((flag & USE_MMAP_BIT) && !(flag & EXTERN_BIT) &&
+ CALL_MUNMAP(base, size) == 0)
+ freed += size;
+ }
+ }
+ else {
+ USAGE_ERROR_ACTION(ms,ms);
+ }
+ return freed;
+}
+
+/*
+ mspace versions of routines are near-clones of the global
+ versions. This is not so nice but better than the alternatives.
+*/
+
+void* mspace_malloc(mspace msp, size_t bytes) {
+ mstate ms = (mstate)msp;
+ if (!ok_magic(ms)) {
+ USAGE_ERROR_ACTION(ms,ms);
+ return 0;
+ }
+ if (!PREACTION(ms)) {
+ void* mem;
+ size_t nb;
+ if (bytes <= MAX_SMALL_REQUEST) {
+ bindex_t idx;
+ binmap_t smallbits;
+ nb = (bytes < MIN_REQUEST)? MIN_CHUNK_SIZE : pad_request(bytes);
+ idx = small_index(nb);
+ smallbits = ms->smallmap >> idx;
+
+ if ((smallbits & 0x3U) != 0) { /* Remainderless fit to a smallbin. */
+ mchunkptr b, p;
+ idx += ~smallbits & 1; /* Uses next bin if idx empty */
+ b = smallbin_at(ms, idx);
+ p = b->fd;
+ assert(chunksize(p) == small_index2size(idx));
+ unlink_first_small_chunk(ms, b, p, idx);
+ set_inuse_and_pinuse(ms, p, small_index2size(idx));
+ mem = chunk2mem(p);
+ check_malloced_chunk(ms, mem, nb);
+ goto postaction;
+ }
+
+ else if (nb > ms->dvsize) {
+ if (smallbits != 0) { /* Use chunk in next nonempty smallbin */
+ mchunkptr b, p, r;
+ size_t rsize;
+ bindex_t i;
+ binmap_t leftbits = (smallbits << idx) & left_bits(idx2bit(idx));
+ binmap_t leastbit = least_bit(leftbits);
+ compute_bit2idx(leastbit, i);
+ b = smallbin_at(ms, i);
+ p = b->fd;
+ assert(chunksize(p) == small_index2size(i));
+ unlink_first_small_chunk(ms, b, p, i);
+ rsize = small_index2size(i) - nb;
+ /* Fit here cannot be remainderless if 4byte sizes */
+ if (SIZE_T_SIZE != 4 && rsize < MIN_CHUNK_SIZE)
+ set_inuse_and_pinuse(ms, p, small_index2size(i));
+ else {
+ set_size_and_pinuse_of_inuse_chunk(ms, p, nb);
+ r = chunk_plus_offset(p, nb);
+ set_size_and_pinuse_of_free_chunk(r, rsize);
+ replace_dv(ms, r, rsize);
+ }
+ mem = chunk2mem(p);
+ check_malloced_chunk(ms, mem, nb);
+ goto postaction;
+ }
+
+ else if (ms->treemap != 0 && (mem = tmalloc_small(ms, nb)) != 0) {
+ check_malloced_chunk(ms, mem, nb);
+ goto postaction;
+ }
+ }
+ }
+ else if (bytes >= MAX_REQUEST)
+ nb = MAX_SIZE_T; /* Too big to allocate. Force failure (in sys alloc) */
+ else {
+ nb = pad_request(bytes);
+ if (ms->treemap != 0 && (mem = tmalloc_large(ms, nb)) != 0) {
+ check_malloced_chunk(ms, mem, nb);
+ goto postaction;
+ }
+ }
+
+ if (nb <= ms->dvsize) {
+ size_t rsize = ms->dvsize - nb;
+ mchunkptr p = ms->dv;
+ if (rsize >= MIN_CHUNK_SIZE) { /* split dv */
+ mchunkptr r = ms->dv = chunk_plus_offset(p, nb);
+ ms->dvsize = rsize;
+ set_size_and_pinuse_of_free_chunk(r, rsize);
+ set_size_and_pinuse_of_inuse_chunk(ms, p, nb);
+ }
+ else { /* exhaust dv */
+ size_t dvs = ms->dvsize;
+ ms->dvsize = 0;
+ ms->dv = 0;
+ set_inuse_and_pinuse(ms, p, dvs);
+ }
+ mem = chunk2mem(p);
+ check_malloced_chunk(ms, mem, nb);
+ goto postaction;
+ }
+
+ else if (nb < ms->topsize) { /* Split top */
+ size_t rsize = ms->topsize -= nb;
+ mchunkptr p = ms->top;
+ mchunkptr r = ms->top = chunk_plus_offset(p, nb);
+ r->head = rsize | PINUSE_BIT;
+ set_size_and_pinuse_of_inuse_chunk(ms, p, nb);
+ mem = chunk2mem(p);
+ check_top_chunk(ms, ms->top);
+ check_malloced_chunk(ms, mem, nb);
+ goto postaction;
+ }
+
+ mem = sys_alloc(ms, nb);
+
+ postaction:
+ POSTACTION(ms);
+ return mem;
+ }
+
+ return 0;
+}
+
+void mspace_free(mspace msp, void* mem) {
+ if (mem != 0) {
+ mchunkptr p = mem2chunk(mem);
+#if FOOTERS
+ mstate fm = get_mstate_for(p);
+ (void)msp; /* placate people compiling -Wunused */
+#else /* FOOTERS */
+ mstate fm = (mstate)msp;
+#endif /* FOOTERS */
+ if (!ok_magic(fm)) {
+ USAGE_ERROR_ACTION(fm, p);
+ return;
+ }
+ if (!PREACTION(fm)) {
+ check_inuse_chunk(fm, p);
+ if (RTCHECK(ok_address(fm, p) && ok_inuse(p))) {
+ size_t psize = chunksize(p);
+ mchunkptr next = chunk_plus_offset(p, psize);
+ if (!pinuse(p)) {
+ size_t prevsize = p->prev_foot;
+ if (is_mmapped(p)) {
+ psize += prevsize + MMAP_FOOT_PAD;
+ if (CALL_MUNMAP((char*)p - prevsize, psize) == 0)
+ fm->footprint -= psize;
+ goto postaction;
+ }
+ else {
+ mchunkptr prev = chunk_minus_offset(p, prevsize);
+ psize += prevsize;
+ p = prev;
+ if (RTCHECK(ok_address(fm, prev))) { /* consolidate backward */
+ if (p != fm->dv) {
+ unlink_chunk(fm, p, prevsize);
+ }
+ else if ((next->head & INUSE_BITS) == INUSE_BITS) {
+ fm->dvsize = psize;
+ set_free_with_pinuse(p, psize, next);
+ goto postaction;
+ }
+ }
+ else
+ goto erroraction;
+ }
+ }
+
+ if (RTCHECK(ok_next(p, next) && ok_pinuse(next))) {
+ if (!cinuse(next)) { /* consolidate forward */
+ if (next == fm->top) {
+ size_t tsize = fm->topsize += psize;
+ fm->top = p;
+ p->head = tsize | PINUSE_BIT;
+ if (p == fm->dv) {
+ fm->dv = 0;
+ fm->dvsize = 0;
+ }
+ if (should_trim(fm, tsize))
+ sys_trim(fm, 0);
+ goto postaction;
+ }
+ else if (next == fm->dv) {
+ size_t dsize = fm->dvsize += psize;
+ fm->dv = p;
+ set_size_and_pinuse_of_free_chunk(p, dsize);
+ goto postaction;
+ }
+ else {
+ size_t nsize = chunksize(next);
+ psize += nsize;
+ unlink_chunk(fm, next, nsize);
+ set_size_and_pinuse_of_free_chunk(p, psize);
+ if (p == fm->dv) {
+ fm->dvsize = psize;
+ goto postaction;
+ }
+ }
+ }
+ else
+ set_free_with_pinuse(p, psize, next);
+
+ if (is_small(psize)) {
+ insert_small_chunk(fm, p, psize);
+ check_free_chunk(fm, p);
+ }
+ else {
+ tchunkptr tp = (tchunkptr)p;
+ insert_large_chunk(fm, tp, psize);
+ check_free_chunk(fm, p);
+ if (--fm->release_checks == 0)
+ release_unused_segments(fm);
+ }
+ goto postaction;
+ }
+ }
+ erroraction:
+ USAGE_ERROR_ACTION(fm, p);
+ postaction:
+ POSTACTION(fm);
+ }
+ }
+}
+
+void* mspace_calloc(mspace msp, size_t n_elements, size_t elem_size) {
+ void* mem;
+ size_t req = 0;
+ mstate ms = (mstate)msp;
+ if (!ok_magic(ms)) {
+ USAGE_ERROR_ACTION(ms,ms);
+ return 0;
+ }
+ if (n_elements != 0) {
+ req = n_elements * elem_size;
+ if (((n_elements | elem_size) & ~(size_t)0xffff) &&
+ (req / n_elements != elem_size))
+ req = MAX_SIZE_T; /* force downstream failure on overflow */
+ }
+ mem = internal_malloc(ms, req);
+ if (mem != 0 && calloc_must_clear(mem2chunk(mem)))
+ memset(mem, 0, req);
+ return mem;
+}
+
+void* mspace_realloc(mspace msp, void* oldmem, size_t bytes) {
+ void* mem = 0;
+ if (oldmem == 0) {
+ mem = mspace_malloc(msp, bytes);
+ }
+ else if (bytes >= MAX_REQUEST) {
+ MALLOC_FAILURE_ACTION;
+ }
+#ifdef REALLOC_ZERO_BYTES_FREES
+ else if (bytes == 0) {
+ mspace_free(msp, oldmem);
+ }
+#endif /* REALLOC_ZERO_BYTES_FREES */
+ else {
+ size_t nb = request2size(bytes);
+ mchunkptr oldp = mem2chunk(oldmem);
+#if ! FOOTERS
+ mstate m = (mstate)msp;
+#else /* FOOTERS */
+ mstate m = get_mstate_for(oldp);
+ if (!ok_magic(m)) {
+ USAGE_ERROR_ACTION(m, oldmem);
+ return 0;
+ }
+#endif /* FOOTERS */
+ if (!PREACTION(m)) {
+ mchunkptr newp = try_realloc_chunk(m, oldp, nb, 1);
+ POSTACTION(m);
+ if (newp != 0) {
+ check_inuse_chunk(m, newp);
+ mem = chunk2mem(newp);
+ }
+ else {
+ mem = mspace_malloc(m, bytes);
+ if (mem != 0) {
+ size_t oc = chunksize(oldp) - overhead_for(oldp);
+ memcpy(mem, oldmem, (oc < bytes)? oc : bytes);
+ mspace_free(m, oldmem);
+ }
+ }
+ }
+ }
+ return mem;
+}
+
+void* mspace_realloc_in_place(mspace msp, void* oldmem, size_t bytes) {
+ void* mem = 0;
+ if (oldmem != 0) {
+ if (bytes >= MAX_REQUEST) {
+ MALLOC_FAILURE_ACTION;
+ }
+ else {
+ size_t nb = request2size(bytes);
+ mchunkptr oldp = mem2chunk(oldmem);
+#if ! FOOTERS
+ mstate m = (mstate)msp;
+#else /* FOOTERS */
+ mstate m = get_mstate_for(oldp);
+ (void)msp; /* placate people compiling -Wunused */
+ if (!ok_magic(m)) {
+ USAGE_ERROR_ACTION(m, oldmem);
+ return 0;
+ }
+#endif /* FOOTERS */
+ if (!PREACTION(m)) {
+ mchunkptr newp = try_realloc_chunk(m, oldp, nb, 0);
+ POSTACTION(m);
+ if (newp == oldp) {
+ check_inuse_chunk(m, newp);
+ mem = oldmem;
+ }
+ }
+ }
+ }
+ return mem;
+}
+
+void* mspace_memalign(mspace msp, size_t alignment, size_t bytes) {
+ mstate ms = (mstate)msp;
+ if (!ok_magic(ms)) {
+ USAGE_ERROR_ACTION(ms,ms);
+ return 0;
+ }
+ if (alignment <= MALLOC_ALIGNMENT)
+ return mspace_malloc(msp, bytes);
+ return internal_memalign(ms, alignment, bytes);
+}
+
+void** mspace_independent_calloc(mspace msp, size_t n_elements,
+ size_t elem_size, void* chunks[]) {
+ size_t sz = elem_size; /* serves as 1-element array */
+ mstate ms = (mstate)msp;
+ if (!ok_magic(ms)) {
+ USAGE_ERROR_ACTION(ms,ms);
+ return 0;
+ }
+ return ialloc(ms, n_elements, &sz, 3, chunks);
+}
+
+void** mspace_independent_comalloc(mspace msp, size_t n_elements,
+ size_t sizes[], void* chunks[]) {
+ mstate ms = (mstate)msp;
+ if (!ok_magic(ms)) {
+ USAGE_ERROR_ACTION(ms,ms);
+ return 0;
+ }
+ return ialloc(ms, n_elements, sizes, 0, chunks);
+}
+
+size_t mspace_bulk_free(mspace msp, void* array[], size_t nelem) {
+ return internal_bulk_free((mstate)msp, array, nelem);
+}
+
+#if MALLOC_INSPECT_ALL
+void mspace_inspect_all(mspace msp,
+ void(*handler)(void *start,
+ void *end,
+ size_t used_bytes,
+ void* callback_arg),
+ void* arg) {
+ mstate ms = (mstate)msp;
+ if (ok_magic(ms)) {
+ if (!PREACTION(ms)) {
+ internal_inspect_all(ms, handler, arg);
+ POSTACTION(ms);
+ }
+ }
+ else {
+ USAGE_ERROR_ACTION(ms,ms);
+ }
+}
+#endif /* MALLOC_INSPECT_ALL */
+
+int mspace_trim(mspace msp, size_t pad) {
+ int result = 0;
+ mstate ms = (mstate)msp;
+ if (ok_magic(ms)) {
+ if (!PREACTION(ms)) {
+ result = sys_trim(ms, pad);
+ POSTACTION(ms);
+ }
+ }
+ else {
+ USAGE_ERROR_ACTION(ms,ms);
+ }
+ return result;
+}
+
+#if !NO_MALLOC_STATS
+void mspace_malloc_stats(mspace msp) {
+ mstate ms = (mstate)msp;
+ if (ok_magic(ms)) {
+ internal_malloc_stats(ms);
+ }
+ else {
+ USAGE_ERROR_ACTION(ms,ms);
+ }
+}
+#endif /* NO_MALLOC_STATS */
+
+size_t mspace_footprint(mspace msp) {
+ size_t result = 0;
+ mstate ms = (mstate)msp;
+ if (ok_magic(ms)) {
+ result = ms->footprint;
+ }
+ else {
+ USAGE_ERROR_ACTION(ms,ms);
+ }
+ return result;
+}
+
+size_t mspace_max_footprint(mspace msp) {
+ size_t result = 0;
+ mstate ms = (mstate)msp;
+ if (ok_magic(ms)) {
+ result = ms->max_footprint;
+ }
+ else {
+ USAGE_ERROR_ACTION(ms,ms);
+ }
+ return result;
+}
+
+size_t mspace_footprint_limit(mspace msp) {
+ size_t result = 0;
+ mstate ms = (mstate)msp;
+ if (ok_magic(ms)) {
+ size_t maf = ms->footprint_limit;
+ result = (maf == 0) ? MAX_SIZE_T : maf;
+ }
+ else {
+ USAGE_ERROR_ACTION(ms,ms);
+ }
+ return result;
+}
+
+size_t mspace_set_footprint_limit(mspace msp, size_t bytes) {
+ size_t result = 0;
+ mstate ms = (mstate)msp;
+ if (ok_magic(ms)) {
+ if (bytes == 0)
+ result = granularity_align(1); /* Use minimal size */
+ if (bytes == MAX_SIZE_T)
+ result = 0; /* disable */
+ else
+ result = granularity_align(bytes);
+ ms->footprint_limit = result;
+ }
+ else {
+ USAGE_ERROR_ACTION(ms,ms);
+ }
+ return result;
+}
+
+#if !NO_MALLINFO
+struct mallinfo mspace_mallinfo(mspace msp) {
+ mstate ms = (mstate)msp;
+ if (!ok_magic(ms)) {
+ USAGE_ERROR_ACTION(ms,ms);
+ }
+ return internal_mallinfo(ms);
+}
+#endif /* NO_MALLINFO */
+
+size_t mspace_usable_size(const void* mem) {
+ if (mem != 0) {
+ mchunkptr p = mem2chunk(mem);
+ if (is_inuse(p))
+ return chunksize(p) - overhead_for(p);
+ }
+ return 0;
+}
+
+int mspace_mallopt(int param_number, int value) {
+ return change_mparam(param_number, value);
+}
+
+#endif /* MSPACES */
+
+
+/* -------------------- Alternative MORECORE functions ------------------- */
+
+/*
+ Guidelines for creating a custom version of MORECORE:
+
+ * For best performance, MORECORE should allocate in multiples of pagesize.
+ * MORECORE may allocate more memory than requested. (Or even less,
+ but this will usually result in a malloc failure.)
+ * MORECORE must not allocate memory when given argument zero, but
+ instead return one past the end address of memory from previous
+ nonzero call.
+ * For best performance, consecutive calls to MORECORE with positive
+ arguments should return increasing addresses, indicating that
+ space has been contiguously extended.
+ * Even though consecutive calls to MORECORE need not return contiguous
+ addresses, it must be OK for malloc'ed chunks to span multiple
+ regions in those cases where they do happen to be contiguous.
+ * MORECORE need not handle negative arguments -- it may instead
+ just return MFAIL when given negative arguments.
+ Negative arguments are always multiples of pagesize. MORECORE
+ must not misinterpret negative args as large positive unsigned
+ args. You can suppress all such calls from even occurring by defining
+ MORECORE_CANNOT_TRIM,
+
+ As an example alternative MORECORE, here is a custom allocator
+ kindly contributed for pre-OSX macOS. It uses virtually but not
+ necessarily physically contiguous non-paged memory (locked in,
+ present and won't get swapped out). You can use it by uncommenting
+ this section, adding some #includes, and setting up the appropriate
+ defines above:
+
+ #define MORECORE osMoreCore
+
+ There is also a shutdown routine that should somehow be called for
+ cleanup upon program exit.
+
+ #define MAX_POOL_ENTRIES 100
+ #define MINIMUM_MORECORE_SIZE (64 * 1024U)
+ static int next_os_pool;
+ void *our_os_pools[MAX_POOL_ENTRIES];
+
+ void *osMoreCore(int size)
+ {
+ void *ptr = 0;
+ static void *sbrk_top = 0;
+
+ if (size > 0)
+ {
+ if (size < MINIMUM_MORECORE_SIZE)
+ size = MINIMUM_MORECORE_SIZE;
+ if (CurrentExecutionLevel() == kTaskLevel)
+ ptr = PoolAllocateResident(size + RM_PAGE_SIZE, 0);
+ if (ptr == 0)
+ {
+ return (void *) MFAIL;
+ }
+ // save ptrs so they can be freed during cleanup
+ our_os_pools[next_os_pool] = ptr;
+ next_os_pool++;
+ ptr = (void *) ((((size_t) ptr) + RM_PAGE_MASK) & ~RM_PAGE_MASK);
+ sbrk_top = (char *) ptr + size;
+ return ptr;
+ }
+ else if (size < 0)
+ {
+ // we don't currently support shrink behavior
+ return (void *) MFAIL;
+ }
+ else
+ {
+ return sbrk_top;
+ }
+ }
+
+ // cleanup any allocated memory pools
+ // called as last thing before shutting down driver
+
+ void osCleanupMem(void)
+ {
+ void **ptr;
+
+ for (ptr = our_os_pools; ptr < &our_os_pools[MAX_POOL_ENTRIES]; ptr++)
+ if (*ptr)
+ {
+ PoolDeallocate(*ptr);
+ *ptr = 0;
+ }
+ }
+
+*/
+
+
+/* -----------------------------------------------------------------------
+History:
+ v2.8.6 Wed Aug 29 06:57:58 2012 Doug Lea
+ * fix bad comparison in dlposix_memalign
+ * don't reuse adjusted asize in sys_alloc
+ * add LOCK_AT_FORK -- thanks to Kirill Artamonov for the suggestion
+ * reduce compiler warnings -- thanks to all who reported/suggested these
+
+ v2.8.5 Sun May 22 10:26:02 2011 Doug Lea (dl at gee)
+ * Always perform unlink checks unless INSECURE
+ * Add posix_memalign.
+ * Improve realloc to expand in more cases; expose realloc_in_place.
+ Thanks to Peter Buhr for the suggestion.
+ * Add footprint_limit, inspect_all, bulk_free. Thanks
+ to Barry Hayes and others for the suggestions.
+ * Internal refactorings to avoid calls while holding locks
+ * Use non-reentrant locks by default. Thanks to Roland McGrath
+ for the suggestion.
+ * Small fixes to mspace_destroy, reset_on_error.
+ * Various configuration extensions/changes. Thanks
+ to all who contributed these.
+
+ V2.8.4a Thu Apr 28 14:39:43 2011 (dl at gee.cs.oswego.edu)
+ * Update Creative Commons URL
+
+ V2.8.4 Wed May 27 09:56:23 2009 Doug Lea (dl at gee)
+ * Use zeros instead of prev foot for is_mmapped
+ * Add mspace_track_large_chunks; thanks to Jean Brouwers
+ * Fix set_inuse in internal_realloc; thanks to Jean Brouwers
+ * Fix insufficient sys_alloc padding when using 16byte alignment
+ * Fix bad error check in mspace_footprint
+ * Adaptations for ptmalloc; thanks to Wolfram Gloger.
+ * Reentrant spin locks; thanks to Earl Chew and others
+ * Win32 improvements; thanks to Niall Douglas and Earl Chew
+ * Add NO_SEGMENT_TRAVERSAL and MAX_RELEASE_CHECK_RATE options
+ * Extension hook in malloc_state
+ * Various small adjustments to reduce warnings on some compilers
+ * Various configuration extensions/changes for more platforms. Thanks
+ to all who contributed these.
+
+ V2.8.3 Thu Sep 22 11:16:32 2005 Doug Lea (dl at gee)
+ * Add max_footprint functions
+ * Ensure all appropriate literals are size_t
+ * Fix conditional compilation problem for some #define settings
+ * Avoid concatenating segments with the one provided
+ in create_mspace_with_base
+ * Rename some variables to avoid compiler shadowing warnings
+ * Use explicit lock initialization.
+ * Better handling of sbrk interference.
+ * Simplify and fix segment insertion, trimming and mspace_destroy
+ * Reinstate REALLOC_ZERO_BYTES_FREES option from 2.7.x
+ * Thanks especially to Dennis Flanagan for help on these.
+
+ V2.8.2 Sun Jun 12 16:01:10 2005 Doug Lea (dl at gee)
+ * Fix memalign brace error.
+
+ V2.8.1 Wed Jun 8 16:11:46 2005 Doug Lea (dl at gee)
+ * Fix improper #endif nesting in C++
+ * Add explicit casts needed for C++
+
+ V2.8.0 Mon May 30 14:09:02 2005 Doug Lea (dl at gee)
+ * Use trees for large bins
+ * Support mspaces
+ * Use segments to unify sbrk-based and mmap-based system allocation,
+ removing need for emulation on most platforms without sbrk.
+ * Default safety checks
+ * Optional footer checks. Thanks to William Robertson for the idea.
+ * Internal code refactoring
+ * Incorporate suggestions and platform-specific changes.
+ Thanks to Dennis Flanagan, Colin Plumb, Niall Douglas,
+ Aaron Bachmann, Emery Berger, and others.
+ * Speed up non-fastbin processing enough to remove fastbins.
+ * Remove useless cfree() to avoid conflicts with other apps.
+ * Remove internal memcpy, memset. Compilers handle builtins better.
+ * Remove some options that no one ever used and rename others.
+
+ V2.7.2 Sat Aug 17 09:07:30 2002 Doug Lea (dl at gee)
+ * Fix malloc_state bitmap array misdeclaration
+
+ V2.7.1 Thu Jul 25 10:58:03 2002 Doug Lea (dl at gee)
+ * Allow tuning of FIRST_SORTED_BIN_SIZE
+ * Use PTR_UINT as type for all ptr->int casts. Thanks to John Belmonte.
+ * Better detection and support for non-contiguousness of MORECORE.
+ Thanks to Andreas Mueller, Conal Walsh, and Wolfram Gloger
+ * Bypass most of malloc if no frees. Thanks To Emery Berger.
+ * Fix freeing of old top non-contiguous chunk im sysmalloc.
+ * Raised default trim and map thresholds to 256K.
+ * Fix mmap-related #defines. Thanks to Lubos Lunak.
+ * Fix copy macros; added LACKS_FCNTL_H. Thanks to Neal Walfield.
+ * Branch-free bin calculation
+ * Default trim and mmap thresholds now 256K.
+
+ V2.7.0 Sun Mar 11 14:14:06 2001 Doug Lea (dl at gee)
+ * Introduce independent_comalloc and independent_calloc.
+ Thanks to Michael Pachos for motivation and help.
+ * Make optional .h file available
+ * Allow > 2GB requests on 32bit systems.
+ * new WIN32 sbrk, mmap, munmap, lock code from <Walter@GeNeSys-e.de>.
+ Thanks also to Andreas Mueller <a.mueller at paradatec.de>,
+ and Anonymous.
+ * Allow override of MALLOC_ALIGNMENT (Thanks to Ruud Waij for
+ helping test this.)
+ * memalign: check alignment arg
+ * realloc: don't try to shift chunks backwards, since this
+ leads to more fragmentation in some programs and doesn't
+ seem to help in any others.
+ * Collect all cases in malloc requiring system memory into sysmalloc
+ * Use mmap as backup to sbrk
+ * Place all internal state in malloc_state
+ * Introduce fastbins (although similar to 2.5.1)
+ * Many minor tunings and cosmetic improvements
+ * Introduce USE_PUBLIC_MALLOC_WRAPPERS, USE_MALLOC_LOCK
+ * Introduce MALLOC_FAILURE_ACTION, MORECORE_CONTIGUOUS
+ Thanks to Tony E. Bennett <tbennett@nvidia.com> and others.
+ * Include errno.h to support default failure action.
+
+ V2.6.6 Sun Dec 5 07:42:19 1999 Doug Lea (dl at gee)
+ * return null for negative arguments
+ * Added Several WIN32 cleanups from Martin C. Fong <mcfong at yahoo.com>
+ * Add 'LACKS_SYS_PARAM_H' for those systems without 'sys/param.h'
+ (e.g. WIN32 platforms)
+ * Cleanup header file inclusion for WIN32 platforms
+ * Cleanup code to avoid Microsoft Visual C++ compiler complaints
+ * Add 'USE_DL_PREFIX' to quickly allow co-existence with existing
+ memory allocation routines
+ * Set 'malloc_getpagesize' for WIN32 platforms (needs more work)
+ * Use 'assert' rather than 'ASSERT' in WIN32 code to conform to
+ usage of 'assert' in non-WIN32 code
+ * Improve WIN32 'sbrk()' emulation's 'findRegion()' routine to
+ avoid infinite loop
+ * Always call 'fREe()' rather than 'free()'
+
+ V2.6.5 Wed Jun 17 15:57:31 1998 Doug Lea (dl at gee)
+ * Fixed ordering problem with boundary-stamping
+
+ V2.6.3 Sun May 19 08:17:58 1996 Doug Lea (dl at gee)
+ * Added pvalloc, as recommended by H.J. Liu
+ * Added 64bit pointer support mainly from Wolfram Gloger
+ * Added anonymously donated WIN32 sbrk emulation
+ * Malloc, calloc, getpagesize: add optimizations from Raymond Nijssen
+ * malloc_extend_top: fix mask error that caused wastage after
+ foreign sbrks
+ * Add linux mremap support code from HJ Liu
+
+ V2.6.2 Tue Dec 5 06:52:55 1995 Doug Lea (dl at gee)
+ * Integrated most documentation with the code.
+ * Add support for mmap, with help from
+ Wolfram Gloger (Gloger@lrz.uni-muenchen.de).
+ * Use last_remainder in more cases.
+ * Pack bins using idea from colin@nyx10.cs.du.edu
+ * Use ordered bins instead of best-fit threshold
+ * Eliminate block-local decls to simplify tracing and debugging.
+ * Support another case of realloc via move into top
+ * Fix error occurring when initial sbrk_base not word-aligned.
+ * Rely on page size for units instead of SBRK_UNIT to
+ avoid surprises about sbrk alignment conventions.
+ * Add mallinfo, mallopt. Thanks to Raymond Nijssen
+ (raymond@es.ele.tue.nl) for the suggestion.
+ * Add `pad' argument to malloc_trim and top_pad mallopt parameter.
+ * More precautions for cases where other routines call sbrk,
+ courtesy of Wolfram Gloger (Gloger@lrz.uni-muenchen.de).
+ * Added macros etc., allowing use in linux libc from
+ H.J. Lu (hjl@gnu.ai.mit.edu)
+ * Inverted this history list
+
+ V2.6.1 Sat Dec 2 14:10:57 1995 Doug Lea (dl at gee)
+ * Re-tuned and fixed to behave more nicely with V2.6.0 changes.
+ * Removed all preallocation code since under current scheme
+ the work required to undo bad preallocations exceeds
+ the work saved in good cases for most test programs.
+ * No longer use return list or unconsolidated bins since
+ no scheme using them consistently outperforms those that don't
+ given above changes.
+ * Use best fit for very large chunks to prevent some worst-cases.
+ * Added some support for debugging
+
+ V2.6.0 Sat Nov 4 07:05:23 1995 Doug Lea (dl at gee)
+ * Removed footers when chunks are in use. Thanks to
+ Paul Wilson (wilson@cs.texas.edu) for the suggestion.
+
+ V2.5.4 Wed Nov 1 07:54:51 1995 Doug Lea (dl at gee)
+ * Added malloc_trim, with help from Wolfram Gloger
+ (wmglo@Dent.MED.Uni-Muenchen.DE).
+
+ V2.5.3 Tue Apr 26 10:16:01 1994 Doug Lea (dl at g)
+
+ V2.5.2 Tue Apr 5 16:20:40 1994 Doug Lea (dl at g)
+ * realloc: try to expand in both directions
+ * malloc: swap order of clean-bin strategy;
+ * realloc: only conditionally expand backwards
+ * Try not to scavenge used bins
+ * Use bin counts as a guide to preallocation
+ * Occasionally bin return list chunks in first scan
+ * Add a few optimizations from colin@nyx10.cs.du.edu
+
+ V2.5.1 Sat Aug 14 15:40:43 1993 Doug Lea (dl at g)
+ * faster bin computation & slightly different binning
+ * merged all consolidations to one part of malloc proper
+ (eliminating old malloc_find_space & malloc_clean_bin)
+ * Scan 2 returns chunks (not just 1)
+ * Propagate failure in realloc if malloc returns 0
+ * Add stuff to allow compilation on non-ANSI compilers
+ from kpv@research.att.com
+
+ V2.5 Sat Aug 7 07:41:59 1993 Doug Lea (dl at g.oswego.edu)
+ * removed potential for odd address access in prev_chunk
+ * removed dependency on getpagesize.h
+ * misc cosmetics and a bit more internal documentation
+ * anticosmetics: mangled names in macros to evade debugger strangeness
+ * tested on sparc, hp-700, dec-mips, rs6000
+ with gcc & native cc (hp, dec only) allowing
+ Detlefs & Zorn comparison study (in SIGPLAN Notices.)
+
+ Trial version Fri Aug 28 13:14:29 1992 Doug Lea (dl at g.oswego.edu)
+ * Based loosely on libg++-1.2X malloc. (It retains some of the overall
+ structure of old version, but most details differ.)
+
+*/
diff --git a/src/arrow/cpp/submodules/parquet-testing/LICENSE.txt b/src/arrow/cpp/submodules/parquet-testing/LICENSE.txt
new file mode 100644
index 000000000..d64569567
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/LICENSE.txt
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/src/arrow/cpp/submodules/parquet-testing/README.md b/src/arrow/cpp/submodules/parquet-testing/README.md
new file mode 100644
index 000000000..a11eb6afd
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/README.md
@@ -0,0 +1,19 @@
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one
+ ~ or more contributor license agreements. See the NOTICE file
+ ~ distributed with this work for additional information
+ ~ regarding copyright ownership. The ASF licenses this file
+ ~ to you under the Apache License, Version 2.0 (the
+ ~ "License"); you may not use this file except in compliance
+ ~ with the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing,
+ ~ software distributed under the License is distributed on an
+ ~ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ ~ KIND, either express or implied. See the License for the
+ ~ specific language governing permissions and limitations
+ ~ under the License.
+ -->
+# Testing Data and Utilities for Apache Parquet
diff --git a/src/arrow/cpp/submodules/parquet-testing/bad_data/PARQUET-1481.parquet b/src/arrow/cpp/submodules/parquet-testing/bad_data/PARQUET-1481.parquet
new file mode 100644
index 000000000..614912f63
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/bad_data/PARQUET-1481.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/bad_data/README.md b/src/arrow/cpp/submodules/parquet-testing/bad_data/README.md
new file mode 100644
index 000000000..472865b0b
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/bad_data/README.md
@@ -0,0 +1,24 @@
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one
+ ~ or more contributor license agreements. See the NOTICE file
+ ~ distributed with this work for additional information
+ ~ regarding copyright ownership. The ASF licenses this file
+ ~ to you under the Apache License, Version 2.0 (the
+ ~ "License"); you may not use this file except in compliance
+ ~ with the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing,
+ ~ software distributed under the License is distributed on an
+ ~ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ ~ KIND, either express or implied. See the License for the
+ ~ specific language governing permissions and limitations
+ ~ under the License.
+ -->
+# "Bad Data" files
+
+These are files used for reproducing various bugs that have been reported.
+
+* PARQUET-1481.parquet: tests a case where a schema Thrift value has been
+ corrupted
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/README.md b/src/arrow/cpp/submodules/parquet-testing/data/README.md
new file mode 100644
index 000000000..80674f303
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/README.md
@@ -0,0 +1,58 @@
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one
+ ~ or more contributor license agreements. See the NOTICE file
+ ~ distributed with this work for additional information
+ ~ regarding copyright ownership. The ASF licenses this file
+ ~ to you under the Apache License, Version 2.0 (the
+ ~ "License"); you may not use this file except in compliance
+ ~ with the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing,
+ ~ software distributed under the License is distributed on an
+ ~ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ ~ KIND, either express or implied. See the License for the
+ ~ specific language governing permissions and limitations
+ ~ under the License.
+ -->
+
+# Test data files for Parquet compatibility and regression testing
+
+| File | Description |
+|---|---|
+| delta_binary_packed.parquet | INT32 and INT64 columns with DELTA_BINARY_PACKED encoding. See [delta_binary_packed.md](delta_binary_packed.md) for details. |
+| nested_structs.rust.parquet | Used to test that the Rust Arrow reader can lookup the correct field from a nested struct. See [ARROW-11452](https://issues.apache.org/jira/browse/ARROW-11452) |
+
+TODO: Document what each file is in the table above.
+
+## Encrypted Files
+
+Tests files with .parquet.encrypted suffix are encrypted using Parquet Modular Encryption.
+
+A detailed description of the Parquet Modular Encryption specification can be found here:
+```
+ https://github.com/apache/parquet-format/blob/encryption/Encryption.md
+```
+
+Following are the keys and key ids (when using key\_retriever) used to encrypt the encrypted columns and footer in the all the encrypted files:
+* Encrypted/Signed Footer:
+ * key: {0,1,2,3,4,5,6,7,8,9,0,1,2,3,4,5}
+ * key_id: "kf"
+* Encrypted column named double_field:
+ * key: {1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,0}
+ * key_id: "kc1"
+* Encrypted column named float_field:
+ * key: {1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,1}
+ * key_id: "kc2"
+
+The following files are encrypted with AAD prefix "tester":
+1. encrypt\_columns\_and\_footer\_disable\_aad\_storage.parquet.encrypted
+2. encrypt\_columns\_and\_footer\_aad.parquet.encrypted
+
+
+A sample that reads and checks these files can be found at the following tests:
+```
+cpp/src/parquet/encryption-read-configurations-test.cc
+cpp/src/parquet/test-encryption-util.h
+```
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/alltypes_dictionary.parquet b/src/arrow/cpp/submodules/parquet-testing/data/alltypes_dictionary.parquet
new file mode 100755
index 000000000..e6da6ab7b
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/alltypes_dictionary.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/alltypes_plain.parquet b/src/arrow/cpp/submodules/parquet-testing/data/alltypes_plain.parquet
new file mode 100755
index 000000000..a63f5dca7
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/alltypes_plain.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/alltypes_plain.snappy.parquet b/src/arrow/cpp/submodules/parquet-testing/data/alltypes_plain.snappy.parquet
new file mode 100755
index 000000000..9809d6765
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/alltypes_plain.snappy.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/binary.parquet b/src/arrow/cpp/submodules/parquet-testing/data/binary.parquet
new file mode 100644
index 000000000..fc8c04669
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/binary.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/bloom_filter.bin b/src/arrow/cpp/submodules/parquet-testing/data/bloom_filter.bin
new file mode 100644
index 000000000..c0e30ce74
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/bloom_filter.bin
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/byte_array_decimal.parquet b/src/arrow/cpp/submodules/parquet-testing/data/byte_array_decimal.parquet
new file mode 100644
index 000000000..798cb2aad
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/byte_array_decimal.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/datapage_v2.snappy.parquet b/src/arrow/cpp/submodules/parquet-testing/data/datapage_v2.snappy.parquet
new file mode 100644
index 000000000..2b77bb1e9
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/datapage_v2.snappy.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/delta_binary_packed.md b/src/arrow/cpp/submodules/parquet-testing/data/delta_binary_packed.md
new file mode 100644
index 000000000..27a6df490
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/delta_binary_packed.md
@@ -0,0 +1,440 @@
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one
+ ~ or more contributor license agreements. See the NOTICE file
+ ~ distributed with this work for additional information
+ ~ regarding copyright ownership. The ASF licenses this file
+ ~ to you under the Apache License, Version 2.0 (the
+ ~ "License"); you may not use this file except in compliance
+ ~ with the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing,
+ ~ software distributed under the License is distributed on an
+ ~ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ ~ KIND, either express or implied. See the License for the
+ ~ specific language governing permissions and limitations
+ ~ under the License.
+ -->
+
+`delta_binary_packed.parquet` is generated with parquet-mr version 1.10.0.
+The expected file contents are in `delta_binary_packed_expect.csv`.
+
+Each column is DELTA_BINARY_PACKED-encoded, with a different delta bitwidth
+for each column.
+
+Each column has 200 rows including 1 first value and 2 blocks. The first block
+has 4 miniblocks, the second block has 3. Each miniblock contains 32 values,
+except the last one which only contains 7 values (1 + 6 \* 32 + 7 == 200).
+
+Here is the file structure:
+```
+File Name: /home/antoine/parquet/testing/data/delta_binary_packed.parquet
+Version: 1.0
+Created By: parquet-mr version 1.10.0 (build 031a6654009e3b82020012a18434c582bd74c73a)
+Total rows: 200
+Number of RowGroups: 1
+Number of Real Columns: 66
+Number of Columns: 66
+Number of Selected Columns: 66
+Column 0: bitwidth0 (INT64)
+Column 1: bitwidth1 (INT64)
+Column 2: bitwidth2 (INT64)
+Column 3: bitwidth3 (INT64)
+Column 4: bitwidth4 (INT64)
+Column 5: bitwidth5 (INT64)
+Column 6: bitwidth6 (INT64)
+Column 7: bitwidth7 (INT64)
+Column 8: bitwidth8 (INT64)
+Column 9: bitwidth9 (INT64)
+Column 10: bitwidth10 (INT64)
+Column 11: bitwidth11 (INT64)
+Column 12: bitwidth12 (INT64)
+Column 13: bitwidth13 (INT64)
+Column 14: bitwidth14 (INT64)
+Column 15: bitwidth15 (INT64)
+Column 16: bitwidth16 (INT64)
+Column 17: bitwidth17 (INT64)
+Column 18: bitwidth18 (INT64)
+Column 19: bitwidth19 (INT64)
+Column 20: bitwidth20 (INT64)
+Column 21: bitwidth21 (INT64)
+Column 22: bitwidth22 (INT64)
+Column 23: bitwidth23 (INT64)
+Column 24: bitwidth24 (INT64)
+Column 25: bitwidth25 (INT64)
+Column 26: bitwidth26 (INT64)
+Column 27: bitwidth27 (INT64)
+Column 28: bitwidth28 (INT64)
+Column 29: bitwidth29 (INT64)
+Column 30: bitwidth30 (INT64)
+Column 31: bitwidth31 (INT64)
+Column 32: bitwidth32 (INT64)
+Column 33: bitwidth33 (INT64)
+Column 34: bitwidth34 (INT64)
+Column 35: bitwidth35 (INT64)
+Column 36: bitwidth36 (INT64)
+Column 37: bitwidth37 (INT64)
+Column 38: bitwidth38 (INT64)
+Column 39: bitwidth39 (INT64)
+Column 40: bitwidth40 (INT64)
+Column 41: bitwidth41 (INT64)
+Column 42: bitwidth42 (INT64)
+Column 43: bitwidth43 (INT64)
+Column 44: bitwidth44 (INT64)
+Column 45: bitwidth45 (INT64)
+Column 46: bitwidth46 (INT64)
+Column 47: bitwidth47 (INT64)
+Column 48: bitwidth48 (INT64)
+Column 49: bitwidth49 (INT64)
+Column 50: bitwidth50 (INT64)
+Column 51: bitwidth51 (INT64)
+Column 52: bitwidth52 (INT64)
+Column 53: bitwidth53 (INT64)
+Column 54: bitwidth54 (INT64)
+Column 55: bitwidth55 (INT64)
+Column 56: bitwidth56 (INT64)
+Column 57: bitwidth57 (INT64)
+Column 58: bitwidth58 (INT64)
+Column 59: bitwidth59 (INT64)
+Column 60: bitwidth60 (INT64)
+Column 61: bitwidth61 (INT64)
+Column 62: bitwidth62 (INT64)
+Column 63: bitwidth63 (INT64)
+Column 64: bitwidth64 (INT64)
+Column 65: int_value (INT32)
+--- Row Group: 0 ---
+--- Total Bytes: 65467 ---
+--- Total Compressed Bytes: 0 ---
+--- Rows: 200 ---
+Column 0
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 6374628540732951412, Min: 6374628540732951412
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 95, Compressed Size: 95
+Column 1
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 0, Min: -104
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 114, Compressed Size: 114
+Column 2
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 0, Min: -82
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 144, Compressed Size: 144
+Column 3
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 0, Min: -96
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 172, Compressed Size: 172
+Column 4
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 0, Min: -132
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 200, Compressed Size: 200
+Column 5
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 24, Min: -290
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 228, Compressed Size: 228
+Column 6
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 259, Min: -93
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 256, Compressed Size: 256
+Column 7
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 476, Min: -64
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 284, Compressed Size: 284
+Column 8
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 387, Min: -732
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 314, Compressed Size: 314
+Column 9
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 194, Min: -1572
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 342, Compressed Size: 342
+Column 10
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 5336, Min: -2353
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 370, Compressed Size: 370
+Column 11
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 13445, Min: -8028
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 398, Compressed Size: 398
+Column 12
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 2017, Min: -35523
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 426, Compressed Size: 426
+Column 13
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 48649, Min: -4096
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 454, Compressed Size: 454
+Column 14
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 65709, Min: -8244
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 482, Compressed Size: 482
+Column 15
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 69786, Min: -106702
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 512, Compressed Size: 512
+Column 16
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 162951, Min: -347012
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 540, Compressed Size: 540
+Column 17
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 0, Min: -1054098
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 568, Compressed Size: 568
+Column 18
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 664380, Min: -372793
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 596, Compressed Size: 596
+Column 19
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 4001179, Min: -402775
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 624, Compressed Size: 624
+Column 20
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 788039, Min: -4434785
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 652, Compressed Size: 652
+Column 21
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 12455554, Min: -1070042
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 680, Compressed Size: 680
+Column 22
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 2189135, Min: -17987827
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 710, Compressed Size: 710
+Column 23
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 25967351, Min: -19361900
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 738, Compressed Size: 738
+Column 24
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 95688064, Min: -17271207
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 766, Compressed Size: 766
+Column 25
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 169215083, Min: -18759951
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 794, Compressed Size: 794
+Column 26
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 163626565, Min: -168761837
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 822, Compressed Size: 822
+Column 27
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 131734874, Min: -736933601
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 850, Compressed Size: 850
+Column 28
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 913547745, Min: -490714808
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 878, Compressed Size: 878
+Column 29
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 500305035, Min: -5834684238
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 908, Compressed Size: 908
+Column 30
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 566280334, Min: -7728643109
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 936, Compressed Size: 936
+Column 31
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 18831788461, Min: -2498101101
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 964, Compressed Size: 964
+Column 32
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 23720914586, Min: -2147483648
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 992, Compressed Size: 992
+Column 33
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 24075494509, Min: -4817999329
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1020, Compressed Size: 1020
+Column 34
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 12118456329, Min: -156025641218
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1048, Compressed Size: 1048
+Column 35
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 41614351758, Min: -114682966820
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1076, Compressed Size: 1076
+Column 36
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 82484946621, Min: -244178626927
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1106, Compressed Size: 1106
+Column 37
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 361459323159, Min: -275190620271
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1134, Compressed Size: 1134
+Column 38
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 1665294434042, Min: -420452598502
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1162, Compressed Size: 1162
+Column 39
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 110454290134, Min: -2926211785103
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1190, Compressed Size: 1190
+Column 40
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 3215717068302, Min: -4988823342986
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1218, Compressed Size: 1218
+Column 41
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 2166086616318, Min: -6488418568768
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1246, Compressed Size: 1246
+Column 42
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 10182365256028, Min: -8738522616121
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1274, Compressed Size: 1274
+Column 43
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 22909885827147, Min: -21214625470327
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1304, Compressed Size: 1304
+Column 44
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 67902133645749, Min: -9796939892175
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1332, Compressed Size: 1332
+Column 45
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 199494208930939, Min: -102473613757961
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1360, Compressed Size: 1360
+Column 46
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 18564971260296, Min: -359696498357610
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1388, Compressed Size: 1388
+Column 47
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 65624006999260, Min: -933995610201533
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1416, Compressed Size: 1416
+Column 48
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 983500521840940, Min: -878019827431629
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1444, Compressed Size: 1444
+Column 49
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 975533803684560, Min: -2091164446177739
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1472, Compressed Size: 1472
+Column 50
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 1276327559487856, Min: -5741928190724373
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1502, Compressed Size: 1502
+Column 51
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 0, Min: -15996275819941210
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1530, Compressed Size: 1530
+Column 52
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 12697545666077932, Min: -8823113595895130
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1558, Compressed Size: 1558
+Column 53
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 4785870085681342, Min: -24800000653307089
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1586, Compressed Size: 1586
+Column 54
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 26202576654140994, Min: -94647392931900711
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1614, Compressed Size: 1614
+Column 55
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 358302517069012889, Min: -32197353745654772
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1642, Compressed Size: 1642
+Column 56
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 678791154000627912, Min: -36028797018963968
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1670, Compressed Size: 1670
+Column 57
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 1444945950888122232, Min: -79246600304853010
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1700, Compressed Size: 1700
+Column 58
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 140747723990970254, Min: -1492687553985044679
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1728, Compressed Size: 1728
+Column 59
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 6318360990909070, Min: -3778424577629102559
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1756, Compressed Size: 1756
+Column 60
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 4574162334421819801, Min: -576460752303423488
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1784, Compressed Size: 1784
+Column 61
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 8803535686130338880, Min: -1155450847100943978
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1812, Compressed Size: 1812
+Column 62
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 9026687750017193101, Min: -4454039315625288390
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1840, Compressed Size: 1840
+Column 63
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 9150047972721273816, Min: -9220123451143279334
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1868, Compressed Size: 1868
+Column 64
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 8846115173408951296, Min: -9223372036854775808
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 1898, Compressed Size: 1898
+Column 65
+ Values: 200, Null Values: 0, Distinct Values: 0
+ Max: 2142811258, Min: -2078683524
+ Compression: UNCOMPRESSED, Encodings: DELTA_BINARY_PACKED
+ Uncompressed Size: 980, Compressed Size: 980
+```
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/delta_binary_packed.parquet b/src/arrow/cpp/submodules/parquet-testing/data/delta_binary_packed.parquet
new file mode 100644
index 000000000..4bb56e90e
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/delta_binary_packed.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/delta_binary_packed_expect.csv b/src/arrow/cpp/submodules/parquet-testing/data/delta_binary_packed_expect.csv
new file mode 100644
index 000000000..60a241383
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/delta_binary_packed_expect.csv
@@ -0,0 +1,201 @@
+bitwidth0,bitwidth1,bitwidth2,bitwidth3,bitwidth4,bitwidth5,bitwidth6,bitwidth7,bitwidth8,bitwidth9,bitwidth10,bitwidth11,bitwidth12,bitwidth13,bitwidth14,bitwidth15,bitwidth16,bitwidth17,bitwidth18,bitwidth19,bitwidth20,bitwidth21,bitwidth22,bitwidth23,bitwidth24,bitwidth25,bitwidth26,bitwidth27,bitwidth28,bitwidth29,bitwidth30,bitwidth31,bitwidth32,bitwidth33,bitwidth34,bitwidth35,bitwidth36,bitwidth37,bitwidth38,bitwidth39,bitwidth40,bitwidth41,bitwidth42,bitwidth43,bitwidth44,bitwidth45,bitwidth46,bitwidth47,bitwidth48,bitwidth49,bitwidth50,bitwidth51,bitwidth52,bitwidth53,bitwidth54,bitwidth55,bitwidth56,bitwidth57,bitwidth58,bitwidth59,bitwidth60,bitwidth61,bitwidth62,bitwidth63,bitwidth64,int_value
+6374628540732951412,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,-2070986743
+6374628540732951412,-1,-2,-4,-8,-16,-32,-64,-128,-256,-512,-1024,-2048,-4096,-8192,-16384,-32768,-65536,-131072,-262144,-524288,-1048576,-2097152,-4194304,-8388608,-16777216,-33554432,-67108864,-134217728,-268435456,-536870912,-1073741824,-2147483648,-4294967296,-8589934592,-17179869184,-34359738368,-68719476736,-137438953472,-274877906944,-549755813888,-1099511627776,-2199023255552,-4398046511104,-8796093022208,-17592186044416,-35184372088832,-70368744177664,-140737488355328,-281474976710656,-562949953421312,-1125899906842624,-2251799813685248,-4503599627370496,-9007199254740992,-18014398509481984,-36028797018963968,-72057594037927936,-144115188075855872,-288230376151711744,-576460752303423488,-1152921504606846976,-2305843009213693952,-4611686018427387904,-9223372036854775808,-22783326
+6374628540732951412,-1,-1,-1,-5,-15,-3,-3,-65,-135,-151,-884,-498,-541,-2357,-13235,-5032,-53240,-29011,-250328,-472084,-14542,-1696277,-1840215,-3928478,-7271486,-21651755,-32837885,-111433329,-241601040,-50065479,-163270799,-1986952430,-4199967586,-1344496374,-16288562574,-28644417946,-65240709120,-70538406032,-191624370730,-296605683166,-1011770835153,-1608485630598,-4014593275381,-6105818392192,-16882699520562,-24806559006505,-2633546645122,-117376993649456,-108462505318873,-17587725526794,-404921059206593,-156565041626495,-314837785970473,-5814847697611025,-12226190471490179,-25585079015954934,-15550441703940602,-56208863042127981,-90038694817934139,-119581030183793765,-506045522191644725,-608563416973139311,-268654784870221847,-725202170854031360,-1782018724
+6374628540732951412,-1,-3,-3,-3,-14,19,9,-70,-291,-145,-275,-779,-4047,-8244,-9036,19242,-48154,57337,-76872,-140006,-606577,-1112970,-1741237,-9834470,-18759951,-6702647,-48746571,-75629270,-261813622,-372076116,-209790348,-1454655246,14148394,3594034401,-22765140849,-37447824121,973131962,-184009634540,-414171928390,-33166024765,-1197833965973,-2945889030754,-4176475312470,-3485626798002,-33267360685874,-30046841485160,-44798370319938,-169294344644831,43524599724868,469339429031087,-249413706008860,1783640864794895,2532719170052087,-6021542079991477,-22672611942859889,1385417775479117,25457797741384317,15212409470706824,-293860965757446445,-421796672552175066,-739907890298874818,-359496579899035406,1767081710130543267,-9103419799827896320,-795597708
+6374628540732951412,-1,-2,-7,-7,-23,-4,51,-107,-322,-69,-404,-497,-1478,-1446,3148,-5651,-72427,2324,-180065,-73164,-1070042,-2913370,-2556489,-13152247,-18307794,-21400650,-21970468,-49244077,-154493129,-865963106,-675026787,269282718,-1936781439,11152842629,-8330016935,-20375075134,4843904084,-154956898562,-348410255803,-57366930449,-1692439856605,-1475642885659,-8535248964530,-6025906219756,-16857628299601,-62685577005613,-27221312319071,-198785994498891,171238115176440,660593922279105,-943597583604423,2827707839263247,4070733887796967,1851107790825924,-10035362931120011,-34178564238813898,-15400877254369937,14534187603609527,-436091502934230264,103753242946746509,-433954430332814130,-2277734064379320830,-1952725388208511715,-598707373574004736,-50404127
+6374628540732951412,-1,-1,-6,-7,-21,-29,58,-174,-370,147,-303,-257,-1037,5630,8895,-35608,-117842,-12437,-27731,60486,-707663,-2929145,-2079255,-17271207,-1875268,-19770821,-80161864,60023242,-421079858,-1199795687,-1596156003,-1428636547,-3315076943,4309884401,-1119928926,-7573364471,-29439686716,-77734727780,-381153221847,438498389411,-1075668472404,688004155483,-12155875302808,-2421765938549,-1412459333102,-57897024522480,-67231976161515,-102178629353955,200912794242206,437990823358964,-691145191610179,3467057874929278,985930545990978,7330199804693107,-15787463490118793,-3529811048518225,-79246600304853010,131968846278002895,-282604809562268943,393659097234343208,-117358901750833047,-4454039315625288390,1057622537212765505,-467456400607547392,-1324028940
+6374628540732951412,-2,-1,-7,-8,-11,-4,102,-90,-428,640,248,-666,1697,12168,11689,-47454,-130484,20117,-75384,-93320,203268,-4986664,-2162578,-8995125,-6760008,-32244703,-106596127,177836390,-266837798,-762866147,-2468267208,389760980,-992988241,7180141332,14885016605,-1511637873,37404525522,-197046149773,-187479057508,230274097668,-927050760313,1799760668852,-10729406875828,-9796939892175,11804765938550,-50966479811265,-44073994829259,-96429064661339,286836747775733,826008137253651,-819172490150060,2491046898575225,371401587078797,-573582956573021,-22424390824700858,8968658959701591,-18708944194978315,140747723990970254,-456990195146363428,195559856224514923,-58121504121408814,-3938402690046704430,-2644179234690480376,-8590184427836801024,1224303596
+6374628540732951412,-2,0,-5,-4,-23,15,78,25,-298,323,-656,-335,1998,18308,10291,-24779,-168487,52710,163729,312623,1124572,-6004940,365099,-10397375,8574828,-63826718,-60362169,256227242,-3856403,-975832397,-2498101101,-880256444,-2948367208,6203611716,19687928212,-9781230924,8614389461,-181589169937,-139212010426,371875472672,-1174473260226,470808489818,-7350458341287,-1895363706255,10243781943563,-21934176136905,11986018548423,-38783203690109,58988343681595,919985218147182,-1879216418388626,785282571415803,-1798356971526639,4083217559270234,-20622206685995410,34030617276605758,44907398151239119,85907244363722199,-200542403147528933,144951121694234471,94062120218954510,-2429296149216326625,-5213688279934623354,-8082003127769759744,1429112635
+6374628540732951412,-3,-1,-4,-11,-28,32,106,-73,-445,672,255,598,1099,21048,16342,-20646,-231878,84900,91072,713147,2048208,-5340953,2871160,-3849966,13201519,-72225014,-126474549,167271468,-265832301,-1316680326,-1662545575,776304783,-2552033190,-1183286497,18537372866,-15878307818,-41899616987,-89943248617,-399408328189,820052639502,-1184279940837,444515555337,-7255722839195,6632009472646,15949617801314,-28984673820583,62072552514313,-154483198313094,176341372647806,1276327559487856,-1225848427310965,925797371224733,-5158669449317863,-2507988839360628,-27478929288360382,23092989509606526,58254636827550401,6355344246857438,-416555529995365351,311878861125735415,2030877170934473,-520036700878024894,-4146332965026751351,-3911694380420096000,834042975
+6374628540732951412,-3,-1,-3,-9,-43,26,151,-197,-429,570,922,-520,4352,19030,14737,-20869,-232220,38560,303791,315091,1936025,-4881785,-806807,-1604961,14823469,-52806631,-103530652,168599986,-150605862,-844930571,-1108215124,2709901091,-1419850009,5502072164,34311620738,-14853235224,1912837594,-203072339472,-283401907392,501037634663,-1166430618400,1320201073039,-9026458927213,15062471783783,25751297219542,-33368654655909,65624006999260,-131585197525211,33216777823211,781671457575793,-1708360133018130,-702483609363014,-5589580148509433,2362753843908645,-25547339482354497,53583587186612101,64202458217750873,-92289171240387763,-446963789889191093,826714302131941278,-836855863139878147,492305742138744518,-2349617351498839192,2019598164191703040,2046362238
+6374628540732951412,-4,-1,-7,-8,-52,-3,136,-269,-351,204,173,-2031,7542,26111,8591,9313,-183842,87325,156974,788039,2272401,-4139456,-3179991,-9334829,30019285,-78394579,-150140121,67683155,-25564796,-1170509324,-827263943,1241753895,-856982445,1549831302,41614351758,-2854031612,37186447971,-209264261928,-196576495532,65581139599,-565237868852,-27181425997,-8140784131223,20921681146776,10912291877720,-3031050085521,36447651522370,9004339030065,-42551590553472,843011773130246,-987900361291078,-1159122838303987,-3103971948799911,-745754957166194,-28422076921040800,35435080255190746,67507706621992576,-162164502488186937,-238844968946687909,1085551322902928420,-750441688166605150,292109076211582142,-6091285987971363204,1264186545259435008,-153007359
+6374628540732951412,-4,-2,-6,-11,-45,-6,194,-378,-221,691,-767,-2336,5464,20543,16606,4413,-242365,61023,27624,381755,1328223,-4063731,-7210880,-15643397,25006907,-54101392,-180677852,-60475857,-172210328,-1258660760,-183835141,1409848176,359114787,1726104739,39972878265,-5574976223,50385356686,-326608050977,-114130189114,237653488482,-1041584369602,717701933094,-10027343371185,22321763710601,20394761493740,5491020402777,64940797932572,122936888914550,-141679511062404,785008273684721,-1843069583722744,-1346066034615013,-4985979727434387,-9419333174161675,-18025836127805424,14789131126627303,117949773758419237,-110224478645965308,-24575572426686944,985411710366330475,-724098107667330840,-1293865679598679971,-8435258790141432536,-7349077577920466944,1051233348
+6374628540732951412,-4,-2,-8,-12,-48,15,184,-460,-81,251,-122,-1453,7784,21735,8660,24604,-221858,183825,203235,21300,482590,-5463685,-5342146,-8265414,36887959,-33434707,-128286361,3262009,-151978673,-960520618,279184899,2538798575,-1265526613,3726788297,24506739045,26428010315,44446136672,-419215517257,-163619049256,457258611444,-406557459633,-796142081746,-10243623720779,19089143898331,35568348582543,14490484773703,41306150210456,116003021152398,-81382130315426,254138347978205,-2915860617063731,-1231506759575816,-6702514319070943,-12413603378500972,-28457876542639136,755256011086418,184751697689979146,-169688843895160882,-77122821522994119,1501442858855176059,-50471319400792401,-325175908277318769,-7252655540844764416,17097750381551616,210007250
+6374628540732951412,-5,-3,-12,-16,-34,43,183,-351,51,66,-709,-204,10563,19194,13801,-4729,-259171,134222,130879,207048,1366368,-3822822,-4739546,-6954400,51827180,-24975715,-175712930,34807814,-197864177,-1399742781,1236190915,1633477006,1945899508,5809365546,27986070347,36166072626,78394978013,-342607723169,110454290134,-82650350547,-1053697599636,-337619588982,-8602981862148,26487377084827,19291311473788,18564971260296,25346102357092,127195492574191,75756773172574,90863646384293,-2259980444104873,-2609233950205194,-6895991431055500,-3631917991776925,-32197353745654772,22067561152949294,184073407071804587,-283677461028711535,-15443749445198100,1533792676030418717,-116049510484975592,616057014773691101,-8957145411688676723,5961558049880928256,-1817882083
+6374628540732951412,-5,-5,-11,-11,-36,17,188,-369,-83,-88,-1112,-331,8027,25335,3146,-37453,-288979,134856,-118943,57541,2346101,-4000344,-7297616,-2888212,61535280,6230976,-128400592,-94491126,-126667984,-977930788,2248165546,1630382608,4238653570,4067672939,16165293331,23660827184,16622401717,-304622135824,-129706910361,-392968647967,-54049523633,-1484979317482,-5326191138472,22152796125967,13180333220852,682079674370,-33644269875469,31261039474359,66636843103768,480987092292444,-2330628995264461,-2581213163271129,-8411395908554576,-3660781120922581,-15214738516782829,57686961596767129,235233019872249182,-319015277570937246,6318360990909070,1462381756863886818,360371993800348532,1578197976097379657,-6130027235446772267,1461822930334877696,220205244
+6374628540732951412,-6,-7,-11,-13,-34,-13,203,-255,145,-39,-774,507,10738,31029,4496,-13567,-272491,228921,-55161,-200429,2383928,-4972456,-3265353,-5130609,48002890,-5161856,-168289205,-71711182,-247244088,-706223501,2742860997,3324113516,2471315729,-1927418155,21112511291,48354160341,27196068163,-244274773926,-25856549636,-254454962946,1040185696684,-1712710383630,-5616874798653,22622515352035,6846895600224,14285840937912,-47792518855222,1910771044941,-44422979644211,965456312943592,-2362758329552222,-1833886426438552,-8491675275163907,-10034900047112753,1294695905928017,69438008375973772,296455797182964681,-461812148548489539,-152142795067870614,1666014290688223233,390213064730011165,3167985562120293098,-9220123451143279334,3726157112338110464,82429627
+6374628540732951412,-6,-7,-11,-21,-47,13,183,-182,135,140,-1017,2017,6917,24795,17912,1664,-324152,310180,-148535,-568423,2018144,-6275527,-5413837,-3752353,53826389,-18130804,-220232917,30029037,-134469831,-248795173,3479351773,1952067057,-1620707824,5024643918,17682641333,82484946621,-39099521050,-217472849943,-147556293200,259085378020,2120897038470,-1746422614133,-1483307040992,19077897884413,-380739342863,-18049761857411,-79473547495057,-77742730744717,60767719291339,629838582980701,-3006652892076308,-2889327032596333,-4598982272107974,-16572845775814946,-8391381010490785,98021515791900951,228083297569685478,-459855145020909508,-350303676973690668,1384676125802549361,-379235671452725086,4947000944862738882,-6195030935369173820,-4494873081776057344,702155563
+6374628540732951412,-6,-6,-8,-29,-40,-14,214,-188,194,-279,-1623,1463,9935,29722,18081,-23177,-336960,416841,-47697,-162188,3062057,-5102894,-7773642,-3394421,56702732,-8333926,-270462406,-92340512,24455474,-14764520,3424410295,2279700482,-106210289,2126282317,3894081107,76766210637,-70459672031,-114650718198,-394936601262,495494915946,1242889596156,-2699892005051,-5208119149999,25244760306304,-13653179193405,-3993310646291,-53603136741176,50649356933026,182976060585478,94017047965540,-4036907043267656,-999326825977612,-7412685455461500,-9571854190591394,3175869249099665,132067903992614990,257759141497873824,-577566623834731614,-553576334223816830,1884070233151006445,581591214091267573,6258135181772316562,-2319688758903463783,-1424218481538575360,1911942950
+6374628540732951412,-6,-8,-10,-22,-29,3,273,-209,34,-475,-612,604,10333,30688,5129,-9078,-279131,541212,193376,-669756,3749363,-6208273,-5215874,-2526744,54709238,-32834249,-266454902,-55841137,179501559,515579198,3004095619,4353264343,-1745035242,8027806589,-4473603776,54862416602,-25193005383,13113868685,-431671473367,305819879998,2166086616318,-2020698400740,-4993004219332,31872833909537,-22918334444952,-38482433529913,-121353638315875,-63379425113142,62907114751970,-447673261879456,-4335762216711756,-3233987753437520,-11195252519669560,-1384461551943916,11368852892270041,151090287387553419,267101516573481406,-549575652310474165,-686912809629082843,1971926527489838425,828578743263410397,8041214653566045388,-5005505913581152564,-5821039515811758080,-905379917
+6374628540732951412,-7,-9,-13,-17,-15,-17,216,-256,-101,-282,-7,-1086,6482,35438,-11017,23415,-262156,593730,405272,-877996,3873681,-6830441,-5274200,2493144,62720416,-21870765,-277301229,66252901,405243011,566280334,3377744777,5652709694,1315439537,12118456329,-20903199284,21161439999,-89220532851,74388733865,-654101126849,569798028789,2041406648205,-791527054543,-4489949735744,36270235012165,-37442516489125,-27662545371525,-167329002922621,23195875336988,205261825765061,-927699251907591,-5422451385958253,-5130245430716973,-7985830625292740,7019785753428889,9416182713810714,172181767568916677,273947917764265087,-437044298514736582,-884248574882876627,2512640667946386373,-43405003765765305,8231151728822495711,-4568379919163274696,-5137630757571825664,-1030925156
+6374628540732951412,-7,-11,-15,-15,-3,-43,198,-255,-106,-514,-888,-679,8368,28919,-12598,42101,-249917,642756,209826,-1399006,4757968,-8132876,-2582997,5639371,77060564,-15790245,-307133087,119234000,272584998,426700004,2907033078,5756061491,-2907710893,8092897899,-3888807941,-5313371742,-64459862894,-39259230388,-615214276119,188035061342,1065998312584,-2387917451040,-3513505231567,36648724688893,-53474604321052,-38127628843154,-143430021932474,131180796319771,58189125995042,-1339895145639924,-5989443249041422,-6343561513426457,-3557123276292794,12756985501257801,13606512727540182,203430540769661668,268769334053725175,-501936925097179159,-742841835511873883,1966079345590272135,-189669085635190385,6836769423491319718,-1555569406885449218,-3440069589541775360,448016346
+6374628540732951412,-7,-10,-13,-21,11,-55,164,-163,93,-900,-1689,-1982,5702,24966,797,40053,-205814,524000,254235,-1681473,5316902,-10084007,-512193,1600394,64643127,-28292930,-300590287,197806232,74650755,2609645,2870804347,6346517724,-4181582725,6589106581,-3355235832,18526361405,-92843541365,-144004661460,-553509135314,533651126176,15797993759,-2789670591291,-273080585072,31934975430349,-54238301917814,-15492375971963,-174062952951388,258109393018223,159158144615396,-1318695021861251,-6213870353171897,-5535461574234650,-1225646822383584,18798752987688257,17648440030659685,193947526478255070,308538978966923002,-398592995170696524,-883721907880047266,2207688497788404302,-535893864296582626,5206804591582589346,-5600099273922787095,4835791230042253312,-1069926607
+6374628540732951412,-8,-11,-16,-21,24,-36,179,-184,83,-1338,-1339,-2276,2834,21429,-4715,21535,-263200,474253,263017,-2036880,6052891,-9654337,674651,7173212,55638712,-48630025,-297460073,126491345,-190928965,-208406241,3812322220,4768586569,37144455,4849257923,-5864400307,-15789487221,-88700197099,-169050461282,-584635350404,63886471587,126602314151,-4604332216853,2643257839695,24523469167362,-37048449461483,-48984649210139,-186613861658060,146973411695559,144814412567366,-1317073504614932,-6254855905240751,-5951355776843536,-1685869425988856,15755297163469064,24133668071218416,195850668106291840,269093073060965243,-461239849008195924,-734860787648762402,2641298396212742760,132937022314482838,7333681238432168342,-2059030261806018300,-318225084124536832,1577807398
+6374628540732951412,-9,-10,-15,-25,15,-55,223,-178,55,-1781,-403,-2462,5660,29134,-56,-425,-298945,556026,369783,-2539443,5224121,-9938522,-2347713,9824244,69618452,-28785876,-269793693,195291497,-164968678,-696511018,4349893233,6673277082,3622273031,12096664539,-22419880956,-13707074339,-37937180892,-191145028587,-385395971689,524773601806,-206247158342,-3901645285313,1753450037209,32427913726542,-49009370782132,-76331674636789,-142701056393917,153563878473207,209274116695192,-1233007473695194,-5859526505388317,-7729727857937996,-4983342825287396,21747107949988642,37417011818772190,210379610089757935,216591967140164379,-490273296904166706,-796701247476686999,2901152837793554559,1047827036544931557,5577896658935768353,-3795511167352361300,-5297612974525495296,121762752
+6374628540732951412,-9,-10,-14,-18,19,-69,256,-138,-99,-1839,-1386,-2599,7715,34957,8959,-23568,-358293,570914,478701,-3015743,4770668,-9985638,-2262731,4738730,69386227,-56573733,-295696709,294332071,-176965370,-913647424,4574141473,6426702734,5955404,11832841384,-11876892203,-30427132825,2016350265,-85234332419,-242368677502,244358389804,-1223700769473,-2703247922495,1407882774781,26096732853006,-41800310846867,-45842994146548,-169962267118368,222746400316687,267118110669187,-747040292680090,-6390589589280166,-6599361727142087,-7286790358028413,21689742493884593,54896452125498397,198277800139263978,160075559716653835,-355436471220344016,-589927812848240945,3194841189277626776,1895524699585408548,6457179915967248253,-2453532199460285291,-2748579601026834432,1157398905
+6374628540732951412,-9,-9,-18,-17,23,-38,296,-98,-305,-1795,-770,-1357,9652,27937,-3122,-31972,-359322,489024,612671,-3212826,5166521,-9622805,-1302732,8459983,71678200,-84235931,-329328072,255414269,1011330,-626227712,3759762688,7385666091,-21260882,7746504192,-22011352668,-30669943268,-36830496533,-80211721825,-153164905498,-47663517347,-305576756157,-1437820045902,3492460433651,18760046044602,-46032004359799,-29475948798842,-177589447364285,288274561418854,-13169928785126,-817999285542971,-5530087477425067,-8557308794603521,-6532431888823801,13774065038201723,65730886240147813,208013450238313757,153786904441469635,-381545634802663893,-657974171030196652,2769549824608328743,2858127770648797317,7050027473204189574,998329856251766612,37497956641886208,-159149608
+6374628540732951412,-9,-10,-19,-22,9,-24,317,-163,-333,-1581,-1151,-1031,12384,32438,-2852,-26220,-412192,521641,366379,-3087723,5313980,-8043179,-4094622,8139868,84929057,-53660364,-316905627,151449620,-188265758,-369457719,2754276670,9275949257,1592825001,3448408297,-28911853258,-54969826220,-63182733617,-164249919271,-318225448358,-527102425213,-961171142140,-3294309119656,-47704534592,15229468694543,-62168491952535,-46946091253840,-229420124112813,169667836867750,-94418288240333,-1272730549559843,-6547627218884509,-8823113595895130,-3832177918342639,19392550704699610,80703572233376554,188240194091383092,177718844298696452,-413394525389424589,-600966474218437774,2928301575356738642,1995294379347026019,8177402787881295676,1923416419787834329,4168656297356431360,-1086596487
+6374628540732951412,-10,-12,-20,-23,5,-49,363,-65,-282,-1534,-137,-1476,15089,30926,-9272,-44728,-397026,531332,278981,-3169282,5796530,-6708287,-2723017,5119126,97103539,-63514661,-287042395,228697652,-405937451,-169725998,2885276424,8809000259,828914763,-1062998081,-30108971520,-66398964789,-120866470593,-194313610030,-219778964664,-241892901676,-271698577002,-1709155762212,3292600241831,13912331400544,-64884725198252,-66876064391487,-232499874890819,171987178850940,174976204232292,-1550752855883718,-6646605530794593,-7446173126356457,-3530313888312526,12301943031809921,80435299812069038,200950139138545565,206871310705383192,-482086455947559342,-767150788140729332,3493755165742407231,1176679341705998418,6370479215378955155,29678279876942237,7526432242651009024,-349032759
+6374628540732951412,-10,-13,-17,-26,-6,-32,340,54,-366,-1615,41,-1753,19038,29954,-5948,-48315,-365462,407814,239015,-3094137,5480318,-4901397,1337188,10210256,83430061,-90243676,-295369285,290529614,-163868264,-700595007,2935500970,8539422859,2986364096,-9625291508,-30674563961,-73107224309,-185769681407,-138115913131,-222717662927,24037455568,-1030687165663,-36893096303,6216440537611,21810299824350,-53951101649165,-97226015005635,-274559582547469,292578890262204,381821871557391,-2037751041932764,-7420163114621595,-6344666642891020,-1856424323285978,19606415099164595,69415616959700911,176820092353864277,265231284332579624,-606989810214179494,-669184506905165901,3092840321466921120,115018389510852727,8610733155318618797,-238587226421700120,1545198226872764416,644234840
+6374628540732951412,-11,-15,-19,-21,-20,-9,331,32,-620,-1409,-574,-3400,20360,37714,-4262,-77463,-405813,517230,401111,-2833695,5364747,-4435979,4390475,1995277,94204816,-83311910,-305956971,286578945,-419534631,-421501262,3226787361,7524807864,3494728539,-9921914525,-40985591149,-39345249088,-187496242498,-258174211386,-337185205680,328060019770,-431642369017,1232846212359,2293664945920,22405969701330,-67770392771882,-75892284622158,-321229647287550,307640402052568,536641693649011,-2508895275420143,-8447383531388522,-4209569150590663,-4539358493789232,26202576654140994,57840215874100081,143523501948146984,322186899454866783,-558070341236042273,-672132905550983258,2783062030436588563,-381922776183626073,6985997114914243383,-2477356827895347551,4401709278912694272,-1216197075
+6374628540732951412,-12,-14,-16,-16,-11,-4,297,156,-432,-1437,-1199,-4715,23745,42635,-11966,-85876,-435397,612151,573168,-2716063,4738391,-3753829,2870222,8443722,92697983,-78130665,-296240162,154468080,-536802736,-251053743,3827561206,6551584771,1880365107,-9821258722,-51929147032,-67798911689,-187116767930,-379957778510,-286167687715,421096378873,-1190649799022,-398612340807,-360401622929,19622261262665,-85227672225366,-46500752457323,-268697676872730,238296976462905,754661634821101,-2162916078253437,-8404025640956742,-5733313778499498,-5876548164553541,17634193130132256,73846454642453579,171017461123053813,387502958992869216,-466907314002110320,-606947866988752185,3103399792089889875,205338861488470052,7928934054248201212,-3192191356225298962,-1867070008845563904,-1937155996
+6374628540732951412,-13,-15,-13,-15,2,-19,357,283,-292,-1308,-344,-3469,26381,44788,-24243,-80060,-479867,568313,614816,-2743216,4207025,-3630400,4095406,11603039,94176581,-109409579,-312242223,265521868,-298945090,-15363098,3108738835,6257631615,2050878870,-13212138142,-66317841094,-70711023448,-166759463499,-420452598502,-375134143619,362966227096,-901178967176,1034669844683,-3777608321125,22205171028681,-94314483461731,-67954301284810,-306528189011690,174892198922400,520327929093463,-2374589746001296,-7774723386722536,-3660443724706266,-3183083770534717,18272112102530356,74628335593744450,206674219763050683,418392384849370438,-474818763353870275,-818405703708370529,3117786407273818183,-464529999663020490,6077167840762513290,-6492852893620660061,-8460822699315977216,-911957403
+6374628540732951412,-14,-17,-12,-13,2,-31,370,163,-105,-1491,-837,-1948,22490,51249,-19897,-62809,-436872,508504,730999,-2818572,4891261,-5495816,2885545,18394560,101808346,-80442726,-274140642,161483490,-250908727,32791188,2094506038,5475946569,568191868,-21241411973,-73134131575,-40775329385,-149678044746,-355497787449,-619234644309,366741620165,-1739548665098,2582363510672,-320409855812,25169606226543,-83070692972088,-41232744075945,-325127238951542,257604126570703,415042372246060,-2472664801900255,-8369691428055968,-3588008955054592,-3078176613865274,10474627898002339,80968412144568983,207068867455615641,397168568230954559,-385150639796247728,-578729451978929707,3542071359603439961,144778666441948564,4303415959865918492,-6178938366746460898,-416656101233648640,-1167656573
+6374628540732951412,-15,-16,-14,-8,16,-52,314,104,-34,-1387,-49,-3708,24171,46571,-25347,-38753,-452572,535356,663732,-2939679,4841781,-5696384,2006407,18055904,100535280,-88873923,-303770991,155626488,-429503792,-342879674,1070962238,5498529597,-2580875907,-21767644630,-82362127294,-10175733405,-217195419213,-262800168303,-643706233426,749434404795,-1334108868611,4575973462830,2069167698904,33307177785424,-92604536532341,-18903459172202,-268900243802667,361213288385709,378699321685629,-1941504544645224,-7899212515628531,-2014188844696944,-3820703577924663,2510569136778914,63229705088570631,230261775915154809,368867750854688196,-269612858450797591,-697179142540116309,3272181471539000090,734438764385526615,2264713900450134860,-5378847394193988534,-8493817543748187136,912053501
+6374628540732951412,-15,-15,-15,-3,6,-41,367,186,74,-1185,-794,-4508,24173,44418,-29505,-46576,-398879,656960,672557,-2821472,3830633,-6107474,2735596,23762144,114695020,-115393752,-342413824,203943734,-246697217,-205562070,906562385,4067552167,-826758089,-18969498692,-78785551260,-20225382662,-230418773036,-301034401181,-484321468970,651172647870,-1156959967248,5835710612829,5905974209528,31269621249875,-84898411852240,12121324873435,-294406456363797,450059825756725,155361481828104,-2442755326843188,-7649417359105478,-1451712835333711,-6218722009993353,-1337315187832507,58497893189786230,231603102163625693,338015355741745400,-192063099392167476,-782325558659377827,2783988493023621589,1583096784859832464,1221754274217258835,-7140137912812301894,-6795163138174600192,-467195949
+6374628540732951412,-15,-14,-15,-2,-3,-46,382,81,-68,-1586,-482,-5050,21908,45725,-26442,-32799,-446277,556705,468549,-3299852,3257804,-6916326,2319847,26028513,114709633,-94914446,-392368271,247962410,-414975497,-393371763,817889430,5126720226,3032747546,-14784501891,-82455289277,-48363958971,-275190620271,-190514738762,-232949294949,585677276451,-1135973705403,4148243825952,5273219273176,33657642747692,-93687028836509,-8760543071725,-249969312651618,427128706089234,97866005641336,-2232374629829352,-7431428829019524,-2048442767379370,-10034573167975778,-716915063402358,65339235741135996,215777687235209860,330918537816242343,-164193048086955334,-612885014294449003,2781980004285670506,2255306993518215630,-619850076667596843,-7626204476201418742,1456881716723723264,-391325924
+6374628540732951412,-15,-14,-18,-2,3,-32,415,-33,-160,-1606,-762,-4570,21581,39036,-21477,-50434,-456284,532726,364649,-2834125,2864892,-7166942,4524402,34096233,109913374,-67384226,-425389757,305929238,-451821651,-584124216,-217369034,6929566751,3153345160,-14801852724,-92342522814,-81812462976,-209667546095,-170153563011,-293287073299,170519658482,-1834593541734,5574764475729,1483584803062,35891605374865,-85055029404813,-42454032166239,-285028356858230,399692352310551,-51586546348519,-2476117766724632,-6542525389280671,-1757406507198323,-5766789391763037,-591928121309074,48676040689074725,202955556514741586,324298025618949625,-270652068698066915,-826370011117569640,2512938230402595920,1666729046694635599,-1367399044260600207,-4217526661054670469,7314653561562862592,927502675
+6374628540732951412,-16,-13,-22,-9,-13,-4,436,-132,-39,-1377,-1431,-5190,20617,35544,-32346,-61517,-475939,471464,122011,-2580165,2334536,-6176854,5102677,28804843,99479563,-83587288,-473203667,209618132,-420346278,-280440646,506105778,8912027903,4479894895,-13203244963,-79172142620,-60383041429,-195565910669,-75346607269,-545962910188,92862639000,-2905287938025,5893659096880,5061029845116,30589531722957,-93039618564640,-56698284723497,-319583510546540,373271923106385,-45034803958494,-3033292614554714,-7382813145704473,-2562348812308202,-3521357234647699,-7321713443621082,52396403363424174,189553425081746207,256180175453754554,-146753449197444711,-699384318800296194,3005869587830999375,918464264450635523,-1182292510937672947,-8725571179793426087,3061067560556047360,-586384928
+6374628540732951412,-17,-14,-26,-13,-25,-17,418,-208,-96,-1853,-2072,-7099,22738,29135,-36538,-94012,-429354,367683,187862,-2447925,1502593,-7075752,1919522,21811512,92668705,-80269779,-467574472,248533699,-659850967,-631338539,991712186,7869308190,8737936847,-12828569191,-78734906451,-66247867535,-208586328861,-183679381467,-757336546557,93444601539,-2998145753640,4645466074559,4941935776329,34553195638673,-91277516283884,-47434701770659,-279247122804492,389198134226247,187345455965173,-2605632739849931,-8474527115993330,-934262638103140,-3684025147104135,-7701746858180436,43745447158936286,167240720166876223,258028363227550071,-290318436648786447,-449090067283393585,2788096328660147690,-8329471268203350,-2180393221627606021,-8487917663441362070,2497767675092794368,-2061074542
+6374628540732951412,-18,-15,-24,-14,-32,-9,420,-244,-222,-1945,-1610,-6958,19339,22248,-41527,-94050,-461168,417454,-34807,-2298203,1323311,-5858954,-21577,21125388,104726971,-91897211,-523027083,297985363,-866766502,-513970396,1554177624,8798736295,12215938504,-21274255608,-64933797073,-62856004287,-192689562239,-210068512086,-964675017817,207062450707,-2922504930594,5405406061341,7424386312522,41576767280778,-102473613757961,-69468619558233,-294054782495576,358617546589719,310967671688035,-2983098488935851,-7626054312384489,-2096175368665746,-7656494711369800,-14333891851096452,37780897926491333,133316549656636501,299893612505416751,-327427255926341493,-695516586028107477,2754825478583752377,-1155450847100943978,-1815444463718485070,-7266643888513060551,-1133972867425997824,-643614834
+6374628540732951412,-18,-17,-23,-13,-38,-29,356,-340,-93,-2353,-2300,-6716,16080,25853,-48444,-122910,-446827,399162,-67246,-1952692,1674508,-6375799,-1199090,26991154,96879106,-95592139,-496035780,398791277,-780647322,-982818199,1888343516,7321435042,10844268939,-20683444722,-48360434507,-31711513362,-130624032676,-104627662396,-1180576569065,705757909967,-2010943564896,5288775602162,4529604877429,38988925352577,-95399215179730,-101187669573281,-236392997704381,450474692161099,339575890037273,-2804845827833715,-8402808139667706,-1868710528095673,-10633047782701784,-7950745695704750,20574698771383522,159505681839169697,308024276949332141,-406698466751900414,-444263261088634046,3204593562924936498,-728506678311573207,325504978007853427,-8698758097516707603,-3657859641528118272,1677819594
+6374628540732951412,-18,-18,-24,-18,-44,-27,315,-314,-246,-2229,-3141,-5410,15532,29682,-62493,-138272,-440100,271400,59403,-1981498,1930649,-7639152,-5365717,29626314,83160496,-114908872,-537908882,271681467,-669152049,-621214805,1572535748,7753769253,13702020359,-19390540427,-48614525767,-9708551088,-84984423190,-97136375759,-1424929407923,721834047309,-1275250422850,7006197948249,3643868622454,34919268481676,-89998306841985,-123548445582568,-250445945325240,550790598619666,318955633905969,-2649969199352914,-9236533584956939,-834837039669361,-8644074185105271,-7560576436478758,6683922803236859,194708017490183072,355675017076353130,-375158168489637103,-479322054038436342,2832063956239644698,341652807500365508,-1054191344449115628,-7992376446337881451,2179375781545911296,1356356082
+6374628540732951412,-18,-20,-22,-12,-41,-16,347,-375,-458,-2042,-3647,-6311,18268,23892,-76157,-125538,-417281,198477,-190721,-2081825,2761164,-5751491,-5426924,22852688,71313136,-103156428,-511905015,369423698,-716307481,-671891685,1429938926,9476772180,13716717930,-24869036362,-51418289947,-3893878526,-20389790316,-130016141268,-1181223864292,279748028101,-2347353187124,5446571940337,4373290887835,34828185845368,-76615515224419,-140277393771421,-234150588693528,465486588396286,569153857320210,-2795509602954751,-10151184040271244,-2567280619937823,-8230258866737662,858082422275149,21385735745645891,194410431013859998,324901842395557882,-485562425920614842,-713998077091176070,3046098254557843363,1485537618608922363,-1507557121276532622,-4431038937645019830,8326591918756537344,-1352516827
+6374628540732951412,-18,-21,-24,-12,-54,-8,309,-311,-489,-1740,-3549,-7802,18655,23765,-64460,-101519,-431290,152067,-241537,-2399560,2260576,-6872762,-6514711,27854319,66022101,-123175071,-497022362,497095914,-606190657,-937336143,2190879922,9259324611,16527120765,-22847585253,-65910570458,-2108810480,-46957848782,-44233524342,-1045873163093,186036561458,-2238781522284,5623807577509,6787062246470,26122838475757,-62683074359360,-136567797979950,-239374111040930,582200562035618,836014609751890,-2711821402795696,-10914643067454514,-1720717803971210,-8400415444262269,-5128558516205996,6231149602069255,230050118467153058,363653690357763067,-521276879943735417,-564890977272331542,3012217488526692074,751237396738962666,-360885910745046818,-8132427014424454271,8846115173408951296,-225556450
+6374628540732951412,-19,-22,-25,-6,-39,5,347,-367,-692,-1874,-3759,-7771,22516,28009,-59017,-93514,-384072,94963,-83766,-2218772,2301885,-5594721,-10124021,30797411,54597936,-143692397,-540267934,602310745,-853156572,-1364972701,1183125576,8657775710,14940614456,-16945248696,-51972431903,-9484873134,10773463957,-61703382830,-1220196697169,-79116954179,-1596747465766,6833154855816,9458236561650,23423650394735,-53542524369290,-122644442670593,-205780212091217,568998649382263,975533803684560,-2768842942600986,-10645194830347358,-543433891777610,-5689074319669578,-3521651344707690,1103509047795451,254909198767174352,411032161673517047,-417280016384856585,-677968009921641672,3433128179636982601,307445181151263986,-809477710037820863,-4786639857950736392,1338470934294153216,-1952200131
+6374628540732951412,-19,-22,-28,-13,-24,-8,346,-301,-919,-1659,-3359,-6431,24671,26194,-71633,-84661,-396067,-30465,-165409,-2646427,3234402,-5890357,-7344606,37198998,38305453,-168761837,-526251441,659357255,-747956045,-913526941,1251068460,9134956793,14637541146,-15478055664,-61633499767,8841772877,-20332464785,26307200504,-1003768172256,14948241773,-1174627491655,5914813061098,13229597275461,17773914995676,-41108395473625,-109913388380010,-266690599716813,581887343889974,846432263401130,-2756436136524610,-10867470987662254,1223433357914035,-9624496880903382,-6823272252719896,13536561821967676,231619800331238411,418205986125349431,-533825251328162368,-916214194812374768,3496955125626130047,209471921013672987,-72894387487028593,-6934161570124393307,-7851699792404943872,1512239260
+6374628540732951412,-19,-21,-27,-21,-11,1,285,-325,-1058,-1298,-3899,-7629,27866,22459,-71785,-83693,-405799,15271,83147,-3156912,4172656,-3942570,-10657127,30964429,34180403,-152036381,-460263515,763366788,-914440647,-921211515,427808705,9218024075,18800657524,-23804137354,-46602391228,-17548013073,42626325099,138603818573,-1266595818168,318793648294,-402546485154,6629670777401,17598365111788,9935199310970,-46706378478257,-91552744921399,-279724103470707,504237982484329,852858611901438,-2200970650097168,-11982049676784393,1998035169813916,-5191229811155895,-14890496071170961,25997259728311269,239534896229511804,369627986339866617,-542511259391570138,-1113271809873883883,3042147948828709848,947665440244282024,1836254853600593252,-6755131329726768401,-5680206036334330880,-1465878623
+6374628540732951412,-20,-20,-30,-27,-17,-6,282,-344,-1073,-1728,-4600,-6218,26574,18394,-69412,-71210,-458098,-38881,-91481,-2793757,3235573,-3419637,-13035348,37980630,44277377,-157002230,-408091469,832825986,-788290669,-1425975973,1395870976,11210024413,14593871734,-26078303420,-43875931634,-30587770007,63030102963,66443234936,-1343007815857,424456067490,-37425496382,7848576680937,14159752816370,13816974390266,-41056497266549,-91278483777335,-311332852111543,515924733817123,651381535896689,-1915856395315274,-11031116536208079,674960439108185,-6550609501995788,-20574536676808494,25843375987264621,262977283277063135,436904549237568965,-435376191865966579,-1074299134014850642,2804867448099337645,1714691865725307076,111443116727196697,-2980138006655807382,1538444585985465344,-1238759026
+6374628540732951412,-20,-20,-34,-22,-24,-31,319,-375,-1265,-1454,-4382,-5297,25902,24894,-83200,-41947,-518236,81615,-302759,-2372215,3827466,-3457866,-11337903,40729100,50267371,-166006555,-362928725,913547745,-828036088,-1340194988,2283686433,11200369802,11876859767,-18964892566,-38751086894,-12565136758,82273711929,-54575537424,-1467226561199,17115149537,-336558999690,7821129157216,17304514041287,5463593972249,-44482985101503,-77488925320050,-299961625039579,638930532842871,623357369679991,-1912836851896783,-11754470838264198,-79519785599603,-3130439276225855,-24961342075604187,38137649973423994,263578936570840752,485328174863574810,-510793628945161408,-1355564073554426765,2934646277635547350,1907139648686164255,1608890748314726965,-5876906154456770222,-1129451278663628800,506443817
+6374628540732951412,-21,-21,-36,-30,-13,-18,267,-297,-1201,-997,-4722,-5907,27773,26500,-87413,-70676,-538693,-30476,-402775,-1867528,4132901,-1545590,-7742250,35208807,43767705,-140818397,-297263467,864625221,-604612734,-1682805954,1915261648,12452666412,8687135142,-12724075637,-31184277870,-30598746357,75860132676,2822525899,-1213716778630,91493275869,-1346331692439,6974484547299,17630524795294,8747786068432,-33701277173227,-104903363388984,-236031440282446,660692794374765,685583449918235,-2124322892684822,-11730596675232447,440720028058952,-1296503487753052,-17999625329798209,35731667897913360,285523138923974462,512743118484640380,-649983825195132384,-1628600386615309086,3358825538949690737,1750985852398388322,2554080614842095799,-7442507329664712121,-4991469765973311488,-510642380
+6374628540732951412,-21,-23,-36,-28,-29,3,269,-387,-1450,-647,-5531,-5609,27213,30038,-88479,-96044,-554991,-52514,-326653,-1790480,4840318,-3352995,-9484163,35586017,56244611,-141907780,-245684523,910143713,-440519676,-1187743085,1298682665,11687259948,8801960230,-19331991005,-30908326555,-22024030616,91741639122,19153614408,-1223750711149,447637381428,-2431197531075,6175870463242,15655276534319,375067676649,-37788648677875,-113891913985120,-257933330582760,784130890913268,929258541196908,-2648898094117356,-11469050456095296,60307541788768,-4202668957881376,-24999952628563331,44990542655948070,294041972737875088,522588059490036101,-744122014372070762,-1883484596371604935,3007776037770833407,1251243007486992652,3613382545295093983,-5289892678632676608,3335053611255703552,-1451880677
+6374628540732951412,-21,-25,-35,-35,-15,18,315,-301,-1235,-207,-6009,-5667,27040,25648,-99488,-117186,-612514,-120881,-170247,-1819582,5883009,-4563581,-12299654,35860126,53844257,-138768993,-188040257,828122254,-419619815,-1724436064,1028220636,12752234724,12762067073,-13583777366,-40350484569,-15968743805,102879142347,123505296031,-1277511022189,294305900830,-2691699423872,5956301711480,14665239701128,7749918208065,-24797824280428,-96717916523202,-198753174394845,839005941370572,764274672254570,-3184452003927502,-12324257118504106,1872273939421449,-8154277563954189,-30345676785648085,60380724836558351,277597438510498247,513043681941427103,-808199238051792682,-1916681330610004179,3544895396685067218,1516311203927484504,2559346419214511375,-4145679316684798073,-4176830774137947136,1009396979
+6374628540732951412,-22,-27,-32,-41,-24,28,313,-269,-1332,-651,-5186,-4929,24874,21333,-88426,-128387,-561309,-47241,-71585,-1498930,5986434,-3952275,-10020155,36439502,54627049,-155406241,-158998017,741682204,-213443828,-2244655951,1149748384,13177185833,14865095740,-15690015201,-43350103616,-40862045878,165670891253,182769942156,-1019337547956,-97697201385,-2170381519941,7365542538070,11472543753952,11757172262926,-12723732139414,-85577032753905,-220249639489497,799028330527008,750536822463940,-2659727153356840,-12915947709423616,4105599114930811,-6858226644336706,-23047011837167918,63255874609761794,305819616044243590,515908853514978032,-864308682414414258,-1932719250043901318,3319279458390344545,954958155478343966,2212165184407359330,-19253023564425288,-6797018343578760192,-1915496749
+6374628540732951412,-22,-29,-31,-41,-24,18,372,-304,-1209,-284,-5983,-6573,27943,14260,-85262,-119305,-561863,76478,-149599,-1641753,6014142,-4167358,-7500403,30524665,46361692,-147675491,-122778804,649752249,-337345664,-2613992741,406939149,15266909060,15118771047,-18112904741,-52646731738,-32176365286,98143317302,175071447933,-905137961175,-440711416636,-1512695290165,5966329489298,9777613079236,18708274388600,236649115556,-83490570442314,-179780092198271,856664830248611,504595802686027,-2155332028218850,-12521638242863397,5848980784609863,-6030632353935264,-29450189320015370,75163161047925529,327075982241192267,570326134442024814,-893588437018826329,-1949688741870996009,2980533639676272063,1117876811158954443,3781878974103687838,-1836339499248986159,-1144925759177536512,1335955453
+6374628540732951412,-22,-30,-34,-48,-17,30,344,-190,-1331,24,-5704,-6026,28482,11860,-76512,-103128,-615760,92406,-66318,-1313722,5491339,-2127355,-6244624,28272564,48587330,-135920549,-90139080,521388020,-160136880,-2361594069,346229595,15895782908,11485540902,-19684222052,-36128057498,-41366100800,84636779790,215473413462,-716616767857,-554888927089,-1481950834604,5039944194645,6065731965643,26214510749586,6937911890726,-88884027194786,-169065390621305,739380382257749,246733428886267,-2667760791366857,-13351752063335887,3644829798113737,-2599630813735435,-21775121830760938,79332736973427041,315034209246006541,612530182419543110,-988860188874021554,-1710784016945046651,3412582219481141184,1306867491131184675,4647703201655719385,-4446512569852790139,-2753356675799821312,1112757104
+6374628540732951412,-23,-29,-31,-56,-30,54,315,-290,-1142,33,-6216,-7455,29971,15112,-87008,-135737,-672305,-10218,16241,-1613632,5430750,-1714196,-3812728,24534320,36015419,-104112471,-114710936,580408578,-162711777,-2570708647,549521089,14628368516,14925412727,-11114017578,-21569201133,-62360993208,34815427028,249588781153,-525251256291,-322205146300,-1441597504620,7025864490142,8330841567785,20898488991598,7357476689476,-87516609716892,-133753115888123,762048477342333,180904189296300,-2373526567803127,-12591579113498020,1752969990540626,1727437182953542,-26612413588673878,80261847678894620,344182408174877076,649427986631254301,-1103169829890367601,-1790853638625523341,2894790170137265823,1923052298553527491,4563578816013249730,-7970255600695938404,-5463976787314476032,-117656487
+6374628540732951412,-23,-31,-34,-52,-45,33,360,-322,-1313,425,-5299,-7413,27720,21357,-102079,-140986,-730553,36024,-71089,-1312519,4713721,-1203937,-6036164,29074688,44527497,-101894115,-156649646,692857708,-93417649,-3001056427,194475728,16242922218,13554325391,-18362879829,-12466966804,-55788845270,103212831081,384767330787,-679329986623,133209506373,-1410495502207,7788717154217,4244967139962,23968882582316,-7852914808749,-100441071659773,-178485130868217,858256584272895,233519326212079,-2784731147448042,-12110991172071793,3849253257665401,3048954018411997,-29440486268052924,70422077685195434,379957092855168047,624777631689930660,-1227750869271183562,-1937230656633551497,3330585754467613869,2630124990839939487,2624562177969298995,-8261777722248790339,-8487791775731974144,1714747228
+6374628540732951412,-23,-31,-36,-51,-59,49,323,-205,-1095,636,-5144,-6850,25895,23808,-91746,-173553,-674973,1836,68044,-1616843,4275601,-2623863,-8155119,32792986,47461554,-115366406,-189355866,627390580,23854134,-2952649488,-506469217,16066436240,16827608296,-12724696660,640887446,-25044066543,59702455384,458189514352,-618554878644,-208085544895,-1786341842216,6067005100449,7340041685681,22790213841260,-2823629070534,-91433675115284,-213382093692314,983500521840940,315297375148271,-2558580177581570,-11631903505297848,1802054310609068,4462041909549595,-29724325146073137,72480708991690396,344043882286243254,661907597091445502,-1215691762947300724,-1855048210693030660,3509807133343996874,2025638963399571338,4104981388557648190,-6830937064746506649,-7010957477052416000,964863654
+6374628540732951412,-23,-31,-40,-51,-67,68,289,-304,-898,618,-4167,-8030,27940,25412,-100111,-178560,-703465,1438,312361,-1302570,3822242,-4501633,-7397032,37950314,56598656,-116147981,-216626069,676033893,-237523588,-3457753922,-1014808333,18046428832,13024886856,-10548821965,-14252666265,-24446918106,47812259137,500823881341,-715363139388,-744867811428,-1033783846642,6629952803935,10900063147626,22710767230704,11636462989658,-76490060190349,-226450933748815,915258297623461,326962819410402,-2876850390376801,-11754402267049806,1720004538324844,1254401970424111,-38229380085798918,79444988566587835,348359126234743594,615867668144359320,-1297937610977074801,-1773559813656467941,3951285961886327536,2161099900481276813,3816015710474213155,-9057297884305891743,-4661849274934372352,-242482968
+6374628540732951412,-24,-30,-43,-50,-81,91,313,-337,-866,446,-4615,-9936,31648,23283,-103013,-164320,-689818,-121458,214946,-1795747,3167277,-5837245,-3884978,40621679,68103209,-104046855,-223912308,672880410,-197460387,-3075270343,-539704027,17923819761,11027494625,-14689809973,-18465326047,-25147919044,63706892703,555363191867,-811136772212,-567666320172,-2129638439133,6842313763534,9180021736617,23204833253852,620332527125,-106110894628591,-238206802620618,823872666817351,235506656216007,-2662863377169243,-12595282729252584,2730902375150050,-499073106457482,-42919245664150806,62629683859553601,340258728846916772,648111236840801866,-1233319547982891089,-1713886883799026531,4347291176081837997,1948553771224058444,2165344708784677963,-8182918300525240279,-4672423884810116096,1850970679
+6374628540732951412,-24,-29,-41,-50,-86,88,334,-343,-822,552,-4773,-8563,28541,21778,-106702,-181114,-702599,-18871,199871,-1682715,3911870,-7269521,50233,45400773,78569833,-90590356,-217307905,580254856,-50322635,-2679754295,-1460297713,18672355577,7068338577,-19014189952,-28378464458,1005482909,45416439514,624878576889,-793539468797,-942669828484,-3137368941078,8730400441810,6920425933638,23654691685607,8126278693763,-117571342258820,-182683482627695,912546946310362,317950641905452,-3108668975389727,-11541281016535503,3292596177627828,698003877076363,-47531601302720737,80122885815625945,363663619274314206,588556565903686115,-1123796533462488485,-1555772394394221793,4574162334421819801,2472082259050469006,4359474323906212139,-8417891917457325716,-3966927420500290560,2021858393
+6374628540732951412,-25,-29,-39,-58,-72,99,392,-330,-1068,449,-5378,-6849,31587,21850,-94576,-175090,-713154,77377,426434,-1202019,3031788,-8072216,1960027,53331471,78175460,-61959463,-259673425,537309045,-277191069,-2540364847,-980819178,17848665376,4554083175,-25899116979,-24686652332,-32863386980,17284054724,558517585189,-974431708815,-441540123080,-4203229173643,8642910365886,9028179886472,15536227913339,23356332984188,-125263936146191,-241279871503902,940199782410110,110520441037498,-3555197426415677,-10871690367158006,5108756022474372,-826533613371802,-54737242467421790,79349321517040462,346240899467885650,586054715695041538,-1050772393601967559,-1680064081935633183,4173578945659425618,3007417367181377648,4889839446169468105,-4279233551295749363,527161608402093056,-819473984
+6374628540732951412,-26,-29,-39,-52,-83,78,377,-319,-884,276,-6147,-5867,31691,18585,-97780,-190148,-672640,50535,612671,-1173038,3762523,-8476711,1040506,45511463,85529216,-67473945,-196937797,467251531,-147416367,-2400916986,-166207138,18308032878,2936074223,-33581715563,-18493716526,-53071674574,67449137753,633810052153,-954551635060,-485154547878,-4468004512266,8844363866044,8903255252018,18784112421894,22762704522609,-112734989714220,-255713072603588,828746968821042,-144729696657168,-3164651299749503,-10830915138900120,5844047512894691,630756959939318,-54110644498187458,89652945821050498,363347549777626730,598163956194650345,-1181672018674235778,-1748670156278398563,3716965187794332277,2006087793921667301,3983710801361025449,-6746803086810530478,-759754240423782400,-859081176
+6374628540732951412,-26,-29,-41,-55,-80,109,349,-380,-848,426,-7122,-4652,33409,22826,-83317,-163074,-695743,51331,638485,-1006786,3630081,-8331245,3571598,50495009,99543452,-44754466,-143467845,385888090,-324956747,-2830058998,180047147,19275948767,1134632709,-37543295228,-15988038525,-83550320240,68587789943,767197183337,-985687419615,34031055528,-4572515650460,8225876690703,11397636852686,11108653389857,15675149691759,-131717940808890,-274256999448615,834508262801872,123377733740548,-3007282669827243,-11466845376169081,7723157581798743,-3642790180640193,-53742958184052361,84943443235257853,333580610209172818,637349744167635474,-1151456253561790417,-1475572508660954769,3549731674385784181,2651277779441737291,5993273107366667577,-2341851212113195901,-3008250531965436928,1631043917
+6374628540732951412,-27,-29,-41,-54,-82,126,391,-424,-646,9,-8028,-5049,33617,15716,-68362,-183324,-721781,55984,533693,-887214,4163274,-6264922,1058332,42924448,89916730,-39750557,-101751167,329853834,-363394326,-2422231529,698704117,18305735424,679586336,-35985115540,-18844501837,-99175678161,111716724583,852540852839,-801677655996,82630285966,-4264830790175,10182365256028,15469505092819,10766633654120,4847917566734,-150847425804549,-310259783170727,863309936107864,395893935189264,-3255562165648646,-10780414662719223,8294449406109194,-6344909097028304,-49988036188660852,98561562851066440,369233830762565615,674194576027136343,-1155814073787791672,-1474841854690264260,3590994061086197579,3549514630753557215,4102967937842445532,-2341903495434569145,5382068879421342720,-1868650121
+6374628540732951412,-27,-28,-44,-48,-74,110,385,-368,-689,-367,-7137,-5767,30417,14169,-52886,-157447,-663825,15247,629319,-908747,4538847,-5230691,1719220,45356782,103841166,-49037336,-126167562,344341065,-109099430,-2366639609,768897333,17278049983,4139773616,-37412074921,-7624390744,-115956548170,48902981441,736991130638,-825698052553,349796036898,-3557349868307,8615502219770,18289475058146,12530716997766,19521579090772,-169662686263352,-363482557069094,806226432698985,267846450574520,-2709511930848729,-11858892326964368,8654166745035541,-5998544230237573,-52728755488190177,115351276884350728,361065425271933814,708772433087792969,-1050762640052730100,-1436426791475377346,3780958668024065756,4571786363213641585,3491203706672467090,-1566528373914751634,915568884374215680,-1733277845
+6374628540732951412,-27,-29,-45,-45,-75,132,338,-371,-747,79,-7401,-5804,33841,21762,-60885,-139701,-649890,-16032,620000,-1217204,5383183,-3193911,-747641,52537839,115803723,-73725155,-80885825,338484802,-120008497,-2693046931,1716155628,17963317350,633747828,-43211476635,-9520611870,-129297755026,79706905732,825391612616,-819049819342,388753504848,-2611160809276,7006708588429,20758146368567,13250532757447,17768888675028,-143794086101153,-379265664055629,790650012441659,33040056936476,-2994920089533115,-11830805998978633,7607739119550963,-9853853940546650,-49346240535824532,131338332165113731,336950022290349422,776171688983460319,-1142934338040761469,-1717349059225756873,3358647398887242470,4551994949591167475,3319390963241411071,-2285065972829249301,2799590967865009152,586207257
+6374628540732951412,-28,-29,-42,-46,-66,154,390,-358,-849,-126,-6879,-5655,31250,23453,-68613,-107372,-635733,114889,631118,-844843,5291382,-1287907,-3924478,46960171,101809914,-62627391,-70392732,374643345,-97112784,-2836087770,991389637,18995156083,927162238,-43355056704,-17905599145,-141147431019,41140851017,869133444005,-1060828138655,330866705147,-2795714526594,7063591787760,22909885827147,13262451089237,10404633313659,-112595927055615,-359707453993361,860818386059880,-139370593036764,-3271457449683538,-12576720349391612,9664223990061644,-10487672088590154,-51531540231962246,142765278810870721,311366995857085211,757943967604309111,-1280031553693344508,-1909947965126151990,3754516155946937373,5602883412283687304,4687650319282916213,-644365272789806908,7221334535370609664,597837078
+6374628540732951412,-29,-31,-39,-54,-59,167,351,-248,-827,383,-6518,-4833,34196,30041,-56377,-139831,-572561,116167,428104,-647860,4942964,221058,-3075495,50358778,114440154,-71248288,-47728588,242038198,-119793296,-2419519028,499691976,20706575118,-1780606920,-50859696417,-29315753256,-134386943788,62255613011,773418390984,-796298812591,630998851903,-2353521142538,5609987665915,22317294676879,10991209645865,19529578265127,-143908269950006,-310757354205719,760789052151308,59551418092140,-2804686902309756,-12421616574182134,7452292576380820,-9593059592805417,-48470638928483043,135320964107276796,312584835963037467,704994781778226591,-1174830542067833656,-1698681352808516536,3814490983557386990,6215406691397651296,6781509277901650422,-2960027300194949621,8599457568205368320,1387707060
+6374628540732951412,-29,-32,-37,-55,-46,183,312,-207,-815,-121,-7063,-4472,36731,24105,-51346,-122055,-606896,31547,370011,-265978,4767109,-94360,-6547901,53818475,120672734,-69120925,-88046214,123457758,-187547010,-2729914025,-164638557,20008691866,888594491,-53995094842,-35548543216,-123244119564,24444203084,827901706927,-1013011963383,857105679719,-2201890895966,4114411667702,20973230312380,11715108666734,31448361575197,-157587916604913,-287142255471065,731255258604309,-79062169965247,-2623390808120700,-11952651478599570,5249723721124049,-12704523493249033,-46441484185799608,149989944294032142,347484078884261316,645175018614062059,-1312675634122876186,-1432670369298625910,3359693309915171257,5519808213064122792,6956397731198547195,-2469731514148137325,3638917426092862464,-834578620
+6374628540732951412,-30,-33,-34,-60,-52,205,283,-307,-1039,70,-6573,-6020,38718,27569,-63704,-101030,-602653,34518,613262,-624020,5804322,1608213,-3357721,62182569,133064317,-88805742,-48403348,51694853,-295017314,-2264582234,719767433,18404233966,-1490482505,-48670558948,-50245835492,-125823587931,25246696657,803593660174,-1248444396585,843971901100,-2722506532066,3548702803774,17813778401787,15885252853310,28778589371228,-189868148751065,-283669315870966,664707767095115,-260272641595583,-2815292145765531,-11068854416286253,7298641190693305,-16311304411440126,-47876189273836578,165531397545199794,370475193039159545,654691430527973928,-1298457016401700529,-1629460169402113663,3832802627702231533,4607001287179126453,6655010837468901589,1026499539943863833,562811945294381056,968721222
+6374628540732951412,-30,-33,-33,-60,-41,234,342,-239,-1111,286,-6521,-5634,40750,33779,-59789,-132458,-648210,-87447,635061,-765129,6289204,486699,-1948434,64514242,133388221,-97177094,-36811839,158468107,-269123202,-2543120802,-279198924,19722352635,-3304507471,-56842773115,-50382454842,-126874605031,2140656214,930574028146,-1361022136580,862948371880,-3745560167436,3422946390493,14879013234526,21898501684940,42829146391205,-167791109209821,-295827269205028,624334250427501,-211861250690722,-3028417343733473,-10085678739329675,9260965682146499,-19248278063837653,-39334634788652299,182907978028319442,356569739809918970,626932697754022424,-1412387314194179688,-1687676504901894070,3873736336248542883,3987983795706048594,7075614529258895281,-1196730680420125693,-7693037553057636352,-270375334
+6374628540732951412,-31,-33,-34,-61,-55,259,371,-114,-1002,538,-6703,-6462,37311,27681,-65788,-162829,-681521,-6225,578398,-859374,7077659,886262,-1588611,69527432,122165378,-93502320,-33434139,271865734,-32930701,-2879885286,679284375,19170677094,-4504417716,-56839971365,-45728976668,-107109972235,26297011474,1067842604958,-1563036137672,1334160039450,-3466736542413,3411249296185,17695749391554,21233365262545,51062046059886,-184152677098052,-316510796859343,579946300844466,-475219590210051,-2707829592699115,-10039622607068415,9975626786779570,-20686070315464230,-34581932125890733,200358179523900970,367829404113409047,589177606771321008,-1307005752570686504,-1745222237966610976,4197714662601101152,2860658291209618435,7357110645705943356,1547189576841534049,1390490051015888896,-1931846469
+6374628540732951412,-32,-35,-34,-56,-61,238,387,-142,-879,454,-6785,-8232,39000,27057,-79113,-146859,-735615,63696,399168,-652343,7667954,161989,-313702,67220951,137641400,-65158534,24746358,339999241,198675135,-3255605413,-62761293,17860116469,-915729834,-49370284085,-50609923689,-100345718190,80152948409,1052872400713,-1761033972218,1743287295790,-3703395619673,5486657634625,18195568372892,27227298709955,64713455270693,-159813978757185,-365406493758592,659577127881713,-631502330073207,-2713556339461514,-10247303280768487,10050473347675265,-20596309596305331,-43583067772001292,189558392024882273,367789965166426806,562358000996702240,-1348177233673573287,-1701170979108713350,4537075309006861442,3908421581431534083,5057702399534898142,-75093012368510313,319523718827382784,-1147510604
+6374628540732951412,-32,-35,-36,-49,-73,210,395,-88,-683,798,-6623,-8750,35902,29376,-83218,-166361,-686661,107320,548116,-1130945,7999458,1269228,-2584948,67936174,150062760,-69277548,16484731,313875011,300174499,-3293482065,-153530633,16021509863,-4263544575,-52487391905,-57696382441,-95241908427,97792586190,1115556421987,-1674515478914,2038868522943,-3940556322677,4136243086008,14028728530458,32543036001385,60345117871148,-132785931924176,-406369904860897,695540993358241,-683617492559173,-3159602310762828,-11099732850775891,10521137272734348,-22459233423478221,-43212666499760255,201174538773140016,368787157987013733,516136215550841824,-1375749382268455940,-1512537768290256799,4062400608244641174,2941457990658869322,5323842645364349622,-1683814418770308861,-4713824146316736512,2119942957
+6374628540732951412,-33,-36,-33,-51,-59,200,456,-33,-511,871,-5821,-8193,32731,31635,-78086,-183043,-740200,167114,755534,-1587096,8991087,-61586,-3212075,74998701,142362942,-37260214,77148857,319968811,427441862,-3252069320,688939596,14461752149,-2223746018,-59456675339,-47438359837,-82697293015,113852707725,1097292244854,-1851794986617,2573455943292,-4138485061084,5992728645095,10745516114917,41261327700876,56097094651285,-164965141847621,-474598994319124,634666817617455,-797614157789124,-2916219851764874,-11585235093282977,9680872799108263,-20560727550603491,-40988612375761199,185872014533907748,375459834348701210,539004024122903572,-1426606582989446659,-1427919426516358931,3862533955924101645,3028502586166123451,5752741405597911302,-3139569336954903936,-8750611836306165760,-1312764566
+6374628540732951412,-34,-37,-32,-46,-69,168,449,54,-319,1202,-4900,-9494,34250,23464,-89852,-199697,-780795,268114,707326,-1278597,9173748,-1526069,149641,70148128,128364764,-11616684,21736559,316075280,405541061,-3618351795,-280502106,13315711936,355281679,-64065467615,-59798066890,-71952893738,127359432575,979117814462,-1817961841519,2727024059867,-4413108769306,7963384569070,7461317858250,45084248480375,48050473299833,-162891731769046,-486891252444997,516462235941360,-763866483172996,-3397160425825000,-11374098882396597,11081919829338028,-18291088263581720,-39088437104616887,184885902514903785,345493954295969238,484090689603310796,-1325094510160187206,-1679926153254772971,4193564668877245685,3371643531014518612,6559529206423869547,-5624217652107245425,-8964602401387195392,-18651112
+6374628540732951412,-35,-36,-31,-44,-66,141,427,58,-344,810,-3996,-8292,33709,28541,-80069,-174199,-745099,397743,965838,-963218,9669045,-981234,-3828765,75061945,132484544,-20266281,-17015972,268633752,500305035,-4097961103,-601887388,12209298198,554219998,-61431270411,-68281976117,-38371349623,187119656865,964017328591,-1960530516884,3044910475232,-4624768915348,8728290789373,9443005162250,50333808719485,62498276754141,-130507859408836,-496276898569243,476857389325869,-967310339640129,-2914693843721439,-11851987198539835,11141772155322901,-16582167086129016,-47339872290235058,200358774625930289,332546931530735076,523031831930157961,-1278091478894988461,-1846991330235613382,4539727135772333939,3667693881213861295,5800577281724027836,-7804059993101097329,-7246222198078055424,-696773221
+6374628540732951412,-35,-38,-30,-51,-81,140,390,112,-452,463,-3421,-9288,36497,31094,-94736,-178677,-785372,472173,743452,-1305845,9878395,-2937911,-7847314,77554412,133567241,-3220732,35816178,169028273,251228761,-3803067341,-690405988,10412678038,-2508569259,-65233770863,-53735171128,-52400084179,219013653662,1053754091550,-1850548051070,2654143532921,-4394829392222,8968127333526,7157623542861,51671818652474,47902945111741,-153761690047551,-436781851943408,425650723521684,-1168599607615926,-2670597172179341,-12475121261880929,11298750214105987,-12425597189204039,-53432309422321561,185429645528184900,316401487804192399,552378677011776630,-1304740991866260480,-1926181195088964621,4042046887417767345,2520077858788022246,5545109424419471780,-7572066790002063058,-4623709365198115840,-1782824032
+6374628540732951412,-35,-37,-31,-52,-96,143,386,147,-411,607,-2529,-9265,33109,32858,-88993,-208579,-847356,519180,909197,-1102998,9369384,-2276875,-9478847,83271836,118525678,-18800173,51645029,143682890,110785665,-3661969259,147577588,10490268375,-4817999329,-63919679112,-49531552680,-43508526658,268711162578,1156271756144,-2041878918852,2742672346052,-4813304723054,7816815555196,5022259708412,51325111610269,38267920068030,-145470040323284,-464978943920603,540833214889127,-1349867262676186,-2962453063011131,-12643349902609981,12697545666077932,-12787282857116113,-45410005773282560,199365002615858268,330385024417930879,559817953747283291,-1447504804658862946,-2129632050517639029,4049173915021189752,2539173320504621237,5648832293461106206,-4011091876437234077,-8155382000988869632,218797988
+6374628540732951412,-36,-36,-35,-54,-100,134,390,189,-217,1107,-1533,-9205,29365,36155,-89915,-178219,-823813,599089,963225,-828045,9240000,-3068928,-10234370,78954712,134882766,-34163148,15657685,29880875,-70741667,-4110930846,990710997,12270200833,-1164910735,-61206434555,-51628989407,-76424101354,210219003481,1226911519073,-1767159249572,2320891831521,-4888621591225,6681853169851,3348074878822,54663688311006,39652043262358,-143535336814988,-432352742625447,548236063354786,-1557636867375348,-3172930602691686,-13274064259385860,10525980405330434,-9654943545816983,-43541240129072189,184760395007388490,348195422216145397,518327874054627715,-1381740638904889838,-2028780846377700081,3825876645994706469,2385912857898249852,6469088456516849043,-3881376639156926481,-3046572392092380160,-1628947536
+6374628540732951412,-36,-37,-33,-58,-102,135,401,212,-395,675,-1088,-7291,27901,28990,-97689,-181102,-888108,664380,757721,-692275,9077535,-2222241,-12541867,81798755,144174126,-15376696,-8074758,36775651,71312393,-4356490237,1587674924,13176757562,-1523295305,-60842628039,-65917235907,-89310396626,207692941917,1132492649970,-1788013453331,2046656688188,-4101030318241,8513950320813,2025212182038,54384910671007,30159641059214,-113429231453043,-478664225455628,596045498972902,-1536712610127004,-3475453645388184,-13665745049168249,10096275304205054,-13356789968653376,-42153193503852175,193209635842732498,361459177520211192,507565598380196088,-1386904062047173794,-2246101483124320112,4066467083332929209,2838435398654213993,4695880952680864157,-3737197436952441822,1006344604582767616,-2020304383
+6374628540732951412,-37,-38,-35,-63,-91,127,384,145,-586,634,-1636,-6988,26520,24677,-91439,-173667,-868579,562087,733425,-821980,8473862,-1697317,-11444338,83641653,133297506,3700066,-49015317,147128870,-131213217,-4796881277,2447383801,12803015486,2641225121,-56367166415,-82117959043,-116466990654,221550239827,1157364013614,-1774507064534,2484006515517,-3195488743636,9630615920171,-1107553791594,49130618754575,36871865471451,-121855573437191,-471522249946398,588771938028933,-1471705699516769,-3065332090573557,-13441824848902052,8875588989667905,-12690004666336556,-49395157471916096,182553898367991181,386116365551016818,463738139522647067,-1492687553985044679,-2202949158961743414,3727799956393412621,2013575971966514551,3035869558209061876,-7063369388130612261,1905586765039367168,1420851349
+6374628540732951412,-37,-37,-38,-70,-98,99,410,272,-425,713,-1548,-8562,29280,30743,-93606,-159550,-871198,457979,749436,-710865,7934854,-2289062,-9622468,76323729,140693753,-10503090,-29401102,232303081,710066,-4774378198,1510912840,13950015151,-1131704956,-54378861767,-86451728862,-85498410395,258608143800,1098292895338,-1615025208618,2633957663171,-3821615185856,8688456136561,-4582913627485,52444098295979,49162767440539,-125638099662821,-526871779540349,566000445379216,-1534873268548053,-3625311341085152,-14515239475159464,8721392048692499,-11827955056515350,-44587165014154547,196392854656307131,392140118658621806,496108286447176225,-1387400811787095952,-1922737526766814185,3959223033228567916,1823376785636355488,2498984934392389608,-6001528151098105696,7049871471162862592,-1208217736
+6374628540732951412,-37,-38,-36,-69,-106,98,424,247,-438,1112,-1850,-8589,32120,32367,-87316,-132748,-822494,421776,706151,-1222118,8389128,-276673,-6504356,78242512,128054989,14104427,-25162390,173790831,-241197046,-5279133004,2283333825,13762158781,3134640119,-57615979420,-84812552834,-82093356851,310664641995,1001976415224,-1739428962916,2461816816141,-4407500488438,6770286837882,-6326063871940,58236169012878,59022462084594,-115962620084340,-479137788264829,638048878890547,-1793748246707591,-3834417455354363,-14530671313264767,6894369592225015,-12041145033544887,-40776841873267423,199492478866436543,407664286956371146,492978963754699173,-1272839999154479408,-1782859969978254821,3730231973657755111,1961773712188979036,2631624889575987806,-4895620516882982912,8068617170484591616,-1063256458
+6374628540732951412,-38,-39,-36,-70,-104,96,446,198,-508,1609,-1704,-7947,36010,36482,-88989,-134490,-764771,432620,454331,-794952,8463424,-1835700,-7117612,75936309,116329699,15205710,-35323034,233984480,-11800995,-5294951893,3269772412,13800912569,4068552249,-59869380397,-75004059301,-72789084931,361459323159,993101486068,-2011183964492,2947139125097,-3614864651386,6865502579738,-6970026570420,62019874053701,52580115796023,-104876480854087,-539266192850218,543235507801684,-1618948100181212,-4379837938609756,-13555573806876559,7038883407125186,-9188654165298011,-48068811820514134,204082463066948019,408219706521008072,450287141580927653,-1288226103861452411,-1912908938133599444,3665966296159911573,1789907880443199816,3457565015121435644,-672217644680183049,8818238964264250368,-938070979
+6374628540732951412,-39,-38,-33,-68,-112,112,397,209,-707,1856,-1230,-6827,37728,32179,-94169,-166159,-730517,491974,377896,-1206774,7472863,-2916539,-5004923,77683349,109936009,16963938,14579259,208246418,-17452286,-5317072125,4271102667,15819468210,3590454468,-64529467174,-78770508317,-84496127953,354542832350,1024022702862,-2084735807606,3043081212014,-4406654519209,8731386399214,-9513346421248,60641295403138,62282800272563,-139786953417748,-559735802096419,473419663058883,-1845722843934277,-4838159783160914,-13713493408830183,8041148277283472,-11355048202690935,-39505081427954485,218952093680577134,394590864417886704,452073566036931873,-1215779654226480112,-1663670085556740360,4044286612133845046,1611335907442210108,4606927745709274601,1570163329205903552,3741467196762590208,235872980
+6374628540732951412,-39,-39,-31,-66,-107,130,425,198,-832,2275,-512,-8382,34881,27237,-98966,-189981,-763075,497511,166761,-1324310,6739212,-4931269,-7018756,82178697,107849817,7118638,53167754,201883388,-173112849,-5523494500,4726024479,15294927522,2918117933,-65820801831,-64275880120,-103752872945,288932310039,1050727330554,-1959106728325,2998213012177,-5287764328284,8633448170977,-11081907357963,64562693757811,74413773367943,-147268051747912,-585128424719848,463509452078866,-1976593071980034,-4960229002379471,-12978515711768109,7905943288581118,-12741219856283495,-46737484666548505,228311833873800405,413822157893590112,386447355580938194,-1199780169111837705,-1729430081782634181,4009767433956683079,766555038267993419,6619798827737886397,1223739056458355497,-2649656295289261056,-814454686
+6374628540732951412,-39,-40,-34,-71,-114,100,412,138,-868,2023,-181,-8157,35066,23814,-90235,-202602,-708402,410266,388319,-1802151,6363820,-6496939,-5955214,88284981,116626133,18482874,39025734,168392681,37255327,-5721678667,5461657020,15746086533,3896518288,-57935649176,-60740245177,-91049896397,291181761661,921713437190,-1887129410365,2550807258987,-5870124315076,6776961083241,-14808386615231,67902133645749,62946255037780,-149506002539210,-522002556421039,495290559600615,-1708493396809432,-4651456252788809,-11942587722350094,7266920227318177,-8658178893721205,-50666709216046751,222204404010457913,437353079208013138,387763448465586349,-1184034325542636484,-1882495499875818898,3504990963889696988,315558345066612773,7790165577589091849,-359274093626188583,-676987823396453376,-1660238084
+6374628540732951412,-39,-40,-37,-67,-123,128,414,197,-792,2275,392,-8507,36334,30566,-75096,-196668,-727453,338788,262222,-1387971,6110540,-8023431,-5964956,81811186,128645556,-12021648,73612670,174842442,198845136,-6098300044,6463044074,17631967249,7004511816,-54855036685,-69752674094,-115418726549,317948855325,1009941991241,-1825609062074,2815940789372,-6144222041158,6316513805722,-18845563919021,66404802464090,77212262746668,-137123973411890,-525142754004568,627554447260146,-1459888974211256,-4759390715494352,-13024874604481443,6575072153742266,-8704933743162753,-49463170484588459,211985042440978822,472244685922998934,418960255103920035,-1141623649288192474,-1894627889919696224,3297260722114140546,1000916164731161882,5601318004953738809,968288850436984346,-4785318585294678016,-636905219
+6374628540732951412,-40,-42,-37,-72,-110,104,381,239,-743,1874,1346,-8995,34426,32613,-76718,-225952,-786032,225573,312683,-881571,7067828,-7599791,-8958856,89459296,139022704,5419357,115361859,116893297,-49315771,-6138672362,6939429309,19427982154,9127067251,-63251415441,-73087824128,-118590976769,306615374549,1066021483478,-2067102003783,2623309659715,-6488418568768,4583007058972,-16500888250869,64750613324388,85086461509022,-129899373581070,-579741357658476,701160890118273,-1510627801739556,-4972810321150988,-12138944486174407,7973281253977700,-11727732387113745,-47077254585814651,203823971082997939,447680736980023485,360466565499038950,-1172902350805848018,-1771779094644968436,2782323821850230912,1199766413955172033,6120544835931937752,1453576488978946045,-3523409814968320000,-1811260799
+6374628540732951412,-41,-43,-41,-78,-95,90,363,295,-785,1817,709,-9657,37304,34017,-75656,-193213,-807146,204165,553620,-797831,7447004,-7495583,-9729390,92074063,151938472,24364485,131734874,33466355,-315502314,-5727153554,7114286185,20130946341,6354404357,-67983845788,-89941926256,-132928137156,275522672274,1180221393508,-1997220639481,2868446489023,-6022830167580,4988218246616,-14477525678258,57040689222166,95316245325535,-155816394865680,-639119570072293,646432917666315,-1709856373260501,-5075461904493525,-11041728029216525,9935692070689387,-9330218231617743,-40449230328864339,202945524349407769,463694593843374150,364353269733837924,-1049047466446615147,-1774069708261589029,2607199539649723720,1895558856784038479,4956270224424155334,-1108533672657095756,-4083979310053133312,411888961
+6374628540732951412,-42,-44,-41,-76,-103,64,367,208,-646,1412,1060,-11353,38264,25876,-80450,-197911,-824246,223687,794400,-771918,7302764,-6441378,-10171897,84159121,135515265,32885915,118695356,-2765699,-479918800,-5531017940,7904641316,20362903283,6493193866,-70296652012,-101383883763,-124801764393,226111785262,1306467180009,-2156947173426,2966047511093,-5236343152562,4204620477322,-14515603409396,56757823485397,89894921007348,-179981110711619,-620522669604062,687054362846987,-1606907790051087,-5471938005804126,-11184550775395370,10749666339610601,-10332152605165534,-46461517993474916,190033447928841088,429557166349231319,373813540270527122,-1048888920731058353,-2061795826086994254,2694927246764476601,3025572460426180328,5118737071565588031,-646116016036719598,2551857541280099328,-1285929030
+6374628540732951412,-43,-46,-43,-79,-114,53,356,175,-831,1754,1606,-12572,36300,31421,-85493,-203852,-800663,315547,990971,-1286899,6552019,-5292893,-9137492,77299312,122316960,45185786,75759257,-73099490,-515973865,-5914603679,8834659052,19258002376,5610402719,-75459508527,-91981481231,-95802240594,269159189823,1257308124826,-1938422631655,3215717068302,-5463787233257,5651562161014,-12637026969111,50206134327993,82339013886794,-194091054808604,-582678010642877,628711860460212,-1877761224962800,-5370740028514652,-12265920263362348,10969266341613166,-11821208062284066,-46405271987195483,188364154843314559,431436928106815150,437645869446525908,-922471520261745407,-1877421992467518680,2479928264236581456,2336713371370163973,6049317990492099554,1199557678540785016,-5222799002775834624,-475713454
+6374628540732951412,-44,-47,-47,-77,-118,51,400,251,-712,1580,1816,-12829,37771,34054,-78769,-215511,-786728,191937,1237419,-939712,6930107,-5262283,-10000502,76609194,127789749,23898148,15879537,-83519235,-728860715,-5662603398,8147245453,17509319215,7662317441,-80519731828,-104375449728,-71693067212,254347073828,1284085779356,-1963465058010,2982340961158,-4556520561396,4076429546590,-11806557655370,50951709173187,71839184865956,-197414210861315,-596428188445548,659348643363765,-1853514774312183,-4919736995202344,-12632803693274249,11838703993428235,-12422311167529969,-48474264547464888,182747326600687560,461582091886722530,504856109734314625,-882162103188112676,-2089873400507627679,2592640091251622504,1949653930042493493,8324198275173543888,1439809101880815644,-5247098221493956608,1732679049
+6374628540732951412,-45,-47,-50,-71,-130,65,400,298,-583,1963,2778,-13341,34470,30678,-73560,-226158,-725492,188354,1046382,-955903,6110340,-6687553,-10489954,81686843,120444726,-8411077,57970865,-167948627,-831254913,-6195929841,7861782656,16747050830,8252419308,-84643192488,-112335611167,-79998634127,218114176886,1366889331538,-1744116811227,3160362475947,-5469043961396,2042402607453,-11632536941387,53158857883839,72190710506782,-166837885487115,-579384480705387,581283393180936,-1944677691129861,-5126350249643400,-11528449663650775,12176681624673375,-11044267482258341,-46739704154098134,174542805502314169,496454348192259990,575016298783264771,-961987025232707828,-1936869549477960060,2491999201774458950,1965770238569355480,8383522717497255185,1779291633912754271,387858304686735360,-451372708
+6374628540732951412,-46,-49,-51,-66,-117,95,373,232,-504,1808,2079,-13057,34054,30074,-67965,-242715,-751136,94143,1283734,-1332932,6363715,-8716040,-11748823,80749085,114833847,23970326,14897560,-108052257,-862047670,-5756835510,7504317864,16606581278,6373547349,-81031039138,-114682966820,-88000907907,233895785438,1340562930914,-1835605110310,2823495302826,-4785728618595,2872713674922,-14648007234045,53530822178267,80141689204285,-154794612149429,-568627273657184,459844173018070,-1961880484079096,-5147415883570885,-11676388685957091,11625609279377472,-15439236979576653,-51108626328622011,189837000946077122,475156471886042912,625308922839990300,-852239746214340763,-2104373073677337022,2682788201611642701,1990200668451555062,7189155810294342023,517656945757861140,6940487025590984704,-1553763394
+6374628540732951412,-46,-48,-55,-65,-130,104,366,227,-723,2299,2297,-13321,37939,37573,-69139,-242937,-812585,54133,1420777,-1068256,6478492,-10262503,-10803715,82257867,118361002,40365755,-22327208,-7253850,-961433640,-5423529250,7809953454,18626533481,8106426012,-79943000118,-99464764679,-85641969231,171072200538,1310275600635,-1891456048531,2559615295139,-4922378147666,2972851420022,-16326277586205,59586597187478,89698811502550,-183486206210301,-629880336823450,571425364638686,-2074528465437639,-5656124409767369,-12107901944377058,12127034437433224,-12079471266681041,-47095705917805386,174273113550767564,482684000538631195,688915217050824423,-979657970062266319,-2155874136405689572,2742530858211229942,1879143600504441295,5680536487428100046,-622680556378811197,7902354714122369024,-2039789440
+6374628540732951412,-46,-47,-57,-60,-139,133,313,141,-575,2434,2465,-14503,40054,45521,-59530,-224835,-791177,-73897,1553609,-1255860,6676478,-8653931,-12132881,75447649,105979286,48057902,-42835060,-19524297,-851404659,-5949917717,8531146704,17309148686,9400152064,-76856862501,-86436694646,-71144146997,198147318162,1373335934659,-1695983952664,2443293617401,-5422930591462,4079930503566,-16326523326500,57808120457333,90484194495550,-205637054338371,-616280477822855,566590171841639,-1939882800065379,-5741928190724373,-11098682940172617,11047958452784360,-14546641559978326,-52071272278094735,187472496937267689,517699842299861637,756038819852606903,-969011049186827796,-2204398557520880811,2169642142327102791,858828122194633032,6540990926119196466,-2719162247971361239,7333478977289017344,695340617
+6374628540732951412,-47,-49,-59,-55,-137,117,342,64,-475,2083,2595,-14006,42008,40347,-47264,-231725,-815411,-23759,1669193,-1630072,7158822,-8262808,-11495979,69115579,98949362,50356210,18902748,25844967,-647631334,-5814984982,8622900138,18377007830,8449233914,-85122041735,-78246783072,-71155188479,138069765273,1299908377724,-1598697282952,2874599942694,-5692657271845,2691476948080,-18632887634450,60876617807365,85491637826038,-219803692170910,-641822881403569,500828708292134,-1921854893731975,-5349205513890604,-10278592917410981,8847742247340635,-14307575137522503,-49576783304868622,173125846014563448,545607791642628234,700079608493977302,-876815073374780747,-2213096197477791559,1662563642865584337,1940014138896002262,4507882363400025276,-6251344903165323077,-1403637558209166336,1442964907
+6374628540732951412,-48,-48,-56,-48,-123,122,287,152,-264,2547,1688,-15685,41970,43351,-46458,-254215,-805148,3144,1462137,-2130958,6360308,-9697794,-10746047,72425867,101319821,31331176,-7858523,118990842,-847864974,-6344994921,8864948380,19875119867,11644443447,-79039685856,-78583408148,-89221737322,158283374180,1310508228409,-1681385132563,2389319807811,-4806450228606,3193721153127,-17350039848972,63943491332523,92035149746785,-208272726055286,-587945825106399,619663114243942,-1656654717029550,-4798936425400349,-11237351851740681,7072307529787665,-16406117738291975,-48734649738693866,190708913873683214,571445256369487145,685868732065150513,-882334582854549957,-1929422683452269447,1211573572899084146,2118061882429118029,5183154157461311526,-6101587858578549513,-167283067045387264,1555938110
+6374628540732951412,-49,-48,-55,-49,-136,93,276,209,-512,3027,1768,-17676,39031,44491,-36341,-226713,-831779,-113184,1572580,-1939678,6338090,-9911102,-9136323,65820295,95352137,49078429,58520379,37057409,-1001859995,-6778028616,9779200281,19316422769,14350316492,-72904599968,-78816439252,-64194003066,213842252047,1274182931904,-1756795071918,2269793838794,-3923755732333,1565380838628,-19722547816375,59315058853939,102527903552662,-242261713989715,-552783188497336,689160438038705,-1741164277276200,-4549168698998689,-10428083453815915,6413262290837841,-17978282688900716,-46831339889872198,192873046428648719,595108337784693306,698971965464792708,-802193649956560603,-1906275901860295822,1064557284487289674,2179778294146500336,6578115720360839459,-7994898938695673833,-2614823640976075776,1210151157
+6374628540732951412,-50,-49,-59,-45,-151,67,271,305,-616,3430,1236,-19587,35330,50937,-50760,-259045,-881974,-64828,1809636,-1679373,7169661,-10956594,-7883141,68844570,88833091,45950417,17596187,-29707549,-866390733,-6738494850,10460426657,17309629268,18005714280,-80920156274,-76520226543,-46464434203,148209015763,1257155576022,-1646205937663,2255687153501,-3062773276531,3477405684434,-19457260507041,62027053067436,107068162922067,-256687173297797,-582703412609961,579720201114934,-1548464316769697,-4209118039074111,-11471571258602589,4414796007100298,-16856478810064608,-43414051936448393,203761950041907471,587335177142113380,631636873026977044,-700818957307189211,-1739894714040849267,1055070515741145372,1955752311412394578,5452863932871713908,-4733179990366215979,-3197342028759286784,423952213
+6374628540732951412,-50,-49,-57,-40,-163,35,256,385,-838,3339,1127,-21479,38232,57725,-61920,-227137,-851566,-76695,1985901,-2103596,6264599,-11983545,-11777854,63011240,102157238,65454135,-49212589,-51354575,-919633446,-6968593140,11146705394,18826820116,18671741975,-84982241477,-83108835206,-61837761679,202589197065,1314979875309,-1740356787340,2062350494417,-4051736059322,2509034185254,-21214625470327,55943993615440,119610717831157,-247879023410693,-559471145707694,647706453736780,-1794375864872445,-4437028333585386,-12524691697461148,3152548997481018,-14548394636500358,-40360702123479854,215718899541244238,590603629162599591,700602065030532390,-665276109615931635,-1696805852524625721,1420811944713932733,3083297226469798189,5480738574728036944,-2679632797956836572,2938874603099812864,1041026686
+6374628540732951412,-50,-49,-60,-35,-151,46,267,387,-1059,3671,1241,-22084,35171,65709,-62694,-234871,-907065,-11113,2208770,-2326381,5723463,-13025905,-12630209,70730392,87085393,76408813,-112909322,-12568767,-749870607,-7130130039,11465928668,19459549930,15507836653,-89116830315,-95795181547,-55593302659,163553364364,1220884916852,-1668803733690,1836707382578,-4101733390480,3247750877245,-18971119678320,63901852728061,107364885362989,-262893136287578,-593927531736093,508059487422070,-1766056592342286,-4411022904180447,-13387288407241338,3660578993053929,-11191920360938041,-44124354994665752,202052730520799975,561012064334022859,707667765691200701,-579280175915995384,-1747231062123573055,1793749658217685564,2746232389848703948,5339812799281735673,-7003948660854754465,3920753964206982144,22197833
+6374628540732951412,-51,-49,-61,-33,-165,26,303,302,-961,3737,515,-23679,37600,62148,-74372,-216956,-949091,85475,2131216,-2681056,6622243,-14351272,-16444771,66409143,87491409,93821007,-130399498,-68669007,-699598065,-7506349380,11574595998,18412313084,15325443681,-88324568988,-93658074171,-88670769582,222465789597,1299382966740,-1840661106153,2347380754258,-5106990504214,1923281197489,-16340381472482,65899019119768,117179990294330,-277205829313524,-660616646409058,606975159971694,-1621112116618137,-4078925246616441,-14300113142200274,5800170962969693,-11218015947792405,-38784801033999917,205507175570830359,543211939544254672,641093687148055937,-660131266865159858,-1768159019384920721,1604002027657934184,3220875037028116476,3349419034099937974,-6932196335227378849,-2443007973000950784,622534857
+6374628540732951412,-52,-50,-61,-38,-181,40,262,331,-994,3568,1282,-25382,34472,60145,-71635,-228292,-976161,176727,1884787,-2255896,6571141,-12355414,-19361900,69936148,76313644,99157185,-67266385,-129997380,-604231211,-7054912232,10712635670,18777611043,14760209439,-83986201722,-105670220084,-89941688097,198868197571,1245306929876,-1729444514796,2110075691860,-4561892361556,-30358042608,-13171242675504,59476413945757,125067184400070,-284979538332092,-653433479878894,485361040772147,-1895002241710044,-4028428462807692,-13602916896305312,7979796030762473,-7757179410268048,-47273047696899135,220949166408669670,533347665435080067,700781082762738160,-536390663546183547,-2018465043982888278,1524949912157107249,4287108616847611866,4698977507783970191,-5727470510526279777,2111357080852558848,-621822108
+6374628540732951412,-52,-52,-61,-35,-172,17,269,332,-992,3353,1774,-25071,36205,52191,-58198,-213270,-1029340,198528,1681628,-2144951,7550608,-13624381,-15405454,71961224,71668165,120420561,-24351910,-1595284,-482816062,-7273632332,10943226333,19587529712,14486054140,-80063901699,-99354311719,-113552559695,147679084053,1366734732305,-1734799767562,1567622133750,-3967548232952,-1341264250983,-11873527478739,57295162754252,118081567439616,-283914836804983,-595731886304484,623561559505921,-1855498069541521,-4319343370908339,-14309126383288885,8336205471201308,-7542945571121765,-49494426872641560,224584538762621750,564821075952517368,725735797072033111,-581666686439559470,-2217669588578280534,1609081582447377725,4980474348149343848,3597326847308165153,-7767906826370564737,3067004336514333696,788493476
+6374628540732951412,-53,-53,-61,-36,-182,31,229,288,-924,3458,2425,-23689,38729,55824,-60364,-231208,-1002396,266460,1893272,-1660275,6778169,-12147742,-11403780,68542918,57323420,138082574,32856717,-1291925,-607483728,-7497452479,11933415063,19404049776,12760983699,-87313114237,-82717757231,-110114225655,145594075977,1332338494577,-1464593928733,1385390762267,-4818805498469,-1345921610270,-8551076714101,53924268659535,110195261981209,-275023950958654,-538145787088806,681378979203267,-2091164446177739,-4372553251285275,-13293863798333184,8184012923156865,-6836929275856782,-47878582081477436,241987821481776827,560654109563618493,756220616749575859,-554541603520125029,-2255971975076092562,1296034881450679223,5357766212586020049,3904674416341925108,-5009951308624706707,-3649788855659347968,79321857
+6374628540732951412,-53,-54,-64,-33,-182,33,184,297,-764,3458,1832,-23320,39978,48141,-69416,-254323,-979368,163542,1911044,-1972051,5766261,-11095764,-8363796,72034634,43902653,163626565,12491699,-64124284,-765552565,-7728643109,12541436786,19851709192,10102426749,-88465452875,-68503613257,-139529471526,150454015187,1350024696840,-1454820139591,1005191445406,-4774466791619,-740426955459,-5354161124705,46671198523388,124903357442838,-304673049752844,-541967828957445,614394531709493,-2065860012171296,-4393261045687890,-13155737647468722,10089579916393117,-7952073755123904,-50452482420172858,250943962321412007,564215437283337612,744212093860481030,-512012533196501256,-2148669436652415883,1320138212370364120,4634490801279917606,2843975525620056001,-6716011747887536618,798425618423955456,-1749802096
+6374628540732951412,-54,-53,-61,-38,-190,15,206,340,-742,3074,2773,-22802,40958,41533,-55152,-277734,-944132,88515,1920900,-1576642,6302705,-11271044,-5705400,63680925,56487720,146308811,-47887290,-103135641,-723231634,-7281229715,11979679809,20668868822,8377145352,-90986967142,-85532140705,-122153892711,85320287301,1463729862519,-1617944400153,615583695314,-3928787924069,-1054747513797,-8322921391432,46388775260189,140671535259707,-289469792950918,-601529429541536,588365252786116,-1804650505066921,-4672662474910832,-12187480301699957,8812548848302593,-9021737357381310,-49216473942352244,241550980430314229,530101984941967520,674647495708235297,-534369603468338876,-2168401793311129871,824632577334471426,5513711426328704946,1948377211214846377,-4858593730735577558,-5126247557571821568,-1534912089
+6374628540732951412,-55,-54,-62,-35,-183,-17,169,286,-936,3019,3299,-23283,37910,40580,-48900,-307203,-899409,-30467,1976858,-1509599,6819728,-10717686,-3796432,67939031,61014567,114764117,-25493990,-5647392,-919082111,-7275882714,12421190936,20866786748,11987252103,-82804796759,-68414183389,-112827471156,107761006851,1511150985261,-1497451624023,580652235928,-2977102695882,-2276152856445,-9886765492869,47209441354547,137330223040966,-297756622428120,-550989839884154,656770878899585,-1609912511262045,-4342350336246961,-11954676193703704,8183182998065905,-9016406632273322,-42820739206378649,238425617377228617,560986638754400925,723036202095426367,-658554585045417960,-2058589944139959379,717211828867526717,4714566399587736772,3068687342862130098,-1519125043214860756,446028044030550016,-1431202288
+6374628540732951412,-55,-53,-66,-36,-186,-49,210,202,-1113,3257,2912,-24489,36639,48294,-56896,-308256,-875256,-151256,1970349,-1161385,6715878,-12623712,-5316153,76312129,75503417,86550776,-17012253,-32474152,-857429960,-7010985002,13470447282,18875814854,15309606005,-76470850375,-85369437810,-137211467788,105813860914,1406778161004,-1236776456583,1110142260347,-3155301016486,-2260376036981,-9264475897640,42016383373654,129250351514340,-283638339136113,-529034832398457,613116981151072,-1383742280214691,-4490170182862974,-11959633406259924,8474454841832193,-6725670254256008,-41949610485835320,241596240195162933,583791929737860140,765773691265385781,-720871638448707190,-2281685842310316352,1055611316942118858,5614667356366182086,5024702147610388993,-4464987916692976219,-1433175099473879040,-351702518
+6374628540732951412,-56,-53,-70,-34,-189,-45,252,120,-966,3024,3810,-26167,38789,54086,-73066,-299901,-813050,-123161,2006318,-1102076,6518670,-13875068,-9307411,68441050,75502996,98385587,-5464692,43109373,-747290982,-6506924136,13232229451,17733344290,12260572060,-81385569578,-70693234092,-147128119025,108013125141,1291686525172,-980748930939,1525974680755,-3246453226972,-3568075576511,-7869769749198,42371545354735,139052214841211,-258865938045558,-569957788313107,612594117017143,-1454666356703857,-4286035159478905,-12380627439407220,8970926073747863,-7327914795202501,-37790288619122478,245849122064823460,605118737802951149,792709575796533089,-851014207235546890,-2433222539208338994,751980718679344414,6449473643662610349,3773174406570656256,-4856936427189025146,-1762571256639486976,2142811258
+6374628540732951412,-57,-53,-71,-33,-178,-47,263,221,-891,3089,2864,-27326,40392,47976,-67609,-303715,-808155,-93670,1924971,-668051,5980714,-15288879,-9156722,69308158,70893967,104747642,-41831442,15913788,-841082546,-6586651715,13003828273,19136098397,13099787929,-84796628378,-81831233768,-125808337135,77767619028,1162250525511,-1103969547480,1230144308178,-3290997707910,-4559381547836,-8851812465308,47274251597806,156419703758422,-249672176992025,-569514330850843,534586453865125,-1503096168319371,-3874841449822236,-12837435203506293,8877328662858137,-7120147456987875,-43016253508903738,247632369732910509,615658133393516284,735448885825372317,-885701304270805007,-2613685142816348300,757533706644839090,6299397379475353811,3199468272085599185,-6253360485992063332,5923543398281102336,1257470651
+6374628540732951412,-57,-54,-71,-37,-165,-41,315,289,-867,3376,2432,-28419,39268,42211,-70032,-312592,-806275,-195278,2180639,-796788,5501178,-15068018,-9444936,62657927,72391126,108293746,-104839412,-108090176,-795204969,-6418296146,13855878672,20604973766,13579828472,-87940890464,-67641091869,-139824286101,146375599290,1206091444376,-964964749781,1759153305027,-3237904173959,-6682494987462,-5993604246973,50508752900199,143806592416326,-275729020525252,-563618661051927,472203817063785,-1326856749319032,-3729485587563225,-12202443233143853,7740851755829195,-3916699346459914,-34064655234802188,259682687337958367,626010432699524829,775158824552385559,-814594953934282577,-2586573160229410024,1298962787368301340,7136951285075546146,1840389020386918252,-5526707427916262499,-1540166123846088704,1746145889
+6374628540732951412,-57,-53,-73,-35,-165,-24,278,284,-688,3562,3276,-26789,40111,43478,-69223,-333884,-742625,-202833,2294556,-698802,6537502,-16140525,-6900232,58053449,77689745,87712488,-155107712,-96879285,-1006529354,-6056280201,14081452534,21690675548,11791582302,-84800061290,-76221994609,-153596582682,154225425997,1134754712306,-810767620718,1629636449581,-3427265293596,-5751173004355,-1628928529538,42407979160983,134990441695991,-258210926692309,-540712650386360,471343971575021,-1149470038963501,-4276112593583798,-11175310101500560,8262970807530702,-1327424867553263,-39879629153121794,256562551488598990,630336790689104987,792722403043694309,-714715264425706720,-2533955121267415392,952247044252734060,6942274201671703054,2666873300537375319,-8777244044758871405,4334750615438155776,-894802465
+6374628540732951412,-58,-55,-73,-34,-172,-28,290,170,-488,3621,2252,-26739,41389,47138,-71095,-317905,-750831,-157755,2076444,-764931,7474344,-17987827,-4138453,65837000,84467135,93321921,-161415539,-16291882,-1054930001,-5765252746,14090959999,21075137300,15349459562,-77487897921,-72114041227,-149272145446,144527325364,1146014801704,-971903369647,1842093330773,-3870172201771,-6907678682738,248821179556,43576219571110,150870589147488,-291817462523120,-485824522069850,394961265758962,-1234076926608954,-4480192277705225,-11017974164595208,8021545845632372,-4296446504431794,-36365117591539354,267378843157061099,615071565288066932,853321122125472645,-576756244888406050,-2371034396488263593,467463791596147654,7578967028293700487,2416525895963057490,-5089401691193470568,319600478081062912,1162853737
+6374628540732951412,-58,-55,-72,-41,-185,-13,279,269,-409,4062,2790,-26869,39496,54058,-83884,-331825,-775880,-126932,2140693,-835947,7610558,-15926930,-5271083,71993407,95242615,88973316,-137552632,-116955144,-1038663672,-6001731478,13926874029,21968599309,15867389406,-74514431039,-59553324731,-168166256721,76988474001,1258350288543,-1124690115626,2233414538560,-3429686845325,-6124441806100,4005475356237,49224754642250,148028991525776,-309933077678358,-472441172719872,385392601737145,-977730026318556,-4590066711802729,-10483579111380478,6291471852155243,-7163739655159596,-32810456759436613,278497168140821387,580857913440889390,892382038071422355,-680740620597725056,-2463951373114058450,492209989501824035,8672980848172287119,3740148477767059118,-3315201418655793873,817552772347598848,-1470789998
+6374628540732951412,-58,-56,-69,-43,-177,-45,321,175,-359,4356,2765,-27683,37738,54707,-72008,-318430,-829784,-16193,2072100,-497767,7355898,-15325770,-3721266,76294845,78872058,77123477,-201112557,-17734788,-1175128147,-5618978222,14831041762,21892981453,15506150219,-81835314067,-45767478254,-164320641040,84358160354,1212247720249,-1064959761926,1958200142706,-3063136960119,-3987248420034,4338284705993,43024832309173,139804180481590,-338203052066211,-436748747552712,490453875631874,-993609363645074,-4381997685870018,-10072648193185950,7013469753468290,-5159608845471243,-28689478784319190,291883720414574907,579737155416817442,914047117765975961,-615129090813892274,-2496116080503984530,774805743286380390,8783685114524175792,3125533900243676294,-2192475125505490182,1425649057641590784,-1152005073
+6374628540732951412,-58,-57,-66,-41,-185,-16,257,148,-340,4746,1866,-27923,39567,47710,-67401,-336298,-883467,-52935,2143388,-152197,7108678,-15876373,-6874313,72489089,69008730,97723073,-190743456,-57265623,-1080390361,-5939618226,13989988689,22661130806,11888996145,-90199259596,-34936600686,-190470678195,28182777446,1310135038876,-1243625941788,1410219974804,-2715745433395,-3443646279769,1024412814753,48783381643462,152372259673689,-359696498357610,-439063181214863,441239461140295,-891411082859214,-4481824328039713,-11037641501976245,5195367633090444,-8202211807373973,-35656242796873198,297184580221928437,558961317791965382,975972142727600547,-708370934112494507,-2692515409874798062,270349610916384835,8803535686130338880,4612782136157352223,-4571833309228366949,2313160042320857088,297457703
+6374628540732951412,-59,-58,-70,-47,-177,-4,296,260,-478,4813,1125,-28782,39238,43638,-57111,-340344,-836587,-16688,2181300,-671898,7178036,-14743469,-8761175,75266314,81689237,76588091,-199906493,-155333326,-973742727,-5476885645,13197006483,22693979777,12909498664,-82798128374,-30201925878,-196822815539,580197761,1294230486030,-1234778327639,1342891359281,-2929017063435,-2679720379785,4373563802683,46338728494535,152904211943553,-327796431859113,-450697851205107,515948658647393,-672357624445928,-4691848474132507,-11233032837146841,6192858721628495,-7055034956899720,-39130251438815956,286371702610658209,571677214245392366,955077310591235120,-719669966890007200,-2565060180844682649,439742563319016859,8793398021007861989,4888729906966441163,-839682022443064416,4765413486522873856,353671258
+6374628540732951412,-59,-58,-67,-40,-169,-28,348,224,-416,4939,1632,-28111,38205,38093,-66866,-332600,-878619,-15188,2246209,-1146975,7736698,-13341477,-12869624,71538923,89302951,93086489,-192458516,-176960361,-881298884,-5637077120,13795631369,20966569873,8630042948,-80414080242,-28050821754,-212351721086,49101774071,1187597687812,-1220284809111,1039844114226,-2834941188295,-1461910554644,133864499674,49250328097563,165437913172563,-350198996562119,-444052162742804,467212955135041,-432412212128514,-4497743563267937,-11707064730959208,6024511421256888,-3473841074155911,-41191245600181235,292714215672932293,574544219644734294,997020919853241068,-659894914505595900,-2841471158373420807,-72811216757352649,8096806697047897801,6546995759356956626,-2620952060971995535,-3379016287932151808,-506650075
+6374628540732951412,-60,-59,-68,-46,-184,-27,305,104,-559,5155,2312,-27571,38464,41209,-65156,-327461,-828732,22843,2267704,-1398764,6948233,-12933792,-16461527,74768535,93055563,78611586,-177307853,-280237325,-979308030,-5745438674,13609295403,21831501389,12098040111,-87042595778,-19896322766,-226128415849,88268655706,1135519857284,-1033850735272,1228907084711,-3408121906174,-732858891974,-1918957017147,52727461921404,179218054928089,-322962337198541,-431253228935376,553454001150245,-324100578465447,-4204428482660586,-12528060969656291,4249943943256608,-7875083576130197,-39581391562829823,309696196801709867,559320128212695506,1047572649831036094,-521086380140768825,-2718229440540886610,117485358416072806,7011783585313980799,6455931794349002763,-5385886694131930293,-7196626931429637120,1555481201
+6374628540732951412,-61,-58,-69,-48,-179,-37,263,15,-463,5336,2798,-27982,42083,41922,-67458,-319628,-876363,139603,2181209,-1721078,6674230,-11472159,-14254483,70661621,103043067,103232084,-244385058,-219796975,-1135176542,-5266020868,14454544314,22808823746,12946100992,-83384831094,-25749320639,-202160513984,58031756084,1126519614163,-925722763857,703693102534,-2584669229467,761287051643,2454537746454,50442736868698,195098675163886,-290454178976391,-482347222591720,499454006444109,-353974167527568,-3871837066716952,-11822366447091499,2835081472550665,-3855640556911695,-47652318513445692,301366545715726650,564687830612795101,1029207296363481972,-474136464540793861,-2871206110796828105,410739197283982672,6387629004394443403,7797230907131195793,-803479101895614862,-878843141610558464,-1973042833
+6374628540732951412,-62,-57,-67,-53,-175,-22,207,130,-344,5104,2553,-28328,39069,34796,-53621,-347012,-832378,128789,2310388,-2120540,5891844,-10264709,-10570241,72168981,110296339,81610906,-308237295,-304088600,-1345320845,-4962959209,15237036717,21159105929,12982692691,-85410160923,-39092632867,-180758293053,19953811864,1172124115909,-1034907379542,218795333229,-3324052600844,1884928252970,3168402839537,47483841202949,184880485821656,-292691221206989,-549141066160213,516969597424509,-341310451198383,-4099451318272042,-11646815832692614,3243270207167177,-2155651866037950,-50851204184374796,291354207938932682,598377175119059377,1094950727109995172,-388977166318585215,-3061893446089471326,444751456512509424,6043075048266662422,6384666573083345075,3015948803179798375,2019794285800592384,1578654448
+6374628540732951412,-63,-58,-66,-50,-174,-52,149,137,-284,4828,1732,-27116,36678,30881,-61477,-332945,-865630,132182,2442739,-2324916,5477042,-8420756,-6495608,69278994,114821942,57368481,-283696389,-177742930,-1401736171,-4519507168,15556851952,20982652962,13257358218,-84716376495,-30054661279,-152419270240,87565720011,1309241028752,-1295170344197,315802724100,-2779455894754,876128096073,2477080700708,45057928121211,199494208930939,-267101430874835,-581455953780086,628351064585732,-253209846490854,-3594473107876429,-12023890175856642,1779320442751849,256241409952066,-44308887621364175,299037290088282911,586930388000406262,1060042322942092980,-379999792091572351,-3128583205719038866,981701650717490916,5042808273043578913,7053237180113572264,-186732450416039097,8544674548969305088,-613354219
+6374628540732951412,-64,-57,-63,-54,-186,-81,188,217,-256,4928,2343,-25717,36913,29100,-58020,-301844,-810596,40100,2395277,-2486227,4800357,-7733739,-7221688,66963171,101562835,27143581,-244705322,-97109713,-1376380621,-4263254683,15835380693,21558236199,13984542464,-93006398834,-19554378597,-129489951814,77074320658,1248691896200,-1164338439293,45907336441,-2523054806967,2363377114661,2932838836031,42818570225991,187122707528797,-249215520505479,-585949040648565,610395930085855,-40862062900473,-4090704590285358,-11897595533986270,1473509753499058,1500973858904341,-50599080182998403,308263400182531205,612358056430393875,1004343313045864819,-474498421281813500,-3190082562241013985,1462997650906934925,4802485719385162188,5729239302497279179,-2221669896576746531,211778512021168128,-665463398
+6374628540732951412,-65,-59,-61,-60,-195,-57,127,204,-314,4768,2413,-24689,39484,37131,-59387,-276158,-875521,11215,2604679,-2872090,5718737,-8482680,-3254995,74265844,115700416,8608327,-269607168,-7990644,-1414563633,-3919083307,16445623289,23559554875,11394219224,-92897536854,-18506431155,-131795337584,113133486453,1198716502773,-1128094197977,177868810083,-2555242068135,1380432622014,-9844245272,48712647494208,192263313695484,-230806191537810,-522215784601646,510000468341470,93247516573604,-4103482703146747,-11898371179586417,-457100597037398,-869058677572942,-48094233996597916,294149152671847857,608861415182592331,1058506426612214398,-332455868862243162,-3353237886539782716,1604793543308985808,5954869174942203965,7386321383760379384,-4517436972077685974,-5411549867167260672,-1440521076
+6374628540732951412,-66,-59,-58,-63,-188,-44,185,80,-440,4912,3302,-22929,41550,34657,-57445,-280533,-902479,50934,2345399,-3272297,6412105,-9289517,-3094558,77834509,118809300,16553281,-248877863,-130031645,-1602004877,-4229123990,15602186714,22365613686,14354105375,-93251256417,-29793332361,-147421761693,146988844155,1077636580907,-1077926825212,673969874990,-3325079482653,3241061472963,1167398823860,57331204988391,189389588316675,-211947775879711,-545542718763884,544561344507049,-40910284089305,-3752879315594960,-12290362998990219,1006099998334497,320560999289939,-45432508706462026,304052474817917020,640505772012334660,996771888357240853,-314073023990963096,-3580110432860014749,1728969645396644873,6297661338977474386,6495447093654543809,-6177579746736788030,-375287600366555136,1495913870
+6374628540732951412,-66,-58,-56,-71,-174,-39,199,59,-696,4409,3020,-21884,42581,35136,-70790,-276260,-952208,-63025,2571693,-3290802,6863979,-10537158,-5818000,85721303,114317217,21093625,-209904017,-210898362,-1864308449,-3778896072,15285192156,22261686632,18423549289,-96859786739,-26160707634,-118340250747,82193599537,1151252470668,-1208230102673,1196119618271,-3939845569166,4052565489795,4914922269087,58555212168709,174454650248295,-211234957196227,-523112381594102,493206270727420,-171533453035276,-4203380970044864,-12457476351953482,416887248443856,-3979009937298810,-53847841620642866,310715773025701256,648930034644828517,1008917805701062716,-239094361095104255,-3447549143295999062,2248955611683612231,5794464649776834146,6602586672103985031,-2928545001563443554,4255928398714867712,-1334468129
+6374628540732951412,-67,-58,-58,-75,-188,-47,211,140,-933,4592,3744,-22566,46273,34337,-56896,-267308,-965122,-36734,2337446,-3659397,7694080,-9213732,-4939667,91364212,117701146,24715982,-200794035,-278073512,-2125748305,-4029201978,14976285033,22033169211,14229863510,-102438907667,-41959680701,-114778754315,137820193878,1112019478990,-1343312995066,1264963082007,-4964695444711,4881110927653,5631891941495,51533060507223,173638442501192,-178525404549117,-572146802018139,562366017743202,68019121068404,-4008988632179184,-11519992030470390,155319304782712,-6315206439595369,-55848874191985371,302796653346536397,618249942961491122,1060484043524283636,-259930123526136032,-3309583462663861492,1982800041501632025,5689561057800945418,6497596426496750141,484661297717584186,-4711615674089097216,-221281690
+6374628540732951412,-68,-58,-58,-80,-180,-63,175,239,-960,4119,4286,-21265,45923,32029,-45178,-253379,-932224,-165085,2138859,-4059507,8087467,-8794556,-2085262,83865256,133792406,34284308,-246644156,-338269637,-2239385046,-3526941830,15208005632,21845848152,15811969857,-100835483588,-25610930242,-91385554036,91072587632,1200078961165,-1499497503696,1651369207066,-4191263990191,4518661807296,2133105032641,56598431613730,177814128312153,-144662400759520,-562511773846830,457345710908930,-109081182637144,-3643556240493438,-10800705459878448,389892536300933,-5336763783139864,-49212191500962198,310216690023656179,601337243714567715,1096137158475539032,-268803149998979734,-3037502251991100705,2204964736447287592,4867984545009171900,7028183737589667761,-2382489348458269156,2808881147843248128,1872797913
+6374628540732951412,-69,-58,-57,-77,-186,-51,193,159,-880,4170,5053,-21310,45872,25658,-53649,-239159,-945391,-66305,2324968,-4232635,8066073,-9097968,668405,89543185,117460080,47058816,-265432212,-377720892,-2317060484,-3779993682,16185231405,21890925344,12868639244,-95447465078,-36009752860,-77319229665,87376439090,1205188898903,-1292406670050,1372525703473,-4189177197477,5308331779069,2600932122453,51243367954647,177862613921154,-156173165193925,-535181663859491,434150742620346,-41087040864254,-3400375565551364,-9835801805921587,-357786434581254,-1223998707818129,-51166087450198005,312747353422530093,570915803771719090,1148965600673479336,-237349667698534370,-3301633919118977829,1744174669439606019,5101566819465356716,5890748307829993709,-4358662055038081080,-4977167024830817280,1925118055
+6374628540732951412,-69,-58,-60,-84,-187,-53,216,140,-951,3698,5532,-21133,44355,19749,-61417,-265340,-903076,-85284,2181177,-4434785,9111742,-10835861,-1426187,93276498,101479199,49919668,-322398086,-490714808,-2246728319,-3399700413,17232680781,23639042217,13079294702,-101840978745,-35497881493,-111467463541,129814734081,1173169412603,-1441841638371,1565950043265,-3615703836139,4936457341769,3083919890281,54882560395647,165823547712304,-187349873087936,-590216574405952,312289660136214,-23870885210006,-3323901216172574,-9754209293471395,145914459364900,1969174611026194,-50560564803752582,317294051713880341,538030404089546657,1217838278182375495,-171730195722793479,-3385385810182317112,2192862169361056897,5616713367663678568,6221997105970991590,-1579863038477603887,-6348444339711259648,522368345
+6374628540732951412,-70,-58,-60,-91,-178,-51,190,192,-1050,3226,6498,-22865,42391,22623,-74550,-273926,-889127,-127445,2235198,-3916551,8791092,-10230859,909276,95688064,112862324,34133617,-267883253,-421044234,-2228800213,-3604785285,17472070461,23720914586,16145119939,-109291992231,-30350413533,-126673734633,116423437211,1256777353609,-1263582253579,1917425369709,-3822574870025,6590736898595,1896819159302,47448181079455,166112084101785,-221716660645954,-660018323604819,247059766384648,133308843652161,-2888102436309470,-9910652679840591,1286134708828077,-2168446168970790,-59529654294988032,307441882581434039,572716531005583098,1240855603655526909,-256020716099721448,-3552131638462719128,1906837793468592043,4834185555087919126,4019798212491169569,2440631178962210340,-745127935655894016,483392733
+6374628540732951412,-71,-58,-59,-99,-165,-48,126,74,-1076,3315,6083,-23090,45583,17869,-87702,-290993,-837010,-244903,2066111,-3904741,8652803,-8251718,-3245388,89123860,104296884,31704905,-269639204,-475605582,-2469802877,-3960535153,17818838308,23205073411,16440479227,-101512546693,-23225032140,-136177512376,107141067921,1138399192533,-1503181191769,1969066175793,-3270348999027,6871184771539,4672332836913,49496770573373,156104297363467,-251758345548341,-727457966750946,215909831466590,-68654881964695,-3380093142754319,-9071570248684300,2418515856447182,-2392875182794670,-68236063005388440,293304038923804958,541990659376728441,1176614887841554444,-229247212346704310,-3778424577629102559,1694225176522542637,5352623964519865383,1983160582143908009,1399943279835251544,-86903260734004224,1412551999
+6374628540732951412,-71,-60,-62,-98,-168,-74,118,169,-1081,3389,5154,-24416,48649,21916,-74974,-317159,-830532,-278276,2047122,-3648699,8559968,-8444133,-4639904,92752880,104761177,35677860,-208740196,-407627849,-2690705609,-3470329790,18831788461,21814132446,18064143127,-98950228016,-34540179509,-116549574594,105739623616,1189998997165,-1755688647144,2052220498652,-3732488782709,6908584384773,3035450511431,48658927225078,160681788053567,-285636123457444,-719035180964982,77874513979011,-60458470422178,-3201226371268173,-8982177857009310,1943350950211935,-386156910436141,-59579484163935694,284134574887708285,572892030023783178,1174309001142805843,-227896524062231262,-3685353989577762625,1612446408695102921,5298148960466192753,989335961314819684,-1865459557750485784,-3554800211595777024,846302975
+6374628540732951412,-72,-59,-62,-95,-179,-48,163,95,-1059,3146,5722,-25296,44726,28636,-60504,-287430,-767220,-178064,1897732,-3438712,9043436,-10133550,-1462213,89533499,91775094,38017261,-195595147,-299508523,-2862362934,-3640333115,18614794049,20717808777,19672927969,-97044691223,-41315791787,-123775220971,170456363058,1176666337033,-1796727635913,1584746366032,-3379892877937,7417791096297,1409279672242,49573668830237,149289782458303,-263791705876866,-766036245832290,184371196159305,30810089359164,-3759984679495385,-8539496191913081,273448553005984,235949741991716,-59591215158051216,297773438581622516,600042471771807505,1148619836041264897,-155274314140775089,-3568055306925122081,2097052844260599970,4701746185357707055,248132777828087469,-663967469902269819,2152339254410800128,-1505922148
+6374628540732951412,-72,-61,-62,-91,-195,-27,129,59,-1163,3423,5616,-25423,44375,22437,-49206,-268308,-768313,-148957,2128876,-3042218,8787477,-8863354,-2645953,85171416,91441341,71111788,-200603021,-200362030,-3097161633,-3674197707,17791886538,21457668713,23244199777,-101439317879,-53887208954,-100306954242,135538113904,1248047817553,-1856519160445,2007676848305,-3157931692241,5477603372567,1306271806570,53869087552662,155402203477335,-295760746979614,-747684077784169,229016578060199,-206519357020220,-3531018597778016,-9061540052881168,-1665117882817420,-2951106954536577,-57652171398035977,296001231021174054,596601480392501541,1110587777349018350,-17601162104878607,-3428871494122228231,2283907829000896563,4240973336425210201,100774328921267807,-3029317509197251334,1440769670253911040,1843301998
+6374628540732951412,-73,-63,-64,-92,-206,-21,86,-54,-953,3760,5798,-24259,43845,25687,-35174,-257088,-709421,-194823,2050114,-2739416,8100716,-9725098,-734315,79359822,84110866,68583722,-139048226,-231195609,-3343996867,-3281063990,18250398919,20167249781,20825024532,-94147250772,-56066470514,-117244971865,74186210726,1230640850375,-1989197719830,1904174203303,-2443389890105,6039207916676,-1843171717198,60809670286538,142214223197317,-270619757357946,-688825137089200,242588764549323,50072246036229,-3279700058518208,-9237522213125021,-2638003156342644,-660970509701263,-61601702802956406,279736826243980804,575241114356468705,1175739496321310711,93936808759462376,-3326456451783436954,2610046636456990040,3530419301302686066,-1914513792628609085,-1895237119831207858,5565949087359415296,-1034124461
+6374628540732951412,-73,-63,-61,-94,-205,-19,73,-174,-1147,3439,6161,-23914,42950,33291,-35036,-243824,-772601,-170807,1894072,-2693574,7124313,-11701809,-1254891,83085803,77817635,56110492,-136148419,-182744047,-3189759629,-3736407767,17459555859,19788794047,17090622197,-101798951802,-61139181280,-96458807708,124775832489,1367075592793,-2053521548522,1833276830785,-2004774087408,4875591927632,-1096674601643,62788592409445,150174939845323,-262296688634279,-754186508020270,116624178221542,-45502471207813,-3582166096987961,-9932460337598693,-3971559771165542,123903613956545,-66708562591229500,270139989892586543,589166555391292692,1149750522985347362,9068149101971758,-3462516837151028376,3185865862156096250,3650274177735751914,-154339294865495770,2393971382658812223,453983209317253120,104332920
+6374628540732951412,-73,-64,-61,-94,-219,-16,88,-168,-1191,3139,6733,-23212,39533,28843,-21166,-220983,-784964,-285993,2018034,-2783621,6778412,-11056883,-4082944,84214870,94011842,26173078,-167294345,-144042251,-3177278060,-3644898311,17339323409,18048934801,17221349351,-108009618777,-45570937717,-120932175153,188302529046,1355706422429,-1808793997018,1452936457296,-1804249170043,4512614680433,-5169805070057,54327006214326,146934953590129,-293268740582382,-771234320971969,31996928194530,-154691778801195,-3466107875229703,-9595239748102723,-2121038647037401,3802335485243838,-66361035927436750,280495399138463376,591496026349851363,1129137695386557044,-82195571198141475,-3270013836875518678,3016881139456820186,4326100118545134636,-52243153744987648,-809091328763599811,-6879100064698967040,-738229721
+6374628540732951412,-74,-66,-58,-99,-204,6,133,-257,-1202,3472,7043,-21482,37007,31772,-20473,-202121,-723287,-257197,1861901,-2675061,7238372,-12631932,-1932113,79196224,96037536,53894554,-226750856,-14091385,-3418893445,-3781257973,18026735903,16290324147,17805311740,-116330783442,-55493478299,-103284604228,230280603599,1366105830779,-2049927517816,1307386691958,-1256225221909,6182349743798,-5709089117335,47448154346093,150330222111879,-262730685076881,-760637697754305,133230736367457,-145504496837309,-3737099570977735,-9730701335538128,-3183459422271080,4785870085681342,-58352214119621866,271540753143944486,593380583773292318,1154813771923894030,24003683971137525,-3129007349566279244,3057870861761137520,4563676244410561768,-1741652180506717543,-2499507919895063841,-2698681114387768320,-118550346
+6374628540732951412,-75,-65,-60,-92,-209,28,71,-335,-1250,3264,7756,-22799,39387,37870,-22580,-189198,-738831,-372793,1728843,-2473398,6471790,-12723115,110460,74854618,109531639,85506834,-177494888,-100590602,-3299890024,-3578023116,16979170144,16422707866,19534470876,-108314631019,-63478030269,-101786429839,190360090817,1453702108313,-2068284619548,1122463298921,-1074251216374,8294667469650,-6137772855499,44207469146213,133710496588987,-292837558553783,-810613380683171,140524903637948,109098365627117,-3976706690291293,-8954108748023807,-1281751345972852,422014296471542,-57863965978708065,281860784337911677,624187151311239298,1139803685338002982,-104852950822770402,-3298643237469398703,3109366964882489214,3947078665691640801,-1888936637549218134,-53085559149257261,-54960883838313472,-1074429320
+6374628540732951412,-75,-64,-62,-97,-214,51,17,-378,-1485,3526,8020,-22535,41238,37962,-31480,-157098,-792898,-312601,1693543,-2185126,7226577,-10877800,-3352292,81072480,111579375,99289584,-145648375,-37564631,-3339698110,-3043611682,15929437254,16737198169,19234901097,-108424852479,-76101716766,-72294601468,138437060344,1362414977460,-1814311938465,648166659283,-1347559896278,6785025594164,-7263478858870,52611273640918,126335187492903,-306450398305686,-818304415687921,123067659307665,-35289239174926,-3860929772851986,-9326177886769213,-2960826405730125,-3311663890663492,-53198978866982303,293872701694432872,604960118479434755,1196571140298484262,-175735066766119170,-3339814776445649841,2887220993241709396,4942404503627845697,-1309598637983847326,-3679643457238937700,-2289236623772830720,1836985248
+6374628540732951412,-76,-64,-66,-103,-215,48,72,-498,-1388,3848,8154,-21036,38315,41843,-33036,-171051,-826262,-229986,1926867,-2050406,6387190,-9410408,-149969,78817433,126592045,81994317,-198276637,9850126,-3494092817,-2690201107,15505280471,18741083271,18952763689,-114554784964,-87434827039,-96395517488,145136227324,1249778691025,-1890030834712,200132550866,-1592688340854,6272941927052,-7443987898894,51481359120702,137728803389611,-306129522303348,-864778185151790,194321571903522,-149817755341048,-4136169621810644,-10215830575913247,-966165516213317,-5348673073269918,-54225444332979874,295719381720462597,622799621982663307,1233274134628458299,-145694396643728882,-3325456343176764193,2496856714044730081,5561621031870866178,535020433403755408,-2478667931290527870,-4455464710630201344,-625874078
+6374628540732951412,-76,-63,-69,-101,-227,51,52,-598,-1200,3931,8082,-21565,37582,49584,-17096,-150252,-836005,-126645,1790515,-2268063,6187642,-9697827,2004248,70820330,138995933,100131962,-157073477,86033336,-3544949896,-2517952941,14760628370,20376915713,20700991282,-119472765300,-75152848587,-120758257371,179881263783,1326812709165,-1803733122651,-20481103340,-1865226554655,7912993187209,-5426112766473,45516402786549,124000390514696,-324485533580718,-801464639293251,312953036697931,-351859562585981,-3787813739325638,-11128552833776230,-533585409329709,-7470222971336011,-60010205080148926,290165707748131369,602083946158023823,1273340086784117145,-223252526265982218,-3066682400509989695,2702175954195771196,6254510076869777175,-312807438343288479,-4028926223061305375,-7781275647862695936,1785069594
+6374628540732951412,-77,-65,-70,-107,-215,65,0,-526,-1309,3750,8827,-19795,34647,57042,-1662,-138248,-890685,-228568,1633079,-2725524,6163040,-10491702,697237,74565050,127767368,93532213,-209675435,-32116110,-3624917040,-2927805617,15581947241,21893441661,24075494509,-116822110531,-59683724667,-146210393388,114424524398,1341560771667,-1638742564263,520137948334,-2927347587131,7415137351179,-7963937754617,52157548982266,140803519083304,-294675355729619,-868076759504942,181128508165910,-91753231238823,-3511241416682881,-11545256318348796,-1952917510863468,-5161099825338866,-59726090170689781,287170105829528178,607326725526282735,1253194074103207461,-162443950414676064,-2964036188567341159,2602201580810990248,5581917084094110764,111739292249520611,-315687754593838642,-2804420462762366976,-2078683524
+6374628540732951412,-78,-67,-70,-104,-208,50,25,-623,-1187,3391,9739,-20154,33067,52103,-12694,-107660,-916626,-152198,1559465,-2667311,7030237,-8691249,1904923,76319509,140223746,103928605,-174921862,-153306984,-3385428163,-2706885189,15283784465,20634072096,22714148135,-114815487874,-57091612899,-142632639974,179528438499,1374129028047,-1585326364054,-18216163874,-2157779605185,5971568380320,-8248990257826,59840551103342,125712490704524,-307669634982741,-858094683247971,93893891048153,11148633321077,-3965770435394003,-12650916058119903,-2795866845057190,-2842039882731798,-52868794593707851,291851494824491430,613405825980084560,1195048607563687933,-273939860456041198,-2919907262065495844,2990262610803219987,4583129402810376165,181580162213099465,4246959372948686659,-8956300373167760384,-318239032
+6374628540732951412,-79,-66,-70,-103,-211,63,37,-595,-1179,3396,10362,-21973,32988,51028,1266,-90551,-974321,-31613,1714996,-3175743,7803348,-8287477,501893,69359083,137088342,98910512,-173819718,-131745320,-3145931336,-2686199273,14437121852,19278314851,23599545262,-120697088759,-41604248551,-134253143854,125780751857,1472026148608,-1776618224012,-39505326085,-2318306820897,6901170432395,-4924425394640,64990567565099,111652036534947,-305332110367408,-914163310042753,181646301536320,252098642650979,-4383324287711565,-11602686401257828,-2435353525115818,-3560138161105291,-51126739228676270,300782461836280108,605316482272626717,1261733050143910498,-211981074866465731,-2685765669824997821,2653189904999257520,3782321278929390916,691839715579758494,2922430866435912196,-6539643740485585920,591125801
+6374628540732951412,-79,-65,-68,-106,-224,77,89,-607,-1135,3451,9484,-23295,34025,48550,-1096,-117519,-1023386,75593,1973180,-2917550,7448519,-6719255,1194548,75565956,121842844,85511068,-236484389,-77920914,-3347810940,-2432406449,14398761485,19550588206,19319015041,-120331621122,-49291537846,-113645692387,186079119005,1445912072723,-1594239632237,187884174606,-1667593682845,4842491314977,-2037243955694,60592866568126,100707139232128,-271636012975317,-901512658325393,107877404575704,96843321318993,-4413459108735178,-12684262107537249,-1288575614077404,-4753229847204999,-59450164750841760,286819049822454845,627828035190966617,1333676316697423272,-231701725633497827,-2915387117003651933,2141225897956204193,4741105099975353488,-636369368218702947,1161463240285619111,-4728258995180919808,-1552550893
+6374628540732951412,-79,-66,-71,-112,-216,45,35,-732,-1194,3795,9477,-21700,37841,44077,-15724,-150155,-1054098,171533,1898280,-2557159,7946295,-6078699,1017876,75174339,126891433,67681852,-256285872,-108624465,-3165630000,-2137454747,13948703263,18836772055,16903715361,-122282296689,-43755576224,-109784652688,229128900914,1335153791826,-1442328168239,-52306649800,-2155622370297,4005199050133,-5381958908769,52491834504675,94347469815820,-271626584797028,-898354901511229,201233335343275,115697087772265,-4364496388658983,-11800659059499920,-3202742649983913,-5491614099014783,-62506861293490140,282869092497009652,633887985829034563,1404650407457743606,-351621794812411035,-3192513442559103226,1709467503862433074,5078622062740096249,-358684656197809748,4234835756101177291,-3459764008912583680,-485578186
+6374628540732951412,-79,-66,-70,-110,-222,72,-9,-696,-1334,3335,9310,-21766,38735,50839,-4538,-163191,-1039232,198436,1736294,-3027254,7709621,-7847689,2676984,68848363,134281982,60318167,-278581669,-105043277,-2909443977,-2174286128,13477379339,17790111190,19091856608,-115538174122,-47114388946,-121296166479,165807156817,1436633213435,-1541488188355,-332257209955,-3069270164580,6031543714670,-6272108745706,60287216098125,95314096859118,-243572488356326,-933995610201533,195347971045851,22225154066364,-4445547869103942,-12101061349573389,-2987314711644919,-5426225800605162,-66584078265296336,293907625614072653,631057647418015205,1444945950888122232,-345707771529625552,-2963590236474531880,1516887593515730388,5696718101738613677,669119346156327898,2326897505454854752,-4614179417881133056,2014100340
+6374628540732951412,-79,-67,-74,-117,-215,45,11,-690,-1171,3071,9267,-22758,39108,55846,2258,-173523,-1007109,309859,1994388,-3409199,8127957,-6167031,4466716,63414039,143908089,32114088,-312462087,-83237120,-3053034313,-2357337371,12509346691,16209630717,15038385577,-110651917075,-31394971321,-124661488701,203579607878,1484035544256,-1507296531366,-845711274347,-2296863848896,3905869013836,-7521694724643,54217628486091,79540748719699,-209493501050146,-893176422552463,76466457140808,-244681349726851,-4271151983050883,-13104069670398667,-2511738146919007,-3620855730439323,-69396715005175675,285387379668513639,657418696472180953,1419022292908601041,-254028452376496144,-2946867248371434647,1546209102998135807,5059196862256488574,2796432364741585932,5089599960843410500,4335512773123899392,162640946
+6374628540732951412,-80,-66,-77,-121,-216,71,74,-626,-1409,3339,9168,-21410,41917,52892,-12574,-141068,-1013550,246586,1801662,-2889452,7195778,-4172443,3310205,69052815,139372881,59034718,-327119574,-12653046,-2888129216,-2744822643,12799693333,16918047615,17739119954,-105364043087,-25218457794,-94332800219,232192572072,1480859262369,-1738914999010,-391913744699,-2636555057128,2270804170430,-5600392513311,49019944743142,77884541904560,-241764048600387,-854865415428877,-61173551589161,-342972931576759,-4067736178810690,-13071176873556527,-1952403589973405,-1371111316457433,-71368306683781182,279140425186007357,650531759202957051,1439484811853485266,-247307351552497022,-2781941693591622438,1119325859969156527,4727662105519595016,3575558328637526142,4216628698772647036,8365962594021490688,1785915259
+6374628540732951412,-80,-67,-80,-127,-209,96,115,-568,-1217,3781,8631,-22483,40415,51258,241,-126890,-1049346,201618,2046577,-2704311,7020283,-4594470,5906125,64201007,151552074,48116535,-327136842,54062074,-3043009806,-2387597096,11781589477,15033113492,19437002906,-109008949514,-10491838578,-94444914570,210372477025,1420932321595,-1835385244333,-819733271443,-2579394901486,2827417431243,-5331787207658,40471883142627,79355033922458,-254266376479208,-853576192561556,-99420035886584,-438363427199951,-4263899620340750,-13587781126130073,-3877122394920300,2356869320962565,-73701805004254175,283520080474027713,675039783198056906,1421473492973994210,-234304636999666499,-2947889458413887128,1367580126706392372,4151869166532315028,2239851383601073051,3234842715380942549,2931913832230537216,-359918765
+6374628540732951412,-81,-68,-83,-121,-222,99,156,-601,-969,3334,8902,-23880,38378,53007,-1274,-98886,-1003906,292076,2306886,-2984211,6963658,-3234817,9148658,57007929,140818504,50768093,-312622034,177020693,-3197225745,-2545210327,12692451313,14629227973,19194198291,-111446979835,2275516820,-117195380026,154529606523,1394875919460,-1767059671131,-568289693655,-2682703195410,3640411820791,-9110714273399,37315637790022,62344474133367,-254943349172608,-835986468900785,-213505991190477,-699395531797690,-3851478496124250,-12791552862941850,-3817010017050906,2050746304391527,-67138942766389373,298434650622967976,643568250534871846,1427785882778246596,-318683532917985143,-2685372644191153003,1039969996505718926,3690273198755661880,4040230470477413167,2604472055608454793,-4988708846615500800,1941422918
+6374628540732951412,-81,-69,-85,-122,-238,67,212,-640,-1206,3035,8519,-25642,35107,48201,6953,-67169,-990239,243815,2487876,-3057229,7286542,-1180481,5187242,55324664,124920617,77691359,-376916885,150599159,-3154875559,-2978599942,13698531558,15749240323,15108664112,-108593170241,-13119949135,-137322419148,152184010235,1496307977496,-2007102174375,-940061241769,-2082536509900,2929865517903,-4783210685437,43187933082901,77808640101470,-252249102947991,-818878566679714,-275732807344728,-673223328654241,-4155039366186772,-12175760361658563,-5992494046357295,-1754594539279534,-75514506615364195,314246182844217826,678791154000627912,1396271858031419325,-299696595322054531,-2475927039186692158,1273912056899618306,4061448176943454865,4225383726000354190,18552536228170086,-128294725090347008,1837153026
+6374628540732951412,-81,-68,-86,-126,-247,38,232,-583,-1270,3040,8248,-25488,38896,43490,2320,-57964,-971122,190591,2527897,-3174153,6423579,-1830537,6858037,52302205,131246208,62862903,-400768319,233908357,-3128060756,-2961616595,14527763235,14481764552,14504830510,-107042791827,-22430257747,-149253907380,124641587958,1450408737583,-1895446704261,-771622519335,-2242572058029,3110050526734,-2881349894049,39078981190580,77316889567326,-242065069232668,-777298316668156,-251938614733833,-544954466788181,-4024214378525623,-12290561600970663,-4935902777462847,-5536679364163291,-78473786056914280,326274896904746261,642981252769316316,1351797858729528115,-410165329845723436,-2429107783297626305,1255426527643447581,4216389064012020045,3350690233221099270,96863131900153931,7292842344215490560,-939062277
+6374628540732951412,-81,-69,-87,-132,-245,58,170,-473,-1291,3014,8332,-26009,37132,46518,1263,-49165,-932144,127690,2669497,-3495005,6894500,-3564181,10238957,45557738,120713944,77272884,-375998884,327731336,-3230506723,-2843791959,14116244934,16024147303,15542826455,-112758826619,-37760148499,-158958348932,72891923465,1359548335667,-1860540560617,-332279922428,-1692089274077,4000478295669,-821700808424,44614794336819,88004637453584,-227628157883024,-766309937553429,-389612493785852,-567376417450064,-4442802899355231,-11403884182092159,-2903946759788054,-3895365224006874,-83239822852595761,311248048148620354,660978128751778385,1346226206163305945,-484685175821636969,-2340037077734203106,1692780145551470878,3452600154193343791,2703318322904830532,476023853837070083,-1573439380478895104,1140306395
+6374628540732951412,-81,-68,-88,-125,-254,58,193,-408,-1294,2554,7664,-27279,33851,45932,-6161,-65830,-887009,183973,2446256,-3814130,6917263,-3748535,8763392,40445194,115435749,90859457,-385035511,272471187,-3283841063,-2484349667,13150569836,16214469768,11586894208,-120109461491,-26902862663,-154100966983,17058605226,1305727386685,-1654964698870,-632753671097,-1081904684771,4409571114011,-3051931539516,51080831852367,102975210158143,-215219586846187,-766410586583480,-474944455144982,-695024129345161,-4167602099211335,-11299237470124842,-731856381799858,-3439553580308718,-83025859465746229,309331006898910287,631695111836652875,1337681453203084964,-618878022547284363,-2460788684646110509,1228502954388287529,3481914065719902593,2974163847589024230,-2580562524074712970,7387918360693492736,-1568357236
+6374628540732951412,-82,-68,-87,-119,-267,31,247,-379,-1536,2314,8276,-25366,32871,51902,-21982,-50964,-937418,253742,2688878,-3403290,7707744,-2298328,12500851,36699620,103044797,116034719,-341839704,138547599,-3362856651,-2047484692,13629474504,14378856916,13070728329,-126443850861,-32577273585,-145438750299,-23409641057,1215678287906,-1596212670017,-962339630968,-435137634694,3188513146268,-3236295539010,55973862150426,114513806846252,-216933894590146,-769551702825008,-570531835413966,-746379441083552,-4482010753626679,-11937502614730367,-529841995150424,-5912111462246252,-86138041527213496,304333568583578461,596948482250213653,1389700738502543532,-699741964933216833,-2192880376805024354,1109952496636271448,3879556913383456908,4720062702052554068,-2950458634759063902,7041677573917393920,-1015707823
+6374628540732951412,-82,-69,-85,-120,-271,11,307,-319,-1572,2278,9153,-23867,33148,52707,-12046,-24374,-919668,298089,2583088,-3013298,6943879,-329022,10960831,29613765,103693014,101761856,-279759660,8153672,-3511964742,-1807232980,13580957090,16291033756,13570236399,-129035531785,-41445802133,-120940742460,-456689748,1149051518083,-1797509664781,-1100128948408,-1182726085409,2065720463353,-280335994459,54470790038948,118775542217236,-224198069479428,-833607666683697,-494832017509895,-605110415437613,-4122072620931174,-12718919186749016,-1459570304838270,-9069417035326515,-85193149950526768,317672389264332471,575848201402198881,1337719942886650372,-806190215913813264,-2137680079301477532,1235102760314684323,4479481629251983517,4462004615583708132,-7348248970873948696,5000432420456630272,1015247445
+6374628540732951412,-83,-68,-89,-118,-256,17,250,-405,-1516,2335,8473,-25675,35761,54766,-17351,-31846,-920748,334104,2471548,-3330932,7287291,-2296803,8204418,32199669,101491614,77051961,-277409104,77590152,-3422331961,-2077313635,13248844153,17240997069,13246307363,-123628492105,-56289292114,-136088942092,-60962247225,1125584257134,-1552661084067,-630620464133,-1114801807945,1813743699234,395182986372,50308310349410,115217609107227,-208549594887035,-775542427520993,-468889713759549,-452915560433363,-3839486052762301,-13286414007120331,-1846813346445447,-13418559004113412,-84339943928282559,319298080310302085,590315004387529725,1271200436166049629,-674233755093332378,-1950789289497729387,1082831045886719409,4672588562516160974,2598287775728419359,-4637419236949910337,2093793665536733184,767228679
+6374628540732951412,-84,-69,-92,-113,-248,32,250,-495,-1271,1885,8361,-26566,34265,54664,-5145,-16924,-878833,264815,2694914,-2881364,7201171,-548325,10486391,28599077,111514695,64773252,-308666785,4497910,-3630246935,-2500419719,13876429038,16087802265,10074061168,-116620403390,-49100901153,-108247433943,-7105094551,1116793629350,-1667434431774,-867356720925,-624950775424,1059554100364,-2163453526545,57159389124725,116315213889956,-241832286710963,-764702770696767,-454459055437178,-541998041694344,-3405127368505405,-13221897574851498,-1738634255680743,-15650584367760422,-87961533878063719,306460398248412760,605671515834830268,1199950927737631218,-585485537466430357,-2229390893122827265,1598117876937004074,5392311371425429769,875360926444166744,-4018975973459816979,-1833095765983158272,1889021218
+6374628540732951412,-84,-71,-90,-108,-261,62,186,-454,-1254,2311,8227,-25534,32907,60763,-19725,3973,-876760,363765,2843718,-2596143,8216454,-1240394,14530830,24706713,103344538,68666701,-352092491,132023602,-3612679400,-2533600371,13883256789,14570982550,13752228599,-122344248191,-40779748463,-90759422773,-10406939524,1237848624773,-1478251599270,-1411930982118,-111248472354,-788680868074,-4243180970590,48893723692201,132915634508584,-246133763962663,-759135936355140,-434411659855671,-820897996317742,-3965931549189249,-14225002563422067,482074837131194,-14630733114037662,-85437020578240126,315154197807714333,633154943690597563,1176196870769851155,-648468982254256131,-1993617760234617187,1590426689214190835,5340371535041620244,2347671882456098153,-2599348479071073373,-4025427107959108608,-421048908
+6374628540732951412,-85,-73,-89,-107,-250,48,221,-365,-1490,2781,9109,-27440,32448,64785,-12917,19872,-936724,460023,2957999,-2969670,8668303,-2251723,15653023,31888713,92684775,52456559,-368301436,235303572,-3470736370,-2592378738,13433993735,15454583841,9503587853,-116981798350,-25935319365,-97617649071,-17136707737,1141050272260,-1709972893039,-983296539511,-896281380901,-2347439720862,-4393771726829,46684195799707,131044020145124,-261869817204520,-714713804120632,-575010300533080,-1083610401717574,-4155766186461024,-14896240137132574,-967868351986859,-17784527030850480,-93501323611186089,312529455882873735,637779759094253169,1130916085415231801,-712372714532910331,-2002358197187213306,1169047231535677523,5875524549642120660,3714472485880451996,-172894393713673652,3687211034383320064,-905892578
+6374628540732951412,-85,-72,-86,-104,-247,61,167,-463,-1514,3040,9712,-28231,32911,62970,-7732,34718,-944778,408922,3063458,-2725415,8910497,-3029052,14794618,23625943,85835817,58399584,-390347442,242307080,-3628861960,-2917359605,14144168095,15511296655,7711308235,-122293819180,-10723873119,-130461916708,-28905525564,1225427273680,-1960362347423,-1188761462535,-1541488656946,-3792661377144,-8525292302538,43289743568705,124182470925578,-251463319313585,-748081260603783,-532411275641930,-1016295766826954,-4256280090665980,-15996275819941210,-1144699420844080,-21947444465540478,-85738319085839088,313448077477567560,629473548716637747,1147092107008153637,-852908085111815313,-1996845112690978978,1264385847894502457,6141156403513155746,4735517693662464677,-1862490465908796587,-4153247260559246336,-462180864
+6374628540732951412,-86,-73,-88,-105,-251,30,196,-481,-1261,2572,9681,-26520,32997,58634,5642,55338,-962582,370849,3171465,-3051654,8674248,-2031488,18146447,24478683,94699749,52561880,-404877404,323005330,-3630101448,-2746671982,13215326111,15435020041,11082132873,-125786946413,486300068,-139429212518,17013989505,1241878418785,-2059348489032,-1704452472926,-668232926239,-5082276919024,-6940881210115,35809603465431,122989102265800,-222495015727631,-684764568058433,-581046045876546,-851635265406564,-4041492096769703,-15238892543095568,114834119838220,-24800000653307089,-83759299224632800,315141190837407796,635772313037137935,1107558475511670688,-810799969839688253,-1851984389808990670,970759119700039242,6686684121425336949,6783718056934529693,-6241760656906767251,947173554328357888,-1494333306
+6374628540732951412,-86,-72,-88,-112,-266,24,226,-448,-1276,2088,9271,-26782,34328,50867,5162,63344,-932537,348891,3342470,-2791247,8947264,-1853869,16945349,22249102,105459723,23375449,-432370583,418228854,-3711778749,-3214188822,13860460924,15343631541,13970247939,-119264575962,14450863140,-143401755217,-33289276339,1226556371979,-2201528068767,-1853058628476,348711070419,-3318160933603,-3676315061004,37682932226657,114328962805508,-193494935396573,-724206240590090,-699847179656307,-879531173828287,-4097144902883413,-15510172272318588,1108894907582754,-24533310695754713,-88704797415051334,304280087076437687,647897975795045877,1142750673873202581,-802040048798296068,-1789095861088765936,953031813296761899,5972557435190066211,5872551221731341148,-8556077275024715542,1630576758686258176,-893424967
+6374628540732951412,-87,-72,-90,-110,-281,1,210,-405,-1338,2535,8656,-26482,37898,55075,4627,54531,-913162,456000,3501482,-3155384,9038761,72684,20404410,14932637,112614242,24166607,-471090402,471131777,-3540999343,-2929497709,14458550299,14970773033,15649153463,-125050450570,8766029076,-110703554734,5175923155,1089299325773,-1974617804154,-2032652356399,178080289175,-4060681664509,-1273793511885,30165531112108,120208340703380,-225719600727935,-755269578375519,-783090199875242,-1044562561176796,-4043987633862300,-15541262148604523,684356736206468,-23391263332052254,-93390784601523431,298335628044409878,619601895337317746,1140767251378782935,-775141271720077583,-1870249657449129446,508549049268996129,6921090493272115237,4185383976032674392,-4844667923845567131,331966904971448320,1225191101
+6374628540732951412,-87,-71,-91,-115,-275,30,179,-525,-1233,2167,8472,-25919,33810,53748,9897,42387,-900049,377464,3580887,-3360307,9800448,1143232,20335679,8648182,113522448,35799661,-532155481,512025996,-3652078293,-3019909957,14797406329,15700585538,12498821596,-123104274224,2362377556,-114326185115,-52947278556,1093509874464,-2129048811086,-2577513288239,-674367173917,-2413721305397,-2756167343985,32589760427429,120046519492258,-252006544888975,-694049385764645,-789320865471282,-794786475324685,-4214730855905785,-15799009842896542,-444394024228861,-21971064395970489,-92365818288821000,286847467794654740,631703113615540680,1190448921176319126,-649710457423039121,-1980578368318853871,464340069245504168,7755921336839904177,4687132135770723003,-1228693758789605259,-7384809820388201472,810009443
+6374628540732951412,-87,-71,-93,-116,-289,25,148,-459,-1174,2538,8569,-26717,35033,48563,11829,58443,-874153,464514,3514924,-3428802,9953551,1304376,18745603,1217519,107623383,40237069,-540789309,505077873,-3693940444,-3427913479,14551984590,14159144002,13948903738,-121595455646,67926717,-80687064387,-93883166861,1143301627255,-2157006526796,-2872389561750,-1687233877834,-2606279256818,-848347093247,31054673922113,125552316556877,-281293747564480,-722740070788495,-819853898837523,-992528152885549,-4417086377872183,-15518582415647190,-658866858524925,-19735876592993056,-89945250492623744,287454174044844977,630328078612392293,1186825800189213506,-673330812467463800,-1748394558776204996,113184072394355434,8385154401566275102,5590565649237451568,-2506639547564750287,-1724763376679445504,-2075311278
+6374628540732951412,-88,-70,-91,-111,-279,11,112,-417,-920,3027,8458,-26821,31801,45359,3914,48843,-877042,432425,3762988,-3091537,10633146,228910,16615191,9368077,107349696,70064579,-504908995,433595305,-3829264536,-2917776683,13695109853,14878739239,16294485702,-126462419064,-16672371326,-95740312896,-56204580156,1208139668420,-2033557944757,-2728427745843,-668909619888,-1249329391428,-2025365822564,29356807485234,141586368697905,-268212159470340,-672854531207608,-861754918433180,-1273123871409938,-4846188056192645,-15617829168629454,-1945279776492512,-19961502719344288,-89807771101252029,298141758745948268,636412720972922333,1215800562899086475,-630174940487965593,-1468782984774619017,308496860870965284,8626371421772904009,6847141361896021773,1174868714271493589,6347055187035315200,35074056
+6374628540732951412,-88,-69,-93,-116,-286,-16,161,-330,-829,3183,8849,-25850,29802,40870,8626,59525,-919490,472572,3895386,-3477924,10841359,2189135,18510857,9895224,112332282,52504944,-509257256,354245320,-3846087545,-3418163521,13409204731,13017199327,14804615407,-126885725130,-3909433124,-125759402232,-57807544762,1334381606340,-2228723881386,-2805620349656,-607053528997,-3229732329551,-768929565966,24616059021010,142195322191771,-248756530551046,-604287581079860,-878019827431629,-1293267473478871,-5283313368771159,-14898407718733062,-1044725139941145,-16365501787927874,-81374335700352365,292196396225976060,604013575400421204,1262048029979440561,-570727237421247458,-1395785292256177061,644276967857314087,7922343612658379410,5904538956686114291,883725006535260112,-1211801468718089216,-1515615921
+6374628540732951412,-89,-69,-96,-120,-282,15,199,-341,-939,3634,9049,-23873,31767,37862,20453,46805,-903866,376742,3770115,-3223334,11830959,578680,18981861,6090539,116634394,52278929,-506492766,267703719,-3987149993,-3059460730,12758482738,13896644859,11598231733,-135456817073,-2928464264,-97033321084,-20557157238,1280547166767,-1999528746131,-2290700903837,-1544206817786,-3461337030986,-2768090903181,16735270403755,139034502366662,-269030042867779,-559230274150096,-844816426348185,-1547739930604233,-5074034952042986,-15097580326399624,464304669546988,-14484597034834450,-89062546594559424,308537768759157122,619830040591808236,1246352977859527343,-478048065433537374,-1619098356884573558,716792987135479170,8340111830615972479,8191289557793343966,3096851376522508090,-427965549603437568,-187725302
+6374628540732951412,-89,-71,-94,-120,-279,-11,258,-447,-883,3853,9882,-22022,35134,40726,13170,73580,-922028,264884,3744198,-3125566,11216269,62120,15891394,8411832,118423175,25023477,-504449842,234511105,-3793311071,-2657594053,12952182955,13346366879,7448682019,-135803763138,-12337899144,-102724764802,-65288936311,1406767789408,-2036100172977,-2188186193321,-2478408406792,-2392198427199,-2675288822234,11387391211356,124978245115535,-255658970879677,-508306376491318,-783420480124735,-1501552798864938,-5350066244628466,-14330218092595409,-1626757890406917,-17489707412101225,-94144256097927144,322091640779474289,618219072262047143,1203191976753882922,-339524595036676889,-1364521646425971818,324571363085664859,8032707836138612087,7135715838207851317,7677919220416605919,-8276081720365067264,-1419344549
+6374628540732951412,-90,-73,-95,-121,-264,2,291,-483,-1079,3750,10156,-23629,35068,36035,-357,68171,-967457,326209,4001179,-3572795,10865958,1646177,11985199,5315043,108106064,21423107,-491435602,183498752,-3727812127,-2449258886,12091075953,15052685268,6801240888,-136739150073,-23447385759,-135814013293,-131145055202,1387419415567,-2256533544627,-2582333105644,-2752718884058,-1988908623088,-4219172998237,12807421038589,119851673738476,-273689268153116,-536077478145023,-688331369724131,-1373047177119176,-5734285358832262,-14063304856292086,-2852294738323079,-17121393108535388,-90890415098756082,332491826830499123,652971897879791491,1138386189960855067,-340149640376262466,-1312119115047069145,861702444263773535,7199338741056390764,8036736800036662171,5280652288227664373,-8304918865524432896,1431169164
+6374628540732951412,-91,-73,-93,-117,-263,-7,347,-541,-1050,3422,10043,-22557,39027,40306,2929,65420,-944359,407249,3806439,-3260858,11373115,1664520,11097268,6953150,108913235,47280430,-533886737,271631460,-3704768357,-2417610381,12817511410,15184174932,10610469088,-144071275601,-16392146842,-121469509600,-128706014014,1462001004273,-2306254756864,-2500726049130,-3396801429877,-3654448629526,-4073636957836,18943639316669,120418980173757,-258581845003295,-474222416026818,-603048829144728,-1380498277317129,-5337018697224292,-13170549689252034,-1273451768325736,-16542968511569363,-86709169980754148,328867762173952890,631848794564215787,1192027583177391224,-368145066355921757,-1110718599560401019,1362653316082324283,7399702617193283531,6142880904654730722,4056116775073437991,-4252242141757933568,1727433986
+6374628540732951412,-91,-73,-92,-112,-266,21,287,-649,-1143,3571,10365,-20789,39622,44634,-3868,88925,-892623,356008,3782167,-3287822,12100229,1200634,11695958,5140371,110578196,70394250,-532819671,180039126,-3963152961,-2225562840,13689164626,13949793802,14607490369,-148789527980,-6770585135,-122596353661,-189765905695,1459081097405,-2213844856388,-2339353280212,-2983015913152,-1556205451798,25566736062,14295875236735,132039077398484,-244487878962694,-469707077162451,-738711437210229,-1267864833040064,-4839062796040822,-13351421120352804,51543395957214,-19528951246061241,-80427209177304942,337842452983651939,660975364818453783,1227380162852160253,-346626410348368864,-1156817887460561608,979113867199223198,6500499995491284919,4487026678092276455,6186341845630425026,-7683917940323639296,-1061753866
+6374628540732951412,-92,-72,-94,-105,-270,-6,301,-663,-992,3239,9502,-22568,36166,48967,2757,113127,-924372,333457,3581292,-2834725,12455554,-406322,14050299,2523241,100903532,83832141,-567618334,300631032,-3770083099,-2003469564,13770748026,13438968529,17493762453,-152156790434,-11655956041,-145824742246,-257356284200,1324099670654,-2444755649833,-2417519844642,-2529559654450,-3362584094460,3289108029176,9705328866892,140378243918095,-278909572348556,-421246493172386,-728212553065951,-1520309801532664,-4528448735794951,-12989020171207020,827007838448102,-21063372191469123,-73729266894567944,333859662810921194,662164416828286070,1288260242878322187,-349883812420785044,-1035319444487018083,1295150441795835087,6507389337561690877,5459043462728158113,9150047972721273816,-4534521183599484928,163947535
+6374628540732951412,-92,-74,-91,-98,-284,-19,307,-541,-858,2768,9452,-24187,39555,52983,15917,105701,-906180,427934,3575661,-3032768,11917967,-552829,12393818,2625314,114509177,94314027,-593233197,354869027,-3820755609,-1917659924,12830253027,13473437311,20493825270,-156025641218,-4291009675,-169531043447,-274583365194,1336521924127,-2468986938632,-2230889936099,-1568464286063,-4506138225158,3229877818638,12399906212740,125680907365617,-283704004540928,-473441877644328,-665179942262310,-1589043853060464,-4243394368824771,-12242674002270851,1569199339549448,-18246937796713894,-75934348175205895,349169802770085628,652811417756521819,1219956191577636328,-264243856174830190,-785249022535119935,1394097662154219162,6568377759082379973,6447871099537247833,4935878625375546314,-6255078185115333632,1925460554
+6374628540732951412,-93,-73,-92,-100,-270,-12,243,-437,-844,3261,10200,-25855,37839,47901,18330,115358,-869576,365750,3779930,-2580387,12011359,380982,9279524,8970075,102492636,115081051,-564667834,412725661,-4022441230,-1672252423,13086784295,12583317784,18710495921,-152546209341,11549281749,-185147335527,-226662404797,1430117645100,-2528818446350,-2110551700379,-2418608573457,-2760323399688,4463648640362,9880550422885,125682999914610,-261738593074756,-483482581619059,-575776125348385,-1556843158926203,-4137821586533573,-12990096256147305,1796776478264853,-17639301170601489,-79040055442673505,340931700954299662,655776952084629485,1170861521006366720,-199688467064771183,-1069384153125819787,964483112426815818,5493795069880998415,7872280582705865001,4131718348684427682,2718544975792475136,-1535105731
+6374628540732951412,-94,-73,-90,-108,-266,-43,259,-511,-1067,3454,10157,-25616,41194,47126,30937,116613,-869545,302916,3706564,-2124291,12024990,-1234465,13183233,15782220,107690115,124496107,-533556506,337757757,-4269096594,-1647581801,13601037602,11602460405,19419126804,-155840951103,17851516364,-192954679281,-168166460755,1425506853659,-2587850645652,-2591165363816,-1678841219923,-3623998462398,2332266822074,3187402862601,130722590061697,-242195758772420,-504566825137475,-672693386617208,-1337279562083030,-3605462663093676,-12809920507488510,1564717623658287,-16917642551015928,-73198164726938073,335461781129407020,671795764644357700,1224051377734557166,-140557684512501831,-1227108390392883607,1119830376134225033,5847672573572042255,8418579144180573935,42033708978990185,5882610786591183872,-1078195912
+6374628540732951412,-94,-73,-87,-111,-275,-65,320,-441,-847,3353,10915,-25803,41851,44758,32526,140887,-834155,257998,3884501,-1783026,12447740,362826,17115223,11108579,100538156,116647004,-553512631,284914948,-4316918482,-2061591042,13019125285,13502646410,15617603181,-151676821678,7295460182,-190751969371,-179539010909,1476627249921,-2553701352325,-2886444160077,-2075786908291,-1507137420810,3108409295628,11391168653024,123500870740351,-240316936144707,-556264583322812,-559922876087191,-1477101320782159,-3108550477945727,-12186816063241851,3623230303790535,-13131465285891121,-79847917471341981,351910645683181684,646973495382368618,1223317011778016333,-129386953612570799,-952714746671963026,1578820985800825108,6232994782183574705,8683242297608930976,2495922847868606130,-2693411037047842816,1003419371
+6374628540732951412,-95,-75,-84,-110,-290,-47,323,-445,-676,3207,11688,-27530,42280,39493,30062,155530,-826818,361465,3640683,-2079171,12015580,-1334860,18472915,9161521,115848029,139355519,-564386330,289435047,-4429018566,-1677731185,13543420133,15033768490,15299808040,-143215199070,-9339817536,-212027363021,-131933537437,1419175213270,-2557758268072,-3395034025922,-2373034585105,564871742743,4716079761510,19284160931957,138905160848425,-222839137464167,-606805354226332,-428582035285524,-1702297100836908,-2630021518053102,-11265361492265321,1492979145965242,-10087553636789105,-87250853979087055,344061176915445511,660453203359103648,1213833912889432999,-18013524728681433,-957590121479313533,1718477975958575198,6500352298839798219,9026687750017193101,3185732947462236075,-7497153657970736128,-1808745234
+6374628540732951412,-96,-76,-81,-105,-284,-29,358,-519,-912,2785,11080,-28215,43951,44416,18640,162951,-845746,461291,3657189,-2545776,12221440,-2292393,19921966,11657069,125644906,144779801,-623441841,160342780,-4635127703,-1897305742,13160743632,13495806178,13335754672,-142954197939,-14640370155,-206698697992,-68698761900,1308722393127,-2369748066940,-3872229320946,-1846000315763,-288159137305,7893718524599,20773514805084,125678014062661,-195709081031849,-551247692836492,-289823509180869,-1813914071400986,-2870452739616087,-11155885267344188,-78350118122599,-12089376144543105,-90844871981077022,340388653300870411,640413953130207902,1274184228519175459,-58735802085736013,-1163426472201013615,2128754959469431468,6130837229303530910,7607000716961308589,232343714733070361,-6393911084598049792,1081677334
+6374628540732951412,-97,-77,-82,-105,-289,-54,336,-417,-904,2820,12019,-29931,41838,50337,6854,153109,-800183,419351,3707286,-2705051,12302461,-1852268,16467681,9125078,112963451,122141962,-688885088,173495371,-4759186564,-2129373503,13563606259,13447775391,10414036647,-141573178387,-18681752251,-216343473735,-49998500998,1405629582358,-2175699208232,-3729619429171,-1833046802084,386685205382,5767587042486,18286919982130,134088471082032,-193984330740372,-568956484311066,-358165557661821,-1568316094981543,-2358966642617600,-11397015580389180,181244091427808,-8111550545665694,-89608509844619671,336253150305998555,626511989741421302,1310453728204084276,-132697897580023684,-939835302709398892,1900428086554009659,5874017840770118001,7885278474134609717,1868918813417314847,2425185733846680576,1076184091
+6374628540732951412,-97,-78,-83,-101,-277,-54,352,-380,-1107,3140,12398,-29984,39628,52798,21722,123446,-785850,497128,3845929,-2287139,12054196,-1574139,16462608,15723848,115921526,126311724,-736933601,243793010,-4908995517,-2064441515,13666849018,13534091544,9153499524,-145075681359,-18096114625,-210393849849,-83182369823,1491640316082,-2112544743575,-4189692342028,-865600354454,303798670376,9198703134322,21221862394095,149019342207506,-196736585251638,-515714384524355,-243424554543377,-1512836848519809,-2681476938306880,-11949486363068769,-990384461375962,-10256771793862319,-86337000646354558,330863058267482561,598605929805438903,1265291945995030141,-183879963339222012,-1041504069883276687,2143332489235862210,6775377306042214062,8672530717683779428,4723732348066713460,1516843097386757120,-1844391978
+6374628540732951412,-97,-77,-81,-98,-264,-85,401,-323,-1011,2784,11536,-31393,43660,56677,22088,130016,-814019,374424,3617847,-1980622,11364952,-2999099,16018767,10352961,126501310,97473079,-690204590,200622179,-5076482949,-1764083170,12933347111,11826248786,6560351642,-147552065439,-15142281568,-223585896149,-33227671386,1538599762175,-2362921007285,-3842494269476,-1373663369535,-852841858119,5799948666156,22731339821848,154930063725569,-163457157374824,-561656261654081,-138020128526929,-1489678851056393,-2440568531014139,-12097412756592508,23675917128353,-12500371193614451,-93802339455219790,330411633364304894,575874890105110270,1206912044361606409,-172982254780896399,-1250668862010774221,1818961442268800363,6503887458921090956,7940917196478188227,8256352863375251982,-6375055782190919680,-1170074517
+6374628540732951412,-98,-77,-79,-94,-268,-93,435,-297,-992,3029,12332,-32220,42716,56693,37157,140038,-753592,328432,3435044,-1639684,11488779,-4520200,19618162,4736713,132397186,120366688,-707412649,178583328,-5101042373,-1641909274,13608442787,11949591288,9744196157,-146357423476,-23050939771,-244178626927,-20347866857,1466357936174,-2380232157533,-4377238241361,-2188203995042,-2397751481716,6482566615530,17411397689829,162648908822640,-179081766728533,-554926593786784,-225458087713757,-1401415700896974,-2853895084719263,-12456890656234958,-906832043785907,-13594126155054595,-94092562020842673,335300447320897759,601324108284830769,1175414403728411052,-247339820533912278,-1248795588172426413,2064046934682788474,7119080588502240048,7514299980188634294,8953706209014158096,-1883000484130080768,-1537703209
+6374628540732951412,-99,-78,-77,-89,-254,-62,432,-368,-970,3319,12002,-33534,42946,50703,46009,113241,-761554,346472,3423380,-1708876,11711162,-5521477,22844427,-1798476,147108904,88078757,-648370478,261833017,-5259443484,-1256899523,13385215366,12887392316,10998289662,-147163520873,-33802911107,-210867772017,-59380454205,1401435646825,-2572306538125,-4743189497224,-1887240569410,-2045372302050,3721323601074,12343451658586,147681697393328,-173631606029040,-504496883954036,-364044910004097,-1350592421136002,-2919783655275585,-12183897187040112,-2330300938477796,-9293894880241862,-94647392931900711,343943716412179710,608781623628790024,1129454885116734185,-337499997306114485,-1248486123630796652,2615366492889756017,6345116122915994512,6632887521689425565,6988480586687270911,778616428197083136,-1077352087
+6374628540732951412,-100,-79,-74,-83,-261,-38,432,-438,-854,3794,12190,-34021,42034,53855,54004,126283,-698857,243056,3527333,-1844719,11839216,-3774513,20583298,1306929,135165934,112746240,-693249006,220524802,-5352565547,-1445558721,14085837211,12069100020,8134921714,-139991207848,-48667480572,-180938968517,-75285604108,1452435319571,-2845809622956,-4483711106295,-1373591168907,-3513349491131,-583853487440,10887880330555,145374639353376,-158449650543217,-462358026026177,-291445694767592,-1335492669183622,-3404519259018197,-11519200095099184,-1779079486776686,-8636934320742974,-86029003917349578,358302517069012889,572940782694805973,1196487646124111155,-366597711339136635,-1508317170472846023,2989916733904352642,5666598452818716729,6795994566807492946,3687806819332783733,457474461569725440,1739998497
+6374628540732951412,-100,-79,-72,-77,-252,-20,412,-461,-818,3598,11542,-32285,43352,54594,69786,139542,-707295,250342,3434792,-1803412,10939757,-4856584,22067096,-3032200,128322292,115304089,-667697931,249254895,-5245028251,-1356327422,14417262962,10107116418,7064673460,-134129341135,-35429765134,-199765546174,-140364184058,1576982420503,-2894822519325,-4032541925514,-1059624020870,-3455781185296,2539561448002,7938539578742,136029910631035,-193265867139148,-448382548643265,-205655233382150,-1516868351876918,-3211597771608176,-11594309183340874,-3745339737199961,-7715780961270925,-84639081207171051,349907439755764838,540365425882720435,1220688172790277774,-358267483413840644,-1396762397922836866,2614184465637246193,6232910895744952206,5242985859007589448,1473783855700242825,5827945093090948096,1620621738
+6374628540732951412,-101,-79,-72,-84,-240,-48,360,-465,-881,3925,12218,-34063,43589,58417,69737,128335,-741870,336528,3546277,-1332772,11043844,-4655269,22503499,-7832990,140183618,144155849,-606341388,139878353,-5257127603,-1127496444,13546063015,9997712606,3963701829,-134381584617,-45206531902,-221582815555,-85050850261,1589303680812,-2660725312450,-4386993052650,-2059262809481,-3258597701165,631809257476,15695363675497,142946857706115,-212163260674953,-411453656385908,-86363404721932,-1332667873067094,-3365705520796607,-11031013390055446,-3990179289581347,-9105233918759693,-89957307508746340,357171120578071231,532603832683726379,1181327672906547036,-432890972624289430,-1235277922664510558,2592225859168162366,6371838621351677894,5437083258914735530,5075744262921326338,3416877765364260864,894696646
+6374628540732951412,-102,-81,-71,-79,-238,-25,323,-366,-754,3704,11990,-35523,42029,57374,64814,155742,-736391,248421,3432438,-1256840,11490839,-6177121,24234003,-2234980,136490145,144886252,-560460322,204877359,-5205329756,-652031115,13071902699,10707600187,5339923233,-136992714066,-41903318065,-231711924579,-73203318523,1553619909950,-2922224063547,-3938857796970,-1493782445471,-4567289389713,1082836651016,17121916691604,127852308845483,-218032263213729,-374648737367576,-144466674950788,-1291640621216687,-3299131803613346,-11222212754960877,-5168495411584519,-10728078292646261,-83808227567696513,342172331312616772,554895802904898733,1231483395721661918,-495008864807653086,-1000917543314614477,2486631645881191778,5774339079395439910,3336238468409181154,2208048563781175730,5509432560499826688,311454177
+6374628540732951412,-102,-81,-72,-79,-242,-49,372,-345,-893,3193,12497,-34200,44350,56647,60800,160625,-775420,273606,3418085,-1055162,10683242,-6853968,21679458,-1711068,142126488,144150308,-517394302,129962664,-5410943210,-1030115368,12097962562,8574587096,7740706332,-139280596381,-42113486467,-197572824349,-32618428856,1570235129998,-2699754702181,-4335884340854,-1131746858516,-4789922135582,-896653327630,10970732404821,120571335830821,-194164033892690,-337302927993282,-209539858212534,-1101464186586358,-2905180154787969,-12252098427538032,-4075832679862612,-10576929492634263,-90243993537452932,346794972723518953,557375139662679092,1184439216743852522,-368095751099526622,-1104376364325472106,2744934846281972561,4875910337145147369,2539545632307129033,6627027981251675780,2751481182911687680,-33600110
+6374628540732951412,-103,-82,-76,-80,-240,-49,420,-453,-910,2776,12441,-33014,41071,51462,66377,135127,-812910,159863,3509618,-852090,11285074,-6229635,24050571,-400762,156951013,143317984,-567114910,52802641,-5592873465,-777478954,13138009521,6857533677,8863588889,-137395323638,-37549130165,-208408025658,-89076585568,1613148303285,-2856942090246,-4661280838604,-1939905361775,-6910453057480,1377889278995,12817885499419,123714273910533,-170949294568307,-404665304859990,-184217031234261,-1376797329011516,-3206197664723563,-12655908669418963,-4265107467907752,-9363077003546189,-87072525573919378,339327984088896748,547805038716748472,1143261452286758060,-340410414981654839,-1287669391695280471,2665314427058622212,5939343424110852343,412712778301938597,8308586053446153357,2582452610370829312,254136712
+6374628540732951412,-104,-81,-79,-76,-225,-56,476,-445,-705,2991,13445,-31482,37591,56830,54892,117120,-869404,169701,3486458,-738990,11468562,-7810176,25967351,-3614446,169215083,126508638,-509609350,-73897743,-5834684238,-1116937816,12486888195,5519472531,6827835625,-134137965682,-45476514027,-235545293061,-38569682708,1665294434042,-2926211785103,-4988823342986,-1815807069420,-8738522616121,3421723231327,4819946441429,123492356993182,-166142443625906,-437742562887761,-167953568755481,-1110066079910075,-3202631183416215,-12820676404046078,-2077434338564761,-9717723443718841,-78396717836077841,330050710980862265,583116630203587195,1172623084473726962,-482599285977119474,-1527139128303696219,2383699001671886398,5463299955536442082,302758700234446107,5110441496851123501,-204551969942868992,697406929
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/dict-page-offset-zero.parquet b/src/arrow/cpp/submodules/parquet-testing/data/dict-page-offset-zero.parquet
new file mode 100644
index 000000000..f9dbd7fd5
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/dict-page-offset-zero.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer.parquet.encrypted b/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer.parquet.encrypted
new file mode 100644
index 000000000..460d05b37
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer.parquet.encrypted
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer_aad.parquet.encrypted b/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer_aad.parquet.encrypted
new file mode 100644
index 000000000..863ccb27f
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer_aad.parquet.encrypted
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer_ctr.parquet.encrypted b/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer_ctr.parquet.encrypted
new file mode 100644
index 000000000..0591df491
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer_ctr.parquet.encrypted
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer_disable_aad_storage.parquet.encrypted b/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer_disable_aad_storage.parquet.encrypted
new file mode 100644
index 000000000..abd40e711
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_and_footer_disable_aad_storage.parquet.encrypted
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_plaintext_footer.parquet.encrypted b/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_plaintext_footer.parquet.encrypted
new file mode 100644
index 000000000..0625080a3
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/encrypt_columns_plaintext_footer.parquet.encrypted
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/fixed_length_decimal.parquet b/src/arrow/cpp/submodules/parquet-testing/data/fixed_length_decimal.parquet
new file mode 100644
index 000000000..69fce531e
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/fixed_length_decimal.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/fixed_length_decimal_legacy.parquet b/src/arrow/cpp/submodules/parquet-testing/data/fixed_length_decimal_legacy.parquet
new file mode 100644
index 000000000..b0df62a2e
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/fixed_length_decimal_legacy.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/hadoop_lz4_compressed.parquet b/src/arrow/cpp/submodules/parquet-testing/data/hadoop_lz4_compressed.parquet
new file mode 100644
index 000000000..b5fadcd49
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/hadoop_lz4_compressed.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/hadoop_lz4_compressed_larger.parquet b/src/arrow/cpp/submodules/parquet-testing/data/hadoop_lz4_compressed_larger.parquet
new file mode 100644
index 000000000..0f133f897
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/hadoop_lz4_compressed_larger.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/int32_decimal.parquet b/src/arrow/cpp/submodules/parquet-testing/data/int32_decimal.parquet
new file mode 100644
index 000000000..5bf2d4ea3
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/int32_decimal.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/int64_decimal.parquet b/src/arrow/cpp/submodules/parquet-testing/data/int64_decimal.parquet
new file mode 100644
index 000000000..5043bcac5
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/int64_decimal.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/list_columns.parquet b/src/arrow/cpp/submodules/parquet-testing/data/list_columns.parquet
new file mode 100644
index 000000000..ecd7597e2
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/list_columns.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/lz4_raw_compressed.parquet b/src/arrow/cpp/submodules/parquet-testing/data/lz4_raw_compressed.parquet
new file mode 100644
index 000000000..4f78711b5
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/lz4_raw_compressed.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/lz4_raw_compressed_larger.parquet b/src/arrow/cpp/submodules/parquet-testing/data/lz4_raw_compressed_larger.parquet
new file mode 100644
index 000000000..b83c59e5f
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/lz4_raw_compressed_larger.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/nation.dict-malformed.parquet b/src/arrow/cpp/submodules/parquet-testing/data/nation.dict-malformed.parquet
new file mode 100644
index 000000000..5008ac0b2
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/nation.dict-malformed.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/nested_lists.snappy.parquet b/src/arrow/cpp/submodules/parquet-testing/data/nested_lists.snappy.parquet
new file mode 100644
index 000000000..f66ba04b6
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/nested_lists.snappy.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/nested_maps.snappy.parquet b/src/arrow/cpp/submodules/parquet-testing/data/nested_maps.snappy.parquet
new file mode 100644
index 000000000..6645527df
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/nested_maps.snappy.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/nested_structs.rust.parquet b/src/arrow/cpp/submodules/parquet-testing/data/nested_structs.rust.parquet
new file mode 100644
index 000000000..1355cff47
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/nested_structs.rust.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/non_hadoop_lz4_compressed.parquet b/src/arrow/cpp/submodules/parquet-testing/data/non_hadoop_lz4_compressed.parquet
new file mode 100644
index 000000000..cfbdc7ef2
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/non_hadoop_lz4_compressed.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/nonnullable.impala.parquet b/src/arrow/cpp/submodules/parquet-testing/data/nonnullable.impala.parquet
new file mode 100644
index 000000000..f4be08287
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/nonnullable.impala.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/nullable.impala.parquet b/src/arrow/cpp/submodules/parquet-testing/data/nullable.impala.parquet
new file mode 100644
index 000000000..2c72f52f3
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/nullable.impala.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/nulls.snappy.parquet b/src/arrow/cpp/submodules/parquet-testing/data/nulls.snappy.parquet
new file mode 100644
index 000000000..4046d79b7
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/nulls.snappy.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/repeated_no_annotation.parquet b/src/arrow/cpp/submodules/parquet-testing/data/repeated_no_annotation.parquet
new file mode 100644
index 000000000..02f20a64c
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/repeated_no_annotation.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/single_nan.parquet b/src/arrow/cpp/submodules/parquet-testing/data/single_nan.parquet
new file mode 100644
index 000000000..84dac10f0
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/single_nan.parquet
Binary files differ
diff --git a/src/arrow/cpp/submodules/parquet-testing/data/uniform_encryption.parquet.encrypted b/src/arrow/cpp/submodules/parquet-testing/data/uniform_encryption.parquet.encrypted
new file mode 100644
index 000000000..048f35f7e
--- /dev/null
+++ b/src/arrow/cpp/submodules/parquet-testing/data/uniform_encryption.parquet.encrypted
Binary files differ
diff --git a/src/arrow/cpp/thirdparty/README.md b/src/arrow/cpp/thirdparty/README.md
new file mode 100644
index 000000000..86ccfa71d
--- /dev/null
+++ b/src/arrow/cpp/thirdparty/README.md
@@ -0,0 +1,25 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Arrow C++ Thirdparty Dependencies
+
+See the "Build Dependency Management" section in the [C++ Developer
+Documentation][1].
+
+[1]: https://github.com/apache/arrow/blob/master/docs/source/developers/cpp/building.rst
diff --git a/src/arrow/cpp/thirdparty/download_dependencies.sh b/src/arrow/cpp/thirdparty/download_dependencies.sh
new file mode 100755
index 000000000..7ffffa08c
--- /dev/null
+++ b/src/arrow/cpp/thirdparty/download_dependencies.sh
@@ -0,0 +1,63 @@
+#!/usr/bin/env bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This script downloads all the thirdparty dependencies as a series of tarballs
+# that can be used for offline builds, etc.
+
+set -eu
+
+SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+
+if [ "$#" -ne 1 ]; then
+ DESTDIR=$(pwd)
+else
+ DESTDIR=$1
+fi
+
+DESTDIR=$(readlink -f "${DESTDIR}")
+
+download_dependency() {
+ local url=$1
+ local out=$2
+
+ wget --quiet --continue --output-document="${out}" "${url}" || \
+ (echo "Failed downloading ${url}" 1>&2; exit 1)
+}
+
+main() {
+ mkdir -p "${DESTDIR}"
+
+ # Load `DEPENDENCIES` variable.
+ source ${SOURCE_DIR}/versions.txt
+
+ echo "# Environment variables for offline Arrow build"
+ for ((i = 0; i < ${#DEPENDENCIES[@]}; i++)); do
+ local dep_packed=${DEPENDENCIES[$i]}
+
+ # Unpack each entry of the form "$home_var $tar_out $dep_url"
+ IFS=" " read -r dep_url_var dep_tar_name dep_url <<< "${dep_packed}"
+
+ local out=${DESTDIR}/${dep_tar_name}
+ download_dependency "${dep_url}" "${out}"
+
+ echo "export ${dep_url_var}=${out}"
+ done
+}
+
+main
diff --git a/src/arrow/cpp/thirdparty/flatbuffers/include/flatbuffers/base.h b/src/arrow/cpp/thirdparty/flatbuffers/include/flatbuffers/base.h
new file mode 100644
index 000000000..955738067
--- /dev/null
+++ b/src/arrow/cpp/thirdparty/flatbuffers/include/flatbuffers/base.h
@@ -0,0 +1,398 @@
+#ifndef FLATBUFFERS_BASE_H_
+#define FLATBUFFERS_BASE_H_
+
+// clang-format off
+
+// If activate should be declared and included first.
+#if defined(FLATBUFFERS_MEMORY_LEAK_TRACKING) && \
+ defined(_MSC_VER) && defined(_DEBUG)
+ // The _CRTDBG_MAP_ALLOC inside <crtdbg.h> will replace
+ // calloc/free (etc) to its debug version using #define directives.
+ #define _CRTDBG_MAP_ALLOC
+ #include <stdlib.h>
+ #include <crtdbg.h>
+ // Replace operator new by trace-enabled version.
+ #define DEBUG_NEW new(_NORMAL_BLOCK, __FILE__, __LINE__)
+ #define new DEBUG_NEW
+#endif
+
+#if !defined(FLATBUFFERS_ASSERT)
+#include <assert.h>
+#define FLATBUFFERS_ASSERT assert
+#elif defined(FLATBUFFERS_ASSERT_INCLUDE)
+// Include file with forward declaration
+#include FLATBUFFERS_ASSERT_INCLUDE
+#endif
+
+#ifndef ARDUINO
+#include <cstdint>
+#endif
+
+#include <cstddef>
+#include <cstdlib>
+#include <cstring>
+
+#if defined(ARDUINO) && !defined(ARDUINOSTL_M_H)
+ #include <utility.h>
+#else
+ #include <utility>
+#endif
+
+#include <string>
+#include <type_traits>
+#include <vector>
+#include <set>
+#include <algorithm>
+#include <iterator>
+#include <memory>
+
+#ifdef _STLPORT_VERSION
+ #define FLATBUFFERS_CPP98_STL
+#endif
+#ifndef FLATBUFFERS_CPP98_STL
+ #include <functional>
+#endif
+
+#include "flatbuffers/stl_emulation.h"
+
+#if defined(__ICCARM__)
+#include <intrinsics.h>
+#endif
+
+// Note the __clang__ check is needed, because clang presents itself
+// as an older GNUC compiler (4.2).
+// Clang 3.3 and later implement all of the ISO C++ 2011 standard.
+// Clang 3.4 and later implement all of the ISO C++ 2014 standard.
+// http://clang.llvm.org/cxx_status.html
+
+// Note the MSVC value '__cplusplus' may be incorrect:
+// The '__cplusplus' predefined macro in the MSVC stuck at the value 199711L,
+// indicating (erroneously!) that the compiler conformed to the C++98 Standard.
+// This value should be correct starting from MSVC2017-15.7-Preview-3.
+// The '__cplusplus' will be valid only if MSVC2017-15.7-P3 and the `/Zc:__cplusplus` switch is set.
+// Workaround (for details see MSDN):
+// Use the _MSC_VER and _MSVC_LANG definition instead of the __cplusplus for compatibility.
+// The _MSVC_LANG macro reports the Standard version regardless of the '/Zc:__cplusplus' switch.
+
+#if defined(__GNUC__) && !defined(__clang__)
+ #define FLATBUFFERS_GCC (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__)
+#else
+ #define FLATBUFFERS_GCC 0
+#endif
+
+#if defined(__clang__)
+ #define FLATBUFFERS_CLANG (__clang_major__ * 10000 + __clang_minor__ * 100 + __clang_patchlevel__)
+#else
+ #define FLATBUFFERS_CLANG 0
+#endif
+
+/// @cond FLATBUFFERS_INTERNAL
+#if __cplusplus <= 199711L && \
+ (!defined(_MSC_VER) || _MSC_VER < 1600) && \
+ (!defined(__GNUC__) || \
+ (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__ < 40400))
+ #error A C++11 compatible compiler with support for the auto typing is \
+ required for FlatBuffers.
+ #error __cplusplus _MSC_VER __GNUC__ __GNUC_MINOR__ __GNUC_PATCHLEVEL__
+#endif
+
+#if !defined(__clang__) && \
+ defined(__GNUC__) && \
+ (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__ < 40600)
+ // Backwards compatibility for g++ 4.4, and 4.5 which don't have the nullptr
+ // and constexpr keywords. Note the __clang__ check is needed, because clang
+ // presents itself as an older GNUC compiler.
+ #ifndef nullptr_t
+ const class nullptr_t {
+ public:
+ template<class T> inline operator T*() const { return 0; }
+ private:
+ void operator&() const;
+ } nullptr = {};
+ #endif
+ #ifndef constexpr
+ #define constexpr const
+ #endif
+#endif
+
+// The wire format uses a little endian encoding (since that's efficient for
+// the common platforms).
+#if defined(__s390x__)
+ #define FLATBUFFERS_LITTLEENDIAN 0
+#endif // __s390x__
+#if !defined(FLATBUFFERS_LITTLEENDIAN)
+ #if defined(__GNUC__) || defined(__clang__) || defined(__ICCARM__)
+ #if (defined(__BIG_ENDIAN__) || \
+ (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__))
+ #define FLATBUFFERS_LITTLEENDIAN 0
+ #else
+ #define FLATBUFFERS_LITTLEENDIAN 1
+ #endif // __BIG_ENDIAN__
+ #elif defined(_MSC_VER)
+ #if defined(_M_PPC)
+ #define FLATBUFFERS_LITTLEENDIAN 0
+ #else
+ #define FLATBUFFERS_LITTLEENDIAN 1
+ #endif
+ #else
+ #error Unable to determine endianness, define FLATBUFFERS_LITTLEENDIAN.
+ #endif
+#endif // !defined(FLATBUFFERS_LITTLEENDIAN)
+
+#define FLATBUFFERS_VERSION_MAJOR 1
+#define FLATBUFFERS_VERSION_MINOR 12
+#define FLATBUFFERS_VERSION_REVISION 0
+#define FLATBUFFERS_STRING_EXPAND(X) #X
+#define FLATBUFFERS_STRING(X) FLATBUFFERS_STRING_EXPAND(X)
+namespace flatbuffers {
+ // Returns version as string "MAJOR.MINOR.REVISION".
+ const char* FLATBUFFERS_VERSION();
+}
+
+#if (!defined(_MSC_VER) || _MSC_VER > 1600) && \
+ (!defined(__GNUC__) || (__GNUC__ * 100 + __GNUC_MINOR__ >= 407)) || \
+ defined(__clang__)
+ #define FLATBUFFERS_FINAL_CLASS final
+ #define FLATBUFFERS_OVERRIDE override
+ #define FLATBUFFERS_VTABLE_UNDERLYING_TYPE : flatbuffers::voffset_t
+#else
+ #define FLATBUFFERS_FINAL_CLASS
+ #define FLATBUFFERS_OVERRIDE
+ #define FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+#endif
+
+#if (!defined(_MSC_VER) || _MSC_VER >= 1900) && \
+ (!defined(__GNUC__) || (__GNUC__ * 100 + __GNUC_MINOR__ >= 406)) || \
+ (defined(__cpp_constexpr) && __cpp_constexpr >= 200704)
+ #define FLATBUFFERS_CONSTEXPR constexpr
+#else
+ #define FLATBUFFERS_CONSTEXPR const
+#endif
+
+#if (defined(__cplusplus) && __cplusplus >= 201402L) || \
+ (defined(__cpp_constexpr) && __cpp_constexpr >= 201304)
+ #define FLATBUFFERS_CONSTEXPR_CPP14 FLATBUFFERS_CONSTEXPR
+#else
+ #define FLATBUFFERS_CONSTEXPR_CPP14
+#endif
+
+#if (defined(__GXX_EXPERIMENTAL_CXX0X__) && (__GNUC__ * 100 + __GNUC_MINOR__ >= 406)) || \
+ (defined(_MSC_FULL_VER) && (_MSC_FULL_VER >= 190023026)) || \
+ defined(__clang__)
+ #define FLATBUFFERS_NOEXCEPT noexcept
+#else
+ #define FLATBUFFERS_NOEXCEPT
+#endif
+
+// NOTE: the FLATBUFFERS_DELETE_FUNC macro may change the access mode to
+// private, so be sure to put it at the end or reset access mode explicitly.
+#if (!defined(_MSC_VER) || _MSC_FULL_VER >= 180020827) && \
+ (!defined(__GNUC__) || (__GNUC__ * 100 + __GNUC_MINOR__ >= 404)) || \
+ defined(__clang__)
+ #define FLATBUFFERS_DELETE_FUNC(func) func = delete;
+#else
+ #define FLATBUFFERS_DELETE_FUNC(func) private: func;
+#endif
+
+#ifndef FLATBUFFERS_HAS_STRING_VIEW
+ // Only provide flatbuffers::string_view if __has_include can be used
+ // to detect a header that provides an implementation
+ #if defined(__has_include)
+ // Check for std::string_view (in c++17)
+ #if __has_include(<string_view>) && (__cplusplus >= 201606 || (defined(_HAS_CXX17) && _HAS_CXX17))
+ #include <string_view>
+ namespace flatbuffers {
+ typedef std::string_view string_view;
+ }
+ #define FLATBUFFERS_HAS_STRING_VIEW 1
+ // Check for std::experimental::string_view (in c++14, compiler-dependent)
+ #elif __has_include(<experimental/string_view>) && (__cplusplus >= 201411)
+ #include <experimental/string_view>
+ namespace flatbuffers {
+ typedef std::experimental::string_view string_view;
+ }
+ #define FLATBUFFERS_HAS_STRING_VIEW 1
+ // Check for absl::string_view
+ #elif __has_include("absl/strings/string_view.h")
+ #include "absl/strings/string_view.h"
+ namespace flatbuffers {
+ typedef absl::string_view string_view;
+ }
+ #define FLATBUFFERS_HAS_STRING_VIEW 1
+ #endif
+ #endif // __has_include
+#endif // !FLATBUFFERS_HAS_STRING_VIEW
+
+#ifndef FLATBUFFERS_HAS_NEW_STRTOD
+ // Modern (C++11) strtod and strtof functions are available for use.
+ // 1) nan/inf strings as argument of strtod;
+ // 2) hex-float as argument of strtod/strtof.
+ #if (defined(_MSC_VER) && _MSC_VER >= 1900) || \
+ (defined(__GNUC__) && (__GNUC__ * 100 + __GNUC_MINOR__ >= 409)) || \
+ (defined(__clang__))
+ #define FLATBUFFERS_HAS_NEW_STRTOD 1
+ #endif
+#endif // !FLATBUFFERS_HAS_NEW_STRTOD
+
+#ifndef FLATBUFFERS_LOCALE_INDEPENDENT
+ // Enable locale independent functions {strtof_l, strtod_l,strtoll_l, strtoull_l}.
+ // They are part of the POSIX-2008 but not part of the C/C++ standard.
+ // GCC/Clang have definition (_XOPEN_SOURCE>=700) if POSIX-2008.
+ #if ((defined(_MSC_VER) && _MSC_VER >= 1800) || \
+ (defined(_XOPEN_SOURCE) && (_XOPEN_SOURCE>=700)))
+ #define FLATBUFFERS_LOCALE_INDEPENDENT 1
+ #else
+ #define FLATBUFFERS_LOCALE_INDEPENDENT 0
+ #endif
+#endif // !FLATBUFFERS_LOCALE_INDEPENDENT
+
+// Suppress Undefined Behavior Sanitizer (recoverable only). Usage:
+// - __supress_ubsan__("undefined")
+// - __supress_ubsan__("signed-integer-overflow")
+#if defined(__clang__) && (__clang_major__ > 3 || (__clang_major__ == 3 && __clang_minor__ >=7))
+ #define __supress_ubsan__(type) __attribute__((no_sanitize(type)))
+#elif defined(__GNUC__) && (__GNUC__ * 100 + __GNUC_MINOR__ >= 409)
+ #define __supress_ubsan__(type) __attribute__((no_sanitize_undefined))
+#else
+ #define __supress_ubsan__(type)
+#endif
+
+// This is constexpr function used for checking compile-time constants.
+// Avoid `#pragma warning(disable: 4127) // C4127: expression is constant`.
+template<typename T> FLATBUFFERS_CONSTEXPR inline bool IsConstTrue(T t) {
+ return !!t;
+}
+
+// Enable C++ attribute [[]] if std:c++17 or higher.
+#if ((__cplusplus >= 201703L) \
+ || (defined(_MSVC_LANG) && (_MSVC_LANG >= 201703L)))
+ // All attributes unknown to an implementation are ignored without causing an error.
+ #define FLATBUFFERS_ATTRIBUTE(attr) [[attr]]
+
+ #define FLATBUFFERS_FALLTHROUGH() [[fallthrough]]
+#else
+ #define FLATBUFFERS_ATTRIBUTE(attr)
+
+ #if FLATBUFFERS_CLANG >= 30800
+ #define FLATBUFFERS_FALLTHROUGH() [[clang::fallthrough]]
+ #elif FLATBUFFERS_GCC >= 70300
+ #define FLATBUFFERS_FALLTHROUGH() [[gnu::fallthrough]]
+ #else
+ #define FLATBUFFERS_FALLTHROUGH()
+ #endif
+#endif
+
+/// @endcond
+
+/// @file
+namespace flatbuffers {
+
+/// @cond FLATBUFFERS_INTERNAL
+// Our default offset / size type, 32bit on purpose on 64bit systems.
+// Also, using a consistent offset type maintains compatibility of serialized
+// offset values between 32bit and 64bit systems.
+typedef uint32_t uoffset_t;
+
+// Signed offsets for references that can go in both directions.
+typedef int32_t soffset_t;
+
+// Offset/index used in v-tables, can be changed to uint8_t in
+// format forks to save a bit of space if desired.
+typedef uint16_t voffset_t;
+
+typedef uintmax_t largest_scalar_t;
+
+// In 32bits, this evaluates to 2GB - 1
+#define FLATBUFFERS_MAX_BUFFER_SIZE ((1ULL << (sizeof(::flatbuffers::soffset_t) * 8 - 1)) - 1)
+
+// We support aligning the contents of buffers up to this size.
+#define FLATBUFFERS_MAX_ALIGNMENT 16
+
+#if defined(_MSC_VER)
+ #pragma warning(push)
+ #pragma warning(disable: 4127) // C4127: conditional expression is constant
+#endif
+
+template<typename T> T EndianSwap(T t) {
+ #if defined(_MSC_VER)
+ #define FLATBUFFERS_BYTESWAP16 _byteswap_ushort
+ #define FLATBUFFERS_BYTESWAP32 _byteswap_ulong
+ #define FLATBUFFERS_BYTESWAP64 _byteswap_uint64
+ #elif defined(__ICCARM__)
+ #define FLATBUFFERS_BYTESWAP16 __REV16
+ #define FLATBUFFERS_BYTESWAP32 __REV
+ #define FLATBUFFERS_BYTESWAP64(x) \
+ ((__REV(static_cast<uint32_t>(x >> 32U))) | (static_cast<uint64_t>(__REV(static_cast<uint32_t>(x)))) << 32U)
+ #else
+ #if defined(__GNUC__) && __GNUC__ * 100 + __GNUC_MINOR__ < 408 && !defined(__clang__)
+ // __builtin_bswap16 was missing prior to GCC 4.8.
+ #define FLATBUFFERS_BYTESWAP16(x) \
+ static_cast<uint16_t>(__builtin_bswap32(static_cast<uint32_t>(x) << 16))
+ #else
+ #define FLATBUFFERS_BYTESWAP16 __builtin_bswap16
+ #endif
+ #define FLATBUFFERS_BYTESWAP32 __builtin_bswap32
+ #define FLATBUFFERS_BYTESWAP64 __builtin_bswap64
+ #endif
+ if (sizeof(T) == 1) { // Compile-time if-then's.
+ return t;
+ } else if (sizeof(T) == 2) {
+ union { T t; uint16_t i; } u = { t };
+ u.i = FLATBUFFERS_BYTESWAP16(u.i);
+ return u.t;
+ } else if (sizeof(T) == 4) {
+ union { T t; uint32_t i; } u = { t };
+ u.i = FLATBUFFERS_BYTESWAP32(u.i);
+ return u.t;
+ } else if (sizeof(T) == 8) {
+ union { T t; uint64_t i; } u = { t };
+ u.i = FLATBUFFERS_BYTESWAP64(u.i);
+ return u.t;
+ } else {
+ FLATBUFFERS_ASSERT(0);
+ return t;
+ }
+}
+
+#if defined(_MSC_VER)
+ #pragma warning(pop)
+#endif
+
+
+template<typename T> T EndianScalar(T t) {
+ #if FLATBUFFERS_LITTLEENDIAN
+ return t;
+ #else
+ return EndianSwap(t);
+ #endif
+}
+
+template<typename T>
+// UBSAN: C++ aliasing type rules, see std::bit_cast<> for details.
+__supress_ubsan__("alignment")
+T ReadScalar(const void *p) {
+ return EndianScalar(*reinterpret_cast<const T *>(p));
+}
+
+template<typename T>
+// UBSAN: C++ aliasing type rules, see std::bit_cast<> for details.
+__supress_ubsan__("alignment")
+void WriteScalar(void *p, T t) {
+ *reinterpret_cast<T *>(p) = EndianScalar(t);
+}
+
+template<typename T> struct Offset;
+template<typename T> __supress_ubsan__("alignment") void WriteScalar(void *p, Offset<T> t) {
+ *reinterpret_cast<uoffset_t *>(p) = EndianScalar(t.o);
+}
+
+// Computes how many bytes you'd have to pad to be able to write an
+// "scalar_size" scalar if the buffer had grown to "buf_size" (downwards in
+// memory).
+__supress_ubsan__("unsigned-integer-overflow")
+inline size_t PaddingBytes(size_t buf_size, size_t scalar_size) {
+ return ((~buf_size) + 1) & (scalar_size - 1);
+}
+
+} // namespace flatbuffers
+#endif // FLATBUFFERS_BASE_H_
diff --git a/src/arrow/cpp/thirdparty/flatbuffers/include/flatbuffers/flatbuffers.h b/src/arrow/cpp/thirdparty/flatbuffers/include/flatbuffers/flatbuffers.h
new file mode 100644
index 000000000..c4dc5bcd0
--- /dev/null
+++ b/src/arrow/cpp/thirdparty/flatbuffers/include/flatbuffers/flatbuffers.h
@@ -0,0 +1,2783 @@
+/*
+ * Copyright 2014 Google Inc. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FLATBUFFERS_H_
+#define FLATBUFFERS_H_
+
+#include "flatbuffers/base.h"
+
+#if defined(FLATBUFFERS_NAN_DEFAULTS)
+# include <cmath>
+#endif
+
+namespace flatbuffers {
+// Generic 'operator==' with conditional specialisations.
+// T e - new value of a scalar field.
+// T def - default of scalar (is known at compile-time).
+template<typename T> inline bool IsTheSameAs(T e, T def) { return e == def; }
+
+#if defined(FLATBUFFERS_NAN_DEFAULTS) && \
+ defined(FLATBUFFERS_HAS_NEW_STRTOD) && (FLATBUFFERS_HAS_NEW_STRTOD > 0)
+// Like `operator==(e, def)` with weak NaN if T=(float|double).
+template<typename T> inline bool IsFloatTheSameAs(T e, T def) {
+ return (e == def) || ((def != def) && (e != e));
+}
+template<> inline bool IsTheSameAs<float>(float e, float def) {
+ return IsFloatTheSameAs(e, def);
+}
+template<> inline bool IsTheSameAs<double>(double e, double def) {
+ return IsFloatTheSameAs(e, def);
+}
+#endif
+
+// Check 'v' is out of closed range [low; high].
+// Workaround for GCC warning [-Werror=type-limits]:
+// comparison is always true due to limited range of data type.
+template<typename T>
+inline bool IsOutRange(const T &v, const T &low, const T &high) {
+ return (v < low) || (high < v);
+}
+
+// Check 'v' is in closed range [low; high].
+template<typename T>
+inline bool IsInRange(const T &v, const T &low, const T &high) {
+ return !IsOutRange(v, low, high);
+}
+
+// Wrapper for uoffset_t to allow safe template specialization.
+// Value is allowed to be 0 to indicate a null object (see e.g. AddOffset).
+template<typename T> struct Offset {
+ uoffset_t o;
+ Offset() : o(0) {}
+ Offset(uoffset_t _o) : o(_o) {}
+ Offset<void> Union() const { return Offset<void>(o); }
+ bool IsNull() const { return !o; }
+};
+
+inline void EndianCheck() {
+ int endiantest = 1;
+ // If this fails, see FLATBUFFERS_LITTLEENDIAN above.
+ FLATBUFFERS_ASSERT(*reinterpret_cast<char *>(&endiantest) ==
+ FLATBUFFERS_LITTLEENDIAN);
+ (void)endiantest;
+}
+
+template<typename T> FLATBUFFERS_CONSTEXPR size_t AlignOf() {
+ // clang-format off
+ #ifdef _MSC_VER
+ return __alignof(T);
+ #else
+ #ifndef alignof
+ return __alignof__(T);
+ #else
+ return alignof(T);
+ #endif
+ #endif
+ // clang-format on
+}
+
+// When we read serialized data from memory, in the case of most scalars,
+// we want to just read T, but in the case of Offset, we want to actually
+// perform the indirection and return a pointer.
+// The template specialization below does just that.
+// It is wrapped in a struct since function templates can't overload on the
+// return type like this.
+// The typedef is for the convenience of callers of this function
+// (avoiding the need for a trailing return decltype)
+template<typename T> struct IndirectHelper {
+ typedef T return_type;
+ typedef T mutable_return_type;
+ static const size_t element_stride = sizeof(T);
+ static return_type Read(const uint8_t *p, uoffset_t i) {
+ return EndianScalar((reinterpret_cast<const T *>(p))[i]);
+ }
+};
+template<typename T> struct IndirectHelper<Offset<T>> {
+ typedef const T *return_type;
+ typedef T *mutable_return_type;
+ static const size_t element_stride = sizeof(uoffset_t);
+ static return_type Read(const uint8_t *p, uoffset_t i) {
+ p += i * sizeof(uoffset_t);
+ return reinterpret_cast<return_type>(p + ReadScalar<uoffset_t>(p));
+ }
+};
+template<typename T> struct IndirectHelper<const T *> {
+ typedef const T *return_type;
+ typedef T *mutable_return_type;
+ static const size_t element_stride = sizeof(T);
+ static return_type Read(const uint8_t *p, uoffset_t i) {
+ return reinterpret_cast<const T *>(p + i * sizeof(T));
+ }
+};
+
+// An STL compatible iterator implementation for Vector below, effectively
+// calling Get() for every element.
+template<typename T, typename IT> struct VectorIterator {
+ typedef std::random_access_iterator_tag iterator_category;
+ typedef IT value_type;
+ typedef ptrdiff_t difference_type;
+ typedef IT *pointer;
+ typedef IT &reference;
+
+ VectorIterator(const uint8_t *data, uoffset_t i)
+ : data_(data + IndirectHelper<T>::element_stride * i) {}
+ VectorIterator(const VectorIterator &other) : data_(other.data_) {}
+ VectorIterator() : data_(nullptr) {}
+
+ VectorIterator &operator=(const VectorIterator &other) {
+ data_ = other.data_;
+ return *this;
+ }
+
+ // clang-format off
+ #if !defined(FLATBUFFERS_CPP98_STL)
+ VectorIterator &operator=(VectorIterator &&other) {
+ data_ = other.data_;
+ return *this;
+ }
+ #endif // !defined(FLATBUFFERS_CPP98_STL)
+ // clang-format on
+
+ bool operator==(const VectorIterator &other) const {
+ return data_ == other.data_;
+ }
+
+ bool operator<(const VectorIterator &other) const {
+ return data_ < other.data_;
+ }
+
+ bool operator!=(const VectorIterator &other) const {
+ return data_ != other.data_;
+ }
+
+ difference_type operator-(const VectorIterator &other) const {
+ return (data_ - other.data_) / IndirectHelper<T>::element_stride;
+ }
+
+ IT operator*() const { return IndirectHelper<T>::Read(data_, 0); }
+
+ IT operator->() const { return IndirectHelper<T>::Read(data_, 0); }
+
+ VectorIterator &operator++() {
+ data_ += IndirectHelper<T>::element_stride;
+ return *this;
+ }
+
+ VectorIterator operator++(int) {
+ VectorIterator temp(data_, 0);
+ data_ += IndirectHelper<T>::element_stride;
+ return temp;
+ }
+
+ VectorIterator operator+(const uoffset_t &offset) const {
+ return VectorIterator(data_ + offset * IndirectHelper<T>::element_stride,
+ 0);
+ }
+
+ VectorIterator &operator+=(const uoffset_t &offset) {
+ data_ += offset * IndirectHelper<T>::element_stride;
+ return *this;
+ }
+
+ VectorIterator &operator--() {
+ data_ -= IndirectHelper<T>::element_stride;
+ return *this;
+ }
+
+ VectorIterator operator--(int) {
+ VectorIterator temp(data_, 0);
+ data_ -= IndirectHelper<T>::element_stride;
+ return temp;
+ }
+
+ VectorIterator operator-(const uoffset_t &offset) const {
+ return VectorIterator(data_ - offset * IndirectHelper<T>::element_stride,
+ 0);
+ }
+
+ VectorIterator &operator-=(const uoffset_t &offset) {
+ data_ -= offset * IndirectHelper<T>::element_stride;
+ return *this;
+ }
+
+ private:
+ const uint8_t *data_;
+};
+
+template<typename Iterator>
+struct VectorReverseIterator : public std::reverse_iterator<Iterator> {
+ explicit VectorReverseIterator(Iterator iter)
+ : std::reverse_iterator<Iterator>(iter) {}
+
+ typename Iterator::value_type operator*() const {
+ return *(std::reverse_iterator<Iterator>::current);
+ }
+
+ typename Iterator::value_type operator->() const {
+ return *(std::reverse_iterator<Iterator>::current);
+ }
+};
+
+struct String;
+
+// This is used as a helper type for accessing vectors.
+// Vector::data() assumes the vector elements start after the length field.
+template<typename T> class Vector {
+ public:
+ typedef VectorIterator<T, typename IndirectHelper<T>::mutable_return_type>
+ iterator;
+ typedef VectorIterator<T, typename IndirectHelper<T>::return_type>
+ const_iterator;
+ typedef VectorReverseIterator<iterator> reverse_iterator;
+ typedef VectorReverseIterator<const_iterator> const_reverse_iterator;
+
+ uoffset_t size() const { return EndianScalar(length_); }
+
+ // Deprecated: use size(). Here for backwards compatibility.
+ FLATBUFFERS_ATTRIBUTE(deprecated("use size() instead"))
+ uoffset_t Length() const { return size(); }
+
+ typedef typename IndirectHelper<T>::return_type return_type;
+ typedef typename IndirectHelper<T>::mutable_return_type mutable_return_type;
+
+ return_type Get(uoffset_t i) const {
+ FLATBUFFERS_ASSERT(i < size());
+ return IndirectHelper<T>::Read(Data(), i);
+ }
+
+ return_type operator[](uoffset_t i) const { return Get(i); }
+
+ // If this is a Vector of enums, T will be its storage type, not the enum
+ // type. This function makes it convenient to retrieve value with enum
+ // type E.
+ template<typename E> E GetEnum(uoffset_t i) const {
+ return static_cast<E>(Get(i));
+ }
+
+ // If this a vector of unions, this does the cast for you. There's no check
+ // to make sure this is the right type!
+ template<typename U> const U *GetAs(uoffset_t i) const {
+ return reinterpret_cast<const U *>(Get(i));
+ }
+
+ // If this a vector of unions, this does the cast for you. There's no check
+ // to make sure this is actually a string!
+ const String *GetAsString(uoffset_t i) const {
+ return reinterpret_cast<const String *>(Get(i));
+ }
+
+ const void *GetStructFromOffset(size_t o) const {
+ return reinterpret_cast<const void *>(Data() + o);
+ }
+
+ iterator begin() { return iterator(Data(), 0); }
+ const_iterator begin() const { return const_iterator(Data(), 0); }
+
+ iterator end() { return iterator(Data(), size()); }
+ const_iterator end() const { return const_iterator(Data(), size()); }
+
+ reverse_iterator rbegin() { return reverse_iterator(end() - 1); }
+ const_reverse_iterator rbegin() const {
+ return const_reverse_iterator(end() - 1);
+ }
+
+ reverse_iterator rend() { return reverse_iterator(begin() - 1); }
+ const_reverse_iterator rend() const {
+ return const_reverse_iterator(begin() - 1);
+ }
+
+ const_iterator cbegin() const { return begin(); }
+
+ const_iterator cend() const { return end(); }
+
+ const_reverse_iterator crbegin() const { return rbegin(); }
+
+ const_reverse_iterator crend() const { return rend(); }
+
+ // Change elements if you have a non-const pointer to this object.
+ // Scalars only. See reflection.h, and the documentation.
+ void Mutate(uoffset_t i, const T &val) {
+ FLATBUFFERS_ASSERT(i < size());
+ WriteScalar(data() + i, val);
+ }
+
+ // Change an element of a vector of tables (or strings).
+ // "val" points to the new table/string, as you can obtain from
+ // e.g. reflection::AddFlatBuffer().
+ void MutateOffset(uoffset_t i, const uint8_t *val) {
+ FLATBUFFERS_ASSERT(i < size());
+ static_assert(sizeof(T) == sizeof(uoffset_t), "Unrelated types");
+ WriteScalar(data() + i,
+ static_cast<uoffset_t>(val - (Data() + i * sizeof(uoffset_t))));
+ }
+
+ // Get a mutable pointer to tables/strings inside this vector.
+ mutable_return_type GetMutableObject(uoffset_t i) const {
+ FLATBUFFERS_ASSERT(i < size());
+ return const_cast<mutable_return_type>(IndirectHelper<T>::Read(Data(), i));
+ }
+
+ // The raw data in little endian format. Use with care.
+ const uint8_t *Data() const {
+ return reinterpret_cast<const uint8_t *>(&length_ + 1);
+ }
+
+ uint8_t *Data() { return reinterpret_cast<uint8_t *>(&length_ + 1); }
+
+ // Similarly, but typed, much like std::vector::data
+ const T *data() const { return reinterpret_cast<const T *>(Data()); }
+ T *data() { return reinterpret_cast<T *>(Data()); }
+
+ template<typename K> return_type LookupByKey(K key) const {
+ void *search_result = std::bsearch(
+ &key, Data(), size(), IndirectHelper<T>::element_stride, KeyCompare<K>);
+
+ if (!search_result) {
+ return nullptr; // Key not found.
+ }
+
+ const uint8_t *element = reinterpret_cast<const uint8_t *>(search_result);
+
+ return IndirectHelper<T>::Read(element, 0);
+ }
+
+ protected:
+ // This class is only used to access pre-existing data. Don't ever
+ // try to construct these manually.
+ Vector();
+
+ uoffset_t length_;
+
+ private:
+ // This class is a pointer. Copying will therefore create an invalid object.
+ // Private and unimplemented copy constructor.
+ Vector(const Vector &);
+ Vector &operator=(const Vector &);
+
+ template<typename K> static int KeyCompare(const void *ap, const void *bp) {
+ const K *key = reinterpret_cast<const K *>(ap);
+ const uint8_t *data = reinterpret_cast<const uint8_t *>(bp);
+ auto table = IndirectHelper<T>::Read(data, 0);
+
+ // std::bsearch compares with the operands transposed, so we negate the
+ // result here.
+ return -table->KeyCompareWithValue(*key);
+ }
+};
+
+// Represent a vector much like the template above, but in this case we
+// don't know what the element types are (used with reflection.h).
+class VectorOfAny {
+ public:
+ uoffset_t size() const { return EndianScalar(length_); }
+
+ const uint8_t *Data() const {
+ return reinterpret_cast<const uint8_t *>(&length_ + 1);
+ }
+ uint8_t *Data() { return reinterpret_cast<uint8_t *>(&length_ + 1); }
+
+ protected:
+ VectorOfAny();
+
+ uoffset_t length_;
+
+ private:
+ VectorOfAny(const VectorOfAny &);
+ VectorOfAny &operator=(const VectorOfAny &);
+};
+
+#ifndef FLATBUFFERS_CPP98_STL
+template<typename T, typename U>
+Vector<Offset<T>> *VectorCast(Vector<Offset<U>> *ptr) {
+ static_assert(std::is_base_of<T, U>::value, "Unrelated types");
+ return reinterpret_cast<Vector<Offset<T>> *>(ptr);
+}
+
+template<typename T, typename U>
+const Vector<Offset<T>> *VectorCast(const Vector<Offset<U>> *ptr) {
+ static_assert(std::is_base_of<T, U>::value, "Unrelated types");
+ return reinterpret_cast<const Vector<Offset<T>> *>(ptr);
+}
+#endif
+
+// Convenient helper function to get the length of any vector, regardless
+// of whether it is null or not (the field is not set).
+template<typename T> static inline size_t VectorLength(const Vector<T> *v) {
+ return v ? v->size() : 0;
+}
+
+// This is used as a helper type for accessing arrays.
+template<typename T, uint16_t length> class Array {
+ typedef
+ typename flatbuffers::integral_constant<bool,
+ flatbuffers::is_scalar<T>::value>
+ scalar_tag;
+ typedef
+ typename flatbuffers::conditional<scalar_tag::value, T, const T *>::type
+ IndirectHelperType;
+
+ public:
+ typedef typename IndirectHelper<IndirectHelperType>::return_type return_type;
+ typedef VectorIterator<T, return_type> const_iterator;
+ typedef VectorReverseIterator<const_iterator> const_reverse_iterator;
+
+ FLATBUFFERS_CONSTEXPR uint16_t size() const { return length; }
+
+ return_type Get(uoffset_t i) const {
+ FLATBUFFERS_ASSERT(i < size());
+ return IndirectHelper<IndirectHelperType>::Read(Data(), i);
+ }
+
+ return_type operator[](uoffset_t i) const { return Get(i); }
+
+ // If this is a Vector of enums, T will be its storage type, not the enum
+ // type. This function makes it convenient to retrieve value with enum
+ // type E.
+ template<typename E> E GetEnum(uoffset_t i) const {
+ return static_cast<E>(Get(i));
+ }
+
+ const_iterator begin() const { return const_iterator(Data(), 0); }
+ const_iterator end() const { return const_iterator(Data(), size()); }
+
+ const_reverse_iterator rbegin() const {
+ return const_reverse_iterator(end());
+ }
+ const_reverse_iterator rend() const { return const_reverse_iterator(end()); }
+
+ const_iterator cbegin() const { return begin(); }
+ const_iterator cend() const { return end(); }
+
+ const_reverse_iterator crbegin() const { return rbegin(); }
+ const_reverse_iterator crend() const { return rend(); }
+
+ // Get a mutable pointer to elements inside this array.
+ // This method used to mutate arrays of structs followed by a @p Mutate
+ // operation. For primitive types use @p Mutate directly.
+ // @warning Assignments and reads to/from the dereferenced pointer are not
+ // automatically converted to the correct endianness.
+ typename flatbuffers::conditional<scalar_tag::value, void, T *>::type
+ GetMutablePointer(uoffset_t i) const {
+ FLATBUFFERS_ASSERT(i < size());
+ return const_cast<T *>(&data()[i]);
+ }
+
+ // Change elements if you have a non-const pointer to this object.
+ void Mutate(uoffset_t i, const T &val) { MutateImpl(scalar_tag(), i, val); }
+
+ // The raw data in little endian format. Use with care.
+ const uint8_t *Data() const { return data_; }
+
+ uint8_t *Data() { return data_; }
+
+ // Similarly, but typed, much like std::vector::data
+ const T *data() const { return reinterpret_cast<const T *>(Data()); }
+ T *data() { return reinterpret_cast<T *>(Data()); }
+
+ protected:
+ void MutateImpl(flatbuffers::integral_constant<bool, true>, uoffset_t i,
+ const T &val) {
+ FLATBUFFERS_ASSERT(i < size());
+ WriteScalar(data() + i, val);
+ }
+
+ void MutateImpl(flatbuffers::integral_constant<bool, false>, uoffset_t i,
+ const T &val) {
+ *(GetMutablePointer(i)) = val;
+ }
+
+ // This class is only used to access pre-existing data. Don't ever
+ // try to construct these manually.
+ // 'constexpr' allows us to use 'size()' at compile time.
+ // @note Must not use 'FLATBUFFERS_CONSTEXPR' here, as const is not allowed on
+ // a constructor.
+#if defined(__cpp_constexpr)
+ constexpr Array();
+#else
+ Array();
+#endif
+
+ uint8_t data_[length * sizeof(T)];
+
+ private:
+ // This class is a pointer. Copying will therefore create an invalid object.
+ // Private and unimplemented copy constructor.
+ Array(const Array &);
+ Array &operator=(const Array &);
+};
+
+// Specialization for Array[struct] with access using Offset<void> pointer.
+// This specialization used by idl_gen_text.cpp.
+template<typename T, uint16_t length> class Array<Offset<T>, length> {
+ static_assert(flatbuffers::is_same<T, void>::value, "unexpected type T");
+
+ public:
+ typedef const void *return_type;
+
+ const uint8_t *Data() const { return data_; }
+
+ // Make idl_gen_text.cpp::PrintContainer happy.
+ return_type operator[](uoffset_t) const {
+ FLATBUFFERS_ASSERT(false);
+ return nullptr;
+ }
+
+ private:
+ // This class is only used to access pre-existing data.
+ Array();
+ Array(const Array &);
+ Array &operator=(const Array &);
+
+ uint8_t data_[1];
+};
+
+// Lexicographically compare two strings (possibly containing nulls), and
+// return true if the first is less than the second.
+static inline bool StringLessThan(const char *a_data, uoffset_t a_size,
+ const char *b_data, uoffset_t b_size) {
+ const auto cmp = memcmp(a_data, b_data, (std::min)(a_size, b_size));
+ return cmp == 0 ? a_size < b_size : cmp < 0;
+}
+
+struct String : public Vector<char> {
+ const char *c_str() const { return reinterpret_cast<const char *>(Data()); }
+ std::string str() const { return std::string(c_str(), size()); }
+
+ // clang-format off
+ #ifdef FLATBUFFERS_HAS_STRING_VIEW
+ flatbuffers::string_view string_view() const {
+ return flatbuffers::string_view(c_str(), size());
+ }
+ #endif // FLATBUFFERS_HAS_STRING_VIEW
+ // clang-format on
+
+ bool operator<(const String &o) const {
+ return StringLessThan(this->data(), this->size(), o.data(), o.size());
+ }
+};
+
+// Convenience function to get std::string from a String returning an empty
+// string on null pointer.
+static inline std::string GetString(const String *str) {
+ return str ? str->str() : "";
+}
+
+// Convenience function to get char* from a String returning an empty string on
+// null pointer.
+static inline const char *GetCstring(const String *str) {
+ return str ? str->c_str() : "";
+}
+
+// Allocator interface. This is flatbuffers-specific and meant only for
+// `vector_downward` usage.
+class Allocator {
+ public:
+ virtual ~Allocator() {}
+
+ // Allocate `size` bytes of memory.
+ virtual uint8_t *allocate(size_t size) = 0;
+
+ // Deallocate `size` bytes of memory at `p` allocated by this allocator.
+ virtual void deallocate(uint8_t *p, size_t size) = 0;
+
+ // Reallocate `new_size` bytes of memory, replacing the old region of size
+ // `old_size` at `p`. In contrast to a normal realloc, this grows downwards,
+ // and is intended specifcally for `vector_downward` use.
+ // `in_use_back` and `in_use_front` indicate how much of `old_size` is
+ // actually in use at each end, and needs to be copied.
+ virtual uint8_t *reallocate_downward(uint8_t *old_p, size_t old_size,
+ size_t new_size, size_t in_use_back,
+ size_t in_use_front) {
+ FLATBUFFERS_ASSERT(new_size > old_size); // vector_downward only grows
+ uint8_t *new_p = allocate(new_size);
+ memcpy_downward(old_p, old_size, new_p, new_size, in_use_back,
+ in_use_front);
+ deallocate(old_p, old_size);
+ return new_p;
+ }
+
+ protected:
+ // Called by `reallocate_downward` to copy memory from `old_p` of `old_size`
+ // to `new_p` of `new_size`. Only memory of size `in_use_front` and
+ // `in_use_back` will be copied from the front and back of the old memory
+ // allocation.
+ void memcpy_downward(uint8_t *old_p, size_t old_size, uint8_t *new_p,
+ size_t new_size, size_t in_use_back,
+ size_t in_use_front) {
+ memcpy(new_p + new_size - in_use_back, old_p + old_size - in_use_back,
+ in_use_back);
+ memcpy(new_p, old_p, in_use_front);
+ }
+};
+
+// DefaultAllocator uses new/delete to allocate memory regions
+class DefaultAllocator : public Allocator {
+ public:
+ uint8_t *allocate(size_t size) FLATBUFFERS_OVERRIDE {
+ return new uint8_t[size];
+ }
+
+ void deallocate(uint8_t *p, size_t) FLATBUFFERS_OVERRIDE { delete[] p; }
+
+ static void dealloc(void *p, size_t) { delete[] static_cast<uint8_t *>(p); }
+};
+
+// These functions allow for a null allocator to mean use the default allocator,
+// as used by DetachedBuffer and vector_downward below.
+// This is to avoid having a statically or dynamically allocated default
+// allocator, or having to move it between the classes that may own it.
+inline uint8_t *Allocate(Allocator *allocator, size_t size) {
+ return allocator ? allocator->allocate(size)
+ : DefaultAllocator().allocate(size);
+}
+
+inline void Deallocate(Allocator *allocator, uint8_t *p, size_t size) {
+ if (allocator)
+ allocator->deallocate(p, size);
+ else
+ DefaultAllocator().deallocate(p, size);
+}
+
+inline uint8_t *ReallocateDownward(Allocator *allocator, uint8_t *old_p,
+ size_t old_size, size_t new_size,
+ size_t in_use_back, size_t in_use_front) {
+ return allocator ? allocator->reallocate_downward(old_p, old_size, new_size,
+ in_use_back, in_use_front)
+ : DefaultAllocator().reallocate_downward(
+ old_p, old_size, new_size, in_use_back, in_use_front);
+}
+
+// DetachedBuffer is a finished flatbuffer memory region, detached from its
+// builder. The original memory region and allocator are also stored so that
+// the DetachedBuffer can manage the memory lifetime.
+class DetachedBuffer {
+ public:
+ DetachedBuffer()
+ : allocator_(nullptr),
+ own_allocator_(false),
+ buf_(nullptr),
+ reserved_(0),
+ cur_(nullptr),
+ size_(0) {}
+
+ DetachedBuffer(Allocator *allocator, bool own_allocator, uint8_t *buf,
+ size_t reserved, uint8_t *cur, size_t sz)
+ : allocator_(allocator),
+ own_allocator_(own_allocator),
+ buf_(buf),
+ reserved_(reserved),
+ cur_(cur),
+ size_(sz) {}
+
+ // clang-format off
+ #if !defined(FLATBUFFERS_CPP98_STL)
+ // clang-format on
+ DetachedBuffer(DetachedBuffer &&other)
+ : allocator_(other.allocator_),
+ own_allocator_(other.own_allocator_),
+ buf_(other.buf_),
+ reserved_(other.reserved_),
+ cur_(other.cur_),
+ size_(other.size_) {
+ other.reset();
+ }
+ // clang-format off
+ #endif // !defined(FLATBUFFERS_CPP98_STL)
+ // clang-format on
+
+ // clang-format off
+ #if !defined(FLATBUFFERS_CPP98_STL)
+ // clang-format on
+ DetachedBuffer &operator=(DetachedBuffer &&other) {
+ if (this == &other) return *this;
+
+ destroy();
+
+ allocator_ = other.allocator_;
+ own_allocator_ = other.own_allocator_;
+ buf_ = other.buf_;
+ reserved_ = other.reserved_;
+ cur_ = other.cur_;
+ size_ = other.size_;
+
+ other.reset();
+
+ return *this;
+ }
+ // clang-format off
+ #endif // !defined(FLATBUFFERS_CPP98_STL)
+ // clang-format on
+
+ ~DetachedBuffer() { destroy(); }
+
+ const uint8_t *data() const { return cur_; }
+
+ uint8_t *data() { return cur_; }
+
+ size_t size() const { return size_; }
+
+ // clang-format off
+ #if 0 // disabled for now due to the ordering of classes in this header
+ template <class T>
+ bool Verify() const {
+ Verifier verifier(data(), size());
+ return verifier.Verify<T>(nullptr);
+ }
+
+ template <class T>
+ const T* GetRoot() const {
+ return flatbuffers::GetRoot<T>(data());
+ }
+
+ template <class T>
+ T* GetRoot() {
+ return flatbuffers::GetRoot<T>(data());
+ }
+ #endif
+ // clang-format on
+
+ // clang-format off
+ #if !defined(FLATBUFFERS_CPP98_STL)
+ // clang-format on
+ // These may change access mode, leave these at end of public section
+ FLATBUFFERS_DELETE_FUNC(DetachedBuffer(const DetachedBuffer &other))
+ FLATBUFFERS_DELETE_FUNC(
+ DetachedBuffer &operator=(const DetachedBuffer &other))
+ // clang-format off
+ #endif // !defined(FLATBUFFERS_CPP98_STL)
+ // clang-format on
+
+ protected:
+ Allocator *allocator_;
+ bool own_allocator_;
+ uint8_t *buf_;
+ size_t reserved_;
+ uint8_t *cur_;
+ size_t size_;
+
+ inline void destroy() {
+ if (buf_) Deallocate(allocator_, buf_, reserved_);
+ if (own_allocator_ && allocator_) { delete allocator_; }
+ reset();
+ }
+
+ inline void reset() {
+ allocator_ = nullptr;
+ own_allocator_ = false;
+ buf_ = nullptr;
+ reserved_ = 0;
+ cur_ = nullptr;
+ size_ = 0;
+ }
+};
+
+// This is a minimal replication of std::vector<uint8_t> functionality,
+// except growing from higher to lower addresses. i.e push_back() inserts data
+// in the lowest address in the vector.
+// Since this vector leaves the lower part unused, we support a "scratch-pad"
+// that can be stored there for temporary data, to share the allocated space.
+// Essentially, this supports 2 std::vectors in a single buffer.
+class vector_downward {
+ public:
+ explicit vector_downward(size_t initial_size, Allocator *allocator,
+ bool own_allocator, size_t buffer_minalign)
+ : allocator_(allocator),
+ own_allocator_(own_allocator),
+ initial_size_(initial_size),
+ buffer_minalign_(buffer_minalign),
+ reserved_(0),
+ buf_(nullptr),
+ cur_(nullptr),
+ scratch_(nullptr) {}
+
+ // clang-format off
+ #if !defined(FLATBUFFERS_CPP98_STL)
+ vector_downward(vector_downward &&other)
+ #else
+ vector_downward(vector_downward &other)
+ #endif // defined(FLATBUFFERS_CPP98_STL)
+ // clang-format on
+ : allocator_(other.allocator_),
+ own_allocator_(other.own_allocator_),
+ initial_size_(other.initial_size_),
+ buffer_minalign_(other.buffer_minalign_),
+ reserved_(other.reserved_),
+ buf_(other.buf_),
+ cur_(other.cur_),
+ scratch_(other.scratch_) {
+ // No change in other.allocator_
+ // No change in other.initial_size_
+ // No change in other.buffer_minalign_
+ other.own_allocator_ = false;
+ other.reserved_ = 0;
+ other.buf_ = nullptr;
+ other.cur_ = nullptr;
+ other.scratch_ = nullptr;
+ }
+
+ // clang-format off
+ #if !defined(FLATBUFFERS_CPP98_STL)
+ // clang-format on
+ vector_downward &operator=(vector_downward &&other) {
+ // Move construct a temporary and swap idiom
+ vector_downward temp(std::move(other));
+ swap(temp);
+ return *this;
+ }
+ // clang-format off
+ #endif // defined(FLATBUFFERS_CPP98_STL)
+ // clang-format on
+
+ ~vector_downward() {
+ clear_buffer();
+ clear_allocator();
+ }
+
+ void reset() {
+ clear_buffer();
+ clear();
+ }
+
+ void clear() {
+ if (buf_) {
+ cur_ = buf_ + reserved_;
+ } else {
+ reserved_ = 0;
+ cur_ = nullptr;
+ }
+ clear_scratch();
+ }
+
+ void clear_scratch() { scratch_ = buf_; }
+
+ void clear_allocator() {
+ if (own_allocator_ && allocator_) { delete allocator_; }
+ allocator_ = nullptr;
+ own_allocator_ = false;
+ }
+
+ void clear_buffer() {
+ if (buf_) Deallocate(allocator_, buf_, reserved_);
+ buf_ = nullptr;
+ }
+
+ // Relinquish the pointer to the caller.
+ uint8_t *release_raw(size_t &allocated_bytes, size_t &offset) {
+ auto *buf = buf_;
+ allocated_bytes = reserved_;
+ offset = static_cast<size_t>(cur_ - buf_);
+
+ // release_raw only relinquishes the buffer ownership.
+ // Does not deallocate or reset the allocator. Destructor will do that.
+ buf_ = nullptr;
+ clear();
+ return buf;
+ }
+
+ // Relinquish the pointer to the caller.
+ DetachedBuffer release() {
+ // allocator ownership (if any) is transferred to DetachedBuffer.
+ DetachedBuffer fb(allocator_, own_allocator_, buf_, reserved_, cur_,
+ size());
+ if (own_allocator_) {
+ allocator_ = nullptr;
+ own_allocator_ = false;
+ }
+ buf_ = nullptr;
+ clear();
+ return fb;
+ }
+
+ size_t ensure_space(size_t len) {
+ FLATBUFFERS_ASSERT(cur_ >= scratch_ && scratch_ >= buf_);
+ if (len > static_cast<size_t>(cur_ - scratch_)) { reallocate(len); }
+ // Beyond this, signed offsets may not have enough range:
+ // (FlatBuffers > 2GB not supported).
+ FLATBUFFERS_ASSERT(size() < FLATBUFFERS_MAX_BUFFER_SIZE);
+ return len;
+ }
+
+ inline uint8_t *make_space(size_t len) {
+ size_t space = ensure_space(len);
+ cur_ -= space;
+ return cur_;
+ }
+
+ // Returns nullptr if using the DefaultAllocator.
+ Allocator *get_custom_allocator() { return allocator_; }
+
+ uoffset_t size() const {
+ return static_cast<uoffset_t>(reserved_ - (cur_ - buf_));
+ }
+
+ uoffset_t scratch_size() const {
+ return static_cast<uoffset_t>(scratch_ - buf_);
+ }
+
+ size_t capacity() const { return reserved_; }
+
+ uint8_t *data() const {
+ FLATBUFFERS_ASSERT(cur_);
+ return cur_;
+ }
+
+ uint8_t *scratch_data() const {
+ FLATBUFFERS_ASSERT(buf_);
+ return buf_;
+ }
+
+ uint8_t *scratch_end() const {
+ FLATBUFFERS_ASSERT(scratch_);
+ return scratch_;
+ }
+
+ uint8_t *data_at(size_t offset) const { return buf_ + reserved_ - offset; }
+
+ void push(const uint8_t *bytes, size_t num) {
+ if (num > 0) { memcpy(make_space(num), bytes, num); }
+ }
+
+ // Specialized version of push() that avoids memcpy call for small data.
+ template<typename T> void push_small(const T &little_endian_t) {
+ make_space(sizeof(T));
+ *reinterpret_cast<T *>(cur_) = little_endian_t;
+ }
+
+ template<typename T> void scratch_push_small(const T &t) {
+ ensure_space(sizeof(T));
+ *reinterpret_cast<T *>(scratch_) = t;
+ scratch_ += sizeof(T);
+ }
+
+ // fill() is most frequently called with small byte counts (<= 4),
+ // which is why we're using loops rather than calling memset.
+ void fill(size_t zero_pad_bytes) {
+ make_space(zero_pad_bytes);
+ for (size_t i = 0; i < zero_pad_bytes; i++) cur_[i] = 0;
+ }
+
+ // Version for when we know the size is larger.
+ // Precondition: zero_pad_bytes > 0
+ void fill_big(size_t zero_pad_bytes) {
+ memset(make_space(zero_pad_bytes), 0, zero_pad_bytes);
+ }
+
+ void pop(size_t bytes_to_remove) { cur_ += bytes_to_remove; }
+ void scratch_pop(size_t bytes_to_remove) { scratch_ -= bytes_to_remove; }
+
+ void swap(vector_downward &other) {
+ using std::swap;
+ swap(allocator_, other.allocator_);
+ swap(own_allocator_, other.own_allocator_);
+ swap(initial_size_, other.initial_size_);
+ swap(buffer_minalign_, other.buffer_minalign_);
+ swap(reserved_, other.reserved_);
+ swap(buf_, other.buf_);
+ swap(cur_, other.cur_);
+ swap(scratch_, other.scratch_);
+ }
+
+ void swap_allocator(vector_downward &other) {
+ using std::swap;
+ swap(allocator_, other.allocator_);
+ swap(own_allocator_, other.own_allocator_);
+ }
+
+ private:
+ // You shouldn't really be copying instances of this class.
+ FLATBUFFERS_DELETE_FUNC(vector_downward(const vector_downward &))
+ FLATBUFFERS_DELETE_FUNC(vector_downward &operator=(const vector_downward &))
+
+ Allocator *allocator_;
+ bool own_allocator_;
+ size_t initial_size_;
+ size_t buffer_minalign_;
+ size_t reserved_;
+ uint8_t *buf_;
+ uint8_t *cur_; // Points at location between empty (below) and used (above).
+ uint8_t *scratch_; // Points to the end of the scratchpad in use.
+
+ void reallocate(size_t len) {
+ auto old_reserved = reserved_;
+ auto old_size = size();
+ auto old_scratch_size = scratch_size();
+ reserved_ +=
+ (std::max)(len, old_reserved ? old_reserved / 2 : initial_size_);
+ reserved_ = (reserved_ + buffer_minalign_ - 1) & ~(buffer_minalign_ - 1);
+ if (buf_) {
+ buf_ = ReallocateDownward(allocator_, buf_, old_reserved, reserved_,
+ old_size, old_scratch_size);
+ } else {
+ buf_ = Allocate(allocator_, reserved_);
+ }
+ cur_ = buf_ + reserved_ - old_size;
+ scratch_ = buf_ + old_scratch_size;
+ }
+};
+
+// Converts a Field ID to a virtual table offset.
+inline voffset_t FieldIndexToOffset(voffset_t field_id) {
+ // Should correspond to what EndTable() below builds up.
+ const int fixed_fields = 2; // Vtable size and Object Size.
+ return static_cast<voffset_t>((field_id + fixed_fields) * sizeof(voffset_t));
+}
+
+template<typename T, typename Alloc>
+const T *data(const std::vector<T, Alloc> &v) {
+ // Eventually the returned pointer gets passed down to memcpy, so
+ // we need it to be non-null to avoid undefined behavior.
+ static uint8_t t;
+ return v.empty() ? reinterpret_cast<const T *>(&t) : &v.front();
+}
+template<typename T, typename Alloc> T *data(std::vector<T, Alloc> &v) {
+ // Eventually the returned pointer gets passed down to memcpy, so
+ // we need it to be non-null to avoid undefined behavior.
+ static uint8_t t;
+ return v.empty() ? reinterpret_cast<T *>(&t) : &v.front();
+}
+
+/// @endcond
+
+/// @addtogroup flatbuffers_cpp_api
+/// @{
+/// @class FlatBufferBuilder
+/// @brief Helper class to hold data needed in creation of a FlatBuffer.
+/// To serialize data, you typically call one of the `Create*()` functions in
+/// the generated code, which in turn call a sequence of `StartTable`/
+/// `PushElement`/`AddElement`/`EndTable`, or the builtin `CreateString`/
+/// `CreateVector` functions. Do this is depth-first order to build up a tree to
+/// the root. `Finish()` wraps up the buffer ready for transport.
+class FlatBufferBuilder {
+ public:
+ /// @brief Default constructor for FlatBufferBuilder.
+ /// @param[in] initial_size The initial size of the buffer, in bytes. Defaults
+ /// to `1024`.
+ /// @param[in] allocator An `Allocator` to use. If null will use
+ /// `DefaultAllocator`.
+ /// @param[in] own_allocator Whether the builder/vector should own the
+ /// allocator. Defaults to / `false`.
+ /// @param[in] buffer_minalign Force the buffer to be aligned to the given
+ /// minimum alignment upon reallocation. Only needed if you intend to store
+ /// types with custom alignment AND you wish to read the buffer in-place
+ /// directly after creation.
+ explicit FlatBufferBuilder(
+ size_t initial_size = 1024, Allocator *allocator = nullptr,
+ bool own_allocator = false,
+ size_t buffer_minalign = AlignOf<largest_scalar_t>())
+ : buf_(initial_size, allocator, own_allocator, buffer_minalign),
+ num_field_loc(0),
+ max_voffset_(0),
+ nested(false),
+ finished(false),
+ minalign_(1),
+ force_defaults_(false),
+ dedup_vtables_(true),
+ string_pool(nullptr) {
+ EndianCheck();
+ }
+
+ // clang-format off
+ /// @brief Move constructor for FlatBufferBuilder.
+ #if !defined(FLATBUFFERS_CPP98_STL)
+ FlatBufferBuilder(FlatBufferBuilder &&other)
+ #else
+ FlatBufferBuilder(FlatBufferBuilder &other)
+ #endif // #if !defined(FLATBUFFERS_CPP98_STL)
+ : buf_(1024, nullptr, false, AlignOf<largest_scalar_t>()),
+ num_field_loc(0),
+ max_voffset_(0),
+ nested(false),
+ finished(false),
+ minalign_(1),
+ force_defaults_(false),
+ dedup_vtables_(true),
+ string_pool(nullptr) {
+ EndianCheck();
+ // Default construct and swap idiom.
+ // Lack of delegating constructors in vs2010 makes it more verbose than needed.
+ Swap(other);
+ }
+ // clang-format on
+
+ // clang-format off
+ #if !defined(FLATBUFFERS_CPP98_STL)
+ // clang-format on
+ /// @brief Move assignment operator for FlatBufferBuilder.
+ FlatBufferBuilder &operator=(FlatBufferBuilder &&other) {
+ // Move construct a temporary and swap idiom
+ FlatBufferBuilder temp(std::move(other));
+ Swap(temp);
+ return *this;
+ }
+ // clang-format off
+ #endif // defined(FLATBUFFERS_CPP98_STL)
+ // clang-format on
+
+ void Swap(FlatBufferBuilder &other) {
+ using std::swap;
+ buf_.swap(other.buf_);
+ swap(num_field_loc, other.num_field_loc);
+ swap(max_voffset_, other.max_voffset_);
+ swap(nested, other.nested);
+ swap(finished, other.finished);
+ swap(minalign_, other.minalign_);
+ swap(force_defaults_, other.force_defaults_);
+ swap(dedup_vtables_, other.dedup_vtables_);
+ swap(string_pool, other.string_pool);
+ }
+
+ ~FlatBufferBuilder() {
+ if (string_pool) delete string_pool;
+ }
+
+ void Reset() {
+ Clear(); // clear builder state
+ buf_.reset(); // deallocate buffer
+ }
+
+ /// @brief Reset all the state in this FlatBufferBuilder so it can be reused
+ /// to construct another buffer.
+ void Clear() {
+ ClearOffsets();
+ buf_.clear();
+ nested = false;
+ finished = false;
+ minalign_ = 1;
+ if (string_pool) string_pool->clear();
+ }
+
+ /// @brief The current size of the serialized buffer, counting from the end.
+ /// @return Returns an `uoffset_t` with the current size of the buffer.
+ uoffset_t GetSize() const { return buf_.size(); }
+
+ /// @brief Get the serialized buffer (after you call `Finish()`).
+ /// @return Returns an `uint8_t` pointer to the FlatBuffer data inside the
+ /// buffer.
+ uint8_t *GetBufferPointer() const {
+ Finished();
+ return buf_.data();
+ }
+
+ /// @brief Get a pointer to an unfinished buffer.
+ /// @return Returns a `uint8_t` pointer to the unfinished buffer.
+ uint8_t *GetCurrentBufferPointer() const { return buf_.data(); }
+
+ /// @brief Get the released pointer to the serialized buffer.
+ /// @warning Do NOT attempt to use this FlatBufferBuilder afterwards!
+ /// @return A `FlatBuffer` that owns the buffer and its allocator and
+ /// behaves similar to a `unique_ptr` with a deleter.
+ FLATBUFFERS_ATTRIBUTE(deprecated("use Release() instead"))
+ DetachedBuffer ReleaseBufferPointer() {
+ Finished();
+ return buf_.release();
+ }
+
+ /// @brief Get the released DetachedBuffer.
+ /// @return A `DetachedBuffer` that owns the buffer and its allocator.
+ DetachedBuffer Release() {
+ Finished();
+ return buf_.release();
+ }
+
+ /// @brief Get the released pointer to the serialized buffer.
+ /// @param size The size of the memory block containing
+ /// the serialized `FlatBuffer`.
+ /// @param offset The offset from the released pointer where the finished
+ /// `FlatBuffer` starts.
+ /// @return A raw pointer to the start of the memory block containing
+ /// the serialized `FlatBuffer`.
+ /// @remark If the allocator is owned, it gets deleted when the destructor is
+ /// called..
+ uint8_t *ReleaseRaw(size_t &size, size_t &offset) {
+ Finished();
+ return buf_.release_raw(size, offset);
+ }
+
+ /// @brief get the minimum alignment this buffer needs to be accessed
+ /// properly. This is only known once all elements have been written (after
+ /// you call Finish()). You can use this information if you need to embed
+ /// a FlatBuffer in some other buffer, such that you can later read it
+ /// without first having to copy it into its own buffer.
+ size_t GetBufferMinAlignment() {
+ Finished();
+ return minalign_;
+ }
+
+ /// @cond FLATBUFFERS_INTERNAL
+ void Finished() const {
+ // If you get this assert, you're attempting to get access a buffer
+ // which hasn't been finished yet. Be sure to call
+ // FlatBufferBuilder::Finish with your root table.
+ // If you really need to access an unfinished buffer, call
+ // GetCurrentBufferPointer instead.
+ FLATBUFFERS_ASSERT(finished);
+ }
+ /// @endcond
+
+ /// @brief In order to save space, fields that are set to their default value
+ /// don't get serialized into the buffer.
+ /// @param[in] fd When set to `true`, always serializes default values that
+ /// are set. Optional fields which are not set explicitly, will still not be
+ /// serialized.
+ void ForceDefaults(bool fd) { force_defaults_ = fd; }
+
+ /// @brief By default vtables are deduped in order to save space.
+ /// @param[in] dedup When set to `true`, dedup vtables.
+ void DedupVtables(bool dedup) { dedup_vtables_ = dedup; }
+
+ /// @cond FLATBUFFERS_INTERNAL
+ void Pad(size_t num_bytes) { buf_.fill(num_bytes); }
+
+ void TrackMinAlign(size_t elem_size) {
+ if (elem_size > minalign_) minalign_ = elem_size;
+ }
+
+ void Align(size_t elem_size) {
+ TrackMinAlign(elem_size);
+ buf_.fill(PaddingBytes(buf_.size(), elem_size));
+ }
+
+ void PushFlatBuffer(const uint8_t *bytes, size_t size) {
+ PushBytes(bytes, size);
+ finished = true;
+ }
+
+ void PushBytes(const uint8_t *bytes, size_t size) { buf_.push(bytes, size); }
+
+ void PopBytes(size_t amount) { buf_.pop(amount); }
+
+ template<typename T> void AssertScalarT() {
+ // The code assumes power of 2 sizes and endian-swap-ability.
+ static_assert(flatbuffers::is_scalar<T>::value, "T must be a scalar type");
+ }
+
+ // Write a single aligned scalar to the buffer
+ template<typename T> uoffset_t PushElement(T element) {
+ AssertScalarT<T>();
+ T litle_endian_element = EndianScalar(element);
+ Align(sizeof(T));
+ buf_.push_small(litle_endian_element);
+ return GetSize();
+ }
+
+ template<typename T> uoffset_t PushElement(Offset<T> off) {
+ // Special case for offsets: see ReferTo below.
+ return PushElement(ReferTo(off.o));
+ }
+
+ // When writing fields, we track where they are, so we can create correct
+ // vtables later.
+ void TrackField(voffset_t field, uoffset_t off) {
+ FieldLoc fl = { off, field };
+ buf_.scratch_push_small(fl);
+ num_field_loc++;
+ max_voffset_ = (std::max)(max_voffset_, field);
+ }
+
+ // Like PushElement, but additionally tracks the field this represents.
+ template<typename T> void AddElement(voffset_t field, T e, T def) {
+ // We don't serialize values equal to the default.
+ if (IsTheSameAs(e, def) && !force_defaults_) return;
+ auto off = PushElement(e);
+ TrackField(field, off);
+ }
+
+ template<typename T> void AddOffset(voffset_t field, Offset<T> off) {
+ if (off.IsNull()) return; // Don't store.
+ AddElement(field, ReferTo(off.o), static_cast<uoffset_t>(0));
+ }
+
+ template<typename T> void AddStruct(voffset_t field, const T *structptr) {
+ if (!structptr) return; // Default, don't store.
+ Align(AlignOf<T>());
+ buf_.push_small(*structptr);
+ TrackField(field, GetSize());
+ }
+
+ void AddStructOffset(voffset_t field, uoffset_t off) {
+ TrackField(field, off);
+ }
+
+ // Offsets initially are relative to the end of the buffer (downwards).
+ // This function converts them to be relative to the current location
+ // in the buffer (when stored here), pointing upwards.
+ uoffset_t ReferTo(uoffset_t off) {
+ // Align to ensure GetSize() below is correct.
+ Align(sizeof(uoffset_t));
+ // Offset must refer to something already in buffer.
+ FLATBUFFERS_ASSERT(off && off <= GetSize());
+ return GetSize() - off + static_cast<uoffset_t>(sizeof(uoffset_t));
+ }
+
+ void NotNested() {
+ // If you hit this, you're trying to construct a Table/Vector/String
+ // during the construction of its parent table (between the MyTableBuilder
+ // and table.Finish().
+ // Move the creation of these sub-objects to above the MyTableBuilder to
+ // not get this assert.
+ // Ignoring this assert may appear to work in simple cases, but the reason
+ // it is here is that storing objects in-line may cause vtable offsets
+ // to not fit anymore. It also leads to vtable duplication.
+ FLATBUFFERS_ASSERT(!nested);
+ // If you hit this, fields were added outside the scope of a table.
+ FLATBUFFERS_ASSERT(!num_field_loc);
+ }
+
+ // From generated code (or from the parser), we call StartTable/EndTable
+ // with a sequence of AddElement calls in between.
+ uoffset_t StartTable() {
+ NotNested();
+ nested = true;
+ return GetSize();
+ }
+
+ // This finishes one serialized object by generating the vtable if it's a
+ // table, comparing it against existing vtables, and writing the
+ // resulting vtable offset.
+ uoffset_t EndTable(uoffset_t start) {
+ // If you get this assert, a corresponding StartTable wasn't called.
+ FLATBUFFERS_ASSERT(nested);
+ // Write the vtable offset, which is the start of any Table.
+ // We fill it's value later.
+ auto vtableoffsetloc = PushElement<soffset_t>(0);
+ // Write a vtable, which consists entirely of voffset_t elements.
+ // It starts with the number of offsets, followed by a type id, followed
+ // by the offsets themselves. In reverse:
+ // Include space for the last offset and ensure empty tables have a
+ // minimum size.
+ max_voffset_ =
+ (std::max)(static_cast<voffset_t>(max_voffset_ + sizeof(voffset_t)),
+ FieldIndexToOffset(0));
+ buf_.fill_big(max_voffset_);
+ auto table_object_size = vtableoffsetloc - start;
+ // Vtable use 16bit offsets.
+ FLATBUFFERS_ASSERT(table_object_size < 0x10000);
+ WriteScalar<voffset_t>(buf_.data() + sizeof(voffset_t),
+ static_cast<voffset_t>(table_object_size));
+ WriteScalar<voffset_t>(buf_.data(), max_voffset_);
+ // Write the offsets into the table
+ for (auto it = buf_.scratch_end() - num_field_loc * sizeof(FieldLoc);
+ it < buf_.scratch_end(); it += sizeof(FieldLoc)) {
+ auto field_location = reinterpret_cast<FieldLoc *>(it);
+ auto pos = static_cast<voffset_t>(vtableoffsetloc - field_location->off);
+ // If this asserts, it means you've set a field twice.
+ FLATBUFFERS_ASSERT(
+ !ReadScalar<voffset_t>(buf_.data() + field_location->id));
+ WriteScalar<voffset_t>(buf_.data() + field_location->id, pos);
+ }
+ ClearOffsets();
+ auto vt1 = reinterpret_cast<voffset_t *>(buf_.data());
+ auto vt1_size = ReadScalar<voffset_t>(vt1);
+ auto vt_use = GetSize();
+ // See if we already have generated a vtable with this exact same
+ // layout before. If so, make it point to the old one, remove this one.
+ if (dedup_vtables_) {
+ for (auto it = buf_.scratch_data(); it < buf_.scratch_end();
+ it += sizeof(uoffset_t)) {
+ auto vt_offset_ptr = reinterpret_cast<uoffset_t *>(it);
+ auto vt2 = reinterpret_cast<voffset_t *>(buf_.data_at(*vt_offset_ptr));
+ auto vt2_size = ReadScalar<voffset_t>(vt2);
+ if (vt1_size != vt2_size || 0 != memcmp(vt2, vt1, vt1_size)) continue;
+ vt_use = *vt_offset_ptr;
+ buf_.pop(GetSize() - vtableoffsetloc);
+ break;
+ }
+ }
+ // If this is a new vtable, remember it.
+ if (vt_use == GetSize()) { buf_.scratch_push_small(vt_use); }
+ // Fill the vtable offset we created above.
+ // The offset points from the beginning of the object to where the
+ // vtable is stored.
+ // Offsets default direction is downward in memory for future format
+ // flexibility (storing all vtables at the start of the file).
+ WriteScalar(buf_.data_at(vtableoffsetloc),
+ static_cast<soffset_t>(vt_use) -
+ static_cast<soffset_t>(vtableoffsetloc));
+
+ nested = false;
+ return vtableoffsetloc;
+ }
+
+ FLATBUFFERS_ATTRIBUTE(deprecated("call the version above instead"))
+ uoffset_t EndTable(uoffset_t start, voffset_t /*numfields*/) {
+ return EndTable(start);
+ }
+
+ // This checks a required field has been set in a given table that has
+ // just been constructed.
+ template<typename T> void Required(Offset<T> table, voffset_t field);
+
+ uoffset_t StartStruct(size_t alignment) {
+ Align(alignment);
+ return GetSize();
+ }
+
+ uoffset_t EndStruct() { return GetSize(); }
+
+ void ClearOffsets() {
+ buf_.scratch_pop(num_field_loc * sizeof(FieldLoc));
+ num_field_loc = 0;
+ max_voffset_ = 0;
+ }
+
+ // Aligns such that when "len" bytes are written, an object can be written
+ // after it with "alignment" without padding.
+ void PreAlign(size_t len, size_t alignment) {
+ TrackMinAlign(alignment);
+ buf_.fill(PaddingBytes(GetSize() + len, alignment));
+ }
+ template<typename T> void PreAlign(size_t len) {
+ AssertScalarT<T>();
+ PreAlign(len, sizeof(T));
+ }
+ /// @endcond
+
+ /// @brief Store a string in the buffer, which can contain any binary data.
+ /// @param[in] str A const char pointer to the data to be stored as a string.
+ /// @param[in] len The number of bytes that should be stored from `str`.
+ /// @return Returns the offset in the buffer where the string starts.
+ Offset<String> CreateString(const char *str, size_t len) {
+ NotNested();
+ PreAlign<uoffset_t>(len + 1); // Always 0-terminated.
+ buf_.fill(1);
+ PushBytes(reinterpret_cast<const uint8_t *>(str), len);
+ PushElement(static_cast<uoffset_t>(len));
+ return Offset<String>(GetSize());
+ }
+
+ /// @brief Store a string in the buffer, which is null-terminated.
+ /// @param[in] str A const char pointer to a C-string to add to the buffer.
+ /// @return Returns the offset in the buffer where the string starts.
+ Offset<String> CreateString(const char *str) {
+ return CreateString(str, strlen(str));
+ }
+
+ /// @brief Store a string in the buffer, which is null-terminated.
+ /// @param[in] str A char pointer to a C-string to add to the buffer.
+ /// @return Returns the offset in the buffer where the string starts.
+ Offset<String> CreateString(char *str) {
+ return CreateString(str, strlen(str));
+ }
+
+ /// @brief Store a string in the buffer, which can contain any binary data.
+ /// @param[in] str A const reference to a std::string to store in the buffer.
+ /// @return Returns the offset in the buffer where the string starts.
+ Offset<String> CreateString(const std::string &str) {
+ return CreateString(str.c_str(), str.length());
+ }
+
+ // clang-format off
+ #ifdef FLATBUFFERS_HAS_STRING_VIEW
+ /// @brief Store a string in the buffer, which can contain any binary data.
+ /// @param[in] str A const string_view to copy in to the buffer.
+ /// @return Returns the offset in the buffer where the string starts.
+ Offset<String> CreateString(flatbuffers::string_view str) {
+ return CreateString(str.data(), str.size());
+ }
+ #endif // FLATBUFFERS_HAS_STRING_VIEW
+ // clang-format on
+
+ /// @brief Store a string in the buffer, which can contain any binary data.
+ /// @param[in] str A const pointer to a `String` struct to add to the buffer.
+ /// @return Returns the offset in the buffer where the string starts
+ Offset<String> CreateString(const String *str) {
+ return str ? CreateString(str->c_str(), str->size()) : 0;
+ }
+
+ /// @brief Store a string in the buffer, which can contain any binary data.
+ /// @param[in] str A const reference to a std::string like type with support
+ /// of T::c_str() and T::length() to store in the buffer.
+ /// @return Returns the offset in the buffer where the string starts.
+ template<typename T> Offset<String> CreateString(const T &str) {
+ return CreateString(str.c_str(), str.length());
+ }
+
+ /// @brief Store a string in the buffer, which can contain any binary data.
+ /// If a string with this exact contents has already been serialized before,
+ /// instead simply returns the offset of the existing string.
+ /// @param[in] str A const char pointer to the data to be stored as a string.
+ /// @param[in] len The number of bytes that should be stored from `str`.
+ /// @return Returns the offset in the buffer where the string starts.
+ Offset<String> CreateSharedString(const char *str, size_t len) {
+ if (!string_pool)
+ string_pool = new StringOffsetMap(StringOffsetCompare(buf_));
+ auto size_before_string = buf_.size();
+ // Must first serialize the string, since the set is all offsets into
+ // buffer.
+ auto off = CreateString(str, len);
+ auto it = string_pool->find(off);
+ // If it exists we reuse existing serialized data!
+ if (it != string_pool->end()) {
+ // We can remove the string we serialized.
+ buf_.pop(buf_.size() - size_before_string);
+ return *it;
+ }
+ // Record this string for future use.
+ string_pool->insert(off);
+ return off;
+ }
+
+ /// @brief Store a string in the buffer, which null-terminated.
+ /// If a string with this exact contents has already been serialized before,
+ /// instead simply returns the offset of the existing string.
+ /// @param[in] str A const char pointer to a C-string to add to the buffer.
+ /// @return Returns the offset in the buffer where the string starts.
+ Offset<String> CreateSharedString(const char *str) {
+ return CreateSharedString(str, strlen(str));
+ }
+
+ /// @brief Store a string in the buffer, which can contain any binary data.
+ /// If a string with this exact contents has already been serialized before,
+ /// instead simply returns the offset of the existing string.
+ /// @param[in] str A const reference to a std::string to store in the buffer.
+ /// @return Returns the offset in the buffer where the string starts.
+ Offset<String> CreateSharedString(const std::string &str) {
+ return CreateSharedString(str.c_str(), str.length());
+ }
+
+ /// @brief Store a string in the buffer, which can contain any binary data.
+ /// If a string with this exact contents has already been serialized before,
+ /// instead simply returns the offset of the existing string.
+ /// @param[in] str A const pointer to a `String` struct to add to the buffer.
+ /// @return Returns the offset in the buffer where the string starts
+ Offset<String> CreateSharedString(const String *str) {
+ return CreateSharedString(str->c_str(), str->size());
+ }
+
+ /// @cond FLATBUFFERS_INTERNAL
+ uoffset_t EndVector(size_t len) {
+ FLATBUFFERS_ASSERT(nested); // Hit if no corresponding StartVector.
+ nested = false;
+ return PushElement(static_cast<uoffset_t>(len));
+ }
+
+ void StartVector(size_t len, size_t elemsize) {
+ NotNested();
+ nested = true;
+ PreAlign<uoffset_t>(len * elemsize);
+ PreAlign(len * elemsize, elemsize); // Just in case elemsize > uoffset_t.
+ }
+
+ // Call this right before StartVector/CreateVector if you want to force the
+ // alignment to be something different than what the element size would
+ // normally dictate.
+ // This is useful when storing a nested_flatbuffer in a vector of bytes,
+ // or when storing SIMD floats, etc.
+ void ForceVectorAlignment(size_t len, size_t elemsize, size_t alignment) {
+ PreAlign(len * elemsize, alignment);
+ }
+
+ // Similar to ForceVectorAlignment but for String fields.
+ void ForceStringAlignment(size_t len, size_t alignment) {
+ PreAlign((len + 1) * sizeof(char), alignment);
+ }
+
+ /// @endcond
+
+ /// @brief Serialize an array into a FlatBuffer `vector`.
+ /// @tparam T The data type of the array elements.
+ /// @param[in] v A pointer to the array of type `T` to serialize into the
+ /// buffer as a `vector`.
+ /// @param[in] len The number of elements to serialize.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ template<typename T> Offset<Vector<T>> CreateVector(const T *v, size_t len) {
+ // If this assert hits, you're specifying a template argument that is
+ // causing the wrong overload to be selected, remove it.
+ AssertScalarT<T>();
+ StartVector(len, sizeof(T));
+ // clang-format off
+ #if FLATBUFFERS_LITTLEENDIAN
+ PushBytes(reinterpret_cast<const uint8_t *>(v), len * sizeof(T));
+ #else
+ if (sizeof(T) == 1) {
+ PushBytes(reinterpret_cast<const uint8_t *>(v), len);
+ } else {
+ for (auto i = len; i > 0; ) {
+ PushElement(v[--i]);
+ }
+ }
+ #endif
+ // clang-format on
+ return Offset<Vector<T>>(EndVector(len));
+ }
+
+ template<typename T>
+ Offset<Vector<Offset<T>>> CreateVector(const Offset<T> *v, size_t len) {
+ StartVector(len, sizeof(Offset<T>));
+ for (auto i = len; i > 0;) { PushElement(v[--i]); }
+ return Offset<Vector<Offset<T>>>(EndVector(len));
+ }
+
+ /// @brief Serialize a `std::vector` into a FlatBuffer `vector`.
+ /// @tparam T The data type of the `std::vector` elements.
+ /// @param v A const reference to the `std::vector` to serialize into the
+ /// buffer as a `vector`.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ template<typename T> Offset<Vector<T>> CreateVector(const std::vector<T> &v) {
+ return CreateVector(data(v), v.size());
+ }
+
+ // vector<bool> may be implemented using a bit-set, so we can't access it as
+ // an array. Instead, read elements manually.
+ // Background: https://isocpp.org/blog/2012/11/on-vectorbool
+ Offset<Vector<uint8_t>> CreateVector(const std::vector<bool> &v) {
+ StartVector(v.size(), sizeof(uint8_t));
+ for (auto i = v.size(); i > 0;) {
+ PushElement(static_cast<uint8_t>(v[--i]));
+ }
+ return Offset<Vector<uint8_t>>(EndVector(v.size()));
+ }
+
+ // clang-format off
+ #ifndef FLATBUFFERS_CPP98_STL
+ /// @brief Serialize values returned by a function into a FlatBuffer `vector`.
+ /// This is a convenience function that takes care of iteration for you.
+ /// @tparam T The data type of the `std::vector` elements.
+ /// @param f A function that takes the current iteration 0..vector_size-1 and
+ /// returns any type that you can construct a FlatBuffers vector out of.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ template<typename T> Offset<Vector<T>> CreateVector(size_t vector_size,
+ const std::function<T (size_t i)> &f) {
+ std::vector<T> elems(vector_size);
+ for (size_t i = 0; i < vector_size; i++) elems[i] = f(i);
+ return CreateVector(elems);
+ }
+ #endif
+ // clang-format on
+
+ /// @brief Serialize values returned by a function into a FlatBuffer `vector`.
+ /// This is a convenience function that takes care of iteration for you.
+ /// @tparam T The data type of the `std::vector` elements.
+ /// @param f A function that takes the current iteration 0..vector_size-1,
+ /// and the state parameter returning any type that you can construct a
+ /// FlatBuffers vector out of.
+ /// @param state State passed to f.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ template<typename T, typename F, typename S>
+ Offset<Vector<T>> CreateVector(size_t vector_size, F f, S *state) {
+ std::vector<T> elems(vector_size);
+ for (size_t i = 0; i < vector_size; i++) elems[i] = f(i, state);
+ return CreateVector(elems);
+ }
+
+ /// @brief Serialize a `std::vector<std::string>` into a FlatBuffer `vector`.
+ /// This is a convenience function for a common case.
+ /// @param v A const reference to the `std::vector` to serialize into the
+ /// buffer as a `vector`.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ Offset<Vector<Offset<String>>> CreateVectorOfStrings(
+ const std::vector<std::string> &v) {
+ std::vector<Offset<String>> offsets(v.size());
+ for (size_t i = 0; i < v.size(); i++) offsets[i] = CreateString(v[i]);
+ return CreateVector(offsets);
+ }
+
+ /// @brief Serialize an array of structs into a FlatBuffer `vector`.
+ /// @tparam T The data type of the struct array elements.
+ /// @param[in] v A pointer to the array of type `T` to serialize into the
+ /// buffer as a `vector`.
+ /// @param[in] len The number of elements to serialize.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ template<typename T>
+ Offset<Vector<const T *>> CreateVectorOfStructs(const T *v, size_t len) {
+ StartVector(len * sizeof(T) / AlignOf<T>(), AlignOf<T>());
+ PushBytes(reinterpret_cast<const uint8_t *>(v), sizeof(T) * len);
+ return Offset<Vector<const T *>>(EndVector(len));
+ }
+
+ /// @brief Serialize an array of native structs into a FlatBuffer `vector`.
+ /// @tparam T The data type of the struct array elements.
+ /// @tparam S The data type of the native struct array elements.
+ /// @param[in] v A pointer to the array of type `S` to serialize into the
+ /// buffer as a `vector`.
+ /// @param[in] len The number of elements to serialize.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ template<typename T, typename S>
+ Offset<Vector<const T *>> CreateVectorOfNativeStructs(const S *v,
+ size_t len) {
+ extern T Pack(const S &);
+ std::vector<T> vv(len);
+ std::transform(v, v + len, vv.begin(), Pack);
+ return CreateVectorOfStructs<T>(data(vv), vv.size());
+ }
+
+ // clang-format off
+ #ifndef FLATBUFFERS_CPP98_STL
+ /// @brief Serialize an array of structs into a FlatBuffer `vector`.
+ /// @tparam T The data type of the struct array elements.
+ /// @param[in] filler A function that takes the current iteration 0..vector_size-1
+ /// and a pointer to the struct that must be filled.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ /// This is mostly useful when flatbuffers are generated with mutation
+ /// accessors.
+ template<typename T> Offset<Vector<const T *>> CreateVectorOfStructs(
+ size_t vector_size, const std::function<void(size_t i, T *)> &filler) {
+ T* structs = StartVectorOfStructs<T>(vector_size);
+ for (size_t i = 0; i < vector_size; i++) {
+ filler(i, structs);
+ structs++;
+ }
+ return EndVectorOfStructs<T>(vector_size);
+ }
+ #endif
+ // clang-format on
+
+ /// @brief Serialize an array of structs into a FlatBuffer `vector`.
+ /// @tparam T The data type of the struct array elements.
+ /// @param[in] f A function that takes the current iteration 0..vector_size-1,
+ /// a pointer to the struct that must be filled and the state argument.
+ /// @param[in] state Arbitrary state to pass to f.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ /// This is mostly useful when flatbuffers are generated with mutation
+ /// accessors.
+ template<typename T, typename F, typename S>
+ Offset<Vector<const T *>> CreateVectorOfStructs(size_t vector_size, F f,
+ S *state) {
+ T *structs = StartVectorOfStructs<T>(vector_size);
+ for (size_t i = 0; i < vector_size; i++) {
+ f(i, structs, state);
+ structs++;
+ }
+ return EndVectorOfStructs<T>(vector_size);
+ }
+
+ /// @brief Serialize a `std::vector` of structs into a FlatBuffer `vector`.
+ /// @tparam T The data type of the `std::vector` struct elements.
+ /// @param[in] v A const reference to the `std::vector` of structs to
+ /// serialize into the buffer as a `vector`.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ template<typename T, typename Alloc>
+ Offset<Vector<const T *>> CreateVectorOfStructs(
+ const std::vector<T, Alloc> &v) {
+ return CreateVectorOfStructs(data(v), v.size());
+ }
+
+ /// @brief Serialize a `std::vector` of native structs into a FlatBuffer
+ /// `vector`.
+ /// @tparam T The data type of the `std::vector` struct elements.
+ /// @tparam S The data type of the `std::vector` native struct elements.
+ /// @param[in] v A const reference to the `std::vector` of structs to
+ /// serialize into the buffer as a `vector`.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ template<typename T, typename S>
+ Offset<Vector<const T *>> CreateVectorOfNativeStructs(
+ const std::vector<S> &v) {
+ return CreateVectorOfNativeStructs<T, S>(data(v), v.size());
+ }
+
+ /// @cond FLATBUFFERS_INTERNAL
+ template<typename T> struct StructKeyComparator {
+ bool operator()(const T &a, const T &b) const {
+ return a.KeyCompareLessThan(&b);
+ }
+
+ private:
+ StructKeyComparator &operator=(const StructKeyComparator &);
+ };
+ /// @endcond
+
+ /// @brief Serialize a `std::vector` of structs into a FlatBuffer `vector`
+ /// in sorted order.
+ /// @tparam T The data type of the `std::vector` struct elements.
+ /// @param[in] v A const reference to the `std::vector` of structs to
+ /// serialize into the buffer as a `vector`.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ template<typename T>
+ Offset<Vector<const T *>> CreateVectorOfSortedStructs(std::vector<T> *v) {
+ return CreateVectorOfSortedStructs(data(*v), v->size());
+ }
+
+ /// @brief Serialize a `std::vector` of native structs into a FlatBuffer
+ /// `vector` in sorted order.
+ /// @tparam T The data type of the `std::vector` struct elements.
+ /// @tparam S The data type of the `std::vector` native struct elements.
+ /// @param[in] v A const reference to the `std::vector` of structs to
+ /// serialize into the buffer as a `vector`.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ template<typename T, typename S>
+ Offset<Vector<const T *>> CreateVectorOfSortedNativeStructs(
+ std::vector<S> *v) {
+ return CreateVectorOfSortedNativeStructs<T, S>(data(*v), v->size());
+ }
+
+ /// @brief Serialize an array of structs into a FlatBuffer `vector` in sorted
+ /// order.
+ /// @tparam T The data type of the struct array elements.
+ /// @param[in] v A pointer to the array of type `T` to serialize into the
+ /// buffer as a `vector`.
+ /// @param[in] len The number of elements to serialize.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ template<typename T>
+ Offset<Vector<const T *>> CreateVectorOfSortedStructs(T *v, size_t len) {
+ std::sort(v, v + len, StructKeyComparator<T>());
+ return CreateVectorOfStructs(v, len);
+ }
+
+ /// @brief Serialize an array of native structs into a FlatBuffer `vector` in
+ /// sorted order.
+ /// @tparam T The data type of the struct array elements.
+ /// @tparam S The data type of the native struct array elements.
+ /// @param[in] v A pointer to the array of type `S` to serialize into the
+ /// buffer as a `vector`.
+ /// @param[in] len The number of elements to serialize.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ template<typename T, typename S>
+ Offset<Vector<const T *>> CreateVectorOfSortedNativeStructs(S *v,
+ size_t len) {
+ extern T Pack(const S &);
+ typedef T (*Pack_t)(const S &);
+ std::vector<T> vv(len);
+ std::transform(v, v + len, vv.begin(), static_cast<Pack_t &>(Pack));
+ return CreateVectorOfSortedStructs<T>(vv, len);
+ }
+
+ /// @cond FLATBUFFERS_INTERNAL
+ template<typename T> struct TableKeyComparator {
+ TableKeyComparator(vector_downward &buf) : buf_(buf) {}
+ TableKeyComparator(const TableKeyComparator &other) : buf_(other.buf_) {}
+ bool operator()(const Offset<T> &a, const Offset<T> &b) const {
+ auto table_a = reinterpret_cast<T *>(buf_.data_at(a.o));
+ auto table_b = reinterpret_cast<T *>(buf_.data_at(b.o));
+ return table_a->KeyCompareLessThan(table_b);
+ }
+ vector_downward &buf_;
+
+ private:
+ TableKeyComparator &operator=(const TableKeyComparator &other) {
+ buf_ = other.buf_;
+ return *this;
+ }
+ };
+ /// @endcond
+
+ /// @brief Serialize an array of `table` offsets as a `vector` in the buffer
+ /// in sorted order.
+ /// @tparam T The data type that the offset refers to.
+ /// @param[in] v An array of type `Offset<T>` that contains the `table`
+ /// offsets to store in the buffer in sorted order.
+ /// @param[in] len The number of elements to store in the `vector`.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ template<typename T>
+ Offset<Vector<Offset<T>>> CreateVectorOfSortedTables(Offset<T> *v,
+ size_t len) {
+ std::sort(v, v + len, TableKeyComparator<T>(buf_));
+ return CreateVector(v, len);
+ }
+
+ /// @brief Serialize an array of `table` offsets as a `vector` in the buffer
+ /// in sorted order.
+ /// @tparam T The data type that the offset refers to.
+ /// @param[in] v An array of type `Offset<T>` that contains the `table`
+ /// offsets to store in the buffer in sorted order.
+ /// @return Returns a typed `Offset` into the serialized data indicating
+ /// where the vector is stored.
+ template<typename T>
+ Offset<Vector<Offset<T>>> CreateVectorOfSortedTables(
+ std::vector<Offset<T>> *v) {
+ return CreateVectorOfSortedTables(data(*v), v->size());
+ }
+
+ /// @brief Specialized version of `CreateVector` for non-copying use cases.
+ /// Write the data any time later to the returned buffer pointer `buf`.
+ /// @param[in] len The number of elements to store in the `vector`.
+ /// @param[in] elemsize The size of each element in the `vector`.
+ /// @param[out] buf A pointer to a `uint8_t` pointer that can be
+ /// written to at a later time to serialize the data into a `vector`
+ /// in the buffer.
+ uoffset_t CreateUninitializedVector(size_t len, size_t elemsize,
+ uint8_t **buf) {
+ NotNested();
+ StartVector(len, elemsize);
+ buf_.make_space(len * elemsize);
+ auto vec_start = GetSize();
+ auto vec_end = EndVector(len);
+ *buf = buf_.data_at(vec_start);
+ return vec_end;
+ }
+
+ /// @brief Specialized version of `CreateVector` for non-copying use cases.
+ /// Write the data any time later to the returned buffer pointer `buf`.
+ /// @tparam T The data type of the data that will be stored in the buffer
+ /// as a `vector`.
+ /// @param[in] len The number of elements to store in the `vector`.
+ /// @param[out] buf A pointer to a pointer of type `T` that can be
+ /// written to at a later time to serialize the data into a `vector`
+ /// in the buffer.
+ template<typename T>
+ Offset<Vector<T>> CreateUninitializedVector(size_t len, T **buf) {
+ AssertScalarT<T>();
+ return CreateUninitializedVector(len, sizeof(T),
+ reinterpret_cast<uint8_t **>(buf));
+ }
+
+ template<typename T>
+ Offset<Vector<const T *>> CreateUninitializedVectorOfStructs(size_t len,
+ T **buf) {
+ return CreateUninitializedVector(len, sizeof(T),
+ reinterpret_cast<uint8_t **>(buf));
+ }
+
+ // @brief Create a vector of scalar type T given as input a vector of scalar
+ // type U, useful with e.g. pre "enum class" enums, or any existing scalar
+ // data of the wrong type.
+ template<typename T, typename U>
+ Offset<Vector<T>> CreateVectorScalarCast(const U *v, size_t len) {
+ AssertScalarT<T>();
+ AssertScalarT<U>();
+ StartVector(len, sizeof(T));
+ for (auto i = len; i > 0;) { PushElement(static_cast<T>(v[--i])); }
+ return Offset<Vector<T>>(EndVector(len));
+ }
+
+ /// @brief Write a struct by itself, typically to be part of a union.
+ template<typename T> Offset<const T *> CreateStruct(const T &structobj) {
+ NotNested();
+ Align(AlignOf<T>());
+ buf_.push_small(structobj);
+ return Offset<const T *>(GetSize());
+ }
+
+ /// @brief The length of a FlatBuffer file header.
+ static const size_t kFileIdentifierLength = 4;
+
+ /// @brief Finish serializing a buffer by writing the root offset.
+ /// @param[in] file_identifier If a `file_identifier` is given, the buffer
+ /// will be prefixed with a standard FlatBuffers file header.
+ template<typename T>
+ void Finish(Offset<T> root, const char *file_identifier = nullptr) {
+ Finish(root.o, file_identifier, false);
+ }
+
+ /// @brief Finish a buffer with a 32 bit size field pre-fixed (size of the
+ /// buffer following the size field). These buffers are NOT compatible
+ /// with standard buffers created by Finish, i.e. you can't call GetRoot
+ /// on them, you have to use GetSizePrefixedRoot instead.
+ /// All >32 bit quantities in this buffer will be aligned when the whole
+ /// size pre-fixed buffer is aligned.
+ /// These kinds of buffers are useful for creating a stream of FlatBuffers.
+ template<typename T>
+ void FinishSizePrefixed(Offset<T> root,
+ const char *file_identifier = nullptr) {
+ Finish(root.o, file_identifier, true);
+ }
+
+ void SwapBufAllocator(FlatBufferBuilder &other) {
+ buf_.swap_allocator(other.buf_);
+ }
+
+ protected:
+ // You shouldn't really be copying instances of this class.
+ FlatBufferBuilder(const FlatBufferBuilder &);
+ FlatBufferBuilder &operator=(const FlatBufferBuilder &);
+
+ void Finish(uoffset_t root, const char *file_identifier, bool size_prefix) {
+ NotNested();
+ buf_.clear_scratch();
+ // This will cause the whole buffer to be aligned.
+ PreAlign((size_prefix ? sizeof(uoffset_t) : 0) + sizeof(uoffset_t) +
+ (file_identifier ? kFileIdentifierLength : 0),
+ minalign_);
+ if (file_identifier) {
+ FLATBUFFERS_ASSERT(strlen(file_identifier) == kFileIdentifierLength);
+ PushBytes(reinterpret_cast<const uint8_t *>(file_identifier),
+ kFileIdentifierLength);
+ }
+ PushElement(ReferTo(root)); // Location of root.
+ if (size_prefix) { PushElement(GetSize()); }
+ finished = true;
+ }
+
+ struct FieldLoc {
+ uoffset_t off;
+ voffset_t id;
+ };
+
+ vector_downward buf_;
+
+ // Accumulating offsets of table members while it is being built.
+ // We store these in the scratch pad of buf_, after the vtable offsets.
+ uoffset_t num_field_loc;
+ // Track how much of the vtable is in use, so we can output the most compact
+ // possible vtable.
+ voffset_t max_voffset_;
+
+ // Ensure objects are not nested.
+ bool nested;
+
+ // Ensure the buffer is finished before it is being accessed.
+ bool finished;
+
+ size_t minalign_;
+
+ bool force_defaults_; // Serialize values equal to their defaults anyway.
+
+ bool dedup_vtables_;
+
+ struct StringOffsetCompare {
+ StringOffsetCompare(const vector_downward &buf) : buf_(&buf) {}
+ bool operator()(const Offset<String> &a, const Offset<String> &b) const {
+ auto stra = reinterpret_cast<const String *>(buf_->data_at(a.o));
+ auto strb = reinterpret_cast<const String *>(buf_->data_at(b.o));
+ return StringLessThan(stra->data(), stra->size(), strb->data(),
+ strb->size());
+ }
+ const vector_downward *buf_;
+ };
+
+ // For use with CreateSharedString. Instantiated on first use only.
+ typedef std::set<Offset<String>, StringOffsetCompare> StringOffsetMap;
+ StringOffsetMap *string_pool;
+
+ private:
+ // Allocates space for a vector of structures.
+ // Must be completed with EndVectorOfStructs().
+ template<typename T> T *StartVectorOfStructs(size_t vector_size) {
+ StartVector(vector_size * sizeof(T) / AlignOf<T>(), AlignOf<T>());
+ return reinterpret_cast<T *>(buf_.make_space(vector_size * sizeof(T)));
+ }
+
+ // End the vector of structues in the flatbuffers.
+ // Vector should have previously be started with StartVectorOfStructs().
+ template<typename T>
+ Offset<Vector<const T *>> EndVectorOfStructs(size_t vector_size) {
+ return Offset<Vector<const T *>>(EndVector(vector_size));
+ }
+};
+/// @}
+
+/// @cond FLATBUFFERS_INTERNAL
+// Helpers to get a typed pointer to the root object contained in the buffer.
+template<typename T> T *GetMutableRoot(void *buf) {
+ EndianCheck();
+ return reinterpret_cast<T *>(
+ reinterpret_cast<uint8_t *>(buf) +
+ EndianScalar(*reinterpret_cast<uoffset_t *>(buf)));
+}
+
+template<typename T> const T *GetRoot(const void *buf) {
+ return GetMutableRoot<T>(const_cast<void *>(buf));
+}
+
+template<typename T> const T *GetSizePrefixedRoot(const void *buf) {
+ return GetRoot<T>(reinterpret_cast<const uint8_t *>(buf) + sizeof(uoffset_t));
+}
+
+/// Helpers to get a typed pointer to objects that are currently being built.
+/// @warning Creating new objects will lead to reallocations and invalidates
+/// the pointer!
+template<typename T>
+T *GetMutableTemporaryPointer(FlatBufferBuilder &fbb, Offset<T> offset) {
+ return reinterpret_cast<T *>(fbb.GetCurrentBufferPointer() + fbb.GetSize() -
+ offset.o);
+}
+
+template<typename T>
+const T *GetTemporaryPointer(FlatBufferBuilder &fbb, Offset<T> offset) {
+ return GetMutableTemporaryPointer<T>(fbb, offset);
+}
+
+/// @brief Get a pointer to the the file_identifier section of the buffer.
+/// @return Returns a const char pointer to the start of the file_identifier
+/// characters in the buffer. The returned char * has length
+/// 'flatbuffers::FlatBufferBuilder::kFileIdentifierLength'.
+/// This function is UNDEFINED for FlatBuffers whose schema does not include
+/// a file_identifier (likely points at padding or the start of a the root
+/// vtable).
+inline const char *GetBufferIdentifier(const void *buf,
+ bool size_prefixed = false) {
+ return reinterpret_cast<const char *>(buf) +
+ ((size_prefixed) ? 2 * sizeof(uoffset_t) : sizeof(uoffset_t));
+}
+
+// Helper to see if the identifier in a buffer has the expected value.
+inline bool BufferHasIdentifier(const void *buf, const char *identifier,
+ bool size_prefixed = false) {
+ return strncmp(GetBufferIdentifier(buf, size_prefixed), identifier,
+ FlatBufferBuilder::kFileIdentifierLength) == 0;
+}
+
+// Helper class to verify the integrity of a FlatBuffer
+class Verifier FLATBUFFERS_FINAL_CLASS {
+ public:
+ Verifier(const uint8_t *buf, size_t buf_len, uoffset_t _max_depth = 64,
+ uoffset_t _max_tables = 1000000, bool _check_alignment = true)
+ : buf_(buf),
+ size_(buf_len),
+ depth_(0),
+ max_depth_(_max_depth),
+ num_tables_(0),
+ max_tables_(_max_tables),
+ upper_bound_(0),
+ check_alignment_(_check_alignment) {
+ FLATBUFFERS_ASSERT(size_ < FLATBUFFERS_MAX_BUFFER_SIZE);
+ }
+
+ // Central location where any verification failures register.
+ bool Check(bool ok) const {
+ // clang-format off
+ #ifdef FLATBUFFERS_DEBUG_VERIFICATION_FAILURE
+ FLATBUFFERS_ASSERT(ok);
+ #endif
+ #ifdef FLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE
+ if (!ok)
+ upper_bound_ = 0;
+ #endif
+ // clang-format on
+ return ok;
+ }
+
+ // Verify any range within the buffer.
+ bool Verify(size_t elem, size_t elem_len) const {
+ // clang-format off
+ #ifdef FLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE
+ auto upper_bound = elem + elem_len;
+ if (upper_bound_ < upper_bound)
+ upper_bound_ = upper_bound;
+ #endif
+ // clang-format on
+ return Check(elem_len < size_ && elem <= size_ - elem_len);
+ }
+
+ template<typename T> bool VerifyAlignment(size_t elem) const {
+ return Check((elem & (sizeof(T) - 1)) == 0 || !check_alignment_);
+ }
+
+ // Verify a range indicated by sizeof(T).
+ template<typename T> bool Verify(size_t elem) const {
+ return VerifyAlignment<T>(elem) && Verify(elem, sizeof(T));
+ }
+
+ bool VerifyFromPointer(const uint8_t *p, size_t len) {
+ auto o = static_cast<size_t>(p - buf_);
+ return Verify(o, len);
+ }
+
+ // Verify relative to a known-good base pointer.
+ bool Verify(const uint8_t *base, voffset_t elem_off, size_t elem_len) const {
+ return Verify(static_cast<size_t>(base - buf_) + elem_off, elem_len);
+ }
+
+ template<typename T>
+ bool Verify(const uint8_t *base, voffset_t elem_off) const {
+ return Verify(static_cast<size_t>(base - buf_) + elem_off, sizeof(T));
+ }
+
+ // Verify a pointer (may be NULL) of a table type.
+ template<typename T> bool VerifyTable(const T *table) {
+ return !table || table->Verify(*this);
+ }
+
+ // Verify a pointer (may be NULL) of any vector type.
+ template<typename T> bool VerifyVector(const Vector<T> *vec) const {
+ return !vec || VerifyVectorOrString(reinterpret_cast<const uint8_t *>(vec),
+ sizeof(T));
+ }
+
+ // Verify a pointer (may be NULL) of a vector to struct.
+ template<typename T> bool VerifyVector(const Vector<const T *> *vec) const {
+ return VerifyVector(reinterpret_cast<const Vector<T> *>(vec));
+ }
+
+ // Verify a pointer (may be NULL) to string.
+ bool VerifyString(const String *str) const {
+ size_t end;
+ return !str || (VerifyVectorOrString(reinterpret_cast<const uint8_t *>(str),
+ 1, &end) &&
+ Verify(end, 1) && // Must have terminator
+ Check(buf_[end] == '\0')); // Terminating byte must be 0.
+ }
+
+ // Common code between vectors and strings.
+ bool VerifyVectorOrString(const uint8_t *vec, size_t elem_size,
+ size_t *end = nullptr) const {
+ auto veco = static_cast<size_t>(vec - buf_);
+ // Check we can read the size field.
+ if (!Verify<uoffset_t>(veco)) return false;
+ // Check the whole array. If this is a string, the byte past the array
+ // must be 0.
+ auto size = ReadScalar<uoffset_t>(vec);
+ auto max_elems = FLATBUFFERS_MAX_BUFFER_SIZE / elem_size;
+ if (!Check(size < max_elems))
+ return false; // Protect against byte_size overflowing.
+ auto byte_size = sizeof(size) + elem_size * size;
+ if (end) *end = veco + byte_size;
+ return Verify(veco, byte_size);
+ }
+
+ // Special case for string contents, after the above has been called.
+ bool VerifyVectorOfStrings(const Vector<Offset<String>> *vec) const {
+ if (vec) {
+ for (uoffset_t i = 0; i < vec->size(); i++) {
+ if (!VerifyString(vec->Get(i))) return false;
+ }
+ }
+ return true;
+ }
+
+ // Special case for table contents, after the above has been called.
+ template<typename T> bool VerifyVectorOfTables(const Vector<Offset<T>> *vec) {
+ if (vec) {
+ for (uoffset_t i = 0; i < vec->size(); i++) {
+ if (!vec->Get(i)->Verify(*this)) return false;
+ }
+ }
+ return true;
+ }
+
+ __supress_ubsan__("unsigned-integer-overflow") bool VerifyTableStart(
+ const uint8_t *table) {
+ // Check the vtable offset.
+ auto tableo = static_cast<size_t>(table - buf_);
+ if (!Verify<soffset_t>(tableo)) return false;
+ // This offset may be signed, but doing the subtraction unsigned always
+ // gives the result we want.
+ auto vtableo = tableo - static_cast<size_t>(ReadScalar<soffset_t>(table));
+ // Check the vtable size field, then check vtable fits in its entirety.
+ return VerifyComplexity() && Verify<voffset_t>(vtableo) &&
+ VerifyAlignment<voffset_t>(ReadScalar<voffset_t>(buf_ + vtableo)) &&
+ Verify(vtableo, ReadScalar<voffset_t>(buf_ + vtableo));
+ }
+
+ template<typename T>
+ bool VerifyBufferFromStart(const char *identifier, size_t start) {
+ if (identifier && (size_ < 2 * sizeof(flatbuffers::uoffset_t) ||
+ !BufferHasIdentifier(buf_ + start, identifier))) {
+ return false;
+ }
+
+ // Call T::Verify, which must be in the generated code for this type.
+ auto o = VerifyOffset(start);
+ return o && reinterpret_cast<const T *>(buf_ + start + o)->Verify(*this)
+ // clang-format off
+ #ifdef FLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE
+ && GetComputedSize()
+ #endif
+ ;
+ // clang-format on
+ }
+
+ // Verify this whole buffer, starting with root type T.
+ template<typename T> bool VerifyBuffer() { return VerifyBuffer<T>(nullptr); }
+
+ template<typename T> bool VerifyBuffer(const char *identifier) {
+ return VerifyBufferFromStart<T>(identifier, 0);
+ }
+
+ template<typename T> bool VerifySizePrefixedBuffer(const char *identifier) {
+ return Verify<uoffset_t>(0U) &&
+ ReadScalar<uoffset_t>(buf_) == size_ - sizeof(uoffset_t) &&
+ VerifyBufferFromStart<T>(identifier, sizeof(uoffset_t));
+ }
+
+ uoffset_t VerifyOffset(size_t start) const {
+ if (!Verify<uoffset_t>(start)) return 0;
+ auto o = ReadScalar<uoffset_t>(buf_ + start);
+ // May not point to itself.
+ if (!Check(o != 0)) return 0;
+ // Can't wrap around / buffers are max 2GB.
+ if (!Check(static_cast<soffset_t>(o) >= 0)) return 0;
+ // Must be inside the buffer to create a pointer from it (pointer outside
+ // buffer is UB).
+ if (!Verify(start + o, 1)) return 0;
+ return o;
+ }
+
+ uoffset_t VerifyOffset(const uint8_t *base, voffset_t start) const {
+ return VerifyOffset(static_cast<size_t>(base - buf_) + start);
+ }
+
+ // Called at the start of a table to increase counters measuring data
+ // structure depth and amount, and possibly bails out with false if
+ // limits set by the constructor have been hit. Needs to be balanced
+ // with EndTable().
+ bool VerifyComplexity() {
+ depth_++;
+ num_tables_++;
+ return Check(depth_ <= max_depth_ && num_tables_ <= max_tables_);
+ }
+
+ // Called at the end of a table to pop the depth count.
+ bool EndTable() {
+ depth_--;
+ return true;
+ }
+
+ // Returns the message size in bytes
+ size_t GetComputedSize() const {
+ // clang-format off
+ #ifdef FLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE
+ uintptr_t size = upper_bound_;
+ // Align the size to uoffset_t
+ size = (size - 1 + sizeof(uoffset_t)) & ~(sizeof(uoffset_t) - 1);
+ return (size > size_) ? 0 : size;
+ #else
+ // Must turn on FLATBUFFERS_TRACK_VERIFIER_BUFFER_SIZE for this to work.
+ (void)upper_bound_;
+ FLATBUFFERS_ASSERT(false);
+ return 0;
+ #endif
+ // clang-format on
+ }
+
+ private:
+ const uint8_t *buf_;
+ size_t size_;
+ uoffset_t depth_;
+ uoffset_t max_depth_;
+ uoffset_t num_tables_;
+ uoffset_t max_tables_;
+ mutable size_t upper_bound_;
+ bool check_alignment_;
+};
+
+// Convenient way to bundle a buffer and its length, to pass it around
+// typed by its root.
+// A BufferRef does not own its buffer.
+struct BufferRefBase {}; // for std::is_base_of
+template<typename T> struct BufferRef : BufferRefBase {
+ BufferRef() : buf(nullptr), len(0), must_free(false) {}
+ BufferRef(uint8_t *_buf, uoffset_t _len)
+ : buf(_buf), len(_len), must_free(false) {}
+
+ ~BufferRef() {
+ if (must_free) free(buf);
+ }
+
+ const T *GetRoot() const { return flatbuffers::GetRoot<T>(buf); }
+
+ bool Verify() {
+ Verifier verifier(buf, len);
+ return verifier.VerifyBuffer<T>(nullptr);
+ }
+
+ uint8_t *buf;
+ uoffset_t len;
+ bool must_free;
+};
+
+// "structs" are flat structures that do not have an offset table, thus
+// always have all members present and do not support forwards/backwards
+// compatible extensions.
+
+class Struct FLATBUFFERS_FINAL_CLASS {
+ public:
+ template<typename T> T GetField(uoffset_t o) const {
+ return ReadScalar<T>(&data_[o]);
+ }
+
+ template<typename T> T GetStruct(uoffset_t o) const {
+ return reinterpret_cast<T>(&data_[o]);
+ }
+
+ const uint8_t *GetAddressOf(uoffset_t o) const { return &data_[o]; }
+ uint8_t *GetAddressOf(uoffset_t o) { return &data_[o]; }
+
+ private:
+ // private constructor & copy constructor: you obtain instances of this
+ // class by pointing to existing data only
+ Struct();
+ Struct(const Struct &);
+ Struct &operator=(const Struct &);
+
+ uint8_t data_[1];
+};
+
+// "tables" use an offset table (possibly shared) that allows fields to be
+// omitted and added at will, but uses an extra indirection to read.
+class Table {
+ public:
+ const uint8_t *GetVTable() const {
+ return data_ - ReadScalar<soffset_t>(data_);
+ }
+
+ // This gets the field offset for any of the functions below it, or 0
+ // if the field was not present.
+ voffset_t GetOptionalFieldOffset(voffset_t field) const {
+ // The vtable offset is always at the start.
+ auto vtable = GetVTable();
+ // The first element is the size of the vtable (fields + type id + itself).
+ auto vtsize = ReadScalar<voffset_t>(vtable);
+ // If the field we're accessing is outside the vtable, we're reading older
+ // data, so it's the same as if the offset was 0 (not present).
+ return field < vtsize ? ReadScalar<voffset_t>(vtable + field) : 0;
+ }
+
+ template<typename T> T GetField(voffset_t field, T defaultval) const {
+ auto field_offset = GetOptionalFieldOffset(field);
+ return field_offset ? ReadScalar<T>(data_ + field_offset) : defaultval;
+ }
+
+ template<typename P> P GetPointer(voffset_t field) {
+ auto field_offset = GetOptionalFieldOffset(field);
+ auto p = data_ + field_offset;
+ return field_offset ? reinterpret_cast<P>(p + ReadScalar<uoffset_t>(p))
+ : nullptr;
+ }
+ template<typename P> P GetPointer(voffset_t field) const {
+ return const_cast<Table *>(this)->GetPointer<P>(field);
+ }
+
+ template<typename P> P GetStruct(voffset_t field) const {
+ auto field_offset = GetOptionalFieldOffset(field);
+ auto p = const_cast<uint8_t *>(data_ + field_offset);
+ return field_offset ? reinterpret_cast<P>(p) : nullptr;
+ }
+
+ template<typename T> bool SetField(voffset_t field, T val, T def) {
+ auto field_offset = GetOptionalFieldOffset(field);
+ if (!field_offset) return IsTheSameAs(val, def);
+ WriteScalar(data_ + field_offset, val);
+ return true;
+ }
+
+ bool SetPointer(voffset_t field, const uint8_t *val) {
+ auto field_offset = GetOptionalFieldOffset(field);
+ if (!field_offset) return false;
+ WriteScalar(data_ + field_offset,
+ static_cast<uoffset_t>(val - (data_ + field_offset)));
+ return true;
+ }
+
+ uint8_t *GetAddressOf(voffset_t field) {
+ auto field_offset = GetOptionalFieldOffset(field);
+ return field_offset ? data_ + field_offset : nullptr;
+ }
+ const uint8_t *GetAddressOf(voffset_t field) const {
+ return const_cast<Table *>(this)->GetAddressOf(field);
+ }
+
+ bool CheckField(voffset_t field) const {
+ return GetOptionalFieldOffset(field) != 0;
+ }
+
+ // Verify the vtable of this table.
+ // Call this once per table, followed by VerifyField once per field.
+ bool VerifyTableStart(Verifier &verifier) const {
+ return verifier.VerifyTableStart(data_);
+ }
+
+ // Verify a particular field.
+ template<typename T>
+ bool VerifyField(const Verifier &verifier, voffset_t field) const {
+ // Calling GetOptionalFieldOffset should be safe now thanks to
+ // VerifyTable().
+ auto field_offset = GetOptionalFieldOffset(field);
+ // Check the actual field.
+ return !field_offset || verifier.Verify<T>(data_, field_offset);
+ }
+
+ // VerifyField for required fields.
+ template<typename T>
+ bool VerifyFieldRequired(const Verifier &verifier, voffset_t field) const {
+ auto field_offset = GetOptionalFieldOffset(field);
+ return verifier.Check(field_offset != 0) &&
+ verifier.Verify<T>(data_, field_offset);
+ }
+
+ // Versions for offsets.
+ bool VerifyOffset(const Verifier &verifier, voffset_t field) const {
+ auto field_offset = GetOptionalFieldOffset(field);
+ return !field_offset || verifier.VerifyOffset(data_, field_offset);
+ }
+
+ bool VerifyOffsetRequired(const Verifier &verifier, voffset_t field) const {
+ auto field_offset = GetOptionalFieldOffset(field);
+ return verifier.Check(field_offset != 0) &&
+ verifier.VerifyOffset(data_, field_offset);
+ }
+
+ private:
+ // private constructor & copy constructor: you obtain instances of this
+ // class by pointing to existing data only
+ Table();
+ Table(const Table &other);
+ Table &operator=(const Table &);
+
+ uint8_t data_[1];
+};
+
+template<typename T>
+void FlatBufferBuilder::Required(Offset<T> table, voffset_t field) {
+ auto table_ptr = reinterpret_cast<const Table *>(buf_.data_at(table.o));
+ bool ok = table_ptr->GetOptionalFieldOffset(field) != 0;
+ // If this fails, the caller will show what field needs to be set.
+ FLATBUFFERS_ASSERT(ok);
+ (void)ok;
+}
+
+/// @brief This can compute the start of a FlatBuffer from a root pointer, i.e.
+/// it is the opposite transformation of GetRoot().
+/// This may be useful if you want to pass on a root and have the recipient
+/// delete the buffer afterwards.
+inline const uint8_t *GetBufferStartFromRootPointer(const void *root) {
+ auto table = reinterpret_cast<const Table *>(root);
+ auto vtable = table->GetVTable();
+ // Either the vtable is before the root or after the root.
+ auto start = (std::min)(vtable, reinterpret_cast<const uint8_t *>(root));
+ // Align to at least sizeof(uoffset_t).
+ start = reinterpret_cast<const uint8_t *>(reinterpret_cast<uintptr_t>(start) &
+ ~(sizeof(uoffset_t) - 1));
+ // Additionally, there may be a file_identifier in the buffer, and the root
+ // offset. The buffer may have been aligned to any size between
+ // sizeof(uoffset_t) and FLATBUFFERS_MAX_ALIGNMENT (see "force_align").
+ // Sadly, the exact alignment is only known when constructing the buffer,
+ // since it depends on the presence of values with said alignment properties.
+ // So instead, we simply look at the next uoffset_t values (root,
+ // file_identifier, and alignment padding) to see which points to the root.
+ // None of the other values can "impersonate" the root since they will either
+ // be 0 or four ASCII characters.
+ static_assert(FlatBufferBuilder::kFileIdentifierLength == sizeof(uoffset_t),
+ "file_identifier is assumed to be the same size as uoffset_t");
+ for (auto possible_roots = FLATBUFFERS_MAX_ALIGNMENT / sizeof(uoffset_t) + 1;
+ possible_roots; possible_roots--) {
+ start -= sizeof(uoffset_t);
+ if (ReadScalar<uoffset_t>(start) + start ==
+ reinterpret_cast<const uint8_t *>(root))
+ return start;
+ }
+ // We didn't find the root, either the "root" passed isn't really a root,
+ // or the buffer is corrupt.
+ // Assert, because calling this function with bad data may cause reads
+ // outside of buffer boundaries.
+ FLATBUFFERS_ASSERT(false);
+ return nullptr;
+}
+
+/// @brief This return the prefixed size of a FlatBuffer.
+inline uoffset_t GetPrefixedSize(const uint8_t *buf) {
+ return ReadScalar<uoffset_t>(buf);
+}
+
+// Base class for native objects (FlatBuffer data de-serialized into native
+// C++ data structures).
+// Contains no functionality, purely documentative.
+struct NativeTable {};
+
+/// @brief Function types to be used with resolving hashes into objects and
+/// back again. The resolver gets a pointer to a field inside an object API
+/// object that is of the type specified in the schema using the attribute
+/// `cpp_type` (it is thus important whatever you write to this address
+/// matches that type). The value of this field is initially null, so you
+/// may choose to implement a delayed binding lookup using this function
+/// if you wish. The resolver does the opposite lookup, for when the object
+/// is being serialized again.
+typedef uint64_t hash_value_t;
+// clang-format off
+#ifdef FLATBUFFERS_CPP98_STL
+ typedef void (*resolver_function_t)(void **pointer_adr, hash_value_t hash);
+ typedef hash_value_t (*rehasher_function_t)(void *pointer);
+#else
+ typedef std::function<void (void **pointer_adr, hash_value_t hash)>
+ resolver_function_t;
+ typedef std::function<hash_value_t (void *pointer)> rehasher_function_t;
+#endif
+// clang-format on
+
+// Helper function to test if a field is present, using any of the field
+// enums in the generated code.
+// `table` must be a generated table type. Since this is a template parameter,
+// this is not typechecked to be a subclass of Table, so beware!
+// Note: this function will return false for fields equal to the default
+// value, since they're not stored in the buffer (unless force_defaults was
+// used).
+template<typename T>
+bool IsFieldPresent(const T *table, typename T::FlatBuffersVTableOffset field) {
+ // Cast, since Table is a private baseclass of any table types.
+ return reinterpret_cast<const Table *>(table)->CheckField(
+ static_cast<voffset_t>(field));
+}
+
+// Utility function for reverse lookups on the EnumNames*() functions
+// (in the generated C++ code)
+// names must be NULL terminated.
+inline int LookupEnum(const char **names, const char *name) {
+ for (const char **p = names; *p; p++)
+ if (!strcmp(*p, name)) return static_cast<int>(p - names);
+ return -1;
+}
+
+// These macros allow us to layout a struct with a guarantee that they'll end
+// up looking the same on different compilers and platforms.
+// It does this by disallowing the compiler to do any padding, and then
+// does padding itself by inserting extra padding fields that make every
+// element aligned to its own size.
+// Additionally, it manually sets the alignment of the struct as a whole,
+// which is typically its largest element, or a custom size set in the schema
+// by the force_align attribute.
+// These are used in the generated code only.
+
+// clang-format off
+#if defined(_MSC_VER)
+ #define FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(alignment) \
+ __pragma(pack(1)) \
+ struct __declspec(align(alignment))
+ #define FLATBUFFERS_STRUCT_END(name, size) \
+ __pragma(pack()) \
+ static_assert(sizeof(name) == size, "compiler breaks packing rules")
+#elif defined(__GNUC__) || defined(__clang__) || defined(__ICCARM__)
+ #define FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(alignment) \
+ _Pragma("pack(1)") \
+ struct __attribute__((aligned(alignment)))
+ #define FLATBUFFERS_STRUCT_END(name, size) \
+ _Pragma("pack()") \
+ static_assert(sizeof(name) == size, "compiler breaks packing rules")
+#else
+ #error Unknown compiler, please define structure alignment macros
+#endif
+// clang-format on
+
+// Minimal reflection via code generation.
+// Besides full-fat reflection (see reflection.h) and parsing/printing by
+// loading schemas (see idl.h), we can also have code generation for mimimal
+// reflection data which allows pretty-printing and other uses without needing
+// a schema or a parser.
+// Generate code with --reflect-types (types only) or --reflect-names (names
+// also) to enable.
+// See minireflect.h for utilities using this functionality.
+
+// These types are organized slightly differently as the ones in idl.h.
+enum SequenceType { ST_TABLE, ST_STRUCT, ST_UNION, ST_ENUM };
+
+// Scalars have the same order as in idl.h
+// clang-format off
+#define FLATBUFFERS_GEN_ELEMENTARY_TYPES(ET) \
+ ET(ET_UTYPE) \
+ ET(ET_BOOL) \
+ ET(ET_CHAR) \
+ ET(ET_UCHAR) \
+ ET(ET_SHORT) \
+ ET(ET_USHORT) \
+ ET(ET_INT) \
+ ET(ET_UINT) \
+ ET(ET_LONG) \
+ ET(ET_ULONG) \
+ ET(ET_FLOAT) \
+ ET(ET_DOUBLE) \
+ ET(ET_STRING) \
+ ET(ET_SEQUENCE) // See SequenceType.
+
+enum ElementaryType {
+ #define FLATBUFFERS_ET(E) E,
+ FLATBUFFERS_GEN_ELEMENTARY_TYPES(FLATBUFFERS_ET)
+ #undef FLATBUFFERS_ET
+};
+
+inline const char * const *ElementaryTypeNames() {
+ static const char * const names[] = {
+ #define FLATBUFFERS_ET(E) #E,
+ FLATBUFFERS_GEN_ELEMENTARY_TYPES(FLATBUFFERS_ET)
+ #undef FLATBUFFERS_ET
+ };
+ return names;
+}
+// clang-format on
+
+// Basic type info cost just 16bits per field!
+struct TypeCode {
+ uint16_t base_type : 4; // ElementaryType
+ uint16_t is_vector : 1;
+ int16_t sequence_ref : 11; // Index into type_refs below, or -1 for none.
+};
+
+static_assert(sizeof(TypeCode) == 2, "TypeCode");
+
+struct TypeTable;
+
+// Signature of the static method present in each type.
+typedef const TypeTable *(*TypeFunction)();
+
+struct TypeTable {
+ SequenceType st;
+ size_t num_elems; // of type_codes, values, names (but not type_refs).
+ const TypeCode *type_codes; // num_elems count
+ const TypeFunction *type_refs; // less than num_elems entries (see TypeCode).
+ const int64_t *values; // Only set for non-consecutive enum/union or structs.
+ const char *const *names; // Only set if compiled with --reflect-names.
+};
+
+// String which identifies the current version of FlatBuffers.
+// flatbuffer_version_string is used by Google developers to identify which
+// applications uploaded to Google Play are using this library. This allows
+// the development team at Google to determine the popularity of the library.
+// How it works: Applications that are uploaded to the Google Play Store are
+// scanned for this version string. We track which applications are using it
+// to measure popularity. You are free to remove it (of course) but we would
+// appreciate if you left it in.
+
+// Weak linkage is culled by VS & doesn't work on cygwin.
+// clang-format off
+#if !defined(_WIN32) && !defined(__CYGWIN__)
+
+extern volatile __attribute__((weak)) const char *flatbuffer_version_string;
+volatile __attribute__((weak)) const char *flatbuffer_version_string =
+ "FlatBuffers "
+ FLATBUFFERS_STRING(FLATBUFFERS_VERSION_MAJOR) "."
+ FLATBUFFERS_STRING(FLATBUFFERS_VERSION_MINOR) "."
+ FLATBUFFERS_STRING(FLATBUFFERS_VERSION_REVISION);
+
+#endif // !defined(_WIN32) && !defined(__CYGWIN__)
+
+#define FLATBUFFERS_DEFINE_BITMASK_OPERATORS(E, T)\
+ inline E operator | (E lhs, E rhs){\
+ return E(T(lhs) | T(rhs));\
+ }\
+ inline E operator & (E lhs, E rhs){\
+ return E(T(lhs) & T(rhs));\
+ }\
+ inline E operator ^ (E lhs, E rhs){\
+ return E(T(lhs) ^ T(rhs));\
+ }\
+ inline E operator ~ (E lhs){\
+ return E(~T(lhs));\
+ }\
+ inline E operator |= (E &lhs, E rhs){\
+ lhs = lhs | rhs;\
+ return lhs;\
+ }\
+ inline E operator &= (E &lhs, E rhs){\
+ lhs = lhs & rhs;\
+ return lhs;\
+ }\
+ inline E operator ^= (E &lhs, E rhs){\
+ lhs = lhs ^ rhs;\
+ return lhs;\
+ }\
+ inline bool operator !(E rhs) \
+ {\
+ return !bool(T(rhs)); \
+ }
+/// @endcond
+} // namespace flatbuffers
+
+// clang-format on
+
+#endif // FLATBUFFERS_H_
diff --git a/src/arrow/cpp/thirdparty/flatbuffers/include/flatbuffers/stl_emulation.h b/src/arrow/cpp/thirdparty/flatbuffers/include/flatbuffers/stl_emulation.h
new file mode 100644
index 000000000..8bae61bfd
--- /dev/null
+++ b/src/arrow/cpp/thirdparty/flatbuffers/include/flatbuffers/stl_emulation.h
@@ -0,0 +1,307 @@
+/*
+ * Copyright 2017 Google Inc. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef FLATBUFFERS_STL_EMULATION_H_
+#define FLATBUFFERS_STL_EMULATION_H_
+
+// clang-format off
+
+#include <string>
+#include <type_traits>
+#include <vector>
+#include <memory>
+#include <limits>
+
+#if defined(_STLPORT_VERSION) && !defined(FLATBUFFERS_CPP98_STL)
+ #define FLATBUFFERS_CPP98_STL
+#endif // defined(_STLPORT_VERSION) && !defined(FLATBUFFERS_CPP98_STL)
+
+#if defined(FLATBUFFERS_CPP98_STL)
+ #include <cctype>
+#endif // defined(FLATBUFFERS_CPP98_STL)
+
+// Check if we can use template aliases
+// Not possible if Microsoft Compiler before 2012
+// Possible is the language feature __cpp_alias_templates is defined well
+// Or possible if the C++ std is C+11 or newer
+#if (defined(_MSC_VER) && _MSC_VER > 1700 /* MSVC2012 */) \
+ || (defined(__cpp_alias_templates) && __cpp_alias_templates >= 200704) \
+ || (defined(__cplusplus) && __cplusplus >= 201103L)
+ #define FLATBUFFERS_TEMPLATES_ALIASES
+#endif
+
+// This header provides backwards compatibility for C++98 STLs like stlport.
+namespace flatbuffers {
+
+// Retrieve ::back() from a string in a way that is compatible with pre C++11
+// STLs (e.g stlport).
+inline char& string_back(std::string &value) {
+ return value[value.length() - 1];
+}
+
+inline char string_back(const std::string &value) {
+ return value[value.length() - 1];
+}
+
+// Helper method that retrieves ::data() from a vector in a way that is
+// compatible with pre C++11 STLs (e.g stlport).
+template <typename T> inline T *vector_data(std::vector<T> &vector) {
+ // In some debug environments, operator[] does bounds checking, so &vector[0]
+ // can't be used.
+ return vector.empty() ? nullptr : &vector[0];
+}
+
+template <typename T> inline const T *vector_data(
+ const std::vector<T> &vector) {
+ return vector.empty() ? nullptr : &vector[0];
+}
+
+template <typename T, typename V>
+inline void vector_emplace_back(std::vector<T> *vector, V &&data) {
+ #if defined(FLATBUFFERS_CPP98_STL)
+ vector->push_back(data);
+ #else
+ vector->emplace_back(std::forward<V>(data));
+ #endif // defined(FLATBUFFERS_CPP98_STL)
+}
+
+#ifndef FLATBUFFERS_CPP98_STL
+ #if defined(FLATBUFFERS_TEMPLATES_ALIASES)
+ template <typename T>
+ using numeric_limits = std::numeric_limits<T>;
+ #else
+ template <typename T> class numeric_limits :
+ public std::numeric_limits<T> {};
+ #endif // defined(FLATBUFFERS_TEMPLATES_ALIASES)
+#else
+ template <typename T> class numeric_limits :
+ public std::numeric_limits<T> {
+ public:
+ // Android NDK fix.
+ static T lowest() {
+ return std::numeric_limits<T>::min();
+ }
+ };
+
+ template <> class numeric_limits<float> :
+ public std::numeric_limits<float> {
+ public:
+ static float lowest() { return -FLT_MAX; }
+ };
+
+ template <> class numeric_limits<double> :
+ public std::numeric_limits<double> {
+ public:
+ static double lowest() { return -DBL_MAX; }
+ };
+
+ template <> class numeric_limits<unsigned long long> {
+ public:
+ static unsigned long long min() { return 0ULL; }
+ static unsigned long long max() { return ~0ULL; }
+ static unsigned long long lowest() {
+ return numeric_limits<unsigned long long>::min();
+ }
+ };
+
+ template <> class numeric_limits<long long> {
+ public:
+ static long long min() {
+ return static_cast<long long>(1ULL << ((sizeof(long long) << 3) - 1));
+ }
+ static long long max() {
+ return static_cast<long long>(
+ (1ULL << ((sizeof(long long) << 3) - 1)) - 1);
+ }
+ static long long lowest() {
+ return numeric_limits<long long>::min();
+ }
+ };
+#endif // FLATBUFFERS_CPP98_STL
+
+#if defined(FLATBUFFERS_TEMPLATES_ALIASES)
+ #ifndef FLATBUFFERS_CPP98_STL
+ template <typename T> using is_scalar = std::is_scalar<T>;
+ template <typename T, typename U> using is_same = std::is_same<T,U>;
+ template <typename T> using is_floating_point = std::is_floating_point<T>;
+ template <typename T> using is_unsigned = std::is_unsigned<T>;
+ template <typename T> using is_enum = std::is_enum<T>;
+ template <typename T> using make_unsigned = std::make_unsigned<T>;
+ template<bool B, class T, class F>
+ using conditional = std::conditional<B, T, F>;
+ template<class T, T v>
+ using integral_constant = std::integral_constant<T, v>;
+ #else
+ // Map C++ TR1 templates defined by stlport.
+ template <typename T> using is_scalar = std::tr1::is_scalar<T>;
+ template <typename T, typename U> using is_same = std::tr1::is_same<T,U>;
+ template <typename T> using is_floating_point =
+ std::tr1::is_floating_point<T>;
+ template <typename T> using is_unsigned = std::tr1::is_unsigned<T>;
+ template <typename T> using is_enum = std::tr1::is_enum<T>;
+ // Android NDK doesn't have std::make_unsigned or std::tr1::make_unsigned.
+ template<typename T> struct make_unsigned {
+ static_assert(is_unsigned<T>::value, "Specialization not implemented!");
+ using type = T;
+ };
+ template<> struct make_unsigned<char> { using type = unsigned char; };
+ template<> struct make_unsigned<short> { using type = unsigned short; };
+ template<> struct make_unsigned<int> { using type = unsigned int; };
+ template<> struct make_unsigned<long> { using type = unsigned long; };
+ template<>
+ struct make_unsigned<long long> { using type = unsigned long long; };
+ template<bool B, class T, class F>
+ using conditional = std::tr1::conditional<B, T, F>;
+ template<class T, T v>
+ using integral_constant = std::tr1::integral_constant<T, v>;
+ #endif // !FLATBUFFERS_CPP98_STL
+#else
+ // MSVC 2010 doesn't support C++11 aliases.
+ template <typename T> struct is_scalar : public std::is_scalar<T> {};
+ template <typename T, typename U> struct is_same : public std::is_same<T,U> {};
+ template <typename T> struct is_floating_point :
+ public std::is_floating_point<T> {};
+ template <typename T> struct is_unsigned : public std::is_unsigned<T> {};
+ template <typename T> struct is_enum : public std::is_enum<T> {};
+ template <typename T> struct make_unsigned : public std::make_unsigned<T> {};
+ template<bool B, class T, class F>
+ struct conditional : public std::conditional<B, T, F> {};
+ template<class T, T v>
+ struct integral_constant : public std::integral_constant<T, v> {};
+#endif // defined(FLATBUFFERS_TEMPLATES_ALIASES)
+
+#ifndef FLATBUFFERS_CPP98_STL
+ #if defined(FLATBUFFERS_TEMPLATES_ALIASES)
+ template <class T> using unique_ptr = std::unique_ptr<T>;
+ #else
+ // MSVC 2010 doesn't support C++11 aliases.
+ // We're manually "aliasing" the class here as we want to bring unique_ptr
+ // into the flatbuffers namespace. We have unique_ptr in the flatbuffers
+ // namespace we have a completely independent implemenation (see below)
+ // for C++98 STL implementations.
+ template <class T> class unique_ptr : public std::unique_ptr<T> {
+ public:
+ unique_ptr() {}
+ explicit unique_ptr(T* p) : std::unique_ptr<T>(p) {}
+ unique_ptr(std::unique_ptr<T>&& u) { *this = std::move(u); }
+ unique_ptr(unique_ptr&& u) { *this = std::move(u); }
+ unique_ptr& operator=(std::unique_ptr<T>&& u) {
+ std::unique_ptr<T>::reset(u.release());
+ return *this;
+ }
+ unique_ptr& operator=(unique_ptr&& u) {
+ std::unique_ptr<T>::reset(u.release());
+ return *this;
+ }
+ unique_ptr& operator=(T* p) {
+ return std::unique_ptr<T>::operator=(p);
+ }
+ };
+ #endif // defined(FLATBUFFERS_TEMPLATES_ALIASES)
+#else
+ // Very limited implementation of unique_ptr.
+ // This is provided simply to allow the C++ code generated from the default
+ // settings to function in C++98 environments with no modifications.
+ template <class T> class unique_ptr {
+ public:
+ typedef T element_type;
+
+ unique_ptr() : ptr_(nullptr) {}
+ explicit unique_ptr(T* p) : ptr_(p) {}
+ unique_ptr(unique_ptr&& u) : ptr_(nullptr) { reset(u.release()); }
+ unique_ptr(const unique_ptr& u) : ptr_(nullptr) {
+ reset(const_cast<unique_ptr*>(&u)->release());
+ }
+ ~unique_ptr() { reset(); }
+
+ unique_ptr& operator=(const unique_ptr& u) {
+ reset(const_cast<unique_ptr*>(&u)->release());
+ return *this;
+ }
+
+ unique_ptr& operator=(unique_ptr&& u) {
+ reset(u.release());
+ return *this;
+ }
+
+ unique_ptr& operator=(T* p) {
+ reset(p);
+ return *this;
+ }
+
+ const T& operator*() const { return *ptr_; }
+ T* operator->() const { return ptr_; }
+ T* get() const noexcept { return ptr_; }
+ explicit operator bool() const { return ptr_ != nullptr; }
+
+ // modifiers
+ T* release() {
+ T* value = ptr_;
+ ptr_ = nullptr;
+ return value;
+ }
+
+ void reset(T* p = nullptr) {
+ T* value = ptr_;
+ ptr_ = p;
+ if (value) delete value;
+ }
+
+ void swap(unique_ptr& u) {
+ T* temp_ptr = ptr_;
+ ptr_ = u.ptr_;
+ u.ptr_ = temp_ptr;
+ }
+
+ private:
+ T* ptr_;
+ };
+
+ template <class T> bool operator==(const unique_ptr<T>& x,
+ const unique_ptr<T>& y) {
+ return x.get() == y.get();
+ }
+
+ template <class T, class D> bool operator==(const unique_ptr<T>& x,
+ const D* y) {
+ return static_cast<D*>(x.get()) == y;
+ }
+
+ template <class T> bool operator==(const unique_ptr<T>& x, intptr_t y) {
+ return reinterpret_cast<intptr_t>(x.get()) == y;
+ }
+
+ template <class T> bool operator!=(const unique_ptr<T>& x, decltype(nullptr)) {
+ return !!x;
+ }
+
+ template <class T> bool operator!=(decltype(nullptr), const unique_ptr<T>& x) {
+ return !!x;
+ }
+
+ template <class T> bool operator==(const unique_ptr<T>& x, decltype(nullptr)) {
+ return !x;
+ }
+
+ template <class T> bool operator==(decltype(nullptr), const unique_ptr<T>& x) {
+ return !x;
+ }
+
+#endif // !FLATBUFFERS_CPP98_STL
+
+} // namespace flatbuffers
+
+#endif // FLATBUFFERS_STL_EMULATION_H_
diff --git a/src/arrow/cpp/thirdparty/hadoop/include/hdfs.h b/src/arrow/cpp/thirdparty/hadoop/include/hdfs.h
new file mode 100644
index 000000000..a4df6ae3b
--- /dev/null
+++ b/src/arrow/cpp/thirdparty/hadoop/include/hdfs.h
@@ -0,0 +1,1024 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBHDFS_HDFS_H
+#define LIBHDFS_HDFS_H
+
+#include <errno.h> /* for EINTERNAL, etc. */
+#include <fcntl.h> /* for O_RDONLY, O_WRONLY */
+#include <stdint.h> /* for uint64_t, etc. */
+#include <time.h> /* for time_t */
+
+/*
+ * Support export of DLL symbols during libhdfs build, and import of DLL symbols
+ * during client application build. A client application may optionally define
+ * symbol LIBHDFS_DLL_IMPORT in its build. This is not strictly required, but
+ * the compiler can produce more efficient code with it.
+ */
+#ifdef WIN32
+ #ifdef LIBHDFS_DLL_EXPORT
+ #define LIBHDFS_EXTERNAL __declspec(dllexport)
+ #elif LIBHDFS_DLL_IMPORT
+ #define LIBHDFS_EXTERNAL __declspec(dllimport)
+ #else
+ #define LIBHDFS_EXTERNAL
+ #endif
+#else
+ #ifdef LIBHDFS_DLL_EXPORT
+ #define LIBHDFS_EXTERNAL __attribute__((visibility("default")))
+ #elif LIBHDFS_DLL_IMPORT
+ #define LIBHDFS_EXTERNAL __attribute__((visibility("default")))
+ #else
+ #define LIBHDFS_EXTERNAL
+ #endif
+#endif
+
+#ifndef O_RDONLY
+#define O_RDONLY 1
+#endif
+
+#ifndef O_WRONLY
+#define O_WRONLY 2
+#endif
+
+#ifndef EINTERNAL
+#define EINTERNAL 255
+#endif
+
+#define ELASTIC_BYTE_BUFFER_POOL_CLASS \
+ "org/apache/hadoop/io/ElasticByteBufferPool"
+
+/** All APIs set errno to meaningful values */
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+ /**
+ * Some utility decls used in libhdfs.
+ */
+ struct hdfsBuilder;
+ typedef int32_t tSize; /// size of data for read/write io ops
+ typedef time_t tTime; /// time type in seconds
+ typedef int64_t tOffset;/// offset within the file
+ typedef uint16_t tPort; /// port
+ typedef enum tObjectKind {
+ kObjectKindFile = 'F',
+ kObjectKindDirectory = 'D',
+ } tObjectKind;
+ struct hdfsStreamBuilder;
+
+
+ /**
+ * The C reflection of org.apache.org.hadoop.FileSystem .
+ */
+ struct hdfs_internal;
+ typedef struct hdfs_internal* hdfsFS;
+
+ struct hdfsFile_internal;
+ typedef struct hdfsFile_internal* hdfsFile;
+
+ struct hadoopRzOptions;
+
+ struct hadoopRzBuffer;
+
+ /**
+ * Determine if a file is open for read.
+ *
+ * @param file The HDFS file
+ * @return 1 if the file is open for read; 0 otherwise
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsFileIsOpenForRead(hdfsFile file);
+
+ /**
+ * Determine if a file is open for write.
+ *
+ * @param file The HDFS file
+ * @return 1 if the file is open for write; 0 otherwise
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsFileIsOpenForWrite(hdfsFile file);
+
+ struct hdfsReadStatistics {
+ uint64_t totalBytesRead;
+ uint64_t totalLocalBytesRead;
+ uint64_t totalShortCircuitBytesRead;
+ uint64_t totalZeroCopyBytesRead;
+ };
+
+ /**
+ * Get read statistics about a file. This is only applicable to files
+ * opened for reading.
+ *
+ * @param file The HDFS file
+ * @param stats (out parameter) on a successful return, the read
+ * statistics. Unchanged otherwise. You must free the
+ * returned statistics with hdfsFileFreeReadStatistics.
+ * @return 0 if the statistics were successfully returned,
+ * -1 otherwise. On a failure, please check errno against
+ * ENOTSUP. webhdfs, LocalFilesystem, and so forth may
+ * not support read statistics.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsFileGetReadStatistics(hdfsFile file,
+ struct hdfsReadStatistics **stats);
+
+ /**
+ * @param stats HDFS read statistics for a file.
+ *
+ * @return the number of remote bytes read.
+ */
+ LIBHDFS_EXTERNAL
+ int64_t hdfsReadStatisticsGetRemoteBytesRead(
+ const struct hdfsReadStatistics *stats);
+
+ /**
+ * Clear the read statistics for a file.
+ *
+ * @param file The file to clear the read statistics of.
+ *
+ * @return 0 on success; the error code otherwise.
+ * EINVAL: the file is not open for reading.
+ * ENOTSUP: the file does not support clearing the read
+ * statistics.
+ * Errno will also be set to this code on failure.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsFileClearReadStatistics(hdfsFile file);
+
+ /**
+ * Free some HDFS read statistics.
+ *
+ * @param stats The HDFS read statistics to free.
+ */
+ LIBHDFS_EXTERNAL
+ void hdfsFileFreeReadStatistics(struct hdfsReadStatistics *stats);
+
+ /**
+ * hdfsConnectAsUser - Connect to a hdfs file system as a specific user
+ * Connect to the hdfs.
+ * @param nn The NameNode. See hdfsBuilderSetNameNode for details.
+ * @param port The port on which the server is listening.
+ * @param user the user name (this is hadoop domain user). Or NULL is equivelant to hhdfsConnect(host, port)
+ * @return Returns a handle to the filesystem or NULL on error.
+ * @deprecated Use hdfsBuilderConnect instead.
+ */
+ LIBHDFS_EXTERNAL
+ hdfsFS hdfsConnectAsUser(const char* nn, tPort port, const char *user);
+
+ /**
+ * hdfsConnect - Connect to a hdfs file system.
+ * Connect to the hdfs.
+ * @param nn The NameNode. See hdfsBuilderSetNameNode for details.
+ * @param port The port on which the server is listening.
+ * @return Returns a handle to the filesystem or NULL on error.
+ * @deprecated Use hdfsBuilderConnect instead.
+ */
+ LIBHDFS_EXTERNAL
+ hdfsFS hdfsConnect(const char* nn, tPort port);
+
+ /**
+ * hdfsConnect - Connect to an hdfs file system.
+ *
+ * Forces a new instance to be created
+ *
+ * @param nn The NameNode. See hdfsBuilderSetNameNode for details.
+ * @param port The port on which the server is listening.
+ * @param user The user name to use when connecting
+ * @return Returns a handle to the filesystem or NULL on error.
+ * @deprecated Use hdfsBuilderConnect instead.
+ */
+ LIBHDFS_EXTERNAL
+ hdfsFS hdfsConnectAsUserNewInstance(const char* nn, tPort port, const char *user );
+
+ /**
+ * hdfsConnect - Connect to an hdfs file system.
+ *
+ * Forces a new instance to be created
+ *
+ * @param nn The NameNode. See hdfsBuilderSetNameNode for details.
+ * @param port The port on which the server is listening.
+ * @return Returns a handle to the filesystem or NULL on error.
+ * @deprecated Use hdfsBuilderConnect instead.
+ */
+ LIBHDFS_EXTERNAL
+ hdfsFS hdfsConnectNewInstance(const char* nn, tPort port);
+
+ /**
+ * Connect to HDFS using the parameters defined by the builder.
+ *
+ * The HDFS builder will be freed, whether or not the connection was
+ * successful.
+ *
+ * Every successful call to hdfsBuilderConnect should be matched with a call
+ * to hdfsDisconnect, when the hdfsFS is no longer needed.
+ *
+ * @param bld The HDFS builder
+ * @return Returns a handle to the filesystem, or NULL on error.
+ */
+ LIBHDFS_EXTERNAL
+ hdfsFS hdfsBuilderConnect(struct hdfsBuilder *bld);
+
+ /**
+ * Create an HDFS builder.
+ *
+ * @return The HDFS builder, or NULL on error.
+ */
+ LIBHDFS_EXTERNAL
+ struct hdfsBuilder *hdfsNewBuilder(void);
+
+ /**
+ * Force the builder to always create a new instance of the FileSystem,
+ * rather than possibly finding one in the cache.
+ *
+ * @param bld The HDFS builder
+ */
+ LIBHDFS_EXTERNAL
+ void hdfsBuilderSetForceNewInstance(struct hdfsBuilder *bld);
+
+ /**
+ * Set the HDFS NameNode to connect to.
+ *
+ * @param bld The HDFS builder
+ * @param nn The NameNode to use.
+ *
+ * If the string given is 'default', the default NameNode
+ * configuration will be used (from the XML configuration files)
+ *
+ * If NULL is given, a LocalFileSystem will be created.
+ *
+ * If the string starts with a protocol type such as file:// or
+ * hdfs://, this protocol type will be used. If not, the
+ * hdfs:// protocol type will be used.
+ *
+ * You may specify a NameNode port in the usual way by
+ * passing a string of the format hdfs://<hostname>:<port>.
+ * Alternately, you may set the port with
+ * hdfsBuilderSetNameNodePort. However, you must not pass the
+ * port in two different ways.
+ */
+ LIBHDFS_EXTERNAL
+ void hdfsBuilderSetNameNode(struct hdfsBuilder *bld, const char *nn);
+
+ /**
+ * Set the port of the HDFS NameNode to connect to.
+ *
+ * @param bld The HDFS builder
+ * @param port The port.
+ */
+ LIBHDFS_EXTERNAL
+ void hdfsBuilderSetNameNodePort(struct hdfsBuilder *bld, tPort port);
+
+ /**
+ * Set the username to use when connecting to the HDFS cluster.
+ *
+ * @param bld The HDFS builder
+ * @param userName The user name. The string will be shallow-copied.
+ */
+ LIBHDFS_EXTERNAL
+ void hdfsBuilderSetUserName(struct hdfsBuilder *bld, const char *userName);
+
+ /**
+ * Set the path to the Kerberos ticket cache to use when connecting to
+ * the HDFS cluster.
+ *
+ * @param bld The HDFS builder
+ * @param kerbTicketCachePath The Kerberos ticket cache path. The string
+ * will be shallow-copied.
+ */
+ LIBHDFS_EXTERNAL
+ void hdfsBuilderSetKerbTicketCachePath(struct hdfsBuilder *bld,
+ const char *kerbTicketCachePath);
+
+ /**
+ * Free an HDFS builder.
+ *
+ * It is normally not necessary to call this function since
+ * hdfsBuilderConnect frees the builder.
+ *
+ * @param bld The HDFS builder
+ */
+ LIBHDFS_EXTERNAL
+ void hdfsFreeBuilder(struct hdfsBuilder *bld);
+
+ /**
+ * Set a configuration string for an HdfsBuilder.
+ *
+ * @param key The key to set.
+ * @param val The value, or NULL to set no value.
+ * This will be shallow-copied. You are responsible for
+ * ensuring that it remains valid until the builder is
+ * freed.
+ *
+ * @return 0 on success; nonzero error code otherwise.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsBuilderConfSetStr(struct hdfsBuilder *bld, const char *key,
+ const char *val);
+
+ /**
+ * Get a configuration string.
+ *
+ * @param key The key to find
+ * @param val (out param) The value. This will be set to NULL if the
+ * key isn't found. You must free this string with
+ * hdfsConfStrFree.
+ *
+ * @return 0 on success; nonzero error code otherwise.
+ * Failure to find the key is not an error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsConfGetStr(const char *key, char **val);
+
+ /**
+ * Get a configuration integer.
+ *
+ * @param key The key to find
+ * @param val (out param) The value. This will NOT be changed if the
+ * key isn't found.
+ *
+ * @return 0 on success; nonzero error code otherwise.
+ * Failure to find the key is not an error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsConfGetInt(const char *key, int32_t *val);
+
+ /**
+ * Free a configuration string found with hdfsConfGetStr.
+ *
+ * @param val A configuration string obtained from hdfsConfGetStr
+ */
+ LIBHDFS_EXTERNAL
+ void hdfsConfStrFree(char *val);
+
+ /**
+ * hdfsDisconnect - Disconnect from the hdfs file system.
+ * Disconnect from hdfs.
+ * @param fs The configured filesystem handle.
+ * @return Returns 0 on success, -1 on error.
+ * Even if there is an error, the resources associated with the
+ * hdfsFS will be freed.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsDisconnect(hdfsFS fs);
+
+ /**
+ * hdfsOpenFile - Open a hdfs file in given mode.
+ * @deprecated Use the hdfsStreamBuilder functions instead.
+ * This function does not support setting block sizes bigger than 2 GB.
+ *
+ * @param fs The configured filesystem handle.
+ * @param path The full path to the file.
+ * @param flags - an | of bits/fcntl.h file flags - supported flags are O_RDONLY, O_WRONLY (meaning create or overwrite i.e., implies O_TRUNCAT),
+ * O_WRONLY|O_APPEND. Other flags are generally ignored other than (O_RDWR || (O_EXCL & O_CREAT)) which return NULL and set errno equal ENOTSUP.
+ * @param bufferSize Size of buffer for read/write - pass 0 if you want
+ * to use the default configured values.
+ * @param replication Block replication - pass 0 if you want to use
+ * the default configured values.
+ * @param blocksize Size of block - pass 0 if you want to use the
+ * default configured values. Note that if you want a block size bigger
+ * than 2 GB, you must use the hdfsStreamBuilder API rather than this
+ * deprecated function.
+ * @return Returns the handle to the open file or NULL on error.
+ */
+ LIBHDFS_EXTERNAL
+ hdfsFile hdfsOpenFile(hdfsFS fs, const char* path, int flags,
+ int bufferSize, short replication, tSize blocksize);
+
+ /**
+ * hdfsStreamBuilderAlloc - Allocate an HDFS stream builder.
+ *
+ * @param fs The configured filesystem handle.
+ * @param path The full path to the file. Will be deep-copied.
+ * @param flags The open flags, as in hdfsOpenFile.
+ * @return Returns the hdfsStreamBuilder, or NULL on error.
+ */
+ LIBHDFS_EXTERNAL
+ struct hdfsStreamBuilder *hdfsStreamBuilderAlloc(hdfsFS fs,
+ const char *path, int flags);
+
+ /**
+ * hdfsStreamBuilderFree - Free an HDFS file builder.
+ *
+ * It is normally not necessary to call this function since
+ * hdfsStreamBuilderBuild frees the builder.
+ *
+ * @param bld The hdfsStreamBuilder to free.
+ */
+ LIBHDFS_EXTERNAL
+ void hdfsStreamBuilderFree(struct hdfsStreamBuilder *bld);
+
+ /**
+ * hdfsStreamBuilderSetBufferSize - Set the stream buffer size.
+ *
+ * @param bld The hdfs stream builder.
+ * @param bufferSize The buffer size to set.
+ *
+ * @return 0 on success, or -1 on error. Errno will be set on error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsStreamBuilderSetBufferSize(struct hdfsStreamBuilder *bld,
+ int32_t bufferSize);
+
+ /**
+ * hdfsStreamBuilderSetReplication - Set the replication for the stream.
+ * This is only relevant for output streams, which will create new blocks.
+ *
+ * @param bld The hdfs stream builder.
+ * @param replication The replication to set.
+ *
+ * @return 0 on success, or -1 on error. Errno will be set on error.
+ * If you call this on an input stream builder, you will get
+ * EINVAL, because this configuration is not relevant to input
+ * streams.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsStreamBuilderSetReplication(struct hdfsStreamBuilder *bld,
+ int16_t replication);
+
+ /**
+ * hdfsStreamBuilderSetDefaultBlockSize - Set the default block size for
+ * the stream. This is only relevant for output streams, which will create
+ * new blocks.
+ *
+ * @param bld The hdfs stream builder.
+ * @param defaultBlockSize The default block size to set.
+ *
+ * @return 0 on success, or -1 on error. Errno will be set on error.
+ * If you call this on an input stream builder, you will get
+ * EINVAL, because this configuration is not relevant to input
+ * streams.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsStreamBuilderSetDefaultBlockSize(struct hdfsStreamBuilder *bld,
+ int64_t defaultBlockSize);
+
+ /**
+ * hdfsStreamBuilderBuild - Build the stream by calling open or create.
+ *
+ * @param bld The hdfs stream builder. This pointer will be freed, whether
+ * or not the open succeeds.
+ *
+ * @return the stream pointer on success, or NULL on error. Errno will be
+ * set on error.
+ */
+ LIBHDFS_EXTERNAL
+ hdfsFile hdfsStreamBuilderBuild(struct hdfsStreamBuilder *bld);
+
+ /**
+ * hdfsTruncateFile - Truncate a hdfs file to given lenght.
+ * @param fs The configured filesystem handle.
+ * @param path The full path to the file.
+ * @param newlength The size the file is to be truncated to
+ * @return 1 if the file has been truncated to the desired newlength
+ * and is immediately available to be reused for write operations
+ * such as append.
+ * 0 if a background process of adjusting the length of the last
+ * block has been started, and clients should wait for it to
+ * complete before proceeding with further file updates.
+ * -1 on error.
+ */
+ int hdfsTruncateFile(hdfsFS fs, const char* path, tOffset newlength);
+
+ /**
+ * hdfsUnbufferFile - Reduce the buffering done on a file.
+ *
+ * @param file The file to unbuffer.
+ * @return 0 on success
+ * ENOTSUP if the file does not support unbuffering
+ * Errno will also be set to this value.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsUnbufferFile(hdfsFile file);
+
+ /**
+ * hdfsCloseFile - Close an open file.
+ * @param fs The configured filesystem handle.
+ * @param file The file handle.
+ * @return Returns 0 on success, -1 on error.
+ * On error, errno will be set appropriately.
+ * If the hdfs file was valid, the memory associated with it will
+ * be freed at the end of this call, even if there was an I/O
+ * error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsCloseFile(hdfsFS fs, hdfsFile file);
+
+
+ /**
+ * hdfsExists - Checks if a given path exsits on the filesystem
+ * @param fs The configured filesystem handle.
+ * @param path The path to look for
+ * @return Returns 0 on success, -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsExists(hdfsFS fs, const char *path);
+
+
+ /**
+ * hdfsSeek - Seek to given offset in file.
+ * This works only for files opened in read-only mode.
+ * @param fs The configured filesystem handle.
+ * @param file The file handle.
+ * @param desiredPos Offset into the file to seek into.
+ * @return Returns 0 on success, -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsSeek(hdfsFS fs, hdfsFile file, tOffset desiredPos);
+
+
+ /**
+ * hdfsTell - Get the current offset in the file, in bytes.
+ * @param fs The configured filesystem handle.
+ * @param file The file handle.
+ * @return Current offset, -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ tOffset hdfsTell(hdfsFS fs, hdfsFile file);
+
+
+ /**
+ * hdfsRead - Read data from an open file.
+ * @param fs The configured filesystem handle.
+ * @param file The file handle.
+ * @param buffer The buffer to copy read bytes into.
+ * @param length The length of the buffer.
+ * @return On success, a positive number indicating how many bytes
+ * were read.
+ * On end-of-file, 0.
+ * On error, -1. Errno will be set to the error code.
+ * Just like the POSIX read function, hdfsRead will return -1
+ * and set errno to EINTR if data is temporarily unavailable,
+ * but we are not yet at the end of the file.
+ */
+ LIBHDFS_EXTERNAL
+ tSize hdfsRead(hdfsFS fs, hdfsFile file, void* buffer, tSize length);
+
+ /**
+ * hdfsPread - Positional read of data from an open file.
+ * @param fs The configured filesystem handle.
+ * @param file The file handle.
+ * @param position Position from which to read
+ * @param buffer The buffer to copy read bytes into.
+ * @param length The length of the buffer.
+ * @return See hdfsRead
+ */
+ LIBHDFS_EXTERNAL
+ tSize hdfsPread(hdfsFS fs, hdfsFile file, tOffset position,
+ void* buffer, tSize length);
+
+
+ /**
+ * hdfsWrite - Write data into an open file.
+ * @param fs The configured filesystem handle.
+ * @param file The file handle.
+ * @param buffer The data.
+ * @param length The no. of bytes to write.
+ * @return Returns the number of bytes written, -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ tSize hdfsWrite(hdfsFS fs, hdfsFile file, const void* buffer,
+ tSize length);
+
+
+ /**
+ * hdfsWrite - Flush the data.
+ * @param fs The configured filesystem handle.
+ * @param file The file handle.
+ * @return Returns 0 on success, -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsFlush(hdfsFS fs, hdfsFile file);
+
+
+ /**
+ * hdfsHFlush - Flush out the data in client's user buffer. After the
+ * return of this call, new readers will see the data.
+ * @param fs configured filesystem handle
+ * @param file file handle
+ * @return 0 on success, -1 on error and sets errno
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsHFlush(hdfsFS fs, hdfsFile file);
+
+
+ /**
+ * hdfsHSync - Similar to posix fsync, Flush out the data in client's
+ * user buffer. all the way to the disk device (but the disk may have
+ * it in its cache).
+ * @param fs configured filesystem handle
+ * @param file file handle
+ * @return 0 on success, -1 on error and sets errno
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsHSync(hdfsFS fs, hdfsFile file);
+
+
+ /**
+ * hdfsAvailable - Number of bytes that can be read from this
+ * input stream without blocking.
+ * @param fs The configured filesystem handle.
+ * @param file The file handle.
+ * @return Returns available bytes; -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsAvailable(hdfsFS fs, hdfsFile file);
+
+
+ /**
+ * hdfsCopy - Copy file from one filesystem to another.
+ * @param srcFS The handle to source filesystem.
+ * @param src The path of source file.
+ * @param dstFS The handle to destination filesystem.
+ * @param dst The path of destination file.
+ * @return Returns 0 on success, -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsCopy(hdfsFS srcFS, const char* src, hdfsFS dstFS, const char* dst);
+
+
+ /**
+ * hdfsMove - Move file from one filesystem to another.
+ * @param srcFS The handle to source filesystem.
+ * @param src The path of source file.
+ * @param dstFS The handle to destination filesystem.
+ * @param dst The path of destination file.
+ * @return Returns 0 on success, -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsMove(hdfsFS srcFS, const char* src, hdfsFS dstFS, const char* dst);
+
+
+ /**
+ * hdfsDelete - Delete file.
+ * @param fs The configured filesystem handle.
+ * @param path The path of the file.
+ * @param recursive if path is a directory and set to
+ * non-zero, the directory is deleted else throws an exception. In
+ * case of a file the recursive argument is irrelevant.
+ * @return Returns 0 on success, -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsDelete(hdfsFS fs, const char* path, int recursive);
+
+ /**
+ * hdfsRename - Rename file.
+ * @param fs The configured filesystem handle.
+ * @param oldPath The path of the source file.
+ * @param newPath The path of the destination file.
+ * @return Returns 0 on success, -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsRename(hdfsFS fs, const char* oldPath, const char* newPath);
+
+
+ /**
+ * hdfsGetWorkingDirectory - Get the current working directory for
+ * the given filesystem.
+ * @param fs The configured filesystem handle.
+ * @param buffer The user-buffer to copy path of cwd into.
+ * @param bufferSize The length of user-buffer.
+ * @return Returns buffer, NULL on error.
+ */
+ LIBHDFS_EXTERNAL
+ char* hdfsGetWorkingDirectory(hdfsFS fs, char *buffer, size_t bufferSize);
+
+
+ /**
+ * hdfsSetWorkingDirectory - Set the working directory. All relative
+ * paths will be resolved relative to it.
+ * @param fs The configured filesystem handle.
+ * @param path The path of the new 'cwd'.
+ * @return Returns 0 on success, -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsSetWorkingDirectory(hdfsFS fs, const char* path);
+
+
+ /**
+ * hdfsCreateDirectory - Make the given file and all non-existent
+ * parents into directories.
+ * @param fs The configured filesystem handle.
+ * @param path The path of the directory.
+ * @return Returns 0 on success, -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsCreateDirectory(hdfsFS fs, const char* path);
+
+
+ /**
+ * hdfsSetReplication - Set the replication of the specified
+ * file to the supplied value
+ * @param fs The configured filesystem handle.
+ * @param path The path of the file.
+ * @return Returns 0 on success, -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsSetReplication(hdfsFS fs, const char* path, int16_t replication);
+
+
+ /**
+ * hdfsFileInfo - Information about a file/directory.
+ */
+ typedef struct {
+ tObjectKind mKind; /* file or directory */
+ char *mName; /* the name of the file */
+ tTime mLastMod; /* the last modification time for the file in seconds */
+ tOffset mSize; /* the size of the file in bytes */
+ short mReplication; /* the count of replicas */
+ tOffset mBlockSize; /* the block size for the file */
+ char *mOwner; /* the owner of the file */
+ char *mGroup; /* the group associated with the file */
+ short mPermissions; /* the permissions associated with the file */
+ tTime mLastAccess; /* the last access time for the file in seconds */
+ } hdfsFileInfo;
+
+
+ /**
+ * hdfsListDirectory - Get list of files/directories for a given
+ * directory-path. hdfsFreeFileInfo should be called to deallocate memory.
+ * @param fs The configured filesystem handle.
+ * @param path The path of the directory.
+ * @param numEntries Set to the number of files/directories in path.
+ * @return Returns a dynamically-allocated array of hdfsFileInfo
+ * objects; NULL on error or empty directory.
+ * errno is set to non-zero on error or zero on success.
+ */
+ LIBHDFS_EXTERNAL
+ hdfsFileInfo *hdfsListDirectory(hdfsFS fs, const char* path,
+ int *numEntries);
+
+
+ /**
+ * hdfsGetPathInfo - Get information about a path as a (dynamically
+ * allocated) single hdfsFileInfo struct. hdfsFreeFileInfo should be
+ * called when the pointer is no longer needed.
+ * @param fs The configured filesystem handle.
+ * @param path The path of the file.
+ * @return Returns a dynamically-allocated hdfsFileInfo object;
+ * NULL on error.
+ */
+ LIBHDFS_EXTERNAL
+ hdfsFileInfo *hdfsGetPathInfo(hdfsFS fs, const char* path);
+
+
+ /**
+ * hdfsFreeFileInfo - Free up the hdfsFileInfo array (including fields)
+ * @param hdfsFileInfo The array of dynamically-allocated hdfsFileInfo
+ * objects.
+ * @param numEntries The size of the array.
+ */
+ LIBHDFS_EXTERNAL
+ void hdfsFreeFileInfo(hdfsFileInfo *hdfsFileInfo, int numEntries);
+
+ /**
+ * hdfsFileIsEncrypted: determine if a file is encrypted based on its
+ * hdfsFileInfo.
+ * @return -1 if there was an error (errno will be set), 0 if the file is
+ * not encrypted, 1 if the file is encrypted.
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsFileIsEncrypted(hdfsFileInfo *hdfsFileInfo);
+
+
+ /**
+ * hdfsGetHosts - Get hostnames where a particular block (determined by
+ * pos & blocksize) of a file is stored. The last element in the array
+ * is NULL. Due to replication, a single block could be present on
+ * multiple hosts.
+ * @param fs The configured filesystem handle.
+ * @param path The path of the file.
+ * @param start The start of the block.
+ * @param length The length of the block.
+ * @return Returns a dynamically-allocated 2-d array of blocks-hosts;
+ * NULL on error.
+ */
+ LIBHDFS_EXTERNAL
+ char*** hdfsGetHosts(hdfsFS fs, const char* path,
+ tOffset start, tOffset length);
+
+
+ /**
+ * hdfsFreeHosts - Free up the structure returned by hdfsGetHosts
+ * @param hdfsFileInfo The array of dynamically-allocated hdfsFileInfo
+ * objects.
+ * @param numEntries The size of the array.
+ */
+ LIBHDFS_EXTERNAL
+ void hdfsFreeHosts(char ***blockHosts);
+
+
+ /**
+ * hdfsGetDefaultBlockSize - Get the default blocksize.
+ *
+ * @param fs The configured filesystem handle.
+ * @deprecated Use hdfsGetDefaultBlockSizeAtPath instead.
+ *
+ * @return Returns the default blocksize, or -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ tOffset hdfsGetDefaultBlockSize(hdfsFS fs);
+
+
+ /**
+ * hdfsGetDefaultBlockSizeAtPath - Get the default blocksize at the
+ * filesystem indicated by a given path.
+ *
+ * @param fs The configured filesystem handle.
+ * @param path The given path will be used to locate the actual
+ * filesystem. The full path does not have to exist.
+ *
+ * @return Returns the default blocksize, or -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ tOffset hdfsGetDefaultBlockSizeAtPath(hdfsFS fs, const char *path);
+
+
+ /**
+ * hdfsGetCapacity - Return the raw capacity of the filesystem.
+ * @param fs The configured filesystem handle.
+ * @return Returns the raw-capacity; -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ tOffset hdfsGetCapacity(hdfsFS fs);
+
+
+ /**
+ * hdfsGetUsed - Return the total raw size of all files in the filesystem.
+ * @param fs The configured filesystem handle.
+ * @return Returns the total-size; -1 on error.
+ */
+ LIBHDFS_EXTERNAL
+ tOffset hdfsGetUsed(hdfsFS fs);
+
+ /**
+ * Change the user and/or group of a file or directory.
+ *
+ * @param fs The configured filesystem handle.
+ * @param path the path to the file or directory
+ * @param owner User string. Set to NULL for 'no change'
+ * @param group Group string. Set to NULL for 'no change'
+ * @return 0 on success else -1
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsChown(hdfsFS fs, const char* path, const char *owner,
+ const char *group);
+
+ /**
+ * hdfsChmod
+ * @param fs The configured filesystem handle.
+ * @param path the path to the file or directory
+ * @param mode the bitmask to set it to
+ * @return 0 on success else -1
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsChmod(hdfsFS fs, const char* path, short mode);
+
+ /**
+ * hdfsUtime
+ * @param fs The configured filesystem handle.
+ * @param path the path to the file or directory
+ * @param mtime new modification time or -1 for no change
+ * @param atime new access time or -1 for no change
+ * @return 0 on success else -1
+ */
+ LIBHDFS_EXTERNAL
+ int hdfsUtime(hdfsFS fs, const char* path, tTime mtime, tTime atime);
+
+ /**
+ * Allocate a zero-copy options structure.
+ *
+ * You must free all options structures allocated with this function using
+ * hadoopRzOptionsFree.
+ *
+ * @return A zero-copy options structure, or NULL if one could
+ * not be allocated. If NULL is returned, errno will
+ * contain the error number.
+ */
+ LIBHDFS_EXTERNAL
+ struct hadoopRzOptions *hadoopRzOptionsAlloc(void);
+
+ /**
+ * Determine whether we should skip checksums in read0.
+ *
+ * @param opts The options structure.
+ * @param skip Nonzero to skip checksums sometimes; zero to always
+ * check them.
+ *
+ * @return 0 on success; -1 plus errno on failure.
+ */
+ LIBHDFS_EXTERNAL
+ int hadoopRzOptionsSetSkipChecksum(
+ struct hadoopRzOptions *opts, int skip);
+
+ /**
+ * Set the ByteBufferPool to use with read0.
+ *
+ * @param opts The options structure.
+ * @param className If this is NULL, we will not use any
+ * ByteBufferPool. If this is non-NULL, it will be
+ * treated as the name of the pool class to use.
+ * For example, you can use
+ * ELASTIC_BYTE_BUFFER_POOL_CLASS.
+ *
+ * @return 0 if the ByteBufferPool class was found and
+ * instantiated;
+ * -1 plus errno otherwise.
+ */
+ LIBHDFS_EXTERNAL
+ int hadoopRzOptionsSetByteBufferPool(
+ struct hadoopRzOptions *opts, const char *className);
+
+ /**
+ * Free a hadoopRzOptionsFree structure.
+ *
+ * @param opts The options structure to free.
+ * Any associated ByteBufferPool will also be freed.
+ */
+ LIBHDFS_EXTERNAL
+ void hadoopRzOptionsFree(struct hadoopRzOptions *opts);
+
+ /**
+ * Perform a byte buffer read.
+ * If possible, this will be a zero-copy (mmap) read.
+ *
+ * @param file The file to read from.
+ * @param opts An options structure created by hadoopRzOptionsAlloc.
+ * @param maxLength The maximum length to read. We may read fewer bytes
+ * than this length.
+ *
+ * @return On success, we will return a new hadoopRzBuffer.
+ * This buffer will continue to be valid and readable
+ * until it is released by readZeroBufferFree. Failure to
+ * release a buffer will lead to a memory leak.
+ * You can access the data within the hadoopRzBuffer with
+ * hadoopRzBufferGet. If you have reached EOF, the data
+ * within the hadoopRzBuffer will be NULL. You must still
+ * free hadoopRzBuffer instances containing NULL.
+ *
+ * On failure, we will return NULL plus an errno code.
+ * errno = EOPNOTSUPP indicates that we could not do a
+ * zero-copy read, and there was no ByteBufferPool
+ * supplied.
+ */
+ LIBHDFS_EXTERNAL
+ struct hadoopRzBuffer* hadoopReadZero(hdfsFile file,
+ struct hadoopRzOptions *opts, int32_t maxLength);
+
+ /**
+ * Determine the length of the buffer returned from readZero.
+ *
+ * @param buffer a buffer returned from readZero.
+ * @return the length of the buffer.
+ */
+ LIBHDFS_EXTERNAL
+ int32_t hadoopRzBufferLength(const struct hadoopRzBuffer *buffer);
+
+ /**
+ * Get a pointer to the raw buffer returned from readZero.
+ *
+ * To find out how many bytes this buffer contains, call
+ * hadoopRzBufferLength.
+ *
+ * @param buffer a buffer returned from readZero.
+ * @return a pointer to the start of the buffer. This will be
+ * NULL when end-of-file has been reached.
+ */
+ LIBHDFS_EXTERNAL
+ const void *hadoopRzBufferGet(const struct hadoopRzBuffer *buffer);
+
+ /**
+ * Release a buffer obtained through readZero.
+ *
+ * @param file The hdfs stream that created this buffer. This must be
+ * the same stream you called hadoopReadZero on.
+ * @param buffer The buffer to release.
+ */
+ LIBHDFS_EXTERNAL
+ void hadoopRzBufferFree(hdfsFile file, struct hadoopRzBuffer *buffer);
+
+#ifdef __cplusplus
+}
+#endif
+
+#undef LIBHDFS_EXTERNAL
+#endif /*LIBHDFS_HDFS_H*/
+
+/**
+ * vim: ts=4: sw=4: et
+ */
diff --git a/src/arrow/cpp/thirdparty/versions.txt b/src/arrow/cpp/thirdparty/versions.txt
new file mode 100644
index 000000000..12da46458
--- /dev/null
+++ b/src/arrow/cpp/thirdparty/versions.txt
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Toolchain library versions
+#
+# This file is used by `download_dependencies.sh` and cmake to figure out which
+# version of a dependency to fetch. In order to add a new dependency, add a
+# version variable, e.g. MY_DEP_VERSION and append an entry in the
+# `DEPENDENCIES` array (see the comment on top of the declaration for the
+# format).
+
+ARROW_ABSL_BUILD_VERSION=20210324.2
+ARROW_ABSL_BUILD_SHA256_CHECKSUM=59b862f50e710277f8ede96f083a5bb8d7c9595376146838b9580be90374ee1f
+ARROW_AWSSDK_BUILD_VERSION=1.8.133
+ARROW_AWSSDK_BUILD_SHA256_CHECKSUM=d6c495bc06be5e21dac716571305d77437e7cfd62a2226b8fe48d9ab5785a8d6
+ARROW_AWS_CHECKSUMS_BUILD_VERSION=v0.1.10
+ARROW_AWS_CHECKSUMS_BUILD_SHA256_CHECKSUM=c9d0100a5743765fc8034e34e2310f77f59b1adab6f2e2f2d4d2a3bd81b2a36d
+ARROW_AWS_C_COMMON_BUILD_VERSION=v0.6.9
+ARROW_AWS_C_COMMON_BUILD_SHA256_CHECKSUM=928a3e36f24d1ee46f9eec360ec5cebfe8b9b8994fe39d4fa74ff51aebb12717
+ARROW_AWS_C_EVENT_STREAM_BUILD_VERSION=v0.1.5
+ARROW_AWS_C_EVENT_STREAM_BUILD_SHA256_CHECKSUM=f1b423a487b5d6dca118bfc0d0c6cc596dc476b282258a3228e73a8f730422d4
+ARROW_BOOST_BUILD_VERSION=1.75.0
+ARROW_BOOST_BUILD_SHA256_CHECKSUM=cb97b36e2295a321c34851e0455bc2630ad6c691d4f9f589170066cd11c835b4
+ARROW_BROTLI_BUILD_VERSION=v1.0.9
+ARROW_BROTLI_BUILD_SHA256_CHECKSUM=f9e8d81d0405ba66d181529af42a3354f838c939095ff99930da6aa9cdf6fe46
+ARROW_BZIP2_BUILD_VERSION=1.0.8
+ARROW_BZIP2_BUILD_SHA256_CHECKSUM=ab5a03176ee106d3f0fa90e381da478ddae405918153cca248e682cd0c4a2269
+ARROW_CARES_BUILD_VERSION=1.17.1
+ARROW_CARES_BUILD_SHA256_CHECKSUM=d73dd0f6de824afd407ce10750ea081af47eba52b8a6cb307d220131ad93fc40
+ARROW_CRC32C_BUILD_VERSION=1.1.2
+ARROW_CRC32C_BUILD_SHA256_CHECKSUM=ac07840513072b7fcebda6e821068aa04889018f24e10e46181068fb214d7e56
+ARROW_GBENCHMARK_BUILD_VERSION=v1.5.2
+ARROW_GBENCHMARK_BUILD_SHA256_CHECKSUM=dccbdab796baa1043f04982147e67bb6e118fe610da2c65f88912d73987e700c
+ARROW_GFLAGS_BUILD_VERSION=v2.2.2
+ARROW_GFLAGS_BUILD_SHA256_CHECKSUM=34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf
+ARROW_GLOG_BUILD_VERSION=v0.4.0
+ARROW_GLOG_BUILD_SHA256_CHECKSUM=f28359aeba12f30d73d9e4711ef356dc842886968112162bc73002645139c39c
+ARROW_GOOGLE_CLOUD_CPP_BUILD_VERSION=v1.31.1
+ARROW_GOOGLE_CLOUD_CPP_BUILD_SHA256_CHECKSUM=dc7cbf95b506a84b48cf71e0462985d262183edeaabdacaaee2109852394a609
+ARROW_GRPC_BUILD_VERSION=v1.35.0
+ARROW_GRPC_BUILD_SHA256_CHECKSUM=27dd2fc5c9809ddcde8eb6fa1fa278a3486566dfc28335fca13eb8df8bd3b958
+ARROW_GTEST_BUILD_VERSION=1.10.0
+ARROW_GTEST_BUILD_SHA256_CHECKSUM=9dc9157a9a1551ec7a7e43daea9a694a0bb5fb8bec81235d8a1e6ef64c716dcb
+ARROW_JEMALLOC_BUILD_VERSION=5.2.1
+ARROW_JEMALLOC_BUILD_SHA256_CHECKSUM=34330e5ce276099e2e8950d9335db5a875689a4c6a56751ef3b1d8c537f887f6
+ARROW_LZ4_BUILD_VERSION=v1.9.3
+ARROW_LZ4_BUILD_SHA256_CHECKSUM=030644df4611007ff7dc962d981f390361e6c97a34e5cbc393ddfbe019ffe2c1
+# mimalloc 1.6.7 didn't build on Visual Studio 2015
+# https://github.com/microsoft/mimalloc/issues/353
+ARROW_MIMALLOC_BUILD_VERSION=v1.7.2
+ARROW_MIMALLOC_BUILD_SHA256_CHECKSUM=b1912e354565a4b698410f7583c0f83934a6dbb3ade54ab7ddcb1569320936bd
+ARROW_NLOHMANN_JSON_BUILD_VERSION=v3.10.2
+ARROW_NLOHMANN_JSON_BUILD_SHA256_CHECKSUM=081ed0f9f89805c2d96335c3acfa993b39a0a5b4b4cef7edb68dd2210a13458c
+ARROW_ORC_BUILD_VERSION=1.7.0
+ARROW_ORC_BUILD_SHA256_CHECKSUM=45d6ba9149ffa2aaa168d61ab326f61181861c94529f26da3918a9aa2f801e39
+ARROW_PROTOBUF_BUILD_VERSION=v3.17.3
+ARROW_PROTOBUF_BUILD_SHA256_CHECKSUM=77ad26d3f65222fd96ccc18b055632b0bfedf295cb748b712a98ba1ac0b704b2
+# Because of https://github.com/Tencent/rapidjson/pull/1323, we require
+# a pre-release version of RapidJSON to build with GCC 8 without
+# warnings.
+ARROW_RAPIDJSON_BUILD_VERSION=1a803826f1197b5e30703afe4b9c0e7dd48074f5
+ARROW_RAPIDJSON_BUILD_SHA256_CHECKSUM=0b6b780b6c534bfb0b23d29910bfe361e486bcfeaf106db8bc8995792072905a
+ARROW_RE2_BUILD_VERSION=2021-02-02
+ARROW_RE2_BUILD_SHA256_CHECKSUM=1396ab50c06c1a8885fb68bf49a5ecfd989163015fd96699a180d6414937f33f
+ARROW_SNAPPY_BUILD_VERSION=1.1.8
+ARROW_SNAPPY_BUILD_SHA256_CHECKSUM=16b677f07832a612b0836178db7f374e414f94657c138e6993cbfc5dcc58651f
+ARROW_THRIFT_BUILD_VERSION=0.13.0
+ARROW_THRIFT_BUILD_SHA256_CHECKSUM=7ad348b88033af46ce49148097afe354d513c1fca7c607b59c33ebb6064b5179
+ARROW_UTF8PROC_BUILD_VERSION=v2.6.1
+ARROW_UTF8PROC_BUILD_SHA256_CHECKSUM=4c06a9dc4017e8a2438ef80ee371d45868bda2237a98b26554de7a95406b283b
+ARROW_XSIMD_BUILD_VERSION=aeec9c872c8b475dedd7781336710f2dd2666cb2
+ARROW_XSIMD_BUILD_SHA256_CHECKSUM=0a841e6c8acf216150e4fc19fca8e29fbab9614b56ac7b96e56019264ca27b26
+ARROW_ZLIB_BUILD_VERSION=1.2.11
+ARROW_ZLIB_BUILD_SHA256_CHECKSUM=c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1
+ARROW_ZSTD_BUILD_VERSION=v1.5.0
+ARROW_ZSTD_BUILD_SHA256_CHECKSUM=0d9ade222c64e912d6957b11c923e214e2e010a18f39bec102f572e693ba2867
+
+
+# The first field is the name of the environment variable expected by cmake.
+# This _must_ match what is defined. The second field is the name of the
+# generated archive file. The third field is the url of the project for the
+# given version.
+DEPENDENCIES=(
+ "ARROW_ABSL_URL absl-${ARROW_ABSL_BUILD_VERSION}.tar.gz https://github.com/abseil/abseil-cpp/archive/${ARROW_ABSL_BUILD_VERSION}.tar.gz"
+ "ARROW_AWSSDK_URL aws-sdk-cpp-${ARROW_AWSSDK_BUILD_VERSION}.tar.gz https://github.com/aws/aws-sdk-cpp/archive/${ARROW_AWSSDK_BUILD_VERSION}.tar.gz"
+ "ARROW_AWS_CHECKSUMS_URL aws-checksums-${ARROW_AWS_CHECKSUMS_BUILD_VERSION}.tar.gz https://github.com/awslabs/aws-checksums/archive/${ARROW_AWS_CHECKSUMS_BUILD_VERSION}.tar.gz"
+ "ARROW_AWS_C_COMMON_URL aws-c-common-${ARROW_AWS_C_COMMON_BUILD_VERSION}.tar.gz https://github.com/awslabs/aws-c-common/archive/${ARROW_AWS_C_COMMON_BUILD_VERSION}.tar.gz"
+ "ARROW_AWS_C_EVENT_STREAM_URL aws-c-event-stream-${ARROW_AWS_C_EVENT_STREAM_BUILD_VERSION}.tar.gz https://github.com/awslabs/aws-c-event-stream/archive/${ARROW_AWS_C_EVENT_STREAM_BUILD_VERSION}.tar.gz"
+ "ARROW_BOOST_URL boost-${ARROW_BOOST_BUILD_VERSION}.tar.gz https://github.com/ursa-labs/thirdparty/releases/download/latest/boost_${ARROW_BOOST_BUILD_VERSION//./_}.tar.gz"
+ "ARROW_BROTLI_URL brotli-${ARROW_BROTLI_BUILD_VERSION}.tar.gz https://github.com/google/brotli/archive/${ARROW_BROTLI_BUILD_VERSION}.tar.gz"
+ "ARROW_BZIP2_URL bzip2-${ARROW_BZIP2_BUILD_VERSION}.tar.gz https://sourceware.org/pub/bzip2/bzip2-${ARROW_BZIP2_BUILD_VERSION}.tar.gz"
+ "ARROW_CARES_URL cares-${ARROW_CARES_BUILD_VERSION}.tar.gz https://c-ares.haxx.se/download/c-ares-${ARROW_CARES_BUILD_VERSION}.tar.gz"
+ "ARROW_CRC32C_URL crc32c-${ARROW_CRC32C_BUILD_VERSION}.tar.gz https://github.com/google/crc32c/archive/refs/tags/${ARROW_CRC32C_BUILD_VERSION}.tar.gz"
+ "ARROW_GBENCHMARK_URL gbenchmark-${ARROW_GBENCHMARK_BUILD_VERSION}.tar.gz https://github.com/google/benchmark/archive/${ARROW_GBENCHMARK_BUILD_VERSION}.tar.gz"
+ "ARROW_GFLAGS_URL gflags-${ARROW_GFLAGS_BUILD_VERSION}.tar.gz https://github.com/gflags/gflags/archive/${ARROW_GFLAGS_BUILD_VERSION}.tar.gz"
+ "ARROW_GLOG_URL glog-${ARROW_GLOG_BUILD_VERSION}.tar.gz https://github.com/google/glog/archive/${ARROW_GLOG_BUILD_VERSION}.tar.gz"
+ "ARROW_GOOGLE_CLOUD_CPP_URL google-cloud-cpp-${ARROW_GOOGLE_CLOUD_CPP_BUILD_VERSION}.tar.gz https://github.com/googleapis/google-cloud-cpp/archive/refs/tags/${ARROW_GOOGLE_CLOUD_CPP_BUILD_VERSION}.tar.gz"
+ "ARROW_GRPC_URL grpc-${ARROW_GRPC_BUILD_VERSION}.tar.gz https://github.com/grpc/grpc/archive/${ARROW_GRPC_BUILD_VERSION}.tar.gz"
+ "ARROW_GTEST_URL gtest-${ARROW_GTEST_BUILD_VERSION}.tar.gz https://github.com/google/googletest/archive/release-${ARROW_GTEST_BUILD_VERSION}.tar.gz"
+ "ARROW_JEMALLOC_URL jemalloc-${ARROW_JEMALLOC_BUILD_VERSION}.tar.bz2 https://github.com/jemalloc/jemalloc/releases/download/${ARROW_JEMALLOC_BUILD_VERSION}/jemalloc-${ARROW_JEMALLOC_BUILD_VERSION}.tar.bz2"
+ "ARROW_LZ4_URL lz4-${ARROW_LZ4_BUILD_VERSION}.tar.gz https://github.com/lz4/lz4/archive/${ARROW_LZ4_BUILD_VERSION}.tar.gz"
+ "ARROW_MIMALLOC_URL mimalloc-${ARROW_MIMALLOC_BUILD_VERSION}.tar.gz https://github.com/microsoft/mimalloc/archive/${ARROW_MIMALLOC_BUILD_VERSION}.tar.gz"
+ "ARROW_NLOHMANN_JSON_URL nlohmann-json-${ARROW_NLOHMANN_JSON_BUILD_VERSION}.tar.gz https://github.com/nlohmann/json/archive/refs/tags/${ARROW_NLOHMANN_JSON_BUILD_VERSION}.tar.gz"
+ "ARROW_ORC_URL orc-${ARROW_ORC_BUILD_VERSION}.tar.gz https://github.com/apache/orc/archive/rel/release-${ARROW_ORC_BUILD_VERSION}.tar.gz"
+ "ARROW_PROTOBUF_URL protobuf-${ARROW_PROTOBUF_BUILD_VERSION}.tar.gz https://github.com/google/protobuf/releases/download/${ARROW_PROTOBUF_BUILD_VERSION}/protobuf-all-${ARROW_PROTOBUF_BUILD_VERSION:1}.tar.gz"
+ "ARROW_RAPIDJSON_URL rapidjson-${ARROW_RAPIDJSON_BUILD_VERSION}.tar.gz https://github.com/miloyip/rapidjson/archive/${ARROW_RAPIDJSON_BUILD_VERSION}.tar.gz"
+ "ARROW_RE2_URL re2-${ARROW_RE2_BUILD_VERSION}.tar.gz https://github.com/google/re2/archive/${ARROW_RE2_BUILD_VERSION}.tar.gz"
+ "ARROW_SNAPPY_URL snappy-${ARROW_SNAPPY_BUILD_VERSION}.tar.gz https://github.com/google/snappy/archive/${ARROW_SNAPPY_BUILD_VERSION}.tar.gz"
+ "ARROW_THRIFT_URL thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz https://archive.apache.org/dist/thrift/${ARROW_THRIFT_BUILD_VERSION}/thrift-${ARROW_THRIFT_BUILD_VERSION}.tar.gz"
+ "ARROW_UTF8PROC_URL utf8proc-${ARROW_UTF8PROC_BUILD_VERSION}.tar.gz https://github.com/JuliaStrings/utf8proc/archive/${ARROW_UTF8PROC_BUILD_VERSION}.tar.gz"
+ "ARROW_XSIMD_URL xsimd-${ARROW_XSIMD_BUILD_VERSION}.tar.gz https://github.com/xtensor-stack/xsimd/archive/${ARROW_XSIMD_BUILD_VERSION}.tar.gz"
+ "ARROW_ZLIB_URL zlib-${ARROW_ZLIB_BUILD_VERSION}.tar.gz https://zlib.net/fossils/zlib-${ARROW_ZLIB_BUILD_VERSION}.tar.gz"
+ "ARROW_ZSTD_URL zstd-${ARROW_ZSTD_BUILD_VERSION}.tar.gz https://github.com/facebook/zstd/archive/${ARROW_ZSTD_BUILD_VERSION}.tar.gz"
+)
+
+# vim: set filetype=sh:
diff --git a/src/arrow/cpp/tools/parquet/CMakeLists.txt b/src/arrow/cpp/tools/parquet/CMakeLists.txt
new file mode 100644
index 000000000..81ab49421
--- /dev/null
+++ b/src/arrow/cpp/tools/parquet/CMakeLists.txt
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(PARQUET_BUILD_EXECUTABLES)
+ set(PARQUET_TOOLS parquet-dump-schema parquet-reader parquet-scan)
+
+ foreach(TOOL ${PARQUET_TOOLS})
+ string(REGEX REPLACE "-" "_" TOOL_SOURCE ${TOOL})
+ add_executable(${TOOL} "${TOOL_SOURCE}.cc")
+ if(ARROW_BUILD_SHARED)
+ target_link_libraries(${TOOL} parquet_shared)
+ else()
+ target_link_libraries(${TOOL} parquet_static)
+ endif()
+ # Avoid unsetting RPATH when installing
+ set_target_properties(${TOOL} PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE)
+ install(TARGETS ${TOOL} ${INSTALL_IS_OPTIONAL}
+ RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
+ endforeach(TOOL)
+
+ add_dependencies(parquet ${PARQUET_TOOLS})
+endif()
diff --git a/src/arrow/cpp/tools/parquet/parquet_dump_schema.cc b/src/arrow/cpp/tools/parquet/parquet_dump_schema.cc
new file mode 100644
index 000000000..0d7c2428f
--- /dev/null
+++ b/src/arrow/cpp/tools/parquet/parquet_dump_schema.cc
@@ -0,0 +1,52 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <iostream>
+
+#include "parquet/api/reader.h"
+#include "parquet/api/schema.h"
+
+int main(int argc, char** argv) {
+ bool help_flag = false;
+ std::string filename;
+
+ for (int i = 1; i < argc; i++) {
+ if (!std::strcmp(argv[i], "-?") || !std::strcmp(argv[i], "-h") ||
+ !std::strcmp(argv[i], "--help")) {
+ help_flag = true;
+ } else {
+ filename = argv[i];
+ }
+ }
+
+ if (argc != 2 || help_flag) {
+ std::cerr << "Usage: parquet-dump-schema [-h] [--help]"
+ << " <filename>" << std::endl;
+ return -1;
+ }
+
+ try {
+ std::unique_ptr<parquet::ParquetFileReader> reader =
+ parquet::ParquetFileReader::OpenFile(filename);
+ PrintSchema(reader->metadata()->schema()->schema_root().get(), std::cout);
+ } catch (const std::exception& e) {
+ std::cerr << "Parquet error: " << e.what() << std::endl;
+ return -1;
+ }
+
+ return 0;
+}
diff --git a/src/arrow/cpp/tools/parquet/parquet_reader.cc b/src/arrow/cpp/tools/parquet/parquet_reader.cc
new file mode 100644
index 000000000..c7db8e11a
--- /dev/null
+++ b/src/arrow/cpp/tools/parquet/parquet_reader.cc
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <iostream>
+#include <list>
+#include <memory>
+
+#include "parquet/api/reader.h"
+
+int main(int argc, char** argv) {
+ if (argc > 5 || argc < 2) {
+ std::cerr << "Usage: parquet-reader [--only-metadata] [--no-memory-map] [--json] "
+ << "[--dump] [--print-key-value-metadata] [--columns=...] <file>"
+ << std::endl;
+ return -1;
+ }
+
+ std::string filename;
+ bool print_values = true;
+ bool print_key_value_metadata = false;
+ bool memory_map = true;
+ bool format_json = false;
+ bool format_dump = false;
+
+ // Read command-line options
+ const std::string COLUMNS_PREFIX = "--columns=";
+ std::list<int> columns;
+
+ char *param, *value;
+ for (int i = 1; i < argc; i++) {
+ if ((param = std::strstr(argv[i], "--only-metadata"))) {
+ print_values = false;
+ } else if ((param = std::strstr(argv[i], "--print-key-value-metadata"))) {
+ print_key_value_metadata = true;
+ } else if ((param = std::strstr(argv[i], "--no-memory-map"))) {
+ memory_map = false;
+ } else if ((param = std::strstr(argv[i], "--json"))) {
+ format_json = true;
+ } else if ((param = std::strstr(argv[i], "--dump"))) {
+ format_dump = true;
+ } else if ((param = std::strstr(argv[i], COLUMNS_PREFIX.c_str()))) {
+ value = std::strtok(param + COLUMNS_PREFIX.length(), ",");
+ while (value) {
+ columns.push_back(std::atoi(value));
+ value = std::strtok(nullptr, ",");
+ }
+ } else {
+ filename = argv[i];
+ }
+ }
+
+ try {
+ std::unique_ptr<parquet::ParquetFileReader> reader =
+ parquet::ParquetFileReader::OpenFile(filename, memory_map);
+ parquet::ParquetFilePrinter printer(reader.get());
+ if (format_json) {
+ printer.JSONPrint(std::cout, columns, filename.c_str());
+ } else {
+ printer.DebugPrint(std::cout, columns, print_values, format_dump,
+ print_key_value_metadata, filename.c_str());
+ }
+ } catch (const std::exception& e) {
+ std::cerr << "Parquet error: " << e.what() << std::endl;
+ return -1;
+ }
+
+ return 0;
+}
diff --git a/src/arrow/cpp/tools/parquet/parquet_scan.cc b/src/arrow/cpp/tools/parquet/parquet_scan.cc
new file mode 100644
index 000000000..2a7721e58
--- /dev/null
+++ b/src/arrow/cpp/tools/parquet/parquet_scan.cc
@@ -0,0 +1,78 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <ctime>
+#include <iostream>
+#include <list>
+#include <memory>
+
+#include "parquet/api/reader.h"
+
+int main(int argc, char** argv) {
+ if (argc > 4 || argc < 1) {
+ std::cerr << "Usage: parquet-scan [--batch-size=] [--columns=...] <file>"
+ << std::endl;
+ return -1;
+ }
+
+ std::string filename;
+
+ // Read command-line options
+ int batch_size = 256;
+ const std::string COLUMNS_PREFIX = "--columns=";
+ const std::string BATCH_SIZE_PREFIX = "--batch-size=";
+ std::vector<int> columns;
+ int num_columns = 0;
+
+ char *param, *value;
+ for (int i = 1; i < argc; i++) {
+ if ((param = std::strstr(argv[i], COLUMNS_PREFIX.c_str()))) {
+ value = std::strtok(param + COLUMNS_PREFIX.length(), ",");
+ while (value) {
+ columns.push_back(std::atoi(value));
+ value = std::strtok(nullptr, ",");
+ num_columns++;
+ }
+ } else if ((param = std::strstr(argv[i], BATCH_SIZE_PREFIX.c_str()))) {
+ value = std::strtok(param + BATCH_SIZE_PREFIX.length(), " ");
+ if (value) {
+ batch_size = std::atoi(value);
+ }
+ } else {
+ filename = argv[i];
+ }
+ }
+
+ try {
+ double total_time;
+ std::clock_t start_time = std::clock();
+ std::unique_ptr<parquet::ParquetFileReader> reader =
+ parquet::ParquetFileReader::OpenFile(filename);
+
+ int64_t total_rows = parquet::ScanFileContents(columns, batch_size, reader.get());
+
+ total_time = static_cast<double>(std::clock() - start_time) /
+ static_cast<double>(CLOCKS_PER_SEC);
+ std::cout << total_rows << " rows scanned in " << total_time << " seconds."
+ << std::endl;
+ } catch (const std::exception& e) {
+ std::cerr << "Parquet error: " << e.what() << std::endl;
+ return -1;
+ }
+
+ return 0;
+}
diff --git a/src/arrow/cpp/valgrind.supp b/src/arrow/cpp/valgrind.supp
new file mode 100644
index 000000000..8d2d5da90
--- /dev/null
+++ b/src/arrow/cpp/valgrind.supp
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+{
+ # Casting to/from boolean might read uninitialized data as the null bitmap isn't considered
+ <boolean_cast>
+ Memcheck:Cond
+ fun:*CastFunctor*BooleanType*
+}
+{
+ <llvm>:Conditional jump or move depends on uninitialised value(s)
+ Memcheck:Cond
+ ...
+ fun:*llvm*PassManager*
+}
+{
+ <re2>:Conditional jump or move depends on uninitialised value(s)
+ Memcheck:Cond
+ ...
+ fun:*re2*RE2*
+}
+{
+ <re2>:Use of uninitialised value of size 8
+ Memcheck:Value8
+ ...
+ fun:*re2*RE2*
+}
+{
+ <re2>:Conditional jump or move depends on uninitialised value(s)
+ Memcheck:Cond
+ ...
+ fun:*re2*Prog*
+}
+{
+ <re2>:Use of uninitialised value of size 8
+ Memcheck:Value8
+ ...
+ fun:*re2*Prog*
+}
diff --git a/src/arrow/cpp/vcpkg.json b/src/arrow/cpp/vcpkg.json
new file mode 100644
index 000000000..be8612d3b
--- /dev/null
+++ b/src/arrow/cpp/vcpkg.json
@@ -0,0 +1,46 @@
+{
+ "name": "arrow",
+ "version-string": "6.0.1",
+ "dependencies": [
+ "abseil",
+ {
+ "name": "aws-sdk-cpp",
+ "features": [
+ "config",
+ "cognito-identity",
+ "identity-management",
+ "s3",
+ "sts",
+ "transfer"
+ ]
+ },
+ "benchmark",
+ "boost-filesystem",
+ "boost-multiprecision",
+ "boost-system",
+ "brotli",
+ "bzip2",
+ "c-ares",
+ "curl",
+ "flatbuffers",
+ "gflags",
+ "glog",
+ "grpc",
+ "gtest",
+ "lz4",
+ "openssl",
+ "orc",
+ "protobuf",
+ "rapidjson",
+ "re2",
+ "snappy",
+ "thrift",
+ "utf8proc",
+ "zlib",
+ "zstd"
+ ],
+ "overrides": [
+ { "name": "gtest", "version": "1.10.0", "port-version": 4 }
+ ],
+ "builtin-baseline": "a267ab118c09f56f3dae96c9a4b3410820ad2f0b"
+}